refactor(pipeline): Improve state file naming conventions for clarity and uniqueness

- Enhanced state file naming to include the processor's sanitized name, ensuring uniqueness when multiple processors are saved in the same directory.
- Updated tests to reflect changes in state file naming, verifying that filenames now include the processor name and step indices to prevent conflicts.
- Added a new test to validate state file naming when using multiple processors, ensuring distinct filenames for each processor's state files.
This commit is contained in:
Adil Zouitine
2025-07-23 09:41:03 +02:00
parent 4ba23ea029
commit 907023f9f7
2 changed files with 58 additions and 17 deletions
+48 -8
View File
@@ -630,7 +630,7 @@ def test_mixed_json_and_tensor_state():
# Check that both config and state files were created
config_path = Path(tmp_dir) / "robotprocessor.json" # Default name is "RobotProcessor"
state_path = Path(tmp_dir) / "step_0.safetensors"
state_path = Path(tmp_dir) / "robotprocessor_step_0.safetensors"
assert config_path.exists()
assert state_path.exists()
@@ -1735,7 +1735,7 @@ def test_error_multiple_configs_no_filename():
def test_state_file_naming_with_indices():
"""Test that state files include step indices to avoid conflicts."""
"""Test that state files include pipeline name and step indices to avoid conflicts."""
# Create multiple steps of same type with state
step1 = MockStepWithTensorState(name="norm1", window_size=5)
step2 = MockStepWithTensorState(name="norm2", window_size=10)
@@ -1755,14 +1755,18 @@ def test_state_file_naming_with_indices():
state_files = sorted(Path(tmp_dir).glob("*.safetensors"))
assert len(state_files) == 3
# Files should be named with indices
expected_names = ["step_0.safetensors", "step_1.safetensors", "step_2.safetensors"]
# Files should be named with pipeline name prefix and indices
expected_names = [
"robotprocessor_step_0.safetensors",
"robotprocessor_step_1.safetensors",
"robotprocessor_step_2.safetensors",
]
actual_names = [f.name for f in state_files]
assert actual_names == expected_names
def test_state_file_naming_with_registry():
"""Test state file naming for registered steps includes both index and name."""
"""Test state file naming for registered steps includes pipeline name, index and registry name."""
# Register a test step
@ProcessorStepRegistry.register("test_stateful_step")
@@ -1799,10 +1803,10 @@ def test_state_file_naming_with_registry():
state_files = sorted(Path(tmp_dir).glob("*.safetensors"))
assert len(state_files) == 2
# Should include both index and registry name
# Should include pipeline name, index and registry name
expected_names = [
"step_0_test_stateful_step.safetensors",
"step_1_test_stateful_step.safetensors",
"robotprocessor_step_0_test_stateful_step.safetensors",
"robotprocessor_step_1_test_stateful_step.safetensors",
]
actual_names = [f.name for f in state_files]
assert actual_names == expected_names
@@ -1995,6 +1999,42 @@ def test_config_filename_special_characters():
assert json_files[0].name == expected_name
def test_state_file_naming_with_multiple_processors():
"""Test that state files are properly prefixed with pipeline names to avoid conflicts."""
# Create two processors with state
step1 = MockStepWithTensorState(name="norm", window_size=5)
preprocessor = RobotProcessor([step1], name="PreProcessor")
step2 = MockStepWithTensorState(name="norm", window_size=10)
postprocessor = RobotProcessor([step2], name="PostProcessor")
# Process some data to create state
for i in range(3):
transition = create_transition(reward=float(i))
preprocessor(transition)
postprocessor(transition)
with tempfile.TemporaryDirectory() as tmp_dir:
# Save both processors to the same directory
preprocessor.save_pretrained(tmp_dir)
postprocessor.save_pretrained(tmp_dir)
# Check that all files exist and are distinct
assert (Path(tmp_dir) / "preprocessor.json").exists()
assert (Path(tmp_dir) / "postprocessor.json").exists()
assert (Path(tmp_dir) / "preprocessor_step_0.safetensors").exists()
assert (Path(tmp_dir) / "postprocessor_step_0.safetensors").exists()
# Load both back and verify they work correctly
loaded_pre = RobotProcessor.from_pretrained(tmp_dir, config_filename="preprocessor.json")
loaded_post = RobotProcessor.from_pretrained(tmp_dir, config_filename="postprocessor.json")
assert loaded_pre.name == "PreProcessor"
assert loaded_post.name == "PostProcessor"
assert loaded_pre.steps[0].window_size == 5
assert loaded_post.steps[0].window_size == 10
def test_override_with_device_strings():
"""Test overriding device parameters with string values."""