Refactor observation preprocessing to use a modular pipeline system

- Introduced `RobotPipeline` and `ObservationProcessor` for handling observation transformations.
- Updated `preprocess_observation` to maintain backward compatibility while leveraging the new pipeline.
- Added tests for the new processing components and ensured they match the original functionality.
- Removed hardcoded logic in favor of a more flexible, composable architecture.
This commit is contained in:
Adil Zouitine
2025-07-02 17:29:58 +02:00
parent 945e1ff266
commit f6c7287ae7
9 changed files with 1472 additions and 50 deletions
+17 -43
View File
@@ -28,62 +28,36 @@ from lerobot.utils.utils import get_channel_first_image_shape
def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Tensor]: def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Tensor]:
# TODO(aliberts, rcadene): refactor this to use features from the environment (no hardcoding)
"""Convert environment observation to LeRobot format observation. """Convert environment observation to LeRobot format observation.
This function uses the new pipeline system internally but maintains
backward compatibility with the original interface.
Args: Args:
observation: Dictionary of observation batches from a Gym vector environment. observation: Dictionary of observation batches from a Gym vector environment.
Returns: Returns:
Dictionary of observation batches with keys renamed to LeRobot format and values as tensors. Dictionary of observation batches with keys renamed to LeRobot format and values as tensors.
""" """
# map to expected inputs for the policy from lerobot.processor.pipeline import RobotPipeline, TransitionIndex
return_observations = {} from lerobot.processor.observation_processor import ObservationProcessor
if "pixels" in observations:
if isinstance(observations["pixels"], dict): # Create pipeline with observation processor
imgs = {f"observation.images.{key}": img for key, img in observations["pixels"].items()} pipeline = RobotPipeline([ObservationProcessor()])
else:
imgs = {"observation.image": observations["pixels"]} # Create transition tuple and process
transition = (observations, None, None, None, None, None, None)
processed_transition = pipeline(transition)
# Return processed observations
return processed_transition[TransitionIndex.OBSERVATION]
for imgkey, img in imgs.items():
# TODO(aliberts, rcadene): use transforms.ToTensor()?
img = torch.from_numpy(img)
# When preprocessing observations in a non-vectorized environment, we need to add a batch dimension.
# This is the case for human-in-the-loop RL where there is only one environment.
if img.ndim == 3:
img = img.unsqueeze(0)
# sanity check that images are channel last
_, h, w, c = img.shape
assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}"
# sanity check that images are uint8
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
# convert to channel first of type float32 in range [0,1]
img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
img = img.type(torch.float32)
img /= 255
return_observations[imgkey] = img
if "environment_state" in observations:
env_state = torch.from_numpy(observations["environment_state"]).float()
if env_state.dim() == 1:
env_state = env_state.unsqueeze(0)
return_observations["observation.environment_state"] = env_state
# TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing
agent_pos = torch.from_numpy(observations["agent_pos"]).float()
if agent_pos.dim() == 1:
agent_pos = agent_pos.unsqueeze(0)
return_observations["observation.state"] = agent_pos
return return_observations
def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]: def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]:
# TODO(aliberts, rcadene): remove this hardcoding of keys and just use the nested keys as is # TODO(aliberts, rcadene): remove this hardcoding of keys and just use the nested keys as is
# (need to also refactor preprocess_observation and externalize normalization from policies) # (need to externalize normalization from policies)
policy_features = {} policy_features = {}
for key, ft in env_cfg.features.items(): for key, ft in env_cfg.features.items():
if ft.type is FeatureType.VISUAL: if ft.type is FeatureType.VISUAL:
+30
View File
@@ -0,0 +1,30 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .pipeline import RobotPipeline, PipelineStep, EnvTransition
from .observation_processor import (
ImageProcessor,
StateProcessor,
ObservationProcessor,
)
__all__ = [
"RobotPipeline",
"PipelineStep",
"EnvTransition",
"ImageProcessor",
"StateProcessor",
"ObservationProcessor",
]
@@ -0,0 +1,224 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import Any
from dataclasses import dataclass, field
import numpy as np
import torch
import einops
from torch import Tensor
from lerobot.processor.pipeline import EnvTransition, PipelineStep, TransitionIndex
@dataclass
class ImageProcessor:
"""Process image observations from environment format to policy format.
Converts images from:
- Channel-last (H, W, C) to channel-first (C, H, W)
- uint8 [0, 255] to float32 [0, 1]
- Adds batch dimension if needed
- Handles both single images and dictionaries of images
"""
def __call__(self, transition: EnvTransition) -> EnvTransition:
observation = transition[TransitionIndex.OBSERVATION]
if observation is None:
return transition
processed_obs = {}
# Handle pixels key
if "pixels" in observation:
if isinstance(observation["pixels"], dict):
imgs = {f"observation.images.{key}": img for key, img in observation["pixels"].items()}
else:
imgs = {"observation.image": observation["pixels"]}
for imgkey, img in imgs.items():
processed_img = self._process_single_image(img)
processed_obs[imgkey] = processed_img
# Copy other observations unchanged
for key, value in observation.items():
if key != "pixels":
processed_obs[key] = value
# Return new transition with processed observation
return (
processed_obs,
transition[TransitionIndex.ACTION],
transition[TransitionIndex.REWARD],
transition[TransitionIndex.DONE],
transition[TransitionIndex.TRUNCATED],
transition[TransitionIndex.INFO],
transition[TransitionIndex.COMPLEMENTARY_DATA],
)
def _process_single_image(self, img: np.ndarray) -> Tensor:
"""Process a single image array."""
# Convert to tensor
img_tensor = torch.from_numpy(img)
# Add batch dimension if needed
if img_tensor.ndim == 3:
img_tensor = img_tensor.unsqueeze(0)
# Validate image format
_, h, w, c = img_tensor.shape
if not (c < h and c < w):
raise ValueError(f"Expected channel-last images, but got shape {img_tensor.shape}")
if img_tensor.dtype != torch.uint8:
raise ValueError(f"Expected torch.uint8 images, but got {img_tensor.dtype}")
# Convert to channel-first format
img_tensor = einops.rearrange(img_tensor, "b h w c -> b c h w").contiguous()
# Convert to float32 and normalize to [0, 1]
img_tensor = img_tensor.type(torch.float32) / 255.0
return img_tensor
def get_config(self) -> dict[str, Any]:
"""Return configuration for serialization."""
return {}
def state_dict(self) -> dict[str, torch.Tensor]:
"""Return state dictionary (empty for this processor)."""
return {}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
"""Load state dictionary (no-op for this processor)."""
pass
def reset(self) -> None:
"""Reset processor state (no-op for this processor)."""
pass
@dataclass
class StateProcessor:
"""Process state observations from environment format to policy format.
Handles:
- environment_state -> observation.environment_state
- agent_pos -> observation.state
- Converts numpy arrays to tensors
- Adds batch dimension if needed
"""
def __call__(self, transition: EnvTransition) -> EnvTransition:
observation = transition[TransitionIndex.OBSERVATION]
if observation is None:
return transition
processed_obs = dict(observation) # Copy existing observations
# Process environment_state
if "environment_state" in observation:
env_state = torch.from_numpy(observation["environment_state"]).float()
if env_state.dim() == 1:
env_state = env_state.unsqueeze(0)
processed_obs["observation.environment_state"] = env_state
# Remove original key
del processed_obs["environment_state"]
# Process agent_pos
if "agent_pos" in observation:
agent_pos = torch.from_numpy(observation["agent_pos"]).float()
if agent_pos.dim() == 1:
agent_pos = agent_pos.unsqueeze(0)
processed_obs["observation.state"] = agent_pos
# Remove original key
del processed_obs["agent_pos"]
# Return new transition with processed observation
return (
processed_obs,
transition[TransitionIndex.ACTION],
transition[TransitionIndex.REWARD],
transition[TransitionIndex.DONE],
transition[TransitionIndex.TRUNCATED],
transition[TransitionIndex.INFO],
transition[TransitionIndex.COMPLEMENTARY_DATA],
)
def get_config(self) -> dict[str, Any]:
"""Return configuration for serialization."""
return {}
def state_dict(self) -> dict[str, torch.Tensor]:
"""Return state dictionary (empty for this processor)."""
return {}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
"""Load state dictionary (no-op for this processor)."""
pass
def reset(self) -> None:
"""Reset processor state (no-op for this processor)."""
pass
@dataclass
class ObservationProcessor:
"""Complete observation processor that combines image and state processing.
This processor replicates the functionality of the original preprocess_observation
function but in a modular, composable way that fits into the pipeline architecture.
"""
image_processor: ImageProcessor = field(default_factory=ImageProcessor)
state_processor: StateProcessor = field(default_factory=StateProcessor)
def __call__(self, transition: EnvTransition) -> EnvTransition:
# First process images
transition = self.image_processor(transition)
# Then process state
transition = self.state_processor(transition)
return transition
def get_config(self) -> dict[str, Any]:
"""Return configuration for serialization."""
return {
"image_processor": self.image_processor.get_config(),
"state_processor": self.state_processor.get_config(),
}
def state_dict(self) -> dict[str, torch.Tensor]:
"""Return state dictionary."""
state = {}
state.update({f"image_processor.{k}": v for k, v in self.image_processor.state_dict().items()})
state.update({f"state_processor.{k}": v for k, v in self.state_processor.state_dict().items()})
return state
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
"""Load state dictionary."""
image_state = {k.replace("image_processor.", ""): v for k, v in state.items() if k.startswith("image_processor.")}
state_state = {k.replace("state_processor.", ""): v for k, v in state.items() if k.startswith("state_processor.")}
self.image_processor.load_state_dict(image_state)
self.state_processor.load_state_dict(state_state)
def reset(self) -> None:
"""Reset processor state."""
self.image_processor.reset()
self.state_processor.reset()
+299
View File
@@ -0,0 +1,299 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import os, json
from typing import Any, Dict, Sequence, Iterable, Protocol, Optional, Tuple, Callable, Union
from dataclasses import dataclass, field
from pathlib import Path
from enum import IntEnum
import numpy as np
import torch
from huggingface_hub import hf_hub_download, ModelHubMixin
from safetensors.torch import save_file, load_file
class TransitionIndex(IntEnum):
"""Explicit indices for EnvTransition tuple components."""
OBSERVATION = 0
ACTION = 1
REWARD = 2
DONE = 3
TRUNCATED = 4
INFO = 5
COMPLEMENTARY_DATA = 6
# (observation, action, reward, done, truncated, info, complementary_data)
EnvTransition = Tuple[
Any| None, # observation
Any| None, # action
float| None, # reward
bool| None, # done
bool| None, # truncated
Dict[str, Any]| None, # info
Dict[str, Any]| None, # complementary_data
]
class PipelineStep(Protocol):
"""Structural typing interface for a single pipeline step.
A step is any callable accepting a full `EnvTransition` tuple and
returning a (possibly modified) tuple of the same structure. Implementers
are encouragedbut not requiredto expose the optional helper methods
listed below. When present, these hooks let `RobotPipeline`
automatically serialise the step's configuration and learnable state using
a safe-to-share JSON + SafeTensors format.
Optional helper protocol:
* ``get_config() -> Dict[str, Any]`` User-defined JSON-serializable
configuration and state. YOU decide what to save here. This is where all
non-tensor state goes (e.g., name, counter, threshold, window_size).
The config dict will be passed to your class constructor when loading.
* ``state_dict() -> Dict[str, torch.Tensor]`` PyTorch tensor state ONLY.
This is exclusively for torch.Tensor objects (e.g., learned weights,
running statistics as tensors). Never put simple Python types here.
* ``load_state_dict(state)`` Inverse of ``state_dict``. Receives a dict
containing torch tensors only.
* ``reset()`` Clear internal buffers at episode boundaries.
Example separation:
- get_config(): {"name": "my_step", "learning_rate": 0.01, "window_size": 10}
- state_dict(): {"weights": torch.tensor(...), "running_mean": torch.tensor(...)}
"""
def __call__(self, transition: EnvTransition) -> EnvTransition: ...
def get_config(self) -> Dict[str, Any]: ...
def state_dict(self) -> Dict[str, torch.Tensor]: ...
def load_state_dict(self, state: Dict[str, torch.Tensor]) -> None: ...
def reset(self) -> None: ...
@dataclass
class RobotPipeline(ModelHubMixin):
"""
Composable, debuggable post-processing pipeline for RL transitions.
The class orchestrates an ordered collection of small, functional
transformsstepsexecuted left-to-right on each incoming
`EnvTransition`.
Parameters:
steps : Sequence[PipelineStep], optional
Ordered list executed on every call
name : str, default="RobotPipeline"
Human-readable identifier that is persisted inside the JSON config.
seed : int | None, optional
Global seed forwarded to steps that choose to consume it.
Examples:
Basic usage::
env = gym.make("CartPole-v1")
pipe = RobotPipeline([
ObservationNormalizer(),
IntrinsicVelocity(),
VelocityBonus(0.02),
])
obs, info = env.reset(seed=0)
tr = (obs, None, 0.0, False, False, info, {})
obs, *_ = pipe(tr) # agent sees a normalised observation
Inspecting intermediate results::
for idx, step_tr in enumerate(pipe.step_through(tr)):
print(idx, step_tr)
Serialization to the Hugging Face Hub::
pipe.save_pretrained("chkpt")
pipe.push_to_hub("my-org/cartpole_pipe")
loaded = RobotPipeline.from_pretrained("my-org/cartpole_pipe")
"""
steps: Sequence[PipelineStep] = field(default_factory=list)
name: str = "RobotPipeline"
seed: Optional[int] = None
# Pipeline-level hooks
# A hook can optionally return a modified transition. If it returns
# ``None`` the current value is left untouched.
before_step_hooks: list[Callable[[int, EnvTransition], Optional[EnvTransition]]] = field(
default_factory=list, repr=False
)
after_step_hooks: list[Callable[[int, EnvTransition], Optional[EnvTransition]]] = field(
default_factory=list, repr=False
)
reset_hooks: list[Callable[[], None]] = field(default_factory=list, repr=False)
def __call__(self, transition: EnvTransition) -> EnvTransition:
"""Run *transition* through every step, firing hooks on the way."""
# Basic validation with helpful error message
if not isinstance(transition, tuple) or len(transition) != 7:
raise ValueError(
f"EnvTransition must be a 7-tuple of (observation, action, reward, done, truncated, info, complementary_data), "
f"got {type(transition).__name__} with length {len(transition) if hasattr(transition, '__len__') else 'unknown'}"
)
for idx, pipeline_step in enumerate(self.steps):
for hook in self.before_step_hooks:
updated = hook(idx, transition)
if updated is not None:
transition = updated
transition = pipeline_step(transition)
for hook in self.after_step_hooks:
updated = hook(idx, transition)
if updated is not None:
transition = updated
return transition
def step_through(self, transition: EnvTransition) -> Iterable[EnvTransition]:
"""Yield the intermediate Transition instances after each pipeline step."""
yield transition
for pipeline_step in self.steps:
transition = pipeline_step(transition)
yield transition
_CFG_NAME = "pipeline.json"
def _save_pretrained(self, destination_path: str, **kwargs):
"""Internal save method for ModelHubMixin compatibility."""
self.save_pretrained(destination_path)
def save_pretrained(self, destination_path: str, **kwargs):
"""Serialize the pipeline definition and parameters to *destination_path*."""
os.makedirs(destination_path, exist_ok=True)
config: Dict[str, Any] = {
"name": self.name,
"seed": self.seed,
"steps": [],
}
for step_index, pipeline_step in enumerate(self.steps):
step_entry: Dict[str, Any] = {
"class": f"{pipeline_step.__class__.__module__}.{pipeline_step.__class__.__name__}",
}
if hasattr(pipeline_step, "get_config"):
step_entry["config"] = pipeline_step.get_config()
if hasattr(pipeline_step, "state_dict"):
state = pipeline_step.state_dict()
if state:
state_filename = f"step_{step_index}.safetensors"
save_file(state, os.path.join(destination_path, state_filename))
step_entry["state_file"] = state_filename
config["steps"].append(step_entry)
with open(os.path.join(destination_path, self._CFG_NAME), "w") as file_pointer:
json.dump(config, file_pointer, indent=2)
@classmethod
def from_pretrained(cls, source: str) -> "RobotPipeline":
"""Load a serialized pipeline from *source* (local path or Hugging Face Hub identifier)."""
if Path(source).is_dir():
# Local path - use it directly
base_path = Path(source)
with open(base_path / cls._CFG_NAME) as file_pointer:
config: Dict[str, Any] = json.load(file_pointer)
else:
# Hugging Face Hub - download all required files
# First download the config file
config_path = hf_hub_download(source, cls._CFG_NAME, repo_type="model")
with open(config_path) as file_pointer:
config: Dict[str, Any] = json.load(file_pointer)
# Store downloaded files in the same directory as the config
base_path = Path(config_path).parent
steps: list[PipelineStep] = []
for step_entry in config["steps"]:
module_path, class_name = step_entry["class"].rsplit(".", 1)
step_class = getattr(__import__(module_path, fromlist=[class_name]), class_name)
step_instance: PipelineStep = step_class(**step_entry.get("config", {}))
if "state_file" in step_entry and hasattr(step_instance, "load_state_dict"):
if Path(source).is_dir():
# Local path - read directly
state_path = str(base_path / step_entry["state_file"])
else:
# Hugging Face Hub - download the state file
state_path = hf_hub_download(source, step_entry["state_file"], repo_type="model")
step_instance.load_state_dict(load_file(state_path))
steps.append(step_instance)
return cls(steps, config.get("name", "RobotPipeline"), config.get("seed"))
def __len__(self) -> int:
"""Return the number of steps in the pipeline."""
return len(self.steps)
def __getitem__(self, idx: int | slice) -> PipelineStep | RobotPipeline:
"""Indexing helper exposing underlying steps.
* ``int`` returns the idx-th PipelineStep.
* ``slice`` returns a new RobotPipeline with the sliced steps.
"""
if isinstance(idx, slice):
return RobotPipeline(self.steps[idx], self.name, self.seed)
return self.steps[idx]
def register_before_step_hook(self, fn: Callable[[int, EnvTransition], Optional[EnvTransition]]):
"""Attach fn to be executed before every pipeline step."""
self.before_step_hooks.append(fn)
def register_after_step_hook(self, fn: Callable[[int, EnvTransition], Optional[EnvTransition]]):
"""Attach fn to be executed after every pipeline step."""
self.after_step_hooks.append(fn)
def register_reset_hook(self, fn: Callable[[], None]):
"""Attach fn to be executed when reset is called."""
self.reset_hooks.append(fn)
def reset(self):
"""Clear state in every step that implements ``reset()`` and fire registered hooks."""
for step in self.steps:
if hasattr(step, "reset"):
step.reset() # type: ignore[attr-defined]
for fn in self.reset_hooks:
fn()
def profile_steps(self, transition: EnvTransition, num_runs: int = 100) -> Dict[str, float]:
"""Profile the execution time of each step for performance optimization."""
import time
profile_results = {}
for idx, pipeline_step in enumerate(self.steps):
step_name = f"step_{idx}_{pipeline_step.__class__.__name__}"
# Warm up
for _ in range(5):
_ = pipeline_step(transition)
# Time the step
start_time = time.perf_counter()
for _ in range(num_runs):
transition = pipeline_step(transition)
end_time = time.perf_counter()
avg_time = (end_time - start_time) / num_runs * 1000 # Convert to milliseconds
profile_results[step_name] = avg_time
return profile_results
+13 -3
View File
@@ -68,7 +68,10 @@ from tqdm import trange
from lerobot.configs import parser from lerobot.configs import parser
from lerobot.configs.eval import EvalPipelineConfig from lerobot.configs.eval import EvalPipelineConfig
from lerobot.envs.factory import make_env from lerobot.envs.factory import make_env
from lerobot.envs.utils import add_envs_task, check_env_attributes_and_types, preprocess_observation from lerobot.envs.utils import add_envs_task, check_env_attributes_and_types
from lerobot.processor.pipeline import RobotPipeline, TransitionIndex
from lerobot.processor.observation_processor import ObservationProcessor
from lerobot.policies.factory import make_policy from lerobot.policies.factory import make_policy
from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.utils import get_device_from_parameters from lerobot.policies.utils import get_device_from_parameters
@@ -127,6 +130,9 @@ def rollout(
observation, info = env.reset(seed=seeds) observation, info = env.reset(seed=seeds)
if render_callback is not None: if render_callback is not None:
render_callback(env) render_callback(env)
# Create observation processing pipeline
obs_pipeline = RobotPipeline([ObservationProcessor()])
all_observations = [] all_observations = []
all_actions = [] all_actions = []
@@ -147,7 +153,9 @@ def rollout(
check_env_attributes_and_types(env) check_env_attributes_and_types(env)
while not np.all(done): while not np.all(done):
# Numpy array to tensor and changing dictionary keys to LeRobot policy format. # Numpy array to tensor and changing dictionary keys to LeRobot policy format.
observation = preprocess_observation(observation) transition = (observation, None, None, None, None, None, None)
processed_transition = obs_pipeline(transition)
observation = processed_transition[TransitionIndex.OBSERVATION]
if return_observations: if return_observations:
all_observations.append(deepcopy(observation)) all_observations.append(deepcopy(observation))
@@ -195,7 +203,9 @@ def rollout(
# Track the final observation. # Track the final observation.
if return_observations: if return_observations:
observation = preprocess_observation(observation) transition = (observation, None, None, None, None, None, None)
processed_transition = obs_pipeline(transition)
observation = processed_transition[TransitionIndex.OBSERVATION]
all_observations.append(deepcopy(observation)) all_observations.append(deepcopy(observation))
# Stack the sequence along the first dimension so that we have (batch, sequence, *) tensors. # Stack the sequence along the first dimension so that we have (batch, sequence, *) tensors.
+9 -2
View File
@@ -22,7 +22,9 @@ from gymnasium.utils.env_checker import check_env
import lerobot import lerobot
from lerobot.envs.factory import make_env, make_env_config from lerobot.envs.factory import make_env, make_env_config
from lerobot.envs.utils import preprocess_observation
from lerobot.processor.pipeline import RobotPipeline, TransitionIndex
from lerobot.processor.observation_processor import ObservationProcessor
from tests.utils import require_env from tests.utils import require_env
OBS_TYPES = ["state", "pixels", "pixels_agent_pos"] OBS_TYPES = ["state", "pixels", "pixels_agent_pos"]
@@ -48,7 +50,12 @@ def test_factory(env_name):
cfg = make_env_config(env_name) cfg = make_env_config(env_name)
env = make_env(cfg, n_envs=1) env = make_env(cfg, n_envs=1)
obs, _ = env.reset() obs, _ = env.reset()
obs = preprocess_observation(obs)
# Process observation using pipeline
obs_pipeline = RobotPipeline([ObservationProcessor()])
transition = (obs, None, None, None, None, None, None)
processed_transition = obs_pipeline(transition)
obs = processed_transition[TransitionIndex.OBSERVATION]
# test image keys are float32 in range [0,1] # test image keys are float32 in range [0,1]
for key in obs: for key in obs:
+6 -2
View File
@@ -30,7 +30,8 @@ from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.datasets.factory import make_dataset from lerobot.datasets.factory import make_dataset
from lerobot.datasets.utils import cycle, dataset_to_policy_features from lerobot.datasets.utils import cycle, dataset_to_policy_features
from lerobot.envs.factory import make_env, make_env_config from lerobot.envs.factory import make_env, make_env_config
from lerobot.envs.utils import preprocess_observation from lerobot.processor.pipeline import RobotPipeline, TransitionIndex
from lerobot.processor.observation_processor import ObservationProcessor
from lerobot.optim.factory import make_optimizer_and_scheduler from lerobot.optim.factory import make_optimizer_and_scheduler
from lerobot.policies.act.modeling_act import ACTTemporalEnsembler from lerobot.policies.act.modeling_act import ACTTemporalEnsembler
from lerobot.policies.factory import ( from lerobot.policies.factory import (
@@ -185,7 +186,10 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs):
observation, _ = env.reset(seed=train_cfg.seed) observation, _ = env.reset(seed=train_cfg.seed)
# apply transform to normalize the observations # apply transform to normalize the observations
observation = preprocess_observation(observation) obs_pipeline = RobotPipeline([ObservationProcessor()])
transition = (observation, None, None, None, None, None, None)
processed_transition = obs_pipeline(transition)
observation = processed_transition[TransitionIndex.OBSERVATION]
# send observation to device/gpu # send observation to device/gpu
observation = {key: observation[key].to(DEVICE, non_blocking=True) for key in observation} observation = {key: observation[key].to(DEVICE, non_blocking=True) for key in observation}
@@ -0,0 +1,466 @@
#!/usr/bin/env python
# Copyright 202 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import pytest
import torch
from lerobot.processor.observation_processor import (
ImageProcessor,
StateProcessor,
ObservationProcessor,
)
from lerobot.processor.pipeline import EnvTransition
def test_process_single_image():
"""Test processing a single image."""
processor = ImageProcessor()
# Create a mock image (H, W, C) format, uint8
image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8)
observation = {"pixels": image}
transition = (observation, None, None, None, None, None, None)
result = processor(transition)
processed_obs = result[0]
# Check that the image was processed correctly
assert "observation.image" in processed_obs
processed_img = processed_obs["observation.image"]
# Check shape: should be (1, 3, 64, 64) - batch, channels, height, width
assert processed_img.shape == (1, 3, 64, 64)
# Check dtype and range
assert processed_img.dtype == torch.float32
assert processed_img.min() >= 0.0
assert processed_img.max() <= 1.0
def test_process_image_dict():
"""Test processing multiple images in a dictionary."""
processor = ImageProcessor()
# Create mock images
image1 = np.random.randint(0, 256, size=(32, 32, 3), dtype=np.uint8)
image2 = np.random.randint(0, 256, size=(48, 48, 3), dtype=np.uint8)
observation = {
"pixels": {
"camera1": image1,
"camera2": image2
}
}
transition = (observation, None, None, None, None, None, None)
result = processor(transition)
processed_obs = result[0]
# Check that both images were processed
assert "observation.images.camera1" in processed_obs
assert "observation.images.camera2" in processed_obs
# Check shapes
assert processed_obs["observation.images.camera1"].shape == (1, 3, 32, 32)
assert processed_obs["observation.images.camera2"].shape == (1, 3, 48, 48)
def test_process_batched_image():
"""Test processing already batched images."""
processor = ImageProcessor()
# Create a batched image (B, H, W, C)
image = np.random.randint(0, 256, size=(2, 64, 64, 3), dtype=np.uint8)
observation = {"pixels": image}
transition = (observation, None, None, None, None, None, None)
result = processor(transition)
processed_obs = result[0]
# Check that batch dimension is preserved
assert processed_obs["observation.image"].shape == (2, 3, 64, 64)
def test_invalid_image_format():
"""Test error handling for invalid image formats."""
processor = ImageProcessor()
# Test wrong channel order (channels first)
image = np.random.randint(0, 256, size=(3, 64, 64), dtype=np.uint8)
observation = {"pixels": image}
transition = (observation, None, None, None, None, None, None)
with pytest.raises(ValueError, match="Expected channel-last images"):
processor(transition)
def test_invalid_image_dtype():
"""Test error handling for invalid image dtype."""
processor = ImageProcessor()
# Test wrong dtype
image = np.random.rand(64, 64, 3).astype(np.float32)
observation = {"pixels": image}
transition = (observation, None, None, None, None, None, None)
with pytest.raises(ValueError, match="Expected torch.uint8 images"):
processor(transition)
def test_no_pixels_in_observation():
"""Test processor when no pixels are in observation."""
processor = ImageProcessor()
observation = {"other_data": np.array([1, 2, 3])}
transition = (observation, None, None, None, None, None, None)
result = processor(transition)
processed_obs = result[0]
# Should preserve other data unchanged
assert "other_data" in processed_obs
np.testing.assert_array_equal(processed_obs["other_data"], np.array([1, 2, 3]))
def test_none_observation():
"""Test processor with None observation."""
processor = ImageProcessor()
transition = (None, None, None, None, None, None, None)
result = processor(transition)
assert result == transition
def test_serialization_methods():
"""Test serialization methods."""
processor = ImageProcessor()
# Test get_config
config = processor.get_config()
assert isinstance(config, dict)
# Test state_dict
state = processor.state_dict()
assert isinstance(state, dict)
# Test load_state_dict (should not raise)
processor.load_state_dict(state)
# Test reset (should not raise)
processor.reset()
def test_process_environment_state():
"""Test processing environment_state."""
processor = StateProcessor()
env_state = np.array([1.0, 2.0, 3.0], dtype=np.float32)
observation = {"environment_state": env_state}
transition = (observation, None, None, None, None, None, None)
result = processor(transition)
processed_obs = result[0]
# Check that environment_state was renamed and processed
assert "observation.environment_state" in processed_obs
assert "environment_state" not in processed_obs
processed_state = processed_obs["observation.environment_state"]
assert processed_state.shape == (1, 3) # Batch dimension added
assert processed_state.dtype == torch.float32
torch.testing.assert_close(processed_state, torch.tensor([[1.0, 2.0, 3.0]]))
def test_process_agent_pos():
"""Test processing agent_pos."""
processor = StateProcessor()
agent_pos = np.array([0.5, -0.5, 1.0], dtype=np.float32)
observation = {"agent_pos": agent_pos}
transition = (observation, None, None, None, None, None, None)
result = processor(transition)
processed_obs = result[0]
# Check that agent_pos was renamed and processed
assert "observation.state" in processed_obs
assert "agent_pos" not in processed_obs
processed_state = processed_obs["observation.state"]
assert processed_state.shape == (1, 3) # Batch dimension added
assert processed_state.dtype == torch.float32
torch.testing.assert_close(processed_state, torch.tensor([[0.5, -0.5, 1.0]]))
def test_process_batched_states():
"""Test processing already batched states."""
processor = StateProcessor()
env_state = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)
agent_pos = np.array([[0.5, -0.5], [1.0, -1.0]], dtype=np.float32)
observation = {
"environment_state": env_state,
"agent_pos": agent_pos
}
transition = (observation, None, None, None, None, None, None)
result = processor(transition)
processed_obs = result[0]
# Check that batch dimensions are preserved
assert processed_obs["observation.environment_state"].shape == (2, 2)
assert processed_obs["observation.state"].shape == (2, 2)
def test_process_both_states():
"""Test processing both environment_state and agent_pos."""
processor = StateProcessor()
env_state = np.array([1.0, 2.0], dtype=np.float32)
agent_pos = np.array([0.5, -0.5], dtype=np.float32)
observation = {
"environment_state": env_state,
"agent_pos": agent_pos,
"other_data": "keep_me"
}
transition = (observation, None, None, None, None, None, None)
result = processor(transition)
processed_obs = result[0]
# Check that both states were processed
assert "observation.environment_state" in processed_obs
assert "observation.state" in processed_obs
# Check that original keys were removed
assert "environment_state" not in processed_obs
assert "agent_pos" not in processed_obs
# Check that other data was preserved
assert processed_obs["other_data"] == "keep_me"
def test_no_states_in_observation():
"""Test processor when no states are in observation."""
processor = StateProcessor()
observation = {"other_data": np.array([1, 2, 3])}
transition = (observation, None, None, None, None, None, None)
result = processor(transition)
processed_obs = result[0]
# Should preserve data unchanged
assert processed_obs == observation
def test_none_observation():
"""Test processor with None observation."""
processor = StateProcessor()
transition = (None, None, None, None, None, None, None)
result = processor(transition)
assert result == transition
def test_serialization_methods():
"""Test serialization methods."""
processor = StateProcessor()
# Test get_config
config = processor.get_config()
assert isinstance(config, dict)
# Test state_dict
state = processor.state_dict()
assert isinstance(state, dict)
# Test load_state_dict (should not raise)
processor.load_state_dict(state)
# Test reset (should not raise)
processor.reset()
def test_complete_observation_processing():
"""Test processing a complete observation with both images and states."""
processor = ObservationProcessor()
# Create mock data
image = np.random.randint(0, 256, size=(32, 32, 3), dtype=np.uint8)
env_state = np.array([1.0, 2.0, 3.0], dtype=np.float32)
agent_pos = np.array([0.5, -0.5, 1.0], dtype=np.float32)
observation = {
"pixels": image,
"environment_state": env_state,
"agent_pos": agent_pos,
"other_data": "preserve_me"
}
transition = (observation, None, None, None, None, None, None)
result = processor(transition)
processed_obs = result[0]
# Check that image was processed
assert "observation.image" in processed_obs
assert processed_obs["observation.image"].shape == (1, 3, 32, 32)
# Check that states were processed
assert "observation.environment_state" in processed_obs
assert "observation.state" in processed_obs
# Check that original keys were removed
assert "pixels" not in processed_obs
assert "environment_state" not in processed_obs
assert "agent_pos" not in processed_obs
# Check that other data was preserved
assert processed_obs["other_data"] == "preserve_me"
def test_image_only_processing():
"""Test processing observation with only images."""
processor = ObservationProcessor()
image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8)
observation = {"pixels": image}
transition = (observation, None, None, None, None, None, None)
result = processor(transition)
processed_obs = result[0]
assert "observation.image" in processed_obs
assert len(processed_obs) == 1
def test_state_only_processing():
"""Test processing observation with only states."""
processor = ObservationProcessor()
agent_pos = np.array([1.0, 2.0], dtype=np.float32)
observation = {"agent_pos": agent_pos}
transition = (observation, None, None, None, None, None, None)
result = processor(transition)
processed_obs = result[0]
assert "observation.state" in processed_obs
assert "agent_pos" not in processed_obs
def test_empty_observation():
"""Test processing empty observation."""
processor = ObservationProcessor()
observation = {}
transition = (observation, None, None, None, None, None, None)
result = processor(transition)
processed_obs = result[0]
assert processed_obs == {}
def test_none_observation():
"""Test processing None observation."""
processor = ObservationProcessor()
transition = (None, None, None, None, None, None, None)
result = processor(transition)
assert result == transition
def test_serialization_methods():
"""Test serialization methods."""
processor = ObservationProcessor()
# Test get_config
config = processor.get_config()
assert isinstance(config, dict)
assert "image_processor" in config
assert "state_processor" in config
# Test state_dict
state = processor.state_dict()
assert isinstance(state, dict)
# Test load_state_dict (should not raise)
processor.load_state_dict(state)
# Test reset (should not raise)
processor.reset()
def test_custom_sub_processors():
"""Test ObservationProcessor with custom sub-processors."""
image_proc = ImageProcessor()
state_proc = StateProcessor()
processor = ObservationProcessor(image_processor=image_proc, state_processor=state_proc)
# Should use the provided processors
assert processor.image_processor is image_proc
assert processor.state_processor is state_proc
def test_equivalent_to_original_function():
"""Test that ObservationProcessor produces equivalent results to preprocess_observation."""
# Import the original function for comparison
from lerobot.envs.utils import preprocess_observation
processor = ObservationProcessor()
# Create test data similar to what the original function expects
image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8)
env_state = np.array([1.0, 2.0, 3.0], dtype=np.float32)
agent_pos = np.array([0.5, -0.5, 1.0], dtype=np.float32)
observation = {
"pixels": image,
"environment_state": env_state,
"agent_pos": agent_pos
}
# Process with original function
original_result = preprocess_observation(observation)
# Process with new processor
transition = (observation, None, None, None, None, None, None)
processor_result = processor(transition)[0]
# Compare results
assert set(original_result.keys()) == set(processor_result.keys())
for key in original_result:
torch.testing.assert_close(original_result[key], processor_result[key])
def test_equivalent_with_image_dict():
"""Test equivalence with dictionary of images."""
from lerobot.envs.utils import preprocess_observation
processor = ObservationProcessor()
# Create test data with multiple cameras
image1 = np.random.randint(0, 256, size=(32, 32, 3), dtype=np.uint8)
image2 = np.random.randint(0, 256, size=(48, 48, 3), dtype=np.uint8)
agent_pos = np.array([1.0, 2.0], dtype=np.float32)
observation = {
"pixels": {"cam1": image1, "cam2": image2},
"agent_pos": agent_pos
}
# Process with original function
original_result = preprocess_observation(observation)
# Process with new processor
transition = (observation, None, None, None, None, None, None)
processor_result = processor(transition)[0]
# Compare results
assert set(original_result.keys()) == set(processor_result.keys())
for key in original_result:
torch.testing.assert_close(original_result[key], processor_result[key])
+408
View File
@@ -0,0 +1,408 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import tempfile
from pathlib import Path
from typing import Any
from dataclasses import dataclass
import numpy as np
import pytest
import torch
from lerobot.processor.pipeline import RobotPipeline, EnvTransition, PipelineStep
@dataclass
class MockStep:
"""Mock pipeline step for testing - demonstrates best practices.
This example shows the proper separation:
- JSON-serializable attributes (name, counter) go in get_config()
- Only torch tensors go in state_dict()
Note: The counter is part of the configuration, so it will be restored
when the step is recreated from config during loading.
"""
name: str = "mock_step"
counter: int = 0
def __call__(self, transition: EnvTransition) -> EnvTransition:
"""Add a counter to the complementary_data."""
obs, action, reward, done, truncated, info, comp_data = transition
if comp_data is None:
comp_data = {}
else:
comp_data = dict(comp_data) # Make a copy
comp_data[f"{self.name}_counter"] = self.counter
self.counter += 1
return (obs, action, reward, done, truncated, info, comp_data)
def get_config(self) -> dict[str, Any]:
# Return all JSON-serializable attributes that should be persisted
# These will be passed to __init__ when loading
return {"name": self.name, "counter": self.counter}
def state_dict(self) -> dict[str, torch.Tensor]:
# Only return torch tensors (empty in this case since we have no tensor state)
return {}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
# No tensor state to load
pass
def reset(self) -> None:
self.counter = 0
@dataclass
class MockStepWithoutOptionalMethods:
"""Mock step that only implements the required __call__ method."""
multiplier: float = 2.0
def __call__(self, transition: EnvTransition) -> EnvTransition:
"""Multiply reward by multiplier."""
obs, action, reward, done, truncated, info, comp_data = transition
if reward is not None:
reward = reward * self.multiplier
return (obs, action, reward, done, truncated, info, comp_data)
@dataclass
class MockStepWithTensorState:
"""Mock step demonstrating mixed JSON attributes and tensor state."""
name: str = "tensor_step"
learning_rate: float = 0.01
window_size: int = 10
def __init__(self, name: str = "tensor_step", learning_rate: float = 0.01, window_size: int = 10):
self.name = name
self.learning_rate = learning_rate
self.window_size = window_size
# Tensor state
self.running_mean = torch.zeros(window_size)
self.running_count = torch.tensor(0)
def __call__(self, transition: EnvTransition) -> EnvTransition:
"""Update running statistics."""
obs, action, reward, done, truncated, info, comp_data = transition
if reward is not None:
# Update running mean
idx = self.running_count % self.window_size
self.running_mean[idx] = reward
self.running_count += 1
return transition
def get_config(self) -> dict[str, Any]:
# Only JSON-serializable attributes
return {
"name": self.name,
"learning_rate": self.learning_rate,
"window_size": self.window_size,
}
def state_dict(self) -> dict[str, torch.Tensor]:
# Only tensor state
return {
"running_mean": self.running_mean,
"running_count": self.running_count,
}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
self.running_mean = state["running_mean"]
self.running_count = state["running_count"]
def reset(self) -> None:
self.running_mean.zero_()
self.running_count.zero_()
def test_empty_pipeline():
"""Test pipeline with no steps."""
pipeline = RobotPipeline()
transition = (None, None, 0.0, False, False, {}, {})
result = pipeline(transition)
assert result == transition
assert len(pipeline) == 0
def test_single_step_pipeline():
"""Test pipeline with a single step."""
step = MockStep("test_step")
pipeline = RobotPipeline([step])
transition = (None, None, 0.0, False, False, {}, {})
result = pipeline(transition)
assert len(pipeline) == 1
assert result[6]["test_step_counter"] == 0 # complementary_data
# Call again to test counter increment
result = pipeline(transition)
assert result[6]["test_step_counter"] == 1
def test_multiple_steps_pipeline():
"""Test pipeline with multiple steps."""
step1 = MockStep("step1")
step2 = MockStep("step2")
pipeline = RobotPipeline([step1, step2])
transition = (None, None, 0.0, False, False, {}, {})
result = pipeline(transition)
assert len(pipeline) == 2
assert result[6]["step1_counter"] == 0
assert result[6]["step2_counter"] == 0
def test_invalid_transition_format():
"""Test pipeline with invalid transition format."""
pipeline = RobotPipeline([MockStep()])
# Test with wrong number of elements
with pytest.raises(ValueError, match="EnvTransition must be a 7-tuple"):
pipeline((None, None, 0.0)) # Only 3 elements
# Test with wrong type
with pytest.raises(ValueError, match="EnvTransition must be a 7-tuple"):
pipeline("not a tuple")
def test_step_through():
"""Test step_through method."""
step1 = MockStep("step1")
step2 = MockStep("step2")
pipeline = RobotPipeline([step1, step2])
transition = (None, None, 0.0, False, False, {}, {})
results = list(pipeline.step_through(transition))
assert len(results) == 3 # Original + 2 steps
assert results[0] == transition # Original
assert "step1_counter" in results[1][6] # After step1
assert "step2_counter" in results[2][6] # After step2
def test_indexing():
"""Test pipeline indexing."""
step1 = MockStep("step1")
step2 = MockStep("step2")
pipeline = RobotPipeline([step1, step2])
# Test integer indexing
assert pipeline[0] is step1
assert pipeline[1] is step2
# Test slice indexing
sub_pipeline = pipeline[0:1]
assert isinstance(sub_pipeline, RobotPipeline)
assert len(sub_pipeline) == 1
assert sub_pipeline[0] is step1
def test_hooks():
"""Test before/after step hooks."""
step = MockStep("test_step")
pipeline = RobotPipeline([step])
before_calls = []
after_calls = []
def before_hook(idx: int, transition: EnvTransition):
before_calls.append(idx)
return transition
def after_hook(idx: int, transition: EnvTransition):
after_calls.append(idx)
return transition
pipeline.register_before_step_hook(before_hook)
pipeline.register_after_step_hook(after_hook)
transition = (None, None, 0.0, False, False, {}, {})
pipeline(transition)
assert before_calls == [0]
assert after_calls == [0]
def test_hook_modification():
"""Test that hooks can modify transitions."""
step = MockStep("test_step")
pipeline = RobotPipeline([step])
def modify_reward_hook(idx: int, transition: EnvTransition):
obs, action, reward, done, truncated, info, comp_data = transition
return (obs, action, 42.0, done, truncated, info, comp_data)
pipeline.register_before_step_hook(modify_reward_hook)
transition = (None, None, 0.0, False, False, {}, {})
result = pipeline(transition)
assert result[2] == 42.0 # reward modified by hook
def test_reset():
"""Test pipeline reset functionality."""
step = MockStep("test_step")
pipeline = RobotPipeline([step])
reset_called = []
def reset_hook():
reset_called.append(True)
pipeline.register_reset_hook(reset_hook)
# Make some calls to increment counter
transition = (None, None, 0.0, False, False, {}, {})
pipeline(transition)
pipeline(transition)
assert step.counter == 2
# Reset should reset step and call hook
pipeline.reset()
assert step.counter == 0
assert len(reset_called) == 1
def test_profile_steps():
"""Test step profiling functionality."""
step1 = MockStep("step1")
step2 = MockStep("step2")
pipeline = RobotPipeline([step1, step2])
transition = (None, None, 0.0, False, False, {}, {})
profile_results = pipeline.profile_steps(transition, num_runs=10)
assert len(profile_results) == 2
assert "step_0_MockStep" in profile_results
assert "step_1_MockStep" in profile_results
assert all(isinstance(time, float) and time >= 0 for time in profile_results.values())
def test_save_and_load_pretrained():
"""Test saving and loading pipeline.
This test demonstrates that JSON-serializable attributes (like counter)
are saved in the config and restored when the step is recreated.
"""
step1 = MockStep("step1")
step2 = MockStep("step2")
# Increment counters to have some state
step1.counter = 5
step2.counter = 10
pipeline = RobotPipeline([step1, step2], name="TestPipeline", seed=42)
with tempfile.TemporaryDirectory() as tmp_dir:
# Save pipeline
pipeline.save_pretrained(tmp_dir)
# Check files were created
config_path = Path(tmp_dir) / "pipeline.json"
assert config_path.exists()
# Check config content
with open(config_path) as f:
config = json.load(f)
assert config["name"] == "TestPipeline"
assert config["seed"] == 42
assert len(config["steps"]) == 2
# Verify counters are saved in config, not in separate state files
assert config["steps"][0]["config"]["counter"] == 5
assert config["steps"][1]["config"]["counter"] == 10
# Load pipeline
loaded_pipeline = RobotPipeline.from_pretrained(tmp_dir)
assert loaded_pipeline.name == "TestPipeline"
assert loaded_pipeline.seed == 42
assert len(loaded_pipeline) == 2
# Check that counter was restored from config
assert loaded_pipeline.steps[0].counter == 5
assert loaded_pipeline.steps[1].counter == 10
def test_step_without_optional_methods():
"""Test pipeline with steps that don't implement optional methods."""
step = MockStepWithoutOptionalMethods(multiplier=3.0)
pipeline = RobotPipeline([step])
transition = (None, None, 2.0, False, False, {}, {})
result = pipeline(transition)
assert result[2] == 6.0 # 2.0 * 3.0
# Reset should work even if step doesn't implement reset
pipeline.reset()
# Save/load should work even without optional methods
with tempfile.TemporaryDirectory() as tmp_dir:
pipeline.save_pretrained(tmp_dir)
loaded_pipeline = RobotPipeline.from_pretrained(tmp_dir)
assert len(loaded_pipeline) == 1
def test_mixed_json_and_tensor_state():
"""Test step with both JSON attributes and tensor state."""
step = MockStepWithTensorState(name="stats", learning_rate=0.05, window_size=5)
pipeline = RobotPipeline([step])
# Process some transitions with rewards
for i in range(10):
transition = (None, None, float(i), False, False, {}, {})
pipeline(transition)
# Check state
assert step.running_count.item() == 10
assert step.learning_rate == 0.05
# Save and load
with tempfile.TemporaryDirectory() as tmp_dir:
pipeline.save_pretrained(tmp_dir)
# Check that both config and state files were created
config_path = Path(tmp_dir) / "pipeline.json"
state_path = Path(tmp_dir) / "step_0.safetensors"
assert config_path.exists()
assert state_path.exists()
# Load and verify
loaded_pipeline = RobotPipeline.from_pretrained(tmp_dir)
loaded_step = loaded_pipeline.steps[0]
# Check JSON attributes were restored
assert loaded_step.name == "stats"
assert loaded_step.learning_rate == 0.05
assert loaded_step.window_size == 5
# Check tensor state was restored
assert loaded_step.running_count.item() == 10
assert torch.allclose(loaded_step.running_mean, step.running_mean)