mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-20 19:19:56 +00:00
fix(processor): Preserve stats overrides in normalizer load_state_dict and fix training resumption (#1958)
* feat(processor): enhance normalization handling and state management - Added support for additional normalization modes including IDENTITY. - Introduced a new function `clean_state_dict` to remove specific substrings from state dict keys. - Implemented preservation of explicitly provided normalization statistics during state loading. - Updated training script to conditionally provide dataset statistics based on resume state. - Expanded tests to verify the correct behavior of stats override preservation and loading. * fix(train): remove redundant comment regarding state loading - Removed a comment that noted the preprocessor and postprocessor state is already loaded when resuming training, as it was deemed unnecessary for clarity.
This commit is contained in:
@@ -88,6 +88,10 @@ def extract_normalization_stats(state_dict: dict[str, torch.Tensor]) -> dict[str
|
|||||||
"unnormalize.", # Must come after unnormalize_* patterns
|
"unnormalize.", # Must come after unnormalize_* patterns
|
||||||
"input_normalizer.",
|
"input_normalizer.",
|
||||||
"output_normalizer.",
|
"output_normalizer.",
|
||||||
|
"normalalize_inputs.",
|
||||||
|
"unnormalize_outputs.",
|
||||||
|
"normalize_targets.",
|
||||||
|
"unnormalize_targets.",
|
||||||
]
|
]
|
||||||
|
|
||||||
# Process each key in state_dict
|
# Process each key in state_dict
|
||||||
@@ -168,6 +172,8 @@ def detect_features_and_norm_modes(
|
|||||||
mode = NormalizationMode.MEAN_STD
|
mode = NormalizationMode.MEAN_STD
|
||||||
elif mode_str == "MIN_MAX":
|
elif mode_str == "MIN_MAX":
|
||||||
mode = NormalizationMode.MIN_MAX
|
mode = NormalizationMode.MIN_MAX
|
||||||
|
elif mode_str == "IDENTITY":
|
||||||
|
mode = NormalizationMode.IDENTITY
|
||||||
else:
|
else:
|
||||||
print(
|
print(
|
||||||
f"Warning: Unknown normalization mode '{mode_str}' for feature type '{feature_type_str}'"
|
f"Warning: Unknown normalization mode '{mode_str}' for feature type '{feature_type_str}'"
|
||||||
@@ -276,6 +282,26 @@ def remove_normalization_layers(state_dict: dict[str, torch.Tensor]) -> dict[str
|
|||||||
return new_state_dict
|
return new_state_dict
|
||||||
|
|
||||||
|
|
||||||
|
def clean_state_dict(
|
||||||
|
state_dict: dict[str, torch.Tensor], remove_str: str = "._orig_mod"
|
||||||
|
) -> dict[str, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Remove a substring (e.g. '._orig_mod') from all keys in a state dict.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state_dict (dict): The original state dict.
|
||||||
|
remove_str (str): The substring to remove from the keys.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: A new state dict with cleaned keys.
|
||||||
|
"""
|
||||||
|
new_state_dict = {}
|
||||||
|
for k, v in state_dict.items():
|
||||||
|
new_k = k.replace(remove_str, "")
|
||||||
|
new_state_dict[new_k] = v
|
||||||
|
return new_state_dict
|
||||||
|
|
||||||
|
|
||||||
def convert_features_to_policy_features(features_dict: dict[str, dict]) -> dict[str, PolicyFeature]:
|
def convert_features_to_policy_features(features_dict: dict[str, dict]) -> dict[str, PolicyFeature]:
|
||||||
"""
|
"""
|
||||||
Converts a feature dictionary from the old config format to the new `PolicyFeature` format.
|
Converts a feature dictionary from the old config format to the new `PolicyFeature` format.
|
||||||
@@ -405,6 +431,7 @@ def main():
|
|||||||
# Remove normalization layers from state_dict
|
# Remove normalization layers from state_dict
|
||||||
print("Removing normalization layers from model...")
|
print("Removing normalization layers from model...")
|
||||||
new_state_dict = remove_normalization_layers(state_dict)
|
new_state_dict = remove_normalization_layers(state_dict)
|
||||||
|
new_state_dict = clean_state_dict(new_state_dict, remove_str="._orig_mod")
|
||||||
|
|
||||||
removed_keys = set(state_dict.keys()) - set(new_state_dict.keys())
|
removed_keys = set(state_dict.keys()) - set(new_state_dict.keys())
|
||||||
if removed_keys:
|
if removed_keys:
|
||||||
|
|||||||
@@ -43,6 +43,30 @@ class _NormalizationMixin:
|
|||||||
be inherited by concrete `ProcessorStep` implementations and should not be used
|
be inherited by concrete `ProcessorStep` implementations and should not be used
|
||||||
directly.
|
directly.
|
||||||
|
|
||||||
|
**Stats Override Preservation:**
|
||||||
|
When stats are explicitly provided during construction (e.g., via overrides in
|
||||||
|
`DataProcessorPipeline.from_pretrained()`), they are preserved even when
|
||||||
|
`load_state_dict()` is called. This allows users to override normalization
|
||||||
|
statistics from saved models while keeping the rest of the model state intact.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
```python
|
||||||
|
# Common use case: Override with dataset stats
|
||||||
|
from lerobot.datasets import LeRobotDataset
|
||||||
|
|
||||||
|
dataset = LeRobotDataset("my_dataset")
|
||||||
|
pipeline = DataProcessorPipeline.from_pretrained(
|
||||||
|
"model_path", overrides={"normalizer_processor": {"stats": dataset.meta.stats}}
|
||||||
|
)
|
||||||
|
# dataset.meta.stats will be used, not the stats from the saved model
|
||||||
|
|
||||||
|
# Custom stats override
|
||||||
|
custom_stats = {"action": {"mean": [0.0], "std": [1.0]}}
|
||||||
|
pipeline = DataProcessorPipeline.from_pretrained(
|
||||||
|
"model_path", overrides={"normalizer_processor": {"stats": custom_stats}}
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
features: A dictionary mapping feature names to `PolicyFeature` objects, defining
|
features: A dictionary mapping feature names to `PolicyFeature` objects, defining
|
||||||
the data structure to be processed.
|
the data structure to be processed.
|
||||||
@@ -57,6 +81,8 @@ class _NormalizationMixin:
|
|||||||
normalization to specific observation features.
|
normalization to specific observation features.
|
||||||
_tensor_stats: An internal dictionary holding the normalization statistics as
|
_tensor_stats: An internal dictionary holding the normalization statistics as
|
||||||
PyTorch tensors.
|
PyTorch tensors.
|
||||||
|
_stats_explicitly_provided: Internal flag tracking whether stats were explicitly
|
||||||
|
provided during construction (used for override preservation).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
features: dict[str, PolicyFeature]
|
features: dict[str, PolicyFeature]
|
||||||
@@ -68,6 +94,7 @@ class _NormalizationMixin:
|
|||||||
normalize_observation_keys: set[str] | None = None
|
normalize_observation_keys: set[str] | None = None
|
||||||
|
|
||||||
_tensor_stats: dict[str, dict[str, Tensor]] = field(default_factory=dict, init=False, repr=False)
|
_tensor_stats: dict[str, dict[str, Tensor]] = field(default_factory=dict, init=False, repr=False)
|
||||||
|
_stats_explicitly_provided: bool = field(default=False, init=False, repr=False)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
"""
|
"""
|
||||||
@@ -78,6 +105,8 @@ class _NormalizationMixin:
|
|||||||
lists) and converts the provided `stats` dictionary into a dictionary of
|
lists) and converts the provided `stats` dictionary into a dictionary of
|
||||||
tensors (`_tensor_stats`) on the specified device.
|
tensors (`_tensor_stats`) on the specified device.
|
||||||
"""
|
"""
|
||||||
|
# Track if stats were explicitly provided (not None and not empty)
|
||||||
|
self._stats_explicitly_provided = self.stats is not None and bool(self.stats)
|
||||||
# Robust JSON deserialization handling (guard empty maps).
|
# Robust JSON deserialization handling (guard empty maps).
|
||||||
if self.features:
|
if self.features:
|
||||||
first_val = next(iter(self.features.values()))
|
first_val = next(iter(self.features.values()))
|
||||||
@@ -145,10 +174,33 @@ class _NormalizationMixin:
|
|||||||
|
|
||||||
The loaded tensors are moved to the processor's configured device.
|
The loaded tensors are moved to the processor's configured device.
|
||||||
|
|
||||||
|
**Stats Override Preservation:**
|
||||||
|
If stats were explicitly provided during construction (e.g., via overrides in
|
||||||
|
`DataProcessorPipeline.from_pretrained()`), they are preserved and the state
|
||||||
|
dictionary is ignored. This allows users to override normalization statistics
|
||||||
|
while still loading the rest of the model state.
|
||||||
|
|
||||||
|
This behavior is crucial for scenarios where users want to adapt a pretrained
|
||||||
|
model to a new dataset with different statistics without retraining the entire
|
||||||
|
model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
state: A flat state dictionary with keys in the format
|
state: A flat state dictionary with keys in the format
|
||||||
`'feature_name.stat_name'`.
|
`'feature_name.stat_name'`.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
When stats are preserved due to explicit provision, only the tensor
|
||||||
|
representation is updated to ensure consistency with the current device
|
||||||
|
and dtype settings.
|
||||||
"""
|
"""
|
||||||
|
# If stats were explicitly provided during construction, preserve them
|
||||||
|
if self._stats_explicitly_provided and self.stats is not None:
|
||||||
|
# Don't load from state_dict, keep the explicitly provided stats
|
||||||
|
# But ensure _tensor_stats is properly initialized
|
||||||
|
self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype) # type: ignore[assignment]
|
||||||
|
return
|
||||||
|
|
||||||
|
# Normal behavior: load stats from state_dict
|
||||||
self._tensor_stats.clear()
|
self._tensor_stats.clear()
|
||||||
for flat_key, tensor in state.items():
|
for flat_key, tensor in state.items():
|
||||||
key, stat_name = flat_key.rsplit(".", 1)
|
key, stat_name = flat_key.rsplit(".", 1)
|
||||||
@@ -159,7 +211,6 @@ class _NormalizationMixin:
|
|||||||
|
|
||||||
# Reconstruct the original stats dict from tensor stats for compatibility with to() method
|
# Reconstruct the original stats dict from tensor stats for compatibility with to() method
|
||||||
# and other functions that rely on self.stats
|
# and other functions that rely on self.stats
|
||||||
|
|
||||||
self.stats = {}
|
self.stats = {}
|
||||||
for key, tensor_dict in self._tensor_stats.items():
|
for key, tensor_dict in self._tensor_stats.items():
|
||||||
self.stats[key] = {}
|
self.stats[key] = {}
|
||||||
@@ -446,5 +497,5 @@ def hotswap_stats(
|
|||||||
if isinstance(step, _NormalizationMixin):
|
if isinstance(step, _NormalizationMixin):
|
||||||
step.stats = stats
|
step.stats = stats
|
||||||
# Re-initialize tensor_stats on the correct device.
|
# Re-initialize tensor_stats on the correct device.
|
||||||
step._tensor_stats = to_tensor(stats, device=step.device, dtype=step.dtype)
|
step._tensor_stats = to_tensor(stats, device=step.device, dtype=step.dtype) # type: ignore[assignment]
|
||||||
return rp
|
return rp
|
||||||
|
|||||||
@@ -26,7 +26,6 @@ from torch.optim import Optimizer
|
|||||||
|
|
||||||
from lerobot.configs import parser
|
from lerobot.configs import parser
|
||||||
from lerobot.configs.train import TrainPipelineConfig
|
from lerobot.configs.train import TrainPipelineConfig
|
||||||
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
|
||||||
from lerobot.datasets.factory import make_dataset
|
from lerobot.datasets.factory import make_dataset
|
||||||
from lerobot.datasets.sampler import EpisodeAwareSampler
|
from lerobot.datasets.sampler import EpisodeAwareSampler
|
||||||
from lerobot.datasets.utils import cycle
|
from lerobot.datasets.utils import cycle
|
||||||
@@ -177,8 +176,15 @@ def train(cfg: TrainPipelineConfig):
|
|||||||
cfg=cfg.policy,
|
cfg=cfg.policy,
|
||||||
ds_meta=dataset.meta,
|
ds_meta=dataset.meta,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Create processors - only provide dataset_stats if not resuming from saved processors
|
||||||
|
processor_kwargs = {}
|
||||||
|
if not (cfg.resume and cfg.policy.pretrained_path):
|
||||||
|
# Only provide dataset_stats when not resuming from saved processor state
|
||||||
|
processor_kwargs["dataset_stats"] = dataset.meta.stats
|
||||||
|
|
||||||
preprocessor, postprocessor = make_pre_post_processors(
|
preprocessor, postprocessor = make_pre_post_processors(
|
||||||
policy_cfg=cfg.policy, pretrained_path=cfg.policy.pretrained_path, dataset_stats=dataset.meta.stats
|
policy_cfg=cfg.policy, pretrained_path=cfg.policy.pretrained_path, **processor_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("Creating optimizer and scheduler")
|
logging.info("Creating optimizer and scheduler")
|
||||||
@@ -189,12 +195,6 @@ def train(cfg: TrainPipelineConfig):
|
|||||||
|
|
||||||
if cfg.resume:
|
if cfg.resume:
|
||||||
step, optimizer, lr_scheduler = load_training_state(cfg.checkpoint_path, optimizer, lr_scheduler)
|
step, optimizer, lr_scheduler = load_training_state(cfg.checkpoint_path, optimizer, lr_scheduler)
|
||||||
preprocessor.from_pretrained(
|
|
||||||
cfg.policy.pretrained_path, config_filename=f"{POLICY_PREPROCESSOR_DEFAULT_NAME}.json"
|
|
||||||
)
|
|
||||||
postprocessor.from_pretrained(
|
|
||||||
cfg.policy.pretrained_path, config_filename=f"{POLICY_POSTPROCESSOR_DEFAULT_NAME}.json"
|
|
||||||
)
|
|
||||||
|
|
||||||
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
||||||
num_total_params = sum(p.numel() for p in policy.parameters())
|
num_total_params = sum(p.numel() for p in policy.parameters())
|
||||||
|
|||||||
@@ -1530,7 +1530,239 @@ def test_dtype_adaptation_bfloat16_input_float32_normalizer():
|
|||||||
assert torch.allclose(output_tensor, expected, atol=1e-2) # bfloat16 has lower precision
|
assert torch.allclose(output_tensor, expected, atol=1e-2) # bfloat16 has lower precision
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
def test_stats_override_preservation_in_load_state_dict():
|
||||||
|
"""
|
||||||
|
Test that explicitly provided stats are preserved during load_state_dict.
|
||||||
|
|
||||||
|
This tests the fix for the bug where stats provided via overrides were
|
||||||
|
being overwritten when load_state_dict was called.
|
||||||
|
"""
|
||||||
|
# Create original stats
|
||||||
|
original_stats = {
|
||||||
|
"observation.image": {"mean": np.array([0.5, 0.5, 0.5]), "std": np.array([0.2, 0.2, 0.2])},
|
||||||
|
"action": {"mean": np.array([0.0, 0.0]), "std": np.array([1.0, 1.0])},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create override stats (what user wants to use)
|
||||||
|
override_stats = {
|
||||||
|
"observation.image": {"mean": np.array([0.3, 0.3, 0.3]), "std": np.array([0.1, 0.1, 0.1])},
|
||||||
|
"action": {"mean": np.array([0.1, 0.1]), "std": np.array([0.5, 0.5])},
|
||||||
|
}
|
||||||
|
|
||||||
|
features = {
|
||||||
|
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)),
|
||||||
|
"action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)),
|
||||||
|
}
|
||||||
|
norm_map = {
|
||||||
|
FeatureType.VISUAL: NormalizationMode.MEAN_STD,
|
||||||
|
FeatureType.ACTION: NormalizationMode.MEAN_STD,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create a normalizer with original stats and save its state
|
||||||
|
original_normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=original_stats)
|
||||||
|
saved_state_dict = original_normalizer.state_dict()
|
||||||
|
|
||||||
|
# Create a new normalizer with override stats (simulating from_pretrained with overrides)
|
||||||
|
override_normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=override_stats)
|
||||||
|
|
||||||
|
# Verify that the override stats are initially set correctly
|
||||||
|
assert set(override_normalizer.stats.keys()) == set(override_stats.keys())
|
||||||
|
for key in override_stats:
|
||||||
|
assert set(override_normalizer.stats[key].keys()) == set(override_stats[key].keys())
|
||||||
|
for stat_name in override_stats[key]:
|
||||||
|
np.testing.assert_array_equal(
|
||||||
|
override_normalizer.stats[key][stat_name], override_stats[key][stat_name]
|
||||||
|
)
|
||||||
|
assert override_normalizer._stats_explicitly_provided is True
|
||||||
|
|
||||||
|
# This is the critical test: load_state_dict should NOT overwrite the override stats
|
||||||
|
override_normalizer.load_state_dict(saved_state_dict)
|
||||||
|
|
||||||
|
# After loading state_dict, stats should still be the override stats, not the original stats
|
||||||
|
# Check that loaded stats match override stats
|
||||||
|
assert set(override_normalizer.stats.keys()) == set(override_stats.keys())
|
||||||
|
for key in override_stats:
|
||||||
|
assert set(override_normalizer.stats[key].keys()) == set(override_stats[key].keys())
|
||||||
|
for stat_name in override_stats[key]:
|
||||||
|
np.testing.assert_array_equal(
|
||||||
|
override_normalizer.stats[key][stat_name], override_stats[key][stat_name]
|
||||||
|
)
|
||||||
|
# Compare individual arrays to avoid numpy array comparison ambiguity
|
||||||
|
for key in override_stats:
|
||||||
|
for stat_name in override_stats[key]:
|
||||||
|
assert not np.array_equal(
|
||||||
|
override_normalizer.stats[key][stat_name], original_stats[key][stat_name]
|
||||||
|
), f"Stats for {key}.{stat_name} should not match original stats"
|
||||||
|
|
||||||
|
# Verify that _tensor_stats are also correctly set to match the override stats
|
||||||
|
expected_tensor_stats = to_tensor(override_stats)
|
||||||
|
for key in expected_tensor_stats:
|
||||||
|
for stat_name in expected_tensor_stats[key]:
|
||||||
|
if isinstance(expected_tensor_stats[key][stat_name], torch.Tensor):
|
||||||
|
torch.testing.assert_close(
|
||||||
|
override_normalizer._tensor_stats[key][stat_name], expected_tensor_stats[key][stat_name]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_stats_without_override_loads_normally():
|
||||||
|
"""
|
||||||
|
Test that when stats are not explicitly provided (normal case),
|
||||||
|
load_state_dict works as before.
|
||||||
|
"""
|
||||||
|
original_stats = {
|
||||||
|
"observation.image": {"mean": np.array([0.5, 0.5, 0.5]), "std": np.array([0.2, 0.2, 0.2])},
|
||||||
|
"action": {"mean": np.array([0.0, 0.0]), "std": np.array([1.0, 1.0])},
|
||||||
|
}
|
||||||
|
|
||||||
|
features = {
|
||||||
|
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)),
|
||||||
|
"action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)),
|
||||||
|
}
|
||||||
|
norm_map = {
|
||||||
|
FeatureType.VISUAL: NormalizationMode.MEAN_STD,
|
||||||
|
FeatureType.ACTION: NormalizationMode.MEAN_STD,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create a normalizer with original stats and save its state
|
||||||
|
original_normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=original_stats)
|
||||||
|
saved_state_dict = original_normalizer.state_dict()
|
||||||
|
|
||||||
|
# Create a new normalizer without stats (simulating normal from_pretrained)
|
||||||
|
new_normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats={})
|
||||||
|
|
||||||
|
# Verify that stats are not explicitly provided
|
||||||
|
assert new_normalizer._stats_explicitly_provided is False
|
||||||
|
|
||||||
|
# Load state dict - this should work normally and load the saved stats
|
||||||
|
new_normalizer.load_state_dict(saved_state_dict)
|
||||||
|
|
||||||
|
# Stats should now match the original stats (normal behavior)
|
||||||
|
# Check that all keys and values match
|
||||||
|
assert set(new_normalizer.stats.keys()) == set(original_stats.keys())
|
||||||
|
for key in original_stats:
|
||||||
|
assert set(new_normalizer.stats[key].keys()) == set(original_stats[key].keys())
|
||||||
|
for stat_name in original_stats[key]:
|
||||||
|
np.testing.assert_allclose(
|
||||||
|
new_normalizer.stats[key][stat_name], original_stats[key][stat_name], rtol=1e-6, atol=1e-6
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_stats_explicit_provided_flag_detection():
|
||||||
|
"""Test that the _stats_explicitly_provided flag is set correctly in different scenarios."""
|
||||||
|
features = {
|
||||||
|
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)),
|
||||||
|
}
|
||||||
|
norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD}
|
||||||
|
|
||||||
|
# Test 1: Explicitly provided stats (non-empty dict)
|
||||||
|
stats = {"observation.image": {"mean": [0.5], "std": [0.2]}}
|
||||||
|
normalizer1 = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
|
||||||
|
assert normalizer1._stats_explicitly_provided is True
|
||||||
|
|
||||||
|
# Test 2: Empty stats dict
|
||||||
|
normalizer2 = NormalizerProcessorStep(features=features, norm_map=norm_map, stats={})
|
||||||
|
assert normalizer2._stats_explicitly_provided is False
|
||||||
|
|
||||||
|
# Test 3: None stats
|
||||||
|
normalizer3 = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=None)
|
||||||
|
assert normalizer3._stats_explicitly_provided is False
|
||||||
|
|
||||||
|
# Test 4: Stats not provided (defaults to None)
|
||||||
|
normalizer4 = NormalizerProcessorStep(features=features, norm_map=norm_map)
|
||||||
|
assert normalizer4._stats_explicitly_provided is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_pipeline_from_pretrained_with_stats_overrides():
|
||||||
|
"""
|
||||||
|
Test the actual use case: DataProcessorPipeline.from_pretrained with stat overrides.
|
||||||
|
|
||||||
|
This is an integration test that verifies the fix works in the real scenario
|
||||||
|
where users provide stat overrides when loading a pipeline.
|
||||||
|
"""
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
# Create test data
|
||||||
|
features = {
|
||||||
|
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 32, 32)),
|
||||||
|
"action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)),
|
||||||
|
}
|
||||||
|
norm_map = {
|
||||||
|
FeatureType.VISUAL: NormalizationMode.MEAN_STD,
|
||||||
|
FeatureType.ACTION: NormalizationMode.MEAN_STD,
|
||||||
|
}
|
||||||
|
|
||||||
|
original_stats = {
|
||||||
|
"observation.image": {"mean": np.array([0.5, 0.5, 0.5]), "std": np.array([0.2, 0.2, 0.2])},
|
||||||
|
"action": {"mean": np.array([0.0, 0.0]), "std": np.array([1.0, 1.0])},
|
||||||
|
}
|
||||||
|
|
||||||
|
override_stats = {
|
||||||
|
"observation.image": {"mean": np.array([0.3, 0.3, 0.3]), "std": np.array([0.1, 0.1, 0.1])},
|
||||||
|
"action": {"mean": np.array([0.1, 0.1]), "std": np.array([0.5, 0.5])},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create and save a pipeline with the original stats
|
||||||
|
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=original_stats)
|
||||||
|
identity = IdentityProcessorStep()
|
||||||
|
original_pipeline = DataProcessorPipeline(steps=[normalizer, identity], name="test_pipeline")
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
# Save the pipeline
|
||||||
|
original_pipeline.save_pretrained(temp_dir)
|
||||||
|
|
||||||
|
# Load the pipeline with stat overrides
|
||||||
|
overrides = {"normalizer_processor": {"stats": override_stats}}
|
||||||
|
|
||||||
|
loaded_pipeline = DataProcessorPipeline.from_pretrained(temp_dir, overrides=overrides)
|
||||||
|
|
||||||
|
# The critical test: the loaded pipeline should use override stats, not original stats
|
||||||
|
loaded_normalizer = loaded_pipeline.steps[0]
|
||||||
|
assert isinstance(loaded_normalizer, NormalizerProcessorStep)
|
||||||
|
|
||||||
|
# Check that loaded stats match override stats
|
||||||
|
assert set(loaded_normalizer.stats.keys()) == set(override_stats.keys())
|
||||||
|
for key in override_stats:
|
||||||
|
assert set(loaded_normalizer.stats[key].keys()) == set(override_stats[key].keys())
|
||||||
|
for stat_name in override_stats[key]:
|
||||||
|
np.testing.assert_array_equal(
|
||||||
|
loaded_normalizer.stats[key][stat_name], override_stats[key][stat_name]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify stats don't match original stats
|
||||||
|
for key in override_stats:
|
||||||
|
for stat_name in override_stats[key]:
|
||||||
|
assert not np.array_equal(
|
||||||
|
loaded_normalizer.stats[key][stat_name], original_stats[key][stat_name]
|
||||||
|
), f"Stats for {key}.{stat_name} should not match original stats"
|
||||||
|
|
||||||
|
# Test that the override stats are actually used in processing
|
||||||
|
observation = {
|
||||||
|
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
|
||||||
|
}
|
||||||
|
action = torch.tensor([1.0, -0.5])
|
||||||
|
transition = create_transition(observation=observation, action=action)
|
||||||
|
|
||||||
|
# Process with override pipeline
|
||||||
|
override_result = loaded_pipeline(transition)
|
||||||
|
|
||||||
|
# Create a reference pipeline with override stats for comparison
|
||||||
|
reference_normalizer = NormalizerProcessorStep(
|
||||||
|
features=features, norm_map=norm_map, stats=override_stats
|
||||||
|
)
|
||||||
|
reference_pipeline = DataProcessorPipeline(
|
||||||
|
steps=[reference_normalizer, identity],
|
||||||
|
to_transition=identity_transition,
|
||||||
|
to_output=identity_transition,
|
||||||
|
)
|
||||||
|
_ = reference_pipeline(transition)
|
||||||
|
|
||||||
|
# The critical part was verified above: loaded_normalizer.stats == override_stats
|
||||||
|
# This confirms that override stats are preserved during load_state_dict.
|
||||||
|
# Let's just verify the pipeline processes data successfully.
|
||||||
|
assert "action" in override_result
|
||||||
|
assert isinstance(override_result["action"], torch.Tensor)
|
||||||
|
|
||||||
|
|
||||||
def test_dtype_adaptation_device_processor_bfloat16_normalizer_float32():
|
def test_dtype_adaptation_device_processor_bfloat16_normalizer_float32():
|
||||||
"""Test policy pipeline scenario: DeviceProcessor(bfloat16) + NormalizerProcessor(float32) → bfloat16 output"""
|
"""Test policy pipeline scenario: DeviceProcessor(bfloat16) + NormalizerProcessor(float32) → bfloat16 output"""
|
||||||
from lerobot.processor import DeviceProcessorStep
|
from lerobot.processor import DeviceProcessorStep
|
||||||
|
|||||||
Reference in New Issue
Block a user