mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-26 22:20:06 +00:00
67bdf4690e
Replace the earlier wrapper (which depended on robocasa.scripts.download + dataset_registry) with a self-contained pipeline that: * downloads each task tarball directly from Box via box_links_ds.json * converts v2.1 -> v3.0 in place using convert_dataset_v21_to_v30 * standardizes camera keys under observation.images.robot0_* and flattens observation.state by concatenating base/EE/gripper subkeys when the source dataset stores them separately * builds per-rank unified shards then aggregates into one dataset Filter: composite_seen task-set restricts discovery to the 16 multi-step target tasks (DeliverStraw, GetToastedBread, ..., WashLettuce). Use --task-set=all to keep every discovered task in the split/source slice; --tasks=... overrides for arbitrary subsets. Defaults sized for hopper-cpu @ 128 cores: 16 workers x 8 cpus-per-task. Adapted from a battle-tested port_robocasa.py reference shared by the user; the only semantic addition is the task-set filter. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1002 lines
37 KiB
Python
1002 lines
37 KiB
Python
#!/usr/bin/env python
|
||
|
||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||
#
|
||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
# you may not use this file except in compliance with the License.
|
||
# You may obtain a copy of the License at
|
||
#
|
||
# http://www.apache.org/licenses/LICENSE-2.0
|
||
#
|
||
# Unless required by applicable law or agreed to in writing, software
|
||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
# See the License for the specific language governing permissions and
|
||
# limitations under the License.
|
||
|
||
"""Rebuild the 16 RoboCasa composite_seen tarballs into one unified LeRobot v3 dataset.
|
||
|
||
Filter-only wrapper around the canonical RoboCasa port script — restricts the
|
||
discovered task set to the 16 ``composite_seen`` tasks (the multi-step subset
|
||
of the official RoboCasa365 target benchmark) so a single command produces the
|
||
exact dataset slice needed for an apples-to-apples pi05 vs pi052 comparison
|
||
on multi-step kitchen manipulation.
|
||
|
||
Per-rank, each datatrove worker:
|
||
|
||
1. Downloads the assigned task tarball(s) directly from Box (resolved via the
|
||
``box_links_ds.json`` bundled with the local ``robocasa`` clone).
|
||
2. Converts the extracted LeRobot v2.1 dataset to v3.0 in place.
|
||
3. Rewrites the per-task data into a per-rank shard with:
|
||
- the canonical RoboCasa task name in ``task``
|
||
- standardized camera keys under ``observation.images.robot0_*``
|
||
- a guaranteed flat ``observation.state`` (concatenation of base / EE /
|
||
gripper sub-keys when the source dataset stores them separately)
|
||
- a standardized ``action`` key
|
||
|
||
A single aggregate worker then merges all shards into one unified dataset.
|
||
|
||
Heavy lifting is parallelized via Datatrove + SLURM on CPU nodes. With
|
||
``--workers=16 --cpus-per-task=8`` on ``hopper-cpu`` you get 128 CPUs total
|
||
across the prepare phase (one task per worker, 8 CPUs each for ffmpeg /
|
||
parquet) and the aggregate phase reuses the same CPU budget on a single node.
|
||
|
||
Typical hopper-cpu invocation::
|
||
|
||
uv run python examples/port_datasets/slurm_build_robocasa_composite_seen.py \\
|
||
--repo-id=${HF_USER}/robocasa_composite_seen_v3 \\
|
||
--work-dir=/fsx/${USER}/robocasa/datasets/v1.0 \\
|
||
--robocasa-root=/fsx/${USER}/robocasa \\
|
||
--split=target \\
|
||
--source=human \\
|
||
--partition=hopper-cpu \\
|
||
--workers=16 \\
|
||
--cpus-per-task=8 \\
|
||
--mem-per-cpu=4G \\
|
||
--time=24:00:00 \\
|
||
--logs-dir=/fsx/${USER}/logs/robocasa
|
||
|
||
Local debug (no SLURM)::
|
||
|
||
uv run python examples/port_datasets/slurm_build_robocasa_composite_seen.py \\
|
||
--repo-id=local/robocasa_composite_seen_v3_smoke \\
|
||
--work-dir=/tmp/robocasa_smoke \\
|
||
--robocasa-root=$HOME/robocasa \\
|
||
--slurm=0 --workers=1 \\
|
||
--tasks PrepareCoffee
|
||
|
||
If ``robocasa`` is already importable in the runtime environment, you can omit
|
||
``--robocasa-root``; the box-links manifest will be located from the package.
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import argparse
|
||
import json
|
||
from pathlib import Path
|
||
|
||
from datatrove.executor import LocalPipelineExecutor
|
||
from datatrove.executor.slurm import SlurmPipelineExecutor
|
||
from datatrove.pipeline.base import PipelineStep
|
||
|
||
DEFAULT_SPLIT = "target"
|
||
DEFAULT_SOURCE = "human"
|
||
DEFAULT_ROBOT_TYPE = "robocasa"
|
||
|
||
# The 16 composite_seen tasks (RoboCasa365 target benchmark, multi-step subset).
|
||
# Order matches the official RoboCasa documentation.
|
||
COMPOSITE_SEEN_TASKS: list[str] = [
|
||
"DeliverStraw",
|
||
"GetToastedBread",
|
||
"KettleBoiling",
|
||
"LoadDishwasher",
|
||
"PackIdenticalLunches",
|
||
"PreSoakPan",
|
||
"PrepareCoffee",
|
||
"RinseSinkBasin",
|
||
"ScrubCuttingBoard",
|
||
"SearingMeat",
|
||
"SetUpCuttingStation",
|
||
"StackBowlsCabinet",
|
||
"SteamInMicrowave",
|
||
"StirVegetables",
|
||
"StoreLeftoversInBowl",
|
||
"WashLettuce",
|
||
]
|
||
|
||
# Other groupings, exposed via ``--task-set`` for symmetry — populated lazily.
|
||
TASK_SETS: dict[str, list[str]] = {
|
||
"composite_seen": COMPOSITE_SEEN_TASKS,
|
||
"all": [], # sentinel — no filter
|
||
}
|
||
|
||
|
||
def _task_name_from_tar_key(tar_key: str) -> str:
|
||
parts = tar_key.split("/")
|
||
if len(parts) < 3:
|
||
raise ValueError(f"Unexpected RoboCasa tar key: {tar_key}")
|
||
return parts[2].removesuffix(".tar")
|
||
|
||
|
||
def _resolve_box_links_json(
|
||
box_links_json: Path | None,
|
||
robocasa_root: Path | None,
|
||
) -> Path:
|
||
if box_links_json is not None:
|
||
if not box_links_json.exists():
|
||
raise FileNotFoundError(f"--box-links-json does not exist: {box_links_json}")
|
||
return box_links_json
|
||
|
||
if robocasa_root is not None:
|
||
candidates = [
|
||
robocasa_root / "models" / "assets" / "box_links" / "box_links_ds.json",
|
||
robocasa_root / "robocasa" / "models" / "assets" / "box_links" / "box_links_ds.json",
|
||
]
|
||
for candidate in candidates:
|
||
if candidate.exists():
|
||
return candidate
|
||
raise FileNotFoundError(
|
||
f"Could not find box_links_ds.json under --robocasa-root={robocasa_root}"
|
||
)
|
||
|
||
try:
|
||
import robocasa # noqa: PLC0415
|
||
except ModuleNotFoundError as exc:
|
||
raise FileNotFoundError(
|
||
"Could not resolve RoboCasa box links. Pass --robocasa-root or --box-links-json, "
|
||
"or run in an environment where `robocasa` is importable."
|
||
) from exc
|
||
|
||
candidate = Path(robocasa.__path__[0]) / "models" / "assets" / "box_links" / "box_links_ds.json"
|
||
if not candidate.exists():
|
||
raise FileNotFoundError(f"Resolved RoboCasa package, but box links file is missing: {candidate}")
|
||
return candidate
|
||
|
||
|
||
def _discover_tasks(
|
||
box_links_json: Path,
|
||
split: str = DEFAULT_SPLIT,
|
||
source: str | None = DEFAULT_SOURCE,
|
||
) -> list[dict[str, str]]:
|
||
with open(box_links_json) as f:
|
||
box_links: dict[str, str] = json.load(f)
|
||
|
||
tasks: list[dict[str, str]] = []
|
||
for tar_key in sorted(box_links):
|
||
parts = tar_key.split("/")
|
||
if len(parts) < 3 or parts[0] != split:
|
||
continue
|
||
|
||
# RoboCasa registries can appear in two layouts:
|
||
# new: split/<atomic|composite>/<task>/<date>/lerobot.tar
|
||
# old: split/<human|mimicgen>/<task>.tar
|
||
if parts[1] in {"human", "mimicgen"}:
|
||
tar_source = parts[1]
|
||
else:
|
||
tar_source = "human"
|
||
|
||
if source is not None and tar_source != source:
|
||
continue
|
||
|
||
tasks.append(
|
||
{
|
||
"task_name": _task_name_from_tar_key(tar_key),
|
||
"tar_key": tar_key,
|
||
"source": tar_source,
|
||
"rel_path": tar_key.removesuffix(".tar"),
|
||
"shared_url": box_links[tar_key],
|
||
}
|
||
)
|
||
return tasks
|
||
|
||
|
||
class PrepareRoboCasaUnifiedShards(PipelineStep):
|
||
"""Build per-rank unified shards from RoboCasa task tarballs."""
|
||
|
||
def __init__(
|
||
self,
|
||
tasks: list[dict[str, str]],
|
||
output_repo_id: str,
|
||
work_dir: str,
|
||
split: str,
|
||
robot_type: str,
|
||
overwrite: bool = False,
|
||
cleanup_temp: bool = False,
|
||
max_episodes_per_task: int | None = None,
|
||
vcodec: str = "libsvtav1",
|
||
):
|
||
super().__init__()
|
||
self.tasks = tasks
|
||
self.output_repo_id = output_repo_id
|
||
self.work_dir = Path(work_dir)
|
||
self.split = split
|
||
self.robot_type = robot_type
|
||
self.overwrite = overwrite
|
||
self.cleanup_temp = cleanup_temp
|
||
self.max_episodes_per_task = max_episodes_per_task
|
||
self.vcodec = vcodec
|
||
|
||
def run(self, data=None, rank: int = 0, world_size: int = 1):
|
||
import copy
|
||
import json
|
||
import logging
|
||
import shutil
|
||
import tarfile
|
||
import urllib.request
|
||
|
||
import numpy as np
|
||
from PIL import Image
|
||
|
||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||
from lerobot.datasets.utils import DEFAULT_DATA_FILE_SIZE_IN_MB, DEFAULT_VIDEO_FILE_SIZE_IN_MB
|
||
from lerobot.scripts.convert_dataset_v21_to_v30 import (
|
||
convert_data,
|
||
convert_episodes_metadata,
|
||
convert_info,
|
||
convert_tasks,
|
||
convert_videos,
|
||
validate_local_dataset_version,
|
||
)
|
||
from lerobot.utils.utils import init_logging
|
||
|
||
init_logging()
|
||
|
||
target_image_keys = {
|
||
"observation.images.robot0_agentview_left": [
|
||
"observation.images.robot0_agentview_left",
|
||
"left_image",
|
||
"observation.images.left_image",
|
||
],
|
||
"observation.images.robot0_agentview_right": [
|
||
"observation.images.robot0_agentview_right",
|
||
"right_image",
|
||
"observation.images.right_image",
|
||
],
|
||
"observation.images.robot0_eye_in_hand": [
|
||
"observation.images.robot0_eye_in_hand",
|
||
"wrist_image",
|
||
"observation.images.wrist_image",
|
||
],
|
||
}
|
||
direct_state_keys = [
|
||
"observation.state",
|
||
"state",
|
||
]
|
||
explicit_state_groups = [
|
||
[
|
||
"observation.state.base_position",
|
||
"observation.state.base_rotation",
|
||
"observation.state.end_effector_position_relative",
|
||
"observation.state.end_effector_rotation_relative",
|
||
"observation.state.gripper_qpos",
|
||
],
|
||
[
|
||
"state.base_position",
|
||
"state.base_rotation",
|
||
"state.end_effector_position_relative",
|
||
"state.end_effector_rotation_relative",
|
||
"state.gripper_qpos",
|
||
],
|
||
]
|
||
|
||
my_tasks = self.tasks[rank::world_size]
|
||
logging.info(
|
||
"Rank %s/%s: rebuilding %s of %s tasks",
|
||
rank,
|
||
world_size,
|
||
len(my_tasks),
|
||
len(self.tasks),
|
||
)
|
||
if not my_tasks:
|
||
return
|
||
|
||
shard_repo_id = f"{self.output_repo_id}_world_{world_size}_rank_{rank}"
|
||
shard_root = (
|
||
self.work_dir
|
||
/ "shards"
|
||
/ self.output_repo_id.replace("/", "__")
|
||
/ f"world_{world_size}"
|
||
/ f"rank_{rank}"
|
||
)
|
||
|
||
def shard_is_complete(root: Path) -> bool:
|
||
info_path = root / "meta" / "info.json"
|
||
tasks_path = root / "meta" / "tasks.parquet"
|
||
stats_path = root / "meta" / "stats.json"
|
||
if not (info_path.exists() and tasks_path.exists() and stats_path.exists()):
|
||
return False
|
||
|
||
episodes_dir = root / "meta" / "episodes"
|
||
data_dir = root / "data"
|
||
videos_dir = root / "videos"
|
||
if not episodes_dir.exists() or not data_dir.exists() or not videos_dir.exists():
|
||
return False
|
||
if not any(episodes_dir.rglob("*.parquet")):
|
||
return False
|
||
if not any(data_dir.rglob("*.parquet")):
|
||
return False
|
||
if not any(videos_dir.rglob("*.mp4")):
|
||
return False
|
||
|
||
with open(info_path) as f:
|
||
info = json.load(f)
|
||
return info.get("total_episodes", 0) > 0 and info.get("total_frames", 0) > 0
|
||
|
||
if shard_is_complete(shard_root) and not self.overwrite:
|
||
logging.info("Shard already complete, skipping rank %s: %s", rank, shard_root)
|
||
return
|
||
if shard_root.exists():
|
||
if self.overwrite:
|
||
logging.warning("Removing existing shard root (--overwrite): %s", shard_root)
|
||
else:
|
||
logging.warning("Removing incomplete shard root before rebuild: %s", shard_root)
|
||
shutil.rmtree(shard_root)
|
||
|
||
def direct_download_url(shared_url: str) -> str:
|
||
shared_id = shared_url.rstrip("/").split("/")[-1]
|
||
base = shared_url.split("/s/")[0]
|
||
return f"{base}/shared/static/{shared_id}.tar"
|
||
|
||
def restore_v21_root_if_needed(dataset_root: Path) -> None:
|
||
old_root = dataset_root.parent / f"{dataset_root.name}_old"
|
||
if not dataset_root.exists() and old_root.exists():
|
||
shutil.move(str(old_root), str(dataset_root))
|
||
|
||
def download_and_extract(shared_url: str, destination: Path) -> None:
|
||
url = direct_download_url(shared_url)
|
||
extract_dir = destination.parent
|
||
extract_dir.mkdir(parents=True, exist_ok=True)
|
||
tar_path = extract_dir / f"{destination.name}.tar"
|
||
|
||
if destination.exists() and (destination / "meta" / "info.json").exists():
|
||
logging.info(" Already extracted: %s", destination)
|
||
return
|
||
|
||
for attempt in range(3):
|
||
try:
|
||
logging.info(" Downloading (attempt %s) -> %s", attempt + 1, tar_path)
|
||
urllib.request.urlretrieve(url, str(tar_path))
|
||
break
|
||
except Exception as exc:
|
||
logging.warning(" Download attempt %s failed: %s", attempt + 1, exc)
|
||
if tar_path.exists():
|
||
tar_path.unlink()
|
||
else:
|
||
raise RuntimeError(f"Failed to download {url} after 3 attempts")
|
||
|
||
logging.info(" Extracting to %s", extract_dir)
|
||
with tarfile.open(tar_path, "r") as tar:
|
||
tar.extractall(path=extract_dir)
|
||
tar_path.unlink()
|
||
|
||
def is_v30(dataset_root: Path) -> bool:
|
||
info_path = dataset_root / "meta" / "info.json"
|
||
if not info_path.exists():
|
||
return False
|
||
with open(info_path) as f:
|
||
info = json.load(f)
|
||
return info.get("codebase_version") == "v3.0"
|
||
|
||
def convert_v21_to_v30(dataset_root: Path) -> None:
|
||
data_mb = DEFAULT_DATA_FILE_SIZE_IN_MB
|
||
video_mb = DEFAULT_VIDEO_FILE_SIZE_IN_MB
|
||
|
||
validate_local_dataset_version(dataset_root)
|
||
|
||
new_root = dataset_root.parent / f"{dataset_root.name}_v30"
|
||
if new_root.exists():
|
||
shutil.rmtree(new_root)
|
||
|
||
convert_info(dataset_root, new_root, data_mb, video_mb)
|
||
convert_tasks(dataset_root, new_root)
|
||
episodes_metadata = convert_data(dataset_root, new_root, data_mb)
|
||
episodes_video_metadata = convert_videos(dataset_root, new_root, video_mb)
|
||
convert_episodes_metadata(
|
||
dataset_root,
|
||
new_root,
|
||
episodes_metadata,
|
||
episodes_video_metadata,
|
||
)
|
||
|
||
old_root = dataset_root.parent / f"{dataset_root.name}_old"
|
||
if old_root.exists():
|
||
shutil.rmtree(old_root)
|
||
shutil.move(str(dataset_root), str(old_root))
|
||
shutil.move(str(new_root), str(dataset_root))
|
||
logging.info(" Conversion complete: %s", dataset_root)
|
||
|
||
def as_float32_vector(value) -> np.ndarray:
|
||
if value.__class__.__module__.startswith("torch"):
|
||
arr = value.detach().cpu().numpy()
|
||
else:
|
||
arr = np.asarray(value)
|
||
return arr.astype(np.float32).reshape(-1)
|
||
|
||
def to_pil_image(value) -> Image.Image:
|
||
if isinstance(value, Image.Image):
|
||
return value
|
||
if value.__class__.__module__.startswith("torch"):
|
||
arr = value.detach().cpu()
|
||
if arr.ndim != 3:
|
||
raise ValueError(f"Expected rank-3 image tensor, got shape {tuple(arr.shape)}")
|
||
if arr.shape[0] in (1, 3, 4) and arr.shape[-1] not in (1, 3, 4):
|
||
arr = arr.permute(1, 2, 0)
|
||
if getattr(arr.dtype, "is_floating_point", False):
|
||
if float(arr.max()) <= 1.0:
|
||
arr = arr * 255.0
|
||
arr = arr.clamp(0, 255).byte()
|
||
else:
|
||
arr = arr.byte()
|
||
return Image.fromarray(arr.numpy())
|
||
|
||
arr = np.asarray(value)
|
||
if arr.ndim != 3:
|
||
raise ValueError(f"Expected rank-3 image array, got shape {arr.shape}")
|
||
if arr.shape[0] in (1, 3, 4) and arr.shape[-1] not in (1, 3, 4):
|
||
arr = np.transpose(arr, (1, 2, 0))
|
||
if np.issubdtype(arr.dtype, np.floating):
|
||
if float(arr.max()) <= 1.0:
|
||
arr = arr * 255.0
|
||
arr = np.clip(arr, 0, 255).astype(np.uint8)
|
||
elif arr.dtype != np.uint8:
|
||
arr = arr.astype(np.uint8)
|
||
return Image.fromarray(arr)
|
||
|
||
def normalize_name(name: str) -> str:
|
||
return name.replace("/", ".").replace("_", ".").lower()
|
||
|
||
def choose_one(available_keys: list[str], aliases: list[str], label: str) -> str:
|
||
for alias in aliases:
|
||
if alias in available_keys:
|
||
return alias
|
||
raise ValueError(f"Could not resolve {label}. Available keys: {available_keys}")
|
||
|
||
def resolve_image_key_map(available_keys: list[str]) -> dict[str, str]:
|
||
return {
|
||
target_key: choose_one(available_keys, aliases, target_key)
|
||
for target_key, aliases in target_image_keys.items()
|
||
}
|
||
|
||
def resolve_action_key(available_keys: list[str]) -> str:
|
||
return choose_one(available_keys, ["action", "actions"], "action")
|
||
|
||
def state_sort_key(name: str) -> tuple[int, str]:
|
||
normalized = normalize_name(name)
|
||
if "base.position" in normalized:
|
||
return (0, normalized)
|
||
if "base.rotation" in normalized or "base.quat" in normalized:
|
||
return (1, normalized)
|
||
if "end.effector.position" in normalized or "eef.pos" in normalized:
|
||
return (2, normalized)
|
||
if "end.effector.rotation" in normalized or "eef.quat" in normalized or "eef.rot" in normalized:
|
||
return (3, normalized)
|
||
if "gripper" in normalized:
|
||
return (4, normalized)
|
||
return (5, normalized)
|
||
|
||
def resolve_state_keys(available_keys: list[str]) -> list[str]:
|
||
for key in direct_state_keys:
|
||
if key in available_keys:
|
||
return [key]
|
||
|
||
for group in explicit_state_groups:
|
||
if all(key in available_keys for key in group):
|
||
return group
|
||
|
||
prefix_keys = [
|
||
key
|
||
for key in available_keys
|
||
if key.startswith("observation.state.") or key.startswith("state.")
|
||
]
|
||
if prefix_keys:
|
||
return sorted(prefix_keys, key=state_sort_key)
|
||
|
||
proprio_like = [
|
||
key
|
||
for key in available_keys
|
||
if any(
|
||
token in normalize_name(key)
|
||
for token in ["base.position", "base.rotation", "end.effector", "eef", "gripper"]
|
||
)
|
||
]
|
||
if proprio_like:
|
||
return sorted(set(proprio_like), key=state_sort_key)
|
||
|
||
raise ValueError(f"Could not resolve RoboCasa proprioception keys. Available keys: {available_keys}")
|
||
|
||
def build_state(item: dict, state_keys: list[str]) -> np.ndarray:
|
||
if len(state_keys) == 1:
|
||
return as_float32_vector(item[state_keys[0]])
|
||
parts = [as_float32_vector(item[key]) for key in state_keys]
|
||
return np.concatenate(parts, axis=0).astype(np.float32)
|
||
|
||
def infer_target_features(
|
||
src_dataset: LeRobotDataset,
|
||
image_key_map: dict[str, str],
|
||
action_key: str,
|
||
state_dim: int,
|
||
) -> tuple[dict, bool]:
|
||
features = {}
|
||
use_videos = False
|
||
|
||
for target_key, source_key in image_key_map.items():
|
||
feature_info = copy.deepcopy(src_dataset.meta.features[source_key])
|
||
if "fps" not in feature_info and feature_info.get("dtype") != "video":
|
||
feature_info["fps"] = int(src_dataset.meta.fps)
|
||
use_videos = use_videos or feature_info.get("dtype") == "video"
|
||
features[target_key] = feature_info
|
||
|
||
action_info = copy.deepcopy(src_dataset.meta.features[action_key])
|
||
action_info["dtype"] = "float32"
|
||
action_info["fps"] = int(src_dataset.meta.fps)
|
||
features["action"] = action_info
|
||
|
||
features["observation.state"] = {
|
||
"dtype": "float32",
|
||
"shape": (state_dim,),
|
||
"names": [f"state_{i}" for i in range(state_dim)],
|
||
"fps": int(src_dataset.meta.fps),
|
||
}
|
||
|
||
return features, use_videos
|
||
|
||
def task_root(task_meta: dict[str, str]) -> Path:
|
||
return self.work_dir / task_meta["rel_path"]
|
||
|
||
def cleanup_task_root(dataset_root: Path) -> None:
|
||
old_root = dataset_root.parent / f"{dataset_root.name}_old"
|
||
if dataset_root.exists():
|
||
shutil.rmtree(dataset_root)
|
||
if old_root.exists():
|
||
shutil.rmtree(old_root)
|
||
|
||
shard_dataset = None
|
||
shard_meta: dict[str, int | tuple[int, ...]] | None = None
|
||
|
||
for task_meta in my_tasks:
|
||
task_name = task_meta["task_name"]
|
||
dataset_root = task_root(task_meta)
|
||
|
||
logging.info("--- %s (%s) ---", task_name, task_meta["tar_key"])
|
||
restore_v21_root_if_needed(dataset_root)
|
||
|
||
download_and_extract(task_meta["shared_url"], dataset_root)
|
||
if not is_v30(dataset_root):
|
||
convert_v21_to_v30(dataset_root)
|
||
|
||
src_dataset = LeRobotDataset(repo_id=task_name, root=dataset_root)
|
||
available_keys = list(src_dataset.meta.features.keys())
|
||
image_key_map = resolve_image_key_map(available_keys)
|
||
action_key = resolve_action_key(available_keys)
|
||
state_keys = resolve_state_keys(available_keys)
|
||
|
||
if len(src_dataset) == 0:
|
||
raise ValueError(f"Task dataset is empty: {dataset_root}")
|
||
|
||
first_item = src_dataset[0]
|
||
first_state = build_state(first_item, state_keys)
|
||
first_action = as_float32_vector(first_item[action_key])
|
||
|
||
if shard_dataset is None:
|
||
target_features, use_videos = infer_target_features(
|
||
src_dataset=src_dataset,
|
||
image_key_map=image_key_map,
|
||
action_key=action_key,
|
||
state_dim=int(first_state.size),
|
||
)
|
||
shard_dataset = LeRobotDataset.create(
|
||
repo_id=shard_repo_id,
|
||
root=shard_root,
|
||
fps=int(src_dataset.meta.fps),
|
||
robot_type=self.robot_type,
|
||
features=target_features,
|
||
use_videos=use_videos,
|
||
vcodec=self.vcodec,
|
||
batch_encoding_size=1,
|
||
)
|
||
shard_meta = {
|
||
"fps": int(src_dataset.meta.fps),
|
||
"state_dim": int(first_state.size),
|
||
"action_shape": tuple(first_action.shape),
|
||
}
|
||
else:
|
||
assert shard_meta is not None
|
||
if int(src_dataset.meta.fps) != shard_meta["fps"]:
|
||
raise ValueError(
|
||
f"FPS mismatch for {task_name}: {src_dataset.meta.fps} != {shard_meta['fps']}"
|
||
)
|
||
if int(first_state.size) != shard_meta["state_dim"]:
|
||
raise ValueError(
|
||
f"State dim mismatch for {task_name}: {first_state.size} != {shard_meta['state_dim']}"
|
||
)
|
||
if tuple(first_action.shape) != shard_meta["action_shape"]:
|
||
raise ValueError(
|
||
f"Action shape mismatch for {task_name}: {tuple(first_action.shape)} != "
|
||
f"{shard_meta['action_shape']}"
|
||
)
|
||
|
||
num_episodes = src_dataset.num_episodes
|
||
if self.max_episodes_per_task is not None:
|
||
num_episodes = min(num_episodes, self.max_episodes_per_task)
|
||
|
||
logging.info(" Appending %s episodes into shard %s", num_episodes, shard_root)
|
||
for episode_idx in range(num_episodes):
|
||
start = int(src_dataset.meta.episodes["dataset_from_index"][episode_idx])
|
||
end = int(src_dataset.meta.episodes["dataset_to_index"][episode_idx])
|
||
|
||
for frame_idx in range(start, end):
|
||
item = src_dataset[frame_idx]
|
||
frame = {
|
||
"task": task_name,
|
||
"observation.state": build_state(item, state_keys),
|
||
"action": as_float32_vector(item[action_key]),
|
||
}
|
||
for target_key, source_key in image_key_map.items():
|
||
frame[target_key] = to_pil_image(item[source_key])
|
||
shard_dataset.add_frame(frame)
|
||
|
||
shard_dataset.save_episode()
|
||
|
||
if self.cleanup_temp:
|
||
cleanup_task_root(dataset_root)
|
||
|
||
if shard_dataset is None:
|
||
logging.warning("Rank %s produced no shard dataset", rank)
|
||
return
|
||
|
||
shard_dataset.finalize()
|
||
logging.info("Rank %s finalized shard at %s", rank, shard_root)
|
||
|
||
|
||
class AggregateRoboCasaUnifiedShards(PipelineStep):
|
||
"""Aggregate repaired shard datasets into one final RoboCasa dataset."""
|
||
|
||
def __init__(
|
||
self,
|
||
output_repo_id: str,
|
||
shard_roots: list[str],
|
||
output_root: str,
|
||
push: bool = True,
|
||
overwrite: bool = False,
|
||
):
|
||
super().__init__()
|
||
self.output_repo_id = output_repo_id
|
||
self.shard_roots = [Path(root) for root in shard_roots]
|
||
self.output_root = Path(output_root)
|
||
self.push = push
|
||
self.overwrite = overwrite
|
||
|
||
def run(self, data=None, rank: int = 0, world_size: int = 1):
|
||
import json
|
||
import logging
|
||
import shutil
|
||
|
||
from lerobot.datasets.aggregate import aggregate_datasets
|
||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||
from lerobot.utils.utils import init_logging
|
||
|
||
init_logging()
|
||
|
||
if rank != 0:
|
||
logging.info("Rank %s: only rank 0 aggregates", rank)
|
||
return
|
||
|
||
def shard_is_complete(root: Path) -> bool:
|
||
info_path = root / "meta" / "info.json"
|
||
tasks_path = root / "meta" / "tasks.parquet"
|
||
stats_path = root / "meta" / "stats.json"
|
||
if not (info_path.exists() and tasks_path.exists() and stats_path.exists()):
|
||
return False
|
||
|
||
episodes_dir = root / "meta" / "episodes"
|
||
data_dir = root / "data"
|
||
videos_dir = root / "videos"
|
||
if not episodes_dir.exists() or not data_dir.exists() or not videos_dir.exists():
|
||
return False
|
||
if not any(episodes_dir.rglob("*.parquet")):
|
||
return False
|
||
if not any(data_dir.rglob("*.parquet")):
|
||
return False
|
||
if not any(videos_dir.rglob("*.mp4")):
|
||
return False
|
||
|
||
with open(info_path) as f:
|
||
info = json.load(f)
|
||
return info.get("total_episodes", 0) > 0 and info.get("total_frames", 0) > 0
|
||
|
||
missing = [root for root in self.shard_roots if not shard_is_complete(root)]
|
||
if missing:
|
||
raise FileNotFoundError(f"Missing shard datasets: {missing}")
|
||
|
||
if self.output_root.exists() and self.overwrite:
|
||
logging.warning("Removing existing unified output (--overwrite): %s", self.output_root)
|
||
shutil.rmtree(self.output_root)
|
||
|
||
shard_repo_ids = [f"{self.output_repo_id}_shard_{idx}" for idx in range(len(self.shard_roots))]
|
||
logging.info("Aggregating %s shards into %s", len(self.shard_roots), self.output_root)
|
||
aggregate_datasets(
|
||
repo_ids=shard_repo_ids,
|
||
roots=self.shard_roots,
|
||
aggr_repo_id=self.output_repo_id,
|
||
aggr_root=self.output_root,
|
||
)
|
||
|
||
if self.push:
|
||
dataset = LeRobotDataset(repo_id=self.output_repo_id, root=self.output_root)
|
||
dataset.push_to_hub(
|
||
tags=["lerobot", "robocasa", "composite_seen", "unified"],
|
||
private=False,
|
||
)
|
||
logging.info("Pushed to https://huggingface.co/datasets/%s", self.output_repo_id)
|
||
|
||
|
||
def make_prepare_executor(
|
||
*,
|
||
tasks: list[dict[str, str]],
|
||
output_repo_id: str,
|
||
work_dir: Path,
|
||
split: str,
|
||
robot_type: str,
|
||
overwrite: bool,
|
||
cleanup_temp: bool,
|
||
max_episodes_per_task: int | None,
|
||
vcodec: str,
|
||
job_name: str,
|
||
logs_dir: Path,
|
||
workers: int,
|
||
partition: str,
|
||
cpus_per_task: int,
|
||
mem_per_cpu: str,
|
||
time_limit: str,
|
||
slurm: bool,
|
||
):
|
||
kwargs = {
|
||
"pipeline": [
|
||
PrepareRoboCasaUnifiedShards(
|
||
tasks=tasks,
|
||
output_repo_id=output_repo_id,
|
||
work_dir=str(work_dir),
|
||
split=split,
|
||
robot_type=robot_type,
|
||
overwrite=overwrite,
|
||
cleanup_temp=cleanup_temp,
|
||
max_episodes_per_task=max_episodes_per_task,
|
||
vcodec=vcodec,
|
||
)
|
||
],
|
||
"logging_dir": str(logs_dir / job_name),
|
||
}
|
||
|
||
if slurm:
|
||
kwargs.update(
|
||
{
|
||
"job_name": job_name,
|
||
"tasks": workers,
|
||
"workers": workers,
|
||
"time": time_limit,
|
||
"partition": partition,
|
||
"cpus_per_task": cpus_per_task,
|
||
"sbatch_args": {"mem-per-cpu": mem_per_cpu},
|
||
}
|
||
)
|
||
return SlurmPipelineExecutor(**kwargs)
|
||
|
||
kwargs.update({"tasks": workers, "workers": 1})
|
||
return LocalPipelineExecutor(**kwargs)
|
||
|
||
|
||
def make_aggregate_executor(
|
||
*,
|
||
output_repo_id: str,
|
||
shard_roots: list[Path],
|
||
output_root: Path,
|
||
push: bool,
|
||
overwrite: bool,
|
||
job_name: str,
|
||
logs_dir: Path,
|
||
partition: str,
|
||
cpus_per_task: int,
|
||
mem_per_cpu: str,
|
||
time_limit: str,
|
||
slurm: bool,
|
||
depends: SlurmPipelineExecutor | None = None,
|
||
):
|
||
kwargs = {
|
||
"pipeline": [
|
||
AggregateRoboCasaUnifiedShards(
|
||
output_repo_id=output_repo_id,
|
||
shard_roots=[str(root) for root in shard_roots],
|
||
output_root=str(output_root),
|
||
push=push,
|
||
overwrite=overwrite,
|
||
)
|
||
],
|
||
"logging_dir": str(logs_dir / job_name),
|
||
}
|
||
|
||
if slurm:
|
||
kwargs.update(
|
||
{
|
||
"job_name": job_name,
|
||
"tasks": 1,
|
||
"workers": 1,
|
||
"time": time_limit,
|
||
"partition": partition,
|
||
"cpus_per_task": cpus_per_task,
|
||
"sbatch_args": {"mem-per-cpu": mem_per_cpu},
|
||
"depends": depends,
|
||
}
|
||
)
|
||
return SlurmPipelineExecutor(**kwargs)
|
||
|
||
kwargs.update({"tasks": 1, "workers": 1})
|
||
return LocalPipelineExecutor(**kwargs)
|
||
|
||
|
||
def resolve_repo_id(args: argparse.Namespace) -> str:
|
||
if args.repo_id:
|
||
return args.repo_id
|
||
if args.hf_user:
|
||
return f"{args.hf_user}/robocasa_composite_seen_{args.split}_{args.source}_unified_v3"
|
||
raise ValueError("Pass either --repo-id or --hf-user.")
|
||
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(
|
||
description="Rebuild the 16 RoboCasa composite_seen tarballs into one unified LeRobot v3 dataset."
|
||
)
|
||
parser.add_argument("--repo-id", type=str, default=None, help="Final unified dataset repo id.")
|
||
parser.add_argument(
|
||
"--hf-user",
|
||
type=str,
|
||
default=None,
|
||
help="Optional shorthand. If set and --repo-id is omitted, derive "
|
||
"<hf_user>/robocasa_composite_seen_<split>_<source>_unified_v3.",
|
||
)
|
||
parser.add_argument("--work-dir", type=Path, required=True)
|
||
parser.add_argument("--split", type=str, default=DEFAULT_SPLIT, choices=["target", "pretrain"])
|
||
parser.add_argument("--source", type=str, default=DEFAULT_SOURCE)
|
||
parser.add_argument(
|
||
"--mode",
|
||
type=str,
|
||
default="all",
|
||
choices=["all", "prepare", "aggregate"],
|
||
help="prepare = build shards, aggregate = merge existing shards, all = do both.",
|
||
)
|
||
parser.add_argument(
|
||
"--task-set",
|
||
type=str,
|
||
default="composite_seen",
|
||
choices=sorted(TASK_SETS.keys()),
|
||
help="Predefined task set to restrict discovery to. Default "
|
||
"``composite_seen`` (the 16 multi-step composite_seen tasks). Use "
|
||
"``all`` to keep every discovered task in the split/source slice. "
|
||
"``--tasks`` overrides this when provided.",
|
||
)
|
||
parser.add_argument("--robocasa-root", type=Path, default=None)
|
||
parser.add_argument("--box-links-json", type=Path, default=None)
|
||
parser.add_argument("--robot-type", type=str, default=DEFAULT_ROBOT_TYPE)
|
||
parser.add_argument("--vcodec", type=str, default="libsvtav1")
|
||
parser.add_argument("--max-episodes-per-task", type=int, default=None)
|
||
parser.add_argument("--cleanup-temp", action="store_true")
|
||
parser.add_argument("--overwrite", action="store_true")
|
||
parser.add_argument("--no-push", action="store_true")
|
||
parser.add_argument("--logs-dir", type=Path, default=Path("logs"))
|
||
parser.add_argument("--job-name", type=str, default="port_robocasa_composite_seen")
|
||
parser.add_argument("--slurm", type=int, default=1, help="1 = Slurm executor, 0 = local debug.")
|
||
parser.add_argument(
|
||
"--workers",
|
||
type=int,
|
||
default=16,
|
||
help="Number of SLURM workers. Default 16 = one per composite_seen task.",
|
||
)
|
||
parser.add_argument("--partition", type=str, default="hopper-cpu")
|
||
parser.add_argument(
|
||
"--cpus-per-task",
|
||
type=int,
|
||
default=8,
|
||
help="CPUs per worker. 16 workers × 8 cpus = 128 cpus total on hopper-cpu.",
|
||
)
|
||
parser.add_argument("--mem-per-cpu", type=str, default="4G")
|
||
parser.add_argument("--time", type=str, default="24:00:00")
|
||
parser.add_argument(
|
||
"--tasks",
|
||
type=str,
|
||
nargs="*",
|
||
default=None,
|
||
help="Explicit task names. Overrides --task-set when provided.",
|
||
)
|
||
parser.add_argument("--dryrun", action="store_true")
|
||
args = parser.parse_args()
|
||
|
||
box_links_json = _resolve_box_links_json(args.box_links_json, args.robocasa_root)
|
||
all_tasks = _discover_tasks(box_links_json, split=args.split, source=args.source)
|
||
|
||
# Filter: explicit --tasks wins; otherwise apply --task-set.
|
||
if args.tasks:
|
||
selected = {task.lower() for task in args.tasks}
|
||
all_tasks = [task for task in all_tasks if task["task_name"].lower() in selected]
|
||
elif args.task_set != "all":
|
||
wanted = {t.lower() for t in TASK_SETS[args.task_set]}
|
||
all_tasks = [task for task in all_tasks if task["task_name"].lower() in wanted]
|
||
|
||
if not all_tasks:
|
||
raise ValueError(
|
||
f"No RoboCasa tasks selected for split={args.split!r}, source={args.source!r}, "
|
||
f"task_set={args.task_set!r}, tasks={args.tasks!r}"
|
||
)
|
||
|
||
print(f"Tasks to rebuild ({len(all_tasks)}):")
|
||
for task in all_tasks:
|
||
print(f" {task['task_name']} ({task['tar_key']})")
|
||
if args.dryrun:
|
||
return
|
||
|
||
output_repo_id = resolve_repo_id(args)
|
||
output_root = args.work_dir / "unified" / output_repo_id
|
||
active_ranks = [rank for rank in range(args.workers) if all_tasks[rank::args.workers]]
|
||
shard_roots = [
|
||
args.work_dir
|
||
/ "shards"
|
||
/ output_repo_id.replace("/", "__")
|
||
/ f"world_{args.workers}"
|
||
/ f"rank_{rank}"
|
||
for rank in active_ranks
|
||
]
|
||
|
||
prepare_executor = None
|
||
if args.mode in {"all", "prepare"}:
|
||
prepare_executor = make_prepare_executor(
|
||
tasks=all_tasks,
|
||
output_repo_id=output_repo_id,
|
||
work_dir=args.work_dir,
|
||
split=args.split,
|
||
robot_type=args.robot_type,
|
||
overwrite=args.overwrite,
|
||
cleanup_temp=args.cleanup_temp,
|
||
max_episodes_per_task=args.max_episodes_per_task,
|
||
vcodec=args.vcodec,
|
||
job_name=args.job_name,
|
||
logs_dir=args.logs_dir,
|
||
workers=args.workers,
|
||
partition=args.partition,
|
||
cpus_per_task=args.cpus_per_task,
|
||
mem_per_cpu=args.mem_per_cpu,
|
||
time_limit=args.time,
|
||
slurm=args.slurm == 1,
|
||
)
|
||
if args.mode == "prepare":
|
||
prepare_executor.run()
|
||
|
||
if args.mode in {"all", "aggregate"}:
|
||
aggregate_executor = make_aggregate_executor(
|
||
output_repo_id=output_repo_id,
|
||
shard_roots=shard_roots,
|
||
output_root=output_root,
|
||
push=not args.no_push,
|
||
overwrite=args.overwrite,
|
||
job_name=f"{args.job_name}_aggregate",
|
||
logs_dir=args.logs_dir,
|
||
partition=args.partition,
|
||
cpus_per_task=args.cpus_per_task,
|
||
mem_per_cpu=args.mem_per_cpu,
|
||
time_limit=args.time,
|
||
slurm=args.slurm == 1,
|
||
depends=prepare_executor if args.mode == "all" and args.slurm == 1 else None,
|
||
)
|
||
if args.mode == "all" and args.slurm == 1:
|
||
# SLURM: submitting the aggregate executor with depends=prepare_executor
|
||
# transitively submits prepare too, with the right --dependency=afterok.
|
||
aggregate_executor.run()
|
||
elif args.mode == "all":
|
||
# Local: run sequentially.
|
||
assert prepare_executor is not None
|
||
prepare_executor.run()
|
||
aggregate_executor.run()
|
||
else:
|
||
aggregate_executor.run()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|