mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 08:39:49 +00:00
fix(processor): Preserve stats overrides in normalizer load_state_dict and fix training resumption (#1958)
* feat(processor): enhance normalization handling and state management - Added support for additional normalization modes including IDENTITY. - Introduced a new function `clean_state_dict` to remove specific substrings from state dict keys. - Implemented preservation of explicitly provided normalization statistics during state loading. - Updated training script to conditionally provide dataset statistics based on resume state. - Expanded tests to verify the correct behavior of stats override preservation and loading. * fix(train): remove redundant comment regarding state loading - Removed a comment that noted the preprocessor and postprocessor state is already loaded when resuming training, as it was deemed unnecessary for clarity.
This commit is contained in:
@@ -88,6 +88,10 @@ def extract_normalization_stats(state_dict: dict[str, torch.Tensor]) -> dict[str
|
||||
"unnormalize.", # Must come after unnormalize_* patterns
|
||||
"input_normalizer.",
|
||||
"output_normalizer.",
|
||||
"normalalize_inputs.",
|
||||
"unnormalize_outputs.",
|
||||
"normalize_targets.",
|
||||
"unnormalize_targets.",
|
||||
]
|
||||
|
||||
# Process each key in state_dict
|
||||
@@ -168,6 +172,8 @@ def detect_features_and_norm_modes(
|
||||
mode = NormalizationMode.MEAN_STD
|
||||
elif mode_str == "MIN_MAX":
|
||||
mode = NormalizationMode.MIN_MAX
|
||||
elif mode_str == "IDENTITY":
|
||||
mode = NormalizationMode.IDENTITY
|
||||
else:
|
||||
print(
|
||||
f"Warning: Unknown normalization mode '{mode_str}' for feature type '{feature_type_str}'"
|
||||
@@ -276,6 +282,26 @@ def remove_normalization_layers(state_dict: dict[str, torch.Tensor]) -> dict[str
|
||||
return new_state_dict
|
||||
|
||||
|
||||
def clean_state_dict(
|
||||
state_dict: dict[str, torch.Tensor], remove_str: str = "._orig_mod"
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Remove a substring (e.g. '._orig_mod') from all keys in a state dict.
|
||||
|
||||
Args:
|
||||
state_dict (dict): The original state dict.
|
||||
remove_str (str): The substring to remove from the keys.
|
||||
|
||||
Returns:
|
||||
dict: A new state dict with cleaned keys.
|
||||
"""
|
||||
new_state_dict = {}
|
||||
for k, v in state_dict.items():
|
||||
new_k = k.replace(remove_str, "")
|
||||
new_state_dict[new_k] = v
|
||||
return new_state_dict
|
||||
|
||||
|
||||
def convert_features_to_policy_features(features_dict: dict[str, dict]) -> dict[str, PolicyFeature]:
|
||||
"""
|
||||
Converts a feature dictionary from the old config format to the new `PolicyFeature` format.
|
||||
@@ -405,6 +431,7 @@ def main():
|
||||
# Remove normalization layers from state_dict
|
||||
print("Removing normalization layers from model...")
|
||||
new_state_dict = remove_normalization_layers(state_dict)
|
||||
new_state_dict = clean_state_dict(new_state_dict, remove_str="._orig_mod")
|
||||
|
||||
removed_keys = set(state_dict.keys()) - set(new_state_dict.keys())
|
||||
if removed_keys:
|
||||
|
||||
@@ -43,6 +43,30 @@ class _NormalizationMixin:
|
||||
be inherited by concrete `ProcessorStep` implementations and should not be used
|
||||
directly.
|
||||
|
||||
**Stats Override Preservation:**
|
||||
When stats are explicitly provided during construction (e.g., via overrides in
|
||||
`DataProcessorPipeline.from_pretrained()`), they are preserved even when
|
||||
`load_state_dict()` is called. This allows users to override normalization
|
||||
statistics from saved models while keeping the rest of the model state intact.
|
||||
|
||||
Examples:
|
||||
```python
|
||||
# Common use case: Override with dataset stats
|
||||
from lerobot.datasets import LeRobotDataset
|
||||
|
||||
dataset = LeRobotDataset("my_dataset")
|
||||
pipeline = DataProcessorPipeline.from_pretrained(
|
||||
"model_path", overrides={"normalizer_processor": {"stats": dataset.meta.stats}}
|
||||
)
|
||||
# dataset.meta.stats will be used, not the stats from the saved model
|
||||
|
||||
# Custom stats override
|
||||
custom_stats = {"action": {"mean": [0.0], "std": [1.0]}}
|
||||
pipeline = DataProcessorPipeline.from_pretrained(
|
||||
"model_path", overrides={"normalizer_processor": {"stats": custom_stats}}
|
||||
)
|
||||
```
|
||||
|
||||
Attributes:
|
||||
features: A dictionary mapping feature names to `PolicyFeature` objects, defining
|
||||
the data structure to be processed.
|
||||
@@ -57,6 +81,8 @@ class _NormalizationMixin:
|
||||
normalization to specific observation features.
|
||||
_tensor_stats: An internal dictionary holding the normalization statistics as
|
||||
PyTorch tensors.
|
||||
_stats_explicitly_provided: Internal flag tracking whether stats were explicitly
|
||||
provided during construction (used for override preservation).
|
||||
"""
|
||||
|
||||
features: dict[str, PolicyFeature]
|
||||
@@ -68,6 +94,7 @@ class _NormalizationMixin:
|
||||
normalize_observation_keys: set[str] | None = None
|
||||
|
||||
_tensor_stats: dict[str, dict[str, Tensor]] = field(default_factory=dict, init=False, repr=False)
|
||||
_stats_explicitly_provided: bool = field(default=False, init=False, repr=False)
|
||||
|
||||
def __post_init__(self):
|
||||
"""
|
||||
@@ -78,6 +105,8 @@ class _NormalizationMixin:
|
||||
lists) and converts the provided `stats` dictionary into a dictionary of
|
||||
tensors (`_tensor_stats`) on the specified device.
|
||||
"""
|
||||
# Track if stats were explicitly provided (not None and not empty)
|
||||
self._stats_explicitly_provided = self.stats is not None and bool(self.stats)
|
||||
# Robust JSON deserialization handling (guard empty maps).
|
||||
if self.features:
|
||||
first_val = next(iter(self.features.values()))
|
||||
@@ -145,10 +174,33 @@ class _NormalizationMixin:
|
||||
|
||||
The loaded tensors are moved to the processor's configured device.
|
||||
|
||||
**Stats Override Preservation:**
|
||||
If stats were explicitly provided during construction (e.g., via overrides in
|
||||
`DataProcessorPipeline.from_pretrained()`), they are preserved and the state
|
||||
dictionary is ignored. This allows users to override normalization statistics
|
||||
while still loading the rest of the model state.
|
||||
|
||||
This behavior is crucial for scenarios where users want to adapt a pretrained
|
||||
model to a new dataset with different statistics without retraining the entire
|
||||
model.
|
||||
|
||||
Args:
|
||||
state: A flat state dictionary with keys in the format
|
||||
`'feature_name.stat_name'`.
|
||||
|
||||
Note:
|
||||
When stats are preserved due to explicit provision, only the tensor
|
||||
representation is updated to ensure consistency with the current device
|
||||
and dtype settings.
|
||||
"""
|
||||
# If stats were explicitly provided during construction, preserve them
|
||||
if self._stats_explicitly_provided and self.stats is not None:
|
||||
# Don't load from state_dict, keep the explicitly provided stats
|
||||
# But ensure _tensor_stats is properly initialized
|
||||
self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype) # type: ignore[assignment]
|
||||
return
|
||||
|
||||
# Normal behavior: load stats from state_dict
|
||||
self._tensor_stats.clear()
|
||||
for flat_key, tensor in state.items():
|
||||
key, stat_name = flat_key.rsplit(".", 1)
|
||||
@@ -159,7 +211,6 @@ class _NormalizationMixin:
|
||||
|
||||
# Reconstruct the original stats dict from tensor stats for compatibility with to() method
|
||||
# and other functions that rely on self.stats
|
||||
|
||||
self.stats = {}
|
||||
for key, tensor_dict in self._tensor_stats.items():
|
||||
self.stats[key] = {}
|
||||
@@ -446,5 +497,5 @@ def hotswap_stats(
|
||||
if isinstance(step, _NormalizationMixin):
|
||||
step.stats = stats
|
||||
# Re-initialize tensor_stats on the correct device.
|
||||
step._tensor_stats = to_tensor(stats, device=step.device, dtype=step.dtype)
|
||||
step._tensor_stats = to_tensor(stats, device=step.device, dtype=step.dtype) # type: ignore[assignment]
|
||||
return rp
|
||||
|
||||
@@ -26,7 +26,6 @@ from torch.optim import Optimizer
|
||||
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.datasets.factory import make_dataset
|
||||
from lerobot.datasets.sampler import EpisodeAwareSampler
|
||||
from lerobot.datasets.utils import cycle
|
||||
@@ -177,8 +176,15 @@ def train(cfg: TrainPipelineConfig):
|
||||
cfg=cfg.policy,
|
||||
ds_meta=dataset.meta,
|
||||
)
|
||||
|
||||
# Create processors - only provide dataset_stats if not resuming from saved processors
|
||||
processor_kwargs = {}
|
||||
if not (cfg.resume and cfg.policy.pretrained_path):
|
||||
# Only provide dataset_stats when not resuming from saved processor state
|
||||
processor_kwargs["dataset_stats"] = dataset.meta.stats
|
||||
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy, pretrained_path=cfg.policy.pretrained_path, dataset_stats=dataset.meta.stats
|
||||
policy_cfg=cfg.policy, pretrained_path=cfg.policy.pretrained_path, **processor_kwargs
|
||||
)
|
||||
|
||||
logging.info("Creating optimizer and scheduler")
|
||||
@@ -189,12 +195,6 @@ def train(cfg: TrainPipelineConfig):
|
||||
|
||||
if cfg.resume:
|
||||
step, optimizer, lr_scheduler = load_training_state(cfg.checkpoint_path, optimizer, lr_scheduler)
|
||||
preprocessor.from_pretrained(
|
||||
cfg.policy.pretrained_path, config_filename=f"{POLICY_PREPROCESSOR_DEFAULT_NAME}.json"
|
||||
)
|
||||
postprocessor.from_pretrained(
|
||||
cfg.policy.pretrained_path, config_filename=f"{POLICY_POSTPROCESSOR_DEFAULT_NAME}.json"
|
||||
)
|
||||
|
||||
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
||||
num_total_params = sum(p.numel() for p in policy.parameters())
|
||||
|
||||
@@ -1530,7 +1530,239 @@ def test_dtype_adaptation_bfloat16_input_float32_normalizer():
|
||||
assert torch.allclose(output_tensor, expected, atol=1e-2) # bfloat16 has lower precision
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_stats_override_preservation_in_load_state_dict():
|
||||
"""
|
||||
Test that explicitly provided stats are preserved during load_state_dict.
|
||||
|
||||
This tests the fix for the bug where stats provided via overrides were
|
||||
being overwritten when load_state_dict was called.
|
||||
"""
|
||||
# Create original stats
|
||||
original_stats = {
|
||||
"observation.image": {"mean": np.array([0.5, 0.5, 0.5]), "std": np.array([0.2, 0.2, 0.2])},
|
||||
"action": {"mean": np.array([0.0, 0.0]), "std": np.array([1.0, 1.0])},
|
||||
}
|
||||
|
||||
# Create override stats (what user wants to use)
|
||||
override_stats = {
|
||||
"observation.image": {"mean": np.array([0.3, 0.3, 0.3]), "std": np.array([0.1, 0.1, 0.1])},
|
||||
"action": {"mean": np.array([0.1, 0.1]), "std": np.array([0.5, 0.5])},
|
||||
}
|
||||
|
||||
features = {
|
||||
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)),
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)),
|
||||
}
|
||||
norm_map = {
|
||||
FeatureType.VISUAL: NormalizationMode.MEAN_STD,
|
||||
FeatureType.ACTION: NormalizationMode.MEAN_STD,
|
||||
}
|
||||
|
||||
# Create a normalizer with original stats and save its state
|
||||
original_normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=original_stats)
|
||||
saved_state_dict = original_normalizer.state_dict()
|
||||
|
||||
# Create a new normalizer with override stats (simulating from_pretrained with overrides)
|
||||
override_normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=override_stats)
|
||||
|
||||
# Verify that the override stats are initially set correctly
|
||||
assert set(override_normalizer.stats.keys()) == set(override_stats.keys())
|
||||
for key in override_stats:
|
||||
assert set(override_normalizer.stats[key].keys()) == set(override_stats[key].keys())
|
||||
for stat_name in override_stats[key]:
|
||||
np.testing.assert_array_equal(
|
||||
override_normalizer.stats[key][stat_name], override_stats[key][stat_name]
|
||||
)
|
||||
assert override_normalizer._stats_explicitly_provided is True
|
||||
|
||||
# This is the critical test: load_state_dict should NOT overwrite the override stats
|
||||
override_normalizer.load_state_dict(saved_state_dict)
|
||||
|
||||
# After loading state_dict, stats should still be the override stats, not the original stats
|
||||
# Check that loaded stats match override stats
|
||||
assert set(override_normalizer.stats.keys()) == set(override_stats.keys())
|
||||
for key in override_stats:
|
||||
assert set(override_normalizer.stats[key].keys()) == set(override_stats[key].keys())
|
||||
for stat_name in override_stats[key]:
|
||||
np.testing.assert_array_equal(
|
||||
override_normalizer.stats[key][stat_name], override_stats[key][stat_name]
|
||||
)
|
||||
# Compare individual arrays to avoid numpy array comparison ambiguity
|
||||
for key in override_stats:
|
||||
for stat_name in override_stats[key]:
|
||||
assert not np.array_equal(
|
||||
override_normalizer.stats[key][stat_name], original_stats[key][stat_name]
|
||||
), f"Stats for {key}.{stat_name} should not match original stats"
|
||||
|
||||
# Verify that _tensor_stats are also correctly set to match the override stats
|
||||
expected_tensor_stats = to_tensor(override_stats)
|
||||
for key in expected_tensor_stats:
|
||||
for stat_name in expected_tensor_stats[key]:
|
||||
if isinstance(expected_tensor_stats[key][stat_name], torch.Tensor):
|
||||
torch.testing.assert_close(
|
||||
override_normalizer._tensor_stats[key][stat_name], expected_tensor_stats[key][stat_name]
|
||||
)
|
||||
|
||||
|
||||
def test_stats_without_override_loads_normally():
|
||||
"""
|
||||
Test that when stats are not explicitly provided (normal case),
|
||||
load_state_dict works as before.
|
||||
"""
|
||||
original_stats = {
|
||||
"observation.image": {"mean": np.array([0.5, 0.5, 0.5]), "std": np.array([0.2, 0.2, 0.2])},
|
||||
"action": {"mean": np.array([0.0, 0.0]), "std": np.array([1.0, 1.0])},
|
||||
}
|
||||
|
||||
features = {
|
||||
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)),
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)),
|
||||
}
|
||||
norm_map = {
|
||||
FeatureType.VISUAL: NormalizationMode.MEAN_STD,
|
||||
FeatureType.ACTION: NormalizationMode.MEAN_STD,
|
||||
}
|
||||
|
||||
# Create a normalizer with original stats and save its state
|
||||
original_normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=original_stats)
|
||||
saved_state_dict = original_normalizer.state_dict()
|
||||
|
||||
# Create a new normalizer without stats (simulating normal from_pretrained)
|
||||
new_normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats={})
|
||||
|
||||
# Verify that stats are not explicitly provided
|
||||
assert new_normalizer._stats_explicitly_provided is False
|
||||
|
||||
# Load state dict - this should work normally and load the saved stats
|
||||
new_normalizer.load_state_dict(saved_state_dict)
|
||||
|
||||
# Stats should now match the original stats (normal behavior)
|
||||
# Check that all keys and values match
|
||||
assert set(new_normalizer.stats.keys()) == set(original_stats.keys())
|
||||
for key in original_stats:
|
||||
assert set(new_normalizer.stats[key].keys()) == set(original_stats[key].keys())
|
||||
for stat_name in original_stats[key]:
|
||||
np.testing.assert_allclose(
|
||||
new_normalizer.stats[key][stat_name], original_stats[key][stat_name], rtol=1e-6, atol=1e-6
|
||||
)
|
||||
|
||||
|
||||
def test_stats_explicit_provided_flag_detection():
|
||||
"""Test that the _stats_explicitly_provided flag is set correctly in different scenarios."""
|
||||
features = {
|
||||
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)),
|
||||
}
|
||||
norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD}
|
||||
|
||||
# Test 1: Explicitly provided stats (non-empty dict)
|
||||
stats = {"observation.image": {"mean": [0.5], "std": [0.2]}}
|
||||
normalizer1 = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
|
||||
assert normalizer1._stats_explicitly_provided is True
|
||||
|
||||
# Test 2: Empty stats dict
|
||||
normalizer2 = NormalizerProcessorStep(features=features, norm_map=norm_map, stats={})
|
||||
assert normalizer2._stats_explicitly_provided is False
|
||||
|
||||
# Test 3: None stats
|
||||
normalizer3 = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=None)
|
||||
assert normalizer3._stats_explicitly_provided is False
|
||||
|
||||
# Test 4: Stats not provided (defaults to None)
|
||||
normalizer4 = NormalizerProcessorStep(features=features, norm_map=norm_map)
|
||||
assert normalizer4._stats_explicitly_provided is False
|
||||
|
||||
|
||||
def test_pipeline_from_pretrained_with_stats_overrides():
|
||||
"""
|
||||
Test the actual use case: DataProcessorPipeline.from_pretrained with stat overrides.
|
||||
|
||||
This is an integration test that verifies the fix works in the real scenario
|
||||
where users provide stat overrides when loading a pipeline.
|
||||
"""
|
||||
import tempfile
|
||||
|
||||
# Create test data
|
||||
features = {
|
||||
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 32, 32)),
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)),
|
||||
}
|
||||
norm_map = {
|
||||
FeatureType.VISUAL: NormalizationMode.MEAN_STD,
|
||||
FeatureType.ACTION: NormalizationMode.MEAN_STD,
|
||||
}
|
||||
|
||||
original_stats = {
|
||||
"observation.image": {"mean": np.array([0.5, 0.5, 0.5]), "std": np.array([0.2, 0.2, 0.2])},
|
||||
"action": {"mean": np.array([0.0, 0.0]), "std": np.array([1.0, 1.0])},
|
||||
}
|
||||
|
||||
override_stats = {
|
||||
"observation.image": {"mean": np.array([0.3, 0.3, 0.3]), "std": np.array([0.1, 0.1, 0.1])},
|
||||
"action": {"mean": np.array([0.1, 0.1]), "std": np.array([0.5, 0.5])},
|
||||
}
|
||||
|
||||
# Create and save a pipeline with the original stats
|
||||
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=original_stats)
|
||||
identity = IdentityProcessorStep()
|
||||
original_pipeline = DataProcessorPipeline(steps=[normalizer, identity], name="test_pipeline")
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Save the pipeline
|
||||
original_pipeline.save_pretrained(temp_dir)
|
||||
|
||||
# Load the pipeline with stat overrides
|
||||
overrides = {"normalizer_processor": {"stats": override_stats}}
|
||||
|
||||
loaded_pipeline = DataProcessorPipeline.from_pretrained(temp_dir, overrides=overrides)
|
||||
|
||||
# The critical test: the loaded pipeline should use override stats, not original stats
|
||||
loaded_normalizer = loaded_pipeline.steps[0]
|
||||
assert isinstance(loaded_normalizer, NormalizerProcessorStep)
|
||||
|
||||
# Check that loaded stats match override stats
|
||||
assert set(loaded_normalizer.stats.keys()) == set(override_stats.keys())
|
||||
for key in override_stats:
|
||||
assert set(loaded_normalizer.stats[key].keys()) == set(override_stats[key].keys())
|
||||
for stat_name in override_stats[key]:
|
||||
np.testing.assert_array_equal(
|
||||
loaded_normalizer.stats[key][stat_name], override_stats[key][stat_name]
|
||||
)
|
||||
|
||||
# Verify stats don't match original stats
|
||||
for key in override_stats:
|
||||
for stat_name in override_stats[key]:
|
||||
assert not np.array_equal(
|
||||
loaded_normalizer.stats[key][stat_name], original_stats[key][stat_name]
|
||||
), f"Stats for {key}.{stat_name} should not match original stats"
|
||||
|
||||
# Test that the override stats are actually used in processing
|
||||
observation = {
|
||||
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
|
||||
}
|
||||
action = torch.tensor([1.0, -0.5])
|
||||
transition = create_transition(observation=observation, action=action)
|
||||
|
||||
# Process with override pipeline
|
||||
override_result = loaded_pipeline(transition)
|
||||
|
||||
# Create a reference pipeline with override stats for comparison
|
||||
reference_normalizer = NormalizerProcessorStep(
|
||||
features=features, norm_map=norm_map, stats=override_stats
|
||||
)
|
||||
reference_pipeline = DataProcessorPipeline(
|
||||
steps=[reference_normalizer, identity],
|
||||
to_transition=identity_transition,
|
||||
to_output=identity_transition,
|
||||
)
|
||||
_ = reference_pipeline(transition)
|
||||
|
||||
# The critical part was verified above: loaded_normalizer.stats == override_stats
|
||||
# This confirms that override stats are preserved during load_state_dict.
|
||||
# Let's just verify the pipeline processes data successfully.
|
||||
assert "action" in override_result
|
||||
assert isinstance(override_result["action"], torch.Tensor)
|
||||
|
||||
|
||||
def test_dtype_adaptation_device_processor_bfloat16_normalizer_float32():
|
||||
"""Test policy pipeline scenario: DeviceProcessor(bfloat16) + NormalizerProcessor(float32) → bfloat16 output"""
|
||||
from lerobot.processor import DeviceProcessorStep
|
||||
|
||||
Reference in New Issue
Block a user