add libero

This commit is contained in:
Jade Choghari
2025-08-05 23:55:08 -04:00
parent 06bebd97b3
commit 21a961ecbb
6 changed files with 626 additions and 59 deletions
+59
View File
@@ -0,0 +1,59 @@
#!/bin/bash
# config
REPO_ID=physical-intelligence/libero
TASK=libero_10
OUTPUT_DIR=./outputs/train_run/smolvla2_libero
# clean previous run
rm -rf $OUTPUT_DIR
# training params
STEPS=100000
BATCH_SIZE=4
EVAL_FREQ=2000
SAVE_FREQ=10000
NUM_WORKERS=0
# model params
POLICY=smolvla
USE_AMP=false
OPTIMIZER_LR=1e-4
PEFT_METHOD=lora
LOAD_VLM_WEIGHTS=true
VLM_REPO_ID=None
MAX_ACTION_DIM=32
MAX_STATE_DIM=32
# dataset/image params
USE_IMAGENET_STATS=false
ENABLE_IMG_TRANSFORM=true
MAX_NUM_IMAGES=2
MAX_IMAGE_DIM=1024
echo -e "\033[1;33m[WARNING]\033[0m LIBERO is not yet fully supported in this PR!"
# launch
MUJOCO_GL=egl python src/lerobot/scripts/train.py \
--policy.type=$POLICY \
--dataset.repo_id=$REPO_ID \
--env.type=libero \
--env.task=$TASK \
--output_dir=$OUTPUT_DIR \
--steps=$STEPS \
--batch_size=$BATCH_SIZE \
--eval_freq=$EVAL_FREQ \
--save_freq=$SAVE_FREQ \
--num_workers=$NUM_WORKERS \
--policy.max_action_dim=$MAX_ACTION_DIM \
--policy.max_state_dim=$MAX_STATE_DIM \
--policy.use_amp=$USE_AMP \
--policy.optimizer_lr=$OPTIMIZER_LR \
--policy.peft_method=$PEFT_METHOD \
--policy.load_vlm_weights=$LOAD_VLM_WEIGHTS \
--policy.repo_id=$VLM_REPO_ID \
--dataset.use_imagenet_stats=$USE_IMAGENET_STATS \
--dataset.image_transforms.enable=$ENABLE_IMG_TRANSFORM \
--dataset.max_num_images=$MAX_NUM_IMAGES \
--dataset.max_image_dim=$MAX_IMAGE_DIM \
# --policy.exclude_image_keys=wrist_image \
# --policy.use_env_state=false
+1
View File
@@ -20,6 +20,7 @@ from huggingface_hub.constants import HF_HOME
OBS_ENV_STATE = "observation.environment_state" OBS_ENV_STATE = "observation.environment_state"
OBS_STATE = "observation.state" OBS_STATE = "observation.state"
OBS_IMAGE = "observation.image" OBS_IMAGE = "observation.image"
OBS_IMAGE_2 = "observation.image2"
OBS_IMAGES = "observation.images" OBS_IMAGES = "observation.images"
ACTION = "action" ACTION = "action"
REWARD = "next.reward" REWARD = "next.reward"
+81 -27
View File
@@ -14,12 +14,12 @@
import abc import abc
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any from typing import Any, Optional
import draccus import draccus
from lerobot.configs.types import FeatureType, PolicyFeature from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.constants import ACTION, OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE from lerobot.constants import ACTION, OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGE_2, OBS_IMAGES, OBS_STATE
from lerobot.robots import RobotConfig from lerobot.robots import RobotConfig
from lerobot.teleoperators.config import TeleoperatorConfig from lerobot.teleoperators.config import TeleoperatorConfig
@@ -30,6 +30,8 @@ class EnvConfig(draccus.ChoiceRegistry, abc.ABC):
fps: int = 30 fps: int = 30
features: dict[str, PolicyFeature] = field(default_factory=dict) features: dict[str, PolicyFeature] = field(default_factory=dict)
features_map: dict[str, str] = field(default_factory=dict) features_map: dict[str, str] = field(default_factory=dict)
multitask_eval: bool = False
max_parallel_tasks: int = 5
@property @property
def type(self) -> str: def type(self) -> str:
@@ -44,7 +46,7 @@ class EnvConfig(draccus.ChoiceRegistry, abc.ABC):
@EnvConfig.register_subclass("aloha") @EnvConfig.register_subclass("aloha")
@dataclass @dataclass
class AlohaEnv(EnvConfig): class AlohaEnv(EnvConfig):
task: str | None = "AlohaInsertion-v0" task: str = "AlohaInsertion-v0"
fps: int = 50 fps: int = 50
episode_length: int = 400 episode_length: int = 400
obs_type: str = "pixels_agent_pos" obs_type: str = "pixels_agent_pos"
@@ -82,7 +84,7 @@ class AlohaEnv(EnvConfig):
@EnvConfig.register_subclass("pusht") @EnvConfig.register_subclass("pusht")
@dataclass @dataclass
class PushtEnv(EnvConfig): class PushtEnv(EnvConfig):
task: str | None = "PushT-v0" task: str = "PushT-v0"
fps: int = 10 fps: int = 10
episode_length: int = 300 episode_length: int = 300
obs_type: str = "pixels_agent_pos" obs_type: str = "pixels_agent_pos"
@@ -124,7 +126,7 @@ class PushtEnv(EnvConfig):
@EnvConfig.register_subclass("xarm") @EnvConfig.register_subclass("xarm")
@dataclass @dataclass
class XarmEnv(EnvConfig): class XarmEnv(EnvConfig):
task: str | None = "XarmLift-v0" task: str = "XarmLift-v0"
fps: int = 15 fps: int = 15
episode_length: int = 200 episode_length: int = 200
obs_type: str = "pixels_agent_pos" obs_type: str = "pixels_agent_pos"
@@ -179,10 +181,10 @@ class EnvTransformConfig:
add_joint_velocity_to_observation: bool = False add_joint_velocity_to_observation: bool = False
add_current_to_observation: bool = False add_current_to_observation: bool = False
add_ee_pose_to_observation: bool = False add_ee_pose_to_observation: bool = False
crop_params_dict: dict[str, tuple[int, int, int, int]] | None = None crop_params_dict: Optional[dict[str, tuple[int, int, int, int]]] = None
resize_size: tuple[int, int] | None = None resize_size: Optional[tuple[int, int]] = None
control_time_s: float = 20.0 control_time_s: float = 20.0
fixed_reset_joint_positions: Any | None = None fixed_reset_joint_positions: Optional[Any] = None
reset_time_s: float = 5.0 reset_time_s: float = 5.0
use_gripper: bool = True use_gripper: bool = True
gripper_quantization_threshold: float | None = 0.8 gripper_quantization_threshold: float | None = 0.8
@@ -195,25 +197,24 @@ class EnvTransformConfig:
class HILSerlRobotEnvConfig(EnvConfig): class HILSerlRobotEnvConfig(EnvConfig):
"""Configuration for the HILSerlRobotEnv environment.""" """Configuration for the HILSerlRobotEnv environment."""
robot: RobotConfig | None = None robot: Optional[RobotConfig] = None
teleop: TeleoperatorConfig | None = None teleop: Optional[TeleoperatorConfig] = None
wrapper: EnvTransformConfig | None = None wrapper: Optional[EnvTransformConfig] = None
fps: int = 10 fps: int = 10
name: str = "real_robot" name: str = "real_robot"
mode: str | None = None # Either "record", "replay", None mode: str = None # Either "record", "replay", None
repo_id: str | None = None repo_id: Optional[str] = None
dataset_root: str | None = None dataset_root: Optional[str] = None
task: str | None = "" task: str = ""
num_episodes: int = 10 # only for record mode num_episodes: int = 10 # only for record mode
episode: int = 0 episode: int = 0
device: str = "cuda" device: str = "cuda"
push_to_hub: bool = True push_to_hub: bool = True
pretrained_policy_name_or_path: str | None = None pretrained_policy_name_or_path: Optional[str] = None
reward_classifier_pretrained_path: str | None = None reward_classifier_pretrained_path: Optional[str] = None
# For the reward classifier, to record more positive examples after a success # For the reward classifier, to record more positive examples after a success
number_of_steps_after_success: int = 0 number_of_steps_after_success: int = 0
@property
def gym_kwargs(self) -> dict: def gym_kwargs(self) -> dict:
return {} return {}
@@ -223,8 +224,9 @@ class HILSerlRobotEnvConfig(EnvConfig):
class HILEnvConfig(EnvConfig): class HILEnvConfig(EnvConfig):
"""Configuration for the HIL environment.""" """Configuration for the HIL environment."""
type: str = "hil"
name: str = "PandaPickCube" name: str = "PandaPickCube"
task: str | None = "PandaPickCubeKeyboard-v0" task: str = "PandaPickCubeKeyboard-v0"
use_viewer: bool = True use_viewer: bool = True
gripper_penalty: float = 0.0 gripper_penalty: float = 0.0
use_gamepad: bool = True use_gamepad: bool = True
@@ -248,18 +250,18 @@ class HILEnvConfig(EnvConfig):
} }
) )
################# args from hilserlrobotenv ################# args from hilserlrobotenv
reward_classifier_pretrained_path: str | None = None reward_classifier_pretrained_path: Optional[str] = None
robot_config: RobotConfig | None = None robot_config: Optional[RobotConfig] = None
teleop_config: TeleoperatorConfig | None = None teleop_config: Optional[TeleoperatorConfig] = None
wrapper: EnvTransformConfig | None = None wrapper: Optional[EnvTransformConfig] = None
mode: str | None = None # Either "record", "replay", None mode: str = None # Either "record", "replay", None
repo_id: str | None = None repo_id: Optional[str] = None
dataset_root: str | None = None dataset_root: Optional[str] = None
num_episodes: int = 10 # only for record mode num_episodes: int = 10 # only for record mode
episode: int = 0 episode: int = 0
device: str = "cuda" device: str = "cuda"
push_to_hub: bool = True push_to_hub: bool = True
pretrained_policy_name_or_path: str | None = None pretrained_policy_name_or_path: Optional[str] = None
# For the reward classifier, to record more positive examples after a success # For the reward classifier, to record more positive examples after a success
number_of_steps_after_success: int = 0 number_of_steps_after_success: int = 0
############################ ############################
@@ -271,3 +273,55 @@ class HILEnvConfig(EnvConfig):
"use_gamepad": self.use_gamepad, "use_gamepad": self.use_gamepad,
"gripper_penalty": self.gripper_penalty, "gripper_penalty": self.gripper_penalty,
} }
@EnvConfig.register_subclass("libero")
@dataclass
class LiberoEnv(EnvConfig):
task: str = "libero_10" # can also choose libero_spatial, libero_object, etc.
fps: int = 30
episode_length: int = 520
obs_type: str = "pixels_agent_pos"
render_mode: str = "rgb_array"
camera_name: str = "agentview_image,robot0_eye_in_hand_image"
init_states: bool = True
multitask_eval: bool = True
features: dict[str, PolicyFeature] = field(
default_factory=lambda: {
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
}
)
features_map: dict[str, str] = field(
default_factory=lambda: {
"action": ACTION,
"agent_pos": OBS_STATE,
"pixels/agentview_image": f"{OBS_IMAGE}",
"pixels/robot0_eye_in_hand_image": f"{OBS_IMAGE_2}",
}
)
def __post_init__(self):
if self.obs_type == "pixels":
self.features["pixels/agentview_image"] = PolicyFeature(
type=FeatureType.VISUAL, shape=(360, 360, 3)
)
self.features["pixels/robot0_eye_in_hand_image"] = PolicyFeature(
type=FeatureType.VISUAL, shape=(360, 360, 3)
)
elif self.obs_type == "pixels_agent_pos":
self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(8,))
self.features["pixels/agentview_image"] = PolicyFeature(
type=FeatureType.VISUAL, shape=(360, 360, 3)
)
self.features["pixels/robot0_eye_in_hand_image"] = PolicyFeature(
type=FeatureType.VISUAL, shape=(360, 360, 3)
)
@property
def gym_kwargs(self) -> dict:
return {
# "task": self.task,
"obs_type": self.obs_type,
"render_mode": self.render_mode,
# "max_episode_steps": self.episode_length,
}
+311
View File
@@ -0,0 +1,311 @@
import math
import os
from collections import defaultdict
from itertools import chain
from typing import Any, Callable
import gymnasium as gym
import numpy as np
import torch
from gymnasium import spaces
from libero.libero import benchmark, get_libero_path
from libero.libero.envs import OffScreenRenderEnv
from lerobot.constants import (
OBS_IMAGE,
OBS_IMAGE_2,
)
def create_libero_envs(
task: str,
n_envs: int,
gym_kwargs: dict[str, Any] = None,
camera_name: str = "agentview_image,robot0_eye_in_hand_image",
init_states: bool = True,
env_cls: Callable = None,
multitask_eval: bool = True,
) -> dict[str, dict[str, Any]]:
"""
Here n_envs is per task and equal to the number of rollouts.
Returns:
dict[str, dict[str, list[LiberoEnv]]]: keys are task_suite and values are list of LiberoEnv envs.
"""
if gym_kwargs is None:
gym_kwargs = {}
if not multitask_eval:
benchmark_dict = benchmark.get_benchmark_dict()
task_suite = benchmark_dict[task]() # can also choose libero_spatial, libero_object, libero_10 etc.
tasks_id = list(range(len(task_suite.tasks)))
episode_indices = [0 for i in range(len(tasks_id))]
if len(tasks_id) == 1:
tasks_id = [tasks_id[0] for _ in range(n_envs)]
episode_indices = list(range(n_envs))
elif len(tasks_id) < n_envs and n_envs % len(tasks_id) == 0:
n_repeat = n_envs // len(tasks_id)
episode_indices = []
for i in range(len(tasks_id)):
episode_indices.extend(list(range(n_repeat)))
tasks_id = list(chain.from_iterable([[item] * n_repeat for item in tasks_id]))
elif n_envs < len(tasks_id):
tasks_id = tasks_id[:n_envs]
episode_indices = list(range(n_envs))[:n_envs]
print(f"WARNING: n_envs < len(tasks_id), evaluating only on {tasks_id}")
print(f"Creating Libero envs with task ids {tasks_id} from suite {task}")
assert n_envs == len(tasks_id), (
f"len(n_envs) and tasks_id should be the same, got {n_envs} and {len(tasks_id)}"
)
return env_cls(
[
lambda i=i: LiberoEnv(
task_suite=task_suite,
task_id=tasks_id[i],
task_suite_name=task,
camera_name=camera_name,
init_states=init_states,
episode_index=episode_indices[i],
**gym_kwargs,
)
for i in range(n_envs)
]
)
else:
envs = defaultdict(dict)
benchmark_dict = benchmark.get_benchmark_dict()
task = task.split(",")
for _task in task:
task_suite = benchmark_dict[
_task
]() # can also choose libero_spatial, libero_object, libero_10 etc.
tasks_ids = list(range(len(task_suite.tasks)))
# tasks_ids = [0] # FIXME(mshukor): debug
for tasks_id in tasks_ids:
episode_indices = list(range(n_envs))
print(
f"Creating Libero envs with task ids {tasks_id} from suite {_task}, episode_indices: {episode_indices}"
)
envs_list = [
lambda i=i: LiberoEnv(
task_suite=task_suite,
task_id=tasks_id,
task_suite_name=_task,
camera_name=camera_name,
init_states=init_states,
episode_index=episode_indices[i],
**gym_kwargs,
)
for i in range(n_envs)
]
envs[_task][tasks_id] = env_cls(envs_list)
return envs
def quat2axisangle(quat):
"""
Copied from robosuite: https://github.com/ARISE-Initiative/robosuite/blob/eafb81f54ffc104f905ee48a16bb15f059176ad3/robosuite/utils/transform_utils.py#L490C1-L512C55
Converts quaternion to axis-angle format.
Returns a unit vector direction scaled by its angle in radians.
Args:
quat (np.array): (x,y,z,w) vec4 float angles
Returns:
np.array: (ax,ay,az) axis-angle exponential coordinates
"""
# clip quaternion
if quat[3] > 1.0:
quat[3] = 1.0
elif quat[3] < -1.0:
quat[3] = -1.0
den = np.sqrt(1.0 - quat[3] * quat[3])
if math.isclose(den, 0.0):
# This is (close to) a zero degree rotation, immediately return
return np.zeros(3)
return (quat[:3] * 2.0 * math.acos(quat[3])) / den
def get_task_init_states(task_suite, i):
init_states_path = os.path.join(
get_libero_path("init_states"),
task_suite.tasks[i].problem_folder,
task_suite.tasks[i].init_states_file,
)
init_states = torch.load(init_states_path, weights_only=False)
return init_states
def get_libero_dummy_action():
"""Get dummy/no-op action, used to roll out the simulation while the robot does nothing."""
return [0, 0, 0, 0, 0, 0, -1]
class LiberoEnv(gym.Env):
metadata = {"render_modes": ["rgb_array"], "render_fps": 80}
def __init__(
self,
task_suite,
task_id,
task_suite_name,
camera_name="agentview_image,robot0_eye_in_hand_image",
obs_type="pixels",
render_mode="rgb_array",
observation_width=256,
observation_height=256,
visualization_width=640,
visualization_height=480,
init_states=True,
episode_index=0,
):
super().__init__()
self.task_id = task_id
self.obs_type = obs_type
self.render_mode = render_mode
self.observation_width = observation_width
self.observation_height = observation_height
self.visualization_width = visualization_width
self.visualization_height = visualization_height
self.init_states = init_states
self.camera_name = camera_name.split(
","
) # agentview_image (main) or robot0_eye_in_hand_image (wrist)
self.camera_name_mapping = {
"agentview_image": OBS_IMAGE,
"robot0_eye_in_hand_image": OBS_IMAGE_2,
}
self.num_steps_wait = (
10 # Do nothing for the first few timesteps to wait for the simulator drops objects
)
self.episode_index = episode_index
self._env = self._make_envs_task(task_suite, self.task_id)
if task_suite_name == "libero_spatial":
max_steps = 220 # longest training demo has 193 steps
elif task_suite_name == "libero_object":
max_steps = 280 # longest training demo has 254 steps
elif task_suite_name == "libero_goal":
max_steps = 300 # longest training demo has 270 steps
elif task_suite_name == "libero_10":
max_steps = 520 # longest training demo has 505 steps
elif task_suite_name == "libero_90":
max_steps = 400 # longest training demo has 373 steps
self._max_episode_steps = max_steps
images = {}
for cam in self.camera_name:
images[self.camera_name_mapping[cam]] = spaces.Box(
low=0,
high=255,
shape=(self.observation_height, self.observation_width, 3),
dtype=np.uint8,
)
if self.obs_type == "state":
raise NotImplementedError()
elif self.obs_type == "pixels":
self.observation_space = spaces.Dict(
{
"pixels": spaces.Dict(images),
}
)
elif self.obs_type == "pixels_agent_pos":
self.observation_space = spaces.Dict(
{
"pixels": spaces.Dict(images),
"agent_pos": spaces.Box(
low=-1000.0,
high=1000.0,
shape=(8,),
dtype=np.float64,
),
}
)
self.action_space = spaces.Box(low=-1, high=1, shape=(7,), dtype=np.float32)
def render(self):
raw_obs = self._env.env._get_observations()
image = self._format_raw_obs(raw_obs)["pixels"][OBS_IMAGE]
return image
def _make_envs_task(self, task_suite, task_id: int = 0):
task = task_suite.get_task(task_id)
self.task = task.name
self.task_description = task.language
task_bddl_file = os.path.join(get_libero_path("bddl_files"), task.problem_folder, task.bddl_file)
env_args = {
"bddl_file_name": task_bddl_file,
"camera_heights": self.observation_height,
"camera_widths": self.observation_width,
}
env = OffScreenRenderEnv(**env_args)
env.reset()
if self.init_states:
init_states = get_task_init_states(
task_suite, task_id
) # for benchmarking purpose, we fix the a set of initial states FIXME(mshukor): should be in the reset()?
init_state_id = self.episode_index # episode index
env.set_init_state(init_states[init_state_id])
return env
def _format_raw_obs(self, raw_obs):
images = {}
for camera_name in self.camera_name:
image = raw_obs[camera_name]
image = image[::-1, ::-1] # rotate 180 degrees
images[self.camera_name_mapping[camera_name]] = image
# images = image if len(images) == 1 else images
state = np.concatenate(
(
raw_obs["robot0_eef_pos"],
quat2axisangle(raw_obs["robot0_eef_quat"]),
raw_obs["robot0_gripper_qpos"],
)
)
agent_pos = state
if self.obs_type == "state":
raise NotImplementedError()
elif self.obs_type == "pixels":
obs = {"pixels": images.copy()}
elif self.obs_type == "pixels_agent_pos":
obs = {
"pixels": images.copy(),
"agent_pos": agent_pos,
}
return obs
def reset(self, seed=None, **kwargs):
super().reset(seed=seed)
self._env.seed(seed)
raw_obs = self._env.reset()
# Do nothing for the first few timesteps to wait for the simulator drops objects
for _ in range(self.num_steps_wait):
raw_obs, _, _, _ = self._env.step(get_libero_dummy_action())
observation = self._format_raw_obs(raw_obs)
info = {"is_success": False}
return observation, info
def step(self, action):
assert action.ndim == 1
raw_obs, reward, done, info = self._env.step(action)
is_success = self._env.check_success()
terminated = done or is_success
info["is_success"] = done # is_success
observation = self._format_raw_obs(raw_obs)
if done:
self.reset()
print(self.task, self.task_id, done, is_success)
truncated = False
return observation, reward, terminated, truncated, info
def close(self):
self._env.close()
+138 -14
View File
@@ -50,12 +50,12 @@ import json
import logging import logging
import threading import threading
import time import time
from collections.abc import Callable
from contextlib import nullcontext from contextlib import nullcontext
from copy import deepcopy from copy import deepcopy
from dataclasses import asdict from dataclasses import asdict
from pathlib import Path from pathlib import Path
from pprint import pformat from pprint import pformat
from typing import Callable
import einops import einops
import gymnasium as gym import gymnasium as gym
@@ -456,32 +456,56 @@ def _compile_episode_data(
return data_dict return data_dict
@parser.wrap() @parser.wrap()
def eval_main(cfg: EvalPipelineConfig): def eval(cfg: EvalPipelineConfig):
logging.info(pformat(asdict(cfg))) logging.info(pformat(asdict(cfg)))
# Check device is available # Check device is available
device = get_safe_torch_device(cfg.policy.device, log=True) device = get_safe_torch_device(cfg.device, log=True)
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True
set_seed(cfg.seed) set_global_seed(cfg.seed)
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}") log_output_dir(cfg.output_dir)
logging.info("Making environment.") logging.info("Making environment.")
env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs) env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs)
logging.info("Making policy.") logging.info("Making policy.")
policy = make_policy( policy = make_policy(
cfg=cfg.policy, cfg=cfg.policy,
device=device,
env_cfg=cfg.env, env_cfg=cfg.env,
) )
policy.eval() policy.eval()
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(): with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext():
if cfg.env.multitask_eval:
info = eval_policy_multitask(
env,
policy,
cfg.eval.n_episodes,
max_episodes_rendered=10,
videos_dir=Path(cfg.output_dir) / "videos",
start_seed=cfg.seed,
max_parallel_tasks=cfg.env.max_parallel_tasks,
verbose=False,
)
# Print overall stats
print("Overall Aggregated Metrics:")
print(info["overall"]["aggregated"])
# Print per-suite stats
for task_group, task_group_info in info.items():
if task_group == "overall":
continue # Skip the overall stats since we already printed it
print(f"\nAggregated Metrics for {task_group}:")
print(task_group_info["aggregated"])
for _task_group, v in env.items():
for _env in v.values():
_env.close()
else:
info = eval_policy( info = eval_policy(
env, env,
policy, policy,
@@ -491,20 +515,120 @@ def eval_main(cfg: EvalPipelineConfig):
start_seed=cfg.seed, start_seed=cfg.seed,
) )
print(info["aggregated"]) print(info["aggregated"])
env.close()
# Save info # Save info
with open(Path(cfg.output_dir) / "eval_info.json", "w") as f: with open(Path(cfg.output_dir) / "eval_info.json", "w") as f:
json.dump(info, f, indent=2) json.dump(info, f, indent=2)
env.close()
logging.info("End of eval") logging.info("End of eval")
def eval_policy_multitask(
envs: dict[str, dict[str, gym.vector.VectorEnv]],
policy,
n_episodes: int,
max_episodes_rendered: int = 0,
videos_dir: Path | None = None,
return_episode_data: bool = False,
start_seed: int | None = None,
max_parallel_tasks: int = 5,
verbose: bool = True,
) -> dict:
global_start = time.time()
results = {}
def main(): overall_rewards, overall_max_rewards, overall_successes = [], [], []
init_logging() overall_video_paths = []
eval_main() overall_episode_data = None
def eval_task(task_group, task_id, env):
"""Evaluates a single task in parallel."""
print(f"Evaluating: task_group: {task_group}, task_id: {task_id} ...")
task_result = eval_policy(
env, policy, n_episodes, max_episodes_rendered, videos_dir, return_episode_data, start_seed, verbose=verbose
)
per_episode = task_result["per_episode"]
return {
"task_group": task_group,
"task_id": task_id,
"sum_rewards": [ep["sum_reward"] for ep in per_episode],
"max_rewards": [ep["max_reward"] for ep in per_episode],
"successes": [ep["success"] for ep in per_episode],
"video_paths": task_result.get("video_paths", []),
}
with concurrent.futures.ThreadPoolExecutor(max_workers=max_parallel_tasks) as executor:
future_to_task = {
executor.submit(eval_task, task_group, task_id, env): (task_group, task_id)
for task_group, tasks in envs.items()
for task_id, env in tasks.items()
}
task_group_results = {}
for future in concurrent.futures.as_completed(future_to_task):
task_result = future.result()
task_group = task_result["task_group"]
if task_group not in task_group_results:
task_group_results[task_group] = {
"sum_rewards": [],
"max_rewards": [],
"successes": [],
"video_paths": [],
}
task_group_results[task_group]["sum_rewards"].extend(task_result["sum_rewards"])
task_group_results[task_group]["max_rewards"].extend(task_result["max_rewards"])
task_group_results[task_group]["successes"].extend(task_result["successes"])
task_group_results[task_group]["video_paths"].extend(task_result["video_paths"])
# Process results per task group
for task_group, data in task_group_results.items():
suite_rewards = data["sum_rewards"]
suite_max_rewards = data["max_rewards"]
suite_successes = data["successes"]
suite_video_paths = data["video_paths"]
suite_eval_s = time.time() - global_start
suite_eval_ep_s = suite_eval_s / max(1, len(suite_rewards))
results[task_group] = {
"aggregated": {
"avg_sum_reward": float(np.nanmean(suite_rewards)),
"avg_max_reward": float(np.nanmean(suite_max_rewards)),
"pc_success": float(np.nanmean(suite_successes) * 100),
"eval_s": suite_eval_s,
"eval_ep_s": suite_eval_ep_s,
},
"video_paths": suite_video_paths,
"episodes": None, # Modify if episode data is needed
}
overall_rewards.extend(suite_rewards)
overall_max_rewards.extend(suite_max_rewards)
overall_successes.extend(suite_successes)
overall_video_paths.extend(suite_video_paths)
# Global metrics
global_eval_s = time.time() - global_start
global_eval_ep_s = global_eval_s / max(1, len(overall_rewards))
results["overall"] = {
"aggregated": {
"avg_sum_reward": float(np.nanmean(overall_rewards)),
"avg_max_reward": float(np.nanmean(overall_max_rewards)),
"pc_success": float(np.nanmean(overall_successes) * 100),
"eval_s": global_eval_s,
"eval_ep_s": global_eval_ep_s,
},
"video_paths": overall_video_paths,
"episodes": overall_episode_data,
}
return results
if __name__ == "__main__": if __name__ == "__main__":
main() init_logging()
eval_main()
+19 -1
View File
@@ -34,7 +34,7 @@ from lerobot.optim.factory import make_optimizer_and_scheduler
from lerobot.policies.factory import make_policy from lerobot.policies.factory import make_policy
from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.utils import get_device_from_parameters from lerobot.policies.utils import get_device_from_parameters
from lerobot.scripts.eval import eval_policy from lerobot.scripts.eval import eval_policy, eval_policy_multitask
from lerobot.utils.logging_utils import AverageMeter, MetricsTracker from lerobot.utils.logging_utils import AverageMeter, MetricsTracker
from lerobot.utils.random_utils import set_seed from lerobot.utils.random_utils import set_seed
from lerobot.utils.train_utils import ( from lerobot.utils.train_utils import (
@@ -252,6 +252,24 @@ def train(cfg: TrainPipelineConfig):
torch.no_grad(), torch.no_grad(),
torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(),
): ):
if cfg.env.multitask_eval:
eval_info = eval_policy_multitask(
eval_env,
policy,
cfg.eval.n_episodes,
videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_id}",
max_episodes_rendered=4,
start_seed=cfg.seed,
max_parallel_tasks=cfg.env.max_parallel_tasks,
)
aggregated_results = eval_info["overall"]["aggregated"]
# Print per-suite stats
for task_group, task_group_info in eval_info.items():
if task_group == "overall":
continue # Skip the overall stats since we already printed it
print(f"\nAggregated Metrics for {task_group}:")
print(task_group_info["aggregated"])
else:
eval_info = eval_policy( eval_info = eval_policy(
eval_env, eval_env,
policy, policy,