fix(smolvla2,pi052): training-correctness audit fixes

CRITICAL (smolvla2) — text-CE was applied to the wrong prefix slice.
``num_state`` was being read from ``state.shape[1]`` (the raw
max_state_dim, ~14-32) instead of the *number of state tokens*
(always 1). Compounded by the trailing-padding issue (state is
not at the end of the padded prefix when ``seq_len < prefix_length``),
the lang slice was landing on image / padding hidden states.

New ``_locate_lang_range`` finds the state position via
``att_masks.nonzero()`` (the only ``1`` in the mask), making the
slice robust to both bugs. Used by ``_compute_text_loss`` and
``_compute_fused_loss``.

LIKELY-BUG (smolvla2) — ``_unfreeze_lm_head`` only re-enabled
``lm_head`` and ``text_model.model.norm.weight``. SmolVLA's parent
ALSO freezes the last 1-2 transformer layers, so text-loss
gradients died in a frozen final block. Now mirrors the parent's
freeze targets and unfreezes the matching ``layers.{N-1}`` (and
``N-2`` when num_vlm % num_expert == 0).

CRITICAL (pi052) — flow and FAST CE were not per-sample masked
under per-sample-routing. Text-only recipe samples
(``plan_generation``, ``ask_vqa_*``) contributed to flow/FAST
loss with prompts that deliberately omit the subtask, corrupting
the signal. Threaded ``predict_actions_t`` through both
``_compute_all_losses_fused`` and ``_compute_text_and_fast_loss``;
flow uses ``(per_sample * mask).sum() / mask.sum()``, FAST uses
``shift_valid & sample_mask`` before ``masked_fill(-100)``.

OTHER
* PI052Policy.forward now falls through to PI05Policy.forward on
  unannotated batches (no text_labels, no predict_actions, no FAST).
* fit_fast_tokenizer cache key now includes ``chunk_size`` — changing
  the chunk size no longer silently loads a wrongly-fit tokenizer.
* Removed dead ``_compute_text_loss`` / ``_compute_fast_action_loss``
  in pi052 (superseded by the fused helpers).
* Fixed stale "no-op stub" docstring on ``knowledge_insulation`` —
  it's been fully wired since the per-layer KI forward port.
* Stripped unused ``copy`` / ``resize_with_pad`` imports.
* Extracted ``_shifted_ce`` / ``_mask_per_sample`` / ``_fast_ce``
  helpers shared between fused and prefix-only paths.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-05-13 14:08:06 +02:00
parent e3ad1c59fc
commit 129aa207e3
6 changed files with 179 additions and 269 deletions
@@ -152,19 +152,15 @@ class PI052Config(PI05Config):
# Knowledge insulation — paper §III.B --------------------------------
# When enabled, gradients from the action expert's flow loss are
# *blocked* from flowing back into the VLM's K/V projections. This
# blocked from flowing back into the VLM's K/V projections. This
# prevents the action loss from over-fitting the language backbone
# to robot-specific features. Implementation requires a custom
# per-layer attention forward that uses ``.detach()`` on VLM K/V
# when computing attention for action queries — see
# ``pi05_full/modeling_pi05.py::compute_layer_complete_knowledge_insulation``
# on branch ``feat/add-pi05`` for the reference implementation.
# to robot-specific features. Implemented in ``modeling_pi052`` as
# a per-instance monkey-patch on ``paligemma_with_expert.forward``
# that splits queries into VLM and action halves and ``.detach()``-s
# the VLM K/V tensors used in the action-half's attention.
knowledge_insulation: bool = False
"""If True, route every transformer layer through the KI
attention path (blocks action→VLM gradient flow on K/V).
Currently a no-op stub in this branch — PR follow-up required to
port the full custom layer from pi05_full. The flag is exposed
here so SLURM commands can be written against the final shape."""
attention path that blocks action→VLM gradient flow on K/V."""
def __post_init__(self) -> None:
super().__post_init__()
@@ -47,12 +47,17 @@ import numpy as np
logger = logging.getLogger(__name__)
def _dataset_signature(dataset_repo_id: str, base_tokenizer_name: str, n_samples: int) -> str:
def _dataset_signature(
dataset_repo_id: str,
base_tokenizer_name: str,
n_samples: int,
chunk_size: int,
) -> str:
"""Deterministic short hash for naming the cache directory.
Keys on (dataset, base tokenizer, sample count) so re-fitting on a
new dataset or a different base doesn't clobber the prior cache,
and so changing n_samples re-runs the fit.
Keys on (dataset, base tokenizer, sample count, chunk size) so any
of those changing re-runs the fit. ``chunk_size`` matters because
the tokenizer is fit on chunks of that length.
"""
h = hashlib.sha256()
h.update(dataset_repo_id.encode("utf-8"))
@@ -60,6 +65,8 @@ def _dataset_signature(dataset_repo_id: str, base_tokenizer_name: str, n_samples
h.update(base_tokenizer_name.encode("utf-8"))
h.update(b"\0")
h.update(str(n_samples).encode("utf-8"))
h.update(b"\0")
h.update(str(chunk_size).encode("utf-8"))
return h.hexdigest()[:16]
@@ -102,7 +109,7 @@ def fit_fast_tokenizer(
FileNotFoundError: If the dataset can't be loaded.
"""
cache_dir = Path(cache_dir)
sig = _dataset_signature(dataset_repo_id, base_tokenizer_name, n_samples)
sig = _dataset_signature(dataset_repo_id, base_tokenizer_name, n_samples, chunk_size)
out_dir = cache_dir / sig
if out_dir.exists() and (out_dir / "preprocessor_config.json").exists():
+79 -178
View File
@@ -55,6 +55,65 @@ from .configuration_pi052 import PI052Config
logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------
# Loss helpers (shared between fused and prefix-only paths)
# ----------------------------------------------------------------------
def _mask_per_sample(per_sample: Tensor, predict_actions_t: Tensor | None) -> Tensor:
"""Mean over samples where ``predict_actions_t`` is True, else over all."""
if predict_actions_t is None:
return per_sample.mean()
mask = predict_actions_t.to(per_sample.dtype)
return (per_sample * mask).sum() / mask.sum().clamp(min=1.0)
def _shifted_ce(logits: Tensor, labels: Tensor) -> Tensor:
"""Next-token CE: hidden at t predicts label at t+1, ignore_index=-100.
Mean over non-ignored positions across the batch. Returns 0 cleanly
when no positions are supervised (clamp(min=1) on the denominator).
"""
shift_logits = logits[:, :-1, :].contiguous()
shift_labels = labels[:, 1:].contiguous().long()
valid = shift_labels != -100
loss = F.cross_entropy(
shift_logits.reshape(-1, shift_logits.shape[-1]),
shift_labels.reshape(-1),
ignore_index=-100,
reduction="sum",
)
return loss / valid.sum().clamp(min=1)
def _fast_ce(
fast_logits: Tensor,
action_tokens: Tensor,
action_mask: Tensor,
predict_actions_t: Tensor | None,
) -> Tensor:
"""FAST-CE with both token-pad masking and per-sample action gating.
``action_mask`` is the FAST tokenizer's padding mask; samples whose
recipe sets ``predict_actions=False`` (e.g. plan_generation,
ask_vqa_*) get *all* their FAST positions masked out via the
per-sample gate.
"""
shift_logits = fast_logits[:, :-1, :].contiguous()
shift_targets = action_tokens[:, 1:].contiguous().long()
shift_valid = action_mask[:, 1:].contiguous().bool()
if predict_actions_t is not None:
sample_mask = predict_actions_t[:, None].expand_as(shift_valid)
shift_valid = shift_valid & sample_mask
shift_targets = shift_targets.masked_fill(~shift_valid, -100)
return F.cross_entropy(
shift_logits.reshape(-1, shift_logits.shape[-1]),
shift_targets.reshape(-1),
ignore_index=-100,
reduction="sum",
) / shift_valid.sum().clamp(min=1)
# ----------------------------------------------------------------------
# Knowledge insulation — ported from pi05_full (branch ``feat/add-pi05``)
# ----------------------------------------------------------------------
@@ -307,6 +366,16 @@ class PI052Policy(PI05Policy):
text_labels = batch.get("text_labels")
predict_actions_t = batch.get("predict_actions")
# Unannotated datasets: no recipe applied → no text_labels and
# no FAST / predict_actions routing. Defer to PI05Policy so the
# plain flow-only training surface keeps working unchanged.
if (
text_labels is None
and predict_actions_t is None
and not getattr(self.config, "enable_fast_action_loss", False)
):
return super().forward(batch, reduction=reduction)
run_flow = (
self.config.flow_loss_weight > 0
and (predict_actions_t is None or bool(predict_actions_t.any().item()))
@@ -350,6 +419,7 @@ class PI052Policy(PI05Policy):
text_labels=text_labels if run_text else None,
action_tokens=action_tokens if run_fast else None,
action_mask=action_mask if run_fast else None,
predict_actions_t=predict_actions_t,
)
loss_dict["flow_loss"] = float(flow_loss.detach().item())
total = self.config.flow_loss_weight * flow_loss
@@ -365,6 +435,7 @@ class PI052Policy(PI05Policy):
text_labels=text_labels if run_text else None,
action_tokens=action_tokens if run_fast else None,
action_mask=action_mask if run_fast else None,
predict_actions_t=predict_actions_t,
)
if text_loss is not None:
loss_dict["text_loss"] = float(text_loss.detach().item())
@@ -399,6 +470,7 @@ class PI052Policy(PI05Policy):
text_labels: Tensor | None,
action_tokens: Tensor | None,
action_mask: Tensor | None,
predict_actions_t: Tensor | None = None,
) -> tuple[Tensor, Tensor | None, Tensor | None]:
"""Full fusion: flow + text + FAST in ONE backbone forward.
@@ -510,7 +582,8 @@ class PI052Policy(PI05Policy):
# internally to max_action_dim).
original_action_dim = self.config.output_features[ACTION].shape[0]
flow_per_dim = flow_per_dim[:, :, :original_action_dim]
flow_loss = flow_per_dim.mean()
per_sample_flow = flow_per_dim.mean(dim=(1, 2))
flow_loss = _mask_per_sample(per_sample_flow, predict_actions_t)
# ---- text + FAST CE from prefix_out ------------------------
lm_head = self.model.paligemma_with_expert.paligemma.lm_head
@@ -523,27 +596,13 @@ class PI052Policy(PI05Policy):
else:
text_hidden = prefix_out[:, -lang_len:, :]
text_logits = lm_head(text_hidden.to(lm_head.weight.dtype))
shift_logits = text_logits[:, :-1, :].contiguous()
shift_labels = text_labels[:, 1:].contiguous().long()
text_loss = F.cross_entropy(
shift_logits.reshape(-1, shift_logits.shape[-1]),
shift_labels.reshape(-1),
ignore_index=-100,
)
text_loss = _shifted_ce(text_logits, text_labels)
fast_loss: Tensor | None = None
if fast_len > 0 and prefix_out is not None:
fast_hidden = prefix_out[:, -fast_len:, :]
fast_logits = lm_head(fast_hidden.to(lm_head.weight.dtype))
shift_logits = fast_logits[:, :-1, :].contiguous()
shift_targets = action_tokens[:, 1:].contiguous().long()
shift_valid = action_mask[:, 1:].contiguous().bool()
shift_targets = shift_targets.masked_fill(~shift_valid, -100)
fast_loss = F.cross_entropy(
shift_logits.reshape(-1, shift_logits.shape[-1]),
shift_targets.reshape(-1),
ignore_index=-100,
)
fast_loss = _fast_ce(fast_logits, action_tokens, action_mask, predict_actions_t)
return flow_loss, text_loss, fast_loss
@@ -553,6 +612,7 @@ class PI052Policy(PI05Policy):
text_labels: Tensor | None,
action_tokens: Tensor | None,
action_mask: Tensor | None,
predict_actions_t: Tensor | None = None,
) -> tuple[Tensor | None, Tensor | None]:
"""Single prefix forward → text CE + FAST CE.
@@ -619,180 +679,21 @@ class PI052Policy(PI05Policy):
lang_len = text_labels.shape[1]
# embed_prefix lays out as [images, language]; with FAST
# appended the full sequence is [images, language, FAST].
# Language hidden states are at positions
# ``[-(fast_len + lang_len) : -fast_len]`` when FAST is
# present, or ``[-lang_len:]`` otherwise.
if fast_len > 0:
text_hidden = vlm_out[:, -(fast_len + lang_len):-fast_len, :]
else:
text_hidden = vlm_out[:, -lang_len:, :]
text_logits = lm_head(text_hidden.to(lm_head.weight.dtype))
shift_logits = text_logits[:, :-1, :].contiguous()
shift_labels = text_labels[:, 1:].contiguous().long()
text_loss = F.cross_entropy(
shift_logits.reshape(-1, shift_logits.shape[-1]),
shift_labels.reshape(-1),
ignore_index=-100,
)
text_loss = _shifted_ce(text_logits, text_labels)
fast_loss: Tensor | None = None
if action_tokens is not None and action_mask is not None and fast_len > 0:
fast_hidden = vlm_out[:, -fast_len:, :]
fast_logits = lm_head(fast_hidden.to(lm_head.weight.dtype))
shift_logits = fast_logits[:, :-1, :].contiguous()
shift_targets = action_tokens[:, 1:].contiguous().long()
shift_valid = action_mask[:, 1:].contiguous().bool()
shift_targets = shift_targets.masked_fill(~shift_valid, -100)
fast_loss = F.cross_entropy(
shift_logits.reshape(-1, shift_logits.shape[-1]),
shift_targets.reshape(-1),
ignore_index=-100,
)
fast_loss = _fast_ce(fast_logits, action_tokens, action_mask, predict_actions_t)
return text_loss, fast_loss
def _compute_fast_action_loss(
self,
batch: dict[str, Tensor],
action_tokens: Tensor,
action_mask: Tensor,
) -> Tensor:
"""Cross-entropy on FAST-tokenised actions via PaliGemma's LM head.
Mirrors the paper's §III.C action-token loss: append the FAST
tokens to the language prefix, forward through the backbone,
slice the per-token logits at the action positions, and
compute CE against the FAST targets (shifted for next-token
prediction). Token-mask gates the loss to valid positions.
"""
from ..pi05.modeling_pi05 import make_att_2d_masks # noqa: PLC0415
images, img_masks = self.model._preprocess_images(batch)
lang_tokens = batch[OBS_LANGUAGE_TOKENS]
lang_masks = batch[OBS_LANGUAGE_ATTENTION_MASK]
# Re-embed prefix to get [images, language] embeddings, then
# append the FAST action token embeddings as additional
# prefix tokens. PaliGemma's embed_language_tokens is shared
# between text and FAST tokens so we can re-use it directly.
prefix_embs, prefix_pad, prefix_att = self.model.embed_prefix(
images, img_masks, lang_tokens, lang_masks
)
emb_dim = prefix_embs.shape[-1]
fast_emb = self.model.paligemma_with_expert.embed_language_tokens(action_tokens)
fast_emb = fast_emb * math.sqrt(emb_dim)
# Concat onto the prefix.
# pad masks: language uses ``lang_masks``; FAST uses
# ``action_mask`` (True at valid token positions).
# att masks: prefix is 0 (bidirectional block); FAST is 1
# (each token starts its own causal block). Per
# ``make_att_2d_masks``'s mask_ar convention this
# yields prefix-LM attention: FAST tokens attend
# bidirectionally to images+language and causally
# among themselves, while prefix tokens *cannot*
# see FAST tokens. Matches pi05_full §III.C.
fast_len = action_tokens.shape[1]
device = prefix_embs.device
ones_att = torch.ones((action_tokens.shape[0], fast_len), dtype=torch.bool, device=device)
full_embs = torch.cat([prefix_embs, fast_emb], dim=1)
full_pad = torch.cat([prefix_pad, action_mask.to(prefix_pad.dtype)], dim=1)
full_att = torch.cat([prefix_att, ones_att], dim=1)
att_2d = make_att_2d_masks(full_pad, full_att)
position_ids = torch.cumsum(full_pad, dim=1) - 1
(vlm_out, _), _ = self.model.paligemma_with_expert.forward(
attention_mask=att_2d,
position_ids=position_ids,
past_key_values=None,
inputs_embeds=[full_embs, None],
use_cache=False,
)
if vlm_out is None:
raise RuntimeError("PI052 FAST loss: VLM forward returned no hidden states.")
# Slice the last ``fast_len`` positions — those correspond to
# the FAST tokens we just appended.
fast_hidden = vlm_out[:, -fast_len:, :]
lm_head = self.model.paligemma_with_expert.paligemma.lm_head
fast_logits = lm_head(fast_hidden.to(lm_head.weight.dtype))
# Shift for next-token prediction. Replace targets at padded
# positions with -100 so ``ignore_index`` in cross_entropy
# cleanly drops them rather than relying on a post-hoc
# multiply-by-mask (which still computes the CE numerator at
# invalid positions and could crash if a padded target id
# falls outside the vocab).
shift_logits = fast_logits[:, :-1, :].contiguous()
shift_targets = action_tokens[:, 1:].contiguous().long()
shift_valid = action_mask[:, 1:].contiguous().bool()
shift_targets = shift_targets.masked_fill(~shift_valid, -100)
# Mean over valid positions via ``ignore_index``.
loss = F.cross_entropy(
shift_logits.reshape(-1, shift_logits.shape[-1]),
shift_targets.reshape(-1),
ignore_index=-100,
)
return loss
def _compute_text_loss(self, batch: dict[str, Tensor], text_labels: Tensor) -> Tensor:
"""Cross-entropy on PaliGemma's LM head over the supervised span.
Embeds images + language, runs the VLM-only forward (the
action expert is skipped via ``inputs_embeds=[..., None]``),
slices the hidden states to the *language* portion so they
align with ``text_labels`` (which covers only the language
tokens, not the image patch tokens), then computes shifted
next-token CE with ``-100`` ignoring padding/non-target
positions.
"""
from ..pi05.modeling_pi05 import make_att_2d_masks # noqa: PLC0415
images, img_masks = self.model._preprocess_images(batch)
tokens = batch[OBS_LANGUAGE_TOKENS]
masks = batch[OBS_LANGUAGE_ATTENTION_MASK]
prefix_embs, prefix_pad_masks, prefix_att_masks = self.model.embed_prefix(
images, img_masks, tokens, masks
)
att_2d = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
(vlm_out, _), _ = self.model.paligemma_with_expert.forward(
attention_mask=att_2d,
position_ids=position_ids,
past_key_values=None,
inputs_embeds=[prefix_embs, None],
use_cache=False,
)
if vlm_out is None:
raise RuntimeError("PI052 text loss: VLM forward returned no hidden states.")
# Slice the hidden states to the language portion. embed_prefix
# concatenates [images, language] in that order, so the trailing
# ``text_labels.shape[1]`` positions are the language tokens.
# Without this slice, applying lm_head to the full vlm_out and
# shifting against text_labels[..., 1:] produces a shape
# mismatch in cross_entropy.
lang_len = text_labels.shape[1]
text_hidden = vlm_out[:, -lang_len:, :]
lm_head = self.model.paligemma_with_expert.paligemma.lm_head
logits = lm_head(text_hidden.to(lm_head.weight.dtype))
# Shift for next-token prediction: predict token[i+1] from
# hidden[i] within the language span.
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = text_labels[..., 1:].contiguous()
loss = F.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
ignore_index=-100,
)
return loss
# ------------------------------------------------------------------
# select_message — AR text generation at inference
# ------------------------------------------------------------------
@@ -36,7 +36,6 @@ Outputs:
from __future__ import annotations
import copy
import logging
from dataclasses import dataclass
from typing import Any
@@ -38,7 +38,6 @@ matching the chat-template-stripped text order).
from __future__ import annotations
import copy
import logging
from dataclasses import dataclass
from typing import Any
@@ -54,6 +54,56 @@ from ..smolvla.modeling_smolvla import SmolVLAPolicy, make_att_2d_masks
from .configuration_smolvla2 import SmolVLA2Config
def _locate_lang_range(prefix_att_masks: Tensor, num_lang: int) -> tuple[int, int]:
"""Find ``[lang_start, lang_end)`` inside the SmolVLA prefix.
``embed_prefix`` lays out the prefix as
``[image_blocks..., lang, state, padding]`` with the att-mask
convention ``[0]*image, [0]*lang, [1]*state, [0]*padding`` (see
``modeling_smolvla.SmolVLAModel.embed_prefix``). State is exactly
one token, and it's the *only* position with ``att_mask == 1``,
so we use the first ``1`` to anchor lang_end. Computing it this
way is robust to (a) state being projected to one embedding token
regardless of its raw feature dim, and (b) the trailing padding
added when ``seq_len < prefix_length``.
"""
row = prefix_att_masks[0]
ones = row.nonzero(as_tuple=False)
if ones.numel() == 0:
raise RuntimeError(
"SmolVLA2: state token not found in prefix att_masks — "
"can't locate language range."
)
state_start = int(ones[0, 0].item())
lang_end = state_start
lang_start = lang_end - num_lang
if lang_start < 0:
raise RuntimeError(
f"SmolVLA2: lang range underflows prefix "
f"(state_start={state_start}, num_lang={num_lang})."
)
return lang_start, lang_end
def _shifted_ce(logits: Tensor, text_labels: Tensor) -> Tensor:
"""Next-token CE: hidden at t predicts label at t+1, ignore_index=-100."""
num_lang = logits.shape[1]
if text_labels.shape[1] != num_lang:
common = min(text_labels.shape[1], num_lang)
logits = logits[:, :common]
text_labels = text_labels[:, :common]
shift_logits = logits[:, :-1, :].contiguous()
shift_labels = text_labels[:, 1:].contiguous().long()
valid = shift_labels != -100
loss = F.cross_entropy(
shift_logits.reshape(-1, shift_logits.shape[-1]),
shift_labels.reshape(-1),
ignore_index=-100,
reduction="sum",
)
return loss / valid.sum().clamp(min=1)
class SmolVLA2Policy(SmolVLAPolicy):
"""SmolVLA + re-enabled SmolVLM language head."""
@@ -78,8 +128,15 @@ class SmolVLA2Policy(SmolVLAPolicy):
# ------------------------------------------------------------------
def _unfreeze_lm_head(self) -> None:
"""Re-enable gradients on the SmolVLM ``lm_head`` (and the bits
of the text path SmolVLA freezes) so the text-loss can flow back.
"""Re-enable gradients on the text-output path so the LM head
loss can flow back.
SmolVLA's ``set_requires_grad`` freezes three things when
``train_expert_only=False``: ``lm_head``,
``text_model.model.norm.weight``, and the last 1-2 text-model
transformer layers (see ``smolvlm_with_expert.py:167-176``).
We must unfreeze *all three* otherwise gradients still die
in the frozen final block and the lm_head learns nothing.
"""
vlm_with_expert = getattr(self.model, "vlm_with_expert", None)
if vlm_with_expert is None:
@@ -87,8 +144,26 @@ class SmolVLA2Policy(SmolVLAPolicy):
vlm = getattr(vlm_with_expert, "vlm", None)
if vlm is None:
return
# Mirror the freeze targets from ``smolvlm_with_expert.set_requires_grad``.
num_vlm = getattr(vlm_with_expert, "num_vlm_layers", None)
num_expert = getattr(vlm_with_expert, "num_expert_layers", None)
last_layers = []
if num_vlm is not None:
last_layers.append(num_vlm - 1)
if (
num_expert is not None
and num_vlm != num_expert
and num_vlm % num_expert == 0
):
last_layers.append(num_vlm - 2)
unfreeze_prefixes = [
"lm_head",
"text_model.model.norm.weight",
*[f"text_model.model.layers.{layer}." for layer in last_layers],
]
for name, param in vlm.named_parameters():
if "lm_head" in name or "text_model.model.norm.weight" in name:
if any(k in name for k in unfreeze_prefixes):
param.requires_grad = True
# ------------------------------------------------------------------
@@ -216,49 +291,12 @@ class SmolVLA2Policy(SmolVLAPolicy):
"states — text-loss path needs them."
)
# Lang token positions inside the prefix. ``embed_prefix`` lays
# out the prefix as ``[image_blocks..., lang, state]`` so the
# lang range is identifiable from the trailing state size and
# the known lang length.
num_lang = lang_tokens.shape[1]
state_for_dim = state if state.ndim >= 2 else state[:, None]
num_state = state_for_dim.shape[1] if state_for_dim.ndim >= 2 else 1
if num_state < 1:
num_state = 1
prefix_len = prefix_out.shape[1]
lang_end = prefix_len - num_state
lang_start = lang_end - num_lang
if lang_start < 0 or lang_end > prefix_len:
raise RuntimeError(
f"SmolVLA2: could not locate lang token range in prefix "
f"(prefix_len={prefix_len}, num_lang={num_lang}, "
f"num_state={num_state})."
)
lang_start, lang_end = _locate_lang_range(prefix_att_masks, lang_tokens.shape[1])
vlm = self.model.vlm_with_expert.vlm
lang_hidden = prefix_out[:, lang_start:lang_end].to(vlm.lm_head.weight.dtype)
logits = vlm.lm_head(lang_hidden) # (B, num_lang, vocab)
if text_labels.shape[1] != num_lang:
common = min(text_labels.shape[1], num_lang)
logits = logits[:, :common]
text_labels = text_labels[:, :common]
# Standard next-token CE: hidden state at position t predicts
# token at position t+1. Shift logits left, labels right by 1.
# Without this, the loss is identity-mapped and the LM head
# learns nothing useful — see HuggingFace ``LlamaForCausalLM``
# for the same convention.
shift_logits = logits[:, :-1, :].contiguous()
shift_labels = text_labels[:, 1:].contiguous().long()
valid_labels = shift_labels != -100
loss = F.cross_entropy(
shift_logits.reshape(-1, shift_logits.shape[-1]),
shift_labels.reshape(-1),
ignore_index=-100,
reduction="sum",
)
return loss / valid_labels.sum().clamp(min=1)
return _shifted_ce(logits, text_labels)
# ------------------------------------------------------------------
# Fused flow + text loss (single backbone forward)
@@ -286,8 +324,6 @@ class SmolVLA2Policy(SmolVLAPolicy):
and text paths separately same trick PI052Policy uses in
``_compute_all_losses_fused``.
"""
from ..smolvla.modeling_smolvla import resize_with_pad # noqa: F401 (kept for parity)
cfg = self.config
if cfg.adapt_to_pi_aloha:
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
@@ -361,39 +397,11 @@ class SmolVLA2Policy(SmolVLAPolicy):
flow_loss = per_sample_flow.mean()
# ---------------- text loss (lang slice of prefix) ---------------
num_lang = lang_tokens.shape[1]
state_for_dim = state if state.ndim >= 2 else state[:, None]
num_state = state_for_dim.shape[1] if state_for_dim.ndim >= 2 else 1
if num_state < 1:
num_state = 1
prefix_len = prefix_out.shape[1]
lang_end = prefix_len - num_state
lang_start = lang_end - num_lang
if lang_start < 0 or lang_end > prefix_len:
raise RuntimeError(
f"SmolVLA2: fused forward could not locate lang range "
f"(prefix_len={prefix_len}, num_lang={num_lang}, "
f"num_state={num_state})."
)
lang_start, lang_end = _locate_lang_range(prefix_att_masks, lang_tokens.shape[1])
vlm = inner.vlm_with_expert.vlm
lang_hidden = prefix_out[:, lang_start:lang_end].to(vlm.lm_head.weight.dtype)
logits = vlm.lm_head(lang_hidden)
if text_labels.shape[1] != num_lang:
common = min(text_labels.shape[1], num_lang)
logits = logits[:, :common]
text_labels = text_labels[:, :common]
shift_logits = logits[:, :-1, :].contiguous()
shift_labels = text_labels[:, 1:].contiguous().long()
valid_labels = shift_labels != -100
ce = F.cross_entropy(
shift_logits.reshape(-1, shift_logits.shape[-1]),
shift_labels.reshape(-1),
ignore_index=-100,
reduction="sum",
)
text_loss = ce / valid_labels.sum().clamp(min=1)
text_loss = _shifted_ce(logits, text_labels)
return flow_loss, text_loss, flow_diag