mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-20 02:59:50 +00:00
feat(pipeline): Enhance configuration filename handling and state file naming
- Introduced support for custom configuration filenames in the `save_pretrained` method, allowing users to specify a filename instead of the default. - Improved state file naming to include step indices, preventing conflicts when multiple processors of the same type are saved. - Added automatic detection for configuration files when loading from a directory, with error handling for multiple files. - Updated tests to validate new features, including custom filenames and automatic config detection.
This commit is contained in:
@@ -26,6 +26,7 @@ from typing import Any, Protocol, TypedDict
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from huggingface_hub import ModelHubMixin, hf_hub_download
|
from huggingface_hub import ModelHubMixin, hf_hub_download
|
||||||
|
from huggingface_hub.errors import HfHubHTTPError
|
||||||
from safetensors.torch import load_file, save_file
|
from safetensors.torch import load_file, save_file
|
||||||
|
|
||||||
from lerobot.utils.utils import get_safe_torch_device
|
from lerobot.utils.utils import get_safe_torch_device
|
||||||
@@ -293,8 +294,6 @@ class RobotProcessor(ModelHubMixin):
|
|||||||
after_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False)
|
after_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False)
|
||||||
reset_hooks: list[Callable[[], None]] = field(default_factory=list, repr=False)
|
reset_hooks: list[Callable[[], None]] = field(default_factory=list, repr=False)
|
||||||
|
|
||||||
_CFG_NAME = "processor.json"
|
|
||||||
|
|
||||||
def __call__(self, data: EnvTransition | dict[str, Any]):
|
def __call__(self, data: EnvTransition | dict[str, Any]):
|
||||||
"""Process data through all steps.
|
"""Process data through all steps.
|
||||||
|
|
||||||
@@ -386,7 +385,9 @@ class RobotProcessor(ModelHubMixin):
|
|||||||
|
|
||||||
def _save_pretrained(self, destination_path: str, **kwargs):
|
def _save_pretrained(self, destination_path: str, **kwargs):
|
||||||
"""Internal save method for ModelHubMixin compatibility."""
|
"""Internal save method for ModelHubMixin compatibility."""
|
||||||
self.save_pretrained(destination_path)
|
# Extract config_filename from kwargs if provided
|
||||||
|
config_filename = kwargs.pop("config_filename", None)
|
||||||
|
self.save_pretrained(destination_path, config_filename=config_filename)
|
||||||
|
|
||||||
def _generate_model_card(self, destination_path: str) -> None:
|
def _generate_model_card(self, destination_path: str) -> None:
|
||||||
"""Generate README.md from the RobotProcessor model card template."""
|
"""Generate README.md from the RobotProcessor model card template."""
|
||||||
@@ -405,10 +406,24 @@ class RobotProcessor(ModelHubMixin):
|
|||||||
with open(readme_path, "w") as f:
|
with open(readme_path, "w") as f:
|
||||||
f.write(model_card_content)
|
f.write(model_card_content)
|
||||||
|
|
||||||
def save_pretrained(self, destination_path: str, **kwargs):
|
def save_pretrained(self, destination_path: str, config_filename: str | None = None, **kwargs):
|
||||||
"""Serialize the processor definition and parameters to *destination_path*."""
|
"""Serialize the processor definition and parameters to *destination_path*.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
destination_path: Directory where the processor will be saved.
|
||||||
|
config_filename: Optional custom config filename. If not provided, defaults to
|
||||||
|
"{self.name}.json" where self.name is sanitized for filesystem compatibility.
|
||||||
|
"""
|
||||||
os.makedirs(destination_path, exist_ok=True)
|
os.makedirs(destination_path, exist_ok=True)
|
||||||
|
|
||||||
|
# Determine config filename - sanitize the processor name for filesystem
|
||||||
|
if config_filename is None:
|
||||||
|
# Sanitize name - replace any character that's not alphanumeric or underscore
|
||||||
|
import re
|
||||||
|
|
||||||
|
sanitized_name = re.sub(r"[^a-zA-Z0-9_]", "_", self.name.lower())
|
||||||
|
config_filename = f"{sanitized_name}.json"
|
||||||
|
|
||||||
config: dict[str, Any] = {
|
config: dict[str, Any] = {
|
||||||
"name": self.name,
|
"name": self.name,
|
||||||
"seed": self.seed,
|
"seed": self.seed,
|
||||||
@@ -448,9 +463,10 @@ class RobotProcessor(ModelHubMixin):
|
|||||||
for key, tensor in state.items():
|
for key, tensor in state.items():
|
||||||
cloned_state[key] = tensor.clone()
|
cloned_state[key] = tensor.clone()
|
||||||
|
|
||||||
# Use registry name for more meaningful filenames when available
|
# Always include step index to ensure unique filenames
|
||||||
|
# This prevents conflicts when the same processor type is used multiple times
|
||||||
if registry_name:
|
if registry_name:
|
||||||
state_filename = f"{registry_name}.safetensors"
|
state_filename = f"step_{step_index}_{registry_name}.safetensors"
|
||||||
else:
|
else:
|
||||||
state_filename = f"step_{step_index}.safetensors"
|
state_filename = f"step_{step_index}.safetensors"
|
||||||
|
|
||||||
@@ -459,7 +475,7 @@ class RobotProcessor(ModelHubMixin):
|
|||||||
|
|
||||||
config["steps"].append(step_entry)
|
config["steps"].append(step_entry)
|
||||||
|
|
||||||
with open(os.path.join(destination_path, self._CFG_NAME), "w") as file_pointer:
|
with open(os.path.join(destination_path, config_filename), "w") as file_pointer:
|
||||||
json.dump(config, file_pointer, indent=2)
|
json.dump(config, file_pointer, indent=2)
|
||||||
|
|
||||||
# Generate README.md from template
|
# Generate README.md from template
|
||||||
@@ -484,12 +500,17 @@ class RobotProcessor(ModelHubMixin):
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls, source: str, *, overrides: dict[str, Any] | None = None) -> RobotProcessor:
|
def from_pretrained(
|
||||||
|
cls, source: str, *, config_filename: str | None = None, overrides: dict[str, Any] | None = None
|
||||||
|
) -> RobotProcessor:
|
||||||
"""Load a serialized processor from source (local path or Hugging Face Hub identifier).
|
"""Load a serialized processor from source (local path or Hugging Face Hub identifier).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
source: Local path to a saved processor directory or Hugging Face Hub identifier
|
source: Local path to a saved processor directory or Hugging Face Hub identifier
|
||||||
(e.g., "username/processor-name").
|
(e.g., "username/processor-name").
|
||||||
|
config_filename: Optional specific config filename to load. If not provided, will:
|
||||||
|
- For local paths: look for any .json file in the directory (error if multiple found)
|
||||||
|
- For HF Hub: try common names ("processor.json", "preprocessor.json", "postprocessor.json")
|
||||||
overrides: Optional dictionary mapping step names to configuration overrides.
|
overrides: Optional dictionary mapping step names to configuration overrides.
|
||||||
Keys must match exact step class names (for unregistered steps) or registry names
|
Keys must match exact step class names (for unregistered steps) or registry names
|
||||||
(for registered steps). Values are dictionaries containing parameter overrides
|
(for registered steps). Values are dictionaries containing parameter overrides
|
||||||
@@ -510,6 +531,13 @@ class RobotProcessor(ModelHubMixin):
|
|||||||
processor = RobotProcessor.from_pretrained("path/to/processor")
|
processor = RobotProcessor.from_pretrained("path/to/processor")
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Loading specific config file:
|
||||||
|
```python
|
||||||
|
processor = RobotProcessor.from_pretrained(
|
||||||
|
"username/multi-processor-repo", config_filename="preprocessor.json"
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
Loading with overrides for non-serializable objects:
|
Loading with overrides for non-serializable objects:
|
||||||
```python
|
```python
|
||||||
import gym
|
import gym
|
||||||
@@ -534,12 +562,52 @@ class RobotProcessor(ModelHubMixin):
|
|||||||
if Path(source).is_dir():
|
if Path(source).is_dir():
|
||||||
# Local path - use it directly
|
# Local path - use it directly
|
||||||
base_path = Path(source)
|
base_path = Path(source)
|
||||||
with open(base_path / cls._CFG_NAME) as file_pointer:
|
|
||||||
|
if config_filename is None:
|
||||||
|
# Look for any .json file in the directory
|
||||||
|
json_files = list(base_path.glob("*.json"))
|
||||||
|
if len(json_files) == 0:
|
||||||
|
raise FileNotFoundError(f"No .json configuration files found in {source}")
|
||||||
|
elif len(json_files) > 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"Multiple .json files found in {source}: {[f.name for f in json_files]}. "
|
||||||
|
f"Please specify which one to load using the config_filename parameter."
|
||||||
|
)
|
||||||
|
config_filename = json_files[0].name
|
||||||
|
|
||||||
|
with open(base_path / config_filename) as file_pointer:
|
||||||
config: dict[str, Any] = json.load(file_pointer)
|
config: dict[str, Any] = json.load(file_pointer)
|
||||||
else:
|
else:
|
||||||
# Hugging Face Hub - download all required files
|
# Hugging Face Hub - download all required files
|
||||||
# First download the config file
|
if config_filename is None:
|
||||||
config_path = hf_hub_download(source, cls._CFG_NAME, repo_type="model")
|
# Try common config names
|
||||||
|
common_names = [
|
||||||
|
"processor.json",
|
||||||
|
"preprocessor.json",
|
||||||
|
"postprocessor.json",
|
||||||
|
"robotprocessor.json",
|
||||||
|
]
|
||||||
|
config_path = None
|
||||||
|
for name in common_names:
|
||||||
|
try:
|
||||||
|
config_path = hf_hub_download(source, name, repo_type="model")
|
||||||
|
config_filename = name
|
||||||
|
break
|
||||||
|
except (FileNotFoundError, OSError, HfHubHTTPError):
|
||||||
|
# FileNotFoundError: local file issues
|
||||||
|
# OSError: network/system errors
|
||||||
|
# HfHubHTTPError: file not found on Hub (404) or other HTTP errors
|
||||||
|
continue
|
||||||
|
|
||||||
|
if config_path is None:
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"No processor configuration file found in {source}. "
|
||||||
|
f"Tried: {common_names}. Please specify the config_filename parameter."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Download specific config file
|
||||||
|
config_path = hf_hub_download(source, config_filename, repo_type="model")
|
||||||
|
|
||||||
with open(config_path) as file_pointer:
|
with open(config_path) as file_pointer:
|
||||||
config: dict[str, Any] = json.load(file_pointer)
|
config: dict[str, Any] = json.load(file_pointer)
|
||||||
|
|
||||||
|
|||||||
@@ -563,7 +563,7 @@ def test_save_and_load_pretrained():
|
|||||||
pipeline.save_pretrained(tmp_dir)
|
pipeline.save_pretrained(tmp_dir)
|
||||||
|
|
||||||
# Check files were created
|
# Check files were created
|
||||||
config_path = Path(tmp_dir) / "processor.json"
|
config_path = Path(tmp_dir) / "testpipeline.json" # Based on name="TestPipeline"
|
||||||
assert config_path.exists()
|
assert config_path.exists()
|
||||||
|
|
||||||
# Check config content
|
# Check config content
|
||||||
@@ -629,7 +629,7 @@ def test_mixed_json_and_tensor_state():
|
|||||||
pipeline.save_pretrained(tmp_dir)
|
pipeline.save_pretrained(tmp_dir)
|
||||||
|
|
||||||
# Check that both config and state files were created
|
# Check that both config and state files were created
|
||||||
config_path = Path(tmp_dir) / "processor.json"
|
config_path = Path(tmp_dir) / "robotprocessor.json" # Default name is "RobotProcessor"
|
||||||
state_path = Path(tmp_dir) / "step_0.safetensors"
|
state_path = Path(tmp_dir) / "step_0.safetensors"
|
||||||
assert config_path.exists()
|
assert config_path.exists()
|
||||||
assert state_path.exists()
|
assert state_path.exists()
|
||||||
@@ -1655,3 +1655,454 @@ def test_repr_edge_case_long_names():
|
|||||||
|
|
||||||
expected = "RobotProcessor(name='LongNames', steps=4: [MockStepWithNonSerializableParam, MockStepWithoutOptionalMethods, ..., MockNonModuleStepWithState], seed=999)"
|
expected = "RobotProcessor(name='LongNames', steps=4: [MockStepWithNonSerializableParam, MockStepWithoutOptionalMethods, ..., MockNonModuleStepWithState], seed=999)"
|
||||||
assert repr_str == expected
|
assert repr_str == expected
|
||||||
|
|
||||||
|
|
||||||
|
# Tests for config filename features and multiple processors
|
||||||
|
def test_save_with_custom_config_filename():
|
||||||
|
"""Test saving processor with custom config filename."""
|
||||||
|
step = MockStep("test")
|
||||||
|
pipeline = RobotProcessor([step], name="TestProcessor")
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
# Save with custom filename
|
||||||
|
pipeline.save_pretrained(tmp_dir, config_filename="my_custom_config.json")
|
||||||
|
|
||||||
|
# Check file exists
|
||||||
|
config_path = Path(tmp_dir) / "my_custom_config.json"
|
||||||
|
assert config_path.exists()
|
||||||
|
|
||||||
|
# Check content
|
||||||
|
with open(config_path) as f:
|
||||||
|
config = json.load(f)
|
||||||
|
assert config["name"] == "TestProcessor"
|
||||||
|
|
||||||
|
# Load with specific filename
|
||||||
|
loaded = RobotProcessor.from_pretrained(tmp_dir, config_filename="my_custom_config.json")
|
||||||
|
assert loaded.name == "TestProcessor"
|
||||||
|
|
||||||
|
|
||||||
|
def test_multiple_processors_same_directory():
|
||||||
|
"""Test saving multiple processors to the same directory with different config files."""
|
||||||
|
# Create different processors
|
||||||
|
preprocessor = RobotProcessor([MockStep("pre1"), MockStep("pre2")], name="preprocessor")
|
||||||
|
|
||||||
|
postprocessor = RobotProcessor([MockStepWithoutOptionalMethods(multiplier=0.5)], name="postprocessor")
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
# Save both to same directory
|
||||||
|
preprocessor.save_pretrained(tmp_dir)
|
||||||
|
postprocessor.save_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
# Check both config files exist
|
||||||
|
assert (Path(tmp_dir) / "preprocessor.json").exists()
|
||||||
|
assert (Path(tmp_dir) / "postprocessor.json").exists()
|
||||||
|
|
||||||
|
# Load them back
|
||||||
|
loaded_pre = RobotProcessor.from_pretrained(tmp_dir, config_filename="preprocessor.json")
|
||||||
|
loaded_post = RobotProcessor.from_pretrained(tmp_dir, config_filename="postprocessor.json")
|
||||||
|
|
||||||
|
assert loaded_pre.name == "preprocessor"
|
||||||
|
assert loaded_post.name == "postprocessor"
|
||||||
|
assert len(loaded_pre) == 2
|
||||||
|
assert len(loaded_post) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_auto_detect_single_config():
|
||||||
|
"""Test automatic config detection when there's only one JSON file."""
|
||||||
|
step = MockStepWithTensorState()
|
||||||
|
pipeline = RobotProcessor([step], name="SingleConfig")
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
pipeline.save_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
# Load without specifying config_filename
|
||||||
|
loaded = RobotProcessor.from_pretrained(tmp_dir)
|
||||||
|
assert loaded.name == "SingleConfig"
|
||||||
|
|
||||||
|
|
||||||
|
def test_error_multiple_configs_no_filename():
|
||||||
|
"""Test error when multiple configs exist and no filename specified."""
|
||||||
|
proc1 = RobotProcessor([MockStep()], name="processor1")
|
||||||
|
proc2 = RobotProcessor([MockStep()], name="processor2")
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
proc1.save_pretrained(tmp_dir)
|
||||||
|
proc2.save_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
# Should raise error
|
||||||
|
with pytest.raises(ValueError, match="Multiple .json files found"):
|
||||||
|
RobotProcessor.from_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
|
||||||
|
def test_state_file_naming_with_indices():
|
||||||
|
"""Test that state files include step indices to avoid conflicts."""
|
||||||
|
# Create multiple steps of same type with state
|
||||||
|
step1 = MockStepWithTensorState(name="norm1", window_size=5)
|
||||||
|
step2 = MockStepWithTensorState(name="norm2", window_size=10)
|
||||||
|
step3 = MockModuleStep(input_dim=5)
|
||||||
|
|
||||||
|
pipeline = RobotProcessor([step1, step2, step3])
|
||||||
|
|
||||||
|
# Process some data to create state
|
||||||
|
for i in range(5):
|
||||||
|
transition = create_transition(observation=torch.randn(2, 5), reward=float(i))
|
||||||
|
pipeline(transition)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
pipeline.save_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
# Check state files have indices
|
||||||
|
state_files = sorted(Path(tmp_dir).glob("*.safetensors"))
|
||||||
|
assert len(state_files) == 3
|
||||||
|
|
||||||
|
# Files should be named with indices
|
||||||
|
expected_names = ["step_0.safetensors", "step_1.safetensors", "step_2.safetensors"]
|
||||||
|
actual_names = [f.name for f in state_files]
|
||||||
|
assert actual_names == expected_names
|
||||||
|
|
||||||
|
|
||||||
|
def test_state_file_naming_with_registry():
|
||||||
|
"""Test state file naming for registered steps includes both index and name."""
|
||||||
|
|
||||||
|
# Register a test step
|
||||||
|
@ProcessorStepRegistry.register("test_stateful_step")
|
||||||
|
@dataclass
|
||||||
|
class TestStatefulStep:
|
||||||
|
value: int = 0
|
||||||
|
|
||||||
|
def __init__(self, value: int = 0):
|
||||||
|
self.value = value
|
||||||
|
self.state_tensor = torch.randn(3, 3)
|
||||||
|
|
||||||
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
|
return transition
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
return {"value": self.value}
|
||||||
|
|
||||||
|
def state_dict(self):
|
||||||
|
return {"state_tensor": self.state_tensor}
|
||||||
|
|
||||||
|
def load_state_dict(self, state):
|
||||||
|
self.state_tensor = state["state_tensor"]
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Create pipeline with registered steps
|
||||||
|
step1 = TestStatefulStep(1)
|
||||||
|
step2 = TestStatefulStep(2)
|
||||||
|
pipeline = RobotProcessor([step1, step2])
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
pipeline.save_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
# Check state files
|
||||||
|
state_files = sorted(Path(tmp_dir).glob("*.safetensors"))
|
||||||
|
assert len(state_files) == 2
|
||||||
|
|
||||||
|
# Should include both index and registry name
|
||||||
|
expected_names = [
|
||||||
|
"step_0_test_stateful_step.safetensors",
|
||||||
|
"step_1_test_stateful_step.safetensors",
|
||||||
|
]
|
||||||
|
actual_names = [f.name for f in state_files]
|
||||||
|
assert actual_names == expected_names
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Cleanup registry
|
||||||
|
ProcessorStepRegistry.unregister("test_stateful_step")
|
||||||
|
|
||||||
|
|
||||||
|
# More comprehensive override tests
|
||||||
|
def test_override_with_nested_config():
|
||||||
|
"""Test overrides with nested configuration dictionaries."""
|
||||||
|
|
||||||
|
@ProcessorStepRegistry.register("complex_config_step")
|
||||||
|
@dataclass
|
||||||
|
class ComplexConfigStep:
|
||||||
|
name: str = "complex"
|
||||||
|
simple_param: int = 42
|
||||||
|
nested_config: dict = None
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.nested_config is None:
|
||||||
|
self.nested_config = {"level1": {"level2": "default"}}
|
||||||
|
|
||||||
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
|
comp_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||||
|
comp_data = dict(comp_data)
|
||||||
|
comp_data["config_value"] = self.nested_config.get("level1", {}).get("level2", "missing")
|
||||||
|
|
||||||
|
new_transition = transition.copy()
|
||||||
|
new_transition[TransitionKey.COMPLEMENTARY_DATA] = comp_data
|
||||||
|
return new_transition
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
return {"name": self.name, "simple_param": self.simple_param, "nested_config": self.nested_config}
|
||||||
|
|
||||||
|
try:
|
||||||
|
step = ComplexConfigStep()
|
||||||
|
pipeline = RobotProcessor([step])
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
pipeline.save_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
# Load with nested override
|
||||||
|
loaded = RobotProcessor.from_pretrained(
|
||||||
|
tmp_dir,
|
||||||
|
overrides={"complex_config_step": {"nested_config": {"level1": {"level2": "overridden"}}}},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test that override worked
|
||||||
|
transition = create_transition()
|
||||||
|
result = loaded(transition)
|
||||||
|
assert result[TransitionKey.COMPLEMENTARY_DATA]["config_value"] == "overridden"
|
||||||
|
finally:
|
||||||
|
ProcessorStepRegistry.unregister("complex_config_step")
|
||||||
|
|
||||||
|
|
||||||
|
def test_override_preserves_defaults():
|
||||||
|
"""Test that overrides only affect specified parameters."""
|
||||||
|
step = MockStepWithNonSerializableParam(name="test", multiplier=2.0)
|
||||||
|
pipeline = RobotProcessor([step])
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
pipeline.save_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
# Override only one parameter
|
||||||
|
loaded = RobotProcessor.from_pretrained(
|
||||||
|
tmp_dir,
|
||||||
|
overrides={
|
||||||
|
"MockStepWithNonSerializableParam": {
|
||||||
|
"multiplier": 5.0 # Only override multiplier
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check that name was preserved from saved config
|
||||||
|
loaded_step = loaded.steps[0]
|
||||||
|
assert loaded_step.name == "test" # Original value
|
||||||
|
assert loaded_step.multiplier == 5.0 # Overridden value
|
||||||
|
|
||||||
|
|
||||||
|
def test_override_type_validation():
|
||||||
|
"""Test that type errors in overrides are caught properly."""
|
||||||
|
step = MockStepWithTensorState(learning_rate=0.01)
|
||||||
|
pipeline = RobotProcessor([step])
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
pipeline.save_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
# Try to override with wrong type
|
||||||
|
overrides = {
|
||||||
|
"MockStepWithTensorState": {
|
||||||
|
"window_size": "not_an_int" # Should be int
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Failed to instantiate"):
|
||||||
|
RobotProcessor.from_pretrained(tmp_dir, overrides=overrides)
|
||||||
|
|
||||||
|
|
||||||
|
def test_override_with_callables():
|
||||||
|
"""Test overriding with callable objects."""
|
||||||
|
|
||||||
|
@ProcessorStepRegistry.register("callable_step")
|
||||||
|
@dataclass
|
||||||
|
class CallableStep:
|
||||||
|
name: str = "callable_step"
|
||||||
|
transform_fn: Any = None
|
||||||
|
|
||||||
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
|
obs = transition.get(TransitionKey.OBSERVATION)
|
||||||
|
if obs is not None and self.transform_fn is not None:
|
||||||
|
processed_obs = {}
|
||||||
|
for k, v in obs.items():
|
||||||
|
processed_obs[k] = self.transform_fn(v)
|
||||||
|
|
||||||
|
new_transition = transition.copy()
|
||||||
|
new_transition[TransitionKey.OBSERVATION] = processed_obs
|
||||||
|
return new_transition
|
||||||
|
return transition
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
return {"name": self.name}
|
||||||
|
|
||||||
|
try:
|
||||||
|
step = CallableStep()
|
||||||
|
pipeline = RobotProcessor([step])
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
pipeline.save_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
# Define a transform function
|
||||||
|
def double_values(x):
|
||||||
|
if isinstance(x, (int, float)):
|
||||||
|
return x * 2
|
||||||
|
elif isinstance(x, torch.Tensor):
|
||||||
|
return x * 2
|
||||||
|
return x
|
||||||
|
|
||||||
|
# Load with callable override
|
||||||
|
loaded = RobotProcessor.from_pretrained(
|
||||||
|
tmp_dir, overrides={"callable_step": {"transform_fn": double_values}}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test it works
|
||||||
|
transition = create_transition(observation={"value": torch.tensor(5.0)})
|
||||||
|
result = loaded(transition)
|
||||||
|
assert result[TransitionKey.OBSERVATION]["value"].item() == 10.0
|
||||||
|
finally:
|
||||||
|
ProcessorStepRegistry.unregister("callable_step")
|
||||||
|
|
||||||
|
|
||||||
|
def test_override_multiple_same_class_warning():
|
||||||
|
"""Test behavior when multiple steps of same class exist."""
|
||||||
|
step1 = MockStepWithNonSerializableParam(name="step1", multiplier=1.0)
|
||||||
|
step2 = MockStepWithNonSerializableParam(name="step2", multiplier=2.0)
|
||||||
|
pipeline = RobotProcessor([step1, step2])
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
pipeline.save_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
# Override affects all instances of the class
|
||||||
|
loaded = RobotProcessor.from_pretrained(
|
||||||
|
tmp_dir, overrides={"MockStepWithNonSerializableParam": {"multiplier": 10.0}}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Both steps get the same override
|
||||||
|
assert loaded.steps[0].multiplier == 10.0
|
||||||
|
assert loaded.steps[1].multiplier == 10.0
|
||||||
|
|
||||||
|
# But original names are preserved
|
||||||
|
assert loaded.steps[0].name == "step1"
|
||||||
|
assert loaded.steps[1].name == "step2"
|
||||||
|
|
||||||
|
|
||||||
|
def test_config_filename_special_characters():
|
||||||
|
"""Test config filenames with special characters are sanitized."""
|
||||||
|
# Processor name with special characters
|
||||||
|
pipeline = RobotProcessor([MockStep()], name="My/Processor\\With:Special*Chars")
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
pipeline.save_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
# Check that filename was sanitized
|
||||||
|
json_files = list(Path(tmp_dir).glob("*.json"))
|
||||||
|
assert len(json_files) == 1
|
||||||
|
|
||||||
|
# Should have replaced special chars with underscores
|
||||||
|
expected_name = "my_processor_with_special_chars.json"
|
||||||
|
assert json_files[0].name == expected_name
|
||||||
|
|
||||||
|
|
||||||
|
def test_override_with_device_strings():
|
||||||
|
"""Test overriding device parameters with string values."""
|
||||||
|
|
||||||
|
@ProcessorStepRegistry.register("device_aware_step")
|
||||||
|
@dataclass
|
||||||
|
class DeviceAwareStep:
|
||||||
|
device: str = "cpu"
|
||||||
|
|
||||||
|
def __init__(self, device: str = "cpu"):
|
||||||
|
self.device = device
|
||||||
|
self.buffer = torch.zeros(10, device=device)
|
||||||
|
|
||||||
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
|
return transition
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
return {"device": str(self.device)}
|
||||||
|
|
||||||
|
def state_dict(self):
|
||||||
|
return {"buffer": self.buffer}
|
||||||
|
|
||||||
|
def load_state_dict(self, state):
|
||||||
|
self.buffer = state["buffer"]
|
||||||
|
|
||||||
|
try:
|
||||||
|
step = DeviceAwareStep(device="cpu")
|
||||||
|
pipeline = RobotProcessor([step])
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
pipeline.save_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
# Override device
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
loaded = RobotProcessor.from_pretrained(
|
||||||
|
tmp_dir, overrides={"device_aware_step": {"device": "cuda:0"}}
|
||||||
|
)
|
||||||
|
|
||||||
|
loaded_step = loaded.steps[0]
|
||||||
|
assert loaded_step.device == "cuda:0"
|
||||||
|
# Note: buffer will still be on CPU from saved state
|
||||||
|
# until .to() is called on the processor
|
||||||
|
|
||||||
|
finally:
|
||||||
|
ProcessorStepRegistry.unregister("device_aware_step")
|
||||||
|
|
||||||
|
|
||||||
|
def test_from_pretrained_nonexistent_path():
|
||||||
|
"""Test error handling when loading from non-existent sources."""
|
||||||
|
from huggingface_hub.errors import HfHubHTTPError, HFValidationError
|
||||||
|
|
||||||
|
# Test with an invalid repo ID (too many slashes) - caught by HF validation
|
||||||
|
with pytest.raises(HFValidationError):
|
||||||
|
RobotProcessor.from_pretrained("/path/that/does/not/exist")
|
||||||
|
|
||||||
|
# Test with a non-existent but valid Hub repo format
|
||||||
|
with pytest.raises((FileNotFoundError, HfHubHTTPError)):
|
||||||
|
RobotProcessor.from_pretrained("nonexistent-user/nonexistent-repo")
|
||||||
|
|
||||||
|
# Test with a local directory that exists but has no config files
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
with pytest.raises(FileNotFoundError, match="No .json configuration files found"):
|
||||||
|
RobotProcessor.from_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_load_with_custom_converter_functions():
|
||||||
|
"""Test that custom to_transition and to_output functions are NOT saved."""
|
||||||
|
|
||||||
|
def custom_to_transition(batch):
|
||||||
|
# Custom conversion logic
|
||||||
|
return {
|
||||||
|
TransitionKey.OBSERVATION: batch.get("obs"),
|
||||||
|
TransitionKey.ACTION: batch.get("act"),
|
||||||
|
TransitionKey.REWARD: batch.get("rew", 0.0),
|
||||||
|
TransitionKey.DONE: batch.get("done", False),
|
||||||
|
TransitionKey.TRUNCATED: batch.get("truncated", False),
|
||||||
|
TransitionKey.INFO: {},
|
||||||
|
TransitionKey.COMPLEMENTARY_DATA: {},
|
||||||
|
}
|
||||||
|
|
||||||
|
def custom_to_output(transition):
|
||||||
|
# Custom output format
|
||||||
|
return {
|
||||||
|
"obs": transition.get(TransitionKey.OBSERVATION),
|
||||||
|
"act": transition.get(TransitionKey.ACTION),
|
||||||
|
"rew": transition.get(TransitionKey.REWARD),
|
||||||
|
"done": transition.get(TransitionKey.DONE),
|
||||||
|
"truncated": transition.get(TransitionKey.TRUNCATED),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create processor with custom converters
|
||||||
|
pipeline = RobotProcessor([MockStep()], to_transition=custom_to_transition, to_output=custom_to_output)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
pipeline.save_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
# Load - should use default converters
|
||||||
|
loaded = RobotProcessor.from_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
# Verify it uses default converters by checking with standard batch format
|
||||||
|
batch = {
|
||||||
|
"observation.image": torch.randn(1, 3, 32, 32),
|
||||||
|
"action": torch.randn(1, 7),
|
||||||
|
"next.reward": torch.tensor([1.0]),
|
||||||
|
"next.done": torch.tensor([False]),
|
||||||
|
"next.truncated": torch.tensor([False]),
|
||||||
|
"info": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Should work with standard format (wouldn't work with custom converter)
|
||||||
|
result = loaded(batch)
|
||||||
|
assert "observation.image" in result # Standard format preserved
|
||||||
|
|||||||
@@ -225,7 +225,7 @@ def test_save_and_load_pretrained():
|
|||||||
pipeline.save_pretrained(tmp_dir)
|
pipeline.save_pretrained(tmp_dir)
|
||||||
|
|
||||||
# Check files were created
|
# Check files were created
|
||||||
config_path = Path(tmp_dir) / "processor.json"
|
config_path = Path(tmp_dir) / "testrenameprocessor.json" # Based on name="TestRenameProcessor"
|
||||||
assert config_path.exists()
|
assert config_path.exists()
|
||||||
|
|
||||||
# No state files should be created for RenameProcessor
|
# No state files should be created for RenameProcessor
|
||||||
@@ -283,7 +283,7 @@ def test_registry_based_save_load():
|
|||||||
# Verify config uses registry name
|
# Verify config uses registry name
|
||||||
import json
|
import json
|
||||||
|
|
||||||
with open(Path(tmp_dir) / "processor.json") as f:
|
with open(Path(tmp_dir) / "robotprocessor.json") as f: # Default name is "RobotProcessor"
|
||||||
config = json.load(f)
|
config = json.load(f)
|
||||||
|
|
||||||
assert "registry_name" in config["steps"][0]
|
assert "registry_name" in config["steps"][0]
|
||||||
|
|||||||
Reference in New Issue
Block a user