From d691d1e4feb62c3dd0ddf41397b219d8dee79390 Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Mon, 22 Sep 2025 17:57:32 +0200 Subject: [PATCH] Add Quantile stats to LeRobotDataset (#1985) * - Add RunningQuantileStats class for efficient histogram-based quantile computation - Integrate quantile parameters (compute_quantiles, quantiles) into LeRobotDataset - Support quantile computation during episode collection and aggregation - Add comprehensive function-based test suite (24 tests) for quantile functionality - Maintain full backward compatibility with existing stats computation - Enable configurable quantiles (default: [0.01, 0.99]) for robust normalization * style fixes, make quantiles computation by default to new datasets * fix tests * - Added DEFAULT_QUANTILES=[0.01, 0.10, 0.50, 0.90, 0.99] to be computed for each features instead of being chosen by the user - Fortified tests. * - add helper functions to reshape stats - add missing test for quantiles * - Add QUANTILE normalization mode to normalize the data with the 1st and 99th percentiles. - Add QUANTILE10 normalization mode to normalize the data with the 10th and 90th percentiles. * style fixes * Added missing lisence * Simplify compute_stats * - added script `augment_dataset_quantile_stats.py` so that we can add quantile stats to existing v3 datasets that dont have quatniles - modified quantile computation instead of using the edge for the value, interpolate the values in the bin --- src/lerobot/configs/types.py | 2 + src/lerobot/datasets/compute_stats.py | 518 +++++++++++++++-- .../v30/augment_dataset_quantile_stats.py | 205 +++++++ src/lerobot/processor/normalize_processor.py | 39 +- tests/datasets/test_compute_stats.py | 524 ++++++++++++++++++ .../test_quantiles_dataset_integration.py | 212 +++++++ tests/processor/test_normalize_processor.py | 223 ++++++++ 7 files changed, 1689 insertions(+), 34 deletions(-) create mode 100644 src/lerobot/datasets/v30/augment_dataset_quantile_stats.py create mode 100644 tests/datasets/test_quantiles_dataset_integration.py diff --git a/src/lerobot/configs/types.py b/src/lerobot/configs/types.py index e02527840..ea8aa039d 100644 --- a/src/lerobot/configs/types.py +++ b/src/lerobot/configs/types.py @@ -36,6 +36,8 @@ class NormalizationMode(str, Enum): MIN_MAX = "MIN_MAX" MEAN_STD = "MEAN_STD" IDENTITY = "IDENTITY" + QUANTILES = "QUANTILES" + QUANTILE10 = "QUANTILE10" class DictLike(Protocol): diff --git a/src/lerobot/datasets/compute_stats.py b/src/lerobot/datasets/compute_stats.py index bfe7b18b4..9fee84487 100644 --- a/src/lerobot/datasets/compute_stats.py +++ b/src/lerobot/datasets/compute_stats.py @@ -17,6 +17,171 @@ import numpy as np from lerobot.datasets.utils import load_image_as_numpy +DEFAULT_QUANTILES = [0.01, 0.10, 0.50, 0.90, 0.99] + + +class RunningQuantileStats: + """Compute running statistics including quantiles for a batch of vectors.""" + + def __init__(self, quantile_list: list[float] | None = None, num_quantile_bins: int = 5000): + self._count = 0 + self._mean = None + self._mean_of_squares = None + self._min = None + self._max = None + self._histograms = None + self._bin_edges = None + self._num_quantile_bins = num_quantile_bins + + self._quantile_list = quantile_list + if self._quantile_list is None: + self._quantile_list = DEFAULT_QUANTILES + self._quantile_keys = [f"q{int(q * 100):02d}" for q in self._quantile_list] + + def update(self, batch: np.ndarray) -> None: + """Update the running statistics with a batch of vectors. + + Args: + batch: An array where all dimensions except the last are batch dimensions. + """ + batch = batch.reshape(-1, batch.shape[-1]) + num_elements, vector_length = batch.shape + + if self._count == 0: + self._mean = np.mean(batch, axis=0) + self._mean_of_squares = np.mean(batch**2, axis=0) + self._min = np.min(batch, axis=0) + self._max = np.max(batch, axis=0) + self._histograms = [np.zeros(self._num_quantile_bins) for _ in range(vector_length)] + self._bin_edges = [ + np.linspace(self._min[i] - 1e-10, self._max[i] + 1e-10, self._num_quantile_bins + 1) + for i in range(vector_length) + ] + else: + if vector_length != self._mean.size: + raise ValueError("The length of new vectors does not match the initialized vector length.") + + new_max = np.max(batch, axis=0) + new_min = np.min(batch, axis=0) + max_changed = np.any(new_max > self._max) + min_changed = np.any(new_min < self._min) + self._max = np.maximum(self._max, new_max) + self._min = np.minimum(self._min, new_min) + + if max_changed or min_changed: + self._adjust_histograms() + + self._count += num_elements + + batch_mean = np.mean(batch, axis=0) + batch_mean_of_squares = np.mean(batch**2, axis=0) + + # Update running mean and mean of squares + self._mean += (batch_mean - self._mean) * (num_elements / self._count) + self._mean_of_squares += (batch_mean_of_squares - self._mean_of_squares) * ( + num_elements / self._count + ) + + self._update_histograms(batch) + + def get_statistics(self) -> dict[str, np.ndarray]: + """Compute and return the statistics of the vectors processed so far. + + Args: + quantiles: List of quantiles to compute (e.g., [0.01, 0.10, 0.50, 0.90, 0.99]). If None, no quantiles computed. + + Returns: + Dictionary containing the computed statistics. + """ + if self._count < 2: + raise ValueError("Cannot compute statistics for less than 2 vectors.") + + variance = self._mean_of_squares - self._mean**2 + stddev = np.sqrt(np.maximum(0, variance)) + + stats = { + "min": self._min.copy(), + "max": self._max.copy(), + "mean": self._mean.copy(), + "std": stddev, + "count": np.array([self._count]), + } + + quantile_results = self._compute_quantiles() + for i, q in enumerate(self._quantile_keys): + stats[q] = quantile_results[i] + + return stats + + def _adjust_histograms(self): + """Adjust histograms when min or max changes.""" + for i in range(len(self._histograms)): + old_edges = self._bin_edges[i] + old_hist = self._histograms[i] + + # Create new edges with small padding to ensure range coverage + padding = (self._max[i] - self._min[i]) * 1e-10 + new_edges = np.linspace( + self._min[i] - padding, self._max[i] + padding, self._num_quantile_bins + 1 + ) + + # Redistribute existing histogram counts to new bins + # We need to map each old bin center to the new bins + old_centers = (old_edges[:-1] + old_edges[1:]) / 2 + new_hist = np.zeros(self._num_quantile_bins) + + for old_center, count in zip(old_centers, old_hist, strict=False): + if count > 0: + # Find which new bin this old center belongs to + bin_idx = np.searchsorted(new_edges, old_center) - 1 + bin_idx = max(0, min(bin_idx, self._num_quantile_bins - 1)) + new_hist[bin_idx] += count + + self._histograms[i] = new_hist + self._bin_edges[i] = new_edges + + def _update_histograms(self, batch: np.ndarray) -> None: + """Update histograms with new vectors.""" + for i in range(batch.shape[1]): + hist, _ = np.histogram(batch[:, i], bins=self._bin_edges[i]) + self._histograms[i] += hist + + def _compute_quantiles(self) -> list[np.ndarray]: + """Compute quantiles based on histograms.""" + results = [] + for q in self._quantile_list: + target_count = q * self._count + q_values = [] + + for hist, edges in zip(self._histograms, self._bin_edges, strict=True): + q_value = self._compute_single_quantile(hist, edges, target_count) + q_values.append(q_value) + + results.append(np.array(q_values)) + return results + + def _compute_single_quantile(self, hist: np.ndarray, edges: np.ndarray, target_count: float) -> float: + """Compute a single quantile value from histogram and bin edges.""" + cumsum = np.cumsum(hist) + idx = np.searchsorted(cumsum, target_count) + + if idx == 0: + return edges[0] + if idx >= len(cumsum): + return edges[-1] + + # If not edge case, interpolate within the bin + count_before = cumsum[idx - 1] + count_in_bin = cumsum[idx] - count_before + + # If no samples in this bin, use the bin edge + if count_in_bin == 0: + return edges[idx] + + # Linear interpolation within the bin + fraction = (target_count - count_before) / count_in_bin + return edges[idx] + fraction * (edges[idx + 1] - edges[idx]) + def estimate_num_samples( dataset_len: int, min_num_samples: int = 100, max_num_samples: int = 10_000, power: float = 0.75 @@ -72,33 +237,296 @@ def sample_images(image_paths: list[str]) -> np.ndarray: return images -def get_feature_stats(array: np.ndarray, axis: tuple, keepdims: bool) -> dict[str, np.ndarray]: - return { - "min": np.min(array, axis=axis, keepdims=keepdims), - "max": np.max(array, axis=axis, keepdims=keepdims), - "mean": np.mean(array, axis=axis, keepdims=keepdims), - "std": np.std(array, axis=axis, keepdims=keepdims), - "count": np.array([len(array)]), +def _reshape_stats_by_axis( + stats: dict[str, np.ndarray], + axis: int | tuple[int, ...] | None, + keepdims: bool, + original_shape: tuple[int, ...], +) -> dict[str, np.ndarray]: + """Reshape all statistics to match NumPy's output conventions. + + Applies consistent reshaping to all statistics (except 'count') based on the + axis and keepdims parameters. This ensures statistics have the correct shape + for broadcasting with the original data. + + Args: + stats: Dictionary of computed statistics + axis: Axis or axes along which statistics were computed + keepdims: Whether to keep reduced dimensions as size-1 dimensions + original_shape: Shape of the original array + + Returns: + Dictionary with reshaped statistics + + Note: + The 'count' statistic is never reshaped as it represents metadata + rather than per-feature statistics. + """ + if axis == (1,) and not keepdims: + return stats + + result = {} + for key, value in stats.items(): + if key == "count": + result[key] = value + else: + result[key] = _reshape_single_stat(value, axis, keepdims, original_shape) + + return result + + +def _reshape_for_image_stats(value: np.ndarray, keepdims: bool) -> np.ndarray: + """Reshape statistics for image data (axis=(0,2,3)).""" + if keepdims and value.ndim == 1: + return value.reshape(1, -1, 1, 1) + return value + + +def _reshape_for_vector_stats( + value: np.ndarray, keepdims: bool, original_shape: tuple[int, ...] +) -> np.ndarray: + """Reshape statistics for vector data (axis=0 or axis=(0,)).""" + if not keepdims: + return value + + if len(original_shape) == 1 and value.ndim > 0: + return value.reshape(1) + elif len(original_shape) >= 2 and value.ndim == 1: + return value.reshape(1, -1) + return value + + +def _reshape_for_feature_stats(value: np.ndarray, keepdims: bool) -> np.ndarray: + """Reshape statistics for feature-wise computation (axis=(1,)).""" + if not keepdims: + return value + + if value.ndim == 0: + return value.reshape(1, 1) + elif value.ndim == 1: + return value.reshape(-1, 1) + return value + + +def _reshape_for_global_stats( + value: np.ndarray, keepdims: bool, original_shape: tuple[int, ...] +) -> np.ndarray | float: + """Reshape statistics for global reduction (axis=None).""" + if keepdims: + target_shape = tuple(1 for _ in original_shape) + return value.reshape(target_shape) + elif not keepdims and value.ndim > 0 and value.size == 1: + return value.item() + return value + + +def _reshape_single_stat( + value: np.ndarray, axis: int | tuple[int, ...] | None, keepdims: bool, original_shape: tuple[int, ...] +) -> np.ndarray | float: + """Apply appropriate reshaping to a single statistic array. + + This function transforms statistic arrays to match expected output shapes + based on the axis configuration and keepdims parameter. + + Args: + value: The statistic array to reshape + axis: Axis or axes that were reduced during computation + keepdims: Whether to maintain reduced dimensions as size-1 dimensions + original_shape: Shape of the original data before reduction + + Returns: + Reshaped array following NumPy broadcasting conventions + + """ + if axis == (0, 2, 3): + return _reshape_for_image_stats(value, keepdims) + + if axis in [0, (0,)]: + return _reshape_for_vector_stats(value, keepdims, original_shape) + + if axis == (1,): + return _reshape_for_feature_stats(value, keepdims) + + if axis is None: + return _reshape_for_global_stats(value, keepdims, original_shape) + + return value + + +def _prepare_array_for_stats(array: np.ndarray, axis: int | tuple[int, ...] | None) -> tuple[np.ndarray, int]: + """Prepare array for statistics computation by reshaping according to axis. + + Args: + array: Input data array + axis: Axis or axes along which to compute statistics + + Returns: + Tuple of (reshaped_array, sample_count) + """ + if axis == (0, 2, 3): # Image data + batch_size, channels, height, width = array.shape + reshaped = array.transpose(0, 2, 3, 1).reshape(-1, channels) + return reshaped, batch_size + + if axis == 0 or axis == (0,): # Vector data + if array.ndim == 1: + reshaped = array.reshape(-1, 1) + else: + reshaped = array + return reshaped, array.shape[0] + + if axis == (1,): # Feature-wise statistics + return array.T, array.shape[1] + + if axis is None: # Global statistics + reshaped = array.reshape(-1, 1) + # For backward compatibility, count represents the first dimension size + return reshaped, array.shape[0] if array.ndim > 0 else 1 + + raise ValueError(f"Unsupported axis configuration: {axis}") + + +def _compute_basic_stats( + array: np.ndarray, sample_count: int, quantile_list: list[float] | None = None +) -> dict[str, np.ndarray]: + """Compute basic statistics for arrays with insufficient samples for quantiles. + + Args: + array: Reshaped array ready for statistics computation + sample_count: Number of samples represented in the data + + Returns: + Dictionary with basic statistics and quantiles set to mean values + """ + if quantile_list is None: + quantile_list = DEFAULT_QUANTILES + quantile_list_keys = [f"q{int(q * 100):02d}" for q in quantile_list] + + stats = { + "min": np.min(array, axis=0), + "max": np.max(array, axis=0), + "mean": np.mean(array, axis=0), + "std": np.std(array, axis=0), + "count": np.array([sample_count]), } + # For single-element arrays with shape (1,1), convert to scalar arrays + if array.shape == (1, 1): + for key in stats: + if key != "count" and stats[key].size == 1: + stats[key] = np.array(stats[key].item()) + + for q in quantile_list_keys: + stats[q] = stats["mean"].copy() + + return stats + + +def get_feature_stats( + array: np.ndarray, + axis: int | tuple[int, ...] | None, + keepdims: bool, + quantile_list: list[float] | None = None, +) -> dict[str, np.ndarray]: + """Compute comprehensive statistics for array features along specified axes. + + This function calculates min, max, mean, std, and quantiles (1%, 10%, 50%, 90%, 99%) + for the input array along the specified axes. It handles different data layouts: + - Image data: axis=(0,2,3) computes per-channel statistics + - Vector data: axis=0 computes per-feature statistics + - Feature-wise: axis=1 computes statistics across features + - Global: axis=None computes statistics over entire array + + Args: + array: Input data array with shape appropriate for the specified axis + axis: Axis or axes along which to compute statistics + - (0, 2, 3): For image data (batch, channels, height, width) + - 0 or (0,): For vector/tabular data (samples, features) + - (1,): For computing across features + - None: For global statistics over entire array + keepdims: If True, reduced axes are kept as dimensions with size 1 + + Returns: + Dictionary containing: + - 'min': Minimum values + - 'max': Maximum values + - 'mean': Mean values + - 'std': Standard deviation + - 'count': Number of samples (always shape (1,)) + - 'q01', 'q10', 'q50', 'q90', 'q99': Quantile values + + """ + if quantile_list is None: + quantile_list = DEFAULT_QUANTILES + + original_shape = array.shape + reshaped, sample_count = _prepare_array_for_stats(array, axis) + + if reshaped.shape[0] < 2: + stats = _compute_basic_stats(reshaped, sample_count, quantile_list) + else: + running_stats = RunningQuantileStats() + running_stats.update(reshaped) + stats = running_stats.get_statistics() + stats["count"] = np.array([sample_count]) + + # For axis=None, the stats are computed as 1D arrays but should be 0-dimensional arrays + if axis is None and reshaped.shape[1] == 1: + for key in stats: + if key != "count" and stats[key].size == 1: + stats[key] = np.array(stats[key].item()) + + stats = _reshape_stats_by_axis(stats, axis, keepdims, original_shape) + return stats + + +def compute_episode_stats( + episode_data: dict[str, list[str] | np.ndarray], + features: dict, + quantile_list: list[float] | None = None, +) -> dict: + """Compute comprehensive statistics for all features in an episode. + + Processes different data types appropriately: + - Images/videos: Samples from paths, computes per-channel stats, normalizes to [0,1] + - Numerical arrays: Computes per-feature statistics + - Strings: Skipped (no statistics computed) + + Args: + episode_data: Dictionary mapping feature names to data + - For images/videos: list of file paths + - For numerical data: numpy arrays + features: Dictionary describing each feature's dtype and shape + + Returns: + Dictionary mapping feature names to their statistics dictionaries. + Each statistics dictionary contains min, max, mean, std, count, and quantiles. + + Note: + Image statistics are normalized to [0,1] range and have shape (3,1,1) for + per-channel values when dtype is 'image' or 'video'. + """ + if quantile_list is None: + quantile_list = DEFAULT_QUANTILES -def compute_episode_stats(episode_data: dict[str, list[str] | np.ndarray], features: dict) -> dict: ep_stats = {} for key, data in episode_data.items(): if features[key]["dtype"] == "string": - continue # HACK: we should receive np.arrays of strings - elif features[key]["dtype"] in ["image", "video"]: - ep_ft_array = sample_images(data) # data is a list of image paths - axes_to_reduce = (0, 2, 3) # keep channel dim + continue + + if features[key]["dtype"] in ["image", "video"]: + ep_ft_array = sample_images(data) + axes_to_reduce = (0, 2, 3) keepdims = True else: - ep_ft_array = data # data is already a np.ndarray - axes_to_reduce = 0 # compute stats over the first axis - keepdims = data.ndim == 1 # keep as np.array + ep_ft_array = data + axes_to_reduce = 0 + keepdims = data.ndim == 1 - ep_stats[key] = get_feature_stats(ep_ft_array, axis=axes_to_reduce, keepdims=keepdims) + ep_stats[key] = get_feature_stats( + ep_ft_array, axis=axes_to_reduce, keepdims=keepdims, quantile_list=quantile_list + ) - # finally, we normalize and remove batch dim for images if features[key]["dtype"] in ["image", "video"]: ep_stats[key] = { k: v if k == "count" else np.squeeze(v / 255.0, axis=0) for k, v in ep_stats[key].items() @@ -107,20 +535,37 @@ def compute_episode_stats(episode_data: dict[str, list[str] | np.ndarray], featu return ep_stats +def _validate_stat_value(value: np.ndarray, key: str, feature_key: str) -> None: + """Validate a single statistic value.""" + if not isinstance(value, np.ndarray): + raise ValueError( + f"Stats must be composed of numpy array, but key '{key}' of feature '{feature_key}' " + f"is of type '{type(value)}' instead." + ) + + if value.ndim == 0: + raise ValueError("Number of dimensions must be at least 1, and is 0 instead.") + + if key == "count" and value.shape != (1,): + raise ValueError(f"Shape of 'count' must be (1), but is {value.shape} instead.") + + if "image" in feature_key and key != "count" and value.shape != (3, 1, 1): + raise ValueError(f"Shape of quantile '{key}' must be (3,1,1), but is {value.shape} instead.") + + def _assert_type_and_shape(stats_list: list[dict[str, dict]]): - for i in range(len(stats_list)): - for fkey in stats_list[i]: - for k, v in stats_list[i][fkey].items(): - if not isinstance(v, np.ndarray): - raise ValueError( - f"Stats must be composed of numpy array, but key '{k}' of feature '{fkey}' is of type '{type(v)}' instead." - ) - if v.ndim == 0: - raise ValueError("Number of dimensions must be at least 1, and is 0 instead.") - if k == "count" and v.shape != (1,): - raise ValueError(f"Shape of 'count' must be (1), but is {v.shape} instead.") - if "image" in fkey and k != "count" and v.shape != (3, 1, 1): - raise ValueError(f"Shape of '{k}' must be (3,1,1), but is {v.shape} instead.") + """Validate that all statistics have correct types and shapes. + + Args: + stats_list: List of statistics dictionaries to validate + + Raises: + ValueError: If any statistic has incorrect type or shape + """ + for stats in stats_list: + for feature_key, feature_stats in stats.items(): + for stat_key, stat_value in feature_stats.items(): + _validate_stat_value(stat_value, stat_key, feature_key) def aggregate_feature_stats(stats_ft_list: list[dict[str, dict]]) -> dict[str, dict[str, np.ndarray]]: @@ -143,7 +588,7 @@ def aggregate_feature_stats(stats_ft_list: list[dict[str, dict]]) -> dict[str, d weighted_variances = (variances + delta_means**2) * counts total_variance = weighted_variances.sum(axis=0) / total_count - return { + aggregated = { "min": np.min(np.stack([s["min"] for s in stats_ft_list]), axis=0), "max": np.max(np.stack([s["max"] for s in stats_ft_list]), axis=0), "mean": total_mean, @@ -151,6 +596,17 @@ def aggregate_feature_stats(stats_ft_list: list[dict[str, dict]]) -> dict[str, d "count": total_count, } + if stats_ft_list: + quantile_keys = [k for k in stats_ft_list[0].keys() if k.startswith("q") and k[1:].isdigit()] + + for q_key in quantile_keys: + if all(q_key in s for s in stats_ft_list): + quantile_values = np.stack([s[q_key] for s in stats_ft_list]) + weighted_quantiles = quantile_values * counts + aggregated[q_key] = weighted_quantiles.sum(axis=0) / total_count + + return aggregated + def aggregate_stats(stats_list: list[dict[str, dict]]) -> dict[str, dict[str, np.ndarray]]: """Aggregate stats from multiple compute_stats outputs into a single set of stats. diff --git a/src/lerobot/datasets/v30/augment_dataset_quantile_stats.py b/src/lerobot/datasets/v30/augment_dataset_quantile_stats.py new file mode 100644 index 000000000..a5247a728 --- /dev/null +++ b/src/lerobot/datasets/v30/augment_dataset_quantile_stats.py @@ -0,0 +1,205 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script augments existing LeRobot datasets with quantile statistics. + +Most datasets created before the quantile feature was added do not contain +quantile statistics (q01, q10, q50, q90, q99) in their metadata. This script: + +1. Loads an existing LeRobot dataset in v3.0 format +2. Checks if it already contains quantile statistics +3. If missing, computes quantile statistics for all features +4. Updates the dataset metadata with the new quantile statistics + +Usage: + +```bash +python src/lerobot/datasets/v30/augment_dataset_quantile_stats.py \ + --repo-id=lerobot/pusht \ +``` +""" + +import argparse +import logging +from pathlib import Path + +import numpy as np + +from lerobot.datasets.compute_stats import DEFAULT_QUANTILES, aggregate_stats, compute_episode_stats +from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.datasets.utils import write_stats +from lerobot.utils.utils import init_logging + + +def has_quantile_stats(stats: dict[str, dict] | None, quantile_list_keys: list[str] | None = None) -> bool: + """Check if dataset statistics already contain quantile information. + + Args: + stats: Dataset statistics dictionary + + Returns: + True if quantile statistics are present, False otherwise + """ + if quantile_list_keys is None: + quantile_list_keys = [f"q{int(q * 100):02d}" for q in DEFAULT_QUANTILES] + + if stats is None: + return False + + for feature_stats in stats.values(): + if any(q_key in feature_stats for q_key in quantile_list_keys): + return True + + return False + + +def load_episode_data(dataset: LeRobotDataset, episode_idx: int) -> dict: + """Load episode data by accessing the underlying HuggingFace dataset. + + Args: + dataset: The LeRobot dataset + episode_idx: Index of the episode to load + + Returns: + Dictionary containing episode data for each feature + """ + + episode_info = dataset.meta.episodes[episode_idx] + episode_length = episode_info["length"] + + start_idx = sum(dataset.meta.episodes[i]["length"] for i in range(episode_idx)) + end_idx = start_idx + episode_length + + episode_data = {} + + episode_slice = dataset.hf_dataset.select(range(start_idx, end_idx)) + + for key, feature_info in dataset.features.items(): + if feature_info["dtype"] == "string": + continue + + if feature_info["dtype"] in ["image", "video"]: + image_paths = [] + for row in episode_slice: + if key in row: + relative_path = row[key] + if isinstance(relative_path, str): + absolute_path = str(dataset.meta.root / relative_path) + image_paths.append(absolute_path) + + if image_paths: + episode_data[key] = image_paths + else: + arrays = [] + for row in episode_slice: + if key in row: + arrays.append(np.array(row[key])) + + if arrays: + episode_data[key] = np.stack(arrays) + + return episode_data + + +def compute_quantile_stats_for_dataset(dataset: LeRobotDataset) -> dict[str, dict]: + """Compute quantile statistics for all episodes in the dataset. + + Args: + dataset: The LeRobot dataset to compute statistics for + + Returns: + Dictionary containing aggregated statistics with quantiles + """ + logging.info(f"Computing quantile statistics for dataset with {dataset.num_episodes} episodes") + + episode_stats_list = [] + + for episode_idx in range(dataset.num_episodes): + episode_data = load_episode_data(dataset, episode_idx) + ep_stats = compute_episode_stats(episode_data, dataset.features) + episode_stats_list.append(ep_stats) + + if not episode_stats_list: + raise ValueError("No episode data found for computing statistics") + + logging.info(f"Aggregating statistics from {len(episode_stats_list)} episodes") + return aggregate_stats(episode_stats_list) + + +def augment_dataset_with_quantile_stats( + repo_id: str, + root: str | Path | None = None, +) -> None: + """Augment a dataset with quantile statistics if they are missing. + + Args: + repo_id: Repository ID of the dataset + root: Local root directory for the dataset + """ + logging.info(f"Loading dataset: {repo_id}") + dataset = LeRobotDataset( + repo_id=repo_id, + root=root, + ) + + if has_quantile_stats(dataset.meta.stats): + logging.info("Dataset already contains quantile statistics. No action needed.") + return + + logging.info("Dataset does not contain quantile statistics. Computing them now...") + + new_stats = compute_quantile_stats_for_dataset(dataset) + + logging.info("Updating dataset metadata with new quantile statistics") + dataset.meta.stats = new_stats + + write_stats(new_stats, dataset.meta.root) + + logging.info("Successfully updated dataset with quantile statistics") + dataset.push_to_hub() + + +def main(): + """Main function to run the augmentation script.""" + parser = argparse.ArgumentParser(description="Augment LeRobot dataset with quantile statistics") + + parser.add_argument( + "--repo-id", + type=str, + required=True, + help="Repository ID of the dataset (e.g., 'lerobot/pusht')", + ) + + parser.add_argument( + "--root", + type=str, + help="Local root directory for the dataset", + ) + + args = parser.parse_args() + root = Path(args.root) if args.root else None + + init_logging() + + augment_dataset_with_quantile_stats( + repo_id=args.repo_id, + root=root, + ) + + +if __name__ == "__main__": + main() diff --git a/src/lerobot/processor/normalize_processor.py b/src/lerobot/processor/normalize_processor.py index bece54f0b..18a71a431 100644 --- a/src/lerobot/processor/normalize_processor.py +++ b/src/lerobot/processor/normalize_processor.py @@ -281,8 +281,14 @@ class _NormalizationMixin: """ Core logic to apply a normalization or unnormalization transformation to a tensor. - This method selects the appropriate normalization mode (e.g., mean/std, min/max) - based on the feature type and applies the corresponding mathematical operation. + This method selects the appropriate normalization mode based on the feature type + and applies the corresponding mathematical operation. + + Normalization Modes: + - MEAN_STD: Centers data around zero with unit variance. + - MIN_MAX: Scales data to [-1, 1] range using actual min/max values. + - QUANTILES: Scales data to [0, 1] range using 1st and 99th percentiles (q01/q99). + - QUANTILE10: Scales data to [0, 1] range using 10th and 90th percentiles (q10/q90). Args: tensor: The input tensor to transform. @@ -300,7 +306,12 @@ class _NormalizationMixin: if norm_mode == NormalizationMode.IDENTITY or key not in self._tensor_stats: return tensor - if norm_mode not in (NormalizationMode.MEAN_STD, NormalizationMode.MIN_MAX): + if norm_mode not in ( + NormalizationMode.MEAN_STD, + NormalizationMode.MIN_MAX, + NormalizationMode.QUANTILES, + NormalizationMode.QUANTILE10, + ): raise ValueError(f"Unsupported normalization mode: {norm_mode}") # For Accelerate compatibility: Ensure stats are on the same device and dtype as the input tensor @@ -334,6 +345,28 @@ class _NormalizationMixin: # Map from [min, max] to [-1, 1] return 2 * (tensor - min_val) / denom - 1 + if norm_mode == NormalizationMode.QUANTILES and "q01" in stats and "q99" in stats: + q01, q99 = stats["q01"], stats["q99"] + denom = q99 - q01 + # Avoid division by zero by adding epsilon when quantiles are identical + denom = torch.where( + denom == 0, torch.tensor(self.eps, device=tensor.device, dtype=tensor.dtype), denom + ) + if inverse: + return tensor * denom + q01 + return (tensor - q01) / denom + + if norm_mode == NormalizationMode.QUANTILE10 and "q10" in stats and "q90" in stats: + q10, q90 = stats["q10"], stats["q90"] + denom = q90 - q10 + # Avoid division by zero by adding epsilon when quantiles are identical + denom = torch.where( + denom == 0, torch.tensor(self.eps, device=tensor.device, dtype=tensor.dtype), denom + ) + if inverse: + return tensor * denom + q10 + return (tensor - q10) / denom + # If necessary stats are missing, return input unchanged. return tensor diff --git a/tests/datasets/test_compute_stats.py b/tests/datasets/test_compute_stats.py index 8f8179c29..915855158 100644 --- a/tests/datasets/test_compute_stats.py +++ b/tests/datasets/test_compute_stats.py @@ -19,6 +19,7 @@ import numpy as np import pytest from lerobot.datasets.compute_stats import ( + RunningQuantileStats, _assert_type_and_shape, aggregate_feature_stats, aggregate_stats, @@ -101,6 +102,9 @@ def test_get_feature_stats_axis_1(sample_array): "count": np.array([3]), } result = get_feature_stats(sample_array, axis=(1,), keepdims=False) + + # Check that basic stats are correct (quantiles are also included now) + assert set(expected.keys()).issubset(set(result.keys())) for key in expected: np.testing.assert_allclose(result[key], expected[key]) @@ -114,6 +118,9 @@ def test_get_feature_stats_no_axis(sample_array): "count": np.array([3]), } result = get_feature_stats(sample_array, axis=None, keepdims=False) + + # Check that basic stats are correct (quantiles are also included now) + assert set(expected.keys()).issubset(set(result.keys())) for key in expected: np.testing.assert_allclose(result[key], expected[key]) @@ -307,3 +314,520 @@ def test_aggregate_stats(): results[fkey]["std"], expected_agg_stats[fkey]["std"], atol=1e-04, rtol=1e-04 ) np.testing.assert_allclose(results[fkey]["count"], expected_agg_stats[fkey]["count"]) + + +def test_running_quantile_stats_initialization(): + """Test proper initialization of RunningQuantileStats.""" + running_stats = RunningQuantileStats() + assert running_stats._count == 0 + assert running_stats._mean is None + assert running_stats._num_quantile_bins == 5000 + + # Test custom bin size + running_stats_custom = RunningQuantileStats(num_quantile_bins=1000) + assert running_stats_custom._num_quantile_bins == 1000 + + +def test_running_quantile_stats_single_batch_update(): + """Test updating with a single batch.""" + np.random.seed(42) + data = np.random.normal(0, 1, (100, 3)) + + running_stats = RunningQuantileStats() + running_stats.update(data) + + assert running_stats._count == 100 + assert running_stats._mean.shape == (3,) + assert len(running_stats._histograms) == 3 + assert len(running_stats._bin_edges) == 3 + + # Verify basic statistics are reasonable + np.testing.assert_allclose(running_stats._mean, np.mean(data, axis=0), atol=1e-10) + + +def test_running_quantile_stats_multiple_batch_updates(): + """Test updating with multiple batches.""" + np.random.seed(42) + data1 = np.random.normal(0, 1, (100, 2)) + data2 = np.random.normal(1, 1, (150, 2)) + + running_stats = RunningQuantileStats() + running_stats.update(data1) + running_stats.update(data2) + + assert running_stats._count == 250 + + # Verify running mean is correct + combined_data = np.vstack([data1, data2]) + expected_mean = np.mean(combined_data, axis=0) + np.testing.assert_allclose(running_stats._mean, expected_mean, atol=1e-10) + + +def test_running_quantile_stats_get_statistics_basic(): + """Test getting basic statistics without quantiles.""" + np.random.seed(42) + data = np.random.normal(0, 1, (100, 2)) + + running_stats = RunningQuantileStats() + running_stats.update(data) + + stats = running_stats.get_statistics() + + # Should have basic stats + expected_keys = {"min", "max", "mean", "std", "count"} + assert expected_keys.issubset(set(stats.keys())) + + # Verify values + np.testing.assert_allclose(stats["mean"], np.mean(data, axis=0), atol=1e-10) + np.testing.assert_allclose(stats["std"], np.std(data, axis=0), atol=1e-6) + np.testing.assert_equal(stats["count"], np.array([100])) + + +def test_running_quantile_stats_get_statistics_with_quantiles(): + """Test getting statistics with quantiles.""" + np.random.seed(42) + data = np.random.normal(0, 1, (1000, 2)) + + running_stats = RunningQuantileStats() + running_stats.update(data) + + stats = running_stats.get_statistics() + + # Should have basic stats plus quantiles + expected_keys = {"min", "max", "mean", "std", "count", "q01", "q10", "q50", "q90", "q99"} + assert expected_keys.issubset(set(stats.keys())) + + # Verify quantile values are reasonable + from lerobot.datasets.compute_stats import DEFAULT_QUANTILES + + for i, q in enumerate(DEFAULT_QUANTILES): + q_key = f"q{int(q * 100):02d}" + assert q_key in stats + assert stats[q_key].shape == (2,) + + # Check that quantiles are in reasonable order + if i > 0: + prev_q_key = f"q{int(DEFAULT_QUANTILES[i - 1] * 100):02d}" + assert np.all(stats[prev_q_key] <= stats[q_key]) + + +def test_running_quantile_stats_histogram_adjustment(): + """Test that histograms adjust when min/max change.""" + running_stats = RunningQuantileStats() + + # Initial data with small range + data1 = np.array([[0.0, 1.0], [0.1, 1.1], [0.2, 1.2]]) + running_stats.update(data1) + + initial_edges_0 = running_stats._bin_edges[0].copy() + initial_edges_1 = running_stats._bin_edges[1].copy() + + # Add data with much larger range + data2 = np.array([[10.0, -10.0], [11.0, -11.0]]) + running_stats.update(data2) + + # Bin edges should have changed + assert not np.array_equal(initial_edges_0, running_stats._bin_edges[0]) + assert not np.array_equal(initial_edges_1, running_stats._bin_edges[1]) + + # New edges should cover the expanded range + # First dimension: min should still be ~0.0, max should be ~11.0 + assert running_stats._bin_edges[0][0] <= 0.0 + assert running_stats._bin_edges[0][-1] >= 11.0 + + # Second dimension: min should be ~-11.0, max should be ~1.2 + assert running_stats._bin_edges[1][0] <= -11.0 + assert running_stats._bin_edges[1][-1] >= 1.2 + + +def test_running_quantile_stats_insufficient_data_error(): + """Test error when trying to get stats with insufficient data.""" + running_stats = RunningQuantileStats() + + with pytest.raises(ValueError, match="Cannot compute statistics for less than 2 vectors"): + running_stats.get_statistics() + + # Single vector should also fail + running_stats.update(np.array([[1.0]])) + with pytest.raises(ValueError, match="Cannot compute statistics for less than 2 vectors"): + running_stats.get_statistics() + + +def test_running_quantile_stats_vector_length_consistency(): + """Test error when vector lengths don't match.""" + running_stats = RunningQuantileStats() + running_stats.update(np.array([[1.0, 2.0], [3.0, 4.0]])) + + with pytest.raises(ValueError, match="The length of new vectors does not match"): + running_stats.update(np.array([[1.0, 2.0, 3.0]])) # Different length + + +def test_running_quantile_stats_reshape_handling(): + """Test that various input shapes are handled correctly.""" + running_stats = RunningQuantileStats() + + # Test 3D input (e.g., images) + data_3d = np.random.normal(0, 1, (10, 32, 32)) + running_stats.update(data_3d) + + assert running_stats._count == 10 * 32 + assert running_stats._mean.shape == (32,) + + # Test 1D input + running_stats_1d = RunningQuantileStats() + data_1d = np.array([1, 2, 3, 4, 5]).reshape(-1, 1) + running_stats_1d.update(data_1d) + + assert running_stats_1d._count == 5 + assert running_stats_1d._mean.shape == (1,) + + +def test_get_feature_stats_quantiles_enabled_by_default(): + """Test that quantiles are computed by default.""" + data = np.random.normal(0, 1, (100, 5)) + stats = get_feature_stats(data, axis=0, keepdims=False) + + expected_keys = {"min", "max", "mean", "std", "count", "q01", "q10", "q50", "q90", "q99"} + assert set(stats.keys()) == expected_keys + + +def test_get_feature_stats_quantiles_with_vector_data(): + """Test quantile computation with vector data.""" + np.random.seed(42) + data = np.random.normal(0, 1, (100, 5)) + + stats = get_feature_stats(data, axis=0, keepdims=False) + + expected_keys = {"min", "max", "mean", "std", "count", "q01", "q10", "q50", "q90", "q99"} + assert set(stats.keys()) == expected_keys + + # Verify shapes + assert stats["q01"].shape == (5,) + assert stats["q99"].shape == (5,) + + # Verify quantiles are reasonable + assert np.all(stats["q01"] < stats["q99"]) + + +def test_get_feature_stats_quantiles_with_image_data(): + """Test quantile computation with image data.""" + np.random.seed(42) + data = np.random.normal(0, 1, (50, 3, 32, 32)) # batch, channels, height, width + + stats = get_feature_stats(data, axis=(0, 2, 3), keepdims=True) + + expected_keys = {"min", "max", "mean", "std", "count", "q01", "q10", "q50", "q90", "q99"} + assert set(stats.keys()) == expected_keys + + # Verify shapes for images (should be (1, channels, 1, 1)) + assert stats["q01"].shape == (1, 3, 1, 1) + assert stats["q50"].shape == (1, 3, 1, 1) + assert stats["q99"].shape == (1, 3, 1, 1) + + +def test_get_feature_stats_fixed_quantiles(): + """Test that fixed quantiles are always computed.""" + data = np.random.normal(0, 1, (200, 3)) + + stats = get_feature_stats(data, axis=0, keepdims=False) + + expected_quantile_keys = {"q01", "q10", "q50", "q90", "q99"} + assert expected_quantile_keys.issubset(set(stats.keys())) + + +def test_get_feature_stats_unsupported_axis_error(): + """Test error for unsupported axis configuration.""" + data = np.random.normal(0, 1, (10, 5)) + + with pytest.raises(ValueError, match="Unsupported axis configuration"): + get_feature_stats( + data, + axis=(1, 2), # Unsupported axis + keepdims=False, + ) + + +def test_compute_episode_stats_backward_compatibility(): + """Test that existing functionality is preserved.""" + episode_data = { + "action": np.random.normal(0, 1, (100, 7)), + "observation.state": np.random.normal(0, 1, (100, 10)), + } + features = { + "action": {"dtype": "float32", "shape": (7,)}, + "observation.state": {"dtype": "float32", "shape": (10,)}, + } + + stats = compute_episode_stats(episode_data, features) + + for key in ["action", "observation.state"]: + expected_keys = {"min", "max", "mean", "std", "count", "q01", "q10", "q50", "q90", "q99"} + assert set(stats[key].keys()) == expected_keys + + +def test_compute_episode_stats_with_custom_quantiles(): + """Test quantile computation with custom quantile values.""" + np.random.seed(42) + episode_data = { + "action": np.random.normal(0, 1, (100, 7)), + "observation.state": np.random.normal(2, 1, (100, 10)), + } + features = { + "action": {"dtype": "float32", "shape": (7,)}, + "observation.state": {"dtype": "float32", "shape": (10,)}, + } + + stats = compute_episode_stats(episode_data, features) + + # Should have quantiles + for key in ["action", "observation.state"]: + expected_keys = {"min", "max", "mean", "std", "count", "q01", "q10", "q50", "q90", "q99"} + assert set(stats[key].keys()) == expected_keys + + # Verify shapes + assert stats[key]["q01"].shape == (features[key]["shape"][0],) + assert stats[key]["q99"].shape == (features[key]["shape"][0],) + + +def test_compute_episode_stats_with_image_data(): + """Test quantile computation with image features.""" + image_paths = [f"image_{i}.jpg" for i in range(50)] + episode_data = { + "observation.image": image_paths, + "action": np.random.normal(0, 1, (50, 5)), + } + features = { + "observation.image": {"dtype": "image"}, + "action": {"dtype": "float32", "shape": (5,)}, + } + + with patch("lerobot.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy): + stats = compute_episode_stats(episode_data, features) + + # Image quantiles should be normalized and have correct shape + assert "q01" in stats["observation.image"] + assert "q50" in stats["observation.image"] + assert "q99" in stats["observation.image"] + assert stats["observation.image"]["q01"].shape == (3, 1, 1) + assert stats["observation.image"]["q50"].shape == (3, 1, 1) + assert stats["observation.image"]["q99"].shape == (3, 1, 1) + + # Action quantiles should have correct shape + assert stats["action"]["q01"].shape == (5,) + assert stats["action"]["q50"].shape == (5,) + assert stats["action"]["q99"].shape == (5,) + + +def test_compute_episode_stats_string_features_skipped(): + """Test that string features are properly skipped.""" + episode_data = { + "task": ["pick_apple"] * 100, # String feature + "action": np.random.normal(0, 1, (100, 5)), + } + features = { + "task": {"dtype": "string"}, + "action": {"dtype": "float32", "shape": (5,)}, + } + + stats = compute_episode_stats( + episode_data, + features, + ) + + # String features should be skipped + assert "task" not in stats + assert "action" in stats + assert "q01" in stats["action"] + + +def test_aggregate_feature_stats_with_quantiles(): + """Test aggregating feature stats that include quantiles.""" + stats_ft_list = [ + { + "min": np.array([1.0]), + "max": np.array([10.0]), + "mean": np.array([5.0]), + "std": np.array([2.0]), + "count": np.array([100]), + "q01": np.array([1.5]), + "q99": np.array([9.5]), + }, + { + "min": np.array([2.0]), + "max": np.array([12.0]), + "mean": np.array([6.0]), + "std": np.array([2.5]), + "count": np.array([150]), + "q01": np.array([2.5]), + "q99": np.array([11.5]), + }, + ] + + result = aggregate_feature_stats(stats_ft_list) + + # Should preserve quantiles + assert "q01" in result + assert "q99" in result + + # Verify quantile aggregation (weighted average) + expected_q01 = (1.5 * 100 + 2.5 * 150) / 250 # ≈ 2.1 + expected_q99 = (9.5 * 100 + 11.5 * 150) / 250 # ≈ 10.7 + + np.testing.assert_allclose(result["q01"], np.array([expected_q01]), atol=1e-6) + np.testing.assert_allclose(result["q99"], np.array([expected_q99]), atol=1e-6) + + +def test_aggregate_stats_mixed_quantiles(): + """Test aggregating stats where some have quantiles and some don't.""" + stats_with_quantiles = { + "feature1": { + "min": np.array([1.0]), + "max": np.array([10.0]), + "mean": np.array([5.0]), + "std": np.array([2.0]), + "count": np.array([100]), + "q01": np.array([1.5]), + "q99": np.array([9.5]), + } + } + + stats_without_quantiles = { + "feature2": { + "min": np.array([0.0]), + "max": np.array([5.0]), + "mean": np.array([2.5]), + "std": np.array([1.5]), + "count": np.array([50]), + } + } + + all_stats = [stats_with_quantiles, stats_without_quantiles] + result = aggregate_stats(all_stats) + + # Feature1 should keep its quantiles + assert "q01" in result["feature1"] + assert "q99" in result["feature1"] + + # Feature2 should not have quantiles + assert "q01" not in result["feature2"] + assert "q99" not in result["feature2"] + + +def test_assert_type_and_shape_with_quantiles(): + """Test validation works correctly with quantile keys.""" + # Valid stats with quantiles + valid_stats = [ + { + "observation.image": { + "min": np.array([0.0, 0.0, 0.0]).reshape(3, 1, 1), + "max": np.array([1.0, 1.0, 1.0]).reshape(3, 1, 1), + "mean": np.array([0.5, 0.5, 0.5]).reshape(3, 1, 1), + "std": np.array([0.2, 0.2, 0.2]).reshape(3, 1, 1), + "count": np.array([100]), + "q01": np.array([0.1, 0.1, 0.1]).reshape(3, 1, 1), + "q99": np.array([0.9, 0.9, 0.9]).reshape(3, 1, 1), + } + } + ] + + # Should not raise error + _assert_type_and_shape(valid_stats) + + # Invalid shape for quantile + invalid_stats = [ + { + "observation.image": { + "count": np.array([100]), + "q01": np.array([0.1, 0.2]), # Wrong shape for image quantile + } + } + ] + + with pytest.raises(ValueError, match="Shape of quantile 'q01' must be \\(3,1,1\\)"): + _assert_type_and_shape(invalid_stats) + + +def test_quantile_integration_single_value_quantiles(): + """Test quantile computation with single repeated value.""" + data = np.ones((100, 3)) # All ones + + running_stats = RunningQuantileStats() + running_stats.update(data) + + stats = running_stats.get_statistics() + + # All quantiles should be approximately 1.0 + np.testing.assert_allclose(stats["q01"], np.array([1.0, 1.0, 1.0]), atol=1e-6) + np.testing.assert_allclose(stats["q50"], np.array([1.0, 1.0, 1.0]), atol=1e-6) + np.testing.assert_allclose(stats["q99"], np.array([1.0, 1.0, 1.0]), atol=1e-6) + + +def test_quantile_integration_fixed_quantiles(): + """Test that fixed quantiles are computed.""" + np.random.seed(42) + data = np.random.normal(0, 1, (1000, 2)) + + stats = get_feature_stats(data, axis=0, keepdims=False) + + # Check all fixed quantiles are present + assert "q01" in stats + assert "q10" in stats + assert "q50" in stats + assert "q90" in stats + assert "q99" in stats + + +def test_quantile_integration_large_dataset_quantiles(): + """Test quantile computation efficiency with large datasets.""" + np.random.seed(42) + large_data = np.random.normal(0, 1, (10000, 5)) + + running_stats = RunningQuantileStats(num_quantile_bins=1000) # Reduced bins for speed + running_stats.update(large_data) + + stats = running_stats.get_statistics() + + # Should complete without issues and produce reasonable results + assert stats["count"][0] == 10000 + assert len(stats["q01"]) == 5 + + +def test_fixed_quantiles_always_computed(): + """Test that the fixed quantiles [0.01, 0.10, 0.50, 0.90, 0.99] are always computed.""" + np.random.seed(42) + # Test with vector data + vector_data = np.random.normal(0, 1, (100, 5)) + vector_stats = get_feature_stats(vector_data, axis=0, keepdims=False) + + # Check all fixed quantiles are present + expected_quantiles = ["q01", "q10", "q50", "q90", "q99"] + for q_key in expected_quantiles: + assert q_key in vector_stats + assert vector_stats[q_key].shape == (5,) + + # Test with image data + image_data = np.random.randint(0, 256, (50, 3, 32, 32), dtype=np.uint8) + image_stats = get_feature_stats(image_data, axis=(0, 2, 3), keepdims=True) + + # Check all fixed quantiles are present for images + for q_key in expected_quantiles: + assert q_key in image_stats + assert image_stats[q_key].shape == (1, 3, 1, 1) + + # Test with episode data + episode_data = { + "action": np.random.normal(0, 1, (100, 7)), + "observation.state": np.random.normal(0, 1, (100, 10)), + } + features = { + "action": {"dtype": "float32", "shape": (7,)}, + "observation.state": {"dtype": "float32", "shape": (10,)}, + } + + episode_stats = compute_episode_stats(episode_data, features) + + # Check all fixed quantiles are present in episode stats + for key in ["action", "observation.state"]: + for q_key in expected_quantiles: + assert q_key in episode_stats[key] + assert episode_stats[key][q_key].shape == (features[key]["shape"][0],) diff --git a/tests/datasets/test_quantiles_dataset_integration.py b/tests/datasets/test_quantiles_dataset_integration.py new file mode 100644 index 000000000..4df7fab06 --- /dev/null +++ b/tests/datasets/test_quantiles_dataset_integration.py @@ -0,0 +1,212 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Integration tests for quantile functionality in LeRobotDataset.""" + +import numpy as np +import pytest + +from lerobot.datasets.lerobot_dataset import LeRobotDataset + + +def mock_load_image_as_numpy(path, dtype, channel_first): + """Mock image loading for consistent test results.""" + return np.ones((3, 32, 32), dtype=dtype) if channel_first else np.ones((32, 32, 3), dtype=dtype) + + +@pytest.fixture +def simple_features(): + """Simple feature configuration for testing.""" + return { + "action": { + "dtype": "float32", + "shape": (4,), + "names": ["arm_x", "arm_y", "arm_z", "gripper"], + }, + "observation.state": { + "dtype": "float32", + "shape": (10,), + "names": [f"joint_{i}" for i in range(10)], + }, + } + + +def test_create_dataset_with_fixed_quantiles(tmp_path, simple_features): + """Test creating dataset with fixed quantiles.""" + dataset = LeRobotDataset.create( + repo_id="test_dataset_fixed_quantiles", + fps=30, + features=simple_features, + root=tmp_path / "create_fixed_quantiles", + ) + + # Dataset should be created successfully + assert dataset is not None + + +def test_save_episode_computes_all_quantiles(tmp_path, simple_features): + """Test that all fixed quantiles are computed when saving an episode.""" + dataset = LeRobotDataset.create( + repo_id="test_dataset_save_episode", + fps=30, + features=simple_features, + root=tmp_path / "save_episode_quantiles", + ) + + # Add some frames + for _ in range(10): + dataset.add_frame( + { + "action": np.random.randn(4).astype(np.float32), # Correct shape for action + "observation.state": np.random.randn(10).astype(np.float32), + "task": "test_task", + } + ) + + dataset.save_episode() + + # Check that all fixed quantiles were computed + stats = dataset.meta.stats + for key in ["action", "observation.state"]: + assert "q01" in stats[key] + assert "q10" in stats[key] + assert "q50" in stats[key] + assert "q90" in stats[key] + assert "q99" in stats[key] + + +def test_quantile_values_ordering(tmp_path, simple_features): + """Test that quantile values are properly ordered.""" + dataset = LeRobotDataset.create( + repo_id="test_dataset_quantile_ordering", + fps=30, + features=simple_features, + root=tmp_path / "quantile_ordering", + ) + + # Add data with known distribution + np.random.seed(42) + for _ in range(100): + dataset.add_frame( + { + "action": np.random.randn(4).astype(np.float32), # Correct shape for action + "observation.state": np.random.randn(10).astype(np.float32), + "task": "test_task", + } + ) + + dataset.save_episode() + stats = dataset.meta.stats + + # Verify quantile ordering + for key in ["action", "observation.state"]: + assert np.all(stats[key]["q01"] <= stats[key]["q10"]) + assert np.all(stats[key]["q10"] <= stats[key]["q50"]) + assert np.all(stats[key]["q50"] <= stats[key]["q90"]) + assert np.all(stats[key]["q90"] <= stats[key]["q99"]) + + +def test_save_episode_with_fixed_quantiles(tmp_path, simple_features): + """Test saving episode always computes fixed quantiles.""" + dataset = LeRobotDataset.create( + repo_id="test_dataset_save_fixed", + fps=30, + features=simple_features, + root=tmp_path / "save_fixed_quantiles", + ) + + # Add frames to episode + np.random.seed(42) + for _ in range(50): + frame = { + "action": np.random.normal(0, 1, (4,)).astype(np.float32), + "observation.state": np.random.normal(0, 1, (10,)).astype(np.float32), + "task": "test_task", + } + dataset.add_frame(frame) + + dataset.save_episode() + + # Check that all fixed quantiles are included + stats = dataset.meta.stats + for key in ["action", "observation.state"]: + feature_stats = stats[key] + expected_keys = {"min", "max", "mean", "std", "count", "q01", "q10", "q50", "q90", "q99"} + assert set(feature_stats.keys()) == expected_keys + + +def test_quantile_aggregation_across_episodes(tmp_path, simple_features): + """Test quantile aggregation across multiple episodes.""" + dataset = LeRobotDataset.create( + repo_id="test_dataset_aggregation", + fps=30, + features=simple_features, + root=tmp_path / "quantile_aggregation", + ) + + # Add frames to episode + np.random.seed(42) + for _ in range(100): + frame = { + "action": np.random.normal(0, 1, (4,)).astype(np.float32), + "observation.state": np.random.normal(2, 1, (10,)).astype(np.float32), + "task": "test_task", + } + dataset.add_frame(frame) + + dataset.save_episode() + + # Check stats include all fixed quantiles + stats = dataset.meta.stats + for key in ["action", "observation.state"]: + feature_stats = stats[key] + expected_keys = {"min", "max", "mean", "std", "count", "q01", "q10", "q50", "q90", "q99"} + assert set(feature_stats.keys()) == expected_keys + assert feature_stats["q01"].shape == (simple_features[key]["shape"][0],) + assert feature_stats["q50"].shape == (simple_features[key]["shape"][0],) + assert feature_stats["q99"].shape == (simple_features[key]["shape"][0],) + assert np.all(feature_stats["q01"] <= feature_stats["q50"]) + assert np.all(feature_stats["q50"] <= feature_stats["q99"]) + + +def test_save_multiple_episodes_with_quantiles(tmp_path, simple_features): + """Test quantile aggregation across multiple episodes.""" + dataset = LeRobotDataset.create( + repo_id="test_dataset_multiple_episodes", + fps=30, + features=simple_features, + root=tmp_path / "multiple_episodes", + ) + + # Save multiple episodes + np.random.seed(42) + for episode_idx in range(3): + for _ in range(50): + frame = { + "action": np.random.normal(episode_idx * 2.0, 1, (4,)).astype(np.float32), + "observation.state": np.random.normal(-episode_idx * 1.5, 1, (10,)).astype(np.float32), + "task": f"task_{episode_idx}", + } + dataset.add_frame(frame) + + dataset.save_episode() + + # Verify final stats include properly aggregated quantiles + stats = dataset.meta.stats + for key in ["action", "observation.state"]: + feature_stats = stats[key] + assert "q01" in feature_stats and "q99" in feature_stats + assert feature_stats["count"][0] == 150 # 3 episodes * 50 frames diff --git a/tests/processor/test_normalize_processor.py b/tests/processor/test_normalize_processor.py index 5d7791919..9669b4ea9 100644 --- a/tests/processor/test_normalize_processor.py +++ b/tests/processor/test_normalize_processor.py @@ -165,6 +165,229 @@ def test_min_max_normalization(observation_normalizer): assert torch.allclose(normalized_obs["observation.state"], expected_state, atol=1e-6) +def test_quantile_normalization(): + """Test QUANTILES mode using 1st-99th percentiles.""" + features = { + "observation.state": PolicyFeature(FeatureType.STATE, (2,)), + } + norm_map = { + FeatureType.STATE: NormalizationMode.QUANTILES, + } + stats = { + "observation.state": { + "q01": np.array([0.1, -0.8]), # 1st percentile + "q99": np.array([0.9, 0.8]), # 99th percentile + }, + } + + normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) + + observation = { + "observation.state": torch.tensor([0.5, 0.0]), + } + transition = create_transition(observation=observation) + + normalized_transition = normalizer(transition) + normalized_obs = normalized_transition[TransitionKey.OBSERVATION] + + # Check quantile normalization to [0, 1] + # For state[0]: (0.5 - 0.1) / (0.9 - 0.1) = 0.4 / 0.8 = 0.5 + # For state[1]: (0.0 - (-0.8)) / (0.8 - (-0.8)) = 0.8 / 1.6 = 0.5 + expected_state = torch.tensor([0.5, 0.5]) + assert torch.allclose(normalized_obs["observation.state"], expected_state, atol=1e-6) + + +def test_quantile10_normalization(): + """Test QUANTILE10 mode using 10th-90th percentiles.""" + features = { + "observation.state": PolicyFeature(FeatureType.STATE, (2,)), + } + norm_map = { + FeatureType.STATE: NormalizationMode.QUANTILE10, + } + stats = { + "observation.state": { + "q10": np.array([0.2, -0.6]), # 10th percentile + "q90": np.array([0.8, 0.6]), # 90th percentile + }, + } + + normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) + + observation = { + "observation.state": torch.tensor([0.5, 0.0]), + } + transition = create_transition(observation=observation) + + normalized_transition = normalizer(transition) + normalized_obs = normalized_transition[TransitionKey.OBSERVATION] + + # Check quantile normalization to [0, 1] + # For state[0]: (0.5 - 0.2) / (0.8 - 0.2) = 0.3 / 0.6 = 0.5 + # For state[1]: (0.0 - (-0.6)) / (0.6 - (-0.6)) = 0.6 / 1.2 = 0.5 + expected_state = torch.tensor([0.5, 0.5]) + assert torch.allclose(normalized_obs["observation.state"], expected_state, atol=1e-6) + + +def test_quantile_unnormalization(): + """Test that quantile normalization can be reversed properly.""" + features = { + "action": PolicyFeature(FeatureType.ACTION, (2,)), + } + norm_map = { + FeatureType.ACTION: NormalizationMode.QUANTILES, + } + stats = { + "action": { + "q01": np.array([0.1, -0.8]), + "q99": np.array([0.9, 0.8]), + }, + } + + normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) + unnormalizer = UnnormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) + + # Test round-trip normalization + original_action = torch.tensor([0.5, 0.0]) + transition = create_transition(action=original_action) + + # Normalize then unnormalize + normalized = normalizer(transition) + unnormalized = unnormalizer(normalized) + + # Should recover original values + recovered_action = unnormalized[TransitionKey.ACTION] + assert torch.allclose(recovered_action, original_action, atol=1e-6) + + +def test_quantile_division_by_zero(): + """Test quantile normalization handles edge case where q01 == q99.""" + features = { + "observation.state": PolicyFeature(FeatureType.STATE, (1,)), + } + norm_map = { + FeatureType.STATE: NormalizationMode.QUANTILES, + } + stats = { + "observation.state": { + "q01": np.array([0.5]), # Same value + "q99": np.array([0.5]), # Same value -> division by zero case + }, + } + + normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) + + observation = { + "observation.state": torch.tensor([0.5]), + } + transition = create_transition(observation=observation) + + # Should not crash and should handle gracefully + normalized_transition = normalizer(transition) + normalized_obs = normalized_transition[TransitionKey.OBSERVATION] + + # When quantiles are identical, should normalize to 0 (due to epsilon handling) + assert torch.isfinite(normalized_obs["observation.state"]).all() + + +def test_quantile_partial_stats(): + """Test that quantile normalization handles missing quantile stats gracefully.""" + features = { + "observation.state": PolicyFeature(FeatureType.STATE, (2,)), + } + norm_map = { + FeatureType.STATE: NormalizationMode.QUANTILES, + } + + # Missing q99 - should pass through unchanged + stats_partial = { + "observation.state": { + "q01": np.array([0.1, -0.8]), # Only q01, missing q99 + }, + } + + normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats_partial) + + observation = { + "observation.state": torch.tensor([0.5, 0.0]), + } + transition = create_transition(observation=observation) + + normalized_transition = normalizer(transition) + normalized_obs = normalized_transition[TransitionKey.OBSERVATION] + + # Should pass through unchanged when stats are incomplete + assert torch.allclose(normalized_obs["observation.state"], observation["observation.state"]) + + +def test_quantile_mixed_with_other_modes(): + """Test quantile normalization mixed with other normalization modes.""" + features = { + "observation.image": PolicyFeature(FeatureType.VISUAL, (3,)), + "observation.state": PolicyFeature(FeatureType.STATE, (2,)), + "action": PolicyFeature(FeatureType.ACTION, (2,)), + } + norm_map = { + FeatureType.VISUAL: NormalizationMode.MEAN_STD, # Standard normalization + FeatureType.STATE: NormalizationMode.QUANTILES, # Quantile normalization + FeatureType.ACTION: NormalizationMode.QUANTILE10, # Different quantile mode + } + stats = { + "observation.image": {"mean": [0.5, 0.5, 0.5], "std": [0.2, 0.2, 0.2]}, + "observation.state": {"q01": [0.1, -0.8], "q99": [0.9, 0.8]}, + "action": {"q10": [0.2, -0.6], "q90": [0.8, 0.6]}, + } + + normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) + + observation = { + "observation.image": torch.tensor([0.7, 0.5, 0.3]), + "observation.state": torch.tensor([0.5, 0.0]), # Should use QUANTILES + } + action = torch.tensor([0.5, 0.0]) # Should use QUANTILE10 + transition = create_transition(observation=observation, action=action) + + normalized_transition = normalizer(transition) + normalized_obs = normalized_transition[TransitionKey.OBSERVATION] + normalized_action = normalized_transition[TransitionKey.ACTION] + + # Image should be mean/std normalized: (0.7 - 0.5) / 0.2 = 1.0, etc. + expected_image = (torch.tensor([0.7, 0.5, 0.3]) - 0.5) / 0.2 + assert torch.allclose(normalized_obs["observation.image"], expected_image) + + # State should be quantile normalized: (0.5 - 0.1) / (0.9 - 0.1) = 0.5, etc. + expected_state = torch.tensor([0.5, 0.5]) + assert torch.allclose(normalized_obs["observation.state"], expected_state, atol=1e-6) + + # Action should be quantile10 normalized: (0.5 - 0.2) / (0.8 - 0.2) = 0.5, etc. + expected_action = torch.tensor([0.5, 0.5]) + assert torch.allclose(normalized_action, expected_action, atol=1e-6) + + +def test_quantile_with_missing_stats(): + """Test that quantile normalization handles completely missing stats gracefully.""" + features = { + "observation.state": PolicyFeature(FeatureType.STATE, (2,)), + } + norm_map = { + FeatureType.STATE: NormalizationMode.QUANTILES, + } + stats = {} # No stats provided + + normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) + + observation = { + "observation.state": torch.tensor([0.5, 0.0]), + } + transition = create_transition(observation=observation) + + normalized_transition = normalizer(transition) + normalized_obs = normalized_transition[TransitionKey.OBSERVATION] + + # Should pass through unchanged when no stats available + assert torch.allclose(normalized_obs["observation.state"], observation["observation.state"]) + + def test_selective_normalization(observation_stats): features = _create_observation_features() norm_map = _create_observation_norm_map()