feat(batch_processor): Enhance ToBatchProcessor to handle action batching

- Updated ToBatchProcessor to add batch dimensions to actions in addition to observations.
- Implemented separate methods for processing observations and actions, improving code readability.
- Added comprehensive unit tests to validate action batching functionality across various tensor dimensions and types.
This commit is contained in:
Adil Zouitine
2025-07-24 17:20:57 +02:00
committed by Steven Palma
parent 21baa8fa02
commit 99de7567e6
3 changed files with 245 additions and 10 deletions
+18 -5
View File
@@ -24,9 +24,9 @@ from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, Tra
@dataclass
@ProcessorStepRegistry.register(name="to_batch_processor")
class ToBatchProcessor:
"""Processor that adds batch dimensions to observations when needed.
"""Processor that adds batch dimensions to observations and actions when needed.
This processor ensures that observations have proper batch dimensions for model processing:
This processor ensures that observations and actions have proper batch dimensions for model processing:
- For state observations (observation.state, observation.environment_state):
Adds batch dimension (unsqueeze at dim=0) if tensor is 1-dimensional
@@ -34,6 +34,9 @@ class ToBatchProcessor:
- For image observations (observation.image, observation.images.*):
Adds batch dimension (unsqueeze at dim=0) if tensor is 3-dimensional (H, W, C)
- For actions:
Adds batch dimension (unsqueeze at dim=0) if tensor is 1-dimensional
This is useful when processing single transitions that need to be batched for
model inference or when converting from unbatched environment outputs to
batched model inputs.
@@ -45,15 +48,21 @@ class ToBatchProcessor:
```python
# State: (7,) -> (1, 7)
# Image: (224, 224, 3) -> (1, 224, 224, 3)
# Action: (4,) -> (1, 4)
# Already batched: (1, 7) -> (1, 7) [unchanged]
```
"""
def __call__(self, transition: EnvTransition) -> EnvTransition:
observation = transition.get(TransitionKey.OBSERVATION)
self._process_observation(transition)
self._process_action(transition)
return transition
def _process_observation(self, transition: EnvTransition) -> None:
"""Process observation component in-place, adding batch dimensions where needed."""
observation = transition.get(TransitionKey.OBSERVATION)
if observation is None:
return transition
return
# Process state observations - add batch dim if 1D
for state_key in [OBS_STATE, OBS_ENV_STATE]:
@@ -73,7 +82,11 @@ class ToBatchProcessor:
if key.startswith(f"{OBS_IMAGES}.") and isinstance(value, Tensor) and value.dim() == 3:
observation[key] = value.unsqueeze(0)
return transition
def _process_action(self, transition: EnvTransition) -> None:
"""Process action component in-place, adding batch dimension if needed."""
action = transition.get(TransitionKey.ACTION)
if action is not None and isinstance(action, Tensor) and action.dim() == 1:
transition[TransitionKey.ACTION] = action.unsqueeze(0)
def get_config(self) -> dict[str, Any]:
"""Return configuration for serialization."""
@@ -46,6 +46,7 @@ from huggingface_hub import hf_hub_download
from safetensors.torch import load_file as load_safetensors
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.processor.batch_processor import ToBatchProcessor
from lerobot.processor.normalize_processor import NormalizerProcessor, UnnormalizerProcessor
from lerobot.processor.pipeline import RobotProcessor
@@ -403,14 +404,16 @@ def main():
preprocessor_steps = [
NormalizerProcessor(features=input_features, norm_map=norm_map, stats=stats),
NormalizerProcessor(features=output_features, norm_map=norm_map, stats=stats),
ToBatchProcessor(),
]
preprocessor = RobotProcessor(preprocessor_steps, name=f"{policy_type}_preprocessor")
preprocessor = RobotProcessor(preprocessor_steps, name="preprocessor")
# Create postprocessor with unnormalizer for outputs only
postprocessor_steps = [
UnnormalizerProcessor(features=output_features, norm_map=norm_map, stats=stats),
ToBatchProcessor(),
]
postprocessor = RobotProcessor(postprocessor_steps, name=f"{policy_type}_postprocessor")
postprocessor = RobotProcessor(postprocessor_steps, name="postprocessor")
# Determine hub repo ID if pushing to hub
if args.push_to_hub: