refactor(pipeline): Transition from tuple to dictionary format for EnvTransition

- Updated the EnvTransition structure to use a dictionary format instead of a tuple, enhancing readability and maintainability.
- Replaced instances of TransitionIndex with TransitionKey for accessing transition components.
- Adjusted related processing functions and tests to accommodate the new dictionary format, ensuring consistent handling of transitions across the codebase.
This commit is contained in:
Adil Zouitine
2025-07-21 14:54:31 +02:00
parent 14c2ece004
commit f2b79656eb
16 changed files with 828 additions and 650 deletions
+94 -35
View File
@@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import Mock
import numpy as np
@@ -10,7 +25,22 @@ from lerobot.processor.normalize_processor import (
UnnormalizerProcessor,
_convert_stats_to_tensors,
)
from lerobot.processor.pipeline import RobotProcessor, TransitionIndex
from lerobot.processor.pipeline import RobotProcessor, TransitionKey
def create_transition(
observation=None, action=None, reward=None, done=None, truncated=None, info=None, complementary_data=None
):
"""Helper to create an EnvTransition dictionary."""
return {
TransitionKey.OBSERVATION: observation,
TransitionKey.ACTION: action,
TransitionKey.REWARD: reward,
TransitionKey.DONE: done,
TransitionKey.TRUNCATED: truncated,
TransitionKey.INFO: info,
TransitionKey.COMPLEMENTARY_DATA: complementary_data,
}
def test_numpy_conversion():
@@ -120,10 +150,10 @@ def test_mean_std_normalization(observation_normalizer):
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
"observation.state": torch.tensor([0.5, 0.0]),
}
transition = (observation, None, None, None, None, None, None)
transition = create_transition(observation=observation)
normalized_transition = observation_normalizer(transition)
normalized_obs = normalized_transition[TransitionIndex.OBSERVATION]
normalized_obs = normalized_transition[TransitionKey.OBSERVATION]
# Check mean/std normalization
expected_image = (torch.tensor([0.7, 0.5, 0.3]) - 0.5) / 0.2
@@ -134,10 +164,10 @@ def test_min_max_normalization(observation_normalizer):
observation = {
"observation.state": torch.tensor([0.5, 0.0]),
}
transition = (observation, None, None, None, None, None, None)
transition = create_transition(observation=observation)
normalized_transition = observation_normalizer(transition)
normalized_obs = normalized_transition[TransitionIndex.OBSERVATION]
normalized_obs = normalized_transition[TransitionKey.OBSERVATION]
# Check min/max normalization to [-1, 1]
# For state[0]: 2 * (0.5 - 0.0) / (1.0 - 0.0) - 1 = 0.0
@@ -157,10 +187,10 @@ def test_selective_normalization(observation_stats):
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
"observation.state": torch.tensor([0.5, 0.0]),
}
transition = (observation, None, None, None, None, None, None)
transition = create_transition(observation=observation)
normalized_transition = normalizer(transition)
normalized_obs = normalized_transition[TransitionIndex.OBSERVATION]
normalized_obs = normalized_transition[TransitionKey.OBSERVATION]
# Only image should be normalized
assert torch.allclose(normalized_obs["observation.image"], (torch.tensor([0.7, 0.5, 0.3]) - 0.5) / 0.2)
@@ -176,10 +206,10 @@ def test_device_compatibility(observation_stats):
observation = {
"observation.image": torch.tensor([0.7, 0.5, 0.3]).cuda(),
}
transition = (observation, None, None, None, None, None, None)
transition = create_transition(observation=observation)
normalized_transition = normalizer(transition)
normalized_obs = normalized_transition[TransitionIndex.OBSERVATION]
normalized_obs = normalized_transition[TransitionKey.OBSERVATION]
assert normalized_obs["observation.image"].device.type == "cuda"
@@ -220,10 +250,10 @@ def test_state_dict_save_load(observation_normalizer):
# Test that it works the same
observation = {"observation.image": torch.tensor([0.7, 0.5, 0.3])}
transition = (observation, None, None, None, None, None, None)
transition = create_transition(observation=observation)
result1 = observation_normalizer(transition)[0]
result2 = new_normalizer(transition)[0]
result1 = observation_normalizer(transition)[TransitionKey.OBSERVATION]
result2 = new_normalizer(transition)[TransitionKey.OBSERVATION]
assert torch.allclose(result1["observation.image"], result2["observation.image"])
@@ -271,10 +301,10 @@ def test_mean_std_unnormalization(action_stats_mean_std):
)
normalized_action = torch.tensor([1.0, -0.5, 2.0])
transition = (None, normalized_action, None, None, None, None, None)
transition = create_transition(action=normalized_action)
unnormalized_transition = unnormalizer(transition)
unnormalized_action = unnormalized_transition[TransitionIndex.ACTION]
unnormalized_action = unnormalized_transition[TransitionKey.ACTION]
# action * std + mean
expected = torch.tensor([1.0 * 1.0 + 0.0, -0.5 * 2.0 + 0.0, 2.0 * 0.5 + 0.0])
@@ -290,10 +320,10 @@ def test_min_max_unnormalization(action_stats_min_max):
# Actions in [-1, 1]
normalized_action = torch.tensor([0.0, -1.0, 1.0])
transition = (None, normalized_action, None, None, None, None, None)
transition = create_transition(action=normalized_action)
unnormalized_transition = unnormalizer(transition)
unnormalized_action = unnormalized_transition[TransitionIndex.ACTION]
unnormalized_action = unnormalized_transition[TransitionKey.ACTION]
# Map from [-1, 1] to [min, max]
# (action + 1) / 2 * (max - min) + min
@@ -315,10 +345,10 @@ def test_numpy_action_input(action_stats_mean_std):
)
normalized_action = np.array([1.0, -0.5, 2.0], dtype=np.float32)
transition = (None, normalized_action, None, None, None, None, None)
transition = create_transition(action=normalized_action)
unnormalized_transition = unnormalizer(transition)
unnormalized_action = unnormalized_transition[TransitionIndex.ACTION]
unnormalized_action = unnormalized_transition[TransitionKey.ACTION]
assert isinstance(unnormalized_action, torch.Tensor)
expected = torch.tensor([1.0, -1.0, 1.0])
@@ -332,7 +362,7 @@ def test_none_action(action_stats_mean_std):
features=features, norm_map=norm_map, stats={"action": action_stats_mean_std}
)
transition = (None, None, None, None, None, None, None)
transition = create_transition()
result = unnormalizer(transition)
# Should return transition unchanged
@@ -396,23 +426,31 @@ def test_combined_normalization(normalizer_processor):
"observation.state": torch.tensor([0.5, 0.0]),
}
action = torch.tensor([1.0, -0.5])
transition = (observation, action, 1.0, False, False, {}, {})
transition = create_transition(
observation=observation,
action=action,
reward=1.0,
done=False,
truncated=False,
info={},
complementary_data={},
)
processed_transition = normalizer_processor(transition)
# Check normalized observations
processed_obs = processed_transition[TransitionIndex.OBSERVATION]
processed_obs = processed_transition[TransitionKey.OBSERVATION]
expected_image = (torch.tensor([0.7, 0.5, 0.3]) - 0.5) / 0.2
assert torch.allclose(processed_obs["observation.image"], expected_image)
# Check normalized action
processed_action = processed_transition[TransitionIndex.ACTION]
processed_action = processed_transition[TransitionKey.ACTION]
expected_action = torch.tensor([(1.0 - 0.0) / 1.0, (-0.5 - 0.0) / 2.0])
assert torch.allclose(processed_action, expected_action)
# Check other fields remain unchanged
assert processed_transition[TransitionIndex.REWARD] == 1.0
assert not processed_transition[TransitionIndex.DONE]
assert processed_transition[TransitionKey.REWARD] == 1.0
assert not processed_transition[TransitionKey.DONE]
def test_processor_from_lerobot_dataset(full_stats):
@@ -466,13 +504,21 @@ def test_integration_with_robot_processor(normalizer_processor):
"observation.state": torch.tensor([0.5, 0.0]),
}
action = torch.tensor([1.0, -0.5])
transition = (observation, action, 1.0, False, False, {}, {})
transition = create_transition(
observation=observation,
action=action,
reward=1.0,
done=False,
truncated=False,
info={},
complementary_data={},
)
processed_transition = robot_processor(transition)
# Verify the processing worked
assert isinstance(processed_transition[TransitionIndex.OBSERVATION], dict)
assert isinstance(processed_transition[TransitionIndex.ACTION], torch.Tensor)
assert isinstance(processed_transition[TransitionKey.OBSERVATION], dict)
assert isinstance(processed_transition[TransitionKey.ACTION], torch.Tensor)
# Edge case tests
@@ -482,7 +528,7 @@ def test_empty_observation():
norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD}
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
transition = (None, None, None, None, None, None, None)
transition = create_transition()
result = normalizer(transition)
assert result == transition
@@ -493,11 +539,13 @@ def test_empty_stats():
norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD}
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats={})
observation = {"observation.image": torch.tensor([0.5])}
transition = (observation, None, None, None, None, None, None)
transition = create_transition(observation=observation)
result = normalizer(transition)
# Should return observation unchanged since no stats are available
assert torch.allclose(result[0]["observation.image"], observation["observation.image"])
assert torch.allclose(
result[TransitionKey.OBSERVATION]["observation.image"], observation["observation.image"]
)
def test_partial_stats():
@@ -507,9 +555,9 @@ def test_partial_stats():
norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD}
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
observation = {"observation.image": torch.tensor([0.7])}
transition = (observation, None, None, None, None, None, None)
transition = create_transition(observation=observation)
processed = normalizer(transition)[TransitionIndex.OBSERVATION]
processed = normalizer(transition)[TransitionKey.OBSERVATION]
assert torch.allclose(processed["observation.image"], observation["observation.image"])
@@ -551,14 +599,25 @@ def test_serialization_roundtrip(full_stats):
"observation.state": torch.tensor([0.5, 0.0]),
}
action = torch.tensor([1.0, -0.5])
transition = (observation, action, 1.0, False, False, {}, {})
transition = create_transition(
observation=observation,
action=action,
reward=1.0,
done=False,
truncated=False,
info={},
complementary_data={},
)
result1 = original_processor(transition)
result2 = new_processor(transition)
# Compare results
assert torch.allclose(result1[0]["observation.image"], result2[0]["observation.image"])
assert torch.allclose(result1[1], result2[1])
assert torch.allclose(
result1[TransitionKey.OBSERVATION]["observation.image"],
result2[TransitionKey.OBSERVATION]["observation.image"],
)
assert torch.allclose(result1[TransitionKey.ACTION], result2[TransitionKey.ACTION])
# Verify features and norm_map are correctly reconstructed
assert new_processor.features.keys() == original_processor.features.keys()