mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-11 14:49:43 +00:00
Add chained SLURM mirror-and-double dataset script.
Provide a standalone DataTrove workflow that mirrors bimanual shards, aggregates mirrored output, builds a doubled dataset, and optionally pushes the final dataset to the Hub. Made-with: Cursor
This commit is contained in:
@@ -0,0 +1,726 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Mirror a bimanual dataset in parallel with DataTrove + SLURM, then double it.
|
||||
|
||||
Workflow:
|
||||
1) Split source episodes across `num_shards` ranks and mirror each shard in parallel.
|
||||
2) Aggregate mirrored shards into one mirrored dataset.
|
||||
3) Aggregate [original, mirrored] into a final doubled dataset.
|
||||
|
||||
Example:
|
||||
python examples/port_datasets/slurm_mirror_dataset.py \
|
||||
--repo-id=pepijn/openarm_bimanual \
|
||||
--output-repo-id=pepijn/openarm_bimanual_doubled \
|
||||
--partition=hopper-cpu \
|
||||
--num-shards=256 \
|
||||
--workers=64 \
|
||||
--cpus-per-task=8 \
|
||||
--mem-per-cpu=4G
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import logging
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from datatrove.executor import LocalPipelineExecutor
|
||||
from datatrove.executor.slurm import SlurmPipelineExecutor
|
||||
from datatrove.pipeline.base import PipelineStep
|
||||
|
||||
from lerobot.datasets.aggregate import aggregate_datasets
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.datasets.utils import DEFAULT_FEATURES
|
||||
from lerobot.utils.constants import HF_LEROBOT_HOME
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
OPENARM_MIRRORING_MASK = {
|
||||
"joint_1": -1,
|
||||
"joint_2": -1,
|
||||
"joint_3": -1,
|
||||
"joint_4": 1,
|
||||
"joint_5": -1,
|
||||
"joint_6": -1,
|
||||
"joint_7": -1,
|
||||
"gripper": 1,
|
||||
}
|
||||
|
||||
|
||||
def get_mirroring_mask(robot_type: str | None) -> dict[str, int]:
|
||||
if robot_type in ["bi_openarm_follower", "openarm_follower", "bi_openarms_follower", "openarms_follower"]:
|
||||
return OPENARM_MIRRORING_MASK
|
||||
raise ValueError(f"Unknown robot type: {robot_type}. Add a mirroring mask for this robot.")
|
||||
|
||||
|
||||
def swap_left_right_name(name: str) -> str:
|
||||
value = name.replace("left_", "LEFT_PLACEHOLDER_")
|
||||
value = value.replace("right_", "left_")
|
||||
value = value.replace("LEFT_PLACEHOLDER_", "right_")
|
||||
return value
|
||||
|
||||
|
||||
def mirror_feature_names(names: list[str]) -> tuple[list[str], dict[int, int]]:
|
||||
mirrored_names = [swap_left_right_name(n) for n in names]
|
||||
old_to_new_idx = {}
|
||||
for old_idx, old_name in enumerate(names):
|
||||
new_name = swap_left_right_name(old_name)
|
||||
new_idx = mirrored_names.index(new_name)
|
||||
old_to_new_idx[old_idx] = new_idx
|
||||
return mirrored_names, old_to_new_idx
|
||||
|
||||
|
||||
def _get_axis_names(feature: dict[str, Any]) -> list[str] | None:
|
||||
names = feature.get("names")
|
||||
if isinstance(names, list):
|
||||
return names
|
||||
if isinstance(names, dict):
|
||||
axes = names.get("axes")
|
||||
if isinstance(axes, list):
|
||||
return axes
|
||||
return None
|
||||
|
||||
|
||||
def _to_numpy(value: Any) -> Any:
|
||||
if isinstance(value, np.ndarray):
|
||||
return value
|
||||
if hasattr(value, "detach"):
|
||||
return value.detach().cpu().numpy()
|
||||
if hasattr(value, "cpu") and hasattr(value, "numpy"):
|
||||
return value.cpu().numpy()
|
||||
if hasattr(value, "numpy"):
|
||||
return value.numpy()
|
||||
return value
|
||||
|
||||
|
||||
def apply_mirroring_mask(value: float, axis_name: str, mirroring_mask: dict[str, int]) -> float:
|
||||
if axis_name.startswith("left_") or axis_name.startswith("right_"):
|
||||
axis_name = axis_name.split("_", 1)[1]
|
||||
joint_name = axis_name.split(".")[0]
|
||||
return value * mirroring_mask.get(joint_name, 1)
|
||||
|
||||
|
||||
def mirror_vector_feature(
|
||||
value: Any,
|
||||
feature: dict[str, Any],
|
||||
mirroring_mask: dict[str, int],
|
||||
) -> Any:
|
||||
array = _to_numpy(value)
|
||||
if not isinstance(array, np.ndarray) or array.ndim != 1:
|
||||
return array
|
||||
|
||||
names = _get_axis_names(feature)
|
||||
if names is None or len(names) != len(array):
|
||||
return array
|
||||
|
||||
mirrored_names, index_mapping = mirror_feature_names(names)
|
||||
mirrored = np.zeros_like(array)
|
||||
for old_idx, new_idx in index_mapping.items():
|
||||
mirrored[new_idx] = apply_mirroring_mask(array[old_idx], mirrored_names[new_idx], mirroring_mask)
|
||||
return mirrored
|
||||
|
||||
|
||||
def flip_horizontal(value: Any, expected_shape: list[int] | tuple[int, ...]) -> Any:
|
||||
array = _to_numpy(value)
|
||||
if not isinstance(array, np.ndarray) or array.ndim != 3:
|
||||
return array
|
||||
|
||||
expected_shape = tuple(expected_shape)
|
||||
if array.shape == expected_shape:
|
||||
return np.flip(array, axis=1).copy() # HWC
|
||||
|
||||
if len(expected_shape) == 3:
|
||||
c, h, w = expected_shape
|
||||
if array.shape == (c, h, w):
|
||||
return np.flip(array, axis=2).copy() # CHW
|
||||
|
||||
# Conservative fallback for unexpected layouts.
|
||||
return np.flip(array, axis=-1).copy()
|
||||
|
||||
|
||||
def build_mirrored_features(features: dict[str, dict[str, Any]]) -> dict[str, dict[str, Any]]:
|
||||
mirrored = {}
|
||||
for key, feature in features.items():
|
||||
new_key = swap_left_right_name(key)
|
||||
new_feature = copy.deepcopy(feature)
|
||||
names = new_feature.get("names")
|
||||
if isinstance(names, list):
|
||||
new_feature["names"] = [swap_left_right_name(name) for name in names]
|
||||
elif isinstance(names, dict) and isinstance(names.get("axes"), list):
|
||||
new_feature["names"]["axes"] = [swap_left_right_name(name) for name in names["axes"]]
|
||||
mirrored[new_key] = new_feature
|
||||
return mirrored
|
||||
|
||||
|
||||
def build_mirrored_frame(
|
||||
item: dict[str, Any],
|
||||
source_features: dict[str, dict[str, Any]],
|
||||
mirroring_mask: dict[str, int],
|
||||
) -> dict[str, Any]:
|
||||
frame = {}
|
||||
for key, feature in source_features.items():
|
||||
if key in DEFAULT_FEATURES:
|
||||
continue
|
||||
|
||||
value = item[key]
|
||||
if key in {"action", "observation.state"}:
|
||||
value = mirror_vector_feature(value, feature, mirroring_mask)
|
||||
elif feature["dtype"] in {"video", "image"}:
|
||||
value = flip_horizontal(value, feature["shape"])
|
||||
else:
|
||||
value = _to_numpy(value)
|
||||
|
||||
frame[swap_left_right_name(key)] = value
|
||||
|
||||
frame["task"] = item["task"]
|
||||
if "timestamp" in item:
|
||||
ts = _to_numpy(item["timestamp"])
|
||||
frame["timestamp"] = float(ts.item() if hasattr(ts, "item") else ts)
|
||||
return frame
|
||||
|
||||
|
||||
def _resolve_source_root(repo_id: str, root: Path | None) -> Path:
|
||||
source_meta = LeRobotDatasetMetadata(repo_id=repo_id, root=root)
|
||||
return source_meta.root
|
||||
|
||||
|
||||
def _get_work_dir(output_repo_id: str, work_dir: Path | None) -> Path:
|
||||
if work_dir is not None:
|
||||
return work_dir
|
||||
safe_name = output_repo_id.replace("/", "__")
|
||||
return HF_LEROBOT_HOME / "_mirror_work" / safe_name
|
||||
|
||||
|
||||
def _get_shard_root(work_dir: Path, world_size: int, rank: int) -> Path:
|
||||
return work_dir / "mirrored_shards" / f"world_{world_size}_rank_{rank}"
|
||||
|
||||
|
||||
def _is_valid_dataset_root(root: Path) -> bool:
|
||||
return (root / "meta" / "info.json").exists()
|
||||
|
||||
|
||||
def mirror_shard(
|
||||
repo_id: str,
|
||||
source_root: Path,
|
||||
mirrored_repo_id: str,
|
||||
shard_root: Path,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
vcodec: str,
|
||||
overwrite: bool,
|
||||
) -> None:
|
||||
source_dataset = LeRobotDataset(repo_id=repo_id, root=source_root)
|
||||
selected_episodes = list(range(rank, source_dataset.meta.total_episodes, world_size))
|
||||
|
||||
if len(selected_episodes) == 0:
|
||||
logger.info("Rank %s has no episodes assigned. Skipping.", rank)
|
||||
return
|
||||
|
||||
if shard_root.exists():
|
||||
if overwrite:
|
||||
shutil.rmtree(shard_root)
|
||||
elif _is_valid_dataset_root(shard_root):
|
||||
logger.info("Rank %s shard already exists at %s. Skipping.", rank, shard_root)
|
||||
return
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Shard root {shard_root} exists but is not a valid dataset. Use --overwrite to recreate."
|
||||
)
|
||||
|
||||
mirroring_mask = get_mirroring_mask(source_dataset.meta.robot_type)
|
||||
mirrored_features = build_mirrored_features(source_dataset.meta.features)
|
||||
|
||||
shard_repo_name = f"{mirrored_repo_id}_world_{world_size}_rank_{rank}"
|
||||
mirrored_dataset = LeRobotDataset.create(
|
||||
repo_id=shard_repo_name,
|
||||
root=shard_root,
|
||||
fps=source_dataset.meta.fps,
|
||||
features=mirrored_features,
|
||||
robot_type=source_dataset.meta.robot_type,
|
||||
use_videos=len(source_dataset.meta.video_keys) > 0,
|
||||
vcodec=vcodec,
|
||||
)
|
||||
mirrored_dataset.meta.update_chunk_settings(
|
||||
chunks_size=source_dataset.meta.chunks_size,
|
||||
data_files_size_in_mb=source_dataset.meta.data_files_size_in_mb,
|
||||
video_files_size_in_mb=source_dataset.meta.video_files_size_in_mb,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Rank %s processing %s episodes into shard %s",
|
||||
rank,
|
||||
len(selected_episodes),
|
||||
shard_root,
|
||||
)
|
||||
for source_ep_idx in selected_episodes:
|
||||
episode = source_dataset.meta.episodes[source_ep_idx]
|
||||
start_idx = int(episode["dataset_from_index"])
|
||||
end_idx = int(episode["dataset_to_index"])
|
||||
|
||||
for frame_idx in range(start_idx, end_idx):
|
||||
item = source_dataset[frame_idx]
|
||||
mirrored_frame = build_mirrored_frame(
|
||||
item=item,
|
||||
source_features=source_dataset.meta.features,
|
||||
mirroring_mask=mirroring_mask,
|
||||
)
|
||||
mirrored_dataset.add_frame(mirrored_frame)
|
||||
|
||||
mirrored_dataset.save_episode()
|
||||
|
||||
mirrored_dataset.finalize()
|
||||
|
||||
|
||||
class MirrorDatasetShards(PipelineStep):
|
||||
def __init__(
|
||||
self,
|
||||
repo_id: str,
|
||||
source_root: Path,
|
||||
mirrored_repo_id: str,
|
||||
work_dir: Path,
|
||||
vcodec: str,
|
||||
overwrite: bool,
|
||||
):
|
||||
super().__init__()
|
||||
self.repo_id = repo_id
|
||||
self.source_root = source_root
|
||||
self.mirrored_repo_id = mirrored_repo_id
|
||||
self.work_dir = work_dir
|
||||
self.vcodec = vcodec
|
||||
self.overwrite = overwrite
|
||||
|
||||
def run(self, data=None, rank: int = 0, world_size: int = 1):
|
||||
init_logging()
|
||||
shard_root = _get_shard_root(self.work_dir, world_size, rank)
|
||||
mirror_shard(
|
||||
repo_id=self.repo_id,
|
||||
source_root=self.source_root,
|
||||
mirrored_repo_id=self.mirrored_repo_id,
|
||||
shard_root=shard_root,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
vcodec=self.vcodec,
|
||||
overwrite=self.overwrite,
|
||||
)
|
||||
|
||||
|
||||
def make_mirror_executor(
|
||||
repo_id: str,
|
||||
source_root: Path,
|
||||
mirrored_repo_id: str,
|
||||
work_dir: Path,
|
||||
logs_dir: Path,
|
||||
job_name: str,
|
||||
num_shards: int,
|
||||
workers: int,
|
||||
partition: str,
|
||||
cpus_per_task: int,
|
||||
mem_per_cpu: str,
|
||||
time_limit: str,
|
||||
vcodec: str,
|
||||
overwrite: bool,
|
||||
slurm: bool,
|
||||
):
|
||||
kwargs = {
|
||||
"pipeline": [
|
||||
MirrorDatasetShards(
|
||||
repo_id=repo_id,
|
||||
source_root=source_root,
|
||||
mirrored_repo_id=mirrored_repo_id,
|
||||
work_dir=work_dir,
|
||||
vcodec=vcodec,
|
||||
overwrite=overwrite,
|
||||
),
|
||||
],
|
||||
"logging_dir": str(logs_dir / job_name),
|
||||
}
|
||||
|
||||
if slurm:
|
||||
if partition is None:
|
||||
raise ValueError("`--partition` is required when `--slurm 1`.")
|
||||
kwargs.update(
|
||||
{
|
||||
"job_name": job_name,
|
||||
"tasks": num_shards,
|
||||
"workers": workers,
|
||||
"time": time_limit,
|
||||
"partition": partition,
|
||||
"cpus_per_task": cpus_per_task,
|
||||
"sbatch_args": {"mem-per-cpu": mem_per_cpu},
|
||||
}
|
||||
)
|
||||
return SlurmPipelineExecutor(**kwargs)
|
||||
|
||||
kwargs.update({"tasks": num_shards, "workers": 1})
|
||||
return LocalPipelineExecutor(**kwargs)
|
||||
|
||||
|
||||
class AggregateMirroredShardsStep(PipelineStep):
|
||||
def __init__(
|
||||
self,
|
||||
mirrored_repo_id: str,
|
||||
mirrored_root: Path,
|
||||
work_dir: Path,
|
||||
num_shards: int,
|
||||
overwrite: bool,
|
||||
):
|
||||
super().__init__()
|
||||
self.mirrored_repo_id = mirrored_repo_id
|
||||
self.mirrored_root = mirrored_root
|
||||
self.work_dir = work_dir
|
||||
self.num_shards = num_shards
|
||||
self.overwrite = overwrite
|
||||
|
||||
def run(self, data=None, rank: int = 0, world_size: int = 1):
|
||||
init_logging()
|
||||
if rank != 0:
|
||||
logger.info("Skipping rank %s for aggregate mirrored step", rank)
|
||||
return
|
||||
aggregate_mirrored_shards(
|
||||
mirrored_repo_id=self.mirrored_repo_id,
|
||||
mirrored_root=self.mirrored_root,
|
||||
work_dir=self.work_dir,
|
||||
num_shards=self.num_shards,
|
||||
overwrite=self.overwrite,
|
||||
)
|
||||
|
||||
|
||||
class BuildDoubledDatasetStep(PipelineStep):
|
||||
def __init__(
|
||||
self,
|
||||
source_repo_id: str,
|
||||
source_root: Path,
|
||||
mirrored_repo_id: str,
|
||||
mirrored_root: Path,
|
||||
output_repo_id: str,
|
||||
output_root: Path,
|
||||
overwrite: bool,
|
||||
):
|
||||
super().__init__()
|
||||
self.source_repo_id = source_repo_id
|
||||
self.source_root = source_root
|
||||
self.mirrored_repo_id = mirrored_repo_id
|
||||
self.mirrored_root = mirrored_root
|
||||
self.output_repo_id = output_repo_id
|
||||
self.output_root = output_root
|
||||
self.overwrite = overwrite
|
||||
|
||||
def run(self, data=None, rank: int = 0, world_size: int = 1):
|
||||
init_logging()
|
||||
if rank != 0:
|
||||
logger.info("Skipping rank %s for build doubled step", rank)
|
||||
return
|
||||
build_doubled_dataset(
|
||||
source_repo_id=self.source_repo_id,
|
||||
source_root=self.source_root,
|
||||
mirrored_repo_id=self.mirrored_repo_id,
|
||||
mirrored_root=self.mirrored_root,
|
||||
output_repo_id=self.output_repo_id,
|
||||
output_root=self.output_root,
|
||||
overwrite=self.overwrite,
|
||||
)
|
||||
|
||||
|
||||
class PushDoubledDatasetStep(PipelineStep):
|
||||
def __init__(
|
||||
self,
|
||||
output_repo_id: str,
|
||||
output_root: Path,
|
||||
):
|
||||
super().__init__()
|
||||
self.output_repo_id = output_repo_id
|
||||
self.output_root = output_root
|
||||
|
||||
def run(self, data=None, rank: int = 0, world_size: int = 1):
|
||||
init_logging()
|
||||
if rank != 0:
|
||||
logger.info("Skipping rank %s for push step", rank)
|
||||
return
|
||||
logger.info("Pushing doubled dataset to hub: %s", self.output_repo_id)
|
||||
LeRobotDataset(self.output_repo_id, root=self.output_root).push_to_hub()
|
||||
|
||||
|
||||
def make_single_task_executor(
|
||||
step: PipelineStep,
|
||||
logs_dir: Path,
|
||||
job_name: str,
|
||||
partition: str | None,
|
||||
cpus_per_task: int,
|
||||
mem_per_cpu: str,
|
||||
time_limit: str,
|
||||
slurm: bool,
|
||||
depends: SlurmPipelineExecutor | None = None,
|
||||
):
|
||||
kwargs = {"pipeline": [step], "logging_dir": str(logs_dir / job_name)}
|
||||
if slurm:
|
||||
if partition is None:
|
||||
raise ValueError("`--partition` is required when `--slurm 1`.")
|
||||
kwargs.update(
|
||||
{
|
||||
"job_name": job_name,
|
||||
"tasks": 1,
|
||||
"workers": 1,
|
||||
"time": time_limit,
|
||||
"partition": partition,
|
||||
"cpus_per_task": cpus_per_task,
|
||||
"sbatch_args": {"mem-per-cpu": mem_per_cpu},
|
||||
"depends": depends,
|
||||
}
|
||||
)
|
||||
return SlurmPipelineExecutor(**kwargs)
|
||||
|
||||
kwargs.update({"tasks": 1, "workers": 1})
|
||||
return LocalPipelineExecutor(**kwargs)
|
||||
|
||||
|
||||
def aggregate_mirrored_shards(
|
||||
mirrored_repo_id: str,
|
||||
mirrored_root: Path,
|
||||
work_dir: Path,
|
||||
num_shards: int,
|
||||
overwrite: bool,
|
||||
):
|
||||
if mirrored_root.exists():
|
||||
if overwrite:
|
||||
shutil.rmtree(mirrored_root)
|
||||
elif _is_valid_dataset_root(mirrored_root):
|
||||
logger.info("Mirrored dataset already exists at %s. Skipping aggregation.", mirrored_root)
|
||||
return
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Mirrored root {mirrored_root} exists but is not a valid dataset. Use --overwrite to recreate."
|
||||
)
|
||||
|
||||
shard_repo_ids = []
|
||||
shard_roots = []
|
||||
for rank in range(num_shards):
|
||||
shard_root = _get_shard_root(work_dir, num_shards, rank)
|
||||
if _is_valid_dataset_root(shard_root):
|
||||
shard_repo_ids.append(f"{mirrored_repo_id}_world_{num_shards}_rank_{rank}")
|
||||
shard_roots.append(shard_root)
|
||||
|
||||
if len(shard_repo_ids) == 0:
|
||||
raise RuntimeError("No mirrored shards were produced. Nothing to aggregate.")
|
||||
|
||||
logger.info("Aggregating %s mirrored shards into %s", len(shard_repo_ids), mirrored_root)
|
||||
aggregate_datasets(
|
||||
repo_ids=shard_repo_ids,
|
||||
roots=shard_roots,
|
||||
aggr_repo_id=mirrored_repo_id,
|
||||
aggr_root=mirrored_root,
|
||||
)
|
||||
|
||||
|
||||
def build_doubled_dataset(
|
||||
source_repo_id: str,
|
||||
source_root: Path,
|
||||
mirrored_repo_id: str,
|
||||
mirrored_root: Path,
|
||||
output_repo_id: str,
|
||||
output_root: Path,
|
||||
overwrite: bool,
|
||||
):
|
||||
if output_root.exists():
|
||||
if overwrite:
|
||||
shutil.rmtree(output_root)
|
||||
elif _is_valid_dataset_root(output_root):
|
||||
logger.info("Doubled dataset already exists at %s. Skipping final aggregation.", output_root)
|
||||
return
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Output root {output_root} exists but is not a valid dataset. Use --overwrite to recreate."
|
||||
)
|
||||
|
||||
logger.info("Aggregating source + mirrored into doubled dataset at %s", output_root)
|
||||
aggregate_datasets(
|
||||
repo_ids=[source_repo_id, mirrored_repo_id],
|
||||
roots=[source_root, mirrored_root],
|
||||
aggr_repo_id=output_repo_id,
|
||||
aggr_root=output_root,
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--repo-id", type=str, required=True, help="Source dataset repo id.")
|
||||
parser.add_argument("--output-repo-id", type=str, required=True, help="Final doubled dataset repo id.")
|
||||
parser.add_argument("--root", type=Path, default=None, help="Root path of source dataset.")
|
||||
parser.add_argument(
|
||||
"--output-root",
|
||||
type=Path,
|
||||
default=None,
|
||||
help="Root path where final doubled dataset is written.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--work-dir",
|
||||
type=Path,
|
||||
default=None,
|
||||
help="Intermediate directory for mirrored shards and mirrored aggregate dataset.",
|
||||
)
|
||||
parser.add_argument("--logs-dir", type=Path, required=True, help="DataTrove logs path.")
|
||||
parser.add_argument("--job-name", type=str, default="mirror_dataset", help="SLURM job name.")
|
||||
parser.add_argument("--num-shards", type=int, default=256, help="Number of DataTrove tasks/ranks.")
|
||||
parser.add_argument(
|
||||
"--workers",
|
||||
type=int,
|
||||
default=64,
|
||||
help="Max concurrent DataTrove workers on SLURM.",
|
||||
)
|
||||
parser.add_argument("--partition", type=str, default=None, help="SLURM partition (e.g. hopper-cpu).")
|
||||
parser.add_argument("--cpus-per-task", type=int, default=8, help="CPU count per SLURM task.")
|
||||
parser.add_argument("--mem-per-cpu", type=str, default="4G", help="Memory per CPU for SLURM task.")
|
||||
parser.add_argument("--time", type=str, default="24:00:00", help="SLURM time limit.")
|
||||
parser.add_argument("--vcodec", type=str, default="libsvtav1", help="Video codec for output videos.")
|
||||
parser.add_argument(
|
||||
"--slurm",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Use SLURM executor. Set 0 for local sequential debugging.",
|
||||
)
|
||||
parser.add_argument("--overwrite", action="store_true", help="Delete existing intermediate/final outputs.")
|
||||
parser.add_argument(
|
||||
"--push-to-hub",
|
||||
action="store_true",
|
||||
help="Push final doubled dataset to Hugging Face Hub after completion.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
init_logging()
|
||||
slurm = args.slurm == 1
|
||||
|
||||
source_root = _resolve_source_root(args.repo_id, args.root)
|
||||
output_root = args.output_root if args.output_root is not None else HF_LEROBOT_HOME / args.output_repo_id
|
||||
work_dir = _get_work_dir(args.output_repo_id, args.work_dir)
|
||||
mirrored_repo_id = f"{args.output_repo_id}_mirrored"
|
||||
mirrored_root = work_dir / "mirrored_aggregate"
|
||||
|
||||
work_dir.mkdir(parents=True, exist_ok=True)
|
||||
args.logs_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
mirror_executor = make_mirror_executor(
|
||||
repo_id=args.repo_id,
|
||||
source_root=source_root,
|
||||
mirrored_repo_id=mirrored_repo_id,
|
||||
work_dir=work_dir,
|
||||
logs_dir=args.logs_dir,
|
||||
job_name=args.job_name,
|
||||
num_shards=args.num_shards,
|
||||
workers=args.workers,
|
||||
partition=args.partition,
|
||||
cpus_per_task=args.cpus_per_task,
|
||||
mem_per_cpu=args.mem_per_cpu,
|
||||
time_limit=args.time,
|
||||
vcodec=args.vcodec,
|
||||
overwrite=args.overwrite,
|
||||
slurm=slurm,
|
||||
)
|
||||
if slurm:
|
||||
aggregate_executor = make_single_task_executor(
|
||||
step=AggregateMirroredShardsStep(
|
||||
mirrored_repo_id=mirrored_repo_id,
|
||||
mirrored_root=mirrored_root,
|
||||
work_dir=work_dir,
|
||||
num_shards=args.num_shards,
|
||||
overwrite=args.overwrite,
|
||||
),
|
||||
logs_dir=args.logs_dir,
|
||||
job_name=f"{args.job_name}_aggregate_mirrored",
|
||||
partition=args.partition,
|
||||
cpus_per_task=args.cpus_per_task,
|
||||
mem_per_cpu=args.mem_per_cpu,
|
||||
time_limit=args.time,
|
||||
slurm=True,
|
||||
depends=mirror_executor,
|
||||
)
|
||||
build_executor = make_single_task_executor(
|
||||
step=BuildDoubledDatasetStep(
|
||||
source_repo_id=args.repo_id,
|
||||
source_root=source_root,
|
||||
mirrored_repo_id=mirrored_repo_id,
|
||||
mirrored_root=mirrored_root,
|
||||
output_repo_id=args.output_repo_id,
|
||||
output_root=output_root,
|
||||
overwrite=args.overwrite,
|
||||
),
|
||||
logs_dir=args.logs_dir,
|
||||
job_name=f"{args.job_name}_build_doubled",
|
||||
partition=args.partition,
|
||||
cpus_per_task=args.cpus_per_task,
|
||||
mem_per_cpu=args.mem_per_cpu,
|
||||
time_limit=args.time,
|
||||
slurm=True,
|
||||
depends=aggregate_executor,
|
||||
)
|
||||
|
||||
final_executor: SlurmPipelineExecutor | LocalPipelineExecutor = build_executor
|
||||
push_executor = None
|
||||
if args.push_to_hub:
|
||||
push_executor = make_single_task_executor(
|
||||
step=PushDoubledDatasetStep(
|
||||
output_repo_id=args.output_repo_id,
|
||||
output_root=output_root,
|
||||
),
|
||||
logs_dir=args.logs_dir,
|
||||
job_name=f"{args.job_name}_push",
|
||||
partition=args.partition,
|
||||
cpus_per_task=args.cpus_per_task,
|
||||
mem_per_cpu=args.mem_per_cpu,
|
||||
time_limit=args.time,
|
||||
slurm=True,
|
||||
depends=build_executor,
|
||||
)
|
||||
final_executor = push_executor
|
||||
|
||||
final_executor.run()
|
||||
logger.info(
|
||||
"Submitted SLURM chain. job_ids: mirror=%s aggregate=%s doubled=%s push=%s",
|
||||
mirror_executor.job_id,
|
||||
aggregate_executor.job_id,
|
||||
build_executor.job_id,
|
||||
push_executor.job_id if push_executor is not None else None,
|
||||
)
|
||||
return
|
||||
|
||||
mirror_executor.run()
|
||||
aggregate_mirrored_shards(
|
||||
mirrored_repo_id=mirrored_repo_id,
|
||||
mirrored_root=mirrored_root,
|
||||
work_dir=work_dir,
|
||||
num_shards=args.num_shards,
|
||||
overwrite=args.overwrite,
|
||||
)
|
||||
build_doubled_dataset(
|
||||
source_repo_id=args.repo_id,
|
||||
source_root=source_root,
|
||||
mirrored_repo_id=mirrored_repo_id,
|
||||
mirrored_root=mirrored_root,
|
||||
output_repo_id=args.output_repo_id,
|
||||
output_root=output_root,
|
||||
overwrite=args.overwrite,
|
||||
)
|
||||
if args.push_to_hub:
|
||||
logger.info("Pushing doubled dataset to hub: %s", args.output_repo_id)
|
||||
LeRobotDataset(args.output_repo_id, root=output_root).push_to_hub()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user