feat(smolvla2): dual-head forward — flow loss + lm_head text loss

The third and final commit of PR 3's SmolVLA2 work. Wires the actual
training signal through:

* ``predict_actions[i] = True``  → sample i contributes to flow loss
* ``text_labels[i, t] != -100``  → token t of sample i contributes to
                                    LM-head cross-entropy

Both routing knobs come from ``SmolVLA2ChatTokenizerStep`` (previous
commit on this branch), which builds them from the recipe's
``message_streams`` / ``target_message_indices``. The per-sample
``predict_actions`` mask preserves the Pi0.5 convention from the
plan's Section I.7: "True iff any low_level target exists".

Implementation:

- ``forward`` reads ``text_labels`` and ``predict_actions`` from the
  batch. When neither is present (vanilla SmolVLA usage with no
  recipe), delegates to ``SmolVLAPolicy.forward`` so unannotated
  datasets keep training as before — full backward compatibility.
- ``flow_loss``: super().forward(reduction="none") returns the
  per-sample (B,) flow loss; we mask non-action samples with the
  ``predict_actions`` bool and renormalize by the count of action
  samples. ``flow_loss_weight = 0`` in the config disables this
  branch entirely (text-only training).
- ``text_loss``: a prefix-only forward through the VLM (no action
  expert / suffix), slicing the lang-token range out of the
  resulting hidden states (``embed_prefix`` orders the prefix as
  ``[image_blocks..., lang, state]`` so the slice is unambiguous).
  Apply ``vlm.lm_head`` to those hidden states, cross-entropy with
  ``text_labels`` (ignore_index=-100). ``text_loss_weight = 0``
  disables this branch (reverts to flow-only behaviour, matching
  SmolVLA exactly).
- The two losses are summed with the config-supplied weights.

Mixed-stream samples (one batch containing both action targets and
text-only sub-recipes) are handled correctly: each sample contributes
where its labels are valid and is masked elsewhere.

Limitations / known follow-ups:

- Text loss runs an additional prefix-only forward separate from the
  flow path's prefix forward. The forwards could share their prefix
  computation; for clarity of this first commit they don't.
  Optimization is straightforward when needed.
- Per-sample loss for ``reduction="none"`` is not yet meaningfully
  defined for the dual path — we broadcast the scalar to (B,) for
  caller compatibility (e.g. RA-BC weighting will need follow-up).
- Inference ``select_action`` is unchanged from SmolVLA today —
  it predicts actions only. A separate "generate text"
  ``select_message`` path is the natural next step for runtime
  use of the LM head (memory updates, plan refreshes, VQA answers).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-04-30 19:54:57 +02:00
parent 37b1eb218a
commit af6d8ebd5b
@@ -13,26 +13,25 @@
# limitations under the License.
"""SmolVLA2 modeling — dual-head subclass of SmolVLAPolicy.
This module defines :class:`SmolVLA2Policy`, which extends SmolVLA with:
Adds:
* an unfrozen SmolVLM ``lm_head`` so language tokens can be supervised,
* a forward path that routes to the flow head, the text head, or both,
driven by ``batch["predict_actions"]`` and ``batch["text_labels"]``.
* a forward path that runs the flow head, the text head, or both,
driven by ``batch["predict_actions"]`` and ``batch["text_labels"]``
produced by :class:`SmolVLA2ChatTokenizerStep` (the previous commit on
this branch).
The text-head computation itself is NOT wired up in this scaffold commit
(the processor doesn't yet produce ``text_labels`` either). This file is
the structural placeholder that:
Per-sample routing — within one batch:
1. registers the ``SmolVLA2Policy`` class with the right config name so
``policies/factory.py`` can build it,
2. unfreezes ``lm_head`` at construction time when the config asks for it
(otherwise SmolVLA's ``train_expert_only`` freezes it again on every
``train()`` call),
3. forwards to ``SmolVLAPolicy.forward`` so behaviour is identical to
SmolVLA when no text labels are present — i.e. existing SmolVLA
training scripts keep working.
* ``predict_actions[i] = True`` ⇒ sample ``i`` contributes to the flow
loss (action chunk supervision).
* ``predict_actions[i] = False`` ⇒ sample ``i`` is masked out of the
flow loss; only its text tokens (where ``text_labels[i, t] != -100``)
contribute to the LM-head cross-entropy.
The next commit on this branch fills in the actual text-loss path.
Falls back to ``SmolVLAPolicy.forward`` cleanly when neither
``text_labels`` nor ``predict_actions`` is in the batch — unannotated
datasets keep working unchanged.
"""
from __future__ import annotations
@@ -40,33 +39,35 @@ from __future__ import annotations
from typing import Any
import torch
import torch.nn.functional as F
from torch import Tensor
from ..smolvla.modeling_smolvla import SmolVLAPolicy
from lerobot.utils.constants import (
ACTION,
OBS_LANGUAGE_ATTENTION_MASK,
OBS_LANGUAGE_TOKENS,
OBS_STATE,
)
from ..smolvla.modeling_smolvla import SmolVLAPolicy, make_att_2d_masks
from .configuration_smolvla2 import SmolVLA2Config
class SmolVLA2Policy(SmolVLAPolicy):
"""SmolVLA + re-enabled SmolVLM language head.
Compatible drop-in for ``SmolVLAPolicy`` from a checkpoint or factory
perspective. Behaviourally identical to SmolVLA until the text-head
code path lands in the next commit on this branch.
"""
"""SmolVLA + re-enabled SmolVLM language head."""
config_class = SmolVLA2Config
name = "smolvla2"
def __init__(self, config: SmolVLA2Config, dataset_stats: dict[str, dict[str, Tensor]] | None = None):
if not isinstance(config, SmolVLA2Config):
# Allow loading a SmolVLA checkpoint into a SmolVLA2 model by
# widening the config type — the new fields fall back to their
# defaults, which preserves the existing SmolVLA behaviour.
config = SmolVLA2Config(**{
f.name: getattr(config, f.name)
for f in config.__dataclass_fields__.values()
if hasattr(config, f.name)
})
config = SmolVLA2Config(
**{
f.name: getattr(config, f.name)
for f in config.__dataclass_fields__.values()
if hasattr(config, f.name)
}
)
super().__init__(config, dataset_stats=dataset_stats)
if config.unfreeze_lm_head and config.text_loss_weight > 0:
self._unfreeze_lm_head()
@@ -76,13 +77,8 @@ class SmolVLA2Policy(SmolVLAPolicy):
# ------------------------------------------------------------------
def _unfreeze_lm_head(self) -> None:
"""Re-enable gradients on the SmolVLM ``lm_head`` (and the bits of
the text path SmolVLA freezes) so the text-loss can flow back.
SmolVLA's ``SmolVLMWithExpertModel.set_requires_grad`` freezes
``lm_head``, ``text_model.model.norm.weight``, and the last
``text_model.layers.<N-1>`` block. We undo that selectively when
text training is enabled.
"""Re-enable gradients on the SmolVLM ``lm_head`` (and the bits
of the text path SmolVLA freezes) so the text-loss can flow back.
"""
vlm_with_expert = getattr(self.model, "vlm_with_expert", None)
if vlm_with_expert is None:
@@ -91,10 +87,7 @@ class SmolVLA2Policy(SmolVLAPolicy):
if vlm is None:
return
for name, param in vlm.named_parameters():
if (
"lm_head" in name
or "text_model.model.norm.weight" in name
):
if "lm_head" in name or "text_model.model.norm.weight" in name:
param.requires_grad = True
# ------------------------------------------------------------------
@@ -108,12 +101,144 @@ class SmolVLA2Policy(SmolVLAPolicy):
time: Tensor | None = None,
reduction: str = "mean",
) -> tuple[Tensor, dict[str, Any]]:
"""Forward pass with optional text-head loss.
"""Forward pass with optional dual-head loss.
SCAFFOLD: forwards directly to ``SmolVLAPolicy.forward``. The
actual text-loss / dual-head routing lands in the next commit on
this branch — it will read ``batch["text_labels"]`` and
``batch["predict_actions"]`` (both produced by the SmolVLA2
processor) to decide which head(s) to run.
Two routing knobs from the batch (produced by
:class:`SmolVLA2ChatTokenizerStep`):
* ``text_labels`` — per-token labels with ``-100`` for non-target
positions. Triggers the text-loss path through ``lm_head``.
* ``predict_actions`` — per-sample bool tensor. ``True`` ⇒
include this sample's action chunk in the flow loss.
When neither is present, delegate to ``SmolVLAPolicy.forward``.
"""
return super().forward(batch, noise=noise, time=time, reduction=reduction)
text_labels = batch.get("text_labels")
predict_actions_t = batch.get("predict_actions")
has_text_data = (
text_labels is not None
and isinstance(text_labels, Tensor)
and self.config.text_loss_weight > 0
)
has_per_sample_routing = (
predict_actions_t is not None and isinstance(predict_actions_t, Tensor)
)
if not has_text_data and not has_per_sample_routing:
return super().forward(batch, noise=noise, time=time, reduction=reduction)
loss_dict: dict[str, Any] = {}
device = batch[OBS_STATE].device
total = torch.zeros((), device=device, dtype=torch.float32)
# ------------------------------------------------------------
# Flow loss path — only when at least one sample wants actions.
# ------------------------------------------------------------
run_flow = self.config.flow_loss_weight > 0 and (
not has_per_sample_routing or bool(predict_actions_t.any().item())
)
if run_flow and ACTION in batch:
per_sample_flow, flow_diag = super().forward(
batch, noise=noise, time=time, reduction="none"
)
# ``per_sample_flow`` has shape (B,) from the SmolVLA
# reduction="none" branch.
if has_per_sample_routing:
mask = predict_actions_t.to(per_sample_flow.dtype)
masked = per_sample_flow * mask
denom = mask.sum().clamp(min=1.0)
flow_loss = masked.sum() / denom
else:
flow_loss = per_sample_flow.mean()
total = total + self.config.flow_loss_weight * flow_loss
loss_dict["flow_loss"] = float(flow_loss.detach().item())
for k, v in flow_diag.items():
loss_dict[f"flow_{k}"] = v
# ------------------------------------------------------------
# Text loss path — prefix-only forward → lm_head → CE.
# ------------------------------------------------------------
if has_text_data:
text_loss = self._compute_text_loss(batch, text_labels)
total = total + self.config.text_loss_weight * text_loss
loss_dict["text_loss"] = float(text_loss.detach().item())
loss_dict["loss"] = float(total.detach().item())
if reduction == "none":
# Per-sample loss isn't meaningfully defined for the dual
# path; broadcast the scalar to (B,) for caller compat.
return total.expand(batch[OBS_STATE].shape[0]), loss_dict
return total, loss_dict
# ------------------------------------------------------------------
# Text-loss internals
# ------------------------------------------------------------------
def _compute_text_loss(self, batch: dict[str, Tensor], text_labels: Tensor) -> Tensor:
"""Cross-entropy on the SmolVLM ``lm_head`` over target tokens."""
if self.config.adapt_to_pi_aloha:
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
images, img_masks = self.prepare_images(batch)
state = self.prepare_state(batch)
lang_tokens = batch[OBS_LANGUAGE_TOKENS]
lang_masks = batch[OBS_LANGUAGE_ATTENTION_MASK]
prefix_embs, prefix_pad_masks, prefix_att_masks = self.model.embed_prefix(
images, img_masks, lang_tokens, lang_masks, state=state
)
prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
# Prefix-only forward.
out_pair, _ = self.model.vlm_with_expert.forward(
attention_mask=prefix_att_2d_masks,
position_ids=prefix_position_ids,
past_key_values=None,
inputs_embeds=[prefix_embs, None],
use_cache=False,
fill_kv_cache=False,
)
prefix_out = out_pair[0] if isinstance(out_pair, (tuple, list)) else out_pair
if prefix_out is None:
raise RuntimeError(
"SmolVLA2: vlm_with_expert.forward returned no prefix hidden "
"states — text-loss path needs them."
)
# Lang token positions inside the prefix. ``embed_prefix`` lays
# out the prefix as ``[image_blocks..., lang, state]`` so the
# lang range is identifiable from the trailing state size and
# the known lang length.
num_lang = lang_tokens.shape[1]
state_for_dim = state if state.ndim >= 2 else state[:, None]
num_state = state_for_dim.shape[1] if state_for_dim.ndim >= 2 else 1
if num_state < 1:
num_state = 1
prefix_len = prefix_out.shape[1]
lang_end = prefix_len - num_state
lang_start = lang_end - num_lang
if lang_start < 0 or lang_end > prefix_len:
raise RuntimeError(
f"SmolVLA2: could not locate lang token range in prefix "
f"(prefix_len={prefix_len}, num_lang={num_lang}, "
f"num_state={num_state})."
)
lang_hidden = prefix_out[:, lang_start:lang_end]
vlm = self.model.vlm_with_expert.vlm
logits = vlm.lm_head(lang_hidden) # (B, num_lang, vocab)
if text_labels.shape[1] != num_lang:
common = min(text_labels.shape[1], num_lang)
logits = logits[:, :common]
text_labels = text_labels[:, :common]
loss = F.cross_entropy(
logits.reshape(-1, logits.shape[-1]),
text_labels.reshape(-1).long(),
ignore_index=-100,
)
return loss