docs(processor): enhance tutorial on implementing custom processors

- Updated the tutorial to use `NormalizerProcessorStep` as the primary example, clarifying its role in normalizing observations and actions.
- Improved explanations of the need for custom processors, emphasizing data compatibility and processing requirements.
- Added code snippets demonstrating the normalization process and the configuration of processor pipelines.
- Enhanced the introduction to processors, detailing their function as translators between raw robot data and model inputs.
- Included examples of real-world processor configurations for both training and inference scenarios.
This commit is contained in:
AdilZouitine
2025-09-15 18:20:28 +02:00
parent 8fb18109ef
commit cee5a3fec5
2 changed files with 174 additions and 434 deletions
+116 -331
View File
@@ -1,56 +1,47 @@
# Implement your own Robot Processor
In this tutorial, you'll learn how to implement your own Robot Processor.
It begins by exploring the need for a custom processor, then uses the Normalization processors as the running example to explain how to implement, configure, and serialize a processor. Finally, it lists all helper processors that ship with LeRobot.
It begins by exploring the need for a custom processor, then uses the `NormalizerProcessorStep` as the running example to explain how to implement, configure, and serialize a processor. Finally, it lists all helper processors that ship with LeRobot.
## Why would you need a custom processor?
In most cases, when reading raw data from a sensor like the camera and robot motor encoders,
you will need to process this data to transform it into a format that is compatible to use with the policies in LeRobot.
For example, raw images are encoded with `uint8` and the values are in the range `[0, 255]`.
To use these images with the policies, you will need to cast them to `float32` and normalize them to the range `[0, 1]`.
In most cases, when reading raw data from sensors or when models output actions, you need to process this data to make it compatible with your target system. For example, a common need is normalizing data ranges to make them suitable for neural networks.
For example, in LeRobot's `VanillaObservationProcessor`, raw images come from the environment as numpy arrays with `uint8` values in range `[0, 255]` and in channel-last format `(H, W, C)`. The processor transforms them into PyTorch tensors with `float32` values in range `[0, 1]` and channel-first format `(C, H, W)`:
LeRobot's `NormalizerProcessorStep` handles this crucial task:
```python
# Input: numpy array with shape (480, 640, 3) and dtype uint8
raw_image = env_observation["pixels"] # Values in [0, 255]
# Input: raw joint positions in [0, 180] degrees
raw_action = torch.tensor([90.0, 45.0, 135.0])
# After processing: torch tensor with shape (1, 3, 480, 640) and dtype float32
processed_image = processor(transition)["observation"]["observation.image"] # Values in [0, 1]
# After processing: normalized to [-1, 1] range for model training
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=dataset_stats)
normalized_result = normalizer(transition)
# ...
```
On the other hand, when a model returns a certain action to be executed on the robot, it is often that one has to post-process this action to make it compatible to run on the robot.
For example, the model might return joint positions values that range from `[-1, 1]` and one would need to scale them to the ranges of the minimum and maximum joint angle positions of the robot.
Other common processing needs include:
In LeRobot, this normalization workflow is handled by the `NormalizerProcessor` (for inputs) and the `UnnormalizerProcessor` (for outputs). These processors are heavily used by policies (e.g., Pi0, SmolVLA) and integrate tightly with the `RobotProcessor`'s `get_config`, `state_dict`, and `load_state_dict` APIs.
For instance, `UnnormalizerProcessor` converts model outputs in `[-1, 1]` back to actual robot joint ranges:
- **Device placement**: Moving tensors between CPU/GPU and converting data types
- **Format conversion**: Transforming between different data structures
- **Batching**: Adding/removing batch dimensions for model compatibility
- **Safety constraints**: Applying limits to robot commands
```python
# Input: model action with normalized values in [-1, 1]
normalized_action = torch.tensor([-0.5, 0.8, -1.0, 0.2]) # Model output
# After post-processing: real joint positions in robot's native ranges
# Example: joints range from [-180.0, 180.0]
real_action = unnormalizer(transition)["action"]
# real action after post-processing: [ -90., 144., -180., 36.]
# Example pipeline combining multiple processors
pipeline = PolicyProcessorPipeline([
RenameObservationsProcessorStep(rename_map={}),
AddBatchDimensionProcessorStep(),
NormalizerProcessorStep(features=features, stats=stats),
DeviceProcessorStep(device="cuda"),
# ...
])
```
The unnormalizer uses the dataset statistics to convert back:
```python
# For MIN_MAX normalization: action = (normalized + 1) * (max - min) / 2 + min
real_action = (normalized_action + 1) * (max_val - min_val) / 2 + min_val
```
All these situations point us towards the need for a mechanism to preprocess the data before being passed to the policies and then post-process the action that are returned to be executed on the robot.
To that end, LeRobot provides a pipeline mechanism to implement a sequence of processing steps for the input data and the output action.
LeRobot provides a pipeline mechanism to implement sequences of processing steps for both input data and output actions, making it easy to compose these transformations in the right order for optimal performance.
## How to implement your own processor?
We'll use the `DeviceProcessorStep` as our main example because it demonstrates essential processor patterns and device/dtype awareness that's crucial for modern multi-GPU setups.
We'll use the `NormalizerProcessorStep` as our main example because it demonstrates essential processor patterns including state management, configuration serialization, and tensor handling that you'll commonly need.
Prepare the sequence of processing steps necessary for your problem. A processor step is a class that implements the following methods:
@@ -63,150 +54,107 @@ Prepare the sequence of processing steps necessary for your problem. A processor
### Implement the `__call__` method
The `__call__` method is the core of your processor step. It takes an `EnvTransition` and returns a modified `EnvTransition`. Here's how the `DeviceProcessorStep` works:
The `__call__` method is the core of your processor step. It takes an `EnvTransition` and returns a modified `EnvTransition`. Here's how the `NormalizerProcessorStep` works:
```python
from dataclasses import dataclass
import torch
from lerobot.processor import ProcessorStep, ProcessorStepRegistry
from lerobot.processor.core import EnvTransition, TransitionKey
@dataclass
@ProcessorStepRegistry.register("device_processor")
class DeviceProcessorStep(ProcessorStep):
"""Move tensors to specified device with optional dtype conversion."""
@ProcessorStepRegistry.register("normalizer_processor")
class NormalizerProcessorStep(ProcessorStep):
"""Normalize observations/actions using dataset statistics."""
device: str = "cpu"
float_dtype: str | None = None
features: dict[str, PolicyFeature]
norm_map: dict[FeatureType, NormalizationMode]
stats: dict[str, dict[str, Any]] | None = None
eps: float = 1e-8
_tensor_stats: dict = field(default_factory=dict, init=False, repr=False)
def __post_init__(self):
"""Initialize device and dtype mappings."""
self.tensor_device = torch.device(self.device)
self.non_blocking = "cuda" in str(self.device)
# Map string dtype to torch dtype
if self.float_dtype is not None:
dtype_mapping = {
"float16": torch.float16, "half": torch.float16,
"float32": torch.float32, "float": torch.float32,
"bfloat16": torch.bfloat16
}
self._target_float_dtype = dtype_mapping[self.float_dtype]
else:
self._target_float_dtype = None
"""Convert stats to tensors for efficient computation."""
self.stats = self.stats or {}
self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=torch.float32)
def __call__(self, transition: EnvTransition) -> EnvTransition:
new_transition = transition.copy()
# Process simple tensor keys
for key in [TransitionKey.ACTION, TransitionKey.REWARD, TransitionKey.DONE, TransitionKey.TRUNCATED]:
value = transition.get(key)
if isinstance(value, torch.Tensor):
new_transition[key] = self._process_tensor(value)
# Process nested tensor dicts
for key in [TransitionKey.OBSERVATION, TransitionKey.COMPLEMENTARY_DATA]:
data_dict = transition.get(key)
if data_dict is not None:
new_data_dict = {
k: self._process_tensor(v) if isinstance(v, torch.Tensor) else v
for k, v in data_dict.items()
}
new_transition[key] = new_data_dict
# Normalize observations
# ...
# Normalize action
# ...
return new_transition
def _process_tensor(self, tensor: torch.Tensor) -> torch.Tensor:
"""Move tensor to target device and convert dtype if needed."""
# Smart device handling for multi-GPU compatibility
if tensor.is_cuda and self.tensor_device.type == "cuda":
# Both on GPU: preserve original GPU (Accelerate compatibility)
target_device = tensor.device
else:
# CPU or different device types: use configured device
target_device = self.tensor_device
# Move if necessary
if tensor.device != target_device:
tensor = tensor.to(target_device, non_blocking=self.non_blocking)
# Convert float dtype if specified
if self._target_float_dtype is not None and tensor.is_floating_point():
tensor = tensor.to(dtype=self._target_float_dtype)
return tensor
def get_config(self) -> dict:
return {"device": self.device, "float_dtype": self.float_dtype}
```
See the full implementation in `src/lerobot/processor/device_processor.py` for complete details.
See the full implementation in `src/lerobot/processor/normalize_processor.py` for complete details.
**Key principles:**
- **Always use `transition.copy()`** to avoid side effects
- **Handle both simple and nested tensors** systematically
- **Smart device handling**: Preserve GPU placement for Accelerate compatibility
- **Validate configurations** in `__post_init__()`
- **Handle both observations and actions** consistently
- **Separate config from state**: `get_config()` returns JSON-serializable params, `state_dict()` returns tensors
- **Convert stats to tensors** in `__post_init__()` for efficient computation
### Configuration and State Management
Processors support serialization through three methods that separate configuration from tensor state. This is especially important for normalization processors, which carry dataset statistics (tensors) in their state, and hyperparameters in their config:
Processors support serialization through three methods that separate configuration from tensor state. The `NormalizerProcessorStep` demonstrates this perfectly - it carries dataset statistics (tensors) in its state, and hyperparameters in its config:
```python
from dataclasses import dataclass, field
from typing import Any
import torch
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
# Continuing the NormalizerProcessorStep example...
@dataclass
class NormalizerProcessor:
features: dict[str, PolicyFeature]
norm_map: dict[FeatureType, NormalizationMode]
eps: float = 1e-8
_tensor_stats: dict[str, dict[str, torch.Tensor]] = field(default_factory=dict, init=False, repr=False)
def get_config(self) -> dict[str, Any]:
"""JSON-serializable configuration (no tensors)."""
return {
"eps": self.eps,
"features": {k: {"type": v.type.value, "shape": v.shape} for k, v in self.features.items()},
"norm_map": {ft.value: nm.value for ft, nm in self.norm_map.items()},
# ...
}
def get_config(self) -> dict[str, Any]:
"""JSON-serializable configuration (no tensors)."""
return {
"eps": self.eps,
"features": {k: {"type": v.type.value, "shape": v.shape} for k, v in self.features.items()},
"norm_map": {ft.value: nm.value for ft, nm in self.norm_map.items()},
}
def state_dict(self) -> dict[str, torch.Tensor]:
"""Tensor state only (e.g., dataset statistics)."""
flat: dict[str, torch.Tensor] = {}
for key, sub in self._tensor_stats.items():
for stat_name, tensor in sub.items():
flat[f"{key}.{stat_name}"] = tensor.cpu() # Always save to CPU
return flat
def state_dict(self) -> dict[str, torch.Tensor]:
"""Tensor state only (e.g., dataset statistics)."""
flat: dict[str, torch.Tensor] = {}
for key, sub in self._tensor_stats.items():
for stat_name, tensor in sub.items():
flat[f"{key}.{stat_name}"] = tensor
return flat
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
"""Restore tensor state at runtime."""
self._tensor_stats.clear()
for flat_key, tensor in state.items():
key, stat_name = flat_key.rsplit(".", 1)
self._tensor_stats.setdefault(key, {})[stat_name] = tensor
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
"""Restore tensor state at runtime."""
self._tensor_stats.clear()
for flat_key, tensor in state.items():
key, stat_name = flat_key.rsplit(".", 1)
# Load to processor's configured device
self._tensor_stats.setdefault(key, {})[stat_name] = tensor.to(
dtype=torch.float32, device=self.device
)
# ...
```
**Usage:**
```python
# Save (e.g., inside a policy)
config = processor.get_config()
tensors = processor.state_dict()
config = normalizer.get_config()
tensors = normalizer.state_dict()
# Restore (e.g., loading a pretrained policy)
new_processor = NormalizerProcessor(**config)
new_processor.load_state_dict(tensors)
new_normalizer = NormalizerProcessorStep(**config)
new_normalizer.load_state_dict(tensors)
# Now new_normalizer has the same stats and configuration
```
### Transform features
The `transform_features` method defines how your processor transforms feature names and shapes. This is crucial for policy configuration and debugging.
Normalization typically preserves the feature keys and shapes, so `NormalizerProcessor.transform_features` returns the input features unchanged. When your processor renames or reshapes, implement this method to reflect the mapping for downstream components. For example, a simple rename processor:
For `NormalizerProcessorStep`, features are typically preserved unchanged since normalization doesn't alter keys or shapes:
```python
def transform_features(self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
"""Normalization preserves all feature definitions."""
return features # No changes to feature structure
# ...
```
When your processor renames or reshapes data, implement this method to reflect the mapping for downstream components. For example, a simple rename processor:
```python
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
@@ -219,6 +167,7 @@ def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, Po
if key.startswith("env_state."):
suffix = key[len("env_state."):]
features[f"observation.{suffix}"] = features.pop(key)
# ...
return features
```
@@ -230,98 +179,30 @@ def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, Po
- Always return the modified features dictionary
- Document transformations clearly in the docstring
### Example of usage from the codebase
`transform_features` is used by `RobotProcessor` to derive the dataset/policy feature contract from an initial feature set by applying each step's transformation. You can see concrete examples in the codebase:
- Phone teleoperation record pipeline (`examples/phone_so100_record.py`): processors like `ForwardKinematicsJointsToEE`, `GripperVelocityToJoint`, and `EEBoundsAndSafety` implement `transform_features` to declare which action/observation keys should be materialized in the dataset.
- SO100 follower kinematics (`src/lerobot/robots/so100_follower/robot_kinematic_processor.py`): each processor's `transform_features` method adds or refines feature keys such as `observation.state.ee.{x,y,z,wx,wy,wz}` or `action.gripper.pos`.
- Rename and tokenizer processors (`src/lerobot/processor/rename_processor.py`, `src/lerobot/processor/tokenizer_processor.py`): demonstrate key renaming and adding language token features to the contract.
In practice, you will often aggregate features by running `DataProcessorPipeline.transform_features(...)` with your initial features to compute the final contract before recording or training.
## Helper Classes
LeRobot provides pre-built processor classes for common transformations. Below is a comprehensive list of registered processors in the codebase.
### Core processors (observations, actions, normalization)
- **`VanillaObservationProcessorStep`** (`observation_processor`): Images and state processing to LeRobot format.
- **`NormalizerProcessorStep`** (`normalizer_processor`): Normalize observations/actions (mean/std or min/max to [-1, 1]).
- **`UnnormalizerProcessorStep`** (`unnormalizer_processor`): Inverse of the normalizer for model outputs.
- **`DeviceProcessorStep`** (`device_processor`): Move tensors to a specific device (CPU/GPU) and optional float dtype.
- **`AddBatchDimensionProcessorStep`** (`to_batch_processor`): Add batch dimension to observations/actions when missing.
- **`RenameObservationsProcessorStep`** (`rename_observations_processor`): Rename observation keys using a mapping dictionary.
- **`TokenizerProcessorStep`** (`tokenizer_processor`): Tokenize language tasks into `observation.language.*` tensors.
### Teleoperation mapping processors
- **`MapDeltaActionToRobotAction`** (`map_delta_action_to_robot_action`): Map teleop deltas (e.g., gamepad) to `action.target_*` fields.
- **`MapPhoneActionToRobotAction`** (`map_phone_action_to_robot_action`): Map calibrated phone pose/buttons to `action.target_*` and gripper.
### Robot kinematics processors (SO100 follower example)
- **`EEReferenceAndDelta`** (`ee_reference_and_delta`): Compute desired EE pose from target deltas and current pose.
- **`EEBoundsAndSafety`** (`ee_bounds_and_safety`): Clip EE pose to bounds and check for jumps.
- **`InverseKinematicsEEToJoints`** (`inverse_kinematics_ee_to_joints`): Convert EE pose to joint targets via IK.
- **`GripperVelocityToJoint`** (`gripper_velocity_to_joint`): Convert gripper velocity input to joint position command.
- **`ForwardKinematicsJointsToEE`** (`forward_kinematics_joints_to_ee`): Compute EE pose features from joint positions via FK.
- **`AddRobotObservationAsComplimentaryData`** (`add_robot_observation`): Read robot observation and insert `raw_joint_positions` into complementary data.
### Policy-specific utility processors
- **`Pi0NewLineProcessor`** (`pi0_new_line_processor`): Ensure text tasks end with a newline (Pi0 tokenizer compatibility).
- **`SmolVLANewLineProcessor`** (`smolvla_new_line_processor`): Ensure text tasks end with a newline (SmolVLA tokenizer compatibility).
### Usage Example
```python
from lerobot.processor import (
NormalizerProcessorStep, DeviceProcessorStep,
RobotProcessorPipeline, AddBatchDimensionProcessorStep
)
# Create a processing pipeline (typical policy preprocessor)
steps = [
NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats),
AddBatchDimensionProcessorStep(),
DeviceProcessorStep(device="cuda"),
]
# Use in RobotProcessorPipeline
processor = RobotProcessorPipeline[dict, dict](steps=steps)
processed_transition = processor(raw_transition)
```
### Using overrides
You can override step parameters at load-time using `overrides`. This is handy for non-serializable objects or site-specific settings. It works both in policy factories and with `DataProcessorPipeline.from_pretrained(...)`.
**Foundational model adaptation**: This is particularly useful when working with foundational pretrained policies where you rarely have access to the original training statistics. You can inject your own dataset statistics to adapt the normalizer to your specific robot or environment data.
Example: during policy evaluation on the robot, override the device and rename map.
Use this to run a policy trained on CUDA on a CPU-only robot, or to remap camera keys when the robot uses different names than the dataset.
```437:445:src/lerobot/record.py
preprocessor, postprocessor = make_processor(
policy_cfg=cfg.policy,
pretrained_path=cfg.policy.pretrained_path,
dataset_stats=rename_stats(dataset.meta.stats, cfg.dataset.rename_map),
preprocessor_overrides={
"device_processor": {"device": cfg.policy.device},
"rename_processor": {"rename_map": cfg.dataset.rename_map},
},
)
```
Direct usage with `from_pretrained`:
```python
from lerobot.processor import RobotProcessorPipeline
# Load a foundational policy trained on diverse robot data
# but adapt normalization to your specific robot/environment
new_stats = LeRobotDataset(repo_id="username/my-dataset").meta.stats
processor = RobotProcessorPipeline.from_pretrained(
"username/my-processor",
"huggingface/foundational-robot-policy", # Pretrained foundation model
overrides={
"device_processor": {"device": "cuda:0"}, # registry name for registered steps
"CustomStep": {"param": 42}, # class name for non-registered steps
"normalizer_processor": {"stats": new_stats}, # Inject your robot's statistics
"device_processor": {"device": "cuda:0"}, # registry name for registered steps
"rename_processor": {"rename_map": robot_key_map}, # Map your robot's observation keys
# ...
},
)
```
@@ -332,139 +213,43 @@ Based on analysis of all LeRobot processor implementations, here are the key pat
### 1. **Safe Data Handling**
```python
# ✅ Always copy data to avoid side effects
new_action = action.copy()
new_obs = observation.copy()
Always create copies of input data to avoid unintended side effects. Use `transition.copy()` and `observation.copy()` rather than modifying data in-place. This prevents your processor from accidentally affecting other components in the pipeline.
# ✅ Check for required data before processing
if "pixels" not in observation:
return observation # Pass through unchanged
Check for required data before processing and handle missing data gracefully. If your processor expects certain keys (like `"pixels"` for image processing), validate their presence first. For optional data, use safe access patterns like `transition.get()` and handle `None` values appropriately.
# ✅ Handle None gracefully
comp = self.transition.get(TransitionKey.COMPLEMENTARY_DATA)
if comp is None:
raise ValueError("Required complementary data missing")
```
When data validation fails, provide clear, actionable error messages that help users understand what went wrong and how to fix it.
### 2. **Robust Input Validation**
### 2. **Choose Appropriate Base Classes**
```python
# ✅ Validate data types and shapes
if not isinstance(action, dict):
raise ValueError(f"Action should be a RobotAction type got {type(action)}")
LeRobot provides specialized base classes that reduce boilerplate code and ensure consistency. Use `ObservationProcessorStep` when you only need to modify observations, `ActionProcessorStep` for action-only processing, and `RobotActionProcessorStep` specifically for dictionary-based robot actions.
# ✅ Check tensor properties before processing
if img_tensor.dtype != torch.uint8:
raise ValueError(f"Expected torch.uint8 images, but got {img_tensor.dtype}")
Only inherit directly from `ProcessorStep` when you need full control over the entire transition or when processing multiple transition components simultaneously. The specialized base classes handle the transition management for you and provide type safety.
# ✅ Validate required keys exist
if None in (x, y, z, wx, wy, wz):
raise ValueError("Missing required end-effector pose components")
```
### 3. **Registration and Naming**
### 3. **Use Appropriate Base Classes**
Register your processors with descriptive, namespaced names using `@ProcessorStepRegistry.register()`. Use organization prefixes like `"robotics_lab/safety_clipper"` or `"acme_corp/vision_enhancer"` to avoid naming conflicts. Avoid generic names like `"processor"` or `"step"` that could clash with other implementations.
```python
# ✅ Observation-only processors
class MyObsProcessor(ObservationProcessorStep):
def observation(self, observation): ...
Good registration makes your processors discoverable and enables clean serialization/deserialization when saving and loading pipelines.
# ✅ Action-only processors
class MyActionProcessor(ActionProcessorStep):
def action(self, action): ...
### 4. **State Management Patterns**
# ✅ Robot action processors (dict actions only)
class MyRobotActionProcessor(RobotActionProcessorStep):
def action(self, action: dict[str, Any]): ...
Distinguish between configuration parameters (JSON-serializable values) and internal state (tensors, buffers). Use dataclass fields with `init=False, repr=False` for internal state that shouldn't appear in the constructor or string representation.
# ✅ Full control processors
class MyFullProcessor(ProcessorStep):
def __call__(self, transition: EnvTransition): ...
```
Implement the `reset()` method to clear internal state between episodes. This is crucial for stateful processors that accumulate data over time, like moving averages or temporal filters.
### 4. **Registration and Naming**
Remember that `get_config()` should only return JSON-serializable configuration, while `state_dict()` handles tensor state separately.
```python
# ✅ Always register with namespaced names
@ProcessorStepRegistry.register("my_company/image_processor")
@dataclass
class ImageProcessor(ObservationProcessorStep):
...
### 5. **Input Validation and Error Handling**
# ✅ Use descriptive, unique names
# Good: "robotics_lab/safety_clipper", "acme_corp/vision_enhancer"
# Bad: "processor", "step", "my_processor"
```
Validate input types and shapes before processing. Check tensor properties like `dtype` and dimensions to ensure compatibility with your algorithms. For robot actions, verify that required pose components or joint values are present and within expected ranges.
### 5. **State Management Patterns**
Use early returns for edge cases where no processing is needed. Provide clear, descriptive error messages that include the expected vs. actual data types or shapes. This makes debugging much easier for users.
```python
# ✅ Use dataclass fields for internal state
@dataclass
class StatefulProcessor(ProcessorStep):
# Public config
window_size: int = 10
### 6. **Device and Dtype Awareness**
# Internal state (not in config)
_buffer: list = field(default_factory=list, init=False, repr=False)
_last_value: float | None = field(default=None, init=False, repr=False)
Design your processors to automatically adapt to the device and dtype of input tensors. Internal tensors (like normalization statistics) should match the input tensor's device and dtype to ensure compatibility with multi-GPU training, mixed precision, and distributed setups.
def reset(self):
"""Reset internal state between episodes."""
self._buffer.clear()
self._last_value = None
```
### 6. **Error Handling**
```python
# ✅ Early returns for edge cases
if not self.enabled or action is None:
return action
# ✅ Clear error messages for invalid inputs
if not isinstance(action, dict):
raise ValueError(f"Action should be a RobotAction type got {type(action)}")
# ✅ Validate required keys exist
if "required_key" not in action:
raise ValueError("Required key 'required_key' not found in action")
```
### 7. **Device and Dtype Awareness**
The key principle: **tensors stored in your processor should mimic the dtype and device of input tensors**. This enables seamless operation in multi-GPU setups, Accelerate, and data parallel configurations.
```python
# ✅ Adapt internal state to match input tensors
def _apply_transform(self, tensor: torch.Tensor, key: str) -> torch.Tensor:
# Check if our internal stats match the input tensor
if key in self._tensor_stats:
first_stat = next(iter(self._tensor_stats[key].values()))
if first_stat.device != tensor.device or first_stat.dtype != tensor.dtype:
# Automatically adapt to input tensor's device/dtype
self.to(device=tensor.device, dtype=tensor.dtype)
# Now process with matching device/dtype
return self._process_with_stats(tensor, key)
# ✅ Implement to() method for device/dtype migration
def to(self, device=None, dtype=None):
if device is not None:
self.device = device
if dtype is not None:
self.dtype = dtype
# Update internal tensor stats to match
self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype)
return self
# ✅ This pattern enables:
# - Multi-GPU training (data on different GPUs)
# - Mixed precision (float16, bfloat16)
# - Accelerate compatibility (automatic device placement)
# - Data parallel setups (distributed training)
```
Implement a `to()` method that moves your processor's internal state to the specified device. Check device/dtype compatibility at runtime and automatically migrate internal state when needed. This pattern enables seamless operation across different hardware configurations without manual intervention.
## Conclusion