mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 00:29:52 +00:00
refactor(processor): replace ModelHubMixin with HubMixin and enhance save_pretrained method (#1937)
- Updated DataProcessorPipeline to use HubMixin instead of ModelHubMixin for improved functionality. - Refactored save_pretrained method to handle saving
This commit is contained in:
@@ -42,10 +42,11 @@ from pathlib import Path
|
||||
from typing import Any, Generic, TypeAlias, TypedDict, TypeVar, cast
|
||||
|
||||
import torch
|
||||
from huggingface_hub import ModelHubMixin, hf_hub_download
|
||||
from huggingface_hub import hf_hub_download
|
||||
from safetensors.torch import load_file, save_file
|
||||
|
||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.utils.hub import HubMixin
|
||||
|
||||
from .converters import batch_to_transition, create_transition, transition_to_batch
|
||||
from .core import EnvAction, EnvTransition, PolicyAction, RobotAction, TransitionKey
|
||||
@@ -67,7 +68,7 @@ class ProcessorStepRegistry:
|
||||
_registry: dict[str, type] = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, name: str = None):
|
||||
def register(cls, name: str | None = None):
|
||||
"""A class decorator to register a ProcessorStep.
|
||||
|
||||
Args:
|
||||
@@ -237,7 +238,7 @@ class ProcessorKwargs(TypedDict, total=False):
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataProcessorPipeline(ModelHubMixin, Generic[TInput, TOutput]):
|
||||
class DataProcessorPipeline(HubMixin, Generic[TInput, TOutput]):
|
||||
"""A sequential pipeline for processing data, integrated with the Hugging Face Hub.
|
||||
|
||||
This class chains together multiple `ProcessorStep` instances to form a complete
|
||||
@@ -324,24 +325,11 @@ class DataProcessorPipeline(ModelHubMixin, Generic[TInput, TOutput]):
|
||||
yield transition
|
||||
|
||||
def _save_pretrained(self, save_directory: Path, **kwargs):
|
||||
"""Internal method to comply with `ModelHubMixin`'s saving mechanism."""
|
||||
config_filename = kwargs.pop("config_filename", None)
|
||||
self.save_pretrained(save_directory, config_filename=config_filename)
|
||||
"""Internal method to comply with `HubMixin`'s saving mechanism.
|
||||
|
||||
def save_pretrained(self, save_directory: str | Path, config_filename: str | None = None, **kwargs):
|
||||
"""Saves the pipeline's configuration and state to a directory.
|
||||
|
||||
This method creates a JSON configuration file that defines the pipeline's structure
|
||||
(name and steps). For each stateful step, it also saves a `.safetensors` file
|
||||
containing its state dictionary.
|
||||
|
||||
Args:
|
||||
save_directory: The directory where the pipeline will be saved.
|
||||
config_filename: The name of the JSON configuration file. If None, a name is
|
||||
generated from the pipeline's `name` attribute.
|
||||
**kwargs: Additional arguments (not used, but present for compatibility).
|
||||
This method does the actual saving work and is called by HubMixin.save_pretrained.
|
||||
"""
|
||||
os.makedirs(str(save_directory), exist_ok=True)
|
||||
config_filename = kwargs.pop("config_filename", None)
|
||||
|
||||
# Sanitize the pipeline name to create a valid filename prefix.
|
||||
sanitized_name = re.sub(r"[^a-zA-Z0-9_]", "_", self.name.lower())
|
||||
@@ -393,6 +381,60 @@ class DataProcessorPipeline(ModelHubMixin, Generic[TInput, TOutput]):
|
||||
with open(os.path.join(str(save_directory), config_filename), "w") as file_pointer:
|
||||
json.dump(config, file_pointer, indent=2)
|
||||
|
||||
def save_pretrained(
|
||||
self,
|
||||
save_directory: str | Path | None = None,
|
||||
*,
|
||||
repo_id: str | None = None,
|
||||
push_to_hub: bool = False,
|
||||
card_kwargs: dict[str, Any] | None = None,
|
||||
config_filename: str | None = None,
|
||||
**push_to_hub_kwargs,
|
||||
):
|
||||
"""Saves the pipeline's configuration and state to a directory.
|
||||
|
||||
This method creates a JSON configuration file that defines the pipeline's structure
|
||||
(name and steps). For each stateful step, it also saves a `.safetensors` file
|
||||
containing its state dictionary.
|
||||
|
||||
Args:
|
||||
save_directory: The directory where the pipeline will be saved. If None, saves to
|
||||
HF_LEROBOT_HOME/processors/{sanitized_pipeline_name}.
|
||||
repo_id: ID of your repository on the Hub. Used only if `push_to_hub=True`.
|
||||
push_to_hub: Whether or not to push your object to the Hugging Face Hub after saving it.
|
||||
card_kwargs: Additional arguments passed to the card template to customize the card.
|
||||
config_filename: The name of the JSON configuration file. If None, a name is
|
||||
generated from the pipeline's `name` attribute.
|
||||
**push_to_hub_kwargs: Additional key word arguments passed along to the push_to_hub method.
|
||||
"""
|
||||
if save_directory is None:
|
||||
# Use default directory in HF_LEROBOT_HOME
|
||||
from lerobot.constants import HF_LEROBOT_HOME
|
||||
|
||||
sanitized_name = re.sub(r"[^a-zA-Z0-9_]", "_", self.name.lower())
|
||||
save_directory = HF_LEROBOT_HOME / "processors" / sanitized_name
|
||||
|
||||
# For direct saves (not through hub), handle config_filename
|
||||
if not push_to_hub and config_filename is not None:
|
||||
# Call _save_pretrained directly with config_filename
|
||||
save_directory = Path(save_directory)
|
||||
save_directory.mkdir(parents=True, exist_ok=True)
|
||||
self._save_pretrained(save_directory, config_filename=config_filename)
|
||||
return None
|
||||
|
||||
# Pass config_filename through kwargs for _save_pretrained when using hub
|
||||
if config_filename is not None:
|
||||
push_to_hub_kwargs["config_filename"] = config_filename
|
||||
|
||||
# Call parent's save_pretrained which will call our _save_pretrained
|
||||
return super().save_pretrained(
|
||||
save_directory=save_directory,
|
||||
repo_id=repo_id,
|
||||
push_to_hub=push_to_hub,
|
||||
card_kwargs=card_kwargs,
|
||||
**push_to_hub_kwargs,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
@@ -430,9 +472,9 @@ class DataProcessorPipeline(ModelHubMixin, Generic[TInput, TOutput]):
|
||||
cache_dir: The path to a specific cache folder to store downloaded files.
|
||||
local_files_only: If True, avoid downloading files from the Hub.
|
||||
revision: The specific model version to use (e.g., a branch name, tag name, or commit id).
|
||||
config_filename: The name of the pipeline's JSON configuration file. Required when
|
||||
loading from the Hub. If loading from a local directory, it's inferred if there's
|
||||
only one `.json` file.
|
||||
config_filename: The name of the pipeline's JSON configuration file. If not provided,
|
||||
it's auto-detected in local directories (if only one .json file exists). This parameter
|
||||
is mandatory when loading from Hugging Face Hub repositories.
|
||||
overrides: A dictionary to override the configuration of specific steps. Keys should
|
||||
match the step's class name or registry name.
|
||||
to_transition: A custom function to convert input data to `EnvTransition`.
|
||||
@@ -448,63 +490,85 @@ class DataProcessorPipeline(ModelHubMixin, Generic[TInput, TOutput]):
|
||||
ImportError: If a step's class cannot be imported.
|
||||
KeyError: If an override key doesn't match any step in the pipeline.
|
||||
"""
|
||||
source = str(pretrained_model_name_or_path)
|
||||
model_id = str(pretrained_model_name_or_path)
|
||||
loaded_config: dict[str, Any] | None = None
|
||||
base_path: Path | None = None
|
||||
|
||||
# Heuristic to distinguish a local path from a Hub repository ID.
|
||||
is_local_path = (
|
||||
Path(source).is_dir()
|
||||
or Path(source).is_absolute()
|
||||
or source.startswith("./")
|
||||
or source.startswith("../")
|
||||
# A simple heuristic: repo IDs usually don't have more than one slash.
|
||||
or source.count("/") > 1
|
||||
or "\\" in source
|
||||
)
|
||||
# Standard pattern: try local directory first
|
||||
if Path(model_id).is_dir():
|
||||
base_path = Path(model_id)
|
||||
|
||||
# Load configuration from a local directory.
|
||||
if is_local_path:
|
||||
base_path = Path(source)
|
||||
|
||||
# If config filename is not provided, try to find a unique .json file.
|
||||
# Handle config filename
|
||||
if config_filename is None:
|
||||
json_files = list(base_path.glob("*.json"))
|
||||
if len(json_files) == 0:
|
||||
raise FileNotFoundError(f"No .json configuration files found in {source}")
|
||||
elif len(json_files) > 1:
|
||||
# No config files found locally, will try Hub next
|
||||
pass
|
||||
elif len(json_files) == 1:
|
||||
config_filename = json_files[0].name
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Multiple .json files found in {source}: {[f.name for f in json_files]}. "
|
||||
f"Multiple .json files found in {model_id}: {[f.name for f in json_files]}. "
|
||||
f"Please specify which one to load using the config_filename parameter."
|
||||
)
|
||||
config_filename = json_files[0].name
|
||||
|
||||
with open(base_path / config_filename) as file_pointer:
|
||||
loaded_config: dict[str, Any] = json.load(file_pointer)
|
||||
# Load configuration from the Hugging Face Hub.
|
||||
else:
|
||||
if config_filename is None:
|
||||
raise ValueError(
|
||||
f"For Hugging Face Hub repositories ({source}), you must specify the config_filename parameter. "
|
||||
f"Example: DataProcessorPipeline.from_pretrained('{source}', config_filename='processor.json')"
|
||||
)
|
||||
# Download the configuration file from the Hub.
|
||||
config_path = hf_hub_download(
|
||||
source,
|
||||
config_filename,
|
||||
repo_type="model",
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=local_files_only,
|
||||
revision=revision,
|
||||
# Try to load config from local directory
|
||||
if config_filename and (base_path / config_filename).exists():
|
||||
with open(base_path / config_filename) as f:
|
||||
loaded_config = json.load(f)
|
||||
|
||||
# If not found locally, try Hub
|
||||
if loaded_config is None:
|
||||
# Check if this looks like a local path that doesn't exist
|
||||
# Hub repo IDs have format "user/repo" with exactly one slash
|
||||
# Local paths typically have multiple slashes, backslashes, or start with ./ or ../
|
||||
looks_like_local_path = (
|
||||
model_id.count("/") > 1 # Multiple slashes suggest local path
|
||||
or "\\" in model_id # Backslashes are only in local paths
|
||||
or Path(model_id).is_absolute() # Absolute paths are local
|
||||
or model_id.startswith("./")
|
||||
or model_id.startswith("../") # Relative path indicators
|
||||
)
|
||||
|
||||
with open(config_path) as file_pointer:
|
||||
loaded_config = json.load(file_pointer)
|
||||
if looks_like_local_path:
|
||||
# This appears to be a local path that doesn't exist
|
||||
raise FileNotFoundError(f"Local path '{model_id}' does not exist")
|
||||
# For Hub repositories, config_filename is mandatory
|
||||
if config_filename is None:
|
||||
raise ValueError(
|
||||
f"When loading from Hugging Face Hub, 'config_filename' must be specified. "
|
||||
f"Example: DataProcessorPipeline.from_pretrained('{model_id}', config_filename='processor.json')"
|
||||
)
|
||||
|
||||
# The base path for other files (like state tensors) is the directory of the config file.
|
||||
base_path = Path(config_path).parent
|
||||
try:
|
||||
# Download the configuration file from the Hub
|
||||
config_path = hf_hub_download(
|
||||
repo_id=model_id,
|
||||
filename=config_filename,
|
||||
repo_type="model",
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=local_files_only,
|
||||
revision=revision,
|
||||
)
|
||||
|
||||
with open(config_path) as f:
|
||||
loaded_config = json.load(f)
|
||||
|
||||
# The base path for other files (like state tensors) is the directory of the config file
|
||||
base_path = Path(config_path).parent
|
||||
|
||||
except Exception as e:
|
||||
raise FileNotFoundError(
|
||||
f"Could not find {config_filename} on the HuggingFace Hub at {model_id}"
|
||||
) from e
|
||||
|
||||
# At this point, loaded_config must be loaded successfully
|
||||
if loaded_config is None:
|
||||
raise RuntimeError("Failed to load configuration from local directory or Hub")
|
||||
|
||||
if overrides is None:
|
||||
overrides = {}
|
||||
@@ -556,13 +620,14 @@ class DataProcessorPipeline(ModelHubMixin, Generic[TInput, TOutput]):
|
||||
|
||||
# Load the step's state if a state file is specified.
|
||||
if "state_file" in step_entry and hasattr(step_instance, "load_state_dict"):
|
||||
if is_local_path:
|
||||
# Check if state file exists locally first
|
||||
if base_path and (base_path / step_entry["state_file"]).exists():
|
||||
state_path = str(base_path / step_entry["state_file"])
|
||||
else:
|
||||
# Download the state file from the Hub.
|
||||
state_path = hf_hub_download(
|
||||
source,
|
||||
step_entry["state_file"],
|
||||
repo_id=model_id,
|
||||
filename=step_entry["state_file"],
|
||||
repo_type="model",
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
|
||||
@@ -1736,12 +1736,14 @@ def test_from_pretrained_nonexistent_path():
|
||||
with pytest.raises(FileNotFoundError):
|
||||
DataProcessorPipeline.from_pretrained("/path/that/does/not/exist")
|
||||
|
||||
# Test with a Hub repo format that would be a local path (too many slashes)
|
||||
# Test with a path that doesn't exist as a directory
|
||||
with pytest.raises(FileNotFoundError):
|
||||
DataProcessorPipeline.from_pretrained("user/repo/extra/path")
|
||||
|
||||
# Test with a non-existent but valid Hub repo format (now requires config_filename)
|
||||
with pytest.raises(ValueError, match="you must specify the config_filename parameter"):
|
||||
# Test with a Hub repo without specifying config_filename (should raise ValueError)
|
||||
with pytest.raises(
|
||||
ValueError, match="When loading from Hugging Face Hub, 'config_filename' must be specified"
|
||||
):
|
||||
DataProcessorPipeline.from_pretrained("nonexistent-user/nonexistent-repo")
|
||||
|
||||
# Test with a non-existent Hub repo when config_filename is provided
|
||||
@@ -1752,7 +1754,8 @@ def test_from_pretrained_nonexistent_path():
|
||||
|
||||
# Test with a local directory that exists but has no config files
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
with pytest.raises(FileNotFoundError, match="No .json configuration files found"):
|
||||
# Since the directory exists but has no config, it will try Hub and fail
|
||||
with pytest.raises(FileNotFoundError):
|
||||
DataProcessorPipeline.from_pretrained(tmp_dir)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user