mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-23 12:40:08 +00:00
refactor(processor): replace ModelHubMixin with HubMixin and enhance save_pretrained method (#1937)
- Updated DataProcessorPipeline to use HubMixin instead of ModelHubMixin for improved functionality. - Refactored save_pretrained method to handle saving
This commit is contained in:
@@ -42,10 +42,11 @@ from pathlib import Path
|
|||||||
from typing import Any, Generic, TypeAlias, TypedDict, TypeVar, cast
|
from typing import Any, Generic, TypeAlias, TypedDict, TypeVar, cast
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from huggingface_hub import ModelHubMixin, hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
from safetensors.torch import load_file, save_file
|
from safetensors.torch import load_file, save_file
|
||||||
|
|
||||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||||
|
from lerobot.utils.hub import HubMixin
|
||||||
|
|
||||||
from .converters import batch_to_transition, create_transition, transition_to_batch
|
from .converters import batch_to_transition, create_transition, transition_to_batch
|
||||||
from .core import EnvAction, EnvTransition, PolicyAction, RobotAction, TransitionKey
|
from .core import EnvAction, EnvTransition, PolicyAction, RobotAction, TransitionKey
|
||||||
@@ -67,7 +68,7 @@ class ProcessorStepRegistry:
|
|||||||
_registry: dict[str, type] = {}
|
_registry: dict[str, type] = {}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def register(cls, name: str = None):
|
def register(cls, name: str | None = None):
|
||||||
"""A class decorator to register a ProcessorStep.
|
"""A class decorator to register a ProcessorStep.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -237,7 +238,7 @@ class ProcessorKwargs(TypedDict, total=False):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DataProcessorPipeline(ModelHubMixin, Generic[TInput, TOutput]):
|
class DataProcessorPipeline(HubMixin, Generic[TInput, TOutput]):
|
||||||
"""A sequential pipeline for processing data, integrated with the Hugging Face Hub.
|
"""A sequential pipeline for processing data, integrated with the Hugging Face Hub.
|
||||||
|
|
||||||
This class chains together multiple `ProcessorStep` instances to form a complete
|
This class chains together multiple `ProcessorStep` instances to form a complete
|
||||||
@@ -324,24 +325,11 @@ class DataProcessorPipeline(ModelHubMixin, Generic[TInput, TOutput]):
|
|||||||
yield transition
|
yield transition
|
||||||
|
|
||||||
def _save_pretrained(self, save_directory: Path, **kwargs):
|
def _save_pretrained(self, save_directory: Path, **kwargs):
|
||||||
"""Internal method to comply with `ModelHubMixin`'s saving mechanism."""
|
"""Internal method to comply with `HubMixin`'s saving mechanism.
|
||||||
config_filename = kwargs.pop("config_filename", None)
|
|
||||||
self.save_pretrained(save_directory, config_filename=config_filename)
|
|
||||||
|
|
||||||
def save_pretrained(self, save_directory: str | Path, config_filename: str | None = None, **kwargs):
|
This method does the actual saving work and is called by HubMixin.save_pretrained.
|
||||||
"""Saves the pipeline's configuration and state to a directory.
|
|
||||||
|
|
||||||
This method creates a JSON configuration file that defines the pipeline's structure
|
|
||||||
(name and steps). For each stateful step, it also saves a `.safetensors` file
|
|
||||||
containing its state dictionary.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
save_directory: The directory where the pipeline will be saved.
|
|
||||||
config_filename: The name of the JSON configuration file. If None, a name is
|
|
||||||
generated from the pipeline's `name` attribute.
|
|
||||||
**kwargs: Additional arguments (not used, but present for compatibility).
|
|
||||||
"""
|
"""
|
||||||
os.makedirs(str(save_directory), exist_ok=True)
|
config_filename = kwargs.pop("config_filename", None)
|
||||||
|
|
||||||
# Sanitize the pipeline name to create a valid filename prefix.
|
# Sanitize the pipeline name to create a valid filename prefix.
|
||||||
sanitized_name = re.sub(r"[^a-zA-Z0-9_]", "_", self.name.lower())
|
sanitized_name = re.sub(r"[^a-zA-Z0-9_]", "_", self.name.lower())
|
||||||
@@ -393,6 +381,60 @@ class DataProcessorPipeline(ModelHubMixin, Generic[TInput, TOutput]):
|
|||||||
with open(os.path.join(str(save_directory), config_filename), "w") as file_pointer:
|
with open(os.path.join(str(save_directory), config_filename), "w") as file_pointer:
|
||||||
json.dump(config, file_pointer, indent=2)
|
json.dump(config, file_pointer, indent=2)
|
||||||
|
|
||||||
|
def save_pretrained(
|
||||||
|
self,
|
||||||
|
save_directory: str | Path | None = None,
|
||||||
|
*,
|
||||||
|
repo_id: str | None = None,
|
||||||
|
push_to_hub: bool = False,
|
||||||
|
card_kwargs: dict[str, Any] | None = None,
|
||||||
|
config_filename: str | None = None,
|
||||||
|
**push_to_hub_kwargs,
|
||||||
|
):
|
||||||
|
"""Saves the pipeline's configuration and state to a directory.
|
||||||
|
|
||||||
|
This method creates a JSON configuration file that defines the pipeline's structure
|
||||||
|
(name and steps). For each stateful step, it also saves a `.safetensors` file
|
||||||
|
containing its state dictionary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
save_directory: The directory where the pipeline will be saved. If None, saves to
|
||||||
|
HF_LEROBOT_HOME/processors/{sanitized_pipeline_name}.
|
||||||
|
repo_id: ID of your repository on the Hub. Used only if `push_to_hub=True`.
|
||||||
|
push_to_hub: Whether or not to push your object to the Hugging Face Hub after saving it.
|
||||||
|
card_kwargs: Additional arguments passed to the card template to customize the card.
|
||||||
|
config_filename: The name of the JSON configuration file. If None, a name is
|
||||||
|
generated from the pipeline's `name` attribute.
|
||||||
|
**push_to_hub_kwargs: Additional key word arguments passed along to the push_to_hub method.
|
||||||
|
"""
|
||||||
|
if save_directory is None:
|
||||||
|
# Use default directory in HF_LEROBOT_HOME
|
||||||
|
from lerobot.constants import HF_LEROBOT_HOME
|
||||||
|
|
||||||
|
sanitized_name = re.sub(r"[^a-zA-Z0-9_]", "_", self.name.lower())
|
||||||
|
save_directory = HF_LEROBOT_HOME / "processors" / sanitized_name
|
||||||
|
|
||||||
|
# For direct saves (not through hub), handle config_filename
|
||||||
|
if not push_to_hub and config_filename is not None:
|
||||||
|
# Call _save_pretrained directly with config_filename
|
||||||
|
save_directory = Path(save_directory)
|
||||||
|
save_directory.mkdir(parents=True, exist_ok=True)
|
||||||
|
self._save_pretrained(save_directory, config_filename=config_filename)
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Pass config_filename through kwargs for _save_pretrained when using hub
|
||||||
|
if config_filename is not None:
|
||||||
|
push_to_hub_kwargs["config_filename"] = config_filename
|
||||||
|
|
||||||
|
# Call parent's save_pretrained which will call our _save_pretrained
|
||||||
|
return super().save_pretrained(
|
||||||
|
save_directory=save_directory,
|
||||||
|
repo_id=repo_id,
|
||||||
|
push_to_hub=push_to_hub,
|
||||||
|
card_kwargs=card_kwargs,
|
||||||
|
**push_to_hub_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(
|
def from_pretrained(
|
||||||
cls,
|
cls,
|
||||||
@@ -430,9 +472,9 @@ class DataProcessorPipeline(ModelHubMixin, Generic[TInput, TOutput]):
|
|||||||
cache_dir: The path to a specific cache folder to store downloaded files.
|
cache_dir: The path to a specific cache folder to store downloaded files.
|
||||||
local_files_only: If True, avoid downloading files from the Hub.
|
local_files_only: If True, avoid downloading files from the Hub.
|
||||||
revision: The specific model version to use (e.g., a branch name, tag name, or commit id).
|
revision: The specific model version to use (e.g., a branch name, tag name, or commit id).
|
||||||
config_filename: The name of the pipeline's JSON configuration file. Required when
|
config_filename: The name of the pipeline's JSON configuration file. If not provided,
|
||||||
loading from the Hub. If loading from a local directory, it's inferred if there's
|
it's auto-detected in local directories (if only one .json file exists). This parameter
|
||||||
only one `.json` file.
|
is mandatory when loading from Hugging Face Hub repositories.
|
||||||
overrides: A dictionary to override the configuration of specific steps. Keys should
|
overrides: A dictionary to override the configuration of specific steps. Keys should
|
||||||
match the step's class name or registry name.
|
match the step's class name or registry name.
|
||||||
to_transition: A custom function to convert input data to `EnvTransition`.
|
to_transition: A custom function to convert input data to `EnvTransition`.
|
||||||
@@ -448,63 +490,85 @@ class DataProcessorPipeline(ModelHubMixin, Generic[TInput, TOutput]):
|
|||||||
ImportError: If a step's class cannot be imported.
|
ImportError: If a step's class cannot be imported.
|
||||||
KeyError: If an override key doesn't match any step in the pipeline.
|
KeyError: If an override key doesn't match any step in the pipeline.
|
||||||
"""
|
"""
|
||||||
source = str(pretrained_model_name_or_path)
|
model_id = str(pretrained_model_name_or_path)
|
||||||
|
loaded_config: dict[str, Any] | None = None
|
||||||
|
base_path: Path | None = None
|
||||||
|
|
||||||
# Heuristic to distinguish a local path from a Hub repository ID.
|
# Standard pattern: try local directory first
|
||||||
is_local_path = (
|
if Path(model_id).is_dir():
|
||||||
Path(source).is_dir()
|
base_path = Path(model_id)
|
||||||
or Path(source).is_absolute()
|
|
||||||
or source.startswith("./")
|
|
||||||
or source.startswith("../")
|
|
||||||
# A simple heuristic: repo IDs usually don't have more than one slash.
|
|
||||||
or source.count("/") > 1
|
|
||||||
or "\\" in source
|
|
||||||
)
|
|
||||||
|
|
||||||
# Load configuration from a local directory.
|
# Handle config filename
|
||||||
if is_local_path:
|
|
||||||
base_path = Path(source)
|
|
||||||
|
|
||||||
# If config filename is not provided, try to find a unique .json file.
|
|
||||||
if config_filename is None:
|
if config_filename is None:
|
||||||
json_files = list(base_path.glob("*.json"))
|
json_files = list(base_path.glob("*.json"))
|
||||||
if len(json_files) == 0:
|
if len(json_files) == 0:
|
||||||
raise FileNotFoundError(f"No .json configuration files found in {source}")
|
# No config files found locally, will try Hub next
|
||||||
elif len(json_files) > 1:
|
pass
|
||||||
|
elif len(json_files) == 1:
|
||||||
|
config_filename = json_files[0].name
|
||||||
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Multiple .json files found in {source}: {[f.name for f in json_files]}. "
|
f"Multiple .json files found in {model_id}: {[f.name for f in json_files]}. "
|
||||||
f"Please specify which one to load using the config_filename parameter."
|
f"Please specify which one to load using the config_filename parameter."
|
||||||
)
|
)
|
||||||
config_filename = json_files[0].name
|
|
||||||
|
|
||||||
with open(base_path / config_filename) as file_pointer:
|
# Try to load config from local directory
|
||||||
loaded_config: dict[str, Any] = json.load(file_pointer)
|
if config_filename and (base_path / config_filename).exists():
|
||||||
# Load configuration from the Hugging Face Hub.
|
with open(base_path / config_filename) as f:
|
||||||
else:
|
loaded_config = json.load(f)
|
||||||
if config_filename is None:
|
|
||||||
raise ValueError(
|
# If not found locally, try Hub
|
||||||
f"For Hugging Face Hub repositories ({source}), you must specify the config_filename parameter. "
|
if loaded_config is None:
|
||||||
f"Example: DataProcessorPipeline.from_pretrained('{source}', config_filename='processor.json')"
|
# Check if this looks like a local path that doesn't exist
|
||||||
)
|
# Hub repo IDs have format "user/repo" with exactly one slash
|
||||||
# Download the configuration file from the Hub.
|
# Local paths typically have multiple slashes, backslashes, or start with ./ or ../
|
||||||
config_path = hf_hub_download(
|
looks_like_local_path = (
|
||||||
source,
|
model_id.count("/") > 1 # Multiple slashes suggest local path
|
||||||
config_filename,
|
or "\\" in model_id # Backslashes are only in local paths
|
||||||
repo_type="model",
|
or Path(model_id).is_absolute() # Absolute paths are local
|
||||||
force_download=force_download,
|
or model_id.startswith("./")
|
||||||
resume_download=resume_download,
|
or model_id.startswith("../") # Relative path indicators
|
||||||
proxies=proxies,
|
|
||||||
token=token,
|
|
||||||
cache_dir=cache_dir,
|
|
||||||
local_files_only=local_files_only,
|
|
||||||
revision=revision,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
with open(config_path) as file_pointer:
|
if looks_like_local_path:
|
||||||
loaded_config = json.load(file_pointer)
|
# This appears to be a local path that doesn't exist
|
||||||
|
raise FileNotFoundError(f"Local path '{model_id}' does not exist")
|
||||||
|
# For Hub repositories, config_filename is mandatory
|
||||||
|
if config_filename is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"When loading from Hugging Face Hub, 'config_filename' must be specified. "
|
||||||
|
f"Example: DataProcessorPipeline.from_pretrained('{model_id}', config_filename='processor.json')"
|
||||||
|
)
|
||||||
|
|
||||||
# The base path for other files (like state tensors) is the directory of the config file.
|
try:
|
||||||
base_path = Path(config_path).parent
|
# Download the configuration file from the Hub
|
||||||
|
config_path = hf_hub_download(
|
||||||
|
repo_id=model_id,
|
||||||
|
filename=config_filename,
|
||||||
|
repo_type="model",
|
||||||
|
force_download=force_download,
|
||||||
|
resume_download=resume_download,
|
||||||
|
proxies=proxies,
|
||||||
|
token=token,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
revision=revision,
|
||||||
|
)
|
||||||
|
|
||||||
|
with open(config_path) as f:
|
||||||
|
loaded_config = json.load(f)
|
||||||
|
|
||||||
|
# The base path for other files (like state tensors) is the directory of the config file
|
||||||
|
base_path = Path(config_path).parent
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"Could not find {config_filename} on the HuggingFace Hub at {model_id}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
# At this point, loaded_config must be loaded successfully
|
||||||
|
if loaded_config is None:
|
||||||
|
raise RuntimeError("Failed to load configuration from local directory or Hub")
|
||||||
|
|
||||||
if overrides is None:
|
if overrides is None:
|
||||||
overrides = {}
|
overrides = {}
|
||||||
@@ -556,13 +620,14 @@ class DataProcessorPipeline(ModelHubMixin, Generic[TInput, TOutput]):
|
|||||||
|
|
||||||
# Load the step's state if a state file is specified.
|
# Load the step's state if a state file is specified.
|
||||||
if "state_file" in step_entry and hasattr(step_instance, "load_state_dict"):
|
if "state_file" in step_entry and hasattr(step_instance, "load_state_dict"):
|
||||||
if is_local_path:
|
# Check if state file exists locally first
|
||||||
|
if base_path and (base_path / step_entry["state_file"]).exists():
|
||||||
state_path = str(base_path / step_entry["state_file"])
|
state_path = str(base_path / step_entry["state_file"])
|
||||||
else:
|
else:
|
||||||
# Download the state file from the Hub.
|
# Download the state file from the Hub.
|
||||||
state_path = hf_hub_download(
|
state_path = hf_hub_download(
|
||||||
source,
|
repo_id=model_id,
|
||||||
step_entry["state_file"],
|
filename=step_entry["state_file"],
|
||||||
repo_type="model",
|
repo_type="model",
|
||||||
force_download=force_download,
|
force_download=force_download,
|
||||||
resume_download=resume_download,
|
resume_download=resume_download,
|
||||||
|
|||||||
@@ -1736,12 +1736,14 @@ def test_from_pretrained_nonexistent_path():
|
|||||||
with pytest.raises(FileNotFoundError):
|
with pytest.raises(FileNotFoundError):
|
||||||
DataProcessorPipeline.from_pretrained("/path/that/does/not/exist")
|
DataProcessorPipeline.from_pretrained("/path/that/does/not/exist")
|
||||||
|
|
||||||
# Test with a Hub repo format that would be a local path (too many slashes)
|
# Test with a path that doesn't exist as a directory
|
||||||
with pytest.raises(FileNotFoundError):
|
with pytest.raises(FileNotFoundError):
|
||||||
DataProcessorPipeline.from_pretrained("user/repo/extra/path")
|
DataProcessorPipeline.from_pretrained("user/repo/extra/path")
|
||||||
|
|
||||||
# Test with a non-existent but valid Hub repo format (now requires config_filename)
|
# Test with a Hub repo without specifying config_filename (should raise ValueError)
|
||||||
with pytest.raises(ValueError, match="you must specify the config_filename parameter"):
|
with pytest.raises(
|
||||||
|
ValueError, match="When loading from Hugging Face Hub, 'config_filename' must be specified"
|
||||||
|
):
|
||||||
DataProcessorPipeline.from_pretrained("nonexistent-user/nonexistent-repo")
|
DataProcessorPipeline.from_pretrained("nonexistent-user/nonexistent-repo")
|
||||||
|
|
||||||
# Test with a non-existent Hub repo when config_filename is provided
|
# Test with a non-existent Hub repo when config_filename is provided
|
||||||
@@ -1752,7 +1754,8 @@ def test_from_pretrained_nonexistent_path():
|
|||||||
|
|
||||||
# Test with a local directory that exists but has no config files
|
# Test with a local directory that exists but has no config files
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
with pytest.raises(FileNotFoundError, match="No .json configuration files found"):
|
# Since the directory exists but has no config, it will try Hub and fail
|
||||||
|
with pytest.raises(FileNotFoundError):
|
||||||
DataProcessorPipeline.from_pretrained(tmp_dir)
|
DataProcessorPipeline.from_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user