diff --git a/src/lerobot/datasets/compute_stats.py b/src/lerobot/datasets/compute_stats.py index 438ac7fba..09765c130 100644 --- a/src/lerobot/datasets/compute_stats.py +++ b/src/lerobot/datasets/compute_stats.py @@ -59,6 +59,8 @@ class RunningQuantileStats: batch: An array where all dimensions except the last are batch dimensions. """ batch = batch.reshape(-1, batch.shape[-1]) + # Promote integer and low-precision inputs before computing squared statistics. + batch = batch.astype(np.result_type(batch.dtype, np.float32), copy=False) num_elements, vector_length = batch.shape if self._count == 0: diff --git a/tests/datasets/test_compute_stats.py b/tests/datasets/test_compute_stats.py index 70ba42378..0f5abfb95 100644 --- a/tests/datasets/test_compute_stats.py +++ b/tests/datasets/test_compute_stats.py @@ -83,6 +83,29 @@ def test_get_feature_stats_images(): assert stats["min"].shape == stats["max"].shape == stats["mean"].shape == stats["std"].shape +def test_get_feature_stats_uint8_images_preserves_std(): + data = np.array( + [ + [ + [[0, 64], [128, 255]], + [[255, 128], [64, 0]], + [[32, 96], [160, 224]], + ], + [ + [[16, 80], [144, 240]], + [[240, 144], [80, 16]], + [[48, 112], [176, 208]], + ], + ], + dtype=np.uint8, + ) + + stats = get_feature_stats(data, axis=(0, 2, 3), keepdims=True) + + expected_std = data.transpose(0, 2, 3, 1).reshape(-1, 3).std(axis=0).reshape(1, 3, 1, 1) + np.testing.assert_allclose(stats["std"], expected_std) + + def test_get_feature_stats_axis_0_keepdims(sample_array): expected = { "min": np.array([[1, 2, 3]]),