mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-11 14:49:43 +00:00
Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| db5c26f07d | |||
| 8904768db4 |
@@ -175,6 +175,11 @@ video_benchmark = ["scikit-image>=0.23.2,<0.26.0", "pandas>=2.2.2,<2.4.0"]
|
||||
aloha = ["gym-aloha>=0.1.2,<0.2.0", "lerobot[scipy-dep]"]
|
||||
pusht = ["gym-pusht>=0.1.5,<0.2.0", "pymunk>=6.6.0,<7.0.0"] # TODO: Fix pymunk version in gym-pusht instead
|
||||
libero = ["lerobot[transformers-dep]", "hf-libero>=0.1.3,<0.2.0; sys_platform == 'linux'", "lerobot[scipy-dep]"]
|
||||
libero_plus = [
|
||||
"lerobot[transformers-dep]",
|
||||
"libero @ git+https://github.com/sylvestf/LIBERO-plus.git@main ; sys_platform == 'linux'",
|
||||
"lerobot[scipy-dep]",
|
||||
]
|
||||
metaworld = ["metaworld==3.0.0", "lerobot[scipy-dep]"]
|
||||
|
||||
# All
|
||||
|
||||
@@ -346,6 +346,65 @@ class LiberoEnv(EnvConfig):
|
||||
return kwargs
|
||||
|
||||
|
||||
@EnvConfig.register_subclass("libero_plus")
|
||||
@dataclass
|
||||
class LiberoPlusEnv(LiberoEnv):
|
||||
"""Alias config for LIBERO-plus benchmarks.
|
||||
|
||||
LIBERO-plus keeps the same Python package/module names as LIBERO, so this
|
||||
config reuses the existing LIBERO env implementation while making intent explicit
|
||||
in experiment configs (`env.type=libero_plus`).
|
||||
"""
|
||||
|
||||
task: str = "libero_spatial"
|
||||
|
||||
|
||||
@EnvConfig.register_subclass("robocasa")
|
||||
@dataclass
|
||||
class RoboCasaEnv(EnvConfig):
|
||||
"""RoboCasa kitchen composite-task environments.
|
||||
|
||||
Wraps ``robocasa.wrappers.gym_wrapper.RoboCasaGymEnv`` with a flat 12-D Box
|
||||
action space and a structured pixel + state observation dict.
|
||||
|
||||
Selected benchmark tasks (3 short + 2 long):
|
||||
Short: PickPlaceCounterToCabinet, PrepareToast, CoffeeSetupMug
|
||||
Long: PrepareCoffee, RestockPantry
|
||||
"""
|
||||
|
||||
task: str = "PickPlaceCounterToCabinet"
|
||||
tasks: list[str] | None = None # multi-task: list of task names (without robocasa/ prefix)
|
||||
fps: int = 20
|
||||
episode_length: int = 500
|
||||
image_size: int = 128
|
||||
split: str = "target" # "pretrain" or "target"
|
||||
features: dict[str, PolicyFeature] = field(
|
||||
default_factory=lambda: {
|
||||
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(12,)),
|
||||
}
|
||||
)
|
||||
features_map: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
ACTION: ACTION,
|
||||
"agentview_left": f"{OBS_IMAGES}.agentview_left",
|
||||
"agentview_right": f"{OBS_IMAGES}.agentview_right",
|
||||
"eye_in_hand": f"{OBS_IMAGES}.eye_in_hand",
|
||||
"robot_state": OBS_STATE,
|
||||
}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
for cam in ("agentview_left", "agentview_right", "eye_in_hand"):
|
||||
self.features[cam] = PolicyFeature(
|
||||
type=FeatureType.VISUAL, shape=(self.image_size, self.image_size, 3)
|
||||
)
|
||||
self.features["robot_state"] = PolicyFeature(type=FeatureType.STATE, shape=(16,))
|
||||
|
||||
@property
|
||||
def gym_kwargs(self) -> dict:
|
||||
return {"split": self.split}
|
||||
|
||||
|
||||
@EnvConfig.register_subclass("metaworld")
|
||||
@dataclass
|
||||
class MetaworldEnv(EnvConfig):
|
||||
|
||||
@@ -20,11 +20,20 @@ import gymnasium as gym
|
||||
from gymnasium.envs.registration import registry as gym_registry
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.envs.configs import AlohaEnv, EnvConfig, HubEnvConfig, IsaaclabArenaEnv, LiberoEnv, PushtEnv
|
||||
from lerobot.envs.configs import (
|
||||
AlohaEnv,
|
||||
EnvConfig,
|
||||
HubEnvConfig,
|
||||
IsaaclabArenaEnv,
|
||||
LiberoEnv,
|
||||
LiberoPlusEnv,
|
||||
PushtEnv,
|
||||
RoboCasaEnv,
|
||||
)
|
||||
from lerobot.envs.utils import _call_make_env, _download_hub_file, _import_hub_module, _normalize_hub_result
|
||||
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
|
||||
from lerobot.processor import ProcessorStep
|
||||
from lerobot.processor.env_processor import IsaaclabArenaProcessorStep, LiberoProcessorStep
|
||||
from lerobot.processor.env_processor import IsaaclabArenaProcessorStep, LiberoProcessorStep, RoboCasaProcessorStep
|
||||
from lerobot.processor.pipeline import PolicyProcessorPipeline
|
||||
|
||||
|
||||
@@ -35,6 +44,10 @@ def make_env_config(env_type: str, **kwargs) -> EnvConfig:
|
||||
return PushtEnv(**kwargs)
|
||||
elif env_type == "libero":
|
||||
return LiberoEnv(**kwargs)
|
||||
elif env_type == "libero_plus":
|
||||
return LiberoPlusEnv(**kwargs)
|
||||
elif env_type == "robocasa":
|
||||
return RoboCasaEnv(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Policy type '{env_type}' is not available.")
|
||||
|
||||
@@ -70,9 +83,13 @@ def make_env_pre_post_processors(
|
||||
return make_xvla_libero_pre_post_processors()
|
||||
|
||||
# For LIBERO environments, add the LiberoProcessorStep to preprocessor
|
||||
if isinstance(env_cfg, LiberoEnv) or "libero" in env_cfg.type:
|
||||
if isinstance(env_cfg, (LiberoEnv, LiberoPlusEnv)) or "libero" in env_cfg.type:
|
||||
preprocessor_steps.append(LiberoProcessorStep())
|
||||
|
||||
# For RoboCasa environments, add the RoboCasaProcessorStep to preprocessor
|
||||
if isinstance(env_cfg, RoboCasaEnv) or "robocasa" in env_cfg.type:
|
||||
preprocessor_steps.append(RoboCasaProcessorStep())
|
||||
|
||||
# For Isaaclab Arena environments, add the IsaaclabArenaProcessorStep
|
||||
if isinstance(env_cfg, IsaaclabArenaEnv) or "isaaclab_arena" in env_cfg.type:
|
||||
# Parse comma-separated keys (handle None for state-based policies)
|
||||
@@ -181,6 +198,20 @@ def make_env(
|
||||
control_mode=cfg.control_mode,
|
||||
episode_length=cfg.episode_length,
|
||||
)
|
||||
elif "robocasa" in cfg.type:
|
||||
from lerobot.envs.robocasa import create_robocasa_envs
|
||||
|
||||
tasks = cfg.tasks if cfg.tasks else [cfg.task]
|
||||
return create_robocasa_envs(
|
||||
tasks=tasks,
|
||||
n_envs=n_envs,
|
||||
image_size=cfg.image_size,
|
||||
split=cfg.split,
|
||||
episode_length=cfg.episode_length,
|
||||
gym_kwargs=cfg.gym_kwargs,
|
||||
env_cls=env_cls,
|
||||
)
|
||||
|
||||
elif "metaworld" in cfg.type:
|
||||
from lerobot.envs.metaworld import create_metaworld_envs
|
||||
|
||||
|
||||
@@ -26,8 +26,14 @@ import gymnasium as gym
|
||||
import numpy as np
|
||||
import torch
|
||||
from gymnasium import spaces
|
||||
from libero.libero import benchmark, get_libero_path
|
||||
from libero.libero.envs import OffScreenRenderEnv
|
||||
|
||||
try:
|
||||
from libero.libero import benchmark, get_libero_path
|
||||
from libero.libero.envs import OffScreenRenderEnv
|
||||
except ImportError:
|
||||
# LIBERO-plus may be installed from source with an extra nested package level.
|
||||
from libero.libero.libero import benchmark, get_libero_path
|
||||
from libero.libero.libero.envs import OffScreenRenderEnv
|
||||
|
||||
from lerobot.processor import RobotObservation
|
||||
|
||||
|
||||
@@ -0,0 +1,273 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable, Sequence
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
from gymnasium import spaces
|
||||
|
||||
|
||||
# Action layout (flat 12D, normalized to [-1, 1]):
|
||||
# [0:3] end_effector_position (delta x, y, z)
|
||||
# [3:6] end_effector_rotation (delta roll, pitch, yaw)
|
||||
# [6:7] gripper_close (open=-1, close=+1)
|
||||
# [7:11] base_motion (x, y, theta, torso_height)
|
||||
# [11:12] control_mode (arm=-1, base=+1)
|
||||
ACTION_DIM = 12
|
||||
ACTION_LOW = -1.0
|
||||
ACTION_HIGH = 1.0
|
||||
|
||||
# Proprioceptive state layout (flat 16D):
|
||||
# [0:2] gripper_qpos
|
||||
# [2:5] base_position
|
||||
# [5:9] base_rotation (quaternion)
|
||||
# [9:12] end_effector_position_relative
|
||||
# [12:16] end_effector_rotation_relative (quaternion)
|
||||
STATE_DIM = 16
|
||||
|
||||
# Obs dict keys from RoboCasaGymEnv.get_observation()
|
||||
_CAM_KEYS = (
|
||||
"video.robot0_agentview_left",
|
||||
"video.robot0_agentview_right",
|
||||
"video.robot0_eye_in_hand",
|
||||
)
|
||||
_STATE_KEYS_ORDERED = (
|
||||
"state.gripper_qpos", # (2,)
|
||||
"state.base_position", # (3,)
|
||||
"state.base_rotation", # (4,)
|
||||
"state.end_effector_position_relative", # (3,)
|
||||
"state.end_effector_rotation_relative", # (4,)
|
||||
)
|
||||
|
||||
# Mapping from video.* key → short image name used in features_map
|
||||
CAM_KEY_TO_NAME = {
|
||||
"video.robot0_agentview_left": "agentview_left",
|
||||
"video.robot0_agentview_right": "agentview_right",
|
||||
"video.robot0_eye_in_hand": "eye_in_hand",
|
||||
}
|
||||
|
||||
|
||||
def _flat_to_action_dict(flat: np.ndarray) -> dict[str, np.ndarray]:
|
||||
"""Convert a 12D flat action array to the Dict format expected by RoboCasaGymEnv."""
|
||||
return {
|
||||
"action.end_effector_position": flat[0:3],
|
||||
"action.end_effector_rotation": flat[3:6],
|
||||
"action.gripper_close": flat[6:7],
|
||||
"action.base_motion": flat[7:11],
|
||||
"action.control_mode": flat[11:12],
|
||||
}
|
||||
|
||||
|
||||
class RoboCasaEnv(gym.Env):
|
||||
"""Thin wrapper around RoboCasaGymEnv that provides a flat Box action space
|
||||
and a structured observation dict compatible with LeRobot policies.
|
||||
|
||||
Observations returned by step/reset:
|
||||
{
|
||||
"pixels": {
|
||||
"agentview_left": (H, W, 3) uint8,
|
||||
"agentview_right": (H, W, 3) uint8,
|
||||
"eye_in_hand": (H, W, 3) uint8,
|
||||
},
|
||||
"robot_state": (16,) float32,
|
||||
}
|
||||
|
||||
Actions: flat float32 ndarray of shape (12,), normalized to [-1, 1].
|
||||
"""
|
||||
|
||||
metadata = {"render_modes": ["rgb_array"], "render_fps": 20}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
task: str,
|
||||
split: str = "target",
|
||||
image_size: int = 128,
|
||||
render_mode: str = "rgb_array",
|
||||
episode_length: int = 500,
|
||||
**gym_kwargs: Any,
|
||||
):
|
||||
super().__init__()
|
||||
# Lazy import — robocasa is optional
|
||||
import robocasa.environments # noqa: F401 — registers all gym envs
|
||||
|
||||
self.task = task
|
||||
self.render_mode = render_mode
|
||||
self.image_size = image_size
|
||||
self._max_episode_steps = episode_length
|
||||
self._step_count = 0
|
||||
|
||||
self._env = gym.make(
|
||||
f"robocasa/{task}",
|
||||
split=split,
|
||||
camera_widths=image_size,
|
||||
camera_heights=image_size,
|
||||
**gym_kwargs,
|
||||
)
|
||||
|
||||
# Flat 12D Box action space
|
||||
self.action_space = spaces.Box(
|
||||
low=ACTION_LOW,
|
||||
high=ACTION_HIGH,
|
||||
shape=(ACTION_DIM,),
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
images = {
|
||||
name: spaces.Box(low=0, high=255, shape=(image_size, image_size, 3), dtype=np.uint8)
|
||||
for name in CAM_KEY_TO_NAME.values()
|
||||
}
|
||||
self.observation_space = spaces.Dict(
|
||||
{
|
||||
"pixels": spaces.Dict(images),
|
||||
"robot_state": spaces.Box(
|
||||
low=-np.inf, high=np.inf, shape=(STATE_DIM,), dtype=np.float32
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
def _format_obs(self, raw_obs: dict) -> dict:
|
||||
pixels = {
|
||||
CAM_KEY_TO_NAME[k]: raw_obs[k]
|
||||
for k in _CAM_KEYS
|
||||
if k in raw_obs
|
||||
}
|
||||
state_parts = [
|
||||
np.asarray(raw_obs[k], dtype=np.float32)
|
||||
for k in _STATE_KEYS_ORDERED
|
||||
if k in raw_obs
|
||||
]
|
||||
robot_state = np.concatenate(state_parts) if state_parts else np.zeros(STATE_DIM, dtype=np.float32)
|
||||
return {"pixels": pixels, "robot_state": robot_state}
|
||||
|
||||
def reset(self, seed: int | None = None, **kwargs) -> tuple[dict, dict]:
|
||||
super().reset(seed=seed)
|
||||
self._step_count = 0
|
||||
raw_obs, info = self._env.reset(seed=seed)
|
||||
info.setdefault("is_success", False)
|
||||
info["task"] = self.task
|
||||
return self._format_obs(raw_obs), info
|
||||
|
||||
def step(self, action: np.ndarray) -> tuple[dict, float, bool, bool, dict]:
|
||||
if action.ndim != 1 or action.shape[0] != ACTION_DIM:
|
||||
raise ValueError(
|
||||
f"Expected 1-D action of shape ({ACTION_DIM},), got {action.shape}"
|
||||
)
|
||||
action_dict = _flat_to_action_dict(action)
|
||||
raw_obs, reward, terminated, truncated, info = self._env.step(action_dict)
|
||||
self._step_count += 1
|
||||
|
||||
is_success = bool(info.get("success", False))
|
||||
terminated = terminated or is_success
|
||||
if self._step_count >= self._max_episode_steps:
|
||||
truncated = True
|
||||
|
||||
info.update({"task": self.task, "is_success": is_success})
|
||||
obs = self._format_obs(raw_obs)
|
||||
|
||||
if terminated or truncated:
|
||||
info["final_info"] = {"task": self.task, "is_success": is_success}
|
||||
|
||||
return obs, reward, terminated, truncated, info
|
||||
|
||||
def render(self) -> np.ndarray | None:
|
||||
if self.render_mode == "rgb_array":
|
||||
return self._env.render()
|
||||
return None
|
||||
|
||||
def close(self) -> None:
|
||||
self._env.close()
|
||||
|
||||
|
||||
def _make_env_fns(
|
||||
*,
|
||||
task: str,
|
||||
n_envs: int,
|
||||
image_size: int,
|
||||
split: str,
|
||||
episode_length: int,
|
||||
gym_kwargs: dict[str, Any],
|
||||
) -> list[Callable[[], RoboCasaEnv]]:
|
||||
"""Build n_envs factory callables for a single task."""
|
||||
def _make(episode_index: int) -> RoboCasaEnv: # noqa: ARG001
|
||||
return RoboCasaEnv(
|
||||
task=task,
|
||||
split=split,
|
||||
image_size=image_size,
|
||||
episode_length=episode_length,
|
||||
**gym_kwargs,
|
||||
)
|
||||
|
||||
return [partial(_make, i) for i in range(n_envs)]
|
||||
|
||||
|
||||
def create_robocasa_envs(
|
||||
tasks: str | Sequence[str],
|
||||
n_envs: int,
|
||||
image_size: int = 128,
|
||||
split: str = "target",
|
||||
episode_length: int = 500,
|
||||
gym_kwargs: dict[str, Any] | None = None,
|
||||
env_cls: Callable[[Sequence[Callable[[], Any]]], Any] | None = None,
|
||||
) -> dict[str, dict[int, Any]]:
|
||||
"""Create vectorized RoboCasa environments.
|
||||
|
||||
Args:
|
||||
tasks: A single task name or list of task names (without "robocasa/" prefix).
|
||||
E.g. "PickPlaceCounterToCabinet" or ["BoilPot", "PrepareCoffee"].
|
||||
n_envs: Number of parallel envs per task.
|
||||
image_size: Square image resolution for all cameras.
|
||||
split: RoboCasa dataset split — "pretrain" or "target".
|
||||
episode_length: Max steps per episode before truncation.
|
||||
gym_kwargs: Extra kwargs forwarded to each RoboCasaEnv.
|
||||
env_cls: Callable to wrap list of factory fns (SyncVectorEnv or AsyncVectorEnv).
|
||||
|
||||
Returns:
|
||||
dict[task_name][task_id=0] -> vec_env
|
||||
"""
|
||||
if env_cls is None or not callable(env_cls):
|
||||
raise ValueError("env_cls must be a callable wrapping a list of env factory callables.")
|
||||
if not isinstance(n_envs, int) or n_envs <= 0:
|
||||
raise ValueError(f"n_envs must be a positive int; got {n_envs}.")
|
||||
|
||||
if isinstance(tasks, str):
|
||||
task_list = [t.strip() for t in tasks.split(",") if t.strip()]
|
||||
else:
|
||||
task_list = [str(t).strip() for t in tasks if str(t).strip()]
|
||||
if not task_list:
|
||||
raise ValueError("`tasks` must contain at least one task name.")
|
||||
|
||||
gym_kwargs = dict(gym_kwargs or {})
|
||||
out: dict[str, dict[int, Any]] = defaultdict(dict)
|
||||
|
||||
print(f"Creating RoboCasa envs | tasks={task_list} | n_envs(per task)={n_envs} | split={split}")
|
||||
for task in task_list:
|
||||
fns = _make_env_fns(
|
||||
task=task,
|
||||
n_envs=n_envs,
|
||||
image_size=image_size,
|
||||
split=split,
|
||||
episode_length=episode_length,
|
||||
gym_kwargs=gym_kwargs,
|
||||
)
|
||||
out["robocasa"][len(out["robocasa"])] = env_cls(fns)
|
||||
print(f" Built vec env | task={task} | n_envs={n_envs}")
|
||||
|
||||
return {suite: dict(task_map) for suite, task_map in out.items()}
|
||||
@@ -153,6 +153,44 @@ class LiberoProcessorStep(ObservationProcessorStep):
|
||||
return result
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="robocasa_processor")
|
||||
class RoboCasaProcessorStep(ObservationProcessorStep):
|
||||
"""
|
||||
Processes RoboCasa observations into LeRobot format.
|
||||
|
||||
The RoboCasaEnv wrapper returns:
|
||||
- ``pixels.<cam_name>``: (B, C, H, W) float32 images (already converted by vectorenv)
|
||||
- ``observation.robot_state``: (B, 16) float32 proprioception
|
||||
|
||||
This step remaps them to:
|
||||
- ``observation.images.<cam_name>`` (unchanged tensor)
|
||||
- ``observation.state`` (robot_state renamed)
|
||||
"""
|
||||
|
||||
def _process_observation(self, observation: dict) -> dict:
|
||||
processed = {}
|
||||
obs_prefix = OBS_PREFIX # "observation."
|
||||
|
||||
for key, value in observation.items():
|
||||
if key.startswith(f"{OBS_IMAGES}."):
|
||||
# Already in the right place; pass through
|
||||
processed[key] = value
|
||||
elif key == OBS_STATE or key == f"{obs_prefix}robot_state":
|
||||
# Rename robot_state → observation.state
|
||||
processed[OBS_STATE] = value.float() if hasattr(value, "float") else value
|
||||
|
||||
return processed
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
return features
|
||||
|
||||
def observation(self, observation: dict) -> dict:
|
||||
return self._process_observation(observation)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="isaaclab_arena_processor")
|
||||
class IsaaclabArenaProcessorStep(ObservationProcessorStep):
|
||||
|
||||
@@ -0,0 +1,176 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for RoboCasa LeRobot integration.
|
||||
|
||||
Requires: robocasa installed + kitchen assets downloaded.
|
||||
Tests are skipped automatically if robocasa is not available.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
# Skip entire module if robocasa is not installed or assets are missing
|
||||
robocasa = pytest.importorskip("robocasa", reason="robocasa not installed")
|
||||
|
||||
from lerobot.envs.robocasa import ACTION_DIM, STATE_DIM, CAM_KEY_TO_NAME, RoboCasaEnv, create_robocasa_envs
|
||||
|
||||
# The 5 benchmark tasks (3 short + 2 long)
|
||||
BENCHMARK_TASKS = [
|
||||
"PickPlaceCounterToCabinet", # short
|
||||
"PrepareToast", # short
|
||||
"CoffeeSetupMug", # short
|
||||
"PrepareCoffee", # long
|
||||
"RestockPantry", # long
|
||||
]
|
||||
SHORT_TASKS = BENCHMARK_TASKS[:3]
|
||||
LONG_TASKS = BENCHMARK_TASKS[3:]
|
||||
|
||||
IMAGE_SIZE = 64 # small for fast tests
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def single_env():
|
||||
"""Shared env instance for lightweight tests."""
|
||||
env = RoboCasaEnv(task="PickPlaceCounterToCabinet", image_size=IMAGE_SIZE)
|
||||
yield env
|
||||
env.close()
|
||||
|
||||
|
||||
class TestRoboCasaEnvSpaces:
|
||||
def test_action_space_is_flat_box(self, single_env):
|
||||
import gymnasium as gym
|
||||
|
||||
assert isinstance(single_env.action_space, gym.spaces.Box)
|
||||
assert single_env.action_space.shape == (ACTION_DIM,)
|
||||
assert single_env.action_space.dtype == np.float32
|
||||
|
||||
def test_action_bounds(self, single_env):
|
||||
assert np.all(single_env.action_space.low == -1.0)
|
||||
assert np.all(single_env.action_space.high == 1.0)
|
||||
|
||||
def test_observation_space_has_pixels_and_state(self, single_env):
|
||||
import gymnasium as gym
|
||||
|
||||
assert isinstance(single_env.observation_space, gym.spaces.Dict)
|
||||
assert "pixels" in single_env.observation_space.spaces
|
||||
assert "robot_state" in single_env.observation_space.spaces
|
||||
|
||||
def test_observation_space_cameras(self, single_env):
|
||||
pixels_space = single_env.observation_space["pixels"]
|
||||
expected_cams = set(CAM_KEY_TO_NAME.values())
|
||||
assert set(pixels_space.spaces.keys()) == expected_cams
|
||||
|
||||
def test_state_dim(self, single_env):
|
||||
state_space = single_env.observation_space["robot_state"]
|
||||
assert state_space.shape == (STATE_DIM,)
|
||||
|
||||
|
||||
class TestRoboCasaEnvReset:
|
||||
def test_reset_returns_obs_and_info(self, single_env):
|
||||
obs, info = single_env.reset()
|
||||
assert isinstance(obs, dict)
|
||||
assert isinstance(info, dict)
|
||||
|
||||
def test_reset_obs_has_pixels(self, single_env):
|
||||
obs, _ = single_env.reset()
|
||||
assert "pixels" in obs
|
||||
for cam_name in CAM_KEY_TO_NAME.values():
|
||||
assert cam_name in obs["pixels"], f"Missing camera: {cam_name}"
|
||||
|
||||
def test_reset_obs_image_shape(self, single_env):
|
||||
obs, _ = single_env.reset()
|
||||
for cam_name, img in obs["pixels"].items():
|
||||
assert img.shape == (IMAGE_SIZE, IMAGE_SIZE, 3), f"Bad shape for {cam_name}: {img.shape}"
|
||||
assert img.dtype == np.uint8
|
||||
|
||||
def test_reset_obs_state_shape(self, single_env):
|
||||
obs, _ = single_env.reset()
|
||||
assert obs["robot_state"].shape == (STATE_DIM,)
|
||||
assert obs["robot_state"].dtype == np.float32
|
||||
|
||||
def test_reset_info_has_task(self, single_env):
|
||||
_, info = single_env.reset()
|
||||
assert "task" in info
|
||||
assert info["task"] == "PickPlaceCounterToCabinet"
|
||||
|
||||
|
||||
class TestRoboCasaEnvStep:
|
||||
def test_step_10_random_actions(self, single_env):
|
||||
single_env.reset()
|
||||
for _ in range(10):
|
||||
action = single_env.action_space.sample()
|
||||
obs, reward, terminated, truncated, info = single_env.step(action)
|
||||
assert obs["robot_state"].shape == (STATE_DIM,)
|
||||
assert isinstance(reward, float)
|
||||
assert isinstance(terminated, bool)
|
||||
assert isinstance(truncated, bool)
|
||||
|
||||
def test_step_bad_action_raises(self, single_env):
|
||||
single_env.reset()
|
||||
with pytest.raises(ValueError, match="Expected 1-D action"):
|
||||
single_env.step(np.zeros((2, ACTION_DIM)))
|
||||
|
||||
def test_step_info_has_is_success(self, single_env):
|
||||
single_env.reset()
|
||||
_, _, _, _, info = single_env.step(single_env.action_space.sample())
|
||||
assert "is_success" in info
|
||||
|
||||
|
||||
class TestRoboCasaConfig:
|
||||
def test_robocasa_env_config(self):
|
||||
from lerobot.envs.configs import RoboCasaEnv as RoboCasaEnvConfig
|
||||
from lerobot.configs.types import FeatureType
|
||||
|
||||
cfg = RoboCasaEnvConfig(task="PickPlaceCounterToCabinet", image_size=IMAGE_SIZE)
|
||||
assert cfg.type == "robocasa"
|
||||
# action feature
|
||||
assert "action" in cfg.features
|
||||
assert cfg.features["action"].shape == (ACTION_DIM,)
|
||||
# camera features
|
||||
for cam in ("agentview_left", "agentview_right", "eye_in_hand"):
|
||||
assert cam in cfg.features
|
||||
assert cfg.features[cam].type == FeatureType.VISUAL
|
||||
assert cfg.features[cam].shape == (IMAGE_SIZE, IMAGE_SIZE, 3)
|
||||
# state feature
|
||||
assert "robot_state" in cfg.features
|
||||
assert cfg.features["robot_state"].shape == (STATE_DIM,)
|
||||
|
||||
def test_make_env_config_robocasa(self):
|
||||
from lerobot.envs.factory import make_env_config
|
||||
cfg = make_env_config("robocasa", task="PickPlaceCounterToCabinet")
|
||||
assert cfg.type == "robocasa"
|
||||
|
||||
|
||||
class TestRoboCasaProcessorStep:
|
||||
def test_processor_remaps_keys(self):
|
||||
import torch
|
||||
from lerobot.processor.env_processor import RoboCasaProcessorStep
|
||||
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
|
||||
|
||||
step = RoboCasaProcessorStep()
|
||||
B = 2
|
||||
obs = {
|
||||
f"{OBS_IMAGES}.agentview_left": torch.zeros(B, 3, IMAGE_SIZE, IMAGE_SIZE),
|
||||
f"{OBS_IMAGES}.agentview_right": torch.zeros(B, 3, IMAGE_SIZE, IMAGE_SIZE),
|
||||
f"{OBS_IMAGES}.eye_in_hand": torch.zeros(B, 3, IMAGE_SIZE, IMAGE_SIZE),
|
||||
f"observation.robot_state": torch.zeros(B, STATE_DIM),
|
||||
}
|
||||
out = step._process_observation(obs)
|
||||
assert OBS_STATE in out
|
||||
assert out[OBS_STATE].dtype == torch.float32
|
||||
for cam in ("agentview_left", "agentview_right", "eye_in_hand"):
|
||||
assert f"{OBS_IMAGES}.{cam}" in out
|
||||
Reference in New Issue
Block a user