From c85f1692d63e672d79824d63e692f3e22720f97e Mon Sep 17 00:00:00 2001 From: Pepijn Date: Sat, 3 Jan 2026 22:12:22 +0100 Subject: [PATCH] in place --- examples/openarms/unify_task.py | 61 +++++++++++---------------------- 1 file changed, 20 insertions(+), 41 deletions(-) diff --git a/examples/openarms/unify_task.py b/examples/openarms/unify_task.py index 81fe93f95..177f45f02 100644 --- a/examples/openarms/unify_task.py +++ b/examples/openarms/unify_task.py @@ -15,17 +15,15 @@ # limitations under the License. """ -Unify all tasks in a dataset to a single task. +Unify all tasks in a dataset to a single task (modifies in-place). This script: 1. Loads a dataset 2. Sets all task_index to 0 and task description to "fold" -3. Updates tasks.parquet and task_index in data files +3. Updates tasks.parquet and task_index in data files (in-place, no copying) Usage: - python examples/openarms/unify_task.py \ - --input-repo-id lerobot-data-collection/level1_rac1 \ - --output-repo-id lerobot-data-collection/level1_rac1 + python examples/openarms/unify_task.py --repo-id lerobot-data-collection/level1_rac1 """ from __future__ import annotations @@ -41,7 +39,6 @@ from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata from lerobot.datasets.utils import ( DATA_DIR, write_info, - write_stats, write_tasks, ) from lerobot.utils.constants import HF_LEROBOT_HOME @@ -52,28 +49,24 @@ UNIFIED_TASK = "fold" def unify_dataset_tasks( - input_repo_id: str, - output_repo_id: str, - input_root: Path | None = None, - output_root: Path | None = None, + repo_id: str, + root: Path | None = None, push_to_hub: bool = False, ) -> None: - """Unify all tasks in a dataset to a single task. + """Unify all tasks in a dataset to a single task (modifies in-place). Args: - input_repo_id: Source dataset repository ID. - output_repo_id: Output dataset repository ID. - input_root: Optional root path for input dataset. - output_root: Optional root path for output dataset. + repo_id: Dataset repository ID. + root: Optional root path for dataset. push_to_hub: Whether to push the result to HuggingFace Hub. """ - logging.info(f"Loading metadata from {input_repo_id}") + input_root = root if root else HF_LEROBOT_HOME / repo_id + input_repo_id = repo_id - input_root = input_root if input_root else HF_LEROBOT_HOME / input_repo_id - output_root = output_root if output_root else HF_LEROBOT_HOME / output_repo_id + logging.info(f"Loading metadata from {repo_id}") # Load source metadata - src_meta = LeRobotDatasetMetadata(input_repo_id, root=input_root) + src_meta = LeRobotDatasetMetadata(repo_id, root=input_root) logging.info(f"Source dataset: {src_meta.total_episodes} episodes, {src_meta.total_frames} frames") logging.info(f"Original tasks: {len(src_meta.tasks)}") @@ -123,32 +116,20 @@ def unify_dataset_tasks( def main(): parser = argparse.ArgumentParser( - description="Unify all tasks in a dataset to a single task 'fold'." + description="Unify all tasks in a dataset to a single task 'fold' (modifies in-place)." ) parser.add_argument( - "--input-repo-id", + "--repo-id", type=str, - default="lerobot-data-collection/full_folding_2025-11-30", - help="Input dataset repository ID", + required=True, + help="Dataset repository ID", ) parser.add_argument( - "--output-repo-id", - type=str, - default="lerobot-data-collection/folding_2025-11-30", - help="Output dataset repository ID", - ) - parser.add_argument( - "--input-root", + "--root", type=Path, default=None, - help="Optional input root path (defaults to HF_LEROBOT_HOME/input_repo_id)", - ) - parser.add_argument( - "--output-root", - type=Path, - default=None, - help="Optional output root path (defaults to HF_LEROBOT_HOME/output_repo_id)", + help="Optional root path (defaults to HF_LEROBOT_HOME/repo_id)", ) parser.add_argument( "--push-to-hub", @@ -161,10 +142,8 @@ def main(): logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") unify_dataset_tasks( - input_repo_id=args.input_repo_id, - output_repo_id=args.output_repo_id, - input_root=args.input_root, - output_root=args.output_root, + repo_id=args.repo_id, + root=args.root, push_to_hub=args.push_to_hub, )