fix(processor): Preserve stats overrides in normalizer load_state_dict and fix training resumption (#1958)

* feat(processor): enhance normalization handling and state management

- Added support for additional normalization modes including IDENTITY.
- Introduced a new function `clean_state_dict` to remove specific substrings from state dict keys.
- Implemented preservation of explicitly provided normalization statistics during state loading.
- Updated training script to conditionally provide dataset statistics based on resume state.
- Expanded tests to verify the correct behavior of stats override preservation and loading.

* fix(train): remove redundant comment regarding state loading

- Removed a comment that noted the preprocessor and postprocessor state is already loaded when resuming training, as it was deemed unnecessary for clarity.
This commit is contained in:
Adil Zouitine
2025-09-16 16:45:13 +02:00
committed by GitHub
parent 772da63a8e
commit a7d1179aab
4 changed files with 321 additions and 11 deletions
+233 -1
View File
@@ -1530,7 +1530,239 @@ def test_dtype_adaptation_bfloat16_input_float32_normalizer():
assert torch.allclose(output_tensor, expected, atol=1e-2) # bfloat16 has lower precision
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_stats_override_preservation_in_load_state_dict():
"""
Test that explicitly provided stats are preserved during load_state_dict.
This tests the fix for the bug where stats provided via overrides were
being overwritten when load_state_dict was called.
"""
# Create original stats
original_stats = {
"observation.image": {"mean": np.array([0.5, 0.5, 0.5]), "std": np.array([0.2, 0.2, 0.2])},
"action": {"mean": np.array([0.0, 0.0]), "std": np.array([1.0, 1.0])},
}
# Create override stats (what user wants to use)
override_stats = {
"observation.image": {"mean": np.array([0.3, 0.3, 0.3]), "std": np.array([0.1, 0.1, 0.1])},
"action": {"mean": np.array([0.1, 0.1]), "std": np.array([0.5, 0.5])},
}
features = {
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)),
"action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)),
}
norm_map = {
FeatureType.VISUAL: NormalizationMode.MEAN_STD,
FeatureType.ACTION: NormalizationMode.MEAN_STD,
}
# Create a normalizer with original stats and save its state
original_normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=original_stats)
saved_state_dict = original_normalizer.state_dict()
# Create a new normalizer with override stats (simulating from_pretrained with overrides)
override_normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=override_stats)
# Verify that the override stats are initially set correctly
assert set(override_normalizer.stats.keys()) == set(override_stats.keys())
for key in override_stats:
assert set(override_normalizer.stats[key].keys()) == set(override_stats[key].keys())
for stat_name in override_stats[key]:
np.testing.assert_array_equal(
override_normalizer.stats[key][stat_name], override_stats[key][stat_name]
)
assert override_normalizer._stats_explicitly_provided is True
# This is the critical test: load_state_dict should NOT overwrite the override stats
override_normalizer.load_state_dict(saved_state_dict)
# After loading state_dict, stats should still be the override stats, not the original stats
# Check that loaded stats match override stats
assert set(override_normalizer.stats.keys()) == set(override_stats.keys())
for key in override_stats:
assert set(override_normalizer.stats[key].keys()) == set(override_stats[key].keys())
for stat_name in override_stats[key]:
np.testing.assert_array_equal(
override_normalizer.stats[key][stat_name], override_stats[key][stat_name]
)
# Compare individual arrays to avoid numpy array comparison ambiguity
for key in override_stats:
for stat_name in override_stats[key]:
assert not np.array_equal(
override_normalizer.stats[key][stat_name], original_stats[key][stat_name]
), f"Stats for {key}.{stat_name} should not match original stats"
# Verify that _tensor_stats are also correctly set to match the override stats
expected_tensor_stats = to_tensor(override_stats)
for key in expected_tensor_stats:
for stat_name in expected_tensor_stats[key]:
if isinstance(expected_tensor_stats[key][stat_name], torch.Tensor):
torch.testing.assert_close(
override_normalizer._tensor_stats[key][stat_name], expected_tensor_stats[key][stat_name]
)
def test_stats_without_override_loads_normally():
"""
Test that when stats are not explicitly provided (normal case),
load_state_dict works as before.
"""
original_stats = {
"observation.image": {"mean": np.array([0.5, 0.5, 0.5]), "std": np.array([0.2, 0.2, 0.2])},
"action": {"mean": np.array([0.0, 0.0]), "std": np.array([1.0, 1.0])},
}
features = {
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)),
"action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)),
}
norm_map = {
FeatureType.VISUAL: NormalizationMode.MEAN_STD,
FeatureType.ACTION: NormalizationMode.MEAN_STD,
}
# Create a normalizer with original stats and save its state
original_normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=original_stats)
saved_state_dict = original_normalizer.state_dict()
# Create a new normalizer without stats (simulating normal from_pretrained)
new_normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats={})
# Verify that stats are not explicitly provided
assert new_normalizer._stats_explicitly_provided is False
# Load state dict - this should work normally and load the saved stats
new_normalizer.load_state_dict(saved_state_dict)
# Stats should now match the original stats (normal behavior)
# Check that all keys and values match
assert set(new_normalizer.stats.keys()) == set(original_stats.keys())
for key in original_stats:
assert set(new_normalizer.stats[key].keys()) == set(original_stats[key].keys())
for stat_name in original_stats[key]:
np.testing.assert_allclose(
new_normalizer.stats[key][stat_name], original_stats[key][stat_name], rtol=1e-6, atol=1e-6
)
def test_stats_explicit_provided_flag_detection():
"""Test that the _stats_explicitly_provided flag is set correctly in different scenarios."""
features = {
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)),
}
norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD}
# Test 1: Explicitly provided stats (non-empty dict)
stats = {"observation.image": {"mean": [0.5], "std": [0.2]}}
normalizer1 = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
assert normalizer1._stats_explicitly_provided is True
# Test 2: Empty stats dict
normalizer2 = NormalizerProcessorStep(features=features, norm_map=norm_map, stats={})
assert normalizer2._stats_explicitly_provided is False
# Test 3: None stats
normalizer3 = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=None)
assert normalizer3._stats_explicitly_provided is False
# Test 4: Stats not provided (defaults to None)
normalizer4 = NormalizerProcessorStep(features=features, norm_map=norm_map)
assert normalizer4._stats_explicitly_provided is False
def test_pipeline_from_pretrained_with_stats_overrides():
"""
Test the actual use case: DataProcessorPipeline.from_pretrained with stat overrides.
This is an integration test that verifies the fix works in the real scenario
where users provide stat overrides when loading a pipeline.
"""
import tempfile
# Create test data
features = {
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 32, 32)),
"action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)),
}
norm_map = {
FeatureType.VISUAL: NormalizationMode.MEAN_STD,
FeatureType.ACTION: NormalizationMode.MEAN_STD,
}
original_stats = {
"observation.image": {"mean": np.array([0.5, 0.5, 0.5]), "std": np.array([0.2, 0.2, 0.2])},
"action": {"mean": np.array([0.0, 0.0]), "std": np.array([1.0, 1.0])},
}
override_stats = {
"observation.image": {"mean": np.array([0.3, 0.3, 0.3]), "std": np.array([0.1, 0.1, 0.1])},
"action": {"mean": np.array([0.1, 0.1]), "std": np.array([0.5, 0.5])},
}
# Create and save a pipeline with the original stats
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=original_stats)
identity = IdentityProcessorStep()
original_pipeline = DataProcessorPipeline(steps=[normalizer, identity], name="test_pipeline")
with tempfile.TemporaryDirectory() as temp_dir:
# Save the pipeline
original_pipeline.save_pretrained(temp_dir)
# Load the pipeline with stat overrides
overrides = {"normalizer_processor": {"stats": override_stats}}
loaded_pipeline = DataProcessorPipeline.from_pretrained(temp_dir, overrides=overrides)
# The critical test: the loaded pipeline should use override stats, not original stats
loaded_normalizer = loaded_pipeline.steps[0]
assert isinstance(loaded_normalizer, NormalizerProcessorStep)
# Check that loaded stats match override stats
assert set(loaded_normalizer.stats.keys()) == set(override_stats.keys())
for key in override_stats:
assert set(loaded_normalizer.stats[key].keys()) == set(override_stats[key].keys())
for stat_name in override_stats[key]:
np.testing.assert_array_equal(
loaded_normalizer.stats[key][stat_name], override_stats[key][stat_name]
)
# Verify stats don't match original stats
for key in override_stats:
for stat_name in override_stats[key]:
assert not np.array_equal(
loaded_normalizer.stats[key][stat_name], original_stats[key][stat_name]
), f"Stats for {key}.{stat_name} should not match original stats"
# Test that the override stats are actually used in processing
observation = {
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
}
action = torch.tensor([1.0, -0.5])
transition = create_transition(observation=observation, action=action)
# Process with override pipeline
override_result = loaded_pipeline(transition)
# Create a reference pipeline with override stats for comparison
reference_normalizer = NormalizerProcessorStep(
features=features, norm_map=norm_map, stats=override_stats
)
reference_pipeline = DataProcessorPipeline(
steps=[reference_normalizer, identity],
to_transition=identity_transition,
to_output=identity_transition,
)
_ = reference_pipeline(transition)
# The critical part was verified above: loaded_normalizer.stats == override_stats
# This confirms that override stats are preserved during load_state_dict.
# Let's just verify the pipeline processes data successfully.
assert "action" in override_result
assert isinstance(override_result["action"], torch.Tensor)
def test_dtype_adaptation_device_processor_bfloat16_normalizer_float32():
"""Test policy pipeline scenario: DeviceProcessor(bfloat16) + NormalizerProcessor(float32) → bfloat16 output"""
from lerobot.processor import DeviceProcessorStep