fix(datasets)

This commit is contained in:
CarolinePascal
2026-03-07 00:18:57 +01:00
parent b883328e6c
commit 07931b1101
2 changed files with 12 additions and 4 deletions
+2 -2
View File
@@ -632,7 +632,7 @@ def cycle(iterable):
iterator = iter(iterable)
def create_branch(repo_id, *, branch: str, repo_type: str | None = None) -> None:
def create_branch(repo_id, *, branch: str, repo_type: str | None = None, revision: str | None = None) -> None:
"""Create a branch on a existing Hugging Face repo. Delete the branch if it already
exists before creating it.
"""
@@ -644,7 +644,7 @@ def create_branch(repo_id, *, branch: str, repo_type: str | None = None) -> None
if ref in refs:
api.delete_branch(repo_id, repo_type=repo_type, branch=branch)
api.create_branch(repo_id, repo_type=repo_type, branch=branch)
api.create_branch(repo_id, repo_type=repo_type, branch=branch, revision=revision)
def create_lerobot_dataset_card(
@@ -205,7 +205,7 @@ def convert_stats_to_json(v1_dir: Path, v2_dir: Path) -> None:
def get_features_from_hf_dataset(
dataset: Dataset, robot_config: RobotConfig | None = None
) -> dict[str, list]:
robot_config = parse_robot_config(robot_config)
robot_config = parse_robot_config(robot_config) if robot_config else None
features = {}
for key, ft in dataset.features.items():
if isinstance(ft, datasets.Value):
@@ -455,7 +455,7 @@ def convert_dataset(
branch = "main"
if test_branch:
branch = test_branch
create_branch(repo_id=repo_id, branch=test_branch, repo_type="dataset")
create_branch(repo_id=repo_id, branch=test_branch, repo_type="dataset", revision=v1)
metadata_v1 = load_json(v1x_dir / V1_INFO_PATH)
dataset = datasets.load_dataset("parquet", data_dir=v1x_dir / "data", split="train")
@@ -564,6 +564,12 @@ def convert_dataset(
"features": features,
}
write_json(metadata_v2_0, v20_dir / INFO_PATH)
info = load_json(v20_dir / INFO_PATH)
if "language_instruction" in info.get("features", {}):
del info["features"]["language_instruction"]
write_json(info, v20_dir / INFO_PATH)
convert_stats_to_json(v1x_dir, v20_dir)
card = create_lerobot_dataset_card(tags=repo_tags, dataset_info=metadata_v2_0, **card_kwargs)
@@ -677,6 +683,8 @@ def main():
if args.robot is not None:
robot_config = make_robot_config(args.robot)
else:
robot_config = None
del args.robot