From 4fe5c3ab7029a836813557eda4edcd460f44d749 Mon Sep 17 00:00:00 2001 From: Jade Choghari Date: Tue, 16 Sep 2025 12:05:32 +0200 Subject: [PATCH] Add libero (#1950) * add libero * backup * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add factory * Add LIBERO as a submodule * add changes * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add multitask * remove photos * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * bug remove * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix video paths and train.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix renaming issues with cams * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update .gitignore * final refactor/fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add safethread support * ad * Add dep (#4) * Add 'libero' dependencies to pyproject.toml * Add Git dependencies for egl_probe and LIBERO * Update libero-requirements.txt * add future dep * update bash * quick fix * remove step1 * cleanup (#5) * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * clean * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * cleanup 2 * improve install * Delete libero-requirements.txt * iterate on review * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add docs for eval * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * doc * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update doc * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove brkpt * fix docs * update docs/script * update doc * skip test warning * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * hotfix: flip actions * add train * new things * More things * add new changes * iterate on review * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove files * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove unces * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * new changes * iterate on review * remove files * factor * update installation * doc title * make it reproducible * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * iterate on review * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove files * update tests * update doc * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove sh files * add gym --------- Signed-off-by: Jade Choghari Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jade Choghari (jchoghar) Co-authored-by: Adil Zouitine --- docs/source/_toctree.yml | 2 + docs/source/libero.mdx | 126 ++++++++++ pyproject.toml | 19 +- src/lerobot/envs/configs.py | 53 +++++ src/lerobot/envs/factory.py | 49 ++-- src/lerobot/envs/libero.py | 399 ++++++++++++++++++++++++++++++++ src/lerobot/envs/utils.py | 41 +++- src/lerobot/policies/factory.py | 4 - src/lerobot/scripts/eval.py | 253 ++++++++++++++++++-- src/lerobot/scripts/train.py | 32 ++- tests/envs/test_envs.py | 5 +- tests/policies/test_policies.py | 9 +- 12 files changed, 939 insertions(+), 53 deletions(-) create mode 100644 docs/source/libero.mdx create mode 100644 src/lerobot/envs/libero.py diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 5f5a509c7..2df5fc5bd 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -19,6 +19,8 @@ title: Train RL in Simulation - local: async title: Use Async Inference + - local: libero + title: Using Libero - local: porting_datasets_v3 title: Porting Large Datasets title: "Tutorials" diff --git a/docs/source/libero.mdx b/docs/source/libero.mdx new file mode 100644 index 000000000..488c02ce0 --- /dev/null +++ b/docs/source/libero.mdx @@ -0,0 +1,126 @@ +# LIBERO + +**LIBERO** is a benchmark designed to study **lifelong robot learning**. The idea is that robots wonโ€™t just be pretrained once in a factory, theyโ€™ll need to keep learning and adapting with their human users over time. This ongoing adaptation is called **lifelong learning in decision making (LLDM)**, and itโ€™s a key step toward building robots that become truly personalized helpers. + +- ๐Ÿ“„ [LIBERO paper](https://arxiv.org/abs/2306.03310) +- ๐Ÿ’ป [Original LIBERO repo](https://github.com/Lifelong-Robot-Learning/LIBERO) + +To make progress on this challenge, LIBERO provides a set of standardized tasks that focus on **knowledge transfer**: how well a robot can apply what it has already learned to new situations. By evaluating on LIBERO, different algorithms can be compared fairly and researchers can build on each otherโ€™s work. + +LIBERO includes **five task suites**: + +- **LIBERO-Spatial (`libero_spatial`)** โ€“ tasks that require reasoning about spatial relations. +- **LIBERO-Object (`libero_object`)** โ€“ tasks centered on manipulating different objects. +- **LIBERO-Goal (`libero_goal`)** โ€“ goal-conditioned tasks where the robot must adapt to changing targets. +- **LIBERO-90 (`libero_90`)** โ€“ 90 short-horizon tasks from the LIBERO-100 collection. +- **LIBERO-Long (`libero_10`)** โ€“ 10 long-horizon tasks from the LIBERO-100 collection. + +Together, these suites cover **130 tasks**, ranging from simple object manipulations to complex multi-step scenarios. LIBERO is meant to grow over time, and to serve as a shared benchmark where the community can test and improve lifelong learning algorithms. + +![An overview of the LIBERO benchmark](https://libero-project.github.io/assets/img/libero/fig1.png) + +## Evaluating with LIBERO + +At **LeRobot**, we ported [LIBERO](https://github.com/Lifelong-Robot-Learning/LIBERO) into our framework and used it mainly to **evaluate [SmolVLA](https://huggingface.co/docs/lerobot/en/smolvla)**, our lightweight Vision-Language-Action model. + +LIBERO is now part of our **multi-eval supported simulation**, meaning you can benchmark your policies either on a **single suite of tasks** or across **multiple suites at once** with just a flag. + +To Install LIBERO, after following LeRobot official instructions, just do: +`pip install -e ".[libero]"` + +### Single-suite evaluation + +Evaluate a policy on one LIBERO suite: + +```bash +python src/lerobot/scripts/eval.py \ + --policy.path="your-policy-id" \ + --env.type=libero \ + --env.task=libero_object \ + --eval.batch_size=2 \ + --eval.n_episodes=3 +``` + +- `--env.task` picks the suite (`libero_object`, `libero_spatial`, etc.). +- `--eval.batch_size` controls how many environments run in parallel. +- `--eval.n_episodes` sets how many episodes to run in total. + +--- + +### Multi-suite evaluation + +Benchmark a policy across multiple suites at once: + +```bash +python src/lerobot/scripts/eval.py \ + --policy.path="your-policy-id" \ + --env.type=libero \ + --env.task=libero_object,libero_spatial \ + --eval.batch_size=1 \ + --eval.n_episodes=2 +``` + +- Pass a comma-separated list to `--env.task` for multi-suite evaluation. + +### Policy inputs and outputs + +When using LIBERO through LeRobot, policies interact with the environment via **observations** and **actions**: + +- **Observations** + - `observation.state` โ€“ proprioceptive features (agent state). + - `observation.images.image` โ€“ main camera view (`agentview_image`). + - `observation.images.image2` โ€“ wrist camera view (`robot0_eye_in_hand_image`). + + โš ๏ธ **Note:** LeRobot enforces the `.images.*` prefix for any multi-modal visual features. Always ensure that your policy config `input_features` use the same naming keys, and that your dataset metadata keys follow this convention during evaluation. + If your data contains different keys, you must rename the observations to match what the policy expects, since naming keys are encoded inside the normalization statistics layer. + This will be fixed with the upcoming Pipeline PR. + +- **Actions** + - Continuous control values in a `Box(-1, 1, shape=(7,))` space. + +We also provide a notebook for quick testing: +Training with LIBERO + +## Training with LIBERO + +When training on LIBERO tasks, make sure your dataset parquet and metadata keys follow the LeRobot convention. + +The environment expects: + +- `observation.state` โ†’ 8-dim agent state +- `observation.images.image` โ†’ main camera (`agentview_image`) +- `observation.images.image2` โ†’ wrist camera (`robot0_eye_in_hand_image`) + +โš ๏ธ Cleaning the dataset upfront is **cleaner and more efficient** than remapping keys inside the code. +To avoid potential mismatches and key errors, we provide a **preprocessed LIBERO dataset** that is fully compatible with the current LeRobot codebase and requires no additional manipulation: +๐Ÿ‘‰ [HuggingFaceVLA/libero](https://huggingface.co/datasets/HuggingFaceVLA/libero) + +For reference, here is the **original dataset** published by Physical Intelligence: +๐Ÿ‘‰ [physical-intelligence/libero](https://huggingface.co/datasets/physical-intelligence/libero) + +--- + +### Example training command + +```bash +python src/lerobot/scripts/train.py \ + --policy.type=smolvla \ + --policy.repo_id=${HF_USER}/libero-test \ + --dataset.repo_id=jadechoghari/smol-libero3 \ + --env.type=libero \ + --env.task=libero_10 \ + --output_dir=./outputs/ \ + --steps=100000 \ + --batch_size=4 \ + --eval.batch_size=1 \ + --eval.n_episodes=1 \ + --eval_freq=1000 \ +``` + +--- + +### Note on rendering + +LeRobot uses MuJoCo for simulation. You need to set the rendering backend before training or evaluation: + +- `export MUJOCO_GL=egl` โ†’ for headless servers (e.g. HPC, cloud) diff --git a/pyproject.toml b/pyproject.toml index 7241a78f9..a1d41b47d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -134,7 +134,21 @@ video_benchmark = ["scikit-image>=0.23.2", "pandas>=2.2.2"] aloha = ["gym-aloha>=0.1.1"] pusht = ["gym-pusht>=0.1.5", "pymunk>=6.6.0,<7.0.0"] # TODO: Fix pymunk version in gym-pusht instead xarm = ["gym-xarm>=0.1.1"] - +libero = [ + "hydra-core>=1.2,<1.4", + "easydict>=1.9", + "lerobot[transformers-dep]", + "robomimic==0.2.0", + "thop>=0.1.0.post2206102148", + "robosuite==1.4.0", + "bddl==1.0.1", + "matplotlib>=3.5.3", + "cloudpickle>=2.0.0", + "gym>=0.25,<0.27", + "future>=0.18.3", + "egl_probe @ git+https://github.com/jadechoghari/egl_probe.git#egg=egl_probe", + "libero @ git+https://github.com/jadechoghari/LIBERO.git@main#egg=libero", +] # All all = [ "lerobot[dynamixel]", @@ -153,7 +167,8 @@ all = [ "lerobot[video_benchmark]", "lerobot[aloha]", "lerobot[pusht]", - "lerobot[xarm]" + "lerobot[xarm]", + "lerobot[libero]" ] [project.scripts] diff --git a/src/lerobot/envs/configs.py b/src/lerobot/envs/configs.py index 35797c6ed..1c4ede961 100644 --- a/src/lerobot/envs/configs.py +++ b/src/lerobot/envs/configs.py @@ -30,6 +30,7 @@ class EnvConfig(draccus.ChoiceRegistry, abc.ABC): fps: int = 30 features: dict[str, PolicyFeature] = field(default_factory=dict) features_map: dict[str, str] = field(default_factory=dict) + max_parallel_tasks: int = 5 @property def type(self) -> str: @@ -271,3 +272,55 @@ class HILEnvConfig(EnvConfig): "use_gamepad": self.use_gamepad, "gripper_penalty": self.gripper_penalty, } + + +@EnvConfig.register_subclass("libero") +@dataclass +class LiberoEnv(EnvConfig): + task: str = "libero_10" # can also choose libero_spatial, libero_object, etc. + fps: int = 30 + episode_length: int = 520 + obs_type: str = "pixels_agent_pos" + render_mode: str = "rgb_array" + camera_name: str = "agentview_image,robot0_eye_in_hand_image" + init_states: bool = True + camera_name_mapping: dict[str, str] | None = (None,) + features: dict[str, PolicyFeature] = field( + default_factory=lambda: { + "action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)), + } + ) + features_map: dict[str, str] = field( + default_factory=lambda: { + "action": ACTION, + "agent_pos": OBS_STATE, + "pixels/agentview_image": f"{OBS_IMAGES}.image", + "pixels/robot0_eye_in_hand_image": f"{OBS_IMAGES}.image2", + } + ) + + def __post_init__(self): + if self.obs_type == "pixels": + self.features["pixels/agentview_image"] = PolicyFeature( + type=FeatureType.VISUAL, shape=(360, 360, 3) + ) + self.features["pixels/robot0_eye_in_hand_image"] = PolicyFeature( + type=FeatureType.VISUAL, shape=(360, 360, 3) + ) + elif self.obs_type == "pixels_agent_pos": + self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(8,)) + self.features["pixels/agentview_image"] = PolicyFeature( + type=FeatureType.VISUAL, shape=(360, 360, 3) + ) + self.features["pixels/robot0_eye_in_hand_image"] = PolicyFeature( + type=FeatureType.VISUAL, shape=(360, 360, 3) + ) + else: + raise ValueError(f"Unsupported obs_type: {self.obs_type}") + + @property + def gym_kwargs(self) -> dict: + return { + "obs_type": self.obs_type, + "render_mode": self.render_mode, + } diff --git a/src/lerobot/envs/factory.py b/src/lerobot/envs/factory.py index dc6d96d61..e4031b9a5 100644 --- a/src/lerobot/envs/factory.py +++ b/src/lerobot/envs/factory.py @@ -17,7 +17,7 @@ import importlib import gymnasium as gym -from lerobot.envs.configs import AlohaEnv, EnvConfig, HILEnvConfig, PushtEnv, XarmEnv +from lerobot.envs.configs import AlohaEnv, EnvConfig, HILEnvConfig, LiberoEnv, PushtEnv, XarmEnv def make_env_config(env_type: str, **kwargs) -> EnvConfig: @@ -29,11 +29,15 @@ def make_env_config(env_type: str, **kwargs) -> EnvConfig: return XarmEnv(**kwargs) elif env_type == "hil": return HILEnvConfig(**kwargs) + elif env_type == "libero": + return LiberoEnv(**kwargs) else: raise ValueError(f"Policy type '{env_type}' is not available.") -def make_env(cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False) -> gym.vector.VectorEnv | None: +def make_env( + cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False +) -> dict[str, dict[int, gym.vector.VectorEnv]]: """Makes a gym vector environment according to the config. Args: @@ -47,25 +51,44 @@ def make_env(cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False) -> g ModuleNotFoundError: If the requested env package is not installed Returns: - gym.vector.VectorEnv: The parallelized gym.env instance. + dict[str, dict[int, gym.vector.VectorEnv]]: + A mapping from suite name to indexed vectorized environments. + - For multi-task benchmarks (e.g., LIBERO): one entry per suite, and one vec env per task_id. + - For single-task environments: a single suite entry (cfg.type) with task_id=0. + """ if n_envs < 1: - raise ValueError("`n_envs must be at least 1") + raise ValueError("`n_envs` must be at least 1") + + env_cls = gym.vector.AsyncVectorEnv if use_async_envs else gym.vector.SyncVectorEnv + + if "libero" in cfg.type: + from lerobot.envs.libero import create_libero_envs + + return create_libero_envs( + task=cfg.task, + n_envs=n_envs, + camera_name=cfg.camera_name, + init_states=cfg.init_states, + gym_kwargs=cfg.gym_kwargs, + env_cls=env_cls, + ) package_name = f"gym_{cfg.type}" - try: importlib.import_module(package_name) except ModuleNotFoundError as e: - print(f"{package_name} is not installed. Please install it with `pip install 'lerobot[{cfg.type}]'`") - raise e + raise ModuleNotFoundError( + f'{package_name} is not installed. Install with: pip install "lerobot[{cfg.type}]"' + ) from e gym_handle = f"{package_name}/{cfg.task}" - # batched version of the env that returns an observation of shape (b, c) - env_cls = gym.vector.AsyncVectorEnv if use_async_envs else gym.vector.SyncVectorEnv - env = env_cls( - [lambda: gym.make(gym_handle, disable_env_checker=True, **cfg.gym_kwargs) for _ in range(n_envs)] - ) + def _make_one(): + return gym.make(gym_handle, disable_env_checker=True, **(cfg.gym_kwargs or {})) - return env + vec = env_cls([_make_one for _ in range(n_envs)]) + + # normalize to {suite: {task_id: vec_env}} for consistency + suite_name = cfg.type # e.g., "pusht", "aloha" + return {suite_name: {0: vec}} diff --git a/src/lerobot/envs/libero.py b/src/lerobot/envs/libero.py new file mode 100644 index 000000000..7ec9b34a2 --- /dev/null +++ b/src/lerobot/envs/libero.py @@ -0,0 +1,399 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import math +import os +from collections import defaultdict +from collections.abc import Callable, Iterable, Mapping, Sequence +from functools import partial +from pathlib import Path +from typing import Any + +import gymnasium as gym +import numpy as np +import torch +from gymnasium import spaces +from libero.libero import benchmark, get_libero_path +from libero.libero.envs import OffScreenRenderEnv + + +def _parse_camera_names(camera_name: str | Sequence[str]) -> list[str]: + """Normalize camera_name into a non-empty list of strings.""" + if isinstance(camera_name, str): + cams = [c.strip() for c in camera_name.split(",") if c.strip()] + elif isinstance(camera_name, (list, tuple)): + cams = [str(c).strip() for c in camera_name if str(c).strip()] + else: + raise TypeError(f"camera_name must be str or sequence[str], got {type(camera_name).__name__}") + if not cams: + raise ValueError("camera_name resolved to an empty list.") + return cams + + +def _get_suite(name: str) -> Any: + """Instantiate a LIBERO suite by name with clear validation.""" + bench = benchmark.get_benchmark_dict() + if name not in bench: + raise ValueError(f"Unknown LIBERO suite '{name}'. Available: {', '.join(sorted(bench.keys()))}") + suite = bench[name]() + if not getattr(suite, "tasks", None): + raise ValueError(f"Suite '{name}' has no tasks.") + return suite + + +def _select_task_ids(total_tasks: int, task_ids: Iterable[int] | None) -> list[int]: + """Validate/normalize task ids. If None โ†’ all tasks.""" + if task_ids is None: + return list(range(total_tasks)) + ids = sorted({int(t) for t in task_ids}) + for t in ids: + if t < 0 or t >= total_tasks: + raise ValueError(f"task_id {t} out of range [0, {total_tasks - 1}].") + return ids + + +def quat2axisangle(quat: np.ndarray) -> np.ndarray: + """ + Copied from robosuite: https://github.com/ARISE-Initiative/robosuite/blob/eafb81f54ffc104f905ee48a16bb15f059176ad3/robosuite/utils/transform_utils.py#L490C1-L512C55 + + Converts quaternion to axis-angle format. + Returns a unit vector direction scaled by its angle in radians. + + Args: + quat (np.array): (x,y,z,w) vec4 float angles + + Returns: + np.array: (ax,ay,az) axis-angle exponential coordinates + """ + # clip quaternion + if quat[3] > 1.0: + quat[3] = 1.0 + elif quat[3] < -1.0: + quat[3] = -1.0 + + den = np.sqrt(1.0 - quat[3] * quat[3]) + if math.isclose(den, 0.0): + # This is (close to) a zero degree rotation, immediately return + return np.zeros(3) + + return (quat[:3] * 2.0 * math.acos(quat[3])) / den + + +def get_task_init_states(task_suite: Any, i: int) -> np.ndarray: + init_states_path = ( + Path(get_libero_path("init_states")) + / task_suite.tasks[i].problem_folder + / task_suite.tasks[i].init_states_file + ) + init_states = torch.load(init_states_path, weights_only=False) # nosec B614 + return init_states + + +def get_libero_dummy_action(): + """Get dummy/no-op action, used to roll out the simulation while the robot does nothing.""" + return [0, 0, 0, 0, 0, 0, -1] + + +OBS_STATE_DIM = 8 +ACTION_DIM = 7 +TASK_SUITE_MAX_STEPS: dict[str, int] = { + "libero_spatial": 280, # longest training demo has 193 steps + "libero_object": 280, # longest training demo has 254 steps + "libero_goal": 300, # longest training demo has 270 steps + "libero_10": 520, # longest training demo has 505 steps + "libero_90": 400, # longest training demo has 373 steps +} + + +class LiberoEnv(gym.Env): + metadata = {"render_modes": ["rgb_array"], "render_fps": 80} + + def __init__( + self, + task_suite: Any, + task_id: int, + task_suite_name: str, + camera_name: str | Sequence[str] = "agentview_image,robot0_eye_in_hand_image", + obs_type: str = "pixels", + render_mode: str = "rgb_array", + observation_width: int = 256, + observation_height: int = 256, + visualization_width: int = 640, + visualization_height: int = 480, + init_states: bool = True, + episode_index: int = 0, + camera_name_mapping: dict[str, str] | None = None, + num_steps_wait: int = 10, + ): + super().__init__() + self.task_id = task_id + self.obs_type = obs_type + self.render_mode = render_mode + self.observation_width = observation_width + self.observation_height = observation_height + self.visualization_width = visualization_width + self.visualization_height = visualization_height + self.init_states = init_states + self.camera_name = camera_name.split( + "," + ) # agentview_image (main) or robot0_eye_in_hand_image (wrist) + + # Map raw camera names to "image1" and "image2". + # The preprocessing step `preprocess_observation` will then prefix these with `.images.*`, + # following the LeRobot convention (e.g., `observation.images.image`, `observation.images.image2`). + # This ensures the policy consistently receives observations in the + # expected format regardless of the original camera naming. + if camera_name_mapping is None: + camera_name_mapping = { + "agentview_image": "image", + "robot0_eye_in_hand_image": "image2", + } + self.camera_name_mapping = camera_name_mapping + self.num_steps_wait = num_steps_wait + self.episode_index = episode_index + # Load once and keep + self._init_states = get_task_init_states(task_suite, self.task_id) if self.init_states else None + self._init_state_id = self.episode_index # tie each sub-env to a fixed init state + + self._env = self._make_envs_task(task_suite, self.task_id) + default_steps = 500 + self._max_episode_steps = TASK_SUITE_MAX_STEPS.get(task_suite_name, default_steps) + + images = {} + for cam in self.camera_name: + images[self.camera_name_mapping[cam]] = spaces.Box( + low=0, + high=255, + shape=(self.observation_height, self.observation_width, 3), + dtype=np.uint8, + ) + + if self.obs_type == "state": + raise NotImplementedError( + "The 'state' observation type is not supported in LiberoEnv. " + "Please switch to an image-based obs_type (e.g. 'pixels', 'pixels_agent_pos')." + ) + + elif self.obs_type == "pixels": + self.observation_space = spaces.Dict( + { + "pixels": spaces.Dict(images), + } + ) + elif self.obs_type == "pixels_agent_pos": + self.observation_space = spaces.Dict( + { + "pixels": spaces.Dict(images), + "agent_pos": spaces.Box( + low=-1000.0, + high=1000.0, + shape=(OBS_STATE_DIM,), + dtype=np.float64, + ), + } + ) + + self.action_space = spaces.Box(low=-1, high=1, shape=(ACTION_DIM,), dtype=np.float32) + + def render(self): + raw_obs = self._env.env._get_observations() + image = self._format_raw_obs(raw_obs)["pixels"]["image"] + return image + + def _make_envs_task(self, task_suite: Any, task_id: int = 0): + task = task_suite.get_task(task_id) + self.task = task.name + self.task_description = task.language + task_bddl_file = os.path.join(get_libero_path("bddl_files"), task.problem_folder, task.bddl_file) + + env_args = { + "bddl_file_name": task_bddl_file, + "camera_heights": self.observation_height, + "camera_widths": self.observation_width, + } + env = OffScreenRenderEnv(**env_args) + env.reset() + return env + + def _format_raw_obs(self, raw_obs: dict[str, Any]) -> dict[str, Any]: + images = {} + for camera_name in self.camera_name: + image = raw_obs[camera_name] + image = image[::-1, ::-1] # rotate 180 degrees + images[self.camera_name_mapping[camera_name]] = image + state = np.concatenate( + ( + raw_obs["robot0_eef_pos"], + quat2axisangle(raw_obs["robot0_eef_quat"]), + raw_obs["robot0_gripper_qpos"], + ) + ) + agent_pos = state + if self.obs_type == "pixels": + return {"pixels": images.copy()} + if self.obs_type == "pixels_agent_pos": + return { + "pixels": images.copy(), + "agent_pos": agent_pos, + } + raise NotImplementedError( + f"The observation type '{self.obs_type}' is not supported in LiberoEnv. " + "Please switch to an image-based obs_type (e.g. 'pixels', 'pixels_agent_pos')." + ) + + def reset(self, seed=None, **kwargs): + super().reset(seed=seed) + self._env.seed(seed) + if self.init_states and self._init_states is not None: + self._env.set_init_state(self._init_states[self._init_state_id]) + raw_obs = self._env.reset() + + # After reset, objects may be unstable (slightly floating, intersecting, etc.). + # Step the simulator with a no-op action for a few frames so everything settles. + # Increasing this value can improve determinism and reproducibility across resets. + for _ in range(self.num_steps_wait): + raw_obs, _, _, _ = self._env.step(get_libero_dummy_action()) + observation = self._format_raw_obs(raw_obs) + info = {"is_success": False} + return observation, info + + def step(self, action: np.ndarray) -> tuple[dict[str, Any], dict[str, Any]]: + if action.ndim != 1: + raise ValueError( + f"Expected action to be 1-D (shape (action_dim,)), " + f"but got shape {action.shape} with ndim={action.ndim}" + ) + raw_obs, reward, done, info = self._env.step(action) + + is_success = self._env.check_success() + terminated = done or is_success + info["is_success"] = done # is_success + + observation = self._format_raw_obs(raw_obs) + if done: + self.reset() + info.update( + { + "task": self.task, + "task_id": self.task_id, + "done": done, + "is_success": is_success, + } + ) + truncated = False + return observation, reward, terminated, truncated, info + + def close(self): + self._env.close() + + +def _make_env_fns( + *, + suite, + suite_name: str, + task_id: int, + n_envs: int, + camera_names: list[str], + init_states: bool, + gym_kwargs: Mapping[str, Any], +) -> list[Callable[[], LiberoEnv]]: + """Build n_envs factory callables for a single (suite, task_id).""" + joined_cams = ",".join(camera_names) # keep backward-compat: downstream expects a string + + def _make_env(episode_index: int, **kwargs) -> LiberoEnv: + local_kwargs = dict(kwargs) + return LiberoEnv( + task_suite=suite, + task_id=task_id, + task_suite_name=suite_name, + camera_name=joined_cams, + init_states=init_states, + episode_index=episode_index, + **local_kwargs, + ) + + fns: list[Callable[[], LiberoEnv]] = [] + for episode_index in range(n_envs): + fns.append(partial(_make_env, episode_index, **gym_kwargs)) + return fns + + +# ---- Main API ---------------------------------------------------------------- + + +def create_libero_envs( + task: str, + n_envs: int, + gym_kwargs: dict[str, Any] | None = None, + camera_name: str | Sequence[str] = "agentview_image,robot0_eye_in_hand_image", + init_states: bool = True, + env_cls: Callable[[Sequence[Callable[[], Any]]], Any] | None = None, +) -> dict[str, dict[int, Any]]: + """ + Create vectorized LIBERO environments with a consistent return shape. + + Returns: + dict[suite_name][task_id] -> vec_env (env_cls([...]) with exactly n_envs factories) + Notes: + - n_envs is the number of rollouts *per task* (episode_index = 0..n_envs-1). + - `task` can be a single suite or a comma-separated list of suites. + - You may pass `task_ids` (list[int]) inside `gym_kwargs` to restrict tasks per suite. + """ + if env_cls is None or not callable(env_cls): + raise ValueError("env_cls must be a callable that wraps a list of environment factory callables.") + if not isinstance(n_envs, int) or n_envs <= 0: + raise ValueError(f"n_envs must be a positive int; got {n_envs}.") + + gym_kwargs = dict(gym_kwargs or {}) + task_ids_filter = gym_kwargs.pop("task_ids", None) # optional: limit to specific tasks + + camera_names = _parse_camera_names(camera_name) + suite_names = [s.strip() for s in str(task).split(",") if s.strip()] + if not suite_names: + raise ValueError("`task` must contain at least one LIBERO suite name.") + + print( + f"Creating LIBERO envs | suites={suite_names} | n_envs(per task)={n_envs} | init_states={init_states}" + ) + if task_ids_filter is not None: + print(f"Restricting to task_ids={task_ids_filter}") + + out: dict[str, dict[int, Any]] = defaultdict(dict) + + for suite_name in suite_names: + suite = _get_suite(suite_name) + total = len(suite.tasks) + selected = _select_task_ids(total, task_ids_filter) + + if not selected: + raise ValueError(f"No tasks selected for suite '{suite_name}' (available: {total}).") + + for tid in selected: + fns = _make_env_fns( + suite=suite, + suite_name=suite_name, + task_id=tid, + n_envs=n_envs, + camera_names=camera_names, + init_states=init_states, + gym_kwargs=gym_kwargs, + ) + out[suite_name][tid] = env_cls(fns) + print(f"Built vec env | suite={suite_name} | task_id={tid} | n_envs={n_envs}") + + # return plain dicts for predictability + return {suite: dict(task_map) for suite, task_map in out.items()} diff --git a/src/lerobot/envs/utils.py b/src/lerobot/envs/utils.py index 00676a011..063fda645 100644 --- a/src/lerobot/envs/utils.py +++ b/src/lerobot/envs/utils.py @@ -14,6 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import warnings +from collections.abc import Mapping, Sequence +from functools import singledispatch from typing import Any import einops @@ -97,7 +99,6 @@ def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]: policy_key = env_cfg.features_map[key] policy_features[policy_key] = feature - return policy_features @@ -134,3 +135,41 @@ def add_envs_task(env: gym.vector.VectorEnv, observation: dict[str, Any]) -> dic num_envs = observation[list(observation.keys())[0]].shape[0] observation["task"] = ["" for _ in range(num_envs)] return observation + + +def _close_single_env(env: Any) -> None: + try: + env.close() + except Exception as exc: + print(f"Exception while closing env {env}: {exc}") + + +@singledispatch +def close_envs(obj: Any) -> None: + """Default: raise if the type is not recognized.""" + raise NotImplementedError(f"close_envs not implemented for type {type(obj).__name__}") + + +@close_envs.register +def _(env: Mapping) -> None: + for v in env.values(): + if isinstance(v, Mapping): + close_envs(v) + elif hasattr(v, "close"): + _close_single_env(v) + + +@close_envs.register +def _(envs: Sequence) -> None: + if isinstance(envs, (str, bytes)): + return + for v in envs: + if isinstance(v, Mapping) or isinstance(v, Sequence) and not isinstance(v, (str, bytes)): + close_envs(v) + elif hasattr(v, "close"): + _close_single_env(v) + + +@close_envs.register +def _(env: gym.Env) -> None: + _close_single_env(env) diff --git a/src/lerobot/policies/factory.py b/src/lerobot/policies/factory.py index ef56bdb61..c3ae9cd54 100644 --- a/src/lerobot/policies/factory.py +++ b/src/lerobot/policies/factory.py @@ -156,7 +156,6 @@ def make_policy( "by default without stats from a dataset." ) features = env_to_policy_features(env_cfg) - cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION} cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features} kwargs["config"] = cfg @@ -169,10 +168,7 @@ def make_policy( else: # Make a fresh policy. policy = policy_cls(**kwargs) - policy.to(cfg.device) assert isinstance(policy, nn.Module) - # policy = torch.compile(policy, mode="reduce-overhead") - return policy diff --git a/src/lerobot/scripts/eval.py b/src/lerobot/scripts/eval.py index 13d30c686..0edd292d5 100644 --- a/src/lerobot/scripts/eval.py +++ b/src/lerobot/scripts/eval.py @@ -46,16 +46,19 @@ Note that in both examples, the repo/folder should contain at least `config.json You can learn about the CLI options for this script in the `EvalPipelineConfig` in lerobot/configs/eval.py """ +import concurrent.futures as cf import json import logging import threading import time -from collections.abc import Callable +from collections import defaultdict +from collections.abc import Callable, Iterator from contextlib import nullcontext from copy import deepcopy from dataclasses import asdict from pathlib import Path from pprint import pformat +from typing import TypedDict import einops import gymnasium as gym @@ -68,7 +71,12 @@ from tqdm import trange from lerobot.configs import parser from lerobot.configs.eval import EvalPipelineConfig from lerobot.envs.factory import make_env -from lerobot.envs.utils import add_envs_task, check_env_attributes_and_types, preprocess_observation +from lerobot.envs.utils import ( + add_envs_task, + check_env_attributes_and_types, + close_envs, + preprocess_observation, +) from lerobot.policies.factory import make_policy from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.utils import get_device_from_parameters @@ -145,7 +153,7 @@ def rollout( leave=False, ) check_env_attributes_and_types(env) - while not np.all(done): + while not np.all(done) and step < max_steps: # Numpy array to tensor and changing dictionary keys to LeRobot policy format. observation = preprocess_observation(observation) if return_observations: @@ -158,10 +166,8 @@ def rollout( # Infer "task" from attributes of environments. # TODO: works with SyncVectorEnv but not AsyncVectorEnv observation = add_envs_task(env, observation) - with torch.inference_mode(): action = policy.select_action(observation) - # Convert to CPU / numpy. action = action.to("cpu").numpy() assert action.ndim == 2, "Action dimensions should be (batch, action_dim)" @@ -179,7 +185,12 @@ def rollout( successes = [False] * env.num_envs # Keep track of which environments are done so far. + # Mark the episode as done if we reach the maximum step limit. + # This ensures that the rollout always terminates cleanly at `max_steps`, + # and allows logging/saving (e.g., videos) to be triggered consistently. done = terminated | truncated | done + if step + 1 == max_steps: + done = np.ones_like(done, dtype=bool) all_actions.append(torch.from_numpy(action)) all_rewards.append(torch.from_numpy(reward)) @@ -402,7 +413,6 @@ def eval_policy( "eval_ep_s": (time.time() - start) / n_episodes, }, } - if return_episode_data: info["episodes"] = episode_data @@ -463,7 +473,6 @@ def eval_main(cfg: EvalPipelineConfig): # Check device is available device = get_safe_torch_device(cfg.policy.device, log=True) - torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True set_seed(cfg.seed) @@ -471,40 +480,246 @@ def eval_main(cfg: EvalPipelineConfig): logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}") logging.info("Making environment.") - env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs) + envs = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs) logging.info("Making policy.") - policy = make_policy( cfg=cfg.policy, env_cfg=cfg.env, ) - policy.eval() + policy.eval() with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(): - info = eval_policy( - env, + info = eval_policy_all( + envs, policy, cfg.eval.n_episodes, max_episodes_rendered=10, videos_dir=Path(cfg.output_dir) / "videos", start_seed=cfg.seed, + max_parallel_tasks=cfg.env.max_parallel_tasks, + verbose=False, ) - print(info["aggregated"]) + print("Overall Aggregated Metrics:") + print(info["overall"]["aggregated"]) + + # Print per-suite stats + for task_group, task_group_info in info.items(): + if task_group == "overall": + continue # Skip the overall stats since we already printed it + print(f"\nAggregated Metrics for {task_group}:") + print(task_group_info["aggregated"]) + # Close all vec envs + close_envs(envs) # Save info with open(Path(cfg.output_dir) / "eval_info.json", "w") as f: json.dump(info, f, indent=2) - env.close() - logging.info("End of eval") -def main(): - init_logging() - eval_main() +# ---- typed payload returned by one task eval ---- +class TaskMetrics(TypedDict): + sum_rewards: list[float] + max_rewards: list[float] + successes: list[bool] + video_paths: list[str] + + +ACC_KEYS = ("sum_rewards", "max_rewards", "successes", "video_paths") + + +def eval_policy_all( + envs: dict[str, dict[int, gym.vector.VectorEnv]], + policy: PreTrainedPolicy, + n_episodes: int, + max_episodes_rendered: int = 0, + videos_dir: Path | None = None, + return_episode_data: bool = False, + start_seed: int | None = None, + max_parallel_tasks: int = 5, + verbose: bool = True, +) -> dict: + """ + Evaluate a policy over a dict-of-dicts of vectorized envs: + envs[suite_name][task_id] -> gym.vector.VectorEnv + Returns a dict with per-suite aggregates and an 'overall' block. + """ + global_start = time.time() + + # inner: evaluate a single (suite, task) + def eval_one( + task_group: str, + task_id: int, + env: gym.vector.VectorEnv, + *, + policy: PreTrainedPolicy, + n_episodes: int, + max_episodes_rendered: int, + videos_dir: Path | None, + return_episode_data: bool, + start_seed: int | None, + ) -> TaskMetrics: + """Evaluates one task_id of one suite using the provided vec env.""" + if verbose: + print(f"Evaluating: task_group={task_group}, task_id={task_id} ...") + + task_videos_dir = None + if videos_dir is not None: + task_videos_dir = videos_dir / f"{task_group}_{task_id}" + task_videos_dir.mkdir(parents=True, exist_ok=True) + + task_result = eval_policy( + env=env, + policy=policy, + n_episodes=n_episodes, + max_episodes_rendered=max_episodes_rendered, + videos_dir=task_videos_dir, + return_episode_data=return_episode_data, + start_seed=start_seed, + ) + + per_episode = task_result["per_episode"] + return TaskMetrics( + sum_rewards=[ep["sum_reward"] for ep in per_episode], + max_rewards=[ep["max_reward"] for ep in per_episode], + successes=[ep["success"] for ep in per_episode], + video_paths=task_result.get("video_paths", []), + ) + + def _eval_monotask( + envs, policy, n_episodes, max_episodes_rendered, videos_dir, return_episode_data, start_seed + ): + for task_group, tasks in envs.items(): + for task_id, vec in tasks.items(): + yield ( + task_group, + task_id, + eval_one( + task_group, + task_id, + vec, + policy=policy, + n_episodes=n_episodes, + max_episodes_rendered=max_episodes_rendered, + videos_dir=videos_dir, + return_episode_data=return_episode_data, + start_seed=start_seed, + ), + ) + + def _eval_parallel( + envs, + policy, + n_episodes, + max_episodes_rendered, + videos_dir, + return_episode_data, + start_seed, + max_parallel_tasks, + ): + with cf.ThreadPoolExecutor(max_workers=max_parallel_tasks) as executor: + fut2key: dict[cf.Future, tuple[str, int]] = {} + for task_group, tasks in envs.items(): + for task_id, vec in tasks.items(): + fut = executor.submit( + eval_one, + task_group, + task_id, + vec, + policy=policy, + n_episodes=n_episodes, + max_episodes_rendered=max_episodes_rendered, + videos_dir=videos_dir, + return_episode_data=return_episode_data, + start_seed=start_seed, + ) + fut2key[fut] = (task_group, task_id) + for fut in cf.as_completed(fut2key): + task_group, task_id = fut2key[fut] + yield task_group, task_id, fut.result() + + # result producer: sequential or threaded, same consumer + def iter_task_results() -> Iterator[tuple[str, int, TaskMetrics]]: + """ + Yield evaluation results for each (task_group, task_id). + + Depending on `max_parallel_tasks`, runs sequentially or in parallel, + but always returns a generator of tuples: + (task_group, task_id, TaskMetrics). + """ + if max_parallel_tasks == 1: + yield from _eval_monotask( + envs, policy, n_episodes, max_episodes_rendered, videos_dir, return_episode_data, start_seed + ) + else: + yield from _eval_parallel( + envs, + policy, + n_episodes, + max_episodes_rendered, + videos_dir, + return_episode_data, + start_seed, + max_parallel_tasks, + ) + + # single accumulator path on the main thread + group_acc: dict[str, dict[str, list]] = defaultdict(lambda: {k: [] for k in ACC_KEYS}) + overall: dict[str, list] = {k: [] for k in ACC_KEYS} + + for task_group, _task_id, metrics in iter_task_results(): + acc = group_acc[task_group] + for k in ACC_KEYS: + acc[k].extend(metrics[k]) + overall[k].extend(metrics[k]) + + # build outputs + results: dict[str, dict] = {} + for task_group, data in group_acc.items(): + suite_rewards = data["sum_rewards"] + suite_max = data["max_rewards"] + suite_succ = data["successes"] + suite_vids = data["video_paths"] + + suite_eval_s = time.time() - global_start + suite_eval_ep_s = suite_eval_s / max(1, len(suite_rewards)) + + results[task_group] = { + "aggregated": { + "avg_sum_reward": float(np.nanmean(suite_rewards)) if suite_rewards else float("nan"), + "avg_max_reward": float(np.nanmean(suite_max)) if suite_max else float("nan"), + "pc_success": float(np.nanmean(suite_succ) * 100) if suite_succ else float("nan"), + "eval_s": suite_eval_s, + "eval_ep_s": suite_eval_ep_s, + }, + "video_paths": suite_vids, + "episodes": None, + } + + global_eval_s = time.time() - global_start + global_eval_ep_s = global_eval_s / max(1, len(overall["sum_rewards"])) + results["overall"] = { + "aggregated": { + "avg_sum_reward": float(np.nanmean(overall["sum_rewards"])) + if overall["sum_rewards"] + else float("nan"), + "avg_max_reward": float(np.nanmean(overall["max_rewards"])) + if overall["max_rewards"] + else float("nan"), + "pc_success": float(np.nanmean(overall["successes"]) * 100) + if overall["successes"] + else float("nan"), + "eval_s": global_eval_s, + "eval_ep_s": global_eval_ep_s, + }, + "video_paths": overall["video_paths"], + "episodes": None, + } + return results if __name__ == "__main__": - main() + init_logging() + eval_main() diff --git a/src/lerobot/scripts/train.py b/src/lerobot/scripts/train.py index 398bea90e..81eb04f63 100644 --- a/src/lerobot/scripts/train.py +++ b/src/lerobot/scripts/train.py @@ -30,11 +30,12 @@ from lerobot.datasets.factory import make_dataset from lerobot.datasets.sampler import EpisodeAwareSampler from lerobot.datasets.utils import cycle from lerobot.envs.factory import make_env +from lerobot.envs.utils import close_envs from lerobot.optim.factory import make_optimizer_and_scheduler from lerobot.policies.factory import make_policy from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.utils import get_device_from_parameters -from lerobot.scripts.eval import eval_policy +from lerobot.scripts.eval import eval_policy_all from lerobot.utils.logging_utils import AverageMeter, MetricsTracker from lerobot.utils.random_utils import set_seed from lerobot.utils.train_utils import ( @@ -126,7 +127,6 @@ def train(cfg: TrainPipelineConfig): logging.info("Creating dataset") dataset = make_dataset(cfg) - # Create environment used for evaluating checkpoints during training on simulation data. # On real-world data, no need to create an environment as evaluations are done outside train.py, # using the eval.py instead, with gym_dora environment and dora-rs. @@ -140,7 +140,6 @@ def train(cfg: TrainPipelineConfig): cfg=cfg.policy, ds_meta=dataset.meta, ) - logging.info("Creating optimizer and scheduler") optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy) grad_scaler = GradScaler(device.type, enabled=cfg.policy.use_amp) @@ -188,7 +187,6 @@ def train(cfg: TrainPipelineConfig): dl_iter = cycle(dataloader) policy.train() - train_metrics = { "loss": AverageMeter("loss", ":.3f"), "grad_norm": AverageMeter("grdn", ":.3f"), @@ -206,7 +204,6 @@ def train(cfg: TrainPipelineConfig): start_time = time.perf_counter() batch = next(dl_iter) train_tracker.dataloading_s = time.perf_counter() - start_time - for key in batch: if isinstance(batch[key], torch.Tensor): if batch[key].dtype != torch.bool: @@ -257,15 +254,27 @@ def train(cfg: TrainPipelineConfig): torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(), ): - eval_info = eval_policy( - eval_env, + eval_info = eval_policy_all( + eval_env, # dict[suite][task_id] -> vec_env policy, cfg.eval.n_episodes, videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_id}", max_episodes_rendered=4, start_seed=cfg.seed, + max_parallel_tasks=cfg.env.max_parallel_tasks, + verbose=False, ) + # overall metrics (suite-agnostic) + aggregated = eval_info["overall"]["aggregated"] + + # optional: per-suite logging + for suite, suite_info in eval_info.items(): + if suite == "overall": + continue + logging.info("Suite %s aggregated: %s", suite, suite_info["aggregated"]) + + # meters/tracker eval_metrics = { "avg_sum_reward": AverageMeter("โˆ‘rwrd", ":.3f"), "pc_success": AverageMeter("success", ":.1f"), @@ -274,17 +283,16 @@ def train(cfg: TrainPipelineConfig): eval_tracker = MetricsTracker( cfg.batch_size, dataset.num_frames, dataset.num_episodes, eval_metrics, initial_step=step ) - eval_tracker.eval_s = eval_info["aggregated"].pop("eval_s") - eval_tracker.avg_sum_reward = eval_info["aggregated"].pop("avg_sum_reward") - eval_tracker.pc_success = eval_info["aggregated"].pop("pc_success") - logging.info(eval_tracker) + eval_tracker.eval_s = aggregated.get("eval_s", 0.0) + eval_tracker.avg_sum_reward = aggregated.get("avg_sum_reward", float("nan")) + eval_tracker.pc_success = aggregated.get("pc_success", float("nan")) if wandb_logger: wandb_log_dict = {**eval_tracker.to_dict(), **eval_info} wandb_logger.log_dict(wandb_log_dict, step, mode="eval") wandb_logger.log_video(eval_info["video_paths"][0], step, mode="eval") if eval_env: - eval_env.close() + close_envs(eval_env) logging.info("End of training") if cfg.policy.push_to_hub: diff --git a/tests/envs/test_envs.py b/tests/envs/test_envs.py index 140e9dfb9..51ea564e5 100644 --- a/tests/envs/test_envs.py +++ b/tests/envs/test_envs.py @@ -46,7 +46,10 @@ def test_env(env_name, env_task, obs_type): @require_env def test_factory(env_name): cfg = make_env_config(env_name) - env = make_env(cfg, n_envs=1) + envs = make_env(cfg, n_envs=1) + suite_name = next(iter(envs)) + task_id = next(iter(envs[suite_name])) + env = envs[suite_name][task_id] obs, _ = env.reset() obs = preprocess_observation(obs) diff --git a/tests/policies/test_policies.py b/tests/policies/test_policies.py index ef2d4ecd8..fd3f70c23 100644 --- a/tests/policies/test_policies.py +++ b/tests/policies/test_policies.py @@ -158,7 +158,7 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs): assert isinstance(policy, PreTrainedPolicy) # Check that we run select_actions and get the appropriate output. - env = make_env(train_cfg.env, n_envs=2) + envs = make_env(train_cfg.env, n_envs=2) dataloader = torch.utils.data.DataLoader( dataset, @@ -187,6 +187,13 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs): # reset the policy and environment policy.reset() + # For testing purposes, we only need a single environment instance. + # So here we unwrap the first suite_name and first task_id to grab + # the actual env object (SyncVectorEnv) that exposes `.reset()`. + suite_name = next(iter(envs)) + task_id = next(iter(envs[suite_name])) + env = envs[suite_name][task_id] + observation, _ = env.reset(seed=train_cfg.seed) # apply transform to normalize the observations