This commit is contained in:
Pepijn
2026-01-03 22:12:22 +01:00
parent 9fd329713a
commit c85f1692d6
+20 -41
View File
@@ -15,17 +15,15 @@
# limitations under the License. # limitations under the License.
""" """
Unify all tasks in a dataset to a single task. Unify all tasks in a dataset to a single task (modifies in-place).
This script: This script:
1. Loads a dataset 1. Loads a dataset
2. Sets all task_index to 0 and task description to "fold" 2. Sets all task_index to 0 and task description to "fold"
3. Updates tasks.parquet and task_index in data files 3. Updates tasks.parquet and task_index in data files (in-place, no copying)
Usage: Usage:
python examples/openarms/unify_task.py \ python examples/openarms/unify_task.py --repo-id lerobot-data-collection/level1_rac1
--input-repo-id lerobot-data-collection/level1_rac1 \
--output-repo-id lerobot-data-collection/level1_rac1
""" """
from __future__ import annotations from __future__ import annotations
@@ -41,7 +39,6 @@ from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
from lerobot.datasets.utils import ( from lerobot.datasets.utils import (
DATA_DIR, DATA_DIR,
write_info, write_info,
write_stats,
write_tasks, write_tasks,
) )
from lerobot.utils.constants import HF_LEROBOT_HOME from lerobot.utils.constants import HF_LEROBOT_HOME
@@ -52,28 +49,24 @@ UNIFIED_TASK = "fold"
def unify_dataset_tasks( def unify_dataset_tasks(
input_repo_id: str, repo_id: str,
output_repo_id: str, root: Path | None = None,
input_root: Path | None = None,
output_root: Path | None = None,
push_to_hub: bool = False, push_to_hub: bool = False,
) -> None: ) -> None:
"""Unify all tasks in a dataset to a single task. """Unify all tasks in a dataset to a single task (modifies in-place).
Args: Args:
input_repo_id: Source dataset repository ID. repo_id: Dataset repository ID.
output_repo_id: Output dataset repository ID. root: Optional root path for dataset.
input_root: Optional root path for input dataset.
output_root: Optional root path for output dataset.
push_to_hub: Whether to push the result to HuggingFace Hub. push_to_hub: Whether to push the result to HuggingFace Hub.
""" """
logging.info(f"Loading metadata from {input_repo_id}") input_root = root if root else HF_LEROBOT_HOME / repo_id
input_repo_id = repo_id
input_root = input_root if input_root else HF_LEROBOT_HOME / input_repo_id logging.info(f"Loading metadata from {repo_id}")
output_root = output_root if output_root else HF_LEROBOT_HOME / output_repo_id
# Load source metadata # Load source metadata
src_meta = LeRobotDatasetMetadata(input_repo_id, root=input_root) src_meta = LeRobotDatasetMetadata(repo_id, root=input_root)
logging.info(f"Source dataset: {src_meta.total_episodes} episodes, {src_meta.total_frames} frames") logging.info(f"Source dataset: {src_meta.total_episodes} episodes, {src_meta.total_frames} frames")
logging.info(f"Original tasks: {len(src_meta.tasks)}") logging.info(f"Original tasks: {len(src_meta.tasks)}")
@@ -123,32 +116,20 @@ def unify_dataset_tasks(
def main(): def main():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Unify all tasks in a dataset to a single task 'fold'." description="Unify all tasks in a dataset to a single task 'fold' (modifies in-place)."
) )
parser.add_argument( parser.add_argument(
"--input-repo-id", "--repo-id",
type=str, type=str,
default="lerobot-data-collection/full_folding_2025-11-30", required=True,
help="Input dataset repository ID", help="Dataset repository ID",
) )
parser.add_argument( parser.add_argument(
"--output-repo-id", "--root",
type=str,
default="lerobot-data-collection/folding_2025-11-30",
help="Output dataset repository ID",
)
parser.add_argument(
"--input-root",
type=Path, type=Path,
default=None, default=None,
help="Optional input root path (defaults to HF_LEROBOT_HOME/input_repo_id)", help="Optional root path (defaults to HF_LEROBOT_HOME/repo_id)",
)
parser.add_argument(
"--output-root",
type=Path,
default=None,
help="Optional output root path (defaults to HF_LEROBOT_HOME/output_repo_id)",
) )
parser.add_argument( parser.add_argument(
"--push-to-hub", "--push-to-hub",
@@ -161,10 +142,8 @@ def main():
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
unify_dataset_tasks( unify_dataset_tasks(
input_repo_id=args.input_repo_id, repo_id=args.repo_id,
output_repo_id=args.output_repo_id, root=args.root,
input_root=args.input_root,
output_root=args.output_root,
push_to_hub=args.push_to_hub, push_to_hub=args.push_to_hub,
) )