Apply suggestions from code review

Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
Signed-off-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
Adil Zouitine
2025-07-09 18:20:43 +02:00
parent 33969a0337
commit 1e0d667a22
9 changed files with 15 additions and 18 deletions
+1 -2
View File
@@ -36,8 +36,7 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
Returns:
Dictionary of observation batches with keys renamed to LeRobot format and values as tensors.
"""
from lerobot.processor.observation_processor import VanillaObservationProcessor
from lerobot.processor.pipeline import RobotProcessor, TransitionIndex
from lerobot.processor import RobotProcessor, TransitionIndex, VanillaObservationProcessor
# Create processor with observation processor
processor = RobotProcessor([VanillaObservationProcessor()])
+2
View File
@@ -32,6 +32,7 @@ from .pipeline import (
ProcessorStepRegistry,
RewardProcessor,
RobotProcessor,
TransitionIndex,
TruncatedProcessor,
)
from .rename_processor import RenameProcessor
@@ -53,6 +54,7 @@ __all__ = [
"RewardProcessor",
"RobotProcessor",
"StateProcessor",
"TransitionIndex",
"TruncatedProcessor",
"VanillaObservationProcessor",
]
+6 -6
View File
@@ -21,7 +21,7 @@ import os
from dataclasses import dataclass, field
from enum import IntEnum
from pathlib import Path
from typing import Any, Callable, Dict, Iterable, Protocol, Sequence, Tuple
from typing import Any, Callable, Iterable, Protocol, Sequence
import torch
from huggingface_hub import ModelHubMixin, hf_hub_download
@@ -41,14 +41,14 @@ class TransitionIndex(IntEnum):
# (observation, action, reward, done, truncated, info, complementary_data)
EnvTransition = Tuple[
EnvTransition = tuple[
dict[str, Any] | None, # observation
Any | torch.Tensor | None, # action
float | torch.Tensor | None, # reward
bool | torch.Tensor | None, # done
bool | torch.Tensor | None, # truncated
Dict[str, Any] | None, # info
Dict[str, Any] | None, # complementary_data
dict[str, Any] | None, # info
dict[str, Any] | None, # complementary_data
]
@@ -135,11 +135,11 @@ class ProcessorStep(Protocol):
a safe-to-share JSON + SafeTensors format.
Optional helper protocol:
* ``get_config() -> Dict[str, Any]`` User-defined JSON-serializable
* ``get_config() -> dict[str, Any]`` User-defined JSON-serializable
configuration and state. YOU decide what to save here. This is where all
non-tensor state goes (e.g., name, counter, threshold, window_size).
The config dict will be passed to your class constructor when loading.
* ``state_dict() -> Dict[str, torch.Tensor]`` PyTorch tensor state ONLY.
* ``state_dict() -> dict[str, torch.Tensor]`` PyTorch tensor state ONLY.
This is exclusively for torch.Tensor objects (e.g., learned weights,
running statistics as tensors). Never put simple Python types here.
* ``load_state_dict(state)`` Inverse of ``state_dict``. Receives a dict
+1 -2
View File
@@ -72,8 +72,7 @@ from lerobot.envs.utils import add_envs_task, check_env_attributes_and_types
from lerobot.policies.factory import make_policy
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.utils import get_device_from_parameters
from lerobot.processor.observation_processor import VanillaObservationProcessor
from lerobot.processor.pipeline import RobotProcessor, TransitionIndex
from lerobot.processor import RobotProcessor, TransitionIndex, VanillaObservationProcessor
from lerobot.utils.io_utils import write_video
from lerobot.utils.random_utils import set_seed
from lerobot.utils.utils import (
+1 -2
View File
@@ -22,8 +22,7 @@ from gymnasium.utils.env_checker import check_env
import lerobot
from lerobot.envs.factory import make_env, make_env_config
from lerobot.processor.observation_processor import VanillaObservationProcessor
from lerobot.processor.pipeline import RobotProcessor, TransitionIndex
from lerobot.processor import RobotProcessor, TransitionIndex, VanillaObservationProcessor
from tests.utils import require_env
OBS_TYPES = ["state", "pixels", "pixels_agent_pos"]
+1 -2
View File
@@ -39,8 +39,7 @@ from lerobot.policies.factory import (
)
from lerobot.policies.normalize import Normalize, Unnormalize
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.processor.observation_processor import VanillaObservationProcessor
from lerobot.processor.pipeline import RobotProcessor, TransitionIndex
from lerobot.processor import RobotProcessor, TransitionIndex, VanillaObservationProcessor
from lerobot.utils.random_utils import seeded_context
from tests.artifacts.policies.save_policy_to_safetensors import get_policy_stats
from tests.utils import DEVICE, require_cpu, require_env, require_x86_64_kernel
@@ -18,7 +18,7 @@ import numpy as np
import pytest
import torch
from lerobot.processor.observation_processor import (
from lerobot.processor import (
ImageProcessor,
StateProcessor,
VanillaObservationProcessor,
+1 -1
View File
@@ -25,7 +25,7 @@ import pytest
import torch
import torch.nn as nn
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, RobotProcessor
from lerobot.processor import EnvTransition, ProcessorStepRegistry, RobotProcessor
@dataclass
+1 -2
View File
@@ -20,8 +20,7 @@ from pathlib import Path
import numpy as np
import torch
from lerobot.processor.pipeline import ProcessorStepRegistry, RobotProcessor, TransitionIndex
from lerobot.processor.rename_processor import RenameProcessor
from lerobot.processor import ProcessorStepRegistry, RobotProcessor, TransitionIndex, RenameProcessor
def test_basic_renaming():