mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 20:19:43 +00:00
fix(tests): ensure tensor stats comparison accounts for reshaping in normalization tests
This commit is contained in:
@@ -1804,13 +1804,15 @@ def test_stats_override_preservation_in_load_state_dict():
|
|||||||
override_normalizer.stats[key][stat_name], original_stats[key][stat_name]
|
override_normalizer.stats[key][stat_name], original_stats[key][stat_name]
|
||||||
), f"Stats for {key}.{stat_name} should not match original stats"
|
), f"Stats for {key}.{stat_name} should not match original stats"
|
||||||
|
|
||||||
# Verify that _tensor_stats are also correctly set to match the override stats
|
# Verify that _tensor_stats values match the override stats
|
||||||
|
# Note: visual stats are reshaped from (C,) to (C,1,1) by _reshape_visual_stats
|
||||||
expected_tensor_stats = to_tensor(override_stats)
|
expected_tensor_stats = to_tensor(override_stats)
|
||||||
for key in expected_tensor_stats:
|
for key in expected_tensor_stats:
|
||||||
for stat_name in expected_tensor_stats[key]:
|
for stat_name in expected_tensor_stats[key]:
|
||||||
if isinstance(expected_tensor_stats[key][stat_name], torch.Tensor):
|
if isinstance(expected_tensor_stats[key][stat_name], torch.Tensor):
|
||||||
torch.testing.assert_close(
|
torch.testing.assert_close(
|
||||||
override_normalizer._tensor_stats[key][stat_name], expected_tensor_stats[key][stat_name]
|
override_normalizer._tensor_stats[key][stat_name].squeeze(),
|
||||||
|
expected_tensor_stats[key][stat_name].squeeze(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -1849,12 +1851,16 @@ def test_stats_without_override_loads_normally():
|
|||||||
# Stats should now match the original stats (normal behavior)
|
# Stats should now match the original stats (normal behavior)
|
||||||
# Check that all keys and values match
|
# Check that all keys and values match
|
||||||
assert set(new_normalizer.stats.keys()) == set(original_stats.keys())
|
assert set(new_normalizer.stats.keys()) == set(original_stats.keys())
|
||||||
|
# Note: visual stats are reshaped from (C,) to (C,1,1) by _reshape_visual_stats,
|
||||||
|
# so we squeeze before comparing values.
|
||||||
for key in original_stats:
|
for key in original_stats:
|
||||||
assert set(new_normalizer.stats[key].keys()) == set(original_stats[key].keys())
|
assert set(new_normalizer.stats[key].keys()) == set(original_stats[key].keys())
|
||||||
for stat_name in original_stats[key]:
|
for stat_name in original_stats[key]:
|
||||||
np.testing.assert_allclose(
|
actual = new_normalizer.stats[key][stat_name]
|
||||||
new_normalizer.stats[key][stat_name], original_stats[key][stat_name], rtol=1e-6, atol=1e-6
|
expected = original_stats[key][stat_name]
|
||||||
)
|
if hasattr(actual, "squeeze"):
|
||||||
|
actual = actual.squeeze()
|
||||||
|
np.testing.assert_allclose(actual, expected, rtol=1e-6, atol=1e-6)
|
||||||
|
|
||||||
|
|
||||||
def test_stats_explicit_provided_flag_detection():
|
def test_stats_explicit_provided_flag_detection():
|
||||||
@@ -2075,8 +2081,9 @@ def test_stats_reconstruction_after_load_state_dict():
|
|||||||
assert ACTION in new_normalizer.stats
|
assert ACTION in new_normalizer.stats
|
||||||
|
|
||||||
# Check that values are correct (converted back from tensors)
|
# Check that values are correct (converted back from tensors)
|
||||||
np.testing.assert_allclose(new_normalizer.stats[OBS_IMAGE]["mean"], [0.5, 0.5, 0.5])
|
# Note: visual stats are reshaped to (C,1,1), so we squeeze before comparing
|
||||||
np.testing.assert_allclose(new_normalizer.stats[OBS_IMAGE]["std"], [0.2, 0.2, 0.2])
|
np.testing.assert_allclose(new_normalizer.stats[OBS_IMAGE]["mean"].squeeze(), [0.5, 0.5, 0.5])
|
||||||
|
np.testing.assert_allclose(new_normalizer.stats[OBS_IMAGE]["std"].squeeze(), [0.2, 0.2, 0.2])
|
||||||
np.testing.assert_allclose(new_normalizer.stats[OBS_STATE]["min"], [0.0, -1.0])
|
np.testing.assert_allclose(new_normalizer.stats[OBS_STATE]["min"], [0.0, -1.0])
|
||||||
np.testing.assert_allclose(new_normalizer.stats[OBS_STATE]["max"], [1.0, 1.0])
|
np.testing.assert_allclose(new_normalizer.stats[OBS_STATE]["max"], [1.0, 1.0])
|
||||||
np.testing.assert_allclose(new_normalizer.stats[ACTION]["mean"], [0.0, 0.0])
|
np.testing.assert_allclose(new_normalizer.stats[ACTION]["mean"], [0.0, 0.0])
|
||||||
|
|||||||
Reference in New Issue
Block a user