refactor(converters): gather converters and refactor the logic (#1833)

* refactor(converters): move batch transition functions to converters module

- Moved `_default_batch_to_transition` and `_default_transition_to_batch` functions from `pipeline.py` to `converters.py` for better organization and separation of concerns.
- Updated references in `RobotProcessor` to use the new location of these functions.
- Added tests to ensure correct functionality of the transition functions, including handling of index and task_index fields.
- Removed redundant tests from `pipeline.py` to streamline the test suite.

* refactor(processor): reorganize EnvTransition and TransitionKey definitions

- Moved `EnvTransition` and `TransitionKey` classes from `pipeline.py` to a new `core.py` module for better structure and maintainability.
- Updated import statements across relevant modules to reflect the new location of these definitions, ensuring consistent access throughout the codebase.

* refactor(converters): rename and update dataset frame conversion functions

- Replaced `to_dataset_frame` with `transition_to_dataset_frame` for clarity and consistency in naming.
- Updated references in `record.py`, `pipeline.py`, and tests to use the new function name.
- Introduced `merge_transitions` to streamline the merging of transitions, enhancing readability and maintainability.
- Adjusted related tests to ensure correct functionality with the new naming conventions.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix(processor): solve conflict artefacts

* refactor(converters): remove unused identity function and update type hints for merge_transitions

* refactor(processor): remove unused identity import and clean up gym_manipulator.py

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Steven Palma <steven.palma@huggingface.co>
This commit is contained in:
Adil Zouitine
2025-09-02 15:33:38 +02:00
committed by GitHub
parent 2c802ac134
commit 645c87e3a9
10 changed files with 421 additions and 341 deletions
+16 -20
View File
@@ -1,11 +1,7 @@
import torch
from lerobot.processor.pipeline import (
RobotProcessor,
TransitionKey,
_default_batch_to_transition,
_default_transition_to_batch,
)
from lerobot.processor.converters import batch_to_transition, transition_to_batch
from lerobot.processor.pipeline import RobotProcessor, TransitionKey
def _dummy_batch():
@@ -48,7 +44,7 @@ def test_observation_grouping_roundtrip():
def test_batch_to_transition_observation_grouping():
"""Test that _default_batch_to_transition correctly groups observation.* keys."""
"""Test that batch_to_transition correctly groups observation.* keys."""
batch = {
"observation.image.top": torch.randn(1, 3, 128, 128),
"observation.image.left": torch.randn(1, 3, 128, 128),
@@ -60,7 +56,7 @@ def test_batch_to_transition_observation_grouping():
"info": {"episode": 42},
}
transition = _default_batch_to_transition(batch)
transition = batch_to_transition(batch)
# Check observation is a dict with all observation.* keys
assert isinstance(transition[TransitionKey.OBSERVATION], dict)
@@ -87,7 +83,7 @@ def test_batch_to_transition_observation_grouping():
def test_transition_to_batch_observation_flattening():
"""Test that _default_transition_to_batch correctly flattens observation dict."""
"""Test that transition_to_batch correctly flattens observation dict."""
observation_dict = {
"observation.image.top": torch.randn(1, 3, 128, 128),
"observation.image.left": torch.randn(1, 3, 128, 128),
@@ -104,7 +100,7 @@ def test_transition_to_batch_observation_flattening():
TransitionKey.COMPLEMENTARY_DATA: {},
}
batch = _default_transition_to_batch(transition)
batch = transition_to_batch(transition)
# Check that observation.* keys are flattened back to batch
assert "observation.image.top" in batch
@@ -134,7 +130,7 @@ def test_no_observation_keys():
"info": {"test": "no_obs"},
}
transition = _default_batch_to_transition(batch)
transition = batch_to_transition(batch)
# Observation should be None when no observation.* keys
assert transition[TransitionKey.OBSERVATION] is None
@@ -147,7 +143,7 @@ def test_no_observation_keys():
assert transition[TransitionKey.INFO] == {"test": "no_obs"}
# Round trip should work
reconstructed_batch = _default_transition_to_batch(transition)
reconstructed_batch = transition_to_batch(transition)
assert reconstructed_batch["action"] == "action_data"
assert reconstructed_batch["next.reward"] == 2.0
assert not reconstructed_batch["next.done"]
@@ -159,7 +155,7 @@ def test_minimal_batch():
"""Test with minimal batch containing only observation.* and action."""
batch = {"observation.state": "minimal_state", "action": "minimal_action"}
transition = _default_batch_to_transition(batch)
transition = batch_to_transition(batch)
# Check observation
assert transition[TransitionKey.OBSERVATION] == {"observation.state": "minimal_state"}
@@ -173,7 +169,7 @@ def test_minimal_batch():
assert transition[TransitionKey.COMPLEMENTARY_DATA] == {}
# Round trip
reconstructed_batch = _default_transition_to_batch(transition)
reconstructed_batch = transition_to_batch(transition)
assert reconstructed_batch["observation.state"] == "minimal_state"
assert reconstructed_batch["action"] == "minimal_action"
assert reconstructed_batch["next.reward"] == 0.0
@@ -186,7 +182,7 @@ def test_empty_batch():
"""Test behavior with empty batch."""
batch = {}
transition = _default_batch_to_transition(batch)
transition = batch_to_transition(batch)
# All fields should have defaults
assert transition[TransitionKey.OBSERVATION] is None
@@ -198,7 +194,7 @@ def test_empty_batch():
assert transition[TransitionKey.COMPLEMENTARY_DATA] == {}
# Round trip
reconstructed_batch = _default_transition_to_batch(transition)
reconstructed_batch = transition_to_batch(transition)
assert reconstructed_batch["action"] is None
assert reconstructed_batch["next.reward"] == 0.0
assert not reconstructed_batch["next.done"]
@@ -219,8 +215,8 @@ def test_complex_nested_observation():
"info": {"episode_length": 200, "success": True},
}
transition = _default_batch_to_transition(batch)
reconstructed_batch = _default_transition_to_batch(transition)
transition = batch_to_transition(batch)
reconstructed_batch = transition_to_batch(transition)
# Check that all observation keys are preserved
original_obs_keys = {k for k in batch if k.startswith("observation.")}
@@ -254,7 +250,7 @@ def test_custom_converter():
def to_tr(batch):
# Custom converter that modifies the reward
tr = _default_batch_to_transition(batch)
tr = batch_to_transition(batch)
# Double the reward
reward = tr.get(TransitionKey.REWARD, 0.0)
new_tr = tr.copy()
@@ -262,7 +258,7 @@ def test_custom_converter():
return new_tr
def to_batch(tr):
batch = _default_transition_to_batch(tr)
batch = transition_to_batch(tr)
return batch
processor = RobotProcessor(steps=[], to_transition=to_tr, to_output=to_batch)
+119 -3
View File
@@ -3,11 +3,13 @@ import pytest
import torch
from lerobot.processor.converters import (
to_dataset_frame,
batch_to_transition,
to_output_robot_action,
to_tensor,
to_transition_robot_observation,
to_transition_teleop_action,
transition_to_batch,
transition_to_dataset_frame,
)
from lerobot.processor.pipeline import TransitionKey
@@ -107,7 +109,7 @@ def test_to_output_robot_action_strips_prefix_and_filters_pos_keys_only():
assert out["gripper.pos"] == pytest.approx(33.0)
def test_to_dataset_frame_merge_and_pack_vectors_and_metadata():
def test_transition_to_dataset_frame_merge_and_pack_vectors_and_metadata():
# Fabricate dataset features (as stored in dataset.meta["features"])
features = {
# Action vector: 3 elements in specific order
@@ -160,7 +162,7 @@ def test_to_dataset_frame_merge_and_pack_vectors_and_metadata():
}
# Directly call the refactored function
batch = to_dataset_frame([teleop_transition, robot_transition], features)
batch = transition_to_dataset_frame([teleop_transition, robot_transition], features)
# Images passthrough
assert "observation.images.front" in batch
@@ -377,3 +379,117 @@ def test_to_tensor_unsupported_type():
with pytest.raises(TypeError, match="Unsupported type for tensor conversion"):
to_tensor(object())
def create_transition(
observation=None, action=None, reward=0.0, done=False, truncated=False, info=None, complementary_data=None
):
"""Helper to create an EnvTransition dictionary."""
return {
TransitionKey.OBSERVATION: observation,
TransitionKey.ACTION: action,
TransitionKey.REWARD: reward,
TransitionKey.DONE: done,
TransitionKey.TRUNCATED: truncated,
TransitionKey.INFO: info if info is not None else {},
TransitionKey.COMPLEMENTARY_DATA: complementary_data if complementary_data is not None else {},
}
def test_batch_to_transition_with_index_fields():
"""Test that batch_to_transition handles index and task_index fields correctly."""
# Create batch with index and task_index fields
batch = {
"observation.state": torch.randn(1, 7),
"action": torch.randn(1, 4),
"next.reward": 1.5,
"next.done": False,
"task": ["pick_cube"],
"index": torch.tensor([42], dtype=torch.int64),
"task_index": torch.tensor([3], dtype=torch.int64),
}
transition = batch_to_transition(batch)
# Check basic transition structure
assert TransitionKey.OBSERVATION in transition
assert TransitionKey.ACTION in transition
assert TransitionKey.COMPLEMENTARY_DATA in transition
# Check that index and task_index are in complementary_data
comp_data = transition[TransitionKey.COMPLEMENTARY_DATA]
assert "index" in comp_data
assert "task_index" in comp_data
assert "task" in comp_data
# Verify values
assert torch.equal(comp_data["index"], batch["index"])
assert torch.equal(comp_data["task_index"], batch["task_index"])
assert comp_data["task"] == batch["task"]
def testtransition_to_batch_with_index_fields():
"""Test that transition_to_batch handles index and task_index fields correctly."""
# Create transition with index and task_index in complementary_data
transition = create_transition(
observation={"observation.state": torch.randn(1, 7)},
action=torch.randn(1, 4),
reward=1.5,
done=False,
complementary_data={
"task": ["navigate"],
"index": torch.tensor([100], dtype=torch.int64),
"task_index": torch.tensor([5], dtype=torch.int64),
},
)
batch = transition_to_batch(transition)
# Check that index and task_index are in the batch
assert "index" in batch
assert "task_index" in batch
assert "task" in batch
# Verify values
assert torch.equal(batch["index"], transition[TransitionKey.COMPLEMENTARY_DATA]["index"])
assert torch.equal(batch["task_index"], transition[TransitionKey.COMPLEMENTARY_DATA]["task_index"])
assert batch["task"] == transition[TransitionKey.COMPLEMENTARY_DATA]["task"]
def test_batch_to_transition_without_index_fields():
"""Test that conversion works without index and task_index fields."""
# Batch without index/task_index
batch = {
"observation.state": torch.randn(1, 7),
"action": torch.randn(1, 4),
"task": ["pick_cube"],
}
transition = batch_to_transition(batch)
comp_data = transition[TransitionKey.COMPLEMENTARY_DATA]
# Should have task but not index/task_index
assert "task" in comp_data
assert "index" not in comp_data
assert "task_index" not in comp_data
def test_transition_to_batch_without_index_fields():
"""Test that conversion works without index and task_index fields."""
# Transition without index/task_index
transition = create_transition(
observation={"observation.state": torch.randn(1, 7)},
action=torch.randn(1, 4),
complementary_data={"task": ["navigate"]},
)
batch = transition_to_batch(transition)
# Should have task but not index/task_index
assert "task" in batch
assert "index" not in batch
assert "task_index" not in batch
-103
View File
@@ -1659,109 +1659,6 @@ def test_state_file_naming_with_multiple_processors():
assert loaded_post.steps[0].window_size == 10
def test_default_batch_to_transition_with_index_fields():
"""Test that _default_batch_to_transition handles index and task_index fields correctly."""
from lerobot.processor.pipeline import _default_batch_to_transition
# Create batch with index and task_index fields
batch = {
"observation.state": torch.randn(1, 7),
"action": torch.randn(1, 4),
"next.reward": 1.5,
"next.done": False,
"task": ["pick_cube"],
"index": torch.tensor([42], dtype=torch.int64),
"task_index": torch.tensor([3], dtype=torch.int64),
}
transition = _default_batch_to_transition(batch)
# Check basic transition structure
assert TransitionKey.OBSERVATION in transition
assert TransitionKey.ACTION in transition
assert TransitionKey.COMPLEMENTARY_DATA in transition
# Check that index and task_index are in complementary_data
comp_data = transition[TransitionKey.COMPLEMENTARY_DATA]
assert "index" in comp_data
assert "task_index" in comp_data
assert "task" in comp_data
# Verify values
assert torch.equal(comp_data["index"], batch["index"])
assert torch.equal(comp_data["task_index"], batch["task_index"])
assert comp_data["task"] == batch["task"]
def test_default_transition_to_batch_with_index_fields():
"""Test that _default_transition_to_batch handles index and task_index fields correctly."""
from lerobot.processor.pipeline import _default_transition_to_batch
# Create transition with index and task_index in complementary_data
transition = create_transition(
observation={"observation.state": torch.randn(1, 7)},
action=torch.randn(1, 4),
reward=1.5,
done=False,
complementary_data={
"task": ["navigate"],
"index": torch.tensor([100], dtype=torch.int64),
"task_index": torch.tensor([5], dtype=torch.int64),
},
)
batch = _default_transition_to_batch(transition)
# Check that index and task_index are in the batch
assert "index" in batch
assert "task_index" in batch
assert "task" in batch
# Verify values
assert torch.equal(batch["index"], transition[TransitionKey.COMPLEMENTARY_DATA]["index"])
assert torch.equal(batch["task_index"], transition[TransitionKey.COMPLEMENTARY_DATA]["task_index"])
assert batch["task"] == transition[TransitionKey.COMPLEMENTARY_DATA]["task"]
def test_batch_to_transition_without_index_fields():
"""Test that conversion works without index and task_index fields."""
from lerobot.processor.pipeline import _default_batch_to_transition
# Batch without index/task_index
batch = {
"observation.state": torch.randn(1, 7),
"action": torch.randn(1, 4),
"task": ["pick_cube"],
}
transition = _default_batch_to_transition(batch)
comp_data = transition[TransitionKey.COMPLEMENTARY_DATA]
# Should have task but not index/task_index
assert "task" in comp_data
assert "index" not in comp_data
assert "task_index" not in comp_data
def test_transition_to_batch_without_index_fields():
"""Test that conversion works without index and task_index fields."""
from lerobot.processor.pipeline import _default_transition_to_batch
# Transition without index/task_index
transition = create_transition(
observation={"observation.state": torch.randn(1, 7)},
action=torch.randn(1, 4),
complementary_data={"task": ["navigate"]},
)
batch = _default_transition_to_batch(transition)
# Should have task but not index/task_index
assert "task" in batch
assert "index" not in batch
assert "task_index" not in batch
def test_override_with_device_strings():
"""Test overriding device parameters with string values."""