fix(tests): ensure tensor stats comparison accounts for reshaping in normalization tests

This commit is contained in:
Khalil Meftah
2026-04-15 16:12:08 +02:00
parent 7a1c9e74c3
commit 23bece96a4
+14 -7
View File
@@ -1804,13 +1804,15 @@ def test_stats_override_preservation_in_load_state_dict():
override_normalizer.stats[key][stat_name], original_stats[key][stat_name]
), f"Stats for {key}.{stat_name} should not match original stats"
# Verify that _tensor_stats are also correctly set to match the override stats
# Verify that _tensor_stats values match the override stats
# Note: visual stats are reshaped from (C,) to (C,1,1) by _reshape_visual_stats
expected_tensor_stats = to_tensor(override_stats)
for key in expected_tensor_stats:
for stat_name in expected_tensor_stats[key]:
if isinstance(expected_tensor_stats[key][stat_name], torch.Tensor):
torch.testing.assert_close(
override_normalizer._tensor_stats[key][stat_name], expected_tensor_stats[key][stat_name]
override_normalizer._tensor_stats[key][stat_name].squeeze(),
expected_tensor_stats[key][stat_name].squeeze(),
)
@@ -1849,12 +1851,16 @@ def test_stats_without_override_loads_normally():
# Stats should now match the original stats (normal behavior)
# Check that all keys and values match
assert set(new_normalizer.stats.keys()) == set(original_stats.keys())
# Note: visual stats are reshaped from (C,) to (C,1,1) by _reshape_visual_stats,
# so we squeeze before comparing values.
for key in original_stats:
assert set(new_normalizer.stats[key].keys()) == set(original_stats[key].keys())
for stat_name in original_stats[key]:
np.testing.assert_allclose(
new_normalizer.stats[key][stat_name], original_stats[key][stat_name], rtol=1e-6, atol=1e-6
)
actual = new_normalizer.stats[key][stat_name]
expected = original_stats[key][stat_name]
if hasattr(actual, "squeeze"):
actual = actual.squeeze()
np.testing.assert_allclose(actual, expected, rtol=1e-6, atol=1e-6)
def test_stats_explicit_provided_flag_detection():
@@ -2075,8 +2081,9 @@ def test_stats_reconstruction_after_load_state_dict():
assert ACTION in new_normalizer.stats
# Check that values are correct (converted back from tensors)
np.testing.assert_allclose(new_normalizer.stats[OBS_IMAGE]["mean"], [0.5, 0.5, 0.5])
np.testing.assert_allclose(new_normalizer.stats[OBS_IMAGE]["std"], [0.2, 0.2, 0.2])
# Note: visual stats are reshaped to (C,1,1), so we squeeze before comparing
np.testing.assert_allclose(new_normalizer.stats[OBS_IMAGE]["mean"].squeeze(), [0.5, 0.5, 0.5])
np.testing.assert_allclose(new_normalizer.stats[OBS_IMAGE]["std"].squeeze(), [0.2, 0.2, 0.2])
np.testing.assert_allclose(new_normalizer.stats[OBS_STATE]["min"], [0.0, -1.0])
np.testing.assert_allclose(new_normalizer.stats[OBS_STATE]["max"], [1.0, 1.0])
np.testing.assert_allclose(new_normalizer.stats[ACTION]["mean"], [0.0, 0.0])