diff --git a/examples/6_evaluate_libero.sh b/examples/6_evaluate_libero.sh index 46355dfa1..a7e235fcd 100644 --- a/examples/6_evaluate_libero.sh +++ b/examples/6_evaluate_libero.sh @@ -55,4 +55,4 @@ python src/lerobot/scripts/eval.py \ # --num_trials_per_task 10 \ # --video_out_path "data/libero/videos" \ # --device "cuda" \ -# --seed 7 \ No newline at end of file +# --seed 7 diff --git a/src/lerobot/envs/factory.py b/src/lerobot/envs/factory.py index d38b2eed3..b74af7276 100644 --- a/src/lerobot/envs/factory.py +++ b/src/lerobot/envs/factory.py @@ -61,9 +61,9 @@ def make_env( env_cls = gym.vector.AsyncVectorEnv if use_async_envs else gym.vector.SyncVectorEnv - if "libero" in cfg.type: from lerobot.envs.libero import create_libero_envs + return create_libero_envs( task=cfg.task, n_envs=n_envs, @@ -74,17 +74,16 @@ def make_env( multitask_eval=cfg.multitask_eval, ) - package_name = f"gym_{cfg.type}" try: importlib.import_module(package_name) except ModuleNotFoundError as e: raise ModuleNotFoundError( - f"{package_name} is not installed. Install with: pip install \"lerobot[{cfg.type}]\"" + f'{package_name} is not installed. Install with: pip install "lerobot[{cfg.type}]"' ) from e gym_handle = f"{package_name}/{cfg.task}" - + def _make_one(): return gym.make(gym_handle, disable_env_checker=True, **(cfg.gym_kwargs or {})) @@ -93,4 +92,3 @@ def make_env( # normalize to {suite: {task_id: vec_env}} for consistency suite_name = cfg.type # e.g., "pusht", "aloha" return {suite_name: {0: vec}} - diff --git a/src/lerobot/envs/libero.py b/src/lerobot/envs/libero.py index ff1574416..0672f583f 100644 --- a/src/lerobot/envs/libero.py +++ b/src/lerobot/envs/libero.py @@ -1,11 +1,13 @@ from __future__ import annotations +import logging import math import os from collections import defaultdict from collections.abc import Callable from itertools import chain -from typing import Any +from typing import Any, Dict, List +from collections.abc import Callable, Iterable, Mapping, Sequence import gymnasium as gym import numpy as np @@ -14,16 +16,12 @@ from gymnasium import spaces from libero.libero import benchmark, get_libero_path from libero.libero.envs import OffScreenRenderEnv -import logging -from collections import defaultdict -from typing import Any, Callable, Dict, Iterable, List, Mapping, Sequence - - logger = logging.getLogger(__name__) # ---- Helpers ----------------------------------------------------------------- -def _parse_camera_names(camera_name: str | Sequence[str]) -> List[str]: + +def _parse_camera_names(camera_name: str | Sequence[str]) -> list[str]: """Normalize camera_name into a non-empty list of strings.""" if isinstance(camera_name, str): cams = [c.strip() for c in camera_name.split(",") if c.strip()] @@ -47,14 +45,14 @@ def _get_suite(name: str): return suite -def _select_task_ids(total_tasks: int, task_ids: Iterable[int] | None) -> List[int]: +def _select_task_ids(total_tasks: int, task_ids: Iterable[int] | None) -> list[int]: """Validate/normalize task ids. If None → all tasks.""" if task_ids is None: return list(range(total_tasks)) - ids = sorted(set(int(t) for t in task_ids)) + ids = sorted({int(t) for t in task_ids}) for t in ids: if t < 0 or t >= total_tasks: - raise ValueError(f"task_id {t} out of range [0, {total_tasks-1}].") + raise ValueError(f"task_id {t} out of range [0, {total_tasks - 1}].") return ids @@ -64,16 +62,25 @@ def _make_env_fns( suite_name: str, task_id: int, n_envs: int, - camera_names: List[str], + camera_names: list[str], init_states: bool, gym_kwargs: Mapping[str, Any], LiberoEnv: type, # injected to avoid forward ref issues if needed -) -> List[Callable[[], "LiberoEnv"]]: +) -> list[Callable[[], LiberoEnv]]: """Build n_envs factory callables for a single (suite, task_id).""" joined_cams = ",".join(camera_names) # keep backward-compat: downstream expects a string - fns: List[Callable[[], "LiberoEnv"]] = [] + fns: list[Callable[[], LiberoEnv]] = [] for i in range(n_envs): - def _mk(i=i, suite=suite, task_id=task_id, suite_name=suite_name, joined_cams=joined_cams, init_states=init_states, gym_kwargs=dict(gym_kwargs)): + + def _mk( + i=i, + suite=suite, + task_id=task_id, + suite_name=suite_name, + joined_cams=joined_cams, + init_states=init_states, + gym_kwargs=dict(gym_kwargs), + ): return LiberoEnv( task_suite=suite, task_id=task_id, @@ -83,11 +90,14 @@ def _make_env_fns( episode_index=i, **gym_kwargs, ) + fns.append(_mk) return fns + # ---- Main API ---------------------------------------------------------------- + def create_libero_envs( task: str, n_envs: int, @@ -130,12 +140,15 @@ def create_libero_envs( logger.info( "Creating LIBERO envs | suites=%s | n_envs(per task)=%d | init_states=%s | multitask_eval=%s", - suite_names, n_envs, init_states, bool(multitask_eval) + suite_names, + n_envs, + init_states, + bool(multitask_eval), ) if task_ids_filter is not None: logger.info("Restricting to task_ids=%s", task_ids_filter) - out: Dict[str, Dict[int, Any]] = defaultdict(dict) + out: dict[str, dict[int, Any]] = defaultdict(dict) for suite_name in suite_names: suite = _get_suite(suite_name) @@ -161,6 +174,8 @@ def create_libero_envs( # return plain dicts for predictability return {suite: dict(task_map) for suite, task_map in out.items()} + + def quat2axisangle(quat): """ Copied from robosuite: https://github.com/ARISE-Initiative/robosuite/blob/eafb81f54ffc104f905ee48a16bb15f059176ad3/robosuite/utils/transform_utils.py#L490C1-L512C55 @@ -256,10 +271,10 @@ class LiberoEnv(gym.Env): self._env = self._make_envs_task(task_suite, self.task_id) TASK_SUITE_MAX_STEPS: dict[str, int] = { "libero_spatial": 220, # longest training demo has 193 steps - "libero_object": 280, # longest training demo has 254 steps - "libero_goal": 300, # longest training demo has 270 steps - "libero_10": 520, # longest training demo has 505 steps - "libero_90": 400, # longest training demo has 373 steps + "libero_object": 280, # longest training demo has 254 steps + "libero_goal": 300, # longest training demo has 270 steps + "libero_10": 520, # longest training demo has 505 steps + "libero_90": 400, # longest training demo has 373 steps } default_steps = 500 self._max_episode_steps = TASK_SUITE_MAX_STEPS.get(task_suite_name, default_steps) diff --git a/src/lerobot/envs/utils.py b/src/lerobot/envs/utils.py index 9490f670e..d65f7f29f 100644 --- a/src/lerobot/envs/utils.py +++ b/src/lerobot/envs/utils.py @@ -80,6 +80,7 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten return return_observations + def preprocess_observation1( observations: dict[str, np.ndarray], cfg: dict[str, Any] = None ) -> dict[str, Tensor]: @@ -130,6 +131,8 @@ def preprocess_observation1( if "task" in observations: return_observations["task"] = observations["task"] return return_observations + + def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]: # TODO(aliberts, rcadene): remove this hardcoding of keys and just use the nested keys as is # (need to also refactor preprocess_observation and externalize normalization from policies) @@ -183,6 +186,7 @@ def add_envs_task(env: gym.vector.VectorEnv, observation: dict[str, Any]) -> dic observation["task"] = ["" for _ in range(num_envs)] return observation + def _close_single_env(env: Any) -> None: """Try to close a single env object if it exposes .close().""" try: @@ -193,6 +197,7 @@ def _close_single_env(env: Any) -> None: # Best-effort close: log but don't raise LOG.debug("Exception while closing env %s: %s", env, exc) + def close_envs(env_or_collection: Any) -> None: """ Close a single env or any nested structure of envs. @@ -225,4 +230,4 @@ def close_envs(env_or_collection: Any) -> None: # Fallback: try to close if possible if hasattr(env_or_collection, "close"): - _close_single_env(env_or_collection) \ No newline at end of file + _close_single_env(env_or_collection) diff --git a/src/lerobot/policies/factory.py b/src/lerobot/policies/factory.py index 03fa44a2a..cc1b0480d 100644 --- a/src/lerobot/policies/factory.py +++ b/src/lerobot/policies/factory.py @@ -31,10 +31,10 @@ from lerobot.policies.pi0fast.configuration_pi0fast import PI0FASTConfig from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.sac.configuration_sac import SACConfig from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig +from lerobot.policies.smolpi0.configuration_smolpi0 import SMOLPI0Config from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig -from lerobot.policies.smolpi0.configuration_smolpi0 import SMOLPI0Config def get_policy_class(name: str) -> PreTrainedPolicy: diff --git a/src/lerobot/policies/normalize.py b/src/lerobot/policies/normalize.py index 043265b1b..646c330cb 100644 --- a/src/lerobot/policies/normalize.py +++ b/src/lerobot/policies/normalize.py @@ -309,8 +309,8 @@ class NormalizePerRobotType(nn.Module): getattr(self, f"{robot_type}_buffer_" + key.replace(".", "_")) for robot_type in robot_types ] if norm_mode is NormalizationMode.MEAN_STD: - mean = torch.stack([buffers[i]["mean"] for i in range(len(robot_types))],dim=0) - std = torch.stack([buffers[i]["std"] for i in range(len(robot_types))],dim=0) + mean = torch.stack([buffers[i]["mean"] for i in range(len(robot_types))], dim=0) + std = torch.stack([buffers[i]["std"] for i in range(len(robot_types))], dim=0) if batch[key].ndim == 3: mean = mean.unsqueeze(1) std = std.unsqueeze(1) @@ -332,6 +332,8 @@ class NormalizePerRobotType(nn.Module): else: raise ValueError(norm_mode) return batch + + # TODO (azouitine): We should replace all normalization on the policies with register_buffer normalization # and remove the `Normalize` and `Unnormalize` classes. def _initialize_stats_buffers( diff --git a/src/lerobot/policies/smolpi0/configuration_smolpi0.py b/src/lerobot/policies/smolpi0/configuration_smolpi0.py index c3605cd82..e39d17f15 100644 --- a/src/lerobot/policies/smolpi0/configuration_smolpi0.py +++ b/src/lerobot/policies/smolpi0/configuration_smolpi0.py @@ -14,12 +14,12 @@ from dataclasses import dataclass, field +from lerobot.configs.policies import PreTrainedConfig +from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature from lerobot.optim.optimizers import AdamWConfig from lerobot.optim.schedulers import ( CosineDecayWithWarmupSchedulerConfig, ) -from lerobot.configs.policies import PreTrainedConfig -from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature @dataclass @@ -52,7 +52,7 @@ class SMOLPI0Config(PreTrainedConfig): max_action_dim: int = 32 # Image preprocessing - resize_imgs_with_padding: tuple[int, int] = (512, 512) #(224, 224) + resize_imgs_with_padding: tuple[int, int] = (512, 512) # (224, 224) # Add empty images. Used by pi0_aloha_sim which adds the empty # left and right wrist cameras in addition to the top camera. @@ -107,14 +107,14 @@ class SMOLPI0Config(PreTrainedConfig): add_image_special_tokens: bool = False add_prompt_template: bool = False - prefix_prompt_template: str = f"<|im_start|>User: What action should the robot take to" - suffix_prompt_template: str = f"?\nAssistant:" + prefix_prompt_template: str = "<|im_start|>User: What action should the robot take to" + suffix_prompt_template: str = "?\nAssistant:" attention_mode: str = "self_attn" - prefix_length: int = -1 # n_obs_steps * num_cameras * num_image_token_per_image + tokenizer_max_length + prefix_length: int = -1 # n_obs_steps * num_cameras * num_image_token_per_image + tokenizer_max_length - past_obs_keys: str = f"image" + past_obs_keys: str = "image" add_local_special_image_tokens: bool = False @@ -122,7 +122,7 @@ class SMOLPI0Config(PreTrainedConfig): state_to_prefix: bool = False - pad_language_to: str = "longest" # "max_length" + pad_language_to: str = "longest" # "max_length" num_expert_layers: int = -1 num_vlm_layers: int = -1 @@ -144,9 +144,9 @@ class SMOLPI0Config(PreTrainedConfig): shuffle_camera_positions: bool = False vlm_img_size: int = -1 - + regression_loss: bool = False - + def __post_init__(self): super().__post_init__() if self.vlm_img_size > 0: @@ -198,7 +198,7 @@ class SMOLPI0Config(PreTrainedConfig): ) @property - def observation_delta_indices(self) -> list: # FIXME(mshukor): support spacing between observations + def observation_delta_indices(self) -> list: # FIXME(mshukor): support spacing between observations return [-k for k in range(0, self.n_obs_steps * self.n_obs_gap, self.n_obs_gap)][::-1] @property diff --git a/src/lerobot/policies/smolpi0/flex_attention.py b/src/lerobot/policies/smolpi0/flex_attention.py index 13950f743..732920af2 100644 --- a/src/lerobot/policies/smolpi0/flex_attention.py +++ b/src/lerobot/policies/smolpi0/flex_attention.py @@ -85,7 +85,7 @@ def flex_attention_forward( b_mask, h_mask, q_len, kv_len = causal_mask.shape # The shape of your mask - block_size = 128 # limitation of flex attention + block_size = 128 # limitation of flex attention q_len_rounded = _round_up_to_multiple(q_len, block_size) kv_len_rounded = _round_up_to_multiple(kv_len, block_size) @@ -95,7 +95,7 @@ def flex_attention_forward( pad_k = kv_len_rounded - kv_len if pad_q > 0 or pad_k > 0: padded_causal_mask = F.pad(causal_mask, (0, pad_k, 0, pad_q), value=0.0) - else: + else: padded_causal_mask = causal_mask mask_mod_fn_orig = precomputed_mask_factory(padded_causal_mask) @@ -107,21 +107,21 @@ def flex_attention_forward( KV_LEN=kv_len_rounded, device=causal_mask.device, ) - + mask_mod_fn_padded = precomputed_mask_factory(mask_4d) # FIXME(mshukor): compile mask torch.compile(create_block_mask) create_block_mask_compiled = torch.compile(create_block_mask) block_mask = create_block_mask_compiled( mask_mod=mask_mod_fn_padded, B=b_mask, - H=None, # + H=None, # Q_LEN=q_len_rounded, KV_LEN=kv_len_rounded, BLOCK_SIZE=block_size, device=causal_mask.device, _compile=False, ) - padded_query_states = F.pad(query_states, (0, 0, 0, pad_q), value=0.0) if pad_q > 0 else query_states + padded_query_states = F.pad(query_states, (0, 0, 0, pad_q), value=0.0) if pad_q > 0 else query_states padded_key_states = F.pad(key_states, (0, 0, 0, pad_k), value=0.0) if pad_k > 0 else key_states padded_value_states = F.pad(value_states, (0, 0, 0, pad_k), value=0.0) if pad_k > 0 else value_states # mask is applied inside the kernel, ideally more efficiently than score_mod. diff --git a/src/lerobot/policies/smolpi0/modeling_smolpi0.py b/src/lerobot/policies/smolpi0/modeling_smolpi0.py index 9a128f7b6..765a5901a 100644 --- a/src/lerobot/policies/smolpi0/modeling_smolpi0.py +++ b/src/lerobot/policies/smolpi0/modeling_smolpi0.py @@ -50,9 +50,10 @@ policy = Pi0Policy.from_pretrained("lerobot/pi0") """ import math -from collections import deque import os import re +from collections import deque + import safetensors import torch import torch.nn.functional as F # noqa: N812 @@ -66,12 +67,11 @@ from lerobot.policies.normalize import ( Unnormalize, UnnormalizePerRobotType, ) -from lerobot.policies.smolpi0.configuration_smolpi0 import SMOLPI0Config -from lerobot.policies.smolpi0.smolvlm_with_expert import ( - SmolVLMWithExpertModel -) from lerobot.policies.pretrained import PreTrainedPolicy +from lerobot.policies.smolpi0.configuration_smolpi0 import SMOLPI0Config +from lerobot.policies.smolpi0.smolvlm_with_expert import SmolVLMWithExpertModel from lerobot.utils.utils import get_safe_dtype + OBS_IMAGE = "observation.image" OBS_IMAGES = "observation.images" ACTION = "action" @@ -86,10 +86,13 @@ IMAGES_ORDER = { OBS_IMAGE_3: 2, OBS_IMAGE_4: 3, } +import random + from lerobot.policies.utils import ( populate_queues, ) -import random + + def create_sinusoidal_pos_embedding( time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu" ) -> Tensor: @@ -171,7 +174,10 @@ def resize_with_pad(img, width, height, pad_value=-1): padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value) return padded_img + _VARIANT_RE = re.compile(r"\.so\d+(?:-[\w]+)?_buffer_") + + def canonicalise(k: str) -> str: """ Remove dataset-variant markers like '.so100-blue_' or '.so100_' from a @@ -179,6 +185,7 @@ def canonicalise(k: str) -> str: """ return _VARIANT_RE.sub(".buffer_", k) + def standardise_state_dict( checkpoint: dict[str, torch.Tensor], ref_keys: set[str], *, verbose: bool = True ) -> tuple[dict[str, torch.Tensor], list[str]]: @@ -209,6 +216,7 @@ def standardise_state_dict( out.update({k: checkpoint[k] for k in unmatched}) return out, unmatched + def load_smolvla( model: torch.nn.Module, filename: str | os.PathLike, @@ -237,6 +245,8 @@ def load_smolvla( ) return model + + def pad_vector(vector, new_dim): """Can be (batch_size x sequence_length x features_dimension) or (batch_size x features_dimension) @@ -286,6 +296,7 @@ def aloha_gripper_to_angular(value): # The values 0.4 and 1.5 were measured on an actual Trossen robot. return normalize(value, min_val=0.4, max_val=1.5) + def rename_checkpoint_keys(checkpoint: dict, rename_str: str): """ Renames keys in a checkpoint dictionary based on the given rename string. @@ -307,6 +318,8 @@ def rename_checkpoint_keys(checkpoint: dict, rename_str: str): k = k.replace(old_key, new_key) new_checkpoint[k] = v return new_checkpoint + + def aloha_gripper_from_angular(value): # Convert from the gripper position used by pi0 to the gripper position that is used by Aloha. # Note that the units are still angular but the range is different. @@ -324,6 +337,7 @@ def aloha_gripper_from_angular_inv(value): value = unnormalize(value, min_val=-0.6213, max_val=1.4910) return normalize(value, min_val=0.4, max_val=1.5) + class SMOLPI0Policy(PreTrainedPolicy): """Wrapper class around VLAFlowMatching model to train and run inference within LeRobot.""" @@ -374,7 +388,9 @@ class SMOLPI0Policy(PreTrainedPolicy): self.language_tokenizer = AutoProcessor.from_pretrained(self.config.vlm_model_name).tokenizer self.model = VLAFlowMatching(config) - self.include_past_states = config.n_obs_steps > 1 and OBS_STATE in self.config.past_obs_keys.split(",") + self.include_past_states = config.n_obs_steps > 1 and OBS_STATE in self.config.past_obs_keys.split( + "," + ) self.include_past_images = config.n_obs_steps > 1 and "image" in self.config.past_obs_keys.split(",") self.num_past_images = self.config.n_obs_steps if self.include_past_images else 1 self.reset() @@ -389,31 +405,20 @@ class SMOLPI0Policy(PreTrainedPolicy): for k in self.config.input_features: if any([past_obs_key in k for past_obs_key in self.config.past_obs_keys.split(",")]): self._queues[k] = deque(maxlen=self.config.n_obs_steps) - + def get_optim_params(self) -> dict: if self.config.optimizer_lr_vlm > 0 and self.config.optimizer_lr_vlm != self.config.optimizer_lr: params = [ + {"params": [p for n, p in self.named_parameters() if ".vlm." not in n and p.requires_grad]}, { - "params": [ - p - for n, p in self.named_parameters() - if not ".vlm." in n and p.requires_grad - ] - }, - { - "params": [ - p - for n, p in self.named_parameters() - if ".vlm." in n and p.requires_grad - ], + "params": [p for n, p in self.named_parameters() if ".vlm." in n and p.requires_grad], "lr": self.config.optimizer_lr_vlm, }, ] return params - + else: return self.parameters() - def merge_peft_model_weights(self) -> None: if "lora" in self.config.peft_method: @@ -438,9 +443,7 @@ class SMOLPI0Policy(PreTrainedPolicy): state = self.prepare_state(batch) lang_tokens, lang_masks = self.prepare_language(batch) - actions = self.model.sample_actions( - images, img_masks, lang_tokens, lang_masks, state, noise=noise - ) + actions = self.model.sample_actions(images, img_masks, lang_tokens, lang_masks, state, noise=noise) # Unpad actions original_action_dim = self.config.action_feature.shape[0] actions = actions[:, :, :original_action_dim] @@ -469,6 +472,7 @@ class SMOLPI0Policy(PreTrainedPolicy): device=map_location, checkpoint_keys_mapping="model._orig_mod.//model.", ) + @torch.no_grad def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: """Select a single action given environment observations. @@ -564,8 +568,12 @@ class SMOLPI0Policy(PreTrainedPolicy): present_img_keys = [key for key in self.config.image_features if key in batch] missing_img_keys = [key for key in self.config.image_features if key not in batch] - present_img_keys = sorted(present_img_keys, key=lambda k: IMAGES_ORDER.get(k, float("inf")), reverse=self.config.reverse_images_order) - if self.config.shuffle_camera_positions and ACTION in batch: # only during training + present_img_keys = sorted( + present_img_keys, + key=lambda k: IMAGES_ORDER.get(k, float("inf")), + reverse=self.config.reverse_images_order, + ) + if self.config.shuffle_camera_positions and ACTION in batch: # only during training present_img_keys = random.sample(present_img_keys, len(present_img_keys)) if len(present_img_keys) == 0: raise ValueError( @@ -609,7 +617,10 @@ class SMOLPI0Policy(PreTrainedPolicy): tasks = [tasks[0] for _ in range(batch[OBS_STATE].shape[0])] if self.config.add_prompt_template: - tasks = [f"{self.config.prefix_prompt_template}{task}{self.config.suffix_prompt_template}" for task in tasks] + tasks = [ + f"{self.config.prefix_prompt_template}{task}{self.config.suffix_prompt_template}" + for task in tasks + ] else: tasks = [task if task.endswith("\n") else f"{task}\n" for task in tasks] tokenized_prompt = self.language_tokenizer.__call__( @@ -618,7 +629,7 @@ class SMOLPI0Policy(PreTrainedPolicy): padding_side="right", max_length=self.config.tokenizer_max_length, return_tensors="pt", - truncation=True, # FIXME(mshukor) + truncation=True, # FIXME(mshukor) ) lang_tokens = tokenized_prompt["input_ids"].to(device=device) @@ -655,7 +666,11 @@ class SMOLPI0Policy(PreTrainedPolicy): def prepare_state(self, batch): """Pad state""" - state = batch[OBS_STATE][:, -1, :] if (batch[OBS_STATE].ndim > 2 and not self.include_past_states) else batch[OBS_STATE] # FIXME(mshukor): no state history for now + state = ( + batch[OBS_STATE][:, -1, :] + if (batch[OBS_STATE].ndim > 2 and not self.include_past_states) + else batch[OBS_STATE] + ) # FIXME(mshukor): no state history for now state = pad_vector(state, self.config.max_state_dim) return state @@ -666,7 +681,9 @@ class SMOLPI0Policy(PreTrainedPolicy): if self.config.relative_actions_mode == "first": actions = torch.cat((actions[:, :1], actions[:, 1:] - actions[:, :1]), dim=1) elif self.config.relative_actions_mode == "state": - assert batch[ACTION].shape[-1] == batch[OBS_STATE].shape[-1], "Relative action mode 'state' requires the action and state to have the same dimension." + assert batch[ACTION].shape[-1] == batch[OBS_STATE].shape[-1], ( + "Relative action mode 'state' requires the action and state to have the same dimension." + ) if state.ndim == 2: state = state.unsqueeze(1) actions = actions - state @@ -674,6 +691,7 @@ class SMOLPI0Policy(PreTrainedPolicy): actions = torch.cat((actions[:, :1], actions[:, 1:] - actions[:, :-1]), dim=1) return actions + def pad_tensor(tensor, max_len, pad_value=0): """ Efficiently pads a tensor along sequence dimension to match max_len. @@ -687,13 +705,16 @@ def pad_tensor(tensor, max_len, pad_value=0): torch.Tensor: Shape (B, max_len, ...) or (B, max_len). """ B, L = tensor.shape[:2] - + # Create a padded tensor of max_len and copy the existing values - padded_tensor = torch.full((B, max_len, *tensor.shape[2:]), pad_value, dtype=tensor.dtype, device=tensor.device) + padded_tensor = torch.full( + (B, max_len, *tensor.shape[2:]), pad_value, dtype=tensor.dtype, device=tensor.device + ) padded_tensor[:, :L] = tensor # Efficient in-place copy return padded_tensor + class VLAFlowMatching(nn.Module): """ π0: A Vision-Language-Action Flow Model for General Robot Control @@ -725,7 +746,8 @@ class VLAFlowMatching(nn.Module): super().__init__() self.config = config - self.vlm_with_expert = SmolVLMWithExpertModel(model_id=self.config.vlm_model_name, + self.vlm_with_expert = SmolVLMWithExpertModel( + model_id=self.config.vlm_model_name, freeze_vision_encoder=self.config.freeze_vision_encoder, train_expert_only=self.config.train_expert_only, attention_implementation=self.config.attention_implementation, @@ -736,50 +758,64 @@ class VLAFlowMatching(nn.Module): self_attn_every_n_layers=self.config.self_attn_every_n_layers, expert_width_multiplier=self.config.expert_width_multiplier, self_attn_only_actions=self.config.self_attn_only_actions, - ) + ) # self.paligemma_with_expert = self.configure_peft(paligemma_with_expert) self.vlm_with_expert.configure_peft(config=self.config) # Projections are float32 self.state_to_prefix = self.config.state_to_prefix if self.state_to_prefix: - self.state_proj = nn.Linear(self.config.max_state_dim, self.vlm_with_expert.config.text_config.hidden_size) + self.state_proj = nn.Linear( + self.config.max_state_dim, self.vlm_with_expert.config.text_config.hidden_size + ) else: self.state_proj = nn.Linear(self.config.max_state_dim, self.vlm_with_expert.expert_hidden_size) self.action_in_proj = nn.Linear(self.config.max_action_dim, self.vlm_with_expert.expert_hidden_size) self.action_out_proj = nn.Linear(self.vlm_with_expert.expert_hidden_size, self.config.max_action_dim) - self.action_time_mlp_in = nn.Linear(self.vlm_with_expert.expert_hidden_size * 2, self.vlm_with_expert.expert_hidden_size) - self.action_time_mlp_out = nn.Linear(self.vlm_with_expert.expert_hidden_size, self.vlm_with_expert.expert_hidden_size) + self.action_time_mlp_in = nn.Linear( + self.vlm_with_expert.expert_hidden_size * 2, self.vlm_with_expert.expert_hidden_size + ) + self.action_time_mlp_out = nn.Linear( + self.vlm_with_expert.expert_hidden_size, self.vlm_with_expert.expert_hidden_size + ) self.set_requires_grad() # SmolVLM2 has: [fake_tok + crop_tok + crop + fake_tok + crop_tok ... + fake_tok + global_tok + global + fake_tok] + [second image] + ... - if any([k in self.config.vlm_model_name for k in ["SmolVLM-", "SmolVLA-"]]): + if any([k in self.config.vlm_model_name for k in ["SmolVLM-", "SmolVLA-"]]): if "SmolVLM-Instruct" in self.config.vlm_model_name: - self.fake_image_token = 49152 - self.global_image_token = [44, 13906, 29, 6266, 46] - self.global_image_start_token = torch.tensor([self.fake_image_token] + self.global_image_token, dtype=torch.long) + self.fake_image_token = 49152 + self.global_image_token = [44, 13906, 29, 6266, 46] + self.global_image_start_token = torch.tensor( + [self.fake_image_token] + self.global_image_token, dtype=torch.long + ) else: - self.fake_image_token = 49189 - self.global_image_token = 49152 - self.global_image_start_token = torch.tensor([self.fake_image_token, self.global_image_token], dtype=torch.long) + self.fake_image_token = 49189 + self.global_image_token = 49152 + self.global_image_start_token = torch.tensor( + [self.fake_image_token, self.global_image_token], dtype=torch.long + ) else: self.fake_image_token = self.vlm_with_expert.processor.tokenizer.fake_image_token_id self.global_image_token = self.vlm_with_expert.processor.tokenizer.global_image_token_id - self.global_image_start_token = torch.tensor([self.fake_image_token, self.global_image_token], dtype=torch.long) + self.global_image_start_token = torch.tensor( + [self.fake_image_token, self.global_image_token], dtype=torch.long + ) self.add_image_special_tokens = self.config.add_image_special_tokens self.add_local_special_image_tokens = self.config.add_local_special_image_tokens - self.local_image_tokens = [torch.tensor([self.fake_image_token, tok], dtype=torch.long) for tok in [49153, 49154, 49155, 49159, 49160, 49161, 49165, 49166, 49167]] # assume 3 x 3 grid - + self.local_image_tokens = [ + torch.tensor([self.fake_image_token, tok], dtype=torch.long) + for tok in [49153, 49154, 49155, 49159, 49160, 49161, 49165, 49166, 49167] + ] # assume 3 x 3 grid + self.local_image_start_token = self.global_image_start_token self.image_end_token = torch.tensor([self.fake_image_token], dtype=torch.long) self.prefix_length = self.config.prefix_length - self.include_past_images = self.config.n_obs_steps > 1 and "image" in self.config.past_obs_keys.split(",") + self.include_past_images = self.config.n_obs_steps > 1 and "image" in self.config.past_obs_keys.split( + "," + ) self.num_past_images = self.config.n_obs_steps if self.include_past_images else 1 self.causal_attention_on_history = self.config.causal_attention_on_history - - - # def configure_peft(self, model): # # return model @@ -845,14 +881,30 @@ class VLAFlowMatching(nn.Module): img, img_mask, ) in enumerate(zip(images, img_masks, strict=False)): - # FIXME(mshukor): add special tokens for the history each history_steps or not + # FIXME(mshukor): add special tokens for the history each history_steps or not if self.add_image_special_tokens: if self.add_local_special_image_tokens and img_idx % num_images != num_images - 1: local_token_idx = img_idx % num_images - image_start_token = self.vlm_with_expert.embed_language_tokens(self.local_image_tokens[local_token_idx].to(device=self.vlm_with_expert.vlm.device)).unsqueeze(0).expand(img.shape[0], -1, -1) + image_start_token = ( + self.vlm_with_expert.embed_language_tokens( + self.local_image_tokens[local_token_idx].to( + device=self.vlm_with_expert.vlm.device + ) + ) + .unsqueeze(0) + .expand(img.shape[0], -1, -1) + ) else: - image_start_token = self.vlm_with_expert.embed_language_tokens(self.global_image_start_token.to(device=self.vlm_with_expert.vlm.device)).unsqueeze(0).expand(img.shape[0], -1, -1) - image_start_mask = torch.ones_like(image_start_token[:, :, 0], dtype=torch.bool, device=image_start_token.device) + image_start_token = ( + self.vlm_with_expert.embed_language_tokens( + self.global_image_start_token.to(device=self.vlm_with_expert.vlm.device) + ) + .unsqueeze(0) + .expand(img.shape[0], -1, -1) + ) + image_start_mask = torch.ones_like( + image_start_token[:, :, 0], dtype=torch.bool, device=image_start_token.device + ) if self.causal_attention_on_history and img_idx % num_images == 0: att_masks += [1] + [0] * (image_start_mask.shape[-1] - 1) else: @@ -861,7 +913,7 @@ class VLAFlowMatching(nn.Module): pad_masks.append(image_start_mask) img_emb = self.vlm_with_expert.embed_image(img) - img_emb = img_emb #.to(dtype=self.vlm_with_expert.type) + img_emb = img_emb # .to(dtype=self.vlm_with_expert.type) # Normalize image embeddings img_emb_dim = img_emb.shape[-1] @@ -880,16 +932,26 @@ class VLAFlowMatching(nn.Module): att_masks += [0] * (num_img_embs) if self.add_image_special_tokens: - if not self.add_local_special_image_tokens or (self.add_local_special_image_tokens and img_idx % num_images == num_images - 1): - image_end_token = self.vlm_with_expert.embed_language_tokens(self.image_end_token.to(device=self.vlm_with_expert.vlm.device)).unsqueeze(0).expand(img.shape[0], -1, -1) - image_end_mask = torch.ones_like(image_end_token[:, :, 0], dtype=torch.bool, device=image_end_token.device) + if not self.add_local_special_image_tokens or ( + self.add_local_special_image_tokens and img_idx % num_images == num_images - 1 + ): + image_end_token = ( + self.vlm_with_expert.embed_language_tokens( + self.image_end_token.to(device=self.vlm_with_expert.vlm.device) + ) + .unsqueeze(0) + .expand(img.shape[0], -1, -1) + ) + image_end_mask = torch.ones_like( + image_end_token[:, :, 0], dtype=torch.bool, device=image_end_token.device + ) embs.append(image_end_token) pad_masks.append(image_end_mask) att_masks += [0] * (image_end_mask.shape[1]) lang_emb = self.vlm_with_expert.embed_language_tokens(lang_tokens) # Normalize language embeddings lang_emb_dim = lang_emb.shape[-1] - lang_emb = lang_emb * math.sqrt(lang_emb_dim) # FIXME(mshukor): is this needed for smolvlm? + lang_emb = lang_emb * math.sqrt(lang_emb_dim) # FIXME(mshukor): is this needed for smolvlm? embs.append(lang_emb) pad_masks.append(lang_masks) @@ -900,7 +962,9 @@ class VLAFlowMatching(nn.Module): if state is not None and self.state_to_prefix: state_emb = self.state_proj(state) - state_emb = state_emb[:, None, :] if state_emb.ndim == 2 else state_emb #.to(dtype=self.vlm_with_expert.type) + state_emb = ( + state_emb[:, None, :] if state_emb.ndim == 2 else state_emb + ) # .to(dtype=self.vlm_with_expert.type) embs.append(state_emb) bsize = state_emb.shape[0] dtype = state_emb.dtype @@ -912,7 +976,7 @@ class VLAFlowMatching(nn.Module): # Set attention masks so that image and language inputs do not attend to state or actions # att_masks += [1] + [0]*(states_seq_len - 1) - att_masks += [1]*(states_seq_len) + att_masks += [1] * (states_seq_len) embs = torch.cat(embs, dim=1) pad_masks = torch.cat(pad_masks, dim=1) att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device) @@ -937,7 +1001,9 @@ class VLAFlowMatching(nn.Module): # Embed state if not self.state_to_prefix: state_emb = self.state_proj(state) - state_emb = state_emb[:, None, :] if state_emb.ndim == 2 else state_emb #.to(dtype=self.vlm_with_expert.type) + state_emb = ( + state_emb[:, None, :] if state_emb.ndim == 2 else state_emb + ) # .to(dtype=self.vlm_with_expert.type) embs.append(state_emb) bsize = state_emb.shape[0] dtype = state_emb.dtype @@ -948,8 +1014,7 @@ class VLAFlowMatching(nn.Module): pad_masks.append(state_mask) # Set attention masks so that image and language inputs do not attend to state or actions - att_masks += [1] + [0]*(states_seq_len - 1) - + att_masks += [1] + [0] * (states_seq_len - 1) # Fuse timestep + action information using an MLP action_emb = self.action_in_proj(noisy_actions) @@ -1010,7 +1075,7 @@ class VLAFlowMatching(nn.Module): images, img_masks, lang_tokens, lang_masks, state=state ) suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(state, x_t, time) - + pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1) att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1) @@ -1061,12 +1126,12 @@ class VLAFlowMatching(nn.Module): x_t = torch.zeros_like(noise, dtype=torch.float32, device=device) expanded_time = torch.zeros(bsize, dtype=torch.float32, device=device) x_t = self.denoise_step( - state, - prefix_pad_masks, - past_key_values, - x_t, - expanded_time, - ) + state, + prefix_pad_masks, + past_key_values, + x_t, + expanded_time, + ) else: dt = -1.0 / self.config.num_steps dt = torch.tensor(dt, dtype=torch.float32, device=device) diff --git a/src/lerobot/policies/smolpi0/smolvlm_with_expert.py b/src/lerobot/policies/smolpi0/smolvlm_with_expert.py index b910679f1..0ccdcccc8 100644 --- a/src/lerobot/policies/smolpi0/smolvlm_with_expert.py +++ b/src/lerobot/policies/smolpi0/smolvlm_with_expert.py @@ -12,31 +12,28 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Union -from functools import partial import copy +from functools import partial +from typing import List, Optional, Union import torch -import torch.version import torch.nn.functional as F # noqa: N812 +import torch.version from peft import LoraConfig, TaskType, get_peft_model from pytest import Cache from torch import nn from transformers import ( AutoConfig, - GemmaForCausalLM, - AutoModelForImageTextToText, - AutoProcessor, - PretrainedConfig, - PreTrainedModel, - SmolVLMForConditionalGeneration, AutoModel, + AutoModelForImageTextToText, AutoModelForVision2Seq, + AutoProcessor, + SmolVLMForConditionalGeneration, ) -from transformers.models.auto import CONFIG_MAPPING -from transformers import SmolVLMModel, SmolVLMConfig + from lerobot.policies.smolpi0.flex_attention import flex_attention_forward + def _round_up_to_multiple(x, multiple): return (x + multiple - 1) // multiple * multiple @@ -180,20 +177,31 @@ def apply_rope(x, positions, max_wavelength=10_000): # f"Wrong value provided for `attention_implementation` ({self.attention_implementation}). Expected 'eager', 'fa2' or 'flex'." # ) + def get_intermediate_size(hidden_dim, ffn_dim_multiplier=4, multiple_of=256): hidden_dim = int(2 * hidden_dim / 3) hidden_dim = int(ffn_dim_multiplier * hidden_dim) hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) return hidden_dim - + class SmolVLMWithExpertModel(nn.Module): # config_class = PaliGemmaWithExpertConfig - def __init__(self, model_id: str = "HuggingFaceTB/SmolVLM2-500M-Video-Instruct", - load_vlm_weights: bool = True, train_expert_only: bool = True, freeze_vision_encoder: bool = False, - attention_implementation: str = "eager", attention_mode: str = "self_attn", num_expert_layers: int = -1, - num_vlm_layers: int = -1, self_attn_every_n_layers: int = -1, expert_width_multiplier: float = 0.5, self_attn_only_actions: bool = False): + def __init__( + self, + model_id: str = "HuggingFaceTB/SmolVLM2-500M-Video-Instruct", + load_vlm_weights: bool = True, + train_expert_only: bool = True, + freeze_vision_encoder: bool = False, + attention_implementation: str = "eager", + attention_mode: str = "self_attn", + num_expert_layers: int = -1, + num_vlm_layers: int = -1, + self_attn_every_n_layers: int = -1, + expert_width_multiplier: float = 0.5, + self_attn_only_actions: bool = False, + ): super().__init__() if load_vlm_weights: print(f"Loading {model_id} weights ...") @@ -227,15 +235,17 @@ class SmolVLMWithExpertModel(nn.Module): # Smaller lm expert lm_expert_config = copy.deepcopy(config.text_config) hidden_size = lm_expert_config.hidden_size - lm_expert_config.hidden_size = int(hidden_size*expert_width_multiplier) #hidden_size // 2 - lm_expert_config.intermediate_size = get_intermediate_size(int(hidden_size*expert_width_multiplier)) + lm_expert_config.hidden_size = int(hidden_size * expert_width_multiplier) # hidden_size // 2 + lm_expert_config.intermediate_size = get_intermediate_size(int(hidden_size * expert_width_multiplier)) lm_expert_config.num_hidden_layers = self.num_vlm_layers - if num_expert_layers > 0 : - assert len(self.get_vlm_model().text_model.layers) % num_expert_layers == 0, f"Number of layers in the VLM {len(self.get_vlm_model().text_model.layers)} are not multiple of num_expert_layers {num_expert_layers}" + if num_expert_layers > 0: + assert len(self.get_vlm_model().text_model.layers) % num_expert_layers == 0, ( + f"Number of layers in the VLM {len(self.get_vlm_model().text_model.layers)} are not multiple of num_expert_layers {num_expert_layers}" + ) lm_expert_config.num_hidden_layers = num_expert_layers # lm_expert_config.head_dim = lm_expert_config.head_dim * 2 self.lm_expert = AutoModel.from_config(lm_expert_config) - + self.num_expert_layers = len(self.lm_expert.layers) self.self_attn_every_n_layers = self_attn_every_n_layers self.self_attn_only_actions = self_attn_only_actions @@ -245,10 +255,14 @@ class SmolVLMWithExpertModel(nn.Module): if self.self_attn_every_n_layers > 0 and layer_idx % self.self_attn_every_n_layers == 0: continue self.lm_expert.layers[layer_idx].self_attn.k_proj = nn.Linear( - config.text_config.num_key_value_heads * config.text_config.head_dim, lm_expert_config.num_key_value_heads * lm_expert_config.head_dim, bias=lm_expert_config.attention_bias + config.text_config.num_key_value_heads * config.text_config.head_dim, + lm_expert_config.num_key_value_heads * lm_expert_config.head_dim, + bias=lm_expert_config.attention_bias, ) self.lm_expert.layers[layer_idx].self_attn.v_proj = nn.Linear( - config.text_config.num_key_value_heads * config.text_config.head_dim, lm_expert_config.num_key_value_heads * lm_expert_config.head_dim, bias=lm_expert_config.attention_bias + config.text_config.num_key_value_heads * config.text_config.head_dim, + lm_expert_config.num_key_value_heads * lm_expert_config.head_dim, + bias=lm_expert_config.attention_bias, ) # Remove unused embed_tokens self.lm_expert.embed_tokens = None @@ -308,8 +322,10 @@ class SmolVLMWithExpertModel(nn.Module): else: self.vlm = self.vlm.merge_and_unload() - def get_vlm_model(self,): - if hasattr(self.vlm.model, "model"): # When using peft + def get_vlm_model( + self, + ): + if hasattr(self.vlm.model, "model"): # When using peft return self.vlm.model.model else: return self.vlm.model @@ -326,22 +342,20 @@ class SmolVLMWithExpertModel(nn.Module): else: # To avoid unused params issue with distributed training last_layers = [self.num_vlm_layers - 1] - if self.num_vlm_layers != self.num_expert_layers and self.num_vlm_layers % self.num_expert_layers == 0: + if ( + self.num_vlm_layers != self.num_expert_layers + and self.num_vlm_layers % self.num_expert_layers == 0 + ): last_layers.append(self.num_vlm_layers - 2) frozen_layers = [ - "lm_head", - "text_model.model.norm.weight", - ] + "lm_head", + "text_model.model.norm.weight", + ] for layer in last_layers: frozen_layers.append(f"text_model.model.layers.{layer}.") for name, params in self.vlm.named_parameters(): - if any( - [ - k in name - for k in frozen_layers - ] - ): + if any([k in name for k in frozen_layers]): params.requires_grad = False # To avoid unused params issue with distributed training for name, params in self.lm_expert.named_parameters(): @@ -410,19 +424,34 @@ class SmolVLMWithExpertModel(nn.Module): # FIXME(mshukor): add special image tokens specific to smolvlm # Get sequence from the vision encoder - image_hidden_states = self.get_vlm_model().vision_model( - pixel_values=image.to(dtype=self.get_vlm_model().vision_model.dtype), - patch_attention_mask=patch_attention_mask, - ).last_hidden_state + image_hidden_states = ( + self.get_vlm_model() + .vision_model( + pixel_values=image.to(dtype=self.get_vlm_model().vision_model.dtype), + patch_attention_mask=patch_attention_mask, + ) + .last_hidden_state + ) # Modality projection & resampling image_hidden_states = self.get_vlm_model().connector(image_hidden_states) return image_hidden_states - + def embed_language_tokens(self, tokens: torch.Tensor): return self.get_vlm_model().text_model.get_input_embeddings()(tokens) - def forward_attn_layer(self, model_layers, inputs_embeds, layer_idx, position_ids, attention_mask, batch_size, head_dim, use_cache: bool = True, fill_kv_cache: bool = True, past_key_values=None) -> list[torch.Tensor]: - + def forward_attn_layer( + self, + model_layers, + inputs_embeds, + layer_idx, + position_ids, + attention_mask, + batch_size, + head_dim, + use_cache: bool = True, + fill_kv_cache: bool = True, + past_key_values=None, + ) -> list[torch.Tensor]: query_states = [] key_states = [] value_states = [] @@ -430,7 +459,7 @@ class SmolVLMWithExpertModel(nn.Module): layer = model_layers[i][layer_idx] if hidden_states is None or layer is None: continue - + # normalizer = torch.tensor(models[i].config.hidden_size**0.5, dtype=hidden_states.dtype) # hidden_states = hidden_states * normalizer hidden_states = layer.input_layernorm(hidden_states) @@ -468,12 +497,16 @@ class SmolVLMWithExpertModel(nn.Module): if inputs_embeds[1] is not None: suffix_len = inputs_embeds[1].shape[1] attention_mask_[:, -suffix_len:, :-suffix_len] = False - position_ids_[:, -suffix_len:] = _position_ids[:, -suffix_len:] - _position_ids[:, -suffix_len][:, None] + position_ids_[:, -suffix_len:] = ( + _position_ids[:, -suffix_len:] - _position_ids[:, -suffix_len][:, None] + ) else: attention_mask_ = _attention_mask position_ids_ = _position_ids - query_states = apply_rope(query_states, position_ids_) # FIXME(mshukor): this assumes we have always the vlm features? + query_states = apply_rope( + query_states, position_ids_ + ) # FIXME(mshukor): this assumes we have always the vlm features? key_states = apply_rope(key_states, position_ids_) if use_cache and past_key_values is None: @@ -491,9 +524,7 @@ class SmolVLMWithExpertModel(nn.Module): # the max len, then we (for instance) double the cache size. This implementation already exists # in `transformers`. (molbap) key_states = torch.cat([past_key_values[layer_idx]["key_states"], key_states], dim=1) - value_states = torch.cat( - [past_key_values[layer_idx]["value_states"], value_states], dim=1 - ) + value_states = torch.cat([past_key_values[layer_idx]["value_states"], value_states], dim=1) attention_interface = self.get_attention_interface() @@ -504,20 +535,32 @@ class SmolVLMWithExpertModel(nn.Module): return [att_output], past_key_values - - def forward_cross_attn_layer(self, model_layers, inputs_embeds, layer_idx, position_ids, attention_mask, batch_size, head_dim, use_cache: bool = True, fill_kv_cache: bool = True, past_key_values = None) -> list[torch.Tensor]: - + def forward_cross_attn_layer( + self, + model_layers, + inputs_embeds, + layer_idx, + position_ids, + attention_mask, + batch_size, + head_dim, + use_cache: bool = True, + fill_kv_cache: bool = True, + past_key_values=None, + ) -> list[torch.Tensor]: attention_interface = self.get_attention_interface() att_outputs = [] - assert len(inputs_embeds) == 2 or (use_cache and past_key_values is not None and not fill_kv_cache), f"Both len(inputs_embeds) == {len(inputs_embeds)} and past_key_values is {past_key_values}" + assert len(inputs_embeds) == 2 or (use_cache and past_key_values is not None and not fill_kv_cache), ( + f"Both len(inputs_embeds) == {len(inputs_embeds)} and past_key_values is {past_key_values}" + ) if len(inputs_embeds) == 2 and not past_key_values: # Prefix attention seq_len = inputs_embeds[0].shape[1] - position_id, expert_position_id = position_ids[:, :seq_len], position_ids[:, seq_len:] + position_id, expert_position_id = position_ids[:, :seq_len], position_ids[:, seq_len:] prefix_attention_mask = attention_mask[:, :seq_len, :seq_len] - + layer = model_layers[0][layer_idx] hidden_states = layer.input_layernorm(inputs_embeds[0]) @@ -558,7 +601,6 @@ class SmolVLMWithExpertModel(nn.Module): key_states = past_key_values[layer_idx]["key_states"] value_states = past_key_values[layer_idx]["value_states"] - # Expert expert_layer = model_layers[1][layer_idx] if expert_layer is not None: @@ -570,21 +612,37 @@ class SmolVLMWithExpertModel(nn.Module): expert_hidden_states = expert_hidden_states.to(dtype=expert_layer.self_attn.q_proj.weight.dtype) expert_query_state = expert_layer.self_attn.q_proj(expert_hidden_states).view(expert_hidden_shape) + _key_states = key_states.to(dtype=expert_layer.self_attn.k_proj.weight.dtype).view( + *key_states.shape[:2], -1 + ) + expert_key_states = expert_layer.self_attn.k_proj(_key_states).view( + *_key_states.shape[:-1], -1, expert_layer.self_attn.head_dim + ) # k_proj should have same dim as kv - _key_states = key_states.to(dtype=expert_layer.self_attn.k_proj.weight.dtype).view(*key_states.shape[:2], -1) - expert_key_states = expert_layer.self_attn.k_proj(_key_states).view(*_key_states.shape[:-1], -1, expert_layer.self_attn.head_dim) # k_proj should have same dim as kv + _value_states = value_states.to(dtype=expert_layer.self_attn.v_proj.weight.dtype).view( + *value_states.shape[:2], -1 + ) + expert_value_states = expert_layer.self_attn.v_proj(_value_states).view( + *_value_states.shape[:-1], -1, expert_layer.self_attn.head_dim + ) - _value_states = value_states.to(dtype=expert_layer.self_attn.v_proj.weight.dtype).view(*value_states.shape[:2], -1) - expert_value_states = expert_layer.self_attn.v_proj(_value_states).view(*_value_states.shape[:-1], -1, expert_layer.self_attn.head_dim) - - expert_position_id = expert_position_id - torch.min(expert_position_id, dim=1, keepdim=True).values # start from 0 - expert_attention_mask = attention_mask[:, -inputs_embeds[1].shape[1]:, :expert_key_states.shape[1]:] # take into account kv + expert_position_id = ( + expert_position_id - torch.min(expert_position_id, dim=1, keepdim=True).values + ) # start from 0 + expert_attention_mask = attention_mask[ + :, -inputs_embeds[1].shape[1] :, : expert_key_states.shape[1] : + ] # take into account kv expert_query_states = apply_rope(expert_query_state, expert_position_id) # expert_key_states = apply_rope(expert_key_state, expert_position_id) - + att_output = attention_interface( - expert_attention_mask, batch_size, head_dim, expert_query_states, expert_key_states, expert_value_states + expert_attention_mask, + batch_size, + head_dim, + expert_query_states, + expert_key_states, + expert_value_states, ) att_outputs.append(att_output) else: @@ -592,8 +650,8 @@ class SmolVLMWithExpertModel(nn.Module): # att_output = att_output.to(dtype=models[i].dtype) return att_outputs, past_key_values - - def get_model_layers(self, models: list) -> list: # FIXME(mshukor): is this efficient? + + def get_model_layers(self, models: list) -> list: # FIXME(mshukor): is this efficient? vlm_layers = [] expert_layers = [] multiple_of = self.num_vlm_layers // self.num_expert_layers @@ -606,15 +664,16 @@ class SmolVLMWithExpertModel(nn.Module): vlm_layers.append(models[0].layers[i]) expert_layers.append(expert_layer) return [vlm_layers, expert_layers] + # TODO: break down this huge forward into modules or functions def forward( self, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None, - inputs_embeds: List[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - fill_kv_cache: Optional[bool] = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: list[torch.FloatTensor] | Cache | None = None, + inputs_embeds: list[torch.FloatTensor] = None, + use_cache: bool | None = None, + fill_kv_cache: bool | None = None, ): models = [self.get_vlm_model().text_model, self.lm_expert] model_layers = self.get_model_layers(models) @@ -626,26 +685,58 @@ class SmolVLMWithExpertModel(nn.Module): continue batch_size = hidden_states.shape[0] - # # Pad prefix embds so that prefix_embs + prefix_embs len are multiple of 128, pad left or right depending on the gen or train + # # Pad prefix embds so that prefix_embs + prefix_embs len are multiple of 128, pad left or right depending on the gen or train if self.attention_implementation == "flex": - if inputs_embeds[0] is not None and inputs_embeds[1] is not None and attention_mask.shape[-1] == attention_mask.shape[-2] and past_key_values is None: # Now only during training + if ( + inputs_embeds[0] is not None + and inputs_embeds[1] is not None + and attention_mask.shape[-1] == attention_mask.shape[-2] + and past_key_values is None + ): # Now only during training seq_len = inputs_embeds[0].shape[1] + inputs_embeds[1].shape[1] - padded_seq_len = _round_up_to_multiple(seq_len, 128) # FIXME(mshukor): more efficient to have a fixed seq len? + padded_seq_len = _round_up_to_multiple( + seq_len, 128 + ) # FIXME(mshukor): more efficient to have a fixed seq len? b_mask, q_len, kv_len = attention_mask.shape # The shape of your mask pad = padded_seq_len - q_len attention_mask = F.pad(attention_mask, (0, pad, 0, pad), value=True) inputs_embeds[0] = F.pad(inputs_embeds[0], (0, 0, 0, pad), value=0.0) position_ids = F.pad(position_ids, (0, pad), value=0) - # RMSNorm num_layers = self.num_vlm_layers head_dim = self.vlm.config.text_config.head_dim for layer_idx in range(num_layers): - if fill_kv_cache or "cross" not in self.attention_mode or (self.self_attn_every_n_layers > 0 and layer_idx % self.self_attn_every_n_layers == 0): - att_outputs, past_key_values = self.forward_attn_layer(model_layers, inputs_embeds, layer_idx, position_ids, attention_mask, batch_size, head_dim, use_cache=use_cache, fill_kv_cache=fill_kv_cache, past_key_values=past_key_values) + if ( + fill_kv_cache + or "cross" not in self.attention_mode + or (self.self_attn_every_n_layers > 0 and layer_idx % self.self_attn_every_n_layers == 0) + ): + att_outputs, past_key_values = self.forward_attn_layer( + model_layers, + inputs_embeds, + layer_idx, + position_ids, + attention_mask, + batch_size, + head_dim, + use_cache=use_cache, + fill_kv_cache=fill_kv_cache, + past_key_values=past_key_values, + ) else: - att_outputs, past_key_values = self.forward_cross_attn_layer(model_layers, inputs_embeds, layer_idx, position_ids, attention_mask, batch_size, head_dim, use_cache=use_cache, fill_kv_cache=fill_kv_cache, past_key_values=past_key_values) + att_outputs, past_key_values = self.forward_cross_attn_layer( + model_layers, + inputs_embeds, + layer_idx, + position_ids, + attention_mask, + batch_size, + head_dim, + use_cache=use_cache, + fill_kv_cache=fill_kv_cache, + past_key_values=past_key_values, + ) # query_states = [] # key_states = [] # value_states = [] @@ -703,7 +794,6 @@ class SmolVLMWithExpertModel(nn.Module): # attention_mask, batch_size, head_dim, query_states, key_states, value_states # ) - # att_output = att_output.to(dtype=models[i].dtype) # first part of att_output is prefix (up to sequence length, [:, 0:prefix_seq_len]) @@ -712,7 +802,9 @@ class SmolVLMWithExpertModel(nn.Module): for i, hidden_states in enumerate(inputs_embeds): # layer = models[i].layers[layer_idx] layer = model_layers[i][layer_idx] - att_output = att_outputs[i] if i < len(att_outputs) else att_outputs[0] # in case of self_attn + att_output = ( + att_outputs[i] if i < len(att_outputs) else att_outputs[0] + ) # in case of self_attn if hidden_states is not None: if layer is None: outputs_embeds.append(hidden_states) @@ -759,7 +851,11 @@ class SmolVLMWithExpertModel(nn.Module): if self.attention_implementation == "fa2": attention_interface = self.flash_attention_forward elif self.attention_implementation == "flex": - attention_interface = partial(flex_attention_forward, num_att_heads=self.num_attention_heads, num_key_value_heads=self.num_key_value_heads) + attention_interface = partial( + flex_attention_forward, + num_att_heads=self.num_attention_heads, + num_key_value_heads=self.num_key_value_heads, + ) else: attention_interface = self.eager_attention_forward return attention_interface @@ -807,7 +903,7 @@ class SmolVLMWithExpertModel(nn.Module): att_weights *= head_dim**-0.5 att_weights = att_weights.to(dtype=torch.float32) - big_neg = torch.finfo(att_weights.dtype).min #-2.3819763e38 # See gemma/modules.py + big_neg = torch.finfo(att_weights.dtype).min # -2.3819763e38 # See gemma/modules.py masked_att_weights = torch.where(attention_mask[:, None, :, :], att_weights, big_neg) probs = nn.functional.softmax(masked_att_weights, dim=-1) probs = probs.to(dtype=value_states.dtype) diff --git a/src/lerobot/policies/smolvla/modeling_smolvla.py b/src/lerobot/policies/smolvla/modeling_smolvla.py index 1c07d98e5..9bb22d7f7 100644 --- a/src/lerobot/policies/smolvla/modeling_smolvla.py +++ b/src/lerobot/policies/smolvla/modeling_smolvla.py @@ -1027,6 +1027,7 @@ from lerobot.policies.utils import ( populate_queues, ) from lerobot.utils.utils import get_safe_dtype + # OBS_STATE = 'state' # Matches ".soNNN", optionally followed by "-something", up to the "_buffer_" marker _VARIANT_RE = re.compile(r"\.so\d+(?:-[\w]+)?_buffer_") @@ -1891,4 +1892,4 @@ class VLAFlowMatching(nn.Module): suffix_out = suffix_out[:, -self.config.chunk_size :] suffix_out = suffix_out.to(dtype=torch.float32) v_t = self.action_out_proj(suffix_out) - return v_t \ No newline at end of file + return v_t diff --git a/src/lerobot/policies/smolvla/saver.txt b/src/lerobot/policies/smolvla/saver.txt index 3410062ba..f2ad6c76f 100644 --- a/src/lerobot/policies/smolvla/saver.txt +++ b/src/lerobot/policies/smolvla/saver.txt @@ -1 +1 @@ -c \ No newline at end of file +c diff --git a/src/lerobot/policies/smolvla/smolvlm_with_expert.py b/src/lerobot/policies/smolvla/smolvlm_with_expert.py index e4cd7acac..3b78c99e6 100644 --- a/src/lerobot/policies/smolvla/smolvlm_with_expert.py +++ b/src/lerobot/policies/smolvla/smolvlm_with_expert.py @@ -547,4 +547,4 @@ class SmolVLMWithExpertModel(nn.Module): # we use -1 because sequence length can change att_output = att_output.reshape(batch_size, -1, num_key_value_heads * num_key_value_groups * head_dim) - return att_output \ No newline at end of file + return att_output diff --git a/src/lerobot/scripts/eval.py b/src/lerobot/scripts/eval.py index 795ed2b3c..b96fbb8a3 100644 --- a/src/lerobot/scripts/eval.py +++ b/src/lerobot/scripts/eval.py @@ -45,17 +45,21 @@ Note that in both examples, the repo/folder should contain at least `config.json You can learn about the CLI options for this script in the `EvalPipelineConfig` in lerobot/configs/eval.py """ -import concurrent + +import concurrent.futures as cf import json import logging import threading import time +from collections import defaultdict from collections.abc import Callable from contextlib import nullcontext from copy import deepcopy from dataclasses import asdict from pathlib import Path from pprint import pformat +from typing import Dict, List, Tuple, TypedDict +from collections.abc import Iterator import einops import gymnasium as gym @@ -68,7 +72,11 @@ from tqdm import trange from lerobot.configs import parser from lerobot.configs.eval import EvalPipelineConfig from lerobot.envs.factory import make_env -from lerobot.envs.utils import add_envs_task, check_env_attributes_and_types, preprocess_observation, preprocess_observation1 +from lerobot.envs.utils import ( + add_envs_task, + check_env_attributes_and_types, + preprocess_observation, +) from lerobot.policies.factory import make_policy from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.utils import get_device_from_parameters @@ -79,9 +87,6 @@ from lerobot.utils.utils import ( init_logging, inside_slurm, ) -from typing import TypedDict, Dict, List, Tuple, Iterator -from collections import defaultdict -import concurrent.futures as cf def rollout( @@ -485,8 +490,12 @@ def _compile_episode_data( data_dict["index"] = torch.arange(start_data_index, start_data_index + total_frames, 1) return data_dict -from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy + + from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata +from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy + + def _inject_normalization_stats(policy: SmolVLAPolicy, dataset_meta: LeRobotDatasetMetadata): """Recreate normalization layers with proper stats from the dataset.""" from lerobot.policies.normalize import Normalize, Unnormalize @@ -518,7 +527,8 @@ def _inject_normalization_stats(policy: SmolVLAPolicy, dataset_meta: LeRobotData def load_smolvla(cfg, dataset_repo: str, policy): from lerobot.datasets.lerobot_dataset import LeRobotDataset - dataset = LeRobotDataset(dataset_repo, root='/raid/jade/.cache/huggingface/datasets/') + + dataset = LeRobotDataset(dataset_repo, root="/raid/jade/.cache/huggingface/datasets/") _inject_normalization_stats(policy=policy, dataset_meta=dataset.meta) # only needed if stats are missing return policy.to("cuda"), dataset @@ -529,8 +539,8 @@ def eval_main(cfg: EvalPipelineConfig): # Check device is available device = get_safe_torch_device(cfg.policy.device, log=True) - #login to hf - from huggingface_hub import login + # login to hf + # login() torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True @@ -549,7 +559,7 @@ def eval_main(cfg: EvalPipelineConfig): breakpoint() # policy, _ = load_smolvla(cfg.policy, "physical-intelligence/libero", policy) # rename "image" -> "observation.image" - + policy.eval() with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(): info = eval_policy_all( @@ -584,10 +594,11 @@ def eval_main(cfg: EvalPipelineConfig): # ---- typed payload returned by one task eval ---- class TaskMetrics(TypedDict): - sum_rewards: List[float] - max_rewards: List[float] - successes: List[bool] - video_paths: List[str] + sum_rewards: list[float] + max_rewards: list[float] + successes: list[bool] + video_paths: list[str] + ACC_KEYS = ("sum_rewards", "max_rewards", "successes", "video_paths") @@ -610,7 +621,7 @@ def eval_policy_all( """ global_start = time.time() - # inner: evaluate a single (suite, task) + # inner: evaluate a single (suite, task) def eval_one( task_group: str, task_id: int, @@ -650,27 +661,36 @@ def eval_policy_all( video_paths=task_result.get("video_paths", []), ) - # result producer: sequential or threaded, same consumer - def iter_task_results() -> Iterator[Tuple[str, int, TaskMetrics]]: + # result producer: sequential or threaded, same consumer + def iter_task_results() -> Iterator[tuple[str, int, TaskMetrics]]: if max_parallel_tasks == 1: for task_group, tasks in envs.items(): for task_id, vec in tasks.items(): - yield task_group, task_id, eval_one( - task_group, task_id, vec, - policy=policy, - n_episodes=n_episodes, - max_episodes_rendered=max_episodes_rendered, - videos_dir=videos_dir, - return_episode_data=return_episode_data, - start_seed=start_seed, + yield ( + task_group, + task_id, + eval_one( + task_group, + task_id, + vec, + policy=policy, + n_episodes=n_episodes, + max_episodes_rendered=max_episodes_rendered, + videos_dir=videos_dir, + return_episode_data=return_episode_data, + start_seed=start_seed, + ), ) else: with cf.ThreadPoolExecutor(max_workers=max_parallel_tasks) as executor: - fut2key: Dict[cf.Future, Tuple[str, int]] = {} + fut2key: dict[cf.Future, tuple[str, int]] = {} for task_group, tasks in envs.items(): for task_id, vec in tasks.items(): fut = executor.submit( - eval_one, task_group, task_id, vec, + eval_one, + task_group, + task_id, + vec, policy=policy, n_episodes=n_episodes, max_episodes_rendered=max_episodes_rendered, @@ -683,9 +703,9 @@ def eval_policy_all( task_group, task_id = fut2key[fut] yield task_group, task_id, fut.result() - # single accumulator path on the main thread - group_acc: Dict[str, Dict[str, List]] = defaultdict(lambda: {k: [] for k in ACC_KEYS}) - overall: Dict[str, List] = {k: [] for k in ACC_KEYS} + # single accumulator path on the main thread + group_acc: dict[str, dict[str, list]] = defaultdict(lambda: {k: [] for k in ACC_KEYS}) + overall: dict[str, list] = {k: [] for k in ACC_KEYS} for task_group, task_id, metrics in iter_task_results(): acc = group_acc[task_group] @@ -694,7 +714,7 @@ def eval_policy_all( overall[k].extend(metrics[k]) # build outputs - results: Dict[str, dict] = {} + results: dict[str, dict] = {} for task_group, data in group_acc.items(): suite_rewards = data["sum_rewards"] suite_max = data["max_rewards"] @@ -720,9 +740,15 @@ def eval_policy_all( global_eval_ep_s = global_eval_s / max(1, len(overall["sum_rewards"])) results["overall"] = { "aggregated": { - "avg_sum_reward": float(np.nanmean(overall["sum_rewards"])) if overall["sum_rewards"] else float("nan"), - "avg_max_reward": float(np.nanmean(overall["max_rewards"])) if overall["max_rewards"] else float("nan"), - "pc_success": float(np.nanmean(overall["successes"]) * 100) if overall["successes"] else float("nan"), + "avg_sum_reward": float(np.nanmean(overall["sum_rewards"])) + if overall["sum_rewards"] + else float("nan"), + "avg_max_reward": float(np.nanmean(overall["max_rewards"])) + if overall["max_rewards"] + else float("nan"), + "pc_success": float(np.nanmean(overall["successes"]) * 100) + if overall["successes"] + else float("nan"), "eval_s": global_eval_s, "eval_ep_s": global_eval_ep_s, }, @@ -732,7 +758,6 @@ def eval_policy_all( return results - if __name__ == "__main__": init_logging() eval_main() diff --git a/src/lerobot/scripts/train.py b/src/lerobot/scripts/train.py index 3feeb0512..fc8be4ebc 100644 --- a/src/lerobot/scripts/train.py +++ b/src/lerobot/scripts/train.py @@ -105,6 +105,7 @@ def update_policy( train_metrics.update_s = time.perf_counter() - start_time return train_metrics, output_dict + # def _inject_normalization_stats(policy: SmolVLAPolicy, dataset_meta: LeRobotDatasetMetadata): # """Recreate normalization layers with dataset stats if missing (Adil's workaround).""" # from lerobot.policies.normalize import Normalize, Unnormalize @@ -132,6 +133,7 @@ def update_policy( # print("✅ Normalization layers injected with dataset stats.") + @parser.wrap() def train(cfg: TrainPipelineConfig): cfg.validate() @@ -271,9 +273,12 @@ def train(cfg: TrainPipelineConfig): if cfg.env and is_eval_step: step_id = get_step_identifier(step, cfg.steps) logging.info(f"Eval policy at step {step}") - with torch.no_grad(), (torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext()): + with ( + torch.no_grad(), + torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(), + ): eval_info = eval_policy_all( - eval_env, # dict[suite][task_id] -> vec_env + eval_env, # dict[suite][task_id] -> vec_env policy, cfg.eval.n_episodes, videos_dir=videos_dir, @@ -295,15 +300,15 @@ def train(cfg: TrainPipelineConfig): # meters/tracker eval_metrics = { "avg_sum_reward": AverageMeter("∑rwrd", ":.3f"), - "pc_success": AverageMeter("success", ":.1f"), - "eval_s": AverageMeter("eval_s", ":.3f"), + "pc_success": AverageMeter("success", ":.1f"), + "eval_s": AverageMeter("eval_s", ":.3f"), } eval_tracker = MetricsTracker( cfg.batch_size, dataset.num_frames, dataset.num_episodes, eval_metrics, initial_step=step ) - eval_tracker.eval_s = aggregated.get("eval_s", 0.0) + eval_tracker.eval_s = aggregated.get("eval_s", 0.0) eval_tracker.avg_sum_reward = aggregated.get("avg_sum_reward", float("nan")) - eval_tracker.pc_success = aggregated.get("pc_success", float("nan")) + eval_tracker.pc_success = aggregated.get("pc_success", float("nan")) if wandb_logger: wandb_log_dict = {**eval_tracker.to_dict(), **eval_info} wandb_logger.log_dict(wandb_log_dict, step, mode="eval") diff --git a/src/lerobot/scripts/train_2.py b/src/lerobot/scripts/train_2.py index 26a9e7aea..5b82ef044 100644 --- a/src/lerobot/scripts/train_2.py +++ b/src/lerobot/scripts/train_2.py @@ -104,6 +104,7 @@ def update_policy( train_metrics.update_s = time.perf_counter() - start_time return train_metrics, output_dict + def _inject_normalization_stats(policy: SmolVLAPolicy, dataset_meta: LeRobotDatasetMetadata): """Recreate normalization layers with dataset stats if missing (Adil's workaround).""" from lerobot.policies.normalize import Normalize, Unnormalize @@ -115,15 +116,15 @@ def _inject_normalization_stats(policy: SmolVLAPolicy, dataset_meta: LeRobotData stats = {} for key, stat_dict in dataset_meta.stats.items(): stats[key] = { - stat_type: torch.as_tensor(stat_array) - if isinstance(stat_array, np.ndarray) - else stat_array + stat_type: torch.as_tensor(stat_array) if isinstance(stat_array, np.ndarray) else stat_array for stat_type, stat_array in stat_dict.items() } normalize_inputs = Normalize(policy.config.input_features, policy.config.normalization_mapping, stats) normalize_targets = Normalize(policy.config.output_features, policy.config.normalization_mapping, stats) - unnormalize_outputs = Unnormalize(policy.config.output_features, policy.config.normalization_mapping, stats) + unnormalize_outputs = Unnormalize( + policy.config.output_features, policy.config.normalization_mapping, stats + ) policy.normalize_inputs = normalize_inputs policy.normalize_targets = normalize_targets @@ -131,6 +132,7 @@ def _inject_normalization_stats(policy: SmolVLAPolicy, dataset_meta: LeRobotData print("✅ Normalization layers injected with dataset stats.") + @parser.wrap() def train(cfg: TrainPipelineConfig): cfg.validate() diff --git a/src/lerobot/scripts/train_accelerate.py b/src/lerobot/scripts/train_accelerate.py index e205f138f..1e8a59a64 100644 --- a/src/lerobot/scripts/train_accelerate.py +++ b/src/lerobot/scripts/train_accelerate.py @@ -24,12 +24,15 @@ from accelerate.utils import set_seed as accelerate_set_seed from termcolor import colored from torch.optim import Optimizer +from lerobot.configs import parser +from lerobot.configs.train import TrainPipelineConfig from lerobot.datasets.factory import make_dataset from lerobot.datasets.sampler import EpisodeAwareSampler from lerobot.envs.factory import make_env from lerobot.optim.factory import make_optimizer_and_scheduler from lerobot.policies.factory import make_policy from lerobot.policies.pretrained import PreTrainedPolicy +from lerobot.scripts.eval import eval_policy from lerobot.utils.logging_utils import AverageMeter, MetricsTracker from lerobot.utils.train_utils import ( get_step_checkpoint_dir, @@ -43,9 +46,6 @@ from lerobot.utils.utils import ( has_method, init_logging, ) -from lerobot.configs import parser -from lerobot.configs.train import TrainPipelineConfig -from lerobot.scripts.eval import eval_policy def update_policy( @@ -100,6 +100,7 @@ def train(cfg: TrainPipelineConfig): # Initialize accelerator from accelerate.utils import DistributedDataParallelKwargs + # added by jade 2 lines ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=False) accelerator = Accelerator(..., kwargs_handlers=[ddp_kwargs]) @@ -357,7 +358,7 @@ def train(cfg: TrainPipelineConfig): if accelerator.is_main_process: logging.info("End of training") - accelerator.end_training() # added by jade + accelerator.end_training() # added by jade if __name__ == "__main__":