mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 03:59:42 +00:00
fix(smolvla2): match training's text-loss forward in select_message
Previous rewrite drove generation through ``vlm.generate()`` (the
standard SmolVLM path), which ignores SmolVLA's custom ``embed_prefix``
that interleaves images + lang + state. Result: the model received a
prompt format it had never been trained on at inference and emitted
JSON-fragment gibberish (``" " " ,",","`` ``cube lift {"...``).
Revert to the cumulative-buffer AR loop driven through
``vlm_with_expert.forward`` — the *same* forward call ``_compute_text_loss``
makes during training (``inputs_embeds=[prefix_embs, None],
use_cache=False, fill_kv_cache=True``). With ``fill_kv_cache=True``,
every layer routes through ``forward_attn_layer``, which gracefully
skips ``None`` expert inputs (``if hidden_states is None or layer is
None: continue``); cross-attention layers — which would otherwise hard-
require a non-None expert input — are bypassed entirely.
Inference now sees the same prefix structure as training: images +
lang + state, with new tokens appended to the lang region. The text
distribution matches what the model was trained to produce.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -303,63 +303,83 @@ class SmolVLA2Policy(SmolVLAPolicy):
|
|||||||
if eos_token_id is None:
|
if eos_token_id is None:
|
||||||
eos_token_id = tokenizer.eos_token_id
|
eos_token_id = tokenizer.eos_token_id
|
||||||
|
|
||||||
# AR text generation through the underlying SmolVLM rather
|
# Match training's text-loss forward path (see
|
||||||
# than ``vlm_with_expert.forward``. The latter is built around
|
# ``_compute_text_loss`` above): build the full prefix via
|
||||||
# the action-expert decode pattern (one-shot suffix forward
|
# ``embed_prefix`` so images + state conditioning is intact,
|
||||||
# against a cached prefix, with cross-attn layers that
|
# then loop AR with ``fill_kv_cache=True, use_cache=False``.
|
||||||
# *require* an expert input on every call) — it isn't a
|
# That flag combo routes every layer through
|
||||||
# general-purpose AR text decoder. ``vlm.generate`` runs the
|
# ``forward_attn_layer`` (which gracefully skips ``None``
|
||||||
# SmolVLM exactly the way HuggingFace inference does it
|
# expert inputs via ``if hidden_states is None or layer is
|
||||||
# everywhere else, so KV caching, beam/greedy/sampling logic,
|
# None: continue``) and short-circuits the cache-update logic
|
||||||
# and EOS handling all just work.
|
# so we don't have to manage past_kv. Each step just
|
||||||
|
# re-forwards the cumulative ``[prefix + generated]``
|
||||||
|
# sequence.
|
||||||
#
|
#
|
||||||
# Trade-off: ``state`` is dropped from the prefix at inference
|
# This is O(n²) in generated text length but cheap in
|
||||||
# time (no slot for it on the standard SmolVLM path), so
|
# absolute terms: image encoding happens once via the initial
|
||||||
# generations may drift from training distribution slightly.
|
# ``embed_prefix`` call, and the per-step cost is just one
|
||||||
# That's acceptable for the dry-run REPL. The high-level
|
# SmolVLM transformer pass over a sequence that grows by one
|
||||||
# branches (subtask / plan / memory / vqa) are mostly
|
# token each time. Standard SmolVLM ``generate`` was the
|
||||||
# vision+language conditioned anyway; the action expert is
|
# other tempting path, but it can't accept SmolVLA's custom
|
||||||
# where state really matters.
|
# ``state_proj`` output and its tile-grid expectations
|
||||||
vlm = self.model.vlm_with_expert.vlm
|
# disagree with our preprocessor — both lead to garbage
|
||||||
|
# generation, which is what the prior approach produced.
|
||||||
|
images, img_masks = self.prepare_images(batch)
|
||||||
|
state = self.prepare_state(batch)
|
||||||
lang_tokens = batch[OBS_LANGUAGE_TOKENS]
|
lang_tokens = batch[OBS_LANGUAGE_TOKENS]
|
||||||
lang_masks = batch[OBS_LANGUAGE_ATTENTION_MASK]
|
lang_masks = batch[OBS_LANGUAGE_ATTENTION_MASK]
|
||||||
|
|
||||||
# NOTE: we deliberately do *not* forward ``pixel_values`` to
|
prefix_embs, prefix_pad_masks, prefix_att_masks = self.model.embed_prefix(
|
||||||
# ``vlm.generate``. The dataset's images go through SmolVLA's
|
images, img_masks, lang_tokens, lang_masks, state=state
|
||||||
# custom preprocessor (resize / normalise to whatever shape
|
)
|
||||||
# the action expert was trained on), but SmolVLM's standard
|
|
||||||
# vision tower expects images sized to its own default tile
|
|
||||||
# grid (e.g. 384/14 → 27×27 patches). The mismatch surfaces
|
|
||||||
# as ``RuntimeError: shape '[2, 34, 34, 768]' is invalid for
|
|
||||||
# input of size <other>`` deep in the post-vision reshape.
|
|
||||||
#
|
|
||||||
# For the dry-run REPL the high-level branches (subtask /
|
|
||||||
# plan / memory) are dominated by their text context anyway,
|
|
||||||
# so running text-only generation through SmolVLM is good
|
|
||||||
# enough. Restoring full vision conditioning at inference
|
|
||||||
# would mean either re-processing the images through the
|
|
||||||
# backbone's own ``ImageProcessor`` (and matching SmolVLA2
|
|
||||||
# training shape) or giving ``vlm_with_expert`` a real AR
|
|
||||||
# text decode mode — both are bigger follow-ups.
|
|
||||||
gen_kwargs: dict[str, Any] = {
|
|
||||||
"input_ids": lang_tokens,
|
|
||||||
"attention_mask": lang_masks.long() if lang_masks.dtype == torch.bool else lang_masks,
|
|
||||||
"max_new_tokens": max_new_tokens,
|
|
||||||
"do_sample": temperature > 0,
|
|
||||||
"pad_token_id": tokenizer.pad_token_id or tokenizer.eos_token_id,
|
|
||||||
}
|
|
||||||
if temperature > 0:
|
|
||||||
gen_kwargs["temperature"] = temperature
|
|
||||||
gen_kwargs["top_p"] = top_p
|
|
||||||
if eos_token_id is not None:
|
|
||||||
gen_kwargs["eos_token_id"] = eos_token_id
|
|
||||||
|
|
||||||
gen_ids = vlm.generate(**gen_kwargs)
|
device = prefix_embs.device
|
||||||
# ``vlm.generate`` returns the prompt + new tokens; slice off
|
bsize = prefix_embs.shape[0]
|
||||||
# the prompt so the caller only sees the model's continuation.
|
vlm = self.model.vlm_with_expert.vlm
|
||||||
prompt_len = lang_tokens.shape[1]
|
emb_dim = prefix_embs.shape[-1]
|
||||||
new_token_ids = gen_ids[0, prompt_len:].tolist()
|
text_emb_scale = math.sqrt(emb_dim)
|
||||||
return tokenizer.decode(new_token_ids, skip_special_tokens=True).strip()
|
|
||||||
|
current_embs = prefix_embs
|
||||||
|
current_pad = prefix_pad_masks
|
||||||
|
current_att = prefix_att_masks
|
||||||
|
ones_step = torch.ones((bsize, 1), dtype=torch.bool, device=device)
|
||||||
|
|
||||||
|
generated: list[int] = []
|
||||||
|
for _ in range(max_new_tokens):
|
||||||
|
full_2d = make_att_2d_masks(current_pad, current_att)
|
||||||
|
full_pos = torch.cumsum(current_pad, dim=1) - 1
|
||||||
|
|
||||||
|
out_pair, _ = self.model.vlm_with_expert.forward(
|
||||||
|
attention_mask=full_2d,
|
||||||
|
position_ids=full_pos,
|
||||||
|
past_key_values=None,
|
||||||
|
inputs_embeds=[current_embs, None],
|
||||||
|
use_cache=False,
|
||||||
|
fill_kv_cache=True,
|
||||||
|
)
|
||||||
|
prefix_out = out_pair[0] if isinstance(out_pair, (tuple, list)) else out_pair
|
||||||
|
if prefix_out is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"select_message: vlm_with_expert.forward returned no hidden states."
|
||||||
|
)
|
||||||
|
|
||||||
|
last_hidden = prefix_out[:, -1:].to(vlm.lm_head.weight.dtype)
|
||||||
|
logits_step = vlm.lm_head(last_hidden)[:, -1] # (B, V)
|
||||||
|
next_ids = self._sample_next_token(logits_step, temperature, top_p)
|
||||||
|
tok_id = int(next_ids[0].item())
|
||||||
|
generated.append(tok_id)
|
||||||
|
if eos_token_id is not None and tok_id == eos_token_id:
|
||||||
|
break
|
||||||
|
|
||||||
|
new_emb = self.model.vlm_with_expert.embed_language_tokens(
|
||||||
|
next_ids.unsqueeze(0)
|
||||||
|
)
|
||||||
|
new_emb = new_emb * text_emb_scale
|
||||||
|
current_embs = torch.cat([current_embs, new_emb], dim=1)
|
||||||
|
current_pad = torch.cat([current_pad, ones_step], dim=1)
|
||||||
|
current_att = torch.cat([current_att, ones_step], dim=1)
|
||||||
|
|
||||||
|
return tokenizer.decode(generated, skip_special_tokens=True).strip()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _sample_next_token(
|
def _sample_next_token(
|
||||||
|
|||||||
Reference in New Issue
Block a user