From 1ccdf365d21f9ef37af20179c381be80a6748179 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Fri, 12 Sep 2025 17:54:27 +0200 Subject: [PATCH] docs(processor): update docstrings pipeline (#1920) --- src/lerobot/processor/pipeline.py | 809 +++++++++++++----------------- 1 file changed, 351 insertions(+), 458 deletions(-) diff --git a/src/lerobot/processor/pipeline.py b/src/lerobot/processor/pipeline.py index 644fca180..5f940a77d 100644 --- a/src/lerobot/processor/pipeline.py +++ b/src/lerobot/processor/pipeline.py @@ -13,11 +13,27 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +""" +This module defines a generic, sequential data processing pipeline framework, primarily designed for +transforming robotics data (observations, actions, rewards, etc.). + +The core components are: +- ProcessorStep: An abstract base class for a single data transformation operation. +- ProcessorStepRegistry: A mechanism to register and retrieve ProcessorStep classes by name. +- DataProcessorPipeline: A class that chains multiple ProcessorStep instances together to form a complete + data processing workflow. It integrates with the Hugging Face Hub for easy sharing and versioning of + pipelines, including their configuration and state. +- Specialized abstract ProcessorStep subclasses (e.g., ObservationProcessorStep, ActionProcessorStep) + to simplify the creation of steps that target specific parts of a data transition. +""" + from __future__ import annotations import importlib import json import os +import re from abc import ABC, abstractmethod from collections.abc import Callable, Iterable, Sequence from copy import deepcopy @@ -34,30 +50,38 @@ from lerobot.configs.types import PipelineFeatureType, PolicyFeature from .converters import batch_to_transition, create_transition, transition_to_batch from .core import EnvAction, EnvTransition, PolicyAction, RobotAction, TransitionKey -# Type variables for generic processor input and output types +# Generic type variables for pipeline input and output. TInput = TypeVar("TInput") TOutput = TypeVar("TOutput") class ProcessorStepRegistry: - """Registry for processor steps that enables saving/loading by name instead of module path.""" + """A registry for ProcessorStep classes to allow instantiation from a string name. + + This class provides a way to map string identifiers to `ProcessorStep` classes, + which is useful for deserializing pipelines from configuration files without + + hardcoding class imports. + """ _registry: dict[str, type] = {} @classmethod def register(cls, name: str = None): - """Decorator to register a processor step class. + """A class decorator to register a ProcessorStep. Args: - name: Optional registration name. If not provided, uses class name. + name: The name to register the class under. If None, the class's `__name__` is used. - Example: - @ProcessorStepRegistry.register("adaptive_normalizer") - class AdaptiveObservationNormalizer: - ... + Returns: + A decorator function that registers the class and returns it. + + Raises: + ValueError: If a step with the same name is already registered. """ def decorator(step_class: type) -> type: + """The actual decorator that performs the registration.""" registration_name = name if name is not None else step_class.__name__ if registration_name in cls._registry: @@ -67,7 +91,7 @@ class ProcessorStepRegistry: ) cls._registry[registration_name] = step_class - # Store the registration name on the class for later reference + # Store the registration name on the class for easy lookup during serialization. step_class._registry_name = registration_name return step_class @@ -75,16 +99,16 @@ class ProcessorStepRegistry: @classmethod def get(cls, name: str) -> type: - """Get a registered processor step class by name. + """Retrieves a processor step class from the registry by its name. Args: - name: The registration name of the step. + name: The name of the step to retrieve. Returns: - The registered step class. + The processor step class corresponding to the given name. Raises: - KeyError: If the step is not registered. + KeyError: If the name is not found in the registry. """ if name not in cls._registry: available = list(cls._registry.keys()) @@ -97,87 +121,113 @@ class ProcessorStepRegistry: @classmethod def unregister(cls, name: str) -> None: - """Remove a step from the registry.""" + """Removes a processor step from the registry. + + Args: + name: The name of the step to unregister. + """ cls._registry.pop(name, None) @classmethod def list(cls) -> list[str]: - """List all registered step names.""" + """Returns a list of all registered processor step names.""" return list(cls._registry.keys()) @classmethod def clear(cls) -> None: - """Clear all registrations.""" + """Clears all processor steps from the registry.""" cls._registry.clear() class ProcessorStep(ABC): - """Structural typing interface for a single processor step. + """Abstract base class for a single step in a data processing pipeline. - A step is any callable accepting a full `EnvTransition` dict and - returning a (possibly modified) dict of the same structure. Implementers - are encouraged—but not required—to expose the optional helper methods - listed below. When present, these hooks let `DataProcessorPipeline` - automatically serialise the step's configuration and learnable state using - a safe-to-share JSON + SafeTensors format. + Each step must implement the `__call__` method to perform its transformation + on a data transition and the `transform_features` method to describe how it + alters the shape or type of data features. - - **Required**: - - ``__call__(transition: EnvTransition) -> EnvTransition`` - - Optional helper protocol: - * ``get_config() -> dict[str, Any]`` – User-defined JSON-serializable - configuration and state. YOU decide what to save here. This is where all - non-tensor state goes (e.g., name, counter, threshold, window_size). - The config dict will be passed to your class constructor when loading. - * ``state_dict() -> dict[str, torch.Tensor]`` – PyTorch tensor state ONLY. - This is exclusively for torch.Tensor objects (e.g., learned weights, - running statistics as tensors). Never put simple Python types here. - * ``load_state_dict(state)`` – Inverse of ``state_dict``. Receives a dict - containing torch tensors only. - * ``reset()`` – Clear internal buffers at episode boundaries. - * ``transform_features(features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]`` - If present, this method will be called to aggregate the dataset features of all steps. - - Example separation: - - get_config(): {"name": "my_step", "learning_rate": 0.01, "window_size": 10} - - state_dict(): {"weights": torch.tensor(...), "running_mean": torch.tensor(...)} + Subclasses can optionally be stateful by implementing `state_dict` and `load_state_dict`. """ _current_transition: EnvTransition | None = None @property def transition(self) -> EnvTransition: - """The current transition being processed by this step.""" + """Provides access to the most recent transition being processed. + + This is useful for steps that need to access other parts of the transition + data beyond their primary target (e.g., an action processing step that + needs to look at the observation). + + Raises: + ValueError: If accessed before the step has been called with a transition. + """ if self._current_transition is None: raise ValueError("Transition is not set. Make sure to call the step with a transition first.") return self._current_transition @abstractmethod def __call__(self, transition: EnvTransition) -> EnvTransition: + """Processes an environment transition. + + This method should contain the core logic of the processing step. + + Args: + transition: The input data transition to be processed. + + Returns: + The processed transition. + """ return transition def get_config(self) -> dict[str, Any]: + """Returns the configuration of the step for serialization. + + Returns: + A JSON-serializable dictionary of configuration parameters. + """ return {} def state_dict(self) -> dict[str, torch.Tensor]: + """Returns the state of the step (e.g., learned parameters, running means). + + Returns: + A dictionary mapping state names to tensors. + """ return {} def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: + """Loads the step's state from a state dictionary. + + Args: + state: A dictionary of state tensors. + """ return None def reset(self) -> None: + """Resets the internal state of the processor step, if any.""" return None @abstractmethod def transform_features( self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + """Defines how this step modifies the description of pipeline features. + + This method is used to track changes in data shapes, dtypes, or modalities + as data flows through the pipeline, without needing to process actual data. + + Args: + features: A dictionary describing the input features for observations, actions, etc. + + Returns: + A dictionary describing the output features after this step's transformation. + """ return features class ProcessorKwargs(TypedDict, total=False): - """Keyword arguments for DataProcessorPipeline constructor.""" + """A TypedDict for optional keyword arguments used in pipeline construction.""" to_transition: Callable[[dict[str, Any]], EnvTransition] | None to_output: Callable[[EnvTransition], Any] | None @@ -188,65 +238,19 @@ class ProcessorKwargs(TypedDict, total=False): @dataclass class DataProcessorPipeline(ModelHubMixin, Generic[TInput, TOutput]): - """ - Composable, debuggable post-processing processor for robot transitions. + """A sequential pipeline for processing data, integrated with the Hugging Face Hub. - The class orchestrates an ordered collection of small, functional transforms—steps—executed - left-to-right on each incoming `EnvTransition`. It can process both `EnvTransition` dicts - and batch dictionaries, automatically converting between formats as needed. + This class chains together multiple `ProcessorStep` instances to form a complete + data processing workflow. It's generic, allowing for custom input and output types, + which are handled by the `to_transition` and `to_output` converters. - The processor is generic over its output type TOutput, which provides better type safety - and clarity about what the processor returns. - - Args: - steps: Ordered list of processing steps executed on every call. Defaults to empty list. - name: Human-readable identifier that is persisted inside the JSON config. - Defaults to "DataProcessorPipeline". - to_transition: Function to convert batch dict to EnvTransition dict. - Defaults to _default_batch_to_transition. - to_output: Function to convert EnvTransition dict to the desired output format of type TOutput. - Defaults to _default_transition_to_batch (returns batch dict). - Use identity function (lambda x: x) for EnvTransition output. - before_step_hooks: List of hooks called before each step. Each hook receives the step - index and transition, and can optionally return a modified transition. - after_step_hooks: List of hooks called after each step. Each hook receives the step - index and transition, and can optionally return a modified transition. - - Type Safety Examples: - ```python - # Default behavior - returns batch dict - processor: DataProcessorPipeline[dict[str, Any]] = DataProcessorPipeline( - steps=[some_step1, some_step2] - ) - result: dict[str, Any] = processor(batch_data) # Type checker knows this is a dict - - # For EnvTransition output, explicitly specify identity function - transition_processor: DataProcessorPipeline[EnvTransition, EnvTransition] = DataProcessorPipeline( - steps=[some_step1, some_step2], - to_output=lambda x: x, # Identity function - ) - result: EnvTransition = transition_processor(batch_data) # Type checker knows this is EnvTransition - - # For custom output types - processor: DataProcessorPipeline[dict[str, Any], str] = DataProcessorPipeline( - steps=[custom_step], to_output=lambda t: f"Processed {len(t)} keys" - ) - result: str = processor(batch_data) # Type checker knows this is str - ``` - - Hook Semantics: - - Hooks are executed sequentially in the order they were registered. There is no way to - reorder hooks after registration without creating a new pipeline. - - Hooks are for observation/monitoring only and DO NOT modify transitions. They are called - with the step index and current transition for logging, debugging, or monitoring purposes. - - All hooks for a given type (before/after) are executed for every step, or none at all if - an error occurs. There is no partial execution of hooks. - - Hooks should generally be stateless to maintain predictable behavior. If you need stateful - processing, consider implementing a proper ProcessorStep instead. - - To remove hooks, use the unregister methods. To remove steps, you must create a new pipeline. - - Hooks ALWAYS receive transitions in EnvTransition format, regardless of the input format - passed to __call__. This ensures consistent hook behavior whether processing batch dicts - or EnvTransition objects. + Attributes: + steps: A sequence of `ProcessorStep` objects that make up the pipeline. + name: A descriptive name for the pipeline. + to_transition: A function to convert raw input data into the standardized `EnvTransition` format. + to_output: A function to convert the final `EnvTransition` into the desired output format. + before_step_hooks: A list of functions to be called before each step is executed. + after_step_hooks: A list of functions to be called after each step is executed. """ steps: Sequence[ProcessorStep] = field(default_factory=list) @@ -256,106 +260,92 @@ class DataProcessorPipeline(ModelHubMixin, Generic[TInput, TOutput]): default_factory=lambda: cast(Callable[[TInput], EnvTransition], batch_to_transition), repr=False ) to_output: Callable[[EnvTransition], TOutput] = field( - # Cast is necessary here: Working around Python type-checker limitation. - # _default_transition_to_batch returns dict[str, Any], but we need it to be TOutput - # for the generic to work. When no explicit type is given, TOutput defaults to dict[str, Any], - # making this cast safe. default_factory=lambda: cast(Callable[[EnvTransition], TOutput], transition_to_batch), repr=False, ) - # Processor-level hooks for observation/monitoring - # Hooks do not modify transitions - they are called for logging, debugging, or monitoring purposes before_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False) after_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False) def __call__(self, data: TInput) -> TOutput: - """Process data through all steps. - - The method accepts a batch dictionary (like the ones returned by ReplayBuffer or - LeRobotDataset). It is first converted to EnvTransition format using to_transition, - then processed through all steps, and finally converted to the output format using to_output. + """Processes input data through the full pipeline. Args: - data: A batch dictionary to process. + data: The input data to process. Returns: - The processed data in the format specified by to_output. + The processed data in the specified output format. """ - # Always convert input through to_transition transition = self.to_transition(data) - transformed_transition = self._forward(transition) - - # Always use to_output for consistent typing return self.to_output(transformed_transition) def _forward(self, transition: EnvTransition) -> EnvTransition: - # Process through all steps + """Executes all processing steps and hooks in sequence. + + Args: + transition: The initial `EnvTransition` object. + + Returns: + The final `EnvTransition` after all steps have been applied. + """ for idx, processor_step in enumerate(self.steps): - # Apply before hooks + # Execute pre-hooks for hook in self.before_step_hooks: hook(idx, transition) - # Execute step transition = processor_step(transition) - # Apply after hooks + # Execute post-hooks for hook in self.after_step_hooks: hook(idx, transition) return transition def step_through(self, data: TInput) -> Iterable[EnvTransition]: - """Yield the intermediate results after each processor step. + """Processes data step-by-step, yielding the transition at each stage. - This is a low-level method that does NOT apply hooks. It simply executes each step - and yields the intermediate results. This allows users to debug the pipeline or - apply custom logic between steps if needed. - - Note: This method always yields EnvTransition objects regardless of output format. - If you need the results in the output format, you'll need to convert them - using `to_output()`. + This is a generator method useful for debugging and inspecting the intermediate + state of the data as it passes through the pipeline. Args: - data: A batch dictionary to process. + data: The input data. Yields: - The intermediate EnvTransition results after each step. + The `EnvTransition` object, starting with the initial state and then after + each processing step. """ - # Always convert input through to_transition transition = self.to_transition(data) - # Yield initial state + # Yield the initial state before any processing. yield transition - # Process each step WITHOUT hooks (low-level method) for processor_step in self.steps: transition = processor_step(transition) yield transition def _save_pretrained(self, save_directory: Path, **kwargs): - """Internal save method for ModelHubMixin compatibility.""" - # Extract config_filename from kwargs if provided + """Internal method to comply with `ModelHubMixin`'s saving mechanism.""" config_filename = kwargs.pop("config_filename", None) self.save_pretrained(save_directory, config_filename=config_filename) def save_pretrained(self, save_directory: str | Path, config_filename: str | None = None, **kwargs): - """Serialize the processor definition and parameters to *save_directory*. + """Saves the pipeline's configuration and state to a directory. + + This method creates a JSON configuration file that defines the pipeline's structure + (name and steps). For each stateful step, it also saves a `.safetensors` file + containing its state dictionary. Args: - save_directory: Directory where the processor will be saved. - config_filename: Optional custom config filename. If not provided, defaults to - "{self.name}.json" where self.name is sanitized for filesystem compatibility. + save_directory: The directory where the pipeline will be saved. + config_filename: The name of the JSON configuration file. If None, a name is + generated from the pipeline's `name` attribute. + **kwargs: Additional arguments (not used, but present for compatibility). """ os.makedirs(str(save_directory), exist_ok=True) - # Sanitize processor name for use in filenames - import re - - # The huggingface hub does not allow special characters in the repo name, so we sanitize the name + # Sanitize the pipeline name to create a valid filename prefix. sanitized_name = re.sub(r"[^a-zA-Z0-9_]", "_", self.name.lower()) - # Use sanitized name for config if not provided if config_filename is None: config_filename = f"{sanitized_name}.json" @@ -364,40 +354,31 @@ class DataProcessorPipeline(ModelHubMixin, Generic[TInput, TOutput]): "steps": [], } + # Iterate through each step to build its configuration entry. for step_index, processor_step in enumerate(self.steps): - # Check if step was registered registry_name = getattr(processor_step.__class__, "_registry_name", None) step_entry: dict[str, Any] = {} + # Prefer registry name for portability, otherwise fall back to full class path. if registry_name: - # Use registry name for registered steps step_entry["registry_name"] = registry_name else: - # Fall back to full module path for unregistered steps step_entry["class"] = ( f"{processor_step.__class__.__module__}.{processor_step.__class__.__name__}" ) + # Save step configuration if `get_config` is implemented. if hasattr(processor_step, "get_config"): step_entry["config"] = processor_step.get_config() + # Save step state if `state_dict` is implemented and returns a non-empty dict. if hasattr(processor_step, "state_dict"): state = processor_step.state_dict() if state: - # Clone tensors to avoid shared memory issues - # This ensures each tensor has its own memory allocation - # The reason is to avoid the following error: - # RuntimeError: Some tensors share memory, this will lead to duplicate memory on disk - # and potential differences when loading them again - # ------------------------------------------------------------------------------ - # Since the state_dict of processor will be light, we can just clone the tensors - # and save them to the disk. - cloned_state = {} - for key, tensor in state.items(): - cloned_state[key] = tensor.clone() + # Clone tensors to avoid modifying the original state. + cloned_state = {key: tensor.clone() for key, tensor in state.items()} - # Include pipeline name and step index to ensure unique filenames - # This prevents conflicts when multiple processors are saved in the same directory + # Create a unique filename for the state file. if registry_name: state_filename = f"{sanitized_name}_step_{step_index}_{registry_name}.safetensors" else: @@ -408,6 +389,7 @@ class DataProcessorPipeline(ModelHubMixin, Generic[TInput, TOutput]): config["steps"].append(step_entry) + # Write the main configuration JSON file. with open(os.path.join(str(save_directory), config_filename), "w") as file_pointer: json.dump(config, file_pointer, indent=2) @@ -429,88 +411,62 @@ class DataProcessorPipeline(ModelHubMixin, Generic[TInput, TOutput]): to_output: Callable[[EnvTransition], TOutput] | None = None, **kwargs, ) -> DataProcessorPipeline[TInput, TOutput]: - """Load a serialized processor from source (local path or Hugging Face Hub identifier). + """Loads a pipeline from a local directory or a Hugging Face Hub repository. + + This method reconstructs a `DataProcessorPipeline` by: + 1. Loading the main JSON configuration file. + 2. Iterating through the steps defined in the config. + 3. Dynamically importing or looking up each step's class. + 4. Instantiating each step with its saved configuration, potentially with overrides. + 5. Loading the step's state from its `.safetensors` file, if it exists. Args: - pretrained_model_name_or_path: Local path to a saved processor directory or Hugging Face Hub identifier - (e.g., "username/processor-name"). - config_filename: Optional specific config filename to load. If not provided, will: - - For local paths: look for any .json file in the directory (error if multiple found) - - For HF Hub: REQUIRED - you must specify the exact config filename - overrides: Optional dictionary mapping step names to configuration overrides. - Keys must match exact step class names (for unregistered steps) or registry names - (for registered steps). Values are dictionaries containing parameter overrides - that will be merged with the saved configuration. This is useful for providing - non-serializable objects like environment instances. - to_transition: Function to convert batch dict to EnvTransition dict. - Defaults to _default_batch_to_transition. - to_output: Function to convert EnvTransition dict to the desired output format of type T. - Defaults to _default_transition_to_batch (returns batch dict). - Use identity function (lambda x: x) for EnvTransition output. + pretrained_model_name_or_path: The identifier of the repository on the Hugging Face Hub + or a path to a local directory. + force_download: Whether to force (re)downloading the files. + resume_download: Whether to resume a previously interrupted download. + proxies: A dictionary of proxy servers to use. + token: The token to use as HTTP bearer authorization for private Hub repositories. + cache_dir: The path to a specific cache folder to store downloaded files. + local_files_only: If True, avoid downloading files from the Hub. + revision: The specific model version to use (e.g., a branch name, tag name, or commit id). + config_filename: The name of the pipeline's JSON configuration file. Required when + loading from the Hub. If loading from a local directory, it's inferred if there's + only one `.json` file. + overrides: A dictionary to override the configuration of specific steps. Keys should + match the step's class name or registry name. + to_transition: A custom function to convert input data to `EnvTransition`. + to_output: A custom function to convert the final `EnvTransition` to the output format. + **kwargs: Additional arguments (not used). Returns: - A DataProcessorPipeline[TInput, TOutput] instance loaded from the saved configuration. + An instance of `DataProcessorPipeline` loaded with the specified configuration and state. Raises: - ImportError: If a processor step class cannot be loaded or imported. - ValueError: If a step cannot be instantiated with the provided configuration. - KeyError: If an override key doesn't match any step in the saved configuration. - - Examples: - Basic loading: - ```python - processor = DataProcessorPipeline.from_pretrained("path/to/processor") - ``` - - Loading from HF Hub (config_filename required): - ```python - processor = DataProcessorPipeline.from_pretrained( - "username/processor-repo", config_filename="processor.json" - ) - ``` - - Loading with overrides for non-serializable objects: - ```python - import gym - - env = gym.make("CartPole-v1") - processor = DataProcessorPipeline.from_pretrained( - "username/cartpole-processor", overrides={"ActionRepeatStep": {"env": env}} - ) - ``` - - Multiple overrides: - ```python - processor = DataProcessorPipeline.from_pretrained( - "path/to/processor", - overrides={ - "CustomStep": {"param1": "new_value"}, - "device_processor": {"device": "cuda:1"}, # For registered steps - }, - ) - ``` + FileNotFoundError: If the config file cannot be found. + ValueError: If configuration is ambiguous or instantiation fails. + ImportError: If a step's class cannot be imported. + KeyError: If an override key doesn't match any step in the pipeline. """ - # Use the local variable name 'source' for clarity source = str(pretrained_model_name_or_path) - # Check if it's a local path (either exists or looks like a filesystem path) - # Hub repositories are typically in the format "username/repo-name" (exactly one slash) - # Local paths are absolute paths, relative paths, or have more complex path structure + # Heuristic to distinguish a local path from a Hub repository ID. is_local_path = ( Path(source).is_dir() or Path(source).is_absolute() or source.startswith("./") or source.startswith("../") - or source.count("/") > 1 # More than one slash suggests local path, not Hub repo - or "\\" in source # Windows-style paths are definitely local + # A simple heuristic: repo IDs usually don't have more than one slash. + or source.count("/") > 1 + or "\\" in source ) + # Load configuration from a local directory. if is_local_path: - # Local path - use it directly base_path = Path(source) + # If config filename is not provided, try to find a unique .json file. if config_filename is None: - # Look for any .json file in the directory json_files = list(base_path.glob("*.json")) if len(json_files) == 0: raise FileNotFoundError(f"No .json configuration files found in {source}") @@ -523,14 +479,14 @@ class DataProcessorPipeline(ModelHubMixin, Generic[TInput, TOutput]): with open(base_path / config_filename) as file_pointer: loaded_config: dict[str, Any] = json.load(file_pointer) + # Load configuration from the Hugging Face Hub. else: - # Hugging Face Hub - download specific config file if config_filename is None: raise ValueError( f"For Hugging Face Hub repositories ({source}), you must specify the config_filename parameter. " f"Example: DataProcessorPipeline.from_pretrained('{source}', config_filename='processor.json')" ) - + # Download the configuration file from the Hub. config_path = hf_hub_download( source, config_filename, @@ -547,32 +503,28 @@ class DataProcessorPipeline(ModelHubMixin, Generic[TInput, TOutput]): with open(config_path) as file_pointer: loaded_config = json.load(file_pointer) - # Store downloaded files in the same directory as the config + # The base path for other files (like state tensors) is the directory of the config file. base_path = Path(config_path).parent - # Handle None overrides if overrides is None: overrides = {} - # Validate that all override keys will be matched override_keys = set(overrides.keys()) steps: list[ProcessorStep] = [] for step_entry in loaded_config["steps"]: - # Check if step uses registry name or module path + # Determine the step class, prioritizing the registry. if "registry_name" in step_entry: - # Load from registry try: step_class = ProcessorStepRegistry.get(step_entry["registry_name"]) step_key = step_entry["registry_name"] except KeyError as e: raise ImportError(f"Failed to load processor step from registry. {str(e)}") from e else: - # Fall back to module path loading for backward compatibility + # Fallback to dynamic import using the full class path. full_class_path = step_entry["class"] module_path, class_name = full_class_path.rsplit(".", 1) - # Import the module containing the step class try: module = importlib.import_module(module_path) step_class = getattr(module, class_name) @@ -585,14 +537,13 @@ class DataProcessorPipeline(ModelHubMixin, Generic[TInput, TOutput]): f"Error: {str(e)}" ) from e - # Instantiate the step with its config + # Instantiate the step, merging saved config with user-provided overrides. try: saved_cfg = step_entry.get("config", {}) step_overrides = overrides.get(step_key, {}) merged_cfg = {**saved_cfg, **step_overrides} step_instance: ProcessorStep = step_class(**merged_cfg) - # Track which override keys were used if step_key in override_keys: override_keys.discard(step_key) @@ -603,13 +554,12 @@ class DataProcessorPipeline(ModelHubMixin, Generic[TInput, TOutput]): f"Error: {str(e)}" ) from e - # Load state if available + # Load the step's state if a state file is specified. if "state_file" in step_entry and hasattr(step_instance, "load_state_dict"): - if Path(source).is_dir(): - # Local path - read directly + if is_local_path: state_path = str(base_path / step_entry["state_file"]) else: - # Hugging Face Hub - download the state file + # Download the state file from the Hub. state_path = hf_hub_download( source, step_entry["state_file"], @@ -627,16 +577,12 @@ class DataProcessorPipeline(ModelHubMixin, Generic[TInput, TOutput]): steps.append(step_instance) - # Check for unused override keys + # Check for any unused override keys, which likely indicates a typo by the user. if override_keys: - available_keys = [] - for step_entry in loaded_config["steps"]: - if "registry_name" in step_entry: - available_keys.append(step_entry["registry_name"]) - else: - full_class_path = step_entry["class"] - class_name = full_class_path.rsplit(".", 1)[1] - available_keys.append(class_name) + available_keys = [ + step.get("registry_name") or step["class"].rsplit(".", 1)[1] + for step in loaded_config["steps"] + ] raise KeyError( f"Override keys {list(override_keys)} do not match any step in the saved configuration. " @@ -644,26 +590,30 @@ class DataProcessorPipeline(ModelHubMixin, Generic[TInput, TOutput]): f"Make sure override keys match exact step class names or registry names." ) + # Construct and return the final pipeline instance. return cls( steps=steps, name=loaded_config.get("name", "DataProcessorPipeline"), to_transition=to_transition or batch_to_transition, - # Cast is necessary here: Same type-checker limitation as above. - # When to_output is None, we use the default which returns dict[str, Any]. - # The cast ensures type consistency with the generic TOutput parameter. to_output=to_output or cast(Callable[[EnvTransition], TOutput], transition_to_batch), ) def __len__(self) -> int: - """Return the number of steps in the processor.""" + """Returns the number of steps in the pipeline.""" return len(self.steps) def __getitem__(self, idx: int | slice) -> ProcessorStep | DataProcessorPipeline[TInput, TOutput]: - """Indexing helper exposing underlying steps. - * ``int`` – returns the idx-th ProcessorStep. - * ``slice`` – returns a new DataProcessorPipeline with the sliced steps. + """Retrieves a step or a sub-pipeline by index or slice. + + Args: + idx: An integer index or a slice object. + + Returns: + A `ProcessorStep` if `idx` is an integer, or a new `DataProcessorPipeline` + containing the sliced steps. """ if isinstance(idx, slice): + # Return a new pipeline instance with the sliced steps. return DataProcessorPipeline( steps=self.steps[idx], name=self.name, @@ -675,17 +625,21 @@ class DataProcessorPipeline(ModelHubMixin, Generic[TInput, TOutput]): return self.steps[idx] def register_before_step_hook(self, fn: Callable[[int, EnvTransition], None]): - """Attach fn to be executed before every processor step.""" + """Registers a function to be called before each step. + + Args: + fn: A callable that accepts the step index and the current transition. + """ self.before_step_hooks.append(fn) def unregister_before_step_hook(self, fn: Callable[[int, EnvTransition], None]): - """Remove a previously registered before_step hook. + """Unregisters a 'before_step' hook. Args: - fn: The exact function reference that was registered. Must be the same object. + fn: The exact function object that was previously registered. Raises: - ValueError: If the hook is not found in the registered hooks. + ValueError: If the hook is not found in the list. """ try: self.before_step_hooks.remove(fn) @@ -695,17 +649,21 @@ class DataProcessorPipeline(ModelHubMixin, Generic[TInput, TOutput]): ) from None def register_after_step_hook(self, fn: Callable[[int, EnvTransition], None]): - """Attach fn to be executed after every processor step.""" + """Registers a function to be called after each step. + + Args: + fn: A callable that accepts the step index and the current transition. + """ self.after_step_hooks.append(fn) def unregister_after_step_hook(self, fn: Callable[[int, EnvTransition], None]): - """Remove a previously registered after_step hook. + """Unregisters an 'after_step' hook. Args: - fn: The exact function reference that was registered. Must be the same object. + fn: The exact function object that was previously registered. Raises: - ValueError: If the hook is not found in the registered hooks. + ValueError: If the hook is not found in the list. """ try: self.after_step_hooks.remove(fn) @@ -715,13 +673,13 @@ class DataProcessorPipeline(ModelHubMixin, Generic[TInput, TOutput]): ) from None def reset(self): - """Clear state in every step that implements ``reset()`` and fire registered hooks.""" + """Resets the state of all stateful steps in the pipeline.""" for step in self.steps: if hasattr(step, "reset"): - step.reset() # type: ignore[attr-defined] + step.reset() def __repr__(self) -> str: - """Return a readable string representation of the processor.""" + """Provides a concise string representation of the pipeline.""" step_names = [step.__class__.__name__ for step in self.steps] if not step_names: @@ -729,7 +687,7 @@ class DataProcessorPipeline(ModelHubMixin, Generic[TInput, TOutput]): elif len(step_names) <= 3: steps_repr = f"steps={len(step_names)}: [{', '.join(step_names)}]" else: - # Show first 2 and last 1 with ellipsis for long lists + # For long pipelines, show the first, second, and last steps. displayed = f"{step_names[0]}, {step_names[1]}, ..., {step_names[-1]}" steps_repr = f"steps={len(step_names)}: [{displayed}]" @@ -738,6 +696,7 @@ class DataProcessorPipeline(ModelHubMixin, Generic[TInput, TOutput]): return f"DataProcessorPipeline({', '.join(parts)})" def __post_init__(self): + """Validates that all provided steps are instances of `ProcessorStep`.""" for i, step in enumerate(self.steps): if not isinstance(step, ProcessorStep): raise TypeError(f"Step {i} ({type(step).__name__}) must inherit from ProcessorStep") @@ -745,9 +704,17 @@ class DataProcessorPipeline(ModelHubMixin, Generic[TInput, TOutput]): def transform_features( self, initial_features: dict[PipelineFeatureType, dict[str, PolicyFeature]] ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: - """ - Apply ALL steps in order. Only if a step has a features method, it will be called. - We aggregate the dataset features of all steps. + """Applies feature transformations from all steps sequentially. + + This method propagates a feature description dictionary through each step's + `transform_features` method, allowing the pipeline to statically determine + the output feature specification without processing any real data. + + Args: + initial_features: A dictionary describing the initial features. + + Returns: + The final feature description after all transformations. """ features: dict[PipelineFeatureType, dict[str, PolicyFeature]] = deepcopy(initial_features) @@ -756,7 +723,16 @@ class DataProcessorPipeline(ModelHubMixin, Generic[TInput, TOutput]): features = out return features + # Convenience methods for processing individual parts of a transition. def process_observation(self, observation: dict[str, Any]) -> dict[str, Any]: + """Processes only the observation part of a transition through the pipeline. + + Args: + observation: The observation dictionary. + + Returns: + The processed observation dictionary. + """ transition: EnvTransition = create_transition(observation=observation) transformed_transition = self._forward(transition) return transformed_transition[TransitionKey.OBSERVATION] @@ -764,74 +740,106 @@ class DataProcessorPipeline(ModelHubMixin, Generic[TInput, TOutput]): def process_action( self, action: PolicyAction | RobotAction | EnvAction ) -> PolicyAction | RobotAction | EnvAction: + """Processes only the action part of a transition through the pipeline. + + Args: + action: The action data. + + Returns: + The processed action. + """ transition: EnvTransition = create_transition(action=action) transformed_transition = self._forward(transition) return transformed_transition[TransitionKey.ACTION] def process_reward(self, reward: float | torch.Tensor) -> float | torch.Tensor: + """Processes only the reward part of a transition through the pipeline. + + Args: + reward: The reward value. + + Returns: + The processed reward. + """ transition: EnvTransition = create_transition(reward=reward) transformed_transition = self._forward(transition) return transformed_transition[TransitionKey.REWARD] def process_done(self, done: bool | torch.Tensor) -> bool | torch.Tensor: + """Processes only the done flag of a transition through the pipeline. + + Args: + done: The done flag. + + Returns: + The processed done flag. + """ transition: EnvTransition = create_transition(done=done) transformed_transition = self._forward(transition) return transformed_transition[TransitionKey.DONE] def process_truncated(self, truncated: bool | torch.Tensor) -> bool | torch.Tensor: + """Processes only the truncated flag of a transition through the pipeline. + + Args: + truncated: The truncated flag. + + Returns: + The processed truncated flag. + """ transition: EnvTransition = create_transition(truncated=truncated) transformed_transition = self._forward(transition) return transformed_transition[TransitionKey.TRUNCATED] def process_info(self, info: dict[str, Any]) -> dict[str, Any]: + """Processes only the info dictionary of a transition through the pipeline. + + Args: + info: The info dictionary. + + Returns: + The processed info dictionary. + """ transition: EnvTransition = create_transition(info=info) transformed_transition = self._forward(transition) return transformed_transition[TransitionKey.INFO] def process_complementary_data(self, complementary_data: dict[str, Any]) -> dict[str, Any]: + """Processes only the complementary data part of a transition through the pipeline. + + Args: + complementary_data: The complementary data dictionary. + + Returns: + The processed complementary data dictionary. + """ transition: EnvTransition = create_transition(complementary_data=complementary_data) transformed_transition = self._forward(transition) return transformed_transition[TransitionKey.COMPLEMENTARY_DATA] +# Type aliases for semantic clarity. RobotProcessorPipeline: TypeAlias = DataProcessorPipeline[TInput, TOutput] PolicyProcessorPipeline: TypeAlias = DataProcessorPipeline[TInput, TOutput] class ObservationProcessorStep(ProcessorStep, ABC): - """Base class for processors that modify only the observation component of a transition. - - Subclasses should override the `observation` method to implement custom observation processing. - This class handles the boilerplate of extracting and reinserting the processed observation - into the transition dict, eliminating the need to implement the `__call__` method in subclasses. - - Example: - ```python - class MyObservationScaler(ObservationProcessor): - def __init__(self, scale_factor): - self.scale_factor = scale_factor - - def observation(self, observation): - return observation * self.scale_factor - ``` - - By inheriting from this class, you avoid writing repetitive code to handle transition dict - manipulation, focusing only on the specific observation processing logic. - """ + """An abstract `ProcessorStep` that specifically targets the observation in a transition.""" @abstractmethod def observation(self, observation) -> dict[str, Any]: - """Process the observation component. + """Processes an observation dictionary. Subclasses must implement this method. Args: - observation: The observation to process + observation: The input observation dictionary from the transition. Returns: - The processed observation + The processed observation dictionary. """ ... def __call__(self, transition: EnvTransition) -> EnvTransition: + """Applies the `observation` method to the transition's observation.""" self._current_transition = transition.copy() new_transition = self._current_transition @@ -845,42 +853,24 @@ class ObservationProcessorStep(ProcessorStep, ABC): class ActionProcessorStep(ProcessorStep, ABC): - """Base class for processors that modify only the action component of a transition. - - Subclasses should override the `action` method to implement custom action processing. - This class handles the boilerplate of extracting and reinserting the processed action - into the transition dict, eliminating the need to implement the `__call__` method in subclasses. - - Example: - ```python - class ActionClipping(ActionProcessor): - def __init__(self, min_val, max_val): - self.min_val = min_val - self.max_val = max_val - - def action(self, action): - return np.clip(action, self.min_val, self.max_val) - ``` - - By inheriting from this class, you avoid writing repetitive code to handle transition dict - manipulation, focusing only on the specific action processing logic. - """ + """An abstract `ProcessorStep` that specifically targets the action in a transition.""" @abstractmethod def action( self, action: PolicyAction | RobotAction | EnvAction ) -> PolicyAction | RobotAction | EnvAction: - """Process the action component. + """Processes an action. Subclasses must implement this method. Args: - action: The action to process + action: The input action from the transition. Returns: - The processed action + The processed action. """ ... def __call__(self, transition: EnvTransition) -> EnvTransition: + """Applies the `action` method to the transition's action.""" self._current_transition = transition.copy() new_transition = self._current_transition @@ -890,42 +880,32 @@ class ActionProcessorStep(ProcessorStep, ABC): processed_action = self.action(action) new_transition[TransitionKey.ACTION] = processed_action - raise ValueError("ActionProcessorStep requires an action in the transition.") + return new_transition class RobotActionProcessorStep(ProcessorStep, ABC): - """Base class for processors that modify only the robot action component of a transition. - - Subclasses should override the `action` method to implement custom robot action processing. - This class handles the boilerplate of extracting and reinserting the processed action - into the transition dict, eliminating the need to implement the `__call__` method in subclasses. - - - By inheriting from this class, you avoid writing repetitive code to handle transition dict - manipulation, focusing only on the specific robot action processing logic. - """ + """An abstract `ProcessorStep` for processing a `RobotAction` (a dictionary).""" @abstractmethod def action(self, action: RobotAction) -> RobotAction: - """Process the robot action component. + """Processes a `RobotAction`. Subclasses must implement this method. Args: - action: The robot action to process + action: The input `RobotAction` dictionary. Returns: - The processed robot action + The processed `RobotAction`. """ ... def __call__(self, transition: EnvTransition) -> EnvTransition: + """Applies the `action` method to the transition's action, ensuring it's a `RobotAction`.""" self._current_transition = transition.copy() new_transition = self._current_transition action = new_transition.get(TransitionKey.ACTION) - # NOTE: We can't use isinstance(action, RobotAction) because RobotAction is a dict[str, Any] - # because Any is generic if not isinstance(action, dict): - raise ValueError(f"Action should be a RobotAction type got {type(action)}") + raise ValueError(f"Action should be a RobotAction type (dict), but got {type(action)}") processed_action = self.action(action=action) new_transition[TransitionKey.ACTION] = processed_action @@ -933,36 +913,28 @@ class RobotActionProcessorStep(ProcessorStep, ABC): class PolicyActionProcessorStep(ProcessorStep, ABC): - """Base class for processors that modify only the policy action component of a transition. - - Subclasses should override the `action` method to implement custom policy action processing. - This class handles the boilerplate of extracting and reinserting the processed action - into the transition dict, eliminating the need to implement the `__call__` method in subclasses. - - - By inheriting from this class, you avoid writing repetitive code to handle transition dict - manipulation, focusing only on the specific policy action processing logic. - """ + """An abstract `ProcessorStep` for processing a `PolicyAction` (a tensor or dict of tensors).""" @abstractmethod def action(self, action: PolicyAction) -> PolicyAction: - """Process the policy action component. + """Processes a `PolicyAction`. Subclasses must implement this method. Args: - action: The policy action to process + action: The input `PolicyAction`. Returns: - The processed policy action + The processed `PolicyAction`. """ ... def __call__(self, transition: EnvTransition) -> EnvTransition: + """Applies the `action` method to the transition's action, ensuring it's a `PolicyAction`.""" self._current_transition = transition.copy() new_transition = self._current_transition action = new_transition.get(TransitionKey.ACTION) if not isinstance(action, PolicyAction): - raise ValueError(f"Action should be a PolicyAction type got {type(action)}") + raise ValueError(f"Action should be a PolicyAction type (tensor), but got {type(action)}") processed_action = self.action(action) new_transition[TransitionKey.ACTION] = processed_action @@ -970,39 +942,22 @@ class PolicyActionProcessorStep(ProcessorStep, ABC): class RewardProcessorStep(ProcessorStep, ABC): - """Base class for processors that modify only the reward component of a transition. - - Subclasses should override the `reward` method to implement custom reward processing. - This class handles the boilerplate of extracting and reinserting the processed reward - into the transition dict, eliminating the need to implement the `__call__` method in subclasses. - - Example: - ```python - class RewardScaler(RewardProcessor): - def __init__(self, scale_factor): - self.scale_factor = scale_factor - - def reward(self, reward): - return reward * self.scale_factor - ``` - - By inheriting from this class, you avoid writing repetitive code to handle transition dict - manipulation, focusing only on the specific reward processing logic. - """ + """An abstract `ProcessorStep` that specifically targets the reward in a transition.""" @abstractmethod def reward(self, reward) -> float | torch.Tensor: - """Process the reward component. + """Processes a reward. Subclasses must implement this method. Args: - reward: The reward to process + reward: The input reward from the transition. Returns: - The processed reward + The processed reward. """ ... def __call__(self, transition: EnvTransition) -> EnvTransition: + """Applies the `reward` method to the transition's reward.""" self._current_transition = transition.copy() new_transition = self._current_transition @@ -1016,44 +971,22 @@ class RewardProcessorStep(ProcessorStep, ABC): class DoneProcessorStep(ProcessorStep, ABC): - """Base class for processors that modify only the done flag of a transition. - - Subclasses should override the `done` method to implement custom done flag processing. - This class handles the boilerplate of extracting and reinserting the processed done flag - into the transition dict, eliminating the need to implement the `__call__` method in subclasses. - - Example: - ```python - class TimeoutDone(DoneProcessor): - def __init__(self, max_steps): - self.steps = 0 - self.max_steps = max_steps - - def done(self, done): - self.steps += 1 - return done or self.steps >= self.max_steps - - def reset(self): - self.steps = 0 - ``` - - By inheriting from this class, you avoid writing repetitive code to handle transition dict - manipulation, focusing only on the specific done flag processing logic. - """ + """An abstract `ProcessorStep` that specifically targets the 'done' flag in a transition.""" @abstractmethod def done(self, done) -> bool | torch.Tensor: - """Process the done flag. + """Processes a 'done' flag. Subclasses must implement this method. Args: - done: The done flag to process + done: The input 'done' flag from the transition. Returns: - The processed done flag + The processed 'done' flag. """ ... def __call__(self, transition: EnvTransition) -> EnvTransition: + """Applies the `done` method to the transition's 'done' flag.""" self._current_transition = transition.copy() new_transition = self._current_transition @@ -1067,40 +1000,22 @@ class DoneProcessorStep(ProcessorStep, ABC): class TruncatedProcessorStep(ProcessorStep, ABC): - """Base class for processors that modify only the truncated flag of a transition. - - Subclasses should override the `truncated` method to implement custom truncated flag processing. - This class handles the boilerplate of extracting and reinserting the processed truncated flag - into the transition dict, eliminating the need to implement the `__call__` method in subclasses. - - Example: - ```python - class EarlyTruncation(TruncatedProcessor): - def __init__(self, threshold): - self.threshold = threshold - - def truncated(self, truncated): - # Additional truncation condition - return truncated or some_condition > self.threshold - ``` - - By inheriting from this class, you avoid writing repetitive code to handle transition dict - manipulation, focusing only on the specific truncated flag processing logic. - """ + """An abstract `ProcessorStep` that specifically targets the 'truncated' flag in a transition.""" @abstractmethod def truncated(self, truncated) -> bool | torch.Tensor: - """Process the truncated flag. + """Processes a 'truncated' flag. Subclasses must implement this method. Args: - truncated: The truncated flag to process + truncated: The input 'truncated' flag from the transition. Returns: - The processed truncated flag + The processed 'truncated' flag. """ ... def __call__(self, transition: EnvTransition) -> EnvTransition: + """Applies the `truncated` method to the transition's 'truncated' flag.""" self._current_transition = transition.copy() new_transition = self._current_transition @@ -1114,45 +1029,22 @@ class TruncatedProcessorStep(ProcessorStep, ABC): class InfoProcessorStep(ProcessorStep, ABC): - """Base class for processors that modify only the info dictionary of a transition. - - Subclasses should override the `info` method to implement custom info processing. - This class handles the boilerplate of extracting and reinserting the processed info - into the transition dict, eliminating the need to implement the `__call__` method in subclasses. - - Example: - ```python - class InfoAugmenter(InfoProcessor): - def __init__(self): - self.step_count = 0 - - def info(self, info): - info = info.copy() # Create a copy to avoid modifying the original - info["steps"] = self.step_count - self.step_count += 1 - return info - - def reset(self): - self.step_count = 0 - ``` - - By inheriting from this class, you avoid writing repetitive code to handle transition dict - manipulation, focusing only on the specific info dictionary processing logic. - """ + """An abstract `ProcessorStep` that specifically targets the 'info' dictionary in a transition.""" @abstractmethod def info(self, info) -> dict[str, Any]: - """Process the info dictionary. + """Processes an 'info' dictionary. Subclasses must implement this method. Args: - info: The info dictionary to process + info: The input 'info' dictionary from the transition. Returns: - The processed info dictionary + The processed 'info' dictionary. """ ... def __call__(self, transition: EnvTransition) -> EnvTransition: + """Applies the `info` method to the transition's 'info' dictionary.""" self._current_transition = transition.copy() new_transition = self._current_transition @@ -1166,26 +1058,22 @@ class InfoProcessorStep(ProcessorStep, ABC): class ComplementaryDataProcessorStep(ProcessorStep, ABC): - """Base class for processors that modify only the complementary data of a transition. - - Subclasses should override the `complementary_data` method to implement custom complementary data processing. - This class handles the boilerplate of extracting and reinserting the processed complementary data - into the transition dict, eliminating the need to implement the `__call__` method in subclasses. - """ + """An abstract `ProcessorStep` that targets the 'complementary_data' in a transition.""" @abstractmethod def complementary_data(self, complementary_data) -> dict[str, Any]: - """Process the complementary data. + """Processes a 'complementary_data' dictionary. Subclasses must implement this method. Args: - complementary_data: The complementary data to process + complementary_data: The input 'complementary_data' from the transition. Returns: - The processed complementary data + The processed 'complementary_data' dictionary. """ ... def __call__(self, transition: EnvTransition) -> EnvTransition: + """Applies the `complementary_data` method to the transition's data.""" self._current_transition = transition.copy() new_transition = self._current_transition @@ -1199,12 +1087,17 @@ class ComplementaryDataProcessorStep(ProcessorStep, ABC): class IdentityProcessorStep(ProcessorStep): - """Identity processor that does nothing.""" + """A no-op processor step that returns the input transition and features unchanged. + + This can be useful as a placeholder or for debugging purposes. + """ def __call__(self, transition: EnvTransition) -> EnvTransition: + """Returns the transition without modification.""" return transition def transform_features( self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + """Returns the features without modification.""" return features