mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-27 06:29:47 +00:00
refactor(normalization): Remove unused state dict transformation methods and streamline imports
- Eliminated the _transform_state_dict_keys and _load_as_safetensor methods from PI0Policy, simplifying the model loading process. - Cleaned up imports in modeling_pi0.py by removing log_model_loading_keys and init_logging. - Updated TDMPCPolicy and VQBeTPolicy to handle action removal from batches during offline evaluation. - Introduced hotswap_stats function in normalize_processor.py to update normalization statistics dynamically, with corresponding tests to ensure functionality.
This commit is contained in:
committed by
Steven Palma
parent
f02ce69df0
commit
8ff95be04c
@@ -65,8 +65,7 @@ from lerobot.policies.pi0.paligemma_with_expert import (
|
||||
PaliGemmaWithExpertModel,
|
||||
)
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.utils import log_model_loading_keys
|
||||
from lerobot.utils.utils import get_safe_dtype, init_logging
|
||||
from lerobot.utils.utils import get_safe_dtype
|
||||
|
||||
|
||||
def create_sinusoidal_pos_embedding(
|
||||
@@ -245,99 +244,6 @@ class PI0Policy(PreTrainedPolicy):
|
||||
"""This should be called whenever the environment is reset."""
|
||||
self._action_queue = deque([], maxlen=self.config.n_action_steps)
|
||||
|
||||
@classmethod
|
||||
def _transform_state_dict_keys(cls, state_dict: dict) -> dict:
|
||||
"""
|
||||
Transform state dict keys to match expected model structure.
|
||||
|
||||
Transformations:
|
||||
- model.paligemma_with_expert.paligemma.language_model.lm_head ->
|
||||
model.paligemma_with_expert.paligemma.lm_head
|
||||
- model.paligemma_with_expert.paligemma.language_model.model ->
|
||||
model.paligemma_with_expert.paligemma.model.language_model
|
||||
- model.paligemma_with_expert.paligemma.vision_tower ->
|
||||
model.paligemma_with_expert.paligemma.model.vision_tower
|
||||
- model.paligemma_with_expert.paligemma.multi_modal_projector ->
|
||||
model.paligemma_with_expert.paligemma.model.multi_modal_projector
|
||||
|
||||
Also handles tied weights between lm_head.weight and
|
||||
embed_tokens.weight.
|
||||
"""
|
||||
import re
|
||||
|
||||
transformed_dict = {}
|
||||
|
||||
transformations = [
|
||||
(
|
||||
re.compile(r"\.paligemma_with_expert\.paligemma\.language_model\.lm_head"),
|
||||
".paligemma_with_expert.paligemma.lm_head",
|
||||
),
|
||||
(
|
||||
re.compile(r"\.paligemma_with_expert\.paligemma\.language_model\.model"),
|
||||
".paligemma_with_expert.paligemma.model.language_model",
|
||||
),
|
||||
(
|
||||
re.compile(r"\.paligemma_with_expert\.paligemma\.vision_tower"),
|
||||
".paligemma_with_expert.paligemma.model.vision_tower",
|
||||
),
|
||||
(
|
||||
re.compile(r"\.paligemma_with_expert\.paligemma\.multi_modal_projector"),
|
||||
".paligemma_with_expert.paligemma.model.multi_modal_projector",
|
||||
),
|
||||
]
|
||||
|
||||
for key, value in state_dict.items():
|
||||
new_key = key
|
||||
for pattern, replacement in transformations:
|
||||
new_key = pattern.sub(replacement, new_key)
|
||||
transformed_dict[new_key] = value
|
||||
|
||||
# Handle tied weights: lm_head.weight and embed_tokens.weight share memory
|
||||
lm_head_key = None
|
||||
embed_tokens_key = None
|
||||
|
||||
for key in transformed_dict:
|
||||
if key.endswith(".paligemma_with_expert.paligemma.lm_head.weight"):
|
||||
lm_head_key = key
|
||||
elif key.endswith(".paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"):
|
||||
embed_tokens_key = key
|
||||
if lm_head_key and embed_tokens_key:
|
||||
break
|
||||
|
||||
if lm_head_key and not embed_tokens_key:
|
||||
embed_tokens_key = lm_head_key.replace(
|
||||
".lm_head.weight", ".model.language_model.embed_tokens.weight"
|
||||
)
|
||||
transformed_dict[embed_tokens_key] = transformed_dict[lm_head_key]
|
||||
elif embed_tokens_key and not lm_head_key:
|
||||
lm_head_key = embed_tokens_key.replace(
|
||||
".model.language_model.embed_tokens.weight", ".lm_head.weight"
|
||||
)
|
||||
transformed_dict[lm_head_key] = transformed_dict[embed_tokens_key]
|
||||
|
||||
return transformed_dict
|
||||
|
||||
@classmethod
|
||||
def _load_as_safetensor(
|
||||
cls, model: "PI0Policy", model_file: str, map_location: str, strict: bool
|
||||
) -> "PI0Policy":
|
||||
"""Override to apply key transformations before loading."""
|
||||
from safetensors.torch import load_file
|
||||
|
||||
init_logging()
|
||||
# Load the state dict from file safely
|
||||
state_dict = load_file(model_file, device=map_location)
|
||||
|
||||
# Apply key transformations
|
||||
transformed_state_dict = cls._transform_state_dict_keys(state_dict)
|
||||
|
||||
# Load the transformed state dict
|
||||
msg = model.load_state_dict(transformed_state_dict, strict=strict)
|
||||
|
||||
# Log message
|
||||
log_model_loading_keys(msg.missing_keys, msg.unexpected_keys)
|
||||
return model
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
return self.parameters()
|
||||
|
||||
|
||||
@@ -141,6 +141,9 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch[OBS_IMAGE] = batch[next(iter(self.config.image_features))]
|
||||
# NOTE: for offline evaluation, we have action in the batch, so we need to pop it out
|
||||
if ACTION in batch:
|
||||
batch.pop(ACTION)
|
||||
|
||||
self._queues = populate_queues(self._queues, batch)
|
||||
|
||||
|
||||
@@ -134,6 +134,9 @@ class VQBeTPolicy(PreTrainedPolicy):
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
# NOTE: It's important that this happens after stacking the images into a single key.
|
||||
batch["observation.images"] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
|
||||
# NOTE: for offline evaluation, we have action in the batch, so we need to pop it out
|
||||
if ACTION in batch:
|
||||
batch.pop(ACTION)
|
||||
|
||||
self._queues = populate_queues(self._queues, batch)
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
|
||||
from .batch_processor import ToBatchProcessor
|
||||
from .device_processor import DeviceProcessor
|
||||
from .normalize_processor import NormalizerProcessor, UnnormalizerProcessor
|
||||
from .normalize_processor import NormalizerProcessor, UnnormalizerProcessor, hotswap_stats
|
||||
from .observation_processor import VanillaObservationProcessor
|
||||
from .pipeline import (
|
||||
ActionProcessor,
|
||||
@@ -43,6 +43,7 @@ __all__ = [
|
||||
"InfoProcessor",
|
||||
"NormalizerProcessor",
|
||||
"UnnormalizerProcessor",
|
||||
"hotswap_stats",
|
||||
"ObservationProcessor",
|
||||
"ProcessorStep",
|
||||
"ProcessorStepRegistry",
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
from collections.abc import Mapping
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
@@ -10,7 +12,7 @@ from torch import Tensor
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey
|
||||
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey, RobotProcessor
|
||||
|
||||
|
||||
def _convert_stats_to_tensors(stats: dict[str, dict[str, Any]]) -> dict[str, dict[str, Tensor]]:
|
||||
@@ -402,3 +404,13 @@ class UnnormalizerProcessor:
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
def hotswap_stats(robot_processor: RobotProcessor, stats: dict[str, dict[str, Any]]) -> RobotProcessor:
|
||||
robot_processor = deepcopy(robot_processor)
|
||||
for step in robot_processor.steps:
|
||||
if isinstance(step, NormalizerProcessor) or isinstance(step, UnnormalizerProcessor):
|
||||
step: NormalizerProcessor | UnnormalizerProcessor
|
||||
step.stats = stats
|
||||
step._tensor_stats = _convert_stats_to_tensors(stats)
|
||||
return robot_processor
|
||||
|
||||
@@ -24,8 +24,9 @@ from lerobot.processor.normalize_processor import (
|
||||
NormalizerProcessor,
|
||||
UnnormalizerProcessor,
|
||||
_convert_stats_to_tensors,
|
||||
hotswap_stats,
|
||||
)
|
||||
from lerobot.processor.pipeline import RobotProcessor, TransitionKey
|
||||
from lerobot.processor.pipeline import IdentityProcessor, RobotProcessor, TransitionKey
|
||||
|
||||
|
||||
def create_transition(
|
||||
@@ -953,3 +954,377 @@ def test_unsupported_normalization_mode_error():
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported normalization mode"):
|
||||
normalizer(transition)
|
||||
|
||||
|
||||
def test_hotswap_stats_basic_functionality():
|
||||
"""Test that hotswap_stats correctly updates stats in normalizer/unnormalizer steps."""
|
||||
# Create initial stats
|
||||
initial_stats = {
|
||||
"observation.image": {"mean": np.array([0.5, 0.5, 0.5]), "std": np.array([0.2, 0.2, 0.2])},
|
||||
"action": {"mean": np.array([0.0, 0.0]), "std": np.array([1.0, 1.0])},
|
||||
}
|
||||
|
||||
# Create new stats for hotswapping
|
||||
new_stats = {
|
||||
"observation.image": {"mean": np.array([0.3, 0.3, 0.3]), "std": np.array([0.1, 0.1, 0.1])},
|
||||
"action": {"mean": np.array([0.1, 0.1]), "std": np.array([0.5, 0.5])},
|
||||
}
|
||||
|
||||
# Create features and norm_map
|
||||
features = {
|
||||
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)),
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)),
|
||||
}
|
||||
norm_map = {
|
||||
FeatureType.VISUAL: NormalizationMode.MEAN_STD,
|
||||
FeatureType.ACTION: NormalizationMode.MEAN_STD,
|
||||
}
|
||||
|
||||
# Create processors
|
||||
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=initial_stats)
|
||||
unnormalizer = UnnormalizerProcessor(features=features, norm_map=norm_map, stats=initial_stats)
|
||||
identity = IdentityProcessor()
|
||||
|
||||
# Create robot processor
|
||||
robot_processor = RobotProcessor(steps=[normalizer, unnormalizer, identity])
|
||||
|
||||
# Hotswap stats
|
||||
new_processor = hotswap_stats(robot_processor, new_stats)
|
||||
|
||||
# Check that normalizer and unnormalizer have new stats
|
||||
assert new_processor.steps[0].stats == new_stats
|
||||
assert new_processor.steps[1].stats == new_stats
|
||||
|
||||
# Check that tensor stats are updated correctly
|
||||
expected_tensor_stats = _convert_stats_to_tensors(new_stats)
|
||||
for key in expected_tensor_stats:
|
||||
for stat_name in expected_tensor_stats[key]:
|
||||
torch.testing.assert_close(
|
||||
new_processor.steps[0]._tensor_stats[key][stat_name], expected_tensor_stats[key][stat_name]
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
new_processor.steps[1]._tensor_stats[key][stat_name], expected_tensor_stats[key][stat_name]
|
||||
)
|
||||
|
||||
|
||||
def test_hotswap_stats_deep_copy():
|
||||
"""Test that hotswap_stats creates a deep copy and doesn't modify the original processor."""
|
||||
initial_stats = {
|
||||
"observation.image": {"mean": np.array([0.5, 0.5, 0.5]), "std": np.array([0.2, 0.2, 0.2])},
|
||||
}
|
||||
|
||||
new_stats = {
|
||||
"observation.image": {"mean": np.array([0.3, 0.3, 0.3]), "std": np.array([0.1, 0.1, 0.1])},
|
||||
}
|
||||
|
||||
features = {
|
||||
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)),
|
||||
}
|
||||
norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD}
|
||||
|
||||
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=initial_stats)
|
||||
original_processor = RobotProcessor(steps=[normalizer])
|
||||
|
||||
# Store reference to original stats
|
||||
original_stats_reference = original_processor.steps[0].stats
|
||||
original_tensor_stats_reference = original_processor.steps[0]._tensor_stats
|
||||
|
||||
# Hotswap stats
|
||||
new_processor = hotswap_stats(original_processor, new_stats)
|
||||
|
||||
# Original processor should be unchanged
|
||||
assert original_processor.steps[0].stats is original_stats_reference
|
||||
assert original_processor.steps[0]._tensor_stats is original_tensor_stats_reference
|
||||
assert original_processor.steps[0].stats == initial_stats
|
||||
|
||||
# New processor should have new stats
|
||||
assert new_processor.steps[0].stats == new_stats
|
||||
assert new_processor.steps[0].stats is not original_stats_reference
|
||||
|
||||
# Processors should be different objects
|
||||
assert new_processor is not original_processor
|
||||
assert new_processor.steps[0] is not original_processor.steps[0]
|
||||
|
||||
|
||||
def test_hotswap_stats_only_affects_normalizer_steps():
|
||||
"""Test that hotswap_stats only modifies NormalizerProcessor and UnnormalizerProcessor steps."""
|
||||
stats = {
|
||||
"observation.image": {"mean": np.array([0.5]), "std": np.array([0.2])},
|
||||
}
|
||||
|
||||
new_stats = {
|
||||
"observation.image": {"mean": np.array([0.3]), "std": np.array([0.1])},
|
||||
}
|
||||
|
||||
features = {
|
||||
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)),
|
||||
}
|
||||
norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD}
|
||||
|
||||
# Create mixed steps
|
||||
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
|
||||
unnormalizer = UnnormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
|
||||
identity = IdentityProcessor()
|
||||
|
||||
robot_processor = RobotProcessor(steps=[normalizer, identity, unnormalizer])
|
||||
|
||||
# Hotswap stats
|
||||
new_processor = hotswap_stats(robot_processor, new_stats)
|
||||
|
||||
# Check that only normalizer and unnormalizer steps are affected
|
||||
assert new_processor.steps[0].stats == new_stats # normalizer
|
||||
assert new_processor.steps[2].stats == new_stats # unnormalizer
|
||||
|
||||
# Identity processor should remain unchanged (and it doesn't have stats attribute)
|
||||
assert not hasattr(new_processor.steps[1], "stats")
|
||||
|
||||
|
||||
def test_hotswap_stats_empty_stats():
|
||||
"""Test hotswap_stats with empty stats dictionary."""
|
||||
initial_stats = {
|
||||
"observation.image": {"mean": np.array([0.5]), "std": np.array([0.2])},
|
||||
}
|
||||
|
||||
empty_stats = {}
|
||||
|
||||
features = {
|
||||
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)),
|
||||
}
|
||||
norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD}
|
||||
|
||||
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=initial_stats)
|
||||
robot_processor = RobotProcessor(steps=[normalizer])
|
||||
|
||||
# Hotswap with empty stats
|
||||
new_processor = hotswap_stats(robot_processor, empty_stats)
|
||||
|
||||
# Should update to empty stats
|
||||
assert new_processor.steps[0].stats == empty_stats
|
||||
assert new_processor.steps[0]._tensor_stats == {}
|
||||
|
||||
|
||||
def test_hotswap_stats_no_normalizer_steps():
|
||||
"""Test hotswap_stats with a processor that has no normalizer/unnormalizer steps."""
|
||||
stats = {
|
||||
"observation.image": {"mean": np.array([0.5]), "std": np.array([0.2])},
|
||||
}
|
||||
|
||||
# Create processor with only identity steps
|
||||
robot_processor = RobotProcessor(steps=[IdentityProcessor(), IdentityProcessor()])
|
||||
|
||||
# Hotswap stats - should work without error
|
||||
new_processor = hotswap_stats(robot_processor, stats)
|
||||
|
||||
# Should return a different object (deep copy)
|
||||
assert new_processor is not robot_processor
|
||||
|
||||
# Steps should be deep copied but unchanged
|
||||
assert len(new_processor.steps) == len(robot_processor.steps)
|
||||
for i, step in enumerate(new_processor.steps):
|
||||
assert step is not robot_processor.steps[i] # Different objects
|
||||
assert isinstance(step, type(robot_processor.steps[i])) # Same type
|
||||
|
||||
|
||||
def test_hotswap_stats_preserves_other_attributes():
|
||||
"""Test that hotswap_stats preserves other processor attributes like features and norm_map."""
|
||||
initial_stats = {
|
||||
"observation.image": {"mean": np.array([0.5]), "std": np.array([0.2])},
|
||||
}
|
||||
|
||||
new_stats = {
|
||||
"observation.image": {"mean": np.array([0.3]), "std": np.array([0.1])},
|
||||
}
|
||||
|
||||
features = {
|
||||
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)),
|
||||
}
|
||||
norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD}
|
||||
normalize_keys = {"observation.image"}
|
||||
eps = 1e-6
|
||||
|
||||
normalizer = NormalizerProcessor(
|
||||
features=features, norm_map=norm_map, stats=initial_stats, normalize_keys=normalize_keys, eps=eps
|
||||
)
|
||||
robot_processor = RobotProcessor(steps=[normalizer])
|
||||
|
||||
# Hotswap stats
|
||||
new_processor = hotswap_stats(robot_processor, new_stats)
|
||||
|
||||
# Check that other attributes are preserved
|
||||
new_normalizer = new_processor.steps[0]
|
||||
assert new_normalizer.features == features
|
||||
assert new_normalizer.norm_map == norm_map
|
||||
assert new_normalizer.normalize_keys == normalize_keys
|
||||
assert new_normalizer.eps == eps
|
||||
|
||||
# But stats should be updated
|
||||
assert new_normalizer.stats == new_stats
|
||||
|
||||
|
||||
def test_hotswap_stats_multiple_normalizer_types():
|
||||
"""Test hotswap_stats with multiple normalizer and unnormalizer steps."""
|
||||
initial_stats = {
|
||||
"observation.image": {"mean": np.array([0.5]), "std": np.array([0.2])},
|
||||
"action": {"min": np.array([-1.0]), "max": np.array([1.0])},
|
||||
}
|
||||
|
||||
new_stats = {
|
||||
"observation.image": {"mean": np.array([0.3]), "std": np.array([0.1])},
|
||||
"action": {"min": np.array([-2.0]), "max": np.array([2.0])},
|
||||
}
|
||||
|
||||
features = {
|
||||
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)),
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(1,)),
|
||||
}
|
||||
norm_map = {
|
||||
FeatureType.VISUAL: NormalizationMode.MEAN_STD,
|
||||
FeatureType.ACTION: NormalizationMode.MIN_MAX,
|
||||
}
|
||||
|
||||
# Create multiple normalizers and unnormalizers
|
||||
normalizer1 = NormalizerProcessor(features=features, norm_map=norm_map, stats=initial_stats)
|
||||
normalizer2 = NormalizerProcessor(features=features, norm_map=norm_map, stats=initial_stats)
|
||||
unnormalizer1 = UnnormalizerProcessor(features=features, norm_map=norm_map, stats=initial_stats)
|
||||
unnormalizer2 = UnnormalizerProcessor(features=features, norm_map=norm_map, stats=initial_stats)
|
||||
|
||||
robot_processor = RobotProcessor(steps=[normalizer1, unnormalizer1, normalizer2, unnormalizer2])
|
||||
|
||||
# Hotswap stats
|
||||
new_processor = hotswap_stats(robot_processor, new_stats)
|
||||
|
||||
# All normalizer/unnormalizer steps should be updated
|
||||
for step in new_processor.steps:
|
||||
assert step.stats == new_stats
|
||||
|
||||
# Check tensor stats conversion
|
||||
expected_tensor_stats = _convert_stats_to_tensors(new_stats)
|
||||
for key in expected_tensor_stats:
|
||||
for stat_name in expected_tensor_stats[key]:
|
||||
torch.testing.assert_close(
|
||||
step._tensor_stats[key][stat_name], expected_tensor_stats[key][stat_name]
|
||||
)
|
||||
|
||||
|
||||
def test_hotswap_stats_with_different_data_types():
|
||||
"""Test hotswap_stats with various data types in stats."""
|
||||
initial_stats = {
|
||||
"observation.image": {"mean": np.array([0.5]), "std": np.array([0.2])},
|
||||
}
|
||||
|
||||
# New stats with different data types (int, float, list, tuple)
|
||||
new_stats = {
|
||||
"observation.image": {
|
||||
"mean": [0.3, 0.4, 0.5], # list
|
||||
"std": (0.1, 0.2, 0.3), # tuple
|
||||
"min": 0, # int
|
||||
"max": 1.0, # float
|
||||
},
|
||||
"action": {
|
||||
"mean": np.array([0.1, 0.2]), # numpy array
|
||||
"std": torch.tensor([0.5, 0.6]), # torch tensor
|
||||
},
|
||||
}
|
||||
|
||||
features = {
|
||||
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)),
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)),
|
||||
}
|
||||
norm_map = {
|
||||
FeatureType.VISUAL: NormalizationMode.MEAN_STD,
|
||||
FeatureType.ACTION: NormalizationMode.MEAN_STD,
|
||||
}
|
||||
|
||||
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=initial_stats)
|
||||
robot_processor = RobotProcessor(steps=[normalizer])
|
||||
|
||||
# Hotswap stats
|
||||
new_processor = hotswap_stats(robot_processor, new_stats)
|
||||
|
||||
# Check that stats are updated
|
||||
assert new_processor.steps[0].stats == new_stats
|
||||
|
||||
# Check that tensor conversion worked correctly
|
||||
tensor_stats = new_processor.steps[0]._tensor_stats
|
||||
assert isinstance(tensor_stats["observation.image"]["mean"], torch.Tensor)
|
||||
assert isinstance(tensor_stats["observation.image"]["std"], torch.Tensor)
|
||||
assert isinstance(tensor_stats["observation.image"]["min"], torch.Tensor)
|
||||
assert isinstance(tensor_stats["observation.image"]["max"], torch.Tensor)
|
||||
assert isinstance(tensor_stats["action"]["mean"], torch.Tensor)
|
||||
assert isinstance(tensor_stats["action"]["std"], torch.Tensor)
|
||||
|
||||
# Check values
|
||||
torch.testing.assert_close(tensor_stats["observation.image"]["mean"], torch.tensor([0.3, 0.4, 0.5]))
|
||||
torch.testing.assert_close(tensor_stats["observation.image"]["std"], torch.tensor([0.1, 0.2, 0.3]))
|
||||
torch.testing.assert_close(tensor_stats["observation.image"]["min"], torch.tensor(0.0))
|
||||
torch.testing.assert_close(tensor_stats["observation.image"]["max"], torch.tensor(1.0))
|
||||
|
||||
|
||||
def test_hotswap_stats_functional_test():
|
||||
"""Test that hotswapped processor actually works functionally."""
|
||||
# Create test data
|
||||
observation = {
|
||||
"observation.image": torch.tensor([[[0.6, 0.7], [0.8, 0.9]], [[0.5, 0.6], [0.7, 0.8]]]),
|
||||
}
|
||||
action = torch.tensor([0.5, -0.5])
|
||||
transition = create_transition(observation=observation, action=action)
|
||||
|
||||
# Initial stats
|
||||
initial_stats = {
|
||||
"observation.image": {"mean": np.array([0.5, 0.4]), "std": np.array([0.2, 0.3])},
|
||||
"action": {"mean": np.array([0.0, 0.0]), "std": np.array([1.0, 1.0])},
|
||||
}
|
||||
|
||||
# New stats
|
||||
new_stats = {
|
||||
"observation.image": {"mean": np.array([0.3, 0.2]), "std": np.array([0.1, 0.2])},
|
||||
"action": {"mean": np.array([0.1, -0.1]), "std": np.array([0.5, 0.5])},
|
||||
}
|
||||
|
||||
features = {
|
||||
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(2, 2, 2)),
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)),
|
||||
}
|
||||
norm_map = {
|
||||
FeatureType.VISUAL: NormalizationMode.MEAN_STD,
|
||||
FeatureType.ACTION: NormalizationMode.MEAN_STD,
|
||||
}
|
||||
|
||||
# Create original processor
|
||||
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=initial_stats)
|
||||
original_processor = RobotProcessor(steps=[normalizer])
|
||||
|
||||
# Process with original stats
|
||||
original_result = original_processor(transition)
|
||||
|
||||
# Hotswap stats
|
||||
new_processor = hotswap_stats(original_processor, new_stats)
|
||||
|
||||
# Process with new stats
|
||||
new_result = new_processor(transition)
|
||||
|
||||
# Results should be different since normalization changed
|
||||
assert not torch.allclose(
|
||||
original_result["observation"]["observation.image"],
|
||||
new_result["observation"]["observation.image"],
|
||||
rtol=1e-3,
|
||||
atol=1e-3,
|
||||
)
|
||||
assert not torch.allclose(original_result["action"], new_result["action"], rtol=1e-3, atol=1e-3)
|
||||
|
||||
# Verify that the new processor is actually using the new stats by checking internal state
|
||||
assert new_processor.steps[0].stats == new_stats
|
||||
assert torch.allclose(
|
||||
new_processor.steps[0]._tensor_stats["observation.image"]["mean"], torch.tensor([0.3, 0.2])
|
||||
)
|
||||
assert torch.allclose(
|
||||
new_processor.steps[0]._tensor_stats["observation.image"]["std"], torch.tensor([0.1, 0.2])
|
||||
)
|
||||
assert torch.allclose(new_processor.steps[0]._tensor_stats["action"]["mean"], torch.tensor([0.1, -0.1]))
|
||||
assert torch.allclose(new_processor.steps[0]._tensor_stats["action"]["std"], torch.tensor([0.5, 0.5]))
|
||||
|
||||
# Test that normalization actually happens (output should not equal input)
|
||||
assert not torch.allclose(
|
||||
new_result["observation"]["observation.image"], observation["observation.image"]
|
||||
)
|
||||
assert not torch.allclose(new_result["action"], action)
|
||||
|
||||
Reference in New Issue
Block a user