mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 00:29:52 +00:00
refactor(pipeline): Rename parameters for clarity and enhance save/load functionality
- Updated parameter names in the save_pretrained and from_pretrained methods for improved readability, changing destination_path to save_directory and source to pretrained_model_name_or_path. - Enhanced the save_pretrained method to ensure directory creation and file handling is consistent with the new parameter names. - Streamlined the loading process in from_pretrained to utilize loaded_config for better clarity and maintainability.
This commit is contained in:
@@ -403,21 +403,21 @@ class RobotProcessor(ModelHubMixin):
|
||||
transition = processor_step(transition)
|
||||
yield transition
|
||||
|
||||
def _save_pretrained(self, destination_path: str, **kwargs):
|
||||
def _save_pretrained(self, save_directory: Path, **kwargs):
|
||||
"""Internal save method for ModelHubMixin compatibility."""
|
||||
# Extract config_filename from kwargs if provided
|
||||
config_filename = kwargs.pop("config_filename", None)
|
||||
self.save_pretrained(destination_path, config_filename=config_filename)
|
||||
self.save_pretrained(save_directory, config_filename=config_filename)
|
||||
|
||||
def save_pretrained(self, destination_path: str, config_filename: str | None = None, **kwargs):
|
||||
"""Serialize the processor definition and parameters to *destination_path*.
|
||||
def save_pretrained(self, save_directory: str | Path, config_filename: str | None = None, **kwargs):
|
||||
"""Serialize the processor definition and parameters to *save_directory*.
|
||||
|
||||
Args:
|
||||
destination_path: Directory where the processor will be saved.
|
||||
save_directory: Directory where the processor will be saved.
|
||||
config_filename: Optional custom config filename. If not provided, defaults to
|
||||
"{self.name}.json" where self.name is sanitized for filesystem compatibility.
|
||||
"""
|
||||
os.makedirs(destination_path, exist_ok=True)
|
||||
os.makedirs(str(save_directory), exist_ok=True)
|
||||
|
||||
# Sanitize processor name for use in filenames
|
||||
import re
|
||||
@@ -439,16 +439,15 @@ class RobotProcessor(ModelHubMixin):
|
||||
# Check if step was registered
|
||||
registry_name = getattr(processor_step.__class__, "_registry_name", None)
|
||||
|
||||
step_entry: dict[str, Any] = {}
|
||||
if registry_name:
|
||||
# Use registry name for registered steps
|
||||
step_entry: dict[str, Any] = {
|
||||
"registry_name": registry_name,
|
||||
}
|
||||
step_entry["registry_name"] = registry_name
|
||||
else:
|
||||
# Fall back to full module path for unregistered steps
|
||||
step_entry: dict[str, Any] = {
|
||||
"class": f"{processor_step.__class__.__module__}.{processor_step.__class__.__name__}",
|
||||
}
|
||||
step_entry["class"] = (
|
||||
f"{processor_step.__class__.__module__}.{processor_step.__class__.__name__}"
|
||||
)
|
||||
|
||||
if hasattr(processor_step, "get_config"):
|
||||
step_entry["config"] = processor_step.get_config()
|
||||
@@ -475,22 +474,34 @@ class RobotProcessor(ModelHubMixin):
|
||||
else:
|
||||
state_filename = f"{sanitized_name}_step_{step_index}.safetensors"
|
||||
|
||||
save_file(cloned_state, os.path.join(destination_path, state_filename))
|
||||
save_file(cloned_state, os.path.join(str(save_directory), state_filename))
|
||||
step_entry["state_file"] = state_filename
|
||||
|
||||
config["steps"].append(step_entry)
|
||||
|
||||
with open(os.path.join(destination_path, config_filename), "w") as file_pointer:
|
||||
with open(os.path.join(str(save_directory), config_filename), "w") as file_pointer:
|
||||
json.dump(config, file_pointer, indent=2)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls, source: str, *, config_filename: str | None = None, overrides: dict[str, Any] | None = None
|
||||
cls,
|
||||
pretrained_model_name_or_path: str | Path,
|
||||
*,
|
||||
force_download: bool = False,
|
||||
resume_download: bool | None = None,
|
||||
proxies: dict[str, str] | None = None,
|
||||
token: str | bool | None = None,
|
||||
cache_dir: str | Path | None = None,
|
||||
local_files_only: bool = False,
|
||||
revision: str | None = None,
|
||||
config_filename: str | None = None,
|
||||
overrides: dict[str, Any] | None = None,
|
||||
**kwargs,
|
||||
) -> RobotProcessor:
|
||||
"""Load a serialized processor from source (local path or Hugging Face Hub identifier).
|
||||
|
||||
Args:
|
||||
source: Local path to a saved processor directory or Hugging Face Hub identifier
|
||||
pretrained_model_name_or_path: Local path to a saved processor directory or Hugging Face Hub identifier
|
||||
(e.g., "username/processor-name").
|
||||
config_filename: Optional specific config filename to load. If not provided, will:
|
||||
- For local paths: look for any .json file in the directory (error if multiple found)
|
||||
@@ -543,6 +554,9 @@ class RobotProcessor(ModelHubMixin):
|
||||
)
|
||||
```
|
||||
"""
|
||||
# Use the local variable name 'source' for clarity
|
||||
source = str(pretrained_model_name_or_path)
|
||||
|
||||
if Path(source).is_dir():
|
||||
# Local path - use it directly
|
||||
base_path = Path(source)
|
||||
@@ -560,7 +574,7 @@ class RobotProcessor(ModelHubMixin):
|
||||
config_filename = json_files[0].name
|
||||
|
||||
with open(base_path / config_filename) as file_pointer:
|
||||
config: dict[str, Any] = json.load(file_pointer)
|
||||
loaded_config: dict[str, Any] = json.load(file_pointer)
|
||||
else:
|
||||
# Hugging Face Hub - download all required files
|
||||
if config_filename is None:
|
||||
@@ -574,7 +588,18 @@ class RobotProcessor(ModelHubMixin):
|
||||
config_path = None
|
||||
for name in common_names:
|
||||
try:
|
||||
config_path = hf_hub_download(source, name, repo_type="model")
|
||||
config_path = hf_hub_download(
|
||||
source,
|
||||
name,
|
||||
repo_type="model",
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=local_files_only,
|
||||
revision=revision,
|
||||
)
|
||||
config_filename = name
|
||||
break
|
||||
except (FileNotFoundError, OSError, HfHubHTTPError):
|
||||
@@ -590,10 +615,21 @@ class RobotProcessor(ModelHubMixin):
|
||||
)
|
||||
else:
|
||||
# Download specific config file
|
||||
config_path = hf_hub_download(source, config_filename, repo_type="model")
|
||||
config_path = hf_hub_download(
|
||||
source,
|
||||
config_filename,
|
||||
repo_type="model",
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=local_files_only,
|
||||
revision=revision,
|
||||
)
|
||||
|
||||
with open(config_path) as file_pointer:
|
||||
config: dict[str, Any] = json.load(file_pointer)
|
||||
loaded_config = json.load(file_pointer)
|
||||
|
||||
# Store downloaded files in the same directory as the config
|
||||
base_path = Path(config_path).parent
|
||||
@@ -606,7 +642,7 @@ class RobotProcessor(ModelHubMixin):
|
||||
override_keys = set(overrides.keys())
|
||||
|
||||
steps: list[ProcessorStep] = []
|
||||
for step_entry in config["steps"]:
|
||||
for step_entry in loaded_config["steps"]:
|
||||
# Check if step uses registry name or module path
|
||||
if "registry_name" in step_entry:
|
||||
# Load from registry
|
||||
@@ -658,7 +694,18 @@ class RobotProcessor(ModelHubMixin):
|
||||
state_path = str(base_path / step_entry["state_file"])
|
||||
else:
|
||||
# Hugging Face Hub - download the state file
|
||||
state_path = hf_hub_download(source, step_entry["state_file"], repo_type="model")
|
||||
state_path = hf_hub_download(
|
||||
source,
|
||||
step_entry["state_file"],
|
||||
repo_type="model",
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=local_files_only,
|
||||
revision=revision,
|
||||
)
|
||||
|
||||
step_instance.load_state_dict(load_file(state_path))
|
||||
|
||||
@@ -667,7 +714,7 @@ class RobotProcessor(ModelHubMixin):
|
||||
# Check for unused override keys
|
||||
if override_keys:
|
||||
available_keys = []
|
||||
for step_entry in config["steps"]:
|
||||
for step_entry in loaded_config["steps"]:
|
||||
if "registry_name" in step_entry:
|
||||
available_keys.append(step_entry["registry_name"])
|
||||
else:
|
||||
@@ -681,7 +728,7 @@ class RobotProcessor(ModelHubMixin):
|
||||
f"Make sure override keys match exact step class names or registry names."
|
||||
)
|
||||
|
||||
return cls(steps, config.get("name", "RobotProcessor"), config.get("seed"))
|
||||
return cls(steps, loaded_config.get("name", "RobotProcessor"), loaded_config.get("seed"))
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the number of steps in the processor."""
|
||||
|
||||
Reference in New Issue
Block a user