From fd89efb545b1f7dedfd6ccfca8432ecea74a5812 Mon Sep 17 00:00:00 2001 From: Pepijn Date: Tue, 5 May 2026 12:08:52 +0200 Subject: [PATCH] fix(smolvla2): 3D attention mask in select_message decode loop MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit SmolVLA's ``eager_attention_forward`` does ``masked = torch.where(attention_mask[:, None, :, :], ...)``, which requires a 3D ``[B, query_len, key_len]`` bool tensor so the broadcast to 4D works. ``select_message``'s prefix forward got this right (passes ``prefix_2d`` from ``make_att_2d_masks``), but the KV-cache decoding loop built ``new_attn = torch.ones((bsize, cur_pos + 1))`` — 2D — and the very first decode step blew up with ``IndexError: too many indices for tensor of dimension 2``. During KV-cache decoding ``query_len = 1`` and ``key_len = cur_pos + 1`` (prefix + every token already generated), so the right shape is ``[B, 1, cur_pos + 1]``. Match the layout SmolVLA's working ``denoise_step`` uses for the equivalent ``prefix_pad_2d_masks`` build. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/lerobot/policies/smolvla2/modeling_smolvla2.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/lerobot/policies/smolvla2/modeling_smolvla2.py b/src/lerobot/policies/smolvla2/modeling_smolvla2.py index 9a54f35b0..f9701632c 100644 --- a/src/lerobot/policies/smolvla2/modeling_smolvla2.py +++ b/src/lerobot/policies/smolvla2/modeling_smolvla2.py @@ -351,7 +351,15 @@ class SmolVLA2Policy(SmolVLAPolicy): new_emb = new_emb * math.sqrt(new_emb.shape[-1]) new_pos = torch.full((bsize, 1), cur_pos, device=device, dtype=torch.long) - new_attn = torch.ones((bsize, cur_pos + 1), device=device, dtype=torch.bool) + # SmolVLA's attention layer expects ``attention_mask`` shape + # ``[B, query_len, key_len]`` (3D bool) so it can broadcast to + # ``[B, 1, query_len, key_len]`` via ``mask[:, None, :, :]``. + # During KV-cache decoding query_len = 1 and key_len = + # ``cur_pos + 1`` (prefix + every token already generated). + # A 2D ``[B, key_len]`` tensor here trips + # ``IndexError: too many indices for tensor of dimension 2`` + # in ``eager_attention_forward``. + new_attn = torch.ones((bsize, 1, cur_pos + 1), device=device, dtype=torch.bool) out_pair, past_kv = self.model.vlm_with_expert.forward( attention_mask=new_attn,