mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 16:57:12 +00:00
fix(smolvla2,pi052): training-correctness audit fixes
CRITICAL (smolvla2) — text-CE was applied to the wrong prefix slice.
``num_state`` was being read from ``state.shape[1]`` (the raw
max_state_dim, ~14-32) instead of the *number of state tokens*
(always 1). Compounded by the trailing-padding issue (state is
not at the end of the padded prefix when ``seq_len < prefix_length``),
the lang slice was landing on image / padding hidden states.
New ``_locate_lang_range`` finds the state position via
``att_masks.nonzero()`` (the only ``1`` in the mask), making the
slice robust to both bugs. Used by ``_compute_text_loss`` and
``_compute_fused_loss``.
LIKELY-BUG (smolvla2) — ``_unfreeze_lm_head`` only re-enabled
``lm_head`` and ``text_model.model.norm.weight``. SmolVLA's parent
ALSO freezes the last 1-2 transformer layers, so text-loss
gradients died in a frozen final block. Now mirrors the parent's
freeze targets and unfreezes the matching ``layers.{N-1}`` (and
``N-2`` when num_vlm % num_expert == 0).
CRITICAL (pi052) — flow and FAST CE were not per-sample masked
under per-sample-routing. Text-only recipe samples
(``plan_generation``, ``ask_vqa_*``) contributed to flow/FAST
loss with prompts that deliberately omit the subtask, corrupting
the signal. Threaded ``predict_actions_t`` through both
``_compute_all_losses_fused`` and ``_compute_text_and_fast_loss``;
flow uses ``(per_sample * mask).sum() / mask.sum()``, FAST uses
``shift_valid & sample_mask`` before ``masked_fill(-100)``.
OTHER
* PI052Policy.forward now falls through to PI05Policy.forward on
unannotated batches (no text_labels, no predict_actions, no FAST).
* fit_fast_tokenizer cache key now includes ``chunk_size`` — changing
the chunk size no longer silently loads a wrongly-fit tokenizer.
* Removed dead ``_compute_text_loss`` / ``_compute_fast_action_loss``
in pi052 (superseded by the fused helpers).
* Fixed stale "no-op stub" docstring on ``knowledge_insulation`` —
it's been fully wired since the per-layer KI forward port.
* Stripped unused ``copy`` / ``resize_with_pad`` imports.
* Extracted ``_shifted_ce`` / ``_mask_per_sample`` / ``_fast_ce``
helpers shared between fused and prefix-only paths.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -152,19 +152,15 @@ class PI052Config(PI05Config):
|
||||
|
||||
# Knowledge insulation — paper §III.B --------------------------------
|
||||
# When enabled, gradients from the action expert's flow loss are
|
||||
# *blocked* from flowing back into the VLM's K/V projections. This
|
||||
# blocked from flowing back into the VLM's K/V projections. This
|
||||
# prevents the action loss from over-fitting the language backbone
|
||||
# to robot-specific features. Implementation requires a custom
|
||||
# per-layer attention forward that uses ``.detach()`` on VLM K/V
|
||||
# when computing attention for action queries — see
|
||||
# ``pi05_full/modeling_pi05.py::compute_layer_complete_knowledge_insulation``
|
||||
# on branch ``feat/add-pi05`` for the reference implementation.
|
||||
# to robot-specific features. Implemented in ``modeling_pi052`` as
|
||||
# a per-instance monkey-patch on ``paligemma_with_expert.forward``
|
||||
# that splits queries into VLM and action halves and ``.detach()``-s
|
||||
# the VLM K/V tensors used in the action-half's attention.
|
||||
knowledge_insulation: bool = False
|
||||
"""If True, route every transformer layer through the KI
|
||||
attention path (blocks action→VLM gradient flow on K/V).
|
||||
Currently a no-op stub in this branch — PR follow-up required to
|
||||
port the full custom layer from pi05_full. The flag is exposed
|
||||
here so SLURM commands can be written against the final shape."""
|
||||
attention path that blocks action→VLM gradient flow on K/V."""
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
super().__post_init__()
|
||||
|
||||
@@ -47,12 +47,17 @@ import numpy as np
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _dataset_signature(dataset_repo_id: str, base_tokenizer_name: str, n_samples: int) -> str:
|
||||
def _dataset_signature(
|
||||
dataset_repo_id: str,
|
||||
base_tokenizer_name: str,
|
||||
n_samples: int,
|
||||
chunk_size: int,
|
||||
) -> str:
|
||||
"""Deterministic short hash for naming the cache directory.
|
||||
|
||||
Keys on (dataset, base tokenizer, sample count) so re-fitting on a
|
||||
new dataset or a different base doesn't clobber the prior cache,
|
||||
and so changing n_samples re-runs the fit.
|
||||
Keys on (dataset, base tokenizer, sample count, chunk size) so any
|
||||
of those changing re-runs the fit. ``chunk_size`` matters because
|
||||
the tokenizer is fit on chunks of that length.
|
||||
"""
|
||||
h = hashlib.sha256()
|
||||
h.update(dataset_repo_id.encode("utf-8"))
|
||||
@@ -60,6 +65,8 @@ def _dataset_signature(dataset_repo_id: str, base_tokenizer_name: str, n_samples
|
||||
h.update(base_tokenizer_name.encode("utf-8"))
|
||||
h.update(b"\0")
|
||||
h.update(str(n_samples).encode("utf-8"))
|
||||
h.update(b"\0")
|
||||
h.update(str(chunk_size).encode("utf-8"))
|
||||
return h.hexdigest()[:16]
|
||||
|
||||
|
||||
@@ -102,7 +109,7 @@ def fit_fast_tokenizer(
|
||||
FileNotFoundError: If the dataset can't be loaded.
|
||||
"""
|
||||
cache_dir = Path(cache_dir)
|
||||
sig = _dataset_signature(dataset_repo_id, base_tokenizer_name, n_samples)
|
||||
sig = _dataset_signature(dataset_repo_id, base_tokenizer_name, n_samples, chunk_size)
|
||||
out_dir = cache_dir / sig
|
||||
|
||||
if out_dir.exists() and (out_dir / "preprocessor_config.json").exists():
|
||||
|
||||
@@ -55,6 +55,65 @@ from .configuration_pi052 import PI052Config
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------
|
||||
# Loss helpers (shared between fused and prefix-only paths)
|
||||
# ----------------------------------------------------------------------
|
||||
|
||||
|
||||
def _mask_per_sample(per_sample: Tensor, predict_actions_t: Tensor | None) -> Tensor:
|
||||
"""Mean over samples where ``predict_actions_t`` is True, else over all."""
|
||||
if predict_actions_t is None:
|
||||
return per_sample.mean()
|
||||
mask = predict_actions_t.to(per_sample.dtype)
|
||||
return (per_sample * mask).sum() / mask.sum().clamp(min=1.0)
|
||||
|
||||
|
||||
def _shifted_ce(logits: Tensor, labels: Tensor) -> Tensor:
|
||||
"""Next-token CE: hidden at t predicts label at t+1, ignore_index=-100.
|
||||
|
||||
Mean over non-ignored positions across the batch. Returns 0 cleanly
|
||||
when no positions are supervised (clamp(min=1) on the denominator).
|
||||
"""
|
||||
shift_logits = logits[:, :-1, :].contiguous()
|
||||
shift_labels = labels[:, 1:].contiguous().long()
|
||||
valid = shift_labels != -100
|
||||
loss = F.cross_entropy(
|
||||
shift_logits.reshape(-1, shift_logits.shape[-1]),
|
||||
shift_labels.reshape(-1),
|
||||
ignore_index=-100,
|
||||
reduction="sum",
|
||||
)
|
||||
return loss / valid.sum().clamp(min=1)
|
||||
|
||||
|
||||
def _fast_ce(
|
||||
fast_logits: Tensor,
|
||||
action_tokens: Tensor,
|
||||
action_mask: Tensor,
|
||||
predict_actions_t: Tensor | None,
|
||||
) -> Tensor:
|
||||
"""FAST-CE with both token-pad masking and per-sample action gating.
|
||||
|
||||
``action_mask`` is the FAST tokenizer's padding mask; samples whose
|
||||
recipe sets ``predict_actions=False`` (e.g. plan_generation,
|
||||
ask_vqa_*) get *all* their FAST positions masked out via the
|
||||
per-sample gate.
|
||||
"""
|
||||
shift_logits = fast_logits[:, :-1, :].contiguous()
|
||||
shift_targets = action_tokens[:, 1:].contiguous().long()
|
||||
shift_valid = action_mask[:, 1:].contiguous().bool()
|
||||
if predict_actions_t is not None:
|
||||
sample_mask = predict_actions_t[:, None].expand_as(shift_valid)
|
||||
shift_valid = shift_valid & sample_mask
|
||||
shift_targets = shift_targets.masked_fill(~shift_valid, -100)
|
||||
return F.cross_entropy(
|
||||
shift_logits.reshape(-1, shift_logits.shape[-1]),
|
||||
shift_targets.reshape(-1),
|
||||
ignore_index=-100,
|
||||
reduction="sum",
|
||||
) / shift_valid.sum().clamp(min=1)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------
|
||||
# Knowledge insulation — ported from pi05_full (branch ``feat/add-pi05``)
|
||||
# ----------------------------------------------------------------------
|
||||
@@ -307,6 +366,16 @@ class PI052Policy(PI05Policy):
|
||||
text_labels = batch.get("text_labels")
|
||||
predict_actions_t = batch.get("predict_actions")
|
||||
|
||||
# Unannotated datasets: no recipe applied → no text_labels and
|
||||
# no FAST / predict_actions routing. Defer to PI05Policy so the
|
||||
# plain flow-only training surface keeps working unchanged.
|
||||
if (
|
||||
text_labels is None
|
||||
and predict_actions_t is None
|
||||
and not getattr(self.config, "enable_fast_action_loss", False)
|
||||
):
|
||||
return super().forward(batch, reduction=reduction)
|
||||
|
||||
run_flow = (
|
||||
self.config.flow_loss_weight > 0
|
||||
and (predict_actions_t is None or bool(predict_actions_t.any().item()))
|
||||
@@ -350,6 +419,7 @@ class PI052Policy(PI05Policy):
|
||||
text_labels=text_labels if run_text else None,
|
||||
action_tokens=action_tokens if run_fast else None,
|
||||
action_mask=action_mask if run_fast else None,
|
||||
predict_actions_t=predict_actions_t,
|
||||
)
|
||||
loss_dict["flow_loss"] = float(flow_loss.detach().item())
|
||||
total = self.config.flow_loss_weight * flow_loss
|
||||
@@ -365,6 +435,7 @@ class PI052Policy(PI05Policy):
|
||||
text_labels=text_labels if run_text else None,
|
||||
action_tokens=action_tokens if run_fast else None,
|
||||
action_mask=action_mask if run_fast else None,
|
||||
predict_actions_t=predict_actions_t,
|
||||
)
|
||||
if text_loss is not None:
|
||||
loss_dict["text_loss"] = float(text_loss.detach().item())
|
||||
@@ -399,6 +470,7 @@ class PI052Policy(PI05Policy):
|
||||
text_labels: Tensor | None,
|
||||
action_tokens: Tensor | None,
|
||||
action_mask: Tensor | None,
|
||||
predict_actions_t: Tensor | None = None,
|
||||
) -> tuple[Tensor, Tensor | None, Tensor | None]:
|
||||
"""Full fusion: flow + text + FAST in ONE backbone forward.
|
||||
|
||||
@@ -510,7 +582,8 @@ class PI052Policy(PI05Policy):
|
||||
# internally to max_action_dim).
|
||||
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||
flow_per_dim = flow_per_dim[:, :, :original_action_dim]
|
||||
flow_loss = flow_per_dim.mean()
|
||||
per_sample_flow = flow_per_dim.mean(dim=(1, 2))
|
||||
flow_loss = _mask_per_sample(per_sample_flow, predict_actions_t)
|
||||
|
||||
# ---- text + FAST CE from prefix_out ------------------------
|
||||
lm_head = self.model.paligemma_with_expert.paligemma.lm_head
|
||||
@@ -523,27 +596,13 @@ class PI052Policy(PI05Policy):
|
||||
else:
|
||||
text_hidden = prefix_out[:, -lang_len:, :]
|
||||
text_logits = lm_head(text_hidden.to(lm_head.weight.dtype))
|
||||
shift_logits = text_logits[:, :-1, :].contiguous()
|
||||
shift_labels = text_labels[:, 1:].contiguous().long()
|
||||
text_loss = F.cross_entropy(
|
||||
shift_logits.reshape(-1, shift_logits.shape[-1]),
|
||||
shift_labels.reshape(-1),
|
||||
ignore_index=-100,
|
||||
)
|
||||
text_loss = _shifted_ce(text_logits, text_labels)
|
||||
|
||||
fast_loss: Tensor | None = None
|
||||
if fast_len > 0 and prefix_out is not None:
|
||||
fast_hidden = prefix_out[:, -fast_len:, :]
|
||||
fast_logits = lm_head(fast_hidden.to(lm_head.weight.dtype))
|
||||
shift_logits = fast_logits[:, :-1, :].contiguous()
|
||||
shift_targets = action_tokens[:, 1:].contiguous().long()
|
||||
shift_valid = action_mask[:, 1:].contiguous().bool()
|
||||
shift_targets = shift_targets.masked_fill(~shift_valid, -100)
|
||||
fast_loss = F.cross_entropy(
|
||||
shift_logits.reshape(-1, shift_logits.shape[-1]),
|
||||
shift_targets.reshape(-1),
|
||||
ignore_index=-100,
|
||||
)
|
||||
fast_loss = _fast_ce(fast_logits, action_tokens, action_mask, predict_actions_t)
|
||||
|
||||
return flow_loss, text_loss, fast_loss
|
||||
|
||||
@@ -553,6 +612,7 @@ class PI052Policy(PI05Policy):
|
||||
text_labels: Tensor | None,
|
||||
action_tokens: Tensor | None,
|
||||
action_mask: Tensor | None,
|
||||
predict_actions_t: Tensor | None = None,
|
||||
) -> tuple[Tensor | None, Tensor | None]:
|
||||
"""Single prefix forward → text CE + FAST CE.
|
||||
|
||||
@@ -619,180 +679,21 @@ class PI052Policy(PI05Policy):
|
||||
lang_len = text_labels.shape[1]
|
||||
# embed_prefix lays out as [images, language]; with FAST
|
||||
# appended the full sequence is [images, language, FAST].
|
||||
# Language hidden states are at positions
|
||||
# ``[-(fast_len + lang_len) : -fast_len]`` when FAST is
|
||||
# present, or ``[-lang_len:]`` otherwise.
|
||||
if fast_len > 0:
|
||||
text_hidden = vlm_out[:, -(fast_len + lang_len):-fast_len, :]
|
||||
else:
|
||||
text_hidden = vlm_out[:, -lang_len:, :]
|
||||
text_logits = lm_head(text_hidden.to(lm_head.weight.dtype))
|
||||
shift_logits = text_logits[:, :-1, :].contiguous()
|
||||
shift_labels = text_labels[:, 1:].contiguous().long()
|
||||
text_loss = F.cross_entropy(
|
||||
shift_logits.reshape(-1, shift_logits.shape[-1]),
|
||||
shift_labels.reshape(-1),
|
||||
ignore_index=-100,
|
||||
)
|
||||
text_loss = _shifted_ce(text_logits, text_labels)
|
||||
|
||||
fast_loss: Tensor | None = None
|
||||
if action_tokens is not None and action_mask is not None and fast_len > 0:
|
||||
fast_hidden = vlm_out[:, -fast_len:, :]
|
||||
fast_logits = lm_head(fast_hidden.to(lm_head.weight.dtype))
|
||||
shift_logits = fast_logits[:, :-1, :].contiguous()
|
||||
shift_targets = action_tokens[:, 1:].contiguous().long()
|
||||
shift_valid = action_mask[:, 1:].contiguous().bool()
|
||||
shift_targets = shift_targets.masked_fill(~shift_valid, -100)
|
||||
fast_loss = F.cross_entropy(
|
||||
shift_logits.reshape(-1, shift_logits.shape[-1]),
|
||||
shift_targets.reshape(-1),
|
||||
ignore_index=-100,
|
||||
)
|
||||
fast_loss = _fast_ce(fast_logits, action_tokens, action_mask, predict_actions_t)
|
||||
|
||||
return text_loss, fast_loss
|
||||
|
||||
def _compute_fast_action_loss(
|
||||
self,
|
||||
batch: dict[str, Tensor],
|
||||
action_tokens: Tensor,
|
||||
action_mask: Tensor,
|
||||
) -> Tensor:
|
||||
"""Cross-entropy on FAST-tokenised actions via PaliGemma's LM head.
|
||||
|
||||
Mirrors the paper's §III.C action-token loss: append the FAST
|
||||
tokens to the language prefix, forward through the backbone,
|
||||
slice the per-token logits at the action positions, and
|
||||
compute CE against the FAST targets (shifted for next-token
|
||||
prediction). Token-mask gates the loss to valid positions.
|
||||
"""
|
||||
from ..pi05.modeling_pi05 import make_att_2d_masks # noqa: PLC0415
|
||||
|
||||
images, img_masks = self.model._preprocess_images(batch)
|
||||
lang_tokens = batch[OBS_LANGUAGE_TOKENS]
|
||||
lang_masks = batch[OBS_LANGUAGE_ATTENTION_MASK]
|
||||
|
||||
# Re-embed prefix to get [images, language] embeddings, then
|
||||
# append the FAST action token embeddings as additional
|
||||
# prefix tokens. PaliGemma's embed_language_tokens is shared
|
||||
# between text and FAST tokens so we can re-use it directly.
|
||||
prefix_embs, prefix_pad, prefix_att = self.model.embed_prefix(
|
||||
images, img_masks, lang_tokens, lang_masks
|
||||
)
|
||||
emb_dim = prefix_embs.shape[-1]
|
||||
fast_emb = self.model.paligemma_with_expert.embed_language_tokens(action_tokens)
|
||||
fast_emb = fast_emb * math.sqrt(emb_dim)
|
||||
|
||||
# Concat onto the prefix.
|
||||
# pad masks: language uses ``lang_masks``; FAST uses
|
||||
# ``action_mask`` (True at valid token positions).
|
||||
# att masks: prefix is 0 (bidirectional block); FAST is 1
|
||||
# (each token starts its own causal block). Per
|
||||
# ``make_att_2d_masks``'s mask_ar convention this
|
||||
# yields prefix-LM attention: FAST tokens attend
|
||||
# bidirectionally to images+language and causally
|
||||
# among themselves, while prefix tokens *cannot*
|
||||
# see FAST tokens. Matches pi05_full §III.C.
|
||||
fast_len = action_tokens.shape[1]
|
||||
device = prefix_embs.device
|
||||
ones_att = torch.ones((action_tokens.shape[0], fast_len), dtype=torch.bool, device=device)
|
||||
full_embs = torch.cat([prefix_embs, fast_emb], dim=1)
|
||||
full_pad = torch.cat([prefix_pad, action_mask.to(prefix_pad.dtype)], dim=1)
|
||||
full_att = torch.cat([prefix_att, ones_att], dim=1)
|
||||
|
||||
att_2d = make_att_2d_masks(full_pad, full_att)
|
||||
position_ids = torch.cumsum(full_pad, dim=1) - 1
|
||||
|
||||
(vlm_out, _), _ = self.model.paligemma_with_expert.forward(
|
||||
attention_mask=att_2d,
|
||||
position_ids=position_ids,
|
||||
past_key_values=None,
|
||||
inputs_embeds=[full_embs, None],
|
||||
use_cache=False,
|
||||
)
|
||||
if vlm_out is None:
|
||||
raise RuntimeError("PI052 FAST loss: VLM forward returned no hidden states.")
|
||||
|
||||
# Slice the last ``fast_len`` positions — those correspond to
|
||||
# the FAST tokens we just appended.
|
||||
fast_hidden = vlm_out[:, -fast_len:, :]
|
||||
lm_head = self.model.paligemma_with_expert.paligemma.lm_head
|
||||
fast_logits = lm_head(fast_hidden.to(lm_head.weight.dtype))
|
||||
|
||||
# Shift for next-token prediction. Replace targets at padded
|
||||
# positions with -100 so ``ignore_index`` in cross_entropy
|
||||
# cleanly drops them rather than relying on a post-hoc
|
||||
# multiply-by-mask (which still computes the CE numerator at
|
||||
# invalid positions and could crash if a padded target id
|
||||
# falls outside the vocab).
|
||||
shift_logits = fast_logits[:, :-1, :].contiguous()
|
||||
shift_targets = action_tokens[:, 1:].contiguous().long()
|
||||
shift_valid = action_mask[:, 1:].contiguous().bool()
|
||||
shift_targets = shift_targets.masked_fill(~shift_valid, -100)
|
||||
|
||||
# Mean over valid positions via ``ignore_index``.
|
||||
loss = F.cross_entropy(
|
||||
shift_logits.reshape(-1, shift_logits.shape[-1]),
|
||||
shift_targets.reshape(-1),
|
||||
ignore_index=-100,
|
||||
)
|
||||
return loss
|
||||
|
||||
def _compute_text_loss(self, batch: dict[str, Tensor], text_labels: Tensor) -> Tensor:
|
||||
"""Cross-entropy on PaliGemma's LM head over the supervised span.
|
||||
|
||||
Embeds images + language, runs the VLM-only forward (the
|
||||
action expert is skipped via ``inputs_embeds=[..., None]``),
|
||||
slices the hidden states to the *language* portion so they
|
||||
align with ``text_labels`` (which covers only the language
|
||||
tokens, not the image patch tokens), then computes shifted
|
||||
next-token CE with ``-100`` ignoring padding/non-target
|
||||
positions.
|
||||
"""
|
||||
from ..pi05.modeling_pi05 import make_att_2d_masks # noqa: PLC0415
|
||||
|
||||
images, img_masks = self.model._preprocess_images(batch)
|
||||
tokens = batch[OBS_LANGUAGE_TOKENS]
|
||||
masks = batch[OBS_LANGUAGE_ATTENTION_MASK]
|
||||
|
||||
prefix_embs, prefix_pad_masks, prefix_att_masks = self.model.embed_prefix(
|
||||
images, img_masks, tokens, masks
|
||||
)
|
||||
att_2d = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
|
||||
position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
||||
|
||||
(vlm_out, _), _ = self.model.paligemma_with_expert.forward(
|
||||
attention_mask=att_2d,
|
||||
position_ids=position_ids,
|
||||
past_key_values=None,
|
||||
inputs_embeds=[prefix_embs, None],
|
||||
use_cache=False,
|
||||
)
|
||||
if vlm_out is None:
|
||||
raise RuntimeError("PI052 text loss: VLM forward returned no hidden states.")
|
||||
|
||||
# Slice the hidden states to the language portion. embed_prefix
|
||||
# concatenates [images, language] in that order, so the trailing
|
||||
# ``text_labels.shape[1]`` positions are the language tokens.
|
||||
# Without this slice, applying lm_head to the full vlm_out and
|
||||
# shifting against text_labels[..., 1:] produces a shape
|
||||
# mismatch in cross_entropy.
|
||||
lang_len = text_labels.shape[1]
|
||||
text_hidden = vlm_out[:, -lang_len:, :]
|
||||
|
||||
lm_head = self.model.paligemma_with_expert.paligemma.lm_head
|
||||
logits = lm_head(text_hidden.to(lm_head.weight.dtype))
|
||||
|
||||
# Shift for next-token prediction: predict token[i+1] from
|
||||
# hidden[i] within the language span.
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = text_labels[..., 1:].contiguous()
|
||||
loss = F.cross_entropy(
|
||||
shift_logits.view(-1, shift_logits.size(-1)),
|
||||
shift_labels.view(-1),
|
||||
ignore_index=-100,
|
||||
)
|
||||
return loss
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# select_message — AR text generation at inference
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@@ -36,7 +36,6 @@ Outputs:
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
@@ -38,7 +38,6 @@ matching the chat-template-stripped text order).
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
@@ -54,6 +54,56 @@ from ..smolvla.modeling_smolvla import SmolVLAPolicy, make_att_2d_masks
|
||||
from .configuration_smolvla2 import SmolVLA2Config
|
||||
|
||||
|
||||
def _locate_lang_range(prefix_att_masks: Tensor, num_lang: int) -> tuple[int, int]:
|
||||
"""Find ``[lang_start, lang_end)`` inside the SmolVLA prefix.
|
||||
|
||||
``embed_prefix`` lays out the prefix as
|
||||
``[image_blocks..., lang, state, padding]`` with the att-mask
|
||||
convention ``[0]*image, [0]*lang, [1]*state, [0]*padding`` (see
|
||||
``modeling_smolvla.SmolVLAModel.embed_prefix``). State is exactly
|
||||
one token, and it's the *only* position with ``att_mask == 1``,
|
||||
so we use the first ``1`` to anchor lang_end. Computing it this
|
||||
way is robust to (a) state being projected to one embedding token
|
||||
regardless of its raw feature dim, and (b) the trailing padding
|
||||
added when ``seq_len < prefix_length``.
|
||||
"""
|
||||
row = prefix_att_masks[0]
|
||||
ones = row.nonzero(as_tuple=False)
|
||||
if ones.numel() == 0:
|
||||
raise RuntimeError(
|
||||
"SmolVLA2: state token not found in prefix att_masks — "
|
||||
"can't locate language range."
|
||||
)
|
||||
state_start = int(ones[0, 0].item())
|
||||
lang_end = state_start
|
||||
lang_start = lang_end - num_lang
|
||||
if lang_start < 0:
|
||||
raise RuntimeError(
|
||||
f"SmolVLA2: lang range underflows prefix "
|
||||
f"(state_start={state_start}, num_lang={num_lang})."
|
||||
)
|
||||
return lang_start, lang_end
|
||||
|
||||
|
||||
def _shifted_ce(logits: Tensor, text_labels: Tensor) -> Tensor:
|
||||
"""Next-token CE: hidden at t predicts label at t+1, ignore_index=-100."""
|
||||
num_lang = logits.shape[1]
|
||||
if text_labels.shape[1] != num_lang:
|
||||
common = min(text_labels.shape[1], num_lang)
|
||||
logits = logits[:, :common]
|
||||
text_labels = text_labels[:, :common]
|
||||
shift_logits = logits[:, :-1, :].contiguous()
|
||||
shift_labels = text_labels[:, 1:].contiguous().long()
|
||||
valid = shift_labels != -100
|
||||
loss = F.cross_entropy(
|
||||
shift_logits.reshape(-1, shift_logits.shape[-1]),
|
||||
shift_labels.reshape(-1),
|
||||
ignore_index=-100,
|
||||
reduction="sum",
|
||||
)
|
||||
return loss / valid.sum().clamp(min=1)
|
||||
|
||||
|
||||
class SmolVLA2Policy(SmolVLAPolicy):
|
||||
"""SmolVLA + re-enabled SmolVLM language head."""
|
||||
|
||||
@@ -78,8 +128,15 @@ class SmolVLA2Policy(SmolVLAPolicy):
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _unfreeze_lm_head(self) -> None:
|
||||
"""Re-enable gradients on the SmolVLM ``lm_head`` (and the bits
|
||||
of the text path SmolVLA freezes) so the text-loss can flow back.
|
||||
"""Re-enable gradients on the text-output path so the LM head
|
||||
loss can flow back.
|
||||
|
||||
SmolVLA's ``set_requires_grad`` freezes three things when
|
||||
``train_expert_only=False``: ``lm_head``,
|
||||
``text_model.model.norm.weight``, and the last 1-2 text-model
|
||||
transformer layers (see ``smolvlm_with_expert.py:167-176``).
|
||||
We must unfreeze *all three* — otherwise gradients still die
|
||||
in the frozen final block and the lm_head learns nothing.
|
||||
"""
|
||||
vlm_with_expert = getattr(self.model, "vlm_with_expert", None)
|
||||
if vlm_with_expert is None:
|
||||
@@ -87,8 +144,26 @@ class SmolVLA2Policy(SmolVLAPolicy):
|
||||
vlm = getattr(vlm_with_expert, "vlm", None)
|
||||
if vlm is None:
|
||||
return
|
||||
|
||||
# Mirror the freeze targets from ``smolvlm_with_expert.set_requires_grad``.
|
||||
num_vlm = getattr(vlm_with_expert, "num_vlm_layers", None)
|
||||
num_expert = getattr(vlm_with_expert, "num_expert_layers", None)
|
||||
last_layers = []
|
||||
if num_vlm is not None:
|
||||
last_layers.append(num_vlm - 1)
|
||||
if (
|
||||
num_expert is not None
|
||||
and num_vlm != num_expert
|
||||
and num_vlm % num_expert == 0
|
||||
):
|
||||
last_layers.append(num_vlm - 2)
|
||||
unfreeze_prefixes = [
|
||||
"lm_head",
|
||||
"text_model.model.norm.weight",
|
||||
*[f"text_model.model.layers.{layer}." for layer in last_layers],
|
||||
]
|
||||
for name, param in vlm.named_parameters():
|
||||
if "lm_head" in name or "text_model.model.norm.weight" in name:
|
||||
if any(k in name for k in unfreeze_prefixes):
|
||||
param.requires_grad = True
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
@@ -216,49 +291,12 @@ class SmolVLA2Policy(SmolVLAPolicy):
|
||||
"states — text-loss path needs them."
|
||||
)
|
||||
|
||||
# Lang token positions inside the prefix. ``embed_prefix`` lays
|
||||
# out the prefix as ``[image_blocks..., lang, state]`` so the
|
||||
# lang range is identifiable from the trailing state size and
|
||||
# the known lang length.
|
||||
num_lang = lang_tokens.shape[1]
|
||||
state_for_dim = state if state.ndim >= 2 else state[:, None]
|
||||
num_state = state_for_dim.shape[1] if state_for_dim.ndim >= 2 else 1
|
||||
if num_state < 1:
|
||||
num_state = 1
|
||||
prefix_len = prefix_out.shape[1]
|
||||
lang_end = prefix_len - num_state
|
||||
lang_start = lang_end - num_lang
|
||||
if lang_start < 0 or lang_end > prefix_len:
|
||||
raise RuntimeError(
|
||||
f"SmolVLA2: could not locate lang token range in prefix "
|
||||
f"(prefix_len={prefix_len}, num_lang={num_lang}, "
|
||||
f"num_state={num_state})."
|
||||
)
|
||||
|
||||
lang_start, lang_end = _locate_lang_range(prefix_att_masks, lang_tokens.shape[1])
|
||||
vlm = self.model.vlm_with_expert.vlm
|
||||
lang_hidden = prefix_out[:, lang_start:lang_end].to(vlm.lm_head.weight.dtype)
|
||||
logits = vlm.lm_head(lang_hidden) # (B, num_lang, vocab)
|
||||
|
||||
if text_labels.shape[1] != num_lang:
|
||||
common = min(text_labels.shape[1], num_lang)
|
||||
logits = logits[:, :common]
|
||||
text_labels = text_labels[:, :common]
|
||||
|
||||
# Standard next-token CE: hidden state at position t predicts
|
||||
# token at position t+1. Shift logits left, labels right by 1.
|
||||
# Without this, the loss is identity-mapped and the LM head
|
||||
# learns nothing useful — see HuggingFace ``LlamaForCausalLM``
|
||||
# for the same convention.
|
||||
shift_logits = logits[:, :-1, :].contiguous()
|
||||
shift_labels = text_labels[:, 1:].contiguous().long()
|
||||
valid_labels = shift_labels != -100
|
||||
loss = F.cross_entropy(
|
||||
shift_logits.reshape(-1, shift_logits.shape[-1]),
|
||||
shift_labels.reshape(-1),
|
||||
ignore_index=-100,
|
||||
reduction="sum",
|
||||
)
|
||||
return loss / valid_labels.sum().clamp(min=1)
|
||||
return _shifted_ce(logits, text_labels)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Fused flow + text loss (single backbone forward)
|
||||
@@ -286,8 +324,6 @@ class SmolVLA2Policy(SmolVLAPolicy):
|
||||
and text paths separately — same trick PI052Policy uses in
|
||||
``_compute_all_losses_fused``.
|
||||
"""
|
||||
from ..smolvla.modeling_smolvla import resize_with_pad # noqa: F401 (kept for parity)
|
||||
|
||||
cfg = self.config
|
||||
if cfg.adapt_to_pi_aloha:
|
||||
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
|
||||
@@ -361,39 +397,11 @@ class SmolVLA2Policy(SmolVLAPolicy):
|
||||
flow_loss = per_sample_flow.mean()
|
||||
|
||||
# ---------------- text loss (lang slice of prefix) ---------------
|
||||
num_lang = lang_tokens.shape[1]
|
||||
state_for_dim = state if state.ndim >= 2 else state[:, None]
|
||||
num_state = state_for_dim.shape[1] if state_for_dim.ndim >= 2 else 1
|
||||
if num_state < 1:
|
||||
num_state = 1
|
||||
prefix_len = prefix_out.shape[1]
|
||||
lang_end = prefix_len - num_state
|
||||
lang_start = lang_end - num_lang
|
||||
if lang_start < 0 or lang_end > prefix_len:
|
||||
raise RuntimeError(
|
||||
f"SmolVLA2: fused forward could not locate lang range "
|
||||
f"(prefix_len={prefix_len}, num_lang={num_lang}, "
|
||||
f"num_state={num_state})."
|
||||
)
|
||||
lang_start, lang_end = _locate_lang_range(prefix_att_masks, lang_tokens.shape[1])
|
||||
vlm = inner.vlm_with_expert.vlm
|
||||
lang_hidden = prefix_out[:, lang_start:lang_end].to(vlm.lm_head.weight.dtype)
|
||||
logits = vlm.lm_head(lang_hidden)
|
||||
|
||||
if text_labels.shape[1] != num_lang:
|
||||
common = min(text_labels.shape[1], num_lang)
|
||||
logits = logits[:, :common]
|
||||
text_labels = text_labels[:, :common]
|
||||
|
||||
shift_logits = logits[:, :-1, :].contiguous()
|
||||
shift_labels = text_labels[:, 1:].contiguous().long()
|
||||
valid_labels = shift_labels != -100
|
||||
ce = F.cross_entropy(
|
||||
shift_logits.reshape(-1, shift_logits.shape[-1]),
|
||||
shift_labels.reshape(-1),
|
||||
ignore_index=-100,
|
||||
reduction="sum",
|
||||
)
|
||||
text_loss = ce / valid_labels.sum().clamp(min=1)
|
||||
text_loss = _shifted_ce(logits, text_labels)
|
||||
|
||||
return flow_loss, text_loss, flow_diag
|
||||
|
||||
|
||||
Reference in New Issue
Block a user