From 066308ceb8198f95ebfe6c8349b872c1f5279cbe Mon Sep 17 00:00:00 2001 From: Adil Zouitine Date: Mon, 15 Sep 2025 13:13:35 +0200 Subject: [PATCH] refactor(processor): replace ModelHubMixin with HubMixin and enhance save_pretrained method (#1937) - Updated DataProcessorPipeline to use HubMixin instead of ModelHubMixin for improved functionality. - Refactored save_pretrained method to handle saving --- src/lerobot/processor/pipeline.py | 205 ++++++++++++++++++++---------- tests/processor/test_pipeline.py | 11 +- 2 files changed, 142 insertions(+), 74 deletions(-) diff --git a/src/lerobot/processor/pipeline.py b/src/lerobot/processor/pipeline.py index e2f16d2cf..6664ba8a5 100644 --- a/src/lerobot/processor/pipeline.py +++ b/src/lerobot/processor/pipeline.py @@ -42,10 +42,11 @@ from pathlib import Path from typing import Any, Generic, TypeAlias, TypedDict, TypeVar, cast import torch -from huggingface_hub import ModelHubMixin, hf_hub_download +from huggingface_hub import hf_hub_download from safetensors.torch import load_file, save_file from lerobot.configs.types import PipelineFeatureType, PolicyFeature +from lerobot.utils.hub import HubMixin from .converters import batch_to_transition, create_transition, transition_to_batch from .core import EnvAction, EnvTransition, PolicyAction, RobotAction, TransitionKey @@ -67,7 +68,7 @@ class ProcessorStepRegistry: _registry: dict[str, type] = {} @classmethod - def register(cls, name: str = None): + def register(cls, name: str | None = None): """A class decorator to register a ProcessorStep. Args: @@ -237,7 +238,7 @@ class ProcessorKwargs(TypedDict, total=False): @dataclass -class DataProcessorPipeline(ModelHubMixin, Generic[TInput, TOutput]): +class DataProcessorPipeline(HubMixin, Generic[TInput, TOutput]): """A sequential pipeline for processing data, integrated with the Hugging Face Hub. This class chains together multiple `ProcessorStep` instances to form a complete @@ -324,24 +325,11 @@ class DataProcessorPipeline(ModelHubMixin, Generic[TInput, TOutput]): yield transition def _save_pretrained(self, save_directory: Path, **kwargs): - """Internal method to comply with `ModelHubMixin`'s saving mechanism.""" - config_filename = kwargs.pop("config_filename", None) - self.save_pretrained(save_directory, config_filename=config_filename) + """Internal method to comply with `HubMixin`'s saving mechanism. - def save_pretrained(self, save_directory: str | Path, config_filename: str | None = None, **kwargs): - """Saves the pipeline's configuration and state to a directory. - - This method creates a JSON configuration file that defines the pipeline's structure - (name and steps). For each stateful step, it also saves a `.safetensors` file - containing its state dictionary. - - Args: - save_directory: The directory where the pipeline will be saved. - config_filename: The name of the JSON configuration file. If None, a name is - generated from the pipeline's `name` attribute. - **kwargs: Additional arguments (not used, but present for compatibility). + This method does the actual saving work and is called by HubMixin.save_pretrained. """ - os.makedirs(str(save_directory), exist_ok=True) + config_filename = kwargs.pop("config_filename", None) # Sanitize the pipeline name to create a valid filename prefix. sanitized_name = re.sub(r"[^a-zA-Z0-9_]", "_", self.name.lower()) @@ -393,6 +381,60 @@ class DataProcessorPipeline(ModelHubMixin, Generic[TInput, TOutput]): with open(os.path.join(str(save_directory), config_filename), "w") as file_pointer: json.dump(config, file_pointer, indent=2) + def save_pretrained( + self, + save_directory: str | Path | None = None, + *, + repo_id: str | None = None, + push_to_hub: bool = False, + card_kwargs: dict[str, Any] | None = None, + config_filename: str | None = None, + **push_to_hub_kwargs, + ): + """Saves the pipeline's configuration and state to a directory. + + This method creates a JSON configuration file that defines the pipeline's structure + (name and steps). For each stateful step, it also saves a `.safetensors` file + containing its state dictionary. + + Args: + save_directory: The directory where the pipeline will be saved. If None, saves to + HF_LEROBOT_HOME/processors/{sanitized_pipeline_name}. + repo_id: ID of your repository on the Hub. Used only if `push_to_hub=True`. + push_to_hub: Whether or not to push your object to the Hugging Face Hub after saving it. + card_kwargs: Additional arguments passed to the card template to customize the card. + config_filename: The name of the JSON configuration file. If None, a name is + generated from the pipeline's `name` attribute. + **push_to_hub_kwargs: Additional key word arguments passed along to the push_to_hub method. + """ + if save_directory is None: + # Use default directory in HF_LEROBOT_HOME + from lerobot.constants import HF_LEROBOT_HOME + + sanitized_name = re.sub(r"[^a-zA-Z0-9_]", "_", self.name.lower()) + save_directory = HF_LEROBOT_HOME / "processors" / sanitized_name + + # For direct saves (not through hub), handle config_filename + if not push_to_hub and config_filename is not None: + # Call _save_pretrained directly with config_filename + save_directory = Path(save_directory) + save_directory.mkdir(parents=True, exist_ok=True) + self._save_pretrained(save_directory, config_filename=config_filename) + return None + + # Pass config_filename through kwargs for _save_pretrained when using hub + if config_filename is not None: + push_to_hub_kwargs["config_filename"] = config_filename + + # Call parent's save_pretrained which will call our _save_pretrained + return super().save_pretrained( + save_directory=save_directory, + repo_id=repo_id, + push_to_hub=push_to_hub, + card_kwargs=card_kwargs, + **push_to_hub_kwargs, + ) + @classmethod def from_pretrained( cls, @@ -430,9 +472,9 @@ class DataProcessorPipeline(ModelHubMixin, Generic[TInput, TOutput]): cache_dir: The path to a specific cache folder to store downloaded files. local_files_only: If True, avoid downloading files from the Hub. revision: The specific model version to use (e.g., a branch name, tag name, or commit id). - config_filename: The name of the pipeline's JSON configuration file. Required when - loading from the Hub. If loading from a local directory, it's inferred if there's - only one `.json` file. + config_filename: The name of the pipeline's JSON configuration file. If not provided, + it's auto-detected in local directories (if only one .json file exists). This parameter + is mandatory when loading from Hugging Face Hub repositories. overrides: A dictionary to override the configuration of specific steps. Keys should match the step's class name or registry name. to_transition: A custom function to convert input data to `EnvTransition`. @@ -448,63 +490,85 @@ class DataProcessorPipeline(ModelHubMixin, Generic[TInput, TOutput]): ImportError: If a step's class cannot be imported. KeyError: If an override key doesn't match any step in the pipeline. """ - source = str(pretrained_model_name_or_path) + model_id = str(pretrained_model_name_or_path) + loaded_config: dict[str, Any] | None = None + base_path: Path | None = None - # Heuristic to distinguish a local path from a Hub repository ID. - is_local_path = ( - Path(source).is_dir() - or Path(source).is_absolute() - or source.startswith("./") - or source.startswith("../") - # A simple heuristic: repo IDs usually don't have more than one slash. - or source.count("/") > 1 - or "\\" in source - ) + # Standard pattern: try local directory first + if Path(model_id).is_dir(): + base_path = Path(model_id) - # Load configuration from a local directory. - if is_local_path: - base_path = Path(source) - - # If config filename is not provided, try to find a unique .json file. + # Handle config filename if config_filename is None: json_files = list(base_path.glob("*.json")) if len(json_files) == 0: - raise FileNotFoundError(f"No .json configuration files found in {source}") - elif len(json_files) > 1: + # No config files found locally, will try Hub next + pass + elif len(json_files) == 1: + config_filename = json_files[0].name + else: raise ValueError( - f"Multiple .json files found in {source}: {[f.name for f in json_files]}. " + f"Multiple .json files found in {model_id}: {[f.name for f in json_files]}. " f"Please specify which one to load using the config_filename parameter." ) - config_filename = json_files[0].name - with open(base_path / config_filename) as file_pointer: - loaded_config: dict[str, Any] = json.load(file_pointer) - # Load configuration from the Hugging Face Hub. - else: - if config_filename is None: - raise ValueError( - f"For Hugging Face Hub repositories ({source}), you must specify the config_filename parameter. " - f"Example: DataProcessorPipeline.from_pretrained('{source}', config_filename='processor.json')" - ) - # Download the configuration file from the Hub. - config_path = hf_hub_download( - source, - config_filename, - repo_type="model", - force_download=force_download, - resume_download=resume_download, - proxies=proxies, - token=token, - cache_dir=cache_dir, - local_files_only=local_files_only, - revision=revision, + # Try to load config from local directory + if config_filename and (base_path / config_filename).exists(): + with open(base_path / config_filename) as f: + loaded_config = json.load(f) + + # If not found locally, try Hub + if loaded_config is None: + # Check if this looks like a local path that doesn't exist + # Hub repo IDs have format "user/repo" with exactly one slash + # Local paths typically have multiple slashes, backslashes, or start with ./ or ../ + looks_like_local_path = ( + model_id.count("/") > 1 # Multiple slashes suggest local path + or "\\" in model_id # Backslashes are only in local paths + or Path(model_id).is_absolute() # Absolute paths are local + or model_id.startswith("./") + or model_id.startswith("../") # Relative path indicators ) - with open(config_path) as file_pointer: - loaded_config = json.load(file_pointer) + if looks_like_local_path: + # This appears to be a local path that doesn't exist + raise FileNotFoundError(f"Local path '{model_id}' does not exist") + # For Hub repositories, config_filename is mandatory + if config_filename is None: + raise ValueError( + f"When loading from Hugging Face Hub, 'config_filename' must be specified. " + f"Example: DataProcessorPipeline.from_pretrained('{model_id}', config_filename='processor.json')" + ) - # The base path for other files (like state tensors) is the directory of the config file. - base_path = Path(config_path).parent + try: + # Download the configuration file from the Hub + config_path = hf_hub_download( + repo_id=model_id, + filename=config_filename, + repo_type="model", + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + token=token, + cache_dir=cache_dir, + local_files_only=local_files_only, + revision=revision, + ) + + with open(config_path) as f: + loaded_config = json.load(f) + + # The base path for other files (like state tensors) is the directory of the config file + base_path = Path(config_path).parent + + except Exception as e: + raise FileNotFoundError( + f"Could not find {config_filename} on the HuggingFace Hub at {model_id}" + ) from e + + # At this point, loaded_config must be loaded successfully + if loaded_config is None: + raise RuntimeError("Failed to load configuration from local directory or Hub") if overrides is None: overrides = {} @@ -556,13 +620,14 @@ class DataProcessorPipeline(ModelHubMixin, Generic[TInput, TOutput]): # Load the step's state if a state file is specified. if "state_file" in step_entry and hasattr(step_instance, "load_state_dict"): - if is_local_path: + # Check if state file exists locally first + if base_path and (base_path / step_entry["state_file"]).exists(): state_path = str(base_path / step_entry["state_file"]) else: # Download the state file from the Hub. state_path = hf_hub_download( - source, - step_entry["state_file"], + repo_id=model_id, + filename=step_entry["state_file"], repo_type="model", force_download=force_download, resume_download=resume_download, diff --git a/tests/processor/test_pipeline.py b/tests/processor/test_pipeline.py index 3fe483df5..4ee07e3a9 100644 --- a/tests/processor/test_pipeline.py +++ b/tests/processor/test_pipeline.py @@ -1736,12 +1736,14 @@ def test_from_pretrained_nonexistent_path(): with pytest.raises(FileNotFoundError): DataProcessorPipeline.from_pretrained("/path/that/does/not/exist") - # Test with a Hub repo format that would be a local path (too many slashes) + # Test with a path that doesn't exist as a directory with pytest.raises(FileNotFoundError): DataProcessorPipeline.from_pretrained("user/repo/extra/path") - # Test with a non-existent but valid Hub repo format (now requires config_filename) - with pytest.raises(ValueError, match="you must specify the config_filename parameter"): + # Test with a Hub repo without specifying config_filename (should raise ValueError) + with pytest.raises( + ValueError, match="When loading from Hugging Face Hub, 'config_filename' must be specified" + ): DataProcessorPipeline.from_pretrained("nonexistent-user/nonexistent-repo") # Test with a non-existent Hub repo when config_filename is provided @@ -1752,7 +1754,8 @@ def test_from_pretrained_nonexistent_path(): # Test with a local directory that exists but has no config files with tempfile.TemporaryDirectory() as tmp_dir: - with pytest.raises(FileNotFoundError, match="No .json configuration files found"): + # Since the directory exists but has no config, it will try Hub and fail + with pytest.raises(FileNotFoundError): DataProcessorPipeline.from_pretrained(tmp_dir)