diff --git a/examples/port_datasets/slurm_build_robocasa_composite_seen.py b/examples/port_datasets/slurm_build_robocasa_composite_seen.py index 24f33e935..9f28f6636 100644 --- a/examples/port_datasets/slurm_build_robocasa_composite_seen.py +++ b/examples/port_datasets/slurm_build_robocasa_composite_seen.py @@ -14,528 +14,988 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Build the merged RoboCasa composite_seen dataset on SLURM via datatrove. +"""Rebuild the 16 RoboCasa composite_seen tarballs into one unified LeRobot v3 dataset. -Two-phase pipeline modeled after ``slurm_port_shards.py`` + -``slurm_aggregate_shards.py``: +Filter-only wrapper around the canonical RoboCasa port script — restricts the +discovered task set to the 16 ``composite_seen`` tasks (the multi-step subset +of the official RoboCasa365 target benchmark) so a single command produces the +exact dataset slice needed for an apples-to-apples pi05 vs pi052 comparison +on multi-step kitchen manipulation. -* Phase 1 — DOWNLOAD (parallel, 16 tasks, 1 per worker): - Each datatrove worker downloads one of the 16 RoboCasa composite_seen task - archives (``v1.0/target/composite///lerobot.tar``) via - RoboCasa's own ``download_datasets`` helper. Idempotent — already-extracted - tasks are skipped. Network-bound, CPU-light. +Per-rank, each datatrove worker: -* Phase 2 — AGGREGATE (single worker, depends on phase 1): - One worker calls ``aggregate_datasets`` over the 16 extracted directories, - producing a single combined LeRobotDataset. Validates fps / robot_type / - features, unifies task indices, concatenates videos + parquet, recomputes - stats. CPU + disk heavy. +1. Downloads the assigned task tarball(s) directly from Box (resolved via the + ``box_links_ds.json`` bundled with the local ``robocasa`` clone). +2. Converts the extracted LeRobot v2.1 dataset to v3.0 in place. +3. Rewrites the per-task data into a per-rank shard with: + - the canonical RoboCasa task name in ``task`` + - standardized camera keys under ``observation.images.robot0_*`` + - a guaranteed flat ``observation.state`` (concatenation of base / EE / + gripper sub-keys when the source dataset stores them separately) + - a standardized ``action`` key -When run under SLURM, phase 2 is submitted with ``depends=phase_1_executor`` -so the scheduler only releases it after every download task succeeds. +A single aggregate worker then merges all shards into one unified dataset. -Local (``--slurm 0``) execution runs the two phases sequentially in the -current process — useful for debugging on a workstation. +Heavy lifting is parallelized via Datatrove + SLURM on CPU nodes. With +``--workers=16 --cpus-per-task=8`` on ``hopper-cpu`` you get 128 CPUs total +across the prepare phase (one task per worker, 8 CPUs each for ffmpeg / +parquet) and the aggregate phase reuses the same CPU budget on a single node. -Usage on SLURM:: +Typical hopper-cpu invocation:: uv run python examples/port_datasets/slurm_build_robocasa_composite_seen.py \\ - --output-dir=/scratch/${USER}/robocasa_composite_seen \\ - --hub-repo-id=${HF_USER}/robocasa_composite_seen \\ - --logs-dir=/scratch/${USER}/logs/robocasa \\ - --partition=cpu \\ - --download-cpus=4 --download-mem=8G \\ - --aggregate-cpus=16 --aggregate-mem-per-cpu=4G \\ - --push-to-hub + --repo-id=${HF_USER}/robocasa_composite_seen_v3 \\ + --work-dir=/fsx/${USER}/robocasa/datasets/v1.0 \\ + --robocasa-root=/fsx/${USER}/robocasa \\ + --split=target \\ + --source=human \\ + --partition=hopper-cpu \\ + --workers=16 \\ + --cpus-per-task=8 \\ + --mem-per-cpu=4G \\ + --time=24:00:00 \\ + --logs-dir=/fsx/${USER}/logs/robocasa -Local debug (sequential, single process):: +Local debug (no SLURM):: uv run python examples/port_datasets/slurm_build_robocasa_composite_seen.py \\ - --output-dir=/tmp/robocasa_composite_seen \\ - --slurm=0 \\ - --tasks=PrepareCoffee,KettleBoiling + --repo-id=local/robocasa_composite_seen_v3_smoke \\ + --work-dir=/tmp/robocasa_smoke \\ + --robocasa-root=$HOME/robocasa \\ + --slurm=0 --workers=1 \\ + --tasks PrepareCoffee -Prereqs: ``robocasa`` + ``robosuite`` installed (see -``docs/source/benchmarks/robocasa.mdx``); ``datatrove`` installed via the -``annotations`` extra (``uv sync --extra annotations``). +If ``robocasa`` is already importable in the runtime environment, you can omit +``--robocasa-root``; the box-links manifest will be located from the package. """ from __future__ import annotations import argparse -import logging +import json from pathlib import Path from datatrove.executor import LocalPipelineExecutor from datatrove.executor.slurm import SlurmPipelineExecutor from datatrove.pipeline.base import PipelineStep -# Reuse the per-task helpers + canonical task list from the single-machine -# script so both runners share one source of truth for "what does it mean to -# download a composite_seen task". The helpers are spelled with leading -# underscores there (module-private), but the slurm runner is a legitimate -# in-tree consumer, so we alias them to clean names here. -from lerobot.scripts.build_robocasa_composite_seen import ( - COMPOSITE_SEEN_TASKS, - _download_task as download_task, - _require_robocasa as require_robocasa, - _resolve_task_root as resolve_task_root, -) +DEFAULT_SPLIT = "target" +DEFAULT_SOURCE = "human" +DEFAULT_ROBOT_TYPE = "robocasa" -logger = logging.getLogger(__name__) +# The 16 composite_seen tasks (RoboCasa365 target benchmark, multi-step subset). +# Order matches the official RoboCasa documentation. +COMPOSITE_SEEN_TASKS: list[str] = [ + "DeliverStraw", + "GetToastedBread", + "KettleBoiling", + "LoadDishwasher", + "PackIdenticalLunches", + "PreSoakPan", + "PrepareCoffee", + "RinseSinkBasin", + "ScrubCuttingBoard", + "SearingMeat", + "SetUpCuttingStation", + "StackBowlsCabinet", + "SteamInMicrowave", + "StirVegetables", + "StoreLeftoversInBowl", + "WashLettuce", +] + +# Other groupings, exposed via ``--task-set`` for symmetry — populated lazily. +TASK_SETS: dict[str, list[str]] = { + "composite_seen": COMPOSITE_SEEN_TASKS, + "all": [], # sentinel — no filter +} -# --------------------------------------------------------------------------- -# Pipeline steps -# --------------------------------------------------------------------------- +def _task_name_from_tar_key(tar_key: str) -> str: + parts = tar_key.split("/") + if len(parts) < 3: + raise ValueError(f"Unexpected RoboCasa tar key: {tar_key}") + return parts[2].removesuffix(".tar") -class DownloadRoboCasaTask(PipelineStep): - """Phase 1 — download the task assigned to this rank. +def _resolve_box_links_json( + box_links_json: Path | None, + robocasa_root: Path | None, +) -> Path: + if box_links_json is not None: + if not box_links_json.exists(): + raise FileNotFoundError(f"--box-links-json does not exist: {box_links_json}") + return box_links_json - Each datatrove worker is given a rank in ``[0, world_size)``. With - ``tasks == len(task_names)`` each worker owns exactly one task; with - fewer workers than tasks, datatrove load-balances tasks across workers - using the standard ``rank::world_size`` slicing. - """ + if robocasa_root is not None: + candidates = [ + robocasa_root / "models" / "assets" / "box_links" / "box_links_ds.json", + robocasa_root / "robocasa" / "models" / "assets" / "box_links" / "box_links_ds.json", + ] + for candidate in candidates: + if candidate.exists(): + return candidate + raise FileNotFoundError( + f"Could not find box_links_ds.json under --robocasa-root={robocasa_root}" + ) - def __init__(self, task_names: list[str], *, overwrite: bool = False): - super().__init__() - self.task_names = list(task_names) - self.overwrite = overwrite + try: + import robocasa # noqa: PLC0415 + except ModuleNotFoundError as exc: + raise FileNotFoundError( + "Could not resolve RoboCasa box links. Pass --robocasa-root or --box-links-json, " + "or run in an environment where `robocasa` is importable." + ) from exc - def run(self, data=None, rank: int = 0, world_size: int = 1): - from lerobot.utils.utils import init_logging # noqa: PLC0415 - - init_logging() - require_robocasa() - - # Standard datatrove slicing: each rank owns the subset - # ``task_names[rank::world_size]``. When ``world_size == - # len(task_names)`` this is exactly one task per rank. - my_tasks = self.task_names[rank::world_size] - if not my_tasks: - logger.info("rank %d/%d: no tasks assigned", rank, world_size) - return - - for task in my_tasks: - logger.info("rank %d/%d: downloading %s", rank, world_size, task) - root = download_task(task, overwrite=self.overwrite) - logger.info("rank %d/%d: %s extracted at %s", rank, world_size, task, root) + candidate = Path(robocasa.__path__[0]) / "models" / "assets" / "box_links" / "box_links_ds.json" + if not candidate.exists(): + raise FileNotFoundError(f"Resolved RoboCasa package, but box links file is missing: {candidate}") + return candidate -class AggregateRoboCasaShards(PipelineStep): - """Phase 2 — merge all 16 extracted directories into one LeRobotDataset. +def _discover_tasks( + box_links_json: Path, + split: str = DEFAULT_SPLIT, + source: str | None = DEFAULT_SOURCE, +) -> list[dict[str, str]]: + with open(box_links_json) as f: + box_links: dict[str, str] = json.load(f) - ``aggregate_datasets`` parallelizes internally; only rank 0 runs the - merge (mirrors the DROID ``slurm_aggregate_shards.py`` convention). - """ + tasks: list[dict[str, str]] = [] + for tar_key in sorted(box_links): + parts = tar_key.split("/") + if len(parts) < 3 or parts[0] != split: + continue + + # RoboCasa registries can appear in two layouts: + # new: split////lerobot.tar + # old: split//.tar + if parts[1] in {"human", "mimicgen"}: + tar_source = parts[1] + else: + tar_source = "human" + + if source is not None and tar_source != source: + continue + + tasks.append( + { + "task_name": _task_name_from_tar_key(tar_key), + "tar_key": tar_key, + "source": tar_source, + "rel_path": tar_key.removesuffix(".tar"), + "shared_url": box_links[tar_key], + } + ) + return tasks + + +class PrepareRoboCasaUnifiedShards(PipelineStep): + """Build per-rank unified shards from RoboCasa task tarballs.""" def __init__( self, - task_names: list[str], - *, + tasks: list[dict[str, str]], output_repo_id: str, - output_dir: Path, - push_to_hub: bool, - private: bool, + work_dir: str, + split: str, + robot_type: str, + overwrite: bool = False, + cleanup_temp: bool = False, + max_episodes_per_task: int | None = None, + vcodec: str = "libsvtav1", ): super().__init__() - self.task_names = list(task_names) + self.tasks = tasks self.output_repo_id = output_repo_id - self.output_dir = Path(output_dir) - self.push_to_hub = push_to_hub - self.private = private + self.work_dir = Path(work_dir) + self.split = split + self.robot_type = robot_type + self.overwrite = overwrite + self.cleanup_temp = cleanup_temp + self.max_episodes_per_task = max_episodes_per_task + self.vcodec = vcodec def run(self, data=None, rank: int = 0, world_size: int = 1): - from lerobot.utils.utils import init_logging # noqa: PLC0415 + import copy + import json + import logging + import shutil + import tarfile + import urllib.request + + import numpy as np + from PIL import Image + + from lerobot.datasets.lerobot_dataset import LeRobotDataset + from lerobot.datasets.utils import DEFAULT_DATA_FILE_SIZE_IN_MB, DEFAULT_VIDEO_FILE_SIZE_IN_MB + from lerobot.scripts.convert_dataset_v21_to_v30 import ( + convert_data, + convert_episodes_metadata, + convert_info, + convert_tasks, + convert_videos, + validate_local_dataset_version, + ) + from lerobot.utils.utils import init_logging + + init_logging() + + target_image_keys = { + "observation.images.robot0_agentview_left": [ + "observation.images.robot0_agentview_left", + "left_image", + "observation.images.left_image", + ], + "observation.images.robot0_agentview_right": [ + "observation.images.robot0_agentview_right", + "right_image", + "observation.images.right_image", + ], + "observation.images.robot0_eye_in_hand": [ + "observation.images.robot0_eye_in_hand", + "wrist_image", + "observation.images.wrist_image", + ], + } + direct_state_keys = [ + "observation.state", + "state", + ] + explicit_state_groups = [ + [ + "observation.state.base_position", + "observation.state.base_rotation", + "observation.state.end_effector_position_relative", + "observation.state.end_effector_rotation_relative", + "observation.state.gripper_qpos", + ], + [ + "state.base_position", + "state.base_rotation", + "state.end_effector_position_relative", + "state.end_effector_rotation_relative", + "state.gripper_qpos", + ], + ] + + my_tasks = self.tasks[rank::world_size] + logging.info( + "Rank %s/%s: rebuilding %s of %s tasks", + rank, + world_size, + len(my_tasks), + len(self.tasks), + ) + if not my_tasks: + return + + shard_repo_id = f"{self.output_repo_id}_world_{world_size}_rank_{rank}" + shard_root = ( + self.work_dir + / "shards" + / self.output_repo_id.replace("/", "__") + / f"world_{world_size}" + / f"rank_{rank}" + ) + + def shard_is_complete(root: Path) -> bool: + info_path = root / "meta" / "info.json" + tasks_path = root / "meta" / "tasks.parquet" + stats_path = root / "meta" / "stats.json" + if not (info_path.exists() and tasks_path.exists() and stats_path.exists()): + return False + + episodes_dir = root / "meta" / "episodes" + data_dir = root / "data" + videos_dir = root / "videos" + if not episodes_dir.exists() or not data_dir.exists() or not videos_dir.exists(): + return False + if not any(episodes_dir.rglob("*.parquet")): + return False + if not any(data_dir.rglob("*.parquet")): + return False + if not any(videos_dir.rglob("*.mp4")): + return False + + with open(info_path) as f: + info = json.load(f) + return info.get("total_episodes", 0) > 0 and info.get("total_frames", 0) > 0 + + if shard_is_complete(shard_root) and not self.overwrite: + logging.info("Shard already complete, skipping rank %s: %s", rank, shard_root) + return + if shard_root.exists(): + if self.overwrite: + logging.warning("Removing existing shard root (--overwrite): %s", shard_root) + else: + logging.warning("Removing incomplete shard root before rebuild: %s", shard_root) + shutil.rmtree(shard_root) + + def direct_download_url(shared_url: str) -> str: + shared_id = shared_url.rstrip("/").split("/")[-1] + base = shared_url.split("/s/")[0] + return f"{base}/shared/static/{shared_id}.tar" + + def restore_v21_root_if_needed(dataset_root: Path) -> None: + old_root = dataset_root.parent / f"{dataset_root.name}_old" + if not dataset_root.exists() and old_root.exists(): + shutil.move(str(old_root), str(dataset_root)) + + def download_and_extract(shared_url: str, destination: Path) -> None: + url = direct_download_url(shared_url) + extract_dir = destination.parent + extract_dir.mkdir(parents=True, exist_ok=True) + tar_path = extract_dir / f"{destination.name}.tar" + + if destination.exists() and (destination / "meta" / "info.json").exists(): + logging.info(" Already extracted: %s", destination) + return + + for attempt in range(3): + try: + logging.info(" Downloading (attempt %s) -> %s", attempt + 1, tar_path) + urllib.request.urlretrieve(url, str(tar_path)) + break + except Exception as exc: + logging.warning(" Download attempt %s failed: %s", attempt + 1, exc) + if tar_path.exists(): + tar_path.unlink() + else: + raise RuntimeError(f"Failed to download {url} after 3 attempts") + + logging.info(" Extracting to %s", extract_dir) + with tarfile.open(tar_path, "r") as tar: + tar.extractall(path=extract_dir) + tar_path.unlink() + + def is_v30(dataset_root: Path) -> bool: + info_path = dataset_root / "meta" / "info.json" + if not info_path.exists(): + return False + with open(info_path) as f: + info = json.load(f) + return info.get("codebase_version") == "v3.0" + + def convert_v21_to_v30(dataset_root: Path) -> None: + data_mb = DEFAULT_DATA_FILE_SIZE_IN_MB + video_mb = DEFAULT_VIDEO_FILE_SIZE_IN_MB + + validate_local_dataset_version(dataset_root) + + new_root = dataset_root.parent / f"{dataset_root.name}_v30" + if new_root.exists(): + shutil.rmtree(new_root) + + convert_info(dataset_root, new_root, data_mb, video_mb) + convert_tasks(dataset_root, new_root) + episodes_metadata = convert_data(dataset_root, new_root, data_mb) + episodes_video_metadata = convert_videos(dataset_root, new_root, video_mb) + convert_episodes_metadata( + dataset_root, + new_root, + episodes_metadata, + episodes_video_metadata, + ) + + old_root = dataset_root.parent / f"{dataset_root.name}_old" + if old_root.exists(): + shutil.rmtree(old_root) + shutil.move(str(dataset_root), str(old_root)) + shutil.move(str(new_root), str(dataset_root)) + logging.info(" Conversion complete: %s", dataset_root) + + def as_float32_vector(value) -> np.ndarray: + if value.__class__.__module__.startswith("torch"): + arr = value.detach().cpu().numpy() + else: + arr = np.asarray(value) + return arr.astype(np.float32).reshape(-1) + + def to_pil_image(value) -> Image.Image: + if isinstance(value, Image.Image): + return value + if value.__class__.__module__.startswith("torch"): + arr = value.detach().cpu() + if arr.ndim != 3: + raise ValueError(f"Expected rank-3 image tensor, got shape {tuple(arr.shape)}") + if arr.shape[0] in (1, 3, 4) and arr.shape[-1] not in (1, 3, 4): + arr = arr.permute(1, 2, 0) + if getattr(arr.dtype, "is_floating_point", False): + if float(arr.max()) <= 1.0: + arr = arr * 255.0 + arr = arr.clamp(0, 255).byte() + else: + arr = arr.byte() + return Image.fromarray(arr.numpy()) + + arr = np.asarray(value) + if arr.ndim != 3: + raise ValueError(f"Expected rank-3 image array, got shape {arr.shape}") + if arr.shape[0] in (1, 3, 4) and arr.shape[-1] not in (1, 3, 4): + arr = np.transpose(arr, (1, 2, 0)) + if np.issubdtype(arr.dtype, np.floating): + if float(arr.max()) <= 1.0: + arr = arr * 255.0 + arr = np.clip(arr, 0, 255).astype(np.uint8) + elif arr.dtype != np.uint8: + arr = arr.astype(np.uint8) + return Image.fromarray(arr) + + def normalize_name(name: str) -> str: + return name.replace("/", ".").replace("_", ".").lower() + + def choose_one(available_keys: list[str], aliases: list[str], label: str) -> str: + for alias in aliases: + if alias in available_keys: + return alias + raise ValueError(f"Could not resolve {label}. Available keys: {available_keys}") + + def resolve_image_key_map(available_keys: list[str]) -> dict[str, str]: + return { + target_key: choose_one(available_keys, aliases, target_key) + for target_key, aliases in target_image_keys.items() + } + + def resolve_action_key(available_keys: list[str]) -> str: + return choose_one(available_keys, ["action", "actions"], "action") + + def state_sort_key(name: str) -> tuple[int, str]: + normalized = normalize_name(name) + if "base.position" in normalized: + return (0, normalized) + if "base.rotation" in normalized or "base.quat" in normalized: + return (1, normalized) + if "end.effector.position" in normalized or "eef.pos" in normalized: + return (2, normalized) + if "end.effector.rotation" in normalized or "eef.quat" in normalized or "eef.rot" in normalized: + return (3, normalized) + if "gripper" in normalized: + return (4, normalized) + return (5, normalized) + + def resolve_state_keys(available_keys: list[str]) -> list[str]: + for key in direct_state_keys: + if key in available_keys: + return [key] + + for group in explicit_state_groups: + if all(key in available_keys for key in group): + return group + + prefix_keys = [ + key + for key in available_keys + if key.startswith("observation.state.") or key.startswith("state.") + ] + if prefix_keys: + return sorted(prefix_keys, key=state_sort_key) + + proprio_like = [ + key + for key in available_keys + if any( + token in normalize_name(key) + for token in ["base.position", "base.rotation", "end.effector", "eef", "gripper"] + ) + ] + if proprio_like: + return sorted(set(proprio_like), key=state_sort_key) + + raise ValueError(f"Could not resolve RoboCasa proprioception keys. Available keys: {available_keys}") + + def build_state(item: dict, state_keys: list[str]) -> np.ndarray: + if len(state_keys) == 1: + return as_float32_vector(item[state_keys[0]]) + parts = [as_float32_vector(item[key]) for key in state_keys] + return np.concatenate(parts, axis=0).astype(np.float32) + + def infer_target_features( + src_dataset: LeRobotDataset, + image_key_map: dict[str, str], + action_key: str, + state_dim: int, + ) -> tuple[dict, bool]: + features = {} + use_videos = False + + for target_key, source_key in image_key_map.items(): + feature_info = copy.deepcopy(src_dataset.meta.features[source_key]) + if "fps" not in feature_info and feature_info.get("dtype") != "video": + feature_info["fps"] = int(src_dataset.meta.fps) + use_videos = use_videos or feature_info.get("dtype") == "video" + features[target_key] = feature_info + + action_info = copy.deepcopy(src_dataset.meta.features[action_key]) + action_info["dtype"] = "float32" + action_info["fps"] = int(src_dataset.meta.fps) + features["action"] = action_info + + features["observation.state"] = { + "dtype": "float32", + "shape": (state_dim,), + "names": [f"state_{i}" for i in range(state_dim)], + "fps": int(src_dataset.meta.fps), + } + + return features, use_videos + + def task_root(task_meta: dict[str, str]) -> Path: + return self.work_dir / task_meta["rel_path"] + + def cleanup_task_root(dataset_root: Path) -> None: + old_root = dataset_root.parent / f"{dataset_root.name}_old" + if dataset_root.exists(): + shutil.rmtree(dataset_root) + if old_root.exists(): + shutil.rmtree(old_root) + + shard_dataset = None + shard_meta: dict[str, int | tuple[int, ...]] | None = None + + for task_meta in my_tasks: + task_name = task_meta["task_name"] + dataset_root = task_root(task_meta) + + logging.info("--- %s (%s) ---", task_name, task_meta["tar_key"]) + restore_v21_root_if_needed(dataset_root) + + download_and_extract(task_meta["shared_url"], dataset_root) + if not is_v30(dataset_root): + convert_v21_to_v30(dataset_root) + + src_dataset = LeRobotDataset(repo_id=task_name, root=dataset_root) + available_keys = list(src_dataset.meta.features.keys()) + image_key_map = resolve_image_key_map(available_keys) + action_key = resolve_action_key(available_keys) + state_keys = resolve_state_keys(available_keys) + + if len(src_dataset) == 0: + raise ValueError(f"Task dataset is empty: {dataset_root}") + + first_item = src_dataset[0] + first_state = build_state(first_item, state_keys) + first_action = as_float32_vector(first_item[action_key]) + + if shard_dataset is None: + target_features, use_videos = infer_target_features( + src_dataset=src_dataset, + image_key_map=image_key_map, + action_key=action_key, + state_dim=int(first_state.size), + ) + shard_dataset = LeRobotDataset.create( + repo_id=shard_repo_id, + root=shard_root, + fps=int(src_dataset.meta.fps), + robot_type=self.robot_type, + features=target_features, + use_videos=use_videos, + vcodec=self.vcodec, + batch_encoding_size=1, + ) + shard_meta = { + "fps": int(src_dataset.meta.fps), + "state_dim": int(first_state.size), + "action_shape": tuple(first_action.shape), + } + else: + assert shard_meta is not None + if int(src_dataset.meta.fps) != shard_meta["fps"]: + raise ValueError( + f"FPS mismatch for {task_name}: {src_dataset.meta.fps} != {shard_meta['fps']}" + ) + if int(first_state.size) != shard_meta["state_dim"]: + raise ValueError( + f"State dim mismatch for {task_name}: {first_state.size} != {shard_meta['state_dim']}" + ) + if tuple(first_action.shape) != shard_meta["action_shape"]: + raise ValueError( + f"Action shape mismatch for {task_name}: {tuple(first_action.shape)} != " + f"{shard_meta['action_shape']}" + ) + + num_episodes = src_dataset.num_episodes + if self.max_episodes_per_task is not None: + num_episodes = min(num_episodes, self.max_episodes_per_task) + + logging.info(" Appending %s episodes into shard %s", num_episodes, shard_root) + for episode_idx in range(num_episodes): + start = int(src_dataset.meta.episodes["dataset_from_index"][episode_idx]) + end = int(src_dataset.meta.episodes["dataset_to_index"][episode_idx]) + + for frame_idx in range(start, end): + item = src_dataset[frame_idx] + frame = { + "task": task_name, + "observation.state": build_state(item, state_keys), + "action": as_float32_vector(item[action_key]), + } + for target_key, source_key in image_key_map.items(): + frame[target_key] = to_pil_image(item[source_key]) + shard_dataset.add_frame(frame) + + shard_dataset.save_episode() + + if self.cleanup_temp: + cleanup_task_root(dataset_root) + + if shard_dataset is None: + logging.warning("Rank %s produced no shard dataset", rank) + return + + shard_dataset.finalize() + logging.info("Rank %s finalized shard at %s", rank, shard_root) + + +class AggregateRoboCasaUnifiedShards(PipelineStep): + """Aggregate repaired shard datasets into one final RoboCasa dataset.""" + + def __init__( + self, + output_repo_id: str, + shard_roots: list[str], + output_root: str, + push: bool = True, + overwrite: bool = False, + ): + super().__init__() + self.output_repo_id = output_repo_id + self.shard_roots = [Path(root) for root in shard_roots] + self.output_root = Path(output_root) + self.push = push + self.overwrite = overwrite + + def run(self, data=None, rank: int = 0, world_size: int = 1): + import json + import logging + import shutil + + from lerobot.datasets.aggregate import aggregate_datasets + from lerobot.datasets.lerobot_dataset import LeRobotDataset + from lerobot.utils.utils import init_logging init_logging() if rank != 0: - logger.info("rank %d: aggregation runs on rank 0 only — skipping", rank) + logging.info("Rank %s: only rank 0 aggregates", rank) return - require_robocasa() - from lerobot.datasets import aggregate_datasets # noqa: PLC0415 - from lerobot.datasets.lerobot_dataset import LeRobotDataset # noqa: PLC0415 + def shard_is_complete(root: Path) -> bool: + info_path = root / "meta" / "info.json" + tasks_path = root / "meta" / "tasks.parquet" + stats_path = root / "meta" / "stats.json" + if not (info_path.exists() and tasks_path.exists() and stats_path.exists()): + return False - # Resolve each task's local extraction root. After phase 1 these - # all exist on disk under robocasa.macros.DATASET_BASE_DIR; if any - # are missing, fail loudly so the operator knows phase 1 didn't - # cleanly complete for that task. - roots: list[Path] = [] - missing: list[str] = [] - for task in self.task_names: - root = resolve_task_root(task) - if not root.exists(): - missing.append(f"{task} -> {root}") - else: - roots.append(root) + episodes_dir = root / "meta" / "episodes" + data_dir = root / "data" + videos_dir = root / "videos" + if not episodes_dir.exists() or not data_dir.exists() or not videos_dir.exists(): + return False + if not any(episodes_dir.rglob("*.parquet")): + return False + if not any(data_dir.rglob("*.parquet")): + return False + if not any(videos_dir.rglob("*.mp4")): + return False + + with open(info_path) as f: + info = json.load(f) + return info.get("total_episodes", 0) > 0 and info.get("total_frames", 0) > 0 + + missing = [root for root in self.shard_roots if not shard_is_complete(root)] if missing: - raise RuntimeError( - "Phase 1 did not produce extracted directories for: " - + ", ".join(missing) - + " — re-run the download phase before aggregating." - ) + raise FileNotFoundError(f"Missing shard datasets: {missing}") - # ``aggregate_datasets`` uses ``repo_ids`` purely for logging / - # the unified task table when ``roots=`` is supplied; the actual - # data is loaded from each root directly. - repo_ids = [f"robocasa/{task}_target_human" for task in self.task_names] + if self.output_root.exists() and self.overwrite: + logging.warning("Removing existing unified output (--overwrite): %s", self.output_root) + shutil.rmtree(self.output_root) - logger.info( - "Aggregating %d source datasets into %s at %s", - len(roots), - self.output_repo_id, - self.output_dir, - ) + shard_repo_ids = [f"{self.output_repo_id}_shard_{idx}" for idx in range(len(self.shard_roots))] + logging.info("Aggregating %s shards into %s", len(self.shard_roots), self.output_root) aggregate_datasets( - repo_ids=repo_ids, + repo_ids=shard_repo_ids, + roots=self.shard_roots, aggr_repo_id=self.output_repo_id, - roots=roots, - aggr_root=self.output_dir, + aggr_root=self.output_root, ) - logger.info("Aggregation complete.") - if self.push_to_hub: - merged = LeRobotDataset( - repo_id=self.output_repo_id, - root=self.output_dir, - ) - logger.info( - "Pushing %s to the Hub (private=%s, %d episodes, %d frames)", - self.output_repo_id, - self.private, - merged.num_episodes, - merged.num_frames, - ) - merged.push_to_hub( - private=self.private, - upload_large_folder=True, - tags=["lerobot", "robocasa", "composite_seen", "manipulation"], - ) - logger.info( - "Push complete: https://huggingface.co/datasets/%s", - self.output_repo_id, + if self.push: + dataset = LeRobotDataset(repo_id=self.output_repo_id, root=self.output_root) + dataset.push_to_hub( + tags=["lerobot", "robocasa", "composite_seen", "unified"], + private=False, ) + logging.info("Pushed to https://huggingface.co/datasets/%s", self.output_repo_id) -# --------------------------------------------------------------------------- -# Executors -# --------------------------------------------------------------------------- - - -def make_download_executor( +def make_prepare_executor( *, - task_names: list[str], + tasks: list[dict[str, str]], + output_repo_id: str, + work_dir: Path, + split: str, + robot_type: str, overwrite: bool, + cleanup_temp: bool, + max_episodes_per_task: int | None, + vcodec: str, job_name: str, logs_dir: Path, workers: int, - partition: str | None, + partition: str, cpus_per_task: int, - mem: str, - time: str, + mem_per_cpu: str, + time_limit: str, slurm: bool, ): - """Phase-1 executor: parallel downloads, one task per worker by default.""" - pipeline = [DownloadRoboCasaTask(task_names, overwrite=overwrite)] - logging_dir = str(logs_dir / job_name) + kwargs = { + "pipeline": [ + PrepareRoboCasaUnifiedShards( + tasks=tasks, + output_repo_id=output_repo_id, + work_dir=str(work_dir), + split=split, + robot_type=robot_type, + overwrite=overwrite, + cleanup_temp=cleanup_temp, + max_episodes_per_task=max_episodes_per_task, + vcodec=vcodec, + ) + ], + "logging_dir": str(logs_dir / job_name), + } if slurm: - return SlurmPipelineExecutor( - pipeline=pipeline, - logging_dir=logging_dir, - job_name=job_name, - tasks=len(task_names), # one shard per RoboCasa task - workers=workers, - time=time, - partition=partition, - cpus_per_task=cpus_per_task, - sbatch_args={"mem": mem}, + kwargs.update( + { + "job_name": job_name, + "tasks": workers, + "workers": workers, + "time": time_limit, + "partition": partition, + "cpus_per_task": cpus_per_task, + "sbatch_args": {"mem-per-cpu": mem_per_cpu}, + } ) - return LocalPipelineExecutor( - pipeline=pipeline, - logging_dir=logging_dir, - tasks=len(task_names), - workers=min(workers, len(task_names)), - ) + return SlurmPipelineExecutor(**kwargs) + + kwargs.update({"tasks": workers, "workers": 1}) + return LocalPipelineExecutor(**kwargs) def make_aggregate_executor( *, - task_names: list[str], output_repo_id: str, - output_dir: Path, - push_to_hub: bool, - private: bool, + shard_roots: list[Path], + output_root: Path, + push: bool, + overwrite: bool, job_name: str, logs_dir: Path, - partition: str | None, + partition: str, cpus_per_task: int, mem_per_cpu: str, - time: str, + time_limit: str, slurm: bool, - depends: SlurmPipelineExecutor | None, + depends: SlurmPipelineExecutor | None = None, ): - """Phase-2 executor: single worker, aggregates the extracted shards.""" - pipeline = [ - AggregateRoboCasaShards( - task_names, - output_repo_id=output_repo_id, - output_dir=output_dir, - push_to_hub=push_to_hub, - private=private, - ) - ] - logging_dir = str(logs_dir / job_name) + kwargs = { + "pipeline": [ + AggregateRoboCasaUnifiedShards( + output_repo_id=output_repo_id, + shard_roots=[str(root) for root in shard_roots], + output_root=str(output_root), + push=push, + overwrite=overwrite, + ) + ], + "logging_dir": str(logs_dir / job_name), + } if slurm: - return SlurmPipelineExecutor( - pipeline=pipeline, - logging_dir=logging_dir, - job_name=job_name, - tasks=1, - workers=1, - time=time, - partition=partition, - cpus_per_task=cpus_per_task, - sbatch_args={"mem-per-cpu": mem_per_cpu}, - depends=depends, # SLURM job dependency on phase 1 + kwargs.update( + { + "job_name": job_name, + "tasks": 1, + "workers": 1, + "time": time_limit, + "partition": partition, + "cpus_per_task": cpus_per_task, + "sbatch_args": {"mem-per-cpu": mem_per_cpu}, + "depends": depends, + } ) - return LocalPipelineExecutor( - pipeline=pipeline, - logging_dir=logging_dir, - tasks=1, - workers=1, - ) + return SlurmPipelineExecutor(**kwargs) + + kwargs.update({"tasks": 1, "workers": 1}) + return LocalPipelineExecutor(**kwargs) -# --------------------------------------------------------------------------- -# CLI -# --------------------------------------------------------------------------- +def resolve_repo_id(args: argparse.Namespace) -> str: + if args.repo_id: + return args.repo_id + if args.hf_user: + return f"{args.hf_user}/robocasa_composite_seen_{args.split}_{args.source}_unified_v3" + raise ValueError("Pass either --repo-id or --hf-user.") -def parse_args() -> argparse.Namespace: +def main(): parser = argparse.ArgumentParser( - description=( - "Build the merged RoboCasa composite_seen LeRobotDataset on SLURM " - "via datatrove (download in parallel, aggregate sequentially)." - ), - formatter_class=argparse.RawDescriptionHelpFormatter, - epilog=__doc__, + description="Rebuild the 16 RoboCasa composite_seen tarballs into one unified LeRobot v3 dataset." ) - - # I/O. + parser.add_argument("--repo-id", type=str, default=None, help="Final unified dataset repo id.") parser.add_argument( - "--output-dir", - type=Path, - required=True, - help="Local directory for the merged dataset (will be created).", - ) - parser.add_argument( - "--hub-repo-id", + "--hf-user", type=str, default=None, - help=( - "Hub repo_id for the merged dataset (e.g. " - "``yourname/robocasa_composite_seen``). Required for " - "``--push-to-hub``; also becomes the merged dataset's " - "canonical ``repo_id``." - ), + help="Optional shorthand. If set and --repo-id is omitted, derive " + "/robocasa_composite_seen___unified_v3.", + ) + parser.add_argument("--work-dir", type=Path, required=True) + parser.add_argument("--split", type=str, default=DEFAULT_SPLIT, choices=["target", "pretrain"]) + parser.add_argument("--source", type=str, default=DEFAULT_SOURCE) + parser.add_argument( + "--mode", + type=str, + default="all", + choices=["all", "prepare", "aggregate"], + help="prepare = build shards, aggregate = merge existing shards, all = do both.", ) parser.add_argument( - "--push-to-hub", - action="store_true", - help="Push the merged dataset to the Hub after aggregation.", + "--task-set", + type=str, + default="composite_seen", + choices=sorted(TASK_SETS.keys()), + help="Predefined task set to restrict discovery to. Default " + "``composite_seen`` (the 16 multi-step composite_seen tasks). Use " + "``all`` to keep every discovered task in the split/source slice. " + "``--tasks`` overrides this when provided.", ) + parser.add_argument("--robocasa-root", type=Path, default=None) + parser.add_argument("--box-links-json", type=Path, default=None) + parser.add_argument("--robot-type", type=str, default=DEFAULT_ROBOT_TYPE) + parser.add_argument("--vcodec", type=str, default="libsvtav1") + parser.add_argument("--max-episodes-per-task", type=int, default=None) + parser.add_argument("--cleanup-temp", action="store_true") + parser.add_argument("--overwrite", action="store_true") + parser.add_argument("--no-push", action="store_true") + parser.add_argument("--logs-dir", type=Path, default=Path("logs")) + parser.add_argument("--job-name", type=str, default="port_robocasa_composite_seen") + parser.add_argument("--slurm", type=int, default=1, help="1 = Slurm executor, 0 = local debug.") parser.add_argument( - "--private", - action="store_true", - help="When pushing, create the Hub repo as private.", + "--workers", + type=int, + default=16, + help="Number of SLURM workers. Default 16 = one per composite_seen task.", ) + parser.add_argument("--partition", type=str, default="hopper-cpu") parser.add_argument( - "--logs-dir", - type=Path, - default=Path("./logs/robocasa"), - help="Path to datatrove logs directory (used for stdout/stderr and " - "phase coordination).", + "--cpus-per-task", + type=int, + default=8, + help="CPUs per worker. 16 workers × 8 cpus = 128 cpus total on hopper-cpu.", ) + parser.add_argument("--mem-per-cpu", type=str, default="4G") + parser.add_argument("--time", type=str, default="24:00:00") parser.add_argument( "--tasks", type=str, + nargs="*", default=None, - help="Comma-separated task names overriding the default 16 " - "composite_seen list (debug / smoke-test).", - ) - parser.add_argument( - "--overwrite-download", - action="store_true", - help="Force re-download even if the local extraction looks complete.", + help="Explicit task names. Overrides --task-set when provided.", ) + parser.add_argument("--dryrun", action="store_true") + args = parser.parse_args() - # SLURM controls. - parser.add_argument( - "--slurm", - type=int, - default=1, - help="Launch over SLURM (``1``) or locally / sequentially (``0``).", - ) - parser.add_argument( - "--partition", - type=str, - default=None, - help="SLURM partition. A CPU partition is sufficient — no GPU needed.", - ) + box_links_json = _resolve_box_links_json(args.box_links_json, args.robocasa_root) + all_tasks = _discover_tasks(box_links_json, split=args.split, source=args.source) - # Phase-1 (download) sizing. - parser.add_argument( - "--download-workers", - type=int, - default=16, - help="Number of parallel SLURM workers for the download phase. " - "Default matches the number of composite_seen tasks (16).", - ) - parser.add_argument( - "--download-cpus", - type=int, - default=4, - help="CPUs per download worker (the work is network- and " - "tar-extraction-bound).", - ) - parser.add_argument( - "--download-mem", - type=str, - default="8G", - help="Total memory per download worker.", - ) - parser.add_argument( - "--download-time", - type=str, - default="06:00:00", - help="SLURM wall-clock limit per download worker (HH:MM:SS).", - ) + # Filter: explicit --tasks wins; otherwise apply --task-set. + if args.tasks: + selected = {task.lower() for task in args.tasks} + all_tasks = [task for task in all_tasks if task["task_name"].lower() in selected] + elif args.task_set != "all": + wanted = {t.lower() for t in TASK_SETS[args.task_set]} + all_tasks = [task for task in all_tasks if task["task_name"].lower() in wanted] - # Phase-2 (aggregate) sizing. - parser.add_argument( - "--aggregate-cpus", - type=int, - default=16, - help="CPUs for the aggregation worker (ffmpeg + parquet I/O parallelize).", - ) - parser.add_argument( - "--aggregate-mem-per-cpu", - type=str, - default="2G", - help="SLURM mem-per-cpu for the aggregation worker.", - ) - parser.add_argument( - "--aggregate-time", - type=str, - default="12:00:00", - help="SLURM wall-clock limit for aggregation (HH:MM:SS). Tens-of-GB " - "merges can take several hours.", - ) + if not all_tasks: + raise ValueError( + f"No RoboCasa tasks selected for split={args.split!r}, source={args.source!r}, " + f"task_set={args.task_set!r}, tasks={args.tasks!r}" + ) - # Job naming. - parser.add_argument( - "--download-job-name", - type=str, - default="robocasa_dl", - help="SLURM job name for phase 1.", - ) - parser.add_argument( - "--aggregate-job-name", - type=str, - default="robocasa_agg", - help="SLURM job name for phase 2.", - ) + print(f"Tasks to rebuild ({len(all_tasks)}):") + for task in all_tasks: + print(f" {task['task_name']} ({task['tar_key']})") + if args.dryrun: + return - parser.add_argument( - "--log-level", - type=str, - default="INFO", - choices=["DEBUG", "INFO", "WARNING", "ERROR"], - ) - return parser.parse_args() + output_repo_id = resolve_repo_id(args) + output_root = args.work_dir / "unified" / output_repo_id + active_ranks = [rank for rank in range(args.workers) if all_tasks[rank::args.workers]] + shard_roots = [ + args.work_dir + / "shards" + / output_repo_id.replace("/", "__") + / f"world_{args.workers}" + / f"rank_{rank}" + for rank in active_ranks + ] + prepare_executor = None + if args.mode in {"all", "prepare"}: + prepare_executor = make_prepare_executor( + tasks=all_tasks, + output_repo_id=output_repo_id, + work_dir=args.work_dir, + split=args.split, + robot_type=args.robot_type, + overwrite=args.overwrite, + cleanup_temp=args.cleanup_temp, + max_episodes_per_task=args.max_episodes_per_task, + vcodec=args.vcodec, + job_name=args.job_name, + logs_dir=args.logs_dir, + workers=args.workers, + partition=args.partition, + cpus_per_task=args.cpus_per_task, + mem_per_cpu=args.mem_per_cpu, + time_limit=args.time, + slurm=args.slurm == 1, + ) + if args.mode == "prepare": + prepare_executor.run() -def main() -> int: - args = parse_args() - logging.basicConfig( - level=getattr(logging, args.log_level), - format="[%(levelname)s] %(name)s: %(message)s", - ) - - tasks = ( - [t.strip() for t in args.tasks.split(",") if t.strip()] - if args.tasks - else list(COMPOSITE_SEEN_TASKS) - ) - if not tasks: - raise SystemExit("No tasks selected.") - - if args.push_to_hub and not args.hub_repo_id: - raise SystemExit("--push-to-hub requires --hub-repo-id.") - - output_repo_id = args.hub_repo_id or "local/robocasa_composite_seen" - slurm = args.slurm == 1 - - logger.info( - "Phase 1 (download) — %d tasks across %d workers (slurm=%s)", - len(tasks), - min(args.download_workers, len(tasks)), - slurm, - ) - download_executor = make_download_executor( - task_names=tasks, - overwrite=args.overwrite_download, - job_name=args.download_job_name, - logs_dir=args.logs_dir, - workers=args.download_workers, - partition=args.partition, - cpus_per_task=args.download_cpus, - mem=args.download_mem, - time=args.download_time, - slurm=slurm, - ) - - logger.info( - "Phase 2 (aggregate) — single worker, output: %s (push_to_hub=%s)", - output_repo_id, - args.push_to_hub, - ) - aggregate_executor = make_aggregate_executor( - task_names=tasks, - output_repo_id=output_repo_id, - output_dir=args.output_dir, - push_to_hub=args.push_to_hub, - private=args.private, - job_name=args.aggregate_job_name, - logs_dir=args.logs_dir, - partition=args.partition, - cpus_per_task=args.aggregate_cpus, - mem_per_cpu=args.aggregate_mem_per_cpu, - time=args.aggregate_time, - slurm=slurm, - depends=download_executor if slurm else None, - ) - - if slurm: - # Submitting the aggregate executor with ``depends=download_executor`` - # also submits the download executor — SlurmPipelineExecutor walks - # the dependency chain and submits each job once with the right - # ``--dependency=afterok:`` arg. - aggregate_executor.run() - else: - # Local sequential: run download to completion, then aggregate. - download_executor.run() - aggregate_executor.run() - - logger.info("Done. Merged dataset at %s.", args.output_dir) - return 0 + if args.mode in {"all", "aggregate"}: + aggregate_executor = make_aggregate_executor( + output_repo_id=output_repo_id, + shard_roots=shard_roots, + output_root=output_root, + push=not args.no_push, + overwrite=args.overwrite, + job_name=f"{args.job_name}_aggregate", + logs_dir=args.logs_dir, + partition=args.partition, + cpus_per_task=args.cpus_per_task, + mem_per_cpu=args.mem_per_cpu, + time_limit=args.time, + slurm=args.slurm == 1, + depends=prepare_executor if args.mode == "all" and args.slurm == 1 else None, + ) + if args.mode == "all" and args.slurm == 1: + # SLURM: submitting the aggregate executor with depends=prepare_executor + # transitively submits prepare too, with the right --dependency=afterok. + aggregate_executor.run() + elif args.mode == "all": + # Local: run sequentially. + assert prepare_executor is not None + prepare_executor.run() + aggregate_executor.run() + else: + aggregate_executor.run() if __name__ == "__main__": - raise SystemExit(main()) + main()