mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 08:39:49 +00:00
refactor(pipeline): Transition from tuple to dictionary format for EnvTransition
- Updated the EnvTransition structure to use a dictionary format instead of a tuple, enhancing readability and maintainability. - Replaced instances of TransitionIndex with TransitionKey for accessing transition components. - Adjusted related processing functions and tests to accommodate the new dictionary format, ensuring consistent handling of transitions across the codebase.
This commit is contained in:
@@ -2,7 +2,7 @@ import torch
|
||||
|
||||
from lerobot.processor.pipeline import (
|
||||
RobotProcessor,
|
||||
TransitionIndex,
|
||||
TransitionKey,
|
||||
_default_batch_to_transition,
|
||||
_default_transition_to_batch,
|
||||
)
|
||||
@@ -63,27 +63,27 @@ def test_batch_to_transition_observation_grouping():
|
||||
transition = _default_batch_to_transition(batch)
|
||||
|
||||
# Check observation is a dict with all observation.* keys
|
||||
assert isinstance(transition[TransitionIndex.OBSERVATION], dict)
|
||||
assert "observation.image.top" in transition[TransitionIndex.OBSERVATION]
|
||||
assert "observation.image.left" in transition[TransitionIndex.OBSERVATION]
|
||||
assert "observation.state" in transition[TransitionIndex.OBSERVATION]
|
||||
assert isinstance(transition[TransitionKey.OBSERVATION], dict)
|
||||
assert "observation.image.top" in transition[TransitionKey.OBSERVATION]
|
||||
assert "observation.image.left" in transition[TransitionKey.OBSERVATION]
|
||||
assert "observation.state" in transition[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check values are preserved
|
||||
assert torch.allclose(
|
||||
transition[TransitionIndex.OBSERVATION]["observation.image.top"], batch["observation.image.top"]
|
||||
transition[TransitionKey.OBSERVATION]["observation.image.top"], batch["observation.image.top"]
|
||||
)
|
||||
assert torch.allclose(
|
||||
transition[TransitionIndex.OBSERVATION]["observation.image.left"], batch["observation.image.left"]
|
||||
transition[TransitionKey.OBSERVATION]["observation.image.left"], batch["observation.image.left"]
|
||||
)
|
||||
assert transition[TransitionIndex.OBSERVATION]["observation.state"] == [1, 2, 3, 4]
|
||||
assert transition[TransitionKey.OBSERVATION]["observation.state"] == [1, 2, 3, 4]
|
||||
|
||||
# Check other fields
|
||||
assert transition[TransitionIndex.ACTION] == "action_data"
|
||||
assert transition[TransitionIndex.REWARD] == 1.5
|
||||
assert transition[TransitionIndex.DONE]
|
||||
assert not transition[TransitionIndex.TRUNCATED]
|
||||
assert transition[TransitionIndex.INFO] == {"episode": 42}
|
||||
assert transition[TransitionIndex.COMPLEMENTARY_DATA] == {}
|
||||
assert transition[TransitionKey.ACTION] == "action_data"
|
||||
assert transition[TransitionKey.REWARD] == 1.5
|
||||
assert transition[TransitionKey.DONE]
|
||||
assert not transition[TransitionKey.TRUNCATED]
|
||||
assert transition[TransitionKey.INFO] == {"episode": 42}
|
||||
assert transition[TransitionKey.COMPLEMENTARY_DATA] == {}
|
||||
|
||||
|
||||
def test_transition_to_batch_observation_flattening():
|
||||
@@ -94,15 +94,15 @@ def test_transition_to_batch_observation_flattening():
|
||||
"observation.state": [1, 2, 3, 4],
|
||||
}
|
||||
|
||||
transition = (
|
||||
observation_dict, # observation
|
||||
"action_data", # action
|
||||
1.5, # reward
|
||||
True, # done
|
||||
False, # truncated
|
||||
{"episode": 42}, # info
|
||||
{}, # complementary_data
|
||||
)
|
||||
transition = {
|
||||
TransitionKey.OBSERVATION: observation_dict,
|
||||
TransitionKey.ACTION: "action_data",
|
||||
TransitionKey.REWARD: 1.5,
|
||||
TransitionKey.DONE: True,
|
||||
TransitionKey.TRUNCATED: False,
|
||||
TransitionKey.INFO: {"episode": 42},
|
||||
TransitionKey.COMPLEMENTARY_DATA: {},
|
||||
}
|
||||
|
||||
batch = _default_transition_to_batch(transition)
|
||||
|
||||
@@ -137,14 +137,14 @@ def test_no_observation_keys():
|
||||
transition = _default_batch_to_transition(batch)
|
||||
|
||||
# Observation should be None when no observation.* keys
|
||||
assert transition[TransitionIndex.OBSERVATION] is None
|
||||
assert transition[TransitionKey.OBSERVATION] is None
|
||||
|
||||
# Check other fields
|
||||
assert transition[TransitionIndex.ACTION] == "action_data"
|
||||
assert transition[TransitionIndex.REWARD] == 2.0
|
||||
assert not transition[TransitionIndex.DONE]
|
||||
assert transition[TransitionIndex.TRUNCATED]
|
||||
assert transition[TransitionIndex.INFO] == {"test": "no_obs"}
|
||||
assert transition[TransitionKey.ACTION] == "action_data"
|
||||
assert transition[TransitionKey.REWARD] == 2.0
|
||||
assert not transition[TransitionKey.DONE]
|
||||
assert transition[TransitionKey.TRUNCATED]
|
||||
assert transition[TransitionKey.INFO] == {"test": "no_obs"}
|
||||
|
||||
# Round trip should work
|
||||
reconstructed_batch = _default_transition_to_batch(transition)
|
||||
@@ -162,15 +162,15 @@ def test_minimal_batch():
|
||||
transition = _default_batch_to_transition(batch)
|
||||
|
||||
# Check observation
|
||||
assert transition[TransitionIndex.OBSERVATION] == {"observation.state": "minimal_state"}
|
||||
assert transition[TransitionIndex.ACTION] == "minimal_action"
|
||||
assert transition[TransitionKey.OBSERVATION] == {"observation.state": "minimal_state"}
|
||||
assert transition[TransitionKey.ACTION] == "minimal_action"
|
||||
|
||||
# Check defaults
|
||||
assert transition[TransitionIndex.REWARD] == 0.0
|
||||
assert not transition[TransitionIndex.DONE]
|
||||
assert not transition[TransitionIndex.TRUNCATED]
|
||||
assert transition[TransitionIndex.INFO] == {}
|
||||
assert transition[TransitionIndex.COMPLEMENTARY_DATA] == {}
|
||||
assert transition[TransitionKey.REWARD] == 0.0
|
||||
assert not transition[TransitionKey.DONE]
|
||||
assert not transition[TransitionKey.TRUNCATED]
|
||||
assert transition[TransitionKey.INFO] == {}
|
||||
assert transition[TransitionKey.COMPLEMENTARY_DATA] == {}
|
||||
|
||||
# Round trip
|
||||
reconstructed_batch = _default_transition_to_batch(transition)
|
||||
@@ -189,13 +189,13 @@ def test_empty_batch():
|
||||
transition = _default_batch_to_transition(batch)
|
||||
|
||||
# All fields should have defaults
|
||||
assert transition[TransitionIndex.OBSERVATION] is None
|
||||
assert transition[TransitionIndex.ACTION] is None
|
||||
assert transition[TransitionIndex.REWARD] == 0.0
|
||||
assert not transition[TransitionIndex.DONE]
|
||||
assert not transition[TransitionIndex.TRUNCATED]
|
||||
assert transition[TransitionIndex.INFO] == {}
|
||||
assert transition[TransitionIndex.COMPLEMENTARY_DATA] == {}
|
||||
assert transition[TransitionKey.OBSERVATION] is None
|
||||
assert transition[TransitionKey.ACTION] is None
|
||||
assert transition[TransitionKey.REWARD] == 0.0
|
||||
assert not transition[TransitionKey.DONE]
|
||||
assert not transition[TransitionKey.TRUNCATED]
|
||||
assert transition[TransitionKey.INFO] == {}
|
||||
assert transition[TransitionKey.COMPLEMENTARY_DATA] == {}
|
||||
|
||||
# Round trip
|
||||
reconstructed_batch = _default_transition_to_batch(transition)
|
||||
@@ -256,33 +256,27 @@ def test_custom_converter():
|
||||
# Custom converter that modifies the reward
|
||||
tr = _default_batch_to_transition(batch)
|
||||
# Double the reward
|
||||
reward = tr[TransitionIndex.REWARD] * 2 if tr[TransitionIndex.REWARD] is not None else 0.0
|
||||
return (
|
||||
tr[TransitionIndex.OBSERVATION],
|
||||
tr[TransitionIndex.ACTION],
|
||||
reward,
|
||||
tr[TransitionIndex.DONE],
|
||||
tr[TransitionIndex.TRUNCATED],
|
||||
tr[TransitionIndex.INFO],
|
||||
tr[TransitionIndex.COMPLEMENTARY_DATA],
|
||||
)
|
||||
reward = tr.get(TransitionKey.REWARD, 0.0)
|
||||
new_tr = tr.copy()
|
||||
new_tr[TransitionKey.REWARD] = reward * 2 if reward is not None else 0.0
|
||||
return new_tr
|
||||
|
||||
def to_batch(tr):
|
||||
# Custom converter that adds a custom field
|
||||
batch = _default_transition_to_batch(tr)
|
||||
batch["custom_field"] = "custom_value"
|
||||
return batch
|
||||
|
||||
proc = RobotProcessor([], to_transition=to_tr, to_output=to_batch)
|
||||
batch = _dummy_batch()
|
||||
out = proc(batch)
|
||||
processor = RobotProcessor(steps=[], to_transition=to_tr, to_output=to_batch)
|
||||
|
||||
# Check that custom modifications were applied
|
||||
assert out["next.reward"] == batch["next.reward"] * 2
|
||||
assert out["custom_field"] == "custom_value"
|
||||
batch = {
|
||||
"observation.state": torch.randn(1, 4),
|
||||
"action": torch.randn(1, 2),
|
||||
"next.reward": 1.0,
|
||||
"next.done": False,
|
||||
}
|
||||
|
||||
# Check that observation.* keys are still preserved
|
||||
original_obs_keys = {k: v for k, v in batch.items() if k.startswith("observation.")}
|
||||
output_obs_keys = {k: v for k, v in out.items() if k.startswith("observation.")}
|
||||
result = processor(batch)
|
||||
|
||||
assert set(original_obs_keys.keys()) == set(output_obs_keys.keys())
|
||||
# Check the reward was doubled by our custom converter
|
||||
assert result["next.reward"] == 2.0
|
||||
assert torch.allclose(result["observation.state"], batch["observation.state"])
|
||||
assert torch.allclose(result["action"], batch["action"])
|
||||
|
||||
Reference in New Issue
Block a user