mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-28 06:59:44 +00:00
refactor(normalization): Remove unused state dict transformation methods and streamline imports
- Eliminated the _transform_state_dict_keys and _load_as_safetensor methods from PI0Policy, simplifying the model loading process. - Cleaned up imports in modeling_pi0.py by removing log_model_loading_keys and init_logging. - Updated TDMPCPolicy and VQBeTPolicy to handle action removal from batches during offline evaluation. - Introduced hotswap_stats function in normalize_processor.py to update normalization statistics dynamically, with corresponding tests to ensure functionality.
This commit is contained in:
committed by
Steven Palma
parent
f02ce69df0
commit
8ff95be04c
@@ -65,8 +65,7 @@ from lerobot.policies.pi0.paligemma_with_expert import (
|
|||||||
PaliGemmaWithExpertModel,
|
PaliGemmaWithExpertModel,
|
||||||
)
|
)
|
||||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||||
from lerobot.policies.utils import log_model_loading_keys
|
from lerobot.utils.utils import get_safe_dtype
|
||||||
from lerobot.utils.utils import get_safe_dtype, init_logging
|
|
||||||
|
|
||||||
|
|
||||||
def create_sinusoidal_pos_embedding(
|
def create_sinusoidal_pos_embedding(
|
||||||
@@ -245,99 +244,6 @@ class PI0Policy(PreTrainedPolicy):
|
|||||||
"""This should be called whenever the environment is reset."""
|
"""This should be called whenever the environment is reset."""
|
||||||
self._action_queue = deque([], maxlen=self.config.n_action_steps)
|
self._action_queue = deque([], maxlen=self.config.n_action_steps)
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _transform_state_dict_keys(cls, state_dict: dict) -> dict:
|
|
||||||
"""
|
|
||||||
Transform state dict keys to match expected model structure.
|
|
||||||
|
|
||||||
Transformations:
|
|
||||||
- model.paligemma_with_expert.paligemma.language_model.lm_head ->
|
|
||||||
model.paligemma_with_expert.paligemma.lm_head
|
|
||||||
- model.paligemma_with_expert.paligemma.language_model.model ->
|
|
||||||
model.paligemma_with_expert.paligemma.model.language_model
|
|
||||||
- model.paligemma_with_expert.paligemma.vision_tower ->
|
|
||||||
model.paligemma_with_expert.paligemma.model.vision_tower
|
|
||||||
- model.paligemma_with_expert.paligemma.multi_modal_projector ->
|
|
||||||
model.paligemma_with_expert.paligemma.model.multi_modal_projector
|
|
||||||
|
|
||||||
Also handles tied weights between lm_head.weight and
|
|
||||||
embed_tokens.weight.
|
|
||||||
"""
|
|
||||||
import re
|
|
||||||
|
|
||||||
transformed_dict = {}
|
|
||||||
|
|
||||||
transformations = [
|
|
||||||
(
|
|
||||||
re.compile(r"\.paligemma_with_expert\.paligemma\.language_model\.lm_head"),
|
|
||||||
".paligemma_with_expert.paligemma.lm_head",
|
|
||||||
),
|
|
||||||
(
|
|
||||||
re.compile(r"\.paligemma_with_expert\.paligemma\.language_model\.model"),
|
|
||||||
".paligemma_with_expert.paligemma.model.language_model",
|
|
||||||
),
|
|
||||||
(
|
|
||||||
re.compile(r"\.paligemma_with_expert\.paligemma\.vision_tower"),
|
|
||||||
".paligemma_with_expert.paligemma.model.vision_tower",
|
|
||||||
),
|
|
||||||
(
|
|
||||||
re.compile(r"\.paligemma_with_expert\.paligemma\.multi_modal_projector"),
|
|
||||||
".paligemma_with_expert.paligemma.model.multi_modal_projector",
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
for key, value in state_dict.items():
|
|
||||||
new_key = key
|
|
||||||
for pattern, replacement in transformations:
|
|
||||||
new_key = pattern.sub(replacement, new_key)
|
|
||||||
transformed_dict[new_key] = value
|
|
||||||
|
|
||||||
# Handle tied weights: lm_head.weight and embed_tokens.weight share memory
|
|
||||||
lm_head_key = None
|
|
||||||
embed_tokens_key = None
|
|
||||||
|
|
||||||
for key in transformed_dict:
|
|
||||||
if key.endswith(".paligemma_with_expert.paligemma.lm_head.weight"):
|
|
||||||
lm_head_key = key
|
|
||||||
elif key.endswith(".paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"):
|
|
||||||
embed_tokens_key = key
|
|
||||||
if lm_head_key and embed_tokens_key:
|
|
||||||
break
|
|
||||||
|
|
||||||
if lm_head_key and not embed_tokens_key:
|
|
||||||
embed_tokens_key = lm_head_key.replace(
|
|
||||||
".lm_head.weight", ".model.language_model.embed_tokens.weight"
|
|
||||||
)
|
|
||||||
transformed_dict[embed_tokens_key] = transformed_dict[lm_head_key]
|
|
||||||
elif embed_tokens_key and not lm_head_key:
|
|
||||||
lm_head_key = embed_tokens_key.replace(
|
|
||||||
".model.language_model.embed_tokens.weight", ".lm_head.weight"
|
|
||||||
)
|
|
||||||
transformed_dict[lm_head_key] = transformed_dict[embed_tokens_key]
|
|
||||||
|
|
||||||
return transformed_dict
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _load_as_safetensor(
|
|
||||||
cls, model: "PI0Policy", model_file: str, map_location: str, strict: bool
|
|
||||||
) -> "PI0Policy":
|
|
||||||
"""Override to apply key transformations before loading."""
|
|
||||||
from safetensors.torch import load_file
|
|
||||||
|
|
||||||
init_logging()
|
|
||||||
# Load the state dict from file safely
|
|
||||||
state_dict = load_file(model_file, device=map_location)
|
|
||||||
|
|
||||||
# Apply key transformations
|
|
||||||
transformed_state_dict = cls._transform_state_dict_keys(state_dict)
|
|
||||||
|
|
||||||
# Load the transformed state dict
|
|
||||||
msg = model.load_state_dict(transformed_state_dict, strict=strict)
|
|
||||||
|
|
||||||
# Log message
|
|
||||||
log_model_loading_keys(msg.missing_keys, msg.unexpected_keys)
|
|
||||||
return model
|
|
||||||
|
|
||||||
def get_optim_params(self) -> dict:
|
def get_optim_params(self) -> dict:
|
||||||
return self.parameters()
|
return self.parameters()
|
||||||
|
|
||||||
|
|||||||
@@ -141,6 +141,9 @@ class TDMPCPolicy(PreTrainedPolicy):
|
|||||||
if self.config.image_features:
|
if self.config.image_features:
|
||||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||||
batch[OBS_IMAGE] = batch[next(iter(self.config.image_features))]
|
batch[OBS_IMAGE] = batch[next(iter(self.config.image_features))]
|
||||||
|
# NOTE: for offline evaluation, we have action in the batch, so we need to pop it out
|
||||||
|
if ACTION in batch:
|
||||||
|
batch.pop(ACTION)
|
||||||
|
|
||||||
self._queues = populate_queues(self._queues, batch)
|
self._queues = populate_queues(self._queues, batch)
|
||||||
|
|
||||||
|
|||||||
@@ -134,6 +134,9 @@ class VQBeTPolicy(PreTrainedPolicy):
|
|||||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||||
# NOTE: It's important that this happens after stacking the images into a single key.
|
# NOTE: It's important that this happens after stacking the images into a single key.
|
||||||
batch["observation.images"] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
|
batch["observation.images"] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
|
||||||
|
# NOTE: for offline evaluation, we have action in the batch, so we need to pop it out
|
||||||
|
if ACTION in batch:
|
||||||
|
batch.pop(ACTION)
|
||||||
|
|
||||||
self._queues = populate_queues(self._queues, batch)
|
self._queues = populate_queues(self._queues, batch)
|
||||||
|
|
||||||
|
|||||||
@@ -16,7 +16,7 @@
|
|||||||
|
|
||||||
from .batch_processor import ToBatchProcessor
|
from .batch_processor import ToBatchProcessor
|
||||||
from .device_processor import DeviceProcessor
|
from .device_processor import DeviceProcessor
|
||||||
from .normalize_processor import NormalizerProcessor, UnnormalizerProcessor
|
from .normalize_processor import NormalizerProcessor, UnnormalizerProcessor, hotswap_stats
|
||||||
from .observation_processor import VanillaObservationProcessor
|
from .observation_processor import VanillaObservationProcessor
|
||||||
from .pipeline import (
|
from .pipeline import (
|
||||||
ActionProcessor,
|
ActionProcessor,
|
||||||
@@ -43,6 +43,7 @@ __all__ = [
|
|||||||
"InfoProcessor",
|
"InfoProcessor",
|
||||||
"NormalizerProcessor",
|
"NormalizerProcessor",
|
||||||
"UnnormalizerProcessor",
|
"UnnormalizerProcessor",
|
||||||
|
"hotswap_stats",
|
||||||
"ObservationProcessor",
|
"ObservationProcessor",
|
||||||
"ProcessorStep",
|
"ProcessorStep",
|
||||||
"ProcessorStepRegistry",
|
"ProcessorStepRegistry",
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
|
from copy import deepcopy
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@@ -10,7 +12,7 @@ from torch import Tensor
|
|||||||
|
|
||||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||||
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey
|
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey, RobotProcessor
|
||||||
|
|
||||||
|
|
||||||
def _convert_stats_to_tensors(stats: dict[str, dict[str, Any]]) -> dict[str, dict[str, Tensor]]:
|
def _convert_stats_to_tensors(stats: dict[str, dict[str, Any]]) -> dict[str, dict[str, Tensor]]:
|
||||||
@@ -402,3 +404,13 @@ class UnnormalizerProcessor:
|
|||||||
|
|
||||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||||
return features
|
return features
|
||||||
|
|
||||||
|
|
||||||
|
def hotswap_stats(robot_processor: RobotProcessor, stats: dict[str, dict[str, Any]]) -> RobotProcessor:
|
||||||
|
robot_processor = deepcopy(robot_processor)
|
||||||
|
for step in robot_processor.steps:
|
||||||
|
if isinstance(step, NormalizerProcessor) or isinstance(step, UnnormalizerProcessor):
|
||||||
|
step: NormalizerProcessor | UnnormalizerProcessor
|
||||||
|
step.stats = stats
|
||||||
|
step._tensor_stats = _convert_stats_to_tensors(stats)
|
||||||
|
return robot_processor
|
||||||
|
|||||||
@@ -24,8 +24,9 @@ from lerobot.processor.normalize_processor import (
|
|||||||
NormalizerProcessor,
|
NormalizerProcessor,
|
||||||
UnnormalizerProcessor,
|
UnnormalizerProcessor,
|
||||||
_convert_stats_to_tensors,
|
_convert_stats_to_tensors,
|
||||||
|
hotswap_stats,
|
||||||
)
|
)
|
||||||
from lerobot.processor.pipeline import RobotProcessor, TransitionKey
|
from lerobot.processor.pipeline import IdentityProcessor, RobotProcessor, TransitionKey
|
||||||
|
|
||||||
|
|
||||||
def create_transition(
|
def create_transition(
|
||||||
@@ -953,3 +954,377 @@ def test_unsupported_normalization_mode_error():
|
|||||||
|
|
||||||
with pytest.raises(ValueError, match="Unsupported normalization mode"):
|
with pytest.raises(ValueError, match="Unsupported normalization mode"):
|
||||||
normalizer(transition)
|
normalizer(transition)
|
||||||
|
|
||||||
|
|
||||||
|
def test_hotswap_stats_basic_functionality():
|
||||||
|
"""Test that hotswap_stats correctly updates stats in normalizer/unnormalizer steps."""
|
||||||
|
# Create initial stats
|
||||||
|
initial_stats = {
|
||||||
|
"observation.image": {"mean": np.array([0.5, 0.5, 0.5]), "std": np.array([0.2, 0.2, 0.2])},
|
||||||
|
"action": {"mean": np.array([0.0, 0.0]), "std": np.array([1.0, 1.0])},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create new stats for hotswapping
|
||||||
|
new_stats = {
|
||||||
|
"observation.image": {"mean": np.array([0.3, 0.3, 0.3]), "std": np.array([0.1, 0.1, 0.1])},
|
||||||
|
"action": {"mean": np.array([0.1, 0.1]), "std": np.array([0.5, 0.5])},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create features and norm_map
|
||||||
|
features = {
|
||||||
|
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)),
|
||||||
|
"action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)),
|
||||||
|
}
|
||||||
|
norm_map = {
|
||||||
|
FeatureType.VISUAL: NormalizationMode.MEAN_STD,
|
||||||
|
FeatureType.ACTION: NormalizationMode.MEAN_STD,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create processors
|
||||||
|
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=initial_stats)
|
||||||
|
unnormalizer = UnnormalizerProcessor(features=features, norm_map=norm_map, stats=initial_stats)
|
||||||
|
identity = IdentityProcessor()
|
||||||
|
|
||||||
|
# Create robot processor
|
||||||
|
robot_processor = RobotProcessor(steps=[normalizer, unnormalizer, identity])
|
||||||
|
|
||||||
|
# Hotswap stats
|
||||||
|
new_processor = hotswap_stats(robot_processor, new_stats)
|
||||||
|
|
||||||
|
# Check that normalizer and unnormalizer have new stats
|
||||||
|
assert new_processor.steps[0].stats == new_stats
|
||||||
|
assert new_processor.steps[1].stats == new_stats
|
||||||
|
|
||||||
|
# Check that tensor stats are updated correctly
|
||||||
|
expected_tensor_stats = _convert_stats_to_tensors(new_stats)
|
||||||
|
for key in expected_tensor_stats:
|
||||||
|
for stat_name in expected_tensor_stats[key]:
|
||||||
|
torch.testing.assert_close(
|
||||||
|
new_processor.steps[0]._tensor_stats[key][stat_name], expected_tensor_stats[key][stat_name]
|
||||||
|
)
|
||||||
|
torch.testing.assert_close(
|
||||||
|
new_processor.steps[1]._tensor_stats[key][stat_name], expected_tensor_stats[key][stat_name]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_hotswap_stats_deep_copy():
|
||||||
|
"""Test that hotswap_stats creates a deep copy and doesn't modify the original processor."""
|
||||||
|
initial_stats = {
|
||||||
|
"observation.image": {"mean": np.array([0.5, 0.5, 0.5]), "std": np.array([0.2, 0.2, 0.2])},
|
||||||
|
}
|
||||||
|
|
||||||
|
new_stats = {
|
||||||
|
"observation.image": {"mean": np.array([0.3, 0.3, 0.3]), "std": np.array([0.1, 0.1, 0.1])},
|
||||||
|
}
|
||||||
|
|
||||||
|
features = {
|
||||||
|
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)),
|
||||||
|
}
|
||||||
|
norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD}
|
||||||
|
|
||||||
|
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=initial_stats)
|
||||||
|
original_processor = RobotProcessor(steps=[normalizer])
|
||||||
|
|
||||||
|
# Store reference to original stats
|
||||||
|
original_stats_reference = original_processor.steps[0].stats
|
||||||
|
original_tensor_stats_reference = original_processor.steps[0]._tensor_stats
|
||||||
|
|
||||||
|
# Hotswap stats
|
||||||
|
new_processor = hotswap_stats(original_processor, new_stats)
|
||||||
|
|
||||||
|
# Original processor should be unchanged
|
||||||
|
assert original_processor.steps[0].stats is original_stats_reference
|
||||||
|
assert original_processor.steps[0]._tensor_stats is original_tensor_stats_reference
|
||||||
|
assert original_processor.steps[0].stats == initial_stats
|
||||||
|
|
||||||
|
# New processor should have new stats
|
||||||
|
assert new_processor.steps[0].stats == new_stats
|
||||||
|
assert new_processor.steps[0].stats is not original_stats_reference
|
||||||
|
|
||||||
|
# Processors should be different objects
|
||||||
|
assert new_processor is not original_processor
|
||||||
|
assert new_processor.steps[0] is not original_processor.steps[0]
|
||||||
|
|
||||||
|
|
||||||
|
def test_hotswap_stats_only_affects_normalizer_steps():
|
||||||
|
"""Test that hotswap_stats only modifies NormalizerProcessor and UnnormalizerProcessor steps."""
|
||||||
|
stats = {
|
||||||
|
"observation.image": {"mean": np.array([0.5]), "std": np.array([0.2])},
|
||||||
|
}
|
||||||
|
|
||||||
|
new_stats = {
|
||||||
|
"observation.image": {"mean": np.array([0.3]), "std": np.array([0.1])},
|
||||||
|
}
|
||||||
|
|
||||||
|
features = {
|
||||||
|
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)),
|
||||||
|
}
|
||||||
|
norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD}
|
||||||
|
|
||||||
|
# Create mixed steps
|
||||||
|
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
|
||||||
|
unnormalizer = UnnormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
|
||||||
|
identity = IdentityProcessor()
|
||||||
|
|
||||||
|
robot_processor = RobotProcessor(steps=[normalizer, identity, unnormalizer])
|
||||||
|
|
||||||
|
# Hotswap stats
|
||||||
|
new_processor = hotswap_stats(robot_processor, new_stats)
|
||||||
|
|
||||||
|
# Check that only normalizer and unnormalizer steps are affected
|
||||||
|
assert new_processor.steps[0].stats == new_stats # normalizer
|
||||||
|
assert new_processor.steps[2].stats == new_stats # unnormalizer
|
||||||
|
|
||||||
|
# Identity processor should remain unchanged (and it doesn't have stats attribute)
|
||||||
|
assert not hasattr(new_processor.steps[1], "stats")
|
||||||
|
|
||||||
|
|
||||||
|
def test_hotswap_stats_empty_stats():
|
||||||
|
"""Test hotswap_stats with empty stats dictionary."""
|
||||||
|
initial_stats = {
|
||||||
|
"observation.image": {"mean": np.array([0.5]), "std": np.array([0.2])},
|
||||||
|
}
|
||||||
|
|
||||||
|
empty_stats = {}
|
||||||
|
|
||||||
|
features = {
|
||||||
|
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)),
|
||||||
|
}
|
||||||
|
norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD}
|
||||||
|
|
||||||
|
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=initial_stats)
|
||||||
|
robot_processor = RobotProcessor(steps=[normalizer])
|
||||||
|
|
||||||
|
# Hotswap with empty stats
|
||||||
|
new_processor = hotswap_stats(robot_processor, empty_stats)
|
||||||
|
|
||||||
|
# Should update to empty stats
|
||||||
|
assert new_processor.steps[0].stats == empty_stats
|
||||||
|
assert new_processor.steps[0]._tensor_stats == {}
|
||||||
|
|
||||||
|
|
||||||
|
def test_hotswap_stats_no_normalizer_steps():
|
||||||
|
"""Test hotswap_stats with a processor that has no normalizer/unnormalizer steps."""
|
||||||
|
stats = {
|
||||||
|
"observation.image": {"mean": np.array([0.5]), "std": np.array([0.2])},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create processor with only identity steps
|
||||||
|
robot_processor = RobotProcessor(steps=[IdentityProcessor(), IdentityProcessor()])
|
||||||
|
|
||||||
|
# Hotswap stats - should work without error
|
||||||
|
new_processor = hotswap_stats(robot_processor, stats)
|
||||||
|
|
||||||
|
# Should return a different object (deep copy)
|
||||||
|
assert new_processor is not robot_processor
|
||||||
|
|
||||||
|
# Steps should be deep copied but unchanged
|
||||||
|
assert len(new_processor.steps) == len(robot_processor.steps)
|
||||||
|
for i, step in enumerate(new_processor.steps):
|
||||||
|
assert step is not robot_processor.steps[i] # Different objects
|
||||||
|
assert isinstance(step, type(robot_processor.steps[i])) # Same type
|
||||||
|
|
||||||
|
|
||||||
|
def test_hotswap_stats_preserves_other_attributes():
|
||||||
|
"""Test that hotswap_stats preserves other processor attributes like features and norm_map."""
|
||||||
|
initial_stats = {
|
||||||
|
"observation.image": {"mean": np.array([0.5]), "std": np.array([0.2])},
|
||||||
|
}
|
||||||
|
|
||||||
|
new_stats = {
|
||||||
|
"observation.image": {"mean": np.array([0.3]), "std": np.array([0.1])},
|
||||||
|
}
|
||||||
|
|
||||||
|
features = {
|
||||||
|
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)),
|
||||||
|
}
|
||||||
|
norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD}
|
||||||
|
normalize_keys = {"observation.image"}
|
||||||
|
eps = 1e-6
|
||||||
|
|
||||||
|
normalizer = NormalizerProcessor(
|
||||||
|
features=features, norm_map=norm_map, stats=initial_stats, normalize_keys=normalize_keys, eps=eps
|
||||||
|
)
|
||||||
|
robot_processor = RobotProcessor(steps=[normalizer])
|
||||||
|
|
||||||
|
# Hotswap stats
|
||||||
|
new_processor = hotswap_stats(robot_processor, new_stats)
|
||||||
|
|
||||||
|
# Check that other attributes are preserved
|
||||||
|
new_normalizer = new_processor.steps[0]
|
||||||
|
assert new_normalizer.features == features
|
||||||
|
assert new_normalizer.norm_map == norm_map
|
||||||
|
assert new_normalizer.normalize_keys == normalize_keys
|
||||||
|
assert new_normalizer.eps == eps
|
||||||
|
|
||||||
|
# But stats should be updated
|
||||||
|
assert new_normalizer.stats == new_stats
|
||||||
|
|
||||||
|
|
||||||
|
def test_hotswap_stats_multiple_normalizer_types():
|
||||||
|
"""Test hotswap_stats with multiple normalizer and unnormalizer steps."""
|
||||||
|
initial_stats = {
|
||||||
|
"observation.image": {"mean": np.array([0.5]), "std": np.array([0.2])},
|
||||||
|
"action": {"min": np.array([-1.0]), "max": np.array([1.0])},
|
||||||
|
}
|
||||||
|
|
||||||
|
new_stats = {
|
||||||
|
"observation.image": {"mean": np.array([0.3]), "std": np.array([0.1])},
|
||||||
|
"action": {"min": np.array([-2.0]), "max": np.array([2.0])},
|
||||||
|
}
|
||||||
|
|
||||||
|
features = {
|
||||||
|
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)),
|
||||||
|
"action": PolicyFeature(type=FeatureType.ACTION, shape=(1,)),
|
||||||
|
}
|
||||||
|
norm_map = {
|
||||||
|
FeatureType.VISUAL: NormalizationMode.MEAN_STD,
|
||||||
|
FeatureType.ACTION: NormalizationMode.MIN_MAX,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create multiple normalizers and unnormalizers
|
||||||
|
normalizer1 = NormalizerProcessor(features=features, norm_map=norm_map, stats=initial_stats)
|
||||||
|
normalizer2 = NormalizerProcessor(features=features, norm_map=norm_map, stats=initial_stats)
|
||||||
|
unnormalizer1 = UnnormalizerProcessor(features=features, norm_map=norm_map, stats=initial_stats)
|
||||||
|
unnormalizer2 = UnnormalizerProcessor(features=features, norm_map=norm_map, stats=initial_stats)
|
||||||
|
|
||||||
|
robot_processor = RobotProcessor(steps=[normalizer1, unnormalizer1, normalizer2, unnormalizer2])
|
||||||
|
|
||||||
|
# Hotswap stats
|
||||||
|
new_processor = hotswap_stats(robot_processor, new_stats)
|
||||||
|
|
||||||
|
# All normalizer/unnormalizer steps should be updated
|
||||||
|
for step in new_processor.steps:
|
||||||
|
assert step.stats == new_stats
|
||||||
|
|
||||||
|
# Check tensor stats conversion
|
||||||
|
expected_tensor_stats = _convert_stats_to_tensors(new_stats)
|
||||||
|
for key in expected_tensor_stats:
|
||||||
|
for stat_name in expected_tensor_stats[key]:
|
||||||
|
torch.testing.assert_close(
|
||||||
|
step._tensor_stats[key][stat_name], expected_tensor_stats[key][stat_name]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_hotswap_stats_with_different_data_types():
|
||||||
|
"""Test hotswap_stats with various data types in stats."""
|
||||||
|
initial_stats = {
|
||||||
|
"observation.image": {"mean": np.array([0.5]), "std": np.array([0.2])},
|
||||||
|
}
|
||||||
|
|
||||||
|
# New stats with different data types (int, float, list, tuple)
|
||||||
|
new_stats = {
|
||||||
|
"observation.image": {
|
||||||
|
"mean": [0.3, 0.4, 0.5], # list
|
||||||
|
"std": (0.1, 0.2, 0.3), # tuple
|
||||||
|
"min": 0, # int
|
||||||
|
"max": 1.0, # float
|
||||||
|
},
|
||||||
|
"action": {
|
||||||
|
"mean": np.array([0.1, 0.2]), # numpy array
|
||||||
|
"std": torch.tensor([0.5, 0.6]), # torch tensor
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
features = {
|
||||||
|
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)),
|
||||||
|
"action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)),
|
||||||
|
}
|
||||||
|
norm_map = {
|
||||||
|
FeatureType.VISUAL: NormalizationMode.MEAN_STD,
|
||||||
|
FeatureType.ACTION: NormalizationMode.MEAN_STD,
|
||||||
|
}
|
||||||
|
|
||||||
|
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=initial_stats)
|
||||||
|
robot_processor = RobotProcessor(steps=[normalizer])
|
||||||
|
|
||||||
|
# Hotswap stats
|
||||||
|
new_processor = hotswap_stats(robot_processor, new_stats)
|
||||||
|
|
||||||
|
# Check that stats are updated
|
||||||
|
assert new_processor.steps[0].stats == new_stats
|
||||||
|
|
||||||
|
# Check that tensor conversion worked correctly
|
||||||
|
tensor_stats = new_processor.steps[0]._tensor_stats
|
||||||
|
assert isinstance(tensor_stats["observation.image"]["mean"], torch.Tensor)
|
||||||
|
assert isinstance(tensor_stats["observation.image"]["std"], torch.Tensor)
|
||||||
|
assert isinstance(tensor_stats["observation.image"]["min"], torch.Tensor)
|
||||||
|
assert isinstance(tensor_stats["observation.image"]["max"], torch.Tensor)
|
||||||
|
assert isinstance(tensor_stats["action"]["mean"], torch.Tensor)
|
||||||
|
assert isinstance(tensor_stats["action"]["std"], torch.Tensor)
|
||||||
|
|
||||||
|
# Check values
|
||||||
|
torch.testing.assert_close(tensor_stats["observation.image"]["mean"], torch.tensor([0.3, 0.4, 0.5]))
|
||||||
|
torch.testing.assert_close(tensor_stats["observation.image"]["std"], torch.tensor([0.1, 0.2, 0.3]))
|
||||||
|
torch.testing.assert_close(tensor_stats["observation.image"]["min"], torch.tensor(0.0))
|
||||||
|
torch.testing.assert_close(tensor_stats["observation.image"]["max"], torch.tensor(1.0))
|
||||||
|
|
||||||
|
|
||||||
|
def test_hotswap_stats_functional_test():
|
||||||
|
"""Test that hotswapped processor actually works functionally."""
|
||||||
|
# Create test data
|
||||||
|
observation = {
|
||||||
|
"observation.image": torch.tensor([[[0.6, 0.7], [0.8, 0.9]], [[0.5, 0.6], [0.7, 0.8]]]),
|
||||||
|
}
|
||||||
|
action = torch.tensor([0.5, -0.5])
|
||||||
|
transition = create_transition(observation=observation, action=action)
|
||||||
|
|
||||||
|
# Initial stats
|
||||||
|
initial_stats = {
|
||||||
|
"observation.image": {"mean": np.array([0.5, 0.4]), "std": np.array([0.2, 0.3])},
|
||||||
|
"action": {"mean": np.array([0.0, 0.0]), "std": np.array([1.0, 1.0])},
|
||||||
|
}
|
||||||
|
|
||||||
|
# New stats
|
||||||
|
new_stats = {
|
||||||
|
"observation.image": {"mean": np.array([0.3, 0.2]), "std": np.array([0.1, 0.2])},
|
||||||
|
"action": {"mean": np.array([0.1, -0.1]), "std": np.array([0.5, 0.5])},
|
||||||
|
}
|
||||||
|
|
||||||
|
features = {
|
||||||
|
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(2, 2, 2)),
|
||||||
|
"action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)),
|
||||||
|
}
|
||||||
|
norm_map = {
|
||||||
|
FeatureType.VISUAL: NormalizationMode.MEAN_STD,
|
||||||
|
FeatureType.ACTION: NormalizationMode.MEAN_STD,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create original processor
|
||||||
|
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=initial_stats)
|
||||||
|
original_processor = RobotProcessor(steps=[normalizer])
|
||||||
|
|
||||||
|
# Process with original stats
|
||||||
|
original_result = original_processor(transition)
|
||||||
|
|
||||||
|
# Hotswap stats
|
||||||
|
new_processor = hotswap_stats(original_processor, new_stats)
|
||||||
|
|
||||||
|
# Process with new stats
|
||||||
|
new_result = new_processor(transition)
|
||||||
|
|
||||||
|
# Results should be different since normalization changed
|
||||||
|
assert not torch.allclose(
|
||||||
|
original_result["observation"]["observation.image"],
|
||||||
|
new_result["observation"]["observation.image"],
|
||||||
|
rtol=1e-3,
|
||||||
|
atol=1e-3,
|
||||||
|
)
|
||||||
|
assert not torch.allclose(original_result["action"], new_result["action"], rtol=1e-3, atol=1e-3)
|
||||||
|
|
||||||
|
# Verify that the new processor is actually using the new stats by checking internal state
|
||||||
|
assert new_processor.steps[0].stats == new_stats
|
||||||
|
assert torch.allclose(
|
||||||
|
new_processor.steps[0]._tensor_stats["observation.image"]["mean"], torch.tensor([0.3, 0.2])
|
||||||
|
)
|
||||||
|
assert torch.allclose(
|
||||||
|
new_processor.steps[0]._tensor_stats["observation.image"]["std"], torch.tensor([0.1, 0.2])
|
||||||
|
)
|
||||||
|
assert torch.allclose(new_processor.steps[0]._tensor_stats["action"]["mean"], torch.tensor([0.1, -0.1]))
|
||||||
|
assert torch.allclose(new_processor.steps[0]._tensor_stats["action"]["std"], torch.tensor([0.5, 0.5]))
|
||||||
|
|
||||||
|
# Test that normalization actually happens (output should not equal input)
|
||||||
|
assert not torch.allclose(
|
||||||
|
new_result["observation"]["observation.image"], observation["observation.image"]
|
||||||
|
)
|
||||||
|
assert not torch.allclose(new_result["action"], action)
|
||||||
|
|||||||
Reference in New Issue
Block a user