mirror of
https://github.com/huggingface/lerobot.git
synced 2026-07-02 23:57:24 +00:00
conversion dest
This commit is contained in:
@@ -13,20 +13,24 @@
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This script will help you convert any LeRobot dataset already pushed to the hub from codebase version 2.0 to
|
||||
2.1. It will:
|
||||
This script converts a LeRobot dataset already pushed to the Hub from codebase version 2.0 to 2.1.
|
||||
It downloads metadata from a SOURCE dataset repo, computes/validates per-episode stats, updates
|
||||
the codebase version in `info.json`, and uploads the result to a DESTINATION dataset repo.
|
||||
It will:
|
||||
|
||||
- Generate per-episodes stats and writes them in `episodes_stats.jsonl`
|
||||
- Check consistency between these new stats and the old ones.
|
||||
- Remove the deprecated `stats.json`.
|
||||
- Update codebase_version in `info.json`.
|
||||
- Push this new version to the hub on the 'main' branch and tags it with "v2.1".
|
||||
- Push this new version to the destination repo/branch and tag it with the current codebase version.
|
||||
|
||||
Usage:
|
||||
|
||||
```bash
|
||||
python -m lerobot.datasets.v21.convert_dataset_v20_to_v21 \
|
||||
--repo-id=aliberts/koch_tutorial
|
||||
--source-repo-id=namespace/source_dataset \
|
||||
--dest-repo-id=namespace/destination_dataset \
|
||||
--branch=main
|
||||
```
|
||||
|
||||
"""
|
||||
@@ -54,48 +58,67 @@ class SuppressWarnings:
|
||||
|
||||
|
||||
def convert_dataset(
|
||||
repo_id: str,
|
||||
source_repo_id: str,
|
||||
dest_repo_id: str,
|
||||
branch: str | None = None,
|
||||
num_workers: int = 4,
|
||||
):
|
||||
# Download metadata from the source repo at v2.0
|
||||
with SuppressWarnings():
|
||||
dataset = LeRobotDataset(repo_id, revision=V20, force_cache_sync=True)
|
||||
dataset = LeRobotDataset(source_repo_id, revision=V20, force_cache_sync=True)
|
||||
|
||||
# Ensure we recompute fresh episodes stats
|
||||
if (dataset.root / EPISODES_STATS_PATH).is_file():
|
||||
(dataset.root / EPISODES_STATS_PATH).unlink()
|
||||
|
||||
# Compute and validate stats
|
||||
convert_stats(dataset, num_workers=num_workers)
|
||||
ref_stats = load_stats(dataset.root)
|
||||
check_aggregate_stats(dataset, ref_stats)
|
||||
|
||||
# Update codebase version in info.json
|
||||
dataset.meta.info["codebase_version"] = CODEBASE_VERSION
|
||||
write_info(dataset.meta.info, dataset.root)
|
||||
|
||||
dataset.push_to_hub(branch=branch, tag_version=False, allow_patterns="meta/")
|
||||
|
||||
# delete old stats.json file
|
||||
if (dataset.root / STATS_PATH).is_file:
|
||||
# Remove deprecated stats.json locally so it won't be uploaded
|
||||
if (dataset.root / STATS_PATH).is_file():
|
||||
(dataset.root / STATS_PATH).unlink()
|
||||
|
||||
# Push only meta/ to destination repo
|
||||
hub_api = HfApi()
|
||||
if hub_api.file_exists(
|
||||
repo_id=dataset.repo_id, filename=STATS_PATH, revision=branch, repo_type="dataset"
|
||||
):
|
||||
hub_api.delete_file(
|
||||
path_in_repo=STATS_PATH, repo_id=dataset.repo_id, revision=branch, repo_type="dataset"
|
||||
)
|
||||
hub_api.create_repo(repo_id=dest_repo_id, private=False, repo_type="dataset", exist_ok=True)
|
||||
if branch:
|
||||
hub_api.create_branch(repo_id=dest_repo_id, branch=branch, repo_type="dataset", exist_ok=True)
|
||||
|
||||
hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
|
||||
hub_api.upload_folder(
|
||||
repo_id=dest_repo_id,
|
||||
folder_path=str(dataset.root),
|
||||
repo_type="dataset",
|
||||
revision=branch,
|
||||
allow_patterns="meta/",
|
||||
)
|
||||
|
||||
# Ensure old stats.json is deleted on destination
|
||||
if hub_api.file_exists(repo_id=dest_repo_id, filename=STATS_PATH, revision=branch, repo_type="dataset"):
|
||||
hub_api.delete_file(path_in_repo=STATS_PATH, repo_id=dest_repo_id, revision=branch, repo_type="dataset")
|
||||
|
||||
# Tag destination with current codebase version
|
||||
hub_api.create_tag(dest_repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
"--source-repo-id",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset "
|
||||
"(e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).",
|
||||
help="Source dataset repo id to download from (must be v2.0).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dest-repo-id",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Destination dataset repo id to upload the converted metadata to.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--branch",
|
||||
|
||||
Reference in New Issue
Block a user