Add Quantile stats to LeRobotDataset (#1985)

* - Add RunningQuantileStats class for efficient histogram-based quantile computation
- Integrate quantile parameters (compute_quantiles, quantiles) into LeRobotDataset
- Support quantile computation during episode collection and aggregation
- Add comprehensive function-based test suite (24 tests) for quantile functionality
- Maintain full backward compatibility with existing stats computation
- Enable configurable quantiles (default: [0.01, 0.99]) for robust normalization

* style fixes, make quantiles computation by default to new datasets

* fix tests

* - Added DEFAULT_QUANTILES=[0.01, 0.10, 0.50, 0.90, 0.99] to be computed for each features instead of being chosen by the user
- Fortified tests.

* - add helper functions to reshape stats
- add missing test for quantiles

* - Add QUANTILE normalization mode to normalize the data with the 1st and 99th percentiles.
- Add QUANTILE10 normalization mode to normalize the data with the 10th and 90th percentiles.

* style fixes

* Added missing lisence

* Simplify compute_stats

* - added script `augment_dataset_quantile_stats.py` so that we can add quantile stats to existing v3 datasets that dont have quatniles
- modified quantile computation instead of using the edge for the value, interpolate the values in the bin
This commit is contained in:
Michel Aractingi
2025-09-22 17:57:32 +02:00
committed by GitHub
parent 5d9acf9d51
commit d691d1e4fe
7 changed files with 1689 additions and 34 deletions
+524
View File
@@ -19,6 +19,7 @@ import numpy as np
import pytest
from lerobot.datasets.compute_stats import (
RunningQuantileStats,
_assert_type_and_shape,
aggregate_feature_stats,
aggregate_stats,
@@ -101,6 +102,9 @@ def test_get_feature_stats_axis_1(sample_array):
"count": np.array([3]),
}
result = get_feature_stats(sample_array, axis=(1,), keepdims=False)
# Check that basic stats are correct (quantiles are also included now)
assert set(expected.keys()).issubset(set(result.keys()))
for key in expected:
np.testing.assert_allclose(result[key], expected[key])
@@ -114,6 +118,9 @@ def test_get_feature_stats_no_axis(sample_array):
"count": np.array([3]),
}
result = get_feature_stats(sample_array, axis=None, keepdims=False)
# Check that basic stats are correct (quantiles are also included now)
assert set(expected.keys()).issubset(set(result.keys()))
for key in expected:
np.testing.assert_allclose(result[key], expected[key])
@@ -307,3 +314,520 @@ def test_aggregate_stats():
results[fkey]["std"], expected_agg_stats[fkey]["std"], atol=1e-04, rtol=1e-04
)
np.testing.assert_allclose(results[fkey]["count"], expected_agg_stats[fkey]["count"])
def test_running_quantile_stats_initialization():
"""Test proper initialization of RunningQuantileStats."""
running_stats = RunningQuantileStats()
assert running_stats._count == 0
assert running_stats._mean is None
assert running_stats._num_quantile_bins == 5000
# Test custom bin size
running_stats_custom = RunningQuantileStats(num_quantile_bins=1000)
assert running_stats_custom._num_quantile_bins == 1000
def test_running_quantile_stats_single_batch_update():
"""Test updating with a single batch."""
np.random.seed(42)
data = np.random.normal(0, 1, (100, 3))
running_stats = RunningQuantileStats()
running_stats.update(data)
assert running_stats._count == 100
assert running_stats._mean.shape == (3,)
assert len(running_stats._histograms) == 3
assert len(running_stats._bin_edges) == 3
# Verify basic statistics are reasonable
np.testing.assert_allclose(running_stats._mean, np.mean(data, axis=0), atol=1e-10)
def test_running_quantile_stats_multiple_batch_updates():
"""Test updating with multiple batches."""
np.random.seed(42)
data1 = np.random.normal(0, 1, (100, 2))
data2 = np.random.normal(1, 1, (150, 2))
running_stats = RunningQuantileStats()
running_stats.update(data1)
running_stats.update(data2)
assert running_stats._count == 250
# Verify running mean is correct
combined_data = np.vstack([data1, data2])
expected_mean = np.mean(combined_data, axis=0)
np.testing.assert_allclose(running_stats._mean, expected_mean, atol=1e-10)
def test_running_quantile_stats_get_statistics_basic():
"""Test getting basic statistics without quantiles."""
np.random.seed(42)
data = np.random.normal(0, 1, (100, 2))
running_stats = RunningQuantileStats()
running_stats.update(data)
stats = running_stats.get_statistics()
# Should have basic stats
expected_keys = {"min", "max", "mean", "std", "count"}
assert expected_keys.issubset(set(stats.keys()))
# Verify values
np.testing.assert_allclose(stats["mean"], np.mean(data, axis=0), atol=1e-10)
np.testing.assert_allclose(stats["std"], np.std(data, axis=0), atol=1e-6)
np.testing.assert_equal(stats["count"], np.array([100]))
def test_running_quantile_stats_get_statistics_with_quantiles():
"""Test getting statistics with quantiles."""
np.random.seed(42)
data = np.random.normal(0, 1, (1000, 2))
running_stats = RunningQuantileStats()
running_stats.update(data)
stats = running_stats.get_statistics()
# Should have basic stats plus quantiles
expected_keys = {"min", "max", "mean", "std", "count", "q01", "q10", "q50", "q90", "q99"}
assert expected_keys.issubset(set(stats.keys()))
# Verify quantile values are reasonable
from lerobot.datasets.compute_stats import DEFAULT_QUANTILES
for i, q in enumerate(DEFAULT_QUANTILES):
q_key = f"q{int(q * 100):02d}"
assert q_key in stats
assert stats[q_key].shape == (2,)
# Check that quantiles are in reasonable order
if i > 0:
prev_q_key = f"q{int(DEFAULT_QUANTILES[i - 1] * 100):02d}"
assert np.all(stats[prev_q_key] <= stats[q_key])
def test_running_quantile_stats_histogram_adjustment():
"""Test that histograms adjust when min/max change."""
running_stats = RunningQuantileStats()
# Initial data with small range
data1 = np.array([[0.0, 1.0], [0.1, 1.1], [0.2, 1.2]])
running_stats.update(data1)
initial_edges_0 = running_stats._bin_edges[0].copy()
initial_edges_1 = running_stats._bin_edges[1].copy()
# Add data with much larger range
data2 = np.array([[10.0, -10.0], [11.0, -11.0]])
running_stats.update(data2)
# Bin edges should have changed
assert not np.array_equal(initial_edges_0, running_stats._bin_edges[0])
assert not np.array_equal(initial_edges_1, running_stats._bin_edges[1])
# New edges should cover the expanded range
# First dimension: min should still be ~0.0, max should be ~11.0
assert running_stats._bin_edges[0][0] <= 0.0
assert running_stats._bin_edges[0][-1] >= 11.0
# Second dimension: min should be ~-11.0, max should be ~1.2
assert running_stats._bin_edges[1][0] <= -11.0
assert running_stats._bin_edges[1][-1] >= 1.2
def test_running_quantile_stats_insufficient_data_error():
"""Test error when trying to get stats with insufficient data."""
running_stats = RunningQuantileStats()
with pytest.raises(ValueError, match="Cannot compute statistics for less than 2 vectors"):
running_stats.get_statistics()
# Single vector should also fail
running_stats.update(np.array([[1.0]]))
with pytest.raises(ValueError, match="Cannot compute statistics for less than 2 vectors"):
running_stats.get_statistics()
def test_running_quantile_stats_vector_length_consistency():
"""Test error when vector lengths don't match."""
running_stats = RunningQuantileStats()
running_stats.update(np.array([[1.0, 2.0], [3.0, 4.0]]))
with pytest.raises(ValueError, match="The length of new vectors does not match"):
running_stats.update(np.array([[1.0, 2.0, 3.0]])) # Different length
def test_running_quantile_stats_reshape_handling():
"""Test that various input shapes are handled correctly."""
running_stats = RunningQuantileStats()
# Test 3D input (e.g., images)
data_3d = np.random.normal(0, 1, (10, 32, 32))
running_stats.update(data_3d)
assert running_stats._count == 10 * 32
assert running_stats._mean.shape == (32,)
# Test 1D input
running_stats_1d = RunningQuantileStats()
data_1d = np.array([1, 2, 3, 4, 5]).reshape(-1, 1)
running_stats_1d.update(data_1d)
assert running_stats_1d._count == 5
assert running_stats_1d._mean.shape == (1,)
def test_get_feature_stats_quantiles_enabled_by_default():
"""Test that quantiles are computed by default."""
data = np.random.normal(0, 1, (100, 5))
stats = get_feature_stats(data, axis=0, keepdims=False)
expected_keys = {"min", "max", "mean", "std", "count", "q01", "q10", "q50", "q90", "q99"}
assert set(stats.keys()) == expected_keys
def test_get_feature_stats_quantiles_with_vector_data():
"""Test quantile computation with vector data."""
np.random.seed(42)
data = np.random.normal(0, 1, (100, 5))
stats = get_feature_stats(data, axis=0, keepdims=False)
expected_keys = {"min", "max", "mean", "std", "count", "q01", "q10", "q50", "q90", "q99"}
assert set(stats.keys()) == expected_keys
# Verify shapes
assert stats["q01"].shape == (5,)
assert stats["q99"].shape == (5,)
# Verify quantiles are reasonable
assert np.all(stats["q01"] < stats["q99"])
def test_get_feature_stats_quantiles_with_image_data():
"""Test quantile computation with image data."""
np.random.seed(42)
data = np.random.normal(0, 1, (50, 3, 32, 32)) # batch, channels, height, width
stats = get_feature_stats(data, axis=(0, 2, 3), keepdims=True)
expected_keys = {"min", "max", "mean", "std", "count", "q01", "q10", "q50", "q90", "q99"}
assert set(stats.keys()) == expected_keys
# Verify shapes for images (should be (1, channels, 1, 1))
assert stats["q01"].shape == (1, 3, 1, 1)
assert stats["q50"].shape == (1, 3, 1, 1)
assert stats["q99"].shape == (1, 3, 1, 1)
def test_get_feature_stats_fixed_quantiles():
"""Test that fixed quantiles are always computed."""
data = np.random.normal(0, 1, (200, 3))
stats = get_feature_stats(data, axis=0, keepdims=False)
expected_quantile_keys = {"q01", "q10", "q50", "q90", "q99"}
assert expected_quantile_keys.issubset(set(stats.keys()))
def test_get_feature_stats_unsupported_axis_error():
"""Test error for unsupported axis configuration."""
data = np.random.normal(0, 1, (10, 5))
with pytest.raises(ValueError, match="Unsupported axis configuration"):
get_feature_stats(
data,
axis=(1, 2), # Unsupported axis
keepdims=False,
)
def test_compute_episode_stats_backward_compatibility():
"""Test that existing functionality is preserved."""
episode_data = {
"action": np.random.normal(0, 1, (100, 7)),
"observation.state": np.random.normal(0, 1, (100, 10)),
}
features = {
"action": {"dtype": "float32", "shape": (7,)},
"observation.state": {"dtype": "float32", "shape": (10,)},
}
stats = compute_episode_stats(episode_data, features)
for key in ["action", "observation.state"]:
expected_keys = {"min", "max", "mean", "std", "count", "q01", "q10", "q50", "q90", "q99"}
assert set(stats[key].keys()) == expected_keys
def test_compute_episode_stats_with_custom_quantiles():
"""Test quantile computation with custom quantile values."""
np.random.seed(42)
episode_data = {
"action": np.random.normal(0, 1, (100, 7)),
"observation.state": np.random.normal(2, 1, (100, 10)),
}
features = {
"action": {"dtype": "float32", "shape": (7,)},
"observation.state": {"dtype": "float32", "shape": (10,)},
}
stats = compute_episode_stats(episode_data, features)
# Should have quantiles
for key in ["action", "observation.state"]:
expected_keys = {"min", "max", "mean", "std", "count", "q01", "q10", "q50", "q90", "q99"}
assert set(stats[key].keys()) == expected_keys
# Verify shapes
assert stats[key]["q01"].shape == (features[key]["shape"][0],)
assert stats[key]["q99"].shape == (features[key]["shape"][0],)
def test_compute_episode_stats_with_image_data():
"""Test quantile computation with image features."""
image_paths = [f"image_{i}.jpg" for i in range(50)]
episode_data = {
"observation.image": image_paths,
"action": np.random.normal(0, 1, (50, 5)),
}
features = {
"observation.image": {"dtype": "image"},
"action": {"dtype": "float32", "shape": (5,)},
}
with patch("lerobot.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy):
stats = compute_episode_stats(episode_data, features)
# Image quantiles should be normalized and have correct shape
assert "q01" in stats["observation.image"]
assert "q50" in stats["observation.image"]
assert "q99" in stats["observation.image"]
assert stats["observation.image"]["q01"].shape == (3, 1, 1)
assert stats["observation.image"]["q50"].shape == (3, 1, 1)
assert stats["observation.image"]["q99"].shape == (3, 1, 1)
# Action quantiles should have correct shape
assert stats["action"]["q01"].shape == (5,)
assert stats["action"]["q50"].shape == (5,)
assert stats["action"]["q99"].shape == (5,)
def test_compute_episode_stats_string_features_skipped():
"""Test that string features are properly skipped."""
episode_data = {
"task": ["pick_apple"] * 100, # String feature
"action": np.random.normal(0, 1, (100, 5)),
}
features = {
"task": {"dtype": "string"},
"action": {"dtype": "float32", "shape": (5,)},
}
stats = compute_episode_stats(
episode_data,
features,
)
# String features should be skipped
assert "task" not in stats
assert "action" in stats
assert "q01" in stats["action"]
def test_aggregate_feature_stats_with_quantiles():
"""Test aggregating feature stats that include quantiles."""
stats_ft_list = [
{
"min": np.array([1.0]),
"max": np.array([10.0]),
"mean": np.array([5.0]),
"std": np.array([2.0]),
"count": np.array([100]),
"q01": np.array([1.5]),
"q99": np.array([9.5]),
},
{
"min": np.array([2.0]),
"max": np.array([12.0]),
"mean": np.array([6.0]),
"std": np.array([2.5]),
"count": np.array([150]),
"q01": np.array([2.5]),
"q99": np.array([11.5]),
},
]
result = aggregate_feature_stats(stats_ft_list)
# Should preserve quantiles
assert "q01" in result
assert "q99" in result
# Verify quantile aggregation (weighted average)
expected_q01 = (1.5 * 100 + 2.5 * 150) / 250 # ≈ 2.1
expected_q99 = (9.5 * 100 + 11.5 * 150) / 250 # ≈ 10.7
np.testing.assert_allclose(result["q01"], np.array([expected_q01]), atol=1e-6)
np.testing.assert_allclose(result["q99"], np.array([expected_q99]), atol=1e-6)
def test_aggregate_stats_mixed_quantiles():
"""Test aggregating stats where some have quantiles and some don't."""
stats_with_quantiles = {
"feature1": {
"min": np.array([1.0]),
"max": np.array([10.0]),
"mean": np.array([5.0]),
"std": np.array([2.0]),
"count": np.array([100]),
"q01": np.array([1.5]),
"q99": np.array([9.5]),
}
}
stats_without_quantiles = {
"feature2": {
"min": np.array([0.0]),
"max": np.array([5.0]),
"mean": np.array([2.5]),
"std": np.array([1.5]),
"count": np.array([50]),
}
}
all_stats = [stats_with_quantiles, stats_without_quantiles]
result = aggregate_stats(all_stats)
# Feature1 should keep its quantiles
assert "q01" in result["feature1"]
assert "q99" in result["feature1"]
# Feature2 should not have quantiles
assert "q01" not in result["feature2"]
assert "q99" not in result["feature2"]
def test_assert_type_and_shape_with_quantiles():
"""Test validation works correctly with quantile keys."""
# Valid stats with quantiles
valid_stats = [
{
"observation.image": {
"min": np.array([0.0, 0.0, 0.0]).reshape(3, 1, 1),
"max": np.array([1.0, 1.0, 1.0]).reshape(3, 1, 1),
"mean": np.array([0.5, 0.5, 0.5]).reshape(3, 1, 1),
"std": np.array([0.2, 0.2, 0.2]).reshape(3, 1, 1),
"count": np.array([100]),
"q01": np.array([0.1, 0.1, 0.1]).reshape(3, 1, 1),
"q99": np.array([0.9, 0.9, 0.9]).reshape(3, 1, 1),
}
}
]
# Should not raise error
_assert_type_and_shape(valid_stats)
# Invalid shape for quantile
invalid_stats = [
{
"observation.image": {
"count": np.array([100]),
"q01": np.array([0.1, 0.2]), # Wrong shape for image quantile
}
}
]
with pytest.raises(ValueError, match="Shape of quantile 'q01' must be \\(3,1,1\\)"):
_assert_type_and_shape(invalid_stats)
def test_quantile_integration_single_value_quantiles():
"""Test quantile computation with single repeated value."""
data = np.ones((100, 3)) # All ones
running_stats = RunningQuantileStats()
running_stats.update(data)
stats = running_stats.get_statistics()
# All quantiles should be approximately 1.0
np.testing.assert_allclose(stats["q01"], np.array([1.0, 1.0, 1.0]), atol=1e-6)
np.testing.assert_allclose(stats["q50"], np.array([1.0, 1.0, 1.0]), atol=1e-6)
np.testing.assert_allclose(stats["q99"], np.array([1.0, 1.0, 1.0]), atol=1e-6)
def test_quantile_integration_fixed_quantiles():
"""Test that fixed quantiles are computed."""
np.random.seed(42)
data = np.random.normal(0, 1, (1000, 2))
stats = get_feature_stats(data, axis=0, keepdims=False)
# Check all fixed quantiles are present
assert "q01" in stats
assert "q10" in stats
assert "q50" in stats
assert "q90" in stats
assert "q99" in stats
def test_quantile_integration_large_dataset_quantiles():
"""Test quantile computation efficiency with large datasets."""
np.random.seed(42)
large_data = np.random.normal(0, 1, (10000, 5))
running_stats = RunningQuantileStats(num_quantile_bins=1000) # Reduced bins for speed
running_stats.update(large_data)
stats = running_stats.get_statistics()
# Should complete without issues and produce reasonable results
assert stats["count"][0] == 10000
assert len(stats["q01"]) == 5
def test_fixed_quantiles_always_computed():
"""Test that the fixed quantiles [0.01, 0.10, 0.50, 0.90, 0.99] are always computed."""
np.random.seed(42)
# Test with vector data
vector_data = np.random.normal(0, 1, (100, 5))
vector_stats = get_feature_stats(vector_data, axis=0, keepdims=False)
# Check all fixed quantiles are present
expected_quantiles = ["q01", "q10", "q50", "q90", "q99"]
for q_key in expected_quantiles:
assert q_key in vector_stats
assert vector_stats[q_key].shape == (5,)
# Test with image data
image_data = np.random.randint(0, 256, (50, 3, 32, 32), dtype=np.uint8)
image_stats = get_feature_stats(image_data, axis=(0, 2, 3), keepdims=True)
# Check all fixed quantiles are present for images
for q_key in expected_quantiles:
assert q_key in image_stats
assert image_stats[q_key].shape == (1, 3, 1, 1)
# Test with episode data
episode_data = {
"action": np.random.normal(0, 1, (100, 7)),
"observation.state": np.random.normal(0, 1, (100, 10)),
}
features = {
"action": {"dtype": "float32", "shape": (7,)},
"observation.state": {"dtype": "float32", "shape": (10,)},
}
episode_stats = compute_episode_stats(episode_data, features)
# Check all fixed quantiles are present in episode stats
for key in ["action", "observation.state"]:
for q_key in expected_quantiles:
assert q_key in episode_stats[key]
assert episode_stats[key][q_key].shape == (features[key]["shape"][0],)