From 1642a14ae5f95745a4d4217a5ef792b38a40d1f9 Mon Sep 17 00:00:00 2001 From: Maximellerbach Date: Mon, 18 May 2026 13:54:07 +0200 Subject: [PATCH] trying out to re-init the action head to avoid pretraining dimension mismatch --- .../vla_jepa/configuration_vla_jepa.py | 1 + .../policies/vla_jepa/modeling_vla_jepa.py | 33 +++++++++++++++++++ 2 files changed, 34 insertions(+) diff --git a/src/lerobot/policies/vla_jepa/configuration_vla_jepa.py b/src/lerobot/policies/vla_jepa/configuration_vla_jepa.py index d2092fe83..c3c2cd2f0 100644 --- a/src/lerobot/policies/vla_jepa/configuration_vla_jepa.py +++ b/src/lerobot/policies/vla_jepa/configuration_vla_jepa.py @@ -27,6 +27,7 @@ class VLAJEPAConfig(PreTrainedConfig): jepa_encoder_name: str = "facebook/vjepa2-vitl-fpc64-256" freeze_qwen: bool = False enable_world_model: bool = True + reinit_action_head: bool = False tokenizer_padding_side: str = "left" prompt_template: str = "Your task is {instruction}. Infer the temporal dynamics from frames {actions} and produce the corresponding policy actions {e_actions}." diff --git a/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py b/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py index 7d6590bfd..1cf7b551f 100644 --- a/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py +++ b/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py @@ -1,5 +1,6 @@ from __future__ import annotations +import logging from collections import deque from pathlib import Path from typing import TYPE_CHECKING @@ -554,3 +555,35 @@ class VLAJEPAPolicy(PreTrainedPolicy): **kwargs, ): return super().from_pretrained(pretrained_name_or_path, **kwargs) + + @classmethod + def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, strict: bool) -> T: + if not model.config.reinit_action_head: + return super()._load_as_safetensor(model, model_file, map_location, strict) + + from safetensors.torch import load_file + + state_dict = load_file(model_file, device=map_location) + current = model.state_dict() + + mismatched: list[str] = [] + filtered: dict = {} + for key, value in state_dict.items(): + if key in current and value.shape != current[key].shape: + mismatched.append( + f"{key}: checkpoint {tuple(value.shape)} vs model {tuple(current[key].shape)}" + ) + else: + filtered[key] = value + + if mismatched: + logging.warning( + f"reinit_action_head=True: skipping {len(mismatched)} tensor(s) with mismatched shapes " + f"(randomly re-initialised):\n " + "\n ".join(mismatched) + ) + + from lerobot.policies.utils import log_model_loading_keys + + missing_keys, unexpected_keys = model.load_state_dict(filtered, strict=False) + log_model_loading_keys(missing_keys, unexpected_keys) + return model