mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 12:09:42 +00:00
feat(envs): add RoboMME memory-augmented manipulation benchmark
- RoboMMEEnv config with 16 tasks across 4 suites (Counting, Permanence, Reference, Imitation) - Gymnasium wrapper around BenchmarkEnvBuilder (robomme.py) - Environment factory wiring for env_type="robomme" - robomme optional dependency in pyproject.toml Made-with: Cursor
This commit is contained in:
@@ -180,6 +180,9 @@ libero_plus = [
|
|||||||
"libero @ git+https://github.com/sylvestf/LIBERO-plus.git@main ; sys_platform == 'linux'",
|
"libero @ git+https://github.com/sylvestf/LIBERO-plus.git@main ; sys_platform == 'linux'",
|
||||||
"lerobot[scipy-dep]",
|
"lerobot[scipy-dep]",
|
||||||
]
|
]
|
||||||
|
robomme = [
|
||||||
|
"robomme @ git+https://github.com/RoboMME/robomme_benchmark.git@main ; sys_platform == 'linux'",
|
||||||
|
]
|
||||||
metaworld = ["metaworld==3.0.0", "lerobot[scipy-dep]"]
|
metaworld = ["metaworld==3.0.0", "lerobot[scipy-dep]"]
|
||||||
|
|
||||||
# All
|
# All
|
||||||
|
|||||||
@@ -405,6 +405,46 @@ class RoboCasaEnv(EnvConfig):
|
|||||||
return {"split": self.split}
|
return {"split": self.split}
|
||||||
|
|
||||||
|
|
||||||
|
@EnvConfig.register_subclass("robomme")
|
||||||
|
@dataclass
|
||||||
|
class RoboMMEEnv(EnvConfig):
|
||||||
|
"""RoboMME memory-augmented manipulation benchmark (ManiSkill/SAPIEN).
|
||||||
|
|
||||||
|
16 tasks across 4 suites: Counting, Permanence, Reference, Imitation.
|
||||||
|
Uses BenchmarkEnvBuilder from the robomme package.
|
||||||
|
"""
|
||||||
|
|
||||||
|
task: str = "PickXtimes"
|
||||||
|
fps: int = 10
|
||||||
|
episode_length: int = 300
|
||||||
|
action_space: str = "joint_angle"
|
||||||
|
dataset_split: str = "test"
|
||||||
|
task_ids: list[int] | None = None
|
||||||
|
features: dict[str, PolicyFeature] = field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(8,)),
|
||||||
|
"front_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(256, 256, 3)),
|
||||||
|
"wrist_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(256, 256, 3)),
|
||||||
|
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(8,)),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
features_map: dict[str, str] = field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
ACTION: ACTION,
|
||||||
|
"front_rgb": f"{OBS_IMAGES}.front",
|
||||||
|
"wrist_rgb": f"{OBS_IMAGES}.wrist",
|
||||||
|
OBS_STATE: OBS_STATE,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def gym_kwargs(self) -> dict:
|
||||||
|
return {
|
||||||
|
"action_space": self.action_space,
|
||||||
|
"dataset": self.dataset_split,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@EnvConfig.register_subclass("metaworld")
|
@EnvConfig.register_subclass("metaworld")
|
||||||
@dataclass
|
@dataclass
|
||||||
class MetaworldEnv(EnvConfig):
|
class MetaworldEnv(EnvConfig):
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ from lerobot.envs.configs import (
|
|||||||
LiberoPlusEnv,
|
LiberoPlusEnv,
|
||||||
PushtEnv,
|
PushtEnv,
|
||||||
RoboCasaEnv,
|
RoboCasaEnv,
|
||||||
|
RoboMMEEnv,
|
||||||
)
|
)
|
||||||
from lerobot.envs.utils import _call_make_env, _download_hub_file, _import_hub_module, _normalize_hub_result
|
from lerobot.envs.utils import _call_make_env, _download_hub_file, _import_hub_module, _normalize_hub_result
|
||||||
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
|
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
|
||||||
@@ -48,6 +49,8 @@ def make_env_config(env_type: str, **kwargs) -> EnvConfig:
|
|||||||
return LiberoPlusEnv(**kwargs)
|
return LiberoPlusEnv(**kwargs)
|
||||||
elif env_type == "robocasa":
|
elif env_type == "robocasa":
|
||||||
return RoboCasaEnv(**kwargs)
|
return RoboCasaEnv(**kwargs)
|
||||||
|
elif env_type == "robomme":
|
||||||
|
return RoboMMEEnv(**kwargs)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Policy type '{env_type}' is not available.")
|
raise ValueError(f"Policy type '{env_type}' is not available.")
|
||||||
|
|
||||||
@@ -212,6 +215,19 @@ def make_env(
|
|||||||
env_cls=env_cls,
|
env_cls=env_cls,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
elif "robomme" in cfg.type:
|
||||||
|
from lerobot.envs.robomme import create_robomme_envs
|
||||||
|
|
||||||
|
return create_robomme_envs(
|
||||||
|
task=cfg.task,
|
||||||
|
n_envs=n_envs,
|
||||||
|
action_space_type=cfg.action_space,
|
||||||
|
dataset=cfg.dataset_split,
|
||||||
|
episode_length=cfg.episode_length,
|
||||||
|
task_ids=cfg.task_ids,
|
||||||
|
env_cls=env_cls,
|
||||||
|
)
|
||||||
|
|
||||||
elif "metaworld" in cfg.type:
|
elif "metaworld" in cfg.type:
|
||||||
from lerobot.envs.metaworld import create_metaworld_envs
|
from lerobot.envs.metaworld import create_metaworld_envs
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,154 @@
|
|||||||
|
"""RoboMME environment wrapper for LeRobot evaluation.
|
||||||
|
|
||||||
|
Wraps the RoboMME ``BenchmarkEnvBuilder`` into a Gymnasium-compatible
|
||||||
|
``VectorEnv`` suitable for ``lerobot_eval``.
|
||||||
|
|
||||||
|
RoboMME tasks:
|
||||||
|
Counting: BinFill, PickXtimes, SwingXtimes, StopCube
|
||||||
|
Permanence: VideoUnmask, VideoUnmaskSwap, ButtonUnmask, ButtonUnmaskSwap
|
||||||
|
Reference: PickHighlight, VideoRepick, VideoPlaceButton, VideoPlaceOrder
|
||||||
|
Imitation: MoveCube, InsertPeg, PatternLock, RouteStick
|
||||||
|
|
||||||
|
Install: pip install robomme (or from source: https://github.com/RoboMME/robomme_benchmark)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import gymnasium as gym
|
||||||
|
import numpy as np
|
||||||
|
from gymnasium import spaces
|
||||||
|
|
||||||
|
ROBOMME_TASKS = [
|
||||||
|
"BinFill", "PickXtimes", "SwingXtimes", "StopCube",
|
||||||
|
"VideoUnmask", "VideoUnmaskSwap", "ButtonUnmask", "ButtonUnmaskSwap",
|
||||||
|
"PickHighlight", "VideoRepick", "VideoPlaceButton", "VideoPlaceOrder",
|
||||||
|
"MoveCube", "InsertPeg", "PatternLock", "RouteStick",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class RoboMMEGymEnv(gym.Env):
|
||||||
|
"""Thin Gymnasium wrapper around a single RoboMME episode env."""
|
||||||
|
|
||||||
|
metadata = {"render_modes": ["rgb_array"]}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
task: str = "PickXtimes",
|
||||||
|
action_space_type: str = "joint_angle",
|
||||||
|
dataset: str = "test",
|
||||||
|
episode_idx: int = 0,
|
||||||
|
max_steps: int = 300,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
from robomme.env_record_wrapper import BenchmarkEnvBuilder
|
||||||
|
|
||||||
|
self._task = task
|
||||||
|
self._action_space_type = action_space_type
|
||||||
|
self._dataset = dataset
|
||||||
|
self._episode_idx = episode_idx
|
||||||
|
self._max_steps = max_steps
|
||||||
|
|
||||||
|
self._builder = BenchmarkEnvBuilder(
|
||||||
|
env_id=task,
|
||||||
|
dataset=dataset,
|
||||||
|
action_space=action_space_type,
|
||||||
|
gui_render=False,
|
||||||
|
max_steps=max_steps,
|
||||||
|
)
|
||||||
|
self._env = None
|
||||||
|
|
||||||
|
action_dim = 8 if action_space_type == "joint_angle" else 7
|
||||||
|
self.action_space = spaces.Box(low=-1.0, high=1.0, shape=(action_dim,), dtype=np.float32)
|
||||||
|
self.observation_space = spaces.Dict({
|
||||||
|
"front_rgb": spaces.Box(0, 255, shape=(256, 256, 3), dtype=np.uint8),
|
||||||
|
"wrist_rgb": spaces.Box(0, 255, shape=(256, 256, 3), dtype=np.uint8),
|
||||||
|
"state": spaces.Box(-np.inf, np.inf, shape=(8,), dtype=np.float32),
|
||||||
|
})
|
||||||
|
|
||||||
|
def reset(self, *, seed=None, options=None):
|
||||||
|
super().reset(seed=seed)
|
||||||
|
self._env = self._builder.make_env_for_episode(
|
||||||
|
episode_idx=self._episode_idx, max_steps=self._max_steps,
|
||||||
|
)
|
||||||
|
obs, info = self._env.reset()
|
||||||
|
return self._convert_obs(obs), self._convert_info(info)
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
obs, reward, terminated, truncated, info = self._env.step(action)
|
||||||
|
|
||||||
|
terminated_bool = bool(terminated.item()) if hasattr(terminated, "item") else bool(terminated)
|
||||||
|
truncated_bool = bool(truncated.item()) if hasattr(truncated, "item") else bool(truncated)
|
||||||
|
|
||||||
|
status = info.get("status", "ongoing")
|
||||||
|
is_success = status == "success"
|
||||||
|
conv_info = self._convert_info(info)
|
||||||
|
conv_info["is_success"] = is_success
|
||||||
|
|
||||||
|
return self._convert_obs(obs), float(reward), terminated_bool, truncated_bool, conv_info
|
||||||
|
|
||||||
|
def _convert_obs(self, obs: dict) -> dict:
|
||||||
|
front_rgb = obs["front_rgb_list"][-1] if isinstance(obs["front_rgb_list"], list) else obs["front_rgb_list"]
|
||||||
|
wrist_rgb = obs["wrist_rgb_list"][-1] if isinstance(obs["wrist_rgb_list"], list) else obs["wrist_rgb_list"]
|
||||||
|
joint_state = obs["joint_state_list"][-1] if isinstance(obs["joint_state_list"], list) else obs["joint_state_list"]
|
||||||
|
gripper_state = obs["gripper_state_list"][-1] if isinstance(obs["gripper_state_list"], list) else obs["gripper_state_list"]
|
||||||
|
|
||||||
|
front_rgb = np.asarray(front_rgb, dtype=np.uint8)
|
||||||
|
wrist_rgb = np.asarray(wrist_rgb, dtype=np.uint8)
|
||||||
|
joint = np.asarray(joint_state, dtype=np.float32).flatten()[:7]
|
||||||
|
gripper = np.asarray(gripper_state, dtype=np.float32).flatten()[:1]
|
||||||
|
state = np.concatenate([joint, gripper])
|
||||||
|
|
||||||
|
return {
|
||||||
|
"front_rgb": front_rgb,
|
||||||
|
"wrist_rgb": wrist_rgb,
|
||||||
|
"state": state,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _convert_info(self, info: dict) -> dict:
|
||||||
|
return {
|
||||||
|
"status": info.get("status", "ongoing"),
|
||||||
|
"task_goal": info.get("task_goal", ""),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def create_robomme_envs(
|
||||||
|
task: str,
|
||||||
|
n_envs: int = 1,
|
||||||
|
action_space_type: str = "joint_angle",
|
||||||
|
dataset: str = "test",
|
||||||
|
episode_length: int = 300,
|
||||||
|
task_ids: list[int] | None = None,
|
||||||
|
env_cls=None,
|
||||||
|
) -> dict[str, dict[int, gym.vector.VectorEnv]]:
|
||||||
|
"""Create vectorized RoboMME environments for evaluation.
|
||||||
|
|
||||||
|
Returns {suite_name: {task_id: VectorEnv}} matching lerobot's expected format.
|
||||||
|
"""
|
||||||
|
if env_cls is None:
|
||||||
|
env_cls = gym.vector.SyncVectorEnv
|
||||||
|
|
||||||
|
if task_ids is None:
|
||||||
|
task_ids = [0]
|
||||||
|
|
||||||
|
suite_name = "robomme"
|
||||||
|
envs_by_task = {}
|
||||||
|
|
||||||
|
for task_id in task_ids:
|
||||||
|
def _make_one(ep_idx=task_id):
|
||||||
|
return RoboMMEGymEnv(
|
||||||
|
task=task,
|
||||||
|
action_space_type=action_space_type,
|
||||||
|
dataset=dataset,
|
||||||
|
episode_idx=ep_idx,
|
||||||
|
max_steps=episode_length,
|
||||||
|
)
|
||||||
|
|
||||||
|
vec = env_cls(
|
||||||
|
[_make_one for _ in range(n_envs)],
|
||||||
|
autoreset_mode=gym.vector.AutoresetMode.SAME_STEP,
|
||||||
|
)
|
||||||
|
envs_by_task[task_id] = vec
|
||||||
|
|
||||||
|
return {suite_name: envs_by_task}
|
||||||
Reference in New Issue
Block a user