mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-18 02:00:03 +00:00
88f7bf01c1
* Refactor observation preprocessing to use a modular pipeline system - Introduced `RobotPipeline` and `ObservationProcessor` for handling observation transformations. - Updated `preprocess_observation` to maintain backward compatibility while leveraging the new pipeline. - Added tests for the new processing components and ensured they match the original functionality. - Removed hardcoded logic in favor of a more flexible, composable architecture. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Refactor observation processing and improve modularity - Updated `ObservationProcessor` to enhance the modular design for processing observations. - Cleaned up imports and improved code readability by removing unnecessary lines and comments. - Ensured backward compatibility while integrating new processing components. - Added tests to validate the functionality of the updated processing architecture. * Remove redundant tests for None observation and serialization methods in `test_observation_processor.py` to streamline the test suite and improve maintainability. * Refactor processing architecture to use RobotProcessor - Replaced instances of RobotPipeline with RobotProcessor across the codebase for improved modularity and clarity. - Introduced ProcessorStepRegistry for better management of processing steps. - Updated relevant documentation and tests to reflect the new processing structure. - Enhanced the save/load functionality to support the new processor design. - Added a model card template for RobotProcessor to facilitate sharing and documentation. * Add RobotProcessor tutorial to documentation - Introduced a new tutorial on using RobotProcessor for preprocessing robot data. - Added a section in the table of contents for easy navigation to the new tutorial. - The tutorial covers key concepts, real-world scenarios, and practical examples for effective use of the RobotProcessor pipeline. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add normalization processor and related components - Introduced `NormalizationProcessor` to handle both observation normalization and action unnormalization. - Added `ObservationNormalizer` and `ActionUnnormalizer` classes for specific normalization tasks. - Updated `__init__.py` to include the new `NormalizationProcessor` in the module exports. - Enhanced `ObservationProcessor` with registration in the `ProcessorStepRegistry` for better modularity. - Created `RenameProcessor` for renaming keys in observations, improving flexibility in data processing. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Enhance processing architecture with new components - Added `RenameProcessor` to facilitate key renaming in observations, improving data handling flexibility. - Updated `__init__.py` to include `RenameProcessor` in module exports. - Refactored `NormalizationProcessor` and `ObservationNormalizer` to use `rsplit` for better key handling. - Introduced comprehensive tests for `NormalizationProcessor` and `RenameProcessor` to ensure functionality and robustness. * chore (docs): add docstring for processor * fix (test): test factory * fix(test): policies * Update tests/processor/test_observation_processor.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Adil Zouitine <adilzouitinegm@gmail.com> * chore(test): add suggestion made by copilot regarding numpy test * fix(test): import issue * Refactor normalization components and update tests - Renamed `ObservationNormalizer` to `NormalizerProcessor` and `ActionUnnormalizer` to `UnnormalizerProcessor` for clarity. - Consolidated normalization logic for both observations and actions into `NormalizerProcessor` and `UnnormalizerProcessor`. - Updated tests to reflect the new class names and ensure proper functionality of normalization and unnormalization processes. - Enhanced handling of missing statistics in normalization processes. * chore (docstrin):Improve docstring for NormalizerProcessor * feat (device processor): Implement device processor * chore (batch handling): Enhance processing components with batch conversion utilities * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix(test): linting issue * chore (output format): improves output format * chore (type): add typing for multiprocess envs * feat (overrides): Implement support for loading processors with parameter overrides - Added the ability to provide non-serializable objects when loading processors from saved configurations using the `overrides` parameter. - Enhanced error handling for invalid override keys and instantiation errors. - Updated documentation and examples to illustrate the usage of overrides for both registered and unregistered steps. - Added comprehensive tests to validate the new functionality and ensure backward compatibility. * chore(normalization): addressing comments from copilot * chore(learner): nit comment from copilot * feat(pipeline): Enhance step_through method to support both tuple and dict inputs * refactor(pipeline): Simplify observation and padding data handling in batch transitions * Apply suggestions from code review Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> Signed-off-by: Adil Zouitine <adilzouitinegm@gmail.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactor(pipeline): Introduce ComplementaryDataProcessor for handling complementary data in transitions * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactor(pipeline): Transition from tuple to dictionary format for EnvTransition - Updated the EnvTransition structure to use a dictionary format instead of a tuple, enhancing readability and maintainability. - Replaced instances of TransitionIndex with TransitionKey for accessing transition components. - Adjusted related processing functions and tests to accommodate the new dictionary format, ensuring consistent handling of transitions across the codebase. * refactor(observation_processor): Improve observation processing by using constants and simplifying pixel handling - Introduced constants for observation keys to enhance readability. - Streamlined the handling of the "pixels" key by copying observations first and processing images more clearly. - Updated the environment state and agent position assignments to use the new constants, improving maintainability. * feat(pipeline): Add hook unregistration functionality and enhance documentation - Implemented methods to unregister before, after, and reset hooks in the RobotProcessor class, allowing for more flexible hook management. - Enhanced documentation to clarify hook execution semantics and the implications of modifying transitions within hooks. - Added comprehensive tests to verify the correct behavior of hook registration and unregistration, including error handling for non-existent hooks. * refactor(pipeline): Clarify hook behavior and improve documentation - Updated the RobotProcessor class to ensure hooks are strictly for observation and do not modify transitions, enhancing clarity and maintainability. - Refactored hook registration methods to reflect the new behavior, ensuring they accept only functions that do not return modified transitions. - Enhanced documentation to clearly outline the purpose of hooks and their execution semantics. - Added tests to verify that hooks are not executed during the step_through method while ensuring they function correctly during the __call__ method. * feat(pipeline): Add __repr__ method to RobotProcessor for improved readability - Implemented a __repr__ method in the RobotProcessor class to provide a clear string representation of the processor, including step names and optional parameters like name and seed. - Added comprehensive tests to validate the __repr__ output for various scenarios, including empty processors, single and multiple steps, custom names, and seed values. - Ensured that the representation handles long lists of steps with truncation for better readability. * chore(pipeline): Move _CFG_NAME along other class member * refactor(pipeline): Utilize get_safe_torch_device for device assignment - Replaced direct torch.device instantiation with get_safe_torch_device to ensure safe device handling. - This change enhances code readability and maintains consistency in device management across the RobotProcessor class. * refactor(pipeline): Enhance state filename generation and profiling method - Updated state filename generation to use the registry name when available, improving clarity in saved files. - Modified the profile_steps method to include a warmup_runs parameter, allowing for more controlled performance profiling. - Ensured consistent conditions during profiling by deep copying transitions for each run, enhancing accuracy in timing results. * chore(doc): address pip install commant lerobot that not exist yet * feat(pipeline): Enhance configuration filename handling and state file naming - Introduced support for custom configuration filenames in the `save_pretrained` method, allowing users to specify a filename instead of the default. - Improved state file naming to include step indices, preventing conflicts when multiple processors of the same type are saved. - Added automatic detection for configuration files when loading from a directory, with error handling for multiple files. - Updated tests to validate new features, including custom filenames and automatic config detection. * refactor(pipeline): Improve state file naming conventions for clarity and uniqueness - Enhanced state file naming to include the processor's sanitized name, ensuring uniqueness when multiple processors are saved in the same directory. - Updated tests to reflect changes in state file naming, verifying that filenames now include the processor name and step indices to prevent conflicts. - Added a new test to validate state file naming when using multiple processors, ensuring distinct filenames for each processor's state files. * docs(pipeline): Add clarification for repo name sanitization process * Feat/pipeline add feature contract (#1637) * Add feature contract to pipelinestep and pipeline * Add tests * Add processor tests * PR feedback * encorperate pr feedback * type in doc * oops * docs(pipeline): Clarify transition handling and hook behavior - Updated documentation to specify that hooks always receive transitions in EnvTransition format, ensuring consistent behavior across input formats. - Refactored the step_through method to yield only EnvTransition objects, regardless of the input format, and updated related tests to reflect this change. - Enhanced test assertions to verify the structure of results and the correctness of processing steps. * refactor(pipeline): Remove to() method for device management - Eliminated the to() method from RobotProcessor, which was responsible for moving tensor states to specified devices. - Removed associated unit tests that validated the functionality of the to() method across various scenarios. - Streamlined the pipeline code by focusing on other device management strategies. * refactor(pipeline): Remove model card generation and streamline processor methods - Eliminated the _generate_model_card method from RobotProcessor, which was responsible for generating README.md files from a template. - Updated save_pretrained method to remove model card generation, focusing on serialization of processor definitions and parameters. - Added default implementations for get_config, state_dict, load_state_dict, reset, and feature_contract methods in various processor classes to enhance consistency and usability. * refactor(observation): Streamline observation preprocessing and remove unused processor methods - Updated the `preprocess_observation` function to enhance image handling and ensure proper tensor formatting. - Removed the `RobotProcessor` and associated transition handling from the `rollout` function, simplifying the observation processing flow. - Integrated direct calls to `preprocess_observation` for improved clarity and efficiency in the evaluation script. * refactor(pipeline): Rename parameters for clarity and enhance save/load functionality - Updated parameter names in the save_pretrained and from_pretrained methods for improved readability, changing destination_path to save_directory and source to pretrained_model_name_or_path. - Enhanced the save_pretrained method to ensure directory creation and file handling is consistent with the new parameter names. - Streamlined the loading process in from_pretrained to utilize loaded_config for better clarity and maintainability. * refactor(pipeline): minor improvements (#1684) * chore(pipeline): remove unused features + device torch + envtransition keys * refactor(pipeline): ImageProcessor & StateProcessor are both implemented directly in VanillaObservationPRocessor * refactor(pipeline): RenameProcessor now inherits from ObservationProcessor + remove unused code * test(pipeline): fix broken test after refactors * docs(pipeline): update docstrings VanillaObservationProcessor * chore(pipeline): move None check to base pipeline classes --------- Signed-off-by: Adil Zouitine <adilzouitinegm@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com> Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
1265 lines
48 KiB
Python
1265 lines
48 KiB
Python
#!/usr/bin/env python
|
||
|
||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||
#
|
||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
# you may not use this file except in compliance with the License.
|
||
# You may obtain a copy of the License at
|
||
#
|
||
# http://www.apache.org/licenses/LICENSE-2.0
|
||
#
|
||
# Unless required by applicable law or agreed to in writing, software
|
||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
# See the License for the specific language governing permissions and
|
||
# limitations under the License.
|
||
from __future__ import annotations
|
||
|
||
import importlib
|
||
import json
|
||
import os
|
||
from collections.abc import Callable, Iterable, Sequence
|
||
from copy import deepcopy
|
||
from dataclasses import dataclass, field
|
||
from enum import Enum
|
||
from pathlib import Path
|
||
from typing import Any, Protocol, TypedDict
|
||
|
||
import torch
|
||
from huggingface_hub import ModelHubMixin, hf_hub_download
|
||
from huggingface_hub.errors import HfHubHTTPError
|
||
from safetensors.torch import load_file, save_file
|
||
|
||
from lerobot.configs.types import PolicyFeature
|
||
|
||
|
||
class TransitionKey(str, Enum):
|
||
"""Keys for accessing EnvTransition dictionary components."""
|
||
|
||
# TODO(Steven): Use consts
|
||
OBSERVATION = "observation"
|
||
ACTION = "action"
|
||
REWARD = "reward"
|
||
DONE = "done"
|
||
TRUNCATED = "truncated"
|
||
INFO = "info"
|
||
COMPLEMENTARY_DATA = "complementary_data"
|
||
|
||
|
||
EnvTransition = TypedDict(
|
||
"EnvTransition",
|
||
{
|
||
TransitionKey.OBSERVATION.value: dict[str, Any] | None,
|
||
TransitionKey.ACTION.value: Any | torch.Tensor | None,
|
||
TransitionKey.REWARD.value: float | torch.Tensor | None,
|
||
TransitionKey.DONE.value: bool | torch.Tensor | None,
|
||
TransitionKey.TRUNCATED.value: bool | torch.Tensor | None,
|
||
TransitionKey.INFO.value: dict[str, Any] | None,
|
||
TransitionKey.COMPLEMENTARY_DATA.value: dict[str, Any] | None,
|
||
},
|
||
)
|
||
|
||
|
||
class ProcessorStepRegistry:
|
||
"""Registry for processor steps that enables saving/loading by name instead of module path."""
|
||
|
||
_registry: dict[str, type] = {}
|
||
|
||
@classmethod
|
||
def register(cls, name: str = None):
|
||
"""Decorator to register a processor step class.
|
||
|
||
Args:
|
||
name: Optional registration name. If not provided, uses class name.
|
||
|
||
Example:
|
||
@ProcessorStepRegistry.register("adaptive_normalizer")
|
||
class AdaptiveObservationNormalizer:
|
||
...
|
||
"""
|
||
|
||
def decorator(step_class: type) -> type:
|
||
registration_name = name if name is not None else step_class.__name__
|
||
|
||
if registration_name in cls._registry:
|
||
raise ValueError(
|
||
f"Processor step '{registration_name}' is already registered. "
|
||
f"Use a different name or unregister the existing one first."
|
||
)
|
||
|
||
cls._registry[registration_name] = step_class
|
||
# Store the registration name on the class for later reference
|
||
step_class._registry_name = registration_name
|
||
return step_class
|
||
|
||
return decorator
|
||
|
||
@classmethod
|
||
def get(cls, name: str) -> type:
|
||
"""Get a registered processor step class by name.
|
||
|
||
Args:
|
||
name: The registration name of the step.
|
||
|
||
Returns:
|
||
The registered step class.
|
||
|
||
Raises:
|
||
KeyError: If the step is not registered.
|
||
"""
|
||
if name not in cls._registry:
|
||
available = list(cls._registry.keys())
|
||
raise KeyError(
|
||
f"Processor step '{name}' not found in registry. "
|
||
f"Available steps: {available}. "
|
||
f"Make sure the step is registered using @ProcessorStepRegistry.register()"
|
||
)
|
||
return cls._registry[name]
|
||
|
||
@classmethod
|
||
def unregister(cls, name: str) -> None:
|
||
"""Remove a step from the registry."""
|
||
cls._registry.pop(name, None)
|
||
|
||
@classmethod
|
||
def list(cls) -> list[str]:
|
||
"""List all registered step names."""
|
||
return list(cls._registry.keys())
|
||
|
||
@classmethod
|
||
def clear(cls) -> None:
|
||
"""Clear all registrations."""
|
||
cls._registry.clear()
|
||
|
||
|
||
class ProcessorStep(Protocol):
|
||
"""Structural typing interface for a single processor step.
|
||
|
||
A step is any callable accepting a full `EnvTransition` dict and
|
||
returning a (possibly modified) dict of the same structure. Implementers
|
||
are encouraged—but not required—to expose the optional helper methods
|
||
listed below. When present, these hooks let `RobotProcessor`
|
||
automatically serialise the step's configuration and learnable state using
|
||
a safe-to-share JSON + SafeTensors format.
|
||
|
||
|
||
**Required**:
|
||
- ``__call__(transition: EnvTransition) -> EnvTransition``
|
||
- ``feature_contract(features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]``
|
||
|
||
Optional helper protocol:
|
||
* ``get_config() -> dict[str, Any]`` – User-defined JSON-serializable
|
||
configuration and state. YOU decide what to save here. This is where all
|
||
non-tensor state goes (e.g., name, counter, threshold, window_size).
|
||
The config dict will be passed to your class constructor when loading.
|
||
* ``state_dict() -> dict[str, torch.Tensor]`` – PyTorch tensor state ONLY.
|
||
This is exclusively for torch.Tensor objects (e.g., learned weights,
|
||
running statistics as tensors). Never put simple Python types here.
|
||
* ``load_state_dict(state)`` – Inverse of ``state_dict``. Receives a dict
|
||
containing torch tensors only.
|
||
* ``reset()`` – Clear internal buffers at episode boundaries.
|
||
|
||
Example separation:
|
||
- get_config(): {"name": "my_step", "learning_rate": 0.01, "window_size": 10}
|
||
- state_dict(): {"weights": torch.tensor(...), "running_mean": torch.tensor(...)}
|
||
"""
|
||
|
||
def __call__(self, transition: EnvTransition) -> EnvTransition: ...
|
||
|
||
def get_config(self) -> dict[str, Any]: ...
|
||
|
||
def state_dict(self) -> dict[str, torch.Tensor]: ...
|
||
|
||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: ...
|
||
|
||
def reset(self) -> None: ...
|
||
|
||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: ...
|
||
|
||
|
||
def _default_batch_to_transition(batch: dict[str, Any]) -> EnvTransition: # noqa: D401
|
||
"""Convert a *batch* dict coming from Learobot replay/dataset code into an
|
||
``EnvTransition`` dictionary.
|
||
|
||
The function maps well known keys to the EnvTransition structure. Missing keys are
|
||
filled with sane defaults (``None`` or ``0.0``/``False``).
|
||
|
||
Keys recognised (case-sensitive):
|
||
|
||
* "observation.*" (keys starting with "observation." are grouped into observation dict)
|
||
* "action"
|
||
* "next.reward"
|
||
* "next.done"
|
||
* "next.truncated"
|
||
* "info"
|
||
|
||
Additional keys are ignored so that existing dataloaders can carry extra
|
||
metadata without breaking the processor.
|
||
"""
|
||
|
||
# Extract observation keys
|
||
observation_keys = {k: v for k, v in batch.items() if k.startswith("observation.")}
|
||
observation = observation_keys if observation_keys else None
|
||
|
||
# Extract padding and task keys for complementary data
|
||
pad_keys = {k: v for k, v in batch.items() if "_is_pad" in k}
|
||
task_key = {"task": batch["task"]} if "task" in batch else {}
|
||
complementary_data = {**pad_keys, **task_key} if pad_keys or task_key else {}
|
||
|
||
transition: EnvTransition = {
|
||
TransitionKey.OBSERVATION: observation,
|
||
TransitionKey.ACTION: batch.get("action"),
|
||
TransitionKey.REWARD: batch.get("next.reward", 0.0),
|
||
TransitionKey.DONE: batch.get("next.done", False),
|
||
TransitionKey.TRUNCATED: batch.get("next.truncated", False),
|
||
TransitionKey.INFO: batch.get("info", {}),
|
||
TransitionKey.COMPLEMENTARY_DATA: complementary_data,
|
||
}
|
||
return transition
|
||
|
||
|
||
def _default_transition_to_batch(transition: EnvTransition) -> dict[str, Any]: # noqa: D401
|
||
"""Inverse of :pyfunc:`_default_batch_to_transition`. Returns a dict with
|
||
the canonical field names used throughout *LeRobot*.
|
||
"""
|
||
|
||
batch = {
|
||
"action": transition.get(TransitionKey.ACTION),
|
||
"next.reward": transition.get(TransitionKey.REWARD, 0.0),
|
||
"next.done": transition.get(TransitionKey.DONE, False),
|
||
"next.truncated": transition.get(TransitionKey.TRUNCATED, False),
|
||
"info": transition.get(TransitionKey.INFO, {}),
|
||
}
|
||
|
||
# Add padding and task data from complementary_data
|
||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
||
if complementary_data:
|
||
pad_data = {k: v for k, v in complementary_data.items() if "_is_pad" in k}
|
||
batch.update(pad_data)
|
||
|
||
if "task" in complementary_data:
|
||
batch["task"] = complementary_data["task"]
|
||
|
||
# Handle observation - flatten dict to observation.* keys if it's a dict
|
||
observation = transition.get(TransitionKey.OBSERVATION)
|
||
if isinstance(observation, dict):
|
||
batch.update(observation)
|
||
|
||
return batch
|
||
|
||
|
||
@dataclass
|
||
class RobotProcessor(ModelHubMixin):
|
||
"""
|
||
Composable, debuggable post-processing processor for robot transitions.
|
||
|
||
The class orchestrates an ordered collection of small, functional transforms—steps—executed
|
||
left-to-right on each incoming `EnvTransition`. It can process both `EnvTransition` dicts
|
||
and batch dictionaries, automatically converting between formats as needed.
|
||
|
||
Args:
|
||
steps: Ordered list of processing steps executed on every call. Defaults to empty list.
|
||
name: Human-readable identifier that is persisted inside the JSON config.
|
||
Defaults to "RobotProcessor".
|
||
to_transition: Function to convert batch dict to EnvTransition dict.
|
||
Defaults to _default_batch_to_transition.
|
||
to_output: Function to convert EnvTransition dict to the desired output format.
|
||
Usually it is a batch dict or EnvTransition dict.
|
||
Defaults to _default_transition_to_batch.
|
||
before_step_hooks: List of hooks called before each step. Each hook receives the step
|
||
index and transition, and can optionally return a modified transition.
|
||
after_step_hooks: List of hooks called after each step. Each hook receives the step
|
||
index and transition, and can optionally return a modified transition.
|
||
|
||
Hook Semantics:
|
||
- Hooks are executed sequentially in the order they were registered. There is no way to
|
||
reorder hooks after registration without creating a new pipeline.
|
||
- Hooks are for observation/monitoring only and DO NOT modify transitions. They are called
|
||
with the step index and current transition for logging, debugging, or monitoring purposes.
|
||
- All hooks for a given type (before/after) are executed for every step, or none at all if
|
||
an error occurs. There is no partial execution of hooks.
|
||
- Hooks should generally be stateless to maintain predictable behavior. If you need stateful
|
||
processing, consider implementing a proper ProcessorStep instead.
|
||
- To remove hooks, use the unregister methods. To remove steps, you must create a new pipeline.
|
||
- Hooks ALWAYS receive transitions in EnvTransition format, regardless of the input format
|
||
passed to __call__. This ensures consistent hook behavior whether processing batch dicts
|
||
or EnvTransition objects.
|
||
"""
|
||
|
||
steps: Sequence[ProcessorStep] = field(default_factory=list)
|
||
name: str = "RobotProcessor"
|
||
|
||
to_transition: Callable[[dict[str, Any]], EnvTransition] = field(
|
||
default_factory=lambda: _default_batch_to_transition, repr=False
|
||
)
|
||
to_output: Callable[[EnvTransition], dict[str, Any] | EnvTransition] = field(
|
||
default_factory=lambda: _default_transition_to_batch, repr=False
|
||
)
|
||
|
||
# Processor-level hooks for observation/monitoring
|
||
# Hooks do not modify transitions - they are called for logging, debugging, or monitoring purposes
|
||
before_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False)
|
||
after_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False)
|
||
|
||
def __call__(self, data: EnvTransition | dict[str, Any]):
|
||
"""Process data through all steps.
|
||
|
||
The method accepts either the classic EnvTransition dict or a batch dictionary
|
||
(like the ones returned by ReplayBuffer or LeRobotDataset). If a dict is supplied
|
||
it is first converted to the internal dict format using to_transition; after all
|
||
steps are executed the dict is transformed back into a batch dict with to_batch and the
|
||
result is returned – thereby preserving the caller's original data type.
|
||
|
||
Args:
|
||
data: Either an EnvTransition dict or a batch dictionary to process.
|
||
|
||
Returns:
|
||
The processed data in the same format as the input (EnvTransition or batch dict).
|
||
|
||
Raises:
|
||
ValueError: If the transition is not a valid EnvTransition format.
|
||
"""
|
||
# Check if we need to convert back to batch format at the end
|
||
_, called_with_batch = self._prepare_transition(data)
|
||
|
||
# Use step_through to get the iterator
|
||
step_iterator = self.step_through(data)
|
||
|
||
# Get initial state (before any steps)
|
||
current_transition = next(step_iterator)
|
||
|
||
# Process each step with hooks
|
||
for idx, next_transition in enumerate(step_iterator):
|
||
# Apply before hooks with current state (before step execution)
|
||
for hook in self.before_step_hooks:
|
||
hook(idx, current_transition)
|
||
|
||
# Move to next state (after step execution)
|
||
current_transition = next_transition
|
||
|
||
# Apply after hooks with updated state
|
||
for hook in self.after_step_hooks:
|
||
hook(idx, current_transition)
|
||
|
||
# Convert back to original format if needed
|
||
return self.to_output(current_transition) if called_with_batch else current_transition
|
||
|
||
def _prepare_transition(self, data: EnvTransition | dict[str, Any]) -> tuple[EnvTransition, bool]:
|
||
"""Prepare and validate transition data for processing.
|
||
|
||
Args:
|
||
data: Either an EnvTransition dict or a batch dictionary to process.
|
||
|
||
Returns:
|
||
A tuple of (prepared_transition, called_with_batch_flag)
|
||
|
||
Raises:
|
||
ValueError: If the transition is not a valid EnvTransition format.
|
||
"""
|
||
# Check if data is already an EnvTransition or needs conversion
|
||
if isinstance(data, dict) and not all(isinstance(k, TransitionKey) for k in data.keys()):
|
||
# It's a batch dict, convert it
|
||
called_with_batch = True
|
||
transition = self.to_transition(data)
|
||
else:
|
||
# It's already an EnvTransition
|
||
called_with_batch = False
|
||
transition = data
|
||
|
||
# Basic validation
|
||
if not isinstance(transition, dict):
|
||
raise ValueError(f"EnvTransition must be a dictionary. Got {type(transition).__name__}")
|
||
|
||
return transition, called_with_batch
|
||
|
||
def step_through(self, data: EnvTransition | dict[str, Any]) -> Iterable[EnvTransition]:
|
||
"""Yield the intermediate results after each processor step.
|
||
|
||
This is a low-level method that does NOT apply hooks. It simply executes each step
|
||
and yields the intermediate results. This allows users to debug the pipeline or
|
||
apply custom logic between steps if needed.
|
||
|
||
Note: This method always yields EnvTransition objects regardless of input format.
|
||
If you need the results in the original input format, you'll need to convert them
|
||
using `to_output()`.
|
||
|
||
Args:
|
||
data: Either an EnvTransition dict or a batch dictionary to process.
|
||
|
||
Yields:
|
||
The intermediate EnvTransition results after each step.
|
||
"""
|
||
transition, _ = self._prepare_transition(data)
|
||
|
||
# Yield initial state
|
||
yield transition
|
||
|
||
# Process each step WITHOUT hooks (low-level method)
|
||
for processor_step in self.steps:
|
||
transition = processor_step(transition)
|
||
yield transition
|
||
|
||
def _save_pretrained(self, save_directory: Path, **kwargs):
|
||
"""Internal save method for ModelHubMixin compatibility."""
|
||
# Extract config_filename from kwargs if provided
|
||
config_filename = kwargs.pop("config_filename", None)
|
||
self.save_pretrained(save_directory, config_filename=config_filename)
|
||
|
||
def save_pretrained(self, save_directory: str | Path, config_filename: str | None = None, **kwargs):
|
||
"""Serialize the processor definition and parameters to *save_directory*.
|
||
|
||
Args:
|
||
save_directory: Directory where the processor will be saved.
|
||
config_filename: Optional custom config filename. If not provided, defaults to
|
||
"{self.name}.json" where self.name is sanitized for filesystem compatibility.
|
||
"""
|
||
os.makedirs(str(save_directory), exist_ok=True)
|
||
|
||
# Sanitize processor name for use in filenames
|
||
import re
|
||
|
||
# The huggingface hub does not allow special characters in the repo name, so we sanitize the name
|
||
sanitized_name = re.sub(r"[^a-zA-Z0-9_]", "_", self.name.lower())
|
||
|
||
# Use sanitized name for config if not provided
|
||
if config_filename is None:
|
||
config_filename = f"{sanitized_name}.json"
|
||
|
||
config: dict[str, Any] = {
|
||
"name": self.name,
|
||
"steps": [],
|
||
}
|
||
|
||
for step_index, processor_step in enumerate(self.steps):
|
||
# Check if step was registered
|
||
registry_name = getattr(processor_step.__class__, "_registry_name", None)
|
||
|
||
step_entry: dict[str, Any] = {}
|
||
if registry_name:
|
||
# Use registry name for registered steps
|
||
step_entry["registry_name"] = registry_name
|
||
else:
|
||
# Fall back to full module path for unregistered steps
|
||
step_entry["class"] = (
|
||
f"{processor_step.__class__.__module__}.{processor_step.__class__.__name__}"
|
||
)
|
||
|
||
if hasattr(processor_step, "get_config"):
|
||
step_entry["config"] = processor_step.get_config()
|
||
|
||
if hasattr(processor_step, "state_dict"):
|
||
state = processor_step.state_dict()
|
||
if state:
|
||
# Clone tensors to avoid shared memory issues
|
||
# This ensures each tensor has its own memory allocation
|
||
# The reason is to avoid the following error:
|
||
# RuntimeError: Some tensors share memory, this will lead to duplicate memory on disk
|
||
# and potential differences when loading them again
|
||
# ------------------------------------------------------------------------------
|
||
# Since the state_dict of processor will be light, we can just clone the tensors
|
||
# and save them to the disk.
|
||
cloned_state = {}
|
||
for key, tensor in state.items():
|
||
cloned_state[key] = tensor.clone()
|
||
|
||
# Include pipeline name and step index to ensure unique filenames
|
||
# This prevents conflicts when multiple processors are saved in the same directory
|
||
if registry_name:
|
||
state_filename = f"{sanitized_name}_step_{step_index}_{registry_name}.safetensors"
|
||
else:
|
||
state_filename = f"{sanitized_name}_step_{step_index}.safetensors"
|
||
|
||
save_file(cloned_state, os.path.join(str(save_directory), state_filename))
|
||
step_entry["state_file"] = state_filename
|
||
|
||
config["steps"].append(step_entry)
|
||
|
||
with open(os.path.join(str(save_directory), config_filename), "w") as file_pointer:
|
||
json.dump(config, file_pointer, indent=2)
|
||
|
||
@classmethod
|
||
def from_pretrained(
|
||
cls,
|
||
pretrained_model_name_or_path: str | Path,
|
||
*,
|
||
force_download: bool = False,
|
||
resume_download: bool | None = None,
|
||
proxies: dict[str, str] | None = None,
|
||
token: str | bool | None = None,
|
||
cache_dir: str | Path | None = None,
|
||
local_files_only: bool = False,
|
||
revision: str | None = None,
|
||
config_filename: str | None = None,
|
||
overrides: dict[str, Any] | None = None,
|
||
**kwargs,
|
||
) -> RobotProcessor:
|
||
"""Load a serialized processor from source (local path or Hugging Face Hub identifier).
|
||
|
||
Args:
|
||
pretrained_model_name_or_path: Local path to a saved processor directory or Hugging Face Hub identifier
|
||
(e.g., "username/processor-name").
|
||
config_filename: Optional specific config filename to load. If not provided, will:
|
||
- For local paths: look for any .json file in the directory (error if multiple found)
|
||
- For HF Hub: try common names ("processor.json", "preprocessor.json", "postprocessor.json")
|
||
overrides: Optional dictionary mapping step names to configuration overrides.
|
||
Keys must match exact step class names (for unregistered steps) or registry names
|
||
(for registered steps). Values are dictionaries containing parameter overrides
|
||
that will be merged with the saved configuration. This is useful for providing
|
||
non-serializable objects like environment instances.
|
||
|
||
Returns:
|
||
A RobotProcessor instance loaded from the saved configuration.
|
||
|
||
Raises:
|
||
ImportError: If a processor step class cannot be loaded or imported.
|
||
ValueError: If a step cannot be instantiated with the provided configuration.
|
||
KeyError: If an override key doesn't match any step in the saved configuration.
|
||
|
||
Examples:
|
||
Basic loading:
|
||
```python
|
||
processor = RobotProcessor.from_pretrained("path/to/processor")
|
||
```
|
||
|
||
Loading specific config file:
|
||
```python
|
||
processor = RobotProcessor.from_pretrained(
|
||
"username/multi-processor-repo", config_filename="preprocessor.json"
|
||
)
|
||
```
|
||
|
||
Loading with overrides for non-serializable objects:
|
||
```python
|
||
import gym
|
||
|
||
env = gym.make("CartPole-v1")
|
||
processor = RobotProcessor.from_pretrained(
|
||
"username/cartpole-processor", overrides={"ActionRepeatStep": {"env": env}}
|
||
)
|
||
```
|
||
|
||
Multiple overrides:
|
||
```python
|
||
processor = RobotProcessor.from_pretrained(
|
||
"path/to/processor",
|
||
overrides={
|
||
"CustomStep": {"param1": "new_value"},
|
||
"device_processor": {"device": "cuda:1"}, # For registered steps
|
||
},
|
||
)
|
||
```
|
||
"""
|
||
# Use the local variable name 'source' for clarity
|
||
source = str(pretrained_model_name_or_path)
|
||
|
||
if Path(source).is_dir():
|
||
# Local path - use it directly
|
||
base_path = Path(source)
|
||
|
||
if config_filename is None:
|
||
# Look for any .json file in the directory
|
||
json_files = list(base_path.glob("*.json"))
|
||
if len(json_files) == 0:
|
||
raise FileNotFoundError(f"No .json configuration files found in {source}")
|
||
elif len(json_files) > 1:
|
||
raise ValueError(
|
||
f"Multiple .json files found in {source}: {[f.name for f in json_files]}. "
|
||
f"Please specify which one to load using the config_filename parameter."
|
||
)
|
||
config_filename = json_files[0].name
|
||
|
||
with open(base_path / config_filename) as file_pointer:
|
||
loaded_config: dict[str, Any] = json.load(file_pointer)
|
||
else:
|
||
# Hugging Face Hub - download all required files
|
||
if config_filename is None:
|
||
# Try common config names
|
||
common_names = [
|
||
"processor.json",
|
||
"preprocessor.json",
|
||
"postprocessor.json",
|
||
"robotprocessor.json",
|
||
]
|
||
config_path = None
|
||
for name in common_names:
|
||
try:
|
||
config_path = hf_hub_download(
|
||
source,
|
||
name,
|
||
repo_type="model",
|
||
force_download=force_download,
|
||
resume_download=resume_download,
|
||
proxies=proxies,
|
||
token=token,
|
||
cache_dir=cache_dir,
|
||
local_files_only=local_files_only,
|
||
revision=revision,
|
||
)
|
||
config_filename = name
|
||
break
|
||
except (FileNotFoundError, OSError, HfHubHTTPError):
|
||
# FileNotFoundError: local file issues
|
||
# OSError: network/system errors
|
||
# HfHubHTTPError: file not found on Hub (404) or other HTTP errors
|
||
continue
|
||
|
||
if config_path is None:
|
||
raise FileNotFoundError(
|
||
f"No processor configuration file found in {source}. "
|
||
f"Tried: {common_names}. Please specify the config_filename parameter."
|
||
)
|
||
else:
|
||
# Download specific config file
|
||
config_path = hf_hub_download(
|
||
source,
|
||
config_filename,
|
||
repo_type="model",
|
||
force_download=force_download,
|
||
resume_download=resume_download,
|
||
proxies=proxies,
|
||
token=token,
|
||
cache_dir=cache_dir,
|
||
local_files_only=local_files_only,
|
||
revision=revision,
|
||
)
|
||
|
||
with open(config_path) as file_pointer:
|
||
loaded_config = json.load(file_pointer)
|
||
|
||
# Store downloaded files in the same directory as the config
|
||
base_path = Path(config_path).parent
|
||
|
||
# Handle None overrides
|
||
if overrides is None:
|
||
overrides = {}
|
||
|
||
# Validate that all override keys will be matched
|
||
override_keys = set(overrides.keys())
|
||
|
||
steps: list[ProcessorStep] = []
|
||
for step_entry in loaded_config["steps"]:
|
||
# Check if step uses registry name or module path
|
||
if "registry_name" in step_entry:
|
||
# Load from registry
|
||
try:
|
||
step_class = ProcessorStepRegistry.get(step_entry["registry_name"])
|
||
step_key = step_entry["registry_name"]
|
||
except KeyError as e:
|
||
raise ImportError(f"Failed to load processor step from registry. {str(e)}") from e
|
||
else:
|
||
# Fall back to module path loading for backward compatibility
|
||
full_class_path = step_entry["class"]
|
||
module_path, class_name = full_class_path.rsplit(".", 1)
|
||
|
||
# Import the module containing the step class
|
||
try:
|
||
module = importlib.import_module(module_path)
|
||
step_class = getattr(module, class_name)
|
||
step_key = class_name
|
||
except (ImportError, AttributeError) as e:
|
||
raise ImportError(
|
||
f"Failed to load processor step '{full_class_path}'. "
|
||
f"Make sure the module '{module_path}' is installed and contains class '{class_name}'. "
|
||
f"Consider registering the step using @ProcessorStepRegistry.register() for better portability. "
|
||
f"Error: {str(e)}"
|
||
) from e
|
||
|
||
# Instantiate the step with its config
|
||
try:
|
||
saved_cfg = step_entry.get("config", {})
|
||
step_overrides = overrides.get(step_key, {})
|
||
merged_cfg = {**saved_cfg, **step_overrides}
|
||
step_instance: ProcessorStep = step_class(**merged_cfg)
|
||
|
||
# Track which override keys were used
|
||
if step_key in override_keys:
|
||
override_keys.discard(step_key)
|
||
|
||
except Exception as e:
|
||
step_name = step_entry.get("registry_name", step_entry.get("class", "Unknown"))
|
||
raise ValueError(
|
||
f"Failed to instantiate processor step '{step_name}' with config: {step_entry.get('config', {})}. "
|
||
f"Error: {str(e)}"
|
||
) from e
|
||
|
||
# Load state if available
|
||
if "state_file" in step_entry and hasattr(step_instance, "load_state_dict"):
|
||
if Path(source).is_dir():
|
||
# Local path - read directly
|
||
state_path = str(base_path / step_entry["state_file"])
|
||
else:
|
||
# Hugging Face Hub - download the state file
|
||
state_path = hf_hub_download(
|
||
source,
|
||
step_entry["state_file"],
|
||
repo_type="model",
|
||
force_download=force_download,
|
||
resume_download=resume_download,
|
||
proxies=proxies,
|
||
token=token,
|
||
cache_dir=cache_dir,
|
||
local_files_only=local_files_only,
|
||
revision=revision,
|
||
)
|
||
|
||
step_instance.load_state_dict(load_file(state_path))
|
||
|
||
steps.append(step_instance)
|
||
|
||
# Check for unused override keys
|
||
if override_keys:
|
||
available_keys = []
|
||
for step_entry in loaded_config["steps"]:
|
||
if "registry_name" in step_entry:
|
||
available_keys.append(step_entry["registry_name"])
|
||
else:
|
||
full_class_path = step_entry["class"]
|
||
class_name = full_class_path.rsplit(".", 1)[1]
|
||
available_keys.append(class_name)
|
||
|
||
raise KeyError(
|
||
f"Override keys {list(override_keys)} do not match any step in the saved configuration. "
|
||
f"Available step keys: {available_keys}. "
|
||
f"Make sure override keys match exact step class names or registry names."
|
||
)
|
||
|
||
return cls(steps, loaded_config.get("name", "RobotProcessor"))
|
||
|
||
def __len__(self) -> int:
|
||
"""Return the number of steps in the processor."""
|
||
return len(self.steps)
|
||
|
||
def __getitem__(self, idx: int | slice) -> ProcessorStep | RobotProcessor:
|
||
"""Indexing helper exposing underlying steps.
|
||
* ``int`` – returns the idx-th ProcessorStep.
|
||
* ``slice`` – returns a new RobotProcessor with the sliced steps.
|
||
"""
|
||
if isinstance(idx, slice):
|
||
return RobotProcessor(self.steps[idx], self.name)
|
||
return self.steps[idx]
|
||
|
||
def register_before_step_hook(self, fn: Callable[[int, EnvTransition], None]):
|
||
"""Attach fn to be executed before every processor step."""
|
||
self.before_step_hooks.append(fn)
|
||
|
||
def unregister_before_step_hook(self, fn: Callable[[int, EnvTransition], None]):
|
||
"""Remove a previously registered before_step hook.
|
||
|
||
Args:
|
||
fn: The exact function reference that was registered. Must be the same object.
|
||
|
||
Raises:
|
||
ValueError: If the hook is not found in the registered hooks.
|
||
"""
|
||
try:
|
||
self.before_step_hooks.remove(fn)
|
||
except ValueError:
|
||
raise ValueError(
|
||
f"Hook {fn} not found in before_step_hooks. Make sure to pass the exact same function reference."
|
||
) from None
|
||
|
||
def register_after_step_hook(self, fn: Callable[[int, EnvTransition], None]):
|
||
"""Attach fn to be executed after every processor step."""
|
||
self.after_step_hooks.append(fn)
|
||
|
||
def unregister_after_step_hook(self, fn: Callable[[int, EnvTransition], None]):
|
||
"""Remove a previously registered after_step hook.
|
||
|
||
Args:
|
||
fn: The exact function reference that was registered. Must be the same object.
|
||
|
||
Raises:
|
||
ValueError: If the hook is not found in the registered hooks.
|
||
"""
|
||
try:
|
||
self.after_step_hooks.remove(fn)
|
||
except ValueError:
|
||
raise ValueError(
|
||
f"Hook {fn} not found in after_step_hooks. Make sure to pass the exact same function reference."
|
||
) from None
|
||
|
||
def reset(self):
|
||
"""Clear state in every step that implements ``reset()`` and fire registered hooks."""
|
||
for step in self.steps:
|
||
if hasattr(step, "reset"):
|
||
step.reset() # type: ignore[attr-defined]
|
||
|
||
def __repr__(self) -> str:
|
||
"""Return a readable string representation of the processor."""
|
||
step_names = [step.__class__.__name__ for step in self.steps]
|
||
|
||
if not step_names:
|
||
steps_repr = "steps=0: []"
|
||
elif len(step_names) <= 3:
|
||
steps_repr = f"steps={len(step_names)}: [{', '.join(step_names)}]"
|
||
else:
|
||
# Show first 2 and last 1 with ellipsis for long lists
|
||
displayed = f"{step_names[0]}, {step_names[1]}, ..., {step_names[-1]}"
|
||
steps_repr = f"steps={len(step_names)}: [{displayed}]"
|
||
|
||
parts = [f"name='{self.name}'", steps_repr]
|
||
|
||
return f"RobotProcessor({', '.join(parts)})"
|
||
|
||
def __post_init__(self):
|
||
for i, step in enumerate(self.steps):
|
||
if not callable(step):
|
||
raise TypeError(
|
||
f"Step {i} ({type(step).__name__}) must define __call__(transition) -> EnvTransition"
|
||
)
|
||
|
||
fc = getattr(step, "feature_contract", None)
|
||
if not callable(fc):
|
||
raise TypeError(
|
||
f"Step {i} ({type(step).__name__}) must define feature_contract(features) -> dict[str, Any]"
|
||
)
|
||
|
||
def feature_contract(self, initial_features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||
"""
|
||
Apply ALL steps in order. Each step must implement
|
||
feature_contract(features) and return a dict (full or incremental schema).
|
||
"""
|
||
features: dict[str, PolicyFeature] = deepcopy(initial_features)
|
||
|
||
for _, step in enumerate(self.steps):
|
||
out = step.feature_contract(features)
|
||
if not isinstance(out, dict):
|
||
raise TypeError(f"{step.__class__.__name__}.feature_contract must return dict[str, Any]")
|
||
features = out
|
||
return features
|
||
|
||
|
||
class ObservationProcessor:
|
||
"""Base class for processors that modify only the observation component of a transition.
|
||
|
||
Subclasses should override the `observation` method to implement custom observation processing.
|
||
This class handles the boilerplate of extracting and reinserting the processed observation
|
||
into the transition dict, eliminating the need to implement the `__call__` method in subclasses.
|
||
|
||
Example:
|
||
```python
|
||
class MyObservationScaler(ObservationProcessor):
|
||
def __init__(self, scale_factor):
|
||
self.scale_factor = scale_factor
|
||
|
||
def observation(self, observation):
|
||
return observation * self.scale_factor
|
||
```
|
||
|
||
By inheriting from this class, you avoid writing repetitive code to handle transition dict
|
||
manipulation, focusing only on the specific observation processing logic.
|
||
"""
|
||
|
||
def observation(self, observation):
|
||
"""Process the observation component.
|
||
|
||
Args:
|
||
observation: The observation to process
|
||
|
||
Returns:
|
||
The processed observation
|
||
"""
|
||
return observation
|
||
|
||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||
observation = transition.get(TransitionKey.OBSERVATION)
|
||
if observation is None:
|
||
return transition
|
||
|
||
processed_observation = self.observation(observation)
|
||
# Create a new transition dict with the processed observation
|
||
new_transition = transition.copy()
|
||
new_transition[TransitionKey.OBSERVATION] = processed_observation
|
||
return new_transition
|
||
|
||
def get_config(self) -> dict[str, Any]:
|
||
return {}
|
||
|
||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||
return {}
|
||
|
||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||
pass
|
||
|
||
def reset(self) -> None:
|
||
pass
|
||
|
||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||
return features
|
||
|
||
|
||
class ActionProcessor:
|
||
"""Base class for processors that modify only the action component of a transition.
|
||
|
||
Subclasses should override the `action` method to implement custom action processing.
|
||
This class handles the boilerplate of extracting and reinserting the processed action
|
||
into the transition dict, eliminating the need to implement the `__call__` method in subclasses.
|
||
|
||
Example:
|
||
```python
|
||
class ActionClipping(ActionProcessor):
|
||
def __init__(self, min_val, max_val):
|
||
self.min_val = min_val
|
||
self.max_val = max_val
|
||
|
||
def action(self, action):
|
||
return np.clip(action, self.min_val, self.max_val)
|
||
```
|
||
|
||
By inheriting from this class, you avoid writing repetitive code to handle transition dict
|
||
manipulation, focusing only on the specific action processing logic.
|
||
"""
|
||
|
||
def action(self, action):
|
||
"""Process the action component.
|
||
|
||
Args:
|
||
action: The action to process
|
||
|
||
Returns:
|
||
The processed action
|
||
"""
|
||
return action
|
||
|
||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||
action = transition.get(TransitionKey.ACTION)
|
||
if action is None:
|
||
return transition
|
||
|
||
processed_action = self.action(action)
|
||
# Create a new transition dict with the processed action
|
||
new_transition = transition.copy()
|
||
new_transition[TransitionKey.ACTION] = processed_action
|
||
return new_transition
|
||
|
||
def get_config(self) -> dict[str, Any]:
|
||
return {}
|
||
|
||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||
return {}
|
||
|
||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||
pass
|
||
|
||
def reset(self) -> None:
|
||
pass
|
||
|
||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||
return features
|
||
|
||
|
||
class RewardProcessor:
|
||
"""Base class for processors that modify only the reward component of a transition.
|
||
|
||
Subclasses should override the `reward` method to implement custom reward processing.
|
||
This class handles the boilerplate of extracting and reinserting the processed reward
|
||
into the transition dict, eliminating the need to implement the `__call__` method in subclasses.
|
||
|
||
Example:
|
||
```python
|
||
class RewardScaler(RewardProcessor):
|
||
def __init__(self, scale_factor):
|
||
self.scale_factor = scale_factor
|
||
|
||
def reward(self, reward):
|
||
return reward * self.scale_factor
|
||
```
|
||
|
||
By inheriting from this class, you avoid writing repetitive code to handle transition dict
|
||
manipulation, focusing only on the specific reward processing logic.
|
||
"""
|
||
|
||
def reward(self, reward):
|
||
"""Process the reward component.
|
||
|
||
Args:
|
||
reward: The reward to process
|
||
|
||
Returns:
|
||
The processed reward
|
||
"""
|
||
return reward
|
||
|
||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||
reward = transition.get(TransitionKey.REWARD)
|
||
if reward is None:
|
||
return transition
|
||
|
||
processed_reward = self.reward(reward)
|
||
# Create a new transition dict with the processed reward
|
||
new_transition = transition.copy()
|
||
new_transition[TransitionKey.REWARD] = processed_reward
|
||
return new_transition
|
||
|
||
def get_config(self) -> dict[str, Any]:
|
||
return {}
|
||
|
||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||
return {}
|
||
|
||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||
pass
|
||
|
||
def reset(self) -> None:
|
||
pass
|
||
|
||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||
return features
|
||
|
||
|
||
class DoneProcessor:
|
||
"""Base class for processors that modify only the done flag of a transition.
|
||
|
||
Subclasses should override the `done` method to implement custom done flag processing.
|
||
This class handles the boilerplate of extracting and reinserting the processed done flag
|
||
into the transition dict, eliminating the need to implement the `__call__` method in subclasses.
|
||
|
||
Example:
|
||
```python
|
||
class TimeoutDone(DoneProcessor):
|
||
def __init__(self, max_steps):
|
||
self.steps = 0
|
||
self.max_steps = max_steps
|
||
|
||
def done(self, done):
|
||
self.steps += 1
|
||
return done or self.steps >= self.max_steps
|
||
|
||
def reset(self):
|
||
self.steps = 0
|
||
```
|
||
|
||
By inheriting from this class, you avoid writing repetitive code to handle transition dict
|
||
manipulation, focusing only on the specific done flag processing logic.
|
||
"""
|
||
|
||
def done(self, done):
|
||
"""Process the done flag.
|
||
|
||
Args:
|
||
done: The done flag to process
|
||
|
||
Returns:
|
||
The processed done flag
|
||
"""
|
||
return done
|
||
|
||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||
done = transition.get(TransitionKey.DONE)
|
||
if done is None:
|
||
return transition
|
||
|
||
processed_done = self.done(done)
|
||
# Create a new transition dict with the processed done flag
|
||
new_transition = transition.copy()
|
||
new_transition[TransitionKey.DONE] = processed_done
|
||
return new_transition
|
||
|
||
def get_config(self) -> dict[str, Any]:
|
||
return {}
|
||
|
||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||
return {}
|
||
|
||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||
pass
|
||
|
||
def reset(self) -> None:
|
||
pass
|
||
|
||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||
return features
|
||
|
||
|
||
class TruncatedProcessor:
|
||
"""Base class for processors that modify only the truncated flag of a transition.
|
||
|
||
Subclasses should override the `truncated` method to implement custom truncated flag processing.
|
||
This class handles the boilerplate of extracting and reinserting the processed truncated flag
|
||
into the transition dict, eliminating the need to implement the `__call__` method in subclasses.
|
||
|
||
Example:
|
||
```python
|
||
class EarlyTruncation(TruncatedProcessor):
|
||
def __init__(self, threshold):
|
||
self.threshold = threshold
|
||
|
||
def truncated(self, truncated):
|
||
# Additional truncation condition
|
||
return truncated or some_condition > self.threshold
|
||
```
|
||
|
||
By inheriting from this class, you avoid writing repetitive code to handle transition dict
|
||
manipulation, focusing only on the specific truncated flag processing logic.
|
||
"""
|
||
|
||
def truncated(self, truncated):
|
||
"""Process the truncated flag.
|
||
|
||
Args:
|
||
truncated: The truncated flag to process
|
||
|
||
Returns:
|
||
The processed truncated flag
|
||
"""
|
||
return truncated
|
||
|
||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||
truncated = transition.get(TransitionKey.TRUNCATED)
|
||
if truncated is None:
|
||
return transition
|
||
|
||
processed_truncated = self.truncated(truncated)
|
||
# Create a new transition dict with the processed truncated flag
|
||
new_transition = transition.copy()
|
||
new_transition[TransitionKey.TRUNCATED] = processed_truncated
|
||
return new_transition
|
||
|
||
def get_config(self) -> dict[str, Any]:
|
||
return {}
|
||
|
||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||
return {}
|
||
|
||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||
pass
|
||
|
||
def reset(self) -> None:
|
||
pass
|
||
|
||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||
return features
|
||
|
||
|
||
class InfoProcessor:
|
||
"""Base class for processors that modify only the info dictionary of a transition.
|
||
|
||
Subclasses should override the `info` method to implement custom info processing.
|
||
This class handles the boilerplate of extracting and reinserting the processed info
|
||
into the transition dict, eliminating the need to implement the `__call__` method in subclasses.
|
||
|
||
Example:
|
||
```python
|
||
class InfoAugmenter(InfoProcessor):
|
||
def __init__(self):
|
||
self.step_count = 0
|
||
|
||
def info(self, info):
|
||
info = info.copy() # Create a copy to avoid modifying the original
|
||
info["steps"] = self.step_count
|
||
self.step_count += 1
|
||
return info
|
||
|
||
def reset(self):
|
||
self.step_count = 0
|
||
```
|
||
|
||
By inheriting from this class, you avoid writing repetitive code to handle transition dict
|
||
manipulation, focusing only on the specific info dictionary processing logic.
|
||
"""
|
||
|
||
def info(self, info):
|
||
"""Process the info dictionary.
|
||
|
||
Args:
|
||
info: The info dictionary to process
|
||
|
||
Returns:
|
||
The processed info dictionary
|
||
"""
|
||
return info
|
||
|
||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||
info = transition.get(TransitionKey.INFO)
|
||
if info is None:
|
||
return transition
|
||
|
||
processed_info = self.info(info)
|
||
# Create a new transition dict with the processed info
|
||
new_transition = transition.copy()
|
||
new_transition[TransitionKey.INFO] = processed_info
|
||
return new_transition
|
||
|
||
def get_config(self) -> dict[str, Any]:
|
||
return {}
|
||
|
||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||
return {}
|
||
|
||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||
pass
|
||
|
||
def reset(self) -> None:
|
||
pass
|
||
|
||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||
return features
|
||
|
||
|
||
class ComplementaryDataProcessor:
|
||
"""Base class for processors that modify only the complementary data of a transition.
|
||
|
||
Subclasses should override the `complementary_data` method to implement custom complementary data processing.
|
||
This class handles the boilerplate of extracting and reinserting the processed complementary data
|
||
into the transition dict, eliminating the need to implement the `__call__` method in subclasses.
|
||
"""
|
||
|
||
def complementary_data(self, complementary_data):
|
||
"""Process the complementary data.
|
||
|
||
Args:
|
||
complementary_data: The complementary data to process
|
||
|
||
Returns:
|
||
The processed complementary data
|
||
"""
|
||
return complementary_data
|
||
|
||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
||
if complementary_data is None:
|
||
return transition
|
||
|
||
processed_complementary_data = self.complementary_data(complementary_data)
|
||
# Create a new transition dict with the processed complementary data
|
||
new_transition = transition.copy()
|
||
new_transition[TransitionKey.COMPLEMENTARY_DATA] = processed_complementary_data
|
||
return new_transition
|
||
|
||
def get_config(self) -> dict[str, Any]:
|
||
return {}
|
||
|
||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||
return {}
|
||
|
||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||
pass
|
||
|
||
def reset(self) -> None:
|
||
pass
|
||
|
||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||
return features
|
||
|
||
|
||
class IdentityProcessor:
|
||
"""Identity processor that does nothing."""
|
||
|
||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||
return transition
|
||
|
||
def get_config(self) -> dict[str, Any]:
|
||
return {}
|
||
|
||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||
return {}
|
||
|
||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||
pass
|
||
|
||
def reset(self) -> None:
|
||
pass
|
||
|
||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||
return features
|