fix(smolvla2): drive select_message through SmolVLM.generate

The hand-rolled AR loop in ``select_message`` was fighting the
underlying ``vlm_with_expert.forward`` design, which assumes the
"prefix-once + suffix-always-via-expert" pattern that ``denoise_step``
uses for action chunks. Cross-attn layers (every other layer with
``attention_mode='cross_attn'`` + ``self_attn_every_n_layers=2``)
hard-require an expert input on every call: passing
``inputs_embeds=[current_embs, None]`` crashed at
``expert_layer.input_layernorm(None)`` with ``'NoneType' object has
no attribute 'dtype'``. Earlier KV-cache attempts ran into the
matching ``[15, 139] vs [15, 1]`` shape mismatch because the cache
gets *overwritten*, not appended, on each ``fill_kv_cache=True`` call
— there's just no AR-text-decode mode in this forward.

Stop fighting it: drive AR text generation through the underlying
SmolVLM via ``vlm.generate(input_ids=..., attention_mask=...,
pixel_values=...)``. KV caching, sampling/greedy, EOS handling all
come from HF's standard implementation. Trade-off: ``state`` drops
out of the prefix at inference (no slot for it on the standard
SmolVLM path), so high-level generations may drift from training
distribution slightly. That's acceptable for the dry-run REPL — the
high-level branches (subtask / plan / memory / vqa) are mostly
vision+language conditioned anyway, and the action expert (where
state actually matters) goes through the unchanged ``select_action``
path.

Image features the runtime merged in (``observation.images.*``) are
stacked into the ``[B, num_images, C, H, W]`` shape SmolVLM expects.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-05-05 12:39:34 +02:00
parent 3ff6c6860e
commit fa8ae1e89b
@@ -303,79 +303,65 @@ class SmolVLA2Policy(SmolVLAPolicy):
if eos_token_id is None:
eos_token_id = tokenizer.eos_token_id
images, img_masks = self.prepare_images(batch)
state = self.prepare_state(batch)
# AR text generation through the underlying SmolVLM rather
# than ``vlm_with_expert.forward``. The latter is built around
# the action-expert decode pattern (one-shot suffix forward
# against a cached prefix, with cross-attn layers that
# *require* an expert input on every call) — it isn't a
# general-purpose AR text decoder. ``vlm.generate`` runs the
# SmolVLM exactly the way HuggingFace inference does it
# everywhere else, so KV caching, beam/greedy/sampling logic,
# and EOS handling all just work.
#
# Trade-off: ``state`` is dropped from the prefix at inference
# time (no slot for it on the standard SmolVLM path), so
# generations may drift from training distribution slightly.
# That's acceptable for the dry-run REPL. The high-level
# branches (subtask / plan / memory / vqa) are mostly
# vision+language conditioned anyway; the action expert is
# where state really matters.
vlm = self.model.vlm_with_expert.vlm
lang_tokens = batch[OBS_LANGUAGE_TOKENS]
lang_masks = batch[OBS_LANGUAGE_ATTENTION_MASK]
# Embed the (images + lang + state) prefix once. Image
# encoding is the expensive part of ``embed_prefix``, so doing
# it here and concatenating new-token embeddings into the same
# ``current_embs`` buffer lets us avoid re-running SigLIP on
# every decode step.
prefix_embs, prefix_pad_masks, prefix_att_masks = self.model.embed_prefix(
images, img_masks, lang_tokens, lang_masks, state=state
# Collect any image features the runtime merged in. SmolVLM
# expects ``pixel_values`` shape ``[B, num_images, C, H, W]``;
# individual ``observation.images.*`` features are typically
# ``[B, C, H, W]`` after the preprocessor, so stack them on a
# new ``num_images`` axis.
image_tensors: list[Tensor] = []
for k, v in batch.items():
if (
isinstance(k, str)
and k.startswith("observation.images.")
and isinstance(v, Tensor)
):
image_tensors.append(v if v.ndim == 4 else v.unsqueeze(0))
pixel_values = (
torch.stack(image_tensors, dim=1) if image_tensors else None
)
device = prefix_embs.device
bsize = prefix_embs.shape[0]
vlm = self.model.vlm_with_expert.vlm
emb_dim = prefix_embs.shape[-1]
text_emb_scale = math.sqrt(emb_dim)
gen_kwargs: dict[str, Any] = {
"input_ids": lang_tokens,
"attention_mask": lang_masks.long() if lang_masks.dtype == torch.bool else lang_masks,
"max_new_tokens": max_new_tokens,
"do_sample": temperature > 0,
"pad_token_id": tokenizer.pad_token_id or tokenizer.eos_token_id,
}
if temperature > 0:
gen_kwargs["temperature"] = temperature
gen_kwargs["top_p"] = top_p
if eos_token_id is not None:
gen_kwargs["eos_token_id"] = eos_token_id
if pixel_values is not None:
gen_kwargs["pixel_values"] = pixel_values
# Cumulative buffers — the prefix at first, then grown by one
# token per decode step. The attention layer's only supported
# multi-step pattern is "pass the full embedded sequence each
# call with no KV cache" (the underlying
# ``vlm_with_expert.forward`` overwrites the cache instead of
# appending, so true incremental decoding isn't supported).
# This is O(n²) in the text length but matches the pattern
# ``denoise_step`` already uses successfully.
current_embs = prefix_embs
current_pad = prefix_pad_masks
current_att = prefix_att_masks
# Pre-build a one-step mask append (a generated token always
# has ``pad=1`` and ``att=1`` — fully causal among generated
# tokens, attends back to the entire prefix).
ones_step = torch.ones((bsize, 1), dtype=torch.bool, device=device)
generated: list[int] = []
for step in range(max_new_tokens):
full_2d = make_att_2d_masks(current_pad, current_att)
full_pos = torch.cumsum(current_pad, dim=1) - 1
out_pair, _ = self.model.vlm_with_expert.forward(
attention_mask=full_2d,
position_ids=full_pos,
past_key_values=None,
inputs_embeds=[current_embs, None],
use_cache=False,
fill_kv_cache=False,
)
prefix_out = out_pair[0] if isinstance(out_pair, (tuple, list)) else out_pair
if prefix_out is None:
raise RuntimeError(
"select_message: vlm_with_expert.forward returned no hidden states."
)
last_hidden = prefix_out[:, -1:]
logits_step = vlm.lm_head(last_hidden)[:, -1] # (B, V)
next_ids = self._sample_next_token(logits_step, temperature, top_p)
tok_id = int(next_ids[0].item())
generated.append(tok_id)
if eos_token_id is not None and tok_id == eos_token_id:
break
new_emb = self.model.vlm_with_expert.embed_language_tokens(
next_ids.unsqueeze(0)
)
new_emb = new_emb * text_emb_scale
current_embs = torch.cat([current_embs, new_emb], dim=1)
current_pad = torch.cat([current_pad, ones_step], dim=1)
current_att = torch.cat([current_att, ones_step], dim=1)
return tokenizer.decode(generated, skip_special_tokens=True).strip()
gen_ids = vlm.generate(**gen_kwargs)
# ``vlm.generate`` returns the prompt + new tokens; slice off
# the prompt so the caller only sees the model's continuation.
prompt_len = lang_tokens.shape[1]
new_token_ids = gen_ids[0, prompt_len:].tolist()
return tokenizer.decode(new_token_ids, skip_special_tokens=True).strip()
@staticmethod
def _sample_next_token(