feat (overrides): Implement support for loading processors with parameter overrides

- Added the ability to provide non-serializable objects when loading processors from saved configurations using the `overrides` parameter.
- Enhanced error handling for invalid override keys and instantiation errors.
- Updated documentation and examples to illustrate the usage of overrides for both registered and unregistered steps.
- Added comprehensive tests to validate the new functionality and ensure backward compatibility.
This commit is contained in:
Adil Zouitine
2025-07-07 12:01:34 +02:00
parent 1c56779dd9
commit 3b8a3a32a0
3 changed files with 882 additions and 5 deletions
+435
View File
@@ -746,6 +746,441 @@ processor = RobotProcessor([
]) ])
``` ```
## Loading Processors with Overrides: Handling Non-Serializable Objects
One of the most powerful features of RobotProcessor is the ability to override step configurations when loading from saved checkpoints. This is particularly useful for handling non-serializable objects like environment instances, database connections, or hardware interfaces that can't be saved to JSON.
### The Problem: Non-Serializable Parameters
Imagine you have a processor step that needs a live gym environment:
```python
from dataclasses import dataclass
import gym
@ProcessorStepRegistry.register("action_repeat_step")
@dataclass
class ActionRepeatStep:
"""Step that repeats actions using environment feedback."""
repeat_count: int = 3
env: gym.Env = None # This can't be serialized to JSON!
def __call__(self, transition: EnvTransition) -> EnvTransition:
obs, action, reward, done, truncated, info, comp_data = transition
if self.env is not None and action is not None:
# Repeat action multiple times in environment
total_reward = 0
for _ in range(self.repeat_count):
_, r, d, t, _ = self.env.step(action)
total_reward += r
if d or t:
break
reward = total_reward
return (obs, action, reward, done, truncated, info, comp_data)
def get_config(self) -> dict[str, Any]:
# Note: env is NOT included because it's not serializable
return {"repeat_count": self.repeat_count}
def state_dict(self) -> dict[str, torch.Tensor]:
return {}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
pass
```
If you try to save and load this processor normally, the `env` parameter will be lost because it can't be serialized to JSON.
### The Solution: Override Parameters
The `overrides` parameter in `from_pretrained()` allows you to provide non-serializable objects when loading:
```python
from lerobot.processor.pipeline import RobotProcessor
# Create processor with environment
env = gym.make("CartPole-v1")
action_step = ActionRepeatStep(repeat_count=2, env=env)
processor = RobotProcessor([action_step], name="CartPoleProcessor")
# Save the processor (env won't be saved)
processor.save_pretrained("./cartpole_processor")
# Later, load with environment override
new_env = gym.make("CartPole-v1")
loaded_processor = RobotProcessor.from_pretrained(
"./cartpole_processor",
overrides={
"action_repeat_step": {"env": new_env} # Provide the environment
}
)
```
### How Overrides Work
The `overrides` parameter is a dictionary where:
- **Keys** are step identifiers (class names for unregistered steps, registry names for registered steps)
- **Values** are dictionaries of parameter overrides that get merged with saved configurations
```python
overrides = {
"StepClassName": {"param1": "new_value", "param2": 42},
"registered_step_name": {"device": "cuda", "learning_rate": 0.01}
}
```
The loading process:
1. Load saved configuration from JSON
2. For each step, check if overrides exist for that step
3. Merge override parameters with saved parameters (overrides take precedence)
4. Instantiate the step with merged configuration
5. Load any saved tensor state
### Real-World Examples
#### Example 1: Environment-Dependent Steps
```python
@ProcessorStepRegistry.register("ik_solver_step")
@dataclass
class InverseKinematicsStep:
"""Convert Cartesian positions to joint angles."""
robot_model: str = "ur5"
solver_timeout: float = 0.1
kinematics_solver: Any = None # Non-serializable solver instance
def __call__(self, transition: EnvTransition) -> EnvTransition:
obs, action, reward, done, truncated, info, comp_data = transition
if self.kinematics_solver is not None and action is not None:
# Convert Cartesian action to joint angles
joint_angles = self.kinematics_solver.solve(action, timeout=self.solver_timeout)
action = joint_angles
return (obs, action, reward, done, truncated, info, comp_data)
def get_config(self) -> dict[str, Any]:
return {
"robot_model": self.robot_model,
"solver_timeout": self.solver_timeout
}
# Save processor without solver
processor = RobotProcessor([InverseKinematicsStep(robot_model="ur5")])
processor.save_pretrained("./robot_processor")
# Load with solver instance
from robotics_toolbox import URKinematics
solver = URKinematics("ur5")
loaded_processor = RobotProcessor.from_pretrained(
"./robot_processor",
overrides={
"ik_solver_step": {
"kinematics_solver": solver,
"solver_timeout": 0.05 # Also override timeout
}
}
)
```
#### Example 2: Device and Hardware Configuration
```python
@ProcessorStepRegistry.register("camera_capture_step")
@dataclass
class CameraCaptureStep:
"""Capture images from physical camera."""
camera_id: int = 0
resolution: tuple = (640, 480)
camera_interface: Any = None # Hardware interface
def get_config(self) -> dict[str, Any]:
return {
"camera_id": self.camera_id,
"resolution": self.resolution
}
# Deploy on different robots with different camera setups
# Robot A
camera_a = CameraInterface("/dev/video0")
processor_a = RobotProcessor.from_pretrained(
"shared/vision_processor",
overrides={
"camera_capture_step": {
"camera_interface": camera_a,
"resolution": (1920, 1080) # High-res camera
}
}
)
# Robot B
camera_b = CameraInterface("/dev/video1")
processor_b = RobotProcessor.from_pretrained(
"shared/vision_processor",
overrides={
"camera_capture_step": {
"camera_interface": camera_b,
"resolution": (640, 480) # Lower-res camera
}
}
)
```
#### Example 3: Multiple Environment Deployment
```python
# Training processor that works with simulation
@ProcessorStepRegistry.register("physics_validator")
@dataclass
class PhysicsValidatorStep:
"""Validate actions against physics constraints."""
max_force: float = 100.0
physics_engine: Any = None
def get_config(self) -> dict[str, Any]:
return {"max_force": self.max_force}
# Different physics engines for different environments
import pybullet as pb
import mujoco
# Simulation deployment
sim_engine = pb.connect(pb.DIRECT)
sim_processor = RobotProcessor.from_pretrained(
"shared/control_processor",
overrides={
"physics_validator": {
"physics_engine": sim_engine,
"max_force": 150.0 # Higher limits in sim
}
}
)
# Real robot deployment
real_engine = RealRobotInterface()
real_processor = RobotProcessor.from_pretrained(
"shared/control_processor",
overrides={
"physics_validator": {
"physics_engine": real_engine,
"max_force": 50.0 # Conservative limits on real robot
}
}
)
```
### Override Key Matching Rules
The override system uses exact string matching:
#### For Registered Steps
Use the registry name (the string passed to `@ProcessorStepRegistry.register()`):
```python
@ProcessorStepRegistry.register("my_custom_step")
class MyStep:
pass
# Use registry name in overrides
overrides = {"my_custom_step": {"param": "value"}}
```
#### For Unregistered Steps
Use the exact class name:
```python
class MyUnregisteredStep:
pass
# Use class name in overrides
overrides = {"MyUnregisteredStep": {"param": "value"}}
```
### Error Handling and Validation
The override system provides helpful error messages:
#### Invalid Override Keys
```python
# This will raise KeyError with helpful message
overrides = {"NonExistentStep": {"param": "value"}}
try:
processor = RobotProcessor.from_pretrained("./processor", overrides=overrides)
except KeyError as e:
print(e)
# Output: Override keys ['NonExistentStep'] do not match any step in the saved configuration.
# Available step keys: ['ActualStepName', 'AnotherStepName']
```
#### Instantiation Errors
```python
# Invalid parameter types are caught
overrides = {"MyStep": {"numeric_param": "not_a_number"}}
try:
processor = RobotProcessor.from_pretrained("./processor", overrides=overrides)
except ValueError as e:
print(e)
# Output: Failed to instantiate processor step 'MyStep' with config: {...}
```
### Multiple Steps with Same Class Name
When you have multiple steps of the same class, all instances get the same override:
```python
step1 = MyStep(param=1)
step2 = MyStep(param=2)
processor = RobotProcessor([step1, step2])
# Both steps will get the override
overrides = {"MyStep": {"param": 999}}
loaded = RobotProcessor.from_pretrained("./processor", overrides=overrides)
# Both steps now have param=999
```
To override steps individually, use different classes or register with different names:
```python
@ProcessorStepRegistry.register("step_1")
class MyStep:
pass
@ProcessorStepRegistry.register("step_2")
class MyStep: # Same class, different registry names
pass
# Now you can override them separately
overrides = {
"step_1": {"param": 1},
"step_2": {"param": 2}
}
```
### Best Practices for Overrides
#### 1. Design Steps for Overrides
When creating steps that need non-serializable objects, design them with overrides in mind:
```python
@dataclass
class WellDesignedStep:
# Serializable configuration
timeout: float = 1.0
retry_count: int = 3
# Non-serializable objects with default None
database: Any = None
api_client: Any = None
def __post_init__(self):
# Validate that required non-serializable objects are provided
if self.database is None:
raise ValueError("database must be provided via overrides")
def get_config(self) -> dict[str, Any]:
# Only include serializable parameters
return {
"timeout": self.timeout,
"retry_count": self.retry_count
}
```
#### 2. Use Registry Names for Clarity
Register steps with descriptive names to make overrides clearer:
```python
@ProcessorStepRegistry.register("robot_arm_controller")
class ArmControlStep:
pass
@ProcessorStepRegistry.register("gripper_controller")
class GripperControlStep:
pass
# Clear override keys
overrides = {
"robot_arm_controller": {"joint_limits": arm_limits},
"gripper_controller": {"force_limit": gripper_force}
}
```
#### 3. Document Override Requirements
Include clear documentation about what overrides are needed:
```python
@ProcessorStepRegistry.register("vision_processor")
class VisionProcessor:
"""Process camera images for robot vision.
Required overrides when loading:
camera_interface: Hardware camera interface object
Optional overrides:
resolution: Camera resolution tuple (default: (640, 480))
fps: Camera frame rate (default: 30)
"""
camera_interface: Any = None
resolution: tuple = (640, 480)
fps: int = 30
```
#### 4. Environment-Specific Configuration Files
Create configuration helpers for different deployment environments:
```python
# config/simulation.py
def get_simulation_overrides():
return {
"camera_step": {"camera_interface": SimCamera()},
"physics_step": {"engine": SimPhysics()},
"control_step": {"safety_limits": False}
}
# config/production.py
def get_production_overrides():
return {
"camera_step": {"camera_interface": RealCamera("/dev/video0")},
"physics_step": {"engine": RealPhysics()},
"control_step": {"safety_limits": True}
}
# Usage
from config.production import get_production_overrides
processor = RobotProcessor.from_pretrained(
"shared/processor",
overrides=get_production_overrides()
)
```
### Integration with Hub Sharing
Overrides work seamlessly with Hugging Face Hub sharing:
```python
# Save processor without non-serializable objects
processor.push_to_hub("my-lab/robot-processor")
# Anyone can load and provide their own environment
import gym
local_env = gym.make("MyRobotEnv-v1")
processor = RobotProcessor.from_pretrained(
"my-lab/robot-processor",
overrides={"env_step": {"env": local_env}}
)
```
This enables sharing of preprocessing logic while allowing each user to provide their own environment-specific dependencies.
## Complete Example: Device-Aware Processing Pipeline ## Complete Example: Device-Aware Processing Pipeline
Here's a complete example showing proper device management and all features: Here's a complete example showing proper device management and all features:
+82 -3
View File
@@ -433,8 +433,53 @@ class RobotProcessor(ModelHubMixin):
return self return self
@classmethod @classmethod
def from_pretrained(cls, source: str) -> RobotProcessor: def from_pretrained(cls, source: str, *, overrides: dict[str, Any] | None = None) -> RobotProcessor:
"""Load a serialized processor from source (local path or Hugging Face Hub identifier).""" """Load a serialized processor from source (local path or Hugging Face Hub identifier).
Args:
source: Local path to a saved processor directory or Hugging Face Hub identifier
(e.g., "username/processor-name").
overrides: Optional dictionary mapping step names to configuration overrides.
Keys must match exact step class names (for unregistered steps) or registry names
(for registered steps). Values are dictionaries containing parameter overrides
that will be merged with the saved configuration. This is useful for providing
non-serializable objects like environment instances.
Returns:
A RobotProcessor instance loaded from the saved configuration.
Raises:
ImportError: If a processor step class cannot be loaded or imported.
ValueError: If a step cannot be instantiated with the provided configuration.
KeyError: If an override key doesn't match any step in the saved configuration.
Examples:
Basic loading:
```python
processor = RobotProcessor.from_pretrained("path/to/processor")
```
Loading with overrides for non-serializable objects:
```python
import gym
env = gym.make("CartPole-v1")
processor = RobotProcessor.from_pretrained(
"username/cartpole-processor",
overrides={"ActionRepeatStep": {"env": env}}
)
```
Multiple overrides:
```python
processor = RobotProcessor.from_pretrained(
"path/to/processor",
overrides={
"CustomStep": {"param1": "new_value"},
"device_processor": {"device": "cuda:1"} # For registered steps
}
)
```
"""
if Path(source).is_dir(): if Path(source).is_dir():
# Local path - use it directly # Local path - use it directly
base_path = Path(source) base_path = Path(source)
@@ -450,6 +495,13 @@ class RobotProcessor(ModelHubMixin):
# Store downloaded files in the same directory as the config # Store downloaded files in the same directory as the config
base_path = Path(config_path).parent base_path = Path(config_path).parent
# Handle None overrides
if overrides is None:
overrides = {}
# Validate that all override keys will be matched
override_keys = set(overrides.keys())
steps: list[ProcessorStep] = [] steps: list[ProcessorStep] = []
for step_entry in config["steps"]: for step_entry in config["steps"]:
# Check if step uses registry name or module path # Check if step uses registry name or module path
@@ -457,6 +509,7 @@ class RobotProcessor(ModelHubMixin):
# Load from registry # Load from registry
try: try:
step_class = ProcessorStepRegistry.get(step_entry["registry_name"]) step_class = ProcessorStepRegistry.get(step_entry["registry_name"])
step_key = step_entry["registry_name"]
except KeyError as e: except KeyError as e:
raise ImportError(f"Failed to load processor step from registry. {str(e)}") from e raise ImportError(f"Failed to load processor step from registry. {str(e)}") from e
else: else:
@@ -468,6 +521,7 @@ class RobotProcessor(ModelHubMixin):
try: try:
module = importlib.import_module(module_path) module = importlib.import_module(module_path)
step_class = getattr(module, class_name) step_class = getattr(module, class_name)
step_key = class_name
except (ImportError, AttributeError) as e: except (ImportError, AttributeError) as e:
raise ImportError( raise ImportError(
f"Failed to load processor step '{full_class_path}'. " f"Failed to load processor step '{full_class_path}'. "
@@ -478,7 +532,15 @@ class RobotProcessor(ModelHubMixin):
# Instantiate the step with its config # Instantiate the step with its config
try: try:
step_instance: ProcessorStep = step_class(**step_entry.get("config", {})) saved_cfg = step_entry.get("config", {})
step_overrides = overrides.get(step_key, {})
merged_cfg = {**saved_cfg, **step_overrides}
step_instance: ProcessorStep = step_class(**merged_cfg)
# Track which override keys were used
if step_key in override_keys:
override_keys.discard(step_key)
except Exception as e: except Exception as e:
step_name = step_entry.get("registry_name", step_entry.get("class", "Unknown")) step_name = step_entry.get("registry_name", step_entry.get("class", "Unknown"))
raise ValueError( raise ValueError(
@@ -499,6 +561,23 @@ class RobotProcessor(ModelHubMixin):
steps.append(step_instance) steps.append(step_instance)
# Check for unused override keys
if override_keys:
available_keys = []
for step_entry in config["steps"]:
if "registry_name" in step_entry:
available_keys.append(step_entry["registry_name"])
else:
full_class_path = step_entry["class"]
class_name = full_class_path.rsplit(".", 1)[1]
available_keys.append(class_name)
raise KeyError(
f"Override keys {list(override_keys)} do not match any step in the saved configuration. "
f"Available step keys: {available_keys}. "
f"Make sure override keys match exact step class names or registry names."
)
return cls(steps, config.get("name", "RobotProcessor"), config.get("seed")) return cls(steps, config.get("name", "RobotProcessor"), config.get("seed"))
def __len__(self) -> int: def __len__(self) -> int:
+365 -2
View File
@@ -18,14 +18,14 @@ import json
import tempfile import tempfile
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any, Dict
import numpy as np import numpy as np
import pytest import pytest
import torch import torch
import torch.nn as nn import torch.nn as nn
from lerobot.processor.pipeline import EnvTransition, RobotProcessor from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, RobotProcessor
@dataclass @dataclass
@@ -839,6 +839,369 @@ def test_to_device_module_vs_non_module():
assert non_module_step.weights.device.type == "cpu" assert non_module_step.weights.device.type == "cpu"
# Tests for overrides functionality
@dataclass
class MockStepWithNonSerializableParam:
"""Mock step that requires a non-serializable parameter."""
def __init__(self, name: str = "mock_env_step", multiplier: float = 1.0, env: Any = None):
self.name = name
# Add type validation for multiplier
if isinstance(multiplier, str):
raise ValueError(f"multiplier must be a number, got string '{multiplier}'")
if not isinstance(multiplier, (int, float)):
raise TypeError(f"multiplier must be a number, got {type(multiplier).__name__}")
self.multiplier = float(multiplier)
self.env = env # Non-serializable parameter (like gym.Env)
def __call__(self, transition: EnvTransition) -> EnvTransition:
obs, action, reward, done, truncated, info, comp_data = transition
# Use the env parameter if provided
if self.env is not None:
comp_data = {} if comp_data is None else dict(comp_data)
comp_data[f"{self.name}_env_info"] = str(self.env)
# Apply multiplier to reward
if reward is not None:
reward = reward * self.multiplier
return (obs, action, reward, done, truncated, info, comp_data)
def get_config(self) -> Dict[str, Any]:
# Note: env is intentionally NOT included here as it's not serializable
return {
"name": self.name,
"multiplier": self.multiplier,
}
def state_dict(self) -> Dict[str, torch.Tensor]:
return {}
def load_state_dict(self, state: Dict[str, torch.Tensor]) -> None:
pass
def reset(self) -> None:
pass
@ProcessorStepRegistry.register("registered_mock_step")
@dataclass
class RegisteredMockStep:
"""Mock step registered in the registry."""
value: int = 42
device: str = "cpu"
def __call__(self, transition: EnvTransition) -> EnvTransition:
obs, action, reward, done, truncated, info, comp_data = transition
comp_data = {} if comp_data is None else dict(comp_data)
comp_data["registered_step_value"] = self.value
comp_data["registered_step_device"] = self.device
return (obs, action, reward, done, truncated, info, comp_data)
def get_config(self) -> Dict[str, Any]:
return {
"value": self.value,
"device": self.device,
}
def state_dict(self) -> Dict[str, torch.Tensor]:
return {}
def load_state_dict(self, state: Dict[str, torch.Tensor]) -> None:
pass
def reset(self) -> None:
pass
class MockEnvironment:
"""Mock environment for testing non-serializable parameters."""
def __init__(self, name: str):
self.name = name
def __str__(self):
return f"MockEnvironment({self.name})"
def test_from_pretrained_with_overrides():
"""Test loading processor with parameter overrides."""
# Create a processor with steps that need overrides
env_step = MockStepWithNonSerializableParam(name="env_step", multiplier=2.0)
registered_step = RegisteredMockStep(value=100, device="cpu")
pipeline = RobotProcessor([env_step, registered_step], name="TestOverrides")
with tempfile.TemporaryDirectory() as tmp_dir:
# Save the pipeline
pipeline.save_pretrained(tmp_dir)
# Create a mock environment for override
mock_env = MockEnvironment("test_env")
# Load with overrides
overrides = {
"MockStepWithNonSerializableParam": {
"env": mock_env,
"multiplier": 3.0, # Override the multiplier too
},
"registered_mock_step": {"device": "cuda", "value": 200},
}
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides=overrides)
# Verify the pipeline was loaded correctly
assert len(loaded_pipeline) == 2
assert loaded_pipeline.name == "TestOverrides"
# Test the loaded steps
transition = (None, None, 1.0, False, False, {}, {})
result = loaded_pipeline(transition)
# Check that overrides were applied
comp_data = result[6]
assert "env_step_env_info" in comp_data
assert comp_data["env_step_env_info"] == "MockEnvironment(test_env)"
assert comp_data["registered_step_value"] == 200
assert comp_data["registered_step_device"] == "cuda"
# Check that multiplier override was applied
assert result[2] == 3.0 # 1.0 * 3.0 (overridden multiplier)
def test_from_pretrained_with_partial_overrides():
"""Test loading processor with overrides for only some steps."""
step1 = MockStepWithNonSerializableParam(name="step1", multiplier=1.0)
step2 = MockStepWithNonSerializableParam(name="step2", multiplier=2.0)
pipeline = RobotProcessor([step1, step2])
with tempfile.TemporaryDirectory() as tmp_dir:
pipeline.save_pretrained(tmp_dir)
# Override only one step
overrides = {"MockStepWithNonSerializableParam": {"multiplier": 5.0}}
# The current implementation applies overrides to ALL steps with the same class name
# Both steps will get the override
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides=overrides)
transition = (None, None, 1.0, False, False, {}, {})
result = loaded_pipeline(transition)
# The reward should be affected by both steps, both getting the override
# First step: 1.0 * 5.0 = 5.0 (overridden)
# Second step: 5.0 * 5.0 = 25.0 (also overridden)
assert result[2] == 25.0
def test_from_pretrained_invalid_override_key():
"""Test that invalid override keys raise KeyError."""
step = MockStepWithNonSerializableParam()
pipeline = RobotProcessor([step])
with tempfile.TemporaryDirectory() as tmp_dir:
pipeline.save_pretrained(tmp_dir)
# Try to override a non-existent step
overrides = {"NonExistentStep": {"param": "value"}}
with pytest.raises(KeyError, match="Override keys.*do not match any step"):
RobotProcessor.from_pretrained(tmp_dir, overrides=overrides)
def test_from_pretrained_multiple_invalid_override_keys():
"""Test that multiple invalid override keys are reported."""
step = MockStepWithNonSerializableParam()
pipeline = RobotProcessor([step])
with tempfile.TemporaryDirectory() as tmp_dir:
pipeline.save_pretrained(tmp_dir)
# Try to override multiple non-existent steps
overrides = {"NonExistentStep1": {"param": "value1"}, "NonExistentStep2": {"param": "value2"}}
with pytest.raises(KeyError) as exc_info:
RobotProcessor.from_pretrained(tmp_dir, overrides=overrides)
error_msg = str(exc_info.value)
assert "NonExistentStep1" in error_msg
assert "NonExistentStep2" in error_msg
assert "Available step keys" in error_msg
def test_from_pretrained_registered_step_override():
"""Test overriding registered steps using registry names."""
registered_step = RegisteredMockStep(value=50, device="cpu")
pipeline = RobotProcessor([registered_step])
with tempfile.TemporaryDirectory() as tmp_dir:
pipeline.save_pretrained(tmp_dir)
# Override using registry name
overrides = {"registered_mock_step": {"value": 999, "device": "cuda"}}
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides=overrides)
# Test that overrides were applied
transition = (None, None, 0.0, False, False, {}, {})
result = loaded_pipeline(transition)
comp_data = result[6]
assert comp_data["registered_step_value"] == 999
assert comp_data["registered_step_device"] == "cuda"
def test_from_pretrained_mixed_registered_and_unregistered():
"""Test overriding both registered and unregistered steps."""
unregistered_step = MockStepWithNonSerializableParam(name="unregistered", multiplier=1.0)
registered_step = RegisteredMockStep(value=10, device="cpu")
pipeline = RobotProcessor([unregistered_step, registered_step])
with tempfile.TemporaryDirectory() as tmp_dir:
pipeline.save_pretrained(tmp_dir)
mock_env = MockEnvironment("mixed_test")
overrides = {
"MockStepWithNonSerializableParam": {"env": mock_env, "multiplier": 4.0},
"registered_mock_step": {"value": 777},
}
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides=overrides)
# Test both steps
transition = (None, None, 2.0, False, False, {}, {})
result = loaded_pipeline(transition)
comp_data = result[6]
assert comp_data["unregistered_env_info"] == "MockEnvironment(mixed_test)"
assert comp_data["registered_step_value"] == 777
assert result[2] == 8.0 # 2.0 * 4.0
def test_from_pretrained_no_overrides():
"""Test that from_pretrained works without overrides (backward compatibility)."""
step = MockStepWithNonSerializableParam(name="no_override", multiplier=3.0)
pipeline = RobotProcessor([step])
with tempfile.TemporaryDirectory() as tmp_dir:
pipeline.save_pretrained(tmp_dir)
# Load without overrides
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir)
assert len(loaded_pipeline) == 1
# Test that the step works (env will be None)
transition = (None, None, 1.0, False, False, {}, {})
result = loaded_pipeline(transition)
assert result[2] == 3.0 # 1.0 * 3.0
def test_from_pretrained_empty_overrides():
"""Test that from_pretrained works with empty overrides dict."""
step = MockStepWithNonSerializableParam(multiplier=2.0)
pipeline = RobotProcessor([step])
with tempfile.TemporaryDirectory() as tmp_dir:
pipeline.save_pretrained(tmp_dir)
# Load with empty overrides
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides={})
assert len(loaded_pipeline) == 1
# Test that the step works normally
transition = (None, None, 1.0, False, False, {}, {})
result = loaded_pipeline(transition)
assert result[2] == 2.0
def test_from_pretrained_override_instantiation_error():
"""Test that instantiation errors with overrides are properly reported."""
step = MockStepWithNonSerializableParam(multiplier=1.0)
pipeline = RobotProcessor([step])
with tempfile.TemporaryDirectory() as tmp_dir:
pipeline.save_pretrained(tmp_dir)
# Try to override with invalid parameter type
overrides = {
"MockStepWithNonSerializableParam": {
"multiplier": "invalid_type" # Should be float, not string
}
}
with pytest.raises(ValueError, match="Failed to instantiate processor step"):
RobotProcessor.from_pretrained(tmp_dir, overrides=overrides)
def test_from_pretrained_with_state_and_overrides():
"""Test that overrides work correctly with steps that have tensor state."""
step = MockStepWithTensorState(name="tensor_step", learning_rate=0.01, window_size=5)
pipeline = RobotProcessor([step])
# Process some data to create state
for i in range(10):
transition = (None, None, float(i), False, False, {}, {})
pipeline(transition)
with tempfile.TemporaryDirectory() as tmp_dir:
pipeline.save_pretrained(tmp_dir)
# Load with overrides
overrides = {
"MockStepWithTensorState": {
"learning_rate": 0.05, # Override learning rate
"window_size": 3, # Override window size
}
}
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides=overrides)
loaded_step = loaded_pipeline.steps[0]
# Check that config overrides were applied
assert loaded_step.learning_rate == 0.05
assert loaded_step.window_size == 3
# Check that tensor state was preserved
assert loaded_step.running_count.item() == 10
# The running_mean should still have the original window_size (5) from saved state
# but the new step will use window_size=3 for future operations
assert loaded_step.running_mean.shape[0] == 5 # From saved state
def test_from_pretrained_override_error_messages():
"""Test that error messages for override failures are helpful."""
step1 = MockStepWithNonSerializableParam(name="step1")
step2 = RegisteredMockStep()
pipeline = RobotProcessor([step1, step2])
with tempfile.TemporaryDirectory() as tmp_dir:
pipeline.save_pretrained(tmp_dir)
# Test with invalid override key
overrides = {"WrongStepName": {"param": "value"}}
with pytest.raises(KeyError) as exc_info:
RobotProcessor.from_pretrained(tmp_dir, overrides=overrides)
error_msg = str(exc_info.value)
assert "WrongStepName" in error_msg
assert "Available step keys" in error_msg
assert "MockStepWithNonSerializableParam" in error_msg
assert "registered_mock_step" in error_msg
class MockStepWithMixedState: class MockStepWithMixedState:
"""Mock step demonstrating proper separation of tensor and non-tensor state. """Mock step demonstrating proper separation of tensor and non-tensor state.