mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 12:09:42 +00:00
Add libero (#1950)
* add libero * backup * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add factory * Add LIBERO as a submodule * add changes * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add multitask * remove photos * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * bug remove * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix video paths and train.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix renaming issues with cams * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update .gitignore * final refactor/fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add safethread support * ad * Add dep (#4) * Add 'libero' dependencies to pyproject.toml * Add Git dependencies for egl_probe and LIBERO * Update libero-requirements.txt * add future dep * update bash * quick fix * remove step1 * cleanup (#5) * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * clean * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * cleanup 2 * improve install * Delete libero-requirements.txt * iterate on review * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add docs for eval * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * doc * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update doc * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove brkpt * fix docs * update docs/script * update doc * skip test warning * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * hotfix: flip actions * add train * new things * More things * add new changes * iterate on review * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove files * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove unces * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * new changes * iterate on review * remove files * factor * update installation * doc title * make it reproducible * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * iterate on review * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove files * update tests * update doc * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove sh files * add gym --------- Signed-off-by: Jade Choghari <chogharijade@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jade Choghari (jchoghar) <chogharijade@gmai.com> Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
@@ -19,6 +19,8 @@
|
|||||||
title: Train RL in Simulation
|
title: Train RL in Simulation
|
||||||
- local: async
|
- local: async
|
||||||
title: Use Async Inference
|
title: Use Async Inference
|
||||||
|
- local: libero
|
||||||
|
title: Using Libero
|
||||||
- local: porting_datasets_v3
|
- local: porting_datasets_v3
|
||||||
title: Porting Large Datasets
|
title: Porting Large Datasets
|
||||||
title: "Tutorials"
|
title: "Tutorials"
|
||||||
|
|||||||
@@ -0,0 +1,126 @@
|
|||||||
|
# LIBERO
|
||||||
|
|
||||||
|
**LIBERO** is a benchmark designed to study **lifelong robot learning**. The idea is that robots won’t just be pretrained once in a factory, they’ll need to keep learning and adapting with their human users over time. This ongoing adaptation is called **lifelong learning in decision making (LLDM)**, and it’s a key step toward building robots that become truly personalized helpers.
|
||||||
|
|
||||||
|
- 📄 [LIBERO paper](https://arxiv.org/abs/2306.03310)
|
||||||
|
- 💻 [Original LIBERO repo](https://github.com/Lifelong-Robot-Learning/LIBERO)
|
||||||
|
|
||||||
|
To make progress on this challenge, LIBERO provides a set of standardized tasks that focus on **knowledge transfer**: how well a robot can apply what it has already learned to new situations. By evaluating on LIBERO, different algorithms can be compared fairly and researchers can build on each other’s work.
|
||||||
|
|
||||||
|
LIBERO includes **five task suites**:
|
||||||
|
|
||||||
|
- **LIBERO-Spatial (`libero_spatial`)** – tasks that require reasoning about spatial relations.
|
||||||
|
- **LIBERO-Object (`libero_object`)** – tasks centered on manipulating different objects.
|
||||||
|
- **LIBERO-Goal (`libero_goal`)** – goal-conditioned tasks where the robot must adapt to changing targets.
|
||||||
|
- **LIBERO-90 (`libero_90`)** – 90 short-horizon tasks from the LIBERO-100 collection.
|
||||||
|
- **LIBERO-Long (`libero_10`)** – 10 long-horizon tasks from the LIBERO-100 collection.
|
||||||
|
|
||||||
|
Together, these suites cover **130 tasks**, ranging from simple object manipulations to complex multi-step scenarios. LIBERO is meant to grow over time, and to serve as a shared benchmark where the community can test and improve lifelong learning algorithms.
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
## Evaluating with LIBERO
|
||||||
|
|
||||||
|
At **LeRobot**, we ported [LIBERO](https://github.com/Lifelong-Robot-Learning/LIBERO) into our framework and used it mainly to **evaluate [SmolVLA](https://huggingface.co/docs/lerobot/en/smolvla)**, our lightweight Vision-Language-Action model.
|
||||||
|
|
||||||
|
LIBERO is now part of our **multi-eval supported simulation**, meaning you can benchmark your policies either on a **single suite of tasks** or across **multiple suites at once** with just a flag.
|
||||||
|
|
||||||
|
To Install LIBERO, after following LeRobot official instructions, just do:
|
||||||
|
`pip install -e ".[libero]"`
|
||||||
|
|
||||||
|
### Single-suite evaluation
|
||||||
|
|
||||||
|
Evaluate a policy on one LIBERO suite:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python src/lerobot/scripts/eval.py \
|
||||||
|
--policy.path="your-policy-id" \
|
||||||
|
--env.type=libero \
|
||||||
|
--env.task=libero_object \
|
||||||
|
--eval.batch_size=2 \
|
||||||
|
--eval.n_episodes=3
|
||||||
|
```
|
||||||
|
|
||||||
|
- `--env.task` picks the suite (`libero_object`, `libero_spatial`, etc.).
|
||||||
|
- `--eval.batch_size` controls how many environments run in parallel.
|
||||||
|
- `--eval.n_episodes` sets how many episodes to run in total.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Multi-suite evaluation
|
||||||
|
|
||||||
|
Benchmark a policy across multiple suites at once:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python src/lerobot/scripts/eval.py \
|
||||||
|
--policy.path="your-policy-id" \
|
||||||
|
--env.type=libero \
|
||||||
|
--env.task=libero_object,libero_spatial \
|
||||||
|
--eval.batch_size=1 \
|
||||||
|
--eval.n_episodes=2
|
||||||
|
```
|
||||||
|
|
||||||
|
- Pass a comma-separated list to `--env.task` for multi-suite evaluation.
|
||||||
|
|
||||||
|
### Policy inputs and outputs
|
||||||
|
|
||||||
|
When using LIBERO through LeRobot, policies interact with the environment via **observations** and **actions**:
|
||||||
|
|
||||||
|
- **Observations**
|
||||||
|
- `observation.state` – proprioceptive features (agent state).
|
||||||
|
- `observation.images.image` – main camera view (`agentview_image`).
|
||||||
|
- `observation.images.image2` – wrist camera view (`robot0_eye_in_hand_image`).
|
||||||
|
|
||||||
|
⚠️ **Note:** LeRobot enforces the `.images.*` prefix for any multi-modal visual features. Always ensure that your policy config `input_features` use the same naming keys, and that your dataset metadata keys follow this convention during evaluation.
|
||||||
|
If your data contains different keys, you must rename the observations to match what the policy expects, since naming keys are encoded inside the normalization statistics layer.
|
||||||
|
This will be fixed with the upcoming Pipeline PR.
|
||||||
|
|
||||||
|
- **Actions**
|
||||||
|
- Continuous control values in a `Box(-1, 1, shape=(7,))` space.
|
||||||
|
|
||||||
|
We also provide a notebook for quick testing:
|
||||||
|
Training with LIBERO
|
||||||
|
|
||||||
|
## Training with LIBERO
|
||||||
|
|
||||||
|
When training on LIBERO tasks, make sure your dataset parquet and metadata keys follow the LeRobot convention.
|
||||||
|
|
||||||
|
The environment expects:
|
||||||
|
|
||||||
|
- `observation.state` → 8-dim agent state
|
||||||
|
- `observation.images.image` → main camera (`agentview_image`)
|
||||||
|
- `observation.images.image2` → wrist camera (`robot0_eye_in_hand_image`)
|
||||||
|
|
||||||
|
⚠️ Cleaning the dataset upfront is **cleaner and more efficient** than remapping keys inside the code.
|
||||||
|
To avoid potential mismatches and key errors, we provide a **preprocessed LIBERO dataset** that is fully compatible with the current LeRobot codebase and requires no additional manipulation:
|
||||||
|
👉 [HuggingFaceVLA/libero](https://huggingface.co/datasets/HuggingFaceVLA/libero)
|
||||||
|
|
||||||
|
For reference, here is the **original dataset** published by Physical Intelligence:
|
||||||
|
👉 [physical-intelligence/libero](https://huggingface.co/datasets/physical-intelligence/libero)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Example training command
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python src/lerobot/scripts/train.py \
|
||||||
|
--policy.type=smolvla \
|
||||||
|
--policy.repo_id=${HF_USER}/libero-test \
|
||||||
|
--dataset.repo_id=jadechoghari/smol-libero3 \
|
||||||
|
--env.type=libero \
|
||||||
|
--env.task=libero_10 \
|
||||||
|
--output_dir=./outputs/ \
|
||||||
|
--steps=100000 \
|
||||||
|
--batch_size=4 \
|
||||||
|
--eval.batch_size=1 \
|
||||||
|
--eval.n_episodes=1 \
|
||||||
|
--eval_freq=1000 \
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Note on rendering
|
||||||
|
|
||||||
|
LeRobot uses MuJoCo for simulation. You need to set the rendering backend before training or evaluation:
|
||||||
|
|
||||||
|
- `export MUJOCO_GL=egl` → for headless servers (e.g. HPC, cloud)
|
||||||
+17
-2
@@ -134,7 +134,21 @@ video_benchmark = ["scikit-image>=0.23.2", "pandas>=2.2.2"]
|
|||||||
aloha = ["gym-aloha>=0.1.1"]
|
aloha = ["gym-aloha>=0.1.1"]
|
||||||
pusht = ["gym-pusht>=0.1.5", "pymunk>=6.6.0,<7.0.0"] # TODO: Fix pymunk version in gym-pusht instead
|
pusht = ["gym-pusht>=0.1.5", "pymunk>=6.6.0,<7.0.0"] # TODO: Fix pymunk version in gym-pusht instead
|
||||||
xarm = ["gym-xarm>=0.1.1"]
|
xarm = ["gym-xarm>=0.1.1"]
|
||||||
|
libero = [
|
||||||
|
"hydra-core>=1.2,<1.4",
|
||||||
|
"easydict>=1.9",
|
||||||
|
"lerobot[transformers-dep]",
|
||||||
|
"robomimic==0.2.0",
|
||||||
|
"thop>=0.1.0.post2206102148",
|
||||||
|
"robosuite==1.4.0",
|
||||||
|
"bddl==1.0.1",
|
||||||
|
"matplotlib>=3.5.3",
|
||||||
|
"cloudpickle>=2.0.0",
|
||||||
|
"gym>=0.25,<0.27",
|
||||||
|
"future>=0.18.3",
|
||||||
|
"egl_probe @ git+https://github.com/jadechoghari/egl_probe.git#egg=egl_probe",
|
||||||
|
"libero @ git+https://github.com/jadechoghari/LIBERO.git@main#egg=libero",
|
||||||
|
]
|
||||||
# All
|
# All
|
||||||
all = [
|
all = [
|
||||||
"lerobot[dynamixel]",
|
"lerobot[dynamixel]",
|
||||||
@@ -153,7 +167,8 @@ all = [
|
|||||||
"lerobot[video_benchmark]",
|
"lerobot[video_benchmark]",
|
||||||
"lerobot[aloha]",
|
"lerobot[aloha]",
|
||||||
"lerobot[pusht]",
|
"lerobot[pusht]",
|
||||||
"lerobot[xarm]"
|
"lerobot[xarm]",
|
||||||
|
"lerobot[libero]"
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ class EnvConfig(draccus.ChoiceRegistry, abc.ABC):
|
|||||||
fps: int = 30
|
fps: int = 30
|
||||||
features: dict[str, PolicyFeature] = field(default_factory=dict)
|
features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||||
features_map: dict[str, str] = field(default_factory=dict)
|
features_map: dict[str, str] = field(default_factory=dict)
|
||||||
|
max_parallel_tasks: int = 5
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def type(self) -> str:
|
def type(self) -> str:
|
||||||
@@ -271,3 +272,55 @@ class HILEnvConfig(EnvConfig):
|
|||||||
"use_gamepad": self.use_gamepad,
|
"use_gamepad": self.use_gamepad,
|
||||||
"gripper_penalty": self.gripper_penalty,
|
"gripper_penalty": self.gripper_penalty,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@EnvConfig.register_subclass("libero")
|
||||||
|
@dataclass
|
||||||
|
class LiberoEnv(EnvConfig):
|
||||||
|
task: str = "libero_10" # can also choose libero_spatial, libero_object, etc.
|
||||||
|
fps: int = 30
|
||||||
|
episode_length: int = 520
|
||||||
|
obs_type: str = "pixels_agent_pos"
|
||||||
|
render_mode: str = "rgb_array"
|
||||||
|
camera_name: str = "agentview_image,robot0_eye_in_hand_image"
|
||||||
|
init_states: bool = True
|
||||||
|
camera_name_mapping: dict[str, str] | None = (None,)
|
||||||
|
features: dict[str, PolicyFeature] = field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
features_map: dict[str, str] = field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
"action": ACTION,
|
||||||
|
"agent_pos": OBS_STATE,
|
||||||
|
"pixels/agentview_image": f"{OBS_IMAGES}.image",
|
||||||
|
"pixels/robot0_eye_in_hand_image": f"{OBS_IMAGES}.image2",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.obs_type == "pixels":
|
||||||
|
self.features["pixels/agentview_image"] = PolicyFeature(
|
||||||
|
type=FeatureType.VISUAL, shape=(360, 360, 3)
|
||||||
|
)
|
||||||
|
self.features["pixels/robot0_eye_in_hand_image"] = PolicyFeature(
|
||||||
|
type=FeatureType.VISUAL, shape=(360, 360, 3)
|
||||||
|
)
|
||||||
|
elif self.obs_type == "pixels_agent_pos":
|
||||||
|
self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(8,))
|
||||||
|
self.features["pixels/agentview_image"] = PolicyFeature(
|
||||||
|
type=FeatureType.VISUAL, shape=(360, 360, 3)
|
||||||
|
)
|
||||||
|
self.features["pixels/robot0_eye_in_hand_image"] = PolicyFeature(
|
||||||
|
type=FeatureType.VISUAL, shape=(360, 360, 3)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported obs_type: {self.obs_type}")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def gym_kwargs(self) -> dict:
|
||||||
|
return {
|
||||||
|
"obs_type": self.obs_type,
|
||||||
|
"render_mode": self.render_mode,
|
||||||
|
}
|
||||||
|
|||||||
+36
-13
@@ -17,7 +17,7 @@ import importlib
|
|||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
|
|
||||||
from lerobot.envs.configs import AlohaEnv, EnvConfig, HILEnvConfig, PushtEnv, XarmEnv
|
from lerobot.envs.configs import AlohaEnv, EnvConfig, HILEnvConfig, LiberoEnv, PushtEnv, XarmEnv
|
||||||
|
|
||||||
|
|
||||||
def make_env_config(env_type: str, **kwargs) -> EnvConfig:
|
def make_env_config(env_type: str, **kwargs) -> EnvConfig:
|
||||||
@@ -29,11 +29,15 @@ def make_env_config(env_type: str, **kwargs) -> EnvConfig:
|
|||||||
return XarmEnv(**kwargs)
|
return XarmEnv(**kwargs)
|
||||||
elif env_type == "hil":
|
elif env_type == "hil":
|
||||||
return HILEnvConfig(**kwargs)
|
return HILEnvConfig(**kwargs)
|
||||||
|
elif env_type == "libero":
|
||||||
|
return LiberoEnv(**kwargs)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Policy type '{env_type}' is not available.")
|
raise ValueError(f"Policy type '{env_type}' is not available.")
|
||||||
|
|
||||||
|
|
||||||
def make_env(cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False) -> gym.vector.VectorEnv | None:
|
def make_env(
|
||||||
|
cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False
|
||||||
|
) -> dict[str, dict[int, gym.vector.VectorEnv]]:
|
||||||
"""Makes a gym vector environment according to the config.
|
"""Makes a gym vector environment according to the config.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -47,25 +51,44 @@ def make_env(cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False) -> g
|
|||||||
ModuleNotFoundError: If the requested env package is not installed
|
ModuleNotFoundError: If the requested env package is not installed
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
gym.vector.VectorEnv: The parallelized gym.env instance.
|
dict[str, dict[int, gym.vector.VectorEnv]]:
|
||||||
|
A mapping from suite name to indexed vectorized environments.
|
||||||
|
- For multi-task benchmarks (e.g., LIBERO): one entry per suite, and one vec env per task_id.
|
||||||
|
- For single-task environments: a single suite entry (cfg.type) with task_id=0.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if n_envs < 1:
|
if n_envs < 1:
|
||||||
raise ValueError("`n_envs must be at least 1")
|
raise ValueError("`n_envs` must be at least 1")
|
||||||
|
|
||||||
|
env_cls = gym.vector.AsyncVectorEnv if use_async_envs else gym.vector.SyncVectorEnv
|
||||||
|
|
||||||
|
if "libero" in cfg.type:
|
||||||
|
from lerobot.envs.libero import create_libero_envs
|
||||||
|
|
||||||
|
return create_libero_envs(
|
||||||
|
task=cfg.task,
|
||||||
|
n_envs=n_envs,
|
||||||
|
camera_name=cfg.camera_name,
|
||||||
|
init_states=cfg.init_states,
|
||||||
|
gym_kwargs=cfg.gym_kwargs,
|
||||||
|
env_cls=env_cls,
|
||||||
|
)
|
||||||
|
|
||||||
package_name = f"gym_{cfg.type}"
|
package_name = f"gym_{cfg.type}"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
importlib.import_module(package_name)
|
importlib.import_module(package_name)
|
||||||
except ModuleNotFoundError as e:
|
except ModuleNotFoundError as e:
|
||||||
print(f"{package_name} is not installed. Please install it with `pip install 'lerobot[{cfg.type}]'`")
|
raise ModuleNotFoundError(
|
||||||
raise e
|
f'{package_name} is not installed. Install with: pip install "lerobot[{cfg.type}]"'
|
||||||
|
) from e
|
||||||
|
|
||||||
gym_handle = f"{package_name}/{cfg.task}"
|
gym_handle = f"{package_name}/{cfg.task}"
|
||||||
|
|
||||||
# batched version of the env that returns an observation of shape (b, c)
|
def _make_one():
|
||||||
env_cls = gym.vector.AsyncVectorEnv if use_async_envs else gym.vector.SyncVectorEnv
|
return gym.make(gym_handle, disable_env_checker=True, **(cfg.gym_kwargs or {}))
|
||||||
env = env_cls(
|
|
||||||
[lambda: gym.make(gym_handle, disable_env_checker=True, **cfg.gym_kwargs) for _ in range(n_envs)]
|
|
||||||
)
|
|
||||||
|
|
||||||
return env
|
vec = env_cls([_make_one for _ in range(n_envs)])
|
||||||
|
|
||||||
|
# normalize to {suite: {task_id: vec_env}} for consistency
|
||||||
|
suite_name = cfg.type # e.g., "pusht", "aloha"
|
||||||
|
return {suite_name: {0: vec}}
|
||||||
|
|||||||
@@ -0,0 +1,399 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
from collections import defaultdict
|
||||||
|
from collections.abc import Callable, Iterable, Mapping, Sequence
|
||||||
|
from functools import partial
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import gymnasium as gym
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from gymnasium import spaces
|
||||||
|
from libero.libero import benchmark, get_libero_path
|
||||||
|
from libero.libero.envs import OffScreenRenderEnv
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_camera_names(camera_name: str | Sequence[str]) -> list[str]:
|
||||||
|
"""Normalize camera_name into a non-empty list of strings."""
|
||||||
|
if isinstance(camera_name, str):
|
||||||
|
cams = [c.strip() for c in camera_name.split(",") if c.strip()]
|
||||||
|
elif isinstance(camera_name, (list, tuple)):
|
||||||
|
cams = [str(c).strip() for c in camera_name if str(c).strip()]
|
||||||
|
else:
|
||||||
|
raise TypeError(f"camera_name must be str or sequence[str], got {type(camera_name).__name__}")
|
||||||
|
if not cams:
|
||||||
|
raise ValueError("camera_name resolved to an empty list.")
|
||||||
|
return cams
|
||||||
|
|
||||||
|
|
||||||
|
def _get_suite(name: str) -> Any:
|
||||||
|
"""Instantiate a LIBERO suite by name with clear validation."""
|
||||||
|
bench = benchmark.get_benchmark_dict()
|
||||||
|
if name not in bench:
|
||||||
|
raise ValueError(f"Unknown LIBERO suite '{name}'. Available: {', '.join(sorted(bench.keys()))}")
|
||||||
|
suite = bench[name]()
|
||||||
|
if not getattr(suite, "tasks", None):
|
||||||
|
raise ValueError(f"Suite '{name}' has no tasks.")
|
||||||
|
return suite
|
||||||
|
|
||||||
|
|
||||||
|
def _select_task_ids(total_tasks: int, task_ids: Iterable[int] | None) -> list[int]:
|
||||||
|
"""Validate/normalize task ids. If None → all tasks."""
|
||||||
|
if task_ids is None:
|
||||||
|
return list(range(total_tasks))
|
||||||
|
ids = sorted({int(t) for t in task_ids})
|
||||||
|
for t in ids:
|
||||||
|
if t < 0 or t >= total_tasks:
|
||||||
|
raise ValueError(f"task_id {t} out of range [0, {total_tasks - 1}].")
|
||||||
|
return ids
|
||||||
|
|
||||||
|
|
||||||
|
def quat2axisangle(quat: np.ndarray) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Copied from robosuite: https://github.com/ARISE-Initiative/robosuite/blob/eafb81f54ffc104f905ee48a16bb15f059176ad3/robosuite/utils/transform_utils.py#L490C1-L512C55
|
||||||
|
|
||||||
|
Converts quaternion to axis-angle format.
|
||||||
|
Returns a unit vector direction scaled by its angle in radians.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
quat (np.array): (x,y,z,w) vec4 float angles
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.array: (ax,ay,az) axis-angle exponential coordinates
|
||||||
|
"""
|
||||||
|
# clip quaternion
|
||||||
|
if quat[3] > 1.0:
|
||||||
|
quat[3] = 1.0
|
||||||
|
elif quat[3] < -1.0:
|
||||||
|
quat[3] = -1.0
|
||||||
|
|
||||||
|
den = np.sqrt(1.0 - quat[3] * quat[3])
|
||||||
|
if math.isclose(den, 0.0):
|
||||||
|
# This is (close to) a zero degree rotation, immediately return
|
||||||
|
return np.zeros(3)
|
||||||
|
|
||||||
|
return (quat[:3] * 2.0 * math.acos(quat[3])) / den
|
||||||
|
|
||||||
|
|
||||||
|
def get_task_init_states(task_suite: Any, i: int) -> np.ndarray:
|
||||||
|
init_states_path = (
|
||||||
|
Path(get_libero_path("init_states"))
|
||||||
|
/ task_suite.tasks[i].problem_folder
|
||||||
|
/ task_suite.tasks[i].init_states_file
|
||||||
|
)
|
||||||
|
init_states = torch.load(init_states_path, weights_only=False) # nosec B614
|
||||||
|
return init_states
|
||||||
|
|
||||||
|
|
||||||
|
def get_libero_dummy_action():
|
||||||
|
"""Get dummy/no-op action, used to roll out the simulation while the robot does nothing."""
|
||||||
|
return [0, 0, 0, 0, 0, 0, -1]
|
||||||
|
|
||||||
|
|
||||||
|
OBS_STATE_DIM = 8
|
||||||
|
ACTION_DIM = 7
|
||||||
|
TASK_SUITE_MAX_STEPS: dict[str, int] = {
|
||||||
|
"libero_spatial": 280, # longest training demo has 193 steps
|
||||||
|
"libero_object": 280, # longest training demo has 254 steps
|
||||||
|
"libero_goal": 300, # longest training demo has 270 steps
|
||||||
|
"libero_10": 520, # longest training demo has 505 steps
|
||||||
|
"libero_90": 400, # longest training demo has 373 steps
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class LiberoEnv(gym.Env):
|
||||||
|
metadata = {"render_modes": ["rgb_array"], "render_fps": 80}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
task_suite: Any,
|
||||||
|
task_id: int,
|
||||||
|
task_suite_name: str,
|
||||||
|
camera_name: str | Sequence[str] = "agentview_image,robot0_eye_in_hand_image",
|
||||||
|
obs_type: str = "pixels",
|
||||||
|
render_mode: str = "rgb_array",
|
||||||
|
observation_width: int = 256,
|
||||||
|
observation_height: int = 256,
|
||||||
|
visualization_width: int = 640,
|
||||||
|
visualization_height: int = 480,
|
||||||
|
init_states: bool = True,
|
||||||
|
episode_index: int = 0,
|
||||||
|
camera_name_mapping: dict[str, str] | None = None,
|
||||||
|
num_steps_wait: int = 10,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.task_id = task_id
|
||||||
|
self.obs_type = obs_type
|
||||||
|
self.render_mode = render_mode
|
||||||
|
self.observation_width = observation_width
|
||||||
|
self.observation_height = observation_height
|
||||||
|
self.visualization_width = visualization_width
|
||||||
|
self.visualization_height = visualization_height
|
||||||
|
self.init_states = init_states
|
||||||
|
self.camera_name = camera_name.split(
|
||||||
|
","
|
||||||
|
) # agentview_image (main) or robot0_eye_in_hand_image (wrist)
|
||||||
|
|
||||||
|
# Map raw camera names to "image1" and "image2".
|
||||||
|
# The preprocessing step `preprocess_observation` will then prefix these with `.images.*`,
|
||||||
|
# following the LeRobot convention (e.g., `observation.images.image`, `observation.images.image2`).
|
||||||
|
# This ensures the policy consistently receives observations in the
|
||||||
|
# expected format regardless of the original camera naming.
|
||||||
|
if camera_name_mapping is None:
|
||||||
|
camera_name_mapping = {
|
||||||
|
"agentview_image": "image",
|
||||||
|
"robot0_eye_in_hand_image": "image2",
|
||||||
|
}
|
||||||
|
self.camera_name_mapping = camera_name_mapping
|
||||||
|
self.num_steps_wait = num_steps_wait
|
||||||
|
self.episode_index = episode_index
|
||||||
|
# Load once and keep
|
||||||
|
self._init_states = get_task_init_states(task_suite, self.task_id) if self.init_states else None
|
||||||
|
self._init_state_id = self.episode_index # tie each sub-env to a fixed init state
|
||||||
|
|
||||||
|
self._env = self._make_envs_task(task_suite, self.task_id)
|
||||||
|
default_steps = 500
|
||||||
|
self._max_episode_steps = TASK_SUITE_MAX_STEPS.get(task_suite_name, default_steps)
|
||||||
|
|
||||||
|
images = {}
|
||||||
|
for cam in self.camera_name:
|
||||||
|
images[self.camera_name_mapping[cam]] = spaces.Box(
|
||||||
|
low=0,
|
||||||
|
high=255,
|
||||||
|
shape=(self.observation_height, self.observation_width, 3),
|
||||||
|
dtype=np.uint8,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.obs_type == "state":
|
||||||
|
raise NotImplementedError(
|
||||||
|
"The 'state' observation type is not supported in LiberoEnv. "
|
||||||
|
"Please switch to an image-based obs_type (e.g. 'pixels', 'pixels_agent_pos')."
|
||||||
|
)
|
||||||
|
|
||||||
|
elif self.obs_type == "pixels":
|
||||||
|
self.observation_space = spaces.Dict(
|
||||||
|
{
|
||||||
|
"pixels": spaces.Dict(images),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
elif self.obs_type == "pixels_agent_pos":
|
||||||
|
self.observation_space = spaces.Dict(
|
||||||
|
{
|
||||||
|
"pixels": spaces.Dict(images),
|
||||||
|
"agent_pos": spaces.Box(
|
||||||
|
low=-1000.0,
|
||||||
|
high=1000.0,
|
||||||
|
shape=(OBS_STATE_DIM,),
|
||||||
|
dtype=np.float64,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
self.action_space = spaces.Box(low=-1, high=1, shape=(ACTION_DIM,), dtype=np.float32)
|
||||||
|
|
||||||
|
def render(self):
|
||||||
|
raw_obs = self._env.env._get_observations()
|
||||||
|
image = self._format_raw_obs(raw_obs)["pixels"]["image"]
|
||||||
|
return image
|
||||||
|
|
||||||
|
def _make_envs_task(self, task_suite: Any, task_id: int = 0):
|
||||||
|
task = task_suite.get_task(task_id)
|
||||||
|
self.task = task.name
|
||||||
|
self.task_description = task.language
|
||||||
|
task_bddl_file = os.path.join(get_libero_path("bddl_files"), task.problem_folder, task.bddl_file)
|
||||||
|
|
||||||
|
env_args = {
|
||||||
|
"bddl_file_name": task_bddl_file,
|
||||||
|
"camera_heights": self.observation_height,
|
||||||
|
"camera_widths": self.observation_width,
|
||||||
|
}
|
||||||
|
env = OffScreenRenderEnv(**env_args)
|
||||||
|
env.reset()
|
||||||
|
return env
|
||||||
|
|
||||||
|
def _format_raw_obs(self, raw_obs: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
images = {}
|
||||||
|
for camera_name in self.camera_name:
|
||||||
|
image = raw_obs[camera_name]
|
||||||
|
image = image[::-1, ::-1] # rotate 180 degrees
|
||||||
|
images[self.camera_name_mapping[camera_name]] = image
|
||||||
|
state = np.concatenate(
|
||||||
|
(
|
||||||
|
raw_obs["robot0_eef_pos"],
|
||||||
|
quat2axisangle(raw_obs["robot0_eef_quat"]),
|
||||||
|
raw_obs["robot0_gripper_qpos"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
agent_pos = state
|
||||||
|
if self.obs_type == "pixels":
|
||||||
|
return {"pixels": images.copy()}
|
||||||
|
if self.obs_type == "pixels_agent_pos":
|
||||||
|
return {
|
||||||
|
"pixels": images.copy(),
|
||||||
|
"agent_pos": agent_pos,
|
||||||
|
}
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"The observation type '{self.obs_type}' is not supported in LiberoEnv. "
|
||||||
|
"Please switch to an image-based obs_type (e.g. 'pixels', 'pixels_agent_pos')."
|
||||||
|
)
|
||||||
|
|
||||||
|
def reset(self, seed=None, **kwargs):
|
||||||
|
super().reset(seed=seed)
|
||||||
|
self._env.seed(seed)
|
||||||
|
if self.init_states and self._init_states is not None:
|
||||||
|
self._env.set_init_state(self._init_states[self._init_state_id])
|
||||||
|
raw_obs = self._env.reset()
|
||||||
|
|
||||||
|
# After reset, objects may be unstable (slightly floating, intersecting, etc.).
|
||||||
|
# Step the simulator with a no-op action for a few frames so everything settles.
|
||||||
|
# Increasing this value can improve determinism and reproducibility across resets.
|
||||||
|
for _ in range(self.num_steps_wait):
|
||||||
|
raw_obs, _, _, _ = self._env.step(get_libero_dummy_action())
|
||||||
|
observation = self._format_raw_obs(raw_obs)
|
||||||
|
info = {"is_success": False}
|
||||||
|
return observation, info
|
||||||
|
|
||||||
|
def step(self, action: np.ndarray) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||||
|
if action.ndim != 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"Expected action to be 1-D (shape (action_dim,)), "
|
||||||
|
f"but got shape {action.shape} with ndim={action.ndim}"
|
||||||
|
)
|
||||||
|
raw_obs, reward, done, info = self._env.step(action)
|
||||||
|
|
||||||
|
is_success = self._env.check_success()
|
||||||
|
terminated = done or is_success
|
||||||
|
info["is_success"] = done # is_success
|
||||||
|
|
||||||
|
observation = self._format_raw_obs(raw_obs)
|
||||||
|
if done:
|
||||||
|
self.reset()
|
||||||
|
info.update(
|
||||||
|
{
|
||||||
|
"task": self.task,
|
||||||
|
"task_id": self.task_id,
|
||||||
|
"done": done,
|
||||||
|
"is_success": is_success,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
truncated = False
|
||||||
|
return observation, reward, terminated, truncated, info
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
self._env.close()
|
||||||
|
|
||||||
|
|
||||||
|
def _make_env_fns(
|
||||||
|
*,
|
||||||
|
suite,
|
||||||
|
suite_name: str,
|
||||||
|
task_id: int,
|
||||||
|
n_envs: int,
|
||||||
|
camera_names: list[str],
|
||||||
|
init_states: bool,
|
||||||
|
gym_kwargs: Mapping[str, Any],
|
||||||
|
) -> list[Callable[[], LiberoEnv]]:
|
||||||
|
"""Build n_envs factory callables for a single (suite, task_id)."""
|
||||||
|
joined_cams = ",".join(camera_names) # keep backward-compat: downstream expects a string
|
||||||
|
|
||||||
|
def _make_env(episode_index: int, **kwargs) -> LiberoEnv:
|
||||||
|
local_kwargs = dict(kwargs)
|
||||||
|
return LiberoEnv(
|
||||||
|
task_suite=suite,
|
||||||
|
task_id=task_id,
|
||||||
|
task_suite_name=suite_name,
|
||||||
|
camera_name=joined_cams,
|
||||||
|
init_states=init_states,
|
||||||
|
episode_index=episode_index,
|
||||||
|
**local_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
fns: list[Callable[[], LiberoEnv]] = []
|
||||||
|
for episode_index in range(n_envs):
|
||||||
|
fns.append(partial(_make_env, episode_index, **gym_kwargs))
|
||||||
|
return fns
|
||||||
|
|
||||||
|
|
||||||
|
# ---- Main API ----------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def create_libero_envs(
|
||||||
|
task: str,
|
||||||
|
n_envs: int,
|
||||||
|
gym_kwargs: dict[str, Any] | None = None,
|
||||||
|
camera_name: str | Sequence[str] = "agentview_image,robot0_eye_in_hand_image",
|
||||||
|
init_states: bool = True,
|
||||||
|
env_cls: Callable[[Sequence[Callable[[], Any]]], Any] | None = None,
|
||||||
|
) -> dict[str, dict[int, Any]]:
|
||||||
|
"""
|
||||||
|
Create vectorized LIBERO environments with a consistent return shape.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict[suite_name][task_id] -> vec_env (env_cls([...]) with exactly n_envs factories)
|
||||||
|
Notes:
|
||||||
|
- n_envs is the number of rollouts *per task* (episode_index = 0..n_envs-1).
|
||||||
|
- `task` can be a single suite or a comma-separated list of suites.
|
||||||
|
- You may pass `task_ids` (list[int]) inside `gym_kwargs` to restrict tasks per suite.
|
||||||
|
"""
|
||||||
|
if env_cls is None or not callable(env_cls):
|
||||||
|
raise ValueError("env_cls must be a callable that wraps a list of environment factory callables.")
|
||||||
|
if not isinstance(n_envs, int) or n_envs <= 0:
|
||||||
|
raise ValueError(f"n_envs must be a positive int; got {n_envs}.")
|
||||||
|
|
||||||
|
gym_kwargs = dict(gym_kwargs or {})
|
||||||
|
task_ids_filter = gym_kwargs.pop("task_ids", None) # optional: limit to specific tasks
|
||||||
|
|
||||||
|
camera_names = _parse_camera_names(camera_name)
|
||||||
|
suite_names = [s.strip() for s in str(task).split(",") if s.strip()]
|
||||||
|
if not suite_names:
|
||||||
|
raise ValueError("`task` must contain at least one LIBERO suite name.")
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Creating LIBERO envs | suites={suite_names} | n_envs(per task)={n_envs} | init_states={init_states}"
|
||||||
|
)
|
||||||
|
if task_ids_filter is not None:
|
||||||
|
print(f"Restricting to task_ids={task_ids_filter}")
|
||||||
|
|
||||||
|
out: dict[str, dict[int, Any]] = defaultdict(dict)
|
||||||
|
|
||||||
|
for suite_name in suite_names:
|
||||||
|
suite = _get_suite(suite_name)
|
||||||
|
total = len(suite.tasks)
|
||||||
|
selected = _select_task_ids(total, task_ids_filter)
|
||||||
|
|
||||||
|
if not selected:
|
||||||
|
raise ValueError(f"No tasks selected for suite '{suite_name}' (available: {total}).")
|
||||||
|
|
||||||
|
for tid in selected:
|
||||||
|
fns = _make_env_fns(
|
||||||
|
suite=suite,
|
||||||
|
suite_name=suite_name,
|
||||||
|
task_id=tid,
|
||||||
|
n_envs=n_envs,
|
||||||
|
camera_names=camera_names,
|
||||||
|
init_states=init_states,
|
||||||
|
gym_kwargs=gym_kwargs,
|
||||||
|
)
|
||||||
|
out[suite_name][tid] = env_cls(fns)
|
||||||
|
print(f"Built vec env | suite={suite_name} | task_id={tid} | n_envs={n_envs}")
|
||||||
|
|
||||||
|
# return plain dicts for predictability
|
||||||
|
return {suite: dict(task_map) for suite, task_map in out.items()}
|
||||||
@@ -14,6 +14,8 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import warnings
|
import warnings
|
||||||
|
from collections.abc import Mapping, Sequence
|
||||||
|
from functools import singledispatch
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import einops
|
import einops
|
||||||
@@ -97,7 +99,6 @@ def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]:
|
|||||||
|
|
||||||
policy_key = env_cfg.features_map[key]
|
policy_key = env_cfg.features_map[key]
|
||||||
policy_features[policy_key] = feature
|
policy_features[policy_key] = feature
|
||||||
|
|
||||||
return policy_features
|
return policy_features
|
||||||
|
|
||||||
|
|
||||||
@@ -134,3 +135,41 @@ def add_envs_task(env: gym.vector.VectorEnv, observation: dict[str, Any]) -> dic
|
|||||||
num_envs = observation[list(observation.keys())[0]].shape[0]
|
num_envs = observation[list(observation.keys())[0]].shape[0]
|
||||||
observation["task"] = ["" for _ in range(num_envs)]
|
observation["task"] = ["" for _ in range(num_envs)]
|
||||||
return observation
|
return observation
|
||||||
|
|
||||||
|
|
||||||
|
def _close_single_env(env: Any) -> None:
|
||||||
|
try:
|
||||||
|
env.close()
|
||||||
|
except Exception as exc:
|
||||||
|
print(f"Exception while closing env {env}: {exc}")
|
||||||
|
|
||||||
|
|
||||||
|
@singledispatch
|
||||||
|
def close_envs(obj: Any) -> None:
|
||||||
|
"""Default: raise if the type is not recognized."""
|
||||||
|
raise NotImplementedError(f"close_envs not implemented for type {type(obj).__name__}")
|
||||||
|
|
||||||
|
|
||||||
|
@close_envs.register
|
||||||
|
def _(env: Mapping) -> None:
|
||||||
|
for v in env.values():
|
||||||
|
if isinstance(v, Mapping):
|
||||||
|
close_envs(v)
|
||||||
|
elif hasattr(v, "close"):
|
||||||
|
_close_single_env(v)
|
||||||
|
|
||||||
|
|
||||||
|
@close_envs.register
|
||||||
|
def _(envs: Sequence) -> None:
|
||||||
|
if isinstance(envs, (str, bytes)):
|
||||||
|
return
|
||||||
|
for v in envs:
|
||||||
|
if isinstance(v, Mapping) or isinstance(v, Sequence) and not isinstance(v, (str, bytes)):
|
||||||
|
close_envs(v)
|
||||||
|
elif hasattr(v, "close"):
|
||||||
|
_close_single_env(v)
|
||||||
|
|
||||||
|
|
||||||
|
@close_envs.register
|
||||||
|
def _(env: gym.Env) -> None:
|
||||||
|
_close_single_env(env)
|
||||||
|
|||||||
@@ -156,7 +156,6 @@ def make_policy(
|
|||||||
"by default without stats from a dataset."
|
"by default without stats from a dataset."
|
||||||
)
|
)
|
||||||
features = env_to_policy_features(env_cfg)
|
features = env_to_policy_features(env_cfg)
|
||||||
|
|
||||||
cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
|
cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
|
||||||
cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features}
|
cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features}
|
||||||
kwargs["config"] = cfg
|
kwargs["config"] = cfg
|
||||||
@@ -169,10 +168,7 @@ def make_policy(
|
|||||||
else:
|
else:
|
||||||
# Make a fresh policy.
|
# Make a fresh policy.
|
||||||
policy = policy_cls(**kwargs)
|
policy = policy_cls(**kwargs)
|
||||||
|
|
||||||
policy.to(cfg.device)
|
policy.to(cfg.device)
|
||||||
assert isinstance(policy, nn.Module)
|
assert isinstance(policy, nn.Module)
|
||||||
|
|
||||||
# policy = torch.compile(policy, mode="reduce-overhead")
|
# policy = torch.compile(policy, mode="reduce-overhead")
|
||||||
|
|
||||||
return policy
|
return policy
|
||||||
|
|||||||
+234
-19
@@ -46,16 +46,19 @@ Note that in both examples, the repo/folder should contain at least `config.json
|
|||||||
You can learn about the CLI options for this script in the `EvalPipelineConfig` in lerobot/configs/eval.py
|
You can learn about the CLI options for this script in the `EvalPipelineConfig` in lerobot/configs/eval.py
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import concurrent.futures as cf
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from collections.abc import Callable
|
from collections import defaultdict
|
||||||
|
from collections.abc import Callable, Iterator
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from dataclasses import asdict
|
from dataclasses import asdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from pprint import pformat
|
from pprint import pformat
|
||||||
|
from typing import TypedDict
|
||||||
|
|
||||||
import einops
|
import einops
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
@@ -68,7 +71,12 @@ from tqdm import trange
|
|||||||
from lerobot.configs import parser
|
from lerobot.configs import parser
|
||||||
from lerobot.configs.eval import EvalPipelineConfig
|
from lerobot.configs.eval import EvalPipelineConfig
|
||||||
from lerobot.envs.factory import make_env
|
from lerobot.envs.factory import make_env
|
||||||
from lerobot.envs.utils import add_envs_task, check_env_attributes_and_types, preprocess_observation
|
from lerobot.envs.utils import (
|
||||||
|
add_envs_task,
|
||||||
|
check_env_attributes_and_types,
|
||||||
|
close_envs,
|
||||||
|
preprocess_observation,
|
||||||
|
)
|
||||||
from lerobot.policies.factory import make_policy
|
from lerobot.policies.factory import make_policy
|
||||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||||
from lerobot.policies.utils import get_device_from_parameters
|
from lerobot.policies.utils import get_device_from_parameters
|
||||||
@@ -145,7 +153,7 @@ def rollout(
|
|||||||
leave=False,
|
leave=False,
|
||||||
)
|
)
|
||||||
check_env_attributes_and_types(env)
|
check_env_attributes_and_types(env)
|
||||||
while not np.all(done):
|
while not np.all(done) and step < max_steps:
|
||||||
# Numpy array to tensor and changing dictionary keys to LeRobot policy format.
|
# Numpy array to tensor and changing dictionary keys to LeRobot policy format.
|
||||||
observation = preprocess_observation(observation)
|
observation = preprocess_observation(observation)
|
||||||
if return_observations:
|
if return_observations:
|
||||||
@@ -158,10 +166,8 @@ def rollout(
|
|||||||
# Infer "task" from attributes of environments.
|
# Infer "task" from attributes of environments.
|
||||||
# TODO: works with SyncVectorEnv but not AsyncVectorEnv
|
# TODO: works with SyncVectorEnv but not AsyncVectorEnv
|
||||||
observation = add_envs_task(env, observation)
|
observation = add_envs_task(env, observation)
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
action = policy.select_action(observation)
|
action = policy.select_action(observation)
|
||||||
|
|
||||||
# Convert to CPU / numpy.
|
# Convert to CPU / numpy.
|
||||||
action = action.to("cpu").numpy()
|
action = action.to("cpu").numpy()
|
||||||
assert action.ndim == 2, "Action dimensions should be (batch, action_dim)"
|
assert action.ndim == 2, "Action dimensions should be (batch, action_dim)"
|
||||||
@@ -179,7 +185,12 @@ def rollout(
|
|||||||
successes = [False] * env.num_envs
|
successes = [False] * env.num_envs
|
||||||
|
|
||||||
# Keep track of which environments are done so far.
|
# Keep track of which environments are done so far.
|
||||||
|
# Mark the episode as done if we reach the maximum step limit.
|
||||||
|
# This ensures that the rollout always terminates cleanly at `max_steps`,
|
||||||
|
# and allows logging/saving (e.g., videos) to be triggered consistently.
|
||||||
done = terminated | truncated | done
|
done = terminated | truncated | done
|
||||||
|
if step + 1 == max_steps:
|
||||||
|
done = np.ones_like(done, dtype=bool)
|
||||||
|
|
||||||
all_actions.append(torch.from_numpy(action))
|
all_actions.append(torch.from_numpy(action))
|
||||||
all_rewards.append(torch.from_numpy(reward))
|
all_rewards.append(torch.from_numpy(reward))
|
||||||
@@ -402,7 +413,6 @@ def eval_policy(
|
|||||||
"eval_ep_s": (time.time() - start) / n_episodes,
|
"eval_ep_s": (time.time() - start) / n_episodes,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
if return_episode_data:
|
if return_episode_data:
|
||||||
info["episodes"] = episode_data
|
info["episodes"] = episode_data
|
||||||
|
|
||||||
@@ -463,7 +473,6 @@ def eval_main(cfg: EvalPipelineConfig):
|
|||||||
|
|
||||||
# Check device is available
|
# Check device is available
|
||||||
device = get_safe_torch_device(cfg.policy.device, log=True)
|
device = get_safe_torch_device(cfg.policy.device, log=True)
|
||||||
|
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
set_seed(cfg.seed)
|
set_seed(cfg.seed)
|
||||||
@@ -471,40 +480,246 @@ def eval_main(cfg: EvalPipelineConfig):
|
|||||||
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
|
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
|
||||||
|
|
||||||
logging.info("Making environment.")
|
logging.info("Making environment.")
|
||||||
env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs)
|
envs = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs)
|
||||||
|
|
||||||
logging.info("Making policy.")
|
logging.info("Making policy.")
|
||||||
|
|
||||||
policy = make_policy(
|
policy = make_policy(
|
||||||
cfg=cfg.policy,
|
cfg=cfg.policy,
|
||||||
env_cfg=cfg.env,
|
env_cfg=cfg.env,
|
||||||
)
|
)
|
||||||
policy.eval()
|
|
||||||
|
|
||||||
|
policy.eval()
|
||||||
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
|
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
|
||||||
info = eval_policy(
|
info = eval_policy_all(
|
||||||
env,
|
envs,
|
||||||
policy,
|
policy,
|
||||||
cfg.eval.n_episodes,
|
cfg.eval.n_episodes,
|
||||||
max_episodes_rendered=10,
|
max_episodes_rendered=10,
|
||||||
videos_dir=Path(cfg.output_dir) / "videos",
|
videos_dir=Path(cfg.output_dir) / "videos",
|
||||||
start_seed=cfg.seed,
|
start_seed=cfg.seed,
|
||||||
|
max_parallel_tasks=cfg.env.max_parallel_tasks,
|
||||||
|
verbose=False,
|
||||||
)
|
)
|
||||||
print(info["aggregated"])
|
print("Overall Aggregated Metrics:")
|
||||||
|
print(info["overall"]["aggregated"])
|
||||||
|
|
||||||
|
# Print per-suite stats
|
||||||
|
for task_group, task_group_info in info.items():
|
||||||
|
if task_group == "overall":
|
||||||
|
continue # Skip the overall stats since we already printed it
|
||||||
|
print(f"\nAggregated Metrics for {task_group}:")
|
||||||
|
print(task_group_info["aggregated"])
|
||||||
|
# Close all vec envs
|
||||||
|
close_envs(envs)
|
||||||
|
|
||||||
# Save info
|
# Save info
|
||||||
with open(Path(cfg.output_dir) / "eval_info.json", "w") as f:
|
with open(Path(cfg.output_dir) / "eval_info.json", "w") as f:
|
||||||
json.dump(info, f, indent=2)
|
json.dump(info, f, indent=2)
|
||||||
|
|
||||||
env.close()
|
|
||||||
|
|
||||||
logging.info("End of eval")
|
logging.info("End of eval")
|
||||||
|
|
||||||
|
|
||||||
def main():
|
# ---- typed payload returned by one task eval ----
|
||||||
init_logging()
|
class TaskMetrics(TypedDict):
|
||||||
eval_main()
|
sum_rewards: list[float]
|
||||||
|
max_rewards: list[float]
|
||||||
|
successes: list[bool]
|
||||||
|
video_paths: list[str]
|
||||||
|
|
||||||
|
|
||||||
|
ACC_KEYS = ("sum_rewards", "max_rewards", "successes", "video_paths")
|
||||||
|
|
||||||
|
|
||||||
|
def eval_policy_all(
|
||||||
|
envs: dict[str, dict[int, gym.vector.VectorEnv]],
|
||||||
|
policy: PreTrainedPolicy,
|
||||||
|
n_episodes: int,
|
||||||
|
max_episodes_rendered: int = 0,
|
||||||
|
videos_dir: Path | None = None,
|
||||||
|
return_episode_data: bool = False,
|
||||||
|
start_seed: int | None = None,
|
||||||
|
max_parallel_tasks: int = 5,
|
||||||
|
verbose: bool = True,
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Evaluate a policy over a dict-of-dicts of vectorized envs:
|
||||||
|
envs[suite_name][task_id] -> gym.vector.VectorEnv
|
||||||
|
Returns a dict with per-suite aggregates and an 'overall' block.
|
||||||
|
"""
|
||||||
|
global_start = time.time()
|
||||||
|
|
||||||
|
# inner: evaluate a single (suite, task)
|
||||||
|
def eval_one(
|
||||||
|
task_group: str,
|
||||||
|
task_id: int,
|
||||||
|
env: gym.vector.VectorEnv,
|
||||||
|
*,
|
||||||
|
policy: PreTrainedPolicy,
|
||||||
|
n_episodes: int,
|
||||||
|
max_episodes_rendered: int,
|
||||||
|
videos_dir: Path | None,
|
||||||
|
return_episode_data: bool,
|
||||||
|
start_seed: int | None,
|
||||||
|
) -> TaskMetrics:
|
||||||
|
"""Evaluates one task_id of one suite using the provided vec env."""
|
||||||
|
if verbose:
|
||||||
|
print(f"Evaluating: task_group={task_group}, task_id={task_id} ...")
|
||||||
|
|
||||||
|
task_videos_dir = None
|
||||||
|
if videos_dir is not None:
|
||||||
|
task_videos_dir = videos_dir / f"{task_group}_{task_id}"
|
||||||
|
task_videos_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
task_result = eval_policy(
|
||||||
|
env=env,
|
||||||
|
policy=policy,
|
||||||
|
n_episodes=n_episodes,
|
||||||
|
max_episodes_rendered=max_episodes_rendered,
|
||||||
|
videos_dir=task_videos_dir,
|
||||||
|
return_episode_data=return_episode_data,
|
||||||
|
start_seed=start_seed,
|
||||||
|
)
|
||||||
|
|
||||||
|
per_episode = task_result["per_episode"]
|
||||||
|
return TaskMetrics(
|
||||||
|
sum_rewards=[ep["sum_reward"] for ep in per_episode],
|
||||||
|
max_rewards=[ep["max_reward"] for ep in per_episode],
|
||||||
|
successes=[ep["success"] for ep in per_episode],
|
||||||
|
video_paths=task_result.get("video_paths", []),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _eval_monotask(
|
||||||
|
envs, policy, n_episodes, max_episodes_rendered, videos_dir, return_episode_data, start_seed
|
||||||
|
):
|
||||||
|
for task_group, tasks in envs.items():
|
||||||
|
for task_id, vec in tasks.items():
|
||||||
|
yield (
|
||||||
|
task_group,
|
||||||
|
task_id,
|
||||||
|
eval_one(
|
||||||
|
task_group,
|
||||||
|
task_id,
|
||||||
|
vec,
|
||||||
|
policy=policy,
|
||||||
|
n_episodes=n_episodes,
|
||||||
|
max_episodes_rendered=max_episodes_rendered,
|
||||||
|
videos_dir=videos_dir,
|
||||||
|
return_episode_data=return_episode_data,
|
||||||
|
start_seed=start_seed,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _eval_parallel(
|
||||||
|
envs,
|
||||||
|
policy,
|
||||||
|
n_episodes,
|
||||||
|
max_episodes_rendered,
|
||||||
|
videos_dir,
|
||||||
|
return_episode_data,
|
||||||
|
start_seed,
|
||||||
|
max_parallel_tasks,
|
||||||
|
):
|
||||||
|
with cf.ThreadPoolExecutor(max_workers=max_parallel_tasks) as executor:
|
||||||
|
fut2key: dict[cf.Future, tuple[str, int]] = {}
|
||||||
|
for task_group, tasks in envs.items():
|
||||||
|
for task_id, vec in tasks.items():
|
||||||
|
fut = executor.submit(
|
||||||
|
eval_one,
|
||||||
|
task_group,
|
||||||
|
task_id,
|
||||||
|
vec,
|
||||||
|
policy=policy,
|
||||||
|
n_episodes=n_episodes,
|
||||||
|
max_episodes_rendered=max_episodes_rendered,
|
||||||
|
videos_dir=videos_dir,
|
||||||
|
return_episode_data=return_episode_data,
|
||||||
|
start_seed=start_seed,
|
||||||
|
)
|
||||||
|
fut2key[fut] = (task_group, task_id)
|
||||||
|
for fut in cf.as_completed(fut2key):
|
||||||
|
task_group, task_id = fut2key[fut]
|
||||||
|
yield task_group, task_id, fut.result()
|
||||||
|
|
||||||
|
# result producer: sequential or threaded, same consumer
|
||||||
|
def iter_task_results() -> Iterator[tuple[str, int, TaskMetrics]]:
|
||||||
|
"""
|
||||||
|
Yield evaluation results for each (task_group, task_id).
|
||||||
|
|
||||||
|
Depending on `max_parallel_tasks`, runs sequentially or in parallel,
|
||||||
|
but always returns a generator of tuples:
|
||||||
|
(task_group, task_id, TaskMetrics).
|
||||||
|
"""
|
||||||
|
if max_parallel_tasks == 1:
|
||||||
|
yield from _eval_monotask(
|
||||||
|
envs, policy, n_episodes, max_episodes_rendered, videos_dir, return_episode_data, start_seed
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
yield from _eval_parallel(
|
||||||
|
envs,
|
||||||
|
policy,
|
||||||
|
n_episodes,
|
||||||
|
max_episodes_rendered,
|
||||||
|
videos_dir,
|
||||||
|
return_episode_data,
|
||||||
|
start_seed,
|
||||||
|
max_parallel_tasks,
|
||||||
|
)
|
||||||
|
|
||||||
|
# single accumulator path on the main thread
|
||||||
|
group_acc: dict[str, dict[str, list]] = defaultdict(lambda: {k: [] for k in ACC_KEYS})
|
||||||
|
overall: dict[str, list] = {k: [] for k in ACC_KEYS}
|
||||||
|
|
||||||
|
for task_group, _task_id, metrics in iter_task_results():
|
||||||
|
acc = group_acc[task_group]
|
||||||
|
for k in ACC_KEYS:
|
||||||
|
acc[k].extend(metrics[k])
|
||||||
|
overall[k].extend(metrics[k])
|
||||||
|
|
||||||
|
# build outputs
|
||||||
|
results: dict[str, dict] = {}
|
||||||
|
for task_group, data in group_acc.items():
|
||||||
|
suite_rewards = data["sum_rewards"]
|
||||||
|
suite_max = data["max_rewards"]
|
||||||
|
suite_succ = data["successes"]
|
||||||
|
suite_vids = data["video_paths"]
|
||||||
|
|
||||||
|
suite_eval_s = time.time() - global_start
|
||||||
|
suite_eval_ep_s = suite_eval_s / max(1, len(suite_rewards))
|
||||||
|
|
||||||
|
results[task_group] = {
|
||||||
|
"aggregated": {
|
||||||
|
"avg_sum_reward": float(np.nanmean(suite_rewards)) if suite_rewards else float("nan"),
|
||||||
|
"avg_max_reward": float(np.nanmean(suite_max)) if suite_max else float("nan"),
|
||||||
|
"pc_success": float(np.nanmean(suite_succ) * 100) if suite_succ else float("nan"),
|
||||||
|
"eval_s": suite_eval_s,
|
||||||
|
"eval_ep_s": suite_eval_ep_s,
|
||||||
|
},
|
||||||
|
"video_paths": suite_vids,
|
||||||
|
"episodes": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
global_eval_s = time.time() - global_start
|
||||||
|
global_eval_ep_s = global_eval_s / max(1, len(overall["sum_rewards"]))
|
||||||
|
results["overall"] = {
|
||||||
|
"aggregated": {
|
||||||
|
"avg_sum_reward": float(np.nanmean(overall["sum_rewards"]))
|
||||||
|
if overall["sum_rewards"]
|
||||||
|
else float("nan"),
|
||||||
|
"avg_max_reward": float(np.nanmean(overall["max_rewards"]))
|
||||||
|
if overall["max_rewards"]
|
||||||
|
else float("nan"),
|
||||||
|
"pc_success": float(np.nanmean(overall["successes"]) * 100)
|
||||||
|
if overall["successes"]
|
||||||
|
else float("nan"),
|
||||||
|
"eval_s": global_eval_s,
|
||||||
|
"eval_ep_s": global_eval_ep_s,
|
||||||
|
},
|
||||||
|
"video_paths": overall["video_paths"],
|
||||||
|
"episodes": None,
|
||||||
|
}
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
init_logging()
|
||||||
|
eval_main()
|
||||||
|
|||||||
@@ -30,11 +30,12 @@ from lerobot.datasets.factory import make_dataset
|
|||||||
from lerobot.datasets.sampler import EpisodeAwareSampler
|
from lerobot.datasets.sampler import EpisodeAwareSampler
|
||||||
from lerobot.datasets.utils import cycle
|
from lerobot.datasets.utils import cycle
|
||||||
from lerobot.envs.factory import make_env
|
from lerobot.envs.factory import make_env
|
||||||
|
from lerobot.envs.utils import close_envs
|
||||||
from lerobot.optim.factory import make_optimizer_and_scheduler
|
from lerobot.optim.factory import make_optimizer_and_scheduler
|
||||||
from lerobot.policies.factory import make_policy
|
from lerobot.policies.factory import make_policy
|
||||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||||
from lerobot.policies.utils import get_device_from_parameters
|
from lerobot.policies.utils import get_device_from_parameters
|
||||||
from lerobot.scripts.eval import eval_policy
|
from lerobot.scripts.eval import eval_policy_all
|
||||||
from lerobot.utils.logging_utils import AverageMeter, MetricsTracker
|
from lerobot.utils.logging_utils import AverageMeter, MetricsTracker
|
||||||
from lerobot.utils.random_utils import set_seed
|
from lerobot.utils.random_utils import set_seed
|
||||||
from lerobot.utils.train_utils import (
|
from lerobot.utils.train_utils import (
|
||||||
@@ -126,7 +127,6 @@ def train(cfg: TrainPipelineConfig):
|
|||||||
|
|
||||||
logging.info("Creating dataset")
|
logging.info("Creating dataset")
|
||||||
dataset = make_dataset(cfg)
|
dataset = make_dataset(cfg)
|
||||||
|
|
||||||
# Create environment used for evaluating checkpoints during training on simulation data.
|
# Create environment used for evaluating checkpoints during training on simulation data.
|
||||||
# On real-world data, no need to create an environment as evaluations are done outside train.py,
|
# On real-world data, no need to create an environment as evaluations are done outside train.py,
|
||||||
# using the eval.py instead, with gym_dora environment and dora-rs.
|
# using the eval.py instead, with gym_dora environment and dora-rs.
|
||||||
@@ -140,7 +140,6 @@ def train(cfg: TrainPipelineConfig):
|
|||||||
cfg=cfg.policy,
|
cfg=cfg.policy,
|
||||||
ds_meta=dataset.meta,
|
ds_meta=dataset.meta,
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("Creating optimizer and scheduler")
|
logging.info("Creating optimizer and scheduler")
|
||||||
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
|
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
|
||||||
grad_scaler = GradScaler(device.type, enabled=cfg.policy.use_amp)
|
grad_scaler = GradScaler(device.type, enabled=cfg.policy.use_amp)
|
||||||
@@ -188,7 +187,6 @@ def train(cfg: TrainPipelineConfig):
|
|||||||
dl_iter = cycle(dataloader)
|
dl_iter = cycle(dataloader)
|
||||||
|
|
||||||
policy.train()
|
policy.train()
|
||||||
|
|
||||||
train_metrics = {
|
train_metrics = {
|
||||||
"loss": AverageMeter("loss", ":.3f"),
|
"loss": AverageMeter("loss", ":.3f"),
|
||||||
"grad_norm": AverageMeter("grdn", ":.3f"),
|
"grad_norm": AverageMeter("grdn", ":.3f"),
|
||||||
@@ -206,7 +204,6 @@ def train(cfg: TrainPipelineConfig):
|
|||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
batch = next(dl_iter)
|
batch = next(dl_iter)
|
||||||
train_tracker.dataloading_s = time.perf_counter() - start_time
|
train_tracker.dataloading_s = time.perf_counter() - start_time
|
||||||
|
|
||||||
for key in batch:
|
for key in batch:
|
||||||
if isinstance(batch[key], torch.Tensor):
|
if isinstance(batch[key], torch.Tensor):
|
||||||
if batch[key].dtype != torch.bool:
|
if batch[key].dtype != torch.bool:
|
||||||
@@ -257,15 +254,27 @@ def train(cfg: TrainPipelineConfig):
|
|||||||
torch.no_grad(),
|
torch.no_grad(),
|
||||||
torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(),
|
torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(),
|
||||||
):
|
):
|
||||||
eval_info = eval_policy(
|
eval_info = eval_policy_all(
|
||||||
eval_env,
|
eval_env, # dict[suite][task_id] -> vec_env
|
||||||
policy,
|
policy,
|
||||||
cfg.eval.n_episodes,
|
cfg.eval.n_episodes,
|
||||||
videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_id}",
|
videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_id}",
|
||||||
max_episodes_rendered=4,
|
max_episodes_rendered=4,
|
||||||
start_seed=cfg.seed,
|
start_seed=cfg.seed,
|
||||||
|
max_parallel_tasks=cfg.env.max_parallel_tasks,
|
||||||
|
verbose=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# overall metrics (suite-agnostic)
|
||||||
|
aggregated = eval_info["overall"]["aggregated"]
|
||||||
|
|
||||||
|
# optional: per-suite logging
|
||||||
|
for suite, suite_info in eval_info.items():
|
||||||
|
if suite == "overall":
|
||||||
|
continue
|
||||||
|
logging.info("Suite %s aggregated: %s", suite, suite_info["aggregated"])
|
||||||
|
|
||||||
|
# meters/tracker
|
||||||
eval_metrics = {
|
eval_metrics = {
|
||||||
"avg_sum_reward": AverageMeter("∑rwrd", ":.3f"),
|
"avg_sum_reward": AverageMeter("∑rwrd", ":.3f"),
|
||||||
"pc_success": AverageMeter("success", ":.1f"),
|
"pc_success": AverageMeter("success", ":.1f"),
|
||||||
@@ -274,17 +283,16 @@ def train(cfg: TrainPipelineConfig):
|
|||||||
eval_tracker = MetricsTracker(
|
eval_tracker = MetricsTracker(
|
||||||
cfg.batch_size, dataset.num_frames, dataset.num_episodes, eval_metrics, initial_step=step
|
cfg.batch_size, dataset.num_frames, dataset.num_episodes, eval_metrics, initial_step=step
|
||||||
)
|
)
|
||||||
eval_tracker.eval_s = eval_info["aggregated"].pop("eval_s")
|
eval_tracker.eval_s = aggregated.get("eval_s", 0.0)
|
||||||
eval_tracker.avg_sum_reward = eval_info["aggregated"].pop("avg_sum_reward")
|
eval_tracker.avg_sum_reward = aggregated.get("avg_sum_reward", float("nan"))
|
||||||
eval_tracker.pc_success = eval_info["aggregated"].pop("pc_success")
|
eval_tracker.pc_success = aggregated.get("pc_success", float("nan"))
|
||||||
logging.info(eval_tracker)
|
|
||||||
if wandb_logger:
|
if wandb_logger:
|
||||||
wandb_log_dict = {**eval_tracker.to_dict(), **eval_info}
|
wandb_log_dict = {**eval_tracker.to_dict(), **eval_info}
|
||||||
wandb_logger.log_dict(wandb_log_dict, step, mode="eval")
|
wandb_logger.log_dict(wandb_log_dict, step, mode="eval")
|
||||||
wandb_logger.log_video(eval_info["video_paths"][0], step, mode="eval")
|
wandb_logger.log_video(eval_info["video_paths"][0], step, mode="eval")
|
||||||
|
|
||||||
if eval_env:
|
if eval_env:
|
||||||
eval_env.close()
|
close_envs(eval_env)
|
||||||
logging.info("End of training")
|
logging.info("End of training")
|
||||||
|
|
||||||
if cfg.policy.push_to_hub:
|
if cfg.policy.push_to_hub:
|
||||||
|
|||||||
@@ -46,7 +46,10 @@ def test_env(env_name, env_task, obs_type):
|
|||||||
@require_env
|
@require_env
|
||||||
def test_factory(env_name):
|
def test_factory(env_name):
|
||||||
cfg = make_env_config(env_name)
|
cfg = make_env_config(env_name)
|
||||||
env = make_env(cfg, n_envs=1)
|
envs = make_env(cfg, n_envs=1)
|
||||||
|
suite_name = next(iter(envs))
|
||||||
|
task_id = next(iter(envs[suite_name]))
|
||||||
|
env = envs[suite_name][task_id]
|
||||||
obs, _ = env.reset()
|
obs, _ = env.reset()
|
||||||
obs = preprocess_observation(obs)
|
obs = preprocess_observation(obs)
|
||||||
|
|
||||||
|
|||||||
@@ -158,7 +158,7 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs):
|
|||||||
assert isinstance(policy, PreTrainedPolicy)
|
assert isinstance(policy, PreTrainedPolicy)
|
||||||
|
|
||||||
# Check that we run select_actions and get the appropriate output.
|
# Check that we run select_actions and get the appropriate output.
|
||||||
env = make_env(train_cfg.env, n_envs=2)
|
envs = make_env(train_cfg.env, n_envs=2)
|
||||||
|
|
||||||
dataloader = torch.utils.data.DataLoader(
|
dataloader = torch.utils.data.DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
@@ -187,6 +187,13 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs):
|
|||||||
|
|
||||||
# reset the policy and environment
|
# reset the policy and environment
|
||||||
policy.reset()
|
policy.reset()
|
||||||
|
# For testing purposes, we only need a single environment instance.
|
||||||
|
# So here we unwrap the first suite_name and first task_id to grab
|
||||||
|
# the actual env object (SyncVectorEnv) that exposes `.reset()`.
|
||||||
|
suite_name = next(iter(envs))
|
||||||
|
task_id = next(iter(envs[suite_name]))
|
||||||
|
env = envs[suite_name][task_id]
|
||||||
|
|
||||||
observation, _ = env.reset(seed=train_cfg.seed)
|
observation, _ = env.reset(seed=train_cfg.seed)
|
||||||
|
|
||||||
# apply transform to normalize the observations
|
# apply transform to normalize the observations
|
||||||
|
|||||||
Reference in New Issue
Block a user