mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-25 05:29:55 +00:00
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
@@ -55,4 +55,4 @@ python src/lerobot/scripts/eval.py \
|
||||
# --num_trials_per_task 10 \
|
||||
# --video_out_path "data/libero/videos" \
|
||||
# --device "cuda" \
|
||||
# --seed 7
|
||||
# --seed 7
|
||||
|
||||
@@ -61,9 +61,9 @@ def make_env(
|
||||
|
||||
env_cls = gym.vector.AsyncVectorEnv if use_async_envs else gym.vector.SyncVectorEnv
|
||||
|
||||
|
||||
if "libero" in cfg.type:
|
||||
from lerobot.envs.libero import create_libero_envs
|
||||
|
||||
return create_libero_envs(
|
||||
task=cfg.task,
|
||||
n_envs=n_envs,
|
||||
@@ -74,17 +74,16 @@ def make_env(
|
||||
multitask_eval=cfg.multitask_eval,
|
||||
)
|
||||
|
||||
|
||||
package_name = f"gym_{cfg.type}"
|
||||
try:
|
||||
importlib.import_module(package_name)
|
||||
except ModuleNotFoundError as e:
|
||||
raise ModuleNotFoundError(
|
||||
f"{package_name} is not installed. Install with: pip install \"lerobot[{cfg.type}]\""
|
||||
f'{package_name} is not installed. Install with: pip install "lerobot[{cfg.type}]"'
|
||||
) from e
|
||||
|
||||
gym_handle = f"{package_name}/{cfg.task}"
|
||||
|
||||
|
||||
def _make_one():
|
||||
return gym.make(gym_handle, disable_env_checker=True, **(cfg.gym_kwargs or {}))
|
||||
|
||||
@@ -93,4 +92,3 @@ def make_env(
|
||||
# normalize to {suite: {task_id: vec_env}} for consistency
|
||||
suite_name = cfg.type # e.g., "pusht", "aloha"
|
||||
return {suite_name: {0: vec}}
|
||||
|
||||
|
||||
+35
-20
@@ -1,11 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from itertools import chain
|
||||
from typing import Any
|
||||
from typing import Any, Dict, List
|
||||
from collections.abc import Callable, Iterable, Mapping, Sequence
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
@@ -14,16 +16,12 @@ from gymnasium import spaces
|
||||
from libero.libero import benchmark, get_libero_path
|
||||
from libero.libero.envs import OffScreenRenderEnv
|
||||
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from typing import Any, Callable, Dict, Iterable, List, Mapping, Sequence
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---- Helpers -----------------------------------------------------------------
|
||||
|
||||
def _parse_camera_names(camera_name: str | Sequence[str]) -> List[str]:
|
||||
|
||||
def _parse_camera_names(camera_name: str | Sequence[str]) -> list[str]:
|
||||
"""Normalize camera_name into a non-empty list of strings."""
|
||||
if isinstance(camera_name, str):
|
||||
cams = [c.strip() for c in camera_name.split(",") if c.strip()]
|
||||
@@ -47,14 +45,14 @@ def _get_suite(name: str):
|
||||
return suite
|
||||
|
||||
|
||||
def _select_task_ids(total_tasks: int, task_ids: Iterable[int] | None) -> List[int]:
|
||||
def _select_task_ids(total_tasks: int, task_ids: Iterable[int] | None) -> list[int]:
|
||||
"""Validate/normalize task ids. If None → all tasks."""
|
||||
if task_ids is None:
|
||||
return list(range(total_tasks))
|
||||
ids = sorted(set(int(t) for t in task_ids))
|
||||
ids = sorted({int(t) for t in task_ids})
|
||||
for t in ids:
|
||||
if t < 0 or t >= total_tasks:
|
||||
raise ValueError(f"task_id {t} out of range [0, {total_tasks-1}].")
|
||||
raise ValueError(f"task_id {t} out of range [0, {total_tasks - 1}].")
|
||||
return ids
|
||||
|
||||
|
||||
@@ -64,16 +62,25 @@ def _make_env_fns(
|
||||
suite_name: str,
|
||||
task_id: int,
|
||||
n_envs: int,
|
||||
camera_names: List[str],
|
||||
camera_names: list[str],
|
||||
init_states: bool,
|
||||
gym_kwargs: Mapping[str, Any],
|
||||
LiberoEnv: type, # injected to avoid forward ref issues if needed
|
||||
) -> List[Callable[[], "LiberoEnv"]]:
|
||||
) -> list[Callable[[], LiberoEnv]]:
|
||||
"""Build n_envs factory callables for a single (suite, task_id)."""
|
||||
joined_cams = ",".join(camera_names) # keep backward-compat: downstream expects a string
|
||||
fns: List[Callable[[], "LiberoEnv"]] = []
|
||||
fns: list[Callable[[], LiberoEnv]] = []
|
||||
for i in range(n_envs):
|
||||
def _mk(i=i, suite=suite, task_id=task_id, suite_name=suite_name, joined_cams=joined_cams, init_states=init_states, gym_kwargs=dict(gym_kwargs)):
|
||||
|
||||
def _mk(
|
||||
i=i,
|
||||
suite=suite,
|
||||
task_id=task_id,
|
||||
suite_name=suite_name,
|
||||
joined_cams=joined_cams,
|
||||
init_states=init_states,
|
||||
gym_kwargs=dict(gym_kwargs),
|
||||
):
|
||||
return LiberoEnv(
|
||||
task_suite=suite,
|
||||
task_id=task_id,
|
||||
@@ -83,11 +90,14 @@ def _make_env_fns(
|
||||
episode_index=i,
|
||||
**gym_kwargs,
|
||||
)
|
||||
|
||||
fns.append(_mk)
|
||||
return fns
|
||||
|
||||
|
||||
# ---- Main API ----------------------------------------------------------------
|
||||
|
||||
|
||||
def create_libero_envs(
|
||||
task: str,
|
||||
n_envs: int,
|
||||
@@ -130,12 +140,15 @@ def create_libero_envs(
|
||||
|
||||
logger.info(
|
||||
"Creating LIBERO envs | suites=%s | n_envs(per task)=%d | init_states=%s | multitask_eval=%s",
|
||||
suite_names, n_envs, init_states, bool(multitask_eval)
|
||||
suite_names,
|
||||
n_envs,
|
||||
init_states,
|
||||
bool(multitask_eval),
|
||||
)
|
||||
if task_ids_filter is not None:
|
||||
logger.info("Restricting to task_ids=%s", task_ids_filter)
|
||||
|
||||
out: Dict[str, Dict[int, Any]] = defaultdict(dict)
|
||||
out: dict[str, dict[int, Any]] = defaultdict(dict)
|
||||
|
||||
for suite_name in suite_names:
|
||||
suite = _get_suite(suite_name)
|
||||
@@ -161,6 +174,8 @@ def create_libero_envs(
|
||||
|
||||
# return plain dicts for predictability
|
||||
return {suite: dict(task_map) for suite, task_map in out.items()}
|
||||
|
||||
|
||||
def quat2axisangle(quat):
|
||||
"""
|
||||
Copied from robosuite: https://github.com/ARISE-Initiative/robosuite/blob/eafb81f54ffc104f905ee48a16bb15f059176ad3/robosuite/utils/transform_utils.py#L490C1-L512C55
|
||||
@@ -256,10 +271,10 @@ class LiberoEnv(gym.Env):
|
||||
self._env = self._make_envs_task(task_suite, self.task_id)
|
||||
TASK_SUITE_MAX_STEPS: dict[str, int] = {
|
||||
"libero_spatial": 220, # longest training demo has 193 steps
|
||||
"libero_object": 280, # longest training demo has 254 steps
|
||||
"libero_goal": 300, # longest training demo has 270 steps
|
||||
"libero_10": 520, # longest training demo has 505 steps
|
||||
"libero_90": 400, # longest training demo has 373 steps
|
||||
"libero_object": 280, # longest training demo has 254 steps
|
||||
"libero_goal": 300, # longest training demo has 270 steps
|
||||
"libero_10": 520, # longest training demo has 505 steps
|
||||
"libero_90": 400, # longest training demo has 373 steps
|
||||
}
|
||||
default_steps = 500
|
||||
self._max_episode_steps = TASK_SUITE_MAX_STEPS.get(task_suite_name, default_steps)
|
||||
|
||||
@@ -80,6 +80,7 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
||||
|
||||
return return_observations
|
||||
|
||||
|
||||
def preprocess_observation1(
|
||||
observations: dict[str, np.ndarray], cfg: dict[str, Any] = None
|
||||
) -> dict[str, Tensor]:
|
||||
@@ -130,6 +131,8 @@ def preprocess_observation1(
|
||||
if "task" in observations:
|
||||
return_observations["task"] = observations["task"]
|
||||
return return_observations
|
||||
|
||||
|
||||
def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]:
|
||||
# TODO(aliberts, rcadene): remove this hardcoding of keys and just use the nested keys as is
|
||||
# (need to also refactor preprocess_observation and externalize normalization from policies)
|
||||
@@ -183,6 +186,7 @@ def add_envs_task(env: gym.vector.VectorEnv, observation: dict[str, Any]) -> dic
|
||||
observation["task"] = ["" for _ in range(num_envs)]
|
||||
return observation
|
||||
|
||||
|
||||
def _close_single_env(env: Any) -> None:
|
||||
"""Try to close a single env object if it exposes .close()."""
|
||||
try:
|
||||
@@ -193,6 +197,7 @@ def _close_single_env(env: Any) -> None:
|
||||
# Best-effort close: log but don't raise
|
||||
LOG.debug("Exception while closing env %s: %s", env, exc)
|
||||
|
||||
|
||||
def close_envs(env_or_collection: Any) -> None:
|
||||
"""
|
||||
Close a single env or any nested structure of envs.
|
||||
@@ -225,4 +230,4 @@ def close_envs(env_or_collection: Any) -> None:
|
||||
|
||||
# Fallback: try to close if possible
|
||||
if hasattr(env_or_collection, "close"):
|
||||
_close_single_env(env_or_collection)
|
||||
_close_single_env(env_or_collection)
|
||||
|
||||
@@ -31,10 +31,10 @@ from lerobot.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
|
||||
from lerobot.policies.smolpi0.configuration_smolpi0 import SMOLPI0Config
|
||||
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
|
||||
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||
from lerobot.policies.smolpi0.configuration_smolpi0 import SMOLPI0Config
|
||||
|
||||
|
||||
def get_policy_class(name: str) -> PreTrainedPolicy:
|
||||
|
||||
@@ -309,8 +309,8 @@ class NormalizePerRobotType(nn.Module):
|
||||
getattr(self, f"{robot_type}_buffer_" + key.replace(".", "_")) for robot_type in robot_types
|
||||
]
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
mean = torch.stack([buffers[i]["mean"] for i in range(len(robot_types))],dim=0)
|
||||
std = torch.stack([buffers[i]["std"] for i in range(len(robot_types))],dim=0)
|
||||
mean = torch.stack([buffers[i]["mean"] for i in range(len(robot_types))], dim=0)
|
||||
std = torch.stack([buffers[i]["std"] for i in range(len(robot_types))], dim=0)
|
||||
if batch[key].ndim == 3:
|
||||
mean = mean.unsqueeze(1)
|
||||
std = std.unsqueeze(1)
|
||||
@@ -332,6 +332,8 @@ class NormalizePerRobotType(nn.Module):
|
||||
else:
|
||||
raise ValueError(norm_mode)
|
||||
return batch
|
||||
|
||||
|
||||
# TODO (azouitine): We should replace all normalization on the policies with register_buffer normalization
|
||||
# and remove the `Normalize` and `Unnormalize` classes.
|
||||
def _initialize_stats_buffers(
|
||||
|
||||
@@ -14,12 +14,12 @@
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import (
|
||||
CosineDecayWithWarmupSchedulerConfig,
|
||||
)
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -52,7 +52,7 @@ class SMOLPI0Config(PreTrainedConfig):
|
||||
max_action_dim: int = 32
|
||||
|
||||
# Image preprocessing
|
||||
resize_imgs_with_padding: tuple[int, int] = (512, 512) #(224, 224)
|
||||
resize_imgs_with_padding: tuple[int, int] = (512, 512) # (224, 224)
|
||||
|
||||
# Add empty images. Used by pi0_aloha_sim which adds the empty
|
||||
# left and right wrist cameras in addition to the top camera.
|
||||
@@ -107,14 +107,14 @@ class SMOLPI0Config(PreTrainedConfig):
|
||||
|
||||
add_image_special_tokens: bool = False
|
||||
add_prompt_template: bool = False
|
||||
prefix_prompt_template: str = f"<|im_start|>User: What action should the robot take to"
|
||||
suffix_prompt_template: str = f"?\nAssistant:"
|
||||
prefix_prompt_template: str = "<|im_start|>User: What action should the robot take to"
|
||||
suffix_prompt_template: str = "?\nAssistant:"
|
||||
|
||||
attention_mode: str = "self_attn"
|
||||
|
||||
prefix_length: int = -1 # n_obs_steps * num_cameras * num_image_token_per_image + tokenizer_max_length
|
||||
prefix_length: int = -1 # n_obs_steps * num_cameras * num_image_token_per_image + tokenizer_max_length
|
||||
|
||||
past_obs_keys: str = f"image"
|
||||
past_obs_keys: str = "image"
|
||||
|
||||
add_local_special_image_tokens: bool = False
|
||||
|
||||
@@ -122,7 +122,7 @@ class SMOLPI0Config(PreTrainedConfig):
|
||||
|
||||
state_to_prefix: bool = False
|
||||
|
||||
pad_language_to: str = "longest" # "max_length"
|
||||
pad_language_to: str = "longest" # "max_length"
|
||||
|
||||
num_expert_layers: int = -1
|
||||
num_vlm_layers: int = -1
|
||||
@@ -144,9 +144,9 @@ class SMOLPI0Config(PreTrainedConfig):
|
||||
|
||||
shuffle_camera_positions: bool = False
|
||||
vlm_img_size: int = -1
|
||||
|
||||
|
||||
regression_loss: bool = False
|
||||
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
if self.vlm_img_size > 0:
|
||||
@@ -198,7 +198,7 @@ class SMOLPI0Config(PreTrainedConfig):
|
||||
)
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> list: # FIXME(mshukor): support spacing between observations
|
||||
def observation_delta_indices(self) -> list: # FIXME(mshukor): support spacing between observations
|
||||
return [-k for k in range(0, self.n_obs_steps * self.n_obs_gap, self.n_obs_gap)][::-1]
|
||||
|
||||
@property
|
||||
|
||||
@@ -85,7 +85,7 @@ def flex_attention_forward(
|
||||
|
||||
b_mask, h_mask, q_len, kv_len = causal_mask.shape # The shape of your mask
|
||||
|
||||
block_size = 128 # limitation of flex attention
|
||||
block_size = 128 # limitation of flex attention
|
||||
q_len_rounded = _round_up_to_multiple(q_len, block_size)
|
||||
kv_len_rounded = _round_up_to_multiple(kv_len, block_size)
|
||||
|
||||
@@ -95,7 +95,7 @@ def flex_attention_forward(
|
||||
pad_k = kv_len_rounded - kv_len
|
||||
if pad_q > 0 or pad_k > 0:
|
||||
padded_causal_mask = F.pad(causal_mask, (0, pad_k, 0, pad_q), value=0.0)
|
||||
else:
|
||||
else:
|
||||
padded_causal_mask = causal_mask
|
||||
mask_mod_fn_orig = precomputed_mask_factory(padded_causal_mask)
|
||||
|
||||
@@ -107,21 +107,21 @@ def flex_attention_forward(
|
||||
KV_LEN=kv_len_rounded,
|
||||
device=causal_mask.device,
|
||||
)
|
||||
|
||||
|
||||
mask_mod_fn_padded = precomputed_mask_factory(mask_4d)
|
||||
# FIXME(mshukor): compile mask torch.compile(create_block_mask)
|
||||
create_block_mask_compiled = torch.compile(create_block_mask)
|
||||
block_mask = create_block_mask_compiled(
|
||||
mask_mod=mask_mod_fn_padded,
|
||||
B=b_mask,
|
||||
H=None, #
|
||||
H=None, #
|
||||
Q_LEN=q_len_rounded,
|
||||
KV_LEN=kv_len_rounded,
|
||||
BLOCK_SIZE=block_size,
|
||||
device=causal_mask.device,
|
||||
_compile=False,
|
||||
)
|
||||
padded_query_states = F.pad(query_states, (0, 0, 0, pad_q), value=0.0) if pad_q > 0 else query_states
|
||||
padded_query_states = F.pad(query_states, (0, 0, 0, pad_q), value=0.0) if pad_q > 0 else query_states
|
||||
padded_key_states = F.pad(key_states, (0, 0, 0, pad_k), value=0.0) if pad_k > 0 else key_states
|
||||
padded_value_states = F.pad(value_states, (0, 0, 0, pad_k), value=0.0) if pad_k > 0 else value_states
|
||||
# mask is applied inside the kernel, ideally more efficiently than score_mod.
|
||||
|
||||
@@ -50,9 +50,10 @@ policy = Pi0Policy.from_pretrained("lerobot/pi0")
|
||||
"""
|
||||
|
||||
import math
|
||||
from collections import deque
|
||||
import os
|
||||
import re
|
||||
from collections import deque
|
||||
|
||||
import safetensors
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
@@ -66,12 +67,11 @@ from lerobot.policies.normalize import (
|
||||
Unnormalize,
|
||||
UnnormalizePerRobotType,
|
||||
)
|
||||
from lerobot.policies.smolpi0.configuration_smolpi0 import SMOLPI0Config
|
||||
from lerobot.policies.smolpi0.smolvlm_with_expert import (
|
||||
SmolVLMWithExpertModel
|
||||
)
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.smolpi0.configuration_smolpi0 import SMOLPI0Config
|
||||
from lerobot.policies.smolpi0.smolvlm_with_expert import SmolVLMWithExpertModel
|
||||
from lerobot.utils.utils import get_safe_dtype
|
||||
|
||||
OBS_IMAGE = "observation.image"
|
||||
OBS_IMAGES = "observation.images"
|
||||
ACTION = "action"
|
||||
@@ -86,10 +86,13 @@ IMAGES_ORDER = {
|
||||
OBS_IMAGE_3: 2,
|
||||
OBS_IMAGE_4: 3,
|
||||
}
|
||||
import random
|
||||
|
||||
from lerobot.policies.utils import (
|
||||
populate_queues,
|
||||
)
|
||||
import random
|
||||
|
||||
|
||||
def create_sinusoidal_pos_embedding(
|
||||
time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu"
|
||||
) -> Tensor:
|
||||
@@ -171,7 +174,10 @@ def resize_with_pad(img, width, height, pad_value=-1):
|
||||
padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value)
|
||||
return padded_img
|
||||
|
||||
|
||||
_VARIANT_RE = re.compile(r"\.so\d+(?:-[\w]+)?_buffer_")
|
||||
|
||||
|
||||
def canonicalise(k: str) -> str:
|
||||
"""
|
||||
Remove dataset-variant markers like '.so100-blue_' or '.so100_' from a
|
||||
@@ -179,6 +185,7 @@ def canonicalise(k: str) -> str:
|
||||
"""
|
||||
return _VARIANT_RE.sub(".buffer_", k)
|
||||
|
||||
|
||||
def standardise_state_dict(
|
||||
checkpoint: dict[str, torch.Tensor], ref_keys: set[str], *, verbose: bool = True
|
||||
) -> tuple[dict[str, torch.Tensor], list[str]]:
|
||||
@@ -209,6 +216,7 @@ def standardise_state_dict(
|
||||
out.update({k: checkpoint[k] for k in unmatched})
|
||||
return out, unmatched
|
||||
|
||||
|
||||
def load_smolvla(
|
||||
model: torch.nn.Module,
|
||||
filename: str | os.PathLike,
|
||||
@@ -237,6 +245,8 @@ def load_smolvla(
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def pad_vector(vector, new_dim):
|
||||
"""Can be (batch_size x sequence_length x features_dimension)
|
||||
or (batch_size x features_dimension)
|
||||
@@ -286,6 +296,7 @@ def aloha_gripper_to_angular(value):
|
||||
# The values 0.4 and 1.5 were measured on an actual Trossen robot.
|
||||
return normalize(value, min_val=0.4, max_val=1.5)
|
||||
|
||||
|
||||
def rename_checkpoint_keys(checkpoint: dict, rename_str: str):
|
||||
"""
|
||||
Renames keys in a checkpoint dictionary based on the given rename string.
|
||||
@@ -307,6 +318,8 @@ def rename_checkpoint_keys(checkpoint: dict, rename_str: str):
|
||||
k = k.replace(old_key, new_key)
|
||||
new_checkpoint[k] = v
|
||||
return new_checkpoint
|
||||
|
||||
|
||||
def aloha_gripper_from_angular(value):
|
||||
# Convert from the gripper position used by pi0 to the gripper position that is used by Aloha.
|
||||
# Note that the units are still angular but the range is different.
|
||||
@@ -324,6 +337,7 @@ def aloha_gripper_from_angular_inv(value):
|
||||
value = unnormalize(value, min_val=-0.6213, max_val=1.4910)
|
||||
return normalize(value, min_val=0.4, max_val=1.5)
|
||||
|
||||
|
||||
class SMOLPI0Policy(PreTrainedPolicy):
|
||||
"""Wrapper class around VLAFlowMatching model to train and run inference within LeRobot."""
|
||||
|
||||
@@ -374,7 +388,9 @@ class SMOLPI0Policy(PreTrainedPolicy):
|
||||
|
||||
self.language_tokenizer = AutoProcessor.from_pretrained(self.config.vlm_model_name).tokenizer
|
||||
self.model = VLAFlowMatching(config)
|
||||
self.include_past_states = config.n_obs_steps > 1 and OBS_STATE in self.config.past_obs_keys.split(",")
|
||||
self.include_past_states = config.n_obs_steps > 1 and OBS_STATE in self.config.past_obs_keys.split(
|
||||
","
|
||||
)
|
||||
self.include_past_images = config.n_obs_steps > 1 and "image" in self.config.past_obs_keys.split(",")
|
||||
self.num_past_images = self.config.n_obs_steps if self.include_past_images else 1
|
||||
self.reset()
|
||||
@@ -389,31 +405,20 @@ class SMOLPI0Policy(PreTrainedPolicy):
|
||||
for k in self.config.input_features:
|
||||
if any([past_obs_key in k for past_obs_key in self.config.past_obs_keys.split(",")]):
|
||||
self._queues[k] = deque(maxlen=self.config.n_obs_steps)
|
||||
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
if self.config.optimizer_lr_vlm > 0 and self.config.optimizer_lr_vlm != self.config.optimizer_lr:
|
||||
params = [
|
||||
{"params": [p for n, p in self.named_parameters() if ".vlm." not in n and p.requires_grad]},
|
||||
{
|
||||
"params": [
|
||||
p
|
||||
for n, p in self.named_parameters()
|
||||
if not ".vlm." in n and p.requires_grad
|
||||
]
|
||||
},
|
||||
{
|
||||
"params": [
|
||||
p
|
||||
for n, p in self.named_parameters()
|
||||
if ".vlm." in n and p.requires_grad
|
||||
],
|
||||
"params": [p for n, p in self.named_parameters() if ".vlm." in n and p.requires_grad],
|
||||
"lr": self.config.optimizer_lr_vlm,
|
||||
},
|
||||
]
|
||||
return params
|
||||
|
||||
|
||||
else:
|
||||
return self.parameters()
|
||||
|
||||
|
||||
def merge_peft_model_weights(self) -> None:
|
||||
if "lora" in self.config.peft_method:
|
||||
@@ -438,9 +443,7 @@ class SMOLPI0Policy(PreTrainedPolicy):
|
||||
state = self.prepare_state(batch)
|
||||
lang_tokens, lang_masks = self.prepare_language(batch)
|
||||
|
||||
actions = self.model.sample_actions(
|
||||
images, img_masks, lang_tokens, lang_masks, state, noise=noise
|
||||
)
|
||||
actions = self.model.sample_actions(images, img_masks, lang_tokens, lang_masks, state, noise=noise)
|
||||
# Unpad actions
|
||||
original_action_dim = self.config.action_feature.shape[0]
|
||||
actions = actions[:, :, :original_action_dim]
|
||||
@@ -469,6 +472,7 @@ class SMOLPI0Policy(PreTrainedPolicy):
|
||||
device=map_location,
|
||||
checkpoint_keys_mapping="model._orig_mod.//model.",
|
||||
)
|
||||
|
||||
@torch.no_grad
|
||||
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
||||
"""Select a single action given environment observations.
|
||||
@@ -564,8 +568,12 @@ class SMOLPI0Policy(PreTrainedPolicy):
|
||||
present_img_keys = [key for key in self.config.image_features if key in batch]
|
||||
missing_img_keys = [key for key in self.config.image_features if key not in batch]
|
||||
|
||||
present_img_keys = sorted(present_img_keys, key=lambda k: IMAGES_ORDER.get(k, float("inf")), reverse=self.config.reverse_images_order)
|
||||
if self.config.shuffle_camera_positions and ACTION in batch: # only during training
|
||||
present_img_keys = sorted(
|
||||
present_img_keys,
|
||||
key=lambda k: IMAGES_ORDER.get(k, float("inf")),
|
||||
reverse=self.config.reverse_images_order,
|
||||
)
|
||||
if self.config.shuffle_camera_positions and ACTION in batch: # only during training
|
||||
present_img_keys = random.sample(present_img_keys, len(present_img_keys))
|
||||
if len(present_img_keys) == 0:
|
||||
raise ValueError(
|
||||
@@ -609,7 +617,10 @@ class SMOLPI0Policy(PreTrainedPolicy):
|
||||
tasks = [tasks[0] for _ in range(batch[OBS_STATE].shape[0])]
|
||||
|
||||
if self.config.add_prompt_template:
|
||||
tasks = [f"{self.config.prefix_prompt_template}{task}{self.config.suffix_prompt_template}" for task in tasks]
|
||||
tasks = [
|
||||
f"{self.config.prefix_prompt_template}{task}{self.config.suffix_prompt_template}"
|
||||
for task in tasks
|
||||
]
|
||||
else:
|
||||
tasks = [task if task.endswith("\n") else f"{task}\n" for task in tasks]
|
||||
tokenized_prompt = self.language_tokenizer.__call__(
|
||||
@@ -618,7 +629,7 @@ class SMOLPI0Policy(PreTrainedPolicy):
|
||||
padding_side="right",
|
||||
max_length=self.config.tokenizer_max_length,
|
||||
return_tensors="pt",
|
||||
truncation=True, # FIXME(mshukor)
|
||||
truncation=True, # FIXME(mshukor)
|
||||
)
|
||||
|
||||
lang_tokens = tokenized_prompt["input_ids"].to(device=device)
|
||||
@@ -655,7 +666,11 @@ class SMOLPI0Policy(PreTrainedPolicy):
|
||||
|
||||
def prepare_state(self, batch):
|
||||
"""Pad state"""
|
||||
state = batch[OBS_STATE][:, -1, :] if (batch[OBS_STATE].ndim > 2 and not self.include_past_states) else batch[OBS_STATE] # FIXME(mshukor): no state history for now
|
||||
state = (
|
||||
batch[OBS_STATE][:, -1, :]
|
||||
if (batch[OBS_STATE].ndim > 2 and not self.include_past_states)
|
||||
else batch[OBS_STATE]
|
||||
) # FIXME(mshukor): no state history for now
|
||||
state = pad_vector(state, self.config.max_state_dim)
|
||||
return state
|
||||
|
||||
@@ -666,7 +681,9 @@ class SMOLPI0Policy(PreTrainedPolicy):
|
||||
if self.config.relative_actions_mode == "first":
|
||||
actions = torch.cat((actions[:, :1], actions[:, 1:] - actions[:, :1]), dim=1)
|
||||
elif self.config.relative_actions_mode == "state":
|
||||
assert batch[ACTION].shape[-1] == batch[OBS_STATE].shape[-1], "Relative action mode 'state' requires the action and state to have the same dimension."
|
||||
assert batch[ACTION].shape[-1] == batch[OBS_STATE].shape[-1], (
|
||||
"Relative action mode 'state' requires the action and state to have the same dimension."
|
||||
)
|
||||
if state.ndim == 2:
|
||||
state = state.unsqueeze(1)
|
||||
actions = actions - state
|
||||
@@ -674,6 +691,7 @@ class SMOLPI0Policy(PreTrainedPolicy):
|
||||
actions = torch.cat((actions[:, :1], actions[:, 1:] - actions[:, :-1]), dim=1)
|
||||
return actions
|
||||
|
||||
|
||||
def pad_tensor(tensor, max_len, pad_value=0):
|
||||
"""
|
||||
Efficiently pads a tensor along sequence dimension to match max_len.
|
||||
@@ -687,13 +705,16 @@ def pad_tensor(tensor, max_len, pad_value=0):
|
||||
torch.Tensor: Shape (B, max_len, ...) or (B, max_len).
|
||||
"""
|
||||
B, L = tensor.shape[:2]
|
||||
|
||||
|
||||
# Create a padded tensor of max_len and copy the existing values
|
||||
padded_tensor = torch.full((B, max_len, *tensor.shape[2:]), pad_value, dtype=tensor.dtype, device=tensor.device)
|
||||
padded_tensor = torch.full(
|
||||
(B, max_len, *tensor.shape[2:]), pad_value, dtype=tensor.dtype, device=tensor.device
|
||||
)
|
||||
padded_tensor[:, :L] = tensor # Efficient in-place copy
|
||||
|
||||
return padded_tensor
|
||||
|
||||
|
||||
class VLAFlowMatching(nn.Module):
|
||||
"""
|
||||
π0: A Vision-Language-Action Flow Model for General Robot Control
|
||||
@@ -725,7 +746,8 @@ class VLAFlowMatching(nn.Module):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
self.vlm_with_expert = SmolVLMWithExpertModel(model_id=self.config.vlm_model_name,
|
||||
self.vlm_with_expert = SmolVLMWithExpertModel(
|
||||
model_id=self.config.vlm_model_name,
|
||||
freeze_vision_encoder=self.config.freeze_vision_encoder,
|
||||
train_expert_only=self.config.train_expert_only,
|
||||
attention_implementation=self.config.attention_implementation,
|
||||
@@ -736,50 +758,64 @@ class VLAFlowMatching(nn.Module):
|
||||
self_attn_every_n_layers=self.config.self_attn_every_n_layers,
|
||||
expert_width_multiplier=self.config.expert_width_multiplier,
|
||||
self_attn_only_actions=self.config.self_attn_only_actions,
|
||||
)
|
||||
)
|
||||
# self.paligemma_with_expert = self.configure_peft(paligemma_with_expert)
|
||||
self.vlm_with_expert.configure_peft(config=self.config)
|
||||
# Projections are float32
|
||||
self.state_to_prefix = self.config.state_to_prefix
|
||||
if self.state_to_prefix:
|
||||
self.state_proj = nn.Linear(self.config.max_state_dim, self.vlm_with_expert.config.text_config.hidden_size)
|
||||
self.state_proj = nn.Linear(
|
||||
self.config.max_state_dim, self.vlm_with_expert.config.text_config.hidden_size
|
||||
)
|
||||
else:
|
||||
self.state_proj = nn.Linear(self.config.max_state_dim, self.vlm_with_expert.expert_hidden_size)
|
||||
self.action_in_proj = nn.Linear(self.config.max_action_dim, self.vlm_with_expert.expert_hidden_size)
|
||||
self.action_out_proj = nn.Linear(self.vlm_with_expert.expert_hidden_size, self.config.max_action_dim)
|
||||
|
||||
self.action_time_mlp_in = nn.Linear(self.vlm_with_expert.expert_hidden_size * 2, self.vlm_with_expert.expert_hidden_size)
|
||||
self.action_time_mlp_out = nn.Linear(self.vlm_with_expert.expert_hidden_size, self.vlm_with_expert.expert_hidden_size)
|
||||
self.action_time_mlp_in = nn.Linear(
|
||||
self.vlm_with_expert.expert_hidden_size * 2, self.vlm_with_expert.expert_hidden_size
|
||||
)
|
||||
self.action_time_mlp_out = nn.Linear(
|
||||
self.vlm_with_expert.expert_hidden_size, self.vlm_with_expert.expert_hidden_size
|
||||
)
|
||||
|
||||
self.set_requires_grad()
|
||||
# SmolVLM2 has: [fake_tok + crop_tok + crop + fake_tok + crop_tok ... + fake_tok + global_tok + global + fake_tok] + [second image] + ...
|
||||
if any([k in self.config.vlm_model_name for k in ["SmolVLM-", "SmolVLA-"]]):
|
||||
if any([k in self.config.vlm_model_name for k in ["SmolVLM-", "SmolVLA-"]]):
|
||||
if "SmolVLM-Instruct" in self.config.vlm_model_name:
|
||||
self.fake_image_token = 49152
|
||||
self.global_image_token = [44, 13906, 29, 6266, 46]
|
||||
self.global_image_start_token = torch.tensor([self.fake_image_token] + self.global_image_token, dtype=torch.long)
|
||||
self.fake_image_token = 49152
|
||||
self.global_image_token = [44, 13906, 29, 6266, 46]
|
||||
self.global_image_start_token = torch.tensor(
|
||||
[self.fake_image_token] + self.global_image_token, dtype=torch.long
|
||||
)
|
||||
else:
|
||||
self.fake_image_token = 49189
|
||||
self.global_image_token = 49152
|
||||
self.global_image_start_token = torch.tensor([self.fake_image_token, self.global_image_token], dtype=torch.long)
|
||||
self.fake_image_token = 49189
|
||||
self.global_image_token = 49152
|
||||
self.global_image_start_token = torch.tensor(
|
||||
[self.fake_image_token, self.global_image_token], dtype=torch.long
|
||||
)
|
||||
else:
|
||||
self.fake_image_token = self.vlm_with_expert.processor.tokenizer.fake_image_token_id
|
||||
self.global_image_token = self.vlm_with_expert.processor.tokenizer.global_image_token_id
|
||||
self.global_image_start_token = torch.tensor([self.fake_image_token, self.global_image_token], dtype=torch.long)
|
||||
self.global_image_start_token = torch.tensor(
|
||||
[self.fake_image_token, self.global_image_token], dtype=torch.long
|
||||
)
|
||||
|
||||
self.add_image_special_tokens = self.config.add_image_special_tokens
|
||||
self.add_local_special_image_tokens = self.config.add_local_special_image_tokens
|
||||
self.local_image_tokens = [torch.tensor([self.fake_image_token, tok], dtype=torch.long) for tok in [49153, 49154, 49155, 49159, 49160, 49161, 49165, 49166, 49167]] # assume 3 x 3 grid
|
||||
|
||||
self.local_image_tokens = [
|
||||
torch.tensor([self.fake_image_token, tok], dtype=torch.long)
|
||||
for tok in [49153, 49154, 49155, 49159, 49160, 49161, 49165, 49166, 49167]
|
||||
] # assume 3 x 3 grid
|
||||
|
||||
self.local_image_start_token = self.global_image_start_token
|
||||
self.image_end_token = torch.tensor([self.fake_image_token], dtype=torch.long)
|
||||
self.prefix_length = self.config.prefix_length
|
||||
self.include_past_images = self.config.n_obs_steps > 1 and "image" in self.config.past_obs_keys.split(",")
|
||||
self.include_past_images = self.config.n_obs_steps > 1 and "image" in self.config.past_obs_keys.split(
|
||||
","
|
||||
)
|
||||
self.num_past_images = self.config.n_obs_steps if self.include_past_images else 1
|
||||
self.causal_attention_on_history = self.config.causal_attention_on_history
|
||||
|
||||
|
||||
|
||||
|
||||
# def configure_peft(self, model):
|
||||
# # return model
|
||||
@@ -845,14 +881,30 @@ class VLAFlowMatching(nn.Module):
|
||||
img,
|
||||
img_mask,
|
||||
) in enumerate(zip(images, img_masks, strict=False)):
|
||||
# FIXME(mshukor): add special tokens for the history each history_steps or not
|
||||
# FIXME(mshukor): add special tokens for the history each history_steps or not
|
||||
if self.add_image_special_tokens:
|
||||
if self.add_local_special_image_tokens and img_idx % num_images != num_images - 1:
|
||||
local_token_idx = img_idx % num_images
|
||||
image_start_token = self.vlm_with_expert.embed_language_tokens(self.local_image_tokens[local_token_idx].to(device=self.vlm_with_expert.vlm.device)).unsqueeze(0).expand(img.shape[0], -1, -1)
|
||||
image_start_token = (
|
||||
self.vlm_with_expert.embed_language_tokens(
|
||||
self.local_image_tokens[local_token_idx].to(
|
||||
device=self.vlm_with_expert.vlm.device
|
||||
)
|
||||
)
|
||||
.unsqueeze(0)
|
||||
.expand(img.shape[0], -1, -1)
|
||||
)
|
||||
else:
|
||||
image_start_token = self.vlm_with_expert.embed_language_tokens(self.global_image_start_token.to(device=self.vlm_with_expert.vlm.device)).unsqueeze(0).expand(img.shape[0], -1, -1)
|
||||
image_start_mask = torch.ones_like(image_start_token[:, :, 0], dtype=torch.bool, device=image_start_token.device)
|
||||
image_start_token = (
|
||||
self.vlm_with_expert.embed_language_tokens(
|
||||
self.global_image_start_token.to(device=self.vlm_with_expert.vlm.device)
|
||||
)
|
||||
.unsqueeze(0)
|
||||
.expand(img.shape[0], -1, -1)
|
||||
)
|
||||
image_start_mask = torch.ones_like(
|
||||
image_start_token[:, :, 0], dtype=torch.bool, device=image_start_token.device
|
||||
)
|
||||
if self.causal_attention_on_history and img_idx % num_images == 0:
|
||||
att_masks += [1] + [0] * (image_start_mask.shape[-1] - 1)
|
||||
else:
|
||||
@@ -861,7 +913,7 @@ class VLAFlowMatching(nn.Module):
|
||||
pad_masks.append(image_start_mask)
|
||||
|
||||
img_emb = self.vlm_with_expert.embed_image(img)
|
||||
img_emb = img_emb #.to(dtype=self.vlm_with_expert.type)
|
||||
img_emb = img_emb # .to(dtype=self.vlm_with_expert.type)
|
||||
|
||||
# Normalize image embeddings
|
||||
img_emb_dim = img_emb.shape[-1]
|
||||
@@ -880,16 +932,26 @@ class VLAFlowMatching(nn.Module):
|
||||
|
||||
att_masks += [0] * (num_img_embs)
|
||||
if self.add_image_special_tokens:
|
||||
if not self.add_local_special_image_tokens or (self.add_local_special_image_tokens and img_idx % num_images == num_images - 1):
|
||||
image_end_token = self.vlm_with_expert.embed_language_tokens(self.image_end_token.to(device=self.vlm_with_expert.vlm.device)).unsqueeze(0).expand(img.shape[0], -1, -1)
|
||||
image_end_mask = torch.ones_like(image_end_token[:, :, 0], dtype=torch.bool, device=image_end_token.device)
|
||||
if not self.add_local_special_image_tokens or (
|
||||
self.add_local_special_image_tokens and img_idx % num_images == num_images - 1
|
||||
):
|
||||
image_end_token = (
|
||||
self.vlm_with_expert.embed_language_tokens(
|
||||
self.image_end_token.to(device=self.vlm_with_expert.vlm.device)
|
||||
)
|
||||
.unsqueeze(0)
|
||||
.expand(img.shape[0], -1, -1)
|
||||
)
|
||||
image_end_mask = torch.ones_like(
|
||||
image_end_token[:, :, 0], dtype=torch.bool, device=image_end_token.device
|
||||
)
|
||||
embs.append(image_end_token)
|
||||
pad_masks.append(image_end_mask)
|
||||
att_masks += [0] * (image_end_mask.shape[1])
|
||||
lang_emb = self.vlm_with_expert.embed_language_tokens(lang_tokens)
|
||||
# Normalize language embeddings
|
||||
lang_emb_dim = lang_emb.shape[-1]
|
||||
lang_emb = lang_emb * math.sqrt(lang_emb_dim) # FIXME(mshukor): is this needed for smolvlm?
|
||||
lang_emb = lang_emb * math.sqrt(lang_emb_dim) # FIXME(mshukor): is this needed for smolvlm?
|
||||
|
||||
embs.append(lang_emb)
|
||||
pad_masks.append(lang_masks)
|
||||
@@ -900,7 +962,9 @@ class VLAFlowMatching(nn.Module):
|
||||
|
||||
if state is not None and self.state_to_prefix:
|
||||
state_emb = self.state_proj(state)
|
||||
state_emb = state_emb[:, None, :] if state_emb.ndim == 2 else state_emb #.to(dtype=self.vlm_with_expert.type)
|
||||
state_emb = (
|
||||
state_emb[:, None, :] if state_emb.ndim == 2 else state_emb
|
||||
) # .to(dtype=self.vlm_with_expert.type)
|
||||
embs.append(state_emb)
|
||||
bsize = state_emb.shape[0]
|
||||
dtype = state_emb.dtype
|
||||
@@ -912,7 +976,7 @@ class VLAFlowMatching(nn.Module):
|
||||
|
||||
# Set attention masks so that image and language inputs do not attend to state or actions
|
||||
# att_masks += [1] + [0]*(states_seq_len - 1)
|
||||
att_masks += [1]*(states_seq_len)
|
||||
att_masks += [1] * (states_seq_len)
|
||||
embs = torch.cat(embs, dim=1)
|
||||
pad_masks = torch.cat(pad_masks, dim=1)
|
||||
att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device)
|
||||
@@ -937,7 +1001,9 @@ class VLAFlowMatching(nn.Module):
|
||||
# Embed state
|
||||
if not self.state_to_prefix:
|
||||
state_emb = self.state_proj(state)
|
||||
state_emb = state_emb[:, None, :] if state_emb.ndim == 2 else state_emb #.to(dtype=self.vlm_with_expert.type)
|
||||
state_emb = (
|
||||
state_emb[:, None, :] if state_emb.ndim == 2 else state_emb
|
||||
) # .to(dtype=self.vlm_with_expert.type)
|
||||
embs.append(state_emb)
|
||||
bsize = state_emb.shape[0]
|
||||
dtype = state_emb.dtype
|
||||
@@ -948,8 +1014,7 @@ class VLAFlowMatching(nn.Module):
|
||||
pad_masks.append(state_mask)
|
||||
|
||||
# Set attention masks so that image and language inputs do not attend to state or actions
|
||||
att_masks += [1] + [0]*(states_seq_len - 1)
|
||||
|
||||
att_masks += [1] + [0] * (states_seq_len - 1)
|
||||
|
||||
# Fuse timestep + action information using an MLP
|
||||
action_emb = self.action_in_proj(noisy_actions)
|
||||
@@ -1010,7 +1075,7 @@ class VLAFlowMatching(nn.Module):
|
||||
images, img_masks, lang_tokens, lang_masks, state=state
|
||||
)
|
||||
suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(state, x_t, time)
|
||||
|
||||
|
||||
pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)
|
||||
att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1)
|
||||
|
||||
@@ -1061,12 +1126,12 @@ class VLAFlowMatching(nn.Module):
|
||||
x_t = torch.zeros_like(noise, dtype=torch.float32, device=device)
|
||||
expanded_time = torch.zeros(bsize, dtype=torch.float32, device=device)
|
||||
x_t = self.denoise_step(
|
||||
state,
|
||||
prefix_pad_masks,
|
||||
past_key_values,
|
||||
x_t,
|
||||
expanded_time,
|
||||
)
|
||||
state,
|
||||
prefix_pad_masks,
|
||||
past_key_values,
|
||||
x_t,
|
||||
expanded_time,
|
||||
)
|
||||
else:
|
||||
dt = -1.0 / self.config.num_steps
|
||||
dt = torch.tensor(dt, dtype=torch.float32, device=device)
|
||||
|
||||
@@ -12,31 +12,28 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Optional, Union
|
||||
from functools import partial
|
||||
import copy
|
||||
from functools import partial
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.version
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
import torch.version
|
||||
from peft import LoraConfig, TaskType, get_peft_model
|
||||
from pytest import Cache
|
||||
from torch import nn
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
GemmaForCausalLM,
|
||||
AutoModelForImageTextToText,
|
||||
AutoProcessor,
|
||||
PretrainedConfig,
|
||||
PreTrainedModel,
|
||||
SmolVLMForConditionalGeneration,
|
||||
AutoModel,
|
||||
AutoModelForImageTextToText,
|
||||
AutoModelForVision2Seq,
|
||||
AutoProcessor,
|
||||
SmolVLMForConditionalGeneration,
|
||||
)
|
||||
from transformers.models.auto import CONFIG_MAPPING
|
||||
from transformers import SmolVLMModel, SmolVLMConfig
|
||||
|
||||
from lerobot.policies.smolpi0.flex_attention import flex_attention_forward
|
||||
|
||||
|
||||
def _round_up_to_multiple(x, multiple):
|
||||
return (x + multiple - 1) // multiple * multiple
|
||||
|
||||
@@ -180,20 +177,31 @@ def apply_rope(x, positions, max_wavelength=10_000):
|
||||
# f"Wrong value provided for `attention_implementation` ({self.attention_implementation}). Expected 'eager', 'fa2' or 'flex'."
|
||||
# )
|
||||
|
||||
|
||||
def get_intermediate_size(hidden_dim, ffn_dim_multiplier=4, multiple_of=256):
|
||||
hidden_dim = int(2 * hidden_dim / 3)
|
||||
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
|
||||
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
||||
return hidden_dim
|
||||
|
||||
|
||||
|
||||
class SmolVLMWithExpertModel(nn.Module):
|
||||
# config_class = PaliGemmaWithExpertConfig
|
||||
|
||||
def __init__(self, model_id: str = "HuggingFaceTB/SmolVLM2-500M-Video-Instruct",
|
||||
load_vlm_weights: bool = True, train_expert_only: bool = True, freeze_vision_encoder: bool = False,
|
||||
attention_implementation: str = "eager", attention_mode: str = "self_attn", num_expert_layers: int = -1,
|
||||
num_vlm_layers: int = -1, self_attn_every_n_layers: int = -1, expert_width_multiplier: float = 0.5, self_attn_only_actions: bool = False):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str = "HuggingFaceTB/SmolVLM2-500M-Video-Instruct",
|
||||
load_vlm_weights: bool = True,
|
||||
train_expert_only: bool = True,
|
||||
freeze_vision_encoder: bool = False,
|
||||
attention_implementation: str = "eager",
|
||||
attention_mode: str = "self_attn",
|
||||
num_expert_layers: int = -1,
|
||||
num_vlm_layers: int = -1,
|
||||
self_attn_every_n_layers: int = -1,
|
||||
expert_width_multiplier: float = 0.5,
|
||||
self_attn_only_actions: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
if load_vlm_weights:
|
||||
print(f"Loading {model_id} weights ...")
|
||||
@@ -227,15 +235,17 @@ class SmolVLMWithExpertModel(nn.Module):
|
||||
# Smaller lm expert
|
||||
lm_expert_config = copy.deepcopy(config.text_config)
|
||||
hidden_size = lm_expert_config.hidden_size
|
||||
lm_expert_config.hidden_size = int(hidden_size*expert_width_multiplier) #hidden_size // 2
|
||||
lm_expert_config.intermediate_size = get_intermediate_size(int(hidden_size*expert_width_multiplier))
|
||||
lm_expert_config.hidden_size = int(hidden_size * expert_width_multiplier) # hidden_size // 2
|
||||
lm_expert_config.intermediate_size = get_intermediate_size(int(hidden_size * expert_width_multiplier))
|
||||
lm_expert_config.num_hidden_layers = self.num_vlm_layers
|
||||
if num_expert_layers > 0 :
|
||||
assert len(self.get_vlm_model().text_model.layers) % num_expert_layers == 0, f"Number of layers in the VLM {len(self.get_vlm_model().text_model.layers)} are not multiple of num_expert_layers {num_expert_layers}"
|
||||
if num_expert_layers > 0:
|
||||
assert len(self.get_vlm_model().text_model.layers) % num_expert_layers == 0, (
|
||||
f"Number of layers in the VLM {len(self.get_vlm_model().text_model.layers)} are not multiple of num_expert_layers {num_expert_layers}"
|
||||
)
|
||||
lm_expert_config.num_hidden_layers = num_expert_layers
|
||||
# lm_expert_config.head_dim = lm_expert_config.head_dim * 2
|
||||
self.lm_expert = AutoModel.from_config(lm_expert_config)
|
||||
|
||||
|
||||
self.num_expert_layers = len(self.lm_expert.layers)
|
||||
self.self_attn_every_n_layers = self_attn_every_n_layers
|
||||
self.self_attn_only_actions = self_attn_only_actions
|
||||
@@ -245,10 +255,14 @@ class SmolVLMWithExpertModel(nn.Module):
|
||||
if self.self_attn_every_n_layers > 0 and layer_idx % self.self_attn_every_n_layers == 0:
|
||||
continue
|
||||
self.lm_expert.layers[layer_idx].self_attn.k_proj = nn.Linear(
|
||||
config.text_config.num_key_value_heads * config.text_config.head_dim, lm_expert_config.num_key_value_heads * lm_expert_config.head_dim, bias=lm_expert_config.attention_bias
|
||||
config.text_config.num_key_value_heads * config.text_config.head_dim,
|
||||
lm_expert_config.num_key_value_heads * lm_expert_config.head_dim,
|
||||
bias=lm_expert_config.attention_bias,
|
||||
)
|
||||
self.lm_expert.layers[layer_idx].self_attn.v_proj = nn.Linear(
|
||||
config.text_config.num_key_value_heads * config.text_config.head_dim, lm_expert_config.num_key_value_heads * lm_expert_config.head_dim, bias=lm_expert_config.attention_bias
|
||||
config.text_config.num_key_value_heads * config.text_config.head_dim,
|
||||
lm_expert_config.num_key_value_heads * lm_expert_config.head_dim,
|
||||
bias=lm_expert_config.attention_bias,
|
||||
)
|
||||
# Remove unused embed_tokens
|
||||
self.lm_expert.embed_tokens = None
|
||||
@@ -308,8 +322,10 @@ class SmolVLMWithExpertModel(nn.Module):
|
||||
else:
|
||||
self.vlm = self.vlm.merge_and_unload()
|
||||
|
||||
def get_vlm_model(self,):
|
||||
if hasattr(self.vlm.model, "model"): # When using peft
|
||||
def get_vlm_model(
|
||||
self,
|
||||
):
|
||||
if hasattr(self.vlm.model, "model"): # When using peft
|
||||
return self.vlm.model.model
|
||||
else:
|
||||
return self.vlm.model
|
||||
@@ -326,22 +342,20 @@ class SmolVLMWithExpertModel(nn.Module):
|
||||
else:
|
||||
# To avoid unused params issue with distributed training
|
||||
last_layers = [self.num_vlm_layers - 1]
|
||||
if self.num_vlm_layers != self.num_expert_layers and self.num_vlm_layers % self.num_expert_layers == 0:
|
||||
if (
|
||||
self.num_vlm_layers != self.num_expert_layers
|
||||
and self.num_vlm_layers % self.num_expert_layers == 0
|
||||
):
|
||||
last_layers.append(self.num_vlm_layers - 2)
|
||||
frozen_layers = [
|
||||
"lm_head",
|
||||
"text_model.model.norm.weight",
|
||||
]
|
||||
"lm_head",
|
||||
"text_model.model.norm.weight",
|
||||
]
|
||||
for layer in last_layers:
|
||||
frozen_layers.append(f"text_model.model.layers.{layer}.")
|
||||
|
||||
for name, params in self.vlm.named_parameters():
|
||||
if any(
|
||||
[
|
||||
k in name
|
||||
for k in frozen_layers
|
||||
]
|
||||
):
|
||||
if any([k in name for k in frozen_layers]):
|
||||
params.requires_grad = False
|
||||
# To avoid unused params issue with distributed training
|
||||
for name, params in self.lm_expert.named_parameters():
|
||||
@@ -410,19 +424,34 @@ class SmolVLMWithExpertModel(nn.Module):
|
||||
|
||||
# FIXME(mshukor): add special image tokens specific to smolvlm
|
||||
# Get sequence from the vision encoder
|
||||
image_hidden_states = self.get_vlm_model().vision_model(
|
||||
pixel_values=image.to(dtype=self.get_vlm_model().vision_model.dtype),
|
||||
patch_attention_mask=patch_attention_mask,
|
||||
).last_hidden_state
|
||||
image_hidden_states = (
|
||||
self.get_vlm_model()
|
||||
.vision_model(
|
||||
pixel_values=image.to(dtype=self.get_vlm_model().vision_model.dtype),
|
||||
patch_attention_mask=patch_attention_mask,
|
||||
)
|
||||
.last_hidden_state
|
||||
)
|
||||
# Modality projection & resampling
|
||||
image_hidden_states = self.get_vlm_model().connector(image_hidden_states)
|
||||
return image_hidden_states
|
||||
|
||||
|
||||
def embed_language_tokens(self, tokens: torch.Tensor):
|
||||
return self.get_vlm_model().text_model.get_input_embeddings()(tokens)
|
||||
|
||||
def forward_attn_layer(self, model_layers, inputs_embeds, layer_idx, position_ids, attention_mask, batch_size, head_dim, use_cache: bool = True, fill_kv_cache: bool = True, past_key_values=None) -> list[torch.Tensor]:
|
||||
|
||||
def forward_attn_layer(
|
||||
self,
|
||||
model_layers,
|
||||
inputs_embeds,
|
||||
layer_idx,
|
||||
position_ids,
|
||||
attention_mask,
|
||||
batch_size,
|
||||
head_dim,
|
||||
use_cache: bool = True,
|
||||
fill_kv_cache: bool = True,
|
||||
past_key_values=None,
|
||||
) -> list[torch.Tensor]:
|
||||
query_states = []
|
||||
key_states = []
|
||||
value_states = []
|
||||
@@ -430,7 +459,7 @@ class SmolVLMWithExpertModel(nn.Module):
|
||||
layer = model_layers[i][layer_idx]
|
||||
if hidden_states is None or layer is None:
|
||||
continue
|
||||
|
||||
|
||||
# normalizer = torch.tensor(models[i].config.hidden_size**0.5, dtype=hidden_states.dtype)
|
||||
# hidden_states = hidden_states * normalizer
|
||||
hidden_states = layer.input_layernorm(hidden_states)
|
||||
@@ -468,12 +497,16 @@ class SmolVLMWithExpertModel(nn.Module):
|
||||
if inputs_embeds[1] is not None:
|
||||
suffix_len = inputs_embeds[1].shape[1]
|
||||
attention_mask_[:, -suffix_len:, :-suffix_len] = False
|
||||
position_ids_[:, -suffix_len:] = _position_ids[:, -suffix_len:] - _position_ids[:, -suffix_len][:, None]
|
||||
position_ids_[:, -suffix_len:] = (
|
||||
_position_ids[:, -suffix_len:] - _position_ids[:, -suffix_len][:, None]
|
||||
)
|
||||
else:
|
||||
attention_mask_ = _attention_mask
|
||||
position_ids_ = _position_ids
|
||||
|
||||
query_states = apply_rope(query_states, position_ids_) # FIXME(mshukor): this assumes we have always the vlm features?
|
||||
query_states = apply_rope(
|
||||
query_states, position_ids_
|
||||
) # FIXME(mshukor): this assumes we have always the vlm features?
|
||||
key_states = apply_rope(key_states, position_ids_)
|
||||
|
||||
if use_cache and past_key_values is None:
|
||||
@@ -491,9 +524,7 @@ class SmolVLMWithExpertModel(nn.Module):
|
||||
# the max len, then we (for instance) double the cache size. This implementation already exists
|
||||
# in `transformers`. (molbap)
|
||||
key_states = torch.cat([past_key_values[layer_idx]["key_states"], key_states], dim=1)
|
||||
value_states = torch.cat(
|
||||
[past_key_values[layer_idx]["value_states"], value_states], dim=1
|
||||
)
|
||||
value_states = torch.cat([past_key_values[layer_idx]["value_states"], value_states], dim=1)
|
||||
|
||||
attention_interface = self.get_attention_interface()
|
||||
|
||||
@@ -504,20 +535,32 @@ class SmolVLMWithExpertModel(nn.Module):
|
||||
|
||||
return [att_output], past_key_values
|
||||
|
||||
|
||||
def forward_cross_attn_layer(self, model_layers, inputs_embeds, layer_idx, position_ids, attention_mask, batch_size, head_dim, use_cache: bool = True, fill_kv_cache: bool = True, past_key_values = None) -> list[torch.Tensor]:
|
||||
|
||||
def forward_cross_attn_layer(
|
||||
self,
|
||||
model_layers,
|
||||
inputs_embeds,
|
||||
layer_idx,
|
||||
position_ids,
|
||||
attention_mask,
|
||||
batch_size,
|
||||
head_dim,
|
||||
use_cache: bool = True,
|
||||
fill_kv_cache: bool = True,
|
||||
past_key_values=None,
|
||||
) -> list[torch.Tensor]:
|
||||
attention_interface = self.get_attention_interface()
|
||||
|
||||
att_outputs = []
|
||||
assert len(inputs_embeds) == 2 or (use_cache and past_key_values is not None and not fill_kv_cache), f"Both len(inputs_embeds) == {len(inputs_embeds)} and past_key_values is {past_key_values}"
|
||||
assert len(inputs_embeds) == 2 or (use_cache and past_key_values is not None and not fill_kv_cache), (
|
||||
f"Both len(inputs_embeds) == {len(inputs_embeds)} and past_key_values is {past_key_values}"
|
||||
)
|
||||
|
||||
if len(inputs_embeds) == 2 and not past_key_values:
|
||||
# Prefix attention
|
||||
seq_len = inputs_embeds[0].shape[1]
|
||||
position_id, expert_position_id = position_ids[:, :seq_len], position_ids[:, seq_len:]
|
||||
position_id, expert_position_id = position_ids[:, :seq_len], position_ids[:, seq_len:]
|
||||
prefix_attention_mask = attention_mask[:, :seq_len, :seq_len]
|
||||
|
||||
|
||||
layer = model_layers[0][layer_idx]
|
||||
|
||||
hidden_states = layer.input_layernorm(inputs_embeds[0])
|
||||
@@ -558,7 +601,6 @@ class SmolVLMWithExpertModel(nn.Module):
|
||||
key_states = past_key_values[layer_idx]["key_states"]
|
||||
value_states = past_key_values[layer_idx]["value_states"]
|
||||
|
||||
|
||||
# Expert
|
||||
expert_layer = model_layers[1][layer_idx]
|
||||
if expert_layer is not None:
|
||||
@@ -570,21 +612,37 @@ class SmolVLMWithExpertModel(nn.Module):
|
||||
expert_hidden_states = expert_hidden_states.to(dtype=expert_layer.self_attn.q_proj.weight.dtype)
|
||||
expert_query_state = expert_layer.self_attn.q_proj(expert_hidden_states).view(expert_hidden_shape)
|
||||
|
||||
_key_states = key_states.to(dtype=expert_layer.self_attn.k_proj.weight.dtype).view(
|
||||
*key_states.shape[:2], -1
|
||||
)
|
||||
expert_key_states = expert_layer.self_attn.k_proj(_key_states).view(
|
||||
*_key_states.shape[:-1], -1, expert_layer.self_attn.head_dim
|
||||
) # k_proj should have same dim as kv
|
||||
|
||||
_key_states = key_states.to(dtype=expert_layer.self_attn.k_proj.weight.dtype).view(*key_states.shape[:2], -1)
|
||||
expert_key_states = expert_layer.self_attn.k_proj(_key_states).view(*_key_states.shape[:-1], -1, expert_layer.self_attn.head_dim) # k_proj should have same dim as kv
|
||||
_value_states = value_states.to(dtype=expert_layer.self_attn.v_proj.weight.dtype).view(
|
||||
*value_states.shape[:2], -1
|
||||
)
|
||||
expert_value_states = expert_layer.self_attn.v_proj(_value_states).view(
|
||||
*_value_states.shape[:-1], -1, expert_layer.self_attn.head_dim
|
||||
)
|
||||
|
||||
_value_states = value_states.to(dtype=expert_layer.self_attn.v_proj.weight.dtype).view(*value_states.shape[:2], -1)
|
||||
expert_value_states = expert_layer.self_attn.v_proj(_value_states).view(*_value_states.shape[:-1], -1, expert_layer.self_attn.head_dim)
|
||||
|
||||
expert_position_id = expert_position_id - torch.min(expert_position_id, dim=1, keepdim=True).values # start from 0
|
||||
expert_attention_mask = attention_mask[:, -inputs_embeds[1].shape[1]:, :expert_key_states.shape[1]:] # take into account kv
|
||||
expert_position_id = (
|
||||
expert_position_id - torch.min(expert_position_id, dim=1, keepdim=True).values
|
||||
) # start from 0
|
||||
expert_attention_mask = attention_mask[
|
||||
:, -inputs_embeds[1].shape[1] :, : expert_key_states.shape[1] :
|
||||
] # take into account kv
|
||||
|
||||
expert_query_states = apply_rope(expert_query_state, expert_position_id)
|
||||
# expert_key_states = apply_rope(expert_key_state, expert_position_id)
|
||||
|
||||
|
||||
att_output = attention_interface(
|
||||
expert_attention_mask, batch_size, head_dim, expert_query_states, expert_key_states, expert_value_states
|
||||
expert_attention_mask,
|
||||
batch_size,
|
||||
head_dim,
|
||||
expert_query_states,
|
||||
expert_key_states,
|
||||
expert_value_states,
|
||||
)
|
||||
att_outputs.append(att_output)
|
||||
else:
|
||||
@@ -592,8 +650,8 @@ class SmolVLMWithExpertModel(nn.Module):
|
||||
|
||||
# att_output = att_output.to(dtype=models[i].dtype)
|
||||
return att_outputs, past_key_values
|
||||
|
||||
def get_model_layers(self, models: list) -> list: # FIXME(mshukor): is this efficient?
|
||||
|
||||
def get_model_layers(self, models: list) -> list: # FIXME(mshukor): is this efficient?
|
||||
vlm_layers = []
|
||||
expert_layers = []
|
||||
multiple_of = self.num_vlm_layers // self.num_expert_layers
|
||||
@@ -606,15 +664,16 @@ class SmolVLMWithExpertModel(nn.Module):
|
||||
vlm_layers.append(models[0].layers[i])
|
||||
expert_layers.append(expert_layer)
|
||||
return [vlm_layers, expert_layers]
|
||||
|
||||
# TODO: break down this huge forward into modules or functions
|
||||
def forward(
|
||||
self,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None,
|
||||
inputs_embeds: List[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
fill_kv_cache: Optional[bool] = None,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
position_ids: torch.LongTensor | None = None,
|
||||
past_key_values: list[torch.FloatTensor] | Cache | None = None,
|
||||
inputs_embeds: list[torch.FloatTensor] = None,
|
||||
use_cache: bool | None = None,
|
||||
fill_kv_cache: bool | None = None,
|
||||
):
|
||||
models = [self.get_vlm_model().text_model, self.lm_expert]
|
||||
model_layers = self.get_model_layers(models)
|
||||
@@ -626,26 +685,58 @@ class SmolVLMWithExpertModel(nn.Module):
|
||||
continue
|
||||
batch_size = hidden_states.shape[0]
|
||||
|
||||
# # Pad prefix embds so that prefix_embs + prefix_embs len are multiple of 128, pad left or right depending on the gen or train
|
||||
# # Pad prefix embds so that prefix_embs + prefix_embs len are multiple of 128, pad left or right depending on the gen or train
|
||||
if self.attention_implementation == "flex":
|
||||
if inputs_embeds[0] is not None and inputs_embeds[1] is not None and attention_mask.shape[-1] == attention_mask.shape[-2] and past_key_values is None: # Now only during training
|
||||
if (
|
||||
inputs_embeds[0] is not None
|
||||
and inputs_embeds[1] is not None
|
||||
and attention_mask.shape[-1] == attention_mask.shape[-2]
|
||||
and past_key_values is None
|
||||
): # Now only during training
|
||||
seq_len = inputs_embeds[0].shape[1] + inputs_embeds[1].shape[1]
|
||||
padded_seq_len = _round_up_to_multiple(seq_len, 128) # FIXME(mshukor): more efficient to have a fixed seq len?
|
||||
padded_seq_len = _round_up_to_multiple(
|
||||
seq_len, 128
|
||||
) # FIXME(mshukor): more efficient to have a fixed seq len?
|
||||
b_mask, q_len, kv_len = attention_mask.shape # The shape of your mask
|
||||
pad = padded_seq_len - q_len
|
||||
attention_mask = F.pad(attention_mask, (0, pad, 0, pad), value=True)
|
||||
inputs_embeds[0] = F.pad(inputs_embeds[0], (0, 0, 0, pad), value=0.0)
|
||||
position_ids = F.pad(position_ids, (0, pad), value=0)
|
||||
|
||||
|
||||
# RMSNorm
|
||||
num_layers = self.num_vlm_layers
|
||||
head_dim = self.vlm.config.text_config.head_dim
|
||||
for layer_idx in range(num_layers):
|
||||
if fill_kv_cache or "cross" not in self.attention_mode or (self.self_attn_every_n_layers > 0 and layer_idx % self.self_attn_every_n_layers == 0):
|
||||
att_outputs, past_key_values = self.forward_attn_layer(model_layers, inputs_embeds, layer_idx, position_ids, attention_mask, batch_size, head_dim, use_cache=use_cache, fill_kv_cache=fill_kv_cache, past_key_values=past_key_values)
|
||||
if (
|
||||
fill_kv_cache
|
||||
or "cross" not in self.attention_mode
|
||||
or (self.self_attn_every_n_layers > 0 and layer_idx % self.self_attn_every_n_layers == 0)
|
||||
):
|
||||
att_outputs, past_key_values = self.forward_attn_layer(
|
||||
model_layers,
|
||||
inputs_embeds,
|
||||
layer_idx,
|
||||
position_ids,
|
||||
attention_mask,
|
||||
batch_size,
|
||||
head_dim,
|
||||
use_cache=use_cache,
|
||||
fill_kv_cache=fill_kv_cache,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
else:
|
||||
att_outputs, past_key_values = self.forward_cross_attn_layer(model_layers, inputs_embeds, layer_idx, position_ids, attention_mask, batch_size, head_dim, use_cache=use_cache, fill_kv_cache=fill_kv_cache, past_key_values=past_key_values)
|
||||
att_outputs, past_key_values = self.forward_cross_attn_layer(
|
||||
model_layers,
|
||||
inputs_embeds,
|
||||
layer_idx,
|
||||
position_ids,
|
||||
attention_mask,
|
||||
batch_size,
|
||||
head_dim,
|
||||
use_cache=use_cache,
|
||||
fill_kv_cache=fill_kv_cache,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
# query_states = []
|
||||
# key_states = []
|
||||
# value_states = []
|
||||
@@ -703,7 +794,6 @@ class SmolVLMWithExpertModel(nn.Module):
|
||||
# attention_mask, batch_size, head_dim, query_states, key_states, value_states
|
||||
# )
|
||||
|
||||
|
||||
# att_output = att_output.to(dtype=models[i].dtype)
|
||||
|
||||
# first part of att_output is prefix (up to sequence length, [:, 0:prefix_seq_len])
|
||||
@@ -712,7 +802,9 @@ class SmolVLMWithExpertModel(nn.Module):
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
# layer = models[i].layers[layer_idx]
|
||||
layer = model_layers[i][layer_idx]
|
||||
att_output = att_outputs[i] if i < len(att_outputs) else att_outputs[0] # in case of self_attn
|
||||
att_output = (
|
||||
att_outputs[i] if i < len(att_outputs) else att_outputs[0]
|
||||
) # in case of self_attn
|
||||
if hidden_states is not None:
|
||||
if layer is None:
|
||||
outputs_embeds.append(hidden_states)
|
||||
@@ -759,7 +851,11 @@ class SmolVLMWithExpertModel(nn.Module):
|
||||
if self.attention_implementation == "fa2":
|
||||
attention_interface = self.flash_attention_forward
|
||||
elif self.attention_implementation == "flex":
|
||||
attention_interface = partial(flex_attention_forward, num_att_heads=self.num_attention_heads, num_key_value_heads=self.num_key_value_heads)
|
||||
attention_interface = partial(
|
||||
flex_attention_forward,
|
||||
num_att_heads=self.num_attention_heads,
|
||||
num_key_value_heads=self.num_key_value_heads,
|
||||
)
|
||||
else:
|
||||
attention_interface = self.eager_attention_forward
|
||||
return attention_interface
|
||||
@@ -807,7 +903,7 @@ class SmolVLMWithExpertModel(nn.Module):
|
||||
att_weights *= head_dim**-0.5
|
||||
|
||||
att_weights = att_weights.to(dtype=torch.float32)
|
||||
big_neg = torch.finfo(att_weights.dtype).min #-2.3819763e38 # See gemma/modules.py
|
||||
big_neg = torch.finfo(att_weights.dtype).min # -2.3819763e38 # See gemma/modules.py
|
||||
masked_att_weights = torch.where(attention_mask[:, None, :, :], att_weights, big_neg)
|
||||
probs = nn.functional.softmax(masked_att_weights, dim=-1)
|
||||
probs = probs.to(dtype=value_states.dtype)
|
||||
|
||||
@@ -1027,6 +1027,7 @@ from lerobot.policies.utils import (
|
||||
populate_queues,
|
||||
)
|
||||
from lerobot.utils.utils import get_safe_dtype
|
||||
|
||||
# OBS_STATE = 'state'
|
||||
# Matches ".soNNN", optionally followed by "-something", up to the "_buffer_" marker
|
||||
_VARIANT_RE = re.compile(r"\.so\d+(?:-[\w]+)?_buffer_")
|
||||
@@ -1891,4 +1892,4 @@ class VLAFlowMatching(nn.Module):
|
||||
suffix_out = suffix_out[:, -self.config.chunk_size :]
|
||||
suffix_out = suffix_out.to(dtype=torch.float32)
|
||||
v_t = self.action_out_proj(suffix_out)
|
||||
return v_t
|
||||
return v_t
|
||||
|
||||
@@ -1 +1 @@
|
||||
c
|
||||
c
|
||||
|
||||
@@ -547,4 +547,4 @@ class SmolVLMWithExpertModel(nn.Module):
|
||||
# we use -1 because sequence length can change
|
||||
att_output = att_output.reshape(batch_size, -1, num_key_value_heads * num_key_value_groups * head_dim)
|
||||
|
||||
return att_output
|
||||
return att_output
|
||||
|
||||
+60
-35
@@ -45,17 +45,21 @@ Note that in both examples, the repo/folder should contain at least `config.json
|
||||
|
||||
You can learn about the CLI options for this script in the `EvalPipelineConfig` in lerobot/configs/eval.py
|
||||
"""
|
||||
import concurrent
|
||||
|
||||
import concurrent.futures as cf
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from contextlib import nullcontext
|
||||
from copy import deepcopy
|
||||
from dataclasses import asdict
|
||||
from pathlib import Path
|
||||
from pprint import pformat
|
||||
from typing import Dict, List, Tuple, TypedDict
|
||||
from collections.abc import Iterator
|
||||
|
||||
import einops
|
||||
import gymnasium as gym
|
||||
@@ -68,7 +72,11 @@ from tqdm import trange
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.eval import EvalPipelineConfig
|
||||
from lerobot.envs.factory import make_env
|
||||
from lerobot.envs.utils import add_envs_task, check_env_attributes_and_types, preprocess_observation, preprocess_observation1
|
||||
from lerobot.envs.utils import (
|
||||
add_envs_task,
|
||||
check_env_attributes_and_types,
|
||||
preprocess_observation,
|
||||
)
|
||||
from lerobot.policies.factory import make_policy
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.utils import get_device_from_parameters
|
||||
@@ -79,9 +87,6 @@ from lerobot.utils.utils import (
|
||||
init_logging,
|
||||
inside_slurm,
|
||||
)
|
||||
from typing import TypedDict, Dict, List, Tuple, Iterator
|
||||
from collections import defaultdict
|
||||
import concurrent.futures as cf
|
||||
|
||||
|
||||
def rollout(
|
||||
@@ -485,8 +490,12 @@ def _compile_episode_data(
|
||||
data_dict["index"] = torch.arange(start_data_index, start_data_index + total_frames, 1)
|
||||
|
||||
return data_dict
|
||||
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy
|
||||
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy
|
||||
|
||||
|
||||
def _inject_normalization_stats(policy: SmolVLAPolicy, dataset_meta: LeRobotDatasetMetadata):
|
||||
"""Recreate normalization layers with proper stats from the dataset."""
|
||||
from lerobot.policies.normalize import Normalize, Unnormalize
|
||||
@@ -518,7 +527,8 @@ def _inject_normalization_stats(policy: SmolVLAPolicy, dataset_meta: LeRobotData
|
||||
|
||||
def load_smolvla(cfg, dataset_repo: str, policy):
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
dataset = LeRobotDataset(dataset_repo, root='/raid/jade/.cache/huggingface/datasets/')
|
||||
|
||||
dataset = LeRobotDataset(dataset_repo, root="/raid/jade/.cache/huggingface/datasets/")
|
||||
_inject_normalization_stats(policy=policy, dataset_meta=dataset.meta) # only needed if stats are missing
|
||||
return policy.to("cuda"), dataset
|
||||
|
||||
@@ -529,8 +539,8 @@ def eval_main(cfg: EvalPipelineConfig):
|
||||
|
||||
# Check device is available
|
||||
device = get_safe_torch_device(cfg.policy.device, log=True)
|
||||
#login to hf
|
||||
from huggingface_hub import login
|
||||
# login to hf
|
||||
|
||||
# login()
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
@@ -549,7 +559,7 @@ def eval_main(cfg: EvalPipelineConfig):
|
||||
breakpoint()
|
||||
# policy, _ = load_smolvla(cfg.policy, "physical-intelligence/libero", policy)
|
||||
# rename "image" -> "observation.image"
|
||||
|
||||
|
||||
policy.eval()
|
||||
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
|
||||
info = eval_policy_all(
|
||||
@@ -584,10 +594,11 @@ def eval_main(cfg: EvalPipelineConfig):
|
||||
|
||||
# ---- typed payload returned by one task eval ----
|
||||
class TaskMetrics(TypedDict):
|
||||
sum_rewards: List[float]
|
||||
max_rewards: List[float]
|
||||
successes: List[bool]
|
||||
video_paths: List[str]
|
||||
sum_rewards: list[float]
|
||||
max_rewards: list[float]
|
||||
successes: list[bool]
|
||||
video_paths: list[str]
|
||||
|
||||
|
||||
ACC_KEYS = ("sum_rewards", "max_rewards", "successes", "video_paths")
|
||||
|
||||
@@ -610,7 +621,7 @@ def eval_policy_all(
|
||||
"""
|
||||
global_start = time.time()
|
||||
|
||||
# inner: evaluate a single (suite, task)
|
||||
# inner: evaluate a single (suite, task)
|
||||
def eval_one(
|
||||
task_group: str,
|
||||
task_id: int,
|
||||
@@ -650,27 +661,36 @@ def eval_policy_all(
|
||||
video_paths=task_result.get("video_paths", []),
|
||||
)
|
||||
|
||||
# result producer: sequential or threaded, same consumer
|
||||
def iter_task_results() -> Iterator[Tuple[str, int, TaskMetrics]]:
|
||||
# result producer: sequential or threaded, same consumer
|
||||
def iter_task_results() -> Iterator[tuple[str, int, TaskMetrics]]:
|
||||
if max_parallel_tasks == 1:
|
||||
for task_group, tasks in envs.items():
|
||||
for task_id, vec in tasks.items():
|
||||
yield task_group, task_id, eval_one(
|
||||
task_group, task_id, vec,
|
||||
policy=policy,
|
||||
n_episodes=n_episodes,
|
||||
max_episodes_rendered=max_episodes_rendered,
|
||||
videos_dir=videos_dir,
|
||||
return_episode_data=return_episode_data,
|
||||
start_seed=start_seed,
|
||||
yield (
|
||||
task_group,
|
||||
task_id,
|
||||
eval_one(
|
||||
task_group,
|
||||
task_id,
|
||||
vec,
|
||||
policy=policy,
|
||||
n_episodes=n_episodes,
|
||||
max_episodes_rendered=max_episodes_rendered,
|
||||
videos_dir=videos_dir,
|
||||
return_episode_data=return_episode_data,
|
||||
start_seed=start_seed,
|
||||
),
|
||||
)
|
||||
else:
|
||||
with cf.ThreadPoolExecutor(max_workers=max_parallel_tasks) as executor:
|
||||
fut2key: Dict[cf.Future, Tuple[str, int]] = {}
|
||||
fut2key: dict[cf.Future, tuple[str, int]] = {}
|
||||
for task_group, tasks in envs.items():
|
||||
for task_id, vec in tasks.items():
|
||||
fut = executor.submit(
|
||||
eval_one, task_group, task_id, vec,
|
||||
eval_one,
|
||||
task_group,
|
||||
task_id,
|
||||
vec,
|
||||
policy=policy,
|
||||
n_episodes=n_episodes,
|
||||
max_episodes_rendered=max_episodes_rendered,
|
||||
@@ -683,9 +703,9 @@ def eval_policy_all(
|
||||
task_group, task_id = fut2key[fut]
|
||||
yield task_group, task_id, fut.result()
|
||||
|
||||
# single accumulator path on the main thread
|
||||
group_acc: Dict[str, Dict[str, List]] = defaultdict(lambda: {k: [] for k in ACC_KEYS})
|
||||
overall: Dict[str, List] = {k: [] for k in ACC_KEYS}
|
||||
# single accumulator path on the main thread
|
||||
group_acc: dict[str, dict[str, list]] = defaultdict(lambda: {k: [] for k in ACC_KEYS})
|
||||
overall: dict[str, list] = {k: [] for k in ACC_KEYS}
|
||||
|
||||
for task_group, task_id, metrics in iter_task_results():
|
||||
acc = group_acc[task_group]
|
||||
@@ -694,7 +714,7 @@ def eval_policy_all(
|
||||
overall[k].extend(metrics[k])
|
||||
|
||||
# build outputs
|
||||
results: Dict[str, dict] = {}
|
||||
results: dict[str, dict] = {}
|
||||
for task_group, data in group_acc.items():
|
||||
suite_rewards = data["sum_rewards"]
|
||||
suite_max = data["max_rewards"]
|
||||
@@ -720,9 +740,15 @@ def eval_policy_all(
|
||||
global_eval_ep_s = global_eval_s / max(1, len(overall["sum_rewards"]))
|
||||
results["overall"] = {
|
||||
"aggregated": {
|
||||
"avg_sum_reward": float(np.nanmean(overall["sum_rewards"])) if overall["sum_rewards"] else float("nan"),
|
||||
"avg_max_reward": float(np.nanmean(overall["max_rewards"])) if overall["max_rewards"] else float("nan"),
|
||||
"pc_success": float(np.nanmean(overall["successes"]) * 100) if overall["successes"] else float("nan"),
|
||||
"avg_sum_reward": float(np.nanmean(overall["sum_rewards"]))
|
||||
if overall["sum_rewards"]
|
||||
else float("nan"),
|
||||
"avg_max_reward": float(np.nanmean(overall["max_rewards"]))
|
||||
if overall["max_rewards"]
|
||||
else float("nan"),
|
||||
"pc_success": float(np.nanmean(overall["successes"]) * 100)
|
||||
if overall["successes"]
|
||||
else float("nan"),
|
||||
"eval_s": global_eval_s,
|
||||
"eval_ep_s": global_eval_ep_s,
|
||||
},
|
||||
@@ -732,7 +758,6 @@ def eval_policy_all(
|
||||
return results
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
init_logging()
|
||||
eval_main()
|
||||
|
||||
@@ -105,6 +105,7 @@ def update_policy(
|
||||
train_metrics.update_s = time.perf_counter() - start_time
|
||||
return train_metrics, output_dict
|
||||
|
||||
|
||||
# def _inject_normalization_stats(policy: SmolVLAPolicy, dataset_meta: LeRobotDatasetMetadata):
|
||||
# """Recreate normalization layers with dataset stats if missing (Adil's workaround)."""
|
||||
# from lerobot.policies.normalize import Normalize, Unnormalize
|
||||
@@ -132,6 +133,7 @@ def update_policy(
|
||||
|
||||
# print("✅ Normalization layers injected with dataset stats.")
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def train(cfg: TrainPipelineConfig):
|
||||
cfg.validate()
|
||||
@@ -271,9 +273,12 @@ def train(cfg: TrainPipelineConfig):
|
||||
if cfg.env and is_eval_step:
|
||||
step_id = get_step_identifier(step, cfg.steps)
|
||||
logging.info(f"Eval policy at step {step}")
|
||||
with torch.no_grad(), (torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext()):
|
||||
with (
|
||||
torch.no_grad(),
|
||||
torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(),
|
||||
):
|
||||
eval_info = eval_policy_all(
|
||||
eval_env, # dict[suite][task_id] -> vec_env
|
||||
eval_env, # dict[suite][task_id] -> vec_env
|
||||
policy,
|
||||
cfg.eval.n_episodes,
|
||||
videos_dir=videos_dir,
|
||||
@@ -295,15 +300,15 @@ def train(cfg: TrainPipelineConfig):
|
||||
# meters/tracker
|
||||
eval_metrics = {
|
||||
"avg_sum_reward": AverageMeter("∑rwrd", ":.3f"),
|
||||
"pc_success": AverageMeter("success", ":.1f"),
|
||||
"eval_s": AverageMeter("eval_s", ":.3f"),
|
||||
"pc_success": AverageMeter("success", ":.1f"),
|
||||
"eval_s": AverageMeter("eval_s", ":.3f"),
|
||||
}
|
||||
eval_tracker = MetricsTracker(
|
||||
cfg.batch_size, dataset.num_frames, dataset.num_episodes, eval_metrics, initial_step=step
|
||||
)
|
||||
eval_tracker.eval_s = aggregated.get("eval_s", 0.0)
|
||||
eval_tracker.eval_s = aggregated.get("eval_s", 0.0)
|
||||
eval_tracker.avg_sum_reward = aggregated.get("avg_sum_reward", float("nan"))
|
||||
eval_tracker.pc_success = aggregated.get("pc_success", float("nan"))
|
||||
eval_tracker.pc_success = aggregated.get("pc_success", float("nan"))
|
||||
if wandb_logger:
|
||||
wandb_log_dict = {**eval_tracker.to_dict(), **eval_info}
|
||||
wandb_logger.log_dict(wandb_log_dict, step, mode="eval")
|
||||
|
||||
@@ -104,6 +104,7 @@ def update_policy(
|
||||
train_metrics.update_s = time.perf_counter() - start_time
|
||||
return train_metrics, output_dict
|
||||
|
||||
|
||||
def _inject_normalization_stats(policy: SmolVLAPolicy, dataset_meta: LeRobotDatasetMetadata):
|
||||
"""Recreate normalization layers with dataset stats if missing (Adil's workaround)."""
|
||||
from lerobot.policies.normalize import Normalize, Unnormalize
|
||||
@@ -115,15 +116,15 @@ def _inject_normalization_stats(policy: SmolVLAPolicy, dataset_meta: LeRobotData
|
||||
stats = {}
|
||||
for key, stat_dict in dataset_meta.stats.items():
|
||||
stats[key] = {
|
||||
stat_type: torch.as_tensor(stat_array)
|
||||
if isinstance(stat_array, np.ndarray)
|
||||
else stat_array
|
||||
stat_type: torch.as_tensor(stat_array) if isinstance(stat_array, np.ndarray) else stat_array
|
||||
for stat_type, stat_array in stat_dict.items()
|
||||
}
|
||||
|
||||
normalize_inputs = Normalize(policy.config.input_features, policy.config.normalization_mapping, stats)
|
||||
normalize_targets = Normalize(policy.config.output_features, policy.config.normalization_mapping, stats)
|
||||
unnormalize_outputs = Unnormalize(policy.config.output_features, policy.config.normalization_mapping, stats)
|
||||
unnormalize_outputs = Unnormalize(
|
||||
policy.config.output_features, policy.config.normalization_mapping, stats
|
||||
)
|
||||
|
||||
policy.normalize_inputs = normalize_inputs
|
||||
policy.normalize_targets = normalize_targets
|
||||
@@ -131,6 +132,7 @@ def _inject_normalization_stats(policy: SmolVLAPolicy, dataset_meta: LeRobotData
|
||||
|
||||
print("✅ Normalization layers injected with dataset stats.")
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def train(cfg: TrainPipelineConfig):
|
||||
cfg.validate()
|
||||
|
||||
@@ -24,12 +24,15 @@ from accelerate.utils import set_seed as accelerate_set_seed
|
||||
from termcolor import colored
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.datasets.factory import make_dataset
|
||||
from lerobot.datasets.sampler import EpisodeAwareSampler
|
||||
from lerobot.envs.factory import make_env
|
||||
from lerobot.optim.factory import make_optimizer_and_scheduler
|
||||
from lerobot.policies.factory import make_policy
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.scripts.eval import eval_policy
|
||||
from lerobot.utils.logging_utils import AverageMeter, MetricsTracker
|
||||
from lerobot.utils.train_utils import (
|
||||
get_step_checkpoint_dir,
|
||||
@@ -43,9 +46,6 @@ from lerobot.utils.utils import (
|
||||
has_method,
|
||||
init_logging,
|
||||
)
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.scripts.eval import eval_policy
|
||||
|
||||
|
||||
def update_policy(
|
||||
@@ -100,6 +100,7 @@ def train(cfg: TrainPipelineConfig):
|
||||
|
||||
# Initialize accelerator
|
||||
from accelerate.utils import DistributedDataParallelKwargs
|
||||
|
||||
# added by jade 2 lines
|
||||
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=False)
|
||||
accelerator = Accelerator(..., kwargs_handlers=[ddp_kwargs])
|
||||
@@ -357,7 +358,7 @@ def train(cfg: TrainPipelineConfig):
|
||||
|
||||
if accelerator.is_main_process:
|
||||
logging.info("End of training")
|
||||
accelerator.end_training() # added by jade
|
||||
accelerator.end_training() # added by jade
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user