mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-17 08:17:02 +00:00
fix(datasets): avoid uint8 overflow in image stats (#3697)
* fix(datasets): avoid uint8 overflow in image stats * fix(datasets): promote stats batches dynamically
This commit is contained in:
@@ -59,6 +59,8 @@ class RunningQuantileStats:
|
||||
batch: An array where all dimensions except the last are batch dimensions.
|
||||
"""
|
||||
batch = batch.reshape(-1, batch.shape[-1])
|
||||
# Promote integer and low-precision inputs before computing squared statistics.
|
||||
batch = batch.astype(np.result_type(batch.dtype, np.float32), copy=False)
|
||||
num_elements, vector_length = batch.shape
|
||||
|
||||
if self._count == 0:
|
||||
|
||||
@@ -83,6 +83,29 @@ def test_get_feature_stats_images():
|
||||
assert stats["min"].shape == stats["max"].shape == stats["mean"].shape == stats["std"].shape
|
||||
|
||||
|
||||
def test_get_feature_stats_uint8_images_preserves_std():
|
||||
data = np.array(
|
||||
[
|
||||
[
|
||||
[[0, 64], [128, 255]],
|
||||
[[255, 128], [64, 0]],
|
||||
[[32, 96], [160, 224]],
|
||||
],
|
||||
[
|
||||
[[16, 80], [144, 240]],
|
||||
[[240, 144], [80, 16]],
|
||||
[[48, 112], [176, 208]],
|
||||
],
|
||||
],
|
||||
dtype=np.uint8,
|
||||
)
|
||||
|
||||
stats = get_feature_stats(data, axis=(0, 2, 3), keepdims=True)
|
||||
|
||||
expected_std = data.transpose(0, 2, 3, 1).reshape(-1, 3).std(axis=0).reshape(1, 3, 1, 1)
|
||||
np.testing.assert_allclose(stats["std"], expected_std)
|
||||
|
||||
|
||||
def test_get_feature_stats_axis_0_keepdims(sample_array):
|
||||
expected = {
|
||||
"min": np.array([[1, 2, 3]]),
|
||||
|
||||
Reference in New Issue
Block a user