mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-24 21:19:53 +00:00
fix(smolvla2): drive select_message through SmolVLM.generate
The hand-rolled AR loop in ``select_message`` was fighting the underlying ``vlm_with_expert.forward`` design, which assumes the "prefix-once + suffix-always-via-expert" pattern that ``denoise_step`` uses for action chunks. Cross-attn layers (every other layer with ``attention_mode='cross_attn'`` + ``self_attn_every_n_layers=2``) hard-require an expert input on every call: passing ``inputs_embeds=[current_embs, None]`` crashed at ``expert_layer.input_layernorm(None)`` with ``'NoneType' object has no attribute 'dtype'``. Earlier KV-cache attempts ran into the matching ``[15, 139] vs [15, 1]`` shape mismatch because the cache gets *overwritten*, not appended, on each ``fill_kv_cache=True`` call — there's just no AR-text-decode mode in this forward. Stop fighting it: drive AR text generation through the underlying SmolVLM via ``vlm.generate(input_ids=..., attention_mask=..., pixel_values=...)``. KV caching, sampling/greedy, EOS handling all come from HF's standard implementation. Trade-off: ``state`` drops out of the prefix at inference (no slot for it on the standard SmolVLM path), so high-level generations may drift from training distribution slightly. That's acceptable for the dry-run REPL — the high-level branches (subtask / plan / memory / vqa) are mostly vision+language conditioned anyway, and the action expert (where state actually matters) goes through the unchanged ``select_action`` path. Image features the runtime merged in (``observation.images.*``) are stacked into the ``[B, num_images, C, H, W]`` shape SmolVLM expects. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -303,79 +303,65 @@ class SmolVLA2Policy(SmolVLAPolicy):
|
||||
if eos_token_id is None:
|
||||
eos_token_id = tokenizer.eos_token_id
|
||||
|
||||
images, img_masks = self.prepare_images(batch)
|
||||
state = self.prepare_state(batch)
|
||||
# AR text generation through the underlying SmolVLM rather
|
||||
# than ``vlm_with_expert.forward``. The latter is built around
|
||||
# the action-expert decode pattern (one-shot suffix forward
|
||||
# against a cached prefix, with cross-attn layers that
|
||||
# *require* an expert input on every call) — it isn't a
|
||||
# general-purpose AR text decoder. ``vlm.generate`` runs the
|
||||
# SmolVLM exactly the way HuggingFace inference does it
|
||||
# everywhere else, so KV caching, beam/greedy/sampling logic,
|
||||
# and EOS handling all just work.
|
||||
#
|
||||
# Trade-off: ``state`` is dropped from the prefix at inference
|
||||
# time (no slot for it on the standard SmolVLM path), so
|
||||
# generations may drift from training distribution slightly.
|
||||
# That's acceptable for the dry-run REPL. The high-level
|
||||
# branches (subtask / plan / memory / vqa) are mostly
|
||||
# vision+language conditioned anyway; the action expert is
|
||||
# where state really matters.
|
||||
vlm = self.model.vlm_with_expert.vlm
|
||||
lang_tokens = batch[OBS_LANGUAGE_TOKENS]
|
||||
lang_masks = batch[OBS_LANGUAGE_ATTENTION_MASK]
|
||||
|
||||
# Embed the (images + lang + state) prefix once. Image
|
||||
# encoding is the expensive part of ``embed_prefix``, so doing
|
||||
# it here and concatenating new-token embeddings into the same
|
||||
# ``current_embs`` buffer lets us avoid re-running SigLIP on
|
||||
# every decode step.
|
||||
prefix_embs, prefix_pad_masks, prefix_att_masks = self.model.embed_prefix(
|
||||
images, img_masks, lang_tokens, lang_masks, state=state
|
||||
# Collect any image features the runtime merged in. SmolVLM
|
||||
# expects ``pixel_values`` shape ``[B, num_images, C, H, W]``;
|
||||
# individual ``observation.images.*`` features are typically
|
||||
# ``[B, C, H, W]`` after the preprocessor, so stack them on a
|
||||
# new ``num_images`` axis.
|
||||
image_tensors: list[Tensor] = []
|
||||
for k, v in batch.items():
|
||||
if (
|
||||
isinstance(k, str)
|
||||
and k.startswith("observation.images.")
|
||||
and isinstance(v, Tensor)
|
||||
):
|
||||
image_tensors.append(v if v.ndim == 4 else v.unsqueeze(0))
|
||||
pixel_values = (
|
||||
torch.stack(image_tensors, dim=1) if image_tensors else None
|
||||
)
|
||||
|
||||
device = prefix_embs.device
|
||||
bsize = prefix_embs.shape[0]
|
||||
vlm = self.model.vlm_with_expert.vlm
|
||||
emb_dim = prefix_embs.shape[-1]
|
||||
text_emb_scale = math.sqrt(emb_dim)
|
||||
gen_kwargs: dict[str, Any] = {
|
||||
"input_ids": lang_tokens,
|
||||
"attention_mask": lang_masks.long() if lang_masks.dtype == torch.bool else lang_masks,
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"do_sample": temperature > 0,
|
||||
"pad_token_id": tokenizer.pad_token_id or tokenizer.eos_token_id,
|
||||
}
|
||||
if temperature > 0:
|
||||
gen_kwargs["temperature"] = temperature
|
||||
gen_kwargs["top_p"] = top_p
|
||||
if eos_token_id is not None:
|
||||
gen_kwargs["eos_token_id"] = eos_token_id
|
||||
if pixel_values is not None:
|
||||
gen_kwargs["pixel_values"] = pixel_values
|
||||
|
||||
# Cumulative buffers — the prefix at first, then grown by one
|
||||
# token per decode step. The attention layer's only supported
|
||||
# multi-step pattern is "pass the full embedded sequence each
|
||||
# call with no KV cache" (the underlying
|
||||
# ``vlm_with_expert.forward`` overwrites the cache instead of
|
||||
# appending, so true incremental decoding isn't supported).
|
||||
# This is O(n²) in the text length but matches the pattern
|
||||
# ``denoise_step`` already uses successfully.
|
||||
current_embs = prefix_embs
|
||||
current_pad = prefix_pad_masks
|
||||
current_att = prefix_att_masks
|
||||
|
||||
# Pre-build a one-step mask append (a generated token always
|
||||
# has ``pad=1`` and ``att=1`` — fully causal among generated
|
||||
# tokens, attends back to the entire prefix).
|
||||
ones_step = torch.ones((bsize, 1), dtype=torch.bool, device=device)
|
||||
|
||||
generated: list[int] = []
|
||||
for step in range(max_new_tokens):
|
||||
full_2d = make_att_2d_masks(current_pad, current_att)
|
||||
full_pos = torch.cumsum(current_pad, dim=1) - 1
|
||||
|
||||
out_pair, _ = self.model.vlm_with_expert.forward(
|
||||
attention_mask=full_2d,
|
||||
position_ids=full_pos,
|
||||
past_key_values=None,
|
||||
inputs_embeds=[current_embs, None],
|
||||
use_cache=False,
|
||||
fill_kv_cache=False,
|
||||
)
|
||||
prefix_out = out_pair[0] if isinstance(out_pair, (tuple, list)) else out_pair
|
||||
if prefix_out is None:
|
||||
raise RuntimeError(
|
||||
"select_message: vlm_with_expert.forward returned no hidden states."
|
||||
)
|
||||
|
||||
last_hidden = prefix_out[:, -1:]
|
||||
logits_step = vlm.lm_head(last_hidden)[:, -1] # (B, V)
|
||||
next_ids = self._sample_next_token(logits_step, temperature, top_p)
|
||||
tok_id = int(next_ids[0].item())
|
||||
generated.append(tok_id)
|
||||
if eos_token_id is not None and tok_id == eos_token_id:
|
||||
break
|
||||
|
||||
new_emb = self.model.vlm_with_expert.embed_language_tokens(
|
||||
next_ids.unsqueeze(0)
|
||||
)
|
||||
new_emb = new_emb * text_emb_scale
|
||||
current_embs = torch.cat([current_embs, new_emb], dim=1)
|
||||
current_pad = torch.cat([current_pad, ones_step], dim=1)
|
||||
current_att = torch.cat([current_att, ones_step], dim=1)
|
||||
|
||||
return tokenizer.decode(generated, skip_special_tokens=True).strip()
|
||||
gen_ids = vlm.generate(**gen_kwargs)
|
||||
# ``vlm.generate`` returns the prompt + new tokens; slice off
|
||||
# the prompt so the caller only sees the model's continuation.
|
||||
prompt_len = lang_tokens.shape[1]
|
||||
new_token_ids = gen_ids[0, prompt_len:].tolist()
|
||||
return tokenizer.decode(new_token_ids, skip_special_tokens=True).strip()
|
||||
|
||||
@staticmethod
|
||||
def _sample_next_token(
|
||||
|
||||
Reference in New Issue
Block a user