mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-27 06:29:47 +00:00
Enhance processing architecture with new components
- Added `RenameProcessor` to facilitate key renaming in observations, improving data handling flexibility. - Updated `__init__.py` to include `RenameProcessor` in module exports. - Refactored `NormalizationProcessor` and `ObservationNormalizer` to use `rsplit` for better key handling. - Introduced comprehensive tests for `NormalizationProcessor` and `RenameProcessor` to ensure functionality and robustness.
This commit is contained in:
@@ -20,6 +20,7 @@ from .observation_processor import (
|
|||||||
StateProcessor,
|
StateProcessor,
|
||||||
)
|
)
|
||||||
from .pipeline import EnvTransition, ProcessorStep, RobotProcessor
|
from .pipeline import EnvTransition, ProcessorStep, RobotProcessor
|
||||||
|
from .rename_processor import RenameProcessor
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"RobotProcessor",
|
"RobotProcessor",
|
||||||
@@ -29,4 +30,5 @@ __all__ = [
|
|||||||
"StateProcessor",
|
"StateProcessor",
|
||||||
"ObservationProcessor",
|
"ObservationProcessor",
|
||||||
"NormalizationProcessor",
|
"NormalizationProcessor",
|
||||||
|
"RenameProcessor",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -151,7 +151,7 @@ class ObservationNormalizer:
|
|||||||
def load_state_dict(self, state: Mapping[str, Tensor]) -> None:
|
def load_state_dict(self, state: Mapping[str, Tensor]) -> None:
|
||||||
self._tensor_stats.clear()
|
self._tensor_stats.clear()
|
||||||
for flat_key, tensor in state.items():
|
for flat_key, tensor in state.items():
|
||||||
key, stat_name = flat_key.split(".", 1)
|
key, stat_name = flat_key.rsplit(".", 1)
|
||||||
if key not in self._tensor_stats:
|
if key not in self._tensor_stats:
|
||||||
self._tensor_stats[key] = {}
|
self._tensor_stats[key] = {}
|
||||||
self._tensor_stats[key][stat_name] = tensor
|
self._tensor_stats[key][stat_name] = tensor
|
||||||
@@ -382,7 +382,7 @@ class NormalizationProcessor:
|
|||||||
def load_state_dict(self, state: Mapping[str, Tensor]) -> None:
|
def load_state_dict(self, state: Mapping[str, Tensor]) -> None:
|
||||||
self._tensor_stats.clear()
|
self._tensor_stats.clear()
|
||||||
for flat_key, tensor in state.items():
|
for flat_key, tensor in state.items():
|
||||||
key, stat_name = flat_key.split(".", 1)
|
key, stat_name = flat_key.rsplit(".", 1)
|
||||||
if key not in self._tensor_stats:
|
if key not in self._tensor_stats:
|
||||||
self._tensor_stats[key] = {}
|
self._tensor_stats[key] = {}
|
||||||
self._tensor_stats[key][stat_name] = tensor
|
self._tensor_stats[key][stat_name] = tensor
|
||||||
|
|||||||
@@ -464,3 +464,95 @@ class RobotProcessor(ModelHubMixin):
|
|||||||
profile_results[step_name] = avg_time
|
profile_results[step_name] = avg_time
|
||||||
|
|
||||||
return profile_results
|
return profile_results
|
||||||
|
|
||||||
|
|
||||||
|
class ObservationProcessor:
|
||||||
|
def observation(self, observation):
|
||||||
|
return observation
|
||||||
|
|
||||||
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
|
observation = transition[TransitionIndex.OBSERVATION]
|
||||||
|
observation = self.observation(observation)
|
||||||
|
transition = (observation, *transition[TransitionIndex.ACTION :])
|
||||||
|
return transition
|
||||||
|
|
||||||
|
|
||||||
|
class ActionProcessor:
|
||||||
|
def action(self, action):
|
||||||
|
return action
|
||||||
|
|
||||||
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
|
action = transition[TransitionIndex.ACTION]
|
||||||
|
action = self.action(action)
|
||||||
|
transition = (transition[TransitionIndex.OBSERVATION], action, *transition[TransitionIndex.REWARD :])
|
||||||
|
return transition
|
||||||
|
|
||||||
|
|
||||||
|
class RewardProcessor:
|
||||||
|
def reward(self, reward):
|
||||||
|
return reward
|
||||||
|
|
||||||
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
|
reward = transition[TransitionIndex.REWARD]
|
||||||
|
reward = self.reward(reward)
|
||||||
|
transition = (
|
||||||
|
transition[TransitionIndex.OBSERVATION],
|
||||||
|
transition[TransitionIndex.ACTION],
|
||||||
|
reward,
|
||||||
|
*transition[TransitionIndex.DONE :],
|
||||||
|
)
|
||||||
|
return transition
|
||||||
|
|
||||||
|
|
||||||
|
class DoneProcessor:
|
||||||
|
def done(self, done):
|
||||||
|
return done
|
||||||
|
|
||||||
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
|
done = transition[TransitionIndex.DONE]
|
||||||
|
done = self.done(done)
|
||||||
|
transition = (
|
||||||
|
transition[TransitionIndex.OBSERVATION],
|
||||||
|
transition[TransitionIndex.ACTION],
|
||||||
|
transition[TransitionIndex.REWARD],
|
||||||
|
done,
|
||||||
|
*transition[TransitionIndex.TRUNCATED :],
|
||||||
|
)
|
||||||
|
return transition
|
||||||
|
|
||||||
|
|
||||||
|
class TruncatedProcessor:
|
||||||
|
def truncated(self, truncated):
|
||||||
|
return truncated
|
||||||
|
|
||||||
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
|
truncated = transition[TransitionIndex.TRUNCATED]
|
||||||
|
truncated = self.truncated(truncated)
|
||||||
|
transition = (
|
||||||
|
transition[TransitionIndex.OBSERVATION],
|
||||||
|
transition[TransitionIndex.ACTION],
|
||||||
|
transition[TransitionIndex.REWARD],
|
||||||
|
transition[TransitionIndex.DONE],
|
||||||
|
truncated,
|
||||||
|
*transition[TransitionIndex.INFO :],
|
||||||
|
)
|
||||||
|
return transition
|
||||||
|
|
||||||
|
|
||||||
|
class InfoProcessor:
|
||||||
|
def info(self, info):
|
||||||
|
return info
|
||||||
|
|
||||||
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
|
info = transition[TransitionIndex.INFO]
|
||||||
|
info = self.info(info)
|
||||||
|
transition = (
|
||||||
|
transition[TransitionIndex.OBSERVATION],
|
||||||
|
transition[TransitionIndex.ACTION],
|
||||||
|
transition[TransitionIndex.REWARD],
|
||||||
|
transition[TransitionIndex.DONE],
|
||||||
|
transition[TransitionIndex.TRUNCATED],
|
||||||
|
info,
|
||||||
|
*transition[TransitionIndex.COMPLEMENTARY_DATA :],
|
||||||
|
)
|
||||||
|
return transition
|
||||||
|
|||||||
@@ -0,0 +1,477 @@
|
|||||||
|
from unittest.mock import Mock
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from lerobot.processor.normalize_processor import (
|
||||||
|
ActionUnnormalizer,
|
||||||
|
NormalizationProcessor,
|
||||||
|
ObservationNormalizer,
|
||||||
|
_convert_stats_to_tensors,
|
||||||
|
)
|
||||||
|
from lerobot.processor.pipeline import RobotProcessor, TransitionIndex
|
||||||
|
|
||||||
|
|
||||||
|
def test_numpy_conversion():
|
||||||
|
stats = {
|
||||||
|
"observation.image": {
|
||||||
|
"mean": np.array([0.5, 0.5, 0.5]),
|
||||||
|
"std": np.array([0.2, 0.2, 0.2]),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tensor_stats = _convert_stats_to_tensors(stats)
|
||||||
|
|
||||||
|
assert isinstance(tensor_stats["observation.image"]["mean"], torch.Tensor)
|
||||||
|
assert isinstance(tensor_stats["observation.image"]["std"], torch.Tensor)
|
||||||
|
assert torch.allclose(tensor_stats["observation.image"]["mean"], torch.tensor([0.5, 0.5, 0.5]))
|
||||||
|
assert torch.allclose(tensor_stats["observation.image"]["std"], torch.tensor([0.2, 0.2, 0.2]))
|
||||||
|
|
||||||
|
|
||||||
|
def test_tensor_conversion():
|
||||||
|
stats = {
|
||||||
|
"action": {
|
||||||
|
"mean": torch.tensor([0.0, 0.0]),
|
||||||
|
"std": torch.tensor([1.0, 1.0]),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tensor_stats = _convert_stats_to_tensors(stats)
|
||||||
|
|
||||||
|
assert tensor_stats["action"]["mean"].dtype == torch.float32
|
||||||
|
assert tensor_stats["action"]["std"].dtype == torch.float32
|
||||||
|
|
||||||
|
|
||||||
|
def test_scalar_conversion():
|
||||||
|
stats = {
|
||||||
|
"reward": {
|
||||||
|
"mean": 0.5,
|
||||||
|
"std": 0.1,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tensor_stats = _convert_stats_to_tensors(stats)
|
||||||
|
|
||||||
|
assert torch.allclose(tensor_stats["reward"]["mean"], torch.tensor(0.5))
|
||||||
|
assert torch.allclose(tensor_stats["reward"]["std"], torch.tensor(0.1))
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_conversion():
|
||||||
|
stats = {
|
||||||
|
"observation.state": {
|
||||||
|
"min": [0.0, -1.0, -2.0],
|
||||||
|
"max": [1.0, 1.0, 2.0],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tensor_stats = _convert_stats_to_tensors(stats)
|
||||||
|
|
||||||
|
assert torch.allclose(tensor_stats["observation.state"]["min"], torch.tensor([0.0, -1.0, -2.0]))
|
||||||
|
assert torch.allclose(tensor_stats["observation.state"]["max"], torch.tensor([1.0, 1.0, 2.0]))
|
||||||
|
|
||||||
|
|
||||||
|
def test_unsupported_type():
|
||||||
|
stats = {
|
||||||
|
"bad_key": {
|
||||||
|
"mean": "string_value",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
with pytest.raises(TypeError, match="Unsupported type"):
|
||||||
|
_convert_stats_to_tensors(stats)
|
||||||
|
|
||||||
|
|
||||||
|
# Fixtures for ObservationNormalizer tests
|
||||||
|
@pytest.fixture
|
||||||
|
def observation_stats():
|
||||||
|
return {
|
||||||
|
"observation.image": {
|
||||||
|
"mean": np.array([0.5, 0.5, 0.5]),
|
||||||
|
"std": np.array([0.2, 0.2, 0.2]),
|
||||||
|
},
|
||||||
|
"observation.state": {
|
||||||
|
"min": np.array([0.0, -1.0]),
|
||||||
|
"max": np.array([1.0, 1.0]),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def observation_normalizer(observation_stats):
|
||||||
|
return ObservationNormalizer(stats=observation_stats)
|
||||||
|
|
||||||
|
|
||||||
|
def test_mean_std_normalization(observation_normalizer):
|
||||||
|
observation = {
|
||||||
|
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
|
||||||
|
"observation.state": torch.tensor([0.5, 0.0]),
|
||||||
|
}
|
||||||
|
transition = (observation, None, None, None, None, None, None)
|
||||||
|
|
||||||
|
normalized_transition = observation_normalizer(transition)
|
||||||
|
normalized_obs = normalized_transition[TransitionIndex.OBSERVATION]
|
||||||
|
|
||||||
|
# Check mean/std normalization
|
||||||
|
expected_image = (torch.tensor([0.7, 0.5, 0.3]) - 0.5) / 0.2
|
||||||
|
assert torch.allclose(normalized_obs["observation.image"], expected_image)
|
||||||
|
|
||||||
|
|
||||||
|
def test_min_max_normalization(observation_normalizer):
|
||||||
|
observation = {
|
||||||
|
"observation.state": torch.tensor([0.5, 0.0]),
|
||||||
|
}
|
||||||
|
transition = (observation, None, None, None, None, None, None)
|
||||||
|
|
||||||
|
normalized_transition = observation_normalizer(transition)
|
||||||
|
normalized_obs = normalized_transition[TransitionIndex.OBSERVATION]
|
||||||
|
|
||||||
|
# Check min/max normalization to [-1, 1]
|
||||||
|
# For state[0]: 2 * (0.5 - 0.0) / (1.0 - 0.0) - 1 = 0.0
|
||||||
|
# For state[1]: 2 * (0.0 - (-1.0)) / (1.0 - (-1.0)) - 1 = 0.0
|
||||||
|
expected_state = torch.tensor([0.0, 0.0])
|
||||||
|
assert torch.allclose(normalized_obs["observation.state"], expected_state, atol=1e-6)
|
||||||
|
|
||||||
|
|
||||||
|
def test_selective_normalization(observation_stats):
|
||||||
|
normalizer = ObservationNormalizer(stats=observation_stats, normalize_keys={"observation.image"})
|
||||||
|
|
||||||
|
observation = {
|
||||||
|
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
|
||||||
|
"observation.state": torch.tensor([0.5, 0.0]),
|
||||||
|
}
|
||||||
|
transition = (observation, None, None, None, None, None, None)
|
||||||
|
|
||||||
|
normalized_transition = normalizer(transition)
|
||||||
|
normalized_obs = normalized_transition[TransitionIndex.OBSERVATION]
|
||||||
|
|
||||||
|
# Only image should be normalized
|
||||||
|
assert torch.allclose(normalized_obs["observation.image"], (torch.tensor([0.7, 0.5, 0.3]) - 0.5) / 0.2)
|
||||||
|
# State should remain unchanged
|
||||||
|
assert torch.allclose(normalized_obs["observation.state"], observation["observation.state"])
|
||||||
|
|
||||||
|
|
||||||
|
def test_missing_stats_error(observation_stats):
|
||||||
|
normalizer = ObservationNormalizer(
|
||||||
|
stats={"observation.image": observation_stats["observation.image"]},
|
||||||
|
normalize_keys={"observation.image", "observation.missing"},
|
||||||
|
)
|
||||||
|
|
||||||
|
observation = {
|
||||||
|
"observation.image": torch.tensor([0.5, 0.5, 0.5]),
|
||||||
|
"observation.missing": torch.tensor([1.0, 2.0]),
|
||||||
|
}
|
||||||
|
transition = (observation, None, None, None, None, None, None)
|
||||||
|
|
||||||
|
with pytest.raises(KeyError, match="Stats not found for requested key 'observation.missing'"):
|
||||||
|
normalizer(transition)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"input_type,input_value,expected_type",
|
||||||
|
[
|
||||||
|
("numpy", np.array([0.7, 0.5, 0.3], dtype=np.float32), torch.Tensor),
|
||||||
|
("torch", torch.tensor([0.7, 0.5, 0.3]), torch.Tensor),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_input_types(observation_normalizer, input_type, input_value, expected_type):
|
||||||
|
observation = {
|
||||||
|
"observation.image": input_value,
|
||||||
|
}
|
||||||
|
transition = (observation, None, None, None, None, None, None)
|
||||||
|
|
||||||
|
normalized_transition = observation_normalizer(transition)
|
||||||
|
normalized_obs = normalized_transition[TransitionIndex.OBSERVATION]
|
||||||
|
|
||||||
|
expected = (torch.tensor([0.7, 0.5, 0.3]) - 0.5) / 0.2
|
||||||
|
assert isinstance(normalized_obs["observation.image"], expected_type)
|
||||||
|
assert torch.allclose(normalized_obs["observation.image"], expected)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||||
|
def test_device_compatibility(observation_stats):
|
||||||
|
normalizer = ObservationNormalizer(stats=observation_stats)
|
||||||
|
observation = {
|
||||||
|
"observation.image": torch.tensor([0.7, 0.5, 0.3]).cuda(),
|
||||||
|
}
|
||||||
|
transition = (observation, None, None, None, None, None, None)
|
||||||
|
|
||||||
|
normalized_transition = normalizer(transition)
|
||||||
|
normalized_obs = normalized_transition[TransitionIndex.OBSERVATION]
|
||||||
|
|
||||||
|
assert normalized_obs["observation.image"].device.type == "cuda"
|
||||||
|
|
||||||
|
|
||||||
|
def test_from_lerobot_dataset():
|
||||||
|
# Mock dataset
|
||||||
|
mock_dataset = Mock()
|
||||||
|
mock_dataset.meta.stats = {
|
||||||
|
"observation.image": {"mean": [0.5], "std": [0.2]},
|
||||||
|
"action": {"mean": [0.0], "std": [1.0]}, # Should be filtered out
|
||||||
|
}
|
||||||
|
|
||||||
|
normalizer = ObservationNormalizer.from_lerobot_dataset(mock_dataset)
|
||||||
|
|
||||||
|
# Check that action stats are filtered out
|
||||||
|
assert "observation.image" in normalizer._tensor_stats
|
||||||
|
assert "action" not in normalizer._tensor_stats
|
||||||
|
|
||||||
|
|
||||||
|
def test_state_dict_save_load(observation_normalizer):
|
||||||
|
# Save state
|
||||||
|
state_dict = observation_normalizer.state_dict()
|
||||||
|
|
||||||
|
# Create new normalizer and load state
|
||||||
|
new_normalizer = ObservationNormalizer(stats={})
|
||||||
|
new_normalizer.load_state_dict(state_dict)
|
||||||
|
|
||||||
|
# Test that it works the same
|
||||||
|
observation = {"observation.image": torch.tensor([0.7, 0.5, 0.3])}
|
||||||
|
transition = (observation, None, None, None, None, None, None)
|
||||||
|
|
||||||
|
result1 = observation_normalizer(transition)[0]
|
||||||
|
result2 = new_normalizer(transition)[0]
|
||||||
|
|
||||||
|
assert torch.allclose(result1["observation.image"], result2["observation.image"])
|
||||||
|
|
||||||
|
|
||||||
|
# Fixtures for ActionUnnormalizer tests
|
||||||
|
@pytest.fixture
|
||||||
|
def action_stats_mean_std():
|
||||||
|
return {
|
||||||
|
"mean": np.array([0.0, 0.0, 0.0]),
|
||||||
|
"std": np.array([1.0, 2.0, 0.5]),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def action_stats_min_max():
|
||||||
|
return {
|
||||||
|
"min": np.array([-1.0, -2.0, 0.0]),
|
||||||
|
"max": np.array([1.0, 2.0, 1.0]),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_mean_std_unnormalization(action_stats_mean_std):
|
||||||
|
unnormalizer = ActionUnnormalizer(action_stats=action_stats_mean_std)
|
||||||
|
|
||||||
|
normalized_action = torch.tensor([1.0, -0.5, 2.0])
|
||||||
|
transition = (None, normalized_action, None, None, None, None, None)
|
||||||
|
|
||||||
|
unnormalized_transition = unnormalizer(transition)
|
||||||
|
unnormalized_action = unnormalized_transition[TransitionIndex.ACTION]
|
||||||
|
|
||||||
|
# action * std + mean
|
||||||
|
expected = torch.tensor([1.0 * 1.0 + 0.0, -0.5 * 2.0 + 0.0, 2.0 * 0.5 + 0.0])
|
||||||
|
assert torch.allclose(unnormalized_action, expected)
|
||||||
|
|
||||||
|
|
||||||
|
def test_min_max_unnormalization(action_stats_min_max):
|
||||||
|
unnormalizer = ActionUnnormalizer(action_stats=action_stats_min_max)
|
||||||
|
|
||||||
|
# Actions in [-1, 1]
|
||||||
|
normalized_action = torch.tensor([0.0, -1.0, 1.0])
|
||||||
|
transition = (None, normalized_action, None, None, None, None, None)
|
||||||
|
|
||||||
|
unnormalized_transition = unnormalizer(transition)
|
||||||
|
unnormalized_action = unnormalized_transition[TransitionIndex.ACTION]
|
||||||
|
|
||||||
|
# Map from [-1, 1] to [min, max]
|
||||||
|
# (action + 1) / 2 * (max - min) + min
|
||||||
|
expected = torch.tensor(
|
||||||
|
[
|
||||||
|
(0.0 + 1) / 2 * (1.0 - (-1.0)) + (-1.0), # 0.0
|
||||||
|
(-1.0 + 1) / 2 * (2.0 - (-2.0)) + (-2.0), # -2.0
|
||||||
|
(1.0 + 1) / 2 * (1.0 - 0.0) + 0.0, # 1.0
|
||||||
|
]
|
||||||
|
)
|
||||||
|
assert torch.allclose(unnormalized_action, expected)
|
||||||
|
|
||||||
|
|
||||||
|
def test_numpy_action_input(action_stats_mean_std):
|
||||||
|
unnormalizer = ActionUnnormalizer(action_stats=action_stats_mean_std)
|
||||||
|
|
||||||
|
normalized_action = np.array([1.0, -0.5, 2.0], dtype=np.float32)
|
||||||
|
transition = (None, normalized_action, None, None, None, None, None)
|
||||||
|
|
||||||
|
unnormalized_transition = unnormalizer(transition)
|
||||||
|
unnormalized_action = unnormalized_transition[TransitionIndex.ACTION]
|
||||||
|
|
||||||
|
assert isinstance(unnormalized_action, torch.Tensor)
|
||||||
|
expected = torch.tensor([1.0, -1.0, 1.0])
|
||||||
|
assert torch.allclose(unnormalized_action, expected)
|
||||||
|
|
||||||
|
|
||||||
|
def test_none_action(action_stats_mean_std):
|
||||||
|
unnormalizer = ActionUnnormalizer(action_stats=action_stats_mean_std)
|
||||||
|
|
||||||
|
transition = (None, None, None, None, None, None, None)
|
||||||
|
result = unnormalizer(transition)
|
||||||
|
|
||||||
|
# Should return transition unchanged
|
||||||
|
assert result == transition
|
||||||
|
|
||||||
|
|
||||||
|
def test_action_from_lerobot_dataset():
|
||||||
|
# Mock dataset
|
||||||
|
mock_dataset = Mock()
|
||||||
|
mock_dataset.meta.stats = {
|
||||||
|
"action": {"mean": [0.0], "std": [1.0]},
|
||||||
|
"observation.image": {"mean": [0.5], "std": [0.2]},
|
||||||
|
}
|
||||||
|
|
||||||
|
unnormalizer = ActionUnnormalizer.from_lerobot_dataset(mock_dataset)
|
||||||
|
|
||||||
|
assert "mean" in unnormalizer._tensor_stats
|
||||||
|
assert "std" in unnormalizer._tensor_stats
|
||||||
|
|
||||||
|
|
||||||
|
def test_missing_action_stats_error():
|
||||||
|
mock_dataset = Mock()
|
||||||
|
mock_dataset.meta.stats = {
|
||||||
|
"observation.image": {"mean": [0.5], "std": [0.2]},
|
||||||
|
}
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Dataset does not contain action statistics"):
|
||||||
|
ActionUnnormalizer.from_lerobot_dataset(mock_dataset)
|
||||||
|
|
||||||
|
|
||||||
|
def test_invalid_stats_error():
|
||||||
|
unnormalizer = ActionUnnormalizer(action_stats={"invalid": [1.0]})
|
||||||
|
|
||||||
|
action = torch.tensor([1.0])
|
||||||
|
transition = (None, action, None, None, None, None, None)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Action stats must contain"):
|
||||||
|
unnormalizer(transition)
|
||||||
|
|
||||||
|
|
||||||
|
# Fixtures for NormalizationProcessor tests
|
||||||
|
@pytest.fixture
|
||||||
|
def full_stats():
|
||||||
|
return {
|
||||||
|
"observation.image": {
|
||||||
|
"mean": np.array([0.5, 0.5, 0.5]),
|
||||||
|
"std": np.array([0.2, 0.2, 0.2]),
|
||||||
|
},
|
||||||
|
"observation.state": {
|
||||||
|
"min": np.array([0.0, -1.0]),
|
||||||
|
"max": np.array([1.0, 1.0]),
|
||||||
|
},
|
||||||
|
"action": {
|
||||||
|
"mean": np.array([0.0, 0.0]),
|
||||||
|
"std": np.array([1.0, 2.0]),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def normalization_processor(full_stats):
|
||||||
|
return NormalizationProcessor(stats=full_stats)
|
||||||
|
|
||||||
|
|
||||||
|
def test_combined_normalization_unnormalization(normalization_processor):
|
||||||
|
observation = {
|
||||||
|
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
|
||||||
|
"observation.state": torch.tensor([0.5, 0.0]),
|
||||||
|
}
|
||||||
|
action = torch.tensor([1.0, -0.5])
|
||||||
|
transition = (observation, action, 1.0, False, False, {}, {})
|
||||||
|
|
||||||
|
processed_transition = normalization_processor(transition)
|
||||||
|
|
||||||
|
# Check normalized observations
|
||||||
|
processed_obs = processed_transition[TransitionIndex.OBSERVATION]
|
||||||
|
expected_image = (torch.tensor([0.7, 0.5, 0.3]) - 0.5) / 0.2
|
||||||
|
assert torch.allclose(processed_obs["observation.image"], expected_image)
|
||||||
|
|
||||||
|
# Check unnormalized action
|
||||||
|
processed_action = processed_transition[TransitionIndex.ACTION]
|
||||||
|
expected_action = torch.tensor([1.0 * 1.0 + 0.0, -0.5 * 2.0 + 0.0])
|
||||||
|
assert torch.allclose(processed_action, expected_action)
|
||||||
|
|
||||||
|
# Check other fields remain unchanged
|
||||||
|
assert processed_transition[TransitionIndex.REWARD] == 1.0
|
||||||
|
assert not processed_transition[TransitionIndex.DONE]
|
||||||
|
|
||||||
|
|
||||||
|
def test_disable_action_unnormalization(full_stats):
|
||||||
|
processor = NormalizationProcessor(stats=full_stats, unnormalize_action=False)
|
||||||
|
|
||||||
|
action = torch.tensor([1.0, -0.5])
|
||||||
|
transition = (None, action, None, None, None, None, None)
|
||||||
|
|
||||||
|
processed_transition = processor(transition)
|
||||||
|
|
||||||
|
# Action should remain unchanged
|
||||||
|
assert torch.allclose(processed_transition[TransitionIndex.ACTION], action)
|
||||||
|
|
||||||
|
|
||||||
|
def test_processor_from_lerobot_dataset(full_stats):
|
||||||
|
# Mock dataset
|
||||||
|
mock_dataset = Mock()
|
||||||
|
mock_dataset.meta.stats = full_stats
|
||||||
|
|
||||||
|
processor = NormalizationProcessor.from_lerobot_dataset(
|
||||||
|
mock_dataset, normalize_keys={"observation.image"}, unnormalize_action=True
|
||||||
|
)
|
||||||
|
|
||||||
|
assert processor.normalize_keys == {"observation.image"}
|
||||||
|
assert processor.unnormalize_action
|
||||||
|
assert "observation.image" in processor._tensor_stats
|
||||||
|
assert "action" in processor._tensor_stats
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_config(full_stats):
|
||||||
|
processor = NormalizationProcessor(
|
||||||
|
stats=full_stats, normalize_keys={"observation.image"}, unnormalize_action=False, eps=1e-6
|
||||||
|
)
|
||||||
|
|
||||||
|
config = processor.get_config()
|
||||||
|
assert config == {"normalize_keys": ["observation.image"], "unnormalize_action": False, "eps": 1e-6}
|
||||||
|
|
||||||
|
|
||||||
|
def test_integration_with_robot_processor(normalization_processor):
|
||||||
|
"""Test integration with RobotProcessor pipeline"""
|
||||||
|
robot_processor = RobotProcessor([normalization_processor])
|
||||||
|
|
||||||
|
observation = {
|
||||||
|
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
|
||||||
|
"observation.state": torch.tensor([0.5, 0.0]),
|
||||||
|
}
|
||||||
|
action = torch.tensor([1.0, -0.5])
|
||||||
|
transition = (observation, action, 1.0, False, False, {}, {})
|
||||||
|
|
||||||
|
processed_transition = robot_processor(transition)
|
||||||
|
|
||||||
|
# Verify the processing worked
|
||||||
|
assert isinstance(processed_transition[TransitionIndex.OBSERVATION], dict)
|
||||||
|
assert isinstance(processed_transition[TransitionIndex.ACTION], torch.Tensor)
|
||||||
|
|
||||||
|
|
||||||
|
# Edge case tests
|
||||||
|
def test_empty_observation():
|
||||||
|
stats = {"observation.image": {"mean": [0.5], "std": [0.2]}}
|
||||||
|
normalizer = ObservationNormalizer(stats=stats)
|
||||||
|
|
||||||
|
transition = (None, None, None, None, None, None, None)
|
||||||
|
result = normalizer(transition)
|
||||||
|
|
||||||
|
assert result == transition
|
||||||
|
|
||||||
|
|
||||||
|
def test_empty_stats():
|
||||||
|
normalizer = ObservationNormalizer(stats={})
|
||||||
|
observation = {"observation.image": torch.tensor([0.5])}
|
||||||
|
transition = (observation, None, None, None, None, None, None)
|
||||||
|
|
||||||
|
result = normalizer(transition)
|
||||||
|
# Should return observation unchanged
|
||||||
|
assert torch.allclose(result[0]["observation.image"], observation["observation.image"])
|
||||||
|
|
||||||
|
|
||||||
|
def test_partial_stats():
|
||||||
|
stats = {
|
||||||
|
"observation.image": {"mean": [0.5]}, # Missing std
|
||||||
|
}
|
||||||
|
normalizer = ObservationNormalizer(stats=stats)
|
||||||
|
observation = {"observation.image": torch.tensor([0.7])}
|
||||||
|
transition = (observation, None, None, None, None, None, None)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="must contain either"):
|
||||||
|
normalizer(transition)
|
||||||
@@ -0,0 +1,393 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from lerobot.processor.pipeline import ProcessorStepRegistry, RobotProcessor, TransitionIndex
|
||||||
|
from lerobot.processor.rename_processor import RenameProcessor
|
||||||
|
|
||||||
|
|
||||||
|
def test_basic_renaming():
|
||||||
|
"""Test basic key renaming functionality."""
|
||||||
|
rename_map = {
|
||||||
|
"old_key1": "new_key1",
|
||||||
|
"old_key2": "new_key2",
|
||||||
|
}
|
||||||
|
processor = RenameProcessor(rename_map=rename_map)
|
||||||
|
|
||||||
|
observation = {
|
||||||
|
"old_key1": torch.tensor([1.0, 2.0]),
|
||||||
|
"old_key2": np.array([3.0, 4.0]),
|
||||||
|
"unchanged_key": "keep_me",
|
||||||
|
}
|
||||||
|
transition = (observation, None, None, None, None, None, None)
|
||||||
|
|
||||||
|
result = processor(transition)
|
||||||
|
processed_obs = result[TransitionIndex.OBSERVATION]
|
||||||
|
|
||||||
|
# Check renamed keys
|
||||||
|
assert "new_key1" in processed_obs
|
||||||
|
assert "new_key2" in processed_obs
|
||||||
|
assert "old_key1" not in processed_obs
|
||||||
|
assert "old_key2" not in processed_obs
|
||||||
|
|
||||||
|
# Check values are preserved
|
||||||
|
torch.testing.assert_close(processed_obs["new_key1"], torch.tensor([1.0, 2.0]))
|
||||||
|
np.testing.assert_array_equal(processed_obs["new_key2"], np.array([3.0, 4.0]))
|
||||||
|
|
||||||
|
# Check unchanged key is preserved
|
||||||
|
assert processed_obs["unchanged_key"] == "keep_me"
|
||||||
|
|
||||||
|
|
||||||
|
def test_empty_rename_map():
|
||||||
|
"""Test processor with empty rename map (should pass through unchanged)."""
|
||||||
|
processor = RenameProcessor(rename_map={})
|
||||||
|
|
||||||
|
observation = {
|
||||||
|
"key1": torch.tensor([1.0]),
|
||||||
|
"key2": "value2",
|
||||||
|
}
|
||||||
|
transition = (observation, None, None, None, None, None, None)
|
||||||
|
|
||||||
|
result = processor(transition)
|
||||||
|
processed_obs = result[TransitionIndex.OBSERVATION]
|
||||||
|
|
||||||
|
# All keys should be unchanged
|
||||||
|
assert processed_obs.keys() == observation.keys()
|
||||||
|
torch.testing.assert_close(processed_obs["key1"], observation["key1"])
|
||||||
|
assert processed_obs["key2"] == observation["key2"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_none_observation():
|
||||||
|
"""Test processor with None observation."""
|
||||||
|
processor = RenameProcessor(rename_map={"old": "new"})
|
||||||
|
|
||||||
|
transition = (None, None, None, None, None, None, None)
|
||||||
|
result = processor(transition)
|
||||||
|
|
||||||
|
# Should return transition unchanged
|
||||||
|
assert result == transition
|
||||||
|
|
||||||
|
|
||||||
|
def test_overlapping_rename():
|
||||||
|
"""Test renaming when new names might conflict."""
|
||||||
|
rename_map = {
|
||||||
|
"a": "b",
|
||||||
|
"b": "c", # This creates a potential conflict
|
||||||
|
}
|
||||||
|
processor = RenameProcessor(rename_map=rename_map)
|
||||||
|
|
||||||
|
observation = {
|
||||||
|
"a": 1,
|
||||||
|
"b": 2,
|
||||||
|
"x": 3,
|
||||||
|
}
|
||||||
|
transition = (observation, None, None, None, None, None, None)
|
||||||
|
|
||||||
|
result = processor(transition)
|
||||||
|
processed_obs = result[TransitionIndex.OBSERVATION]
|
||||||
|
|
||||||
|
# Check that renaming happens correctly
|
||||||
|
assert "a" not in processed_obs
|
||||||
|
assert processed_obs["b"] == 1 # 'a' renamed to 'b'
|
||||||
|
assert processed_obs["c"] == 2 # original 'b' renamed to 'c'
|
||||||
|
assert processed_obs["x"] == 3
|
||||||
|
|
||||||
|
|
||||||
|
def test_partial_rename():
|
||||||
|
"""Test renaming only some keys."""
|
||||||
|
rename_map = {
|
||||||
|
"observation.state": "observation.proprio_state",
|
||||||
|
"pixels": "observation.image",
|
||||||
|
}
|
||||||
|
processor = RenameProcessor(rename_map=rename_map)
|
||||||
|
|
||||||
|
observation = {
|
||||||
|
"observation.state": torch.randn(10),
|
||||||
|
"pixels": np.random.randint(0, 256, (64, 64, 3), dtype=np.uint8),
|
||||||
|
"reward": 1.0,
|
||||||
|
"info": {"episode": 1},
|
||||||
|
}
|
||||||
|
transition = (observation, None, None, None, None, None, None)
|
||||||
|
|
||||||
|
result = processor(transition)
|
||||||
|
processed_obs = result[TransitionIndex.OBSERVATION]
|
||||||
|
|
||||||
|
# Check renamed keys
|
||||||
|
assert "observation.proprio_state" in processed_obs
|
||||||
|
assert "observation.image" in processed_obs
|
||||||
|
assert "observation.state" not in processed_obs
|
||||||
|
assert "pixels" not in processed_obs
|
||||||
|
|
||||||
|
# Check unchanged keys
|
||||||
|
assert processed_obs["reward"] == 1.0
|
||||||
|
assert processed_obs["info"] == {"episode": 1}
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_config():
|
||||||
|
"""Test configuration serialization."""
|
||||||
|
rename_map = {
|
||||||
|
"old1": "new1",
|
||||||
|
"old2": "new2",
|
||||||
|
}
|
||||||
|
processor = RenameProcessor(rename_map=rename_map)
|
||||||
|
|
||||||
|
config = processor.get_config()
|
||||||
|
assert config == {"rename_map": rename_map}
|
||||||
|
|
||||||
|
|
||||||
|
def test_state_dict():
|
||||||
|
"""Test state dict (should be empty for RenameProcessor)."""
|
||||||
|
processor = RenameProcessor(rename_map={"old": "new"})
|
||||||
|
|
||||||
|
state = processor.state_dict()
|
||||||
|
assert state == {}
|
||||||
|
|
||||||
|
# Load state dict should work even with empty dict
|
||||||
|
processor.load_state_dict({})
|
||||||
|
|
||||||
|
|
||||||
|
def test_integration_with_robot_processor():
|
||||||
|
"""Test integration with RobotProcessor pipeline."""
|
||||||
|
rename_map = {
|
||||||
|
"agent_pos": "observation.state",
|
||||||
|
"pixels": "observation.image",
|
||||||
|
}
|
||||||
|
rename_processor = RenameProcessor(rename_map=rename_map)
|
||||||
|
|
||||||
|
pipeline = RobotProcessor([rename_processor])
|
||||||
|
|
||||||
|
observation = {
|
||||||
|
"agent_pos": np.array([1.0, 2.0, 3.0]),
|
||||||
|
"pixels": np.zeros((32, 32, 3), dtype=np.uint8),
|
||||||
|
"other_data": "preserve_me",
|
||||||
|
}
|
||||||
|
transition = (observation, None, 0.5, False, False, {}, {})
|
||||||
|
|
||||||
|
result = pipeline(transition)
|
||||||
|
processed_obs = result[TransitionIndex.OBSERVATION]
|
||||||
|
|
||||||
|
# Check renaming worked through pipeline
|
||||||
|
assert "observation.state" in processed_obs
|
||||||
|
assert "observation.image" in processed_obs
|
||||||
|
assert "agent_pos" not in processed_obs
|
||||||
|
assert "pixels" not in processed_obs
|
||||||
|
assert processed_obs["other_data"] == "preserve_me"
|
||||||
|
|
||||||
|
# Check other transition elements unchanged
|
||||||
|
assert result[TransitionIndex.REWARD] == 0.5
|
||||||
|
assert result[TransitionIndex.DONE] is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_and_load_pretrained():
|
||||||
|
"""Test saving and loading processor with RobotProcessor."""
|
||||||
|
rename_map = {
|
||||||
|
"old_state": "observation.state",
|
||||||
|
"old_image": "observation.image",
|
||||||
|
}
|
||||||
|
processor = RenameProcessor(rename_map=rename_map)
|
||||||
|
pipeline = RobotProcessor([processor], name="TestRenameProcessor")
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
# Save pipeline
|
||||||
|
pipeline.save_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
# Check files were created
|
||||||
|
config_path = Path(tmp_dir) / "processor.json"
|
||||||
|
assert config_path.exists()
|
||||||
|
|
||||||
|
# No state files should be created for RenameProcessor
|
||||||
|
state_files = list(Path(tmp_dir).glob("*.safetensors"))
|
||||||
|
assert len(state_files) == 0
|
||||||
|
|
||||||
|
# Load pipeline
|
||||||
|
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
assert loaded_pipeline.name == "TestRenameProcessor"
|
||||||
|
assert len(loaded_pipeline) == 1
|
||||||
|
|
||||||
|
# Check that loaded processor works correctly
|
||||||
|
loaded_processor = loaded_pipeline.steps[0]
|
||||||
|
assert isinstance(loaded_processor, RenameProcessor)
|
||||||
|
assert loaded_processor.rename_map == rename_map
|
||||||
|
|
||||||
|
# Test functionality after loading
|
||||||
|
observation = {"old_state": [1, 2, 3], "old_image": "image_data"}
|
||||||
|
transition = (observation, None, None, None, None, None, None)
|
||||||
|
|
||||||
|
result = loaded_pipeline(transition)
|
||||||
|
processed_obs = result[TransitionIndex.OBSERVATION]
|
||||||
|
|
||||||
|
assert "observation.state" in processed_obs
|
||||||
|
assert "observation.image" in processed_obs
|
||||||
|
assert processed_obs["observation.state"] == [1, 2, 3]
|
||||||
|
assert processed_obs["observation.image"] == "image_data"
|
||||||
|
|
||||||
|
|
||||||
|
def test_registry_functionality():
|
||||||
|
"""Test that RenameProcessor is properly registered."""
|
||||||
|
# Check that it's registered
|
||||||
|
assert "rename_processor" in ProcessorStepRegistry.list()
|
||||||
|
|
||||||
|
# Get from registry
|
||||||
|
retrieved_class = ProcessorStepRegistry.get("rename_processor")
|
||||||
|
assert retrieved_class is RenameProcessor
|
||||||
|
|
||||||
|
# Create instance from registry
|
||||||
|
instance = retrieved_class(rename_map={"old": "new"})
|
||||||
|
assert isinstance(instance, RenameProcessor)
|
||||||
|
assert instance.rename_map == {"old": "new"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_registry_based_save_load():
|
||||||
|
"""Test save/load using registry name instead of module path."""
|
||||||
|
processor = RenameProcessor(rename_map={"key1": "renamed_key1"})
|
||||||
|
pipeline = RobotProcessor([processor])
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
# Save and load
|
||||||
|
pipeline.save_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
# Verify config uses registry name
|
||||||
|
import json
|
||||||
|
|
||||||
|
with open(Path(tmp_dir) / "processor.json") as f:
|
||||||
|
config = json.load(f)
|
||||||
|
|
||||||
|
assert "registry_name" in config["steps"][0]
|
||||||
|
assert config["steps"][0]["registry_name"] == "rename_processor"
|
||||||
|
assert "class" not in config["steps"][0] # Should use registry, not module path
|
||||||
|
|
||||||
|
# Load should work
|
||||||
|
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir)
|
||||||
|
loaded_processor = loaded_pipeline.steps[0]
|
||||||
|
assert isinstance(loaded_processor, RenameProcessor)
|
||||||
|
assert loaded_processor.rename_map == {"key1": "renamed_key1"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_chained_rename_processors():
|
||||||
|
"""Test multiple RenameProcessors in a pipeline."""
|
||||||
|
# First processor: rename raw keys to intermediate format
|
||||||
|
processor1 = RenameProcessor(
|
||||||
|
rename_map={
|
||||||
|
"pos": "agent_position",
|
||||||
|
"img": "camera_image",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Second processor: rename to final format
|
||||||
|
processor2 = RenameProcessor(
|
||||||
|
rename_map={
|
||||||
|
"agent_position": "observation.state",
|
||||||
|
"camera_image": "observation.image",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
pipeline = RobotProcessor([processor1, processor2])
|
||||||
|
|
||||||
|
observation = {
|
||||||
|
"pos": np.array([1.0, 2.0]),
|
||||||
|
"img": "image_data",
|
||||||
|
"extra": "keep_me",
|
||||||
|
}
|
||||||
|
transition = (observation, None, None, None, None, None, None)
|
||||||
|
|
||||||
|
# Step through to see intermediate results
|
||||||
|
results = list(pipeline.step_through(transition))
|
||||||
|
|
||||||
|
# After first processor
|
||||||
|
assert "agent_position" in results[1][TransitionIndex.OBSERVATION]
|
||||||
|
assert "camera_image" in results[1][TransitionIndex.OBSERVATION]
|
||||||
|
|
||||||
|
# After second processor
|
||||||
|
final_obs = results[2][TransitionIndex.OBSERVATION]
|
||||||
|
assert "observation.state" in final_obs
|
||||||
|
assert "observation.image" in final_obs
|
||||||
|
assert final_obs["extra"] == "keep_me"
|
||||||
|
|
||||||
|
# Original keys should be gone
|
||||||
|
assert "pos" not in final_obs
|
||||||
|
assert "img" not in final_obs
|
||||||
|
assert "agent_position" not in final_obs
|
||||||
|
assert "camera_image" not in final_obs
|
||||||
|
|
||||||
|
|
||||||
|
def test_nested_observation_rename():
|
||||||
|
"""Test renaming with nested observation structures."""
|
||||||
|
rename_map = {
|
||||||
|
"observation.images.left": "observation.camera.left_view",
|
||||||
|
"observation.images.right": "observation.camera.right_view",
|
||||||
|
"observation.proprio": "observation.proprioception",
|
||||||
|
}
|
||||||
|
processor = RenameProcessor(rename_map=rename_map)
|
||||||
|
|
||||||
|
observation = {
|
||||||
|
"observation.images.left": torch.randn(3, 64, 64),
|
||||||
|
"observation.images.right": torch.randn(3, 64, 64),
|
||||||
|
"observation.proprio": torch.randn(7),
|
||||||
|
"observation.gripper": torch.tensor([0.0]), # Not renamed
|
||||||
|
}
|
||||||
|
transition = (observation, None, None, None, None, None, None)
|
||||||
|
|
||||||
|
result = processor(transition)
|
||||||
|
processed_obs = result[TransitionIndex.OBSERVATION]
|
||||||
|
|
||||||
|
# Check renames
|
||||||
|
assert "observation.camera.left_view" in processed_obs
|
||||||
|
assert "observation.camera.right_view" in processed_obs
|
||||||
|
assert "observation.proprioception" in processed_obs
|
||||||
|
|
||||||
|
# Check unchanged key
|
||||||
|
assert "observation.gripper" in processed_obs
|
||||||
|
|
||||||
|
# Check old keys removed
|
||||||
|
assert "observation.images.left" not in processed_obs
|
||||||
|
assert "observation.images.right" not in processed_obs
|
||||||
|
assert "observation.proprio" not in processed_obs
|
||||||
|
|
||||||
|
|
||||||
|
def test_value_types_preserved():
|
||||||
|
"""Test that various value types are preserved during renaming."""
|
||||||
|
rename_map = {"old_tensor": "new_tensor", "old_array": "new_array", "old_scalar": "new_scalar"}
|
||||||
|
processor = RenameProcessor(rename_map=rename_map)
|
||||||
|
|
||||||
|
tensor_value = torch.randn(3, 3)
|
||||||
|
array_value = np.random.rand(2, 2)
|
||||||
|
|
||||||
|
observation = {
|
||||||
|
"old_tensor": tensor_value,
|
||||||
|
"old_array": array_value,
|
||||||
|
"old_scalar": 42,
|
||||||
|
"old_string": "hello",
|
||||||
|
"old_dict": {"nested": "value"},
|
||||||
|
"old_list": [1, 2, 3],
|
||||||
|
}
|
||||||
|
transition = (observation, None, None, None, None, None, None)
|
||||||
|
|
||||||
|
result = processor(transition)
|
||||||
|
processed_obs = result[TransitionIndex.OBSERVATION]
|
||||||
|
|
||||||
|
# Check that values and types are preserved
|
||||||
|
assert torch.equal(processed_obs["new_tensor"], tensor_value)
|
||||||
|
assert np.array_equal(processed_obs["new_array"], array_value)
|
||||||
|
assert processed_obs["new_scalar"] == 42
|
||||||
|
assert processed_obs["old_string"] == "hello"
|
||||||
|
assert processed_obs["old_dict"] == {"nested": "value"}
|
||||||
|
assert processed_obs["old_list"] == [1, 2, 3]
|
||||||
Reference in New Issue
Block a user