fix(smolvla2): 3D attention mask in select_message decode loop

SmolVLA's ``eager_attention_forward`` does
``masked = torch.where(attention_mask[:, None, :, :], ...)``, which
requires a 3D ``[B, query_len, key_len]`` bool tensor so the
broadcast to 4D works. ``select_message``'s prefix forward got this
right (passes ``prefix_2d`` from ``make_att_2d_masks``), but the
KV-cache decoding loop built ``new_attn = torch.ones((bsize,
cur_pos + 1))`` — 2D — and the very first decode step blew up with
``IndexError: too many indices for tensor of dimension 2``.

During KV-cache decoding ``query_len = 1`` and
``key_len = cur_pos + 1`` (prefix + every token already generated),
so the right shape is ``[B, 1, cur_pos + 1]``. Match the layout
SmolVLA's working ``denoise_step`` uses for the equivalent
``prefix_pad_2d_masks`` build.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-05-05 12:08:52 +02:00
parent 2776b57c9e
commit fd89efb545
@@ -351,7 +351,15 @@ class SmolVLA2Policy(SmolVLAPolicy):
new_emb = new_emb * math.sqrt(new_emb.shape[-1])
new_pos = torch.full((bsize, 1), cur_pos, device=device, dtype=torch.long)
new_attn = torch.ones((bsize, cur_pos + 1), device=device, dtype=torch.bool)
# SmolVLA's attention layer expects ``attention_mask`` shape
# ``[B, query_len, key_len]`` (3D bool) so it can broadcast to
# ``[B, 1, query_len, key_len]`` via ``mask[:, None, :, :]``.
# During KV-cache decoding query_len = 1 and key_len =
# ``cur_pos + 1`` (prefix + every token already generated).
# A 2D ``[B, key_len]`` tensor here trips
# ``IndexError: too many indices for tensor of dimension 2``
# in ``eager_attention_forward``.
new_attn = torch.ones((bsize, 1, cur_pos + 1), device=device, dtype=torch.bool)
out_pair, past_kv = self.model.vlm_with_expert.forward(
attention_mask=new_attn,