mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 16:49:55 +00:00
refactor(pipeline): Transition from tuple to dictionary format for EnvTransition
- Updated the EnvTransition structure to use a dictionary format instead of a tuple, enhancing readability and maintainability. - Replaced instances of TransitionIndex with TransitionKey for accessing transition components. - Adjusted related processing functions and tests to accommodate the new dictionary format, ensuring consistent handling of transitions across the codebase.
This commit is contained in:
@@ -23,6 +23,22 @@ from lerobot.processor import (
|
||||
StateProcessor,
|
||||
VanillaObservationProcessor,
|
||||
)
|
||||
from lerobot.processor.pipeline import TransitionKey
|
||||
|
||||
|
||||
def create_transition(
|
||||
observation=None, action=None, reward=None, done=None, truncated=None, info=None, complementary_data=None
|
||||
):
|
||||
"""Helper to create an EnvTransition dictionary."""
|
||||
return {
|
||||
TransitionKey.OBSERVATION: observation,
|
||||
TransitionKey.ACTION: action,
|
||||
TransitionKey.REWARD: reward,
|
||||
TransitionKey.DONE: done,
|
||||
TransitionKey.TRUNCATED: truncated,
|
||||
TransitionKey.INFO: info,
|
||||
TransitionKey.COMPLEMENTARY_DATA: complementary_data,
|
||||
}
|
||||
|
||||
|
||||
def test_process_single_image():
|
||||
@@ -33,10 +49,10 @@ def test_process_single_image():
|
||||
image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8)
|
||||
|
||||
observation = {"pixels": image}
|
||||
transition = (observation, None, None, None, None, None, None)
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
result = processor(transition)
|
||||
processed_obs = result[0]
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check that the image was processed correctly
|
||||
assert "observation.image" in processed_obs
|
||||
@@ -60,10 +76,10 @@ def test_process_image_dict():
|
||||
image2 = np.random.randint(0, 256, size=(48, 48, 3), dtype=np.uint8)
|
||||
|
||||
observation = {"pixels": {"camera1": image1, "camera2": image2}}
|
||||
transition = (observation, None, None, None, None, None, None)
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
result = processor(transition)
|
||||
processed_obs = result[0]
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check that both images were processed
|
||||
assert "observation.images.camera1" in processed_obs
|
||||
@@ -82,10 +98,10 @@ def test_process_batched_image():
|
||||
image = np.random.randint(0, 256, size=(2, 64, 64, 3), dtype=np.uint8)
|
||||
|
||||
observation = {"pixels": image}
|
||||
transition = (observation, None, None, None, None, None, None)
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
result = processor(transition)
|
||||
processed_obs = result[0]
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check that batch dimension is preserved
|
||||
assert processed_obs["observation.image"].shape == (2, 3, 64, 64)
|
||||
@@ -98,7 +114,7 @@ def test_invalid_image_format():
|
||||
# Test wrong channel order (channels first)
|
||||
image = np.random.randint(0, 256, size=(3, 64, 64), dtype=np.uint8)
|
||||
observation = {"pixels": image}
|
||||
transition = (observation, None, None, None, None, None, None)
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
with pytest.raises(ValueError, match="Expected channel-last images"):
|
||||
processor(transition)
|
||||
@@ -111,7 +127,7 @@ def test_invalid_image_dtype():
|
||||
# Test wrong dtype
|
||||
image = np.random.rand(64, 64, 3).astype(np.float32)
|
||||
observation = {"pixels": image}
|
||||
transition = (observation, None, None, None, None, None, None)
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
with pytest.raises(ValueError, match="Expected torch.uint8 images"):
|
||||
processor(transition)
|
||||
@@ -122,10 +138,10 @@ def test_no_pixels_in_observation():
|
||||
processor = ImageProcessor()
|
||||
|
||||
observation = {"other_data": np.array([1, 2, 3])}
|
||||
transition = (observation, None, None, None, None, None, None)
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
result = processor(transition)
|
||||
processed_obs = result[0]
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
# Should preserve other data unchanged
|
||||
assert "other_data" in processed_obs
|
||||
@@ -136,7 +152,7 @@ def test_none_observation():
|
||||
"""Test processor with None observation."""
|
||||
processor = ImageProcessor()
|
||||
|
||||
transition = (None, None, None, None, None, None, None)
|
||||
transition = create_transition()
|
||||
result = processor(transition)
|
||||
|
||||
assert result == transition
|
||||
@@ -167,10 +183,10 @@ def test_process_environment_state():
|
||||
|
||||
env_state = np.array([1.0, 2.0, 3.0], dtype=np.float32)
|
||||
observation = {"environment_state": env_state}
|
||||
transition = (observation, None, None, None, None, None, None)
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
result = processor(transition)
|
||||
processed_obs = result[0]
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check that environment_state was renamed and processed
|
||||
assert "observation.environment_state" in processed_obs
|
||||
@@ -188,10 +204,10 @@ def test_process_agent_pos():
|
||||
|
||||
agent_pos = np.array([0.5, -0.5, 1.0], dtype=np.float32)
|
||||
observation = {"agent_pos": agent_pos}
|
||||
transition = (observation, None, None, None, None, None, None)
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
result = processor(transition)
|
||||
processed_obs = result[0]
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check that agent_pos was renamed and processed
|
||||
assert "observation.state" in processed_obs
|
||||
@@ -211,10 +227,10 @@ def test_process_batched_states():
|
||||
agent_pos = np.array([[0.5, -0.5], [1.0, -1.0]], dtype=np.float32)
|
||||
|
||||
observation = {"environment_state": env_state, "agent_pos": agent_pos}
|
||||
transition = (observation, None, None, None, None, None, None)
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
result = processor(transition)
|
||||
processed_obs = result[0]
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check that batch dimensions are preserved
|
||||
assert processed_obs["observation.environment_state"].shape == (2, 2)
|
||||
@@ -229,10 +245,10 @@ def test_process_both_states():
|
||||
agent_pos = np.array([0.5, -0.5], dtype=np.float32)
|
||||
|
||||
observation = {"environment_state": env_state, "agent_pos": agent_pos, "other_data": "keep_me"}
|
||||
transition = (observation, None, None, None, None, None, None)
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
result = processor(transition)
|
||||
processed_obs = result[0]
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check that both states were processed
|
||||
assert "observation.environment_state" in processed_obs
|
||||
@@ -251,10 +267,10 @@ def test_no_states_in_observation():
|
||||
processor = StateProcessor()
|
||||
|
||||
observation = {"other_data": np.array([1, 2, 3])}
|
||||
transition = (observation, None, None, None, None, None, None)
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
result = processor(transition)
|
||||
processed_obs = result[0]
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
# Should preserve data unchanged
|
||||
np.testing.assert_array_equal(processed_obs, observation)
|
||||
@@ -275,10 +291,10 @@ def test_complete_observation_processing():
|
||||
"agent_pos": agent_pos,
|
||||
"other_data": "preserve_me",
|
||||
}
|
||||
transition = (observation, None, None, None, None, None, None)
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
result = processor(transition)
|
||||
processed_obs = result[0]
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check that image was processed
|
||||
assert "observation.image" in processed_obs
|
||||
@@ -303,10 +319,10 @@ def test_image_only_processing():
|
||||
|
||||
image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8)
|
||||
observation = {"pixels": image}
|
||||
transition = (observation, None, None, None, None, None, None)
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
result = processor(transition)
|
||||
processed_obs = result[0]
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
assert "observation.image" in processed_obs
|
||||
assert len(processed_obs) == 1
|
||||
@@ -318,10 +334,10 @@ def test_state_only_processing():
|
||||
|
||||
agent_pos = np.array([1.0, 2.0], dtype=np.float32)
|
||||
observation = {"agent_pos": agent_pos}
|
||||
transition = (observation, None, None, None, None, None, None)
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
result = processor(transition)
|
||||
processed_obs = result[0]
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
assert "observation.state" in processed_obs
|
||||
assert "agent_pos" not in processed_obs
|
||||
@@ -332,10 +348,10 @@ def test_empty_observation():
|
||||
processor = VanillaObservationProcessor()
|
||||
|
||||
observation = {}
|
||||
transition = (observation, None, None, None, None, None, None)
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
result = processor(transition)
|
||||
processed_obs = result[0]
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
assert processed_obs == {}
|
||||
|
||||
@@ -369,8 +385,8 @@ def test_equivalent_to_original_function():
|
||||
original_result = preprocess_observation(observation)
|
||||
|
||||
# Process with new processor
|
||||
transition = (observation, None, None, None, None, None, None)
|
||||
processor_result = processor(transition)[0]
|
||||
transition = create_transition(observation=observation)
|
||||
processor_result = processor(transition)[TransitionKey.OBSERVATION]
|
||||
|
||||
# Compare results
|
||||
assert set(original_result.keys()) == set(processor_result.keys())
|
||||
@@ -396,8 +412,8 @@ def test_equivalent_with_image_dict():
|
||||
original_result = preprocess_observation(observation)
|
||||
|
||||
# Process with new processor
|
||||
transition = (observation, None, None, None, None, None, None)
|
||||
processor_result = processor(transition)[0]
|
||||
transition = create_transition(observation=observation)
|
||||
processor_result = processor(transition)[TransitionKey.OBSERVATION]
|
||||
|
||||
# Compare results
|
||||
assert set(original_result.keys()) == set(processor_result.keys())
|
||||
|
||||
Reference in New Issue
Block a user