From 3b8a3a32a093c1ceb752c57f577dca9715f0562a Mon Sep 17 00:00:00 2001 From: Adil Zouitine Date: Mon, 7 Jul 2025 12:01:34 +0200 Subject: [PATCH] feat (overrides): Implement support for loading processors with parameter overrides - Added the ability to provide non-serializable objects when loading processors from saved configurations using the `overrides` parameter. - Enhanced error handling for invalid override keys and instantiation errors. - Updated documentation and examples to illustrate the usage of overrides for both registered and unregistered steps. - Added comprehensive tests to validate the new functionality and ensure backward compatibility. --- docs/source/processor_tutorial.mdx | 435 +++++++++++++++++++++++++++++ src/lerobot/processor/pipeline.py | 85 +++++- tests/processor/test_pipeline.py | 367 +++++++++++++++++++++++- 3 files changed, 882 insertions(+), 5 deletions(-) diff --git a/docs/source/processor_tutorial.mdx b/docs/source/processor_tutorial.mdx index 3f59666cd..5e55e0b45 100644 --- a/docs/source/processor_tutorial.mdx +++ b/docs/source/processor_tutorial.mdx @@ -746,6 +746,441 @@ processor = RobotProcessor([ ]) ``` +## Loading Processors with Overrides: Handling Non-Serializable Objects + +One of the most powerful features of RobotProcessor is the ability to override step configurations when loading from saved checkpoints. This is particularly useful for handling non-serializable objects like environment instances, database connections, or hardware interfaces that can't be saved to JSON. + +### The Problem: Non-Serializable Parameters + +Imagine you have a processor step that needs a live gym environment: + +```python +from dataclasses import dataclass +import gym + +@ProcessorStepRegistry.register("action_repeat_step") +@dataclass +class ActionRepeatStep: + """Step that repeats actions using environment feedback.""" + + repeat_count: int = 3 + env: gym.Env = None # This can't be serialized to JSON! + + def __call__(self, transition: EnvTransition) -> EnvTransition: + obs, action, reward, done, truncated, info, comp_data = transition + + if self.env is not None and action is not None: + # Repeat action multiple times in environment + total_reward = 0 + for _ in range(self.repeat_count): + _, r, d, t, _ = self.env.step(action) + total_reward += r + if d or t: + break + reward = total_reward + + return (obs, action, reward, done, truncated, info, comp_data) + + def get_config(self) -> dict[str, Any]: + # Note: env is NOT included because it's not serializable + return {"repeat_count": self.repeat_count} + + def state_dict(self) -> dict[str, torch.Tensor]: + return {} + + def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: + pass +``` + +If you try to save and load this processor normally, the `env` parameter will be lost because it can't be serialized to JSON. + +### The Solution: Override Parameters + +The `overrides` parameter in `from_pretrained()` allows you to provide non-serializable objects when loading: + +```python +from lerobot.processor.pipeline import RobotProcessor + +# Create processor with environment +env = gym.make("CartPole-v1") +action_step = ActionRepeatStep(repeat_count=2, env=env) +processor = RobotProcessor([action_step], name="CartPoleProcessor") + +# Save the processor (env won't be saved) +processor.save_pretrained("./cartpole_processor") + +# Later, load with environment override +new_env = gym.make("CartPole-v1") +loaded_processor = RobotProcessor.from_pretrained( + "./cartpole_processor", + overrides={ + "action_repeat_step": {"env": new_env} # Provide the environment + } +) +``` + +### How Overrides Work + +The `overrides` parameter is a dictionary where: +- **Keys** are step identifiers (class names for unregistered steps, registry names for registered steps) +- **Values** are dictionaries of parameter overrides that get merged with saved configurations + +```python +overrides = { + "StepClassName": {"param1": "new_value", "param2": 42}, + "registered_step_name": {"device": "cuda", "learning_rate": 0.01} +} +``` + +The loading process: +1. Load saved configuration from JSON +2. For each step, check if overrides exist for that step +3. Merge override parameters with saved parameters (overrides take precedence) +4. Instantiate the step with merged configuration +5. Load any saved tensor state + +### Real-World Examples + +#### Example 1: Environment-Dependent Steps + +```python +@ProcessorStepRegistry.register("ik_solver_step") +@dataclass +class InverseKinematicsStep: + """Convert Cartesian positions to joint angles.""" + + robot_model: str = "ur5" + solver_timeout: float = 0.1 + kinematics_solver: Any = None # Non-serializable solver instance + + def __call__(self, transition: EnvTransition) -> EnvTransition: + obs, action, reward, done, truncated, info, comp_data = transition + + if self.kinematics_solver is not None and action is not None: + # Convert Cartesian action to joint angles + joint_angles = self.kinematics_solver.solve(action, timeout=self.solver_timeout) + action = joint_angles + + return (obs, action, reward, done, truncated, info, comp_data) + + def get_config(self) -> dict[str, Any]: + return { + "robot_model": self.robot_model, + "solver_timeout": self.solver_timeout + } + +# Save processor without solver +processor = RobotProcessor([InverseKinematicsStep(robot_model="ur5")]) +processor.save_pretrained("./robot_processor") + +# Load with solver instance +from robotics_toolbox import URKinematics +solver = URKinematics("ur5") + +loaded_processor = RobotProcessor.from_pretrained( + "./robot_processor", + overrides={ + "ik_solver_step": { + "kinematics_solver": solver, + "solver_timeout": 0.05 # Also override timeout + } + } +) +``` + +#### Example 2: Device and Hardware Configuration + +```python +@ProcessorStepRegistry.register("camera_capture_step") +@dataclass +class CameraCaptureStep: + """Capture images from physical camera.""" + + camera_id: int = 0 + resolution: tuple = (640, 480) + camera_interface: Any = None # Hardware interface + + def get_config(self) -> dict[str, Any]: + return { + "camera_id": self.camera_id, + "resolution": self.resolution + } + +# Deploy on different robots with different camera setups +# Robot A +camera_a = CameraInterface("/dev/video0") +processor_a = RobotProcessor.from_pretrained( + "shared/vision_processor", + overrides={ + "camera_capture_step": { + "camera_interface": camera_a, + "resolution": (1920, 1080) # High-res camera + } + } +) + +# Robot B +camera_b = CameraInterface("/dev/video1") +processor_b = RobotProcessor.from_pretrained( + "shared/vision_processor", + overrides={ + "camera_capture_step": { + "camera_interface": camera_b, + "resolution": (640, 480) # Lower-res camera + } + } +) +``` + +#### Example 3: Multiple Environment Deployment + +```python +# Training processor that works with simulation +@ProcessorStepRegistry.register("physics_validator") +@dataclass +class PhysicsValidatorStep: + """Validate actions against physics constraints.""" + + max_force: float = 100.0 + physics_engine: Any = None + + def get_config(self) -> dict[str, Any]: + return {"max_force": self.max_force} + +# Different physics engines for different environments +import pybullet as pb +import mujoco + +# Simulation deployment +sim_engine = pb.connect(pb.DIRECT) +sim_processor = RobotProcessor.from_pretrained( + "shared/control_processor", + overrides={ + "physics_validator": { + "physics_engine": sim_engine, + "max_force": 150.0 # Higher limits in sim + } + } +) + +# Real robot deployment +real_engine = RealRobotInterface() +real_processor = RobotProcessor.from_pretrained( + "shared/control_processor", + overrides={ + "physics_validator": { + "physics_engine": real_engine, + "max_force": 50.0 # Conservative limits on real robot + } + } +) +``` + +### Override Key Matching Rules + +The override system uses exact string matching: + +#### For Registered Steps +Use the registry name (the string passed to `@ProcessorStepRegistry.register()`): + +```python +@ProcessorStepRegistry.register("my_custom_step") +class MyStep: + pass + +# Use registry name in overrides +overrides = {"my_custom_step": {"param": "value"}} +``` + +#### For Unregistered Steps +Use the exact class name: + +```python +class MyUnregisteredStep: + pass + +# Use class name in overrides +overrides = {"MyUnregisteredStep": {"param": "value"}} +``` + +### Error Handling and Validation + +The override system provides helpful error messages: + +#### Invalid Override Keys +```python +# This will raise KeyError with helpful message +overrides = {"NonExistentStep": {"param": "value"}} + +try: + processor = RobotProcessor.from_pretrained("./processor", overrides=overrides) +except KeyError as e: + print(e) + # Output: Override keys ['NonExistentStep'] do not match any step in the saved configuration. + # Available step keys: ['ActualStepName', 'AnotherStepName'] +``` + +#### Instantiation Errors +```python +# Invalid parameter types are caught +overrides = {"MyStep": {"numeric_param": "not_a_number"}} + +try: + processor = RobotProcessor.from_pretrained("./processor", overrides=overrides) +except ValueError as e: + print(e) + # Output: Failed to instantiate processor step 'MyStep' with config: {...} +``` + +### Multiple Steps with Same Class Name + +When you have multiple steps of the same class, all instances get the same override: + +```python +step1 = MyStep(param=1) +step2 = MyStep(param=2) +processor = RobotProcessor([step1, step2]) + +# Both steps will get the override +overrides = {"MyStep": {"param": 999}} +loaded = RobotProcessor.from_pretrained("./processor", overrides=overrides) +# Both steps now have param=999 +``` + +To override steps individually, use different classes or register with different names: + +```python +@ProcessorStepRegistry.register("step_1") +class MyStep: + pass + +@ProcessorStepRegistry.register("step_2") +class MyStep: # Same class, different registry names + pass + +# Now you can override them separately +overrides = { + "step_1": {"param": 1}, + "step_2": {"param": 2} +} +``` + +### Best Practices for Overrides + +#### 1. Design Steps for Overrides +When creating steps that need non-serializable objects, design them with overrides in mind: + +```python +@dataclass +class WellDesignedStep: + # Serializable configuration + timeout: float = 1.0 + retry_count: int = 3 + + # Non-serializable objects with default None + database: Any = None + api_client: Any = None + + def __post_init__(self): + # Validate that required non-serializable objects are provided + if self.database is None: + raise ValueError("database must be provided via overrides") + + def get_config(self) -> dict[str, Any]: + # Only include serializable parameters + return { + "timeout": self.timeout, + "retry_count": self.retry_count + } +``` + +#### 2. Use Registry Names for Clarity +Register steps with descriptive names to make overrides clearer: + +```python +@ProcessorStepRegistry.register("robot_arm_controller") +class ArmControlStep: + pass + +@ProcessorStepRegistry.register("gripper_controller") +class GripperControlStep: + pass + +# Clear override keys +overrides = { + "robot_arm_controller": {"joint_limits": arm_limits}, + "gripper_controller": {"force_limit": gripper_force} +} +``` + +#### 3. Document Override Requirements +Include clear documentation about what overrides are needed: + +```python +@ProcessorStepRegistry.register("vision_processor") +class VisionProcessor: + """Process camera images for robot vision. + + Required overrides when loading: + camera_interface: Hardware camera interface object + + Optional overrides: + resolution: Camera resolution tuple (default: (640, 480)) + fps: Camera frame rate (default: 30) + """ + camera_interface: Any = None + resolution: tuple = (640, 480) + fps: int = 30 +``` + +#### 4. Environment-Specific Configuration Files +Create configuration helpers for different deployment environments: + +```python +# config/simulation.py +def get_simulation_overrides(): + return { + "camera_step": {"camera_interface": SimCamera()}, + "physics_step": {"engine": SimPhysics()}, + "control_step": {"safety_limits": False} + } + +# config/production.py +def get_production_overrides(): + return { + "camera_step": {"camera_interface": RealCamera("/dev/video0")}, + "physics_step": {"engine": RealPhysics()}, + "control_step": {"safety_limits": True} + } + +# Usage +from config.production import get_production_overrides +processor = RobotProcessor.from_pretrained( + "shared/processor", + overrides=get_production_overrides() +) +``` + +### Integration with Hub Sharing + +Overrides work seamlessly with Hugging Face Hub sharing: + +```python +# Save processor without non-serializable objects +processor.push_to_hub("my-lab/robot-processor") + +# Anyone can load and provide their own environment +import gym +local_env = gym.make("MyRobotEnv-v1") + +processor = RobotProcessor.from_pretrained( + "my-lab/robot-processor", + overrides={"env_step": {"env": local_env}} +) +``` + +This enables sharing of preprocessing logic while allowing each user to provide their own environment-specific dependencies. + ## Complete Example: Device-Aware Processing Pipeline Here's a complete example showing proper device management and all features: diff --git a/src/lerobot/processor/pipeline.py b/src/lerobot/processor/pipeline.py index 5e5f4c177..66deea8a9 100644 --- a/src/lerobot/processor/pipeline.py +++ b/src/lerobot/processor/pipeline.py @@ -433,8 +433,53 @@ class RobotProcessor(ModelHubMixin): return self @classmethod - def from_pretrained(cls, source: str) -> RobotProcessor: - """Load a serialized processor from source (local path or Hugging Face Hub identifier).""" + def from_pretrained(cls, source: str, *, overrides: dict[str, Any] | None = None) -> RobotProcessor: + """Load a serialized processor from source (local path or Hugging Face Hub identifier). + + Args: + source: Local path to a saved processor directory or Hugging Face Hub identifier + (e.g., "username/processor-name"). + overrides: Optional dictionary mapping step names to configuration overrides. + Keys must match exact step class names (for unregistered steps) or registry names + (for registered steps). Values are dictionaries containing parameter overrides + that will be merged with the saved configuration. This is useful for providing + non-serializable objects like environment instances. + + Returns: + A RobotProcessor instance loaded from the saved configuration. + + Raises: + ImportError: If a processor step class cannot be loaded or imported. + ValueError: If a step cannot be instantiated with the provided configuration. + KeyError: If an override key doesn't match any step in the saved configuration. + + Examples: + Basic loading: + ```python + processor = RobotProcessor.from_pretrained("path/to/processor") + ``` + + Loading with overrides for non-serializable objects: + ```python + import gym + env = gym.make("CartPole-v1") + processor = RobotProcessor.from_pretrained( + "username/cartpole-processor", + overrides={"ActionRepeatStep": {"env": env}} + ) + ``` + + Multiple overrides: + ```python + processor = RobotProcessor.from_pretrained( + "path/to/processor", + overrides={ + "CustomStep": {"param1": "new_value"}, + "device_processor": {"device": "cuda:1"} # For registered steps + } + ) + ``` + """ if Path(source).is_dir(): # Local path - use it directly base_path = Path(source) @@ -450,6 +495,13 @@ class RobotProcessor(ModelHubMixin): # Store downloaded files in the same directory as the config base_path = Path(config_path).parent + # Handle None overrides + if overrides is None: + overrides = {} + + # Validate that all override keys will be matched + override_keys = set(overrides.keys()) + steps: list[ProcessorStep] = [] for step_entry in config["steps"]: # Check if step uses registry name or module path @@ -457,6 +509,7 @@ class RobotProcessor(ModelHubMixin): # Load from registry try: step_class = ProcessorStepRegistry.get(step_entry["registry_name"]) + step_key = step_entry["registry_name"] except KeyError as e: raise ImportError(f"Failed to load processor step from registry. {str(e)}") from e else: @@ -468,6 +521,7 @@ class RobotProcessor(ModelHubMixin): try: module = importlib.import_module(module_path) step_class = getattr(module, class_name) + step_key = class_name except (ImportError, AttributeError) as e: raise ImportError( f"Failed to load processor step '{full_class_path}'. " @@ -478,7 +532,15 @@ class RobotProcessor(ModelHubMixin): # Instantiate the step with its config try: - step_instance: ProcessorStep = step_class(**step_entry.get("config", {})) + saved_cfg = step_entry.get("config", {}) + step_overrides = overrides.get(step_key, {}) + merged_cfg = {**saved_cfg, **step_overrides} + step_instance: ProcessorStep = step_class(**merged_cfg) + + # Track which override keys were used + if step_key in override_keys: + override_keys.discard(step_key) + except Exception as e: step_name = step_entry.get("registry_name", step_entry.get("class", "Unknown")) raise ValueError( @@ -499,6 +561,23 @@ class RobotProcessor(ModelHubMixin): steps.append(step_instance) + # Check for unused override keys + if override_keys: + available_keys = [] + for step_entry in config["steps"]: + if "registry_name" in step_entry: + available_keys.append(step_entry["registry_name"]) + else: + full_class_path = step_entry["class"] + class_name = full_class_path.rsplit(".", 1)[1] + available_keys.append(class_name) + + raise KeyError( + f"Override keys {list(override_keys)} do not match any step in the saved configuration. " + f"Available step keys: {available_keys}. " + f"Make sure override keys match exact step class names or registry names." + ) + return cls(steps, config.get("name", "RobotProcessor"), config.get("seed")) def __len__(self) -> int: diff --git a/tests/processor/test_pipeline.py b/tests/processor/test_pipeline.py index d452623ad..280913e49 100644 --- a/tests/processor/test_pipeline.py +++ b/tests/processor/test_pipeline.py @@ -18,14 +18,14 @@ import json import tempfile from dataclasses import dataclass from pathlib import Path -from typing import Any +from typing import Any, Dict import numpy as np import pytest import torch import torch.nn as nn -from lerobot.processor.pipeline import EnvTransition, RobotProcessor +from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, RobotProcessor @dataclass @@ -839,6 +839,369 @@ def test_to_device_module_vs_non_module(): assert non_module_step.weights.device.type == "cpu" +# Tests for overrides functionality +@dataclass +class MockStepWithNonSerializableParam: + """Mock step that requires a non-serializable parameter.""" + + def __init__(self, name: str = "mock_env_step", multiplier: float = 1.0, env: Any = None): + self.name = name + # Add type validation for multiplier + if isinstance(multiplier, str): + raise ValueError(f"multiplier must be a number, got string '{multiplier}'") + if not isinstance(multiplier, (int, float)): + raise TypeError(f"multiplier must be a number, got {type(multiplier).__name__}") + self.multiplier = float(multiplier) + self.env = env # Non-serializable parameter (like gym.Env) + + def __call__(self, transition: EnvTransition) -> EnvTransition: + obs, action, reward, done, truncated, info, comp_data = transition + + # Use the env parameter if provided + if self.env is not None: + comp_data = {} if comp_data is None else dict(comp_data) + comp_data[f"{self.name}_env_info"] = str(self.env) + + # Apply multiplier to reward + if reward is not None: + reward = reward * self.multiplier + + return (obs, action, reward, done, truncated, info, comp_data) + + def get_config(self) -> Dict[str, Any]: + # Note: env is intentionally NOT included here as it's not serializable + return { + "name": self.name, + "multiplier": self.multiplier, + } + + def state_dict(self) -> Dict[str, torch.Tensor]: + return {} + + def load_state_dict(self, state: Dict[str, torch.Tensor]) -> None: + pass + + def reset(self) -> None: + pass + + +@ProcessorStepRegistry.register("registered_mock_step") +@dataclass +class RegisteredMockStep: + """Mock step registered in the registry.""" + + value: int = 42 + device: str = "cpu" + + def __call__(self, transition: EnvTransition) -> EnvTransition: + obs, action, reward, done, truncated, info, comp_data = transition + + comp_data = {} if comp_data is None else dict(comp_data) + comp_data["registered_step_value"] = self.value + comp_data["registered_step_device"] = self.device + + return (obs, action, reward, done, truncated, info, comp_data) + + def get_config(self) -> Dict[str, Any]: + return { + "value": self.value, + "device": self.device, + } + + def state_dict(self) -> Dict[str, torch.Tensor]: + return {} + + def load_state_dict(self, state: Dict[str, torch.Tensor]) -> None: + pass + + def reset(self) -> None: + pass + + +class MockEnvironment: + """Mock environment for testing non-serializable parameters.""" + + def __init__(self, name: str): + self.name = name + + def __str__(self): + return f"MockEnvironment({self.name})" + + +def test_from_pretrained_with_overrides(): + """Test loading processor with parameter overrides.""" + # Create a processor with steps that need overrides + env_step = MockStepWithNonSerializableParam(name="env_step", multiplier=2.0) + registered_step = RegisteredMockStep(value=100, device="cpu") + + pipeline = RobotProcessor([env_step, registered_step], name="TestOverrides") + + with tempfile.TemporaryDirectory() as tmp_dir: + # Save the pipeline + pipeline.save_pretrained(tmp_dir) + + # Create a mock environment for override + mock_env = MockEnvironment("test_env") + + # Load with overrides + overrides = { + "MockStepWithNonSerializableParam": { + "env": mock_env, + "multiplier": 3.0, # Override the multiplier too + }, + "registered_mock_step": {"device": "cuda", "value": 200}, + } + + loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides=overrides) + + # Verify the pipeline was loaded correctly + assert len(loaded_pipeline) == 2 + assert loaded_pipeline.name == "TestOverrides" + + # Test the loaded steps + transition = (None, None, 1.0, False, False, {}, {}) + result = loaded_pipeline(transition) + + # Check that overrides were applied + comp_data = result[6] + assert "env_step_env_info" in comp_data + assert comp_data["env_step_env_info"] == "MockEnvironment(test_env)" + assert comp_data["registered_step_value"] == 200 + assert comp_data["registered_step_device"] == "cuda" + + # Check that multiplier override was applied + assert result[2] == 3.0 # 1.0 * 3.0 (overridden multiplier) + + +def test_from_pretrained_with_partial_overrides(): + """Test loading processor with overrides for only some steps.""" + step1 = MockStepWithNonSerializableParam(name="step1", multiplier=1.0) + step2 = MockStepWithNonSerializableParam(name="step2", multiplier=2.0) + + pipeline = RobotProcessor([step1, step2]) + + with tempfile.TemporaryDirectory() as tmp_dir: + pipeline.save_pretrained(tmp_dir) + + # Override only one step + overrides = {"MockStepWithNonSerializableParam": {"multiplier": 5.0}} + + # The current implementation applies overrides to ALL steps with the same class name + # Both steps will get the override + loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides=overrides) + + transition = (None, None, 1.0, False, False, {}, {}) + result = loaded_pipeline(transition) + + # The reward should be affected by both steps, both getting the override + # First step: 1.0 * 5.0 = 5.0 (overridden) + # Second step: 5.0 * 5.0 = 25.0 (also overridden) + assert result[2] == 25.0 + + +def test_from_pretrained_invalid_override_key(): + """Test that invalid override keys raise KeyError.""" + step = MockStepWithNonSerializableParam() + pipeline = RobotProcessor([step]) + + with tempfile.TemporaryDirectory() as tmp_dir: + pipeline.save_pretrained(tmp_dir) + + # Try to override a non-existent step + overrides = {"NonExistentStep": {"param": "value"}} + + with pytest.raises(KeyError, match="Override keys.*do not match any step"): + RobotProcessor.from_pretrained(tmp_dir, overrides=overrides) + + +def test_from_pretrained_multiple_invalid_override_keys(): + """Test that multiple invalid override keys are reported.""" + step = MockStepWithNonSerializableParam() + pipeline = RobotProcessor([step]) + + with tempfile.TemporaryDirectory() as tmp_dir: + pipeline.save_pretrained(tmp_dir) + + # Try to override multiple non-existent steps + overrides = {"NonExistentStep1": {"param": "value1"}, "NonExistentStep2": {"param": "value2"}} + + with pytest.raises(KeyError) as exc_info: + RobotProcessor.from_pretrained(tmp_dir, overrides=overrides) + + error_msg = str(exc_info.value) + assert "NonExistentStep1" in error_msg + assert "NonExistentStep2" in error_msg + assert "Available step keys" in error_msg + + +def test_from_pretrained_registered_step_override(): + """Test overriding registered steps using registry names.""" + registered_step = RegisteredMockStep(value=50, device="cpu") + pipeline = RobotProcessor([registered_step]) + + with tempfile.TemporaryDirectory() as tmp_dir: + pipeline.save_pretrained(tmp_dir) + + # Override using registry name + overrides = {"registered_mock_step": {"value": 999, "device": "cuda"}} + + loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides=overrides) + + # Test that overrides were applied + transition = (None, None, 0.0, False, False, {}, {}) + result = loaded_pipeline(transition) + + comp_data = result[6] + assert comp_data["registered_step_value"] == 999 + assert comp_data["registered_step_device"] == "cuda" + + +def test_from_pretrained_mixed_registered_and_unregistered(): + """Test overriding both registered and unregistered steps.""" + unregistered_step = MockStepWithNonSerializableParam(name="unregistered", multiplier=1.0) + registered_step = RegisteredMockStep(value=10, device="cpu") + + pipeline = RobotProcessor([unregistered_step, registered_step]) + + with tempfile.TemporaryDirectory() as tmp_dir: + pipeline.save_pretrained(tmp_dir) + + mock_env = MockEnvironment("mixed_test") + + overrides = { + "MockStepWithNonSerializableParam": {"env": mock_env, "multiplier": 4.0}, + "registered_mock_step": {"value": 777}, + } + + loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides=overrides) + + # Test both steps + transition = (None, None, 2.0, False, False, {}, {}) + result = loaded_pipeline(transition) + + comp_data = result[6] + assert comp_data["unregistered_env_info"] == "MockEnvironment(mixed_test)" + assert comp_data["registered_step_value"] == 777 + assert result[2] == 8.0 # 2.0 * 4.0 + + +def test_from_pretrained_no_overrides(): + """Test that from_pretrained works without overrides (backward compatibility).""" + step = MockStepWithNonSerializableParam(name="no_override", multiplier=3.0) + pipeline = RobotProcessor([step]) + + with tempfile.TemporaryDirectory() as tmp_dir: + pipeline.save_pretrained(tmp_dir) + + # Load without overrides + loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir) + + assert len(loaded_pipeline) == 1 + + # Test that the step works (env will be None) + transition = (None, None, 1.0, False, False, {}, {}) + result = loaded_pipeline(transition) + + assert result[2] == 3.0 # 1.0 * 3.0 + + +def test_from_pretrained_empty_overrides(): + """Test that from_pretrained works with empty overrides dict.""" + step = MockStepWithNonSerializableParam(multiplier=2.0) + pipeline = RobotProcessor([step]) + + with tempfile.TemporaryDirectory() as tmp_dir: + pipeline.save_pretrained(tmp_dir) + + # Load with empty overrides + loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides={}) + + assert len(loaded_pipeline) == 1 + + # Test that the step works normally + transition = (None, None, 1.0, False, False, {}, {}) + result = loaded_pipeline(transition) + + assert result[2] == 2.0 + + +def test_from_pretrained_override_instantiation_error(): + """Test that instantiation errors with overrides are properly reported.""" + step = MockStepWithNonSerializableParam(multiplier=1.0) + pipeline = RobotProcessor([step]) + + with tempfile.TemporaryDirectory() as tmp_dir: + pipeline.save_pretrained(tmp_dir) + + # Try to override with invalid parameter type + overrides = { + "MockStepWithNonSerializableParam": { + "multiplier": "invalid_type" # Should be float, not string + } + } + + with pytest.raises(ValueError, match="Failed to instantiate processor step"): + RobotProcessor.from_pretrained(tmp_dir, overrides=overrides) + + +def test_from_pretrained_with_state_and_overrides(): + """Test that overrides work correctly with steps that have tensor state.""" + step = MockStepWithTensorState(name="tensor_step", learning_rate=0.01, window_size=5) + pipeline = RobotProcessor([step]) + + # Process some data to create state + for i in range(10): + transition = (None, None, float(i), False, False, {}, {}) + pipeline(transition) + + with tempfile.TemporaryDirectory() as tmp_dir: + pipeline.save_pretrained(tmp_dir) + + # Load with overrides + overrides = { + "MockStepWithTensorState": { + "learning_rate": 0.05, # Override learning rate + "window_size": 3, # Override window size + } + } + + loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides=overrides) + loaded_step = loaded_pipeline.steps[0] + + # Check that config overrides were applied + assert loaded_step.learning_rate == 0.05 + assert loaded_step.window_size == 3 + + # Check that tensor state was preserved + assert loaded_step.running_count.item() == 10 + + # The running_mean should still have the original window_size (5) from saved state + # but the new step will use window_size=3 for future operations + assert loaded_step.running_mean.shape[0] == 5 # From saved state + + +def test_from_pretrained_override_error_messages(): + """Test that error messages for override failures are helpful.""" + step1 = MockStepWithNonSerializableParam(name="step1") + step2 = RegisteredMockStep() + pipeline = RobotProcessor([step1, step2]) + + with tempfile.TemporaryDirectory() as tmp_dir: + pipeline.save_pretrained(tmp_dir) + + # Test with invalid override key + overrides = {"WrongStepName": {"param": "value"}} + + with pytest.raises(KeyError) as exc_info: + RobotProcessor.from_pretrained(tmp_dir, overrides=overrides) + + error_msg = str(exc_info.value) + assert "WrongStepName" in error_msg + assert "Available step keys" in error_msg + assert "MockStepWithNonSerializableParam" in error_msg + assert "registered_mock_step" in error_msg + + class MockStepWithMixedState: """Mock step demonstrating proper separation of tensor and non-tensor state.