mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 20:19:43 +00:00
feat(batch_processor): Enhance ToBatchProcessor to handle action batching
- Updated ToBatchProcessor to add batch dimensions to actions in addition to observations. - Implemented separate methods for processing observations and actions, improving code readability. - Added comprehensive unit tests to validate action batching functionality across various tensor dimensions and types.
This commit is contained in:
committed by
Steven Palma
parent
21baa8fa02
commit
99de7567e6
@@ -24,9 +24,9 @@ from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, Tra
|
|||||||
@dataclass
|
@dataclass
|
||||||
@ProcessorStepRegistry.register(name="to_batch_processor")
|
@ProcessorStepRegistry.register(name="to_batch_processor")
|
||||||
class ToBatchProcessor:
|
class ToBatchProcessor:
|
||||||
"""Processor that adds batch dimensions to observations when needed.
|
"""Processor that adds batch dimensions to observations and actions when needed.
|
||||||
|
|
||||||
This processor ensures that observations have proper batch dimensions for model processing:
|
This processor ensures that observations and actions have proper batch dimensions for model processing:
|
||||||
|
|
||||||
- For state observations (observation.state, observation.environment_state):
|
- For state observations (observation.state, observation.environment_state):
|
||||||
Adds batch dimension (unsqueeze at dim=0) if tensor is 1-dimensional
|
Adds batch dimension (unsqueeze at dim=0) if tensor is 1-dimensional
|
||||||
@@ -34,6 +34,9 @@ class ToBatchProcessor:
|
|||||||
- For image observations (observation.image, observation.images.*):
|
- For image observations (observation.image, observation.images.*):
|
||||||
Adds batch dimension (unsqueeze at dim=0) if tensor is 3-dimensional (H, W, C)
|
Adds batch dimension (unsqueeze at dim=0) if tensor is 3-dimensional (H, W, C)
|
||||||
|
|
||||||
|
- For actions:
|
||||||
|
Adds batch dimension (unsqueeze at dim=0) if tensor is 1-dimensional
|
||||||
|
|
||||||
This is useful when processing single transitions that need to be batched for
|
This is useful when processing single transitions that need to be batched for
|
||||||
model inference or when converting from unbatched environment outputs to
|
model inference or when converting from unbatched environment outputs to
|
||||||
batched model inputs.
|
batched model inputs.
|
||||||
@@ -45,15 +48,21 @@ class ToBatchProcessor:
|
|||||||
```python
|
```python
|
||||||
# State: (7,) -> (1, 7)
|
# State: (7,) -> (1, 7)
|
||||||
# Image: (224, 224, 3) -> (1, 224, 224, 3)
|
# Image: (224, 224, 3) -> (1, 224, 224, 3)
|
||||||
|
# Action: (4,) -> (1, 4)
|
||||||
# Already batched: (1, 7) -> (1, 7) [unchanged]
|
# Already batched: (1, 7) -> (1, 7) [unchanged]
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
observation = transition.get(TransitionKey.OBSERVATION)
|
self._process_observation(transition)
|
||||||
|
self._process_action(transition)
|
||||||
|
return transition
|
||||||
|
|
||||||
|
def _process_observation(self, transition: EnvTransition) -> None:
|
||||||
|
"""Process observation component in-place, adding batch dimensions where needed."""
|
||||||
|
observation = transition.get(TransitionKey.OBSERVATION)
|
||||||
if observation is None:
|
if observation is None:
|
||||||
return transition
|
return
|
||||||
|
|
||||||
# Process state observations - add batch dim if 1D
|
# Process state observations - add batch dim if 1D
|
||||||
for state_key in [OBS_STATE, OBS_ENV_STATE]:
|
for state_key in [OBS_STATE, OBS_ENV_STATE]:
|
||||||
@@ -73,7 +82,11 @@ class ToBatchProcessor:
|
|||||||
if key.startswith(f"{OBS_IMAGES}.") and isinstance(value, Tensor) and value.dim() == 3:
|
if key.startswith(f"{OBS_IMAGES}.") and isinstance(value, Tensor) and value.dim() == 3:
|
||||||
observation[key] = value.unsqueeze(0)
|
observation[key] = value.unsqueeze(0)
|
||||||
|
|
||||||
return transition
|
def _process_action(self, transition: EnvTransition) -> None:
|
||||||
|
"""Process action component in-place, adding batch dimension if needed."""
|
||||||
|
action = transition.get(TransitionKey.ACTION)
|
||||||
|
if action is not None and isinstance(action, Tensor) and action.dim() == 1:
|
||||||
|
transition[TransitionKey.ACTION] = action.unsqueeze(0)
|
||||||
|
|
||||||
def get_config(self) -> dict[str, Any]:
|
def get_config(self) -> dict[str, Any]:
|
||||||
"""Return configuration for serialization."""
|
"""Return configuration for serialization."""
|
||||||
|
|||||||
@@ -46,6 +46,7 @@ from huggingface_hub import hf_hub_download
|
|||||||
from safetensors.torch import load_file as load_safetensors
|
from safetensors.torch import load_file as load_safetensors
|
||||||
|
|
||||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||||
|
from lerobot.processor.batch_processor import ToBatchProcessor
|
||||||
from lerobot.processor.normalize_processor import NormalizerProcessor, UnnormalizerProcessor
|
from lerobot.processor.normalize_processor import NormalizerProcessor, UnnormalizerProcessor
|
||||||
from lerobot.processor.pipeline import RobotProcessor
|
from lerobot.processor.pipeline import RobotProcessor
|
||||||
|
|
||||||
@@ -403,14 +404,16 @@ def main():
|
|||||||
preprocessor_steps = [
|
preprocessor_steps = [
|
||||||
NormalizerProcessor(features=input_features, norm_map=norm_map, stats=stats),
|
NormalizerProcessor(features=input_features, norm_map=norm_map, stats=stats),
|
||||||
NormalizerProcessor(features=output_features, norm_map=norm_map, stats=stats),
|
NormalizerProcessor(features=output_features, norm_map=norm_map, stats=stats),
|
||||||
|
ToBatchProcessor(),
|
||||||
]
|
]
|
||||||
preprocessor = RobotProcessor(preprocessor_steps, name=f"{policy_type}_preprocessor")
|
preprocessor = RobotProcessor(preprocessor_steps, name="preprocessor")
|
||||||
|
|
||||||
# Create postprocessor with unnormalizer for outputs only
|
# Create postprocessor with unnormalizer for outputs only
|
||||||
postprocessor_steps = [
|
postprocessor_steps = [
|
||||||
UnnormalizerProcessor(features=output_features, norm_map=norm_map, stats=stats),
|
UnnormalizerProcessor(features=output_features, norm_map=norm_map, stats=stats),
|
||||||
|
ToBatchProcessor(),
|
||||||
]
|
]
|
||||||
postprocessor = RobotProcessor(postprocessor_steps, name=f"{policy_type}_postprocessor")
|
postprocessor = RobotProcessor(postprocessor_steps, name="postprocessor")
|
||||||
|
|
||||||
# Determine hub repo ID if pushing to hub
|
# Determine hub repo ID if pushing to hub
|
||||||
if args.push_to_hub:
|
if args.push_to_hub:
|
||||||
|
|||||||
@@ -17,11 +17,14 @@
|
|||||||
import tempfile
|
import tempfile
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
|
from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
|
||||||
from lerobot.processor import ProcessorStepRegistry, RobotProcessor, ToBatchProcessor, TransitionKey
|
from lerobot.processor import EnvTransition, ProcessorStepRegistry, RobotProcessor
|
||||||
|
from lerobot.processor.batch_processor import ToBatchProcessor
|
||||||
|
from lerobot.processor.pipeline import TransitionKey
|
||||||
|
|
||||||
|
|
||||||
def create_transition(
|
def create_transition(
|
||||||
@@ -34,8 +37,8 @@ def create_transition(
|
|||||||
TransitionKey.REWARD: reward,
|
TransitionKey.REWARD: reward,
|
||||||
TransitionKey.DONE: done,
|
TransitionKey.DONE: done,
|
||||||
TransitionKey.TRUNCATED: truncated,
|
TransitionKey.TRUNCATED: truncated,
|
||||||
TransitionKey.INFO: info if info is not None else {},
|
TransitionKey.INFO: info,
|
||||||
TransitionKey.COMPLEMENTARY_DATA: complementary_data if complementary_data is not None else {},
|
TransitionKey.COMPLEMENTARY_DATA: complementary_data,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -421,3 +424,219 @@ def test_edge_case_zero_dimensional_tensors():
|
|||||||
# 0D tensors should remain unchanged
|
# 0D tensors should remain unchanged
|
||||||
assert torch.allclose(processed_obs[OBS_STATE], scalar_tensor)
|
assert torch.allclose(processed_obs[OBS_STATE], scalar_tensor)
|
||||||
assert torch.allclose(processed_obs["scalar_value"], scalar_tensor)
|
assert torch.allclose(processed_obs["scalar_value"], scalar_tensor)
|
||||||
|
|
||||||
|
|
||||||
|
# Action-specific tests
|
||||||
|
def test_action_1d_to_2d():
|
||||||
|
"""Test that 1D action tensors get batch dimension added."""
|
||||||
|
processor = ToBatchProcessor()
|
||||||
|
|
||||||
|
# Create 1D action tensor
|
||||||
|
action_1d = torch.randn(4)
|
||||||
|
transition = create_transition(action=action_1d)
|
||||||
|
|
||||||
|
result = processor(transition)
|
||||||
|
|
||||||
|
# Should add batch dimension
|
||||||
|
assert result[TransitionKey.ACTION].shape == (1, 4)
|
||||||
|
assert torch.equal(result[TransitionKey.ACTION][0], action_1d)
|
||||||
|
|
||||||
|
|
||||||
|
def test_action_already_batched():
|
||||||
|
"""Test that already batched action tensors remain unchanged."""
|
||||||
|
processor = ToBatchProcessor()
|
||||||
|
|
||||||
|
# Test various batch sizes
|
||||||
|
action_batched_1 = torch.randn(1, 4)
|
||||||
|
action_batched_5 = torch.randn(5, 4)
|
||||||
|
|
||||||
|
# Single batch
|
||||||
|
transition = create_transition(action=action_batched_1)
|
||||||
|
result = processor(transition)
|
||||||
|
assert torch.equal(result[TransitionKey.ACTION], action_batched_1)
|
||||||
|
|
||||||
|
# Multiple batch
|
||||||
|
transition = create_transition(action=action_batched_5)
|
||||||
|
result = processor(transition)
|
||||||
|
assert torch.equal(result[TransitionKey.ACTION], action_batched_5)
|
||||||
|
|
||||||
|
|
||||||
|
def test_action_higher_dimensional():
|
||||||
|
"""Test that higher dimensional action tensors remain unchanged."""
|
||||||
|
processor = ToBatchProcessor()
|
||||||
|
|
||||||
|
# 3D action tensor (e.g., sequence of actions)
|
||||||
|
action_3d = torch.randn(2, 4, 3)
|
||||||
|
transition = create_transition(action=action_3d)
|
||||||
|
result = processor(transition)
|
||||||
|
assert torch.equal(result[TransitionKey.ACTION], action_3d)
|
||||||
|
|
||||||
|
# 4D action tensor
|
||||||
|
action_4d = torch.randn(2, 10, 4, 3)
|
||||||
|
transition = create_transition(action=action_4d)
|
||||||
|
result = processor(transition)
|
||||||
|
assert torch.equal(result[TransitionKey.ACTION], action_4d)
|
||||||
|
|
||||||
|
|
||||||
|
def test_action_scalar_tensor():
|
||||||
|
"""Test that scalar (0D) action tensors remain unchanged."""
|
||||||
|
processor = ToBatchProcessor()
|
||||||
|
|
||||||
|
action_scalar = torch.tensor(1.5)
|
||||||
|
transition = create_transition(action=action_scalar)
|
||||||
|
result = processor(transition)
|
||||||
|
|
||||||
|
# Should remain scalar
|
||||||
|
assert result[TransitionKey.ACTION].dim() == 0
|
||||||
|
assert torch.equal(result[TransitionKey.ACTION], action_scalar)
|
||||||
|
|
||||||
|
|
||||||
|
def test_action_non_tensor():
|
||||||
|
"""Test that non-tensor actions remain unchanged."""
|
||||||
|
processor = ToBatchProcessor()
|
||||||
|
|
||||||
|
# List action
|
||||||
|
action_list = [0.1, 0.2, 0.3, 0.4]
|
||||||
|
transition = create_transition(action=action_list)
|
||||||
|
result = processor(transition)
|
||||||
|
assert result[TransitionKey.ACTION] == action_list
|
||||||
|
|
||||||
|
# Numpy array action (as Python object, not converted)
|
||||||
|
action_numpy = np.array([1, 2, 3, 4])
|
||||||
|
transition = create_transition(action=action_numpy)
|
||||||
|
result = processor(transition)
|
||||||
|
assert np.array_equal(result[TransitionKey.ACTION], action_numpy)
|
||||||
|
|
||||||
|
# String action (edge case)
|
||||||
|
action_string = "forward"
|
||||||
|
transition = create_transition(action=action_string)
|
||||||
|
result = processor(transition)
|
||||||
|
assert result[TransitionKey.ACTION] == action_string
|
||||||
|
|
||||||
|
# Dict action (structured action)
|
||||||
|
action_dict = {"linear": [0.5, 0.0], "angular": 0.2}
|
||||||
|
transition = create_transition(action=action_dict)
|
||||||
|
result = processor(transition)
|
||||||
|
assert result[TransitionKey.ACTION] == action_dict
|
||||||
|
|
||||||
|
|
||||||
|
def test_action_none():
|
||||||
|
"""Test that None action is handled correctly."""
|
||||||
|
processor = ToBatchProcessor()
|
||||||
|
|
||||||
|
transition = create_transition(action=None)
|
||||||
|
result = processor(transition)
|
||||||
|
assert result[TransitionKey.ACTION] is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_action_with_observation():
|
||||||
|
"""Test action processing together with observation processing."""
|
||||||
|
processor = ToBatchProcessor()
|
||||||
|
|
||||||
|
# Both need batching
|
||||||
|
observation = {
|
||||||
|
OBS_STATE: torch.randn(7),
|
||||||
|
OBS_IMAGE: torch.randn(64, 64, 3),
|
||||||
|
}
|
||||||
|
action = torch.randn(4)
|
||||||
|
|
||||||
|
transition = create_transition(observation=observation, action=action)
|
||||||
|
result = processor(transition)
|
||||||
|
|
||||||
|
# Both should be batched
|
||||||
|
assert result[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 7)
|
||||||
|
assert result[TransitionKey.OBSERVATION][OBS_IMAGE].shape == (1, 64, 64, 3)
|
||||||
|
assert result[TransitionKey.ACTION].shape == (1, 4)
|
||||||
|
|
||||||
|
|
||||||
|
def test_action_different_sizes():
|
||||||
|
"""Test action processing with various action dimensions."""
|
||||||
|
processor = ToBatchProcessor()
|
||||||
|
|
||||||
|
# Different action sizes (robot with different DOF)
|
||||||
|
action_sizes = [1, 2, 4, 7, 10, 20]
|
||||||
|
|
||||||
|
for size in action_sizes:
|
||||||
|
action = torch.randn(size)
|
||||||
|
transition = create_transition(action=action)
|
||||||
|
result = processor(transition)
|
||||||
|
|
||||||
|
assert result[TransitionKey.ACTION].shape == (1, size)
|
||||||
|
assert torch.equal(result[TransitionKey.ACTION][0], action)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||||
|
def test_action_device_compatibility():
|
||||||
|
"""Test action processing on different devices."""
|
||||||
|
processor = ToBatchProcessor()
|
||||||
|
|
||||||
|
# CUDA action
|
||||||
|
action_cuda = torch.randn(4, device="cuda")
|
||||||
|
transition = create_transition(action=action_cuda)
|
||||||
|
result = processor(transition)
|
||||||
|
|
||||||
|
assert result[TransitionKey.ACTION].shape == (1, 4)
|
||||||
|
assert result[TransitionKey.ACTION].device.type == "cuda"
|
||||||
|
|
||||||
|
# CPU action
|
||||||
|
action_cpu = torch.randn(4, device="cpu")
|
||||||
|
transition = create_transition(action=action_cpu)
|
||||||
|
result = processor(transition)
|
||||||
|
|
||||||
|
assert result[TransitionKey.ACTION].shape == (1, 4)
|
||||||
|
assert result[TransitionKey.ACTION].device.type == "cpu"
|
||||||
|
|
||||||
|
|
||||||
|
def test_action_dtype_preservation():
|
||||||
|
"""Test that action dtype is preserved during processing."""
|
||||||
|
processor = ToBatchProcessor()
|
||||||
|
|
||||||
|
# Different dtypes
|
||||||
|
dtypes = [torch.float32, torch.float64, torch.int32, torch.int64]
|
||||||
|
|
||||||
|
for dtype in dtypes:
|
||||||
|
action = torch.randn(4).to(dtype)
|
||||||
|
transition = create_transition(action=action)
|
||||||
|
result = processor(transition)
|
||||||
|
|
||||||
|
assert result[TransitionKey.ACTION].dtype == dtype
|
||||||
|
assert result[TransitionKey.ACTION].shape == (1, 4)
|
||||||
|
|
||||||
|
|
||||||
|
def test_action_in_place_mutation():
|
||||||
|
"""Test that the processor mutates the transition in place for actions."""
|
||||||
|
processor = ToBatchProcessor()
|
||||||
|
|
||||||
|
action = torch.randn(4)
|
||||||
|
transition = create_transition(action=action)
|
||||||
|
|
||||||
|
# Store reference to original transition
|
||||||
|
original_transition = transition
|
||||||
|
|
||||||
|
# Process
|
||||||
|
result = processor(transition)
|
||||||
|
|
||||||
|
# Should be the same object (in-place mutation)
|
||||||
|
assert result is original_transition
|
||||||
|
assert result[TransitionKey.ACTION].shape == (1, 4)
|
||||||
|
|
||||||
|
|
||||||
|
def test_empty_action_tensor():
|
||||||
|
"""Test handling of empty action tensors."""
|
||||||
|
processor = ToBatchProcessor()
|
||||||
|
|
||||||
|
# Empty 1D tensor
|
||||||
|
action_empty = torch.tensor([])
|
||||||
|
transition = create_transition(action=action_empty)
|
||||||
|
result = processor(transition)
|
||||||
|
|
||||||
|
# Should add batch dimension even to empty tensor
|
||||||
|
assert result[TransitionKey.ACTION].shape == (1, 0)
|
||||||
|
|
||||||
|
# Empty 2D tensor (already batched)
|
||||||
|
action_empty_2d = torch.randn(1, 0)
|
||||||
|
transition = create_transition(action=action_empty_2d)
|
||||||
|
result = processor(transition)
|
||||||
|
|
||||||
|
# Should remain unchanged
|
||||||
|
assert result[TransitionKey.ACTION].shape == (1, 0)
|
||||||
|
|||||||
Reference in New Issue
Block a user