mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 08:39:49 +00:00
feat(pipeline): Enhance configuration filename handling and state file naming
- Introduced support for custom configuration filenames in the `save_pretrained` method, allowing users to specify a filename instead of the default. - Improved state file naming to include step indices, preventing conflicts when multiple processors of the same type are saved. - Added automatic detection for configuration files when loading from a directory, with error handling for multiple files. - Updated tests to validate new features, including custom filenames and automatic config detection.
This commit is contained in:
@@ -563,7 +563,7 @@ def test_save_and_load_pretrained():
|
||||
pipeline.save_pretrained(tmp_dir)
|
||||
|
||||
# Check files were created
|
||||
config_path = Path(tmp_dir) / "processor.json"
|
||||
config_path = Path(tmp_dir) / "testpipeline.json" # Based on name="TestPipeline"
|
||||
assert config_path.exists()
|
||||
|
||||
# Check config content
|
||||
@@ -629,7 +629,7 @@ def test_mixed_json_and_tensor_state():
|
||||
pipeline.save_pretrained(tmp_dir)
|
||||
|
||||
# Check that both config and state files were created
|
||||
config_path = Path(tmp_dir) / "processor.json"
|
||||
config_path = Path(tmp_dir) / "robotprocessor.json" # Default name is "RobotProcessor"
|
||||
state_path = Path(tmp_dir) / "step_0.safetensors"
|
||||
assert config_path.exists()
|
||||
assert state_path.exists()
|
||||
@@ -1655,3 +1655,454 @@ def test_repr_edge_case_long_names():
|
||||
|
||||
expected = "RobotProcessor(name='LongNames', steps=4: [MockStepWithNonSerializableParam, MockStepWithoutOptionalMethods, ..., MockNonModuleStepWithState], seed=999)"
|
||||
assert repr_str == expected
|
||||
|
||||
|
||||
# Tests for config filename features and multiple processors
|
||||
def test_save_with_custom_config_filename():
|
||||
"""Test saving processor with custom config filename."""
|
||||
step = MockStep("test")
|
||||
pipeline = RobotProcessor([step], name="TestProcessor")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
# Save with custom filename
|
||||
pipeline.save_pretrained(tmp_dir, config_filename="my_custom_config.json")
|
||||
|
||||
# Check file exists
|
||||
config_path = Path(tmp_dir) / "my_custom_config.json"
|
||||
assert config_path.exists()
|
||||
|
||||
# Check content
|
||||
with open(config_path) as f:
|
||||
config = json.load(f)
|
||||
assert config["name"] == "TestProcessor"
|
||||
|
||||
# Load with specific filename
|
||||
loaded = RobotProcessor.from_pretrained(tmp_dir, config_filename="my_custom_config.json")
|
||||
assert loaded.name == "TestProcessor"
|
||||
|
||||
|
||||
def test_multiple_processors_same_directory():
|
||||
"""Test saving multiple processors to the same directory with different config files."""
|
||||
# Create different processors
|
||||
preprocessor = RobotProcessor([MockStep("pre1"), MockStep("pre2")], name="preprocessor")
|
||||
|
||||
postprocessor = RobotProcessor([MockStepWithoutOptionalMethods(multiplier=0.5)], name="postprocessor")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
# Save both to same directory
|
||||
preprocessor.save_pretrained(tmp_dir)
|
||||
postprocessor.save_pretrained(tmp_dir)
|
||||
|
||||
# Check both config files exist
|
||||
assert (Path(tmp_dir) / "preprocessor.json").exists()
|
||||
assert (Path(tmp_dir) / "postprocessor.json").exists()
|
||||
|
||||
# Load them back
|
||||
loaded_pre = RobotProcessor.from_pretrained(tmp_dir, config_filename="preprocessor.json")
|
||||
loaded_post = RobotProcessor.from_pretrained(tmp_dir, config_filename="postprocessor.json")
|
||||
|
||||
assert loaded_pre.name == "preprocessor"
|
||||
assert loaded_post.name == "postprocessor"
|
||||
assert len(loaded_pre) == 2
|
||||
assert len(loaded_post) == 1
|
||||
|
||||
|
||||
def test_auto_detect_single_config():
|
||||
"""Test automatic config detection when there's only one JSON file."""
|
||||
step = MockStepWithTensorState()
|
||||
pipeline = RobotProcessor([step], name="SingleConfig")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
pipeline.save_pretrained(tmp_dir)
|
||||
|
||||
# Load without specifying config_filename
|
||||
loaded = RobotProcessor.from_pretrained(tmp_dir)
|
||||
assert loaded.name == "SingleConfig"
|
||||
|
||||
|
||||
def test_error_multiple_configs_no_filename():
|
||||
"""Test error when multiple configs exist and no filename specified."""
|
||||
proc1 = RobotProcessor([MockStep()], name="processor1")
|
||||
proc2 = RobotProcessor([MockStep()], name="processor2")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
proc1.save_pretrained(tmp_dir)
|
||||
proc2.save_pretrained(tmp_dir)
|
||||
|
||||
# Should raise error
|
||||
with pytest.raises(ValueError, match="Multiple .json files found"):
|
||||
RobotProcessor.from_pretrained(tmp_dir)
|
||||
|
||||
|
||||
def test_state_file_naming_with_indices():
|
||||
"""Test that state files include step indices to avoid conflicts."""
|
||||
# Create multiple steps of same type with state
|
||||
step1 = MockStepWithTensorState(name="norm1", window_size=5)
|
||||
step2 = MockStepWithTensorState(name="norm2", window_size=10)
|
||||
step3 = MockModuleStep(input_dim=5)
|
||||
|
||||
pipeline = RobotProcessor([step1, step2, step3])
|
||||
|
||||
# Process some data to create state
|
||||
for i in range(5):
|
||||
transition = create_transition(observation=torch.randn(2, 5), reward=float(i))
|
||||
pipeline(transition)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
pipeline.save_pretrained(tmp_dir)
|
||||
|
||||
# Check state files have indices
|
||||
state_files = sorted(Path(tmp_dir).glob("*.safetensors"))
|
||||
assert len(state_files) == 3
|
||||
|
||||
# Files should be named with indices
|
||||
expected_names = ["step_0.safetensors", "step_1.safetensors", "step_2.safetensors"]
|
||||
actual_names = [f.name for f in state_files]
|
||||
assert actual_names == expected_names
|
||||
|
||||
|
||||
def test_state_file_naming_with_registry():
|
||||
"""Test state file naming for registered steps includes both index and name."""
|
||||
|
||||
# Register a test step
|
||||
@ProcessorStepRegistry.register("test_stateful_step")
|
||||
@dataclass
|
||||
class TestStatefulStep:
|
||||
value: int = 0
|
||||
|
||||
def __init__(self, value: int = 0):
|
||||
self.value = value
|
||||
self.state_tensor = torch.randn(3, 3)
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
return transition
|
||||
|
||||
def get_config(self):
|
||||
return {"value": self.value}
|
||||
|
||||
def state_dict(self):
|
||||
return {"state_tensor": self.state_tensor}
|
||||
|
||||
def load_state_dict(self, state):
|
||||
self.state_tensor = state["state_tensor"]
|
||||
|
||||
try:
|
||||
# Create pipeline with registered steps
|
||||
step1 = TestStatefulStep(1)
|
||||
step2 = TestStatefulStep(2)
|
||||
pipeline = RobotProcessor([step1, step2])
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
pipeline.save_pretrained(tmp_dir)
|
||||
|
||||
# Check state files
|
||||
state_files = sorted(Path(tmp_dir).glob("*.safetensors"))
|
||||
assert len(state_files) == 2
|
||||
|
||||
# Should include both index and registry name
|
||||
expected_names = [
|
||||
"step_0_test_stateful_step.safetensors",
|
||||
"step_1_test_stateful_step.safetensors",
|
||||
]
|
||||
actual_names = [f.name for f in state_files]
|
||||
assert actual_names == expected_names
|
||||
|
||||
finally:
|
||||
# Cleanup registry
|
||||
ProcessorStepRegistry.unregister("test_stateful_step")
|
||||
|
||||
|
||||
# More comprehensive override tests
|
||||
def test_override_with_nested_config():
|
||||
"""Test overrides with nested configuration dictionaries."""
|
||||
|
||||
@ProcessorStepRegistry.register("complex_config_step")
|
||||
@dataclass
|
||||
class ComplexConfigStep:
|
||||
name: str = "complex"
|
||||
simple_param: int = 42
|
||||
nested_config: dict = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.nested_config is None:
|
||||
self.nested_config = {"level1": {"level2": "default"}}
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
comp_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||
comp_data = dict(comp_data)
|
||||
comp_data["config_value"] = self.nested_config.get("level1", {}).get("level2", "missing")
|
||||
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.COMPLEMENTARY_DATA] = comp_data
|
||||
return new_transition
|
||||
|
||||
def get_config(self):
|
||||
return {"name": self.name, "simple_param": self.simple_param, "nested_config": self.nested_config}
|
||||
|
||||
try:
|
||||
step = ComplexConfigStep()
|
||||
pipeline = RobotProcessor([step])
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
pipeline.save_pretrained(tmp_dir)
|
||||
|
||||
# Load with nested override
|
||||
loaded = RobotProcessor.from_pretrained(
|
||||
tmp_dir,
|
||||
overrides={"complex_config_step": {"nested_config": {"level1": {"level2": "overridden"}}}},
|
||||
)
|
||||
|
||||
# Test that override worked
|
||||
transition = create_transition()
|
||||
result = loaded(transition)
|
||||
assert result[TransitionKey.COMPLEMENTARY_DATA]["config_value"] == "overridden"
|
||||
finally:
|
||||
ProcessorStepRegistry.unregister("complex_config_step")
|
||||
|
||||
|
||||
def test_override_preserves_defaults():
|
||||
"""Test that overrides only affect specified parameters."""
|
||||
step = MockStepWithNonSerializableParam(name="test", multiplier=2.0)
|
||||
pipeline = RobotProcessor([step])
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
pipeline.save_pretrained(tmp_dir)
|
||||
|
||||
# Override only one parameter
|
||||
loaded = RobotProcessor.from_pretrained(
|
||||
tmp_dir,
|
||||
overrides={
|
||||
"MockStepWithNonSerializableParam": {
|
||||
"multiplier": 5.0 # Only override multiplier
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# Check that name was preserved from saved config
|
||||
loaded_step = loaded.steps[0]
|
||||
assert loaded_step.name == "test" # Original value
|
||||
assert loaded_step.multiplier == 5.0 # Overridden value
|
||||
|
||||
|
||||
def test_override_type_validation():
|
||||
"""Test that type errors in overrides are caught properly."""
|
||||
step = MockStepWithTensorState(learning_rate=0.01)
|
||||
pipeline = RobotProcessor([step])
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
pipeline.save_pretrained(tmp_dir)
|
||||
|
||||
# Try to override with wrong type
|
||||
overrides = {
|
||||
"MockStepWithTensorState": {
|
||||
"window_size": "not_an_int" # Should be int
|
||||
}
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="Failed to instantiate"):
|
||||
RobotProcessor.from_pretrained(tmp_dir, overrides=overrides)
|
||||
|
||||
|
||||
def test_override_with_callables():
|
||||
"""Test overriding with callable objects."""
|
||||
|
||||
@ProcessorStepRegistry.register("callable_step")
|
||||
@dataclass
|
||||
class CallableStep:
|
||||
name: str = "callable_step"
|
||||
transform_fn: Any = None
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
obs = transition.get(TransitionKey.OBSERVATION)
|
||||
if obs is not None and self.transform_fn is not None:
|
||||
processed_obs = {}
|
||||
for k, v in obs.items():
|
||||
processed_obs[k] = self.transform_fn(v)
|
||||
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.OBSERVATION] = processed_obs
|
||||
return new_transition
|
||||
return transition
|
||||
|
||||
def get_config(self):
|
||||
return {"name": self.name}
|
||||
|
||||
try:
|
||||
step = CallableStep()
|
||||
pipeline = RobotProcessor([step])
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
pipeline.save_pretrained(tmp_dir)
|
||||
|
||||
# Define a transform function
|
||||
def double_values(x):
|
||||
if isinstance(x, (int, float)):
|
||||
return x * 2
|
||||
elif isinstance(x, torch.Tensor):
|
||||
return x * 2
|
||||
return x
|
||||
|
||||
# Load with callable override
|
||||
loaded = RobotProcessor.from_pretrained(
|
||||
tmp_dir, overrides={"callable_step": {"transform_fn": double_values}}
|
||||
)
|
||||
|
||||
# Test it works
|
||||
transition = create_transition(observation={"value": torch.tensor(5.0)})
|
||||
result = loaded(transition)
|
||||
assert result[TransitionKey.OBSERVATION]["value"].item() == 10.0
|
||||
finally:
|
||||
ProcessorStepRegistry.unregister("callable_step")
|
||||
|
||||
|
||||
def test_override_multiple_same_class_warning():
|
||||
"""Test behavior when multiple steps of same class exist."""
|
||||
step1 = MockStepWithNonSerializableParam(name="step1", multiplier=1.0)
|
||||
step2 = MockStepWithNonSerializableParam(name="step2", multiplier=2.0)
|
||||
pipeline = RobotProcessor([step1, step2])
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
pipeline.save_pretrained(tmp_dir)
|
||||
|
||||
# Override affects all instances of the class
|
||||
loaded = RobotProcessor.from_pretrained(
|
||||
tmp_dir, overrides={"MockStepWithNonSerializableParam": {"multiplier": 10.0}}
|
||||
)
|
||||
|
||||
# Both steps get the same override
|
||||
assert loaded.steps[0].multiplier == 10.0
|
||||
assert loaded.steps[1].multiplier == 10.0
|
||||
|
||||
# But original names are preserved
|
||||
assert loaded.steps[0].name == "step1"
|
||||
assert loaded.steps[1].name == "step2"
|
||||
|
||||
|
||||
def test_config_filename_special_characters():
|
||||
"""Test config filenames with special characters are sanitized."""
|
||||
# Processor name with special characters
|
||||
pipeline = RobotProcessor([MockStep()], name="My/Processor\\With:Special*Chars")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
pipeline.save_pretrained(tmp_dir)
|
||||
|
||||
# Check that filename was sanitized
|
||||
json_files = list(Path(tmp_dir).glob("*.json"))
|
||||
assert len(json_files) == 1
|
||||
|
||||
# Should have replaced special chars with underscores
|
||||
expected_name = "my_processor_with_special_chars.json"
|
||||
assert json_files[0].name == expected_name
|
||||
|
||||
|
||||
def test_override_with_device_strings():
|
||||
"""Test overriding device parameters with string values."""
|
||||
|
||||
@ProcessorStepRegistry.register("device_aware_step")
|
||||
@dataclass
|
||||
class DeviceAwareStep:
|
||||
device: str = "cpu"
|
||||
|
||||
def __init__(self, device: str = "cpu"):
|
||||
self.device = device
|
||||
self.buffer = torch.zeros(10, device=device)
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
return transition
|
||||
|
||||
def get_config(self):
|
||||
return {"device": str(self.device)}
|
||||
|
||||
def state_dict(self):
|
||||
return {"buffer": self.buffer}
|
||||
|
||||
def load_state_dict(self, state):
|
||||
self.buffer = state["buffer"]
|
||||
|
||||
try:
|
||||
step = DeviceAwareStep(device="cpu")
|
||||
pipeline = RobotProcessor([step])
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
pipeline.save_pretrained(tmp_dir)
|
||||
|
||||
# Override device
|
||||
if torch.cuda.is_available():
|
||||
loaded = RobotProcessor.from_pretrained(
|
||||
tmp_dir, overrides={"device_aware_step": {"device": "cuda:0"}}
|
||||
)
|
||||
|
||||
loaded_step = loaded.steps[0]
|
||||
assert loaded_step.device == "cuda:0"
|
||||
# Note: buffer will still be on CPU from saved state
|
||||
# until .to() is called on the processor
|
||||
|
||||
finally:
|
||||
ProcessorStepRegistry.unregister("device_aware_step")
|
||||
|
||||
|
||||
def test_from_pretrained_nonexistent_path():
|
||||
"""Test error handling when loading from non-existent sources."""
|
||||
from huggingface_hub.errors import HfHubHTTPError, HFValidationError
|
||||
|
||||
# Test with an invalid repo ID (too many slashes) - caught by HF validation
|
||||
with pytest.raises(HFValidationError):
|
||||
RobotProcessor.from_pretrained("/path/that/does/not/exist")
|
||||
|
||||
# Test with a non-existent but valid Hub repo format
|
||||
with pytest.raises((FileNotFoundError, HfHubHTTPError)):
|
||||
RobotProcessor.from_pretrained("nonexistent-user/nonexistent-repo")
|
||||
|
||||
# Test with a local directory that exists but has no config files
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
with pytest.raises(FileNotFoundError, match="No .json configuration files found"):
|
||||
RobotProcessor.from_pretrained(tmp_dir)
|
||||
|
||||
|
||||
def test_save_load_with_custom_converter_functions():
|
||||
"""Test that custom to_transition and to_output functions are NOT saved."""
|
||||
|
||||
def custom_to_transition(batch):
|
||||
# Custom conversion logic
|
||||
return {
|
||||
TransitionKey.OBSERVATION: batch.get("obs"),
|
||||
TransitionKey.ACTION: batch.get("act"),
|
||||
TransitionKey.REWARD: batch.get("rew", 0.0),
|
||||
TransitionKey.DONE: batch.get("done", False),
|
||||
TransitionKey.TRUNCATED: batch.get("truncated", False),
|
||||
TransitionKey.INFO: {},
|
||||
TransitionKey.COMPLEMENTARY_DATA: {},
|
||||
}
|
||||
|
||||
def custom_to_output(transition):
|
||||
# Custom output format
|
||||
return {
|
||||
"obs": transition.get(TransitionKey.OBSERVATION),
|
||||
"act": transition.get(TransitionKey.ACTION),
|
||||
"rew": transition.get(TransitionKey.REWARD),
|
||||
"done": transition.get(TransitionKey.DONE),
|
||||
"truncated": transition.get(TransitionKey.TRUNCATED),
|
||||
}
|
||||
|
||||
# Create processor with custom converters
|
||||
pipeline = RobotProcessor([MockStep()], to_transition=custom_to_transition, to_output=custom_to_output)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
pipeline.save_pretrained(tmp_dir)
|
||||
|
||||
# Load - should use default converters
|
||||
loaded = RobotProcessor.from_pretrained(tmp_dir)
|
||||
|
||||
# Verify it uses default converters by checking with standard batch format
|
||||
batch = {
|
||||
"observation.image": torch.randn(1, 3, 32, 32),
|
||||
"action": torch.randn(1, 7),
|
||||
"next.reward": torch.tensor([1.0]),
|
||||
"next.done": torch.tensor([False]),
|
||||
"next.truncated": torch.tensor([False]),
|
||||
"info": {},
|
||||
}
|
||||
|
||||
# Should work with standard format (wouldn't work with custom converter)
|
||||
result = loaded(batch)
|
||||
assert "observation.image" in result # Standard format preserved
|
||||
|
||||
@@ -225,7 +225,7 @@ def test_save_and_load_pretrained():
|
||||
pipeline.save_pretrained(tmp_dir)
|
||||
|
||||
# Check files were created
|
||||
config_path = Path(tmp_dir) / "processor.json"
|
||||
config_path = Path(tmp_dir) / "testrenameprocessor.json" # Based on name="TestRenameProcessor"
|
||||
assert config_path.exists()
|
||||
|
||||
# No state files should be created for RenameProcessor
|
||||
@@ -283,7 +283,7 @@ def test_registry_based_save_load():
|
||||
# Verify config uses registry name
|
||||
import json
|
||||
|
||||
with open(Path(tmp_dir) / "processor.json") as f:
|
||||
with open(Path(tmp_dir) / "robotprocessor.json") as f: # Default name is "RobotProcessor"
|
||||
config = json.load(f)
|
||||
|
||||
assert "registry_name" in config["steps"][0]
|
||||
|
||||
Reference in New Issue
Block a user