mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-18 10:10:08 +00:00
refactor(pipeline): Improve state file naming conventions for clarity and uniqueness
- Enhanced state file naming to include the processor's sanitized name, ensuring uniqueness when multiple processors are saved in the same directory. - Updated tests to reflect changes in state file naming, verifying that filenames now include the processor name and step indices to prevent conflicts. - Added a new test to validate state file naming when using multiple processors, ensuring distinct filenames for each processor's state files.
This commit is contained in:
@@ -630,7 +630,7 @@ def test_mixed_json_and_tensor_state():
|
||||
|
||||
# Check that both config and state files were created
|
||||
config_path = Path(tmp_dir) / "robotprocessor.json" # Default name is "RobotProcessor"
|
||||
state_path = Path(tmp_dir) / "step_0.safetensors"
|
||||
state_path = Path(tmp_dir) / "robotprocessor_step_0.safetensors"
|
||||
assert config_path.exists()
|
||||
assert state_path.exists()
|
||||
|
||||
@@ -1735,7 +1735,7 @@ def test_error_multiple_configs_no_filename():
|
||||
|
||||
|
||||
def test_state_file_naming_with_indices():
|
||||
"""Test that state files include step indices to avoid conflicts."""
|
||||
"""Test that state files include pipeline name and step indices to avoid conflicts."""
|
||||
# Create multiple steps of same type with state
|
||||
step1 = MockStepWithTensorState(name="norm1", window_size=5)
|
||||
step2 = MockStepWithTensorState(name="norm2", window_size=10)
|
||||
@@ -1755,14 +1755,18 @@ def test_state_file_naming_with_indices():
|
||||
state_files = sorted(Path(tmp_dir).glob("*.safetensors"))
|
||||
assert len(state_files) == 3
|
||||
|
||||
# Files should be named with indices
|
||||
expected_names = ["step_0.safetensors", "step_1.safetensors", "step_2.safetensors"]
|
||||
# Files should be named with pipeline name prefix and indices
|
||||
expected_names = [
|
||||
"robotprocessor_step_0.safetensors",
|
||||
"robotprocessor_step_1.safetensors",
|
||||
"robotprocessor_step_2.safetensors",
|
||||
]
|
||||
actual_names = [f.name for f in state_files]
|
||||
assert actual_names == expected_names
|
||||
|
||||
|
||||
def test_state_file_naming_with_registry():
|
||||
"""Test state file naming for registered steps includes both index and name."""
|
||||
"""Test state file naming for registered steps includes pipeline name, index and registry name."""
|
||||
|
||||
# Register a test step
|
||||
@ProcessorStepRegistry.register("test_stateful_step")
|
||||
@@ -1799,10 +1803,10 @@ def test_state_file_naming_with_registry():
|
||||
state_files = sorted(Path(tmp_dir).glob("*.safetensors"))
|
||||
assert len(state_files) == 2
|
||||
|
||||
# Should include both index and registry name
|
||||
# Should include pipeline name, index and registry name
|
||||
expected_names = [
|
||||
"step_0_test_stateful_step.safetensors",
|
||||
"step_1_test_stateful_step.safetensors",
|
||||
"robotprocessor_step_0_test_stateful_step.safetensors",
|
||||
"robotprocessor_step_1_test_stateful_step.safetensors",
|
||||
]
|
||||
actual_names = [f.name for f in state_files]
|
||||
assert actual_names == expected_names
|
||||
@@ -1995,6 +1999,42 @@ def test_config_filename_special_characters():
|
||||
assert json_files[0].name == expected_name
|
||||
|
||||
|
||||
def test_state_file_naming_with_multiple_processors():
|
||||
"""Test that state files are properly prefixed with pipeline names to avoid conflicts."""
|
||||
# Create two processors with state
|
||||
step1 = MockStepWithTensorState(name="norm", window_size=5)
|
||||
preprocessor = RobotProcessor([step1], name="PreProcessor")
|
||||
|
||||
step2 = MockStepWithTensorState(name="norm", window_size=10)
|
||||
postprocessor = RobotProcessor([step2], name="PostProcessor")
|
||||
|
||||
# Process some data to create state
|
||||
for i in range(3):
|
||||
transition = create_transition(reward=float(i))
|
||||
preprocessor(transition)
|
||||
postprocessor(transition)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
# Save both processors to the same directory
|
||||
preprocessor.save_pretrained(tmp_dir)
|
||||
postprocessor.save_pretrained(tmp_dir)
|
||||
|
||||
# Check that all files exist and are distinct
|
||||
assert (Path(tmp_dir) / "preprocessor.json").exists()
|
||||
assert (Path(tmp_dir) / "postprocessor.json").exists()
|
||||
assert (Path(tmp_dir) / "preprocessor_step_0.safetensors").exists()
|
||||
assert (Path(tmp_dir) / "postprocessor_step_0.safetensors").exists()
|
||||
|
||||
# Load both back and verify they work correctly
|
||||
loaded_pre = RobotProcessor.from_pretrained(tmp_dir, config_filename="preprocessor.json")
|
||||
loaded_post = RobotProcessor.from_pretrained(tmp_dir, config_filename="postprocessor.json")
|
||||
|
||||
assert loaded_pre.name == "PreProcessor"
|
||||
assert loaded_post.name == "PostProcessor"
|
||||
assert loaded_pre.steps[0].window_size == 5
|
||||
assert loaded_post.steps[0].window_size == 10
|
||||
|
||||
|
||||
def test_override_with_device_strings():
|
||||
"""Test overriding device parameters with string values."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user