mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-25 05:29:55 +00:00
fix(smolvla2): drive select_message through SmolVLM.generate
The hand-rolled AR loop in ``select_message`` was fighting the underlying ``vlm_with_expert.forward`` design, which assumes the "prefix-once + suffix-always-via-expert" pattern that ``denoise_step`` uses for action chunks. Cross-attn layers (every other layer with ``attention_mode='cross_attn'`` + ``self_attn_every_n_layers=2``) hard-require an expert input on every call: passing ``inputs_embeds=[current_embs, None]`` crashed at ``expert_layer.input_layernorm(None)`` with ``'NoneType' object has no attribute 'dtype'``. Earlier KV-cache attempts ran into the matching ``[15, 139] vs [15, 1]`` shape mismatch because the cache gets *overwritten*, not appended, on each ``fill_kv_cache=True`` call — there's just no AR-text-decode mode in this forward. Stop fighting it: drive AR text generation through the underlying SmolVLM via ``vlm.generate(input_ids=..., attention_mask=..., pixel_values=...)``. KV caching, sampling/greedy, EOS handling all come from HF's standard implementation. Trade-off: ``state`` drops out of the prefix at inference (no slot for it on the standard SmolVLM path), so high-level generations may drift from training distribution slightly. That's acceptable for the dry-run REPL — the high-level branches (subtask / plan / memory / vqa) are mostly vision+language conditioned anyway, and the action expert (where state actually matters) goes through the unchanged ``select_action`` path. Image features the runtime merged in (``observation.images.*``) are stacked into the ``[B, num_images, C, H, W]`` shape SmolVLM expects. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -303,79 +303,65 @@ class SmolVLA2Policy(SmolVLAPolicy):
|
|||||||
if eos_token_id is None:
|
if eos_token_id is None:
|
||||||
eos_token_id = tokenizer.eos_token_id
|
eos_token_id = tokenizer.eos_token_id
|
||||||
|
|
||||||
images, img_masks = self.prepare_images(batch)
|
# AR text generation through the underlying SmolVLM rather
|
||||||
state = self.prepare_state(batch)
|
# than ``vlm_with_expert.forward``. The latter is built around
|
||||||
|
# the action-expert decode pattern (one-shot suffix forward
|
||||||
|
# against a cached prefix, with cross-attn layers that
|
||||||
|
# *require* an expert input on every call) — it isn't a
|
||||||
|
# general-purpose AR text decoder. ``vlm.generate`` runs the
|
||||||
|
# SmolVLM exactly the way HuggingFace inference does it
|
||||||
|
# everywhere else, so KV caching, beam/greedy/sampling logic,
|
||||||
|
# and EOS handling all just work.
|
||||||
|
#
|
||||||
|
# Trade-off: ``state`` is dropped from the prefix at inference
|
||||||
|
# time (no slot for it on the standard SmolVLM path), so
|
||||||
|
# generations may drift from training distribution slightly.
|
||||||
|
# That's acceptable for the dry-run REPL. The high-level
|
||||||
|
# branches (subtask / plan / memory / vqa) are mostly
|
||||||
|
# vision+language conditioned anyway; the action expert is
|
||||||
|
# where state really matters.
|
||||||
|
vlm = self.model.vlm_with_expert.vlm
|
||||||
lang_tokens = batch[OBS_LANGUAGE_TOKENS]
|
lang_tokens = batch[OBS_LANGUAGE_TOKENS]
|
||||||
lang_masks = batch[OBS_LANGUAGE_ATTENTION_MASK]
|
lang_masks = batch[OBS_LANGUAGE_ATTENTION_MASK]
|
||||||
|
|
||||||
# Embed the (images + lang + state) prefix once. Image
|
# Collect any image features the runtime merged in. SmolVLM
|
||||||
# encoding is the expensive part of ``embed_prefix``, so doing
|
# expects ``pixel_values`` shape ``[B, num_images, C, H, W]``;
|
||||||
# it here and concatenating new-token embeddings into the same
|
# individual ``observation.images.*`` features are typically
|
||||||
# ``current_embs`` buffer lets us avoid re-running SigLIP on
|
# ``[B, C, H, W]`` after the preprocessor, so stack them on a
|
||||||
# every decode step.
|
# new ``num_images`` axis.
|
||||||
prefix_embs, prefix_pad_masks, prefix_att_masks = self.model.embed_prefix(
|
image_tensors: list[Tensor] = []
|
||||||
images, img_masks, lang_tokens, lang_masks, state=state
|
for k, v in batch.items():
|
||||||
|
if (
|
||||||
|
isinstance(k, str)
|
||||||
|
and k.startswith("observation.images.")
|
||||||
|
and isinstance(v, Tensor)
|
||||||
|
):
|
||||||
|
image_tensors.append(v if v.ndim == 4 else v.unsqueeze(0))
|
||||||
|
pixel_values = (
|
||||||
|
torch.stack(image_tensors, dim=1) if image_tensors else None
|
||||||
)
|
)
|
||||||
|
|
||||||
device = prefix_embs.device
|
gen_kwargs: dict[str, Any] = {
|
||||||
bsize = prefix_embs.shape[0]
|
"input_ids": lang_tokens,
|
||||||
vlm = self.model.vlm_with_expert.vlm
|
"attention_mask": lang_masks.long() if lang_masks.dtype == torch.bool else lang_masks,
|
||||||
emb_dim = prefix_embs.shape[-1]
|
"max_new_tokens": max_new_tokens,
|
||||||
text_emb_scale = math.sqrt(emb_dim)
|
"do_sample": temperature > 0,
|
||||||
|
"pad_token_id": tokenizer.pad_token_id or tokenizer.eos_token_id,
|
||||||
|
}
|
||||||
|
if temperature > 0:
|
||||||
|
gen_kwargs["temperature"] = temperature
|
||||||
|
gen_kwargs["top_p"] = top_p
|
||||||
|
if eos_token_id is not None:
|
||||||
|
gen_kwargs["eos_token_id"] = eos_token_id
|
||||||
|
if pixel_values is not None:
|
||||||
|
gen_kwargs["pixel_values"] = pixel_values
|
||||||
|
|
||||||
# Cumulative buffers — the prefix at first, then grown by one
|
gen_ids = vlm.generate(**gen_kwargs)
|
||||||
# token per decode step. The attention layer's only supported
|
# ``vlm.generate`` returns the prompt + new tokens; slice off
|
||||||
# multi-step pattern is "pass the full embedded sequence each
|
# the prompt so the caller only sees the model's continuation.
|
||||||
# call with no KV cache" (the underlying
|
prompt_len = lang_tokens.shape[1]
|
||||||
# ``vlm_with_expert.forward`` overwrites the cache instead of
|
new_token_ids = gen_ids[0, prompt_len:].tolist()
|
||||||
# appending, so true incremental decoding isn't supported).
|
return tokenizer.decode(new_token_ids, skip_special_tokens=True).strip()
|
||||||
# This is O(n²) in the text length but matches the pattern
|
|
||||||
# ``denoise_step`` already uses successfully.
|
|
||||||
current_embs = prefix_embs
|
|
||||||
current_pad = prefix_pad_masks
|
|
||||||
current_att = prefix_att_masks
|
|
||||||
|
|
||||||
# Pre-build a one-step mask append (a generated token always
|
|
||||||
# has ``pad=1`` and ``att=1`` — fully causal among generated
|
|
||||||
# tokens, attends back to the entire prefix).
|
|
||||||
ones_step = torch.ones((bsize, 1), dtype=torch.bool, device=device)
|
|
||||||
|
|
||||||
generated: list[int] = []
|
|
||||||
for step in range(max_new_tokens):
|
|
||||||
full_2d = make_att_2d_masks(current_pad, current_att)
|
|
||||||
full_pos = torch.cumsum(current_pad, dim=1) - 1
|
|
||||||
|
|
||||||
out_pair, _ = self.model.vlm_with_expert.forward(
|
|
||||||
attention_mask=full_2d,
|
|
||||||
position_ids=full_pos,
|
|
||||||
past_key_values=None,
|
|
||||||
inputs_embeds=[current_embs, None],
|
|
||||||
use_cache=False,
|
|
||||||
fill_kv_cache=False,
|
|
||||||
)
|
|
||||||
prefix_out = out_pair[0] if isinstance(out_pair, (tuple, list)) else out_pair
|
|
||||||
if prefix_out is None:
|
|
||||||
raise RuntimeError(
|
|
||||||
"select_message: vlm_with_expert.forward returned no hidden states."
|
|
||||||
)
|
|
||||||
|
|
||||||
last_hidden = prefix_out[:, -1:]
|
|
||||||
logits_step = vlm.lm_head(last_hidden)[:, -1] # (B, V)
|
|
||||||
next_ids = self._sample_next_token(logits_step, temperature, top_p)
|
|
||||||
tok_id = int(next_ids[0].item())
|
|
||||||
generated.append(tok_id)
|
|
||||||
if eos_token_id is not None and tok_id == eos_token_id:
|
|
||||||
break
|
|
||||||
|
|
||||||
new_emb = self.model.vlm_with_expert.embed_language_tokens(
|
|
||||||
next_ids.unsqueeze(0)
|
|
||||||
)
|
|
||||||
new_emb = new_emb * text_emb_scale
|
|
||||||
current_embs = torch.cat([current_embs, new_emb], dim=1)
|
|
||||||
current_pad = torch.cat([current_pad, ones_step], dim=1)
|
|
||||||
current_att = torch.cat([current_att, ones_step], dim=1)
|
|
||||||
|
|
||||||
return tokenizer.decode(generated, skip_special_tokens=True).strip()
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _sample_next_token(
|
def _sample_next_token(
|
||||||
|
|||||||
Reference in New Issue
Block a user