mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 00:59:46 +00:00
feat (overrides): Implement support for loading processors with parameter overrides
- Added the ability to provide non-serializable objects when loading processors from saved configurations using the `overrides` parameter. - Enhanced error handling for invalid override keys and instantiation errors. - Updated documentation and examples to illustrate the usage of overrides for both registered and unregistered steps. - Added comprehensive tests to validate the new functionality and ensure backward compatibility.
This commit is contained in:
@@ -18,14 +18,14 @@ import json
|
||||
import tempfile
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import Any, Dict
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from lerobot.processor.pipeline import EnvTransition, RobotProcessor
|
||||
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, RobotProcessor
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -839,6 +839,369 @@ def test_to_device_module_vs_non_module():
|
||||
assert non_module_step.weights.device.type == "cpu"
|
||||
|
||||
|
||||
# Tests for overrides functionality
|
||||
@dataclass
|
||||
class MockStepWithNonSerializableParam:
|
||||
"""Mock step that requires a non-serializable parameter."""
|
||||
|
||||
def __init__(self, name: str = "mock_env_step", multiplier: float = 1.0, env: Any = None):
|
||||
self.name = name
|
||||
# Add type validation for multiplier
|
||||
if isinstance(multiplier, str):
|
||||
raise ValueError(f"multiplier must be a number, got string '{multiplier}'")
|
||||
if not isinstance(multiplier, (int, float)):
|
||||
raise TypeError(f"multiplier must be a number, got {type(multiplier).__name__}")
|
||||
self.multiplier = float(multiplier)
|
||||
self.env = env # Non-serializable parameter (like gym.Env)
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
obs, action, reward, done, truncated, info, comp_data = transition
|
||||
|
||||
# Use the env parameter if provided
|
||||
if self.env is not None:
|
||||
comp_data = {} if comp_data is None else dict(comp_data)
|
||||
comp_data[f"{self.name}_env_info"] = str(self.env)
|
||||
|
||||
# Apply multiplier to reward
|
||||
if reward is not None:
|
||||
reward = reward * self.multiplier
|
||||
|
||||
return (obs, action, reward, done, truncated, info, comp_data)
|
||||
|
||||
def get_config(self) -> Dict[str, Any]:
|
||||
# Note: env is intentionally NOT included here as it's not serializable
|
||||
return {
|
||||
"name": self.name,
|
||||
"multiplier": self.multiplier,
|
||||
}
|
||||
|
||||
def state_dict(self) -> Dict[str, torch.Tensor]:
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state: Dict[str, torch.Tensor]) -> None:
|
||||
pass
|
||||
|
||||
def reset(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("registered_mock_step")
|
||||
@dataclass
|
||||
class RegisteredMockStep:
|
||||
"""Mock step registered in the registry."""
|
||||
|
||||
value: int = 42
|
||||
device: str = "cpu"
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
obs, action, reward, done, truncated, info, comp_data = transition
|
||||
|
||||
comp_data = {} if comp_data is None else dict(comp_data)
|
||||
comp_data["registered_step_value"] = self.value
|
||||
comp_data["registered_step_device"] = self.device
|
||||
|
||||
return (obs, action, reward, done, truncated, info, comp_data)
|
||||
|
||||
def get_config(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"value": self.value,
|
||||
"device": self.device,
|
||||
}
|
||||
|
||||
def state_dict(self) -> Dict[str, torch.Tensor]:
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state: Dict[str, torch.Tensor]) -> None:
|
||||
pass
|
||||
|
||||
def reset(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class MockEnvironment:
|
||||
"""Mock environment for testing non-serializable parameters."""
|
||||
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
|
||||
def __str__(self):
|
||||
return f"MockEnvironment({self.name})"
|
||||
|
||||
|
||||
def test_from_pretrained_with_overrides():
|
||||
"""Test loading processor with parameter overrides."""
|
||||
# Create a processor with steps that need overrides
|
||||
env_step = MockStepWithNonSerializableParam(name="env_step", multiplier=2.0)
|
||||
registered_step = RegisteredMockStep(value=100, device="cpu")
|
||||
|
||||
pipeline = RobotProcessor([env_step, registered_step], name="TestOverrides")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
# Save the pipeline
|
||||
pipeline.save_pretrained(tmp_dir)
|
||||
|
||||
# Create a mock environment for override
|
||||
mock_env = MockEnvironment("test_env")
|
||||
|
||||
# Load with overrides
|
||||
overrides = {
|
||||
"MockStepWithNonSerializableParam": {
|
||||
"env": mock_env,
|
||||
"multiplier": 3.0, # Override the multiplier too
|
||||
},
|
||||
"registered_mock_step": {"device": "cuda", "value": 200},
|
||||
}
|
||||
|
||||
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides=overrides)
|
||||
|
||||
# Verify the pipeline was loaded correctly
|
||||
assert len(loaded_pipeline) == 2
|
||||
assert loaded_pipeline.name == "TestOverrides"
|
||||
|
||||
# Test the loaded steps
|
||||
transition = (None, None, 1.0, False, False, {}, {})
|
||||
result = loaded_pipeline(transition)
|
||||
|
||||
# Check that overrides were applied
|
||||
comp_data = result[6]
|
||||
assert "env_step_env_info" in comp_data
|
||||
assert comp_data["env_step_env_info"] == "MockEnvironment(test_env)"
|
||||
assert comp_data["registered_step_value"] == 200
|
||||
assert comp_data["registered_step_device"] == "cuda"
|
||||
|
||||
# Check that multiplier override was applied
|
||||
assert result[2] == 3.0 # 1.0 * 3.0 (overridden multiplier)
|
||||
|
||||
|
||||
def test_from_pretrained_with_partial_overrides():
|
||||
"""Test loading processor with overrides for only some steps."""
|
||||
step1 = MockStepWithNonSerializableParam(name="step1", multiplier=1.0)
|
||||
step2 = MockStepWithNonSerializableParam(name="step2", multiplier=2.0)
|
||||
|
||||
pipeline = RobotProcessor([step1, step2])
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
pipeline.save_pretrained(tmp_dir)
|
||||
|
||||
# Override only one step
|
||||
overrides = {"MockStepWithNonSerializableParam": {"multiplier": 5.0}}
|
||||
|
||||
# The current implementation applies overrides to ALL steps with the same class name
|
||||
# Both steps will get the override
|
||||
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides=overrides)
|
||||
|
||||
transition = (None, None, 1.0, False, False, {}, {})
|
||||
result = loaded_pipeline(transition)
|
||||
|
||||
# The reward should be affected by both steps, both getting the override
|
||||
# First step: 1.0 * 5.0 = 5.0 (overridden)
|
||||
# Second step: 5.0 * 5.0 = 25.0 (also overridden)
|
||||
assert result[2] == 25.0
|
||||
|
||||
|
||||
def test_from_pretrained_invalid_override_key():
|
||||
"""Test that invalid override keys raise KeyError."""
|
||||
step = MockStepWithNonSerializableParam()
|
||||
pipeline = RobotProcessor([step])
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
pipeline.save_pretrained(tmp_dir)
|
||||
|
||||
# Try to override a non-existent step
|
||||
overrides = {"NonExistentStep": {"param": "value"}}
|
||||
|
||||
with pytest.raises(KeyError, match="Override keys.*do not match any step"):
|
||||
RobotProcessor.from_pretrained(tmp_dir, overrides=overrides)
|
||||
|
||||
|
||||
def test_from_pretrained_multiple_invalid_override_keys():
|
||||
"""Test that multiple invalid override keys are reported."""
|
||||
step = MockStepWithNonSerializableParam()
|
||||
pipeline = RobotProcessor([step])
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
pipeline.save_pretrained(tmp_dir)
|
||||
|
||||
# Try to override multiple non-existent steps
|
||||
overrides = {"NonExistentStep1": {"param": "value1"}, "NonExistentStep2": {"param": "value2"}}
|
||||
|
||||
with pytest.raises(KeyError) as exc_info:
|
||||
RobotProcessor.from_pretrained(tmp_dir, overrides=overrides)
|
||||
|
||||
error_msg = str(exc_info.value)
|
||||
assert "NonExistentStep1" in error_msg
|
||||
assert "NonExistentStep2" in error_msg
|
||||
assert "Available step keys" in error_msg
|
||||
|
||||
|
||||
def test_from_pretrained_registered_step_override():
|
||||
"""Test overriding registered steps using registry names."""
|
||||
registered_step = RegisteredMockStep(value=50, device="cpu")
|
||||
pipeline = RobotProcessor([registered_step])
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
pipeline.save_pretrained(tmp_dir)
|
||||
|
||||
# Override using registry name
|
||||
overrides = {"registered_mock_step": {"value": 999, "device": "cuda"}}
|
||||
|
||||
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides=overrides)
|
||||
|
||||
# Test that overrides were applied
|
||||
transition = (None, None, 0.0, False, False, {}, {})
|
||||
result = loaded_pipeline(transition)
|
||||
|
||||
comp_data = result[6]
|
||||
assert comp_data["registered_step_value"] == 999
|
||||
assert comp_data["registered_step_device"] == "cuda"
|
||||
|
||||
|
||||
def test_from_pretrained_mixed_registered_and_unregistered():
|
||||
"""Test overriding both registered and unregistered steps."""
|
||||
unregistered_step = MockStepWithNonSerializableParam(name="unregistered", multiplier=1.0)
|
||||
registered_step = RegisteredMockStep(value=10, device="cpu")
|
||||
|
||||
pipeline = RobotProcessor([unregistered_step, registered_step])
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
pipeline.save_pretrained(tmp_dir)
|
||||
|
||||
mock_env = MockEnvironment("mixed_test")
|
||||
|
||||
overrides = {
|
||||
"MockStepWithNonSerializableParam": {"env": mock_env, "multiplier": 4.0},
|
||||
"registered_mock_step": {"value": 777},
|
||||
}
|
||||
|
||||
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides=overrides)
|
||||
|
||||
# Test both steps
|
||||
transition = (None, None, 2.0, False, False, {}, {})
|
||||
result = loaded_pipeline(transition)
|
||||
|
||||
comp_data = result[6]
|
||||
assert comp_data["unregistered_env_info"] == "MockEnvironment(mixed_test)"
|
||||
assert comp_data["registered_step_value"] == 777
|
||||
assert result[2] == 8.0 # 2.0 * 4.0
|
||||
|
||||
|
||||
def test_from_pretrained_no_overrides():
|
||||
"""Test that from_pretrained works without overrides (backward compatibility)."""
|
||||
step = MockStepWithNonSerializableParam(name="no_override", multiplier=3.0)
|
||||
pipeline = RobotProcessor([step])
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
pipeline.save_pretrained(tmp_dir)
|
||||
|
||||
# Load without overrides
|
||||
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir)
|
||||
|
||||
assert len(loaded_pipeline) == 1
|
||||
|
||||
# Test that the step works (env will be None)
|
||||
transition = (None, None, 1.0, False, False, {}, {})
|
||||
result = loaded_pipeline(transition)
|
||||
|
||||
assert result[2] == 3.0 # 1.0 * 3.0
|
||||
|
||||
|
||||
def test_from_pretrained_empty_overrides():
|
||||
"""Test that from_pretrained works with empty overrides dict."""
|
||||
step = MockStepWithNonSerializableParam(multiplier=2.0)
|
||||
pipeline = RobotProcessor([step])
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
pipeline.save_pretrained(tmp_dir)
|
||||
|
||||
# Load with empty overrides
|
||||
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides={})
|
||||
|
||||
assert len(loaded_pipeline) == 1
|
||||
|
||||
# Test that the step works normally
|
||||
transition = (None, None, 1.0, False, False, {}, {})
|
||||
result = loaded_pipeline(transition)
|
||||
|
||||
assert result[2] == 2.0
|
||||
|
||||
|
||||
def test_from_pretrained_override_instantiation_error():
|
||||
"""Test that instantiation errors with overrides are properly reported."""
|
||||
step = MockStepWithNonSerializableParam(multiplier=1.0)
|
||||
pipeline = RobotProcessor([step])
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
pipeline.save_pretrained(tmp_dir)
|
||||
|
||||
# Try to override with invalid parameter type
|
||||
overrides = {
|
||||
"MockStepWithNonSerializableParam": {
|
||||
"multiplier": "invalid_type" # Should be float, not string
|
||||
}
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="Failed to instantiate processor step"):
|
||||
RobotProcessor.from_pretrained(tmp_dir, overrides=overrides)
|
||||
|
||||
|
||||
def test_from_pretrained_with_state_and_overrides():
|
||||
"""Test that overrides work correctly with steps that have tensor state."""
|
||||
step = MockStepWithTensorState(name="tensor_step", learning_rate=0.01, window_size=5)
|
||||
pipeline = RobotProcessor([step])
|
||||
|
||||
# Process some data to create state
|
||||
for i in range(10):
|
||||
transition = (None, None, float(i), False, False, {}, {})
|
||||
pipeline(transition)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
pipeline.save_pretrained(tmp_dir)
|
||||
|
||||
# Load with overrides
|
||||
overrides = {
|
||||
"MockStepWithTensorState": {
|
||||
"learning_rate": 0.05, # Override learning rate
|
||||
"window_size": 3, # Override window size
|
||||
}
|
||||
}
|
||||
|
||||
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides=overrides)
|
||||
loaded_step = loaded_pipeline.steps[0]
|
||||
|
||||
# Check that config overrides were applied
|
||||
assert loaded_step.learning_rate == 0.05
|
||||
assert loaded_step.window_size == 3
|
||||
|
||||
# Check that tensor state was preserved
|
||||
assert loaded_step.running_count.item() == 10
|
||||
|
||||
# The running_mean should still have the original window_size (5) from saved state
|
||||
# but the new step will use window_size=3 for future operations
|
||||
assert loaded_step.running_mean.shape[0] == 5 # From saved state
|
||||
|
||||
|
||||
def test_from_pretrained_override_error_messages():
|
||||
"""Test that error messages for override failures are helpful."""
|
||||
step1 = MockStepWithNonSerializableParam(name="step1")
|
||||
step2 = RegisteredMockStep()
|
||||
pipeline = RobotProcessor([step1, step2])
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
pipeline.save_pretrained(tmp_dir)
|
||||
|
||||
# Test with invalid override key
|
||||
overrides = {"WrongStepName": {"param": "value"}}
|
||||
|
||||
with pytest.raises(KeyError) as exc_info:
|
||||
RobotProcessor.from_pretrained(tmp_dir, overrides=overrides)
|
||||
|
||||
error_msg = str(exc_info.value)
|
||||
assert "WrongStepName" in error_msg
|
||||
assert "Available step keys" in error_msg
|
||||
assert "MockStepWithNonSerializableParam" in error_msg
|
||||
assert "registered_mock_step" in error_msg
|
||||
|
||||
|
||||
class MockStepWithMixedState:
|
||||
"""Mock step demonstrating proper separation of tensor and non-tensor state.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user