From 26d732c8c886078869c6085e25cfccb4e0fbbe16 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Sun, 15 Mar 2026 23:07:52 -0700 Subject: [PATCH] refactor(dataset): modular files --- src/lerobot/datasets/backtracking.py | 175 ++++ src/lerobot/datasets/dataset_metadata.py | 516 +++++++++++ src/lerobot/datasets/feature_utils.py | 552 ++++++++++++ src/lerobot/datasets/io_utils.py | 342 +++++++ src/lerobot/datasets/lerobot_dataset.py | 670 +------------- src/lerobot/datasets/multi_dataset.py | 210 +++++ src/lerobot/datasets/utils.py | 1040 ++-------------------- tests/datasets/test_aggregate.py | 24 +- tests/datasets/test_dataset_tools.py | 112 +-- tests/fixtures/dataset_factories.py | 4 +- tests/test_control_robot.py | 8 +- 11 files changed, 1925 insertions(+), 1728 deletions(-) create mode 100644 src/lerobot/datasets/backtracking.py create mode 100644 src/lerobot/datasets/dataset_metadata.py create mode 100644 src/lerobot/datasets/feature_utils.py create mode 100644 src/lerobot/datasets/io_utils.py create mode 100644 src/lerobot/datasets/multi_dataset.py diff --git a/src/lerobot/datasets/backtracking.py b/src/lerobot/datasets/backtracking.py new file mode 100644 index 000000000..16363be86 --- /dev/null +++ b/src/lerobot/datasets/backtracking.py @@ -0,0 +1,175 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections import deque +from collections.abc import Iterable, Iterator + + +class LookBackError(Exception): + """ + Exception raised when trying to look back in the history of a Backtrackable object. + """ + + pass + + +class LookAheadError(Exception): + """ + Exception raised when trying to look ahead in the future of a Backtrackable object. + """ + + pass + + +class Backtrackable[T]: + """ + Wrap any iterator/iterable so you can step back up to `history` items + and look ahead up to `lookahead` items. + + This is useful for streaming datasets where you need to access previous and future items + but can't load the entire dataset into memory. + + Example: + ------- + ```python + ds = load_dataset("c4", "en", streaming=True, split="train") + rev = Backtrackable(ds, history=3, lookahead=2) + + x0 = next(rev) # forward + x1 = next(rev) + x2 = next(rev) + + # Look ahead + x3_peek = rev.peek_ahead(1) # next item without moving cursor + x4_peek = rev.peek_ahead(2) # two items ahead + + # Look back + x1_again = rev.peek_back(1) # previous item without moving cursor + x0_again = rev.peek_back(2) # two items back + + # Move backward + x1_back = rev.prev() # back one step + next(rev) # returns x2, continues forward from where we were + ``` + """ + + __slots__ = ("_source", "_back_buf", "_ahead_buf", "_cursor", "_history", "_lookahead") + + def __init__(self, iterable: Iterable[T], *, history: int = 1, lookahead: int = 0): + if history < 1: + raise ValueError("history must be >= 1") + if lookahead <= 0: + raise ValueError("lookahead must be > 0") + + self._source: Iterator[T] = iter(iterable) + self._back_buf: deque[T] = deque(maxlen=history) + self._ahead_buf: deque[T] = deque(maxlen=lookahead) if lookahead > 0 else deque() + self._cursor: int = 0 + self._history = history + self._lookahead = lookahead + + def __iter__(self) -> "Backtrackable[T]": + return self + + def __next__(self) -> T: + # If we've stepped back, consume from back buffer first + if self._cursor < 0: # -1 means "last item", etc. + self._cursor += 1 + return self._back_buf[self._cursor] + + # If we have items in the ahead buffer, use them first + item = self._ahead_buf.popleft() if self._ahead_buf else next(self._source) + + # Add current item to back buffer and reset cursor + self._back_buf.append(item) + self._cursor = 0 + return item + + def prev(self) -> T: + """ + Step one item back in history and return it. + Raises IndexError if already at the oldest buffered item. + """ + if len(self._back_buf) + self._cursor <= 1: + raise LookBackError("At start of history") + + self._cursor -= 1 + return self._back_buf[self._cursor] + + def peek_back(self, n: int = 1) -> T: + """ + Look `n` items back (n=1 == previous item) without moving the cursor. + """ + if n < 0 or n + 1 > len(self._back_buf) + self._cursor: + raise LookBackError("peek_back distance out of range") + + return self._back_buf[self._cursor - (n + 1)] + + def peek_ahead(self, n: int = 1) -> T: + """ + Look `n` items ahead (n=1 == next item) without moving the cursor. + Fills the ahead buffer if necessary. + """ + if n < 1: + raise LookAheadError("peek_ahead distance must be 1 or more") + elif n > self._lookahead: + raise LookAheadError("peek_ahead distance exceeds lookahead limit") + + # Fill ahead buffer if we don't have enough items + while len(self._ahead_buf) < n: + try: + item = next(self._source) + self._ahead_buf.append(item) + + except StopIteration as err: + raise LookAheadError("peek_ahead: not enough items in source") from err + + return self._ahead_buf[n - 1] + + def history(self) -> list[T]: + """ + Return a copy of the buffered history (most recent last). + The list length ≤ `history` argument passed at construction. + """ + if self._cursor == 0: + return list(self._back_buf) + + # When cursor<0, slice so the order remains chronological + return list(self._back_buf)[: self._cursor or None] + + def can_peek_back(self, steps: int = 1) -> bool: + """ + Check if we can go back `steps` items without raising an IndexError. + """ + return steps <= len(self._back_buf) + self._cursor + + def can_peek_ahead(self, steps: int = 1) -> bool: + """ + Check if we can peek ahead `steps` items. + This may involve trying to fill the ahead buffer. + """ + if self._lookahead > 0 and steps > self._lookahead: + return False + + # Try to fill ahead buffer to check if we can peek that far + try: + while len(self._ahead_buf) < steps: + if self._lookahead > 0 and len(self._ahead_buf) >= self._lookahead: + return False + item = next(self._source) + self._ahead_buf.append(item) + return True + except StopIteration: + return False diff --git a/src/lerobot/datasets/dataset_metadata.py b/src/lerobot/datasets/dataset_metadata.py new file mode 100644 index 000000000..61585a8a3 --- /dev/null +++ b/src/lerobot/datasets/dataset_metadata.py @@ -0,0 +1,516 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pathlib import Path + +import numpy as np +import packaging.version +import pandas as pd +import pyarrow as pa +import pyarrow.parquet as pq +from huggingface_hub import snapshot_download + +from lerobot.datasets.compute_stats import aggregate_stats +from lerobot.datasets.utils import ( + DEFAULT_EPISODES_PATH, + DEFAULT_FEATURES, + INFO_PATH, + _validate_feature_names, + check_version_compatibility, + create_empty_dataset_info, + flatten_dict, + get_file_size_in_mb, + get_safe_version, + is_valid_version, + load_episodes, + load_info, + load_stats, + load_subtasks, + load_tasks, + update_chunk_file_indices, + write_info, + write_json, + write_stats, + write_tasks, +) +from lerobot.datasets.video_utils import get_video_info +from lerobot.utils.constants import HF_LEROBOT_HOME + +CODEBASE_VERSION = "v3.0" + + +class LeRobotDatasetMetadata: + def __init__( + self, + repo_id: str, + root: str | Path | None = None, + revision: str | None = None, + force_cache_sync: bool = False, + metadata_buffer_size: int = 10, + ): + self.repo_id = repo_id + self.revision = revision if revision else CODEBASE_VERSION + self.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id + self.writer = None + self.latest_episode = None + self.metadata_buffer: list[dict] = [] + self.metadata_buffer_size = metadata_buffer_size + + try: + if force_cache_sync: + raise FileNotFoundError + self.load_metadata() + except (FileNotFoundError, NotADirectoryError): + if is_valid_version(self.revision): + self.revision = get_safe_version(self.repo_id, self.revision) + + (self.root / "meta").mkdir(exist_ok=True, parents=True) + self.pull_from_repo(allow_patterns="meta/") + self.load_metadata() + + def _flush_metadata_buffer(self) -> None: + """Write all buffered episode metadata to parquet file.""" + if not hasattr(self, "metadata_buffer") or len(self.metadata_buffer) == 0: + return + + combined_dict = {} + for episode_dict in self.metadata_buffer: + for key, value in episode_dict.items(): + if key not in combined_dict: + combined_dict[key] = [] + # Extract value and serialize numpy arrays + # because PyArrow's from_pydict function doesn't support numpy arrays + val = value[0] if isinstance(value, list) else value + combined_dict[key].append(val.tolist() if isinstance(val, np.ndarray) else val) + + first_ep = self.metadata_buffer[0] + chunk_idx = first_ep["meta/episodes/chunk_index"][0] + file_idx = first_ep["meta/episodes/file_index"][0] + + table = pa.Table.from_pydict(combined_dict) + + if not self.writer: + path = Path(self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx)) + path.parent.mkdir(parents=True, exist_ok=True) + self.writer = pq.ParquetWriter( + path, schema=table.schema, compression="snappy", use_dictionary=True + ) + + self.writer.write_table(table) + + self.latest_episode = self.metadata_buffer[-1] + self.metadata_buffer.clear() + + def _close_writer(self) -> None: + """Close and cleanup the parquet writer if it exists.""" + self._flush_metadata_buffer() + + writer = getattr(self, "writer", None) + if writer is not None: + writer.close() + self.writer = None + + def __del__(self): + """ + Trust the user to call .finalize() but as an added safety check call the parquet writer to stop when calling the destructor + """ + self._close_writer() + + def load_metadata(self): + self.info = load_info(self.root) + check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION) + self.tasks = load_tasks(self.root) + self.subtasks = load_subtasks(self.root) + self.episodes = load_episodes(self.root) + self.stats = load_stats(self.root) + + def pull_from_repo( + self, + allow_patterns: list[str] | str | None = None, + ignore_patterns: list[str] | str | None = None, + ) -> None: + snapshot_download( + self.repo_id, + repo_type="dataset", + revision=self.revision, + local_dir=self.root, + allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns, + ) + + @property + def url_root(self) -> str: + return f"hf://datasets/{self.repo_id}" + + @property + def _version(self) -> packaging.version.Version: + """Codebase version used to create this dataset.""" + return packaging.version.parse(self.info["codebase_version"]) + + def get_data_file_path(self, ep_index: int) -> Path: + if self.episodes is None: + self.episodes = load_episodes(self.root) + if ep_index >= len(self.episodes): + raise IndexError( + f"Episode index {ep_index} out of range. Episodes: {len(self.episodes) if self.episodes else 0}" + ) + ep = self.episodes[ep_index] + chunk_idx = ep["data/chunk_index"] + file_idx = ep["data/file_index"] + fpath = self.data_path.format(chunk_index=chunk_idx, file_index=file_idx) + return Path(fpath) + + def get_video_file_path(self, ep_index: int, vid_key: str) -> Path: + if self.episodes is None: + self.episodes = load_episodes(self.root) + if ep_index >= len(self.episodes): + raise IndexError( + f"Episode index {ep_index} out of range. Episodes: {len(self.episodes) if self.episodes else 0}" + ) + ep = self.episodes[ep_index] + chunk_idx = ep[f"videos/{vid_key}/chunk_index"] + file_idx = ep[f"videos/{vid_key}/file_index"] + fpath = self.video_path.format(video_key=vid_key, chunk_index=chunk_idx, file_index=file_idx) + return Path(fpath) + + @property + def data_path(self) -> str: + """Formattable string for the parquet files.""" + return self.info["data_path"] + + @property + def video_path(self) -> str | None: + """Formattable string for the video files.""" + return self.info["video_path"] + + @property + def robot_type(self) -> str | None: + """Robot type used in recording this dataset.""" + return self.info["robot_type"] + + @property + def fps(self) -> int: + """Frames per second used during data collection.""" + return self.info["fps"] + + @property + def features(self) -> dict[str, dict]: + """All features contained in the dataset.""" + return self.info["features"] + + @property + def image_keys(self) -> list[str]: + """Keys to access visual modalities stored as images.""" + return [key for key, ft in self.features.items() if ft["dtype"] == "image"] + + @property + def video_keys(self) -> list[str]: + """Keys to access visual modalities stored as videos.""" + return [key for key, ft in self.features.items() if ft["dtype"] == "video"] + + @property + def camera_keys(self) -> list[str]: + """Keys to access visual modalities (regardless of their storage method).""" + return [key for key, ft in self.features.items() if ft["dtype"] in ["video", "image"]] + + @property + def names(self) -> dict[str, list | dict]: + """Names of the various dimensions of vector modalities.""" + return {key: ft["names"] for key, ft in self.features.items()} + + @property + def shapes(self) -> dict: + """Shapes for the different features.""" + return {key: tuple(ft["shape"]) for key, ft in self.features.items()} + + @property + def total_episodes(self) -> int: + """Total number of episodes available.""" + return self.info["total_episodes"] + + @property + def total_frames(self) -> int: + """Total number of frames saved in this dataset.""" + return self.info["total_frames"] + + @property + def total_tasks(self) -> int: + """Total number of different tasks performed in this dataset.""" + return self.info["total_tasks"] + + @property + def chunks_size(self) -> int: + """Max number of files per chunk.""" + return self.info["chunks_size"] + + @property + def data_files_size_in_mb(self) -> int: + """Max size of data file in mega bytes.""" + return self.info["data_files_size_in_mb"] + + @property + def video_files_size_in_mb(self) -> int: + """Max size of video file in mega bytes.""" + return self.info["video_files_size_in_mb"] + + def get_task_index(self, task: str) -> int | None: + """ + Given a task in natural language, returns its task_index if the task already exists in the dataset, + otherwise return None. + """ + if task in self.tasks.index: + return int(self.tasks.loc[task].task_index) + else: + return None + + def save_episode_tasks(self, tasks: list[str]): + if len(set(tasks)) != len(tasks): + raise ValueError(f"Tasks are not unique: {tasks}") + + if self.tasks is None: + new_tasks = tasks + task_indices = range(len(tasks)) + self.tasks = pd.DataFrame({"task_index": task_indices}, index=pd.Index(tasks, name="task")) + else: + new_tasks = [task for task in tasks if task not in self.tasks.index] + new_task_indices = range(len(self.tasks), len(self.tasks) + len(new_tasks)) + for task_idx, task in zip(new_task_indices, new_tasks, strict=False): + self.tasks.loc[task] = task_idx + + if len(new_tasks) > 0: + # Update on disk + write_tasks(self.tasks, self.root) + + def _save_episode_metadata(self, episode_dict: dict) -> None: + """Buffer episode metadata and write to parquet in batches for efficiency. + + This function accumulates episode metadata in a buffer and flushes it when the buffer + reaches the configured size. This reduces I/O overhead by writing multiple episodes + at once instead of one row at a time. + + Notes: We both need to update parquet files and HF dataset: + - `pandas` loads parquet file in RAM + - `datasets` relies on a memory mapping from pyarrow (no RAM). It either converts parquet files to a pyarrow cache on disk, + or loads directly from pyarrow cache. + """ + # Convert to list format for each value + episode_dict = {key: [value] for key, value in episode_dict.items()} + num_frames = episode_dict["length"][0] + + if self.latest_episode is None: + # Initialize indices and frame count for a new dataset made of the first episode data + chunk_idx, file_idx = 0, 0 + if self.episodes is not None and len(self.episodes) > 0: + # It means we are resuming recording, so we need to load the latest episode + # Update the indices to avoid overwriting the latest episode + chunk_idx = self.episodes[-1]["meta/episodes/chunk_index"] + file_idx = self.episodes[-1]["meta/episodes/file_index"] + latest_num_frames = self.episodes[-1]["dataset_to_index"] + episode_dict["dataset_from_index"] = [latest_num_frames] + episode_dict["dataset_to_index"] = [latest_num_frames + num_frames] + + # When resuming, move to the next file + chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.chunks_size) + else: + episode_dict["dataset_from_index"] = [0] + episode_dict["dataset_to_index"] = [num_frames] + + episode_dict["meta/episodes/chunk_index"] = [chunk_idx] + episode_dict["meta/episodes/file_index"] = [file_idx] + else: + chunk_idx = self.latest_episode["meta/episodes/chunk_index"][0] + file_idx = self.latest_episode["meta/episodes/file_index"][0] + + latest_path = ( + self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx) + if self.writer is None + else self.writer.where + ) + + if Path(latest_path).exists(): + latest_size_in_mb = get_file_size_in_mb(Path(latest_path)) + latest_num_frames = self.latest_episode["episode_index"][0] + + av_size_per_frame = latest_size_in_mb / latest_num_frames if latest_num_frames > 0 else 0.0 + + if latest_size_in_mb + av_size_per_frame * num_frames >= self.data_files_size_in_mb: + # Size limit is reached, flush buffer and prepare new parquet file + self._flush_metadata_buffer() + chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.chunks_size) + self._close_writer() + + # Update the existing pandas dataframe with new row + episode_dict["meta/episodes/chunk_index"] = [chunk_idx] + episode_dict["meta/episodes/file_index"] = [file_idx] + episode_dict["dataset_from_index"] = [self.latest_episode["dataset_to_index"][0]] + episode_dict["dataset_to_index"] = [self.latest_episode["dataset_to_index"][0] + num_frames] + + # Add to buffer + self.metadata_buffer.append(episode_dict) + self.latest_episode = episode_dict + + if len(self.metadata_buffer) >= self.metadata_buffer_size: + self._flush_metadata_buffer() + + def save_episode( + self, + episode_index: int, + episode_length: int, + episode_tasks: list[str], + episode_stats: dict[str, dict], + episode_metadata: dict, + ) -> None: + episode_dict = { + "episode_index": episode_index, + "tasks": episode_tasks, + "length": episode_length, + } + episode_dict.update(episode_metadata) + episode_dict.update(flatten_dict({"stats": episode_stats})) + self._save_episode_metadata(episode_dict) + + # Update info + self.info["total_episodes"] += 1 + self.info["total_frames"] += episode_length + self.info["total_tasks"] = len(self.tasks) + self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"} + + write_info(self.info, self.root) + + self.stats = aggregate_stats([self.stats, episode_stats]) if self.stats is not None else episode_stats + write_stats(self.stats, self.root) + + def update_video_info(self, video_key: str | None = None) -> None: + """ + Warning: this function writes info from first episode videos, implicitly assuming that all videos have + been encoded the same way. Also, this means it assumes the first episode exists. + """ + if video_key is not None and video_key not in self.video_keys: + raise ValueError(f"Video key {video_key} not found in dataset") + + video_keys = [video_key] if video_key is not None else self.video_keys + for key in video_keys: + if not self.features[key].get("info", None): + video_path = self.root / self.video_path.format(video_key=key, chunk_index=0, file_index=0) + self.info["features"][key]["info"] = get_video_info(video_path) + + def update_chunk_settings( + self, + chunks_size: int | None = None, + data_files_size_in_mb: int | None = None, + video_files_size_in_mb: int | None = None, + ) -> None: + """Update chunk and file size settings after dataset creation. + + This allows users to customize storage organization without modifying the constructor. + These settings control how episodes are chunked and how large files can grow before + creating new ones. + + Args: + chunks_size: Maximum number of files per chunk directory. If None, keeps current value. + data_files_size_in_mb: Maximum size for data parquet files in MB. If None, keeps current value. + video_files_size_in_mb: Maximum size for video files in MB. If None, keeps current value. + """ + if chunks_size is not None: + if chunks_size <= 0: + raise ValueError(f"chunks_size must be positive, got {chunks_size}") + self.info["chunks_size"] = chunks_size + + if data_files_size_in_mb is not None: + if data_files_size_in_mb <= 0: + raise ValueError(f"data_files_size_in_mb must be positive, got {data_files_size_in_mb}") + self.info["data_files_size_in_mb"] = data_files_size_in_mb + + if video_files_size_in_mb is not None: + if video_files_size_in_mb <= 0: + raise ValueError(f"video_files_size_in_mb must be positive, got {video_files_size_in_mb}") + self.info["video_files_size_in_mb"] = video_files_size_in_mb + + # Update the info file on disk + write_info(self.info, self.root) + + def get_chunk_settings(self) -> dict[str, int]: + """Get current chunk and file size settings. + + Returns: + Dict containing chunks_size, data_files_size_in_mb, and video_files_size_in_mb. + """ + return { + "chunks_size": self.chunks_size, + "data_files_size_in_mb": self.data_files_size_in_mb, + "video_files_size_in_mb": self.video_files_size_in_mb, + } + + def __repr__(self): + feature_keys = list(self.features) + return ( + f"{self.__class__.__name__}({{\n" + f" Repository ID: '{self.repo_id}',\n" + f" Total episodes: '{self.total_episodes}',\n" + f" Total frames: '{self.total_frames}',\n" + f" Features: '{feature_keys}',\n" + "})',\n" + ) + + @classmethod + def create( + cls, + repo_id: str, + fps: int, + features: dict, + robot_type: str | None = None, + root: str | Path | None = None, + use_videos: bool = True, + metadata_buffer_size: int = 10, + chunks_size: int | None = None, + data_files_size_in_mb: int | None = None, + video_files_size_in_mb: int | None = None, + ) -> "LeRobotDatasetMetadata": + """Creates metadata for a LeRobotDataset.""" + obj = cls.__new__(cls) + obj.repo_id = repo_id + obj.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id + + obj.root.mkdir(parents=True, exist_ok=False) + + features = {**features, **DEFAULT_FEATURES} + _validate_feature_names(features) + + obj.tasks = None + obj.subtasks = None + obj.episodes = None + obj.stats = None + obj.info = create_empty_dataset_info( + CODEBASE_VERSION, + fps, + features, + use_videos, + robot_type, + chunks_size, + data_files_size_in_mb, + video_files_size_in_mb, + ) + if len(obj.video_keys) > 0 and not use_videos: + raise ValueError( + f"Features contain video keys {obj.video_keys}, but 'use_videos' is set to False. " + "Either remove video features from the features dict, or set 'use_videos=True'." + ) + write_json(obj.info, obj.root / INFO_PATH) + obj.revision = None + obj.writer = None + obj.latest_episode = None + obj.metadata_buffer = [] + obj.metadata_buffer_size = metadata_buffer_size + return obj diff --git a/src/lerobot/datasets/feature_utils.py b/src/lerobot/datasets/feature_utils.py new file mode 100644 index 000000000..d9a3c6301 --- /dev/null +++ b/src/lerobot/datasets/feature_utils.py @@ -0,0 +1,552 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pprint import pformat +from typing import Any + +import datasets +import numpy as np +from PIL import Image as PILImage + +from lerobot.configs.types import FeatureType, PolicyFeature +from lerobot.datasets.utils import ( + DEFAULT_CHUNK_SIZE, + DEFAULT_DATA_FILE_SIZE_IN_MB, + DEFAULT_DATA_PATH, + DEFAULT_FEATURES, + DEFAULT_VIDEO_FILE_SIZE_IN_MB, + DEFAULT_VIDEO_PATH, +) +from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_STR +from lerobot.utils.utils import is_valid_numpy_dtype_string + + +def get_hf_features_from_features(features: dict) -> datasets.Features: + """Convert a LeRobot features dictionary to a `datasets.Features` object. + + Args: + features (dict): A LeRobot-style feature dictionary. + + Returns: + datasets.Features: The corresponding Hugging Face `datasets.Features` object. + + Raises: + ValueError: If a feature has an unsupported shape. + """ + hf_features = {} + for key, ft in features.items(): + if ft["dtype"] == "video": + continue + elif ft["dtype"] == "image": + hf_features[key] = datasets.Image() + elif ft["shape"] == (1,): + hf_features[key] = datasets.Value(dtype=ft["dtype"]) + elif len(ft["shape"]) == 1: + hf_features[key] = datasets.Sequence( + length=ft["shape"][0], feature=datasets.Value(dtype=ft["dtype"]) + ) + elif len(ft["shape"]) == 2: + hf_features[key] = datasets.Array2D(shape=ft["shape"], dtype=ft["dtype"]) + elif len(ft["shape"]) == 3: + hf_features[key] = datasets.Array3D(shape=ft["shape"], dtype=ft["dtype"]) + elif len(ft["shape"]) == 4: + hf_features[key] = datasets.Array4D(shape=ft["shape"], dtype=ft["dtype"]) + elif len(ft["shape"]) == 5: + hf_features[key] = datasets.Array5D(shape=ft["shape"], dtype=ft["dtype"]) + else: + raise ValueError(f"Corresponding feature is not valid: {ft}") + + return datasets.Features(hf_features) + + +def _validate_feature_names(features: dict[str, dict]) -> None: + """Validate that feature names do not contain invalid characters. + + Args: + features (dict): The LeRobot features dictionary. + + Raises: + ValueError: If any feature name contains '/'. + """ + invalid_features = {name: ft for name, ft in features.items() if "/" in name} + if invalid_features: + raise ValueError(f"Feature names should not contain '/'. Found '/' in '{invalid_features}'.") + + +def hw_to_dataset_features( + hw_features: dict[str, type | tuple], prefix: str, use_video: bool = True +) -> dict[str, dict]: + """Convert hardware-specific features to a LeRobot dataset feature dictionary. + + This function takes a dictionary describing hardware outputs (like joint states + or camera image shapes) and formats it into the standard LeRobot feature + specification. + + Args: + hw_features (dict): Dictionary mapping feature names to their type (float for + joints) or shape (tuple for images). + prefix (str): The prefix to add to the feature keys (e.g., "observation" + or "action"). + use_video (bool): If True, image features are marked as "video", otherwise "image". + + Returns: + dict: A LeRobot features dictionary. + """ + features = {} + joint_fts = { + key: ftype + for key, ftype in hw_features.items() + if ftype is float or (isinstance(ftype, PolicyFeature) and ftype.type != FeatureType.VISUAL) + } + cam_fts = {key: shape for key, shape in hw_features.items() if isinstance(shape, tuple)} + + if joint_fts and prefix == ACTION: + features[prefix] = { + "dtype": "float32", + "shape": (len(joint_fts),), + "names": list(joint_fts), + } + + if joint_fts and prefix == OBS_STR: + features[f"{prefix}.state"] = { + "dtype": "float32", + "shape": (len(joint_fts),), + "names": list(joint_fts), + } + + for key, shape in cam_fts.items(): + features[f"{prefix}.images.{key}"] = { + "dtype": "video" if use_video else "image", + "shape": shape, + "names": ["height", "width", "channels"], + } + + _validate_feature_names(features) + return features + + +def build_dataset_frame( + ds_features: dict[str, dict], values: dict[str, Any], prefix: str +) -> dict[str, np.ndarray]: + """Construct a single data frame from raw values based on dataset features. + + A "frame" is a dictionary containing all the data for a single timestep, + formatted as numpy arrays according to the feature specification. + + Args: + ds_features (dict): The LeRobot dataset features dictionary. + values (dict): A dictionary of raw values from the hardware/environment. + prefix (str): The prefix to filter features by (e.g., "observation" + or "action"). + + Returns: + dict: A dictionary representing a single frame of data. + """ + frame = {} + for key, ft in ds_features.items(): + if key in DEFAULT_FEATURES or not key.startswith(prefix): + continue + elif ft["dtype"] == "float32" and len(ft["shape"]) == 1: + frame[key] = np.array([values[name] for name in ft["names"]], dtype=np.float32) + elif ft["dtype"] in ["image", "video"]: + frame[key] = values[key.removeprefix(f"{prefix}.images.")] + + return frame + + +def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFeature]: + """Convert dataset features to policy features. + + This function transforms the dataset's feature specification into a format + that a policy can use, classifying features by type (e.g., visual, state, + action) and ensuring correct shapes (e.g., channel-first for images). + + Args: + features (dict): The LeRobot dataset features dictionary. + + Returns: + dict: A dictionary mapping feature keys to `PolicyFeature` objects. + + Raises: + ValueError: If an image feature does not have a 3D shape. + """ + # TODO(aliberts): Implement "type" in dataset features and simplify this + policy_features = {} + for key, ft in features.items(): + shape = ft["shape"] + if ft["dtype"] in ["image", "video"]: + type = FeatureType.VISUAL + if len(shape) != 3: + raise ValueError(f"Number of dimensions of {key} != 3 (shape={shape})") + + names = ft["names"] + # Backward compatibility for "channel" which is an error introduced in LeRobotDataset v2.0 for ported datasets. + if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w) + shape = (shape[2], shape[0], shape[1]) + elif key == OBS_ENV_STATE: + type = FeatureType.ENV + elif key.startswith(OBS_STR): + type = FeatureType.STATE + elif key.startswith(ACTION): + type = FeatureType.ACTION + else: + continue + + policy_features[key] = PolicyFeature( + type=type, + shape=shape, + ) + + return policy_features + + +def combine_feature_dicts(*dicts: dict) -> dict: + """Merge LeRobot grouped feature dicts. + + - For 1D numeric specs (dtype not image/video/string) with "names": we merge the names and recompute the shape. + - For others (e.g. `observation.images.*`), the last one wins (if they are identical). + + Args: + *dicts: A variable number of LeRobot feature dictionaries to merge. + + Returns: + dict: A single merged feature dictionary. + + Raises: + ValueError: If there's a dtype mismatch for a feature being merged. + """ + out: dict = {} + for d in dicts: + for key, value in d.items(): + if not isinstance(value, dict): + out[key] = value + continue + + dtype = value.get("dtype") + shape = value.get("shape") + is_vector = ( + dtype not in ("image", "video", "string") + and isinstance(shape, tuple) + and len(shape) == 1 + and "names" in value + ) + + if is_vector: + # Initialize or retrieve the accumulating dict for this feature key + target = out.setdefault(key, {"dtype": dtype, "names": [], "shape": (0,)}) + # Ensure consistent data types across merged entries + if "dtype" in target and dtype != target["dtype"]: + raise ValueError(f"dtype mismatch for '{key}': {target['dtype']} vs {dtype}") + + # Merge feature names: append only new ones to preserve order without duplicates + seen = set(target["names"]) + for n in value["names"]: + if n not in seen: + target["names"].append(n) + seen.add(n) + # Recompute the shape to reflect the updated number of features + target["shape"] = (len(target["names"]),) + else: + # For images/videos and non-1D entries: override with the latest definition + out[key] = value + return out + + +def create_empty_dataset_info( + codebase_version: str, + fps: int, + features: dict, + use_videos: bool, + robot_type: str | None = None, + chunks_size: int | None = None, + data_files_size_in_mb: int | None = None, + video_files_size_in_mb: int | None = None, +) -> dict: + """Create a template dictionary for a new dataset's `info.json`. + + Args: + codebase_version (str): The version of the LeRobot codebase. + fps (int): The frames per second of the data. + features (dict): The LeRobot features dictionary for the dataset. + use_videos (bool): Whether the dataset will store videos. + robot_type (str | None): The type of robot used, if any. + + Returns: + dict: A dictionary with the initial dataset metadata. + """ + return { + "codebase_version": codebase_version, + "robot_type": robot_type, + "total_episodes": 0, + "total_frames": 0, + "total_tasks": 0, + "chunks_size": chunks_size or DEFAULT_CHUNK_SIZE, + "data_files_size_in_mb": data_files_size_in_mb or DEFAULT_DATA_FILE_SIZE_IN_MB, + "video_files_size_in_mb": video_files_size_in_mb or DEFAULT_VIDEO_FILE_SIZE_IN_MB, + "fps": fps, + "splits": {}, + "data_path": DEFAULT_DATA_PATH, + "video_path": DEFAULT_VIDEO_PATH if use_videos else None, + "features": features, + } + + +def check_delta_timestamps( + delta_timestamps: dict[str, list[float]], fps: int, tolerance_s: float, raise_value_error: bool = True +) -> bool: + """Check if delta timestamps are multiples of 1/fps +/- tolerance. + + This ensures that adding these delta timestamps to any existing timestamp in + the dataset will result in a value that aligns with the dataset's frame rate. + + Args: + delta_timestamps (dict): A dictionary where values are lists of time + deltas in seconds. + fps (int): The frames per second of the dataset. + tolerance_s (float): The allowed tolerance in seconds. + raise_value_error (bool): If True, raises an error on failure. + + Returns: + bool: True if all deltas are valid, False otherwise. + + Raises: + ValueError: If any delta is outside the tolerance and `raise_value_error` is True. + """ + outside_tolerance = {} + for key, delta_ts in delta_timestamps.items(): + within_tolerance = [abs(ts * fps - round(ts * fps)) / fps <= tolerance_s for ts in delta_ts] + if not all(within_tolerance): + outside_tolerance[key] = [ + ts for ts, is_within in zip(delta_ts, within_tolerance, strict=True) if not is_within + ] + + if len(outside_tolerance) > 0: + if raise_value_error: + raise ValueError( + f""" + The following delta_timestamps are found outside of tolerance range. + Please make sure they are multiples of 1/{fps} +/- tolerance and adjust + their values accordingly. + \n{pformat(outside_tolerance)} + """ + ) + return False + + return True + + +def get_delta_indices(delta_timestamps: dict[str, list[float]], fps: int) -> dict[str, list[int]]: + """Convert delta timestamps in seconds to delta indices in frames. + + Args: + delta_timestamps (dict): A dictionary of time deltas in seconds. + fps (int): The frames per second of the dataset. + + Returns: + dict: A dictionary of frame delta indices. + """ + delta_indices = {} + for key, delta_ts in delta_timestamps.items(): + delta_indices[key] = [round(d * fps) for d in delta_ts] + + return delta_indices + + +def validate_frame(frame: dict, features: dict) -> None: + expected_features = set(features) - set(DEFAULT_FEATURES) + actual_features = set(frame) + + # task is a special required field that's not part of regular features + if "task" not in actual_features: + raise ValueError("Feature mismatch in `frame` dictionary:\nMissing features: {'task'}\n") + + # Remove task from actual_features for regular feature validation + actual_features_for_validation = actual_features - {"task"} + + error_message = validate_features_presence(actual_features_for_validation, expected_features) + + common_features = actual_features_for_validation & expected_features + for name in common_features: + error_message += validate_feature_dtype_and_shape(name, features[name], frame[name]) + + if error_message: + raise ValueError(error_message) + + +def validate_features_presence(actual_features: set[str], expected_features: set[str]) -> str: + """Check for missing or extra features in a frame. + + Args: + actual_features (set[str]): The set of feature names present in the frame. + expected_features (set[str]): The set of feature names expected in the frame. + + Returns: + str: An error message string if there's a mismatch, otherwise an empty string. + """ + error_message = "" + missing_features = expected_features - actual_features + extra_features = actual_features - expected_features + + if missing_features or extra_features: + error_message += "Feature mismatch in `frame` dictionary:\n" + if missing_features: + error_message += f"Missing features: {missing_features}\n" + if extra_features: + error_message += f"Extra features: {extra_features}\n" + + return error_message + + +def validate_feature_dtype_and_shape( + name: str, feature: dict, value: np.ndarray | PILImage.Image | str +) -> str: + """Validate the dtype and shape of a single feature's value. + + Args: + name (str): The name of the feature. + feature (dict): The feature specification from the LeRobot features dictionary. + value: The value of the feature to validate. + + Returns: + str: An error message if validation fails, otherwise an empty string. + + Raises: + NotImplementedError: If the feature dtype is not supported for validation. + """ + expected_dtype = feature["dtype"] + expected_shape = feature["shape"] + if is_valid_numpy_dtype_string(expected_dtype): + return validate_feature_numpy_array(name, expected_dtype, expected_shape, value) + elif expected_dtype in ["image", "video"]: + return validate_feature_image_or_video(name, expected_shape, value) + elif expected_dtype == "string": + return validate_feature_string(name, value) + else: + raise NotImplementedError(f"The feature dtype '{expected_dtype}' is not implemented yet.") + + +def validate_feature_numpy_array( + name: str, expected_dtype: str, expected_shape: list[int], value: np.ndarray +) -> str: + """Validate a feature that is expected to be a numpy array. + + Args: + name (str): The name of the feature. + expected_dtype (str): The expected numpy dtype as a string. + expected_shape (list[int]): The expected shape. + value (np.ndarray): The numpy array to validate. + + Returns: + str: An error message if validation fails, otherwise an empty string. + """ + error_message = "" + if isinstance(value, np.ndarray): + actual_dtype = value.dtype + actual_shape = value.shape + + if actual_dtype != np.dtype(expected_dtype): + error_message += f"The feature '{name}' of dtype '{actual_dtype}' is not of the expected dtype '{expected_dtype}'.\n" + + if actual_shape != expected_shape: + error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{expected_shape}'.\n" + else: + error_message += f"The feature '{name}' is not a 'np.ndarray'. Expected type is '{expected_dtype}', but type '{type(value)}' provided instead.\n" + + return error_message + + +def validate_feature_image_or_video( + name: str, expected_shape: list[str], value: np.ndarray | PILImage.Image +) -> str: + """Validate a feature that is expected to be an image or video frame. + + Accepts `np.ndarray` (channel-first or channel-last) or `PIL.Image.Image`. + + Args: + name (str): The name of the feature. + expected_shape (list[str]): The expected shape (C, H, W). + value: The image data to validate. + + Returns: + str: An error message if validation fails, otherwise an empty string. + """ + # Note: The check of pixels range ([0,1] for float and [0,255] for uint8) is done by the image writer threads. + error_message = "" + if isinstance(value, np.ndarray): + actual_shape = value.shape + c, h, w = expected_shape + if len(actual_shape) != 3 or (actual_shape != (c, h, w) and actual_shape != (h, w, c)): + error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(c, h, w)}' or '{(h, w, c)}'.\n" + elif isinstance(value, PILImage.Image): + pass + else: + error_message += f"The feature '{name}' is expected to be of type 'PIL.Image' or 'np.ndarray' channel first or channel last, but type '{type(value)}' provided instead.\n" + + return error_message + + +def validate_feature_string(name: str, value: str) -> str: + """Validate a feature that is expected to be a string. + + Args: + name (str): The name of the feature. + value (str): The value to validate. + + Returns: + str: An error message if validation fails, otherwise an empty string. + """ + if not isinstance(value, str): + return f"The feature '{name}' is expected to be of type 'str', but type '{type(value)}' provided instead.\n" + return "" + + +def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features: dict) -> None: + """Validate the episode buffer before it's written to disk. + + Ensures the buffer has the required keys, contains at least one frame, and + has features consistent with the dataset's specification. + + Args: + episode_buffer (dict): The buffer containing data for a single episode. + total_episodes (int): The current total number of episodes in the dataset. + features (dict): The LeRobot features dictionary for the dataset. + + Raises: + ValueError: If the buffer is invalid. + NotImplementedError: If the episode index is manually set and doesn't match. + """ + if "size" not in episode_buffer: + raise ValueError("size key not found in episode_buffer") + + if "task" not in episode_buffer: + raise ValueError("task key not found in episode_buffer") + + if episode_buffer["episode_index"] != total_episodes: + # TODO(aliberts): Add option to use existing episode_index + raise NotImplementedError( + "You might have manually provided the episode_buffer with an episode_index that doesn't " + "match the total number of episodes already in the dataset. This is not supported for now." + ) + + if episode_buffer["size"] == 0: + raise ValueError("You must add one or several frames with `add_frame` before calling `add_episode`.") + + buffer_keys = set(episode_buffer.keys()) - {"task", "size"} + if not buffer_keys == set(features): + raise ValueError( + f"Features from `episode_buffer` don't match the ones in `features`." + f"In episode_buffer not in features: {buffer_keys - set(features)}" + f"In features not in episode_buffer: {set(features) - buffer_keys}" + ) diff --git a/src/lerobot/datasets/io_utils.py b/src/lerobot/datasets/io_utils.py new file mode 100644 index 000000000..cee6cfba8 --- /dev/null +++ b/src/lerobot/datasets/io_utils.py @@ -0,0 +1,342 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +from pathlib import Path +from typing import Any + +import datasets +import numpy as np +import pandas +import pandas as pd +import pyarrow.dataset as pa_ds +import pyarrow.parquet as pq +import torch +from datasets import Dataset +from datasets.table import embed_table_storage +from PIL import Image as PILImage +from torchvision import transforms + +from lerobot.datasets.utils import ( + DEFAULT_DATA_FILE_SIZE_IN_MB, + DEFAULT_EPISODES_PATH, + DEFAULT_SUBTASKS_PATH, + DEFAULT_TASKS_PATH, + EPISODES_DIR, + INFO_PATH, + STATS_PATH, + flatten_dict, + serialize_dict, + unflatten_dict, +) +from lerobot.utils.utils import SuppressProgressBars + + +def get_parquet_file_size_in_mb(parquet_path: str | Path) -> float: + metadata = pq.read_metadata(parquet_path) + total_uncompressed_size = 0 + for row_group in range(metadata.num_row_groups): + rg_metadata = metadata.row_group(row_group) + for column in range(rg_metadata.num_columns): + col_metadata = rg_metadata.column(column) + total_uncompressed_size += col_metadata.total_uncompressed_size + return total_uncompressed_size / (1024**2) + + +def get_hf_dataset_size_in_mb(hf_ds: Dataset) -> int: + return hf_ds.data.nbytes // (1024**2) + + +def load_nested_dataset( + pq_dir: Path, features: datasets.Features | None = None, episodes: list[int] | None = None +) -> Dataset: + """Find parquet files in provided directory {pq_dir}/chunk-xxx/file-xxx.parquet + Convert parquet files to pyarrow memory mapped in a cache folder for efficient RAM usage + Concatenate all pyarrow references to return HF Dataset format + + Args: + pq_dir: Directory containing parquet files + features: Optional features schema to ensure consistent loading of complex types like images + episodes: Optional list of episode indices to filter. Uses PyArrow predicate pushdown for efficiency. + """ + paths = sorted(pq_dir.glob("*/*.parquet")) + if len(paths) == 0: + raise FileNotFoundError(f"Provided directory does not contain any parquet file: {pq_dir}") + + with SuppressProgressBars(): + # We use .from_parquet() memory-mapped loading for efficiency + filters = pa_ds.field("episode_index").isin(episodes) if episodes is not None else None + return Dataset.from_parquet([str(path) for path in paths], filters=filters, features=features) + + +def get_parquet_num_frames(parquet_path: str | Path) -> int: + metadata = pq.read_metadata(parquet_path) + return metadata.num_rows + + +def get_file_size_in_mb(file_path: Path) -> float: + """Get file size on disk in megabytes. + + Args: + file_path (Path): Path to the file. + """ + file_size_bytes = file_path.stat().st_size + return file_size_bytes / (1024**2) + + +def embed_images(dataset: datasets.Dataset) -> datasets.Dataset: + """Embed image bytes into the dataset table before saving to Parquet. + + This function prepares a Hugging Face dataset for serialization by converting + image objects into an embedded format that can be stored in Arrow/Parquet. + + Args: + dataset (datasets.Dataset): The input dataset, possibly containing image features. + + Returns: + datasets.Dataset: The dataset with images embedded in the table storage. + """ + # Embed image bytes into the table before saving to parquet + format = dataset.format + dataset = dataset.with_format("arrow") + dataset = dataset.map(embed_table_storage, batched=False) + dataset = dataset.with_format(**format) + return dataset + + +def load_json(fpath: Path) -> Any: + """Load data from a JSON file. + + Args: + fpath (Path): Path to the JSON file. + + Returns: + Any: The data loaded from the JSON file. + """ + with open(fpath) as f: + return json.load(f) + + +def write_json(data: dict, fpath: Path) -> None: + """Write data to a JSON file. + + Creates parent directories if they don't exist. + + Args: + data (dict): The dictionary to write. + fpath (Path): The path to the output JSON file. + """ + fpath.parent.mkdir(exist_ok=True, parents=True) + with open(fpath, "w") as f: + json.dump(data, f, indent=4, ensure_ascii=False) + + +def write_info(info: dict, local_dir: Path) -> None: + write_json(info, local_dir / INFO_PATH) + + +def load_info(local_dir: Path) -> dict: + """Load dataset info metadata from its standard file path. + + Also converts shape lists to tuples for consistency. + + Args: + local_dir (Path): The root directory of the dataset. + + Returns: + dict: The dataset information dictionary. + """ + info = load_json(local_dir / INFO_PATH) + for ft in info["features"].values(): + ft["shape"] = tuple(ft["shape"]) + return info + + +def write_stats(stats: dict, local_dir: Path) -> None: + """Serialize and write dataset statistics to their standard file path. + + Args: + stats (dict): The statistics dictionary (can contain tensors/numpy arrays). + local_dir (Path): The root directory of the dataset. + """ + serialized_stats = serialize_dict(stats) + write_json(serialized_stats, local_dir / STATS_PATH) + + +def cast_stats_to_numpy(stats: dict) -> dict[str, dict[str, np.ndarray]]: + """Recursively cast numerical values in a stats dictionary to numpy arrays. + + Args: + stats (dict): The statistics dictionary. + + Returns: + dict: The statistics dictionary with values cast to numpy arrays. + """ + stats = {key: np.array(value) for key, value in flatten_dict(stats).items()} + return unflatten_dict(stats) + + +def load_stats(local_dir: Path) -> dict[str, dict[str, np.ndarray]] | None: + """Load dataset statistics and cast numerical values to numpy arrays. + + Returns None if the stats file doesn't exist. + + Args: + local_dir (Path): The root directory of the dataset. + + Returns: + A dictionary of statistics or None if the file is not found. + """ + if not (local_dir / STATS_PATH).exists(): + return None + stats = load_json(local_dir / STATS_PATH) + return cast_stats_to_numpy(stats) + + +def write_tasks(tasks: pandas.DataFrame, local_dir: Path) -> None: + path = local_dir / DEFAULT_TASKS_PATH + path.parent.mkdir(parents=True, exist_ok=True) + tasks.to_parquet(path) + + +def load_tasks(local_dir: Path) -> pandas.DataFrame: + tasks = pd.read_parquet(local_dir / DEFAULT_TASKS_PATH) + tasks.index.name = "task" + return tasks + + +def load_subtasks(local_dir: Path) -> pandas.DataFrame | None: + """Load subtasks from subtasks.parquet if it exists.""" + subtasks_path = local_dir / DEFAULT_SUBTASKS_PATH + if subtasks_path.exists(): + return pd.read_parquet(subtasks_path) + return None + + +def write_episodes(episodes: Dataset, local_dir: Path) -> None: + """Write episode metadata to a parquet file in the LeRobot v3.0 format. + This function writes episode-level metadata to a single parquet file. + Used primarily during dataset conversion (v2.1 → v3.0) and in test fixtures. + + Args: + episodes: HuggingFace Dataset containing episode metadata + local_dir: Root directory where the dataset will be stored + """ + episode_size_mb = get_hf_dataset_size_in_mb(episodes) + if episode_size_mb > DEFAULT_DATA_FILE_SIZE_IN_MB: + raise NotImplementedError( + f"Episodes dataset is too large ({episode_size_mb} MB) to write to a single file. " + f"The current limit is {DEFAULT_DATA_FILE_SIZE_IN_MB} MB. " + "This function only supports single-file episode metadata. " + ) + + fpath = local_dir / DEFAULT_EPISODES_PATH.format(chunk_index=0, file_index=0) + fpath.parent.mkdir(parents=True, exist_ok=True) + episodes.to_parquet(fpath) + + +def load_episodes(local_dir: Path) -> datasets.Dataset: + episodes = load_nested_dataset(local_dir / EPISODES_DIR) + # Select episode features/columns containing references to episode data and videos + # (e.g. tasks, dataset_from_index, dataset_to_index, data/chunk_index, data/file_index, etc.) + # This is to speedup access to these data, instead of having to load episode stats. + episodes = episodes.select_columns([key for key in episodes.features if not key.startswith("stats/")]) + return episodes + + +def load_image_as_numpy( + fpath: str | Path, dtype: np.dtype = np.float32, channel_first: bool = True +) -> np.ndarray: + """Load an image from a file into a numpy array. + + Args: + fpath (str | Path): Path to the image file. + dtype (np.dtype): The desired data type of the output array. If floating, + pixels are scaled to [0, 1]. + channel_first (bool): If True, converts the image to (C, H, W) format. + Otherwise, it remains in (H, W, C) format. + + Returns: + np.ndarray: The image as a numpy array. + """ + img = PILImage.open(fpath).convert("RGB") + img_array = np.array(img, dtype=dtype) + if channel_first: # (H, W, C) -> (C, H, W) + img_array = np.transpose(img_array, (2, 0, 1)) + if np.issubdtype(dtype, np.floating): + img_array /= 255.0 + return img_array + + +def hf_transform_to_torch(items_dict: dict[str, list[Any]]) -> dict[str, list[torch.Tensor | str]]: + """Convert a batch from a Hugging Face dataset to torch tensors. + + This transform function converts items from Hugging Face dataset format (pyarrow) + to torch tensors. Importantly, images are converted from PIL objects (H, W, C, uint8) + to a torch image representation (C, H, W, float32) in the range [0, 1]. Other + types are converted to torch.tensor. + + Args: + items_dict (dict): A dictionary representing a batch of data from a + Hugging Face dataset. + + Returns: + dict: The batch with items converted to torch tensors. + """ + for key in items_dict: + first_item = items_dict[key][0] + if isinstance(first_item, PILImage.Image): + to_tensor = transforms.ToTensor() + items_dict[key] = [to_tensor(img) for img in items_dict[key]] + elif first_item is None: + pass + else: + items_dict[key] = [x if isinstance(x, str) else torch.tensor(x) for x in items_dict[key]] + return items_dict + + +def to_parquet_with_hf_images( + df: pandas.DataFrame, path: Path, features: datasets.Features | None = None +) -> None: + """This function correctly writes to parquet a panda DataFrame that contains images encoded by HF dataset. + This way, it can be loaded by HF dataset and correctly formatted images are returned. + + Args: + df: DataFrame to write to parquet. + path: Path to write the parquet file. + features: Optional HuggingFace Features schema. If provided, ensures image columns + are properly typed as Image() in the parquet schema. + """ + # TODO(qlhoest): replace this weird synthax by `df.to_parquet(path)` only + ds = datasets.Dataset.from_dict(df.to_dict(orient="list"), features=features) + ds.to_parquet(path) + + +def item_to_torch(item: dict) -> dict: + """Convert all items in a dictionary to PyTorch tensors where appropriate. + + This function is used to convert an item from a streaming dataset to PyTorch tensors. + + Args: + item (dict): Dictionary of items from a dataset. + + Returns: + dict: Dictionary with all tensor-like items converted to torch.Tensor. + """ + for key, val in item.items(): + if isinstance(val, (np.ndarray | list)) and key not in ["task"]: + # Convert numpy arrays and lists to torch tensors + item[key] = torch.tensor(val) + return item diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index 5d1b5d042..6aecc016e 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -23,30 +23,23 @@ from pathlib import Path import datasets import numpy as np -import packaging.version import pandas as pd import PIL.Image -import pyarrow as pa import pyarrow.parquet as pq import torch import torch.utils from huggingface_hub import HfApi, snapshot_download from huggingface_hub.errors import RevisionNotFoundError -from lerobot.datasets.compute_stats import aggregate_stats, compute_episode_stats +from lerobot.datasets.compute_stats import compute_episode_stats +from lerobot.datasets.dataset_metadata import CODEBASE_VERSION, LeRobotDatasetMetadata from lerobot.datasets.image_writer import AsyncImageWriter, write_image from lerobot.datasets.utils import ( DEFAULT_EPISODES_PATH, - DEFAULT_FEATURES, DEFAULT_IMAGE_PATH, - INFO_PATH, - _validate_feature_names, check_delta_timestamps, - check_version_compatibility, - create_empty_dataset_info, create_lerobot_dataset_card, embed_images, - flatten_dict, get_delta_indices, get_file_size_in_mb, get_hf_features_from_features, @@ -54,501 +47,25 @@ from lerobot.datasets.utils import ( hf_transform_to_torch, is_valid_version, load_episodes, - load_info, load_nested_dataset, - load_stats, - load_subtasks, - load_tasks, update_chunk_file_indices, validate_episode_buffer, validate_frame, write_info, - write_json, - write_stats, - write_tasks, ) from lerobot.datasets.video_utils import ( StreamingVideoEncoder, - VideoFrame, concatenate_video_files, decode_video_frames, encode_video_frames, get_safe_default_codec, get_video_duration_in_s, - get_video_info, resolve_vcodec, ) from lerobot.utils.constants import HF_LEROBOT_HOME logger = logging.getLogger(__name__) -CODEBASE_VERSION = "v3.0" - - -class LeRobotDatasetMetadata: - def __init__( - self, - repo_id: str, - root: str | Path | None = None, - revision: str | None = None, - force_cache_sync: bool = False, - metadata_buffer_size: int = 10, - ): - self.repo_id = repo_id - self.revision = revision if revision else CODEBASE_VERSION - self.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id - self.writer = None - self.latest_episode = None - self.metadata_buffer: list[dict] = [] - self.metadata_buffer_size = metadata_buffer_size - - try: - if force_cache_sync: - raise FileNotFoundError - self.load_metadata() - except (FileNotFoundError, NotADirectoryError): - if is_valid_version(self.revision): - self.revision = get_safe_version(self.repo_id, self.revision) - - (self.root / "meta").mkdir(exist_ok=True, parents=True) - self.pull_from_repo(allow_patterns="meta/") - self.load_metadata() - - def _flush_metadata_buffer(self) -> None: - """Write all buffered episode metadata to parquet file.""" - if not hasattr(self, "metadata_buffer") or len(self.metadata_buffer) == 0: - return - - combined_dict = {} - for episode_dict in self.metadata_buffer: - for key, value in episode_dict.items(): - if key not in combined_dict: - combined_dict[key] = [] - # Extract value and serialize numpy arrays - # because PyArrow's from_pydict function doesn't support numpy arrays - val = value[0] if isinstance(value, list) else value - combined_dict[key].append(val.tolist() if isinstance(val, np.ndarray) else val) - - first_ep = self.metadata_buffer[0] - chunk_idx = first_ep["meta/episodes/chunk_index"][0] - file_idx = first_ep["meta/episodes/file_index"][0] - - table = pa.Table.from_pydict(combined_dict) - - if not self.writer: - path = Path(self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx)) - path.parent.mkdir(parents=True, exist_ok=True) - self.writer = pq.ParquetWriter( - path, schema=table.schema, compression="snappy", use_dictionary=True - ) - - self.writer.write_table(table) - - self.latest_episode = self.metadata_buffer[-1] - self.metadata_buffer.clear() - - def _close_writer(self) -> None: - """Close and cleanup the parquet writer if it exists.""" - self._flush_metadata_buffer() - - writer = getattr(self, "writer", None) - if writer is not None: - writer.close() - self.writer = None - - def __del__(self): - """ - Trust the user to call .finalize() but as an added safety check call the parquet writer to stop when calling the destructor - """ - self._close_writer() - - def load_metadata(self): - self.info = load_info(self.root) - check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION) - self.tasks = load_tasks(self.root) - self.subtasks = load_subtasks(self.root) - self.episodes = load_episodes(self.root) - self.stats = load_stats(self.root) - - def pull_from_repo( - self, - allow_patterns: list[str] | str | None = None, - ignore_patterns: list[str] | str | None = None, - ) -> None: - snapshot_download( - self.repo_id, - repo_type="dataset", - revision=self.revision, - local_dir=self.root, - allow_patterns=allow_patterns, - ignore_patterns=ignore_patterns, - ) - - @property - def url_root(self) -> str: - return f"hf://datasets/{self.repo_id}" - - @property - def _version(self) -> packaging.version.Version: - """Codebase version used to create this dataset.""" - return packaging.version.parse(self.info["codebase_version"]) - - def get_data_file_path(self, ep_index: int) -> Path: - if self.episodes is None: - self.episodes = load_episodes(self.root) - if ep_index >= len(self.episodes): - raise IndexError( - f"Episode index {ep_index} out of range. Episodes: {len(self.episodes) if self.episodes else 0}" - ) - ep = self.episodes[ep_index] - chunk_idx = ep["data/chunk_index"] - file_idx = ep["data/file_index"] - fpath = self.data_path.format(chunk_index=chunk_idx, file_index=file_idx) - return Path(fpath) - - def get_video_file_path(self, ep_index: int, vid_key: str) -> Path: - if self.episodes is None: - self.episodes = load_episodes(self.root) - if ep_index >= len(self.episodes): - raise IndexError( - f"Episode index {ep_index} out of range. Episodes: {len(self.episodes) if self.episodes else 0}" - ) - ep = self.episodes[ep_index] - chunk_idx = ep[f"videos/{vid_key}/chunk_index"] - file_idx = ep[f"videos/{vid_key}/file_index"] - fpath = self.video_path.format(video_key=vid_key, chunk_index=chunk_idx, file_index=file_idx) - return Path(fpath) - - @property - def data_path(self) -> str: - """Formattable string for the parquet files.""" - return self.info["data_path"] - - @property - def video_path(self) -> str | None: - """Formattable string for the video files.""" - return self.info["video_path"] - - @property - def robot_type(self) -> str | None: - """Robot type used in recording this dataset.""" - return self.info["robot_type"] - - @property - def fps(self) -> int: - """Frames per second used during data collection.""" - return self.info["fps"] - - @property - def features(self) -> dict[str, dict]: - """All features contained in the dataset.""" - return self.info["features"] - - @property - def image_keys(self) -> list[str]: - """Keys to access visual modalities stored as images.""" - return [key for key, ft in self.features.items() if ft["dtype"] == "image"] - - @property - def video_keys(self) -> list[str]: - """Keys to access visual modalities stored as videos.""" - return [key for key, ft in self.features.items() if ft["dtype"] == "video"] - - @property - def camera_keys(self) -> list[str]: - """Keys to access visual modalities (regardless of their storage method).""" - return [key for key, ft in self.features.items() if ft["dtype"] in ["video", "image"]] - - @property - def names(self) -> dict[str, list | dict]: - """Names of the various dimensions of vector modalities.""" - return {key: ft["names"] for key, ft in self.features.items()} - - @property - def shapes(self) -> dict: - """Shapes for the different features.""" - return {key: tuple(ft["shape"]) for key, ft in self.features.items()} - - @property - def total_episodes(self) -> int: - """Total number of episodes available.""" - return self.info["total_episodes"] - - @property - def total_frames(self) -> int: - """Total number of frames saved in this dataset.""" - return self.info["total_frames"] - - @property - def total_tasks(self) -> int: - """Total number of different tasks performed in this dataset.""" - return self.info["total_tasks"] - - @property - def chunks_size(self) -> int: - """Max number of files per chunk.""" - return self.info["chunks_size"] - - @property - def data_files_size_in_mb(self) -> int: - """Max size of data file in mega bytes.""" - return self.info["data_files_size_in_mb"] - - @property - def video_files_size_in_mb(self) -> int: - """Max size of video file in mega bytes.""" - return self.info["video_files_size_in_mb"] - - def get_task_index(self, task: str) -> int | None: - """ - Given a task in natural language, returns its task_index if the task already exists in the dataset, - otherwise return None. - """ - if task in self.tasks.index: - return int(self.tasks.loc[task].task_index) - else: - return None - - def save_episode_tasks(self, tasks: list[str]): - if len(set(tasks)) != len(tasks): - raise ValueError(f"Tasks are not unique: {tasks}") - - if self.tasks is None: - new_tasks = tasks - task_indices = range(len(tasks)) - self.tasks = pd.DataFrame({"task_index": task_indices}, index=pd.Index(tasks, name="task")) - else: - new_tasks = [task for task in tasks if task not in self.tasks.index] - new_task_indices = range(len(self.tasks), len(self.tasks) + len(new_tasks)) - for task_idx, task in zip(new_task_indices, new_tasks, strict=False): - self.tasks.loc[task] = task_idx - - if len(new_tasks) > 0: - # Update on disk - write_tasks(self.tasks, self.root) - - def _save_episode_metadata(self, episode_dict: dict) -> None: - """Buffer episode metadata and write to parquet in batches for efficiency. - - This function accumulates episode metadata in a buffer and flushes it when the buffer - reaches the configured size. This reduces I/O overhead by writing multiple episodes - at once instead of one row at a time. - - Notes: We both need to update parquet files and HF dataset: - - `pandas` loads parquet file in RAM - - `datasets` relies on a memory mapping from pyarrow (no RAM). It either converts parquet files to a pyarrow cache on disk, - or loads directly from pyarrow cache. - """ - # Convert to list format for each value - episode_dict = {key: [value] for key, value in episode_dict.items()} - num_frames = episode_dict["length"][0] - - if self.latest_episode is None: - # Initialize indices and frame count for a new dataset made of the first episode data - chunk_idx, file_idx = 0, 0 - if self.episodes is not None and len(self.episodes) > 0: - # It means we are resuming recording, so we need to load the latest episode - # Update the indices to avoid overwriting the latest episode - chunk_idx = self.episodes[-1]["meta/episodes/chunk_index"] - file_idx = self.episodes[-1]["meta/episodes/file_index"] - latest_num_frames = self.episodes[-1]["dataset_to_index"] - episode_dict["dataset_from_index"] = [latest_num_frames] - episode_dict["dataset_to_index"] = [latest_num_frames + num_frames] - - # When resuming, move to the next file - chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.chunks_size) - else: - episode_dict["dataset_from_index"] = [0] - episode_dict["dataset_to_index"] = [num_frames] - - episode_dict["meta/episodes/chunk_index"] = [chunk_idx] - episode_dict["meta/episodes/file_index"] = [file_idx] - else: - chunk_idx = self.latest_episode["meta/episodes/chunk_index"][0] - file_idx = self.latest_episode["meta/episodes/file_index"][0] - - latest_path = ( - self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx) - if self.writer is None - else self.writer.where - ) - - if Path(latest_path).exists(): - latest_size_in_mb = get_file_size_in_mb(Path(latest_path)) - latest_num_frames = self.latest_episode["episode_index"][0] - - av_size_per_frame = latest_size_in_mb / latest_num_frames if latest_num_frames > 0 else 0.0 - - if latest_size_in_mb + av_size_per_frame * num_frames >= self.data_files_size_in_mb: - # Size limit is reached, flush buffer and prepare new parquet file - self._flush_metadata_buffer() - chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.chunks_size) - self._close_writer() - - # Update the existing pandas dataframe with new row - episode_dict["meta/episodes/chunk_index"] = [chunk_idx] - episode_dict["meta/episodes/file_index"] = [file_idx] - episode_dict["dataset_from_index"] = [self.latest_episode["dataset_to_index"][0]] - episode_dict["dataset_to_index"] = [self.latest_episode["dataset_to_index"][0] + num_frames] - - # Add to buffer - self.metadata_buffer.append(episode_dict) - self.latest_episode = episode_dict - - if len(self.metadata_buffer) >= self.metadata_buffer_size: - self._flush_metadata_buffer() - - def save_episode( - self, - episode_index: int, - episode_length: int, - episode_tasks: list[str], - episode_stats: dict[str, dict], - episode_metadata: dict, - ) -> None: - episode_dict = { - "episode_index": episode_index, - "tasks": episode_tasks, - "length": episode_length, - } - episode_dict.update(episode_metadata) - episode_dict.update(flatten_dict({"stats": episode_stats})) - self._save_episode_metadata(episode_dict) - - # Update info - self.info["total_episodes"] += 1 - self.info["total_frames"] += episode_length - self.info["total_tasks"] = len(self.tasks) - self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"} - - write_info(self.info, self.root) - - self.stats = aggregate_stats([self.stats, episode_stats]) if self.stats is not None else episode_stats - write_stats(self.stats, self.root) - - def update_video_info(self, video_key: str | None = None) -> None: - """ - Warning: this function writes info from first episode videos, implicitly assuming that all videos have - been encoded the same way. Also, this means it assumes the first episode exists. - """ - if video_key is not None and video_key not in self.video_keys: - raise ValueError(f"Video key {video_key} not found in dataset") - - video_keys = [video_key] if video_key is not None else self.video_keys - for key in video_keys: - if not self.features[key].get("info", None): - video_path = self.root / self.video_path.format(video_key=key, chunk_index=0, file_index=0) - self.info["features"][key]["info"] = get_video_info(video_path) - - def update_chunk_settings( - self, - chunks_size: int | None = None, - data_files_size_in_mb: int | None = None, - video_files_size_in_mb: int | None = None, - ) -> None: - """Update chunk and file size settings after dataset creation. - - This allows users to customize storage organization without modifying the constructor. - These settings control how episodes are chunked and how large files can grow before - creating new ones. - - Args: - chunks_size: Maximum number of files per chunk directory. If None, keeps current value. - data_files_size_in_mb: Maximum size for data parquet files in MB. If None, keeps current value. - video_files_size_in_mb: Maximum size for video files in MB. If None, keeps current value. - """ - if chunks_size is not None: - if chunks_size <= 0: - raise ValueError(f"chunks_size must be positive, got {chunks_size}") - self.info["chunks_size"] = chunks_size - - if data_files_size_in_mb is not None: - if data_files_size_in_mb <= 0: - raise ValueError(f"data_files_size_in_mb must be positive, got {data_files_size_in_mb}") - self.info["data_files_size_in_mb"] = data_files_size_in_mb - - if video_files_size_in_mb is not None: - if video_files_size_in_mb <= 0: - raise ValueError(f"video_files_size_in_mb must be positive, got {video_files_size_in_mb}") - self.info["video_files_size_in_mb"] = video_files_size_in_mb - - # Update the info file on disk - write_info(self.info, self.root) - - def get_chunk_settings(self) -> dict[str, int]: - """Get current chunk and file size settings. - - Returns: - Dict containing chunks_size, data_files_size_in_mb, and video_files_size_in_mb. - """ - return { - "chunks_size": self.chunks_size, - "data_files_size_in_mb": self.data_files_size_in_mb, - "video_files_size_in_mb": self.video_files_size_in_mb, - } - - def __repr__(self): - feature_keys = list(self.features) - return ( - f"{self.__class__.__name__}({{\n" - f" Repository ID: '{self.repo_id}',\n" - f" Total episodes: '{self.total_episodes}',\n" - f" Total frames: '{self.total_frames}',\n" - f" Features: '{feature_keys}',\n" - "})',\n" - ) - - @classmethod - def create( - cls, - repo_id: str, - fps: int, - features: dict, - robot_type: str | None = None, - root: str | Path | None = None, - use_videos: bool = True, - metadata_buffer_size: int = 10, - chunks_size: int | None = None, - data_files_size_in_mb: int | None = None, - video_files_size_in_mb: int | None = None, - ) -> "LeRobotDatasetMetadata": - """Creates metadata for a LeRobotDataset.""" - obj = cls.__new__(cls) - obj.repo_id = repo_id - obj.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id - - obj.root.mkdir(parents=True, exist_ok=False) - - features = {**features, **DEFAULT_FEATURES} - _validate_feature_names(features) - - obj.tasks = None - obj.subtasks = None - obj.episodes = None - obj.stats = None - obj.info = create_empty_dataset_info( - CODEBASE_VERSION, - fps, - features, - use_videos, - robot_type, - chunks_size, - data_files_size_in_mb, - video_files_size_in_mb, - ) - if len(obj.video_keys) > 0 and not use_videos: - raise ValueError( - f"Features contain video keys {obj.video_keys}, but 'use_videos' is set to False. " - "Either remove video features from the features dict, or set 'use_videos=True'." - ) - write_json(obj.info, obj.root / INFO_PATH) - obj.revision = None - obj.writer = None - obj.latest_episode = None - obj.metadata_buffer = [] - obj.metadata_buffer_size = metadata_buffer_size - return obj - def _encode_video_worker( video_key: str, @@ -1723,182 +1240,7 @@ class LeRobotDataset(torch.utils.data.Dataset): return obj -class MultiLeRobotDataset(torch.utils.data.Dataset): - """A dataset consisting of multiple underlying `LeRobotDataset`s. - - The underlying `LeRobotDataset`s are effectively concatenated, and this class adopts much of the API - structure of `LeRobotDataset`. - """ - - def __init__( - self, - repo_ids: list[str], - root: str | Path | None = None, - episodes: dict | None = None, - image_transforms: Callable | None = None, - delta_timestamps: dict[str, list[float]] | None = None, - tolerances_s: dict | None = None, - download_videos: bool = True, - video_backend: str | None = None, - ): - super().__init__() - self.repo_ids = repo_ids - self.root = Path(root) if root else HF_LEROBOT_HOME - self.tolerances_s = tolerances_s if tolerances_s else dict.fromkeys(repo_ids, 0.0001) - # Construct the underlying datasets passing everything but `transform` and `delta_timestamps` which - # are handled by this class. - self._datasets = [ - LeRobotDataset( - repo_id, - root=self.root / repo_id, - episodes=episodes[repo_id] if episodes else None, - image_transforms=image_transforms, - delta_timestamps=delta_timestamps, - tolerance_s=self.tolerances_s[repo_id], - download_videos=download_videos, - video_backend=video_backend, - ) - for repo_id in repo_ids - ] - - # Disable any data keys that are not common across all of the datasets. Note: we may relax this - # restriction in future iterations of this class. For now, this is necessary at least for being able - # to use PyTorch's default DataLoader collate function. - self.disabled_features = set() - intersection_features = set(self._datasets[0].features) - for ds in self._datasets: - intersection_features.intersection_update(ds.features) - if len(intersection_features) == 0: - raise RuntimeError( - "Multiple datasets were provided but they had no keys common to all of them. " - "The multi-dataset functionality currently only keeps common keys." - ) - for repo_id, ds in zip(self.repo_ids, self._datasets, strict=True): - extra_keys = set(ds.features).difference(intersection_features) - if extra_keys: - logger.warning( - f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the " - "other datasets." - ) - self.disabled_features.update(extra_keys) - - self.image_transforms = image_transforms - self.delta_timestamps = delta_timestamps - # TODO(rcadene, aliberts): We should not perform this aggregation for datasets - # with multiple robots of different ranges. Instead we should have one normalization - # per robot. - self.stats = aggregate_stats([dataset.meta.stats for dataset in self._datasets]) - - @property - def repo_id_to_index(self): - """Return a mapping from dataset repo_id to a dataset index automatically created by this class. - - This index is incorporated as a data key in the dictionary returned by `__getitem__`. - """ - return {repo_id: i for i, repo_id in enumerate(self.repo_ids)} - - @property - def fps(self) -> int: - """Frames per second used during data collection. - - NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info. - """ - return self._datasets[0].meta.info["fps"] - - @property - def video(self) -> bool: - """Returns True if this dataset loads video frames from mp4 files. - - Returns False if it only loads images from png files. - - NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info. - """ - return self._datasets[0].meta.info.get("video", False) - - @property - def features(self) -> datasets.Features: - features = {} - for dataset in self._datasets: - features.update({k: v for k, v in dataset.hf_features.items() if k not in self.disabled_features}) - return features - - @property - def camera_keys(self) -> list[str]: - """Keys to access image and video stream from cameras.""" - keys = [] - for key, feats in self.features.items(): - if isinstance(feats, (datasets.Image | VideoFrame)): - keys.append(key) - return keys - - @property - def video_frame_keys(self) -> list[str]: - """Keys to access video frames that requires to be decoded into images. - - Note: It is empty if the dataset contains images only, - or equal to `self.cameras` if the dataset contains videos only, - or can even be a subset of `self.cameras` in a case of a mixed image/video dataset. - """ - video_frame_keys = [] - for key, feats in self.features.items(): - if isinstance(feats, VideoFrame): - video_frame_keys.append(key) - return video_frame_keys - - @property - def num_frames(self) -> int: - """Number of samples/frames.""" - return sum(d.num_frames for d in self._datasets) - - @property - def num_episodes(self) -> int: - """Number of episodes.""" - return sum(d.num_episodes for d in self._datasets) - - @property - def tolerance_s(self) -> float: - """Tolerance in seconds used to discard loaded frames when their timestamps - are not close enough from the requested frames. It is only used when `delta_timestamps` - is provided or when loading video frames from mp4 files. - """ - # 1e-4 to account for possible numerical error - return 1 / self.fps - 1e-4 - - def __len__(self): - return self.num_frames - - def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: - if idx >= len(self): - raise IndexError(f"Index {idx} out of bounds.") - # Determine which dataset to get an item from based on the index. - start_idx = 0 - dataset_idx = 0 - for dataset in self._datasets: - if idx >= start_idx + dataset.num_frames: - start_idx += dataset.num_frames - dataset_idx += 1 - continue - break - else: - raise AssertionError("We expect the loop to break out as long as the index is within bounds.") - item = self._datasets[dataset_idx][idx - start_idx] - item["dataset_index"] = torch.tensor(dataset_idx) - for data_key in self.disabled_features: - if data_key in item: - del item[data_key] - - return item - - def __repr__(self): - return ( - f"{self.__class__.__name__}(\n" - f" Repository IDs: '{self.repo_ids}',\n" - f" Number of Samples: {self.num_frames},\n" - f" Number of Episodes: {self.num_episodes},\n" - f" Type: {'video (.mp4)' if self.video else 'image (.png)'},\n" - f" Recorded Frames per Second: {self.fps},\n" - f" Camera Keys: {self.camera_keys},\n" - f" Video Frame Keys: {self.video_frame_keys if self.video else 'N/A'},\n" - f" Transformations: {self.image_transforms},\n" - f")" - ) +# --------------------------------------------------------------------------- +# Backward-compatible re-export +# --------------------------------------------------------------------------- +from lerobot.datasets.multi_dataset import MultiLeRobotDataset # noqa: E402, F401 diff --git a/src/lerobot/datasets/multi_dataset.py b/src/lerobot/datasets/multi_dataset.py new file mode 100644 index 000000000..917d5c5eb --- /dev/null +++ b/src/lerobot/datasets/multi_dataset.py @@ -0,0 +1,210 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from collections.abc import Callable +from pathlib import Path + +import datasets +import torch +import torch.utils + +from lerobot.datasets.compute_stats import aggregate_stats +from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.datasets.video_utils import VideoFrame +from lerobot.utils.constants import HF_LEROBOT_HOME + +logger = logging.getLogger(__name__) + + +class MultiLeRobotDataset(torch.utils.data.Dataset): + """A dataset consisting of multiple underlying `LeRobotDataset`s. + + The underlying `LeRobotDataset`s are effectively concatenated, and this class adopts much of the API + structure of `LeRobotDataset`. + """ + + def __init__( + self, + repo_ids: list[str], + root: str | Path | None = None, + episodes: dict | None = None, + image_transforms: Callable | None = None, + delta_timestamps: dict[str, list[float]] | None = None, + tolerances_s: dict | None = None, + download_videos: bool = True, + video_backend: str | None = None, + ): + super().__init__() + self.repo_ids = repo_ids + self.root = Path(root) if root else HF_LEROBOT_HOME + self.tolerances_s = tolerances_s if tolerances_s else dict.fromkeys(repo_ids, 0.0001) + # Construct the underlying datasets passing everything but `transform` and `delta_timestamps` which + # are handled by this class. + self._datasets = [ + LeRobotDataset( + repo_id, + root=self.root / repo_id, + episodes=episodes[repo_id] if episodes else None, + image_transforms=image_transforms, + delta_timestamps=delta_timestamps, + tolerance_s=self.tolerances_s[repo_id], + download_videos=download_videos, + video_backend=video_backend, + ) + for repo_id in repo_ids + ] + + # Disable any data keys that are not common across all of the datasets. Note: we may relax this + # restriction in future iterations of this class. For now, this is necessary at least for being able + # to use PyTorch's default DataLoader collate function. + self.disabled_features = set() + intersection_features = set(self._datasets[0].features) + for ds in self._datasets: + intersection_features.intersection_update(ds.features) + if len(intersection_features) == 0: + raise RuntimeError( + "Multiple datasets were provided but they had no keys common to all of them. " + "The multi-dataset functionality currently only keeps common keys." + ) + for repo_id, ds in zip(self.repo_ids, self._datasets, strict=True): + extra_keys = set(ds.features).difference(intersection_features) + if extra_keys: + logger.warning( + f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the " + "other datasets." + ) + self.disabled_features.update(extra_keys) + + self.image_transforms = image_transforms + self.delta_timestamps = delta_timestamps + # TODO(rcadene, aliberts): We should not perform this aggregation for datasets + # with multiple robots of different ranges. Instead we should have one normalization + # per robot. + self.stats = aggregate_stats([dataset.meta.stats for dataset in self._datasets]) + + @property + def repo_id_to_index(self): + """Return a mapping from dataset repo_id to a dataset index automatically created by this class. + + This index is incorporated as a data key in the dictionary returned by `__getitem__`. + """ + return {repo_id: i for i, repo_id in enumerate(self.repo_ids)} + + @property + def fps(self) -> int: + """Frames per second used during data collection. + + NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info. + """ + return self._datasets[0].meta.info["fps"] + + @property + def video(self) -> bool: + """Returns True if this dataset loads video frames from mp4 files. + + Returns False if it only loads images from png files. + + NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info. + """ + return self._datasets[0].meta.info.get("video", False) + + @property + def features(self) -> datasets.Features: + features = {} + for dataset in self._datasets: + features.update({k: v for k, v in dataset.hf_features.items() if k not in self.disabled_features}) + return features + + @property + def camera_keys(self) -> list[str]: + """Keys to access image and video stream from cameras.""" + keys = [] + for key, feats in self.features.items(): + if isinstance(feats, (datasets.Image | VideoFrame)): + keys.append(key) + return keys + + @property + def video_frame_keys(self) -> list[str]: + """Keys to access video frames that requires to be decoded into images. + + Note: It is empty if the dataset contains images only, + or equal to `self.cameras` if the dataset contains videos only, + or can even be a subset of `self.cameras` in a case of a mixed image/video dataset. + """ + video_frame_keys = [] + for key, feats in self.features.items(): + if isinstance(feats, VideoFrame): + video_frame_keys.append(key) + return video_frame_keys + + @property + def num_frames(self) -> int: + """Number of samples/frames.""" + return sum(d.num_frames for d in self._datasets) + + @property + def num_episodes(self) -> int: + """Number of episodes.""" + return sum(d.num_episodes for d in self._datasets) + + @property + def tolerance_s(self) -> float: + """Tolerance in seconds used to discard loaded frames when their timestamps + are not close enough from the requested frames. It is only used when `delta_timestamps` + is provided or when loading video frames from mp4 files. + """ + # 1e-4 to account for possible numerical error + return 1 / self.fps - 1e-4 + + def __len__(self): + return self.num_frames + + def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: + if idx >= len(self): + raise IndexError(f"Index {idx} out of bounds.") + # Determine which dataset to get an item from based on the index. + start_idx = 0 + dataset_idx = 0 + for dataset in self._datasets: + if idx >= start_idx + dataset.num_frames: + start_idx += dataset.num_frames + dataset_idx += 1 + continue + break + else: + raise AssertionError("We expect the loop to break out as long as the index is within bounds.") + item = self._datasets[dataset_idx][idx - start_idx] + item["dataset_index"] = torch.tensor(dataset_idx) + for data_key in self.disabled_features: + if data_key in item: + del item[data_key] + + return item + + def __repr__(self): + return ( + f"{self.__class__.__name__}(\n" + f" Repository IDs: '{self.repo_ids}',\n" + f" Number of Samples: {self.num_frames},\n" + f" Number of Episodes: {self.num_episodes},\n" + f" Type: {'video (.mp4)' if self.video else 'image (.png)'},\n" + f" Recorded Frames per Second: {self.fps},\n" + f" Camera Keys: {self.camera_keys},\n" + f" Video Frame Keys: {self.video_frame_keys if self.video else 'N/A'},\n" + f" Transformations: {self.image_transforms},\n" + f")" + ) diff --git a/src/lerobot/datasets/utils.py b/src/lerobot/datasets/utils.py index 8bc56a1bd..d1bddbf44 100644 --- a/src/lerobot/datasets/utils.py +++ b/src/lerobot/datasets/utils.py @@ -17,35 +17,21 @@ import contextlib import importlib.resources import json import logging -from collections import deque -from collections.abc import Iterable, Iterator -from pathlib import Path -from pprint import pformat +from collections.abc import Iterator from typing import Any import datasets import numpy as np import packaging.version -import pandas -import pandas as pd -import pyarrow.dataset as pa_ds -import pyarrow.parquet as pq import torch -from datasets import Dataset -from datasets.table import embed_table_storage from huggingface_hub import DatasetCard, DatasetCardData, HfApi from huggingface_hub.errors import RevisionNotFoundError -from PIL import Image as PILImage -from torchvision import transforms -from lerobot.configs.types import FeatureType, PolicyFeature from lerobot.datasets.backward_compatibility import ( FUTURE_MESSAGE, BackwardCompatibilityError, ForwardCompatibilityError, ) -from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_STR -from lerobot.utils.utils import SuppressProgressBars, is_valid_numpy_dtype_string DEFAULT_CHUNK_SIZE = 1000 # Max number of files per chunk DEFAULT_DATA_FILE_SIZE_IN_MB = 100 # Max size per file @@ -79,21 +65,6 @@ DEFAULT_FEATURES = { } -def get_parquet_file_size_in_mb(parquet_path: str | Path) -> float: - metadata = pq.read_metadata(parquet_path) - total_uncompressed_size = 0 - for row_group in range(metadata.num_row_groups): - rg_metadata = metadata.row_group(row_group) - for column in range(rg_metadata.num_columns): - col_metadata = rg_metadata.column(column) - total_uncompressed_size += col_metadata.total_uncompressed_size - return total_uncompressed_size / (1024**2) - - -def get_hf_dataset_size_in_mb(hf_ds: Dataset) -> int: - return hf_ds.data.nbytes // (1024**2) - - def update_chunk_file_indices(chunk_idx: int, file_idx: int, chunks_size: int) -> tuple[int, int]: if file_idx == chunks_size - 1: file_idx = 0 @@ -103,43 +74,6 @@ def update_chunk_file_indices(chunk_idx: int, file_idx: int, chunks_size: int) - return chunk_idx, file_idx -def load_nested_dataset( - pq_dir: Path, features: datasets.Features | None = None, episodes: list[int] | None = None -) -> Dataset: - """Find parquet files in provided directory {pq_dir}/chunk-xxx/file-xxx.parquet - Convert parquet files to pyarrow memory mapped in a cache folder for efficient RAM usage - Concatenate all pyarrow references to return HF Dataset format - - Args: - pq_dir: Directory containing parquet files - features: Optional features schema to ensure consistent loading of complex types like images - episodes: Optional list of episode indices to filter. Uses PyArrow predicate pushdown for efficiency. - """ - paths = sorted(pq_dir.glob("*/*.parquet")) - if len(paths) == 0: - raise FileNotFoundError(f"Provided directory does not contain any parquet file: {pq_dir}") - - with SuppressProgressBars(): - # We use .from_parquet() memory-mapped loading for efficiency - filters = pa_ds.field("episode_index").isin(episodes) if episodes is not None else None - return Dataset.from_parquet([str(path) for path in paths], filters=filters, features=features) - - -def get_parquet_num_frames(parquet_path: str | Path) -> int: - metadata = pq.read_metadata(parquet_path) - return metadata.num_rows - - -def get_file_size_in_mb(file_path: Path) -> float: - """Get file size on disk in megabytes. - - Args: - file_path (Path): Path to the file. - """ - file_size_bytes = file_path.stat().st_size - return file_size_bytes / (1024**2) - - def flatten_dict(d: dict, parent_key: str = "", sep: str = "/") -> dict: """Flatten a nested dictionary by joining keys with a separator. @@ -222,217 +156,6 @@ def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict: return unflatten_dict(serialized_dict) -def embed_images(dataset: datasets.Dataset) -> datasets.Dataset: - """Embed image bytes into the dataset table before saving to Parquet. - - This function prepares a Hugging Face dataset for serialization by converting - image objects into an embedded format that can be stored in Arrow/Parquet. - - Args: - dataset (datasets.Dataset): The input dataset, possibly containing image features. - - Returns: - datasets.Dataset: The dataset with images embedded in the table storage. - """ - # Embed image bytes into the table before saving to parquet - format = dataset.format - dataset = dataset.with_format("arrow") - dataset = dataset.map(embed_table_storage, batched=False) - dataset = dataset.with_format(**format) - return dataset - - -def load_json(fpath: Path) -> Any: - """Load data from a JSON file. - - Args: - fpath (Path): Path to the JSON file. - - Returns: - Any: The data loaded from the JSON file. - """ - with open(fpath) as f: - return json.load(f) - - -def write_json(data: dict, fpath: Path) -> None: - """Write data to a JSON file. - - Creates parent directories if they don't exist. - - Args: - data (dict): The dictionary to write. - fpath (Path): The path to the output JSON file. - """ - fpath.parent.mkdir(exist_ok=True, parents=True) - with open(fpath, "w") as f: - json.dump(data, f, indent=4, ensure_ascii=False) - - -def write_info(info: dict, local_dir: Path) -> None: - write_json(info, local_dir / INFO_PATH) - - -def load_info(local_dir: Path) -> dict: - """Load dataset info metadata from its standard file path. - - Also converts shape lists to tuples for consistency. - - Args: - local_dir (Path): The root directory of the dataset. - - Returns: - dict: The dataset information dictionary. - """ - info = load_json(local_dir / INFO_PATH) - for ft in info["features"].values(): - ft["shape"] = tuple(ft["shape"]) - return info - - -def write_stats(stats: dict, local_dir: Path) -> None: - """Serialize and write dataset statistics to their standard file path. - - Args: - stats (dict): The statistics dictionary (can contain tensors/numpy arrays). - local_dir (Path): The root directory of the dataset. - """ - serialized_stats = serialize_dict(stats) - write_json(serialized_stats, local_dir / STATS_PATH) - - -def cast_stats_to_numpy(stats: dict) -> dict[str, dict[str, np.ndarray]]: - """Recursively cast numerical values in a stats dictionary to numpy arrays. - - Args: - stats (dict): The statistics dictionary. - - Returns: - dict: The statistics dictionary with values cast to numpy arrays. - """ - stats = {key: np.array(value) for key, value in flatten_dict(stats).items()} - return unflatten_dict(stats) - - -def load_stats(local_dir: Path) -> dict[str, dict[str, np.ndarray]] | None: - """Load dataset statistics and cast numerical values to numpy arrays. - - Returns None if the stats file doesn't exist. - - Args: - local_dir (Path): The root directory of the dataset. - - Returns: - A dictionary of statistics or None if the file is not found. - """ - if not (local_dir / STATS_PATH).exists(): - return None - stats = load_json(local_dir / STATS_PATH) - return cast_stats_to_numpy(stats) - - -def write_tasks(tasks: pandas.DataFrame, local_dir: Path) -> None: - path = local_dir / DEFAULT_TASKS_PATH - path.parent.mkdir(parents=True, exist_ok=True) - tasks.to_parquet(path) - - -def load_tasks(local_dir: Path) -> pandas.DataFrame: - tasks = pd.read_parquet(local_dir / DEFAULT_TASKS_PATH) - tasks.index.name = "task" - return tasks - - -def load_subtasks(local_dir: Path) -> pandas.DataFrame | None: - """Load subtasks from subtasks.parquet if it exists.""" - subtasks_path = local_dir / DEFAULT_SUBTASKS_PATH - if subtasks_path.exists(): - return pd.read_parquet(subtasks_path) - return None - - -def write_episodes(episodes: Dataset, local_dir: Path) -> None: - """Write episode metadata to a parquet file in the LeRobot v3.0 format. - This function writes episode-level metadata to a single parquet file. - Used primarily during dataset conversion (v2.1 → v3.0) and in test fixtures. - - Args: - episodes: HuggingFace Dataset containing episode metadata - local_dir: Root directory where the dataset will be stored - """ - episode_size_mb = get_hf_dataset_size_in_mb(episodes) - if episode_size_mb > DEFAULT_DATA_FILE_SIZE_IN_MB: - raise NotImplementedError( - f"Episodes dataset is too large ({episode_size_mb} MB) to write to a single file. " - f"The current limit is {DEFAULT_DATA_FILE_SIZE_IN_MB} MB. " - "This function only supports single-file episode metadata. " - ) - - fpath = local_dir / DEFAULT_EPISODES_PATH.format(chunk_index=0, file_index=0) - fpath.parent.mkdir(parents=True, exist_ok=True) - episodes.to_parquet(fpath) - - -def load_episodes(local_dir: Path) -> datasets.Dataset: - episodes = load_nested_dataset(local_dir / EPISODES_DIR) - # Select episode features/columns containing references to episode data and videos - # (e.g. tasks, dataset_from_index, dataset_to_index, data/chunk_index, data/file_index, etc.) - # This is to speedup access to these data, instead of having to load episode stats. - episodes = episodes.select_columns([key for key in episodes.features if not key.startswith("stats/")]) - return episodes - - -def load_image_as_numpy( - fpath: str | Path, dtype: np.dtype = np.float32, channel_first: bool = True -) -> np.ndarray: - """Load an image from a file into a numpy array. - - Args: - fpath (str | Path): Path to the image file. - dtype (np.dtype): The desired data type of the output array. If floating, - pixels are scaled to [0, 1]. - channel_first (bool): If True, converts the image to (C, H, W) format. - Otherwise, it remains in (H, W, C) format. - - Returns: - np.ndarray: The image as a numpy array. - """ - img = PILImage.open(fpath).convert("RGB") - img_array = np.array(img, dtype=dtype) - if channel_first: # (H, W, C) -> (C, H, W) - img_array = np.transpose(img_array, (2, 0, 1)) - if np.issubdtype(dtype, np.floating): - img_array /= 255.0 - return img_array - - -def hf_transform_to_torch(items_dict: dict[str, list[Any]]) -> dict[str, list[torch.Tensor | str]]: - """Convert a batch from a Hugging Face dataset to torch tensors. - - This transform function converts items from Hugging Face dataset format (pyarrow) - to torch tensors. Importantly, images are converted from PIL objects (H, W, C, uint8) - to a torch image representation (C, H, W, float32) in the range [0, 1]. Other - types are converted to torch.tensor. - - Args: - items_dict (dict): A dictionary representing a batch of data from a - Hugging Face dataset. - - Returns: - dict: The batch with items converted to torch tensors. - """ - for key in items_dict: - first_item = items_dict[key][0] - if isinstance(first_item, PILImage.Image): - to_tensor = transforms.ToTensor() - items_dict[key] = [to_tensor(img) for img in items_dict[key]] - elif first_item is None: - pass - else: - items_dict[key] = [x if isinstance(x, str) else torch.tensor(x) for x in items_dict[key]] - return items_dict - - def is_valid_version(version: str) -> bool: """Check if a string is a valid PEP 440 version. @@ -560,337 +283,6 @@ def get_safe_version(repo_id: str, version: str | packaging.version.Version) -> raise ForwardCompatibilityError(repo_id, min(upper_versions)) -def get_hf_features_from_features(features: dict) -> datasets.Features: - """Convert a LeRobot features dictionary to a `datasets.Features` object. - - Args: - features (dict): A LeRobot-style feature dictionary. - - Returns: - datasets.Features: The corresponding Hugging Face `datasets.Features` object. - - Raises: - ValueError: If a feature has an unsupported shape. - """ - hf_features = {} - for key, ft in features.items(): - if ft["dtype"] == "video": - continue - elif ft["dtype"] == "image": - hf_features[key] = datasets.Image() - elif ft["shape"] == (1,): - hf_features[key] = datasets.Value(dtype=ft["dtype"]) - elif len(ft["shape"]) == 1: - hf_features[key] = datasets.Sequence( - length=ft["shape"][0], feature=datasets.Value(dtype=ft["dtype"]) - ) - elif len(ft["shape"]) == 2: - hf_features[key] = datasets.Array2D(shape=ft["shape"], dtype=ft["dtype"]) - elif len(ft["shape"]) == 3: - hf_features[key] = datasets.Array3D(shape=ft["shape"], dtype=ft["dtype"]) - elif len(ft["shape"]) == 4: - hf_features[key] = datasets.Array4D(shape=ft["shape"], dtype=ft["dtype"]) - elif len(ft["shape"]) == 5: - hf_features[key] = datasets.Array5D(shape=ft["shape"], dtype=ft["dtype"]) - else: - raise ValueError(f"Corresponding feature is not valid: {ft}") - - return datasets.Features(hf_features) - - -def _validate_feature_names(features: dict[str, dict]) -> None: - """Validate that feature names do not contain invalid characters. - - Args: - features (dict): The LeRobot features dictionary. - - Raises: - ValueError: If any feature name contains '/'. - """ - invalid_features = {name: ft for name, ft in features.items() if "/" in name} - if invalid_features: - raise ValueError(f"Feature names should not contain '/'. Found '/' in '{invalid_features}'.") - - -def hw_to_dataset_features( - hw_features: dict[str, type | tuple], prefix: str, use_video: bool = True -) -> dict[str, dict]: - """Convert hardware-specific features to a LeRobot dataset feature dictionary. - - This function takes a dictionary describing hardware outputs (like joint states - or camera image shapes) and formats it into the standard LeRobot feature - specification. - - Args: - hw_features (dict): Dictionary mapping feature names to their type (float for - joints) or shape (tuple for images). - prefix (str): The prefix to add to the feature keys (e.g., "observation" - or "action"). - use_video (bool): If True, image features are marked as "video", otherwise "image". - - Returns: - dict: A LeRobot features dictionary. - """ - features = {} - joint_fts = { - key: ftype - for key, ftype in hw_features.items() - if ftype is float or (isinstance(ftype, PolicyFeature) and ftype.type != FeatureType.VISUAL) - } - cam_fts = {key: shape for key, shape in hw_features.items() if isinstance(shape, tuple)} - - if joint_fts and prefix == ACTION: - features[prefix] = { - "dtype": "float32", - "shape": (len(joint_fts),), - "names": list(joint_fts), - } - - if joint_fts and prefix == OBS_STR: - features[f"{prefix}.state"] = { - "dtype": "float32", - "shape": (len(joint_fts),), - "names": list(joint_fts), - } - - for key, shape in cam_fts.items(): - features[f"{prefix}.images.{key}"] = { - "dtype": "video" if use_video else "image", - "shape": shape, - "names": ["height", "width", "channels"], - } - - _validate_feature_names(features) - return features - - -def build_dataset_frame( - ds_features: dict[str, dict], values: dict[str, Any], prefix: str -) -> dict[str, np.ndarray]: - """Construct a single data frame from raw values based on dataset features. - - A "frame" is a dictionary containing all the data for a single timestep, - formatted as numpy arrays according to the feature specification. - - Args: - ds_features (dict): The LeRobot dataset features dictionary. - values (dict): A dictionary of raw values from the hardware/environment. - prefix (str): The prefix to filter features by (e.g., "observation" - or "action"). - - Returns: - dict: A dictionary representing a single frame of data. - """ - frame = {} - for key, ft in ds_features.items(): - if key in DEFAULT_FEATURES or not key.startswith(prefix): - continue - elif ft["dtype"] == "float32" and len(ft["shape"]) == 1: - frame[key] = np.array([values[name] for name in ft["names"]], dtype=np.float32) - elif ft["dtype"] in ["image", "video"]: - frame[key] = values[key.removeprefix(f"{prefix}.images.")] - - return frame - - -def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFeature]: - """Convert dataset features to policy features. - - This function transforms the dataset's feature specification into a format - that a policy can use, classifying features by type (e.g., visual, state, - action) and ensuring correct shapes (e.g., channel-first for images). - - Args: - features (dict): The LeRobot dataset features dictionary. - - Returns: - dict: A dictionary mapping feature keys to `PolicyFeature` objects. - - Raises: - ValueError: If an image feature does not have a 3D shape. - """ - # TODO(aliberts): Implement "type" in dataset features and simplify this - policy_features = {} - for key, ft in features.items(): - shape = ft["shape"] - if ft["dtype"] in ["image", "video"]: - type = FeatureType.VISUAL - if len(shape) != 3: - raise ValueError(f"Number of dimensions of {key} != 3 (shape={shape})") - - names = ft["names"] - # Backward compatibility for "channel" which is an error introduced in LeRobotDataset v2.0 for ported datasets. - if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w) - shape = (shape[2], shape[0], shape[1]) - elif key == OBS_ENV_STATE: - type = FeatureType.ENV - elif key.startswith(OBS_STR): - type = FeatureType.STATE - elif key.startswith(ACTION): - type = FeatureType.ACTION - else: - continue - - policy_features[key] = PolicyFeature( - type=type, - shape=shape, - ) - - return policy_features - - -def combine_feature_dicts(*dicts: dict) -> dict: - """Merge LeRobot grouped feature dicts. - - - For 1D numeric specs (dtype not image/video/string) with "names": we merge the names and recompute the shape. - - For others (e.g. `observation.images.*`), the last one wins (if they are identical). - - Args: - *dicts: A variable number of LeRobot feature dictionaries to merge. - - Returns: - dict: A single merged feature dictionary. - - Raises: - ValueError: If there's a dtype mismatch for a feature being merged. - """ - out: dict = {} - for d in dicts: - for key, value in d.items(): - if not isinstance(value, dict): - out[key] = value - continue - - dtype = value.get("dtype") - shape = value.get("shape") - is_vector = ( - dtype not in ("image", "video", "string") - and isinstance(shape, tuple) - and len(shape) == 1 - and "names" in value - ) - - if is_vector: - # Initialize or retrieve the accumulating dict for this feature key - target = out.setdefault(key, {"dtype": dtype, "names": [], "shape": (0,)}) - # Ensure consistent data types across merged entries - if "dtype" in target and dtype != target["dtype"]: - raise ValueError(f"dtype mismatch for '{key}': {target['dtype']} vs {dtype}") - - # Merge feature names: append only new ones to preserve order without duplicates - seen = set(target["names"]) - for n in value["names"]: - if n not in seen: - target["names"].append(n) - seen.add(n) - # Recompute the shape to reflect the updated number of features - target["shape"] = (len(target["names"]),) - else: - # For images/videos and non-1D entries: override with the latest definition - out[key] = value - return out - - -def create_empty_dataset_info( - codebase_version: str, - fps: int, - features: dict, - use_videos: bool, - robot_type: str | None = None, - chunks_size: int | None = None, - data_files_size_in_mb: int | None = None, - video_files_size_in_mb: int | None = None, -) -> dict: - """Create a template dictionary for a new dataset's `info.json`. - - Args: - codebase_version (str): The version of the LeRobot codebase. - fps (int): The frames per second of the data. - features (dict): The LeRobot features dictionary for the dataset. - use_videos (bool): Whether the dataset will store videos. - robot_type (str | None): The type of robot used, if any. - - Returns: - dict: A dictionary with the initial dataset metadata. - """ - return { - "codebase_version": codebase_version, - "robot_type": robot_type, - "total_episodes": 0, - "total_frames": 0, - "total_tasks": 0, - "chunks_size": chunks_size or DEFAULT_CHUNK_SIZE, - "data_files_size_in_mb": data_files_size_in_mb or DEFAULT_DATA_FILE_SIZE_IN_MB, - "video_files_size_in_mb": video_files_size_in_mb or DEFAULT_VIDEO_FILE_SIZE_IN_MB, - "fps": fps, - "splits": {}, - "data_path": DEFAULT_DATA_PATH, - "video_path": DEFAULT_VIDEO_PATH if use_videos else None, - "features": features, - } - - -def check_delta_timestamps( - delta_timestamps: dict[str, list[float]], fps: int, tolerance_s: float, raise_value_error: bool = True -) -> bool: - """Check if delta timestamps are multiples of 1/fps +/- tolerance. - - This ensures that adding these delta timestamps to any existing timestamp in - the dataset will result in a value that aligns with the dataset's frame rate. - - Args: - delta_timestamps (dict): A dictionary where values are lists of time - deltas in seconds. - fps (int): The frames per second of the dataset. - tolerance_s (float): The allowed tolerance in seconds. - raise_value_error (bool): If True, raises an error on failure. - - Returns: - bool: True if all deltas are valid, False otherwise. - - Raises: - ValueError: If any delta is outside the tolerance and `raise_value_error` is True. - """ - outside_tolerance = {} - for key, delta_ts in delta_timestamps.items(): - within_tolerance = [abs(ts * fps - round(ts * fps)) / fps <= tolerance_s for ts in delta_ts] - if not all(within_tolerance): - outside_tolerance[key] = [ - ts for ts, is_within in zip(delta_ts, within_tolerance, strict=True) if not is_within - ] - - if len(outside_tolerance) > 0: - if raise_value_error: - raise ValueError( - f""" - The following delta_timestamps are found outside of tolerance range. - Please make sure they are multiples of 1/{fps} +/- tolerance and adjust - their values accordingly. - \n{pformat(outside_tolerance)} - """ - ) - return False - - return True - - -def get_delta_indices(delta_timestamps: dict[str, list[float]], fps: int) -> dict[str, list[int]]: - """Convert delta timestamps in seconds to delta indices in frames. - - Args: - delta_timestamps (dict): A dictionary of time deltas in seconds. - fps (int): The frames per second of the dataset. - - Returns: - dict: A dictionary of frame delta indices. - """ - delta_indices = {} - for key, delta_ts in delta_timestamps.items(): - delta_indices[key] = [round(d * fps) for d in delta_ts] - - return delta_indices - - def cycle(iterable: Any) -> Iterator[Any]: """Create a dataloader-safe cyclical iterator. @@ -982,229 +374,6 @@ def create_lerobot_dataset_card( ) -def validate_frame(frame: dict, features: dict) -> None: - expected_features = set(features) - set(DEFAULT_FEATURES) - actual_features = set(frame) - - # task is a special required field that's not part of regular features - if "task" not in actual_features: - raise ValueError("Feature mismatch in `frame` dictionary:\nMissing features: {'task'}\n") - - # Remove task from actual_features for regular feature validation - actual_features_for_validation = actual_features - {"task"} - - error_message = validate_features_presence(actual_features_for_validation, expected_features) - - common_features = actual_features_for_validation & expected_features - for name in common_features: - error_message += validate_feature_dtype_and_shape(name, features[name], frame[name]) - - if error_message: - raise ValueError(error_message) - - -def validate_features_presence(actual_features: set[str], expected_features: set[str]) -> str: - """Check for missing or extra features in a frame. - - Args: - actual_features (set[str]): The set of feature names present in the frame. - expected_features (set[str]): The set of feature names expected in the frame. - - Returns: - str: An error message string if there's a mismatch, otherwise an empty string. - """ - error_message = "" - missing_features = expected_features - actual_features - extra_features = actual_features - expected_features - - if missing_features or extra_features: - error_message += "Feature mismatch in `frame` dictionary:\n" - if missing_features: - error_message += f"Missing features: {missing_features}\n" - if extra_features: - error_message += f"Extra features: {extra_features}\n" - - return error_message - - -def validate_feature_dtype_and_shape( - name: str, feature: dict, value: np.ndarray | PILImage.Image | str -) -> str: - """Validate the dtype and shape of a single feature's value. - - Args: - name (str): The name of the feature. - feature (dict): The feature specification from the LeRobot features dictionary. - value: The value of the feature to validate. - - Returns: - str: An error message if validation fails, otherwise an empty string. - - Raises: - NotImplementedError: If the feature dtype is not supported for validation. - """ - expected_dtype = feature["dtype"] - expected_shape = feature["shape"] - if is_valid_numpy_dtype_string(expected_dtype): - return validate_feature_numpy_array(name, expected_dtype, expected_shape, value) - elif expected_dtype in ["image", "video"]: - return validate_feature_image_or_video(name, expected_shape, value) - elif expected_dtype == "string": - return validate_feature_string(name, value) - else: - raise NotImplementedError(f"The feature dtype '{expected_dtype}' is not implemented yet.") - - -def validate_feature_numpy_array( - name: str, expected_dtype: str, expected_shape: list[int], value: np.ndarray -) -> str: - """Validate a feature that is expected to be a numpy array. - - Args: - name (str): The name of the feature. - expected_dtype (str): The expected numpy dtype as a string. - expected_shape (list[int]): The expected shape. - value (np.ndarray): The numpy array to validate. - - Returns: - str: An error message if validation fails, otherwise an empty string. - """ - error_message = "" - if isinstance(value, np.ndarray): - actual_dtype = value.dtype - actual_shape = value.shape - - if actual_dtype != np.dtype(expected_dtype): - error_message += f"The feature '{name}' of dtype '{actual_dtype}' is not of the expected dtype '{expected_dtype}'.\n" - - if actual_shape != expected_shape: - error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{expected_shape}'.\n" - else: - error_message += f"The feature '{name}' is not a 'np.ndarray'. Expected type is '{expected_dtype}', but type '{type(value)}' provided instead.\n" - - return error_message - - -def validate_feature_image_or_video( - name: str, expected_shape: list[str], value: np.ndarray | PILImage.Image -) -> str: - """Validate a feature that is expected to be an image or video frame. - - Accepts `np.ndarray` (channel-first or channel-last) or `PIL.Image.Image`. - - Args: - name (str): The name of the feature. - expected_shape (list[str]): The expected shape (C, H, W). - value: The image data to validate. - - Returns: - str: An error message if validation fails, otherwise an empty string. - """ - # Note: The check of pixels range ([0,1] for float and [0,255] for uint8) is done by the image writer threads. - error_message = "" - if isinstance(value, np.ndarray): - actual_shape = value.shape - c, h, w = expected_shape - if len(actual_shape) != 3 or (actual_shape != (c, h, w) and actual_shape != (h, w, c)): - error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(c, h, w)}' or '{(h, w, c)}'.\n" - elif isinstance(value, PILImage.Image): - pass - else: - error_message += f"The feature '{name}' is expected to be of type 'PIL.Image' or 'np.ndarray' channel first or channel last, but type '{type(value)}' provided instead.\n" - - return error_message - - -def validate_feature_string(name: str, value: str) -> str: - """Validate a feature that is expected to be a string. - - Args: - name (str): The name of the feature. - value (str): The value to validate. - - Returns: - str: An error message if validation fails, otherwise an empty string. - """ - if not isinstance(value, str): - return f"The feature '{name}' is expected to be of type 'str', but type '{type(value)}' provided instead.\n" - return "" - - -def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features: dict) -> None: - """Validate the episode buffer before it's written to disk. - - Ensures the buffer has the required keys, contains at least one frame, and - has features consistent with the dataset's specification. - - Args: - episode_buffer (dict): The buffer containing data for a single episode. - total_episodes (int): The current total number of episodes in the dataset. - features (dict): The LeRobot features dictionary for the dataset. - - Raises: - ValueError: If the buffer is invalid. - NotImplementedError: If the episode index is manually set and doesn't match. - """ - if "size" not in episode_buffer: - raise ValueError("size key not found in episode_buffer") - - if "task" not in episode_buffer: - raise ValueError("task key not found in episode_buffer") - - if episode_buffer["episode_index"] != total_episodes: - # TODO(aliberts): Add option to use existing episode_index - raise NotImplementedError( - "You might have manually provided the episode_buffer with an episode_index that doesn't " - "match the total number of episodes already in the dataset. This is not supported for now." - ) - - if episode_buffer["size"] == 0: - raise ValueError("You must add one or several frames with `add_frame` before calling `add_episode`.") - - buffer_keys = set(episode_buffer.keys()) - {"task", "size"} - if not buffer_keys == set(features): - raise ValueError( - f"Features from `episode_buffer` don't match the ones in `features`." - f"In episode_buffer not in features: {buffer_keys - set(features)}" - f"In features not in episode_buffer: {set(features) - buffer_keys}" - ) - - -def to_parquet_with_hf_images( - df: pandas.DataFrame, path: Path, features: datasets.Features | None = None -) -> None: - """This function correctly writes to parquet a panda DataFrame that contains images encoded by HF dataset. - This way, it can be loaded by HF dataset and correctly formatted images are returned. - - Args: - df: DataFrame to write to parquet. - path: Path to write the parquet file. - features: Optional HuggingFace Features schema. If provided, ensures image columns - are properly typed as Image() in the parquet schema. - """ - # TODO(qlhoest): replace this weird synthax by `df.to_parquet(path)` only - ds = datasets.Dataset.from_dict(df.to_dict(orient="list"), features=features) - ds.to_parquet(path) - - -def item_to_torch(item: dict) -> dict: - """Convert all items in a dictionary to PyTorch tensors where appropriate. - - This function is used to convert an item from a streaming dataset to PyTorch tensors. - - Args: - item (dict): Dictionary of items from a dataset. - - Returns: - dict: Dictionary with all tensor-like items converted to torch.Tensor. - """ - for key, val in item.items(): - if isinstance(val, (np.ndarray | list)) and key not in ["task"]: - # Convert numpy arrays and lists to torch tensors - item[key] = torch.tensor(val) - return item - - def is_float_in_list(target, float_list, threshold=1e-6): return any(abs(target - x) <= threshold for x in float_list) @@ -1216,164 +385,6 @@ def find_float_index(target, float_list, threshold=1e-6): return -1 -class LookBackError(Exception): - """ - Exception raised when trying to look back in the history of a Backtrackable object. - """ - - pass - - -class LookAheadError(Exception): - """ - Exception raised when trying to look ahead in the future of a Backtrackable object. - """ - - pass - - -class Backtrackable[T]: - """ - Wrap any iterator/iterable so you can step back up to `history` items - and look ahead up to `lookahead` items. - - This is useful for streaming datasets where you need to access previous and future items - but can't load the entire dataset into memory. - - Example: - ------- - ```python - ds = load_dataset("c4", "en", streaming=True, split="train") - rev = Backtrackable(ds, history=3, lookahead=2) - - x0 = next(rev) # forward - x1 = next(rev) - x2 = next(rev) - - # Look ahead - x3_peek = rev.peek_ahead(1) # next item without moving cursor - x4_peek = rev.peek_ahead(2) # two items ahead - - # Look back - x1_again = rev.peek_back(1) # previous item without moving cursor - x0_again = rev.peek_back(2) # two items back - - # Move backward - x1_back = rev.prev() # back one step - next(rev) # returns x2, continues forward from where we were - ``` - """ - - __slots__ = ("_source", "_back_buf", "_ahead_buf", "_cursor", "_history", "_lookahead") - - def __init__(self, iterable: Iterable[T], *, history: int = 1, lookahead: int = 0): - if history < 1: - raise ValueError("history must be >= 1") - if lookahead <= 0: - raise ValueError("lookahead must be > 0") - - self._source: Iterator[T] = iter(iterable) - self._back_buf: deque[T] = deque(maxlen=history) - self._ahead_buf: deque[T] = deque(maxlen=lookahead) if lookahead > 0 else deque() - self._cursor: int = 0 - self._history = history - self._lookahead = lookahead - - def __iter__(self) -> "Backtrackable[T]": - return self - - def __next__(self) -> T: - # If we've stepped back, consume from back buffer first - if self._cursor < 0: # -1 means "last item", etc. - self._cursor += 1 - return self._back_buf[self._cursor] - - # If we have items in the ahead buffer, use them first - item = self._ahead_buf.popleft() if self._ahead_buf else next(self._source) - - # Add current item to back buffer and reset cursor - self._back_buf.append(item) - self._cursor = 0 - return item - - def prev(self) -> T: - """ - Step one item back in history and return it. - Raises IndexError if already at the oldest buffered item. - """ - if len(self._back_buf) + self._cursor <= 1: - raise LookBackError("At start of history") - - self._cursor -= 1 - return self._back_buf[self._cursor] - - def peek_back(self, n: int = 1) -> T: - """ - Look `n` items back (n=1 == previous item) without moving the cursor. - """ - if n < 0 or n + 1 > len(self._back_buf) + self._cursor: - raise LookBackError("peek_back distance out of range") - - return self._back_buf[self._cursor - (n + 1)] - - def peek_ahead(self, n: int = 1) -> T: - """ - Look `n` items ahead (n=1 == next item) without moving the cursor. - Fills the ahead buffer if necessary. - """ - if n < 1: - raise LookAheadError("peek_ahead distance must be 1 or more") - elif n > self._lookahead: - raise LookAheadError("peek_ahead distance exceeds lookahead limit") - - # Fill ahead buffer if we don't have enough items - while len(self._ahead_buf) < n: - try: - item = next(self._source) - self._ahead_buf.append(item) - - except StopIteration as err: - raise LookAheadError("peek_ahead: not enough items in source") from err - - return self._ahead_buf[n - 1] - - def history(self) -> list[T]: - """ - Return a copy of the buffered history (most recent last). - The list length ≤ `history` argument passed at construction. - """ - if self._cursor == 0: - return list(self._back_buf) - - # When cursor<0, slice so the order remains chronological - return list(self._back_buf)[: self._cursor or None] - - def can_peek_back(self, steps: int = 1) -> bool: - """ - Check if we can go back `steps` items without raising an IndexError. - """ - return steps <= len(self._back_buf) + self._cursor - - def can_peek_ahead(self, steps: int = 1) -> bool: - """ - Check if we can peek ahead `steps` items. - This may involve trying to fill the ahead buffer. - """ - if self._lookahead > 0 and steps > self._lookahead: - return False - - # Try to fill ahead buffer to check if we can peek that far - try: - while len(self._ahead_buf) < steps: - if self._lookahead > 0 and len(self._ahead_buf) >= self._lookahead: - return False - item = next(self._source) - self._ahead_buf.append(item) - return True - except StopIteration: - return False - - def safe_shard(dataset: datasets.IterableDataset, index: int, num_shards: int) -> datasets.Dataset: """ Safe shards the dataset. @@ -1381,3 +392,52 @@ def safe_shard(dataset: datasets.IterableDataset, index: int, num_shards: int) - shard_idx = min(dataset.num_shards, index + 1) - 1 return dataset.shard(num_shards, index=shard_idx) + + +# --------------------------------------------------------------------------- +# Backward-compatible re-exports: symbols moved to focused submodules. +# Existing ``from lerobot.datasets.utils import `` will keep working. +# --------------------------------------------------------------------------- +from lerobot.datasets.backtracking import Backtrackable, LookAheadError, LookBackError # noqa: E402, F401 +from lerobot.datasets.feature_utils import ( # noqa: E402, F401 + _validate_feature_names, + build_dataset_frame, + check_delta_timestamps, + combine_feature_dicts, + create_empty_dataset_info, + dataset_to_policy_features, + get_delta_indices, + get_hf_features_from_features, + hw_to_dataset_features, + validate_episode_buffer, + validate_feature_dtype_and_shape, + validate_feature_image_or_video, + validate_feature_numpy_array, + validate_feature_string, + validate_features_presence, + validate_frame, +) +from lerobot.datasets.io_utils import ( # noqa: E402, F401 + cast_stats_to_numpy, + embed_images, + get_file_size_in_mb, + get_hf_dataset_size_in_mb, + get_parquet_file_size_in_mb, + get_parquet_num_frames, + hf_transform_to_torch, + item_to_torch, + load_episodes, + load_image_as_numpy, + load_info, + load_json, + load_nested_dataset, + load_stats, + load_subtasks, + load_tasks, + to_parquet_with_hf_images, + write_episodes, + write_info, + write_json, + write_stats, + write_tasks, +) diff --git a/tests/datasets/test_aggregate.py b/tests/datasets/test_aggregate.py index 3609bac24..4ac7e001a 100644 --- a/tests/datasets/test_aggregate.py +++ b/tests/datasets/test_aggregate.py @@ -260,8 +260,8 @@ def test_aggregate_datasets(tmp_path, lerobot_dataset_factory): # Mock the revision to prevent Hub calls during dataset loading with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.return_value = str(tmp_path / "test_aggr") @@ -311,8 +311,8 @@ def test_aggregate_with_low_threshold(tmp_path, lerobot_dataset_factory): # Mock the revision to prevent Hub calls during dataset loading with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.return_value = str(tmp_path / "small_aggr") @@ -367,8 +367,8 @@ def test_video_timestamps_regression(tmp_path, lerobot_dataset_factory): ) with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.return_value = str(tmp_path / "regression_aggr") @@ -492,8 +492,8 @@ def test_aggregate_image_datasets(tmp_path, lerobot_dataset_factory): # Load the aggregated dataset with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.return_value = str(tmp_path / "image_aggr") @@ -562,8 +562,8 @@ def test_aggregate_already_merged_dataset(tmp_path, lerobot_dataset_factory): ) with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.return_value = str(tmp_path / "ds_ab") @@ -590,8 +590,8 @@ def test_aggregate_already_merged_dataset(tmp_path, lerobot_dataset_factory): ) with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.return_value = str(tmp_path / "ds_abc") diff --git a/tests/datasets/test_dataset_tools.py b/tests/datasets/test_dataset_tools.py index 1de199630..24daed91e 100644 --- a/tests/datasets/test_dataset_tools.py +++ b/tests/datasets/test_dataset_tools.py @@ -67,8 +67,8 @@ def test_delete_single_episode(sample_dataset, tmp_path): output_dir = tmp_path / "filtered" with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.return_value = str(output_dir) @@ -93,8 +93,8 @@ def test_delete_multiple_episodes(sample_dataset, tmp_path): output_dir = tmp_path / "filtered" with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.return_value = str(output_dir) @@ -150,8 +150,8 @@ def test_split_by_episodes(sample_dataset, tmp_path): } with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" @@ -193,8 +193,8 @@ def test_split_by_fractions(sample_dataset, tmp_path): } with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" @@ -270,8 +270,8 @@ def test_merge_two_datasets(sample_dataset, tmp_path, empty_lerobot_dataset_fact dataset2.finalize() with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.return_value = str(tmp_path / "merged_dataset") @@ -310,8 +310,8 @@ def test_add_features_with_values(sample_dataset, tmp_path): } with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.return_value = str(tmp_path / "with_reward") @@ -346,8 +346,8 @@ def test_add_features_with_callable(sample_dataset, tmp_path): "reward": (compute_reward, feature_info), } with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.return_value = str(tmp_path / "with_reward") @@ -401,8 +401,8 @@ def test_modify_features_add_and_remove(sample_dataset, tmp_path): feature_info = {"dtype": "float32", "shape": (1,), "names": None} with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.return_value = str(tmp_path / "modified") @@ -434,8 +434,8 @@ def test_modify_features_only_add(sample_dataset, tmp_path): feature_info = {"dtype": "float32", "shape": (1,), "names": None} with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.return_value = str(tmp_path / "modified") @@ -457,8 +457,8 @@ def test_modify_features_only_remove(sample_dataset, tmp_path): feature_info = {"dtype": "float32", "shape": (1,), "names": None} with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(kwargs.get("local_dir", tmp_path)) @@ -494,8 +494,8 @@ def test_remove_single_feature(sample_dataset, tmp_path): "reward": (np.random.randn(50, 1).astype(np.float32), feature_info), } with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(kwargs.get("local_dir", tmp_path)) @@ -521,8 +521,8 @@ def test_remove_single_feature(sample_dataset, tmp_path): def test_remove_multiple_features(sample_dataset, tmp_path): """Test removing multiple features at once.""" with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(kwargs.get("local_dir", tmp_path)) @@ -576,8 +576,8 @@ def test_remove_camera_feature(sample_dataset, tmp_path): camera_to_remove = camera_keys[0] with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.return_value = str(tmp_path / "without_camera") @@ -598,8 +598,8 @@ def test_remove_camera_feature(sample_dataset, tmp_path): def test_complex_workflow_integration(sample_dataset, tmp_path): """Test a complex workflow combining multiple operations.""" with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(kwargs.get("local_dir", tmp_path)) @@ -647,8 +647,8 @@ def test_delete_episodes_preserves_stats(sample_dataset, tmp_path): output_dir = tmp_path / "filtered" with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.return_value = str(output_dir) @@ -671,8 +671,8 @@ def test_delete_episodes_preserves_tasks(sample_dataset, tmp_path): output_dir = tmp_path / "filtered" with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.return_value = str(output_dir) @@ -699,8 +699,8 @@ def test_split_three_ways(sample_dataset, tmp_path): } with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" @@ -732,8 +732,8 @@ def test_split_preserves_stats(sample_dataset, tmp_path): splits = {"train": [0, 1, 2], "val": [3, 4]} with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" @@ -790,8 +790,8 @@ def test_merge_three_datasets(sample_dataset, tmp_path, empty_lerobot_dataset_fa datasets.append(dataset) with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.return_value = str(tmp_path / "merged_dataset") @@ -832,8 +832,8 @@ def test_merge_preserves_stats(sample_dataset, tmp_path, empty_lerobot_dataset_f dataset2.finalize() with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.return_value = str(tmp_path / "merged_dataset") @@ -866,8 +866,8 @@ def test_add_features_preserves_existing_stats(sample_dataset, tmp_path): } with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.return_value = str(tmp_path / "with_reward") @@ -890,8 +890,8 @@ def test_remove_feature_updates_stats(sample_dataset, tmp_path): feature_info = {"dtype": "float32", "shape": (1,), "names": None} with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(kwargs.get("local_dir", tmp_path)) @@ -919,8 +919,8 @@ def test_delete_consecutive_episodes(sample_dataset, tmp_path): output_dir = tmp_path / "filtered" with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.return_value = str(output_dir) @@ -943,8 +943,8 @@ def test_delete_first_and_last_episodes(sample_dataset, tmp_path): output_dir = tmp_path / "filtered" with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.return_value = str(output_dir) @@ -971,8 +971,8 @@ def test_split_all_episodes_assigned(sample_dataset, tmp_path): } with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" @@ -999,8 +999,8 @@ def test_modify_features_preserves_file_structure(sample_dataset, tmp_path): feature_info = {"dtype": "float32", "shape": (1,), "names": None} with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" @@ -1229,8 +1229,8 @@ def test_convert_image_to_video_dataset(tmp_path): output_dir = tmp_path / "pusht_video" with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.return_value = str(output_dir) @@ -1292,8 +1292,8 @@ def test_convert_image_to_video_dataset_subset_episodes(tmp_path): output_dir = tmp_path / "pusht_video_subset" with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.return_value = str(output_dir) diff --git a/tests/fixtures/dataset_factories.py b/tests/fixtures/dataset_factories.py index f8dd01fec..f53a16924 100644 --- a/tests/fixtures/dataset_factories.py +++ b/tests/fixtures/dataset_factories.py @@ -453,8 +453,8 @@ def lerobot_dataset_metadata_factory( episodes=episodes, ) with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version_patch, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download_patch, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version_patch, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download_patch, ): mock_get_safe_version_patch.side_effect = lambda repo_id, version: version mock_snapshot_download_patch.side_effect = mock_snapshot_download diff --git a/tests/test_control_robot.py b/tests/test_control_robot.py index ace0aea49..772588467 100644 --- a/tests/test_control_robot.py +++ b/tests/test_control_robot.py @@ -71,8 +71,8 @@ def test_record_and_resume(tmp_path): cfg.resume = True # Mock the revision to prevent Hub calls during resume with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.return_value = str(tmp_path / "record") @@ -115,8 +115,8 @@ def test_record_and_replay(tmp_path): # Mock the revision to prevent Hub calls during replay with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.return_value = str(tmp_path / "record_and_replay")