mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-20 19:19:56 +00:00
More things
This commit is contained in:
@@ -14,11 +14,11 @@ export HF_DATASETS_OFFLINE=1
|
|||||||
export HF_HUB_OFFLINE=1
|
export HF_HUB_OFFLINE=1
|
||||||
export TOKENIZERS_PARALLELISM=false
|
export TOKENIZERS_PARALLELISM=false
|
||||||
export MUJOCO_GL=egl
|
export MUJOCO_GL=egl
|
||||||
export CUDA_VISIBLE_DEVICES=3
|
export CUDA_VISIBLE_DEVICES=2
|
||||||
|
|
||||||
# CONFIGURATION
|
# CONFIGURATION
|
||||||
POLICY_PATH="/raid/jade/logs/lerobot/lerobot_2_HuggingFaceVLA_libero_smolvla_lr1e-4bs32steps100000/checkpoints/100000/pretrained_model"
|
POLICY_PATH="/raid/jade/logs/lerobot/lerobot_2_HuggingFaceVLA_libero_smolvla_lr1e-4bs32steps100000/checkpoints/100000/pretrained_model"
|
||||||
POLICY_PATH="/raid/jade/models/smolvlamust"
|
POLICY_PATH="/raid/jade/logs/lerobot/lerobot_new_HuggingfaceVLA_libero_smolvla_lr1e-4bs32steps100000/checkpoints/100000/pretrained_model"
|
||||||
TASK=libero_spatial
|
TASK=libero_spatial
|
||||||
ENV_TYPE="libero"
|
ENV_TYPE="libero"
|
||||||
BATCH_SIZE=10
|
BATCH_SIZE=10
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
@@ -320,8 +320,6 @@ class LiberoEnv(EnvConfig):
|
|||||||
@property
|
@property
|
||||||
def gym_kwargs(self) -> dict:
|
def gym_kwargs(self) -> dict:
|
||||||
return {
|
return {
|
||||||
# "task": self.task,
|
|
||||||
"obs_type": self.obs_type,
|
"obs_type": self.obs_type,
|
||||||
"render_mode": self.render_mode,
|
"render_mode": self.render_mode,
|
||||||
# "max_episode_steps": self.episode_length,
|
|
||||||
}
|
}
|
||||||
|
|||||||
+23
-24
@@ -56,37 +56,36 @@ def make_env(
|
|||||||
names to indexed vectorized environments (when multitask eval is used).
|
names to indexed vectorized environments (when multitask eval is used).
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if n_envs < 1:
|
if n_envs < 1:
|
||||||
raise ValueError("`n_envs must be at least 1")
|
raise ValueError("`n_envs` must be at least 1")
|
||||||
|
|
||||||
# batched version of the env that returns an observation of shape (b, c)
|
env_cls = gym.vector.AsyncVectorEnv if use_async_envs else gym.vector.SyncVectorEnv
|
||||||
env_cls = gym.vector.AsyncVectorEnv if use_async_envs else gym.vector.SyncVectorEnv
|
|
||||||
|
|
||||||
if "libero" in cfg.type:
|
|
||||||
from lerobot.envs.libero import create_libero_envs
|
if "libero" in cfg.type:
|
||||||
|
from lerobot.envs.libero import create_libero_envs
|
||||||
|
return create_libero_envs(
|
||||||
|
task=cfg.task,
|
||||||
|
n_envs=n_envs,
|
||||||
|
camera_name=cfg.camera_name,
|
||||||
|
init_states=cfg.init_states,
|
||||||
|
gym_kwargs=cfg.gym_kwargs,
|
||||||
|
env_cls=env_cls,
|
||||||
|
multitask_eval=cfg.multitask_eval,
|
||||||
|
)
|
||||||
|
|
||||||
env = create_libero_envs(
|
|
||||||
task=cfg.task,
|
|
||||||
n_envs=n_envs,
|
|
||||||
camera_name=cfg.camera_name,
|
|
||||||
init_states=cfg.init_states,
|
|
||||||
gym_kwargs=cfg.gym_kwargs,
|
|
||||||
env_cls=env_cls,
|
|
||||||
multitask_eval=cfg.multitask_eval,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
package_name = f"gym_{cfg.type}"
|
package_name = f"gym_{cfg.type}"
|
||||||
try:
|
try:
|
||||||
importlib.import_module(package_name)
|
importlib.import_module(package_name)
|
||||||
except ModuleNotFoundError as e:
|
except ModuleNotFoundError as e:
|
||||||
print(
|
raise ModuleNotFoundError(
|
||||||
f"{package_name} is not installed. Please install it with `pip install 'lerobot[{cfg.type}]'`"
|
f"{package_name} is not installed. Install with: pip install \"lerobot[{cfg.type}]\""
|
||||||
)
|
) from e
|
||||||
raise e
|
|
||||||
|
|
||||||
gym_handle = f"{package_name}/{cfg.task}"
|
gym_handle = f"{package_name}/{cfg.task}"
|
||||||
env = env_cls(
|
|
||||||
[lambda: gym.make(gym_handle, disable_env_checker=True, **cfg.gym_kwargs) for _ in range(n_envs)]
|
def _make_one():
|
||||||
)
|
return gym.make(gym_handle, disable_env_checker=True, **(cfg.gym_kwargs or {}))
|
||||||
|
|
||||||
return env
|
return env_cls([_make_one for _ in range(n_envs)])
|
||||||
|
|||||||
@@ -0,0 +1,326 @@
|
|||||||
|
import math
|
||||||
|
import os
|
||||||
|
from collections import defaultdict
|
||||||
|
from collections.abc import Callable
|
||||||
|
from itertools import chain
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import gymnasium as gym
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from gymnasium import spaces
|
||||||
|
from libero.libero import benchmark, get_libero_path
|
||||||
|
from libero.libero.envs import OffScreenRenderEnv
|
||||||
|
|
||||||
|
|
||||||
|
def create_libero_envs(
|
||||||
|
task: str,
|
||||||
|
n_envs: int,
|
||||||
|
gym_kwargs: dict[str, Any] = None,
|
||||||
|
camera_name: str = "agentview_image,robot0_eye_in_hand_image",
|
||||||
|
init_states: bool = True,
|
||||||
|
env_cls: Callable = None,
|
||||||
|
multitask_eval: bool = True,
|
||||||
|
) -> dict[str, dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Here n_envs is per task and equal to the number of rollouts.
|
||||||
|
Returns:
|
||||||
|
dict[str, dict[str, list[LiberoEnv]]]: keys are task_suite and values are list of LiberoEnv envs.
|
||||||
|
"""
|
||||||
|
print("num envs", n_envs)
|
||||||
|
print("multitask_eval", multitask_eval)
|
||||||
|
print("gym_kwargs", gym_kwargs)
|
||||||
|
if gym_kwargs is None:
|
||||||
|
gym_kwargs = {}
|
||||||
|
|
||||||
|
if not multitask_eval:
|
||||||
|
benchmark_dict = benchmark.get_benchmark_dict()
|
||||||
|
task_suite = benchmark_dict[task]() # can also choose libero_spatial, libero_object, libero_10 etc.
|
||||||
|
tasks_id = list(range(len(task_suite.tasks)))
|
||||||
|
episode_indices = [0 for i in range(len(tasks_id))]
|
||||||
|
if len(tasks_id) == 1:
|
||||||
|
tasks_id = [tasks_id[0] for _ in range(n_envs)]
|
||||||
|
episode_indices = list(range(n_envs))
|
||||||
|
elif len(tasks_id) < n_envs and n_envs % len(tasks_id) == 0:
|
||||||
|
n_repeat = n_envs // len(tasks_id)
|
||||||
|
print("n_repeat", n_repeat)
|
||||||
|
episode_indices = []
|
||||||
|
for _ in range(len(tasks_id)):
|
||||||
|
episode_indices.extend(list(range(n_repeat)))
|
||||||
|
tasks_id = list(chain.from_iterable([[item] * n_repeat for item in tasks_id]))
|
||||||
|
elif n_envs < len(tasks_id):
|
||||||
|
tasks_id = tasks_id[:n_envs]
|
||||||
|
episode_indices = list(range(n_envs))[:n_envs]
|
||||||
|
print(f"WARNING: n_envs < len(tasks_id), evaluating only on {tasks_id}")
|
||||||
|
print(f"Creating Libero envs with task ids {tasks_id} from suite {task}")
|
||||||
|
assert n_envs == len(tasks_id), (
|
||||||
|
f"len(n_envs) and tasks_id should be the same, got {n_envs} and {len(tasks_id)}"
|
||||||
|
)
|
||||||
|
return env_cls(
|
||||||
|
[
|
||||||
|
lambda i=i: LiberoEnv(
|
||||||
|
task_suite=task_suite,
|
||||||
|
task_id=tasks_id[i],
|
||||||
|
task_suite_name=task,
|
||||||
|
camera_name=camera_name,
|
||||||
|
init_states=init_states,
|
||||||
|
episode_index=episode_indices[i],
|
||||||
|
**gym_kwargs,
|
||||||
|
)
|
||||||
|
for i in range(n_envs)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
envs = defaultdict(dict)
|
||||||
|
benchmark_dict = benchmark.get_benchmark_dict()
|
||||||
|
task = task.split(",")
|
||||||
|
for _task in task:
|
||||||
|
task_suite = benchmark_dict[
|
||||||
|
_task
|
||||||
|
]() # can also choose libero_spatial, libero_object, libero_10 etc.
|
||||||
|
tasks_ids = list(range(len(task_suite.tasks)))
|
||||||
|
for tasks_id in tasks_ids:
|
||||||
|
episode_indices = list(range(n_envs))
|
||||||
|
print(
|
||||||
|
f"Creating Libero envs with task ids {tasks_id} from suite {_task}, episode_indices: {episode_indices}"
|
||||||
|
)
|
||||||
|
envs_list = [
|
||||||
|
(
|
||||||
|
lambda i=i,
|
||||||
|
task_suite=task_suite,
|
||||||
|
tasks_id=tasks_id,
|
||||||
|
_task=_task,
|
||||||
|
episode_indices=episode_indices: LiberoEnv(
|
||||||
|
task_suite=task_suite,
|
||||||
|
task_id=tasks_id,
|
||||||
|
task_suite_name=_task,
|
||||||
|
camera_name=camera_name,
|
||||||
|
init_states=init_states,
|
||||||
|
episode_index=episode_indices[i],
|
||||||
|
**gym_kwargs,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for i in range(n_envs)
|
||||||
|
]
|
||||||
|
envs[_task][tasks_id] = env_cls(envs_list)
|
||||||
|
return envs
|
||||||
|
|
||||||
|
|
||||||
|
def quat2axisangle(quat):
|
||||||
|
"""
|
||||||
|
Copied from robosuite: https://github.com/ARISE-Initiative/robosuite/blob/eafb81f54ffc104f905ee48a16bb15f059176ad3/robosuite/utils/transform_utils.py#L490C1-L512C55
|
||||||
|
|
||||||
|
Converts quaternion to axis-angle format.
|
||||||
|
Returns a unit vector direction scaled by its angle in radians.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
quat (np.array): (x,y,z,w) vec4 float angles
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.array: (ax,ay,az) axis-angle exponential coordinates
|
||||||
|
"""
|
||||||
|
# clip quaternion
|
||||||
|
if quat[3] > 1.0:
|
||||||
|
quat[3] = 1.0
|
||||||
|
elif quat[3] < -1.0:
|
||||||
|
quat[3] = -1.0
|
||||||
|
|
||||||
|
den = np.sqrt(1.0 - quat[3] * quat[3])
|
||||||
|
if math.isclose(den, 0.0):
|
||||||
|
# This is (close to) a zero degree rotation, immediately return
|
||||||
|
return np.zeros(3)
|
||||||
|
|
||||||
|
return (quat[:3] * 2.0 * math.acos(quat[3])) / den
|
||||||
|
|
||||||
|
|
||||||
|
def get_task_init_states(task_suite, i):
|
||||||
|
init_states_path = os.path.join(
|
||||||
|
get_libero_path("init_states"),
|
||||||
|
task_suite.tasks[i].problem_folder,
|
||||||
|
task_suite.tasks[i].init_states_file,
|
||||||
|
)
|
||||||
|
init_states = torch.load(init_states_path, weights_only=False) # nosec B614
|
||||||
|
return init_states
|
||||||
|
|
||||||
|
|
||||||
|
def get_libero_dummy_action():
|
||||||
|
"""Get dummy/no-op action, used to roll out the simulation while the robot does nothing."""
|
||||||
|
return [0, 0, 0, 0, 0, 0, -1]
|
||||||
|
|
||||||
|
|
||||||
|
OBS_STATE_DIM = 8
|
||||||
|
ACTION_DIM = 7
|
||||||
|
|
||||||
|
|
||||||
|
class LiberoEnv(gym.Env):
|
||||||
|
metadata = {"render_modes": ["rgb_array"], "render_fps": 80}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
task_suite,
|
||||||
|
task_id,
|
||||||
|
task_suite_name,
|
||||||
|
camera_name="agentview_image,robot0_eye_in_hand_image",
|
||||||
|
obs_type="pixels",
|
||||||
|
render_mode="rgb_array",
|
||||||
|
observation_width=256,
|
||||||
|
observation_height=256,
|
||||||
|
visualization_width=640,
|
||||||
|
visualization_height=480,
|
||||||
|
init_states=True,
|
||||||
|
episode_index=0,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.task_id = task_id
|
||||||
|
self.obs_type = obs_type
|
||||||
|
self.render_mode = render_mode
|
||||||
|
self.observation_width = observation_width
|
||||||
|
self.observation_height = observation_height
|
||||||
|
self.visualization_width = visualization_width
|
||||||
|
self.visualization_height = visualization_height
|
||||||
|
self.init_states = init_states
|
||||||
|
self.camera_name = camera_name.split(
|
||||||
|
","
|
||||||
|
) # agentview_image (main) or robot0_eye_in_hand_image (wrist)
|
||||||
|
|
||||||
|
# Map raw camera names to "image1" and "image2".
|
||||||
|
# The preprocessing step `preprocess_observation` will then prefix these with `.images.*`,
|
||||||
|
# following the LeRobot convention (e.g., `observation.images.image`, `observation.images.image2`).
|
||||||
|
# This ensures the policy consistently receives observations in the
|
||||||
|
# expected format regardless of the original camera naming.
|
||||||
|
self.camera_name_mapping = {
|
||||||
|
"agentview_image": "image",
|
||||||
|
"robot0_eye_in_hand_image": "image2",
|
||||||
|
}
|
||||||
|
|
||||||
|
self.num_steps_wait = (
|
||||||
|
10 # Do nothing for the first few timesteps to wait for the simulator drops objects
|
||||||
|
)
|
||||||
|
self.episode_index = episode_index
|
||||||
|
|
||||||
|
self._env = self._make_envs_task(task_suite, self.task_id)
|
||||||
|
if task_suite_name == "libero_spatial":
|
||||||
|
max_steps = 220 # longest training demo has 193 steps
|
||||||
|
elif task_suite_name == "libero_object":
|
||||||
|
max_steps = 280 # longest training demo has 254 steps
|
||||||
|
elif task_suite_name == "libero_goal":
|
||||||
|
max_steps = 300 # longest training demo has 270 steps
|
||||||
|
elif task_suite_name == "libero_10":
|
||||||
|
max_steps = 520 # longest training demo has 505 steps
|
||||||
|
elif task_suite_name == "libero_90":
|
||||||
|
max_steps = 400 # longest training demo has 373 steps
|
||||||
|
self._max_episode_steps = max_steps
|
||||||
|
|
||||||
|
images = {}
|
||||||
|
for cam in self.camera_name:
|
||||||
|
images[self.camera_name_mapping[cam]] = spaces.Box(
|
||||||
|
low=0,
|
||||||
|
high=255,
|
||||||
|
shape=(self.observation_height, self.observation_width, 3),
|
||||||
|
dtype=np.uint8,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.obs_type == "state":
|
||||||
|
raise NotImplementedError()
|
||||||
|
elif self.obs_type == "pixels":
|
||||||
|
self.observation_space = spaces.Dict(
|
||||||
|
{
|
||||||
|
"pixels": spaces.Dict(images),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
elif self.obs_type == "pixels_agent_pos":
|
||||||
|
self.observation_space = spaces.Dict(
|
||||||
|
{
|
||||||
|
"pixels": spaces.Dict(images),
|
||||||
|
"agent_pos": spaces.Box(
|
||||||
|
low=-1000.0,
|
||||||
|
high=1000.0,
|
||||||
|
shape=(OBS_STATE_DIM,),
|
||||||
|
dtype=np.float64,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
self.action_space = spaces.Box(low=-1, high=1, shape=(ACTION_DIM,), dtype=np.float32)
|
||||||
|
|
||||||
|
def render(self):
|
||||||
|
raw_obs = self._env.env._get_observations()
|
||||||
|
image = self._format_raw_obs(raw_obs)["pixels"]["image"]
|
||||||
|
return image
|
||||||
|
|
||||||
|
def _make_envs_task(self, task_suite, task_id: int = 0):
|
||||||
|
task = task_suite.get_task(task_id)
|
||||||
|
self.task = task.name
|
||||||
|
self.task_description = task.language
|
||||||
|
task_bddl_file = os.path.join(get_libero_path("bddl_files"), task.problem_folder, task.bddl_file)
|
||||||
|
|
||||||
|
env_args = {
|
||||||
|
"bddl_file_name": task_bddl_file,
|
||||||
|
"camera_heights": self.observation_height,
|
||||||
|
"camera_widths": self.observation_width,
|
||||||
|
}
|
||||||
|
env = OffScreenRenderEnv(**env_args)
|
||||||
|
env.reset()
|
||||||
|
if self.init_states:
|
||||||
|
init_states = get_task_init_states(
|
||||||
|
task_suite, task_id
|
||||||
|
) # for benchmarking purpose, we fix the a set of initial states FIXME(mshukor): should be in the reset()?
|
||||||
|
init_state_id = self.episode_index # episode index
|
||||||
|
env.set_init_state(init_states[init_state_id])
|
||||||
|
|
||||||
|
return env
|
||||||
|
|
||||||
|
def _format_raw_obs(self, raw_obs):
|
||||||
|
images = {}
|
||||||
|
for camera_name in self.camera_name:
|
||||||
|
image = raw_obs[camera_name]
|
||||||
|
image = image[::-1, ::-1] # rotate 180 degrees
|
||||||
|
images[self.camera_name_mapping[camera_name]] = image
|
||||||
|
state = np.concatenate(
|
||||||
|
(
|
||||||
|
raw_obs["robot0_eef_pos"],
|
||||||
|
quat2axisangle(raw_obs["robot0_eef_quat"]),
|
||||||
|
raw_obs["robot0_gripper_qpos"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
agent_pos = state
|
||||||
|
if self.obs_type == "state":
|
||||||
|
raise NotImplementedError()
|
||||||
|
elif self.obs_type == "pixels":
|
||||||
|
obs = {"pixels": images.copy()}
|
||||||
|
elif self.obs_type == "pixels_agent_pos":
|
||||||
|
obs = {
|
||||||
|
"pixels": images.copy(),
|
||||||
|
"agent_pos": agent_pos,
|
||||||
|
}
|
||||||
|
return obs
|
||||||
|
|
||||||
|
def reset(self, seed=None, **kwargs):
|
||||||
|
super().reset(seed=seed)
|
||||||
|
|
||||||
|
self._env.seed(seed)
|
||||||
|
raw_obs = self._env.reset()
|
||||||
|
# Do nothing for the first few timesteps to wait for the simulator drops objects
|
||||||
|
for _ in range(self.num_steps_wait):
|
||||||
|
raw_obs, _, _, _ = self._env.step(get_libero_dummy_action())
|
||||||
|
observation = self._format_raw_obs(raw_obs)
|
||||||
|
info = {"is_success": False}
|
||||||
|
return observation, info
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
assert action.ndim == 1
|
||||||
|
raw_obs, reward, done, info = self._env.step(action)
|
||||||
|
|
||||||
|
is_success = self._env.check_success()
|
||||||
|
terminated = done or is_success
|
||||||
|
info["is_success"] = done # is_success
|
||||||
|
|
||||||
|
observation = self._format_raw_obs(raw_obs)
|
||||||
|
if done:
|
||||||
|
self.reset()
|
||||||
|
print(self.task, self.task_id, done, is_success)
|
||||||
|
truncated = False
|
||||||
|
return observation, reward, terminated, truncated, info
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
self._env.close()
|
||||||
@@ -245,9 +245,8 @@ class LiberoEnv(gym.Env):
|
|||||||
|
|
||||||
def render(self):
|
def render(self):
|
||||||
raw_obs = self._env.env._get_observations()
|
raw_obs = self._env.env._get_observations()
|
||||||
formatted = self._format_raw_obs(raw_obs)
|
image = self._format_raw_obs(raw_obs)["pixels"]["image"]
|
||||||
# grab the "main" camera
|
return image
|
||||||
return formatted["pixels"]["image"]
|
|
||||||
|
|
||||||
def _make_envs_task(self, task_suite, task_id: int = 0):
|
def _make_envs_task(self, task_suite, task_id: int = 0):
|
||||||
task = task_suite.get_task(task_id)
|
task = task_suite.get_task(task_id)
|
||||||
@@ -277,7 +276,6 @@ class LiberoEnv(gym.Env):
|
|||||||
image = raw_obs[camera_name]
|
image = raw_obs[camera_name]
|
||||||
image = image[::-1, ::-1] # rotate 180 degrees
|
image = image[::-1, ::-1] # rotate 180 degrees
|
||||||
images[self.camera_name_mapping[camera_name]] = image
|
images[self.camera_name_mapping[camera_name]] = image
|
||||||
# images = image if len(images) == 1 else images
|
|
||||||
state = np.concatenate(
|
state = np.concatenate(
|
||||||
(
|
(
|
||||||
raw_obs["robot0_eef_pos"],
|
raw_obs["robot0_eef_pos"],
|
||||||
@@ -311,14 +309,17 @@ class LiberoEnv(gym.Env):
|
|||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
assert action.ndim == 1
|
assert action.ndim == 1
|
||||||
action[-1] = 1.0 - action[-1]
|
|
||||||
raw_obs, reward, done, info = self._env.step(action)
|
raw_obs, reward, done, info = self._env.step(action)
|
||||||
|
|
||||||
is_success = self._env.check_success()
|
is_success = self._env.check_success()
|
||||||
terminated = done or is_success
|
terminated = done or is_success
|
||||||
info["is_success"] = is_success
|
info["is_success"] = done # is_success
|
||||||
|
|
||||||
observation = self._format_raw_obs(raw_obs)
|
observation = self._format_raw_obs(raw_obs)
|
||||||
|
if done:
|
||||||
|
self.reset()
|
||||||
|
print(self.task, self.task_id, done, is_success)
|
||||||
truncated = False
|
truncated = False
|
||||||
# note if it is unable to complete get libero error after many steps
|
|
||||||
return observation, reward, terminated, truncated, info
|
return observation, reward, terminated, truncated, info
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
|
|||||||
@@ -0,0 +1,308 @@
|
|||||||
|
import math
|
||||||
|
import os
|
||||||
|
from collections import defaultdict
|
||||||
|
from itertools import chain
|
||||||
|
from typing import Any, Callable
|
||||||
|
|
||||||
|
import gymnasium as gym
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from gymnasium import spaces
|
||||||
|
from libero.libero import benchmark, get_libero_path
|
||||||
|
from libero.libero.envs import OffScreenRenderEnv
|
||||||
|
|
||||||
|
|
||||||
|
OBS_IMAGE = "observation.image"
|
||||||
|
OBS_IMAGE_2 = "observation.image2"
|
||||||
|
def create_libero_envs(
|
||||||
|
task: str,
|
||||||
|
n_envs: int,
|
||||||
|
gym_kwargs: dict[str, Any] = None,
|
||||||
|
camera_name: str = "agentview_image,robot0_eye_in_hand_image",
|
||||||
|
init_states: bool = True,
|
||||||
|
env_cls: Callable = None,
|
||||||
|
multitask_eval: bool = True,
|
||||||
|
) -> dict[str, dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Here n_envs is per task and equal to the number of rollouts.
|
||||||
|
Returns:
|
||||||
|
dict[str, dict[str, list[LiberoEnv]]]: keys are task_suite and values are list of LiberoEnv envs.
|
||||||
|
"""
|
||||||
|
if gym_kwargs is None:
|
||||||
|
gym_kwargs = {}
|
||||||
|
|
||||||
|
if not multitask_eval:
|
||||||
|
benchmark_dict = benchmark.get_benchmark_dict()
|
||||||
|
task_suite = benchmark_dict[task]() # can also choose libero_spatial, libero_object, libero_10 etc.
|
||||||
|
tasks_id = list(range(len(task_suite.tasks)))
|
||||||
|
episode_indices = [0 for i in range(len(tasks_id))]
|
||||||
|
if len(tasks_id) == 1:
|
||||||
|
tasks_id = [tasks_id[0] for _ in range(n_envs)]
|
||||||
|
episode_indices = list(range(n_envs))
|
||||||
|
elif len(tasks_id) < n_envs and n_envs % len(tasks_id) == 0:
|
||||||
|
n_repeat = n_envs // len(tasks_id)
|
||||||
|
episode_indices = []
|
||||||
|
for i in range(len(tasks_id)):
|
||||||
|
episode_indices.extend(list(range(n_repeat)))
|
||||||
|
tasks_id = list(chain.from_iterable([[item] * n_repeat for item in tasks_id]))
|
||||||
|
elif n_envs < len(tasks_id):
|
||||||
|
tasks_id = tasks_id[:n_envs]
|
||||||
|
episode_indices = list(range(n_envs))[:n_envs]
|
||||||
|
print(f"WARNING: n_envs < len(tasks_id), evaluating only on {tasks_id}")
|
||||||
|
print(f"Creating Libero envs with task ids {tasks_id} from suite {task}")
|
||||||
|
assert n_envs == len(
|
||||||
|
tasks_id
|
||||||
|
), f"len(n_envs) and tasks_id should be the same, got {n_envs} and {len(tasks_id)}"
|
||||||
|
return env_cls(
|
||||||
|
[
|
||||||
|
lambda i=i: LiberoEnv(
|
||||||
|
task_suite=task_suite,
|
||||||
|
task_id=tasks_id[i],
|
||||||
|
task_suite_name=task,
|
||||||
|
camera_name=camera_name,
|
||||||
|
init_states=init_states,
|
||||||
|
episode_index=episode_indices[i],
|
||||||
|
**gym_kwargs,
|
||||||
|
)
|
||||||
|
for i in range(n_envs)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
envs = defaultdict(dict)
|
||||||
|
benchmark_dict = benchmark.get_benchmark_dict()
|
||||||
|
task = task.split(",")
|
||||||
|
for _task in task:
|
||||||
|
task_suite = benchmark_dict[
|
||||||
|
_task
|
||||||
|
]() # can also choose libero_spatial, libero_object, libero_10 etc.
|
||||||
|
tasks_ids = list(range(len(task_suite.tasks)))
|
||||||
|
# tasks_ids = [0] # FIXME(mshukor): debug
|
||||||
|
for tasks_id in tasks_ids:
|
||||||
|
episode_indices = list(range(n_envs))
|
||||||
|
print(
|
||||||
|
f"Creating Libero envs with task ids {tasks_id} from suite {_task}, episode_indices: {episode_indices}"
|
||||||
|
)
|
||||||
|
envs_list = [
|
||||||
|
lambda i=i: LiberoEnv(
|
||||||
|
task_suite=task_suite,
|
||||||
|
task_id=tasks_id,
|
||||||
|
task_suite_name=_task,
|
||||||
|
camera_name=camera_name,
|
||||||
|
init_states=init_states,
|
||||||
|
episode_index=episode_indices[i],
|
||||||
|
**gym_kwargs,
|
||||||
|
)
|
||||||
|
for i in range(n_envs)
|
||||||
|
]
|
||||||
|
envs[_task][tasks_id] = env_cls(envs_list)
|
||||||
|
return envs
|
||||||
|
|
||||||
|
|
||||||
|
def quat2axisangle(quat):
|
||||||
|
"""
|
||||||
|
Copied from robosuite: https://github.com/ARISE-Initiative/robosuite/blob/eafb81f54ffc104f905ee48a16bb15f059176ad3/robosuite/utils/transform_utils.py#L490C1-L512C55
|
||||||
|
|
||||||
|
Converts quaternion to axis-angle format.
|
||||||
|
Returns a unit vector direction scaled by its angle in radians.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
quat (np.array): (x,y,z,w) vec4 float angles
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.array: (ax,ay,az) axis-angle exponential coordinates
|
||||||
|
"""
|
||||||
|
# clip quaternion
|
||||||
|
if quat[3] > 1.0:
|
||||||
|
quat[3] = 1.0
|
||||||
|
elif quat[3] < -1.0:
|
||||||
|
quat[3] = -1.0
|
||||||
|
|
||||||
|
den = np.sqrt(1.0 - quat[3] * quat[3])
|
||||||
|
if math.isclose(den, 0.0):
|
||||||
|
# This is (close to) a zero degree rotation, immediately return
|
||||||
|
return np.zeros(3)
|
||||||
|
|
||||||
|
return (quat[:3] * 2.0 * math.acos(quat[3])) / den
|
||||||
|
|
||||||
|
|
||||||
|
def get_task_init_states(task_suite, i):
|
||||||
|
init_states_path = os.path.join(
|
||||||
|
get_libero_path("init_states"),
|
||||||
|
task_suite.tasks[i].problem_folder,
|
||||||
|
task_suite.tasks[i].init_states_file,
|
||||||
|
)
|
||||||
|
init_states = torch.load(init_states_path, weights_only=False)
|
||||||
|
return init_states
|
||||||
|
|
||||||
|
|
||||||
|
def get_libero_dummy_action():
|
||||||
|
"""Get dummy/no-op action, used to roll out the simulation while the robot does nothing."""
|
||||||
|
return [0, 0, 0, 0, 0, 0, -1]
|
||||||
|
|
||||||
|
|
||||||
|
class LiberoEnv(gym.Env):
|
||||||
|
metadata = {"render_modes": ["rgb_array"], "render_fps": 80}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
task_suite,
|
||||||
|
task_id,
|
||||||
|
task_suite_name,
|
||||||
|
camera_name="agentview_image,robot0_eye_in_hand_image",
|
||||||
|
obs_type="pixels",
|
||||||
|
render_mode="rgb_array",
|
||||||
|
observation_width=256,
|
||||||
|
observation_height=256,
|
||||||
|
visualization_width=640,
|
||||||
|
visualization_height=480,
|
||||||
|
init_states=True,
|
||||||
|
episode_index=0,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.task_id = task_id
|
||||||
|
self.obs_type = obs_type
|
||||||
|
self.render_mode = render_mode
|
||||||
|
self.observation_width = observation_width
|
||||||
|
self.observation_height = observation_height
|
||||||
|
self.visualization_width = visualization_width
|
||||||
|
self.visualization_height = visualization_height
|
||||||
|
self.init_states = init_states
|
||||||
|
self.camera_name = camera_name.split(
|
||||||
|
","
|
||||||
|
) # agentview_image (main) or robot0_eye_in_hand_image (wrist)
|
||||||
|
self.camera_name_mapping = {
|
||||||
|
"agentview_image": OBS_IMAGE,
|
||||||
|
"robot0_eye_in_hand_image": OBS_IMAGE_2,
|
||||||
|
}
|
||||||
|
self.num_steps_wait = (
|
||||||
|
10 # Do nothing for the first few timesteps to wait for the simulator drops objects
|
||||||
|
)
|
||||||
|
self.episode_index = episode_index
|
||||||
|
|
||||||
|
self._env = self._make_envs_task(task_suite, self.task_id)
|
||||||
|
if task_suite_name == "libero_spatial":
|
||||||
|
max_steps = 220 # longest training demo has 193 steps
|
||||||
|
elif task_suite_name == "libero_object":
|
||||||
|
max_steps = 280 # longest training demo has 254 steps
|
||||||
|
elif task_suite_name == "libero_goal":
|
||||||
|
max_steps = 300 # longest training demo has 270 steps
|
||||||
|
elif task_suite_name == "libero_10":
|
||||||
|
max_steps = 520 # longest training demo has 505 steps
|
||||||
|
elif task_suite_name == "libero_90":
|
||||||
|
max_steps = 400 # longest training demo has 373 steps
|
||||||
|
self._max_episode_steps = max_steps
|
||||||
|
|
||||||
|
images = {}
|
||||||
|
for cam in self.camera_name:
|
||||||
|
images[self.camera_name_mapping[cam]] = spaces.Box(
|
||||||
|
low=0,
|
||||||
|
high=255,
|
||||||
|
shape=(self.observation_height, self.observation_width, 3),
|
||||||
|
dtype=np.uint8,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.obs_type == "state":
|
||||||
|
raise NotImplementedError()
|
||||||
|
elif self.obs_type == "pixels":
|
||||||
|
self.observation_space = spaces.Dict(
|
||||||
|
{
|
||||||
|
"pixels": spaces.Dict(images),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
elif self.obs_type == "pixels_agent_pos":
|
||||||
|
self.observation_space = spaces.Dict(
|
||||||
|
{
|
||||||
|
"pixels": spaces.Dict(images),
|
||||||
|
"agent_pos": spaces.Box(
|
||||||
|
low=-1000.0,
|
||||||
|
high=1000.0,
|
||||||
|
shape=(8,),
|
||||||
|
dtype=np.float64,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
self.action_space = spaces.Box(low=-1, high=1, shape=(7,), dtype=np.float32)
|
||||||
|
|
||||||
|
def render(self):
|
||||||
|
raw_obs = self._env.env._get_observations()
|
||||||
|
image = self._format_raw_obs(raw_obs)["pixels"][OBS_IMAGE]
|
||||||
|
return image
|
||||||
|
|
||||||
|
def _make_envs_task(self, task_suite, task_id: int = 0):
|
||||||
|
task = task_suite.get_task(task_id)
|
||||||
|
self.task = task.name
|
||||||
|
self.task_description = task.language
|
||||||
|
task_bddl_file = os.path.join(get_libero_path("bddl_files"), task.problem_folder, task.bddl_file)
|
||||||
|
|
||||||
|
env_args = {
|
||||||
|
"bddl_file_name": task_bddl_file,
|
||||||
|
"camera_heights": self.observation_height,
|
||||||
|
"camera_widths": self.observation_width,
|
||||||
|
}
|
||||||
|
env = OffScreenRenderEnv(**env_args)
|
||||||
|
env.reset()
|
||||||
|
if self.init_states:
|
||||||
|
init_states = get_task_init_states(
|
||||||
|
task_suite, task_id
|
||||||
|
) # for benchmarking purpose, we fix the a set of initial states FIXME(mshukor): should be in the reset()?
|
||||||
|
init_state_id = self.episode_index # episode index
|
||||||
|
env.set_init_state(init_states[init_state_id])
|
||||||
|
|
||||||
|
return env
|
||||||
|
|
||||||
|
def _format_raw_obs(self, raw_obs):
|
||||||
|
images = {}
|
||||||
|
for camera_name in self.camera_name:
|
||||||
|
image = raw_obs[camera_name]
|
||||||
|
image = image[::-1, ::-1] # rotate 180 degrees
|
||||||
|
images[self.camera_name_mapping[camera_name]] = image
|
||||||
|
# images = image if len(images) == 1 else images
|
||||||
|
state = np.concatenate(
|
||||||
|
(
|
||||||
|
raw_obs["robot0_eef_pos"],
|
||||||
|
quat2axisangle(raw_obs["robot0_eef_quat"]),
|
||||||
|
raw_obs["robot0_gripper_qpos"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
agent_pos = state
|
||||||
|
if self.obs_type == "state":
|
||||||
|
raise NotImplementedError()
|
||||||
|
elif self.obs_type == "pixels":
|
||||||
|
obs = {"pixels": images.copy()}
|
||||||
|
elif self.obs_type == "pixels_agent_pos":
|
||||||
|
obs = {
|
||||||
|
"pixels": images.copy(),
|
||||||
|
"agent_pos": agent_pos,
|
||||||
|
}
|
||||||
|
return obs
|
||||||
|
|
||||||
|
def reset(self, seed=None, **kwargs):
|
||||||
|
super().reset(seed=seed)
|
||||||
|
|
||||||
|
self._env.seed(seed)
|
||||||
|
raw_obs = self._env.reset()
|
||||||
|
# Do nothing for the first few timesteps to wait for the simulator drops objects
|
||||||
|
for _ in range(self.num_steps_wait):
|
||||||
|
raw_obs, _, _, _ = self._env.step(get_libero_dummy_action())
|
||||||
|
observation = self._format_raw_obs(raw_obs)
|
||||||
|
info = {"is_success": False}
|
||||||
|
return observation, info
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
assert action.ndim == 1
|
||||||
|
raw_obs, reward, done, info = self._env.step(action)
|
||||||
|
|
||||||
|
is_success = self._env.check_success()
|
||||||
|
terminated = done or is_success
|
||||||
|
info["is_success"] = done # is_success
|
||||||
|
|
||||||
|
observation = self._format_raw_obs(raw_obs)
|
||||||
|
if done:
|
||||||
|
self.reset()
|
||||||
|
print(self.task, self.task_id, done, is_success)
|
||||||
|
truncated = False
|
||||||
|
return observation, reward, terminated, truncated, info
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
self._env.close()
|
||||||
@@ -80,7 +80,56 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
|||||||
|
|
||||||
return return_observations
|
return return_observations
|
||||||
|
|
||||||
|
def preprocess_observation1(
|
||||||
|
observations: dict[str, np.ndarray], cfg: dict[str, Any] = None
|
||||||
|
) -> dict[str, Tensor]:
|
||||||
|
# TODO(aliberts, rcadene): refactor this to use features from the environment (no hardcoding)
|
||||||
|
"""Convert environment observation to LeRobot format observation.
|
||||||
|
Args:
|
||||||
|
observation: Dictionary of observation batches from a Gym vector environment.
|
||||||
|
Returns:
|
||||||
|
Dictionary of observation batches with keys renamed to LeRobot format and values as tensors.
|
||||||
|
"""
|
||||||
|
# map to expected inputs for the policy
|
||||||
|
return_observations = {}
|
||||||
|
image_key = list(cfg.image_features.keys())[0] if cfg else "observation.image"
|
||||||
|
state_key = cfg.robot_state_feature_key if cfg else "observation.state"
|
||||||
|
if "pixels" in observations:
|
||||||
|
if isinstance(observations["pixels"], dict):
|
||||||
|
# imgs = {f"{image_key}.{key}": img for key, img in observations["pixels"].items()}
|
||||||
|
imgs = observations["pixels"] # keys should be OBS_IMAGE, OBS_IMAGE_2, OBS_IMAGE_3
|
||||||
|
else:
|
||||||
|
imgs = {f"{image_key}": observations["pixels"]}
|
||||||
|
|
||||||
|
for imgkey, img in imgs.items():
|
||||||
|
# TODO(aliberts, rcadene): use transforms.ToTensor()?
|
||||||
|
img = torch.from_numpy(img)
|
||||||
|
|
||||||
|
# sanity check that images are channel last
|
||||||
|
_, h, w, c = img.shape
|
||||||
|
assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}"
|
||||||
|
|
||||||
|
# sanity check that images are uint8
|
||||||
|
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
|
||||||
|
|
||||||
|
# convert to channel first of type float32 in range [0,1]
|
||||||
|
img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
|
||||||
|
img = img.type(torch.float32)
|
||||||
|
img /= 255
|
||||||
|
|
||||||
|
return_observations[imgkey] = img
|
||||||
|
|
||||||
|
if "environment_state" in observations:
|
||||||
|
return_observations["observation.environment_state"] = torch.from_numpy(
|
||||||
|
observations["environment_state"]
|
||||||
|
).float()
|
||||||
|
|
||||||
|
# TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing
|
||||||
|
# requirement for "agent_pos"
|
||||||
|
return_observations[state_key] = torch.from_numpy(observations["agent_pos"]).float()
|
||||||
|
if "task" in observations:
|
||||||
|
return_observations["task"] = observations["task"]
|
||||||
|
return return_observations
|
||||||
def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]:
|
def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]:
|
||||||
# TODO(aliberts, rcadene): remove this hardcoding of keys and just use the nested keys as is
|
# TODO(aliberts, rcadene): remove this hardcoding of keys and just use the nested keys as is
|
||||||
# (need to also refactor preprocess_observation and externalize normalization from policies)
|
# (need to also refactor preprocess_observation and externalize normalization from policies)
|
||||||
|
|||||||
@@ -177,6 +177,6 @@ def make_policy(
|
|||||||
policy = policy_cls(**kwargs)
|
policy = policy_cls(**kwargs)
|
||||||
policy.to(cfg.device)
|
policy.to(cfg.device)
|
||||||
assert isinstance(policy, nn.Module)
|
assert isinstance(policy, nn.Module)
|
||||||
|
breakpoint()
|
||||||
# policy = torch.compile(policy, mode="reduce-overhead")
|
# policy = torch.compile(policy, mode="reduce-overhead")
|
||||||
|
|
||||||
return policy
|
return policy
|
||||||
|
|||||||
@@ -51,7 +51,9 @@ policy = Pi0Policy.from_pretrained("lerobot/pi0")
|
|||||||
|
|
||||||
import math
|
import math
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import safetensors
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F # noqa: N812
|
import torch.nn.functional as F # noqa: N812
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
@@ -169,7 +171,72 @@ def resize_with_pad(img, width, height, pad_value=-1):
|
|||||||
padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value)
|
padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value)
|
||||||
return padded_img
|
return padded_img
|
||||||
|
|
||||||
|
_VARIANT_RE = re.compile(r"\.so\d+(?:-[\w]+)?_buffer_")
|
||||||
|
def canonicalise(k: str) -> str:
|
||||||
|
"""
|
||||||
|
Remove dataset-variant markers like '.so100-blue_' or '.so100_' from a
|
||||||
|
normalisation-buffer key.
|
||||||
|
"""
|
||||||
|
return _VARIANT_RE.sub(".buffer_", k)
|
||||||
|
|
||||||
|
def standardise_state_dict(
|
||||||
|
checkpoint: dict[str, torch.Tensor], ref_keys: set[str], *, verbose: bool = True
|
||||||
|
) -> tuple[dict[str, torch.Tensor], list[str]]:
|
||||||
|
"""
|
||||||
|
• Re-keys `checkpoint ` so that every entry matches the *reference* key set.
|
||||||
|
• If several variant keys collapse to the same canonical name we keep the
|
||||||
|
first one and log the collision.
|
||||||
|
• Returns the new dict + a list of entries that could not be matched.
|
||||||
|
"""
|
||||||
|
out, collisions, unmatched = {}, {}, []
|
||||||
|
|
||||||
|
for k, v in checkpoint.items():
|
||||||
|
canon = canonicalise(k)
|
||||||
|
if canon in ref_keys:
|
||||||
|
if canon in out: # duplicate after collapsing
|
||||||
|
collisions.setdefault(canon, []).append(k)
|
||||||
|
else:
|
||||||
|
out[canon] = v
|
||||||
|
else:
|
||||||
|
unmatched.append(k)
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
for canon, variants in collisions.items():
|
||||||
|
print(f"[standardise_state_dict] '{canon}' ← {variants}")
|
||||||
|
if unmatched:
|
||||||
|
print(f"[standardise_state_dict] kept {len(unmatched)} unmatched keys")
|
||||||
|
|
||||||
|
out.update({k: checkpoint[k] for k in unmatched})
|
||||||
|
return out, unmatched
|
||||||
|
|
||||||
|
def load_smolvla(
|
||||||
|
model: torch.nn.Module,
|
||||||
|
filename: str | os.PathLike,
|
||||||
|
*,
|
||||||
|
device: str = "cpu",
|
||||||
|
checkpoint_keys_mapping: str = "",
|
||||||
|
) -> torch.nn.Module:
|
||||||
|
state_dict = safetensors.torch.load_file(filename, device=device)
|
||||||
|
|
||||||
|
# Optional user-supplied renames (e.g. "model._orig_mod.//model.")
|
||||||
|
if checkpoint_keys_mapping and "//" in checkpoint_keys_mapping:
|
||||||
|
state_dict = rename_checkpoint_keys(state_dict, checkpoint_keys_mapping)
|
||||||
|
|
||||||
|
state_dict, _ = standardise_state_dict(state_dict, set(model.state_dict().keys()))
|
||||||
|
|
||||||
|
# HACK(aliberts): to not overwrite normalization parameters as they should come from the dataset
|
||||||
|
norm_keys = ("normalize_inputs", "normalize_targets", "unnormalize_outputs")
|
||||||
|
state_dict = {k: v for k, v in state_dict.items() if not k.startswith(norm_keys)}
|
||||||
|
|
||||||
|
missing, unexpected = model.load_state_dict(state_dict, strict=False)
|
||||||
|
if not all(key.startswith(norm_keys) for key in missing) or unexpected:
|
||||||
|
raise RuntimeError(
|
||||||
|
"SmolVLA %d missing / %d unexpected keys",
|
||||||
|
len(missing),
|
||||||
|
len(unexpected),
|
||||||
|
)
|
||||||
|
|
||||||
|
return model
|
||||||
def pad_vector(vector, new_dim):
|
def pad_vector(vector, new_dim):
|
||||||
"""Can be (batch_size x sequence_length x features_dimension)
|
"""Can be (batch_size x sequence_length x features_dimension)
|
||||||
or (batch_size x features_dimension)
|
or (batch_size x features_dimension)
|
||||||
@@ -219,7 +286,27 @@ def aloha_gripper_to_angular(value):
|
|||||||
# The values 0.4 and 1.5 were measured on an actual Trossen robot.
|
# The values 0.4 and 1.5 were measured on an actual Trossen robot.
|
||||||
return normalize(value, min_val=0.4, max_val=1.5)
|
return normalize(value, min_val=0.4, max_val=1.5)
|
||||||
|
|
||||||
|
def rename_checkpoint_keys(checkpoint: dict, rename_str: str):
|
||||||
|
"""
|
||||||
|
Renames keys in a checkpoint dictionary based on the given rename string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
checkpoint (dict): The checkpoint dictionary.
|
||||||
|
rename_str (str): A string specifying key mappings in the format "old1//new1,old2//new2".
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: The modified checkpoint with renamed keys.
|
||||||
|
"""
|
||||||
|
|
||||||
|
rename_dict = dict(pair.split("//") for pair in rename_str.split(","))
|
||||||
|
|
||||||
|
new_checkpoint = {}
|
||||||
|
for k, v in checkpoint.items():
|
||||||
|
for old_key, new_key in rename_dict.items():
|
||||||
|
if old_key in k:
|
||||||
|
k = k.replace(old_key, new_key)
|
||||||
|
new_checkpoint[k] = v
|
||||||
|
return new_checkpoint
|
||||||
def aloha_gripper_from_angular(value):
|
def aloha_gripper_from_angular(value):
|
||||||
# Convert from the gripper position used by pi0 to the gripper position that is used by Aloha.
|
# Convert from the gripper position used by pi0 to the gripper position that is used by Aloha.
|
||||||
# Note that the units are still angular but the range is different.
|
# Note that the units are still angular but the range is different.
|
||||||
@@ -333,7 +420,7 @@ class SMOLPI0Policy(PreTrainedPolicy):
|
|||||||
self.model.vlm_with_expert.merge_lora_weights()
|
self.model.vlm_with_expert.merge_lora_weights()
|
||||||
|
|
||||||
@torch.no_grad
|
@torch.no_grad
|
||||||
def select_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
||||||
"""Select a single action given environment observations.
|
"""Select a single action given environment observations.
|
||||||
|
|
||||||
This method wraps `select_actions` in order to return one action at a time for execution in the
|
This method wraps `select_actions` in order to return one action at a time for execution in the
|
||||||
@@ -364,7 +451,24 @@ class SMOLPI0Policy(PreTrainedPolicy):
|
|||||||
actions = self._pi_aloha_encode_actions(actions)
|
actions = self._pi_aloha_encode_actions(actions)
|
||||||
|
|
||||||
return actions
|
return actions
|
||||||
|
|
||||||
|
# HACK(aliberts, danaaubakirova): we overwrite this classmethod here to fix smolVLA-specific issues
|
||||||
|
@classmethod
|
||||||
|
def _load_as_safetensor(
|
||||||
|
cls,
|
||||||
|
model: "SmolVLAPolicy",
|
||||||
|
model_file: str,
|
||||||
|
map_location: str,
|
||||||
|
strict: bool,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
safetensors.torch.load_model(model, model_file, strict=strict, device=map_location)
|
||||||
|
return load_smolvla(
|
||||||
|
model,
|
||||||
|
model_file,
|
||||||
|
device=map_location,
|
||||||
|
checkpoint_keys_mapping="model._orig_mod.//model.",
|
||||||
|
)
|
||||||
@torch.no_grad
|
@torch.no_grad
|
||||||
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
||||||
"""Select a single action given environment observations.
|
"""Select a single action given environment observations.
|
||||||
|
|||||||
@@ -1027,7 +1027,7 @@ from lerobot.policies.utils import (
|
|||||||
populate_queues,
|
populate_queues,
|
||||||
)
|
)
|
||||||
from lerobot.utils.utils import get_safe_dtype
|
from lerobot.utils.utils import get_safe_dtype
|
||||||
|
# OBS_STATE = 'state'
|
||||||
# Matches ".soNNN", optionally followed by "-something", up to the "_buffer_" marker
|
# Matches ".soNNN", optionally followed by "-something", up to the "_buffer_" marker
|
||||||
_VARIANT_RE = re.compile(r"\.so\d+(?:-[\w]+)?_buffer_")
|
_VARIANT_RE = re.compile(r"\.so\d+(?:-[\w]+)?_buffer_")
|
||||||
|
|
||||||
@@ -1347,6 +1347,7 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
|||||||
|
|
||||||
# Unpad actions
|
# Unpad actions
|
||||||
original_action_dim = self.config.action_feature.shape[0]
|
original_action_dim = self.config.action_feature.shape[0]
|
||||||
|
original_action_dim = 7
|
||||||
actions = actions[:, :, :original_action_dim]
|
actions = actions[:, :, :original_action_dim]
|
||||||
|
|
||||||
actions = self.unnormalize_outputs({ACTION: actions})[ACTION]
|
actions = self.unnormalize_outputs({ACTION: actions})[ACTION]
|
||||||
|
|||||||
@@ -69,7 +69,7 @@ from tqdm import trange
|
|||||||
from lerobot.configs import parser
|
from lerobot.configs import parser
|
||||||
from lerobot.configs.eval import EvalPipelineConfig
|
from lerobot.configs.eval import EvalPipelineConfig
|
||||||
from lerobot.envs.factory import make_env
|
from lerobot.envs.factory import make_env
|
||||||
from lerobot.envs.utils import add_envs_task, check_env_attributes_and_types, preprocess_observation
|
from lerobot.envs.utils import add_envs_task, check_env_attributes_and_types, preprocess_observation, preprocess_observation1
|
||||||
from lerobot.policies.factory import make_policy
|
from lerobot.policies.factory import make_policy
|
||||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||||
from lerobot.policies.utils import get_device_from_parameters
|
from lerobot.policies.utils import get_device_from_parameters
|
||||||
@@ -125,6 +125,10 @@ def rollout(
|
|||||||
|
|
||||||
# Reset the policy and environments.
|
# Reset the policy and environments.
|
||||||
policy.reset()
|
policy.reset()
|
||||||
|
# added by jade
|
||||||
|
# for k in list(policy.config.input_features.keys()):
|
||||||
|
# if k.startswith("observation.image"):
|
||||||
|
# policy.config.input_features["observation.images." + k.split("observation.", 1)[1]] = policy.config.input_features.pop(k)
|
||||||
observation, info = env.reset(seed=seeds)
|
observation, info = env.reset(seed=seeds)
|
||||||
if render_callback is not None:
|
if render_callback is not None:
|
||||||
render_callback(env)
|
render_callback(env)
|
||||||
@@ -149,6 +153,7 @@ def rollout(
|
|||||||
while not np.all(done) and step < max_steps:
|
while not np.all(done) and step < max_steps:
|
||||||
# Numpy array to tensor and changing dictionary keys to LeRobot policy format.
|
# Numpy array to tensor and changing dictionary keys to LeRobot policy format.
|
||||||
observation = preprocess_observation(observation)
|
observation = preprocess_observation(observation)
|
||||||
|
# observation = preprocess_observation1(observation)
|
||||||
if return_observations:
|
if return_observations:
|
||||||
all_observations.append(deepcopy(observation))
|
all_observations.append(deepcopy(observation))
|
||||||
|
|
||||||
@@ -159,6 +164,26 @@ def rollout(
|
|||||||
# Infer "task" from attributes of environments.
|
# Infer "task" from attributes of environments.
|
||||||
# TODO: works with SyncVectorEnv but not AsyncVectorEnv
|
# TODO: works with SyncVectorEnv but not AsyncVectorEnv
|
||||||
observation = add_envs_task(env, observation)
|
observation = add_envs_task(env, observation)
|
||||||
|
# breakpoint()
|
||||||
|
# observation = {
|
||||||
|
# k.replace("observation.images.", "observation.") if k.startswith("observation.images.") else k: v
|
||||||
|
# for k, v in observation.items()
|
||||||
|
# # }
|
||||||
|
# if "observation.image" in observation:
|
||||||
|
# observation["image"] = observation.pop("observation.image").to(
|
||||||
|
# device, non_blocking=device.type == "cuda"
|
||||||
|
# )
|
||||||
|
|
||||||
|
# if "observation.image2" in observation:
|
||||||
|
# observation["wrist_image"] = observation.pop("observation.image2").to(
|
||||||
|
# device, non_blocking=device.type == "cuda"
|
||||||
|
# )
|
||||||
|
|
||||||
|
# if "observation.state" in observation:
|
||||||
|
# observation["state"] = observation.pop("observation.state").to(
|
||||||
|
# device, non_blocking=device.type == "cuda"
|
||||||
|
# )
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
action = policy.select_action(observation)
|
action = policy.select_action(observation)
|
||||||
# Convert to CPU / numpy.
|
# Convert to CPU / numpy.
|
||||||
@@ -489,12 +514,11 @@ def _inject_normalization_stats(policy: SmolVLAPolicy, dataset_meta: LeRobotData
|
|||||||
print("Normalization layers recreated with dataset stats.")
|
print("Normalization layers recreated with dataset stats.")
|
||||||
|
|
||||||
|
|
||||||
def load_smolvla(cfg, dataset_repo: str):
|
def load_smolvla(cfg, dataset_repo: str, policy):
|
||||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||||
dataset = LeRobotDataset(dataset_repo, root='/raid/jade/.cache/huggingface/datasets/')
|
dataset = LeRobotDataset(dataset_repo, root='/raid/jade/.cache/huggingface/datasets/')
|
||||||
policy = make_policy(cfg=cfg, ds_meta=dataset.meta)
|
|
||||||
_inject_normalization_stats(policy=policy, dataset_meta=dataset.meta) # only needed if stats are missing
|
_inject_normalization_stats(policy=policy, dataset_meta=dataset.meta) # only needed if stats are missing
|
||||||
return policy, dataset
|
return policy.to("cuda"), dataset
|
||||||
|
|
||||||
|
|
||||||
@parser.wrap()
|
@parser.wrap()
|
||||||
@@ -505,7 +529,7 @@ def eval_main(cfg: EvalPipelineConfig):
|
|||||||
device = get_safe_torch_device(cfg.policy.device, log=True)
|
device = get_safe_torch_device(cfg.policy.device, log=True)
|
||||||
#login to hf
|
#login to hf
|
||||||
from huggingface_hub import login
|
from huggingface_hub import login
|
||||||
login()
|
# login()
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
set_seed(cfg.seed)
|
set_seed(cfg.seed)
|
||||||
@@ -520,9 +544,10 @@ def eval_main(cfg: EvalPipelineConfig):
|
|||||||
cfg=cfg.policy,
|
cfg=cfg.policy,
|
||||||
env_cfg=cfg.env,
|
env_cfg=cfg.env,
|
||||||
)
|
)
|
||||||
# breakpoint()
|
breakpoint()
|
||||||
load_smolvla(cfg.policy, "physical-intelligence/libero")
|
# policy, _ = load_smolvla(cfg.policy, "physical-intelligence/libero", policy)
|
||||||
# breakpoint()
|
# rename "image" -> "observation.image"
|
||||||
|
|
||||||
policy.eval()
|
policy.eval()
|
||||||
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
|
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
|
||||||
if cfg.env.multitask_eval:
|
if cfg.env.multitask_eval:
|
||||||
|
|||||||
Reference in New Issue
Block a user