mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-11 14:49:43 +00:00
Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 1f658023f1 | |||
| a225127527 |
+6
-5
@@ -66,7 +66,7 @@ dependencies = [
|
||||
"accelerate>=1.10.0,<2.0.0",
|
||||
|
||||
# Core dependencies
|
||||
"numpy>=2.0.0,<2.3.0", # TODO: upper bound imposed by opencv-python-headless
|
||||
"numpy>=2.0.0,<2.3.0", # NOTE: Explicitly listing numpy helps the resolver converge faster. Upper bound imposed by opencv-python-headless.
|
||||
"setuptools>=71.0.0,<81.0.0",
|
||||
"cmake>=3.29.0.1,<4.2.0",
|
||||
"packaging>=24.2,<26.0",
|
||||
@@ -105,7 +105,7 @@ can-dep = ["python-can>=4.2.0,<5.0.0"]
|
||||
peft-dep = ["peft>=0.18.0,<1.0.0"]
|
||||
scipy-dep = ["scipy>=1.14.0,<2.0.0"]
|
||||
qwen-vl-utils-dep = ["qwen-vl-utils>=0.0.11,<0.1.0"]
|
||||
matplotlib-dep = ["matplotlib>=3.10.3,<4.0.0", "contourpy>=1.3.0,<2.0.0"]
|
||||
matplotlib-dep = ["matplotlib>=3.10.3,<4.0.0", "contourpy>=1.3.0,<2.0.0"] # NOTE: Explicitly listing contourpy helps the resolver converge faster.
|
||||
|
||||
# Motors
|
||||
feetech = ["feetech-servo-sdk>=1.0.0,<2.0.0"]
|
||||
@@ -130,7 +130,7 @@ reachy2 = ["reachy2_sdk>=1.0.15,<1.1.0"]
|
||||
kinematics = ["lerobot[placo-dep]"]
|
||||
intelrealsense = [
|
||||
"pyrealsense2>=2.55.1.6486,<2.57.0 ; sys_platform != 'darwin'",
|
||||
"pyrealsense2-macosx>=2.54,<2.55.0 ; sys_platform == 'darwin'",
|
||||
"pyrealsense2-macosx>=2.54,<2.57.0 ; sys_platform == 'darwin'",
|
||||
]
|
||||
phone = ["hebi-py>=2.8.0,<2.12.0", "teleop>=0.1.0,<0.2.0", "fastapi<1.0", "lerobot[scipy-dep]"]
|
||||
|
||||
@@ -169,6 +169,7 @@ test = ["pytest>=8.1.0,<9.0.0", "pytest-timeout>=2.4.0,<3.0.0", "pytest-cov>=5.0
|
||||
video_benchmark = ["scikit-image>=0.23.2,<0.26.0", "pandas>=2.2.2,<2.4.0"]
|
||||
|
||||
# Simulation
|
||||
# NOTE: Explicitly listing scipy helps flatten the dependecy tree.
|
||||
aloha = ["gym-aloha>=0.1.2,<0.2.0", "lerobot[scipy-dep]"]
|
||||
pusht = ["gym-pusht>=0.1.5,<0.2.0", "pymunk>=6.6.0,<7.0.0"] # TODO: Fix pymunk version in gym-pusht instead
|
||||
libero = ["lerobot[transformers-dep]", "hf-libero>=0.1.3,<0.2.0; sys_platform == 'linux'", "lerobot[scipy-dep]"]
|
||||
@@ -176,8 +177,8 @@ metaworld = ["metaworld==3.0.0", "lerobot[scipy-dep]"]
|
||||
|
||||
# All
|
||||
all = [
|
||||
# Resolver hint: scipy is pulled in transitively via lerobot[scipy-dep] through
|
||||
# multiple extras below (aloha, metaworld, pi, wallx, phone). Listing it explicitly
|
||||
# NOTE(resolver hint): scipy is pulled in transitively via lerobot[scipy-dep] through
|
||||
# multiple extras (aloha, metaworld, pi, wallx, phone). Listing it explicitly
|
||||
# helps pip's resolver converge by constraining scipy early, before it encounters
|
||||
# the loose scipy requirements from transitive deps like dm-control and metaworld.
|
||||
"scipy>=1.14.0,<2.0.0",
|
||||
|
||||
@@ -25,6 +25,7 @@ This module provides utilities for:
|
||||
|
||||
import logging
|
||||
import shutil
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from pathlib import Path
|
||||
@@ -45,6 +46,8 @@ from lerobot.datasets.utils import (
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
DEFAULT_DATA_PATH,
|
||||
DEFAULT_EPISODES_PATH,
|
||||
DEFAULT_SUBTASKS_PATH,
|
||||
flatten_dict,
|
||||
get_parquet_file_size_in_mb,
|
||||
load_episodes,
|
||||
update_chunk_file_indices,
|
||||
@@ -141,6 +144,315 @@ def delete_episodes(
|
||||
return new_dataset
|
||||
|
||||
|
||||
def trim_episode_start(
|
||||
dataset: LeRobotDataset,
|
||||
seconds: float,
|
||||
episode_indices: list[int] | None = None,
|
||||
output_dir: str | Path | None = None,
|
||||
repo_id: str | None = None,
|
||||
) -> LeRobotDataset:
|
||||
"""Trim the first N seconds from selected episodes and create a new dataset.
|
||||
|
||||
The operation rewrites data parquet files and updates episode metadata so that:
|
||||
- frame_index starts at 0 for each trimmed episode
|
||||
- timestamp starts at 0 for each trimmed episode
|
||||
- global index remains contiguous across the full dataset
|
||||
- dataset_from_index / dataset_to_index reflect new frame ranges
|
||||
|
||||
Video files are copied as-is and per-episode video timestamps are shifted forward
|
||||
for trimmed episodes.
|
||||
|
||||
Episodes selected for trimming that are too short (length <= trim_frames) are skipped
|
||||
from the output dataset.
|
||||
|
||||
Args:
|
||||
dataset: The source LeRobotDataset.
|
||||
seconds: Number of seconds to remove from episode starts.
|
||||
episode_indices: Optional list of episode indices to trim. If None, trims all episodes.
|
||||
output_dir: Directory to save the new dataset. If None, uses default location.
|
||||
repo_id: Repository ID for the new dataset. If None, appends "_trimmed" to original.
|
||||
"""
|
||||
if seconds <= 0:
|
||||
raise ValueError(f"seconds must be strictly positive, got {seconds}")
|
||||
|
||||
if dataset.meta.episodes is None:
|
||||
dataset.meta.episodes = load_episodes(dataset.meta.root)
|
||||
|
||||
trim_frames = int(seconds * dataset.meta.fps)
|
||||
if trim_frames <= 0:
|
||||
raise ValueError(
|
||||
f"seconds={seconds} corresponds to 0 frames at fps={dataset.meta.fps}. "
|
||||
"Increase seconds so at least one frame is trimmed."
|
||||
)
|
||||
|
||||
if episode_indices is None:
|
||||
episode_indices = list(range(dataset.meta.total_episodes))
|
||||
|
||||
if len(episode_indices) == 0:
|
||||
raise ValueError("No episodes specified to trim")
|
||||
|
||||
episode_indices = sorted(set(episode_indices))
|
||||
valid_indices = set(range(dataset.meta.total_episodes))
|
||||
invalid = set(episode_indices) - valid_indices
|
||||
if invalid:
|
||||
raise ValueError(f"Invalid episode indices: {invalid}")
|
||||
|
||||
too_short = sorted(
|
||||
ep_idx for ep_idx in episode_indices if int(dataset.meta.episodes[ep_idx]["length"]) <= trim_frames
|
||||
)
|
||||
trim_set = set(episode_indices)
|
||||
skipped_set = set(too_short)
|
||||
trim_set -= skipped_set
|
||||
|
||||
if too_short:
|
||||
logging.warning(
|
||||
f"Skipping {len(too_short)} episode(s) that are too short to trim "
|
||||
f"({trim_frames} frames): {too_short}"
|
||||
)
|
||||
|
||||
episodes_to_keep = [ep_idx for ep_idx in range(dataset.meta.total_episodes) if ep_idx not in skipped_set]
|
||||
if not episodes_to_keep:
|
||||
raise ValueError(
|
||||
"All episodes selected for trimming are too short and would be skipped. "
|
||||
"Try a smaller trim duration."
|
||||
)
|
||||
|
||||
logging.info(
|
||||
f"Trimming {len(trim_set)} episode(s) by {seconds}s and keeping {len(episodes_to_keep)} "
|
||||
f"episode(s) in output"
|
||||
)
|
||||
|
||||
if repo_id is None:
|
||||
repo_id = f"{dataset.repo_id}_trimmed"
|
||||
output_dir = Path(output_dir) if output_dir is not None else HF_LEROBOT_HOME / repo_id
|
||||
|
||||
new_meta = LeRobotDatasetMetadata.create(
|
||||
repo_id=repo_id,
|
||||
fps=dataset.meta.fps,
|
||||
features=dataset.meta.features,
|
||||
robot_type=dataset.meta.robot_type,
|
||||
root=output_dir,
|
||||
use_videos=len(dataset.meta.video_keys) > 0,
|
||||
chunks_size=dataset.meta.chunks_size,
|
||||
data_files_size_in_mb=dataset.meta.data_files_size_in_mb,
|
||||
video_files_size_in_mb=dataset.meta.video_files_size_in_mb,
|
||||
)
|
||||
|
||||
if dataset.meta.tasks is not None:
|
||||
write_tasks(dataset.meta.tasks, new_meta.root)
|
||||
new_meta.tasks = dataset.meta.tasks.copy()
|
||||
|
||||
subtasks_path = dataset.root / DEFAULT_SUBTASKS_PATH
|
||||
if subtasks_path.exists():
|
||||
dst_subtasks_path = new_meta.root / DEFAULT_SUBTASKS_PATH
|
||||
dst_subtasks_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy(subtasks_path, dst_subtasks_path)
|
||||
|
||||
episode_mapping = {old_idx: new_idx for new_idx, old_idx in enumerate(episodes_to_keep)}
|
||||
trim_duration_s = trim_frames / dataset.meta.fps
|
||||
|
||||
episode_lengths: dict[int, int] = {}
|
||||
episode_ranges: dict[int, tuple[int, int]] = {}
|
||||
total_frames = 0
|
||||
for old_ep_idx in episodes_to_keep:
|
||||
new_ep_idx = episode_mapping[old_ep_idx]
|
||||
src_length = int(dataset.meta.episodes[old_ep_idx]["length"])
|
||||
new_length = src_length - trim_frames if old_ep_idx in trim_set else src_length
|
||||
episode_lengths[new_ep_idx] = new_length
|
||||
episode_ranges[new_ep_idx] = (total_frames, total_frames + new_length)
|
||||
total_frames += new_length
|
||||
|
||||
numeric_features = {
|
||||
k: v
|
||||
for k, v in dataset.meta.features.items()
|
||||
if v["dtype"] not in ["image", "video", "string"]
|
||||
}
|
||||
episode_stats_parts: dict[int, list[dict[str, dict]]] = defaultdict(list)
|
||||
episode_file_metadata: dict[int, dict[str, int]] = {}
|
||||
|
||||
data_dir = dataset.root / DATA_DIR
|
||||
parquet_files = sorted(data_dir.glob("*/*.parquet"))
|
||||
if not parquet_files:
|
||||
raise ValueError(f"No parquet files found in {data_dir}")
|
||||
|
||||
for src_path in tqdm(parquet_files, desc="Trimming data files"):
|
||||
df = pd.read_parquet(src_path).reset_index(drop=True)
|
||||
|
||||
if len(df) == 0:
|
||||
continue
|
||||
|
||||
if skipped_set:
|
||||
keep_mask = ~df["episode_index"].isin(skipped_set)
|
||||
if not keep_mask.all():
|
||||
df = df.loc[keep_mask].copy().reset_index(drop=True)
|
||||
|
||||
if len(df) == 0:
|
||||
continue
|
||||
|
||||
if trim_set:
|
||||
trim_mask = df["episode_index"].isin(trim_set) & (df["frame_index"] < trim_frames)
|
||||
if trim_mask.any():
|
||||
df = df.loc[~trim_mask].copy().reset_index(drop=True)
|
||||
|
||||
if len(df) == 0:
|
||||
continue
|
||||
|
||||
relative_path = src_path.relative_to(dataset.root)
|
||||
chunk_idx = int(relative_path.parts[1].split("-")[1])
|
||||
file_idx = int(relative_path.parts[2].split("-")[1].split(".")[0])
|
||||
|
||||
for old_ep_idx in sorted(df["episode_index"].unique().tolist()):
|
||||
ep_mask = df["episode_index"] == old_ep_idx
|
||||
new_ep_idx = episode_mapping[old_ep_idx]
|
||||
|
||||
if old_ep_idx in trim_set:
|
||||
df.loc[ep_mask, "frame_index"] = df.loc[ep_mask, "frame_index"] - trim_frames
|
||||
shifted_timestamps = df.loc[ep_mask, "timestamp"].to_numpy(dtype=np.float64) - trim_duration_s
|
||||
df.loc[ep_mask, "timestamp"] = np.clip(shifted_timestamps, a_min=0.0, a_max=None)
|
||||
|
||||
df.loc[ep_mask, "episode_index"] = new_ep_idx
|
||||
|
||||
ep_start, _ = episode_ranges[new_ep_idx]
|
||||
new_indices = ep_start + df.loc[ep_mask, "frame_index"].to_numpy(dtype=np.int64)
|
||||
df.loc[ep_mask, "index"] = new_indices
|
||||
|
||||
if new_ep_idx in episode_file_metadata:
|
||||
existing = episode_file_metadata[new_ep_idx]
|
||||
if (
|
||||
existing["data/chunk_index"] != chunk_idx
|
||||
or existing["data/file_index"] != file_idx
|
||||
):
|
||||
raise ValueError(
|
||||
f"Episode {old_ep_idx} spans multiple data files. "
|
||||
"trim_episode_start currently expects one data file per episode."
|
||||
)
|
||||
else:
|
||||
episode_file_metadata[new_ep_idx] = {
|
||||
"data/chunk_index": chunk_idx,
|
||||
"data/file_index": file_idx,
|
||||
}
|
||||
|
||||
if numeric_features:
|
||||
ep_df = df.loc[ep_mask]
|
||||
episode_data: dict[str, np.ndarray] = {}
|
||||
episode_feature_spec: dict[str, dict] = {}
|
||||
|
||||
for key, feature in numeric_features.items():
|
||||
if key not in ep_df.columns:
|
||||
continue
|
||||
|
||||
values = ep_df[key].to_numpy()
|
||||
if len(values) == 0:
|
||||
continue
|
||||
|
||||
first_value = values[0]
|
||||
if isinstance(first_value, np.ndarray):
|
||||
episode_data[key] = np.stack(values)
|
||||
elif isinstance(first_value, (list, tuple)):
|
||||
episode_data[key] = np.stack(values)
|
||||
else:
|
||||
episode_data[key] = np.asarray(values)
|
||||
|
||||
episode_feature_spec[key] = feature
|
||||
|
||||
if episode_data:
|
||||
episode_stats_parts[new_ep_idx].append(
|
||||
compute_episode_stats(episode_data, episode_feature_spec)
|
||||
)
|
||||
|
||||
df["index"] = df["index"].astype(np.int64)
|
||||
if "frame_index" in df.columns:
|
||||
df["frame_index"] = df["frame_index"].astype(np.int64)
|
||||
|
||||
dst_path = new_meta.root / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
_write_parquet(df, dst_path, new_meta)
|
||||
|
||||
all_episode_stats = []
|
||||
for old_ep_idx in tqdm(episodes_to_keep, desc="Writing episode metadata"):
|
||||
new_ep_idx = episode_mapping[old_ep_idx]
|
||||
|
||||
if new_ep_idx not in episode_file_metadata:
|
||||
raise ValueError(f"Missing data file metadata for episode {old_ep_idx}")
|
||||
|
||||
from_idx, to_idx = episode_ranges[new_ep_idx]
|
||||
src_episode = dataset.meta.episodes[old_ep_idx]
|
||||
ep_data_meta = episode_file_metadata[new_ep_idx]
|
||||
|
||||
stats_parts = episode_stats_parts.get(new_ep_idx, [])
|
||||
ep_stats = aggregate_stats(stats_parts) if len(stats_parts) > 1 else (stats_parts[0] if stats_parts else {})
|
||||
if ep_stats:
|
||||
all_episode_stats.append(ep_stats)
|
||||
|
||||
episode_meta = {
|
||||
"data/chunk_index": ep_data_meta["data/chunk_index"],
|
||||
"data/file_index": ep_data_meta["data/file_index"],
|
||||
"dataset_from_index": from_idx,
|
||||
"dataset_to_index": to_idx,
|
||||
}
|
||||
|
||||
for video_key in dataset.meta.video_keys:
|
||||
from_ts = src_episode[f"videos/{video_key}/from_timestamp"]
|
||||
if old_ep_idx in trim_set:
|
||||
from_ts += trim_duration_s
|
||||
episode_meta.update(
|
||||
{
|
||||
f"videos/{video_key}/chunk_index": src_episode[f"videos/{video_key}/chunk_index"],
|
||||
f"videos/{video_key}/file_index": src_episode[f"videos/{video_key}/file_index"],
|
||||
f"videos/{video_key}/from_timestamp": from_ts,
|
||||
f"videos/{video_key}/to_timestamp": src_episode[f"videos/{video_key}/to_timestamp"],
|
||||
}
|
||||
)
|
||||
|
||||
episode_dict = {
|
||||
"episode_index": new_ep_idx,
|
||||
"tasks": src_episode["tasks"],
|
||||
"length": episode_lengths[new_ep_idx],
|
||||
}
|
||||
episode_dict.update(episode_meta)
|
||||
if ep_stats:
|
||||
episode_dict.update(flatten_dict({"stats": ep_stats}))
|
||||
|
||||
new_meta._save_episode_metadata(episode_dict)
|
||||
|
||||
new_meta._close_writer()
|
||||
|
||||
if new_meta.video_keys:
|
||||
_copy_videos(dataset, new_meta)
|
||||
|
||||
new_meta.info.update(
|
||||
{
|
||||
"total_episodes": len(episodes_to_keep),
|
||||
"total_frames": total_frames,
|
||||
"total_tasks": len(new_meta.tasks) if new_meta.tasks is not None else 0,
|
||||
"splits": {"train": f"0:{len(episodes_to_keep)}"},
|
||||
}
|
||||
)
|
||||
|
||||
if new_meta.video_keys and dataset.meta.video_keys:
|
||||
for key in new_meta.video_keys:
|
||||
if key in dataset.meta.features:
|
||||
new_meta.info["features"][key]["info"] = dataset.meta.info["features"][key].get("info", {})
|
||||
|
||||
write_info(new_meta.info, new_meta.root)
|
||||
|
||||
merged_stats = aggregate_stats(all_episode_stats) if all_episode_stats else {}
|
||||
if dataset.meta.stats:
|
||||
for key, value in dataset.meta.stats.items():
|
||||
if key not in merged_stats:
|
||||
merged_stats[key] = value
|
||||
if merged_stats:
|
||||
write_stats(merged_stats, new_meta.root)
|
||||
|
||||
return LeRobotDataset(
|
||||
repo_id=repo_id,
|
||||
root=output_dir,
|
||||
image_transforms=dataset.image_transforms,
|
||||
delta_timestamps=dataset.delta_timestamps,
|
||||
tolerance_s=dataset.tolerance_s,
|
||||
)
|
||||
|
||||
|
||||
def split_dataset(
|
||||
dataset: LeRobotDataset,
|
||||
splits: dict[str, float | list[int]],
|
||||
|
||||
@@ -117,6 +117,13 @@ Modify tasks - set default task with overrides for specific episodes (WARNING: m
|
||||
--operation.new_task "Default task" \
|
||||
--operation.episode_tasks '{"5": "Special task for episode 5"}'
|
||||
|
||||
Trim first 3 seconds from all episodes:
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--new_repo_id lerobot/pusht_trim3s \
|
||||
--operation.type trim_episode_start \
|
||||
--operation.seconds 3.0
|
||||
|
||||
Convert image dataset to video format and save locally:
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
@@ -170,6 +177,7 @@ from lerobot.datasets.dataset_tools import (
|
||||
modify_tasks,
|
||||
remove_feature,
|
||||
split_dataset,
|
||||
trim_episode_start,
|
||||
)
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.utils.constants import HF_LEROBOT_HOME
|
||||
@@ -215,6 +223,13 @@ class ModifyTasksConfig(OperationConfig):
|
||||
episode_tasks: dict[str, str] | None = None
|
||||
|
||||
|
||||
@OperationConfig.register_subclass("trim_episode_start")
|
||||
@dataclass
|
||||
class TrimEpisodeStartConfig(OperationConfig):
|
||||
seconds: float | None = None
|
||||
episode_indices: list[int] | None = None
|
||||
|
||||
|
||||
@OperationConfig.register_subclass("convert_image_to_video")
|
||||
@dataclass
|
||||
class ConvertImageToVideoConfig(OperationConfig):
|
||||
@@ -464,6 +479,41 @@ def handle_modify_tasks(cfg: EditDatasetConfig) -> None:
|
||||
modified_dataset.push_to_hub()
|
||||
|
||||
|
||||
def handle_trim_episode_start(cfg: EditDatasetConfig) -> None:
|
||||
if not isinstance(cfg.operation, TrimEpisodeStartConfig):
|
||||
raise ValueError("Operation config must be TrimEpisodeStartConfig")
|
||||
|
||||
if cfg.operation.seconds is None:
|
||||
raise ValueError("seconds must be specified for trim_episode_start operation")
|
||||
|
||||
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root)
|
||||
output_repo_id, output_dir = get_output_path(
|
||||
cfg.repo_id, cfg.new_repo_id, Path(cfg.root) if cfg.root else None
|
||||
)
|
||||
|
||||
if cfg.new_repo_id is None:
|
||||
dataset.root = Path(str(dataset.root) + "_old")
|
||||
|
||||
logging.info(
|
||||
f"Trimming first {cfg.operation.seconds}s from episodes "
|
||||
f"{cfg.operation.episode_indices if cfg.operation.episode_indices else 'ALL'} in {cfg.repo_id}"
|
||||
)
|
||||
new_dataset = trim_episode_start(
|
||||
dataset=dataset,
|
||||
seconds=cfg.operation.seconds,
|
||||
episode_indices=cfg.operation.episode_indices,
|
||||
output_dir=output_dir,
|
||||
repo_id=output_repo_id,
|
||||
)
|
||||
|
||||
logging.info(f"Dataset saved to {output_dir}")
|
||||
logging.info(f"Episodes: {new_dataset.meta.total_episodes}, Frames: {new_dataset.meta.total_frames}")
|
||||
|
||||
if cfg.push_to_hub:
|
||||
logging.info(f"Pushing to hub as {output_repo_id}")
|
||||
LeRobotDataset(output_repo_id, root=output_dir).push_to_hub()
|
||||
|
||||
|
||||
def handle_convert_image_to_video(cfg: EditDatasetConfig) -> None:
|
||||
# Note: Parser may create any config type with the right fields, so we access fields directly
|
||||
# instead of checking isinstance()
|
||||
@@ -594,6 +644,8 @@ def edit_dataset(cfg: EditDatasetConfig) -> None:
|
||||
handle_remove_feature(cfg)
|
||||
elif operation_type == "modify_tasks":
|
||||
handle_modify_tasks(cfg)
|
||||
elif operation_type == "trim_episode_start":
|
||||
handle_trim_episode_start(cfg)
|
||||
elif operation_type == "convert_image_to_video":
|
||||
handle_convert_image_to_video(cfg)
|
||||
elif operation_type == "info":
|
||||
|
||||
@@ -29,6 +29,7 @@ from lerobot.datasets.dataset_tools import (
|
||||
modify_tasks,
|
||||
remove_feature,
|
||||
split_dataset,
|
||||
trim_episode_start,
|
||||
)
|
||||
from lerobot.scripts.lerobot_edit_dataset import convert_image_to_video_dataset
|
||||
|
||||
@@ -142,6 +143,104 @@ def test_delete_empty_list(sample_dataset, tmp_path):
|
||||
)
|
||||
|
||||
|
||||
def test_trim_episode_start_updates_indices(sample_dataset, tmp_path):
|
||||
"""Test trimming episode starts updates frame/timestamp/index metadata consistently."""
|
||||
output_dir = tmp_path / "trimmed"
|
||||
trim_seconds = 0.1 # 3 frames at 30 FPS
|
||||
trim_frames = int(trim_seconds * sample_dataset.meta.fps)
|
||||
|
||||
with (
|
||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||
):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
mock_snapshot_download.return_value = str(output_dir)
|
||||
|
||||
new_dataset = trim_episode_start(
|
||||
sample_dataset,
|
||||
seconds=trim_seconds,
|
||||
output_dir=output_dir,
|
||||
)
|
||||
|
||||
expected_length = 10 - trim_frames
|
||||
assert new_dataset.meta.total_episodes == sample_dataset.meta.total_episodes
|
||||
assert new_dataset.meta.total_frames == sample_dataset.meta.total_episodes * expected_length
|
||||
|
||||
indices = [int(i.item()) for i in new_dataset.hf_dataset["index"]]
|
||||
assert indices == list(range(new_dataset.meta.total_frames))
|
||||
|
||||
episode_indices = [int(i.item()) for i in new_dataset.hf_dataset["episode_index"]]
|
||||
frame_indices = [int(i.item()) for i in new_dataset.hf_dataset["frame_index"]]
|
||||
timestamps = [float(i.item()) for i in new_dataset.hf_dataset["timestamp"]]
|
||||
|
||||
for ep_idx in range(sample_dataset.meta.total_episodes):
|
||||
ep_frame_indices = [f for e, f in zip(episode_indices, frame_indices, strict=False) if e == ep_idx]
|
||||
ep_timestamps = [t for e, t in zip(episode_indices, timestamps, strict=False) if e == ep_idx]
|
||||
|
||||
assert len(ep_frame_indices) == expected_length
|
||||
assert ep_frame_indices == list(range(expected_length))
|
||||
assert ep_timestamps[0] == pytest.approx(0.0)
|
||||
assert ep_timestamps[-1] == pytest.approx((expected_length - 1) / sample_dataset.meta.fps)
|
||||
|
||||
ep_meta = new_dataset.meta.episodes[ep_idx]
|
||||
assert int(ep_meta["length"]) == expected_length
|
||||
assert int(ep_meta["dataset_from_index"]) == ep_idx * expected_length
|
||||
assert int(ep_meta["dataset_to_index"]) == (ep_idx + 1) * expected_length
|
||||
|
||||
|
||||
def test_trim_episode_start_skips_too_short_episodes(tmp_path, empty_lerobot_dataset_factory):
|
||||
"""Test too-short episodes are skipped and remaining episodes are reindexed."""
|
||||
features = {
|
||||
"action": {"dtype": "float32", "shape": (2,), "names": None},
|
||||
"observation.state": {"dtype": "float32", "shape": (2,), "names": None},
|
||||
"observation.images.top": {"dtype": "image", "shape": (32, 32, 3), "names": None},
|
||||
}
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "source", features=features)
|
||||
|
||||
for ep_len in [10, 2, 10]:
|
||||
for _ in range(ep_len):
|
||||
dataset.add_frame(
|
||||
{
|
||||
"action": np.random.randn(2).astype(np.float32),
|
||||
"observation.state": np.random.randn(2).astype(np.float32),
|
||||
"observation.images.top": np.random.randint(0, 255, size=(32, 32, 3), dtype=np.uint8),
|
||||
"task": "task",
|
||||
}
|
||||
)
|
||||
dataset.save_episode()
|
||||
dataset.finalize()
|
||||
|
||||
trim_seconds = 0.1 # 3 frames at 30 FPS
|
||||
with (
|
||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||
):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
mock_snapshot_download.return_value = str(tmp_path / "trimmed")
|
||||
|
||||
new_dataset = trim_episode_start(
|
||||
dataset,
|
||||
seconds=trim_seconds,
|
||||
output_dir=tmp_path / "trimmed",
|
||||
)
|
||||
|
||||
# Episode 1 is too short and gets skipped. Remaining episodes are trimmed and reindexed.
|
||||
assert new_dataset.meta.total_episodes == 2
|
||||
assert new_dataset.meta.total_frames == 14
|
||||
assert sorted({int(idx.item()) for idx in new_dataset.hf_dataset["episode_index"]}) == [0, 1]
|
||||
assert [int(ep["length"]) for ep in new_dataset.meta.episodes] == [7, 7]
|
||||
|
||||
|
||||
def test_trim_episode_start_rejects_when_all_selected_are_too_short(sample_dataset, tmp_path):
|
||||
"""Test trimming fails when all selected episodes are too short and would be skipped."""
|
||||
with pytest.raises(ValueError, match="All episodes selected for trimming are too short"):
|
||||
trim_episode_start(
|
||||
sample_dataset,
|
||||
seconds=1.0, # 30 frames > 10-frame episodes
|
||||
output_dir=tmp_path / "trimmed",
|
||||
)
|
||||
|
||||
|
||||
def test_split_by_episodes(sample_dataset, tmp_path):
|
||||
"""Test splitting dataset by specific episode indices."""
|
||||
splits = {
|
||||
|
||||
@@ -28,6 +28,7 @@ from lerobot.scripts.lerobot_edit_dataset import (
|
||||
RemoveFeatureConfig,
|
||||
SplitConfig,
|
||||
_validate_config,
|
||||
TrimEpisodeStartConfig,
|
||||
)
|
||||
|
||||
|
||||
@@ -47,6 +48,7 @@ class TestOperationTypeParsing:
|
||||
("merge", MergeConfig),
|
||||
("remove_feature", RemoveFeatureConfig),
|
||||
("modify_tasks", ModifyTasksConfig),
|
||||
("trim_episode_start", TrimEpisodeStartConfig),
|
||||
("convert_image_to_video", ConvertImageToVideoConfig),
|
||||
("info", InfoConfig),
|
||||
],
|
||||
@@ -77,6 +79,7 @@ class TestOperationTypeParsing:
|
||||
("merge", MergeConfig),
|
||||
("remove_feature", RemoveFeatureConfig),
|
||||
("modify_tasks", ModifyTasksConfig),
|
||||
("trim_episode_start", TrimEpisodeStartConfig),
|
||||
("convert_image_to_video", ConvertImageToVideoConfig),
|
||||
("info", InfoConfig),
|
||||
],
|
||||
|
||||
Reference in New Issue
Block a user