refactor(normalization): Remove unused state dict transformation methods and streamline imports

- Eliminated the _transform_state_dict_keys and _load_as_safetensor methods from PI0Policy, simplifying the model loading process.
- Cleaned up imports in modeling_pi0.py by removing log_model_loading_keys and init_logging.
- Updated TDMPCPolicy and VQBeTPolicy to handle action removal from batches during offline evaluation.
- Introduced hotswap_stats function in normalize_processor.py to update normalization statistics dynamically, with corresponding tests to ensure functionality.
This commit is contained in:
Adil Zouitine
2025-08-01 08:55:38 +02:00
committed by Steven Palma
parent f02ce69df0
commit 8ff95be04c
6 changed files with 398 additions and 98 deletions
+1 -95
View File
@@ -65,8 +65,7 @@ from lerobot.policies.pi0.paligemma_with_expert import (
PaliGemmaWithExpertModel, PaliGemmaWithExpertModel,
) )
from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.utils import log_model_loading_keys from lerobot.utils.utils import get_safe_dtype
from lerobot.utils.utils import get_safe_dtype, init_logging
def create_sinusoidal_pos_embedding( def create_sinusoidal_pos_embedding(
@@ -245,99 +244,6 @@ class PI0Policy(PreTrainedPolicy):
"""This should be called whenever the environment is reset.""" """This should be called whenever the environment is reset."""
self._action_queue = deque([], maxlen=self.config.n_action_steps) self._action_queue = deque([], maxlen=self.config.n_action_steps)
@classmethod
def _transform_state_dict_keys(cls, state_dict: dict) -> dict:
"""
Transform state dict keys to match expected model structure.
Transformations:
- model.paligemma_with_expert.paligemma.language_model.lm_head ->
model.paligemma_with_expert.paligemma.lm_head
- model.paligemma_with_expert.paligemma.language_model.model ->
model.paligemma_with_expert.paligemma.model.language_model
- model.paligemma_with_expert.paligemma.vision_tower ->
model.paligemma_with_expert.paligemma.model.vision_tower
- model.paligemma_with_expert.paligemma.multi_modal_projector ->
model.paligemma_with_expert.paligemma.model.multi_modal_projector
Also handles tied weights between lm_head.weight and
embed_tokens.weight.
"""
import re
transformed_dict = {}
transformations = [
(
re.compile(r"\.paligemma_with_expert\.paligemma\.language_model\.lm_head"),
".paligemma_with_expert.paligemma.lm_head",
),
(
re.compile(r"\.paligemma_with_expert\.paligemma\.language_model\.model"),
".paligemma_with_expert.paligemma.model.language_model",
),
(
re.compile(r"\.paligemma_with_expert\.paligemma\.vision_tower"),
".paligemma_with_expert.paligemma.model.vision_tower",
),
(
re.compile(r"\.paligemma_with_expert\.paligemma\.multi_modal_projector"),
".paligemma_with_expert.paligemma.model.multi_modal_projector",
),
]
for key, value in state_dict.items():
new_key = key
for pattern, replacement in transformations:
new_key = pattern.sub(replacement, new_key)
transformed_dict[new_key] = value
# Handle tied weights: lm_head.weight and embed_tokens.weight share memory
lm_head_key = None
embed_tokens_key = None
for key in transformed_dict:
if key.endswith(".paligemma_with_expert.paligemma.lm_head.weight"):
lm_head_key = key
elif key.endswith(".paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"):
embed_tokens_key = key
if lm_head_key and embed_tokens_key:
break
if lm_head_key and not embed_tokens_key:
embed_tokens_key = lm_head_key.replace(
".lm_head.weight", ".model.language_model.embed_tokens.weight"
)
transformed_dict[embed_tokens_key] = transformed_dict[lm_head_key]
elif embed_tokens_key and not lm_head_key:
lm_head_key = embed_tokens_key.replace(
".model.language_model.embed_tokens.weight", ".lm_head.weight"
)
transformed_dict[lm_head_key] = transformed_dict[embed_tokens_key]
return transformed_dict
@classmethod
def _load_as_safetensor(
cls, model: "PI0Policy", model_file: str, map_location: str, strict: bool
) -> "PI0Policy":
"""Override to apply key transformations before loading."""
from safetensors.torch import load_file
init_logging()
# Load the state dict from file safely
state_dict = load_file(model_file, device=map_location)
# Apply key transformations
transformed_state_dict = cls._transform_state_dict_keys(state_dict)
# Load the transformed state dict
msg = model.load_state_dict(transformed_state_dict, strict=strict)
# Log message
log_model_loading_keys(msg.missing_keys, msg.unexpected_keys)
return model
def get_optim_params(self) -> dict: def get_optim_params(self) -> dict:
return self.parameters() return self.parameters()
@@ -141,6 +141,9 @@ class TDMPCPolicy(PreTrainedPolicy):
if self.config.image_features: if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch[OBS_IMAGE] = batch[next(iter(self.config.image_features))] batch[OBS_IMAGE] = batch[next(iter(self.config.image_features))]
# NOTE: for offline evaluation, we have action in the batch, so we need to pop it out
if ACTION in batch:
batch.pop(ACTION)
self._queues = populate_queues(self._queues, batch) self._queues = populate_queues(self._queues, batch)
@@ -134,6 +134,9 @@ class VQBeTPolicy(PreTrainedPolicy):
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
# NOTE: It's important that this happens after stacking the images into a single key. # NOTE: It's important that this happens after stacking the images into a single key.
batch["observation.images"] = torch.stack([batch[key] for key in self.config.image_features], dim=-4) batch["observation.images"] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
# NOTE: for offline evaluation, we have action in the batch, so we need to pop it out
if ACTION in batch:
batch.pop(ACTION)
self._queues = populate_queues(self._queues, batch) self._queues = populate_queues(self._queues, batch)
+2 -1
View File
@@ -16,7 +16,7 @@
from .batch_processor import ToBatchProcessor from .batch_processor import ToBatchProcessor
from .device_processor import DeviceProcessor from .device_processor import DeviceProcessor
from .normalize_processor import NormalizerProcessor, UnnormalizerProcessor from .normalize_processor import NormalizerProcessor, UnnormalizerProcessor, hotswap_stats
from .observation_processor import VanillaObservationProcessor from .observation_processor import VanillaObservationProcessor
from .pipeline import ( from .pipeline import (
ActionProcessor, ActionProcessor,
@@ -43,6 +43,7 @@ __all__ = [
"InfoProcessor", "InfoProcessor",
"NormalizerProcessor", "NormalizerProcessor",
"UnnormalizerProcessor", "UnnormalizerProcessor",
"hotswap_stats",
"ObservationProcessor", "ObservationProcessor",
"ProcessorStep", "ProcessorStep",
"ProcessorStepRegistry", "ProcessorStepRegistry",
+13 -1
View File
@@ -1,6 +1,8 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Mapping from collections.abc import Mapping
from copy import deepcopy
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any from typing import Any
@@ -10,7 +12,7 @@ from torch import Tensor
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey, RobotProcessor
def _convert_stats_to_tensors(stats: dict[str, dict[str, Any]]) -> dict[str, dict[str, Tensor]]: def _convert_stats_to_tensors(stats: dict[str, dict[str, Any]]) -> dict[str, dict[str, Tensor]]:
@@ -402,3 +404,13 @@ class UnnormalizerProcessor:
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features return features
def hotswap_stats(robot_processor: RobotProcessor, stats: dict[str, dict[str, Any]]) -> RobotProcessor:
robot_processor = deepcopy(robot_processor)
for step in robot_processor.steps:
if isinstance(step, NormalizerProcessor) or isinstance(step, UnnormalizerProcessor):
step: NormalizerProcessor | UnnormalizerProcessor
step.stats = stats
step._tensor_stats = _convert_stats_to_tensors(stats)
return robot_processor
+376 -1
View File
@@ -24,8 +24,9 @@ from lerobot.processor.normalize_processor import (
NormalizerProcessor, NormalizerProcessor,
UnnormalizerProcessor, UnnormalizerProcessor,
_convert_stats_to_tensors, _convert_stats_to_tensors,
hotswap_stats,
) )
from lerobot.processor.pipeline import RobotProcessor, TransitionKey from lerobot.processor.pipeline import IdentityProcessor, RobotProcessor, TransitionKey
def create_transition( def create_transition(
@@ -953,3 +954,377 @@ def test_unsupported_normalization_mode_error():
with pytest.raises(ValueError, match="Unsupported normalization mode"): with pytest.raises(ValueError, match="Unsupported normalization mode"):
normalizer(transition) normalizer(transition)
def test_hotswap_stats_basic_functionality():
"""Test that hotswap_stats correctly updates stats in normalizer/unnormalizer steps."""
# Create initial stats
initial_stats = {
"observation.image": {"mean": np.array([0.5, 0.5, 0.5]), "std": np.array([0.2, 0.2, 0.2])},
"action": {"mean": np.array([0.0, 0.0]), "std": np.array([1.0, 1.0])},
}
# Create new stats for hotswapping
new_stats = {
"observation.image": {"mean": np.array([0.3, 0.3, 0.3]), "std": np.array([0.1, 0.1, 0.1])},
"action": {"mean": np.array([0.1, 0.1]), "std": np.array([0.5, 0.5])},
}
# Create features and norm_map
features = {
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)),
"action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)),
}
norm_map = {
FeatureType.VISUAL: NormalizationMode.MEAN_STD,
FeatureType.ACTION: NormalizationMode.MEAN_STD,
}
# Create processors
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=initial_stats)
unnormalizer = UnnormalizerProcessor(features=features, norm_map=norm_map, stats=initial_stats)
identity = IdentityProcessor()
# Create robot processor
robot_processor = RobotProcessor(steps=[normalizer, unnormalizer, identity])
# Hotswap stats
new_processor = hotswap_stats(robot_processor, new_stats)
# Check that normalizer and unnormalizer have new stats
assert new_processor.steps[0].stats == new_stats
assert new_processor.steps[1].stats == new_stats
# Check that tensor stats are updated correctly
expected_tensor_stats = _convert_stats_to_tensors(new_stats)
for key in expected_tensor_stats:
for stat_name in expected_tensor_stats[key]:
torch.testing.assert_close(
new_processor.steps[0]._tensor_stats[key][stat_name], expected_tensor_stats[key][stat_name]
)
torch.testing.assert_close(
new_processor.steps[1]._tensor_stats[key][stat_name], expected_tensor_stats[key][stat_name]
)
def test_hotswap_stats_deep_copy():
"""Test that hotswap_stats creates a deep copy and doesn't modify the original processor."""
initial_stats = {
"observation.image": {"mean": np.array([0.5, 0.5, 0.5]), "std": np.array([0.2, 0.2, 0.2])},
}
new_stats = {
"observation.image": {"mean": np.array([0.3, 0.3, 0.3]), "std": np.array([0.1, 0.1, 0.1])},
}
features = {
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)),
}
norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD}
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=initial_stats)
original_processor = RobotProcessor(steps=[normalizer])
# Store reference to original stats
original_stats_reference = original_processor.steps[0].stats
original_tensor_stats_reference = original_processor.steps[0]._tensor_stats
# Hotswap stats
new_processor = hotswap_stats(original_processor, new_stats)
# Original processor should be unchanged
assert original_processor.steps[0].stats is original_stats_reference
assert original_processor.steps[0]._tensor_stats is original_tensor_stats_reference
assert original_processor.steps[0].stats == initial_stats
# New processor should have new stats
assert new_processor.steps[0].stats == new_stats
assert new_processor.steps[0].stats is not original_stats_reference
# Processors should be different objects
assert new_processor is not original_processor
assert new_processor.steps[0] is not original_processor.steps[0]
def test_hotswap_stats_only_affects_normalizer_steps():
"""Test that hotswap_stats only modifies NormalizerProcessor and UnnormalizerProcessor steps."""
stats = {
"observation.image": {"mean": np.array([0.5]), "std": np.array([0.2])},
}
new_stats = {
"observation.image": {"mean": np.array([0.3]), "std": np.array([0.1])},
}
features = {
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)),
}
norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD}
# Create mixed steps
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
unnormalizer = UnnormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
identity = IdentityProcessor()
robot_processor = RobotProcessor(steps=[normalizer, identity, unnormalizer])
# Hotswap stats
new_processor = hotswap_stats(robot_processor, new_stats)
# Check that only normalizer and unnormalizer steps are affected
assert new_processor.steps[0].stats == new_stats # normalizer
assert new_processor.steps[2].stats == new_stats # unnormalizer
# Identity processor should remain unchanged (and it doesn't have stats attribute)
assert not hasattr(new_processor.steps[1], "stats")
def test_hotswap_stats_empty_stats():
"""Test hotswap_stats with empty stats dictionary."""
initial_stats = {
"observation.image": {"mean": np.array([0.5]), "std": np.array([0.2])},
}
empty_stats = {}
features = {
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)),
}
norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD}
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=initial_stats)
robot_processor = RobotProcessor(steps=[normalizer])
# Hotswap with empty stats
new_processor = hotswap_stats(robot_processor, empty_stats)
# Should update to empty stats
assert new_processor.steps[0].stats == empty_stats
assert new_processor.steps[0]._tensor_stats == {}
def test_hotswap_stats_no_normalizer_steps():
"""Test hotswap_stats with a processor that has no normalizer/unnormalizer steps."""
stats = {
"observation.image": {"mean": np.array([0.5]), "std": np.array([0.2])},
}
# Create processor with only identity steps
robot_processor = RobotProcessor(steps=[IdentityProcessor(), IdentityProcessor()])
# Hotswap stats - should work without error
new_processor = hotswap_stats(robot_processor, stats)
# Should return a different object (deep copy)
assert new_processor is not robot_processor
# Steps should be deep copied but unchanged
assert len(new_processor.steps) == len(robot_processor.steps)
for i, step in enumerate(new_processor.steps):
assert step is not robot_processor.steps[i] # Different objects
assert isinstance(step, type(robot_processor.steps[i])) # Same type
def test_hotswap_stats_preserves_other_attributes():
"""Test that hotswap_stats preserves other processor attributes like features and norm_map."""
initial_stats = {
"observation.image": {"mean": np.array([0.5]), "std": np.array([0.2])},
}
new_stats = {
"observation.image": {"mean": np.array([0.3]), "std": np.array([0.1])},
}
features = {
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)),
}
norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD}
normalize_keys = {"observation.image"}
eps = 1e-6
normalizer = NormalizerProcessor(
features=features, norm_map=norm_map, stats=initial_stats, normalize_keys=normalize_keys, eps=eps
)
robot_processor = RobotProcessor(steps=[normalizer])
# Hotswap stats
new_processor = hotswap_stats(robot_processor, new_stats)
# Check that other attributes are preserved
new_normalizer = new_processor.steps[0]
assert new_normalizer.features == features
assert new_normalizer.norm_map == norm_map
assert new_normalizer.normalize_keys == normalize_keys
assert new_normalizer.eps == eps
# But stats should be updated
assert new_normalizer.stats == new_stats
def test_hotswap_stats_multiple_normalizer_types():
"""Test hotswap_stats with multiple normalizer and unnormalizer steps."""
initial_stats = {
"observation.image": {"mean": np.array([0.5]), "std": np.array([0.2])},
"action": {"min": np.array([-1.0]), "max": np.array([1.0])},
}
new_stats = {
"observation.image": {"mean": np.array([0.3]), "std": np.array([0.1])},
"action": {"min": np.array([-2.0]), "max": np.array([2.0])},
}
features = {
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)),
"action": PolicyFeature(type=FeatureType.ACTION, shape=(1,)),
}
norm_map = {
FeatureType.VISUAL: NormalizationMode.MEAN_STD,
FeatureType.ACTION: NormalizationMode.MIN_MAX,
}
# Create multiple normalizers and unnormalizers
normalizer1 = NormalizerProcessor(features=features, norm_map=norm_map, stats=initial_stats)
normalizer2 = NormalizerProcessor(features=features, norm_map=norm_map, stats=initial_stats)
unnormalizer1 = UnnormalizerProcessor(features=features, norm_map=norm_map, stats=initial_stats)
unnormalizer2 = UnnormalizerProcessor(features=features, norm_map=norm_map, stats=initial_stats)
robot_processor = RobotProcessor(steps=[normalizer1, unnormalizer1, normalizer2, unnormalizer2])
# Hotswap stats
new_processor = hotswap_stats(robot_processor, new_stats)
# All normalizer/unnormalizer steps should be updated
for step in new_processor.steps:
assert step.stats == new_stats
# Check tensor stats conversion
expected_tensor_stats = _convert_stats_to_tensors(new_stats)
for key in expected_tensor_stats:
for stat_name in expected_tensor_stats[key]:
torch.testing.assert_close(
step._tensor_stats[key][stat_name], expected_tensor_stats[key][stat_name]
)
def test_hotswap_stats_with_different_data_types():
"""Test hotswap_stats with various data types in stats."""
initial_stats = {
"observation.image": {"mean": np.array([0.5]), "std": np.array([0.2])},
}
# New stats with different data types (int, float, list, tuple)
new_stats = {
"observation.image": {
"mean": [0.3, 0.4, 0.5], # list
"std": (0.1, 0.2, 0.3), # tuple
"min": 0, # int
"max": 1.0, # float
},
"action": {
"mean": np.array([0.1, 0.2]), # numpy array
"std": torch.tensor([0.5, 0.6]), # torch tensor
},
}
features = {
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)),
"action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)),
}
norm_map = {
FeatureType.VISUAL: NormalizationMode.MEAN_STD,
FeatureType.ACTION: NormalizationMode.MEAN_STD,
}
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=initial_stats)
robot_processor = RobotProcessor(steps=[normalizer])
# Hotswap stats
new_processor = hotswap_stats(robot_processor, new_stats)
# Check that stats are updated
assert new_processor.steps[0].stats == new_stats
# Check that tensor conversion worked correctly
tensor_stats = new_processor.steps[0]._tensor_stats
assert isinstance(tensor_stats["observation.image"]["mean"], torch.Tensor)
assert isinstance(tensor_stats["observation.image"]["std"], torch.Tensor)
assert isinstance(tensor_stats["observation.image"]["min"], torch.Tensor)
assert isinstance(tensor_stats["observation.image"]["max"], torch.Tensor)
assert isinstance(tensor_stats["action"]["mean"], torch.Tensor)
assert isinstance(tensor_stats["action"]["std"], torch.Tensor)
# Check values
torch.testing.assert_close(tensor_stats["observation.image"]["mean"], torch.tensor([0.3, 0.4, 0.5]))
torch.testing.assert_close(tensor_stats["observation.image"]["std"], torch.tensor([0.1, 0.2, 0.3]))
torch.testing.assert_close(tensor_stats["observation.image"]["min"], torch.tensor(0.0))
torch.testing.assert_close(tensor_stats["observation.image"]["max"], torch.tensor(1.0))
def test_hotswap_stats_functional_test():
"""Test that hotswapped processor actually works functionally."""
# Create test data
observation = {
"observation.image": torch.tensor([[[0.6, 0.7], [0.8, 0.9]], [[0.5, 0.6], [0.7, 0.8]]]),
}
action = torch.tensor([0.5, -0.5])
transition = create_transition(observation=observation, action=action)
# Initial stats
initial_stats = {
"observation.image": {"mean": np.array([0.5, 0.4]), "std": np.array([0.2, 0.3])},
"action": {"mean": np.array([0.0, 0.0]), "std": np.array([1.0, 1.0])},
}
# New stats
new_stats = {
"observation.image": {"mean": np.array([0.3, 0.2]), "std": np.array([0.1, 0.2])},
"action": {"mean": np.array([0.1, -0.1]), "std": np.array([0.5, 0.5])},
}
features = {
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(2, 2, 2)),
"action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)),
}
norm_map = {
FeatureType.VISUAL: NormalizationMode.MEAN_STD,
FeatureType.ACTION: NormalizationMode.MEAN_STD,
}
# Create original processor
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=initial_stats)
original_processor = RobotProcessor(steps=[normalizer])
# Process with original stats
original_result = original_processor(transition)
# Hotswap stats
new_processor = hotswap_stats(original_processor, new_stats)
# Process with new stats
new_result = new_processor(transition)
# Results should be different since normalization changed
assert not torch.allclose(
original_result["observation"]["observation.image"],
new_result["observation"]["observation.image"],
rtol=1e-3,
atol=1e-3,
)
assert not torch.allclose(original_result["action"], new_result["action"], rtol=1e-3, atol=1e-3)
# Verify that the new processor is actually using the new stats by checking internal state
assert new_processor.steps[0].stats == new_stats
assert torch.allclose(
new_processor.steps[0]._tensor_stats["observation.image"]["mean"], torch.tensor([0.3, 0.2])
)
assert torch.allclose(
new_processor.steps[0]._tensor_stats["observation.image"]["std"], torch.tensor([0.1, 0.2])
)
assert torch.allclose(new_processor.steps[0]._tensor_stats["action"]["mean"], torch.tensor([0.1, -0.1]))
assert torch.allclose(new_processor.steps[0]._tensor_stats["action"]["std"], torch.tensor([0.5, 0.5]))
# Test that normalization actually happens (output should not equal input)
assert not torch.allclose(
new_result["observation"]["observation.image"], observation["observation.image"]
)
assert not torch.allclose(new_result["action"], action)