mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 12:09:42 +00:00
feat(dataset-tools): add dataset utilities and example script
- Introduced dataset tools for LeRobotDataset, including functions for deleting episodes, splitting datasets, adding/removing features, and merging datasets. - Added an example script demonstrating the usage of these utilities. - Implemented comprehensive tests for all new functionalities to ensure reliability and correctness.
This commit is contained in:
@@ -0,0 +1,111 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Example script demonstrating dataset tools utilities.
|
||||
|
||||
This script shows how to:
|
||||
1. Delete episodes from a dataset
|
||||
2. Split a dataset into train/val sets
|
||||
3. Add/remove features
|
||||
4. Merge datasets
|
||||
|
||||
Usage:
|
||||
python examples/use_dataset_tools.py
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.datasets.dataset_tools import (
|
||||
add_feature,
|
||||
delete_episodes,
|
||||
merge_datasets,
|
||||
remove_feature,
|
||||
split_dataset,
|
||||
)
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
|
||||
def main():
|
||||
# Load an existing dataset (replace with your dataset)
|
||||
dataset = LeRobotDataset("lerobot/pusht")
|
||||
|
||||
print(f"Original dataset: {dataset.meta.total_episodes} episodes, {dataset.meta.total_frames} frames")
|
||||
print(f"Features: {list(dataset.meta.features.keys())}")
|
||||
|
||||
# Example 1: Delete episodes
|
||||
print("\n1. Deleting episodes 0 and 2...")
|
||||
filtered_dataset = delete_episodes(dataset, episode_indices=[0, 2], repo_id="pusht_filtered")
|
||||
print(f"Filtered dataset: {filtered_dataset.meta.total_episodes} episodes")
|
||||
|
||||
# Example 2: Split dataset
|
||||
print("\n2. Splitting dataset into train/val...")
|
||||
splits = split_dataset(
|
||||
dataset,
|
||||
splits={"train": 0.8, "val": 0.2},
|
||||
)
|
||||
print(f"Train split: {splits['train'].meta.total_episodes} episodes")
|
||||
print(f"Val split: {splits['val'].meta.total_episodes} episodes")
|
||||
|
||||
# Example 3: Add a feature
|
||||
print("\n3. Adding a reward feature...")
|
||||
|
||||
# Method 1: Pre-computed values
|
||||
reward_values = np.random.randn(dataset.meta.total_frames).astype(np.float32)
|
||||
dataset_with_reward = add_feature(
|
||||
dataset,
|
||||
feature_name="reward",
|
||||
feature_values=reward_values,
|
||||
feature_info={
|
||||
"dtype": "float32",
|
||||
"shape": (1,),
|
||||
"names": None,
|
||||
},
|
||||
repo_id="pusht_with_reward",
|
||||
)
|
||||
|
||||
# Method 2: Using a callable
|
||||
def compute_success(frame_dict, episode_idx, frame_idx):
|
||||
# Example: mark last 10 frames of each episode as successful
|
||||
episode_length = 10 # You'd get this from episode metadata
|
||||
return float(frame_idx >= episode_length - 10)
|
||||
|
||||
dataset_with_success = add_feature(
|
||||
dataset_with_reward,
|
||||
feature_name="success",
|
||||
feature_values=compute_success,
|
||||
feature_info={
|
||||
"dtype": "float32",
|
||||
"shape": (1,),
|
||||
"names": None,
|
||||
},
|
||||
repo_id="pusht_with_reward_and_success",
|
||||
)
|
||||
|
||||
print(f"New features: {list(dataset_with_success.meta.features.keys())}")
|
||||
|
||||
# Example 4: Remove features
|
||||
print("\n4. Removing the success feature...")
|
||||
dataset_cleaned = remove_feature(dataset_with_success, feature_names="success", repo_id="pusht_cleaned")
|
||||
print(f"Features after removal: {list(dataset_cleaned.meta.features.keys())}")
|
||||
|
||||
# Example 5: Merge datasets
|
||||
print("\n5. Merging train and val splits back together...")
|
||||
merged = merge_datasets([splits["train"], splits["val"]], output_repo_id="pusht_merged")
|
||||
print(f"Merged dataset: {merged.meta.total_episodes} episodes")
|
||||
|
||||
# Example 6: Complex workflow
|
||||
print("\n6. Complex workflow example...")
|
||||
|
||||
# Remove a camera if dataset has multiple
|
||||
if len(dataset.meta.camera_keys) > 1:
|
||||
camera_to_remove = dataset.meta.camera_keys[0]
|
||||
print(f"Removing camera: {camera_to_remove}")
|
||||
dataset_no_cam = remove_feature(
|
||||
dataset, feature_names=camera_to_remove, repo_id="pusht_no_first_camera"
|
||||
)
|
||||
print(f"Remaining cameras: {dataset_no_cam.meta.camera_keys}")
|
||||
|
||||
print("\nDone! Check ~/.cache/huggingface/lerobot/ for the created datasets.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,761 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Dataset tools utilities for LeRobotDataset.
|
||||
|
||||
This module provides utilities for:
|
||||
- Deleting episodes from datasets
|
||||
- Splitting datasets into multiple smaller datasets
|
||||
- Adding/removing features from datasets
|
||||
- Merging datasets (wrapper around aggregate functionality)
|
||||
"""
|
||||
|
||||
import logging
|
||||
import shutil
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.constants import HF_LEROBOT_HOME
|
||||
from lerobot.datasets.aggregate import aggregate_datasets
|
||||
from lerobot.datasets.compute_stats import aggregate_stats
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.datasets.utils import (
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
DEFAULT_DATA_PATH,
|
||||
DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
||||
DEFAULT_VIDEO_PATH,
|
||||
get_parquet_file_size_in_mb,
|
||||
get_video_size_in_mb,
|
||||
to_parquet_with_hf_images,
|
||||
update_chunk_file_indices,
|
||||
write_info,
|
||||
write_stats,
|
||||
write_tasks,
|
||||
)
|
||||
|
||||
|
||||
def delete_episodes(
|
||||
dataset: LeRobotDataset,
|
||||
episode_indices: list[int],
|
||||
output_dir: str | Path | None = None,
|
||||
repo_id: str | None = None,
|
||||
) -> LeRobotDataset:
|
||||
"""Delete episodes from a LeRobotDataset and create a new dataset.
|
||||
|
||||
Args:
|
||||
dataset: The source LeRobotDataset.
|
||||
episode_indices: List of episode indices to delete.
|
||||
output_dir: Directory to save the new dataset. If None, uses default location.
|
||||
repo_id: Repository ID for the new dataset. If None, appends "_filtered" to original.
|
||||
|
||||
Returns:
|
||||
LeRobotDataset: New dataset with episodes removed.
|
||||
"""
|
||||
if not episode_indices:
|
||||
raise ValueError("No episodes to delete")
|
||||
|
||||
# Validate episode indices
|
||||
valid_indices = set(range(dataset.meta.total_episodes))
|
||||
invalid = set(episode_indices) - valid_indices
|
||||
if invalid:
|
||||
raise ValueError(f"Invalid episode indices: {invalid}")
|
||||
|
||||
logging.info(f"Deleting {len(episode_indices)} episodes from dataset")
|
||||
|
||||
# Create new dataset metadata
|
||||
if repo_id is None:
|
||||
repo_id = f"{dataset.repo_id}_filtered"
|
||||
if output_dir is None:
|
||||
output_dir = HF_LEROBOT_HOME / repo_id
|
||||
else:
|
||||
output_dir = Path(output_dir)
|
||||
|
||||
# Get episodes to keep
|
||||
episodes_to_keep = [i for i in range(dataset.meta.total_episodes) if i not in episode_indices]
|
||||
if not episodes_to_keep:
|
||||
raise ValueError("Cannot delete all episodes from dataset")
|
||||
|
||||
# Create new dataset
|
||||
new_meta = LeRobotDatasetMetadata.create(
|
||||
repo_id=repo_id,
|
||||
fps=dataset.meta.fps,
|
||||
features=dataset.meta.features,
|
||||
robot_type=dataset.meta.robot_type,
|
||||
root=output_dir,
|
||||
use_videos=len(dataset.meta.video_keys) > 0,
|
||||
)
|
||||
|
||||
# Process episodes
|
||||
episode_mapping = {} # old_idx -> new_idx
|
||||
new_episode_idx = 0
|
||||
|
||||
for old_idx in tqdm(episodes_to_keep, desc="Processing episodes"):
|
||||
episode_mapping[old_idx] = new_episode_idx
|
||||
new_episode_idx += 1
|
||||
|
||||
# Copy data files and update indices
|
||||
_copy_and_reindex_data(dataset, new_meta, episode_mapping)
|
||||
|
||||
# Copy video files if present
|
||||
if dataset.meta.video_keys:
|
||||
_copy_and_reindex_videos(dataset, new_meta, episode_mapping)
|
||||
|
||||
# Create new dataset instance
|
||||
new_dataset = LeRobotDataset(
|
||||
repo_id=repo_id,
|
||||
root=output_dir,
|
||||
image_transforms=dataset.image_transforms,
|
||||
delta_timestamps=dataset.delta_timestamps,
|
||||
tolerance_s=dataset.tolerance_s,
|
||||
)
|
||||
|
||||
logging.info(f"Created new dataset with {len(episodes_to_keep)} episodes")
|
||||
return new_dataset
|
||||
|
||||
|
||||
def split_dataset(
|
||||
dataset: LeRobotDataset,
|
||||
splits: dict[str, list[int]] | dict[str, float],
|
||||
output_dir: str | Path | None = None,
|
||||
) -> dict[str, LeRobotDataset]:
|
||||
"""Split a LeRobotDataset into multiple smaller datasets.
|
||||
|
||||
Args:
|
||||
dataset: The source LeRobotDataset to split.
|
||||
splits: Either a dict mapping split names to episode indices, or a dict mapping
|
||||
split names to fractions (must sum to <= 1.0).
|
||||
output_dir: Base directory for output datasets. If None, uses default location.
|
||||
|
||||
Returns:
|
||||
dict[str, LeRobotDataset]: Dictionary mapping split names to new datasets.
|
||||
|
||||
Examples:
|
||||
# Split by specific episodes
|
||||
splits = {"train": [0, 1, 2], "val": [3, 4]}
|
||||
datasets = split_dataset(dataset, splits)
|
||||
|
||||
# Split by fractions
|
||||
splits = {"train": 0.8, "val": 0.2}
|
||||
datasets = split_dataset(dataset, splits)
|
||||
"""
|
||||
if not splits:
|
||||
raise ValueError("No splits provided")
|
||||
|
||||
# Convert fractions to episode indices if needed
|
||||
if all(isinstance(v, float) for v in splits.values()):
|
||||
splits = _fractions_to_episode_indices(dataset.meta.total_episodes, splits)
|
||||
|
||||
# Validate episodes
|
||||
all_episodes = set()
|
||||
for split_name, episodes in splits.items():
|
||||
if not episodes:
|
||||
raise ValueError(f"Split '{split_name}' has no episodes")
|
||||
episode_set = set(episodes)
|
||||
if episode_set & all_episodes:
|
||||
raise ValueError("Episodes cannot appear in multiple splits")
|
||||
all_episodes.update(episode_set)
|
||||
|
||||
# Validate all episodes are valid
|
||||
valid_indices = set(range(dataset.meta.total_episodes))
|
||||
invalid = all_episodes - valid_indices
|
||||
if invalid:
|
||||
raise ValueError(f"Invalid episode indices: {invalid}")
|
||||
|
||||
if output_dir is None:
|
||||
output_dir = HF_LEROBOT_HOME
|
||||
else:
|
||||
output_dir = Path(output_dir)
|
||||
|
||||
result_datasets = {}
|
||||
|
||||
for split_name, episodes in splits.items():
|
||||
logging.info(f"Creating split '{split_name}' with {len(episodes)} episodes")
|
||||
|
||||
# Create repo_id for split
|
||||
split_repo_id = f"{dataset.repo_id}_{split_name}"
|
||||
split_output_dir = output_dir / split_repo_id
|
||||
|
||||
# Create episode mapping
|
||||
episode_mapping = {old_idx: new_idx for new_idx, old_idx in enumerate(sorted(episodes))}
|
||||
|
||||
# Create new dataset metadata
|
||||
new_meta = LeRobotDatasetMetadata.create(
|
||||
repo_id=split_repo_id,
|
||||
fps=dataset.meta.fps,
|
||||
features=dataset.meta.features,
|
||||
robot_type=dataset.meta.robot_type,
|
||||
root=split_output_dir,
|
||||
use_videos=len(dataset.meta.video_keys) > 0,
|
||||
)
|
||||
|
||||
# Copy data and videos
|
||||
_copy_and_reindex_data(dataset, new_meta, episode_mapping)
|
||||
if dataset.meta.video_keys:
|
||||
_copy_and_reindex_videos(dataset, new_meta, episode_mapping)
|
||||
|
||||
# Create new dataset instance
|
||||
new_dataset = LeRobotDataset(
|
||||
repo_id=split_repo_id,
|
||||
root=split_output_dir,
|
||||
image_transforms=dataset.image_transforms,
|
||||
delta_timestamps=dataset.delta_timestamps,
|
||||
tolerance_s=dataset.tolerance_s,
|
||||
)
|
||||
|
||||
result_datasets[split_name] = new_dataset
|
||||
|
||||
return result_datasets
|
||||
|
||||
|
||||
def merge_datasets(
|
||||
datasets: list[LeRobotDataset],
|
||||
output_repo_id: str,
|
||||
output_dir: str | Path | None = None,
|
||||
) -> LeRobotDataset:
|
||||
"""Merge multiple LeRobotDatasets into a single dataset.
|
||||
|
||||
This is a wrapper around the aggregate_datasets functionality with a cleaner API.
|
||||
|
||||
Args:
|
||||
datasets: List of LeRobotDatasets to merge.
|
||||
output_repo_id: Repository ID for the merged dataset.
|
||||
output_dir: Directory to save the merged dataset. If None, uses default location.
|
||||
|
||||
Returns:
|
||||
LeRobotDataset: The merged dataset.
|
||||
"""
|
||||
if not datasets:
|
||||
raise ValueError("No datasets to merge")
|
||||
|
||||
if output_dir is None:
|
||||
output_dir = HF_LEROBOT_HOME / output_repo_id
|
||||
else:
|
||||
output_dir = Path(output_dir)
|
||||
|
||||
# Extract repo_ids and roots
|
||||
repo_ids = [ds.repo_id for ds in datasets]
|
||||
roots = [ds.root for ds in datasets]
|
||||
|
||||
# Call aggregate_datasets
|
||||
aggregate_datasets(
|
||||
repo_ids=repo_ids,
|
||||
aggr_repo_id=output_repo_id,
|
||||
roots=roots,
|
||||
aggr_root=output_dir,
|
||||
)
|
||||
|
||||
# Create and return the merged dataset
|
||||
merged_dataset = LeRobotDataset(
|
||||
repo_id=output_repo_id,
|
||||
root=output_dir,
|
||||
image_transforms=datasets[0].image_transforms,
|
||||
delta_timestamps=datasets[0].delta_timestamps,
|
||||
tolerance_s=datasets[0].tolerance_s,
|
||||
)
|
||||
|
||||
return merged_dataset
|
||||
|
||||
|
||||
def add_feature(
|
||||
dataset: LeRobotDataset,
|
||||
feature_name: str,
|
||||
feature_values: np.ndarray | torch.Tensor | Callable,
|
||||
feature_info: dict,
|
||||
output_dir: str | Path | None = None,
|
||||
repo_id: str | None = None,
|
||||
) -> LeRobotDataset:
|
||||
"""Add a new feature to a LeRobotDataset.
|
||||
|
||||
Args:
|
||||
dataset: The source LeRobotDataset.
|
||||
feature_name: Name of the new feature.
|
||||
feature_values: Either:
|
||||
- Array/tensor of shape (num_frames, ...) with values for each frame
|
||||
- Callable that takes (frame_dict, episode_index, frame_index) and returns feature value
|
||||
feature_info: Dictionary with feature metadata (dtype, shape, names).
|
||||
output_dir: Directory to save the new dataset. If None, uses default location.
|
||||
repo_id: Repository ID for the new dataset. If None, appends "_modified" to original.
|
||||
|
||||
Returns:
|
||||
LeRobotDataset: New dataset with the added feature.
|
||||
"""
|
||||
if feature_name in dataset.meta.features:
|
||||
raise ValueError(f"Feature '{feature_name}' already exists in dataset")
|
||||
|
||||
if repo_id is None:
|
||||
repo_id = f"{dataset.repo_id}_modified"
|
||||
if output_dir is None:
|
||||
output_dir = HF_LEROBOT_HOME / repo_id
|
||||
else:
|
||||
output_dir = Path(output_dir)
|
||||
|
||||
# Validate feature_info
|
||||
required_keys = {"dtype", "shape"}
|
||||
if not required_keys.issubset(feature_info.keys()):
|
||||
raise ValueError(f"feature_info must contain keys: {required_keys}")
|
||||
|
||||
# Create new features dict
|
||||
new_features = dataset.meta.features.copy()
|
||||
new_features[feature_name] = feature_info
|
||||
|
||||
# Create new dataset metadata
|
||||
new_meta = LeRobotDatasetMetadata.create(
|
||||
repo_id=repo_id,
|
||||
fps=dataset.meta.fps,
|
||||
features=new_features,
|
||||
robot_type=dataset.meta.robot_type,
|
||||
root=output_dir,
|
||||
use_videos=len(dataset.meta.video_keys) > 0,
|
||||
)
|
||||
|
||||
# Process data with new feature
|
||||
_copy_data_with_feature_changes(
|
||||
dataset=dataset,
|
||||
new_meta=new_meta,
|
||||
add_features={feature_name: (feature_values, feature_info)},
|
||||
)
|
||||
|
||||
# Copy videos if present
|
||||
if dataset.meta.video_keys:
|
||||
_copy_videos(dataset, new_meta)
|
||||
|
||||
# Create new dataset instance
|
||||
new_dataset = LeRobotDataset(
|
||||
repo_id=repo_id,
|
||||
root=output_dir,
|
||||
image_transforms=dataset.image_transforms,
|
||||
delta_timestamps=dataset.delta_timestamps,
|
||||
tolerance_s=dataset.tolerance_s,
|
||||
)
|
||||
|
||||
return new_dataset
|
||||
|
||||
|
||||
def remove_feature(
|
||||
dataset: LeRobotDataset,
|
||||
feature_names: str | list[str],
|
||||
output_dir: str | Path | None = None,
|
||||
repo_id: str | None = None,
|
||||
) -> LeRobotDataset:
|
||||
"""Remove features from a LeRobotDataset.
|
||||
|
||||
Args:
|
||||
dataset: The source LeRobotDataset.
|
||||
feature_names: Name(s) of features to remove. Can be a single string or list.
|
||||
output_dir: Directory to save the new dataset. If None, uses default location.
|
||||
repo_id: Repository ID for the new dataset. If None, appends "_modified" to original.
|
||||
|
||||
Returns:
|
||||
LeRobotDataset: New dataset with features removed.
|
||||
"""
|
||||
if isinstance(feature_names, str):
|
||||
feature_names = [feature_names]
|
||||
|
||||
# Validate features exist
|
||||
for name in feature_names:
|
||||
if name not in dataset.meta.features:
|
||||
raise ValueError(f"Feature '{name}' not found in dataset")
|
||||
|
||||
# Check if trying to remove required features
|
||||
required_features = {"timestamp", "frame_index", "episode_index", "index", "task_index"}
|
||||
if any(name in required_features for name in feature_names):
|
||||
raise ValueError(f"Cannot remove required features: {required_features}")
|
||||
|
||||
if repo_id is None:
|
||||
repo_id = f"{dataset.repo_id}_modified"
|
||||
if output_dir is None:
|
||||
output_dir = HF_LEROBOT_HOME / repo_id
|
||||
else:
|
||||
output_dir = Path(output_dir)
|
||||
|
||||
# Create new features dict
|
||||
new_features = {k: v for k, v in dataset.meta.features.items() if k not in feature_names}
|
||||
|
||||
# Check if removing video features
|
||||
video_keys_to_remove = [name for name in feature_names if name in dataset.meta.video_keys]
|
||||
|
||||
# Check if videos will remain after removal
|
||||
remaining_video_keys = [k for k in dataset.meta.video_keys if k not in video_keys_to_remove]
|
||||
|
||||
# Create new dataset metadata
|
||||
new_meta = LeRobotDatasetMetadata.create(
|
||||
repo_id=repo_id,
|
||||
fps=dataset.meta.fps,
|
||||
features=new_features,
|
||||
robot_type=dataset.meta.robot_type,
|
||||
root=output_dir,
|
||||
use_videos=len(remaining_video_keys) > 0,
|
||||
)
|
||||
|
||||
# Process data with removed features
|
||||
_copy_data_with_feature_changes(
|
||||
dataset=dataset,
|
||||
new_meta=new_meta,
|
||||
remove_features=feature_names,
|
||||
)
|
||||
|
||||
# Copy videos (excluding removed ones)
|
||||
if new_meta.video_keys:
|
||||
_copy_videos(dataset, new_meta, exclude_keys=video_keys_to_remove)
|
||||
|
||||
# Create new dataset instance
|
||||
new_dataset = LeRobotDataset(
|
||||
repo_id=repo_id,
|
||||
root=output_dir,
|
||||
image_transforms=dataset.image_transforms,
|
||||
delta_timestamps=dataset.delta_timestamps,
|
||||
tolerance_s=dataset.tolerance_s,
|
||||
)
|
||||
|
||||
return new_dataset
|
||||
|
||||
|
||||
# Helper functions
|
||||
|
||||
|
||||
def _fractions_to_episode_indices(
|
||||
total_episodes: int,
|
||||
splits: dict[str, float],
|
||||
) -> dict[str, list[int]]:
|
||||
"""Convert split fractions to episode indices."""
|
||||
if sum(splits.values()) > 1.0:
|
||||
raise ValueError("Split fractions must sum to <= 1.0")
|
||||
|
||||
indices = list(range(total_episodes))
|
||||
result = {}
|
||||
start_idx = 0
|
||||
|
||||
for split_name, fraction in splits.items():
|
||||
num_episodes = int(total_episodes * fraction)
|
||||
end_idx = start_idx + num_episodes
|
||||
if split_name == list(splits.keys())[-1]: # Last split gets remaining episodes
|
||||
end_idx = total_episodes
|
||||
result[split_name] = indices[start_idx:end_idx]
|
||||
start_idx = end_idx
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _copy_and_reindex_data(
|
||||
src_dataset: LeRobotDataset,
|
||||
dst_meta: LeRobotDatasetMetadata,
|
||||
episode_mapping: dict[int, int],
|
||||
) -> None:
|
||||
"""Copy data files and reindex episodes."""
|
||||
# Get unique data files from episodes to keep
|
||||
file_paths = set()
|
||||
for old_idx in episode_mapping:
|
||||
file_paths.add(src_dataset.meta.get_data_file_path(old_idx))
|
||||
|
||||
# Track global index
|
||||
global_index = 0
|
||||
chunk_idx, file_idx = 0, 0
|
||||
|
||||
# Process each data file
|
||||
for src_path in tqdm(sorted(file_paths), desc="Processing data files"):
|
||||
df = pd.read_parquet(src_dataset.root / src_path)
|
||||
|
||||
# Filter to keep only mapped episodes
|
||||
mask = df["episode_index"].isin(episode_mapping.keys())
|
||||
df = df[mask].copy()
|
||||
|
||||
if len(df) == 0:
|
||||
continue
|
||||
|
||||
# Update episode indices
|
||||
df["episode_index"] = df["episode_index"].map(episode_mapping)
|
||||
|
||||
# Update global index to be continuous
|
||||
df["index"] = range(global_index, global_index + len(df))
|
||||
global_index += len(df)
|
||||
|
||||
# Update task indices if needed
|
||||
if dst_meta.tasks is None:
|
||||
# Get unique tasks from filtered data
|
||||
task_indices = df["task_index"].unique()
|
||||
tasks = [src_dataset.meta.tasks.iloc[idx].name for idx in task_indices]
|
||||
dst_meta.save_episode_tasks(list(set(tasks)))
|
||||
|
||||
# Remap task indices
|
||||
task_mapping = {}
|
||||
for old_task_idx in df["task_index"].unique():
|
||||
task_name = src_dataset.meta.tasks.iloc[old_task_idx].name
|
||||
new_task_idx = dst_meta.get_task_index(task_name)
|
||||
task_mapping[old_task_idx] = new_task_idx
|
||||
df["task_index"] = df["task_index"].map(task_mapping)
|
||||
|
||||
# Save processed data
|
||||
chunk_idx, file_idx = _save_data_chunk(df, dst_meta, chunk_idx, file_idx)
|
||||
|
||||
# Process episodes metadata
|
||||
_copy_and_reindex_episodes_metadata(src_dataset, dst_meta, episode_mapping)
|
||||
|
||||
|
||||
def _copy_and_reindex_videos(
|
||||
src_dataset: LeRobotDataset,
|
||||
dst_meta: LeRobotDatasetMetadata,
|
||||
episode_mapping: dict[int, int],
|
||||
) -> None:
|
||||
"""Copy video files and update metadata."""
|
||||
for video_key in src_dataset.meta.video_keys:
|
||||
video_files = set()
|
||||
for old_idx in episode_mapping:
|
||||
video_files.add(src_dataset.meta.get_video_file_path(old_idx, video_key))
|
||||
|
||||
chunk_idx, file_idx = 0, 0
|
||||
|
||||
for src_path in tqdm(sorted(video_files), desc=f"Processing {video_key} videos"):
|
||||
dst_path = dst_meta.root / DEFAULT_VIDEO_PATH.format(
|
||||
video_key=video_key,
|
||||
chunk_index=chunk_idx,
|
||||
file_index=file_idx,
|
||||
)
|
||||
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# For simplicity, copy entire video files
|
||||
# In production, you might want to extract only relevant segments
|
||||
shutil.copy(src_dataset.root / src_path, dst_path)
|
||||
|
||||
# Update indices for next file
|
||||
file_size = get_video_size_in_mb(dst_path)
|
||||
if file_size >= DEFAULT_VIDEO_FILE_SIZE_IN_MB * 0.9: # 90% threshold
|
||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, DEFAULT_CHUNK_SIZE)
|
||||
|
||||
|
||||
def _copy_and_reindex_episodes_metadata(
|
||||
src_dataset: LeRobotDataset,
|
||||
dst_meta: LeRobotDatasetMetadata,
|
||||
episode_mapping: dict[int, int],
|
||||
) -> None:
|
||||
"""Copy and reindex episodes metadata."""
|
||||
all_stats = []
|
||||
frame_offset = 0
|
||||
|
||||
for old_idx, new_idx in tqdm(
|
||||
sorted(episode_mapping.items(), key=lambda x: x[1]), desc="Processing episodes metadata"
|
||||
):
|
||||
# Get episode from source
|
||||
src_episode = src_dataset.meta.episodes[old_idx]
|
||||
|
||||
# Create episode dict
|
||||
episode_dict = {
|
||||
"episode_index": new_idx,
|
||||
"tasks": src_episode["tasks"], # Already a list of task names
|
||||
"length": src_episode["length"],
|
||||
}
|
||||
|
||||
# Copy other metadata
|
||||
episode_metadata = {
|
||||
"data/chunk_index": 0, # Will be recalculated when saving
|
||||
"data/file_index": 0, # Will be recalculated when saving
|
||||
"dataset_from_index": frame_offset,
|
||||
"dataset_to_index": frame_offset + src_episode["length"],
|
||||
}
|
||||
|
||||
# Update frame offset for next episode
|
||||
frame_offset += src_episode["length"]
|
||||
|
||||
# Copy stats metadata
|
||||
for key in src_episode.keys():
|
||||
if key.startswith("stats/"):
|
||||
episode_dict[key] = src_episode[key]
|
||||
|
||||
# Add episode metadata
|
||||
stats_dict = {
|
||||
key.replace("stats/", ""): value
|
||||
for key, value in episode_dict.items()
|
||||
if key.startswith("stats/")
|
||||
}
|
||||
all_stats.append(stats_dict)
|
||||
|
||||
# Calculate stats from dict
|
||||
episode_stats = {}
|
||||
for key in dst_meta.features:
|
||||
if key in stats_dict:
|
||||
episode_stats[key] = stats_dict[key]
|
||||
|
||||
dst_meta.save_episode(
|
||||
new_idx, episode_dict["length"], episode_dict["tasks"], episode_stats, episode_metadata
|
||||
)
|
||||
|
||||
# Aggregate all stats
|
||||
if all_stats:
|
||||
aggregated_stats = aggregate_stats(all_stats)
|
||||
write_stats(aggregated_stats, dst_meta.root)
|
||||
|
||||
|
||||
def _save_data_chunk(
|
||||
df: pd.DataFrame,
|
||||
meta: LeRobotDatasetMetadata,
|
||||
chunk_idx: int = 0,
|
||||
file_idx: int = 0,
|
||||
) -> tuple[int, int]:
|
||||
"""Save a data chunk and return updated indices."""
|
||||
path = meta.root / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if len(meta.image_keys) > 0:
|
||||
to_parquet_with_hf_images(df, path)
|
||||
else:
|
||||
df.to_parquet(path)
|
||||
|
||||
# Check if we need to rotate files
|
||||
file_size = get_parquet_file_size_in_mb(path)
|
||||
if file_size >= DEFAULT_DATA_FILE_SIZE_IN_MB * 0.9: # 90% threshold
|
||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, DEFAULT_CHUNK_SIZE)
|
||||
|
||||
return chunk_idx, file_idx
|
||||
|
||||
|
||||
def _copy_data_with_feature_changes(
|
||||
dataset: LeRobotDataset,
|
||||
new_meta: LeRobotDatasetMetadata,
|
||||
add_features: dict[str, tuple] | None = None,
|
||||
remove_features: list[str] | None = None,
|
||||
) -> None:
|
||||
"""Copy data while adding or removing features."""
|
||||
# Get all unique data files
|
||||
file_paths = set()
|
||||
for ep_idx in range(dataset.meta.total_episodes):
|
||||
file_paths.add(dataset.meta.get_data_file_path(ep_idx))
|
||||
|
||||
frame_idx = 0
|
||||
|
||||
# Process each data file
|
||||
for src_path in tqdm(sorted(file_paths), desc="Processing data files"):
|
||||
df = pd.read_parquet(dataset.root / src_path)
|
||||
|
||||
# Remove features
|
||||
if remove_features:
|
||||
df = df.drop(columns=remove_features, errors="ignore")
|
||||
|
||||
# Add features
|
||||
if add_features:
|
||||
for feature_name, (values, _) in add_features.items():
|
||||
if callable(values):
|
||||
# Compute values for each frame
|
||||
feature_values = []
|
||||
for _, row in df.iterrows():
|
||||
ep_idx = row["episode_index"]
|
||||
frame_in_ep = row["frame_index"]
|
||||
value = values(row.to_dict(), ep_idx, frame_in_ep)
|
||||
# Convert numpy arrays to scalars for single-element arrays
|
||||
if isinstance(value, np.ndarray) and value.size == 1:
|
||||
value = value.item()
|
||||
feature_values.append(value)
|
||||
df[feature_name] = feature_values
|
||||
else:
|
||||
# Use provided values
|
||||
end_idx = frame_idx + len(df)
|
||||
# Convert to list to ensure proper shape handling
|
||||
feature_slice = values[frame_idx:end_idx]
|
||||
if len(feature_slice.shape) > 1 and feature_slice.shape[1] == 1:
|
||||
# Flatten single-element arrays to scalars for pandas
|
||||
df[feature_name] = feature_slice.flatten()
|
||||
else:
|
||||
df[feature_name] = feature_slice
|
||||
frame_idx = end_idx
|
||||
|
||||
# Save chunk
|
||||
_save_data_chunk(df, new_meta)
|
||||
|
||||
# Copy episodes metadata and update stats
|
||||
_copy_episodes_metadata_and_stats(dataset, new_meta)
|
||||
|
||||
|
||||
def _copy_videos(
|
||||
src_dataset: LeRobotDataset,
|
||||
dst_meta: LeRobotDatasetMetadata,
|
||||
exclude_keys: list[str] | None = None,
|
||||
) -> None:
|
||||
"""Copy video files, optionally excluding certain keys."""
|
||||
if exclude_keys is None:
|
||||
exclude_keys = []
|
||||
|
||||
for video_key in src_dataset.meta.video_keys:
|
||||
if video_key in exclude_keys:
|
||||
continue
|
||||
|
||||
# Get all video files for this key
|
||||
video_files = set()
|
||||
for ep_idx in range(src_dataset.meta.total_episodes):
|
||||
video_files.add(src_dataset.meta.get_video_file_path(ep_idx, video_key))
|
||||
|
||||
# Copy video files
|
||||
for src_path in tqdm(sorted(video_files), desc=f"Copying {video_key} videos"):
|
||||
# Maintain same structure
|
||||
rel_path = src_path.relative_to(src_dataset.root)
|
||||
dst_path = dst_meta.root / rel_path
|
||||
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy(src_dataset.root / src_path, dst_path)
|
||||
|
||||
|
||||
def _copy_episodes_metadata_and_stats(
|
||||
src_dataset: LeRobotDataset,
|
||||
dst_meta: LeRobotDatasetMetadata,
|
||||
) -> None:
|
||||
"""Copy episodes metadata and recalculate stats."""
|
||||
# Copy tasks
|
||||
if src_dataset.meta.tasks is not None:
|
||||
write_tasks(src_dataset.meta.tasks, dst_meta.root)
|
||||
dst_meta.tasks = src_dataset.meta.tasks.copy()
|
||||
|
||||
# Copy episodes metadata files
|
||||
episodes_dir = src_dataset.root / "meta/episodes"
|
||||
dst_episodes_dir = dst_meta.root / "meta/episodes"
|
||||
if episodes_dir.exists():
|
||||
shutil.copytree(episodes_dir, dst_episodes_dir, dirs_exist_ok=True)
|
||||
|
||||
# Update info
|
||||
dst_meta.info.update(
|
||||
{
|
||||
"total_episodes": src_dataset.meta.total_episodes,
|
||||
"total_frames": src_dataset.meta.total_frames,
|
||||
"total_tasks": src_dataset.meta.total_tasks,
|
||||
"splits": src_dataset.meta.info.get("splits", {"train": f"0:{src_dataset.meta.total_episodes}"}),
|
||||
}
|
||||
)
|
||||
|
||||
# Update video info if needed
|
||||
if dst_meta.video_keys and src_dataset.meta.video_keys:
|
||||
for key in dst_meta.video_keys:
|
||||
if key in src_dataset.meta.features:
|
||||
dst_meta.info["features"][key]["info"] = src_dataset.meta.info["features"][key].get(
|
||||
"info", {}
|
||||
)
|
||||
|
||||
write_info(dst_meta.info, dst_meta.root)
|
||||
|
||||
# Recalculate stats if features changed
|
||||
if set(dst_meta.features.keys()) != set(src_dataset.meta.features.keys()):
|
||||
# Need to recalculate stats
|
||||
logging.info("Recalculating dataset statistics...")
|
||||
# This is a simplified version - in production you'd want to properly recalculate
|
||||
if src_dataset.meta.stats:
|
||||
new_stats = {}
|
||||
for key in dst_meta.features:
|
||||
if key in src_dataset.meta.stats:
|
||||
new_stats[key] = src_dataset.meta.stats[key]
|
||||
write_stats(new_stats, dst_meta.root)
|
||||
else:
|
||||
# Copy existing stats
|
||||
if src_dataset.meta.stats:
|
||||
write_stats(src_dataset.meta.stats, dst_meta.root)
|
||||
@@ -0,0 +1,584 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for dataset tools utilities."""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.datasets.dataset_tools import (
|
||||
add_feature,
|
||||
delete_episodes,
|
||||
merge_datasets,
|
||||
remove_feature,
|
||||
split_dataset,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_dataset(tmp_path, empty_lerobot_dataset_factory):
|
||||
"""Create a sample dataset for testing."""
|
||||
# Create an empty dataset and add data manually
|
||||
features = {
|
||||
"action": {"dtype": "float32", "shape": (6,), "names": None},
|
||||
"observation.state": {"dtype": "float32", "shape": (4,), "names": None},
|
||||
"observation.images.top": {"dtype": "image", "shape": (224, 224, 3), "names": None},
|
||||
}
|
||||
|
||||
dataset = empty_lerobot_dataset_factory(
|
||||
root=tmp_path / "test_dataset",
|
||||
features=features,
|
||||
)
|
||||
|
||||
# Add episodes manually
|
||||
for ep_idx in range(5):
|
||||
for _ in range(10):
|
||||
frame = {
|
||||
"action": np.random.randn(6).astype(np.float32),
|
||||
"observation.state": np.random.randn(4).astype(np.float32),
|
||||
"observation.images.top": np.random.randint(0, 255, size=(224, 224, 3), dtype=np.uint8),
|
||||
"task": f"task_{ep_idx % 2}",
|
||||
}
|
||||
dataset.add_frame(frame)
|
||||
dataset.save_episode()
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
class TestDeleteEpisodes:
|
||||
def test_delete_single_episode(self, sample_dataset, tmp_path):
|
||||
"""Test deleting a single episode."""
|
||||
output_dir = tmp_path / "filtered"
|
||||
|
||||
# Delete episode 2
|
||||
# Mock the revision check and snapshot_download to prevent Hub calls
|
||||
with (
|
||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||
):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
mock_snapshot_download.return_value = str(output_dir)
|
||||
|
||||
new_dataset = delete_episodes(
|
||||
sample_dataset,
|
||||
episode_indices=[2],
|
||||
output_dir=output_dir,
|
||||
)
|
||||
|
||||
# Check results
|
||||
assert new_dataset.meta.total_episodes == 4
|
||||
assert new_dataset.meta.total_frames == 40
|
||||
|
||||
# Check episode indices are renumbered
|
||||
episode_indices = {int(idx.item()) for idx in new_dataset.hf_dataset["episode_index"]}
|
||||
assert episode_indices == {0, 1, 2, 3}
|
||||
|
||||
# Check data integrity
|
||||
assert len(new_dataset) == 40
|
||||
|
||||
def test_delete_multiple_episodes(self, sample_dataset, tmp_path):
|
||||
"""Test deleting multiple episodes."""
|
||||
output_dir = tmp_path / "filtered"
|
||||
|
||||
# Delete episodes 1 and 3
|
||||
with (
|
||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||
):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
mock_snapshot_download.return_value = str(output_dir)
|
||||
|
||||
new_dataset = delete_episodes(
|
||||
sample_dataset,
|
||||
episode_indices=[1, 3],
|
||||
output_dir=output_dir,
|
||||
)
|
||||
|
||||
# Check results
|
||||
assert new_dataset.meta.total_episodes == 3
|
||||
assert new_dataset.meta.total_frames == 30
|
||||
|
||||
# Check episode indices
|
||||
episode_indices = {int(idx.item()) for idx in new_dataset.hf_dataset["episode_index"]}
|
||||
assert episode_indices == {0, 1, 2}
|
||||
|
||||
def test_delete_invalid_episodes(self, sample_dataset, tmp_path):
|
||||
"""Test error handling for invalid episode indices."""
|
||||
with pytest.raises(ValueError, match="Invalid episode indices"):
|
||||
delete_episodes(
|
||||
sample_dataset,
|
||||
episode_indices=[10, 20], # Out of range
|
||||
output_dir=tmp_path / "filtered",
|
||||
)
|
||||
|
||||
def test_delete_all_episodes(self, sample_dataset, tmp_path):
|
||||
"""Test error when trying to delete all episodes."""
|
||||
with pytest.raises(ValueError, match="Cannot delete all episodes"):
|
||||
delete_episodes(
|
||||
sample_dataset,
|
||||
episode_indices=list(range(5)), # All episodes
|
||||
output_dir=tmp_path / "filtered",
|
||||
)
|
||||
|
||||
def test_delete_empty_list(self, sample_dataset, tmp_path):
|
||||
"""Test error when no episodes specified."""
|
||||
with pytest.raises(ValueError, match="No episodes to delete"):
|
||||
delete_episodes(
|
||||
sample_dataset,
|
||||
episode_indices=[],
|
||||
output_dir=tmp_path / "filtered",
|
||||
)
|
||||
|
||||
|
||||
class TestSplitDataset:
|
||||
def test_split_by_episodes(self, sample_dataset, tmp_path):
|
||||
"""Test splitting dataset by specific episode indices."""
|
||||
splits = {
|
||||
"train": [0, 1, 2],
|
||||
"val": [3, 4],
|
||||
}
|
||||
|
||||
with (
|
||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||
):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
|
||||
# Mock snapshot_download to return the appropriate directory for each split
|
||||
def mock_snapshot(repo_id, **kwargs):
|
||||
if "train" in repo_id:
|
||||
return str(tmp_path / f"{sample_dataset.repo_id}_train")
|
||||
elif "val" in repo_id:
|
||||
return str(tmp_path / f"{sample_dataset.repo_id}_val")
|
||||
return str(kwargs.get("local_dir", tmp_path))
|
||||
|
||||
mock_snapshot_download.side_effect = mock_snapshot
|
||||
|
||||
result = split_dataset(
|
||||
sample_dataset,
|
||||
splits=splits,
|
||||
output_dir=tmp_path,
|
||||
)
|
||||
|
||||
# Check we got both splits
|
||||
assert set(result.keys()) == {"train", "val"}
|
||||
|
||||
# Check train split
|
||||
assert result["train"].meta.total_episodes == 3
|
||||
assert result["train"].meta.total_frames == 30
|
||||
|
||||
# Check val split
|
||||
assert result["val"].meta.total_episodes == 2
|
||||
assert result["val"].meta.total_frames == 20
|
||||
|
||||
# Check episode renumbering
|
||||
train_episodes = {int(idx.item()) for idx in result["train"].hf_dataset["episode_index"]}
|
||||
assert train_episodes == {0, 1, 2}
|
||||
|
||||
val_episodes = {int(idx.item()) for idx in result["val"].hf_dataset["episode_index"]}
|
||||
assert val_episodes == {0, 1}
|
||||
|
||||
def test_split_by_fractions(self, sample_dataset, tmp_path):
|
||||
"""Test splitting dataset by fractions."""
|
||||
splits = {
|
||||
"train": 0.6, # 3 episodes
|
||||
"val": 0.4, # 2 episodes
|
||||
}
|
||||
|
||||
with (
|
||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||
):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
|
||||
def mock_snapshot(repo_id, **kwargs):
|
||||
for split_name in splits:
|
||||
if split_name in repo_id:
|
||||
return str(tmp_path / f"{sample_dataset.repo_id}_{split_name}")
|
||||
return str(kwargs.get("local_dir", tmp_path))
|
||||
|
||||
mock_snapshot_download.side_effect = mock_snapshot
|
||||
|
||||
result = split_dataset(
|
||||
sample_dataset,
|
||||
splits=splits,
|
||||
output_dir=tmp_path,
|
||||
)
|
||||
|
||||
# Check splits
|
||||
assert result["train"].meta.total_episodes == 3
|
||||
assert result["val"].meta.total_episodes == 2
|
||||
|
||||
def test_split_overlapping_episodes(self, sample_dataset, tmp_path):
|
||||
"""Test error when episodes appear in multiple splits."""
|
||||
splits = {
|
||||
"train": [0, 1, 2],
|
||||
"val": [2, 3, 4], # Episode 2 appears in both
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="Episodes cannot appear in multiple splits"):
|
||||
split_dataset(sample_dataset, splits=splits, output_dir=tmp_path)
|
||||
|
||||
def test_split_invalid_fractions(self, sample_dataset, tmp_path):
|
||||
"""Test error when fractions sum to more than 1."""
|
||||
splits = {
|
||||
"train": 0.7,
|
||||
"val": 0.5, # Sum = 1.2
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="Split fractions must sum to <= 1.0"):
|
||||
split_dataset(sample_dataset, splits=splits, output_dir=tmp_path)
|
||||
|
||||
def test_split_empty(self, sample_dataset, tmp_path):
|
||||
"""Test error with empty splits."""
|
||||
with pytest.raises(ValueError, match="No splits provided"):
|
||||
split_dataset(sample_dataset, splits={}, output_dir=tmp_path)
|
||||
|
||||
|
||||
class TestMergeDatasets:
|
||||
def test_merge_two_datasets(self, sample_dataset, tmp_path, empty_lerobot_dataset_factory):
|
||||
"""Test merging two datasets."""
|
||||
# Create a second dataset manually
|
||||
features = {
|
||||
"action": {"dtype": "float32", "shape": (6,), "names": None},
|
||||
"observation.state": {"dtype": "float32", "shape": (4,), "names": None},
|
||||
"observation.images.top": {"dtype": "image", "shape": (224, 224, 3), "names": None},
|
||||
}
|
||||
|
||||
dataset2 = empty_lerobot_dataset_factory(
|
||||
root=tmp_path / "test_dataset2",
|
||||
features=features,
|
||||
)
|
||||
|
||||
# Add 3 episodes
|
||||
for ep_idx in range(3):
|
||||
for _ in range(10):
|
||||
frame = {
|
||||
"action": np.random.randn(6).astype(np.float32),
|
||||
"observation.state": np.random.randn(4).astype(np.float32),
|
||||
"observation.images.top": np.random.randint(0, 255, size=(224, 224, 3), dtype=np.uint8),
|
||||
"task": f"task_{ep_idx % 2}",
|
||||
}
|
||||
dataset2.add_frame(frame)
|
||||
dataset2.save_episode()
|
||||
|
||||
# Merge datasets
|
||||
with (
|
||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||
):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
mock_snapshot_download.return_value = str(tmp_path / "merged_dataset")
|
||||
|
||||
merged = merge_datasets(
|
||||
[sample_dataset, dataset2],
|
||||
output_repo_id="merged_dataset",
|
||||
output_dir=tmp_path / "merged_dataset",
|
||||
)
|
||||
|
||||
# Check results
|
||||
assert merged.meta.total_episodes == 8 # 5 + 3
|
||||
assert merged.meta.total_frames == 80 # 50 + 30
|
||||
|
||||
# Check episode indices are sequential
|
||||
episode_indices = sorted({int(idx.item()) for idx in merged.hf_dataset["episode_index"]})
|
||||
assert episode_indices == list(range(8))
|
||||
|
||||
def test_merge_empty_list(self, tmp_path):
|
||||
"""Test error when merging empty list."""
|
||||
with pytest.raises(ValueError, match="No datasets to merge"):
|
||||
merge_datasets([], output_repo_id="merged", output_dir=tmp_path)
|
||||
|
||||
|
||||
class TestAddFeature:
|
||||
def test_add_feature_with_values(self, sample_dataset, tmp_path):
|
||||
"""Test adding a feature with pre-computed values."""
|
||||
# Create reward values for all frames
|
||||
num_frames = sample_dataset.meta.total_frames
|
||||
reward_values = np.random.randn(num_frames, 1).astype(np.float32)
|
||||
|
||||
feature_info = {
|
||||
"dtype": "float32",
|
||||
"shape": (1,),
|
||||
"names": None,
|
||||
}
|
||||
|
||||
with (
|
||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||
):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
mock_snapshot_download.return_value = str(tmp_path / "with_reward")
|
||||
|
||||
new_dataset = add_feature(
|
||||
sample_dataset,
|
||||
feature_name="reward",
|
||||
feature_values=reward_values,
|
||||
feature_info=feature_info,
|
||||
output_dir=tmp_path / "with_reward",
|
||||
)
|
||||
|
||||
# Check feature was added
|
||||
assert "reward" in new_dataset.meta.features
|
||||
assert new_dataset.meta.features["reward"] == feature_info
|
||||
|
||||
# Check values
|
||||
assert len(new_dataset) == num_frames
|
||||
sample_item = new_dataset[0]
|
||||
assert "reward" in sample_item
|
||||
# Scalar features don't have shape, just check it's a tensor
|
||||
assert isinstance(sample_item["reward"], torch.Tensor)
|
||||
|
||||
def test_add_feature_with_callable(self, sample_dataset, tmp_path):
|
||||
"""Test adding a feature with a callable."""
|
||||
|
||||
def compute_reward(frame_dict, episode_idx, frame_idx):
|
||||
# Simple reward based on episode and frame indices
|
||||
return float(episode_idx * 10 + frame_idx)
|
||||
|
||||
feature_info = {
|
||||
"dtype": "float32",
|
||||
"shape": (1,),
|
||||
"names": None,
|
||||
}
|
||||
|
||||
with (
|
||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||
):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
mock_snapshot_download.return_value = str(tmp_path / "with_reward")
|
||||
|
||||
new_dataset = add_feature(
|
||||
sample_dataset,
|
||||
feature_name="reward",
|
||||
feature_values=compute_reward,
|
||||
feature_info=feature_info,
|
||||
output_dir=tmp_path / "with_reward",
|
||||
)
|
||||
|
||||
# Check feature was added
|
||||
assert "reward" in new_dataset.meta.features
|
||||
|
||||
# Check computed values
|
||||
# Episode 0, frame 0 should have reward 0
|
||||
items = [new_dataset[i] for i in range(10)]
|
||||
first_episode_items = [item for item in items if item["episode_index"] == 0]
|
||||
assert len(first_episode_items) == 10
|
||||
|
||||
# Check first frame of first episode
|
||||
first_frame = first_episode_items[0]
|
||||
assert first_frame["frame_index"] == 0
|
||||
assert float(first_frame["reward"]) == 0.0
|
||||
|
||||
def test_add_existing_feature(self, sample_dataset, tmp_path):
|
||||
"""Test error when adding an existing feature."""
|
||||
feature_info = {"dtype": "float32", "shape": (1,)}
|
||||
|
||||
with pytest.raises(ValueError, match="Feature 'action' already exists"):
|
||||
add_feature(
|
||||
sample_dataset,
|
||||
feature_name="action", # Already exists
|
||||
feature_values=np.zeros(50),
|
||||
feature_info=feature_info,
|
||||
output_dir=tmp_path / "modified",
|
||||
)
|
||||
|
||||
def test_add_feature_invalid_info(self, sample_dataset, tmp_path):
|
||||
"""Test error with invalid feature info."""
|
||||
with pytest.raises(ValueError, match="feature_info must contain keys"):
|
||||
add_feature(
|
||||
sample_dataset,
|
||||
feature_name="reward",
|
||||
feature_values=np.zeros(50),
|
||||
feature_info={"dtype": "float32"}, # Missing 'shape'
|
||||
output_dir=tmp_path / "modified",
|
||||
)
|
||||
|
||||
|
||||
class TestRemoveFeature:
|
||||
def test_remove_single_feature(self, sample_dataset, tmp_path):
|
||||
"""Test removing a single feature."""
|
||||
# First add a feature to remove
|
||||
feature_info = {"dtype": "float32", "shape": (1,), "names": None}
|
||||
|
||||
with (
|
||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||
):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(
|
||||
kwargs.get("local_dir", tmp_path)
|
||||
)
|
||||
|
||||
dataset_with_reward = add_feature(
|
||||
sample_dataset,
|
||||
feature_name="reward",
|
||||
feature_values=np.random.randn(50, 1).astype(np.float32),
|
||||
feature_info=feature_info,
|
||||
output_dir=tmp_path / "with_reward",
|
||||
)
|
||||
|
||||
# Now remove it
|
||||
dataset_without_reward = remove_feature(
|
||||
dataset_with_reward,
|
||||
feature_names="reward",
|
||||
output_dir=tmp_path / "without_reward",
|
||||
)
|
||||
|
||||
# Check feature was removed
|
||||
assert "reward" not in dataset_without_reward.meta.features
|
||||
|
||||
# Check data
|
||||
sample_item = dataset_without_reward[0]
|
||||
assert "reward" not in sample_item
|
||||
|
||||
def test_remove_multiple_features(self, sample_dataset, tmp_path):
|
||||
"""Test removing multiple features at once."""
|
||||
with (
|
||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||
):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(
|
||||
kwargs.get("local_dir", tmp_path)
|
||||
)
|
||||
|
||||
# Add two features
|
||||
dataset = sample_dataset
|
||||
for feature_name in ["reward", "success"]:
|
||||
feature_info = {"dtype": "float32", "shape": (1,), "names": None}
|
||||
dataset = add_feature(
|
||||
dataset,
|
||||
feature_name=feature_name,
|
||||
feature_values=np.random.randn(dataset.meta.total_frames, 1).astype(np.float32),
|
||||
feature_info=feature_info,
|
||||
output_dir=tmp_path / f"with_{feature_name}",
|
||||
)
|
||||
|
||||
# Remove both
|
||||
dataset_clean = remove_feature(
|
||||
dataset,
|
||||
feature_names=["reward", "success"],
|
||||
output_dir=tmp_path / "clean",
|
||||
)
|
||||
|
||||
# Check both were removed
|
||||
assert "reward" not in dataset_clean.meta.features
|
||||
assert "success" not in dataset_clean.meta.features
|
||||
|
||||
def test_remove_nonexistent_feature(self, sample_dataset, tmp_path):
|
||||
"""Test error when removing non-existent feature."""
|
||||
with pytest.raises(ValueError, match="Feature 'nonexistent' not found"):
|
||||
remove_feature(
|
||||
sample_dataset,
|
||||
feature_names="nonexistent",
|
||||
output_dir=tmp_path / "modified",
|
||||
)
|
||||
|
||||
def test_remove_required_feature(self, sample_dataset, tmp_path):
|
||||
"""Test error when trying to remove required features."""
|
||||
with pytest.raises(ValueError, match="Cannot remove required features"):
|
||||
remove_feature(
|
||||
sample_dataset,
|
||||
feature_names="timestamp", # Required feature
|
||||
output_dir=tmp_path / "modified",
|
||||
)
|
||||
|
||||
def test_remove_camera_feature(self, sample_dataset, tmp_path):
|
||||
"""Test removing a camera feature."""
|
||||
camera_keys = sample_dataset.meta.camera_keys
|
||||
if not camera_keys:
|
||||
pytest.skip("No camera keys in dataset")
|
||||
|
||||
# Remove first camera
|
||||
camera_to_remove = camera_keys[0]
|
||||
|
||||
with (
|
||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||
):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
mock_snapshot_download.return_value = str(tmp_path / "without_camera")
|
||||
|
||||
dataset_without_camera = remove_feature(
|
||||
sample_dataset,
|
||||
feature_names=camera_to_remove,
|
||||
output_dir=tmp_path / "without_camera",
|
||||
)
|
||||
|
||||
# Check camera was removed
|
||||
assert camera_to_remove not in dataset_without_camera.meta.features
|
||||
assert camera_to_remove not in dataset_without_camera.meta.camera_keys
|
||||
|
||||
# Check data
|
||||
sample_item = dataset_without_camera[0]
|
||||
assert camera_to_remove not in sample_item
|
||||
|
||||
|
||||
class TestIntegration:
|
||||
def test_complex_workflow(self, sample_dataset, tmp_path):
|
||||
"""Test a complex workflow combining multiple operations."""
|
||||
with (
|
||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||
):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(
|
||||
kwargs.get("local_dir", tmp_path)
|
||||
)
|
||||
|
||||
# 1. Add a reward feature
|
||||
dataset = add_feature(
|
||||
sample_dataset,
|
||||
feature_name="reward",
|
||||
feature_values=np.random.randn(50, 1).astype(np.float32),
|
||||
feature_info={"dtype": "float32", "shape": (1,), "names": None},
|
||||
output_dir=tmp_path / "step1",
|
||||
)
|
||||
|
||||
# 2. Delete an episode
|
||||
dataset = delete_episodes(
|
||||
dataset,
|
||||
episode_indices=[2],
|
||||
output_dir=tmp_path / "step2",
|
||||
)
|
||||
|
||||
# 3. Split into train/val
|
||||
splits = split_dataset(
|
||||
dataset,
|
||||
splits={"train": 0.75, "val": 0.25},
|
||||
output_dir=tmp_path / "step3",
|
||||
)
|
||||
|
||||
# 4. Merge them back
|
||||
merged = merge_datasets(
|
||||
list(splits.values()),
|
||||
output_repo_id="final_dataset",
|
||||
output_dir=tmp_path / "step4",
|
||||
)
|
||||
|
||||
# Check final dataset
|
||||
assert merged.meta.total_episodes == 4 # Started with 5, deleted 1
|
||||
assert merged.meta.total_frames == 40
|
||||
assert "reward" in merged.meta.features # Feature preserved
|
||||
|
||||
# Check data integrity
|
||||
assert len(merged) == 40
|
||||
sample_item = merged[0]
|
||||
assert "reward" in sample_item
|
||||
Reference in New Issue
Block a user