refactor(pipeline): Transition from tuple to dictionary format for EnvTransition

- Updated the EnvTransition structure to use a dictionary format instead of a tuple, enhancing readability and maintainability.
- Replaced instances of TransitionIndex with TransitionKey for accessing transition components.
- Adjusted related processing functions and tests to accommodate the new dictionary format, ensuring consistent handling of transitions across the codebase.
This commit is contained in:
Adil Zouitine
2025-07-21 14:54:31 +02:00
parent 14c2ece004
commit f2b79656eb
16 changed files with 828 additions and 650 deletions
+49 -33
View File
@@ -23,6 +23,22 @@ from lerobot.processor import (
StateProcessor,
VanillaObservationProcessor,
)
from lerobot.processor.pipeline import TransitionKey
def create_transition(
observation=None, action=None, reward=None, done=None, truncated=None, info=None, complementary_data=None
):
"""Helper to create an EnvTransition dictionary."""
return {
TransitionKey.OBSERVATION: observation,
TransitionKey.ACTION: action,
TransitionKey.REWARD: reward,
TransitionKey.DONE: done,
TransitionKey.TRUNCATED: truncated,
TransitionKey.INFO: info,
TransitionKey.COMPLEMENTARY_DATA: complementary_data,
}
def test_process_single_image():
@@ -33,10 +49,10 @@ def test_process_single_image():
image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8)
observation = {"pixels": image}
transition = (observation, None, None, None, None, None, None)
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[0]
processed_obs = result[TransitionKey.OBSERVATION]
# Check that the image was processed correctly
assert "observation.image" in processed_obs
@@ -60,10 +76,10 @@ def test_process_image_dict():
image2 = np.random.randint(0, 256, size=(48, 48, 3), dtype=np.uint8)
observation = {"pixels": {"camera1": image1, "camera2": image2}}
transition = (observation, None, None, None, None, None, None)
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[0]
processed_obs = result[TransitionKey.OBSERVATION]
# Check that both images were processed
assert "observation.images.camera1" in processed_obs
@@ -82,10 +98,10 @@ def test_process_batched_image():
image = np.random.randint(0, 256, size=(2, 64, 64, 3), dtype=np.uint8)
observation = {"pixels": image}
transition = (observation, None, None, None, None, None, None)
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[0]
processed_obs = result[TransitionKey.OBSERVATION]
# Check that batch dimension is preserved
assert processed_obs["observation.image"].shape == (2, 3, 64, 64)
@@ -98,7 +114,7 @@ def test_invalid_image_format():
# Test wrong channel order (channels first)
image = np.random.randint(0, 256, size=(3, 64, 64), dtype=np.uint8)
observation = {"pixels": image}
transition = (observation, None, None, None, None, None, None)
transition = create_transition(observation=observation)
with pytest.raises(ValueError, match="Expected channel-last images"):
processor(transition)
@@ -111,7 +127,7 @@ def test_invalid_image_dtype():
# Test wrong dtype
image = np.random.rand(64, 64, 3).astype(np.float32)
observation = {"pixels": image}
transition = (observation, None, None, None, None, None, None)
transition = create_transition(observation=observation)
with pytest.raises(ValueError, match="Expected torch.uint8 images"):
processor(transition)
@@ -122,10 +138,10 @@ def test_no_pixels_in_observation():
processor = ImageProcessor()
observation = {"other_data": np.array([1, 2, 3])}
transition = (observation, None, None, None, None, None, None)
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[0]
processed_obs = result[TransitionKey.OBSERVATION]
# Should preserve other data unchanged
assert "other_data" in processed_obs
@@ -136,7 +152,7 @@ def test_none_observation():
"""Test processor with None observation."""
processor = ImageProcessor()
transition = (None, None, None, None, None, None, None)
transition = create_transition()
result = processor(transition)
assert result == transition
@@ -167,10 +183,10 @@ def test_process_environment_state():
env_state = np.array([1.0, 2.0, 3.0], dtype=np.float32)
observation = {"environment_state": env_state}
transition = (observation, None, None, None, None, None, None)
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[0]
processed_obs = result[TransitionKey.OBSERVATION]
# Check that environment_state was renamed and processed
assert "observation.environment_state" in processed_obs
@@ -188,10 +204,10 @@ def test_process_agent_pos():
agent_pos = np.array([0.5, -0.5, 1.0], dtype=np.float32)
observation = {"agent_pos": agent_pos}
transition = (observation, None, None, None, None, None, None)
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[0]
processed_obs = result[TransitionKey.OBSERVATION]
# Check that agent_pos was renamed and processed
assert "observation.state" in processed_obs
@@ -211,10 +227,10 @@ def test_process_batched_states():
agent_pos = np.array([[0.5, -0.5], [1.0, -1.0]], dtype=np.float32)
observation = {"environment_state": env_state, "agent_pos": agent_pos}
transition = (observation, None, None, None, None, None, None)
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[0]
processed_obs = result[TransitionKey.OBSERVATION]
# Check that batch dimensions are preserved
assert processed_obs["observation.environment_state"].shape == (2, 2)
@@ -229,10 +245,10 @@ def test_process_both_states():
agent_pos = np.array([0.5, -0.5], dtype=np.float32)
observation = {"environment_state": env_state, "agent_pos": agent_pos, "other_data": "keep_me"}
transition = (observation, None, None, None, None, None, None)
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[0]
processed_obs = result[TransitionKey.OBSERVATION]
# Check that both states were processed
assert "observation.environment_state" in processed_obs
@@ -251,10 +267,10 @@ def test_no_states_in_observation():
processor = StateProcessor()
observation = {"other_data": np.array([1, 2, 3])}
transition = (observation, None, None, None, None, None, None)
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[0]
processed_obs = result[TransitionKey.OBSERVATION]
# Should preserve data unchanged
np.testing.assert_array_equal(processed_obs, observation)
@@ -275,10 +291,10 @@ def test_complete_observation_processing():
"agent_pos": agent_pos,
"other_data": "preserve_me",
}
transition = (observation, None, None, None, None, None, None)
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[0]
processed_obs = result[TransitionKey.OBSERVATION]
# Check that image was processed
assert "observation.image" in processed_obs
@@ -303,10 +319,10 @@ def test_image_only_processing():
image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8)
observation = {"pixels": image}
transition = (observation, None, None, None, None, None, None)
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[0]
processed_obs = result[TransitionKey.OBSERVATION]
assert "observation.image" in processed_obs
assert len(processed_obs) == 1
@@ -318,10 +334,10 @@ def test_state_only_processing():
agent_pos = np.array([1.0, 2.0], dtype=np.float32)
observation = {"agent_pos": agent_pos}
transition = (observation, None, None, None, None, None, None)
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[0]
processed_obs = result[TransitionKey.OBSERVATION]
assert "observation.state" in processed_obs
assert "agent_pos" not in processed_obs
@@ -332,10 +348,10 @@ def test_empty_observation():
processor = VanillaObservationProcessor()
observation = {}
transition = (observation, None, None, None, None, None, None)
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[0]
processed_obs = result[TransitionKey.OBSERVATION]
assert processed_obs == {}
@@ -369,8 +385,8 @@ def test_equivalent_to_original_function():
original_result = preprocess_observation(observation)
# Process with new processor
transition = (observation, None, None, None, None, None, None)
processor_result = processor(transition)[0]
transition = create_transition(observation=observation)
processor_result = processor(transition)[TransitionKey.OBSERVATION]
# Compare results
assert set(original_result.keys()) == set(processor_result.keys())
@@ -396,8 +412,8 @@ def test_equivalent_with_image_dict():
original_result = preprocess_observation(observation)
# Process with new processor
transition = (observation, None, None, None, None, None, None)
processor_result = processor(transition)[0]
transition = create_transition(observation=observation)
processor_result = processor(transition)[TransitionKey.OBSERVATION]
# Compare results
assert set(original_result.keys()) == set(processor_result.keys())