feat (overrides): Implement support for loading processors with parameter overrides

- Added the ability to provide non-serializable objects when loading processors from saved configurations using the `overrides` parameter.
- Enhanced error handling for invalid override keys and instantiation errors.
- Updated documentation and examples to illustrate the usage of overrides for both registered and unregistered steps.
- Added comprehensive tests to validate the new functionality and ensure backward compatibility.
This commit is contained in:
Adil Zouitine
2025-07-07 12:01:34 +02:00
parent 1c56779dd9
commit 3b8a3a32a0
3 changed files with 882 additions and 5 deletions
+365 -2
View File
@@ -18,14 +18,14 @@ import json
import tempfile
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from typing import Any, Dict
import numpy as np
import pytest
import torch
import torch.nn as nn
from lerobot.processor.pipeline import EnvTransition, RobotProcessor
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, RobotProcessor
@dataclass
@@ -839,6 +839,369 @@ def test_to_device_module_vs_non_module():
assert non_module_step.weights.device.type == "cpu"
# Tests for overrides functionality
@dataclass
class MockStepWithNonSerializableParam:
"""Mock step that requires a non-serializable parameter."""
def __init__(self, name: str = "mock_env_step", multiplier: float = 1.0, env: Any = None):
self.name = name
# Add type validation for multiplier
if isinstance(multiplier, str):
raise ValueError(f"multiplier must be a number, got string '{multiplier}'")
if not isinstance(multiplier, (int, float)):
raise TypeError(f"multiplier must be a number, got {type(multiplier).__name__}")
self.multiplier = float(multiplier)
self.env = env # Non-serializable parameter (like gym.Env)
def __call__(self, transition: EnvTransition) -> EnvTransition:
obs, action, reward, done, truncated, info, comp_data = transition
# Use the env parameter if provided
if self.env is not None:
comp_data = {} if comp_data is None else dict(comp_data)
comp_data[f"{self.name}_env_info"] = str(self.env)
# Apply multiplier to reward
if reward is not None:
reward = reward * self.multiplier
return (obs, action, reward, done, truncated, info, comp_data)
def get_config(self) -> Dict[str, Any]:
# Note: env is intentionally NOT included here as it's not serializable
return {
"name": self.name,
"multiplier": self.multiplier,
}
def state_dict(self) -> Dict[str, torch.Tensor]:
return {}
def load_state_dict(self, state: Dict[str, torch.Tensor]) -> None:
pass
def reset(self) -> None:
pass
@ProcessorStepRegistry.register("registered_mock_step")
@dataclass
class RegisteredMockStep:
"""Mock step registered in the registry."""
value: int = 42
device: str = "cpu"
def __call__(self, transition: EnvTransition) -> EnvTransition:
obs, action, reward, done, truncated, info, comp_data = transition
comp_data = {} if comp_data is None else dict(comp_data)
comp_data["registered_step_value"] = self.value
comp_data["registered_step_device"] = self.device
return (obs, action, reward, done, truncated, info, comp_data)
def get_config(self) -> Dict[str, Any]:
return {
"value": self.value,
"device": self.device,
}
def state_dict(self) -> Dict[str, torch.Tensor]:
return {}
def load_state_dict(self, state: Dict[str, torch.Tensor]) -> None:
pass
def reset(self) -> None:
pass
class MockEnvironment:
"""Mock environment for testing non-serializable parameters."""
def __init__(self, name: str):
self.name = name
def __str__(self):
return f"MockEnvironment({self.name})"
def test_from_pretrained_with_overrides():
"""Test loading processor with parameter overrides."""
# Create a processor with steps that need overrides
env_step = MockStepWithNonSerializableParam(name="env_step", multiplier=2.0)
registered_step = RegisteredMockStep(value=100, device="cpu")
pipeline = RobotProcessor([env_step, registered_step], name="TestOverrides")
with tempfile.TemporaryDirectory() as tmp_dir:
# Save the pipeline
pipeline.save_pretrained(tmp_dir)
# Create a mock environment for override
mock_env = MockEnvironment("test_env")
# Load with overrides
overrides = {
"MockStepWithNonSerializableParam": {
"env": mock_env,
"multiplier": 3.0, # Override the multiplier too
},
"registered_mock_step": {"device": "cuda", "value": 200},
}
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides=overrides)
# Verify the pipeline was loaded correctly
assert len(loaded_pipeline) == 2
assert loaded_pipeline.name == "TestOverrides"
# Test the loaded steps
transition = (None, None, 1.0, False, False, {}, {})
result = loaded_pipeline(transition)
# Check that overrides were applied
comp_data = result[6]
assert "env_step_env_info" in comp_data
assert comp_data["env_step_env_info"] == "MockEnvironment(test_env)"
assert comp_data["registered_step_value"] == 200
assert comp_data["registered_step_device"] == "cuda"
# Check that multiplier override was applied
assert result[2] == 3.0 # 1.0 * 3.0 (overridden multiplier)
def test_from_pretrained_with_partial_overrides():
"""Test loading processor with overrides for only some steps."""
step1 = MockStepWithNonSerializableParam(name="step1", multiplier=1.0)
step2 = MockStepWithNonSerializableParam(name="step2", multiplier=2.0)
pipeline = RobotProcessor([step1, step2])
with tempfile.TemporaryDirectory() as tmp_dir:
pipeline.save_pretrained(tmp_dir)
# Override only one step
overrides = {"MockStepWithNonSerializableParam": {"multiplier": 5.0}}
# The current implementation applies overrides to ALL steps with the same class name
# Both steps will get the override
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides=overrides)
transition = (None, None, 1.0, False, False, {}, {})
result = loaded_pipeline(transition)
# The reward should be affected by both steps, both getting the override
# First step: 1.0 * 5.0 = 5.0 (overridden)
# Second step: 5.0 * 5.0 = 25.0 (also overridden)
assert result[2] == 25.0
def test_from_pretrained_invalid_override_key():
"""Test that invalid override keys raise KeyError."""
step = MockStepWithNonSerializableParam()
pipeline = RobotProcessor([step])
with tempfile.TemporaryDirectory() as tmp_dir:
pipeline.save_pretrained(tmp_dir)
# Try to override a non-existent step
overrides = {"NonExistentStep": {"param": "value"}}
with pytest.raises(KeyError, match="Override keys.*do not match any step"):
RobotProcessor.from_pretrained(tmp_dir, overrides=overrides)
def test_from_pretrained_multiple_invalid_override_keys():
"""Test that multiple invalid override keys are reported."""
step = MockStepWithNonSerializableParam()
pipeline = RobotProcessor([step])
with tempfile.TemporaryDirectory() as tmp_dir:
pipeline.save_pretrained(tmp_dir)
# Try to override multiple non-existent steps
overrides = {"NonExistentStep1": {"param": "value1"}, "NonExistentStep2": {"param": "value2"}}
with pytest.raises(KeyError) as exc_info:
RobotProcessor.from_pretrained(tmp_dir, overrides=overrides)
error_msg = str(exc_info.value)
assert "NonExistentStep1" in error_msg
assert "NonExistentStep2" in error_msg
assert "Available step keys" in error_msg
def test_from_pretrained_registered_step_override():
"""Test overriding registered steps using registry names."""
registered_step = RegisteredMockStep(value=50, device="cpu")
pipeline = RobotProcessor([registered_step])
with tempfile.TemporaryDirectory() as tmp_dir:
pipeline.save_pretrained(tmp_dir)
# Override using registry name
overrides = {"registered_mock_step": {"value": 999, "device": "cuda"}}
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides=overrides)
# Test that overrides were applied
transition = (None, None, 0.0, False, False, {}, {})
result = loaded_pipeline(transition)
comp_data = result[6]
assert comp_data["registered_step_value"] == 999
assert comp_data["registered_step_device"] == "cuda"
def test_from_pretrained_mixed_registered_and_unregistered():
"""Test overriding both registered and unregistered steps."""
unregistered_step = MockStepWithNonSerializableParam(name="unregistered", multiplier=1.0)
registered_step = RegisteredMockStep(value=10, device="cpu")
pipeline = RobotProcessor([unregistered_step, registered_step])
with tempfile.TemporaryDirectory() as tmp_dir:
pipeline.save_pretrained(tmp_dir)
mock_env = MockEnvironment("mixed_test")
overrides = {
"MockStepWithNonSerializableParam": {"env": mock_env, "multiplier": 4.0},
"registered_mock_step": {"value": 777},
}
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides=overrides)
# Test both steps
transition = (None, None, 2.0, False, False, {}, {})
result = loaded_pipeline(transition)
comp_data = result[6]
assert comp_data["unregistered_env_info"] == "MockEnvironment(mixed_test)"
assert comp_data["registered_step_value"] == 777
assert result[2] == 8.0 # 2.0 * 4.0
def test_from_pretrained_no_overrides():
"""Test that from_pretrained works without overrides (backward compatibility)."""
step = MockStepWithNonSerializableParam(name="no_override", multiplier=3.0)
pipeline = RobotProcessor([step])
with tempfile.TemporaryDirectory() as tmp_dir:
pipeline.save_pretrained(tmp_dir)
# Load without overrides
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir)
assert len(loaded_pipeline) == 1
# Test that the step works (env will be None)
transition = (None, None, 1.0, False, False, {}, {})
result = loaded_pipeline(transition)
assert result[2] == 3.0 # 1.0 * 3.0
def test_from_pretrained_empty_overrides():
"""Test that from_pretrained works with empty overrides dict."""
step = MockStepWithNonSerializableParam(multiplier=2.0)
pipeline = RobotProcessor([step])
with tempfile.TemporaryDirectory() as tmp_dir:
pipeline.save_pretrained(tmp_dir)
# Load with empty overrides
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides={})
assert len(loaded_pipeline) == 1
# Test that the step works normally
transition = (None, None, 1.0, False, False, {}, {})
result = loaded_pipeline(transition)
assert result[2] == 2.0
def test_from_pretrained_override_instantiation_error():
"""Test that instantiation errors with overrides are properly reported."""
step = MockStepWithNonSerializableParam(multiplier=1.0)
pipeline = RobotProcessor([step])
with tempfile.TemporaryDirectory() as tmp_dir:
pipeline.save_pretrained(tmp_dir)
# Try to override with invalid parameter type
overrides = {
"MockStepWithNonSerializableParam": {
"multiplier": "invalid_type" # Should be float, not string
}
}
with pytest.raises(ValueError, match="Failed to instantiate processor step"):
RobotProcessor.from_pretrained(tmp_dir, overrides=overrides)
def test_from_pretrained_with_state_and_overrides():
"""Test that overrides work correctly with steps that have tensor state."""
step = MockStepWithTensorState(name="tensor_step", learning_rate=0.01, window_size=5)
pipeline = RobotProcessor([step])
# Process some data to create state
for i in range(10):
transition = (None, None, float(i), False, False, {}, {})
pipeline(transition)
with tempfile.TemporaryDirectory() as tmp_dir:
pipeline.save_pretrained(tmp_dir)
# Load with overrides
overrides = {
"MockStepWithTensorState": {
"learning_rate": 0.05, # Override learning rate
"window_size": 3, # Override window size
}
}
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides=overrides)
loaded_step = loaded_pipeline.steps[0]
# Check that config overrides were applied
assert loaded_step.learning_rate == 0.05
assert loaded_step.window_size == 3
# Check that tensor state was preserved
assert loaded_step.running_count.item() == 10
# The running_mean should still have the original window_size (5) from saved state
# but the new step will use window_size=3 for future operations
assert loaded_step.running_mean.shape[0] == 5 # From saved state
def test_from_pretrained_override_error_messages():
"""Test that error messages for override failures are helpful."""
step1 = MockStepWithNonSerializableParam(name="step1")
step2 = RegisteredMockStep()
pipeline = RobotProcessor([step1, step2])
with tempfile.TemporaryDirectory() as tmp_dir:
pipeline.save_pretrained(tmp_dir)
# Test with invalid override key
overrides = {"WrongStepName": {"param": "value"}}
with pytest.raises(KeyError) as exc_info:
RobotProcessor.from_pretrained(tmp_dir, overrides=overrides)
error_msg = str(exc_info.value)
assert "WrongStepName" in error_msg
assert "Available step keys" in error_msg
assert "MockStepWithNonSerializableParam" in error_msg
assert "registered_mock_step" in error_msg
class MockStepWithMixedState:
"""Mock step demonstrating proper separation of tensor and non-tensor state.