mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 00:59:46 +00:00
in place
This commit is contained in:
@@ -15,17 +15,15 @@
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Unify all tasks in a dataset to a single task.
|
||||
Unify all tasks in a dataset to a single task (modifies in-place).
|
||||
|
||||
This script:
|
||||
1. Loads a dataset
|
||||
2. Sets all task_index to 0 and task description to "fold"
|
||||
3. Updates tasks.parquet and task_index in data files
|
||||
3. Updates tasks.parquet and task_index in data files (in-place, no copying)
|
||||
|
||||
Usage:
|
||||
python examples/openarms/unify_task.py \
|
||||
--input-repo-id lerobot-data-collection/level1_rac1 \
|
||||
--output-repo-id lerobot-data-collection/level1_rac1
|
||||
python examples/openarms/unify_task.py --repo-id lerobot-data-collection/level1_rac1
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -41,7 +39,6 @@ from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||
from lerobot.datasets.utils import (
|
||||
DATA_DIR,
|
||||
write_info,
|
||||
write_stats,
|
||||
write_tasks,
|
||||
)
|
||||
from lerobot.utils.constants import HF_LEROBOT_HOME
|
||||
@@ -52,28 +49,24 @@ UNIFIED_TASK = "fold"
|
||||
|
||||
|
||||
def unify_dataset_tasks(
|
||||
input_repo_id: str,
|
||||
output_repo_id: str,
|
||||
input_root: Path | None = None,
|
||||
output_root: Path | None = None,
|
||||
repo_id: str,
|
||||
root: Path | None = None,
|
||||
push_to_hub: bool = False,
|
||||
) -> None:
|
||||
"""Unify all tasks in a dataset to a single task.
|
||||
"""Unify all tasks in a dataset to a single task (modifies in-place).
|
||||
|
||||
Args:
|
||||
input_repo_id: Source dataset repository ID.
|
||||
output_repo_id: Output dataset repository ID.
|
||||
input_root: Optional root path for input dataset.
|
||||
output_root: Optional root path for output dataset.
|
||||
repo_id: Dataset repository ID.
|
||||
root: Optional root path for dataset.
|
||||
push_to_hub: Whether to push the result to HuggingFace Hub.
|
||||
"""
|
||||
logging.info(f"Loading metadata from {input_repo_id}")
|
||||
input_root = root if root else HF_LEROBOT_HOME / repo_id
|
||||
input_repo_id = repo_id
|
||||
|
||||
input_root = input_root if input_root else HF_LEROBOT_HOME / input_repo_id
|
||||
output_root = output_root if output_root else HF_LEROBOT_HOME / output_repo_id
|
||||
logging.info(f"Loading metadata from {repo_id}")
|
||||
|
||||
# Load source metadata
|
||||
src_meta = LeRobotDatasetMetadata(input_repo_id, root=input_root)
|
||||
src_meta = LeRobotDatasetMetadata(repo_id, root=input_root)
|
||||
|
||||
logging.info(f"Source dataset: {src_meta.total_episodes} episodes, {src_meta.total_frames} frames")
|
||||
logging.info(f"Original tasks: {len(src_meta.tasks)}")
|
||||
@@ -123,32 +116,20 @@ def unify_dataset_tasks(
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Unify all tasks in a dataset to a single task 'fold'."
|
||||
description="Unify all tasks in a dataset to a single task 'fold' (modifies in-place)."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--input-repo-id",
|
||||
"--repo-id",
|
||||
type=str,
|
||||
default="lerobot-data-collection/full_folding_2025-11-30",
|
||||
help="Input dataset repository ID",
|
||||
required=True,
|
||||
help="Dataset repository ID",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-repo-id",
|
||||
type=str,
|
||||
default="lerobot-data-collection/folding_2025-11-30",
|
||||
help="Output dataset repository ID",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input-root",
|
||||
"--root",
|
||||
type=Path,
|
||||
default=None,
|
||||
help="Optional input root path (defaults to HF_LEROBOT_HOME/input_repo_id)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-root",
|
||||
type=Path,
|
||||
default=None,
|
||||
help="Optional output root path (defaults to HF_LEROBOT_HOME/output_repo_id)",
|
||||
help="Optional root path (defaults to HF_LEROBOT_HOME/repo_id)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push-to-hub",
|
||||
@@ -161,10 +142,8 @@ def main():
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
||||
|
||||
unify_dataset_tasks(
|
||||
input_repo_id=args.input_repo_id,
|
||||
output_repo_id=args.output_repo_id,
|
||||
input_root=args.input_root,
|
||||
output_root=args.output_root,
|
||||
repo_id=args.repo_id,
|
||||
root=args.root,
|
||||
push_to_hub=args.push_to_hub,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user