mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-11 14:49:43 +00:00
feat(smolvla2): dual-head forward — flow loss + lm_head text loss
The third and final commit of PR 3's SmolVLA2 work. Wires the actual
training signal through:
* ``predict_actions[i] = True`` → sample i contributes to flow loss
* ``text_labels[i, t] != -100`` → token t of sample i contributes to
LM-head cross-entropy
Both routing knobs come from ``SmolVLA2ChatTokenizerStep`` (previous
commit on this branch), which builds them from the recipe's
``message_streams`` / ``target_message_indices``. The per-sample
``predict_actions`` mask preserves the Pi0.5 convention from the
plan's Section I.7: "True iff any low_level target exists".
Implementation:
- ``forward`` reads ``text_labels`` and ``predict_actions`` from the
batch. When neither is present (vanilla SmolVLA usage with no
recipe), delegates to ``SmolVLAPolicy.forward`` so unannotated
datasets keep training as before — full backward compatibility.
- ``flow_loss``: super().forward(reduction="none") returns the
per-sample (B,) flow loss; we mask non-action samples with the
``predict_actions`` bool and renormalize by the count of action
samples. ``flow_loss_weight = 0`` in the config disables this
branch entirely (text-only training).
- ``text_loss``: a prefix-only forward through the VLM (no action
expert / suffix), slicing the lang-token range out of the
resulting hidden states (``embed_prefix`` orders the prefix as
``[image_blocks..., lang, state]`` so the slice is unambiguous).
Apply ``vlm.lm_head`` to those hidden states, cross-entropy with
``text_labels`` (ignore_index=-100). ``text_loss_weight = 0``
disables this branch (reverts to flow-only behaviour, matching
SmolVLA exactly).
- The two losses are summed with the config-supplied weights.
Mixed-stream samples (one batch containing both action targets and
text-only sub-recipes) are handled correctly: each sample contributes
where its labels are valid and is masked elsewhere.
Limitations / known follow-ups:
- Text loss runs an additional prefix-only forward separate from the
flow path's prefix forward. The forwards could share their prefix
computation; for clarity of this first commit they don't.
Optimization is straightforward when needed.
- Per-sample loss for ``reduction="none"`` is not yet meaningfully
defined for the dual path — we broadcast the scalar to (B,) for
caller compatibility (e.g. RA-BC weighting will need follow-up).
- Inference ``select_action`` is unchanged from SmolVLA today —
it predicts actions only. A separate "generate text"
``select_message`` path is the natural next step for runtime
use of the LM head (memory updates, plan refreshes, VQA answers).
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -13,26 +13,25 @@
|
||||
# limitations under the License.
|
||||
"""SmolVLA2 modeling — dual-head subclass of SmolVLAPolicy.
|
||||
|
||||
This module defines :class:`SmolVLA2Policy`, which extends SmolVLA with:
|
||||
Adds:
|
||||
|
||||
* an unfrozen SmolVLM ``lm_head`` so language tokens can be supervised,
|
||||
* a forward path that routes to the flow head, the text head, or both,
|
||||
driven by ``batch["predict_actions"]`` and ``batch["text_labels"]``.
|
||||
* a forward path that runs the flow head, the text head, or both,
|
||||
driven by ``batch["predict_actions"]`` and ``batch["text_labels"]``
|
||||
produced by :class:`SmolVLA2ChatTokenizerStep` (the previous commit on
|
||||
this branch).
|
||||
|
||||
The text-head computation itself is NOT wired up in this scaffold commit
|
||||
(the processor doesn't yet produce ``text_labels`` either). This file is
|
||||
the structural placeholder that:
|
||||
Per-sample routing — within one batch:
|
||||
|
||||
1. registers the ``SmolVLA2Policy`` class with the right config name so
|
||||
``policies/factory.py`` can build it,
|
||||
2. unfreezes ``lm_head`` at construction time when the config asks for it
|
||||
(otherwise SmolVLA's ``train_expert_only`` freezes it again on every
|
||||
``train()`` call),
|
||||
3. forwards to ``SmolVLAPolicy.forward`` so behaviour is identical to
|
||||
SmolVLA when no text labels are present — i.e. existing SmolVLA
|
||||
training scripts keep working.
|
||||
* ``predict_actions[i] = True`` ⇒ sample ``i`` contributes to the flow
|
||||
loss (action chunk supervision).
|
||||
* ``predict_actions[i] = False`` ⇒ sample ``i`` is masked out of the
|
||||
flow loss; only its text tokens (where ``text_labels[i, t] != -100``)
|
||||
contribute to the LM-head cross-entropy.
|
||||
|
||||
The next commit on this branch fills in the actual text-loss path.
|
||||
Falls back to ``SmolVLAPolicy.forward`` cleanly when neither
|
||||
``text_labels`` nor ``predict_actions`` is in the batch — unannotated
|
||||
datasets keep working unchanged.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -40,33 +39,35 @@ from __future__ import annotations
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
|
||||
from ..smolvla.modeling_smolvla import SmolVLAPolicy
|
||||
from lerobot.utils.constants import (
|
||||
ACTION,
|
||||
OBS_LANGUAGE_ATTENTION_MASK,
|
||||
OBS_LANGUAGE_TOKENS,
|
||||
OBS_STATE,
|
||||
)
|
||||
|
||||
from ..smolvla.modeling_smolvla import SmolVLAPolicy, make_att_2d_masks
|
||||
from .configuration_smolvla2 import SmolVLA2Config
|
||||
|
||||
|
||||
class SmolVLA2Policy(SmolVLAPolicy):
|
||||
"""SmolVLA + re-enabled SmolVLM language head.
|
||||
|
||||
Compatible drop-in for ``SmolVLAPolicy`` from a checkpoint or factory
|
||||
perspective. Behaviourally identical to SmolVLA until the text-head
|
||||
code path lands in the next commit on this branch.
|
||||
"""
|
||||
"""SmolVLA + re-enabled SmolVLM language head."""
|
||||
|
||||
config_class = SmolVLA2Config
|
||||
name = "smolvla2"
|
||||
|
||||
def __init__(self, config: SmolVLA2Config, dataset_stats: dict[str, dict[str, Tensor]] | None = None):
|
||||
if not isinstance(config, SmolVLA2Config):
|
||||
# Allow loading a SmolVLA checkpoint into a SmolVLA2 model by
|
||||
# widening the config type — the new fields fall back to their
|
||||
# defaults, which preserves the existing SmolVLA behaviour.
|
||||
config = SmolVLA2Config(**{
|
||||
f.name: getattr(config, f.name)
|
||||
for f in config.__dataclass_fields__.values()
|
||||
if hasattr(config, f.name)
|
||||
})
|
||||
config = SmolVLA2Config(
|
||||
**{
|
||||
f.name: getattr(config, f.name)
|
||||
for f in config.__dataclass_fields__.values()
|
||||
if hasattr(config, f.name)
|
||||
}
|
||||
)
|
||||
super().__init__(config, dataset_stats=dataset_stats)
|
||||
if config.unfreeze_lm_head and config.text_loss_weight > 0:
|
||||
self._unfreeze_lm_head()
|
||||
@@ -76,13 +77,8 @@ class SmolVLA2Policy(SmolVLAPolicy):
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _unfreeze_lm_head(self) -> None:
|
||||
"""Re-enable gradients on the SmolVLM ``lm_head`` (and the bits of
|
||||
the text path SmolVLA freezes) so the text-loss can flow back.
|
||||
|
||||
SmolVLA's ``SmolVLMWithExpertModel.set_requires_grad`` freezes
|
||||
``lm_head``, ``text_model.model.norm.weight``, and the last
|
||||
``text_model.layers.<N-1>`` block. We undo that selectively when
|
||||
text training is enabled.
|
||||
"""Re-enable gradients on the SmolVLM ``lm_head`` (and the bits
|
||||
of the text path SmolVLA freezes) so the text-loss can flow back.
|
||||
"""
|
||||
vlm_with_expert = getattr(self.model, "vlm_with_expert", None)
|
||||
if vlm_with_expert is None:
|
||||
@@ -91,10 +87,7 @@ class SmolVLA2Policy(SmolVLAPolicy):
|
||||
if vlm is None:
|
||||
return
|
||||
for name, param in vlm.named_parameters():
|
||||
if (
|
||||
"lm_head" in name
|
||||
or "text_model.model.norm.weight" in name
|
||||
):
|
||||
if "lm_head" in name or "text_model.model.norm.weight" in name:
|
||||
param.requires_grad = True
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
@@ -108,12 +101,144 @@ class SmolVLA2Policy(SmolVLAPolicy):
|
||||
time: Tensor | None = None,
|
||||
reduction: str = "mean",
|
||||
) -> tuple[Tensor, dict[str, Any]]:
|
||||
"""Forward pass with optional text-head loss.
|
||||
"""Forward pass with optional dual-head loss.
|
||||
|
||||
SCAFFOLD: forwards directly to ``SmolVLAPolicy.forward``. The
|
||||
actual text-loss / dual-head routing lands in the next commit on
|
||||
this branch — it will read ``batch["text_labels"]`` and
|
||||
``batch["predict_actions"]`` (both produced by the SmolVLA2
|
||||
processor) to decide which head(s) to run.
|
||||
Two routing knobs from the batch (produced by
|
||||
:class:`SmolVLA2ChatTokenizerStep`):
|
||||
|
||||
* ``text_labels`` — per-token labels with ``-100`` for non-target
|
||||
positions. Triggers the text-loss path through ``lm_head``.
|
||||
* ``predict_actions`` — per-sample bool tensor. ``True`` ⇒
|
||||
include this sample's action chunk in the flow loss.
|
||||
|
||||
When neither is present, delegate to ``SmolVLAPolicy.forward``.
|
||||
"""
|
||||
return super().forward(batch, noise=noise, time=time, reduction=reduction)
|
||||
text_labels = batch.get("text_labels")
|
||||
predict_actions_t = batch.get("predict_actions")
|
||||
|
||||
has_text_data = (
|
||||
text_labels is not None
|
||||
and isinstance(text_labels, Tensor)
|
||||
and self.config.text_loss_weight > 0
|
||||
)
|
||||
has_per_sample_routing = (
|
||||
predict_actions_t is not None and isinstance(predict_actions_t, Tensor)
|
||||
)
|
||||
|
||||
if not has_text_data and not has_per_sample_routing:
|
||||
return super().forward(batch, noise=noise, time=time, reduction=reduction)
|
||||
|
||||
loss_dict: dict[str, Any] = {}
|
||||
device = batch[OBS_STATE].device
|
||||
total = torch.zeros((), device=device, dtype=torch.float32)
|
||||
|
||||
# ------------------------------------------------------------
|
||||
# Flow loss path — only when at least one sample wants actions.
|
||||
# ------------------------------------------------------------
|
||||
run_flow = self.config.flow_loss_weight > 0 and (
|
||||
not has_per_sample_routing or bool(predict_actions_t.any().item())
|
||||
)
|
||||
if run_flow and ACTION in batch:
|
||||
per_sample_flow, flow_diag = super().forward(
|
||||
batch, noise=noise, time=time, reduction="none"
|
||||
)
|
||||
# ``per_sample_flow`` has shape (B,) from the SmolVLA
|
||||
# reduction="none" branch.
|
||||
if has_per_sample_routing:
|
||||
mask = predict_actions_t.to(per_sample_flow.dtype)
|
||||
masked = per_sample_flow * mask
|
||||
denom = mask.sum().clamp(min=1.0)
|
||||
flow_loss = masked.sum() / denom
|
||||
else:
|
||||
flow_loss = per_sample_flow.mean()
|
||||
total = total + self.config.flow_loss_weight * flow_loss
|
||||
loss_dict["flow_loss"] = float(flow_loss.detach().item())
|
||||
for k, v in flow_diag.items():
|
||||
loss_dict[f"flow_{k}"] = v
|
||||
|
||||
# ------------------------------------------------------------
|
||||
# Text loss path — prefix-only forward → lm_head → CE.
|
||||
# ------------------------------------------------------------
|
||||
if has_text_data:
|
||||
text_loss = self._compute_text_loss(batch, text_labels)
|
||||
total = total + self.config.text_loss_weight * text_loss
|
||||
loss_dict["text_loss"] = float(text_loss.detach().item())
|
||||
|
||||
loss_dict["loss"] = float(total.detach().item())
|
||||
|
||||
if reduction == "none":
|
||||
# Per-sample loss isn't meaningfully defined for the dual
|
||||
# path; broadcast the scalar to (B,) for caller compat.
|
||||
return total.expand(batch[OBS_STATE].shape[0]), loss_dict
|
||||
return total, loss_dict
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Text-loss internals
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _compute_text_loss(self, batch: dict[str, Tensor], text_labels: Tensor) -> Tensor:
|
||||
"""Cross-entropy on the SmolVLM ``lm_head`` over target tokens."""
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
|
||||
|
||||
images, img_masks = self.prepare_images(batch)
|
||||
state = self.prepare_state(batch)
|
||||
lang_tokens = batch[OBS_LANGUAGE_TOKENS]
|
||||
lang_masks = batch[OBS_LANGUAGE_ATTENTION_MASK]
|
||||
|
||||
prefix_embs, prefix_pad_masks, prefix_att_masks = self.model.embed_prefix(
|
||||
images, img_masks, lang_tokens, lang_masks, state=state
|
||||
)
|
||||
prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
|
||||
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
||||
|
||||
# Prefix-only forward.
|
||||
out_pair, _ = self.model.vlm_with_expert.forward(
|
||||
attention_mask=prefix_att_2d_masks,
|
||||
position_ids=prefix_position_ids,
|
||||
past_key_values=None,
|
||||
inputs_embeds=[prefix_embs, None],
|
||||
use_cache=False,
|
||||
fill_kv_cache=False,
|
||||
)
|
||||
prefix_out = out_pair[0] if isinstance(out_pair, (tuple, list)) else out_pair
|
||||
if prefix_out is None:
|
||||
raise RuntimeError(
|
||||
"SmolVLA2: vlm_with_expert.forward returned no prefix hidden "
|
||||
"states — text-loss path needs them."
|
||||
)
|
||||
|
||||
# Lang token positions inside the prefix. ``embed_prefix`` lays
|
||||
# out the prefix as ``[image_blocks..., lang, state]`` so the
|
||||
# lang range is identifiable from the trailing state size and
|
||||
# the known lang length.
|
||||
num_lang = lang_tokens.shape[1]
|
||||
state_for_dim = state if state.ndim >= 2 else state[:, None]
|
||||
num_state = state_for_dim.shape[1] if state_for_dim.ndim >= 2 else 1
|
||||
if num_state < 1:
|
||||
num_state = 1
|
||||
prefix_len = prefix_out.shape[1]
|
||||
lang_end = prefix_len - num_state
|
||||
lang_start = lang_end - num_lang
|
||||
if lang_start < 0 or lang_end > prefix_len:
|
||||
raise RuntimeError(
|
||||
f"SmolVLA2: could not locate lang token range in prefix "
|
||||
f"(prefix_len={prefix_len}, num_lang={num_lang}, "
|
||||
f"num_state={num_state})."
|
||||
)
|
||||
|
||||
lang_hidden = prefix_out[:, lang_start:lang_end]
|
||||
vlm = self.model.vlm_with_expert.vlm
|
||||
logits = vlm.lm_head(lang_hidden) # (B, num_lang, vocab)
|
||||
|
||||
if text_labels.shape[1] != num_lang:
|
||||
common = min(text_labels.shape[1], num_lang)
|
||||
logits = logits[:, :common]
|
||||
text_labels = text_labels[:, :common]
|
||||
|
||||
loss = F.cross_entropy(
|
||||
logits.reshape(-1, logits.shape[-1]),
|
||||
text_labels.reshape(-1).long(),
|
||||
ignore_index=-100,
|
||||
)
|
||||
return loss
|
||||
|
||||
Reference in New Issue
Block a user