From 769f5316036451f72d4e60a415d983c8198f4cd3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 2 Jul 2025 15:31:15 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lerobot/envs/utils.py | 17 +- src/lerobot/processor/__init__.py | 6 +- .../processor/observation_processor.py | 91 +-- src/lerobot/processor/pipeline.py | 91 +-- src/lerobot/scripts/eval.py | 7 +- tests/envs/test_envs.py | 5 +- tests/policies/test_policies.py | 4 +- tests/processor/test_observation_processor.py | 554 +++++++++--------- tests/processor/test_pipeline.py | 185 +++--- 9 files changed, 485 insertions(+), 475 deletions(-) diff --git a/src/lerobot/envs/utils.py b/src/lerobot/envs/utils.py index c90113b36..428d74c08 100644 --- a/src/lerobot/envs/utils.py +++ b/src/lerobot/envs/utils.py @@ -16,10 +16,8 @@ import warnings from typing import Any -import einops import gymnasium as gym import numpy as np -import torch from torch import Tensor from lerobot.configs.types import FeatureType, PolicyFeature @@ -29,32 +27,29 @@ from lerobot.utils.utils import get_channel_first_image_shape def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Tensor]: """Convert environment observation to LeRobot format observation. - + This function uses the new pipeline system internally but maintains backward compatibility with the original interface. - + Args: observation: Dictionary of observation batches from a Gym vector environment. Returns: Dictionary of observation batches with keys renamed to LeRobot format and values as tensors. """ - from lerobot.processor.pipeline import RobotPipeline, TransitionIndex from lerobot.processor.observation_processor import ObservationProcessor - + from lerobot.processor.pipeline import RobotPipeline, TransitionIndex + # Create pipeline with observation processor pipeline = RobotPipeline([ObservationProcessor()]) - + # Create transition tuple and process transition = (observations, None, None, None, None, None, None) processed_transition = pipeline(transition) - + # Return processed observations return processed_transition[TransitionIndex.OBSERVATION] - - - def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]: # TODO(aliberts, rcadene): remove this hardcoding of keys and just use the nested keys as is # (need to externalize normalization from policies) diff --git a/src/lerobot/processor/__init__.py b/src/lerobot/processor/__init__.py index 76fb86b8f..0fe1cd3a5 100644 --- a/src/lerobot/processor/__init__.py +++ b/src/lerobot/processor/__init__.py @@ -13,16 +13,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from .pipeline import RobotPipeline, PipelineStep, EnvTransition from .observation_processor import ( ImageProcessor, - StateProcessor, ObservationProcessor, + StateProcessor, ) +from .pipeline import EnvTransition, PipelineStep, RobotPipeline __all__ = [ "RobotPipeline", - "PipelineStep", + "PipelineStep", "EnvTransition", "ImageProcessor", "StateProcessor", diff --git a/src/lerobot/processor/observation_processor.py b/src/lerobot/processor/observation_processor.py index 76189a47f..c2a1c88a7 100644 --- a/src/lerobot/processor/observation_processor.py +++ b/src/lerobot/processor/observation_processor.py @@ -15,35 +15,36 @@ # limitations under the License. from __future__ import annotations -from typing import Any from dataclasses import dataclass, field +from typing import Any + +import einops import numpy as np import torch -import einops from torch import Tensor -from lerobot.processor.pipeline import EnvTransition, PipelineStep, TransitionIndex +from lerobot.processor.pipeline import EnvTransition, TransitionIndex @dataclass class ImageProcessor: """Process image observations from environment format to policy format. - + Converts images from: - Channel-last (H, W, C) to channel-first (C, H, W) - uint8 [0, 255] to float32 [0, 1] - Adds batch dimension if needed - Handles both single images and dictionaries of images """ - + def __call__(self, transition: EnvTransition) -> EnvTransition: observation = transition[TransitionIndex.OBSERVATION] - + if observation is None: return transition - + processed_obs = {} - + # Handle pixels key if "pixels" in observation: if isinstance(observation["pixels"], dict): @@ -54,12 +55,12 @@ class ImageProcessor: for imgkey, img in imgs.items(): processed_img = self._process_single_image(img) processed_obs[imgkey] = processed_img - + # Copy other observations unchanged for key, value in observation.items(): if key != "pixels": processed_obs[key] = value - + # Return new transition with processed observation return ( processed_obs, @@ -70,44 +71,44 @@ class ImageProcessor: transition[TransitionIndex.INFO], transition[TransitionIndex.COMPLEMENTARY_DATA], ) - + def _process_single_image(self, img: np.ndarray) -> Tensor: """Process a single image array.""" # Convert to tensor img_tensor = torch.from_numpy(img) - + # Add batch dimension if needed if img_tensor.ndim == 3: img_tensor = img_tensor.unsqueeze(0) - + # Validate image format _, h, w, c = img_tensor.shape if not (c < h and c < w): raise ValueError(f"Expected channel-last images, but got shape {img_tensor.shape}") - + if img_tensor.dtype != torch.uint8: raise ValueError(f"Expected torch.uint8 images, but got {img_tensor.dtype}") - + # Convert to channel-first format img_tensor = einops.rearrange(img_tensor, "b h w c -> b c h w").contiguous() - + # Convert to float32 and normalize to [0, 1] img_tensor = img_tensor.type(torch.float32) / 255.0 - + return img_tensor - + def get_config(self) -> dict[str, Any]: """Return configuration for serialization.""" return {} - + def state_dict(self) -> dict[str, torch.Tensor]: """Return state dictionary (empty for this processor).""" return {} - + def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: """Load state dictionary (no-op for this processor).""" pass - + def reset(self) -> None: """Reset processor state (no-op for this processor).""" pass @@ -116,22 +117,22 @@ class ImageProcessor: @dataclass class StateProcessor: """Process state observations from environment format to policy format. - + Handles: - environment_state -> observation.environment_state - agent_pos -> observation.state - Converts numpy arrays to tensors - Adds batch dimension if needed """ - + def __call__(self, transition: EnvTransition) -> EnvTransition: observation = transition[TransitionIndex.OBSERVATION] - + if observation is None: return transition - + processed_obs = dict(observation) # Copy existing observations - + # Process environment_state if "environment_state" in observation: env_state = torch.from_numpy(observation["environment_state"]).float() @@ -140,16 +141,16 @@ class StateProcessor: processed_obs["observation.environment_state"] = env_state # Remove original key del processed_obs["environment_state"] - + # Process agent_pos if "agent_pos" in observation: agent_pos = torch.from_numpy(observation["agent_pos"]).float() if agent_pos.dim() == 1: agent_pos = agent_pos.unsqueeze(0) processed_obs["observation.state"] = agent_pos - # Remove original key + # Remove original key del processed_obs["agent_pos"] - + # Return new transition with processed observation return ( processed_obs, @@ -160,19 +161,19 @@ class StateProcessor: transition[TransitionIndex.INFO], transition[TransitionIndex.COMPLEMENTARY_DATA], ) - + def get_config(self) -> dict[str, Any]: """Return configuration for serialization.""" return {} - + def state_dict(self) -> dict[str, torch.Tensor]: """Return state dictionary (empty for this processor).""" return {} - + def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: """Load state dictionary (no-op for this processor).""" pass - + def reset(self) -> None: """Reset processor state (no-op for this processor).""" pass @@ -181,43 +182,47 @@ class StateProcessor: @dataclass class ObservationProcessor: """Complete observation processor that combines image and state processing. - + This processor replicates the functionality of the original preprocess_observation function but in a modular, composable way that fits into the pipeline architecture. """ - + image_processor: ImageProcessor = field(default_factory=ImageProcessor) state_processor: StateProcessor = field(default_factory=StateProcessor) - + def __call__(self, transition: EnvTransition) -> EnvTransition: # First process images transition = self.image_processor(transition) # Then process state transition = self.state_processor(transition) return transition - + def get_config(self) -> dict[str, Any]: """Return configuration for serialization.""" return { "image_processor": self.image_processor.get_config(), "state_processor": self.state_processor.get_config(), } - + def state_dict(self) -> dict[str, torch.Tensor]: """Return state dictionary.""" state = {} state.update({f"image_processor.{k}": v for k, v in self.image_processor.state_dict().items()}) state.update({f"state_processor.{k}": v for k, v in self.state_processor.state_dict().items()}) return state - + def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: """Load state dictionary.""" - image_state = {k.replace("image_processor.", ""): v for k, v in state.items() if k.startswith("image_processor.")} - state_state = {k.replace("state_processor.", ""): v for k, v in state.items() if k.startswith("state_processor.")} - + image_state = { + k.replace("image_processor.", ""): v for k, v in state.items() if k.startswith("image_processor.") + } + state_state = { + k.replace("state_processor.", ""): v for k, v in state.items() if k.startswith("state_processor.") + } + self.image_processor.load_state_dict(image_state) self.state_processor.load_state_dict(state_state) - + def reset(self) -> None: """Reset processor state.""" self.image_processor.reset() diff --git a/src/lerobot/processor/pipeline.py b/src/lerobot/processor/pipeline.py index 2451ed99a..8003ce214 100644 --- a/src/lerobot/processor/pipeline.py +++ b/src/lerobot/processor/pipeline.py @@ -14,19 +14,22 @@ # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations -import os, json -from typing import Any, Dict, Sequence, Iterable, Protocol, Optional, Tuple, Callable, Union + +import json +import os from dataclasses import dataclass, field -from pathlib import Path from enum import IntEnum -import numpy as np +from pathlib import Path +from typing import Any, Callable, Dict, Iterable, Protocol, Sequence, Tuple + import torch -from huggingface_hub import hf_hub_download, ModelHubMixin -from safetensors.torch import save_file, load_file +from huggingface_hub import ModelHubMixin, hf_hub_download +from safetensors.torch import load_file, save_file class TransitionIndex(IntEnum): """Explicit indices for EnvTransition tuple components.""" + OBSERVATION = 0 ACTION = 1 REWARD = 2 @@ -38,29 +41,28 @@ class TransitionIndex(IntEnum): # (observation, action, reward, done, truncated, info, complementary_data) EnvTransition = Tuple[ - Any| None, # observation - Any| None, # action - float| None, # reward - bool| None, # done - bool| None, # truncated - Dict[str, Any]| None, # info - Dict[str, Any]| None, # complementary_data + Any | None, # observation + Any | None, # action + float | None, # reward + bool | None, # done + bool | None, # truncated + Dict[str, Any] | None, # info + Dict[str, Any] | None, # complementary_data ] - class PipelineStep(Protocol): """Structural typing interface for a single pipeline step. - + A step is any callable accepting a full `EnvTransition` tuple and returning a (possibly modified) tuple of the same structure. Implementers are encouraged—but not required—to expose the optional helper methods listed below. When present, these hooks let `RobotPipeline` automatically serialise the step's configuration and learnable state using a safe-to-share JSON + SafeTensors format. - + Optional helper protocol: - * ``get_config() -> Dict[str, Any]`` – User-defined JSON-serializable + * ``get_config() -> Dict[str, Any]`` – User-defined JSON-serializable configuration and state. YOU decide what to save here. This is where all non-tensor state goes (e.g., name, counter, threshold, window_size). The config dict will be passed to your class constructor when loading. @@ -70,7 +72,7 @@ class PipelineStep(Protocol): * ``load_state_dict(state)`` – Inverse of ``state_dict``. Receives a dict containing torch tensors only. * ``reset()`` – Clear internal buffers at episode boundaries. - + Example separation: - get_config(): {"name": "my_step", "learning_rate": 0.01, "window_size": 10} - state_dict(): {"weights": torch.tensor(...), "running_mean": torch.tensor(...)} @@ -78,11 +80,11 @@ class PipelineStep(Protocol): def __call__(self, transition: EnvTransition) -> EnvTransition: ... - def get_config(self) -> Dict[str, Any]: ... + def get_config(self) -> dict[str, Any]: ... - def state_dict(self) -> Dict[str, torch.Tensor]: ... + def state_dict(self) -> dict[str, torch.Tensor]: ... - def load_state_dict(self, state: Dict[str, torch.Tensor]) -> None: ... + def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: ... def reset(self) -> None: ... @@ -120,24 +122,25 @@ class RobotPipeline(ModelHubMixin): pipe.push_to_hub("my-org/cartpole_pipe") loaded = RobotPipeline.from_pretrained("my-org/cartpole_pipe") """ + steps: Sequence[PipelineStep] = field(default_factory=list) name: str = "RobotPipeline" - seed: Optional[int] = None + seed: int | None = None # Pipeline-level hooks # A hook can optionally return a modified transition. If it returns # ``None`` the current value is left untouched. - before_step_hooks: list[Callable[[int, EnvTransition], Optional[EnvTransition]]] = field( + before_step_hooks: list[Callable[[int, EnvTransition], EnvTransition | None]] = field( default_factory=list, repr=False ) - after_step_hooks: list[Callable[[int, EnvTransition], Optional[EnvTransition]]] = field( + after_step_hooks: list[Callable[[int, EnvTransition], EnvTransition | None]] = field( default_factory=list, repr=False ) reset_hooks: list[Callable[[], None]] = field(default_factory=list, repr=False) def __call__(self, transition: EnvTransition) -> EnvTransition: """Run *transition* through every step, firing hooks on the way.""" - + # Basic validation with helpful error message if not isinstance(transition, tuple) or len(transition) != 7: raise ValueError( @@ -168,23 +171,23 @@ class RobotPipeline(ModelHubMixin): yield transition _CFG_NAME = "pipeline.json" - + def _save_pretrained(self, destination_path: str, **kwargs): """Internal save method for ModelHubMixin compatibility.""" self.save_pretrained(destination_path) - + def save_pretrained(self, destination_path: str, **kwargs): """Serialize the pipeline definition and parameters to *destination_path*.""" os.makedirs(destination_path, exist_ok=True) - config: Dict[str, Any] = { + config: dict[str, Any] = { "name": self.name, "seed": self.seed, "steps": [], } for step_index, pipeline_step in enumerate(self.steps): - step_entry: Dict[str, Any] = { + step_entry: dict[str, Any] = { "class": f"{pipeline_step.__class__.__module__}.{pipeline_step.__class__.__name__}", } @@ -204,20 +207,20 @@ class RobotPipeline(ModelHubMixin): json.dump(config, file_pointer, indent=2) @classmethod - def from_pretrained(cls, source: str) -> "RobotPipeline": + def from_pretrained(cls, source: str) -> RobotPipeline: """Load a serialized pipeline from *source* (local path or Hugging Face Hub identifier).""" if Path(source).is_dir(): # Local path - use it directly base_path = Path(source) with open(base_path / cls._CFG_NAME) as file_pointer: - config: Dict[str, Any] = json.load(file_pointer) + config: dict[str, Any] = json.load(file_pointer) else: # Hugging Face Hub - download all required files # First download the config file config_path = hf_hub_download(source, cls._CFG_NAME, repo_type="model") with open(config_path) as file_pointer: - config: Dict[str, Any] = json.load(file_pointer) - + config: dict[str, Any] = json.load(file_pointer) + # Store downloaded files in the same directory as the config base_path = Path(config_path).parent @@ -234,7 +237,7 @@ class RobotPipeline(ModelHubMixin): else: # Hugging Face Hub - download the state file state_path = hf_hub_download(source, step_entry["state_file"], repo_type="model") - + step_instance.load_state_dict(load_file(state_path)) steps.append(step_instance) @@ -254,11 +257,11 @@ class RobotPipeline(ModelHubMixin): return RobotPipeline(self.steps[idx], self.name, self.seed) return self.steps[idx] - def register_before_step_hook(self, fn: Callable[[int, EnvTransition], Optional[EnvTransition]]): + def register_before_step_hook(self, fn: Callable[[int, EnvTransition], EnvTransition | None]): """Attach fn to be executed before every pipeline step.""" self.before_step_hooks.append(fn) - def register_after_step_hook(self, fn: Callable[[int, EnvTransition], Optional[EnvTransition]]): + def register_after_step_hook(self, fn: Callable[[int, EnvTransition], EnvTransition | None]): """Attach fn to be executed after every pipeline step.""" self.after_step_hooks.append(fn) @@ -274,26 +277,26 @@ class RobotPipeline(ModelHubMixin): for fn in self.reset_hooks: fn() - def profile_steps(self, transition: EnvTransition, num_runs: int = 100) -> Dict[str, float]: + def profile_steps(self, transition: EnvTransition, num_runs: int = 100) -> dict[str, float]: """Profile the execution time of each step for performance optimization.""" import time - + profile_results = {} - + for idx, pipeline_step in enumerate(self.steps): step_name = f"step_{idx}_{pipeline_step.__class__.__name__}" - + # Warm up for _ in range(5): _ = pipeline_step(transition) - + # Time the step start_time = time.perf_counter() for _ in range(num_runs): transition = pipeline_step(transition) end_time = time.perf_counter() - + avg_time = (end_time - start_time) / num_runs * 1000 # Convert to milliseconds profile_results[step_name] = avg_time - - return profile_results \ No newline at end of file + + return profile_results diff --git a/src/lerobot/scripts/eval.py b/src/lerobot/scripts/eval.py index e60e2eb43..5b604707c 100644 --- a/src/lerobot/scripts/eval.py +++ b/src/lerobot/scripts/eval.py @@ -69,12 +69,11 @@ from lerobot.configs import parser from lerobot.configs.eval import EvalPipelineConfig from lerobot.envs.factory import make_env from lerobot.envs.utils import add_envs_task, check_env_attributes_and_types - -from lerobot.processor.pipeline import RobotPipeline, TransitionIndex -from lerobot.processor.observation_processor import ObservationProcessor from lerobot.policies.factory import make_policy from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.utils import get_device_from_parameters +from lerobot.processor.observation_processor import ObservationProcessor +from lerobot.processor.pipeline import RobotPipeline, TransitionIndex from lerobot.utils.io_utils import write_video from lerobot.utils.random_utils import set_seed from lerobot.utils.utils import ( @@ -130,7 +129,7 @@ def rollout( observation, info = env.reset(seed=seeds) if render_callback is not None: render_callback(env) - + # Create observation processing pipeline obs_pipeline = RobotPipeline([ObservationProcessor()]) diff --git a/tests/envs/test_envs.py b/tests/envs/test_envs.py index b14f7b6a7..f2a1a14e1 100644 --- a/tests/envs/test_envs.py +++ b/tests/envs/test_envs.py @@ -22,9 +22,8 @@ from gymnasium.utils.env_checker import check_env import lerobot from lerobot.envs.factory import make_env, make_env_config - -from lerobot.processor.pipeline import RobotPipeline, TransitionIndex from lerobot.processor.observation_processor import ObservationProcessor +from lerobot.processor.pipeline import RobotPipeline, TransitionIndex from tests.utils import require_env OBS_TYPES = ["state", "pixels", "pixels_agent_pos"] @@ -50,7 +49,7 @@ def test_factory(env_name): cfg = make_env_config(env_name) env = make_env(cfg, n_envs=1) obs, _ = env.reset() - + # Process observation using pipeline obs_pipeline = RobotPipeline([ObservationProcessor()]) transition = (obs, None, None, None, None, None, None) diff --git a/tests/policies/test_policies.py b/tests/policies/test_policies.py index c48256214..9993ef6fa 100644 --- a/tests/policies/test_policies.py +++ b/tests/policies/test_policies.py @@ -30,8 +30,6 @@ from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature from lerobot.datasets.factory import make_dataset from lerobot.datasets.utils import cycle, dataset_to_policy_features from lerobot.envs.factory import make_env, make_env_config -from lerobot.processor.pipeline import RobotPipeline, TransitionIndex -from lerobot.processor.observation_processor import ObservationProcessor from lerobot.optim.factory import make_optimizer_and_scheduler from lerobot.policies.act.modeling_act import ACTTemporalEnsembler from lerobot.policies.factory import ( @@ -41,6 +39,8 @@ from lerobot.policies.factory import ( ) from lerobot.policies.normalize import Normalize, Unnormalize from lerobot.policies.pretrained import PreTrainedPolicy +from lerobot.processor.observation_processor import ObservationProcessor +from lerobot.processor.pipeline import RobotPipeline, TransitionIndex from lerobot.utils.random_utils import seeded_context from tests.artifacts.policies.save_policy_to_safetensors import get_policy_stats from tests.utils import DEVICE, require_cpu, require_env, require_x86_64_kernel diff --git a/tests/processor/test_observation_processor.py b/tests/processor/test_observation_processor.py index d49d3874d..7b3d17109 100644 --- a/tests/processor/test_observation_processor.py +++ b/tests/processor/test_observation_processor.py @@ -20,447 +20,447 @@ import torch from lerobot.processor.observation_processor import ( ImageProcessor, - StateProcessor, ObservationProcessor, + StateProcessor, ) -from lerobot.processor.pipeline import EnvTransition def test_process_single_image(): """Test processing a single image.""" processor = ImageProcessor() - + # Create a mock image (H, W, C) format, uint8 image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8) - + observation = {"pixels": image} transition = (observation, None, None, None, None, None, None) - + result = processor(transition) processed_obs = result[0] - + # Check that the image was processed correctly assert "observation.image" in processed_obs processed_img = processed_obs["observation.image"] - + # Check shape: should be (1, 3, 64, 64) - batch, channels, height, width assert processed_img.shape == (1, 3, 64, 64) - + # Check dtype and range assert processed_img.dtype == torch.float32 assert processed_img.min() >= 0.0 assert processed_img.max() <= 1.0 + def test_process_image_dict(): """Test processing multiple images in a dictionary.""" processor = ImageProcessor() - + # Create mock images image1 = np.random.randint(0, 256, size=(32, 32, 3), dtype=np.uint8) image2 = np.random.randint(0, 256, size=(48, 48, 3), dtype=np.uint8) - - observation = { - "pixels": { - "camera1": image1, - "camera2": image2 - } - } + + observation = {"pixels": {"camera1": image1, "camera2": image2}} transition = (observation, None, None, None, None, None, None) - + result = processor(transition) processed_obs = result[0] - + # Check that both images were processed assert "observation.images.camera1" in processed_obs assert "observation.images.camera2" in processed_obs - + # Check shapes assert processed_obs["observation.images.camera1"].shape == (1, 3, 32, 32) assert processed_obs["observation.images.camera2"].shape == (1, 3, 48, 48) + def test_process_batched_image(): """Test processing already batched images.""" processor = ImageProcessor() - + # Create a batched image (B, H, W, C) image = np.random.randint(0, 256, size=(2, 64, 64, 3), dtype=np.uint8) - + observation = {"pixels": image} transition = (observation, None, None, None, None, None, None) - + result = processor(transition) processed_obs = result[0] - + # Check that batch dimension is preserved assert processed_obs["observation.image"].shape == (2, 3, 64, 64) + def test_invalid_image_format(): """Test error handling for invalid image formats.""" processor = ImageProcessor() - + # Test wrong channel order (channels first) image = np.random.randint(0, 256, size=(3, 64, 64), dtype=np.uint8) observation = {"pixels": image} transition = (observation, None, None, None, None, None, None) - + with pytest.raises(ValueError, match="Expected channel-last images"): processor(transition) + def test_invalid_image_dtype(): """Test error handling for invalid image dtype.""" processor = ImageProcessor() - + # Test wrong dtype image = np.random.rand(64, 64, 3).astype(np.float32) observation = {"pixels": image} transition = (observation, None, None, None, None, None, None) - + with pytest.raises(ValueError, match="Expected torch.uint8 images"): processor(transition) + def test_no_pixels_in_observation(): """Test processor when no pixels are in observation.""" processor = ImageProcessor() - + observation = {"other_data": np.array([1, 2, 3])} transition = (observation, None, None, None, None, None, None) - + result = processor(transition) processed_obs = result[0] - + # Should preserve other data unchanged assert "other_data" in processed_obs np.testing.assert_array_equal(processed_obs["other_data"], np.array([1, 2, 3])) + def test_none_observation(): """Test processor with None observation.""" processor = ImageProcessor() - + transition = (None, None, None, None, None, None, None) result = processor(transition) - + assert result == transition + def test_serialization_methods(): """Test serialization methods.""" processor = ImageProcessor() - + # Test get_config config = processor.get_config() assert isinstance(config, dict) - + # Test state_dict state = processor.state_dict() assert isinstance(state, dict) - + # Test load_state_dict (should not raise) processor.load_state_dict(state) - + # Test reset (should not raise) processor.reset() def test_process_environment_state(): - """Test processing environment_state.""" - processor = StateProcessor() - - env_state = np.array([1.0, 2.0, 3.0], dtype=np.float32) - observation = {"environment_state": env_state} - transition = (observation, None, None, None, None, None, None) - - result = processor(transition) - processed_obs = result[0] - - # Check that environment_state was renamed and processed - assert "observation.environment_state" in processed_obs - assert "environment_state" not in processed_obs - - processed_state = processed_obs["observation.environment_state"] - assert processed_state.shape == (1, 3) # Batch dimension added - assert processed_state.dtype == torch.float32 - torch.testing.assert_close(processed_state, torch.tensor([[1.0, 2.0, 3.0]])) - + """Test processing environment_state.""" + processor = StateProcessor() + + env_state = np.array([1.0, 2.0, 3.0], dtype=np.float32) + observation = {"environment_state": env_state} + transition = (observation, None, None, None, None, None, None) + + result = processor(transition) + processed_obs = result[0] + + # Check that environment_state was renamed and processed + assert "observation.environment_state" in processed_obs + assert "environment_state" not in processed_obs + + processed_state = processed_obs["observation.environment_state"] + assert processed_state.shape == (1, 3) # Batch dimension added + assert processed_state.dtype == torch.float32 + torch.testing.assert_close(processed_state, torch.tensor([[1.0, 2.0, 3.0]])) + + def test_process_agent_pos(): - """Test processing agent_pos.""" - processor = StateProcessor() - - agent_pos = np.array([0.5, -0.5, 1.0], dtype=np.float32) - observation = {"agent_pos": agent_pos} - transition = (observation, None, None, None, None, None, None) - - result = processor(transition) - processed_obs = result[0] - - # Check that agent_pos was renamed and processed - assert "observation.state" in processed_obs - assert "agent_pos" not in processed_obs - - processed_state = processed_obs["observation.state"] - assert processed_state.shape == (1, 3) # Batch dimension added - assert processed_state.dtype == torch.float32 - torch.testing.assert_close(processed_state, torch.tensor([[0.5, -0.5, 1.0]])) - + """Test processing agent_pos.""" + processor = StateProcessor() + + agent_pos = np.array([0.5, -0.5, 1.0], dtype=np.float32) + observation = {"agent_pos": agent_pos} + transition = (observation, None, None, None, None, None, None) + + result = processor(transition) + processed_obs = result[0] + + # Check that agent_pos was renamed and processed + assert "observation.state" in processed_obs + assert "agent_pos" not in processed_obs + + processed_state = processed_obs["observation.state"] + assert processed_state.shape == (1, 3) # Batch dimension added + assert processed_state.dtype == torch.float32 + torch.testing.assert_close(processed_state, torch.tensor([[0.5, -0.5, 1.0]])) + + def test_process_batched_states(): - """Test processing already batched states.""" - processor = StateProcessor() - - env_state = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32) - agent_pos = np.array([[0.5, -0.5], [1.0, -1.0]], dtype=np.float32) - - observation = { - "environment_state": env_state, - "agent_pos": agent_pos - } - transition = (observation, None, None, None, None, None, None) - - result = processor(transition) - processed_obs = result[0] - - # Check that batch dimensions are preserved - assert processed_obs["observation.environment_state"].shape == (2, 2) - assert processed_obs["observation.state"].shape == (2, 2) - + """Test processing already batched states.""" + processor = StateProcessor() + + env_state = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32) + agent_pos = np.array([[0.5, -0.5], [1.0, -1.0]], dtype=np.float32) + + observation = {"environment_state": env_state, "agent_pos": agent_pos} + transition = (observation, None, None, None, None, None, None) + + result = processor(transition) + processed_obs = result[0] + + # Check that batch dimensions are preserved + assert processed_obs["observation.environment_state"].shape == (2, 2) + assert processed_obs["observation.state"].shape == (2, 2) + + def test_process_both_states(): """Test processing both environment_state and agent_pos.""" processor = StateProcessor() - + env_state = np.array([1.0, 2.0], dtype=np.float32) agent_pos = np.array([0.5, -0.5], dtype=np.float32) - - observation = { - "environment_state": env_state, - "agent_pos": agent_pos, - "other_data": "keep_me" - } + + observation = {"environment_state": env_state, "agent_pos": agent_pos, "other_data": "keep_me"} transition = (observation, None, None, None, None, None, None) - + result = processor(transition) processed_obs = result[0] - + # Check that both states were processed assert "observation.environment_state" in processed_obs assert "observation.state" in processed_obs - + # Check that original keys were removed assert "environment_state" not in processed_obs assert "agent_pos" not in processed_obs - + # Check that other data was preserved assert processed_obs["other_data"] == "keep_me" + def test_no_states_in_observation(): """Test processor when no states are in observation.""" processor = StateProcessor() - + observation = {"other_data": np.array([1, 2, 3])} transition = (observation, None, None, None, None, None, None) - + result = processor(transition) processed_obs = result[0] - + # Should preserve data unchanged assert processed_obs == observation + def test_none_observation(): """Test processor with None observation.""" processor = StateProcessor() - + transition = (None, None, None, None, None, None, None) result = processor(transition) - + assert result == transition + def test_serialization_methods(): """Test serialization methods.""" processor = StateProcessor() - + # Test get_config config = processor.get_config() assert isinstance(config, dict) - + # Test state_dict state = processor.state_dict() assert isinstance(state, dict) - + # Test load_state_dict (should not raise) processor.load_state_dict(state) - + # Test reset (should not raise) processor.reset() def test_complete_observation_processing(): - """Test processing a complete observation with both images and states.""" - processor = ObservationProcessor() - - # Create mock data - image = np.random.randint(0, 256, size=(32, 32, 3), dtype=np.uint8) - env_state = np.array([1.0, 2.0, 3.0], dtype=np.float32) - agent_pos = np.array([0.5, -0.5, 1.0], dtype=np.float32) - - observation = { - "pixels": image, - "environment_state": env_state, - "agent_pos": agent_pos, - "other_data": "preserve_me" - } - transition = (observation, None, None, None, None, None, None) - - result = processor(transition) - processed_obs = result[0] - - # Check that image was processed - assert "observation.image" in processed_obs - assert processed_obs["observation.image"].shape == (1, 3, 32, 32) - - # Check that states were processed - assert "observation.environment_state" in processed_obs - assert "observation.state" in processed_obs - - # Check that original keys were removed - assert "pixels" not in processed_obs - assert "environment_state" not in processed_obs - assert "agent_pos" not in processed_obs - - # Check that other data was preserved - assert processed_obs["other_data"] == "preserve_me" - + """Test processing a complete observation with both images and states.""" + processor = ObservationProcessor() + + # Create mock data + image = np.random.randint(0, 256, size=(32, 32, 3), dtype=np.uint8) + env_state = np.array([1.0, 2.0, 3.0], dtype=np.float32) + agent_pos = np.array([0.5, -0.5, 1.0], dtype=np.float32) + + observation = { + "pixels": image, + "environment_state": env_state, + "agent_pos": agent_pos, + "other_data": "preserve_me", + } + transition = (observation, None, None, None, None, None, None) + + result = processor(transition) + processed_obs = result[0] + + # Check that image was processed + assert "observation.image" in processed_obs + assert processed_obs["observation.image"].shape == (1, 3, 32, 32) + + # Check that states were processed + assert "observation.environment_state" in processed_obs + assert "observation.state" in processed_obs + + # Check that original keys were removed + assert "pixels" not in processed_obs + assert "environment_state" not in processed_obs + assert "agent_pos" not in processed_obs + + # Check that other data was preserved + assert processed_obs["other_data"] == "preserve_me" + + def test_image_only_processing(): - """Test processing observation with only images.""" - processor = ObservationProcessor() - - image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8) - observation = {"pixels": image} - transition = (observation, None, None, None, None, None, None) - - result = processor(transition) - processed_obs = result[0] - - assert "observation.image" in processed_obs - assert len(processed_obs) == 1 - + """Test processing observation with only images.""" + processor = ObservationProcessor() + + image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8) + observation = {"pixels": image} + transition = (observation, None, None, None, None, None, None) + + result = processor(transition) + processed_obs = result[0] + + assert "observation.image" in processed_obs + assert len(processed_obs) == 1 + + def test_state_only_processing(): - """Test processing observation with only states.""" - processor = ObservationProcessor() - - agent_pos = np.array([1.0, 2.0], dtype=np.float32) - observation = {"agent_pos": agent_pos} - transition = (observation, None, None, None, None, None, None) - - result = processor(transition) - processed_obs = result[0] - - assert "observation.state" in processed_obs - assert "agent_pos" not in processed_obs - + """Test processing observation with only states.""" + processor = ObservationProcessor() + + agent_pos = np.array([1.0, 2.0], dtype=np.float32) + observation = {"agent_pos": agent_pos} + transition = (observation, None, None, None, None, None, None) + + result = processor(transition) + processed_obs = result[0] + + assert "observation.state" in processed_obs + assert "agent_pos" not in processed_obs + + def test_empty_observation(): - """Test processing empty observation.""" - processor = ObservationProcessor() - - observation = {} - transition = (observation, None, None, None, None, None, None) - - result = processor(transition) - processed_obs = result[0] - - assert processed_obs == {} - + """Test processing empty observation.""" + processor = ObservationProcessor() + + observation = {} + transition = (observation, None, None, None, None, None, None) + + result = processor(transition) + processed_obs = result[0] + + assert processed_obs == {} + + def test_none_observation(): - """Test processing None observation.""" - processor = ObservationProcessor() - - transition = (None, None, None, None, None, None, None) - result = processor(transition) - - assert result == transition - + """Test processing None observation.""" + processor = ObservationProcessor() + + transition = (None, None, None, None, None, None, None) + result = processor(transition) + + assert result == transition + + def test_serialization_methods(): - """Test serialization methods.""" - processor = ObservationProcessor() - - # Test get_config - config = processor.get_config() - assert isinstance(config, dict) - assert "image_processor" in config - assert "state_processor" in config - - # Test state_dict - state = processor.state_dict() - assert isinstance(state, dict) - - # Test load_state_dict (should not raise) - processor.load_state_dict(state) - - # Test reset (should not raise) - processor.reset() - + """Test serialization methods.""" + processor = ObservationProcessor() + + # Test get_config + config = processor.get_config() + assert isinstance(config, dict) + assert "image_processor" in config + assert "state_processor" in config + + # Test state_dict + state = processor.state_dict() + assert isinstance(state, dict) + + # Test load_state_dict (should not raise) + processor.load_state_dict(state) + + # Test reset (should not raise) + processor.reset() + + def test_custom_sub_processors(): - """Test ObservationProcessor with custom sub-processors.""" - image_proc = ImageProcessor() - state_proc = StateProcessor() - processor = ObservationProcessor(image_processor=image_proc, state_processor=state_proc) - - # Should use the provided processors - assert processor.image_processor is image_proc - assert processor.state_processor is state_proc + """Test ObservationProcessor with custom sub-processors.""" + image_proc = ImageProcessor() + state_proc = StateProcessor() + processor = ObservationProcessor(image_processor=image_proc, state_processor=state_proc) + + # Should use the provided processors + assert processor.image_processor is image_proc + assert processor.state_processor is state_proc def test_equivalent_to_original_function(): - """Test that ObservationProcessor produces equivalent results to preprocess_observation.""" - # Import the original function for comparison - from lerobot.envs.utils import preprocess_observation - - processor = ObservationProcessor() - - # Create test data similar to what the original function expects - image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8) - env_state = np.array([1.0, 2.0, 3.0], dtype=np.float32) - agent_pos = np.array([0.5, -0.5, 1.0], dtype=np.float32) - - observation = { - "pixels": image, - "environment_state": env_state, - "agent_pos": agent_pos - } - - # Process with original function - original_result = preprocess_observation(observation) - - # Process with new processor - transition = (observation, None, None, None, None, None, None) - processor_result = processor(transition)[0] - - # Compare results - assert set(original_result.keys()) == set(processor_result.keys()) - - for key in original_result: - torch.testing.assert_close(original_result[key], processor_result[key]) - + """Test that ObservationProcessor produces equivalent results to preprocess_observation.""" + # Import the original function for comparison + from lerobot.envs.utils import preprocess_observation + + processor = ObservationProcessor() + + # Create test data similar to what the original function expects + image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8) + env_state = np.array([1.0, 2.0, 3.0], dtype=np.float32) + agent_pos = np.array([0.5, -0.5, 1.0], dtype=np.float32) + + observation = {"pixels": image, "environment_state": env_state, "agent_pos": agent_pos} + + # Process with original function + original_result = preprocess_observation(observation) + + # Process with new processor + transition = (observation, None, None, None, None, None, None) + processor_result = processor(transition)[0] + + # Compare results + assert set(original_result.keys()) == set(processor_result.keys()) + + for key in original_result: + torch.testing.assert_close(original_result[key], processor_result[key]) + + def test_equivalent_with_image_dict(): - """Test equivalence with dictionary of images.""" - from lerobot.envs.utils import preprocess_observation - - processor = ObservationProcessor() - - # Create test data with multiple cameras - image1 = np.random.randint(0, 256, size=(32, 32, 3), dtype=np.uint8) - image2 = np.random.randint(0, 256, size=(48, 48, 3), dtype=np.uint8) - agent_pos = np.array([1.0, 2.0], dtype=np.float32) - - observation = { - "pixels": {"cam1": image1, "cam2": image2}, - "agent_pos": agent_pos - } - - # Process with original function - original_result = preprocess_observation(observation) - - # Process with new processor - transition = (observation, None, None, None, None, None, None) - processor_result = processor(transition)[0] - - # Compare results - assert set(original_result.keys()) == set(processor_result.keys()) - - for key in original_result: - torch.testing.assert_close(original_result[key], processor_result[key]) \ No newline at end of file + """Test equivalence with dictionary of images.""" + from lerobot.envs.utils import preprocess_observation + + processor = ObservationProcessor() + + # Create test data with multiple cameras + image1 = np.random.randint(0, 256, size=(32, 32, 3), dtype=np.uint8) + image2 = np.random.randint(0, 256, size=(48, 48, 3), dtype=np.uint8) + agent_pos = np.array([1.0, 2.0], dtype=np.float32) + + observation = {"pixels": {"cam1": image1, "cam2": image2}, "agent_pos": agent_pos} + + # Process with original function + original_result = preprocess_observation(observation) + + # Process with new processor + transition = (observation, None, None, None, None, None, None) + processor_result = processor(transition)[0] + + # Compare results + assert set(original_result.keys()) == set(processor_result.keys()) + + for key in original_result: + torch.testing.assert_close(original_result[key], processor_result[key]) diff --git a/tests/processor/test_pipeline.py b/tests/processor/test_pipeline.py index 1f19b355a..936893847 100644 --- a/tests/processor/test_pipeline.py +++ b/tests/processor/test_pipeline.py @@ -16,87 +16,86 @@ import json import tempfile +from dataclasses import dataclass from pathlib import Path from typing import Any -from dataclasses import dataclass -import numpy as np import pytest import torch -from lerobot.processor.pipeline import RobotPipeline, EnvTransition, PipelineStep +from lerobot.processor.pipeline import EnvTransition, RobotPipeline @dataclass class MockStep: """Mock pipeline step for testing - demonstrates best practices. - + This example shows the proper separation: - JSON-serializable attributes (name, counter) go in get_config() - Only torch tensors go in state_dict() - + Note: The counter is part of the configuration, so it will be restored when the step is recreated from config during loading. """ - + name: str = "mock_step" counter: int = 0 - + def __call__(self, transition: EnvTransition) -> EnvTransition: """Add a counter to the complementary_data.""" obs, action, reward, done, truncated, info, comp_data = transition - + if comp_data is None: comp_data = {} else: comp_data = dict(comp_data) # Make a copy - + comp_data[f"{self.name}_counter"] = self.counter self.counter += 1 - + return (obs, action, reward, done, truncated, info, comp_data) - + def get_config(self) -> dict[str, Any]: # Return all JSON-serializable attributes that should be persisted # These will be passed to __init__ when loading return {"name": self.name, "counter": self.counter} - + def state_dict(self) -> dict[str, torch.Tensor]: # Only return torch tensors (empty in this case since we have no tensor state) return {} - + def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: # No tensor state to load pass - + def reset(self) -> None: self.counter = 0 -@dataclass +@dataclass class MockStepWithoutOptionalMethods: """Mock step that only implements the required __call__ method.""" - + multiplier: float = 2.0 - + def __call__(self, transition: EnvTransition) -> EnvTransition: """Multiply reward by multiplier.""" obs, action, reward, done, truncated, info, comp_data = transition - + if reward is not None: reward = reward * self.multiplier - + return (obs, action, reward, done, truncated, info, comp_data) @dataclass class MockStepWithTensorState: """Mock step demonstrating mixed JSON attributes and tensor state.""" - + name: str = "tensor_step" learning_rate: float = 0.01 window_size: int = 10 - + def __init__(self, name: str = "tensor_step", learning_rate: float = 0.01, window_size: int = 10): self.name = name self.learning_rate = learning_rate @@ -104,19 +103,19 @@ class MockStepWithTensorState: # Tensor state self.running_mean = torch.zeros(window_size) self.running_count = torch.tensor(0) - + def __call__(self, transition: EnvTransition) -> EnvTransition: """Update running statistics.""" obs, action, reward, done, truncated, info, comp_data = transition - + if reward is not None: # Update running mean idx = self.running_count % self.window_size self.running_mean[idx] = reward self.running_count += 1 - + return transition - + def get_config(self) -> dict[str, Any]: # Only JSON-serializable attributes return { @@ -124,18 +123,18 @@ class MockStepWithTensorState: "learning_rate": self.learning_rate, "window_size": self.window_size, } - + def state_dict(self) -> dict[str, torch.Tensor]: # Only tensor state return { "running_mean": self.running_mean, "running_count": self.running_count, } - + def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: self.running_mean = state["running_mean"] self.running_count = state["running_count"] - + def reset(self) -> None: self.running_mean.zero_() self.running_count.zero_() @@ -144,265 +143,275 @@ class MockStepWithTensorState: def test_empty_pipeline(): """Test pipeline with no steps.""" pipeline = RobotPipeline() - + transition = (None, None, 0.0, False, False, {}, {}) result = pipeline(transition) - + assert result == transition assert len(pipeline) == 0 + def test_single_step_pipeline(): """Test pipeline with a single step.""" step = MockStep("test_step") pipeline = RobotPipeline([step]) - + transition = (None, None, 0.0, False, False, {}, {}) result = pipeline(transition) - + assert len(pipeline) == 1 assert result[6]["test_step_counter"] == 0 # complementary_data - + # Call again to test counter increment result = pipeline(transition) assert result[6]["test_step_counter"] == 1 + def test_multiple_steps_pipeline(): """Test pipeline with multiple steps.""" step1 = MockStep("step1") - step2 = MockStep("step2") + step2 = MockStep("step2") pipeline = RobotPipeline([step1, step2]) - + transition = (None, None, 0.0, False, False, {}, {}) result = pipeline(transition) - + assert len(pipeline) == 2 assert result[6]["step1_counter"] == 0 assert result[6]["step2_counter"] == 0 + def test_invalid_transition_format(): """Test pipeline with invalid transition format.""" pipeline = RobotPipeline([MockStep()]) - + # Test with wrong number of elements with pytest.raises(ValueError, match="EnvTransition must be a 7-tuple"): pipeline((None, None, 0.0)) # Only 3 elements - + # Test with wrong type with pytest.raises(ValueError, match="EnvTransition must be a 7-tuple"): pipeline("not a tuple") + def test_step_through(): """Test step_through method.""" step1 = MockStep("step1") step2 = MockStep("step2") pipeline = RobotPipeline([step1, step2]) - + transition = (None, None, 0.0, False, False, {}, {}) - + results = list(pipeline.step_through(transition)) - + assert len(results) == 3 # Original + 2 steps assert results[0] == transition # Original assert "step1_counter" in results[1][6] # After step1 assert "step2_counter" in results[2][6] # After step2 + def test_indexing(): """Test pipeline indexing.""" step1 = MockStep("step1") step2 = MockStep("step2") pipeline = RobotPipeline([step1, step2]) - + # Test integer indexing assert pipeline[0] is step1 assert pipeline[1] is step2 - + # Test slice indexing sub_pipeline = pipeline[0:1] assert isinstance(sub_pipeline, RobotPipeline) assert len(sub_pipeline) == 1 assert sub_pipeline[0] is step1 + def test_hooks(): """Test before/after step hooks.""" step = MockStep("test_step") pipeline = RobotPipeline([step]) - + before_calls = [] after_calls = [] - + def before_hook(idx: int, transition: EnvTransition): before_calls.append(idx) return transition - + def after_hook(idx: int, transition: EnvTransition): after_calls.append(idx) return transition - + pipeline.register_before_step_hook(before_hook) pipeline.register_after_step_hook(after_hook) - + transition = (None, None, 0.0, False, False, {}, {}) pipeline(transition) - + assert before_calls == [0] assert after_calls == [0] + def test_hook_modification(): """Test that hooks can modify transitions.""" step = MockStep("test_step") pipeline = RobotPipeline([step]) - + def modify_reward_hook(idx: int, transition: EnvTransition): obs, action, reward, done, truncated, info, comp_data = transition return (obs, action, 42.0, done, truncated, info, comp_data) - + pipeline.register_before_step_hook(modify_reward_hook) - + transition = (None, None, 0.0, False, False, {}, {}) result = pipeline(transition) - + assert result[2] == 42.0 # reward modified by hook + def test_reset(): """Test pipeline reset functionality.""" step = MockStep("test_step") pipeline = RobotPipeline([step]) - + reset_called = [] - + def reset_hook(): reset_called.append(True) - + pipeline.register_reset_hook(reset_hook) - + # Make some calls to increment counter transition = (None, None, 0.0, False, False, {}, {}) pipeline(transition) pipeline(transition) - + assert step.counter == 2 - + # Reset should reset step and call hook pipeline.reset() - + assert step.counter == 0 assert len(reset_called) == 1 + def test_profile_steps(): """Test step profiling functionality.""" step1 = MockStep("step1") step2 = MockStep("step2") pipeline = RobotPipeline([step1, step2]) - + transition = (None, None, 0.0, False, False, {}, {}) - + profile_results = pipeline.profile_steps(transition, num_runs=10) - + assert len(profile_results) == 2 assert "step_0_MockStep" in profile_results assert "step_1_MockStep" in profile_results assert all(isinstance(time, float) and time >= 0 for time in profile_results.values()) + def test_save_and_load_pretrained(): """Test saving and loading pipeline. - + This test demonstrates that JSON-serializable attributes (like counter) are saved in the config and restored when the step is recreated. """ step1 = MockStep("step1") step2 = MockStep("step2") - + # Increment counters to have some state step1.counter = 5 step2.counter = 10 - + pipeline = RobotPipeline([step1, step2], name="TestPipeline", seed=42) - + with tempfile.TemporaryDirectory() as tmp_dir: # Save pipeline pipeline.save_pretrained(tmp_dir) - + # Check files were created config_path = Path(tmp_dir) / "pipeline.json" assert config_path.exists() - + # Check config content with open(config_path) as f: config = json.load(f) - + assert config["name"] == "TestPipeline" assert config["seed"] == 42 assert len(config["steps"]) == 2 - + # Verify counters are saved in config, not in separate state files assert config["steps"][0]["config"]["counter"] == 5 assert config["steps"][1]["config"]["counter"] == 10 - + # Load pipeline loaded_pipeline = RobotPipeline.from_pretrained(tmp_dir) - + assert loaded_pipeline.name == "TestPipeline" assert loaded_pipeline.seed == 42 assert len(loaded_pipeline) == 2 - + # Check that counter was restored from config assert loaded_pipeline.steps[0].counter == 5 assert loaded_pipeline.steps[1].counter == 10 + def test_step_without_optional_methods(): """Test pipeline with steps that don't implement optional methods.""" step = MockStepWithoutOptionalMethods(multiplier=3.0) pipeline = RobotPipeline([step]) - + transition = (None, None, 2.0, False, False, {}, {}) result = pipeline(transition) - + assert result[2] == 6.0 # 2.0 * 3.0 - + # Reset should work even if step doesn't implement reset pipeline.reset() - + # Save/load should work even without optional methods with tempfile.TemporaryDirectory() as tmp_dir: pipeline.save_pretrained(tmp_dir) loaded_pipeline = RobotPipeline.from_pretrained(tmp_dir) assert len(loaded_pipeline) == 1 + def test_mixed_json_and_tensor_state(): """Test step with both JSON attributes and tensor state.""" step = MockStepWithTensorState(name="stats", learning_rate=0.05, window_size=5) pipeline = RobotPipeline([step]) - + # Process some transitions with rewards for i in range(10): transition = (None, None, float(i), False, False, {}, {}) pipeline(transition) - + # Check state assert step.running_count.item() == 10 assert step.learning_rate == 0.05 - + # Save and load with tempfile.TemporaryDirectory() as tmp_dir: pipeline.save_pretrained(tmp_dir) - + # Check that both config and state files were created - config_path = Path(tmp_dir) / "pipeline.json" + config_path = Path(tmp_dir) / "pipeline.json" state_path = Path(tmp_dir) / "step_0.safetensors" assert config_path.exists() assert state_path.exists() - + # Load and verify loaded_pipeline = RobotPipeline.from_pretrained(tmp_dir) loaded_step = loaded_pipeline.steps[0] - + # Check JSON attributes were restored assert loaded_step.name == "stats" assert loaded_step.learning_rate == 0.05 assert loaded_step.window_size == 5 - + # Check tensor state was restored assert loaded_step.running_count.item() == 10 assert torch.allclose(loaded_step.running_mean, step.running_mean) - - \ No newline at end of file