mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-21 03:30:10 +00:00
feat (overrides): Implement support for loading processors with parameter overrides
- Added the ability to provide non-serializable objects when loading processors from saved configurations using the `overrides` parameter. - Enhanced error handling for invalid override keys and instantiation errors. - Updated documentation and examples to illustrate the usage of overrides for both registered and unregistered steps. - Added comprehensive tests to validate the new functionality and ensure backward compatibility.
This commit is contained in:
@@ -746,6 +746,441 @@ processor = RobotProcessor([
|
|||||||
])
|
])
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Loading Processors with Overrides: Handling Non-Serializable Objects
|
||||||
|
|
||||||
|
One of the most powerful features of RobotProcessor is the ability to override step configurations when loading from saved checkpoints. This is particularly useful for handling non-serializable objects like environment instances, database connections, or hardware interfaces that can't be saved to JSON.
|
||||||
|
|
||||||
|
### The Problem: Non-Serializable Parameters
|
||||||
|
|
||||||
|
Imagine you have a processor step that needs a live gym environment:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from dataclasses import dataclass
|
||||||
|
import gym
|
||||||
|
|
||||||
|
@ProcessorStepRegistry.register("action_repeat_step")
|
||||||
|
@dataclass
|
||||||
|
class ActionRepeatStep:
|
||||||
|
"""Step that repeats actions using environment feedback."""
|
||||||
|
|
||||||
|
repeat_count: int = 3
|
||||||
|
env: gym.Env = None # This can't be serialized to JSON!
|
||||||
|
|
||||||
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
|
obs, action, reward, done, truncated, info, comp_data = transition
|
||||||
|
|
||||||
|
if self.env is not None and action is not None:
|
||||||
|
# Repeat action multiple times in environment
|
||||||
|
total_reward = 0
|
||||||
|
for _ in range(self.repeat_count):
|
||||||
|
_, r, d, t, _ = self.env.step(action)
|
||||||
|
total_reward += r
|
||||||
|
if d or t:
|
||||||
|
break
|
||||||
|
reward = total_reward
|
||||||
|
|
||||||
|
return (obs, action, reward, done, truncated, info, comp_data)
|
||||||
|
|
||||||
|
def get_config(self) -> dict[str, Any]:
|
||||||
|
# Note: env is NOT included because it's not serializable
|
||||||
|
return {"repeat_count": self.repeat_count}
|
||||||
|
|
||||||
|
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||||
|
pass
|
||||||
|
```
|
||||||
|
|
||||||
|
If you try to save and load this processor normally, the `env` parameter will be lost because it can't be serialized to JSON.
|
||||||
|
|
||||||
|
### The Solution: Override Parameters
|
||||||
|
|
||||||
|
The `overrides` parameter in `from_pretrained()` allows you to provide non-serializable objects when loading:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from lerobot.processor.pipeline import RobotProcessor
|
||||||
|
|
||||||
|
# Create processor with environment
|
||||||
|
env = gym.make("CartPole-v1")
|
||||||
|
action_step = ActionRepeatStep(repeat_count=2, env=env)
|
||||||
|
processor = RobotProcessor([action_step], name="CartPoleProcessor")
|
||||||
|
|
||||||
|
# Save the processor (env won't be saved)
|
||||||
|
processor.save_pretrained("./cartpole_processor")
|
||||||
|
|
||||||
|
# Later, load with environment override
|
||||||
|
new_env = gym.make("CartPole-v1")
|
||||||
|
loaded_processor = RobotProcessor.from_pretrained(
|
||||||
|
"./cartpole_processor",
|
||||||
|
overrides={
|
||||||
|
"action_repeat_step": {"env": new_env} # Provide the environment
|
||||||
|
}
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### How Overrides Work
|
||||||
|
|
||||||
|
The `overrides` parameter is a dictionary where:
|
||||||
|
- **Keys** are step identifiers (class names for unregistered steps, registry names for registered steps)
|
||||||
|
- **Values** are dictionaries of parameter overrides that get merged with saved configurations
|
||||||
|
|
||||||
|
```python
|
||||||
|
overrides = {
|
||||||
|
"StepClassName": {"param1": "new_value", "param2": 42},
|
||||||
|
"registered_step_name": {"device": "cuda", "learning_rate": 0.01}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
The loading process:
|
||||||
|
1. Load saved configuration from JSON
|
||||||
|
2. For each step, check if overrides exist for that step
|
||||||
|
3. Merge override parameters with saved parameters (overrides take precedence)
|
||||||
|
4. Instantiate the step with merged configuration
|
||||||
|
5. Load any saved tensor state
|
||||||
|
|
||||||
|
### Real-World Examples
|
||||||
|
|
||||||
|
#### Example 1: Environment-Dependent Steps
|
||||||
|
|
||||||
|
```python
|
||||||
|
@ProcessorStepRegistry.register("ik_solver_step")
|
||||||
|
@dataclass
|
||||||
|
class InverseKinematicsStep:
|
||||||
|
"""Convert Cartesian positions to joint angles."""
|
||||||
|
|
||||||
|
robot_model: str = "ur5"
|
||||||
|
solver_timeout: float = 0.1
|
||||||
|
kinematics_solver: Any = None # Non-serializable solver instance
|
||||||
|
|
||||||
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
|
obs, action, reward, done, truncated, info, comp_data = transition
|
||||||
|
|
||||||
|
if self.kinematics_solver is not None and action is not None:
|
||||||
|
# Convert Cartesian action to joint angles
|
||||||
|
joint_angles = self.kinematics_solver.solve(action, timeout=self.solver_timeout)
|
||||||
|
action = joint_angles
|
||||||
|
|
||||||
|
return (obs, action, reward, done, truncated, info, comp_data)
|
||||||
|
|
||||||
|
def get_config(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"robot_model": self.robot_model,
|
||||||
|
"solver_timeout": self.solver_timeout
|
||||||
|
}
|
||||||
|
|
||||||
|
# Save processor without solver
|
||||||
|
processor = RobotProcessor([InverseKinematicsStep(robot_model="ur5")])
|
||||||
|
processor.save_pretrained("./robot_processor")
|
||||||
|
|
||||||
|
# Load with solver instance
|
||||||
|
from robotics_toolbox import URKinematics
|
||||||
|
solver = URKinematics("ur5")
|
||||||
|
|
||||||
|
loaded_processor = RobotProcessor.from_pretrained(
|
||||||
|
"./robot_processor",
|
||||||
|
overrides={
|
||||||
|
"ik_solver_step": {
|
||||||
|
"kinematics_solver": solver,
|
||||||
|
"solver_timeout": 0.05 # Also override timeout
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Example 2: Device and Hardware Configuration
|
||||||
|
|
||||||
|
```python
|
||||||
|
@ProcessorStepRegistry.register("camera_capture_step")
|
||||||
|
@dataclass
|
||||||
|
class CameraCaptureStep:
|
||||||
|
"""Capture images from physical camera."""
|
||||||
|
|
||||||
|
camera_id: int = 0
|
||||||
|
resolution: tuple = (640, 480)
|
||||||
|
camera_interface: Any = None # Hardware interface
|
||||||
|
|
||||||
|
def get_config(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"camera_id": self.camera_id,
|
||||||
|
"resolution": self.resolution
|
||||||
|
}
|
||||||
|
|
||||||
|
# Deploy on different robots with different camera setups
|
||||||
|
# Robot A
|
||||||
|
camera_a = CameraInterface("/dev/video0")
|
||||||
|
processor_a = RobotProcessor.from_pretrained(
|
||||||
|
"shared/vision_processor",
|
||||||
|
overrides={
|
||||||
|
"camera_capture_step": {
|
||||||
|
"camera_interface": camera_a,
|
||||||
|
"resolution": (1920, 1080) # High-res camera
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Robot B
|
||||||
|
camera_b = CameraInterface("/dev/video1")
|
||||||
|
processor_b = RobotProcessor.from_pretrained(
|
||||||
|
"shared/vision_processor",
|
||||||
|
overrides={
|
||||||
|
"camera_capture_step": {
|
||||||
|
"camera_interface": camera_b,
|
||||||
|
"resolution": (640, 480) # Lower-res camera
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Example 3: Multiple Environment Deployment
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Training processor that works with simulation
|
||||||
|
@ProcessorStepRegistry.register("physics_validator")
|
||||||
|
@dataclass
|
||||||
|
class PhysicsValidatorStep:
|
||||||
|
"""Validate actions against physics constraints."""
|
||||||
|
|
||||||
|
max_force: float = 100.0
|
||||||
|
physics_engine: Any = None
|
||||||
|
|
||||||
|
def get_config(self) -> dict[str, Any]:
|
||||||
|
return {"max_force": self.max_force}
|
||||||
|
|
||||||
|
# Different physics engines for different environments
|
||||||
|
import pybullet as pb
|
||||||
|
import mujoco
|
||||||
|
|
||||||
|
# Simulation deployment
|
||||||
|
sim_engine = pb.connect(pb.DIRECT)
|
||||||
|
sim_processor = RobotProcessor.from_pretrained(
|
||||||
|
"shared/control_processor",
|
||||||
|
overrides={
|
||||||
|
"physics_validator": {
|
||||||
|
"physics_engine": sim_engine,
|
||||||
|
"max_force": 150.0 # Higher limits in sim
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Real robot deployment
|
||||||
|
real_engine = RealRobotInterface()
|
||||||
|
real_processor = RobotProcessor.from_pretrained(
|
||||||
|
"shared/control_processor",
|
||||||
|
overrides={
|
||||||
|
"physics_validator": {
|
||||||
|
"physics_engine": real_engine,
|
||||||
|
"max_force": 50.0 # Conservative limits on real robot
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Override Key Matching Rules
|
||||||
|
|
||||||
|
The override system uses exact string matching:
|
||||||
|
|
||||||
|
#### For Registered Steps
|
||||||
|
Use the registry name (the string passed to `@ProcessorStepRegistry.register()`):
|
||||||
|
|
||||||
|
```python
|
||||||
|
@ProcessorStepRegistry.register("my_custom_step")
|
||||||
|
class MyStep:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Use registry name in overrides
|
||||||
|
overrides = {"my_custom_step": {"param": "value"}}
|
||||||
|
```
|
||||||
|
|
||||||
|
#### For Unregistered Steps
|
||||||
|
Use the exact class name:
|
||||||
|
|
||||||
|
```python
|
||||||
|
class MyUnregisteredStep:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Use class name in overrides
|
||||||
|
overrides = {"MyUnregisteredStep": {"param": "value"}}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Error Handling and Validation
|
||||||
|
|
||||||
|
The override system provides helpful error messages:
|
||||||
|
|
||||||
|
#### Invalid Override Keys
|
||||||
|
```python
|
||||||
|
# This will raise KeyError with helpful message
|
||||||
|
overrides = {"NonExistentStep": {"param": "value"}}
|
||||||
|
|
||||||
|
try:
|
||||||
|
processor = RobotProcessor.from_pretrained("./processor", overrides=overrides)
|
||||||
|
except KeyError as e:
|
||||||
|
print(e)
|
||||||
|
# Output: Override keys ['NonExistentStep'] do not match any step in the saved configuration.
|
||||||
|
# Available step keys: ['ActualStepName', 'AnotherStepName']
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Instantiation Errors
|
||||||
|
```python
|
||||||
|
# Invalid parameter types are caught
|
||||||
|
overrides = {"MyStep": {"numeric_param": "not_a_number"}}
|
||||||
|
|
||||||
|
try:
|
||||||
|
processor = RobotProcessor.from_pretrained("./processor", overrides=overrides)
|
||||||
|
except ValueError as e:
|
||||||
|
print(e)
|
||||||
|
# Output: Failed to instantiate processor step 'MyStep' with config: {...}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Multiple Steps with Same Class Name
|
||||||
|
|
||||||
|
When you have multiple steps of the same class, all instances get the same override:
|
||||||
|
|
||||||
|
```python
|
||||||
|
step1 = MyStep(param=1)
|
||||||
|
step2 = MyStep(param=2)
|
||||||
|
processor = RobotProcessor([step1, step2])
|
||||||
|
|
||||||
|
# Both steps will get the override
|
||||||
|
overrides = {"MyStep": {"param": 999}}
|
||||||
|
loaded = RobotProcessor.from_pretrained("./processor", overrides=overrides)
|
||||||
|
# Both steps now have param=999
|
||||||
|
```
|
||||||
|
|
||||||
|
To override steps individually, use different classes or register with different names:
|
||||||
|
|
||||||
|
```python
|
||||||
|
@ProcessorStepRegistry.register("step_1")
|
||||||
|
class MyStep:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@ProcessorStepRegistry.register("step_2")
|
||||||
|
class MyStep: # Same class, different registry names
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Now you can override them separately
|
||||||
|
overrides = {
|
||||||
|
"step_1": {"param": 1},
|
||||||
|
"step_2": {"param": 2}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Best Practices for Overrides
|
||||||
|
|
||||||
|
#### 1. Design Steps for Overrides
|
||||||
|
When creating steps that need non-serializable objects, design them with overrides in mind:
|
||||||
|
|
||||||
|
```python
|
||||||
|
@dataclass
|
||||||
|
class WellDesignedStep:
|
||||||
|
# Serializable configuration
|
||||||
|
timeout: float = 1.0
|
||||||
|
retry_count: int = 3
|
||||||
|
|
||||||
|
# Non-serializable objects with default None
|
||||||
|
database: Any = None
|
||||||
|
api_client: Any = None
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
# Validate that required non-serializable objects are provided
|
||||||
|
if self.database is None:
|
||||||
|
raise ValueError("database must be provided via overrides")
|
||||||
|
|
||||||
|
def get_config(self) -> dict[str, Any]:
|
||||||
|
# Only include serializable parameters
|
||||||
|
return {
|
||||||
|
"timeout": self.timeout,
|
||||||
|
"retry_count": self.retry_count
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 2. Use Registry Names for Clarity
|
||||||
|
Register steps with descriptive names to make overrides clearer:
|
||||||
|
|
||||||
|
```python
|
||||||
|
@ProcessorStepRegistry.register("robot_arm_controller")
|
||||||
|
class ArmControlStep:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@ProcessorStepRegistry.register("gripper_controller")
|
||||||
|
class GripperControlStep:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Clear override keys
|
||||||
|
overrides = {
|
||||||
|
"robot_arm_controller": {"joint_limits": arm_limits},
|
||||||
|
"gripper_controller": {"force_limit": gripper_force}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 3. Document Override Requirements
|
||||||
|
Include clear documentation about what overrides are needed:
|
||||||
|
|
||||||
|
```python
|
||||||
|
@ProcessorStepRegistry.register("vision_processor")
|
||||||
|
class VisionProcessor:
|
||||||
|
"""Process camera images for robot vision.
|
||||||
|
|
||||||
|
Required overrides when loading:
|
||||||
|
camera_interface: Hardware camera interface object
|
||||||
|
|
||||||
|
Optional overrides:
|
||||||
|
resolution: Camera resolution tuple (default: (640, 480))
|
||||||
|
fps: Camera frame rate (default: 30)
|
||||||
|
"""
|
||||||
|
camera_interface: Any = None
|
||||||
|
resolution: tuple = (640, 480)
|
||||||
|
fps: int = 30
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 4. Environment-Specific Configuration Files
|
||||||
|
Create configuration helpers for different deployment environments:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# config/simulation.py
|
||||||
|
def get_simulation_overrides():
|
||||||
|
return {
|
||||||
|
"camera_step": {"camera_interface": SimCamera()},
|
||||||
|
"physics_step": {"engine": SimPhysics()},
|
||||||
|
"control_step": {"safety_limits": False}
|
||||||
|
}
|
||||||
|
|
||||||
|
# config/production.py
|
||||||
|
def get_production_overrides():
|
||||||
|
return {
|
||||||
|
"camera_step": {"camera_interface": RealCamera("/dev/video0")},
|
||||||
|
"physics_step": {"engine": RealPhysics()},
|
||||||
|
"control_step": {"safety_limits": True}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Usage
|
||||||
|
from config.production import get_production_overrides
|
||||||
|
processor = RobotProcessor.from_pretrained(
|
||||||
|
"shared/processor",
|
||||||
|
overrides=get_production_overrides()
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Integration with Hub Sharing
|
||||||
|
|
||||||
|
Overrides work seamlessly with Hugging Face Hub sharing:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Save processor without non-serializable objects
|
||||||
|
processor.push_to_hub("my-lab/robot-processor")
|
||||||
|
|
||||||
|
# Anyone can load and provide their own environment
|
||||||
|
import gym
|
||||||
|
local_env = gym.make("MyRobotEnv-v1")
|
||||||
|
|
||||||
|
processor = RobotProcessor.from_pretrained(
|
||||||
|
"my-lab/robot-processor",
|
||||||
|
overrides={"env_step": {"env": local_env}}
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
This enables sharing of preprocessing logic while allowing each user to provide their own environment-specific dependencies.
|
||||||
|
|
||||||
## Complete Example: Device-Aware Processing Pipeline
|
## Complete Example: Device-Aware Processing Pipeline
|
||||||
|
|
||||||
Here's a complete example showing proper device management and all features:
|
Here's a complete example showing proper device management and all features:
|
||||||
|
|||||||
@@ -433,8 +433,53 @@ class RobotProcessor(ModelHubMixin):
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls, source: str) -> RobotProcessor:
|
def from_pretrained(cls, source: str, *, overrides: dict[str, Any] | None = None) -> RobotProcessor:
|
||||||
"""Load a serialized processor from source (local path or Hugging Face Hub identifier)."""
|
"""Load a serialized processor from source (local path or Hugging Face Hub identifier).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source: Local path to a saved processor directory or Hugging Face Hub identifier
|
||||||
|
(e.g., "username/processor-name").
|
||||||
|
overrides: Optional dictionary mapping step names to configuration overrides.
|
||||||
|
Keys must match exact step class names (for unregistered steps) or registry names
|
||||||
|
(for registered steps). Values are dictionaries containing parameter overrides
|
||||||
|
that will be merged with the saved configuration. This is useful for providing
|
||||||
|
non-serializable objects like environment instances.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A RobotProcessor instance loaded from the saved configuration.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ImportError: If a processor step class cannot be loaded or imported.
|
||||||
|
ValueError: If a step cannot be instantiated with the provided configuration.
|
||||||
|
KeyError: If an override key doesn't match any step in the saved configuration.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
Basic loading:
|
||||||
|
```python
|
||||||
|
processor = RobotProcessor.from_pretrained("path/to/processor")
|
||||||
|
```
|
||||||
|
|
||||||
|
Loading with overrides for non-serializable objects:
|
||||||
|
```python
|
||||||
|
import gym
|
||||||
|
env = gym.make("CartPole-v1")
|
||||||
|
processor = RobotProcessor.from_pretrained(
|
||||||
|
"username/cartpole-processor",
|
||||||
|
overrides={"ActionRepeatStep": {"env": env}}
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
Multiple overrides:
|
||||||
|
```python
|
||||||
|
processor = RobotProcessor.from_pretrained(
|
||||||
|
"path/to/processor",
|
||||||
|
overrides={
|
||||||
|
"CustomStep": {"param1": "new_value"},
|
||||||
|
"device_processor": {"device": "cuda:1"} # For registered steps
|
||||||
|
}
|
||||||
|
)
|
||||||
|
```
|
||||||
|
"""
|
||||||
if Path(source).is_dir():
|
if Path(source).is_dir():
|
||||||
# Local path - use it directly
|
# Local path - use it directly
|
||||||
base_path = Path(source)
|
base_path = Path(source)
|
||||||
@@ -450,6 +495,13 @@ class RobotProcessor(ModelHubMixin):
|
|||||||
# Store downloaded files in the same directory as the config
|
# Store downloaded files in the same directory as the config
|
||||||
base_path = Path(config_path).parent
|
base_path = Path(config_path).parent
|
||||||
|
|
||||||
|
# Handle None overrides
|
||||||
|
if overrides is None:
|
||||||
|
overrides = {}
|
||||||
|
|
||||||
|
# Validate that all override keys will be matched
|
||||||
|
override_keys = set(overrides.keys())
|
||||||
|
|
||||||
steps: list[ProcessorStep] = []
|
steps: list[ProcessorStep] = []
|
||||||
for step_entry in config["steps"]:
|
for step_entry in config["steps"]:
|
||||||
# Check if step uses registry name or module path
|
# Check if step uses registry name or module path
|
||||||
@@ -457,6 +509,7 @@ class RobotProcessor(ModelHubMixin):
|
|||||||
# Load from registry
|
# Load from registry
|
||||||
try:
|
try:
|
||||||
step_class = ProcessorStepRegistry.get(step_entry["registry_name"])
|
step_class = ProcessorStepRegistry.get(step_entry["registry_name"])
|
||||||
|
step_key = step_entry["registry_name"]
|
||||||
except KeyError as e:
|
except KeyError as e:
|
||||||
raise ImportError(f"Failed to load processor step from registry. {str(e)}") from e
|
raise ImportError(f"Failed to load processor step from registry. {str(e)}") from e
|
||||||
else:
|
else:
|
||||||
@@ -468,6 +521,7 @@ class RobotProcessor(ModelHubMixin):
|
|||||||
try:
|
try:
|
||||||
module = importlib.import_module(module_path)
|
module = importlib.import_module(module_path)
|
||||||
step_class = getattr(module, class_name)
|
step_class = getattr(module, class_name)
|
||||||
|
step_key = class_name
|
||||||
except (ImportError, AttributeError) as e:
|
except (ImportError, AttributeError) as e:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
f"Failed to load processor step '{full_class_path}'. "
|
f"Failed to load processor step '{full_class_path}'. "
|
||||||
@@ -478,7 +532,15 @@ class RobotProcessor(ModelHubMixin):
|
|||||||
|
|
||||||
# Instantiate the step with its config
|
# Instantiate the step with its config
|
||||||
try:
|
try:
|
||||||
step_instance: ProcessorStep = step_class(**step_entry.get("config", {}))
|
saved_cfg = step_entry.get("config", {})
|
||||||
|
step_overrides = overrides.get(step_key, {})
|
||||||
|
merged_cfg = {**saved_cfg, **step_overrides}
|
||||||
|
step_instance: ProcessorStep = step_class(**merged_cfg)
|
||||||
|
|
||||||
|
# Track which override keys were used
|
||||||
|
if step_key in override_keys:
|
||||||
|
override_keys.discard(step_key)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
step_name = step_entry.get("registry_name", step_entry.get("class", "Unknown"))
|
step_name = step_entry.get("registry_name", step_entry.get("class", "Unknown"))
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -499,6 +561,23 @@ class RobotProcessor(ModelHubMixin):
|
|||||||
|
|
||||||
steps.append(step_instance)
|
steps.append(step_instance)
|
||||||
|
|
||||||
|
# Check for unused override keys
|
||||||
|
if override_keys:
|
||||||
|
available_keys = []
|
||||||
|
for step_entry in config["steps"]:
|
||||||
|
if "registry_name" in step_entry:
|
||||||
|
available_keys.append(step_entry["registry_name"])
|
||||||
|
else:
|
||||||
|
full_class_path = step_entry["class"]
|
||||||
|
class_name = full_class_path.rsplit(".", 1)[1]
|
||||||
|
available_keys.append(class_name)
|
||||||
|
|
||||||
|
raise KeyError(
|
||||||
|
f"Override keys {list(override_keys)} do not match any step in the saved configuration. "
|
||||||
|
f"Available step keys: {available_keys}. "
|
||||||
|
f"Make sure override keys match exact step class names or registry names."
|
||||||
|
)
|
||||||
|
|
||||||
return cls(steps, config.get("name", "RobotProcessor"), config.get("seed"))
|
return cls(steps, config.get("name", "RobotProcessor"), config.get("seed"))
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
|
|||||||
@@ -18,14 +18,14 @@ import json
|
|||||||
import tempfile
|
import tempfile
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any, Dict
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from lerobot.processor.pipeline import EnvTransition, RobotProcessor
|
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, RobotProcessor
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -839,6 +839,369 @@ def test_to_device_module_vs_non_module():
|
|||||||
assert non_module_step.weights.device.type == "cpu"
|
assert non_module_step.weights.device.type == "cpu"
|
||||||
|
|
||||||
|
|
||||||
|
# Tests for overrides functionality
|
||||||
|
@dataclass
|
||||||
|
class MockStepWithNonSerializableParam:
|
||||||
|
"""Mock step that requires a non-serializable parameter."""
|
||||||
|
|
||||||
|
def __init__(self, name: str = "mock_env_step", multiplier: float = 1.0, env: Any = None):
|
||||||
|
self.name = name
|
||||||
|
# Add type validation for multiplier
|
||||||
|
if isinstance(multiplier, str):
|
||||||
|
raise ValueError(f"multiplier must be a number, got string '{multiplier}'")
|
||||||
|
if not isinstance(multiplier, (int, float)):
|
||||||
|
raise TypeError(f"multiplier must be a number, got {type(multiplier).__name__}")
|
||||||
|
self.multiplier = float(multiplier)
|
||||||
|
self.env = env # Non-serializable parameter (like gym.Env)
|
||||||
|
|
||||||
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
|
obs, action, reward, done, truncated, info, comp_data = transition
|
||||||
|
|
||||||
|
# Use the env parameter if provided
|
||||||
|
if self.env is not None:
|
||||||
|
comp_data = {} if comp_data is None else dict(comp_data)
|
||||||
|
comp_data[f"{self.name}_env_info"] = str(self.env)
|
||||||
|
|
||||||
|
# Apply multiplier to reward
|
||||||
|
if reward is not None:
|
||||||
|
reward = reward * self.multiplier
|
||||||
|
|
||||||
|
return (obs, action, reward, done, truncated, info, comp_data)
|
||||||
|
|
||||||
|
def get_config(self) -> Dict[str, Any]:
|
||||||
|
# Note: env is intentionally NOT included here as it's not serializable
|
||||||
|
return {
|
||||||
|
"name": self.name,
|
||||||
|
"multiplier": self.multiplier,
|
||||||
|
}
|
||||||
|
|
||||||
|
def state_dict(self) -> Dict[str, torch.Tensor]:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def load_state_dict(self, state: Dict[str, torch.Tensor]) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@ProcessorStepRegistry.register("registered_mock_step")
|
||||||
|
@dataclass
|
||||||
|
class RegisteredMockStep:
|
||||||
|
"""Mock step registered in the registry."""
|
||||||
|
|
||||||
|
value: int = 42
|
||||||
|
device: str = "cpu"
|
||||||
|
|
||||||
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
|
obs, action, reward, done, truncated, info, comp_data = transition
|
||||||
|
|
||||||
|
comp_data = {} if comp_data is None else dict(comp_data)
|
||||||
|
comp_data["registered_step_value"] = self.value
|
||||||
|
comp_data["registered_step_device"] = self.device
|
||||||
|
|
||||||
|
return (obs, action, reward, done, truncated, info, comp_data)
|
||||||
|
|
||||||
|
def get_config(self) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"value": self.value,
|
||||||
|
"device": self.device,
|
||||||
|
}
|
||||||
|
|
||||||
|
def state_dict(self) -> Dict[str, torch.Tensor]:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def load_state_dict(self, state: Dict[str, torch.Tensor]) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class MockEnvironment:
|
||||||
|
"""Mock environment for testing non-serializable parameters."""
|
||||||
|
|
||||||
|
def __init__(self, name: str):
|
||||||
|
self.name = name
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return f"MockEnvironment({self.name})"
|
||||||
|
|
||||||
|
|
||||||
|
def test_from_pretrained_with_overrides():
|
||||||
|
"""Test loading processor with parameter overrides."""
|
||||||
|
# Create a processor with steps that need overrides
|
||||||
|
env_step = MockStepWithNonSerializableParam(name="env_step", multiplier=2.0)
|
||||||
|
registered_step = RegisteredMockStep(value=100, device="cpu")
|
||||||
|
|
||||||
|
pipeline = RobotProcessor([env_step, registered_step], name="TestOverrides")
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
# Save the pipeline
|
||||||
|
pipeline.save_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
# Create a mock environment for override
|
||||||
|
mock_env = MockEnvironment("test_env")
|
||||||
|
|
||||||
|
# Load with overrides
|
||||||
|
overrides = {
|
||||||
|
"MockStepWithNonSerializableParam": {
|
||||||
|
"env": mock_env,
|
||||||
|
"multiplier": 3.0, # Override the multiplier too
|
||||||
|
},
|
||||||
|
"registered_mock_step": {"device": "cuda", "value": 200},
|
||||||
|
}
|
||||||
|
|
||||||
|
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides=overrides)
|
||||||
|
|
||||||
|
# Verify the pipeline was loaded correctly
|
||||||
|
assert len(loaded_pipeline) == 2
|
||||||
|
assert loaded_pipeline.name == "TestOverrides"
|
||||||
|
|
||||||
|
# Test the loaded steps
|
||||||
|
transition = (None, None, 1.0, False, False, {}, {})
|
||||||
|
result = loaded_pipeline(transition)
|
||||||
|
|
||||||
|
# Check that overrides were applied
|
||||||
|
comp_data = result[6]
|
||||||
|
assert "env_step_env_info" in comp_data
|
||||||
|
assert comp_data["env_step_env_info"] == "MockEnvironment(test_env)"
|
||||||
|
assert comp_data["registered_step_value"] == 200
|
||||||
|
assert comp_data["registered_step_device"] == "cuda"
|
||||||
|
|
||||||
|
# Check that multiplier override was applied
|
||||||
|
assert result[2] == 3.0 # 1.0 * 3.0 (overridden multiplier)
|
||||||
|
|
||||||
|
|
||||||
|
def test_from_pretrained_with_partial_overrides():
|
||||||
|
"""Test loading processor with overrides for only some steps."""
|
||||||
|
step1 = MockStepWithNonSerializableParam(name="step1", multiplier=1.0)
|
||||||
|
step2 = MockStepWithNonSerializableParam(name="step2", multiplier=2.0)
|
||||||
|
|
||||||
|
pipeline = RobotProcessor([step1, step2])
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
pipeline.save_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
# Override only one step
|
||||||
|
overrides = {"MockStepWithNonSerializableParam": {"multiplier": 5.0}}
|
||||||
|
|
||||||
|
# The current implementation applies overrides to ALL steps with the same class name
|
||||||
|
# Both steps will get the override
|
||||||
|
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides=overrides)
|
||||||
|
|
||||||
|
transition = (None, None, 1.0, False, False, {}, {})
|
||||||
|
result = loaded_pipeline(transition)
|
||||||
|
|
||||||
|
# The reward should be affected by both steps, both getting the override
|
||||||
|
# First step: 1.0 * 5.0 = 5.0 (overridden)
|
||||||
|
# Second step: 5.0 * 5.0 = 25.0 (also overridden)
|
||||||
|
assert result[2] == 25.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_from_pretrained_invalid_override_key():
|
||||||
|
"""Test that invalid override keys raise KeyError."""
|
||||||
|
step = MockStepWithNonSerializableParam()
|
||||||
|
pipeline = RobotProcessor([step])
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
pipeline.save_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
# Try to override a non-existent step
|
||||||
|
overrides = {"NonExistentStep": {"param": "value"}}
|
||||||
|
|
||||||
|
with pytest.raises(KeyError, match="Override keys.*do not match any step"):
|
||||||
|
RobotProcessor.from_pretrained(tmp_dir, overrides=overrides)
|
||||||
|
|
||||||
|
|
||||||
|
def test_from_pretrained_multiple_invalid_override_keys():
|
||||||
|
"""Test that multiple invalid override keys are reported."""
|
||||||
|
step = MockStepWithNonSerializableParam()
|
||||||
|
pipeline = RobotProcessor([step])
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
pipeline.save_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
# Try to override multiple non-existent steps
|
||||||
|
overrides = {"NonExistentStep1": {"param": "value1"}, "NonExistentStep2": {"param": "value2"}}
|
||||||
|
|
||||||
|
with pytest.raises(KeyError) as exc_info:
|
||||||
|
RobotProcessor.from_pretrained(tmp_dir, overrides=overrides)
|
||||||
|
|
||||||
|
error_msg = str(exc_info.value)
|
||||||
|
assert "NonExistentStep1" in error_msg
|
||||||
|
assert "NonExistentStep2" in error_msg
|
||||||
|
assert "Available step keys" in error_msg
|
||||||
|
|
||||||
|
|
||||||
|
def test_from_pretrained_registered_step_override():
|
||||||
|
"""Test overriding registered steps using registry names."""
|
||||||
|
registered_step = RegisteredMockStep(value=50, device="cpu")
|
||||||
|
pipeline = RobotProcessor([registered_step])
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
pipeline.save_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
# Override using registry name
|
||||||
|
overrides = {"registered_mock_step": {"value": 999, "device": "cuda"}}
|
||||||
|
|
||||||
|
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides=overrides)
|
||||||
|
|
||||||
|
# Test that overrides were applied
|
||||||
|
transition = (None, None, 0.0, False, False, {}, {})
|
||||||
|
result = loaded_pipeline(transition)
|
||||||
|
|
||||||
|
comp_data = result[6]
|
||||||
|
assert comp_data["registered_step_value"] == 999
|
||||||
|
assert comp_data["registered_step_device"] == "cuda"
|
||||||
|
|
||||||
|
|
||||||
|
def test_from_pretrained_mixed_registered_and_unregistered():
|
||||||
|
"""Test overriding both registered and unregistered steps."""
|
||||||
|
unregistered_step = MockStepWithNonSerializableParam(name="unregistered", multiplier=1.0)
|
||||||
|
registered_step = RegisteredMockStep(value=10, device="cpu")
|
||||||
|
|
||||||
|
pipeline = RobotProcessor([unregistered_step, registered_step])
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
pipeline.save_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
mock_env = MockEnvironment("mixed_test")
|
||||||
|
|
||||||
|
overrides = {
|
||||||
|
"MockStepWithNonSerializableParam": {"env": mock_env, "multiplier": 4.0},
|
||||||
|
"registered_mock_step": {"value": 777},
|
||||||
|
}
|
||||||
|
|
||||||
|
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides=overrides)
|
||||||
|
|
||||||
|
# Test both steps
|
||||||
|
transition = (None, None, 2.0, False, False, {}, {})
|
||||||
|
result = loaded_pipeline(transition)
|
||||||
|
|
||||||
|
comp_data = result[6]
|
||||||
|
assert comp_data["unregistered_env_info"] == "MockEnvironment(mixed_test)"
|
||||||
|
assert comp_data["registered_step_value"] == 777
|
||||||
|
assert result[2] == 8.0 # 2.0 * 4.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_from_pretrained_no_overrides():
|
||||||
|
"""Test that from_pretrained works without overrides (backward compatibility)."""
|
||||||
|
step = MockStepWithNonSerializableParam(name="no_override", multiplier=3.0)
|
||||||
|
pipeline = RobotProcessor([step])
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
pipeline.save_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
# Load without overrides
|
||||||
|
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
assert len(loaded_pipeline) == 1
|
||||||
|
|
||||||
|
# Test that the step works (env will be None)
|
||||||
|
transition = (None, None, 1.0, False, False, {}, {})
|
||||||
|
result = loaded_pipeline(transition)
|
||||||
|
|
||||||
|
assert result[2] == 3.0 # 1.0 * 3.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_from_pretrained_empty_overrides():
|
||||||
|
"""Test that from_pretrained works with empty overrides dict."""
|
||||||
|
step = MockStepWithNonSerializableParam(multiplier=2.0)
|
||||||
|
pipeline = RobotProcessor([step])
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
pipeline.save_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
# Load with empty overrides
|
||||||
|
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides={})
|
||||||
|
|
||||||
|
assert len(loaded_pipeline) == 1
|
||||||
|
|
||||||
|
# Test that the step works normally
|
||||||
|
transition = (None, None, 1.0, False, False, {}, {})
|
||||||
|
result = loaded_pipeline(transition)
|
||||||
|
|
||||||
|
assert result[2] == 2.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_from_pretrained_override_instantiation_error():
|
||||||
|
"""Test that instantiation errors with overrides are properly reported."""
|
||||||
|
step = MockStepWithNonSerializableParam(multiplier=1.0)
|
||||||
|
pipeline = RobotProcessor([step])
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
pipeline.save_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
# Try to override with invalid parameter type
|
||||||
|
overrides = {
|
||||||
|
"MockStepWithNonSerializableParam": {
|
||||||
|
"multiplier": "invalid_type" # Should be float, not string
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Failed to instantiate processor step"):
|
||||||
|
RobotProcessor.from_pretrained(tmp_dir, overrides=overrides)
|
||||||
|
|
||||||
|
|
||||||
|
def test_from_pretrained_with_state_and_overrides():
|
||||||
|
"""Test that overrides work correctly with steps that have tensor state."""
|
||||||
|
step = MockStepWithTensorState(name="tensor_step", learning_rate=0.01, window_size=5)
|
||||||
|
pipeline = RobotProcessor([step])
|
||||||
|
|
||||||
|
# Process some data to create state
|
||||||
|
for i in range(10):
|
||||||
|
transition = (None, None, float(i), False, False, {}, {})
|
||||||
|
pipeline(transition)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
pipeline.save_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
# Load with overrides
|
||||||
|
overrides = {
|
||||||
|
"MockStepWithTensorState": {
|
||||||
|
"learning_rate": 0.05, # Override learning rate
|
||||||
|
"window_size": 3, # Override window size
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides=overrides)
|
||||||
|
loaded_step = loaded_pipeline.steps[0]
|
||||||
|
|
||||||
|
# Check that config overrides were applied
|
||||||
|
assert loaded_step.learning_rate == 0.05
|
||||||
|
assert loaded_step.window_size == 3
|
||||||
|
|
||||||
|
# Check that tensor state was preserved
|
||||||
|
assert loaded_step.running_count.item() == 10
|
||||||
|
|
||||||
|
# The running_mean should still have the original window_size (5) from saved state
|
||||||
|
# but the new step will use window_size=3 for future operations
|
||||||
|
assert loaded_step.running_mean.shape[0] == 5 # From saved state
|
||||||
|
|
||||||
|
|
||||||
|
def test_from_pretrained_override_error_messages():
|
||||||
|
"""Test that error messages for override failures are helpful."""
|
||||||
|
step1 = MockStepWithNonSerializableParam(name="step1")
|
||||||
|
step2 = RegisteredMockStep()
|
||||||
|
pipeline = RobotProcessor([step1, step2])
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
pipeline.save_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
# Test with invalid override key
|
||||||
|
overrides = {"WrongStepName": {"param": "value"}}
|
||||||
|
|
||||||
|
with pytest.raises(KeyError) as exc_info:
|
||||||
|
RobotProcessor.from_pretrained(tmp_dir, overrides=overrides)
|
||||||
|
|
||||||
|
error_msg = str(exc_info.value)
|
||||||
|
assert "WrongStepName" in error_msg
|
||||||
|
assert "Available step keys" in error_msg
|
||||||
|
assert "MockStepWithNonSerializableParam" in error_msg
|
||||||
|
assert "registered_mock_step" in error_msg
|
||||||
|
|
||||||
|
|
||||||
class MockStepWithMixedState:
|
class MockStepWithMixedState:
|
||||||
"""Mock step demonstrating proper separation of tensor and non-tensor state.
|
"""Mock step demonstrating proper separation of tensor and non-tensor state.
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user