refactor(converters): gather converters and refactor the logic (#1833)

* refactor(converters): move batch transition functions to converters module

- Moved `_default_batch_to_transition` and `_default_transition_to_batch` functions from `pipeline.py` to `converters.py` for better organization and separation of concerns.
- Updated references in `RobotProcessor` to use the new location of these functions.
- Added tests to ensure correct functionality of the transition functions, including handling of index and task_index fields.
- Removed redundant tests from `pipeline.py` to streamline the test suite.

* refactor(processor): reorganize EnvTransition and TransitionKey definitions

- Moved `EnvTransition` and `TransitionKey` classes from `pipeline.py` to a new `core.py` module for better structure and maintainability.
- Updated import statements across relevant modules to reflect the new location of these definitions, ensuring consistent access throughout the codebase.

* refactor(converters): rename and update dataset frame conversion functions

- Replaced `to_dataset_frame` with `transition_to_dataset_frame` for clarity and consistency in naming.
- Updated references in `record.py`, `pipeline.py`, and tests to use the new function name.
- Introduced `merge_transitions` to streamline the merging of transitions, enhancing readability and maintainability.
- Adjusted related tests to ensure correct functionality with the new naming conventions.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix(processor): solve conflict artefacts

* refactor(converters): remove unused identity function and update type hints for merge_transitions

* refactor(processor): remove unused identity import and clean up gym_manipulator.py

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Steven Palma <steven.palma@huggingface.co>
This commit is contained in:
Adil Zouitine
2025-09-02 15:33:38 +02:00
committed by GitHub
parent 2c802ac134
commit 645c87e3a9
10 changed files with 421 additions and 341 deletions
+16 -20
View File
@@ -1,11 +1,7 @@
import torch
from lerobot.processor.pipeline import (
RobotProcessor,
TransitionKey,
_default_batch_to_transition,
_default_transition_to_batch,
)
from lerobot.processor.converters import batch_to_transition, transition_to_batch
from lerobot.processor.pipeline import RobotProcessor, TransitionKey
def _dummy_batch():
@@ -48,7 +44,7 @@ def test_observation_grouping_roundtrip():
def test_batch_to_transition_observation_grouping():
"""Test that _default_batch_to_transition correctly groups observation.* keys."""
"""Test that batch_to_transition correctly groups observation.* keys."""
batch = {
"observation.image.top": torch.randn(1, 3, 128, 128),
"observation.image.left": torch.randn(1, 3, 128, 128),
@@ -60,7 +56,7 @@ def test_batch_to_transition_observation_grouping():
"info": {"episode": 42},
}
transition = _default_batch_to_transition(batch)
transition = batch_to_transition(batch)
# Check observation is a dict with all observation.* keys
assert isinstance(transition[TransitionKey.OBSERVATION], dict)
@@ -87,7 +83,7 @@ def test_batch_to_transition_observation_grouping():
def test_transition_to_batch_observation_flattening():
"""Test that _default_transition_to_batch correctly flattens observation dict."""
"""Test that transition_to_batch correctly flattens observation dict."""
observation_dict = {
"observation.image.top": torch.randn(1, 3, 128, 128),
"observation.image.left": torch.randn(1, 3, 128, 128),
@@ -104,7 +100,7 @@ def test_transition_to_batch_observation_flattening():
TransitionKey.COMPLEMENTARY_DATA: {},
}
batch = _default_transition_to_batch(transition)
batch = transition_to_batch(transition)
# Check that observation.* keys are flattened back to batch
assert "observation.image.top" in batch
@@ -134,7 +130,7 @@ def test_no_observation_keys():
"info": {"test": "no_obs"},
}
transition = _default_batch_to_transition(batch)
transition = batch_to_transition(batch)
# Observation should be None when no observation.* keys
assert transition[TransitionKey.OBSERVATION] is None
@@ -147,7 +143,7 @@ def test_no_observation_keys():
assert transition[TransitionKey.INFO] == {"test": "no_obs"}
# Round trip should work
reconstructed_batch = _default_transition_to_batch(transition)
reconstructed_batch = transition_to_batch(transition)
assert reconstructed_batch["action"] == "action_data"
assert reconstructed_batch["next.reward"] == 2.0
assert not reconstructed_batch["next.done"]
@@ -159,7 +155,7 @@ def test_minimal_batch():
"""Test with minimal batch containing only observation.* and action."""
batch = {"observation.state": "minimal_state", "action": "minimal_action"}
transition = _default_batch_to_transition(batch)
transition = batch_to_transition(batch)
# Check observation
assert transition[TransitionKey.OBSERVATION] == {"observation.state": "minimal_state"}
@@ -173,7 +169,7 @@ def test_minimal_batch():
assert transition[TransitionKey.COMPLEMENTARY_DATA] == {}
# Round trip
reconstructed_batch = _default_transition_to_batch(transition)
reconstructed_batch = transition_to_batch(transition)
assert reconstructed_batch["observation.state"] == "minimal_state"
assert reconstructed_batch["action"] == "minimal_action"
assert reconstructed_batch["next.reward"] == 0.0
@@ -186,7 +182,7 @@ def test_empty_batch():
"""Test behavior with empty batch."""
batch = {}
transition = _default_batch_to_transition(batch)
transition = batch_to_transition(batch)
# All fields should have defaults
assert transition[TransitionKey.OBSERVATION] is None
@@ -198,7 +194,7 @@ def test_empty_batch():
assert transition[TransitionKey.COMPLEMENTARY_DATA] == {}
# Round trip
reconstructed_batch = _default_transition_to_batch(transition)
reconstructed_batch = transition_to_batch(transition)
assert reconstructed_batch["action"] is None
assert reconstructed_batch["next.reward"] == 0.0
assert not reconstructed_batch["next.done"]
@@ -219,8 +215,8 @@ def test_complex_nested_observation():
"info": {"episode_length": 200, "success": True},
}
transition = _default_batch_to_transition(batch)
reconstructed_batch = _default_transition_to_batch(transition)
transition = batch_to_transition(batch)
reconstructed_batch = transition_to_batch(transition)
# Check that all observation keys are preserved
original_obs_keys = {k for k in batch if k.startswith("observation.")}
@@ -254,7 +250,7 @@ def test_custom_converter():
def to_tr(batch):
# Custom converter that modifies the reward
tr = _default_batch_to_transition(batch)
tr = batch_to_transition(batch)
# Double the reward
reward = tr.get(TransitionKey.REWARD, 0.0)
new_tr = tr.copy()
@@ -262,7 +258,7 @@ def test_custom_converter():
return new_tr
def to_batch(tr):
batch = _default_transition_to_batch(tr)
batch = transition_to_batch(tr)
return batch
processor = RobotProcessor(steps=[], to_transition=to_tr, to_output=to_batch)