mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-25 05:29:55 +00:00
add libero
This commit is contained in:
@@ -0,0 +1,59 @@
|
||||
#!/bin/bash
|
||||
|
||||
# config
|
||||
REPO_ID=physical-intelligence/libero
|
||||
TASK=libero_10
|
||||
OUTPUT_DIR=./outputs/train_run/smolvla2_libero
|
||||
|
||||
# clean previous run
|
||||
rm -rf $OUTPUT_DIR
|
||||
|
||||
# training params
|
||||
STEPS=100000
|
||||
BATCH_SIZE=4
|
||||
EVAL_FREQ=2000
|
||||
SAVE_FREQ=10000
|
||||
NUM_WORKERS=0
|
||||
|
||||
# model params
|
||||
POLICY=smolvla
|
||||
USE_AMP=false
|
||||
OPTIMIZER_LR=1e-4
|
||||
PEFT_METHOD=lora
|
||||
LOAD_VLM_WEIGHTS=true
|
||||
VLM_REPO_ID=None
|
||||
MAX_ACTION_DIM=32
|
||||
MAX_STATE_DIM=32
|
||||
|
||||
# dataset/image params
|
||||
USE_IMAGENET_STATS=false
|
||||
ENABLE_IMG_TRANSFORM=true
|
||||
MAX_NUM_IMAGES=2
|
||||
MAX_IMAGE_DIM=1024
|
||||
|
||||
echo -e "\033[1;33m[WARNING]\033[0m LIBERO is not yet fully supported in this PR!"
|
||||
# launch
|
||||
MUJOCO_GL=egl python src/lerobot/scripts/train.py \
|
||||
--policy.type=$POLICY \
|
||||
--dataset.repo_id=$REPO_ID \
|
||||
--env.type=libero \
|
||||
--env.task=$TASK \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--steps=$STEPS \
|
||||
--batch_size=$BATCH_SIZE \
|
||||
--eval_freq=$EVAL_FREQ \
|
||||
--save_freq=$SAVE_FREQ \
|
||||
--num_workers=$NUM_WORKERS \
|
||||
--policy.max_action_dim=$MAX_ACTION_DIM \
|
||||
--policy.max_state_dim=$MAX_STATE_DIM \
|
||||
--policy.use_amp=$USE_AMP \
|
||||
--policy.optimizer_lr=$OPTIMIZER_LR \
|
||||
--policy.peft_method=$PEFT_METHOD \
|
||||
--policy.load_vlm_weights=$LOAD_VLM_WEIGHTS \
|
||||
--policy.repo_id=$VLM_REPO_ID \
|
||||
--dataset.use_imagenet_stats=$USE_IMAGENET_STATS \
|
||||
--dataset.image_transforms.enable=$ENABLE_IMG_TRANSFORM \
|
||||
--dataset.max_num_images=$MAX_NUM_IMAGES \
|
||||
--dataset.max_image_dim=$MAX_IMAGE_DIM \
|
||||
# --policy.exclude_image_keys=wrist_image \
|
||||
# --policy.use_env_state=false
|
||||
@@ -20,6 +20,7 @@ from huggingface_hub.constants import HF_HOME
|
||||
OBS_ENV_STATE = "observation.environment_state"
|
||||
OBS_STATE = "observation.state"
|
||||
OBS_IMAGE = "observation.image"
|
||||
OBS_IMAGE_2 = "observation.image2"
|
||||
OBS_IMAGES = "observation.images"
|
||||
ACTION = "action"
|
||||
REWARD = "next.reward"
|
||||
|
||||
+81
-27
@@ -14,12 +14,12 @@
|
||||
|
||||
import abc
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
|
||||
import draccus
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.constants import ACTION, OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.constants import ACTION, OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGE_2, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.robots import RobotConfig
|
||||
from lerobot.teleoperators.config import TeleoperatorConfig
|
||||
|
||||
@@ -30,6 +30,8 @@ class EnvConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
fps: int = 30
|
||||
features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||
features_map: dict[str, str] = field(default_factory=dict)
|
||||
multitask_eval: bool = False
|
||||
max_parallel_tasks: int = 5
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
@@ -44,7 +46,7 @@ class EnvConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
@EnvConfig.register_subclass("aloha")
|
||||
@dataclass
|
||||
class AlohaEnv(EnvConfig):
|
||||
task: str | None = "AlohaInsertion-v0"
|
||||
task: str = "AlohaInsertion-v0"
|
||||
fps: int = 50
|
||||
episode_length: int = 400
|
||||
obs_type: str = "pixels_agent_pos"
|
||||
@@ -82,7 +84,7 @@ class AlohaEnv(EnvConfig):
|
||||
@EnvConfig.register_subclass("pusht")
|
||||
@dataclass
|
||||
class PushtEnv(EnvConfig):
|
||||
task: str | None = "PushT-v0"
|
||||
task: str = "PushT-v0"
|
||||
fps: int = 10
|
||||
episode_length: int = 300
|
||||
obs_type: str = "pixels_agent_pos"
|
||||
@@ -124,7 +126,7 @@ class PushtEnv(EnvConfig):
|
||||
@EnvConfig.register_subclass("xarm")
|
||||
@dataclass
|
||||
class XarmEnv(EnvConfig):
|
||||
task: str | None = "XarmLift-v0"
|
||||
task: str = "XarmLift-v0"
|
||||
fps: int = 15
|
||||
episode_length: int = 200
|
||||
obs_type: str = "pixels_agent_pos"
|
||||
@@ -179,10 +181,10 @@ class EnvTransformConfig:
|
||||
add_joint_velocity_to_observation: bool = False
|
||||
add_current_to_observation: bool = False
|
||||
add_ee_pose_to_observation: bool = False
|
||||
crop_params_dict: dict[str, tuple[int, int, int, int]] | None = None
|
||||
resize_size: tuple[int, int] | None = None
|
||||
crop_params_dict: Optional[dict[str, tuple[int, int, int, int]]] = None
|
||||
resize_size: Optional[tuple[int, int]] = None
|
||||
control_time_s: float = 20.0
|
||||
fixed_reset_joint_positions: Any | None = None
|
||||
fixed_reset_joint_positions: Optional[Any] = None
|
||||
reset_time_s: float = 5.0
|
||||
use_gripper: bool = True
|
||||
gripper_quantization_threshold: float | None = 0.8
|
||||
@@ -195,25 +197,24 @@ class EnvTransformConfig:
|
||||
class HILSerlRobotEnvConfig(EnvConfig):
|
||||
"""Configuration for the HILSerlRobotEnv environment."""
|
||||
|
||||
robot: RobotConfig | None = None
|
||||
teleop: TeleoperatorConfig | None = None
|
||||
wrapper: EnvTransformConfig | None = None
|
||||
robot: Optional[RobotConfig] = None
|
||||
teleop: Optional[TeleoperatorConfig] = None
|
||||
wrapper: Optional[EnvTransformConfig] = None
|
||||
fps: int = 10
|
||||
name: str = "real_robot"
|
||||
mode: str | None = None # Either "record", "replay", None
|
||||
repo_id: str | None = None
|
||||
dataset_root: str | None = None
|
||||
task: str | None = ""
|
||||
mode: str = None # Either "record", "replay", None
|
||||
repo_id: Optional[str] = None
|
||||
dataset_root: Optional[str] = None
|
||||
task: str = ""
|
||||
num_episodes: int = 10 # only for record mode
|
||||
episode: int = 0
|
||||
device: str = "cuda"
|
||||
push_to_hub: bool = True
|
||||
pretrained_policy_name_or_path: str | None = None
|
||||
reward_classifier_pretrained_path: str | None = None
|
||||
pretrained_policy_name_or_path: Optional[str] = None
|
||||
reward_classifier_pretrained_path: Optional[str] = None
|
||||
# For the reward classifier, to record more positive examples after a success
|
||||
number_of_steps_after_success: int = 0
|
||||
|
||||
@property
|
||||
def gym_kwargs(self) -> dict:
|
||||
return {}
|
||||
|
||||
@@ -223,8 +224,9 @@ class HILSerlRobotEnvConfig(EnvConfig):
|
||||
class HILEnvConfig(EnvConfig):
|
||||
"""Configuration for the HIL environment."""
|
||||
|
||||
type: str = "hil"
|
||||
name: str = "PandaPickCube"
|
||||
task: str | None = "PandaPickCubeKeyboard-v0"
|
||||
task: str = "PandaPickCubeKeyboard-v0"
|
||||
use_viewer: bool = True
|
||||
gripper_penalty: float = 0.0
|
||||
use_gamepad: bool = True
|
||||
@@ -248,18 +250,18 @@ class HILEnvConfig(EnvConfig):
|
||||
}
|
||||
)
|
||||
################# args from hilserlrobotenv
|
||||
reward_classifier_pretrained_path: str | None = None
|
||||
robot_config: RobotConfig | None = None
|
||||
teleop_config: TeleoperatorConfig | None = None
|
||||
wrapper: EnvTransformConfig | None = None
|
||||
mode: str | None = None # Either "record", "replay", None
|
||||
repo_id: str | None = None
|
||||
dataset_root: str | None = None
|
||||
reward_classifier_pretrained_path: Optional[str] = None
|
||||
robot_config: Optional[RobotConfig] = None
|
||||
teleop_config: Optional[TeleoperatorConfig] = None
|
||||
wrapper: Optional[EnvTransformConfig] = None
|
||||
mode: str = None # Either "record", "replay", None
|
||||
repo_id: Optional[str] = None
|
||||
dataset_root: Optional[str] = None
|
||||
num_episodes: int = 10 # only for record mode
|
||||
episode: int = 0
|
||||
device: str = "cuda"
|
||||
push_to_hub: bool = True
|
||||
pretrained_policy_name_or_path: str | None = None
|
||||
pretrained_policy_name_or_path: Optional[str] = None
|
||||
# For the reward classifier, to record more positive examples after a success
|
||||
number_of_steps_after_success: int = 0
|
||||
############################
|
||||
@@ -271,3 +273,55 @@ class HILEnvConfig(EnvConfig):
|
||||
"use_gamepad": self.use_gamepad,
|
||||
"gripper_penalty": self.gripper_penalty,
|
||||
}
|
||||
|
||||
|
||||
@EnvConfig.register_subclass("libero")
|
||||
@dataclass
|
||||
class LiberoEnv(EnvConfig):
|
||||
task: str = "libero_10" # can also choose libero_spatial, libero_object, etc.
|
||||
fps: int = 30
|
||||
episode_length: int = 520
|
||||
obs_type: str = "pixels_agent_pos"
|
||||
render_mode: str = "rgb_array"
|
||||
camera_name: str = "agentview_image,robot0_eye_in_hand_image"
|
||||
init_states: bool = True
|
||||
multitask_eval: bool = True
|
||||
features: dict[str, PolicyFeature] = field(
|
||||
default_factory=lambda: {
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
|
||||
}
|
||||
)
|
||||
features_map: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"action": ACTION,
|
||||
"agent_pos": OBS_STATE,
|
||||
"pixels/agentview_image": f"{OBS_IMAGE}",
|
||||
"pixels/robot0_eye_in_hand_image": f"{OBS_IMAGE_2}",
|
||||
}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.obs_type == "pixels":
|
||||
self.features["pixels/agentview_image"] = PolicyFeature(
|
||||
type=FeatureType.VISUAL, shape=(360, 360, 3)
|
||||
)
|
||||
self.features["pixels/robot0_eye_in_hand_image"] = PolicyFeature(
|
||||
type=FeatureType.VISUAL, shape=(360, 360, 3)
|
||||
)
|
||||
elif self.obs_type == "pixels_agent_pos":
|
||||
self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(8,))
|
||||
self.features["pixels/agentview_image"] = PolicyFeature(
|
||||
type=FeatureType.VISUAL, shape=(360, 360, 3)
|
||||
)
|
||||
self.features["pixels/robot0_eye_in_hand_image"] = PolicyFeature(
|
||||
type=FeatureType.VISUAL, shape=(360, 360, 3)
|
||||
)
|
||||
|
||||
@property
|
||||
def gym_kwargs(self) -> dict:
|
||||
return {
|
||||
# "task": self.task,
|
||||
"obs_type": self.obs_type,
|
||||
"render_mode": self.render_mode,
|
||||
# "max_episode_steps": self.episode_length,
|
||||
}
|
||||
|
||||
@@ -0,0 +1,311 @@
|
||||
import math
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from itertools import chain
|
||||
from typing import Any, Callable
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import torch
|
||||
from gymnasium import spaces
|
||||
from libero.libero import benchmark, get_libero_path
|
||||
from libero.libero.envs import OffScreenRenderEnv
|
||||
|
||||
from lerobot.constants import (
|
||||
OBS_IMAGE,
|
||||
OBS_IMAGE_2,
|
||||
)
|
||||
|
||||
|
||||
def create_libero_envs(
|
||||
task: str,
|
||||
n_envs: int,
|
||||
gym_kwargs: dict[str, Any] = None,
|
||||
camera_name: str = "agentview_image,robot0_eye_in_hand_image",
|
||||
init_states: bool = True,
|
||||
env_cls: Callable = None,
|
||||
multitask_eval: bool = True,
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""
|
||||
Here n_envs is per task and equal to the number of rollouts.
|
||||
Returns:
|
||||
dict[str, dict[str, list[LiberoEnv]]]: keys are task_suite and values are list of LiberoEnv envs.
|
||||
"""
|
||||
if gym_kwargs is None:
|
||||
gym_kwargs = {}
|
||||
|
||||
if not multitask_eval:
|
||||
benchmark_dict = benchmark.get_benchmark_dict()
|
||||
task_suite = benchmark_dict[task]() # can also choose libero_spatial, libero_object, libero_10 etc.
|
||||
tasks_id = list(range(len(task_suite.tasks)))
|
||||
episode_indices = [0 for i in range(len(tasks_id))]
|
||||
if len(tasks_id) == 1:
|
||||
tasks_id = [tasks_id[0] for _ in range(n_envs)]
|
||||
episode_indices = list(range(n_envs))
|
||||
elif len(tasks_id) < n_envs and n_envs % len(tasks_id) == 0:
|
||||
n_repeat = n_envs // len(tasks_id)
|
||||
episode_indices = []
|
||||
for i in range(len(tasks_id)):
|
||||
episode_indices.extend(list(range(n_repeat)))
|
||||
tasks_id = list(chain.from_iterable([[item] * n_repeat for item in tasks_id]))
|
||||
elif n_envs < len(tasks_id):
|
||||
tasks_id = tasks_id[:n_envs]
|
||||
episode_indices = list(range(n_envs))[:n_envs]
|
||||
print(f"WARNING: n_envs < len(tasks_id), evaluating only on {tasks_id}")
|
||||
print(f"Creating Libero envs with task ids {tasks_id} from suite {task}")
|
||||
assert n_envs == len(tasks_id), (
|
||||
f"len(n_envs) and tasks_id should be the same, got {n_envs} and {len(tasks_id)}"
|
||||
)
|
||||
return env_cls(
|
||||
[
|
||||
lambda i=i: LiberoEnv(
|
||||
task_suite=task_suite,
|
||||
task_id=tasks_id[i],
|
||||
task_suite_name=task,
|
||||
camera_name=camera_name,
|
||||
init_states=init_states,
|
||||
episode_index=episode_indices[i],
|
||||
**gym_kwargs,
|
||||
)
|
||||
for i in range(n_envs)
|
||||
]
|
||||
)
|
||||
else:
|
||||
envs = defaultdict(dict)
|
||||
benchmark_dict = benchmark.get_benchmark_dict()
|
||||
task = task.split(",")
|
||||
for _task in task:
|
||||
task_suite = benchmark_dict[
|
||||
_task
|
||||
]() # can also choose libero_spatial, libero_object, libero_10 etc.
|
||||
tasks_ids = list(range(len(task_suite.tasks)))
|
||||
# tasks_ids = [0] # FIXME(mshukor): debug
|
||||
for tasks_id in tasks_ids:
|
||||
episode_indices = list(range(n_envs))
|
||||
print(
|
||||
f"Creating Libero envs with task ids {tasks_id} from suite {_task}, episode_indices: {episode_indices}"
|
||||
)
|
||||
envs_list = [
|
||||
lambda i=i: LiberoEnv(
|
||||
task_suite=task_suite,
|
||||
task_id=tasks_id,
|
||||
task_suite_name=_task,
|
||||
camera_name=camera_name,
|
||||
init_states=init_states,
|
||||
episode_index=episode_indices[i],
|
||||
**gym_kwargs,
|
||||
)
|
||||
for i in range(n_envs)
|
||||
]
|
||||
envs[_task][tasks_id] = env_cls(envs_list)
|
||||
return envs
|
||||
|
||||
|
||||
def quat2axisangle(quat):
|
||||
"""
|
||||
Copied from robosuite: https://github.com/ARISE-Initiative/robosuite/blob/eafb81f54ffc104f905ee48a16bb15f059176ad3/robosuite/utils/transform_utils.py#L490C1-L512C55
|
||||
|
||||
Converts quaternion to axis-angle format.
|
||||
Returns a unit vector direction scaled by its angle in radians.
|
||||
|
||||
Args:
|
||||
quat (np.array): (x,y,z,w) vec4 float angles
|
||||
|
||||
Returns:
|
||||
np.array: (ax,ay,az) axis-angle exponential coordinates
|
||||
"""
|
||||
# clip quaternion
|
||||
if quat[3] > 1.0:
|
||||
quat[3] = 1.0
|
||||
elif quat[3] < -1.0:
|
||||
quat[3] = -1.0
|
||||
|
||||
den = np.sqrt(1.0 - quat[3] * quat[3])
|
||||
if math.isclose(den, 0.0):
|
||||
# This is (close to) a zero degree rotation, immediately return
|
||||
return np.zeros(3)
|
||||
|
||||
return (quat[:3] * 2.0 * math.acos(quat[3])) / den
|
||||
|
||||
|
||||
def get_task_init_states(task_suite, i):
|
||||
init_states_path = os.path.join(
|
||||
get_libero_path("init_states"),
|
||||
task_suite.tasks[i].problem_folder,
|
||||
task_suite.tasks[i].init_states_file,
|
||||
)
|
||||
init_states = torch.load(init_states_path, weights_only=False)
|
||||
return init_states
|
||||
|
||||
|
||||
def get_libero_dummy_action():
|
||||
"""Get dummy/no-op action, used to roll out the simulation while the robot does nothing."""
|
||||
return [0, 0, 0, 0, 0, 0, -1]
|
||||
|
||||
|
||||
class LiberoEnv(gym.Env):
|
||||
metadata = {"render_modes": ["rgb_array"], "render_fps": 80}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
task_suite,
|
||||
task_id,
|
||||
task_suite_name,
|
||||
camera_name="agentview_image,robot0_eye_in_hand_image",
|
||||
obs_type="pixels",
|
||||
render_mode="rgb_array",
|
||||
observation_width=256,
|
||||
observation_height=256,
|
||||
visualization_width=640,
|
||||
visualization_height=480,
|
||||
init_states=True,
|
||||
episode_index=0,
|
||||
):
|
||||
super().__init__()
|
||||
self.task_id = task_id
|
||||
self.obs_type = obs_type
|
||||
self.render_mode = render_mode
|
||||
self.observation_width = observation_width
|
||||
self.observation_height = observation_height
|
||||
self.visualization_width = visualization_width
|
||||
self.visualization_height = visualization_height
|
||||
self.init_states = init_states
|
||||
self.camera_name = camera_name.split(
|
||||
","
|
||||
) # agentview_image (main) or robot0_eye_in_hand_image (wrist)
|
||||
self.camera_name_mapping = {
|
||||
"agentview_image": OBS_IMAGE,
|
||||
"robot0_eye_in_hand_image": OBS_IMAGE_2,
|
||||
}
|
||||
self.num_steps_wait = (
|
||||
10 # Do nothing for the first few timesteps to wait for the simulator drops objects
|
||||
)
|
||||
self.episode_index = episode_index
|
||||
|
||||
self._env = self._make_envs_task(task_suite, self.task_id)
|
||||
if task_suite_name == "libero_spatial":
|
||||
max_steps = 220 # longest training demo has 193 steps
|
||||
elif task_suite_name == "libero_object":
|
||||
max_steps = 280 # longest training demo has 254 steps
|
||||
elif task_suite_name == "libero_goal":
|
||||
max_steps = 300 # longest training demo has 270 steps
|
||||
elif task_suite_name == "libero_10":
|
||||
max_steps = 520 # longest training demo has 505 steps
|
||||
elif task_suite_name == "libero_90":
|
||||
max_steps = 400 # longest training demo has 373 steps
|
||||
self._max_episode_steps = max_steps
|
||||
|
||||
images = {}
|
||||
for cam in self.camera_name:
|
||||
images[self.camera_name_mapping[cam]] = spaces.Box(
|
||||
low=0,
|
||||
high=255,
|
||||
shape=(self.observation_height, self.observation_width, 3),
|
||||
dtype=np.uint8,
|
||||
)
|
||||
|
||||
if self.obs_type == "state":
|
||||
raise NotImplementedError()
|
||||
elif self.obs_type == "pixels":
|
||||
self.observation_space = spaces.Dict(
|
||||
{
|
||||
"pixels": spaces.Dict(images),
|
||||
}
|
||||
)
|
||||
elif self.obs_type == "pixels_agent_pos":
|
||||
self.observation_space = spaces.Dict(
|
||||
{
|
||||
"pixels": spaces.Dict(images),
|
||||
"agent_pos": spaces.Box(
|
||||
low=-1000.0,
|
||||
high=1000.0,
|
||||
shape=(8,),
|
||||
dtype=np.float64,
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
self.action_space = spaces.Box(low=-1, high=1, shape=(7,), dtype=np.float32)
|
||||
|
||||
def render(self):
|
||||
raw_obs = self._env.env._get_observations()
|
||||
image = self._format_raw_obs(raw_obs)["pixels"][OBS_IMAGE]
|
||||
return image
|
||||
|
||||
def _make_envs_task(self, task_suite, task_id: int = 0):
|
||||
task = task_suite.get_task(task_id)
|
||||
self.task = task.name
|
||||
self.task_description = task.language
|
||||
task_bddl_file = os.path.join(get_libero_path("bddl_files"), task.problem_folder, task.bddl_file)
|
||||
|
||||
env_args = {
|
||||
"bddl_file_name": task_bddl_file,
|
||||
"camera_heights": self.observation_height,
|
||||
"camera_widths": self.observation_width,
|
||||
}
|
||||
env = OffScreenRenderEnv(**env_args)
|
||||
env.reset()
|
||||
if self.init_states:
|
||||
init_states = get_task_init_states(
|
||||
task_suite, task_id
|
||||
) # for benchmarking purpose, we fix the a set of initial states FIXME(mshukor): should be in the reset()?
|
||||
init_state_id = self.episode_index # episode index
|
||||
env.set_init_state(init_states[init_state_id])
|
||||
|
||||
return env
|
||||
|
||||
def _format_raw_obs(self, raw_obs):
|
||||
images = {}
|
||||
for camera_name in self.camera_name:
|
||||
image = raw_obs[camera_name]
|
||||
image = image[::-1, ::-1] # rotate 180 degrees
|
||||
images[self.camera_name_mapping[camera_name]] = image
|
||||
# images = image if len(images) == 1 else images
|
||||
state = np.concatenate(
|
||||
(
|
||||
raw_obs["robot0_eef_pos"],
|
||||
quat2axisangle(raw_obs["robot0_eef_quat"]),
|
||||
raw_obs["robot0_gripper_qpos"],
|
||||
)
|
||||
)
|
||||
agent_pos = state
|
||||
if self.obs_type == "state":
|
||||
raise NotImplementedError()
|
||||
elif self.obs_type == "pixels":
|
||||
obs = {"pixels": images.copy()}
|
||||
elif self.obs_type == "pixels_agent_pos":
|
||||
obs = {
|
||||
"pixels": images.copy(),
|
||||
"agent_pos": agent_pos,
|
||||
}
|
||||
return obs
|
||||
|
||||
def reset(self, seed=None, **kwargs):
|
||||
super().reset(seed=seed)
|
||||
|
||||
self._env.seed(seed)
|
||||
raw_obs = self._env.reset()
|
||||
# Do nothing for the first few timesteps to wait for the simulator drops objects
|
||||
for _ in range(self.num_steps_wait):
|
||||
raw_obs, _, _, _ = self._env.step(get_libero_dummy_action())
|
||||
observation = self._format_raw_obs(raw_obs)
|
||||
info = {"is_success": False}
|
||||
return observation, info
|
||||
|
||||
def step(self, action):
|
||||
assert action.ndim == 1
|
||||
raw_obs, reward, done, info = self._env.step(action)
|
||||
|
||||
is_success = self._env.check_success()
|
||||
terminated = done or is_success
|
||||
info["is_success"] = done # is_success
|
||||
|
||||
observation = self._format_raw_obs(raw_obs)
|
||||
if done:
|
||||
self.reset()
|
||||
print(self.task, self.task_id, done, is_success)
|
||||
truncated = False
|
||||
return observation, reward, terminated, truncated, info
|
||||
|
||||
def close(self):
|
||||
self._env.close()
|
||||
+147
-23
@@ -50,12 +50,12 @@ import json
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from contextlib import nullcontext
|
||||
from copy import deepcopy
|
||||
from dataclasses import asdict
|
||||
from pathlib import Path
|
||||
from pprint import pformat
|
||||
from typing import Callable
|
||||
|
||||
import einops
|
||||
import gymnasium as gym
|
||||
@@ -456,55 +456,179 @@ def _compile_episode_data(
|
||||
|
||||
return data_dict
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def eval_main(cfg: EvalPipelineConfig):
|
||||
def eval(cfg: EvalPipelineConfig):
|
||||
logging.info(pformat(asdict(cfg)))
|
||||
|
||||
# Check device is available
|
||||
device = get_safe_torch_device(cfg.policy.device, log=True)
|
||||
device = get_safe_torch_device(cfg.device, log=True)
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
set_seed(cfg.seed)
|
||||
set_global_seed(cfg.seed)
|
||||
|
||||
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
|
||||
log_output_dir(cfg.output_dir)
|
||||
|
||||
logging.info("Making environment.")
|
||||
env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs)
|
||||
|
||||
logging.info("Making policy.")
|
||||
|
||||
policy = make_policy(
|
||||
cfg=cfg.policy,
|
||||
device=device,
|
||||
env_cfg=cfg.env,
|
||||
)
|
||||
policy.eval()
|
||||
|
||||
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
|
||||
info = eval_policy(
|
||||
env,
|
||||
policy,
|
||||
cfg.eval.n_episodes,
|
||||
max_episodes_rendered=10,
|
||||
videos_dir=Path(cfg.output_dir) / "videos",
|
||||
start_seed=cfg.seed,
|
||||
)
|
||||
print(info["aggregated"])
|
||||
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext():
|
||||
if cfg.env.multitask_eval:
|
||||
info = eval_policy_multitask(
|
||||
env,
|
||||
policy,
|
||||
cfg.eval.n_episodes,
|
||||
max_episodes_rendered=10,
|
||||
videos_dir=Path(cfg.output_dir) / "videos",
|
||||
start_seed=cfg.seed,
|
||||
max_parallel_tasks=cfg.env.max_parallel_tasks,
|
||||
verbose=False,
|
||||
)
|
||||
# Print overall stats
|
||||
print("Overall Aggregated Metrics:")
|
||||
print(info["overall"]["aggregated"])
|
||||
|
||||
# Print per-suite stats
|
||||
for task_group, task_group_info in info.items():
|
||||
if task_group == "overall":
|
||||
continue # Skip the overall stats since we already printed it
|
||||
print(f"\nAggregated Metrics for {task_group}:")
|
||||
print(task_group_info["aggregated"])
|
||||
for _task_group, v in env.items():
|
||||
for _env in v.values():
|
||||
_env.close()
|
||||
else:
|
||||
info = eval_policy(
|
||||
env,
|
||||
policy,
|
||||
cfg.eval.n_episodes,
|
||||
max_episodes_rendered=10,
|
||||
videos_dir=Path(cfg.output_dir) / "videos",
|
||||
start_seed=cfg.seed,
|
||||
)
|
||||
print(info["aggregated"])
|
||||
env.close()
|
||||
|
||||
# Save info
|
||||
with open(Path(cfg.output_dir) / "eval_info.json", "w") as f:
|
||||
json.dump(info, f, indent=2)
|
||||
|
||||
env.close()
|
||||
|
||||
logging.info("End of eval")
|
||||
|
||||
def eval_policy_multitask(
|
||||
envs: dict[str, dict[str, gym.vector.VectorEnv]],
|
||||
policy,
|
||||
n_episodes: int,
|
||||
max_episodes_rendered: int = 0,
|
||||
videos_dir: Path | None = None,
|
||||
return_episode_data: bool = False,
|
||||
start_seed: int | None = None,
|
||||
max_parallel_tasks: int = 5,
|
||||
verbose: bool = True,
|
||||
) -> dict:
|
||||
global_start = time.time()
|
||||
results = {}
|
||||
|
||||
def main():
|
||||
init_logging()
|
||||
eval_main()
|
||||
overall_rewards, overall_max_rewards, overall_successes = [], [], []
|
||||
overall_video_paths = []
|
||||
overall_episode_data = None
|
||||
|
||||
def eval_task(task_group, task_id, env):
|
||||
"""Evaluates a single task in parallel."""
|
||||
print(f"Evaluating: task_group: {task_group}, task_id: {task_id} ...")
|
||||
task_result = eval_policy(
|
||||
env, policy, n_episodes, max_episodes_rendered, videos_dir, return_episode_data, start_seed, verbose=verbose
|
||||
)
|
||||
|
||||
per_episode = task_result["per_episode"]
|
||||
return {
|
||||
"task_group": task_group,
|
||||
"task_id": task_id,
|
||||
"sum_rewards": [ep["sum_reward"] for ep in per_episode],
|
||||
"max_rewards": [ep["max_reward"] for ep in per_episode],
|
||||
"successes": [ep["success"] for ep in per_episode],
|
||||
"video_paths": task_result.get("video_paths", []),
|
||||
}
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_parallel_tasks) as executor:
|
||||
future_to_task = {
|
||||
executor.submit(eval_task, task_group, task_id, env): (task_group, task_id)
|
||||
for task_group, tasks in envs.items()
|
||||
for task_id, env in tasks.items()
|
||||
}
|
||||
|
||||
task_group_results = {}
|
||||
|
||||
for future in concurrent.futures.as_completed(future_to_task):
|
||||
task_result = future.result()
|
||||
task_group = task_result["task_group"]
|
||||
|
||||
if task_group not in task_group_results:
|
||||
task_group_results[task_group] = {
|
||||
"sum_rewards": [],
|
||||
"max_rewards": [],
|
||||
"successes": [],
|
||||
"video_paths": [],
|
||||
}
|
||||
|
||||
task_group_results[task_group]["sum_rewards"].extend(task_result["sum_rewards"])
|
||||
task_group_results[task_group]["max_rewards"].extend(task_result["max_rewards"])
|
||||
task_group_results[task_group]["successes"].extend(task_result["successes"])
|
||||
task_group_results[task_group]["video_paths"].extend(task_result["video_paths"])
|
||||
|
||||
# Process results per task group
|
||||
for task_group, data in task_group_results.items():
|
||||
suite_rewards = data["sum_rewards"]
|
||||
suite_max_rewards = data["max_rewards"]
|
||||
suite_successes = data["successes"]
|
||||
suite_video_paths = data["video_paths"]
|
||||
|
||||
suite_eval_s = time.time() - global_start
|
||||
suite_eval_ep_s = suite_eval_s / max(1, len(suite_rewards))
|
||||
|
||||
results[task_group] = {
|
||||
"aggregated": {
|
||||
"avg_sum_reward": float(np.nanmean(suite_rewards)),
|
||||
"avg_max_reward": float(np.nanmean(suite_max_rewards)),
|
||||
"pc_success": float(np.nanmean(suite_successes) * 100),
|
||||
"eval_s": suite_eval_s,
|
||||
"eval_ep_s": suite_eval_ep_s,
|
||||
},
|
||||
"video_paths": suite_video_paths,
|
||||
"episodes": None, # Modify if episode data is needed
|
||||
}
|
||||
|
||||
overall_rewards.extend(suite_rewards)
|
||||
overall_max_rewards.extend(suite_max_rewards)
|
||||
overall_successes.extend(suite_successes)
|
||||
overall_video_paths.extend(suite_video_paths)
|
||||
|
||||
# Global metrics
|
||||
global_eval_s = time.time() - global_start
|
||||
global_eval_ep_s = global_eval_s / max(1, len(overall_rewards))
|
||||
|
||||
results["overall"] = {
|
||||
"aggregated": {
|
||||
"avg_sum_reward": float(np.nanmean(overall_rewards)),
|
||||
"avg_max_reward": float(np.nanmean(overall_max_rewards)),
|
||||
"pc_success": float(np.nanmean(overall_successes) * 100),
|
||||
"eval_s": global_eval_s,
|
||||
"eval_ep_s": global_eval_ep_s,
|
||||
},
|
||||
"video_paths": overall_video_paths,
|
||||
"episodes": overall_episode_data,
|
||||
}
|
||||
|
||||
return results
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
init_logging()
|
||||
eval_main()
|
||||
|
||||
@@ -34,7 +34,7 @@ from lerobot.optim.factory import make_optimizer_and_scheduler
|
||||
from lerobot.policies.factory import make_policy
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.utils import get_device_from_parameters
|
||||
from lerobot.scripts.eval import eval_policy
|
||||
from lerobot.scripts.eval import eval_policy, eval_policy_multitask
|
||||
from lerobot.utils.logging_utils import AverageMeter, MetricsTracker
|
||||
from lerobot.utils.random_utils import set_seed
|
||||
from lerobot.utils.train_utils import (
|
||||
@@ -252,14 +252,32 @@ def train(cfg: TrainPipelineConfig):
|
||||
torch.no_grad(),
|
||||
torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(),
|
||||
):
|
||||
eval_info = eval_policy(
|
||||
eval_env,
|
||||
policy,
|
||||
cfg.eval.n_episodes,
|
||||
videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_id}",
|
||||
max_episodes_rendered=4,
|
||||
start_seed=cfg.seed,
|
||||
)
|
||||
if cfg.env.multitask_eval:
|
||||
eval_info = eval_policy_multitask(
|
||||
eval_env,
|
||||
policy,
|
||||
cfg.eval.n_episodes,
|
||||
videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_id}",
|
||||
max_episodes_rendered=4,
|
||||
start_seed=cfg.seed,
|
||||
max_parallel_tasks=cfg.env.max_parallel_tasks,
|
||||
)
|
||||
aggregated_results = eval_info["overall"]["aggregated"]
|
||||
# Print per-suite stats
|
||||
for task_group, task_group_info in eval_info.items():
|
||||
if task_group == "overall":
|
||||
continue # Skip the overall stats since we already printed it
|
||||
print(f"\nAggregated Metrics for {task_group}:")
|
||||
print(task_group_info["aggregated"])
|
||||
else:
|
||||
eval_info = eval_policy(
|
||||
eval_env,
|
||||
policy,
|
||||
cfg.eval.n_episodes,
|
||||
videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_id}",
|
||||
max_episodes_rendered=4,
|
||||
start_seed=cfg.seed,
|
||||
)
|
||||
|
||||
eval_metrics = {
|
||||
"avg_sum_reward": AverageMeter("∑rwrd", ":.3f"),
|
||||
|
||||
Reference in New Issue
Block a user