From 1e0d667a22a54bd05d8cc622bb4e6cf7d74730b5 Mon Sep 17 00:00:00 2001 From: Adil Zouitine Date: Wed, 9 Jul 2025 18:20:43 +0200 Subject: [PATCH] Apply suggestions from code review Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> Signed-off-by: Adil Zouitine --- src/lerobot/envs/utils.py | 3 +-- src/lerobot/processor/__init__.py | 2 ++ src/lerobot/processor/pipeline.py | 12 ++++++------ src/lerobot/scripts/eval.py | 3 +-- tests/envs/test_envs.py | 3 +-- tests/policies/test_policies.py | 3 +-- tests/processor/test_observation_processor.py | 2 +- tests/processor/test_pipeline.py | 2 +- tests/processor/test_rename_processor.py | 3 +-- 9 files changed, 15 insertions(+), 18 deletions(-) diff --git a/src/lerobot/envs/utils.py b/src/lerobot/envs/utils.py index c5aaa7001..a65023d32 100644 --- a/src/lerobot/envs/utils.py +++ b/src/lerobot/envs/utils.py @@ -36,8 +36,7 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten Returns: Dictionary of observation batches with keys renamed to LeRobot format and values as tensors. """ - from lerobot.processor.observation_processor import VanillaObservationProcessor - from lerobot.processor.pipeline import RobotProcessor, TransitionIndex + from lerobot.processor import RobotProcessor, TransitionIndex, VanillaObservationProcessor # Create processor with observation processor processor = RobotProcessor([VanillaObservationProcessor()]) diff --git a/src/lerobot/processor/__init__.py b/src/lerobot/processor/__init__.py index c5c4af9fa..5dd2e0125 100644 --- a/src/lerobot/processor/__init__.py +++ b/src/lerobot/processor/__init__.py @@ -32,6 +32,7 @@ from .pipeline import ( ProcessorStepRegistry, RewardProcessor, RobotProcessor, + TransitionIndex, TruncatedProcessor, ) from .rename_processor import RenameProcessor @@ -53,6 +54,7 @@ __all__ = [ "RewardProcessor", "RobotProcessor", "StateProcessor", + "TransitionIndex", "TruncatedProcessor", "VanillaObservationProcessor", ] diff --git a/src/lerobot/processor/pipeline.py b/src/lerobot/processor/pipeline.py index 038711339..62fa732aa 100644 --- a/src/lerobot/processor/pipeline.py +++ b/src/lerobot/processor/pipeline.py @@ -21,7 +21,7 @@ import os from dataclasses import dataclass, field from enum import IntEnum from pathlib import Path -from typing import Any, Callable, Dict, Iterable, Protocol, Sequence, Tuple +from typing import Any, Callable, Iterable, Protocol, Sequence import torch from huggingface_hub import ModelHubMixin, hf_hub_download @@ -41,14 +41,14 @@ class TransitionIndex(IntEnum): # (observation, action, reward, done, truncated, info, complementary_data) -EnvTransition = Tuple[ +EnvTransition = tuple[ dict[str, Any] | None, # observation Any | torch.Tensor | None, # action float | torch.Tensor | None, # reward bool | torch.Tensor | None, # done bool | torch.Tensor | None, # truncated - Dict[str, Any] | None, # info - Dict[str, Any] | None, # complementary_data + dict[str, Any] | None, # info + dict[str, Any] | None, # complementary_data ] @@ -135,11 +135,11 @@ class ProcessorStep(Protocol): a safe-to-share JSON + SafeTensors format. Optional helper protocol: - * ``get_config() -> Dict[str, Any]`` – User-defined JSON-serializable + * ``get_config() -> dict[str, Any]`` – User-defined JSON-serializable configuration and state. YOU decide what to save here. This is where all non-tensor state goes (e.g., name, counter, threshold, window_size). The config dict will be passed to your class constructor when loading. - * ``state_dict() -> Dict[str, torch.Tensor]`` – PyTorch tensor state ONLY. + * ``state_dict() -> dict[str, torch.Tensor]`` – PyTorch tensor state ONLY. This is exclusively for torch.Tensor objects (e.g., learned weights, running statistics as tensors). Never put simple Python types here. * ``load_state_dict(state)`` – Inverse of ``state_dict``. Receives a dict diff --git a/src/lerobot/scripts/eval.py b/src/lerobot/scripts/eval.py index 7ea4a8995..c8e1a80cc 100644 --- a/src/lerobot/scripts/eval.py +++ b/src/lerobot/scripts/eval.py @@ -72,8 +72,7 @@ from lerobot.envs.utils import add_envs_task, check_env_attributes_and_types from lerobot.policies.factory import make_policy from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.utils import get_device_from_parameters -from lerobot.processor.observation_processor import VanillaObservationProcessor -from lerobot.processor.pipeline import RobotProcessor, TransitionIndex +from lerobot.processor import RobotProcessor, TransitionIndex, VanillaObservationProcessor from lerobot.utils.io_utils import write_video from lerobot.utils.random_utils import set_seed from lerobot.utils.utils import ( diff --git a/tests/envs/test_envs.py b/tests/envs/test_envs.py index 9add876c2..15ce1f933 100644 --- a/tests/envs/test_envs.py +++ b/tests/envs/test_envs.py @@ -22,8 +22,7 @@ from gymnasium.utils.env_checker import check_env import lerobot from lerobot.envs.factory import make_env, make_env_config -from lerobot.processor.observation_processor import VanillaObservationProcessor -from lerobot.processor.pipeline import RobotProcessor, TransitionIndex +from lerobot.processor import RobotProcessor, TransitionIndex, VanillaObservationProcessor from tests.utils import require_env OBS_TYPES = ["state", "pixels", "pixels_agent_pos"] diff --git a/tests/policies/test_policies.py b/tests/policies/test_policies.py index c95e99d34..44751a829 100644 --- a/tests/policies/test_policies.py +++ b/tests/policies/test_policies.py @@ -39,8 +39,7 @@ from lerobot.policies.factory import ( ) from lerobot.policies.normalize import Normalize, Unnormalize from lerobot.policies.pretrained import PreTrainedPolicy -from lerobot.processor.observation_processor import VanillaObservationProcessor -from lerobot.processor.pipeline import RobotProcessor, TransitionIndex +from lerobot.processor import RobotProcessor, TransitionIndex, VanillaObservationProcessor from lerobot.utils.random_utils import seeded_context from tests.artifacts.policies.save_policy_to_safetensors import get_policy_stats from tests.utils import DEVICE, require_cpu, require_env, require_x86_64_kernel diff --git a/tests/processor/test_observation_processor.py b/tests/processor/test_observation_processor.py index 1f8cd92f0..5e06fd7fa 100644 --- a/tests/processor/test_observation_processor.py +++ b/tests/processor/test_observation_processor.py @@ -18,7 +18,7 @@ import numpy as np import pytest import torch -from lerobot.processor.observation_processor import ( +from lerobot.processor import ( ImageProcessor, StateProcessor, VanillaObservationProcessor, diff --git a/tests/processor/test_pipeline.py b/tests/processor/test_pipeline.py index b5952b412..60e727132 100644 --- a/tests/processor/test_pipeline.py +++ b/tests/processor/test_pipeline.py @@ -25,7 +25,7 @@ import pytest import torch import torch.nn as nn -from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, RobotProcessor +from lerobot.processor import EnvTransition, ProcessorStepRegistry, RobotProcessor @dataclass diff --git a/tests/processor/test_rename_processor.py b/tests/processor/test_rename_processor.py index 1b7b28425..c29343b52 100644 --- a/tests/processor/test_rename_processor.py +++ b/tests/processor/test_rename_processor.py @@ -20,8 +20,7 @@ from pathlib import Path import numpy as np import torch -from lerobot.processor.pipeline import ProcessorStepRegistry, RobotProcessor, TransitionIndex -from lerobot.processor.rename_processor import RenameProcessor +from lerobot.processor import ProcessorStepRegistry, RobotProcessor, TransitionIndex, RenameProcessor def test_basic_renaming():