mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 16:49:55 +00:00
Refactor normalization components and update tests
- Renamed `ObservationNormalizer` to `NormalizerProcessor` and `ActionUnnormalizer` to `UnnormalizerProcessor` for clarity. - Consolidated normalization logic for both observations and actions into `NormalizerProcessor` and `UnnormalizerProcessor`. - Updated tests to reflect the new class names and ensure proper functionality of normalization and unnormalization processes. - Enhanced handling of missing statistics in normalization processes.
This commit is contained in:
@@ -5,9 +5,8 @@ import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.processor.normalize_processor import (
|
||||
ActionUnnormalizer,
|
||||
NormalizationProcessor,
|
||||
ObservationNormalizer,
|
||||
NormalizerProcessor,
|
||||
UnnormalizerProcessor,
|
||||
_convert_stats_to_tensors,
|
||||
)
|
||||
from lerobot.processor.pipeline import RobotProcessor, TransitionIndex
|
||||
@@ -77,7 +76,7 @@ def test_unsupported_type():
|
||||
_convert_stats_to_tensors(stats)
|
||||
|
||||
|
||||
# Fixtures for ObservationNormalizer tests
|
||||
# Fixtures for observation normalisation tests using NormalizerProcessor
|
||||
@pytest.fixture
|
||||
def observation_stats():
|
||||
return {
|
||||
@@ -94,7 +93,8 @@ def observation_stats():
|
||||
|
||||
@pytest.fixture
|
||||
def observation_normalizer(observation_stats):
|
||||
return ObservationNormalizer(stats=observation_stats)
|
||||
"""Return a NormalizerProcessor that only has observation stats (no action)."""
|
||||
return NormalizerProcessor(stats=observation_stats)
|
||||
|
||||
|
||||
def test_mean_std_normalization(observation_normalizer):
|
||||
@@ -129,7 +129,7 @@ def test_min_max_normalization(observation_normalizer):
|
||||
|
||||
|
||||
def test_selective_normalization(observation_stats):
|
||||
normalizer = ObservationNormalizer(stats=observation_stats, normalize_keys={"observation.image"})
|
||||
normalizer = NormalizerProcessor(stats=observation_stats, normalize_keys={"observation.image"})
|
||||
|
||||
observation = {
|
||||
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
|
||||
@@ -146,46 +146,9 @@ def test_selective_normalization(observation_stats):
|
||||
assert torch.allclose(normalized_obs["observation.state"], observation["observation.state"])
|
||||
|
||||
|
||||
def test_missing_stats_error(observation_stats):
|
||||
normalizer = ObservationNormalizer(
|
||||
stats={"observation.image": observation_stats["observation.image"]},
|
||||
normalize_keys={"observation.image", "observation.missing"},
|
||||
)
|
||||
|
||||
observation = {
|
||||
"observation.image": torch.tensor([0.5, 0.5, 0.5]),
|
||||
"observation.missing": torch.tensor([1.0, 2.0]),
|
||||
}
|
||||
transition = (observation, None, None, None, None, None, None)
|
||||
|
||||
with pytest.raises(KeyError, match="Stats not found for requested key 'observation.missing'"):
|
||||
normalizer(transition)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_type,input_value,expected_type",
|
||||
[
|
||||
("numpy", np.array([0.7, 0.5, 0.3], dtype=np.float32), torch.Tensor),
|
||||
("torch", torch.tensor([0.7, 0.5, 0.3]), torch.Tensor),
|
||||
],
|
||||
)
|
||||
def test_input_types(observation_normalizer, input_type, input_value, expected_type):
|
||||
observation = {
|
||||
"observation.image": input_value,
|
||||
}
|
||||
transition = (observation, None, None, None, None, None, None)
|
||||
|
||||
normalized_transition = observation_normalizer(transition)
|
||||
normalized_obs = normalized_transition[TransitionIndex.OBSERVATION]
|
||||
|
||||
expected = (torch.tensor([0.7, 0.5, 0.3]) - 0.5) / 0.2
|
||||
assert isinstance(normalized_obs["observation.image"], expected_type)
|
||||
assert torch.allclose(normalized_obs["observation.image"], expected)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_device_compatibility(observation_stats):
|
||||
normalizer = ObservationNormalizer(stats=observation_stats)
|
||||
normalizer = NormalizerProcessor(stats=observation_stats)
|
||||
observation = {
|
||||
"observation.image": torch.tensor([0.7, 0.5, 0.3]).cuda(),
|
||||
}
|
||||
@@ -205,11 +168,11 @@ def test_from_lerobot_dataset():
|
||||
"action": {"mean": [0.0], "std": [1.0]}, # Should be filtered out
|
||||
}
|
||||
|
||||
normalizer = ObservationNormalizer.from_lerobot_dataset(mock_dataset)
|
||||
normalizer = NormalizerProcessor.from_lerobot_dataset(mock_dataset)
|
||||
|
||||
# Check that action stats are filtered out
|
||||
# Both observation and action statistics should be present in tensor stats
|
||||
assert "observation.image" in normalizer._tensor_stats
|
||||
assert "action" not in normalizer._tensor_stats
|
||||
assert "action" in normalizer._tensor_stats
|
||||
|
||||
|
||||
def test_state_dict_save_load(observation_normalizer):
|
||||
@@ -217,7 +180,7 @@ def test_state_dict_save_load(observation_normalizer):
|
||||
state_dict = observation_normalizer.state_dict()
|
||||
|
||||
# Create new normalizer and load state
|
||||
new_normalizer = ObservationNormalizer(stats={})
|
||||
new_normalizer = NormalizerProcessor(stats={})
|
||||
new_normalizer.load_state_dict(state_dict)
|
||||
|
||||
# Test that it works the same
|
||||
@@ -248,7 +211,7 @@ def action_stats_min_max():
|
||||
|
||||
|
||||
def test_mean_std_unnormalization(action_stats_mean_std):
|
||||
unnormalizer = ActionUnnormalizer(action_stats=action_stats_mean_std)
|
||||
unnormalizer = UnnormalizerProcessor(stats={"action": action_stats_mean_std})
|
||||
|
||||
normalized_action = torch.tensor([1.0, -0.5, 2.0])
|
||||
transition = (None, normalized_action, None, None, None, None, None)
|
||||
@@ -262,7 +225,7 @@ def test_mean_std_unnormalization(action_stats_mean_std):
|
||||
|
||||
|
||||
def test_min_max_unnormalization(action_stats_min_max):
|
||||
unnormalizer = ActionUnnormalizer(action_stats=action_stats_min_max)
|
||||
unnormalizer = UnnormalizerProcessor(stats={"action": action_stats_min_max})
|
||||
|
||||
# Actions in [-1, 1]
|
||||
normalized_action = torch.tensor([0.0, -1.0, 1.0])
|
||||
@@ -284,7 +247,7 @@ def test_min_max_unnormalization(action_stats_min_max):
|
||||
|
||||
|
||||
def test_numpy_action_input(action_stats_mean_std):
|
||||
unnormalizer = ActionUnnormalizer(action_stats=action_stats_mean_std)
|
||||
unnormalizer = UnnormalizerProcessor(stats={"action": action_stats_mean_std})
|
||||
|
||||
normalized_action = np.array([1.0, -0.5, 2.0], dtype=np.float32)
|
||||
transition = (None, normalized_action, None, None, None, None, None)
|
||||
@@ -298,7 +261,7 @@ def test_numpy_action_input(action_stats_mean_std):
|
||||
|
||||
|
||||
def test_none_action(action_stats_mean_std):
|
||||
unnormalizer = ActionUnnormalizer(action_stats=action_stats_mean_std)
|
||||
unnormalizer = UnnormalizerProcessor(stats={"action": action_stats_mean_std})
|
||||
|
||||
transition = (None, None, None, None, None, None, None)
|
||||
result = unnormalizer(transition)
|
||||
@@ -308,40 +271,13 @@ def test_none_action(action_stats_mean_std):
|
||||
|
||||
|
||||
def test_action_from_lerobot_dataset():
|
||||
# Mock dataset
|
||||
mock_dataset = Mock()
|
||||
mock_dataset.meta.stats = {
|
||||
"action": {"mean": [0.0], "std": [1.0]},
|
||||
"observation.image": {"mean": [0.5], "std": [0.2]},
|
||||
}
|
||||
|
||||
unnormalizer = ActionUnnormalizer.from_lerobot_dataset(mock_dataset)
|
||||
|
||||
assert "mean" in unnormalizer._tensor_stats
|
||||
assert "std" in unnormalizer._tensor_stats
|
||||
mock_dataset.meta.stats = {"action": {"mean": [0.0], "std": [1.0]}}
|
||||
unnormalizer = UnnormalizerProcessor.from_lerobot_dataset(mock_dataset)
|
||||
assert "mean" in unnormalizer._tensor_stats["action"]
|
||||
|
||||
|
||||
def test_missing_action_stats_error():
|
||||
mock_dataset = Mock()
|
||||
mock_dataset.meta.stats = {
|
||||
"observation.image": {"mean": [0.5], "std": [0.2]},
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="Dataset does not contain action statistics"):
|
||||
ActionUnnormalizer.from_lerobot_dataset(mock_dataset)
|
||||
|
||||
|
||||
def test_invalid_stats_error():
|
||||
unnormalizer = ActionUnnormalizer(action_stats={"invalid": [1.0]})
|
||||
|
||||
action = torch.tensor([1.0])
|
||||
transition = (None, action, None, None, None, None, None)
|
||||
|
||||
with pytest.raises(ValueError, match="Action stats must contain"):
|
||||
unnormalizer(transition)
|
||||
|
||||
|
||||
# Fixtures for NormalizationProcessor tests
|
||||
# Fixtures for NormalizerProcessor tests
|
||||
@pytest.fixture
|
||||
def full_stats():
|
||||
return {
|
||||
@@ -361,11 +297,11 @@ def full_stats():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def normalization_processor(full_stats):
|
||||
return NormalizationProcessor(stats=full_stats)
|
||||
def normalizer_processor(full_stats):
|
||||
return NormalizerProcessor(stats=full_stats)
|
||||
|
||||
|
||||
def test_combined_normalization_unnormalization(normalization_processor):
|
||||
def test_combined_normalization(normalizer_processor):
|
||||
observation = {
|
||||
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
|
||||
"observation.state": torch.tensor([0.5, 0.0]),
|
||||
@@ -373,16 +309,16 @@ def test_combined_normalization_unnormalization(normalization_processor):
|
||||
action = torch.tensor([1.0, -0.5])
|
||||
transition = (observation, action, 1.0, False, False, {}, {})
|
||||
|
||||
processed_transition = normalization_processor(transition)
|
||||
processed_transition = normalizer_processor(transition)
|
||||
|
||||
# Check normalized observations
|
||||
processed_obs = processed_transition[TransitionIndex.OBSERVATION]
|
||||
expected_image = (torch.tensor([0.7, 0.5, 0.3]) - 0.5) / 0.2
|
||||
assert torch.allclose(processed_obs["observation.image"], expected_image)
|
||||
|
||||
# Check unnormalized action
|
||||
# Check normalized action
|
||||
processed_action = processed_transition[TransitionIndex.ACTION]
|
||||
expected_action = torch.tensor([1.0 * 1.0 + 0.0, -0.5 * 2.0 + 0.0])
|
||||
expected_action = torch.tensor([(1.0 - 0.0) / 1.0, (-0.5 - 0.0) / 2.0])
|
||||
assert torch.allclose(processed_action, expected_action)
|
||||
|
||||
# Check other fields remain unchanged
|
||||
@@ -390,45 +326,28 @@ def test_combined_normalization_unnormalization(normalization_processor):
|
||||
assert not processed_transition[TransitionIndex.DONE]
|
||||
|
||||
|
||||
def test_disable_action_unnormalization(full_stats):
|
||||
processor = NormalizationProcessor(stats=full_stats, unnormalize_action=False)
|
||||
|
||||
action = torch.tensor([1.0, -0.5])
|
||||
transition = (None, action, None, None, None, None, None)
|
||||
|
||||
processed_transition = processor(transition)
|
||||
|
||||
# Action should remain unchanged
|
||||
assert torch.allclose(processed_transition[TransitionIndex.ACTION], action)
|
||||
|
||||
|
||||
def test_processor_from_lerobot_dataset(full_stats):
|
||||
# Mock dataset
|
||||
mock_dataset = Mock()
|
||||
mock_dataset.meta.stats = full_stats
|
||||
|
||||
processor = NormalizationProcessor.from_lerobot_dataset(
|
||||
mock_dataset, normalize_keys={"observation.image"}, unnormalize_action=True
|
||||
)
|
||||
processor = NormalizerProcessor.from_lerobot_dataset(mock_dataset, normalize_keys={"observation.image"})
|
||||
|
||||
assert processor.normalize_keys == {"observation.image"}
|
||||
assert processor.unnormalize_action
|
||||
assert "observation.image" in processor._tensor_stats
|
||||
assert "action" in processor._tensor_stats
|
||||
|
||||
|
||||
def test_get_config(full_stats):
|
||||
processor = NormalizationProcessor(
|
||||
stats=full_stats, normalize_keys={"observation.image"}, unnormalize_action=False, eps=1e-6
|
||||
)
|
||||
processor = NormalizerProcessor(stats=full_stats, normalize_keys={"observation.image"}, eps=1e-6)
|
||||
|
||||
config = processor.get_config()
|
||||
assert config == {"normalize_keys": ["observation.image"], "unnormalize_action": False, "eps": 1e-6}
|
||||
assert config == {"normalize_keys": ["observation.image"], "eps": 1e-6}
|
||||
|
||||
|
||||
def test_integration_with_robot_processor(normalization_processor):
|
||||
def test_integration_with_robot_processor(normalizer_processor):
|
||||
"""Test integration with RobotProcessor pipeline"""
|
||||
robot_processor = RobotProcessor([normalization_processor])
|
||||
robot_processor = RobotProcessor([normalizer_processor])
|
||||
|
||||
observation = {
|
||||
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
|
||||
@@ -447,7 +366,7 @@ def test_integration_with_robot_processor(normalization_processor):
|
||||
# Edge case tests
|
||||
def test_empty_observation():
|
||||
stats = {"observation.image": {"mean": [0.5], "std": [0.2]}}
|
||||
normalizer = ObservationNormalizer(stats=stats)
|
||||
normalizer = NormalizerProcessor(stats=stats)
|
||||
|
||||
transition = (None, None, None, None, None, None, None)
|
||||
result = normalizer(transition)
|
||||
@@ -456,7 +375,7 @@ def test_empty_observation():
|
||||
|
||||
|
||||
def test_empty_stats():
|
||||
normalizer = ObservationNormalizer(stats={})
|
||||
normalizer = NormalizerProcessor(stats={})
|
||||
observation = {"observation.image": torch.tensor([0.5])}
|
||||
transition = (observation, None, None, None, None, None, None)
|
||||
|
||||
@@ -466,12 +385,20 @@ def test_empty_stats():
|
||||
|
||||
|
||||
def test_partial_stats():
|
||||
stats = {
|
||||
"observation.image": {"mean": [0.5]}, # Missing std
|
||||
}
|
||||
normalizer = ObservationNormalizer(stats=stats)
|
||||
"""If statistics are incomplete, the value should pass through unchanged."""
|
||||
stats = {"observation.image": {"mean": [0.5]}} # Missing std / (min,max)
|
||||
normalizer = NormalizerProcessor(stats=stats)
|
||||
observation = {"observation.image": torch.tensor([0.7])}
|
||||
transition = (observation, None, None, None, None, None, None)
|
||||
|
||||
with pytest.raises(ValueError, match="must contain either"):
|
||||
normalizer(transition)
|
||||
processed = normalizer(transition)[TransitionIndex.OBSERVATION]
|
||||
assert torch.allclose(processed["observation.image"], observation["observation.image"])
|
||||
|
||||
|
||||
def test_missing_action_stats_no_error():
|
||||
mock_dataset = Mock()
|
||||
mock_dataset.meta.stats = {"observation.image": {"mean": [0.5], "std": [0.2]}}
|
||||
|
||||
processor = UnnormalizerProcessor.from_lerobot_dataset(mock_dataset)
|
||||
# The tensor stats should not contain the 'action' key
|
||||
assert "action" not in processor._tensor_stats
|
||||
|
||||
Reference in New Issue
Block a user