This commit is contained in:
Pepijn
2026-02-05 16:03:45 +01:00
parent 39e14c086c
commit 76a4529d29
+107 -91
View File
@@ -41,13 +41,9 @@ python examples/port_datasets/slurm_mirror_dataset.py \
""" """
import argparse import argparse
import json
import logging import logging
import subprocess
from pathlib import Path from pathlib import Path
import numpy as np
import pandas as pd
from datatrove.executor import LocalPipelineExecutor from datatrove.executor import LocalPipelineExecutor
from datatrove.executor.slurm import SlurmPipelineExecutor from datatrove.executor.slurm import SlurmPipelineExecutor
from datatrove.pipeline.base import PipelineStep from datatrove.pipeline.base import PipelineStep
@@ -66,49 +62,44 @@ OPENARM_MIRRORING_MASK = {
} }
def get_mirroring_mask(robot_type: str) -> dict[str, int]: class MirrorVideos(PipelineStep):
if robot_type in ["bi_openarm_follower", "openarm_follower", "bi_openarms_follower", "openarms_follower"]: """Pipeline step that mirrors video files for assigned episodes."""
return OPENARM_MIRRORING_MASK
raise ValueError(f"Unknown robot type: {robot_type}. Add a mirroring mask for this robot.")
def __init__(
self,
repo_id: str,
output_repo_id: str,
root: str | None = None,
output_root: str | None = None,
vcodec: str = "libsvtav1",
):
super().__init__()
self.repo_id = repo_id
self.output_repo_id = output_repo_id
self.root = root
self.output_root = output_root
self.vcodec = vcodec
def swap_left_right_name(name: str) -> str: def run(self, data=None, rank: int = 0, world_size: int = 1):
import subprocess
from pathlib import Path
from datasets.utils.tqdm import disable_progress_bars
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.utils.constants import HF_LEROBOT_HOME
from lerobot.utils.utils import init_logging
init_logging()
disable_progress_bars()
def swap_left_right_name(name: str) -> str:
result = name.replace("left_", "LEFT_PLACEHOLDER_") result = name.replace("left_", "LEFT_PLACEHOLDER_")
result = result.replace("right_", "left_") result = result.replace("right_", "left_")
result = result.replace("LEFT_PLACEHOLDER_", "right_") result = result.replace("LEFT_PLACEHOLDER_", "right_")
return result return result
def flip_video_frames(input_path: Path, output_path: Path, fps: float, vcodec: str):
def mirror_feature_names(names: list[str]) -> tuple[list[str], dict[int, int]]:
mirrored_names = [swap_left_right_name(n) for n in names]
old_to_new_idx = {}
for old_idx, old_name in enumerate(names):
new_name = swap_left_right_name(old_name)
new_idx = mirrored_names.index(new_name)
old_to_new_idx[old_idx] = new_idx
return mirrored_names, old_to_new_idx
def apply_mirroring_mask(value: float, feature_name: str, mirroring_mask: dict[str, int]) -> float:
name_without_prefix = feature_name.split("_", 1)[1] if "_" in feature_name else feature_name
joint_name = name_without_prefix.split(".")[0]
if joint_name in mirroring_mask:
return value * mirroring_mask[joint_name]
return value
def mirror_array(array: np.ndarray, names: list[str], mirroring_mask: dict[str, int]) -> np.ndarray:
mirrored_names, idx_mapping = mirror_feature_names(names)
result = np.zeros_like(array)
for old_idx, new_idx in idx_mapping.items():
new_name = mirrored_names[new_idx]
value = array[old_idx]
mirrored_value = apply_mirroring_mask(value, new_name, mirroring_mask)
result[new_idx] = mirrored_value
return result
def flip_video_frames(input_path: Path, output_path: Path, fps: float, vcodec: str = "libsvtav1"):
output_path.parent.mkdir(parents=True, exist_ok=True) output_path.parent.mkdir(parents=True, exist_ok=True)
cmd = [ cmd = [
"ffmpeg", "-y", "-i", str(input_path), "ffmpeg", "-y", "-i", str(input_path),
@@ -127,9 +118,7 @@ def flip_video_frames(input_path: Path, output_path: Path, fps: float, vcodec: s
if result.returncode != 0: if result.returncode != 0:
raise RuntimeError(f"FFmpeg failed: {result.stderr}") raise RuntimeError(f"FFmpeg failed: {result.stderr}")
def video_is_valid(path: Path) -> bool:
def video_is_valid(path: Path) -> bool:
"""Check if a video file exists and is valid."""
if not path.exists(): if not path.exists():
return False return False
try: try:
@@ -142,36 +131,11 @@ def video_is_valid(path: Path) -> bool:
except Exception: except Exception:
return False return False
root = Path(self.root) if self.root else None
output_root = Path(self.output_root) if self.output_root else None
class MirrorVideos(PipelineStep): dataset = LeRobotDataset(self.repo_id, root=root)
"""Pipeline step that mirrors video files for assigned episodes.""" output_root = output_root or (HF_LEROBOT_HOME / self.output_repo_id)
def __init__(
self,
repo_id: str,
output_repo_id: str,
root: str | None = None,
output_root: str | None = None,
vcodec: str = "libsvtav1",
):
super().__init__()
self.repo_id = repo_id
self.output_repo_id = output_repo_id
self.root = Path(root) if root else None
self.output_root = Path(output_root) if output_root else None
self.vcodec = vcodec
def run(self, data=None, rank: int = 0, world_size: int = 1):
from datasets.utils.tqdm import disable_progress_bars
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.utils.constants import HF_LEROBOT_HOME
from lerobot.utils.utils import init_logging
init_logging()
disable_progress_bars()
dataset = LeRobotDataset(self.repo_id, root=self.root)
output_root = self.output_root or (HF_LEROBOT_HOME / self.output_repo_id)
if not dataset.meta.video_keys: if not dataset.meta.video_keys:
logger.info(f"Rank {rank}: No videos to process") logger.info(f"Rank {rank}: No videos to process")
@@ -191,7 +155,6 @@ class MirrorVideos(PipelineStep):
except KeyError: except KeyError:
continue continue
# Distribute tasks across workers
my_tasks = [t for i, t in enumerate(video_tasks) if i % world_size == rank] my_tasks = [t for i, t in enumerate(video_tasks) if i % world_size == rank]
logger.info(f"Rank {rank}/{world_size}: Processing {len(my_tasks)}/{len(video_tasks)} videos") logger.info(f"Rank {rank}/{world_size}: Processing {len(my_tasks)}/{len(video_tasks)} videos")
@@ -216,14 +179,19 @@ class MirrorDataAndMetadata(PipelineStep):
super().__init__() super().__init__()
self.repo_id = repo_id self.repo_id = repo_id
self.output_repo_id = output_repo_id self.output_repo_id = output_repo_id
self.root = Path(root) if root else None self.root = root
self.output_root = Path(output_root) if output_root else None self.output_root = output_root
def run(self, data=None, rank: int = 0, world_size: int = 1): def run(self, data=None, rank: int = 0, world_size: int = 1):
if rank != 0: if rank != 0:
return return
from pathlib import Path
import numpy as np
import pandas as pd
from datasets.utils.tqdm import disable_progress_bars from datasets.utils.tqdm import disable_progress_bars
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
from lerobot.datasets.utils import DATA_DIR, DEFAULT_DATA_PATH, write_info, write_stats, write_tasks from lerobot.datasets.utils import DATA_DIR, DEFAULT_DATA_PATH, write_info, write_stats, write_tasks
from lerobot.utils.constants import HF_LEROBOT_HOME from lerobot.utils.constants import HF_LEROBOT_HOME
@@ -232,8 +200,63 @@ class MirrorDataAndMetadata(PipelineStep):
init_logging() init_logging()
disable_progress_bars() disable_progress_bars()
dataset = LeRobotDataset(self.repo_id, root=self.root) MIRRORING_MASK = {
output_root = self.output_root or (HF_LEROBOT_HOME / self.output_repo_id) "joint_1": -1, "joint_2": -1, "joint_3": -1, "joint_4": 1,
"joint_5": -1, "joint_6": -1, "joint_7": -1, "gripper": 1,
}
def get_mirroring_mask(robot_type: str) -> dict[str, int]:
if robot_type in ["bi_openarm_follower", "openarm_follower", "bi_openarms_follower", "openarms_follower"]:
return MIRRORING_MASK
raise ValueError(f"Unknown robot type: {robot_type}. Add a mirroring mask for this robot.")
def swap_left_right_name(name: str) -> str:
result = name.replace("left_", "LEFT_PLACEHOLDER_")
result = result.replace("right_", "left_")
result = result.replace("LEFT_PLACEHOLDER_", "right_")
return result
def mirror_feature_names(names: list[str]) -> tuple[list[str], dict[int, int]]:
mirrored_names = [swap_left_right_name(n) for n in names]
old_to_new_idx = {}
for old_idx, old_name in enumerate(names):
new_name = swap_left_right_name(old_name)
new_idx = mirrored_names.index(new_name)
old_to_new_idx[old_idx] = new_idx
return mirrored_names, old_to_new_idx
def apply_mirroring_mask(value: float, feature_name: str, mirroring_mask: dict[str, int]) -> float:
name_without_prefix = feature_name.split("_", 1)[1] if "_" in feature_name else feature_name
joint_name = name_without_prefix.split(".")[0]
if joint_name in mirroring_mask:
return value * mirroring_mask[joint_name]
return value
def mirror_array(array: np.ndarray, names: list[str], mirroring_mask: dict[str, int]) -> np.ndarray:
mirrored_names, idx_mapping = mirror_feature_names(names)
result = np.zeros_like(array)
for old_idx, new_idx in idx_mapping.items():
new_name = mirrored_names[new_idx]
value = array[old_idx]
mirrored_value = apply_mirroring_mask(value, new_name, mirroring_mask)
result[new_idx] = mirrored_value
return result
def mirror_stats(stats: dict) -> dict:
mirrored = {}
for key, value in stats.items():
new_key = swap_left_right_name(key)
if isinstance(value, dict):
mirrored[new_key] = mirror_stats(value)
else:
mirrored[new_key] = value
return mirrored
root = Path(self.root) if self.root else None
output_root = Path(self.output_root) if self.output_root else None
dataset = LeRobotDataset(self.repo_id, root=root)
output_root = output_root or (HF_LEROBOT_HOME / self.output_repo_id)
done_marker = output_root / ".data_mirrored" done_marker = output_root / ".data_mirrored"
if done_marker.exists(): if done_marker.exists():
@@ -263,7 +286,6 @@ class MirrorDataAndMetadata(PipelineStep):
if dataset.meta.tasks is not None: if dataset.meta.tasks is not None:
write_tasks(dataset.meta.tasks, new_meta.root) write_tasks(dataset.meta.tasks, new_meta.root)
# Mirror parquet data
data_dir = dataset.root / DATA_DIR data_dir = dataset.root / DATA_DIR
parquet_files = sorted(data_dir.glob("*/*.parquet")) parquet_files = sorted(data_dir.glob("*/*.parquet"))
action_names = dataset.meta.features.get("action", {}).get("names", []) action_names = dataset.meta.features.get("action", {}).get("names", [])
@@ -291,7 +313,6 @@ class MirrorDataAndMetadata(PipelineStep):
dst_path.parent.mkdir(parents=True, exist_ok=True) dst_path.parent.mkdir(parents=True, exist_ok=True)
df.to_parquet(dst_path, index=False) df.to_parquet(dst_path, index=False)
# Copy episodes metadata
episodes_dir = dataset.root / "meta/episodes" episodes_dir = dataset.root / "meta/episodes"
dst_episodes_dir = new_meta.root / "meta/episodes" dst_episodes_dir = new_meta.root / "meta/episodes"
if episodes_dir.exists(): if episodes_dir.exists():
@@ -330,21 +351,17 @@ class MirrorDataAndMetadata(PipelineStep):
logger.info(f"Data and metadata mirrored to {output_root}") logger.info(f"Data and metadata mirrored to {output_root}")
def mirror_stats(stats: dict) -> dict: def swap_left_right_name(name: str) -> str:
mirrored = {} result = name.replace("left_", "LEFT_PLACEHOLDER_")
for key, value in stats.items(): result = result.replace("right_", "left_")
new_key = swap_left_right_name(key) result = result.replace("LEFT_PLACEHOLDER_", "right_")
if isinstance(value, dict): return result
mirrored[new_key] = mirror_stats(value)
else:
mirrored[new_key] = value
return mirrored
def get_num_video_tasks(repo_id: str, root: str | None = None) -> int: def get_num_video_tasks(repo_id: str, root: str | None = None) -> int:
"""Count total video files to process."""
from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.lerobot_dataset import LeRobotDataset
dataset = LeRobotDataset(repo_id, root=root) root_path = Path(root) if root else None
dataset = LeRobotDataset(repo_id, root=root_path)
count = 0 count = 0
for video_key in dataset.meta.video_keys: for video_key in dataset.meta.video_keys:
for ep_idx in range(dataset.meta.total_episodes): for ep_idx in range(dataset.meta.total_episodes):
@@ -451,4 +468,3 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
main() main()