Apply suggestions from code review

Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
Signed-off-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
Adil Zouitine
2025-07-09 18:20:43 +02:00
parent 33969a0337
commit 1e0d667a22
9 changed files with 15 additions and 18 deletions
+1 -2
View File
@@ -36,8 +36,7 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
Returns: Returns:
Dictionary of observation batches with keys renamed to LeRobot format and values as tensors. Dictionary of observation batches with keys renamed to LeRobot format and values as tensors.
""" """
from lerobot.processor.observation_processor import VanillaObservationProcessor from lerobot.processor import RobotProcessor, TransitionIndex, VanillaObservationProcessor
from lerobot.processor.pipeline import RobotProcessor, TransitionIndex
# Create processor with observation processor # Create processor with observation processor
processor = RobotProcessor([VanillaObservationProcessor()]) processor = RobotProcessor([VanillaObservationProcessor()])
+2
View File
@@ -32,6 +32,7 @@ from .pipeline import (
ProcessorStepRegistry, ProcessorStepRegistry,
RewardProcessor, RewardProcessor,
RobotProcessor, RobotProcessor,
TransitionIndex,
TruncatedProcessor, TruncatedProcessor,
) )
from .rename_processor import RenameProcessor from .rename_processor import RenameProcessor
@@ -53,6 +54,7 @@ __all__ = [
"RewardProcessor", "RewardProcessor",
"RobotProcessor", "RobotProcessor",
"StateProcessor", "StateProcessor",
"TransitionIndex",
"TruncatedProcessor", "TruncatedProcessor",
"VanillaObservationProcessor", "VanillaObservationProcessor",
] ]
+6 -6
View File
@@ -21,7 +21,7 @@ import os
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import IntEnum from enum import IntEnum
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Dict, Iterable, Protocol, Sequence, Tuple from typing import Any, Callable, Iterable, Protocol, Sequence
import torch import torch
from huggingface_hub import ModelHubMixin, hf_hub_download from huggingface_hub import ModelHubMixin, hf_hub_download
@@ -41,14 +41,14 @@ class TransitionIndex(IntEnum):
# (observation, action, reward, done, truncated, info, complementary_data) # (observation, action, reward, done, truncated, info, complementary_data)
EnvTransition = Tuple[ EnvTransition = tuple[
dict[str, Any] | None, # observation dict[str, Any] | None, # observation
Any | torch.Tensor | None, # action Any | torch.Tensor | None, # action
float | torch.Tensor | None, # reward float | torch.Tensor | None, # reward
bool | torch.Tensor | None, # done bool | torch.Tensor | None, # done
bool | torch.Tensor | None, # truncated bool | torch.Tensor | None, # truncated
Dict[str, Any] | None, # info dict[str, Any] | None, # info
Dict[str, Any] | None, # complementary_data dict[str, Any] | None, # complementary_data
] ]
@@ -135,11 +135,11 @@ class ProcessorStep(Protocol):
a safe-to-share JSON + SafeTensors format. a safe-to-share JSON + SafeTensors format.
Optional helper protocol: Optional helper protocol:
* ``get_config() -> Dict[str, Any]`` User-defined JSON-serializable * ``get_config() -> dict[str, Any]`` User-defined JSON-serializable
configuration and state. YOU decide what to save here. This is where all configuration and state. YOU decide what to save here. This is where all
non-tensor state goes (e.g., name, counter, threshold, window_size). non-tensor state goes (e.g., name, counter, threshold, window_size).
The config dict will be passed to your class constructor when loading. The config dict will be passed to your class constructor when loading.
* ``state_dict() -> Dict[str, torch.Tensor]`` PyTorch tensor state ONLY. * ``state_dict() -> dict[str, torch.Tensor]`` PyTorch tensor state ONLY.
This is exclusively for torch.Tensor objects (e.g., learned weights, This is exclusively for torch.Tensor objects (e.g., learned weights,
running statistics as tensors). Never put simple Python types here. running statistics as tensors). Never put simple Python types here.
* ``load_state_dict(state)`` Inverse of ``state_dict``. Receives a dict * ``load_state_dict(state)`` Inverse of ``state_dict``. Receives a dict
+1 -2
View File
@@ -72,8 +72,7 @@ from lerobot.envs.utils import add_envs_task, check_env_attributes_and_types
from lerobot.policies.factory import make_policy from lerobot.policies.factory import make_policy
from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.utils import get_device_from_parameters from lerobot.policies.utils import get_device_from_parameters
from lerobot.processor.observation_processor import VanillaObservationProcessor from lerobot.processor import RobotProcessor, TransitionIndex, VanillaObservationProcessor
from lerobot.processor.pipeline import RobotProcessor, TransitionIndex
from lerobot.utils.io_utils import write_video from lerobot.utils.io_utils import write_video
from lerobot.utils.random_utils import set_seed from lerobot.utils.random_utils import set_seed
from lerobot.utils.utils import ( from lerobot.utils.utils import (
+1 -2
View File
@@ -22,8 +22,7 @@ from gymnasium.utils.env_checker import check_env
import lerobot import lerobot
from lerobot.envs.factory import make_env, make_env_config from lerobot.envs.factory import make_env, make_env_config
from lerobot.processor.observation_processor import VanillaObservationProcessor from lerobot.processor import RobotProcessor, TransitionIndex, VanillaObservationProcessor
from lerobot.processor.pipeline import RobotProcessor, TransitionIndex
from tests.utils import require_env from tests.utils import require_env
OBS_TYPES = ["state", "pixels", "pixels_agent_pos"] OBS_TYPES = ["state", "pixels", "pixels_agent_pos"]
+1 -2
View File
@@ -39,8 +39,7 @@ from lerobot.policies.factory import (
) )
from lerobot.policies.normalize import Normalize, Unnormalize from lerobot.policies.normalize import Normalize, Unnormalize
from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.processor.observation_processor import VanillaObservationProcessor from lerobot.processor import RobotProcessor, TransitionIndex, VanillaObservationProcessor
from lerobot.processor.pipeline import RobotProcessor, TransitionIndex
from lerobot.utils.random_utils import seeded_context from lerobot.utils.random_utils import seeded_context
from tests.artifacts.policies.save_policy_to_safetensors import get_policy_stats from tests.artifacts.policies.save_policy_to_safetensors import get_policy_stats
from tests.utils import DEVICE, require_cpu, require_env, require_x86_64_kernel from tests.utils import DEVICE, require_cpu, require_env, require_x86_64_kernel
@@ -18,7 +18,7 @@ import numpy as np
import pytest import pytest
import torch import torch
from lerobot.processor.observation_processor import ( from lerobot.processor import (
ImageProcessor, ImageProcessor,
StateProcessor, StateProcessor,
VanillaObservationProcessor, VanillaObservationProcessor,
+1 -1
View File
@@ -25,7 +25,7 @@ import pytest
import torch import torch
import torch.nn as nn import torch.nn as nn
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, RobotProcessor from lerobot.processor import EnvTransition, ProcessorStepRegistry, RobotProcessor
@dataclass @dataclass
+1 -2
View File
@@ -20,8 +20,7 @@ from pathlib import Path
import numpy as np import numpy as np
import torch import torch
from lerobot.processor.pipeline import ProcessorStepRegistry, RobotProcessor, TransitionIndex from lerobot.processor import ProcessorStepRegistry, RobotProcessor, TransitionIndex, RenameProcessor
from lerobot.processor.rename_processor import RenameProcessor
def test_basic_renaming(): def test_basic_renaming():