mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-17 01:30:14 +00:00
docs(processor): enhance tutorial on implementing custom processors
- Updated the tutorial to use `NormalizerProcessorStep` as the primary example, clarifying its role in normalizing observations and actions. - Improved explanations of the need for custom processors, emphasizing data compatibility and processing requirements. - Added code snippets demonstrating the normalization process and the configuration of processor pipelines. - Enhanced the introduction to processors, detailing their function as translators between raw robot data and model inputs. - Included examples of real-world processor configurations for both training and inference scenarios.
This commit is contained in:
@@ -1,56 +1,47 @@
|
||||
# Implement your own Robot Processor
|
||||
|
||||
In this tutorial, you'll learn how to implement your own Robot Processor.
|
||||
It begins by exploring the need for a custom processor, then uses the Normalization processors as the running example to explain how to implement, configure, and serialize a processor. Finally, it lists all helper processors that ship with LeRobot.
|
||||
It begins by exploring the need for a custom processor, then uses the `NormalizerProcessorStep` as the running example to explain how to implement, configure, and serialize a processor. Finally, it lists all helper processors that ship with LeRobot.
|
||||
|
||||
## Why would you need a custom processor?
|
||||
|
||||
In most cases, when reading raw data from a sensor like the camera and robot motor encoders,
|
||||
you will need to process this data to transform it into a format that is compatible to use with the policies in LeRobot.
|
||||
For example, raw images are encoded with `uint8` and the values are in the range `[0, 255]`.
|
||||
To use these images with the policies, you will need to cast them to `float32` and normalize them to the range `[0, 1]`.
|
||||
In most cases, when reading raw data from sensors or when models output actions, you need to process this data to make it compatible with your target system. For example, a common need is normalizing data ranges to make them suitable for neural networks.
|
||||
|
||||
For example, in LeRobot's `VanillaObservationProcessor`, raw images come from the environment as numpy arrays with `uint8` values in range `[0, 255]` and in channel-last format `(H, W, C)`. The processor transforms them into PyTorch tensors with `float32` values in range `[0, 1]` and channel-first format `(C, H, W)`:
|
||||
LeRobot's `NormalizerProcessorStep` handles this crucial task:
|
||||
|
||||
```python
|
||||
# Input: numpy array with shape (480, 640, 3) and dtype uint8
|
||||
raw_image = env_observation["pixels"] # Values in [0, 255]
|
||||
# Input: raw joint positions in [0, 180] degrees
|
||||
raw_action = torch.tensor([90.0, 45.0, 135.0])
|
||||
|
||||
# After processing: torch tensor with shape (1, 3, 480, 640) and dtype float32
|
||||
processed_image = processor(transition)["observation"]["observation.image"] # Values in [0, 1]
|
||||
# After processing: normalized to [-1, 1] range for model training
|
||||
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=dataset_stats)
|
||||
normalized_result = normalizer(transition)
|
||||
# ...
|
||||
```
|
||||
|
||||
On the other hand, when a model returns a certain action to be executed on the robot, it is often that one has to post-process this action to make it compatible to run on the robot.
|
||||
For example, the model might return joint positions values that range from `[-1, 1]` and one would need to scale them to the ranges of the minimum and maximum joint angle positions of the robot.
|
||||
Other common processing needs include:
|
||||
|
||||
In LeRobot, this normalization workflow is handled by the `NormalizerProcessor` (for inputs) and the `UnnormalizerProcessor` (for outputs). These processors are heavily used by policies (e.g., Pi0, SmolVLA) and integrate tightly with the `RobotProcessor`'s `get_config`, `state_dict`, and `load_state_dict` APIs.
|
||||
|
||||
For instance, `UnnormalizerProcessor` converts model outputs in `[-1, 1]` back to actual robot joint ranges:
|
||||
- **Device placement**: Moving tensors between CPU/GPU and converting data types
|
||||
- **Format conversion**: Transforming between different data structures
|
||||
- **Batching**: Adding/removing batch dimensions for model compatibility
|
||||
- **Safety constraints**: Applying limits to robot commands
|
||||
|
||||
```python
|
||||
# Input: model action with normalized values in [-1, 1]
|
||||
normalized_action = torch.tensor([-0.5, 0.8, -1.0, 0.2]) # Model output
|
||||
|
||||
# After post-processing: real joint positions in robot's native ranges
|
||||
# Example: joints range from [-180.0, 180.0]
|
||||
real_action = unnormalizer(transition)["action"]
|
||||
# real action after post-processing: [ -90., 144., -180., 36.]
|
||||
# Example pipeline combining multiple processors
|
||||
pipeline = PolicyProcessorPipeline([
|
||||
RenameObservationsProcessorStep(rename_map={}),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
NormalizerProcessorStep(features=features, stats=stats),
|
||||
DeviceProcessorStep(device="cuda"),
|
||||
# ...
|
||||
])
|
||||
```
|
||||
|
||||
The unnormalizer uses the dataset statistics to convert back:
|
||||
|
||||
```python
|
||||
# For MIN_MAX normalization: action = (normalized + 1) * (max - min) / 2 + min
|
||||
real_action = (normalized_action + 1) * (max_val - min_val) / 2 + min_val
|
||||
```
|
||||
|
||||
All these situations point us towards the need for a mechanism to preprocess the data before being passed to the policies and then post-process the action that are returned to be executed on the robot.
|
||||
|
||||
To that end, LeRobot provides a pipeline mechanism to implement a sequence of processing steps for the input data and the output action.
|
||||
LeRobot provides a pipeline mechanism to implement sequences of processing steps for both input data and output actions, making it easy to compose these transformations in the right order for optimal performance.
|
||||
|
||||
## How to implement your own processor?
|
||||
|
||||
We'll use the `DeviceProcessorStep` as our main example because it demonstrates essential processor patterns and device/dtype awareness that's crucial for modern multi-GPU setups.
|
||||
We'll use the `NormalizerProcessorStep` as our main example because it demonstrates essential processor patterns including state management, configuration serialization, and tensor handling that you'll commonly need.
|
||||
|
||||
Prepare the sequence of processing steps necessary for your problem. A processor step is a class that implements the following methods:
|
||||
|
||||
@@ -63,150 +54,107 @@ Prepare the sequence of processing steps necessary for your problem. A processor
|
||||
|
||||
### Implement the `__call__` method
|
||||
|
||||
The `__call__` method is the core of your processor step. It takes an `EnvTransition` and returns a modified `EnvTransition`. Here's how the `DeviceProcessorStep` works:
|
||||
The `__call__` method is the core of your processor step. It takes an `EnvTransition` and returns a modified `EnvTransition`. Here's how the `NormalizerProcessorStep` works:
|
||||
|
||||
```python
|
||||
from dataclasses import dataclass
|
||||
import torch
|
||||
from lerobot.processor import ProcessorStep, ProcessorStepRegistry
|
||||
from lerobot.processor.core import EnvTransition, TransitionKey
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register("device_processor")
|
||||
class DeviceProcessorStep(ProcessorStep):
|
||||
"""Move tensors to specified device with optional dtype conversion."""
|
||||
@ProcessorStepRegistry.register("normalizer_processor")
|
||||
class NormalizerProcessorStep(ProcessorStep):
|
||||
"""Normalize observations/actions using dataset statistics."""
|
||||
|
||||
device: str = "cpu"
|
||||
float_dtype: str | None = None
|
||||
features: dict[str, PolicyFeature]
|
||||
norm_map: dict[FeatureType, NormalizationMode]
|
||||
stats: dict[str, dict[str, Any]] | None = None
|
||||
eps: float = 1e-8
|
||||
_tensor_stats: dict = field(default_factory=dict, init=False, repr=False)
|
||||
|
||||
def __post_init__(self):
|
||||
"""Initialize device and dtype mappings."""
|
||||
self.tensor_device = torch.device(self.device)
|
||||
self.non_blocking = "cuda" in str(self.device)
|
||||
|
||||
# Map string dtype to torch dtype
|
||||
if self.float_dtype is not None:
|
||||
dtype_mapping = {
|
||||
"float16": torch.float16, "half": torch.float16,
|
||||
"float32": torch.float32, "float": torch.float32,
|
||||
"bfloat16": torch.bfloat16
|
||||
}
|
||||
self._target_float_dtype = dtype_mapping[self.float_dtype]
|
||||
else:
|
||||
self._target_float_dtype = None
|
||||
"""Convert stats to tensors for efficient computation."""
|
||||
self.stats = self.stats or {}
|
||||
self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=torch.float32)
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
new_transition = transition.copy()
|
||||
|
||||
# Process simple tensor keys
|
||||
for key in [TransitionKey.ACTION, TransitionKey.REWARD, TransitionKey.DONE, TransitionKey.TRUNCATED]:
|
||||
value = transition.get(key)
|
||||
if isinstance(value, torch.Tensor):
|
||||
new_transition[key] = self._process_tensor(value)
|
||||
|
||||
# Process nested tensor dicts
|
||||
for key in [TransitionKey.OBSERVATION, TransitionKey.COMPLEMENTARY_DATA]:
|
||||
data_dict = transition.get(key)
|
||||
if data_dict is not None:
|
||||
new_data_dict = {
|
||||
k: self._process_tensor(v) if isinstance(v, torch.Tensor) else v
|
||||
for k, v in data_dict.items()
|
||||
}
|
||||
new_transition[key] = new_data_dict
|
||||
|
||||
# Normalize observations
|
||||
# ...
|
||||
# Normalize action
|
||||
# ...
|
||||
return new_transition
|
||||
|
||||
def _process_tensor(self, tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""Move tensor to target device and convert dtype if needed."""
|
||||
# Smart device handling for multi-GPU compatibility
|
||||
if tensor.is_cuda and self.tensor_device.type == "cuda":
|
||||
# Both on GPU: preserve original GPU (Accelerate compatibility)
|
||||
target_device = tensor.device
|
||||
else:
|
||||
# CPU or different device types: use configured device
|
||||
target_device = self.tensor_device
|
||||
|
||||
# Move if necessary
|
||||
if tensor.device != target_device:
|
||||
tensor = tensor.to(target_device, non_blocking=self.non_blocking)
|
||||
|
||||
# Convert float dtype if specified
|
||||
if self._target_float_dtype is not None and tensor.is_floating_point():
|
||||
tensor = tensor.to(dtype=self._target_float_dtype)
|
||||
|
||||
return tensor
|
||||
|
||||
def get_config(self) -> dict:
|
||||
return {"device": self.device, "float_dtype": self.float_dtype}
|
||||
```
|
||||
|
||||
See the full implementation in `src/lerobot/processor/device_processor.py` for complete details.
|
||||
See the full implementation in `src/lerobot/processor/normalize_processor.py` for complete details.
|
||||
|
||||
**Key principles:**
|
||||
|
||||
- **Always use `transition.copy()`** to avoid side effects
|
||||
- **Handle both simple and nested tensors** systematically
|
||||
- **Smart device handling**: Preserve GPU placement for Accelerate compatibility
|
||||
- **Validate configurations** in `__post_init__()`
|
||||
- **Handle both observations and actions** consistently
|
||||
- **Separate config from state**: `get_config()` returns JSON-serializable params, `state_dict()` returns tensors
|
||||
- **Convert stats to tensors** in `__post_init__()` for efficient computation
|
||||
|
||||
### Configuration and State Management
|
||||
|
||||
Processors support serialization through three methods that separate configuration from tensor state. This is especially important for normalization processors, which carry dataset statistics (tensors) in their state, and hyperparameters in their config:
|
||||
Processors support serialization through three methods that separate configuration from tensor state. The `NormalizerProcessorStep` demonstrates this perfectly - it carries dataset statistics (tensors) in its state, and hyperparameters in its config:
|
||||
|
||||
```python
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
import torch
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
# Continuing the NormalizerProcessorStep example...
|
||||
|
||||
@dataclass
|
||||
class NormalizerProcessor:
|
||||
features: dict[str, PolicyFeature]
|
||||
norm_map: dict[FeatureType, NormalizationMode]
|
||||
eps: float = 1e-8
|
||||
_tensor_stats: dict[str, dict[str, torch.Tensor]] = field(default_factory=dict, init=False, repr=False)
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""JSON-serializable configuration (no tensors)."""
|
||||
return {
|
||||
"eps": self.eps,
|
||||
"features": {k: {"type": v.type.value, "shape": v.shape} for k, v in self.features.items()},
|
||||
"norm_map": {ft.value: nm.value for ft, nm in self.norm_map.items()},
|
||||
# ...
|
||||
}
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""JSON-serializable configuration (no tensors)."""
|
||||
return {
|
||||
"eps": self.eps,
|
||||
"features": {k: {"type": v.type.value, "shape": v.shape} for k, v in self.features.items()},
|
||||
"norm_map": {ft.value: nm.value for ft, nm in self.norm_map.items()},
|
||||
}
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
"""Tensor state only (e.g., dataset statistics)."""
|
||||
flat: dict[str, torch.Tensor] = {}
|
||||
for key, sub in self._tensor_stats.items():
|
||||
for stat_name, tensor in sub.items():
|
||||
flat[f"{key}.{stat_name}"] = tensor.cpu() # Always save to CPU
|
||||
return flat
|
||||
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
"""Tensor state only (e.g., dataset statistics)."""
|
||||
flat: dict[str, torch.Tensor] = {}
|
||||
for key, sub in self._tensor_stats.items():
|
||||
for stat_name, tensor in sub.items():
|
||||
flat[f"{key}.{stat_name}"] = tensor
|
||||
return flat
|
||||
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||
"""Restore tensor state at runtime."""
|
||||
self._tensor_stats.clear()
|
||||
for flat_key, tensor in state.items():
|
||||
key, stat_name = flat_key.rsplit(".", 1)
|
||||
self._tensor_stats.setdefault(key, {})[stat_name] = tensor
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||
"""Restore tensor state at runtime."""
|
||||
self._tensor_stats.clear()
|
||||
for flat_key, tensor in state.items():
|
||||
key, stat_name = flat_key.rsplit(".", 1)
|
||||
# Load to processor's configured device
|
||||
self._tensor_stats.setdefault(key, {})[stat_name] = tensor.to(
|
||||
dtype=torch.float32, device=self.device
|
||||
)
|
||||
# ...
|
||||
```
|
||||
|
||||
**Usage:**
|
||||
|
||||
```python
|
||||
# Save (e.g., inside a policy)
|
||||
config = processor.get_config()
|
||||
tensors = processor.state_dict()
|
||||
config = normalizer.get_config()
|
||||
tensors = normalizer.state_dict()
|
||||
|
||||
# Restore (e.g., loading a pretrained policy)
|
||||
new_processor = NormalizerProcessor(**config)
|
||||
new_processor.load_state_dict(tensors)
|
||||
new_normalizer = NormalizerProcessorStep(**config)
|
||||
new_normalizer.load_state_dict(tensors)
|
||||
# Now new_normalizer has the same stats and configuration
|
||||
```
|
||||
|
||||
### Transform features
|
||||
|
||||
The `transform_features` method defines how your processor transforms feature names and shapes. This is crucial for policy configuration and debugging.
|
||||
|
||||
Normalization typically preserves the feature keys and shapes, so `NormalizerProcessor.transform_features` returns the input features unchanged. When your processor renames or reshapes, implement this method to reflect the mapping for downstream components. For example, a simple rename processor:
|
||||
For `NormalizerProcessorStep`, features are typically preserved unchanged since normalization doesn't alter keys or shapes:
|
||||
|
||||
```python
|
||||
def transform_features(self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
"""Normalization preserves all feature definitions."""
|
||||
return features # No changes to feature structure
|
||||
# ...
|
||||
```
|
||||
|
||||
When your processor renames or reshapes data, implement this method to reflect the mapping for downstream components. For example, a simple rename processor:
|
||||
|
||||
```python
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
@@ -219,6 +167,7 @@ def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, Po
|
||||
if key.startswith("env_state."):
|
||||
suffix = key[len("env_state."):]
|
||||
features[f"observation.{suffix}"] = features.pop(key)
|
||||
# ...
|
||||
|
||||
return features
|
||||
```
|
||||
@@ -230,98 +179,30 @@ def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, Po
|
||||
- Always return the modified features dictionary
|
||||
- Document transformations clearly in the docstring
|
||||
|
||||
### Example of usage from the codebase
|
||||
|
||||
`transform_features` is used by `RobotProcessor` to derive the dataset/policy feature contract from an initial feature set by applying each step's transformation. You can see concrete examples in the codebase:
|
||||
|
||||
- Phone teleoperation record pipeline (`examples/phone_so100_record.py`): processors like `ForwardKinematicsJointsToEE`, `GripperVelocityToJoint`, and `EEBoundsAndSafety` implement `transform_features` to declare which action/observation keys should be materialized in the dataset.
|
||||
- SO100 follower kinematics (`src/lerobot/robots/so100_follower/robot_kinematic_processor.py`): each processor's `transform_features` method adds or refines feature keys such as `observation.state.ee.{x,y,z,wx,wy,wz}` or `action.gripper.pos`.
|
||||
- Rename and tokenizer processors (`src/lerobot/processor/rename_processor.py`, `src/lerobot/processor/tokenizer_processor.py`): demonstrate key renaming and adding language token features to the contract.
|
||||
|
||||
In practice, you will often aggregate features by running `DataProcessorPipeline.transform_features(...)` with your initial features to compute the final contract before recording or training.
|
||||
|
||||
## Helper Classes
|
||||
|
||||
LeRobot provides pre-built processor classes for common transformations. Below is a comprehensive list of registered processors in the codebase.
|
||||
|
||||
### Core processors (observations, actions, normalization)
|
||||
|
||||
- **`VanillaObservationProcessorStep`** (`observation_processor`): Images and state processing to LeRobot format.
|
||||
- **`NormalizerProcessorStep`** (`normalizer_processor`): Normalize observations/actions (mean/std or min/max to [-1, 1]).
|
||||
- **`UnnormalizerProcessorStep`** (`unnormalizer_processor`): Inverse of the normalizer for model outputs.
|
||||
- **`DeviceProcessorStep`** (`device_processor`): Move tensors to a specific device (CPU/GPU) and optional float dtype.
|
||||
- **`AddBatchDimensionProcessorStep`** (`to_batch_processor`): Add batch dimension to observations/actions when missing.
|
||||
- **`RenameObservationsProcessorStep`** (`rename_observations_processor`): Rename observation keys using a mapping dictionary.
|
||||
- **`TokenizerProcessorStep`** (`tokenizer_processor`): Tokenize language tasks into `observation.language.*` tensors.
|
||||
|
||||
### Teleoperation mapping processors
|
||||
|
||||
- **`MapDeltaActionToRobotAction`** (`map_delta_action_to_robot_action`): Map teleop deltas (e.g., gamepad) to `action.target_*` fields.
|
||||
- **`MapPhoneActionToRobotAction`** (`map_phone_action_to_robot_action`): Map calibrated phone pose/buttons to `action.target_*` and gripper.
|
||||
|
||||
### Robot kinematics processors (SO100 follower example)
|
||||
|
||||
- **`EEReferenceAndDelta`** (`ee_reference_and_delta`): Compute desired EE pose from target deltas and current pose.
|
||||
- **`EEBoundsAndSafety`** (`ee_bounds_and_safety`): Clip EE pose to bounds and check for jumps.
|
||||
- **`InverseKinematicsEEToJoints`** (`inverse_kinematics_ee_to_joints`): Convert EE pose to joint targets via IK.
|
||||
- **`GripperVelocityToJoint`** (`gripper_velocity_to_joint`): Convert gripper velocity input to joint position command.
|
||||
- **`ForwardKinematicsJointsToEE`** (`forward_kinematics_joints_to_ee`): Compute EE pose features from joint positions via FK.
|
||||
- **`AddRobotObservationAsComplimentaryData`** (`add_robot_observation`): Read robot observation and insert `raw_joint_positions` into complementary data.
|
||||
|
||||
### Policy-specific utility processors
|
||||
|
||||
- **`Pi0NewLineProcessor`** (`pi0_new_line_processor`): Ensure text tasks end with a newline (Pi0 tokenizer compatibility).
|
||||
- **`SmolVLANewLineProcessor`** (`smolvla_new_line_processor`): Ensure text tasks end with a newline (SmolVLA tokenizer compatibility).
|
||||
|
||||
### Usage Example
|
||||
|
||||
```python
|
||||
from lerobot.processor import (
|
||||
NormalizerProcessorStep, DeviceProcessorStep,
|
||||
RobotProcessorPipeline, AddBatchDimensionProcessorStep
|
||||
)
|
||||
|
||||
# Create a processing pipeline (typical policy preprocessor)
|
||||
steps = [
|
||||
NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
DeviceProcessorStep(device="cuda"),
|
||||
]
|
||||
|
||||
# Use in RobotProcessorPipeline
|
||||
processor = RobotProcessorPipeline[dict, dict](steps=steps)
|
||||
processed_transition = processor(raw_transition)
|
||||
```
|
||||
|
||||
### Using overrides
|
||||
|
||||
You can override step parameters at load-time using `overrides`. This is handy for non-serializable objects or site-specific settings. It works both in policy factories and with `DataProcessorPipeline.from_pretrained(...)`.
|
||||
|
||||
**Foundational model adaptation**: This is particularly useful when working with foundational pretrained policies where you rarely have access to the original training statistics. You can inject your own dataset statistics to adapt the normalizer to your specific robot or environment data.
|
||||
|
||||
Example: during policy evaluation on the robot, override the device and rename map.
|
||||
Use this to run a policy trained on CUDA on a CPU-only robot, or to remap camera keys when the robot uses different names than the dataset.
|
||||
|
||||
```437:445:src/lerobot/record.py
|
||||
preprocessor, postprocessor = make_processor(
|
||||
policy_cfg=cfg.policy,
|
||||
pretrained_path=cfg.policy.pretrained_path,
|
||||
dataset_stats=rename_stats(dataset.meta.stats, cfg.dataset.rename_map),
|
||||
preprocessor_overrides={
|
||||
"device_processor": {"device": cfg.policy.device},
|
||||
"rename_processor": {"rename_map": cfg.dataset.rename_map},
|
||||
},
|
||||
)
|
||||
```
|
||||
|
||||
Direct usage with `from_pretrained`:
|
||||
|
||||
```python
|
||||
from lerobot.processor import RobotProcessorPipeline
|
||||
|
||||
# Load a foundational policy trained on diverse robot data
|
||||
# but adapt normalization to your specific robot/environment
|
||||
new_stats = LeRobotDataset(repo_id="username/my-dataset").meta.stats
|
||||
processor = RobotProcessorPipeline.from_pretrained(
|
||||
"username/my-processor",
|
||||
"huggingface/foundational-robot-policy", # Pretrained foundation model
|
||||
overrides={
|
||||
"device_processor": {"device": "cuda:0"}, # registry name for registered steps
|
||||
"CustomStep": {"param": 42}, # class name for non-registered steps
|
||||
"normalizer_processor": {"stats": new_stats}, # Inject your robot's statistics
|
||||
"device_processor": {"device": "cuda:0"}, # registry name for registered steps
|
||||
"rename_processor": {"rename_map": robot_key_map}, # Map your robot's observation keys
|
||||
# ...
|
||||
},
|
||||
)
|
||||
```
|
||||
@@ -332,139 +213,43 @@ Based on analysis of all LeRobot processor implementations, here are the key pat
|
||||
|
||||
### 1. **Safe Data Handling**
|
||||
|
||||
```python
|
||||
# ✅ Always copy data to avoid side effects
|
||||
new_action = action.copy()
|
||||
new_obs = observation.copy()
|
||||
Always create copies of input data to avoid unintended side effects. Use `transition.copy()` and `observation.copy()` rather than modifying data in-place. This prevents your processor from accidentally affecting other components in the pipeline.
|
||||
|
||||
# ✅ Check for required data before processing
|
||||
if "pixels" not in observation:
|
||||
return observation # Pass through unchanged
|
||||
Check for required data before processing and handle missing data gracefully. If your processor expects certain keys (like `"pixels"` for image processing), validate their presence first. For optional data, use safe access patterns like `transition.get()` and handle `None` values appropriately.
|
||||
|
||||
# ✅ Handle None gracefully
|
||||
comp = self.transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
||||
if comp is None:
|
||||
raise ValueError("Required complementary data missing")
|
||||
```
|
||||
When data validation fails, provide clear, actionable error messages that help users understand what went wrong and how to fix it.
|
||||
|
||||
### 2. **Robust Input Validation**
|
||||
### 2. **Choose Appropriate Base Classes**
|
||||
|
||||
```python
|
||||
# ✅ Validate data types and shapes
|
||||
if not isinstance(action, dict):
|
||||
raise ValueError(f"Action should be a RobotAction type got {type(action)}")
|
||||
LeRobot provides specialized base classes that reduce boilerplate code and ensure consistency. Use `ObservationProcessorStep` when you only need to modify observations, `ActionProcessorStep` for action-only processing, and `RobotActionProcessorStep` specifically for dictionary-based robot actions.
|
||||
|
||||
# ✅ Check tensor properties before processing
|
||||
if img_tensor.dtype != torch.uint8:
|
||||
raise ValueError(f"Expected torch.uint8 images, but got {img_tensor.dtype}")
|
||||
Only inherit directly from `ProcessorStep` when you need full control over the entire transition or when processing multiple transition components simultaneously. The specialized base classes handle the transition management for you and provide type safety.
|
||||
|
||||
# ✅ Validate required keys exist
|
||||
if None in (x, y, z, wx, wy, wz):
|
||||
raise ValueError("Missing required end-effector pose components")
|
||||
```
|
||||
### 3. **Registration and Naming**
|
||||
|
||||
### 3. **Use Appropriate Base Classes**
|
||||
Register your processors with descriptive, namespaced names using `@ProcessorStepRegistry.register()`. Use organization prefixes like `"robotics_lab/safety_clipper"` or `"acme_corp/vision_enhancer"` to avoid naming conflicts. Avoid generic names like `"processor"` or `"step"` that could clash with other implementations.
|
||||
|
||||
```python
|
||||
# ✅ Observation-only processors
|
||||
class MyObsProcessor(ObservationProcessorStep):
|
||||
def observation(self, observation): ...
|
||||
Good registration makes your processors discoverable and enables clean serialization/deserialization when saving and loading pipelines.
|
||||
|
||||
# ✅ Action-only processors
|
||||
class MyActionProcessor(ActionProcessorStep):
|
||||
def action(self, action): ...
|
||||
### 4. **State Management Patterns**
|
||||
|
||||
# ✅ Robot action processors (dict actions only)
|
||||
class MyRobotActionProcessor(RobotActionProcessorStep):
|
||||
def action(self, action: dict[str, Any]): ...
|
||||
Distinguish between configuration parameters (JSON-serializable values) and internal state (tensors, buffers). Use dataclass fields with `init=False, repr=False` for internal state that shouldn't appear in the constructor or string representation.
|
||||
|
||||
# ✅ Full control processors
|
||||
class MyFullProcessor(ProcessorStep):
|
||||
def __call__(self, transition: EnvTransition): ...
|
||||
```
|
||||
Implement the `reset()` method to clear internal state between episodes. This is crucial for stateful processors that accumulate data over time, like moving averages or temporal filters.
|
||||
|
||||
### 4. **Registration and Naming**
|
||||
Remember that `get_config()` should only return JSON-serializable configuration, while `state_dict()` handles tensor state separately.
|
||||
|
||||
```python
|
||||
# ✅ Always register with namespaced names
|
||||
@ProcessorStepRegistry.register("my_company/image_processor")
|
||||
@dataclass
|
||||
class ImageProcessor(ObservationProcessorStep):
|
||||
...
|
||||
### 5. **Input Validation and Error Handling**
|
||||
|
||||
# ✅ Use descriptive, unique names
|
||||
# Good: "robotics_lab/safety_clipper", "acme_corp/vision_enhancer"
|
||||
# Bad: "processor", "step", "my_processor"
|
||||
```
|
||||
Validate input types and shapes before processing. Check tensor properties like `dtype` and dimensions to ensure compatibility with your algorithms. For robot actions, verify that required pose components or joint values are present and within expected ranges.
|
||||
|
||||
### 5. **State Management Patterns**
|
||||
Use early returns for edge cases where no processing is needed. Provide clear, descriptive error messages that include the expected vs. actual data types or shapes. This makes debugging much easier for users.
|
||||
|
||||
```python
|
||||
# ✅ Use dataclass fields for internal state
|
||||
@dataclass
|
||||
class StatefulProcessor(ProcessorStep):
|
||||
# Public config
|
||||
window_size: int = 10
|
||||
### 6. **Device and Dtype Awareness**
|
||||
|
||||
# Internal state (not in config)
|
||||
_buffer: list = field(default_factory=list, init=False, repr=False)
|
||||
_last_value: float | None = field(default=None, init=False, repr=False)
|
||||
Design your processors to automatically adapt to the device and dtype of input tensors. Internal tensors (like normalization statistics) should match the input tensor's device and dtype to ensure compatibility with multi-GPU training, mixed precision, and distributed setups.
|
||||
|
||||
def reset(self):
|
||||
"""Reset internal state between episodes."""
|
||||
self._buffer.clear()
|
||||
self._last_value = None
|
||||
```
|
||||
|
||||
### 6. **Error Handling**
|
||||
|
||||
```python
|
||||
# ✅ Early returns for edge cases
|
||||
if not self.enabled or action is None:
|
||||
return action
|
||||
|
||||
# ✅ Clear error messages for invalid inputs
|
||||
if not isinstance(action, dict):
|
||||
raise ValueError(f"Action should be a RobotAction type got {type(action)}")
|
||||
|
||||
# ✅ Validate required keys exist
|
||||
if "required_key" not in action:
|
||||
raise ValueError("Required key 'required_key' not found in action")
|
||||
```
|
||||
|
||||
### 7. **Device and Dtype Awareness**
|
||||
|
||||
The key principle: **tensors stored in your processor should mimic the dtype and device of input tensors**. This enables seamless operation in multi-GPU setups, Accelerate, and data parallel configurations.
|
||||
|
||||
```python
|
||||
# ✅ Adapt internal state to match input tensors
|
||||
def _apply_transform(self, tensor: torch.Tensor, key: str) -> torch.Tensor:
|
||||
# Check if our internal stats match the input tensor
|
||||
if key in self._tensor_stats:
|
||||
first_stat = next(iter(self._tensor_stats[key].values()))
|
||||
if first_stat.device != tensor.device or first_stat.dtype != tensor.dtype:
|
||||
# Automatically adapt to input tensor's device/dtype
|
||||
self.to(device=tensor.device, dtype=tensor.dtype)
|
||||
|
||||
# Now process with matching device/dtype
|
||||
return self._process_with_stats(tensor, key)
|
||||
|
||||
# ✅ Implement to() method for device/dtype migration
|
||||
def to(self, device=None, dtype=None):
|
||||
if device is not None:
|
||||
self.device = device
|
||||
if dtype is not None:
|
||||
self.dtype = dtype
|
||||
# Update internal tensor stats to match
|
||||
self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype)
|
||||
return self
|
||||
|
||||
# ✅ This pattern enables:
|
||||
# - Multi-GPU training (data on different GPUs)
|
||||
# - Mixed precision (float16, bfloat16)
|
||||
# - Accelerate compatibility (automatic device placement)
|
||||
# - Data parallel setups (distributed training)
|
||||
```
|
||||
Implement a `to()` method that moves your processor's internal state to the specified device. Check device/dtype compatibility at runtime and automatically migrate internal state when needed. This pattern enables seamless operation across different hardware configurations without manual intervention.
|
||||
|
||||
## Conclusion
|
||||
|
||||
|
||||
Reference in New Issue
Block a user