mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-24 04:59:47 +00:00
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
committed by
Adil Zouitine
parent
f6c7287ae7
commit
769f531603
@@ -16,10 +16,8 @@
|
||||
import warnings
|
||||
from typing import Any
|
||||
|
||||
import einops
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
@@ -38,8 +36,8 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
||||
Returns:
|
||||
Dictionary of observation batches with keys renamed to LeRobot format and values as tensors.
|
||||
"""
|
||||
from lerobot.processor.pipeline import RobotPipeline, TransitionIndex
|
||||
from lerobot.processor.observation_processor import ObservationProcessor
|
||||
from lerobot.processor.pipeline import RobotPipeline, TransitionIndex
|
||||
|
||||
# Create pipeline with observation processor
|
||||
pipeline = RobotPipeline([ObservationProcessor()])
|
||||
@@ -52,9 +50,6 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
||||
return processed_transition[TransitionIndex.OBSERVATION]
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]:
|
||||
# TODO(aliberts, rcadene): remove this hardcoding of keys and just use the nested keys as is
|
||||
# (need to externalize normalization from policies)
|
||||
|
||||
@@ -13,12 +13,12 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from .pipeline import RobotPipeline, PipelineStep, EnvTransition
|
||||
from .observation_processor import (
|
||||
ImageProcessor,
|
||||
StateProcessor,
|
||||
ObservationProcessor,
|
||||
StateProcessor,
|
||||
)
|
||||
from .pipeline import EnvTransition, PipelineStep, RobotPipeline
|
||||
|
||||
__all__ = [
|
||||
"RobotPipeline",
|
||||
|
||||
@@ -15,14 +15,15 @@
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
import torch
|
||||
import einops
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.processor.pipeline import EnvTransition, PipelineStep, TransitionIndex
|
||||
from lerobot.processor.pipeline import EnvTransition, TransitionIndex
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -212,8 +213,12 @@ class ObservationProcessor:
|
||||
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||
"""Load state dictionary."""
|
||||
image_state = {k.replace("image_processor.", ""): v for k, v in state.items() if k.startswith("image_processor.")}
|
||||
state_state = {k.replace("state_processor.", ""): v for k, v in state.items() if k.startswith("state_processor.")}
|
||||
image_state = {
|
||||
k.replace("image_processor.", ""): v for k, v in state.items() if k.startswith("image_processor.")
|
||||
}
|
||||
state_state = {
|
||||
k.replace("state_processor.", ""): v for k, v in state.items() if k.startswith("state_processor.")
|
||||
}
|
||||
|
||||
self.image_processor.load_state_dict(image_state)
|
||||
self.state_processor.load_state_dict(state_state)
|
||||
|
||||
@@ -14,19 +14,22 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
import os, json
|
||||
from typing import Any, Dict, Sequence, Iterable, Protocol, Optional, Tuple, Callable, Union
|
||||
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from enum import IntEnum
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Iterable, Protocol, Sequence, Tuple
|
||||
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download, ModelHubMixin
|
||||
from safetensors.torch import save_file, load_file
|
||||
from huggingface_hub import ModelHubMixin, hf_hub_download
|
||||
from safetensors.torch import load_file, save_file
|
||||
|
||||
|
||||
class TransitionIndex(IntEnum):
|
||||
"""Explicit indices for EnvTransition tuple components."""
|
||||
|
||||
OBSERVATION = 0
|
||||
ACTION = 1
|
||||
REWARD = 2
|
||||
@@ -38,17 +41,16 @@ class TransitionIndex(IntEnum):
|
||||
|
||||
# (observation, action, reward, done, truncated, info, complementary_data)
|
||||
EnvTransition = Tuple[
|
||||
Any| None, # observation
|
||||
Any| None, # action
|
||||
float| None, # reward
|
||||
bool| None, # done
|
||||
bool| None, # truncated
|
||||
Dict[str, Any]| None, # info
|
||||
Dict[str, Any]| None, # complementary_data
|
||||
Any | None, # observation
|
||||
Any | None, # action
|
||||
float | None, # reward
|
||||
bool | None, # done
|
||||
bool | None, # truncated
|
||||
Dict[str, Any] | None, # info
|
||||
Dict[str, Any] | None, # complementary_data
|
||||
]
|
||||
|
||||
|
||||
|
||||
class PipelineStep(Protocol):
|
||||
"""Structural typing interface for a single pipeline step.
|
||||
|
||||
@@ -78,11 +80,11 @@ class PipelineStep(Protocol):
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition: ...
|
||||
|
||||
def get_config(self) -> Dict[str, Any]: ...
|
||||
def get_config(self) -> dict[str, Any]: ...
|
||||
|
||||
def state_dict(self) -> Dict[str, torch.Tensor]: ...
|
||||
def state_dict(self) -> dict[str, torch.Tensor]: ...
|
||||
|
||||
def load_state_dict(self, state: Dict[str, torch.Tensor]) -> None: ...
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: ...
|
||||
|
||||
def reset(self) -> None: ...
|
||||
|
||||
@@ -120,17 +122,18 @@ class RobotPipeline(ModelHubMixin):
|
||||
pipe.push_to_hub("my-org/cartpole_pipe")
|
||||
loaded = RobotPipeline.from_pretrained("my-org/cartpole_pipe")
|
||||
"""
|
||||
|
||||
steps: Sequence[PipelineStep] = field(default_factory=list)
|
||||
name: str = "RobotPipeline"
|
||||
seed: Optional[int] = None
|
||||
seed: int | None = None
|
||||
|
||||
# Pipeline-level hooks
|
||||
# A hook can optionally return a modified transition. If it returns
|
||||
# ``None`` the current value is left untouched.
|
||||
before_step_hooks: list[Callable[[int, EnvTransition], Optional[EnvTransition]]] = field(
|
||||
before_step_hooks: list[Callable[[int, EnvTransition], EnvTransition | None]] = field(
|
||||
default_factory=list, repr=False
|
||||
)
|
||||
after_step_hooks: list[Callable[[int, EnvTransition], Optional[EnvTransition]]] = field(
|
||||
after_step_hooks: list[Callable[[int, EnvTransition], EnvTransition | None]] = field(
|
||||
default_factory=list, repr=False
|
||||
)
|
||||
reset_hooks: list[Callable[[], None]] = field(default_factory=list, repr=False)
|
||||
@@ -177,14 +180,14 @@ class RobotPipeline(ModelHubMixin):
|
||||
"""Serialize the pipeline definition and parameters to *destination_path*."""
|
||||
os.makedirs(destination_path, exist_ok=True)
|
||||
|
||||
config: Dict[str, Any] = {
|
||||
config: dict[str, Any] = {
|
||||
"name": self.name,
|
||||
"seed": self.seed,
|
||||
"steps": [],
|
||||
}
|
||||
|
||||
for step_index, pipeline_step in enumerate(self.steps):
|
||||
step_entry: Dict[str, Any] = {
|
||||
step_entry: dict[str, Any] = {
|
||||
"class": f"{pipeline_step.__class__.__module__}.{pipeline_step.__class__.__name__}",
|
||||
}
|
||||
|
||||
@@ -204,19 +207,19 @@ class RobotPipeline(ModelHubMixin):
|
||||
json.dump(config, file_pointer, indent=2)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, source: str) -> "RobotPipeline":
|
||||
def from_pretrained(cls, source: str) -> RobotPipeline:
|
||||
"""Load a serialized pipeline from *source* (local path or Hugging Face Hub identifier)."""
|
||||
if Path(source).is_dir():
|
||||
# Local path - use it directly
|
||||
base_path = Path(source)
|
||||
with open(base_path / cls._CFG_NAME) as file_pointer:
|
||||
config: Dict[str, Any] = json.load(file_pointer)
|
||||
config: dict[str, Any] = json.load(file_pointer)
|
||||
else:
|
||||
# Hugging Face Hub - download all required files
|
||||
# First download the config file
|
||||
config_path = hf_hub_download(source, cls._CFG_NAME, repo_type="model")
|
||||
with open(config_path) as file_pointer:
|
||||
config: Dict[str, Any] = json.load(file_pointer)
|
||||
config: dict[str, Any] = json.load(file_pointer)
|
||||
|
||||
# Store downloaded files in the same directory as the config
|
||||
base_path = Path(config_path).parent
|
||||
@@ -254,11 +257,11 @@ class RobotPipeline(ModelHubMixin):
|
||||
return RobotPipeline(self.steps[idx], self.name, self.seed)
|
||||
return self.steps[idx]
|
||||
|
||||
def register_before_step_hook(self, fn: Callable[[int, EnvTransition], Optional[EnvTransition]]):
|
||||
def register_before_step_hook(self, fn: Callable[[int, EnvTransition], EnvTransition | None]):
|
||||
"""Attach fn to be executed before every pipeline step."""
|
||||
self.before_step_hooks.append(fn)
|
||||
|
||||
def register_after_step_hook(self, fn: Callable[[int, EnvTransition], Optional[EnvTransition]]):
|
||||
def register_after_step_hook(self, fn: Callable[[int, EnvTransition], EnvTransition | None]):
|
||||
"""Attach fn to be executed after every pipeline step."""
|
||||
self.after_step_hooks.append(fn)
|
||||
|
||||
@@ -274,7 +277,7 @@ class RobotPipeline(ModelHubMixin):
|
||||
for fn in self.reset_hooks:
|
||||
fn()
|
||||
|
||||
def profile_steps(self, transition: EnvTransition, num_runs: int = 100) -> Dict[str, float]:
|
||||
def profile_steps(self, transition: EnvTransition, num_runs: int = 100) -> dict[str, float]:
|
||||
"""Profile the execution time of each step for performance optimization."""
|
||||
import time
|
||||
|
||||
|
||||
@@ -69,12 +69,11 @@ from lerobot.configs import parser
|
||||
from lerobot.configs.eval import EvalPipelineConfig
|
||||
from lerobot.envs.factory import make_env
|
||||
from lerobot.envs.utils import add_envs_task, check_env_attributes_and_types
|
||||
|
||||
from lerobot.processor.pipeline import RobotPipeline, TransitionIndex
|
||||
from lerobot.processor.observation_processor import ObservationProcessor
|
||||
from lerobot.policies.factory import make_policy
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.utils import get_device_from_parameters
|
||||
from lerobot.processor.observation_processor import ObservationProcessor
|
||||
from lerobot.processor.pipeline import RobotPipeline, TransitionIndex
|
||||
from lerobot.utils.io_utils import write_video
|
||||
from lerobot.utils.random_utils import set_seed
|
||||
from lerobot.utils.utils import (
|
||||
|
||||
@@ -22,9 +22,8 @@ from gymnasium.utils.env_checker import check_env
|
||||
|
||||
import lerobot
|
||||
from lerobot.envs.factory import make_env, make_env_config
|
||||
|
||||
from lerobot.processor.pipeline import RobotPipeline, TransitionIndex
|
||||
from lerobot.processor.observation_processor import ObservationProcessor
|
||||
from lerobot.processor.pipeline import RobotPipeline, TransitionIndex
|
||||
from tests.utils import require_env
|
||||
|
||||
OBS_TYPES = ["state", "pixels", "pixels_agent_pos"]
|
||||
|
||||
@@ -30,8 +30,6 @@ from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.datasets.factory import make_dataset
|
||||
from lerobot.datasets.utils import cycle, dataset_to_policy_features
|
||||
from lerobot.envs.factory import make_env, make_env_config
|
||||
from lerobot.processor.pipeline import RobotPipeline, TransitionIndex
|
||||
from lerobot.processor.observation_processor import ObservationProcessor
|
||||
from lerobot.optim.factory import make_optimizer_and_scheduler
|
||||
from lerobot.policies.act.modeling_act import ACTTemporalEnsembler
|
||||
from lerobot.policies.factory import (
|
||||
@@ -41,6 +39,8 @@ from lerobot.policies.factory import (
|
||||
)
|
||||
from lerobot.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.processor.observation_processor import ObservationProcessor
|
||||
from lerobot.processor.pipeline import RobotPipeline, TransitionIndex
|
||||
from lerobot.utils.random_utils import seeded_context
|
||||
from tests.artifacts.policies.save_policy_to_safetensors import get_policy_stats
|
||||
from tests.utils import DEVICE, require_cpu, require_env, require_x86_64_kernel
|
||||
|
||||
@@ -20,10 +20,9 @@ import torch
|
||||
|
||||
from lerobot.processor.observation_processor import (
|
||||
ImageProcessor,
|
||||
StateProcessor,
|
||||
ObservationProcessor,
|
||||
StateProcessor,
|
||||
)
|
||||
from lerobot.processor.pipeline import EnvTransition
|
||||
|
||||
|
||||
def test_process_single_image():
|
||||
@@ -51,6 +50,7 @@ def test_process_single_image():
|
||||
assert processed_img.min() >= 0.0
|
||||
assert processed_img.max() <= 1.0
|
||||
|
||||
|
||||
def test_process_image_dict():
|
||||
"""Test processing multiple images in a dictionary."""
|
||||
processor = ImageProcessor()
|
||||
@@ -59,12 +59,7 @@ def test_process_image_dict():
|
||||
image1 = np.random.randint(0, 256, size=(32, 32, 3), dtype=np.uint8)
|
||||
image2 = np.random.randint(0, 256, size=(48, 48, 3), dtype=np.uint8)
|
||||
|
||||
observation = {
|
||||
"pixels": {
|
||||
"camera1": image1,
|
||||
"camera2": image2
|
||||
}
|
||||
}
|
||||
observation = {"pixels": {"camera1": image1, "camera2": image2}}
|
||||
transition = (observation, None, None, None, None, None, None)
|
||||
|
||||
result = processor(transition)
|
||||
@@ -78,6 +73,7 @@ def test_process_image_dict():
|
||||
assert processed_obs["observation.images.camera1"].shape == (1, 3, 32, 32)
|
||||
assert processed_obs["observation.images.camera2"].shape == (1, 3, 48, 48)
|
||||
|
||||
|
||||
def test_process_batched_image():
|
||||
"""Test processing already batched images."""
|
||||
processor = ImageProcessor()
|
||||
@@ -94,6 +90,7 @@ def test_process_batched_image():
|
||||
# Check that batch dimension is preserved
|
||||
assert processed_obs["observation.image"].shape == (2, 3, 64, 64)
|
||||
|
||||
|
||||
def test_invalid_image_format():
|
||||
"""Test error handling for invalid image formats."""
|
||||
processor = ImageProcessor()
|
||||
@@ -106,6 +103,7 @@ def test_invalid_image_format():
|
||||
with pytest.raises(ValueError, match="Expected channel-last images"):
|
||||
processor(transition)
|
||||
|
||||
|
||||
def test_invalid_image_dtype():
|
||||
"""Test error handling for invalid image dtype."""
|
||||
processor = ImageProcessor()
|
||||
@@ -118,6 +116,7 @@ def test_invalid_image_dtype():
|
||||
with pytest.raises(ValueError, match="Expected torch.uint8 images"):
|
||||
processor(transition)
|
||||
|
||||
|
||||
def test_no_pixels_in_observation():
|
||||
"""Test processor when no pixels are in observation."""
|
||||
processor = ImageProcessor()
|
||||
@@ -132,6 +131,7 @@ def test_no_pixels_in_observation():
|
||||
assert "other_data" in processed_obs
|
||||
np.testing.assert_array_equal(processed_obs["other_data"], np.array([1, 2, 3]))
|
||||
|
||||
|
||||
def test_none_observation():
|
||||
"""Test processor with None observation."""
|
||||
processor = ImageProcessor()
|
||||
@@ -141,6 +141,7 @@ def test_none_observation():
|
||||
|
||||
assert result == transition
|
||||
|
||||
|
||||
def test_serialization_methods():
|
||||
"""Test serialization methods."""
|
||||
processor = ImageProcessor()
|
||||
@@ -161,64 +162,64 @@ def test_serialization_methods():
|
||||
|
||||
|
||||
def test_process_environment_state():
|
||||
"""Test processing environment_state."""
|
||||
processor = StateProcessor()
|
||||
"""Test processing environment_state."""
|
||||
processor = StateProcessor()
|
||||
|
||||
env_state = np.array([1.0, 2.0, 3.0], dtype=np.float32)
|
||||
observation = {"environment_state": env_state}
|
||||
transition = (observation, None, None, None, None, None, None)
|
||||
env_state = np.array([1.0, 2.0, 3.0], dtype=np.float32)
|
||||
observation = {"environment_state": env_state}
|
||||
transition = (observation, None, None, None, None, None, None)
|
||||
|
||||
result = processor(transition)
|
||||
processed_obs = result[0]
|
||||
result = processor(transition)
|
||||
processed_obs = result[0]
|
||||
|
||||
# Check that environment_state was renamed and processed
|
||||
assert "observation.environment_state" in processed_obs
|
||||
assert "environment_state" not in processed_obs
|
||||
# Check that environment_state was renamed and processed
|
||||
assert "observation.environment_state" in processed_obs
|
||||
assert "environment_state" not in processed_obs
|
||||
|
||||
processed_state = processed_obs["observation.environment_state"]
|
||||
assert processed_state.shape == (1, 3) # Batch dimension added
|
||||
assert processed_state.dtype == torch.float32
|
||||
torch.testing.assert_close(processed_state, torch.tensor([[1.0, 2.0, 3.0]]))
|
||||
|
||||
processed_state = processed_obs["observation.environment_state"]
|
||||
assert processed_state.shape == (1, 3) # Batch dimension added
|
||||
assert processed_state.dtype == torch.float32
|
||||
torch.testing.assert_close(processed_state, torch.tensor([[1.0, 2.0, 3.0]]))
|
||||
|
||||
def test_process_agent_pos():
|
||||
"""Test processing agent_pos."""
|
||||
processor = StateProcessor()
|
||||
"""Test processing agent_pos."""
|
||||
processor = StateProcessor()
|
||||
|
||||
agent_pos = np.array([0.5, -0.5, 1.0], dtype=np.float32)
|
||||
observation = {"agent_pos": agent_pos}
|
||||
transition = (observation, None, None, None, None, None, None)
|
||||
agent_pos = np.array([0.5, -0.5, 1.0], dtype=np.float32)
|
||||
observation = {"agent_pos": agent_pos}
|
||||
transition = (observation, None, None, None, None, None, None)
|
||||
|
||||
result = processor(transition)
|
||||
processed_obs = result[0]
|
||||
result = processor(transition)
|
||||
processed_obs = result[0]
|
||||
|
||||
# Check that agent_pos was renamed and processed
|
||||
assert "observation.state" in processed_obs
|
||||
assert "agent_pos" not in processed_obs
|
||||
# Check that agent_pos was renamed and processed
|
||||
assert "observation.state" in processed_obs
|
||||
assert "agent_pos" not in processed_obs
|
||||
|
||||
processed_state = processed_obs["observation.state"]
|
||||
assert processed_state.shape == (1, 3) # Batch dimension added
|
||||
assert processed_state.dtype == torch.float32
|
||||
torch.testing.assert_close(processed_state, torch.tensor([[0.5, -0.5, 1.0]]))
|
||||
|
||||
processed_state = processed_obs["observation.state"]
|
||||
assert processed_state.shape == (1, 3) # Batch dimension added
|
||||
assert processed_state.dtype == torch.float32
|
||||
torch.testing.assert_close(processed_state, torch.tensor([[0.5, -0.5, 1.0]]))
|
||||
|
||||
def test_process_batched_states():
|
||||
"""Test processing already batched states."""
|
||||
processor = StateProcessor()
|
||||
"""Test processing already batched states."""
|
||||
processor = StateProcessor()
|
||||
|
||||
env_state = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)
|
||||
agent_pos = np.array([[0.5, -0.5], [1.0, -1.0]], dtype=np.float32)
|
||||
env_state = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)
|
||||
agent_pos = np.array([[0.5, -0.5], [1.0, -1.0]], dtype=np.float32)
|
||||
|
||||
observation = {
|
||||
"environment_state": env_state,
|
||||
"agent_pos": agent_pos
|
||||
}
|
||||
transition = (observation, None, None, None, None, None, None)
|
||||
observation = {"environment_state": env_state, "agent_pos": agent_pos}
|
||||
transition = (observation, None, None, None, None, None, None)
|
||||
|
||||
result = processor(transition)
|
||||
processed_obs = result[0]
|
||||
result = processor(transition)
|
||||
processed_obs = result[0]
|
||||
|
||||
# Check that batch dimensions are preserved
|
||||
assert processed_obs["observation.environment_state"].shape == (2, 2)
|
||||
assert processed_obs["observation.state"].shape == (2, 2)
|
||||
|
||||
# Check that batch dimensions are preserved
|
||||
assert processed_obs["observation.environment_state"].shape == (2, 2)
|
||||
assert processed_obs["observation.state"].shape == (2, 2)
|
||||
|
||||
def test_process_both_states():
|
||||
"""Test processing both environment_state and agent_pos."""
|
||||
@@ -227,11 +228,7 @@ def test_process_both_states():
|
||||
env_state = np.array([1.0, 2.0], dtype=np.float32)
|
||||
agent_pos = np.array([0.5, -0.5], dtype=np.float32)
|
||||
|
||||
observation = {
|
||||
"environment_state": env_state,
|
||||
"agent_pos": agent_pos,
|
||||
"other_data": "keep_me"
|
||||
}
|
||||
observation = {"environment_state": env_state, "agent_pos": agent_pos, "other_data": "keep_me"}
|
||||
transition = (observation, None, None, None, None, None, None)
|
||||
|
||||
result = processor(transition)
|
||||
@@ -248,6 +245,7 @@ def test_process_both_states():
|
||||
# Check that other data was preserved
|
||||
assert processed_obs["other_data"] == "keep_me"
|
||||
|
||||
|
||||
def test_no_states_in_observation():
|
||||
"""Test processor when no states are in observation."""
|
||||
processor = StateProcessor()
|
||||
@@ -261,6 +259,7 @@ def test_no_states_in_observation():
|
||||
# Should preserve data unchanged
|
||||
assert processed_obs == observation
|
||||
|
||||
|
||||
def test_none_observation():
|
||||
"""Test processor with None observation."""
|
||||
processor = StateProcessor()
|
||||
@@ -270,6 +269,7 @@ def test_none_observation():
|
||||
|
||||
assert result == transition
|
||||
|
||||
|
||||
def test_serialization_methods():
|
||||
"""Test serialization methods."""
|
||||
processor = StateProcessor()
|
||||
@@ -290,177 +290,177 @@ def test_serialization_methods():
|
||||
|
||||
|
||||
def test_complete_observation_processing():
|
||||
"""Test processing a complete observation with both images and states."""
|
||||
processor = ObservationProcessor()
|
||||
"""Test processing a complete observation with both images and states."""
|
||||
processor = ObservationProcessor()
|
||||
|
||||
# Create mock data
|
||||
image = np.random.randint(0, 256, size=(32, 32, 3), dtype=np.uint8)
|
||||
env_state = np.array([1.0, 2.0, 3.0], dtype=np.float32)
|
||||
agent_pos = np.array([0.5, -0.5, 1.0], dtype=np.float32)
|
||||
# Create mock data
|
||||
image = np.random.randint(0, 256, size=(32, 32, 3), dtype=np.uint8)
|
||||
env_state = np.array([1.0, 2.0, 3.0], dtype=np.float32)
|
||||
agent_pos = np.array([0.5, -0.5, 1.0], dtype=np.float32)
|
||||
|
||||
observation = {
|
||||
"pixels": image,
|
||||
"environment_state": env_state,
|
||||
"agent_pos": agent_pos,
|
||||
"other_data": "preserve_me"
|
||||
}
|
||||
transition = (observation, None, None, None, None, None, None)
|
||||
observation = {
|
||||
"pixels": image,
|
||||
"environment_state": env_state,
|
||||
"agent_pos": agent_pos,
|
||||
"other_data": "preserve_me",
|
||||
}
|
||||
transition = (observation, None, None, None, None, None, None)
|
||||
|
||||
result = processor(transition)
|
||||
processed_obs = result[0]
|
||||
result = processor(transition)
|
||||
processed_obs = result[0]
|
||||
|
||||
# Check that image was processed
|
||||
assert "observation.image" in processed_obs
|
||||
assert processed_obs["observation.image"].shape == (1, 3, 32, 32)
|
||||
# Check that image was processed
|
||||
assert "observation.image" in processed_obs
|
||||
assert processed_obs["observation.image"].shape == (1, 3, 32, 32)
|
||||
|
||||
# Check that states were processed
|
||||
assert "observation.environment_state" in processed_obs
|
||||
assert "observation.state" in processed_obs
|
||||
# Check that states were processed
|
||||
assert "observation.environment_state" in processed_obs
|
||||
assert "observation.state" in processed_obs
|
||||
|
||||
# Check that original keys were removed
|
||||
assert "pixels" not in processed_obs
|
||||
assert "environment_state" not in processed_obs
|
||||
assert "agent_pos" not in processed_obs
|
||||
# Check that original keys were removed
|
||||
assert "pixels" not in processed_obs
|
||||
assert "environment_state" not in processed_obs
|
||||
assert "agent_pos" not in processed_obs
|
||||
|
||||
# Check that other data was preserved
|
||||
assert processed_obs["other_data"] == "preserve_me"
|
||||
|
||||
# Check that other data was preserved
|
||||
assert processed_obs["other_data"] == "preserve_me"
|
||||
|
||||
def test_image_only_processing():
|
||||
"""Test processing observation with only images."""
|
||||
processor = ObservationProcessor()
|
||||
"""Test processing observation with only images."""
|
||||
processor = ObservationProcessor()
|
||||
|
||||
image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8)
|
||||
observation = {"pixels": image}
|
||||
transition = (observation, None, None, None, None, None, None)
|
||||
image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8)
|
||||
observation = {"pixels": image}
|
||||
transition = (observation, None, None, None, None, None, None)
|
||||
|
||||
result = processor(transition)
|
||||
processed_obs = result[0]
|
||||
result = processor(transition)
|
||||
processed_obs = result[0]
|
||||
|
||||
assert "observation.image" in processed_obs
|
||||
assert len(processed_obs) == 1
|
||||
|
||||
assert "observation.image" in processed_obs
|
||||
assert len(processed_obs) == 1
|
||||
|
||||
def test_state_only_processing():
|
||||
"""Test processing observation with only states."""
|
||||
processor = ObservationProcessor()
|
||||
"""Test processing observation with only states."""
|
||||
processor = ObservationProcessor()
|
||||
|
||||
agent_pos = np.array([1.0, 2.0], dtype=np.float32)
|
||||
observation = {"agent_pos": agent_pos}
|
||||
transition = (observation, None, None, None, None, None, None)
|
||||
agent_pos = np.array([1.0, 2.0], dtype=np.float32)
|
||||
observation = {"agent_pos": agent_pos}
|
||||
transition = (observation, None, None, None, None, None, None)
|
||||
|
||||
result = processor(transition)
|
||||
processed_obs = result[0]
|
||||
result = processor(transition)
|
||||
processed_obs = result[0]
|
||||
|
||||
assert "observation.state" in processed_obs
|
||||
assert "agent_pos" not in processed_obs
|
||||
|
||||
assert "observation.state" in processed_obs
|
||||
assert "agent_pos" not in processed_obs
|
||||
|
||||
def test_empty_observation():
|
||||
"""Test processing empty observation."""
|
||||
processor = ObservationProcessor()
|
||||
"""Test processing empty observation."""
|
||||
processor = ObservationProcessor()
|
||||
|
||||
observation = {}
|
||||
transition = (observation, None, None, None, None, None, None)
|
||||
observation = {}
|
||||
transition = (observation, None, None, None, None, None, None)
|
||||
|
||||
result = processor(transition)
|
||||
processed_obs = result[0]
|
||||
result = processor(transition)
|
||||
processed_obs = result[0]
|
||||
|
||||
assert processed_obs == {}
|
||||
|
||||
assert processed_obs == {}
|
||||
|
||||
def test_none_observation():
|
||||
"""Test processing None observation."""
|
||||
processor = ObservationProcessor()
|
||||
"""Test processing None observation."""
|
||||
processor = ObservationProcessor()
|
||||
|
||||
transition = (None, None, None, None, None, None, None)
|
||||
result = processor(transition)
|
||||
transition = (None, None, None, None, None, None, None)
|
||||
result = processor(transition)
|
||||
|
||||
assert result == transition
|
||||
|
||||
assert result == transition
|
||||
|
||||
def test_serialization_methods():
|
||||
"""Test serialization methods."""
|
||||
processor = ObservationProcessor()
|
||||
"""Test serialization methods."""
|
||||
processor = ObservationProcessor()
|
||||
|
||||
# Test get_config
|
||||
config = processor.get_config()
|
||||
assert isinstance(config, dict)
|
||||
assert "image_processor" in config
|
||||
assert "state_processor" in config
|
||||
# Test get_config
|
||||
config = processor.get_config()
|
||||
assert isinstance(config, dict)
|
||||
assert "image_processor" in config
|
||||
assert "state_processor" in config
|
||||
|
||||
# Test state_dict
|
||||
state = processor.state_dict()
|
||||
assert isinstance(state, dict)
|
||||
# Test state_dict
|
||||
state = processor.state_dict()
|
||||
assert isinstance(state, dict)
|
||||
|
||||
# Test load_state_dict (should not raise)
|
||||
processor.load_state_dict(state)
|
||||
# Test load_state_dict (should not raise)
|
||||
processor.load_state_dict(state)
|
||||
|
||||
# Test reset (should not raise)
|
||||
processor.reset()
|
||||
|
||||
# Test reset (should not raise)
|
||||
processor.reset()
|
||||
|
||||
def test_custom_sub_processors():
|
||||
"""Test ObservationProcessor with custom sub-processors."""
|
||||
image_proc = ImageProcessor()
|
||||
state_proc = StateProcessor()
|
||||
processor = ObservationProcessor(image_processor=image_proc, state_processor=state_proc)
|
||||
"""Test ObservationProcessor with custom sub-processors."""
|
||||
image_proc = ImageProcessor()
|
||||
state_proc = StateProcessor()
|
||||
processor = ObservationProcessor(image_processor=image_proc, state_processor=state_proc)
|
||||
|
||||
# Should use the provided processors
|
||||
assert processor.image_processor is image_proc
|
||||
assert processor.state_processor is state_proc
|
||||
# Should use the provided processors
|
||||
assert processor.image_processor is image_proc
|
||||
assert processor.state_processor is state_proc
|
||||
|
||||
|
||||
def test_equivalent_to_original_function():
|
||||
"""Test that ObservationProcessor produces equivalent results to preprocess_observation."""
|
||||
# Import the original function for comparison
|
||||
from lerobot.envs.utils import preprocess_observation
|
||||
"""Test that ObservationProcessor produces equivalent results to preprocess_observation."""
|
||||
# Import the original function for comparison
|
||||
from lerobot.envs.utils import preprocess_observation
|
||||
|
||||
processor = ObservationProcessor()
|
||||
processor = ObservationProcessor()
|
||||
|
||||
# Create test data similar to what the original function expects
|
||||
image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8)
|
||||
env_state = np.array([1.0, 2.0, 3.0], dtype=np.float32)
|
||||
agent_pos = np.array([0.5, -0.5, 1.0], dtype=np.float32)
|
||||
# Create test data similar to what the original function expects
|
||||
image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8)
|
||||
env_state = np.array([1.0, 2.0, 3.0], dtype=np.float32)
|
||||
agent_pos = np.array([0.5, -0.5, 1.0], dtype=np.float32)
|
||||
|
||||
observation = {
|
||||
"pixels": image,
|
||||
"environment_state": env_state,
|
||||
"agent_pos": agent_pos
|
||||
}
|
||||
observation = {"pixels": image, "environment_state": env_state, "agent_pos": agent_pos}
|
||||
|
||||
# Process with original function
|
||||
original_result = preprocess_observation(observation)
|
||||
# Process with original function
|
||||
original_result = preprocess_observation(observation)
|
||||
|
||||
# Process with new processor
|
||||
transition = (observation, None, None, None, None, None, None)
|
||||
processor_result = processor(transition)[0]
|
||||
# Process with new processor
|
||||
transition = (observation, None, None, None, None, None, None)
|
||||
processor_result = processor(transition)[0]
|
||||
|
||||
# Compare results
|
||||
assert set(original_result.keys()) == set(processor_result.keys())
|
||||
# Compare results
|
||||
assert set(original_result.keys()) == set(processor_result.keys())
|
||||
|
||||
for key in original_result:
|
||||
torch.testing.assert_close(original_result[key], processor_result[key])
|
||||
|
||||
for key in original_result:
|
||||
torch.testing.assert_close(original_result[key], processor_result[key])
|
||||
|
||||
def test_equivalent_with_image_dict():
|
||||
"""Test equivalence with dictionary of images."""
|
||||
from lerobot.envs.utils import preprocess_observation
|
||||
"""Test equivalence with dictionary of images."""
|
||||
from lerobot.envs.utils import preprocess_observation
|
||||
|
||||
processor = ObservationProcessor()
|
||||
processor = ObservationProcessor()
|
||||
|
||||
# Create test data with multiple cameras
|
||||
image1 = np.random.randint(0, 256, size=(32, 32, 3), dtype=np.uint8)
|
||||
image2 = np.random.randint(0, 256, size=(48, 48, 3), dtype=np.uint8)
|
||||
agent_pos = np.array([1.0, 2.0], dtype=np.float32)
|
||||
# Create test data with multiple cameras
|
||||
image1 = np.random.randint(0, 256, size=(32, 32, 3), dtype=np.uint8)
|
||||
image2 = np.random.randint(0, 256, size=(48, 48, 3), dtype=np.uint8)
|
||||
agent_pos = np.array([1.0, 2.0], dtype=np.float32)
|
||||
|
||||
observation = {
|
||||
"pixels": {"cam1": image1, "cam2": image2},
|
||||
"agent_pos": agent_pos
|
||||
}
|
||||
observation = {"pixels": {"cam1": image1, "cam2": image2}, "agent_pos": agent_pos}
|
||||
|
||||
# Process with original function
|
||||
original_result = preprocess_observation(observation)
|
||||
# Process with original function
|
||||
original_result = preprocess_observation(observation)
|
||||
|
||||
# Process with new processor
|
||||
transition = (observation, None, None, None, None, None, None)
|
||||
processor_result = processor(transition)[0]
|
||||
# Process with new processor
|
||||
transition = (observation, None, None, None, None, None, None)
|
||||
processor_result = processor(transition)[0]
|
||||
|
||||
# Compare results
|
||||
assert set(original_result.keys()) == set(processor_result.keys())
|
||||
# Compare results
|
||||
assert set(original_result.keys()) == set(processor_result.keys())
|
||||
|
||||
for key in original_result:
|
||||
torch.testing.assert_close(original_result[key], processor_result[key])
|
||||
for key in original_result:
|
||||
torch.testing.assert_close(original_result[key], processor_result[key])
|
||||
|
||||
@@ -16,15 +16,14 @@
|
||||
|
||||
import json
|
||||
import tempfile
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from dataclasses import dataclass
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.processor.pipeline import RobotPipeline, EnvTransition, PipelineStep
|
||||
from lerobot.processor.pipeline import EnvTransition, RobotPipeline
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -151,6 +150,7 @@ def test_empty_pipeline():
|
||||
assert result == transition
|
||||
assert len(pipeline) == 0
|
||||
|
||||
|
||||
def test_single_step_pipeline():
|
||||
"""Test pipeline with a single step."""
|
||||
step = MockStep("test_step")
|
||||
@@ -166,6 +166,7 @@ def test_single_step_pipeline():
|
||||
result = pipeline(transition)
|
||||
assert result[6]["test_step_counter"] == 1
|
||||
|
||||
|
||||
def test_multiple_steps_pipeline():
|
||||
"""Test pipeline with multiple steps."""
|
||||
step1 = MockStep("step1")
|
||||
@@ -179,6 +180,7 @@ def test_multiple_steps_pipeline():
|
||||
assert result[6]["step1_counter"] == 0
|
||||
assert result[6]["step2_counter"] == 0
|
||||
|
||||
|
||||
def test_invalid_transition_format():
|
||||
"""Test pipeline with invalid transition format."""
|
||||
pipeline = RobotPipeline([MockStep()])
|
||||
@@ -191,6 +193,7 @@ def test_invalid_transition_format():
|
||||
with pytest.raises(ValueError, match="EnvTransition must be a 7-tuple"):
|
||||
pipeline("not a tuple")
|
||||
|
||||
|
||||
def test_step_through():
|
||||
"""Test step_through method."""
|
||||
step1 = MockStep("step1")
|
||||
@@ -206,6 +209,7 @@ def test_step_through():
|
||||
assert "step1_counter" in results[1][6] # After step1
|
||||
assert "step2_counter" in results[2][6] # After step2
|
||||
|
||||
|
||||
def test_indexing():
|
||||
"""Test pipeline indexing."""
|
||||
step1 = MockStep("step1")
|
||||
@@ -222,6 +226,7 @@ def test_indexing():
|
||||
assert len(sub_pipeline) == 1
|
||||
assert sub_pipeline[0] is step1
|
||||
|
||||
|
||||
def test_hooks():
|
||||
"""Test before/after step hooks."""
|
||||
step = MockStep("test_step")
|
||||
@@ -247,6 +252,7 @@ def test_hooks():
|
||||
assert before_calls == [0]
|
||||
assert after_calls == [0]
|
||||
|
||||
|
||||
def test_hook_modification():
|
||||
"""Test that hooks can modify transitions."""
|
||||
step = MockStep("test_step")
|
||||
@@ -263,6 +269,7 @@ def test_hook_modification():
|
||||
|
||||
assert result[2] == 42.0 # reward modified by hook
|
||||
|
||||
|
||||
def test_reset():
|
||||
"""Test pipeline reset functionality."""
|
||||
step = MockStep("test_step")
|
||||
@@ -288,6 +295,7 @@ def test_reset():
|
||||
assert step.counter == 0
|
||||
assert len(reset_called) == 1
|
||||
|
||||
|
||||
def test_profile_steps():
|
||||
"""Test step profiling functionality."""
|
||||
step1 = MockStep("step1")
|
||||
@@ -303,6 +311,7 @@ def test_profile_steps():
|
||||
assert "step_1_MockStep" in profile_results
|
||||
assert all(isinstance(time, float) and time >= 0 for time in profile_results.values())
|
||||
|
||||
|
||||
def test_save_and_load_pretrained():
|
||||
"""Test saving and loading pipeline.
|
||||
|
||||
@@ -349,6 +358,7 @@ def test_save_and_load_pretrained():
|
||||
assert loaded_pipeline.steps[0].counter == 5
|
||||
assert loaded_pipeline.steps[1].counter == 10
|
||||
|
||||
|
||||
def test_step_without_optional_methods():
|
||||
"""Test pipeline with steps that don't implement optional methods."""
|
||||
step = MockStepWithoutOptionalMethods(multiplier=3.0)
|
||||
@@ -368,6 +378,7 @@ def test_step_without_optional_methods():
|
||||
loaded_pipeline = RobotPipeline.from_pretrained(tmp_dir)
|
||||
assert len(loaded_pipeline) == 1
|
||||
|
||||
|
||||
def test_mixed_json_and_tensor_state():
|
||||
"""Test step with both JSON attributes and tensor state."""
|
||||
step = MockStepWithTensorState(name="stats", learning_rate=0.05, window_size=5)
|
||||
@@ -404,5 +415,3 @@ def test_mixed_json_and_tensor_state():
|
||||
# Check tensor state was restored
|
||||
assert loaded_step.running_count.item() == 10
|
||||
assert torch.allclose(loaded_step.running_mean, step.running_mean)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user