mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 12:09:42 +00:00
refactor(pipeline): Improve state file naming conventions for clarity and uniqueness
- Enhanced state file naming to include the processor's sanitized name, ensuring uniqueness when multiple processors are saved in the same directory. - Updated tests to reflect changes in state file naming, verifying that filenames now include the processor name and step indices to prevent conflicts. - Added a new test to validate state file naming when using multiple processors, ensuring distinct filenames for each processor's state files.
This commit is contained in:
@@ -416,12 +416,13 @@ class RobotProcessor(ModelHubMixin):
|
|||||||
"""
|
"""
|
||||||
os.makedirs(destination_path, exist_ok=True)
|
os.makedirs(destination_path, exist_ok=True)
|
||||||
|
|
||||||
# Determine config filename - sanitize the processor name for filesystem
|
# Sanitize processor name for use in filenames
|
||||||
if config_filename is None:
|
import re
|
||||||
# Sanitize name - replace any character that's not alphanumeric or underscore
|
|
||||||
import re
|
|
||||||
|
|
||||||
sanitized_name = re.sub(r"[^a-zA-Z0-9_]", "_", self.name.lower())
|
sanitized_name = re.sub(r"[^a-zA-Z0-9_]", "_", self.name.lower())
|
||||||
|
|
||||||
|
# Use sanitized name for config if not provided
|
||||||
|
if config_filename is None:
|
||||||
config_filename = f"{sanitized_name}.json"
|
config_filename = f"{sanitized_name}.json"
|
||||||
|
|
||||||
config: dict[str, Any] = {
|
config: dict[str, Any] = {
|
||||||
@@ -463,12 +464,12 @@ class RobotProcessor(ModelHubMixin):
|
|||||||
for key, tensor in state.items():
|
for key, tensor in state.items():
|
||||||
cloned_state[key] = tensor.clone()
|
cloned_state[key] = tensor.clone()
|
||||||
|
|
||||||
# Always include step index to ensure unique filenames
|
# Include pipeline name and step index to ensure unique filenames
|
||||||
# This prevents conflicts when the same processor type is used multiple times
|
# This prevents conflicts when multiple processors are saved in the same directory
|
||||||
if registry_name:
|
if registry_name:
|
||||||
state_filename = f"step_{step_index}_{registry_name}.safetensors"
|
state_filename = f"{sanitized_name}_step_{step_index}_{registry_name}.safetensors"
|
||||||
else:
|
else:
|
||||||
state_filename = f"step_{step_index}.safetensors"
|
state_filename = f"{sanitized_name}_step_{step_index}.safetensors"
|
||||||
|
|
||||||
save_file(cloned_state, os.path.join(destination_path, state_filename))
|
save_file(cloned_state, os.path.join(destination_path, state_filename))
|
||||||
step_entry["state_file"] = state_filename
|
step_entry["state_file"] = state_filename
|
||||||
|
|||||||
@@ -630,7 +630,7 @@ def test_mixed_json_and_tensor_state():
|
|||||||
|
|
||||||
# Check that both config and state files were created
|
# Check that both config and state files were created
|
||||||
config_path = Path(tmp_dir) / "robotprocessor.json" # Default name is "RobotProcessor"
|
config_path = Path(tmp_dir) / "robotprocessor.json" # Default name is "RobotProcessor"
|
||||||
state_path = Path(tmp_dir) / "step_0.safetensors"
|
state_path = Path(tmp_dir) / "robotprocessor_step_0.safetensors"
|
||||||
assert config_path.exists()
|
assert config_path.exists()
|
||||||
assert state_path.exists()
|
assert state_path.exists()
|
||||||
|
|
||||||
@@ -1735,7 +1735,7 @@ def test_error_multiple_configs_no_filename():
|
|||||||
|
|
||||||
|
|
||||||
def test_state_file_naming_with_indices():
|
def test_state_file_naming_with_indices():
|
||||||
"""Test that state files include step indices to avoid conflicts."""
|
"""Test that state files include pipeline name and step indices to avoid conflicts."""
|
||||||
# Create multiple steps of same type with state
|
# Create multiple steps of same type with state
|
||||||
step1 = MockStepWithTensorState(name="norm1", window_size=5)
|
step1 = MockStepWithTensorState(name="norm1", window_size=5)
|
||||||
step2 = MockStepWithTensorState(name="norm2", window_size=10)
|
step2 = MockStepWithTensorState(name="norm2", window_size=10)
|
||||||
@@ -1755,14 +1755,18 @@ def test_state_file_naming_with_indices():
|
|||||||
state_files = sorted(Path(tmp_dir).glob("*.safetensors"))
|
state_files = sorted(Path(tmp_dir).glob("*.safetensors"))
|
||||||
assert len(state_files) == 3
|
assert len(state_files) == 3
|
||||||
|
|
||||||
# Files should be named with indices
|
# Files should be named with pipeline name prefix and indices
|
||||||
expected_names = ["step_0.safetensors", "step_1.safetensors", "step_2.safetensors"]
|
expected_names = [
|
||||||
|
"robotprocessor_step_0.safetensors",
|
||||||
|
"robotprocessor_step_1.safetensors",
|
||||||
|
"robotprocessor_step_2.safetensors",
|
||||||
|
]
|
||||||
actual_names = [f.name for f in state_files]
|
actual_names = [f.name for f in state_files]
|
||||||
assert actual_names == expected_names
|
assert actual_names == expected_names
|
||||||
|
|
||||||
|
|
||||||
def test_state_file_naming_with_registry():
|
def test_state_file_naming_with_registry():
|
||||||
"""Test state file naming for registered steps includes both index and name."""
|
"""Test state file naming for registered steps includes pipeline name, index and registry name."""
|
||||||
|
|
||||||
# Register a test step
|
# Register a test step
|
||||||
@ProcessorStepRegistry.register("test_stateful_step")
|
@ProcessorStepRegistry.register("test_stateful_step")
|
||||||
@@ -1799,10 +1803,10 @@ def test_state_file_naming_with_registry():
|
|||||||
state_files = sorted(Path(tmp_dir).glob("*.safetensors"))
|
state_files = sorted(Path(tmp_dir).glob("*.safetensors"))
|
||||||
assert len(state_files) == 2
|
assert len(state_files) == 2
|
||||||
|
|
||||||
# Should include both index and registry name
|
# Should include pipeline name, index and registry name
|
||||||
expected_names = [
|
expected_names = [
|
||||||
"step_0_test_stateful_step.safetensors",
|
"robotprocessor_step_0_test_stateful_step.safetensors",
|
||||||
"step_1_test_stateful_step.safetensors",
|
"robotprocessor_step_1_test_stateful_step.safetensors",
|
||||||
]
|
]
|
||||||
actual_names = [f.name for f in state_files]
|
actual_names = [f.name for f in state_files]
|
||||||
assert actual_names == expected_names
|
assert actual_names == expected_names
|
||||||
@@ -1995,6 +1999,42 @@ def test_config_filename_special_characters():
|
|||||||
assert json_files[0].name == expected_name
|
assert json_files[0].name == expected_name
|
||||||
|
|
||||||
|
|
||||||
|
def test_state_file_naming_with_multiple_processors():
|
||||||
|
"""Test that state files are properly prefixed with pipeline names to avoid conflicts."""
|
||||||
|
# Create two processors with state
|
||||||
|
step1 = MockStepWithTensorState(name="norm", window_size=5)
|
||||||
|
preprocessor = RobotProcessor([step1], name="PreProcessor")
|
||||||
|
|
||||||
|
step2 = MockStepWithTensorState(name="norm", window_size=10)
|
||||||
|
postprocessor = RobotProcessor([step2], name="PostProcessor")
|
||||||
|
|
||||||
|
# Process some data to create state
|
||||||
|
for i in range(3):
|
||||||
|
transition = create_transition(reward=float(i))
|
||||||
|
preprocessor(transition)
|
||||||
|
postprocessor(transition)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
# Save both processors to the same directory
|
||||||
|
preprocessor.save_pretrained(tmp_dir)
|
||||||
|
postprocessor.save_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
# Check that all files exist and are distinct
|
||||||
|
assert (Path(tmp_dir) / "preprocessor.json").exists()
|
||||||
|
assert (Path(tmp_dir) / "postprocessor.json").exists()
|
||||||
|
assert (Path(tmp_dir) / "preprocessor_step_0.safetensors").exists()
|
||||||
|
assert (Path(tmp_dir) / "postprocessor_step_0.safetensors").exists()
|
||||||
|
|
||||||
|
# Load both back and verify they work correctly
|
||||||
|
loaded_pre = RobotProcessor.from_pretrained(tmp_dir, config_filename="preprocessor.json")
|
||||||
|
loaded_post = RobotProcessor.from_pretrained(tmp_dir, config_filename="postprocessor.json")
|
||||||
|
|
||||||
|
assert loaded_pre.name == "PreProcessor"
|
||||||
|
assert loaded_post.name == "PostProcessor"
|
||||||
|
assert loaded_pre.steps[0].window_size == 5
|
||||||
|
assert loaded_post.steps[0].window_size == 10
|
||||||
|
|
||||||
|
|
||||||
def test_override_with_device_strings():
|
def test_override_with_device_strings():
|
||||||
"""Test overriding device parameters with string values."""
|
"""Test overriding device parameters with string values."""
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user