diff --git a/tests/processor/test_normalize_processor.py b/tests/processor/test_normalize_processor.py index cd5c75005..e046adb0d 100644 --- a/tests/processor/test_normalize_processor.py +++ b/tests/processor/test_normalize_processor.py @@ -1804,13 +1804,15 @@ def test_stats_override_preservation_in_load_state_dict(): override_normalizer.stats[key][stat_name], original_stats[key][stat_name] ), f"Stats for {key}.{stat_name} should not match original stats" - # Verify that _tensor_stats are also correctly set to match the override stats + # Verify that _tensor_stats values match the override stats + # Note: visual stats are reshaped from (C,) to (C,1,1) by _reshape_visual_stats expected_tensor_stats = to_tensor(override_stats) for key in expected_tensor_stats: for stat_name in expected_tensor_stats[key]: if isinstance(expected_tensor_stats[key][stat_name], torch.Tensor): torch.testing.assert_close( - override_normalizer._tensor_stats[key][stat_name], expected_tensor_stats[key][stat_name] + override_normalizer._tensor_stats[key][stat_name].squeeze(), + expected_tensor_stats[key][stat_name].squeeze(), ) @@ -1849,12 +1851,16 @@ def test_stats_without_override_loads_normally(): # Stats should now match the original stats (normal behavior) # Check that all keys and values match assert set(new_normalizer.stats.keys()) == set(original_stats.keys()) + # Note: visual stats are reshaped from (C,) to (C,1,1) by _reshape_visual_stats, + # so we squeeze before comparing values. for key in original_stats: assert set(new_normalizer.stats[key].keys()) == set(original_stats[key].keys()) for stat_name in original_stats[key]: - np.testing.assert_allclose( - new_normalizer.stats[key][stat_name], original_stats[key][stat_name], rtol=1e-6, atol=1e-6 - ) + actual = new_normalizer.stats[key][stat_name] + expected = original_stats[key][stat_name] + if hasattr(actual, "squeeze"): + actual = actual.squeeze() + np.testing.assert_allclose(actual, expected, rtol=1e-6, atol=1e-6) def test_stats_explicit_provided_flag_detection(): @@ -2075,8 +2081,9 @@ def test_stats_reconstruction_after_load_state_dict(): assert ACTION in new_normalizer.stats # Check that values are correct (converted back from tensors) - np.testing.assert_allclose(new_normalizer.stats[OBS_IMAGE]["mean"], [0.5, 0.5, 0.5]) - np.testing.assert_allclose(new_normalizer.stats[OBS_IMAGE]["std"], [0.2, 0.2, 0.2]) + # Note: visual stats are reshaped to (C,1,1), so we squeeze before comparing + np.testing.assert_allclose(new_normalizer.stats[OBS_IMAGE]["mean"].squeeze(), [0.5, 0.5, 0.5]) + np.testing.assert_allclose(new_normalizer.stats[OBS_IMAGE]["std"].squeeze(), [0.2, 0.2, 0.2]) np.testing.assert_allclose(new_normalizer.stats[OBS_STATE]["min"], [0.0, -1.0]) np.testing.assert_allclose(new_normalizer.stats[OBS_STATE]["max"], [1.0, 1.0]) np.testing.assert_allclose(new_normalizer.stats[ACTION]["mean"], [0.0, 0.0])