mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-25 13:40:00 +00:00
refactor(pipeline): Rename parameters for clarity and enhance save/load functionality
- Updated parameter names in the save_pretrained and from_pretrained methods for improved readability, changing destination_path to save_directory and source to pretrained_model_name_or_path. - Enhanced the save_pretrained method to ensure directory creation and file handling is consistent with the new parameter names. - Streamlined the loading process in from_pretrained to utilize loaded_config for better clarity and maintainability.
This commit is contained in:
@@ -403,21 +403,21 @@ class RobotProcessor(ModelHubMixin):
|
|||||||
transition = processor_step(transition)
|
transition = processor_step(transition)
|
||||||
yield transition
|
yield transition
|
||||||
|
|
||||||
def _save_pretrained(self, destination_path: str, **kwargs):
|
def _save_pretrained(self, save_directory: Path, **kwargs):
|
||||||
"""Internal save method for ModelHubMixin compatibility."""
|
"""Internal save method for ModelHubMixin compatibility."""
|
||||||
# Extract config_filename from kwargs if provided
|
# Extract config_filename from kwargs if provided
|
||||||
config_filename = kwargs.pop("config_filename", None)
|
config_filename = kwargs.pop("config_filename", None)
|
||||||
self.save_pretrained(destination_path, config_filename=config_filename)
|
self.save_pretrained(save_directory, config_filename=config_filename)
|
||||||
|
|
||||||
def save_pretrained(self, destination_path: str, config_filename: str | None = None, **kwargs):
|
def save_pretrained(self, save_directory: str | Path, config_filename: str | None = None, **kwargs):
|
||||||
"""Serialize the processor definition and parameters to *destination_path*.
|
"""Serialize the processor definition and parameters to *save_directory*.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
destination_path: Directory where the processor will be saved.
|
save_directory: Directory where the processor will be saved.
|
||||||
config_filename: Optional custom config filename. If not provided, defaults to
|
config_filename: Optional custom config filename. If not provided, defaults to
|
||||||
"{self.name}.json" where self.name is sanitized for filesystem compatibility.
|
"{self.name}.json" where self.name is sanitized for filesystem compatibility.
|
||||||
"""
|
"""
|
||||||
os.makedirs(destination_path, exist_ok=True)
|
os.makedirs(str(save_directory), exist_ok=True)
|
||||||
|
|
||||||
# Sanitize processor name for use in filenames
|
# Sanitize processor name for use in filenames
|
||||||
import re
|
import re
|
||||||
@@ -439,16 +439,15 @@ class RobotProcessor(ModelHubMixin):
|
|||||||
# Check if step was registered
|
# Check if step was registered
|
||||||
registry_name = getattr(processor_step.__class__, "_registry_name", None)
|
registry_name = getattr(processor_step.__class__, "_registry_name", None)
|
||||||
|
|
||||||
|
step_entry: dict[str, Any] = {}
|
||||||
if registry_name:
|
if registry_name:
|
||||||
# Use registry name for registered steps
|
# Use registry name for registered steps
|
||||||
step_entry: dict[str, Any] = {
|
step_entry["registry_name"] = registry_name
|
||||||
"registry_name": registry_name,
|
|
||||||
}
|
|
||||||
else:
|
else:
|
||||||
# Fall back to full module path for unregistered steps
|
# Fall back to full module path for unregistered steps
|
||||||
step_entry: dict[str, Any] = {
|
step_entry["class"] = (
|
||||||
"class": f"{processor_step.__class__.__module__}.{processor_step.__class__.__name__}",
|
f"{processor_step.__class__.__module__}.{processor_step.__class__.__name__}"
|
||||||
}
|
)
|
||||||
|
|
||||||
if hasattr(processor_step, "get_config"):
|
if hasattr(processor_step, "get_config"):
|
||||||
step_entry["config"] = processor_step.get_config()
|
step_entry["config"] = processor_step.get_config()
|
||||||
@@ -475,22 +474,34 @@ class RobotProcessor(ModelHubMixin):
|
|||||||
else:
|
else:
|
||||||
state_filename = f"{sanitized_name}_step_{step_index}.safetensors"
|
state_filename = f"{sanitized_name}_step_{step_index}.safetensors"
|
||||||
|
|
||||||
save_file(cloned_state, os.path.join(destination_path, state_filename))
|
save_file(cloned_state, os.path.join(str(save_directory), state_filename))
|
||||||
step_entry["state_file"] = state_filename
|
step_entry["state_file"] = state_filename
|
||||||
|
|
||||||
config["steps"].append(step_entry)
|
config["steps"].append(step_entry)
|
||||||
|
|
||||||
with open(os.path.join(destination_path, config_filename), "w") as file_pointer:
|
with open(os.path.join(str(save_directory), config_filename), "w") as file_pointer:
|
||||||
json.dump(config, file_pointer, indent=2)
|
json.dump(config, file_pointer, indent=2)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(
|
def from_pretrained(
|
||||||
cls, source: str, *, config_filename: str | None = None, overrides: dict[str, Any] | None = None
|
cls,
|
||||||
|
pretrained_model_name_or_path: str | Path,
|
||||||
|
*,
|
||||||
|
force_download: bool = False,
|
||||||
|
resume_download: bool | None = None,
|
||||||
|
proxies: dict[str, str] | None = None,
|
||||||
|
token: str | bool | None = None,
|
||||||
|
cache_dir: str | Path | None = None,
|
||||||
|
local_files_only: bool = False,
|
||||||
|
revision: str | None = None,
|
||||||
|
config_filename: str | None = None,
|
||||||
|
overrides: dict[str, Any] | None = None,
|
||||||
|
**kwargs,
|
||||||
) -> RobotProcessor:
|
) -> RobotProcessor:
|
||||||
"""Load a serialized processor from source (local path or Hugging Face Hub identifier).
|
"""Load a serialized processor from source (local path or Hugging Face Hub identifier).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
source: Local path to a saved processor directory or Hugging Face Hub identifier
|
pretrained_model_name_or_path: Local path to a saved processor directory or Hugging Face Hub identifier
|
||||||
(e.g., "username/processor-name").
|
(e.g., "username/processor-name").
|
||||||
config_filename: Optional specific config filename to load. If not provided, will:
|
config_filename: Optional specific config filename to load. If not provided, will:
|
||||||
- For local paths: look for any .json file in the directory (error if multiple found)
|
- For local paths: look for any .json file in the directory (error if multiple found)
|
||||||
@@ -543,6 +554,9 @@ class RobotProcessor(ModelHubMixin):
|
|||||||
)
|
)
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
# Use the local variable name 'source' for clarity
|
||||||
|
source = str(pretrained_model_name_or_path)
|
||||||
|
|
||||||
if Path(source).is_dir():
|
if Path(source).is_dir():
|
||||||
# Local path - use it directly
|
# Local path - use it directly
|
||||||
base_path = Path(source)
|
base_path = Path(source)
|
||||||
@@ -560,7 +574,7 @@ class RobotProcessor(ModelHubMixin):
|
|||||||
config_filename = json_files[0].name
|
config_filename = json_files[0].name
|
||||||
|
|
||||||
with open(base_path / config_filename) as file_pointer:
|
with open(base_path / config_filename) as file_pointer:
|
||||||
config: dict[str, Any] = json.load(file_pointer)
|
loaded_config: dict[str, Any] = json.load(file_pointer)
|
||||||
else:
|
else:
|
||||||
# Hugging Face Hub - download all required files
|
# Hugging Face Hub - download all required files
|
||||||
if config_filename is None:
|
if config_filename is None:
|
||||||
@@ -574,7 +588,18 @@ class RobotProcessor(ModelHubMixin):
|
|||||||
config_path = None
|
config_path = None
|
||||||
for name in common_names:
|
for name in common_names:
|
||||||
try:
|
try:
|
||||||
config_path = hf_hub_download(source, name, repo_type="model")
|
config_path = hf_hub_download(
|
||||||
|
source,
|
||||||
|
name,
|
||||||
|
repo_type="model",
|
||||||
|
force_download=force_download,
|
||||||
|
resume_download=resume_download,
|
||||||
|
proxies=proxies,
|
||||||
|
token=token,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
revision=revision,
|
||||||
|
)
|
||||||
config_filename = name
|
config_filename = name
|
||||||
break
|
break
|
||||||
except (FileNotFoundError, OSError, HfHubHTTPError):
|
except (FileNotFoundError, OSError, HfHubHTTPError):
|
||||||
@@ -590,10 +615,21 @@ class RobotProcessor(ModelHubMixin):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Download specific config file
|
# Download specific config file
|
||||||
config_path = hf_hub_download(source, config_filename, repo_type="model")
|
config_path = hf_hub_download(
|
||||||
|
source,
|
||||||
|
config_filename,
|
||||||
|
repo_type="model",
|
||||||
|
force_download=force_download,
|
||||||
|
resume_download=resume_download,
|
||||||
|
proxies=proxies,
|
||||||
|
token=token,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
revision=revision,
|
||||||
|
)
|
||||||
|
|
||||||
with open(config_path) as file_pointer:
|
with open(config_path) as file_pointer:
|
||||||
config: dict[str, Any] = json.load(file_pointer)
|
loaded_config = json.load(file_pointer)
|
||||||
|
|
||||||
# Store downloaded files in the same directory as the config
|
# Store downloaded files in the same directory as the config
|
||||||
base_path = Path(config_path).parent
|
base_path = Path(config_path).parent
|
||||||
@@ -606,7 +642,7 @@ class RobotProcessor(ModelHubMixin):
|
|||||||
override_keys = set(overrides.keys())
|
override_keys = set(overrides.keys())
|
||||||
|
|
||||||
steps: list[ProcessorStep] = []
|
steps: list[ProcessorStep] = []
|
||||||
for step_entry in config["steps"]:
|
for step_entry in loaded_config["steps"]:
|
||||||
# Check if step uses registry name or module path
|
# Check if step uses registry name or module path
|
||||||
if "registry_name" in step_entry:
|
if "registry_name" in step_entry:
|
||||||
# Load from registry
|
# Load from registry
|
||||||
@@ -658,7 +694,18 @@ class RobotProcessor(ModelHubMixin):
|
|||||||
state_path = str(base_path / step_entry["state_file"])
|
state_path = str(base_path / step_entry["state_file"])
|
||||||
else:
|
else:
|
||||||
# Hugging Face Hub - download the state file
|
# Hugging Face Hub - download the state file
|
||||||
state_path = hf_hub_download(source, step_entry["state_file"], repo_type="model")
|
state_path = hf_hub_download(
|
||||||
|
source,
|
||||||
|
step_entry["state_file"],
|
||||||
|
repo_type="model",
|
||||||
|
force_download=force_download,
|
||||||
|
resume_download=resume_download,
|
||||||
|
proxies=proxies,
|
||||||
|
token=token,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
revision=revision,
|
||||||
|
)
|
||||||
|
|
||||||
step_instance.load_state_dict(load_file(state_path))
|
step_instance.load_state_dict(load_file(state_path))
|
||||||
|
|
||||||
@@ -667,7 +714,7 @@ class RobotProcessor(ModelHubMixin):
|
|||||||
# Check for unused override keys
|
# Check for unused override keys
|
||||||
if override_keys:
|
if override_keys:
|
||||||
available_keys = []
|
available_keys = []
|
||||||
for step_entry in config["steps"]:
|
for step_entry in loaded_config["steps"]:
|
||||||
if "registry_name" in step_entry:
|
if "registry_name" in step_entry:
|
||||||
available_keys.append(step_entry["registry_name"])
|
available_keys.append(step_entry["registry_name"])
|
||||||
else:
|
else:
|
||||||
@@ -681,7 +728,7 @@ class RobotProcessor(ModelHubMixin):
|
|||||||
f"Make sure override keys match exact step class names or registry names."
|
f"Make sure override keys match exact step class names or registry names."
|
||||||
)
|
)
|
||||||
|
|
||||||
return cls(steps, config.get("name", "RobotProcessor"), config.get("seed"))
|
return cls(steps, loaded_config.get("name", "RobotProcessor"), loaded_config.get("seed"))
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
"""Return the number of steps in the processor."""
|
"""Return the number of steps in the processor."""
|
||||||
|
|||||||
Reference in New Issue
Block a user