From a7d1179aab8f3c7aa9d6f3eb76409bbbc676642a Mon Sep 17 00:00:00 2001 From: Adil Zouitine Date: Tue, 16 Sep 2025 16:45:13 +0200 Subject: [PATCH] fix(processor): Preserve stats overrides in normalizer load_state_dict and fix training resumption (#1958) * feat(processor): enhance normalization handling and state management - Added support for additional normalization modes including IDENTITY. - Introduced a new function `clean_state_dict` to remove specific substrings from state dict keys. - Implemented preservation of explicitly provided normalization statistics during state loading. - Updated training script to conditionally provide dataset statistics based on resume state. - Expanded tests to verify the correct behavior of stats override preservation and loading. * fix(train): remove redundant comment regarding state loading - Removed a comment that noted the preprocessor and postprocessor state is already loaded when resuming training, as it was deemed unnecessary for clarity. --- .../processor/migrate_policy_normalization.py | 27 ++ src/lerobot/processor/normalize_processor.py | 55 +++- src/lerobot/scripts/train.py | 16 +- tests/processor/test_normalize_processor.py | 234 +++++++++++++++++- 4 files changed, 321 insertions(+), 11 deletions(-) diff --git a/src/lerobot/processor/migrate_policy_normalization.py b/src/lerobot/processor/migrate_policy_normalization.py index c4e25a515..131f799d6 100644 --- a/src/lerobot/processor/migrate_policy_normalization.py +++ b/src/lerobot/processor/migrate_policy_normalization.py @@ -88,6 +88,10 @@ def extract_normalization_stats(state_dict: dict[str, torch.Tensor]) -> dict[str "unnormalize.", # Must come after unnormalize_* patterns "input_normalizer.", "output_normalizer.", + "normalalize_inputs.", + "unnormalize_outputs.", + "normalize_targets.", + "unnormalize_targets.", ] # Process each key in state_dict @@ -168,6 +172,8 @@ def detect_features_and_norm_modes( mode = NormalizationMode.MEAN_STD elif mode_str == "MIN_MAX": mode = NormalizationMode.MIN_MAX + elif mode_str == "IDENTITY": + mode = NormalizationMode.IDENTITY else: print( f"Warning: Unknown normalization mode '{mode_str}' for feature type '{feature_type_str}'" @@ -276,6 +282,26 @@ def remove_normalization_layers(state_dict: dict[str, torch.Tensor]) -> dict[str return new_state_dict +def clean_state_dict( + state_dict: dict[str, torch.Tensor], remove_str: str = "._orig_mod" +) -> dict[str, torch.Tensor]: + """ + Remove a substring (e.g. '._orig_mod') from all keys in a state dict. + + Args: + state_dict (dict): The original state dict. + remove_str (str): The substring to remove from the keys. + + Returns: + dict: A new state dict with cleaned keys. + """ + new_state_dict = {} + for k, v in state_dict.items(): + new_k = k.replace(remove_str, "") + new_state_dict[new_k] = v + return new_state_dict + + def convert_features_to_policy_features(features_dict: dict[str, dict]) -> dict[str, PolicyFeature]: """ Converts a feature dictionary from the old config format to the new `PolicyFeature` format. @@ -405,6 +431,7 @@ def main(): # Remove normalization layers from state_dict print("Removing normalization layers from model...") new_state_dict = remove_normalization_layers(state_dict) + new_state_dict = clean_state_dict(new_state_dict, remove_str="._orig_mod") removed_keys = set(state_dict.keys()) - set(new_state_dict.keys()) if removed_keys: diff --git a/src/lerobot/processor/normalize_processor.py b/src/lerobot/processor/normalize_processor.py index 698bb3c92..bece54f0b 100644 --- a/src/lerobot/processor/normalize_processor.py +++ b/src/lerobot/processor/normalize_processor.py @@ -43,6 +43,30 @@ class _NormalizationMixin: be inherited by concrete `ProcessorStep` implementations and should not be used directly. + **Stats Override Preservation:** + When stats are explicitly provided during construction (e.g., via overrides in + `DataProcessorPipeline.from_pretrained()`), they are preserved even when + `load_state_dict()` is called. This allows users to override normalization + statistics from saved models while keeping the rest of the model state intact. + + Examples: + ```python + # Common use case: Override with dataset stats + from lerobot.datasets import LeRobotDataset + + dataset = LeRobotDataset("my_dataset") + pipeline = DataProcessorPipeline.from_pretrained( + "model_path", overrides={"normalizer_processor": {"stats": dataset.meta.stats}} + ) + # dataset.meta.stats will be used, not the stats from the saved model + + # Custom stats override + custom_stats = {"action": {"mean": [0.0], "std": [1.0]}} + pipeline = DataProcessorPipeline.from_pretrained( + "model_path", overrides={"normalizer_processor": {"stats": custom_stats}} + ) + ``` + Attributes: features: A dictionary mapping feature names to `PolicyFeature` objects, defining the data structure to be processed. @@ -57,6 +81,8 @@ class _NormalizationMixin: normalization to specific observation features. _tensor_stats: An internal dictionary holding the normalization statistics as PyTorch tensors. + _stats_explicitly_provided: Internal flag tracking whether stats were explicitly + provided during construction (used for override preservation). """ features: dict[str, PolicyFeature] @@ -68,6 +94,7 @@ class _NormalizationMixin: normalize_observation_keys: set[str] | None = None _tensor_stats: dict[str, dict[str, Tensor]] = field(default_factory=dict, init=False, repr=False) + _stats_explicitly_provided: bool = field(default=False, init=False, repr=False) def __post_init__(self): """ @@ -78,6 +105,8 @@ class _NormalizationMixin: lists) and converts the provided `stats` dictionary into a dictionary of tensors (`_tensor_stats`) on the specified device. """ + # Track if stats were explicitly provided (not None and not empty) + self._stats_explicitly_provided = self.stats is not None and bool(self.stats) # Robust JSON deserialization handling (guard empty maps). if self.features: first_val = next(iter(self.features.values())) @@ -145,10 +174,33 @@ class _NormalizationMixin: The loaded tensors are moved to the processor's configured device. + **Stats Override Preservation:** + If stats were explicitly provided during construction (e.g., via overrides in + `DataProcessorPipeline.from_pretrained()`), they are preserved and the state + dictionary is ignored. This allows users to override normalization statistics + while still loading the rest of the model state. + + This behavior is crucial for scenarios where users want to adapt a pretrained + model to a new dataset with different statistics without retraining the entire + model. + Args: state: A flat state dictionary with keys in the format `'feature_name.stat_name'`. + + Note: + When stats are preserved due to explicit provision, only the tensor + representation is updated to ensure consistency with the current device + and dtype settings. """ + # If stats were explicitly provided during construction, preserve them + if self._stats_explicitly_provided and self.stats is not None: + # Don't load from state_dict, keep the explicitly provided stats + # But ensure _tensor_stats is properly initialized + self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype) # type: ignore[assignment] + return + + # Normal behavior: load stats from state_dict self._tensor_stats.clear() for flat_key, tensor in state.items(): key, stat_name = flat_key.rsplit(".", 1) @@ -159,7 +211,6 @@ class _NormalizationMixin: # Reconstruct the original stats dict from tensor stats for compatibility with to() method # and other functions that rely on self.stats - self.stats = {} for key, tensor_dict in self._tensor_stats.items(): self.stats[key] = {} @@ -446,5 +497,5 @@ def hotswap_stats( if isinstance(step, _NormalizationMixin): step.stats = stats # Re-initialize tensor_stats on the correct device. - step._tensor_stats = to_tensor(stats, device=step.device, dtype=step.dtype) + step._tensor_stats = to_tensor(stats, device=step.device, dtype=step.dtype) # type: ignore[assignment] return rp diff --git a/src/lerobot/scripts/train.py b/src/lerobot/scripts/train.py index 316069eb9..485fc9275 100644 --- a/src/lerobot/scripts/train.py +++ b/src/lerobot/scripts/train.py @@ -26,7 +26,6 @@ from torch.optim import Optimizer from lerobot.configs import parser from lerobot.configs.train import TrainPipelineConfig -from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME from lerobot.datasets.factory import make_dataset from lerobot.datasets.sampler import EpisodeAwareSampler from lerobot.datasets.utils import cycle @@ -177,8 +176,15 @@ def train(cfg: TrainPipelineConfig): cfg=cfg.policy, ds_meta=dataset.meta, ) + + # Create processors - only provide dataset_stats if not resuming from saved processors + processor_kwargs = {} + if not (cfg.resume and cfg.policy.pretrained_path): + # Only provide dataset_stats when not resuming from saved processor state + processor_kwargs["dataset_stats"] = dataset.meta.stats + preprocessor, postprocessor = make_pre_post_processors( - policy_cfg=cfg.policy, pretrained_path=cfg.policy.pretrained_path, dataset_stats=dataset.meta.stats + policy_cfg=cfg.policy, pretrained_path=cfg.policy.pretrained_path, **processor_kwargs ) logging.info("Creating optimizer and scheduler") @@ -189,12 +195,6 @@ def train(cfg: TrainPipelineConfig): if cfg.resume: step, optimizer, lr_scheduler = load_training_state(cfg.checkpoint_path, optimizer, lr_scheduler) - preprocessor.from_pretrained( - cfg.policy.pretrained_path, config_filename=f"{POLICY_PREPROCESSOR_DEFAULT_NAME}.json" - ) - postprocessor.from_pretrained( - cfg.policy.pretrained_path, config_filename=f"{POLICY_POSTPROCESSOR_DEFAULT_NAME}.json" - ) num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad) num_total_params = sum(p.numel() for p in policy.parameters()) diff --git a/tests/processor/test_normalize_processor.py b/tests/processor/test_normalize_processor.py index 0a28320ae..fb19aa073 100644 --- a/tests/processor/test_normalize_processor.py +++ b/tests/processor/test_normalize_processor.py @@ -1530,7 +1530,239 @@ def test_dtype_adaptation_bfloat16_input_float32_normalizer(): assert torch.allclose(output_tensor, expected, atol=1e-2) # bfloat16 has lower precision -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_stats_override_preservation_in_load_state_dict(): + """ + Test that explicitly provided stats are preserved during load_state_dict. + + This tests the fix for the bug where stats provided via overrides were + being overwritten when load_state_dict was called. + """ + # Create original stats + original_stats = { + "observation.image": {"mean": np.array([0.5, 0.5, 0.5]), "std": np.array([0.2, 0.2, 0.2])}, + "action": {"mean": np.array([0.0, 0.0]), "std": np.array([1.0, 1.0])}, + } + + # Create override stats (what user wants to use) + override_stats = { + "observation.image": {"mean": np.array([0.3, 0.3, 0.3]), "std": np.array([0.1, 0.1, 0.1])}, + "action": {"mean": np.array([0.1, 0.1]), "std": np.array([0.5, 0.5])}, + } + + features = { + "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), + "action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)), + } + norm_map = { + FeatureType.VISUAL: NormalizationMode.MEAN_STD, + FeatureType.ACTION: NormalizationMode.MEAN_STD, + } + + # Create a normalizer with original stats and save its state + original_normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=original_stats) + saved_state_dict = original_normalizer.state_dict() + + # Create a new normalizer with override stats (simulating from_pretrained with overrides) + override_normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=override_stats) + + # Verify that the override stats are initially set correctly + assert set(override_normalizer.stats.keys()) == set(override_stats.keys()) + for key in override_stats: + assert set(override_normalizer.stats[key].keys()) == set(override_stats[key].keys()) + for stat_name in override_stats[key]: + np.testing.assert_array_equal( + override_normalizer.stats[key][stat_name], override_stats[key][stat_name] + ) + assert override_normalizer._stats_explicitly_provided is True + + # This is the critical test: load_state_dict should NOT overwrite the override stats + override_normalizer.load_state_dict(saved_state_dict) + + # After loading state_dict, stats should still be the override stats, not the original stats + # Check that loaded stats match override stats + assert set(override_normalizer.stats.keys()) == set(override_stats.keys()) + for key in override_stats: + assert set(override_normalizer.stats[key].keys()) == set(override_stats[key].keys()) + for stat_name in override_stats[key]: + np.testing.assert_array_equal( + override_normalizer.stats[key][stat_name], override_stats[key][stat_name] + ) + # Compare individual arrays to avoid numpy array comparison ambiguity + for key in override_stats: + for stat_name in override_stats[key]: + assert not np.array_equal( + override_normalizer.stats[key][stat_name], original_stats[key][stat_name] + ), f"Stats for {key}.{stat_name} should not match original stats" + + # Verify that _tensor_stats are also correctly set to match the override stats + expected_tensor_stats = to_tensor(override_stats) + for key in expected_tensor_stats: + for stat_name in expected_tensor_stats[key]: + if isinstance(expected_tensor_stats[key][stat_name], torch.Tensor): + torch.testing.assert_close( + override_normalizer._tensor_stats[key][stat_name], expected_tensor_stats[key][stat_name] + ) + + +def test_stats_without_override_loads_normally(): + """ + Test that when stats are not explicitly provided (normal case), + load_state_dict works as before. + """ + original_stats = { + "observation.image": {"mean": np.array([0.5, 0.5, 0.5]), "std": np.array([0.2, 0.2, 0.2])}, + "action": {"mean": np.array([0.0, 0.0]), "std": np.array([1.0, 1.0])}, + } + + features = { + "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), + "action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)), + } + norm_map = { + FeatureType.VISUAL: NormalizationMode.MEAN_STD, + FeatureType.ACTION: NormalizationMode.MEAN_STD, + } + + # Create a normalizer with original stats and save its state + original_normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=original_stats) + saved_state_dict = original_normalizer.state_dict() + + # Create a new normalizer without stats (simulating normal from_pretrained) + new_normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats={}) + + # Verify that stats are not explicitly provided + assert new_normalizer._stats_explicitly_provided is False + + # Load state dict - this should work normally and load the saved stats + new_normalizer.load_state_dict(saved_state_dict) + + # Stats should now match the original stats (normal behavior) + # Check that all keys and values match + assert set(new_normalizer.stats.keys()) == set(original_stats.keys()) + for key in original_stats: + assert set(new_normalizer.stats[key].keys()) == set(original_stats[key].keys()) + for stat_name in original_stats[key]: + np.testing.assert_allclose( + new_normalizer.stats[key][stat_name], original_stats[key][stat_name], rtol=1e-6, atol=1e-6 + ) + + +def test_stats_explicit_provided_flag_detection(): + """Test that the _stats_explicitly_provided flag is set correctly in different scenarios.""" + features = { + "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), + } + norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD} + + # Test 1: Explicitly provided stats (non-empty dict) + stats = {"observation.image": {"mean": [0.5], "std": [0.2]}} + normalizer1 = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) + assert normalizer1._stats_explicitly_provided is True + + # Test 2: Empty stats dict + normalizer2 = NormalizerProcessorStep(features=features, norm_map=norm_map, stats={}) + assert normalizer2._stats_explicitly_provided is False + + # Test 3: None stats + normalizer3 = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=None) + assert normalizer3._stats_explicitly_provided is False + + # Test 4: Stats not provided (defaults to None) + normalizer4 = NormalizerProcessorStep(features=features, norm_map=norm_map) + assert normalizer4._stats_explicitly_provided is False + + +def test_pipeline_from_pretrained_with_stats_overrides(): + """ + Test the actual use case: DataProcessorPipeline.from_pretrained with stat overrides. + + This is an integration test that verifies the fix works in the real scenario + where users provide stat overrides when loading a pipeline. + """ + import tempfile + + # Create test data + features = { + "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 32, 32)), + "action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)), + } + norm_map = { + FeatureType.VISUAL: NormalizationMode.MEAN_STD, + FeatureType.ACTION: NormalizationMode.MEAN_STD, + } + + original_stats = { + "observation.image": {"mean": np.array([0.5, 0.5, 0.5]), "std": np.array([0.2, 0.2, 0.2])}, + "action": {"mean": np.array([0.0, 0.0]), "std": np.array([1.0, 1.0])}, + } + + override_stats = { + "observation.image": {"mean": np.array([0.3, 0.3, 0.3]), "std": np.array([0.1, 0.1, 0.1])}, + "action": {"mean": np.array([0.1, 0.1]), "std": np.array([0.5, 0.5])}, + } + + # Create and save a pipeline with the original stats + normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=original_stats) + identity = IdentityProcessorStep() + original_pipeline = DataProcessorPipeline(steps=[normalizer, identity], name="test_pipeline") + + with tempfile.TemporaryDirectory() as temp_dir: + # Save the pipeline + original_pipeline.save_pretrained(temp_dir) + + # Load the pipeline with stat overrides + overrides = {"normalizer_processor": {"stats": override_stats}} + + loaded_pipeline = DataProcessorPipeline.from_pretrained(temp_dir, overrides=overrides) + + # The critical test: the loaded pipeline should use override stats, not original stats + loaded_normalizer = loaded_pipeline.steps[0] + assert isinstance(loaded_normalizer, NormalizerProcessorStep) + + # Check that loaded stats match override stats + assert set(loaded_normalizer.stats.keys()) == set(override_stats.keys()) + for key in override_stats: + assert set(loaded_normalizer.stats[key].keys()) == set(override_stats[key].keys()) + for stat_name in override_stats[key]: + np.testing.assert_array_equal( + loaded_normalizer.stats[key][stat_name], override_stats[key][stat_name] + ) + + # Verify stats don't match original stats + for key in override_stats: + for stat_name in override_stats[key]: + assert not np.array_equal( + loaded_normalizer.stats[key][stat_name], original_stats[key][stat_name] + ), f"Stats for {key}.{stat_name} should not match original stats" + + # Test that the override stats are actually used in processing + observation = { + "observation.image": torch.tensor([0.7, 0.5, 0.3]), + } + action = torch.tensor([1.0, -0.5]) + transition = create_transition(observation=observation, action=action) + + # Process with override pipeline + override_result = loaded_pipeline(transition) + + # Create a reference pipeline with override stats for comparison + reference_normalizer = NormalizerProcessorStep( + features=features, norm_map=norm_map, stats=override_stats + ) + reference_pipeline = DataProcessorPipeline( + steps=[reference_normalizer, identity], + to_transition=identity_transition, + to_output=identity_transition, + ) + _ = reference_pipeline(transition) + + # The critical part was verified above: loaded_normalizer.stats == override_stats + # This confirms that override stats are preserved during load_state_dict. + # Let's just verify the pipeline processes data successfully. + assert "action" in override_result + assert isinstance(override_result["action"], torch.Tensor) + + def test_dtype_adaptation_device_processor_bfloat16_normalizer_float32(): """Test policy pipeline scenario: DeviceProcessor(bfloat16) + NormalizerProcessor(float32) → bfloat16 output""" from lerobot.processor import DeviceProcessorStep