refactor(pipeline): Transition from tuple to dictionary format for EnvTransition

- Updated the EnvTransition structure to use a dictionary format instead of a tuple, enhancing readability and maintainability.
- Replaced instances of TransitionIndex with TransitionKey for accessing transition components.
- Adjusted related processing functions and tests to accommodate the new dictionary format, ensuring consistent handling of transitions across the codebase.
This commit is contained in:
Adil Zouitine
2025-07-21 14:54:31 +02:00
parent 14c2ece004
commit f2b79656eb
16 changed files with 828 additions and 650 deletions
+107 -73
View File
@@ -18,7 +18,7 @@ import json
import tempfile
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict
from typing import Any
import numpy as np
import pytest
@@ -26,6 +26,22 @@ import torch
import torch.nn as nn
from lerobot.processor import EnvTransition, ProcessorStepRegistry, RobotProcessor
from lerobot.processor.pipeline import TransitionKey
def create_transition(
observation=None, action=None, reward=0.0, done=False, truncated=False, info=None, complementary_data=None
):
"""Helper to create an EnvTransition dictionary."""
return {
TransitionKey.OBSERVATION: observation,
TransitionKey.ACTION: action,
TransitionKey.REWARD: reward,
TransitionKey.DONE: done,
TransitionKey.TRUNCATED: truncated,
TransitionKey.INFO: info if info is not None else {},
TransitionKey.COMPLEMENTARY_DATA: complementary_data if complementary_data is not None else {},
}
@dataclass
@@ -45,14 +61,16 @@ class MockStep:
def __call__(self, transition: EnvTransition) -> EnvTransition:
"""Add a counter to the complementary_data."""
obs, action, reward, done, truncated, info, comp_data = transition
comp_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
comp_data = {} if comp_data is None else dict(comp_data) # Make a copy
comp_data[f"{self.name}_counter"] = self.counter
self.counter += 1
return (obs, action, reward, done, truncated, info, comp_data)
# Create a new transition with updated complementary_data
new_transition = transition.copy()
new_transition[TransitionKey.COMPLEMENTARY_DATA] = comp_data
return new_transition
def get_config(self) -> dict[str, Any]:
# Return all JSON-serializable attributes that should be persisted
@@ -79,12 +97,14 @@ class MockStepWithoutOptionalMethods:
def __call__(self, transition: EnvTransition) -> EnvTransition:
"""Multiply reward by multiplier."""
obs, action, reward, done, truncated, info, comp_data = transition
reward = transition.get(TransitionKey.REWARD)
if reward is not None:
reward = reward * self.multiplier
new_transition = transition.copy()
new_transition[TransitionKey.REWARD] = reward * self.multiplier
return new_transition
return (obs, action, reward, done, truncated, info, comp_data)
return transition
@dataclass
@@ -105,7 +125,7 @@ class MockStepWithTensorState:
def __call__(self, transition: EnvTransition) -> EnvTransition:
"""Update running statistics."""
obs, action, reward, done, truncated, info, comp_data = transition
reward = transition.get(TransitionKey.REWARD)
if reward is not None:
# Update running mean
@@ -143,7 +163,7 @@ def test_empty_pipeline():
"""Test pipeline with no steps."""
pipeline = RobotProcessor()
transition = (None, None, 0.0, False, False, {}, {})
transition = create_transition()
result = pipeline(transition)
assert result == transition
@@ -155,15 +175,15 @@ def test_single_step_pipeline():
step = MockStep("test_step")
pipeline = RobotProcessor([step])
transition = (None, None, 0.0, False, False, {}, {})
transition = create_transition()
result = pipeline(transition)
assert len(pipeline) == 1
assert result[6]["test_step_counter"] == 0 # complementary_data
assert result[TransitionKey.COMPLEMENTARY_DATA]["test_step_counter"] == 0
# Call again to test counter increment
result = pipeline(transition)
assert result[6]["test_step_counter"] == 1
assert result[TransitionKey.COMPLEMENTARY_DATA]["test_step_counter"] == 1
def test_multiple_steps_pipeline():
@@ -172,46 +192,46 @@ def test_multiple_steps_pipeline():
step2 = MockStep("step2")
pipeline = RobotProcessor([step1, step2])
transition = (None, None, 0.0, False, False, {}, {})
transition = create_transition()
result = pipeline(transition)
assert len(pipeline) == 2
assert result[6]["step1_counter"] == 0
assert result[6]["step2_counter"] == 0
assert result[TransitionKey.COMPLEMENTARY_DATA]["step1_counter"] == 0
assert result[TransitionKey.COMPLEMENTARY_DATA]["step2_counter"] == 0
def test_invalid_transition_format():
"""Test pipeline with invalid transition format."""
pipeline = RobotProcessor([MockStep()])
# Test with wrong number of elements
with pytest.raises(ValueError, match="EnvTransition must be a 7-tuple"):
pipeline((None, None, 0.0)) # Only 3 elements
# Test with wrong type (tuple instead of dict)
with pytest.raises(ValueError, match="EnvTransition must be a dictionary"):
pipeline((None, None, 0.0, False, False, {}, {})) # Tuple instead of dict
# Test with wrong type
with pytest.raises(ValueError, match="EnvTransition must be a 7-tuple"):
pipeline("not a tuple")
# Test with wrong type (string)
with pytest.raises(ValueError, match="EnvTransition must be a dictionary"):
pipeline("not a dict")
def test_step_through():
"""Test step_through method with tuple input."""
"""Test step_through method with dict input."""
step1 = MockStep("step1")
step2 = MockStep("step2")
pipeline = RobotProcessor([step1, step2])
transition = (None, None, 0.0, False, False, {}, {})
transition = create_transition()
results = list(pipeline.step_through(transition))
assert len(results) == 3 # Original + 2 steps
assert results[0] == transition # Original
assert "step1_counter" in results[1][6] # After step1
assert "step2_counter" in results[2][6] # After step2
assert "step1_counter" in results[1][TransitionKey.COMPLEMENTARY_DATA] # After step1
assert "step2_counter" in results[2][TransitionKey.COMPLEMENTARY_DATA] # After step2
# Ensure all results are tuples (same format as input)
# Ensure all results are dicts (same format as input)
for result in results:
assert isinstance(result, tuple)
assert len(result) == 7
assert isinstance(result, dict)
assert all(isinstance(k, TransitionKey) for k in result.keys())
def test_step_through_with_dict():
@@ -279,7 +299,7 @@ def test_hooks():
pipeline.register_before_step_hook(before_hook)
pipeline.register_after_step_hook(after_hook)
transition = (None, None, 0.0, False, False, {}, {})
transition = create_transition()
pipeline(transition)
assert before_calls == [0]
@@ -292,15 +312,16 @@ def test_hook_modification():
pipeline = RobotProcessor([step])
def modify_reward_hook(idx: int, transition: EnvTransition):
obs, action, reward, done, truncated, info, comp_data = transition
return (obs, action, 42.0, done, truncated, info, comp_data)
new_transition = transition.copy()
new_transition[TransitionKey.REWARD] = 42.0
return new_transition
pipeline.register_before_step_hook(modify_reward_hook)
transition = (None, None, 0.0, False, False, {}, {})
transition = create_transition()
result = pipeline(transition)
assert result[2] == 42.0 # reward modified by hook
assert result[TransitionKey.REWARD] == 42.0 # reward modified by hook
def test_reset():
@@ -316,7 +337,7 @@ def test_reset():
pipeline.register_reset_hook(reset_hook)
# Make some calls to increment counter
transition = (None, None, 0.0, False, False, {}, {})
transition = create_transition()
pipeline(transition)
pipeline(transition)
@@ -335,7 +356,7 @@ def test_profile_steps():
step2 = MockStep("step2")
pipeline = RobotProcessor([step1, step2])
transition = (None, None, 0.0, False, False, {}, {})
transition = create_transition()
profile_results = pipeline.profile_steps(transition, num_runs=10)
@@ -397,10 +418,10 @@ def test_step_without_optional_methods():
step = MockStepWithoutOptionalMethods(multiplier=3.0)
pipeline = RobotProcessor([step])
transition = (None, None, 2.0, False, False, {}, {})
transition = create_transition(reward=2.0)
result = pipeline(transition)
assert result[2] == 6.0 # 2.0 * 3.0
assert result[TransitionKey.REWARD] == 6.0 # 2.0 * 3.0
# Reset should work even if step doesn't implement reset
pipeline.reset()
@@ -419,7 +440,7 @@ def test_mixed_json_and_tensor_state():
# Process some transitions with rewards
for i in range(10):
transition = (None, None, float(i), False, False, {}, {})
transition = create_transition(reward=float(i))
pipeline(transition)
# Check state
@@ -466,7 +487,7 @@ class MockModuleStep(nn.Module):
def __call__(self, transition: EnvTransition) -> EnvTransition:
"""Process transition and update running mean."""
obs, action, reward, done, truncated, info, comp_data = transition
obs = transition.get(TransitionKey.OBSERVATION)
if obs is not None and isinstance(obs, torch.Tensor):
# Process observation through linear layer
@@ -509,7 +530,7 @@ def test_to_device_with_state_dict():
# Process some transitions to populate state
for i in range(10):
transition = (None, None, float(i), False, False, {}, {})
transition = create_transition(reward=float(i))
pipeline(transition)
# Check initial device (should be CPU)
@@ -551,7 +572,7 @@ def test_to_device_with_module():
# Process some data
obs = torch.randn(2, 5)
transition = (obs, None, 1.0, False, False, {}, {})
transition = create_transition(observation=obs, reward=1.0)
pipeline(transition)
# Check initial device
@@ -575,7 +596,7 @@ def test_to_device_with_module():
# Verify the module still works after transfer
obs_cuda = torch.randn(2, 5, device="cuda:0")
transition = (obs_cuda, None, 1.0, False, False, {}, {})
transition = create_transition(observation=obs_cuda, reward=1.0)
pipeline(transition) # Should not raise an error
@@ -589,7 +610,7 @@ def test_to_device_mixed_steps():
# Process some data
for i in range(5):
transition = (torch.randn(2, 10), None, float(i), False, False, {}, {})
transition = create_transition(observation=torch.randn(2, 10), reward=float(i))
pipeline(transition)
# Check initial state
@@ -630,7 +651,7 @@ def test_to_device_preserves_functionality():
# Process initial data
rewards = [1.0, 2.0, 3.0]
for r in rewards:
transition = (None, None, r, False, False, {}, {})
transition = create_transition(reward=r)
pipeline(transition)
# Check state before transfer
@@ -645,7 +666,7 @@ def test_to_device_preserves_functionality():
assert step.running_count == initial_count
# Process more data to ensure functionality
transition = (None, None, 4.0, False, False, {}, {})
transition = create_transition(reward=4.0)
_ = pipeline(transition)
assert step.running_count == 4
@@ -700,7 +721,8 @@ class MockNonModuleStepWithState:
def __call__(self, transition: EnvTransition) -> EnvTransition:
"""Process transition using tensor operations."""
obs, action, reward, done, truncated, info, comp_data = transition
obs = transition.get(TransitionKey.OBSERVATION)
comp_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
if obs is not None and isinstance(obs, torch.Tensor) and obs.numel() >= self.feature_dim:
# Perform some tensor operations
@@ -718,7 +740,12 @@ class MockNonModuleStepWithState:
comp_data[f"{self.name}_mean_output"] = output.mean().item()
comp_data[f"{self.name}_steps"] = self.step_count.item()
return (obs, action, reward, done, truncated, info, comp_data)
# Return updated transition
new_transition = transition.copy()
new_transition[TransitionKey.COMPLEMENTARY_DATA] = comp_data
return new_transition
return transition
def get_config(self) -> dict[str, Any]:
return {
@@ -763,9 +790,9 @@ def test_to_device_non_module_class():
# Process some data to populate state
for i in range(3):
obs = torch.randn(2, 5)
transition = (obs, None, float(i), False, False, {}, {})
transition = create_transition(observation=obs, reward=float(i))
result = pipeline(transition)
comp_data = result[6]
comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
assert f"{non_module_step.name}_steps" in comp_data
# Verify all tensors are on CPU initially
@@ -811,9 +838,9 @@ def test_to_device_non_module_class():
# Test that step still works on GPU
obs_gpu = torch.randn(2, 5, device="cuda")
transition = (obs_gpu, None, 1.0, False, False, {}, {})
transition = create_transition(observation=obs_gpu, reward=1.0)
result = pipeline(transition)
comp_data = result[6]
comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
# Verify processing worked
assert comp_data[f"{non_module_step.name}_steps"] == 4
@@ -835,7 +862,7 @@ def test_to_device_module_vs_non_module():
# Process some data
obs = torch.randn(2, 5)
transition = (obs, None, 1.0, False, False, {}, {})
transition = create_transition(observation=obs, reward=1.0)
_ = pipeline(transition)
# Check initial devices
@@ -860,7 +887,7 @@ def test_to_device_module_vs_non_module():
# Process data on GPU
obs_gpu = torch.randn(2, 5, device="cuda")
transition = (obs_gpu, None, 2.0, False, False, {}, {})
transition = create_transition(observation=obs_gpu, reward=2.0)
_ = pipeline(transition)
# Verify both steps processed the data
@@ -889,7 +916,8 @@ class MockStepWithNonSerializableParam:
self.env = env # Non-serializable parameter (like gym.Env)
def __call__(self, transition: EnvTransition) -> EnvTransition:
obs, action, reward, done, truncated, info, comp_data = transition
reward = transition.get(TransitionKey.REWARD)
comp_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
# Use the env parameter if provided
if self.env is not None:
@@ -897,10 +925,14 @@ class MockStepWithNonSerializableParam:
comp_data[f"{self.name}_env_info"] = str(self.env)
# Apply multiplier to reward
new_transition = transition.copy()
if reward is not None:
reward = reward * self.multiplier
new_transition[TransitionKey.REWARD] = reward * self.multiplier
return (obs, action, reward, done, truncated, info, comp_data)
if comp_data:
new_transition[TransitionKey.COMPLEMENTARY_DATA] = comp_data
return new_transition
def get_config(self) -> dict[str, Any]:
# Note: env is intentionally NOT included here as it's not serializable
@@ -928,13 +960,15 @@ class RegisteredMockStep:
device: str = "cpu"
def __call__(self, transition: EnvTransition) -> EnvTransition:
obs, action, reward, done, truncated, info, comp_data = transition
comp_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
comp_data = {} if comp_data is None else dict(comp_data)
comp_data["registered_step_value"] = self.value
comp_data["registered_step_device"] = self.device
return (obs, action, reward, done, truncated, info, comp_data)
new_transition = transition.copy()
new_transition[TransitionKey.COMPLEMENTARY_DATA] = comp_data
return new_transition
def get_config(self) -> dict[str, Any]:
return {
@@ -993,18 +1027,18 @@ def test_from_pretrained_with_overrides():
assert loaded_pipeline.name == "TestOverrides"
# Test the loaded steps
transition = (None, None, 1.0, False, False, {}, {})
transition = create_transition(reward=1.0)
result = loaded_pipeline(transition)
# Check that overrides were applied
comp_data = result[6]
comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
assert "env_step_env_info" in comp_data
assert comp_data["env_step_env_info"] == "MockEnvironment(test_env)"
assert comp_data["registered_step_value"] == 200
assert comp_data["registered_step_device"] == "cuda"
# Check that multiplier override was applied
assert result[2] == 3.0 # 1.0 * 3.0 (overridden multiplier)
assert result[TransitionKey.REWARD] == 3.0 # 1.0 * 3.0 (overridden multiplier)
def test_from_pretrained_with_partial_overrides():
@@ -1024,13 +1058,13 @@ def test_from_pretrained_with_partial_overrides():
# Both steps will get the override
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides=overrides)
transition = (None, None, 1.0, False, False, {}, {})
transition = create_transition(reward=1.0)
result = loaded_pipeline(transition)
# The reward should be affected by both steps, both getting the override
# First step: 1.0 * 5.0 = 5.0 (overridden)
# Second step: 5.0 * 5.0 = 25.0 (also overridden)
assert result[2] == 25.0
assert result[TransitionKey.REWARD] == 25.0
def test_from_pretrained_invalid_override_key():
@@ -1082,10 +1116,10 @@ def test_from_pretrained_registered_step_override():
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides=overrides)
# Test that overrides were applied
transition = (None, None, 0.0, False, False, {}, {})
transition = create_transition()
result = loaded_pipeline(transition)
comp_data = result[6]
comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
assert comp_data["registered_step_value"] == 999
assert comp_data["registered_step_device"] == "cuda"
@@ -1110,13 +1144,13 @@ def test_from_pretrained_mixed_registered_and_unregistered():
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides=overrides)
# Test both steps
transition = (None, None, 2.0, False, False, {}, {})
transition = create_transition(reward=2.0)
result = loaded_pipeline(transition)
comp_data = result[6]
comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
assert comp_data["unregistered_env_info"] == "MockEnvironment(mixed_test)"
assert comp_data["registered_step_value"] == 777
assert result[2] == 8.0 # 2.0 * 4.0
assert result[TransitionKey.REWARD] == 8.0 # 2.0 * 4.0
def test_from_pretrained_no_overrides():
@@ -1133,10 +1167,10 @@ def test_from_pretrained_no_overrides():
assert len(loaded_pipeline) == 1
# Test that the step works (env will be None)
transition = (None, None, 1.0, False, False, {}, {})
transition = create_transition(reward=1.0)
result = loaded_pipeline(transition)
assert result[2] == 3.0 # 1.0 * 3.0
assert result[TransitionKey.REWARD] == 3.0 # 1.0 * 3.0
def test_from_pretrained_empty_overrides():
@@ -1153,10 +1187,10 @@ def test_from_pretrained_empty_overrides():
assert len(loaded_pipeline) == 1
# Test that the step works normally
transition = (None, None, 1.0, False, False, {}, {})
transition = create_transition(reward=1.0)
result = loaded_pipeline(transition)
assert result[2] == 2.0
assert result[TransitionKey.REWARD] == 2.0
def test_from_pretrained_override_instantiation_error():
@@ -1185,7 +1219,7 @@ def test_from_pretrained_with_state_and_overrides():
# Process some data to create state
for i in range(10):
transition = (None, None, float(i), False, False, {}, {})
transition = create_transition(reward=float(i))
pipeline(transition)
with tempfile.TemporaryDirectory() as tmp_dir: