refactor(pipeline): Improve state file naming conventions for clarity and uniqueness

- Enhanced state file naming to include the processor's sanitized name, ensuring uniqueness when multiple processors are saved in the same directory.
- Updated tests to reflect changes in state file naming, verifying that filenames now include the processor name and step indices to prevent conflicts.
- Added a new test to validate state file naming when using multiple processors, ensuring distinct filenames for each processor's state files.
This commit is contained in:
Adil Zouitine
2025-07-23 09:41:03 +02:00
parent 4ba23ea029
commit 907023f9f7
2 changed files with 58 additions and 17 deletions
+10 -9
View File
@@ -416,12 +416,13 @@ class RobotProcessor(ModelHubMixin):
""" """
os.makedirs(destination_path, exist_ok=True) os.makedirs(destination_path, exist_ok=True)
# Determine config filename - sanitize the processor name for filesystem # Sanitize processor name for use in filenames
if config_filename is None: import re
# Sanitize name - replace any character that's not alphanumeric or underscore
import re
sanitized_name = re.sub(r"[^a-zA-Z0-9_]", "_", self.name.lower()) sanitized_name = re.sub(r"[^a-zA-Z0-9_]", "_", self.name.lower())
# Use sanitized name for config if not provided
if config_filename is None:
config_filename = f"{sanitized_name}.json" config_filename = f"{sanitized_name}.json"
config: dict[str, Any] = { config: dict[str, Any] = {
@@ -463,12 +464,12 @@ class RobotProcessor(ModelHubMixin):
for key, tensor in state.items(): for key, tensor in state.items():
cloned_state[key] = tensor.clone() cloned_state[key] = tensor.clone()
# Always include step index to ensure unique filenames # Include pipeline name and step index to ensure unique filenames
# This prevents conflicts when the same processor type is used multiple times # This prevents conflicts when multiple processors are saved in the same directory
if registry_name: if registry_name:
state_filename = f"step_{step_index}_{registry_name}.safetensors" state_filename = f"{sanitized_name}_step_{step_index}_{registry_name}.safetensors"
else: else:
state_filename = f"step_{step_index}.safetensors" state_filename = f"{sanitized_name}_step_{step_index}.safetensors"
save_file(cloned_state, os.path.join(destination_path, state_filename)) save_file(cloned_state, os.path.join(destination_path, state_filename))
step_entry["state_file"] = state_filename step_entry["state_file"] = state_filename
+48 -8
View File
@@ -630,7 +630,7 @@ def test_mixed_json_and_tensor_state():
# Check that both config and state files were created # Check that both config and state files were created
config_path = Path(tmp_dir) / "robotprocessor.json" # Default name is "RobotProcessor" config_path = Path(tmp_dir) / "robotprocessor.json" # Default name is "RobotProcessor"
state_path = Path(tmp_dir) / "step_0.safetensors" state_path = Path(tmp_dir) / "robotprocessor_step_0.safetensors"
assert config_path.exists() assert config_path.exists()
assert state_path.exists() assert state_path.exists()
@@ -1735,7 +1735,7 @@ def test_error_multiple_configs_no_filename():
def test_state_file_naming_with_indices(): def test_state_file_naming_with_indices():
"""Test that state files include step indices to avoid conflicts.""" """Test that state files include pipeline name and step indices to avoid conflicts."""
# Create multiple steps of same type with state # Create multiple steps of same type with state
step1 = MockStepWithTensorState(name="norm1", window_size=5) step1 = MockStepWithTensorState(name="norm1", window_size=5)
step2 = MockStepWithTensorState(name="norm2", window_size=10) step2 = MockStepWithTensorState(name="norm2", window_size=10)
@@ -1755,14 +1755,18 @@ def test_state_file_naming_with_indices():
state_files = sorted(Path(tmp_dir).glob("*.safetensors")) state_files = sorted(Path(tmp_dir).glob("*.safetensors"))
assert len(state_files) == 3 assert len(state_files) == 3
# Files should be named with indices # Files should be named with pipeline name prefix and indices
expected_names = ["step_0.safetensors", "step_1.safetensors", "step_2.safetensors"] expected_names = [
"robotprocessor_step_0.safetensors",
"robotprocessor_step_1.safetensors",
"robotprocessor_step_2.safetensors",
]
actual_names = [f.name for f in state_files] actual_names = [f.name for f in state_files]
assert actual_names == expected_names assert actual_names == expected_names
def test_state_file_naming_with_registry(): def test_state_file_naming_with_registry():
"""Test state file naming for registered steps includes both index and name.""" """Test state file naming for registered steps includes pipeline name, index and registry name."""
# Register a test step # Register a test step
@ProcessorStepRegistry.register("test_stateful_step") @ProcessorStepRegistry.register("test_stateful_step")
@@ -1799,10 +1803,10 @@ def test_state_file_naming_with_registry():
state_files = sorted(Path(tmp_dir).glob("*.safetensors")) state_files = sorted(Path(tmp_dir).glob("*.safetensors"))
assert len(state_files) == 2 assert len(state_files) == 2
# Should include both index and registry name # Should include pipeline name, index and registry name
expected_names = [ expected_names = [
"step_0_test_stateful_step.safetensors", "robotprocessor_step_0_test_stateful_step.safetensors",
"step_1_test_stateful_step.safetensors", "robotprocessor_step_1_test_stateful_step.safetensors",
] ]
actual_names = [f.name for f in state_files] actual_names = [f.name for f in state_files]
assert actual_names == expected_names assert actual_names == expected_names
@@ -1995,6 +1999,42 @@ def test_config_filename_special_characters():
assert json_files[0].name == expected_name assert json_files[0].name == expected_name
def test_state_file_naming_with_multiple_processors():
"""Test that state files are properly prefixed with pipeline names to avoid conflicts."""
# Create two processors with state
step1 = MockStepWithTensorState(name="norm", window_size=5)
preprocessor = RobotProcessor([step1], name="PreProcessor")
step2 = MockStepWithTensorState(name="norm", window_size=10)
postprocessor = RobotProcessor([step2], name="PostProcessor")
# Process some data to create state
for i in range(3):
transition = create_transition(reward=float(i))
preprocessor(transition)
postprocessor(transition)
with tempfile.TemporaryDirectory() as tmp_dir:
# Save both processors to the same directory
preprocessor.save_pretrained(tmp_dir)
postprocessor.save_pretrained(tmp_dir)
# Check that all files exist and are distinct
assert (Path(tmp_dir) / "preprocessor.json").exists()
assert (Path(tmp_dir) / "postprocessor.json").exists()
assert (Path(tmp_dir) / "preprocessor_step_0.safetensors").exists()
assert (Path(tmp_dir) / "postprocessor_step_0.safetensors").exists()
# Load both back and verify they work correctly
loaded_pre = RobotProcessor.from_pretrained(tmp_dir, config_filename="preprocessor.json")
loaded_post = RobotProcessor.from_pretrained(tmp_dir, config_filename="postprocessor.json")
assert loaded_pre.name == "PreProcessor"
assert loaded_post.name == "PostProcessor"
assert loaded_pre.steps[0].window_size == 5
assert loaded_post.steps[0].window_size == 10
def test_override_with_device_strings(): def test_override_with_device_strings():
"""Test overriding device parameters with string values.""" """Test overriding device parameters with string values."""