mirror of
https://github.com/huggingface/lerobot.git
synced 2026-07-01 15:17:05 +00:00
3dd19d043e
* feat(depth): add depth quantization helpers and tests
* feat(video): add ffv1 to supported codecs
* feat(depth): persist depth metadata
* feat(depth): extend quantization tools to better fit the encoding/decoding pipeline
* feat(depth): plumb DepthEncoderConfig through LeRobotDataset and DatasetWriter
* feat(depth): wire StreamingVideoEncoder + writer to depth encoder
* feat(depth): wire DatasetReader to decode_depth_frames
* feat(cameras/realsense): expose async depth in metric meters
* feat(features): route 2D camera shapes to observation.depth.<key>
* feat(robots/so_follower): emit + populate depth keys when use_depth
* feat(record): plumb DepthEncoderConfig through lerobot-record
* feat(viz): render depth observations as rr.DepthImage in Viridis
* feat(depth maps writer): adding support for raw depth maps recording with image writer
* chore(format): format code
* feat(depth shape): ensuring depth maps shape is always including the channel
* feat(is_depth): simplifying is_depth nested name + legacy support
* fix(stop_event): fixing stop_event race condition in camera classes
* fix(plumbing): fixing missing parts in the depth maps pipeline
* chore(typos): fixing typos
* test(fix): fixing exisiting tests to still work with latest features
* tests(depth): adding new tests for depth integration validation
* feat(pix_fmt channels): use PyAv to check get pixel formats number of channels
* feat(refactor): refactor DepthEncoderConfig quantization pipeline, so that the methods do not live in the config class. Add pixel format - channels validation.Move the default pixel format for depth in the config file.
* fix(pre-commit): fixing mutable defautl value
* fix(info): fixing info metadata update when is_depth_map was set
* tests(typos): fixing typos in tests
* fix(realsense): fixing typo in realsense serial number
* fix(normalization): restricting 255 normalization to non depth/uint8 images only
* fix(typo): fixing typo
* fix(TIFF): add missing quantization and cleanup for TIFF files
* feat(batched dequantization): optimizing dequantize_depth for torch based batched dequantization
* feat(tools): adding depth support in LeRobotDataset edition tools
* test(aggregate): extending aggregation tests to depth frames
* test(cleaning): cleaning up tests
* fix(from_video_info): fixing early validation issue in from_video_info
* fix(typo): fixing typo
* fix(is_depth): adding missing doctrings and is_depth arguments in video decoding functions
Co-authored-by: Wensi (Vince) Ai <59036629+wensi-ai@users.noreply.github.com>
* fix(depth units): fixing depth units output for the realsense cameras
* feat(output unit): adding support for output unit specification at dataset reading/training time
Co-authored-by: Wensi (Vince) Ai <59036629+wensi-ai@users.noreply.github.com>
* test(depth): cleaning up depth tests
* test(depth encoding): updating and cleaning video/depth encoding tests
* chore(format): formatting code
* docs(depth): improving depth maps docs
* test(fix): fixing depth tests
* test(dataset tools): adding missing tests for new dataset edition tools features
* chore(format): formatting code
* fix(pyav check): fixing PyAV option validation for integer codec options by normalizing
numeric values before calling `is_integer()`
Co-authored-by: Wensi (Vince) Ai <59036629+wensi-ai@users.noreply.github.com>
* docs(mermaid): fixing mermaid diagram
* fix(rebase): rebase follow up corrections
* feat(dataset tools): adding missing docstrings and features for depth fill support in dataset edition tools
* docs(docstring): updating docstrings
* docs(dataset tools): updating docs
* fix(save images): fixing image saving in dataset tools
* fix(update video info): fixing update video info logic to match the recording and editing use cases
* test(reencode): fixing reencoding monkeypatch
* fix(review): add Claude review
* chore(format): format code
* fix(update video info): ditching the differentiated approahces for video info update - video info are always updated unless for preserved keys.
* chore(rebase): fixing rebase merge conflicts
* test(visualization): fixing visualization tests
* feat(docstrings): adding explicit docstring for encoding parameters. Docstrigns will now show up as description in the CLI --help.
* feat(mm as default): adding a global DEFAULT_DEPTH_UNIT variable setting mm as default depth unit
* fix(RGB <-> camera): renaming camera_encoder to rgb_encoder for clarity
* chore(TODO): removing deprecated TODO
* doc(write_u16_plane): improving docstrings for write_u16_plane
* feat(units): adding constants for depth frames units (m and mm)
* fix(spam): replacing spamming warning but a debug log
* feat(leagcy metadata): adding automatic metadata update for legacy 'video.is_depth_map' feature
* fix(copy&reindex): fixing metadat reshaping for single channel frames
* fix(ImageNet): excluding dpeth frames from ImageNet stats
* fix(PyAV container seek): fixing initial PyAV container seek to be robust againsy codec choice
* feat(lerobot-dataset-viz): adding support for depth in lerobot-dataset-viz
* fix(compress): removing rerun compression for DepthImages
* fix(signle channel squeeze): fixing single channel squeezing
* chore(format): format code
* fix(streaming): adding support for dequantization in streaming_dataset.py
* refactor(read depth): factorizing depth reading methods for realsense camera and adding support for depth-only usage
* chore(renaming): fixing missed RGBEncoderConfig renamings
* docs(renaming): reflecting renamings in a clearer way in the docs
* chore(annotation): excluding depth from the annotation pipeline
* feat(robots): adding depth support in compatible follower robots
* feat(LeSadKiwi): excluding LeKiwi from depth support (for now)
* chore(fail): removing misplaced file
* chore(fail): removing misplaced file
* fix(remove ffv1): removing ffv1 as it does not support MP4
* docs(cheat sheet): adding depth and video encoding to the cheat sheet
* fix(lossless): tuning depth encoding parameters for lossless depth storage
* test(fix): fixing failing tests
* depth(ZMQ): excluding ZMQ from depth support
* Revert "depth(ZMQ): excluding ZMQ from depth support"
This reverts commit b95cf4e4c2.
* fix(image transforms): excluding depth frames from images transforms
* fix(typo): typo
* fix(stats): fixing stats computation for depth frames
* fix(TIFF vs. pytorch): adding an extra uint16 to float32 conversion for depth maps stored as raw TIFF images
* fix(typos): fixing typos
* test(dtype): fixing stats computation typing tests
---------
Signed-off-by: Steven Palma <imstevenpmwork@ieee.org>
Co-authored-by: Wensi (Vince) Ai <59036629+wensi-ai@users.noreply.github.com>
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
Co-authored-by: Wensi Ai <wsai@stanford.edu>
881 lines
30 KiB
Python
881 lines
30 KiB
Python
#!/usr/bin/env python
|
|
|
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
from unittest.mock import patch
|
|
|
|
import numpy as np
|
|
import pytest
|
|
|
|
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
|
|
|
from lerobot.datasets.compute_stats import (
|
|
RunningQuantileStats,
|
|
_assert_type_and_shape,
|
|
aggregate_feature_stats,
|
|
aggregate_stats,
|
|
compute_episode_stats,
|
|
estimate_num_samples,
|
|
get_feature_stats,
|
|
sample_images,
|
|
sample_indices,
|
|
)
|
|
from lerobot.utils.constants import OBS_IMAGE, OBS_STATE
|
|
|
|
|
|
def mock_load_image_as_numpy(path, dtype, channel_first):
|
|
is_depth = "depth" in str(path)
|
|
channels = 1 if is_depth else 3
|
|
out_dtype = np.uint16 if is_depth else dtype
|
|
arr = np.arange(channels * 32 * 32, dtype=out_dtype).reshape(channels, 32, 32)
|
|
return arr if channel_first else arr.transpose(1, 2, 0)
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_array():
|
|
return np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
|
|
|
|
|
|
def test_estimate_num_samples():
|
|
assert estimate_num_samples(1) == 1
|
|
assert estimate_num_samples(10) == 10
|
|
assert estimate_num_samples(100) == 100
|
|
assert estimate_num_samples(200) == 100
|
|
assert estimate_num_samples(1000) == 177
|
|
assert estimate_num_samples(2000) == 299
|
|
assert estimate_num_samples(5000) == 594
|
|
assert estimate_num_samples(10_000) == 1000
|
|
assert estimate_num_samples(20_000) == 1681
|
|
assert estimate_num_samples(50_000) == 3343
|
|
assert estimate_num_samples(500_000) == 10_000
|
|
|
|
|
|
def test_sample_indices():
|
|
indices = sample_indices(10)
|
|
assert len(indices) > 0
|
|
assert indices[0] == 0
|
|
assert indices[-1] == 9
|
|
assert len(indices) == estimate_num_samples(10)
|
|
|
|
|
|
@patch("lerobot.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy)
|
|
def test_sample_images(mock_load):
|
|
image_paths = [f"image_{i}.jpg" for i in range(100)]
|
|
images = sample_images(image_paths)
|
|
assert isinstance(images, np.ndarray)
|
|
assert images.shape[1:] == (3, 32, 32)
|
|
assert images.dtype == np.uint8
|
|
assert len(images) == estimate_num_samples(100)
|
|
|
|
|
|
def test_get_feature_stats_images():
|
|
data = np.random.rand(100, 3, 32, 32)
|
|
stats = get_feature_stats(data, axis=(0, 2, 3), keepdims=True)
|
|
assert "min" in stats and "max" in stats and "mean" in stats and "std" in stats and "count" in stats
|
|
np.testing.assert_equal(stats["count"], np.array([100]))
|
|
assert stats["min"].shape == stats["max"].shape == stats["mean"].shape == stats["std"].shape
|
|
|
|
|
|
def test_get_feature_stats_uint8_images_preserves_std():
|
|
data = np.array(
|
|
[
|
|
[
|
|
[[0, 64], [128, 255]],
|
|
[[255, 128], [64, 0]],
|
|
[[32, 96], [160, 224]],
|
|
],
|
|
[
|
|
[[16, 80], [144, 240]],
|
|
[[240, 144], [80, 16]],
|
|
[[48, 112], [176, 208]],
|
|
],
|
|
],
|
|
dtype=np.uint8,
|
|
)
|
|
|
|
stats = get_feature_stats(data, axis=(0, 2, 3), keepdims=True)
|
|
|
|
expected_std = data.transpose(0, 2, 3, 1).reshape(-1, 3).std(axis=0).reshape(1, 3, 1, 1)
|
|
np.testing.assert_allclose(stats["std"], expected_std)
|
|
|
|
|
|
def test_get_feature_stats_axis_0_keepdims(sample_array):
|
|
expected = {
|
|
"min": np.array([[1, 2, 3]]),
|
|
"max": np.array([[7, 8, 9]]),
|
|
"mean": np.array([[4.0, 5.0, 6.0]]),
|
|
"std": np.array([[2.44948974, 2.44948974, 2.44948974]]),
|
|
"count": np.array([3]),
|
|
}
|
|
result = get_feature_stats(sample_array, axis=(0,), keepdims=True)
|
|
for key in expected:
|
|
np.testing.assert_allclose(result[key], expected[key])
|
|
|
|
|
|
def test_get_feature_stats_axis_1(sample_array):
|
|
expected = {
|
|
"min": np.array([1, 4, 7]),
|
|
"max": np.array([3, 6, 9]),
|
|
"mean": np.array([2.0, 5.0, 8.0]),
|
|
"std": np.array([0.81649658, 0.81649658, 0.81649658]),
|
|
"count": np.array([3]),
|
|
}
|
|
result = get_feature_stats(sample_array, axis=(1,), keepdims=False)
|
|
|
|
# Check that basic stats are correct (quantiles are also included now)
|
|
assert set(expected.keys()).issubset(set(result.keys()))
|
|
for key in expected:
|
|
np.testing.assert_allclose(result[key], expected[key])
|
|
|
|
|
|
def test_get_feature_stats_no_axis(sample_array):
|
|
expected = {
|
|
"min": np.array(1),
|
|
"max": np.array(9),
|
|
"mean": np.array(5.0),
|
|
"std": np.array(2.5819889),
|
|
"count": np.array([3]),
|
|
}
|
|
result = get_feature_stats(sample_array, axis=None, keepdims=False)
|
|
|
|
# Check that basic stats are correct (quantiles are also included now)
|
|
assert set(expected.keys()).issubset(set(result.keys()))
|
|
for key in expected:
|
|
np.testing.assert_allclose(result[key], expected[key])
|
|
|
|
|
|
def test_get_feature_stats_empty_array():
|
|
array = np.array([])
|
|
with pytest.raises(ValueError):
|
|
get_feature_stats(array, axis=(0,), keepdims=True)
|
|
|
|
|
|
def test_get_feature_stats_single_value():
|
|
array = np.array([[1337]])
|
|
result = get_feature_stats(array, axis=None, keepdims=True)
|
|
np.testing.assert_equal(result["min"], np.array(1337))
|
|
np.testing.assert_equal(result["max"], np.array(1337))
|
|
np.testing.assert_equal(result["mean"], np.array(1337.0))
|
|
np.testing.assert_equal(result["std"], np.array(0.0))
|
|
np.testing.assert_equal(result["count"], np.array([1]))
|
|
|
|
|
|
def test_compute_episode_stats():
|
|
depth_key = "observation.images.depth"
|
|
episode_data = {
|
|
OBS_IMAGE: [f"image_{i}.jpg" for i in range(100)],
|
|
depth_key: [f"depth_{i}.tiff" for i in range(100)],
|
|
OBS_STATE: np.random.rand(100, 10),
|
|
}
|
|
features = {
|
|
OBS_IMAGE: {"dtype": "image"},
|
|
depth_key: {"dtype": "image", "info": {"is_depth_map": True}},
|
|
OBS_STATE: {"dtype": "numeric"},
|
|
}
|
|
|
|
with patch("lerobot.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy):
|
|
stats = compute_episode_stats(episode_data, features)
|
|
|
|
assert OBS_IMAGE in stats and depth_key in stats and OBS_STATE in stats
|
|
assert stats[OBS_IMAGE]["count"].item() == 100
|
|
assert stats[depth_key]["count"].item() == 100
|
|
assert stats[OBS_STATE]["count"].item() == 100
|
|
assert stats[OBS_IMAGE]["mean"].shape == (3, 1, 1)
|
|
assert stats[depth_key]["mean"].shape == (1, 1, 1)
|
|
# Depth keeps raw values: max far exceeds 255, proving no /255 and no uint8 downcast.
|
|
assert stats[depth_key]["min"].item() == 0.0
|
|
assert stats[depth_key]["max"].item() == 1023.0
|
|
# RGB is normalized to [0, 1].
|
|
np.testing.assert_allclose(stats[OBS_IMAGE]["min"], 0.0)
|
|
np.testing.assert_allclose(stats[OBS_IMAGE]["max"], 1.0)
|
|
|
|
|
|
def test_assert_type_and_shape_valid():
|
|
valid_stats = [
|
|
{
|
|
"feature1": {
|
|
"min": np.array([1.0]),
|
|
"max": np.array([10.0]),
|
|
"mean": np.array([5.0]),
|
|
"std": np.array([2.0]),
|
|
"count": np.array([1]),
|
|
}
|
|
}
|
|
]
|
|
_assert_type_and_shape(valid_stats)
|
|
|
|
|
|
def test_assert_type_and_shape_invalid_type():
|
|
invalid_stats = [
|
|
{
|
|
"feature1": {
|
|
"min": [1.0], # Not a numpy array
|
|
"max": np.array([10.0]),
|
|
"mean": np.array([5.0]),
|
|
"std": np.array([2.0]),
|
|
"count": np.array([1]),
|
|
}
|
|
}
|
|
]
|
|
with pytest.raises(ValueError, match="Stats must be composed of numpy array"):
|
|
_assert_type_and_shape(invalid_stats)
|
|
|
|
|
|
def test_assert_type_and_shape_invalid_shape():
|
|
invalid_stats = [
|
|
{
|
|
"feature1": {
|
|
"count": np.array([1, 2]), # Wrong shape
|
|
}
|
|
}
|
|
]
|
|
with pytest.raises(ValueError, match=r"Shape of 'count' must be \(1\)"):
|
|
_assert_type_and_shape(invalid_stats)
|
|
|
|
|
|
def test_aggregate_feature_stats():
|
|
stats_ft_list = [
|
|
{
|
|
"min": np.array([1.0]),
|
|
"max": np.array([10.0]),
|
|
"mean": np.array([5.0]),
|
|
"std": np.array([2.0]),
|
|
"count": np.array([1]),
|
|
},
|
|
{
|
|
"min": np.array([2.0]),
|
|
"max": np.array([12.0]),
|
|
"mean": np.array([6.0]),
|
|
"std": np.array([2.5]),
|
|
"count": np.array([1]),
|
|
},
|
|
]
|
|
result = aggregate_feature_stats(stats_ft_list)
|
|
np.testing.assert_allclose(result["min"], np.array([1.0]))
|
|
np.testing.assert_allclose(result["max"], np.array([12.0]))
|
|
np.testing.assert_allclose(result["mean"], np.array([5.5]))
|
|
np.testing.assert_allclose(result["std"], np.array([2.318405]), atol=1e-6)
|
|
np.testing.assert_allclose(result["count"], np.array([2]))
|
|
|
|
|
|
def test_aggregate_stats():
|
|
all_stats = [
|
|
{
|
|
OBS_IMAGE: {
|
|
"min": [1, 2, 3],
|
|
"max": [10, 20, 30],
|
|
"mean": [5.5, 10.5, 15.5],
|
|
"std": [2.87, 5.87, 8.87],
|
|
"count": 10,
|
|
},
|
|
OBS_STATE: {"min": 1, "max": 10, "mean": 5.5, "std": 2.87, "count": 10},
|
|
"extra_key_0": {"min": 5, "max": 25, "mean": 15, "std": 6, "count": 6},
|
|
},
|
|
{
|
|
OBS_IMAGE: {
|
|
"min": [2, 1, 0],
|
|
"max": [15, 10, 5],
|
|
"mean": [8.5, 5.5, 2.5],
|
|
"std": [3.42, 2.42, 1.42],
|
|
"count": 15,
|
|
},
|
|
OBS_STATE: {"min": 2, "max": 15, "mean": 8.5, "std": 3.42, "count": 15},
|
|
"extra_key_1": {"min": 0, "max": 20, "mean": 10, "std": 5, "count": 5},
|
|
},
|
|
]
|
|
|
|
expected_agg_stats = {
|
|
OBS_IMAGE: {
|
|
"min": [1, 1, 0],
|
|
"max": [15, 20, 30],
|
|
"mean": [7.3, 7.5, 7.7],
|
|
"std": [3.5317, 4.8267, 8.5581],
|
|
"count": 25,
|
|
},
|
|
OBS_STATE: {
|
|
"min": 1,
|
|
"max": 15,
|
|
"mean": 7.3,
|
|
"std": 3.5317,
|
|
"count": 25,
|
|
},
|
|
"extra_key_0": {
|
|
"min": 5,
|
|
"max": 25,
|
|
"mean": 15.0,
|
|
"std": 6.0,
|
|
"count": 6,
|
|
},
|
|
"extra_key_1": {
|
|
"min": 0,
|
|
"max": 20,
|
|
"mean": 10.0,
|
|
"std": 5.0,
|
|
"count": 5,
|
|
},
|
|
}
|
|
|
|
# cast to numpy
|
|
for ep_stats in all_stats:
|
|
for fkey, stats in ep_stats.items():
|
|
for k in stats:
|
|
stats[k] = np.array(stats[k], dtype=np.int64 if k == "count" else np.float32)
|
|
if fkey == OBS_IMAGE and k != "count":
|
|
stats[k] = stats[k].reshape(3, 1, 1) # for normalization on image channels
|
|
else:
|
|
stats[k] = stats[k].reshape(1)
|
|
|
|
# cast to numpy
|
|
for fkey, stats in expected_agg_stats.items():
|
|
for k in stats:
|
|
stats[k] = np.array(stats[k], dtype=np.int64 if k == "count" else np.float32)
|
|
if fkey == OBS_IMAGE and k != "count":
|
|
stats[k] = stats[k].reshape(3, 1, 1) # for normalization on image channels
|
|
else:
|
|
stats[k] = stats[k].reshape(1)
|
|
|
|
results = aggregate_stats(all_stats)
|
|
|
|
for fkey in expected_agg_stats:
|
|
np.testing.assert_allclose(results[fkey]["min"], expected_agg_stats[fkey]["min"])
|
|
np.testing.assert_allclose(results[fkey]["max"], expected_agg_stats[fkey]["max"])
|
|
np.testing.assert_allclose(results[fkey]["mean"], expected_agg_stats[fkey]["mean"])
|
|
np.testing.assert_allclose(
|
|
results[fkey]["std"], expected_agg_stats[fkey]["std"], atol=1e-04, rtol=1e-04
|
|
)
|
|
np.testing.assert_allclose(results[fkey]["count"], expected_agg_stats[fkey]["count"])
|
|
|
|
|
|
def test_running_quantile_stats_initialization():
|
|
"""Test proper initialization of RunningQuantileStats."""
|
|
running_stats = RunningQuantileStats()
|
|
assert running_stats._count == 0
|
|
assert running_stats._mean is None
|
|
assert running_stats._num_quantile_bins == 5000
|
|
|
|
# Test custom bin size
|
|
running_stats_custom = RunningQuantileStats(num_quantile_bins=1000)
|
|
assert running_stats_custom._num_quantile_bins == 1000
|
|
|
|
|
|
def test_running_quantile_stats_single_batch_update():
|
|
"""Test updating with a single batch."""
|
|
np.random.seed(42)
|
|
data = np.random.normal(0, 1, (100, 3))
|
|
|
|
running_stats = RunningQuantileStats()
|
|
running_stats.update(data)
|
|
|
|
assert running_stats._count == 100
|
|
assert running_stats._mean.shape == (3,)
|
|
assert len(running_stats._histograms) == 3
|
|
assert len(running_stats._bin_edges) == 3
|
|
|
|
# Verify basic statistics are reasonable
|
|
np.testing.assert_allclose(running_stats._mean, np.mean(data, axis=0), atol=1e-10)
|
|
|
|
|
|
def test_running_quantile_stats_multiple_batch_updates():
|
|
"""Test updating with multiple batches."""
|
|
np.random.seed(42)
|
|
data1 = np.random.normal(0, 1, (100, 2))
|
|
data2 = np.random.normal(1, 1, (150, 2))
|
|
|
|
running_stats = RunningQuantileStats()
|
|
running_stats.update(data1)
|
|
running_stats.update(data2)
|
|
|
|
assert running_stats._count == 250
|
|
|
|
# Verify running mean is correct
|
|
combined_data = np.vstack([data1, data2])
|
|
expected_mean = np.mean(combined_data, axis=0)
|
|
np.testing.assert_allclose(running_stats._mean, expected_mean, atol=1e-10)
|
|
|
|
|
|
def test_running_quantile_stats_get_statistics_basic():
|
|
"""Test getting basic statistics without quantiles."""
|
|
np.random.seed(42)
|
|
data = np.random.normal(0, 1, (100, 2))
|
|
|
|
running_stats = RunningQuantileStats()
|
|
running_stats.update(data)
|
|
|
|
stats = running_stats.get_statistics()
|
|
|
|
# Should have basic stats
|
|
expected_keys = {"min", "max", "mean", "std", "count"}
|
|
assert expected_keys.issubset(set(stats.keys()))
|
|
|
|
# Verify values
|
|
np.testing.assert_allclose(stats["mean"], np.mean(data, axis=0), atol=1e-10)
|
|
np.testing.assert_allclose(stats["std"], np.std(data, axis=0), atol=1e-6)
|
|
np.testing.assert_equal(stats["count"], np.array([100]))
|
|
|
|
|
|
def test_running_quantile_stats_get_statistics_with_quantiles():
|
|
"""Test getting statistics with quantiles."""
|
|
np.random.seed(42)
|
|
data = np.random.normal(0, 1, (1000, 2))
|
|
|
|
running_stats = RunningQuantileStats()
|
|
running_stats.update(data)
|
|
|
|
stats = running_stats.get_statistics()
|
|
|
|
# Should have basic stats plus quantiles
|
|
expected_keys = {"min", "max", "mean", "std", "count", "q01", "q10", "q50", "q90", "q99"}
|
|
assert expected_keys.issubset(set(stats.keys()))
|
|
|
|
# Verify quantile values are reasonable
|
|
from lerobot.datasets.compute_stats import DEFAULT_QUANTILES
|
|
|
|
for i, q in enumerate(DEFAULT_QUANTILES):
|
|
q_key = f"q{int(q * 100):02d}"
|
|
assert q_key in stats
|
|
assert stats[q_key].shape == (2,)
|
|
|
|
# Check that quantiles are in reasonable order
|
|
if i > 0:
|
|
prev_q_key = f"q{int(DEFAULT_QUANTILES[i - 1] * 100):02d}"
|
|
assert np.all(stats[prev_q_key] <= stats[q_key])
|
|
|
|
|
|
def test_running_quantile_stats_histogram_adjustment():
|
|
"""Test that histograms adjust when min/max change."""
|
|
running_stats = RunningQuantileStats()
|
|
|
|
# Initial data with small range
|
|
data1 = np.array([[0.0, 1.0], [0.1, 1.1], [0.2, 1.2]])
|
|
running_stats.update(data1)
|
|
|
|
initial_edges_0 = running_stats._bin_edges[0].copy()
|
|
initial_edges_1 = running_stats._bin_edges[1].copy()
|
|
|
|
# Add data with much larger range
|
|
data2 = np.array([[10.0, -10.0], [11.0, -11.0]])
|
|
running_stats.update(data2)
|
|
|
|
# Bin edges should have changed
|
|
assert not np.array_equal(initial_edges_0, running_stats._bin_edges[0])
|
|
assert not np.array_equal(initial_edges_1, running_stats._bin_edges[1])
|
|
|
|
# New edges should cover the expanded range
|
|
# First dimension: min should still be ~0.0, max should be ~11.0
|
|
assert running_stats._bin_edges[0][0] <= 0.0
|
|
assert running_stats._bin_edges[0][-1] >= 11.0
|
|
|
|
# Second dimension: min should be ~-11.0, max should be ~1.2
|
|
assert running_stats._bin_edges[1][0] <= -11.0
|
|
assert running_stats._bin_edges[1][-1] >= 1.2
|
|
|
|
|
|
def test_running_quantile_stats_insufficient_data_error():
|
|
"""Test error when trying to get stats with insufficient data."""
|
|
running_stats = RunningQuantileStats()
|
|
|
|
with pytest.raises(ValueError, match="Cannot compute statistics for less than 2 vectors"):
|
|
running_stats.get_statistics()
|
|
|
|
# Single vector should also fail
|
|
running_stats.update(np.array([[1.0]]))
|
|
with pytest.raises(ValueError, match="Cannot compute statistics for less than 2 vectors"):
|
|
running_stats.get_statistics()
|
|
|
|
|
|
def test_running_quantile_stats_vector_length_consistency():
|
|
"""Test error when vector lengths don't match."""
|
|
running_stats = RunningQuantileStats()
|
|
running_stats.update(np.array([[1.0, 2.0], [3.0, 4.0]]))
|
|
|
|
with pytest.raises(ValueError, match="The length of new vectors does not match"):
|
|
running_stats.update(np.array([[1.0, 2.0, 3.0]])) # Different length
|
|
|
|
|
|
def test_running_quantile_stats_reshape_handling():
|
|
"""Test that various input shapes are handled correctly."""
|
|
running_stats = RunningQuantileStats()
|
|
|
|
# Test 3D input (e.g., images)
|
|
data_3d = np.random.normal(0, 1, (10, 32, 32))
|
|
running_stats.update(data_3d)
|
|
|
|
assert running_stats._count == 10 * 32
|
|
assert running_stats._mean.shape == (32,)
|
|
|
|
# Test 1D input
|
|
running_stats_1d = RunningQuantileStats()
|
|
data_1d = np.array([1, 2, 3, 4, 5]).reshape(-1, 1)
|
|
running_stats_1d.update(data_1d)
|
|
|
|
assert running_stats_1d._count == 5
|
|
assert running_stats_1d._mean.shape == (1,)
|
|
|
|
|
|
def test_get_feature_stats_quantiles_enabled_by_default():
|
|
"""Test that quantiles are computed by default."""
|
|
data = np.random.normal(0, 1, (100, 5))
|
|
stats = get_feature_stats(data, axis=0, keepdims=False)
|
|
|
|
expected_keys = {"min", "max", "mean", "std", "count", "q01", "q10", "q50", "q90", "q99"}
|
|
assert set(stats.keys()) == expected_keys
|
|
|
|
|
|
def test_get_feature_stats_quantiles_with_vector_data():
|
|
"""Test quantile computation with vector data."""
|
|
np.random.seed(42)
|
|
data = np.random.normal(0, 1, (100, 5))
|
|
|
|
stats = get_feature_stats(data, axis=0, keepdims=False)
|
|
|
|
expected_keys = {"min", "max", "mean", "std", "count", "q01", "q10", "q50", "q90", "q99"}
|
|
assert set(stats.keys()) == expected_keys
|
|
|
|
# Verify shapes
|
|
assert stats["q01"].shape == (5,)
|
|
assert stats["q99"].shape == (5,)
|
|
|
|
# Verify quantiles are reasonable
|
|
assert np.all(stats["q01"] < stats["q99"])
|
|
|
|
|
|
def test_get_feature_stats_quantiles_with_image_data():
|
|
"""Test quantile computation with image data."""
|
|
np.random.seed(42)
|
|
data = np.random.normal(0, 1, (50, 3, 32, 32)) # batch, channels, height, width
|
|
|
|
stats = get_feature_stats(data, axis=(0, 2, 3), keepdims=True)
|
|
|
|
expected_keys = {"min", "max", "mean", "std", "count", "q01", "q10", "q50", "q90", "q99"}
|
|
assert set(stats.keys()) == expected_keys
|
|
|
|
# Verify shapes for images (should be (1, channels, 1, 1))
|
|
assert stats["q01"].shape == (1, 3, 1, 1)
|
|
assert stats["q50"].shape == (1, 3, 1, 1)
|
|
assert stats["q99"].shape == (1, 3, 1, 1)
|
|
|
|
|
|
def test_get_feature_stats_fixed_quantiles():
|
|
"""Test that fixed quantiles are always computed."""
|
|
data = np.random.normal(0, 1, (200, 3))
|
|
|
|
stats = get_feature_stats(data, axis=0, keepdims=False)
|
|
|
|
expected_quantile_keys = {"q01", "q10", "q50", "q90", "q99"}
|
|
assert expected_quantile_keys.issubset(set(stats.keys()))
|
|
|
|
|
|
def test_get_feature_stats_unsupported_axis_error():
|
|
"""Test error for unsupported axis configuration."""
|
|
data = np.random.normal(0, 1, (10, 5))
|
|
|
|
with pytest.raises(ValueError, match="Unsupported axis configuration"):
|
|
get_feature_stats(
|
|
data,
|
|
axis=(1, 2), # Unsupported axis
|
|
keepdims=False,
|
|
)
|
|
|
|
|
|
def test_compute_episode_stats_backward_compatibility():
|
|
"""Test that existing functionality is preserved."""
|
|
episode_data = {
|
|
"action": np.random.normal(0, 1, (100, 7)),
|
|
"observation.state": np.random.normal(0, 1, (100, 10)),
|
|
}
|
|
features = {
|
|
"action": {"dtype": "float32", "shape": (7,)},
|
|
"observation.state": {"dtype": "float32", "shape": (10,)},
|
|
}
|
|
|
|
stats = compute_episode_stats(episode_data, features)
|
|
|
|
for key in ["action", "observation.state"]:
|
|
expected_keys = {"min", "max", "mean", "std", "count", "q01", "q10", "q50", "q90", "q99"}
|
|
assert set(stats[key].keys()) == expected_keys
|
|
|
|
|
|
def test_compute_episode_stats_with_custom_quantiles():
|
|
"""Test quantile computation with custom quantile values."""
|
|
np.random.seed(42)
|
|
episode_data = {
|
|
"action": np.random.normal(0, 1, (100, 7)),
|
|
"observation.state": np.random.normal(2, 1, (100, 10)),
|
|
}
|
|
features = {
|
|
"action": {"dtype": "float32", "shape": (7,)},
|
|
"observation.state": {"dtype": "float32", "shape": (10,)},
|
|
}
|
|
|
|
stats = compute_episode_stats(episode_data, features)
|
|
|
|
# Should have quantiles
|
|
for key in ["action", "observation.state"]:
|
|
expected_keys = {"min", "max", "mean", "std", "count", "q01", "q10", "q50", "q90", "q99"}
|
|
assert set(stats[key].keys()) == expected_keys
|
|
|
|
# Verify shapes
|
|
assert stats[key]["q01"].shape == (features[key]["shape"][0],)
|
|
assert stats[key]["q99"].shape == (features[key]["shape"][0],)
|
|
|
|
|
|
def test_compute_episode_stats_with_image_data():
|
|
"""Test quantile computation with image features."""
|
|
image_paths = [f"image_{i}.jpg" for i in range(50)]
|
|
depth_paths = [f"depth_{i}.tiff" for i in range(50)]
|
|
episode_data = {
|
|
"observation.image": image_paths,
|
|
"observation.images.depth": depth_paths,
|
|
"action": np.random.normal(0, 1, (50, 5)),
|
|
}
|
|
features = {
|
|
"observation.image": {"dtype": "image"},
|
|
"observation.images.depth": {"dtype": "image", "info": {"is_depth_map": True}},
|
|
"action": {"dtype": "float32", "shape": (5,)},
|
|
}
|
|
|
|
with patch("lerobot.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy):
|
|
stats = compute_episode_stats(episode_data, features)
|
|
|
|
# RGB image quantiles should be normalized and per-channel.
|
|
for q in ("q01", "q50", "q99"):
|
|
assert stats["observation.image"][q].shape == (3, 1, 1)
|
|
|
|
# Depth quantiles are single-channel and kept in raw (un-normalized) units.
|
|
for q in ("q01", "q50", "q99"):
|
|
assert stats["observation.images.depth"][q].shape == (1, 1, 1)
|
|
# Depth max stays in raw units (not /255, not uint8-capped); RGB is normalized.
|
|
assert stats["observation.images.depth"]["max"].item() == 1023.0
|
|
np.testing.assert_allclose(stats["observation.image"]["max"], 1.0)
|
|
|
|
# Action quantiles should have correct shape
|
|
assert stats["action"]["q01"].shape == (5,)
|
|
assert stats["action"]["q50"].shape == (5,)
|
|
assert stats["action"]["q99"].shape == (5,)
|
|
|
|
|
|
def test_compute_episode_stats_string_features_skipped():
|
|
"""Test that string features are properly skipped."""
|
|
episode_data = {
|
|
"task": ["pick_apple"] * 100, # String feature
|
|
"action": np.random.normal(0, 1, (100, 5)),
|
|
}
|
|
features = {
|
|
"task": {"dtype": "string"},
|
|
"action": {"dtype": "float32", "shape": (5,)},
|
|
}
|
|
|
|
stats = compute_episode_stats(
|
|
episode_data,
|
|
features,
|
|
)
|
|
|
|
# String features should be skipped
|
|
assert "task" not in stats
|
|
assert "action" in stats
|
|
assert "q01" in stats["action"]
|
|
|
|
|
|
def test_aggregate_feature_stats_with_quantiles():
|
|
"""Test aggregating feature stats that include quantiles."""
|
|
stats_ft_list = [
|
|
{
|
|
"min": np.array([1.0]),
|
|
"max": np.array([10.0]),
|
|
"mean": np.array([5.0]),
|
|
"std": np.array([2.0]),
|
|
"count": np.array([100]),
|
|
"q01": np.array([1.5]),
|
|
"q99": np.array([9.5]),
|
|
},
|
|
{
|
|
"min": np.array([2.0]),
|
|
"max": np.array([12.0]),
|
|
"mean": np.array([6.0]),
|
|
"std": np.array([2.5]),
|
|
"count": np.array([150]),
|
|
"q01": np.array([2.5]),
|
|
"q99": np.array([11.5]),
|
|
},
|
|
]
|
|
|
|
result = aggregate_feature_stats(stats_ft_list)
|
|
|
|
# Should preserve quantiles
|
|
assert "q01" in result
|
|
assert "q99" in result
|
|
|
|
# Verify quantile aggregation (weighted average)
|
|
expected_q01 = (1.5 * 100 + 2.5 * 150) / 250 # ≈ 2.1
|
|
expected_q99 = (9.5 * 100 + 11.5 * 150) / 250 # ≈ 10.7
|
|
|
|
np.testing.assert_allclose(result["q01"], np.array([expected_q01]), atol=1e-6)
|
|
np.testing.assert_allclose(result["q99"], np.array([expected_q99]), atol=1e-6)
|
|
|
|
|
|
def test_aggregate_stats_mixed_quantiles():
|
|
"""Test aggregating stats where some have quantiles and some don't."""
|
|
stats_with_quantiles = {
|
|
"feature1": {
|
|
"min": np.array([1.0]),
|
|
"max": np.array([10.0]),
|
|
"mean": np.array([5.0]),
|
|
"std": np.array([2.0]),
|
|
"count": np.array([100]),
|
|
"q01": np.array([1.5]),
|
|
"q99": np.array([9.5]),
|
|
}
|
|
}
|
|
|
|
stats_without_quantiles = {
|
|
"feature2": {
|
|
"min": np.array([0.0]),
|
|
"max": np.array([5.0]),
|
|
"mean": np.array([2.5]),
|
|
"std": np.array([1.5]),
|
|
"count": np.array([50]),
|
|
}
|
|
}
|
|
|
|
all_stats = [stats_with_quantiles, stats_without_quantiles]
|
|
result = aggregate_stats(all_stats)
|
|
|
|
# Feature1 should keep its quantiles
|
|
assert "q01" in result["feature1"]
|
|
assert "q99" in result["feature1"]
|
|
|
|
# Feature2 should not have quantiles
|
|
assert "q01" not in result["feature2"]
|
|
assert "q99" not in result["feature2"]
|
|
|
|
|
|
def test_assert_type_and_shape_with_quantiles():
|
|
"""Test validation works correctly with quantile keys."""
|
|
# Valid stats with quantiles
|
|
valid_stats = [
|
|
{
|
|
"observation.image": {
|
|
"min": np.array([0.0, 0.0, 0.0]).reshape(3, 1, 1),
|
|
"max": np.array([1.0, 1.0, 1.0]).reshape(3, 1, 1),
|
|
"mean": np.array([0.5, 0.5, 0.5]).reshape(3, 1, 1),
|
|
"std": np.array([0.2, 0.2, 0.2]).reshape(3, 1, 1),
|
|
"count": np.array([100]),
|
|
"q01": np.array([0.1, 0.1, 0.1]).reshape(3, 1, 1),
|
|
"q99": np.array([0.9, 0.9, 0.9]).reshape(3, 1, 1),
|
|
}
|
|
}
|
|
]
|
|
|
|
# Should not raise error
|
|
_assert_type_and_shape(valid_stats)
|
|
|
|
# Invalid shape for quantile
|
|
invalid_stats = [
|
|
{
|
|
"observation.image": {
|
|
"count": np.array([100]),
|
|
"q01": np.array([0.1, 0.2]), # Wrong shape for image quantile
|
|
}
|
|
}
|
|
]
|
|
|
|
with pytest.raises(ValueError, match="Shape of quantile 'q01' must be \\(3,1,1\\)"):
|
|
_assert_type_and_shape(invalid_stats)
|
|
|
|
|
|
def test_quantile_integration_single_value_quantiles():
|
|
"""Test quantile computation with single repeated value."""
|
|
data = np.ones((100, 3)) # All ones
|
|
|
|
running_stats = RunningQuantileStats()
|
|
running_stats.update(data)
|
|
|
|
stats = running_stats.get_statistics()
|
|
|
|
# All quantiles should be approximately 1.0
|
|
np.testing.assert_allclose(stats["q01"], np.array([1.0, 1.0, 1.0]), atol=1e-6)
|
|
np.testing.assert_allclose(stats["q50"], np.array([1.0, 1.0, 1.0]), atol=1e-6)
|
|
np.testing.assert_allclose(stats["q99"], np.array([1.0, 1.0, 1.0]), atol=1e-6)
|
|
|
|
|
|
def test_quantile_integration_fixed_quantiles():
|
|
"""Test that fixed quantiles are computed."""
|
|
np.random.seed(42)
|
|
data = np.random.normal(0, 1, (1000, 2))
|
|
|
|
stats = get_feature_stats(data, axis=0, keepdims=False)
|
|
|
|
# Check all fixed quantiles are present
|
|
assert "q01" in stats
|
|
assert "q10" in stats
|
|
assert "q50" in stats
|
|
assert "q90" in stats
|
|
assert "q99" in stats
|
|
|
|
|
|
def test_quantile_integration_large_dataset_quantiles():
|
|
"""Test quantile computation efficiency with large datasets."""
|
|
np.random.seed(42)
|
|
large_data = np.random.normal(0, 1, (10000, 5))
|
|
|
|
running_stats = RunningQuantileStats(num_quantile_bins=1000) # Reduced bins for speed
|
|
running_stats.update(large_data)
|
|
|
|
stats = running_stats.get_statistics()
|
|
|
|
# Should complete without issues and produce reasonable results
|
|
assert stats["count"][0] == 10000
|
|
assert len(stats["q01"]) == 5
|
|
|
|
|
|
def test_fixed_quantiles_always_computed():
|
|
"""Test that the fixed quantiles [0.01, 0.10, 0.50, 0.90, 0.99] are always computed."""
|
|
np.random.seed(42)
|
|
# Test with vector data
|
|
vector_data = np.random.normal(0, 1, (100, 5))
|
|
vector_stats = get_feature_stats(vector_data, axis=0, keepdims=False)
|
|
|
|
# Check all fixed quantiles are present
|
|
expected_quantiles = ["q01", "q10", "q50", "q90", "q99"]
|
|
for q_key in expected_quantiles:
|
|
assert q_key in vector_stats
|
|
assert vector_stats[q_key].shape == (5,)
|
|
|
|
# Test with image data
|
|
image_data = np.random.randint(0, 256, (50, 3, 32, 32), dtype=np.uint8)
|
|
image_stats = get_feature_stats(image_data, axis=(0, 2, 3), keepdims=True)
|
|
|
|
# Check all fixed quantiles are present for images
|
|
for q_key in expected_quantiles:
|
|
assert q_key in image_stats
|
|
assert image_stats[q_key].shape == (1, 3, 1, 1)
|
|
|
|
# Test with episode data
|
|
episode_data = {
|
|
"action": np.random.normal(0, 1, (100, 7)),
|
|
"observation.state": np.random.normal(0, 1, (100, 10)),
|
|
}
|
|
features = {
|
|
"action": {"dtype": "float32", "shape": (7,)},
|
|
"observation.state": {"dtype": "float32", "shape": (10,)},
|
|
}
|
|
|
|
episode_stats = compute_episode_stats(episode_data, features)
|
|
|
|
# Check all fixed quantiles are present in episode stats
|
|
for key in ["action", "observation.state"]:
|
|
for q_key in expected_quantiles:
|
|
assert q_key in episode_stats[key]
|
|
assert episode_stats[key][q_key].shape == (features[key]["shape"][0],)
|