[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot]
2025-09-11 11:51:53 +00:00
parent 565c992589
commit a19d7fb6bf
17 changed files with 469 additions and 254 deletions
+1 -1
View File
@@ -55,4 +55,4 @@ python src/lerobot/scripts/eval.py \
# --num_trials_per_task 10 \ # --num_trials_per_task 10 \
# --video_out_path "data/libero/videos" \ # --video_out_path "data/libero/videos" \
# --device "cuda" \ # --device "cuda" \
# --seed 7 # --seed 7
+3 -5
View File
@@ -61,9 +61,9 @@ def make_env(
env_cls = gym.vector.AsyncVectorEnv if use_async_envs else gym.vector.SyncVectorEnv env_cls = gym.vector.AsyncVectorEnv if use_async_envs else gym.vector.SyncVectorEnv
if "libero" in cfg.type: if "libero" in cfg.type:
from lerobot.envs.libero import create_libero_envs from lerobot.envs.libero import create_libero_envs
return create_libero_envs( return create_libero_envs(
task=cfg.task, task=cfg.task,
n_envs=n_envs, n_envs=n_envs,
@@ -74,17 +74,16 @@ def make_env(
multitask_eval=cfg.multitask_eval, multitask_eval=cfg.multitask_eval,
) )
package_name = f"gym_{cfg.type}" package_name = f"gym_{cfg.type}"
try: try:
importlib.import_module(package_name) importlib.import_module(package_name)
except ModuleNotFoundError as e: except ModuleNotFoundError as e:
raise ModuleNotFoundError( raise ModuleNotFoundError(
f"{package_name} is not installed. Install with: pip install \"lerobot[{cfg.type}]\"" f'{package_name} is not installed. Install with: pip install "lerobot[{cfg.type}]"'
) from e ) from e
gym_handle = f"{package_name}/{cfg.task}" gym_handle = f"{package_name}/{cfg.task}"
def _make_one(): def _make_one():
return gym.make(gym_handle, disable_env_checker=True, **(cfg.gym_kwargs or {})) return gym.make(gym_handle, disable_env_checker=True, **(cfg.gym_kwargs or {}))
@@ -93,4 +92,3 @@ def make_env(
# normalize to {suite: {task_id: vec_env}} for consistency # normalize to {suite: {task_id: vec_env}} for consistency
suite_name = cfg.type # e.g., "pusht", "aloha" suite_name = cfg.type # e.g., "pusht", "aloha"
return {suite_name: {0: vec}} return {suite_name: {0: vec}}
+35 -20
View File
@@ -1,11 +1,13 @@
from __future__ import annotations from __future__ import annotations
import logging
import math import math
import os import os
from collections import defaultdict from collections import defaultdict
from collections.abc import Callable from collections.abc import Callable
from itertools import chain from itertools import chain
from typing import Any from typing import Any, Dict, List
from collections.abc import Callable, Iterable, Mapping, Sequence
import gymnasium as gym import gymnasium as gym
import numpy as np import numpy as np
@@ -14,16 +16,12 @@ from gymnasium import spaces
from libero.libero import benchmark, get_libero_path from libero.libero import benchmark, get_libero_path
from libero.libero.envs import OffScreenRenderEnv from libero.libero.envs import OffScreenRenderEnv
import logging
from collections import defaultdict
from typing import Any, Callable, Dict, Iterable, List, Mapping, Sequence
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# ---- Helpers ----------------------------------------------------------------- # ---- Helpers -----------------------------------------------------------------
def _parse_camera_names(camera_name: str | Sequence[str]) -> List[str]:
def _parse_camera_names(camera_name: str | Sequence[str]) -> list[str]:
"""Normalize camera_name into a non-empty list of strings.""" """Normalize camera_name into a non-empty list of strings."""
if isinstance(camera_name, str): if isinstance(camera_name, str):
cams = [c.strip() for c in camera_name.split(",") if c.strip()] cams = [c.strip() for c in camera_name.split(",") if c.strip()]
@@ -47,14 +45,14 @@ def _get_suite(name: str):
return suite return suite
def _select_task_ids(total_tasks: int, task_ids: Iterable[int] | None) -> List[int]: def _select_task_ids(total_tasks: int, task_ids: Iterable[int] | None) -> list[int]:
"""Validate/normalize task ids. If None → all tasks.""" """Validate/normalize task ids. If None → all tasks."""
if task_ids is None: if task_ids is None:
return list(range(total_tasks)) return list(range(total_tasks))
ids = sorted(set(int(t) for t in task_ids)) ids = sorted({int(t) for t in task_ids})
for t in ids: for t in ids:
if t < 0 or t >= total_tasks: if t < 0 or t >= total_tasks:
raise ValueError(f"task_id {t} out of range [0, {total_tasks-1}].") raise ValueError(f"task_id {t} out of range [0, {total_tasks - 1}].")
return ids return ids
@@ -64,16 +62,25 @@ def _make_env_fns(
suite_name: str, suite_name: str,
task_id: int, task_id: int,
n_envs: int, n_envs: int,
camera_names: List[str], camera_names: list[str],
init_states: bool, init_states: bool,
gym_kwargs: Mapping[str, Any], gym_kwargs: Mapping[str, Any],
LiberoEnv: type, # injected to avoid forward ref issues if needed LiberoEnv: type, # injected to avoid forward ref issues if needed
) -> List[Callable[[], "LiberoEnv"]]: ) -> list[Callable[[], LiberoEnv]]:
"""Build n_envs factory callables for a single (suite, task_id).""" """Build n_envs factory callables for a single (suite, task_id)."""
joined_cams = ",".join(camera_names) # keep backward-compat: downstream expects a string joined_cams = ",".join(camera_names) # keep backward-compat: downstream expects a string
fns: List[Callable[[], "LiberoEnv"]] = [] fns: list[Callable[[], LiberoEnv]] = []
for i in range(n_envs): for i in range(n_envs):
def _mk(i=i, suite=suite, task_id=task_id, suite_name=suite_name, joined_cams=joined_cams, init_states=init_states, gym_kwargs=dict(gym_kwargs)):
def _mk(
i=i,
suite=suite,
task_id=task_id,
suite_name=suite_name,
joined_cams=joined_cams,
init_states=init_states,
gym_kwargs=dict(gym_kwargs),
):
return LiberoEnv( return LiberoEnv(
task_suite=suite, task_suite=suite,
task_id=task_id, task_id=task_id,
@@ -83,11 +90,14 @@ def _make_env_fns(
episode_index=i, episode_index=i,
**gym_kwargs, **gym_kwargs,
) )
fns.append(_mk) fns.append(_mk)
return fns return fns
# ---- Main API ---------------------------------------------------------------- # ---- Main API ----------------------------------------------------------------
def create_libero_envs( def create_libero_envs(
task: str, task: str,
n_envs: int, n_envs: int,
@@ -130,12 +140,15 @@ def create_libero_envs(
logger.info( logger.info(
"Creating LIBERO envs | suites=%s | n_envs(per task)=%d | init_states=%s | multitask_eval=%s", "Creating LIBERO envs | suites=%s | n_envs(per task)=%d | init_states=%s | multitask_eval=%s",
suite_names, n_envs, init_states, bool(multitask_eval) suite_names,
n_envs,
init_states,
bool(multitask_eval),
) )
if task_ids_filter is not None: if task_ids_filter is not None:
logger.info("Restricting to task_ids=%s", task_ids_filter) logger.info("Restricting to task_ids=%s", task_ids_filter)
out: Dict[str, Dict[int, Any]] = defaultdict(dict) out: dict[str, dict[int, Any]] = defaultdict(dict)
for suite_name in suite_names: for suite_name in suite_names:
suite = _get_suite(suite_name) suite = _get_suite(suite_name)
@@ -161,6 +174,8 @@ def create_libero_envs(
# return plain dicts for predictability # return plain dicts for predictability
return {suite: dict(task_map) for suite, task_map in out.items()} return {suite: dict(task_map) for suite, task_map in out.items()}
def quat2axisangle(quat): def quat2axisangle(quat):
""" """
Copied from robosuite: https://github.com/ARISE-Initiative/robosuite/blob/eafb81f54ffc104f905ee48a16bb15f059176ad3/robosuite/utils/transform_utils.py#L490C1-L512C55 Copied from robosuite: https://github.com/ARISE-Initiative/robosuite/blob/eafb81f54ffc104f905ee48a16bb15f059176ad3/robosuite/utils/transform_utils.py#L490C1-L512C55
@@ -256,10 +271,10 @@ class LiberoEnv(gym.Env):
self._env = self._make_envs_task(task_suite, self.task_id) self._env = self._make_envs_task(task_suite, self.task_id)
TASK_SUITE_MAX_STEPS: dict[str, int] = { TASK_SUITE_MAX_STEPS: dict[str, int] = {
"libero_spatial": 220, # longest training demo has 193 steps "libero_spatial": 220, # longest training demo has 193 steps
"libero_object": 280, # longest training demo has 254 steps "libero_object": 280, # longest training demo has 254 steps
"libero_goal": 300, # longest training demo has 270 steps "libero_goal": 300, # longest training demo has 270 steps
"libero_10": 520, # longest training demo has 505 steps "libero_10": 520, # longest training demo has 505 steps
"libero_90": 400, # longest training demo has 373 steps "libero_90": 400, # longest training demo has 373 steps
} }
default_steps = 500 default_steps = 500
self._max_episode_steps = TASK_SUITE_MAX_STEPS.get(task_suite_name, default_steps) self._max_episode_steps = TASK_SUITE_MAX_STEPS.get(task_suite_name, default_steps)
+6 -1
View File
@@ -80,6 +80,7 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
return return_observations return return_observations
def preprocess_observation1( def preprocess_observation1(
observations: dict[str, np.ndarray], cfg: dict[str, Any] = None observations: dict[str, np.ndarray], cfg: dict[str, Any] = None
) -> dict[str, Tensor]: ) -> dict[str, Tensor]:
@@ -130,6 +131,8 @@ def preprocess_observation1(
if "task" in observations: if "task" in observations:
return_observations["task"] = observations["task"] return_observations["task"] = observations["task"]
return return_observations return return_observations
def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]: def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]:
# TODO(aliberts, rcadene): remove this hardcoding of keys and just use the nested keys as is # TODO(aliberts, rcadene): remove this hardcoding of keys and just use the nested keys as is
# (need to also refactor preprocess_observation and externalize normalization from policies) # (need to also refactor preprocess_observation and externalize normalization from policies)
@@ -183,6 +186,7 @@ def add_envs_task(env: gym.vector.VectorEnv, observation: dict[str, Any]) -> dic
observation["task"] = ["" for _ in range(num_envs)] observation["task"] = ["" for _ in range(num_envs)]
return observation return observation
def _close_single_env(env: Any) -> None: def _close_single_env(env: Any) -> None:
"""Try to close a single env object if it exposes .close().""" """Try to close a single env object if it exposes .close()."""
try: try:
@@ -193,6 +197,7 @@ def _close_single_env(env: Any) -> None:
# Best-effort close: log but don't raise # Best-effort close: log but don't raise
LOG.debug("Exception while closing env %s: %s", env, exc) LOG.debug("Exception while closing env %s: %s", env, exc)
def close_envs(env_or_collection: Any) -> None: def close_envs(env_or_collection: Any) -> None:
""" """
Close a single env or any nested structure of envs. Close a single env or any nested structure of envs.
@@ -225,4 +230,4 @@ def close_envs(env_or_collection: Any) -> None:
# Fallback: try to close if possible # Fallback: try to close if possible
if hasattr(env_or_collection, "close"): if hasattr(env_or_collection, "close"):
_close_single_env(env_or_collection) _close_single_env(env_or_collection)
+1 -1
View File
@@ -31,10 +31,10 @@ from lerobot.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.sac.configuration_sac import SACConfig from lerobot.policies.sac.configuration_sac import SACConfig
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
from lerobot.policies.smolpi0.configuration_smolpi0 import SMOLPI0Config
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
from lerobot.policies.smolpi0.configuration_smolpi0 import SMOLPI0Config
def get_policy_class(name: str) -> PreTrainedPolicy: def get_policy_class(name: str) -> PreTrainedPolicy:
+4 -2
View File
@@ -309,8 +309,8 @@ class NormalizePerRobotType(nn.Module):
getattr(self, f"{robot_type}_buffer_" + key.replace(".", "_")) for robot_type in robot_types getattr(self, f"{robot_type}_buffer_" + key.replace(".", "_")) for robot_type in robot_types
] ]
if norm_mode is NormalizationMode.MEAN_STD: if norm_mode is NormalizationMode.MEAN_STD:
mean = torch.stack([buffers[i]["mean"] for i in range(len(robot_types))],dim=0) mean = torch.stack([buffers[i]["mean"] for i in range(len(robot_types))], dim=0)
std = torch.stack([buffers[i]["std"] for i in range(len(robot_types))],dim=0) std = torch.stack([buffers[i]["std"] for i in range(len(robot_types))], dim=0)
if batch[key].ndim == 3: if batch[key].ndim == 3:
mean = mean.unsqueeze(1) mean = mean.unsqueeze(1)
std = std.unsqueeze(1) std = std.unsqueeze(1)
@@ -332,6 +332,8 @@ class NormalizePerRobotType(nn.Module):
else: else:
raise ValueError(norm_mode) raise ValueError(norm_mode)
return batch return batch
# TODO (azouitine): We should replace all normalization on the policies with register_buffer normalization # TODO (azouitine): We should replace all normalization on the policies with register_buffer normalization
# and remove the `Normalize` and `Unnormalize` classes. # and remove the `Normalize` and `Unnormalize` classes.
def _initialize_stats_buffers( def _initialize_stats_buffers(
@@ -14,12 +14,12 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.optim.optimizers import AdamWConfig from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import ( from lerobot.optim.schedulers import (
CosineDecayWithWarmupSchedulerConfig, CosineDecayWithWarmupSchedulerConfig,
) )
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
@dataclass @dataclass
@@ -52,7 +52,7 @@ class SMOLPI0Config(PreTrainedConfig):
max_action_dim: int = 32 max_action_dim: int = 32
# Image preprocessing # Image preprocessing
resize_imgs_with_padding: tuple[int, int] = (512, 512) #(224, 224) resize_imgs_with_padding: tuple[int, int] = (512, 512) # (224, 224)
# Add empty images. Used by pi0_aloha_sim which adds the empty # Add empty images. Used by pi0_aloha_sim which adds the empty
# left and right wrist cameras in addition to the top camera. # left and right wrist cameras in addition to the top camera.
@@ -107,14 +107,14 @@ class SMOLPI0Config(PreTrainedConfig):
add_image_special_tokens: bool = False add_image_special_tokens: bool = False
add_prompt_template: bool = False add_prompt_template: bool = False
prefix_prompt_template: str = f"<|im_start|>User: What action should the robot take to" prefix_prompt_template: str = "<|im_start|>User: What action should the robot take to"
suffix_prompt_template: str = f"?\nAssistant:" suffix_prompt_template: str = "?\nAssistant:"
attention_mode: str = "self_attn" attention_mode: str = "self_attn"
prefix_length: int = -1 # n_obs_steps * num_cameras * num_image_token_per_image + tokenizer_max_length prefix_length: int = -1 # n_obs_steps * num_cameras * num_image_token_per_image + tokenizer_max_length
past_obs_keys: str = f"image" past_obs_keys: str = "image"
add_local_special_image_tokens: bool = False add_local_special_image_tokens: bool = False
@@ -122,7 +122,7 @@ class SMOLPI0Config(PreTrainedConfig):
state_to_prefix: bool = False state_to_prefix: bool = False
pad_language_to: str = "longest" # "max_length" pad_language_to: str = "longest" # "max_length"
num_expert_layers: int = -1 num_expert_layers: int = -1
num_vlm_layers: int = -1 num_vlm_layers: int = -1
@@ -144,9 +144,9 @@ class SMOLPI0Config(PreTrainedConfig):
shuffle_camera_positions: bool = False shuffle_camera_positions: bool = False
vlm_img_size: int = -1 vlm_img_size: int = -1
regression_loss: bool = False regression_loss: bool = False
def __post_init__(self): def __post_init__(self):
super().__post_init__() super().__post_init__()
if self.vlm_img_size > 0: if self.vlm_img_size > 0:
@@ -198,7 +198,7 @@ class SMOLPI0Config(PreTrainedConfig):
) )
@property @property
def observation_delta_indices(self) -> list: # FIXME(mshukor): support spacing between observations def observation_delta_indices(self) -> list: # FIXME(mshukor): support spacing between observations
return [-k for k in range(0, self.n_obs_steps * self.n_obs_gap, self.n_obs_gap)][::-1] return [-k for k in range(0, self.n_obs_steps * self.n_obs_gap, self.n_obs_gap)][::-1]
@property @property
@@ -85,7 +85,7 @@ def flex_attention_forward(
b_mask, h_mask, q_len, kv_len = causal_mask.shape # The shape of your mask b_mask, h_mask, q_len, kv_len = causal_mask.shape # The shape of your mask
block_size = 128 # limitation of flex attention block_size = 128 # limitation of flex attention
q_len_rounded = _round_up_to_multiple(q_len, block_size) q_len_rounded = _round_up_to_multiple(q_len, block_size)
kv_len_rounded = _round_up_to_multiple(kv_len, block_size) kv_len_rounded = _round_up_to_multiple(kv_len, block_size)
@@ -95,7 +95,7 @@ def flex_attention_forward(
pad_k = kv_len_rounded - kv_len pad_k = kv_len_rounded - kv_len
if pad_q > 0 or pad_k > 0: if pad_q > 0 or pad_k > 0:
padded_causal_mask = F.pad(causal_mask, (0, pad_k, 0, pad_q), value=0.0) padded_causal_mask = F.pad(causal_mask, (0, pad_k, 0, pad_q), value=0.0)
else: else:
padded_causal_mask = causal_mask padded_causal_mask = causal_mask
mask_mod_fn_orig = precomputed_mask_factory(padded_causal_mask) mask_mod_fn_orig = precomputed_mask_factory(padded_causal_mask)
@@ -107,21 +107,21 @@ def flex_attention_forward(
KV_LEN=kv_len_rounded, KV_LEN=kv_len_rounded,
device=causal_mask.device, device=causal_mask.device,
) )
mask_mod_fn_padded = precomputed_mask_factory(mask_4d) mask_mod_fn_padded = precomputed_mask_factory(mask_4d)
# FIXME(mshukor): compile mask torch.compile(create_block_mask) # FIXME(mshukor): compile mask torch.compile(create_block_mask)
create_block_mask_compiled = torch.compile(create_block_mask) create_block_mask_compiled = torch.compile(create_block_mask)
block_mask = create_block_mask_compiled( block_mask = create_block_mask_compiled(
mask_mod=mask_mod_fn_padded, mask_mod=mask_mod_fn_padded,
B=b_mask, B=b_mask,
H=None, # H=None, #
Q_LEN=q_len_rounded, Q_LEN=q_len_rounded,
KV_LEN=kv_len_rounded, KV_LEN=kv_len_rounded,
BLOCK_SIZE=block_size, BLOCK_SIZE=block_size,
device=causal_mask.device, device=causal_mask.device,
_compile=False, _compile=False,
) )
padded_query_states = F.pad(query_states, (0, 0, 0, pad_q), value=0.0) if pad_q > 0 else query_states padded_query_states = F.pad(query_states, (0, 0, 0, pad_q), value=0.0) if pad_q > 0 else query_states
padded_key_states = F.pad(key_states, (0, 0, 0, pad_k), value=0.0) if pad_k > 0 else key_states padded_key_states = F.pad(key_states, (0, 0, 0, pad_k), value=0.0) if pad_k > 0 else key_states
padded_value_states = F.pad(value_states, (0, 0, 0, pad_k), value=0.0) if pad_k > 0 else value_states padded_value_states = F.pad(value_states, (0, 0, 0, pad_k), value=0.0) if pad_k > 0 else value_states
# mask is applied inside the kernel, ideally more efficiently than score_mod. # mask is applied inside the kernel, ideally more efficiently than score_mod.
+138 -73
View File
@@ -50,9 +50,10 @@ policy = Pi0Policy.from_pretrained("lerobot/pi0")
""" """
import math import math
from collections import deque
import os import os
import re import re
from collections import deque
import safetensors import safetensors
import torch import torch
import torch.nn.functional as F # noqa: N812 import torch.nn.functional as F # noqa: N812
@@ -66,12 +67,11 @@ from lerobot.policies.normalize import (
Unnormalize, Unnormalize,
UnnormalizePerRobotType, UnnormalizePerRobotType,
) )
from lerobot.policies.smolpi0.configuration_smolpi0 import SMOLPI0Config
from lerobot.policies.smolpi0.smolvlm_with_expert import (
SmolVLMWithExpertModel
)
from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.smolpi0.configuration_smolpi0 import SMOLPI0Config
from lerobot.policies.smolpi0.smolvlm_with_expert import SmolVLMWithExpertModel
from lerobot.utils.utils import get_safe_dtype from lerobot.utils.utils import get_safe_dtype
OBS_IMAGE = "observation.image" OBS_IMAGE = "observation.image"
OBS_IMAGES = "observation.images" OBS_IMAGES = "observation.images"
ACTION = "action" ACTION = "action"
@@ -86,10 +86,13 @@ IMAGES_ORDER = {
OBS_IMAGE_3: 2, OBS_IMAGE_3: 2,
OBS_IMAGE_4: 3, OBS_IMAGE_4: 3,
} }
import random
from lerobot.policies.utils import ( from lerobot.policies.utils import (
populate_queues, populate_queues,
) )
import random
def create_sinusoidal_pos_embedding( def create_sinusoidal_pos_embedding(
time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu" time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu"
) -> Tensor: ) -> Tensor:
@@ -171,7 +174,10 @@ def resize_with_pad(img, width, height, pad_value=-1):
padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value) padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value)
return padded_img return padded_img
_VARIANT_RE = re.compile(r"\.so\d+(?:-[\w]+)?_buffer_") _VARIANT_RE = re.compile(r"\.so\d+(?:-[\w]+)?_buffer_")
def canonicalise(k: str) -> str: def canonicalise(k: str) -> str:
""" """
Remove dataset-variant markers like '.so100-blue_' or '.so100_' from a Remove dataset-variant markers like '.so100-blue_' or '.so100_' from a
@@ -179,6 +185,7 @@ def canonicalise(k: str) -> str:
""" """
return _VARIANT_RE.sub(".buffer_", k) return _VARIANT_RE.sub(".buffer_", k)
def standardise_state_dict( def standardise_state_dict(
checkpoint: dict[str, torch.Tensor], ref_keys: set[str], *, verbose: bool = True checkpoint: dict[str, torch.Tensor], ref_keys: set[str], *, verbose: bool = True
) -> tuple[dict[str, torch.Tensor], list[str]]: ) -> tuple[dict[str, torch.Tensor], list[str]]:
@@ -209,6 +216,7 @@ def standardise_state_dict(
out.update({k: checkpoint[k] for k in unmatched}) out.update({k: checkpoint[k] for k in unmatched})
return out, unmatched return out, unmatched
def load_smolvla( def load_smolvla(
model: torch.nn.Module, model: torch.nn.Module,
filename: str | os.PathLike, filename: str | os.PathLike,
@@ -237,6 +245,8 @@ def load_smolvla(
) )
return model return model
def pad_vector(vector, new_dim): def pad_vector(vector, new_dim):
"""Can be (batch_size x sequence_length x features_dimension) """Can be (batch_size x sequence_length x features_dimension)
or (batch_size x features_dimension) or (batch_size x features_dimension)
@@ -286,6 +296,7 @@ def aloha_gripper_to_angular(value):
# The values 0.4 and 1.5 were measured on an actual Trossen robot. # The values 0.4 and 1.5 were measured on an actual Trossen robot.
return normalize(value, min_val=0.4, max_val=1.5) return normalize(value, min_val=0.4, max_val=1.5)
def rename_checkpoint_keys(checkpoint: dict, rename_str: str): def rename_checkpoint_keys(checkpoint: dict, rename_str: str):
""" """
Renames keys in a checkpoint dictionary based on the given rename string. Renames keys in a checkpoint dictionary based on the given rename string.
@@ -307,6 +318,8 @@ def rename_checkpoint_keys(checkpoint: dict, rename_str: str):
k = k.replace(old_key, new_key) k = k.replace(old_key, new_key)
new_checkpoint[k] = v new_checkpoint[k] = v
return new_checkpoint return new_checkpoint
def aloha_gripper_from_angular(value): def aloha_gripper_from_angular(value):
# Convert from the gripper position used by pi0 to the gripper position that is used by Aloha. # Convert from the gripper position used by pi0 to the gripper position that is used by Aloha.
# Note that the units are still angular but the range is different. # Note that the units are still angular but the range is different.
@@ -324,6 +337,7 @@ def aloha_gripper_from_angular_inv(value):
value = unnormalize(value, min_val=-0.6213, max_val=1.4910) value = unnormalize(value, min_val=-0.6213, max_val=1.4910)
return normalize(value, min_val=0.4, max_val=1.5) return normalize(value, min_val=0.4, max_val=1.5)
class SMOLPI0Policy(PreTrainedPolicy): class SMOLPI0Policy(PreTrainedPolicy):
"""Wrapper class around VLAFlowMatching model to train and run inference within LeRobot.""" """Wrapper class around VLAFlowMatching model to train and run inference within LeRobot."""
@@ -374,7 +388,9 @@ class SMOLPI0Policy(PreTrainedPolicy):
self.language_tokenizer = AutoProcessor.from_pretrained(self.config.vlm_model_name).tokenizer self.language_tokenizer = AutoProcessor.from_pretrained(self.config.vlm_model_name).tokenizer
self.model = VLAFlowMatching(config) self.model = VLAFlowMatching(config)
self.include_past_states = config.n_obs_steps > 1 and OBS_STATE in self.config.past_obs_keys.split(",") self.include_past_states = config.n_obs_steps > 1 and OBS_STATE in self.config.past_obs_keys.split(
","
)
self.include_past_images = config.n_obs_steps > 1 and "image" in self.config.past_obs_keys.split(",") self.include_past_images = config.n_obs_steps > 1 and "image" in self.config.past_obs_keys.split(",")
self.num_past_images = self.config.n_obs_steps if self.include_past_images else 1 self.num_past_images = self.config.n_obs_steps if self.include_past_images else 1
self.reset() self.reset()
@@ -389,31 +405,20 @@ class SMOLPI0Policy(PreTrainedPolicy):
for k in self.config.input_features: for k in self.config.input_features:
if any([past_obs_key in k for past_obs_key in self.config.past_obs_keys.split(",")]): if any([past_obs_key in k for past_obs_key in self.config.past_obs_keys.split(",")]):
self._queues[k] = deque(maxlen=self.config.n_obs_steps) self._queues[k] = deque(maxlen=self.config.n_obs_steps)
def get_optim_params(self) -> dict: def get_optim_params(self) -> dict:
if self.config.optimizer_lr_vlm > 0 and self.config.optimizer_lr_vlm != self.config.optimizer_lr: if self.config.optimizer_lr_vlm > 0 and self.config.optimizer_lr_vlm != self.config.optimizer_lr:
params = [ params = [
{"params": [p for n, p in self.named_parameters() if ".vlm." not in n and p.requires_grad]},
{ {
"params": [ "params": [p for n, p in self.named_parameters() if ".vlm." in n and p.requires_grad],
p
for n, p in self.named_parameters()
if not ".vlm." in n and p.requires_grad
]
},
{
"params": [
p
for n, p in self.named_parameters()
if ".vlm." in n and p.requires_grad
],
"lr": self.config.optimizer_lr_vlm, "lr": self.config.optimizer_lr_vlm,
}, },
] ]
return params return params
else: else:
return self.parameters() return self.parameters()
def merge_peft_model_weights(self) -> None: def merge_peft_model_weights(self) -> None:
if "lora" in self.config.peft_method: if "lora" in self.config.peft_method:
@@ -438,9 +443,7 @@ class SMOLPI0Policy(PreTrainedPolicy):
state = self.prepare_state(batch) state = self.prepare_state(batch)
lang_tokens, lang_masks = self.prepare_language(batch) lang_tokens, lang_masks = self.prepare_language(batch)
actions = self.model.sample_actions( actions = self.model.sample_actions(images, img_masks, lang_tokens, lang_masks, state, noise=noise)
images, img_masks, lang_tokens, lang_masks, state, noise=noise
)
# Unpad actions # Unpad actions
original_action_dim = self.config.action_feature.shape[0] original_action_dim = self.config.action_feature.shape[0]
actions = actions[:, :, :original_action_dim] actions = actions[:, :, :original_action_dim]
@@ -469,6 +472,7 @@ class SMOLPI0Policy(PreTrainedPolicy):
device=map_location, device=map_location,
checkpoint_keys_mapping="model._orig_mod.//model.", checkpoint_keys_mapping="model._orig_mod.//model.",
) )
@torch.no_grad @torch.no_grad
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
"""Select a single action given environment observations. """Select a single action given environment observations.
@@ -564,8 +568,12 @@ class SMOLPI0Policy(PreTrainedPolicy):
present_img_keys = [key for key in self.config.image_features if key in batch] present_img_keys = [key for key in self.config.image_features if key in batch]
missing_img_keys = [key for key in self.config.image_features if key not in batch] missing_img_keys = [key for key in self.config.image_features if key not in batch]
present_img_keys = sorted(present_img_keys, key=lambda k: IMAGES_ORDER.get(k, float("inf")), reverse=self.config.reverse_images_order) present_img_keys = sorted(
if self.config.shuffle_camera_positions and ACTION in batch: # only during training present_img_keys,
key=lambda k: IMAGES_ORDER.get(k, float("inf")),
reverse=self.config.reverse_images_order,
)
if self.config.shuffle_camera_positions and ACTION in batch: # only during training
present_img_keys = random.sample(present_img_keys, len(present_img_keys)) present_img_keys = random.sample(present_img_keys, len(present_img_keys))
if len(present_img_keys) == 0: if len(present_img_keys) == 0:
raise ValueError( raise ValueError(
@@ -609,7 +617,10 @@ class SMOLPI0Policy(PreTrainedPolicy):
tasks = [tasks[0] for _ in range(batch[OBS_STATE].shape[0])] tasks = [tasks[0] for _ in range(batch[OBS_STATE].shape[0])]
if self.config.add_prompt_template: if self.config.add_prompt_template:
tasks = [f"{self.config.prefix_prompt_template}{task}{self.config.suffix_prompt_template}" for task in tasks] tasks = [
f"{self.config.prefix_prompt_template}{task}{self.config.suffix_prompt_template}"
for task in tasks
]
else: else:
tasks = [task if task.endswith("\n") else f"{task}\n" for task in tasks] tasks = [task if task.endswith("\n") else f"{task}\n" for task in tasks]
tokenized_prompt = self.language_tokenizer.__call__( tokenized_prompt = self.language_tokenizer.__call__(
@@ -618,7 +629,7 @@ class SMOLPI0Policy(PreTrainedPolicy):
padding_side="right", padding_side="right",
max_length=self.config.tokenizer_max_length, max_length=self.config.tokenizer_max_length,
return_tensors="pt", return_tensors="pt",
truncation=True, # FIXME(mshukor) truncation=True, # FIXME(mshukor)
) )
lang_tokens = tokenized_prompt["input_ids"].to(device=device) lang_tokens = tokenized_prompt["input_ids"].to(device=device)
@@ -655,7 +666,11 @@ class SMOLPI0Policy(PreTrainedPolicy):
def prepare_state(self, batch): def prepare_state(self, batch):
"""Pad state""" """Pad state"""
state = batch[OBS_STATE][:, -1, :] if (batch[OBS_STATE].ndim > 2 and not self.include_past_states) else batch[OBS_STATE] # FIXME(mshukor): no state history for now state = (
batch[OBS_STATE][:, -1, :]
if (batch[OBS_STATE].ndim > 2 and not self.include_past_states)
else batch[OBS_STATE]
) # FIXME(mshukor): no state history for now
state = pad_vector(state, self.config.max_state_dim) state = pad_vector(state, self.config.max_state_dim)
return state return state
@@ -666,7 +681,9 @@ class SMOLPI0Policy(PreTrainedPolicy):
if self.config.relative_actions_mode == "first": if self.config.relative_actions_mode == "first":
actions = torch.cat((actions[:, :1], actions[:, 1:] - actions[:, :1]), dim=1) actions = torch.cat((actions[:, :1], actions[:, 1:] - actions[:, :1]), dim=1)
elif self.config.relative_actions_mode == "state": elif self.config.relative_actions_mode == "state":
assert batch[ACTION].shape[-1] == batch[OBS_STATE].shape[-1], "Relative action mode 'state' requires the action and state to have the same dimension." assert batch[ACTION].shape[-1] == batch[OBS_STATE].shape[-1], (
"Relative action mode 'state' requires the action and state to have the same dimension."
)
if state.ndim == 2: if state.ndim == 2:
state = state.unsqueeze(1) state = state.unsqueeze(1)
actions = actions - state actions = actions - state
@@ -674,6 +691,7 @@ class SMOLPI0Policy(PreTrainedPolicy):
actions = torch.cat((actions[:, :1], actions[:, 1:] - actions[:, :-1]), dim=1) actions = torch.cat((actions[:, :1], actions[:, 1:] - actions[:, :-1]), dim=1)
return actions return actions
def pad_tensor(tensor, max_len, pad_value=0): def pad_tensor(tensor, max_len, pad_value=0):
""" """
Efficiently pads a tensor along sequence dimension to match max_len. Efficiently pads a tensor along sequence dimension to match max_len.
@@ -687,13 +705,16 @@ def pad_tensor(tensor, max_len, pad_value=0):
torch.Tensor: Shape (B, max_len, ...) or (B, max_len). torch.Tensor: Shape (B, max_len, ...) or (B, max_len).
""" """
B, L = tensor.shape[:2] B, L = tensor.shape[:2]
# Create a padded tensor of max_len and copy the existing values # Create a padded tensor of max_len and copy the existing values
padded_tensor = torch.full((B, max_len, *tensor.shape[2:]), pad_value, dtype=tensor.dtype, device=tensor.device) padded_tensor = torch.full(
(B, max_len, *tensor.shape[2:]), pad_value, dtype=tensor.dtype, device=tensor.device
)
padded_tensor[:, :L] = tensor # Efficient in-place copy padded_tensor[:, :L] = tensor # Efficient in-place copy
return padded_tensor return padded_tensor
class VLAFlowMatching(nn.Module): class VLAFlowMatching(nn.Module):
""" """
π0: A Vision-Language-Action Flow Model for General Robot Control π0: A Vision-Language-Action Flow Model for General Robot Control
@@ -725,7 +746,8 @@ class VLAFlowMatching(nn.Module):
super().__init__() super().__init__()
self.config = config self.config = config
self.vlm_with_expert = SmolVLMWithExpertModel(model_id=self.config.vlm_model_name, self.vlm_with_expert = SmolVLMWithExpertModel(
model_id=self.config.vlm_model_name,
freeze_vision_encoder=self.config.freeze_vision_encoder, freeze_vision_encoder=self.config.freeze_vision_encoder,
train_expert_only=self.config.train_expert_only, train_expert_only=self.config.train_expert_only,
attention_implementation=self.config.attention_implementation, attention_implementation=self.config.attention_implementation,
@@ -736,50 +758,64 @@ class VLAFlowMatching(nn.Module):
self_attn_every_n_layers=self.config.self_attn_every_n_layers, self_attn_every_n_layers=self.config.self_attn_every_n_layers,
expert_width_multiplier=self.config.expert_width_multiplier, expert_width_multiplier=self.config.expert_width_multiplier,
self_attn_only_actions=self.config.self_attn_only_actions, self_attn_only_actions=self.config.self_attn_only_actions,
) )
# self.paligemma_with_expert = self.configure_peft(paligemma_with_expert) # self.paligemma_with_expert = self.configure_peft(paligemma_with_expert)
self.vlm_with_expert.configure_peft(config=self.config) self.vlm_with_expert.configure_peft(config=self.config)
# Projections are float32 # Projections are float32
self.state_to_prefix = self.config.state_to_prefix self.state_to_prefix = self.config.state_to_prefix
if self.state_to_prefix: if self.state_to_prefix:
self.state_proj = nn.Linear(self.config.max_state_dim, self.vlm_with_expert.config.text_config.hidden_size) self.state_proj = nn.Linear(
self.config.max_state_dim, self.vlm_with_expert.config.text_config.hidden_size
)
else: else:
self.state_proj = nn.Linear(self.config.max_state_dim, self.vlm_with_expert.expert_hidden_size) self.state_proj = nn.Linear(self.config.max_state_dim, self.vlm_with_expert.expert_hidden_size)
self.action_in_proj = nn.Linear(self.config.max_action_dim, self.vlm_with_expert.expert_hidden_size) self.action_in_proj = nn.Linear(self.config.max_action_dim, self.vlm_with_expert.expert_hidden_size)
self.action_out_proj = nn.Linear(self.vlm_with_expert.expert_hidden_size, self.config.max_action_dim) self.action_out_proj = nn.Linear(self.vlm_with_expert.expert_hidden_size, self.config.max_action_dim)
self.action_time_mlp_in = nn.Linear(self.vlm_with_expert.expert_hidden_size * 2, self.vlm_with_expert.expert_hidden_size) self.action_time_mlp_in = nn.Linear(
self.action_time_mlp_out = nn.Linear(self.vlm_with_expert.expert_hidden_size, self.vlm_with_expert.expert_hidden_size) self.vlm_with_expert.expert_hidden_size * 2, self.vlm_with_expert.expert_hidden_size
)
self.action_time_mlp_out = nn.Linear(
self.vlm_with_expert.expert_hidden_size, self.vlm_with_expert.expert_hidden_size
)
self.set_requires_grad() self.set_requires_grad()
# SmolVLM2 has: [fake_tok + crop_tok + crop + fake_tok + crop_tok ... + fake_tok + global_tok + global + fake_tok] + [second image] + ... # SmolVLM2 has: [fake_tok + crop_tok + crop + fake_tok + crop_tok ... + fake_tok + global_tok + global + fake_tok] + [second image] + ...
if any([k in self.config.vlm_model_name for k in ["SmolVLM-", "SmolVLA-"]]): if any([k in self.config.vlm_model_name for k in ["SmolVLM-", "SmolVLA-"]]):
if "SmolVLM-Instruct" in self.config.vlm_model_name: if "SmolVLM-Instruct" in self.config.vlm_model_name:
self.fake_image_token = 49152 self.fake_image_token = 49152
self.global_image_token = [44, 13906, 29, 6266, 46] self.global_image_token = [44, 13906, 29, 6266, 46]
self.global_image_start_token = torch.tensor([self.fake_image_token] + self.global_image_token, dtype=torch.long) self.global_image_start_token = torch.tensor(
[self.fake_image_token] + self.global_image_token, dtype=torch.long
)
else: else:
self.fake_image_token = 49189 self.fake_image_token = 49189
self.global_image_token = 49152 self.global_image_token = 49152
self.global_image_start_token = torch.tensor([self.fake_image_token, self.global_image_token], dtype=torch.long) self.global_image_start_token = torch.tensor(
[self.fake_image_token, self.global_image_token], dtype=torch.long
)
else: else:
self.fake_image_token = self.vlm_with_expert.processor.tokenizer.fake_image_token_id self.fake_image_token = self.vlm_with_expert.processor.tokenizer.fake_image_token_id
self.global_image_token = self.vlm_with_expert.processor.tokenizer.global_image_token_id self.global_image_token = self.vlm_with_expert.processor.tokenizer.global_image_token_id
self.global_image_start_token = torch.tensor([self.fake_image_token, self.global_image_token], dtype=torch.long) self.global_image_start_token = torch.tensor(
[self.fake_image_token, self.global_image_token], dtype=torch.long
)
self.add_image_special_tokens = self.config.add_image_special_tokens self.add_image_special_tokens = self.config.add_image_special_tokens
self.add_local_special_image_tokens = self.config.add_local_special_image_tokens self.add_local_special_image_tokens = self.config.add_local_special_image_tokens
self.local_image_tokens = [torch.tensor([self.fake_image_token, tok], dtype=torch.long) for tok in [49153, 49154, 49155, 49159, 49160, 49161, 49165, 49166, 49167]] # assume 3 x 3 grid self.local_image_tokens = [
torch.tensor([self.fake_image_token, tok], dtype=torch.long)
for tok in [49153, 49154, 49155, 49159, 49160, 49161, 49165, 49166, 49167]
] # assume 3 x 3 grid
self.local_image_start_token = self.global_image_start_token self.local_image_start_token = self.global_image_start_token
self.image_end_token = torch.tensor([self.fake_image_token], dtype=torch.long) self.image_end_token = torch.tensor([self.fake_image_token], dtype=torch.long)
self.prefix_length = self.config.prefix_length self.prefix_length = self.config.prefix_length
self.include_past_images = self.config.n_obs_steps > 1 and "image" in self.config.past_obs_keys.split(",") self.include_past_images = self.config.n_obs_steps > 1 and "image" in self.config.past_obs_keys.split(
","
)
self.num_past_images = self.config.n_obs_steps if self.include_past_images else 1 self.num_past_images = self.config.n_obs_steps if self.include_past_images else 1
self.causal_attention_on_history = self.config.causal_attention_on_history self.causal_attention_on_history = self.config.causal_attention_on_history
# def configure_peft(self, model): # def configure_peft(self, model):
# # return model # # return model
@@ -845,14 +881,30 @@ class VLAFlowMatching(nn.Module):
img, img,
img_mask, img_mask,
) in enumerate(zip(images, img_masks, strict=False)): ) in enumerate(zip(images, img_masks, strict=False)):
# FIXME(mshukor): add special tokens for the history each history_steps or not # FIXME(mshukor): add special tokens for the history each history_steps or not
if self.add_image_special_tokens: if self.add_image_special_tokens:
if self.add_local_special_image_tokens and img_idx % num_images != num_images - 1: if self.add_local_special_image_tokens and img_idx % num_images != num_images - 1:
local_token_idx = img_idx % num_images local_token_idx = img_idx % num_images
image_start_token = self.vlm_with_expert.embed_language_tokens(self.local_image_tokens[local_token_idx].to(device=self.vlm_with_expert.vlm.device)).unsqueeze(0).expand(img.shape[0], -1, -1) image_start_token = (
self.vlm_with_expert.embed_language_tokens(
self.local_image_tokens[local_token_idx].to(
device=self.vlm_with_expert.vlm.device
)
)
.unsqueeze(0)
.expand(img.shape[0], -1, -1)
)
else: else:
image_start_token = self.vlm_with_expert.embed_language_tokens(self.global_image_start_token.to(device=self.vlm_with_expert.vlm.device)).unsqueeze(0).expand(img.shape[0], -1, -1) image_start_token = (
image_start_mask = torch.ones_like(image_start_token[:, :, 0], dtype=torch.bool, device=image_start_token.device) self.vlm_with_expert.embed_language_tokens(
self.global_image_start_token.to(device=self.vlm_with_expert.vlm.device)
)
.unsqueeze(0)
.expand(img.shape[0], -1, -1)
)
image_start_mask = torch.ones_like(
image_start_token[:, :, 0], dtype=torch.bool, device=image_start_token.device
)
if self.causal_attention_on_history and img_idx % num_images == 0: if self.causal_attention_on_history and img_idx % num_images == 0:
att_masks += [1] + [0] * (image_start_mask.shape[-1] - 1) att_masks += [1] + [0] * (image_start_mask.shape[-1] - 1)
else: else:
@@ -861,7 +913,7 @@ class VLAFlowMatching(nn.Module):
pad_masks.append(image_start_mask) pad_masks.append(image_start_mask)
img_emb = self.vlm_with_expert.embed_image(img) img_emb = self.vlm_with_expert.embed_image(img)
img_emb = img_emb #.to(dtype=self.vlm_with_expert.type) img_emb = img_emb # .to(dtype=self.vlm_with_expert.type)
# Normalize image embeddings # Normalize image embeddings
img_emb_dim = img_emb.shape[-1] img_emb_dim = img_emb.shape[-1]
@@ -880,16 +932,26 @@ class VLAFlowMatching(nn.Module):
att_masks += [0] * (num_img_embs) att_masks += [0] * (num_img_embs)
if self.add_image_special_tokens: if self.add_image_special_tokens:
if not self.add_local_special_image_tokens or (self.add_local_special_image_tokens and img_idx % num_images == num_images - 1): if not self.add_local_special_image_tokens or (
image_end_token = self.vlm_with_expert.embed_language_tokens(self.image_end_token.to(device=self.vlm_with_expert.vlm.device)).unsqueeze(0).expand(img.shape[0], -1, -1) self.add_local_special_image_tokens and img_idx % num_images == num_images - 1
image_end_mask = torch.ones_like(image_end_token[:, :, 0], dtype=torch.bool, device=image_end_token.device) ):
image_end_token = (
self.vlm_with_expert.embed_language_tokens(
self.image_end_token.to(device=self.vlm_with_expert.vlm.device)
)
.unsqueeze(0)
.expand(img.shape[0], -1, -1)
)
image_end_mask = torch.ones_like(
image_end_token[:, :, 0], dtype=torch.bool, device=image_end_token.device
)
embs.append(image_end_token) embs.append(image_end_token)
pad_masks.append(image_end_mask) pad_masks.append(image_end_mask)
att_masks += [0] * (image_end_mask.shape[1]) att_masks += [0] * (image_end_mask.shape[1])
lang_emb = self.vlm_with_expert.embed_language_tokens(lang_tokens) lang_emb = self.vlm_with_expert.embed_language_tokens(lang_tokens)
# Normalize language embeddings # Normalize language embeddings
lang_emb_dim = lang_emb.shape[-1] lang_emb_dim = lang_emb.shape[-1]
lang_emb = lang_emb * math.sqrt(lang_emb_dim) # FIXME(mshukor): is this needed for smolvlm? lang_emb = lang_emb * math.sqrt(lang_emb_dim) # FIXME(mshukor): is this needed for smolvlm?
embs.append(lang_emb) embs.append(lang_emb)
pad_masks.append(lang_masks) pad_masks.append(lang_masks)
@@ -900,7 +962,9 @@ class VLAFlowMatching(nn.Module):
if state is not None and self.state_to_prefix: if state is not None and self.state_to_prefix:
state_emb = self.state_proj(state) state_emb = self.state_proj(state)
state_emb = state_emb[:, None, :] if state_emb.ndim == 2 else state_emb #.to(dtype=self.vlm_with_expert.type) state_emb = (
state_emb[:, None, :] if state_emb.ndim == 2 else state_emb
) # .to(dtype=self.vlm_with_expert.type)
embs.append(state_emb) embs.append(state_emb)
bsize = state_emb.shape[0] bsize = state_emb.shape[0]
dtype = state_emb.dtype dtype = state_emb.dtype
@@ -912,7 +976,7 @@ class VLAFlowMatching(nn.Module):
# Set attention masks so that image and language inputs do not attend to state or actions # Set attention masks so that image and language inputs do not attend to state or actions
# att_masks += [1] + [0]*(states_seq_len - 1) # att_masks += [1] + [0]*(states_seq_len - 1)
att_masks += [1]*(states_seq_len) att_masks += [1] * (states_seq_len)
embs = torch.cat(embs, dim=1) embs = torch.cat(embs, dim=1)
pad_masks = torch.cat(pad_masks, dim=1) pad_masks = torch.cat(pad_masks, dim=1)
att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device) att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device)
@@ -937,7 +1001,9 @@ class VLAFlowMatching(nn.Module):
# Embed state # Embed state
if not self.state_to_prefix: if not self.state_to_prefix:
state_emb = self.state_proj(state) state_emb = self.state_proj(state)
state_emb = state_emb[:, None, :] if state_emb.ndim == 2 else state_emb #.to(dtype=self.vlm_with_expert.type) state_emb = (
state_emb[:, None, :] if state_emb.ndim == 2 else state_emb
) # .to(dtype=self.vlm_with_expert.type)
embs.append(state_emb) embs.append(state_emb)
bsize = state_emb.shape[0] bsize = state_emb.shape[0]
dtype = state_emb.dtype dtype = state_emb.dtype
@@ -948,8 +1014,7 @@ class VLAFlowMatching(nn.Module):
pad_masks.append(state_mask) pad_masks.append(state_mask)
# Set attention masks so that image and language inputs do not attend to state or actions # Set attention masks so that image and language inputs do not attend to state or actions
att_masks += [1] + [0]*(states_seq_len - 1) att_masks += [1] + [0] * (states_seq_len - 1)
# Fuse timestep + action information using an MLP # Fuse timestep + action information using an MLP
action_emb = self.action_in_proj(noisy_actions) action_emb = self.action_in_proj(noisy_actions)
@@ -1010,7 +1075,7 @@ class VLAFlowMatching(nn.Module):
images, img_masks, lang_tokens, lang_masks, state=state images, img_masks, lang_tokens, lang_masks, state=state
) )
suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(state, x_t, time) suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(state, x_t, time)
pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1) pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)
att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1) att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1)
@@ -1061,12 +1126,12 @@ class VLAFlowMatching(nn.Module):
x_t = torch.zeros_like(noise, dtype=torch.float32, device=device) x_t = torch.zeros_like(noise, dtype=torch.float32, device=device)
expanded_time = torch.zeros(bsize, dtype=torch.float32, device=device) expanded_time = torch.zeros(bsize, dtype=torch.float32, device=device)
x_t = self.denoise_step( x_t = self.denoise_step(
state, state,
prefix_pad_masks, prefix_pad_masks,
past_key_values, past_key_values,
x_t, x_t,
expanded_time, expanded_time,
) )
else: else:
dt = -1.0 / self.config.num_steps dt = -1.0 / self.config.num_steps
dt = torch.tensor(dt, dtype=torch.float32, device=device) dt = torch.tensor(dt, dtype=torch.float32, device=device)
@@ -12,31 +12,28 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import List, Optional, Union
from functools import partial
import copy import copy
from functools import partial
from typing import List, Optional, Union
import torch import torch
import torch.version
import torch.nn.functional as F # noqa: N812 import torch.nn.functional as F # noqa: N812
import torch.version
from peft import LoraConfig, TaskType, get_peft_model from peft import LoraConfig, TaskType, get_peft_model
from pytest import Cache from pytest import Cache
from torch import nn from torch import nn
from transformers import ( from transformers import (
AutoConfig, AutoConfig,
GemmaForCausalLM,
AutoModelForImageTextToText,
AutoProcessor,
PretrainedConfig,
PreTrainedModel,
SmolVLMForConditionalGeneration,
AutoModel, AutoModel,
AutoModelForImageTextToText,
AutoModelForVision2Seq, AutoModelForVision2Seq,
AutoProcessor,
SmolVLMForConditionalGeneration,
) )
from transformers.models.auto import CONFIG_MAPPING
from transformers import SmolVLMModel, SmolVLMConfig
from lerobot.policies.smolpi0.flex_attention import flex_attention_forward from lerobot.policies.smolpi0.flex_attention import flex_attention_forward
def _round_up_to_multiple(x, multiple): def _round_up_to_multiple(x, multiple):
return (x + multiple - 1) // multiple * multiple return (x + multiple - 1) // multiple * multiple
@@ -180,20 +177,31 @@ def apply_rope(x, positions, max_wavelength=10_000):
# f"Wrong value provided for `attention_implementation` ({self.attention_implementation}). Expected 'eager', 'fa2' or 'flex'." # f"Wrong value provided for `attention_implementation` ({self.attention_implementation}). Expected 'eager', 'fa2' or 'flex'."
# ) # )
def get_intermediate_size(hidden_dim, ffn_dim_multiplier=4, multiple_of=256): def get_intermediate_size(hidden_dim, ffn_dim_multiplier=4, multiple_of=256):
hidden_dim = int(2 * hidden_dim / 3) hidden_dim = int(2 * hidden_dim / 3)
hidden_dim = int(ffn_dim_multiplier * hidden_dim) hidden_dim = int(ffn_dim_multiplier * hidden_dim)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
return hidden_dim return hidden_dim
class SmolVLMWithExpertModel(nn.Module): class SmolVLMWithExpertModel(nn.Module):
# config_class = PaliGemmaWithExpertConfig # config_class = PaliGemmaWithExpertConfig
def __init__(self, model_id: str = "HuggingFaceTB/SmolVLM2-500M-Video-Instruct", def __init__(
load_vlm_weights: bool = True, train_expert_only: bool = True, freeze_vision_encoder: bool = False, self,
attention_implementation: str = "eager", attention_mode: str = "self_attn", num_expert_layers: int = -1, model_id: str = "HuggingFaceTB/SmolVLM2-500M-Video-Instruct",
num_vlm_layers: int = -1, self_attn_every_n_layers: int = -1, expert_width_multiplier: float = 0.5, self_attn_only_actions: bool = False): load_vlm_weights: bool = True,
train_expert_only: bool = True,
freeze_vision_encoder: bool = False,
attention_implementation: str = "eager",
attention_mode: str = "self_attn",
num_expert_layers: int = -1,
num_vlm_layers: int = -1,
self_attn_every_n_layers: int = -1,
expert_width_multiplier: float = 0.5,
self_attn_only_actions: bool = False,
):
super().__init__() super().__init__()
if load_vlm_weights: if load_vlm_weights:
print(f"Loading {model_id} weights ...") print(f"Loading {model_id} weights ...")
@@ -227,15 +235,17 @@ class SmolVLMWithExpertModel(nn.Module):
# Smaller lm expert # Smaller lm expert
lm_expert_config = copy.deepcopy(config.text_config) lm_expert_config = copy.deepcopy(config.text_config)
hidden_size = lm_expert_config.hidden_size hidden_size = lm_expert_config.hidden_size
lm_expert_config.hidden_size = int(hidden_size*expert_width_multiplier) #hidden_size // 2 lm_expert_config.hidden_size = int(hidden_size * expert_width_multiplier) # hidden_size // 2
lm_expert_config.intermediate_size = get_intermediate_size(int(hidden_size*expert_width_multiplier)) lm_expert_config.intermediate_size = get_intermediate_size(int(hidden_size * expert_width_multiplier))
lm_expert_config.num_hidden_layers = self.num_vlm_layers lm_expert_config.num_hidden_layers = self.num_vlm_layers
if num_expert_layers > 0 : if num_expert_layers > 0:
assert len(self.get_vlm_model().text_model.layers) % num_expert_layers == 0, f"Number of layers in the VLM {len(self.get_vlm_model().text_model.layers)} are not multiple of num_expert_layers {num_expert_layers}" assert len(self.get_vlm_model().text_model.layers) % num_expert_layers == 0, (
f"Number of layers in the VLM {len(self.get_vlm_model().text_model.layers)} are not multiple of num_expert_layers {num_expert_layers}"
)
lm_expert_config.num_hidden_layers = num_expert_layers lm_expert_config.num_hidden_layers = num_expert_layers
# lm_expert_config.head_dim = lm_expert_config.head_dim * 2 # lm_expert_config.head_dim = lm_expert_config.head_dim * 2
self.lm_expert = AutoModel.from_config(lm_expert_config) self.lm_expert = AutoModel.from_config(lm_expert_config)
self.num_expert_layers = len(self.lm_expert.layers) self.num_expert_layers = len(self.lm_expert.layers)
self.self_attn_every_n_layers = self_attn_every_n_layers self.self_attn_every_n_layers = self_attn_every_n_layers
self.self_attn_only_actions = self_attn_only_actions self.self_attn_only_actions = self_attn_only_actions
@@ -245,10 +255,14 @@ class SmolVLMWithExpertModel(nn.Module):
if self.self_attn_every_n_layers > 0 and layer_idx % self.self_attn_every_n_layers == 0: if self.self_attn_every_n_layers > 0 and layer_idx % self.self_attn_every_n_layers == 0:
continue continue
self.lm_expert.layers[layer_idx].self_attn.k_proj = nn.Linear( self.lm_expert.layers[layer_idx].self_attn.k_proj = nn.Linear(
config.text_config.num_key_value_heads * config.text_config.head_dim, lm_expert_config.num_key_value_heads * lm_expert_config.head_dim, bias=lm_expert_config.attention_bias config.text_config.num_key_value_heads * config.text_config.head_dim,
lm_expert_config.num_key_value_heads * lm_expert_config.head_dim,
bias=lm_expert_config.attention_bias,
) )
self.lm_expert.layers[layer_idx].self_attn.v_proj = nn.Linear( self.lm_expert.layers[layer_idx].self_attn.v_proj = nn.Linear(
config.text_config.num_key_value_heads * config.text_config.head_dim, lm_expert_config.num_key_value_heads * lm_expert_config.head_dim, bias=lm_expert_config.attention_bias config.text_config.num_key_value_heads * config.text_config.head_dim,
lm_expert_config.num_key_value_heads * lm_expert_config.head_dim,
bias=lm_expert_config.attention_bias,
) )
# Remove unused embed_tokens # Remove unused embed_tokens
self.lm_expert.embed_tokens = None self.lm_expert.embed_tokens = None
@@ -308,8 +322,10 @@ class SmolVLMWithExpertModel(nn.Module):
else: else:
self.vlm = self.vlm.merge_and_unload() self.vlm = self.vlm.merge_and_unload()
def get_vlm_model(self,): def get_vlm_model(
if hasattr(self.vlm.model, "model"): # When using peft self,
):
if hasattr(self.vlm.model, "model"): # When using peft
return self.vlm.model.model return self.vlm.model.model
else: else:
return self.vlm.model return self.vlm.model
@@ -326,22 +342,20 @@ class SmolVLMWithExpertModel(nn.Module):
else: else:
# To avoid unused params issue with distributed training # To avoid unused params issue with distributed training
last_layers = [self.num_vlm_layers - 1] last_layers = [self.num_vlm_layers - 1]
if self.num_vlm_layers != self.num_expert_layers and self.num_vlm_layers % self.num_expert_layers == 0: if (
self.num_vlm_layers != self.num_expert_layers
and self.num_vlm_layers % self.num_expert_layers == 0
):
last_layers.append(self.num_vlm_layers - 2) last_layers.append(self.num_vlm_layers - 2)
frozen_layers = [ frozen_layers = [
"lm_head", "lm_head",
"text_model.model.norm.weight", "text_model.model.norm.weight",
] ]
for layer in last_layers: for layer in last_layers:
frozen_layers.append(f"text_model.model.layers.{layer}.") frozen_layers.append(f"text_model.model.layers.{layer}.")
for name, params in self.vlm.named_parameters(): for name, params in self.vlm.named_parameters():
if any( if any([k in name for k in frozen_layers]):
[
k in name
for k in frozen_layers
]
):
params.requires_grad = False params.requires_grad = False
# To avoid unused params issue with distributed training # To avoid unused params issue with distributed training
for name, params in self.lm_expert.named_parameters(): for name, params in self.lm_expert.named_parameters():
@@ -410,19 +424,34 @@ class SmolVLMWithExpertModel(nn.Module):
# FIXME(mshukor): add special image tokens specific to smolvlm # FIXME(mshukor): add special image tokens specific to smolvlm
# Get sequence from the vision encoder # Get sequence from the vision encoder
image_hidden_states = self.get_vlm_model().vision_model( image_hidden_states = (
pixel_values=image.to(dtype=self.get_vlm_model().vision_model.dtype), self.get_vlm_model()
patch_attention_mask=patch_attention_mask, .vision_model(
).last_hidden_state pixel_values=image.to(dtype=self.get_vlm_model().vision_model.dtype),
patch_attention_mask=patch_attention_mask,
)
.last_hidden_state
)
# Modality projection & resampling # Modality projection & resampling
image_hidden_states = self.get_vlm_model().connector(image_hidden_states) image_hidden_states = self.get_vlm_model().connector(image_hidden_states)
return image_hidden_states return image_hidden_states
def embed_language_tokens(self, tokens: torch.Tensor): def embed_language_tokens(self, tokens: torch.Tensor):
return self.get_vlm_model().text_model.get_input_embeddings()(tokens) return self.get_vlm_model().text_model.get_input_embeddings()(tokens)
def forward_attn_layer(self, model_layers, inputs_embeds, layer_idx, position_ids, attention_mask, batch_size, head_dim, use_cache: bool = True, fill_kv_cache: bool = True, past_key_values=None) -> list[torch.Tensor]: def forward_attn_layer(
self,
model_layers,
inputs_embeds,
layer_idx,
position_ids,
attention_mask,
batch_size,
head_dim,
use_cache: bool = True,
fill_kv_cache: bool = True,
past_key_values=None,
) -> list[torch.Tensor]:
query_states = [] query_states = []
key_states = [] key_states = []
value_states = [] value_states = []
@@ -430,7 +459,7 @@ class SmolVLMWithExpertModel(nn.Module):
layer = model_layers[i][layer_idx] layer = model_layers[i][layer_idx]
if hidden_states is None or layer is None: if hidden_states is None or layer is None:
continue continue
# normalizer = torch.tensor(models[i].config.hidden_size**0.5, dtype=hidden_states.dtype) # normalizer = torch.tensor(models[i].config.hidden_size**0.5, dtype=hidden_states.dtype)
# hidden_states = hidden_states * normalizer # hidden_states = hidden_states * normalizer
hidden_states = layer.input_layernorm(hidden_states) hidden_states = layer.input_layernorm(hidden_states)
@@ -468,12 +497,16 @@ class SmolVLMWithExpertModel(nn.Module):
if inputs_embeds[1] is not None: if inputs_embeds[1] is not None:
suffix_len = inputs_embeds[1].shape[1] suffix_len = inputs_embeds[1].shape[1]
attention_mask_[:, -suffix_len:, :-suffix_len] = False attention_mask_[:, -suffix_len:, :-suffix_len] = False
position_ids_[:, -suffix_len:] = _position_ids[:, -suffix_len:] - _position_ids[:, -suffix_len][:, None] position_ids_[:, -suffix_len:] = (
_position_ids[:, -suffix_len:] - _position_ids[:, -suffix_len][:, None]
)
else: else:
attention_mask_ = _attention_mask attention_mask_ = _attention_mask
position_ids_ = _position_ids position_ids_ = _position_ids
query_states = apply_rope(query_states, position_ids_) # FIXME(mshukor): this assumes we have always the vlm features? query_states = apply_rope(
query_states, position_ids_
) # FIXME(mshukor): this assumes we have always the vlm features?
key_states = apply_rope(key_states, position_ids_) key_states = apply_rope(key_states, position_ids_)
if use_cache and past_key_values is None: if use_cache and past_key_values is None:
@@ -491,9 +524,7 @@ class SmolVLMWithExpertModel(nn.Module):
# the max len, then we (for instance) double the cache size. This implementation already exists # the max len, then we (for instance) double the cache size. This implementation already exists
# in `transformers`. (molbap) # in `transformers`. (molbap)
key_states = torch.cat([past_key_values[layer_idx]["key_states"], key_states], dim=1) key_states = torch.cat([past_key_values[layer_idx]["key_states"], key_states], dim=1)
value_states = torch.cat( value_states = torch.cat([past_key_values[layer_idx]["value_states"], value_states], dim=1)
[past_key_values[layer_idx]["value_states"], value_states], dim=1
)
attention_interface = self.get_attention_interface() attention_interface = self.get_attention_interface()
@@ -504,20 +535,32 @@ class SmolVLMWithExpertModel(nn.Module):
return [att_output], past_key_values return [att_output], past_key_values
def forward_cross_attn_layer(
def forward_cross_attn_layer(self, model_layers, inputs_embeds, layer_idx, position_ids, attention_mask, batch_size, head_dim, use_cache: bool = True, fill_kv_cache: bool = True, past_key_values = None) -> list[torch.Tensor]: self,
model_layers,
inputs_embeds,
layer_idx,
position_ids,
attention_mask,
batch_size,
head_dim,
use_cache: bool = True,
fill_kv_cache: bool = True,
past_key_values=None,
) -> list[torch.Tensor]:
attention_interface = self.get_attention_interface() attention_interface = self.get_attention_interface()
att_outputs = [] att_outputs = []
assert len(inputs_embeds) == 2 or (use_cache and past_key_values is not None and not fill_kv_cache), f"Both len(inputs_embeds) == {len(inputs_embeds)} and past_key_values is {past_key_values}" assert len(inputs_embeds) == 2 or (use_cache and past_key_values is not None and not fill_kv_cache), (
f"Both len(inputs_embeds) == {len(inputs_embeds)} and past_key_values is {past_key_values}"
)
if len(inputs_embeds) == 2 and not past_key_values: if len(inputs_embeds) == 2 and not past_key_values:
# Prefix attention # Prefix attention
seq_len = inputs_embeds[0].shape[1] seq_len = inputs_embeds[0].shape[1]
position_id, expert_position_id = position_ids[:, :seq_len], position_ids[:, seq_len:] position_id, expert_position_id = position_ids[:, :seq_len], position_ids[:, seq_len:]
prefix_attention_mask = attention_mask[:, :seq_len, :seq_len] prefix_attention_mask = attention_mask[:, :seq_len, :seq_len]
layer = model_layers[0][layer_idx] layer = model_layers[0][layer_idx]
hidden_states = layer.input_layernorm(inputs_embeds[0]) hidden_states = layer.input_layernorm(inputs_embeds[0])
@@ -558,7 +601,6 @@ class SmolVLMWithExpertModel(nn.Module):
key_states = past_key_values[layer_idx]["key_states"] key_states = past_key_values[layer_idx]["key_states"]
value_states = past_key_values[layer_idx]["value_states"] value_states = past_key_values[layer_idx]["value_states"]
# Expert # Expert
expert_layer = model_layers[1][layer_idx] expert_layer = model_layers[1][layer_idx]
if expert_layer is not None: if expert_layer is not None:
@@ -570,21 +612,37 @@ class SmolVLMWithExpertModel(nn.Module):
expert_hidden_states = expert_hidden_states.to(dtype=expert_layer.self_attn.q_proj.weight.dtype) expert_hidden_states = expert_hidden_states.to(dtype=expert_layer.self_attn.q_proj.weight.dtype)
expert_query_state = expert_layer.self_attn.q_proj(expert_hidden_states).view(expert_hidden_shape) expert_query_state = expert_layer.self_attn.q_proj(expert_hidden_states).view(expert_hidden_shape)
_key_states = key_states.to(dtype=expert_layer.self_attn.k_proj.weight.dtype).view(
*key_states.shape[:2], -1
)
expert_key_states = expert_layer.self_attn.k_proj(_key_states).view(
*_key_states.shape[:-1], -1, expert_layer.self_attn.head_dim
) # k_proj should have same dim as kv
_key_states = key_states.to(dtype=expert_layer.self_attn.k_proj.weight.dtype).view(*key_states.shape[:2], -1) _value_states = value_states.to(dtype=expert_layer.self_attn.v_proj.weight.dtype).view(
expert_key_states = expert_layer.self_attn.k_proj(_key_states).view(*_key_states.shape[:-1], -1, expert_layer.self_attn.head_dim) # k_proj should have same dim as kv *value_states.shape[:2], -1
)
expert_value_states = expert_layer.self_attn.v_proj(_value_states).view(
*_value_states.shape[:-1], -1, expert_layer.self_attn.head_dim
)
_value_states = value_states.to(dtype=expert_layer.self_attn.v_proj.weight.dtype).view(*value_states.shape[:2], -1) expert_position_id = (
expert_value_states = expert_layer.self_attn.v_proj(_value_states).view(*_value_states.shape[:-1], -1, expert_layer.self_attn.head_dim) expert_position_id - torch.min(expert_position_id, dim=1, keepdim=True).values
) # start from 0
expert_position_id = expert_position_id - torch.min(expert_position_id, dim=1, keepdim=True).values # start from 0 expert_attention_mask = attention_mask[
expert_attention_mask = attention_mask[:, -inputs_embeds[1].shape[1]:, :expert_key_states.shape[1]:] # take into account kv :, -inputs_embeds[1].shape[1] :, : expert_key_states.shape[1] :
] # take into account kv
expert_query_states = apply_rope(expert_query_state, expert_position_id) expert_query_states = apply_rope(expert_query_state, expert_position_id)
# expert_key_states = apply_rope(expert_key_state, expert_position_id) # expert_key_states = apply_rope(expert_key_state, expert_position_id)
att_output = attention_interface( att_output = attention_interface(
expert_attention_mask, batch_size, head_dim, expert_query_states, expert_key_states, expert_value_states expert_attention_mask,
batch_size,
head_dim,
expert_query_states,
expert_key_states,
expert_value_states,
) )
att_outputs.append(att_output) att_outputs.append(att_output)
else: else:
@@ -592,8 +650,8 @@ class SmolVLMWithExpertModel(nn.Module):
# att_output = att_output.to(dtype=models[i].dtype) # att_output = att_output.to(dtype=models[i].dtype)
return att_outputs, past_key_values return att_outputs, past_key_values
def get_model_layers(self, models: list) -> list: # FIXME(mshukor): is this efficient? def get_model_layers(self, models: list) -> list: # FIXME(mshukor): is this efficient?
vlm_layers = [] vlm_layers = []
expert_layers = [] expert_layers = []
multiple_of = self.num_vlm_layers // self.num_expert_layers multiple_of = self.num_vlm_layers // self.num_expert_layers
@@ -606,15 +664,16 @@ class SmolVLMWithExpertModel(nn.Module):
vlm_layers.append(models[0].layers[i]) vlm_layers.append(models[0].layers[i])
expert_layers.append(expert_layer) expert_layers.append(expert_layer)
return [vlm_layers, expert_layers] return [vlm_layers, expert_layers]
# TODO: break down this huge forward into modules or functions # TODO: break down this huge forward into modules or functions
def forward( def forward(
self, self,
attention_mask: Optional[torch.Tensor] = None, attention_mask: torch.Tensor | None = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: torch.LongTensor | None = None,
past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None, past_key_values: list[torch.FloatTensor] | Cache | None = None,
inputs_embeds: List[torch.FloatTensor] = None, inputs_embeds: list[torch.FloatTensor] = None,
use_cache: Optional[bool] = None, use_cache: bool | None = None,
fill_kv_cache: Optional[bool] = None, fill_kv_cache: bool | None = None,
): ):
models = [self.get_vlm_model().text_model, self.lm_expert] models = [self.get_vlm_model().text_model, self.lm_expert]
model_layers = self.get_model_layers(models) model_layers = self.get_model_layers(models)
@@ -626,26 +685,58 @@ class SmolVLMWithExpertModel(nn.Module):
continue continue
batch_size = hidden_states.shape[0] batch_size = hidden_states.shape[0]
# # Pad prefix embds so that prefix_embs + prefix_embs len are multiple of 128, pad left or right depending on the gen or train # # Pad prefix embds so that prefix_embs + prefix_embs len are multiple of 128, pad left or right depending on the gen or train
if self.attention_implementation == "flex": if self.attention_implementation == "flex":
if inputs_embeds[0] is not None and inputs_embeds[1] is not None and attention_mask.shape[-1] == attention_mask.shape[-2] and past_key_values is None: # Now only during training if (
inputs_embeds[0] is not None
and inputs_embeds[1] is not None
and attention_mask.shape[-1] == attention_mask.shape[-2]
and past_key_values is None
): # Now only during training
seq_len = inputs_embeds[0].shape[1] + inputs_embeds[1].shape[1] seq_len = inputs_embeds[0].shape[1] + inputs_embeds[1].shape[1]
padded_seq_len = _round_up_to_multiple(seq_len, 128) # FIXME(mshukor): more efficient to have a fixed seq len? padded_seq_len = _round_up_to_multiple(
seq_len, 128
) # FIXME(mshukor): more efficient to have a fixed seq len?
b_mask, q_len, kv_len = attention_mask.shape # The shape of your mask b_mask, q_len, kv_len = attention_mask.shape # The shape of your mask
pad = padded_seq_len - q_len pad = padded_seq_len - q_len
attention_mask = F.pad(attention_mask, (0, pad, 0, pad), value=True) attention_mask = F.pad(attention_mask, (0, pad, 0, pad), value=True)
inputs_embeds[0] = F.pad(inputs_embeds[0], (0, 0, 0, pad), value=0.0) inputs_embeds[0] = F.pad(inputs_embeds[0], (0, 0, 0, pad), value=0.0)
position_ids = F.pad(position_ids, (0, pad), value=0) position_ids = F.pad(position_ids, (0, pad), value=0)
# RMSNorm # RMSNorm
num_layers = self.num_vlm_layers num_layers = self.num_vlm_layers
head_dim = self.vlm.config.text_config.head_dim head_dim = self.vlm.config.text_config.head_dim
for layer_idx in range(num_layers): for layer_idx in range(num_layers):
if fill_kv_cache or "cross" not in self.attention_mode or (self.self_attn_every_n_layers > 0 and layer_idx % self.self_attn_every_n_layers == 0): if (
att_outputs, past_key_values = self.forward_attn_layer(model_layers, inputs_embeds, layer_idx, position_ids, attention_mask, batch_size, head_dim, use_cache=use_cache, fill_kv_cache=fill_kv_cache, past_key_values=past_key_values) fill_kv_cache
or "cross" not in self.attention_mode
or (self.self_attn_every_n_layers > 0 and layer_idx % self.self_attn_every_n_layers == 0)
):
att_outputs, past_key_values = self.forward_attn_layer(
model_layers,
inputs_embeds,
layer_idx,
position_ids,
attention_mask,
batch_size,
head_dim,
use_cache=use_cache,
fill_kv_cache=fill_kv_cache,
past_key_values=past_key_values,
)
else: else:
att_outputs, past_key_values = self.forward_cross_attn_layer(model_layers, inputs_embeds, layer_idx, position_ids, attention_mask, batch_size, head_dim, use_cache=use_cache, fill_kv_cache=fill_kv_cache, past_key_values=past_key_values) att_outputs, past_key_values = self.forward_cross_attn_layer(
model_layers,
inputs_embeds,
layer_idx,
position_ids,
attention_mask,
batch_size,
head_dim,
use_cache=use_cache,
fill_kv_cache=fill_kv_cache,
past_key_values=past_key_values,
)
# query_states = [] # query_states = []
# key_states = [] # key_states = []
# value_states = [] # value_states = []
@@ -703,7 +794,6 @@ class SmolVLMWithExpertModel(nn.Module):
# attention_mask, batch_size, head_dim, query_states, key_states, value_states # attention_mask, batch_size, head_dim, query_states, key_states, value_states
# ) # )
# att_output = att_output.to(dtype=models[i].dtype) # att_output = att_output.to(dtype=models[i].dtype)
# first part of att_output is prefix (up to sequence length, [:, 0:prefix_seq_len]) # first part of att_output is prefix (up to sequence length, [:, 0:prefix_seq_len])
@@ -712,7 +802,9 @@ class SmolVLMWithExpertModel(nn.Module):
for i, hidden_states in enumerate(inputs_embeds): for i, hidden_states in enumerate(inputs_embeds):
# layer = models[i].layers[layer_idx] # layer = models[i].layers[layer_idx]
layer = model_layers[i][layer_idx] layer = model_layers[i][layer_idx]
att_output = att_outputs[i] if i < len(att_outputs) else att_outputs[0] # in case of self_attn att_output = (
att_outputs[i] if i < len(att_outputs) else att_outputs[0]
) # in case of self_attn
if hidden_states is not None: if hidden_states is not None:
if layer is None: if layer is None:
outputs_embeds.append(hidden_states) outputs_embeds.append(hidden_states)
@@ -759,7 +851,11 @@ class SmolVLMWithExpertModel(nn.Module):
if self.attention_implementation == "fa2": if self.attention_implementation == "fa2":
attention_interface = self.flash_attention_forward attention_interface = self.flash_attention_forward
elif self.attention_implementation == "flex": elif self.attention_implementation == "flex":
attention_interface = partial(flex_attention_forward, num_att_heads=self.num_attention_heads, num_key_value_heads=self.num_key_value_heads) attention_interface = partial(
flex_attention_forward,
num_att_heads=self.num_attention_heads,
num_key_value_heads=self.num_key_value_heads,
)
else: else:
attention_interface = self.eager_attention_forward attention_interface = self.eager_attention_forward
return attention_interface return attention_interface
@@ -807,7 +903,7 @@ class SmolVLMWithExpertModel(nn.Module):
att_weights *= head_dim**-0.5 att_weights *= head_dim**-0.5
att_weights = att_weights.to(dtype=torch.float32) att_weights = att_weights.to(dtype=torch.float32)
big_neg = torch.finfo(att_weights.dtype).min #-2.3819763e38 # See gemma/modules.py big_neg = torch.finfo(att_weights.dtype).min # -2.3819763e38 # See gemma/modules.py
masked_att_weights = torch.where(attention_mask[:, None, :, :], att_weights, big_neg) masked_att_weights = torch.where(attention_mask[:, None, :, :], att_weights, big_neg)
probs = nn.functional.softmax(masked_att_weights, dim=-1) probs = nn.functional.softmax(masked_att_weights, dim=-1)
probs = probs.to(dtype=value_states.dtype) probs = probs.to(dtype=value_states.dtype)
@@ -1027,6 +1027,7 @@ from lerobot.policies.utils import (
populate_queues, populate_queues,
) )
from lerobot.utils.utils import get_safe_dtype from lerobot.utils.utils import get_safe_dtype
# OBS_STATE = 'state' # OBS_STATE = 'state'
# Matches ".soNNN", optionally followed by "-something", up to the "_buffer_" marker # Matches ".soNNN", optionally followed by "-something", up to the "_buffer_" marker
_VARIANT_RE = re.compile(r"\.so\d+(?:-[\w]+)?_buffer_") _VARIANT_RE = re.compile(r"\.so\d+(?:-[\w]+)?_buffer_")
@@ -1891,4 +1892,4 @@ class VLAFlowMatching(nn.Module):
suffix_out = suffix_out[:, -self.config.chunk_size :] suffix_out = suffix_out[:, -self.config.chunk_size :]
suffix_out = suffix_out.to(dtype=torch.float32) suffix_out = suffix_out.to(dtype=torch.float32)
v_t = self.action_out_proj(suffix_out) v_t = self.action_out_proj(suffix_out)
return v_t return v_t
+1 -1
View File
@@ -1 +1 @@
c c
@@ -547,4 +547,4 @@ class SmolVLMWithExpertModel(nn.Module):
# we use -1 because sequence length can change # we use -1 because sequence length can change
att_output = att_output.reshape(batch_size, -1, num_key_value_heads * num_key_value_groups * head_dim) att_output = att_output.reshape(batch_size, -1, num_key_value_heads * num_key_value_groups * head_dim)
return att_output return att_output
+60 -35
View File
@@ -45,17 +45,21 @@ Note that in both examples, the repo/folder should contain at least `config.json
You can learn about the CLI options for this script in the `EvalPipelineConfig` in lerobot/configs/eval.py You can learn about the CLI options for this script in the `EvalPipelineConfig` in lerobot/configs/eval.py
""" """
import concurrent
import concurrent.futures as cf
import json import json
import logging import logging
import threading import threading
import time import time
from collections import defaultdict
from collections.abc import Callable from collections.abc import Callable
from contextlib import nullcontext from contextlib import nullcontext
from copy import deepcopy from copy import deepcopy
from dataclasses import asdict from dataclasses import asdict
from pathlib import Path from pathlib import Path
from pprint import pformat from pprint import pformat
from typing import Dict, List, Tuple, TypedDict
from collections.abc import Iterator
import einops import einops
import gymnasium as gym import gymnasium as gym
@@ -68,7 +72,11 @@ from tqdm import trange
from lerobot.configs import parser from lerobot.configs import parser
from lerobot.configs.eval import EvalPipelineConfig from lerobot.configs.eval import EvalPipelineConfig
from lerobot.envs.factory import make_env from lerobot.envs.factory import make_env
from lerobot.envs.utils import add_envs_task, check_env_attributes_and_types, preprocess_observation, preprocess_observation1 from lerobot.envs.utils import (
add_envs_task,
check_env_attributes_and_types,
preprocess_observation,
)
from lerobot.policies.factory import make_policy from lerobot.policies.factory import make_policy
from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.utils import get_device_from_parameters from lerobot.policies.utils import get_device_from_parameters
@@ -79,9 +87,6 @@ from lerobot.utils.utils import (
init_logging, init_logging,
inside_slurm, inside_slurm,
) )
from typing import TypedDict, Dict, List, Tuple, Iterator
from collections import defaultdict
import concurrent.futures as cf
def rollout( def rollout(
@@ -485,8 +490,12 @@ def _compile_episode_data(
data_dict["index"] = torch.arange(start_data_index, start_data_index + total_frames, 1) data_dict["index"] = torch.arange(start_data_index, start_data_index + total_frames, 1)
return data_dict return data_dict
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy
def _inject_normalization_stats(policy: SmolVLAPolicy, dataset_meta: LeRobotDatasetMetadata): def _inject_normalization_stats(policy: SmolVLAPolicy, dataset_meta: LeRobotDatasetMetadata):
"""Recreate normalization layers with proper stats from the dataset.""" """Recreate normalization layers with proper stats from the dataset."""
from lerobot.policies.normalize import Normalize, Unnormalize from lerobot.policies.normalize import Normalize, Unnormalize
@@ -518,7 +527,8 @@ def _inject_normalization_stats(policy: SmolVLAPolicy, dataset_meta: LeRobotData
def load_smolvla(cfg, dataset_repo: str, policy): def load_smolvla(cfg, dataset_repo: str, policy):
from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.lerobot_dataset import LeRobotDataset
dataset = LeRobotDataset(dataset_repo, root='/raid/jade/.cache/huggingface/datasets/')
dataset = LeRobotDataset(dataset_repo, root="/raid/jade/.cache/huggingface/datasets/")
_inject_normalization_stats(policy=policy, dataset_meta=dataset.meta) # only needed if stats are missing _inject_normalization_stats(policy=policy, dataset_meta=dataset.meta) # only needed if stats are missing
return policy.to("cuda"), dataset return policy.to("cuda"), dataset
@@ -529,8 +539,8 @@ def eval_main(cfg: EvalPipelineConfig):
# Check device is available # Check device is available
device = get_safe_torch_device(cfg.policy.device, log=True) device = get_safe_torch_device(cfg.policy.device, log=True)
#login to hf # login to hf
from huggingface_hub import login
# login() # login()
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True
@@ -549,7 +559,7 @@ def eval_main(cfg: EvalPipelineConfig):
breakpoint() breakpoint()
# policy, _ = load_smolvla(cfg.policy, "physical-intelligence/libero", policy) # policy, _ = load_smolvla(cfg.policy, "physical-intelligence/libero", policy)
# rename "image" -> "observation.image" # rename "image" -> "observation.image"
policy.eval() policy.eval()
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(): with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
info = eval_policy_all( info = eval_policy_all(
@@ -584,10 +594,11 @@ def eval_main(cfg: EvalPipelineConfig):
# ---- typed payload returned by one task eval ---- # ---- typed payload returned by one task eval ----
class TaskMetrics(TypedDict): class TaskMetrics(TypedDict):
sum_rewards: List[float] sum_rewards: list[float]
max_rewards: List[float] max_rewards: list[float]
successes: List[bool] successes: list[bool]
video_paths: List[str] video_paths: list[str]
ACC_KEYS = ("sum_rewards", "max_rewards", "successes", "video_paths") ACC_KEYS = ("sum_rewards", "max_rewards", "successes", "video_paths")
@@ -610,7 +621,7 @@ def eval_policy_all(
""" """
global_start = time.time() global_start = time.time()
# inner: evaluate a single (suite, task) # inner: evaluate a single (suite, task)
def eval_one( def eval_one(
task_group: str, task_group: str,
task_id: int, task_id: int,
@@ -650,27 +661,36 @@ def eval_policy_all(
video_paths=task_result.get("video_paths", []), video_paths=task_result.get("video_paths", []),
) )
# result producer: sequential or threaded, same consumer # result producer: sequential or threaded, same consumer
def iter_task_results() -> Iterator[Tuple[str, int, TaskMetrics]]: def iter_task_results() -> Iterator[tuple[str, int, TaskMetrics]]:
if max_parallel_tasks == 1: if max_parallel_tasks == 1:
for task_group, tasks in envs.items(): for task_group, tasks in envs.items():
for task_id, vec in tasks.items(): for task_id, vec in tasks.items():
yield task_group, task_id, eval_one( yield (
task_group, task_id, vec, task_group,
policy=policy, task_id,
n_episodes=n_episodes, eval_one(
max_episodes_rendered=max_episodes_rendered, task_group,
videos_dir=videos_dir, task_id,
return_episode_data=return_episode_data, vec,
start_seed=start_seed, policy=policy,
n_episodes=n_episodes,
max_episodes_rendered=max_episodes_rendered,
videos_dir=videos_dir,
return_episode_data=return_episode_data,
start_seed=start_seed,
),
) )
else: else:
with cf.ThreadPoolExecutor(max_workers=max_parallel_tasks) as executor: with cf.ThreadPoolExecutor(max_workers=max_parallel_tasks) as executor:
fut2key: Dict[cf.Future, Tuple[str, int]] = {} fut2key: dict[cf.Future, tuple[str, int]] = {}
for task_group, tasks in envs.items(): for task_group, tasks in envs.items():
for task_id, vec in tasks.items(): for task_id, vec in tasks.items():
fut = executor.submit( fut = executor.submit(
eval_one, task_group, task_id, vec, eval_one,
task_group,
task_id,
vec,
policy=policy, policy=policy,
n_episodes=n_episodes, n_episodes=n_episodes,
max_episodes_rendered=max_episodes_rendered, max_episodes_rendered=max_episodes_rendered,
@@ -683,9 +703,9 @@ def eval_policy_all(
task_group, task_id = fut2key[fut] task_group, task_id = fut2key[fut]
yield task_group, task_id, fut.result() yield task_group, task_id, fut.result()
# single accumulator path on the main thread # single accumulator path on the main thread
group_acc: Dict[str, Dict[str, List]] = defaultdict(lambda: {k: [] for k in ACC_KEYS}) group_acc: dict[str, dict[str, list]] = defaultdict(lambda: {k: [] for k in ACC_KEYS})
overall: Dict[str, List] = {k: [] for k in ACC_KEYS} overall: dict[str, list] = {k: [] for k in ACC_KEYS}
for task_group, task_id, metrics in iter_task_results(): for task_group, task_id, metrics in iter_task_results():
acc = group_acc[task_group] acc = group_acc[task_group]
@@ -694,7 +714,7 @@ def eval_policy_all(
overall[k].extend(metrics[k]) overall[k].extend(metrics[k])
# build outputs # build outputs
results: Dict[str, dict] = {} results: dict[str, dict] = {}
for task_group, data in group_acc.items(): for task_group, data in group_acc.items():
suite_rewards = data["sum_rewards"] suite_rewards = data["sum_rewards"]
suite_max = data["max_rewards"] suite_max = data["max_rewards"]
@@ -720,9 +740,15 @@ def eval_policy_all(
global_eval_ep_s = global_eval_s / max(1, len(overall["sum_rewards"])) global_eval_ep_s = global_eval_s / max(1, len(overall["sum_rewards"]))
results["overall"] = { results["overall"] = {
"aggregated": { "aggregated": {
"avg_sum_reward": float(np.nanmean(overall["sum_rewards"])) if overall["sum_rewards"] else float("nan"), "avg_sum_reward": float(np.nanmean(overall["sum_rewards"]))
"avg_max_reward": float(np.nanmean(overall["max_rewards"])) if overall["max_rewards"] else float("nan"), if overall["sum_rewards"]
"pc_success": float(np.nanmean(overall["successes"]) * 100) if overall["successes"] else float("nan"), else float("nan"),
"avg_max_reward": float(np.nanmean(overall["max_rewards"]))
if overall["max_rewards"]
else float("nan"),
"pc_success": float(np.nanmean(overall["successes"]) * 100)
if overall["successes"]
else float("nan"),
"eval_s": global_eval_s, "eval_s": global_eval_s,
"eval_ep_s": global_eval_ep_s, "eval_ep_s": global_eval_ep_s,
}, },
@@ -732,7 +758,6 @@ def eval_policy_all(
return results return results
if __name__ == "__main__": if __name__ == "__main__":
init_logging() init_logging()
eval_main() eval_main()
+11 -6
View File
@@ -105,6 +105,7 @@ def update_policy(
train_metrics.update_s = time.perf_counter() - start_time train_metrics.update_s = time.perf_counter() - start_time
return train_metrics, output_dict return train_metrics, output_dict
# def _inject_normalization_stats(policy: SmolVLAPolicy, dataset_meta: LeRobotDatasetMetadata): # def _inject_normalization_stats(policy: SmolVLAPolicy, dataset_meta: LeRobotDatasetMetadata):
# """Recreate normalization layers with dataset stats if missing (Adil's workaround).""" # """Recreate normalization layers with dataset stats if missing (Adil's workaround)."""
# from lerobot.policies.normalize import Normalize, Unnormalize # from lerobot.policies.normalize import Normalize, Unnormalize
@@ -132,6 +133,7 @@ def update_policy(
# print("✅ Normalization layers injected with dataset stats.") # print("✅ Normalization layers injected with dataset stats.")
@parser.wrap() @parser.wrap()
def train(cfg: TrainPipelineConfig): def train(cfg: TrainPipelineConfig):
cfg.validate() cfg.validate()
@@ -271,9 +273,12 @@ def train(cfg: TrainPipelineConfig):
if cfg.env and is_eval_step: if cfg.env and is_eval_step:
step_id = get_step_identifier(step, cfg.steps) step_id = get_step_identifier(step, cfg.steps)
logging.info(f"Eval policy at step {step}") logging.info(f"Eval policy at step {step}")
with torch.no_grad(), (torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext()): with (
torch.no_grad(),
torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(),
):
eval_info = eval_policy_all( eval_info = eval_policy_all(
eval_env, # dict[suite][task_id] -> vec_env eval_env, # dict[suite][task_id] -> vec_env
policy, policy,
cfg.eval.n_episodes, cfg.eval.n_episodes,
videos_dir=videos_dir, videos_dir=videos_dir,
@@ -295,15 +300,15 @@ def train(cfg: TrainPipelineConfig):
# meters/tracker # meters/tracker
eval_metrics = { eval_metrics = {
"avg_sum_reward": AverageMeter("∑rwrd", ":.3f"), "avg_sum_reward": AverageMeter("∑rwrd", ":.3f"),
"pc_success": AverageMeter("success", ":.1f"), "pc_success": AverageMeter("success", ":.1f"),
"eval_s": AverageMeter("eval_s", ":.3f"), "eval_s": AverageMeter("eval_s", ":.3f"),
} }
eval_tracker = MetricsTracker( eval_tracker = MetricsTracker(
cfg.batch_size, dataset.num_frames, dataset.num_episodes, eval_metrics, initial_step=step cfg.batch_size, dataset.num_frames, dataset.num_episodes, eval_metrics, initial_step=step
) )
eval_tracker.eval_s = aggregated.get("eval_s", 0.0) eval_tracker.eval_s = aggregated.get("eval_s", 0.0)
eval_tracker.avg_sum_reward = aggregated.get("avg_sum_reward", float("nan")) eval_tracker.avg_sum_reward = aggregated.get("avg_sum_reward", float("nan"))
eval_tracker.pc_success = aggregated.get("pc_success", float("nan")) eval_tracker.pc_success = aggregated.get("pc_success", float("nan"))
if wandb_logger: if wandb_logger:
wandb_log_dict = {**eval_tracker.to_dict(), **eval_info} wandb_log_dict = {**eval_tracker.to_dict(), **eval_info}
wandb_logger.log_dict(wandb_log_dict, step, mode="eval") wandb_logger.log_dict(wandb_log_dict, step, mode="eval")
+6 -4
View File
@@ -104,6 +104,7 @@ def update_policy(
train_metrics.update_s = time.perf_counter() - start_time train_metrics.update_s = time.perf_counter() - start_time
return train_metrics, output_dict return train_metrics, output_dict
def _inject_normalization_stats(policy: SmolVLAPolicy, dataset_meta: LeRobotDatasetMetadata): def _inject_normalization_stats(policy: SmolVLAPolicy, dataset_meta: LeRobotDatasetMetadata):
"""Recreate normalization layers with dataset stats if missing (Adil's workaround).""" """Recreate normalization layers with dataset stats if missing (Adil's workaround)."""
from lerobot.policies.normalize import Normalize, Unnormalize from lerobot.policies.normalize import Normalize, Unnormalize
@@ -115,15 +116,15 @@ def _inject_normalization_stats(policy: SmolVLAPolicy, dataset_meta: LeRobotData
stats = {} stats = {}
for key, stat_dict in dataset_meta.stats.items(): for key, stat_dict in dataset_meta.stats.items():
stats[key] = { stats[key] = {
stat_type: torch.as_tensor(stat_array) stat_type: torch.as_tensor(stat_array) if isinstance(stat_array, np.ndarray) else stat_array
if isinstance(stat_array, np.ndarray)
else stat_array
for stat_type, stat_array in stat_dict.items() for stat_type, stat_array in stat_dict.items()
} }
normalize_inputs = Normalize(policy.config.input_features, policy.config.normalization_mapping, stats) normalize_inputs = Normalize(policy.config.input_features, policy.config.normalization_mapping, stats)
normalize_targets = Normalize(policy.config.output_features, policy.config.normalization_mapping, stats) normalize_targets = Normalize(policy.config.output_features, policy.config.normalization_mapping, stats)
unnormalize_outputs = Unnormalize(policy.config.output_features, policy.config.normalization_mapping, stats) unnormalize_outputs = Unnormalize(
policy.config.output_features, policy.config.normalization_mapping, stats
)
policy.normalize_inputs = normalize_inputs policy.normalize_inputs = normalize_inputs
policy.normalize_targets = normalize_targets policy.normalize_targets = normalize_targets
@@ -131,6 +132,7 @@ def _inject_normalization_stats(policy: SmolVLAPolicy, dataset_meta: LeRobotData
print("✅ Normalization layers injected with dataset stats.") print("✅ Normalization layers injected with dataset stats.")
@parser.wrap() @parser.wrap()
def train(cfg: TrainPipelineConfig): def train(cfg: TrainPipelineConfig):
cfg.validate() cfg.validate()
+5 -4
View File
@@ -24,12 +24,15 @@ from accelerate.utils import set_seed as accelerate_set_seed
from termcolor import colored from termcolor import colored
from torch.optim import Optimizer from torch.optim import Optimizer
from lerobot.configs import parser
from lerobot.configs.train import TrainPipelineConfig
from lerobot.datasets.factory import make_dataset from lerobot.datasets.factory import make_dataset
from lerobot.datasets.sampler import EpisodeAwareSampler from lerobot.datasets.sampler import EpisodeAwareSampler
from lerobot.envs.factory import make_env from lerobot.envs.factory import make_env
from lerobot.optim.factory import make_optimizer_and_scheduler from lerobot.optim.factory import make_optimizer_and_scheduler
from lerobot.policies.factory import make_policy from lerobot.policies.factory import make_policy
from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.scripts.eval import eval_policy
from lerobot.utils.logging_utils import AverageMeter, MetricsTracker from lerobot.utils.logging_utils import AverageMeter, MetricsTracker
from lerobot.utils.train_utils import ( from lerobot.utils.train_utils import (
get_step_checkpoint_dir, get_step_checkpoint_dir,
@@ -43,9 +46,6 @@ from lerobot.utils.utils import (
has_method, has_method,
init_logging, init_logging,
) )
from lerobot.configs import parser
from lerobot.configs.train import TrainPipelineConfig
from lerobot.scripts.eval import eval_policy
def update_policy( def update_policy(
@@ -100,6 +100,7 @@ def train(cfg: TrainPipelineConfig):
# Initialize accelerator # Initialize accelerator
from accelerate.utils import DistributedDataParallelKwargs from accelerate.utils import DistributedDataParallelKwargs
# added by jade 2 lines # added by jade 2 lines
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=False) ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=False)
accelerator = Accelerator(..., kwargs_handlers=[ddp_kwargs]) accelerator = Accelerator(..., kwargs_handlers=[ddp_kwargs])
@@ -357,7 +358,7 @@ def train(cfg: TrainPipelineConfig):
if accelerator.is_main_process: if accelerator.is_main_process:
logging.info("End of training") logging.info("End of training")
accelerator.end_training() # added by jade accelerator.end_training() # added by jade
if __name__ == "__main__": if __name__ == "__main__":