From 7bef12a4613e353971a7fee0b36786fe90b87715 Mon Sep 17 00:00:00 2001 From: pepijn Date: Fri, 13 Mar 2026 04:44:32 +0000 Subject: [PATCH] feat(envs): add RoboMME memory-augmented manipulation benchmark - RoboMMEEnv config with 16 tasks across 4 suites (Counting, Permanence, Reference, Imitation) - Gymnasium wrapper around BenchmarkEnvBuilder (robomme.py) - Environment factory wiring for env_type="robomme" - robomme optional dependency in pyproject.toml Made-with: Cursor --- pyproject.toml | 3 + src/lerobot/envs/configs.py | 40 ++++++++++ src/lerobot/envs/factory.py | 16 ++++ src/lerobot/envs/robomme.py | 154 ++++++++++++++++++++++++++++++++++++ 4 files changed, 213 insertions(+) create mode 100644 src/lerobot/envs/robomme.py diff --git a/pyproject.toml b/pyproject.toml index 585af6f4b..b46868cd6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -180,6 +180,9 @@ libero_plus = [ "libero @ git+https://github.com/sylvestf/LIBERO-plus.git@main ; sys_platform == 'linux'", "lerobot[scipy-dep]", ] +robomme = [ + "robomme @ git+https://github.com/RoboMME/robomme_benchmark.git@main ; sys_platform == 'linux'", +] metaworld = ["metaworld==3.0.0", "lerobot[scipy-dep]"] # All diff --git a/src/lerobot/envs/configs.py b/src/lerobot/envs/configs.py index 2a73dd272..bdc16bc73 100644 --- a/src/lerobot/envs/configs.py +++ b/src/lerobot/envs/configs.py @@ -405,6 +405,46 @@ class RoboCasaEnv(EnvConfig): return {"split": self.split} +@EnvConfig.register_subclass("robomme") +@dataclass +class RoboMMEEnv(EnvConfig): + """RoboMME memory-augmented manipulation benchmark (ManiSkill/SAPIEN). + + 16 tasks across 4 suites: Counting, Permanence, Reference, Imitation. + Uses BenchmarkEnvBuilder from the robomme package. + """ + + task: str = "PickXtimes" + fps: int = 10 + episode_length: int = 300 + action_space: str = "joint_angle" + dataset_split: str = "test" + task_ids: list[int] | None = None + features: dict[str, PolicyFeature] = field( + default_factory=lambda: { + ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(8,)), + "front_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(256, 256, 3)), + "wrist_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(256, 256, 3)), + OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(8,)), + } + ) + features_map: dict[str, str] = field( + default_factory=lambda: { + ACTION: ACTION, + "front_rgb": f"{OBS_IMAGES}.front", + "wrist_rgb": f"{OBS_IMAGES}.wrist", + OBS_STATE: OBS_STATE, + } + ) + + @property + def gym_kwargs(self) -> dict: + return { + "action_space": self.action_space, + "dataset": self.dataset_split, + } + + @EnvConfig.register_subclass("metaworld") @dataclass class MetaworldEnv(EnvConfig): diff --git a/src/lerobot/envs/factory.py b/src/lerobot/envs/factory.py index 365f74088..2810e4025 100644 --- a/src/lerobot/envs/factory.py +++ b/src/lerobot/envs/factory.py @@ -29,6 +29,7 @@ from lerobot.envs.configs import ( LiberoPlusEnv, PushtEnv, RoboCasaEnv, + RoboMMEEnv, ) from lerobot.envs.utils import _call_make_env, _download_hub_file, _import_hub_module, _normalize_hub_result from lerobot.policies.xvla.configuration_xvla import XVLAConfig @@ -48,6 +49,8 @@ def make_env_config(env_type: str, **kwargs) -> EnvConfig: return LiberoPlusEnv(**kwargs) elif env_type == "robocasa": return RoboCasaEnv(**kwargs) + elif env_type == "robomme": + return RoboMMEEnv(**kwargs) else: raise ValueError(f"Policy type '{env_type}' is not available.") @@ -212,6 +215,19 @@ def make_env( env_cls=env_cls, ) + elif "robomme" in cfg.type: + from lerobot.envs.robomme import create_robomme_envs + + return create_robomme_envs( + task=cfg.task, + n_envs=n_envs, + action_space_type=cfg.action_space, + dataset=cfg.dataset_split, + episode_length=cfg.episode_length, + task_ids=cfg.task_ids, + env_cls=env_cls, + ) + elif "metaworld" in cfg.type: from lerobot.envs.metaworld import create_metaworld_envs diff --git a/src/lerobot/envs/robomme.py b/src/lerobot/envs/robomme.py new file mode 100644 index 000000000..e963edf6e --- /dev/null +++ b/src/lerobot/envs/robomme.py @@ -0,0 +1,154 @@ +"""RoboMME environment wrapper for LeRobot evaluation. + +Wraps the RoboMME ``BenchmarkEnvBuilder`` into a Gymnasium-compatible +``VectorEnv`` suitable for ``lerobot_eval``. + +RoboMME tasks: + Counting: BinFill, PickXtimes, SwingXtimes, StopCube + Permanence: VideoUnmask, VideoUnmaskSwap, ButtonUnmask, ButtonUnmaskSwap + Reference: PickHighlight, VideoRepick, VideoPlaceButton, VideoPlaceOrder + Imitation: MoveCube, InsertPeg, PatternLock, RouteStick + +Install: pip install robomme (or from source: https://github.com/RoboMME/robomme_benchmark) +""" + +from __future__ import annotations + +from typing import Any + +import gymnasium as gym +import numpy as np +from gymnasium import spaces + +ROBOMME_TASKS = [ + "BinFill", "PickXtimes", "SwingXtimes", "StopCube", + "VideoUnmask", "VideoUnmaskSwap", "ButtonUnmask", "ButtonUnmaskSwap", + "PickHighlight", "VideoRepick", "VideoPlaceButton", "VideoPlaceOrder", + "MoveCube", "InsertPeg", "PatternLock", "RouteStick", +] + + +class RoboMMEGymEnv(gym.Env): + """Thin Gymnasium wrapper around a single RoboMME episode env.""" + + metadata = {"render_modes": ["rgb_array"]} + + def __init__( + self, + task: str = "PickXtimes", + action_space_type: str = "joint_angle", + dataset: str = "test", + episode_idx: int = 0, + max_steps: int = 300, + ): + super().__init__() + from robomme.env_record_wrapper import BenchmarkEnvBuilder + + self._task = task + self._action_space_type = action_space_type + self._dataset = dataset + self._episode_idx = episode_idx + self._max_steps = max_steps + + self._builder = BenchmarkEnvBuilder( + env_id=task, + dataset=dataset, + action_space=action_space_type, + gui_render=False, + max_steps=max_steps, + ) + self._env = None + + action_dim = 8 if action_space_type == "joint_angle" else 7 + self.action_space = spaces.Box(low=-1.0, high=1.0, shape=(action_dim,), dtype=np.float32) + self.observation_space = spaces.Dict({ + "front_rgb": spaces.Box(0, 255, shape=(256, 256, 3), dtype=np.uint8), + "wrist_rgb": spaces.Box(0, 255, shape=(256, 256, 3), dtype=np.uint8), + "state": spaces.Box(-np.inf, np.inf, shape=(8,), dtype=np.float32), + }) + + def reset(self, *, seed=None, options=None): + super().reset(seed=seed) + self._env = self._builder.make_env_for_episode( + episode_idx=self._episode_idx, max_steps=self._max_steps, + ) + obs, info = self._env.reset() + return self._convert_obs(obs), self._convert_info(info) + + def step(self, action): + obs, reward, terminated, truncated, info = self._env.step(action) + + terminated_bool = bool(terminated.item()) if hasattr(terminated, "item") else bool(terminated) + truncated_bool = bool(truncated.item()) if hasattr(truncated, "item") else bool(truncated) + + status = info.get("status", "ongoing") + is_success = status == "success" + conv_info = self._convert_info(info) + conv_info["is_success"] = is_success + + return self._convert_obs(obs), float(reward), terminated_bool, truncated_bool, conv_info + + def _convert_obs(self, obs: dict) -> dict: + front_rgb = obs["front_rgb_list"][-1] if isinstance(obs["front_rgb_list"], list) else obs["front_rgb_list"] + wrist_rgb = obs["wrist_rgb_list"][-1] if isinstance(obs["wrist_rgb_list"], list) else obs["wrist_rgb_list"] + joint_state = obs["joint_state_list"][-1] if isinstance(obs["joint_state_list"], list) else obs["joint_state_list"] + gripper_state = obs["gripper_state_list"][-1] if isinstance(obs["gripper_state_list"], list) else obs["gripper_state_list"] + + front_rgb = np.asarray(front_rgb, dtype=np.uint8) + wrist_rgb = np.asarray(wrist_rgb, dtype=np.uint8) + joint = np.asarray(joint_state, dtype=np.float32).flatten()[:7] + gripper = np.asarray(gripper_state, dtype=np.float32).flatten()[:1] + state = np.concatenate([joint, gripper]) + + return { + "front_rgb": front_rgb, + "wrist_rgb": wrist_rgb, + "state": state, + } + + def _convert_info(self, info: dict) -> dict: + return { + "status": info.get("status", "ongoing"), + "task_goal": info.get("task_goal", ""), + } + + +def create_robomme_envs( + task: str, + n_envs: int = 1, + action_space_type: str = "joint_angle", + dataset: str = "test", + episode_length: int = 300, + task_ids: list[int] | None = None, + env_cls=None, +) -> dict[str, dict[int, gym.vector.VectorEnv]]: + """Create vectorized RoboMME environments for evaluation. + + Returns {suite_name: {task_id: VectorEnv}} matching lerobot's expected format. + """ + if env_cls is None: + env_cls = gym.vector.SyncVectorEnv + + if task_ids is None: + task_ids = [0] + + suite_name = "robomme" + envs_by_task = {} + + for task_id in task_ids: + def _make_one(ep_idx=task_id): + return RoboMMEGymEnv( + task=task, + action_space_type=action_space_type, + dataset=dataset, + episode_idx=ep_idx, + max_steps=episode_length, + ) + + vec = env_cls( + [_make_one for _ in range(n_envs)], + autoreset_mode=gym.vector.AutoresetMode.SAME_STEP, + ) + envs_by_task[task_id] = vec + + return {suite_name: envs_by_task}