mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-14 16:19:45 +00:00
fix(smolvla2): rewrite select_message decode loop without KV cache
SmolVLA's ``vlm_with_expert.forward`` doesn't actually support incremental KV cache growth — its only ``fill_kv_cache=True`` mode *overwrites* the cache with the latest call's key/value states, and its only ``fill_kv_cache=False`` mode concatenates ``cache + new`` into a local ``key_states`` for one matmul without ever updating the cache itself. The original ``select_message`` decode loop tried to use ``fill_kv_cache=True`` per step, which clobbered the cache to 1 token after the first decode and threw ``Expected size for first two dimensions of batch2 tensor to be: [15, 139] but got: [15, 1]`` — the attention mask still expected 139 keys but the cached + new key_states only had 1. Match the pattern ``denoise_step`` already uses successfully: maintain a cumulative ``(embs, pad, att)`` buffer that starts as the prefix and grows by one bool/embedding row per step. Each step forwards the *full* sequence with ``use_cache=False, fill_kv_cache=False, past_key_values=None`` so the matmul shapes always line up. Generated-token rows are tagged ``pad=1, att=1`` which makes them fully causal among themselves while still able to attend back to the entire prefix (per ``make_att_2d_masks`` semantics: a token can attend to any earlier token whose cumulative ``att`` count is ≤ its own). Image encoding is still done once via the initial ``embed_prefix`` call — the expensive part doesn't repeat. The remaining cost is O(n²) text-only transformer forwards, which is fine for the dry-run REPL's 50–100 token responses. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -308,35 +308,58 @@ class SmolVLA2Policy(SmolVLAPolicy):
|
||||
lang_tokens = batch[OBS_LANGUAGE_TOKENS]
|
||||
lang_masks = batch[OBS_LANGUAGE_ATTENTION_MASK]
|
||||
|
||||
# 1) Embed prefix (images + lang + state) and run with KV cache.
|
||||
# Embed the (images + lang + state) prefix once. Image
|
||||
# encoding is the expensive part of ``embed_prefix``, so doing
|
||||
# it here and concatenating new-token embeddings into the same
|
||||
# ``current_embs`` buffer lets us avoid re-running SigLIP on
|
||||
# every decode step.
|
||||
prefix_embs, prefix_pad_masks, prefix_att_masks = self.model.embed_prefix(
|
||||
images, img_masks, lang_tokens, lang_masks, state=state
|
||||
)
|
||||
prefix_2d = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
|
||||
prefix_pos = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
||||
out_pair, past_kv = self.model.vlm_with_expert.forward(
|
||||
attention_mask=prefix_2d,
|
||||
position_ids=prefix_pos,
|
||||
past_key_values=None,
|
||||
inputs_embeds=[prefix_embs, None],
|
||||
use_cache=True,
|
||||
fill_kv_cache=True,
|
||||
)
|
||||
prefix_out = out_pair[0] if isinstance(out_pair, (tuple, list)) else out_pair
|
||||
if prefix_out is None:
|
||||
raise RuntimeError("select_message: prefix forward returned no hidden states.")
|
||||
|
||||
vlm = self.model.vlm_with_expert.vlm
|
||||
|
||||
# 2) Initial logits — sample first new token from the last
|
||||
# prefix position.
|
||||
last_hidden = prefix_out[:, -1:]
|
||||
device = last_hidden.device
|
||||
device = prefix_embs.device
|
||||
bsize = prefix_embs.shape[0]
|
||||
cur_pos = int(prefix_embs.shape[1])
|
||||
vlm = self.model.vlm_with_expert.vlm
|
||||
emb_dim = prefix_embs.shape[-1]
|
||||
text_emb_scale = math.sqrt(emb_dim)
|
||||
|
||||
# Cumulative buffers — the prefix at first, then grown by one
|
||||
# token per decode step. The attention layer's only supported
|
||||
# multi-step pattern is "pass the full embedded sequence each
|
||||
# call with no KV cache" (the underlying
|
||||
# ``vlm_with_expert.forward`` overwrites the cache instead of
|
||||
# appending, so true incremental decoding isn't supported).
|
||||
# This is O(n²) in the text length but matches the pattern
|
||||
# ``denoise_step`` already uses successfully.
|
||||
current_embs = prefix_embs
|
||||
current_pad = prefix_pad_masks
|
||||
current_att = prefix_att_masks
|
||||
|
||||
# Pre-build a one-step mask append (a generated token always
|
||||
# has ``pad=1`` and ``att=1`` — fully causal among generated
|
||||
# tokens, attends back to the entire prefix).
|
||||
ones_step = torch.ones((bsize, 1), dtype=torch.bool, device=device)
|
||||
|
||||
generated: list[int] = []
|
||||
for _ in range(max_new_tokens):
|
||||
for step in range(max_new_tokens):
|
||||
full_2d = make_att_2d_masks(current_pad, current_att)
|
||||
full_pos = torch.cumsum(current_pad, dim=1) - 1
|
||||
|
||||
out_pair, _ = self.model.vlm_with_expert.forward(
|
||||
attention_mask=full_2d,
|
||||
position_ids=full_pos,
|
||||
past_key_values=None,
|
||||
inputs_embeds=[current_embs, None],
|
||||
use_cache=False,
|
||||
fill_kv_cache=False,
|
||||
)
|
||||
prefix_out = out_pair[0] if isinstance(out_pair, (tuple, list)) else out_pair
|
||||
if prefix_out is None:
|
||||
raise RuntimeError(
|
||||
"select_message: vlm_with_expert.forward returned no hidden states."
|
||||
)
|
||||
|
||||
last_hidden = prefix_out[:, -1:]
|
||||
logits_step = vlm.lm_head(last_hidden)[:, -1] # (B, V)
|
||||
next_ids = self._sample_next_token(logits_step, temperature, top_p)
|
||||
tok_id = int(next_ids[0].item())
|
||||
@@ -344,34 +367,13 @@ class SmolVLA2Policy(SmolVLAPolicy):
|
||||
if eos_token_id is not None and tok_id == eos_token_id:
|
||||
break
|
||||
|
||||
# 3) Embed the new token and forward with KV cache.
|
||||
new_emb = self.model.vlm_with_expert.embed_language_tokens(
|
||||
next_ids.unsqueeze(0)
|
||||
)
|
||||
new_emb = new_emb * math.sqrt(new_emb.shape[-1])
|
||||
|
||||
new_pos = torch.full((bsize, 1), cur_pos, device=device, dtype=torch.long)
|
||||
# SmolVLA's attention layer expects ``attention_mask`` shape
|
||||
# ``[B, query_len, key_len]`` (3D bool) so it can broadcast to
|
||||
# ``[B, 1, query_len, key_len]`` via ``mask[:, None, :, :]``.
|
||||
# During KV-cache decoding query_len = 1 and key_len =
|
||||
# ``cur_pos + 1`` (prefix + every token already generated).
|
||||
# A 2D ``[B, key_len]`` tensor here trips
|
||||
# ``IndexError: too many indices for tensor of dimension 2``
|
||||
# in ``eager_attention_forward``.
|
||||
new_attn = torch.ones((bsize, 1, cur_pos + 1), device=device, dtype=torch.bool)
|
||||
|
||||
out_pair, past_kv = self.model.vlm_with_expert.forward(
|
||||
attention_mask=new_attn,
|
||||
position_ids=new_pos,
|
||||
past_key_values=past_kv,
|
||||
inputs_embeds=[new_emb, None],
|
||||
use_cache=True,
|
||||
fill_kv_cache=True,
|
||||
)
|
||||
new_prefix_out = out_pair[0] if isinstance(out_pair, (tuple, list)) else out_pair
|
||||
last_hidden = new_prefix_out[:, -1:]
|
||||
cur_pos += 1
|
||||
new_emb = new_emb * text_emb_scale
|
||||
current_embs = torch.cat([current_embs, new_emb], dim=1)
|
||||
current_pad = torch.cat([current_pad, ones_step], dim=1)
|
||||
current_att = torch.cat([current_att, ones_step], dim=1)
|
||||
|
||||
return tokenizer.decode(generated, skip_special_tokens=True).strip()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user