mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 00:59:46 +00:00
fix(datasets)
This commit is contained in:
@@ -632,7 +632,7 @@ def cycle(iterable):
|
||||
iterator = iter(iterable)
|
||||
|
||||
|
||||
def create_branch(repo_id, *, branch: str, repo_type: str | None = None) -> None:
|
||||
def create_branch(repo_id, *, branch: str, repo_type: str | None = None, revision: str | None = None) -> None:
|
||||
"""Create a branch on a existing Hugging Face repo. Delete the branch if it already
|
||||
exists before creating it.
|
||||
"""
|
||||
@@ -644,7 +644,7 @@ def create_branch(repo_id, *, branch: str, repo_type: str | None = None) -> None
|
||||
if ref in refs:
|
||||
api.delete_branch(repo_id, repo_type=repo_type, branch=branch)
|
||||
|
||||
api.create_branch(repo_id, repo_type=repo_type, branch=branch)
|
||||
api.create_branch(repo_id, repo_type=repo_type, branch=branch, revision=revision)
|
||||
|
||||
|
||||
def create_lerobot_dataset_card(
|
||||
|
||||
@@ -205,7 +205,7 @@ def convert_stats_to_json(v1_dir: Path, v2_dir: Path) -> None:
|
||||
def get_features_from_hf_dataset(
|
||||
dataset: Dataset, robot_config: RobotConfig | None = None
|
||||
) -> dict[str, list]:
|
||||
robot_config = parse_robot_config(robot_config)
|
||||
robot_config = parse_robot_config(robot_config) if robot_config else None
|
||||
features = {}
|
||||
for key, ft in dataset.features.items():
|
||||
if isinstance(ft, datasets.Value):
|
||||
@@ -455,7 +455,7 @@ def convert_dataset(
|
||||
branch = "main"
|
||||
if test_branch:
|
||||
branch = test_branch
|
||||
create_branch(repo_id=repo_id, branch=test_branch, repo_type="dataset")
|
||||
create_branch(repo_id=repo_id, branch=test_branch, repo_type="dataset", revision=v1)
|
||||
|
||||
metadata_v1 = load_json(v1x_dir / V1_INFO_PATH)
|
||||
dataset = datasets.load_dataset("parquet", data_dir=v1x_dir / "data", split="train")
|
||||
@@ -564,6 +564,12 @@ def convert_dataset(
|
||||
"features": features,
|
||||
}
|
||||
write_json(metadata_v2_0, v20_dir / INFO_PATH)
|
||||
|
||||
info = load_json(v20_dir / INFO_PATH)
|
||||
if "language_instruction" in info.get("features", {}):
|
||||
del info["features"]["language_instruction"]
|
||||
write_json(info, v20_dir / INFO_PATH)
|
||||
|
||||
convert_stats_to_json(v1x_dir, v20_dir)
|
||||
card = create_lerobot_dataset_card(tags=repo_tags, dataset_info=metadata_v2_0, **card_kwargs)
|
||||
|
||||
@@ -677,6 +683,8 @@ def main():
|
||||
|
||||
if args.robot is not None:
|
||||
robot_config = make_robot_config(args.robot)
|
||||
else:
|
||||
robot_config = None
|
||||
|
||||
del args.robot
|
||||
|
||||
|
||||
Reference in New Issue
Block a user