refactor(pipeline): Transition from tuple to dictionary format for EnvTransition

- Updated the EnvTransition structure to use a dictionary format instead of a tuple, enhancing readability and maintainability.
- Replaced instances of TransitionIndex with TransitionKey for accessing transition components.
- Adjusted related processing functions and tests to accommodate the new dictionary format, ensuring consistent handling of transitions across the codebase.
This commit is contained in:
Adil Zouitine
2025-07-21 14:54:31 +02:00
parent 14c2ece004
commit f2b79656eb
16 changed files with 828 additions and 650 deletions
+41 -25
View File
@@ -13,14 +13,28 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tempfile
from pathlib import Path
import numpy as np
import torch
from lerobot.processor import ProcessorStepRegistry, RenameProcessor, RobotProcessor, TransitionIndex
from lerobot.processor import ProcessorStepRegistry, RenameProcessor, RobotProcessor, TransitionKey
def create_transition(
observation=None, action=None, reward=None, done=None, truncated=None, info=None, complementary_data=None
):
"""Helper to create an EnvTransition dictionary."""
return {
TransitionKey.OBSERVATION: observation,
TransitionKey.ACTION: action,
TransitionKey.REWARD: reward,
TransitionKey.DONE: done,
TransitionKey.TRUNCATED: truncated,
TransitionKey.INFO: info,
TransitionKey.COMPLEMENTARY_DATA: complementary_data,
}
def test_basic_renaming():
@@ -36,10 +50,10 @@ def test_basic_renaming():
"old_key2": np.array([3.0, 4.0]),
"unchanged_key": "keep_me",
}
transition = (observation, None, None, None, None, None, None)
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[TransitionIndex.OBSERVATION]
processed_obs = result[TransitionKey.OBSERVATION]
# Check renamed keys
assert "new_key1" in processed_obs
@@ -63,10 +77,10 @@ def test_empty_rename_map():
"key1": torch.tensor([1.0]),
"key2": "value2",
}
transition = (observation, None, None, None, None, None, None)
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[TransitionIndex.OBSERVATION]
processed_obs = result[TransitionKey.OBSERVATION]
# All keys should be unchanged
assert processed_obs.keys() == observation.keys()
@@ -78,7 +92,7 @@ def test_none_observation():
"""Test processor with None observation."""
processor = RenameProcessor(rename_map={"old": "new"})
transition = (None, None, None, None, None, None, None)
transition = create_transition()
result = processor(transition)
# Should return transition unchanged
@@ -98,10 +112,10 @@ def test_overlapping_rename():
"b": 2,
"x": 3,
}
transition = (observation, None, None, None, None, None, None)
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[TransitionIndex.OBSERVATION]
processed_obs = result[TransitionKey.OBSERVATION]
# Check that renaming happens correctly
assert "a" not in processed_obs
@@ -124,10 +138,10 @@ def test_partial_rename():
"reward": 1.0,
"info": {"episode": 1},
}
transition = (observation, None, None, None, None, None, None)
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[TransitionIndex.OBSERVATION]
processed_obs = result[TransitionKey.OBSERVATION]
# Check renamed keys
assert "observation.proprio_state" in processed_obs
@@ -178,10 +192,12 @@ def test_integration_with_robot_processor():
"pixels": np.zeros((32, 32, 3), dtype=np.uint8),
"other_data": "preserve_me",
}
transition = (observation, None, 0.5, False, False, {}, {})
transition = create_transition(
observation=observation, reward=0.5, done=False, truncated=False, info={}, complementary_data={}
)
result = pipeline(transition)
processed_obs = result[TransitionIndex.OBSERVATION]
processed_obs = result[TransitionKey.OBSERVATION]
# Check renaming worked through pipeline
assert "observation.state" in processed_obs
@@ -191,8 +207,8 @@ def test_integration_with_robot_processor():
assert processed_obs["other_data"] == "preserve_me"
# Check other transition elements unchanged
assert result[TransitionIndex.REWARD] == 0.5
assert result[TransitionIndex.DONE] is False
assert result[TransitionKey.REWARD] == 0.5
assert result[TransitionKey.DONE] is False
def test_save_and_load_pretrained():
@@ -229,10 +245,10 @@ def test_save_and_load_pretrained():
# Test functionality after loading
observation = {"old_state": [1, 2, 3], "old_image": "image_data"}
transition = (observation, None, None, None, None, None, None)
transition = create_transition(observation=observation)
result = loaded_pipeline(transition)
processed_obs = result[TransitionIndex.OBSERVATION]
processed_obs = result[TransitionKey.OBSERVATION]
assert "observation.state" in processed_obs
assert "observation.image" in processed_obs
@@ -306,17 +322,17 @@ def test_chained_rename_processors():
"img": "image_data",
"extra": "keep_me",
}
transition = (observation, None, None, None, None, None, None)
transition = create_transition(observation=observation)
# Step through to see intermediate results
results = list(pipeline.step_through(transition))
# After first processor
assert "agent_position" in results[1][TransitionIndex.OBSERVATION]
assert "camera_image" in results[1][TransitionIndex.OBSERVATION]
assert "agent_position" in results[1][TransitionKey.OBSERVATION]
assert "camera_image" in results[1][TransitionKey.OBSERVATION]
# After second processor
final_obs = results[2][TransitionIndex.OBSERVATION]
final_obs = results[2][TransitionKey.OBSERVATION]
assert "observation.state" in final_obs
assert "observation.image" in final_obs
assert final_obs["extra"] == "keep_me"
@@ -343,10 +359,10 @@ def test_nested_observation_rename():
"observation.proprio": torch.randn(7),
"observation.gripper": torch.tensor([0.0]), # Not renamed
}
transition = (observation, None, None, None, None, None, None)
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[TransitionIndex.OBSERVATION]
processed_obs = result[TransitionKey.OBSERVATION]
# Check renames
assert "observation.camera.left_view" in processed_obs
@@ -378,10 +394,10 @@ def test_value_types_preserved():
"old_dict": {"nested": "value"},
"old_list": [1, 2, 3],
}
transition = (observation, None, None, None, None, None, None)
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[TransitionIndex.OBSERVATION]
processed_obs = result[TransitionKey.OBSERVATION]
# Check that values and types are preserved
assert torch.equal(processed_obs["new_tensor"], tensor_value)