[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot]
2025-07-02 15:31:15 +00:00
committed by Adil Zouitine
parent f6c7287ae7
commit 769f531603
9 changed files with 485 additions and 475 deletions
+1 -6
View File
@@ -16,10 +16,8 @@
import warnings
from typing import Any
import einops
import gymnasium as gym
import numpy as np
import torch
from torch import Tensor
from lerobot.configs.types import FeatureType, PolicyFeature
@@ -38,8 +36,8 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
Returns:
Dictionary of observation batches with keys renamed to LeRobot format and values as tensors.
"""
from lerobot.processor.pipeline import RobotPipeline, TransitionIndex
from lerobot.processor.observation_processor import ObservationProcessor
from lerobot.processor.pipeline import RobotPipeline, TransitionIndex
# Create pipeline with observation processor
pipeline = RobotPipeline([ObservationProcessor()])
@@ -52,9 +50,6 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
return processed_transition[TransitionIndex.OBSERVATION]
def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]:
# TODO(aliberts, rcadene): remove this hardcoding of keys and just use the nested keys as is
# (need to externalize normalization from policies)
+2 -2
View File
@@ -13,12 +13,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .pipeline import RobotPipeline, PipelineStep, EnvTransition
from .observation_processor import (
ImageProcessor,
StateProcessor,
ObservationProcessor,
StateProcessor,
)
from .pipeline import EnvTransition, PipelineStep, RobotPipeline
__all__ = [
"RobotPipeline",
+10 -5
View File
@@ -15,14 +15,15 @@
# limitations under the License.
from __future__ import annotations
from typing import Any
from dataclasses import dataclass, field
from typing import Any
import einops
import numpy as np
import torch
import einops
from torch import Tensor
from lerobot.processor.pipeline import EnvTransition, PipelineStep, TransitionIndex
from lerobot.processor.pipeline import EnvTransition, TransitionIndex
@dataclass
@@ -212,8 +213,12 @@ class ObservationProcessor:
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
"""Load state dictionary."""
image_state = {k.replace("image_processor.", ""): v for k, v in state.items() if k.startswith("image_processor.")}
state_state = {k.replace("state_processor.", ""): v for k, v in state.items() if k.startswith("state_processor.")}
image_state = {
k.replace("image_processor.", ""): v for k, v in state.items() if k.startswith("image_processor.")
}
state_state = {
k.replace("state_processor.", ""): v for k, v in state.items() if k.startswith("state_processor.")
}
self.image_processor.load_state_dict(image_state)
self.state_processor.load_state_dict(state_state)
+31 -28
View File
@@ -14,19 +14,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import os, json
from typing import Any, Dict, Sequence, Iterable, Protocol, Optional, Tuple, Callable, Union
import json
import os
from dataclasses import dataclass, field
from pathlib import Path
from enum import IntEnum
import numpy as np
from pathlib import Path
from typing import Any, Callable, Dict, Iterable, Protocol, Sequence, Tuple
import torch
from huggingface_hub import hf_hub_download, ModelHubMixin
from safetensors.torch import save_file, load_file
from huggingface_hub import ModelHubMixin, hf_hub_download
from safetensors.torch import load_file, save_file
class TransitionIndex(IntEnum):
"""Explicit indices for EnvTransition tuple components."""
OBSERVATION = 0
ACTION = 1
REWARD = 2
@@ -38,17 +41,16 @@ class TransitionIndex(IntEnum):
# (observation, action, reward, done, truncated, info, complementary_data)
EnvTransition = Tuple[
Any| None, # observation
Any| None, # action
float| None, # reward
bool| None, # done
bool| None, # truncated
Dict[str, Any]| None, # info
Dict[str, Any]| None, # complementary_data
Any | None, # observation
Any | None, # action
float | None, # reward
bool | None, # done
bool | None, # truncated
Dict[str, Any] | None, # info
Dict[str, Any] | None, # complementary_data
]
class PipelineStep(Protocol):
"""Structural typing interface for a single pipeline step.
@@ -78,11 +80,11 @@ class PipelineStep(Protocol):
def __call__(self, transition: EnvTransition) -> EnvTransition: ...
def get_config(self) -> Dict[str, Any]: ...
def get_config(self) -> dict[str, Any]: ...
def state_dict(self) -> Dict[str, torch.Tensor]: ...
def state_dict(self) -> dict[str, torch.Tensor]: ...
def load_state_dict(self, state: Dict[str, torch.Tensor]) -> None: ...
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: ...
def reset(self) -> None: ...
@@ -120,17 +122,18 @@ class RobotPipeline(ModelHubMixin):
pipe.push_to_hub("my-org/cartpole_pipe")
loaded = RobotPipeline.from_pretrained("my-org/cartpole_pipe")
"""
steps: Sequence[PipelineStep] = field(default_factory=list)
name: str = "RobotPipeline"
seed: Optional[int] = None
seed: int | None = None
# Pipeline-level hooks
# A hook can optionally return a modified transition. If it returns
# ``None`` the current value is left untouched.
before_step_hooks: list[Callable[[int, EnvTransition], Optional[EnvTransition]]] = field(
before_step_hooks: list[Callable[[int, EnvTransition], EnvTransition | None]] = field(
default_factory=list, repr=False
)
after_step_hooks: list[Callable[[int, EnvTransition], Optional[EnvTransition]]] = field(
after_step_hooks: list[Callable[[int, EnvTransition], EnvTransition | None]] = field(
default_factory=list, repr=False
)
reset_hooks: list[Callable[[], None]] = field(default_factory=list, repr=False)
@@ -177,14 +180,14 @@ class RobotPipeline(ModelHubMixin):
"""Serialize the pipeline definition and parameters to *destination_path*."""
os.makedirs(destination_path, exist_ok=True)
config: Dict[str, Any] = {
config: dict[str, Any] = {
"name": self.name,
"seed": self.seed,
"steps": [],
}
for step_index, pipeline_step in enumerate(self.steps):
step_entry: Dict[str, Any] = {
step_entry: dict[str, Any] = {
"class": f"{pipeline_step.__class__.__module__}.{pipeline_step.__class__.__name__}",
}
@@ -204,19 +207,19 @@ class RobotPipeline(ModelHubMixin):
json.dump(config, file_pointer, indent=2)
@classmethod
def from_pretrained(cls, source: str) -> "RobotPipeline":
def from_pretrained(cls, source: str) -> RobotPipeline:
"""Load a serialized pipeline from *source* (local path or Hugging Face Hub identifier)."""
if Path(source).is_dir():
# Local path - use it directly
base_path = Path(source)
with open(base_path / cls._CFG_NAME) as file_pointer:
config: Dict[str, Any] = json.load(file_pointer)
config: dict[str, Any] = json.load(file_pointer)
else:
# Hugging Face Hub - download all required files
# First download the config file
config_path = hf_hub_download(source, cls._CFG_NAME, repo_type="model")
with open(config_path) as file_pointer:
config: Dict[str, Any] = json.load(file_pointer)
config: dict[str, Any] = json.load(file_pointer)
# Store downloaded files in the same directory as the config
base_path = Path(config_path).parent
@@ -254,11 +257,11 @@ class RobotPipeline(ModelHubMixin):
return RobotPipeline(self.steps[idx], self.name, self.seed)
return self.steps[idx]
def register_before_step_hook(self, fn: Callable[[int, EnvTransition], Optional[EnvTransition]]):
def register_before_step_hook(self, fn: Callable[[int, EnvTransition], EnvTransition | None]):
"""Attach fn to be executed before every pipeline step."""
self.before_step_hooks.append(fn)
def register_after_step_hook(self, fn: Callable[[int, EnvTransition], Optional[EnvTransition]]):
def register_after_step_hook(self, fn: Callable[[int, EnvTransition], EnvTransition | None]):
"""Attach fn to be executed after every pipeline step."""
self.after_step_hooks.append(fn)
@@ -274,7 +277,7 @@ class RobotPipeline(ModelHubMixin):
for fn in self.reset_hooks:
fn()
def profile_steps(self, transition: EnvTransition, num_runs: int = 100) -> Dict[str, float]:
def profile_steps(self, transition: EnvTransition, num_runs: int = 100) -> dict[str, float]:
"""Profile the execution time of each step for performance optimization."""
import time
+2 -3
View File
@@ -69,12 +69,11 @@ from lerobot.configs import parser
from lerobot.configs.eval import EvalPipelineConfig
from lerobot.envs.factory import make_env
from lerobot.envs.utils import add_envs_task, check_env_attributes_and_types
from lerobot.processor.pipeline import RobotPipeline, TransitionIndex
from lerobot.processor.observation_processor import ObservationProcessor
from lerobot.policies.factory import make_policy
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.utils import get_device_from_parameters
from lerobot.processor.observation_processor import ObservationProcessor
from lerobot.processor.pipeline import RobotPipeline, TransitionIndex
from lerobot.utils.io_utils import write_video
from lerobot.utils.random_utils import set_seed
from lerobot.utils.utils import (
+1 -2
View File
@@ -22,9 +22,8 @@ from gymnasium.utils.env_checker import check_env
import lerobot
from lerobot.envs.factory import make_env, make_env_config
from lerobot.processor.pipeline import RobotPipeline, TransitionIndex
from lerobot.processor.observation_processor import ObservationProcessor
from lerobot.processor.pipeline import RobotPipeline, TransitionIndex
from tests.utils import require_env
OBS_TYPES = ["state", "pixels", "pixels_agent_pos"]
+2 -2
View File
@@ -30,8 +30,6 @@ from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.datasets.factory import make_dataset
from lerobot.datasets.utils import cycle, dataset_to_policy_features
from lerobot.envs.factory import make_env, make_env_config
from lerobot.processor.pipeline import RobotPipeline, TransitionIndex
from lerobot.processor.observation_processor import ObservationProcessor
from lerobot.optim.factory import make_optimizer_and_scheduler
from lerobot.policies.act.modeling_act import ACTTemporalEnsembler
from lerobot.policies.factory import (
@@ -41,6 +39,8 @@ from lerobot.policies.factory import (
)
from lerobot.policies.normalize import Normalize, Unnormalize
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.processor.observation_processor import ObservationProcessor
from lerobot.processor.pipeline import RobotPipeline, TransitionIndex
from lerobot.utils.random_utils import seeded_context
from tests.artifacts.policies.save_policy_to_safetensors import get_policy_stats
from tests.utils import DEVICE, require_cpu, require_env, require_x86_64_kernel
+175 -175
View File
@@ -20,10 +20,9 @@ import torch
from lerobot.processor.observation_processor import (
ImageProcessor,
StateProcessor,
ObservationProcessor,
StateProcessor,
)
from lerobot.processor.pipeline import EnvTransition
def test_process_single_image():
@@ -51,6 +50,7 @@ def test_process_single_image():
assert processed_img.min() >= 0.0
assert processed_img.max() <= 1.0
def test_process_image_dict():
"""Test processing multiple images in a dictionary."""
processor = ImageProcessor()
@@ -59,12 +59,7 @@ def test_process_image_dict():
image1 = np.random.randint(0, 256, size=(32, 32, 3), dtype=np.uint8)
image2 = np.random.randint(0, 256, size=(48, 48, 3), dtype=np.uint8)
observation = {
"pixels": {
"camera1": image1,
"camera2": image2
}
}
observation = {"pixels": {"camera1": image1, "camera2": image2}}
transition = (observation, None, None, None, None, None, None)
result = processor(transition)
@@ -78,6 +73,7 @@ def test_process_image_dict():
assert processed_obs["observation.images.camera1"].shape == (1, 3, 32, 32)
assert processed_obs["observation.images.camera2"].shape == (1, 3, 48, 48)
def test_process_batched_image():
"""Test processing already batched images."""
processor = ImageProcessor()
@@ -94,6 +90,7 @@ def test_process_batched_image():
# Check that batch dimension is preserved
assert processed_obs["observation.image"].shape == (2, 3, 64, 64)
def test_invalid_image_format():
"""Test error handling for invalid image formats."""
processor = ImageProcessor()
@@ -106,6 +103,7 @@ def test_invalid_image_format():
with pytest.raises(ValueError, match="Expected channel-last images"):
processor(transition)
def test_invalid_image_dtype():
"""Test error handling for invalid image dtype."""
processor = ImageProcessor()
@@ -118,6 +116,7 @@ def test_invalid_image_dtype():
with pytest.raises(ValueError, match="Expected torch.uint8 images"):
processor(transition)
def test_no_pixels_in_observation():
"""Test processor when no pixels are in observation."""
processor = ImageProcessor()
@@ -132,6 +131,7 @@ def test_no_pixels_in_observation():
assert "other_data" in processed_obs
np.testing.assert_array_equal(processed_obs["other_data"], np.array([1, 2, 3]))
def test_none_observation():
"""Test processor with None observation."""
processor = ImageProcessor()
@@ -141,6 +141,7 @@ def test_none_observation():
assert result == transition
def test_serialization_methods():
"""Test serialization methods."""
processor = ImageProcessor()
@@ -161,64 +162,64 @@ def test_serialization_methods():
def test_process_environment_state():
"""Test processing environment_state."""
processor = StateProcessor()
"""Test processing environment_state."""
processor = StateProcessor()
env_state = np.array([1.0, 2.0, 3.0], dtype=np.float32)
observation = {"environment_state": env_state}
transition = (observation, None, None, None, None, None, None)
env_state = np.array([1.0, 2.0, 3.0], dtype=np.float32)
observation = {"environment_state": env_state}
transition = (observation, None, None, None, None, None, None)
result = processor(transition)
processed_obs = result[0]
result = processor(transition)
processed_obs = result[0]
# Check that environment_state was renamed and processed
assert "observation.environment_state" in processed_obs
assert "environment_state" not in processed_obs
# Check that environment_state was renamed and processed
assert "observation.environment_state" in processed_obs
assert "environment_state" not in processed_obs
processed_state = processed_obs["observation.environment_state"]
assert processed_state.shape == (1, 3) # Batch dimension added
assert processed_state.dtype == torch.float32
torch.testing.assert_close(processed_state, torch.tensor([[1.0, 2.0, 3.0]]))
processed_state = processed_obs["observation.environment_state"]
assert processed_state.shape == (1, 3) # Batch dimension added
assert processed_state.dtype == torch.float32
torch.testing.assert_close(processed_state, torch.tensor([[1.0, 2.0, 3.0]]))
def test_process_agent_pos():
"""Test processing agent_pos."""
processor = StateProcessor()
"""Test processing agent_pos."""
processor = StateProcessor()
agent_pos = np.array([0.5, -0.5, 1.0], dtype=np.float32)
observation = {"agent_pos": agent_pos}
transition = (observation, None, None, None, None, None, None)
agent_pos = np.array([0.5, -0.5, 1.0], dtype=np.float32)
observation = {"agent_pos": agent_pos}
transition = (observation, None, None, None, None, None, None)
result = processor(transition)
processed_obs = result[0]
result = processor(transition)
processed_obs = result[0]
# Check that agent_pos was renamed and processed
assert "observation.state" in processed_obs
assert "agent_pos" not in processed_obs
# Check that agent_pos was renamed and processed
assert "observation.state" in processed_obs
assert "agent_pos" not in processed_obs
processed_state = processed_obs["observation.state"]
assert processed_state.shape == (1, 3) # Batch dimension added
assert processed_state.dtype == torch.float32
torch.testing.assert_close(processed_state, torch.tensor([[0.5, -0.5, 1.0]]))
processed_state = processed_obs["observation.state"]
assert processed_state.shape == (1, 3) # Batch dimension added
assert processed_state.dtype == torch.float32
torch.testing.assert_close(processed_state, torch.tensor([[0.5, -0.5, 1.0]]))
def test_process_batched_states():
"""Test processing already batched states."""
processor = StateProcessor()
"""Test processing already batched states."""
processor = StateProcessor()
env_state = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)
agent_pos = np.array([[0.5, -0.5], [1.0, -1.0]], dtype=np.float32)
env_state = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)
agent_pos = np.array([[0.5, -0.5], [1.0, -1.0]], dtype=np.float32)
observation = {
"environment_state": env_state,
"agent_pos": agent_pos
}
transition = (observation, None, None, None, None, None, None)
observation = {"environment_state": env_state, "agent_pos": agent_pos}
transition = (observation, None, None, None, None, None, None)
result = processor(transition)
processed_obs = result[0]
result = processor(transition)
processed_obs = result[0]
# Check that batch dimensions are preserved
assert processed_obs["observation.environment_state"].shape == (2, 2)
assert processed_obs["observation.state"].shape == (2, 2)
# Check that batch dimensions are preserved
assert processed_obs["observation.environment_state"].shape == (2, 2)
assert processed_obs["observation.state"].shape == (2, 2)
def test_process_both_states():
"""Test processing both environment_state and agent_pos."""
@@ -227,11 +228,7 @@ def test_process_both_states():
env_state = np.array([1.0, 2.0], dtype=np.float32)
agent_pos = np.array([0.5, -0.5], dtype=np.float32)
observation = {
"environment_state": env_state,
"agent_pos": agent_pos,
"other_data": "keep_me"
}
observation = {"environment_state": env_state, "agent_pos": agent_pos, "other_data": "keep_me"}
transition = (observation, None, None, None, None, None, None)
result = processor(transition)
@@ -248,6 +245,7 @@ def test_process_both_states():
# Check that other data was preserved
assert processed_obs["other_data"] == "keep_me"
def test_no_states_in_observation():
"""Test processor when no states are in observation."""
processor = StateProcessor()
@@ -261,6 +259,7 @@ def test_no_states_in_observation():
# Should preserve data unchanged
assert processed_obs == observation
def test_none_observation():
"""Test processor with None observation."""
processor = StateProcessor()
@@ -270,6 +269,7 @@ def test_none_observation():
assert result == transition
def test_serialization_methods():
"""Test serialization methods."""
processor = StateProcessor()
@@ -290,177 +290,177 @@ def test_serialization_methods():
def test_complete_observation_processing():
"""Test processing a complete observation with both images and states."""
processor = ObservationProcessor()
"""Test processing a complete observation with both images and states."""
processor = ObservationProcessor()
# Create mock data
image = np.random.randint(0, 256, size=(32, 32, 3), dtype=np.uint8)
env_state = np.array([1.0, 2.0, 3.0], dtype=np.float32)
agent_pos = np.array([0.5, -0.5, 1.0], dtype=np.float32)
# Create mock data
image = np.random.randint(0, 256, size=(32, 32, 3), dtype=np.uint8)
env_state = np.array([1.0, 2.0, 3.0], dtype=np.float32)
agent_pos = np.array([0.5, -0.5, 1.0], dtype=np.float32)
observation = {
"pixels": image,
"environment_state": env_state,
"agent_pos": agent_pos,
"other_data": "preserve_me"
}
transition = (observation, None, None, None, None, None, None)
observation = {
"pixels": image,
"environment_state": env_state,
"agent_pos": agent_pos,
"other_data": "preserve_me",
}
transition = (observation, None, None, None, None, None, None)
result = processor(transition)
processed_obs = result[0]
result = processor(transition)
processed_obs = result[0]
# Check that image was processed
assert "observation.image" in processed_obs
assert processed_obs["observation.image"].shape == (1, 3, 32, 32)
# Check that image was processed
assert "observation.image" in processed_obs
assert processed_obs["observation.image"].shape == (1, 3, 32, 32)
# Check that states were processed
assert "observation.environment_state" in processed_obs
assert "observation.state" in processed_obs
# Check that states were processed
assert "observation.environment_state" in processed_obs
assert "observation.state" in processed_obs
# Check that original keys were removed
assert "pixels" not in processed_obs
assert "environment_state" not in processed_obs
assert "agent_pos" not in processed_obs
# Check that original keys were removed
assert "pixels" not in processed_obs
assert "environment_state" not in processed_obs
assert "agent_pos" not in processed_obs
# Check that other data was preserved
assert processed_obs["other_data"] == "preserve_me"
# Check that other data was preserved
assert processed_obs["other_data"] == "preserve_me"
def test_image_only_processing():
"""Test processing observation with only images."""
processor = ObservationProcessor()
"""Test processing observation with only images."""
processor = ObservationProcessor()
image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8)
observation = {"pixels": image}
transition = (observation, None, None, None, None, None, None)
image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8)
observation = {"pixels": image}
transition = (observation, None, None, None, None, None, None)
result = processor(transition)
processed_obs = result[0]
result = processor(transition)
processed_obs = result[0]
assert "observation.image" in processed_obs
assert len(processed_obs) == 1
assert "observation.image" in processed_obs
assert len(processed_obs) == 1
def test_state_only_processing():
"""Test processing observation with only states."""
processor = ObservationProcessor()
"""Test processing observation with only states."""
processor = ObservationProcessor()
agent_pos = np.array([1.0, 2.0], dtype=np.float32)
observation = {"agent_pos": agent_pos}
transition = (observation, None, None, None, None, None, None)
agent_pos = np.array([1.0, 2.0], dtype=np.float32)
observation = {"agent_pos": agent_pos}
transition = (observation, None, None, None, None, None, None)
result = processor(transition)
processed_obs = result[0]
result = processor(transition)
processed_obs = result[0]
assert "observation.state" in processed_obs
assert "agent_pos" not in processed_obs
assert "observation.state" in processed_obs
assert "agent_pos" not in processed_obs
def test_empty_observation():
"""Test processing empty observation."""
processor = ObservationProcessor()
"""Test processing empty observation."""
processor = ObservationProcessor()
observation = {}
transition = (observation, None, None, None, None, None, None)
observation = {}
transition = (observation, None, None, None, None, None, None)
result = processor(transition)
processed_obs = result[0]
result = processor(transition)
processed_obs = result[0]
assert processed_obs == {}
assert processed_obs == {}
def test_none_observation():
"""Test processing None observation."""
processor = ObservationProcessor()
"""Test processing None observation."""
processor = ObservationProcessor()
transition = (None, None, None, None, None, None, None)
result = processor(transition)
transition = (None, None, None, None, None, None, None)
result = processor(transition)
assert result == transition
assert result == transition
def test_serialization_methods():
"""Test serialization methods."""
processor = ObservationProcessor()
"""Test serialization methods."""
processor = ObservationProcessor()
# Test get_config
config = processor.get_config()
assert isinstance(config, dict)
assert "image_processor" in config
assert "state_processor" in config
# Test get_config
config = processor.get_config()
assert isinstance(config, dict)
assert "image_processor" in config
assert "state_processor" in config
# Test state_dict
state = processor.state_dict()
assert isinstance(state, dict)
# Test state_dict
state = processor.state_dict()
assert isinstance(state, dict)
# Test load_state_dict (should not raise)
processor.load_state_dict(state)
# Test load_state_dict (should not raise)
processor.load_state_dict(state)
# Test reset (should not raise)
processor.reset()
# Test reset (should not raise)
processor.reset()
def test_custom_sub_processors():
"""Test ObservationProcessor with custom sub-processors."""
image_proc = ImageProcessor()
state_proc = StateProcessor()
processor = ObservationProcessor(image_processor=image_proc, state_processor=state_proc)
"""Test ObservationProcessor with custom sub-processors."""
image_proc = ImageProcessor()
state_proc = StateProcessor()
processor = ObservationProcessor(image_processor=image_proc, state_processor=state_proc)
# Should use the provided processors
assert processor.image_processor is image_proc
assert processor.state_processor is state_proc
# Should use the provided processors
assert processor.image_processor is image_proc
assert processor.state_processor is state_proc
def test_equivalent_to_original_function():
"""Test that ObservationProcessor produces equivalent results to preprocess_observation."""
# Import the original function for comparison
from lerobot.envs.utils import preprocess_observation
"""Test that ObservationProcessor produces equivalent results to preprocess_observation."""
# Import the original function for comparison
from lerobot.envs.utils import preprocess_observation
processor = ObservationProcessor()
processor = ObservationProcessor()
# Create test data similar to what the original function expects
image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8)
env_state = np.array([1.0, 2.0, 3.0], dtype=np.float32)
agent_pos = np.array([0.5, -0.5, 1.0], dtype=np.float32)
# Create test data similar to what the original function expects
image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8)
env_state = np.array([1.0, 2.0, 3.0], dtype=np.float32)
agent_pos = np.array([0.5, -0.5, 1.0], dtype=np.float32)
observation = {
"pixels": image,
"environment_state": env_state,
"agent_pos": agent_pos
}
observation = {"pixels": image, "environment_state": env_state, "agent_pos": agent_pos}
# Process with original function
original_result = preprocess_observation(observation)
# Process with original function
original_result = preprocess_observation(observation)
# Process with new processor
transition = (observation, None, None, None, None, None, None)
processor_result = processor(transition)[0]
# Process with new processor
transition = (observation, None, None, None, None, None, None)
processor_result = processor(transition)[0]
# Compare results
assert set(original_result.keys()) == set(processor_result.keys())
# Compare results
assert set(original_result.keys()) == set(processor_result.keys())
for key in original_result:
torch.testing.assert_close(original_result[key], processor_result[key])
for key in original_result:
torch.testing.assert_close(original_result[key], processor_result[key])
def test_equivalent_with_image_dict():
"""Test equivalence with dictionary of images."""
from lerobot.envs.utils import preprocess_observation
"""Test equivalence with dictionary of images."""
from lerobot.envs.utils import preprocess_observation
processor = ObservationProcessor()
processor = ObservationProcessor()
# Create test data with multiple cameras
image1 = np.random.randint(0, 256, size=(32, 32, 3), dtype=np.uint8)
image2 = np.random.randint(0, 256, size=(48, 48, 3), dtype=np.uint8)
agent_pos = np.array([1.0, 2.0], dtype=np.float32)
# Create test data with multiple cameras
image1 = np.random.randint(0, 256, size=(32, 32, 3), dtype=np.uint8)
image2 = np.random.randint(0, 256, size=(48, 48, 3), dtype=np.uint8)
agent_pos = np.array([1.0, 2.0], dtype=np.float32)
observation = {
"pixels": {"cam1": image1, "cam2": image2},
"agent_pos": agent_pos
}
observation = {"pixels": {"cam1": image1, "cam2": image2}, "agent_pos": agent_pos}
# Process with original function
original_result = preprocess_observation(observation)
# Process with original function
original_result = preprocess_observation(observation)
# Process with new processor
transition = (observation, None, None, None, None, None, None)
processor_result = processor(transition)[0]
# Process with new processor
transition = (observation, None, None, None, None, None, None)
processor_result = processor(transition)[0]
# Compare results
assert set(original_result.keys()) == set(processor_result.keys())
# Compare results
assert set(original_result.keys()) == set(processor_result.keys())
for key in original_result:
torch.testing.assert_close(original_result[key], processor_result[key])
for key in original_result:
torch.testing.assert_close(original_result[key], processor_result[key])
+14 -5
View File
@@ -16,15 +16,14 @@
import json
import tempfile
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from dataclasses import dataclass
import numpy as np
import pytest
import torch
from lerobot.processor.pipeline import RobotPipeline, EnvTransition, PipelineStep
from lerobot.processor.pipeline import EnvTransition, RobotPipeline
@dataclass
@@ -151,6 +150,7 @@ def test_empty_pipeline():
assert result == transition
assert len(pipeline) == 0
def test_single_step_pipeline():
"""Test pipeline with a single step."""
step = MockStep("test_step")
@@ -166,6 +166,7 @@ def test_single_step_pipeline():
result = pipeline(transition)
assert result[6]["test_step_counter"] == 1
def test_multiple_steps_pipeline():
"""Test pipeline with multiple steps."""
step1 = MockStep("step1")
@@ -179,6 +180,7 @@ def test_multiple_steps_pipeline():
assert result[6]["step1_counter"] == 0
assert result[6]["step2_counter"] == 0
def test_invalid_transition_format():
"""Test pipeline with invalid transition format."""
pipeline = RobotPipeline([MockStep()])
@@ -191,6 +193,7 @@ def test_invalid_transition_format():
with pytest.raises(ValueError, match="EnvTransition must be a 7-tuple"):
pipeline("not a tuple")
def test_step_through():
"""Test step_through method."""
step1 = MockStep("step1")
@@ -206,6 +209,7 @@ def test_step_through():
assert "step1_counter" in results[1][6] # After step1
assert "step2_counter" in results[2][6] # After step2
def test_indexing():
"""Test pipeline indexing."""
step1 = MockStep("step1")
@@ -222,6 +226,7 @@ def test_indexing():
assert len(sub_pipeline) == 1
assert sub_pipeline[0] is step1
def test_hooks():
"""Test before/after step hooks."""
step = MockStep("test_step")
@@ -247,6 +252,7 @@ def test_hooks():
assert before_calls == [0]
assert after_calls == [0]
def test_hook_modification():
"""Test that hooks can modify transitions."""
step = MockStep("test_step")
@@ -263,6 +269,7 @@ def test_hook_modification():
assert result[2] == 42.0 # reward modified by hook
def test_reset():
"""Test pipeline reset functionality."""
step = MockStep("test_step")
@@ -288,6 +295,7 @@ def test_reset():
assert step.counter == 0
assert len(reset_called) == 1
def test_profile_steps():
"""Test step profiling functionality."""
step1 = MockStep("step1")
@@ -303,6 +311,7 @@ def test_profile_steps():
assert "step_1_MockStep" in profile_results
assert all(isinstance(time, float) and time >= 0 for time in profile_results.values())
def test_save_and_load_pretrained():
"""Test saving and loading pipeline.
@@ -349,6 +358,7 @@ def test_save_and_load_pretrained():
assert loaded_pipeline.steps[0].counter == 5
assert loaded_pipeline.steps[1].counter == 10
def test_step_without_optional_methods():
"""Test pipeline with steps that don't implement optional methods."""
step = MockStepWithoutOptionalMethods(multiplier=3.0)
@@ -368,6 +378,7 @@ def test_step_without_optional_methods():
loaded_pipeline = RobotPipeline.from_pretrained(tmp_dir)
assert len(loaded_pipeline) == 1
def test_mixed_json_and_tensor_state():
"""Test step with both JSON attributes and tensor state."""
step = MockStepWithTensorState(name="stats", learning_rate=0.05, window_size=5)
@@ -404,5 +415,3 @@ def test_mixed_json_and_tensor_state():
# Check tensor state was restored
assert loaded_step.running_count.item() == 10
assert torch.allclose(loaded_step.running_mean, step.running_mean)