mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 16:49:55 +00:00
feat (overrides): Implement support for loading processors with parameter overrides
- Added the ability to provide non-serializable objects when loading processors from saved configurations using the `overrides` parameter. - Enhanced error handling for invalid override keys and instantiation errors. - Updated documentation and examples to illustrate the usage of overrides for both registered and unregistered steps. - Added comprehensive tests to validate the new functionality and ensure backward compatibility.
This commit is contained in:
@@ -746,6 +746,441 @@ processor = RobotProcessor([
|
||||
])
|
||||
```
|
||||
|
||||
## Loading Processors with Overrides: Handling Non-Serializable Objects
|
||||
|
||||
One of the most powerful features of RobotProcessor is the ability to override step configurations when loading from saved checkpoints. This is particularly useful for handling non-serializable objects like environment instances, database connections, or hardware interfaces that can't be saved to JSON.
|
||||
|
||||
### The Problem: Non-Serializable Parameters
|
||||
|
||||
Imagine you have a processor step that needs a live gym environment:
|
||||
|
||||
```python
|
||||
from dataclasses import dataclass
|
||||
import gym
|
||||
|
||||
@ProcessorStepRegistry.register("action_repeat_step")
|
||||
@dataclass
|
||||
class ActionRepeatStep:
|
||||
"""Step that repeats actions using environment feedback."""
|
||||
|
||||
repeat_count: int = 3
|
||||
env: gym.Env = None # This can't be serialized to JSON!
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
obs, action, reward, done, truncated, info, comp_data = transition
|
||||
|
||||
if self.env is not None and action is not None:
|
||||
# Repeat action multiple times in environment
|
||||
total_reward = 0
|
||||
for _ in range(self.repeat_count):
|
||||
_, r, d, t, _ = self.env.step(action)
|
||||
total_reward += r
|
||||
if d or t:
|
||||
break
|
||||
reward = total_reward
|
||||
|
||||
return (obs, action, reward, done, truncated, info, comp_data)
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
# Note: env is NOT included because it's not serializable
|
||||
return {"repeat_count": self.repeat_count}
|
||||
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||
pass
|
||||
```
|
||||
|
||||
If you try to save and load this processor normally, the `env` parameter will be lost because it can't be serialized to JSON.
|
||||
|
||||
### The Solution: Override Parameters
|
||||
|
||||
The `overrides` parameter in `from_pretrained()` allows you to provide non-serializable objects when loading:
|
||||
|
||||
```python
|
||||
from lerobot.processor.pipeline import RobotProcessor
|
||||
|
||||
# Create processor with environment
|
||||
env = gym.make("CartPole-v1")
|
||||
action_step = ActionRepeatStep(repeat_count=2, env=env)
|
||||
processor = RobotProcessor([action_step], name="CartPoleProcessor")
|
||||
|
||||
# Save the processor (env won't be saved)
|
||||
processor.save_pretrained("./cartpole_processor")
|
||||
|
||||
# Later, load with environment override
|
||||
new_env = gym.make("CartPole-v1")
|
||||
loaded_processor = RobotProcessor.from_pretrained(
|
||||
"./cartpole_processor",
|
||||
overrides={
|
||||
"action_repeat_step": {"env": new_env} # Provide the environment
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
### How Overrides Work
|
||||
|
||||
The `overrides` parameter is a dictionary where:
|
||||
- **Keys** are step identifiers (class names for unregistered steps, registry names for registered steps)
|
||||
- **Values** are dictionaries of parameter overrides that get merged with saved configurations
|
||||
|
||||
```python
|
||||
overrides = {
|
||||
"StepClassName": {"param1": "new_value", "param2": 42},
|
||||
"registered_step_name": {"device": "cuda", "learning_rate": 0.01}
|
||||
}
|
||||
```
|
||||
|
||||
The loading process:
|
||||
1. Load saved configuration from JSON
|
||||
2. For each step, check if overrides exist for that step
|
||||
3. Merge override parameters with saved parameters (overrides take precedence)
|
||||
4. Instantiate the step with merged configuration
|
||||
5. Load any saved tensor state
|
||||
|
||||
### Real-World Examples
|
||||
|
||||
#### Example 1: Environment-Dependent Steps
|
||||
|
||||
```python
|
||||
@ProcessorStepRegistry.register("ik_solver_step")
|
||||
@dataclass
|
||||
class InverseKinematicsStep:
|
||||
"""Convert Cartesian positions to joint angles."""
|
||||
|
||||
robot_model: str = "ur5"
|
||||
solver_timeout: float = 0.1
|
||||
kinematics_solver: Any = None # Non-serializable solver instance
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
obs, action, reward, done, truncated, info, comp_data = transition
|
||||
|
||||
if self.kinematics_solver is not None and action is not None:
|
||||
# Convert Cartesian action to joint angles
|
||||
joint_angles = self.kinematics_solver.solve(action, timeout=self.solver_timeout)
|
||||
action = joint_angles
|
||||
|
||||
return (obs, action, reward, done, truncated, info, comp_data)
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
"robot_model": self.robot_model,
|
||||
"solver_timeout": self.solver_timeout
|
||||
}
|
||||
|
||||
# Save processor without solver
|
||||
processor = RobotProcessor([InverseKinematicsStep(robot_model="ur5")])
|
||||
processor.save_pretrained("./robot_processor")
|
||||
|
||||
# Load with solver instance
|
||||
from robotics_toolbox import URKinematics
|
||||
solver = URKinematics("ur5")
|
||||
|
||||
loaded_processor = RobotProcessor.from_pretrained(
|
||||
"./robot_processor",
|
||||
overrides={
|
||||
"ik_solver_step": {
|
||||
"kinematics_solver": solver,
|
||||
"solver_timeout": 0.05 # Also override timeout
|
||||
}
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
#### Example 2: Device and Hardware Configuration
|
||||
|
||||
```python
|
||||
@ProcessorStepRegistry.register("camera_capture_step")
|
||||
@dataclass
|
||||
class CameraCaptureStep:
|
||||
"""Capture images from physical camera."""
|
||||
|
||||
camera_id: int = 0
|
||||
resolution: tuple = (640, 480)
|
||||
camera_interface: Any = None # Hardware interface
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
"camera_id": self.camera_id,
|
||||
"resolution": self.resolution
|
||||
}
|
||||
|
||||
# Deploy on different robots with different camera setups
|
||||
# Robot A
|
||||
camera_a = CameraInterface("/dev/video0")
|
||||
processor_a = RobotProcessor.from_pretrained(
|
||||
"shared/vision_processor",
|
||||
overrides={
|
||||
"camera_capture_step": {
|
||||
"camera_interface": camera_a,
|
||||
"resolution": (1920, 1080) # High-res camera
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
# Robot B
|
||||
camera_b = CameraInterface("/dev/video1")
|
||||
processor_b = RobotProcessor.from_pretrained(
|
||||
"shared/vision_processor",
|
||||
overrides={
|
||||
"camera_capture_step": {
|
||||
"camera_interface": camera_b,
|
||||
"resolution": (640, 480) # Lower-res camera
|
||||
}
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
#### Example 3: Multiple Environment Deployment
|
||||
|
||||
```python
|
||||
# Training processor that works with simulation
|
||||
@ProcessorStepRegistry.register("physics_validator")
|
||||
@dataclass
|
||||
class PhysicsValidatorStep:
|
||||
"""Validate actions against physics constraints."""
|
||||
|
||||
max_force: float = 100.0
|
||||
physics_engine: Any = None
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {"max_force": self.max_force}
|
||||
|
||||
# Different physics engines for different environments
|
||||
import pybullet as pb
|
||||
import mujoco
|
||||
|
||||
# Simulation deployment
|
||||
sim_engine = pb.connect(pb.DIRECT)
|
||||
sim_processor = RobotProcessor.from_pretrained(
|
||||
"shared/control_processor",
|
||||
overrides={
|
||||
"physics_validator": {
|
||||
"physics_engine": sim_engine,
|
||||
"max_force": 150.0 # Higher limits in sim
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
# Real robot deployment
|
||||
real_engine = RealRobotInterface()
|
||||
real_processor = RobotProcessor.from_pretrained(
|
||||
"shared/control_processor",
|
||||
overrides={
|
||||
"physics_validator": {
|
||||
"physics_engine": real_engine,
|
||||
"max_force": 50.0 # Conservative limits on real robot
|
||||
}
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
### Override Key Matching Rules
|
||||
|
||||
The override system uses exact string matching:
|
||||
|
||||
#### For Registered Steps
|
||||
Use the registry name (the string passed to `@ProcessorStepRegistry.register()`):
|
||||
|
||||
```python
|
||||
@ProcessorStepRegistry.register("my_custom_step")
|
||||
class MyStep:
|
||||
pass
|
||||
|
||||
# Use registry name in overrides
|
||||
overrides = {"my_custom_step": {"param": "value"}}
|
||||
```
|
||||
|
||||
#### For Unregistered Steps
|
||||
Use the exact class name:
|
||||
|
||||
```python
|
||||
class MyUnregisteredStep:
|
||||
pass
|
||||
|
||||
# Use class name in overrides
|
||||
overrides = {"MyUnregisteredStep": {"param": "value"}}
|
||||
```
|
||||
|
||||
### Error Handling and Validation
|
||||
|
||||
The override system provides helpful error messages:
|
||||
|
||||
#### Invalid Override Keys
|
||||
```python
|
||||
# This will raise KeyError with helpful message
|
||||
overrides = {"NonExistentStep": {"param": "value"}}
|
||||
|
||||
try:
|
||||
processor = RobotProcessor.from_pretrained("./processor", overrides=overrides)
|
||||
except KeyError as e:
|
||||
print(e)
|
||||
# Output: Override keys ['NonExistentStep'] do not match any step in the saved configuration.
|
||||
# Available step keys: ['ActualStepName', 'AnotherStepName']
|
||||
```
|
||||
|
||||
#### Instantiation Errors
|
||||
```python
|
||||
# Invalid parameter types are caught
|
||||
overrides = {"MyStep": {"numeric_param": "not_a_number"}}
|
||||
|
||||
try:
|
||||
processor = RobotProcessor.from_pretrained("./processor", overrides=overrides)
|
||||
except ValueError as e:
|
||||
print(e)
|
||||
# Output: Failed to instantiate processor step 'MyStep' with config: {...}
|
||||
```
|
||||
|
||||
### Multiple Steps with Same Class Name
|
||||
|
||||
When you have multiple steps of the same class, all instances get the same override:
|
||||
|
||||
```python
|
||||
step1 = MyStep(param=1)
|
||||
step2 = MyStep(param=2)
|
||||
processor = RobotProcessor([step1, step2])
|
||||
|
||||
# Both steps will get the override
|
||||
overrides = {"MyStep": {"param": 999}}
|
||||
loaded = RobotProcessor.from_pretrained("./processor", overrides=overrides)
|
||||
# Both steps now have param=999
|
||||
```
|
||||
|
||||
To override steps individually, use different classes or register with different names:
|
||||
|
||||
```python
|
||||
@ProcessorStepRegistry.register("step_1")
|
||||
class MyStep:
|
||||
pass
|
||||
|
||||
@ProcessorStepRegistry.register("step_2")
|
||||
class MyStep: # Same class, different registry names
|
||||
pass
|
||||
|
||||
# Now you can override them separately
|
||||
overrides = {
|
||||
"step_1": {"param": 1},
|
||||
"step_2": {"param": 2}
|
||||
}
|
||||
```
|
||||
|
||||
### Best Practices for Overrides
|
||||
|
||||
#### 1. Design Steps for Overrides
|
||||
When creating steps that need non-serializable objects, design them with overrides in mind:
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class WellDesignedStep:
|
||||
# Serializable configuration
|
||||
timeout: float = 1.0
|
||||
retry_count: int = 3
|
||||
|
||||
# Non-serializable objects with default None
|
||||
database: Any = None
|
||||
api_client: Any = None
|
||||
|
||||
def __post_init__(self):
|
||||
# Validate that required non-serializable objects are provided
|
||||
if self.database is None:
|
||||
raise ValueError("database must be provided via overrides")
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
# Only include serializable parameters
|
||||
return {
|
||||
"timeout": self.timeout,
|
||||
"retry_count": self.retry_count
|
||||
}
|
||||
```
|
||||
|
||||
#### 2. Use Registry Names for Clarity
|
||||
Register steps with descriptive names to make overrides clearer:
|
||||
|
||||
```python
|
||||
@ProcessorStepRegistry.register("robot_arm_controller")
|
||||
class ArmControlStep:
|
||||
pass
|
||||
|
||||
@ProcessorStepRegistry.register("gripper_controller")
|
||||
class GripperControlStep:
|
||||
pass
|
||||
|
||||
# Clear override keys
|
||||
overrides = {
|
||||
"robot_arm_controller": {"joint_limits": arm_limits},
|
||||
"gripper_controller": {"force_limit": gripper_force}
|
||||
}
|
||||
```
|
||||
|
||||
#### 3. Document Override Requirements
|
||||
Include clear documentation about what overrides are needed:
|
||||
|
||||
```python
|
||||
@ProcessorStepRegistry.register("vision_processor")
|
||||
class VisionProcessor:
|
||||
"""Process camera images for robot vision.
|
||||
|
||||
Required overrides when loading:
|
||||
camera_interface: Hardware camera interface object
|
||||
|
||||
Optional overrides:
|
||||
resolution: Camera resolution tuple (default: (640, 480))
|
||||
fps: Camera frame rate (default: 30)
|
||||
"""
|
||||
camera_interface: Any = None
|
||||
resolution: tuple = (640, 480)
|
||||
fps: int = 30
|
||||
```
|
||||
|
||||
#### 4. Environment-Specific Configuration Files
|
||||
Create configuration helpers for different deployment environments:
|
||||
|
||||
```python
|
||||
# config/simulation.py
|
||||
def get_simulation_overrides():
|
||||
return {
|
||||
"camera_step": {"camera_interface": SimCamera()},
|
||||
"physics_step": {"engine": SimPhysics()},
|
||||
"control_step": {"safety_limits": False}
|
||||
}
|
||||
|
||||
# config/production.py
|
||||
def get_production_overrides():
|
||||
return {
|
||||
"camera_step": {"camera_interface": RealCamera("/dev/video0")},
|
||||
"physics_step": {"engine": RealPhysics()},
|
||||
"control_step": {"safety_limits": True}
|
||||
}
|
||||
|
||||
# Usage
|
||||
from config.production import get_production_overrides
|
||||
processor = RobotProcessor.from_pretrained(
|
||||
"shared/processor",
|
||||
overrides=get_production_overrides()
|
||||
)
|
||||
```
|
||||
|
||||
### Integration with Hub Sharing
|
||||
|
||||
Overrides work seamlessly with Hugging Face Hub sharing:
|
||||
|
||||
```python
|
||||
# Save processor without non-serializable objects
|
||||
processor.push_to_hub("my-lab/robot-processor")
|
||||
|
||||
# Anyone can load and provide their own environment
|
||||
import gym
|
||||
local_env = gym.make("MyRobotEnv-v1")
|
||||
|
||||
processor = RobotProcessor.from_pretrained(
|
||||
"my-lab/robot-processor",
|
||||
overrides={"env_step": {"env": local_env}}
|
||||
)
|
||||
```
|
||||
|
||||
This enables sharing of preprocessing logic while allowing each user to provide their own environment-specific dependencies.
|
||||
|
||||
## Complete Example: Device-Aware Processing Pipeline
|
||||
|
||||
Here's a complete example showing proper device management and all features:
|
||||
|
||||
Reference in New Issue
Block a user