[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot]
2025-09-11 11:51:53 +00:00
parent 565c992589
commit a19d7fb6bf
17 changed files with 469 additions and 254 deletions
+1 -1
View File
@@ -55,4 +55,4 @@ python src/lerobot/scripts/eval.py \
# --num_trials_per_task 10 \
# --video_out_path "data/libero/videos" \
# --device "cuda" \
# --seed 7
# --seed 7
+3 -5
View File
@@ -61,9 +61,9 @@ def make_env(
env_cls = gym.vector.AsyncVectorEnv if use_async_envs else gym.vector.SyncVectorEnv
if "libero" in cfg.type:
from lerobot.envs.libero import create_libero_envs
return create_libero_envs(
task=cfg.task,
n_envs=n_envs,
@@ -74,17 +74,16 @@ def make_env(
multitask_eval=cfg.multitask_eval,
)
package_name = f"gym_{cfg.type}"
try:
importlib.import_module(package_name)
except ModuleNotFoundError as e:
raise ModuleNotFoundError(
f"{package_name} is not installed. Install with: pip install \"lerobot[{cfg.type}]\""
f'{package_name} is not installed. Install with: pip install "lerobot[{cfg.type}]"'
) from e
gym_handle = f"{package_name}/{cfg.task}"
def _make_one():
return gym.make(gym_handle, disable_env_checker=True, **(cfg.gym_kwargs or {}))
@@ -93,4 +92,3 @@ def make_env(
# normalize to {suite: {task_id: vec_env}} for consistency
suite_name = cfg.type # e.g., "pusht", "aloha"
return {suite_name: {0: vec}}
+35 -20
View File
@@ -1,11 +1,13 @@
from __future__ import annotations
import logging
import math
import os
from collections import defaultdict
from collections.abc import Callable
from itertools import chain
from typing import Any
from typing import Any, Dict, List
from collections.abc import Callable, Iterable, Mapping, Sequence
import gymnasium as gym
import numpy as np
@@ -14,16 +16,12 @@ from gymnasium import spaces
from libero.libero import benchmark, get_libero_path
from libero.libero.envs import OffScreenRenderEnv
import logging
from collections import defaultdict
from typing import Any, Callable, Dict, Iterable, List, Mapping, Sequence
logger = logging.getLogger(__name__)
# ---- Helpers -----------------------------------------------------------------
def _parse_camera_names(camera_name: str | Sequence[str]) -> List[str]:
def _parse_camera_names(camera_name: str | Sequence[str]) -> list[str]:
"""Normalize camera_name into a non-empty list of strings."""
if isinstance(camera_name, str):
cams = [c.strip() for c in camera_name.split(",") if c.strip()]
@@ -47,14 +45,14 @@ def _get_suite(name: str):
return suite
def _select_task_ids(total_tasks: int, task_ids: Iterable[int] | None) -> List[int]:
def _select_task_ids(total_tasks: int, task_ids: Iterable[int] | None) -> list[int]:
"""Validate/normalize task ids. If None → all tasks."""
if task_ids is None:
return list(range(total_tasks))
ids = sorted(set(int(t) for t in task_ids))
ids = sorted({int(t) for t in task_ids})
for t in ids:
if t < 0 or t >= total_tasks:
raise ValueError(f"task_id {t} out of range [0, {total_tasks-1}].")
raise ValueError(f"task_id {t} out of range [0, {total_tasks - 1}].")
return ids
@@ -64,16 +62,25 @@ def _make_env_fns(
suite_name: str,
task_id: int,
n_envs: int,
camera_names: List[str],
camera_names: list[str],
init_states: bool,
gym_kwargs: Mapping[str, Any],
LiberoEnv: type, # injected to avoid forward ref issues if needed
) -> List[Callable[[], "LiberoEnv"]]:
) -> list[Callable[[], LiberoEnv]]:
"""Build n_envs factory callables for a single (suite, task_id)."""
joined_cams = ",".join(camera_names) # keep backward-compat: downstream expects a string
fns: List[Callable[[], "LiberoEnv"]] = []
fns: list[Callable[[], LiberoEnv]] = []
for i in range(n_envs):
def _mk(i=i, suite=suite, task_id=task_id, suite_name=suite_name, joined_cams=joined_cams, init_states=init_states, gym_kwargs=dict(gym_kwargs)):
def _mk(
i=i,
suite=suite,
task_id=task_id,
suite_name=suite_name,
joined_cams=joined_cams,
init_states=init_states,
gym_kwargs=dict(gym_kwargs),
):
return LiberoEnv(
task_suite=suite,
task_id=task_id,
@@ -83,11 +90,14 @@ def _make_env_fns(
episode_index=i,
**gym_kwargs,
)
fns.append(_mk)
return fns
# ---- Main API ----------------------------------------------------------------
def create_libero_envs(
task: str,
n_envs: int,
@@ -130,12 +140,15 @@ def create_libero_envs(
logger.info(
"Creating LIBERO envs | suites=%s | n_envs(per task)=%d | init_states=%s | multitask_eval=%s",
suite_names, n_envs, init_states, bool(multitask_eval)
suite_names,
n_envs,
init_states,
bool(multitask_eval),
)
if task_ids_filter is not None:
logger.info("Restricting to task_ids=%s", task_ids_filter)
out: Dict[str, Dict[int, Any]] = defaultdict(dict)
out: dict[str, dict[int, Any]] = defaultdict(dict)
for suite_name in suite_names:
suite = _get_suite(suite_name)
@@ -161,6 +174,8 @@ def create_libero_envs(
# return plain dicts for predictability
return {suite: dict(task_map) for suite, task_map in out.items()}
def quat2axisangle(quat):
"""
Copied from robosuite: https://github.com/ARISE-Initiative/robosuite/blob/eafb81f54ffc104f905ee48a16bb15f059176ad3/robosuite/utils/transform_utils.py#L490C1-L512C55
@@ -256,10 +271,10 @@ class LiberoEnv(gym.Env):
self._env = self._make_envs_task(task_suite, self.task_id)
TASK_SUITE_MAX_STEPS: dict[str, int] = {
"libero_spatial": 220, # longest training demo has 193 steps
"libero_object": 280, # longest training demo has 254 steps
"libero_goal": 300, # longest training demo has 270 steps
"libero_10": 520, # longest training demo has 505 steps
"libero_90": 400, # longest training demo has 373 steps
"libero_object": 280, # longest training demo has 254 steps
"libero_goal": 300, # longest training demo has 270 steps
"libero_10": 520, # longest training demo has 505 steps
"libero_90": 400, # longest training demo has 373 steps
}
default_steps = 500
self._max_episode_steps = TASK_SUITE_MAX_STEPS.get(task_suite_name, default_steps)
+6 -1
View File
@@ -80,6 +80,7 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
return return_observations
def preprocess_observation1(
observations: dict[str, np.ndarray], cfg: dict[str, Any] = None
) -> dict[str, Tensor]:
@@ -130,6 +131,8 @@ def preprocess_observation1(
if "task" in observations:
return_observations["task"] = observations["task"]
return return_observations
def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]:
# TODO(aliberts, rcadene): remove this hardcoding of keys and just use the nested keys as is
# (need to also refactor preprocess_observation and externalize normalization from policies)
@@ -183,6 +186,7 @@ def add_envs_task(env: gym.vector.VectorEnv, observation: dict[str, Any]) -> dic
observation["task"] = ["" for _ in range(num_envs)]
return observation
def _close_single_env(env: Any) -> None:
"""Try to close a single env object if it exposes .close()."""
try:
@@ -193,6 +197,7 @@ def _close_single_env(env: Any) -> None:
# Best-effort close: log but don't raise
LOG.debug("Exception while closing env %s: %s", env, exc)
def close_envs(env_or_collection: Any) -> None:
"""
Close a single env or any nested structure of envs.
@@ -225,4 +230,4 @@ def close_envs(env_or_collection: Any) -> None:
# Fallback: try to close if possible
if hasattr(env_or_collection, "close"):
_close_single_env(env_or_collection)
_close_single_env(env_or_collection)
+1 -1
View File
@@ -31,10 +31,10 @@ from lerobot.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.sac.configuration_sac import SACConfig
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
from lerobot.policies.smolpi0.configuration_smolpi0 import SMOLPI0Config
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
from lerobot.policies.smolpi0.configuration_smolpi0 import SMOLPI0Config
def get_policy_class(name: str) -> PreTrainedPolicy:
+4 -2
View File
@@ -309,8 +309,8 @@ class NormalizePerRobotType(nn.Module):
getattr(self, f"{robot_type}_buffer_" + key.replace(".", "_")) for robot_type in robot_types
]
if norm_mode is NormalizationMode.MEAN_STD:
mean = torch.stack([buffers[i]["mean"] for i in range(len(robot_types))],dim=0)
std = torch.stack([buffers[i]["std"] for i in range(len(robot_types))],dim=0)
mean = torch.stack([buffers[i]["mean"] for i in range(len(robot_types))], dim=0)
std = torch.stack([buffers[i]["std"] for i in range(len(robot_types))], dim=0)
if batch[key].ndim == 3:
mean = mean.unsqueeze(1)
std = std.unsqueeze(1)
@@ -332,6 +332,8 @@ class NormalizePerRobotType(nn.Module):
else:
raise ValueError(norm_mode)
return batch
# TODO (azouitine): We should replace all normalization on the policies with register_buffer normalization
# and remove the `Normalize` and `Unnormalize` classes.
def _initialize_stats_buffers(
@@ -14,12 +14,12 @@
from dataclasses import dataclass, field
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import (
CosineDecayWithWarmupSchedulerConfig,
)
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
@dataclass
@@ -52,7 +52,7 @@ class SMOLPI0Config(PreTrainedConfig):
max_action_dim: int = 32
# Image preprocessing
resize_imgs_with_padding: tuple[int, int] = (512, 512) #(224, 224)
resize_imgs_with_padding: tuple[int, int] = (512, 512) # (224, 224)
# Add empty images. Used by pi0_aloha_sim which adds the empty
# left and right wrist cameras in addition to the top camera.
@@ -107,14 +107,14 @@ class SMOLPI0Config(PreTrainedConfig):
add_image_special_tokens: bool = False
add_prompt_template: bool = False
prefix_prompt_template: str = f"<|im_start|>User: What action should the robot take to"
suffix_prompt_template: str = f"?\nAssistant:"
prefix_prompt_template: str = "<|im_start|>User: What action should the robot take to"
suffix_prompt_template: str = "?\nAssistant:"
attention_mode: str = "self_attn"
prefix_length: int = -1 # n_obs_steps * num_cameras * num_image_token_per_image + tokenizer_max_length
prefix_length: int = -1 # n_obs_steps * num_cameras * num_image_token_per_image + tokenizer_max_length
past_obs_keys: str = f"image"
past_obs_keys: str = "image"
add_local_special_image_tokens: bool = False
@@ -122,7 +122,7 @@ class SMOLPI0Config(PreTrainedConfig):
state_to_prefix: bool = False
pad_language_to: str = "longest" # "max_length"
pad_language_to: str = "longest" # "max_length"
num_expert_layers: int = -1
num_vlm_layers: int = -1
@@ -144,9 +144,9 @@ class SMOLPI0Config(PreTrainedConfig):
shuffle_camera_positions: bool = False
vlm_img_size: int = -1
regression_loss: bool = False
def __post_init__(self):
super().__post_init__()
if self.vlm_img_size > 0:
@@ -198,7 +198,7 @@ class SMOLPI0Config(PreTrainedConfig):
)
@property
def observation_delta_indices(self) -> list: # FIXME(mshukor): support spacing between observations
def observation_delta_indices(self) -> list: # FIXME(mshukor): support spacing between observations
return [-k for k in range(0, self.n_obs_steps * self.n_obs_gap, self.n_obs_gap)][::-1]
@property
@@ -85,7 +85,7 @@ def flex_attention_forward(
b_mask, h_mask, q_len, kv_len = causal_mask.shape # The shape of your mask
block_size = 128 # limitation of flex attention
block_size = 128 # limitation of flex attention
q_len_rounded = _round_up_to_multiple(q_len, block_size)
kv_len_rounded = _round_up_to_multiple(kv_len, block_size)
@@ -95,7 +95,7 @@ def flex_attention_forward(
pad_k = kv_len_rounded - kv_len
if pad_q > 0 or pad_k > 0:
padded_causal_mask = F.pad(causal_mask, (0, pad_k, 0, pad_q), value=0.0)
else:
else:
padded_causal_mask = causal_mask
mask_mod_fn_orig = precomputed_mask_factory(padded_causal_mask)
@@ -107,21 +107,21 @@ def flex_attention_forward(
KV_LEN=kv_len_rounded,
device=causal_mask.device,
)
mask_mod_fn_padded = precomputed_mask_factory(mask_4d)
# FIXME(mshukor): compile mask torch.compile(create_block_mask)
create_block_mask_compiled = torch.compile(create_block_mask)
block_mask = create_block_mask_compiled(
mask_mod=mask_mod_fn_padded,
B=b_mask,
H=None, #
H=None, #
Q_LEN=q_len_rounded,
KV_LEN=kv_len_rounded,
BLOCK_SIZE=block_size,
device=causal_mask.device,
_compile=False,
)
padded_query_states = F.pad(query_states, (0, 0, 0, pad_q), value=0.0) if pad_q > 0 else query_states
padded_query_states = F.pad(query_states, (0, 0, 0, pad_q), value=0.0) if pad_q > 0 else query_states
padded_key_states = F.pad(key_states, (0, 0, 0, pad_k), value=0.0) if pad_k > 0 else key_states
padded_value_states = F.pad(value_states, (0, 0, 0, pad_k), value=0.0) if pad_k > 0 else value_states
# mask is applied inside the kernel, ideally more efficiently than score_mod.
+138 -73
View File
@@ -50,9 +50,10 @@ policy = Pi0Policy.from_pretrained("lerobot/pi0")
"""
import math
from collections import deque
import os
import re
from collections import deque
import safetensors
import torch
import torch.nn.functional as F # noqa: N812
@@ -66,12 +67,11 @@ from lerobot.policies.normalize import (
Unnormalize,
UnnormalizePerRobotType,
)
from lerobot.policies.smolpi0.configuration_smolpi0 import SMOLPI0Config
from lerobot.policies.smolpi0.smolvlm_with_expert import (
SmolVLMWithExpertModel
)
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.smolpi0.configuration_smolpi0 import SMOLPI0Config
from lerobot.policies.smolpi0.smolvlm_with_expert import SmolVLMWithExpertModel
from lerobot.utils.utils import get_safe_dtype
OBS_IMAGE = "observation.image"
OBS_IMAGES = "observation.images"
ACTION = "action"
@@ -86,10 +86,13 @@ IMAGES_ORDER = {
OBS_IMAGE_3: 2,
OBS_IMAGE_4: 3,
}
import random
from lerobot.policies.utils import (
populate_queues,
)
import random
def create_sinusoidal_pos_embedding(
time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu"
) -> Tensor:
@@ -171,7 +174,10 @@ def resize_with_pad(img, width, height, pad_value=-1):
padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value)
return padded_img
_VARIANT_RE = re.compile(r"\.so\d+(?:-[\w]+)?_buffer_")
def canonicalise(k: str) -> str:
"""
Remove dataset-variant markers like '.so100-blue_' or '.so100_' from a
@@ -179,6 +185,7 @@ def canonicalise(k: str) -> str:
"""
return _VARIANT_RE.sub(".buffer_", k)
def standardise_state_dict(
checkpoint: dict[str, torch.Tensor], ref_keys: set[str], *, verbose: bool = True
) -> tuple[dict[str, torch.Tensor], list[str]]:
@@ -209,6 +216,7 @@ def standardise_state_dict(
out.update({k: checkpoint[k] for k in unmatched})
return out, unmatched
def load_smolvla(
model: torch.nn.Module,
filename: str | os.PathLike,
@@ -237,6 +245,8 @@ def load_smolvla(
)
return model
def pad_vector(vector, new_dim):
"""Can be (batch_size x sequence_length x features_dimension)
or (batch_size x features_dimension)
@@ -286,6 +296,7 @@ def aloha_gripper_to_angular(value):
# The values 0.4 and 1.5 were measured on an actual Trossen robot.
return normalize(value, min_val=0.4, max_val=1.5)
def rename_checkpoint_keys(checkpoint: dict, rename_str: str):
"""
Renames keys in a checkpoint dictionary based on the given rename string.
@@ -307,6 +318,8 @@ def rename_checkpoint_keys(checkpoint: dict, rename_str: str):
k = k.replace(old_key, new_key)
new_checkpoint[k] = v
return new_checkpoint
def aloha_gripper_from_angular(value):
# Convert from the gripper position used by pi0 to the gripper position that is used by Aloha.
# Note that the units are still angular but the range is different.
@@ -324,6 +337,7 @@ def aloha_gripper_from_angular_inv(value):
value = unnormalize(value, min_val=-0.6213, max_val=1.4910)
return normalize(value, min_val=0.4, max_val=1.5)
class SMOLPI0Policy(PreTrainedPolicy):
"""Wrapper class around VLAFlowMatching model to train and run inference within LeRobot."""
@@ -374,7 +388,9 @@ class SMOLPI0Policy(PreTrainedPolicy):
self.language_tokenizer = AutoProcessor.from_pretrained(self.config.vlm_model_name).tokenizer
self.model = VLAFlowMatching(config)
self.include_past_states = config.n_obs_steps > 1 and OBS_STATE in self.config.past_obs_keys.split(",")
self.include_past_states = config.n_obs_steps > 1 and OBS_STATE in self.config.past_obs_keys.split(
","
)
self.include_past_images = config.n_obs_steps > 1 and "image" in self.config.past_obs_keys.split(",")
self.num_past_images = self.config.n_obs_steps if self.include_past_images else 1
self.reset()
@@ -389,31 +405,20 @@ class SMOLPI0Policy(PreTrainedPolicy):
for k in self.config.input_features:
if any([past_obs_key in k for past_obs_key in self.config.past_obs_keys.split(",")]):
self._queues[k] = deque(maxlen=self.config.n_obs_steps)
def get_optim_params(self) -> dict:
if self.config.optimizer_lr_vlm > 0 and self.config.optimizer_lr_vlm != self.config.optimizer_lr:
params = [
{"params": [p for n, p in self.named_parameters() if ".vlm." not in n and p.requires_grad]},
{
"params": [
p
for n, p in self.named_parameters()
if not ".vlm." in n and p.requires_grad
]
},
{
"params": [
p
for n, p in self.named_parameters()
if ".vlm." in n and p.requires_grad
],
"params": [p for n, p in self.named_parameters() if ".vlm." in n and p.requires_grad],
"lr": self.config.optimizer_lr_vlm,
},
]
return params
else:
return self.parameters()
def merge_peft_model_weights(self) -> None:
if "lora" in self.config.peft_method:
@@ -438,9 +443,7 @@ class SMOLPI0Policy(PreTrainedPolicy):
state = self.prepare_state(batch)
lang_tokens, lang_masks = self.prepare_language(batch)
actions = self.model.sample_actions(
images, img_masks, lang_tokens, lang_masks, state, noise=noise
)
actions = self.model.sample_actions(images, img_masks, lang_tokens, lang_masks, state, noise=noise)
# Unpad actions
original_action_dim = self.config.action_feature.shape[0]
actions = actions[:, :, :original_action_dim]
@@ -469,6 +472,7 @@ class SMOLPI0Policy(PreTrainedPolicy):
device=map_location,
checkpoint_keys_mapping="model._orig_mod.//model.",
)
@torch.no_grad
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
"""Select a single action given environment observations.
@@ -564,8 +568,12 @@ class SMOLPI0Policy(PreTrainedPolicy):
present_img_keys = [key for key in self.config.image_features if key in batch]
missing_img_keys = [key for key in self.config.image_features if key not in batch]
present_img_keys = sorted(present_img_keys, key=lambda k: IMAGES_ORDER.get(k, float("inf")), reverse=self.config.reverse_images_order)
if self.config.shuffle_camera_positions and ACTION in batch: # only during training
present_img_keys = sorted(
present_img_keys,
key=lambda k: IMAGES_ORDER.get(k, float("inf")),
reverse=self.config.reverse_images_order,
)
if self.config.shuffle_camera_positions and ACTION in batch: # only during training
present_img_keys = random.sample(present_img_keys, len(present_img_keys))
if len(present_img_keys) == 0:
raise ValueError(
@@ -609,7 +617,10 @@ class SMOLPI0Policy(PreTrainedPolicy):
tasks = [tasks[0] for _ in range(batch[OBS_STATE].shape[0])]
if self.config.add_prompt_template:
tasks = [f"{self.config.prefix_prompt_template}{task}{self.config.suffix_prompt_template}" for task in tasks]
tasks = [
f"{self.config.prefix_prompt_template}{task}{self.config.suffix_prompt_template}"
for task in tasks
]
else:
tasks = [task if task.endswith("\n") else f"{task}\n" for task in tasks]
tokenized_prompt = self.language_tokenizer.__call__(
@@ -618,7 +629,7 @@ class SMOLPI0Policy(PreTrainedPolicy):
padding_side="right",
max_length=self.config.tokenizer_max_length,
return_tensors="pt",
truncation=True, # FIXME(mshukor)
truncation=True, # FIXME(mshukor)
)
lang_tokens = tokenized_prompt["input_ids"].to(device=device)
@@ -655,7 +666,11 @@ class SMOLPI0Policy(PreTrainedPolicy):
def prepare_state(self, batch):
"""Pad state"""
state = batch[OBS_STATE][:, -1, :] if (batch[OBS_STATE].ndim > 2 and not self.include_past_states) else batch[OBS_STATE] # FIXME(mshukor): no state history for now
state = (
batch[OBS_STATE][:, -1, :]
if (batch[OBS_STATE].ndim > 2 and not self.include_past_states)
else batch[OBS_STATE]
) # FIXME(mshukor): no state history for now
state = pad_vector(state, self.config.max_state_dim)
return state
@@ -666,7 +681,9 @@ class SMOLPI0Policy(PreTrainedPolicy):
if self.config.relative_actions_mode == "first":
actions = torch.cat((actions[:, :1], actions[:, 1:] - actions[:, :1]), dim=1)
elif self.config.relative_actions_mode == "state":
assert batch[ACTION].shape[-1] == batch[OBS_STATE].shape[-1], "Relative action mode 'state' requires the action and state to have the same dimension."
assert batch[ACTION].shape[-1] == batch[OBS_STATE].shape[-1], (
"Relative action mode 'state' requires the action and state to have the same dimension."
)
if state.ndim == 2:
state = state.unsqueeze(1)
actions = actions - state
@@ -674,6 +691,7 @@ class SMOLPI0Policy(PreTrainedPolicy):
actions = torch.cat((actions[:, :1], actions[:, 1:] - actions[:, :-1]), dim=1)
return actions
def pad_tensor(tensor, max_len, pad_value=0):
"""
Efficiently pads a tensor along sequence dimension to match max_len.
@@ -687,13 +705,16 @@ def pad_tensor(tensor, max_len, pad_value=0):
torch.Tensor: Shape (B, max_len, ...) or (B, max_len).
"""
B, L = tensor.shape[:2]
# Create a padded tensor of max_len and copy the existing values
padded_tensor = torch.full((B, max_len, *tensor.shape[2:]), pad_value, dtype=tensor.dtype, device=tensor.device)
padded_tensor = torch.full(
(B, max_len, *tensor.shape[2:]), pad_value, dtype=tensor.dtype, device=tensor.device
)
padded_tensor[:, :L] = tensor # Efficient in-place copy
return padded_tensor
class VLAFlowMatching(nn.Module):
"""
π0: A Vision-Language-Action Flow Model for General Robot Control
@@ -725,7 +746,8 @@ class VLAFlowMatching(nn.Module):
super().__init__()
self.config = config
self.vlm_with_expert = SmolVLMWithExpertModel(model_id=self.config.vlm_model_name,
self.vlm_with_expert = SmolVLMWithExpertModel(
model_id=self.config.vlm_model_name,
freeze_vision_encoder=self.config.freeze_vision_encoder,
train_expert_only=self.config.train_expert_only,
attention_implementation=self.config.attention_implementation,
@@ -736,50 +758,64 @@ class VLAFlowMatching(nn.Module):
self_attn_every_n_layers=self.config.self_attn_every_n_layers,
expert_width_multiplier=self.config.expert_width_multiplier,
self_attn_only_actions=self.config.self_attn_only_actions,
)
)
# self.paligemma_with_expert = self.configure_peft(paligemma_with_expert)
self.vlm_with_expert.configure_peft(config=self.config)
# Projections are float32
self.state_to_prefix = self.config.state_to_prefix
if self.state_to_prefix:
self.state_proj = nn.Linear(self.config.max_state_dim, self.vlm_with_expert.config.text_config.hidden_size)
self.state_proj = nn.Linear(
self.config.max_state_dim, self.vlm_with_expert.config.text_config.hidden_size
)
else:
self.state_proj = nn.Linear(self.config.max_state_dim, self.vlm_with_expert.expert_hidden_size)
self.action_in_proj = nn.Linear(self.config.max_action_dim, self.vlm_with_expert.expert_hidden_size)
self.action_out_proj = nn.Linear(self.vlm_with_expert.expert_hidden_size, self.config.max_action_dim)
self.action_time_mlp_in = nn.Linear(self.vlm_with_expert.expert_hidden_size * 2, self.vlm_with_expert.expert_hidden_size)
self.action_time_mlp_out = nn.Linear(self.vlm_with_expert.expert_hidden_size, self.vlm_with_expert.expert_hidden_size)
self.action_time_mlp_in = nn.Linear(
self.vlm_with_expert.expert_hidden_size * 2, self.vlm_with_expert.expert_hidden_size
)
self.action_time_mlp_out = nn.Linear(
self.vlm_with_expert.expert_hidden_size, self.vlm_with_expert.expert_hidden_size
)
self.set_requires_grad()
# SmolVLM2 has: [fake_tok + crop_tok + crop + fake_tok + crop_tok ... + fake_tok + global_tok + global + fake_tok] + [second image] + ...
if any([k in self.config.vlm_model_name for k in ["SmolVLM-", "SmolVLA-"]]):
if any([k in self.config.vlm_model_name for k in ["SmolVLM-", "SmolVLA-"]]):
if "SmolVLM-Instruct" in self.config.vlm_model_name:
self.fake_image_token = 49152
self.global_image_token = [44, 13906, 29, 6266, 46]
self.global_image_start_token = torch.tensor([self.fake_image_token] + self.global_image_token, dtype=torch.long)
self.fake_image_token = 49152
self.global_image_token = [44, 13906, 29, 6266, 46]
self.global_image_start_token = torch.tensor(
[self.fake_image_token] + self.global_image_token, dtype=torch.long
)
else:
self.fake_image_token = 49189
self.global_image_token = 49152
self.global_image_start_token = torch.tensor([self.fake_image_token, self.global_image_token], dtype=torch.long)
self.fake_image_token = 49189
self.global_image_token = 49152
self.global_image_start_token = torch.tensor(
[self.fake_image_token, self.global_image_token], dtype=torch.long
)
else:
self.fake_image_token = self.vlm_with_expert.processor.tokenizer.fake_image_token_id
self.global_image_token = self.vlm_with_expert.processor.tokenizer.global_image_token_id
self.global_image_start_token = torch.tensor([self.fake_image_token, self.global_image_token], dtype=torch.long)
self.global_image_start_token = torch.tensor(
[self.fake_image_token, self.global_image_token], dtype=torch.long
)
self.add_image_special_tokens = self.config.add_image_special_tokens
self.add_local_special_image_tokens = self.config.add_local_special_image_tokens
self.local_image_tokens = [torch.tensor([self.fake_image_token, tok], dtype=torch.long) for tok in [49153, 49154, 49155, 49159, 49160, 49161, 49165, 49166, 49167]] # assume 3 x 3 grid
self.local_image_tokens = [
torch.tensor([self.fake_image_token, tok], dtype=torch.long)
for tok in [49153, 49154, 49155, 49159, 49160, 49161, 49165, 49166, 49167]
] # assume 3 x 3 grid
self.local_image_start_token = self.global_image_start_token
self.image_end_token = torch.tensor([self.fake_image_token], dtype=torch.long)
self.prefix_length = self.config.prefix_length
self.include_past_images = self.config.n_obs_steps > 1 and "image" in self.config.past_obs_keys.split(",")
self.include_past_images = self.config.n_obs_steps > 1 and "image" in self.config.past_obs_keys.split(
","
)
self.num_past_images = self.config.n_obs_steps if self.include_past_images else 1
self.causal_attention_on_history = self.config.causal_attention_on_history
# def configure_peft(self, model):
# # return model
@@ -845,14 +881,30 @@ class VLAFlowMatching(nn.Module):
img,
img_mask,
) in enumerate(zip(images, img_masks, strict=False)):
# FIXME(mshukor): add special tokens for the history each history_steps or not
# FIXME(mshukor): add special tokens for the history each history_steps or not
if self.add_image_special_tokens:
if self.add_local_special_image_tokens and img_idx % num_images != num_images - 1:
local_token_idx = img_idx % num_images
image_start_token = self.vlm_with_expert.embed_language_tokens(self.local_image_tokens[local_token_idx].to(device=self.vlm_with_expert.vlm.device)).unsqueeze(0).expand(img.shape[0], -1, -1)
image_start_token = (
self.vlm_with_expert.embed_language_tokens(
self.local_image_tokens[local_token_idx].to(
device=self.vlm_with_expert.vlm.device
)
)
.unsqueeze(0)
.expand(img.shape[0], -1, -1)
)
else:
image_start_token = self.vlm_with_expert.embed_language_tokens(self.global_image_start_token.to(device=self.vlm_with_expert.vlm.device)).unsqueeze(0).expand(img.shape[0], -1, -1)
image_start_mask = torch.ones_like(image_start_token[:, :, 0], dtype=torch.bool, device=image_start_token.device)
image_start_token = (
self.vlm_with_expert.embed_language_tokens(
self.global_image_start_token.to(device=self.vlm_with_expert.vlm.device)
)
.unsqueeze(0)
.expand(img.shape[0], -1, -1)
)
image_start_mask = torch.ones_like(
image_start_token[:, :, 0], dtype=torch.bool, device=image_start_token.device
)
if self.causal_attention_on_history and img_idx % num_images == 0:
att_masks += [1] + [0] * (image_start_mask.shape[-1] - 1)
else:
@@ -861,7 +913,7 @@ class VLAFlowMatching(nn.Module):
pad_masks.append(image_start_mask)
img_emb = self.vlm_with_expert.embed_image(img)
img_emb = img_emb #.to(dtype=self.vlm_with_expert.type)
img_emb = img_emb # .to(dtype=self.vlm_with_expert.type)
# Normalize image embeddings
img_emb_dim = img_emb.shape[-1]
@@ -880,16 +932,26 @@ class VLAFlowMatching(nn.Module):
att_masks += [0] * (num_img_embs)
if self.add_image_special_tokens:
if not self.add_local_special_image_tokens or (self.add_local_special_image_tokens and img_idx % num_images == num_images - 1):
image_end_token = self.vlm_with_expert.embed_language_tokens(self.image_end_token.to(device=self.vlm_with_expert.vlm.device)).unsqueeze(0).expand(img.shape[0], -1, -1)
image_end_mask = torch.ones_like(image_end_token[:, :, 0], dtype=torch.bool, device=image_end_token.device)
if not self.add_local_special_image_tokens or (
self.add_local_special_image_tokens and img_idx % num_images == num_images - 1
):
image_end_token = (
self.vlm_with_expert.embed_language_tokens(
self.image_end_token.to(device=self.vlm_with_expert.vlm.device)
)
.unsqueeze(0)
.expand(img.shape[0], -1, -1)
)
image_end_mask = torch.ones_like(
image_end_token[:, :, 0], dtype=torch.bool, device=image_end_token.device
)
embs.append(image_end_token)
pad_masks.append(image_end_mask)
att_masks += [0] * (image_end_mask.shape[1])
lang_emb = self.vlm_with_expert.embed_language_tokens(lang_tokens)
# Normalize language embeddings
lang_emb_dim = lang_emb.shape[-1]
lang_emb = lang_emb * math.sqrt(lang_emb_dim) # FIXME(mshukor): is this needed for smolvlm?
lang_emb = lang_emb * math.sqrt(lang_emb_dim) # FIXME(mshukor): is this needed for smolvlm?
embs.append(lang_emb)
pad_masks.append(lang_masks)
@@ -900,7 +962,9 @@ class VLAFlowMatching(nn.Module):
if state is not None and self.state_to_prefix:
state_emb = self.state_proj(state)
state_emb = state_emb[:, None, :] if state_emb.ndim == 2 else state_emb #.to(dtype=self.vlm_with_expert.type)
state_emb = (
state_emb[:, None, :] if state_emb.ndim == 2 else state_emb
) # .to(dtype=self.vlm_with_expert.type)
embs.append(state_emb)
bsize = state_emb.shape[0]
dtype = state_emb.dtype
@@ -912,7 +976,7 @@ class VLAFlowMatching(nn.Module):
# Set attention masks so that image and language inputs do not attend to state or actions
# att_masks += [1] + [0]*(states_seq_len - 1)
att_masks += [1]*(states_seq_len)
att_masks += [1] * (states_seq_len)
embs = torch.cat(embs, dim=1)
pad_masks = torch.cat(pad_masks, dim=1)
att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device)
@@ -937,7 +1001,9 @@ class VLAFlowMatching(nn.Module):
# Embed state
if not self.state_to_prefix:
state_emb = self.state_proj(state)
state_emb = state_emb[:, None, :] if state_emb.ndim == 2 else state_emb #.to(dtype=self.vlm_with_expert.type)
state_emb = (
state_emb[:, None, :] if state_emb.ndim == 2 else state_emb
) # .to(dtype=self.vlm_with_expert.type)
embs.append(state_emb)
bsize = state_emb.shape[0]
dtype = state_emb.dtype
@@ -948,8 +1014,7 @@ class VLAFlowMatching(nn.Module):
pad_masks.append(state_mask)
# Set attention masks so that image and language inputs do not attend to state or actions
att_masks += [1] + [0]*(states_seq_len - 1)
att_masks += [1] + [0] * (states_seq_len - 1)
# Fuse timestep + action information using an MLP
action_emb = self.action_in_proj(noisy_actions)
@@ -1010,7 +1075,7 @@ class VLAFlowMatching(nn.Module):
images, img_masks, lang_tokens, lang_masks, state=state
)
suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(state, x_t, time)
pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)
att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1)
@@ -1061,12 +1126,12 @@ class VLAFlowMatching(nn.Module):
x_t = torch.zeros_like(noise, dtype=torch.float32, device=device)
expanded_time = torch.zeros(bsize, dtype=torch.float32, device=device)
x_t = self.denoise_step(
state,
prefix_pad_masks,
past_key_values,
x_t,
expanded_time,
)
state,
prefix_pad_masks,
past_key_values,
x_t,
expanded_time,
)
else:
dt = -1.0 / self.config.num_steps
dt = torch.tensor(dt, dtype=torch.float32, device=device)
@@ -12,31 +12,28 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Union
from functools import partial
import copy
from functools import partial
from typing import List, Optional, Union
import torch
import torch.version
import torch.nn.functional as F # noqa: N812
import torch.version
from peft import LoraConfig, TaskType, get_peft_model
from pytest import Cache
from torch import nn
from transformers import (
AutoConfig,
GemmaForCausalLM,
AutoModelForImageTextToText,
AutoProcessor,
PretrainedConfig,
PreTrainedModel,
SmolVLMForConditionalGeneration,
AutoModel,
AutoModelForImageTextToText,
AutoModelForVision2Seq,
AutoProcessor,
SmolVLMForConditionalGeneration,
)
from transformers.models.auto import CONFIG_MAPPING
from transformers import SmolVLMModel, SmolVLMConfig
from lerobot.policies.smolpi0.flex_attention import flex_attention_forward
def _round_up_to_multiple(x, multiple):
return (x + multiple - 1) // multiple * multiple
@@ -180,20 +177,31 @@ def apply_rope(x, positions, max_wavelength=10_000):
# f"Wrong value provided for `attention_implementation` ({self.attention_implementation}). Expected 'eager', 'fa2' or 'flex'."
# )
def get_intermediate_size(hidden_dim, ffn_dim_multiplier=4, multiple_of=256):
hidden_dim = int(2 * hidden_dim / 3)
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
return hidden_dim
class SmolVLMWithExpertModel(nn.Module):
# config_class = PaliGemmaWithExpertConfig
def __init__(self, model_id: str = "HuggingFaceTB/SmolVLM2-500M-Video-Instruct",
load_vlm_weights: bool = True, train_expert_only: bool = True, freeze_vision_encoder: bool = False,
attention_implementation: str = "eager", attention_mode: str = "self_attn", num_expert_layers: int = -1,
num_vlm_layers: int = -1, self_attn_every_n_layers: int = -1, expert_width_multiplier: float = 0.5, self_attn_only_actions: bool = False):
def __init__(
self,
model_id: str = "HuggingFaceTB/SmolVLM2-500M-Video-Instruct",
load_vlm_weights: bool = True,
train_expert_only: bool = True,
freeze_vision_encoder: bool = False,
attention_implementation: str = "eager",
attention_mode: str = "self_attn",
num_expert_layers: int = -1,
num_vlm_layers: int = -1,
self_attn_every_n_layers: int = -1,
expert_width_multiplier: float = 0.5,
self_attn_only_actions: bool = False,
):
super().__init__()
if load_vlm_weights:
print(f"Loading {model_id} weights ...")
@@ -227,15 +235,17 @@ class SmolVLMWithExpertModel(nn.Module):
# Smaller lm expert
lm_expert_config = copy.deepcopy(config.text_config)
hidden_size = lm_expert_config.hidden_size
lm_expert_config.hidden_size = int(hidden_size*expert_width_multiplier) #hidden_size // 2
lm_expert_config.intermediate_size = get_intermediate_size(int(hidden_size*expert_width_multiplier))
lm_expert_config.hidden_size = int(hidden_size * expert_width_multiplier) # hidden_size // 2
lm_expert_config.intermediate_size = get_intermediate_size(int(hidden_size * expert_width_multiplier))
lm_expert_config.num_hidden_layers = self.num_vlm_layers
if num_expert_layers > 0 :
assert len(self.get_vlm_model().text_model.layers) % num_expert_layers == 0, f"Number of layers in the VLM {len(self.get_vlm_model().text_model.layers)} are not multiple of num_expert_layers {num_expert_layers}"
if num_expert_layers > 0:
assert len(self.get_vlm_model().text_model.layers) % num_expert_layers == 0, (
f"Number of layers in the VLM {len(self.get_vlm_model().text_model.layers)} are not multiple of num_expert_layers {num_expert_layers}"
)
lm_expert_config.num_hidden_layers = num_expert_layers
# lm_expert_config.head_dim = lm_expert_config.head_dim * 2
self.lm_expert = AutoModel.from_config(lm_expert_config)
self.num_expert_layers = len(self.lm_expert.layers)
self.self_attn_every_n_layers = self_attn_every_n_layers
self.self_attn_only_actions = self_attn_only_actions
@@ -245,10 +255,14 @@ class SmolVLMWithExpertModel(nn.Module):
if self.self_attn_every_n_layers > 0 and layer_idx % self.self_attn_every_n_layers == 0:
continue
self.lm_expert.layers[layer_idx].self_attn.k_proj = nn.Linear(
config.text_config.num_key_value_heads * config.text_config.head_dim, lm_expert_config.num_key_value_heads * lm_expert_config.head_dim, bias=lm_expert_config.attention_bias
config.text_config.num_key_value_heads * config.text_config.head_dim,
lm_expert_config.num_key_value_heads * lm_expert_config.head_dim,
bias=lm_expert_config.attention_bias,
)
self.lm_expert.layers[layer_idx].self_attn.v_proj = nn.Linear(
config.text_config.num_key_value_heads * config.text_config.head_dim, lm_expert_config.num_key_value_heads * lm_expert_config.head_dim, bias=lm_expert_config.attention_bias
config.text_config.num_key_value_heads * config.text_config.head_dim,
lm_expert_config.num_key_value_heads * lm_expert_config.head_dim,
bias=lm_expert_config.attention_bias,
)
# Remove unused embed_tokens
self.lm_expert.embed_tokens = None
@@ -308,8 +322,10 @@ class SmolVLMWithExpertModel(nn.Module):
else:
self.vlm = self.vlm.merge_and_unload()
def get_vlm_model(self,):
if hasattr(self.vlm.model, "model"): # When using peft
def get_vlm_model(
self,
):
if hasattr(self.vlm.model, "model"): # When using peft
return self.vlm.model.model
else:
return self.vlm.model
@@ -326,22 +342,20 @@ class SmolVLMWithExpertModel(nn.Module):
else:
# To avoid unused params issue with distributed training
last_layers = [self.num_vlm_layers - 1]
if self.num_vlm_layers != self.num_expert_layers and self.num_vlm_layers % self.num_expert_layers == 0:
if (
self.num_vlm_layers != self.num_expert_layers
and self.num_vlm_layers % self.num_expert_layers == 0
):
last_layers.append(self.num_vlm_layers - 2)
frozen_layers = [
"lm_head",
"text_model.model.norm.weight",
]
"lm_head",
"text_model.model.norm.weight",
]
for layer in last_layers:
frozen_layers.append(f"text_model.model.layers.{layer}.")
for name, params in self.vlm.named_parameters():
if any(
[
k in name
for k in frozen_layers
]
):
if any([k in name for k in frozen_layers]):
params.requires_grad = False
# To avoid unused params issue with distributed training
for name, params in self.lm_expert.named_parameters():
@@ -410,19 +424,34 @@ class SmolVLMWithExpertModel(nn.Module):
# FIXME(mshukor): add special image tokens specific to smolvlm
# Get sequence from the vision encoder
image_hidden_states = self.get_vlm_model().vision_model(
pixel_values=image.to(dtype=self.get_vlm_model().vision_model.dtype),
patch_attention_mask=patch_attention_mask,
).last_hidden_state
image_hidden_states = (
self.get_vlm_model()
.vision_model(
pixel_values=image.to(dtype=self.get_vlm_model().vision_model.dtype),
patch_attention_mask=patch_attention_mask,
)
.last_hidden_state
)
# Modality projection & resampling
image_hidden_states = self.get_vlm_model().connector(image_hidden_states)
return image_hidden_states
def embed_language_tokens(self, tokens: torch.Tensor):
return self.get_vlm_model().text_model.get_input_embeddings()(tokens)
def forward_attn_layer(self, model_layers, inputs_embeds, layer_idx, position_ids, attention_mask, batch_size, head_dim, use_cache: bool = True, fill_kv_cache: bool = True, past_key_values=None) -> list[torch.Tensor]:
def forward_attn_layer(
self,
model_layers,
inputs_embeds,
layer_idx,
position_ids,
attention_mask,
batch_size,
head_dim,
use_cache: bool = True,
fill_kv_cache: bool = True,
past_key_values=None,
) -> list[torch.Tensor]:
query_states = []
key_states = []
value_states = []
@@ -430,7 +459,7 @@ class SmolVLMWithExpertModel(nn.Module):
layer = model_layers[i][layer_idx]
if hidden_states is None or layer is None:
continue
# normalizer = torch.tensor(models[i].config.hidden_size**0.5, dtype=hidden_states.dtype)
# hidden_states = hidden_states * normalizer
hidden_states = layer.input_layernorm(hidden_states)
@@ -468,12 +497,16 @@ class SmolVLMWithExpertModel(nn.Module):
if inputs_embeds[1] is not None:
suffix_len = inputs_embeds[1].shape[1]
attention_mask_[:, -suffix_len:, :-suffix_len] = False
position_ids_[:, -suffix_len:] = _position_ids[:, -suffix_len:] - _position_ids[:, -suffix_len][:, None]
position_ids_[:, -suffix_len:] = (
_position_ids[:, -suffix_len:] - _position_ids[:, -suffix_len][:, None]
)
else:
attention_mask_ = _attention_mask
position_ids_ = _position_ids
query_states = apply_rope(query_states, position_ids_) # FIXME(mshukor): this assumes we have always the vlm features?
query_states = apply_rope(
query_states, position_ids_
) # FIXME(mshukor): this assumes we have always the vlm features?
key_states = apply_rope(key_states, position_ids_)
if use_cache and past_key_values is None:
@@ -491,9 +524,7 @@ class SmolVLMWithExpertModel(nn.Module):
# the max len, then we (for instance) double the cache size. This implementation already exists
# in `transformers`. (molbap)
key_states = torch.cat([past_key_values[layer_idx]["key_states"], key_states], dim=1)
value_states = torch.cat(
[past_key_values[layer_idx]["value_states"], value_states], dim=1
)
value_states = torch.cat([past_key_values[layer_idx]["value_states"], value_states], dim=1)
attention_interface = self.get_attention_interface()
@@ -504,20 +535,32 @@ class SmolVLMWithExpertModel(nn.Module):
return [att_output], past_key_values
def forward_cross_attn_layer(self, model_layers, inputs_embeds, layer_idx, position_ids, attention_mask, batch_size, head_dim, use_cache: bool = True, fill_kv_cache: bool = True, past_key_values = None) -> list[torch.Tensor]:
def forward_cross_attn_layer(
self,
model_layers,
inputs_embeds,
layer_idx,
position_ids,
attention_mask,
batch_size,
head_dim,
use_cache: bool = True,
fill_kv_cache: bool = True,
past_key_values=None,
) -> list[torch.Tensor]:
attention_interface = self.get_attention_interface()
att_outputs = []
assert len(inputs_embeds) == 2 or (use_cache and past_key_values is not None and not fill_kv_cache), f"Both len(inputs_embeds) == {len(inputs_embeds)} and past_key_values is {past_key_values}"
assert len(inputs_embeds) == 2 or (use_cache and past_key_values is not None and not fill_kv_cache), (
f"Both len(inputs_embeds) == {len(inputs_embeds)} and past_key_values is {past_key_values}"
)
if len(inputs_embeds) == 2 and not past_key_values:
# Prefix attention
seq_len = inputs_embeds[0].shape[1]
position_id, expert_position_id = position_ids[:, :seq_len], position_ids[:, seq_len:]
position_id, expert_position_id = position_ids[:, :seq_len], position_ids[:, seq_len:]
prefix_attention_mask = attention_mask[:, :seq_len, :seq_len]
layer = model_layers[0][layer_idx]
hidden_states = layer.input_layernorm(inputs_embeds[0])
@@ -558,7 +601,6 @@ class SmolVLMWithExpertModel(nn.Module):
key_states = past_key_values[layer_idx]["key_states"]
value_states = past_key_values[layer_idx]["value_states"]
# Expert
expert_layer = model_layers[1][layer_idx]
if expert_layer is not None:
@@ -570,21 +612,37 @@ class SmolVLMWithExpertModel(nn.Module):
expert_hidden_states = expert_hidden_states.to(dtype=expert_layer.self_attn.q_proj.weight.dtype)
expert_query_state = expert_layer.self_attn.q_proj(expert_hidden_states).view(expert_hidden_shape)
_key_states = key_states.to(dtype=expert_layer.self_attn.k_proj.weight.dtype).view(
*key_states.shape[:2], -1
)
expert_key_states = expert_layer.self_attn.k_proj(_key_states).view(
*_key_states.shape[:-1], -1, expert_layer.self_attn.head_dim
) # k_proj should have same dim as kv
_key_states = key_states.to(dtype=expert_layer.self_attn.k_proj.weight.dtype).view(*key_states.shape[:2], -1)
expert_key_states = expert_layer.self_attn.k_proj(_key_states).view(*_key_states.shape[:-1], -1, expert_layer.self_attn.head_dim) # k_proj should have same dim as kv
_value_states = value_states.to(dtype=expert_layer.self_attn.v_proj.weight.dtype).view(
*value_states.shape[:2], -1
)
expert_value_states = expert_layer.self_attn.v_proj(_value_states).view(
*_value_states.shape[:-1], -1, expert_layer.self_attn.head_dim
)
_value_states = value_states.to(dtype=expert_layer.self_attn.v_proj.weight.dtype).view(*value_states.shape[:2], -1)
expert_value_states = expert_layer.self_attn.v_proj(_value_states).view(*_value_states.shape[:-1], -1, expert_layer.self_attn.head_dim)
expert_position_id = expert_position_id - torch.min(expert_position_id, dim=1, keepdim=True).values # start from 0
expert_attention_mask = attention_mask[:, -inputs_embeds[1].shape[1]:, :expert_key_states.shape[1]:] # take into account kv
expert_position_id = (
expert_position_id - torch.min(expert_position_id, dim=1, keepdim=True).values
) # start from 0
expert_attention_mask = attention_mask[
:, -inputs_embeds[1].shape[1] :, : expert_key_states.shape[1] :
] # take into account kv
expert_query_states = apply_rope(expert_query_state, expert_position_id)
# expert_key_states = apply_rope(expert_key_state, expert_position_id)
att_output = attention_interface(
expert_attention_mask, batch_size, head_dim, expert_query_states, expert_key_states, expert_value_states
expert_attention_mask,
batch_size,
head_dim,
expert_query_states,
expert_key_states,
expert_value_states,
)
att_outputs.append(att_output)
else:
@@ -592,8 +650,8 @@ class SmolVLMWithExpertModel(nn.Module):
# att_output = att_output.to(dtype=models[i].dtype)
return att_outputs, past_key_values
def get_model_layers(self, models: list) -> list: # FIXME(mshukor): is this efficient?
def get_model_layers(self, models: list) -> list: # FIXME(mshukor): is this efficient?
vlm_layers = []
expert_layers = []
multiple_of = self.num_vlm_layers // self.num_expert_layers
@@ -606,15 +664,16 @@ class SmolVLMWithExpertModel(nn.Module):
vlm_layers.append(models[0].layers[i])
expert_layers.append(expert_layer)
return [vlm_layers, expert_layers]
# TODO: break down this huge forward into modules or functions
def forward(
self,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None,
inputs_embeds: List[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
fill_kv_cache: Optional[bool] = None,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_values: list[torch.FloatTensor] | Cache | None = None,
inputs_embeds: list[torch.FloatTensor] = None,
use_cache: bool | None = None,
fill_kv_cache: bool | None = None,
):
models = [self.get_vlm_model().text_model, self.lm_expert]
model_layers = self.get_model_layers(models)
@@ -626,26 +685,58 @@ class SmolVLMWithExpertModel(nn.Module):
continue
batch_size = hidden_states.shape[0]
# # Pad prefix embds so that prefix_embs + prefix_embs len are multiple of 128, pad left or right depending on the gen or train
# # Pad prefix embds so that prefix_embs + prefix_embs len are multiple of 128, pad left or right depending on the gen or train
if self.attention_implementation == "flex":
if inputs_embeds[0] is not None and inputs_embeds[1] is not None and attention_mask.shape[-1] == attention_mask.shape[-2] and past_key_values is None: # Now only during training
if (
inputs_embeds[0] is not None
and inputs_embeds[1] is not None
and attention_mask.shape[-1] == attention_mask.shape[-2]
and past_key_values is None
): # Now only during training
seq_len = inputs_embeds[0].shape[1] + inputs_embeds[1].shape[1]
padded_seq_len = _round_up_to_multiple(seq_len, 128) # FIXME(mshukor): more efficient to have a fixed seq len?
padded_seq_len = _round_up_to_multiple(
seq_len, 128
) # FIXME(mshukor): more efficient to have a fixed seq len?
b_mask, q_len, kv_len = attention_mask.shape # The shape of your mask
pad = padded_seq_len - q_len
attention_mask = F.pad(attention_mask, (0, pad, 0, pad), value=True)
inputs_embeds[0] = F.pad(inputs_embeds[0], (0, 0, 0, pad), value=0.0)
position_ids = F.pad(position_ids, (0, pad), value=0)
# RMSNorm
num_layers = self.num_vlm_layers
head_dim = self.vlm.config.text_config.head_dim
for layer_idx in range(num_layers):
if fill_kv_cache or "cross" not in self.attention_mode or (self.self_attn_every_n_layers > 0 and layer_idx % self.self_attn_every_n_layers == 0):
att_outputs, past_key_values = self.forward_attn_layer(model_layers, inputs_embeds, layer_idx, position_ids, attention_mask, batch_size, head_dim, use_cache=use_cache, fill_kv_cache=fill_kv_cache, past_key_values=past_key_values)
if (
fill_kv_cache
or "cross" not in self.attention_mode
or (self.self_attn_every_n_layers > 0 and layer_idx % self.self_attn_every_n_layers == 0)
):
att_outputs, past_key_values = self.forward_attn_layer(
model_layers,
inputs_embeds,
layer_idx,
position_ids,
attention_mask,
batch_size,
head_dim,
use_cache=use_cache,
fill_kv_cache=fill_kv_cache,
past_key_values=past_key_values,
)
else:
att_outputs, past_key_values = self.forward_cross_attn_layer(model_layers, inputs_embeds, layer_idx, position_ids, attention_mask, batch_size, head_dim, use_cache=use_cache, fill_kv_cache=fill_kv_cache, past_key_values=past_key_values)
att_outputs, past_key_values = self.forward_cross_attn_layer(
model_layers,
inputs_embeds,
layer_idx,
position_ids,
attention_mask,
batch_size,
head_dim,
use_cache=use_cache,
fill_kv_cache=fill_kv_cache,
past_key_values=past_key_values,
)
# query_states = []
# key_states = []
# value_states = []
@@ -703,7 +794,6 @@ class SmolVLMWithExpertModel(nn.Module):
# attention_mask, batch_size, head_dim, query_states, key_states, value_states
# )
# att_output = att_output.to(dtype=models[i].dtype)
# first part of att_output is prefix (up to sequence length, [:, 0:prefix_seq_len])
@@ -712,7 +802,9 @@ class SmolVLMWithExpertModel(nn.Module):
for i, hidden_states in enumerate(inputs_embeds):
# layer = models[i].layers[layer_idx]
layer = model_layers[i][layer_idx]
att_output = att_outputs[i] if i < len(att_outputs) else att_outputs[0] # in case of self_attn
att_output = (
att_outputs[i] if i < len(att_outputs) else att_outputs[0]
) # in case of self_attn
if hidden_states is not None:
if layer is None:
outputs_embeds.append(hidden_states)
@@ -759,7 +851,11 @@ class SmolVLMWithExpertModel(nn.Module):
if self.attention_implementation == "fa2":
attention_interface = self.flash_attention_forward
elif self.attention_implementation == "flex":
attention_interface = partial(flex_attention_forward, num_att_heads=self.num_attention_heads, num_key_value_heads=self.num_key_value_heads)
attention_interface = partial(
flex_attention_forward,
num_att_heads=self.num_attention_heads,
num_key_value_heads=self.num_key_value_heads,
)
else:
attention_interface = self.eager_attention_forward
return attention_interface
@@ -807,7 +903,7 @@ class SmolVLMWithExpertModel(nn.Module):
att_weights *= head_dim**-0.5
att_weights = att_weights.to(dtype=torch.float32)
big_neg = torch.finfo(att_weights.dtype).min #-2.3819763e38 # See gemma/modules.py
big_neg = torch.finfo(att_weights.dtype).min # -2.3819763e38 # See gemma/modules.py
masked_att_weights = torch.where(attention_mask[:, None, :, :], att_weights, big_neg)
probs = nn.functional.softmax(masked_att_weights, dim=-1)
probs = probs.to(dtype=value_states.dtype)
@@ -1027,6 +1027,7 @@ from lerobot.policies.utils import (
populate_queues,
)
from lerobot.utils.utils import get_safe_dtype
# OBS_STATE = 'state'
# Matches ".soNNN", optionally followed by "-something", up to the "_buffer_" marker
_VARIANT_RE = re.compile(r"\.so\d+(?:-[\w]+)?_buffer_")
@@ -1891,4 +1892,4 @@ class VLAFlowMatching(nn.Module):
suffix_out = suffix_out[:, -self.config.chunk_size :]
suffix_out = suffix_out.to(dtype=torch.float32)
v_t = self.action_out_proj(suffix_out)
return v_t
return v_t
+1 -1
View File
@@ -1 +1 @@
c
c
@@ -547,4 +547,4 @@ class SmolVLMWithExpertModel(nn.Module):
# we use -1 because sequence length can change
att_output = att_output.reshape(batch_size, -1, num_key_value_heads * num_key_value_groups * head_dim)
return att_output
return att_output
+60 -35
View File
@@ -45,17 +45,21 @@ Note that in both examples, the repo/folder should contain at least `config.json
You can learn about the CLI options for this script in the `EvalPipelineConfig` in lerobot/configs/eval.py
"""
import concurrent
import concurrent.futures as cf
import json
import logging
import threading
import time
from collections import defaultdict
from collections.abc import Callable
from contextlib import nullcontext
from copy import deepcopy
from dataclasses import asdict
from pathlib import Path
from pprint import pformat
from typing import Dict, List, Tuple, TypedDict
from collections.abc import Iterator
import einops
import gymnasium as gym
@@ -68,7 +72,11 @@ from tqdm import trange
from lerobot.configs import parser
from lerobot.configs.eval import EvalPipelineConfig
from lerobot.envs.factory import make_env
from lerobot.envs.utils import add_envs_task, check_env_attributes_and_types, preprocess_observation, preprocess_observation1
from lerobot.envs.utils import (
add_envs_task,
check_env_attributes_and_types,
preprocess_observation,
)
from lerobot.policies.factory import make_policy
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.utils import get_device_from_parameters
@@ -79,9 +87,6 @@ from lerobot.utils.utils import (
init_logging,
inside_slurm,
)
from typing import TypedDict, Dict, List, Tuple, Iterator
from collections import defaultdict
import concurrent.futures as cf
def rollout(
@@ -485,8 +490,12 @@ def _compile_episode_data(
data_dict["index"] = torch.arange(start_data_index, start_data_index + total_frames, 1)
return data_dict
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy
def _inject_normalization_stats(policy: SmolVLAPolicy, dataset_meta: LeRobotDatasetMetadata):
"""Recreate normalization layers with proper stats from the dataset."""
from lerobot.policies.normalize import Normalize, Unnormalize
@@ -518,7 +527,8 @@ def _inject_normalization_stats(policy: SmolVLAPolicy, dataset_meta: LeRobotData
def load_smolvla(cfg, dataset_repo: str, policy):
from lerobot.datasets.lerobot_dataset import LeRobotDataset
dataset = LeRobotDataset(dataset_repo, root='/raid/jade/.cache/huggingface/datasets/')
dataset = LeRobotDataset(dataset_repo, root="/raid/jade/.cache/huggingface/datasets/")
_inject_normalization_stats(policy=policy, dataset_meta=dataset.meta) # only needed if stats are missing
return policy.to("cuda"), dataset
@@ -529,8 +539,8 @@ def eval_main(cfg: EvalPipelineConfig):
# Check device is available
device = get_safe_torch_device(cfg.policy.device, log=True)
#login to hf
from huggingface_hub import login
# login to hf
# login()
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
@@ -549,7 +559,7 @@ def eval_main(cfg: EvalPipelineConfig):
breakpoint()
# policy, _ = load_smolvla(cfg.policy, "physical-intelligence/libero", policy)
# rename "image" -> "observation.image"
policy.eval()
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
info = eval_policy_all(
@@ -584,10 +594,11 @@ def eval_main(cfg: EvalPipelineConfig):
# ---- typed payload returned by one task eval ----
class TaskMetrics(TypedDict):
sum_rewards: List[float]
max_rewards: List[float]
successes: List[bool]
video_paths: List[str]
sum_rewards: list[float]
max_rewards: list[float]
successes: list[bool]
video_paths: list[str]
ACC_KEYS = ("sum_rewards", "max_rewards", "successes", "video_paths")
@@ -610,7 +621,7 @@ def eval_policy_all(
"""
global_start = time.time()
# inner: evaluate a single (suite, task)
# inner: evaluate a single (suite, task)
def eval_one(
task_group: str,
task_id: int,
@@ -650,27 +661,36 @@ def eval_policy_all(
video_paths=task_result.get("video_paths", []),
)
# result producer: sequential or threaded, same consumer
def iter_task_results() -> Iterator[Tuple[str, int, TaskMetrics]]:
# result producer: sequential or threaded, same consumer
def iter_task_results() -> Iterator[tuple[str, int, TaskMetrics]]:
if max_parallel_tasks == 1:
for task_group, tasks in envs.items():
for task_id, vec in tasks.items():
yield task_group, task_id, eval_one(
task_group, task_id, vec,
policy=policy,
n_episodes=n_episodes,
max_episodes_rendered=max_episodes_rendered,
videos_dir=videos_dir,
return_episode_data=return_episode_data,
start_seed=start_seed,
yield (
task_group,
task_id,
eval_one(
task_group,
task_id,
vec,
policy=policy,
n_episodes=n_episodes,
max_episodes_rendered=max_episodes_rendered,
videos_dir=videos_dir,
return_episode_data=return_episode_data,
start_seed=start_seed,
),
)
else:
with cf.ThreadPoolExecutor(max_workers=max_parallel_tasks) as executor:
fut2key: Dict[cf.Future, Tuple[str, int]] = {}
fut2key: dict[cf.Future, tuple[str, int]] = {}
for task_group, tasks in envs.items():
for task_id, vec in tasks.items():
fut = executor.submit(
eval_one, task_group, task_id, vec,
eval_one,
task_group,
task_id,
vec,
policy=policy,
n_episodes=n_episodes,
max_episodes_rendered=max_episodes_rendered,
@@ -683,9 +703,9 @@ def eval_policy_all(
task_group, task_id = fut2key[fut]
yield task_group, task_id, fut.result()
# single accumulator path on the main thread
group_acc: Dict[str, Dict[str, List]] = defaultdict(lambda: {k: [] for k in ACC_KEYS})
overall: Dict[str, List] = {k: [] for k in ACC_KEYS}
# single accumulator path on the main thread
group_acc: dict[str, dict[str, list]] = defaultdict(lambda: {k: [] for k in ACC_KEYS})
overall: dict[str, list] = {k: [] for k in ACC_KEYS}
for task_group, task_id, metrics in iter_task_results():
acc = group_acc[task_group]
@@ -694,7 +714,7 @@ def eval_policy_all(
overall[k].extend(metrics[k])
# build outputs
results: Dict[str, dict] = {}
results: dict[str, dict] = {}
for task_group, data in group_acc.items():
suite_rewards = data["sum_rewards"]
suite_max = data["max_rewards"]
@@ -720,9 +740,15 @@ def eval_policy_all(
global_eval_ep_s = global_eval_s / max(1, len(overall["sum_rewards"]))
results["overall"] = {
"aggregated": {
"avg_sum_reward": float(np.nanmean(overall["sum_rewards"])) if overall["sum_rewards"] else float("nan"),
"avg_max_reward": float(np.nanmean(overall["max_rewards"])) if overall["max_rewards"] else float("nan"),
"pc_success": float(np.nanmean(overall["successes"]) * 100) if overall["successes"] else float("nan"),
"avg_sum_reward": float(np.nanmean(overall["sum_rewards"]))
if overall["sum_rewards"]
else float("nan"),
"avg_max_reward": float(np.nanmean(overall["max_rewards"]))
if overall["max_rewards"]
else float("nan"),
"pc_success": float(np.nanmean(overall["successes"]) * 100)
if overall["successes"]
else float("nan"),
"eval_s": global_eval_s,
"eval_ep_s": global_eval_ep_s,
},
@@ -732,7 +758,6 @@ def eval_policy_all(
return results
if __name__ == "__main__":
init_logging()
eval_main()
+11 -6
View File
@@ -105,6 +105,7 @@ def update_policy(
train_metrics.update_s = time.perf_counter() - start_time
return train_metrics, output_dict
# def _inject_normalization_stats(policy: SmolVLAPolicy, dataset_meta: LeRobotDatasetMetadata):
# """Recreate normalization layers with dataset stats if missing (Adil's workaround)."""
# from lerobot.policies.normalize import Normalize, Unnormalize
@@ -132,6 +133,7 @@ def update_policy(
# print("✅ Normalization layers injected with dataset stats.")
@parser.wrap()
def train(cfg: TrainPipelineConfig):
cfg.validate()
@@ -271,9 +273,12 @@ def train(cfg: TrainPipelineConfig):
if cfg.env and is_eval_step:
step_id = get_step_identifier(step, cfg.steps)
logging.info(f"Eval policy at step {step}")
with torch.no_grad(), (torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext()):
with (
torch.no_grad(),
torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(),
):
eval_info = eval_policy_all(
eval_env, # dict[suite][task_id] -> vec_env
eval_env, # dict[suite][task_id] -> vec_env
policy,
cfg.eval.n_episodes,
videos_dir=videos_dir,
@@ -295,15 +300,15 @@ def train(cfg: TrainPipelineConfig):
# meters/tracker
eval_metrics = {
"avg_sum_reward": AverageMeter("∑rwrd", ":.3f"),
"pc_success": AverageMeter("success", ":.1f"),
"eval_s": AverageMeter("eval_s", ":.3f"),
"pc_success": AverageMeter("success", ":.1f"),
"eval_s": AverageMeter("eval_s", ":.3f"),
}
eval_tracker = MetricsTracker(
cfg.batch_size, dataset.num_frames, dataset.num_episodes, eval_metrics, initial_step=step
)
eval_tracker.eval_s = aggregated.get("eval_s", 0.0)
eval_tracker.eval_s = aggregated.get("eval_s", 0.0)
eval_tracker.avg_sum_reward = aggregated.get("avg_sum_reward", float("nan"))
eval_tracker.pc_success = aggregated.get("pc_success", float("nan"))
eval_tracker.pc_success = aggregated.get("pc_success", float("nan"))
if wandb_logger:
wandb_log_dict = {**eval_tracker.to_dict(), **eval_info}
wandb_logger.log_dict(wandb_log_dict, step, mode="eval")
+6 -4
View File
@@ -104,6 +104,7 @@ def update_policy(
train_metrics.update_s = time.perf_counter() - start_time
return train_metrics, output_dict
def _inject_normalization_stats(policy: SmolVLAPolicy, dataset_meta: LeRobotDatasetMetadata):
"""Recreate normalization layers with dataset stats if missing (Adil's workaround)."""
from lerobot.policies.normalize import Normalize, Unnormalize
@@ -115,15 +116,15 @@ def _inject_normalization_stats(policy: SmolVLAPolicy, dataset_meta: LeRobotData
stats = {}
for key, stat_dict in dataset_meta.stats.items():
stats[key] = {
stat_type: torch.as_tensor(stat_array)
if isinstance(stat_array, np.ndarray)
else stat_array
stat_type: torch.as_tensor(stat_array) if isinstance(stat_array, np.ndarray) else stat_array
for stat_type, stat_array in stat_dict.items()
}
normalize_inputs = Normalize(policy.config.input_features, policy.config.normalization_mapping, stats)
normalize_targets = Normalize(policy.config.output_features, policy.config.normalization_mapping, stats)
unnormalize_outputs = Unnormalize(policy.config.output_features, policy.config.normalization_mapping, stats)
unnormalize_outputs = Unnormalize(
policy.config.output_features, policy.config.normalization_mapping, stats
)
policy.normalize_inputs = normalize_inputs
policy.normalize_targets = normalize_targets
@@ -131,6 +132,7 @@ def _inject_normalization_stats(policy: SmolVLAPolicy, dataset_meta: LeRobotData
print("✅ Normalization layers injected with dataset stats.")
@parser.wrap()
def train(cfg: TrainPipelineConfig):
cfg.validate()
+5 -4
View File
@@ -24,12 +24,15 @@ from accelerate.utils import set_seed as accelerate_set_seed
from termcolor import colored
from torch.optim import Optimizer
from lerobot.configs import parser
from lerobot.configs.train import TrainPipelineConfig
from lerobot.datasets.factory import make_dataset
from lerobot.datasets.sampler import EpisodeAwareSampler
from lerobot.envs.factory import make_env
from lerobot.optim.factory import make_optimizer_and_scheduler
from lerobot.policies.factory import make_policy
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.scripts.eval import eval_policy
from lerobot.utils.logging_utils import AverageMeter, MetricsTracker
from lerobot.utils.train_utils import (
get_step_checkpoint_dir,
@@ -43,9 +46,6 @@ from lerobot.utils.utils import (
has_method,
init_logging,
)
from lerobot.configs import parser
from lerobot.configs.train import TrainPipelineConfig
from lerobot.scripts.eval import eval_policy
def update_policy(
@@ -100,6 +100,7 @@ def train(cfg: TrainPipelineConfig):
# Initialize accelerator
from accelerate.utils import DistributedDataParallelKwargs
# added by jade 2 lines
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=False)
accelerator = Accelerator(..., kwargs_handlers=[ddp_kwargs])
@@ -357,7 +358,7 @@ def train(cfg: TrainPipelineConfig):
if accelerator.is_main_process:
logging.info("End of training")
accelerator.end_training() # added by jade
accelerator.end_training() # added by jade
if __name__ == "__main__":