From f6c7287ae7e3055b8cb9998f37675415140729e6 Mon Sep 17 00:00:00 2001 From: Adil Zouitine Date: Wed, 2 Jul 2025 17:29:58 +0200 Subject: [PATCH] Refactor observation preprocessing to use a modular pipeline system - Introduced `RobotPipeline` and `ObservationProcessor` for handling observation transformations. - Updated `preprocess_observation` to maintain backward compatibility while leveraging the new pipeline. - Added tests for the new processing components and ensured they match the original functionality. - Removed hardcoded logic in favor of a more flexible, composable architecture. --- src/lerobot/envs/utils.py | 60 +-- src/lerobot/processor/__init__.py | 30 ++ .../processor/observation_processor.py | 224 +++++++++ src/lerobot/processor/pipeline.py | 299 +++++++++++ src/lerobot/scripts/eval.py | 16 +- tests/envs/test_envs.py | 11 +- tests/policies/test_policies.py | 8 +- tests/processor/test_observation_processor.py | 466 ++++++++++++++++++ tests/processor/test_pipeline.py | 408 +++++++++++++++ 9 files changed, 1472 insertions(+), 50 deletions(-) create mode 100644 src/lerobot/processor/__init__.py create mode 100644 src/lerobot/processor/observation_processor.py create mode 100644 src/lerobot/processor/pipeline.py create mode 100644 tests/processor/test_observation_processor.py create mode 100644 tests/processor/test_pipeline.py diff --git a/src/lerobot/envs/utils.py b/src/lerobot/envs/utils.py index 00676a011..c90113b36 100644 --- a/src/lerobot/envs/utils.py +++ b/src/lerobot/envs/utils.py @@ -28,62 +28,36 @@ from lerobot.utils.utils import get_channel_first_image_shape def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Tensor]: - # TODO(aliberts, rcadene): refactor this to use features from the environment (no hardcoding) """Convert environment observation to LeRobot format observation. + + This function uses the new pipeline system internally but maintains + backward compatibility with the original interface. + Args: observation: Dictionary of observation batches from a Gym vector environment. Returns: Dictionary of observation batches with keys renamed to LeRobot format and values as tensors. """ - # map to expected inputs for the policy - return_observations = {} - if "pixels" in observations: - if isinstance(observations["pixels"], dict): - imgs = {f"observation.images.{key}": img for key, img in observations["pixels"].items()} - else: - imgs = {"observation.image": observations["pixels"]} + from lerobot.processor.pipeline import RobotPipeline, TransitionIndex + from lerobot.processor.observation_processor import ObservationProcessor + + # Create pipeline with observation processor + pipeline = RobotPipeline([ObservationProcessor()]) + + # Create transition tuple and process + transition = (observations, None, None, None, None, None, None) + processed_transition = pipeline(transition) + + # Return processed observations + return processed_transition[TransitionIndex.OBSERVATION] - for imgkey, img in imgs.items(): - # TODO(aliberts, rcadene): use transforms.ToTensor()? - img = torch.from_numpy(img) - # When preprocessing observations in a non-vectorized environment, we need to add a batch dimension. - # This is the case for human-in-the-loop RL where there is only one environment. - if img.ndim == 3: - img = img.unsqueeze(0) - # sanity check that images are channel last - _, h, w, c = img.shape - assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}" - # sanity check that images are uint8 - assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}" - - # convert to channel first of type float32 in range [0,1] - img = einops.rearrange(img, "b h w c -> b c h w").contiguous() - img = img.type(torch.float32) - img /= 255 - - return_observations[imgkey] = img - - if "environment_state" in observations: - env_state = torch.from_numpy(observations["environment_state"]).float() - if env_state.dim() == 1: - env_state = env_state.unsqueeze(0) - - return_observations["observation.environment_state"] = env_state - - # TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing - agent_pos = torch.from_numpy(observations["agent_pos"]).float() - if agent_pos.dim() == 1: - agent_pos = agent_pos.unsqueeze(0) - return_observations["observation.state"] = agent_pos - - return return_observations def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]: # TODO(aliberts, rcadene): remove this hardcoding of keys and just use the nested keys as is - # (need to also refactor preprocess_observation and externalize normalization from policies) + # (need to externalize normalization from policies) policy_features = {} for key, ft in env_cfg.features.items(): if ft.type is FeatureType.VISUAL: diff --git a/src/lerobot/processor/__init__.py b/src/lerobot/processor/__init__.py new file mode 100644 index 000000000..76fb86b8f --- /dev/null +++ b/src/lerobot/processor/__init__.py @@ -0,0 +1,30 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .pipeline import RobotPipeline, PipelineStep, EnvTransition +from .observation_processor import ( + ImageProcessor, + StateProcessor, + ObservationProcessor, +) + +__all__ = [ + "RobotPipeline", + "PipelineStep", + "EnvTransition", + "ImageProcessor", + "StateProcessor", + "ObservationProcessor", +] diff --git a/src/lerobot/processor/observation_processor.py b/src/lerobot/processor/observation_processor.py new file mode 100644 index 000000000..76189a47f --- /dev/null +++ b/src/lerobot/processor/observation_processor.py @@ -0,0 +1,224 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from typing import Any +from dataclasses import dataclass, field +import numpy as np +import torch +import einops +from torch import Tensor + +from lerobot.processor.pipeline import EnvTransition, PipelineStep, TransitionIndex + + +@dataclass +class ImageProcessor: + """Process image observations from environment format to policy format. + + Converts images from: + - Channel-last (H, W, C) to channel-first (C, H, W) + - uint8 [0, 255] to float32 [0, 1] + - Adds batch dimension if needed + - Handles both single images and dictionaries of images + """ + + def __call__(self, transition: EnvTransition) -> EnvTransition: + observation = transition[TransitionIndex.OBSERVATION] + + if observation is None: + return transition + + processed_obs = {} + + # Handle pixels key + if "pixels" in observation: + if isinstance(observation["pixels"], dict): + imgs = {f"observation.images.{key}": img for key, img in observation["pixels"].items()} + else: + imgs = {"observation.image": observation["pixels"]} + + for imgkey, img in imgs.items(): + processed_img = self._process_single_image(img) + processed_obs[imgkey] = processed_img + + # Copy other observations unchanged + for key, value in observation.items(): + if key != "pixels": + processed_obs[key] = value + + # Return new transition with processed observation + return ( + processed_obs, + transition[TransitionIndex.ACTION], + transition[TransitionIndex.REWARD], + transition[TransitionIndex.DONE], + transition[TransitionIndex.TRUNCATED], + transition[TransitionIndex.INFO], + transition[TransitionIndex.COMPLEMENTARY_DATA], + ) + + def _process_single_image(self, img: np.ndarray) -> Tensor: + """Process a single image array.""" + # Convert to tensor + img_tensor = torch.from_numpy(img) + + # Add batch dimension if needed + if img_tensor.ndim == 3: + img_tensor = img_tensor.unsqueeze(0) + + # Validate image format + _, h, w, c = img_tensor.shape + if not (c < h and c < w): + raise ValueError(f"Expected channel-last images, but got shape {img_tensor.shape}") + + if img_tensor.dtype != torch.uint8: + raise ValueError(f"Expected torch.uint8 images, but got {img_tensor.dtype}") + + # Convert to channel-first format + img_tensor = einops.rearrange(img_tensor, "b h w c -> b c h w").contiguous() + + # Convert to float32 and normalize to [0, 1] + img_tensor = img_tensor.type(torch.float32) / 255.0 + + return img_tensor + + def get_config(self) -> dict[str, Any]: + """Return configuration for serialization.""" + return {} + + def state_dict(self) -> dict[str, torch.Tensor]: + """Return state dictionary (empty for this processor).""" + return {} + + def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: + """Load state dictionary (no-op for this processor).""" + pass + + def reset(self) -> None: + """Reset processor state (no-op for this processor).""" + pass + + +@dataclass +class StateProcessor: + """Process state observations from environment format to policy format. + + Handles: + - environment_state -> observation.environment_state + - agent_pos -> observation.state + - Converts numpy arrays to tensors + - Adds batch dimension if needed + """ + + def __call__(self, transition: EnvTransition) -> EnvTransition: + observation = transition[TransitionIndex.OBSERVATION] + + if observation is None: + return transition + + processed_obs = dict(observation) # Copy existing observations + + # Process environment_state + if "environment_state" in observation: + env_state = torch.from_numpy(observation["environment_state"]).float() + if env_state.dim() == 1: + env_state = env_state.unsqueeze(0) + processed_obs["observation.environment_state"] = env_state + # Remove original key + del processed_obs["environment_state"] + + # Process agent_pos + if "agent_pos" in observation: + agent_pos = torch.from_numpy(observation["agent_pos"]).float() + if agent_pos.dim() == 1: + agent_pos = agent_pos.unsqueeze(0) + processed_obs["observation.state"] = agent_pos + # Remove original key + del processed_obs["agent_pos"] + + # Return new transition with processed observation + return ( + processed_obs, + transition[TransitionIndex.ACTION], + transition[TransitionIndex.REWARD], + transition[TransitionIndex.DONE], + transition[TransitionIndex.TRUNCATED], + transition[TransitionIndex.INFO], + transition[TransitionIndex.COMPLEMENTARY_DATA], + ) + + def get_config(self) -> dict[str, Any]: + """Return configuration for serialization.""" + return {} + + def state_dict(self) -> dict[str, torch.Tensor]: + """Return state dictionary (empty for this processor).""" + return {} + + def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: + """Load state dictionary (no-op for this processor).""" + pass + + def reset(self) -> None: + """Reset processor state (no-op for this processor).""" + pass + + +@dataclass +class ObservationProcessor: + """Complete observation processor that combines image and state processing. + + This processor replicates the functionality of the original preprocess_observation + function but in a modular, composable way that fits into the pipeline architecture. + """ + + image_processor: ImageProcessor = field(default_factory=ImageProcessor) + state_processor: StateProcessor = field(default_factory=StateProcessor) + + def __call__(self, transition: EnvTransition) -> EnvTransition: + # First process images + transition = self.image_processor(transition) + # Then process state + transition = self.state_processor(transition) + return transition + + def get_config(self) -> dict[str, Any]: + """Return configuration for serialization.""" + return { + "image_processor": self.image_processor.get_config(), + "state_processor": self.state_processor.get_config(), + } + + def state_dict(self) -> dict[str, torch.Tensor]: + """Return state dictionary.""" + state = {} + state.update({f"image_processor.{k}": v for k, v in self.image_processor.state_dict().items()}) + state.update({f"state_processor.{k}": v for k, v in self.state_processor.state_dict().items()}) + return state + + def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: + """Load state dictionary.""" + image_state = {k.replace("image_processor.", ""): v for k, v in state.items() if k.startswith("image_processor.")} + state_state = {k.replace("state_processor.", ""): v for k, v in state.items() if k.startswith("state_processor.")} + + self.image_processor.load_state_dict(image_state) + self.state_processor.load_state_dict(state_state) + + def reset(self) -> None: + """Reset processor state.""" + self.image_processor.reset() + self.state_processor.reset() diff --git a/src/lerobot/processor/pipeline.py b/src/lerobot/processor/pipeline.py new file mode 100644 index 000000000..2451ed99a --- /dev/null +++ b/src/lerobot/processor/pipeline.py @@ -0,0 +1,299 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations +import os, json +from typing import Any, Dict, Sequence, Iterable, Protocol, Optional, Tuple, Callable, Union +from dataclasses import dataclass, field +from pathlib import Path +from enum import IntEnum +import numpy as np +import torch +from huggingface_hub import hf_hub_download, ModelHubMixin +from safetensors.torch import save_file, load_file + + +class TransitionIndex(IntEnum): + """Explicit indices for EnvTransition tuple components.""" + OBSERVATION = 0 + ACTION = 1 + REWARD = 2 + DONE = 3 + TRUNCATED = 4 + INFO = 5 + COMPLEMENTARY_DATA = 6 + + +# (observation, action, reward, done, truncated, info, complementary_data) +EnvTransition = Tuple[ + Any| None, # observation + Any| None, # action + float| None, # reward + bool| None, # done + bool| None, # truncated + Dict[str, Any]| None, # info + Dict[str, Any]| None, # complementary_data +] + + + +class PipelineStep(Protocol): + """Structural typing interface for a single pipeline step. + + A step is any callable accepting a full `EnvTransition` tuple and + returning a (possibly modified) tuple of the same structure. Implementers + are encouraged—but not required—to expose the optional helper methods + listed below. When present, these hooks let `RobotPipeline` + automatically serialise the step's configuration and learnable state using + a safe-to-share JSON + SafeTensors format. + + Optional helper protocol: + * ``get_config() -> Dict[str, Any]`` – User-defined JSON-serializable + configuration and state. YOU decide what to save here. This is where all + non-tensor state goes (e.g., name, counter, threshold, window_size). + The config dict will be passed to your class constructor when loading. + * ``state_dict() -> Dict[str, torch.Tensor]`` – PyTorch tensor state ONLY. + This is exclusively for torch.Tensor objects (e.g., learned weights, + running statistics as tensors). Never put simple Python types here. + * ``load_state_dict(state)`` – Inverse of ``state_dict``. Receives a dict + containing torch tensors only. + * ``reset()`` – Clear internal buffers at episode boundaries. + + Example separation: + - get_config(): {"name": "my_step", "learning_rate": 0.01, "window_size": 10} + - state_dict(): {"weights": torch.tensor(...), "running_mean": torch.tensor(...)} + """ + + def __call__(self, transition: EnvTransition) -> EnvTransition: ... + + def get_config(self) -> Dict[str, Any]: ... + + def state_dict(self) -> Dict[str, torch.Tensor]: ... + + def load_state_dict(self, state: Dict[str, torch.Tensor]) -> None: ... + + def reset(self) -> None: ... + + +@dataclass +class RobotPipeline(ModelHubMixin): + """ + Composable, debuggable post-processing pipeline for RL transitions. + The class orchestrates an ordered collection of small, functional + transforms—steps—executed left-to-right on each incoming + `EnvTransition`. + Parameters: + steps : Sequence[PipelineStep], optional + Ordered list executed on every call + name : str, default="RobotPipeline" + Human-readable identifier that is persisted inside the JSON config. + seed : int | None, optional + Global seed forwarded to steps that choose to consume it. + Examples: + Basic usage:: + env = gym.make("CartPole-v1") + pipe = RobotPipeline([ + ObservationNormalizer(), + IntrinsicVelocity(), + VelocityBonus(0.02), + ]) + obs, info = env.reset(seed=0) + tr = (obs, None, 0.0, False, False, info, {}) + obs, *_ = pipe(tr) # agent sees a normalised observation + Inspecting intermediate results:: + for idx, step_tr in enumerate(pipe.step_through(tr)): + print(idx, step_tr) + Serialization to the Hugging Face Hub:: + pipe.save_pretrained("chkpt") + pipe.push_to_hub("my-org/cartpole_pipe") + loaded = RobotPipeline.from_pretrained("my-org/cartpole_pipe") + """ + steps: Sequence[PipelineStep] = field(default_factory=list) + name: str = "RobotPipeline" + seed: Optional[int] = None + + # Pipeline-level hooks + # A hook can optionally return a modified transition. If it returns + # ``None`` the current value is left untouched. + before_step_hooks: list[Callable[[int, EnvTransition], Optional[EnvTransition]]] = field( + default_factory=list, repr=False + ) + after_step_hooks: list[Callable[[int, EnvTransition], Optional[EnvTransition]]] = field( + default_factory=list, repr=False + ) + reset_hooks: list[Callable[[], None]] = field(default_factory=list, repr=False) + + def __call__(self, transition: EnvTransition) -> EnvTransition: + """Run *transition* through every step, firing hooks on the way.""" + + # Basic validation with helpful error message + if not isinstance(transition, tuple) or len(transition) != 7: + raise ValueError( + f"EnvTransition must be a 7-tuple of (observation, action, reward, done, truncated, info, complementary_data), " + f"got {type(transition).__name__} with length {len(transition) if hasattr(transition, '__len__') else 'unknown'}" + ) + + for idx, pipeline_step in enumerate(self.steps): + for hook in self.before_step_hooks: + updated = hook(idx, transition) + if updated is not None: + transition = updated + + transition = pipeline_step(transition) + + for hook in self.after_step_hooks: + updated = hook(idx, transition) + if updated is not None: + transition = updated + + return transition + + def step_through(self, transition: EnvTransition) -> Iterable[EnvTransition]: + """Yield the intermediate Transition instances after each pipeline step.""" + yield transition + for pipeline_step in self.steps: + transition = pipeline_step(transition) + yield transition + + _CFG_NAME = "pipeline.json" + + def _save_pretrained(self, destination_path: str, **kwargs): + """Internal save method for ModelHubMixin compatibility.""" + self.save_pretrained(destination_path) + + def save_pretrained(self, destination_path: str, **kwargs): + """Serialize the pipeline definition and parameters to *destination_path*.""" + os.makedirs(destination_path, exist_ok=True) + + config: Dict[str, Any] = { + "name": self.name, + "seed": self.seed, + "steps": [], + } + + for step_index, pipeline_step in enumerate(self.steps): + step_entry: Dict[str, Any] = { + "class": f"{pipeline_step.__class__.__module__}.{pipeline_step.__class__.__name__}", + } + + if hasattr(pipeline_step, "get_config"): + step_entry["config"] = pipeline_step.get_config() + + if hasattr(pipeline_step, "state_dict"): + state = pipeline_step.state_dict() + if state: + state_filename = f"step_{step_index}.safetensors" + save_file(state, os.path.join(destination_path, state_filename)) + step_entry["state_file"] = state_filename + + config["steps"].append(step_entry) + + with open(os.path.join(destination_path, self._CFG_NAME), "w") as file_pointer: + json.dump(config, file_pointer, indent=2) + + @classmethod + def from_pretrained(cls, source: str) -> "RobotPipeline": + """Load a serialized pipeline from *source* (local path or Hugging Face Hub identifier).""" + if Path(source).is_dir(): + # Local path - use it directly + base_path = Path(source) + with open(base_path / cls._CFG_NAME) as file_pointer: + config: Dict[str, Any] = json.load(file_pointer) + else: + # Hugging Face Hub - download all required files + # First download the config file + config_path = hf_hub_download(source, cls._CFG_NAME, repo_type="model") + with open(config_path) as file_pointer: + config: Dict[str, Any] = json.load(file_pointer) + + # Store downloaded files in the same directory as the config + base_path = Path(config_path).parent + + steps: list[PipelineStep] = [] + for step_entry in config["steps"]: + module_path, class_name = step_entry["class"].rsplit(".", 1) + step_class = getattr(__import__(module_path, fromlist=[class_name]), class_name) + step_instance: PipelineStep = step_class(**step_entry.get("config", {})) + + if "state_file" in step_entry and hasattr(step_instance, "load_state_dict"): + if Path(source).is_dir(): + # Local path - read directly + state_path = str(base_path / step_entry["state_file"]) + else: + # Hugging Face Hub - download the state file + state_path = hf_hub_download(source, step_entry["state_file"], repo_type="model") + + step_instance.load_state_dict(load_file(state_path)) + + steps.append(step_instance) + + return cls(steps, config.get("name", "RobotPipeline"), config.get("seed")) + + def __len__(self) -> int: + """Return the number of steps in the pipeline.""" + return len(self.steps) + + def __getitem__(self, idx: int | slice) -> PipelineStep | RobotPipeline: + """Indexing helper exposing underlying steps. + * ``int`` – returns the idx-th PipelineStep. + * ``slice`` – returns a new RobotPipeline with the sliced steps. + """ + if isinstance(idx, slice): + return RobotPipeline(self.steps[idx], self.name, self.seed) + return self.steps[idx] + + def register_before_step_hook(self, fn: Callable[[int, EnvTransition], Optional[EnvTransition]]): + """Attach fn to be executed before every pipeline step.""" + self.before_step_hooks.append(fn) + + def register_after_step_hook(self, fn: Callable[[int, EnvTransition], Optional[EnvTransition]]): + """Attach fn to be executed after every pipeline step.""" + self.after_step_hooks.append(fn) + + def register_reset_hook(self, fn: Callable[[], None]): + """Attach fn to be executed when reset is called.""" + self.reset_hooks.append(fn) + + def reset(self): + """Clear state in every step that implements ``reset()`` and fire registered hooks.""" + for step in self.steps: + if hasattr(step, "reset"): + step.reset() # type: ignore[attr-defined] + for fn in self.reset_hooks: + fn() + + def profile_steps(self, transition: EnvTransition, num_runs: int = 100) -> Dict[str, float]: + """Profile the execution time of each step for performance optimization.""" + import time + + profile_results = {} + + for idx, pipeline_step in enumerate(self.steps): + step_name = f"step_{idx}_{pipeline_step.__class__.__name__}" + + # Warm up + for _ in range(5): + _ = pipeline_step(transition) + + # Time the step + start_time = time.perf_counter() + for _ in range(num_runs): + transition = pipeline_step(transition) + end_time = time.perf_counter() + + avg_time = (end_time - start_time) / num_runs * 1000 # Convert to milliseconds + profile_results[step_name] = avg_time + + return profile_results \ No newline at end of file diff --git a/src/lerobot/scripts/eval.py b/src/lerobot/scripts/eval.py index 7c5aec48a..e60e2eb43 100644 --- a/src/lerobot/scripts/eval.py +++ b/src/lerobot/scripts/eval.py @@ -68,7 +68,10 @@ from tqdm import trange from lerobot.configs import parser from lerobot.configs.eval import EvalPipelineConfig from lerobot.envs.factory import make_env -from lerobot.envs.utils import add_envs_task, check_env_attributes_and_types, preprocess_observation +from lerobot.envs.utils import add_envs_task, check_env_attributes_and_types + +from lerobot.processor.pipeline import RobotPipeline, TransitionIndex +from lerobot.processor.observation_processor import ObservationProcessor from lerobot.policies.factory import make_policy from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.utils import get_device_from_parameters @@ -127,6 +130,9 @@ def rollout( observation, info = env.reset(seed=seeds) if render_callback is not None: render_callback(env) + + # Create observation processing pipeline + obs_pipeline = RobotPipeline([ObservationProcessor()]) all_observations = [] all_actions = [] @@ -147,7 +153,9 @@ def rollout( check_env_attributes_and_types(env) while not np.all(done): # Numpy array to tensor and changing dictionary keys to LeRobot policy format. - observation = preprocess_observation(observation) + transition = (observation, None, None, None, None, None, None) + processed_transition = obs_pipeline(transition) + observation = processed_transition[TransitionIndex.OBSERVATION] if return_observations: all_observations.append(deepcopy(observation)) @@ -195,7 +203,9 @@ def rollout( # Track the final observation. if return_observations: - observation = preprocess_observation(observation) + transition = (observation, None, None, None, None, None, None) + processed_transition = obs_pipeline(transition) + observation = processed_transition[TransitionIndex.OBSERVATION] all_observations.append(deepcopy(observation)) # Stack the sequence along the first dimension so that we have (batch, sequence, *) tensors. diff --git a/tests/envs/test_envs.py b/tests/envs/test_envs.py index 140e9dfb9..b14f7b6a7 100644 --- a/tests/envs/test_envs.py +++ b/tests/envs/test_envs.py @@ -22,7 +22,9 @@ from gymnasium.utils.env_checker import check_env import lerobot from lerobot.envs.factory import make_env, make_env_config -from lerobot.envs.utils import preprocess_observation + +from lerobot.processor.pipeline import RobotPipeline, TransitionIndex +from lerobot.processor.observation_processor import ObservationProcessor from tests.utils import require_env OBS_TYPES = ["state", "pixels", "pixels_agent_pos"] @@ -48,7 +50,12 @@ def test_factory(env_name): cfg = make_env_config(env_name) env = make_env(cfg, n_envs=1) obs, _ = env.reset() - obs = preprocess_observation(obs) + + # Process observation using pipeline + obs_pipeline = RobotPipeline([ObservationProcessor()]) + transition = (obs, None, None, None, None, None, None) + processed_transition = obs_pipeline(transition) + obs = processed_transition[TransitionIndex.OBSERVATION] # test image keys are float32 in range [0,1] for key in obs: diff --git a/tests/policies/test_policies.py b/tests/policies/test_policies.py index ed37fedd6..c48256214 100644 --- a/tests/policies/test_policies.py +++ b/tests/policies/test_policies.py @@ -30,7 +30,8 @@ from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature from lerobot.datasets.factory import make_dataset from lerobot.datasets.utils import cycle, dataset_to_policy_features from lerobot.envs.factory import make_env, make_env_config -from lerobot.envs.utils import preprocess_observation +from lerobot.processor.pipeline import RobotPipeline, TransitionIndex +from lerobot.processor.observation_processor import ObservationProcessor from lerobot.optim.factory import make_optimizer_and_scheduler from lerobot.policies.act.modeling_act import ACTTemporalEnsembler from lerobot.policies.factory import ( @@ -185,7 +186,10 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs): observation, _ = env.reset(seed=train_cfg.seed) # apply transform to normalize the observations - observation = preprocess_observation(observation) + obs_pipeline = RobotPipeline([ObservationProcessor()]) + transition = (observation, None, None, None, None, None, None) + processed_transition = obs_pipeline(transition) + observation = processed_transition[TransitionIndex.OBSERVATION] # send observation to device/gpu observation = {key: observation[key].to(DEVICE, non_blocking=True) for key in observation} diff --git a/tests/processor/test_observation_processor.py b/tests/processor/test_observation_processor.py new file mode 100644 index 000000000..d49d3874d --- /dev/null +++ b/tests/processor/test_observation_processor.py @@ -0,0 +1,466 @@ +#!/usr/bin/env python + +# Copyright 202 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest +import torch + +from lerobot.processor.observation_processor import ( + ImageProcessor, + StateProcessor, + ObservationProcessor, +) +from lerobot.processor.pipeline import EnvTransition + + +def test_process_single_image(): + """Test processing a single image.""" + processor = ImageProcessor() + + # Create a mock image (H, W, C) format, uint8 + image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8) + + observation = {"pixels": image} + transition = (observation, None, None, None, None, None, None) + + result = processor(transition) + processed_obs = result[0] + + # Check that the image was processed correctly + assert "observation.image" in processed_obs + processed_img = processed_obs["observation.image"] + + # Check shape: should be (1, 3, 64, 64) - batch, channels, height, width + assert processed_img.shape == (1, 3, 64, 64) + + # Check dtype and range + assert processed_img.dtype == torch.float32 + assert processed_img.min() >= 0.0 + assert processed_img.max() <= 1.0 + +def test_process_image_dict(): + """Test processing multiple images in a dictionary.""" + processor = ImageProcessor() + + # Create mock images + image1 = np.random.randint(0, 256, size=(32, 32, 3), dtype=np.uint8) + image2 = np.random.randint(0, 256, size=(48, 48, 3), dtype=np.uint8) + + observation = { + "pixels": { + "camera1": image1, + "camera2": image2 + } + } + transition = (observation, None, None, None, None, None, None) + + result = processor(transition) + processed_obs = result[0] + + # Check that both images were processed + assert "observation.images.camera1" in processed_obs + assert "observation.images.camera2" in processed_obs + + # Check shapes + assert processed_obs["observation.images.camera1"].shape == (1, 3, 32, 32) + assert processed_obs["observation.images.camera2"].shape == (1, 3, 48, 48) + +def test_process_batched_image(): + """Test processing already batched images.""" + processor = ImageProcessor() + + # Create a batched image (B, H, W, C) + image = np.random.randint(0, 256, size=(2, 64, 64, 3), dtype=np.uint8) + + observation = {"pixels": image} + transition = (observation, None, None, None, None, None, None) + + result = processor(transition) + processed_obs = result[0] + + # Check that batch dimension is preserved + assert processed_obs["observation.image"].shape == (2, 3, 64, 64) + +def test_invalid_image_format(): + """Test error handling for invalid image formats.""" + processor = ImageProcessor() + + # Test wrong channel order (channels first) + image = np.random.randint(0, 256, size=(3, 64, 64), dtype=np.uint8) + observation = {"pixels": image} + transition = (observation, None, None, None, None, None, None) + + with pytest.raises(ValueError, match="Expected channel-last images"): + processor(transition) + +def test_invalid_image_dtype(): + """Test error handling for invalid image dtype.""" + processor = ImageProcessor() + + # Test wrong dtype + image = np.random.rand(64, 64, 3).astype(np.float32) + observation = {"pixels": image} + transition = (observation, None, None, None, None, None, None) + + with pytest.raises(ValueError, match="Expected torch.uint8 images"): + processor(transition) + +def test_no_pixels_in_observation(): + """Test processor when no pixels are in observation.""" + processor = ImageProcessor() + + observation = {"other_data": np.array([1, 2, 3])} + transition = (observation, None, None, None, None, None, None) + + result = processor(transition) + processed_obs = result[0] + + # Should preserve other data unchanged + assert "other_data" in processed_obs + np.testing.assert_array_equal(processed_obs["other_data"], np.array([1, 2, 3])) + +def test_none_observation(): + """Test processor with None observation.""" + processor = ImageProcessor() + + transition = (None, None, None, None, None, None, None) + result = processor(transition) + + assert result == transition + +def test_serialization_methods(): + """Test serialization methods.""" + processor = ImageProcessor() + + # Test get_config + config = processor.get_config() + assert isinstance(config, dict) + + # Test state_dict + state = processor.state_dict() + assert isinstance(state, dict) + + # Test load_state_dict (should not raise) + processor.load_state_dict(state) + + # Test reset (should not raise) + processor.reset() + + +def test_process_environment_state(): + """Test processing environment_state.""" + processor = StateProcessor() + + env_state = np.array([1.0, 2.0, 3.0], dtype=np.float32) + observation = {"environment_state": env_state} + transition = (observation, None, None, None, None, None, None) + + result = processor(transition) + processed_obs = result[0] + + # Check that environment_state was renamed and processed + assert "observation.environment_state" in processed_obs + assert "environment_state" not in processed_obs + + processed_state = processed_obs["observation.environment_state"] + assert processed_state.shape == (1, 3) # Batch dimension added + assert processed_state.dtype == torch.float32 + torch.testing.assert_close(processed_state, torch.tensor([[1.0, 2.0, 3.0]])) + +def test_process_agent_pos(): + """Test processing agent_pos.""" + processor = StateProcessor() + + agent_pos = np.array([0.5, -0.5, 1.0], dtype=np.float32) + observation = {"agent_pos": agent_pos} + transition = (observation, None, None, None, None, None, None) + + result = processor(transition) + processed_obs = result[0] + + # Check that agent_pos was renamed and processed + assert "observation.state" in processed_obs + assert "agent_pos" not in processed_obs + + processed_state = processed_obs["observation.state"] + assert processed_state.shape == (1, 3) # Batch dimension added + assert processed_state.dtype == torch.float32 + torch.testing.assert_close(processed_state, torch.tensor([[0.5, -0.5, 1.0]])) + +def test_process_batched_states(): + """Test processing already batched states.""" + processor = StateProcessor() + + env_state = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32) + agent_pos = np.array([[0.5, -0.5], [1.0, -1.0]], dtype=np.float32) + + observation = { + "environment_state": env_state, + "agent_pos": agent_pos + } + transition = (observation, None, None, None, None, None, None) + + result = processor(transition) + processed_obs = result[0] + + # Check that batch dimensions are preserved + assert processed_obs["observation.environment_state"].shape == (2, 2) + assert processed_obs["observation.state"].shape == (2, 2) + +def test_process_both_states(): + """Test processing both environment_state and agent_pos.""" + processor = StateProcessor() + + env_state = np.array([1.0, 2.0], dtype=np.float32) + agent_pos = np.array([0.5, -0.5], dtype=np.float32) + + observation = { + "environment_state": env_state, + "agent_pos": agent_pos, + "other_data": "keep_me" + } + transition = (observation, None, None, None, None, None, None) + + result = processor(transition) + processed_obs = result[0] + + # Check that both states were processed + assert "observation.environment_state" in processed_obs + assert "observation.state" in processed_obs + + # Check that original keys were removed + assert "environment_state" not in processed_obs + assert "agent_pos" not in processed_obs + + # Check that other data was preserved + assert processed_obs["other_data"] == "keep_me" + +def test_no_states_in_observation(): + """Test processor when no states are in observation.""" + processor = StateProcessor() + + observation = {"other_data": np.array([1, 2, 3])} + transition = (observation, None, None, None, None, None, None) + + result = processor(transition) + processed_obs = result[0] + + # Should preserve data unchanged + assert processed_obs == observation + +def test_none_observation(): + """Test processor with None observation.""" + processor = StateProcessor() + + transition = (None, None, None, None, None, None, None) + result = processor(transition) + + assert result == transition + +def test_serialization_methods(): + """Test serialization methods.""" + processor = StateProcessor() + + # Test get_config + config = processor.get_config() + assert isinstance(config, dict) + + # Test state_dict + state = processor.state_dict() + assert isinstance(state, dict) + + # Test load_state_dict (should not raise) + processor.load_state_dict(state) + + # Test reset (should not raise) + processor.reset() + + +def test_complete_observation_processing(): + """Test processing a complete observation with both images and states.""" + processor = ObservationProcessor() + + # Create mock data + image = np.random.randint(0, 256, size=(32, 32, 3), dtype=np.uint8) + env_state = np.array([1.0, 2.0, 3.0], dtype=np.float32) + agent_pos = np.array([0.5, -0.5, 1.0], dtype=np.float32) + + observation = { + "pixels": image, + "environment_state": env_state, + "agent_pos": agent_pos, + "other_data": "preserve_me" + } + transition = (observation, None, None, None, None, None, None) + + result = processor(transition) + processed_obs = result[0] + + # Check that image was processed + assert "observation.image" in processed_obs + assert processed_obs["observation.image"].shape == (1, 3, 32, 32) + + # Check that states were processed + assert "observation.environment_state" in processed_obs + assert "observation.state" in processed_obs + + # Check that original keys were removed + assert "pixels" not in processed_obs + assert "environment_state" not in processed_obs + assert "agent_pos" not in processed_obs + + # Check that other data was preserved + assert processed_obs["other_data"] == "preserve_me" + +def test_image_only_processing(): + """Test processing observation with only images.""" + processor = ObservationProcessor() + + image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8) + observation = {"pixels": image} + transition = (observation, None, None, None, None, None, None) + + result = processor(transition) + processed_obs = result[0] + + assert "observation.image" in processed_obs + assert len(processed_obs) == 1 + +def test_state_only_processing(): + """Test processing observation with only states.""" + processor = ObservationProcessor() + + agent_pos = np.array([1.0, 2.0], dtype=np.float32) + observation = {"agent_pos": agent_pos} + transition = (observation, None, None, None, None, None, None) + + result = processor(transition) + processed_obs = result[0] + + assert "observation.state" in processed_obs + assert "agent_pos" not in processed_obs + +def test_empty_observation(): + """Test processing empty observation.""" + processor = ObservationProcessor() + + observation = {} + transition = (observation, None, None, None, None, None, None) + + result = processor(transition) + processed_obs = result[0] + + assert processed_obs == {} + +def test_none_observation(): + """Test processing None observation.""" + processor = ObservationProcessor() + + transition = (None, None, None, None, None, None, None) + result = processor(transition) + + assert result == transition + +def test_serialization_methods(): + """Test serialization methods.""" + processor = ObservationProcessor() + + # Test get_config + config = processor.get_config() + assert isinstance(config, dict) + assert "image_processor" in config + assert "state_processor" in config + + # Test state_dict + state = processor.state_dict() + assert isinstance(state, dict) + + # Test load_state_dict (should not raise) + processor.load_state_dict(state) + + # Test reset (should not raise) + processor.reset() + +def test_custom_sub_processors(): + """Test ObservationProcessor with custom sub-processors.""" + image_proc = ImageProcessor() + state_proc = StateProcessor() + processor = ObservationProcessor(image_processor=image_proc, state_processor=state_proc) + + # Should use the provided processors + assert processor.image_processor is image_proc + assert processor.state_processor is state_proc + + +def test_equivalent_to_original_function(): + """Test that ObservationProcessor produces equivalent results to preprocess_observation.""" + # Import the original function for comparison + from lerobot.envs.utils import preprocess_observation + + processor = ObservationProcessor() + + # Create test data similar to what the original function expects + image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8) + env_state = np.array([1.0, 2.0, 3.0], dtype=np.float32) + agent_pos = np.array([0.5, -0.5, 1.0], dtype=np.float32) + + observation = { + "pixels": image, + "environment_state": env_state, + "agent_pos": agent_pos + } + + # Process with original function + original_result = preprocess_observation(observation) + + # Process with new processor + transition = (observation, None, None, None, None, None, None) + processor_result = processor(transition)[0] + + # Compare results + assert set(original_result.keys()) == set(processor_result.keys()) + + for key in original_result: + torch.testing.assert_close(original_result[key], processor_result[key]) + +def test_equivalent_with_image_dict(): + """Test equivalence with dictionary of images.""" + from lerobot.envs.utils import preprocess_observation + + processor = ObservationProcessor() + + # Create test data with multiple cameras + image1 = np.random.randint(0, 256, size=(32, 32, 3), dtype=np.uint8) + image2 = np.random.randint(0, 256, size=(48, 48, 3), dtype=np.uint8) + agent_pos = np.array([1.0, 2.0], dtype=np.float32) + + observation = { + "pixels": {"cam1": image1, "cam2": image2}, + "agent_pos": agent_pos + } + + # Process with original function + original_result = preprocess_observation(observation) + + # Process with new processor + transition = (observation, None, None, None, None, None, None) + processor_result = processor(transition)[0] + + # Compare results + assert set(original_result.keys()) == set(processor_result.keys()) + + for key in original_result: + torch.testing.assert_close(original_result[key], processor_result[key]) \ No newline at end of file diff --git a/tests/processor/test_pipeline.py b/tests/processor/test_pipeline.py new file mode 100644 index 000000000..1f19b355a --- /dev/null +++ b/tests/processor/test_pipeline.py @@ -0,0 +1,408 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import tempfile +from pathlib import Path +from typing import Any +from dataclasses import dataclass + +import numpy as np +import pytest +import torch + +from lerobot.processor.pipeline import RobotPipeline, EnvTransition, PipelineStep + + +@dataclass +class MockStep: + """Mock pipeline step for testing - demonstrates best practices. + + This example shows the proper separation: + - JSON-serializable attributes (name, counter) go in get_config() + - Only torch tensors go in state_dict() + + Note: The counter is part of the configuration, so it will be restored + when the step is recreated from config during loading. + """ + + name: str = "mock_step" + counter: int = 0 + + def __call__(self, transition: EnvTransition) -> EnvTransition: + """Add a counter to the complementary_data.""" + obs, action, reward, done, truncated, info, comp_data = transition + + if comp_data is None: + comp_data = {} + else: + comp_data = dict(comp_data) # Make a copy + + comp_data[f"{self.name}_counter"] = self.counter + self.counter += 1 + + return (obs, action, reward, done, truncated, info, comp_data) + + def get_config(self) -> dict[str, Any]: + # Return all JSON-serializable attributes that should be persisted + # These will be passed to __init__ when loading + return {"name": self.name, "counter": self.counter} + + def state_dict(self) -> dict[str, torch.Tensor]: + # Only return torch tensors (empty in this case since we have no tensor state) + return {} + + def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: + # No tensor state to load + pass + + def reset(self) -> None: + self.counter = 0 + + +@dataclass +class MockStepWithoutOptionalMethods: + """Mock step that only implements the required __call__ method.""" + + multiplier: float = 2.0 + + def __call__(self, transition: EnvTransition) -> EnvTransition: + """Multiply reward by multiplier.""" + obs, action, reward, done, truncated, info, comp_data = transition + + if reward is not None: + reward = reward * self.multiplier + + return (obs, action, reward, done, truncated, info, comp_data) + + +@dataclass +class MockStepWithTensorState: + """Mock step demonstrating mixed JSON attributes and tensor state.""" + + name: str = "tensor_step" + learning_rate: float = 0.01 + window_size: int = 10 + + def __init__(self, name: str = "tensor_step", learning_rate: float = 0.01, window_size: int = 10): + self.name = name + self.learning_rate = learning_rate + self.window_size = window_size + # Tensor state + self.running_mean = torch.zeros(window_size) + self.running_count = torch.tensor(0) + + def __call__(self, transition: EnvTransition) -> EnvTransition: + """Update running statistics.""" + obs, action, reward, done, truncated, info, comp_data = transition + + if reward is not None: + # Update running mean + idx = self.running_count % self.window_size + self.running_mean[idx] = reward + self.running_count += 1 + + return transition + + def get_config(self) -> dict[str, Any]: + # Only JSON-serializable attributes + return { + "name": self.name, + "learning_rate": self.learning_rate, + "window_size": self.window_size, + } + + def state_dict(self) -> dict[str, torch.Tensor]: + # Only tensor state + return { + "running_mean": self.running_mean, + "running_count": self.running_count, + } + + def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: + self.running_mean = state["running_mean"] + self.running_count = state["running_count"] + + def reset(self) -> None: + self.running_mean.zero_() + self.running_count.zero_() + + +def test_empty_pipeline(): + """Test pipeline with no steps.""" + pipeline = RobotPipeline() + + transition = (None, None, 0.0, False, False, {}, {}) + result = pipeline(transition) + + assert result == transition + assert len(pipeline) == 0 + +def test_single_step_pipeline(): + """Test pipeline with a single step.""" + step = MockStep("test_step") + pipeline = RobotPipeline([step]) + + transition = (None, None, 0.0, False, False, {}, {}) + result = pipeline(transition) + + assert len(pipeline) == 1 + assert result[6]["test_step_counter"] == 0 # complementary_data + + # Call again to test counter increment + result = pipeline(transition) + assert result[6]["test_step_counter"] == 1 + +def test_multiple_steps_pipeline(): + """Test pipeline with multiple steps.""" + step1 = MockStep("step1") + step2 = MockStep("step2") + pipeline = RobotPipeline([step1, step2]) + + transition = (None, None, 0.0, False, False, {}, {}) + result = pipeline(transition) + + assert len(pipeline) == 2 + assert result[6]["step1_counter"] == 0 + assert result[6]["step2_counter"] == 0 + +def test_invalid_transition_format(): + """Test pipeline with invalid transition format.""" + pipeline = RobotPipeline([MockStep()]) + + # Test with wrong number of elements + with pytest.raises(ValueError, match="EnvTransition must be a 7-tuple"): + pipeline((None, None, 0.0)) # Only 3 elements + + # Test with wrong type + with pytest.raises(ValueError, match="EnvTransition must be a 7-tuple"): + pipeline("not a tuple") + +def test_step_through(): + """Test step_through method.""" + step1 = MockStep("step1") + step2 = MockStep("step2") + pipeline = RobotPipeline([step1, step2]) + + transition = (None, None, 0.0, False, False, {}, {}) + + results = list(pipeline.step_through(transition)) + + assert len(results) == 3 # Original + 2 steps + assert results[0] == transition # Original + assert "step1_counter" in results[1][6] # After step1 + assert "step2_counter" in results[2][6] # After step2 + +def test_indexing(): + """Test pipeline indexing.""" + step1 = MockStep("step1") + step2 = MockStep("step2") + pipeline = RobotPipeline([step1, step2]) + + # Test integer indexing + assert pipeline[0] is step1 + assert pipeline[1] is step2 + + # Test slice indexing + sub_pipeline = pipeline[0:1] + assert isinstance(sub_pipeline, RobotPipeline) + assert len(sub_pipeline) == 1 + assert sub_pipeline[0] is step1 + +def test_hooks(): + """Test before/after step hooks.""" + step = MockStep("test_step") + pipeline = RobotPipeline([step]) + + before_calls = [] + after_calls = [] + + def before_hook(idx: int, transition: EnvTransition): + before_calls.append(idx) + return transition + + def after_hook(idx: int, transition: EnvTransition): + after_calls.append(idx) + return transition + + pipeline.register_before_step_hook(before_hook) + pipeline.register_after_step_hook(after_hook) + + transition = (None, None, 0.0, False, False, {}, {}) + pipeline(transition) + + assert before_calls == [0] + assert after_calls == [0] + +def test_hook_modification(): + """Test that hooks can modify transitions.""" + step = MockStep("test_step") + pipeline = RobotPipeline([step]) + + def modify_reward_hook(idx: int, transition: EnvTransition): + obs, action, reward, done, truncated, info, comp_data = transition + return (obs, action, 42.0, done, truncated, info, comp_data) + + pipeline.register_before_step_hook(modify_reward_hook) + + transition = (None, None, 0.0, False, False, {}, {}) + result = pipeline(transition) + + assert result[2] == 42.0 # reward modified by hook + +def test_reset(): + """Test pipeline reset functionality.""" + step = MockStep("test_step") + pipeline = RobotPipeline([step]) + + reset_called = [] + + def reset_hook(): + reset_called.append(True) + + pipeline.register_reset_hook(reset_hook) + + # Make some calls to increment counter + transition = (None, None, 0.0, False, False, {}, {}) + pipeline(transition) + pipeline(transition) + + assert step.counter == 2 + + # Reset should reset step and call hook + pipeline.reset() + + assert step.counter == 0 + assert len(reset_called) == 1 + +def test_profile_steps(): + """Test step profiling functionality.""" + step1 = MockStep("step1") + step2 = MockStep("step2") + pipeline = RobotPipeline([step1, step2]) + + transition = (None, None, 0.0, False, False, {}, {}) + + profile_results = pipeline.profile_steps(transition, num_runs=10) + + assert len(profile_results) == 2 + assert "step_0_MockStep" in profile_results + assert "step_1_MockStep" in profile_results + assert all(isinstance(time, float) and time >= 0 for time in profile_results.values()) + +def test_save_and_load_pretrained(): + """Test saving and loading pipeline. + + This test demonstrates that JSON-serializable attributes (like counter) + are saved in the config and restored when the step is recreated. + """ + step1 = MockStep("step1") + step2 = MockStep("step2") + + # Increment counters to have some state + step1.counter = 5 + step2.counter = 10 + + pipeline = RobotPipeline([step1, step2], name="TestPipeline", seed=42) + + with tempfile.TemporaryDirectory() as tmp_dir: + # Save pipeline + pipeline.save_pretrained(tmp_dir) + + # Check files were created + config_path = Path(tmp_dir) / "pipeline.json" + assert config_path.exists() + + # Check config content + with open(config_path) as f: + config = json.load(f) + + assert config["name"] == "TestPipeline" + assert config["seed"] == 42 + assert len(config["steps"]) == 2 + + # Verify counters are saved in config, not in separate state files + assert config["steps"][0]["config"]["counter"] == 5 + assert config["steps"][1]["config"]["counter"] == 10 + + # Load pipeline + loaded_pipeline = RobotPipeline.from_pretrained(tmp_dir) + + assert loaded_pipeline.name == "TestPipeline" + assert loaded_pipeline.seed == 42 + assert len(loaded_pipeline) == 2 + + # Check that counter was restored from config + assert loaded_pipeline.steps[0].counter == 5 + assert loaded_pipeline.steps[1].counter == 10 + +def test_step_without_optional_methods(): + """Test pipeline with steps that don't implement optional methods.""" + step = MockStepWithoutOptionalMethods(multiplier=3.0) + pipeline = RobotPipeline([step]) + + transition = (None, None, 2.0, False, False, {}, {}) + result = pipeline(transition) + + assert result[2] == 6.0 # 2.0 * 3.0 + + # Reset should work even if step doesn't implement reset + pipeline.reset() + + # Save/load should work even without optional methods + with tempfile.TemporaryDirectory() as tmp_dir: + pipeline.save_pretrained(tmp_dir) + loaded_pipeline = RobotPipeline.from_pretrained(tmp_dir) + assert len(loaded_pipeline) == 1 + +def test_mixed_json_and_tensor_state(): + """Test step with both JSON attributes and tensor state.""" + step = MockStepWithTensorState(name="stats", learning_rate=0.05, window_size=5) + pipeline = RobotPipeline([step]) + + # Process some transitions with rewards + for i in range(10): + transition = (None, None, float(i), False, False, {}, {}) + pipeline(transition) + + # Check state + assert step.running_count.item() == 10 + assert step.learning_rate == 0.05 + + # Save and load + with tempfile.TemporaryDirectory() as tmp_dir: + pipeline.save_pretrained(tmp_dir) + + # Check that both config and state files were created + config_path = Path(tmp_dir) / "pipeline.json" + state_path = Path(tmp_dir) / "step_0.safetensors" + assert config_path.exists() + assert state_path.exists() + + # Load and verify + loaded_pipeline = RobotPipeline.from_pretrained(tmp_dir) + loaded_step = loaded_pipeline.steps[0] + + # Check JSON attributes were restored + assert loaded_step.name == "stats" + assert loaded_step.learning_rate == 0.05 + assert loaded_step.window_size == 5 + + # Check tensor state was restored + assert loaded_step.running_count.item() == 10 + assert torch.allclose(loaded_step.running_mean, step.running_mean) + + \ No newline at end of file