From 017ff73fbfe46bf9a673cd9b402988dcb79151f7 Mon Sep 17 00:00:00 2001 From: Jade Choghari Date: Mon, 23 Mar 2026 13:57:53 -0700 Subject: [PATCH 01/47] chore(docs): add rename map and empty cam guide (#3065) * add blog/guide * add to tree * chore(docs): rephrase rename_map docs for clarity and simplicity --------- Co-authored-by: Steven Palma Co-authored-by: Steven Palma --- docs/source/_toctree.yml | 2 + docs/source/rename_map.mdx | 114 +++++++++++++++++++++++++++++++++++++ 2 files changed, 116 insertions(+) create mode 100644 docs/source/rename_map.mdx diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 1055975d7..09d94d28c 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -19,6 +19,8 @@ title: Multi GPU training - local: peft_training title: Training with PEFT (e.g., LoRA) + - local: rename_map + title: Using Rename Map and Empty Cameras title: "Tutorials" - sections: - local: lerobot-dataset-v3 diff --git a/docs/source/rename_map.mdx b/docs/source/rename_map.mdx new file mode 100644 index 000000000..6249faaca --- /dev/null +++ b/docs/source/rename_map.mdx @@ -0,0 +1,114 @@ +# Rename Map and Empty Cameras + +When you train, evaluate, or record with a robot policy, your **dataset** or **environment** provides observations under one set of keys (e.g. `observation.images.front`, `observation.images.eagle`), while your **policy** expects another (e.g. `observation.images.image`, `observation.images.image2`). The **rename map** bridges that gap without changing the policy or data source. + +> **Scope:** The rename map only renames **observation** keys (images and state). Action keys are not affected. + +## Why observation keys don't always match + +Policies have a fixed set of **input feature names** baked into their pretrained config. For example: + +- [pi0fast-libero](https://huggingface.co/lerobot/pi0fast-libero) expects `observation.images.base_0_rgb` and `observation.images.left_wrist_0_rgb`. +- [xvla-base](https://huggingface.co/lerobot/xvla-base) expects `observation.images.image`, `observation.images.image2`, and `observation.images.image3`. + +Your dataset might use different names entirely (e.g. `observation.images.front`, `observation.images.eagle`, `observation.images.glove`), and your eval environment might use yet another set. Rather than editing the policy config or renaming columns in the dataset, you pass a **rename map**: a JSON dictionary that maps source keys to the keys the policy expects. Renaming happens inside the preprocessor pipeline, so the policy always sees its expected keys. + +## Using the rename map + +Pass the mapping as a JSON string on the command line. The convention is always: + +``` +--rename_map='{"source_key": "policy_key", ...}' +``` + +where **source_key** is what the dataset or environment provides, and **policy_key** is what the policy expects. + +Only listed keys are renamed; everything else passes through unchanged. Order of entries doesn't matter. + +Supported policies: **PI0**, **PI05**, **PI0Fast**, **SmolVLA**, and **XVLA**. + +### Training + +Suppose you fine-tune [lerobot/xvla-base](https://huggingface.co/lerobot/xvla-base) on a dataset with images under `observation.images.front`, `observation.images.eagle`, and `observation.images.glove`. XVLA expects `observation.images.image`, `observation.images.image2`, and `observation.images.image3`: + +```bash +lerobot-train \ + --dataset.repo_id=YOUR_DATASET \ + --output_dir=./outputs/xvla_training \ + --job_name=xvla_training \ + --policy.path="lerobot/xvla-base" \ + --policy.repo_id="HF_USER/xvla-your-robot" \ + --policy.dtype=bfloat16 \ + --policy.action_mode=auto \ + --steps=20000 \ + --policy.device=cuda \ + --policy.freeze_vision_encoder=false \ + --policy.freeze_language_encoder=false \ + --policy.train_policy_transformer=true \ + --policy.train_soft_prompts=true \ + --rename_map='{"observation.images.front": "observation.images.image", "observation.images.eagle": "observation.images.image2", "observation.images.glove": "observation.images.image3"}' +``` + +### Evaluation + +A policy that expects `observation.images.base_0_rgb` and `observation.images.left_wrist_0_rgb` (e.g. [pi0fast-libero](https://huggingface.co/lerobot/pi0fast-libero)), but the LIBERO environment returns `observation.images.image` and `observation.images.image2`: + +```bash +lerobot-eval \ + --policy.path=lerobot/pi0fast-libero \ + --env.type=libero \ + ... \ + --rename_map='{"observation.images.image": "observation.images.base_0_rgb", "observation.images.image2": "observation.images.left_wrist_0_rgb"}' +``` + +### Recording + +`lerobot-record` also supports rename maps, nested under the dataset config: + +```bash +lerobot-record \ # When running inference + --policy.path="/smolVLA_finetuned" \ + ... \ + --dataset.rename_map='{"observation.images.glove2": "observation.images.image"}' +``` + +## Alternative: edit the policy config directly + +If you always use the same dataset or environment, you can **edit the policy's `config.json`** so its observation keys match your data source. Then no rename map is needed. + +The tradeoff: modifying the policy config ties it to one data source. A rename map keeps one policy usable across many datasets and environments. + +## Empty cameras: fewer views than the policy expects + +Some policies are built for a fixed number of image inputs. If your dataset has fewer cameras, you can set **`empty_cameras`** in the policy config instead of modifying the model architecture. + +### How it works + +Setting `empty_cameras=N` adds N placeholder image features to the policy config, named: + +``` +observation.images.empty_camera_0 +observation.images.empty_camera_1 +... +``` + +At runtime, these keys have no corresponding data in the batch. The policy fills them with masked dummy tensors (padded with `-1` for SigLIP-based vision encoders, with a zero attention mask), so the extra image slots are effectively ignored during training and inference. + +### Example + +XVLA-base has three visual inputs and `empty_cameras=0` by default. Your dataset only has two cameras: + +1. Set `--policy.empty_cameras=1`. +2. The config adds a third key: `observation.images.empty_camera_0`. +3. Use the rename map for your two real cameras as usual. +4. The third slot is masked out — no fake images needed in your dataset. + +## Quick reference + +| Goal | What to do | +| ----------------------------------------- | --------------------------------------------------------------------------- | +| Dataset keys ≠ policy keys | `--rename_map='{"dataset_key": "policy_key", ...}'` | +| Env keys ≠ policy keys (eval) | `--rename_map='{"env_key": "policy_key", ...}'` | +| Recording with different keys (inference) | `--dataset.rename_map='{"source_key": "policy_key", ...}'`. | +| Fewer cameras than policy expects | `--policy.empty_cameras=N` (supported by PI0, PI05, PI0Fast, SmolVLA, XVLA) | +| Avoid passing a rename map | Edit the policy's `config.json` so its keys match your data source | From 123495250b029f5f4bc4d8c91f8ac705a7e18426 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Thu, 26 Mar 2026 19:09:25 +0100 Subject: [PATCH 02/47] refactor(dataset): split LeRobotDataset into DatasetReader & DatasetWriter (+ API cleanup) (#3180) * refactor(dataset): split reader and writer * chore(dataset): remove proxys * refactor(dataset): better reader & writer encapsulation * refactor(datasets): clean API + reduce leaky implementations * refactor(dataset): API cleaning for writer, reader and meta * refactor(dataset): expose writer & reader + other minor improvements * refactor(dataset): improve teardown routine * refactor(dataset): add hf_dataset property at the facade level * chore(dataset): add init for datasset module * docs(dataset): add docstrings for public API of the dataset classes * tests(dataset): add tests for new classes * fix(dataset): remove circular dependecy --- docs/source/il_robots.mdx | 2 +- examples/backward_compatibility/replay.py | 2 +- examples/dataset/load_lerobot_dataset.py | 5 +- examples/lekiwi/replay.py | 6 +- examples/phone_to_so100/replay.py | 6 +- examples/so100_to_so100_EE/replay.py | 6 +- src/lerobot/datasets/__init__.py | 33 + src/lerobot/datasets/dataset_metadata.py | 172 ++- src/lerobot/datasets/dataset_reader.py | 288 ++++ src/lerobot/datasets/dataset_tools.py | 2 +- src/lerobot/datasets/dataset_writer.py | 625 ++++++++ src/lerobot/datasets/image_writer.py | 6 +- src/lerobot/datasets/lerobot_dataset.py | 1375 ++++++----------- src/lerobot/datasets/multi_dataset.py | 9 +- src/lerobot/datasets/video_utils.py | 50 +- src/lerobot/rl/buffer.py | 6 +- src/lerobot/rl/gym_manipulator.py | 3 +- src/lerobot/scripts/lerobot_record.py | 13 +- src/lerobot/scripts/lerobot_replay.py | 6 +- .../scripts/lerobot_train_tokenizer.py | 8 +- .../policies/save_policy_to_safetensors.py | 2 +- tests/datasets/test_dataset_metadata.py | 385 +++++ tests/datasets/test_dataset_reader.py | 168 ++ tests/datasets/test_dataset_writer.py | 226 +++ tests/datasets/test_datasets.py | 170 +- tests/datasets/test_image_writer.py | 8 +- tests/datasets/test_lerobot_dataset.py | 314 ++++ .../datasets/test_streaming_video_encoder.py | 4 +- 28 files changed, 2742 insertions(+), 1158 deletions(-) create mode 100644 src/lerobot/datasets/__init__.py create mode 100644 src/lerobot/datasets/dataset_reader.py create mode 100644 src/lerobot/datasets/dataset_writer.py create mode 100644 tests/datasets/test_dataset_metadata.py create mode 100644 tests/datasets/test_dataset_reader.py create mode 100644 tests/datasets/test_dataset_writer.py create mode 100644 tests/datasets/test_lerobot_dataset.py diff --git a/docs/source/il_robots.mdx b/docs/source/il_robots.mdx index 245634382..8e50a2aec 100644 --- a/docs/source/il_robots.mdx +++ b/docs/source/il_robots.mdx @@ -424,7 +424,7 @@ robot = SO100Follower(robot_config) robot.connect() dataset = LeRobotDataset("/", episodes=[episode_idx]) -actions = dataset.hf_dataset.select_columns("action") +actions = dataset.select_columns("action") log_say(f"Replaying episode {episode_idx}") for idx in range(dataset.num_frames): diff --git a/examples/backward_compatibility/replay.py b/examples/backward_compatibility/replay.py index 13fdfd5f5..e999b5913 100644 --- a/examples/backward_compatibility/replay.py +++ b/examples/backward_compatibility/replay.py @@ -78,7 +78,7 @@ def replay(cfg: ReplayConfig): robot = make_robot_from_config(cfg.robot) dataset = LeRobotDataset(cfg.dataset.repo_id, root=cfg.dataset.root, episodes=[cfg.dataset.episode]) - actions = dataset.hf_dataset.select_columns(ACTION) + actions = dataset.select_columns(ACTION) robot.connect() try: diff --git a/examples/dataset/load_lerobot_dataset.py b/examples/dataset/load_lerobot_dataset.py index ea3516710..44ae21a11 100644 --- a/examples/dataset/load_lerobot_dataset.py +++ b/examples/dataset/load_lerobot_dataset.py @@ -88,9 +88,8 @@ def main(): # The previous metadata class is contained in the 'meta' attribute of the dataset: print(dataset.meta) - # LeRobotDataset actually wraps an underlying Hugging Face dataset - # (see https://huggingface.co/docs/datasets for more information). - print(dataset.hf_dataset) + # You can inspect the dataset using its repr: + print(dataset) # LeRobot datasets also subclasses PyTorch datasets so you can do everything you know and love from working # with the latter, like iterating through the dataset. diff --git a/examples/lekiwi/replay.py b/examples/lekiwi/replay.py index cf89aea16..0cfd4811c 100644 --- a/examples/lekiwi/replay.py +++ b/examples/lekiwi/replay.py @@ -35,9 +35,7 @@ def main(): # Fetch the dataset to replay dataset = LeRobotDataset("/", episodes=[EPISODE_IDX]) - # Filter dataset to only include frames from the specified episode since episodes are chunked in dataset V3.0 - episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == EPISODE_IDX) - actions = episode_frames.select_columns(ACTION) + actions = dataset.select_columns(ACTION) # Connect to the robot robot.connect() @@ -48,7 +46,7 @@ def main(): print("Starting replay loop...") log_say(f"Replaying episode {EPISODE_IDX}") - for idx in range(len(episode_frames)): + for idx in range(dataset.num_frames): t0 = time.perf_counter() # Get recorded action from dataset diff --git a/examples/phone_to_so100/replay.py b/examples/phone_to_so100/replay.py index 7b955cdb7..c544614a7 100644 --- a/examples/phone_to_so100/replay.py +++ b/examples/phone_to_so100/replay.py @@ -67,9 +67,7 @@ def main(): # Fetch the dataset to replay dataset = LeRobotDataset(HF_REPO_ID, episodes=[EPISODE_IDX]) - # Filter dataset to only include frames from the specified episode since episodes are chunked in dataset V3.0 - episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == EPISODE_IDX) - actions = episode_frames.select_columns(ACTION) + actions = dataset.select_columns(ACTION) # Connect to the robot robot.connect() @@ -80,7 +78,7 @@ def main(): print("Starting replay loop...") log_say(f"Replaying episode {EPISODE_IDX}") - for idx in range(len(episode_frames)): + for idx in range(dataset.num_frames): t0 = time.perf_counter() # Get recorded action from dataset diff --git a/examples/so100_to_so100_EE/replay.py b/examples/so100_to_so100_EE/replay.py index b042e02dd..7caa560f0 100644 --- a/examples/so100_to_so100_EE/replay.py +++ b/examples/so100_to_so100_EE/replay.py @@ -68,9 +68,7 @@ def main(): # Fetch the dataset to replay dataset = LeRobotDataset(HF_REPO_ID, episodes=[EPISODE_IDX]) - # Filter dataset to only include frames from the specified episode since episodes are chunked in dataset V3.0 - episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == EPISODE_IDX) - actions = episode_frames.select_columns(ACTION) + actions = dataset.select_columns(ACTION) # Connect to the robot robot.connect() @@ -81,7 +79,7 @@ def main(): print("Starting replay loop...") log_say(f"Replaying episode {EPISODE_IDX}") - for idx in range(len(episode_frames)): + for idx in range(dataset.num_frames): t0 = time.perf_counter() # Get recorded action from dataset diff --git a/src/lerobot/datasets/__init__.py b/src/lerobot/datasets/__init__.py new file mode 100644 index 000000000..42c4ab810 --- /dev/null +++ b/src/lerobot/datasets/__init__.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata +from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.datasets.multi_dataset import MultiLeRobotDataset +from lerobot.datasets.sampler import EpisodeAwareSampler +from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset +from lerobot.datasets.transforms import ImageTransforms, ImageTransformsConfig + +__all__ = [ + "EpisodeAwareSampler", + "ImageTransforms", + "ImageTransformsConfig", + "LeRobotDataset", + "LeRobotDatasetMetadata", + "MultiLeRobotDataset", + "StreamingLeRobotDataset", +] diff --git a/src/lerobot/datasets/dataset_metadata.py b/src/lerobot/datasets/dataset_metadata.py index 560a90a6e..a43ba07b4 100644 --- a/src/lerobot/datasets/dataset_metadata.py +++ b/src/lerobot/datasets/dataset_metadata.py @@ -13,6 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import contextlib from pathlib import Path import numpy as np @@ -53,6 +54,13 @@ CODEBASE_VERSION = "v3.0" class LeRobotDatasetMetadata: + """Metadata container for a LeRobot dataset. + + Manages the ``info.json``, ``stats.json``, ``tasks.parquet``, and + ``episodes/`` parquet files that describe a dataset's structure, content, + and statistics. + """ + def __init__( self, repo_id: str, @@ -61,33 +69,51 @@ class LeRobotDatasetMetadata: force_cache_sync: bool = False, metadata_buffer_size: int = 10, ): + """Load or download metadata for an existing LeRobot dataset. + + Attempts to load metadata from local disk. If files are missing or + ``force_cache_sync`` is ``True``, downloads the ``meta/`` directory from + the Hub. + + Args: + repo_id: Repository identifier (e.g. ``'lerobot/aloha_sim'``). + root: Local directory for the dataset. Defaults to + ``$HF_LEROBOT_HOME/{repo_id}``. + revision: Git revision (branch, tag, or commit hash). Defaults to + the current codebase version. + force_cache_sync: If ``True``, re-download metadata from the Hub + even when local files exist. + metadata_buffer_size: Number of episode metadata records to buffer + in memory before flushing to parquet. + """ self.repo_id = repo_id self.revision = revision if revision else CODEBASE_VERSION self.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id - self.writer = None + self._pq_writer = None self.latest_episode = None - self.metadata_buffer: list[dict] = [] - self.metadata_buffer_size = metadata_buffer_size + self._metadata_buffer: list[dict] = [] + self._metadata_buffer_size = metadata_buffer_size + self._finalized = False try: if force_cache_sync: raise FileNotFoundError - self.load_metadata() + self._load_metadata() except (FileNotFoundError, NotADirectoryError): if is_valid_version(self.revision): self.revision = get_safe_version(self.repo_id, self.revision) (self.root / "meta").mkdir(exist_ok=True, parents=True) - self.pull_from_repo(allow_patterns="meta/") - self.load_metadata() + self._pull_from_repo(allow_patterns="meta/") + self._load_metadata() def _flush_metadata_buffer(self) -> None: """Write all buffered episode metadata to parquet file.""" - if not hasattr(self, "metadata_buffer") or len(self.metadata_buffer) == 0: + if not hasattr(self, "_metadata_buffer") or len(self._metadata_buffer) == 0: return combined_dict = {} - for episode_dict in self.metadata_buffer: + for episode_dict in self._metadata_buffer: for key, value in episode_dict.items(): if key not in combined_dict: combined_dict[key] = [] @@ -96,40 +122,50 @@ class LeRobotDatasetMetadata: val = value[0] if isinstance(value, list) else value combined_dict[key].append(val.tolist() if isinstance(val, np.ndarray) else val) - first_ep = self.metadata_buffer[0] + first_ep = self._metadata_buffer[0] chunk_idx = first_ep["meta/episodes/chunk_index"][0] file_idx = first_ep["meta/episodes/file_index"][0] table = pa.Table.from_pydict(combined_dict) - if not self.writer: + if not self._pq_writer: path = Path(self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx)) path.parent.mkdir(parents=True, exist_ok=True) - self.writer = pq.ParquetWriter( + self._pq_writer = pq.ParquetWriter( path, schema=table.schema, compression="snappy", use_dictionary=True ) - self.writer.write_table(table) + self._pq_writer.write_table(table) - self.latest_episode = self.metadata_buffer[-1] - self.metadata_buffer.clear() + self.latest_episode = self._metadata_buffer[-1] + self._metadata_buffer.clear() def _close_writer(self) -> None: """Close and cleanup the parquet writer if it exists.""" self._flush_metadata_buffer() - writer = getattr(self, "writer", None) + writer = getattr(self, "_pq_writer", None) if writer is not None: writer.close() - self.writer = None + self._pq_writer = None + + def finalize(self) -> None: + """Flush metadata buffer and close the parquet writer. + + Idempotent — safe to call multiple times. + """ + if getattr(self, "_finalized", False): + return + self._close_writer() + self._finalized = True def __del__(self): - """ - Trust the user to call .finalize() but as an added safety check call the parquet writer to stop when calling the destructor - """ - self._close_writer() + """Safety net: flush and close parquet writer on garbage collection.""" + # During interpreter shutdown, referenced objects may already be collected. + with contextlib.suppress(Exception): + self.finalize() - def load_metadata(self): + def _load_metadata(self): self.info = load_info(self.root) check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION) self.tasks = load_tasks(self.root) @@ -137,7 +173,7 @@ class LeRobotDatasetMetadata: self.episodes = load_episodes(self.root) self.stats = load_stats(self.root) - def pull_from_repo( + def _pull_from_repo( self, allow_patterns: list[str] | str | None = None, ignore_patterns: list[str] | str | None = None, @@ -153,6 +189,7 @@ class LeRobotDatasetMetadata: @property def url_root(self) -> str: + """Hugging Face Hub URL root for this dataset.""" return f"hf://datasets/{self.repo_id}" @property @@ -161,6 +198,17 @@ class LeRobotDatasetMetadata: return packaging.version.parse(self.info["codebase_version"]) def get_data_file_path(self, ep_index: int) -> Path: + """Return the relative parquet file path for the given episode index. + + Args: + ep_index: Zero-based episode index. + + Returns: + Path to the parquet file containing this episode's data. + + Raises: + IndexError: If ``ep_index`` is out of range. + """ if self.episodes is None: self.episodes = load_episodes(self.root) if ep_index >= len(self.episodes): @@ -174,6 +222,19 @@ class LeRobotDatasetMetadata: return Path(fpath) def get_video_file_path(self, ep_index: int, vid_key: str) -> Path: + """Return the relative video file path for the given episode and video key. + + Args: + ep_index: Zero-based episode index. + vid_key: Feature key identifying the video stream + (e.g. ``'observation.images.laptop'``). + + Returns: + Path to the video file containing this episode's frames. + + Raises: + IndexError: If ``ep_index`` is out of range. + """ if self.episodes is None: self.episodes = load_episodes(self.root) if ep_index >= len(self.episodes): @@ -277,6 +338,17 @@ class LeRobotDatasetMetadata: return None def save_episode_tasks(self, tasks: list[str]): + """Register tasks for the current episode and persist to disk. + + New tasks that do not already exist in the dataset are assigned + sequential task indices and appended to the tasks parquet file. + + Args: + tasks: List of unique task descriptions in natural language. + + Raises: + ValueError: If ``tasks`` contains duplicates. + """ if len(set(tasks)) != len(tasks): raise ValueError(f"Tasks are not unique: {tasks}") @@ -336,8 +408,8 @@ class LeRobotDatasetMetadata: latest_path = ( self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx) - if self.writer is None - else self.writer.where + if self._pq_writer is None + else self._pq_writer.where ) if Path(latest_path).exists(): @@ -359,10 +431,10 @@ class LeRobotDatasetMetadata: episode_dict["dataset_to_index"] = [self.latest_episode["dataset_to_index"][0] + num_frames] # Add to buffer - self.metadata_buffer.append(episode_dict) + self._metadata_buffer.append(episode_dict) self.latest_episode = episode_dict - if len(self.metadata_buffer) >= self.metadata_buffer_size: + if len(self._metadata_buffer) >= self._metadata_buffer_size: self._flush_metadata_buffer() def save_episode( @@ -373,6 +445,20 @@ class LeRobotDatasetMetadata: episode_stats: dict[str, dict], episode_metadata: dict, ) -> None: + """Persist episode metadata, update dataset info, and aggregate stats. + + Writes the episode's metadata to the buffered parquet writer, increments + the total episode/frame counters in ``info.json``, and merges the + episode's statistics into the running dataset statistics. + + Args: + episode_index: Zero-based index of the episode being saved. + episode_length: Number of frames in this episode. + episode_tasks: List of task descriptions for this episode. + episode_stats: Per-feature statistics for this episode. + episode_metadata: Additional metadata (chunk/file indices, frame + ranges, video timestamps, etc.). + """ episode_dict = { "episode_index": episode_index, "tasks": episode_tasks, @@ -479,7 +565,32 @@ class LeRobotDatasetMetadata: data_files_size_in_mb: int | None = None, video_files_size_in_mb: int | None = None, ) -> "LeRobotDatasetMetadata": - """Creates metadata for a LeRobotDataset.""" + """Create metadata for a new LeRobot dataset from scratch. + + Initializes the ``info.json`` file on disk with the provided feature + schema and dataset settings. No episode data is written yet. + + Args: + repo_id: Repository identifier (e.g. ``'user/my_dataset'``). + fps: Frames per second used during data collection. + features: Feature specification dict mapping feature names to their + type/shape metadata. + robot_type: Optional robot type string stored in metadata. + root: Local directory for the dataset. Defaults to + ``$HF_LEROBOT_HOME/{repo_id}``. Must not already exist. + use_videos: If ``True``, visual modalities are encoded as MP4 videos. + metadata_buffer_size: Number of episode metadata records to buffer + before flushing to parquet. + chunks_size: Max number of files per chunk directory. ``None`` uses + the default. + data_files_size_in_mb: Max parquet file size in MB. ``None`` uses the + default. + video_files_size_in_mb: Max video file size in MB. ``None`` uses the + default. + + Returns: + A new :class:`LeRobotDatasetMetadata` instance. + """ obj = cls.__new__(cls) obj.repo_id = repo_id obj.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id @@ -510,8 +621,9 @@ class LeRobotDatasetMetadata: ) write_json(obj.info, obj.root / INFO_PATH) obj.revision = None - obj.writer = None + obj._pq_writer = None obj.latest_episode = None - obj.metadata_buffer = [] - obj.metadata_buffer_size = metadata_buffer_size + obj._metadata_buffer = [] + obj._metadata_buffer_size = metadata_buffer_size + obj._finalized = False return obj diff --git a/src/lerobot/datasets/dataset_reader.py b/src/lerobot/datasets/dataset_reader.py new file mode 100644 index 000000000..0233a3cf6 --- /dev/null +++ b/src/lerobot/datasets/dataset_reader.py @@ -0,0 +1,288 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Private reader component for LeRobotDataset. Handles random-access reading (HF dataset, delta indices, video decoding).""" + +from collections.abc import Callable +from pathlib import Path + +import datasets +import torch + +from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata +from lerobot.datasets.feature_utils import ( + check_delta_timestamps, + get_delta_indices, + get_hf_features_from_features, +) +from lerobot.datasets.io_utils import ( + hf_transform_to_torch, + load_nested_dataset, +) +from lerobot.datasets.video_utils import decode_video_frames + + +class DatasetReader: + """Encapsulates read-side state and methods for LeRobotDataset. + + Owns: hf_dataset, _absolute_to_relative_idx, delta_indices. + """ + + def __init__( + self, + meta: LeRobotDatasetMetadata, + root: Path, + episodes: list[int] | None, + tolerance_s: float, + video_backend: str, + delta_timestamps: dict[str, list[float]] | None, + image_transforms: Callable | None, + ): + """Initialize the reader with metadata, filtering, and transform config. + + The HF dataset is not loaded here — call :meth:`try_load` or + :meth:`load_and_activate` afterward. + + Args: + meta: Dataset metadata instance. + root: Local dataset root directory. + episodes: Optional list of episode indices to select. ``None`` + means all episodes. + tolerance_s: Timestamp synchronization tolerance in seconds. + video_backend: Video decoding backend identifier. + delta_timestamps: Optional dict mapping feature keys to lists of + relative timestamp offsets for temporal context windows. + image_transforms: Optional torchvision v2 transform applied to + visual features. + """ + self._meta = meta + self._root = root + self.episodes = episodes + self._tolerance_s = tolerance_s + self._video_backend = video_backend + self._image_transforms = image_transforms + + self.hf_dataset: datasets.Dataset | None = None + self._absolute_to_relative_idx: dict[int, int] | None = None + + # Setup delta_indices (doesn't depend on hf_dataset) + self.delta_indices = None + if delta_timestamps is not None: + check_delta_timestamps(delta_timestamps, meta.fps, tolerance_s) + self.delta_indices = get_delta_indices(delta_timestamps, meta.fps) + + def try_load(self) -> bool: + """Attempt to load from local cache. Returns True if data is sufficient.""" + try: + self.hf_dataset = self._load_hf_dataset() + except (FileNotFoundError, NotADirectoryError): + self.hf_dataset = None + return False + if not self._check_cached_episodes_sufficient(): + self.hf_dataset = None + return False + self._build_index_mapping() + return True + + def load_and_activate(self) -> None: + """Load HF dataset from disk and build index mapping. Call after data is on disk.""" + self.hf_dataset = self._load_hf_dataset() + self._build_index_mapping() + + def _build_index_mapping(self) -> None: + """Build absolute-to-relative index mapping from loaded hf_dataset.""" + self._absolute_to_relative_idx = None + if self.episodes is not None and self.hf_dataset is not None: + self._absolute_to_relative_idx = { + abs_idx.item() if isinstance(abs_idx, torch.Tensor) else abs_idx: rel_idx + for rel_idx, abs_idx in enumerate(self.hf_dataset["index"]) + } + + @property + def num_frames(self) -> int: + """Number of frames in selected episodes.""" + if self.episodes is not None and self.hf_dataset is not None: + return len(self.hf_dataset) + return self._meta.total_frames + + @property + def num_episodes(self) -> int: + """Number of episodes selected.""" + return len(self.episodes) if self.episodes is not None else self._meta.total_episodes + + def _load_hf_dataset(self) -> datasets.Dataset: + """hf_dataset contains all the observations, states, actions, rewards, etc.""" + features = get_hf_features_from_features(self._meta.features) + hf_dataset = load_nested_dataset(self._root / "data", features=features, episodes=self.episodes) + hf_dataset.set_transform(hf_transform_to_torch) + return hf_dataset + + def _check_cached_episodes_sufficient(self) -> bool: + """Check if the cached dataset contains all requested episodes and their video files.""" + if self.hf_dataset is None or len(self.hf_dataset) == 0: + return False + + available_episodes = { + ep_idx.item() if isinstance(ep_idx, torch.Tensor) else ep_idx + for ep_idx in self.hf_dataset.unique("episode_index") + } + + if self.episodes is None: + requested_episodes = set(range(self._meta.total_episodes)) + else: + requested_episodes = set(self.episodes) + + if not requested_episodes.issubset(available_episodes): + return False + + if len(self._meta.video_keys) > 0: + for ep_idx in requested_episodes: + for vid_key in self._meta.video_keys: + video_path = self._root / self._meta.get_video_file_path(ep_idx, vid_key) + if not video_path.exists(): + return False + + return True + + def get_episodes_file_paths(self) -> list[Path]: + """Return deduplicated file paths (data + video) for selected episodes. + + Used to build the ``allow_patterns`` list for ``snapshot_download``. + """ + episodes = self.episodes if self.episodes is not None else list(range(self._meta.total_episodes)) + fpaths = [str(self._meta.get_data_file_path(ep_idx)) for ep_idx in episodes] + if len(self._meta.video_keys) > 0: + video_files = [ + str(self._meta.get_video_file_path(ep_idx, vid_key)) + for vid_key in self._meta.video_keys + for ep_idx in episodes + ] + fpaths += video_files + # episodes are stored in the same files, so we return unique paths only + fpaths = list(set(fpaths)) + return fpaths + + def _get_query_indices( + self, abs_idx: int, ep_idx: int + ) -> tuple[dict[str, list[int]], dict[str, torch.Tensor]]: + """Compute query indices for delta timestamps.""" + ep = self._meta.episodes[ep_idx] + ep_start = ep["dataset_from_index"] + ep_end = ep["dataset_to_index"] + query_indices = { + key: [max(ep_start, min(ep_end - 1, abs_idx + delta)) for delta in delta_idx] + for key, delta_idx in self.delta_indices.items() + } + padding = { + f"{key}_is_pad": torch.BoolTensor( + [(abs_idx + delta < ep_start) | (abs_idx + delta >= ep_end) for delta in delta_idx] + ) + for key, delta_idx in self.delta_indices.items() + } + return query_indices, padding + + def _get_query_timestamps( + self, + current_ts: float, + query_indices: dict[str, list[int]] | None = None, + ) -> dict[str, list[float]]: + query_timestamps = {} + for key in self._meta.video_keys: + if query_indices is not None and key in query_indices: + if self._absolute_to_relative_idx is not None: + relative_indices = [self._absolute_to_relative_idx[idx] for idx in query_indices[key]] + timestamps = self.hf_dataset[relative_indices]["timestamp"] + else: + timestamps = self.hf_dataset[query_indices[key]]["timestamp"] + query_timestamps[key] = torch.stack(timestamps).tolist() + else: + query_timestamps[key] = [current_ts] + + return query_timestamps + + def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict: + """Query dataset for indices across keys, skipping video keys.""" + result: dict = {} + for key, q_idx in query_indices.items(): + if key in self._meta.video_keys: + continue + relative_indices = ( + q_idx + if self._absolute_to_relative_idx is None + else [self._absolute_to_relative_idx[idx] for idx in q_idx] + ) + try: + result[key] = torch.stack(self.hf_dataset[key][relative_indices]) + except (KeyError, TypeError, IndexError): + result[key] = torch.stack(self.hf_dataset[relative_indices][key]) + return result + + def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict[str, torch.Tensor]: + """Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function + in the main process (e.g. by using a second Dataloader with num_workers=0). It will result in a + Segmentation Fault. + """ + ep = self._meta.episodes[ep_idx] + item = {} + for vid_key, query_ts in query_timestamps.items(): + from_timestamp = ep[f"videos/{vid_key}/from_timestamp"] + shifted_query_ts = [from_timestamp + ts for ts in query_ts] + + video_path = self._root / self._meta.get_video_file_path(ep_idx, vid_key) + frames = decode_video_frames(video_path, shifted_query_ts, self._tolerance_s, self._video_backend) + item[vid_key] = frames.squeeze(0) + + return item + + def get_item(self, idx) -> dict: + """Core __getitem__ logic. Assumes hf_dataset is loaded. + + ``idx`` is a *relative* index into the (possibly episode-filtered) + HF dataset, **not** the absolute frame index stored in the ``index`` + column. The absolute index is retrieved from the row itself. + """ + item = self.hf_dataset[idx] + ep_idx = item["episode_index"].item() + abs_idx = item["index"].item() + + query_indices = None + if self.delta_indices is not None: + query_indices, padding = self._get_query_indices(abs_idx, ep_idx) + query_result = self._query_hf_dataset(query_indices) + item = {**item, **padding} + for key, val in query_result.items(): + item[key] = val + + if len(self._meta.video_keys) > 0: + current_ts = item["timestamp"].item() + query_timestamps = self._get_query_timestamps(current_ts, query_indices) + video_frames = self._query_videos(query_timestamps, ep_idx) + item = {**video_frames, **item} + + if self._image_transforms is not None: + image_keys = self._meta.camera_keys + for cam in image_keys: + item[cam] = self._image_transforms(item[cam]) + + # Add task as a string + task_idx = item["task_index"].item() + item["task"] = self._meta.tasks.iloc[task_idx].name + + # add subtask information if available + if "subtask_index" in self._meta.features and self._meta.subtasks is not None: + subtask_idx = item["subtask_index"].item() + item["subtask"] = self._meta.subtasks.iloc[subtask_idx].name + + return item diff --git a/src/lerobot/datasets/dataset_tools.py b/src/lerobot/datasets/dataset_tools.py index 87cdc18e5..cd2b9fc7c 100644 --- a/src/lerobot/datasets/dataset_tools.py +++ b/src/lerobot/datasets/dataset_tools.py @@ -891,7 +891,7 @@ def _copy_and_reindex_episodes_metadata( total_frames += src_episode["length"] - dst_meta._close_writer() + dst_meta.finalize() dst_meta.info.update( { diff --git a/src/lerobot/datasets/dataset_writer.py b/src/lerobot/datasets/dataset_writer.py new file mode 100644 index 000000000..b74b18e0c --- /dev/null +++ b/src/lerobot/datasets/dataset_writer.py @@ -0,0 +1,625 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Private writer component for LeRobotDataset. Handles sequential recording (episode buffer, ParquetWriter, image writer, video encoding).""" + +from __future__ import annotations + +import concurrent.futures +import contextlib +import logging +import shutil +import tempfile +from pathlib import Path + +import datasets +import numpy as np +import pandas as pd +import PIL.Image +import pyarrow.parquet as pq +import torch + +from lerobot.datasets.compute_stats import compute_episode_stats +from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata +from lerobot.datasets.feature_utils import ( + get_hf_features_from_features, + validate_episode_buffer, + validate_frame, +) +from lerobot.datasets.image_writer import AsyncImageWriter, write_image +from lerobot.datasets.io_utils import ( + embed_images, + get_file_size_in_mb, + load_episodes, + write_info, +) +from lerobot.datasets.utils import ( + DEFAULT_EPISODES_PATH, + DEFAULT_IMAGE_PATH, + update_chunk_file_indices, +) +from lerobot.datasets.video_utils import ( + StreamingVideoEncoder, + concatenate_video_files, + encode_video_frames, + get_video_duration_in_s, +) + +logger = logging.getLogger(__name__) + + +def _encode_video_worker( + video_key: str, + episode_index: int, + root: Path, + fps: int, + vcodec: str = "libsvtav1", + encoder_threads: int | None = None, +) -> Path: + temp_path = Path(tempfile.mkdtemp(dir=root)) / f"{video_key}_{episode_index:03d}.mp4" + fpath = DEFAULT_IMAGE_PATH.format(image_key=video_key, episode_index=episode_index, frame_index=0) + img_dir = (root / fpath).parent + encode_video_frames( + img_dir, temp_path, fps, vcodec=vcodec, overwrite=True, encoder_threads=encoder_threads + ) + shutil.rmtree(img_dir) + return temp_path + + +class DatasetWriter: + """Encapsulates write-side state and methods for LeRobotDataset. + + Owns: episode_buffer, image_writer, _pq_writer (ParquetWriter), _latest_episode, + _current_file_start_frame, _streaming_encoder, _episodes_since_last_encoding, _recorded_frames. + """ + + def __init__( + self, + meta: LeRobotDatasetMetadata, + root: Path, + vcodec: str, + encoder_threads: int | None, + batch_encoding_size: int, + streaming_encoder: StreamingVideoEncoder | None = None, + initial_frames: int = 0, + ): + """Initialize the writer with metadata, codec, and encoding config. + + Args: + meta: Dataset metadata instance (used for feature schema, chunk + settings, and episode persistence). + root: Local dataset root directory. + vcodec: Video codec for encoding (e.g. ``'libsvtav1'``, ``'h264'``). + encoder_threads: Threads per encoder instance. ``None`` for auto. + batch_encoding_size: Number of episodes to accumulate before + batch-encoding videos. + streaming_encoder: Optional pre-built :class:`StreamingVideoEncoder` + for real-time encoding. ``None`` disables streaming mode. + initial_frames: Starting frame count (non-zero when resuming). + """ + self._meta = meta + self._root = root + self._vcodec = vcodec + self._encoder_threads = encoder_threads + self._batch_encoding_size = batch_encoding_size + self._streaming_encoder = streaming_encoder + + # Writer state + self.image_writer: AsyncImageWriter | None = None + self.episode_buffer: dict = self._create_episode_buffer() + self._pq_writer: pq.ParquetWriter | None = None + self._latest_episode: dict | None = None + self._current_file_start_frame: int | None = None + self._episodes_since_last_encoding: int = 0 + self._recorded_frames: int = initial_frames + self._finalized = False + + def _create_episode_buffer(self, episode_index: int | None = None) -> dict: + current_ep_idx = self._meta.total_episodes if episode_index is None else episode_index + ep_buffer = {} + ep_buffer["size"] = 0 + ep_buffer["task"] = [] + for key in self._meta.features: + ep_buffer[key] = current_ep_idx if key == "episode_index" else [] + return ep_buffer + + def _get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path: + fpath = DEFAULT_IMAGE_PATH.format( + image_key=image_key, episode_index=episode_index, frame_index=frame_index + ) + return self._root / fpath + + def _get_image_file_dir(self, episode_index: int, image_key: str) -> Path: + return self._get_image_file_path(episode_index, image_key, frame_index=0).parent + + def _save_image( + self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path, compress_level: int = 1 + ) -> None: + if self.image_writer is None: + if isinstance(image, torch.Tensor): + image = image.cpu().numpy() + write_image(image, fpath, compress_level=compress_level) + else: + self.image_writer.save_image(image=image, fpath=fpath, compress_level=compress_level) + + def add_frame(self, frame: dict) -> None: + """Add a frame to the episode_buffer. Images are written to a temporary directory.""" + # Convert torch to numpy if needed + for name in frame: + if isinstance(frame[name], torch.Tensor): + frame[name] = frame[name].numpy() + + validate_frame(frame, self._meta.features) + + if self.episode_buffer is None: + self.episode_buffer = self._create_episode_buffer() + + # Automatically add frame_index and timestamp to episode buffer + frame_index = self.episode_buffer["size"] + timestamp = frame.pop("timestamp") if "timestamp" in frame else frame_index / self._meta.fps + self.episode_buffer["frame_index"].append(frame_index) + self.episode_buffer["timestamp"].append(timestamp) + self.episode_buffer["task"].append(frame.pop("task")) + + # Start streaming encoder on first frame of episode + if frame_index == 0 and self._streaming_encoder is not None: + self._streaming_encoder.start_episode( + video_keys=list(self._meta.video_keys), + temp_dir=self._root, + ) + + # Add frame features to episode_buffer + for key in frame: + if key not in self._meta.features: + raise ValueError( + f"An element of the frame is not in the features. '{key}' not in '{self._meta.features.keys()}'." + ) + + if self._meta.features[key]["dtype"] == "video" and self._streaming_encoder is not None: + self._streaming_encoder.feed_frame(key, frame[key]) + self.episode_buffer[key].append(None) + elif self._meta.features[key]["dtype"] in ["image", "video"]: + img_path = self._get_image_file_path( + episode_index=self.episode_buffer["episode_index"], image_key=key, frame_index=frame_index + ) + if frame_index == 0: + img_path.parent.mkdir(parents=True, exist_ok=True) + compress_level = 1 if self._meta.features[key]["dtype"] == "video" else 6 + self._save_image(frame[key], img_path, compress_level) + self.episode_buffer[key].append(str(img_path)) + else: + self.episode_buffer[key].append(frame[key]) + + self.episode_buffer["size"] += 1 + + def save_episode( + self, + episode_data: dict | None = None, + parallel_encoding: bool = True, + ) -> None: + """Save the current episode in self.episode_buffer to disk.""" + episode_buffer = episode_data if episode_data is not None else self.episode_buffer + + validate_episode_buffer(episode_buffer, self._meta.total_episodes, self._meta.features) + + # size and task are special cases that won't be added to hf_dataset + episode_length = episode_buffer.pop("size") + tasks = episode_buffer.pop("task") + episode_tasks = list(set(tasks)) + episode_index = episode_buffer["episode_index"] + + episode_buffer["index"] = np.arange(self._meta.total_frames, self._meta.total_frames + episode_length) + episode_buffer["episode_index"] = np.full((episode_length,), episode_index) + + # Update tasks and task indices with new tasks if any + self._meta.save_episode_tasks(episode_tasks) + + # Given tasks in natural language, find their corresponding task indices + episode_buffer["task_index"] = np.array([self._meta.get_task_index(task) for task in tasks]) + + for key, ft in self._meta.features.items(): + if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]: + continue + episode_buffer[key] = np.stack(episode_buffer[key]) + + # Wait for image writer to end, so that episode stats over images can be computed + self._wait_image_writer() + + has_video_keys = len(self._meta.video_keys) > 0 + use_streaming = self._streaming_encoder is not None and has_video_keys + use_batched_encoding = self._batch_encoding_size > 1 + + if use_streaming: + non_video_buffer = { + k: v + for k, v in episode_buffer.items() + if self._meta.features.get(k, {}).get("dtype") not in ("video",) + } + non_video_features = {k: v for k, v in self._meta.features.items() if v["dtype"] != "video"} + ep_stats = compute_episode_stats(non_video_buffer, non_video_features) + else: + ep_stats = compute_episode_stats(episode_buffer, self._meta.features) + + ep_metadata = self._save_episode_data(episode_buffer) + + if use_streaming: + streaming_results = self._streaming_encoder.finish_episode() + for video_key in self._meta.video_keys: + temp_path, video_stats = streaming_results[video_key] + if video_stats is not None: + ep_stats[video_key] = { + k: v if k == "count" else np.squeeze(v.reshape(1, -1, 1, 1) / 255.0, axis=0) + for k, v in video_stats.items() + } + ep_metadata.update(self._save_episode_video(video_key, episode_index, temp_path=temp_path)) + elif has_video_keys and not use_batched_encoding: + num_cameras = len(self._meta.video_keys) + if parallel_encoding and num_cameras > 1: + with concurrent.futures.ProcessPoolExecutor(max_workers=num_cameras) as executor: + future_to_key = { + executor.submit( + _encode_video_worker, + video_key, + episode_index, + self._root, + self._meta.fps, + self._vcodec, + self._encoder_threads, + ): video_key + for video_key in self._meta.video_keys + } + + results = {} + for future in concurrent.futures.as_completed(future_to_key): + video_key = future_to_key[future] + try: + temp_path = future.result() + results[video_key] = temp_path + except Exception as exc: + logger.error(f"Video encoding failed for {video_key}: {exc}") + raise exc + + for video_key in self._meta.video_keys: + temp_path = results[video_key] + ep_metadata.update( + self._save_episode_video(video_key, episode_index, temp_path=temp_path) + ) + else: + for video_key in self._meta.video_keys: + ep_metadata.update(self._save_episode_video(video_key, episode_index)) + + # `meta.save_episode` need to be executed after encoding the videos + self._meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats, ep_metadata) + + if has_video_keys and use_batched_encoding: + self._episodes_since_last_encoding += 1 + if self._episodes_since_last_encoding == self._batch_encoding_size: + start_ep = self._meta.total_episodes - self._batch_encoding_size + end_ep = self._meta.total_episodes + self._batch_save_episode_video(start_ep, end_ep) + self._episodes_since_last_encoding = 0 + + if episode_data is None: + self.clear_episode_buffer(delete_images=len(self._meta.image_keys) > 0) + + def _batch_save_episode_video(self, start_episode: int, end_episode: int | None = None) -> None: + """Batch save videos for multiple episodes.""" + if end_episode is None: + end_episode = self._meta.total_episodes + + logger.info( + f"Batch encoding {self._batch_encoding_size} videos for episodes {start_episode} to {end_episode - 1}" + ) + + chunk_idx = self._meta.episodes[start_episode]["data/chunk_index"] + file_idx = self._meta.episodes[start_episode]["data/file_index"] + episode_df_path = self._root / DEFAULT_EPISODES_PATH.format( + chunk_index=chunk_idx, file_index=file_idx + ) + episode_df = pd.read_parquet(episode_df_path) + + for ep_idx in range(start_episode, end_episode): + logger.info(f"Encoding videos for episode {ep_idx}") + + if ( + self._meta.episodes[ep_idx]["data/chunk_index"] != chunk_idx + or self._meta.episodes[ep_idx]["data/file_index"] != file_idx + ): + episode_df.to_parquet(episode_df_path) + self._meta.episodes = load_episodes(self._root) + + chunk_idx = self._meta.episodes[ep_idx]["data/chunk_index"] + file_idx = self._meta.episodes[ep_idx]["data/file_index"] + episode_df_path = self._root / DEFAULT_EPISODES_PATH.format( + chunk_index=chunk_idx, file_index=file_idx + ) + episode_df = pd.read_parquet(episode_df_path) + + video_ep_metadata = {} + for video_key in self._meta.video_keys: + video_ep_metadata.update(self._save_episode_video(video_key, ep_idx)) + video_ep_metadata.pop("episode_index") + video_ep_df = pd.DataFrame(video_ep_metadata, index=[ep_idx]).convert_dtypes( + dtype_backend="pyarrow" + ) + + episode_df = episode_df.combine_first(video_ep_df) + episode_df.to_parquet(episode_df_path) + self._meta.episodes = load_episodes(self._root) + + def _save_episode_data(self, episode_buffer: dict) -> dict: + """Save episode data to a parquet file.""" + # Use metadata features as the authoritative schema + hf_features = get_hf_features_from_features(self._meta.features) + ep_dict = {key: episode_buffer[key] for key in hf_features} + ep_dataset = datasets.Dataset.from_dict(ep_dict, features=hf_features, split="train") + ep_dataset = embed_images(ep_dataset) + ep_num_frames = len(ep_dataset) + + if self._latest_episode is None: + chunk_idx, file_idx = 0, 0 + global_frame_index = 0 + self._current_file_start_frame = 0 + if self._meta.episodes is not None and len(self._meta.episodes) > 0: + latest_ep = self._meta.episodes[-1] + global_frame_index = latest_ep["dataset_to_index"] + chunk_idx = latest_ep["data/chunk_index"] + file_idx = latest_ep["data/file_index"] + + chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self._meta.chunks_size) + self._current_file_start_frame = global_frame_index + else: + latest_ep = self._latest_episode + chunk_idx = latest_ep["data/chunk_index"] + file_idx = latest_ep["data/file_index"] + global_frame_index = latest_ep["index"][-1] + 1 + + latest_path = self._root / self._meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx) + latest_size_in_mb = get_file_size_in_mb(latest_path) + + frames_in_current_file = global_frame_index - self._current_file_start_frame + av_size_per_frame = ( + latest_size_in_mb / frames_in_current_file if frames_in_current_file > 0 else 0 + ) + + if latest_size_in_mb + av_size_per_frame * ep_num_frames >= self._meta.data_files_size_in_mb: + chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self._meta.chunks_size) + self.close_writer() + self._current_file_start_frame = global_frame_index + + ep_dict["data/chunk_index"] = chunk_idx + ep_dict["data/file_index"] = file_idx + + path = self._root / self._meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx) + path.parent.mkdir(parents=True, exist_ok=True) + + table = ep_dataset.with_format("arrow")[:] + if not self._pq_writer: + self._pq_writer = pq.ParquetWriter( + path, schema=table.schema, compression="snappy", use_dictionary=True + ) + self._pq_writer.write_table(table) + + metadata = { + "data/chunk_index": chunk_idx, + "data/file_index": file_idx, + "dataset_from_index": global_frame_index, + "dataset_to_index": global_frame_index + ep_num_frames, + } + + self._latest_episode = {**ep_dict, **metadata} + self._recorded_frames += ep_num_frames + + return metadata + + def _save_episode_video( + self, + video_key: str, + episode_index: int, + temp_path: Path | None = None, + ) -> dict: + if temp_path is None: + ep_path = self._encode_temporary_episode_video(video_key, episode_index) + else: + ep_path = temp_path + + ep_size_in_mb = get_file_size_in_mb(ep_path) + ep_duration_in_s = get_video_duration_in_s(ep_path) + + if ( + episode_index == 0 + or self._meta.latest_episode is None + or f"videos/{video_key}/chunk_index" not in self._meta.latest_episode + ): + chunk_idx, file_idx = 0, 0 + if self._meta.episodes is not None and len(self._meta.episodes) > 0: + old_chunk_idx = self._meta.episodes[-1][f"videos/{video_key}/chunk_index"] + old_file_idx = self._meta.episodes[-1][f"videos/{video_key}/file_index"] + chunk_idx, file_idx = update_chunk_file_indices( + old_chunk_idx, old_file_idx, self._meta.chunks_size + ) + latest_duration_in_s = 0.0 + new_path = self._root / self._meta.video_path.format( + video_key=video_key, chunk_index=chunk_idx, file_index=file_idx + ) + new_path.parent.mkdir(parents=True, exist_ok=True) + shutil.move(str(ep_path), str(new_path)) + else: + latest_ep = self._meta.latest_episode + chunk_idx = latest_ep[f"videos/{video_key}/chunk_index"][0] + file_idx = latest_ep[f"videos/{video_key}/file_index"][0] + + latest_path = self._root / self._meta.video_path.format( + video_key=video_key, chunk_index=chunk_idx, file_index=file_idx + ) + latest_size_in_mb = get_file_size_in_mb(latest_path) + latest_duration_in_s = latest_ep[f"videos/{video_key}/to_timestamp"][0] + + if latest_size_in_mb + ep_size_in_mb >= self._meta.video_files_size_in_mb: + chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self._meta.chunks_size) + new_path = self._root / self._meta.video_path.format( + video_key=video_key, chunk_index=chunk_idx, file_index=file_idx + ) + new_path.parent.mkdir(parents=True, exist_ok=True) + shutil.move(str(ep_path), str(new_path)) + latest_duration_in_s = 0.0 + else: + concatenate_video_files( + [latest_path, ep_path], + latest_path, + ) + + # Remove temporary directory + shutil.rmtree(str(ep_path.parent)) + + # Update video info (only needed when first episode is encoded) + if episode_index == 0: + self._meta.update_video_info(video_key) + write_info(self._meta.info, self._meta.root) + + metadata = { + "episode_index": episode_index, + f"videos/{video_key}/chunk_index": chunk_idx, + f"videos/{video_key}/file_index": file_idx, + f"videos/{video_key}/from_timestamp": latest_duration_in_s, + f"videos/{video_key}/to_timestamp": latest_duration_in_s + ep_duration_in_s, + } + return metadata + + def clear_episode_buffer(self, delete_images: bool = True) -> None: + """Discard the current episode buffer and optionally delete temp images. + + Args: + delete_images: If ``True``, remove temporary image directories + written for the current episode. + """ + # Cancel streaming encoder if active + if self._streaming_encoder is not None: + self._streaming_encoder.cancel_episode() + + if delete_images: + if self.image_writer is not None: + self._wait_image_writer() + episode_index = self.episode_buffer["episode_index"] + # episode_index is `int` when freshly created, but becomes `np.ndarray` after + # save_episode() mutates the buffer. Handle both types here. + if isinstance(episode_index, np.ndarray): + episode_index = episode_index.item() if episode_index.size == 1 else episode_index[0] + for cam_key in self._meta.image_keys: + img_dir = self._get_image_file_dir(episode_index, cam_key) + if img_dir.is_dir(): + shutil.rmtree(img_dir) + + self.episode_buffer = self._create_episode_buffer() + + def start_image_writer(self, num_processes: int = 0, num_threads: int = 4) -> None: + """Start an :class:`AsyncImageWriter` for background image persistence. + + Args: + num_processes: Number of subprocesses. ``0`` means threads only. + num_threads: Number of threads per process. + """ + if isinstance(self.image_writer, AsyncImageWriter): + logger.warning( + "You are starting a new AsyncImageWriter that is replacing an already existing one in the dataset." + ) + + self.image_writer = AsyncImageWriter( + num_processes=num_processes, + num_threads=num_threads, + ) + + def stop_image_writer(self) -> None: + """Stop the image writer (needed before pickling the dataset for DataLoader).""" + if self.image_writer is not None: + self.image_writer.stop() + self.image_writer = None + + def _wait_image_writer(self) -> None: + """Wait for asynchronous image writer to finish.""" + if self.image_writer is not None: + self.image_writer.wait_until_done() + + def _encode_temporary_episode_video(self, video_key: str, episode_index: int) -> Path: + """Use ffmpeg to convert frames stored as png into mp4 videos.""" + return _encode_video_worker( + video_key, episode_index, self._root, self._meta.fps, self._vcodec, self._encoder_threads + ) + + def close_writer(self) -> None: + """Close and cleanup the parquet writer if it exists.""" + if self._pq_writer is not None: + self._pq_writer.close() + self._pq_writer = None + + def flush_pending_videos(self) -> None: + """Flush any pending video encoding (streaming or batch). + + For streaming encoding: closes the encoder. + For batch encoding: encodes any remaining episodes that haven't been batch-encoded yet. + """ + if self._streaming_encoder is not None: + self._streaming_encoder.close() + elif self._episodes_since_last_encoding > 0: + start_ep = self._meta.total_episodes - self._episodes_since_last_encoding + end_ep = self._meta.total_episodes + logger.info( + f"Encoding remaining {self._episodes_since_last_encoding} episodes, " + f"from episode {start_ep} to {end_ep - 1}" + ) + self._batch_save_episode_video(start_ep, end_ep) + + def cancel_pending_videos(self) -> None: + """Cancel any in-progress streaming encoding without flushing.""" + if self._streaming_encoder is not None: + self._streaming_encoder.cancel_episode() + + def cleanup_interrupted_episode(self, episode_index: int) -> None: + """Remove temporary image directories for an interrupted episode.""" + for key in self._meta.video_keys: + img_dir = self._get_image_file_path( + episode_index=episode_index, image_key=key, frame_index=0 + ).parent + if img_dir.exists(): + logger.debug( + f"Cleaning up interrupted episode images for episode {episode_index}, camera {key}" + ) + shutil.rmtree(img_dir) + + def finalize(self) -> None: + """Flush all pending work and release all resources. + + Idempotent — safe to call multiple times. + """ + if getattr(self, "_finalized", False): + return + # 1. Wait for async image writes to complete, then stop + if self.image_writer is not None: + self.image_writer.wait_until_done() + self.image_writer.stop() + self.image_writer = None + # 2. Flush pending video encoding (streaming or batch) + self.flush_pending_videos() + # 3. Close own parquet writer + self.close_writer() + # 4. Finalize metadata (idempotent) + self._meta.finalize() + self._finalized = True + + def __del__(self): + """Safety net: release resources on garbage collection.""" + # During interpreter shutdown, referenced objects may already be collected. + with contextlib.suppress(Exception): + self.finalize() diff --git a/src/lerobot/datasets/image_writer.py b/src/lerobot/datasets/image_writer.py index 9f40394de..603067757 100644 --- a/src/lerobot/datasets/image_writer.py +++ b/src/lerobot/datasets/image_writer.py @@ -32,10 +32,10 @@ def safe_stop_image_writer(func): return func(*args, **kwargs) except Exception as e: dataset = kwargs.get("dataset") - image_writer = getattr(dataset, "image_writer", None) if dataset else None - if image_writer is not None: + writer = getattr(dataset, "writer", None) if dataset else None + if writer is not None and writer.image_writer is not None: logger.warning("Waiting for image writer to terminate...") - image_writer.stop() + writer.image_writer.stop() raise e return wrapper diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index 8f0600ba8..cba0c1cba 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -13,57 +13,28 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import concurrent.futures import contextlib import logging -import shutil -import tempfile from collections.abc import Callable from pathlib import Path import datasets -import numpy as np -import pandas as pd -import PIL.Image -import pyarrow.parquet as pq import torch import torch.utils from huggingface_hub import HfApi, snapshot_download from huggingface_hub.errors import RevisionNotFoundError -from lerobot.datasets.compute_stats import compute_episode_stats from lerobot.datasets.dataset_metadata import CODEBASE_VERSION, LeRobotDatasetMetadata -from lerobot.datasets.feature_utils import ( - check_delta_timestamps, - get_delta_indices, - get_hf_features_from_features, - validate_episode_buffer, - validate_frame, -) -from lerobot.datasets.image_writer import AsyncImageWriter, write_image -from lerobot.datasets.io_utils import ( - embed_images, - get_file_size_in_mb, - hf_transform_to_torch, - load_episodes, - load_nested_dataset, - write_info, -) +from lerobot.datasets.dataset_reader import DatasetReader +from lerobot.datasets.dataset_writer import DatasetWriter from lerobot.datasets.utils import ( - DEFAULT_EPISODES_PATH, - DEFAULT_IMAGE_PATH, create_lerobot_dataset_card, get_safe_version, is_valid_version, - update_chunk_file_indices, ) from lerobot.datasets.video_utils import ( StreamingVideoEncoder, - concatenate_video_files, - decode_video_frames, - encode_video_frames, get_safe_default_codec, - get_video_duration_in_s, resolve_vcodec, ) from lerobot.utils.constants import HF_LEROBOT_HOME @@ -71,24 +42,6 @@ from lerobot.utils.constants import HF_LEROBOT_HOME logger = logging.getLogger(__name__) -def _encode_video_worker( - video_key: str, - episode_index: int, - root: Path, - fps: int, - vcodec: str = "libsvtav1", - encoder_threads: int | None = None, -) -> Path: - temp_path = Path(tempfile.mkdtemp(dir=root)) / f"{video_key}_{episode_index:03d}.mp4" - fpath = DEFAULT_IMAGE_PATH.format(image_key=video_key, episode_index=episode_index, frame_index=0) - img_dir = (root / fpath).parent - encode_video_frames( - img_dir, temp_path, fps, vcodec=vcodec, overwrite=True, encoder_threads=encoder_threads - ) - shutil.rmtree(img_dir) - return temp_path - - class LeRobotDataset(torch.utils.data.Dataset): def __init__( self, @@ -136,7 +89,7 @@ class LeRobotDataset(torch.utils.data.Dataset): - stats stores the dataset statistics of the different modalities for normalization - tasks contains the prompts for each task of the dataset, which can be used for task-conditioned training. - - hf_dataset (from datasets.Dataset), which will read any values from parquet files. + - data (backed by datasets.Dataset), which reads values from parquet files. - videos (optional) from which frames are loaded to be synchronous with data from parquet files. A typical LeRobotDataset looks like this from its root path: @@ -229,6 +182,11 @@ class LeRobotDataset(torch.utils.data.Dataset): encoder_threads (int | None, optional): Number of threads per encoder instance. None lets the codec auto-detect (default). Lower values reduce CPU usage per encoder. Maps to 'lp' (via svtav1-params) for libsvtav1 and 'threads' for h264/hevc. + + Note: + Write-mode parameters (``streaming_encoding``, ``batch_encoding_size``) passed to + ``__init__`` are deprecated. Use :meth:`create` for new datasets or :meth:`resume` + to append to existing ones. """ super().__init__() self.repo_id = repo_id @@ -238,21 +196,11 @@ class LeRobotDataset(torch.utils.data.Dataset): self.episodes = episodes self.tolerance_s = tolerance_s self.revision = revision if revision else CODEBASE_VERSION - self.video_backend = video_backend if video_backend else get_safe_default_codec() - self.delta_indices = None - self.batch_encoding_size = batch_encoding_size - self.episodes_since_last_encoding = 0 - self.vcodec = resolve_vcodec(vcodec) + self._video_backend = video_backend if video_backend else get_safe_default_codec() + self._batch_encoding_size = batch_encoding_size + self._vcodec = resolve_vcodec(vcodec) self._encoder_threads = encoder_threads - # Unused attributes - self.image_writer = None - self.episode_buffer = None - self.writer = None - self.latest_episode = None - self._current_file_start_frame = None # Track the starting frame index of the current parquet file - self._streaming_encoder = None - self.root.mkdir(exist_ok=True, parents=True) # Load metadata @@ -260,64 +208,270 @@ class LeRobotDataset(torch.utils.data.Dataset): self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync ) - # Track dataset state for efficient incremental writing - self._lazy_loading = False - self._recorded_frames = self.meta.total_frames - self._writer_closed_for_reading = False + # Create reader (hf_dataset loaded below) + self.reader = DatasetReader( + meta=self.meta, + root=self.root, + episodes=episodes, + tolerance_s=tolerance_s, + video_backend=self._video_backend, + delta_timestamps=delta_timestamps, + image_transforms=image_transforms, + ) # Load actual data - try: - if force_cache_sync: - raise FileNotFoundError - self.hf_dataset = self.load_hf_dataset() - # Check if cached dataset contains all requested episodes - if not self._check_cached_episodes_sufficient(): - raise FileNotFoundError("Cached dataset doesn't contain all requested episodes") - except (FileNotFoundError, NotADirectoryError): + if force_cache_sync or not self.reader.try_load(): if is_valid_version(self.revision): self.revision = get_safe_version(self.repo_id, self.revision) - self.download(download_videos) - self.hf_dataset = self.load_hf_dataset() + self._download(download_videos) + self.reader.load_and_activate() - # Create mapping from absolute indices to relative indices when only a subset of the episodes are loaded - # Build a mapping: absolute_index -> relative_index_in_filtered_dataset - self._absolute_to_relative_idx = None - if self.episodes is not None: - self._absolute_to_relative_idx = { - abs_idx.item() if isinstance(abs_idx, torch.Tensor) else abs_idx: rel_idx - for rel_idx, abs_idx in enumerate(self.hf_dataset["index"]) - } + # Detect write-mode params for backward compatibility + _has_write_params = streaming_encoding or batch_encoding_size != 1 + if _has_write_params: + import warnings - # Setup delta_indices - if self.delta_timestamps is not None: - check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s) - self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps) - - # Initialize streaming encoder for resumed recording - if streaming_encoding and len(self.meta.video_keys) > 0: - self._streaming_encoder = StreamingVideoEncoder( - fps=self.meta.fps, - vcodec=self.vcodec, - pix_fmt="yuv420p", - g=2, - crf=30, - preset=None, - queue_maxsize=encoder_queue_maxsize, - encoder_threads=encoder_threads, + warnings.warn( + "Passing write-mode parameters (streaming_encoding, batch_encoding_size) to " + "LeRobotDataset.__init__() is deprecated. Use LeRobotDataset.resume() instead.", + DeprecationWarning, + stacklevel=2, ) - - def _close_writer(self) -> None: - """Close and cleanup the parquet writer if it exists.""" - writer = getattr(self, "writer", None) - if writer is not None: - writer.close() + streaming_enc = None + if streaming_encoding and len(self.meta.video_keys) > 0: + streaming_enc = self._build_streaming_encoder( + self.meta.fps, self._vcodec, encoder_queue_maxsize, encoder_threads + ) + self.writer = DatasetWriter( + meta=self.meta, + root=self.root, + vcodec=self._vcodec, + encoder_threads=encoder_threads, + batch_encoding_size=batch_encoding_size, + streaming_encoder=streaming_enc, + initial_frames=self.meta.total_frames, + ) + else: self.writer = None - def __del__(self): + self._is_finalized = False + + # ── Writer guard ────────────────────────────────────────────────── + + def _require_writer(self, method_name: str) -> None: + if self.writer is None: + raise RuntimeError( + f"Cannot call '{method_name}()' on a read-only dataset. " + f"Use LeRobotDataset.create() for new recording or " + f"LeRobotDataset.resume() for resume recording." + ) + + # ── Reader guard ────────────────────────────────────────────────── + + def _ensure_reader(self) -> DatasetReader: + """Lazily create the reader on first access.""" + if self.reader is None: + self.reader = DatasetReader( + meta=self.meta, + root=self.root, + episodes=self.episodes, + tolerance_s=self.tolerance_s, + video_backend=self._video_backend, + delta_timestamps=self.delta_timestamps, + image_transforms=self.image_transforms, + ) + return self.reader + + @staticmethod + def _build_streaming_encoder( + fps: int, + vcodec: str, + encoder_queue_maxsize: int, + encoder_threads: int | None, + ) -> StreamingVideoEncoder: + return StreamingVideoEncoder( + fps=fps, + vcodec=vcodec, + pix_fmt="yuv420p", + g=2, + crf=30, + preset=None, + queue_maxsize=encoder_queue_maxsize, + encoder_threads=encoder_threads, + ) + + # ── Metadata properties ─────────────────────────────────────────── + + @property + def fps(self) -> int: + """Frames per second used during data collection.""" + return self.meta.fps + + @property + def num_frames(self) -> int: + """Number of frames in selected episodes.""" + # Check directly instead of using _ensure_reader(): in write-only mode + # (create/resume) we rely on metadata rather than initializing a reader. + if self.reader is None: + return self.meta.total_frames + return self.reader.num_frames + + @property + def num_episodes(self) -> int: + """Number of episodes selected.""" + # Check directly instead of using _ensure_reader(): in write-only mode + # (create/resume) we rely on metadata rather than initializing a reader. + if self.reader is None: + return self.meta.total_episodes + return self.reader.num_episodes + + @property + def features(self) -> dict[str, dict]: + """Feature specification dict mapping feature names to their type/shape metadata.""" + return self.meta.features + + @property + def hf_dataset(self) -> datasets.Dataset: + """The underlying Hugging Face Dataset object""" + self.reader = self._ensure_reader() + if self.reader.hf_dataset is None: + self.reader.load_and_activate() + return self.reader.hf_dataset + + # ── Writer-delegated methods ────────────────────────────────────── + + def add_frame(self, frame: dict) -> None: + """Add a single frame to the current episode buffer. + + Delegates to :meth:`DatasetWriter.add_frame`. The dataset must be in + write mode (created via :meth:`create` or :meth:`resume`). + + Args: + frame: Dict mapping feature names to their values for this frame. + Must include a ``'task'`` key. Torch tensors are converted to numpy. + + Raises: + RuntimeError: If the dataset is read-only (no writer). """ - Trust the user to call .finalize() but as an added safety check call the parquet writer to stop when calling the destructor + self._require_writer("add_frame") + self.writer.add_frame(frame) + + def save_episode(self, episode_data: dict | None = None, parallel_encoding: bool = True) -> None: + """Save the current episode buffer to disk. + + Delegates to :meth:`DatasetWriter.save_episode`. Encodes videos, writes + parquet data, and updates metadata. The episode buffer is reset afterward. + + Args: + episode_data: Optional pre-built episode dict. If ``None``, uses the + internal episode buffer populated by :meth:`add_frame`. + parallel_encoding: If ``True`` and multiple cameras exist, encode + videos in parallel using a process pool. + + Raises: + RuntimeError: If the dataset is read-only (no writer). """ - self._close_writer() + self._require_writer("save_episode") + self.writer.save_episode(episode_data, parallel_encoding) + + def clear_episode_buffer(self, delete_images: bool = True) -> None: + """Discard the current episode buffer without saving. + + Delegates to :meth:`DatasetWriter.clear_episode_buffer`. Useful for + discarding a failed or interrupted recording episode. + + Args: + delete_images: If ``True``, also remove temporary image files written + to disk for the current episode. + + Raises: + RuntimeError: If the dataset is read-only (no writer). + """ + self._require_writer("clear_episode_buffer") + self.writer.clear_episode_buffer(delete_images) + + def has_pending_frames(self) -> bool: + """Check if there are unsaved frames in the episode buffer.""" + if self.writer is None: + return False + return self.writer.episode_buffer is not None and self.writer.episode_buffer["size"] > 0 + + def finalize(self): + """Flush all pending work and close writers. + + Must be called after data collection/conversion, otherwise footer metadata + won't be written to the parquet files and the dataset will be invalid. + + Idempotent — safe to call multiple times. DatasetWriter.__del__ acts as a + safety net if this is never called explicitly. + """ + if self._is_finalized: + return + if self.writer is not None: + self.writer.finalize() + self._is_finalized = True + + # ── Core Dataset methods ────────────────────────────────────────── + + def __len__(self): + """Return the number of frames in the selected episodes.""" + return self.num_frames + + def __getitem__(self, idx) -> dict: + """Return a single frame by index, with all transforms applied. + + Loads the frame from the underlying HF dataset, expands delta-timestamp + windows, decodes video frames, and applies image transforms. Delegates + the core logic to :meth:`DatasetReader.get_item`. + + Args: + idx: Index into the (possibly episode-filtered) dataset. + + Returns: + Dict mapping feature names to their tensor values for this frame. + + Raises: + RuntimeError: If the dataset is currently being recorded and + :meth:`finalize` has not been called yet. + """ + if self.writer is not None and not self._is_finalized: + raise RuntimeError( + "Cannot read from a dataset that is being recorded. Call finalize() first, then access items." + ) + reader = self._ensure_reader() + if reader.hf_dataset is None: + # One-shot load after finalize() + reader.load_and_activate() + return reader.get_item(idx) + + def select_columns(self, column_names: str | list[str]): + """Select specific columns from the underlying dataset. + + Useful for extracting action sequences during replay without loading all features. + Returns a ``datasets.Dataset`` containing only the requested columns. + """ + return self.hf_dataset.select_columns(column_names) + + def get_raw_item(self, idx) -> dict: + """Get a raw frame without image transforms applied. + + Unlike ``__getitem__``, this returns the raw HF dataset row at the given + index with no delta-timestamp expansion, video decoding, or image transforms. + """ + return self.hf_dataset[idx] + + def __repr__(self): + feature_keys = list(self.features) + return ( + f"{self.__class__.__name__}({{\n" + f" Repository ID: '{self.repo_id}',\n" + f" Number of selected episodes: '{self.num_episodes}',\n" + f" Number of selected samples: '{self.num_frames}',\n" + f" Features: '{feature_keys}',\n" + f"}})" + ) + + # ── Hub methods (stay on facade) ────────────────────────────────── def push_to_hub( self, @@ -331,6 +485,27 @@ class LeRobotDataset(torch.utils.data.Dataset): upload_large_folder: bool = False, **card_kwargs, ) -> None: + """Upload the dataset to the Hugging Face Hub. + + Creates the repository if it does not exist, uploads all dataset files + (optionally excluding videos), generates a dataset card, and tags the + revision with the current codebase version. + + Args: + branch: Optional branch to push to. Created from the current + revision if it does not exist. + tags: Optional list of tags for the dataset card. + license: License identifier for the dataset card. + tag_version: If ``True``, create a Git tag for the current codebase + version. + push_videos: If ``False``, skip uploading the ``videos/`` directory. + private: If ``True``, create a private repository. + allow_patterns: Glob pattern(s) restricting which files to upload. + upload_large_folder: If ``True``, use ``upload_large_folder`` instead + of ``upload_folder`` for very large datasets. + **card_kwargs: Additional keyword arguments forwarded to dataset card + creation. + """ ignore_patterns = ["images/"] if not push_videos: ignore_patterns.append("videos/") @@ -374,795 +549,23 @@ class LeRobotDataset(torch.utils.data.Dataset): hub_api.delete_tag(self.repo_id, tag=CODEBASE_VERSION, repo_type="dataset") hub_api.create_tag(self.repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset") - def pull_from_repo( - self, - allow_patterns: list[str] | str | None = None, - ignore_patterns: list[str] | str | None = None, - ) -> None: + def _download(self, download_videos: bool = True) -> None: + """Downloads the dataset from the given 'repo_id' at the provided version.""" + ignore_patterns = None if download_videos else "videos/" + files = None + if self.episodes is not None: + # Reader is guaranteed to exist here (created in __init__ before _download) + files = self.reader.get_episodes_file_paths() snapshot_download( self.repo_id, repo_type="dataset", revision=self.revision, local_dir=self.root, - allow_patterns=allow_patterns, + allow_patterns=files, ignore_patterns=ignore_patterns, ) - def download(self, download_videos: bool = True) -> None: - """Downloads the dataset from the given 'repo_id' at the provided version. If 'episodes' is given, this - will only download those episodes (selected by their episode_index). If 'episodes' is None, the whole - dataset will be downloaded. Thanks to the behavior of snapshot_download, if the files are already present - in 'local_dir', they won't be downloaded again. - """ - # TODO(rcadene, aliberts): implement faster transfer - # https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads - ignore_patterns = None if download_videos else "videos/" - files = None - if self.episodes is not None: - files = self.get_episodes_file_paths() - self.pull_from_repo(allow_patterns=files, ignore_patterns=ignore_patterns) - - def get_episodes_file_paths(self) -> list[Path]: - episodes = self.episodes if self.episodes is not None else list(range(self.meta.total_episodes)) - fpaths = [str(self.meta.get_data_file_path(ep_idx)) for ep_idx in episodes] - if len(self.meta.video_keys) > 0: - video_files = [ - str(self.meta.get_video_file_path(ep_idx, vid_key)) - for vid_key in self.meta.video_keys - for ep_idx in episodes - ] - fpaths += video_files - # episodes are stored in the same files, so we return unique paths only - fpaths = list(set(fpaths)) - return fpaths - - def load_hf_dataset(self) -> datasets.Dataset: - """hf_dataset contains all the observations, states, actions, rewards, etc.""" - features = get_hf_features_from_features(self.features) - hf_dataset = load_nested_dataset(self.root / "data", features=features, episodes=self.episodes) - hf_dataset.set_transform(hf_transform_to_torch) - return hf_dataset - - def _check_cached_episodes_sufficient(self) -> bool: - """Check if the cached dataset contains all requested episodes and their video files.""" - if self.hf_dataset is None or len(self.hf_dataset) == 0: - return False - - # Get available episode indices from cached dataset - available_episodes = { - ep_idx.item() if isinstance(ep_idx, torch.Tensor) else ep_idx - for ep_idx in self.hf_dataset.unique("episode_index") - } - - # Determine requested episodes - if self.episodes is None: - requested_episodes = set(range(self.meta.total_episodes)) - else: - requested_episodes = set(self.episodes) - - # Check if all requested episodes are available in cached data - if not requested_episodes.issubset(available_episodes): - return False - - # Check if all required video files exist - if len(self.meta.video_keys) > 0: - for ep_idx in requested_episodes: - for vid_key in self.meta.video_keys: - video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key) - if not video_path.exists(): - return False - - return True - - def create_hf_dataset(self) -> datasets.Dataset: - features = get_hf_features_from_features(self.features) - ft_dict = {col: [] for col in features} - hf_dataset = datasets.Dataset.from_dict(ft_dict, features=features, split="train") - hf_dataset.set_transform(hf_transform_to_torch) - return hf_dataset - - @property - def fps(self) -> int: - """Frames per second used during data collection.""" - return self.meta.fps - - @property - def num_frames(self) -> int: - """Number of frames in selected episodes. - - Note: When episodes a subset of the full dataset is requested, we must return the - actual loaded data length (len(self.hf_dataset)) rather than metadata total_frames. - self.meta.total_frames is the total number of frames in the full dataset. - """ - if self.episodes is not None and self.hf_dataset is not None: - return len(self.hf_dataset) - return self.meta.total_frames - - @property - def num_episodes(self) -> int: - """Number of episodes selected.""" - return len(self.episodes) if self.episodes is not None else self.meta.total_episodes - - @property - def features(self) -> dict[str, dict]: - return self.meta.features - - @property - def hf_features(self) -> datasets.Features: - """Features of the hf_dataset.""" - if self.hf_dataset is not None: - return self.hf_dataset.features - else: - return get_hf_features_from_features(self.features) - - def _get_query_indices( - self, abs_idx: int, ep_idx: int - ) -> tuple[dict[str, list[int]], dict[str, torch.Tensor]]: - """Compute query indices for delta timestamps. - - Args: - abs_idx: The absolute index in the full dataset (not the relative index in filtered episodes). - ep_idx: The episode index. - - Returns: - A tuple of (query_indices, padding) where: - - query_indices: Dict mapping keys to lists of absolute indices to query - - padding: Dict mapping "{key}_is_pad" to boolean tensors indicating padded positions - """ - ep = self.meta.episodes[ep_idx] - ep_start = ep["dataset_from_index"] - ep_end = ep["dataset_to_index"] - query_indices = { - key: [max(ep_start, min(ep_end - 1, abs_idx + delta)) for delta in delta_idx] - for key, delta_idx in self.delta_indices.items() - } - padding = { # Pad values outside of current episode range - f"{key}_is_pad": torch.BoolTensor( - [(abs_idx + delta < ep_start) | (abs_idx + delta >= ep_end) for delta in delta_idx] - ) - for key, delta_idx in self.delta_indices.items() - } - return query_indices, padding - - def _get_query_timestamps( - self, - current_ts: float, - query_indices: dict[str, list[int]] | None = None, - ) -> dict[str, list[float]]: - query_timestamps = {} - for key in self.meta.video_keys: - if query_indices is not None and key in query_indices: - if self._absolute_to_relative_idx is not None: - relative_indices = [self._absolute_to_relative_idx[idx] for idx in query_indices[key]] - timestamps = self.hf_dataset[relative_indices]["timestamp"] - else: - timestamps = self.hf_dataset[query_indices[key]]["timestamp"] - query_timestamps[key] = torch.stack(timestamps).tolist() - else: - query_timestamps[key] = [current_ts] - - return query_timestamps - - def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict: - """ - Query dataset for indices across keys, skipping video keys. - - Tries column-first [key][indices] for speed, falls back to row-first. - - Args: - query_indices: Dict mapping keys to index lists to retrieve - - Returns: - Dict with stacked tensors of queried data (video keys excluded) - """ - result: dict = {} - for key, q_idx in query_indices.items(): - if key in self.meta.video_keys: - continue - # Map absolute indices to relative indices if needed - relative_indices = ( - q_idx - if self._absolute_to_relative_idx is None - else [self._absolute_to_relative_idx[idx] for idx in q_idx] - ) - try: - result[key] = torch.stack(self.hf_dataset[key][relative_indices]) - except (KeyError, TypeError, IndexError): - result[key] = torch.stack(self.hf_dataset[relative_indices][key]) - return result - - def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict[str, torch.Tensor]: - """Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function - in the main process (e.g. by using a second Dataloader with num_workers=0). It will result in a - Segmentation Fault. This probably happens because a memory reference to the video loader is created in - the main process and a subprocess fails to access it. - """ - ep = self.meta.episodes[ep_idx] - item = {} - for vid_key, query_ts in query_timestamps.items(): - # Episodes are stored sequentially on a single mp4 to reduce the number of files. - # Thus we load the start timestamp of the episode on this mp4 and, - # shift the query timestamp accordingly. - from_timestamp = ep[f"videos/{vid_key}/from_timestamp"] - shifted_query_ts = [from_timestamp + ts for ts in query_ts] - - video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key) - frames = decode_video_frames(video_path, shifted_query_ts, self.tolerance_s, self.video_backend) - item[vid_key] = frames.squeeze(0) - - return item - - def _ensure_hf_dataset_loaded(self): - """Lazy load the HF dataset only when needed for reading.""" - if self._lazy_loading or self.hf_dataset is None: - # Close the writer before loading to ensure parquet file is properly finalized - if self.writer is not None: - self._close_writer() - self._writer_closed_for_reading = True - self.hf_dataset = self.load_hf_dataset() - self._lazy_loading = False - - def __len__(self): - return self.num_frames - - def __getitem__(self, idx) -> dict: - # Ensure dataset is loaded when we actually need to read from it - self._ensure_hf_dataset_loaded() - item = self.hf_dataset[idx] - ep_idx = item["episode_index"].item() - # Use the absolute index from the dataset for delta timestamp calculations - abs_idx = item["index"].item() - - query_indices = None - if self.delta_indices is not None: - query_indices, padding = self._get_query_indices(abs_idx, ep_idx) - query_result = self._query_hf_dataset(query_indices) - item = {**item, **padding} - for key, val in query_result.items(): - item[key] = val - - if len(self.meta.video_keys) > 0: - current_ts = item["timestamp"].item() - query_timestamps = self._get_query_timestamps(current_ts, query_indices) - video_frames = self._query_videos(query_timestamps, ep_idx) - item = {**video_frames, **item} - - if self.image_transforms is not None: - image_keys = self.meta.camera_keys - for cam in image_keys: - item[cam] = self.image_transforms(item[cam]) - - # Add task as a string - task_idx = item["task_index"].item() - item["task"] = self.meta.tasks.iloc[task_idx].name - - # add subtask information if available - if "subtask_index" in self.features and self.meta.subtasks is not None: - subtask_idx = item["subtask_index"].item() - item["subtask"] = self.meta.subtasks.iloc[subtask_idx].name - - return item - - def __repr__(self): - feature_keys = list(self.features) - return ( - f"{self.__class__.__name__}({{\n" - f" Repository ID: '{self.repo_id}',\n" - f" Number of selected episodes: '{self.num_episodes}',\n" - f" Number of selected samples: '{self.num_frames}',\n" - f" Features: '{feature_keys}',\n" - "})',\n" - ) - - def finalize(self): - """ - Close the parquet writers. This function needs to be called after data collection/conversion, else footer metadata won't be written to the parquet files. - The dataset won't be valid and can't be loaded as ds = LeRobotDataset(repo_id=repo, root=HF_LEROBOT_HOME.joinpath(repo)) - """ - self._close_writer() - self.meta._close_writer() - if self._streaming_encoder is not None: - self._streaming_encoder.close() - - def create_episode_buffer(self, episode_index: int | None = None) -> dict: - current_ep_idx = self.meta.total_episodes if episode_index is None else episode_index - ep_buffer = {} - # size and task are special cases that are not in self.features - ep_buffer["size"] = 0 - ep_buffer["task"] = [] - for key in self.features: - ep_buffer[key] = current_ep_idx if key == "episode_index" else [] - return ep_buffer - - # TODO(Steven): consider move this to utils - def _get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path: - fpath = DEFAULT_IMAGE_PATH.format( - image_key=image_key, episode_index=episode_index, frame_index=frame_index - ) - return self.root / fpath - - def _get_image_file_dir(self, episode_index: int, image_key: str) -> Path: - return self._get_image_file_path(episode_index, image_key, frame_index=0).parent - - def _save_image( - self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path, compress_level: int = 1 - ) -> None: - if self.image_writer is None: - if isinstance(image, torch.Tensor): - image = image.cpu().numpy() - write_image(image, fpath, compress_level=compress_level) - else: - self.image_writer.save_image(image=image, fpath=fpath, compress_level=compress_level) - - def add_frame(self, frame: dict) -> None: - """ - This function only adds the frame to the episode_buffer. Apart from images — which are written in a - temporary directory — nothing is written to disk. To save those frames, the 'save_episode()' method - then needs to be called. - """ - # Convert torch to numpy if needed - for name in frame: - if isinstance(frame[name], torch.Tensor): - frame[name] = frame[name].numpy() - - validate_frame(frame, self.features) - - if self.episode_buffer is None: - self.episode_buffer = self.create_episode_buffer() - - # Automatically add frame_index and timestamp to episode buffer - frame_index = self.episode_buffer["size"] - timestamp = frame.pop("timestamp") if "timestamp" in frame else frame_index / self.fps - self.episode_buffer["frame_index"].append(frame_index) - self.episode_buffer["timestamp"].append(timestamp) - self.episode_buffer["task"].append(frame.pop("task")) # Remove task from frame after processing - - # Start streaming encoder on first frame of episode (once, before iterating keys) - if frame_index == 0 and self._streaming_encoder is not None: - self._streaming_encoder.start_episode( - video_keys=list(self.meta.video_keys), - temp_dir=self.root, - ) - - # Add frame features to episode_buffer - for key in frame: - if key not in self.features: - raise ValueError( - f"An element of the frame is not in the features. '{key}' not in '{self.features.keys()}'." - ) - - if self.features[key]["dtype"] == "video" and self._streaming_encoder is not None: - self._streaming_encoder.feed_frame(key, frame[key]) - self.episode_buffer[key].append(None) # Placeholder (video keys are skipped in parquet) - elif self.features[key]["dtype"] in ["image", "video"]: - img_path = self._get_image_file_path( - episode_index=self.episode_buffer["episode_index"], image_key=key, frame_index=frame_index - ) - if frame_index == 0: - img_path.parent.mkdir(parents=True, exist_ok=True) - compress_level = 1 if self.features[key]["dtype"] == "video" else 6 - self._save_image(frame[key], img_path, compress_level) - self.episode_buffer[key].append(str(img_path)) - else: - self.episode_buffer[key].append(frame[key]) - - self.episode_buffer["size"] += 1 - - def save_episode( - self, - episode_data: dict | None = None, - parallel_encoding: bool = True, - ) -> None: - """ - This will save to disk the current episode in self.episode_buffer. - - Video encoding is handled automatically based on batch_encoding_size: - - If batch_encoding_size == 1: Videos are encoded immediately after each episode - - If batch_encoding_size > 1: Videos are encoded in batches. - - Args: - episode_data (dict | None, optional): Dict containing the episode data to save. If None, this will - save the current episode in self.episode_buffer, which is filled with 'add_frame'. Defaults to - None. - parallel_encoding (bool, optional): If True, encode videos in parallel using ProcessPoolExecutor. - Defaults to True on Linux, False on macOS as it tends to use all the CPU available already. - """ - episode_buffer = episode_data if episode_data is not None else self.episode_buffer - - validate_episode_buffer(episode_buffer, self.meta.total_episodes, self.features) - - # size and task are special cases that won't be added to hf_dataset - episode_length = episode_buffer.pop("size") - tasks = episode_buffer.pop("task") - episode_tasks = list(set(tasks)) - episode_index = episode_buffer["episode_index"] - - episode_buffer["index"] = np.arange(self.meta.total_frames, self.meta.total_frames + episode_length) - episode_buffer["episode_index"] = np.full((episode_length,), episode_index) - - # Update tasks and task indices with new tasks if any - self.meta.save_episode_tasks(episode_tasks) - - # Given tasks in natural language, find their corresponding task indices - episode_buffer["task_index"] = np.array([self.meta.get_task_index(task) for task in tasks]) - - for key, ft in self.features.items(): - # index, episode_index, task_index are already processed above, and image and video - # are processed separately by storing image path and frame info as meta data - if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]: - continue - episode_buffer[key] = np.stack(episode_buffer[key]) - - # Wait for image writer to end, so that episode stats over images can be computed - self._wait_image_writer() - - has_video_keys = len(self.meta.video_keys) > 0 - use_streaming = self._streaming_encoder is not None and has_video_keys - use_batched_encoding = self.batch_encoding_size > 1 - - if use_streaming: - # Compute stats for non-video features only (video stats come from encoder) - non_video_buffer = { - k: v - for k, v in episode_buffer.items() - if self.features.get(k, {}).get("dtype") not in ("video",) - } - non_video_features = {k: v for k, v in self.features.items() if v["dtype"] != "video"} - ep_stats = compute_episode_stats(non_video_buffer, non_video_features) - else: - ep_stats = compute_episode_stats(episode_buffer, self.features) - - ep_metadata = self._save_episode_data(episode_buffer) - - if use_streaming: - # Finish streaming encoding and collect results - streaming_results = self._streaming_encoder.finish_episode() - for video_key in self.meta.video_keys: - temp_path, video_stats = streaming_results[video_key] - if video_stats is not None: - # Format stats same as compute_episode_stats: normalize to [0,1], reshape to (C,1,1) - ep_stats[video_key] = { - k: v if k == "count" else np.squeeze(v.reshape(1, -1, 1, 1) / 255.0, axis=0) - for k, v in video_stats.items() - } - ep_metadata.update(self._save_episode_video(video_key, episode_index, temp_path=temp_path)) - elif has_video_keys and not use_batched_encoding: - num_cameras = len(self.meta.video_keys) - if parallel_encoding and num_cameras > 1: - # TODO(Steven): Ideally we would like to control the number of threads per encoding such that: - # num_cameras * num_threads = (total_cpu -1) - with concurrent.futures.ProcessPoolExecutor(max_workers=num_cameras) as executor: - future_to_key = { - executor.submit( - _encode_video_worker, - video_key, - episode_index, - self.root, - self.fps, - self.vcodec, - self._encoder_threads, - ): video_key - for video_key in self.meta.video_keys - } - - results = {} - for future in concurrent.futures.as_completed(future_to_key): - video_key = future_to_key[future] - try: - temp_path = future.result() - results[video_key] = temp_path - except Exception as exc: - logger.error(f"Video encoding failed for {video_key}: {exc}") - raise exc - - for video_key in self.meta.video_keys: - temp_path = results[video_key] - ep_metadata.update( - self._save_episode_video(video_key, episode_index, temp_path=temp_path) - ) - else: - for video_key in self.meta.video_keys: - ep_metadata.update(self._save_episode_video(video_key, episode_index)) - - # `meta.save_episode` need to be executed after encoding the videos - self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats, ep_metadata) - - if has_video_keys and use_batched_encoding: - # Check if we should trigger batch encoding - self.episodes_since_last_encoding += 1 - if self.episodes_since_last_encoding == self.batch_encoding_size: - start_ep = self.num_episodes - self.batch_encoding_size - end_ep = self.num_episodes - self._batch_save_episode_video(start_ep, end_ep) - self.episodes_since_last_encoding = 0 - - if not episode_data: - # Reset episode buffer and clean up temporary images (if not already deleted during video encoding) - self.clear_episode_buffer(delete_images=len(self.meta.image_keys) > 0) - - def _batch_save_episode_video(self, start_episode: int, end_episode: int | None = None) -> None: - """ - Batch save videos for multiple episodes. - - Args: - start_episode: Starting episode index (inclusive) - end_episode: Ending episode index (exclusive). If None, encodes all episodes from start_episode to the current episode. - """ - if end_episode is None: - end_episode = self.num_episodes - - logger.info( - f"Batch encoding {self.batch_encoding_size} videos for episodes {start_episode} to {end_episode - 1}" - ) - - chunk_idx = self.meta.episodes[start_episode]["data/chunk_index"] - file_idx = self.meta.episodes[start_episode]["data/file_index"] - episode_df_path = self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx) - episode_df = pd.read_parquet(episode_df_path) - - for ep_idx in range(start_episode, end_episode): - logger.info(f"Encoding videos for episode {ep_idx}") - - if ( - self.meta.episodes[ep_idx]["data/chunk_index"] != chunk_idx - or self.meta.episodes[ep_idx]["data/file_index"] != file_idx - ): - # The current episode is in a new chunk or file. - # Save previous episode dataframe and update the Hugging Face dataset by reloading it. - episode_df.to_parquet(episode_df_path) - self.meta.episodes = load_episodes(self.root) - - # Load new episode dataframe - chunk_idx = self.meta.episodes[ep_idx]["data/chunk_index"] - file_idx = self.meta.episodes[ep_idx]["data/file_index"] - episode_df_path = self.root / DEFAULT_EPISODES_PATH.format( - chunk_index=chunk_idx, file_index=file_idx - ) - episode_df = pd.read_parquet(episode_df_path) - - # Save the current episode's video metadata to the dataframe - video_ep_metadata = {} - for video_key in self.meta.video_keys: - video_ep_metadata.update(self._save_episode_video(video_key, ep_idx)) - video_ep_metadata.pop("episode_index") - video_ep_df = pd.DataFrame(video_ep_metadata, index=[ep_idx]).convert_dtypes( - dtype_backend="pyarrow" - ) # allows NaN values along with integers - - episode_df = episode_df.combine_first(video_ep_df) - episode_df.to_parquet(episode_df_path) - self.meta.episodes = load_episodes(self.root) - - def _save_episode_data(self, episode_buffer: dict) -> dict: - """Save episode data to a parquet file and update the Hugging Face dataset of frames data. - - This function processes episodes data from a buffer, converts it into a Hugging Face dataset, - and saves it as a parquet file. It handles both the creation of new parquet files and the - updating of existing ones based on size constraints. After saving the data, it reloads - the Hugging Face dataset to ensure it is up-to-date. - - Notes: We both need to update parquet files and HF dataset: - - `pandas` loads parquet file in RAM - - `datasets` relies on a memory mapping from pyarrow (no RAM). It either converts parquet files to a pyarrow cache on disk, - or loads directly from pyarrow cache. - """ - # Convert buffer into HF Dataset - ep_dict = {key: episode_buffer[key] for key in self.hf_features} - ep_dataset = datasets.Dataset.from_dict(ep_dict, features=self.hf_features, split="train") - ep_dataset = embed_images(ep_dataset) - ep_num_frames = len(ep_dataset) - - if self.latest_episode is None: - # Initialize indices and frame count for a new dataset made of the first episode data - chunk_idx, file_idx = 0, 0 - global_frame_index = 0 - self._current_file_start_frame = 0 - # However, if the episodes already exists - # It means we are resuming recording, so we need to load the latest episode - # Update the indices to avoid overwriting the latest episode - if self.meta.episodes is not None and len(self.meta.episodes) > 0: - latest_ep = self.meta.episodes[-1] - global_frame_index = latest_ep["dataset_to_index"] - chunk_idx = latest_ep["data/chunk_index"] - file_idx = latest_ep["data/file_index"] - - # When resuming, move to the next file - chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.meta.chunks_size) - self._current_file_start_frame = global_frame_index - else: - # Retrieve information from the latest parquet file - latest_ep = self.latest_episode - chunk_idx = latest_ep["data/chunk_index"] - file_idx = latest_ep["data/file_index"] - global_frame_index = latest_ep["index"][-1] + 1 - - latest_path = self.root / self.meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx) - latest_size_in_mb = get_file_size_in_mb(latest_path) - - frames_in_current_file = global_frame_index - self._current_file_start_frame - av_size_per_frame = ( - latest_size_in_mb / frames_in_current_file if frames_in_current_file > 0 else 0 - ) - - # Determine if a new parquet file is needed - if ( - latest_size_in_mb + av_size_per_frame * ep_num_frames >= self.meta.data_files_size_in_mb - or self._writer_closed_for_reading - ): - # Size limit is reached or writer was closed for reading, prepare new parquet file - chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.meta.chunks_size) - self._close_writer() - self._writer_closed_for_reading = False - self._current_file_start_frame = global_frame_index - - ep_dict["data/chunk_index"] = chunk_idx - ep_dict["data/file_index"] = file_idx - - # Write the resulting dataframe from RAM to disk - path = self.root / self.meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx) - path.parent.mkdir(parents=True, exist_ok=True) - - table = ep_dataset.with_format("arrow")[:] - if not self.writer: - self.writer = pq.ParquetWriter( - path, schema=table.schema, compression="snappy", use_dictionary=True - ) - self.writer.write_table(table) - - metadata = { - "data/chunk_index": chunk_idx, - "data/file_index": file_idx, - "dataset_from_index": global_frame_index, - "dataset_to_index": global_frame_index + ep_num_frames, - } - - # Store metadata with episode data for next episode - self.latest_episode = {**ep_dict, **metadata} - - # Mark that the HF dataset needs reloading (lazy loading approach) - # This avoids expensive reloading during sequential recording - self._lazy_loading = True - # Update recorded frames count for efficient length tracking - self._recorded_frames += ep_num_frames - - return metadata - - def _save_episode_video( - self, - video_key: str, - episode_index: int, - temp_path: Path | None = None, - ) -> dict: - # Encode episode frames into a temporary video - if temp_path is None: - ep_path = self._encode_temporary_episode_video(video_key, episode_index) - else: - ep_path = temp_path - - ep_size_in_mb = get_file_size_in_mb(ep_path) - ep_duration_in_s = get_video_duration_in_s(ep_path) - - if ( - episode_index == 0 - or self.meta.latest_episode is None - or f"videos/{video_key}/chunk_index" not in self.meta.latest_episode - ): - # Initialize indices for a new dataset made of the first episode data - chunk_idx, file_idx = 0, 0 - if self.meta.episodes is not None and len(self.meta.episodes) > 0: - # It means we are resuming recording, so we need to load the latest episode - # Update the indices to avoid overwriting the latest episode - old_chunk_idx = self.meta.episodes[-1][f"videos/{video_key}/chunk_index"] - old_file_idx = self.meta.episodes[-1][f"videos/{video_key}/file_index"] - chunk_idx, file_idx = update_chunk_file_indices( - old_chunk_idx, old_file_idx, self.meta.chunks_size - ) - latest_duration_in_s = 0.0 - new_path = self.root / self.meta.video_path.format( - video_key=video_key, chunk_index=chunk_idx, file_index=file_idx - ) - new_path.parent.mkdir(parents=True, exist_ok=True) - shutil.move(str(ep_path), str(new_path)) - else: - # Retrieve information from the latest updated video file using latest_episode - latest_ep = self.meta.latest_episode - chunk_idx = latest_ep[f"videos/{video_key}/chunk_index"][0] - file_idx = latest_ep[f"videos/{video_key}/file_index"][0] - - latest_path = self.root / self.meta.video_path.format( - video_key=video_key, chunk_index=chunk_idx, file_index=file_idx - ) - latest_size_in_mb = get_file_size_in_mb(latest_path) - latest_duration_in_s = latest_ep[f"videos/{video_key}/to_timestamp"][0] - - if latest_size_in_mb + ep_size_in_mb >= self.meta.video_files_size_in_mb: - # Move temporary episode video to a new video file in the dataset - chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.meta.chunks_size) - new_path = self.root / self.meta.video_path.format( - video_key=video_key, chunk_index=chunk_idx, file_index=file_idx - ) - new_path.parent.mkdir(parents=True, exist_ok=True) - shutil.move(str(ep_path), str(new_path)) - latest_duration_in_s = 0.0 - else: - # Update latest video file - concatenate_video_files( - [latest_path, ep_path], - latest_path, - ) - - # Remove temporary directory - shutil.rmtree(str(ep_path.parent)) - - # Update video info (only needed when first episode is encoded since it reads from episode 0) - if episode_index == 0: - self.meta.update_video_info(video_key) - write_info(self.meta.info, self.meta.root) # ensure video info always written properly - - metadata = { - "episode_index": episode_index, - f"videos/{video_key}/chunk_index": chunk_idx, - f"videos/{video_key}/file_index": file_idx, - f"videos/{video_key}/from_timestamp": latest_duration_in_s, - f"videos/{video_key}/to_timestamp": latest_duration_in_s + ep_duration_in_s, - } - return metadata - - def clear_episode_buffer(self, delete_images: bool = True) -> None: - # Cancel streaming encoder if active - if self._streaming_encoder is not None: - self._streaming_encoder.cancel_episode() - - # Clean up image files for the current episode buffer - if delete_images: - # Wait for the async image writer to finish - if self.image_writer is not None: - self._wait_image_writer() - episode_index = self.episode_buffer["episode_index"] - if isinstance(episode_index, np.ndarray): - episode_index = episode_index.item() if episode_index.size == 1 else episode_index[0] - for cam_key in self.meta.image_keys: - img_dir = self._get_image_file_dir(episode_index, cam_key) - if img_dir.is_dir(): - shutil.rmtree(img_dir) - - # Reset the buffer - self.episode_buffer = self.create_episode_buffer() - - def start_image_writer(self, num_processes: int = 0, num_threads: int = 4) -> None: - if isinstance(self.image_writer, AsyncImageWriter): - logger.warning( - "You are starting a new AsyncImageWriter that is replacing an already existing one in the dataset." - ) - - self.image_writer = AsyncImageWriter( - num_processes=num_processes, - num_threads=num_threads, - ) - - def stop_image_writer(self) -> None: - """ - Whenever wrapping this dataset inside a parallelized DataLoader, this needs to be called first to - remove the image_writer in order for the LeRobotDataset object to be pickleable and parallelized. - """ - if self.image_writer is not None: - self.image_writer.stop() - self.image_writer = None - - def _wait_image_writer(self) -> None: - """Wait for asynchronous image writer to finish.""" - if self.image_writer is not None: - self.image_writer.wait_until_done() - - def _encode_temporary_episode_video(self, video_key: str, episode_index: int) -> Path: - """ - Use ffmpeg to convert frames stored as png into mp4 videos. - Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding, - since video encoding with ffmpeg is already using multithreading. - """ - return _encode_video_worker( - video_key, episode_index, self.root, self.fps, self.vcodec, self._encoder_threads - ) + # ── Class constructors ──────────────────────────────────────────── @classmethod def create( @@ -1184,7 +587,42 @@ class LeRobotDataset(torch.utils.data.Dataset): encoder_queue_maxsize: int = 30, encoder_threads: int | None = None, ) -> "LeRobotDataset": - """Create a LeRobot Dataset from scratch in order to record data.""" + """Create a new LeRobotDataset from scratch for recording data. + + Returns a write-mode dataset with an active :class:`DatasetWriter`. Use + :meth:`add_frame` / :meth:`save_episode` to populate it, then + :meth:`finalize` when done. + + Args: + repo_id: Repository identifier, typically ``'{hf_user}/{dataset_name}'``. + fps: Frames per second used during data collection. + features: Feature specification dict mapping feature names to their + type/shape metadata. + root: Local directory for dataset storage. Defaults to + ``$HF_LEROBOT_HOME/{repo_id}``. + robot_type: Optional robot type string stored in metadata. + use_videos: If ``True``, visual modalities are stored as MP4 videos. + If ``False``, they are stored as images. + tolerance_s: Timestamp synchronization tolerance in seconds. + image_writer_processes: Number of subprocesses for async image + writing. ``0`` means use threads only. + image_writer_threads: Number of threads for async image writing. + video_backend: Video decoding backend (used when reading back). + batch_encoding_size: Number of episodes to accumulate before + batch-encoding videos. ``1`` means encode immediately. + vcodec: Video codec for encoding. Options include ``'libsvtav1'``, + ``'h264'``, ``'hevc'``, ``'auto'``. + metadata_buffer_size: Number of episode metadata records to buffer + before flushing to parquet. + streaming_encoding: If ``True``, encode video frames in real-time + during capture instead of writing images first. + encoder_queue_maxsize: Max buffered frames per camera when using + streaming encoding. + encoder_threads: Threads per encoder instance. ``None`` for auto. + + Returns: + A new :class:`LeRobotDataset` in write mode. + """ vcodec = resolve_vcodec(vcodec) obj = cls.__new__(cls) obj.meta = LeRobotDatasetMetadata.create( @@ -1200,45 +638,126 @@ class LeRobotDataset(torch.utils.data.Dataset): obj.root = obj.meta.root obj.revision = None obj.tolerance_s = tolerance_s - obj.image_writer = None - obj.batch_encoding_size = batch_encoding_size - obj.episodes_since_last_encoding = 0 - obj.vcodec = vcodec - obj._encoder_threads = encoder_threads - - if image_writer_processes or image_writer_threads: - obj.start_image_writer(image_writer_processes, image_writer_threads) - - obj.episode_buffer = obj.create_episode_buffer() - - obj.episodes = None - obj.hf_dataset = obj.create_hf_dataset() obj.image_transforms = None obj.delta_timestamps = None - obj.delta_indices = None - obj._absolute_to_relative_idx = None - obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec() - obj.writer = None - obj.latest_episode = None - obj._current_file_start_frame = None - # Initialize tracking for incremental recording - obj._lazy_loading = False - obj._recorded_frames = 0 - obj._writer_closed_for_reading = False + obj.episodes = None + obj._video_backend = video_backend if video_backend is not None else get_safe_default_codec() + obj._batch_encoding_size = batch_encoding_size + obj._vcodec = vcodec + obj._encoder_threads = encoder_threads - # Initialize streaming encoder + # Reader is lazily created on first access (write-only mode) + obj.reader = None + + # Create writer + streaming_enc = None if streaming_encoding and len(obj.meta.video_keys) > 0: - obj._streaming_encoder = StreamingVideoEncoder( - fps=fps, - vcodec=vcodec, - pix_fmt="yuv420p", - g=2, - crf=30, - preset=None, - queue_maxsize=encoder_queue_maxsize, - encoder_threads=encoder_threads, - ) - else: - obj._streaming_encoder = None + streaming_enc = cls._build_streaming_encoder(fps, vcodec, encoder_queue_maxsize, encoder_threads) + obj.writer = DatasetWriter( + meta=obj.meta, + root=obj.root, + vcodec=vcodec, + encoder_threads=encoder_threads, + batch_encoding_size=batch_encoding_size, + streaming_encoder=streaming_enc, + ) + + if image_writer_processes or image_writer_threads: + obj.writer.start_image_writer(image_writer_processes, image_writer_threads) + + obj._is_finalized = False + + return obj + + @classmethod + def resume( + cls, + repo_id: str, + root: str | Path | None = None, + tolerance_s: float = 1e-4, + revision: str | None = None, + force_cache_sync: bool = False, + video_backend: str | None = None, + batch_encoding_size: int = 1, + vcodec: str = "libsvtav1", + image_writer_processes: int = 0, + image_writer_threads: int = 0, + streaming_encoding: bool = False, + encoder_queue_maxsize: int = 30, + encoder_threads: int | None = None, + ) -> "LeRobotDataset": + """Resume recording on an existing dataset. + + Loads metadata from an existing dataset (local or Hub) and creates a + :class:`DatasetWriter` for appending new episodes. The underlying HF + dataset is not loaded until :meth:`finalize` is called and data is + subsequently read. + + Args: + repo_id: Repository identifier of the existing dataset. + root: Local directory of the dataset. Defaults to + ``$HF_LEROBOT_HOME/{repo_id}``. + tolerance_s: Timestamp synchronization tolerance in seconds. + revision: Git revision (branch, tag, or commit hash). Defaults to + current codebase version tag. + force_cache_sync: If ``True``, re-download metadata from the Hub even + if a local cache exists. + video_backend: Video decoding backend for reading back data. + batch_encoding_size: Number of episodes to accumulate before + batch-encoding videos. + vcodec: Video codec for encoding. + image_writer_processes: Subprocesses for async image writing. + image_writer_threads: Threads for async image writing. + streaming_encoding: If ``True``, encode video in real-time during + capture. + encoder_queue_maxsize: Max buffered frames per camera for streaming. + encoder_threads: Threads per encoder instance. ``None`` for auto. + + Returns: + A :class:`LeRobotDataset` in write mode, ready to append episodes. + """ + vcodec = resolve_vcodec(vcodec) + obj = cls.__new__(cls) + obj.repo_id = repo_id + obj.root = Path(root) if root else HF_LEROBOT_HOME / repo_id + obj.root.mkdir(exist_ok=True, parents=True) + obj.revision = revision if revision else CODEBASE_VERSION + obj.tolerance_s = tolerance_s + obj.image_transforms = None + obj.delta_timestamps = None + obj.episodes = None + obj._video_backend = video_backend if video_backend else get_safe_default_codec() + obj._batch_encoding_size = batch_encoding_size + obj._vcodec = vcodec + obj._encoder_threads = encoder_threads + + # Load metadata + obj.meta = LeRobotDatasetMetadata( + obj.repo_id, obj.root, obj.revision, force_cache_sync=force_cache_sync + ) + + # Reader is lazily created on first access (write-only mode) + obj.reader = None + + # Create writer for appending + streaming_enc = None + if streaming_encoding and len(obj.meta.video_keys) > 0: + streaming_enc = cls._build_streaming_encoder( + obj.meta.fps, vcodec, encoder_queue_maxsize, encoder_threads + ) + obj.writer = DatasetWriter( + meta=obj.meta, + root=obj.root, + vcodec=vcodec, + encoder_threads=encoder_threads, + batch_encoding_size=batch_encoding_size, + streaming_encoder=streaming_enc, + initial_frames=obj.meta.total_frames, + ) + + if image_writer_processes or image_writer_threads: + obj.writer.start_image_writer(image_writer_processes, image_writer_threads) + + obj._is_finalized = False return obj diff --git a/src/lerobot/datasets/multi_dataset.py b/src/lerobot/datasets/multi_dataset.py index 917d5c5eb..d16c5bb07 100644 --- a/src/lerobot/datasets/multi_dataset.py +++ b/src/lerobot/datasets/multi_dataset.py @@ -22,6 +22,7 @@ import torch import torch.utils from lerobot.datasets.compute_stats import aggregate_stats +from lerobot.datasets.feature_utils import get_hf_features_from_features from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.video_utils import VideoFrame from lerobot.utils.constants import HF_LEROBOT_HOME @@ -125,7 +126,13 @@ class MultiLeRobotDataset(torch.utils.data.Dataset): def features(self) -> datasets.Features: features = {} for dataset in self._datasets: - features.update({k: v for k, v in dataset.hf_features.items() if k not in self.disabled_features}) + features.update( + { + k: v + for k, v in get_hf_features_from_features(dataset.features).items() + if k not in self.disabled_features + } + ) return features @property diff --git a/src/lerobot/datasets/video_utils.py b/src/lerobot/datasets/video_utils.py index e465b79b4..59c8c7d3e 100644 --- a/src/lerobot/datasets/video_utils.py +++ b/src/lerobot/datasets/video_utils.py @@ -741,6 +741,7 @@ class StreamingVideoEncoder: self._video_paths: dict[str, Path] = {} self._dropped_frames: dict[str, int] = {} self._episode_active = False + self._closed = False def start_episode(self, video_keys: list[str], temp_dir: Path) -> None: """Start encoder threads for a new episode. @@ -895,8 +896,11 @@ class StreamingVideoEncoder: def close(self) -> None: """Close the encoder, canceling any in-progress episode.""" + if self._closed: + return if self._episode_active: self.cancel_episode() + self._closed = True def _cleanup(self) -> None: """Clean up queues and thread tracking dicts.""" @@ -1063,43 +1067,19 @@ class VideoEncodingManager: return self def __exit__(self, exc_type, exc_val, exc_tb): - streaming_encoder = getattr(self.dataset, "_streaming_encoder", None) + writer = self.dataset.writer + if writer is not None: + if exc_type is not None and writer._streaming_encoder is not None: + writer.cancel_pending_videos() - if streaming_encoder is not None: - # Handle streaming encoder cleanup - if exc_type is not None: - streaming_encoder.cancel_episode() - streaming_encoder.close() - elif self.dataset.episodes_since_last_encoding > 0: - # Handle any remaining episodes that haven't been batch encoded - if exc_type is not None: - logger.info("Exception occurred. Encoding remaining episodes before exit...") - else: - logger.info("Recording stopped. Encoding remaining episodes...") + # finalize() handles flush_pending_videos + parquet + metadata + self.dataset.finalize() - start_ep = self.dataset.num_episodes - self.dataset.episodes_since_last_encoding - end_ep = self.dataset.num_episodes - logger.info( - f"Encoding remaining {self.dataset.episodes_since_last_encoding} episodes, " - f"from episode {start_ep} to {end_ep - 1}" - ) - self.dataset._batch_save_episode_video(start_ep, end_ep) - - # Finalize the dataset to properly close all writers - self.dataset.finalize() - - # Clean up episode images if recording was interrupted (only for non-streaming mode) - if exc_type is not None and streaming_encoder is None: - interrupted_episode_index = self.dataset.num_episodes - for key in self.dataset.meta.video_keys: - img_dir = self.dataset._get_image_file_path( - episode_index=interrupted_episode_index, image_key=key, frame_index=0 - ).parent - if img_dir.exists(): - logger.debug( - f"Cleaning up interrupted episode images for episode {interrupted_episode_index}, camera {key}" - ) - shutil.rmtree(img_dir) + # Clean up episode images if recording was interrupted (only for non-streaming mode) + if exc_type is not None and writer._streaming_encoder is None: + writer.cleanup_interrupted_episode(self.dataset.num_episodes) + else: + self.dataset.finalize() # Clean up any remaining images directory if it's empty img_dir = self.dataset.root / "images" diff --git a/src/lerobot/rl/buffer.py b/src/lerobot/rl/buffer.py index 81aa29c48..68954162d 100644 --- a/src/lerobot/rl/buffer.py +++ b/src/lerobot/rl/buffer.py @@ -563,7 +563,7 @@ class ReplayBuffer: ) # Start writing images if needed - lerobot_dataset.start_image_writer(num_processes=0, num_threads=3) + lerobot_dataset.writer.start_image_writer(num_processes=0, num_threads=3) # Convert transitions into episodes and frames @@ -603,10 +603,10 @@ class ReplayBuffer: lerobot_dataset.save_episode() # Save any remaining frames in the buffer - if lerobot_dataset.episode_buffer["size"] > 0: + if lerobot_dataset.has_pending_frames(): lerobot_dataset.save_episode() - lerobot_dataset.stop_image_writer() + lerobot_dataset.writer.stop_image_writer() lerobot_dataset.finalize() return lerobot_dataset diff --git a/src/lerobot/rl/gym_manipulator.py b/src/lerobot/rl/gym_manipulator.py index f5fcb7437..bd64d205f 100644 --- a/src/lerobot/rl/gym_manipulator.py +++ b/src/lerobot/rl/gym_manipulator.py @@ -752,8 +752,7 @@ def replay_trajectory( episodes=[cfg.dataset.replay_episode], download_videos=False, ) - episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == cfg.dataset.replay_episode) - actions = episode_frames.select_columns(ACTION) + actions = dataset.select_columns(ACTION) _, info = env.reset() diff --git a/src/lerobot/scripts/lerobot_record.py b/src/lerobot/scripts/lerobot_record.py index 819634ba2..ac01c9319 100644 --- a/src/lerobot/scripts/lerobot_record.py +++ b/src/lerobot/scripts/lerobot_record.py @@ -468,7 +468,8 @@ def record(cfg: RecordConfig) -> LeRobotDataset: try: if cfg.resume: - dataset = LeRobotDataset( + num_cameras = len(robot.cameras) if hasattr(robot, "cameras") else 0 + dataset = LeRobotDataset.resume( cfg.dataset.repo_id, root=cfg.dataset.root, batch_encoding_size=cfg.dataset.video_encoding_batch_size, @@ -476,13 +477,11 @@ def record(cfg: RecordConfig) -> LeRobotDataset: streaming_encoding=cfg.dataset.streaming_encoding, encoder_queue_maxsize=cfg.dataset.encoder_queue_maxsize, encoder_threads=cfg.dataset.encoder_threads, + image_writer_processes=cfg.dataset.num_image_writer_processes if num_cameras > 0 else 0, + image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera * num_cameras + if num_cameras > 0 + else 0, ) - - if hasattr(robot, "cameras") and len(robot.cameras) > 0: - dataset.start_image_writer( - num_processes=cfg.dataset.num_image_writer_processes, - num_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot.cameras), - ) sanity_check_dataset_robot_compatibility(dataset, robot, cfg.dataset.fps, dataset_features) else: # Create empty dataset or load existing saved episodes diff --git a/src/lerobot/scripts/lerobot_replay.py b/src/lerobot/scripts/lerobot_replay.py index 7c0b5b96b..09e7d4e8b 100644 --- a/src/lerobot/scripts/lerobot_replay.py +++ b/src/lerobot/scripts/lerobot_replay.py @@ -104,15 +104,13 @@ def replay(cfg: ReplayConfig): robot = make_robot_from_config(cfg.robot) dataset = LeRobotDataset(cfg.dataset.repo_id, root=cfg.dataset.root, episodes=[cfg.dataset.episode]) - # Filter dataset to only include frames from the specified episode since episodes are chunked in dataset V3.0 - episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == cfg.dataset.episode) - actions = episode_frames.select_columns(ACTION) + actions = dataset.select_columns(ACTION) robot.connect() try: log_say("Replaying episode", cfg.play_sounds, blocking=True) - for idx in range(len(episode_frames)): + for idx in range(dataset.num_frames): start_episode_t = time.perf_counter() action_array = actions[idx][ACTION] diff --git a/src/lerobot/scripts/lerobot_train_tokenizer.py b/src/lerobot/scripts/lerobot_train_tokenizer.py index 807d48333..70185fc51 100644 --- a/src/lerobot/scripts/lerobot_train_tokenizer.py +++ b/src/lerobot/scripts/lerobot_train_tokenizer.py @@ -204,15 +204,15 @@ def process_episode(args): for abs_idx in range(from_idx, to_idx): # map absolute index to relative index if needed - if dataset._absolute_to_relative_idx is not None: - if abs_idx not in dataset._absolute_to_relative_idx: + if dataset.reader._absolute_to_relative_idx is not None: + if abs_idx not in dataset.reader._absolute_to_relative_idx: # this episode's frames aren't in the filtered dataset return None - rel_idx = dataset._absolute_to_relative_idx[abs_idx] + rel_idx = dataset.reader._absolute_to_relative_idx[abs_idx] else: rel_idx = abs_idx - frame = dataset.hf_dataset[rel_idx] + frame = dataset.get_raw_item(rel_idx) # get state (could be from observation.state or other state key) if state_key in frame: diff --git a/tests/artifacts/policies/save_policy_to_safetensors.py b/tests/artifacts/policies/save_policy_to_safetensors.py index 64b125cc9..7359f6169 100644 --- a/tests/artifacts/policies/save_policy_to_safetensors.py +++ b/tests/artifacts/policies/save_policy_to_safetensors.py @@ -80,7 +80,7 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict): # HACK: We reload a batch with no delta_indices as `select_action` won't expect a timestamps dimension # We simulate having an environment using a dataset by setting delta_indices to None and dropping tensors # indicating padding (those ending with "_is_pad") - dataset.delta_indices = None + dataset.reader.delta_indices = None batch = next(iter(dataloader)) obs = {} for k in batch: diff --git a/tests/datasets/test_dataset_metadata.py b/tests/datasets/test_dataset_metadata.py new file mode 100644 index 000000000..3f3971e15 --- /dev/null +++ b/tests/datasets/test_dataset_metadata.py @@ -0,0 +1,385 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contract tests for LeRobotDatasetMetadata.""" + +import json + +import numpy as np +import pytest + +from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata +from lerobot.datasets.utils import INFO_PATH +from tests.fixtures.constants import DEFAULT_FPS, DUMMY_ROBOT_TYPE + +# ── helpers ────────────────────────────────────────────────────────── + +SIMPLE_FEATURES = { + "state": {"dtype": "float32", "shape": (6,), "names": None}, + "action": {"dtype": "float32", "shape": (6,), "names": None}, +} + +VIDEO_FEATURES = { + **SIMPLE_FEATURES, + "observation.images.laptop": { + "dtype": "video", + "shape": (64, 96, 3), + "names": ["height", "width", "channels"], + "info": None, + }, +} + +IMAGE_FEATURES = { + **SIMPLE_FEATURES, + "observation.images.laptop": { + "dtype": "image", + "shape": (64, 96, 3), + "names": ["height", "width", "channels"], + "info": None, + }, +} + + +def _make_dummy_stats(features: dict) -> dict: + """Create minimal episode stats matching the given features.""" + stats = {} + for key, ft in features.items(): + if ft["dtype"] in ("image", "video"): + stats[key] = { + "max": np.ones((3, 1, 1), dtype=np.float32), + "mean": np.full((3, 1, 1), 0.5, dtype=np.float32), + "min": np.zeros((3, 1, 1), dtype=np.float32), + "std": np.full((3, 1, 1), 0.25, dtype=np.float32), + "count": np.array([5]), + } + elif ft["dtype"] in ("float32", "float64", "int64"): + stats[key] = { + "max": np.ones(ft["shape"], dtype=np.float32), + "mean": np.full(ft["shape"], 0.5, dtype=np.float32), + "min": np.zeros(ft["shape"], dtype=np.float32), + "std": np.full(ft["shape"], 0.25, dtype=np.float32), + "count": np.array([5]), + } + return stats + + +# ── Construction contracts ─────────────────────────────────────────── + + +def test_create_produces_valid_info_on_disk(tmp_path): + """create() writes info.json and the returned object reflects the provided settings.""" + root = tmp_path / "new_ds" + meta = LeRobotDatasetMetadata.create( + repo_id="test/meta", + fps=DEFAULT_FPS, + features=SIMPLE_FEATURES, + robot_type=DUMMY_ROBOT_TYPE, + root=root, + use_videos=False, + ) + + # info.json was written to disk + assert (root / INFO_PATH).exists() + with open(root / INFO_PATH) as f: + info_on_disk = json.load(f) + + assert meta.fps == DEFAULT_FPS + assert meta.robot_type == DUMMY_ROBOT_TYPE + assert "state" in meta.features + assert "action" in meta.features + assert info_on_disk["fps"] == DEFAULT_FPS + + +def test_create_starts_with_zero_counts(tmp_path): + """A freshly created metadata has zero episode/frame/task counts.""" + root = tmp_path / "empty_ds" + meta = LeRobotDatasetMetadata.create( + repo_id="test/empty", fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root, use_videos=False + ) + + assert meta.total_episodes == 0 + assert meta.total_frames == 0 + assert meta.total_tasks == 0 + assert meta.tasks is None + assert meta.episodes is None + assert meta.stats is None + + +def test_create_with_videos_sets_video_path(tmp_path): + """When features include video-dtype keys, create() produces a non-None video_path.""" + root = tmp_path / "video_ds" + meta = LeRobotDatasetMetadata.create( + repo_id="test/video", fps=DEFAULT_FPS, features=VIDEO_FEATURES, root=root, use_videos=True + ) + + assert meta.video_path is not None + assert len(meta.video_keys) == 1 + assert "observation.images.laptop" in meta.video_keys + + +def test_create_without_videos_has_no_video_path(tmp_path): + """When use_videos=False and no video features, video_path is None.""" + root = tmp_path / "no_video" + meta = LeRobotDatasetMetadata.create( + repo_id="test/novid", fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root, use_videos=False + ) + + assert meta.video_path is None + assert meta.video_keys == [] + + +def test_create_raises_on_existing_directory(tmp_path): + """create() raises if root directory already exists.""" + root = tmp_path / "existing" + root.mkdir() + + with pytest.raises(FileExistsError): + LeRobotDatasetMetadata.create( + repo_id="test/exists", fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root, use_videos=False + ) + + +def test_init_loads_existing_metadata(tmp_path, lerobot_dataset_metadata_factory, info_factory): + """When metadata files exist on disk, __init__ loads them correctly.""" + root = tmp_path / "load_test" + info = info_factory(total_episodes=3, total_frames=150, total_tasks=1, use_videos=False) + meta = lerobot_dataset_metadata_factory(root=root, info=info) + + assert meta.total_episodes == 3 + assert meta.total_frames == 150 + assert meta.fps == info["fps"] + + +# ── Property accessors ─────────────────────────────────────────────── + + +def test_property_accessors_reflect_info(tmp_path): + """Properties return values consistent with the info dict.""" + root = tmp_path / "props_ds" + meta = LeRobotDatasetMetadata.create( + repo_id="test/props", + fps=DEFAULT_FPS, + features=IMAGE_FEATURES, + robot_type=DUMMY_ROBOT_TYPE, + root=root, + use_videos=False, + ) + + assert meta.fps == DEFAULT_FPS + assert meta.robot_type == DUMMY_ROBOT_TYPE + # shapes should be tuples + for _key, shape in meta.shapes.items(): + assert isinstance(shape, tuple) + # image_keys should contain the image feature + assert "observation.images.laptop" in meta.image_keys + # camera_keys is a superset of image_keys and video_keys + assert set(meta.image_keys + meta.video_keys) == set(meta.camera_keys) + + +def test_data_path_is_formattable(tmp_path): + """data_path contains format placeholders that can be .format()-ed.""" + root = tmp_path / "fmt_ds" + meta = LeRobotDatasetMetadata.create( + repo_id="test/fmt", fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root, use_videos=False + ) + + formatted = meta.data_path.format(chunk_index=0, file_index=0) + assert "chunk" in formatted.lower() or "0" in formatted + + +# ── Task management ────────────────────────────────────────────────── + + +def test_save_episode_tasks_creates_tasks_dataframe(tmp_path): + """On a fresh metadata, save_episode_tasks() creates the tasks DataFrame.""" + root = tmp_path / "task_ds" + meta = LeRobotDatasetMetadata.create( + repo_id="test/task", fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root, use_videos=False + ) + assert meta.tasks is None + + meta.save_episode_tasks(["Pick up the cube"]) + + assert meta.tasks is not None + assert len(meta.tasks) == 1 + assert "Pick up the cube" in meta.tasks.index + + +def test_save_episode_tasks_is_additive(tmp_path): + """New tasks are added; existing tasks keep their original index.""" + root = tmp_path / "additive_ds" + meta = LeRobotDatasetMetadata.create( + repo_id="test/add", fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root, use_videos=False + ) + + meta.save_episode_tasks(["Task A"]) + idx_a = meta.get_task_index("Task A") + + meta.save_episode_tasks(["Task A", "Task B"]) + assert meta.get_task_index("Task A") == idx_a # unchanged + assert meta.get_task_index("Task B") is not None + assert len(meta.tasks) == 2 + + +def test_get_task_index_returns_none_for_unknown(tmp_path): + """get_task_index() returns None for an unknown task.""" + root = tmp_path / "unknown_ds" + meta = LeRobotDatasetMetadata.create( + repo_id="test/unknown", fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root, use_videos=False + ) + meta.save_episode_tasks(["Known task"]) + + assert meta.get_task_index("Known task") == 0 + assert meta.get_task_index("Unknown task") is None + + +def test_save_episode_tasks_rejects_duplicates(tmp_path): + """save_episode_tasks() raises ValueError on duplicate task strings.""" + root = tmp_path / "dup_ds" + meta = LeRobotDatasetMetadata.create( + repo_id="test/dup", fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root, use_videos=False + ) + + with pytest.raises(ValueError): + meta.save_episode_tasks(["Same task", "Same task"]) + + +# ── Episode saving ─────────────────────────────────────────────────── + + +def test_save_episode_increments_counters(tmp_path): + """After save_episode(), total_episodes and total_frames increase.""" + root = tmp_path / "ep_ds" + meta = LeRobotDatasetMetadata.create( + repo_id="test/ep", fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root, use_videos=False + ) + meta.save_episode_tasks(["Task 1"]) + stats = _make_dummy_stats(meta.features) + + meta.save_episode( + episode_index=0, + episode_length=10, + episode_tasks=["Task 1"], + episode_stats=stats, + episode_metadata={}, + ) + + assert meta.total_episodes == 1 + assert meta.total_frames == 10 + + +def test_save_episode_updates_stats(tmp_path): + """After save_episode(), .stats is non-None and has feature keys.""" + root = tmp_path / "stats_ds" + meta = LeRobotDatasetMetadata.create( + repo_id="test/stats", fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root, use_videos=False + ) + meta.save_episode_tasks(["Task 1"]) + stats = _make_dummy_stats(meta.features) + + meta.save_episode( + episode_index=0, + episode_length=5, + episode_tasks=["Task 1"], + episode_stats=stats, + episode_metadata={}, + ) + + assert meta.stats is not None + # Stats should contain at least the user-defined feature keys + for key in SIMPLE_FEATURES: + assert key in meta.stats + + +# ── Chunk settings ─────────────────────────────────────────────────── + + +def test_update_chunk_settings_persists(tmp_path): + """update_chunk_settings() changes values and writes info.json.""" + root = tmp_path / "chunk_ds" + meta = LeRobotDatasetMetadata.create( + repo_id="test/chunk", fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root, use_videos=False + ) + original = meta.get_chunk_settings() + + meta.update_chunk_settings(chunks_size=500) + assert meta.chunks_size == 500 + assert meta.chunks_size != original["chunks_size"] or original["chunks_size"] == 500 + + # Verify persisted + with open(root / INFO_PATH) as f: + info_on_disk = json.load(f) + assert info_on_disk["chunks_size"] == 500 + + +def test_update_chunk_settings_rejects_non_positive(tmp_path): + """update_chunk_settings() raises ValueError for <= 0 values.""" + root = tmp_path / "bad_chunk" + meta = LeRobotDatasetMetadata.create( + repo_id="test/bad", fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root, use_videos=False + ) + + with pytest.raises(ValueError): + meta.update_chunk_settings(chunks_size=0) + with pytest.raises(ValueError): + meta.update_chunk_settings(data_files_size_in_mb=-1) + + +# ── Finalization ───────────────────────────────────────────────────── + + +def test_finalize_is_idempotent(tmp_path): + """Calling finalize() multiple times does not raise.""" + root = tmp_path / "fin_ds" + meta = LeRobotDatasetMetadata.create( + repo_id="test/fin", fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root, use_videos=False + ) + + meta.finalize() + meta.finalize() # second call should not raise + + +def test_finalize_flushes_buffered_metadata(tmp_path): + """Episodes saved before finalize() are written to parquet.""" + root = tmp_path / "flush_ds" + meta = LeRobotDatasetMetadata.create( + repo_id="test/flush", + fps=DEFAULT_FPS, + features=SIMPLE_FEATURES, + root=root, + use_videos=False, + metadata_buffer_size=100, # large buffer so nothing auto-flushes + ) + meta.save_episode_tasks(["Task 1"]) + stats = _make_dummy_stats(meta.features) + + # Save a few episodes (won't auto-flush since buffer_size=100) + for i in range(3): + meta.save_episode( + episode_index=i, + episode_length=5, + episode_tasks=["Task 1"], + episode_stats=stats, + episode_metadata={}, + ) + + # Before finalize, the parquet might not exist yet + meta.finalize() + + # After finalize, episodes parquet should exist + episodes_dir = root / "meta" / "episodes" + assert episodes_dir.exists() + parquet_files = list(episodes_dir.rglob("*.parquet")) + assert len(parquet_files) > 0 diff --git a/tests/datasets/test_dataset_reader.py b/tests/datasets/test_dataset_reader.py new file mode 100644 index 000000000..4c8a8b23f --- /dev/null +++ b/tests/datasets/test_dataset_reader.py @@ -0,0 +1,168 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contract tests for DatasetReader.""" + +from lerobot.datasets.dataset_reader import DatasetReader +from lerobot.datasets.video_utils import get_safe_default_codec + +# ── Loading ────────────────────────────────────────────────────────── + + +def test_try_load_returns_true_when_data_exists(tmp_path, lerobot_dataset_factory): + """Given a fully written dataset, try_load() returns True.""" + dataset = lerobot_dataset_factory( + root=tmp_path / "ds", total_episodes=2, total_frames=20, use_videos=False + ) + reader = DatasetReader( + meta=dataset.meta, + root=dataset.root, + episodes=None, + tolerance_s=1e-4, + video_backend=get_safe_default_codec(), + delta_timestamps=None, + image_transforms=None, + ) + assert reader.try_load() is True + assert reader.hf_dataset is not None + + +def test_try_load_returns_false_when_no_data(tmp_path): + """When only metadata exists (no data/ parquets), try_load() returns False.""" + from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata + + root = tmp_path / "meta_only" + features = {"state": {"dtype": "float32", "shape": (2,), "names": None}} + meta = LeRobotDatasetMetadata.create( + repo_id="test/meta_only", fps=30, features=features, root=root, use_videos=False + ) + + reader = DatasetReader( + meta=meta, + root=meta.root, + episodes=None, + tolerance_s=1e-4, + video_backend=get_safe_default_codec(), + delta_timestamps=None, + image_transforms=None, + ) + assert reader.try_load() is False + assert reader.hf_dataset is None + + +# ── Counts ─────────────────────────────────────────────────────────── + + +def test_num_frames_without_filter(tmp_path, lerobot_dataset_factory): + """With episodes=None, num_frames equals total_frames.""" + dataset = lerobot_dataset_factory( + root=tmp_path / "ds", total_episodes=3, total_frames=60, use_videos=False + ) + assert dataset.reader.num_frames == dataset.meta.total_frames + + +def test_num_episodes_without_filter(tmp_path, lerobot_dataset_factory): + """With episodes=None, num_episodes equals total_episodes.""" + dataset = lerobot_dataset_factory( + root=tmp_path / "ds", total_episodes=3, total_frames=60, use_videos=False + ) + assert dataset.reader.num_episodes == dataset.meta.total_episodes + + +def test_num_frames_with_episode_filter(tmp_path, lerobot_dataset_factory): + """When filtering to a subset, only those episodes' frames are counted.""" + dataset = lerobot_dataset_factory( + root=tmp_path / "ds", total_episodes=5, total_frames=100, episodes=[0, 2], use_videos=False + ) + # Filtered frames should be less than total + assert dataset.reader.num_frames <= dataset.meta.total_frames + assert dataset.reader.num_episodes == 2 + + +# ── get_item ───────────────────────────────────────────────────────── + + +def test_get_item_returns_expected_keys(tmp_path, lerobot_dataset_factory): + """get_item(0) returns a dict with expected keys.""" + dataset = lerobot_dataset_factory( + root=tmp_path / "ds", total_episodes=1, total_frames=10, use_videos=False + ) + item = dataset.reader.get_item(0) + + # Standard keys that must always be present + for key in ["index", "episode_index", "frame_index", "timestamp", "task_index", "task"]: + assert key in item, f"Missing key: {key}" + + +def test_get_item_values_are_correct(tmp_path, lerobot_dataset_factory): + """get_item() returns correct index and episode_index.""" + dataset = lerobot_dataset_factory( + root=tmp_path / "ds", total_episodes=2, total_frames=20, use_videos=False + ) + item_0 = dataset.reader.get_item(0) + + assert item_0["index"].item() == 0 + assert item_0["episode_index"].item() == 0 + + +# ── Transforms ─────────────────────────────────────────────────────── + + +def test_image_transforms_are_applied(tmp_path, lerobot_dataset_factory): + """When image_transforms is provided, get_item() applies it to camera keys.""" + transform_called = {"count": 0} + + def sentinel_transform(img): + transform_called["count"] += 1 + return img + + dataset = lerobot_dataset_factory( + root=tmp_path / "ds", + total_episodes=1, + total_frames=5, + use_videos=False, + image_transforms=sentinel_transform, + ) + item = dataset[0] # noqa: F841 + + # Should have been called once per camera key per frame + num_cameras = len(dataset.meta.camera_keys) + if num_cameras > 0: + assert transform_called["count"] >= 1 + + +# ── File paths ─────────────────────────────────────────────────────── + + +def test_get_episodes_file_paths_returns_data_paths(tmp_path, lerobot_dataset_factory): + """get_episodes_file_paths() returns paths including data/ paths.""" + dataset = lerobot_dataset_factory( + root=tmp_path / "ds", total_episodes=2, total_frames=20, use_videos=False + ) + paths = dataset.reader.get_episodes_file_paths() + + assert len(paths) > 0 + assert any("data/" in str(p) for p in paths) + + +def test_get_episodes_file_paths_includes_video_paths(tmp_path, lerobot_dataset_factory): + """When dataset has video keys, file paths include video/ paths.""" + dataset = lerobot_dataset_factory( + root=tmp_path / "ds", total_episodes=2, total_frames=20, use_videos=True + ) + + if len(dataset.meta.video_keys) > 0: + paths = dataset.reader.get_episodes_file_paths() + assert any("video" in str(p).lower() for p in paths) diff --git a/tests/datasets/test_dataset_writer.py b/tests/datasets/test_dataset_writer.py new file mode 100644 index 000000000..8c6ee68bd --- /dev/null +++ b/tests/datasets/test_dataset_writer.py @@ -0,0 +1,226 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contract tests for DatasetWriter.""" + +from pathlib import Path +from unittest.mock import patch + +import numpy as np +import pytest +import torch +from PIL import Image + +from lerobot.datasets.dataset_writer import _encode_video_worker +from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.datasets.utils import DEFAULT_IMAGE_PATH +from tests.fixtures.constants import DEFAULT_FPS, DUMMY_REPO_ID + +SIMPLE_FEATURES = { + "state": {"dtype": "float32", "shape": (6,), "names": None}, + "action": {"dtype": "float32", "shape": (6,), "names": None}, +} + + +def _make_frame(features: dict, task: str = "Dummy task") -> dict: + """Build a valid frame dict for the given features.""" + frame = {"task": task} + for key, ft in features.items(): + if ft["dtype"] in ("image", "video"): + frame[key] = np.random.randint(0, 256, size=ft["shape"], dtype=np.uint8) + elif ft["dtype"] in ("float32", "float64"): + frame[key] = torch.randn(ft["shape"]) + elif ft["dtype"] == "int64": + frame[key] = torch.zeros(ft["shape"], dtype=torch.int64) + return frame + + +# ── Existing encode_video_worker tests ─────────────────────────────── + + +def test_encode_video_worker_forwards_vcodec(tmp_path): + """_encode_video_worker correctly forwards the vcodec parameter.""" + video_key = "observation.images.laptop" + fpath = DEFAULT_IMAGE_PATH.format(image_key=video_key, episode_index=0, frame_index=0) + img_dir = tmp_path / Path(fpath).parent + img_dir.mkdir(parents=True, exist_ok=True) + Image.new("RGB", (64, 64), color="red").save(img_dir / "frame-000000.png") + + captured_kwargs = {} + + def mock_encode(imgs_dir, video_path, fps, **kwargs): + captured_kwargs.update(kwargs) + Path(video_path).parent.mkdir(parents=True, exist_ok=True) + Path(video_path).touch() + + with patch("lerobot.datasets.dataset_writer.encode_video_frames", side_effect=mock_encode): + _encode_video_worker(video_key, 0, tmp_path, fps=30, vcodec="h264") + + assert captured_kwargs["vcodec"] == "h264" + + +def test_encode_video_worker_default_vcodec(tmp_path): + """_encode_video_worker uses libsvtav1 as the default codec.""" + video_key = "observation.images.laptop" + fpath = DEFAULT_IMAGE_PATH.format(image_key=video_key, episode_index=0, frame_index=0) + img_dir = tmp_path / Path(fpath).parent + img_dir.mkdir(parents=True, exist_ok=True) + Image.new("RGB", (64, 64), color="red").save(img_dir / "frame-000000.png") + + captured_kwargs = {} + + def mock_encode(imgs_dir, video_path, fps, **kwargs): + captured_kwargs.update(kwargs) + Path(video_path).parent.mkdir(parents=True, exist_ok=True) + Path(video_path).touch() + + with patch("lerobot.datasets.dataset_writer.encode_video_frames", side_effect=mock_encode): + _encode_video_worker(video_key, 0, tmp_path, fps=30) + + assert captured_kwargs["vcodec"] == "libsvtav1" + + +# ── add_frame contracts ────────────────────────────────────────────── + + +def test_add_frame_increments_buffer_size(tmp_path): + """Each add_frame() call increases episode_buffer['size'] by 1.""" + dataset = LeRobotDataset.create( + repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=tmp_path / "ds" + ) + assert dataset.writer.episode_buffer["size"] == 0 + + dataset.add_frame(_make_frame(SIMPLE_FEATURES)) + assert dataset.writer.episode_buffer["size"] == 1 + + dataset.add_frame(_make_frame(SIMPLE_FEATURES)) + assert dataset.writer.episode_buffer["size"] == 2 + + +def test_add_frame_rejects_missing_feature(tmp_path): + """add_frame() raises ValueError when a required feature is missing.""" + dataset = LeRobotDataset.create( + repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=tmp_path / "ds" + ) + with pytest.raises(ValueError, match="Missing features"): + dataset.add_frame({"task": "Dummy task", "state": torch.randn(6)}) + # missing 'action' + + +# ── save_episode contracts ─────────────────────────────────────────── + + +def test_save_episode_writes_parquet(tmp_path): + """After save_episode(), at least one .parquet file exists under data/.""" + dataset = LeRobotDataset.create( + repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=tmp_path / "ds" + ) + for _ in range(3): + dataset.add_frame(_make_frame(SIMPLE_FEATURES)) + dataset.save_episode() + + parquet_files = list((tmp_path / "ds" / "data").rglob("*.parquet")) + assert len(parquet_files) > 0 + + +def test_save_episode_updates_counters(tmp_path): + """After save_episode(), metadata counters are updated.""" + dataset = LeRobotDataset.create( + repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=tmp_path / "ds" + ) + for _ in range(5): + dataset.add_frame(_make_frame(SIMPLE_FEATURES)) + dataset.save_episode() + + assert dataset.meta.total_episodes == 1 + assert dataset.meta.total_frames == 5 + + +def test_save_episode_resets_buffer(tmp_path): + """After save_episode(), the episode buffer is reset.""" + dataset = LeRobotDataset.create( + repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=tmp_path / "ds" + ) + for _ in range(3): + dataset.add_frame(_make_frame(SIMPLE_FEATURES)) + dataset.save_episode() + + assert dataset.writer.episode_buffer["size"] == 0 + + +def test_save_multiple_episodes(tmp_path): + """Recording 3 episodes results in correct total counts.""" + dataset = LeRobotDataset.create( + repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=tmp_path / "ds" + ) + total_frames = 0 + for ep in range(3): + n_frames = ep + 2 # 2, 3, 4 + for _ in range(n_frames): + dataset.add_frame(_make_frame(SIMPLE_FEATURES)) + dataset.save_episode() + total_frames += n_frames + + assert dataset.meta.total_episodes == 3 + assert dataset.meta.total_frames == total_frames + + +# ── clear / lifecycle ──────────────────────────────────────────────── + + +def test_clear_resets_buffer(tmp_path): + """clear_episode_buffer() resets the buffer size to 0.""" + dataset = LeRobotDataset.create( + repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=tmp_path / "ds" + ) + dataset.add_frame(_make_frame(SIMPLE_FEATURES)) + assert dataset.writer.episode_buffer["size"] == 1 + + dataset.clear_episode_buffer() + assert dataset.writer.episode_buffer["size"] == 0 + + +def test_finalize_is_idempotent(tmp_path): + """Calling finalize() twice does not raise.""" + dataset = LeRobotDataset.create( + repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=tmp_path / "ds" + ) + for _ in range(3): + dataset.add_frame(_make_frame(SIMPLE_FEATURES)) + dataset.save_episode() + + dataset.finalize() + dataset.finalize() # second call should not raise + + +def test_finalize_then_read_roundtrip(tmp_path): + """Write data, finalize, re-open, and verify data matches.""" + root = tmp_path / "roundtrip" + features = {"state": {"dtype": "float32", "shape": (2,), "names": None}} + dataset = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=features, root=root) + + # Record known values + known_states = [] + for i in range(5): + state = torch.tensor([float(i), float(i * 10)]) + known_states.append(state) + dataset.add_frame({"task": "Test task", "state": state}) + dataset.save_episode() + dataset.finalize() + + # Read back + for i in range(5): + item = dataset[i] + assert torch.allclose(item["state"], known_states[i], atol=1e-5) diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py index 67878d8f6..b2518149f 100644 --- a/tests/datasets/test_datasets.py +++ b/tests/datasets/test_datasets.py @@ -32,10 +32,7 @@ from lerobot.datasets.factory import make_dataset from lerobot.datasets.feature_utils import get_hf_features_from_features, hw_to_dataset_features from lerobot.datasets.image_writer import image_array_to_pil_image from lerobot.datasets.io_utils import hf_transform_to_torch -from lerobot.datasets.lerobot_dataset import ( - LeRobotDataset, - _encode_video_worker, -) +from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.multi_dataset import MultiLeRobotDataset from lerobot.datasets.utils import ( DEFAULT_CHUNK_SIZE, @@ -72,7 +69,7 @@ def image_dataset(tmp_path, empty_lerobot_dataset_factory): def test_same_attributes_defined(tmp_path, lerobot_dataset_factory): """ Instantiate a LeRobotDataset both ways with '__init__()' and 'create()' and verify that instantiated - objects have the same sets of attributes defined. + objects have the same sets of facade-level attributes defined. """ # Instantiate both ways robot = make_robot_from_config(MockRobotConfig()) @@ -87,6 +84,7 @@ def test_same_attributes_defined(tmp_path, lerobot_dataset_factory): root_init = tmp_path / "init" dataset_init = lerobot_dataset_factory(root=root_init, total_episodes=1, total_frames=1) + # Facade-level attributes should match between __init__ and create() init_attr = set(vars(dataset_init).keys()) create_attr = set(vars(dataset_create).keys()) @@ -214,6 +212,7 @@ def test_add_frame(tmp_path, empty_lerobot_dataset_factory): dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) dataset.add_frame({"state": torch.randn(1), "task": "Dummy task"}) dataset.save_episode() + dataset.finalize() assert len(dataset) == 1 assert dataset[0]["task"] == "Dummy task" @@ -226,6 +225,7 @@ def test_add_frame_state_1d(tmp_path, empty_lerobot_dataset_factory): dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) dataset.add_frame({"state": torch.randn(2), "task": "Dummy task"}) dataset.save_episode() + dataset.finalize() assert dataset[0]["state"].shape == torch.Size([2]) @@ -235,6 +235,7 @@ def test_add_frame_state_2d(tmp_path, empty_lerobot_dataset_factory): dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) dataset.add_frame({"state": torch.randn(2, 4), "task": "Dummy task"}) dataset.save_episode() + dataset.finalize() assert dataset[0]["state"].shape == torch.Size([2, 4]) @@ -244,6 +245,7 @@ def test_add_frame_state_3d(tmp_path, empty_lerobot_dataset_factory): dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) dataset.add_frame({"state": torch.randn(2, 4, 3), "task": "Dummy task"}) dataset.save_episode() + dataset.finalize() assert dataset[0]["state"].shape == torch.Size([2, 4, 3]) @@ -253,6 +255,7 @@ def test_add_frame_state_4d(tmp_path, empty_lerobot_dataset_factory): dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) dataset.add_frame({"state": torch.randn(2, 4, 3, 5), "task": "Dummy task"}) dataset.save_episode() + dataset.finalize() assert dataset[0]["state"].shape == torch.Size([2, 4, 3, 5]) @@ -262,6 +265,7 @@ def test_add_frame_state_5d(tmp_path, empty_lerobot_dataset_factory): dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) dataset.add_frame({"state": torch.randn(2, 4, 3, 5, 1), "task": "Dummy task"}) dataset.save_episode() + dataset.finalize() assert dataset[0]["state"].shape == torch.Size([2, 4, 3, 5, 1]) @@ -271,6 +275,7 @@ def test_add_frame_state_numpy(tmp_path, empty_lerobot_dataset_factory): dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) dataset.add_frame({"state": np.array([1], dtype=np.float32), "task": "Dummy task"}) dataset.save_episode() + dataset.finalize() assert dataset[0]["state"].ndim == 0 @@ -280,6 +285,7 @@ def test_add_frame_string(tmp_path, empty_lerobot_dataset_factory): dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) dataset.add_frame({"caption": "Dummy caption", "task": "Dummy task"}) dataset.save_episode() + dataset.finalize() assert dataset[0]["caption"] == "Dummy caption" @@ -315,6 +321,7 @@ def test_add_frame_image(image_dataset): dataset = image_dataset dataset.add_frame({"image": np.random.rand(*DUMMY_CHW), "task": "Dummy task"}) dataset.save_episode() + dataset.finalize() assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW) @@ -323,6 +330,7 @@ def test_add_frame_image_h_w_c(image_dataset): dataset = image_dataset dataset.add_frame({"image": np.random.rand(*DUMMY_HWC), "task": "Dummy task"}) dataset.save_episode() + dataset.finalize() assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW) @@ -332,6 +340,7 @@ def test_add_frame_image_uint8(image_dataset): image = np.random.randint(0, 256, DUMMY_HWC, dtype=np.uint8) dataset.add_frame({"image": image, "task": "Dummy task"}) dataset.save_episode() + dataset.finalize() assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW) @@ -341,6 +350,7 @@ def test_add_frame_image_pil(image_dataset): image = np.random.randint(0, 256, DUMMY_HWC, dtype=np.uint8) dataset.add_frame({"image": Image.fromarray(image), "task": "Dummy task"}) dataset.save_episode() + dataset.finalize() assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW) @@ -361,7 +371,7 @@ def test_tmp_image_deletion(tmp_path, empty_lerobot_dataset_factory): ds_img = empty_lerobot_dataset_factory(root=tmp_path / "img", features=features_image) ds_img.add_frame({"image": np.random.rand(*DUMMY_CHW), "task": "Dummy task"}) ds_img.save_episode() - img_dir = ds_img._get_image_file_dir(0, image_key) + img_dir = ds_img.writer._get_image_file_dir(0, image_key) assert not img_dir.exists(), "Temporary image directory should be removed for image features" @@ -374,10 +384,10 @@ def test_tmp_video_deletion(tmp_path, empty_lerobot_dataset_factory): } ds_vid = empty_lerobot_dataset_factory(root=tmp_path / "vid", features=features_video) - ds_vid.batch_encoding_size = 1 + ds_vid.writer._batch_encoding_size = 1 ds_vid.add_frame({vid_key: np.random.rand(*DUMMY_CHW), "task": "Dummy task"}) ds_vid.save_episode() - vid_img_dir = ds_vid._get_image_file_dir(0, vid_key) + vid_img_dir = ds_vid.writer._get_image_file_dir(0, vid_key) assert not vid_img_dir.exists(), ( "Temporary image directory should be removed when batch_encoding_size == 1" ) @@ -402,8 +412,8 @@ def test_tmp_mixed_deletion(tmp_path, empty_lerobot_dataset_factory): } ) ds_mixed.save_episode() - img_dir = ds_mixed._get_image_file_dir(0, image_key) - vid_img_dir = ds_mixed._get_image_file_dir(0, vid_key) + img_dir = ds_mixed.writer._get_image_file_dir(0, image_key) + vid_img_dir = ds_mixed.writer._get_image_file_dir(0, vid_key) assert not img_dir.exists(), "Temporary image directory should be removed for image features" assert vid_img_dir.exists(), ( "Temporary image directory should not be removed for video features when batch_encoding_size == 2" @@ -631,29 +641,29 @@ def test_check_cached_episodes_sufficient(tmp_path, lerobot_dataset_factory): ) # Test hf_dataset is None - dataset.hf_dataset = None - assert dataset._check_cached_episodes_sufficient() is False + dataset.reader.hf_dataset = None + assert dataset.reader._check_cached_episodes_sufficient() is False # Test hf_dataset is empty import datasets empty_features = get_hf_features_from_features(dataset.features) - dataset.hf_dataset = datasets.Dataset.from_dict( + dataset.reader.hf_dataset = datasets.Dataset.from_dict( {key: [] for key in empty_features}, features=empty_features ) - dataset.hf_dataset.set_transform(hf_transform_to_torch) - assert dataset._check_cached_episodes_sufficient() is False + dataset.reader.hf_dataset.set_transform(hf_transform_to_torch) + assert dataset.reader._check_cached_episodes_sufficient() is False # Restore the original dataset for remaining tests - dataset.hf_dataset = dataset.load_hf_dataset() + dataset.reader.hf_dataset = dataset.reader._load_hf_dataset() # Test all episodes requested (self.episodes = None) and all are available - dataset.episodes = None - assert dataset._check_cached_episodes_sufficient() is True + dataset.reader.episodes = None + assert dataset.reader._check_cached_episodes_sufficient() is True # Test specific episodes requested that are all available - dataset.episodes = [0, 2, 4] - assert dataset._check_cached_episodes_sufficient() is True + dataset.reader.episodes = [0, 2, 4] + assert dataset.reader._check_cached_episodes_sufficient() is True # Test request episodes that don't exist in the cached dataset # Create a dataset with only episodes 0, 1, 2 @@ -665,8 +675,8 @@ def test_check_cached_episodes_sufficient(tmp_path, lerobot_dataset_factory): ) # Request episodes that include non-existent ones - limited_dataset.episodes = [0, 1, 2, 3, 4] - assert limited_dataset._check_cached_episodes_sufficient() is False + limited_dataset.reader.episodes = [0, 1, 2, 3, 4] + assert limited_dataset.reader._check_cached_episodes_sufficient() is False # Test create a dataset with sparse episodes (e.g., only episodes 0, 2, 4) # First create the full dataset structure @@ -702,22 +712,22 @@ def test_check_cached_episodes_sufficient(tmp_path, lerobot_dataset_factory): filtered_data[key] = filtered_values - sparse_dataset.hf_dataset = datasets.Dataset.from_dict( + sparse_dataset.reader.hf_dataset = datasets.Dataset.from_dict( filtered_data, features=get_hf_features_from_features(sparse_dataset.features) ) - sparse_dataset.hf_dataset.set_transform(hf_transform_to_torch) + sparse_dataset.reader.hf_dataset.set_transform(hf_transform_to_torch) # Test requesting all episodes when only some are cached - sparse_dataset.episodes = None - assert sparse_dataset._check_cached_episodes_sufficient() is False + sparse_dataset.reader.episodes = None + assert sparse_dataset.reader._check_cached_episodes_sufficient() is False # Test requesting only the available episodes - sparse_dataset.episodes = [0, 2, 4] - assert sparse_dataset._check_cached_episodes_sufficient() is True + sparse_dataset.reader.episodes = [0, 2, 4] + assert sparse_dataset.reader._check_cached_episodes_sufficient() is True # Test requesting a mix of available and unavailable episodes - sparse_dataset.episodes = [0, 1, 2] - assert sparse_dataset._check_cached_episodes_sufficient() is False + sparse_dataset.reader.episodes = [0, 1, 2] + assert sparse_dataset.reader._check_cached_episodes_sufficient() is False def test_update_chunk_settings(tmp_path, empty_lerobot_dataset_factory): @@ -1189,13 +1199,13 @@ def test_dataset_resume_recording(tmp_path, empty_lerobot_dataset_factory): del dataset_verify # Phase 3: Resume recording - add more episodes - dataset_resumed = LeRobotDataset(initial_repo_id, root=initial_root, revision="v3.0") + dataset_resumed = LeRobotDataset.resume(initial_repo_id, root=initial_root, revision="v3.0") assert dataset_resumed.meta.total_episodes == initial_episodes assert dataset_resumed.meta.total_frames == initial_episodes * frames_per_episode - assert dataset_resumed.latest_episode is None # Not recording yet - assert dataset_resumed.writer is None - assert dataset_resumed.meta.writer is None + assert dataset_resumed.writer._latest_episode is None # Not recording yet + assert dataset_resumed.writer._pq_writer is None + assert dataset_resumed.meta._pq_writer is None additional_episodes = 2 for ep_idx in range(initial_episodes, initial_episodes + additional_episodes): @@ -1271,7 +1281,7 @@ def test_frames_in_current_file_calculation(tmp_path, empty_lerobot_dataset_fact dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features, use_videos=False) dataset.meta.update_chunk_settings(data_files_size_in_mb=100) - assert dataset._current_file_start_frame is None + assert dataset.writer._current_file_start_frame is None frames_per_episode = 10 for _ in range(frames_per_episode): @@ -1284,7 +1294,7 @@ def test_frames_in_current_file_calculation(tmp_path, empty_lerobot_dataset_fact ) dataset.save_episode() - assert dataset._current_file_start_frame == 0 + assert dataset.writer._current_file_start_frame == 0 assert dataset.meta.total_episodes == 1 assert dataset.meta.total_frames == frames_per_episode @@ -1298,12 +1308,12 @@ def test_frames_in_current_file_calculation(tmp_path, empty_lerobot_dataset_fact ) dataset.save_episode() - assert dataset._current_file_start_frame == 0 + assert dataset.writer._current_file_start_frame == 0 assert dataset.meta.total_episodes == 2 assert dataset.meta.total_frames == 2 * frames_per_episode - ep1_chunk = dataset.latest_episode["data/chunk_index"] - ep1_file = dataset.latest_episode["data/file_index"] + ep1_chunk = dataset.writer._latest_episode["data/chunk_index"] + ep1_file = dataset.writer._latest_episode["data/file_index"] assert ep1_chunk == 0 assert ep1_file == 0 @@ -1317,12 +1327,12 @@ def test_frames_in_current_file_calculation(tmp_path, empty_lerobot_dataset_fact ) dataset.save_episode() - assert dataset._current_file_start_frame == 0 + assert dataset.writer._current_file_start_frame == 0 assert dataset.meta.total_episodes == 3 assert dataset.meta.total_frames == 3 * frames_per_episode - ep2_chunk = dataset.latest_episode["data/chunk_index"] - ep2_file = dataset.latest_episode["data/file_index"] + ep2_chunk = dataset.writer._latest_episode["data/chunk_index"] + ep2_file = dataset.writer._latest_episode["data/file_index"] assert ep2_chunk == 0 assert ep2_file == 0 @@ -1354,82 +1364,6 @@ def test_frames_in_current_file_calculation(tmp_path, empty_lerobot_dataset_fact assert frame["episode_index"].item() == expected_ep -def test_encode_video_worker_forwards_vcodec(tmp_path): - """Test that _encode_video_worker correctly forwards the vcodec parameter to encode_video_frames.""" - from unittest.mock import patch - - from lerobot.datasets.utils import DEFAULT_IMAGE_PATH - - # Create the expected directory structure - video_key = "observation.images.laptop" - episode_index = 0 - frame_index = 0 - - fpath = DEFAULT_IMAGE_PATH.format( - image_key=video_key, episode_index=episode_index, frame_index=frame_index - ) - img_dir = tmp_path / Path(fpath).parent - img_dir.mkdir(parents=True, exist_ok=True) - - # Create a dummy image file - dummy_img = Image.new("RGB", (64, 64), color="red") - dummy_img.save(img_dir / "frame-000000.png") - - # Track what vcodec was passed to encode_video_frames - captured_kwargs = {} - - def mock_encode_video_frames(imgs_dir, video_path, fps, **kwargs): - captured_kwargs.update(kwargs) - # Create a dummy output file so the worker doesn't fail - Path(video_path).parent.mkdir(parents=True, exist_ok=True) - Path(video_path).touch() - - with patch("lerobot.datasets.lerobot_dataset.encode_video_frames", side_effect=mock_encode_video_frames): - # Test with h264 codec - _encode_video_worker(video_key, episode_index, tmp_path, fps=30, vcodec="h264") - - assert "vcodec" in captured_kwargs - assert captured_kwargs["vcodec"] == "h264" - - -def test_encode_video_worker_default_vcodec(tmp_path): - """Test that _encode_video_worker uses libsvtav1 as the default codec.""" - from unittest.mock import patch - - from lerobot.datasets.utils import DEFAULT_IMAGE_PATH - - # Create the expected directory structure - video_key = "observation.images.laptop" - episode_index = 0 - frame_index = 0 - - fpath = DEFAULT_IMAGE_PATH.format( - image_key=video_key, episode_index=episode_index, frame_index=frame_index - ) - img_dir = tmp_path / Path(fpath).parent - img_dir.mkdir(parents=True, exist_ok=True) - - # Create a dummy image file - dummy_img = Image.new("RGB", (64, 64), color="red") - dummy_img.save(img_dir / "frame-000000.png") - - # Track what vcodec was passed to encode_video_frames - captured_kwargs = {} - - def mock_encode_video_frames(imgs_dir, video_path, fps, **kwargs): - captured_kwargs.update(kwargs) - # Create a dummy output file so the worker doesn't fail - Path(video_path).parent.mkdir(parents=True, exist_ok=True) - Path(video_path).touch() - - with patch("lerobot.datasets.lerobot_dataset.encode_video_frames", side_effect=mock_encode_video_frames): - # Test with default codec (no vcodec specified) - _encode_video_worker(video_key, episode_index, tmp_path, fps=30) - - assert "vcodec" in captured_kwargs - assert captured_kwargs["vcodec"] == "libsvtav1" - - def test_lerobot_dataset_vcodec_validation(): """Test that LeRobotDataset validates the vcodec parameter.""" # Test that invalid vcodec raises ValueError diff --git a/tests/datasets/test_image_writer.py b/tests/datasets/test_image_writer.py index e02755171..55419473f 100644 --- a/tests/datasets/test_image_writer.py +++ b/tests/datasets/test_image_writer.py @@ -352,10 +352,14 @@ def test_with_different_image_formats(tmp_path, img_array_factory): def test_safe_stop_image_writer_decorator(): - class MockDataset: + class MockWriter: def __init__(self): self.image_writer = MagicMock(spec=AsyncImageWriter) + class MockDataset: + def __init__(self): + self.writer = MockWriter() + @safe_stop_image_writer def function_that_raises_exception(dataset=None): raise Exception("Test exception") @@ -366,7 +370,7 @@ def test_safe_stop_image_writer_decorator(): function_that_raises_exception(dataset=dataset) assert str(exc_info.value) == "Test exception" - dataset.image_writer.stop.assert_called_once() + dataset.writer.image_writer.stop.assert_called_once() def test_main_process_time(tmp_path, img_tensor_factory): diff --git a/tests/datasets/test_lerobot_dataset.py b/tests/datasets/test_lerobot_dataset.py new file mode 100644 index 000000000..d7ce54a15 --- /dev/null +++ b/tests/datasets/test_lerobot_dataset.py @@ -0,0 +1,314 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contract tests for the LeRobotDataset facade. + +Tests focus on mode contracts (read-only, write-only, resume), guards, +property delegation, and the full create-record-finalize-read lifecycle. +""" + +import pytest +import torch + +from lerobot.datasets.dataset_reader import DatasetReader +from lerobot.datasets.dataset_writer import DatasetWriter +from lerobot.datasets.lerobot_dataset import LeRobotDataset +from tests.fixtures.constants import DEFAULT_FPS, DUMMY_REPO_ID + +SIMPLE_FEATURES = { + "state": {"dtype": "float32", "shape": (2,), "names": None}, +} + + +def _make_frame(task: str = "Dummy task") -> dict: + return {"task": task, "state": torch.randn(2)} + + +# ── Read-only mode (via __init__) ──────────────────────────────────── + + +def test_init_creates_reader_no_writer(tmp_path, lerobot_dataset_factory): + """__init__() sets reader to a DatasetReader and writer to None.""" + dataset = lerobot_dataset_factory( + root=tmp_path / "ds", total_episodes=1, total_frames=10, use_videos=False + ) + assert isinstance(dataset.reader, DatasetReader) + assert dataset.writer is None + + +def test_init_loads_data(tmp_path, lerobot_dataset_factory): + """After __init__(), the dataset has data and len > 0.""" + dataset = lerobot_dataset_factory( + root=tmp_path / "ds", total_episodes=1, total_frames=10, use_videos=False + ) + assert len(dataset) > 0 + + +def test_getitem_works_in_read_mode(tmp_path, lerobot_dataset_factory): + """dataset[0] returns a dict with expected keys in read-only mode.""" + dataset = lerobot_dataset_factory( + root=tmp_path / "ds", total_episodes=1, total_frames=10, use_videos=False + ) + item = dataset[0] + assert isinstance(item, dict) + assert "index" in item + assert "task" in item + + +def test_len_matches_num_frames(tmp_path, lerobot_dataset_factory): + """len(dataset) equals dataset.num_frames.""" + dataset = lerobot_dataset_factory( + root=tmp_path / "ds", total_episodes=2, total_frames=30, use_videos=False + ) + assert len(dataset) == dataset.num_frames + + +# ── Write-only mode (via create()) ────────────────────────────────── + + +def test_create_sets_writer_no_reader(tmp_path): + """create() sets writer to a DatasetWriter and reader to None.""" + dataset = LeRobotDataset.create( + repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=tmp_path / "ds" + ) + assert isinstance(dataset.writer, DatasetWriter) + assert dataset.reader is None + + +def test_create_initial_counts_zero(tmp_path): + """After create(), num_episodes == 0 and num_frames == 0.""" + dataset = LeRobotDataset.create( + repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=tmp_path / "ds" + ) + assert dataset.num_episodes == 0 + assert dataset.num_frames == 0 + + +def test_add_frame_works_in_write_mode(tmp_path): + """add_frame() succeeds on a dataset created via create().""" + dataset = LeRobotDataset.create( + repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=tmp_path / "ds" + ) + dataset.add_frame(_make_frame()) # should not raise + + +# ── Resume mode ────────────────────────────────────────────────────── + + +def test_resume_creates_writer(tmp_path): + """After resume(), writer is a DatasetWriter.""" + root = tmp_path / "resume_ds" + dataset = LeRobotDataset.create( + repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root + ) + for _ in range(3): + dataset.add_frame(_make_frame()) + dataset.save_episode() + dataset.finalize() + + resumed = LeRobotDataset.resume(repo_id=DUMMY_REPO_ID, root=root) + assert isinstance(resumed.writer, DatasetWriter) + + +def test_resume_preserves_episode_count(tmp_path): + """After resume(), existing episodes are counted.""" + root = tmp_path / "resume_ds" + dataset = LeRobotDataset.create( + repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root + ) + for _ in range(3): + dataset.add_frame(_make_frame()) + dataset.save_episode() + dataset.finalize() + + resumed = LeRobotDataset.resume(repo_id=DUMMY_REPO_ID, root=root) + assert resumed.meta.total_episodes == 1 + + +def test_resume_can_add_more_episodes(tmp_path): + """After resume(), new episodes can be added.""" + root = tmp_path / "resume_ds" + dataset = LeRobotDataset.create( + repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root + ) + for _ in range(3): + dataset.add_frame(_make_frame()) + dataset.save_episode() + dataset.finalize() + + resumed = LeRobotDataset.resume(repo_id=DUMMY_REPO_ID, root=root) + for _ in range(2): + resumed.add_frame(_make_frame()) + resumed.save_episode() + + assert resumed.meta.total_episodes == 2 + + +# ── Writer guard ───────────────────────────────────────────────────── + + +def test_add_frame_raises_without_writer(tmp_path, lerobot_dataset_factory): + """add_frame() raises RuntimeError on a read-only dataset.""" + dataset = lerobot_dataset_factory( + root=tmp_path / "ds", total_episodes=1, total_frames=5, use_videos=False + ) + with pytest.raises(RuntimeError, match="read-only"): + dataset.add_frame(_make_frame()) + + +def test_save_episode_raises_without_writer(tmp_path, lerobot_dataset_factory): + """save_episode() raises RuntimeError on a read-only dataset.""" + dataset = lerobot_dataset_factory( + root=tmp_path / "ds", total_episodes=1, total_frames=5, use_videos=False + ) + with pytest.raises(RuntimeError, match="read-only"): + dataset.save_episode() + + +def test_clear_episode_buffer_raises_without_writer(tmp_path, lerobot_dataset_factory): + """clear_episode_buffer() raises RuntimeError on a read-only dataset.""" + dataset = lerobot_dataset_factory( + root=tmp_path / "ds", total_episodes=1, total_frames=5, use_videos=False + ) + with pytest.raises(RuntimeError, match="read-only"): + dataset.clear_episode_buffer() + + +# ── Reader guard ───────────────────────────────────────────────────── + + +def test_getitem_raises_before_finalize(tmp_path): + """dataset[0] raises RuntimeError while recording (before finalize).""" + dataset = LeRobotDataset.create( + repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=tmp_path / "ds" + ) + for _ in range(3): + dataset.add_frame(_make_frame()) + dataset.save_episode() + + with pytest.raises(RuntimeError, match="finalize"): + dataset[0] + + +def test_getitem_works_after_finalize(tmp_path): + """After finalize(), dataset[0] returns data.""" + dataset = LeRobotDataset.create( + repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=tmp_path / "ds" + ) + for _ in range(3): + dataset.add_frame(_make_frame()) + dataset.save_episode() + dataset.finalize() + + item = dataset[0] + assert "state" in item + assert "task" in item + + +# ── Property delegation ────────────────────────────────────────────── + + +def test_fps_delegates_to_meta(tmp_path, lerobot_dataset_factory): + """dataset.fps == dataset.meta.fps.""" + dataset = lerobot_dataset_factory( + root=tmp_path / "ds", total_episodes=1, total_frames=5, use_videos=False + ) + assert dataset.fps == dataset.meta.fps + + +def test_features_delegates_to_meta(tmp_path, lerobot_dataset_factory): + """dataset.features is dataset.meta.features.""" + dataset = lerobot_dataset_factory( + root=tmp_path / "ds", total_episodes=1, total_frames=5, use_videos=False + ) + assert dataset.features is dataset.meta.features + + +def test_num_frames_uses_meta_in_write_mode(tmp_path): + """In write-only mode (reader=None), num_frames comes from metadata.""" + dataset = LeRobotDataset.create( + repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=tmp_path / "ds" + ) + assert dataset.reader is None + assert dataset.num_frames == dataset.meta.total_frames + + +# ── Lifecycle ──────────────────────────────────────────────────────── + + +def test_finalize_is_idempotent(tmp_path): + """Calling finalize() twice does not raise.""" + dataset = LeRobotDataset.create( + repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=tmp_path / "ds" + ) + dataset.finalize() + dataset.finalize() + + +def test_has_pending_frames_lifecycle(tmp_path): + """has_pending_frames: False -> True (add_frame) -> False (save_episode).""" + dataset = LeRobotDataset.create( + repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=tmp_path / "ds" + ) + assert dataset.has_pending_frames() is False + + dataset.add_frame(_make_frame()) + assert dataset.has_pending_frames() is True + + dataset.save_episode() + assert dataset.has_pending_frames() is False + + +def test_create_record_finalize_read_roundtrip(tmp_path): + """End-to-end: create, record 2 episodes, finalize, re-open, verify data.""" + root = tmp_path / "roundtrip" + dataset = LeRobotDataset.create( + repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root + ) + + # Episode 0: 3 frames with known values + ep0_states = [] + for i in range(3): + state = torch.tensor([float(i), float(i * 2)]) + ep0_states.append(state) + dataset.add_frame({"task": "Task A", "state": state}) + dataset.save_episode() + + # Episode 1: 2 frames + ep1_states = [] + for i in range(2): + state = torch.tensor([float(i + 100), float(i + 200)]) + ep1_states.append(state) + dataset.add_frame({"task": "Task B", "state": state}) + dataset.save_episode() + + dataset.finalize() + + # Re-open as read-only + reopened = LeRobotDataset(repo_id=DUMMY_REPO_ID, root=root) + assert len(reopened) == 5 + assert reopened.num_episodes == 2 + + # Verify episode 0 + for i in range(3): + item = reopened[i] + assert torch.allclose(item["state"], ep0_states[i], atol=1e-5) + assert item["episode_index"].item() == 0 + + # Verify episode 1 + for i in range(2): + item = reopened[3 + i] + assert torch.allclose(item["state"], ep1_states[i], atol=1e-5) + assert item["episode_index"].item() == 1 diff --git a/tests/datasets/test_streaming_video_encoder.py b/tests/datasets/test_streaming_video_encoder.py index a85db6a8d..f7e63b06f 100644 --- a/tests/datasets/test_streaming_video_encoder.py +++ b/tests/datasets/test_streaming_video_encoder.py @@ -534,7 +534,7 @@ class TestStreamingEncoderIntegration: streaming_encoding=True, ) - assert dataset._streaming_encoder is not None + assert dataset.writer._streaming_encoder is not None num_frames = 20 for _ in range(num_frames): @@ -580,7 +580,7 @@ class TestStreamingEncoderIntegration: streaming_encoding=False, ) - assert dataset._streaming_encoder is None + assert dataset.writer._streaming_encoder is None num_frames = 5 for _ in range(num_frames): From aa9cc9bd43e9eba92a32aaadde2a09c90b5836cf Mon Sep 17 00:00:00 2001 From: Reece O'Mahoney <66252930+reeceomahoney@users.noreply.github.com> Date: Thu, 26 Mar 2026 20:05:15 +0000 Subject: [PATCH 03/47] fix(logging): suppress noisy httpx INFO logs (#3173) Set httpx logger level to WARNING in init_logging to prevent HTTP request traces from flooding the terminal during train and eval scripts. Co-authored-by: Steven Palma --- src/lerobot/utils/utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/lerobot/utils/utils.py b/src/lerobot/utils/utils.py index b9f8441d6..f6aa93bea 100644 --- a/src/lerobot/utils/utils.py +++ b/src/lerobot/utils/utils.py @@ -95,6 +95,8 @@ def init_logging( file_handler.setLevel(file_level.upper()) logger.addHandler(file_handler) + logging.getLogger("httpx").setLevel(logging.WARNING) + def format_big_number(num, precision=0): suffixes = ["", "K", "M", "B", "T", "Q"] From 07502868e58095b437e5dd5a480fecc65a6f29dc Mon Sep 17 00:00:00 2001 From: Maxime Ellerbach Date: Fri, 27 Mar 2026 21:25:12 +0100 Subject: [PATCH 04/47] fix(deps): breaking change from transformers 5.4.0 (#3231) * fix(deps): breaking change from transformers 5.4.0 * Update src/lerobot/policies/xvla/modeling_florence2.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Maxime Ellerbach * Update src/lerobot/policies/wall_x/qwen_model/qwen2_5_vl_moe.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Maxime Ellerbach * removing dataclass * bumping transformers 5.4.0 --------- Signed-off-by: Maxime Ellerbach Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- pyproject.toml | 2 +- .../policies/groot/action_head/flow_matching_action_head.py | 3 +-- src/lerobot/policies/groot/groot_n1.py | 3 +-- src/lerobot/policies/wall_x/qwen_model/qwen2_5_vl_moe.py | 4 ++-- src/lerobot/policies/xvla/modeling_florence2.py | 4 ++-- 5 files changed, 7 insertions(+), 9 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 5f45626c0..7e4f24eb6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -99,7 +99,7 @@ dependencies = [ # Common pygame-dep = ["pygame>=2.5.1,<2.7.0"] placo-dep = ["placo>=0.9.6,<0.9.17"] -transformers-dep = ["transformers>=5.3.0,<6.0.0"] +transformers-dep = ["transformers>=5.4.0,<6.0.0"] grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"] can-dep = ["python-can>=4.2.0,<5.0.0"] peft-dep = ["peft>=0.18.0,<1.0.0"] diff --git a/src/lerobot/policies/groot/action_head/flow_matching_action_head.py b/src/lerobot/policies/groot/action_head/flow_matching_action_head.py index bfc456ba0..74d922988 100644 --- a/src/lerobot/policies/groot/action_head/flow_matching_action_head.py +++ b/src/lerobot/policies/groot/action_head/flow_matching_action_head.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass, field +from dataclasses import field from typing import TYPE_CHECKING import torch @@ -110,7 +110,6 @@ class MultiEmbodimentActionEncoder(nn.Module): return x -@dataclass class FlowmatchingActionHeadConfig(PretrainedConfig): """NOTE: N1.5 uses XEmbFlowmatchingPolicyHeadConfig as action head""" diff --git a/src/lerobot/policies/groot/groot_n1.py b/src/lerobot/policies/groot/groot_n1.py index 06ff5a04d..38512b8a8 100644 --- a/src/lerobot/policies/groot/groot_n1.py +++ b/src/lerobot/policies/groot/groot_n1.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass, field +from dataclasses import field from pathlib import Path from typing import TYPE_CHECKING @@ -173,7 +173,6 @@ N_COLOR_CHANNELS = 3 # config -@dataclass class GR00TN15Config(PretrainedConfig): model_type = "gr00t_n1_5" backbone_cfg: dict = field(init=False, metadata={"help": "Backbone configuration."}) diff --git a/src/lerobot/policies/wall_x/qwen_model/qwen2_5_vl_moe.py b/src/lerobot/policies/wall_x/qwen_model/qwen2_5_vl_moe.py index ecf3eb371..a80096514 100644 --- a/src/lerobot/policies/wall_x/qwen_model/qwen2_5_vl_moe.py +++ b/src/lerobot/policies/wall_x/qwen_model/qwen2_5_vl_moe.py @@ -22,7 +22,7 @@ from transformers.utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, + is_flash_attn_greater_or_equal, is_torchdynamo_compiling, logging, replace_return_docstrings, @@ -890,7 +890,7 @@ class Qwen2_5_VLFlashAttention2(Qwen2_5_VLAttention): # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal("2.1.0") def forward( self, diff --git a/src/lerobot/policies/xvla/modeling_florence2.py b/src/lerobot/policies/xvla/modeling_florence2.py index e33efe5c3..81f9c8234 100644 --- a/src/lerobot/policies/xvla/modeling_florence2.py +++ b/src/lerobot/policies/xvla/modeling_florence2.py @@ -45,7 +45,7 @@ from transformers.utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, + is_flash_attn_greater_or_equal, logging, replace_return_docstrings, ) @@ -909,7 +909,7 @@ class Florence2FlashAttention2(Florence2Attention): # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal("2.1.0") def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) From 975d89b38d1091589f4232c0449c32e75e27b2e5 Mon Sep 17 00:00:00 2001 From: Maxime Ellerbach Date: Fri, 27 Mar 2026 21:25:37 +0100 Subject: [PATCH 05/47] chore(docs): add more guidance to bring your own policies tutorial (#3230) * chore(docs): add more guidance to bring your own policies tutorial * removing normalization to avoid confusion with processors * trailing whitespace * Update docs/source/bring_your_own_policies.mdx Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Maxime Ellerbach * Update docs/source/bring_your_own_policies.mdx Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Maxime Ellerbach * adding get optim params and predict_action chunk * removing extra quotes --------- Signed-off-by: Maxime Ellerbach --- docs/source/bring_your_own_policies.mdx | 98 +++++++++++++++++++++---- 1 file changed, 85 insertions(+), 13 deletions(-) diff --git a/docs/source/bring_your_own_policies.mdx b/docs/source/bring_your_own_policies.mdx index 9266c9e5b..38c32aa71 100644 --- a/docs/source/bring_your_own_policies.mdx +++ b/docs/source/bring_your_own_policies.mdx @@ -41,13 +41,15 @@ requires = # your-build-system ## Step 2: Define the Policy Configuration -Create a configuration class that inherits from `PreTrainedConfig` and registers your policy type: +Create a configuration class that inherits from [`PreTrainedConfig`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/configs/policies.py) and registers your policy type: +Here is a template to get you started, customize the parameters and methods as needed for your policy's architecture and training requirements. ```python # configuration_my_custom_policy.py from dataclasses import dataclass, field from lerobot.configs.policies import PreTrainedConfig -from lerobot.configs.types import NormalizationMode +from lerobot.optim.optimizers import AdamWConfig +from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig @PreTrainedConfig.register_subclass("my_custom_policy") @dataclass @@ -61,22 +63,56 @@ class MyCustomPolicyConfig(PreTrainedConfig): hidden_dim: Hidden dimension for the policy network # Add your policy-specific parameters here """ - # ...PreTrainedConfig fields... - pass + + horizon: int = 50 + n_action_steps: int = 50 + hidden_dim: int = 256 + + optimizer_lr: float = 1e-4 + optimizer_weight_decay: float = 1e-4 def __post_init__(self): super().__post_init__() - # Add any validation logic here + if self.n_action_steps > self.horizon: + raise ValueError("n_action_steps cannot exceed horizon") def validate_features(self) -> None: """Validate input/output feature compatibility.""" - # Implement validation logic for your policy's requirements - pass + if not self.image_features: + raise ValueError("MyCustomPolicy requires at least one image feature.") + if self.action_feature is None: + raise ValueError("MyCustomPolicy requires 'action' in output_features.") + + def get_optimizer_preset(self) -> AdamWConfig: + return AdamWConfig(lr=self.optimizer_lr, weight_decay=self.optimizer_weight_decay) + + def get_scheduler_preset(self): + return None + + @property + def observation_delta_indices(self) -> list[int] | None: + """Relative timestep offsets the dataset loader provides per observation. + + Return `None` for single-frame policies. For temporal policies that consume + multiple past or future frames, return a list of offsets, e.g. `[-20, -10, 0, 10]` for + 3 past frames at stride 10 and 1 future frame at stride 10. + """ + return None + + @property + def action_delta_indices(self) -> list[int]: + """Relative timestep offsets for the action chunk the dataset loader returns. + """ + return list(range(self.horizon)) + + @property + def reward_delta_indices(self) -> None: + return None ``` ## Step 3: Implement the Policy Class -Create your policy implementation by inheriting from LeRobot's base `PreTrainedPolicy` class: +Create your policy implementation by inheriting from [`PreTrainedPolicy`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/pretrained.py): ```python # modeling_my_custom_policy.py @@ -85,38 +121,74 @@ import torch.nn as nn from typing import Any from lerobot.policies.pretrained import PreTrainedPolicy +from lerobot.utils.constants import ACTION from .configuration_my_custom_policy import MyCustomPolicyConfig class MyCustomPolicy(PreTrainedPolicy): - config_class = MyCustomPolicyConfig + config_class = MyCustomPolicyConfig # must match the string in @register_subclass name = "my_custom_policy" def __init__(self, config: MyCustomPolicyConfig, dataset_stats: dict[str, Any] = None): super().__init__(config, dataset_stats) + config.validate_features() # not called automatically by the base class + self.config = config + self.model = ... # your nn.Module here + + def reset(self): + """Reset episode state.""" ... + + def get_optim_params(self) -> dict: + """Return parameters to pass to the optimizer (e.g. with per-group lr/wd).""" + return {"params": self.parameters()} + + def predict_action_chunk(self, batch: dict[str, torch.Tensor], **kwargs) -> torch.Tensor: + """Return the full action chunk (B, chunk_size, action_dim) for the current observation.""" + ... + + def select_action(self, batch: dict[str, torch.Tensor], **kwargs) -> torch.Tensor: + """Return a single action for the current timestep (called at inference).""" + ... + + def forward(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + """Compute the training loss. + + `batch["action_is_pad"]` is a bool mask of shape (B, horizon) that marks + timesteps padded because the episode ended before `horizon` steps, you + can exclude those from your loss. + """ + actions = batch[ACTION] + action_is_pad = batch.get("action_is_pad") + ... + return {"loss": ...} ``` ## Step 4: Add Data Processors -Create processor functions: +Create processor functions. For a concrete reference, see [processor_act.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/act/processor_act.py) or [processor_diffusion.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/diffusion/processor_diffusion.py). ```python # processor_my_custom_policy.py from typing import Any import torch +from lerobot.processor import PolicyAction, PolicyProcessorPipeline + def make_my_custom_policy_pre_post_processors( config, + dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None, ) -> tuple[ PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], PolicyProcessorPipeline[PolicyAction, PolicyAction], ]: - """Create preprocessing and postprocessing functions for your policy.""" - pass # Define your preprocessing and postprocessing logic here - + preprocessor = ... # build your PolicyProcessorPipeline for inputs + postprocessor = ... # build your PolicyProcessorPipeline for outputs + return preprocessor, postprocessor ``` +**Important - function naming:** LeRobot discovers your processor by name. The function **must** be called `make_{policy_name}_pre_post_processors` (matching the string you passed to `@PreTrainedConfig.register_subclass`). + ## Step 5: Package Initialization Expose your classes in the package's `__init__.py`: From 4e45acca52679745f9c7d4b80984ef4c59fe9a57 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Fri, 27 Mar 2026 22:21:55 +0100 Subject: [PATCH 06/47] fix(dataset): use revision-safe Hub cache for downloaded datasets (#3233) * refactor(dataset): enhance dataset root directory handling and introduce hub cache support - Updated DatasetConfig and LeRobotDatasetMetadata to clarify root directory behavior and introduce a dedicated hub cache for downloads. - Refactored LeRobotDataset and StreamingLeRobotDataset to utilize the new hub cache and improve directory management. - Added tests to ensure correct behavior when using the hub cache and handling different revisions without a specified root directory. * refactor(dataset): improve root directory handling in LeRobotDataset - Updated LeRobotDataset to store the requested root path separately from the actual root path. - Adjusted metadata loading to use the requested root, enhancing clarity and consistency in directory management. * refactor(dataset): minor improvements for hub cache support * chore(datasets): guard in resume + assertion test --------- Co-authored-by: AdilZouitine Co-authored-by: mickaelChen --- src/lerobot/configs/default.py | 3 +- src/lerobot/datasets/dataset_metadata.py | 39 ++- src/lerobot/datasets/dataset_reader.py | 8 +- src/lerobot/datasets/lerobot_dataset.py | 81 ++++-- src/lerobot/datasets/streaming_dataset.py | 14 +- src/lerobot/datasets/utils.py | 13 + src/lerobot/utils/constants.py | 4 + tests/datasets/test_lerobot_dataset.py | 318 ++++++++++++++++++++++ 8 files changed, 440 insertions(+), 40 deletions(-) diff --git a/src/lerobot/configs/default.py b/src/lerobot/configs/default.py index 7f481b9ca..58ed64420 100644 --- a/src/lerobot/configs/default.py +++ b/src/lerobot/configs/default.py @@ -27,7 +27,8 @@ class DatasetConfig: # "dataset_index" into the returned item. The index mapping is made according to the order in which the # datasets are provided. repo_id: str - # Root directory where the dataset will be stored (e.g. 'dataset/path'). If None, defaults to $HF_LEROBOT_HOME/repo_id. + # Root directory for a concrete local dataset tree (e.g. 'dataset/path'). If None, local datasets are + # looked up under $HF_LEROBOT_HOME/repo_id and Hub downloads use a revision-safe cache under $HF_LEROBOT_HOME/hub. root: str | None = None episodes: list[int] | None = None image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig) diff --git a/src/lerobot/datasets/dataset_metadata.py b/src/lerobot/datasets/dataset_metadata.py index a43ba07b4..65dbc9c4a 100644 --- a/src/lerobot/datasets/dataset_metadata.py +++ b/src/lerobot/datasets/dataset_metadata.py @@ -44,11 +44,12 @@ from lerobot.datasets.utils import ( check_version_compatibility, flatten_dict, get_safe_version, + has_legacy_hub_download_metadata, is_valid_version, update_chunk_file_indices, ) from lerobot.datasets.video_utils import get_video_info -from lerobot.utils.constants import HF_LEROBOT_HOME +from lerobot.utils.constants import HF_LEROBOT_HOME, HF_LEROBOT_HUB_CACHE CODEBASE_VERSION = "v3.0" @@ -77,8 +78,12 @@ class LeRobotDatasetMetadata: Args: repo_id: Repository identifier (e.g. ``'lerobot/aloha_sim'``). - root: Local directory for the dataset. Defaults to - ``$HF_LEROBOT_HOME/{repo_id}``. + root: Local directory for the dataset. When provided, Hub downloads + are materialized directly into this directory. When omitted, + existing local datasets are still looked up under + ``$HF_LEROBOT_HOME/{repo_id}``, but Hub downloads use a + revision-safe snapshot cache under + ``$HF_LEROBOT_HOME/hub``. revision: Git revision (branch, tag, or commit hash). Defaults to the current codebase version. force_cache_sync: If ``True``, re-download metadata from the Hub @@ -88,7 +93,8 @@ class LeRobotDatasetMetadata: """ self.repo_id = repo_id self.revision = revision if revision else CODEBASE_VERSION - self.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id + self._requested_root = Path(root) if root is not None else None + self.root = self._requested_root if self._requested_root is not None else HF_LEROBOT_HOME / repo_id self._pq_writer = None self.latest_episode = None self._metadata_buffer: list[dict] = [] @@ -96,14 +102,15 @@ class LeRobotDatasetMetadata: self._finalized = False try: - if force_cache_sync: + if force_cache_sync or ( + self._requested_root is None and has_legacy_hub_download_metadata(self.root) + ): raise FileNotFoundError self._load_metadata() except (FileNotFoundError, NotADirectoryError): if is_valid_version(self.revision): self.revision = get_safe_version(self.repo_id, self.revision) - (self.root / "meta").mkdir(exist_ok=True, parents=True) self._pull_from_repo(allow_patterns="meta/") self._load_metadata() @@ -178,14 +185,29 @@ class LeRobotDatasetMetadata: allow_patterns: list[str] | str | None = None, ignore_patterns: list[str] | str | None = None, ) -> None: + if self._requested_root is None: + self.root = Path( + snapshot_download( + self.repo_id, + repo_type="dataset", + revision=self.revision, + cache_dir=HF_LEROBOT_HUB_CACHE, + allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns, + ) + ) + return + + self._requested_root.mkdir(exist_ok=True, parents=True) snapshot_download( self.repo_id, repo_type="dataset", revision=self.revision, - local_dir=self.root, + local_dir=self._requested_root, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, ) + self.root = self._requested_root @property def url_root(self) -> str: @@ -593,7 +615,8 @@ class LeRobotDatasetMetadata: """ obj = cls.__new__(cls) obj.repo_id = repo_id - obj.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id + obj._requested_root = Path(root) if root is not None else None + obj.root = obj._requested_root if obj._requested_root is not None else HF_LEROBOT_HOME / repo_id obj.root.mkdir(parents=True, exist_ok=False) diff --git a/src/lerobot/datasets/dataset_reader.py b/src/lerobot/datasets/dataset_reader.py index 0233a3cf6..3720a5084 100644 --- a/src/lerobot/datasets/dataset_reader.py +++ b/src/lerobot/datasets/dataset_reader.py @@ -68,7 +68,7 @@ class DatasetReader: visual features. """ self._meta = meta - self._root = root + self.root = root self.episodes = episodes self._tolerance_s = tolerance_s self._video_backend = video_backend @@ -125,7 +125,7 @@ class DatasetReader: def _load_hf_dataset(self) -> datasets.Dataset: """hf_dataset contains all the observations, states, actions, rewards, etc.""" features = get_hf_features_from_features(self._meta.features) - hf_dataset = load_nested_dataset(self._root / "data", features=features, episodes=self.episodes) + hf_dataset = load_nested_dataset(self.root / "data", features=features, episodes=self.episodes) hf_dataset.set_transform(hf_transform_to_torch) return hf_dataset @@ -150,7 +150,7 @@ class DatasetReader: if len(self._meta.video_keys) > 0: for ep_idx in requested_episodes: for vid_key in self._meta.video_keys: - video_path = self._root / self._meta.get_video_file_path(ep_idx, vid_key) + video_path = self.root / self._meta.get_video_file_path(ep_idx, vid_key) if not video_path.exists(): return False @@ -240,7 +240,7 @@ class DatasetReader: from_timestamp = ep[f"videos/{vid_key}/from_timestamp"] shifted_query_ts = [from_timestamp + ts for ts in query_ts] - video_path = self._root / self._meta.get_video_file_path(ep_idx, vid_key) + video_path = self.root / self._meta.get_video_file_path(ep_idx, vid_key) frames = decode_video_frames(video_path, shifted_query_ts, self._tolerance_s, self._video_backend) item[vid_key] = frames.squeeze(0) diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index cba0c1cba..f719222fd 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -37,7 +37,7 @@ from lerobot.datasets.video_utils import ( get_safe_default_codec, resolve_vcodec, ) -from lerobot.utils.constants import HF_LEROBOT_HOME +from lerobot.utils.constants import HF_LEROBOT_HUB_CACHE logger = logging.getLogger(__name__) @@ -144,10 +144,11 @@ class LeRobotDataset(torch.utils.data.Dataset): Args: repo_id (str): This is the repo id that will be used to fetch the dataset. - root (Path | None, optional): Local directory where the dataset will be downloaded and - stored. If set, all dataset files will be stored directly under this path. If not set, the - dataset files will be stored under $HF_LEROBOT_HOME/repo_id (configurable via the - HF_LEROBOT_HOME environment variable). + root (Path | None, optional): Local directory where the dataset will be read from or downloaded + into. If set, all dataset files are materialized directly under this path. If not set, + existing local datasets are still looked up under ``$HF_LEROBOT_HOME/{repo_id}``, but Hub + downloads use a revision-safe snapshot cache under + ``$HF_LEROBOT_HOME/hub``. episodes (list[int] | None, optional): If specified, this will only load episodes specified by their episode_index in this list. Defaults to None. image_transforms (Callable | None, optional): You can pass standard v2 image transforms from @@ -190,7 +191,7 @@ class LeRobotDataset(torch.utils.data.Dataset): """ super().__init__() self.repo_id = repo_id - self.root = Path(root) if root else HF_LEROBOT_HOME / repo_id + self._requested_root = Path(root) if root else None self.image_transforms = image_transforms self.delta_timestamps = delta_timestamps self.episodes = episodes @@ -201,12 +202,15 @@ class LeRobotDataset(torch.utils.data.Dataset): self._vcodec = resolve_vcodec(vcodec) self._encoder_threads = encoder_threads - self.root.mkdir(exist_ok=True, parents=True) + if self._requested_root is not None: + self._requested_root.mkdir(exist_ok=True, parents=True) - # Load metadata + # Load metadata (sets self.root once from the resolved metadata root) self.meta = LeRobotDatasetMetadata( - self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync + self.repo_id, self._requested_root, self.revision, force_cache_sync=force_cache_sync ) + self.root = self.meta.root + self.revision = self.meta.revision # Create reader (hf_dataset loaded below) self.reader = DatasetReader( @@ -556,14 +560,33 @@ class LeRobotDataset(torch.utils.data.Dataset): if self.episodes is not None: # Reader is guaranteed to exist here (created in __init__ before _download) files = self.reader.get_episodes_file_paths() - snapshot_download( - self.repo_id, - repo_type="dataset", - revision=self.revision, - local_dir=self.root, - allow_patterns=files, - ignore_patterns=ignore_patterns, - ) + + if self._requested_root is None: + self.meta.root = Path( + snapshot_download( + self.repo_id, + repo_type="dataset", + revision=self.revision, + cache_dir=HF_LEROBOT_HUB_CACHE, + allow_patterns=files, + ignore_patterns=ignore_patterns, + ) + ) + else: + self._requested_root.mkdir(exist_ok=True, parents=True) + snapshot_download( + self.repo_id, + repo_type="dataset", + revision=self.revision, + local_dir=self._requested_root, + allow_patterns=files, + ignore_patterns=ignore_patterns, + ) + self.meta.root = self._requested_root + + # Propagate resolved root from metadata (single source of truth) + self.root = self.meta.root + self.reader.root = self.meta.root # ── Class constructors ──────────────────────────────────────────── @@ -635,6 +658,7 @@ class LeRobotDataset(torch.utils.data.Dataset): metadata_buffer_size=metadata_buffer_size, ) obj.repo_id = obj.meta.repo_id + obj._requested_root = obj.meta.root obj.root = obj.meta.root obj.revision = None obj.tolerance_s = tolerance_s @@ -695,8 +719,10 @@ class LeRobotDataset(torch.utils.data.Dataset): Args: repo_id: Repository identifier of the existing dataset. - root: Local directory of the dataset. Defaults to - ``$HF_LEROBOT_HOME/{repo_id}``. + root: Local directory of the dataset. When provided, Hub downloads + are materialized directly into this directory. When omitted, + Hub downloads use a revision-safe snapshot cache under + ``$HF_LEROBOT_HOME/hub``. tolerance_s: Timestamp synchronization tolerance in seconds. revision: Git revision (branch, tag, or commit hash). Defaults to current codebase version tag. @@ -716,11 +742,16 @@ class LeRobotDataset(torch.utils.data.Dataset): Returns: A :class:`LeRobotDataset` in write mode, ready to append episodes. """ + if not root: + raise ValueError( + "resume() requires an explicit 'root' directory because it creates a DatasetWriter. " + "Writing into the revision-safe Hub snapshot cache (used when root=None) would corrupt " + "the shared cache. Please provide a local directory path." + ) vcodec = resolve_vcodec(vcodec) obj = cls.__new__(cls) obj.repo_id = repo_id - obj.root = Path(root) if root else HF_LEROBOT_HOME / repo_id - obj.root.mkdir(exist_ok=True, parents=True) + obj._requested_root = Path(root) obj.revision = revision if revision else CODEBASE_VERSION obj.tolerance_s = tolerance_s obj.image_transforms = None @@ -731,10 +762,14 @@ class LeRobotDataset(torch.utils.data.Dataset): obj._vcodec = vcodec obj._encoder_threads = encoder_threads - # Load metadata + if obj._requested_root is not None: + obj._requested_root.mkdir(exist_ok=True, parents=True) + + # Load metadata (revision-safe when root is not provided) obj.meta = LeRobotDatasetMetadata( - obj.repo_id, obj.root, obj.revision, force_cache_sync=force_cache_sync + obj.repo_id, obj._requested_root, obj.revision, force_cache_sync=force_cache_sync ) + obj.root = obj.meta.root # Reader is lazily created on first access (write-only mode) obj.reader = None diff --git a/src/lerobot/datasets/streaming_dataset.py b/src/lerobot/datasets/streaming_dataset.py index 62e00558a..1767cc79d 100644 --- a/src/lerobot/datasets/streaming_dataset.py +++ b/src/lerobot/datasets/streaming_dataset.py @@ -255,7 +255,9 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): Args: repo_id (str): This is the repo id that will be used to fetch the dataset. - root (Path | None, optional): Local directory to use for downloading/writing files. + root (Path | None, optional): Local directory to use for local datasets. When omitted, Hub + metadata is resolved through a revision-safe snapshot cache under + ``$HF_LEROBOT_HOME/hub``. episodes (list[int] | None, optional): If specified, this will only load episodes specified by their episode_index in this list. image_transforms (Callable | None, optional): Transform to apply to image data. @@ -271,7 +273,8 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): """ super().__init__() self.repo_id = repo_id - self.root = Path(root) if root else HF_LEROBOT_HOME / repo_id + self._requested_root = Path(root) if root else None + self.root = self._requested_root if self._requested_root is not None else HF_LEROBOT_HOME / repo_id self.streaming_from_local = root is not None self.image_transforms = image_transforms @@ -288,12 +291,15 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): # We cache the video decoders to avoid re-initializing them at each frame (avoiding a ~10x slowdown) self.video_decoder_cache = None - self.root.mkdir(exist_ok=True, parents=True) + if self._requested_root is not None: + self.root.mkdir(exist_ok=True, parents=True) # Load metadata self.meta = LeRobotDatasetMetadata( - self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync + self.repo_id, self._requested_root, self.revision, force_cache_sync=force_cache_sync ) + self.root = self.meta.root + self.revision = self.meta.revision # Check version check_version_compatibility(self.repo_id, self.meta._version, CODEBASE_VERSION) diff --git a/src/lerobot/datasets/utils.py b/src/lerobot/datasets/utils.py index 2e1d360f9..36e7934ed 100644 --- a/src/lerobot/datasets/utils.py +++ b/src/lerobot/datasets/utils.py @@ -18,6 +18,7 @@ import importlib.resources import json import logging from collections.abc import Iterator +from pathlib import Path from typing import Any import datasets @@ -101,6 +102,18 @@ DEFAULT_FEATURES = { } +def has_legacy_hub_download_metadata(root: Path) -> bool: + """Return ``True`` when *root* looks like a legacy Hub ``local_dir`` mirror. + + ``snapshot_download(local_dir=...)`` stores lightweight metadata under + ``/.cache/huggingface/download/``. The presence of this + directory is a reliable indicator that the dataset was downloaded with + the old non-revision-safe ``local_dir`` mode and should be re-fetched + through the snapshot cache instead. + """ + return (root / ".cache" / "huggingface" / "download").exists() + + def update_chunk_file_indices(chunk_idx: int, file_idx: int, chunks_size: int) -> tuple[int, int]: if file_idx == chunks_size - 1: file_idx = 0 diff --git a/src/lerobot/utils/constants.py b/src/lerobot/utils/constants.py index ecd54844c..fd10cab35 100644 --- a/src/lerobot/utils/constants.py +++ b/src/lerobot/utils/constants.py @@ -65,6 +65,10 @@ if "LEROBOT_HOME" in os.environ: # cache dir default_cache_path = Path(HF_HOME) / "lerobot" HF_LEROBOT_HOME = Path(os.getenv("HF_LEROBOT_HOME", default_cache_path)).expanduser() +# LeRobot's own revision-safe Hub cache (NOT the system-wide ~/.cache/huggingface/hub/). +# Used as the ``cache_dir`` argument to ``snapshot_download`` so that different +# dataset revisions are stored in isolated snapshot directories. +HF_LEROBOT_HUB_CACHE = HF_LEROBOT_HOME / "hub" # calibration dir default_calibration_path = HF_LEROBOT_HOME / "calibration" diff --git a/tests/datasets/test_lerobot_dataset.py b/tests/datasets/test_lerobot_dataset.py index d7ce54a15..a8aa47ed2 100644 --- a/tests/datasets/test_lerobot_dataset.py +++ b/tests/datasets/test_lerobot_dataset.py @@ -19,9 +19,15 @@ Tests focus on mode contracts (read-only, write-only, resume), guards, property delegation, and the full create-record-finalize-read lifecycle. """ +from pathlib import Path +from unittest.mock import Mock + import pytest import torch +import lerobot.datasets.dataset_metadata as dataset_metadata_module +import lerobot.datasets.lerobot_dataset as lerobot_dataset_module +from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata from lerobot.datasets.dataset_reader import DatasetReader from lerobot.datasets.dataset_writer import DatasetWriter from lerobot.datasets.lerobot_dataset import LeRobotDataset @@ -30,12 +36,69 @@ from tests.fixtures.constants import DEFAULT_FPS, DUMMY_REPO_ID SIMPLE_FEATURES = { "state": {"dtype": "float32", "shape": (2,), "names": None}, } +SNAPSHOT_MAIN_FEATURES = { + **SIMPLE_FEATURES, + "test": {"dtype": "float32", "shape": (2,), "names": None}, +} def _make_frame(task: str = "Dummy task") -> dict: return {"task": task, "state": torch.randn(2)} +def _set_default_cache_root(monkeypatch: pytest.MonkeyPatch, cache_root: Path) -> None: + monkeypatch.setattr(dataset_metadata_module, "HF_LEROBOT_HOME", cache_root) + monkeypatch.setattr(dataset_metadata_module, "HF_LEROBOT_HUB_CACHE", cache_root / "hub") + monkeypatch.setattr(lerobot_dataset_module, "HF_LEROBOT_HUB_CACHE", cache_root / "hub") + + +def _write_dataset_tree( + root: Path, + *, + motor_features: dict[str, dict], + info_factory, + stats_factory, + tasks_factory, + episodes_factory, + hf_dataset_factory, + create_info, + create_stats, + create_tasks, + create_episodes, + create_hf_dataset, +) -> None: + root.mkdir(parents=True, exist_ok=True) + info = info_factory( + total_episodes=1, + total_frames=3, + total_tasks=1, + use_videos=False, + motor_features=motor_features, + camera_features={}, + ) + tasks = tasks_factory(total_tasks=1) + episodes = episodes_factory( + features=info["features"], + fps=info["fps"], + total_episodes=1, + total_frames=3, + tasks=tasks, + ) + stats = stats_factory(features=info["features"]) + hf_dataset = hf_dataset_factory( + features=info["features"], + tasks=tasks, + episodes=episodes, + fps=info["fps"], + ) + + create_info(root, info) + create_stats(root, stats) + create_tasks(root, tasks) + create_episodes(root, episodes) + create_hf_dataset(root, hf_dataset) + + # ── Read-only mode (via __init__) ──────────────────────────────────── @@ -75,6 +138,261 @@ def test_len_matches_num_frames(tmp_path, lerobot_dataset_factory): assert len(dataset) == dataset.num_frames +def test_metadata_without_root_uses_hub_cache_snapshot_download( + tmp_path, + info_factory, + stats_factory, + tasks_factory, + episodes_factory, + hf_dataset_factory, + create_info, + create_stats, + create_tasks, + create_episodes, + create_hf_dataset, + monkeypatch, +): + """Metadata refresh uses the dedicated Hub cache instead of a shared local_dir mirror.""" + repo_id = DUMMY_REPO_ID + cache_root = tmp_path / "lerobot_cache" + snapshot_root = cache_root / "hub" / "datasets--dummy--repo" / "snapshots" / "commit-main" + _write_dataset_tree( + snapshot_root, + motor_features=SNAPSHOT_MAIN_FEATURES, + info_factory=info_factory, + stats_factory=stats_factory, + tasks_factory=tasks_factory, + episodes_factory=episodes_factory, + hf_dataset_factory=hf_dataset_factory, + create_info=create_info, + create_stats=create_stats, + create_tasks=create_tasks, + create_episodes=create_episodes, + create_hf_dataset=create_hf_dataset, + ) + + _set_default_cache_root(monkeypatch, cache_root) + snapshot_download = Mock(return_value=str(snapshot_root)) + monkeypatch.setattr(dataset_metadata_module, "snapshot_download", snapshot_download) + + meta = LeRobotDatasetMetadata(repo_id=repo_id, revision="main", force_cache_sync=True) + + assert meta.root == snapshot_root + assert snapshot_download.call_count == 1 + assert snapshot_download.call_args.args == (repo_id,) + assert snapshot_download.call_args.kwargs == { + "repo_type": "dataset", + "revision": "main", + "cache_dir": cache_root / "hub", + "allow_patterns": "meta/", + "ignore_patterns": None, + } + + +def test_without_root_reads_different_revisions_from_distinct_snapshot_roots( + tmp_path, + info_factory, + stats_factory, + tasks_factory, + episodes_factory, + hf_dataset_factory, + create_info, + create_stats, + create_tasks, + create_episodes, + create_hf_dataset, + monkeypatch, +): + """Different revisions resolve to different on-disk snapshot roots.""" + repo_id = DUMMY_REPO_ID + old_revision = "b59010db93eb6cc3cf06ef2f7cae1bbe62b726d9" + cache_root = tmp_path / "lerobot_cache" + main_root = cache_root / "hub" / "datasets--dummy--repo" / "snapshots" / "commit-main" + old_root = cache_root / "hub" / "datasets--dummy--repo" / "snapshots" / "commit-old" + + _write_dataset_tree( + main_root, + motor_features=SNAPSHOT_MAIN_FEATURES, + info_factory=info_factory, + stats_factory=stats_factory, + tasks_factory=tasks_factory, + episodes_factory=episodes_factory, + hf_dataset_factory=hf_dataset_factory, + create_info=create_info, + create_stats=create_stats, + create_tasks=create_tasks, + create_episodes=create_episodes, + create_hf_dataset=create_hf_dataset, + ) + _write_dataset_tree( + old_root, + motor_features=SIMPLE_FEATURES, + info_factory=info_factory, + stats_factory=stats_factory, + tasks_factory=tasks_factory, + episodes_factory=episodes_factory, + hf_dataset_factory=hf_dataset_factory, + create_info=create_info, + create_stats=create_stats, + create_tasks=create_tasks, + create_episodes=create_episodes, + create_hf_dataset=create_hf_dataset, + ) + + _set_default_cache_root(monkeypatch, cache_root) + snapshot_roots = { + "main": main_root, + old_revision: old_root, + } + meta_snapshot_download = Mock( + side_effect=lambda repo_id, **kwargs: str(snapshot_roots[kwargs["revision"]]) + ) + data_snapshot_download = Mock( + side_effect=lambda repo_id, **kwargs: str(snapshot_roots[kwargs["revision"]]) + ) + monkeypatch.setattr(dataset_metadata_module, "snapshot_download", meta_snapshot_download) + monkeypatch.setattr(lerobot_dataset_module, "snapshot_download", data_snapshot_download) + + main_dataset = LeRobotDataset( + repo_id=repo_id, revision="main", download_videos=False, force_cache_sync=True + ) + old_dataset = LeRobotDataset( + repo_id=repo_id, revision=old_revision, download_videos=False, force_cache_sync=True + ) + + assert main_dataset.root == main_root + assert old_dataset.root == old_root + assert "test" in main_dataset.hf_dataset.column_names + assert "test" not in old_dataset.hf_dataset.column_names + + # Metadata downloads use cache_dir, not local_dir + assert meta_snapshot_download.call_count == 2 + for download_call in meta_snapshot_download.call_args_list: + assert download_call.kwargs["cache_dir"] == cache_root / "hub" + assert "local_dir" not in download_call.kwargs + + # Data downloads also use cache_dir, not local_dir + assert data_snapshot_download.call_count == 2 + for download_call in data_snapshot_download.call_args_list: + assert download_call.kwargs["cache_dir"] == cache_root / "hub" + assert "local_dir" not in download_call.kwargs + + +def test_metadata_without_root_ignores_legacy_local_dir_cache( + tmp_path, + info_factory, + stats_factory, + tasks_factory, + episodes_factory, + hf_dataset_factory, + create_info, + create_stats, + create_tasks, + create_episodes, + create_hf_dataset, + monkeypatch, +): + """Legacy local-dir mirrors are bypassed in favor of revision-safe snapshots.""" + repo_id = DUMMY_REPO_ID + cache_root = tmp_path / "lerobot_cache" + legacy_root = cache_root / repo_id + snapshot_root = cache_root / "hub" / "datasets--dummy--repo" / "snapshots" / "commit-main" + + _write_dataset_tree( + legacy_root, + motor_features=SIMPLE_FEATURES, + info_factory=info_factory, + stats_factory=stats_factory, + tasks_factory=tasks_factory, + episodes_factory=episodes_factory, + hf_dataset_factory=hf_dataset_factory, + create_info=create_info, + create_stats=create_stats, + create_tasks=create_tasks, + create_episodes=create_episodes, + create_hf_dataset=create_hf_dataset, + ) + (legacy_root / ".cache" / "huggingface" / "download").mkdir(parents=True, exist_ok=True) + _write_dataset_tree( + snapshot_root, + motor_features=SNAPSHOT_MAIN_FEATURES, + info_factory=info_factory, + stats_factory=stats_factory, + tasks_factory=tasks_factory, + episodes_factory=episodes_factory, + hf_dataset_factory=hf_dataset_factory, + create_info=create_info, + create_stats=create_stats, + create_tasks=create_tasks, + create_episodes=create_episodes, + create_hf_dataset=create_hf_dataset, + ) + + _set_default_cache_root(monkeypatch, cache_root) + snapshot_download = Mock(return_value=str(snapshot_root)) + monkeypatch.setattr(dataset_metadata_module, "snapshot_download", snapshot_download) + + meta = LeRobotDatasetMetadata(repo_id=repo_id, revision="main") + + assert meta.root == snapshot_root + assert "test" in meta.features + assert snapshot_download.call_count == 1 + + +def test_download_without_root_uses_hub_cache( + tmp_path, + info_factory, + stats_factory, + tasks_factory, + episodes_factory, + hf_dataset_factory, + create_info, + create_stats, + create_tasks, + create_episodes, + create_hf_dataset, + monkeypatch, +): + """LeRobotDataset._download() uses cache_dir (not local_dir) when root is not provided.""" + repo_id = DUMMY_REPO_ID + cache_root = tmp_path / "lerobot_cache" + snapshot_root = cache_root / "hub" / "datasets--dummy--repo" / "snapshots" / "commit-main" + + # Pre-populate snapshot directory so metadata loads succeed, but leave + # data absent so that _download() is triggered. + _write_dataset_tree( + snapshot_root, + motor_features=SIMPLE_FEATURES, + info_factory=info_factory, + stats_factory=stats_factory, + tasks_factory=tasks_factory, + episodes_factory=episodes_factory, + hf_dataset_factory=hf_dataset_factory, + create_info=create_info, + create_stats=create_stats, + create_tasks=create_tasks, + create_episodes=create_episodes, + create_hf_dataset=create_hf_dataset, + ) + + _set_default_cache_root(monkeypatch, cache_root) + meta_snapshot_download = Mock(return_value=str(snapshot_root)) + monkeypatch.setattr(dataset_metadata_module, "snapshot_download", meta_snapshot_download) + + # Mock the data snapshot_download to return the same root (data already + # exists there from _write_dataset_tree). + data_snapshot_download = Mock(return_value=str(snapshot_root)) + monkeypatch.setattr(lerobot_dataset_module, "snapshot_download", data_snapshot_download) + + LeRobotDataset(repo_id=repo_id, revision="main", force_cache_sync=True) + + # _download() should have called snapshot_download with cache_dir + assert data_snapshot_download.call_count == 1 + call_kwargs = data_snapshot_download.call_args.kwargs + assert call_kwargs["cache_dir"] == cache_root / "hub" + assert "local_dir" not in call_kwargs + + # ── Write-only mode (via create()) ────────────────────────────────── From 2e069b1c4769e40371d226763191d68123282d4d Mon Sep 17 00:00:00 2001 From: Bryson Jones <63133702+brysonjones@users.noreply.github.com> Date: Fri, 27 Mar 2026 16:41:26 -0700 Subject: [PATCH 07/47] Feature/add multitask diffusion transformer policy implementation (#2545) * Add multitask diffusion transformer policy Add multitask diffusion transformer policy * expand the observation encoder to support differnt size encoders for vision and text * add RoPE attention module as this is shown to help training dynamics and generation quality for DiTs * update readme and citations for multitask dit policy * remove dino vision encoder and simplify text and vision encoders by removing inheritance structure * adjust factory comment * update docstring for multitask dit policy processor file * simplify config for multitask dit by merging and flattening everything, then adding comments to denote where some parameters are only used for specific objectives * add references to the modeling file comments * merge all modules files into the main modeling file * add torch.no_grad decorators * split up select action return statement * remove redundant asserts * add tutorial to training with multi_task_dit * fix bugs when testing on hardware * remove environment state conditioning * update typo in test instruction comment * add processor tests to multitask dit tests * move policy to top of file * use constants for indexing into batches and remove env state references * remove the base classes since we don't need to be able to extend * fix nit formatting in generate actions fcn * reformat and clean up tutorial for multitask dit policy * add more descriptions and depth to multitask dit tutorial * note origins of each training objective * rename config param for multiple vision encoders * refactor code to perform task tokenization in the processor instead of in the modeling code for multitask dit * add multitask dit to toc for docs * add conditional transformers import to match all other policies that use transformers lib * add test handling for multitask dit when transformers isnt available * skip tests without transformers * remove cropping of images smaller than the crop size * add kwargs arg to multitask dit constructor * add wallx dep conflict management for multitask dit policy * use hyphens for cleanliness in pyproject.toml * add conflict management to pyproject toml for pi conflict for mtdp as well * update tests script to not use unnecessary uv sync call which resolves dependencies that do not need to run. This drastically reduces CI run time * revert fast tests edits * update docs and readme files, fixing some typos and adding multitask dit to readme * chore(dependencies): upgrade transformers + hggingface-hub + peft + scipy * chore(dependencies): bump pi0 family to transformers v5 * chore(dependencies): bump wall x to transformers v5 * chore(dependencies): bump gr00t to transformers v5 * chore(style): fix pre-commit * fix(policy): xvla forced_bos_token missing * test(rl): skip ci tests for resnet10 * Fix: full pi models support for transformer v5 (#2967) * fix(pi): remove loss truncation * fix(pi): remove state padding before tokenization * fix(pi): fix image padding value * fix from_pretrain * add transformer v5 changes * remove reference * more fixes * make it work * add support for rest of pi family * add pifast work * more changes * more changes * more cleanup * fix torch params * dtype fix * torch compile * embed mismatch fix * revert groot * more nit fixes * remove unused classes * more fixes * revert * nit * torch dtype warning fix * but back dynamic renaming * add tie embedding --------- Co-authored-by: Yufei Sun * chore: fix XVLA in transformers v5 (#3006) * test(policies): enable wall x CI testing * style(test): pre-commit check * style(test): pre-commit --------- Signed-off-by: Bryson Jones <63133702+brysonjones@users.noreply.github.com> Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com> Co-authored-by: Steven Palma Co-authored-by: Jade Choghari Co-authored-by: Yufei Sun Co-authored-by: Steven Palma --- README.md | 10 +- docs/source/_toctree.yml | 2 + docs/source/multi_task_dit.mdx | 340 ++++++++ docs/source/policy_multi_task_dit_README.md | 37 + pyproject.toml | 1 + src/lerobot/policies/__init__.py | 2 + src/lerobot/policies/factory.py | 24 +- src/lerobot/policies/multi_task_dit/README.md | 37 + .../policies/multi_task_dit/__init__.py | 21 + .../configuration_multi_task_dit.py | 256 ++++++ .../multi_task_dit/modeling_multi_task_dit.py | 803 ++++++++++++++++++ .../processor_multi_task_dit.py | 105 +++ .../multi_task_dit/test_multi_task_dit.py | 624 ++++++++++++++ 13 files changed, 2253 insertions(+), 9 deletions(-) create mode 100644 docs/source/multi_task_dit.mdx create mode 100644 docs/source/policy_multi_task_dit_README.md create mode 100644 src/lerobot/policies/multi_task_dit/README.md create mode 100644 src/lerobot/policies/multi_task_dit/__init__.py create mode 100644 src/lerobot/policies/multi_task_dit/configuration_multi_task_dit.py create mode 100644 src/lerobot/policies/multi_task_dit/modeling_multi_task_dit.py create mode 100644 src/lerobot/policies/multi_task_dit/processor_multi_task_dit.py create mode 100644 tests/policies/multi_task_dit/test_multi_task_dit.py diff --git a/README.md b/README.md index f58b337b3..f67d9103c 100644 --- a/README.md +++ b/README.md @@ -100,11 +100,11 @@ lerobot-train \ --dataset.repo_id=lerobot/aloha_mobile_cabinet ``` -| Category | Models | -| -------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | -| **Imitation Learning** | [ACT](./docs/source/policy_act_README.md), [Diffusion](./docs/source/policy_diffusion_README.md), [VQ-BeT](./docs/source/policy_vqbet_README.md) | -| **Reinforcement Learning** | [HIL-SERL](./docs/source/hilserl.mdx), [TDMPC](./docs/source/policy_tdmpc_README.md) & QC-FQL (coming soon) | -| **VLAs Models** | [Pi0Fast](./docs/source/pi0fast.mdx), [Pi0.5](./docs/source/pi05.mdx), [GR00T N1.5](./docs/source/policy_groot_README.md), [SmolVLA](./docs/source/policy_smolvla_README.md), [XVLA](./docs/source/xvla.mdx) | +| Category | Models | +| -------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| **Imitation Learning** | [ACT](./docs/source/policy_act_README.md), [Diffusion](./docs/source/policy_diffusion_README.md), [VQ-BeT](./docs/source/policy_vqbet_README.md), [Multitask DiT Policy](./docs/source/policy_multi_task_dit_README.md) | +| **Reinforcement Learning** | [HIL-SERL](./docs/source/hilserl.mdx), [TDMPC](./docs/source/policy_tdmpc_README.md) & QC-FQL (coming soon) | +| **VLAs Models** | [Pi0Fast](./docs/source/pi0fast.mdx), [Pi0.5](./docs/source/pi05.mdx), [GR00T N1.5](./docs/source/policy_groot_README.md), [SmolVLA](./docs/source/policy_smolvla_README.md), [XVLA](./docs/source/xvla.mdx) | Similarly to the hardware, you can easily implement your own policy & leverage LeRobot's data collection, training, and visualization tools, and share your model to the HF Hub diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 09d94d28c..650a21184 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -49,6 +49,8 @@ title: NVIDIA GR00T N1.5 - local: xvla title: X-VLA + - local: multi_task_dit + title: Multitask DiT Policy - local: walloss title: WALL-OSS title: "Policies" diff --git a/docs/source/multi_task_dit.mdx b/docs/source/multi_task_dit.mdx new file mode 100644 index 000000000..c3cced708 --- /dev/null +++ b/docs/source/multi_task_dit.mdx @@ -0,0 +1,340 @@ +# Multitask DiT Policy + +Multitask Diffusion Transformer (DiT) Policy is an evolution of the original Diffusion Policy architecture, which leverages a large DiT with text and vision conditioning for multitask robot learning. This implementation supports both diffusion and flow matching objectives for action generation, enabling robots to perform diverse manipulation tasks conditioned on language instructions. + +## Model Overview + +The model uses: + +- **CLIP Vision Encoder**: Processes RGB images from multiple camera views +- **CLIP Text Encoder**: Encodes language task instructions (frozen weights with learnable projection) +- **Diffusion Transformer**: Predicts action sequences conditioned on observations and language +- **Two Objectives**: Supports both diffusion (DDPM/DDIM) and flow matching for action generation + +This model is exciting because you can achieve extremely high dexterity, competitive with multi-billion parameter +VLAs, with only ~450M parameters and significantly less training. + +## Installation Requirements + +Multitask DiT Policy has additional dependencies. Install it with: + +```bash +pip install lerobot[multi_task_dit] +``` + +This will install all necessary dependencies including the HuggingFace Transformers library for CLIP models. + +## Usage + +To use Multitask DiT in your LeRobot configuration, specify the policy type as: + +```python +policy.type=multi_task_dit +``` + +## Training + +### Basic Training Command + +Here's a complete training command for training Multitask DiT on your dataset: + +```bash +lerobot-train \ + --dataset.repo_id=YOUR_DATASET \ + --output_dir=./outputs/multitask_dit_training \ + --batch_size=32 \ + --steps=5000 \ + --save_freq=500 \ + --log_freq=100 \ + --policy.type=multi_task_dit \ + --policy.device=cuda \ + --policy.repo_id="HF_USER/multitask-dit-your-robot" \ + --wandb.enable=true +``` + +### Recommended Hyperparameters and Dataset Details (30Hz Control Frequency) + +For reliable performance, start with these suggested default hyperparameters: + +```bash +lerobot-train \ + --dataset.repo_id=YOUR_DATASET \ + --output_dir=./outputs/mutitask_dit_training \ + --batch_size=320 \ + --steps=30000 \ + --policy.type=multi_task_dit \ + --policy.device=cuda \ + --policy.horizon=32 \ + --policy.n_action_steps=24 \ + --policy.objective=diffusion \ + --policy.noise_scheduler_type=DDPM \ + --policy.num_train_timesteps=100 \ + --policy.repo_id="HF_USER/multitask-dit-your-robot" \ + --wandb.enable=true +``` + +**Key Parameters:** + +- **Batch Size**: 192-320 - If you have access to a GPU that can support this, you will get the best training dynamics +- **Horizon**: 32 - number of action steps to predict, ~1.0 sec at 30Hz +- **n_action_steps**: 24 - ~0.8 seconds at 30Hz +- **Objective**: `diffusion` - start with diffusion and experiment with flow matching if generation quality is poor +- **Training Steps**: >30k steps recommended for a single task + +### Training Configuration Parameters + +#### Objective Selection + +Choose between diffusion and flow matching: + +```bash +# Diffusion objective (default) +--policy.objective=diffusion \ +--policy.noise_scheduler_type=DDPM \ # or "DDIM" +--policy.num_train_timesteps=100 \ +--policy.num_inference_steps=10 \ # For faster inference +--policy.beta_schedule=squaredcos_cap_v2 \ # Noise schedule type +--policy.prediction_type=epsilon \ # "epsilon" (predict noise) or "sample" (predict clean) +--policy.clip_sample=true \ # Clip samples during denoising +--policy.clip_sample_range=1.0 # Clipping range [-x, x] + +# Flow matching objective +--policy.objective=flow_matching \ +--policy.timestep_sampling_strategy=beta \ # or "uniform" | the beta sampling strategy performance appears much better in practice +--policy.num_integration_steps=100 \ +--policy.integration_method=euler \ # or "rk4" +--policy.sigma_min=0.0 # Minimum noise in flow interpolation path +``` + +#### Transformer Architecture + +Adjust model capacity based on dataset size: + +```bash +# Small datasets (< 100 examples) +--policy.num_layers=4 \ +--policy.hidden_dim=512 \ +--policy.num_heads=8 # should ideally be hidden_dim // 64 + +# Medium datasets (100-5k examples) - default +--policy.num_layers=6 \ +--policy.hidden_dim=512 \ +--policy.num_heads=8 # should ideally be hidden_dim // 64 + +# Large datasets (> 5k examples) +--policy.num_layers=8 \ +--policy.hidden_dim=512 \ +--policy.num_heads=8 # should ideally be hidden_dim // 64 +``` + +**Positional Encoding Options:** + +The model supports two positional encoding methods for action sequences: + +```bash +# Rotary Position Embedding (RoPE) - default, recommended +--policy.use_rope=true \ +--policy.rope_base=10000.0 # Base frequency for RoPE + +# Absolute positional encoding +--policy.use_positional_encoding=true # Disables RoPE when true +``` + +**Other Transformer Parameters:** + +```bash +--policy.dropout=0.1 # Dropout rate for DiT blocks (0.0-1.0) +--policy.timestep_embed_dim=256 # Timestep embedding dimension +``` + +#### Vision Encoder Configuration + +```bash +# Use different CLIP model for more expressivity at the cost of inference time +# experiment with larger or smaller models depending on the complexity of your tasks and size of dataset +--policy.vision_encoder_name=openai/clip-vit-large-patch14 + +# Use separate vision encoder per camera +# This may be useful when cameras have significantly different characteristics, but +# be wary of increased VRAM footprint. +--policy.use_separate_rgb_encoder_per_camera=true + +# Image preprocessing +--policy.image_resize_shape=[XXX,YYY] \ # you may need to resize your images for inference speed ups +--policy.image_crop_shape=[224,224] \ +--policy.image_crop_is_random=true # Random during training, center at inference +``` + +#### Text Encoder Configuration + +```bash +# Use different CLIP text encoder model +# same as vision: experiment with larger or smaller models depending on the +# complexity of your tasks and size of dataset +--policy.text_encoder_name=openai/clip-vit-large-patch14 +``` + +#### Learning Rate Configuration + +The vision encoder uses a separate learning rate multiplier, where 1/10th is suggested to be the ideal staritng point: + +```bash +--policy.optimizer_lr=2e-5 \ +--policy.vision_encoder_lr_multiplier=0.1 # Vision encoder LR = 0.1 * optimizer_lr +``` + +### Training Tuning Guidelines + +#### 1. Flow Matching with Beta Sampling + +The original diffusion implementation here is based on the work described in [TRI's LBM paper](https://arxiv.org/abs/2507.05331) + +Additionally, we have implemented a flow-matching objective, which is described at a high-level in [Boston Dynamics blog post](https://bostondynamics.com/blog/large-behavior-models-atlas-find-new-footing/). + +Consider testing the flow-matching objective and evaluating performance differences for your task: + +```bash +--policy.objective=flow_matching \ +--policy.timestep_sampling_strategy=beta \ +--policy.timestep_sampling_alpha=1.5 \ +--policy.timestep_sampling_beta=1.0 \ +--policy.timestep_sampling_s=0.999 +``` + +This hasn't been shown to be a silver bullet across every user case, but it occasionally results in smoother and more consistent actions. + +#### 2. Number of Transformer Layers + +Match model capacity to your dataset size: + +- **Small datasets** (< 100 examples): Reduce to 4 layers +- **Large datasets** (> 5k examples): Increase to 8 layers + +#### 3. `horizon` Tuning + +The model can be sensitive to the horizon you choose. Start with around a 1 second horizon based on your control frequency: + +- **30 Hz frequency**: `horizon=30` +- **10 Hz frequency**: `horizon=10` + +Then experiment with increasing from there. The horizon determines how far into the future the model predicts actions. + +#### 4. `n_action_steps` Sensitivity + +The model can also be very sensitive to `n_action_steps`. Start with it being around 0.8 seconds based on your control frequency and tune from there: + +- **Lower values**: More reactive but potentially less stable for long-horizon tasks +- **Higher values**: Better for long-horizon execution but open-loop failures are limited in their recovery + +### Inference Tuning + +For faster inference, use DDIM with fewer sampling steps: + +```bash +--policy.noise_scheduler_type=DDIM \ +--policy.num_inference_steps=10 +``` + +### Resuming Training + +To resume training from a checkpoint: + +```bash +lerobot-train \ + --config_path=./outputs/mutitask_dit_training/checkpoints/last/pretrained_model/train_config.json \ + --resume=true +``` + +The checkpoint directory should contain `model.safetensors` and `config.json` files (saved automatically during training). When resuming, the configuration is loaded from the checkpoint, so you don't need to specify other parameters. + +## Common Failure Modes and Debugging + +Training these models can be finicky. Here are common failure modes and debugging approaches: + +### Idling / No Motion + +The model may "collapse" during inference, resulting in static or no motion. This can occur when: + +1. **Insufficient training data**: If you only have 20-50 examples, try to roughly double your dataset size. Once you have above 300 examples, if you're still seeing this, the task may be too complex. + +2. **Multiple similar tasks**: When your dataset contains multiple similar tasks (e.g., picking up 2 different objects), the model may rely too heavily on language conditioning which might not be rich enough. + +**Debugging tips:** + +- Increase dataset size (double until you get to over 300 examples) +- Train for longer, up to 100k steps, even when the loss flatlines +- Check if the model is receiving proper language instructions or increase diversity of instruction + +### Executing the Wrong Task + +Sometimes the robot will completely ignore your instruction and perform some other task. This generally only happens if you have trained on multiple tasks. + +**Potential causes:** + +- Language instruction ambiguity +- Insufficient task-specific training data +- Model confusion between similar tasks in the multitask dataset + +**Debugging tips:** + +- Verify language instruction specificity, especially if descriptions are similar between multiple tasks +- Check task distribution in your training dataset and add weighting to the failing/ignored task +- Consider task-specific fine-tuning + +### Training Instability + +If training loss is unstable or diverging: + +- Try adjusting learning rate between `1e-5` and `3e-4` +- Increase batch size if possible +- Check that your dataset normalization is correct +- Verify image preprocessing is working correctly + +## Performance Considerations + +### GPU Requirements + +- **Inference**: At least an RTX 5070 Ti (or equivalent GPU) is recommended for reasonable speed performance +- **Training**: A GPU with enough VRAM to load batch sizes of >64 is ideal, which will vary depending on the number of image observations, etc + +### Batch Size Recommendations + +- **Minimum**: 64 (less than this may result in unstable training) +- **Recommended**: 256-320 (best performance, requires larger GPU) + +## Example: Training on Custom Dataset + +Here's a complete example training on a custom dataset: + +```bash +lerobot-train \ + --dataset.repo_id=YOUR_DATASET \ + --output_dir=./outputs/mutitask_dit_training \ + --batch_size=320 \ + --steps=30000 \ + --save_freq=1000 \ + --log_freq=100 \ + --eval_freq=1000 \ + --policy.type=multi_task_dit \ + --policy.device=cuda \ + --policy.horizon=32 \ + --policy.n_action_steps=24 \ + --policy.objective=diffusion \ + --policy.noise_scheduler_type=DDPM \ + --policy.num_layers=6 \ + --policy.hidden_dim=512 \ + --policy.vision_encoder_name=openai/clip-vit-base-patch16 \ + --policy.image_resize_shape=[320,240] \ + --policy.image_crop_shape=[224,224] \ + --policy.repo_id="HF_USER/multitask-dit-your-robot" \ + --wandb.enable=true \ + --wandb.project=multitask_dit +``` + +## References + +For more details on the technical implementation and architecture, see: + +- [A Careful Examination of Large Behavior Models for Multitask Dexterous Manipulation](https://arxiv.org/abs/2507.05331) +- [Large Behavior Models and Atlas Find New Footing](https://bostondynamics.com/blog/large-behavior-models-atlas-find-new-footing/) +- [Dissecting and Open-Sourcing Multitask Diffusion Transformer Policy](https://brysonkjones.substack.com/p/dissecting-and-open-sourcing-multitask-diffusion-transformer-policy) diff --git a/docs/source/policy_multi_task_dit_README.md b/docs/source/policy_multi_task_dit_README.md new file mode 100644 index 000000000..f24fa927e --- /dev/null +++ b/docs/source/policy_multi_task_dit_README.md @@ -0,0 +1,37 @@ +# Multitask DiT Policy + +## Citation + +If you use this work, please cite the following works: + +```bibtex +@misc{jones2025multitaskditpolicy, + author = {Bryson Jones}, + title = {Dissecting and Open-Sourcing Multitask Diffusion Transformer Policy}, + year = {2025}, + url = {https://brysonkjones.substack.com/p/dissecting-and-open-sourcing-multitask-diffusion-transformer-policy}, + note = {Blog post} +} +``` + +```bibtex +@misc{trilbmteam2025carefulexaminationlargebehaviormodels, + author = {TRI LBM Team}, + title = {A Careful Examination of Large Behavior Models for Multitask Dexterous Manipulation}, + year = {2025}, + eprint = {arXiv:2507.05331}, + archivePrefix = {arXiv}, + primaryClass = {cs.RO}, + url = {https://arxiv.org/abs/2507.05331} +} +``` + +```bibtex +@misc{bostondynamics2025largebehaviormodelsatlas, + author = {Boston Dynamics and TRI Research Team}, + title = {Large Behavior Models and Atlas Find New Footing}, + year = {2025}, + url = {https://bostondynamics.com/blog/large-behavior-models-atlas-find-new-footing/}, + note = {Blog post} +} +``` diff --git a/pyproject.toml b/pyproject.toml index 7e4f24eb6..bed22a507 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -145,6 +145,7 @@ wallx = [ ] pi = ["lerobot[transformers-dep]", "lerobot[scipy-dep]"] smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0", "safetensors>=0.4.3,<1.0.0"] +multi_task_dit = ["lerobot[transformers-dep]"] groot = [ "lerobot[transformers-dep]", "lerobot[peft]", diff --git a/src/lerobot/policies/__init__.py b/src/lerobot/policies/__init__.py index c7951f028..55ce09cf9 100644 --- a/src/lerobot/policies/__init__.py +++ b/src/lerobot/policies/__init__.py @@ -15,6 +15,7 @@ from .act.configuration_act import ACTConfig as ACTConfig from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig from .groot.configuration_groot import GrootConfig as GrootConfig +from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig as MultiTaskDiTConfig from .pi0.configuration_pi0 import PI0Config as PI0Config from .pi0_fast.configuration_pi0_fast import PI0FastConfig as PI0FastConfig from .pi05.configuration_pi05 import PI05Config as PI05Config @@ -28,6 +29,7 @@ from .xvla.configuration_xvla import XVLAConfig as XVLAConfig __all__ = [ "ACTConfig", "DiffusionConfig", + "MultiTaskDiTConfig", "PI0Config", "PI05Config", "PI0FastConfig", diff --git a/src/lerobot/policies/factory.py b/src/lerobot/policies/factory.py index 2320cd624..146924502 100644 --- a/src/lerobot/policies/factory.py +++ b/src/lerobot/policies/factory.py @@ -31,6 +31,7 @@ from lerobot.envs.utils import env_to_policy_features from lerobot.policies.act.configuration_act import ACTConfig from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig from lerobot.policies.groot.configuration_groot import GrootConfig +from lerobot.policies.multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig from lerobot.policies.pi0.configuration_pi0 import PI0Config from lerobot.policies.pi05.configuration_pi05 import PI05Config from lerobot.policies.pretrained import PreTrainedPolicy @@ -67,8 +68,7 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]: Args: name: The name of the policy. Supported names are "tdmpc", "diffusion", "act", - "vqbet", "pi0", "pi05", "sac", "reward_classifier", "smolvla", "wall_x". - + "multi_task_dit", "vqbet", "pi0", "pi05", "sac", "reward_classifier", "smolvla", "wall_x". Returns: The policy class corresponding to the given name. @@ -87,6 +87,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]: from lerobot.policies.act.modeling_act import ACTPolicy return ACTPolicy + elif name == "multi_task_dit": + from lerobot.policies.multi_task_dit.modeling_multi_task_dit import MultiTaskDiTPolicy + + return MultiTaskDiTPolicy elif name == "vqbet": from lerobot.policies.vqbet.modeling_vqbet import VQBeTPolicy @@ -147,8 +151,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig: Args: policy_type: The type of the policy. Supported types include "tdmpc", - "diffusion", "act", "vqbet", "pi0", "pi05", "sac", "smolvla", - "reward_classifier", "wall_x". + "multi_task_dit", "diffusion", "act", "vqbet", "pi0", "pi05", "sac", + "smolvla", "reward_classifier", "wall_x". **kwargs: Keyword arguments to be passed to the configuration class constructor. Returns: @@ -163,6 +167,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig: return DiffusionConfig(**kwargs) elif policy_type == "act": return ACTConfig(**kwargs) + elif policy_type == "multi_task_dit": + return MultiTaskDiTConfig(**kwargs) elif policy_type == "vqbet": return VQBeTConfig(**kwargs) elif policy_type == "pi0": @@ -309,6 +315,16 @@ def make_pre_post_processors( dataset_stats=kwargs.get("dataset_stats"), ) + elif isinstance(policy_cfg, MultiTaskDiTConfig): + from lerobot.policies.multi_task_dit.processor_multi_task_dit import ( + make_multi_task_dit_pre_post_processors, + ) + + processors = make_multi_task_dit_pre_post_processors( + config=policy_cfg, + dataset_stats=kwargs.get("dataset_stats"), + ) + elif isinstance(policy_cfg, VQBeTConfig): from lerobot.policies.vqbet.processor_vqbet import make_vqbet_pre_post_processors diff --git a/src/lerobot/policies/multi_task_dit/README.md b/src/lerobot/policies/multi_task_dit/README.md new file mode 100644 index 000000000..f24fa927e --- /dev/null +++ b/src/lerobot/policies/multi_task_dit/README.md @@ -0,0 +1,37 @@ +# Multitask DiT Policy + +## Citation + +If you use this work, please cite the following works: + +```bibtex +@misc{jones2025multitaskditpolicy, + author = {Bryson Jones}, + title = {Dissecting and Open-Sourcing Multitask Diffusion Transformer Policy}, + year = {2025}, + url = {https://brysonkjones.substack.com/p/dissecting-and-open-sourcing-multitask-diffusion-transformer-policy}, + note = {Blog post} +} +``` + +```bibtex +@misc{trilbmteam2025carefulexaminationlargebehaviormodels, + author = {TRI LBM Team}, + title = {A Careful Examination of Large Behavior Models for Multitask Dexterous Manipulation}, + year = {2025}, + eprint = {arXiv:2507.05331}, + archivePrefix = {arXiv}, + primaryClass = {cs.RO}, + url = {https://arxiv.org/abs/2507.05331} +} +``` + +```bibtex +@misc{bostondynamics2025largebehaviormodelsatlas, + author = {Boston Dynamics and TRI Research Team}, + title = {Large Behavior Models and Atlas Find New Footing}, + year = {2025}, + url = {https://bostondynamics.com/blog/large-behavior-models-atlas-find-new-footing/}, + note = {Blog post} +} +``` diff --git a/src/lerobot/policies/multi_task_dit/__init__.py b/src/lerobot/policies/multi_task_dit/__init__.py new file mode 100644 index 000000000..52a209d47 --- /dev/null +++ b/src/lerobot/policies/multi_task_dit/__init__.py @@ -0,0 +1,21 @@ +#!/usr/bin/env python + +# Copyright 2025 Bryson Jones and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .configuration_multi_task_dit import MultiTaskDiTConfig +from .modeling_multi_task_dit import MultiTaskDiTPolicy +from .processor_multi_task_dit import make_multi_task_dit_pre_post_processors + +__all__ = ["MultiTaskDiTConfig", "MultiTaskDiTPolicy", "make_multi_task_dit_pre_post_processors"] diff --git a/src/lerobot/policies/multi_task_dit/configuration_multi_task_dit.py b/src/lerobot/policies/multi_task_dit/configuration_multi_task_dit.py new file mode 100644 index 000000000..061230687 --- /dev/null +++ b/src/lerobot/policies/multi_task_dit/configuration_multi_task_dit.py @@ -0,0 +1,256 @@ +#!/usr/bin/env python + +# Copyright 2025 Bryson Jones and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from dataclasses import dataclass, field + +from lerobot.configs.policies import PreTrainedConfig +from lerobot.configs.types import NormalizationMode +from lerobot.optim.optimizers import AdamConfig +from lerobot.optim.schedulers import DiffuserSchedulerConfig + + +@PreTrainedConfig.register_subclass("multi_task_dit") +@dataclass +class MultiTaskDiTConfig(PreTrainedConfig): + """Configuration for the Multi-Task Diffusion Transformer (DiT) policy. + + A transformer-based policy that supports both diffusion and flow matching objectives + for multi-task robot learning with text and vision conditioning. + """ + + n_obs_steps: int = 2 # Number of observation steps for temporal context + horizon: int = 32 # Number of action steps to predict + n_action_steps: int = 24 # Actions executed per policy call (~0.8s at 30Hz) + + # Objective Selection + objective: str = "diffusion" # "diffusion" or "flow_matching" + + # --- Diffusion-specific (used when objective="diffusion") --- + noise_scheduler_type: str = "DDPM" # "DDPM" or "DDIM" + num_train_timesteps: int = 100 # Number of diffusion timesteps + beta_schedule: str = "squaredcos_cap_v2" # Noise schedule type + beta_start: float = 0.0001 # Starting noise level + beta_end: float = 0.02 # Ending noise level + prediction_type: str = "epsilon" # "epsilon" (predict noise) or "sample" (predict clean) + clip_sample: bool = True # Clip samples during denoising + clip_sample_range: float = 1.0 # Clipping range [-x, x] + num_inference_steps: int | None = None # Denoising steps at inference (defaults to num_train_timesteps) + + # --- Flow Matching-specific (used when objective="flow_matching") --- + sigma_min: float = 0.0 # Minimum noise in flow interpolation path + num_integration_steps: int = 100 # ODE integration steps at inference + integration_method: str = "euler" # ODE solver: "euler" or "rk4" + timestep_sampling_strategy: str = "beta" # "uniform" or "beta" + + timestep_sampling_s: float = 0.999 # (beta only) Max timestep threshold + timestep_sampling_alpha: float = 1.5 # (beta only) Beta distribution alpha + timestep_sampling_beta: float = 1.0 # (beta only) Beta distribution beta + + # Transformer Architecture + hidden_dim: int = 512 # Transformer hidden dimension + num_layers: int = 6 # Number of transformer layers + num_heads: int = 8 # Number of attention heads + dropout: float = 0.1 # Dropout rate + use_positional_encoding: bool = False # Use absolute positional encoding + timestep_embed_dim: int = 256 # Timestep embedding dimension + use_rope: bool = True # Use Rotary Position Embedding + rope_base: float = 10000.0 # RoPE base frequency + + # Vision Encoder (CLIP) + vision_encoder_name: str = "openai/clip-vit-base-patch16" # HuggingFace CLIP model + use_separate_rgb_encoder_per_camera: bool = False # Separate encoder per camera view + vision_encoder_lr_multiplier: float = 0.1 # LR multiplier for vision encoder + image_resize_shape: tuple[int, int] | None = None # Resize images before crop + image_crop_shape: tuple[int, int] | None = (224, 224) # Crop shape (CLIP default) + image_crop_is_random: bool = True # Random crop during training, center at inference + + # Text Encoder (CLIP) + text_encoder_name: str = "openai/clip-vit-base-patch16" # HuggingFace CLIP model + tokenizer_max_length: int = 77 # Max length for tokenized text (CLIP default is 77) + tokenizer_padding: str = "max_length" # Padding strategy: "max_length" or "longest" + tokenizer_padding_side: str = "right" # Padding side: "left" or "right" + tokenizer_truncation: bool = True # Whether to truncate sequences longer than max_length + + # Normalization + normalization_mapping: dict[str, NormalizationMode] = field( + default_factory=lambda: { + "VISUAL": NormalizationMode.MEAN_STD, + "STATE": NormalizationMode.MIN_MAX, + "ACTION": NormalizationMode.MIN_MAX, + } + ) + + # Training/Optimizer + optimizer_lr: float = 2e-5 + optimizer_betas: tuple = (0.95, 0.999) + optimizer_eps: float = 1e-8 + optimizer_weight_decay: float = 0.0 + scheduler_name: str = "cosine" + scheduler_warmup_steps: int = 0 + do_mask_loss_for_padding: bool = False + + # Auto-calculated + drop_n_last_frames: int | None = None + + def __post_init__(self): + super().__post_init__() + + if self.drop_n_last_frames is None: + self.drop_n_last_frames = self.horizon - self.n_action_steps - self.n_obs_steps + 1 + + self._validate() + + def _validate(self): + """Validate configuration parameters.""" + # Objective validation + if self.objective not in ["diffusion", "flow_matching"]: + raise ValueError(f"objective must be 'diffusion' or 'flow_matching', got '{self.objective}'") + + # Transformer validation + if self.hidden_dim <= 0: + raise ValueError("hidden_dim must be positive") + if self.num_layers <= 0: + raise ValueError("num_layers must be positive") + if self.num_heads <= 0: + raise ValueError("num_heads must be positive") + if self.hidden_dim % self.num_heads != 0: + raise ValueError("hidden_dim must be divisible by num_heads") + if not (0.0 <= self.dropout <= 1.0): + raise ValueError("dropout must be between 0.0 and 1.0") + + # Vision encoder validation + if "clip" not in self.vision_encoder_name.lower(): + raise ValueError( + f"vision_encoder_name must be a CLIP model (contain 'clip'), got '{self.vision_encoder_name}'" + ) + if ( + self.image_resize_shape + and self.image_crop_shape + and ( + self.image_crop_shape[0] > self.image_resize_shape[0] + or self.image_crop_shape[1] > self.image_resize_shape[1] + ) + ): + logging.warning( + "image_crop_shape %s must be <= image_resize_shape %s; disabling cropping.", + self.image_crop_shape, + self.image_resize_shape, + ) + self.image_crop_shape = None + + # Text encoder validation + if "clip" not in self.text_encoder_name.lower(): + raise ValueError( + f"text_encoder_name must be a CLIP model (contain 'clip'), got '{self.text_encoder_name}'" + ) + + # Objective-specific validation + if self.objective == "diffusion": + if self.noise_scheduler_type not in ["DDPM", "DDIM"]: + raise ValueError( + f"noise_scheduler_type must be 'DDPM' or 'DDIM', got {self.noise_scheduler_type}" + ) + if self.prediction_type not in ["epsilon", "sample"]: + raise ValueError(f"prediction_type must be 'epsilon' or 'sample', got {self.prediction_type}") + if self.num_train_timesteps <= 0: + raise ValueError(f"num_train_timesteps must be positive, got {self.num_train_timesteps}") + if not (0.0 <= self.beta_start <= self.beta_end <= 1.0): + raise ValueError(f"Invalid beta values: {self.beta_start}, {self.beta_end}") + + elif self.objective == "flow_matching": + if not (0.0 <= self.sigma_min <= 1.0): + raise ValueError(f"sigma_min must be in [0, 1], got {self.sigma_min}") + if self.num_integration_steps <= 0: + raise ValueError(f"num_integration_steps must be positive, got {self.num_integration_steps}") + if self.integration_method not in ["euler", "rk4"]: + raise ValueError( + f"integration_method must be 'euler' or 'rk4', got {self.integration_method}" + ) + if self.timestep_sampling_strategy not in ["uniform", "beta"]: + raise ValueError("timestep_sampling_strategy must be 'uniform' or 'beta'") + if self.timestep_sampling_strategy == "beta": + if not (0.0 < self.timestep_sampling_s <= 1.0): + raise ValueError(f"timestep_sampling_s must be in (0, 1], got {self.timestep_sampling_s}") + if self.timestep_sampling_alpha <= 0: + raise ValueError("timestep_sampling_alpha must be positive") + if self.timestep_sampling_beta <= 0: + raise ValueError("timestep_sampling_beta must be positive") + + def get_optimizer_preset(self) -> AdamConfig: + return AdamConfig( + lr=self.optimizer_lr, + betas=self.optimizer_betas, + eps=self.optimizer_eps, + weight_decay=self.optimizer_weight_decay, + ) + + def get_scheduler_preset(self) -> DiffuserSchedulerConfig: + return DiffuserSchedulerConfig( + name=self.scheduler_name, + num_warmup_steps=self.scheduler_warmup_steps, + ) + + def validate_features(self) -> None: + """Validate that required input features are present and properly configured.""" + # If the configured crop doesn't fit, disable cropping instead of erroring. + # Note: if image_resize_shape is set, cropping is applied *after* resizing. + if self.image_crop_shape is not None: + for key, image_ft in self.image_features.items(): + # image_ft.shape is (C, H, W) + effective_h, effective_w = ( + self.image_resize_shape + if self.image_resize_shape is not None + else (image_ft.shape[1], image_ft.shape[2]) + ) + if self.image_crop_shape[0] > effective_h or self.image_crop_shape[1] > effective_w: + logging.warning( + "image_crop_shape %s doesn't fit within effective image shape (%s, %s) for '%s'; disabling cropping.", + self.image_crop_shape, + effective_h, + effective_w, + key, + ) + self.image_crop_shape = None + break + + if len(self.image_features) > 0: + first_key, first_ft = next(iter(self.image_features.items())) + for key, image_ft in self.image_features.items(): + if image_ft.shape != first_ft.shape: + raise ValueError( + f"Image '{key}' shape {image_ft.shape} != '{first_key}' shape {first_ft.shape}" + ) + + @property + def is_diffusion(self) -> bool: + return self.objective == "diffusion" + + @property + def is_flow_matching(self) -> bool: + return self.objective == "flow_matching" + + @property + def observation_delta_indices(self) -> list: + return list(range(1 - self.n_obs_steps, 1)) + + @property + def action_delta_indices(self) -> list: + return list(range(1 - self.n_obs_steps, 1 - self.n_obs_steps + self.horizon)) + + @property + def reward_delta_indices(self) -> None: + return None diff --git a/src/lerobot/policies/multi_task_dit/modeling_multi_task_dit.py b/src/lerobot/policies/multi_task_dit/modeling_multi_task_dit.py new file mode 100644 index 000000000..4fee851e0 --- /dev/null +++ b/src/lerobot/policies/multi_task_dit/modeling_multi_task_dit.py @@ -0,0 +1,803 @@ +#!/usr/bin/env python + +# Copyright 2025 Bryson Jones and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Multi-Task Diffusion Transformer (DiT) Policy + +Transformer-based diffusion policy for multi-task robot learning with text and vision conditioning. +Supports both diffusion and flow matching objectives for action generation. + +References: +- https://arxiv.org/abs/2507.05331 +- https://bostondynamics.com/blog/large-behavior-models-atlas-find-new-footing/ +- https://brysonkjones.substack.com/p/dissecting-and-open-sourcing-multitask-diffusion-transformer-policy +""" + +import math +from collections import deque +from typing import TYPE_CHECKING + +import einops +import torch +import torch.nn as nn +import torch.nn.functional as F # noqa: N812 +import torchvision +from diffusers.schedulers.scheduling_ddim import DDIMScheduler +from diffusers.schedulers.scheduling_ddpm import DDPMScheduler +from torch import Tensor + +from lerobot.policies.multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig +from lerobot.utils.import_utils import _transformers_available + +# Conditional import for type checking and lazy loading +if TYPE_CHECKING or _transformers_available: + from transformers import CLIPTextModel, CLIPVisionModel +else: + CLIPTextModel = None + CLIPVisionModel = None +from lerobot.policies.pretrained import PreTrainedPolicy +from lerobot.policies.utils import populate_queues +from lerobot.utils.constants import ( + ACTION, + OBS_IMAGES, + OBS_LANGUAGE_ATTENTION_MASK, + OBS_LANGUAGE_TOKENS, + OBS_STATE, +) + +# -- Policy -- + + +class MultiTaskDiTPolicy(PreTrainedPolicy): + config_class = MultiTaskDiTConfig + name = "multi_task_dit" + + def __init__(self, config: MultiTaskDiTConfig, **kwargs): + super().__init__(config) + config.validate_features() + self.config = config + + self._queues = None + + self.observation_encoder = ObservationEncoder(config) + conditioning_dim = self.observation_encoder.conditioning_dim + self.noise_predictor = DiffusionTransformer(config, conditioning_dim=conditioning_dim) + + action_dim = config.action_feature.shape[0] + horizon = config.horizon + + if config.is_diffusion: + self.objective = DiffusionObjective( + config, + action_dim=action_dim, + horizon=horizon, + do_mask_loss_for_padding=config.do_mask_loss_for_padding, + ) + elif config.is_flow_matching: + self.objective = FlowMatchingObjective( + config, + action_dim=action_dim, + horizon=horizon, + do_mask_loss_for_padding=config.do_mask_loss_for_padding, + ) + else: + raise ValueError(f"Unsupported objective: {config.objective}") + + self.reset() + + def get_optim_params(self) -> list: + """Returns parameter groups with different learning rates for vision vs non-vision parameters""" + non_vision_params = [] + vision_encoder_params = [] + + for name, param in self.named_parameters(): + if not param.requires_grad: + continue + + if "observation_encoder.vision_encoder" in name: + vision_encoder_params.append(param) + else: + non_vision_params.append(param) + + return [ + {"params": non_vision_params}, + { + "params": vision_encoder_params, + "lr": self.config.optimizer_lr * self.config.vision_encoder_lr_multiplier, + }, + ] + + def _generate_actions(self, batch: dict[str, Tensor]) -> Tensor: + batch_size, n_obs_steps = batch[OBS_STATE].shape[:2] + assert n_obs_steps == self.config.n_obs_steps + + conditioning_vec = self.observation_encoder.encode(batch) + actions = self.objective.conditional_sample(self.noise_predictor, batch_size, conditioning_vec) + + start = n_obs_steps - 1 + end = start + self.config.n_action_steps + actions = actions[:, start:end] + return actions + + def reset(self): + """Clear observation and action queues. Should be called on `env.reset()`""" + self._queues = { + OBS_STATE: deque(maxlen=self.config.n_obs_steps), + ACTION: deque(maxlen=self.config.n_action_steps), + } + + if self.config.image_features: + self._queues[OBS_IMAGES] = deque(maxlen=self.config.n_obs_steps) + + @torch.no_grad() + def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: + """Predict a chunk of actions given environment observations""" + self.eval() + + for k in batch: + if k in self._queues: + batch[k] = torch.stack(list(self._queues[k]), dim=1) + + actions = self._generate_actions(batch) + return actions + + def _prepare_batch(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: + """Prepare batch by stacking image features if needed.""" + if self.config.image_features: + batch = dict(batch) # shallow copy to avoid modifying original + batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4) + + return batch + + @torch.no_grad() + def select_action(self, batch: dict[str, Tensor]) -> Tensor: + """Select a single action given environment observations""" + if ACTION in batch: + batch = dict(batch) # shallow copy to avoid modifying original + batch.pop(ACTION) + + batch = self._prepare_batch(batch) + + self._queues = populate_queues(self._queues, batch) + + if len(self._queues[ACTION]) == 0: + actions = self.predict_action_chunk(batch) + self._queues[ACTION].extend(actions.transpose(0, 1)) + + action = self._queues[ACTION].popleft() + return action + + def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict | None]: + """Run the batch through the model and compute the loss for training""" + batch = self._prepare_batch(batch) + + conditioning_vec = self.observation_encoder.encode(batch) + loss = self.objective.compute_loss(self.noise_predictor, batch, conditioning_vec) + + return loss, None + + +# -- Observation Encoders -- + + +class CLIPVisionEncoder(nn.Module): + """CLIP vision encoder using the CLS token for global image representation.""" + + def __init__(self, model_name: str): + super().__init__() + self.model_name = model_name + self.model = CLIPVisionModel.from_pretrained(self.model_name) + self.num_non_spatial_tokens = 1 + self.embed_dim = self.model.config.hidden_size + + def forward(self, x: Tensor) -> Tensor: + """Encode RGB image to CLS token.""" + outputs = self.model(pixel_values=x, output_hidden_states=False) + cls_token = outputs.last_hidden_state[:, 0] + b, embed_dim = cls_token.shape + return cls_token.reshape(b, embed_dim, 1, 1) + + def get_output_shape(self) -> tuple: + return (self.embed_dim, 1, 1) + + +class CLIPTextEncoder(nn.Module): + """CLIP text encoder with frozen weights and a learnable projection layer. + + Accepts pre-tokenized inputs (input_ids and attention_mask) from the processor pipeline. See the processor + pipeline to see how the tokenization is handled. + """ + + def __init__(self, model_name: str = "openai/clip-vit-base-patch16", projection_dim: int = 512): + super().__init__() + self.model_name = model_name + self.projection_dim = projection_dim + self.text_encoder = CLIPTextModel.from_pretrained(model_name) + + for param in self.text_encoder.parameters(): + param.requires_grad = False + + self.text_embed_dim = self.text_encoder.config.hidden_size + self.projection = nn.Linear(self.text_embed_dim, projection_dim) + + def forward(self, input_ids: Tensor, attention_mask: Tensor) -> Tensor: + """Encode pre-tokenized text to feature vectors.""" + # Ensure inputs are on the same device as the model + device = next(self.parameters()).device + input_ids = input_ids.to(device) + attention_mask = attention_mask.to(device) + + with torch.no_grad(): + outputs = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask) + clip_features = outputs.pooler_output + + return self.projection(clip_features) + + +class ObservationEncoder(nn.Module): + """Handles all observation processing for the conditioning vector.""" + + def __init__(self, config): + super().__init__() + self.config = config + self._setup_preprocessing(config) + + if config.image_features: + self.num_cameras = len(config.image_features) + self.camera_names = list(config.image_features.keys()) + + if config.use_separate_rgb_encoder_per_camera: + self.vision_encoders = nn.ModuleList( + [CLIPVisionEncoder(model_name=config.vision_encoder_name) for _ in self.camera_names] + ) + self.vision_encoder = None + else: + self.vision_encoder = CLIPVisionEncoder(model_name=config.vision_encoder_name) + self.vision_encoders = None + else: + self.vision_encoder = None + self.vision_encoders = None + self.camera_names = [] + self.num_cameras = 0 + + if hasattr(config, "robot_state_feature") and config.robot_state_feature: + self.robot_state_dim = config.robot_state_feature.shape[0] + else: + self.robot_state_dim = 0 + + self.text_dim = config.hidden_dim + self.text_encoder = CLIPTextEncoder(model_name=config.text_encoder_name, projection_dim=self.text_dim) + + self._setup_vector_output() + + def _apply_preprocessing(self, images: Tensor) -> Tensor: + if self.do_resize: + images = self.resize(images) + if self.do_crop: + images = self.maybe_random_crop(images) if self.training else self.center_crop(images) + return images + + def _setup_preprocessing(self, config): + if config.image_resize_shape is not None: + self.do_resize = True + self.resize = torchvision.transforms.Resize( + size=config.image_resize_shape, + interpolation=torchvision.transforms.InterpolationMode.BILINEAR, + antialias=True, + ) + else: + self.do_resize = False + + if config.image_crop_shape is not None: + self.do_crop = True + self.center_crop = torchvision.transforms.CenterCrop(config.image_crop_shape) + if config.image_crop_is_random: + self.maybe_random_crop = torchvision.transforms.RandomCrop(config.image_crop_shape) + else: + self.maybe_random_crop = self.center_crop + else: + self.do_crop = False + + def _setup_vector_output(self): + total_dim = 0 + + if self.vision_encoder is not None or self.vision_encoders is not None: + encoder_to_check = self.vision_encoder or next(iter(self.vision_encoders)) + feature_map_shape = encoder_to_check.get_output_shape() + c, h, w = feature_map_shape + spatial_feature_dim = c * h * w + total_dim += spatial_feature_dim * self.num_cameras + + total_dim += self.robot_state_dim + total_dim += self.text_dim + + self.conditioning_dim = total_dim * self.config.n_obs_steps + + def encode(self, batch: dict) -> Tensor: + """Encode observations to vector format.""" + batch_size, n_obs_steps = batch[OBS_STATE].shape[:2] + conditioning_feats = [] + + conditioning_feats.append(batch[OBS_STATE]) + + if self.vision_encoder is not None or self.vision_encoders is not None: + images = batch[OBS_IMAGES] + + if len(images.shape) == 5: + images = images.unsqueeze(1) + + if self.config.use_separate_rgb_encoder_per_camera: + camera_features = [] + for cam_idx in range(self.num_cameras): + cam_images = images[:, :, cam_idx] + cam_images_flat = einops.rearrange(cam_images, "b s c h w -> (b s) c h w") + cam_images_flat = self._apply_preprocessing(cam_images_flat) + cam_features = self.vision_encoders[cam_idx](cam_images_flat) + cam_visual_features = cam_features.flatten(start_dim=1) + cam_features_reshaped = einops.rearrange( + cam_visual_features, "(b s) f -> b s f", b=batch_size, s=n_obs_steps + ) + camera_features.append(cam_features_reshaped) + img_features = torch.cat(camera_features, dim=-1) + conditioning_feats.append(img_features) + else: + images_flat = einops.rearrange(images, "b s n c h w -> (b s n) c h w") + images_flat = self._apply_preprocessing(images_flat) + visual_features = self.vision_encoder(images_flat).flatten(start_dim=1) + img_features = einops.rearrange( + visual_features, "(b s n) f -> b s (n f)", b=batch_size, s=n_obs_steps, n=self.num_cameras + ) + conditioning_feats.append(img_features) + + if self.text_encoder is not None and OBS_LANGUAGE_TOKENS in batch: + input_ids = batch[OBS_LANGUAGE_TOKENS] # [batch_size, seq_length] + attention_mask = batch[OBS_LANGUAGE_ATTENTION_MASK] # [batch_size, seq_length] + + text_features = self.text_encoder(input_ids, attention_mask) + + text_features = text_features.unsqueeze(1).expand(-1, n_obs_steps, -1) + conditioning_feats.append(text_features) + + combined_features = torch.cat(conditioning_feats, dim=-1) + return combined_features.flatten(start_dim=1) + + +# -- Transformer Components -- + + +def modulate(x: Tensor, shift: Tensor, scale: Tensor) -> Tensor: + """Modulate input with shift and scale for AdaLN-Zero.""" + return x * (1 + scale) + shift + + +class SinusoidalPosEmb(nn.Module): + """Sinusoidal positional embeddings for timesteps.""" + + def __init__(self, dim: int): + super().__init__() + self.dim = dim + + def forward(self, x: Tensor) -> Tensor: + device = x.device + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=device) * -emb) + emb = x[:, None] * emb[None, :] + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb + + +class RotaryPositionalEmbedding(nn.Module): + """Rotary Position Embedding (RoPE) for transformers.""" + + def __init__(self, head_dim: int, max_seq_len: int = 512, base: float = 10000.0): + super().__init__() + assert head_dim % 2 == 0, "head_dim must be even for RoPE" + + self.head_dim = head_dim + self.max_seq_len = max_seq_len + self.base = base + + inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._precompute_cache(max_seq_len) + + def _precompute_cache(self, seq_len: int): + t = torch.arange(seq_len, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("_cos_cached", emb.cos()[None, None, :, :], persistent=False) + self.register_buffer("_sin_cached", emb.sin()[None, None, :, :], persistent=False) + + def _rotate_half(self, x: Tensor) -> Tensor: + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + def forward(self, q: Tensor, k: Tensor) -> tuple[Tensor, Tensor]: + seq_len = q.shape[2] + if seq_len > self.max_seq_len: + raise ValueError(f"Sequence length {seq_len} exceeds max_seq_len {self.max_seq_len}.") + + cos = self._cos_cached[:, :, :seq_len, :].to(q.dtype) + sin = self._sin_cached[:, :, :seq_len, :].to(q.dtype) + + q_rotated = (q * cos) + (self._rotate_half(q) * sin) + k_rotated = (k * cos) + (self._rotate_half(k) * sin) + return q_rotated, k_rotated + + +class RoPEAttention(nn.Module): + """Multi-head self-attention with Rotary Position Embedding (RoPE).""" + + def __init__( + self, + hidden_size: int, + num_heads: int, + dropout: float = 0.0, + max_seq_len: int = 512, + rope_base: float = 10000.0, + ): + super().__init__() + assert hidden_size % num_heads == 0, "hidden_size must be divisible by num_heads" + + self.hidden_size = hidden_size + self.num_heads = num_heads + self.head_dim = hidden_size // num_heads + self.scale = self.head_dim**-0.5 + + self.qkv_proj = nn.Linear(hidden_size, 3 * hidden_size, bias=True) + self.out_proj = nn.Linear(hidden_size, hidden_size, bias=True) + self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity() + self.rope = RotaryPositionalEmbedding(head_dim=self.head_dim, max_seq_len=max_seq_len, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + B, T, _ = x.shape # noqa: N806 + + qkv = self.qkv_proj(x) + qkv = qkv.reshape(B, T, 3, self.num_heads, self.head_dim) + qkv = qkv.permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + + q, k = self.rope(q, k) + + attn_out = torch.nn.functional.scaled_dot_product_attention( + q, + k, + v, + dropout_p=self.dropout.p if isinstance(self.dropout, nn.Dropout) and self.training else 0.0, + ) + + attn_out = attn_out.transpose(1, 2).reshape(B, T, self.hidden_size) + return self.out_proj(attn_out) + + +class TransformerBlock(nn.Module): + """DiT-style transformer block with AdaLN-Zero.""" + + def __init__( + self, + hidden_size: int = 128, + num_heads: int = 4, + num_features: int = 128, + dropout: float = 0.0, + use_rope: bool = False, + max_seq_len: int = 512, + rope_base: float = 10000.0, + ): + super().__init__() + self.use_rope = use_rope + + if use_rope: + self.attn = RoPEAttention( + hidden_size=hidden_size, + num_heads=num_heads, + dropout=dropout, + max_seq_len=max_seq_len, + rope_base=rope_base, + ) + else: + self.multihead_attn = nn.MultiheadAttention( + hidden_size, num_heads=num_heads, batch_first=True, dropout=dropout + ) + + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + + self.mlp = nn.Sequential( + nn.Linear(hidden_size, hidden_size * 4), + nn.GELU(approximate="tanh"), + nn.Linear(hidden_size * 4, hidden_size), + ) + + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(num_features, 6 * hidden_size, bias=True)) + + def forward(self, x: Tensor, features: Tensor) -> Tensor: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation( + features + ).chunk(6, dim=1) + + attn_input = modulate(self.norm1(x), shift_msa.unsqueeze(1), scale_msa.unsqueeze(1)) + + if self.use_rope: + attn_out = self.attn(attn_input) + else: + attn_out, _ = self.multihead_attn(attn_input, attn_input, attn_input) + + x = x + gate_msa.unsqueeze(1) * attn_out + + mlp_input = modulate(self.norm2(x), shift_mlp.unsqueeze(1), scale_mlp.unsqueeze(1)) + mlp_out = self.mlp(mlp_input) + x = x + gate_mlp.unsqueeze(1) * mlp_out + + return x + + +class DiffusionTransformer(nn.Module): + """Transformer-based diffusion noise prediction model.""" + + def __init__(self, config, conditioning_dim: int): + super().__init__() + self.config = config + self.conditioning_dim = conditioning_dim + + self.action_dim = config.action_feature.shape[0] + self.horizon = config.horizon + self.hidden_size = config.hidden_dim + self.num_layers = config.num_layers + self.num_heads = config.num_heads + self.dropout = config.dropout + self.use_rope = config.use_rope + + self.timestep_embed_dim = config.timestep_embed_dim + self.time_mlp = nn.Sequential( + SinusoidalPosEmb(self.timestep_embed_dim), + nn.Linear(self.timestep_embed_dim, 2 * self.timestep_embed_dim), + nn.GELU(), + nn.Linear(2 * self.timestep_embed_dim, self.timestep_embed_dim), + nn.GELU(), + ) + + self.cond_dim = self.timestep_embed_dim + conditioning_dim + self.input_proj = nn.Linear(self.action_dim, self.hidden_size) + + if config.use_positional_encoding: + self.pos_embedding = nn.Parameter( + torch.empty(1, self.horizon, self.hidden_size).normal_(std=0.02) + ) + else: + self.pos_embedding = None + + self.transformer_blocks = nn.ModuleList( + [ + TransformerBlock( + hidden_size=self.hidden_size, + num_heads=self.num_heads, + num_features=self.cond_dim, + dropout=self.dropout, + use_rope=self.use_rope, + max_seq_len=self.horizon, + rope_base=config.rope_base, + ) + for _ in range(self.num_layers) + ] + ) + + self.output_proj = nn.Linear(self.hidden_size, self.action_dim) + self._initialize_weights() + + def _initialize_weights(self): + for block in self.transformer_blocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + def forward(self, x: Tensor, timestep: Tensor, conditioning_vec: Tensor) -> Tensor: + _, seq_len, _ = x.shape + + timestep_features = self.time_mlp(timestep) + cond_features = torch.cat([timestep_features, conditioning_vec], dim=-1) + + hidden_seq = self.input_proj(x) + + if self.pos_embedding is not None: + hidden_seq = hidden_seq + self.pos_embedding[:, :seq_len, :] + + for block in self.transformer_blocks: + hidden_seq = block(hidden_seq, cond_features) + + return self.output_proj(hidden_seq) + + +# -- Objectives -- + + +class DiffusionObjective(nn.Module): + """Standard diffusion (DDPM/DDIM) objective implementation.""" + + def __init__(self, config, action_dim: int, horizon: int, do_mask_loss_for_padding: bool = False): + super().__init__() + self.config = config + self.action_dim = action_dim + self.horizon = horizon + self.do_mask_loss_for_padding = do_mask_loss_for_padding + + scheduler_kwargs = { + "num_train_timesteps": config.num_train_timesteps, + "beta_start": config.beta_start, + "beta_end": config.beta_end, + "beta_schedule": config.beta_schedule, + "clip_sample": config.clip_sample, + "clip_sample_range": config.clip_sample_range, + "prediction_type": config.prediction_type, + } + + if config.noise_scheduler_type == "DDPM": + self.noise_scheduler: DDPMScheduler | DDIMScheduler = DDPMScheduler(**scheduler_kwargs) + elif config.noise_scheduler_type == "DDIM": + self.noise_scheduler = DDIMScheduler(**scheduler_kwargs) + else: + raise ValueError(f"Unsupported noise scheduler type {config.noise_scheduler_type}") + + self.num_inference_steps = ( + config.num_inference_steps + if config.num_inference_steps is not None + else self.noise_scheduler.config.num_train_timesteps + ) + + def compute_loss(self, model: nn.Module, batch: dict[str, Tensor], conditioning_vec: Tensor) -> Tensor: + clean_actions = batch[ACTION] + noise = torch.randn_like(clean_actions) + timesteps = torch.randint( + low=0, + high=self.noise_scheduler.config.num_train_timesteps, + size=(clean_actions.shape[0],), + device=clean_actions.device, + ).long() + noisy_actions = self.noise_scheduler.add_noise(clean_actions, noise, timesteps) + + prediction_type = self.noise_scheduler.config.prediction_type + if prediction_type == "epsilon": + target = noise + elif prediction_type == "sample": + target = clean_actions + else: + raise ValueError(f"Unsupported prediction type: {prediction_type}") + + predicted = model(noisy_actions, timesteps, conditioning_vec=conditioning_vec) + loss = F.mse_loss(predicted, target, reduction="none") + + if self.do_mask_loss_for_padding and "action_is_pad" in batch: + valid_actions = ~batch["action_is_pad"] + loss = loss * valid_actions.unsqueeze(-1) + + return loss.mean() + + def conditional_sample(self, model: nn.Module, batch_size: int, conditioning_vec: Tensor) -> Tensor: + device = next(model.parameters()).device + dtype = next(model.parameters()).dtype + + sample = torch.randn( + size=(batch_size, self.horizon, self.action_dim), + dtype=dtype, + device=device, + ) + + self.noise_scheduler.set_timesteps(self.num_inference_steps) + for t in self.noise_scheduler.timesteps: + model_output = model( + sample, + torch.full(sample.shape[:1], t, dtype=torch.long, device=sample.device), + conditioning_vec=conditioning_vec, + ) + sample = self.noise_scheduler.step(model_output, t, sample).prev_sample + + return sample + + +class FlowMatchingObjective(nn.Module): + """Flow matching objective: trains a model to predict velocity fields.""" + + def __init__(self, config, action_dim: int, horizon: int, do_mask_loss_for_padding: bool = False): + super().__init__() + self.config = config + self.action_dim = action_dim + self.horizon = horizon + self.do_mask_loss_for_padding = do_mask_loss_for_padding + + def _sample_timesteps(self, batch_size: int, device: torch.device) -> Tensor: + if self.config.timestep_sampling_strategy == "uniform": + return torch.rand(batch_size, device=device) + elif self.config.timestep_sampling_strategy == "beta": + beta_dist = torch.distributions.Beta( + self.config.timestep_sampling_alpha, self.config.timestep_sampling_beta + ) + u = beta_dist.sample((batch_size,)).to(device) + return self.config.timestep_sampling_s * (1.0 - u) + else: + raise ValueError(f"Unknown timestep strategy: {self.config.timestep_sampling_strategy}") + + def compute_loss(self, model: nn.Module, batch: dict[str, Tensor], conditioning_vec: Tensor) -> Tensor: + data = batch[ACTION] + batch_size = data.shape[0] + device = data.device + + noise = torch.randn_like(data) + t = self._sample_timesteps(batch_size, device) + t_expanded = t.view(-1, 1, 1) + x_t = t_expanded * data + (1 - (1 - self.config.sigma_min) * t_expanded) * noise + + target_velocity = data - (1 - self.config.sigma_min) * noise + predicted_velocity = model(x_t, t, conditioning_vec=conditioning_vec) + loss = F.mse_loss(predicted_velocity, target_velocity, reduction="none") + + if self.do_mask_loss_for_padding and "action_is_pad" in batch: + valid_mask = ~batch["action_is_pad"] + loss = loss * valid_mask.unsqueeze(-1) + + return loss.mean() + + def conditional_sample(self, model: nn.Module, batch_size: int, conditioning_vec: Tensor) -> Tensor: + device = next(model.parameters()).device + dtype = next(model.parameters()).dtype + + x = torch.randn((batch_size, self.horizon, self.action_dim), dtype=dtype, device=device) + + num_steps = self.config.num_integration_steps + time_grid = torch.linspace(0, 1, num_steps + 1, device=device) + + if self.config.integration_method == "euler": + x = self._euler_integrate(model, x, time_grid, conditioning_vec) + elif self.config.integration_method == "rk4": + x = self._rk4_integrate(model, x, time_grid, conditioning_vec) + else: + raise ValueError(f"Unknown integration method: {self.config.integration_method}") + + return x + + def _euler_integrate( + self, model: nn.Module, x_init: Tensor, time_grid: Tensor, conditioning_vec: Tensor + ) -> Tensor: + x = x_init + for i in range(len(time_grid) - 1): + t_scalar = time_grid[i].item() + dt = (time_grid[i + 1] - time_grid[i]).item() + t_batch = torch.full((x.shape[0],), t_scalar, dtype=x.dtype, device=x.device) + with torch.no_grad(): + velocity = model(x, t_batch, conditioning_vec=conditioning_vec) + x = x + dt * velocity + return x + + def _rk4_integrate( + self, model: nn.Module, x_init: Tensor, time_grid: Tensor, conditioning_vec: Tensor + ) -> Tensor: + x = x_init + + def dynamics(x_val: Tensor, t_scalar: float) -> Tensor: + t_batch = torch.full((x_val.shape[0],), t_scalar, dtype=x_val.dtype, device=x_val.device) + with torch.no_grad(): + return model(x_val, t_batch, conditioning_vec=conditioning_vec) + + for i in range(len(time_grid) - 1): + t = time_grid[i].item() + dt = (time_grid[i + 1] - time_grid[i]).item() + + k1 = dynamics(x, t) + k2 = dynamics(x + dt * k1 / 2, t + dt / 2) + k3 = dynamics(x + dt * k2 / 2, t + dt / 2) + k4 = dynamics(x + dt * k3, t + dt) + + x = x + dt / 6 * (k1 + 2 * k2 + 2 * k3 + k4) + + return x diff --git a/src/lerobot/policies/multi_task_dit/processor_multi_task_dit.py b/src/lerobot/policies/multi_task_dit/processor_multi_task_dit.py new file mode 100644 index 000000000..fc94599c2 --- /dev/null +++ b/src/lerobot/policies/multi_task_dit/processor_multi_task_dit.py @@ -0,0 +1,105 @@ +#!/usr/bin/env python + +# Copyright 2025 Bryson Jones and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +import torch + +from lerobot.policies.multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig +from lerobot.processor import ( + AddBatchDimensionProcessorStep, + DeviceProcessorStep, + NormalizerProcessorStep, + PolicyAction, + PolicyProcessorPipeline, + RenameObservationsProcessorStep, + TokenizerProcessorStep, + UnnormalizerProcessorStep, +) +from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action +from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME + + +def make_multi_task_dit_pre_post_processors( + config: MultiTaskDiTConfig, + dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None, +) -> tuple[ + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], + PolicyProcessorPipeline[PolicyAction, PolicyAction], +]: + """ + Constructs pre-processor and post-processor pipelines for a Multi-Task DiT policy. + + The pre-processing pipeline prepares the input data for the model by: + 1. Renaming features. + 2. Adding a batch dimension. + 3. Tokenizing the language task description (if present). + 4. Moving the data to the specified device. + 5. Normalizing the input and output features based on dataset statistics. + + The post-processing pipeline handles the model's output by: + 1. Unnormalizing the output features to their original scale. + 2. Moving the data to the CPU. + + Args: + config: The configuration object for the Multi-Task DiT policy, + containing feature definitions, normalization mappings, and device information. + dataset_stats: A dictionary of statistics used for normalization. + Defaults to None. + + Returns: + A tuple containing the configured pre-processor and post-processor pipelines. + """ + + input_steps = [ + RenameObservationsProcessorStep(rename_map={}), + AddBatchDimensionProcessorStep(), + TokenizerProcessorStep( + tokenizer_name=config.text_encoder_name, + padding=config.tokenizer_padding, + padding_side=config.tokenizer_padding_side, + max_length=config.tokenizer_max_length, + truncation=config.tokenizer_truncation, + ), + DeviceProcessorStep(device=config.device), + NormalizerProcessorStep( + features={**config.input_features, **config.output_features}, + norm_map=config.normalization_mapping, + stats=dataset_stats, + device=config.device, + ), + ] + output_steps = [ + UnnormalizerProcessorStep( + features=config.output_features, + norm_map=config.normalization_mapping, + stats=dataset_stats, + ), + DeviceProcessorStep(device="cpu"), + ] + + return ( + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]]( + steps=input_steps, + name=POLICY_PREPROCESSOR_DEFAULT_NAME, + ), + PolicyProcessorPipeline[PolicyAction, PolicyAction]( + steps=output_steps, + name=POLICY_POSTPROCESSOR_DEFAULT_NAME, + to_transition=policy_action_to_transition, + to_output=transition_to_policy_action, + ), + ) diff --git a/tests/policies/multi_task_dit/test_multi_task_dit.py b/tests/policies/multi_task_dit/test_multi_task_dit.py new file mode 100644 index 000000000..5b70422d4 --- /dev/null +++ b/tests/policies/multi_task_dit/test_multi_task_dit.py @@ -0,0 +1,624 @@ +#!/usr/bin/env python + +# Copyright 2025 Bryson Jones and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ruff: noqa: E402 + +"""Test script for Multi-Task DiT policy. + +To run tests locally: + python -m pytest tests/policies/multi_task_dit/test_multi_task_dit.py -v +""" + +import os + +import pytest +import torch +from torch import Tensor + +pytest.importorskip("transformers") + +pytestmark = pytest.mark.skipif( + os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true", + reason="This test requires local transformers installation and is not meant for CI", +) + +from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature +from lerobot.policies.multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig +from lerobot.policies.multi_task_dit.modeling_multi_task_dit import MultiTaskDiTPolicy +from lerobot.policies.multi_task_dit.processor_multi_task_dit import ( + make_multi_task_dit_pre_post_processors, +) +from lerobot.utils.constants import ( + ACTION, + OBS_IMAGES, + OBS_LANGUAGE_ATTENTION_MASK, + OBS_LANGUAGE_TOKENS, + OBS_STATE, +) +from lerobot.utils.random_utils import seeded_context, set_seed + + +@pytest.fixture(autouse=True) +def set_random_seed(): + seed = 17 + set_seed(seed) + + +def create_train_batch( + batch_size: int = 2, + n_obs_steps: int = 2, + horizon: int = 16, + state_dim: int = 10, + action_dim: int = 10, + height: int = 224, + width: int = 224, +) -> dict[str, Tensor]: + """Create a training batch with visual input and text.""" + return { + "observation.state": torch.randn(batch_size, n_obs_steps, state_dim), + f"{OBS_IMAGES}.laptop": torch.rand(batch_size, n_obs_steps, 3, height, width), + ACTION: torch.randn(batch_size, horizon, action_dim), + "task": ["pick up the cube"] * batch_size, + } + + +def create_observation_batch( + batch_size: int = 2, state_dim: int = 10, height: int = 224, width: int = 224 +) -> dict: + """Create observation batch for inference for a single timestep.""" + return { + "observation.state": torch.randn(batch_size, state_dim), + f"{OBS_IMAGES}.laptop": torch.rand(batch_size, 3, height, width), + "task": ["pick up the red cube"] * batch_size, + } + + +def create_config( + state_dim: int = 10, + action_dim: int = 10, + n_obs_steps: int = 2, + horizon: int = 16, + n_action_steps: int = 8, + with_visual: bool = True, + height: int = 224, + width: int = 224, +) -> MultiTaskDiTConfig: + """Create a MultiTaskDiT config for testing. + + Args: + state_dim: Dimension of state observations + action_dim: Dimension of actions + n_obs_steps: Number of observation steps + horizon: Action prediction horizon + n_action_steps: Number of action steps to execute + with_visual: Whether to include visual input (default: True) + height: Image height (only used if with_visual=True) + width: Image width (only used if with_visual=True) + """ + input_features = {OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))} + + if with_visual: + input_features[f"{OBS_IMAGES}.laptop"] = PolicyFeature( + type=FeatureType.VISUAL, shape=(3, height, width) + ) + + config = MultiTaskDiTConfig( + input_features=input_features, + output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))}, + n_obs_steps=n_obs_steps, + horizon=horizon, + n_action_steps=n_action_steps, + # Use smaller model for faster tests + hidden_dim=128, + num_layers=2, + num_heads=4, + ) + + config.validate_features() + return config + + +@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 10, 10), (1, 6, 6)]) +def test_multi_task_dit_policy_forward(batch_size: int, state_dim: int, action_dim: int): + """Test forward pass (training mode).""" + n_obs_steps = 2 + horizon = 16 + n_action_steps = 8 + + config = create_config( + state_dim=state_dim, + action_dim=action_dim, + n_obs_steps=n_obs_steps, + horizon=horizon, + n_action_steps=n_action_steps, + ) + + policy = MultiTaskDiTPolicy(config=config) + policy.train() + + # Use preprocessor to handle tokenization + config.normalization_mapping = { + "VISUAL": NormalizationMode.IDENTITY, + "STATE": NormalizationMode.IDENTITY, + "ACTION": NormalizationMode.IDENTITY, + } + preprocessor, _ = make_multi_task_dit_pre_post_processors(config=config, dataset_stats=None) + + batch = create_train_batch( + batch_size=batch_size, + n_obs_steps=n_obs_steps, + horizon=horizon, + state_dim=state_dim, + action_dim=action_dim, + ) + + # Process batch through preprocessor to tokenize task text + processed_batch = preprocessor(batch) + + # Test forward pass + loss, _ = policy.forward(processed_batch) + assert loss is not None + assert loss.item() is not None + assert loss.shape == () + + # Test backward pass + loss.backward() + + +def test_multi_task_dit_pre_post_processors(): + """Test pre and post processors for Multi-Task DiT policy.""" + state_dim = 10 + action_dim = 8 + n_obs_steps = 2 + horizon = 16 + + config = create_config( + state_dim=state_dim, + action_dim=action_dim, + n_obs_steps=n_obs_steps, + horizon=horizon, + n_action_steps=8, + ) + config.device = "cpu" + + # Set normalization mode to match the stats we're providing + config.normalization_mapping = { + "VISUAL": NormalizationMode.IDENTITY, + "STATE": NormalizationMode.MEAN_STD, # Use MEAN_STD since we provide mean/std stats + "ACTION": NormalizationMode.MIN_MAX, + } + + # Create dataset stats for normalization + dataset_stats = { + "observation.state": { + "mean": torch.zeros(state_dim), + "std": torch.ones(state_dim), + }, + "action": { + "min": torch.full((action_dim,), -1.0), + "max": torch.ones(action_dim), + }, + } + + # Create processors + preprocessor, postprocessor = make_multi_task_dit_pre_post_processors( + config=config, dataset_stats=dataset_stats + ) + + # Test preprocessor with sample data + batch = { + "observation.state": torch.randn(state_dim), + f"{OBS_IMAGES}.laptop": torch.rand(3, 224, 224), + ACTION: torch.randn(action_dim), + "task": "pick up the cube", + } + + processed_batch = preprocessor(batch) + + # Check that data is batched + assert processed_batch["observation.state"].shape == (1, state_dim) + assert processed_batch[f"{OBS_IMAGES}.laptop"].shape == (1, 3, 224, 224) + assert processed_batch[ACTION].shape == (1, action_dim) + # Check that task text was tokenized + assert OBS_LANGUAGE_TOKENS in processed_batch + assert OBS_LANGUAGE_ATTENTION_MASK in processed_batch + assert processed_batch[OBS_LANGUAGE_TOKENS].shape[0] == 1 # batch dimension + assert processed_batch[OBS_LANGUAGE_ATTENTION_MASK].shape[0] == 1 # batch dimension + + # Check that data is on correct device + assert processed_batch["observation.state"].device.type == "cpu" + assert processed_batch[f"{OBS_IMAGES}.laptop"].device.type == "cpu" + assert processed_batch[ACTION].device.type == "cpu" + + # Test postprocessor with sample action (PolicyAction is just a torch.Tensor) + action = torch.randn(1, action_dim) + processed_action = postprocessor(action) + + # Check that action is unnormalized and on CPU + assert processed_action.shape == (1, action_dim) + assert processed_action.device.type == "cpu" + + +def test_multi_task_dit_pre_post_processors_normalization(): + """Test that normalization and unnormalization work correctly with simple sanity check numbers.""" + state_dim = 3 + action_dim = 2 + + config = create_config( + state_dim=state_dim, + action_dim=action_dim, + n_obs_steps=2, + horizon=16, + n_action_steps=8, + ) + config.device = "cpu" + + # Set normalization mode to match the stats we're providing + config.normalization_mapping = { + "VISUAL": NormalizationMode.IDENTITY, + "STATE": NormalizationMode.MEAN_STD, # Use MEAN_STD since we provide mean/std stats + "ACTION": NormalizationMode.MIN_MAX, + } + + # Use simple stats that will actually transform the values + dataset_stats = { + "observation.state": { + "mean": torch.full((state_dim,), 5.0), + "std": torch.full((state_dim,), 2.0), + }, + "action": { + "min": torch.zeros(action_dim), + "max": torch.full((action_dim,), 2.0), + }, + } + + # Create processors + preprocessor, postprocessor = make_multi_task_dit_pre_post_processors( + config=config, dataset_stats=dataset_stats + ) + + # Use simple input values + input_state = torch.tensor([7.0, 5.0, 3.0]) # Will normalize to [1.0, 0.0, -1.0] + input_action = torch.tensor([1.0, 2.0]) # Will normalize to [0.0, 1.0] + + batch = { + "observation.state": input_state, + f"{OBS_IMAGES}.laptop": torch.rand(3, 224, 224), + ACTION: input_action, + "task": "test task", + } + + # Process through preprocessor + processed_batch = preprocessor(batch) + + # State normalization: (x - mean) / std + expected_normalized_state = torch.tensor([1.0, 0.0, -1.0]) + assert torch.allclose(processed_batch["observation.state"][0], expected_normalized_state, atol=1e-5) + + # Action normalization: (x - min) / (max - min) * 2 - 1 + expected_normalized_action = torch.tensor([0.0, 1.0]) + assert torch.allclose(processed_batch[ACTION][0], expected_normalized_action, atol=1e-5) + + # Test unnormalization: should recover original values + normalized_action_tensor = processed_batch[ACTION][0:1] # Keep batch dimension + unnormalized_action = postprocessor(normalized_action_tensor) + + # Should recover original action values + assert torch.allclose(unnormalized_action[0], input_action, atol=1e-4) + + +@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 10, 10), (1, 6, 6)]) +def test_multi_task_dit_policy_select_action(batch_size: int, state_dim: int, action_dim: int): + """Test select_action (inference mode).""" + n_obs_steps = 2 + horizon = 16 + n_action_steps = 8 + + config = create_config( + state_dim=state_dim, + action_dim=action_dim, + n_obs_steps=n_obs_steps, + horizon=horizon, + n_action_steps=n_action_steps, + ) + + policy = MultiTaskDiTPolicy(config=config) + policy.eval() + policy.reset() # Reset queues before inference + + # Create processors - use IDENTITY normalization when no stats provided + config.normalization_mapping = { + "VISUAL": NormalizationMode.IDENTITY, + "STATE": NormalizationMode.IDENTITY, + "ACTION": NormalizationMode.IDENTITY, + } + preprocessor, postprocessor = make_multi_task_dit_pre_post_processors(config=config, dataset_stats=None) + + with torch.no_grad(): + observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim) + # Process observation through preprocessor + processed_obs = preprocessor(observation_batch) + selected_action = policy.select_action(processed_obs) + # Process action through postprocessor (PolicyAction is just a torch.Tensor) + processed_action = postprocessor(selected_action) + assert processed_action.shape == (batch_size, action_dim) + + +def test_multi_task_dit_policy_diffusion_objective(): + """Test policy with diffusion objective.""" + batch_size = 2 + state_dim = 10 + action_dim = 10 + n_obs_steps = 2 + horizon = 16 + n_action_steps = 8 + + input_features = { + OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,)), + f"{OBS_IMAGES}.laptop": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)), + } + + config = MultiTaskDiTConfig( + input_features=input_features, + output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))}, + n_obs_steps=n_obs_steps, + horizon=horizon, + n_action_steps=n_action_steps, + # Use diffusion objective + objective="diffusion", + noise_scheduler_type="DDPM", + num_train_timesteps=100, + num_inference_steps=10, + # Smaller model for tests + hidden_dim=128, + num_layers=2, + num_heads=4, + ) + config.validate_features() + + policy = MultiTaskDiTPolicy(config=config) + policy.train() + + # Use preprocessor to handle tokenization + config.normalization_mapping = { + "VISUAL": NormalizationMode.IDENTITY, + "STATE": NormalizationMode.IDENTITY, + "ACTION": NormalizationMode.IDENTITY, + } + preprocessor, _ = make_multi_task_dit_pre_post_processors(config=config, dataset_stats=None) + + batch = create_train_batch( + batch_size=batch_size, + n_obs_steps=n_obs_steps, + horizon=horizon, + state_dim=state_dim, + action_dim=action_dim, + ) + + # Process batch through preprocessor to tokenize task text + processed_batch = preprocessor(batch) + + # Test forward pass + loss, _ = policy.forward(processed_batch) + assert loss is not None + assert loss.item() is not None + + # Test inference + policy.eval() + # Use IDENTITY normalization when no stats provided + config.normalization_mapping = { + "VISUAL": NormalizationMode.IDENTITY, + "STATE": NormalizationMode.IDENTITY, + "ACTION": NormalizationMode.IDENTITY, + } + preprocessor, postprocessor = make_multi_task_dit_pre_post_processors(config=config, dataset_stats=None) + with torch.no_grad(): + observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim) + # Process observation through preprocessor + processed_obs = preprocessor(observation_batch) + selected_action = policy.select_action(processed_obs) + # Process action through postprocessor (PolicyAction is just a torch.Tensor) + processed_action = postprocessor(selected_action) + assert processed_action.shape == (batch_size, action_dim) + + +def test_multi_task_dit_policy_flow_matching_objective(): + """Test policy with flow matching objective.""" + batch_size = 2 + state_dim = 10 + action_dim = 10 + n_obs_steps = 2 + horizon = 16 + n_action_steps = 8 + + input_features = { + OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,)), + f"{OBS_IMAGES}.laptop": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)), + } + + config = MultiTaskDiTConfig( + input_features=input_features, + output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))}, + n_obs_steps=n_obs_steps, + horizon=horizon, + n_action_steps=n_action_steps, + # Use flow matching objective + objective="flow_matching", + sigma_min=0.0, + num_integration_steps=10, # Fewer steps for faster tests + integration_method="euler", + # Smaller model for tests + hidden_dim=128, + num_layers=2, + num_heads=4, + ) + config.validate_features() + + policy = MultiTaskDiTPolicy(config=config) + policy.train() + + # Use preprocessor to handle tokenization + config.normalization_mapping = { + "VISUAL": NormalizationMode.IDENTITY, + "STATE": NormalizationMode.IDENTITY, + "ACTION": NormalizationMode.IDENTITY, + } + preprocessor, _ = make_multi_task_dit_pre_post_processors(config=config, dataset_stats=None) + + batch = create_train_batch( + batch_size=batch_size, + n_obs_steps=n_obs_steps, + horizon=horizon, + state_dim=state_dim, + action_dim=action_dim, + ) + + # Process batch through preprocessor to tokenize task text + processed_batch = preprocessor(batch) + + # Test forward pass + loss, _ = policy.forward(processed_batch) + assert loss is not None + assert loss.item() is not None + + # Test inference + policy.eval() + # Use IDENTITY normalization when no stats provided + config.normalization_mapping = { + "VISUAL": NormalizationMode.IDENTITY, + "STATE": NormalizationMode.IDENTITY, + "ACTION": NormalizationMode.IDENTITY, + } + preprocessor, postprocessor = make_multi_task_dit_pre_post_processors(config=config, dataset_stats=None) + with torch.no_grad(): + observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim) + # Process observation through preprocessor + processed_obs = preprocessor(observation_batch) + selected_action = policy.select_action(processed_obs) + # Process action through postprocessor (PolicyAction is just a torch.Tensor) + processed_action = postprocessor(selected_action) + assert processed_action.shape == (batch_size, action_dim) + + +def test_multi_task_dit_policy_save_and_load(tmp_path): + """Test that the policy can be saved and loaded correctly.""" + root = tmp_path / "test_multi_task_dit_save_and_load" + + state_dim = 10 + action_dim = 10 + batch_size = 2 + n_obs_steps = 2 + horizon = 16 + n_action_steps = 8 + + config = create_config( + state_dim=state_dim, + action_dim=action_dim, + n_obs_steps=n_obs_steps, + horizon=horizon, + n_action_steps=n_action_steps, + ) + + policy = MultiTaskDiTPolicy(config=config) + policy.eval() + + # Get device before saving + device = next(policy.parameters()).device + + policy.save_pretrained(root) + loaded_policy = MultiTaskDiTPolicy.from_pretrained(root, config=config) + + # Explicitly move loaded_policy to the same device + loaded_policy.to(device) + loaded_policy.eval() + + batch = create_train_batch( + batch_size=batch_size, + n_obs_steps=n_obs_steps, + horizon=horizon, + state_dim=state_dim, + action_dim=action_dim, + ) + + # Use preprocessor to handle tokenization + config.normalization_mapping = { + "VISUAL": NormalizationMode.IDENTITY, + "STATE": NormalizationMode.IDENTITY, + "ACTION": NormalizationMode.IDENTITY, + } + preprocessor, postprocessor = make_multi_task_dit_pre_post_processors(config=config, dataset_stats=None) + + with torch.no_grad(): + with seeded_context(12): + # Process batch through preprocessor + processed_batch = preprocessor(batch) + # Move batch to the same device as the policy + for key in processed_batch: + if isinstance(processed_batch[key], torch.Tensor): + processed_batch[key] = processed_batch[key].to(device) + # Collect policy values before saving + loss, _ = policy.forward(processed_batch) + + observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim) + # Process observation through preprocessor + processed_obs = preprocessor(observation_batch) + actions = policy.select_action(processed_obs) + + with seeded_context(12): + # Process batch through preprocessor + processed_batch = preprocessor(batch) + # Collect policy values after loading + loaded_loss, _ = loaded_policy.forward(processed_batch) + + loaded_observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim) + processed_obs = preprocessor(loaded_observation_batch) + loaded_actions = loaded_policy.select_action(processed_obs) + + # Compare state dicts + assert policy.state_dict().keys() == loaded_policy.state_dict().keys() + for k in policy.state_dict(): + assert torch.allclose(policy.state_dict()[k], loaded_policy.state_dict()[k], atol=1e-6) + + # Compare values before and after saving and loading + assert torch.allclose(loss, loaded_loss) + assert torch.allclose(actions, loaded_actions) + + +def test_multi_task_dit_policy_get_optim_params(): + """Test that the policy returns correct optimizer parameter groups.""" + config = create_config( + state_dim=10, + action_dim=10, + n_obs_steps=2, + horizon=16, + n_action_steps=8, + ) + + policy = MultiTaskDiTPolicy(config=config) + param_groups = policy.get_optim_params() + + # Should have 2 parameter groups: non-vision and vision encoder + assert len(param_groups) == 2 + + # First group is non-vision params (no lr specified, will use default) + assert "params" in param_groups[0] + assert len(param_groups[0]["params"]) > 0 + + # Second group is vision encoder params with different lr + assert "params" in param_groups[1] + assert "lr" in param_groups[1] + expected_lr = config.optimizer_lr * config.vision_encoder_lr_multiplier + assert param_groups[1]["lr"] == expected_lr From 3b185f7f9d8faf16bbdf3a833e670feac08f881e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=9B=9B=E4=B8=83?= <41624527+SevenFo@users.noreply.github.com> Date: Sat, 28 Mar 2026 18:37:57 +0800 Subject: [PATCH 08/47] fix(datasets): remove unreachable timestamp branch in add_frame (#3163) * fix(datasets): remove unreachable timestamp branch in add_frame and document caller contract - Remove dead code: frame.pop("timestamp") branch in add_frame() could never execute because validate_frame() raises ValueError for any DEFAULT_FEATURES key (including timestamp) before we reach that line. - Expand add_frame() docstring: explicitly document that timestamp and frame_index must NOT be passed by the caller. - Add explanatory comment in validate_frame(): clarifies why DEFAULT_FEATURES are excluded from expected_features, preventing future re-introduction of the dead branch. The dead branch originated in #1200, which fixed a shape-(1,) mismatch for a code path that was subsequently made unreachable by a refactor of validate_frame. * chore(datasets): narrow PR scope * fix(datasets): move add_frame timestamp cleanup to dataset_writer --- src/lerobot/datasets/dataset_writer.py | 13 +++++++++++-- src/lerobot/datasets/feature_utils.py | 4 ++++ 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/src/lerobot/datasets/dataset_writer.py b/src/lerobot/datasets/dataset_writer.py index b74b18e0c..787ecd337 100644 --- a/src/lerobot/datasets/dataset_writer.py +++ b/src/lerobot/datasets/dataset_writer.py @@ -155,7 +155,16 @@ class DatasetWriter: self.image_writer.save_image(image=image, fpath=fpath, compress_level=compress_level) def add_frame(self, frame: dict) -> None: - """Add a frame to the episode_buffer. Images are written to a temporary directory.""" + """ + Add a single frame to the current episode buffer. + + Apart from images written to a temporary directory, nothing is written to disk + until ``save_episode()`` is called. + + The caller must provide all user-defined features plus ``"task"``, and must + not provide ``"timestamp"`` or ``"frame_index"``; those are computed + automatically. + """ # Convert torch to numpy if needed for name in frame: if isinstance(frame[name], torch.Tensor): @@ -168,7 +177,7 @@ class DatasetWriter: # Automatically add frame_index and timestamp to episode buffer frame_index = self.episode_buffer["size"] - timestamp = frame.pop("timestamp") if "timestamp" in frame else frame_index / self._meta.fps + timestamp = frame_index / self._meta.fps self.episode_buffer["frame_index"].append(frame_index) self.episode_buffer["timestamp"].append(timestamp) self.episode_buffer["task"].append(frame.pop("task")) diff --git a/src/lerobot/datasets/feature_utils.py b/src/lerobot/datasets/feature_utils.py index d9a3c6301..46154d92a 100644 --- a/src/lerobot/datasets/feature_utils.py +++ b/src/lerobot/datasets/feature_utils.py @@ -365,6 +365,10 @@ def get_delta_indices(delta_timestamps: dict[str, list[float]], fps: int) -> dic def validate_frame(frame: dict, features: dict) -> None: + # DEFAULT_FEATURES (timestamp, frame_index, episode_index, index, task_index) are + # auto-populated by the recording pipeline (add_frame / save_episode) and must not + # be supplied by the caller. Excluding them here means any frame dict that contains + # these keys will be rejected as extra features. expected_features = set(features) - set(DEFAULT_FEATURES) actual_features = set(frame) From 5d4fdf5088ed86aa6d0d85c426525a4b1d2e213d Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Mon, 30 Mar 2026 16:33:17 +0200 Subject: [PATCH 09/47] feat(scripts): add transformers version (#3248) * feat(scripts): add transformers and torch version * chore(scripts): remove pytorch --- src/lerobot/scripts/lerobot_info.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/lerobot/scripts/lerobot_info.py b/src/lerobot/scripts/lerobot_info.py index 879d392be..2092db48b 100644 --- a/src/lerobot/scripts/lerobot_info.py +++ b/src/lerobot/scripts/lerobot_info.py @@ -65,6 +65,7 @@ def get_sys_info() -> dict[str, str]: "Platform": platform.platform(), "Python version": platform.python_version(), "Huggingface Hub version": get_package_version("huggingface_hub"), + "Transformers version": get_package_version("transformers"), "Datasets version": get_package_version("datasets"), "Numpy version": get_package_version("numpy"), "FFmpeg version": get_ffmpeg_version(), From 720cf8e3a09f62fa95260cc49a7a30e5d0f7473a Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Mon, 30 Mar 2026 19:11:41 +0200 Subject: [PATCH 10/47] Revert "fix(deps): breaking change from transformers 5.4.0" (#3249) * Revert "fix(deps): breaking change from transformers 5.4.0 (#3231)" This reverts commit 07502868e58095b437e5dd5a480fecc65a6f29dc. * chore(dependecies): pin transformers to 5.3.0 temporarily --- pyproject.toml | 2 +- .../policies/groot/action_head/flow_matching_action_head.py | 3 ++- src/lerobot/policies/groot/groot_n1.py | 3 ++- src/lerobot/policies/wall_x/qwen_model/qwen2_5_vl_moe.py | 4 ++-- src/lerobot/policies/xvla/modeling_florence2.py | 4 ++-- 5 files changed, 9 insertions(+), 7 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index bed22a507..4a1efab30 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -99,7 +99,7 @@ dependencies = [ # Common pygame-dep = ["pygame>=2.5.1,<2.7.0"] placo-dep = ["placo>=0.9.6,<0.9.17"] -transformers-dep = ["transformers>=5.4.0,<6.0.0"] +transformers-dep = ["transformers==5.3.0"] # TODO(Steven): https://github.com/huggingface/lerobot/pull/3249 grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"] can-dep = ["python-can>=4.2.0,<5.0.0"] peft-dep = ["peft>=0.18.0,<1.0.0"] diff --git a/src/lerobot/policies/groot/action_head/flow_matching_action_head.py b/src/lerobot/policies/groot/action_head/flow_matching_action_head.py index 74d922988..bfc456ba0 100644 --- a/src/lerobot/policies/groot/action_head/flow_matching_action_head.py +++ b/src/lerobot/policies/groot/action_head/flow_matching_action_head.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import field +from dataclasses import dataclass, field from typing import TYPE_CHECKING import torch @@ -110,6 +110,7 @@ class MultiEmbodimentActionEncoder(nn.Module): return x +@dataclass class FlowmatchingActionHeadConfig(PretrainedConfig): """NOTE: N1.5 uses XEmbFlowmatchingPolicyHeadConfig as action head""" diff --git a/src/lerobot/policies/groot/groot_n1.py b/src/lerobot/policies/groot/groot_n1.py index 38512b8a8..06ff5a04d 100644 --- a/src/lerobot/policies/groot/groot_n1.py +++ b/src/lerobot/policies/groot/groot_n1.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import field +from dataclasses import dataclass, field from pathlib import Path from typing import TYPE_CHECKING @@ -173,6 +173,7 @@ N_COLOR_CHANNELS = 3 # config +@dataclass class GR00TN15Config(PretrainedConfig): model_type = "gr00t_n1_5" backbone_cfg: dict = field(init=False, metadata={"help": "Backbone configuration."}) diff --git a/src/lerobot/policies/wall_x/qwen_model/qwen2_5_vl_moe.py b/src/lerobot/policies/wall_x/qwen_model/qwen2_5_vl_moe.py index a80096514..ecf3eb371 100644 --- a/src/lerobot/policies/wall_x/qwen_model/qwen2_5_vl_moe.py +++ b/src/lerobot/policies/wall_x/qwen_model/qwen2_5_vl_moe.py @@ -22,7 +22,7 @@ from transformers.utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, - is_flash_attn_greater_or_equal, + is_flash_attn_greater_or_equal_2_10, is_torchdynamo_compiling, logging, replace_return_docstrings, @@ -890,7 +890,7 @@ class Qwen2_5_VLFlashAttention2(Qwen2_5_VLAttention): # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal("2.1.0") + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() def forward( self, diff --git a/src/lerobot/policies/xvla/modeling_florence2.py b/src/lerobot/policies/xvla/modeling_florence2.py index 81f9c8234..e33efe5c3 100644 --- a/src/lerobot/policies/xvla/modeling_florence2.py +++ b/src/lerobot/policies/xvla/modeling_florence2.py @@ -45,7 +45,7 @@ from transformers.utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, - is_flash_attn_greater_or_equal, + is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, ) @@ -909,7 +909,7 @@ class Florence2FlashAttention2(Florence2Attention): # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal("2.1.0") + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) From 9300352876f68ed7c12726d6fb8ff45773023b7c Mon Sep 17 00:00:00 2001 From: Jai Kumaar Ratadia Date: Tue, 31 Mar 2026 11:16:34 +0100 Subject: [PATCH 11/47] Fix SO-101 assembly instruction order to match videos (#3242) * Fix SO-101 assembly instruction order to match videos Motor horn installation steps were listed after placing motors into the housing, but the assembly videos show installing horns first. Reordered steps to match the videos, which is also the easier approach since horns are harder to attach once the motor is seated. Added missing detail that bottom horns don't require screws. * Update docs/source/so101.mdx Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Jai Kumaar Ratadia --------- Signed-off-by: Jai Kumaar Ratadia Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com> --- docs/source/so101.mdx | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/docs/source/so101.mdx b/docs/source/so101.mdx index 7c9df588a..1274b8282 100644 --- a/docs/source/so101.mdx +++ b/docs/source/so101.mdx @@ -236,10 +236,10 @@ It is advisable to install one 3-pin cable in the motor after placing them befor ### Joint 1 +- Install both motor horns. Secure the top horn with a M3x6mm screw. No screws are required for the bottom horn. - Place the first motor into the base. - Fasten the motor with 4 M2x6mm screws (smallest screws). Two from the top and two from the bottom. - Slide over the first motor holder and fasten it using two M2x6mm screws (one on each side). -- Install both motor horns, securing the top horn with a M3x6mm screw. - Attach the shoulder part. - Tighten the shoulder part with 4 M3x6mm screws on top and 4 M3x6mm screws on the bottom - Add the shoulder motor holder. @@ -255,9 +255,9 @@ It is advisable to install one 3-pin cable in the motor after placing them befor ### Joint 2 +- Install both motor horns. Secure the top horn with a M3x6mm screw. No screws are required for the bottom horn. - Slide the second motor in from the top. - Fasten the second motor with 4 M2x6mm screws. -- Attach both motor horns to motor 2, again use the M3x6mm horn screw. - Attach the upper arm with 4 M3x6mm screws on each side.
@@ -271,8 +271,8 @@ It is advisable to install one 3-pin cable in the motor after placing them befor ### Joint 3 -- Insert motor 3 and fasten using 4 M2x6mm screws -- Attach both motor horns to motor 3 and secure one again with a M3x6mm horn screw. +- Install both motor horns. Secure the top horn with a M3x6mm screw. No screws are required for the bottom horn. +- Insert motor 3 and fasten using 4 M2x6mm screws. - Connect the forearm to motor 3 using 4 M3x6mm screws on each side.
@@ -286,9 +286,10 @@ It is advisable to install one 3-pin cable in the motor after placing them befor ### Joint 4 +- Install both motor horns. Secure the top horn with a M3x6mm screw. No screws are required for the bottom horn. - Slide over motor holder 4. - Slide in motor 4. -- Fasten motor 4 with 4 M2x6mm screws and attach its motor horns, use a M3x6mm horn screw. +- Fasten motor 4 with 4 M2x6mm screws.