mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-20 19:19:56 +00:00
improved part 2 of processor guide
This commit is contained in:
committed by
Adil Zouitine
parent
feb3fed5e8
commit
ac80f1f081
@@ -1,7 +1,8 @@
|
|||||||
# Implement your processor
|
# Implement your own Robot Processor
|
||||||
|
|
||||||
In this tutorial, we will explain how to implement your own processor. We will start by motivating
|
In this tutorial, you'll learn how to implement your own Robot Processor.
|
||||||
the need for a custom processor and then we will explain the helper classes that can help you implement your own processor.
|
It begins by exploring the need for a custom processor, followed by an explanation of how to implement one.
|
||||||
|
The tutorial also covers the set of helper classes that are available in LeRobot to support the implementation.
|
||||||
|
|
||||||
## Why would you need a custom processor?
|
## Why would you need a custom processor?
|
||||||
|
|
||||||
@@ -10,7 +11,7 @@ you will need to process this data to transform it into a format that is compati
|
|||||||
For example, raw images are encoded with `uint8` and the values are in the range `[0, 255]`.
|
For example, raw images are encoded with `uint8` and the values are in the range `[0, 255]`.
|
||||||
To use these images with the policies, you will need to cast them to `float32` and normalize them to the range `[0, 1]`.
|
To use these images with the policies, you will need to cast them to `float32` and normalize them to the range `[0, 1]`.
|
||||||
|
|
||||||
For example, in LeRobot's `ImageProcessor`, raw images come from the environment as numpy arrays with `uint8` values in range `[0, 255]` and in channel-last format `(H, W, C)`. The processor transforms them into PyTorch tensors with `float32` values in range `[0, 1]` and channel-first format `(C, H, W)`:
|
For example, in LeRobot's `VanillaObservationProcessor`, raw images come from the environment as numpy arrays with `uint8` values in range `[0, 255]` and in channel-last format `(H, W, C)`. The processor transforms them into PyTorch tensors with `float32` values in range `[0, 1]` and channel-first format `(C, H, W)`:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
# Input: numpy array with shape (480, 640, 3) and dtype uint8
|
# Input: numpy array with shape (480, 640, 3) and dtype uint8
|
||||||
@@ -59,446 +60,188 @@ Prepare the sequence of processing steps necessary for your problem. A processor
|
|||||||
|
|
||||||
### Implement the `__call__` method
|
### Implement the `__call__` method
|
||||||
|
|
||||||
The `__call__` method is the core of your processor step. It takes an `EnvTransition` and returns a modified `EnvTransition`. Here's a real example from LeRobot's `ImageProcessor`:
|
The `__call__` method is the core of your processor step. It takes an `EnvTransition` and returns a modified `EnvTransition`. Here's a minimal example:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
import torch
|
|
||||||
import einops
|
|
||||||
from lerobot.processor.pipeline import EnvTransition, TransitionKey
|
from lerobot.processor.pipeline import EnvTransition, TransitionKey
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ImageProcessor:
|
class MyProcessor:
|
||||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
|
# Check if the required data exists
|
||||||
observation = transition.get(TransitionKey.OBSERVATION)
|
observation = transition.get(TransitionKey.OBSERVATION)
|
||||||
|
|
||||||
if observation is None:
|
if observation is None:
|
||||||
return transition
|
return transition
|
||||||
|
|
||||||
processed_obs = {}
|
# Process the data
|
||||||
|
processed_obs = self._process_observation(observation)
|
||||||
|
|
||||||
# Copy all observations first
|
# Return new transition with processed data
|
||||||
for key, value in observation.items():
|
|
||||||
processed_obs[key] = value
|
|
||||||
|
|
||||||
# Handle pixels key if present
|
|
||||||
pixels = observation.get("pixels")
|
|
||||||
if pixels is not None:
|
|
||||||
# Remove pixels from processed_obs since we'll replace it with processed images
|
|
||||||
processed_obs.pop("pixels", None)
|
|
||||||
|
|
||||||
# Process the image
|
|
||||||
processed_img = self._process_single_image(pixels)
|
|
||||||
processed_obs["observation.image"] = processed_img
|
|
||||||
|
|
||||||
# Return new transition with processed observation
|
|
||||||
new_transition = transition.copy()
|
new_transition = transition.copy()
|
||||||
new_transition[TransitionKey.OBSERVATION] = processed_obs
|
new_transition[TransitionKey.OBSERVATION] = processed_obs
|
||||||
return new_transition
|
return new_transition
|
||||||
|
|
||||||
def _process_single_image(self, img):
|
def _process_observation(self, obs):
|
||||||
# Convert to tensor
|
# Your custom processing logic here
|
||||||
img_tensor = torch.from_numpy(img)
|
return obs
|
||||||
|
|
||||||
# Add batch dimension if needed
|
|
||||||
if img_tensor.ndim == 3:
|
|
||||||
img_tensor = img_tensor.unsqueeze(0)
|
|
||||||
|
|
||||||
# Convert to channel-first format: (B, H, W, C) -> (B, C, H, W)
|
|
||||||
img_tensor = einops.rearrange(img_tensor, "b h w c -> b c h w").contiguous()
|
|
||||||
|
|
||||||
# Convert to float32 and normalize to [0, 1]
|
|
||||||
img_tensor = img_tensor.type(torch.float32) / 255.0
|
|
||||||
|
|
||||||
return img_tensor
|
|
||||||
```
|
```
|
||||||
|
|
||||||
Key principles for implementing `__call__`:
|
**Key principles:**
|
||||||
|
|
||||||
- Always check if the required data exists (observations, actions, etc.)
|
- Always check if required data exists before processing
|
||||||
- Return the original transition unchanged if no processing is needed
|
- Return unchanged transition if no processing is needed
|
||||||
- Create a copy of the transition to avoid side effects
|
- Use `transition.copy()` to avoid side effects
|
||||||
- Only modify the specific keys your processor is responsible for
|
- Only modify the specific keys your processor handles
|
||||||
|
|
||||||
|
**Tip**: For observation-only processors, inherit from `ObservationProcessor` to avoid writing `__call__` boilerplate:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from lerobot.processor.pipeline import ObservationProcessor
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MyObservationProcessor(ObservationProcessor):
|
||||||
|
def observation(self, observation):
|
||||||
|
# Only implement this method - __call__ is handled automatically
|
||||||
|
return self._process_observation(observation)
|
||||||
|
```
|
||||||
|
|
||||||
### Configuration and State Management
|
### Configuration and State Management
|
||||||
|
|
||||||
LeRobot processors support serialization and deserialization through three key methods. Here's how they work using `NormalizerProcessor` as an example:
|
Processors support serialization through three methods that separate configuration from tensor state:
|
||||||
|
|
||||||
#### `get_config()` - Serializable Configuration
|
|
||||||
|
|
||||||
This method returns all non-tensor configuration that can be saved to JSON:
|
|
||||||
|
|
||||||
```python
|
```python
|
||||||
@dataclass
|
@dataclass
|
||||||
class NormalizerProcessor:
|
class MyProcessor:
|
||||||
features: dict[str, PolicyFeature]
|
threshold: float = 0.5
|
||||||
norm_map: dict[FeatureType, NormalizationMode]
|
_running_mean: torch.Tensor = field(default=None, init=False)
|
||||||
normalize_keys: set[str] | None = None
|
|
||||||
eps: float = 1e-8
|
|
||||||
|
|
||||||
def get_config(self) -> dict[str, Any]:
|
def get_config(self) -> dict[str, Any]:
|
||||||
"""Return JSON-serializable configuration."""
|
"""Return JSON-serializable configuration."""
|
||||||
return {
|
return {"threshold": self.threshold}
|
||||||
"features": {k: {"type": v.type.value, "shape": v.shape} for k, v in self.features.items()},
|
|
||||||
"norm_map": {k.value: v.value for k, v in self.norm_map.items()},
|
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||||
"normalize_keys": list(self.normalize_keys) if self.normalize_keys else None,
|
"""Return tensor state only."""
|
||||||
"eps": self.eps,
|
if self._running_mean is not None:
|
||||||
# Note: 'stats' is not included as it contains tensors
|
return {"running_mean": self._running_mean}
|
||||||
}
|
return {}
|
||||||
|
|
||||||
|
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||||
|
"""Restore tensor state."""
|
||||||
|
if "running_mean" in state:
|
||||||
|
self._running_mean = state["running_mean"]
|
||||||
```
|
```
|
||||||
|
|
||||||
#### `state_dict()` - Tensor State
|
**Usage:**
|
||||||
|
|
||||||
This method returns only PyTorch tensors that need special serialization:
|
|
||||||
|
|
||||||
```python
|
```python
|
||||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
# Save
|
||||||
"""Return tensor state dictionary."""
|
|
||||||
state = {}
|
|
||||||
for key, stats in self._tensor_stats.items():
|
|
||||||
for stat_name, tensor_val in stats.items():
|
|
||||||
state[f"{key}.{stat_name}"] = tensor_val
|
|
||||||
return state
|
|
||||||
```
|
|
||||||
|
|
||||||
#### `load_state_dict()` - Restore Tensor State
|
|
||||||
|
|
||||||
This method restores the tensor state from a saved state dictionary:
|
|
||||||
|
|
||||||
```python
|
|
||||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
|
||||||
"""Load tensor state from state dictionary."""
|
|
||||||
# Reconstruct _tensor_stats from flat state dict
|
|
||||||
self._tensor_stats = {}
|
|
||||||
for full_key, tensor_val in state.items():
|
|
||||||
if "." in full_key:
|
|
||||||
key, stat_name = full_key.rsplit(".", 1)
|
|
||||||
if key not in self._tensor_stats:
|
|
||||||
self._tensor_stats[key] = {}
|
|
||||||
self._tensor_stats[key][stat_name] = tensor_val
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Usage Example
|
|
||||||
|
|
||||||
```python
|
|
||||||
# Save processor
|
|
||||||
config = processor.get_config()
|
config = processor.get_config()
|
||||||
tensors = processor.state_dict()
|
tensors = processor.state_dict()
|
||||||
|
|
||||||
# Later, restore processor
|
# Restore
|
||||||
new_processor = NormalizerProcessor(**config)
|
new_processor = MyProcessor(**config)
|
||||||
new_processor.load_state_dict(tensors)
|
new_processor.load_state_dict(tensors)
|
||||||
```
|
```
|
||||||
|
|
||||||
### Feature Contract
|
### Feature Contract
|
||||||
|
|
||||||
The `feature_contract` method defines how your processor transforms the feature space. It tells the system how input feature names and shapes change after processing. This is crucial for policy configuration and debugging.
|
The `feature_contract` method defines how your processor transforms feature names and shapes. This is crucial for policy configuration and debugging.
|
||||||
|
|
||||||
Here's an example from `ImageProcessor`:
|
|
||||||
|
|
||||||
```python
|
```python
|
||||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||||
"""Transforms:
|
"""Transform feature keys: old_key -> new_key"""
|
||||||
pixels -> observation.image,
|
# Simple renaming
|
||||||
observation.pixels -> observation.image,
|
|
||||||
pixels.<cam> -> observation.images.<cam>,
|
|
||||||
observation.pixels.<cam> -> observation.images.<cam>
|
|
||||||
"""
|
|
||||||
# Handle simple pixel renaming
|
|
||||||
if "pixels" in features:
|
if "pixels" in features:
|
||||||
features["observation.image"] = features.pop("pixels")
|
features["observation.image"] = features.pop("pixels")
|
||||||
if "observation.pixels" in features:
|
|
||||||
features["observation.image"] = features.pop("observation.pixels")
|
|
||||||
|
|
||||||
# Handle camera-specific pixels
|
# Pattern-based renaming
|
||||||
prefixes = ("pixels.", "observation.pixels.")
|
|
||||||
for key in list(features.keys()):
|
for key in list(features.keys()):
|
||||||
for p in prefixes:
|
if key.startswith("env_state."):
|
||||||
if key.startswith(p):
|
suffix = key[len("env_state."):]
|
||||||
suffix = key[len(p):]
|
features[f"observation.{suffix}"] = features.pop(key)
|
||||||
features[f"observation.images.{suffix}"] = features.pop(key)
|
|
||||||
break
|
|
||||||
return features
|
|
||||||
```
|
|
||||||
|
|
||||||
And from `StateProcessor`:
|
|
||||||
|
|
||||||
```python
|
|
||||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
|
||||||
"""Transforms:
|
|
||||||
environment_state -> observation.environment_state,
|
|
||||||
agent_pos -> observation.state
|
|
||||||
"""
|
|
||||||
pairs = (
|
|
||||||
("environment_state", "observation.environment_state"),
|
|
||||||
("agent_pos", "observation.state"),
|
|
||||||
)
|
|
||||||
for old, new in pairs:
|
|
||||||
if old in features:
|
|
||||||
features[new] = features.pop(old)
|
|
||||||
prefixed = f"observation.{old}"
|
|
||||||
if prefixed in features:
|
|
||||||
features[new] = features.pop(prefixed)
|
|
||||||
return features
|
return features
|
||||||
```
|
```
|
||||||
|
|
||||||
**Key principles:**
|
**Key principles:**
|
||||||
|
|
||||||
- Use `features.pop(old_key)` to remove the old feature and get its value
|
- Use `features.pop(old_key)` to remove and get the old feature
|
||||||
- Use `features[new_key] = old_feature` to add the new feature with same properties
|
- Use `features[new_key] = old_feature` to add the renamed feature
|
||||||
- Always return the modified features dictionary
|
- Always return the modified features dictionary
|
||||||
- Document the transformations clearly in the docstring
|
- Document transformations clearly in the docstring
|
||||||
|
|
||||||
## Helper Classes
|
## Helper Classes
|
||||||
|
|
||||||
LeRobot provides several pre-built processor classes that handle common transformations, which you can use directly or as building blocks for your custom processors.
|
LeRobot provides pre-built processor classes for common transformations:
|
||||||
|
|
||||||
### Core Processing Classes
|
### Core Classes
|
||||||
|
|
||||||
- **`ImageProcessor`** - Converts images from numpy arrays (uint8, channel-last) to PyTorch tensors (float32, channel-first)
|
- **`VanillaObservationProcessor`** - Handles images and state observations
|
||||||
- **`StateProcessor`** - Handles state observations, converting numpy arrays to tensors and renaming keys
|
- **`NormalizerProcessor`** - Normalizes data using dataset statistics (mean/std or min/max)
|
||||||
- **`NormalizerProcessor`** - Normalizes observations and actions using dataset statistics (mean/std or min/max)
|
- **`UnnormalizerProcessor`** - Converts normalized values back to original ranges
|
||||||
- **`UnnormalizerProcessor`** - Inverse of NormalizerProcessor, converts normalized values back to original ranges
|
|
||||||
|
|
||||||
### Utility Classes
|
### Utility Classes
|
||||||
|
|
||||||
- **`DeviceProcessor`** - Moves tensors to specified device (CPU/GPU)
|
- **`DeviceProcessor`** - Moves tensors to specified device (CPU/GPU)
|
||||||
- **`ToBatchProcessor`** - Adds batch dimensions to observations and actions
|
- **`ToBatchProcessor`** - Adds batch dimensions
|
||||||
- **`RenameProcessor`** - Renames keys in observations using a mapping dictionary
|
- **`RenameProcessor`** - Renames keys using a mapping dictionary
|
||||||
- **`TokenizerProcessor`** - Handles text tokenization for language-conditioned policies
|
- **`TokenizerProcessor`** - Handles text tokenization for language-conditioned policies
|
||||||
|
|
||||||
### Composition Classes
|
### Usage Example
|
||||||
|
|
||||||
- **`VanillaObservationProcessor`** - Combines ImageProcessor and StateProcessor for complete observation handling
|
|
||||||
- **`RobotProcessor`** - Main container that orchestrates a sequence of processor steps
|
|
||||||
|
|
||||||
### Example Usage
|
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from lerobot.processor import (
|
from lerobot.processor import (
|
||||||
ImageProcessor,
|
VanillaObservationProcessor,
|
||||||
NormalizerProcessor,
|
NormalizerProcessor,
|
||||||
DeviceProcessor,
|
DeviceProcessor,
|
||||||
RobotProcessor
|
RobotProcessor
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create a processing pipeline
|
# Create a processing pipeline
|
||||||
preprocessing_steps = [
|
steps = [
|
||||||
ImageProcessor(), # Convert images
|
VanillaObservationProcessor(), # Process images and states
|
||||||
NormalizerProcessor(features=features, norm_map=norm_map, stats=dataset_stats), # Normalize
|
NormalizerProcessor(features=features, norm_map=norm_map, stats=stats),
|
||||||
DeviceProcessor(device="cuda"), # Move to GPU
|
DeviceProcessor(device="cuda"),
|
||||||
]
|
]
|
||||||
|
|
||||||
# Combine into a robot processor
|
# Use in RobotProcessor
|
||||||
preprocessor = RobotProcessor(steps=preprocessing_steps, name="my_preprocessor")
|
processor = RobotProcessor(steps=steps)
|
||||||
|
processed_transition = processor(raw_transition)
|
||||||
# Use it
|
|
||||||
processed_transition = preprocessor(raw_transition)
|
|
||||||
```
|
```
|
||||||
|
|
||||||
These helper classes implement the full `ProcessorStep` protocol, so they can be easily combined and extended for your specific needs.
|
|
||||||
|
|
||||||
## Best Practices
|
## Best Practices
|
||||||
|
|
||||||
### Design Principles
|
- **Keep processors atomic** - One transformation per processor for reusability and debugging
|
||||||
|
- **Use dataclasses** - Clean initialization with `@dataclass`
|
||||||
- **Keep processors atomic** - Each processor should handle one specific transformation. This makes them more reusable and easier to debug.
|
- **Always register processors** - Use `@ProcessorStepRegistry.register("name")` for discoverability
|
||||||
- **Use dataclasses** - Always implement processors as dataclasses for clean initialization and automatic generation of `__init__`, `__repr__`, etc.
|
- **Check for None** - Always validate required data exists before processing
|
||||||
|
- **Use copy() for safety** - Avoid side effects with `transition.copy()`
|
||||||
|
- **Separate config and state** - JSON-serializable config vs tensor state_dict
|
||||||
|
- **Use base classes** - Inherit from `ObservationProcessor` for observation-only processing
|
||||||
|
|
||||||
```python
|
```python
|
||||||
|
@ProcessorStepRegistry.register("my_processor")
|
||||||
@dataclass
|
@dataclass
|
||||||
class MyProcessor:
|
class MyProcessor(ObservationProcessor):
|
||||||
threshold: float = 0.5
|
threshold: float = 0.5
|
||||||
normalize: bool = True
|
|
||||||
|
def observation(self, observation):
|
||||||
|
if observation is None:
|
||||||
|
return observation
|
||||||
|
# Your processing logic here
|
||||||
|
return processed_observation
|
||||||
```
|
```
|
||||||
|
|
||||||
### Registration and Discovery
|
## Conclusion
|
||||||
|
|
||||||
- **Always register your processor** - Use the `@ProcessorStepRegistry.register()` decorator to make your processor discoverable:
|
You now have all the tools to implement custom processors in LeRobot! The key steps are:
|
||||||
|
|
||||||
```python
|
1. **Define your processor** as a dataclass with the required methods (`__call__`, `get_config`, `state_dict`, `load_state_dict`, `reset`, `feature_contract`)
|
||||||
@ProcessorStepRegistry.register("my_custom_processor")
|
2. **Register it** using `@ProcessorStepRegistry.register("name")` for discoverability
|
||||||
@dataclass
|
3. **Integrate it** into a `RobotProcessor` pipeline with other processing steps
|
||||||
class MyCustomProcessor:
|
4. **Use base classes** like `ObservationProcessor` when possible to reduce boilerplate
|
||||||
# Implementation
|
|
||||||
```
|
|
||||||
|
|
||||||
### State Management
|
The processor system is designed to be modular and composable, allowing you to build complex data processing pipelines from simple, focused components. Whether you're preprocessing sensor data for training or post-processing model outputs for robot execution, custom processors give you the flexibility to handle any data transformation your robotics application requires.
|
||||||
|
|
||||||
- **Separate concerns in config vs state_dict** - Keep non-tensor configuration in `get_config()` and only tensors in `state_dict()`:
|
Start simple, test thoroughly, and leverage the existing helper classes to build robust data processing pipelines for your robot learning workflows.
|
||||||
|
|
||||||
```python
|
|
||||||
def get_config(self) -> dict[str, Any]:
|
|
||||||
# JSON-serializable data only
|
|
||||||
return {"threshold": self.threshold, "mode": self.mode}
|
|
||||||
|
|
||||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
|
||||||
# Tensors only
|
|
||||||
return {"running_mean": self.running_mean, "weights": self.weights}
|
|
||||||
```
|
|
||||||
|
|
||||||
### Safety and Robustness
|
|
||||||
|
|
||||||
- **Always check for None values** - Validate that required data exists before processing:
|
|
||||||
|
|
||||||
```python
|
|
||||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
|
||||||
observation = transition.get(TransitionKey.OBSERVATION)
|
|
||||||
if observation is None:
|
|
||||||
return transition # Return unchanged if no observation
|
|
||||||
```
|
|
||||||
|
|
||||||
- **Use copy() for safety** - If performance is not critical, copy transitions to avoid side effects:
|
|
||||||
|
|
||||||
```python
|
|
||||||
new_transition = transition.copy()
|
|
||||||
new_transition[TransitionKey.OBSERVATION] = processed_obs
|
|
||||||
return new_transition
|
|
||||||
```
|
|
||||||
|
|
||||||
### Debugging and Development
|
|
||||||
|
|
||||||
- **Use hooks for debugging** - RobotProcessor supports hooks for monitoring:
|
|
||||||
|
|
||||||
```python
|
|
||||||
def debug_hook(step_idx: int, transition: EnvTransition) -> None:
|
|
||||||
print(f"Step {step_idx}: {list(transition.keys())}")
|
|
||||||
|
|
||||||
processor = RobotProcessor(steps=[...], before_step_hooks=[debug_hook])
|
|
||||||
```
|
|
||||||
|
|
||||||
- **Use step_through() for development** - Iterate through processing steps one by one:
|
|
||||||
|
|
||||||
```python
|
|
||||||
for step_idx, transition in processor.step_through(data):
|
|
||||||
print(f"After step {step_idx}: {transition}")
|
|
||||||
```
|
|
||||||
|
|
||||||
### Performance Considerations
|
|
||||||
|
|
||||||
- **Batch processing** - Design processors to handle batched data efficiently when possible
|
|
||||||
- **Device awareness** - Let `DeviceProcessor` handle device placement rather than hardcoding it
|
|
||||||
- **Memory efficiency** - Reuse tensors when safe to do so for better performance
|
|
||||||
|
|
||||||
### Documentation
|
|
||||||
|
|
||||||
- **Document feature contracts clearly** - Be explicit about how your processor transforms the feature space
|
|
||||||
- **Provide usage examples** - Include docstring examples showing typical usage patterns
|
|
||||||
|
|
||||||
## Complete Example: Custom Smoothing Processor
|
|
||||||
|
|
||||||
Here's a complete example that implements a simple action smoothing processor:
|
|
||||||
|
|
||||||
```python
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import Any
|
|
||||||
import torch
|
|
||||||
from torch import Tensor
|
|
||||||
|
|
||||||
from lerobot.processor.pipeline import (
|
|
||||||
EnvTransition,
|
|
||||||
ProcessorStepRegistry,
|
|
||||||
TransitionKey
|
|
||||||
)
|
|
||||||
from lerobot.configs.types import PolicyFeature
|
|
||||||
|
|
||||||
@ProcessorStepRegistry.register("action_smoother")
|
|
||||||
@dataclass
|
|
||||||
class ActionSmoothingProcessor:
|
|
||||||
"""Smooths actions using exponential moving average.
|
|
||||||
|
|
||||||
This processor maintains a running average of actions to reduce jitter
|
|
||||||
from the policy predictions.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Configuration
|
|
||||||
alpha: float = 0.7 # Smoothing factor (0 = no smoothing, 1 = no memory)
|
|
||||||
|
|
||||||
# State (not in config)
|
|
||||||
_previous_action: Tensor | None = field(default=None, init=False, repr=False)
|
|
||||||
_initialized: bool = field(default=False, init=False, repr=False)
|
|
||||||
|
|
||||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
|
||||||
action = transition.get(TransitionKey.ACTION)
|
|
||||||
|
|
||||||
if action is None:
|
|
||||||
return transition
|
|
||||||
|
|
||||||
# Convert to tensor if needed
|
|
||||||
if not isinstance(action, torch.Tensor):
|
|
||||||
action = torch.as_tensor(action, dtype=torch.float32)
|
|
||||||
|
|
||||||
# Initialize on first call
|
|
||||||
if not self._initialized:
|
|
||||||
self._previous_action = action.clone()
|
|
||||||
self._initialized = True
|
|
||||||
smoothed_action = action
|
|
||||||
else:
|
|
||||||
# Exponential moving average: new = alpha * current + (1-alpha) * previous
|
|
||||||
smoothed_action = self.alpha * action + (1 - self.alpha) * self._previous_action
|
|
||||||
self._previous_action = smoothed_action.clone()
|
|
||||||
|
|
||||||
# Return new transition with smoothed action
|
|
||||||
new_transition = transition.copy()
|
|
||||||
new_transition[TransitionKey.ACTION] = smoothed_action
|
|
||||||
return new_transition
|
|
||||||
|
|
||||||
def get_config(self) -> dict[str, Any]:
|
|
||||||
"""Return JSON-serializable configuration."""
|
|
||||||
return {"alpha": self.alpha}
|
|
||||||
|
|
||||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
|
||||||
"""Return tensor state."""
|
|
||||||
if self._previous_action is not None:
|
|
||||||
return {"previous_action": self._previous_action}
|
|
||||||
return {}
|
|
||||||
|
|
||||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
|
||||||
"""Load tensor state."""
|
|
||||||
if "previous_action" in state:
|
|
||||||
self._previous_action = state["previous_action"]
|
|
||||||
self._initialized = True
|
|
||||||
else:
|
|
||||||
self._previous_action = None
|
|
||||||
self._initialized = False
|
|
||||||
|
|
||||||
def reset(self) -> None:
|
|
||||||
"""Reset processor state at episode boundaries."""
|
|
||||||
self._previous_action = None
|
|
||||||
self._initialized = False
|
|
||||||
|
|
||||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
|
||||||
"""Action shapes remain unchanged."""
|
|
||||||
return features # No transformation to feature space
|
|
||||||
```
|
|
||||||
|
|
||||||
### Usage Example
|
|
||||||
|
|
||||||
```python
|
|
||||||
from lerobot.processor import RobotProcessor, DeviceProcessor
|
|
||||||
|
|
||||||
# Create a postprocessing pipeline with action smoothing
|
|
||||||
postprocessing_steps = [
|
|
||||||
DeviceProcessor(device="cpu"), # Move to CPU first
|
|
||||||
ActionSmoothingProcessor(alpha=0.8), # Apply smoothing
|
|
||||||
]
|
|
||||||
|
|
||||||
postprocessor = RobotProcessor(
|
|
||||||
steps=postprocessing_steps,
|
|
||||||
name="smooth_postprocessor"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Use in your policy inference loop
|
|
||||||
for transition in environment_transitions:
|
|
||||||
# Get action from policy
|
|
||||||
action = policy(transition)
|
|
||||||
|
|
||||||
# Post-process (including smoothing)
|
|
||||||
transition_with_action = {"action": action}
|
|
||||||
smoothed_transition = postprocessor(transition_with_action)
|
|
||||||
|
|
||||||
# Execute smoothed action
|
|
||||||
next_obs = env.step(smoothed_transition["action"])
|
|
||||||
```
|
|
||||||
|
|
||||||
This example demonstrates all the key concepts: processor registration, state management, configuration serialization, and proper integration with the LeRobot pipeline system.
|
|
||||||
|
|||||||
Reference in New Issue
Block a user