trim episodes

This commit is contained in:
Martino Russi
2026-02-25 23:09:33 +01:00
parent 975dcad918
commit 012bde51cb
2 changed files with 427 additions and 0 deletions
+294
View File
@@ -47,6 +47,7 @@ from lerobot.datasets.utils import (
DEFAULT_EPISODES_PATH,
get_parquet_file_size_in_mb,
load_episodes,
load_info,
update_chunk_file_indices,
write_info,
write_stats,
@@ -1774,3 +1775,296 @@ def convert_image_to_video_dataset(
# Return new dataset
return LeRobotDataset(repo_id=repo_id, root=output_dir)
def trim_episodes_by_frames(
dataset: LeRobotDataset,
episode_frames_to_keep: dict[int, list[int]],
output_dir: str | Path | None = None,
repo_id: str | None = None,
) -> LeRobotDataset:
"""Trim multiple episodes to keep only specific frames.
This function creates a new dataset where the specified episodes contain only
the frames at the given indices. All other episodes are copied as-is.
Args:
dataset: The source LeRobotDataset.
episode_frames_to_keep: Dict mapping episode indices to lists of global frame indices to keep.
output_dir: Directory to save the new dataset. If None, uses default location.
repo_id: Repository ID for the new dataset. If None, appends "_trimmed" to original.
Returns:
A new LeRobotDataset with the trimmed episodes.
"""
if not episode_frames_to_keep:
raise ValueError("No episodes to trim")
for ep_idx in episode_frames_to_keep:
if ep_idx >= dataset.meta.total_episodes:
raise ValueError(f"Episode {ep_idx} does not exist")
if not episode_frames_to_keep[ep_idx]:
raise ValueError(f"No frames to keep for episode {ep_idx}")
if repo_id is None:
repo_id = f"{dataset.repo_id}_trimmed"
output_dir = Path(output_dir) if output_dir is not None else HF_LEROBOT_HOME / repo_id
total_trimmed = sum(len(frames) for frames in episode_frames_to_keep.values())
logging.info(f"Trimming {len(episode_frames_to_keep)} episodes, keeping {total_trimmed} frames total")
# Create new metadata
new_meta = LeRobotDatasetMetadata.create(
repo_id=repo_id,
fps=dataset.meta.fps,
features=dataset.meta.features,
robot_type=dataset.meta.robot_type,
root=output_dir,
use_videos=len(dataset.meta.video_keys) > 0,
)
# Build set of all frames to keep (for episodes being trimmed)
# and compute new frame counts per episode
all_keep_frames: set[int] = set()
trimmed_frame_counts: dict[int, int] = {}
for ep_idx, frames in episode_frames_to_keep.items():
all_keep_frames.update(frames)
trimmed_frame_counts[ep_idx] = len(frames)
# Copy and filter data
_copy_and_reindex_data_with_multi_frame_filter(
dataset, new_meta, episode_frames_to_keep, all_keep_frames
)
# Handle videos if present
if dataset.meta.video_keys:
_copy_and_reindex_videos_with_multi_frame_filter(
dataset, new_meta, episode_frames_to_keep
)
# Copy episode metadata
_copy_and_reindex_episodes_metadata_for_multi_trim(
dataset, new_meta, trimmed_frame_counts
)
logging.info(f"Created trimmed dataset with {new_meta.total_frames} frames at {output_dir}")
# Return the metadata instead of trying to load as LeRobotDataset
# This avoids Hub validation issues when the repo doesn't exist yet
return new_meta
# Keep old function for backward compatibility
def trim_episode_by_frames(
dataset: LeRobotDataset,
episode_index: int,
keep_frame_indices: list[int],
output_dir: str | Path | None = None,
repo_id: str | None = None,
) -> LeRobotDataset:
"""Trim a single episode. Wrapper around trim_episodes_by_frames."""
return trim_episodes_by_frames(
dataset,
episode_frames_to_keep={episode_index: keep_frame_indices},
output_dir=output_dir,
repo_id=repo_id,
)
def _copy_and_reindex_data_with_multi_frame_filter(
src_dataset: LeRobotDataset,
dst_meta: LeRobotDatasetMetadata,
episode_frames_to_keep: dict[int, list[int]],
all_keep_frames: set[int],
) -> None:
"""Copy data files with frame-level filtering for multiple episodes."""
if src_dataset.meta.episodes is None:
src_dataset.meta.episodes = load_episodes(src_dataset.meta.root)
# Copy tasks
if dst_meta.tasks is None and src_dataset.meta.tasks is not None:
# Tasks are stored with task string as index
dst_meta.save_episode_tasks(list(src_dataset.meta.tasks.index))
# Get all parquet files
data_dir = src_dataset.root / "data"
parquet_files = sorted(data_dir.glob("chunk-*/file-*.parquet"))
trim_episode_set = set(episode_frames_to_keep.keys())
global_index = 0
for parquet_path in tqdm(parquet_files, desc="Processing data files"):
df = pd.read_parquet(parquet_path)
# Filter: keep all frames from non-trimmed episodes,
# and only specified frames from trimmed episodes
mask = (~df["episode_index"].isin(trim_episode_set)) | (df["index"].isin(all_keep_frames))
df = df[mask].copy().reset_index(drop=True)
if len(df) == 0:
continue
# Reindex
df["index"] = range(global_index, global_index + len(df))
# Recalculate frame_index within each episode
for ep_idx in df["episode_index"].unique():
ep_mask = df["episode_index"] == ep_idx
df.loc[ep_mask, "frame_index"] = range(ep_mask.sum())
# Recalculate timestamps based on frame_index and fps
df["timestamp"] = df["frame_index"] / src_dataset.meta.fps
# Determine output path (keep same structure)
rel_path = parquet_path.relative_to(src_dataset.root)
dst_path = dst_meta.root / rel_path
dst_path.parent.mkdir(parents=True, exist_ok=True)
_write_parquet(df, dst_path, dst_meta)
global_index += len(df)
def _copy_and_reindex_videos_with_multi_frame_filter(
src_dataset: LeRobotDataset,
dst_meta: LeRobotDatasetMetadata,
episode_frames_to_keep: dict[int, list[int]],
) -> None:
"""Copy video files for trimmed dataset.
In v3.0 datasets, multiple episodes are concatenated into single video files.
Each episode has from_timestamp/to_timestamp indicating its portion of the video.
For trimming, we copy the original video files as-is and update the metadata
timestamps in _copy_and_reindex_episodes_metadata_for_multi_trim.
"""
for video_key in src_dataset.meta.video_keys:
video_dir = src_dataset.root / "videos" / video_key
dst_video_dir = dst_meta.root / "videos" / video_key
if not video_dir.exists():
logging.warning(f"Video directory not found: {video_dir}")
continue
# Copy all video files (they contain concatenated episodes)
# The metadata timestamps will handle which portions to use
copied_files = set()
for chunk_dir in video_dir.glob("chunk-*"):
dst_chunk_dir = dst_video_dir / chunk_dir.name
dst_chunk_dir.mkdir(parents=True, exist_ok=True)
for video_file in chunk_dir.glob("*.mp4"):
if video_file.name not in copied_files:
dst_path = dst_chunk_dir / video_file.name
if not dst_path.exists():
shutil.copy(video_file, dst_path)
copied_files.add(video_file.name)
logging.info(f"Copied {len(copied_files)} video files for {video_key}")
def _trim_video_frames(
src_path: Path,
dst_path: Path,
keep_frame_indices: list[int],
fps: float,
episode_start_idx: int,
) -> None:
"""Trim a video to keep only specific frames using ffmpeg."""
import subprocess
# Convert global indices to local indices within the episode
local_indices = sorted([idx - episode_start_idx for idx in keep_frame_indices])
if not local_indices:
logging.warning(f"No frames to keep for video {src_path}")
return
# Calculate start and end times
start_frame = local_indices[0]
end_frame = local_indices[-1]
start_time = start_frame / fps
duration = (end_frame - start_frame + 1) / fps
# Use ffmpeg to trim
cmd = [
"ffmpeg", "-y",
"-ss", str(start_time),
"-i", str(src_path),
"-t", str(duration),
"-c", "copy", # Fast copy without re-encoding
str(dst_path)
]
try:
subprocess.run(cmd, check=True, capture_output=True)
except subprocess.CalledProcessError as e:
logging.error(f"Failed to trim video: {e.stderr.decode()}")
# Fallback: copy the whole video
shutil.copy(src_path, dst_path)
def _copy_and_reindex_episodes_metadata_for_multi_trim(
src_dataset: LeRobotDataset,
dst_meta: LeRobotDatasetMetadata,
trimmed_frame_counts: dict[int, int],
) -> None:
"""Copy and update episode metadata for trimmed dataset."""
if src_dataset.meta.episodes is None:
src_dataset.meta.episodes = load_episodes(src_dataset.meta.root)
# Calculate new frame counts and indices
episodes_data = []
global_idx = 0
for old_ep_idx in range(src_dataset.meta.total_episodes):
src_ep = src_dataset.meta.episodes[old_ep_idx]
if old_ep_idx in trimmed_frame_counts:
ep_length = trimmed_frame_counts[old_ep_idx]
else:
ep_length = src_ep["length"]
ep_data = {
"episode_index": old_ep_idx,
"tasks": src_ep["tasks"],
"length": ep_length,
"data/chunk_index": src_ep["data/chunk_index"],
"data/file_index": src_ep["data/file_index"],
"dataset_from_index": global_idx,
"dataset_to_index": global_idx + ep_length,
}
# Copy video metadata - preserve timestamps for concatenated videos
for video_key in src_dataset.meta.video_keys:
ep_data[f"videos/{video_key}/chunk_index"] = src_ep[f"videos/{video_key}/chunk_index"]
ep_data[f"videos/{video_key}/file_index"] = src_ep[f"videos/{video_key}/file_index"]
# Keep original from_timestamp (start position in concatenated video)
orig_from_ts = src_ep[f"videos/{video_key}/from_timestamp"]
ep_data[f"videos/{video_key}/from_timestamp"] = orig_from_ts
# For trimmed episodes, update to_timestamp based on new length
# For non-trimmed episodes, keep original to_timestamp
if old_ep_idx in trimmed_frame_counts:
ep_data[f"videos/{video_key}/to_timestamp"] = orig_from_ts + (ep_length / src_dataset.meta.fps)
else:
ep_data[f"videos/{video_key}/to_timestamp"] = src_ep[f"videos/{video_key}/to_timestamp"]
ep_data["meta/episodes/chunk_index"] = 0
ep_data["meta/episodes/file_index"] = 0
episodes_data.append(ep_data)
global_idx += ep_length
# Save episodes metadata
df = pd.DataFrame(episodes_data)
episodes_path = dst_meta.root / DEFAULT_EPISODES_PATH.format(chunk_index=0, file_index=0)
episodes_path.parent.mkdir(parents=True, exist_ok=True)
df.to_parquet(episodes_path)
# Update info.json
info = load_info(src_dataset.root)
info["total_episodes"] = len(episodes_data)
info["total_frames"] = global_idx
write_info(info, dst_meta.root)
+133
View File
@@ -104,6 +104,28 @@ Convert image dataset to video format and push to hub:
--operation.type convert_image_to_video \
--push_to_hub true
Trim single episode to keep only frames within timestamp range:
python -m lerobot.scripts.lerobot_edit_dataset \
--repo_id lerobot/pusht \
--new_repo_id lerobot/pusht_trimmed \
--operation.type trim_episode \
--operation.episode_index 0 \
--operation.start_timestamp 10.0 \
--operation.end_timestamp 30.0
Trim multiple episodes at once (use null for no limit):
python -m lerobot.scripts.lerobot_edit_dataset \
--repo_id lerobot/pusht \
--operation.type trim_episode \
--operation.episode_trims '{"0": [10.0, 30.0], "2": [5.0, null], "3": [null, 20.0]}'
Trim and re-upload to same repo (overwrites original):
python -m lerobot.scripts.lerobot_edit_dataset \
--repo_id lerobot/pusht \
--operation.type trim_episode \
--operation.episode_index 0 \
--operation.start_timestamp 10.0 \
--push_to_hub true
Show dataset information:
lerobot-edit-dataset \
--repo_id lerobot/pusht_image \
@@ -204,9 +226,32 @@ class InfoConfig(OperationConfig):
show_features: bool = False
@dataclass
class TrimEpisodeConfig:
"""Trim episodes to keep only frames within timestamp ranges.
Supports multiple episodes via episode_trims dict:
--operation.episode_trims '{"0": [10.0, 30.0], "2": [5.0, 20.0]}'
Or single episode via legacy parameters:
--operation.episode_index 0 --operation.start_timestamp 10.0 --operation.end_timestamp 30.0
"""
type: str = "trim_episode"
# Multi-episode support: dict mapping episode_index -> [start_timestamp, end_timestamp]
# Use null for no limit, e.g. {"0": [10.0, null], "2": [null, 30.0]}
episode_trims: dict[str, list[float | None]] | None = None
# Legacy single-episode parameters (used if episode_trims is None)
episode_index: int | None = None
start_timestamp: float | None = None # Keep frames from this timestamp (inclusive)
end_timestamp: float | None = None # Keep frames until this timestamp (inclusive)
@dataclass
class EditDatasetConfig:
repo_id: str
operation: (
DeleteEpisodesConfig | SplitConfig | MergeConfig | RemoveFeatureConfig | ConvertImageToVideoConfig | TrimEpisodeConfig
)
operation: OperationConfig
root: str | None = None
new_repo_id: str | None = None
@@ -351,6 +396,92 @@ def handle_remove_feature(cfg: EditDatasetConfig) -> None:
LeRobotDataset(output_repo_id, root=output_dir).push_to_hub()
def handle_trim_episode(cfg: EditDatasetConfig) -> None:
"""Trim episodes to keep only frames within timestamp ranges."""
if not isinstance(cfg.operation, TrimEpisodeConfig):
raise ValueError("Operation config must be TrimEpisodeConfig")
# Parse episode trims - support both multi-episode dict and legacy single episode
episode_trims: dict[int, tuple[float | None, float | None]] = {}
if cfg.operation.episode_trims is not None:
# Multi-episode mode
for ep_str, ts_range in cfg.operation.episode_trims.items():
ep_idx = int(ep_str)
start_ts = ts_range[0] if len(ts_range) > 0 else None
end_ts = ts_range[1] if len(ts_range) > 1 else None
episode_trims[ep_idx] = (start_ts, end_ts)
elif cfg.operation.episode_index is not None:
# Legacy single-episode mode
if cfg.operation.start_timestamp is None and cfg.operation.end_timestamp is None:
raise ValueError("At least one of start_timestamp or end_timestamp must be specified")
episode_trims[cfg.operation.episode_index] = (
cfg.operation.start_timestamp,
cfg.operation.end_timestamp,
)
else:
raise ValueError("Either episode_trims or episode_index must be specified")
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root)
output_repo_id, output_dir = get_output_path(
cfg.repo_id, cfg.new_repo_id, Path(cfg.root) if cfg.root else None
)
if cfg.new_repo_id is None:
dataset.root = Path(str(dataset.root) + "_old")
logging.info(f"Trimming {len(episode_trims)} episode(s) from {cfg.repo_id}")
# Get episode boundaries and find frames to keep for each episode
episodes_info = dataset.meta.episodes
all_frames_to_keep: dict[int, list[int]] = {}
for ep_idx, (start_ts, end_ts) in episode_trims.items():
if ep_idx >= len(episodes_info["episode_index"]):
raise ValueError(f"Episode {ep_idx} does not exist (dataset has {len(episodes_info['episode_index'])} episodes)")
from_frame = episodes_info["dataset_from_index"][ep_idx]
to_frame = episodes_info["dataset_to_index"][ep_idx]
logging.info(f"Episode {ep_idx}: trimming to [{start_ts}, {end_ts}]")
logging.info(f" Original frames: {from_frame} to {to_frame} ({to_frame - from_frame} frames)")
# Find frames within timestamp range
frames_to_keep = []
for frame_idx in range(from_frame, to_frame):
frame = dataset.hf_dataset[frame_idx]
ts = frame["timestamp"]
in_range = True
if start_ts is not None and ts < start_ts:
in_range = False
if end_ts is not None and ts > end_ts:
in_range = False
if in_range:
frames_to_keep.append(frame_idx)
if not frames_to_keep:
raise ValueError(f"Episode {ep_idx}: No frames found in timestamp range [{start_ts}, {end_ts}]")
logging.info(f" Keeping {len(frames_to_keep)} frames (indices {frames_to_keep[0]} to {frames_to_keep[-1]})")
all_frames_to_keep[ep_idx] = frames_to_keep
from lerobot.datasets.dataset_tools import trim_episodes_by_frames
new_dataset = trim_episodes_by_frames(
dataset,
episode_frames_to_keep=all_frames_to_keep,
output_dir=output_dir,
repo_id=output_repo_id,
)
logging.info(f"Dataset saved to {output_dir}")
logging.info(f"Episodes: {new_dataset.meta.total_episodes}, Frames: {new_dataset.meta.total_frames}")
if cfg.push_to_hub:
logging.info(f"Pushing to hub as {output_repo_id}")
LeRobotDataset(output_repo_id, root=output_dir).push_to_hub()
def handle_modify_tasks(cfg: EditDatasetConfig) -> None:
if not isinstance(cfg.operation, ModifyTasksConfig):
raise ValueError("Operation config must be ModifyTasksConfig")
@@ -515,6 +646,8 @@ def edit_dataset(cfg: EditDatasetConfig) -> None:
handle_modify_tasks(cfg)
elif operation_type == "convert_image_to_video":
handle_convert_image_to_video(cfg)
elif operation_type == "trim_episode":
handle_trim_episode(cfg)
elif operation_type == "info":
handle_info(cfg)
else: