feat(processor): Introduce ToBatchProcessor for handling observation batching

- Added ToBatchProcessor to ensure observations have proper batch dimensions for model processing.
- Implemented functionality to add batch dimensions to state and image observations as needed.
- Created comprehensive unit tests to validate the processor's behavior with various tensor dimensions and types.
- Ensured compatibility with existing transition keys and maintained the integrity of non-observation data.
This commit is contained in:
Adil Zouitine
2025-07-23 18:41:53 +02:00
committed by Steven Palma
parent e7be2fd113
commit f5c6b03b61
3 changed files with 517 additions and 0 deletions
+2
View File
@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .batch_processor import ToBatchProcessor
from .device_processor import DeviceProcessor
from .normalize_processor import NormalizerProcessor, UnnormalizerProcessor
from .observation_processor import VanillaObservationProcessor
@@ -48,6 +49,7 @@ __all__ = [
"RenameProcessor",
"RewardProcessor",
"RobotProcessor",
"ToBatchProcessor",
"TransitionKey",
"TruncatedProcessor",
"VanillaObservationProcessor",
+92
View File
@@ -0,0 +1,92 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Any
import torch
from torch import Tensor
from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey
@dataclass
@ProcessorStepRegistry.register(name="to_batch_processor")
class ToBatchProcessor:
"""Processor that adds batch dimensions to observations when needed.
This processor ensures that observations have proper batch dimensions for model processing:
- For state observations (observation.state, observation.environment_state):
Adds batch dimension (unsqueeze at dim=0) if tensor is 1-dimensional
- For image observations (observation.image, observation.images.*):
Adds batch dimension (unsqueeze at dim=0) if tensor is 3-dimensional (H, W, C)
This is useful when processing single transitions that need to be batched for
model inference or when converting from unbatched environment outputs to
batched model inputs.
The processor only modifies tensors that need batching and leaves already
batched tensors unchanged.
Example:
```python
# State: (7,) -> (1, 7)
# Image: (224, 224, 3) -> (1, 224, 224, 3)
# Already batched: (1, 7) -> (1, 7) [unchanged]
```
"""
def __call__(self, transition: EnvTransition) -> EnvTransition:
observation = transition.get(TransitionKey.OBSERVATION)
if observation is None:
return transition
# Process state observations - add batch dim if 1D
for state_key in [OBS_STATE, OBS_ENV_STATE]:
if state_key in observation:
state_value = observation[state_key]
if isinstance(state_value, Tensor) and state_value.dim() == 1:
observation[state_key] = state_value.unsqueeze(0)
# Process single image observation - add batch dim if 3D
if OBS_IMAGE in observation:
image_value = observation[OBS_IMAGE]
if isinstance(image_value, Tensor) and image_value.dim() == 3:
observation[OBS_IMAGE] = image_value.unsqueeze(0)
# Process multiple image observations - add batch dim if 3D
for key, value in observation.items():
if key.startswith(f"{OBS_IMAGES}.") and isinstance(value, Tensor) and value.dim() == 3:
observation[key] = value.unsqueeze(0)
return transition
def get_config(self) -> dict[str, Any]:
"""Return configuration for serialization."""
return {}
def state_dict(self) -> dict[str, torch.Tensor]:
"""Return state dictionary (empty for this processor)."""
return {}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
"""Load state dictionary (no-op for this processor)."""
pass
def reset(self) -> None:
"""Reset processor state (no-op for this processor)."""
pass