Files
lerobot/docs/source/implement_your_own_processor.mdx
T
2025-08-07 18:13:34 +02:00

248 lines
10 KiB
Plaintext

# Implement your own Robot Processor
In this tutorial, you'll learn how to implement your own Robot Processor.
It begins by exploring the need for a custom processor, followed by an explanation of how to implement one.
The tutorial also covers the set of helper classes that are available in LeRobot to support the implementation.
## Why would you need a custom processor?
In most cases, when reading raw data from a sensor like the camera and robot motor encoders,
you will need to process this data to transform it into a format that is compatible to use with the policies in LeRobot.
For example, raw images are encoded with `uint8` and the values are in the range `[0, 255]`.
To use these images with the policies, you will need to cast them to `float32` and normalize them to the range `[0, 1]`.
For example, in LeRobot's `VanillaObservationProcessor`, raw images come from the environment as numpy arrays with `uint8` values in range `[0, 255]` and in channel-last format `(H, W, C)`. The processor transforms them into PyTorch tensors with `float32` values in range `[0, 1]` and channel-first format `(C, H, W)`:
```python
# Input: numpy array with shape (480, 640, 3) and dtype uint8
raw_image = env_observation["pixels"] # Values in [0, 255]
# After processing: torch tensor with shape (1, 3, 480, 640) and dtype float32
processed_image = processor(transition)["observation"]["observation.image"] # Values in [0, 1]
```
On the other hand, when a model returns a certain action to be executed on the robot, it is often that one has to post-process this action to make it compatible to run on the robot.
For example, the model might return joint positions values that range from `[-1, 1]` and one would need to scale them to the ranges of the minimum and maximum joint angle positions of the robot.
For instance, in LeRobot's `UnnormalizerProcessor`, model outputs are in the normalized range `[-1, 1]` and need to be converted back to actual robot joint ranges:
```python
# Input: model action with normalized values in [-1, 1]
normalized_action = torch.tensor([-0.5, 0.8, -1.0, 0.2]) # Model output
# After post-processing: real joint positions in robot's native ranges
# Example: joints range from [-180.0, 180.0]
real_action = unnormalizer(transition)["action"]
# real action after post-processing: [ -90., 144., -180., 36.]
```
The unnormalizer uses the dataset statistics to convert back:
```python
# For MIN_MAX normalization: action = (normalized + 1) * (max - min) / 2 + min
real_action = (normalized_action + 1) * (max_val - min_val) / 2 + min_val
```
All this situation point us towards the need for a mechanism to preprocess the data before being passed to the policies and then post-process the action that are returned to be executed on the robot.
To that end, LeRobot provides a pipeline mechanism to implement a sequence of processing steps for the input data and the output action.
## How to implement your own processor?
Prepare the sequence of processing steps necessary for your problem. A processor step is a class that implements the following methods:
- `__call__`: implements the processing step for the input transition.
- `get_config`: gets the configuration of the processor step.
- `state_dict`: gets the state of the processor step.
- `load_state_dict`: loads the state of the processor step.
- `reset`: resets the state of the processor step.
- `feature_contract`: displays the modification to the feature space during the processor step.
### Implement the `__call__` method
The `__call__` method is the core of your processor step. It takes an `EnvTransition` and returns a modified `EnvTransition`. Here's a minimal example:
```python
from dataclasses import dataclass
from lerobot.processor.pipeline import EnvTransition, TransitionKey
@dataclass
class MyProcessor:
def __call__(self, transition: EnvTransition) -> EnvTransition:
# Check if the required data exists
observation = transition.get(TransitionKey.OBSERVATION)
if observation is None:
return transition
# Process the data
processed_obs = self._process_observation(observation)
# Return new transition with processed data
new_transition = transition.copy()
new_transition[TransitionKey.OBSERVATION] = processed_obs
return new_transition
def _process_observation(self, obs):
# Your custom processing logic here
return obs
```
**Key principles:**
- Always check if required data exists before processing
- Return unchanged transition if no processing is needed
- Use `transition.copy()` to avoid side effects
- Only modify the specific keys your processor handles
**Tip**: For observation-only processors, inherit from `ObservationProcessor` to avoid writing `__call__` boilerplate:
```python
from lerobot.processor.pipeline import ObservationProcessor
@dataclass
class MyObservationProcessor(ObservationProcessor):
def observation(self, observation):
# Only implement this method - __call__ is handled automatically
return self._process_observation(observation)
```
### Configuration and State Management
Processors support serialization through three methods that separate configuration from tensor state:
```python
@dataclass
class MyProcessor:
threshold: float = 0.5
_running_mean: torch.Tensor = field(default=None, init=False)
def get_config(self) -> dict[str, Any]:
"""Return JSON-serializable configuration."""
return {"threshold": self.threshold}
def state_dict(self) -> dict[str, torch.Tensor]:
"""Return tensor state only."""
if self._running_mean is not None:
return {"running_mean": self._running_mean}
return {}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
"""Restore tensor state."""
if "running_mean" in state:
self._running_mean = state["running_mean"]
```
**Usage:**
```python
# Save
config = processor.get_config()
tensors = processor.state_dict()
# Restore
new_processor = MyProcessor(**config)
new_processor.load_state_dict(tensors)
```
### Feature Contract
The `feature_contract` method defines how your processor transforms feature names and shapes. This is crucial for policy configuration and debugging.
```python
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
"""Transform feature keys: old_key -> new_key"""
# Simple renaming
if "pixels" in features:
features["observation.image"] = features.pop("pixels")
# Pattern-based renaming
for key in list(features.keys()):
if key.startswith("env_state."):
suffix = key[len("env_state."):]
features[f"observation.{suffix}"] = features.pop(key)
return features
```
**Key principles:**
- Use `features.pop(old_key)` to remove and get the old feature
- Use `features[new_key] = old_feature` to add the renamed feature
- Always return the modified features dictionary
- Document transformations clearly in the docstring
## Helper Classes
LeRobot provides pre-built processor classes for common transformations:
### Core Classes
- **`VanillaObservationProcessor`** - Handles images and state observations
- **`NormalizerProcessor`** - Normalizes data using dataset statistics (mean/std or min/max)
- **`UnnormalizerProcessor`** - Converts normalized values back to original ranges
### Utility Classes
- **`DeviceProcessor`** - Moves tensors to specified device (CPU/GPU)
- **`ToBatchProcessor`** - Adds batch dimensions
- **`RenameProcessor`** - Renames keys using a mapping dictionary
- **`TokenizerProcessor`** - Handles text tokenization for language-conditioned policies
### Usage Example
```python
from lerobot.processor import (
VanillaObservationProcessor,
NormalizerProcessor,
DeviceProcessor,
RobotProcessor
)
# Create a processing pipeline
steps = [
VanillaObservationProcessor(), # Process images and states
NormalizerProcessor(features=features, norm_map=norm_map, stats=stats),
DeviceProcessor(device="cuda"),
]
# Use in RobotProcessor
processor = RobotProcessor(steps=steps)
processed_transition = processor(raw_transition)
```
## Best Practices
- **Keep processors atomic** - One transformation per processor for reusability and debugging
- **Use dataclasses** - Clean initialization with `@dataclass`
- **Always register processors** - Use `@ProcessorStepRegistry.register("name")` for discoverability
- **Check for None** - Always validate required data exists before processing
- **Use copy() for safety** - Avoid side effects with `transition.copy()`
- **Separate config and state** - JSON-serializable config vs tensor state_dict
- **Use base classes** - Inherit from `ObservationProcessor` for observation-only processing
```python
@ProcessorStepRegistry.register("my_processor")
@dataclass
class MyProcessor(ObservationProcessor):
threshold: float = 0.5
def observation(self, observation):
if observation is None:
return observation
# Your processing logic here
return processed_observation
```
## Conclusion
You now have all the tools to implement custom processors in LeRobot! The key steps are:
1. **Define your processor** as a dataclass with the required methods (`__call__`, `get_config`, `state_dict`, `load_state_dict`, `reset`, `feature_contract`)
2. **Register it** using `@ProcessorStepRegistry.register("name")` for discoverability
3. **Integrate it** into a `RobotProcessor` pipeline with other processing steps
4. **Use base classes** like `ObservationProcessor` when possible to reduce boilerplate
The processor system is designed to be modular and composable, allowing you to build complex data processing pipelines from simple, focused components. Whether you're preprocessing sensor data for training or post-processing model outputs for robot execution, custom processors give you the flexibility to handle any data transformation your robotics application requires.
Start simple, test thoroughly, and leverage the existing helper classes to build robust data processing pipelines for your robot learning workflows.