mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-23 04:30:10 +00:00
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
@@ -55,4 +55,4 @@ python src/lerobot/scripts/eval.py \
|
|||||||
# --num_trials_per_task 10 \
|
# --num_trials_per_task 10 \
|
||||||
# --video_out_path "data/libero/videos" \
|
# --video_out_path "data/libero/videos" \
|
||||||
# --device "cuda" \
|
# --device "cuda" \
|
||||||
# --seed 7
|
# --seed 7
|
||||||
|
|||||||
@@ -61,9 +61,9 @@ def make_env(
|
|||||||
|
|
||||||
env_cls = gym.vector.AsyncVectorEnv if use_async_envs else gym.vector.SyncVectorEnv
|
env_cls = gym.vector.AsyncVectorEnv if use_async_envs else gym.vector.SyncVectorEnv
|
||||||
|
|
||||||
|
|
||||||
if "libero" in cfg.type:
|
if "libero" in cfg.type:
|
||||||
from lerobot.envs.libero import create_libero_envs
|
from lerobot.envs.libero import create_libero_envs
|
||||||
|
|
||||||
return create_libero_envs(
|
return create_libero_envs(
|
||||||
task=cfg.task,
|
task=cfg.task,
|
||||||
n_envs=n_envs,
|
n_envs=n_envs,
|
||||||
@@ -74,17 +74,16 @@ def make_env(
|
|||||||
multitask_eval=cfg.multitask_eval,
|
multitask_eval=cfg.multitask_eval,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
package_name = f"gym_{cfg.type}"
|
package_name = f"gym_{cfg.type}"
|
||||||
try:
|
try:
|
||||||
importlib.import_module(package_name)
|
importlib.import_module(package_name)
|
||||||
except ModuleNotFoundError as e:
|
except ModuleNotFoundError as e:
|
||||||
raise ModuleNotFoundError(
|
raise ModuleNotFoundError(
|
||||||
f"{package_name} is not installed. Install with: pip install \"lerobot[{cfg.type}]\""
|
f'{package_name} is not installed. Install with: pip install "lerobot[{cfg.type}]"'
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
gym_handle = f"{package_name}/{cfg.task}"
|
gym_handle = f"{package_name}/{cfg.task}"
|
||||||
|
|
||||||
def _make_one():
|
def _make_one():
|
||||||
return gym.make(gym_handle, disable_env_checker=True, **(cfg.gym_kwargs or {}))
|
return gym.make(gym_handle, disable_env_checker=True, **(cfg.gym_kwargs or {}))
|
||||||
|
|
||||||
@@ -93,4 +92,3 @@ def make_env(
|
|||||||
# normalize to {suite: {task_id: vec_env}} for consistency
|
# normalize to {suite: {task_id: vec_env}} for consistency
|
||||||
suite_name = cfg.type # e.g., "pusht", "aloha"
|
suite_name = cfg.type # e.g., "pusht", "aloha"
|
||||||
return {suite_name: {0: vec}}
|
return {suite_name: {0: vec}}
|
||||||
|
|
||||||
|
|||||||
+35
-20
@@ -1,11 +1,13 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from typing import Any
|
from typing import Any, Dict, List
|
||||||
|
from collections.abc import Callable, Iterable, Mapping, Sequence
|
||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -14,16 +16,12 @@ from gymnasium import spaces
|
|||||||
from libero.libero import benchmark, get_libero_path
|
from libero.libero import benchmark, get_libero_path
|
||||||
from libero.libero.envs import OffScreenRenderEnv
|
from libero.libero.envs import OffScreenRenderEnv
|
||||||
|
|
||||||
import logging
|
|
||||||
from collections import defaultdict
|
|
||||||
from typing import Any, Callable, Dict, Iterable, List, Mapping, Sequence
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# ---- Helpers -----------------------------------------------------------------
|
# ---- Helpers -----------------------------------------------------------------
|
||||||
|
|
||||||
def _parse_camera_names(camera_name: str | Sequence[str]) -> List[str]:
|
|
||||||
|
def _parse_camera_names(camera_name: str | Sequence[str]) -> list[str]:
|
||||||
"""Normalize camera_name into a non-empty list of strings."""
|
"""Normalize camera_name into a non-empty list of strings."""
|
||||||
if isinstance(camera_name, str):
|
if isinstance(camera_name, str):
|
||||||
cams = [c.strip() for c in camera_name.split(",") if c.strip()]
|
cams = [c.strip() for c in camera_name.split(",") if c.strip()]
|
||||||
@@ -47,14 +45,14 @@ def _get_suite(name: str):
|
|||||||
return suite
|
return suite
|
||||||
|
|
||||||
|
|
||||||
def _select_task_ids(total_tasks: int, task_ids: Iterable[int] | None) -> List[int]:
|
def _select_task_ids(total_tasks: int, task_ids: Iterable[int] | None) -> list[int]:
|
||||||
"""Validate/normalize task ids. If None → all tasks."""
|
"""Validate/normalize task ids. If None → all tasks."""
|
||||||
if task_ids is None:
|
if task_ids is None:
|
||||||
return list(range(total_tasks))
|
return list(range(total_tasks))
|
||||||
ids = sorted(set(int(t) for t in task_ids))
|
ids = sorted({int(t) for t in task_ids})
|
||||||
for t in ids:
|
for t in ids:
|
||||||
if t < 0 or t >= total_tasks:
|
if t < 0 or t >= total_tasks:
|
||||||
raise ValueError(f"task_id {t} out of range [0, {total_tasks-1}].")
|
raise ValueError(f"task_id {t} out of range [0, {total_tasks - 1}].")
|
||||||
return ids
|
return ids
|
||||||
|
|
||||||
|
|
||||||
@@ -64,16 +62,25 @@ def _make_env_fns(
|
|||||||
suite_name: str,
|
suite_name: str,
|
||||||
task_id: int,
|
task_id: int,
|
||||||
n_envs: int,
|
n_envs: int,
|
||||||
camera_names: List[str],
|
camera_names: list[str],
|
||||||
init_states: bool,
|
init_states: bool,
|
||||||
gym_kwargs: Mapping[str, Any],
|
gym_kwargs: Mapping[str, Any],
|
||||||
LiberoEnv: type, # injected to avoid forward ref issues if needed
|
LiberoEnv: type, # injected to avoid forward ref issues if needed
|
||||||
) -> List[Callable[[], "LiberoEnv"]]:
|
) -> list[Callable[[], LiberoEnv]]:
|
||||||
"""Build n_envs factory callables for a single (suite, task_id)."""
|
"""Build n_envs factory callables for a single (suite, task_id)."""
|
||||||
joined_cams = ",".join(camera_names) # keep backward-compat: downstream expects a string
|
joined_cams = ",".join(camera_names) # keep backward-compat: downstream expects a string
|
||||||
fns: List[Callable[[], "LiberoEnv"]] = []
|
fns: list[Callable[[], LiberoEnv]] = []
|
||||||
for i in range(n_envs):
|
for i in range(n_envs):
|
||||||
def _mk(i=i, suite=suite, task_id=task_id, suite_name=suite_name, joined_cams=joined_cams, init_states=init_states, gym_kwargs=dict(gym_kwargs)):
|
|
||||||
|
def _mk(
|
||||||
|
i=i,
|
||||||
|
suite=suite,
|
||||||
|
task_id=task_id,
|
||||||
|
suite_name=suite_name,
|
||||||
|
joined_cams=joined_cams,
|
||||||
|
init_states=init_states,
|
||||||
|
gym_kwargs=dict(gym_kwargs),
|
||||||
|
):
|
||||||
return LiberoEnv(
|
return LiberoEnv(
|
||||||
task_suite=suite,
|
task_suite=suite,
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
@@ -83,11 +90,14 @@ def _make_env_fns(
|
|||||||
episode_index=i,
|
episode_index=i,
|
||||||
**gym_kwargs,
|
**gym_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
fns.append(_mk)
|
fns.append(_mk)
|
||||||
return fns
|
return fns
|
||||||
|
|
||||||
|
|
||||||
# ---- Main API ----------------------------------------------------------------
|
# ---- Main API ----------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
def create_libero_envs(
|
def create_libero_envs(
|
||||||
task: str,
|
task: str,
|
||||||
n_envs: int,
|
n_envs: int,
|
||||||
@@ -130,12 +140,15 @@ def create_libero_envs(
|
|||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"Creating LIBERO envs | suites=%s | n_envs(per task)=%d | init_states=%s | multitask_eval=%s",
|
"Creating LIBERO envs | suites=%s | n_envs(per task)=%d | init_states=%s | multitask_eval=%s",
|
||||||
suite_names, n_envs, init_states, bool(multitask_eval)
|
suite_names,
|
||||||
|
n_envs,
|
||||||
|
init_states,
|
||||||
|
bool(multitask_eval),
|
||||||
)
|
)
|
||||||
if task_ids_filter is not None:
|
if task_ids_filter is not None:
|
||||||
logger.info("Restricting to task_ids=%s", task_ids_filter)
|
logger.info("Restricting to task_ids=%s", task_ids_filter)
|
||||||
|
|
||||||
out: Dict[str, Dict[int, Any]] = defaultdict(dict)
|
out: dict[str, dict[int, Any]] = defaultdict(dict)
|
||||||
|
|
||||||
for suite_name in suite_names:
|
for suite_name in suite_names:
|
||||||
suite = _get_suite(suite_name)
|
suite = _get_suite(suite_name)
|
||||||
@@ -161,6 +174,8 @@ def create_libero_envs(
|
|||||||
|
|
||||||
# return plain dicts for predictability
|
# return plain dicts for predictability
|
||||||
return {suite: dict(task_map) for suite, task_map in out.items()}
|
return {suite: dict(task_map) for suite, task_map in out.items()}
|
||||||
|
|
||||||
|
|
||||||
def quat2axisangle(quat):
|
def quat2axisangle(quat):
|
||||||
"""
|
"""
|
||||||
Copied from robosuite: https://github.com/ARISE-Initiative/robosuite/blob/eafb81f54ffc104f905ee48a16bb15f059176ad3/robosuite/utils/transform_utils.py#L490C1-L512C55
|
Copied from robosuite: https://github.com/ARISE-Initiative/robosuite/blob/eafb81f54ffc104f905ee48a16bb15f059176ad3/robosuite/utils/transform_utils.py#L490C1-L512C55
|
||||||
@@ -256,10 +271,10 @@ class LiberoEnv(gym.Env):
|
|||||||
self._env = self._make_envs_task(task_suite, self.task_id)
|
self._env = self._make_envs_task(task_suite, self.task_id)
|
||||||
TASK_SUITE_MAX_STEPS: dict[str, int] = {
|
TASK_SUITE_MAX_STEPS: dict[str, int] = {
|
||||||
"libero_spatial": 220, # longest training demo has 193 steps
|
"libero_spatial": 220, # longest training demo has 193 steps
|
||||||
"libero_object": 280, # longest training demo has 254 steps
|
"libero_object": 280, # longest training demo has 254 steps
|
||||||
"libero_goal": 300, # longest training demo has 270 steps
|
"libero_goal": 300, # longest training demo has 270 steps
|
||||||
"libero_10": 520, # longest training demo has 505 steps
|
"libero_10": 520, # longest training demo has 505 steps
|
||||||
"libero_90": 400, # longest training demo has 373 steps
|
"libero_90": 400, # longest training demo has 373 steps
|
||||||
}
|
}
|
||||||
default_steps = 500
|
default_steps = 500
|
||||||
self._max_episode_steps = TASK_SUITE_MAX_STEPS.get(task_suite_name, default_steps)
|
self._max_episode_steps = TASK_SUITE_MAX_STEPS.get(task_suite_name, default_steps)
|
||||||
|
|||||||
@@ -80,6 +80,7 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
|||||||
|
|
||||||
return return_observations
|
return return_observations
|
||||||
|
|
||||||
|
|
||||||
def preprocess_observation1(
|
def preprocess_observation1(
|
||||||
observations: dict[str, np.ndarray], cfg: dict[str, Any] = None
|
observations: dict[str, np.ndarray], cfg: dict[str, Any] = None
|
||||||
) -> dict[str, Tensor]:
|
) -> dict[str, Tensor]:
|
||||||
@@ -130,6 +131,8 @@ def preprocess_observation1(
|
|||||||
if "task" in observations:
|
if "task" in observations:
|
||||||
return_observations["task"] = observations["task"]
|
return_observations["task"] = observations["task"]
|
||||||
return return_observations
|
return return_observations
|
||||||
|
|
||||||
|
|
||||||
def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]:
|
def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]:
|
||||||
# TODO(aliberts, rcadene): remove this hardcoding of keys and just use the nested keys as is
|
# TODO(aliberts, rcadene): remove this hardcoding of keys and just use the nested keys as is
|
||||||
# (need to also refactor preprocess_observation and externalize normalization from policies)
|
# (need to also refactor preprocess_observation and externalize normalization from policies)
|
||||||
@@ -183,6 +186,7 @@ def add_envs_task(env: gym.vector.VectorEnv, observation: dict[str, Any]) -> dic
|
|||||||
observation["task"] = ["" for _ in range(num_envs)]
|
observation["task"] = ["" for _ in range(num_envs)]
|
||||||
return observation
|
return observation
|
||||||
|
|
||||||
|
|
||||||
def _close_single_env(env: Any) -> None:
|
def _close_single_env(env: Any) -> None:
|
||||||
"""Try to close a single env object if it exposes .close()."""
|
"""Try to close a single env object if it exposes .close()."""
|
||||||
try:
|
try:
|
||||||
@@ -193,6 +197,7 @@ def _close_single_env(env: Any) -> None:
|
|||||||
# Best-effort close: log but don't raise
|
# Best-effort close: log but don't raise
|
||||||
LOG.debug("Exception while closing env %s: %s", env, exc)
|
LOG.debug("Exception while closing env %s: %s", env, exc)
|
||||||
|
|
||||||
|
|
||||||
def close_envs(env_or_collection: Any) -> None:
|
def close_envs(env_or_collection: Any) -> None:
|
||||||
"""
|
"""
|
||||||
Close a single env or any nested structure of envs.
|
Close a single env or any nested structure of envs.
|
||||||
@@ -225,4 +230,4 @@ def close_envs(env_or_collection: Any) -> None:
|
|||||||
|
|
||||||
# Fallback: try to close if possible
|
# Fallback: try to close if possible
|
||||||
if hasattr(env_or_collection, "close"):
|
if hasattr(env_or_collection, "close"):
|
||||||
_close_single_env(env_or_collection)
|
_close_single_env(env_or_collection)
|
||||||
|
|||||||
@@ -31,10 +31,10 @@ from lerobot.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
|
|||||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||||
from lerobot.policies.sac.configuration_sac import SACConfig
|
from lerobot.policies.sac.configuration_sac import SACConfig
|
||||||
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
|
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
|
||||||
|
from lerobot.policies.smolpi0.configuration_smolpi0 import SMOLPI0Config
|
||||||
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
|
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
|
||||||
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||||
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
|
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||||
from lerobot.policies.smolpi0.configuration_smolpi0 import SMOLPI0Config
|
|
||||||
|
|
||||||
|
|
||||||
def get_policy_class(name: str) -> PreTrainedPolicy:
|
def get_policy_class(name: str) -> PreTrainedPolicy:
|
||||||
|
|||||||
@@ -309,8 +309,8 @@ class NormalizePerRobotType(nn.Module):
|
|||||||
getattr(self, f"{robot_type}_buffer_" + key.replace(".", "_")) for robot_type in robot_types
|
getattr(self, f"{robot_type}_buffer_" + key.replace(".", "_")) for robot_type in robot_types
|
||||||
]
|
]
|
||||||
if norm_mode is NormalizationMode.MEAN_STD:
|
if norm_mode is NormalizationMode.MEAN_STD:
|
||||||
mean = torch.stack([buffers[i]["mean"] for i in range(len(robot_types))],dim=0)
|
mean = torch.stack([buffers[i]["mean"] for i in range(len(robot_types))], dim=0)
|
||||||
std = torch.stack([buffers[i]["std"] for i in range(len(robot_types))],dim=0)
|
std = torch.stack([buffers[i]["std"] for i in range(len(robot_types))], dim=0)
|
||||||
if batch[key].ndim == 3:
|
if batch[key].ndim == 3:
|
||||||
mean = mean.unsqueeze(1)
|
mean = mean.unsqueeze(1)
|
||||||
std = std.unsqueeze(1)
|
std = std.unsqueeze(1)
|
||||||
@@ -332,6 +332,8 @@ class NormalizePerRobotType(nn.Module):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(norm_mode)
|
raise ValueError(norm_mode)
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
|
|
||||||
# TODO (azouitine): We should replace all normalization on the policies with register_buffer normalization
|
# TODO (azouitine): We should replace all normalization on the policies with register_buffer normalization
|
||||||
# and remove the `Normalize` and `Unnormalize` classes.
|
# and remove the `Normalize` and `Unnormalize` classes.
|
||||||
def _initialize_stats_buffers(
|
def _initialize_stats_buffers(
|
||||||
|
|||||||
@@ -14,12 +14,12 @@
|
|||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
from lerobot.configs.policies import PreTrainedConfig
|
||||||
|
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||||
from lerobot.optim.optimizers import AdamWConfig
|
from lerobot.optim.optimizers import AdamWConfig
|
||||||
from lerobot.optim.schedulers import (
|
from lerobot.optim.schedulers import (
|
||||||
CosineDecayWithWarmupSchedulerConfig,
|
CosineDecayWithWarmupSchedulerConfig,
|
||||||
)
|
)
|
||||||
from lerobot.configs.policies import PreTrainedConfig
|
|
||||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -52,7 +52,7 @@ class SMOLPI0Config(PreTrainedConfig):
|
|||||||
max_action_dim: int = 32
|
max_action_dim: int = 32
|
||||||
|
|
||||||
# Image preprocessing
|
# Image preprocessing
|
||||||
resize_imgs_with_padding: tuple[int, int] = (512, 512) #(224, 224)
|
resize_imgs_with_padding: tuple[int, int] = (512, 512) # (224, 224)
|
||||||
|
|
||||||
# Add empty images. Used by pi0_aloha_sim which adds the empty
|
# Add empty images. Used by pi0_aloha_sim which adds the empty
|
||||||
# left and right wrist cameras in addition to the top camera.
|
# left and right wrist cameras in addition to the top camera.
|
||||||
@@ -107,14 +107,14 @@ class SMOLPI0Config(PreTrainedConfig):
|
|||||||
|
|
||||||
add_image_special_tokens: bool = False
|
add_image_special_tokens: bool = False
|
||||||
add_prompt_template: bool = False
|
add_prompt_template: bool = False
|
||||||
prefix_prompt_template: str = f"<|im_start|>User: What action should the robot take to"
|
prefix_prompt_template: str = "<|im_start|>User: What action should the robot take to"
|
||||||
suffix_prompt_template: str = f"?\nAssistant:"
|
suffix_prompt_template: str = "?\nAssistant:"
|
||||||
|
|
||||||
attention_mode: str = "self_attn"
|
attention_mode: str = "self_attn"
|
||||||
|
|
||||||
prefix_length: int = -1 # n_obs_steps * num_cameras * num_image_token_per_image + tokenizer_max_length
|
prefix_length: int = -1 # n_obs_steps * num_cameras * num_image_token_per_image + tokenizer_max_length
|
||||||
|
|
||||||
past_obs_keys: str = f"image"
|
past_obs_keys: str = "image"
|
||||||
|
|
||||||
add_local_special_image_tokens: bool = False
|
add_local_special_image_tokens: bool = False
|
||||||
|
|
||||||
@@ -122,7 +122,7 @@ class SMOLPI0Config(PreTrainedConfig):
|
|||||||
|
|
||||||
state_to_prefix: bool = False
|
state_to_prefix: bool = False
|
||||||
|
|
||||||
pad_language_to: str = "longest" # "max_length"
|
pad_language_to: str = "longest" # "max_length"
|
||||||
|
|
||||||
num_expert_layers: int = -1
|
num_expert_layers: int = -1
|
||||||
num_vlm_layers: int = -1
|
num_vlm_layers: int = -1
|
||||||
@@ -144,9 +144,9 @@ class SMOLPI0Config(PreTrainedConfig):
|
|||||||
|
|
||||||
shuffle_camera_positions: bool = False
|
shuffle_camera_positions: bool = False
|
||||||
vlm_img_size: int = -1
|
vlm_img_size: int = -1
|
||||||
|
|
||||||
regression_loss: bool = False
|
regression_loss: bool = False
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
super().__post_init__()
|
super().__post_init__()
|
||||||
if self.vlm_img_size > 0:
|
if self.vlm_img_size > 0:
|
||||||
@@ -198,7 +198,7 @@ class SMOLPI0Config(PreTrainedConfig):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def observation_delta_indices(self) -> list: # FIXME(mshukor): support spacing between observations
|
def observation_delta_indices(self) -> list: # FIXME(mshukor): support spacing between observations
|
||||||
return [-k for k in range(0, self.n_obs_steps * self.n_obs_gap, self.n_obs_gap)][::-1]
|
return [-k for k in range(0, self.n_obs_steps * self.n_obs_gap, self.n_obs_gap)][::-1]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
@@ -85,7 +85,7 @@ def flex_attention_forward(
|
|||||||
|
|
||||||
b_mask, h_mask, q_len, kv_len = causal_mask.shape # The shape of your mask
|
b_mask, h_mask, q_len, kv_len = causal_mask.shape # The shape of your mask
|
||||||
|
|
||||||
block_size = 128 # limitation of flex attention
|
block_size = 128 # limitation of flex attention
|
||||||
q_len_rounded = _round_up_to_multiple(q_len, block_size)
|
q_len_rounded = _round_up_to_multiple(q_len, block_size)
|
||||||
kv_len_rounded = _round_up_to_multiple(kv_len, block_size)
|
kv_len_rounded = _round_up_to_multiple(kv_len, block_size)
|
||||||
|
|
||||||
@@ -95,7 +95,7 @@ def flex_attention_forward(
|
|||||||
pad_k = kv_len_rounded - kv_len
|
pad_k = kv_len_rounded - kv_len
|
||||||
if pad_q > 0 or pad_k > 0:
|
if pad_q > 0 or pad_k > 0:
|
||||||
padded_causal_mask = F.pad(causal_mask, (0, pad_k, 0, pad_q), value=0.0)
|
padded_causal_mask = F.pad(causal_mask, (0, pad_k, 0, pad_q), value=0.0)
|
||||||
else:
|
else:
|
||||||
padded_causal_mask = causal_mask
|
padded_causal_mask = causal_mask
|
||||||
mask_mod_fn_orig = precomputed_mask_factory(padded_causal_mask)
|
mask_mod_fn_orig = precomputed_mask_factory(padded_causal_mask)
|
||||||
|
|
||||||
@@ -107,21 +107,21 @@ def flex_attention_forward(
|
|||||||
KV_LEN=kv_len_rounded,
|
KV_LEN=kv_len_rounded,
|
||||||
device=causal_mask.device,
|
device=causal_mask.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
mask_mod_fn_padded = precomputed_mask_factory(mask_4d)
|
mask_mod_fn_padded = precomputed_mask_factory(mask_4d)
|
||||||
# FIXME(mshukor): compile mask torch.compile(create_block_mask)
|
# FIXME(mshukor): compile mask torch.compile(create_block_mask)
|
||||||
create_block_mask_compiled = torch.compile(create_block_mask)
|
create_block_mask_compiled = torch.compile(create_block_mask)
|
||||||
block_mask = create_block_mask_compiled(
|
block_mask = create_block_mask_compiled(
|
||||||
mask_mod=mask_mod_fn_padded,
|
mask_mod=mask_mod_fn_padded,
|
||||||
B=b_mask,
|
B=b_mask,
|
||||||
H=None, #
|
H=None, #
|
||||||
Q_LEN=q_len_rounded,
|
Q_LEN=q_len_rounded,
|
||||||
KV_LEN=kv_len_rounded,
|
KV_LEN=kv_len_rounded,
|
||||||
BLOCK_SIZE=block_size,
|
BLOCK_SIZE=block_size,
|
||||||
device=causal_mask.device,
|
device=causal_mask.device,
|
||||||
_compile=False,
|
_compile=False,
|
||||||
)
|
)
|
||||||
padded_query_states = F.pad(query_states, (0, 0, 0, pad_q), value=0.0) if pad_q > 0 else query_states
|
padded_query_states = F.pad(query_states, (0, 0, 0, pad_q), value=0.0) if pad_q > 0 else query_states
|
||||||
padded_key_states = F.pad(key_states, (0, 0, 0, pad_k), value=0.0) if pad_k > 0 else key_states
|
padded_key_states = F.pad(key_states, (0, 0, 0, pad_k), value=0.0) if pad_k > 0 else key_states
|
||||||
padded_value_states = F.pad(value_states, (0, 0, 0, pad_k), value=0.0) if pad_k > 0 else value_states
|
padded_value_states = F.pad(value_states, (0, 0, 0, pad_k), value=0.0) if pad_k > 0 else value_states
|
||||||
# mask is applied inside the kernel, ideally more efficiently than score_mod.
|
# mask is applied inside the kernel, ideally more efficiently than score_mod.
|
||||||
|
|||||||
@@ -50,9 +50,10 @@ policy = Pi0Policy.from_pretrained("lerobot/pi0")
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from collections import deque
|
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
from collections import deque
|
||||||
|
|
||||||
import safetensors
|
import safetensors
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F # noqa: N812
|
import torch.nn.functional as F # noqa: N812
|
||||||
@@ -66,12 +67,11 @@ from lerobot.policies.normalize import (
|
|||||||
Unnormalize,
|
Unnormalize,
|
||||||
UnnormalizePerRobotType,
|
UnnormalizePerRobotType,
|
||||||
)
|
)
|
||||||
from lerobot.policies.smolpi0.configuration_smolpi0 import SMOLPI0Config
|
|
||||||
from lerobot.policies.smolpi0.smolvlm_with_expert import (
|
|
||||||
SmolVLMWithExpertModel
|
|
||||||
)
|
|
||||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||||
|
from lerobot.policies.smolpi0.configuration_smolpi0 import SMOLPI0Config
|
||||||
|
from lerobot.policies.smolpi0.smolvlm_with_expert import SmolVLMWithExpertModel
|
||||||
from lerobot.utils.utils import get_safe_dtype
|
from lerobot.utils.utils import get_safe_dtype
|
||||||
|
|
||||||
OBS_IMAGE = "observation.image"
|
OBS_IMAGE = "observation.image"
|
||||||
OBS_IMAGES = "observation.images"
|
OBS_IMAGES = "observation.images"
|
||||||
ACTION = "action"
|
ACTION = "action"
|
||||||
@@ -86,10 +86,13 @@ IMAGES_ORDER = {
|
|||||||
OBS_IMAGE_3: 2,
|
OBS_IMAGE_3: 2,
|
||||||
OBS_IMAGE_4: 3,
|
OBS_IMAGE_4: 3,
|
||||||
}
|
}
|
||||||
|
import random
|
||||||
|
|
||||||
from lerobot.policies.utils import (
|
from lerobot.policies.utils import (
|
||||||
populate_queues,
|
populate_queues,
|
||||||
)
|
)
|
||||||
import random
|
|
||||||
|
|
||||||
def create_sinusoidal_pos_embedding(
|
def create_sinusoidal_pos_embedding(
|
||||||
time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu"
|
time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu"
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
@@ -171,7 +174,10 @@ def resize_with_pad(img, width, height, pad_value=-1):
|
|||||||
padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value)
|
padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value)
|
||||||
return padded_img
|
return padded_img
|
||||||
|
|
||||||
|
|
||||||
_VARIANT_RE = re.compile(r"\.so\d+(?:-[\w]+)?_buffer_")
|
_VARIANT_RE = re.compile(r"\.so\d+(?:-[\w]+)?_buffer_")
|
||||||
|
|
||||||
|
|
||||||
def canonicalise(k: str) -> str:
|
def canonicalise(k: str) -> str:
|
||||||
"""
|
"""
|
||||||
Remove dataset-variant markers like '.so100-blue_' or '.so100_' from a
|
Remove dataset-variant markers like '.so100-blue_' or '.so100_' from a
|
||||||
@@ -179,6 +185,7 @@ def canonicalise(k: str) -> str:
|
|||||||
"""
|
"""
|
||||||
return _VARIANT_RE.sub(".buffer_", k)
|
return _VARIANT_RE.sub(".buffer_", k)
|
||||||
|
|
||||||
|
|
||||||
def standardise_state_dict(
|
def standardise_state_dict(
|
||||||
checkpoint: dict[str, torch.Tensor], ref_keys: set[str], *, verbose: bool = True
|
checkpoint: dict[str, torch.Tensor], ref_keys: set[str], *, verbose: bool = True
|
||||||
) -> tuple[dict[str, torch.Tensor], list[str]]:
|
) -> tuple[dict[str, torch.Tensor], list[str]]:
|
||||||
@@ -209,6 +216,7 @@ def standardise_state_dict(
|
|||||||
out.update({k: checkpoint[k] for k in unmatched})
|
out.update({k: checkpoint[k] for k in unmatched})
|
||||||
return out, unmatched
|
return out, unmatched
|
||||||
|
|
||||||
|
|
||||||
def load_smolvla(
|
def load_smolvla(
|
||||||
model: torch.nn.Module,
|
model: torch.nn.Module,
|
||||||
filename: str | os.PathLike,
|
filename: str | os.PathLike,
|
||||||
@@ -237,6 +245,8 @@ def load_smolvla(
|
|||||||
)
|
)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def pad_vector(vector, new_dim):
|
def pad_vector(vector, new_dim):
|
||||||
"""Can be (batch_size x sequence_length x features_dimension)
|
"""Can be (batch_size x sequence_length x features_dimension)
|
||||||
or (batch_size x features_dimension)
|
or (batch_size x features_dimension)
|
||||||
@@ -286,6 +296,7 @@ def aloha_gripper_to_angular(value):
|
|||||||
# The values 0.4 and 1.5 were measured on an actual Trossen robot.
|
# The values 0.4 and 1.5 were measured on an actual Trossen robot.
|
||||||
return normalize(value, min_val=0.4, max_val=1.5)
|
return normalize(value, min_val=0.4, max_val=1.5)
|
||||||
|
|
||||||
|
|
||||||
def rename_checkpoint_keys(checkpoint: dict, rename_str: str):
|
def rename_checkpoint_keys(checkpoint: dict, rename_str: str):
|
||||||
"""
|
"""
|
||||||
Renames keys in a checkpoint dictionary based on the given rename string.
|
Renames keys in a checkpoint dictionary based on the given rename string.
|
||||||
@@ -307,6 +318,8 @@ def rename_checkpoint_keys(checkpoint: dict, rename_str: str):
|
|||||||
k = k.replace(old_key, new_key)
|
k = k.replace(old_key, new_key)
|
||||||
new_checkpoint[k] = v
|
new_checkpoint[k] = v
|
||||||
return new_checkpoint
|
return new_checkpoint
|
||||||
|
|
||||||
|
|
||||||
def aloha_gripper_from_angular(value):
|
def aloha_gripper_from_angular(value):
|
||||||
# Convert from the gripper position used by pi0 to the gripper position that is used by Aloha.
|
# Convert from the gripper position used by pi0 to the gripper position that is used by Aloha.
|
||||||
# Note that the units are still angular but the range is different.
|
# Note that the units are still angular but the range is different.
|
||||||
@@ -324,6 +337,7 @@ def aloha_gripper_from_angular_inv(value):
|
|||||||
value = unnormalize(value, min_val=-0.6213, max_val=1.4910)
|
value = unnormalize(value, min_val=-0.6213, max_val=1.4910)
|
||||||
return normalize(value, min_val=0.4, max_val=1.5)
|
return normalize(value, min_val=0.4, max_val=1.5)
|
||||||
|
|
||||||
|
|
||||||
class SMOLPI0Policy(PreTrainedPolicy):
|
class SMOLPI0Policy(PreTrainedPolicy):
|
||||||
"""Wrapper class around VLAFlowMatching model to train and run inference within LeRobot."""
|
"""Wrapper class around VLAFlowMatching model to train and run inference within LeRobot."""
|
||||||
|
|
||||||
@@ -374,7 +388,9 @@ class SMOLPI0Policy(PreTrainedPolicy):
|
|||||||
|
|
||||||
self.language_tokenizer = AutoProcessor.from_pretrained(self.config.vlm_model_name).tokenizer
|
self.language_tokenizer = AutoProcessor.from_pretrained(self.config.vlm_model_name).tokenizer
|
||||||
self.model = VLAFlowMatching(config)
|
self.model = VLAFlowMatching(config)
|
||||||
self.include_past_states = config.n_obs_steps > 1 and OBS_STATE in self.config.past_obs_keys.split(",")
|
self.include_past_states = config.n_obs_steps > 1 and OBS_STATE in self.config.past_obs_keys.split(
|
||||||
|
","
|
||||||
|
)
|
||||||
self.include_past_images = config.n_obs_steps > 1 and "image" in self.config.past_obs_keys.split(",")
|
self.include_past_images = config.n_obs_steps > 1 and "image" in self.config.past_obs_keys.split(",")
|
||||||
self.num_past_images = self.config.n_obs_steps if self.include_past_images else 1
|
self.num_past_images = self.config.n_obs_steps if self.include_past_images else 1
|
||||||
self.reset()
|
self.reset()
|
||||||
@@ -389,31 +405,20 @@ class SMOLPI0Policy(PreTrainedPolicy):
|
|||||||
for k in self.config.input_features:
|
for k in self.config.input_features:
|
||||||
if any([past_obs_key in k for past_obs_key in self.config.past_obs_keys.split(",")]):
|
if any([past_obs_key in k for past_obs_key in self.config.past_obs_keys.split(",")]):
|
||||||
self._queues[k] = deque(maxlen=self.config.n_obs_steps)
|
self._queues[k] = deque(maxlen=self.config.n_obs_steps)
|
||||||
|
|
||||||
def get_optim_params(self) -> dict:
|
def get_optim_params(self) -> dict:
|
||||||
if self.config.optimizer_lr_vlm > 0 and self.config.optimizer_lr_vlm != self.config.optimizer_lr:
|
if self.config.optimizer_lr_vlm > 0 and self.config.optimizer_lr_vlm != self.config.optimizer_lr:
|
||||||
params = [
|
params = [
|
||||||
|
{"params": [p for n, p in self.named_parameters() if ".vlm." not in n and p.requires_grad]},
|
||||||
{
|
{
|
||||||
"params": [
|
"params": [p for n, p in self.named_parameters() if ".vlm." in n and p.requires_grad],
|
||||||
p
|
|
||||||
for n, p in self.named_parameters()
|
|
||||||
if not ".vlm." in n and p.requires_grad
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"params": [
|
|
||||||
p
|
|
||||||
for n, p in self.named_parameters()
|
|
||||||
if ".vlm." in n and p.requires_grad
|
|
||||||
],
|
|
||||||
"lr": self.config.optimizer_lr_vlm,
|
"lr": self.config.optimizer_lr_vlm,
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
return params
|
return params
|
||||||
|
|
||||||
else:
|
else:
|
||||||
return self.parameters()
|
return self.parameters()
|
||||||
|
|
||||||
|
|
||||||
def merge_peft_model_weights(self) -> None:
|
def merge_peft_model_weights(self) -> None:
|
||||||
if "lora" in self.config.peft_method:
|
if "lora" in self.config.peft_method:
|
||||||
@@ -438,9 +443,7 @@ class SMOLPI0Policy(PreTrainedPolicy):
|
|||||||
state = self.prepare_state(batch)
|
state = self.prepare_state(batch)
|
||||||
lang_tokens, lang_masks = self.prepare_language(batch)
|
lang_tokens, lang_masks = self.prepare_language(batch)
|
||||||
|
|
||||||
actions = self.model.sample_actions(
|
actions = self.model.sample_actions(images, img_masks, lang_tokens, lang_masks, state, noise=noise)
|
||||||
images, img_masks, lang_tokens, lang_masks, state, noise=noise
|
|
||||||
)
|
|
||||||
# Unpad actions
|
# Unpad actions
|
||||||
original_action_dim = self.config.action_feature.shape[0]
|
original_action_dim = self.config.action_feature.shape[0]
|
||||||
actions = actions[:, :, :original_action_dim]
|
actions = actions[:, :, :original_action_dim]
|
||||||
@@ -469,6 +472,7 @@ class SMOLPI0Policy(PreTrainedPolicy):
|
|||||||
device=map_location,
|
device=map_location,
|
||||||
checkpoint_keys_mapping="model._orig_mod.//model.",
|
checkpoint_keys_mapping="model._orig_mod.//model.",
|
||||||
)
|
)
|
||||||
|
|
||||||
@torch.no_grad
|
@torch.no_grad
|
||||||
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
||||||
"""Select a single action given environment observations.
|
"""Select a single action given environment observations.
|
||||||
@@ -564,8 +568,12 @@ class SMOLPI0Policy(PreTrainedPolicy):
|
|||||||
present_img_keys = [key for key in self.config.image_features if key in batch]
|
present_img_keys = [key for key in self.config.image_features if key in batch]
|
||||||
missing_img_keys = [key for key in self.config.image_features if key not in batch]
|
missing_img_keys = [key for key in self.config.image_features if key not in batch]
|
||||||
|
|
||||||
present_img_keys = sorted(present_img_keys, key=lambda k: IMAGES_ORDER.get(k, float("inf")), reverse=self.config.reverse_images_order)
|
present_img_keys = sorted(
|
||||||
if self.config.shuffle_camera_positions and ACTION in batch: # only during training
|
present_img_keys,
|
||||||
|
key=lambda k: IMAGES_ORDER.get(k, float("inf")),
|
||||||
|
reverse=self.config.reverse_images_order,
|
||||||
|
)
|
||||||
|
if self.config.shuffle_camera_positions and ACTION in batch: # only during training
|
||||||
present_img_keys = random.sample(present_img_keys, len(present_img_keys))
|
present_img_keys = random.sample(present_img_keys, len(present_img_keys))
|
||||||
if len(present_img_keys) == 0:
|
if len(present_img_keys) == 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -609,7 +617,10 @@ class SMOLPI0Policy(PreTrainedPolicy):
|
|||||||
tasks = [tasks[0] for _ in range(batch[OBS_STATE].shape[0])]
|
tasks = [tasks[0] for _ in range(batch[OBS_STATE].shape[0])]
|
||||||
|
|
||||||
if self.config.add_prompt_template:
|
if self.config.add_prompt_template:
|
||||||
tasks = [f"{self.config.prefix_prompt_template}{task}{self.config.suffix_prompt_template}" for task in tasks]
|
tasks = [
|
||||||
|
f"{self.config.prefix_prompt_template}{task}{self.config.suffix_prompt_template}"
|
||||||
|
for task in tasks
|
||||||
|
]
|
||||||
else:
|
else:
|
||||||
tasks = [task if task.endswith("\n") else f"{task}\n" for task in tasks]
|
tasks = [task if task.endswith("\n") else f"{task}\n" for task in tasks]
|
||||||
tokenized_prompt = self.language_tokenizer.__call__(
|
tokenized_prompt = self.language_tokenizer.__call__(
|
||||||
@@ -618,7 +629,7 @@ class SMOLPI0Policy(PreTrainedPolicy):
|
|||||||
padding_side="right",
|
padding_side="right",
|
||||||
max_length=self.config.tokenizer_max_length,
|
max_length=self.config.tokenizer_max_length,
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
truncation=True, # FIXME(mshukor)
|
truncation=True, # FIXME(mshukor)
|
||||||
)
|
)
|
||||||
|
|
||||||
lang_tokens = tokenized_prompt["input_ids"].to(device=device)
|
lang_tokens = tokenized_prompt["input_ids"].to(device=device)
|
||||||
@@ -655,7 +666,11 @@ class SMOLPI0Policy(PreTrainedPolicy):
|
|||||||
|
|
||||||
def prepare_state(self, batch):
|
def prepare_state(self, batch):
|
||||||
"""Pad state"""
|
"""Pad state"""
|
||||||
state = batch[OBS_STATE][:, -1, :] if (batch[OBS_STATE].ndim > 2 and not self.include_past_states) else batch[OBS_STATE] # FIXME(mshukor): no state history for now
|
state = (
|
||||||
|
batch[OBS_STATE][:, -1, :]
|
||||||
|
if (batch[OBS_STATE].ndim > 2 and not self.include_past_states)
|
||||||
|
else batch[OBS_STATE]
|
||||||
|
) # FIXME(mshukor): no state history for now
|
||||||
state = pad_vector(state, self.config.max_state_dim)
|
state = pad_vector(state, self.config.max_state_dim)
|
||||||
return state
|
return state
|
||||||
|
|
||||||
@@ -666,7 +681,9 @@ class SMOLPI0Policy(PreTrainedPolicy):
|
|||||||
if self.config.relative_actions_mode == "first":
|
if self.config.relative_actions_mode == "first":
|
||||||
actions = torch.cat((actions[:, :1], actions[:, 1:] - actions[:, :1]), dim=1)
|
actions = torch.cat((actions[:, :1], actions[:, 1:] - actions[:, :1]), dim=1)
|
||||||
elif self.config.relative_actions_mode == "state":
|
elif self.config.relative_actions_mode == "state":
|
||||||
assert batch[ACTION].shape[-1] == batch[OBS_STATE].shape[-1], "Relative action mode 'state' requires the action and state to have the same dimension."
|
assert batch[ACTION].shape[-1] == batch[OBS_STATE].shape[-1], (
|
||||||
|
"Relative action mode 'state' requires the action and state to have the same dimension."
|
||||||
|
)
|
||||||
if state.ndim == 2:
|
if state.ndim == 2:
|
||||||
state = state.unsqueeze(1)
|
state = state.unsqueeze(1)
|
||||||
actions = actions - state
|
actions = actions - state
|
||||||
@@ -674,6 +691,7 @@ class SMOLPI0Policy(PreTrainedPolicy):
|
|||||||
actions = torch.cat((actions[:, :1], actions[:, 1:] - actions[:, :-1]), dim=1)
|
actions = torch.cat((actions[:, :1], actions[:, 1:] - actions[:, :-1]), dim=1)
|
||||||
return actions
|
return actions
|
||||||
|
|
||||||
|
|
||||||
def pad_tensor(tensor, max_len, pad_value=0):
|
def pad_tensor(tensor, max_len, pad_value=0):
|
||||||
"""
|
"""
|
||||||
Efficiently pads a tensor along sequence dimension to match max_len.
|
Efficiently pads a tensor along sequence dimension to match max_len.
|
||||||
@@ -687,13 +705,16 @@ def pad_tensor(tensor, max_len, pad_value=0):
|
|||||||
torch.Tensor: Shape (B, max_len, ...) or (B, max_len).
|
torch.Tensor: Shape (B, max_len, ...) or (B, max_len).
|
||||||
"""
|
"""
|
||||||
B, L = tensor.shape[:2]
|
B, L = tensor.shape[:2]
|
||||||
|
|
||||||
# Create a padded tensor of max_len and copy the existing values
|
# Create a padded tensor of max_len and copy the existing values
|
||||||
padded_tensor = torch.full((B, max_len, *tensor.shape[2:]), pad_value, dtype=tensor.dtype, device=tensor.device)
|
padded_tensor = torch.full(
|
||||||
|
(B, max_len, *tensor.shape[2:]), pad_value, dtype=tensor.dtype, device=tensor.device
|
||||||
|
)
|
||||||
padded_tensor[:, :L] = tensor # Efficient in-place copy
|
padded_tensor[:, :L] = tensor # Efficient in-place copy
|
||||||
|
|
||||||
return padded_tensor
|
return padded_tensor
|
||||||
|
|
||||||
|
|
||||||
class VLAFlowMatching(nn.Module):
|
class VLAFlowMatching(nn.Module):
|
||||||
"""
|
"""
|
||||||
π0: A Vision-Language-Action Flow Model for General Robot Control
|
π0: A Vision-Language-Action Flow Model for General Robot Control
|
||||||
@@ -725,7 +746,8 @@ class VLAFlowMatching(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
self.vlm_with_expert = SmolVLMWithExpertModel(model_id=self.config.vlm_model_name,
|
self.vlm_with_expert = SmolVLMWithExpertModel(
|
||||||
|
model_id=self.config.vlm_model_name,
|
||||||
freeze_vision_encoder=self.config.freeze_vision_encoder,
|
freeze_vision_encoder=self.config.freeze_vision_encoder,
|
||||||
train_expert_only=self.config.train_expert_only,
|
train_expert_only=self.config.train_expert_only,
|
||||||
attention_implementation=self.config.attention_implementation,
|
attention_implementation=self.config.attention_implementation,
|
||||||
@@ -736,50 +758,64 @@ class VLAFlowMatching(nn.Module):
|
|||||||
self_attn_every_n_layers=self.config.self_attn_every_n_layers,
|
self_attn_every_n_layers=self.config.self_attn_every_n_layers,
|
||||||
expert_width_multiplier=self.config.expert_width_multiplier,
|
expert_width_multiplier=self.config.expert_width_multiplier,
|
||||||
self_attn_only_actions=self.config.self_attn_only_actions,
|
self_attn_only_actions=self.config.self_attn_only_actions,
|
||||||
)
|
)
|
||||||
# self.paligemma_with_expert = self.configure_peft(paligemma_with_expert)
|
# self.paligemma_with_expert = self.configure_peft(paligemma_with_expert)
|
||||||
self.vlm_with_expert.configure_peft(config=self.config)
|
self.vlm_with_expert.configure_peft(config=self.config)
|
||||||
# Projections are float32
|
# Projections are float32
|
||||||
self.state_to_prefix = self.config.state_to_prefix
|
self.state_to_prefix = self.config.state_to_prefix
|
||||||
if self.state_to_prefix:
|
if self.state_to_prefix:
|
||||||
self.state_proj = nn.Linear(self.config.max_state_dim, self.vlm_with_expert.config.text_config.hidden_size)
|
self.state_proj = nn.Linear(
|
||||||
|
self.config.max_state_dim, self.vlm_with_expert.config.text_config.hidden_size
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.state_proj = nn.Linear(self.config.max_state_dim, self.vlm_with_expert.expert_hidden_size)
|
self.state_proj = nn.Linear(self.config.max_state_dim, self.vlm_with_expert.expert_hidden_size)
|
||||||
self.action_in_proj = nn.Linear(self.config.max_action_dim, self.vlm_with_expert.expert_hidden_size)
|
self.action_in_proj = nn.Linear(self.config.max_action_dim, self.vlm_with_expert.expert_hidden_size)
|
||||||
self.action_out_proj = nn.Linear(self.vlm_with_expert.expert_hidden_size, self.config.max_action_dim)
|
self.action_out_proj = nn.Linear(self.vlm_with_expert.expert_hidden_size, self.config.max_action_dim)
|
||||||
|
|
||||||
self.action_time_mlp_in = nn.Linear(self.vlm_with_expert.expert_hidden_size * 2, self.vlm_with_expert.expert_hidden_size)
|
self.action_time_mlp_in = nn.Linear(
|
||||||
self.action_time_mlp_out = nn.Linear(self.vlm_with_expert.expert_hidden_size, self.vlm_with_expert.expert_hidden_size)
|
self.vlm_with_expert.expert_hidden_size * 2, self.vlm_with_expert.expert_hidden_size
|
||||||
|
)
|
||||||
|
self.action_time_mlp_out = nn.Linear(
|
||||||
|
self.vlm_with_expert.expert_hidden_size, self.vlm_with_expert.expert_hidden_size
|
||||||
|
)
|
||||||
|
|
||||||
self.set_requires_grad()
|
self.set_requires_grad()
|
||||||
# SmolVLM2 has: [fake_tok + crop_tok + crop + fake_tok + crop_tok ... + fake_tok + global_tok + global + fake_tok] + [second image] + ...
|
# SmolVLM2 has: [fake_tok + crop_tok + crop + fake_tok + crop_tok ... + fake_tok + global_tok + global + fake_tok] + [second image] + ...
|
||||||
if any([k in self.config.vlm_model_name for k in ["SmolVLM-", "SmolVLA-"]]):
|
if any([k in self.config.vlm_model_name for k in ["SmolVLM-", "SmolVLA-"]]):
|
||||||
if "SmolVLM-Instruct" in self.config.vlm_model_name:
|
if "SmolVLM-Instruct" in self.config.vlm_model_name:
|
||||||
self.fake_image_token = 49152
|
self.fake_image_token = 49152
|
||||||
self.global_image_token = [44, 13906, 29, 6266, 46]
|
self.global_image_token = [44, 13906, 29, 6266, 46]
|
||||||
self.global_image_start_token = torch.tensor([self.fake_image_token] + self.global_image_token, dtype=torch.long)
|
self.global_image_start_token = torch.tensor(
|
||||||
|
[self.fake_image_token] + self.global_image_token, dtype=torch.long
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.fake_image_token = 49189
|
self.fake_image_token = 49189
|
||||||
self.global_image_token = 49152
|
self.global_image_token = 49152
|
||||||
self.global_image_start_token = torch.tensor([self.fake_image_token, self.global_image_token], dtype=torch.long)
|
self.global_image_start_token = torch.tensor(
|
||||||
|
[self.fake_image_token, self.global_image_token], dtype=torch.long
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.fake_image_token = self.vlm_with_expert.processor.tokenizer.fake_image_token_id
|
self.fake_image_token = self.vlm_with_expert.processor.tokenizer.fake_image_token_id
|
||||||
self.global_image_token = self.vlm_with_expert.processor.tokenizer.global_image_token_id
|
self.global_image_token = self.vlm_with_expert.processor.tokenizer.global_image_token_id
|
||||||
self.global_image_start_token = torch.tensor([self.fake_image_token, self.global_image_token], dtype=torch.long)
|
self.global_image_start_token = torch.tensor(
|
||||||
|
[self.fake_image_token, self.global_image_token], dtype=torch.long
|
||||||
|
)
|
||||||
|
|
||||||
self.add_image_special_tokens = self.config.add_image_special_tokens
|
self.add_image_special_tokens = self.config.add_image_special_tokens
|
||||||
self.add_local_special_image_tokens = self.config.add_local_special_image_tokens
|
self.add_local_special_image_tokens = self.config.add_local_special_image_tokens
|
||||||
self.local_image_tokens = [torch.tensor([self.fake_image_token, tok], dtype=torch.long) for tok in [49153, 49154, 49155, 49159, 49160, 49161, 49165, 49166, 49167]] # assume 3 x 3 grid
|
self.local_image_tokens = [
|
||||||
|
torch.tensor([self.fake_image_token, tok], dtype=torch.long)
|
||||||
|
for tok in [49153, 49154, 49155, 49159, 49160, 49161, 49165, 49166, 49167]
|
||||||
|
] # assume 3 x 3 grid
|
||||||
|
|
||||||
self.local_image_start_token = self.global_image_start_token
|
self.local_image_start_token = self.global_image_start_token
|
||||||
self.image_end_token = torch.tensor([self.fake_image_token], dtype=torch.long)
|
self.image_end_token = torch.tensor([self.fake_image_token], dtype=torch.long)
|
||||||
self.prefix_length = self.config.prefix_length
|
self.prefix_length = self.config.prefix_length
|
||||||
self.include_past_images = self.config.n_obs_steps > 1 and "image" in self.config.past_obs_keys.split(",")
|
self.include_past_images = self.config.n_obs_steps > 1 and "image" in self.config.past_obs_keys.split(
|
||||||
|
","
|
||||||
|
)
|
||||||
self.num_past_images = self.config.n_obs_steps if self.include_past_images else 1
|
self.num_past_images = self.config.n_obs_steps if self.include_past_images else 1
|
||||||
self.causal_attention_on_history = self.config.causal_attention_on_history
|
self.causal_attention_on_history = self.config.causal_attention_on_history
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# def configure_peft(self, model):
|
# def configure_peft(self, model):
|
||||||
# # return model
|
# # return model
|
||||||
@@ -845,14 +881,30 @@ class VLAFlowMatching(nn.Module):
|
|||||||
img,
|
img,
|
||||||
img_mask,
|
img_mask,
|
||||||
) in enumerate(zip(images, img_masks, strict=False)):
|
) in enumerate(zip(images, img_masks, strict=False)):
|
||||||
# FIXME(mshukor): add special tokens for the history each history_steps or not
|
# FIXME(mshukor): add special tokens for the history each history_steps or not
|
||||||
if self.add_image_special_tokens:
|
if self.add_image_special_tokens:
|
||||||
if self.add_local_special_image_tokens and img_idx % num_images != num_images - 1:
|
if self.add_local_special_image_tokens and img_idx % num_images != num_images - 1:
|
||||||
local_token_idx = img_idx % num_images
|
local_token_idx = img_idx % num_images
|
||||||
image_start_token = self.vlm_with_expert.embed_language_tokens(self.local_image_tokens[local_token_idx].to(device=self.vlm_with_expert.vlm.device)).unsqueeze(0).expand(img.shape[0], -1, -1)
|
image_start_token = (
|
||||||
|
self.vlm_with_expert.embed_language_tokens(
|
||||||
|
self.local_image_tokens[local_token_idx].to(
|
||||||
|
device=self.vlm_with_expert.vlm.device
|
||||||
|
)
|
||||||
|
)
|
||||||
|
.unsqueeze(0)
|
||||||
|
.expand(img.shape[0], -1, -1)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
image_start_token = self.vlm_with_expert.embed_language_tokens(self.global_image_start_token.to(device=self.vlm_with_expert.vlm.device)).unsqueeze(0).expand(img.shape[0], -1, -1)
|
image_start_token = (
|
||||||
image_start_mask = torch.ones_like(image_start_token[:, :, 0], dtype=torch.bool, device=image_start_token.device)
|
self.vlm_with_expert.embed_language_tokens(
|
||||||
|
self.global_image_start_token.to(device=self.vlm_with_expert.vlm.device)
|
||||||
|
)
|
||||||
|
.unsqueeze(0)
|
||||||
|
.expand(img.shape[0], -1, -1)
|
||||||
|
)
|
||||||
|
image_start_mask = torch.ones_like(
|
||||||
|
image_start_token[:, :, 0], dtype=torch.bool, device=image_start_token.device
|
||||||
|
)
|
||||||
if self.causal_attention_on_history and img_idx % num_images == 0:
|
if self.causal_attention_on_history and img_idx % num_images == 0:
|
||||||
att_masks += [1] + [0] * (image_start_mask.shape[-1] - 1)
|
att_masks += [1] + [0] * (image_start_mask.shape[-1] - 1)
|
||||||
else:
|
else:
|
||||||
@@ -861,7 +913,7 @@ class VLAFlowMatching(nn.Module):
|
|||||||
pad_masks.append(image_start_mask)
|
pad_masks.append(image_start_mask)
|
||||||
|
|
||||||
img_emb = self.vlm_with_expert.embed_image(img)
|
img_emb = self.vlm_with_expert.embed_image(img)
|
||||||
img_emb = img_emb #.to(dtype=self.vlm_with_expert.type)
|
img_emb = img_emb # .to(dtype=self.vlm_with_expert.type)
|
||||||
|
|
||||||
# Normalize image embeddings
|
# Normalize image embeddings
|
||||||
img_emb_dim = img_emb.shape[-1]
|
img_emb_dim = img_emb.shape[-1]
|
||||||
@@ -880,16 +932,26 @@ class VLAFlowMatching(nn.Module):
|
|||||||
|
|
||||||
att_masks += [0] * (num_img_embs)
|
att_masks += [0] * (num_img_embs)
|
||||||
if self.add_image_special_tokens:
|
if self.add_image_special_tokens:
|
||||||
if not self.add_local_special_image_tokens or (self.add_local_special_image_tokens and img_idx % num_images == num_images - 1):
|
if not self.add_local_special_image_tokens or (
|
||||||
image_end_token = self.vlm_with_expert.embed_language_tokens(self.image_end_token.to(device=self.vlm_with_expert.vlm.device)).unsqueeze(0).expand(img.shape[0], -1, -1)
|
self.add_local_special_image_tokens and img_idx % num_images == num_images - 1
|
||||||
image_end_mask = torch.ones_like(image_end_token[:, :, 0], dtype=torch.bool, device=image_end_token.device)
|
):
|
||||||
|
image_end_token = (
|
||||||
|
self.vlm_with_expert.embed_language_tokens(
|
||||||
|
self.image_end_token.to(device=self.vlm_with_expert.vlm.device)
|
||||||
|
)
|
||||||
|
.unsqueeze(0)
|
||||||
|
.expand(img.shape[0], -1, -1)
|
||||||
|
)
|
||||||
|
image_end_mask = torch.ones_like(
|
||||||
|
image_end_token[:, :, 0], dtype=torch.bool, device=image_end_token.device
|
||||||
|
)
|
||||||
embs.append(image_end_token)
|
embs.append(image_end_token)
|
||||||
pad_masks.append(image_end_mask)
|
pad_masks.append(image_end_mask)
|
||||||
att_masks += [0] * (image_end_mask.shape[1])
|
att_masks += [0] * (image_end_mask.shape[1])
|
||||||
lang_emb = self.vlm_with_expert.embed_language_tokens(lang_tokens)
|
lang_emb = self.vlm_with_expert.embed_language_tokens(lang_tokens)
|
||||||
# Normalize language embeddings
|
# Normalize language embeddings
|
||||||
lang_emb_dim = lang_emb.shape[-1]
|
lang_emb_dim = lang_emb.shape[-1]
|
||||||
lang_emb = lang_emb * math.sqrt(lang_emb_dim) # FIXME(mshukor): is this needed for smolvlm?
|
lang_emb = lang_emb * math.sqrt(lang_emb_dim) # FIXME(mshukor): is this needed for smolvlm?
|
||||||
|
|
||||||
embs.append(lang_emb)
|
embs.append(lang_emb)
|
||||||
pad_masks.append(lang_masks)
|
pad_masks.append(lang_masks)
|
||||||
@@ -900,7 +962,9 @@ class VLAFlowMatching(nn.Module):
|
|||||||
|
|
||||||
if state is not None and self.state_to_prefix:
|
if state is not None and self.state_to_prefix:
|
||||||
state_emb = self.state_proj(state)
|
state_emb = self.state_proj(state)
|
||||||
state_emb = state_emb[:, None, :] if state_emb.ndim == 2 else state_emb #.to(dtype=self.vlm_with_expert.type)
|
state_emb = (
|
||||||
|
state_emb[:, None, :] if state_emb.ndim == 2 else state_emb
|
||||||
|
) # .to(dtype=self.vlm_with_expert.type)
|
||||||
embs.append(state_emb)
|
embs.append(state_emb)
|
||||||
bsize = state_emb.shape[0]
|
bsize = state_emb.shape[0]
|
||||||
dtype = state_emb.dtype
|
dtype = state_emb.dtype
|
||||||
@@ -912,7 +976,7 @@ class VLAFlowMatching(nn.Module):
|
|||||||
|
|
||||||
# Set attention masks so that image and language inputs do not attend to state or actions
|
# Set attention masks so that image and language inputs do not attend to state or actions
|
||||||
# att_masks += [1] + [0]*(states_seq_len - 1)
|
# att_masks += [1] + [0]*(states_seq_len - 1)
|
||||||
att_masks += [1]*(states_seq_len)
|
att_masks += [1] * (states_seq_len)
|
||||||
embs = torch.cat(embs, dim=1)
|
embs = torch.cat(embs, dim=1)
|
||||||
pad_masks = torch.cat(pad_masks, dim=1)
|
pad_masks = torch.cat(pad_masks, dim=1)
|
||||||
att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device)
|
att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device)
|
||||||
@@ -937,7 +1001,9 @@ class VLAFlowMatching(nn.Module):
|
|||||||
# Embed state
|
# Embed state
|
||||||
if not self.state_to_prefix:
|
if not self.state_to_prefix:
|
||||||
state_emb = self.state_proj(state)
|
state_emb = self.state_proj(state)
|
||||||
state_emb = state_emb[:, None, :] if state_emb.ndim == 2 else state_emb #.to(dtype=self.vlm_with_expert.type)
|
state_emb = (
|
||||||
|
state_emb[:, None, :] if state_emb.ndim == 2 else state_emb
|
||||||
|
) # .to(dtype=self.vlm_with_expert.type)
|
||||||
embs.append(state_emb)
|
embs.append(state_emb)
|
||||||
bsize = state_emb.shape[0]
|
bsize = state_emb.shape[0]
|
||||||
dtype = state_emb.dtype
|
dtype = state_emb.dtype
|
||||||
@@ -948,8 +1014,7 @@ class VLAFlowMatching(nn.Module):
|
|||||||
pad_masks.append(state_mask)
|
pad_masks.append(state_mask)
|
||||||
|
|
||||||
# Set attention masks so that image and language inputs do not attend to state or actions
|
# Set attention masks so that image and language inputs do not attend to state or actions
|
||||||
att_masks += [1] + [0]*(states_seq_len - 1)
|
att_masks += [1] + [0] * (states_seq_len - 1)
|
||||||
|
|
||||||
|
|
||||||
# Fuse timestep + action information using an MLP
|
# Fuse timestep + action information using an MLP
|
||||||
action_emb = self.action_in_proj(noisy_actions)
|
action_emb = self.action_in_proj(noisy_actions)
|
||||||
@@ -1010,7 +1075,7 @@ class VLAFlowMatching(nn.Module):
|
|||||||
images, img_masks, lang_tokens, lang_masks, state=state
|
images, img_masks, lang_tokens, lang_masks, state=state
|
||||||
)
|
)
|
||||||
suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(state, x_t, time)
|
suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(state, x_t, time)
|
||||||
|
|
||||||
pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)
|
pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)
|
||||||
att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1)
|
att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1)
|
||||||
|
|
||||||
@@ -1061,12 +1126,12 @@ class VLAFlowMatching(nn.Module):
|
|||||||
x_t = torch.zeros_like(noise, dtype=torch.float32, device=device)
|
x_t = torch.zeros_like(noise, dtype=torch.float32, device=device)
|
||||||
expanded_time = torch.zeros(bsize, dtype=torch.float32, device=device)
|
expanded_time = torch.zeros(bsize, dtype=torch.float32, device=device)
|
||||||
x_t = self.denoise_step(
|
x_t = self.denoise_step(
|
||||||
state,
|
state,
|
||||||
prefix_pad_masks,
|
prefix_pad_masks,
|
||||||
past_key_values,
|
past_key_values,
|
||||||
x_t,
|
x_t,
|
||||||
expanded_time,
|
expanded_time,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
dt = -1.0 / self.config.num_steps
|
dt = -1.0 / self.config.num_steps
|
||||||
dt = torch.tensor(dt, dtype=torch.float32, device=device)
|
dt = torch.tensor(dt, dtype=torch.float32, device=device)
|
||||||
|
|||||||
@@ -12,31 +12,28 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import List, Optional, Union
|
|
||||||
from functools import partial
|
|
||||||
import copy
|
import copy
|
||||||
|
from functools import partial
|
||||||
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.version
|
|
||||||
import torch.nn.functional as F # noqa: N812
|
import torch.nn.functional as F # noqa: N812
|
||||||
|
import torch.version
|
||||||
from peft import LoraConfig, TaskType, get_peft_model
|
from peft import LoraConfig, TaskType, get_peft_model
|
||||||
from pytest import Cache
|
from pytest import Cache
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoConfig,
|
AutoConfig,
|
||||||
GemmaForCausalLM,
|
|
||||||
AutoModelForImageTextToText,
|
|
||||||
AutoProcessor,
|
|
||||||
PretrainedConfig,
|
|
||||||
PreTrainedModel,
|
|
||||||
SmolVLMForConditionalGeneration,
|
|
||||||
AutoModel,
|
AutoModel,
|
||||||
|
AutoModelForImageTextToText,
|
||||||
AutoModelForVision2Seq,
|
AutoModelForVision2Seq,
|
||||||
|
AutoProcessor,
|
||||||
|
SmolVLMForConditionalGeneration,
|
||||||
)
|
)
|
||||||
from transformers.models.auto import CONFIG_MAPPING
|
|
||||||
from transformers import SmolVLMModel, SmolVLMConfig
|
|
||||||
from lerobot.policies.smolpi0.flex_attention import flex_attention_forward
|
from lerobot.policies.smolpi0.flex_attention import flex_attention_forward
|
||||||
|
|
||||||
|
|
||||||
def _round_up_to_multiple(x, multiple):
|
def _round_up_to_multiple(x, multiple):
|
||||||
return (x + multiple - 1) // multiple * multiple
|
return (x + multiple - 1) // multiple * multiple
|
||||||
|
|
||||||
@@ -180,20 +177,31 @@ def apply_rope(x, positions, max_wavelength=10_000):
|
|||||||
# f"Wrong value provided for `attention_implementation` ({self.attention_implementation}). Expected 'eager', 'fa2' or 'flex'."
|
# f"Wrong value provided for `attention_implementation` ({self.attention_implementation}). Expected 'eager', 'fa2' or 'flex'."
|
||||||
# )
|
# )
|
||||||
|
|
||||||
|
|
||||||
def get_intermediate_size(hidden_dim, ffn_dim_multiplier=4, multiple_of=256):
|
def get_intermediate_size(hidden_dim, ffn_dim_multiplier=4, multiple_of=256):
|
||||||
hidden_dim = int(2 * hidden_dim / 3)
|
hidden_dim = int(2 * hidden_dim / 3)
|
||||||
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
|
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
|
||||||
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
||||||
return hidden_dim
|
return hidden_dim
|
||||||
|
|
||||||
|
|
||||||
class SmolVLMWithExpertModel(nn.Module):
|
class SmolVLMWithExpertModel(nn.Module):
|
||||||
# config_class = PaliGemmaWithExpertConfig
|
# config_class = PaliGemmaWithExpertConfig
|
||||||
|
|
||||||
def __init__(self, model_id: str = "HuggingFaceTB/SmolVLM2-500M-Video-Instruct",
|
def __init__(
|
||||||
load_vlm_weights: bool = True, train_expert_only: bool = True, freeze_vision_encoder: bool = False,
|
self,
|
||||||
attention_implementation: str = "eager", attention_mode: str = "self_attn", num_expert_layers: int = -1,
|
model_id: str = "HuggingFaceTB/SmolVLM2-500M-Video-Instruct",
|
||||||
num_vlm_layers: int = -1, self_attn_every_n_layers: int = -1, expert_width_multiplier: float = 0.5, self_attn_only_actions: bool = False):
|
load_vlm_weights: bool = True,
|
||||||
|
train_expert_only: bool = True,
|
||||||
|
freeze_vision_encoder: bool = False,
|
||||||
|
attention_implementation: str = "eager",
|
||||||
|
attention_mode: str = "self_attn",
|
||||||
|
num_expert_layers: int = -1,
|
||||||
|
num_vlm_layers: int = -1,
|
||||||
|
self_attn_every_n_layers: int = -1,
|
||||||
|
expert_width_multiplier: float = 0.5,
|
||||||
|
self_attn_only_actions: bool = False,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if load_vlm_weights:
|
if load_vlm_weights:
|
||||||
print(f"Loading {model_id} weights ...")
|
print(f"Loading {model_id} weights ...")
|
||||||
@@ -227,15 +235,17 @@ class SmolVLMWithExpertModel(nn.Module):
|
|||||||
# Smaller lm expert
|
# Smaller lm expert
|
||||||
lm_expert_config = copy.deepcopy(config.text_config)
|
lm_expert_config = copy.deepcopy(config.text_config)
|
||||||
hidden_size = lm_expert_config.hidden_size
|
hidden_size = lm_expert_config.hidden_size
|
||||||
lm_expert_config.hidden_size = int(hidden_size*expert_width_multiplier) #hidden_size // 2
|
lm_expert_config.hidden_size = int(hidden_size * expert_width_multiplier) # hidden_size // 2
|
||||||
lm_expert_config.intermediate_size = get_intermediate_size(int(hidden_size*expert_width_multiplier))
|
lm_expert_config.intermediate_size = get_intermediate_size(int(hidden_size * expert_width_multiplier))
|
||||||
lm_expert_config.num_hidden_layers = self.num_vlm_layers
|
lm_expert_config.num_hidden_layers = self.num_vlm_layers
|
||||||
if num_expert_layers > 0 :
|
if num_expert_layers > 0:
|
||||||
assert len(self.get_vlm_model().text_model.layers) % num_expert_layers == 0, f"Number of layers in the VLM {len(self.get_vlm_model().text_model.layers)} are not multiple of num_expert_layers {num_expert_layers}"
|
assert len(self.get_vlm_model().text_model.layers) % num_expert_layers == 0, (
|
||||||
|
f"Number of layers in the VLM {len(self.get_vlm_model().text_model.layers)} are not multiple of num_expert_layers {num_expert_layers}"
|
||||||
|
)
|
||||||
lm_expert_config.num_hidden_layers = num_expert_layers
|
lm_expert_config.num_hidden_layers = num_expert_layers
|
||||||
# lm_expert_config.head_dim = lm_expert_config.head_dim * 2
|
# lm_expert_config.head_dim = lm_expert_config.head_dim * 2
|
||||||
self.lm_expert = AutoModel.from_config(lm_expert_config)
|
self.lm_expert = AutoModel.from_config(lm_expert_config)
|
||||||
|
|
||||||
self.num_expert_layers = len(self.lm_expert.layers)
|
self.num_expert_layers = len(self.lm_expert.layers)
|
||||||
self.self_attn_every_n_layers = self_attn_every_n_layers
|
self.self_attn_every_n_layers = self_attn_every_n_layers
|
||||||
self.self_attn_only_actions = self_attn_only_actions
|
self.self_attn_only_actions = self_attn_only_actions
|
||||||
@@ -245,10 +255,14 @@ class SmolVLMWithExpertModel(nn.Module):
|
|||||||
if self.self_attn_every_n_layers > 0 and layer_idx % self.self_attn_every_n_layers == 0:
|
if self.self_attn_every_n_layers > 0 and layer_idx % self.self_attn_every_n_layers == 0:
|
||||||
continue
|
continue
|
||||||
self.lm_expert.layers[layer_idx].self_attn.k_proj = nn.Linear(
|
self.lm_expert.layers[layer_idx].self_attn.k_proj = nn.Linear(
|
||||||
config.text_config.num_key_value_heads * config.text_config.head_dim, lm_expert_config.num_key_value_heads * lm_expert_config.head_dim, bias=lm_expert_config.attention_bias
|
config.text_config.num_key_value_heads * config.text_config.head_dim,
|
||||||
|
lm_expert_config.num_key_value_heads * lm_expert_config.head_dim,
|
||||||
|
bias=lm_expert_config.attention_bias,
|
||||||
)
|
)
|
||||||
self.lm_expert.layers[layer_idx].self_attn.v_proj = nn.Linear(
|
self.lm_expert.layers[layer_idx].self_attn.v_proj = nn.Linear(
|
||||||
config.text_config.num_key_value_heads * config.text_config.head_dim, lm_expert_config.num_key_value_heads * lm_expert_config.head_dim, bias=lm_expert_config.attention_bias
|
config.text_config.num_key_value_heads * config.text_config.head_dim,
|
||||||
|
lm_expert_config.num_key_value_heads * lm_expert_config.head_dim,
|
||||||
|
bias=lm_expert_config.attention_bias,
|
||||||
)
|
)
|
||||||
# Remove unused embed_tokens
|
# Remove unused embed_tokens
|
||||||
self.lm_expert.embed_tokens = None
|
self.lm_expert.embed_tokens = None
|
||||||
@@ -308,8 +322,10 @@ class SmolVLMWithExpertModel(nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.vlm = self.vlm.merge_and_unload()
|
self.vlm = self.vlm.merge_and_unload()
|
||||||
|
|
||||||
def get_vlm_model(self,):
|
def get_vlm_model(
|
||||||
if hasattr(self.vlm.model, "model"): # When using peft
|
self,
|
||||||
|
):
|
||||||
|
if hasattr(self.vlm.model, "model"): # When using peft
|
||||||
return self.vlm.model.model
|
return self.vlm.model.model
|
||||||
else:
|
else:
|
||||||
return self.vlm.model
|
return self.vlm.model
|
||||||
@@ -326,22 +342,20 @@ class SmolVLMWithExpertModel(nn.Module):
|
|||||||
else:
|
else:
|
||||||
# To avoid unused params issue with distributed training
|
# To avoid unused params issue with distributed training
|
||||||
last_layers = [self.num_vlm_layers - 1]
|
last_layers = [self.num_vlm_layers - 1]
|
||||||
if self.num_vlm_layers != self.num_expert_layers and self.num_vlm_layers % self.num_expert_layers == 0:
|
if (
|
||||||
|
self.num_vlm_layers != self.num_expert_layers
|
||||||
|
and self.num_vlm_layers % self.num_expert_layers == 0
|
||||||
|
):
|
||||||
last_layers.append(self.num_vlm_layers - 2)
|
last_layers.append(self.num_vlm_layers - 2)
|
||||||
frozen_layers = [
|
frozen_layers = [
|
||||||
"lm_head",
|
"lm_head",
|
||||||
"text_model.model.norm.weight",
|
"text_model.model.norm.weight",
|
||||||
]
|
]
|
||||||
for layer in last_layers:
|
for layer in last_layers:
|
||||||
frozen_layers.append(f"text_model.model.layers.{layer}.")
|
frozen_layers.append(f"text_model.model.layers.{layer}.")
|
||||||
|
|
||||||
for name, params in self.vlm.named_parameters():
|
for name, params in self.vlm.named_parameters():
|
||||||
if any(
|
if any([k in name for k in frozen_layers]):
|
||||||
[
|
|
||||||
k in name
|
|
||||||
for k in frozen_layers
|
|
||||||
]
|
|
||||||
):
|
|
||||||
params.requires_grad = False
|
params.requires_grad = False
|
||||||
# To avoid unused params issue with distributed training
|
# To avoid unused params issue with distributed training
|
||||||
for name, params in self.lm_expert.named_parameters():
|
for name, params in self.lm_expert.named_parameters():
|
||||||
@@ -410,19 +424,34 @@ class SmolVLMWithExpertModel(nn.Module):
|
|||||||
|
|
||||||
# FIXME(mshukor): add special image tokens specific to smolvlm
|
# FIXME(mshukor): add special image tokens specific to smolvlm
|
||||||
# Get sequence from the vision encoder
|
# Get sequence from the vision encoder
|
||||||
image_hidden_states = self.get_vlm_model().vision_model(
|
image_hidden_states = (
|
||||||
pixel_values=image.to(dtype=self.get_vlm_model().vision_model.dtype),
|
self.get_vlm_model()
|
||||||
patch_attention_mask=patch_attention_mask,
|
.vision_model(
|
||||||
).last_hidden_state
|
pixel_values=image.to(dtype=self.get_vlm_model().vision_model.dtype),
|
||||||
|
patch_attention_mask=patch_attention_mask,
|
||||||
|
)
|
||||||
|
.last_hidden_state
|
||||||
|
)
|
||||||
# Modality projection & resampling
|
# Modality projection & resampling
|
||||||
image_hidden_states = self.get_vlm_model().connector(image_hidden_states)
|
image_hidden_states = self.get_vlm_model().connector(image_hidden_states)
|
||||||
return image_hidden_states
|
return image_hidden_states
|
||||||
|
|
||||||
def embed_language_tokens(self, tokens: torch.Tensor):
|
def embed_language_tokens(self, tokens: torch.Tensor):
|
||||||
return self.get_vlm_model().text_model.get_input_embeddings()(tokens)
|
return self.get_vlm_model().text_model.get_input_embeddings()(tokens)
|
||||||
|
|
||||||
def forward_attn_layer(self, model_layers, inputs_embeds, layer_idx, position_ids, attention_mask, batch_size, head_dim, use_cache: bool = True, fill_kv_cache: bool = True, past_key_values=None) -> list[torch.Tensor]:
|
def forward_attn_layer(
|
||||||
|
self,
|
||||||
|
model_layers,
|
||||||
|
inputs_embeds,
|
||||||
|
layer_idx,
|
||||||
|
position_ids,
|
||||||
|
attention_mask,
|
||||||
|
batch_size,
|
||||||
|
head_dim,
|
||||||
|
use_cache: bool = True,
|
||||||
|
fill_kv_cache: bool = True,
|
||||||
|
past_key_values=None,
|
||||||
|
) -> list[torch.Tensor]:
|
||||||
query_states = []
|
query_states = []
|
||||||
key_states = []
|
key_states = []
|
||||||
value_states = []
|
value_states = []
|
||||||
@@ -430,7 +459,7 @@ class SmolVLMWithExpertModel(nn.Module):
|
|||||||
layer = model_layers[i][layer_idx]
|
layer = model_layers[i][layer_idx]
|
||||||
if hidden_states is None or layer is None:
|
if hidden_states is None or layer is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# normalizer = torch.tensor(models[i].config.hidden_size**0.5, dtype=hidden_states.dtype)
|
# normalizer = torch.tensor(models[i].config.hidden_size**0.5, dtype=hidden_states.dtype)
|
||||||
# hidden_states = hidden_states * normalizer
|
# hidden_states = hidden_states * normalizer
|
||||||
hidden_states = layer.input_layernorm(hidden_states)
|
hidden_states = layer.input_layernorm(hidden_states)
|
||||||
@@ -468,12 +497,16 @@ class SmolVLMWithExpertModel(nn.Module):
|
|||||||
if inputs_embeds[1] is not None:
|
if inputs_embeds[1] is not None:
|
||||||
suffix_len = inputs_embeds[1].shape[1]
|
suffix_len = inputs_embeds[1].shape[1]
|
||||||
attention_mask_[:, -suffix_len:, :-suffix_len] = False
|
attention_mask_[:, -suffix_len:, :-suffix_len] = False
|
||||||
position_ids_[:, -suffix_len:] = _position_ids[:, -suffix_len:] - _position_ids[:, -suffix_len][:, None]
|
position_ids_[:, -suffix_len:] = (
|
||||||
|
_position_ids[:, -suffix_len:] - _position_ids[:, -suffix_len][:, None]
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
attention_mask_ = _attention_mask
|
attention_mask_ = _attention_mask
|
||||||
position_ids_ = _position_ids
|
position_ids_ = _position_ids
|
||||||
|
|
||||||
query_states = apply_rope(query_states, position_ids_) # FIXME(mshukor): this assumes we have always the vlm features?
|
query_states = apply_rope(
|
||||||
|
query_states, position_ids_
|
||||||
|
) # FIXME(mshukor): this assumes we have always the vlm features?
|
||||||
key_states = apply_rope(key_states, position_ids_)
|
key_states = apply_rope(key_states, position_ids_)
|
||||||
|
|
||||||
if use_cache and past_key_values is None:
|
if use_cache and past_key_values is None:
|
||||||
@@ -491,9 +524,7 @@ class SmolVLMWithExpertModel(nn.Module):
|
|||||||
# the max len, then we (for instance) double the cache size. This implementation already exists
|
# the max len, then we (for instance) double the cache size. This implementation already exists
|
||||||
# in `transformers`. (molbap)
|
# in `transformers`. (molbap)
|
||||||
key_states = torch.cat([past_key_values[layer_idx]["key_states"], key_states], dim=1)
|
key_states = torch.cat([past_key_values[layer_idx]["key_states"], key_states], dim=1)
|
||||||
value_states = torch.cat(
|
value_states = torch.cat([past_key_values[layer_idx]["value_states"], value_states], dim=1)
|
||||||
[past_key_values[layer_idx]["value_states"], value_states], dim=1
|
|
||||||
)
|
|
||||||
|
|
||||||
attention_interface = self.get_attention_interface()
|
attention_interface = self.get_attention_interface()
|
||||||
|
|
||||||
@@ -504,20 +535,32 @@ class SmolVLMWithExpertModel(nn.Module):
|
|||||||
|
|
||||||
return [att_output], past_key_values
|
return [att_output], past_key_values
|
||||||
|
|
||||||
|
def forward_cross_attn_layer(
|
||||||
def forward_cross_attn_layer(self, model_layers, inputs_embeds, layer_idx, position_ids, attention_mask, batch_size, head_dim, use_cache: bool = True, fill_kv_cache: bool = True, past_key_values = None) -> list[torch.Tensor]:
|
self,
|
||||||
|
model_layers,
|
||||||
|
inputs_embeds,
|
||||||
|
layer_idx,
|
||||||
|
position_ids,
|
||||||
|
attention_mask,
|
||||||
|
batch_size,
|
||||||
|
head_dim,
|
||||||
|
use_cache: bool = True,
|
||||||
|
fill_kv_cache: bool = True,
|
||||||
|
past_key_values=None,
|
||||||
|
) -> list[torch.Tensor]:
|
||||||
attention_interface = self.get_attention_interface()
|
attention_interface = self.get_attention_interface()
|
||||||
|
|
||||||
att_outputs = []
|
att_outputs = []
|
||||||
assert len(inputs_embeds) == 2 or (use_cache and past_key_values is not None and not fill_kv_cache), f"Both len(inputs_embeds) == {len(inputs_embeds)} and past_key_values is {past_key_values}"
|
assert len(inputs_embeds) == 2 or (use_cache and past_key_values is not None and not fill_kv_cache), (
|
||||||
|
f"Both len(inputs_embeds) == {len(inputs_embeds)} and past_key_values is {past_key_values}"
|
||||||
|
)
|
||||||
|
|
||||||
if len(inputs_embeds) == 2 and not past_key_values:
|
if len(inputs_embeds) == 2 and not past_key_values:
|
||||||
# Prefix attention
|
# Prefix attention
|
||||||
seq_len = inputs_embeds[0].shape[1]
|
seq_len = inputs_embeds[0].shape[1]
|
||||||
position_id, expert_position_id = position_ids[:, :seq_len], position_ids[:, seq_len:]
|
position_id, expert_position_id = position_ids[:, :seq_len], position_ids[:, seq_len:]
|
||||||
prefix_attention_mask = attention_mask[:, :seq_len, :seq_len]
|
prefix_attention_mask = attention_mask[:, :seq_len, :seq_len]
|
||||||
|
|
||||||
layer = model_layers[0][layer_idx]
|
layer = model_layers[0][layer_idx]
|
||||||
|
|
||||||
hidden_states = layer.input_layernorm(inputs_embeds[0])
|
hidden_states = layer.input_layernorm(inputs_embeds[0])
|
||||||
@@ -558,7 +601,6 @@ class SmolVLMWithExpertModel(nn.Module):
|
|||||||
key_states = past_key_values[layer_idx]["key_states"]
|
key_states = past_key_values[layer_idx]["key_states"]
|
||||||
value_states = past_key_values[layer_idx]["value_states"]
|
value_states = past_key_values[layer_idx]["value_states"]
|
||||||
|
|
||||||
|
|
||||||
# Expert
|
# Expert
|
||||||
expert_layer = model_layers[1][layer_idx]
|
expert_layer = model_layers[1][layer_idx]
|
||||||
if expert_layer is not None:
|
if expert_layer is not None:
|
||||||
@@ -570,21 +612,37 @@ class SmolVLMWithExpertModel(nn.Module):
|
|||||||
expert_hidden_states = expert_hidden_states.to(dtype=expert_layer.self_attn.q_proj.weight.dtype)
|
expert_hidden_states = expert_hidden_states.to(dtype=expert_layer.self_attn.q_proj.weight.dtype)
|
||||||
expert_query_state = expert_layer.self_attn.q_proj(expert_hidden_states).view(expert_hidden_shape)
|
expert_query_state = expert_layer.self_attn.q_proj(expert_hidden_states).view(expert_hidden_shape)
|
||||||
|
|
||||||
|
_key_states = key_states.to(dtype=expert_layer.self_attn.k_proj.weight.dtype).view(
|
||||||
|
*key_states.shape[:2], -1
|
||||||
|
)
|
||||||
|
expert_key_states = expert_layer.self_attn.k_proj(_key_states).view(
|
||||||
|
*_key_states.shape[:-1], -1, expert_layer.self_attn.head_dim
|
||||||
|
) # k_proj should have same dim as kv
|
||||||
|
|
||||||
_key_states = key_states.to(dtype=expert_layer.self_attn.k_proj.weight.dtype).view(*key_states.shape[:2], -1)
|
_value_states = value_states.to(dtype=expert_layer.self_attn.v_proj.weight.dtype).view(
|
||||||
expert_key_states = expert_layer.self_attn.k_proj(_key_states).view(*_key_states.shape[:-1], -1, expert_layer.self_attn.head_dim) # k_proj should have same dim as kv
|
*value_states.shape[:2], -1
|
||||||
|
)
|
||||||
|
expert_value_states = expert_layer.self_attn.v_proj(_value_states).view(
|
||||||
|
*_value_states.shape[:-1], -1, expert_layer.self_attn.head_dim
|
||||||
|
)
|
||||||
|
|
||||||
_value_states = value_states.to(dtype=expert_layer.self_attn.v_proj.weight.dtype).view(*value_states.shape[:2], -1)
|
expert_position_id = (
|
||||||
expert_value_states = expert_layer.self_attn.v_proj(_value_states).view(*_value_states.shape[:-1], -1, expert_layer.self_attn.head_dim)
|
expert_position_id - torch.min(expert_position_id, dim=1, keepdim=True).values
|
||||||
|
) # start from 0
|
||||||
expert_position_id = expert_position_id - torch.min(expert_position_id, dim=1, keepdim=True).values # start from 0
|
expert_attention_mask = attention_mask[
|
||||||
expert_attention_mask = attention_mask[:, -inputs_embeds[1].shape[1]:, :expert_key_states.shape[1]:] # take into account kv
|
:, -inputs_embeds[1].shape[1] :, : expert_key_states.shape[1] :
|
||||||
|
] # take into account kv
|
||||||
|
|
||||||
expert_query_states = apply_rope(expert_query_state, expert_position_id)
|
expert_query_states = apply_rope(expert_query_state, expert_position_id)
|
||||||
# expert_key_states = apply_rope(expert_key_state, expert_position_id)
|
# expert_key_states = apply_rope(expert_key_state, expert_position_id)
|
||||||
|
|
||||||
att_output = attention_interface(
|
att_output = attention_interface(
|
||||||
expert_attention_mask, batch_size, head_dim, expert_query_states, expert_key_states, expert_value_states
|
expert_attention_mask,
|
||||||
|
batch_size,
|
||||||
|
head_dim,
|
||||||
|
expert_query_states,
|
||||||
|
expert_key_states,
|
||||||
|
expert_value_states,
|
||||||
)
|
)
|
||||||
att_outputs.append(att_output)
|
att_outputs.append(att_output)
|
||||||
else:
|
else:
|
||||||
@@ -592,8 +650,8 @@ class SmolVLMWithExpertModel(nn.Module):
|
|||||||
|
|
||||||
# att_output = att_output.to(dtype=models[i].dtype)
|
# att_output = att_output.to(dtype=models[i].dtype)
|
||||||
return att_outputs, past_key_values
|
return att_outputs, past_key_values
|
||||||
|
|
||||||
def get_model_layers(self, models: list) -> list: # FIXME(mshukor): is this efficient?
|
def get_model_layers(self, models: list) -> list: # FIXME(mshukor): is this efficient?
|
||||||
vlm_layers = []
|
vlm_layers = []
|
||||||
expert_layers = []
|
expert_layers = []
|
||||||
multiple_of = self.num_vlm_layers // self.num_expert_layers
|
multiple_of = self.num_vlm_layers // self.num_expert_layers
|
||||||
@@ -606,15 +664,16 @@ class SmolVLMWithExpertModel(nn.Module):
|
|||||||
vlm_layers.append(models[0].layers[i])
|
vlm_layers.append(models[0].layers[i])
|
||||||
expert_layers.append(expert_layer)
|
expert_layers.append(expert_layer)
|
||||||
return [vlm_layers, expert_layers]
|
return [vlm_layers, expert_layers]
|
||||||
|
|
||||||
# TODO: break down this huge forward into modules or functions
|
# TODO: break down this huge forward into modules or functions
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: torch.Tensor | None = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: torch.LongTensor | None = None,
|
||||||
past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None,
|
past_key_values: list[torch.FloatTensor] | Cache | None = None,
|
||||||
inputs_embeds: List[torch.FloatTensor] = None,
|
inputs_embeds: list[torch.FloatTensor] = None,
|
||||||
use_cache: Optional[bool] = None,
|
use_cache: bool | None = None,
|
||||||
fill_kv_cache: Optional[bool] = None,
|
fill_kv_cache: bool | None = None,
|
||||||
):
|
):
|
||||||
models = [self.get_vlm_model().text_model, self.lm_expert]
|
models = [self.get_vlm_model().text_model, self.lm_expert]
|
||||||
model_layers = self.get_model_layers(models)
|
model_layers = self.get_model_layers(models)
|
||||||
@@ -626,26 +685,58 @@ class SmolVLMWithExpertModel(nn.Module):
|
|||||||
continue
|
continue
|
||||||
batch_size = hidden_states.shape[0]
|
batch_size = hidden_states.shape[0]
|
||||||
|
|
||||||
# # Pad prefix embds so that prefix_embs + prefix_embs len are multiple of 128, pad left or right depending on the gen or train
|
# # Pad prefix embds so that prefix_embs + prefix_embs len are multiple of 128, pad left or right depending on the gen or train
|
||||||
if self.attention_implementation == "flex":
|
if self.attention_implementation == "flex":
|
||||||
if inputs_embeds[0] is not None and inputs_embeds[1] is not None and attention_mask.shape[-1] == attention_mask.shape[-2] and past_key_values is None: # Now only during training
|
if (
|
||||||
|
inputs_embeds[0] is not None
|
||||||
|
and inputs_embeds[1] is not None
|
||||||
|
and attention_mask.shape[-1] == attention_mask.shape[-2]
|
||||||
|
and past_key_values is None
|
||||||
|
): # Now only during training
|
||||||
seq_len = inputs_embeds[0].shape[1] + inputs_embeds[1].shape[1]
|
seq_len = inputs_embeds[0].shape[1] + inputs_embeds[1].shape[1]
|
||||||
padded_seq_len = _round_up_to_multiple(seq_len, 128) # FIXME(mshukor): more efficient to have a fixed seq len?
|
padded_seq_len = _round_up_to_multiple(
|
||||||
|
seq_len, 128
|
||||||
|
) # FIXME(mshukor): more efficient to have a fixed seq len?
|
||||||
b_mask, q_len, kv_len = attention_mask.shape # The shape of your mask
|
b_mask, q_len, kv_len = attention_mask.shape # The shape of your mask
|
||||||
pad = padded_seq_len - q_len
|
pad = padded_seq_len - q_len
|
||||||
attention_mask = F.pad(attention_mask, (0, pad, 0, pad), value=True)
|
attention_mask = F.pad(attention_mask, (0, pad, 0, pad), value=True)
|
||||||
inputs_embeds[0] = F.pad(inputs_embeds[0], (0, 0, 0, pad), value=0.0)
|
inputs_embeds[0] = F.pad(inputs_embeds[0], (0, 0, 0, pad), value=0.0)
|
||||||
position_ids = F.pad(position_ids, (0, pad), value=0)
|
position_ids = F.pad(position_ids, (0, pad), value=0)
|
||||||
|
|
||||||
|
|
||||||
# RMSNorm
|
# RMSNorm
|
||||||
num_layers = self.num_vlm_layers
|
num_layers = self.num_vlm_layers
|
||||||
head_dim = self.vlm.config.text_config.head_dim
|
head_dim = self.vlm.config.text_config.head_dim
|
||||||
for layer_idx in range(num_layers):
|
for layer_idx in range(num_layers):
|
||||||
if fill_kv_cache or "cross" not in self.attention_mode or (self.self_attn_every_n_layers > 0 and layer_idx % self.self_attn_every_n_layers == 0):
|
if (
|
||||||
att_outputs, past_key_values = self.forward_attn_layer(model_layers, inputs_embeds, layer_idx, position_ids, attention_mask, batch_size, head_dim, use_cache=use_cache, fill_kv_cache=fill_kv_cache, past_key_values=past_key_values)
|
fill_kv_cache
|
||||||
|
or "cross" not in self.attention_mode
|
||||||
|
or (self.self_attn_every_n_layers > 0 and layer_idx % self.self_attn_every_n_layers == 0)
|
||||||
|
):
|
||||||
|
att_outputs, past_key_values = self.forward_attn_layer(
|
||||||
|
model_layers,
|
||||||
|
inputs_embeds,
|
||||||
|
layer_idx,
|
||||||
|
position_ids,
|
||||||
|
attention_mask,
|
||||||
|
batch_size,
|
||||||
|
head_dim,
|
||||||
|
use_cache=use_cache,
|
||||||
|
fill_kv_cache=fill_kv_cache,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
att_outputs, past_key_values = self.forward_cross_attn_layer(model_layers, inputs_embeds, layer_idx, position_ids, attention_mask, batch_size, head_dim, use_cache=use_cache, fill_kv_cache=fill_kv_cache, past_key_values=past_key_values)
|
att_outputs, past_key_values = self.forward_cross_attn_layer(
|
||||||
|
model_layers,
|
||||||
|
inputs_embeds,
|
||||||
|
layer_idx,
|
||||||
|
position_ids,
|
||||||
|
attention_mask,
|
||||||
|
batch_size,
|
||||||
|
head_dim,
|
||||||
|
use_cache=use_cache,
|
||||||
|
fill_kv_cache=fill_kv_cache,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
)
|
||||||
# query_states = []
|
# query_states = []
|
||||||
# key_states = []
|
# key_states = []
|
||||||
# value_states = []
|
# value_states = []
|
||||||
@@ -703,7 +794,6 @@ class SmolVLMWithExpertModel(nn.Module):
|
|||||||
# attention_mask, batch_size, head_dim, query_states, key_states, value_states
|
# attention_mask, batch_size, head_dim, query_states, key_states, value_states
|
||||||
# )
|
# )
|
||||||
|
|
||||||
|
|
||||||
# att_output = att_output.to(dtype=models[i].dtype)
|
# att_output = att_output.to(dtype=models[i].dtype)
|
||||||
|
|
||||||
# first part of att_output is prefix (up to sequence length, [:, 0:prefix_seq_len])
|
# first part of att_output is prefix (up to sequence length, [:, 0:prefix_seq_len])
|
||||||
@@ -712,7 +802,9 @@ class SmolVLMWithExpertModel(nn.Module):
|
|||||||
for i, hidden_states in enumerate(inputs_embeds):
|
for i, hidden_states in enumerate(inputs_embeds):
|
||||||
# layer = models[i].layers[layer_idx]
|
# layer = models[i].layers[layer_idx]
|
||||||
layer = model_layers[i][layer_idx]
|
layer = model_layers[i][layer_idx]
|
||||||
att_output = att_outputs[i] if i < len(att_outputs) else att_outputs[0] # in case of self_attn
|
att_output = (
|
||||||
|
att_outputs[i] if i < len(att_outputs) else att_outputs[0]
|
||||||
|
) # in case of self_attn
|
||||||
if hidden_states is not None:
|
if hidden_states is not None:
|
||||||
if layer is None:
|
if layer is None:
|
||||||
outputs_embeds.append(hidden_states)
|
outputs_embeds.append(hidden_states)
|
||||||
@@ -759,7 +851,11 @@ class SmolVLMWithExpertModel(nn.Module):
|
|||||||
if self.attention_implementation == "fa2":
|
if self.attention_implementation == "fa2":
|
||||||
attention_interface = self.flash_attention_forward
|
attention_interface = self.flash_attention_forward
|
||||||
elif self.attention_implementation == "flex":
|
elif self.attention_implementation == "flex":
|
||||||
attention_interface = partial(flex_attention_forward, num_att_heads=self.num_attention_heads, num_key_value_heads=self.num_key_value_heads)
|
attention_interface = partial(
|
||||||
|
flex_attention_forward,
|
||||||
|
num_att_heads=self.num_attention_heads,
|
||||||
|
num_key_value_heads=self.num_key_value_heads,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
attention_interface = self.eager_attention_forward
|
attention_interface = self.eager_attention_forward
|
||||||
return attention_interface
|
return attention_interface
|
||||||
@@ -807,7 +903,7 @@ class SmolVLMWithExpertModel(nn.Module):
|
|||||||
att_weights *= head_dim**-0.5
|
att_weights *= head_dim**-0.5
|
||||||
|
|
||||||
att_weights = att_weights.to(dtype=torch.float32)
|
att_weights = att_weights.to(dtype=torch.float32)
|
||||||
big_neg = torch.finfo(att_weights.dtype).min #-2.3819763e38 # See gemma/modules.py
|
big_neg = torch.finfo(att_weights.dtype).min # -2.3819763e38 # See gemma/modules.py
|
||||||
masked_att_weights = torch.where(attention_mask[:, None, :, :], att_weights, big_neg)
|
masked_att_weights = torch.where(attention_mask[:, None, :, :], att_weights, big_neg)
|
||||||
probs = nn.functional.softmax(masked_att_weights, dim=-1)
|
probs = nn.functional.softmax(masked_att_weights, dim=-1)
|
||||||
probs = probs.to(dtype=value_states.dtype)
|
probs = probs.to(dtype=value_states.dtype)
|
||||||
|
|||||||
@@ -1027,6 +1027,7 @@ from lerobot.policies.utils import (
|
|||||||
populate_queues,
|
populate_queues,
|
||||||
)
|
)
|
||||||
from lerobot.utils.utils import get_safe_dtype
|
from lerobot.utils.utils import get_safe_dtype
|
||||||
|
|
||||||
# OBS_STATE = 'state'
|
# OBS_STATE = 'state'
|
||||||
# Matches ".soNNN", optionally followed by "-something", up to the "_buffer_" marker
|
# Matches ".soNNN", optionally followed by "-something", up to the "_buffer_" marker
|
||||||
_VARIANT_RE = re.compile(r"\.so\d+(?:-[\w]+)?_buffer_")
|
_VARIANT_RE = re.compile(r"\.so\d+(?:-[\w]+)?_buffer_")
|
||||||
@@ -1891,4 +1892,4 @@ class VLAFlowMatching(nn.Module):
|
|||||||
suffix_out = suffix_out[:, -self.config.chunk_size :]
|
suffix_out = suffix_out[:, -self.config.chunk_size :]
|
||||||
suffix_out = suffix_out.to(dtype=torch.float32)
|
suffix_out = suffix_out.to(dtype=torch.float32)
|
||||||
v_t = self.action_out_proj(suffix_out)
|
v_t = self.action_out_proj(suffix_out)
|
||||||
return v_t
|
return v_t
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
c
|
c
|
||||||
|
|||||||
@@ -547,4 +547,4 @@ class SmolVLMWithExpertModel(nn.Module):
|
|||||||
# we use -1 because sequence length can change
|
# we use -1 because sequence length can change
|
||||||
att_output = att_output.reshape(batch_size, -1, num_key_value_heads * num_key_value_groups * head_dim)
|
att_output = att_output.reshape(batch_size, -1, num_key_value_heads * num_key_value_groups * head_dim)
|
||||||
|
|
||||||
return att_output
|
return att_output
|
||||||
|
|||||||
+60
-35
@@ -45,17 +45,21 @@ Note that in both examples, the repo/folder should contain at least `config.json
|
|||||||
|
|
||||||
You can learn about the CLI options for this script in the `EvalPipelineConfig` in lerobot/configs/eval.py
|
You can learn about the CLI options for this script in the `EvalPipelineConfig` in lerobot/configs/eval.py
|
||||||
"""
|
"""
|
||||||
import concurrent
|
|
||||||
|
import concurrent.futures as cf
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
|
from collections import defaultdict
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from dataclasses import asdict
|
from dataclasses import asdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from pprint import pformat
|
from pprint import pformat
|
||||||
|
from typing import Dict, List, Tuple, TypedDict
|
||||||
|
from collections.abc import Iterator
|
||||||
|
|
||||||
import einops
|
import einops
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
@@ -68,7 +72,11 @@ from tqdm import trange
|
|||||||
from lerobot.configs import parser
|
from lerobot.configs import parser
|
||||||
from lerobot.configs.eval import EvalPipelineConfig
|
from lerobot.configs.eval import EvalPipelineConfig
|
||||||
from lerobot.envs.factory import make_env
|
from lerobot.envs.factory import make_env
|
||||||
from lerobot.envs.utils import add_envs_task, check_env_attributes_and_types, preprocess_observation, preprocess_observation1
|
from lerobot.envs.utils import (
|
||||||
|
add_envs_task,
|
||||||
|
check_env_attributes_and_types,
|
||||||
|
preprocess_observation,
|
||||||
|
)
|
||||||
from lerobot.policies.factory import make_policy
|
from lerobot.policies.factory import make_policy
|
||||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||||
from lerobot.policies.utils import get_device_from_parameters
|
from lerobot.policies.utils import get_device_from_parameters
|
||||||
@@ -79,9 +87,6 @@ from lerobot.utils.utils import (
|
|||||||
init_logging,
|
init_logging,
|
||||||
inside_slurm,
|
inside_slurm,
|
||||||
)
|
)
|
||||||
from typing import TypedDict, Dict, List, Tuple, Iterator
|
|
||||||
from collections import defaultdict
|
|
||||||
import concurrent.futures as cf
|
|
||||||
|
|
||||||
|
|
||||||
def rollout(
|
def rollout(
|
||||||
@@ -485,8 +490,12 @@ def _compile_episode_data(
|
|||||||
data_dict["index"] = torch.arange(start_data_index, start_data_index + total_frames, 1)
|
data_dict["index"] = torch.arange(start_data_index, start_data_index + total_frames, 1)
|
||||||
|
|
||||||
return data_dict
|
return data_dict
|
||||||
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy
|
|
||||||
|
|
||||||
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||||
|
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy
|
||||||
|
|
||||||
|
|
||||||
def _inject_normalization_stats(policy: SmolVLAPolicy, dataset_meta: LeRobotDatasetMetadata):
|
def _inject_normalization_stats(policy: SmolVLAPolicy, dataset_meta: LeRobotDatasetMetadata):
|
||||||
"""Recreate normalization layers with proper stats from the dataset."""
|
"""Recreate normalization layers with proper stats from the dataset."""
|
||||||
from lerobot.policies.normalize import Normalize, Unnormalize
|
from lerobot.policies.normalize import Normalize, Unnormalize
|
||||||
@@ -518,7 +527,8 @@ def _inject_normalization_stats(policy: SmolVLAPolicy, dataset_meta: LeRobotData
|
|||||||
|
|
||||||
def load_smolvla(cfg, dataset_repo: str, policy):
|
def load_smolvla(cfg, dataset_repo: str, policy):
|
||||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||||
dataset = LeRobotDataset(dataset_repo, root='/raid/jade/.cache/huggingface/datasets/')
|
|
||||||
|
dataset = LeRobotDataset(dataset_repo, root="/raid/jade/.cache/huggingface/datasets/")
|
||||||
_inject_normalization_stats(policy=policy, dataset_meta=dataset.meta) # only needed if stats are missing
|
_inject_normalization_stats(policy=policy, dataset_meta=dataset.meta) # only needed if stats are missing
|
||||||
return policy.to("cuda"), dataset
|
return policy.to("cuda"), dataset
|
||||||
|
|
||||||
@@ -529,8 +539,8 @@ def eval_main(cfg: EvalPipelineConfig):
|
|||||||
|
|
||||||
# Check device is available
|
# Check device is available
|
||||||
device = get_safe_torch_device(cfg.policy.device, log=True)
|
device = get_safe_torch_device(cfg.policy.device, log=True)
|
||||||
#login to hf
|
# login to hf
|
||||||
from huggingface_hub import login
|
|
||||||
# login()
|
# login()
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
@@ -549,7 +559,7 @@ def eval_main(cfg: EvalPipelineConfig):
|
|||||||
breakpoint()
|
breakpoint()
|
||||||
# policy, _ = load_smolvla(cfg.policy, "physical-intelligence/libero", policy)
|
# policy, _ = load_smolvla(cfg.policy, "physical-intelligence/libero", policy)
|
||||||
# rename "image" -> "observation.image"
|
# rename "image" -> "observation.image"
|
||||||
|
|
||||||
policy.eval()
|
policy.eval()
|
||||||
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
|
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
|
||||||
info = eval_policy_all(
|
info = eval_policy_all(
|
||||||
@@ -584,10 +594,11 @@ def eval_main(cfg: EvalPipelineConfig):
|
|||||||
|
|
||||||
# ---- typed payload returned by one task eval ----
|
# ---- typed payload returned by one task eval ----
|
||||||
class TaskMetrics(TypedDict):
|
class TaskMetrics(TypedDict):
|
||||||
sum_rewards: List[float]
|
sum_rewards: list[float]
|
||||||
max_rewards: List[float]
|
max_rewards: list[float]
|
||||||
successes: List[bool]
|
successes: list[bool]
|
||||||
video_paths: List[str]
|
video_paths: list[str]
|
||||||
|
|
||||||
|
|
||||||
ACC_KEYS = ("sum_rewards", "max_rewards", "successes", "video_paths")
|
ACC_KEYS = ("sum_rewards", "max_rewards", "successes", "video_paths")
|
||||||
|
|
||||||
@@ -610,7 +621,7 @@ def eval_policy_all(
|
|||||||
"""
|
"""
|
||||||
global_start = time.time()
|
global_start = time.time()
|
||||||
|
|
||||||
# inner: evaluate a single (suite, task)
|
# inner: evaluate a single (suite, task)
|
||||||
def eval_one(
|
def eval_one(
|
||||||
task_group: str,
|
task_group: str,
|
||||||
task_id: int,
|
task_id: int,
|
||||||
@@ -650,27 +661,36 @@ def eval_policy_all(
|
|||||||
video_paths=task_result.get("video_paths", []),
|
video_paths=task_result.get("video_paths", []),
|
||||||
)
|
)
|
||||||
|
|
||||||
# result producer: sequential or threaded, same consumer
|
# result producer: sequential or threaded, same consumer
|
||||||
def iter_task_results() -> Iterator[Tuple[str, int, TaskMetrics]]:
|
def iter_task_results() -> Iterator[tuple[str, int, TaskMetrics]]:
|
||||||
if max_parallel_tasks == 1:
|
if max_parallel_tasks == 1:
|
||||||
for task_group, tasks in envs.items():
|
for task_group, tasks in envs.items():
|
||||||
for task_id, vec in tasks.items():
|
for task_id, vec in tasks.items():
|
||||||
yield task_group, task_id, eval_one(
|
yield (
|
||||||
task_group, task_id, vec,
|
task_group,
|
||||||
policy=policy,
|
task_id,
|
||||||
n_episodes=n_episodes,
|
eval_one(
|
||||||
max_episodes_rendered=max_episodes_rendered,
|
task_group,
|
||||||
videos_dir=videos_dir,
|
task_id,
|
||||||
return_episode_data=return_episode_data,
|
vec,
|
||||||
start_seed=start_seed,
|
policy=policy,
|
||||||
|
n_episodes=n_episodes,
|
||||||
|
max_episodes_rendered=max_episodes_rendered,
|
||||||
|
videos_dir=videos_dir,
|
||||||
|
return_episode_data=return_episode_data,
|
||||||
|
start_seed=start_seed,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
with cf.ThreadPoolExecutor(max_workers=max_parallel_tasks) as executor:
|
with cf.ThreadPoolExecutor(max_workers=max_parallel_tasks) as executor:
|
||||||
fut2key: Dict[cf.Future, Tuple[str, int]] = {}
|
fut2key: dict[cf.Future, tuple[str, int]] = {}
|
||||||
for task_group, tasks in envs.items():
|
for task_group, tasks in envs.items():
|
||||||
for task_id, vec in tasks.items():
|
for task_id, vec in tasks.items():
|
||||||
fut = executor.submit(
|
fut = executor.submit(
|
||||||
eval_one, task_group, task_id, vec,
|
eval_one,
|
||||||
|
task_group,
|
||||||
|
task_id,
|
||||||
|
vec,
|
||||||
policy=policy,
|
policy=policy,
|
||||||
n_episodes=n_episodes,
|
n_episodes=n_episodes,
|
||||||
max_episodes_rendered=max_episodes_rendered,
|
max_episodes_rendered=max_episodes_rendered,
|
||||||
@@ -683,9 +703,9 @@ def eval_policy_all(
|
|||||||
task_group, task_id = fut2key[fut]
|
task_group, task_id = fut2key[fut]
|
||||||
yield task_group, task_id, fut.result()
|
yield task_group, task_id, fut.result()
|
||||||
|
|
||||||
# single accumulator path on the main thread
|
# single accumulator path on the main thread
|
||||||
group_acc: Dict[str, Dict[str, List]] = defaultdict(lambda: {k: [] for k in ACC_KEYS})
|
group_acc: dict[str, dict[str, list]] = defaultdict(lambda: {k: [] for k in ACC_KEYS})
|
||||||
overall: Dict[str, List] = {k: [] for k in ACC_KEYS}
|
overall: dict[str, list] = {k: [] for k in ACC_KEYS}
|
||||||
|
|
||||||
for task_group, task_id, metrics in iter_task_results():
|
for task_group, task_id, metrics in iter_task_results():
|
||||||
acc = group_acc[task_group]
|
acc = group_acc[task_group]
|
||||||
@@ -694,7 +714,7 @@ def eval_policy_all(
|
|||||||
overall[k].extend(metrics[k])
|
overall[k].extend(metrics[k])
|
||||||
|
|
||||||
# build outputs
|
# build outputs
|
||||||
results: Dict[str, dict] = {}
|
results: dict[str, dict] = {}
|
||||||
for task_group, data in group_acc.items():
|
for task_group, data in group_acc.items():
|
||||||
suite_rewards = data["sum_rewards"]
|
suite_rewards = data["sum_rewards"]
|
||||||
suite_max = data["max_rewards"]
|
suite_max = data["max_rewards"]
|
||||||
@@ -720,9 +740,15 @@ def eval_policy_all(
|
|||||||
global_eval_ep_s = global_eval_s / max(1, len(overall["sum_rewards"]))
|
global_eval_ep_s = global_eval_s / max(1, len(overall["sum_rewards"]))
|
||||||
results["overall"] = {
|
results["overall"] = {
|
||||||
"aggregated": {
|
"aggregated": {
|
||||||
"avg_sum_reward": float(np.nanmean(overall["sum_rewards"])) if overall["sum_rewards"] else float("nan"),
|
"avg_sum_reward": float(np.nanmean(overall["sum_rewards"]))
|
||||||
"avg_max_reward": float(np.nanmean(overall["max_rewards"])) if overall["max_rewards"] else float("nan"),
|
if overall["sum_rewards"]
|
||||||
"pc_success": float(np.nanmean(overall["successes"]) * 100) if overall["successes"] else float("nan"),
|
else float("nan"),
|
||||||
|
"avg_max_reward": float(np.nanmean(overall["max_rewards"]))
|
||||||
|
if overall["max_rewards"]
|
||||||
|
else float("nan"),
|
||||||
|
"pc_success": float(np.nanmean(overall["successes"]) * 100)
|
||||||
|
if overall["successes"]
|
||||||
|
else float("nan"),
|
||||||
"eval_s": global_eval_s,
|
"eval_s": global_eval_s,
|
||||||
"eval_ep_s": global_eval_ep_s,
|
"eval_ep_s": global_eval_ep_s,
|
||||||
},
|
},
|
||||||
@@ -732,7 +758,6 @@ def eval_policy_all(
|
|||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
init_logging()
|
init_logging()
|
||||||
eval_main()
|
eval_main()
|
||||||
|
|||||||
@@ -105,6 +105,7 @@ def update_policy(
|
|||||||
train_metrics.update_s = time.perf_counter() - start_time
|
train_metrics.update_s = time.perf_counter() - start_time
|
||||||
return train_metrics, output_dict
|
return train_metrics, output_dict
|
||||||
|
|
||||||
|
|
||||||
# def _inject_normalization_stats(policy: SmolVLAPolicy, dataset_meta: LeRobotDatasetMetadata):
|
# def _inject_normalization_stats(policy: SmolVLAPolicy, dataset_meta: LeRobotDatasetMetadata):
|
||||||
# """Recreate normalization layers with dataset stats if missing (Adil's workaround)."""
|
# """Recreate normalization layers with dataset stats if missing (Adil's workaround)."""
|
||||||
# from lerobot.policies.normalize import Normalize, Unnormalize
|
# from lerobot.policies.normalize import Normalize, Unnormalize
|
||||||
@@ -132,6 +133,7 @@ def update_policy(
|
|||||||
|
|
||||||
# print("✅ Normalization layers injected with dataset stats.")
|
# print("✅ Normalization layers injected with dataset stats.")
|
||||||
|
|
||||||
|
|
||||||
@parser.wrap()
|
@parser.wrap()
|
||||||
def train(cfg: TrainPipelineConfig):
|
def train(cfg: TrainPipelineConfig):
|
||||||
cfg.validate()
|
cfg.validate()
|
||||||
@@ -271,9 +273,12 @@ def train(cfg: TrainPipelineConfig):
|
|||||||
if cfg.env and is_eval_step:
|
if cfg.env and is_eval_step:
|
||||||
step_id = get_step_identifier(step, cfg.steps)
|
step_id = get_step_identifier(step, cfg.steps)
|
||||||
logging.info(f"Eval policy at step {step}")
|
logging.info(f"Eval policy at step {step}")
|
||||||
with torch.no_grad(), (torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext()):
|
with (
|
||||||
|
torch.no_grad(),
|
||||||
|
torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(),
|
||||||
|
):
|
||||||
eval_info = eval_policy_all(
|
eval_info = eval_policy_all(
|
||||||
eval_env, # dict[suite][task_id] -> vec_env
|
eval_env, # dict[suite][task_id] -> vec_env
|
||||||
policy,
|
policy,
|
||||||
cfg.eval.n_episodes,
|
cfg.eval.n_episodes,
|
||||||
videos_dir=videos_dir,
|
videos_dir=videos_dir,
|
||||||
@@ -295,15 +300,15 @@ def train(cfg: TrainPipelineConfig):
|
|||||||
# meters/tracker
|
# meters/tracker
|
||||||
eval_metrics = {
|
eval_metrics = {
|
||||||
"avg_sum_reward": AverageMeter("∑rwrd", ":.3f"),
|
"avg_sum_reward": AverageMeter("∑rwrd", ":.3f"),
|
||||||
"pc_success": AverageMeter("success", ":.1f"),
|
"pc_success": AverageMeter("success", ":.1f"),
|
||||||
"eval_s": AverageMeter("eval_s", ":.3f"),
|
"eval_s": AverageMeter("eval_s", ":.3f"),
|
||||||
}
|
}
|
||||||
eval_tracker = MetricsTracker(
|
eval_tracker = MetricsTracker(
|
||||||
cfg.batch_size, dataset.num_frames, dataset.num_episodes, eval_metrics, initial_step=step
|
cfg.batch_size, dataset.num_frames, dataset.num_episodes, eval_metrics, initial_step=step
|
||||||
)
|
)
|
||||||
eval_tracker.eval_s = aggregated.get("eval_s", 0.0)
|
eval_tracker.eval_s = aggregated.get("eval_s", 0.0)
|
||||||
eval_tracker.avg_sum_reward = aggregated.get("avg_sum_reward", float("nan"))
|
eval_tracker.avg_sum_reward = aggregated.get("avg_sum_reward", float("nan"))
|
||||||
eval_tracker.pc_success = aggregated.get("pc_success", float("nan"))
|
eval_tracker.pc_success = aggregated.get("pc_success", float("nan"))
|
||||||
if wandb_logger:
|
if wandb_logger:
|
||||||
wandb_log_dict = {**eval_tracker.to_dict(), **eval_info}
|
wandb_log_dict = {**eval_tracker.to_dict(), **eval_info}
|
||||||
wandb_logger.log_dict(wandb_log_dict, step, mode="eval")
|
wandb_logger.log_dict(wandb_log_dict, step, mode="eval")
|
||||||
|
|||||||
@@ -104,6 +104,7 @@ def update_policy(
|
|||||||
train_metrics.update_s = time.perf_counter() - start_time
|
train_metrics.update_s = time.perf_counter() - start_time
|
||||||
return train_metrics, output_dict
|
return train_metrics, output_dict
|
||||||
|
|
||||||
|
|
||||||
def _inject_normalization_stats(policy: SmolVLAPolicy, dataset_meta: LeRobotDatasetMetadata):
|
def _inject_normalization_stats(policy: SmolVLAPolicy, dataset_meta: LeRobotDatasetMetadata):
|
||||||
"""Recreate normalization layers with dataset stats if missing (Adil's workaround)."""
|
"""Recreate normalization layers with dataset stats if missing (Adil's workaround)."""
|
||||||
from lerobot.policies.normalize import Normalize, Unnormalize
|
from lerobot.policies.normalize import Normalize, Unnormalize
|
||||||
@@ -115,15 +116,15 @@ def _inject_normalization_stats(policy: SmolVLAPolicy, dataset_meta: LeRobotData
|
|||||||
stats = {}
|
stats = {}
|
||||||
for key, stat_dict in dataset_meta.stats.items():
|
for key, stat_dict in dataset_meta.stats.items():
|
||||||
stats[key] = {
|
stats[key] = {
|
||||||
stat_type: torch.as_tensor(stat_array)
|
stat_type: torch.as_tensor(stat_array) if isinstance(stat_array, np.ndarray) else stat_array
|
||||||
if isinstance(stat_array, np.ndarray)
|
|
||||||
else stat_array
|
|
||||||
for stat_type, stat_array in stat_dict.items()
|
for stat_type, stat_array in stat_dict.items()
|
||||||
}
|
}
|
||||||
|
|
||||||
normalize_inputs = Normalize(policy.config.input_features, policy.config.normalization_mapping, stats)
|
normalize_inputs = Normalize(policy.config.input_features, policy.config.normalization_mapping, stats)
|
||||||
normalize_targets = Normalize(policy.config.output_features, policy.config.normalization_mapping, stats)
|
normalize_targets = Normalize(policy.config.output_features, policy.config.normalization_mapping, stats)
|
||||||
unnormalize_outputs = Unnormalize(policy.config.output_features, policy.config.normalization_mapping, stats)
|
unnormalize_outputs = Unnormalize(
|
||||||
|
policy.config.output_features, policy.config.normalization_mapping, stats
|
||||||
|
)
|
||||||
|
|
||||||
policy.normalize_inputs = normalize_inputs
|
policy.normalize_inputs = normalize_inputs
|
||||||
policy.normalize_targets = normalize_targets
|
policy.normalize_targets = normalize_targets
|
||||||
@@ -131,6 +132,7 @@ def _inject_normalization_stats(policy: SmolVLAPolicy, dataset_meta: LeRobotData
|
|||||||
|
|
||||||
print("✅ Normalization layers injected with dataset stats.")
|
print("✅ Normalization layers injected with dataset stats.")
|
||||||
|
|
||||||
|
|
||||||
@parser.wrap()
|
@parser.wrap()
|
||||||
def train(cfg: TrainPipelineConfig):
|
def train(cfg: TrainPipelineConfig):
|
||||||
cfg.validate()
|
cfg.validate()
|
||||||
|
|||||||
@@ -24,12 +24,15 @@ from accelerate.utils import set_seed as accelerate_set_seed
|
|||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
|
|
||||||
|
from lerobot.configs import parser
|
||||||
|
from lerobot.configs.train import TrainPipelineConfig
|
||||||
from lerobot.datasets.factory import make_dataset
|
from lerobot.datasets.factory import make_dataset
|
||||||
from lerobot.datasets.sampler import EpisodeAwareSampler
|
from lerobot.datasets.sampler import EpisodeAwareSampler
|
||||||
from lerobot.envs.factory import make_env
|
from lerobot.envs.factory import make_env
|
||||||
from lerobot.optim.factory import make_optimizer_and_scheduler
|
from lerobot.optim.factory import make_optimizer_and_scheduler
|
||||||
from lerobot.policies.factory import make_policy
|
from lerobot.policies.factory import make_policy
|
||||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||||
|
from lerobot.scripts.eval import eval_policy
|
||||||
from lerobot.utils.logging_utils import AverageMeter, MetricsTracker
|
from lerobot.utils.logging_utils import AverageMeter, MetricsTracker
|
||||||
from lerobot.utils.train_utils import (
|
from lerobot.utils.train_utils import (
|
||||||
get_step_checkpoint_dir,
|
get_step_checkpoint_dir,
|
||||||
@@ -43,9 +46,6 @@ from lerobot.utils.utils import (
|
|||||||
has_method,
|
has_method,
|
||||||
init_logging,
|
init_logging,
|
||||||
)
|
)
|
||||||
from lerobot.configs import parser
|
|
||||||
from lerobot.configs.train import TrainPipelineConfig
|
|
||||||
from lerobot.scripts.eval import eval_policy
|
|
||||||
|
|
||||||
|
|
||||||
def update_policy(
|
def update_policy(
|
||||||
@@ -100,6 +100,7 @@ def train(cfg: TrainPipelineConfig):
|
|||||||
|
|
||||||
# Initialize accelerator
|
# Initialize accelerator
|
||||||
from accelerate.utils import DistributedDataParallelKwargs
|
from accelerate.utils import DistributedDataParallelKwargs
|
||||||
|
|
||||||
# added by jade 2 lines
|
# added by jade 2 lines
|
||||||
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=False)
|
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=False)
|
||||||
accelerator = Accelerator(..., kwargs_handlers=[ddp_kwargs])
|
accelerator = Accelerator(..., kwargs_handlers=[ddp_kwargs])
|
||||||
@@ -357,7 +358,7 @@ def train(cfg: TrainPipelineConfig):
|
|||||||
|
|
||||||
if accelerator.is_main_process:
|
if accelerator.is_main_process:
|
||||||
logging.info("End of training")
|
logging.info("End of training")
|
||||||
accelerator.end_training() # added by jade
|
accelerator.end_training() # added by jade
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
Reference in New Issue
Block a user