mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-20 11:09:59 +00:00
in place
This commit is contained in:
@@ -15,17 +15,15 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Unify all tasks in a dataset to a single task.
|
Unify all tasks in a dataset to a single task (modifies in-place).
|
||||||
|
|
||||||
This script:
|
This script:
|
||||||
1. Loads a dataset
|
1. Loads a dataset
|
||||||
2. Sets all task_index to 0 and task description to "fold"
|
2. Sets all task_index to 0 and task description to "fold"
|
||||||
3. Updates tasks.parquet and task_index in data files
|
3. Updates tasks.parquet and task_index in data files (in-place, no copying)
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
python examples/openarms/unify_task.py \
|
python examples/openarms/unify_task.py --repo-id lerobot-data-collection/level1_rac1
|
||||||
--input-repo-id lerobot-data-collection/level1_rac1 \
|
|
||||||
--output-repo-id lerobot-data-collection/level1_rac1
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
@@ -41,7 +39,6 @@ from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
|||||||
from lerobot.datasets.utils import (
|
from lerobot.datasets.utils import (
|
||||||
DATA_DIR,
|
DATA_DIR,
|
||||||
write_info,
|
write_info,
|
||||||
write_stats,
|
|
||||||
write_tasks,
|
write_tasks,
|
||||||
)
|
)
|
||||||
from lerobot.utils.constants import HF_LEROBOT_HOME
|
from lerobot.utils.constants import HF_LEROBOT_HOME
|
||||||
@@ -52,28 +49,24 @@ UNIFIED_TASK = "fold"
|
|||||||
|
|
||||||
|
|
||||||
def unify_dataset_tasks(
|
def unify_dataset_tasks(
|
||||||
input_repo_id: str,
|
repo_id: str,
|
||||||
output_repo_id: str,
|
root: Path | None = None,
|
||||||
input_root: Path | None = None,
|
|
||||||
output_root: Path | None = None,
|
|
||||||
push_to_hub: bool = False,
|
push_to_hub: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Unify all tasks in a dataset to a single task.
|
"""Unify all tasks in a dataset to a single task (modifies in-place).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_repo_id: Source dataset repository ID.
|
repo_id: Dataset repository ID.
|
||||||
output_repo_id: Output dataset repository ID.
|
root: Optional root path for dataset.
|
||||||
input_root: Optional root path for input dataset.
|
|
||||||
output_root: Optional root path for output dataset.
|
|
||||||
push_to_hub: Whether to push the result to HuggingFace Hub.
|
push_to_hub: Whether to push the result to HuggingFace Hub.
|
||||||
"""
|
"""
|
||||||
logging.info(f"Loading metadata from {input_repo_id}")
|
input_root = root if root else HF_LEROBOT_HOME / repo_id
|
||||||
|
input_repo_id = repo_id
|
||||||
|
|
||||||
input_root = input_root if input_root else HF_LEROBOT_HOME / input_repo_id
|
logging.info(f"Loading metadata from {repo_id}")
|
||||||
output_root = output_root if output_root else HF_LEROBOT_HOME / output_repo_id
|
|
||||||
|
|
||||||
# Load source metadata
|
# Load source metadata
|
||||||
src_meta = LeRobotDatasetMetadata(input_repo_id, root=input_root)
|
src_meta = LeRobotDatasetMetadata(repo_id, root=input_root)
|
||||||
|
|
||||||
logging.info(f"Source dataset: {src_meta.total_episodes} episodes, {src_meta.total_frames} frames")
|
logging.info(f"Source dataset: {src_meta.total_episodes} episodes, {src_meta.total_frames} frames")
|
||||||
logging.info(f"Original tasks: {len(src_meta.tasks)}")
|
logging.info(f"Original tasks: {len(src_meta.tasks)}")
|
||||||
@@ -123,32 +116,20 @@ def unify_dataset_tasks(
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="Unify all tasks in a dataset to a single task 'fold'."
|
description="Unify all tasks in a dataset to a single task 'fold' (modifies in-place)."
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--input-repo-id",
|
"--repo-id",
|
||||||
type=str,
|
type=str,
|
||||||
default="lerobot-data-collection/full_folding_2025-11-30",
|
required=True,
|
||||||
help="Input dataset repository ID",
|
help="Dataset repository ID",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--output-repo-id",
|
"--root",
|
||||||
type=str,
|
|
||||||
default="lerobot-data-collection/folding_2025-11-30",
|
|
||||||
help="Output dataset repository ID",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--input-root",
|
|
||||||
type=Path,
|
type=Path,
|
||||||
default=None,
|
default=None,
|
||||||
help="Optional input root path (defaults to HF_LEROBOT_HOME/input_repo_id)",
|
help="Optional root path (defaults to HF_LEROBOT_HOME/repo_id)",
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--output-root",
|
|
||||||
type=Path,
|
|
||||||
default=None,
|
|
||||||
help="Optional output root path (defaults to HF_LEROBOT_HOME/output_repo_id)",
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--push-to-hub",
|
"--push-to-hub",
|
||||||
@@ -161,10 +142,8 @@ def main():
|
|||||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
||||||
|
|
||||||
unify_dataset_tasks(
|
unify_dataset_tasks(
|
||||||
input_repo_id=args.input_repo_id,
|
repo_id=args.repo_id,
|
||||||
output_repo_id=args.output_repo_id,
|
root=args.root,
|
||||||
input_root=args.input_root,
|
|
||||||
output_root=args.output_root,
|
|
||||||
push_to_hub=args.push_to_hub,
|
push_to_hub=args.push_to_hub,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user