refactor(pipeline): Rename parameters for clarity and enhance save/load functionality

- Updated parameter names in the save_pretrained and from_pretrained methods for improved readability, changing destination_path to save_directory and source to pretrained_model_name_or_path.
- Enhanced the save_pretrained method to ensure directory creation and file handling is consistent with the new parameter names.
- Streamlined the loading process in from_pretrained to utilize loaded_config for better clarity and maintainability.
This commit is contained in:
Adil Zouitine
2025-08-05 17:44:21 +02:00
parent 05bd18f453
commit 7beb040e8e
+71 -24
View File
@@ -403,21 +403,21 @@ class RobotProcessor(ModelHubMixin):
transition = processor_step(transition)
yield transition
def _save_pretrained(self, destination_path: str, **kwargs):
def _save_pretrained(self, save_directory: Path, **kwargs):
"""Internal save method for ModelHubMixin compatibility."""
# Extract config_filename from kwargs if provided
config_filename = kwargs.pop("config_filename", None)
self.save_pretrained(destination_path, config_filename=config_filename)
self.save_pretrained(save_directory, config_filename=config_filename)
def save_pretrained(self, destination_path: str, config_filename: str | None = None, **kwargs):
"""Serialize the processor definition and parameters to *destination_path*.
def save_pretrained(self, save_directory: str | Path, config_filename: str | None = None, **kwargs):
"""Serialize the processor definition and parameters to *save_directory*.
Args:
destination_path: Directory where the processor will be saved.
save_directory: Directory where the processor will be saved.
config_filename: Optional custom config filename. If not provided, defaults to
"{self.name}.json" where self.name is sanitized for filesystem compatibility.
"""
os.makedirs(destination_path, exist_ok=True)
os.makedirs(str(save_directory), exist_ok=True)
# Sanitize processor name for use in filenames
import re
@@ -439,16 +439,15 @@ class RobotProcessor(ModelHubMixin):
# Check if step was registered
registry_name = getattr(processor_step.__class__, "_registry_name", None)
step_entry: dict[str, Any] = {}
if registry_name:
# Use registry name for registered steps
step_entry: dict[str, Any] = {
"registry_name": registry_name,
}
step_entry["registry_name"] = registry_name
else:
# Fall back to full module path for unregistered steps
step_entry: dict[str, Any] = {
"class": f"{processor_step.__class__.__module__}.{processor_step.__class__.__name__}",
}
step_entry["class"] = (
f"{processor_step.__class__.__module__}.{processor_step.__class__.__name__}"
)
if hasattr(processor_step, "get_config"):
step_entry["config"] = processor_step.get_config()
@@ -475,22 +474,34 @@ class RobotProcessor(ModelHubMixin):
else:
state_filename = f"{sanitized_name}_step_{step_index}.safetensors"
save_file(cloned_state, os.path.join(destination_path, state_filename))
save_file(cloned_state, os.path.join(str(save_directory), state_filename))
step_entry["state_file"] = state_filename
config["steps"].append(step_entry)
with open(os.path.join(destination_path, config_filename), "w") as file_pointer:
with open(os.path.join(str(save_directory), config_filename), "w") as file_pointer:
json.dump(config, file_pointer, indent=2)
@classmethod
def from_pretrained(
cls, source: str, *, config_filename: str | None = None, overrides: dict[str, Any] | None = None
cls,
pretrained_model_name_or_path: str | Path,
*,
force_download: bool = False,
resume_download: bool | None = None,
proxies: dict[str, str] | None = None,
token: str | bool | None = None,
cache_dir: str | Path | None = None,
local_files_only: bool = False,
revision: str | None = None,
config_filename: str | None = None,
overrides: dict[str, Any] | None = None,
**kwargs,
) -> RobotProcessor:
"""Load a serialized processor from source (local path or Hugging Face Hub identifier).
Args:
source: Local path to a saved processor directory or Hugging Face Hub identifier
pretrained_model_name_or_path: Local path to a saved processor directory or Hugging Face Hub identifier
(e.g., "username/processor-name").
config_filename: Optional specific config filename to load. If not provided, will:
- For local paths: look for any .json file in the directory (error if multiple found)
@@ -543,6 +554,9 @@ class RobotProcessor(ModelHubMixin):
)
```
"""
# Use the local variable name 'source' for clarity
source = str(pretrained_model_name_or_path)
if Path(source).is_dir():
# Local path - use it directly
base_path = Path(source)
@@ -560,7 +574,7 @@ class RobotProcessor(ModelHubMixin):
config_filename = json_files[0].name
with open(base_path / config_filename) as file_pointer:
config: dict[str, Any] = json.load(file_pointer)
loaded_config: dict[str, Any] = json.load(file_pointer)
else:
# Hugging Face Hub - download all required files
if config_filename is None:
@@ -574,7 +588,18 @@ class RobotProcessor(ModelHubMixin):
config_path = None
for name in common_names:
try:
config_path = hf_hub_download(source, name, repo_type="model")
config_path = hf_hub_download(
source,
name,
repo_type="model",
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
token=token,
cache_dir=cache_dir,
local_files_only=local_files_only,
revision=revision,
)
config_filename = name
break
except (FileNotFoundError, OSError, HfHubHTTPError):
@@ -590,10 +615,21 @@ class RobotProcessor(ModelHubMixin):
)
else:
# Download specific config file
config_path = hf_hub_download(source, config_filename, repo_type="model")
config_path = hf_hub_download(
source,
config_filename,
repo_type="model",
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
token=token,
cache_dir=cache_dir,
local_files_only=local_files_only,
revision=revision,
)
with open(config_path) as file_pointer:
config: dict[str, Any] = json.load(file_pointer)
loaded_config = json.load(file_pointer)
# Store downloaded files in the same directory as the config
base_path = Path(config_path).parent
@@ -606,7 +642,7 @@ class RobotProcessor(ModelHubMixin):
override_keys = set(overrides.keys())
steps: list[ProcessorStep] = []
for step_entry in config["steps"]:
for step_entry in loaded_config["steps"]:
# Check if step uses registry name or module path
if "registry_name" in step_entry:
# Load from registry
@@ -658,7 +694,18 @@ class RobotProcessor(ModelHubMixin):
state_path = str(base_path / step_entry["state_file"])
else:
# Hugging Face Hub - download the state file
state_path = hf_hub_download(source, step_entry["state_file"], repo_type="model")
state_path = hf_hub_download(
source,
step_entry["state_file"],
repo_type="model",
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
token=token,
cache_dir=cache_dir,
local_files_only=local_files_only,
revision=revision,
)
step_instance.load_state_dict(load_file(state_path))
@@ -667,7 +714,7 @@ class RobotProcessor(ModelHubMixin):
# Check for unused override keys
if override_keys:
available_keys = []
for step_entry in config["steps"]:
for step_entry in loaded_config["steps"]:
if "registry_name" in step_entry:
available_keys.append(step_entry["registry_name"])
else:
@@ -681,7 +728,7 @@ class RobotProcessor(ModelHubMixin):
f"Make sure override keys match exact step class names or registry names."
)
return cls(steps, config.get("name", "RobotProcessor"), config.get("seed"))
return cls(steps, loaded_config.get("name", "RobotProcessor"), loaded_config.get("seed"))
def __len__(self) -> int:
"""Return the number of steps in the processor."""