refactor(processor): replace ModelHubMixin with HubMixin and enhance save_pretrained method (#1937)

- Updated DataProcessorPipeline to use HubMixin instead of ModelHubMixin for improved functionality.
- Refactored save_pretrained method to handle saving
This commit is contained in:
Adil Zouitine
2025-09-15 13:13:35 +02:00
committed by GitHub
parent 40e9ddd1ed
commit 066308ceb8
2 changed files with 142 additions and 74 deletions
+135 -70
View File
@@ -42,10 +42,11 @@ from pathlib import Path
from typing import Any, Generic, TypeAlias, TypedDict, TypeVar, cast
import torch
from huggingface_hub import ModelHubMixin, hf_hub_download
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file, save_file
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
from lerobot.utils.hub import HubMixin
from .converters import batch_to_transition, create_transition, transition_to_batch
from .core import EnvAction, EnvTransition, PolicyAction, RobotAction, TransitionKey
@@ -67,7 +68,7 @@ class ProcessorStepRegistry:
_registry: dict[str, type] = {}
@classmethod
def register(cls, name: str = None):
def register(cls, name: str | None = None):
"""A class decorator to register a ProcessorStep.
Args:
@@ -237,7 +238,7 @@ class ProcessorKwargs(TypedDict, total=False):
@dataclass
class DataProcessorPipeline(ModelHubMixin, Generic[TInput, TOutput]):
class DataProcessorPipeline(HubMixin, Generic[TInput, TOutput]):
"""A sequential pipeline for processing data, integrated with the Hugging Face Hub.
This class chains together multiple `ProcessorStep` instances to form a complete
@@ -324,24 +325,11 @@ class DataProcessorPipeline(ModelHubMixin, Generic[TInput, TOutput]):
yield transition
def _save_pretrained(self, save_directory: Path, **kwargs):
"""Internal method to comply with `ModelHubMixin`'s saving mechanism."""
config_filename = kwargs.pop("config_filename", None)
self.save_pretrained(save_directory, config_filename=config_filename)
"""Internal method to comply with `HubMixin`'s saving mechanism.
def save_pretrained(self, save_directory: str | Path, config_filename: str | None = None, **kwargs):
"""Saves the pipeline's configuration and state to a directory.
This method creates a JSON configuration file that defines the pipeline's structure
(name and steps). For each stateful step, it also saves a `.safetensors` file
containing its state dictionary.
Args:
save_directory: The directory where the pipeline will be saved.
config_filename: The name of the JSON configuration file. If None, a name is
generated from the pipeline's `name` attribute.
**kwargs: Additional arguments (not used, but present for compatibility).
This method does the actual saving work and is called by HubMixin.save_pretrained.
"""
os.makedirs(str(save_directory), exist_ok=True)
config_filename = kwargs.pop("config_filename", None)
# Sanitize the pipeline name to create a valid filename prefix.
sanitized_name = re.sub(r"[^a-zA-Z0-9_]", "_", self.name.lower())
@@ -393,6 +381,60 @@ class DataProcessorPipeline(ModelHubMixin, Generic[TInput, TOutput]):
with open(os.path.join(str(save_directory), config_filename), "w") as file_pointer:
json.dump(config, file_pointer, indent=2)
def save_pretrained(
self,
save_directory: str | Path | None = None,
*,
repo_id: str | None = None,
push_to_hub: bool = False,
card_kwargs: dict[str, Any] | None = None,
config_filename: str | None = None,
**push_to_hub_kwargs,
):
"""Saves the pipeline's configuration and state to a directory.
This method creates a JSON configuration file that defines the pipeline's structure
(name and steps). For each stateful step, it also saves a `.safetensors` file
containing its state dictionary.
Args:
save_directory: The directory where the pipeline will be saved. If None, saves to
HF_LEROBOT_HOME/processors/{sanitized_pipeline_name}.
repo_id: ID of your repository on the Hub. Used only if `push_to_hub=True`.
push_to_hub: Whether or not to push your object to the Hugging Face Hub after saving it.
card_kwargs: Additional arguments passed to the card template to customize the card.
config_filename: The name of the JSON configuration file. If None, a name is
generated from the pipeline's `name` attribute.
**push_to_hub_kwargs: Additional key word arguments passed along to the push_to_hub method.
"""
if save_directory is None:
# Use default directory in HF_LEROBOT_HOME
from lerobot.constants import HF_LEROBOT_HOME
sanitized_name = re.sub(r"[^a-zA-Z0-9_]", "_", self.name.lower())
save_directory = HF_LEROBOT_HOME / "processors" / sanitized_name
# For direct saves (not through hub), handle config_filename
if not push_to_hub and config_filename is not None:
# Call _save_pretrained directly with config_filename
save_directory = Path(save_directory)
save_directory.mkdir(parents=True, exist_ok=True)
self._save_pretrained(save_directory, config_filename=config_filename)
return None
# Pass config_filename through kwargs for _save_pretrained when using hub
if config_filename is not None:
push_to_hub_kwargs["config_filename"] = config_filename
# Call parent's save_pretrained which will call our _save_pretrained
return super().save_pretrained(
save_directory=save_directory,
repo_id=repo_id,
push_to_hub=push_to_hub,
card_kwargs=card_kwargs,
**push_to_hub_kwargs,
)
@classmethod
def from_pretrained(
cls,
@@ -430,9 +472,9 @@ class DataProcessorPipeline(ModelHubMixin, Generic[TInput, TOutput]):
cache_dir: The path to a specific cache folder to store downloaded files.
local_files_only: If True, avoid downloading files from the Hub.
revision: The specific model version to use (e.g., a branch name, tag name, or commit id).
config_filename: The name of the pipeline's JSON configuration file. Required when
loading from the Hub. If loading from a local directory, it's inferred if there's
only one `.json` file.
config_filename: The name of the pipeline's JSON configuration file. If not provided,
it's auto-detected in local directories (if only one .json file exists). This parameter
is mandatory when loading from Hugging Face Hub repositories.
overrides: A dictionary to override the configuration of specific steps. Keys should
match the step's class name or registry name.
to_transition: A custom function to convert input data to `EnvTransition`.
@@ -448,63 +490,85 @@ class DataProcessorPipeline(ModelHubMixin, Generic[TInput, TOutput]):
ImportError: If a step's class cannot be imported.
KeyError: If an override key doesn't match any step in the pipeline.
"""
source = str(pretrained_model_name_or_path)
model_id = str(pretrained_model_name_or_path)
loaded_config: dict[str, Any] | None = None
base_path: Path | None = None
# Heuristic to distinguish a local path from a Hub repository ID.
is_local_path = (
Path(source).is_dir()
or Path(source).is_absolute()
or source.startswith("./")
or source.startswith("../")
# A simple heuristic: repo IDs usually don't have more than one slash.
or source.count("/") > 1
or "\\" in source
)
# Standard pattern: try local directory first
if Path(model_id).is_dir():
base_path = Path(model_id)
# Load configuration from a local directory.
if is_local_path:
base_path = Path(source)
# If config filename is not provided, try to find a unique .json file.
# Handle config filename
if config_filename is None:
json_files = list(base_path.glob("*.json"))
if len(json_files) == 0:
raise FileNotFoundError(f"No .json configuration files found in {source}")
elif len(json_files) > 1:
# No config files found locally, will try Hub next
pass
elif len(json_files) == 1:
config_filename = json_files[0].name
else:
raise ValueError(
f"Multiple .json files found in {source}: {[f.name for f in json_files]}. "
f"Multiple .json files found in {model_id}: {[f.name for f in json_files]}. "
f"Please specify which one to load using the config_filename parameter."
)
config_filename = json_files[0].name
with open(base_path / config_filename) as file_pointer:
loaded_config: dict[str, Any] = json.load(file_pointer)
# Load configuration from the Hugging Face Hub.
else:
if config_filename is None:
raise ValueError(
f"For Hugging Face Hub repositories ({source}), you must specify the config_filename parameter. "
f"Example: DataProcessorPipeline.from_pretrained('{source}', config_filename='processor.json')"
)
# Download the configuration file from the Hub.
config_path = hf_hub_download(
source,
config_filename,
repo_type="model",
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
token=token,
cache_dir=cache_dir,
local_files_only=local_files_only,
revision=revision,
# Try to load config from local directory
if config_filename and (base_path / config_filename).exists():
with open(base_path / config_filename) as f:
loaded_config = json.load(f)
# If not found locally, try Hub
if loaded_config is None:
# Check if this looks like a local path that doesn't exist
# Hub repo IDs have format "user/repo" with exactly one slash
# Local paths typically have multiple slashes, backslashes, or start with ./ or ../
looks_like_local_path = (
model_id.count("/") > 1 # Multiple slashes suggest local path
or "\\" in model_id # Backslashes are only in local paths
or Path(model_id).is_absolute() # Absolute paths are local
or model_id.startswith("./")
or model_id.startswith("../") # Relative path indicators
)
with open(config_path) as file_pointer:
loaded_config = json.load(file_pointer)
if looks_like_local_path:
# This appears to be a local path that doesn't exist
raise FileNotFoundError(f"Local path '{model_id}' does not exist")
# For Hub repositories, config_filename is mandatory
if config_filename is None:
raise ValueError(
f"When loading from Hugging Face Hub, 'config_filename' must be specified. "
f"Example: DataProcessorPipeline.from_pretrained('{model_id}', config_filename='processor.json')"
)
# The base path for other files (like state tensors) is the directory of the config file.
base_path = Path(config_path).parent
try:
# Download the configuration file from the Hub
config_path = hf_hub_download(
repo_id=model_id,
filename=config_filename,
repo_type="model",
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
token=token,
cache_dir=cache_dir,
local_files_only=local_files_only,
revision=revision,
)
with open(config_path) as f:
loaded_config = json.load(f)
# The base path for other files (like state tensors) is the directory of the config file
base_path = Path(config_path).parent
except Exception as e:
raise FileNotFoundError(
f"Could not find {config_filename} on the HuggingFace Hub at {model_id}"
) from e
# At this point, loaded_config must be loaded successfully
if loaded_config is None:
raise RuntimeError("Failed to load configuration from local directory or Hub")
if overrides is None:
overrides = {}
@@ -556,13 +620,14 @@ class DataProcessorPipeline(ModelHubMixin, Generic[TInput, TOutput]):
# Load the step's state if a state file is specified.
if "state_file" in step_entry and hasattr(step_instance, "load_state_dict"):
if is_local_path:
# Check if state file exists locally first
if base_path and (base_path / step_entry["state_file"]).exists():
state_path = str(base_path / step_entry["state_file"])
else:
# Download the state file from the Hub.
state_path = hf_hub_download(
source,
step_entry["state_file"],
repo_id=model_id,
filename=step_entry["state_file"],
repo_type="model",
force_download=force_download,
resume_download=resume_download,
+7 -4
View File
@@ -1736,12 +1736,14 @@ def test_from_pretrained_nonexistent_path():
with pytest.raises(FileNotFoundError):
DataProcessorPipeline.from_pretrained("/path/that/does/not/exist")
# Test with a Hub repo format that would be a local path (too many slashes)
# Test with a path that doesn't exist as a directory
with pytest.raises(FileNotFoundError):
DataProcessorPipeline.from_pretrained("user/repo/extra/path")
# Test with a non-existent but valid Hub repo format (now requires config_filename)
with pytest.raises(ValueError, match="you must specify the config_filename parameter"):
# Test with a Hub repo without specifying config_filename (should raise ValueError)
with pytest.raises(
ValueError, match="When loading from Hugging Face Hub, 'config_filename' must be specified"
):
DataProcessorPipeline.from_pretrained("nonexistent-user/nonexistent-repo")
# Test with a non-existent Hub repo when config_filename is provided
@@ -1752,7 +1754,8 @@ def test_from_pretrained_nonexistent_path():
# Test with a local directory that exists but has no config files
with tempfile.TemporaryDirectory() as tmp_dir:
with pytest.raises(FileNotFoundError, match="No .json configuration files found"):
# Since the directory exists but has no config, it will try Hub and fail
with pytest.raises(FileNotFoundError):
DataProcessorPipeline.from_pretrained(tmp_dir)