mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-18 02:00:03 +00:00
Add Quantile stats to LeRobotDataset (#1985)
* - Add RunningQuantileStats class for efficient histogram-based quantile computation - Integrate quantile parameters (compute_quantiles, quantiles) into LeRobotDataset - Support quantile computation during episode collection and aggregation - Add comprehensive function-based test suite (24 tests) for quantile functionality - Maintain full backward compatibility with existing stats computation - Enable configurable quantiles (default: [0.01, 0.99]) for robust normalization * style fixes, make quantiles computation by default to new datasets * fix tests * - Added DEFAULT_QUANTILES=[0.01, 0.10, 0.50, 0.90, 0.99] to be computed for each features instead of being chosen by the user - Fortified tests. * - add helper functions to reshape stats - add missing test for quantiles * - Add QUANTILE normalization mode to normalize the data with the 1st and 99th percentiles. - Add QUANTILE10 normalization mode to normalize the data with the 10th and 90th percentiles. * style fixes * Added missing lisence * Simplify compute_stats * - added script `augment_dataset_quantile_stats.py` so that we can add quantile stats to existing v3 datasets that dont have quatniles - modified quantile computation instead of using the edge for the value, interpolate the values in the bin
This commit is contained in:
@@ -0,0 +1,212 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Integration tests for quantile functionality in LeRobotDataset."""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
|
||||
def mock_load_image_as_numpy(path, dtype, channel_first):
|
||||
"""Mock image loading for consistent test results."""
|
||||
return np.ones((3, 32, 32), dtype=dtype) if channel_first else np.ones((32, 32, 3), dtype=dtype)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def simple_features():
|
||||
"""Simple feature configuration for testing."""
|
||||
return {
|
||||
"action": {
|
||||
"dtype": "float32",
|
||||
"shape": (4,),
|
||||
"names": ["arm_x", "arm_y", "arm_z", "gripper"],
|
||||
},
|
||||
"observation.state": {
|
||||
"dtype": "float32",
|
||||
"shape": (10,),
|
||||
"names": [f"joint_{i}" for i in range(10)],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def test_create_dataset_with_fixed_quantiles(tmp_path, simple_features):
|
||||
"""Test creating dataset with fixed quantiles."""
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id="test_dataset_fixed_quantiles",
|
||||
fps=30,
|
||||
features=simple_features,
|
||||
root=tmp_path / "create_fixed_quantiles",
|
||||
)
|
||||
|
||||
# Dataset should be created successfully
|
||||
assert dataset is not None
|
||||
|
||||
|
||||
def test_save_episode_computes_all_quantiles(tmp_path, simple_features):
|
||||
"""Test that all fixed quantiles are computed when saving an episode."""
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id="test_dataset_save_episode",
|
||||
fps=30,
|
||||
features=simple_features,
|
||||
root=tmp_path / "save_episode_quantiles",
|
||||
)
|
||||
|
||||
# Add some frames
|
||||
for _ in range(10):
|
||||
dataset.add_frame(
|
||||
{
|
||||
"action": np.random.randn(4).astype(np.float32), # Correct shape for action
|
||||
"observation.state": np.random.randn(10).astype(np.float32),
|
||||
"task": "test_task",
|
||||
}
|
||||
)
|
||||
|
||||
dataset.save_episode()
|
||||
|
||||
# Check that all fixed quantiles were computed
|
||||
stats = dataset.meta.stats
|
||||
for key in ["action", "observation.state"]:
|
||||
assert "q01" in stats[key]
|
||||
assert "q10" in stats[key]
|
||||
assert "q50" in stats[key]
|
||||
assert "q90" in stats[key]
|
||||
assert "q99" in stats[key]
|
||||
|
||||
|
||||
def test_quantile_values_ordering(tmp_path, simple_features):
|
||||
"""Test that quantile values are properly ordered."""
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id="test_dataset_quantile_ordering",
|
||||
fps=30,
|
||||
features=simple_features,
|
||||
root=tmp_path / "quantile_ordering",
|
||||
)
|
||||
|
||||
# Add data with known distribution
|
||||
np.random.seed(42)
|
||||
for _ in range(100):
|
||||
dataset.add_frame(
|
||||
{
|
||||
"action": np.random.randn(4).astype(np.float32), # Correct shape for action
|
||||
"observation.state": np.random.randn(10).astype(np.float32),
|
||||
"task": "test_task",
|
||||
}
|
||||
)
|
||||
|
||||
dataset.save_episode()
|
||||
stats = dataset.meta.stats
|
||||
|
||||
# Verify quantile ordering
|
||||
for key in ["action", "observation.state"]:
|
||||
assert np.all(stats[key]["q01"] <= stats[key]["q10"])
|
||||
assert np.all(stats[key]["q10"] <= stats[key]["q50"])
|
||||
assert np.all(stats[key]["q50"] <= stats[key]["q90"])
|
||||
assert np.all(stats[key]["q90"] <= stats[key]["q99"])
|
||||
|
||||
|
||||
def test_save_episode_with_fixed_quantiles(tmp_path, simple_features):
|
||||
"""Test saving episode always computes fixed quantiles."""
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id="test_dataset_save_fixed",
|
||||
fps=30,
|
||||
features=simple_features,
|
||||
root=tmp_path / "save_fixed_quantiles",
|
||||
)
|
||||
|
||||
# Add frames to episode
|
||||
np.random.seed(42)
|
||||
for _ in range(50):
|
||||
frame = {
|
||||
"action": np.random.normal(0, 1, (4,)).astype(np.float32),
|
||||
"observation.state": np.random.normal(0, 1, (10,)).astype(np.float32),
|
||||
"task": "test_task",
|
||||
}
|
||||
dataset.add_frame(frame)
|
||||
|
||||
dataset.save_episode()
|
||||
|
||||
# Check that all fixed quantiles are included
|
||||
stats = dataset.meta.stats
|
||||
for key in ["action", "observation.state"]:
|
||||
feature_stats = stats[key]
|
||||
expected_keys = {"min", "max", "mean", "std", "count", "q01", "q10", "q50", "q90", "q99"}
|
||||
assert set(feature_stats.keys()) == expected_keys
|
||||
|
||||
|
||||
def test_quantile_aggregation_across_episodes(tmp_path, simple_features):
|
||||
"""Test quantile aggregation across multiple episodes."""
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id="test_dataset_aggregation",
|
||||
fps=30,
|
||||
features=simple_features,
|
||||
root=tmp_path / "quantile_aggregation",
|
||||
)
|
||||
|
||||
# Add frames to episode
|
||||
np.random.seed(42)
|
||||
for _ in range(100):
|
||||
frame = {
|
||||
"action": np.random.normal(0, 1, (4,)).astype(np.float32),
|
||||
"observation.state": np.random.normal(2, 1, (10,)).astype(np.float32),
|
||||
"task": "test_task",
|
||||
}
|
||||
dataset.add_frame(frame)
|
||||
|
||||
dataset.save_episode()
|
||||
|
||||
# Check stats include all fixed quantiles
|
||||
stats = dataset.meta.stats
|
||||
for key in ["action", "observation.state"]:
|
||||
feature_stats = stats[key]
|
||||
expected_keys = {"min", "max", "mean", "std", "count", "q01", "q10", "q50", "q90", "q99"}
|
||||
assert set(feature_stats.keys()) == expected_keys
|
||||
assert feature_stats["q01"].shape == (simple_features[key]["shape"][0],)
|
||||
assert feature_stats["q50"].shape == (simple_features[key]["shape"][0],)
|
||||
assert feature_stats["q99"].shape == (simple_features[key]["shape"][0],)
|
||||
assert np.all(feature_stats["q01"] <= feature_stats["q50"])
|
||||
assert np.all(feature_stats["q50"] <= feature_stats["q99"])
|
||||
|
||||
|
||||
def test_save_multiple_episodes_with_quantiles(tmp_path, simple_features):
|
||||
"""Test quantile aggregation across multiple episodes."""
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id="test_dataset_multiple_episodes",
|
||||
fps=30,
|
||||
features=simple_features,
|
||||
root=tmp_path / "multiple_episodes",
|
||||
)
|
||||
|
||||
# Save multiple episodes
|
||||
np.random.seed(42)
|
||||
for episode_idx in range(3):
|
||||
for _ in range(50):
|
||||
frame = {
|
||||
"action": np.random.normal(episode_idx * 2.0, 1, (4,)).astype(np.float32),
|
||||
"observation.state": np.random.normal(-episode_idx * 1.5, 1, (10,)).astype(np.float32),
|
||||
"task": f"task_{episode_idx}",
|
||||
}
|
||||
dataset.add_frame(frame)
|
||||
|
||||
dataset.save_episode()
|
||||
|
||||
# Verify final stats include properly aggregated quantiles
|
||||
stats = dataset.meta.stats
|
||||
for key in ["action", "observation.state"]:
|
||||
feature_stats = stats[key]
|
||||
assert "q01" in feature_stats and "q99" in feature_stats
|
||||
assert feature_stats["count"][0] == 150 # 3 episodes * 50 frames
|
||||
Reference in New Issue
Block a user