From 7c2ec31793da193e03853699cb5040db6ce1caa5 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Sun, 15 Mar 2026 20:42:15 -0700 Subject: [PATCH] refactor(datasets): module cleanup (#3169) --- .../datasets/backward_compatibility.py | 2 +- src/lerobot/datasets/lerobot_dataset.py | 3 +- src/lerobot/datasets/online_buffer.py | 382 ------------------ .../datasets/push_dataset_to_hub/utils.py | 73 ---- .../augment_dataset_quantile_stats.py | 2 +- .../convert_dataset_v21_to_v30.py | 4 +- tests/datasets/test_dataset_utils.py | 17 +- tests/datasets/test_online_buffer.py | 282 ------------- tests/datasets/test_sampler.py | 18 +- 9 files changed, 38 insertions(+), 745 deletions(-) delete mode 100644 src/lerobot/datasets/online_buffer.py delete mode 100644 src/lerobot/datasets/push_dataset_to_hub/utils.py rename src/lerobot/{datasets/v30 => scripts}/augment_dataset_quantile_stats.py (99%) rename src/lerobot/{datasets/v30 => scripts}/convert_dataset_v21_to_v30.py (99%) delete mode 100644 tests/datasets/test_online_buffer.py diff --git a/src/lerobot/datasets/backward_compatibility.py b/src/lerobot/datasets/backward_compatibility.py index ae95c5f7b..aefbfd55b 100644 --- a/src/lerobot/datasets/backward_compatibility.py +++ b/src/lerobot/datasets/backward_compatibility.py @@ -20,7 +20,7 @@ The dataset you requested ({repo_id}) is in {version} format. We introduced a new format since v3.0 which is not backward compatible with v2.1. Please, update your dataset to the new format using this command: ``` -python -m lerobot.datasets.v30.convert_dataset_v21_to_v30 --repo-id={repo_id} +python -m lerobot.scripts.convert_dataset_v21_to_v30 --repo-id={repo_id} ``` If you already have a converted version uploaded to the hub, then this error might be because of diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index 26f0c769c..11c10f493 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -596,7 +596,7 @@ class LeRobotDataset(torch.utils.data.Dataset): the dataset from that address and load it, pending your dataset is compliant with codebase_version v3.0. If your dataset has been created before this new format, you will be prompted to convert it using our conversion script from v2.1 to v3.0, which you can find at - lerobot/datasets/v30/convert_dataset_v21_to_v30.py. + lerobot/scripts/convert_dataset_v21_to_v30.py. 2. Your dataset doesn't already exists (either on local disk or on the Hub): you can create an empty @@ -1683,7 +1683,6 @@ class LeRobotDataset(torch.utils.data.Dataset): if image_writer_processes or image_writer_threads: obj.start_image_writer(image_writer_processes, image_writer_threads) - # TODO(aliberts, rcadene, alexander-soare): Merge this with OnlineBuffer/DataBuffer obj.episode_buffer = obj.create_episode_buffer() obj.episodes = None diff --git a/src/lerobot/datasets/online_buffer.py b/src/lerobot/datasets/online_buffer.py deleted file mode 100644 index 563d800b9..000000000 --- a/src/lerobot/datasets/online_buffer.py +++ /dev/null @@ -1,382 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""An online buffer for the online training loop in train.py - -Note to maintainers: This duplicates some logic from LeRobotDataset and EpisodeAwareSampler. We should -consider converging to one approach. Here we have opted to use numpy.memmap to back the data buffer. It's much -faster than using HuggingFace Datasets as there's no conversion to an intermediate non-python object. Also it -supports in-place slicing and mutation which is very handy for a dynamic buffer. -""" - -import os -from pathlib import Path -from typing import Any - -import numpy as np -import torch - -from lerobot.datasets.lerobot_dataset import LeRobotDataset - - -def _make_memmap_safe(**kwargs) -> np.memmap: - """Make a numpy memmap with checks on available disk space first. - - Expected kwargs are: "filename", "dtype" (must by np.dtype), "mode" and "shape" - - For information on dtypes: - https://numpy.org/doc/stable/reference/arrays.dtypes.html#arrays-dtypes-constructing - """ - if kwargs["mode"].startswith("w"): - required_space = kwargs["dtype"].itemsize * np.prod(kwargs["shape"]) # bytes - stats = os.statvfs(Path(kwargs["filename"]).parent) - available_space = stats.f_bavail * stats.f_frsize # bytes - if required_space >= available_space * 0.8: - raise RuntimeError( - f"You're about to take up {required_space} of {available_space} bytes available." - ) - return np.memmap(**kwargs) - - -class OnlineBuffer(torch.utils.data.Dataset): - """FIFO data buffer for the online training loop in train.py. - - Follows the protocol of LeRobotDataset as much as is required to have it be used by the online training - loop in the same way that a LeRobotDataset would be used. - - The underlying data structure will have data inserted in a circular fashion. Always insert after the - last index, and when you reach the end, wrap around to the start. - - The data is stored in a numpy memmap. - """ - - NEXT_INDEX_KEY = "_next_index" - OCCUPANCY_MASK_KEY = "_occupancy_mask" - INDEX_KEY = "index" - FRAME_INDEX_KEY = "frame_index" - EPISODE_INDEX_KEY = "episode_index" - TIMESTAMP_KEY = "timestamp" - IS_PAD_POSTFIX = "_is_pad" - - def __init__( - self, - write_dir: str | Path, - data_spec: dict[str, Any] | None, - buffer_capacity: int | None, - fps: float | None = None, - delta_timestamps: dict[str, list[float]] | dict[str, np.ndarray] | None = None, - ): - """ - The online buffer can be provided from scratch or you can load an existing online buffer by passing - a `write_dir` associated with an existing buffer. - - Args: - write_dir: Where to keep the numpy memmap files. One memmap file will be stored for each data key. - Note that if the files already exist, they are opened in read-write mode (used for training - resumption.) - data_spec: A mapping from data key to data specification, like {data_key: {"shape": tuple[int], - "dtype": np.dtype}}. This should include all the data that you wish to record into the buffer, - but note that "index", "frame_index" and "episode_index" are already accounted for by this - class, so you don't need to include them. - buffer_capacity: How many frames should be stored in the buffer as a maximum. Be aware of your - system's available disk space when choosing this. - fps: Same as the fps concept in LeRobot dataset. Here it needs to be provided for the - delta_timestamps logic. You can pass None if you are not using delta_timestamps. - delta_timestamps: Same as the delta_timestamps concept in LeRobotDataset. This is internally - converted to dict[str, np.ndarray] for optimization purposes. - - """ - self.set_delta_timestamps(delta_timestamps) - self._fps = fps - # Tolerance in seconds used to discard loaded frames when their timestamps are not close enough from - # the requested frames. It is only used when `delta_timestamps` is provided. - # minus 1e-4 to account for possible numerical error - self.tolerance_s = 1 / self.fps - 1e-4 if fps is not None else None - self._buffer_capacity = buffer_capacity - data_spec = self._make_data_spec(data_spec, buffer_capacity) - Path(write_dir).mkdir(parents=True, exist_ok=True) - self._data = {} - for k, v in data_spec.items(): - self._data[k] = _make_memmap_safe( - filename=Path(write_dir) / k, - dtype=v["dtype"] if v is not None else None, - mode="r+" if (Path(write_dir) / k).exists() else "w+", - shape=tuple(v["shape"]) if v is not None else None, - ) - - @property - def delta_timestamps(self) -> dict[str, np.ndarray] | None: - return self._delta_timestamps - - def set_delta_timestamps(self, value: dict[str, list[float]] | None): - """Set delta_timestamps converting the values to numpy arrays. - - The conversion is for an optimization in the __getitem__. The loop is much slower if the arrays - need to be converted into numpy arrays. - """ - if value is not None: - self._delta_timestamps = {k: np.array(v) for k, v in value.items()} - else: - self._delta_timestamps = None - - def _make_data_spec(self, data_spec: dict[str, Any], buffer_capacity: int) -> dict[str, dict[str, Any]]: - """Makes the data spec for np.memmap.""" - if any(k.startswith("_") for k in data_spec): - raise ValueError( - "data_spec keys should not start with '_'. This prefix is reserved for internal logic." - ) - preset_keys = { - OnlineBuffer.INDEX_KEY, - OnlineBuffer.FRAME_INDEX_KEY, - OnlineBuffer.EPISODE_INDEX_KEY, - OnlineBuffer.TIMESTAMP_KEY, - } - if len(intersection := set(data_spec).intersection(preset_keys)) > 0: - raise ValueError( - f"data_spec should not contain any of {preset_keys} as these are handled internally. " - f"The provided data_spec has {intersection}." - ) - complete_data_spec = { - # _next_index will be a pointer to the next index that we should start filling from when we add - # more data. - OnlineBuffer.NEXT_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": ()}, - # Since the memmap is initialized with all-zeros, this keeps track of which indices are occupied - # with real data rather than the dummy initialization. - OnlineBuffer.OCCUPANCY_MASK_KEY: {"dtype": np.dtype("?"), "shape": (buffer_capacity,)}, - OnlineBuffer.INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)}, - OnlineBuffer.FRAME_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)}, - OnlineBuffer.EPISODE_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)}, - OnlineBuffer.TIMESTAMP_KEY: {"dtype": np.dtype("float64"), "shape": (buffer_capacity,)}, - } - for k, v in data_spec.items(): - complete_data_spec[k] = {"dtype": v["dtype"], "shape": (buffer_capacity, *v["shape"])} - return complete_data_spec - - def add_data(self, data: dict[str, np.ndarray]): - """Add new data to the buffer, which could potentially mean shifting old data out. - - The new data should contain all the frames (in order) of any number of episodes. The indices should - start from 0 (note to the developer: this can easily be generalized). See the `rollout` and - `eval_policy` functions in `eval.py` for more information on how the data is constructed. - - Shift the incoming data index and episode_index to continue on from the last frame. Note that this - will be done in place! - """ - if len(missing_keys := (set(self.data_keys).difference(set(data)))) > 0: - raise ValueError(f"Missing data keys: {missing_keys}") - new_data_length = len(data[self.data_keys[0]]) - if not all(len(data[k]) == new_data_length for k in self.data_keys): - raise ValueError("All data items should have the same length") - - next_index = self._data[OnlineBuffer.NEXT_INDEX_KEY] - - # Sanity check to make sure that the new data indices start from 0. - assert data[OnlineBuffer.EPISODE_INDEX_KEY][0].item() == 0 - assert data[OnlineBuffer.INDEX_KEY][0].item() == 0 - - # Shift the incoming indices if necessary. - if self.num_frames > 0: - last_episode_index = self._data[OnlineBuffer.EPISODE_INDEX_KEY][next_index - 1] - last_data_index = self._data[OnlineBuffer.INDEX_KEY][next_index - 1] - data[OnlineBuffer.EPISODE_INDEX_KEY] += last_episode_index + 1 - data[OnlineBuffer.INDEX_KEY] += last_data_index + 1 - - # Insert the new data starting from next_index. It may be necessary to wrap around to the start. - n_surplus = max(0, new_data_length - (self._buffer_capacity - next_index)) - for k in self.data_keys: - if n_surplus == 0: - slc = slice(next_index, next_index + new_data_length) - self._data[k][slc] = data[k] - self._data[OnlineBuffer.OCCUPANCY_MASK_KEY][slc] = True - else: - self._data[k][next_index:] = data[k][:-n_surplus] - self._data[OnlineBuffer.OCCUPANCY_MASK_KEY][next_index:] = True - self._data[k][:n_surplus] = data[k][-n_surplus:] - if n_surplus == 0: - self._data[OnlineBuffer.NEXT_INDEX_KEY] = next_index + new_data_length - else: - self._data[OnlineBuffer.NEXT_INDEX_KEY] = n_surplus - - @property - def data_keys(self) -> list[str]: - keys = set(self._data) - keys.remove(OnlineBuffer.OCCUPANCY_MASK_KEY) - keys.remove(OnlineBuffer.NEXT_INDEX_KEY) - return sorted(keys) - - @property - def fps(self) -> float | None: - return self._fps - - @property - def num_episodes(self) -> int: - return len( - np.unique(self._data[OnlineBuffer.EPISODE_INDEX_KEY][self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]]) - ) - - @property - def num_frames(self) -> int: - return np.count_nonzero(self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]) - - def __len__(self): - return self.num_frames - - def _item_to_tensors(self, item: dict) -> dict: - item_ = {} - for k, v in item.items(): - if isinstance(v, torch.Tensor): - item_[k] = v - elif isinstance(v, np.ndarray): - item_[k] = torch.from_numpy(v) - else: - item_[k] = torch.tensor(v) - return item_ - - def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: - if idx >= len(self) or idx < -len(self): - raise IndexError - - item = {k: v[idx] for k, v in self._data.items() if not k.startswith("_")} - - if self.delta_timestamps is None: - return self._item_to_tensors(item) - - episode_index = item[OnlineBuffer.EPISODE_INDEX_KEY] - current_ts = item[OnlineBuffer.TIMESTAMP_KEY] - episode_data_indices = np.where( - np.bitwise_and( - self._data[OnlineBuffer.EPISODE_INDEX_KEY] == episode_index, - self._data[OnlineBuffer.OCCUPANCY_MASK_KEY], - ) - )[0] - episode_timestamps = self._data[OnlineBuffer.TIMESTAMP_KEY][episode_data_indices] - - for data_key in self.delta_timestamps: - # Note: The logic in this loop is copied from `load_previous_and_future_frames`. - # Get timestamps used as query to retrieve data of previous/future frames. - query_ts = current_ts + self.delta_timestamps[data_key] - - # Compute distances between each query timestamp and all timestamps of all the frames belonging to - # the episode. - dist = np.abs(query_ts[:, None] - episode_timestamps[None, :]) - argmin_ = np.argmin(dist, axis=1) - min_ = dist[np.arange(dist.shape[0]), argmin_] - - is_pad = min_ > self.tolerance_s - - # Check violated query timestamps are all outside the episode range. - assert ( - (query_ts[is_pad] < episode_timestamps[0]) | (episode_timestamps[-1] < query_ts[is_pad]) - ).all(), ( - f"One or several timestamps unexpectedly violate the tolerance ({min_} > {self.tolerance_s=}" - ") inside the episode range." - ) - - # Load frames for this data key. - item[data_key] = self._data[data_key][episode_data_indices[argmin_]] - - item[f"{data_key}{OnlineBuffer.IS_PAD_POSTFIX}"] = is_pad - - return self._item_to_tensors(item) - - def get_data_by_key(self, key: str) -> torch.Tensor: - """Returns all data for a given data key as a Tensor.""" - return torch.from_numpy(self._data[key][self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]]) - - -def compute_sampler_weights( - offline_dataset: LeRobotDataset, - offline_drop_n_last_frames: int = 0, - online_dataset: OnlineBuffer | None = None, - online_sampling_ratio: float | None = None, - online_drop_n_last_frames: int = 0, -) -> torch.Tensor: - """Compute the sampling weights for the online training dataloader in train.py. - - Args: - offline_dataset: The LeRobotDataset used for offline pre-training. - online_drop_n_last_frames: Number of frames to drop from the end of each offline dataset episode. - online_dataset: The OnlineBuffer used in online training. - online_sampling_ratio: The proportion of data that should be sampled from the online dataset. If an - online dataset is provided, this value must also be provided. - online_drop_n_first_frames: See `offline_drop_n_last_frames`. This is the same, but for the online - dataset. - Returns: - Tensor of weights for [offline_dataset; online_dataset], normalized to 1. - - Notes to maintainers: - - This duplicates some logic from EpisodeAwareSampler. We should consider converging to one approach. - - When used with `torch.utils.data.WeightedRandomSampler`, it could completely replace - `EpisodeAwareSampler` as the online dataset related arguments are optional. The only missing feature - is the ability to turn shuffling off. - - Options `drop_first_n_frames` and `episode_indices_to_use` can be added easily. They were not - included here to avoid adding complexity. - """ - if len(offline_dataset) == 0 and (online_dataset is None or len(online_dataset) == 0): - raise ValueError("At least one of `offline_dataset` or `online_dataset` should be contain data.") - if (online_dataset is None) ^ (online_sampling_ratio is None): - raise ValueError( - "`online_dataset` and `online_sampling_ratio` must be provided together or not at all." - ) - offline_sampling_ratio = 0 if online_sampling_ratio is None else 1 - online_sampling_ratio - - weights = [] - - if len(offline_dataset) > 0: - offline_data_mask_indices = [] - for start_index, end_index in zip( - offline_dataset.meta.episodes["dataset_from_index"], - offline_dataset.meta.episodes["dataset_to_index"], - strict=True, - ): - offline_data_mask_indices.extend(range(start_index, end_index - offline_drop_n_last_frames)) - offline_data_mask = torch.zeros(len(offline_dataset), dtype=torch.bool) - offline_data_mask[torch.tensor(offline_data_mask_indices)] = True - weights.append( - torch.full( - size=(len(offline_dataset),), - fill_value=offline_sampling_ratio / offline_data_mask.sum(), - ) - * offline_data_mask - ) - - if online_dataset is not None and len(online_dataset) > 0: - online_data_mask_indices = [] - episode_indices = online_dataset.get_data_by_key("episode_index") - for episode_idx in torch.unique(episode_indices): - where_episode = torch.where(episode_indices == episode_idx) - start_index = where_episode[0][0] - end_index = where_episode[0][-1] + 1 - online_data_mask_indices.extend( - range(start_index.item(), end_index.item() - online_drop_n_last_frames) - ) - online_data_mask = torch.zeros(len(online_dataset), dtype=torch.bool) - online_data_mask[torch.tensor(online_data_mask_indices)] = True - weights.append( - torch.full( - size=(len(online_dataset),), - fill_value=online_sampling_ratio / online_data_mask.sum(), - ) - * online_data_mask - ) - - weights = torch.cat(weights) - - if weights.sum() == 0: - weights += 1 / len(weights) - else: - weights /= weights.sum() - - return weights diff --git a/src/lerobot/datasets/push_dataset_to_hub/utils.py b/src/lerobot/datasets/push_dataset_to_hub/utils.py deleted file mode 100644 index 48214e1bf..000000000 --- a/src/lerobot/datasets/push_dataset_to_hub/utils.py +++ /dev/null @@ -1,73 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import datasets -import torch - - -# TODO(aliberts): remove -def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> dict[str, torch.Tensor]: - """ - Calculate episode data index for the provided HuggingFace Dataset. Relies on episode_index column of hf_dataset. - - Parameters: - - hf_dataset (datasets.Dataset): A HuggingFace dataset containing the episode index. - - Returns: - - episode_data_index: A dictionary containing the data index for each episode. The dictionary has two keys: - - "from": A tensor containing the starting index of each episode. - - "to": A tensor containing the ending index of each episode. - """ - episode_data_index = {"from": [], "to": []} - - current_episode = None - """ - The episode_index is a list of integers, each representing the episode index of the corresponding example. - For instance, the following is a valid episode_index: - [0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2] - - Below, we iterate through the episode_index and populate the episode_data_index dictionary with the starting and - ending index of each episode. For the episode_index above, the episode_data_index dictionary will look like this: - { - "from": [0, 3, 7], - "to": [3, 7, 12] - } - """ - if len(hf_dataset) == 0: - episode_data_index = { - "from": torch.tensor([]), - "to": torch.tensor([]), - } - return episode_data_index - for idx, episode_idx in enumerate(hf_dataset["episode_index"]): - if episode_idx != current_episode: - # We encountered a new episode, so we append its starting location to the "from" list - episode_data_index["from"].append(idx) - # If this is not the first episode, we append the ending location of the previous episode to the "to" list - if current_episode is not None: - episode_data_index["to"].append(idx) - # Let's keep track of the current episode index - current_episode = episode_idx - else: - # We are still in the same episode, so there is nothing for us to do here - pass - # We have reached the end of the dataset, so we append the ending location of the last episode to the "to" list - episode_data_index["to"].append(idx + 1) - - for k in ["from", "to"]: - episode_data_index[k] = torch.tensor(episode_data_index[k]) - - return episode_data_index diff --git a/src/lerobot/datasets/v30/augment_dataset_quantile_stats.py b/src/lerobot/scripts/augment_dataset_quantile_stats.py similarity index 99% rename from src/lerobot/datasets/v30/augment_dataset_quantile_stats.py rename to src/lerobot/scripts/augment_dataset_quantile_stats.py index 900a43a4f..e6ab6867e 100644 --- a/src/lerobot/datasets/v30/augment_dataset_quantile_stats.py +++ b/src/lerobot/scripts/augment_dataset_quantile_stats.py @@ -28,7 +28,7 @@ quantile statistics (q01, q10, q50, q90, q99) in their metadata. This script: Usage: ```bash -python src/lerobot/datasets/v30/augment_dataset_quantile_stats.py \ +python src/lerobot/scripts/augment_dataset_quantile_stats.py \ --repo-id=lerobot/pusht \ ``` """ diff --git a/src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py b/src/lerobot/scripts/convert_dataset_v21_to_v30.py similarity index 99% rename from src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py rename to src/lerobot/scripts/convert_dataset_v21_to_v30.py index 81de05686..dc81cc51c 100644 --- a/src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py +++ b/src/lerobot/scripts/convert_dataset_v21_to_v30.py @@ -28,13 +28,13 @@ Usage: Convert a dataset from the hub: ```bash -python src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py \ +python src/lerobot/scripts/convert_dataset_v21_to_v30.py \ --repo-id=lerobot/pusht ``` Convert a local dataset (works in place): ```bash -python src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py \ +python src/lerobot/scripts/convert_dataset_v21_to_v30.py \ --repo-id=lerobot/pusht \ --root=/path/to/local/dataset/directory \ --push-to-hub=false diff --git a/tests/datasets/test_dataset_utils.py b/tests/datasets/test_dataset_utils.py index 99b832e55..d40ee238f 100644 --- a/tests/datasets/test_dataset_utils.py +++ b/tests/datasets/test_dataset_utils.py @@ -19,11 +19,26 @@ import torch from datasets import Dataset from huggingface_hub import DatasetCard -from lerobot.datasets.push_dataset_to_hub.utils import calculate_episode_data_index from lerobot.datasets.utils import combine_feature_dicts, create_lerobot_dataset_card, hf_transform_to_torch from lerobot.utils.constants import ACTION, OBS_IMAGES +def calculate_episode_data_index(hf_dataset: Dataset) -> dict[str, torch.Tensor]: + """Calculate episode data index for testing. Returns {"from": Tensor, "to": Tensor}.""" + episode_data_index: dict[str, list[int]] = {"from": [], "to": []} + current_episode = None + if len(hf_dataset) == 0: + return {"from": torch.tensor([]), "to": torch.tensor([])} + for idx, episode_idx in enumerate(hf_dataset["episode_index"]): + if episode_idx != current_episode: + episode_data_index["from"].append(idx) + if current_episode is not None: + episode_data_index["to"].append(idx) + current_episode = episode_idx + episode_data_index["to"].append(idx + 1) + return {k: torch.tensor(v) for k, v in episode_data_index.items()} + + def test_default_parameters(): card = create_lerobot_dataset_card() assert isinstance(card, DatasetCard) diff --git a/tests/datasets/test_online_buffer.py b/tests/datasets/test_online_buffer.py deleted file mode 100644 index 887da6041..000000000 --- a/tests/datasets/test_online_buffer.py +++ /dev/null @@ -1,282 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License.d -from copy import deepcopy -from uuid import uuid4 - -import numpy as np -import pytest -import torch - -from lerobot.datasets.online_buffer import OnlineBuffer, compute_sampler_weights - -# Some constants for OnlineBuffer tests. -data_key = "data" -data_shape = (2, 3) # just some arbitrary > 1D shape -buffer_capacity = 100 -fps = 10 - - -def make_new_buffer( - write_dir: str | None = None, delta_timestamps: dict[str, list[float]] | None = None -) -> tuple[OnlineBuffer, str]: - if write_dir is None: - write_dir = f"/tmp/online_buffer_{uuid4().hex}" - buffer = OnlineBuffer( - write_dir, - data_spec={data_key: {"shape": data_shape, "dtype": np.dtype("float32")}}, - buffer_capacity=buffer_capacity, - fps=fps, - delta_timestamps=delta_timestamps, - ) - return buffer, write_dir - - -def make_spoof_data_frames(n_episodes: int, n_frames_per_episode: int) -> dict[str, np.ndarray]: - new_data = { - data_key: np.arange(n_frames_per_episode * n_episodes * np.prod(data_shape)).reshape(-1, *data_shape), - OnlineBuffer.INDEX_KEY: np.arange(n_frames_per_episode * n_episodes), - OnlineBuffer.EPISODE_INDEX_KEY: np.repeat(np.arange(n_episodes), n_frames_per_episode), - OnlineBuffer.FRAME_INDEX_KEY: np.tile(np.arange(n_frames_per_episode), n_episodes), - OnlineBuffer.TIMESTAMP_KEY: np.tile(np.arange(n_frames_per_episode) / fps, n_episodes), - } - return new_data - - -def test_non_mutate(): - """Checks that the data provided to the add_data method is copied rather than passed by reference. - - This means that mutating the data in the buffer does not mutate the original data. - - NOTE: If this test fails, it means some of the other tests may be compromised. For example, we can't trust - a success case for `test_write_read`. - """ - buffer, _ = make_new_buffer() - new_data = make_spoof_data_frames(2, buffer_capacity // 4) - new_data_copy = deepcopy(new_data) - buffer.add_data(new_data) - buffer._data[data_key][:] += 1 - assert all(np.array_equal(new_data[k], new_data_copy[k]) for k in new_data) - - -def test_index_error_no_data(): - buffer, _ = make_new_buffer() - with pytest.raises(IndexError): - buffer[0] - - -def test_index_error_with_data(): - buffer, _ = make_new_buffer() - n_frames = buffer_capacity // 2 - new_data = make_spoof_data_frames(1, n_frames) - buffer.add_data(new_data) - with pytest.raises(IndexError): - buffer[n_frames] - with pytest.raises(IndexError): - buffer[-n_frames - 1] - - -@pytest.mark.parametrize("do_reload", [False, True]) -def test_write_read(do_reload: bool): - """Checks that data can be added to the buffer and read back. - - If do_reload we delete the buffer object and load the buffer back from disk before reading. - """ - buffer, write_dir = make_new_buffer() - n_episodes = 2 - n_frames_per_episode = buffer_capacity // 4 - new_data = make_spoof_data_frames(n_episodes, n_frames_per_episode) - buffer.add_data(new_data) - - if do_reload: - del buffer - buffer, _ = make_new_buffer(write_dir) - - assert len(buffer) == n_frames_per_episode * n_episodes - for i, item in enumerate(buffer): - assert all(isinstance(item[k], torch.Tensor) for k in item) - assert np.array_equal(item[data_key].numpy(), new_data[data_key][i]) - - -def test_read_data_key(): - """Tests that data can be added to a buffer and all data for a. specific key can be read back.""" - buffer, _ = make_new_buffer() - n_episodes = 2 - n_frames_per_episode = buffer_capacity // 4 - new_data = make_spoof_data_frames(n_episodes, n_frames_per_episode) - buffer.add_data(new_data) - - data_from_buffer = buffer.get_data_by_key(data_key) - assert isinstance(data_from_buffer, torch.Tensor) - assert np.array_equal(data_from_buffer.numpy(), new_data[data_key]) - - -def test_fifo(): - """Checks that if data is added beyond the buffer capacity, we discard the oldest data first.""" - buffer, _ = make_new_buffer() - n_frames_per_episode = buffer_capacity // 4 - n_episodes = 3 - new_data = make_spoof_data_frames(n_episodes, n_frames_per_episode) - buffer.add_data(new_data) - n_more_episodes = 2 - # Developer sanity check (in case someone changes the global `buffer_capacity`). - assert (n_episodes + n_more_episodes) * n_frames_per_episode > buffer_capacity, ( - "Something went wrong with the test code." - ) - more_new_data = make_spoof_data_frames(n_more_episodes, n_frames_per_episode) - buffer.add_data(more_new_data) - assert len(buffer) == buffer_capacity, "The buffer should be full." - - expected_data = {} - for k in new_data: - # Concatenate, left-truncate, then roll, to imitate the cyclical FIFO pattern in OnlineBuffer. - expected_data[k] = np.roll( - np.concatenate([new_data[k], more_new_data[k]])[-buffer_capacity:], - shift=len(new_data[k]) + len(more_new_data[k]) - buffer_capacity, - axis=0, - ) - - for i, item in enumerate(buffer): - assert all(isinstance(item[k], torch.Tensor) for k in item) - assert np.array_equal(item[data_key].numpy(), expected_data[data_key][i]) - - -def test_delta_timestamps_within_tolerance(): - """Check that getting an item with delta_timestamps within tolerance succeeds. - - Note: Copied from `test_datasets.py::test_load_previous_and_future_frames_within_tolerance`. - """ - # Sanity check on global fps as we are assuming it is 10 here. - assert fps == 10, "This test assumes fps==10" - buffer, _ = make_new_buffer(delta_timestamps={"index": [-0.2, 0, 0.139]}) - new_data = make_spoof_data_frames(n_episodes=1, n_frames_per_episode=5) - buffer.add_data(new_data) - buffer.tolerance_s = 0.04 - item = buffer[2] - data, is_pad = item["index"], item[f"index{OnlineBuffer.IS_PAD_POSTFIX}"] - torch.testing.assert_close(data, torch.tensor([0, 2, 3]), msg="Data does not match expected values") - assert not is_pad.any(), "Unexpected padding detected" - - -def test_delta_timestamps_outside_tolerance_inside_episode_range(): - """Check that getting an item with delta_timestamps outside of tolerance fails. - - We expect it to fail if and only if the requested timestamps are within the episode range. - - Note: Copied from - `test_datasets.py::test_load_previous_and_future_frames_outside_tolerance_inside_episode_range` - """ - # Sanity check on global fps as we are assuming it is 10 here. - assert fps == 10, "This test assumes fps==10" - buffer, _ = make_new_buffer(delta_timestamps={"index": [-0.2, 0, 0.141]}) - new_data = make_spoof_data_frames(n_episodes=1, n_frames_per_episode=5) - buffer.add_data(new_data) - buffer.tolerance_s = 0.04 - with pytest.raises(AssertionError): - buffer[2] - - -def test_delta_timestamps_outside_tolerance_outside_episode_range(): - """Check that copy-padding of timestamps outside of the episode range works. - - Note: Copied from - `test_datasets.py::test_load_previous_and_future_frames_outside_tolerance_outside_episode_range` - """ - # Sanity check on global fps as we are assuming it is 10 here. - assert fps == 10, "This test assumes fps==10" - buffer, _ = make_new_buffer(delta_timestamps={"index": [-0.3, -0.24, 0, 0.26, 0.3]}) - new_data = make_spoof_data_frames(n_episodes=1, n_frames_per_episode=5) - buffer.add_data(new_data) - buffer.tolerance_s = 0.04 - item = buffer[2] - data, is_pad = item["index"], item["index_is_pad"] - assert torch.equal(data, torch.tensor([0, 0, 2, 4, 4])), "Data does not match expected values" - assert torch.equal(is_pad, torch.tensor([True, False, False, True, True])), ( - "Padding does not match expected values" - ) - - -# Arbitrarily set small dataset sizes, making sure to have uneven sizes. -@pytest.mark.parametrize("offline_dataset_size", [1, 6]) -@pytest.mark.parametrize("online_dataset_size", [0, 4]) -@pytest.mark.parametrize("online_sampling_ratio", [0.0, 1.0]) -def test_compute_sampler_weights_trivial( - lerobot_dataset_factory, - tmp_path, - offline_dataset_size: int, - online_dataset_size: int, - online_sampling_ratio: float, -): - offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=offline_dataset_size) - online_dataset, _ = make_new_buffer() - if online_dataset_size > 0: - online_dataset.add_data( - make_spoof_data_frames(n_episodes=2, n_frames_per_episode=online_dataset_size // 2) - ) - - weights = compute_sampler_weights( - offline_dataset, online_dataset=online_dataset, online_sampling_ratio=online_sampling_ratio - ) - if offline_dataset_size == 0 or online_dataset_size == 0: - expected_weights = torch.ones(offline_dataset_size + online_dataset_size) - elif online_sampling_ratio == 0: - expected_weights = torch.cat([torch.ones(offline_dataset_size), torch.zeros(online_dataset_size)]) - elif online_sampling_ratio == 1: - expected_weights = torch.cat([torch.zeros(offline_dataset_size), torch.ones(online_dataset_size)]) - expected_weights /= expected_weights.sum() - torch.testing.assert_close(weights, expected_weights) - - -def test_compute_sampler_weights_nontrivial_ratio(lerobot_dataset_factory, tmp_path): - # Arbitrarily set small dataset sizes, making sure to have uneven sizes. - offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=4) - online_dataset, _ = make_new_buffer() - online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2)) - online_sampling_ratio = 0.8 - weights = compute_sampler_weights( - offline_dataset, online_dataset=online_dataset, online_sampling_ratio=online_sampling_ratio - ) - torch.testing.assert_close( - weights, torch.tensor([0.05, 0.05, 0.05, 0.05, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]) - ) - - -def test_compute_sampler_weights_nontrivial_ratio_and_drop_last_n(lerobot_dataset_factory, tmp_path): - # Arbitrarily set small dataset sizes, making sure to have uneven sizes. - offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=4) - online_dataset, _ = make_new_buffer() - online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2)) - weights = compute_sampler_weights( - offline_dataset, online_dataset=online_dataset, online_sampling_ratio=0.8, online_drop_n_last_frames=1 - ) - torch.testing.assert_close( - weights, torch.tensor([0.05, 0.05, 0.05, 0.05, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0]) - ) - - -def test_compute_sampler_weights_drop_n_last_frames(lerobot_dataset_factory, tmp_path): - """Note: test copied from test_sampler.""" - offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=2) - online_dataset, _ = make_new_buffer() - online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2)) - - weights = compute_sampler_weights( - offline_dataset, - offline_drop_n_last_frames=1, - online_dataset=online_dataset, - online_sampling_ratio=0.5, - online_drop_n_last_frames=1, - ) - torch.testing.assert_close(weights, torch.tensor([0.5, 0, 0.125, 0, 0.125, 0, 0.125, 0, 0.125, 0])) diff --git a/tests/datasets/test_sampler.py b/tests/datasets/test_sampler.py index fd7a6e380..e5b35e426 100644 --- a/tests/datasets/test_sampler.py +++ b/tests/datasets/test_sampler.py @@ -13,15 +13,31 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import torch from datasets import Dataset -from lerobot.datasets.push_dataset_to_hub.utils import calculate_episode_data_index from lerobot.datasets.sampler import EpisodeAwareSampler from lerobot.datasets.utils import ( hf_transform_to_torch, ) +def calculate_episode_data_index(hf_dataset: Dataset) -> dict[str, torch.Tensor]: + """Calculate episode data index for testing. Returns {"from": Tensor, "to": Tensor}.""" + episode_data_index: dict[str, list[int]] = {"from": [], "to": []} + current_episode = None + if len(hf_dataset) == 0: + return {"from": torch.tensor([]), "to": torch.tensor([])} + for idx, episode_idx in enumerate(hf_dataset["episode_index"]): + if episode_idx != current_episode: + episode_data_index["from"].append(idx) + if current_episode is not None: + episode_data_index["to"].append(idx) + current_episode = episode_idx + episode_data_index["to"].append(idx + 1) + return {k: torch.tensor(v) for k, v in episode_data_index.items()} + + def test_drop_n_first_frames(): dataset = Dataset.from_dict( {