Compare commits

..

1 Commits

Author SHA1 Message Date
Martino Russi 012bde51cb trim episodes 2026-02-25 23:09:33 +01:00
4 changed files with 427 additions and 367 deletions
+294
View File
@@ -47,6 +47,7 @@ from lerobot.datasets.utils import (
DEFAULT_EPISODES_PATH,
get_parquet_file_size_in_mb,
load_episodes,
load_info,
update_chunk_file_indices,
write_info,
write_stats,
@@ -1774,3 +1775,296 @@ def convert_image_to_video_dataset(
# Return new dataset
return LeRobotDataset(repo_id=repo_id, root=output_dir)
def trim_episodes_by_frames(
dataset: LeRobotDataset,
episode_frames_to_keep: dict[int, list[int]],
output_dir: str | Path | None = None,
repo_id: str | None = None,
) -> LeRobotDataset:
"""Trim multiple episodes to keep only specific frames.
This function creates a new dataset where the specified episodes contain only
the frames at the given indices. All other episodes are copied as-is.
Args:
dataset: The source LeRobotDataset.
episode_frames_to_keep: Dict mapping episode indices to lists of global frame indices to keep.
output_dir: Directory to save the new dataset. If None, uses default location.
repo_id: Repository ID for the new dataset. If None, appends "_trimmed" to original.
Returns:
A new LeRobotDataset with the trimmed episodes.
"""
if not episode_frames_to_keep:
raise ValueError("No episodes to trim")
for ep_idx in episode_frames_to_keep:
if ep_idx >= dataset.meta.total_episodes:
raise ValueError(f"Episode {ep_idx} does not exist")
if not episode_frames_to_keep[ep_idx]:
raise ValueError(f"No frames to keep for episode {ep_idx}")
if repo_id is None:
repo_id = f"{dataset.repo_id}_trimmed"
output_dir = Path(output_dir) if output_dir is not None else HF_LEROBOT_HOME / repo_id
total_trimmed = sum(len(frames) for frames in episode_frames_to_keep.values())
logging.info(f"Trimming {len(episode_frames_to_keep)} episodes, keeping {total_trimmed} frames total")
# Create new metadata
new_meta = LeRobotDatasetMetadata.create(
repo_id=repo_id,
fps=dataset.meta.fps,
features=dataset.meta.features,
robot_type=dataset.meta.robot_type,
root=output_dir,
use_videos=len(dataset.meta.video_keys) > 0,
)
# Build set of all frames to keep (for episodes being trimmed)
# and compute new frame counts per episode
all_keep_frames: set[int] = set()
trimmed_frame_counts: dict[int, int] = {}
for ep_idx, frames in episode_frames_to_keep.items():
all_keep_frames.update(frames)
trimmed_frame_counts[ep_idx] = len(frames)
# Copy and filter data
_copy_and_reindex_data_with_multi_frame_filter(
dataset, new_meta, episode_frames_to_keep, all_keep_frames
)
# Handle videos if present
if dataset.meta.video_keys:
_copy_and_reindex_videos_with_multi_frame_filter(
dataset, new_meta, episode_frames_to_keep
)
# Copy episode metadata
_copy_and_reindex_episodes_metadata_for_multi_trim(
dataset, new_meta, trimmed_frame_counts
)
logging.info(f"Created trimmed dataset with {new_meta.total_frames} frames at {output_dir}")
# Return the metadata instead of trying to load as LeRobotDataset
# This avoids Hub validation issues when the repo doesn't exist yet
return new_meta
# Keep old function for backward compatibility
def trim_episode_by_frames(
dataset: LeRobotDataset,
episode_index: int,
keep_frame_indices: list[int],
output_dir: str | Path | None = None,
repo_id: str | None = None,
) -> LeRobotDataset:
"""Trim a single episode. Wrapper around trim_episodes_by_frames."""
return trim_episodes_by_frames(
dataset,
episode_frames_to_keep={episode_index: keep_frame_indices},
output_dir=output_dir,
repo_id=repo_id,
)
def _copy_and_reindex_data_with_multi_frame_filter(
src_dataset: LeRobotDataset,
dst_meta: LeRobotDatasetMetadata,
episode_frames_to_keep: dict[int, list[int]],
all_keep_frames: set[int],
) -> None:
"""Copy data files with frame-level filtering for multiple episodes."""
if src_dataset.meta.episodes is None:
src_dataset.meta.episodes = load_episodes(src_dataset.meta.root)
# Copy tasks
if dst_meta.tasks is None and src_dataset.meta.tasks is not None:
# Tasks are stored with task string as index
dst_meta.save_episode_tasks(list(src_dataset.meta.tasks.index))
# Get all parquet files
data_dir = src_dataset.root / "data"
parquet_files = sorted(data_dir.glob("chunk-*/file-*.parquet"))
trim_episode_set = set(episode_frames_to_keep.keys())
global_index = 0
for parquet_path in tqdm(parquet_files, desc="Processing data files"):
df = pd.read_parquet(parquet_path)
# Filter: keep all frames from non-trimmed episodes,
# and only specified frames from trimmed episodes
mask = (~df["episode_index"].isin(trim_episode_set)) | (df["index"].isin(all_keep_frames))
df = df[mask].copy().reset_index(drop=True)
if len(df) == 0:
continue
# Reindex
df["index"] = range(global_index, global_index + len(df))
# Recalculate frame_index within each episode
for ep_idx in df["episode_index"].unique():
ep_mask = df["episode_index"] == ep_idx
df.loc[ep_mask, "frame_index"] = range(ep_mask.sum())
# Recalculate timestamps based on frame_index and fps
df["timestamp"] = df["frame_index"] / src_dataset.meta.fps
# Determine output path (keep same structure)
rel_path = parquet_path.relative_to(src_dataset.root)
dst_path = dst_meta.root / rel_path
dst_path.parent.mkdir(parents=True, exist_ok=True)
_write_parquet(df, dst_path, dst_meta)
global_index += len(df)
def _copy_and_reindex_videos_with_multi_frame_filter(
src_dataset: LeRobotDataset,
dst_meta: LeRobotDatasetMetadata,
episode_frames_to_keep: dict[int, list[int]],
) -> None:
"""Copy video files for trimmed dataset.
In v3.0 datasets, multiple episodes are concatenated into single video files.
Each episode has from_timestamp/to_timestamp indicating its portion of the video.
For trimming, we copy the original video files as-is and update the metadata
timestamps in _copy_and_reindex_episodes_metadata_for_multi_trim.
"""
for video_key in src_dataset.meta.video_keys:
video_dir = src_dataset.root / "videos" / video_key
dst_video_dir = dst_meta.root / "videos" / video_key
if not video_dir.exists():
logging.warning(f"Video directory not found: {video_dir}")
continue
# Copy all video files (they contain concatenated episodes)
# The metadata timestamps will handle which portions to use
copied_files = set()
for chunk_dir in video_dir.glob("chunk-*"):
dst_chunk_dir = dst_video_dir / chunk_dir.name
dst_chunk_dir.mkdir(parents=True, exist_ok=True)
for video_file in chunk_dir.glob("*.mp4"):
if video_file.name not in copied_files:
dst_path = dst_chunk_dir / video_file.name
if not dst_path.exists():
shutil.copy(video_file, dst_path)
copied_files.add(video_file.name)
logging.info(f"Copied {len(copied_files)} video files for {video_key}")
def _trim_video_frames(
src_path: Path,
dst_path: Path,
keep_frame_indices: list[int],
fps: float,
episode_start_idx: int,
) -> None:
"""Trim a video to keep only specific frames using ffmpeg."""
import subprocess
# Convert global indices to local indices within the episode
local_indices = sorted([idx - episode_start_idx for idx in keep_frame_indices])
if not local_indices:
logging.warning(f"No frames to keep for video {src_path}")
return
# Calculate start and end times
start_frame = local_indices[0]
end_frame = local_indices[-1]
start_time = start_frame / fps
duration = (end_frame - start_frame + 1) / fps
# Use ffmpeg to trim
cmd = [
"ffmpeg", "-y",
"-ss", str(start_time),
"-i", str(src_path),
"-t", str(duration),
"-c", "copy", # Fast copy without re-encoding
str(dst_path)
]
try:
subprocess.run(cmd, check=True, capture_output=True)
except subprocess.CalledProcessError as e:
logging.error(f"Failed to trim video: {e.stderr.decode()}")
# Fallback: copy the whole video
shutil.copy(src_path, dst_path)
def _copy_and_reindex_episodes_metadata_for_multi_trim(
src_dataset: LeRobotDataset,
dst_meta: LeRobotDatasetMetadata,
trimmed_frame_counts: dict[int, int],
) -> None:
"""Copy and update episode metadata for trimmed dataset."""
if src_dataset.meta.episodes is None:
src_dataset.meta.episodes = load_episodes(src_dataset.meta.root)
# Calculate new frame counts and indices
episodes_data = []
global_idx = 0
for old_ep_idx in range(src_dataset.meta.total_episodes):
src_ep = src_dataset.meta.episodes[old_ep_idx]
if old_ep_idx in trimmed_frame_counts:
ep_length = trimmed_frame_counts[old_ep_idx]
else:
ep_length = src_ep["length"]
ep_data = {
"episode_index": old_ep_idx,
"tasks": src_ep["tasks"],
"length": ep_length,
"data/chunk_index": src_ep["data/chunk_index"],
"data/file_index": src_ep["data/file_index"],
"dataset_from_index": global_idx,
"dataset_to_index": global_idx + ep_length,
}
# Copy video metadata - preserve timestamps for concatenated videos
for video_key in src_dataset.meta.video_keys:
ep_data[f"videos/{video_key}/chunk_index"] = src_ep[f"videos/{video_key}/chunk_index"]
ep_data[f"videos/{video_key}/file_index"] = src_ep[f"videos/{video_key}/file_index"]
# Keep original from_timestamp (start position in concatenated video)
orig_from_ts = src_ep[f"videos/{video_key}/from_timestamp"]
ep_data[f"videos/{video_key}/from_timestamp"] = orig_from_ts
# For trimmed episodes, update to_timestamp based on new length
# For non-trimmed episodes, keep original to_timestamp
if old_ep_idx in trimmed_frame_counts:
ep_data[f"videos/{video_key}/to_timestamp"] = orig_from_ts + (ep_length / src_dataset.meta.fps)
else:
ep_data[f"videos/{video_key}/to_timestamp"] = src_ep[f"videos/{video_key}/to_timestamp"]
ep_data["meta/episodes/chunk_index"] = 0
ep_data["meta/episodes/file_index"] = 0
episodes_data.append(ep_data)
global_idx += ep_length
# Save episodes metadata
df = pd.DataFrame(episodes_data)
episodes_path = dst_meta.root / DEFAULT_EPISODES_PATH.format(chunk_index=0, file_index=0)
episodes_path.parent.mkdir(parents=True, exist_ok=True)
df.to_parquet(episodes_path)
# Update info.json
info = load_info(src_dataset.root)
info["total_episodes"] = len(episodes_data)
info["total_frames"] = global_idx
write_info(info, dst_meta.root)
+133
View File
@@ -104,6 +104,28 @@ Convert image dataset to video format and push to hub:
--operation.type convert_image_to_video \
--push_to_hub true
Trim single episode to keep only frames within timestamp range:
python -m lerobot.scripts.lerobot_edit_dataset \
--repo_id lerobot/pusht \
--new_repo_id lerobot/pusht_trimmed \
--operation.type trim_episode \
--operation.episode_index 0 \
--operation.start_timestamp 10.0 \
--operation.end_timestamp 30.0
Trim multiple episodes at once (use null for no limit):
python -m lerobot.scripts.lerobot_edit_dataset \
--repo_id lerobot/pusht \
--operation.type trim_episode \
--operation.episode_trims '{"0": [10.0, 30.0], "2": [5.0, null], "3": [null, 20.0]}'
Trim and re-upload to same repo (overwrites original):
python -m lerobot.scripts.lerobot_edit_dataset \
--repo_id lerobot/pusht \
--operation.type trim_episode \
--operation.episode_index 0 \
--operation.start_timestamp 10.0 \
--push_to_hub true
Show dataset information:
lerobot-edit-dataset \
--repo_id lerobot/pusht_image \
@@ -204,9 +226,32 @@ class InfoConfig(OperationConfig):
show_features: bool = False
@dataclass
class TrimEpisodeConfig:
"""Trim episodes to keep only frames within timestamp ranges.
Supports multiple episodes via episode_trims dict:
--operation.episode_trims '{"0": [10.0, 30.0], "2": [5.0, 20.0]}'
Or single episode via legacy parameters:
--operation.episode_index 0 --operation.start_timestamp 10.0 --operation.end_timestamp 30.0
"""
type: str = "trim_episode"
# Multi-episode support: dict mapping episode_index -> [start_timestamp, end_timestamp]
# Use null for no limit, e.g. {"0": [10.0, null], "2": [null, 30.0]}
episode_trims: dict[str, list[float | None]] | None = None
# Legacy single-episode parameters (used if episode_trims is None)
episode_index: int | None = None
start_timestamp: float | None = None # Keep frames from this timestamp (inclusive)
end_timestamp: float | None = None # Keep frames until this timestamp (inclusive)
@dataclass
class EditDatasetConfig:
repo_id: str
operation: (
DeleteEpisodesConfig | SplitConfig | MergeConfig | RemoveFeatureConfig | ConvertImageToVideoConfig | TrimEpisodeConfig
)
operation: OperationConfig
root: str | None = None
new_repo_id: str | None = None
@@ -351,6 +396,92 @@ def handle_remove_feature(cfg: EditDatasetConfig) -> None:
LeRobotDataset(output_repo_id, root=output_dir).push_to_hub()
def handle_trim_episode(cfg: EditDatasetConfig) -> None:
"""Trim episodes to keep only frames within timestamp ranges."""
if not isinstance(cfg.operation, TrimEpisodeConfig):
raise ValueError("Operation config must be TrimEpisodeConfig")
# Parse episode trims - support both multi-episode dict and legacy single episode
episode_trims: dict[int, tuple[float | None, float | None]] = {}
if cfg.operation.episode_trims is not None:
# Multi-episode mode
for ep_str, ts_range in cfg.operation.episode_trims.items():
ep_idx = int(ep_str)
start_ts = ts_range[0] if len(ts_range) > 0 else None
end_ts = ts_range[1] if len(ts_range) > 1 else None
episode_trims[ep_idx] = (start_ts, end_ts)
elif cfg.operation.episode_index is not None:
# Legacy single-episode mode
if cfg.operation.start_timestamp is None and cfg.operation.end_timestamp is None:
raise ValueError("At least one of start_timestamp or end_timestamp must be specified")
episode_trims[cfg.operation.episode_index] = (
cfg.operation.start_timestamp,
cfg.operation.end_timestamp,
)
else:
raise ValueError("Either episode_trims or episode_index must be specified")
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root)
output_repo_id, output_dir = get_output_path(
cfg.repo_id, cfg.new_repo_id, Path(cfg.root) if cfg.root else None
)
if cfg.new_repo_id is None:
dataset.root = Path(str(dataset.root) + "_old")
logging.info(f"Trimming {len(episode_trims)} episode(s) from {cfg.repo_id}")
# Get episode boundaries and find frames to keep for each episode
episodes_info = dataset.meta.episodes
all_frames_to_keep: dict[int, list[int]] = {}
for ep_idx, (start_ts, end_ts) in episode_trims.items():
if ep_idx >= len(episodes_info["episode_index"]):
raise ValueError(f"Episode {ep_idx} does not exist (dataset has {len(episodes_info['episode_index'])} episodes)")
from_frame = episodes_info["dataset_from_index"][ep_idx]
to_frame = episodes_info["dataset_to_index"][ep_idx]
logging.info(f"Episode {ep_idx}: trimming to [{start_ts}, {end_ts}]")
logging.info(f" Original frames: {from_frame} to {to_frame} ({to_frame - from_frame} frames)")
# Find frames within timestamp range
frames_to_keep = []
for frame_idx in range(from_frame, to_frame):
frame = dataset.hf_dataset[frame_idx]
ts = frame["timestamp"]
in_range = True
if start_ts is not None and ts < start_ts:
in_range = False
if end_ts is not None and ts > end_ts:
in_range = False
if in_range:
frames_to_keep.append(frame_idx)
if not frames_to_keep:
raise ValueError(f"Episode {ep_idx}: No frames found in timestamp range [{start_ts}, {end_ts}]")
logging.info(f" Keeping {len(frames_to_keep)} frames (indices {frames_to_keep[0]} to {frames_to_keep[-1]})")
all_frames_to_keep[ep_idx] = frames_to_keep
from lerobot.datasets.dataset_tools import trim_episodes_by_frames
new_dataset = trim_episodes_by_frames(
dataset,
episode_frames_to_keep=all_frames_to_keep,
output_dir=output_dir,
repo_id=output_repo_id,
)
logging.info(f"Dataset saved to {output_dir}")
logging.info(f"Episodes: {new_dataset.meta.total_episodes}, Frames: {new_dataset.meta.total_frames}")
if cfg.push_to_hub:
logging.info(f"Pushing to hub as {output_repo_id}")
LeRobotDataset(output_repo_id, root=output_dir).push_to_hub()
def handle_modify_tasks(cfg: EditDatasetConfig) -> None:
if not isinstance(cfg.operation, ModifyTasksConfig):
raise ValueError("Operation config must be ModifyTasksConfig")
@@ -515,6 +646,8 @@ def edit_dataset(cfg: EditDatasetConfig) -> None:
handle_modify_tasks(cfg)
elif operation_type == "convert_image_to_video":
handle_convert_image_to_video(cfg)
elif operation_type == "trim_episode":
handle_trim_episode(cfg)
elif operation_type == "info":
handle_info(cfg)
else:
@@ -1,366 +0,0 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Mirror a bimanual robot dataset by swapping left/right arms and inverting joint values.
This script creates a mirrored version of a dataset where:
1. Left and right arm observations/actions are swapped
2. Joint values are inverted according to a mirroring mask
3. Video frames are horizontally flipped
Example usage:
```shell
python -m lerobot.scripts.lerobot_mirror_dataset \
--repo_id=pepijn/openarm_bimanual \
--output_repo_id=pepijn/openarm_bimanual_mirrored
```
"""
import argparse
import logging
import os
import subprocess
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
import numpy as np
import pandas as pd
from tqdm import tqdm
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
from lerobot.datasets.utils import (
DATA_DIR,
DEFAULT_DATA_PATH,
write_info,
write_stats,
write_tasks,
)
from lerobot.utils.constants import HF_LEROBOT_HOME
logger = logging.getLogger(__name__)
OPENARM_MIRRORING_MASK = {
"joint_1": -1, # Pan - invert
"joint_2": -1, # Lift - invert
"joint_3": -1, # Roll - invert
"joint_4": 1, # Elbow - no invert
"joint_5": -1, # W-Roll - invert
"joint_6": -1, # W-Pitch - invert
"joint_7": -1, # W-Yaw - invert
"gripper": 1, # Gripper - no invert
}
def get_mirroring_mask(robot_type: str) -> dict[str, int]:
"""Get the mirroring mask for a given robot type."""
if robot_type in ["bi_openarm_follower", "openarm_follower", "bi_openarms_follower", "openarms_follower"]:
return OPENARM_MIRRORING_MASK
raise ValueError(f"Unknown robot type: {robot_type}. Add a mirroring mask for this robot.")
def swap_left_right_name(name: str) -> str:
"""Swap 'left' and 'right' in a feature name."""
# Use placeholder to avoid double-swap
result = name.replace("left_", "LEFT_PLACEHOLDER_")
result = result.replace("right_", "left_")
result = result.replace("LEFT_PLACEHOLDER_", "right_")
return result
def mirror_feature_names(names: list[str]) -> tuple[list[str], dict[int, int]]:
"""Mirror feature names by swapping left/right and return the new names and index mapping."""
mirrored_names = [swap_left_right_name(n) for n in names]
old_to_new_idx = {}
for old_idx, old_name in enumerate(names):
new_name = swap_left_right_name(old_name)
new_idx = mirrored_names.index(new_name)
old_to_new_idx[old_idx] = new_idx
return mirrored_names, old_to_new_idx
def apply_mirroring_mask(
value: float,
feature_name: str,
mirroring_mask: dict[str, int],
) -> float:
"""Apply mirroring mask to a joint value."""
name_without_prefix = feature_name.split("_", 1)[1] if "_" in feature_name else feature_name
joint_name = name_without_prefix.split(".")[0]
if joint_name in mirroring_mask:
return value * mirroring_mask[joint_name]
return value
def mirror_array(
array: np.ndarray,
names: list[str],
mirroring_mask: dict[str, int],
) -> np.ndarray:
"""Mirror an array of values (action or state) by swapping left/right and applying mask."""
mirrored_names, idx_mapping = mirror_feature_names(names)
result = np.zeros_like(array)
for old_idx, new_idx in idx_mapping.items():
old_name = names[old_idx]
new_name = mirrored_names[new_idx]
value = array[old_idx]
mirrored_value = apply_mirroring_mask(value, new_name, mirroring_mask)
result[new_idx] = mirrored_value
return result
def flip_video_frames(
input_path: Path,
output_path: Path,
fps: float,
vcodec: str = "libsvtav1",
):
"""Flip video frames horizontally using FFmpeg with same settings as encode_video_frames."""
output_path.parent.mkdir(parents=True, exist_ok=True)
cmd = [
"ffmpeg", "-y", "-i", str(input_path),
"-vf", "hflip",
"-c:v", vcodec,
"-g", "2",
"-crf", "30",
"-r", str(int(fps)),
"-pix_fmt", "yuv420p",
"-loglevel", "error",
]
if vcodec == "libsvtav1":
cmd.extend(["-preset", "12"])
cmd.append(str(output_path))
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode != 0:
raise RuntimeError(f"FFmpeg failed: {result.stderr}")
def mirror_dataset(
repo_id: str,
output_repo_id: str,
root: str | Path | None = None,
output_root: str | Path | None = None,
mirroring_mask: dict[str, int] | None = None,
vcodec: str = "libsvtav1",
num_workers: int | None = None,
) -> LeRobotDataset:
"""Mirror a bimanual robot dataset."""
logger.info(f"Loading dataset: {repo_id}")
dataset = LeRobotDataset(repo_id, root=root)
if mirroring_mask is None:
robot_type = dataset.meta.robot_type or "bi_openarms_follower"
mirroring_mask = get_mirroring_mask(robot_type)
logger.info(f"Using mirroring mask for robot type: {robot_type}")
output_root = Path(output_root) if output_root else HF_LEROBOT_HOME / output_repo_id
mirrored_features = {}
for key, feat in dataset.meta.features.items():
new_key = swap_left_right_name(key)
new_feat = feat.copy()
if "names" in new_feat and new_feat["names"]:
new_feat["names"] = [swap_left_right_name(n) for n in new_feat["names"]]
mirrored_features[new_key] = new_feat
logger.info("Creating mirrored dataset metadata...")
new_meta = LeRobotDatasetMetadata.create(
repo_id=output_repo_id,
fps=dataset.meta.fps,
features=mirrored_features,
robot_type=dataset.meta.robot_type,
root=output_root,
use_videos=len(dataset.meta.video_keys) > 0,
)
if dataset.meta.tasks is not None:
write_tasks(dataset.meta.tasks, new_meta.root)
new_meta.tasks = dataset.meta.tasks.copy()
_mirror_data(dataset, new_meta, mirroring_mask)
_mirror_videos(dataset, new_meta, vcodec, num_workers)
_copy_episodes_metadata(dataset, new_meta)
logger.info(f"Mirrored dataset saved to: {output_root}")
return LeRobotDataset(output_repo_id, root=output_root)
def _mirror_data(
src_dataset: LeRobotDataset,
dst_meta: LeRobotDatasetMetadata,
mirroring_mask: dict[str, int],
) -> None:
"""Mirror parquet data files."""
data_dir = src_dataset.root / DATA_DIR
parquet_files = sorted(data_dir.glob("*/*.parquet"))
if not parquet_files:
raise ValueError(f"No parquet files found in {data_dir}")
action_names = src_dataset.meta.features.get("action", {}).get("names", [])
state_names = src_dataset.meta.features.get("observation.state", {}).get("names", [])
for src_path in tqdm(parquet_files, desc="Mirroring data files"):
df = pd.read_parquet(src_path).reset_index(drop=True)
relative_path = src_path.relative_to(src_dataset.root)
chunk_dir = relative_path.parts[1]
file_name = relative_path.parts[2]
chunk_idx = int(chunk_dir.split("-")[1])
file_idx = int(file_name.split("-")[1].split(".")[0])
if "action" in df.columns and action_names:
actions = np.stack(df["action"].values)
mirrored_actions = np.array([
mirror_array(row, action_names, mirroring_mask) for row in actions
])
df["action"] = list(mirrored_actions)
if "observation.state" in df.columns and state_names:
states = np.stack(df["observation.state"].values)
mirrored_states = np.array([
mirror_array(row, state_names, mirroring_mask) for row in states
])
df["observation.state"] = list(mirrored_states)
dst_path = dst_meta.root / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
dst_path.parent.mkdir(parents=True, exist_ok=True)
df.to_parquet(dst_path, index=False)
def _mirror_videos(
src_dataset: LeRobotDataset,
dst_meta: LeRobotDatasetMetadata,
vcodec: str,
num_workers: int | None = None,
) -> None:
"""Mirror video files by flipping horizontally and swapping left/right names."""
if not src_dataset.meta.video_keys:
return
video_tasks = []
for old_video_key in src_dataset.meta.video_keys:
new_video_key = swap_left_right_name(old_video_key)
for ep_idx in range(src_dataset.meta.total_episodes):
try:
src_path = src_dataset.root / src_dataset.meta.get_video_file_path(ep_idx, old_video_key)
dst_relative = src_dataset.meta.get_video_file_path(ep_idx, old_video_key)
dst_relative_str = str(dst_relative).replace(old_video_key, new_video_key)
dst_path = dst_meta.root / dst_relative_str
if src_path.exists():
video_tasks.append((src_path, dst_path))
except KeyError:
continue
def process_video(task, pbar):
src_path, dst_path = task
pbar.set_postfix_str(src_path.name)
flip_video_frames(src_path, dst_path, src_dataset.meta.fps, vcodec)
return src_path
if num_workers is None:
num_workers = os.cpu_count() or 16
num_workers = min(len(video_tasks), num_workers)
logger.info(f"Processing {len(video_tasks)} videos with {num_workers} workers")
with tqdm(total=len(video_tasks), desc="Mirroring videos") as pbar:
with ThreadPoolExecutor(max_workers=num_workers) as executor:
futures = {executor.submit(process_video, t, pbar): t for t in video_tasks}
for future in as_completed(futures):
task = futures[future]
future.result()
pbar.set_postfix_str(f"done: {task[0].name}")
pbar.update(1)
def _copy_episodes_metadata(
src_dataset: LeRobotDataset,
dst_meta: LeRobotDatasetMetadata,
) -> None:
"""Copy episodes metadata with swapped video keys."""
episodes_dir = src_dataset.root / "meta/episodes"
dst_episodes_dir = dst_meta.root / "meta/episodes"
if episodes_dir.exists():
dst_episodes_dir.mkdir(parents=True, exist_ok=True)
for src_parquet in episodes_dir.glob("*/*.parquet"):
df = pd.read_parquet(src_parquet)
columns_to_rename = {}
for col in df.columns:
if col.startswith("videos/"):
parts = col.split("/")
if len(parts) >= 2:
video_key = parts[1]
new_video_key = swap_left_right_name(video_key)
new_col = col.replace(f"videos/{video_key}/", f"videos/{new_video_key}/")
columns_to_rename[col] = new_col
if columns_to_rename:
df = df.rename(columns=columns_to_rename)
dst_parquet = dst_episodes_dir / src_parquet.relative_to(episodes_dir)
dst_parquet.parent.mkdir(parents=True, exist_ok=True)
df.to_parquet(dst_parquet, index=False)
dst_meta.info.update({
"total_episodes": src_dataset.meta.total_episodes,
"total_frames": src_dataset.meta.total_frames,
"total_tasks": src_dataset.meta.total_tasks,
"total_videos": src_dataset.meta.total_videos,
"total_chunks": src_dataset.meta.total_chunks,
})
write_info(dst_meta.info, dst_meta.root)
if src_dataset.meta.stats is not None:
mirrored_stats = _mirror_stats(src_dataset.meta.stats)
write_stats(mirrored_stats, dst_meta.root)
def _mirror_stats(stats: dict) -> dict:
"""Mirror stats by swapping left/right in feature names."""
mirrored = {}
for key, value in stats.items():
new_key = swap_left_right_name(key)
if isinstance(value, dict):
mirrored[new_key] = _mirror_stats(value)
else:
mirrored[new_key] = value
return mirrored
def main():
logging.basicConfig(level=logging.INFO)
parser = argparse.ArgumentParser(description="Mirror a bimanual robot dataset")
parser.add_argument("--repo_id", type=str, required=True, help="Source dataset repo_id")
parser.add_argument("--output_repo_id", type=str, required=True, help="Output dataset repo_id")
parser.add_argument("--root", type=str, default=None, help="Source dataset root directory")
parser.add_argument("--output_root", type=str, default=None, help="Output dataset root directory")
parser.add_argument("--vcodec", type=str, default="libsvtav1", help="Video codec (libsvtav1, h264, hevc)")
parser.add_argument("--num_workers", type=int, default=None, help="Number of parallel workers for video processing")
parser.add_argument("--push-to-hub", action="store_true", help="Push mirrored dataset to HuggingFace Hub")
args = parser.parse_args()
dataset = mirror_dataset(
repo_id=args.repo_id,
output_repo_id=args.output_repo_id,
root=args.root,
output_root=args.output_root,
vcodec=args.vcodec,
num_workers=args.num_workers,
)
if getattr(args, "push_to_hub", False):
logger.info(f"Pushing dataset to HuggingFace Hub: {args.output_repo_id}")
dataset.push_to_hub()
if __name__ == "__main__":
main()
-1
View File
@@ -38,7 +38,6 @@ from lerobot.policies.factory import make_policy, make_pre_post_processors
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.rl.wandb_utils import WandBLogger
from lerobot.scripts.lerobot_eval import eval_policy_all
from lerobot.teleoperators import openarm_mini # noqa: F401
from lerobot.utils.import_utils import register_third_party_plugins
from lerobot.utils.logging_utils import AverageMeter, MetricsTracker
from lerobot.utils.random_utils import set_seed