diff --git a/src/lerobot/policies/smolvla2/modeling_smolvla2.py b/src/lerobot/policies/smolvla2/modeling_smolvla2.py index f9701632c..89d4a1c7a 100644 --- a/src/lerobot/policies/smolvla2/modeling_smolvla2.py +++ b/src/lerobot/policies/smolvla2/modeling_smolvla2.py @@ -308,35 +308,58 @@ class SmolVLA2Policy(SmolVLAPolicy): lang_tokens = batch[OBS_LANGUAGE_TOKENS] lang_masks = batch[OBS_LANGUAGE_ATTENTION_MASK] - # 1) Embed prefix (images + lang + state) and run with KV cache. + # Embed the (images + lang + state) prefix once. Image + # encoding is the expensive part of ``embed_prefix``, so doing + # it here and concatenating new-token embeddings into the same + # ``current_embs`` buffer lets us avoid re-running SigLIP on + # every decode step. prefix_embs, prefix_pad_masks, prefix_att_masks = self.model.embed_prefix( images, img_masks, lang_tokens, lang_masks, state=state ) - prefix_2d = make_att_2d_masks(prefix_pad_masks, prefix_att_masks) - prefix_pos = torch.cumsum(prefix_pad_masks, dim=1) - 1 - out_pair, past_kv = self.model.vlm_with_expert.forward( - attention_mask=prefix_2d, - position_ids=prefix_pos, - past_key_values=None, - inputs_embeds=[prefix_embs, None], - use_cache=True, - fill_kv_cache=True, - ) - prefix_out = out_pair[0] if isinstance(out_pair, (tuple, list)) else out_pair - if prefix_out is None: - raise RuntimeError("select_message: prefix forward returned no hidden states.") - vlm = self.model.vlm_with_expert.vlm - - # 2) Initial logits — sample first new token from the last - # prefix position. - last_hidden = prefix_out[:, -1:] - device = last_hidden.device + device = prefix_embs.device bsize = prefix_embs.shape[0] - cur_pos = int(prefix_embs.shape[1]) + vlm = self.model.vlm_with_expert.vlm + emb_dim = prefix_embs.shape[-1] + text_emb_scale = math.sqrt(emb_dim) + + # Cumulative buffers — the prefix at first, then grown by one + # token per decode step. The attention layer's only supported + # multi-step pattern is "pass the full embedded sequence each + # call with no KV cache" (the underlying + # ``vlm_with_expert.forward`` overwrites the cache instead of + # appending, so true incremental decoding isn't supported). + # This is O(n²) in the text length but matches the pattern + # ``denoise_step`` already uses successfully. + current_embs = prefix_embs + current_pad = prefix_pad_masks + current_att = prefix_att_masks + + # Pre-build a one-step mask append (a generated token always + # has ``pad=1`` and ``att=1`` — fully causal among generated + # tokens, attends back to the entire prefix). + ones_step = torch.ones((bsize, 1), dtype=torch.bool, device=device) generated: list[int] = [] - for _ in range(max_new_tokens): + for step in range(max_new_tokens): + full_2d = make_att_2d_masks(current_pad, current_att) + full_pos = torch.cumsum(current_pad, dim=1) - 1 + + out_pair, _ = self.model.vlm_with_expert.forward( + attention_mask=full_2d, + position_ids=full_pos, + past_key_values=None, + inputs_embeds=[current_embs, None], + use_cache=False, + fill_kv_cache=False, + ) + prefix_out = out_pair[0] if isinstance(out_pair, (tuple, list)) else out_pair + if prefix_out is None: + raise RuntimeError( + "select_message: vlm_with_expert.forward returned no hidden states." + ) + + last_hidden = prefix_out[:, -1:] logits_step = vlm.lm_head(last_hidden)[:, -1] # (B, V) next_ids = self._sample_next_token(logits_step, temperature, top_p) tok_id = int(next_ids[0].item()) @@ -344,34 +367,13 @@ class SmolVLA2Policy(SmolVLAPolicy): if eos_token_id is not None and tok_id == eos_token_id: break - # 3) Embed the new token and forward with KV cache. new_emb = self.model.vlm_with_expert.embed_language_tokens( next_ids.unsqueeze(0) ) - new_emb = new_emb * math.sqrt(new_emb.shape[-1]) - - new_pos = torch.full((bsize, 1), cur_pos, device=device, dtype=torch.long) - # SmolVLA's attention layer expects ``attention_mask`` shape - # ``[B, query_len, key_len]`` (3D bool) so it can broadcast to - # ``[B, 1, query_len, key_len]`` via ``mask[:, None, :, :]``. - # During KV-cache decoding query_len = 1 and key_len = - # ``cur_pos + 1`` (prefix + every token already generated). - # A 2D ``[B, key_len]`` tensor here trips - # ``IndexError: too many indices for tensor of dimension 2`` - # in ``eager_attention_forward``. - new_attn = torch.ones((bsize, 1, cur_pos + 1), device=device, dtype=torch.bool) - - out_pair, past_kv = self.model.vlm_with_expert.forward( - attention_mask=new_attn, - position_ids=new_pos, - past_key_values=past_kv, - inputs_embeds=[new_emb, None], - use_cache=True, - fill_kv_cache=True, - ) - new_prefix_out = out_pair[0] if isinstance(out_pair, (tuple, list)) else out_pair - last_hidden = new_prefix_out[:, -1:] - cur_pos += 1 + new_emb = new_emb * text_emb_scale + current_embs = torch.cat([current_embs, new_emb], dim=1) + current_pad = torch.cat([current_pad, ones_step], dim=1) + current_att = torch.cat([current_att, ones_step], dim=1) return tokenizer.decode(generated, skip_special_tokens=True).strip()