mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-21 03:30:10 +00:00
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 1f658023f1 |
@@ -25,6 +25,7 @@ This module provides utilities for:
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import shutil
|
import shutil
|
||||||
|
from collections import defaultdict
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -45,6 +46,8 @@ from lerobot.datasets.utils import (
|
|||||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||||
DEFAULT_DATA_PATH,
|
DEFAULT_DATA_PATH,
|
||||||
DEFAULT_EPISODES_PATH,
|
DEFAULT_EPISODES_PATH,
|
||||||
|
DEFAULT_SUBTASKS_PATH,
|
||||||
|
flatten_dict,
|
||||||
get_parquet_file_size_in_mb,
|
get_parquet_file_size_in_mb,
|
||||||
load_episodes,
|
load_episodes,
|
||||||
update_chunk_file_indices,
|
update_chunk_file_indices,
|
||||||
@@ -141,6 +144,315 @@ def delete_episodes(
|
|||||||
return new_dataset
|
return new_dataset
|
||||||
|
|
||||||
|
|
||||||
|
def trim_episode_start(
|
||||||
|
dataset: LeRobotDataset,
|
||||||
|
seconds: float,
|
||||||
|
episode_indices: list[int] | None = None,
|
||||||
|
output_dir: str | Path | None = None,
|
||||||
|
repo_id: str | None = None,
|
||||||
|
) -> LeRobotDataset:
|
||||||
|
"""Trim the first N seconds from selected episodes and create a new dataset.
|
||||||
|
|
||||||
|
The operation rewrites data parquet files and updates episode metadata so that:
|
||||||
|
- frame_index starts at 0 for each trimmed episode
|
||||||
|
- timestamp starts at 0 for each trimmed episode
|
||||||
|
- global index remains contiguous across the full dataset
|
||||||
|
- dataset_from_index / dataset_to_index reflect new frame ranges
|
||||||
|
|
||||||
|
Video files are copied as-is and per-episode video timestamps are shifted forward
|
||||||
|
for trimmed episodes.
|
||||||
|
|
||||||
|
Episodes selected for trimming that are too short (length <= trim_frames) are skipped
|
||||||
|
from the output dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset: The source LeRobotDataset.
|
||||||
|
seconds: Number of seconds to remove from episode starts.
|
||||||
|
episode_indices: Optional list of episode indices to trim. If None, trims all episodes.
|
||||||
|
output_dir: Directory to save the new dataset. If None, uses default location.
|
||||||
|
repo_id: Repository ID for the new dataset. If None, appends "_trimmed" to original.
|
||||||
|
"""
|
||||||
|
if seconds <= 0:
|
||||||
|
raise ValueError(f"seconds must be strictly positive, got {seconds}")
|
||||||
|
|
||||||
|
if dataset.meta.episodes is None:
|
||||||
|
dataset.meta.episodes = load_episodes(dataset.meta.root)
|
||||||
|
|
||||||
|
trim_frames = int(seconds * dataset.meta.fps)
|
||||||
|
if trim_frames <= 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"seconds={seconds} corresponds to 0 frames at fps={dataset.meta.fps}. "
|
||||||
|
"Increase seconds so at least one frame is trimmed."
|
||||||
|
)
|
||||||
|
|
||||||
|
if episode_indices is None:
|
||||||
|
episode_indices = list(range(dataset.meta.total_episodes))
|
||||||
|
|
||||||
|
if len(episode_indices) == 0:
|
||||||
|
raise ValueError("No episodes specified to trim")
|
||||||
|
|
||||||
|
episode_indices = sorted(set(episode_indices))
|
||||||
|
valid_indices = set(range(dataset.meta.total_episodes))
|
||||||
|
invalid = set(episode_indices) - valid_indices
|
||||||
|
if invalid:
|
||||||
|
raise ValueError(f"Invalid episode indices: {invalid}")
|
||||||
|
|
||||||
|
too_short = sorted(
|
||||||
|
ep_idx for ep_idx in episode_indices if int(dataset.meta.episodes[ep_idx]["length"]) <= trim_frames
|
||||||
|
)
|
||||||
|
trim_set = set(episode_indices)
|
||||||
|
skipped_set = set(too_short)
|
||||||
|
trim_set -= skipped_set
|
||||||
|
|
||||||
|
if too_short:
|
||||||
|
logging.warning(
|
||||||
|
f"Skipping {len(too_short)} episode(s) that are too short to trim "
|
||||||
|
f"({trim_frames} frames): {too_short}"
|
||||||
|
)
|
||||||
|
|
||||||
|
episodes_to_keep = [ep_idx for ep_idx in range(dataset.meta.total_episodes) if ep_idx not in skipped_set]
|
||||||
|
if not episodes_to_keep:
|
||||||
|
raise ValueError(
|
||||||
|
"All episodes selected for trimming are too short and would be skipped. "
|
||||||
|
"Try a smaller trim duration."
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info(
|
||||||
|
f"Trimming {len(trim_set)} episode(s) by {seconds}s and keeping {len(episodes_to_keep)} "
|
||||||
|
f"episode(s) in output"
|
||||||
|
)
|
||||||
|
|
||||||
|
if repo_id is None:
|
||||||
|
repo_id = f"{dataset.repo_id}_trimmed"
|
||||||
|
output_dir = Path(output_dir) if output_dir is not None else HF_LEROBOT_HOME / repo_id
|
||||||
|
|
||||||
|
new_meta = LeRobotDatasetMetadata.create(
|
||||||
|
repo_id=repo_id,
|
||||||
|
fps=dataset.meta.fps,
|
||||||
|
features=dataset.meta.features,
|
||||||
|
robot_type=dataset.meta.robot_type,
|
||||||
|
root=output_dir,
|
||||||
|
use_videos=len(dataset.meta.video_keys) > 0,
|
||||||
|
chunks_size=dataset.meta.chunks_size,
|
||||||
|
data_files_size_in_mb=dataset.meta.data_files_size_in_mb,
|
||||||
|
video_files_size_in_mb=dataset.meta.video_files_size_in_mb,
|
||||||
|
)
|
||||||
|
|
||||||
|
if dataset.meta.tasks is not None:
|
||||||
|
write_tasks(dataset.meta.tasks, new_meta.root)
|
||||||
|
new_meta.tasks = dataset.meta.tasks.copy()
|
||||||
|
|
||||||
|
subtasks_path = dataset.root / DEFAULT_SUBTASKS_PATH
|
||||||
|
if subtasks_path.exists():
|
||||||
|
dst_subtasks_path = new_meta.root / DEFAULT_SUBTASKS_PATH
|
||||||
|
dst_subtasks_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
shutil.copy(subtasks_path, dst_subtasks_path)
|
||||||
|
|
||||||
|
episode_mapping = {old_idx: new_idx for new_idx, old_idx in enumerate(episodes_to_keep)}
|
||||||
|
trim_duration_s = trim_frames / dataset.meta.fps
|
||||||
|
|
||||||
|
episode_lengths: dict[int, int] = {}
|
||||||
|
episode_ranges: dict[int, tuple[int, int]] = {}
|
||||||
|
total_frames = 0
|
||||||
|
for old_ep_idx in episodes_to_keep:
|
||||||
|
new_ep_idx = episode_mapping[old_ep_idx]
|
||||||
|
src_length = int(dataset.meta.episodes[old_ep_idx]["length"])
|
||||||
|
new_length = src_length - trim_frames if old_ep_idx in trim_set else src_length
|
||||||
|
episode_lengths[new_ep_idx] = new_length
|
||||||
|
episode_ranges[new_ep_idx] = (total_frames, total_frames + new_length)
|
||||||
|
total_frames += new_length
|
||||||
|
|
||||||
|
numeric_features = {
|
||||||
|
k: v
|
||||||
|
for k, v in dataset.meta.features.items()
|
||||||
|
if v["dtype"] not in ["image", "video", "string"]
|
||||||
|
}
|
||||||
|
episode_stats_parts: dict[int, list[dict[str, dict]]] = defaultdict(list)
|
||||||
|
episode_file_metadata: dict[int, dict[str, int]] = {}
|
||||||
|
|
||||||
|
data_dir = dataset.root / DATA_DIR
|
||||||
|
parquet_files = sorted(data_dir.glob("*/*.parquet"))
|
||||||
|
if not parquet_files:
|
||||||
|
raise ValueError(f"No parquet files found in {data_dir}")
|
||||||
|
|
||||||
|
for src_path in tqdm(parquet_files, desc="Trimming data files"):
|
||||||
|
df = pd.read_parquet(src_path).reset_index(drop=True)
|
||||||
|
|
||||||
|
if len(df) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if skipped_set:
|
||||||
|
keep_mask = ~df["episode_index"].isin(skipped_set)
|
||||||
|
if not keep_mask.all():
|
||||||
|
df = df.loc[keep_mask].copy().reset_index(drop=True)
|
||||||
|
|
||||||
|
if len(df) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if trim_set:
|
||||||
|
trim_mask = df["episode_index"].isin(trim_set) & (df["frame_index"] < trim_frames)
|
||||||
|
if trim_mask.any():
|
||||||
|
df = df.loc[~trim_mask].copy().reset_index(drop=True)
|
||||||
|
|
||||||
|
if len(df) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
relative_path = src_path.relative_to(dataset.root)
|
||||||
|
chunk_idx = int(relative_path.parts[1].split("-")[1])
|
||||||
|
file_idx = int(relative_path.parts[2].split("-")[1].split(".")[0])
|
||||||
|
|
||||||
|
for old_ep_idx in sorted(df["episode_index"].unique().tolist()):
|
||||||
|
ep_mask = df["episode_index"] == old_ep_idx
|
||||||
|
new_ep_idx = episode_mapping[old_ep_idx]
|
||||||
|
|
||||||
|
if old_ep_idx in trim_set:
|
||||||
|
df.loc[ep_mask, "frame_index"] = df.loc[ep_mask, "frame_index"] - trim_frames
|
||||||
|
shifted_timestamps = df.loc[ep_mask, "timestamp"].to_numpy(dtype=np.float64) - trim_duration_s
|
||||||
|
df.loc[ep_mask, "timestamp"] = np.clip(shifted_timestamps, a_min=0.0, a_max=None)
|
||||||
|
|
||||||
|
df.loc[ep_mask, "episode_index"] = new_ep_idx
|
||||||
|
|
||||||
|
ep_start, _ = episode_ranges[new_ep_idx]
|
||||||
|
new_indices = ep_start + df.loc[ep_mask, "frame_index"].to_numpy(dtype=np.int64)
|
||||||
|
df.loc[ep_mask, "index"] = new_indices
|
||||||
|
|
||||||
|
if new_ep_idx in episode_file_metadata:
|
||||||
|
existing = episode_file_metadata[new_ep_idx]
|
||||||
|
if (
|
||||||
|
existing["data/chunk_index"] != chunk_idx
|
||||||
|
or existing["data/file_index"] != file_idx
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"Episode {old_ep_idx} spans multiple data files. "
|
||||||
|
"trim_episode_start currently expects one data file per episode."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
episode_file_metadata[new_ep_idx] = {
|
||||||
|
"data/chunk_index": chunk_idx,
|
||||||
|
"data/file_index": file_idx,
|
||||||
|
}
|
||||||
|
|
||||||
|
if numeric_features:
|
||||||
|
ep_df = df.loc[ep_mask]
|
||||||
|
episode_data: dict[str, np.ndarray] = {}
|
||||||
|
episode_feature_spec: dict[str, dict] = {}
|
||||||
|
|
||||||
|
for key, feature in numeric_features.items():
|
||||||
|
if key not in ep_df.columns:
|
||||||
|
continue
|
||||||
|
|
||||||
|
values = ep_df[key].to_numpy()
|
||||||
|
if len(values) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
first_value = values[0]
|
||||||
|
if isinstance(first_value, np.ndarray):
|
||||||
|
episode_data[key] = np.stack(values)
|
||||||
|
elif isinstance(first_value, (list, tuple)):
|
||||||
|
episode_data[key] = np.stack(values)
|
||||||
|
else:
|
||||||
|
episode_data[key] = np.asarray(values)
|
||||||
|
|
||||||
|
episode_feature_spec[key] = feature
|
||||||
|
|
||||||
|
if episode_data:
|
||||||
|
episode_stats_parts[new_ep_idx].append(
|
||||||
|
compute_episode_stats(episode_data, episode_feature_spec)
|
||||||
|
)
|
||||||
|
|
||||||
|
df["index"] = df["index"].astype(np.int64)
|
||||||
|
if "frame_index" in df.columns:
|
||||||
|
df["frame_index"] = df["frame_index"].astype(np.int64)
|
||||||
|
|
||||||
|
dst_path = new_meta.root / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||||
|
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
_write_parquet(df, dst_path, new_meta)
|
||||||
|
|
||||||
|
all_episode_stats = []
|
||||||
|
for old_ep_idx in tqdm(episodes_to_keep, desc="Writing episode metadata"):
|
||||||
|
new_ep_idx = episode_mapping[old_ep_idx]
|
||||||
|
|
||||||
|
if new_ep_idx not in episode_file_metadata:
|
||||||
|
raise ValueError(f"Missing data file metadata for episode {old_ep_idx}")
|
||||||
|
|
||||||
|
from_idx, to_idx = episode_ranges[new_ep_idx]
|
||||||
|
src_episode = dataset.meta.episodes[old_ep_idx]
|
||||||
|
ep_data_meta = episode_file_metadata[new_ep_idx]
|
||||||
|
|
||||||
|
stats_parts = episode_stats_parts.get(new_ep_idx, [])
|
||||||
|
ep_stats = aggregate_stats(stats_parts) if len(stats_parts) > 1 else (stats_parts[0] if stats_parts else {})
|
||||||
|
if ep_stats:
|
||||||
|
all_episode_stats.append(ep_stats)
|
||||||
|
|
||||||
|
episode_meta = {
|
||||||
|
"data/chunk_index": ep_data_meta["data/chunk_index"],
|
||||||
|
"data/file_index": ep_data_meta["data/file_index"],
|
||||||
|
"dataset_from_index": from_idx,
|
||||||
|
"dataset_to_index": to_idx,
|
||||||
|
}
|
||||||
|
|
||||||
|
for video_key in dataset.meta.video_keys:
|
||||||
|
from_ts = src_episode[f"videos/{video_key}/from_timestamp"]
|
||||||
|
if old_ep_idx in trim_set:
|
||||||
|
from_ts += trim_duration_s
|
||||||
|
episode_meta.update(
|
||||||
|
{
|
||||||
|
f"videos/{video_key}/chunk_index": src_episode[f"videos/{video_key}/chunk_index"],
|
||||||
|
f"videos/{video_key}/file_index": src_episode[f"videos/{video_key}/file_index"],
|
||||||
|
f"videos/{video_key}/from_timestamp": from_ts,
|
||||||
|
f"videos/{video_key}/to_timestamp": src_episode[f"videos/{video_key}/to_timestamp"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
episode_dict = {
|
||||||
|
"episode_index": new_ep_idx,
|
||||||
|
"tasks": src_episode["tasks"],
|
||||||
|
"length": episode_lengths[new_ep_idx],
|
||||||
|
}
|
||||||
|
episode_dict.update(episode_meta)
|
||||||
|
if ep_stats:
|
||||||
|
episode_dict.update(flatten_dict({"stats": ep_stats}))
|
||||||
|
|
||||||
|
new_meta._save_episode_metadata(episode_dict)
|
||||||
|
|
||||||
|
new_meta._close_writer()
|
||||||
|
|
||||||
|
if new_meta.video_keys:
|
||||||
|
_copy_videos(dataset, new_meta)
|
||||||
|
|
||||||
|
new_meta.info.update(
|
||||||
|
{
|
||||||
|
"total_episodes": len(episodes_to_keep),
|
||||||
|
"total_frames": total_frames,
|
||||||
|
"total_tasks": len(new_meta.tasks) if new_meta.tasks is not None else 0,
|
||||||
|
"splits": {"train": f"0:{len(episodes_to_keep)}"},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if new_meta.video_keys and dataset.meta.video_keys:
|
||||||
|
for key in new_meta.video_keys:
|
||||||
|
if key in dataset.meta.features:
|
||||||
|
new_meta.info["features"][key]["info"] = dataset.meta.info["features"][key].get("info", {})
|
||||||
|
|
||||||
|
write_info(new_meta.info, new_meta.root)
|
||||||
|
|
||||||
|
merged_stats = aggregate_stats(all_episode_stats) if all_episode_stats else {}
|
||||||
|
if dataset.meta.stats:
|
||||||
|
for key, value in dataset.meta.stats.items():
|
||||||
|
if key not in merged_stats:
|
||||||
|
merged_stats[key] = value
|
||||||
|
if merged_stats:
|
||||||
|
write_stats(merged_stats, new_meta.root)
|
||||||
|
|
||||||
|
return LeRobotDataset(
|
||||||
|
repo_id=repo_id,
|
||||||
|
root=output_dir,
|
||||||
|
image_transforms=dataset.image_transforms,
|
||||||
|
delta_timestamps=dataset.delta_timestamps,
|
||||||
|
tolerance_s=dataset.tolerance_s,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def split_dataset(
|
def split_dataset(
|
||||||
dataset: LeRobotDataset,
|
dataset: LeRobotDataset,
|
||||||
splits: dict[str, float | list[int]],
|
splits: dict[str, float | list[int]],
|
||||||
|
|||||||
@@ -117,6 +117,13 @@ Modify tasks - set default task with overrides for specific episodes (WARNING: m
|
|||||||
--operation.new_task "Default task" \
|
--operation.new_task "Default task" \
|
||||||
--operation.episode_tasks '{"5": "Special task for episode 5"}'
|
--operation.episode_tasks '{"5": "Special task for episode 5"}'
|
||||||
|
|
||||||
|
Trim first 3 seconds from all episodes:
|
||||||
|
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||||
|
--repo_id lerobot/pusht \
|
||||||
|
--new_repo_id lerobot/pusht_trim3s \
|
||||||
|
--operation.type trim_episode_start \
|
||||||
|
--operation.seconds 3.0
|
||||||
|
|
||||||
Convert image dataset to video format and save locally:
|
Convert image dataset to video format and save locally:
|
||||||
lerobot-edit-dataset \
|
lerobot-edit-dataset \
|
||||||
--repo_id lerobot/pusht_image \
|
--repo_id lerobot/pusht_image \
|
||||||
@@ -170,6 +177,7 @@ from lerobot.datasets.dataset_tools import (
|
|||||||
modify_tasks,
|
modify_tasks,
|
||||||
remove_feature,
|
remove_feature,
|
||||||
split_dataset,
|
split_dataset,
|
||||||
|
trim_episode_start,
|
||||||
)
|
)
|
||||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||||
from lerobot.utils.constants import HF_LEROBOT_HOME
|
from lerobot.utils.constants import HF_LEROBOT_HOME
|
||||||
@@ -215,6 +223,13 @@ class ModifyTasksConfig(OperationConfig):
|
|||||||
episode_tasks: dict[str, str] | None = None
|
episode_tasks: dict[str, str] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@OperationConfig.register_subclass("trim_episode_start")
|
||||||
|
@dataclass
|
||||||
|
class TrimEpisodeStartConfig(OperationConfig):
|
||||||
|
seconds: float | None = None
|
||||||
|
episode_indices: list[int] | None = None
|
||||||
|
|
||||||
|
|
||||||
@OperationConfig.register_subclass("convert_image_to_video")
|
@OperationConfig.register_subclass("convert_image_to_video")
|
||||||
@dataclass
|
@dataclass
|
||||||
class ConvertImageToVideoConfig(OperationConfig):
|
class ConvertImageToVideoConfig(OperationConfig):
|
||||||
@@ -464,6 +479,41 @@ def handle_modify_tasks(cfg: EditDatasetConfig) -> None:
|
|||||||
modified_dataset.push_to_hub()
|
modified_dataset.push_to_hub()
|
||||||
|
|
||||||
|
|
||||||
|
def handle_trim_episode_start(cfg: EditDatasetConfig) -> None:
|
||||||
|
if not isinstance(cfg.operation, TrimEpisodeStartConfig):
|
||||||
|
raise ValueError("Operation config must be TrimEpisodeStartConfig")
|
||||||
|
|
||||||
|
if cfg.operation.seconds is None:
|
||||||
|
raise ValueError("seconds must be specified for trim_episode_start operation")
|
||||||
|
|
||||||
|
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root)
|
||||||
|
output_repo_id, output_dir = get_output_path(
|
||||||
|
cfg.repo_id, cfg.new_repo_id, Path(cfg.root) if cfg.root else None
|
||||||
|
)
|
||||||
|
|
||||||
|
if cfg.new_repo_id is None:
|
||||||
|
dataset.root = Path(str(dataset.root) + "_old")
|
||||||
|
|
||||||
|
logging.info(
|
||||||
|
f"Trimming first {cfg.operation.seconds}s from episodes "
|
||||||
|
f"{cfg.operation.episode_indices if cfg.operation.episode_indices else 'ALL'} in {cfg.repo_id}"
|
||||||
|
)
|
||||||
|
new_dataset = trim_episode_start(
|
||||||
|
dataset=dataset,
|
||||||
|
seconds=cfg.operation.seconds,
|
||||||
|
episode_indices=cfg.operation.episode_indices,
|
||||||
|
output_dir=output_dir,
|
||||||
|
repo_id=output_repo_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info(f"Dataset saved to {output_dir}")
|
||||||
|
logging.info(f"Episodes: {new_dataset.meta.total_episodes}, Frames: {new_dataset.meta.total_frames}")
|
||||||
|
|
||||||
|
if cfg.push_to_hub:
|
||||||
|
logging.info(f"Pushing to hub as {output_repo_id}")
|
||||||
|
LeRobotDataset(output_repo_id, root=output_dir).push_to_hub()
|
||||||
|
|
||||||
|
|
||||||
def handle_convert_image_to_video(cfg: EditDatasetConfig) -> None:
|
def handle_convert_image_to_video(cfg: EditDatasetConfig) -> None:
|
||||||
# Note: Parser may create any config type with the right fields, so we access fields directly
|
# Note: Parser may create any config type with the right fields, so we access fields directly
|
||||||
# instead of checking isinstance()
|
# instead of checking isinstance()
|
||||||
@@ -594,6 +644,8 @@ def edit_dataset(cfg: EditDatasetConfig) -> None:
|
|||||||
handle_remove_feature(cfg)
|
handle_remove_feature(cfg)
|
||||||
elif operation_type == "modify_tasks":
|
elif operation_type == "modify_tasks":
|
||||||
handle_modify_tasks(cfg)
|
handle_modify_tasks(cfg)
|
||||||
|
elif operation_type == "trim_episode_start":
|
||||||
|
handle_trim_episode_start(cfg)
|
||||||
elif operation_type == "convert_image_to_video":
|
elif operation_type == "convert_image_to_video":
|
||||||
handle_convert_image_to_video(cfg)
|
handle_convert_image_to_video(cfg)
|
||||||
elif operation_type == "info":
|
elif operation_type == "info":
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ from lerobot.datasets.dataset_tools import (
|
|||||||
modify_tasks,
|
modify_tasks,
|
||||||
remove_feature,
|
remove_feature,
|
||||||
split_dataset,
|
split_dataset,
|
||||||
|
trim_episode_start,
|
||||||
)
|
)
|
||||||
from lerobot.scripts.lerobot_edit_dataset import convert_image_to_video_dataset
|
from lerobot.scripts.lerobot_edit_dataset import convert_image_to_video_dataset
|
||||||
|
|
||||||
@@ -142,6 +143,104 @@ def test_delete_empty_list(sample_dataset, tmp_path):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_trim_episode_start_updates_indices(sample_dataset, tmp_path):
|
||||||
|
"""Test trimming episode starts updates frame/timestamp/index metadata consistently."""
|
||||||
|
output_dir = tmp_path / "trimmed"
|
||||||
|
trim_seconds = 0.1 # 3 frames at 30 FPS
|
||||||
|
trim_frames = int(trim_seconds * sample_dataset.meta.fps)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||||
|
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||||
|
):
|
||||||
|
mock_get_safe_version.return_value = "v3.0"
|
||||||
|
mock_snapshot_download.return_value = str(output_dir)
|
||||||
|
|
||||||
|
new_dataset = trim_episode_start(
|
||||||
|
sample_dataset,
|
||||||
|
seconds=trim_seconds,
|
||||||
|
output_dir=output_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
expected_length = 10 - trim_frames
|
||||||
|
assert new_dataset.meta.total_episodes == sample_dataset.meta.total_episodes
|
||||||
|
assert new_dataset.meta.total_frames == sample_dataset.meta.total_episodes * expected_length
|
||||||
|
|
||||||
|
indices = [int(i.item()) for i in new_dataset.hf_dataset["index"]]
|
||||||
|
assert indices == list(range(new_dataset.meta.total_frames))
|
||||||
|
|
||||||
|
episode_indices = [int(i.item()) for i in new_dataset.hf_dataset["episode_index"]]
|
||||||
|
frame_indices = [int(i.item()) for i in new_dataset.hf_dataset["frame_index"]]
|
||||||
|
timestamps = [float(i.item()) for i in new_dataset.hf_dataset["timestamp"]]
|
||||||
|
|
||||||
|
for ep_idx in range(sample_dataset.meta.total_episodes):
|
||||||
|
ep_frame_indices = [f for e, f in zip(episode_indices, frame_indices, strict=False) if e == ep_idx]
|
||||||
|
ep_timestamps = [t for e, t in zip(episode_indices, timestamps, strict=False) if e == ep_idx]
|
||||||
|
|
||||||
|
assert len(ep_frame_indices) == expected_length
|
||||||
|
assert ep_frame_indices == list(range(expected_length))
|
||||||
|
assert ep_timestamps[0] == pytest.approx(0.0)
|
||||||
|
assert ep_timestamps[-1] == pytest.approx((expected_length - 1) / sample_dataset.meta.fps)
|
||||||
|
|
||||||
|
ep_meta = new_dataset.meta.episodes[ep_idx]
|
||||||
|
assert int(ep_meta["length"]) == expected_length
|
||||||
|
assert int(ep_meta["dataset_from_index"]) == ep_idx * expected_length
|
||||||
|
assert int(ep_meta["dataset_to_index"]) == (ep_idx + 1) * expected_length
|
||||||
|
|
||||||
|
|
||||||
|
def test_trim_episode_start_skips_too_short_episodes(tmp_path, empty_lerobot_dataset_factory):
|
||||||
|
"""Test too-short episodes are skipped and remaining episodes are reindexed."""
|
||||||
|
features = {
|
||||||
|
"action": {"dtype": "float32", "shape": (2,), "names": None},
|
||||||
|
"observation.state": {"dtype": "float32", "shape": (2,), "names": None},
|
||||||
|
"observation.images.top": {"dtype": "image", "shape": (32, 32, 3), "names": None},
|
||||||
|
}
|
||||||
|
dataset = empty_lerobot_dataset_factory(root=tmp_path / "source", features=features)
|
||||||
|
|
||||||
|
for ep_len in [10, 2, 10]:
|
||||||
|
for _ in range(ep_len):
|
||||||
|
dataset.add_frame(
|
||||||
|
{
|
||||||
|
"action": np.random.randn(2).astype(np.float32),
|
||||||
|
"observation.state": np.random.randn(2).astype(np.float32),
|
||||||
|
"observation.images.top": np.random.randint(0, 255, size=(32, 32, 3), dtype=np.uint8),
|
||||||
|
"task": "task",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
dataset.save_episode()
|
||||||
|
dataset.finalize()
|
||||||
|
|
||||||
|
trim_seconds = 0.1 # 3 frames at 30 FPS
|
||||||
|
with (
|
||||||
|
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||||
|
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||||
|
):
|
||||||
|
mock_get_safe_version.return_value = "v3.0"
|
||||||
|
mock_snapshot_download.return_value = str(tmp_path / "trimmed")
|
||||||
|
|
||||||
|
new_dataset = trim_episode_start(
|
||||||
|
dataset,
|
||||||
|
seconds=trim_seconds,
|
||||||
|
output_dir=tmp_path / "trimmed",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Episode 1 is too short and gets skipped. Remaining episodes are trimmed and reindexed.
|
||||||
|
assert new_dataset.meta.total_episodes == 2
|
||||||
|
assert new_dataset.meta.total_frames == 14
|
||||||
|
assert sorted({int(idx.item()) for idx in new_dataset.hf_dataset["episode_index"]}) == [0, 1]
|
||||||
|
assert [int(ep["length"]) for ep in new_dataset.meta.episodes] == [7, 7]
|
||||||
|
|
||||||
|
|
||||||
|
def test_trim_episode_start_rejects_when_all_selected_are_too_short(sample_dataset, tmp_path):
|
||||||
|
"""Test trimming fails when all selected episodes are too short and would be skipped."""
|
||||||
|
with pytest.raises(ValueError, match="All episodes selected for trimming are too short"):
|
||||||
|
trim_episode_start(
|
||||||
|
sample_dataset,
|
||||||
|
seconds=1.0, # 30 frames > 10-frame episodes
|
||||||
|
output_dir=tmp_path / "trimmed",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_split_by_episodes(sample_dataset, tmp_path):
|
def test_split_by_episodes(sample_dataset, tmp_path):
|
||||||
"""Test splitting dataset by specific episode indices."""
|
"""Test splitting dataset by specific episode indices."""
|
||||||
splits = {
|
splits = {
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ from lerobot.scripts.lerobot_edit_dataset import (
|
|||||||
RemoveFeatureConfig,
|
RemoveFeatureConfig,
|
||||||
SplitConfig,
|
SplitConfig,
|
||||||
_validate_config,
|
_validate_config,
|
||||||
|
TrimEpisodeStartConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -47,6 +48,7 @@ class TestOperationTypeParsing:
|
|||||||
("merge", MergeConfig),
|
("merge", MergeConfig),
|
||||||
("remove_feature", RemoveFeatureConfig),
|
("remove_feature", RemoveFeatureConfig),
|
||||||
("modify_tasks", ModifyTasksConfig),
|
("modify_tasks", ModifyTasksConfig),
|
||||||
|
("trim_episode_start", TrimEpisodeStartConfig),
|
||||||
("convert_image_to_video", ConvertImageToVideoConfig),
|
("convert_image_to_video", ConvertImageToVideoConfig),
|
||||||
("info", InfoConfig),
|
("info", InfoConfig),
|
||||||
],
|
],
|
||||||
@@ -77,6 +79,7 @@ class TestOperationTypeParsing:
|
|||||||
("merge", MergeConfig),
|
("merge", MergeConfig),
|
||||||
("remove_feature", RemoveFeatureConfig),
|
("remove_feature", RemoveFeatureConfig),
|
||||||
("modify_tasks", ModifyTasksConfig),
|
("modify_tasks", ModifyTasksConfig),
|
||||||
|
("trim_episode_start", TrimEpisodeStartConfig),
|
||||||
("convert_image_to_video", ConvertImageToVideoConfig),
|
("convert_image_to_video", ConvertImageToVideoConfig),
|
||||||
("info", InfoConfig),
|
("info", InfoConfig),
|
||||||
],
|
],
|
||||||
|
|||||||
Reference in New Issue
Block a user