refactor(converters): gather converters and refactor the logic (#1833)

* refactor(converters): move batch transition functions to converters module

- Moved `_default_batch_to_transition` and `_default_transition_to_batch` functions from `pipeline.py` to `converters.py` for better organization and separation of concerns.
- Updated references in `RobotProcessor` to use the new location of these functions.
- Added tests to ensure correct functionality of the transition functions, including handling of index and task_index fields.
- Removed redundant tests from `pipeline.py` to streamline the test suite.

* refactor(processor): reorganize EnvTransition and TransitionKey definitions

- Moved `EnvTransition` and `TransitionKey` classes from `pipeline.py` to a new `core.py` module for better structure and maintainability.
- Updated import statements across relevant modules to reflect the new location of these definitions, ensuring consistent access throughout the codebase.

* refactor(converters): rename and update dataset frame conversion functions

- Replaced `to_dataset_frame` with `transition_to_dataset_frame` for clarity and consistency in naming.
- Updated references in `record.py`, `pipeline.py`, and tests to use the new function name.
- Introduced `merge_transitions` to streamline the merging of transitions, enhancing readability and maintainability.
- Adjusted related tests to ensure correct functionality with the new naming conventions.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix(processor): solve conflict artefacts

* refactor(converters): remove unused identity function and update type hints for merge_transitions

* refactor(processor): remove unused identity import and clean up gym_manipulator.py

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Steven Palma <steven.palma@huggingface.co>
This commit is contained in:
Adil Zouitine
2025-09-02 15:33:38 +02:00
committed by GitHub
parent 2c802ac134
commit 645c87e3a9
10 changed files with 421 additions and 341 deletions
+119 -3
View File
@@ -3,11 +3,13 @@ import pytest
import torch
from lerobot.processor.converters import (
to_dataset_frame,
batch_to_transition,
to_output_robot_action,
to_tensor,
to_transition_robot_observation,
to_transition_teleop_action,
transition_to_batch,
transition_to_dataset_frame,
)
from lerobot.processor.pipeline import TransitionKey
@@ -107,7 +109,7 @@ def test_to_output_robot_action_strips_prefix_and_filters_pos_keys_only():
assert out["gripper.pos"] == pytest.approx(33.0)
def test_to_dataset_frame_merge_and_pack_vectors_and_metadata():
def test_transition_to_dataset_frame_merge_and_pack_vectors_and_metadata():
# Fabricate dataset features (as stored in dataset.meta["features"])
features = {
# Action vector: 3 elements in specific order
@@ -160,7 +162,7 @@ def test_to_dataset_frame_merge_and_pack_vectors_and_metadata():
}
# Directly call the refactored function
batch = to_dataset_frame([teleop_transition, robot_transition], features)
batch = transition_to_dataset_frame([teleop_transition, robot_transition], features)
# Images passthrough
assert "observation.images.front" in batch
@@ -377,3 +379,117 @@ def test_to_tensor_unsupported_type():
with pytest.raises(TypeError, match="Unsupported type for tensor conversion"):
to_tensor(object())
def create_transition(
observation=None, action=None, reward=0.0, done=False, truncated=False, info=None, complementary_data=None
):
"""Helper to create an EnvTransition dictionary."""
return {
TransitionKey.OBSERVATION: observation,
TransitionKey.ACTION: action,
TransitionKey.REWARD: reward,
TransitionKey.DONE: done,
TransitionKey.TRUNCATED: truncated,
TransitionKey.INFO: info if info is not None else {},
TransitionKey.COMPLEMENTARY_DATA: complementary_data if complementary_data is not None else {},
}
def test_batch_to_transition_with_index_fields():
"""Test that batch_to_transition handles index and task_index fields correctly."""
# Create batch with index and task_index fields
batch = {
"observation.state": torch.randn(1, 7),
"action": torch.randn(1, 4),
"next.reward": 1.5,
"next.done": False,
"task": ["pick_cube"],
"index": torch.tensor([42], dtype=torch.int64),
"task_index": torch.tensor([3], dtype=torch.int64),
}
transition = batch_to_transition(batch)
# Check basic transition structure
assert TransitionKey.OBSERVATION in transition
assert TransitionKey.ACTION in transition
assert TransitionKey.COMPLEMENTARY_DATA in transition
# Check that index and task_index are in complementary_data
comp_data = transition[TransitionKey.COMPLEMENTARY_DATA]
assert "index" in comp_data
assert "task_index" in comp_data
assert "task" in comp_data
# Verify values
assert torch.equal(comp_data["index"], batch["index"])
assert torch.equal(comp_data["task_index"], batch["task_index"])
assert comp_data["task"] == batch["task"]
def testtransition_to_batch_with_index_fields():
"""Test that transition_to_batch handles index and task_index fields correctly."""
# Create transition with index and task_index in complementary_data
transition = create_transition(
observation={"observation.state": torch.randn(1, 7)},
action=torch.randn(1, 4),
reward=1.5,
done=False,
complementary_data={
"task": ["navigate"],
"index": torch.tensor([100], dtype=torch.int64),
"task_index": torch.tensor([5], dtype=torch.int64),
},
)
batch = transition_to_batch(transition)
# Check that index and task_index are in the batch
assert "index" in batch
assert "task_index" in batch
assert "task" in batch
# Verify values
assert torch.equal(batch["index"], transition[TransitionKey.COMPLEMENTARY_DATA]["index"])
assert torch.equal(batch["task_index"], transition[TransitionKey.COMPLEMENTARY_DATA]["task_index"])
assert batch["task"] == transition[TransitionKey.COMPLEMENTARY_DATA]["task"]
def test_batch_to_transition_without_index_fields():
"""Test that conversion works without index and task_index fields."""
# Batch without index/task_index
batch = {
"observation.state": torch.randn(1, 7),
"action": torch.randn(1, 4),
"task": ["pick_cube"],
}
transition = batch_to_transition(batch)
comp_data = transition[TransitionKey.COMPLEMENTARY_DATA]
# Should have task but not index/task_index
assert "task" in comp_data
assert "index" not in comp_data
assert "task_index" not in comp_data
def test_transition_to_batch_without_index_fields():
"""Test that conversion works without index and task_index fields."""
# Transition without index/task_index
transition = create_transition(
observation={"observation.state": torch.randn(1, 7)},
action=torch.randn(1, 4),
complementary_data={"task": ["navigate"]},
)
batch = transition_to_batch(transition)
# Should have task but not index/task_index
assert "task" in batch
assert "index" not in batch
assert "task_index" not in batch