fix(processor): Preserve stats overrides in normalizer load_state_dict and fix training resumption (#1958)

* feat(processor): enhance normalization handling and state management

- Added support for additional normalization modes including IDENTITY.
- Introduced a new function `clean_state_dict` to remove specific substrings from state dict keys.
- Implemented preservation of explicitly provided normalization statistics during state loading.
- Updated training script to conditionally provide dataset statistics based on resume state.
- Expanded tests to verify the correct behavior of stats override preservation and loading.

* fix(train): remove redundant comment regarding state loading

- Removed a comment that noted the preprocessor and postprocessor state is already loaded when resuming training, as it was deemed unnecessary for clarity.
This commit is contained in:
Adil Zouitine
2025-09-16 16:45:13 +02:00
committed by GitHub
parent 772da63a8e
commit a7d1179aab
4 changed files with 321 additions and 11 deletions
@@ -88,6 +88,10 @@ def extract_normalization_stats(state_dict: dict[str, torch.Tensor]) -> dict[str
"unnormalize.", # Must come after unnormalize_* patterns
"input_normalizer.",
"output_normalizer.",
"normalalize_inputs.",
"unnormalize_outputs.",
"normalize_targets.",
"unnormalize_targets.",
]
# Process each key in state_dict
@@ -168,6 +172,8 @@ def detect_features_and_norm_modes(
mode = NormalizationMode.MEAN_STD
elif mode_str == "MIN_MAX":
mode = NormalizationMode.MIN_MAX
elif mode_str == "IDENTITY":
mode = NormalizationMode.IDENTITY
else:
print(
f"Warning: Unknown normalization mode '{mode_str}' for feature type '{feature_type_str}'"
@@ -276,6 +282,26 @@ def remove_normalization_layers(state_dict: dict[str, torch.Tensor]) -> dict[str
return new_state_dict
def clean_state_dict(
state_dict: dict[str, torch.Tensor], remove_str: str = "._orig_mod"
) -> dict[str, torch.Tensor]:
"""
Remove a substring (e.g. '._orig_mod') from all keys in a state dict.
Args:
state_dict (dict): The original state dict.
remove_str (str): The substring to remove from the keys.
Returns:
dict: A new state dict with cleaned keys.
"""
new_state_dict = {}
for k, v in state_dict.items():
new_k = k.replace(remove_str, "")
new_state_dict[new_k] = v
return new_state_dict
def convert_features_to_policy_features(features_dict: dict[str, dict]) -> dict[str, PolicyFeature]:
"""
Converts a feature dictionary from the old config format to the new `PolicyFeature` format.
@@ -405,6 +431,7 @@ def main():
# Remove normalization layers from state_dict
print("Removing normalization layers from model...")
new_state_dict = remove_normalization_layers(state_dict)
new_state_dict = clean_state_dict(new_state_dict, remove_str="._orig_mod")
removed_keys = set(state_dict.keys()) - set(new_state_dict.keys())
if removed_keys:
+53 -2
View File
@@ -43,6 +43,30 @@ class _NormalizationMixin:
be inherited by concrete `ProcessorStep` implementations and should not be used
directly.
**Stats Override Preservation:**
When stats are explicitly provided during construction (e.g., via overrides in
`DataProcessorPipeline.from_pretrained()`), they are preserved even when
`load_state_dict()` is called. This allows users to override normalization
statistics from saved models while keeping the rest of the model state intact.
Examples:
```python
# Common use case: Override with dataset stats
from lerobot.datasets import LeRobotDataset
dataset = LeRobotDataset("my_dataset")
pipeline = DataProcessorPipeline.from_pretrained(
"model_path", overrides={"normalizer_processor": {"stats": dataset.meta.stats}}
)
# dataset.meta.stats will be used, not the stats from the saved model
# Custom stats override
custom_stats = {"action": {"mean": [0.0], "std": [1.0]}}
pipeline = DataProcessorPipeline.from_pretrained(
"model_path", overrides={"normalizer_processor": {"stats": custom_stats}}
)
```
Attributes:
features: A dictionary mapping feature names to `PolicyFeature` objects, defining
the data structure to be processed.
@@ -57,6 +81,8 @@ class _NormalizationMixin:
normalization to specific observation features.
_tensor_stats: An internal dictionary holding the normalization statistics as
PyTorch tensors.
_stats_explicitly_provided: Internal flag tracking whether stats were explicitly
provided during construction (used for override preservation).
"""
features: dict[str, PolicyFeature]
@@ -68,6 +94,7 @@ class _NormalizationMixin:
normalize_observation_keys: set[str] | None = None
_tensor_stats: dict[str, dict[str, Tensor]] = field(default_factory=dict, init=False, repr=False)
_stats_explicitly_provided: bool = field(default=False, init=False, repr=False)
def __post_init__(self):
"""
@@ -78,6 +105,8 @@ class _NormalizationMixin:
lists) and converts the provided `stats` dictionary into a dictionary of
tensors (`_tensor_stats`) on the specified device.
"""
# Track if stats were explicitly provided (not None and not empty)
self._stats_explicitly_provided = self.stats is not None and bool(self.stats)
# Robust JSON deserialization handling (guard empty maps).
if self.features:
first_val = next(iter(self.features.values()))
@@ -145,10 +174,33 @@ class _NormalizationMixin:
The loaded tensors are moved to the processor's configured device.
**Stats Override Preservation:**
If stats were explicitly provided during construction (e.g., via overrides in
`DataProcessorPipeline.from_pretrained()`), they are preserved and the state
dictionary is ignored. This allows users to override normalization statistics
while still loading the rest of the model state.
This behavior is crucial for scenarios where users want to adapt a pretrained
model to a new dataset with different statistics without retraining the entire
model.
Args:
state: A flat state dictionary with keys in the format
`'feature_name.stat_name'`.
Note:
When stats are preserved due to explicit provision, only the tensor
representation is updated to ensure consistency with the current device
and dtype settings.
"""
# If stats were explicitly provided during construction, preserve them
if self._stats_explicitly_provided and self.stats is not None:
# Don't load from state_dict, keep the explicitly provided stats
# But ensure _tensor_stats is properly initialized
self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype) # type: ignore[assignment]
return
# Normal behavior: load stats from state_dict
self._tensor_stats.clear()
for flat_key, tensor in state.items():
key, stat_name = flat_key.rsplit(".", 1)
@@ -159,7 +211,6 @@ class _NormalizationMixin:
# Reconstruct the original stats dict from tensor stats for compatibility with to() method
# and other functions that rely on self.stats
self.stats = {}
for key, tensor_dict in self._tensor_stats.items():
self.stats[key] = {}
@@ -446,5 +497,5 @@ def hotswap_stats(
if isinstance(step, _NormalizationMixin):
step.stats = stats
# Re-initialize tensor_stats on the correct device.
step._tensor_stats = to_tensor(stats, device=step.device, dtype=step.dtype)
step._tensor_stats = to_tensor(stats, device=step.device, dtype=step.dtype) # type: ignore[assignment]
return rp
+8 -8
View File
@@ -26,7 +26,6 @@ from torch.optim import Optimizer
from lerobot.configs import parser
from lerobot.configs.train import TrainPipelineConfig
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
from lerobot.datasets.factory import make_dataset
from lerobot.datasets.sampler import EpisodeAwareSampler
from lerobot.datasets.utils import cycle
@@ -177,8 +176,15 @@ def train(cfg: TrainPipelineConfig):
cfg=cfg.policy,
ds_meta=dataset.meta,
)
# Create processors - only provide dataset_stats if not resuming from saved processors
processor_kwargs = {}
if not (cfg.resume and cfg.policy.pretrained_path):
# Only provide dataset_stats when not resuming from saved processor state
processor_kwargs["dataset_stats"] = dataset.meta.stats
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=cfg.policy, pretrained_path=cfg.policy.pretrained_path, dataset_stats=dataset.meta.stats
policy_cfg=cfg.policy, pretrained_path=cfg.policy.pretrained_path, **processor_kwargs
)
logging.info("Creating optimizer and scheduler")
@@ -189,12 +195,6 @@ def train(cfg: TrainPipelineConfig):
if cfg.resume:
step, optimizer, lr_scheduler = load_training_state(cfg.checkpoint_path, optimizer, lr_scheduler)
preprocessor.from_pretrained(
cfg.policy.pretrained_path, config_filename=f"{POLICY_PREPROCESSOR_DEFAULT_NAME}.json"
)
postprocessor.from_pretrained(
cfg.policy.pretrained_path, config_filename=f"{POLICY_POSTPROCESSOR_DEFAULT_NAME}.json"
)
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
num_total_params = sum(p.numel() for p in policy.parameters())
+233 -1
View File
@@ -1530,7 +1530,239 @@ def test_dtype_adaptation_bfloat16_input_float32_normalizer():
assert torch.allclose(output_tensor, expected, atol=1e-2) # bfloat16 has lower precision
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_stats_override_preservation_in_load_state_dict():
"""
Test that explicitly provided stats are preserved during load_state_dict.
This tests the fix for the bug where stats provided via overrides were
being overwritten when load_state_dict was called.
"""
# Create original stats
original_stats = {
"observation.image": {"mean": np.array([0.5, 0.5, 0.5]), "std": np.array([0.2, 0.2, 0.2])},
"action": {"mean": np.array([0.0, 0.0]), "std": np.array([1.0, 1.0])},
}
# Create override stats (what user wants to use)
override_stats = {
"observation.image": {"mean": np.array([0.3, 0.3, 0.3]), "std": np.array([0.1, 0.1, 0.1])},
"action": {"mean": np.array([0.1, 0.1]), "std": np.array([0.5, 0.5])},
}
features = {
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)),
"action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)),
}
norm_map = {
FeatureType.VISUAL: NormalizationMode.MEAN_STD,
FeatureType.ACTION: NormalizationMode.MEAN_STD,
}
# Create a normalizer with original stats and save its state
original_normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=original_stats)
saved_state_dict = original_normalizer.state_dict()
# Create a new normalizer with override stats (simulating from_pretrained with overrides)
override_normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=override_stats)
# Verify that the override stats are initially set correctly
assert set(override_normalizer.stats.keys()) == set(override_stats.keys())
for key in override_stats:
assert set(override_normalizer.stats[key].keys()) == set(override_stats[key].keys())
for stat_name in override_stats[key]:
np.testing.assert_array_equal(
override_normalizer.stats[key][stat_name], override_stats[key][stat_name]
)
assert override_normalizer._stats_explicitly_provided is True
# This is the critical test: load_state_dict should NOT overwrite the override stats
override_normalizer.load_state_dict(saved_state_dict)
# After loading state_dict, stats should still be the override stats, not the original stats
# Check that loaded stats match override stats
assert set(override_normalizer.stats.keys()) == set(override_stats.keys())
for key in override_stats:
assert set(override_normalizer.stats[key].keys()) == set(override_stats[key].keys())
for stat_name in override_stats[key]:
np.testing.assert_array_equal(
override_normalizer.stats[key][stat_name], override_stats[key][stat_name]
)
# Compare individual arrays to avoid numpy array comparison ambiguity
for key in override_stats:
for stat_name in override_stats[key]:
assert not np.array_equal(
override_normalizer.stats[key][stat_name], original_stats[key][stat_name]
), f"Stats for {key}.{stat_name} should not match original stats"
# Verify that _tensor_stats are also correctly set to match the override stats
expected_tensor_stats = to_tensor(override_stats)
for key in expected_tensor_stats:
for stat_name in expected_tensor_stats[key]:
if isinstance(expected_tensor_stats[key][stat_name], torch.Tensor):
torch.testing.assert_close(
override_normalizer._tensor_stats[key][stat_name], expected_tensor_stats[key][stat_name]
)
def test_stats_without_override_loads_normally():
"""
Test that when stats are not explicitly provided (normal case),
load_state_dict works as before.
"""
original_stats = {
"observation.image": {"mean": np.array([0.5, 0.5, 0.5]), "std": np.array([0.2, 0.2, 0.2])},
"action": {"mean": np.array([0.0, 0.0]), "std": np.array([1.0, 1.0])},
}
features = {
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)),
"action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)),
}
norm_map = {
FeatureType.VISUAL: NormalizationMode.MEAN_STD,
FeatureType.ACTION: NormalizationMode.MEAN_STD,
}
# Create a normalizer with original stats and save its state
original_normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=original_stats)
saved_state_dict = original_normalizer.state_dict()
# Create a new normalizer without stats (simulating normal from_pretrained)
new_normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats={})
# Verify that stats are not explicitly provided
assert new_normalizer._stats_explicitly_provided is False
# Load state dict - this should work normally and load the saved stats
new_normalizer.load_state_dict(saved_state_dict)
# Stats should now match the original stats (normal behavior)
# Check that all keys and values match
assert set(new_normalizer.stats.keys()) == set(original_stats.keys())
for key in original_stats:
assert set(new_normalizer.stats[key].keys()) == set(original_stats[key].keys())
for stat_name in original_stats[key]:
np.testing.assert_allclose(
new_normalizer.stats[key][stat_name], original_stats[key][stat_name], rtol=1e-6, atol=1e-6
)
def test_stats_explicit_provided_flag_detection():
"""Test that the _stats_explicitly_provided flag is set correctly in different scenarios."""
features = {
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)),
}
norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD}
# Test 1: Explicitly provided stats (non-empty dict)
stats = {"observation.image": {"mean": [0.5], "std": [0.2]}}
normalizer1 = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
assert normalizer1._stats_explicitly_provided is True
# Test 2: Empty stats dict
normalizer2 = NormalizerProcessorStep(features=features, norm_map=norm_map, stats={})
assert normalizer2._stats_explicitly_provided is False
# Test 3: None stats
normalizer3 = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=None)
assert normalizer3._stats_explicitly_provided is False
# Test 4: Stats not provided (defaults to None)
normalizer4 = NormalizerProcessorStep(features=features, norm_map=norm_map)
assert normalizer4._stats_explicitly_provided is False
def test_pipeline_from_pretrained_with_stats_overrides():
"""
Test the actual use case: DataProcessorPipeline.from_pretrained with stat overrides.
This is an integration test that verifies the fix works in the real scenario
where users provide stat overrides when loading a pipeline.
"""
import tempfile
# Create test data
features = {
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 32, 32)),
"action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)),
}
norm_map = {
FeatureType.VISUAL: NormalizationMode.MEAN_STD,
FeatureType.ACTION: NormalizationMode.MEAN_STD,
}
original_stats = {
"observation.image": {"mean": np.array([0.5, 0.5, 0.5]), "std": np.array([0.2, 0.2, 0.2])},
"action": {"mean": np.array([0.0, 0.0]), "std": np.array([1.0, 1.0])},
}
override_stats = {
"observation.image": {"mean": np.array([0.3, 0.3, 0.3]), "std": np.array([0.1, 0.1, 0.1])},
"action": {"mean": np.array([0.1, 0.1]), "std": np.array([0.5, 0.5])},
}
# Create and save a pipeline with the original stats
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=original_stats)
identity = IdentityProcessorStep()
original_pipeline = DataProcessorPipeline(steps=[normalizer, identity], name="test_pipeline")
with tempfile.TemporaryDirectory() as temp_dir:
# Save the pipeline
original_pipeline.save_pretrained(temp_dir)
# Load the pipeline with stat overrides
overrides = {"normalizer_processor": {"stats": override_stats}}
loaded_pipeline = DataProcessorPipeline.from_pretrained(temp_dir, overrides=overrides)
# The critical test: the loaded pipeline should use override stats, not original stats
loaded_normalizer = loaded_pipeline.steps[0]
assert isinstance(loaded_normalizer, NormalizerProcessorStep)
# Check that loaded stats match override stats
assert set(loaded_normalizer.stats.keys()) == set(override_stats.keys())
for key in override_stats:
assert set(loaded_normalizer.stats[key].keys()) == set(override_stats[key].keys())
for stat_name in override_stats[key]:
np.testing.assert_array_equal(
loaded_normalizer.stats[key][stat_name], override_stats[key][stat_name]
)
# Verify stats don't match original stats
for key in override_stats:
for stat_name in override_stats[key]:
assert not np.array_equal(
loaded_normalizer.stats[key][stat_name], original_stats[key][stat_name]
), f"Stats for {key}.{stat_name} should not match original stats"
# Test that the override stats are actually used in processing
observation = {
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
}
action = torch.tensor([1.0, -0.5])
transition = create_transition(observation=observation, action=action)
# Process with override pipeline
override_result = loaded_pipeline(transition)
# Create a reference pipeline with override stats for comparison
reference_normalizer = NormalizerProcessorStep(
features=features, norm_map=norm_map, stats=override_stats
)
reference_pipeline = DataProcessorPipeline(
steps=[reference_normalizer, identity],
to_transition=identity_transition,
to_output=identity_transition,
)
_ = reference_pipeline(transition)
# The critical part was verified above: loaded_normalizer.stats == override_stats
# This confirms that override stats are preserved during load_state_dict.
# Let's just verify the pipeline processes data successfully.
assert "action" in override_result
assert isinstance(override_result["action"], torch.Tensor)
def test_dtype_adaptation_device_processor_bfloat16_normalizer_float32():
"""Test policy pipeline scenario: DeviceProcessor(bfloat16) + NormalizerProcessor(float32) → bfloat16 output"""
from lerobot.processor import DeviceProcessorStep