From 21a961ecbb93be965e55fcfc4c93fa9200f4be75 Mon Sep 17 00:00:00 2001 From: Jade Choghari Date: Tue, 5 Aug 2025 23:55:08 -0400 Subject: [PATCH] add libero --- examples/5_train_libero.sh | 59 +++++++ src/lerobot/constants.py | 1 + src/lerobot/envs/configs.py | 108 +++++++++--- src/lerobot/envs/libero.py | 311 +++++++++++++++++++++++++++++++++++ src/lerobot/scripts/eval.py | 170 ++++++++++++++++--- src/lerobot/scripts/train.py | 36 +++- 6 files changed, 626 insertions(+), 59 deletions(-) create mode 100644 examples/5_train_libero.sh create mode 100644 src/lerobot/envs/libero.py diff --git a/examples/5_train_libero.sh b/examples/5_train_libero.sh new file mode 100644 index 000000000..0b8633cd6 --- /dev/null +++ b/examples/5_train_libero.sh @@ -0,0 +1,59 @@ +#!/bin/bash + +# config +REPO_ID=physical-intelligence/libero +TASK=libero_10 +OUTPUT_DIR=./outputs/train_run/smolvla2_libero + +# clean previous run +rm -rf $OUTPUT_DIR + +# training params +STEPS=100000 +BATCH_SIZE=4 +EVAL_FREQ=2000 +SAVE_FREQ=10000 +NUM_WORKERS=0 + +# model params +POLICY=smolvla +USE_AMP=false +OPTIMIZER_LR=1e-4 +PEFT_METHOD=lora +LOAD_VLM_WEIGHTS=true +VLM_REPO_ID=None +MAX_ACTION_DIM=32 +MAX_STATE_DIM=32 + +# dataset/image params +USE_IMAGENET_STATS=false +ENABLE_IMG_TRANSFORM=true +MAX_NUM_IMAGES=2 +MAX_IMAGE_DIM=1024 + +echo -e "\033[1;33m[WARNING]\033[0m LIBERO is not yet fully supported in this PR!" +# launch +MUJOCO_GL=egl python src/lerobot/scripts/train.py \ + --policy.type=$POLICY \ + --dataset.repo_id=$REPO_ID \ + --env.type=libero \ + --env.task=$TASK \ + --output_dir=$OUTPUT_DIR \ + --steps=$STEPS \ + --batch_size=$BATCH_SIZE \ + --eval_freq=$EVAL_FREQ \ + --save_freq=$SAVE_FREQ \ + --num_workers=$NUM_WORKERS \ + --policy.max_action_dim=$MAX_ACTION_DIM \ + --policy.max_state_dim=$MAX_STATE_DIM \ + --policy.use_amp=$USE_AMP \ + --policy.optimizer_lr=$OPTIMIZER_LR \ + --policy.peft_method=$PEFT_METHOD \ + --policy.load_vlm_weights=$LOAD_VLM_WEIGHTS \ + --policy.repo_id=$VLM_REPO_ID \ + --dataset.use_imagenet_stats=$USE_IMAGENET_STATS \ + --dataset.image_transforms.enable=$ENABLE_IMG_TRANSFORM \ + --dataset.max_num_images=$MAX_NUM_IMAGES \ + --dataset.max_image_dim=$MAX_IMAGE_DIM \ + # --policy.exclude_image_keys=wrist_image \ + # --policy.use_env_state=false diff --git a/src/lerobot/constants.py b/src/lerobot/constants.py index 30777239e..bc5b2013c 100644 --- a/src/lerobot/constants.py +++ b/src/lerobot/constants.py @@ -20,6 +20,7 @@ from huggingface_hub.constants import HF_HOME OBS_ENV_STATE = "observation.environment_state" OBS_STATE = "observation.state" OBS_IMAGE = "observation.image" +OBS_IMAGE_2 = "observation.image2" OBS_IMAGES = "observation.images" ACTION = "action" REWARD = "next.reward" diff --git a/src/lerobot/envs/configs.py b/src/lerobot/envs/configs.py index 35797c6ed..f815ca3b3 100644 --- a/src/lerobot/envs/configs.py +++ b/src/lerobot/envs/configs.py @@ -14,12 +14,12 @@ import abc from dataclasses import dataclass, field -from typing import Any +from typing import Any, Optional import draccus from lerobot.configs.types import FeatureType, PolicyFeature -from lerobot.constants import ACTION, OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE +from lerobot.constants import ACTION, OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGE_2, OBS_IMAGES, OBS_STATE from lerobot.robots import RobotConfig from lerobot.teleoperators.config import TeleoperatorConfig @@ -30,6 +30,8 @@ class EnvConfig(draccus.ChoiceRegistry, abc.ABC): fps: int = 30 features: dict[str, PolicyFeature] = field(default_factory=dict) features_map: dict[str, str] = field(default_factory=dict) + multitask_eval: bool = False + max_parallel_tasks: int = 5 @property def type(self) -> str: @@ -44,7 +46,7 @@ class EnvConfig(draccus.ChoiceRegistry, abc.ABC): @EnvConfig.register_subclass("aloha") @dataclass class AlohaEnv(EnvConfig): - task: str | None = "AlohaInsertion-v0" + task: str = "AlohaInsertion-v0" fps: int = 50 episode_length: int = 400 obs_type: str = "pixels_agent_pos" @@ -82,7 +84,7 @@ class AlohaEnv(EnvConfig): @EnvConfig.register_subclass("pusht") @dataclass class PushtEnv(EnvConfig): - task: str | None = "PushT-v0" + task: str = "PushT-v0" fps: int = 10 episode_length: int = 300 obs_type: str = "pixels_agent_pos" @@ -124,7 +126,7 @@ class PushtEnv(EnvConfig): @EnvConfig.register_subclass("xarm") @dataclass class XarmEnv(EnvConfig): - task: str | None = "XarmLift-v0" + task: str = "XarmLift-v0" fps: int = 15 episode_length: int = 200 obs_type: str = "pixels_agent_pos" @@ -179,10 +181,10 @@ class EnvTransformConfig: add_joint_velocity_to_observation: bool = False add_current_to_observation: bool = False add_ee_pose_to_observation: bool = False - crop_params_dict: dict[str, tuple[int, int, int, int]] | None = None - resize_size: tuple[int, int] | None = None + crop_params_dict: Optional[dict[str, tuple[int, int, int, int]]] = None + resize_size: Optional[tuple[int, int]] = None control_time_s: float = 20.0 - fixed_reset_joint_positions: Any | None = None + fixed_reset_joint_positions: Optional[Any] = None reset_time_s: float = 5.0 use_gripper: bool = True gripper_quantization_threshold: float | None = 0.8 @@ -195,25 +197,24 @@ class EnvTransformConfig: class HILSerlRobotEnvConfig(EnvConfig): """Configuration for the HILSerlRobotEnv environment.""" - robot: RobotConfig | None = None - teleop: TeleoperatorConfig | None = None - wrapper: EnvTransformConfig | None = None + robot: Optional[RobotConfig] = None + teleop: Optional[TeleoperatorConfig] = None + wrapper: Optional[EnvTransformConfig] = None fps: int = 10 name: str = "real_robot" - mode: str | None = None # Either "record", "replay", None - repo_id: str | None = None - dataset_root: str | None = None - task: str | None = "" + mode: str = None # Either "record", "replay", None + repo_id: Optional[str] = None + dataset_root: Optional[str] = None + task: str = "" num_episodes: int = 10 # only for record mode episode: int = 0 device: str = "cuda" push_to_hub: bool = True - pretrained_policy_name_or_path: str | None = None - reward_classifier_pretrained_path: str | None = None + pretrained_policy_name_or_path: Optional[str] = None + reward_classifier_pretrained_path: Optional[str] = None # For the reward classifier, to record more positive examples after a success number_of_steps_after_success: int = 0 - @property def gym_kwargs(self) -> dict: return {} @@ -223,8 +224,9 @@ class HILSerlRobotEnvConfig(EnvConfig): class HILEnvConfig(EnvConfig): """Configuration for the HIL environment.""" + type: str = "hil" name: str = "PandaPickCube" - task: str | None = "PandaPickCubeKeyboard-v0" + task: str = "PandaPickCubeKeyboard-v0" use_viewer: bool = True gripper_penalty: float = 0.0 use_gamepad: bool = True @@ -248,18 +250,18 @@ class HILEnvConfig(EnvConfig): } ) ################# args from hilserlrobotenv - reward_classifier_pretrained_path: str | None = None - robot_config: RobotConfig | None = None - teleop_config: TeleoperatorConfig | None = None - wrapper: EnvTransformConfig | None = None - mode: str | None = None # Either "record", "replay", None - repo_id: str | None = None - dataset_root: str | None = None + reward_classifier_pretrained_path: Optional[str] = None + robot_config: Optional[RobotConfig] = None + teleop_config: Optional[TeleoperatorConfig] = None + wrapper: Optional[EnvTransformConfig] = None + mode: str = None # Either "record", "replay", None + repo_id: Optional[str] = None + dataset_root: Optional[str] = None num_episodes: int = 10 # only for record mode episode: int = 0 device: str = "cuda" push_to_hub: bool = True - pretrained_policy_name_or_path: str | None = None + pretrained_policy_name_or_path: Optional[str] = None # For the reward classifier, to record more positive examples after a success number_of_steps_after_success: int = 0 ############################ @@ -271,3 +273,55 @@ class HILEnvConfig(EnvConfig): "use_gamepad": self.use_gamepad, "gripper_penalty": self.gripper_penalty, } + + +@EnvConfig.register_subclass("libero") +@dataclass +class LiberoEnv(EnvConfig): + task: str = "libero_10" # can also choose libero_spatial, libero_object, etc. + fps: int = 30 + episode_length: int = 520 + obs_type: str = "pixels_agent_pos" + render_mode: str = "rgb_array" + camera_name: str = "agentview_image,robot0_eye_in_hand_image" + init_states: bool = True + multitask_eval: bool = True + features: dict[str, PolicyFeature] = field( + default_factory=lambda: { + "action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)), + } + ) + features_map: dict[str, str] = field( + default_factory=lambda: { + "action": ACTION, + "agent_pos": OBS_STATE, + "pixels/agentview_image": f"{OBS_IMAGE}", + "pixels/robot0_eye_in_hand_image": f"{OBS_IMAGE_2}", + } + ) + + def __post_init__(self): + if self.obs_type == "pixels": + self.features["pixels/agentview_image"] = PolicyFeature( + type=FeatureType.VISUAL, shape=(360, 360, 3) + ) + self.features["pixels/robot0_eye_in_hand_image"] = PolicyFeature( + type=FeatureType.VISUAL, shape=(360, 360, 3) + ) + elif self.obs_type == "pixels_agent_pos": + self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(8,)) + self.features["pixels/agentview_image"] = PolicyFeature( + type=FeatureType.VISUAL, shape=(360, 360, 3) + ) + self.features["pixels/robot0_eye_in_hand_image"] = PolicyFeature( + type=FeatureType.VISUAL, shape=(360, 360, 3) + ) + + @property + def gym_kwargs(self) -> dict: + return { + # "task": self.task, + "obs_type": self.obs_type, + "render_mode": self.render_mode, + # "max_episode_steps": self.episode_length, + } diff --git a/src/lerobot/envs/libero.py b/src/lerobot/envs/libero.py new file mode 100644 index 000000000..75dfd6ada --- /dev/null +++ b/src/lerobot/envs/libero.py @@ -0,0 +1,311 @@ +import math +import os +from collections import defaultdict +from itertools import chain +from typing import Any, Callable + +import gymnasium as gym +import numpy as np +import torch +from gymnasium import spaces +from libero.libero import benchmark, get_libero_path +from libero.libero.envs import OffScreenRenderEnv + +from lerobot.constants import ( + OBS_IMAGE, + OBS_IMAGE_2, +) + + +def create_libero_envs( + task: str, + n_envs: int, + gym_kwargs: dict[str, Any] = None, + camera_name: str = "agentview_image,robot0_eye_in_hand_image", + init_states: bool = True, + env_cls: Callable = None, + multitask_eval: bool = True, +) -> dict[str, dict[str, Any]]: + """ + Here n_envs is per task and equal to the number of rollouts. + Returns: + dict[str, dict[str, list[LiberoEnv]]]: keys are task_suite and values are list of LiberoEnv envs. + """ + if gym_kwargs is None: + gym_kwargs = {} + + if not multitask_eval: + benchmark_dict = benchmark.get_benchmark_dict() + task_suite = benchmark_dict[task]() # can also choose libero_spatial, libero_object, libero_10 etc. + tasks_id = list(range(len(task_suite.tasks))) + episode_indices = [0 for i in range(len(tasks_id))] + if len(tasks_id) == 1: + tasks_id = [tasks_id[0] for _ in range(n_envs)] + episode_indices = list(range(n_envs)) + elif len(tasks_id) < n_envs and n_envs % len(tasks_id) == 0: + n_repeat = n_envs // len(tasks_id) + episode_indices = [] + for i in range(len(tasks_id)): + episode_indices.extend(list(range(n_repeat))) + tasks_id = list(chain.from_iterable([[item] * n_repeat for item in tasks_id])) + elif n_envs < len(tasks_id): + tasks_id = tasks_id[:n_envs] + episode_indices = list(range(n_envs))[:n_envs] + print(f"WARNING: n_envs < len(tasks_id), evaluating only on {tasks_id}") + print(f"Creating Libero envs with task ids {tasks_id} from suite {task}") + assert n_envs == len(tasks_id), ( + f"len(n_envs) and tasks_id should be the same, got {n_envs} and {len(tasks_id)}" + ) + return env_cls( + [ + lambda i=i: LiberoEnv( + task_suite=task_suite, + task_id=tasks_id[i], + task_suite_name=task, + camera_name=camera_name, + init_states=init_states, + episode_index=episode_indices[i], + **gym_kwargs, + ) + for i in range(n_envs) + ] + ) + else: + envs = defaultdict(dict) + benchmark_dict = benchmark.get_benchmark_dict() + task = task.split(",") + for _task in task: + task_suite = benchmark_dict[ + _task + ]() # can also choose libero_spatial, libero_object, libero_10 etc. + tasks_ids = list(range(len(task_suite.tasks))) + # tasks_ids = [0] # FIXME(mshukor): debug + for tasks_id in tasks_ids: + episode_indices = list(range(n_envs)) + print( + f"Creating Libero envs with task ids {tasks_id} from suite {_task}, episode_indices: {episode_indices}" + ) + envs_list = [ + lambda i=i: LiberoEnv( + task_suite=task_suite, + task_id=tasks_id, + task_suite_name=_task, + camera_name=camera_name, + init_states=init_states, + episode_index=episode_indices[i], + **gym_kwargs, + ) + for i in range(n_envs) + ] + envs[_task][tasks_id] = env_cls(envs_list) + return envs + + +def quat2axisangle(quat): + """ + Copied from robosuite: https://github.com/ARISE-Initiative/robosuite/blob/eafb81f54ffc104f905ee48a16bb15f059176ad3/robosuite/utils/transform_utils.py#L490C1-L512C55 + + Converts quaternion to axis-angle format. + Returns a unit vector direction scaled by its angle in radians. + + Args: + quat (np.array): (x,y,z,w) vec4 float angles + + Returns: + np.array: (ax,ay,az) axis-angle exponential coordinates + """ + # clip quaternion + if quat[3] > 1.0: + quat[3] = 1.0 + elif quat[3] < -1.0: + quat[3] = -1.0 + + den = np.sqrt(1.0 - quat[3] * quat[3]) + if math.isclose(den, 0.0): + # This is (close to) a zero degree rotation, immediately return + return np.zeros(3) + + return (quat[:3] * 2.0 * math.acos(quat[3])) / den + + +def get_task_init_states(task_suite, i): + init_states_path = os.path.join( + get_libero_path("init_states"), + task_suite.tasks[i].problem_folder, + task_suite.tasks[i].init_states_file, + ) + init_states = torch.load(init_states_path, weights_only=False) + return init_states + + +def get_libero_dummy_action(): + """Get dummy/no-op action, used to roll out the simulation while the robot does nothing.""" + return [0, 0, 0, 0, 0, 0, -1] + + +class LiberoEnv(gym.Env): + metadata = {"render_modes": ["rgb_array"], "render_fps": 80} + + def __init__( + self, + task_suite, + task_id, + task_suite_name, + camera_name="agentview_image,robot0_eye_in_hand_image", + obs_type="pixels", + render_mode="rgb_array", + observation_width=256, + observation_height=256, + visualization_width=640, + visualization_height=480, + init_states=True, + episode_index=0, + ): + super().__init__() + self.task_id = task_id + self.obs_type = obs_type + self.render_mode = render_mode + self.observation_width = observation_width + self.observation_height = observation_height + self.visualization_width = visualization_width + self.visualization_height = visualization_height + self.init_states = init_states + self.camera_name = camera_name.split( + "," + ) # agentview_image (main) or robot0_eye_in_hand_image (wrist) + self.camera_name_mapping = { + "agentview_image": OBS_IMAGE, + "robot0_eye_in_hand_image": OBS_IMAGE_2, + } + self.num_steps_wait = ( + 10 # Do nothing for the first few timesteps to wait for the simulator drops objects + ) + self.episode_index = episode_index + + self._env = self._make_envs_task(task_suite, self.task_id) + if task_suite_name == "libero_spatial": + max_steps = 220 # longest training demo has 193 steps + elif task_suite_name == "libero_object": + max_steps = 280 # longest training demo has 254 steps + elif task_suite_name == "libero_goal": + max_steps = 300 # longest training demo has 270 steps + elif task_suite_name == "libero_10": + max_steps = 520 # longest training demo has 505 steps + elif task_suite_name == "libero_90": + max_steps = 400 # longest training demo has 373 steps + self._max_episode_steps = max_steps + + images = {} + for cam in self.camera_name: + images[self.camera_name_mapping[cam]] = spaces.Box( + low=0, + high=255, + shape=(self.observation_height, self.observation_width, 3), + dtype=np.uint8, + ) + + if self.obs_type == "state": + raise NotImplementedError() + elif self.obs_type == "pixels": + self.observation_space = spaces.Dict( + { + "pixels": spaces.Dict(images), + } + ) + elif self.obs_type == "pixels_agent_pos": + self.observation_space = spaces.Dict( + { + "pixels": spaces.Dict(images), + "agent_pos": spaces.Box( + low=-1000.0, + high=1000.0, + shape=(8,), + dtype=np.float64, + ), + } + ) + + self.action_space = spaces.Box(low=-1, high=1, shape=(7,), dtype=np.float32) + + def render(self): + raw_obs = self._env.env._get_observations() + image = self._format_raw_obs(raw_obs)["pixels"][OBS_IMAGE] + return image + + def _make_envs_task(self, task_suite, task_id: int = 0): + task = task_suite.get_task(task_id) + self.task = task.name + self.task_description = task.language + task_bddl_file = os.path.join(get_libero_path("bddl_files"), task.problem_folder, task.bddl_file) + + env_args = { + "bddl_file_name": task_bddl_file, + "camera_heights": self.observation_height, + "camera_widths": self.observation_width, + } + env = OffScreenRenderEnv(**env_args) + env.reset() + if self.init_states: + init_states = get_task_init_states( + task_suite, task_id + ) # for benchmarking purpose, we fix the a set of initial states FIXME(mshukor): should be in the reset()? + init_state_id = self.episode_index # episode index + env.set_init_state(init_states[init_state_id]) + + return env + + def _format_raw_obs(self, raw_obs): + images = {} + for camera_name in self.camera_name: + image = raw_obs[camera_name] + image = image[::-1, ::-1] # rotate 180 degrees + images[self.camera_name_mapping[camera_name]] = image + # images = image if len(images) == 1 else images + state = np.concatenate( + ( + raw_obs["robot0_eef_pos"], + quat2axisangle(raw_obs["robot0_eef_quat"]), + raw_obs["robot0_gripper_qpos"], + ) + ) + agent_pos = state + if self.obs_type == "state": + raise NotImplementedError() + elif self.obs_type == "pixels": + obs = {"pixels": images.copy()} + elif self.obs_type == "pixels_agent_pos": + obs = { + "pixels": images.copy(), + "agent_pos": agent_pos, + } + return obs + + def reset(self, seed=None, **kwargs): + super().reset(seed=seed) + + self._env.seed(seed) + raw_obs = self._env.reset() + # Do nothing for the first few timesteps to wait for the simulator drops objects + for _ in range(self.num_steps_wait): + raw_obs, _, _, _ = self._env.step(get_libero_dummy_action()) + observation = self._format_raw_obs(raw_obs) + info = {"is_success": False} + return observation, info + + def step(self, action): + assert action.ndim == 1 + raw_obs, reward, done, info = self._env.step(action) + + is_success = self._env.check_success() + terminated = done or is_success + info["is_success"] = done # is_success + + observation = self._format_raw_obs(raw_obs) + if done: + self.reset() + print(self.task, self.task_id, done, is_success) + truncated = False + return observation, reward, terminated, truncated, info + + def close(self): + self._env.close() diff --git a/src/lerobot/scripts/eval.py b/src/lerobot/scripts/eval.py index 6a6c02a24..a56c4c3b5 100644 --- a/src/lerobot/scripts/eval.py +++ b/src/lerobot/scripts/eval.py @@ -50,12 +50,12 @@ import json import logging import threading import time -from collections.abc import Callable from contextlib import nullcontext from copy import deepcopy from dataclasses import asdict from pathlib import Path from pprint import pformat +from typing import Callable import einops import gymnasium as gym @@ -456,55 +456,179 @@ def _compile_episode_data( return data_dict - @parser.wrap() -def eval_main(cfg: EvalPipelineConfig): +def eval(cfg: EvalPipelineConfig): logging.info(pformat(asdict(cfg))) # Check device is available - device = get_safe_torch_device(cfg.policy.device, log=True) + device = get_safe_torch_device(cfg.device, log=True) torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True - set_seed(cfg.seed) + set_global_seed(cfg.seed) - logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}") + log_output_dir(cfg.output_dir) logging.info("Making environment.") env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs) logging.info("Making policy.") - policy = make_policy( cfg=cfg.policy, + device=device, env_cfg=cfg.env, ) policy.eval() - with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(): - info = eval_policy( - env, - policy, - cfg.eval.n_episodes, - max_episodes_rendered=10, - videos_dir=Path(cfg.output_dir) / "videos", - start_seed=cfg.seed, - ) - print(info["aggregated"]) + with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext(): + if cfg.env.multitask_eval: + info = eval_policy_multitask( + env, + policy, + cfg.eval.n_episodes, + max_episodes_rendered=10, + videos_dir=Path(cfg.output_dir) / "videos", + start_seed=cfg.seed, + max_parallel_tasks=cfg.env.max_parallel_tasks, + verbose=False, + ) + # Print overall stats + print("Overall Aggregated Metrics:") + print(info["overall"]["aggregated"]) + + # Print per-suite stats + for task_group, task_group_info in info.items(): + if task_group == "overall": + continue # Skip the overall stats since we already printed it + print(f"\nAggregated Metrics for {task_group}:") + print(task_group_info["aggregated"]) + for _task_group, v in env.items(): + for _env in v.values(): + _env.close() + else: + info = eval_policy( + env, + policy, + cfg.eval.n_episodes, + max_episodes_rendered=10, + videos_dir=Path(cfg.output_dir) / "videos", + start_seed=cfg.seed, + ) + print(info["aggregated"]) + env.close() # Save info with open(Path(cfg.output_dir) / "eval_info.json", "w") as f: json.dump(info, f, indent=2) - env.close() - logging.info("End of eval") +def eval_policy_multitask( + envs: dict[str, dict[str, gym.vector.VectorEnv]], + policy, + n_episodes: int, + max_episodes_rendered: int = 0, + videos_dir: Path | None = None, + return_episode_data: bool = False, + start_seed: int | None = None, + max_parallel_tasks: int = 5, + verbose: bool = True, +) -> dict: + global_start = time.time() + results = {} -def main(): - init_logging() - eval_main() + overall_rewards, overall_max_rewards, overall_successes = [], [], [] + overall_video_paths = [] + overall_episode_data = None + def eval_task(task_group, task_id, env): + """Evaluates a single task in parallel.""" + print(f"Evaluating: task_group: {task_group}, task_id: {task_id} ...") + task_result = eval_policy( + env, policy, n_episodes, max_episodes_rendered, videos_dir, return_episode_data, start_seed, verbose=verbose + ) + + per_episode = task_result["per_episode"] + return { + "task_group": task_group, + "task_id": task_id, + "sum_rewards": [ep["sum_reward"] for ep in per_episode], + "max_rewards": [ep["max_reward"] for ep in per_episode], + "successes": [ep["success"] for ep in per_episode], + "video_paths": task_result.get("video_paths", []), + } + + with concurrent.futures.ThreadPoolExecutor(max_workers=max_parallel_tasks) as executor: + future_to_task = { + executor.submit(eval_task, task_group, task_id, env): (task_group, task_id) + for task_group, tasks in envs.items() + for task_id, env in tasks.items() + } + + task_group_results = {} + + for future in concurrent.futures.as_completed(future_to_task): + task_result = future.result() + task_group = task_result["task_group"] + + if task_group not in task_group_results: + task_group_results[task_group] = { + "sum_rewards": [], + "max_rewards": [], + "successes": [], + "video_paths": [], + } + + task_group_results[task_group]["sum_rewards"].extend(task_result["sum_rewards"]) + task_group_results[task_group]["max_rewards"].extend(task_result["max_rewards"]) + task_group_results[task_group]["successes"].extend(task_result["successes"]) + task_group_results[task_group]["video_paths"].extend(task_result["video_paths"]) + + # Process results per task group + for task_group, data in task_group_results.items(): + suite_rewards = data["sum_rewards"] + suite_max_rewards = data["max_rewards"] + suite_successes = data["successes"] + suite_video_paths = data["video_paths"] + + suite_eval_s = time.time() - global_start + suite_eval_ep_s = suite_eval_s / max(1, len(suite_rewards)) + + results[task_group] = { + "aggregated": { + "avg_sum_reward": float(np.nanmean(suite_rewards)), + "avg_max_reward": float(np.nanmean(suite_max_rewards)), + "pc_success": float(np.nanmean(suite_successes) * 100), + "eval_s": suite_eval_s, + "eval_ep_s": suite_eval_ep_s, + }, + "video_paths": suite_video_paths, + "episodes": None, # Modify if episode data is needed + } + + overall_rewards.extend(suite_rewards) + overall_max_rewards.extend(suite_max_rewards) + overall_successes.extend(suite_successes) + overall_video_paths.extend(suite_video_paths) + + # Global metrics + global_eval_s = time.time() - global_start + global_eval_ep_s = global_eval_s / max(1, len(overall_rewards)) + + results["overall"] = { + "aggregated": { + "avg_sum_reward": float(np.nanmean(overall_rewards)), + "avg_max_reward": float(np.nanmean(overall_max_rewards)), + "pc_success": float(np.nanmean(overall_successes) * 100), + "eval_s": global_eval_s, + "eval_ep_s": global_eval_ep_s, + }, + "video_paths": overall_video_paths, + "episodes": overall_episode_data, + } + + return results if __name__ == "__main__": - main() + init_logging() + eval_main() diff --git a/src/lerobot/scripts/train.py b/src/lerobot/scripts/train.py index 235352cd8..9b287d957 100644 --- a/src/lerobot/scripts/train.py +++ b/src/lerobot/scripts/train.py @@ -34,7 +34,7 @@ from lerobot.optim.factory import make_optimizer_and_scheduler from lerobot.policies.factory import make_policy from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.utils import get_device_from_parameters -from lerobot.scripts.eval import eval_policy +from lerobot.scripts.eval import eval_policy, eval_policy_multitask from lerobot.utils.logging_utils import AverageMeter, MetricsTracker from lerobot.utils.random_utils import set_seed from lerobot.utils.train_utils import ( @@ -252,14 +252,32 @@ def train(cfg: TrainPipelineConfig): torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(), ): - eval_info = eval_policy( - eval_env, - policy, - cfg.eval.n_episodes, - videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_id}", - max_episodes_rendered=4, - start_seed=cfg.seed, - ) + if cfg.env.multitask_eval: + eval_info = eval_policy_multitask( + eval_env, + policy, + cfg.eval.n_episodes, + videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_id}", + max_episodes_rendered=4, + start_seed=cfg.seed, + max_parallel_tasks=cfg.env.max_parallel_tasks, + ) + aggregated_results = eval_info["overall"]["aggregated"] + # Print per-suite stats + for task_group, task_group_info in eval_info.items(): + if task_group == "overall": + continue # Skip the overall stats since we already printed it + print(f"\nAggregated Metrics for {task_group}:") + print(task_group_info["aggregated"]) + else: + eval_info = eval_policy( + eval_env, + policy, + cfg.eval.n_episodes, + videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_id}", + max_episodes_rendered=4, + start_seed=cfg.seed, + ) eval_metrics = { "avg_sum_reward": AverageMeter("∑rwrd", ":.3f"),