Add chained SLURM mirror-and-double dataset script.

Provide a standalone DataTrove workflow that mirrors bimanual shards, aggregates mirrored output, builds a doubled dataset, and optionally pushes the final dataset to the Hub.

Made-with: Cursor
This commit is contained in:
pepijn
2026-02-27 10:59:33 +00:00
parent 5865170d36
commit b2d3186011
@@ -0,0 +1,726 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Mirror a bimanual dataset in parallel with DataTrove + SLURM, then double it.
Workflow:
1) Split source episodes across `num_shards` ranks and mirror each shard in parallel.
2) Aggregate mirrored shards into one mirrored dataset.
3) Aggregate [original, mirrored] into a final doubled dataset.
Example:
python examples/port_datasets/slurm_mirror_dataset.py \
--repo-id=pepijn/openarm_bimanual \
--output-repo-id=pepijn/openarm_bimanual_doubled \
--partition=hopper-cpu \
--num-shards=256 \
--workers=64 \
--cpus-per-task=8 \
--mem-per-cpu=4G
"""
import argparse
import copy
import logging
import shutil
from pathlib import Path
from typing import Any
import numpy as np
from datatrove.executor import LocalPipelineExecutor
from datatrove.executor.slurm import SlurmPipelineExecutor
from datatrove.pipeline.base import PipelineStep
from lerobot.datasets.aggregate import aggregate_datasets
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
from lerobot.datasets.utils import DEFAULT_FEATURES
from lerobot.utils.constants import HF_LEROBOT_HOME
from lerobot.utils.utils import init_logging
logger = logging.getLogger(__name__)
OPENARM_MIRRORING_MASK = {
"joint_1": -1,
"joint_2": -1,
"joint_3": -1,
"joint_4": 1,
"joint_5": -1,
"joint_6": -1,
"joint_7": -1,
"gripper": 1,
}
def get_mirroring_mask(robot_type: str | None) -> dict[str, int]:
if robot_type in ["bi_openarm_follower", "openarm_follower", "bi_openarms_follower", "openarms_follower"]:
return OPENARM_MIRRORING_MASK
raise ValueError(f"Unknown robot type: {robot_type}. Add a mirroring mask for this robot.")
def swap_left_right_name(name: str) -> str:
value = name.replace("left_", "LEFT_PLACEHOLDER_")
value = value.replace("right_", "left_")
value = value.replace("LEFT_PLACEHOLDER_", "right_")
return value
def mirror_feature_names(names: list[str]) -> tuple[list[str], dict[int, int]]:
mirrored_names = [swap_left_right_name(n) for n in names]
old_to_new_idx = {}
for old_idx, old_name in enumerate(names):
new_name = swap_left_right_name(old_name)
new_idx = mirrored_names.index(new_name)
old_to_new_idx[old_idx] = new_idx
return mirrored_names, old_to_new_idx
def _get_axis_names(feature: dict[str, Any]) -> list[str] | None:
names = feature.get("names")
if isinstance(names, list):
return names
if isinstance(names, dict):
axes = names.get("axes")
if isinstance(axes, list):
return axes
return None
def _to_numpy(value: Any) -> Any:
if isinstance(value, np.ndarray):
return value
if hasattr(value, "detach"):
return value.detach().cpu().numpy()
if hasattr(value, "cpu") and hasattr(value, "numpy"):
return value.cpu().numpy()
if hasattr(value, "numpy"):
return value.numpy()
return value
def apply_mirroring_mask(value: float, axis_name: str, mirroring_mask: dict[str, int]) -> float:
if axis_name.startswith("left_") or axis_name.startswith("right_"):
axis_name = axis_name.split("_", 1)[1]
joint_name = axis_name.split(".")[0]
return value * mirroring_mask.get(joint_name, 1)
def mirror_vector_feature(
value: Any,
feature: dict[str, Any],
mirroring_mask: dict[str, int],
) -> Any:
array = _to_numpy(value)
if not isinstance(array, np.ndarray) or array.ndim != 1:
return array
names = _get_axis_names(feature)
if names is None or len(names) != len(array):
return array
mirrored_names, index_mapping = mirror_feature_names(names)
mirrored = np.zeros_like(array)
for old_idx, new_idx in index_mapping.items():
mirrored[new_idx] = apply_mirroring_mask(array[old_idx], mirrored_names[new_idx], mirroring_mask)
return mirrored
def flip_horizontal(value: Any, expected_shape: list[int] | tuple[int, ...]) -> Any:
array = _to_numpy(value)
if not isinstance(array, np.ndarray) or array.ndim != 3:
return array
expected_shape = tuple(expected_shape)
if array.shape == expected_shape:
return np.flip(array, axis=1).copy() # HWC
if len(expected_shape) == 3:
c, h, w = expected_shape
if array.shape == (c, h, w):
return np.flip(array, axis=2).copy() # CHW
# Conservative fallback for unexpected layouts.
return np.flip(array, axis=-1).copy()
def build_mirrored_features(features: dict[str, dict[str, Any]]) -> dict[str, dict[str, Any]]:
mirrored = {}
for key, feature in features.items():
new_key = swap_left_right_name(key)
new_feature = copy.deepcopy(feature)
names = new_feature.get("names")
if isinstance(names, list):
new_feature["names"] = [swap_left_right_name(name) for name in names]
elif isinstance(names, dict) and isinstance(names.get("axes"), list):
new_feature["names"]["axes"] = [swap_left_right_name(name) for name in names["axes"]]
mirrored[new_key] = new_feature
return mirrored
def build_mirrored_frame(
item: dict[str, Any],
source_features: dict[str, dict[str, Any]],
mirroring_mask: dict[str, int],
) -> dict[str, Any]:
frame = {}
for key, feature in source_features.items():
if key in DEFAULT_FEATURES:
continue
value = item[key]
if key in {"action", "observation.state"}:
value = mirror_vector_feature(value, feature, mirroring_mask)
elif feature["dtype"] in {"video", "image"}:
value = flip_horizontal(value, feature["shape"])
else:
value = _to_numpy(value)
frame[swap_left_right_name(key)] = value
frame["task"] = item["task"]
if "timestamp" in item:
ts = _to_numpy(item["timestamp"])
frame["timestamp"] = float(ts.item() if hasattr(ts, "item") else ts)
return frame
def _resolve_source_root(repo_id: str, root: Path | None) -> Path:
source_meta = LeRobotDatasetMetadata(repo_id=repo_id, root=root)
return source_meta.root
def _get_work_dir(output_repo_id: str, work_dir: Path | None) -> Path:
if work_dir is not None:
return work_dir
safe_name = output_repo_id.replace("/", "__")
return HF_LEROBOT_HOME / "_mirror_work" / safe_name
def _get_shard_root(work_dir: Path, world_size: int, rank: int) -> Path:
return work_dir / "mirrored_shards" / f"world_{world_size}_rank_{rank}"
def _is_valid_dataset_root(root: Path) -> bool:
return (root / "meta" / "info.json").exists()
def mirror_shard(
repo_id: str,
source_root: Path,
mirrored_repo_id: str,
shard_root: Path,
rank: int,
world_size: int,
vcodec: str,
overwrite: bool,
) -> None:
source_dataset = LeRobotDataset(repo_id=repo_id, root=source_root)
selected_episodes = list(range(rank, source_dataset.meta.total_episodes, world_size))
if len(selected_episodes) == 0:
logger.info("Rank %s has no episodes assigned. Skipping.", rank)
return
if shard_root.exists():
if overwrite:
shutil.rmtree(shard_root)
elif _is_valid_dataset_root(shard_root):
logger.info("Rank %s shard already exists at %s. Skipping.", rank, shard_root)
return
else:
raise RuntimeError(
f"Shard root {shard_root} exists but is not a valid dataset. Use --overwrite to recreate."
)
mirroring_mask = get_mirroring_mask(source_dataset.meta.robot_type)
mirrored_features = build_mirrored_features(source_dataset.meta.features)
shard_repo_name = f"{mirrored_repo_id}_world_{world_size}_rank_{rank}"
mirrored_dataset = LeRobotDataset.create(
repo_id=shard_repo_name,
root=shard_root,
fps=source_dataset.meta.fps,
features=mirrored_features,
robot_type=source_dataset.meta.robot_type,
use_videos=len(source_dataset.meta.video_keys) > 0,
vcodec=vcodec,
)
mirrored_dataset.meta.update_chunk_settings(
chunks_size=source_dataset.meta.chunks_size,
data_files_size_in_mb=source_dataset.meta.data_files_size_in_mb,
video_files_size_in_mb=source_dataset.meta.video_files_size_in_mb,
)
logger.info(
"Rank %s processing %s episodes into shard %s",
rank,
len(selected_episodes),
shard_root,
)
for source_ep_idx in selected_episodes:
episode = source_dataset.meta.episodes[source_ep_idx]
start_idx = int(episode["dataset_from_index"])
end_idx = int(episode["dataset_to_index"])
for frame_idx in range(start_idx, end_idx):
item = source_dataset[frame_idx]
mirrored_frame = build_mirrored_frame(
item=item,
source_features=source_dataset.meta.features,
mirroring_mask=mirroring_mask,
)
mirrored_dataset.add_frame(mirrored_frame)
mirrored_dataset.save_episode()
mirrored_dataset.finalize()
class MirrorDatasetShards(PipelineStep):
def __init__(
self,
repo_id: str,
source_root: Path,
mirrored_repo_id: str,
work_dir: Path,
vcodec: str,
overwrite: bool,
):
super().__init__()
self.repo_id = repo_id
self.source_root = source_root
self.mirrored_repo_id = mirrored_repo_id
self.work_dir = work_dir
self.vcodec = vcodec
self.overwrite = overwrite
def run(self, data=None, rank: int = 0, world_size: int = 1):
init_logging()
shard_root = _get_shard_root(self.work_dir, world_size, rank)
mirror_shard(
repo_id=self.repo_id,
source_root=self.source_root,
mirrored_repo_id=self.mirrored_repo_id,
shard_root=shard_root,
rank=rank,
world_size=world_size,
vcodec=self.vcodec,
overwrite=self.overwrite,
)
def make_mirror_executor(
repo_id: str,
source_root: Path,
mirrored_repo_id: str,
work_dir: Path,
logs_dir: Path,
job_name: str,
num_shards: int,
workers: int,
partition: str,
cpus_per_task: int,
mem_per_cpu: str,
time_limit: str,
vcodec: str,
overwrite: bool,
slurm: bool,
):
kwargs = {
"pipeline": [
MirrorDatasetShards(
repo_id=repo_id,
source_root=source_root,
mirrored_repo_id=mirrored_repo_id,
work_dir=work_dir,
vcodec=vcodec,
overwrite=overwrite,
),
],
"logging_dir": str(logs_dir / job_name),
}
if slurm:
if partition is None:
raise ValueError("`--partition` is required when `--slurm 1`.")
kwargs.update(
{
"job_name": job_name,
"tasks": num_shards,
"workers": workers,
"time": time_limit,
"partition": partition,
"cpus_per_task": cpus_per_task,
"sbatch_args": {"mem-per-cpu": mem_per_cpu},
}
)
return SlurmPipelineExecutor(**kwargs)
kwargs.update({"tasks": num_shards, "workers": 1})
return LocalPipelineExecutor(**kwargs)
class AggregateMirroredShardsStep(PipelineStep):
def __init__(
self,
mirrored_repo_id: str,
mirrored_root: Path,
work_dir: Path,
num_shards: int,
overwrite: bool,
):
super().__init__()
self.mirrored_repo_id = mirrored_repo_id
self.mirrored_root = mirrored_root
self.work_dir = work_dir
self.num_shards = num_shards
self.overwrite = overwrite
def run(self, data=None, rank: int = 0, world_size: int = 1):
init_logging()
if rank != 0:
logger.info("Skipping rank %s for aggregate mirrored step", rank)
return
aggregate_mirrored_shards(
mirrored_repo_id=self.mirrored_repo_id,
mirrored_root=self.mirrored_root,
work_dir=self.work_dir,
num_shards=self.num_shards,
overwrite=self.overwrite,
)
class BuildDoubledDatasetStep(PipelineStep):
def __init__(
self,
source_repo_id: str,
source_root: Path,
mirrored_repo_id: str,
mirrored_root: Path,
output_repo_id: str,
output_root: Path,
overwrite: bool,
):
super().__init__()
self.source_repo_id = source_repo_id
self.source_root = source_root
self.mirrored_repo_id = mirrored_repo_id
self.mirrored_root = mirrored_root
self.output_repo_id = output_repo_id
self.output_root = output_root
self.overwrite = overwrite
def run(self, data=None, rank: int = 0, world_size: int = 1):
init_logging()
if rank != 0:
logger.info("Skipping rank %s for build doubled step", rank)
return
build_doubled_dataset(
source_repo_id=self.source_repo_id,
source_root=self.source_root,
mirrored_repo_id=self.mirrored_repo_id,
mirrored_root=self.mirrored_root,
output_repo_id=self.output_repo_id,
output_root=self.output_root,
overwrite=self.overwrite,
)
class PushDoubledDatasetStep(PipelineStep):
def __init__(
self,
output_repo_id: str,
output_root: Path,
):
super().__init__()
self.output_repo_id = output_repo_id
self.output_root = output_root
def run(self, data=None, rank: int = 0, world_size: int = 1):
init_logging()
if rank != 0:
logger.info("Skipping rank %s for push step", rank)
return
logger.info("Pushing doubled dataset to hub: %s", self.output_repo_id)
LeRobotDataset(self.output_repo_id, root=self.output_root).push_to_hub()
def make_single_task_executor(
step: PipelineStep,
logs_dir: Path,
job_name: str,
partition: str | None,
cpus_per_task: int,
mem_per_cpu: str,
time_limit: str,
slurm: bool,
depends: SlurmPipelineExecutor | None = None,
):
kwargs = {"pipeline": [step], "logging_dir": str(logs_dir / job_name)}
if slurm:
if partition is None:
raise ValueError("`--partition` is required when `--slurm 1`.")
kwargs.update(
{
"job_name": job_name,
"tasks": 1,
"workers": 1,
"time": time_limit,
"partition": partition,
"cpus_per_task": cpus_per_task,
"sbatch_args": {"mem-per-cpu": mem_per_cpu},
"depends": depends,
}
)
return SlurmPipelineExecutor(**kwargs)
kwargs.update({"tasks": 1, "workers": 1})
return LocalPipelineExecutor(**kwargs)
def aggregate_mirrored_shards(
mirrored_repo_id: str,
mirrored_root: Path,
work_dir: Path,
num_shards: int,
overwrite: bool,
):
if mirrored_root.exists():
if overwrite:
shutil.rmtree(mirrored_root)
elif _is_valid_dataset_root(mirrored_root):
logger.info("Mirrored dataset already exists at %s. Skipping aggregation.", mirrored_root)
return
else:
raise RuntimeError(
f"Mirrored root {mirrored_root} exists but is not a valid dataset. Use --overwrite to recreate."
)
shard_repo_ids = []
shard_roots = []
for rank in range(num_shards):
shard_root = _get_shard_root(work_dir, num_shards, rank)
if _is_valid_dataset_root(shard_root):
shard_repo_ids.append(f"{mirrored_repo_id}_world_{num_shards}_rank_{rank}")
shard_roots.append(shard_root)
if len(shard_repo_ids) == 0:
raise RuntimeError("No mirrored shards were produced. Nothing to aggregate.")
logger.info("Aggregating %s mirrored shards into %s", len(shard_repo_ids), mirrored_root)
aggregate_datasets(
repo_ids=shard_repo_ids,
roots=shard_roots,
aggr_repo_id=mirrored_repo_id,
aggr_root=mirrored_root,
)
def build_doubled_dataset(
source_repo_id: str,
source_root: Path,
mirrored_repo_id: str,
mirrored_root: Path,
output_repo_id: str,
output_root: Path,
overwrite: bool,
):
if output_root.exists():
if overwrite:
shutil.rmtree(output_root)
elif _is_valid_dataset_root(output_root):
logger.info("Doubled dataset already exists at %s. Skipping final aggregation.", output_root)
return
else:
raise RuntimeError(
f"Output root {output_root} exists but is not a valid dataset. Use --overwrite to recreate."
)
logger.info("Aggregating source + mirrored into doubled dataset at %s", output_root)
aggregate_datasets(
repo_ids=[source_repo_id, mirrored_repo_id],
roots=[source_root, mirrored_root],
aggr_repo_id=output_repo_id,
aggr_root=output_root,
)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--repo-id", type=str, required=True, help="Source dataset repo id.")
parser.add_argument("--output-repo-id", type=str, required=True, help="Final doubled dataset repo id.")
parser.add_argument("--root", type=Path, default=None, help="Root path of source dataset.")
parser.add_argument(
"--output-root",
type=Path,
default=None,
help="Root path where final doubled dataset is written.",
)
parser.add_argument(
"--work-dir",
type=Path,
default=None,
help="Intermediate directory for mirrored shards and mirrored aggregate dataset.",
)
parser.add_argument("--logs-dir", type=Path, required=True, help="DataTrove logs path.")
parser.add_argument("--job-name", type=str, default="mirror_dataset", help="SLURM job name.")
parser.add_argument("--num-shards", type=int, default=256, help="Number of DataTrove tasks/ranks.")
parser.add_argument(
"--workers",
type=int,
default=64,
help="Max concurrent DataTrove workers on SLURM.",
)
parser.add_argument("--partition", type=str, default=None, help="SLURM partition (e.g. hopper-cpu).")
parser.add_argument("--cpus-per-task", type=int, default=8, help="CPU count per SLURM task.")
parser.add_argument("--mem-per-cpu", type=str, default="4G", help="Memory per CPU for SLURM task.")
parser.add_argument("--time", type=str, default="24:00:00", help="SLURM time limit.")
parser.add_argument("--vcodec", type=str, default="libsvtav1", help="Video codec for output videos.")
parser.add_argument(
"--slurm",
type=int,
default=1,
help="Use SLURM executor. Set 0 for local sequential debugging.",
)
parser.add_argument("--overwrite", action="store_true", help="Delete existing intermediate/final outputs.")
parser.add_argument(
"--push-to-hub",
action="store_true",
help="Push final doubled dataset to Hugging Face Hub after completion.",
)
args = parser.parse_args()
init_logging()
slurm = args.slurm == 1
source_root = _resolve_source_root(args.repo_id, args.root)
output_root = args.output_root if args.output_root is not None else HF_LEROBOT_HOME / args.output_repo_id
work_dir = _get_work_dir(args.output_repo_id, args.work_dir)
mirrored_repo_id = f"{args.output_repo_id}_mirrored"
mirrored_root = work_dir / "mirrored_aggregate"
work_dir.mkdir(parents=True, exist_ok=True)
args.logs_dir.mkdir(parents=True, exist_ok=True)
mirror_executor = make_mirror_executor(
repo_id=args.repo_id,
source_root=source_root,
mirrored_repo_id=mirrored_repo_id,
work_dir=work_dir,
logs_dir=args.logs_dir,
job_name=args.job_name,
num_shards=args.num_shards,
workers=args.workers,
partition=args.partition,
cpus_per_task=args.cpus_per_task,
mem_per_cpu=args.mem_per_cpu,
time_limit=args.time,
vcodec=args.vcodec,
overwrite=args.overwrite,
slurm=slurm,
)
if slurm:
aggregate_executor = make_single_task_executor(
step=AggregateMirroredShardsStep(
mirrored_repo_id=mirrored_repo_id,
mirrored_root=mirrored_root,
work_dir=work_dir,
num_shards=args.num_shards,
overwrite=args.overwrite,
),
logs_dir=args.logs_dir,
job_name=f"{args.job_name}_aggregate_mirrored",
partition=args.partition,
cpus_per_task=args.cpus_per_task,
mem_per_cpu=args.mem_per_cpu,
time_limit=args.time,
slurm=True,
depends=mirror_executor,
)
build_executor = make_single_task_executor(
step=BuildDoubledDatasetStep(
source_repo_id=args.repo_id,
source_root=source_root,
mirrored_repo_id=mirrored_repo_id,
mirrored_root=mirrored_root,
output_repo_id=args.output_repo_id,
output_root=output_root,
overwrite=args.overwrite,
),
logs_dir=args.logs_dir,
job_name=f"{args.job_name}_build_doubled",
partition=args.partition,
cpus_per_task=args.cpus_per_task,
mem_per_cpu=args.mem_per_cpu,
time_limit=args.time,
slurm=True,
depends=aggregate_executor,
)
final_executor: SlurmPipelineExecutor | LocalPipelineExecutor = build_executor
push_executor = None
if args.push_to_hub:
push_executor = make_single_task_executor(
step=PushDoubledDatasetStep(
output_repo_id=args.output_repo_id,
output_root=output_root,
),
logs_dir=args.logs_dir,
job_name=f"{args.job_name}_push",
partition=args.partition,
cpus_per_task=args.cpus_per_task,
mem_per_cpu=args.mem_per_cpu,
time_limit=args.time,
slurm=True,
depends=build_executor,
)
final_executor = push_executor
final_executor.run()
logger.info(
"Submitted SLURM chain. job_ids: mirror=%s aggregate=%s doubled=%s push=%s",
mirror_executor.job_id,
aggregate_executor.job_id,
build_executor.job_id,
push_executor.job_id if push_executor is not None else None,
)
return
mirror_executor.run()
aggregate_mirrored_shards(
mirrored_repo_id=mirrored_repo_id,
mirrored_root=mirrored_root,
work_dir=work_dir,
num_shards=args.num_shards,
overwrite=args.overwrite,
)
build_doubled_dataset(
source_repo_id=args.repo_id,
source_root=source_root,
mirrored_repo_id=mirrored_repo_id,
mirrored_root=mirrored_root,
output_repo_id=args.output_repo_id,
output_root=output_root,
overwrite=args.overwrite,
)
if args.push_to_hub:
logger.info("Pushing doubled dataset to hub: %s", args.output_repo_id)
LeRobotDataset(args.output_repo_id, root=output_root).push_to_hub()
if __name__ == "__main__":
main()