feat(batch_processor): Enhance ToBatchProcessor to handle action batching

- Updated ToBatchProcessor to add batch dimensions to actions in addition to observations.
- Implemented separate methods for processing observations and actions, improving code readability.
- Added comprehensive unit tests to validate action batching functionality across various tensor dimensions and types.
This commit is contained in:
Adil Zouitine
2025-07-24 17:20:57 +02:00
committed by Steven Palma
parent 21baa8fa02
commit 99de7567e6
3 changed files with 245 additions and 10 deletions
+18 -5
View File
@@ -24,9 +24,9 @@ from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, Tra
@dataclass @dataclass
@ProcessorStepRegistry.register(name="to_batch_processor") @ProcessorStepRegistry.register(name="to_batch_processor")
class ToBatchProcessor: class ToBatchProcessor:
"""Processor that adds batch dimensions to observations when needed. """Processor that adds batch dimensions to observations and actions when needed.
This processor ensures that observations have proper batch dimensions for model processing: This processor ensures that observations and actions have proper batch dimensions for model processing:
- For state observations (observation.state, observation.environment_state): - For state observations (observation.state, observation.environment_state):
Adds batch dimension (unsqueeze at dim=0) if tensor is 1-dimensional Adds batch dimension (unsqueeze at dim=0) if tensor is 1-dimensional
@@ -34,6 +34,9 @@ class ToBatchProcessor:
- For image observations (observation.image, observation.images.*): - For image observations (observation.image, observation.images.*):
Adds batch dimension (unsqueeze at dim=0) if tensor is 3-dimensional (H, W, C) Adds batch dimension (unsqueeze at dim=0) if tensor is 3-dimensional (H, W, C)
- For actions:
Adds batch dimension (unsqueeze at dim=0) if tensor is 1-dimensional
This is useful when processing single transitions that need to be batched for This is useful when processing single transitions that need to be batched for
model inference or when converting from unbatched environment outputs to model inference or when converting from unbatched environment outputs to
batched model inputs. batched model inputs.
@@ -45,15 +48,21 @@ class ToBatchProcessor:
```python ```python
# State: (7,) -> (1, 7) # State: (7,) -> (1, 7)
# Image: (224, 224, 3) -> (1, 224, 224, 3) # Image: (224, 224, 3) -> (1, 224, 224, 3)
# Action: (4,) -> (1, 4)
# Already batched: (1, 7) -> (1, 7) [unchanged] # Already batched: (1, 7) -> (1, 7) [unchanged]
``` ```
""" """
def __call__(self, transition: EnvTransition) -> EnvTransition: def __call__(self, transition: EnvTransition) -> EnvTransition:
observation = transition.get(TransitionKey.OBSERVATION) self._process_observation(transition)
self._process_action(transition)
return transition
def _process_observation(self, transition: EnvTransition) -> None:
"""Process observation component in-place, adding batch dimensions where needed."""
observation = transition.get(TransitionKey.OBSERVATION)
if observation is None: if observation is None:
return transition return
# Process state observations - add batch dim if 1D # Process state observations - add batch dim if 1D
for state_key in [OBS_STATE, OBS_ENV_STATE]: for state_key in [OBS_STATE, OBS_ENV_STATE]:
@@ -73,7 +82,11 @@ class ToBatchProcessor:
if key.startswith(f"{OBS_IMAGES}.") and isinstance(value, Tensor) and value.dim() == 3: if key.startswith(f"{OBS_IMAGES}.") and isinstance(value, Tensor) and value.dim() == 3:
observation[key] = value.unsqueeze(0) observation[key] = value.unsqueeze(0)
return transition def _process_action(self, transition: EnvTransition) -> None:
"""Process action component in-place, adding batch dimension if needed."""
action = transition.get(TransitionKey.ACTION)
if action is not None and isinstance(action, Tensor) and action.dim() == 1:
transition[TransitionKey.ACTION] = action.unsqueeze(0)
def get_config(self) -> dict[str, Any]: def get_config(self) -> dict[str, Any]:
"""Return configuration for serialization.""" """Return configuration for serialization."""
@@ -46,6 +46,7 @@ from huggingface_hub import hf_hub_download
from safetensors.torch import load_file as load_safetensors from safetensors.torch import load_file as load_safetensors
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.processor.batch_processor import ToBatchProcessor
from lerobot.processor.normalize_processor import NormalizerProcessor, UnnormalizerProcessor from lerobot.processor.normalize_processor import NormalizerProcessor, UnnormalizerProcessor
from lerobot.processor.pipeline import RobotProcessor from lerobot.processor.pipeline import RobotProcessor
@@ -403,14 +404,16 @@ def main():
preprocessor_steps = [ preprocessor_steps = [
NormalizerProcessor(features=input_features, norm_map=norm_map, stats=stats), NormalizerProcessor(features=input_features, norm_map=norm_map, stats=stats),
NormalizerProcessor(features=output_features, norm_map=norm_map, stats=stats), NormalizerProcessor(features=output_features, norm_map=norm_map, stats=stats),
ToBatchProcessor(),
] ]
preprocessor = RobotProcessor(preprocessor_steps, name=f"{policy_type}_preprocessor") preprocessor = RobotProcessor(preprocessor_steps, name="preprocessor")
# Create postprocessor with unnormalizer for outputs only # Create postprocessor with unnormalizer for outputs only
postprocessor_steps = [ postprocessor_steps = [
UnnormalizerProcessor(features=output_features, norm_map=norm_map, stats=stats), UnnormalizerProcessor(features=output_features, norm_map=norm_map, stats=stats),
ToBatchProcessor(),
] ]
postprocessor = RobotProcessor(postprocessor_steps, name=f"{policy_type}_postprocessor") postprocessor = RobotProcessor(postprocessor_steps, name="postprocessor")
# Determine hub repo ID if pushing to hub # Determine hub repo ID if pushing to hub
if args.push_to_hub: if args.push_to_hub:
+222 -3
View File
@@ -17,11 +17,14 @@
import tempfile import tempfile
from pathlib import Path from pathlib import Path
import numpy as np
import pytest import pytest
import torch import torch
from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
from lerobot.processor import ProcessorStepRegistry, RobotProcessor, ToBatchProcessor, TransitionKey from lerobot.processor import EnvTransition, ProcessorStepRegistry, RobotProcessor
from lerobot.processor.batch_processor import ToBatchProcessor
from lerobot.processor.pipeline import TransitionKey
def create_transition( def create_transition(
@@ -34,8 +37,8 @@ def create_transition(
TransitionKey.REWARD: reward, TransitionKey.REWARD: reward,
TransitionKey.DONE: done, TransitionKey.DONE: done,
TransitionKey.TRUNCATED: truncated, TransitionKey.TRUNCATED: truncated,
TransitionKey.INFO: info if info is not None else {}, TransitionKey.INFO: info,
TransitionKey.COMPLEMENTARY_DATA: complementary_data if complementary_data is not None else {}, TransitionKey.COMPLEMENTARY_DATA: complementary_data,
} }
@@ -421,3 +424,219 @@ def test_edge_case_zero_dimensional_tensors():
# 0D tensors should remain unchanged # 0D tensors should remain unchanged
assert torch.allclose(processed_obs[OBS_STATE], scalar_tensor) assert torch.allclose(processed_obs[OBS_STATE], scalar_tensor)
assert torch.allclose(processed_obs["scalar_value"], scalar_tensor) assert torch.allclose(processed_obs["scalar_value"], scalar_tensor)
# Action-specific tests
def test_action_1d_to_2d():
"""Test that 1D action tensors get batch dimension added."""
processor = ToBatchProcessor()
# Create 1D action tensor
action_1d = torch.randn(4)
transition = create_transition(action=action_1d)
result = processor(transition)
# Should add batch dimension
assert result[TransitionKey.ACTION].shape == (1, 4)
assert torch.equal(result[TransitionKey.ACTION][0], action_1d)
def test_action_already_batched():
"""Test that already batched action tensors remain unchanged."""
processor = ToBatchProcessor()
# Test various batch sizes
action_batched_1 = torch.randn(1, 4)
action_batched_5 = torch.randn(5, 4)
# Single batch
transition = create_transition(action=action_batched_1)
result = processor(transition)
assert torch.equal(result[TransitionKey.ACTION], action_batched_1)
# Multiple batch
transition = create_transition(action=action_batched_5)
result = processor(transition)
assert torch.equal(result[TransitionKey.ACTION], action_batched_5)
def test_action_higher_dimensional():
"""Test that higher dimensional action tensors remain unchanged."""
processor = ToBatchProcessor()
# 3D action tensor (e.g., sequence of actions)
action_3d = torch.randn(2, 4, 3)
transition = create_transition(action=action_3d)
result = processor(transition)
assert torch.equal(result[TransitionKey.ACTION], action_3d)
# 4D action tensor
action_4d = torch.randn(2, 10, 4, 3)
transition = create_transition(action=action_4d)
result = processor(transition)
assert torch.equal(result[TransitionKey.ACTION], action_4d)
def test_action_scalar_tensor():
"""Test that scalar (0D) action tensors remain unchanged."""
processor = ToBatchProcessor()
action_scalar = torch.tensor(1.5)
transition = create_transition(action=action_scalar)
result = processor(transition)
# Should remain scalar
assert result[TransitionKey.ACTION].dim() == 0
assert torch.equal(result[TransitionKey.ACTION], action_scalar)
def test_action_non_tensor():
"""Test that non-tensor actions remain unchanged."""
processor = ToBatchProcessor()
# List action
action_list = [0.1, 0.2, 0.3, 0.4]
transition = create_transition(action=action_list)
result = processor(transition)
assert result[TransitionKey.ACTION] == action_list
# Numpy array action (as Python object, not converted)
action_numpy = np.array([1, 2, 3, 4])
transition = create_transition(action=action_numpy)
result = processor(transition)
assert np.array_equal(result[TransitionKey.ACTION], action_numpy)
# String action (edge case)
action_string = "forward"
transition = create_transition(action=action_string)
result = processor(transition)
assert result[TransitionKey.ACTION] == action_string
# Dict action (structured action)
action_dict = {"linear": [0.5, 0.0], "angular": 0.2}
transition = create_transition(action=action_dict)
result = processor(transition)
assert result[TransitionKey.ACTION] == action_dict
def test_action_none():
"""Test that None action is handled correctly."""
processor = ToBatchProcessor()
transition = create_transition(action=None)
result = processor(transition)
assert result[TransitionKey.ACTION] is None
def test_action_with_observation():
"""Test action processing together with observation processing."""
processor = ToBatchProcessor()
# Both need batching
observation = {
OBS_STATE: torch.randn(7),
OBS_IMAGE: torch.randn(64, 64, 3),
}
action = torch.randn(4)
transition = create_transition(observation=observation, action=action)
result = processor(transition)
# Both should be batched
assert result[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 7)
assert result[TransitionKey.OBSERVATION][OBS_IMAGE].shape == (1, 64, 64, 3)
assert result[TransitionKey.ACTION].shape == (1, 4)
def test_action_different_sizes():
"""Test action processing with various action dimensions."""
processor = ToBatchProcessor()
# Different action sizes (robot with different DOF)
action_sizes = [1, 2, 4, 7, 10, 20]
for size in action_sizes:
action = torch.randn(size)
transition = create_transition(action=action)
result = processor(transition)
assert result[TransitionKey.ACTION].shape == (1, size)
assert torch.equal(result[TransitionKey.ACTION][0], action)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_action_device_compatibility():
"""Test action processing on different devices."""
processor = ToBatchProcessor()
# CUDA action
action_cuda = torch.randn(4, device="cuda")
transition = create_transition(action=action_cuda)
result = processor(transition)
assert result[TransitionKey.ACTION].shape == (1, 4)
assert result[TransitionKey.ACTION].device.type == "cuda"
# CPU action
action_cpu = torch.randn(4, device="cpu")
transition = create_transition(action=action_cpu)
result = processor(transition)
assert result[TransitionKey.ACTION].shape == (1, 4)
assert result[TransitionKey.ACTION].device.type == "cpu"
def test_action_dtype_preservation():
"""Test that action dtype is preserved during processing."""
processor = ToBatchProcessor()
# Different dtypes
dtypes = [torch.float32, torch.float64, torch.int32, torch.int64]
for dtype in dtypes:
action = torch.randn(4).to(dtype)
transition = create_transition(action=action)
result = processor(transition)
assert result[TransitionKey.ACTION].dtype == dtype
assert result[TransitionKey.ACTION].shape == (1, 4)
def test_action_in_place_mutation():
"""Test that the processor mutates the transition in place for actions."""
processor = ToBatchProcessor()
action = torch.randn(4)
transition = create_transition(action=action)
# Store reference to original transition
original_transition = transition
# Process
result = processor(transition)
# Should be the same object (in-place mutation)
assert result is original_transition
assert result[TransitionKey.ACTION].shape == (1, 4)
def test_empty_action_tensor():
"""Test handling of empty action tensors."""
processor = ToBatchProcessor()
# Empty 1D tensor
action_empty = torch.tensor([])
transition = create_transition(action=action_empty)
result = processor(transition)
# Should add batch dimension even to empty tensor
assert result[TransitionKey.ACTION].shape == (1, 0)
# Empty 2D tensor (already batched)
action_empty_2d = torch.randn(1, 0)
transition = create_transition(action=action_empty_2d)
result = processor(transition)
# Should remain unchanged
assert result[TransitionKey.ACTION].shape == (1, 0)