mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-19 18:49:52 +00:00
feat(batch_processor): Enhance ToBatchProcessor to handle action batching
- Updated ToBatchProcessor to add batch dimensions to actions in addition to observations. - Implemented separate methods for processing observations and actions, improving code readability. - Added comprehensive unit tests to validate action batching functionality across various tensor dimensions and types.
This commit is contained in:
committed by
Steven Palma
parent
21baa8fa02
commit
99de7567e6
@@ -24,9 +24,9 @@ from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, Tra
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="to_batch_processor")
|
||||
class ToBatchProcessor:
|
||||
"""Processor that adds batch dimensions to observations when needed.
|
||||
"""Processor that adds batch dimensions to observations and actions when needed.
|
||||
|
||||
This processor ensures that observations have proper batch dimensions for model processing:
|
||||
This processor ensures that observations and actions have proper batch dimensions for model processing:
|
||||
|
||||
- For state observations (observation.state, observation.environment_state):
|
||||
Adds batch dimension (unsqueeze at dim=0) if tensor is 1-dimensional
|
||||
@@ -34,6 +34,9 @@ class ToBatchProcessor:
|
||||
- For image observations (observation.image, observation.images.*):
|
||||
Adds batch dimension (unsqueeze at dim=0) if tensor is 3-dimensional (H, W, C)
|
||||
|
||||
- For actions:
|
||||
Adds batch dimension (unsqueeze at dim=0) if tensor is 1-dimensional
|
||||
|
||||
This is useful when processing single transitions that need to be batched for
|
||||
model inference or when converting from unbatched environment outputs to
|
||||
batched model inputs.
|
||||
@@ -45,15 +48,21 @@ class ToBatchProcessor:
|
||||
```python
|
||||
# State: (7,) -> (1, 7)
|
||||
# Image: (224, 224, 3) -> (1, 224, 224, 3)
|
||||
# Action: (4,) -> (1, 4)
|
||||
# Already batched: (1, 7) -> (1, 7) [unchanged]
|
||||
```
|
||||
"""
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
observation = transition.get(TransitionKey.OBSERVATION)
|
||||
self._process_observation(transition)
|
||||
self._process_action(transition)
|
||||
return transition
|
||||
|
||||
def _process_observation(self, transition: EnvTransition) -> None:
|
||||
"""Process observation component in-place, adding batch dimensions where needed."""
|
||||
observation = transition.get(TransitionKey.OBSERVATION)
|
||||
if observation is None:
|
||||
return transition
|
||||
return
|
||||
|
||||
# Process state observations - add batch dim if 1D
|
||||
for state_key in [OBS_STATE, OBS_ENV_STATE]:
|
||||
@@ -73,7 +82,11 @@ class ToBatchProcessor:
|
||||
if key.startswith(f"{OBS_IMAGES}.") and isinstance(value, Tensor) and value.dim() == 3:
|
||||
observation[key] = value.unsqueeze(0)
|
||||
|
||||
return transition
|
||||
def _process_action(self, transition: EnvTransition) -> None:
|
||||
"""Process action component in-place, adding batch dimension if needed."""
|
||||
action = transition.get(TransitionKey.ACTION)
|
||||
if action is not None and isinstance(action, Tensor) and action.dim() == 1:
|
||||
transition[TransitionKey.ACTION] = action.unsqueeze(0)
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""Return configuration for serialization."""
|
||||
|
||||
@@ -46,6 +46,7 @@ from huggingface_hub import hf_hub_download
|
||||
from safetensors.torch import load_file as load_safetensors
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.processor.batch_processor import ToBatchProcessor
|
||||
from lerobot.processor.normalize_processor import NormalizerProcessor, UnnormalizerProcessor
|
||||
from lerobot.processor.pipeline import RobotProcessor
|
||||
|
||||
@@ -403,14 +404,16 @@ def main():
|
||||
preprocessor_steps = [
|
||||
NormalizerProcessor(features=input_features, norm_map=norm_map, stats=stats),
|
||||
NormalizerProcessor(features=output_features, norm_map=norm_map, stats=stats),
|
||||
ToBatchProcessor(),
|
||||
]
|
||||
preprocessor = RobotProcessor(preprocessor_steps, name=f"{policy_type}_preprocessor")
|
||||
preprocessor = RobotProcessor(preprocessor_steps, name="preprocessor")
|
||||
|
||||
# Create postprocessor with unnormalizer for outputs only
|
||||
postprocessor_steps = [
|
||||
UnnormalizerProcessor(features=output_features, norm_map=norm_map, stats=stats),
|
||||
ToBatchProcessor(),
|
||||
]
|
||||
postprocessor = RobotProcessor(postprocessor_steps, name=f"{policy_type}_postprocessor")
|
||||
postprocessor = RobotProcessor(postprocessor_steps, name="postprocessor")
|
||||
|
||||
# Determine hub repo ID if pushing to hub
|
||||
if args.push_to_hub:
|
||||
|
||||
Reference in New Issue
Block a user