feat(pipeline): Enhance configuration filename handling and state file naming

- Introduced support for custom configuration filenames in the `save_pretrained` method, allowing users to specify a filename instead of the default.
- Improved state file naming to include step indices, preventing conflicts when multiple processors of the same type are saved.
- Added automatic detection for configuration files when loading from a directory, with error handling for multiple files.
- Updated tests to validate new features, including custom filenames and automatic config detection.
This commit is contained in:
Adil Zouitine
2025-07-22 14:35:34 +02:00
parent 409ac0baca
commit 4ba23ea029
3 changed files with 535 additions and 16 deletions
+80 -12
View File
@@ -26,6 +26,7 @@ from typing import Any, Protocol, TypedDict
import torch import torch
from huggingface_hub import ModelHubMixin, hf_hub_download from huggingface_hub import ModelHubMixin, hf_hub_download
from huggingface_hub.errors import HfHubHTTPError
from safetensors.torch import load_file, save_file from safetensors.torch import load_file, save_file
from lerobot.utils.utils import get_safe_torch_device from lerobot.utils.utils import get_safe_torch_device
@@ -293,8 +294,6 @@ class RobotProcessor(ModelHubMixin):
after_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False) after_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False)
reset_hooks: list[Callable[[], None]] = field(default_factory=list, repr=False) reset_hooks: list[Callable[[], None]] = field(default_factory=list, repr=False)
_CFG_NAME = "processor.json"
def __call__(self, data: EnvTransition | dict[str, Any]): def __call__(self, data: EnvTransition | dict[str, Any]):
"""Process data through all steps. """Process data through all steps.
@@ -386,7 +385,9 @@ class RobotProcessor(ModelHubMixin):
def _save_pretrained(self, destination_path: str, **kwargs): def _save_pretrained(self, destination_path: str, **kwargs):
"""Internal save method for ModelHubMixin compatibility.""" """Internal save method for ModelHubMixin compatibility."""
self.save_pretrained(destination_path) # Extract config_filename from kwargs if provided
config_filename = kwargs.pop("config_filename", None)
self.save_pretrained(destination_path, config_filename=config_filename)
def _generate_model_card(self, destination_path: str) -> None: def _generate_model_card(self, destination_path: str) -> None:
"""Generate README.md from the RobotProcessor model card template.""" """Generate README.md from the RobotProcessor model card template."""
@@ -405,10 +406,24 @@ class RobotProcessor(ModelHubMixin):
with open(readme_path, "w") as f: with open(readme_path, "w") as f:
f.write(model_card_content) f.write(model_card_content)
def save_pretrained(self, destination_path: str, **kwargs): def save_pretrained(self, destination_path: str, config_filename: str | None = None, **kwargs):
"""Serialize the processor definition and parameters to *destination_path*.""" """Serialize the processor definition and parameters to *destination_path*.
Args:
destination_path: Directory where the processor will be saved.
config_filename: Optional custom config filename. If not provided, defaults to
"{self.name}.json" where self.name is sanitized for filesystem compatibility.
"""
os.makedirs(destination_path, exist_ok=True) os.makedirs(destination_path, exist_ok=True)
# Determine config filename - sanitize the processor name for filesystem
if config_filename is None:
# Sanitize name - replace any character that's not alphanumeric or underscore
import re
sanitized_name = re.sub(r"[^a-zA-Z0-9_]", "_", self.name.lower())
config_filename = f"{sanitized_name}.json"
config: dict[str, Any] = { config: dict[str, Any] = {
"name": self.name, "name": self.name,
"seed": self.seed, "seed": self.seed,
@@ -448,9 +463,10 @@ class RobotProcessor(ModelHubMixin):
for key, tensor in state.items(): for key, tensor in state.items():
cloned_state[key] = tensor.clone() cloned_state[key] = tensor.clone()
# Use registry name for more meaningful filenames when available # Always include step index to ensure unique filenames
# This prevents conflicts when the same processor type is used multiple times
if registry_name: if registry_name:
state_filename = f"{registry_name}.safetensors" state_filename = f"step_{step_index}_{registry_name}.safetensors"
else: else:
state_filename = f"step_{step_index}.safetensors" state_filename = f"step_{step_index}.safetensors"
@@ -459,7 +475,7 @@ class RobotProcessor(ModelHubMixin):
config["steps"].append(step_entry) config["steps"].append(step_entry)
with open(os.path.join(destination_path, self._CFG_NAME), "w") as file_pointer: with open(os.path.join(destination_path, config_filename), "w") as file_pointer:
json.dump(config, file_pointer, indent=2) json.dump(config, file_pointer, indent=2)
# Generate README.md from template # Generate README.md from template
@@ -484,12 +500,17 @@ class RobotProcessor(ModelHubMixin):
return self return self
@classmethod @classmethod
def from_pretrained(cls, source: str, *, overrides: dict[str, Any] | None = None) -> RobotProcessor: def from_pretrained(
cls, source: str, *, config_filename: str | None = None, overrides: dict[str, Any] | None = None
) -> RobotProcessor:
"""Load a serialized processor from source (local path or Hugging Face Hub identifier). """Load a serialized processor from source (local path or Hugging Face Hub identifier).
Args: Args:
source: Local path to a saved processor directory or Hugging Face Hub identifier source: Local path to a saved processor directory or Hugging Face Hub identifier
(e.g., "username/processor-name"). (e.g., "username/processor-name").
config_filename: Optional specific config filename to load. If not provided, will:
- For local paths: look for any .json file in the directory (error if multiple found)
- For HF Hub: try common names ("processor.json", "preprocessor.json", "postprocessor.json")
overrides: Optional dictionary mapping step names to configuration overrides. overrides: Optional dictionary mapping step names to configuration overrides.
Keys must match exact step class names (for unregistered steps) or registry names Keys must match exact step class names (for unregistered steps) or registry names
(for registered steps). Values are dictionaries containing parameter overrides (for registered steps). Values are dictionaries containing parameter overrides
@@ -510,6 +531,13 @@ class RobotProcessor(ModelHubMixin):
processor = RobotProcessor.from_pretrained("path/to/processor") processor = RobotProcessor.from_pretrained("path/to/processor")
``` ```
Loading specific config file:
```python
processor = RobotProcessor.from_pretrained(
"username/multi-processor-repo", config_filename="preprocessor.json"
)
```
Loading with overrides for non-serializable objects: Loading with overrides for non-serializable objects:
```python ```python
import gym import gym
@@ -534,12 +562,52 @@ class RobotProcessor(ModelHubMixin):
if Path(source).is_dir(): if Path(source).is_dir():
# Local path - use it directly # Local path - use it directly
base_path = Path(source) base_path = Path(source)
with open(base_path / cls._CFG_NAME) as file_pointer:
if config_filename is None:
# Look for any .json file in the directory
json_files = list(base_path.glob("*.json"))
if len(json_files) == 0:
raise FileNotFoundError(f"No .json configuration files found in {source}")
elif len(json_files) > 1:
raise ValueError(
f"Multiple .json files found in {source}: {[f.name for f in json_files]}. "
f"Please specify which one to load using the config_filename parameter."
)
config_filename = json_files[0].name
with open(base_path / config_filename) as file_pointer:
config: dict[str, Any] = json.load(file_pointer) config: dict[str, Any] = json.load(file_pointer)
else: else:
# Hugging Face Hub - download all required files # Hugging Face Hub - download all required files
# First download the config file if config_filename is None:
config_path = hf_hub_download(source, cls._CFG_NAME, repo_type="model") # Try common config names
common_names = [
"processor.json",
"preprocessor.json",
"postprocessor.json",
"robotprocessor.json",
]
config_path = None
for name in common_names:
try:
config_path = hf_hub_download(source, name, repo_type="model")
config_filename = name
break
except (FileNotFoundError, OSError, HfHubHTTPError):
# FileNotFoundError: local file issues
# OSError: network/system errors
# HfHubHTTPError: file not found on Hub (404) or other HTTP errors
continue
if config_path is None:
raise FileNotFoundError(
f"No processor configuration file found in {source}. "
f"Tried: {common_names}. Please specify the config_filename parameter."
)
else:
# Download specific config file
config_path = hf_hub_download(source, config_filename, repo_type="model")
with open(config_path) as file_pointer: with open(config_path) as file_pointer:
config: dict[str, Any] = json.load(file_pointer) config: dict[str, Any] = json.load(file_pointer)
+453 -2
View File
@@ -563,7 +563,7 @@ def test_save_and_load_pretrained():
pipeline.save_pretrained(tmp_dir) pipeline.save_pretrained(tmp_dir)
# Check files were created # Check files were created
config_path = Path(tmp_dir) / "processor.json" config_path = Path(tmp_dir) / "testpipeline.json" # Based on name="TestPipeline"
assert config_path.exists() assert config_path.exists()
# Check config content # Check config content
@@ -629,7 +629,7 @@ def test_mixed_json_and_tensor_state():
pipeline.save_pretrained(tmp_dir) pipeline.save_pretrained(tmp_dir)
# Check that both config and state files were created # Check that both config and state files were created
config_path = Path(tmp_dir) / "processor.json" config_path = Path(tmp_dir) / "robotprocessor.json" # Default name is "RobotProcessor"
state_path = Path(tmp_dir) / "step_0.safetensors" state_path = Path(tmp_dir) / "step_0.safetensors"
assert config_path.exists() assert config_path.exists()
assert state_path.exists() assert state_path.exists()
@@ -1655,3 +1655,454 @@ def test_repr_edge_case_long_names():
expected = "RobotProcessor(name='LongNames', steps=4: [MockStepWithNonSerializableParam, MockStepWithoutOptionalMethods, ..., MockNonModuleStepWithState], seed=999)" expected = "RobotProcessor(name='LongNames', steps=4: [MockStepWithNonSerializableParam, MockStepWithoutOptionalMethods, ..., MockNonModuleStepWithState], seed=999)"
assert repr_str == expected assert repr_str == expected
# Tests for config filename features and multiple processors
def test_save_with_custom_config_filename():
"""Test saving processor with custom config filename."""
step = MockStep("test")
pipeline = RobotProcessor([step], name="TestProcessor")
with tempfile.TemporaryDirectory() as tmp_dir:
# Save with custom filename
pipeline.save_pretrained(tmp_dir, config_filename="my_custom_config.json")
# Check file exists
config_path = Path(tmp_dir) / "my_custom_config.json"
assert config_path.exists()
# Check content
with open(config_path) as f:
config = json.load(f)
assert config["name"] == "TestProcessor"
# Load with specific filename
loaded = RobotProcessor.from_pretrained(tmp_dir, config_filename="my_custom_config.json")
assert loaded.name == "TestProcessor"
def test_multiple_processors_same_directory():
"""Test saving multiple processors to the same directory with different config files."""
# Create different processors
preprocessor = RobotProcessor([MockStep("pre1"), MockStep("pre2")], name="preprocessor")
postprocessor = RobotProcessor([MockStepWithoutOptionalMethods(multiplier=0.5)], name="postprocessor")
with tempfile.TemporaryDirectory() as tmp_dir:
# Save both to same directory
preprocessor.save_pretrained(tmp_dir)
postprocessor.save_pretrained(tmp_dir)
# Check both config files exist
assert (Path(tmp_dir) / "preprocessor.json").exists()
assert (Path(tmp_dir) / "postprocessor.json").exists()
# Load them back
loaded_pre = RobotProcessor.from_pretrained(tmp_dir, config_filename="preprocessor.json")
loaded_post = RobotProcessor.from_pretrained(tmp_dir, config_filename="postprocessor.json")
assert loaded_pre.name == "preprocessor"
assert loaded_post.name == "postprocessor"
assert len(loaded_pre) == 2
assert len(loaded_post) == 1
def test_auto_detect_single_config():
"""Test automatic config detection when there's only one JSON file."""
step = MockStepWithTensorState()
pipeline = RobotProcessor([step], name="SingleConfig")
with tempfile.TemporaryDirectory() as tmp_dir:
pipeline.save_pretrained(tmp_dir)
# Load without specifying config_filename
loaded = RobotProcessor.from_pretrained(tmp_dir)
assert loaded.name == "SingleConfig"
def test_error_multiple_configs_no_filename():
"""Test error when multiple configs exist and no filename specified."""
proc1 = RobotProcessor([MockStep()], name="processor1")
proc2 = RobotProcessor([MockStep()], name="processor2")
with tempfile.TemporaryDirectory() as tmp_dir:
proc1.save_pretrained(tmp_dir)
proc2.save_pretrained(tmp_dir)
# Should raise error
with pytest.raises(ValueError, match="Multiple .json files found"):
RobotProcessor.from_pretrained(tmp_dir)
def test_state_file_naming_with_indices():
"""Test that state files include step indices to avoid conflicts."""
# Create multiple steps of same type with state
step1 = MockStepWithTensorState(name="norm1", window_size=5)
step2 = MockStepWithTensorState(name="norm2", window_size=10)
step3 = MockModuleStep(input_dim=5)
pipeline = RobotProcessor([step1, step2, step3])
# Process some data to create state
for i in range(5):
transition = create_transition(observation=torch.randn(2, 5), reward=float(i))
pipeline(transition)
with tempfile.TemporaryDirectory() as tmp_dir:
pipeline.save_pretrained(tmp_dir)
# Check state files have indices
state_files = sorted(Path(tmp_dir).glob("*.safetensors"))
assert len(state_files) == 3
# Files should be named with indices
expected_names = ["step_0.safetensors", "step_1.safetensors", "step_2.safetensors"]
actual_names = [f.name for f in state_files]
assert actual_names == expected_names
def test_state_file_naming_with_registry():
"""Test state file naming for registered steps includes both index and name."""
# Register a test step
@ProcessorStepRegistry.register("test_stateful_step")
@dataclass
class TestStatefulStep:
value: int = 0
def __init__(self, value: int = 0):
self.value = value
self.state_tensor = torch.randn(3, 3)
def __call__(self, transition: EnvTransition) -> EnvTransition:
return transition
def get_config(self):
return {"value": self.value}
def state_dict(self):
return {"state_tensor": self.state_tensor}
def load_state_dict(self, state):
self.state_tensor = state["state_tensor"]
try:
# Create pipeline with registered steps
step1 = TestStatefulStep(1)
step2 = TestStatefulStep(2)
pipeline = RobotProcessor([step1, step2])
with tempfile.TemporaryDirectory() as tmp_dir:
pipeline.save_pretrained(tmp_dir)
# Check state files
state_files = sorted(Path(tmp_dir).glob("*.safetensors"))
assert len(state_files) == 2
# Should include both index and registry name
expected_names = [
"step_0_test_stateful_step.safetensors",
"step_1_test_stateful_step.safetensors",
]
actual_names = [f.name for f in state_files]
assert actual_names == expected_names
finally:
# Cleanup registry
ProcessorStepRegistry.unregister("test_stateful_step")
# More comprehensive override tests
def test_override_with_nested_config():
"""Test overrides with nested configuration dictionaries."""
@ProcessorStepRegistry.register("complex_config_step")
@dataclass
class ComplexConfigStep:
name: str = "complex"
simple_param: int = 42
nested_config: dict = None
def __post_init__(self):
if self.nested_config is None:
self.nested_config = {"level1": {"level2": "default"}}
def __call__(self, transition: EnvTransition) -> EnvTransition:
comp_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
comp_data = dict(comp_data)
comp_data["config_value"] = self.nested_config.get("level1", {}).get("level2", "missing")
new_transition = transition.copy()
new_transition[TransitionKey.COMPLEMENTARY_DATA] = comp_data
return new_transition
def get_config(self):
return {"name": self.name, "simple_param": self.simple_param, "nested_config": self.nested_config}
try:
step = ComplexConfigStep()
pipeline = RobotProcessor([step])
with tempfile.TemporaryDirectory() as tmp_dir:
pipeline.save_pretrained(tmp_dir)
# Load with nested override
loaded = RobotProcessor.from_pretrained(
tmp_dir,
overrides={"complex_config_step": {"nested_config": {"level1": {"level2": "overridden"}}}},
)
# Test that override worked
transition = create_transition()
result = loaded(transition)
assert result[TransitionKey.COMPLEMENTARY_DATA]["config_value"] == "overridden"
finally:
ProcessorStepRegistry.unregister("complex_config_step")
def test_override_preserves_defaults():
"""Test that overrides only affect specified parameters."""
step = MockStepWithNonSerializableParam(name="test", multiplier=2.0)
pipeline = RobotProcessor([step])
with tempfile.TemporaryDirectory() as tmp_dir:
pipeline.save_pretrained(tmp_dir)
# Override only one parameter
loaded = RobotProcessor.from_pretrained(
tmp_dir,
overrides={
"MockStepWithNonSerializableParam": {
"multiplier": 5.0 # Only override multiplier
}
},
)
# Check that name was preserved from saved config
loaded_step = loaded.steps[0]
assert loaded_step.name == "test" # Original value
assert loaded_step.multiplier == 5.0 # Overridden value
def test_override_type_validation():
"""Test that type errors in overrides are caught properly."""
step = MockStepWithTensorState(learning_rate=0.01)
pipeline = RobotProcessor([step])
with tempfile.TemporaryDirectory() as tmp_dir:
pipeline.save_pretrained(tmp_dir)
# Try to override with wrong type
overrides = {
"MockStepWithTensorState": {
"window_size": "not_an_int" # Should be int
}
}
with pytest.raises(ValueError, match="Failed to instantiate"):
RobotProcessor.from_pretrained(tmp_dir, overrides=overrides)
def test_override_with_callables():
"""Test overriding with callable objects."""
@ProcessorStepRegistry.register("callable_step")
@dataclass
class CallableStep:
name: str = "callable_step"
transform_fn: Any = None
def __call__(self, transition: EnvTransition) -> EnvTransition:
obs = transition.get(TransitionKey.OBSERVATION)
if obs is not None and self.transform_fn is not None:
processed_obs = {}
for k, v in obs.items():
processed_obs[k] = self.transform_fn(v)
new_transition = transition.copy()
new_transition[TransitionKey.OBSERVATION] = processed_obs
return new_transition
return transition
def get_config(self):
return {"name": self.name}
try:
step = CallableStep()
pipeline = RobotProcessor([step])
with tempfile.TemporaryDirectory() as tmp_dir:
pipeline.save_pretrained(tmp_dir)
# Define a transform function
def double_values(x):
if isinstance(x, (int, float)):
return x * 2
elif isinstance(x, torch.Tensor):
return x * 2
return x
# Load with callable override
loaded = RobotProcessor.from_pretrained(
tmp_dir, overrides={"callable_step": {"transform_fn": double_values}}
)
# Test it works
transition = create_transition(observation={"value": torch.tensor(5.0)})
result = loaded(transition)
assert result[TransitionKey.OBSERVATION]["value"].item() == 10.0
finally:
ProcessorStepRegistry.unregister("callable_step")
def test_override_multiple_same_class_warning():
"""Test behavior when multiple steps of same class exist."""
step1 = MockStepWithNonSerializableParam(name="step1", multiplier=1.0)
step2 = MockStepWithNonSerializableParam(name="step2", multiplier=2.0)
pipeline = RobotProcessor([step1, step2])
with tempfile.TemporaryDirectory() as tmp_dir:
pipeline.save_pretrained(tmp_dir)
# Override affects all instances of the class
loaded = RobotProcessor.from_pretrained(
tmp_dir, overrides={"MockStepWithNonSerializableParam": {"multiplier": 10.0}}
)
# Both steps get the same override
assert loaded.steps[0].multiplier == 10.0
assert loaded.steps[1].multiplier == 10.0
# But original names are preserved
assert loaded.steps[0].name == "step1"
assert loaded.steps[1].name == "step2"
def test_config_filename_special_characters():
"""Test config filenames with special characters are sanitized."""
# Processor name with special characters
pipeline = RobotProcessor([MockStep()], name="My/Processor\\With:Special*Chars")
with tempfile.TemporaryDirectory() as tmp_dir:
pipeline.save_pretrained(tmp_dir)
# Check that filename was sanitized
json_files = list(Path(tmp_dir).glob("*.json"))
assert len(json_files) == 1
# Should have replaced special chars with underscores
expected_name = "my_processor_with_special_chars.json"
assert json_files[0].name == expected_name
def test_override_with_device_strings():
"""Test overriding device parameters with string values."""
@ProcessorStepRegistry.register("device_aware_step")
@dataclass
class DeviceAwareStep:
device: str = "cpu"
def __init__(self, device: str = "cpu"):
self.device = device
self.buffer = torch.zeros(10, device=device)
def __call__(self, transition: EnvTransition) -> EnvTransition:
return transition
def get_config(self):
return {"device": str(self.device)}
def state_dict(self):
return {"buffer": self.buffer}
def load_state_dict(self, state):
self.buffer = state["buffer"]
try:
step = DeviceAwareStep(device="cpu")
pipeline = RobotProcessor([step])
with tempfile.TemporaryDirectory() as tmp_dir:
pipeline.save_pretrained(tmp_dir)
# Override device
if torch.cuda.is_available():
loaded = RobotProcessor.from_pretrained(
tmp_dir, overrides={"device_aware_step": {"device": "cuda:0"}}
)
loaded_step = loaded.steps[0]
assert loaded_step.device == "cuda:0"
# Note: buffer will still be on CPU from saved state
# until .to() is called on the processor
finally:
ProcessorStepRegistry.unregister("device_aware_step")
def test_from_pretrained_nonexistent_path():
"""Test error handling when loading from non-existent sources."""
from huggingface_hub.errors import HfHubHTTPError, HFValidationError
# Test with an invalid repo ID (too many slashes) - caught by HF validation
with pytest.raises(HFValidationError):
RobotProcessor.from_pretrained("/path/that/does/not/exist")
# Test with a non-existent but valid Hub repo format
with pytest.raises((FileNotFoundError, HfHubHTTPError)):
RobotProcessor.from_pretrained("nonexistent-user/nonexistent-repo")
# Test with a local directory that exists but has no config files
with tempfile.TemporaryDirectory() as tmp_dir:
with pytest.raises(FileNotFoundError, match="No .json configuration files found"):
RobotProcessor.from_pretrained(tmp_dir)
def test_save_load_with_custom_converter_functions():
"""Test that custom to_transition and to_output functions are NOT saved."""
def custom_to_transition(batch):
# Custom conversion logic
return {
TransitionKey.OBSERVATION: batch.get("obs"),
TransitionKey.ACTION: batch.get("act"),
TransitionKey.REWARD: batch.get("rew", 0.0),
TransitionKey.DONE: batch.get("done", False),
TransitionKey.TRUNCATED: batch.get("truncated", False),
TransitionKey.INFO: {},
TransitionKey.COMPLEMENTARY_DATA: {},
}
def custom_to_output(transition):
# Custom output format
return {
"obs": transition.get(TransitionKey.OBSERVATION),
"act": transition.get(TransitionKey.ACTION),
"rew": transition.get(TransitionKey.REWARD),
"done": transition.get(TransitionKey.DONE),
"truncated": transition.get(TransitionKey.TRUNCATED),
}
# Create processor with custom converters
pipeline = RobotProcessor([MockStep()], to_transition=custom_to_transition, to_output=custom_to_output)
with tempfile.TemporaryDirectory() as tmp_dir:
pipeline.save_pretrained(tmp_dir)
# Load - should use default converters
loaded = RobotProcessor.from_pretrained(tmp_dir)
# Verify it uses default converters by checking with standard batch format
batch = {
"observation.image": torch.randn(1, 3, 32, 32),
"action": torch.randn(1, 7),
"next.reward": torch.tensor([1.0]),
"next.done": torch.tensor([False]),
"next.truncated": torch.tensor([False]),
"info": {},
}
# Should work with standard format (wouldn't work with custom converter)
result = loaded(batch)
assert "observation.image" in result # Standard format preserved
+2 -2
View File
@@ -225,7 +225,7 @@ def test_save_and_load_pretrained():
pipeline.save_pretrained(tmp_dir) pipeline.save_pretrained(tmp_dir)
# Check files were created # Check files were created
config_path = Path(tmp_dir) / "processor.json" config_path = Path(tmp_dir) / "testrenameprocessor.json" # Based on name="TestRenameProcessor"
assert config_path.exists() assert config_path.exists()
# No state files should be created for RenameProcessor # No state files should be created for RenameProcessor
@@ -283,7 +283,7 @@ def test_registry_based_save_load():
# Verify config uses registry name # Verify config uses registry name
import json import json
with open(Path(tmp_dir) / "processor.json") as f: with open(Path(tmp_dir) / "robotprocessor.json") as f: # Default name is "RobotProcessor"
config = json.load(f) config = json.load(f)
assert "registry_name" in config["steps"][0] assert "registry_name" in config["steps"][0]