mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 16:49:55 +00:00
fix(processor): Preserve stats overrides in normalizer load_state_dict and fix training resumption (#1958)
* feat(processor): enhance normalization handling and state management - Added support for additional normalization modes including IDENTITY. - Introduced a new function `clean_state_dict` to remove specific substrings from state dict keys. - Implemented preservation of explicitly provided normalization statistics during state loading. - Updated training script to conditionally provide dataset statistics based on resume state. - Expanded tests to verify the correct behavior of stats override preservation and loading. * fix(train): remove redundant comment regarding state loading - Removed a comment that noted the preprocessor and postprocessor state is already loaded when resuming training, as it was deemed unnecessary for clarity.
This commit is contained in:
@@ -1530,7 +1530,239 @@ def test_dtype_adaptation_bfloat16_input_float32_normalizer():
|
||||
assert torch.allclose(output_tensor, expected, atol=1e-2) # bfloat16 has lower precision
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_stats_override_preservation_in_load_state_dict():
|
||||
"""
|
||||
Test that explicitly provided stats are preserved during load_state_dict.
|
||||
|
||||
This tests the fix for the bug where stats provided via overrides were
|
||||
being overwritten when load_state_dict was called.
|
||||
"""
|
||||
# Create original stats
|
||||
original_stats = {
|
||||
"observation.image": {"mean": np.array([0.5, 0.5, 0.5]), "std": np.array([0.2, 0.2, 0.2])},
|
||||
"action": {"mean": np.array([0.0, 0.0]), "std": np.array([1.0, 1.0])},
|
||||
}
|
||||
|
||||
# Create override stats (what user wants to use)
|
||||
override_stats = {
|
||||
"observation.image": {"mean": np.array([0.3, 0.3, 0.3]), "std": np.array([0.1, 0.1, 0.1])},
|
||||
"action": {"mean": np.array([0.1, 0.1]), "std": np.array([0.5, 0.5])},
|
||||
}
|
||||
|
||||
features = {
|
||||
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)),
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)),
|
||||
}
|
||||
norm_map = {
|
||||
FeatureType.VISUAL: NormalizationMode.MEAN_STD,
|
||||
FeatureType.ACTION: NormalizationMode.MEAN_STD,
|
||||
}
|
||||
|
||||
# Create a normalizer with original stats and save its state
|
||||
original_normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=original_stats)
|
||||
saved_state_dict = original_normalizer.state_dict()
|
||||
|
||||
# Create a new normalizer with override stats (simulating from_pretrained with overrides)
|
||||
override_normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=override_stats)
|
||||
|
||||
# Verify that the override stats are initially set correctly
|
||||
assert set(override_normalizer.stats.keys()) == set(override_stats.keys())
|
||||
for key in override_stats:
|
||||
assert set(override_normalizer.stats[key].keys()) == set(override_stats[key].keys())
|
||||
for stat_name in override_stats[key]:
|
||||
np.testing.assert_array_equal(
|
||||
override_normalizer.stats[key][stat_name], override_stats[key][stat_name]
|
||||
)
|
||||
assert override_normalizer._stats_explicitly_provided is True
|
||||
|
||||
# This is the critical test: load_state_dict should NOT overwrite the override stats
|
||||
override_normalizer.load_state_dict(saved_state_dict)
|
||||
|
||||
# After loading state_dict, stats should still be the override stats, not the original stats
|
||||
# Check that loaded stats match override stats
|
||||
assert set(override_normalizer.stats.keys()) == set(override_stats.keys())
|
||||
for key in override_stats:
|
||||
assert set(override_normalizer.stats[key].keys()) == set(override_stats[key].keys())
|
||||
for stat_name in override_stats[key]:
|
||||
np.testing.assert_array_equal(
|
||||
override_normalizer.stats[key][stat_name], override_stats[key][stat_name]
|
||||
)
|
||||
# Compare individual arrays to avoid numpy array comparison ambiguity
|
||||
for key in override_stats:
|
||||
for stat_name in override_stats[key]:
|
||||
assert not np.array_equal(
|
||||
override_normalizer.stats[key][stat_name], original_stats[key][stat_name]
|
||||
), f"Stats for {key}.{stat_name} should not match original stats"
|
||||
|
||||
# Verify that _tensor_stats are also correctly set to match the override stats
|
||||
expected_tensor_stats = to_tensor(override_stats)
|
||||
for key in expected_tensor_stats:
|
||||
for stat_name in expected_tensor_stats[key]:
|
||||
if isinstance(expected_tensor_stats[key][stat_name], torch.Tensor):
|
||||
torch.testing.assert_close(
|
||||
override_normalizer._tensor_stats[key][stat_name], expected_tensor_stats[key][stat_name]
|
||||
)
|
||||
|
||||
|
||||
def test_stats_without_override_loads_normally():
|
||||
"""
|
||||
Test that when stats are not explicitly provided (normal case),
|
||||
load_state_dict works as before.
|
||||
"""
|
||||
original_stats = {
|
||||
"observation.image": {"mean": np.array([0.5, 0.5, 0.5]), "std": np.array([0.2, 0.2, 0.2])},
|
||||
"action": {"mean": np.array([0.0, 0.0]), "std": np.array([1.0, 1.0])},
|
||||
}
|
||||
|
||||
features = {
|
||||
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)),
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)),
|
||||
}
|
||||
norm_map = {
|
||||
FeatureType.VISUAL: NormalizationMode.MEAN_STD,
|
||||
FeatureType.ACTION: NormalizationMode.MEAN_STD,
|
||||
}
|
||||
|
||||
# Create a normalizer with original stats and save its state
|
||||
original_normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=original_stats)
|
||||
saved_state_dict = original_normalizer.state_dict()
|
||||
|
||||
# Create a new normalizer without stats (simulating normal from_pretrained)
|
||||
new_normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats={})
|
||||
|
||||
# Verify that stats are not explicitly provided
|
||||
assert new_normalizer._stats_explicitly_provided is False
|
||||
|
||||
# Load state dict - this should work normally and load the saved stats
|
||||
new_normalizer.load_state_dict(saved_state_dict)
|
||||
|
||||
# Stats should now match the original stats (normal behavior)
|
||||
# Check that all keys and values match
|
||||
assert set(new_normalizer.stats.keys()) == set(original_stats.keys())
|
||||
for key in original_stats:
|
||||
assert set(new_normalizer.stats[key].keys()) == set(original_stats[key].keys())
|
||||
for stat_name in original_stats[key]:
|
||||
np.testing.assert_allclose(
|
||||
new_normalizer.stats[key][stat_name], original_stats[key][stat_name], rtol=1e-6, atol=1e-6
|
||||
)
|
||||
|
||||
|
||||
def test_stats_explicit_provided_flag_detection():
|
||||
"""Test that the _stats_explicitly_provided flag is set correctly in different scenarios."""
|
||||
features = {
|
||||
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)),
|
||||
}
|
||||
norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD}
|
||||
|
||||
# Test 1: Explicitly provided stats (non-empty dict)
|
||||
stats = {"observation.image": {"mean": [0.5], "std": [0.2]}}
|
||||
normalizer1 = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
|
||||
assert normalizer1._stats_explicitly_provided is True
|
||||
|
||||
# Test 2: Empty stats dict
|
||||
normalizer2 = NormalizerProcessorStep(features=features, norm_map=norm_map, stats={})
|
||||
assert normalizer2._stats_explicitly_provided is False
|
||||
|
||||
# Test 3: None stats
|
||||
normalizer3 = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=None)
|
||||
assert normalizer3._stats_explicitly_provided is False
|
||||
|
||||
# Test 4: Stats not provided (defaults to None)
|
||||
normalizer4 = NormalizerProcessorStep(features=features, norm_map=norm_map)
|
||||
assert normalizer4._stats_explicitly_provided is False
|
||||
|
||||
|
||||
def test_pipeline_from_pretrained_with_stats_overrides():
|
||||
"""
|
||||
Test the actual use case: DataProcessorPipeline.from_pretrained with stat overrides.
|
||||
|
||||
This is an integration test that verifies the fix works in the real scenario
|
||||
where users provide stat overrides when loading a pipeline.
|
||||
"""
|
||||
import tempfile
|
||||
|
||||
# Create test data
|
||||
features = {
|
||||
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 32, 32)),
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)),
|
||||
}
|
||||
norm_map = {
|
||||
FeatureType.VISUAL: NormalizationMode.MEAN_STD,
|
||||
FeatureType.ACTION: NormalizationMode.MEAN_STD,
|
||||
}
|
||||
|
||||
original_stats = {
|
||||
"observation.image": {"mean": np.array([0.5, 0.5, 0.5]), "std": np.array([0.2, 0.2, 0.2])},
|
||||
"action": {"mean": np.array([0.0, 0.0]), "std": np.array([1.0, 1.0])},
|
||||
}
|
||||
|
||||
override_stats = {
|
||||
"observation.image": {"mean": np.array([0.3, 0.3, 0.3]), "std": np.array([0.1, 0.1, 0.1])},
|
||||
"action": {"mean": np.array([0.1, 0.1]), "std": np.array([0.5, 0.5])},
|
||||
}
|
||||
|
||||
# Create and save a pipeline with the original stats
|
||||
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=original_stats)
|
||||
identity = IdentityProcessorStep()
|
||||
original_pipeline = DataProcessorPipeline(steps=[normalizer, identity], name="test_pipeline")
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Save the pipeline
|
||||
original_pipeline.save_pretrained(temp_dir)
|
||||
|
||||
# Load the pipeline with stat overrides
|
||||
overrides = {"normalizer_processor": {"stats": override_stats}}
|
||||
|
||||
loaded_pipeline = DataProcessorPipeline.from_pretrained(temp_dir, overrides=overrides)
|
||||
|
||||
# The critical test: the loaded pipeline should use override stats, not original stats
|
||||
loaded_normalizer = loaded_pipeline.steps[0]
|
||||
assert isinstance(loaded_normalizer, NormalizerProcessorStep)
|
||||
|
||||
# Check that loaded stats match override stats
|
||||
assert set(loaded_normalizer.stats.keys()) == set(override_stats.keys())
|
||||
for key in override_stats:
|
||||
assert set(loaded_normalizer.stats[key].keys()) == set(override_stats[key].keys())
|
||||
for stat_name in override_stats[key]:
|
||||
np.testing.assert_array_equal(
|
||||
loaded_normalizer.stats[key][stat_name], override_stats[key][stat_name]
|
||||
)
|
||||
|
||||
# Verify stats don't match original stats
|
||||
for key in override_stats:
|
||||
for stat_name in override_stats[key]:
|
||||
assert not np.array_equal(
|
||||
loaded_normalizer.stats[key][stat_name], original_stats[key][stat_name]
|
||||
), f"Stats for {key}.{stat_name} should not match original stats"
|
||||
|
||||
# Test that the override stats are actually used in processing
|
||||
observation = {
|
||||
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
|
||||
}
|
||||
action = torch.tensor([1.0, -0.5])
|
||||
transition = create_transition(observation=observation, action=action)
|
||||
|
||||
# Process with override pipeline
|
||||
override_result = loaded_pipeline(transition)
|
||||
|
||||
# Create a reference pipeline with override stats for comparison
|
||||
reference_normalizer = NormalizerProcessorStep(
|
||||
features=features, norm_map=norm_map, stats=override_stats
|
||||
)
|
||||
reference_pipeline = DataProcessorPipeline(
|
||||
steps=[reference_normalizer, identity],
|
||||
to_transition=identity_transition,
|
||||
to_output=identity_transition,
|
||||
)
|
||||
_ = reference_pipeline(transition)
|
||||
|
||||
# The critical part was verified above: loaded_normalizer.stats == override_stats
|
||||
# This confirms that override stats are preserved during load_state_dict.
|
||||
# Let's just verify the pipeline processes data successfully.
|
||||
assert "action" in override_result
|
||||
assert isinstance(override_result["action"], torch.Tensor)
|
||||
|
||||
|
||||
def test_dtype_adaptation_device_processor_bfloat16_normalizer_float32():
|
||||
"""Test policy pipeline scenario: DeviceProcessor(bfloat16) + NormalizerProcessor(float32) → bfloat16 output"""
|
||||
from lerobot.processor import DeviceProcessorStep
|
||||
|
||||
Reference in New Issue
Block a user