feat(pipeline): Enhance configuration filename handling and state file naming

- Introduced support for custom configuration filenames in the `save_pretrained` method, allowing users to specify a filename instead of the default.
- Improved state file naming to include step indices, preventing conflicts when multiple processors of the same type are saved.
- Added automatic detection for configuration files when loading from a directory, with error handling for multiple files.
- Updated tests to validate new features, including custom filenames and automatic config detection.
This commit is contained in:
Adil Zouitine
2025-07-22 14:35:34 +02:00
parent 409ac0baca
commit 4ba23ea029
3 changed files with 535 additions and 16 deletions
+453 -2
View File
@@ -563,7 +563,7 @@ def test_save_and_load_pretrained():
pipeline.save_pretrained(tmp_dir)
# Check files were created
config_path = Path(tmp_dir) / "processor.json"
config_path = Path(tmp_dir) / "testpipeline.json" # Based on name="TestPipeline"
assert config_path.exists()
# Check config content
@@ -629,7 +629,7 @@ def test_mixed_json_and_tensor_state():
pipeline.save_pretrained(tmp_dir)
# Check that both config and state files were created
config_path = Path(tmp_dir) / "processor.json"
config_path = Path(tmp_dir) / "robotprocessor.json" # Default name is "RobotProcessor"
state_path = Path(tmp_dir) / "step_0.safetensors"
assert config_path.exists()
assert state_path.exists()
@@ -1655,3 +1655,454 @@ def test_repr_edge_case_long_names():
expected = "RobotProcessor(name='LongNames', steps=4: [MockStepWithNonSerializableParam, MockStepWithoutOptionalMethods, ..., MockNonModuleStepWithState], seed=999)"
assert repr_str == expected
# Tests for config filename features and multiple processors
def test_save_with_custom_config_filename():
"""Test saving processor with custom config filename."""
step = MockStep("test")
pipeline = RobotProcessor([step], name="TestProcessor")
with tempfile.TemporaryDirectory() as tmp_dir:
# Save with custom filename
pipeline.save_pretrained(tmp_dir, config_filename="my_custom_config.json")
# Check file exists
config_path = Path(tmp_dir) / "my_custom_config.json"
assert config_path.exists()
# Check content
with open(config_path) as f:
config = json.load(f)
assert config["name"] == "TestProcessor"
# Load with specific filename
loaded = RobotProcessor.from_pretrained(tmp_dir, config_filename="my_custom_config.json")
assert loaded.name == "TestProcessor"
def test_multiple_processors_same_directory():
"""Test saving multiple processors to the same directory with different config files."""
# Create different processors
preprocessor = RobotProcessor([MockStep("pre1"), MockStep("pre2")], name="preprocessor")
postprocessor = RobotProcessor([MockStepWithoutOptionalMethods(multiplier=0.5)], name="postprocessor")
with tempfile.TemporaryDirectory() as tmp_dir:
# Save both to same directory
preprocessor.save_pretrained(tmp_dir)
postprocessor.save_pretrained(tmp_dir)
# Check both config files exist
assert (Path(tmp_dir) / "preprocessor.json").exists()
assert (Path(tmp_dir) / "postprocessor.json").exists()
# Load them back
loaded_pre = RobotProcessor.from_pretrained(tmp_dir, config_filename="preprocessor.json")
loaded_post = RobotProcessor.from_pretrained(tmp_dir, config_filename="postprocessor.json")
assert loaded_pre.name == "preprocessor"
assert loaded_post.name == "postprocessor"
assert len(loaded_pre) == 2
assert len(loaded_post) == 1
def test_auto_detect_single_config():
"""Test automatic config detection when there's only one JSON file."""
step = MockStepWithTensorState()
pipeline = RobotProcessor([step], name="SingleConfig")
with tempfile.TemporaryDirectory() as tmp_dir:
pipeline.save_pretrained(tmp_dir)
# Load without specifying config_filename
loaded = RobotProcessor.from_pretrained(tmp_dir)
assert loaded.name == "SingleConfig"
def test_error_multiple_configs_no_filename():
"""Test error when multiple configs exist and no filename specified."""
proc1 = RobotProcessor([MockStep()], name="processor1")
proc2 = RobotProcessor([MockStep()], name="processor2")
with tempfile.TemporaryDirectory() as tmp_dir:
proc1.save_pretrained(tmp_dir)
proc2.save_pretrained(tmp_dir)
# Should raise error
with pytest.raises(ValueError, match="Multiple .json files found"):
RobotProcessor.from_pretrained(tmp_dir)
def test_state_file_naming_with_indices():
"""Test that state files include step indices to avoid conflicts."""
# Create multiple steps of same type with state
step1 = MockStepWithTensorState(name="norm1", window_size=5)
step2 = MockStepWithTensorState(name="norm2", window_size=10)
step3 = MockModuleStep(input_dim=5)
pipeline = RobotProcessor([step1, step2, step3])
# Process some data to create state
for i in range(5):
transition = create_transition(observation=torch.randn(2, 5), reward=float(i))
pipeline(transition)
with tempfile.TemporaryDirectory() as tmp_dir:
pipeline.save_pretrained(tmp_dir)
# Check state files have indices
state_files = sorted(Path(tmp_dir).glob("*.safetensors"))
assert len(state_files) == 3
# Files should be named with indices
expected_names = ["step_0.safetensors", "step_1.safetensors", "step_2.safetensors"]
actual_names = [f.name for f in state_files]
assert actual_names == expected_names
def test_state_file_naming_with_registry():
"""Test state file naming for registered steps includes both index and name."""
# Register a test step
@ProcessorStepRegistry.register("test_stateful_step")
@dataclass
class TestStatefulStep:
value: int = 0
def __init__(self, value: int = 0):
self.value = value
self.state_tensor = torch.randn(3, 3)
def __call__(self, transition: EnvTransition) -> EnvTransition:
return transition
def get_config(self):
return {"value": self.value}
def state_dict(self):
return {"state_tensor": self.state_tensor}
def load_state_dict(self, state):
self.state_tensor = state["state_tensor"]
try:
# Create pipeline with registered steps
step1 = TestStatefulStep(1)
step2 = TestStatefulStep(2)
pipeline = RobotProcessor([step1, step2])
with tempfile.TemporaryDirectory() as tmp_dir:
pipeline.save_pretrained(tmp_dir)
# Check state files
state_files = sorted(Path(tmp_dir).glob("*.safetensors"))
assert len(state_files) == 2
# Should include both index and registry name
expected_names = [
"step_0_test_stateful_step.safetensors",
"step_1_test_stateful_step.safetensors",
]
actual_names = [f.name for f in state_files]
assert actual_names == expected_names
finally:
# Cleanup registry
ProcessorStepRegistry.unregister("test_stateful_step")
# More comprehensive override tests
def test_override_with_nested_config():
"""Test overrides with nested configuration dictionaries."""
@ProcessorStepRegistry.register("complex_config_step")
@dataclass
class ComplexConfigStep:
name: str = "complex"
simple_param: int = 42
nested_config: dict = None
def __post_init__(self):
if self.nested_config is None:
self.nested_config = {"level1": {"level2": "default"}}
def __call__(self, transition: EnvTransition) -> EnvTransition:
comp_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
comp_data = dict(comp_data)
comp_data["config_value"] = self.nested_config.get("level1", {}).get("level2", "missing")
new_transition = transition.copy()
new_transition[TransitionKey.COMPLEMENTARY_DATA] = comp_data
return new_transition
def get_config(self):
return {"name": self.name, "simple_param": self.simple_param, "nested_config": self.nested_config}
try:
step = ComplexConfigStep()
pipeline = RobotProcessor([step])
with tempfile.TemporaryDirectory() as tmp_dir:
pipeline.save_pretrained(tmp_dir)
# Load with nested override
loaded = RobotProcessor.from_pretrained(
tmp_dir,
overrides={"complex_config_step": {"nested_config": {"level1": {"level2": "overridden"}}}},
)
# Test that override worked
transition = create_transition()
result = loaded(transition)
assert result[TransitionKey.COMPLEMENTARY_DATA]["config_value"] == "overridden"
finally:
ProcessorStepRegistry.unregister("complex_config_step")
def test_override_preserves_defaults():
"""Test that overrides only affect specified parameters."""
step = MockStepWithNonSerializableParam(name="test", multiplier=2.0)
pipeline = RobotProcessor([step])
with tempfile.TemporaryDirectory() as tmp_dir:
pipeline.save_pretrained(tmp_dir)
# Override only one parameter
loaded = RobotProcessor.from_pretrained(
tmp_dir,
overrides={
"MockStepWithNonSerializableParam": {
"multiplier": 5.0 # Only override multiplier
}
},
)
# Check that name was preserved from saved config
loaded_step = loaded.steps[0]
assert loaded_step.name == "test" # Original value
assert loaded_step.multiplier == 5.0 # Overridden value
def test_override_type_validation():
"""Test that type errors in overrides are caught properly."""
step = MockStepWithTensorState(learning_rate=0.01)
pipeline = RobotProcessor([step])
with tempfile.TemporaryDirectory() as tmp_dir:
pipeline.save_pretrained(tmp_dir)
# Try to override with wrong type
overrides = {
"MockStepWithTensorState": {
"window_size": "not_an_int" # Should be int
}
}
with pytest.raises(ValueError, match="Failed to instantiate"):
RobotProcessor.from_pretrained(tmp_dir, overrides=overrides)
def test_override_with_callables():
"""Test overriding with callable objects."""
@ProcessorStepRegistry.register("callable_step")
@dataclass
class CallableStep:
name: str = "callable_step"
transform_fn: Any = None
def __call__(self, transition: EnvTransition) -> EnvTransition:
obs = transition.get(TransitionKey.OBSERVATION)
if obs is not None and self.transform_fn is not None:
processed_obs = {}
for k, v in obs.items():
processed_obs[k] = self.transform_fn(v)
new_transition = transition.copy()
new_transition[TransitionKey.OBSERVATION] = processed_obs
return new_transition
return transition
def get_config(self):
return {"name": self.name}
try:
step = CallableStep()
pipeline = RobotProcessor([step])
with tempfile.TemporaryDirectory() as tmp_dir:
pipeline.save_pretrained(tmp_dir)
# Define a transform function
def double_values(x):
if isinstance(x, (int, float)):
return x * 2
elif isinstance(x, torch.Tensor):
return x * 2
return x
# Load with callable override
loaded = RobotProcessor.from_pretrained(
tmp_dir, overrides={"callable_step": {"transform_fn": double_values}}
)
# Test it works
transition = create_transition(observation={"value": torch.tensor(5.0)})
result = loaded(transition)
assert result[TransitionKey.OBSERVATION]["value"].item() == 10.0
finally:
ProcessorStepRegistry.unregister("callable_step")
def test_override_multiple_same_class_warning():
"""Test behavior when multiple steps of same class exist."""
step1 = MockStepWithNonSerializableParam(name="step1", multiplier=1.0)
step2 = MockStepWithNonSerializableParam(name="step2", multiplier=2.0)
pipeline = RobotProcessor([step1, step2])
with tempfile.TemporaryDirectory() as tmp_dir:
pipeline.save_pretrained(tmp_dir)
# Override affects all instances of the class
loaded = RobotProcessor.from_pretrained(
tmp_dir, overrides={"MockStepWithNonSerializableParam": {"multiplier": 10.0}}
)
# Both steps get the same override
assert loaded.steps[0].multiplier == 10.0
assert loaded.steps[1].multiplier == 10.0
# But original names are preserved
assert loaded.steps[0].name == "step1"
assert loaded.steps[1].name == "step2"
def test_config_filename_special_characters():
"""Test config filenames with special characters are sanitized."""
# Processor name with special characters
pipeline = RobotProcessor([MockStep()], name="My/Processor\\With:Special*Chars")
with tempfile.TemporaryDirectory() as tmp_dir:
pipeline.save_pretrained(tmp_dir)
# Check that filename was sanitized
json_files = list(Path(tmp_dir).glob("*.json"))
assert len(json_files) == 1
# Should have replaced special chars with underscores
expected_name = "my_processor_with_special_chars.json"
assert json_files[0].name == expected_name
def test_override_with_device_strings():
"""Test overriding device parameters with string values."""
@ProcessorStepRegistry.register("device_aware_step")
@dataclass
class DeviceAwareStep:
device: str = "cpu"
def __init__(self, device: str = "cpu"):
self.device = device
self.buffer = torch.zeros(10, device=device)
def __call__(self, transition: EnvTransition) -> EnvTransition:
return transition
def get_config(self):
return {"device": str(self.device)}
def state_dict(self):
return {"buffer": self.buffer}
def load_state_dict(self, state):
self.buffer = state["buffer"]
try:
step = DeviceAwareStep(device="cpu")
pipeline = RobotProcessor([step])
with tempfile.TemporaryDirectory() as tmp_dir:
pipeline.save_pretrained(tmp_dir)
# Override device
if torch.cuda.is_available():
loaded = RobotProcessor.from_pretrained(
tmp_dir, overrides={"device_aware_step": {"device": "cuda:0"}}
)
loaded_step = loaded.steps[0]
assert loaded_step.device == "cuda:0"
# Note: buffer will still be on CPU from saved state
# until .to() is called on the processor
finally:
ProcessorStepRegistry.unregister("device_aware_step")
def test_from_pretrained_nonexistent_path():
"""Test error handling when loading from non-existent sources."""
from huggingface_hub.errors import HfHubHTTPError, HFValidationError
# Test with an invalid repo ID (too many slashes) - caught by HF validation
with pytest.raises(HFValidationError):
RobotProcessor.from_pretrained("/path/that/does/not/exist")
# Test with a non-existent but valid Hub repo format
with pytest.raises((FileNotFoundError, HfHubHTTPError)):
RobotProcessor.from_pretrained("nonexistent-user/nonexistent-repo")
# Test with a local directory that exists but has no config files
with tempfile.TemporaryDirectory() as tmp_dir:
with pytest.raises(FileNotFoundError, match="No .json configuration files found"):
RobotProcessor.from_pretrained(tmp_dir)
def test_save_load_with_custom_converter_functions():
"""Test that custom to_transition and to_output functions are NOT saved."""
def custom_to_transition(batch):
# Custom conversion logic
return {
TransitionKey.OBSERVATION: batch.get("obs"),
TransitionKey.ACTION: batch.get("act"),
TransitionKey.REWARD: batch.get("rew", 0.0),
TransitionKey.DONE: batch.get("done", False),
TransitionKey.TRUNCATED: batch.get("truncated", False),
TransitionKey.INFO: {},
TransitionKey.COMPLEMENTARY_DATA: {},
}
def custom_to_output(transition):
# Custom output format
return {
"obs": transition.get(TransitionKey.OBSERVATION),
"act": transition.get(TransitionKey.ACTION),
"rew": transition.get(TransitionKey.REWARD),
"done": transition.get(TransitionKey.DONE),
"truncated": transition.get(TransitionKey.TRUNCATED),
}
# Create processor with custom converters
pipeline = RobotProcessor([MockStep()], to_transition=custom_to_transition, to_output=custom_to_output)
with tempfile.TemporaryDirectory() as tmp_dir:
pipeline.save_pretrained(tmp_dir)
# Load - should use default converters
loaded = RobotProcessor.from_pretrained(tmp_dir)
# Verify it uses default converters by checking with standard batch format
batch = {
"observation.image": torch.randn(1, 3, 32, 32),
"action": torch.randn(1, 7),
"next.reward": torch.tensor([1.0]),
"next.done": torch.tensor([False]),
"next.truncated": torch.tensor([False]),
"info": {},
}
# Should work with standard format (wouldn't work with custom converter)
result = loaded(batch)
assert "observation.image" in result # Standard format preserved
+2 -2
View File
@@ -225,7 +225,7 @@ def test_save_and_load_pretrained():
pipeline.save_pretrained(tmp_dir)
# Check files were created
config_path = Path(tmp_dir) / "processor.json"
config_path = Path(tmp_dir) / "testrenameprocessor.json" # Based on name="TestRenameProcessor"
assert config_path.exists()
# No state files should be created for RenameProcessor
@@ -283,7 +283,7 @@ def test_registry_based_save_load():
# Verify config uses registry name
import json
with open(Path(tmp_dir) / "processor.json") as f:
with open(Path(tmp_dir) / "robotprocessor.json") as f: # Default name is "RobotProcessor"
config = json.load(f)
assert "registry_name" in config["steps"][0]