mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-21 11:39:50 +00:00
Apply suggestions from code review
Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> Signed-off-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
@@ -36,8 +36,7 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
|||||||
Returns:
|
Returns:
|
||||||
Dictionary of observation batches with keys renamed to LeRobot format and values as tensors.
|
Dictionary of observation batches with keys renamed to LeRobot format and values as tensors.
|
||||||
"""
|
"""
|
||||||
from lerobot.processor.observation_processor import VanillaObservationProcessor
|
from lerobot.processor import RobotProcessor, TransitionIndex, VanillaObservationProcessor
|
||||||
from lerobot.processor.pipeline import RobotProcessor, TransitionIndex
|
|
||||||
|
|
||||||
# Create processor with observation processor
|
# Create processor with observation processor
|
||||||
processor = RobotProcessor([VanillaObservationProcessor()])
|
processor = RobotProcessor([VanillaObservationProcessor()])
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ from .pipeline import (
|
|||||||
ProcessorStepRegistry,
|
ProcessorStepRegistry,
|
||||||
RewardProcessor,
|
RewardProcessor,
|
||||||
RobotProcessor,
|
RobotProcessor,
|
||||||
|
TransitionIndex,
|
||||||
TruncatedProcessor,
|
TruncatedProcessor,
|
||||||
)
|
)
|
||||||
from .rename_processor import RenameProcessor
|
from .rename_processor import RenameProcessor
|
||||||
@@ -53,6 +54,7 @@ __all__ = [
|
|||||||
"RewardProcessor",
|
"RewardProcessor",
|
||||||
"RobotProcessor",
|
"RobotProcessor",
|
||||||
"StateProcessor",
|
"StateProcessor",
|
||||||
|
"TransitionIndex",
|
||||||
"TruncatedProcessor",
|
"TruncatedProcessor",
|
||||||
"VanillaObservationProcessor",
|
"VanillaObservationProcessor",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ import os
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from enum import IntEnum
|
from enum import IntEnum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable, Dict, Iterable, Protocol, Sequence, Tuple
|
from typing import Any, Callable, Iterable, Protocol, Sequence
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from huggingface_hub import ModelHubMixin, hf_hub_download
|
from huggingface_hub import ModelHubMixin, hf_hub_download
|
||||||
@@ -41,14 +41,14 @@ class TransitionIndex(IntEnum):
|
|||||||
|
|
||||||
|
|
||||||
# (observation, action, reward, done, truncated, info, complementary_data)
|
# (observation, action, reward, done, truncated, info, complementary_data)
|
||||||
EnvTransition = Tuple[
|
EnvTransition = tuple[
|
||||||
dict[str, Any] | None, # observation
|
dict[str, Any] | None, # observation
|
||||||
Any | torch.Tensor | None, # action
|
Any | torch.Tensor | None, # action
|
||||||
float | torch.Tensor | None, # reward
|
float | torch.Tensor | None, # reward
|
||||||
bool | torch.Tensor | None, # done
|
bool | torch.Tensor | None, # done
|
||||||
bool | torch.Tensor | None, # truncated
|
bool | torch.Tensor | None, # truncated
|
||||||
Dict[str, Any] | None, # info
|
dict[str, Any] | None, # info
|
||||||
Dict[str, Any] | None, # complementary_data
|
dict[str, Any] | None, # complementary_data
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@@ -135,11 +135,11 @@ class ProcessorStep(Protocol):
|
|||||||
a safe-to-share JSON + SafeTensors format.
|
a safe-to-share JSON + SafeTensors format.
|
||||||
|
|
||||||
Optional helper protocol:
|
Optional helper protocol:
|
||||||
* ``get_config() -> Dict[str, Any]`` – User-defined JSON-serializable
|
* ``get_config() -> dict[str, Any]`` – User-defined JSON-serializable
|
||||||
configuration and state. YOU decide what to save here. This is where all
|
configuration and state. YOU decide what to save here. This is where all
|
||||||
non-tensor state goes (e.g., name, counter, threshold, window_size).
|
non-tensor state goes (e.g., name, counter, threshold, window_size).
|
||||||
The config dict will be passed to your class constructor when loading.
|
The config dict will be passed to your class constructor when loading.
|
||||||
* ``state_dict() -> Dict[str, torch.Tensor]`` – PyTorch tensor state ONLY.
|
* ``state_dict() -> dict[str, torch.Tensor]`` – PyTorch tensor state ONLY.
|
||||||
This is exclusively for torch.Tensor objects (e.g., learned weights,
|
This is exclusively for torch.Tensor objects (e.g., learned weights,
|
||||||
running statistics as tensors). Never put simple Python types here.
|
running statistics as tensors). Never put simple Python types here.
|
||||||
* ``load_state_dict(state)`` – Inverse of ``state_dict``. Receives a dict
|
* ``load_state_dict(state)`` – Inverse of ``state_dict``. Receives a dict
|
||||||
|
|||||||
@@ -72,8 +72,7 @@ from lerobot.envs.utils import add_envs_task, check_env_attributes_and_types
|
|||||||
from lerobot.policies.factory import make_policy
|
from lerobot.policies.factory import make_policy
|
||||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||||
from lerobot.policies.utils import get_device_from_parameters
|
from lerobot.policies.utils import get_device_from_parameters
|
||||||
from lerobot.processor.observation_processor import VanillaObservationProcessor
|
from lerobot.processor import RobotProcessor, TransitionIndex, VanillaObservationProcessor
|
||||||
from lerobot.processor.pipeline import RobotProcessor, TransitionIndex
|
|
||||||
from lerobot.utils.io_utils import write_video
|
from lerobot.utils.io_utils import write_video
|
||||||
from lerobot.utils.random_utils import set_seed
|
from lerobot.utils.random_utils import set_seed
|
||||||
from lerobot.utils.utils import (
|
from lerobot.utils.utils import (
|
||||||
|
|||||||
@@ -22,8 +22,7 @@ from gymnasium.utils.env_checker import check_env
|
|||||||
|
|
||||||
import lerobot
|
import lerobot
|
||||||
from lerobot.envs.factory import make_env, make_env_config
|
from lerobot.envs.factory import make_env, make_env_config
|
||||||
from lerobot.processor.observation_processor import VanillaObservationProcessor
|
from lerobot.processor import RobotProcessor, TransitionIndex, VanillaObservationProcessor
|
||||||
from lerobot.processor.pipeline import RobotProcessor, TransitionIndex
|
|
||||||
from tests.utils import require_env
|
from tests.utils import require_env
|
||||||
|
|
||||||
OBS_TYPES = ["state", "pixels", "pixels_agent_pos"]
|
OBS_TYPES = ["state", "pixels", "pixels_agent_pos"]
|
||||||
|
|||||||
@@ -39,8 +39,7 @@ from lerobot.policies.factory import (
|
|||||||
)
|
)
|
||||||
from lerobot.policies.normalize import Normalize, Unnormalize
|
from lerobot.policies.normalize import Normalize, Unnormalize
|
||||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||||
from lerobot.processor.observation_processor import VanillaObservationProcessor
|
from lerobot.processor import RobotProcessor, TransitionIndex, VanillaObservationProcessor
|
||||||
from lerobot.processor.pipeline import RobotProcessor, TransitionIndex
|
|
||||||
from lerobot.utils.random_utils import seeded_context
|
from lerobot.utils.random_utils import seeded_context
|
||||||
from tests.artifacts.policies.save_policy_to_safetensors import get_policy_stats
|
from tests.artifacts.policies.save_policy_to_safetensors import get_policy_stats
|
||||||
from tests.utils import DEVICE, require_cpu, require_env, require_x86_64_kernel
|
from tests.utils import DEVICE, require_cpu, require_env, require_x86_64_kernel
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ import numpy as np
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from lerobot.processor.observation_processor import (
|
from lerobot.processor import (
|
||||||
ImageProcessor,
|
ImageProcessor,
|
||||||
StateProcessor,
|
StateProcessor,
|
||||||
VanillaObservationProcessor,
|
VanillaObservationProcessor,
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, RobotProcessor
|
from lerobot.processor import EnvTransition, ProcessorStepRegistry, RobotProcessor
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@@ -20,8 +20,7 @@ from pathlib import Path
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from lerobot.processor.pipeline import ProcessorStepRegistry, RobotProcessor, TransitionIndex
|
from lerobot.processor import ProcessorStepRegistry, RobotProcessor, TransitionIndex, RenameProcessor
|
||||||
from lerobot.processor.rename_processor import RenameProcessor
|
|
||||||
|
|
||||||
|
|
||||||
def test_basic_renaming():
|
def test_basic_renaming():
|
||||||
|
|||||||
Reference in New Issue
Block a user