feat(dataset-tools): add dataset utilities and example script

- Introduced dataset tools for LeRobotDataset, including functions for deleting episodes, splitting datasets, adding/removing features, and merging datasets.
- Added an example script demonstrating the usage of these utilities.
- Implemented comprehensive tests for all new functionalities to ensure reliability and correctness.
This commit is contained in:
Michel Aractingi
2025-08-13 01:45:49 +02:00
parent 267a753eda
commit 867174c8bc
3 changed files with 1456 additions and 0 deletions
+111
View File
@@ -0,0 +1,111 @@
#!/usr/bin/env python
"""
Example script demonstrating dataset tools utilities.
This script shows how to:
1. Delete episodes from a dataset
2. Split a dataset into train/val sets
3. Add/remove features
4. Merge datasets
Usage:
python examples/use_dataset_tools.py
"""
import numpy as np
from lerobot.datasets.dataset_tools import (
add_feature,
delete_episodes,
merge_datasets,
remove_feature,
split_dataset,
)
from lerobot.datasets.lerobot_dataset import LeRobotDataset
def main():
# Load an existing dataset (replace with your dataset)
dataset = LeRobotDataset("lerobot/pusht")
print(f"Original dataset: {dataset.meta.total_episodes} episodes, {dataset.meta.total_frames} frames")
print(f"Features: {list(dataset.meta.features.keys())}")
# Example 1: Delete episodes
print("\n1. Deleting episodes 0 and 2...")
filtered_dataset = delete_episodes(dataset, episode_indices=[0, 2], repo_id="pusht_filtered")
print(f"Filtered dataset: {filtered_dataset.meta.total_episodes} episodes")
# Example 2: Split dataset
print("\n2. Splitting dataset into train/val...")
splits = split_dataset(
dataset,
splits={"train": 0.8, "val": 0.2},
)
print(f"Train split: {splits['train'].meta.total_episodes} episodes")
print(f"Val split: {splits['val'].meta.total_episodes} episodes")
# Example 3: Add a feature
print("\n3. Adding a reward feature...")
# Method 1: Pre-computed values
reward_values = np.random.randn(dataset.meta.total_frames).astype(np.float32)
dataset_with_reward = add_feature(
dataset,
feature_name="reward",
feature_values=reward_values,
feature_info={
"dtype": "float32",
"shape": (1,),
"names": None,
},
repo_id="pusht_with_reward",
)
# Method 2: Using a callable
def compute_success(frame_dict, episode_idx, frame_idx):
# Example: mark last 10 frames of each episode as successful
episode_length = 10 # You'd get this from episode metadata
return float(frame_idx >= episode_length - 10)
dataset_with_success = add_feature(
dataset_with_reward,
feature_name="success",
feature_values=compute_success,
feature_info={
"dtype": "float32",
"shape": (1,),
"names": None,
},
repo_id="pusht_with_reward_and_success",
)
print(f"New features: {list(dataset_with_success.meta.features.keys())}")
# Example 4: Remove features
print("\n4. Removing the success feature...")
dataset_cleaned = remove_feature(dataset_with_success, feature_names="success", repo_id="pusht_cleaned")
print(f"Features after removal: {list(dataset_cleaned.meta.features.keys())}")
# Example 5: Merge datasets
print("\n5. Merging train and val splits back together...")
merged = merge_datasets([splits["train"], splits["val"]], output_repo_id="pusht_merged")
print(f"Merged dataset: {merged.meta.total_episodes} episodes")
# Example 6: Complex workflow
print("\n6. Complex workflow example...")
# Remove a camera if dataset has multiple
if len(dataset.meta.camera_keys) > 1:
camera_to_remove = dataset.meta.camera_keys[0]
print(f"Removing camera: {camera_to_remove}")
dataset_no_cam = remove_feature(
dataset, feature_names=camera_to_remove, repo_id="pusht_no_first_camera"
)
print(f"Remaining cameras: {dataset_no_cam.meta.camera_keys}")
print("\nDone! Check ~/.cache/huggingface/lerobot/ for the created datasets.")
if __name__ == "__main__":
main()
+761
View File
@@ -0,0 +1,761 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Dataset tools utilities for LeRobotDataset.
This module provides utilities for:
- Deleting episodes from datasets
- Splitting datasets into multiple smaller datasets
- Adding/removing features from datasets
- Merging datasets (wrapper around aggregate functionality)
"""
import logging
import shutil
from collections.abc import Callable
from pathlib import Path
import numpy as np
import pandas as pd
import torch
from tqdm import tqdm
from lerobot.constants import HF_LEROBOT_HOME
from lerobot.datasets.aggregate import aggregate_datasets
from lerobot.datasets.compute_stats import aggregate_stats
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
from lerobot.datasets.utils import (
DEFAULT_CHUNK_SIZE,
DEFAULT_DATA_FILE_SIZE_IN_MB,
DEFAULT_DATA_PATH,
DEFAULT_VIDEO_FILE_SIZE_IN_MB,
DEFAULT_VIDEO_PATH,
get_parquet_file_size_in_mb,
get_video_size_in_mb,
to_parquet_with_hf_images,
update_chunk_file_indices,
write_info,
write_stats,
write_tasks,
)
def delete_episodes(
dataset: LeRobotDataset,
episode_indices: list[int],
output_dir: str | Path | None = None,
repo_id: str | None = None,
) -> LeRobotDataset:
"""Delete episodes from a LeRobotDataset and create a new dataset.
Args:
dataset: The source LeRobotDataset.
episode_indices: List of episode indices to delete.
output_dir: Directory to save the new dataset. If None, uses default location.
repo_id: Repository ID for the new dataset. If None, appends "_filtered" to original.
Returns:
LeRobotDataset: New dataset with episodes removed.
"""
if not episode_indices:
raise ValueError("No episodes to delete")
# Validate episode indices
valid_indices = set(range(dataset.meta.total_episodes))
invalid = set(episode_indices) - valid_indices
if invalid:
raise ValueError(f"Invalid episode indices: {invalid}")
logging.info(f"Deleting {len(episode_indices)} episodes from dataset")
# Create new dataset metadata
if repo_id is None:
repo_id = f"{dataset.repo_id}_filtered"
if output_dir is None:
output_dir = HF_LEROBOT_HOME / repo_id
else:
output_dir = Path(output_dir)
# Get episodes to keep
episodes_to_keep = [i for i in range(dataset.meta.total_episodes) if i not in episode_indices]
if not episodes_to_keep:
raise ValueError("Cannot delete all episodes from dataset")
# Create new dataset
new_meta = LeRobotDatasetMetadata.create(
repo_id=repo_id,
fps=dataset.meta.fps,
features=dataset.meta.features,
robot_type=dataset.meta.robot_type,
root=output_dir,
use_videos=len(dataset.meta.video_keys) > 0,
)
# Process episodes
episode_mapping = {} # old_idx -> new_idx
new_episode_idx = 0
for old_idx in tqdm(episodes_to_keep, desc="Processing episodes"):
episode_mapping[old_idx] = new_episode_idx
new_episode_idx += 1
# Copy data files and update indices
_copy_and_reindex_data(dataset, new_meta, episode_mapping)
# Copy video files if present
if dataset.meta.video_keys:
_copy_and_reindex_videos(dataset, new_meta, episode_mapping)
# Create new dataset instance
new_dataset = LeRobotDataset(
repo_id=repo_id,
root=output_dir,
image_transforms=dataset.image_transforms,
delta_timestamps=dataset.delta_timestamps,
tolerance_s=dataset.tolerance_s,
)
logging.info(f"Created new dataset with {len(episodes_to_keep)} episodes")
return new_dataset
def split_dataset(
dataset: LeRobotDataset,
splits: dict[str, list[int]] | dict[str, float],
output_dir: str | Path | None = None,
) -> dict[str, LeRobotDataset]:
"""Split a LeRobotDataset into multiple smaller datasets.
Args:
dataset: The source LeRobotDataset to split.
splits: Either a dict mapping split names to episode indices, or a dict mapping
split names to fractions (must sum to <= 1.0).
output_dir: Base directory for output datasets. If None, uses default location.
Returns:
dict[str, LeRobotDataset]: Dictionary mapping split names to new datasets.
Examples:
# Split by specific episodes
splits = {"train": [0, 1, 2], "val": [3, 4]}
datasets = split_dataset(dataset, splits)
# Split by fractions
splits = {"train": 0.8, "val": 0.2}
datasets = split_dataset(dataset, splits)
"""
if not splits:
raise ValueError("No splits provided")
# Convert fractions to episode indices if needed
if all(isinstance(v, float) for v in splits.values()):
splits = _fractions_to_episode_indices(dataset.meta.total_episodes, splits)
# Validate episodes
all_episodes = set()
for split_name, episodes in splits.items():
if not episodes:
raise ValueError(f"Split '{split_name}' has no episodes")
episode_set = set(episodes)
if episode_set & all_episodes:
raise ValueError("Episodes cannot appear in multiple splits")
all_episodes.update(episode_set)
# Validate all episodes are valid
valid_indices = set(range(dataset.meta.total_episodes))
invalid = all_episodes - valid_indices
if invalid:
raise ValueError(f"Invalid episode indices: {invalid}")
if output_dir is None:
output_dir = HF_LEROBOT_HOME
else:
output_dir = Path(output_dir)
result_datasets = {}
for split_name, episodes in splits.items():
logging.info(f"Creating split '{split_name}' with {len(episodes)} episodes")
# Create repo_id for split
split_repo_id = f"{dataset.repo_id}_{split_name}"
split_output_dir = output_dir / split_repo_id
# Create episode mapping
episode_mapping = {old_idx: new_idx for new_idx, old_idx in enumerate(sorted(episodes))}
# Create new dataset metadata
new_meta = LeRobotDatasetMetadata.create(
repo_id=split_repo_id,
fps=dataset.meta.fps,
features=dataset.meta.features,
robot_type=dataset.meta.robot_type,
root=split_output_dir,
use_videos=len(dataset.meta.video_keys) > 0,
)
# Copy data and videos
_copy_and_reindex_data(dataset, new_meta, episode_mapping)
if dataset.meta.video_keys:
_copy_and_reindex_videos(dataset, new_meta, episode_mapping)
# Create new dataset instance
new_dataset = LeRobotDataset(
repo_id=split_repo_id,
root=split_output_dir,
image_transforms=dataset.image_transforms,
delta_timestamps=dataset.delta_timestamps,
tolerance_s=dataset.tolerance_s,
)
result_datasets[split_name] = new_dataset
return result_datasets
def merge_datasets(
datasets: list[LeRobotDataset],
output_repo_id: str,
output_dir: str | Path | None = None,
) -> LeRobotDataset:
"""Merge multiple LeRobotDatasets into a single dataset.
This is a wrapper around the aggregate_datasets functionality with a cleaner API.
Args:
datasets: List of LeRobotDatasets to merge.
output_repo_id: Repository ID for the merged dataset.
output_dir: Directory to save the merged dataset. If None, uses default location.
Returns:
LeRobotDataset: The merged dataset.
"""
if not datasets:
raise ValueError("No datasets to merge")
if output_dir is None:
output_dir = HF_LEROBOT_HOME / output_repo_id
else:
output_dir = Path(output_dir)
# Extract repo_ids and roots
repo_ids = [ds.repo_id for ds in datasets]
roots = [ds.root for ds in datasets]
# Call aggregate_datasets
aggregate_datasets(
repo_ids=repo_ids,
aggr_repo_id=output_repo_id,
roots=roots,
aggr_root=output_dir,
)
# Create and return the merged dataset
merged_dataset = LeRobotDataset(
repo_id=output_repo_id,
root=output_dir,
image_transforms=datasets[0].image_transforms,
delta_timestamps=datasets[0].delta_timestamps,
tolerance_s=datasets[0].tolerance_s,
)
return merged_dataset
def add_feature(
dataset: LeRobotDataset,
feature_name: str,
feature_values: np.ndarray | torch.Tensor | Callable,
feature_info: dict,
output_dir: str | Path | None = None,
repo_id: str | None = None,
) -> LeRobotDataset:
"""Add a new feature to a LeRobotDataset.
Args:
dataset: The source LeRobotDataset.
feature_name: Name of the new feature.
feature_values: Either:
- Array/tensor of shape (num_frames, ...) with values for each frame
- Callable that takes (frame_dict, episode_index, frame_index) and returns feature value
feature_info: Dictionary with feature metadata (dtype, shape, names).
output_dir: Directory to save the new dataset. If None, uses default location.
repo_id: Repository ID for the new dataset. If None, appends "_modified" to original.
Returns:
LeRobotDataset: New dataset with the added feature.
"""
if feature_name in dataset.meta.features:
raise ValueError(f"Feature '{feature_name}' already exists in dataset")
if repo_id is None:
repo_id = f"{dataset.repo_id}_modified"
if output_dir is None:
output_dir = HF_LEROBOT_HOME / repo_id
else:
output_dir = Path(output_dir)
# Validate feature_info
required_keys = {"dtype", "shape"}
if not required_keys.issubset(feature_info.keys()):
raise ValueError(f"feature_info must contain keys: {required_keys}")
# Create new features dict
new_features = dataset.meta.features.copy()
new_features[feature_name] = feature_info
# Create new dataset metadata
new_meta = LeRobotDatasetMetadata.create(
repo_id=repo_id,
fps=dataset.meta.fps,
features=new_features,
robot_type=dataset.meta.robot_type,
root=output_dir,
use_videos=len(dataset.meta.video_keys) > 0,
)
# Process data with new feature
_copy_data_with_feature_changes(
dataset=dataset,
new_meta=new_meta,
add_features={feature_name: (feature_values, feature_info)},
)
# Copy videos if present
if dataset.meta.video_keys:
_copy_videos(dataset, new_meta)
# Create new dataset instance
new_dataset = LeRobotDataset(
repo_id=repo_id,
root=output_dir,
image_transforms=dataset.image_transforms,
delta_timestamps=dataset.delta_timestamps,
tolerance_s=dataset.tolerance_s,
)
return new_dataset
def remove_feature(
dataset: LeRobotDataset,
feature_names: str | list[str],
output_dir: str | Path | None = None,
repo_id: str | None = None,
) -> LeRobotDataset:
"""Remove features from a LeRobotDataset.
Args:
dataset: The source LeRobotDataset.
feature_names: Name(s) of features to remove. Can be a single string or list.
output_dir: Directory to save the new dataset. If None, uses default location.
repo_id: Repository ID for the new dataset. If None, appends "_modified" to original.
Returns:
LeRobotDataset: New dataset with features removed.
"""
if isinstance(feature_names, str):
feature_names = [feature_names]
# Validate features exist
for name in feature_names:
if name not in dataset.meta.features:
raise ValueError(f"Feature '{name}' not found in dataset")
# Check if trying to remove required features
required_features = {"timestamp", "frame_index", "episode_index", "index", "task_index"}
if any(name in required_features for name in feature_names):
raise ValueError(f"Cannot remove required features: {required_features}")
if repo_id is None:
repo_id = f"{dataset.repo_id}_modified"
if output_dir is None:
output_dir = HF_LEROBOT_HOME / repo_id
else:
output_dir = Path(output_dir)
# Create new features dict
new_features = {k: v for k, v in dataset.meta.features.items() if k not in feature_names}
# Check if removing video features
video_keys_to_remove = [name for name in feature_names if name in dataset.meta.video_keys]
# Check if videos will remain after removal
remaining_video_keys = [k for k in dataset.meta.video_keys if k not in video_keys_to_remove]
# Create new dataset metadata
new_meta = LeRobotDatasetMetadata.create(
repo_id=repo_id,
fps=dataset.meta.fps,
features=new_features,
robot_type=dataset.meta.robot_type,
root=output_dir,
use_videos=len(remaining_video_keys) > 0,
)
# Process data with removed features
_copy_data_with_feature_changes(
dataset=dataset,
new_meta=new_meta,
remove_features=feature_names,
)
# Copy videos (excluding removed ones)
if new_meta.video_keys:
_copy_videos(dataset, new_meta, exclude_keys=video_keys_to_remove)
# Create new dataset instance
new_dataset = LeRobotDataset(
repo_id=repo_id,
root=output_dir,
image_transforms=dataset.image_transforms,
delta_timestamps=dataset.delta_timestamps,
tolerance_s=dataset.tolerance_s,
)
return new_dataset
# Helper functions
def _fractions_to_episode_indices(
total_episodes: int,
splits: dict[str, float],
) -> dict[str, list[int]]:
"""Convert split fractions to episode indices."""
if sum(splits.values()) > 1.0:
raise ValueError("Split fractions must sum to <= 1.0")
indices = list(range(total_episodes))
result = {}
start_idx = 0
for split_name, fraction in splits.items():
num_episodes = int(total_episodes * fraction)
end_idx = start_idx + num_episodes
if split_name == list(splits.keys())[-1]: # Last split gets remaining episodes
end_idx = total_episodes
result[split_name] = indices[start_idx:end_idx]
start_idx = end_idx
return result
def _copy_and_reindex_data(
src_dataset: LeRobotDataset,
dst_meta: LeRobotDatasetMetadata,
episode_mapping: dict[int, int],
) -> None:
"""Copy data files and reindex episodes."""
# Get unique data files from episodes to keep
file_paths = set()
for old_idx in episode_mapping:
file_paths.add(src_dataset.meta.get_data_file_path(old_idx))
# Track global index
global_index = 0
chunk_idx, file_idx = 0, 0
# Process each data file
for src_path in tqdm(sorted(file_paths), desc="Processing data files"):
df = pd.read_parquet(src_dataset.root / src_path)
# Filter to keep only mapped episodes
mask = df["episode_index"].isin(episode_mapping.keys())
df = df[mask].copy()
if len(df) == 0:
continue
# Update episode indices
df["episode_index"] = df["episode_index"].map(episode_mapping)
# Update global index to be continuous
df["index"] = range(global_index, global_index + len(df))
global_index += len(df)
# Update task indices if needed
if dst_meta.tasks is None:
# Get unique tasks from filtered data
task_indices = df["task_index"].unique()
tasks = [src_dataset.meta.tasks.iloc[idx].name for idx in task_indices]
dst_meta.save_episode_tasks(list(set(tasks)))
# Remap task indices
task_mapping = {}
for old_task_idx in df["task_index"].unique():
task_name = src_dataset.meta.tasks.iloc[old_task_idx].name
new_task_idx = dst_meta.get_task_index(task_name)
task_mapping[old_task_idx] = new_task_idx
df["task_index"] = df["task_index"].map(task_mapping)
# Save processed data
chunk_idx, file_idx = _save_data_chunk(df, dst_meta, chunk_idx, file_idx)
# Process episodes metadata
_copy_and_reindex_episodes_metadata(src_dataset, dst_meta, episode_mapping)
def _copy_and_reindex_videos(
src_dataset: LeRobotDataset,
dst_meta: LeRobotDatasetMetadata,
episode_mapping: dict[int, int],
) -> None:
"""Copy video files and update metadata."""
for video_key in src_dataset.meta.video_keys:
video_files = set()
for old_idx in episode_mapping:
video_files.add(src_dataset.meta.get_video_file_path(old_idx, video_key))
chunk_idx, file_idx = 0, 0
for src_path in tqdm(sorted(video_files), desc=f"Processing {video_key} videos"):
dst_path = dst_meta.root / DEFAULT_VIDEO_PATH.format(
video_key=video_key,
chunk_index=chunk_idx,
file_index=file_idx,
)
dst_path.parent.mkdir(parents=True, exist_ok=True)
# For simplicity, copy entire video files
# In production, you might want to extract only relevant segments
shutil.copy(src_dataset.root / src_path, dst_path)
# Update indices for next file
file_size = get_video_size_in_mb(dst_path)
if file_size >= DEFAULT_VIDEO_FILE_SIZE_IN_MB * 0.9: # 90% threshold
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, DEFAULT_CHUNK_SIZE)
def _copy_and_reindex_episodes_metadata(
src_dataset: LeRobotDataset,
dst_meta: LeRobotDatasetMetadata,
episode_mapping: dict[int, int],
) -> None:
"""Copy and reindex episodes metadata."""
all_stats = []
frame_offset = 0
for old_idx, new_idx in tqdm(
sorted(episode_mapping.items(), key=lambda x: x[1]), desc="Processing episodes metadata"
):
# Get episode from source
src_episode = src_dataset.meta.episodes[old_idx]
# Create episode dict
episode_dict = {
"episode_index": new_idx,
"tasks": src_episode["tasks"], # Already a list of task names
"length": src_episode["length"],
}
# Copy other metadata
episode_metadata = {
"data/chunk_index": 0, # Will be recalculated when saving
"data/file_index": 0, # Will be recalculated when saving
"dataset_from_index": frame_offset,
"dataset_to_index": frame_offset + src_episode["length"],
}
# Update frame offset for next episode
frame_offset += src_episode["length"]
# Copy stats metadata
for key in src_episode.keys():
if key.startswith("stats/"):
episode_dict[key] = src_episode[key]
# Add episode metadata
stats_dict = {
key.replace("stats/", ""): value
for key, value in episode_dict.items()
if key.startswith("stats/")
}
all_stats.append(stats_dict)
# Calculate stats from dict
episode_stats = {}
for key in dst_meta.features:
if key in stats_dict:
episode_stats[key] = stats_dict[key]
dst_meta.save_episode(
new_idx, episode_dict["length"], episode_dict["tasks"], episode_stats, episode_metadata
)
# Aggregate all stats
if all_stats:
aggregated_stats = aggregate_stats(all_stats)
write_stats(aggregated_stats, dst_meta.root)
def _save_data_chunk(
df: pd.DataFrame,
meta: LeRobotDatasetMetadata,
chunk_idx: int = 0,
file_idx: int = 0,
) -> tuple[int, int]:
"""Save a data chunk and return updated indices."""
path = meta.root / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
path.parent.mkdir(parents=True, exist_ok=True)
if len(meta.image_keys) > 0:
to_parquet_with_hf_images(df, path)
else:
df.to_parquet(path)
# Check if we need to rotate files
file_size = get_parquet_file_size_in_mb(path)
if file_size >= DEFAULT_DATA_FILE_SIZE_IN_MB * 0.9: # 90% threshold
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, DEFAULT_CHUNK_SIZE)
return chunk_idx, file_idx
def _copy_data_with_feature_changes(
dataset: LeRobotDataset,
new_meta: LeRobotDatasetMetadata,
add_features: dict[str, tuple] | None = None,
remove_features: list[str] | None = None,
) -> None:
"""Copy data while adding or removing features."""
# Get all unique data files
file_paths = set()
for ep_idx in range(dataset.meta.total_episodes):
file_paths.add(dataset.meta.get_data_file_path(ep_idx))
frame_idx = 0
# Process each data file
for src_path in tqdm(sorted(file_paths), desc="Processing data files"):
df = pd.read_parquet(dataset.root / src_path)
# Remove features
if remove_features:
df = df.drop(columns=remove_features, errors="ignore")
# Add features
if add_features:
for feature_name, (values, _) in add_features.items():
if callable(values):
# Compute values for each frame
feature_values = []
for _, row in df.iterrows():
ep_idx = row["episode_index"]
frame_in_ep = row["frame_index"]
value = values(row.to_dict(), ep_idx, frame_in_ep)
# Convert numpy arrays to scalars for single-element arrays
if isinstance(value, np.ndarray) and value.size == 1:
value = value.item()
feature_values.append(value)
df[feature_name] = feature_values
else:
# Use provided values
end_idx = frame_idx + len(df)
# Convert to list to ensure proper shape handling
feature_slice = values[frame_idx:end_idx]
if len(feature_slice.shape) > 1 and feature_slice.shape[1] == 1:
# Flatten single-element arrays to scalars for pandas
df[feature_name] = feature_slice.flatten()
else:
df[feature_name] = feature_slice
frame_idx = end_idx
# Save chunk
_save_data_chunk(df, new_meta)
# Copy episodes metadata and update stats
_copy_episodes_metadata_and_stats(dataset, new_meta)
def _copy_videos(
src_dataset: LeRobotDataset,
dst_meta: LeRobotDatasetMetadata,
exclude_keys: list[str] | None = None,
) -> None:
"""Copy video files, optionally excluding certain keys."""
if exclude_keys is None:
exclude_keys = []
for video_key in src_dataset.meta.video_keys:
if video_key in exclude_keys:
continue
# Get all video files for this key
video_files = set()
for ep_idx in range(src_dataset.meta.total_episodes):
video_files.add(src_dataset.meta.get_video_file_path(ep_idx, video_key))
# Copy video files
for src_path in tqdm(sorted(video_files), desc=f"Copying {video_key} videos"):
# Maintain same structure
rel_path = src_path.relative_to(src_dataset.root)
dst_path = dst_meta.root / rel_path
dst_path.parent.mkdir(parents=True, exist_ok=True)
shutil.copy(src_dataset.root / src_path, dst_path)
def _copy_episodes_metadata_and_stats(
src_dataset: LeRobotDataset,
dst_meta: LeRobotDatasetMetadata,
) -> None:
"""Copy episodes metadata and recalculate stats."""
# Copy tasks
if src_dataset.meta.tasks is not None:
write_tasks(src_dataset.meta.tasks, dst_meta.root)
dst_meta.tasks = src_dataset.meta.tasks.copy()
# Copy episodes metadata files
episodes_dir = src_dataset.root / "meta/episodes"
dst_episodes_dir = dst_meta.root / "meta/episodes"
if episodes_dir.exists():
shutil.copytree(episodes_dir, dst_episodes_dir, dirs_exist_ok=True)
# Update info
dst_meta.info.update(
{
"total_episodes": src_dataset.meta.total_episodes,
"total_frames": src_dataset.meta.total_frames,
"total_tasks": src_dataset.meta.total_tasks,
"splits": src_dataset.meta.info.get("splits", {"train": f"0:{src_dataset.meta.total_episodes}"}),
}
)
# Update video info if needed
if dst_meta.video_keys and src_dataset.meta.video_keys:
for key in dst_meta.video_keys:
if key in src_dataset.meta.features:
dst_meta.info["features"][key]["info"] = src_dataset.meta.info["features"][key].get(
"info", {}
)
write_info(dst_meta.info, dst_meta.root)
# Recalculate stats if features changed
if set(dst_meta.features.keys()) != set(src_dataset.meta.features.keys()):
# Need to recalculate stats
logging.info("Recalculating dataset statistics...")
# This is a simplified version - in production you'd want to properly recalculate
if src_dataset.meta.stats:
new_stats = {}
for key in dst_meta.features:
if key in src_dataset.meta.stats:
new_stats[key] = src_dataset.meta.stats[key]
write_stats(new_stats, dst_meta.root)
else:
# Copy existing stats
if src_dataset.meta.stats:
write_stats(src_dataset.meta.stats, dst_meta.root)
+584
View File
@@ -0,0 +1,584 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for dataset tools utilities."""
from unittest.mock import patch
import numpy as np
import pytest
import torch
from lerobot.datasets.dataset_tools import (
add_feature,
delete_episodes,
merge_datasets,
remove_feature,
split_dataset,
)
@pytest.fixture
def sample_dataset(tmp_path, empty_lerobot_dataset_factory):
"""Create a sample dataset for testing."""
# Create an empty dataset and add data manually
features = {
"action": {"dtype": "float32", "shape": (6,), "names": None},
"observation.state": {"dtype": "float32", "shape": (4,), "names": None},
"observation.images.top": {"dtype": "image", "shape": (224, 224, 3), "names": None},
}
dataset = empty_lerobot_dataset_factory(
root=tmp_path / "test_dataset",
features=features,
)
# Add episodes manually
for ep_idx in range(5):
for _ in range(10):
frame = {
"action": np.random.randn(6).astype(np.float32),
"observation.state": np.random.randn(4).astype(np.float32),
"observation.images.top": np.random.randint(0, 255, size=(224, 224, 3), dtype=np.uint8),
"task": f"task_{ep_idx % 2}",
}
dataset.add_frame(frame)
dataset.save_episode()
return dataset
class TestDeleteEpisodes:
def test_delete_single_episode(self, sample_dataset, tmp_path):
"""Test deleting a single episode."""
output_dir = tmp_path / "filtered"
# Delete episode 2
# Mock the revision check and snapshot_download to prevent Hub calls
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.return_value = str(output_dir)
new_dataset = delete_episodes(
sample_dataset,
episode_indices=[2],
output_dir=output_dir,
)
# Check results
assert new_dataset.meta.total_episodes == 4
assert new_dataset.meta.total_frames == 40
# Check episode indices are renumbered
episode_indices = {int(idx.item()) for idx in new_dataset.hf_dataset["episode_index"]}
assert episode_indices == {0, 1, 2, 3}
# Check data integrity
assert len(new_dataset) == 40
def test_delete_multiple_episodes(self, sample_dataset, tmp_path):
"""Test deleting multiple episodes."""
output_dir = tmp_path / "filtered"
# Delete episodes 1 and 3
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.return_value = str(output_dir)
new_dataset = delete_episodes(
sample_dataset,
episode_indices=[1, 3],
output_dir=output_dir,
)
# Check results
assert new_dataset.meta.total_episodes == 3
assert new_dataset.meta.total_frames == 30
# Check episode indices
episode_indices = {int(idx.item()) for idx in new_dataset.hf_dataset["episode_index"]}
assert episode_indices == {0, 1, 2}
def test_delete_invalid_episodes(self, sample_dataset, tmp_path):
"""Test error handling for invalid episode indices."""
with pytest.raises(ValueError, match="Invalid episode indices"):
delete_episodes(
sample_dataset,
episode_indices=[10, 20], # Out of range
output_dir=tmp_path / "filtered",
)
def test_delete_all_episodes(self, sample_dataset, tmp_path):
"""Test error when trying to delete all episodes."""
with pytest.raises(ValueError, match="Cannot delete all episodes"):
delete_episodes(
sample_dataset,
episode_indices=list(range(5)), # All episodes
output_dir=tmp_path / "filtered",
)
def test_delete_empty_list(self, sample_dataset, tmp_path):
"""Test error when no episodes specified."""
with pytest.raises(ValueError, match="No episodes to delete"):
delete_episodes(
sample_dataset,
episode_indices=[],
output_dir=tmp_path / "filtered",
)
class TestSplitDataset:
def test_split_by_episodes(self, sample_dataset, tmp_path):
"""Test splitting dataset by specific episode indices."""
splits = {
"train": [0, 1, 2],
"val": [3, 4],
}
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
# Mock snapshot_download to return the appropriate directory for each split
def mock_snapshot(repo_id, **kwargs):
if "train" in repo_id:
return str(tmp_path / f"{sample_dataset.repo_id}_train")
elif "val" in repo_id:
return str(tmp_path / f"{sample_dataset.repo_id}_val")
return str(kwargs.get("local_dir", tmp_path))
mock_snapshot_download.side_effect = mock_snapshot
result = split_dataset(
sample_dataset,
splits=splits,
output_dir=tmp_path,
)
# Check we got both splits
assert set(result.keys()) == {"train", "val"}
# Check train split
assert result["train"].meta.total_episodes == 3
assert result["train"].meta.total_frames == 30
# Check val split
assert result["val"].meta.total_episodes == 2
assert result["val"].meta.total_frames == 20
# Check episode renumbering
train_episodes = {int(idx.item()) for idx in result["train"].hf_dataset["episode_index"]}
assert train_episodes == {0, 1, 2}
val_episodes = {int(idx.item()) for idx in result["val"].hf_dataset["episode_index"]}
assert val_episodes == {0, 1}
def test_split_by_fractions(self, sample_dataset, tmp_path):
"""Test splitting dataset by fractions."""
splits = {
"train": 0.6, # 3 episodes
"val": 0.4, # 2 episodes
}
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
def mock_snapshot(repo_id, **kwargs):
for split_name in splits:
if split_name in repo_id:
return str(tmp_path / f"{sample_dataset.repo_id}_{split_name}")
return str(kwargs.get("local_dir", tmp_path))
mock_snapshot_download.side_effect = mock_snapshot
result = split_dataset(
sample_dataset,
splits=splits,
output_dir=tmp_path,
)
# Check splits
assert result["train"].meta.total_episodes == 3
assert result["val"].meta.total_episodes == 2
def test_split_overlapping_episodes(self, sample_dataset, tmp_path):
"""Test error when episodes appear in multiple splits."""
splits = {
"train": [0, 1, 2],
"val": [2, 3, 4], # Episode 2 appears in both
}
with pytest.raises(ValueError, match="Episodes cannot appear in multiple splits"):
split_dataset(sample_dataset, splits=splits, output_dir=tmp_path)
def test_split_invalid_fractions(self, sample_dataset, tmp_path):
"""Test error when fractions sum to more than 1."""
splits = {
"train": 0.7,
"val": 0.5, # Sum = 1.2
}
with pytest.raises(ValueError, match="Split fractions must sum to <= 1.0"):
split_dataset(sample_dataset, splits=splits, output_dir=tmp_path)
def test_split_empty(self, sample_dataset, tmp_path):
"""Test error with empty splits."""
with pytest.raises(ValueError, match="No splits provided"):
split_dataset(sample_dataset, splits={}, output_dir=tmp_path)
class TestMergeDatasets:
def test_merge_two_datasets(self, sample_dataset, tmp_path, empty_lerobot_dataset_factory):
"""Test merging two datasets."""
# Create a second dataset manually
features = {
"action": {"dtype": "float32", "shape": (6,), "names": None},
"observation.state": {"dtype": "float32", "shape": (4,), "names": None},
"observation.images.top": {"dtype": "image", "shape": (224, 224, 3), "names": None},
}
dataset2 = empty_lerobot_dataset_factory(
root=tmp_path / "test_dataset2",
features=features,
)
# Add 3 episodes
for ep_idx in range(3):
for _ in range(10):
frame = {
"action": np.random.randn(6).astype(np.float32),
"observation.state": np.random.randn(4).astype(np.float32),
"observation.images.top": np.random.randint(0, 255, size=(224, 224, 3), dtype=np.uint8),
"task": f"task_{ep_idx % 2}",
}
dataset2.add_frame(frame)
dataset2.save_episode()
# Merge datasets
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.return_value = str(tmp_path / "merged_dataset")
merged = merge_datasets(
[sample_dataset, dataset2],
output_repo_id="merged_dataset",
output_dir=tmp_path / "merged_dataset",
)
# Check results
assert merged.meta.total_episodes == 8 # 5 + 3
assert merged.meta.total_frames == 80 # 50 + 30
# Check episode indices are sequential
episode_indices = sorted({int(idx.item()) for idx in merged.hf_dataset["episode_index"]})
assert episode_indices == list(range(8))
def test_merge_empty_list(self, tmp_path):
"""Test error when merging empty list."""
with pytest.raises(ValueError, match="No datasets to merge"):
merge_datasets([], output_repo_id="merged", output_dir=tmp_path)
class TestAddFeature:
def test_add_feature_with_values(self, sample_dataset, tmp_path):
"""Test adding a feature with pre-computed values."""
# Create reward values for all frames
num_frames = sample_dataset.meta.total_frames
reward_values = np.random.randn(num_frames, 1).astype(np.float32)
feature_info = {
"dtype": "float32",
"shape": (1,),
"names": None,
}
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.return_value = str(tmp_path / "with_reward")
new_dataset = add_feature(
sample_dataset,
feature_name="reward",
feature_values=reward_values,
feature_info=feature_info,
output_dir=tmp_path / "with_reward",
)
# Check feature was added
assert "reward" in new_dataset.meta.features
assert new_dataset.meta.features["reward"] == feature_info
# Check values
assert len(new_dataset) == num_frames
sample_item = new_dataset[0]
assert "reward" in sample_item
# Scalar features don't have shape, just check it's a tensor
assert isinstance(sample_item["reward"], torch.Tensor)
def test_add_feature_with_callable(self, sample_dataset, tmp_path):
"""Test adding a feature with a callable."""
def compute_reward(frame_dict, episode_idx, frame_idx):
# Simple reward based on episode and frame indices
return float(episode_idx * 10 + frame_idx)
feature_info = {
"dtype": "float32",
"shape": (1,),
"names": None,
}
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.return_value = str(tmp_path / "with_reward")
new_dataset = add_feature(
sample_dataset,
feature_name="reward",
feature_values=compute_reward,
feature_info=feature_info,
output_dir=tmp_path / "with_reward",
)
# Check feature was added
assert "reward" in new_dataset.meta.features
# Check computed values
# Episode 0, frame 0 should have reward 0
items = [new_dataset[i] for i in range(10)]
first_episode_items = [item for item in items if item["episode_index"] == 0]
assert len(first_episode_items) == 10
# Check first frame of first episode
first_frame = first_episode_items[0]
assert first_frame["frame_index"] == 0
assert float(first_frame["reward"]) == 0.0
def test_add_existing_feature(self, sample_dataset, tmp_path):
"""Test error when adding an existing feature."""
feature_info = {"dtype": "float32", "shape": (1,)}
with pytest.raises(ValueError, match="Feature 'action' already exists"):
add_feature(
sample_dataset,
feature_name="action", # Already exists
feature_values=np.zeros(50),
feature_info=feature_info,
output_dir=tmp_path / "modified",
)
def test_add_feature_invalid_info(self, sample_dataset, tmp_path):
"""Test error with invalid feature info."""
with pytest.raises(ValueError, match="feature_info must contain keys"):
add_feature(
sample_dataset,
feature_name="reward",
feature_values=np.zeros(50),
feature_info={"dtype": "float32"}, # Missing 'shape'
output_dir=tmp_path / "modified",
)
class TestRemoveFeature:
def test_remove_single_feature(self, sample_dataset, tmp_path):
"""Test removing a single feature."""
# First add a feature to remove
feature_info = {"dtype": "float32", "shape": (1,), "names": None}
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(
kwargs.get("local_dir", tmp_path)
)
dataset_with_reward = add_feature(
sample_dataset,
feature_name="reward",
feature_values=np.random.randn(50, 1).astype(np.float32),
feature_info=feature_info,
output_dir=tmp_path / "with_reward",
)
# Now remove it
dataset_without_reward = remove_feature(
dataset_with_reward,
feature_names="reward",
output_dir=tmp_path / "without_reward",
)
# Check feature was removed
assert "reward" not in dataset_without_reward.meta.features
# Check data
sample_item = dataset_without_reward[0]
assert "reward" not in sample_item
def test_remove_multiple_features(self, sample_dataset, tmp_path):
"""Test removing multiple features at once."""
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(
kwargs.get("local_dir", tmp_path)
)
# Add two features
dataset = sample_dataset
for feature_name in ["reward", "success"]:
feature_info = {"dtype": "float32", "shape": (1,), "names": None}
dataset = add_feature(
dataset,
feature_name=feature_name,
feature_values=np.random.randn(dataset.meta.total_frames, 1).astype(np.float32),
feature_info=feature_info,
output_dir=tmp_path / f"with_{feature_name}",
)
# Remove both
dataset_clean = remove_feature(
dataset,
feature_names=["reward", "success"],
output_dir=tmp_path / "clean",
)
# Check both were removed
assert "reward" not in dataset_clean.meta.features
assert "success" not in dataset_clean.meta.features
def test_remove_nonexistent_feature(self, sample_dataset, tmp_path):
"""Test error when removing non-existent feature."""
with pytest.raises(ValueError, match="Feature 'nonexistent' not found"):
remove_feature(
sample_dataset,
feature_names="nonexistent",
output_dir=tmp_path / "modified",
)
def test_remove_required_feature(self, sample_dataset, tmp_path):
"""Test error when trying to remove required features."""
with pytest.raises(ValueError, match="Cannot remove required features"):
remove_feature(
sample_dataset,
feature_names="timestamp", # Required feature
output_dir=tmp_path / "modified",
)
def test_remove_camera_feature(self, sample_dataset, tmp_path):
"""Test removing a camera feature."""
camera_keys = sample_dataset.meta.camera_keys
if not camera_keys:
pytest.skip("No camera keys in dataset")
# Remove first camera
camera_to_remove = camera_keys[0]
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.return_value = str(tmp_path / "without_camera")
dataset_without_camera = remove_feature(
sample_dataset,
feature_names=camera_to_remove,
output_dir=tmp_path / "without_camera",
)
# Check camera was removed
assert camera_to_remove not in dataset_without_camera.meta.features
assert camera_to_remove not in dataset_without_camera.meta.camera_keys
# Check data
sample_item = dataset_without_camera[0]
assert camera_to_remove not in sample_item
class TestIntegration:
def test_complex_workflow(self, sample_dataset, tmp_path):
"""Test a complex workflow combining multiple operations."""
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(
kwargs.get("local_dir", tmp_path)
)
# 1. Add a reward feature
dataset = add_feature(
sample_dataset,
feature_name="reward",
feature_values=np.random.randn(50, 1).astype(np.float32),
feature_info={"dtype": "float32", "shape": (1,), "names": None},
output_dir=tmp_path / "step1",
)
# 2. Delete an episode
dataset = delete_episodes(
dataset,
episode_indices=[2],
output_dir=tmp_path / "step2",
)
# 3. Split into train/val
splits = split_dataset(
dataset,
splits={"train": 0.75, "val": 0.25},
output_dir=tmp_path / "step3",
)
# 4. Merge them back
merged = merge_datasets(
list(splits.values()),
output_repo_id="final_dataset",
output_dir=tmp_path / "step4",
)
# Check final dataset
assert merged.meta.total_episodes == 4 # Started with 5, deleted 1
assert merged.meta.total_frames == 40
assert "reward" in merged.meta.features # Feature preserved
# Check data integrity
assert len(merged) == 40
sample_item = merged[0]
assert "reward" in sample_item