mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-24 04:59:47 +00:00
trying out to re-init the action head to avoid pretraining dimension mismatch
This commit is contained in:
@@ -27,6 +27,7 @@ class VLAJEPAConfig(PreTrainedConfig):
|
|||||||
jepa_encoder_name: str = "facebook/vjepa2-vitl-fpc64-256"
|
jepa_encoder_name: str = "facebook/vjepa2-vitl-fpc64-256"
|
||||||
freeze_qwen: bool = False
|
freeze_qwen: bool = False
|
||||||
enable_world_model: bool = True
|
enable_world_model: bool = True
|
||||||
|
reinit_action_head: bool = False
|
||||||
|
|
||||||
tokenizer_padding_side: str = "left"
|
tokenizer_padding_side: str = "left"
|
||||||
prompt_template: str = "Your task is {instruction}. Infer the temporal dynamics from frames {actions} and produce the corresponding policy actions {e_actions}."
|
prompt_template: str = "Your task is {instruction}. Infer the temporal dynamics from frames {actions} and produce the corresponding policy actions {e_actions}."
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
@@ -554,3 +555,35 @@ class VLAJEPAPolicy(PreTrainedPolicy):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
return super().from_pretrained(pretrained_name_or_path, **kwargs)
|
return super().from_pretrained(pretrained_name_or_path, **kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, strict: bool) -> T:
|
||||||
|
if not model.config.reinit_action_head:
|
||||||
|
return super()._load_as_safetensor(model, model_file, map_location, strict)
|
||||||
|
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
|
state_dict = load_file(model_file, device=map_location)
|
||||||
|
current = model.state_dict()
|
||||||
|
|
||||||
|
mismatched: list[str] = []
|
||||||
|
filtered: dict = {}
|
||||||
|
for key, value in state_dict.items():
|
||||||
|
if key in current and value.shape != current[key].shape:
|
||||||
|
mismatched.append(
|
||||||
|
f"{key}: checkpoint {tuple(value.shape)} vs model {tuple(current[key].shape)}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
filtered[key] = value
|
||||||
|
|
||||||
|
if mismatched:
|
||||||
|
logging.warning(
|
||||||
|
f"reinit_action_head=True: skipping {len(mismatched)} tensor(s) with mismatched shapes "
|
||||||
|
f"(randomly re-initialised):\n " + "\n ".join(mismatched)
|
||||||
|
)
|
||||||
|
|
||||||
|
from lerobot.policies.utils import log_model_loading_keys
|
||||||
|
|
||||||
|
missing_keys, unexpected_keys = model.load_state_dict(filtered, strict=False)
|
||||||
|
log_model_loading_keys(missing_keys, unexpected_keys)
|
||||||
|
return model
|
||||||
|
|||||||
Reference in New Issue
Block a user