conversion dest

This commit is contained in:
Pepijn
2025-09-01 11:01:27 +02:00
parent ce5b27d255
commit d35ed3fd83
@@ -13,20 +13,24 @@
# limitations under the License.
"""
This script will help you convert any LeRobot dataset already pushed to the hub from codebase version 2.0 to
2.1. It will:
This script converts a LeRobot dataset already pushed to the Hub from codebase version 2.0 to 2.1.
It downloads metadata from a SOURCE dataset repo, computes/validates per-episode stats, updates
the codebase version in `info.json`, and uploads the result to a DESTINATION dataset repo.
It will:
- Generate per-episodes stats and writes them in `episodes_stats.jsonl`
- Check consistency between these new stats and the old ones.
- Remove the deprecated `stats.json`.
- Update codebase_version in `info.json`.
- Push this new version to the hub on the 'main' branch and tags it with "v2.1".
- Push this new version to the destination repo/branch and tag it with the current codebase version.
Usage:
```bash
python -m lerobot.datasets.v21.convert_dataset_v20_to_v21 \
--repo-id=aliberts/koch_tutorial
--source-repo-id=namespace/source_dataset \
--dest-repo-id=namespace/destination_dataset \
--branch=main
```
"""
@@ -54,48 +58,67 @@ class SuppressWarnings:
def convert_dataset(
repo_id: str,
source_repo_id: str,
dest_repo_id: str,
branch: str | None = None,
num_workers: int = 4,
):
# Download metadata from the source repo at v2.0
with SuppressWarnings():
dataset = LeRobotDataset(repo_id, revision=V20, force_cache_sync=True)
dataset = LeRobotDataset(source_repo_id, revision=V20, force_cache_sync=True)
# Ensure we recompute fresh episodes stats
if (dataset.root / EPISODES_STATS_PATH).is_file():
(dataset.root / EPISODES_STATS_PATH).unlink()
# Compute and validate stats
convert_stats(dataset, num_workers=num_workers)
ref_stats = load_stats(dataset.root)
check_aggregate_stats(dataset, ref_stats)
# Update codebase version in info.json
dataset.meta.info["codebase_version"] = CODEBASE_VERSION
write_info(dataset.meta.info, dataset.root)
dataset.push_to_hub(branch=branch, tag_version=False, allow_patterns="meta/")
# delete old stats.json file
if (dataset.root / STATS_PATH).is_file:
# Remove deprecated stats.json locally so it won't be uploaded
if (dataset.root / STATS_PATH).is_file():
(dataset.root / STATS_PATH).unlink()
# Push only meta/ to destination repo
hub_api = HfApi()
if hub_api.file_exists(
repo_id=dataset.repo_id, filename=STATS_PATH, revision=branch, repo_type="dataset"
):
hub_api.delete_file(
path_in_repo=STATS_PATH, repo_id=dataset.repo_id, revision=branch, repo_type="dataset"
)
hub_api.create_repo(repo_id=dest_repo_id, private=False, repo_type="dataset", exist_ok=True)
if branch:
hub_api.create_branch(repo_id=dest_repo_id, branch=branch, repo_type="dataset", exist_ok=True)
hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
hub_api.upload_folder(
repo_id=dest_repo_id,
folder_path=str(dataset.root),
repo_type="dataset",
revision=branch,
allow_patterns="meta/",
)
# Ensure old stats.json is deleted on destination
if hub_api.file_exists(repo_id=dest_repo_id, filename=STATS_PATH, revision=branch, repo_type="dataset"):
hub_api.delete_file(path_in_repo=STATS_PATH, repo_id=dest_repo_id, revision=branch, repo_type="dataset")
# Tag destination with current codebase version
hub_api.create_tag(dest_repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--repo-id",
"--source-repo-id",
type=str,
required=True,
help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset "
"(e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).",
help="Source dataset repo id to download from (must be v2.0).",
)
parser.add_argument(
"--dest-repo-id",
type=str,
required=True,
help="Destination dataset repo id to upload the converted metadata to.",
)
parser.add_argument(
"--branch",