mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-18 18:20:08 +00:00
refactor(converters): gather converters and refactor the logic (#1833)
* refactor(converters): move batch transition functions to converters module - Moved `_default_batch_to_transition` and `_default_transition_to_batch` functions from `pipeline.py` to `converters.py` for better organization and separation of concerns. - Updated references in `RobotProcessor` to use the new location of these functions. - Added tests to ensure correct functionality of the transition functions, including handling of index and task_index fields. - Removed redundant tests from `pipeline.py` to streamline the test suite. * refactor(processor): reorganize EnvTransition and TransitionKey definitions - Moved `EnvTransition` and `TransitionKey` classes from `pipeline.py` to a new `core.py` module for better structure and maintainability. - Updated import statements across relevant modules to reflect the new location of these definitions, ensuring consistent access throughout the codebase. * refactor(converters): rename and update dataset frame conversion functions - Replaced `to_dataset_frame` with `transition_to_dataset_frame` for clarity and consistency in naming. - Updated references in `record.py`, `pipeline.py`, and tests to use the new function name. - Introduced `merge_transitions` to streamline the merging of transitions, enhancing readability and maintainability. - Adjusted related tests to ensure correct functionality with the new naming conventions. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix(processor): solve conflict artefacts * refactor(converters): remove unused identity function and update type hints for merge_transitions * refactor(processor): remove unused identity import and clean up gym_manipulator.py --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Steven Palma <steven.palma@huggingface.co>
This commit is contained in:
@@ -3,11 +3,13 @@ import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.processor.converters import (
|
||||
to_dataset_frame,
|
||||
batch_to_transition,
|
||||
to_output_robot_action,
|
||||
to_tensor,
|
||||
to_transition_robot_observation,
|
||||
to_transition_teleop_action,
|
||||
transition_to_batch,
|
||||
transition_to_dataset_frame,
|
||||
)
|
||||
from lerobot.processor.pipeline import TransitionKey
|
||||
|
||||
@@ -107,7 +109,7 @@ def test_to_output_robot_action_strips_prefix_and_filters_pos_keys_only():
|
||||
assert out["gripper.pos"] == pytest.approx(33.0)
|
||||
|
||||
|
||||
def test_to_dataset_frame_merge_and_pack_vectors_and_metadata():
|
||||
def test_transition_to_dataset_frame_merge_and_pack_vectors_and_metadata():
|
||||
# Fabricate dataset features (as stored in dataset.meta["features"])
|
||||
features = {
|
||||
# Action vector: 3 elements in specific order
|
||||
@@ -160,7 +162,7 @@ def test_to_dataset_frame_merge_and_pack_vectors_and_metadata():
|
||||
}
|
||||
|
||||
# Directly call the refactored function
|
||||
batch = to_dataset_frame([teleop_transition, robot_transition], features)
|
||||
batch = transition_to_dataset_frame([teleop_transition, robot_transition], features)
|
||||
|
||||
# Images passthrough
|
||||
assert "observation.images.front" in batch
|
||||
@@ -377,3 +379,117 @@ def test_to_tensor_unsupported_type():
|
||||
|
||||
with pytest.raises(TypeError, match="Unsupported type for tensor conversion"):
|
||||
to_tensor(object())
|
||||
|
||||
|
||||
def create_transition(
|
||||
observation=None, action=None, reward=0.0, done=False, truncated=False, info=None, complementary_data=None
|
||||
):
|
||||
"""Helper to create an EnvTransition dictionary."""
|
||||
return {
|
||||
TransitionKey.OBSERVATION: observation,
|
||||
TransitionKey.ACTION: action,
|
||||
TransitionKey.REWARD: reward,
|
||||
TransitionKey.DONE: done,
|
||||
TransitionKey.TRUNCATED: truncated,
|
||||
TransitionKey.INFO: info if info is not None else {},
|
||||
TransitionKey.COMPLEMENTARY_DATA: complementary_data if complementary_data is not None else {},
|
||||
}
|
||||
|
||||
|
||||
def test_batch_to_transition_with_index_fields():
|
||||
"""Test that batch_to_transition handles index and task_index fields correctly."""
|
||||
|
||||
# Create batch with index and task_index fields
|
||||
batch = {
|
||||
"observation.state": torch.randn(1, 7),
|
||||
"action": torch.randn(1, 4),
|
||||
"next.reward": 1.5,
|
||||
"next.done": False,
|
||||
"task": ["pick_cube"],
|
||||
"index": torch.tensor([42], dtype=torch.int64),
|
||||
"task_index": torch.tensor([3], dtype=torch.int64),
|
||||
}
|
||||
|
||||
transition = batch_to_transition(batch)
|
||||
|
||||
# Check basic transition structure
|
||||
assert TransitionKey.OBSERVATION in transition
|
||||
assert TransitionKey.ACTION in transition
|
||||
assert TransitionKey.COMPLEMENTARY_DATA in transition
|
||||
|
||||
# Check that index and task_index are in complementary_data
|
||||
comp_data = transition[TransitionKey.COMPLEMENTARY_DATA]
|
||||
assert "index" in comp_data
|
||||
assert "task_index" in comp_data
|
||||
assert "task" in comp_data
|
||||
|
||||
# Verify values
|
||||
assert torch.equal(comp_data["index"], batch["index"])
|
||||
assert torch.equal(comp_data["task_index"], batch["task_index"])
|
||||
assert comp_data["task"] == batch["task"]
|
||||
|
||||
|
||||
def testtransition_to_batch_with_index_fields():
|
||||
"""Test that transition_to_batch handles index and task_index fields correctly."""
|
||||
|
||||
# Create transition with index and task_index in complementary_data
|
||||
transition = create_transition(
|
||||
observation={"observation.state": torch.randn(1, 7)},
|
||||
action=torch.randn(1, 4),
|
||||
reward=1.5,
|
||||
done=False,
|
||||
complementary_data={
|
||||
"task": ["navigate"],
|
||||
"index": torch.tensor([100], dtype=torch.int64),
|
||||
"task_index": torch.tensor([5], dtype=torch.int64),
|
||||
},
|
||||
)
|
||||
|
||||
batch = transition_to_batch(transition)
|
||||
|
||||
# Check that index and task_index are in the batch
|
||||
assert "index" in batch
|
||||
assert "task_index" in batch
|
||||
assert "task" in batch
|
||||
|
||||
# Verify values
|
||||
assert torch.equal(batch["index"], transition[TransitionKey.COMPLEMENTARY_DATA]["index"])
|
||||
assert torch.equal(batch["task_index"], transition[TransitionKey.COMPLEMENTARY_DATA]["task_index"])
|
||||
assert batch["task"] == transition[TransitionKey.COMPLEMENTARY_DATA]["task"]
|
||||
|
||||
|
||||
def test_batch_to_transition_without_index_fields():
|
||||
"""Test that conversion works without index and task_index fields."""
|
||||
|
||||
# Batch without index/task_index
|
||||
batch = {
|
||||
"observation.state": torch.randn(1, 7),
|
||||
"action": torch.randn(1, 4),
|
||||
"task": ["pick_cube"],
|
||||
}
|
||||
|
||||
transition = batch_to_transition(batch)
|
||||
comp_data = transition[TransitionKey.COMPLEMENTARY_DATA]
|
||||
|
||||
# Should have task but not index/task_index
|
||||
assert "task" in comp_data
|
||||
assert "index" not in comp_data
|
||||
assert "task_index" not in comp_data
|
||||
|
||||
|
||||
def test_transition_to_batch_without_index_fields():
|
||||
"""Test that conversion works without index and task_index fields."""
|
||||
|
||||
# Transition without index/task_index
|
||||
transition = create_transition(
|
||||
observation={"observation.state": torch.randn(1, 7)},
|
||||
action=torch.randn(1, 4),
|
||||
complementary_data={"task": ["navigate"]},
|
||||
)
|
||||
|
||||
batch = transition_to_batch(transition)
|
||||
|
||||
# Should have task but not index/task_index
|
||||
assert "task" in batch
|
||||
assert "index" not in batch
|
||||
assert "task_index" not in batch
|
||||
|
||||
Reference in New Issue
Block a user