feat(batch_processor): Enhance ToBatchProcessor to handle action batching

- Updated ToBatchProcessor to add batch dimensions to actions in addition to observations.
- Implemented separate methods for processing observations and actions, improving code readability.
- Added comprehensive unit tests to validate action batching functionality across various tensor dimensions and types.
This commit is contained in:
Adil Zouitine
2025-07-24 17:20:57 +02:00
committed by Steven Palma
parent 21baa8fa02
commit 99de7567e6
3 changed files with 245 additions and 10 deletions
+222 -3
View File
@@ -17,11 +17,14 @@
import tempfile
from pathlib import Path
import numpy as np
import pytest
import torch
from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
from lerobot.processor import ProcessorStepRegistry, RobotProcessor, ToBatchProcessor, TransitionKey
from lerobot.processor import EnvTransition, ProcessorStepRegistry, RobotProcessor
from lerobot.processor.batch_processor import ToBatchProcessor
from lerobot.processor.pipeline import TransitionKey
def create_transition(
@@ -34,8 +37,8 @@ def create_transition(
TransitionKey.REWARD: reward,
TransitionKey.DONE: done,
TransitionKey.TRUNCATED: truncated,
TransitionKey.INFO: info if info is not None else {},
TransitionKey.COMPLEMENTARY_DATA: complementary_data if complementary_data is not None else {},
TransitionKey.INFO: info,
TransitionKey.COMPLEMENTARY_DATA: complementary_data,
}
@@ -421,3 +424,219 @@ def test_edge_case_zero_dimensional_tensors():
# 0D tensors should remain unchanged
assert torch.allclose(processed_obs[OBS_STATE], scalar_tensor)
assert torch.allclose(processed_obs["scalar_value"], scalar_tensor)
# Action-specific tests
def test_action_1d_to_2d():
"""Test that 1D action tensors get batch dimension added."""
processor = ToBatchProcessor()
# Create 1D action tensor
action_1d = torch.randn(4)
transition = create_transition(action=action_1d)
result = processor(transition)
# Should add batch dimension
assert result[TransitionKey.ACTION].shape == (1, 4)
assert torch.equal(result[TransitionKey.ACTION][0], action_1d)
def test_action_already_batched():
"""Test that already batched action tensors remain unchanged."""
processor = ToBatchProcessor()
# Test various batch sizes
action_batched_1 = torch.randn(1, 4)
action_batched_5 = torch.randn(5, 4)
# Single batch
transition = create_transition(action=action_batched_1)
result = processor(transition)
assert torch.equal(result[TransitionKey.ACTION], action_batched_1)
# Multiple batch
transition = create_transition(action=action_batched_5)
result = processor(transition)
assert torch.equal(result[TransitionKey.ACTION], action_batched_5)
def test_action_higher_dimensional():
"""Test that higher dimensional action tensors remain unchanged."""
processor = ToBatchProcessor()
# 3D action tensor (e.g., sequence of actions)
action_3d = torch.randn(2, 4, 3)
transition = create_transition(action=action_3d)
result = processor(transition)
assert torch.equal(result[TransitionKey.ACTION], action_3d)
# 4D action tensor
action_4d = torch.randn(2, 10, 4, 3)
transition = create_transition(action=action_4d)
result = processor(transition)
assert torch.equal(result[TransitionKey.ACTION], action_4d)
def test_action_scalar_tensor():
"""Test that scalar (0D) action tensors remain unchanged."""
processor = ToBatchProcessor()
action_scalar = torch.tensor(1.5)
transition = create_transition(action=action_scalar)
result = processor(transition)
# Should remain scalar
assert result[TransitionKey.ACTION].dim() == 0
assert torch.equal(result[TransitionKey.ACTION], action_scalar)
def test_action_non_tensor():
"""Test that non-tensor actions remain unchanged."""
processor = ToBatchProcessor()
# List action
action_list = [0.1, 0.2, 0.3, 0.4]
transition = create_transition(action=action_list)
result = processor(transition)
assert result[TransitionKey.ACTION] == action_list
# Numpy array action (as Python object, not converted)
action_numpy = np.array([1, 2, 3, 4])
transition = create_transition(action=action_numpy)
result = processor(transition)
assert np.array_equal(result[TransitionKey.ACTION], action_numpy)
# String action (edge case)
action_string = "forward"
transition = create_transition(action=action_string)
result = processor(transition)
assert result[TransitionKey.ACTION] == action_string
# Dict action (structured action)
action_dict = {"linear": [0.5, 0.0], "angular": 0.2}
transition = create_transition(action=action_dict)
result = processor(transition)
assert result[TransitionKey.ACTION] == action_dict
def test_action_none():
"""Test that None action is handled correctly."""
processor = ToBatchProcessor()
transition = create_transition(action=None)
result = processor(transition)
assert result[TransitionKey.ACTION] is None
def test_action_with_observation():
"""Test action processing together with observation processing."""
processor = ToBatchProcessor()
# Both need batching
observation = {
OBS_STATE: torch.randn(7),
OBS_IMAGE: torch.randn(64, 64, 3),
}
action = torch.randn(4)
transition = create_transition(observation=observation, action=action)
result = processor(transition)
# Both should be batched
assert result[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 7)
assert result[TransitionKey.OBSERVATION][OBS_IMAGE].shape == (1, 64, 64, 3)
assert result[TransitionKey.ACTION].shape == (1, 4)
def test_action_different_sizes():
"""Test action processing with various action dimensions."""
processor = ToBatchProcessor()
# Different action sizes (robot with different DOF)
action_sizes = [1, 2, 4, 7, 10, 20]
for size in action_sizes:
action = torch.randn(size)
transition = create_transition(action=action)
result = processor(transition)
assert result[TransitionKey.ACTION].shape == (1, size)
assert torch.equal(result[TransitionKey.ACTION][0], action)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_action_device_compatibility():
"""Test action processing on different devices."""
processor = ToBatchProcessor()
# CUDA action
action_cuda = torch.randn(4, device="cuda")
transition = create_transition(action=action_cuda)
result = processor(transition)
assert result[TransitionKey.ACTION].shape == (1, 4)
assert result[TransitionKey.ACTION].device.type == "cuda"
# CPU action
action_cpu = torch.randn(4, device="cpu")
transition = create_transition(action=action_cpu)
result = processor(transition)
assert result[TransitionKey.ACTION].shape == (1, 4)
assert result[TransitionKey.ACTION].device.type == "cpu"
def test_action_dtype_preservation():
"""Test that action dtype is preserved during processing."""
processor = ToBatchProcessor()
# Different dtypes
dtypes = [torch.float32, torch.float64, torch.int32, torch.int64]
for dtype in dtypes:
action = torch.randn(4).to(dtype)
transition = create_transition(action=action)
result = processor(transition)
assert result[TransitionKey.ACTION].dtype == dtype
assert result[TransitionKey.ACTION].shape == (1, 4)
def test_action_in_place_mutation():
"""Test that the processor mutates the transition in place for actions."""
processor = ToBatchProcessor()
action = torch.randn(4)
transition = create_transition(action=action)
# Store reference to original transition
original_transition = transition
# Process
result = processor(transition)
# Should be the same object (in-place mutation)
assert result is original_transition
assert result[TransitionKey.ACTION].shape == (1, 4)
def test_empty_action_tensor():
"""Test handling of empty action tensors."""
processor = ToBatchProcessor()
# Empty 1D tensor
action_empty = torch.tensor([])
transition = create_transition(action=action_empty)
result = processor(transition)
# Should add batch dimension even to empty tensor
assert result[TransitionKey.ACTION].shape == (1, 0)
# Empty 2D tensor (already batched)
action_empty_2d = torch.randn(1, 0)
transition = create_transition(action=action_empty_2d)
result = processor(transition)
# Should remain unchanged
assert result[TransitionKey.ACTION].shape == (1, 0)