mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 00:59:46 +00:00
Refactor processing architecture to use RobotProcessor
- Replaced instances of RobotPipeline with RobotProcessor across the codebase for improved modularity and clarity. - Introduced ProcessorStepRegistry for better management of processing steps. - Updated relevant documentation and tests to reflect the new processing structure. - Enhanced the save/load functionality to support the new processor design. - Added a model card template for RobotProcessor to facilitate sharing and documentation.
This commit is contained in:
@@ -37,14 +37,14 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
||||
Dictionary of observation batches with keys renamed to LeRobot format and values as tensors.
|
||||
"""
|
||||
from lerobot.processor.observation_processor import ObservationProcessor
|
||||
from lerobot.processor.pipeline import RobotPipeline, TransitionIndex
|
||||
from lerobot.processor.pipeline import RobotProcessor, TransitionIndex
|
||||
|
||||
# Create pipeline with observation processor
|
||||
pipeline = RobotPipeline([ObservationProcessor()])
|
||||
# Create processor with observation processor
|
||||
processor = RobotProcessor([ObservationProcessor()])
|
||||
|
||||
# Create transition tuple and process
|
||||
transition = (observations, None, None, None, None, None, None)
|
||||
processed_transition = pipeline(transition)
|
||||
processed_transition = processor(transition)
|
||||
|
||||
# Return processed observations
|
||||
return processed_transition[TransitionIndex.OBSERVATION]
|
||||
|
||||
@@ -18,11 +18,11 @@ from .observation_processor import (
|
||||
ObservationProcessor,
|
||||
StateProcessor,
|
||||
)
|
||||
from .pipeline import EnvTransition, PipelineStep, RobotPipeline
|
||||
from .pipeline import EnvTransition, ProcessorStep, RobotProcessor
|
||||
|
||||
__all__ = [
|
||||
"RobotPipeline",
|
||||
"PipelineStep",
|
||||
"RobotProcessor",
|
||||
"ProcessorStep",
|
||||
"EnvTransition",
|
||||
"ImageProcessor",
|
||||
"StateProcessor",
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
@@ -51,13 +52,85 @@ EnvTransition = Tuple[
|
||||
]
|
||||
|
||||
|
||||
class PipelineStep(Protocol):
|
||||
"""Structural typing interface for a single pipeline step.
|
||||
class ProcessorStepRegistry:
|
||||
"""Registry for processor steps that enables saving/loading by name instead of module path."""
|
||||
|
||||
_registry: dict[str, type] = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, name: str = None):
|
||||
"""Decorator to register a processor step class.
|
||||
|
||||
Args:
|
||||
name: Optional registration name. If not provided, uses class name.
|
||||
|
||||
Example:
|
||||
@ProcessorStepRegistry.register("adaptive_normalizer")
|
||||
class AdaptiveObservationNormalizer:
|
||||
...
|
||||
"""
|
||||
|
||||
def decorator(step_class: type) -> type:
|
||||
registration_name = name if name is not None else step_class.__name__
|
||||
|
||||
if registration_name in cls._registry:
|
||||
raise ValueError(
|
||||
f"Processor step '{registration_name}' is already registered. "
|
||||
f"Use a different name or unregister the existing one first."
|
||||
)
|
||||
|
||||
cls._registry[registration_name] = step_class
|
||||
# Store the registration name on the class for later reference
|
||||
step_class._registry_name = registration_name
|
||||
return step_class
|
||||
|
||||
return decorator
|
||||
|
||||
@classmethod
|
||||
def get(cls, name: str) -> type:
|
||||
"""Get a registered processor step class by name.
|
||||
|
||||
Args:
|
||||
name: The registration name of the step.
|
||||
|
||||
Returns:
|
||||
The registered step class.
|
||||
|
||||
Raises:
|
||||
KeyError: If the step is not registered.
|
||||
"""
|
||||
if name not in cls._registry:
|
||||
available = list(cls._registry.keys())
|
||||
raise KeyError(
|
||||
f"Processor step '{name}' not found in registry. "
|
||||
f"Available steps: {available}. "
|
||||
f"Make sure the step is registered using @ProcessorStepRegistry.register()"
|
||||
)
|
||||
return cls._registry[name]
|
||||
|
||||
@classmethod
|
||||
def unregister(cls, name: str) -> None:
|
||||
"""Remove a step from the registry."""
|
||||
cls._registry.pop(name, None)
|
||||
|
||||
@classmethod
|
||||
def list(cls) -> list[str]:
|
||||
"""List all registered step names."""
|
||||
return list(cls._registry.keys())
|
||||
|
||||
@classmethod
|
||||
def clear(cls) -> None:
|
||||
"""Clear all registrations."""
|
||||
cls._registry.clear()
|
||||
|
||||
|
||||
class ProcessorStep(Protocol):
|
||||
"""Structural typing interface for a single processor step.
|
||||
|
||||
A step is any callable accepting a full `EnvTransition` tuple and
|
||||
returning a (possibly modified) tuple of the same structure. Implementers
|
||||
are encouraged—but not required—to expose the optional helper methods
|
||||
listed below. When present, these hooks let `RobotPipeline`
|
||||
listed below. When present, these hooks let `RobotProcessor`
|
||||
automatically serialise the step's configuration and learnable state using
|
||||
a safe-to-share JSON + SafeTensors format.
|
||||
|
||||
@@ -90,44 +163,44 @@ class PipelineStep(Protocol):
|
||||
|
||||
|
||||
@dataclass
|
||||
class RobotPipeline(ModelHubMixin):
|
||||
class RobotProcessor(ModelHubMixin):
|
||||
"""
|
||||
Composable, debuggable post-processing pipeline for RL transitions.
|
||||
Composable, debuggable post-processing processor for robot transitions.
|
||||
The class orchestrates an ordered collection of small, functional
|
||||
transforms—steps—executed left-to-right on each incoming
|
||||
`EnvTransition`.
|
||||
Parameters:
|
||||
steps : Sequence[PipelineStep], optional
|
||||
steps : Sequence[ProcessorStep], optional
|
||||
Ordered list executed on every call
|
||||
name : str, default="RobotPipeline"
|
||||
name : str, default="RobotProcessor"
|
||||
Human-readable identifier that is persisted inside the JSON config.
|
||||
seed : int | None, optional
|
||||
Global seed forwarded to steps that choose to consume it.
|
||||
Examples:
|
||||
Basic usage::
|
||||
env = gym.make("CartPole-v1")
|
||||
pipe = RobotPipeline([
|
||||
proc = RobotProcessor([
|
||||
ObservationNormalizer(),
|
||||
IntrinsicVelocity(),
|
||||
VelocityBonus(0.02),
|
||||
])
|
||||
obs, info = env.reset(seed=0)
|
||||
tr = (obs, None, 0.0, False, False, info, {})
|
||||
obs, *_ = pipe(tr) # agent sees a normalised observation
|
||||
obs, *_ = proc(tr) # agent sees a normalised observation
|
||||
Inspecting intermediate results::
|
||||
for idx, step_tr in enumerate(pipe.step_through(tr)):
|
||||
for idx, step_tr in enumerate(proc.step_through(tr)):
|
||||
print(idx, step_tr)
|
||||
Serialization to the Hugging Face Hub::
|
||||
pipe.save_pretrained("chkpt")
|
||||
pipe.push_to_hub("my-org/cartpole_pipe")
|
||||
loaded = RobotPipeline.from_pretrained("my-org/cartpole_pipe")
|
||||
proc.save_pretrained("chkpt")
|
||||
proc.push_to_hub("my-org/cartpole_proc")
|
||||
loaded = RobotProcessor.from_pretrained("my-org/cartpole_proc")
|
||||
"""
|
||||
|
||||
steps: Sequence[PipelineStep] = field(default_factory=list)
|
||||
name: str = "RobotPipeline"
|
||||
steps: Sequence[ProcessorStep] = field(default_factory=list)
|
||||
name: str = "RobotProcessor"
|
||||
seed: int | None = None
|
||||
|
||||
# Pipeline-level hooks
|
||||
# Processor-level hooks
|
||||
# A hook can optionally return a modified transition. If it returns
|
||||
# ``None`` the current value is left untouched.
|
||||
before_step_hooks: list[Callable[[int, EnvTransition], EnvTransition | None]] = field(
|
||||
@@ -148,13 +221,13 @@ class RobotPipeline(ModelHubMixin):
|
||||
f"got {type(transition).__name__} with length {len(transition) if hasattr(transition, '__len__') else 'unknown'}"
|
||||
)
|
||||
|
||||
for idx, pipeline_step in enumerate(self.steps):
|
||||
for idx, processor_step in enumerate(self.steps):
|
||||
for hook in self.before_step_hooks:
|
||||
updated = hook(idx, transition)
|
||||
if updated is not None:
|
||||
transition = updated
|
||||
|
||||
transition = pipeline_step(transition)
|
||||
transition = processor_step(transition)
|
||||
|
||||
for hook in self.after_step_hooks:
|
||||
updated = hook(idx, transition)
|
||||
@@ -164,20 +237,37 @@ class RobotPipeline(ModelHubMixin):
|
||||
return transition
|
||||
|
||||
def step_through(self, transition: EnvTransition) -> Iterable[EnvTransition]:
|
||||
"""Yield the intermediate Transition instances after each pipeline step."""
|
||||
"""Yield the intermediate Transition instances after each processor step."""
|
||||
yield transition
|
||||
for pipeline_step in self.steps:
|
||||
transition = pipeline_step(transition)
|
||||
for processor_step in self.steps:
|
||||
transition = processor_step(transition)
|
||||
yield transition
|
||||
|
||||
_CFG_NAME = "pipeline.json"
|
||||
_CFG_NAME = "processor.json"
|
||||
|
||||
def _save_pretrained(self, destination_path: str, **kwargs):
|
||||
"""Internal save method for ModelHubMixin compatibility."""
|
||||
self.save_pretrained(destination_path)
|
||||
|
||||
def _generate_model_card(self, destination_path: str) -> None:
|
||||
"""Generate README.md from the RobotProcessor model card template."""
|
||||
# Read the template
|
||||
template_path = Path(__file__).parent.parent / "templates" / "robotprocessor_modelcard_template.md"
|
||||
|
||||
if not template_path.exists():
|
||||
# Fallback: if template doesn't exist, skip model card generation
|
||||
return
|
||||
|
||||
with open(template_path) as f:
|
||||
model_card_content = f.read()
|
||||
|
||||
# Write the README.md
|
||||
readme_path = os.path.join(destination_path, "README.md")
|
||||
with open(readme_path, "w") as f:
|
||||
f.write(model_card_content)
|
||||
|
||||
def save_pretrained(self, destination_path: str, **kwargs):
|
||||
"""Serialize the pipeline definition and parameters to *destination_path*."""
|
||||
"""Serialize the processor definition and parameters to *destination_path*."""
|
||||
os.makedirs(destination_path, exist_ok=True)
|
||||
|
||||
config: dict[str, Any] = {
|
||||
@@ -186,19 +276,41 @@ class RobotPipeline(ModelHubMixin):
|
||||
"steps": [],
|
||||
}
|
||||
|
||||
for step_index, pipeline_step in enumerate(self.steps):
|
||||
step_entry: dict[str, Any] = {
|
||||
"class": f"{pipeline_step.__class__.__module__}.{pipeline_step.__class__.__name__}",
|
||||
}
|
||||
for step_index, processor_step in enumerate(self.steps):
|
||||
# Check if step was registered
|
||||
registry_name = getattr(processor_step.__class__, "_registry_name", None)
|
||||
|
||||
if hasattr(pipeline_step, "get_config"):
|
||||
step_entry["config"] = pipeline_step.get_config()
|
||||
if registry_name:
|
||||
# Use registry name for registered steps
|
||||
step_entry: dict[str, Any] = {
|
||||
"registry_name": registry_name,
|
||||
}
|
||||
else:
|
||||
# Fall back to full module path for unregistered steps
|
||||
step_entry: dict[str, Any] = {
|
||||
"class": f"{processor_step.__class__.__module__}.{processor_step.__class__.__name__}",
|
||||
}
|
||||
|
||||
if hasattr(pipeline_step, "state_dict"):
|
||||
state = pipeline_step.state_dict()
|
||||
if hasattr(processor_step, "get_config"):
|
||||
step_entry["config"] = processor_step.get_config()
|
||||
|
||||
if hasattr(processor_step, "state_dict"):
|
||||
state = processor_step.state_dict()
|
||||
if state:
|
||||
# Clone tensors to avoid shared memory issues
|
||||
# This ensures each tensor has its own memory allocation
|
||||
# The reason is to avoid the following error:
|
||||
# RuntimeError: Some tensors share memory, this will lead to duplicate memory on disk
|
||||
# and potential differences when loading them again
|
||||
# ------------------------------------------------------------------------------
|
||||
# Since the state_dict of processor will be light, we can just clone the tensors
|
||||
# and save them to the disk.
|
||||
cloned_state = {}
|
||||
for key, tensor in state.items():
|
||||
cloned_state[key] = tensor.clone()
|
||||
|
||||
state_filename = f"step_{step_index}.safetensors"
|
||||
save_file(state, os.path.join(destination_path, state_filename))
|
||||
save_file(cloned_state, os.path.join(destination_path, state_filename))
|
||||
step_entry["state_file"] = state_filename
|
||||
|
||||
config["steps"].append(step_entry)
|
||||
@@ -206,9 +318,30 @@ class RobotPipeline(ModelHubMixin):
|
||||
with open(os.path.join(destination_path, self._CFG_NAME), "w") as file_pointer:
|
||||
json.dump(config, file_pointer, indent=2)
|
||||
|
||||
# Generate README.md from template
|
||||
self._generate_model_card(destination_path)
|
||||
|
||||
def to(self, device: str | torch.device):
|
||||
"""Move all tensor states inside each step to device and return self.
|
||||
|
||||
Uses a generic mechanism: fetch each step's state dict, move every tensor
|
||||
to the target device, and reload it. Only works for steps that implement
|
||||
both state_dict() and load_state_dict() methods.
|
||||
"""
|
||||
device = torch.device(device)
|
||||
|
||||
for step in self.steps:
|
||||
if hasattr(step, "state_dict") and hasattr(step, "load_state_dict"):
|
||||
state = step.state_dict()
|
||||
if state: # Only process if there's actual state
|
||||
moved_state = {k: v.to(device) for k, v in state.items()}
|
||||
step.load_state_dict(moved_state)
|
||||
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, source: str) -> RobotPipeline:
|
||||
"""Load a serialized pipeline from *source* (local path or Hugging Face Hub identifier)."""
|
||||
def from_pretrained(cls, source: str) -> RobotProcessor:
|
||||
"""Load a serialized processor from source (local path or Hugging Face Hub identifier)."""
|
||||
if Path(source).is_dir():
|
||||
# Local path - use it directly
|
||||
base_path = Path(source)
|
||||
@@ -224,12 +357,43 @@ class RobotPipeline(ModelHubMixin):
|
||||
# Store downloaded files in the same directory as the config
|
||||
base_path = Path(config_path).parent
|
||||
|
||||
steps: list[PipelineStep] = []
|
||||
steps: list[ProcessorStep] = []
|
||||
for step_entry in config["steps"]:
|
||||
module_path, class_name = step_entry["class"].rsplit(".", 1)
|
||||
step_class = getattr(__import__(module_path, fromlist=[class_name]), class_name)
|
||||
step_instance: PipelineStep = step_class(**step_entry.get("config", {}))
|
||||
# Check if step uses registry name or module path
|
||||
if "registry_name" in step_entry:
|
||||
# Load from registry
|
||||
try:
|
||||
step_class = ProcessorStepRegistry.get(step_entry["registry_name"])
|
||||
except KeyError as e:
|
||||
raise ImportError(f"Failed to load processor step from registry. {str(e)}") from e
|
||||
else:
|
||||
# Fall back to module path loading for backward compatibility
|
||||
full_class_path = step_entry["class"]
|
||||
module_path, class_name = full_class_path.rsplit(".", 1)
|
||||
|
||||
# Import the module containing the step class
|
||||
try:
|
||||
module = importlib.import_module(module_path)
|
||||
step_class = getattr(module, class_name)
|
||||
except (ImportError, AttributeError) as e:
|
||||
raise ImportError(
|
||||
f"Failed to load processor step '{full_class_path}'. "
|
||||
f"Make sure the module '{module_path}' is installed and contains class '{class_name}'. "
|
||||
f"Consider registering the step using @ProcessorStepRegistry.register() for better portability. "
|
||||
f"Error: {str(e)}"
|
||||
) from e
|
||||
|
||||
# Instantiate the step with its config
|
||||
try:
|
||||
step_instance: ProcessorStep = step_class(**step_entry.get("config", {}))
|
||||
except Exception as e:
|
||||
step_name = step_entry.get("registry_name", step_entry.get("class", "Unknown"))
|
||||
raise ValueError(
|
||||
f"Failed to instantiate processor step '{step_name}' with config: {step_entry.get('config', {})}. "
|
||||
f"Error: {str(e)}"
|
||||
) from e
|
||||
|
||||
# Load state if available
|
||||
if "state_file" in step_entry and hasattr(step_instance, "load_state_dict"):
|
||||
if Path(source).is_dir():
|
||||
# Local path - read directly
|
||||
@@ -242,27 +406,27 @@ class RobotPipeline(ModelHubMixin):
|
||||
|
||||
steps.append(step_instance)
|
||||
|
||||
return cls(steps, config.get("name", "RobotPipeline"), config.get("seed"))
|
||||
return cls(steps, config.get("name", "RobotProcessor"), config.get("seed"))
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the number of steps in the pipeline."""
|
||||
"""Return the number of steps in the processor."""
|
||||
return len(self.steps)
|
||||
|
||||
def __getitem__(self, idx: int | slice) -> PipelineStep | RobotPipeline:
|
||||
def __getitem__(self, idx: int | slice) -> ProcessorStep | RobotProcessor:
|
||||
"""Indexing helper exposing underlying steps.
|
||||
* ``int`` – returns the idx-th PipelineStep.
|
||||
* ``slice`` – returns a new RobotPipeline with the sliced steps.
|
||||
* ``int`` – returns the idx-th ProcessorStep.
|
||||
* ``slice`` – returns a new RobotProcessor with the sliced steps.
|
||||
"""
|
||||
if isinstance(idx, slice):
|
||||
return RobotPipeline(self.steps[idx], self.name, self.seed)
|
||||
return RobotProcessor(self.steps[idx], self.name, self.seed)
|
||||
return self.steps[idx]
|
||||
|
||||
def register_before_step_hook(self, fn: Callable[[int, EnvTransition], EnvTransition | None]):
|
||||
"""Attach fn to be executed before every pipeline step."""
|
||||
"""Attach fn to be executed before every processor step."""
|
||||
self.before_step_hooks.append(fn)
|
||||
|
||||
def register_after_step_hook(self, fn: Callable[[int, EnvTransition], EnvTransition | None]):
|
||||
"""Attach fn to be executed after every pipeline step."""
|
||||
"""Attach fn to be executed after every processor step."""
|
||||
self.after_step_hooks.append(fn)
|
||||
|
||||
def register_reset_hook(self, fn: Callable[[], None]):
|
||||
@@ -283,17 +447,17 @@ class RobotPipeline(ModelHubMixin):
|
||||
|
||||
profile_results = {}
|
||||
|
||||
for idx, pipeline_step in enumerate(self.steps):
|
||||
step_name = f"step_{idx}_{pipeline_step.__class__.__name__}"
|
||||
for idx, processor_step in enumerate(self.steps):
|
||||
step_name = f"step_{idx}_{processor_step.__class__.__name__}"
|
||||
|
||||
# Warm up
|
||||
for _ in range(5):
|
||||
_ = pipeline_step(transition)
|
||||
_ = processor_step(transition)
|
||||
|
||||
# Time the step
|
||||
start_time = time.perf_counter()
|
||||
for _ in range(num_runs):
|
||||
transition = pipeline_step(transition)
|
||||
transition = processor_step(transition)
|
||||
end_time = time.perf_counter()
|
||||
|
||||
avg_time = (end_time - start_time) / num_runs * 1000 # Convert to milliseconds
|
||||
|
||||
@@ -73,7 +73,7 @@ from lerobot.policies.factory import make_policy
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.utils import get_device_from_parameters
|
||||
from lerobot.processor.observation_processor import ObservationProcessor
|
||||
from lerobot.processor.pipeline import RobotPipeline, TransitionIndex
|
||||
from lerobot.processor.pipeline import RobotProcessor, TransitionIndex
|
||||
from lerobot.utils.io_utils import write_video
|
||||
from lerobot.utils.random_utils import set_seed
|
||||
from lerobot.utils.utils import (
|
||||
@@ -130,8 +130,8 @@ def rollout(
|
||||
if render_callback is not None:
|
||||
render_callback(env)
|
||||
|
||||
# Create observation processing pipeline
|
||||
obs_pipeline = RobotPipeline([ObservationProcessor()])
|
||||
# Create observation processing processor
|
||||
obs_processor = RobotProcessor([ObservationProcessor()])
|
||||
|
||||
all_observations = []
|
||||
all_actions = []
|
||||
@@ -153,7 +153,7 @@ def rollout(
|
||||
while not np.all(done):
|
||||
# Numpy array to tensor and changing dictionary keys to LeRobot policy format.
|
||||
transition = (observation, None, None, None, None, None, None)
|
||||
processed_transition = obs_pipeline(transition)
|
||||
processed_transition = obs_processor(transition)
|
||||
observation = processed_transition[TransitionIndex.OBSERVATION]
|
||||
if return_observations:
|
||||
all_observations.append(deepcopy(observation))
|
||||
@@ -203,7 +203,7 @@ def rollout(
|
||||
# Track the final observation.
|
||||
if return_observations:
|
||||
transition = (observation, None, None, None, None, None, None)
|
||||
processed_transition = obs_pipeline(transition)
|
||||
processed_transition = obs_processor(transition)
|
||||
observation = processed_transition[TransitionIndex.OBSERVATION]
|
||||
all_observations.append(deepcopy(observation))
|
||||
|
||||
|
||||
@@ -0,0 +1,196 @@
|
||||
---
|
||||
library_name: lerobot
|
||||
tags:
|
||||
- robotics
|
||||
- lerobot
|
||||
- safetensors
|
||||
pipeline_tag: robotics
|
||||
---
|
||||
|
||||
# RobotProcessor
|
||||
|
||||
## Overview
|
||||
|
||||
RobotProcessor is a composable, debuggable post-processing pipeline for robot transitions in the LeRobot framework. It orchestrates an ordered collection of small, functional transforms (steps) that are executed left-to-right on each incoming `EnvTransition`.
|
||||
|
||||
## Architecture
|
||||
|
||||
The RobotProcessor provides a modular architecture for processing robot environment transitions through a sequence of composable steps. Each step is a callable that accepts a full `EnvTransition` tuple and returns a potentially modified tuple of the same structure.
|
||||
|
||||
### EnvTransition Structure
|
||||
|
||||
An `EnvTransition` is a 7-tuple containing:
|
||||
1. **observation**: Current state observation
|
||||
2. **action**: Action taken (can be None)
|
||||
3. **reward**: Reward received (float or None)
|
||||
4. **done**: Episode termination flag (bool or None)
|
||||
5. **truncated**: Episode truncation flag (bool or None)
|
||||
6. **info**: Additional information dictionary
|
||||
7. **complementary_data**: Extra data dictionary
|
||||
|
||||
## Key Features
|
||||
|
||||
- **Composable Pipeline**: Chain multiple processing steps in a specific order
|
||||
- **State Persistence**: Save and load processor state using SafeTensors format
|
||||
- **Hugging Face Hub Integration**: Easy sharing and loading via `save_pretrained()` and `from_pretrained()`
|
||||
- **Debugging Support**: Step-through functionality to inspect intermediate transformations
|
||||
- **Hook System**: Before/after step hooks for additional processing or monitoring
|
||||
- **Device Support**: Move tensor states to different devices (CPU/GPU)
|
||||
- **Performance Profiling**: Built-in profiling to identify bottlenecks
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
pip install lerobot
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
### Basic Example
|
||||
|
||||
```python
|
||||
from lerobot.processor.pipeline import RobotProcessor
|
||||
from your_steps import ObservationNormalizer, VelocityCalculator
|
||||
|
||||
# Create a processor with multiple steps
|
||||
processor = RobotProcessor(
|
||||
steps=[
|
||||
ObservationNormalizer(mean=0, std=1),
|
||||
VelocityCalculator(window_size=5),
|
||||
],
|
||||
name="my_robot_processor",
|
||||
seed=42
|
||||
)
|
||||
|
||||
# Process a transition
|
||||
obs, info = env.reset()
|
||||
transition = (obs, None, 0.0, False, False, info, {})
|
||||
processed_transition = processor(transition)
|
||||
|
||||
# Extract processed observation
|
||||
processed_obs = processed_transition[0]
|
||||
```
|
||||
|
||||
### Saving and Loading
|
||||
|
||||
```python
|
||||
# Save locally
|
||||
processor.save_pretrained("./my_processor")
|
||||
|
||||
# Push to Hugging Face Hub
|
||||
processor.push_to_hub("username/my-robot-processor")
|
||||
|
||||
# Load from Hub
|
||||
loaded_processor = RobotProcessor.from_pretrained("username/my-robot-processor")
|
||||
```
|
||||
|
||||
### Debugging with Step-Through
|
||||
|
||||
```python
|
||||
# Inspect intermediate results
|
||||
for idx, intermediate_transition in enumerate(processor.step_through(transition)):
|
||||
print(f"After step {idx}: {intermediate_transition[0]}") # Print observation
|
||||
```
|
||||
|
||||
### Using Hooks
|
||||
|
||||
```python
|
||||
# Add monitoring hook
|
||||
def log_observation(step_idx, transition):
|
||||
print(f"Step {step_idx}: obs shape = {transition[0].shape}")
|
||||
return None # Don't modify transition
|
||||
|
||||
processor.register_before_step_hook(log_observation)
|
||||
```
|
||||
|
||||
## Creating Custom Steps
|
||||
|
||||
To create a custom processor step, implement the `ProcessorStep` protocol:
|
||||
|
||||
```python
|
||||
from lerobot.processor.pipeline import ProcessorStepRegistry, EnvTransition
|
||||
|
||||
@ProcessorStepRegistry.register("my_custom_step")
|
||||
class MyCustomStep:
|
||||
def __init__(self, param1=1.0):
|
||||
self.param1 = param1
|
||||
self.buffer = []
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
obs, action, reward, done, truncated, info, comp_data = transition
|
||||
# Process observation
|
||||
processed_obs = obs * self.param1
|
||||
return (processed_obs, action, reward, done, truncated, info, comp_data)
|
||||
|
||||
def get_config(self) -> dict:
|
||||
return {"param1": self.param1}
|
||||
|
||||
def state_dict(self) -> dict:
|
||||
# Return only torch.Tensor state
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state: dict) -> None:
|
||||
# Load tensor state
|
||||
pass
|
||||
|
||||
def reset(self) -> None:
|
||||
# Clear buffers at episode boundaries
|
||||
self.buffer.clear()
|
||||
```
|
||||
|
||||
## Advanced Features
|
||||
|
||||
### Device Management
|
||||
|
||||
```python
|
||||
# Move all tensor states to GPU
|
||||
processor = processor.to("cuda")
|
||||
|
||||
# Move to specific device
|
||||
processor = processor.to(torch.device("cuda:1"))
|
||||
```
|
||||
|
||||
### Performance Profiling
|
||||
|
||||
```python
|
||||
# Profile step execution times
|
||||
profile_results = processor.profile_steps(transition, num_runs=100)
|
||||
for step_name, time_ms in profile_results.items():
|
||||
print(f"{step_name}: {time_ms:.3f} ms")
|
||||
```
|
||||
|
||||
### Processor Slicing
|
||||
|
||||
```python
|
||||
# Get a single step
|
||||
first_step = processor[0]
|
||||
|
||||
# Create a sub-processor with steps 1-3
|
||||
sub_processor = processor[1:4]
|
||||
```
|
||||
|
||||
## Model Card Specifications
|
||||
|
||||
- **Pipeline Tag**: robotics
|
||||
- **Library**: lerobot
|
||||
- **Format**: safetensors
|
||||
- **License**: Apache 2.0
|
||||
|
||||
## Limitations
|
||||
|
||||
- Steps must maintain the 7-tuple structure of EnvTransition
|
||||
- All tensor state must be separated from configuration for proper serialization
|
||||
- Steps are executed sequentially (no parallel processing within a single transition)
|
||||
|
||||
## Citation
|
||||
|
||||
If you use RobotProcessor in your research, please cite:
|
||||
|
||||
```bibtex
|
||||
@misc{cadene2024lerobot,
|
||||
author = {Cadene, Remi and Alibert, Simon and Soare, Alexander and Gallouedec, Quentin and Zouitine, Adil and Palma, Steven and Kooijmans, Pepijn and Aractingi, Michel and Shukor, Mustafa and Aubakirova, Dana and Russi, Martino and Capuano, Francesco and Pascale, Caroline and Choghari, Jade and Moss, Jess and Wolf, Thomas},
|
||||
title = {LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch},
|
||||
howpublished = "\url{https://github.com/huggingface/lerobot}",
|
||||
year = {2024}
|
||||
}
|
||||
```
|
||||
@@ -23,7 +23,7 @@ from gymnasium.utils.env_checker import check_env
|
||||
import lerobot
|
||||
from lerobot.envs.factory import make_env, make_env_config
|
||||
from lerobot.processor.observation_processor import ObservationProcessor
|
||||
from lerobot.processor.pipeline import RobotPipeline, TransitionIndex
|
||||
from lerobot.processor.pipeline import RobotProcessor, TransitionIndex
|
||||
from tests.utils import require_env
|
||||
|
||||
OBS_TYPES = ["state", "pixels", "pixels_agent_pos"]
|
||||
@@ -50,10 +50,10 @@ def test_factory(env_name):
|
||||
env = make_env(cfg, n_envs=1)
|
||||
obs, _ = env.reset()
|
||||
|
||||
# Process observation using pipeline
|
||||
obs_pipeline = RobotPipeline([ObservationProcessor()])
|
||||
# Process observation using processor
|
||||
obs_processor = RobotProcessor([ObservationProcessor()])
|
||||
transition = (obs, None, None, None, None, None, None)
|
||||
processed_transition = obs_pipeline(transition)
|
||||
processed_transition = obs_processor(transition)
|
||||
obs = processed_transition[TransitionIndex.OBSERVATION]
|
||||
|
||||
# test image keys are float32 in range [0,1]
|
||||
|
||||
@@ -40,7 +40,7 @@ from lerobot.policies.factory import (
|
||||
from lerobot.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.processor.observation_processor import ObservationProcessor
|
||||
from lerobot.processor.pipeline import RobotPipeline, TransitionIndex
|
||||
from lerobot.processor.pipeline import RobotProcessor, TransitionIndex
|
||||
from lerobot.utils.random_utils import seeded_context
|
||||
from tests.artifacts.policies.save_policy_to_safetensors import get_policy_stats
|
||||
from tests.utils import DEVICE, require_cpu, require_env, require_x86_64_kernel
|
||||
@@ -186,9 +186,9 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs):
|
||||
observation, _ = env.reset(seed=train_cfg.seed)
|
||||
|
||||
# apply transform to normalize the observations
|
||||
obs_pipeline = RobotPipeline([ObservationProcessor()])
|
||||
obs_processor = RobotProcessor([ObservationProcessor()])
|
||||
transition = (observation, None, None, None, None, None, None)
|
||||
processed_transition = obs_pipeline(transition)
|
||||
processed_transition = obs_processor(transition)
|
||||
observation = processed_transition[TransitionIndex.OBSERVATION]
|
||||
|
||||
# send observation to device/gpu
|
||||
|
||||
@@ -20,10 +20,12 @@ from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from lerobot.processor.pipeline import EnvTransition, RobotPipeline
|
||||
from lerobot.processor.pipeline import EnvTransition, RobotProcessor
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -139,7 +141,7 @@ class MockStepWithTensorState:
|
||||
|
||||
def test_empty_pipeline():
|
||||
"""Test pipeline with no steps."""
|
||||
pipeline = RobotPipeline()
|
||||
pipeline = RobotProcessor()
|
||||
|
||||
transition = (None, None, 0.0, False, False, {}, {})
|
||||
result = pipeline(transition)
|
||||
@@ -151,7 +153,7 @@ def test_empty_pipeline():
|
||||
def test_single_step_pipeline():
|
||||
"""Test pipeline with a single step."""
|
||||
step = MockStep("test_step")
|
||||
pipeline = RobotPipeline([step])
|
||||
pipeline = RobotProcessor([step])
|
||||
|
||||
transition = (None, None, 0.0, False, False, {}, {})
|
||||
result = pipeline(transition)
|
||||
@@ -168,7 +170,7 @@ def test_multiple_steps_pipeline():
|
||||
"""Test pipeline with multiple steps."""
|
||||
step1 = MockStep("step1")
|
||||
step2 = MockStep("step2")
|
||||
pipeline = RobotPipeline([step1, step2])
|
||||
pipeline = RobotProcessor([step1, step2])
|
||||
|
||||
transition = (None, None, 0.0, False, False, {}, {})
|
||||
result = pipeline(transition)
|
||||
@@ -180,7 +182,7 @@ def test_multiple_steps_pipeline():
|
||||
|
||||
def test_invalid_transition_format():
|
||||
"""Test pipeline with invalid transition format."""
|
||||
pipeline = RobotPipeline([MockStep()])
|
||||
pipeline = RobotProcessor([MockStep()])
|
||||
|
||||
# Test with wrong number of elements
|
||||
with pytest.raises(ValueError, match="EnvTransition must be a 7-tuple"):
|
||||
@@ -195,7 +197,7 @@ def test_step_through():
|
||||
"""Test step_through method."""
|
||||
step1 = MockStep("step1")
|
||||
step2 = MockStep("step2")
|
||||
pipeline = RobotPipeline([step1, step2])
|
||||
pipeline = RobotProcessor([step1, step2])
|
||||
|
||||
transition = (None, None, 0.0, False, False, {}, {})
|
||||
|
||||
@@ -211,7 +213,7 @@ def test_indexing():
|
||||
"""Test pipeline indexing."""
|
||||
step1 = MockStep("step1")
|
||||
step2 = MockStep("step2")
|
||||
pipeline = RobotPipeline([step1, step2])
|
||||
pipeline = RobotProcessor([step1, step2])
|
||||
|
||||
# Test integer indexing
|
||||
assert pipeline[0] is step1
|
||||
@@ -219,7 +221,7 @@ def test_indexing():
|
||||
|
||||
# Test slice indexing
|
||||
sub_pipeline = pipeline[0:1]
|
||||
assert isinstance(sub_pipeline, RobotPipeline)
|
||||
assert isinstance(sub_pipeline, RobotProcessor)
|
||||
assert len(sub_pipeline) == 1
|
||||
assert sub_pipeline[0] is step1
|
||||
|
||||
@@ -227,7 +229,7 @@ def test_indexing():
|
||||
def test_hooks():
|
||||
"""Test before/after step hooks."""
|
||||
step = MockStep("test_step")
|
||||
pipeline = RobotPipeline([step])
|
||||
pipeline = RobotProcessor([step])
|
||||
|
||||
before_calls = []
|
||||
after_calls = []
|
||||
@@ -253,7 +255,7 @@ def test_hooks():
|
||||
def test_hook_modification():
|
||||
"""Test that hooks can modify transitions."""
|
||||
step = MockStep("test_step")
|
||||
pipeline = RobotPipeline([step])
|
||||
pipeline = RobotProcessor([step])
|
||||
|
||||
def modify_reward_hook(idx: int, transition: EnvTransition):
|
||||
obs, action, reward, done, truncated, info, comp_data = transition
|
||||
@@ -270,7 +272,7 @@ def test_hook_modification():
|
||||
def test_reset():
|
||||
"""Test pipeline reset functionality."""
|
||||
step = MockStep("test_step")
|
||||
pipeline = RobotPipeline([step])
|
||||
pipeline = RobotProcessor([step])
|
||||
|
||||
reset_called = []
|
||||
|
||||
@@ -297,7 +299,7 @@ def test_profile_steps():
|
||||
"""Test step profiling functionality."""
|
||||
step1 = MockStep("step1")
|
||||
step2 = MockStep("step2")
|
||||
pipeline = RobotPipeline([step1, step2])
|
||||
pipeline = RobotProcessor([step1, step2])
|
||||
|
||||
transition = (None, None, 0.0, False, False, {}, {})
|
||||
|
||||
@@ -322,14 +324,14 @@ def test_save_and_load_pretrained():
|
||||
step1.counter = 5
|
||||
step2.counter = 10
|
||||
|
||||
pipeline = RobotPipeline([step1, step2], name="TestPipeline", seed=42)
|
||||
pipeline = RobotProcessor([step1, step2], name="TestPipeline", seed=42)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
# Save pipeline
|
||||
pipeline.save_pretrained(tmp_dir)
|
||||
|
||||
# Check files were created
|
||||
config_path = Path(tmp_dir) / "pipeline.json"
|
||||
config_path = Path(tmp_dir) / "processor.json"
|
||||
assert config_path.exists()
|
||||
|
||||
# Check config content
|
||||
@@ -345,7 +347,7 @@ def test_save_and_load_pretrained():
|
||||
assert config["steps"][1]["config"]["counter"] == 10
|
||||
|
||||
# Load pipeline
|
||||
loaded_pipeline = RobotPipeline.from_pretrained(tmp_dir)
|
||||
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir)
|
||||
|
||||
assert loaded_pipeline.name == "TestPipeline"
|
||||
assert loaded_pipeline.seed == 42
|
||||
@@ -359,7 +361,7 @@ def test_save_and_load_pretrained():
|
||||
def test_step_without_optional_methods():
|
||||
"""Test pipeline with steps that don't implement optional methods."""
|
||||
step = MockStepWithoutOptionalMethods(multiplier=3.0)
|
||||
pipeline = RobotPipeline([step])
|
||||
pipeline = RobotProcessor([step])
|
||||
|
||||
transition = (None, None, 2.0, False, False, {}, {})
|
||||
result = pipeline(transition)
|
||||
@@ -372,14 +374,14 @@ def test_step_without_optional_methods():
|
||||
# Save/load should work even without optional methods
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
pipeline.save_pretrained(tmp_dir)
|
||||
loaded_pipeline = RobotPipeline.from_pretrained(tmp_dir)
|
||||
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir)
|
||||
assert len(loaded_pipeline) == 1
|
||||
|
||||
|
||||
def test_mixed_json_and_tensor_state():
|
||||
"""Test step with both JSON attributes and tensor state."""
|
||||
step = MockStepWithTensorState(name="stats", learning_rate=0.05, window_size=5)
|
||||
pipeline = RobotPipeline([step])
|
||||
pipeline = RobotProcessor([step])
|
||||
|
||||
# Process some transitions with rewards
|
||||
for i in range(10):
|
||||
@@ -395,13 +397,13 @@ def test_mixed_json_and_tensor_state():
|
||||
pipeline.save_pretrained(tmp_dir)
|
||||
|
||||
# Check that both config and state files were created
|
||||
config_path = Path(tmp_dir) / "pipeline.json"
|
||||
config_path = Path(tmp_dir) / "processor.json"
|
||||
state_path = Path(tmp_dir) / "step_0.safetensors"
|
||||
assert config_path.exists()
|
||||
assert state_path.exists()
|
||||
|
||||
# Load and verify
|
||||
loaded_pipeline = RobotPipeline.from_pretrained(tmp_dir)
|
||||
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir)
|
||||
loaded_step = loaded_pipeline.steps[0]
|
||||
|
||||
# Check JSON attributes were restored
|
||||
@@ -412,3 +414,512 @@ def test_mixed_json_and_tensor_state():
|
||||
# Check tensor state was restored
|
||||
assert loaded_step.running_count.item() == 10
|
||||
assert torch.allclose(loaded_step.running_mean, step.running_mean)
|
||||
|
||||
|
||||
class MockModuleStep(nn.Module):
|
||||
"""Mock step that inherits from nn.Module to test state_dict handling of module parameters."""
|
||||
|
||||
def __init__(self, input_dim: int = 10, hidden_dim: int = 5):
|
||||
super().__init__()
|
||||
self.input_dim = input_dim
|
||||
self.hidden_dim = hidden_dim
|
||||
self.linear = nn.Linear(input_dim, hidden_dim)
|
||||
self.running_mean = nn.Parameter(torch.zeros(hidden_dim), requires_grad=False)
|
||||
self.counter = 0 # Non-tensor state
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.linear(x)
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
"""Process transition and update running mean."""
|
||||
obs, action, reward, done, truncated, info, comp_data = transition
|
||||
|
||||
if obs is not None and isinstance(obs, torch.Tensor):
|
||||
# Process observation through linear layer
|
||||
processed = self.forward(obs[:, : self.input_dim])
|
||||
|
||||
# Update running mean in-place (don't reassign the parameter)
|
||||
with torch.no_grad():
|
||||
self.running_mean.mul_(0.9).add_(processed.mean(dim=0), alpha=0.1)
|
||||
|
||||
self.counter += 1
|
||||
|
||||
return transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
"input_dim": self.input_dim,
|
||||
"hidden_dim": self.hidden_dim,
|
||||
"counter": self.counter,
|
||||
}
|
||||
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
"""Override to return all module parameters and buffers."""
|
||||
# Get the module's state dict (includes all parameters and buffers)
|
||||
return super().state_dict()
|
||||
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||
"""Override to load all module parameters and buffers."""
|
||||
# Use the module's load_state_dict
|
||||
super().load_state_dict(state)
|
||||
|
||||
def reset(self) -> None:
|
||||
self.running_mean.zero_()
|
||||
self.counter = 0
|
||||
|
||||
|
||||
def test_to_device_with_state_dict():
|
||||
"""Test moving pipeline to device for steps with state_dict."""
|
||||
step = MockStepWithTensorState(name="device_test", window_size=5)
|
||||
pipeline = RobotProcessor([step])
|
||||
|
||||
# Process some transitions to populate state
|
||||
for i in range(10):
|
||||
transition = (None, None, float(i), False, False, {}, {})
|
||||
pipeline(transition)
|
||||
|
||||
# Check initial device (should be CPU)
|
||||
assert step.running_mean.device.type == "cpu"
|
||||
assert step.running_count.device.type == "cpu"
|
||||
|
||||
# Move to same device (CPU)
|
||||
result = pipeline.to("cpu")
|
||||
assert result is pipeline # Check it returns self
|
||||
assert step.running_mean.device.type == "cpu"
|
||||
assert step.running_count.device.type == "cpu"
|
||||
|
||||
# Test with torch.device object
|
||||
result = pipeline.to(torch.device("cpu"))
|
||||
assert result is pipeline
|
||||
assert step.running_mean.device.type == "cpu"
|
||||
|
||||
# If CUDA is available, test GPU transfer
|
||||
if torch.cuda.is_available():
|
||||
result = pipeline.to("cuda")
|
||||
assert result is pipeline
|
||||
assert step.running_mean.device.type == "cuda"
|
||||
assert step.running_count.device.type == "cuda"
|
||||
|
||||
# Move back to CPU
|
||||
pipeline.to("cpu")
|
||||
assert step.running_mean.device.type == "cpu"
|
||||
assert step.running_count.device.type == "cpu"
|
||||
|
||||
|
||||
def test_to_device_with_module():
|
||||
"""Test moving pipeline to device for steps that inherit from nn.Module.
|
||||
|
||||
Even though the step inherits from nn.Module, the pipeline will use the
|
||||
state_dict/load_state_dict approach to move tensors to the device.
|
||||
"""
|
||||
module_step = MockModuleStep(input_dim=5, hidden_dim=3)
|
||||
pipeline = RobotProcessor([module_step])
|
||||
|
||||
# Process some data
|
||||
obs = torch.randn(2, 5)
|
||||
transition = (obs, None, 1.0, False, False, {}, {})
|
||||
pipeline(transition)
|
||||
|
||||
# Check initial device
|
||||
assert module_step.linear.weight.device.type == "cpu"
|
||||
assert module_step.running_mean.device.type == "cpu"
|
||||
|
||||
# Move to same device
|
||||
result = pipeline.to("cpu")
|
||||
assert result is pipeline
|
||||
assert module_step.linear.weight.device.type == "cpu"
|
||||
assert module_step.running_mean.device.type == "cpu"
|
||||
|
||||
# If CUDA is available, test GPU transfer
|
||||
if torch.cuda.is_available():
|
||||
result = pipeline.to("cuda:0")
|
||||
assert result is pipeline
|
||||
assert module_step.linear.weight.device.type == "cuda"
|
||||
assert module_step.linear.weight.device.index == 0
|
||||
assert module_step.running_mean.device.type == "cuda"
|
||||
assert module_step.running_mean.device.index == 0
|
||||
|
||||
# Verify the module still works after transfer
|
||||
obs_cuda = torch.randn(2, 5, device="cuda:0")
|
||||
transition = (obs_cuda, None, 1.0, False, False, {}, {})
|
||||
pipeline(transition) # Should not raise an error
|
||||
|
||||
|
||||
def test_to_device_mixed_steps():
|
||||
"""Test moving pipeline with various types of steps, all using state_dict approach."""
|
||||
module_step = MockModuleStep()
|
||||
state_dict_step = MockStepWithTensorState()
|
||||
simple_step = MockStepWithoutOptionalMethods() # No tensor state
|
||||
|
||||
pipeline = RobotProcessor([module_step, state_dict_step, simple_step])
|
||||
|
||||
# Process some data
|
||||
for i in range(5):
|
||||
transition = (torch.randn(2, 10), None, float(i), False, False, {}, {})
|
||||
pipeline(transition)
|
||||
|
||||
# Check initial state
|
||||
assert module_step.linear.weight.device.type == "cpu"
|
||||
assert state_dict_step.running_mean.device.type == "cpu"
|
||||
|
||||
# Move to device
|
||||
result = pipeline.to("cpu")
|
||||
assert result is pipeline
|
||||
|
||||
if torch.cuda.is_available():
|
||||
pipeline.to("cuda")
|
||||
assert module_step.linear.weight.device.type == "cuda"
|
||||
assert module_step.running_mean.device.type == "cuda"
|
||||
assert state_dict_step.running_mean.device.type == "cuda"
|
||||
assert state_dict_step.running_count.device.type == "cuda"
|
||||
|
||||
|
||||
def test_to_device_empty_state():
|
||||
"""Test moving pipeline with steps that have empty state_dict."""
|
||||
step = MockStep("empty_state") # This step has empty state_dict
|
||||
pipeline = RobotProcessor([step])
|
||||
|
||||
# Should not raise an error even with empty state
|
||||
result = pipeline.to("cpu")
|
||||
assert result is pipeline
|
||||
|
||||
if torch.cuda.is_available():
|
||||
result = pipeline.to("cuda")
|
||||
assert result is pipeline
|
||||
|
||||
|
||||
def test_to_device_preserves_functionality():
|
||||
"""Test that pipeline functionality is preserved after device transfer."""
|
||||
step = MockStepWithTensorState(window_size=3)
|
||||
pipeline = RobotProcessor([step])
|
||||
|
||||
# Process initial data
|
||||
rewards = [1.0, 2.0, 3.0]
|
||||
for r in rewards:
|
||||
transition = (None, None, r, False, False, {}, {})
|
||||
pipeline(transition)
|
||||
|
||||
# Check state before transfer
|
||||
initial_mean = step.running_mean.clone()
|
||||
initial_count = step.running_count.clone()
|
||||
|
||||
# Move to device (CPU to CPU in this case, but tests the mechanism)
|
||||
pipeline.to("cpu")
|
||||
|
||||
# Verify state is preserved
|
||||
assert torch.allclose(step.running_mean, initial_mean)
|
||||
assert step.running_count == initial_count
|
||||
|
||||
# Process more data to ensure functionality
|
||||
transition = (None, None, 4.0, False, False, {}, {})
|
||||
_ = pipeline(transition)
|
||||
|
||||
assert step.running_count == 4
|
||||
assert step.running_mean[0] == 4.0 # First slot should have been overwritten with 4.0
|
||||
|
||||
|
||||
def test_to_device_invalid_device():
|
||||
"""Test error handling for invalid devices."""
|
||||
pipeline = RobotProcessor([MockStep()])
|
||||
|
||||
# Invalid device names should raise an error from PyTorch
|
||||
with pytest.raises(RuntimeError):
|
||||
pipeline.to("invalid_device")
|
||||
|
||||
|
||||
def test_to_device_chaining():
|
||||
"""Test that to() returns self for method chaining."""
|
||||
step1 = MockStepWithTensorState()
|
||||
step2 = MockModuleStep()
|
||||
pipeline = RobotProcessor([step1, step2])
|
||||
|
||||
# Test chaining
|
||||
result = pipeline.to("cpu").reset()
|
||||
assert result is None # reset() returns None
|
||||
|
||||
# Can chain multiple to() calls
|
||||
result1 = pipeline.to("cpu")
|
||||
result2 = result1.to("cpu")
|
||||
assert result1 is pipeline
|
||||
assert result2 is pipeline
|
||||
|
||||
|
||||
class MockNonModuleStepWithState:
|
||||
"""Mock step that explicitly does NOT inherit from nn.Module but has tensor state.
|
||||
|
||||
This tests the state_dict/load_state_dict path for regular classes.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str = "non_module_step", feature_dim: int = 10):
|
||||
self.name = name
|
||||
self.feature_dim = feature_dim
|
||||
|
||||
# Initialize tensor state - these are regular tensors, not nn.Parameters
|
||||
self.weights = torch.randn(feature_dim, feature_dim)
|
||||
self.bias = torch.zeros(feature_dim)
|
||||
self.running_stats = torch.zeros(feature_dim)
|
||||
self.step_count = torch.tensor(0)
|
||||
|
||||
# Non-tensor state
|
||||
self.config_value = 42
|
||||
self.history = []
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
"""Process transition using tensor operations."""
|
||||
obs, action, reward, done, truncated, info, comp_data = transition
|
||||
|
||||
if obs is not None and isinstance(obs, torch.Tensor) and obs.numel() >= self.feature_dim:
|
||||
# Perform some tensor operations
|
||||
flat_obs = obs.flatten()[: self.feature_dim]
|
||||
|
||||
# Simple linear transformation (ensure dimensions match for matmul)
|
||||
output = torch.matmul(self.weights.T, flat_obs) + self.bias
|
||||
|
||||
# Update running stats
|
||||
self.running_stats = 0.9 * self.running_stats + 0.1 * output
|
||||
self.step_count += 1
|
||||
|
||||
# Add to complementary data
|
||||
comp_data = {} if comp_data is None else dict(comp_data)
|
||||
comp_data[f"{self.name}_mean_output"] = output.mean().item()
|
||||
comp_data[f"{self.name}_steps"] = self.step_count.item()
|
||||
|
||||
return (obs, action, reward, done, truncated, info, comp_data)
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
"name": self.name,
|
||||
"feature_dim": self.feature_dim,
|
||||
"config_value": self.config_value,
|
||||
}
|
||||
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
"""Return only tensor state."""
|
||||
return {
|
||||
"weights": self.weights,
|
||||
"bias": self.bias,
|
||||
"running_stats": self.running_stats,
|
||||
"step_count": self.step_count,
|
||||
}
|
||||
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||
"""Load tensor state."""
|
||||
self.weights = state["weights"]
|
||||
self.bias = state["bias"]
|
||||
self.running_stats = state["running_stats"]
|
||||
self.step_count = state["step_count"]
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset statistics but keep learned parameters."""
|
||||
self.running_stats.zero_()
|
||||
self.step_count.zero_()
|
||||
self.history.clear()
|
||||
|
||||
|
||||
def test_to_device_non_module_class():
|
||||
"""Test moving pipeline to device for regular classes (non nn.Module) with tensor state.
|
||||
|
||||
This ensures the state_dict/load_state_dict approach works for classes that
|
||||
don't inherit from nn.Module but still have tensor state to manage.
|
||||
"""
|
||||
# Create a non-module step with tensor state
|
||||
non_module_step = MockNonModuleStepWithState(name="device_test", feature_dim=5)
|
||||
pipeline = RobotProcessor([non_module_step])
|
||||
|
||||
# Process some data to populate state
|
||||
for i in range(3):
|
||||
obs = torch.randn(2, 5)
|
||||
transition = (obs, None, float(i), False, False, {}, {})
|
||||
result = pipeline(transition)
|
||||
comp_data = result[6]
|
||||
assert f"{non_module_step.name}_steps" in comp_data
|
||||
|
||||
# Verify all tensors are on CPU initially
|
||||
assert non_module_step.weights.device.type == "cpu"
|
||||
assert non_module_step.bias.device.type == "cpu"
|
||||
assert non_module_step.running_stats.device.type == "cpu"
|
||||
assert non_module_step.step_count.device.type == "cpu"
|
||||
|
||||
# Verify step count
|
||||
assert non_module_step.step_count.item() == 3
|
||||
|
||||
# Store initial values for comparison
|
||||
initial_weights = non_module_step.weights.clone()
|
||||
initial_bias = non_module_step.bias.clone()
|
||||
initial_stats = non_module_step.running_stats.clone()
|
||||
|
||||
# Move to same device (CPU)
|
||||
result = pipeline.to("cpu")
|
||||
assert result is pipeline
|
||||
|
||||
# Verify tensors are still on CPU and values unchanged
|
||||
assert non_module_step.weights.device.type == "cpu"
|
||||
assert torch.allclose(non_module_step.weights, initial_weights)
|
||||
assert torch.allclose(non_module_step.bias, initial_bias)
|
||||
assert torch.allclose(non_module_step.running_stats, initial_stats)
|
||||
|
||||
# If CUDA is available, test GPU transfer
|
||||
if torch.cuda.is_available():
|
||||
# Move to GPU
|
||||
pipeline.to("cuda")
|
||||
|
||||
# Verify all tensors moved to GPU
|
||||
assert non_module_step.weights.device.type == "cuda"
|
||||
assert non_module_step.bias.device.type == "cuda"
|
||||
assert non_module_step.running_stats.device.type == "cuda"
|
||||
assert non_module_step.step_count.device.type == "cuda"
|
||||
|
||||
# Verify values are preserved
|
||||
assert torch.allclose(non_module_step.weights.cpu(), initial_weights)
|
||||
assert torch.allclose(non_module_step.bias.cpu(), initial_bias)
|
||||
assert torch.allclose(non_module_step.running_stats.cpu(), initial_stats)
|
||||
assert non_module_step.step_count.item() == 3
|
||||
|
||||
# Test that step still works on GPU
|
||||
obs_gpu = torch.randn(2, 5, device="cuda")
|
||||
transition = (obs_gpu, None, 1.0, False, False, {}, {})
|
||||
result = pipeline(transition)
|
||||
comp_data = result[6]
|
||||
|
||||
# Verify processing worked
|
||||
assert comp_data[f"{non_module_step.name}_steps"] == 4
|
||||
|
||||
# Move back to CPU
|
||||
pipeline.to("cpu")
|
||||
assert non_module_step.weights.device.type == "cpu"
|
||||
assert non_module_step.step_count.item() == 4
|
||||
|
||||
|
||||
def test_to_device_module_vs_non_module():
|
||||
"""Test that both nn.Module and non-Module steps work with the same state_dict approach."""
|
||||
# Create both types of steps
|
||||
module_step = MockModuleStep(input_dim=5, hidden_dim=3)
|
||||
non_module_step = MockNonModuleStepWithState(name="non_module", feature_dim=5)
|
||||
|
||||
# Create pipeline with both
|
||||
pipeline = RobotProcessor([module_step, non_module_step])
|
||||
|
||||
# Process some data
|
||||
obs = torch.randn(2, 5)
|
||||
transition = (obs, None, 1.0, False, False, {}, {})
|
||||
_ = pipeline(transition)
|
||||
|
||||
# Check initial devices
|
||||
assert module_step.linear.weight.device.type == "cpu"
|
||||
assert module_step.running_mean.device.type == "cpu"
|
||||
assert non_module_step.weights.device.type == "cpu"
|
||||
assert non_module_step.running_stats.device.type == "cpu"
|
||||
|
||||
# Both should have been called
|
||||
assert module_step.counter == 1
|
||||
assert non_module_step.step_count.item() == 1
|
||||
|
||||
if torch.cuda.is_available():
|
||||
# Move to GPU
|
||||
pipeline.to("cuda")
|
||||
|
||||
# Verify both types of steps moved correctly
|
||||
assert module_step.linear.weight.device.type == "cuda"
|
||||
assert module_step.running_mean.device.type == "cuda"
|
||||
assert non_module_step.weights.device.type == "cuda"
|
||||
assert non_module_step.running_stats.device.type == "cuda"
|
||||
|
||||
# Process data on GPU
|
||||
obs_gpu = torch.randn(2, 5, device="cuda")
|
||||
transition = (obs_gpu, None, 2.0, False, False, {}, {})
|
||||
_ = pipeline(transition)
|
||||
|
||||
# Verify both steps processed the data
|
||||
assert module_step.counter == 2
|
||||
assert non_module_step.step_count.item() == 2
|
||||
|
||||
# Move back to CPU and verify
|
||||
pipeline.to("cpu")
|
||||
assert module_step.linear.weight.device.type == "cpu"
|
||||
assert non_module_step.weights.device.type == "cpu"
|
||||
|
||||
|
||||
class MockStepWithMixedState:
|
||||
"""Mock step demonstrating proper separation of tensor and non-tensor state.
|
||||
|
||||
Non-tensor state should go in get_config(), only tensors in state_dict().
|
||||
"""
|
||||
|
||||
def __init__(self, name: str = "mixed_state"):
|
||||
self.name = name
|
||||
self.tensor_data = torch.randn(5)
|
||||
self.numpy_data = np.array([1, 2, 3, 4, 5]) # Goes in config
|
||||
self.scalar_value = 42 # Goes in config
|
||||
self.list_value = [1, 2, 3] # Goes in config
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
# Simple pass-through
|
||||
return transition
|
||||
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
"""Return ONLY tensor state as per the type contract."""
|
||||
return {
|
||||
"tensor_data": self.tensor_data,
|
||||
}
|
||||
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||
"""Load tensor state only."""
|
||||
self.tensor_data = state["tensor_data"]
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""Non-tensor state goes here."""
|
||||
return {
|
||||
"name": self.name,
|
||||
"numpy_data": self.numpy_data.tolist(), # Convert to list for JSON serialization
|
||||
"scalar_value": self.scalar_value,
|
||||
"list_value": self.list_value,
|
||||
}
|
||||
|
||||
|
||||
def test_to_device_with_mixed_state_types():
|
||||
"""Test that to() only moves tensor state, while non-tensor state remains in config."""
|
||||
step = MockStepWithMixedState()
|
||||
pipeline = RobotProcessor([step])
|
||||
|
||||
# Store initial values
|
||||
initial_numpy = step.numpy_data.copy()
|
||||
initial_scalar = step.scalar_value
|
||||
initial_list = step.list_value.copy()
|
||||
|
||||
# Check initial state
|
||||
assert step.tensor_data.device.type == "cpu"
|
||||
assert isinstance(step.numpy_data, np.ndarray)
|
||||
assert isinstance(step.scalar_value, int)
|
||||
assert isinstance(step.list_value, list)
|
||||
|
||||
# Verify state_dict only contains tensors
|
||||
state = step.state_dict()
|
||||
assert all(isinstance(v, torch.Tensor) for v in state.values())
|
||||
assert "tensor_data" in state
|
||||
assert "numpy_data" not in state
|
||||
|
||||
# Move to same device
|
||||
pipeline.to("cpu")
|
||||
|
||||
# Verify tensor moved and non-tensor attributes unchanged
|
||||
assert step.tensor_data.device.type == "cpu"
|
||||
assert np.array_equal(step.numpy_data, initial_numpy)
|
||||
assert step.scalar_value == initial_scalar
|
||||
assert step.list_value == initial_list
|
||||
|
||||
if torch.cuda.is_available():
|
||||
# Move to GPU
|
||||
pipeline.to("cuda")
|
||||
|
||||
# Only tensor should move to GPU
|
||||
assert step.tensor_data.device.type == "cuda"
|
||||
|
||||
# Non-tensor values should remain unchanged
|
||||
assert isinstance(step.numpy_data, np.ndarray)
|
||||
assert np.array_equal(step.numpy_data, initial_numpy)
|
||||
assert step.scalar_value == initial_scalar
|
||||
assert step.list_value == initial_list
|
||||
|
||||
# Move back to CPU
|
||||
pipeline.to("cpu")
|
||||
assert step.tensor_data.device.type == "cpu"
|
||||
|
||||
Reference in New Issue
Block a user