From 7beb040e8e14a015d13f93e6307a47f0f7a12185 Mon Sep 17 00:00:00 2001 From: Adil Zouitine Date: Tue, 5 Aug 2025 17:44:21 +0200 Subject: [PATCH] refactor(pipeline): Rename parameters for clarity and enhance save/load functionality - Updated parameter names in the save_pretrained and from_pretrained methods for improved readability, changing destination_path to save_directory and source to pretrained_model_name_or_path. - Enhanced the save_pretrained method to ensure directory creation and file handling is consistent with the new parameter names. - Streamlined the loading process in from_pretrained to utilize loaded_config for better clarity and maintainability. --- src/lerobot/processor/pipeline.py | 95 +++++++++++++++++++++++-------- 1 file changed, 71 insertions(+), 24 deletions(-) diff --git a/src/lerobot/processor/pipeline.py b/src/lerobot/processor/pipeline.py index a9221c7ae..7e78830fb 100644 --- a/src/lerobot/processor/pipeline.py +++ b/src/lerobot/processor/pipeline.py @@ -403,21 +403,21 @@ class RobotProcessor(ModelHubMixin): transition = processor_step(transition) yield transition - def _save_pretrained(self, destination_path: str, **kwargs): + def _save_pretrained(self, save_directory: Path, **kwargs): """Internal save method for ModelHubMixin compatibility.""" # Extract config_filename from kwargs if provided config_filename = kwargs.pop("config_filename", None) - self.save_pretrained(destination_path, config_filename=config_filename) + self.save_pretrained(save_directory, config_filename=config_filename) - def save_pretrained(self, destination_path: str, config_filename: str | None = None, **kwargs): - """Serialize the processor definition and parameters to *destination_path*. + def save_pretrained(self, save_directory: str | Path, config_filename: str | None = None, **kwargs): + """Serialize the processor definition and parameters to *save_directory*. Args: - destination_path: Directory where the processor will be saved. + save_directory: Directory where the processor will be saved. config_filename: Optional custom config filename. If not provided, defaults to "{self.name}.json" where self.name is sanitized for filesystem compatibility. """ - os.makedirs(destination_path, exist_ok=True) + os.makedirs(str(save_directory), exist_ok=True) # Sanitize processor name for use in filenames import re @@ -439,16 +439,15 @@ class RobotProcessor(ModelHubMixin): # Check if step was registered registry_name = getattr(processor_step.__class__, "_registry_name", None) + step_entry: dict[str, Any] = {} if registry_name: # Use registry name for registered steps - step_entry: dict[str, Any] = { - "registry_name": registry_name, - } + step_entry["registry_name"] = registry_name else: # Fall back to full module path for unregistered steps - step_entry: dict[str, Any] = { - "class": f"{processor_step.__class__.__module__}.{processor_step.__class__.__name__}", - } + step_entry["class"] = ( + f"{processor_step.__class__.__module__}.{processor_step.__class__.__name__}" + ) if hasattr(processor_step, "get_config"): step_entry["config"] = processor_step.get_config() @@ -475,22 +474,34 @@ class RobotProcessor(ModelHubMixin): else: state_filename = f"{sanitized_name}_step_{step_index}.safetensors" - save_file(cloned_state, os.path.join(destination_path, state_filename)) + save_file(cloned_state, os.path.join(str(save_directory), state_filename)) step_entry["state_file"] = state_filename config["steps"].append(step_entry) - with open(os.path.join(destination_path, config_filename), "w") as file_pointer: + with open(os.path.join(str(save_directory), config_filename), "w") as file_pointer: json.dump(config, file_pointer, indent=2) @classmethod def from_pretrained( - cls, source: str, *, config_filename: str | None = None, overrides: dict[str, Any] | None = None + cls, + pretrained_model_name_or_path: str | Path, + *, + force_download: bool = False, + resume_download: bool | None = None, + proxies: dict[str, str] | None = None, + token: str | bool | None = None, + cache_dir: str | Path | None = None, + local_files_only: bool = False, + revision: str | None = None, + config_filename: str | None = None, + overrides: dict[str, Any] | None = None, + **kwargs, ) -> RobotProcessor: """Load a serialized processor from source (local path or Hugging Face Hub identifier). Args: - source: Local path to a saved processor directory or Hugging Face Hub identifier + pretrained_model_name_or_path: Local path to a saved processor directory or Hugging Face Hub identifier (e.g., "username/processor-name"). config_filename: Optional specific config filename to load. If not provided, will: - For local paths: look for any .json file in the directory (error if multiple found) @@ -543,6 +554,9 @@ class RobotProcessor(ModelHubMixin): ) ``` """ + # Use the local variable name 'source' for clarity + source = str(pretrained_model_name_or_path) + if Path(source).is_dir(): # Local path - use it directly base_path = Path(source) @@ -560,7 +574,7 @@ class RobotProcessor(ModelHubMixin): config_filename = json_files[0].name with open(base_path / config_filename) as file_pointer: - config: dict[str, Any] = json.load(file_pointer) + loaded_config: dict[str, Any] = json.load(file_pointer) else: # Hugging Face Hub - download all required files if config_filename is None: @@ -574,7 +588,18 @@ class RobotProcessor(ModelHubMixin): config_path = None for name in common_names: try: - config_path = hf_hub_download(source, name, repo_type="model") + config_path = hf_hub_download( + source, + name, + repo_type="model", + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + token=token, + cache_dir=cache_dir, + local_files_only=local_files_only, + revision=revision, + ) config_filename = name break except (FileNotFoundError, OSError, HfHubHTTPError): @@ -590,10 +615,21 @@ class RobotProcessor(ModelHubMixin): ) else: # Download specific config file - config_path = hf_hub_download(source, config_filename, repo_type="model") + config_path = hf_hub_download( + source, + config_filename, + repo_type="model", + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + token=token, + cache_dir=cache_dir, + local_files_only=local_files_only, + revision=revision, + ) with open(config_path) as file_pointer: - config: dict[str, Any] = json.load(file_pointer) + loaded_config = json.load(file_pointer) # Store downloaded files in the same directory as the config base_path = Path(config_path).parent @@ -606,7 +642,7 @@ class RobotProcessor(ModelHubMixin): override_keys = set(overrides.keys()) steps: list[ProcessorStep] = [] - for step_entry in config["steps"]: + for step_entry in loaded_config["steps"]: # Check if step uses registry name or module path if "registry_name" in step_entry: # Load from registry @@ -658,7 +694,18 @@ class RobotProcessor(ModelHubMixin): state_path = str(base_path / step_entry["state_file"]) else: # Hugging Face Hub - download the state file - state_path = hf_hub_download(source, step_entry["state_file"], repo_type="model") + state_path = hf_hub_download( + source, + step_entry["state_file"], + repo_type="model", + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + token=token, + cache_dir=cache_dir, + local_files_only=local_files_only, + revision=revision, + ) step_instance.load_state_dict(load_file(state_path)) @@ -667,7 +714,7 @@ class RobotProcessor(ModelHubMixin): # Check for unused override keys if override_keys: available_keys = [] - for step_entry in config["steps"]: + for step_entry in loaded_config["steps"]: if "registry_name" in step_entry: available_keys.append(step_entry["registry_name"]) else: @@ -681,7 +728,7 @@ class RobotProcessor(ModelHubMixin): f"Make sure override keys match exact step class names or registry names." ) - return cls(steps, config.get("name", "RobotProcessor"), config.get("seed")) + return cls(steps, loaded_config.get("name", "RobotProcessor"), loaded_config.get("seed")) def __len__(self) -> int: """Return the number of steps in the processor."""