diff --git a/src/lerobot/policies/smolvla2/modeling_smolvla2.py b/src/lerobot/policies/smolvla2/modeling_smolvla2.py index 1c65a26ec..79978cf95 100644 --- a/src/lerobot/policies/smolvla2/modeling_smolvla2.py +++ b/src/lerobot/policies/smolvla2/modeling_smolvla2.py @@ -303,63 +303,83 @@ class SmolVLA2Policy(SmolVLAPolicy): if eos_token_id is None: eos_token_id = tokenizer.eos_token_id - # AR text generation through the underlying SmolVLM rather - # than ``vlm_with_expert.forward``. The latter is built around - # the action-expert decode pattern (one-shot suffix forward - # against a cached prefix, with cross-attn layers that - # *require* an expert input on every call) — it isn't a - # general-purpose AR text decoder. ``vlm.generate`` runs the - # SmolVLM exactly the way HuggingFace inference does it - # everywhere else, so KV caching, beam/greedy/sampling logic, - # and EOS handling all just work. + # Match training's text-loss forward path (see + # ``_compute_text_loss`` above): build the full prefix via + # ``embed_prefix`` so images + state conditioning is intact, + # then loop AR with ``fill_kv_cache=True, use_cache=False``. + # That flag combo routes every layer through + # ``forward_attn_layer`` (which gracefully skips ``None`` + # expert inputs via ``if hidden_states is None or layer is + # None: continue``) and short-circuits the cache-update logic + # so we don't have to manage past_kv. Each step just + # re-forwards the cumulative ``[prefix + generated]`` + # sequence. # - # Trade-off: ``state`` is dropped from the prefix at inference - # time (no slot for it on the standard SmolVLM path), so - # generations may drift from training distribution slightly. - # That's acceptable for the dry-run REPL. The high-level - # branches (subtask / plan / memory / vqa) are mostly - # vision+language conditioned anyway; the action expert is - # where state really matters. - vlm = self.model.vlm_with_expert.vlm + # This is O(n²) in generated text length but cheap in + # absolute terms: image encoding happens once via the initial + # ``embed_prefix`` call, and the per-step cost is just one + # SmolVLM transformer pass over a sequence that grows by one + # token each time. Standard SmolVLM ``generate`` was the + # other tempting path, but it can't accept SmolVLA's custom + # ``state_proj`` output and its tile-grid expectations + # disagree with our preprocessor — both lead to garbage + # generation, which is what the prior approach produced. + images, img_masks = self.prepare_images(batch) + state = self.prepare_state(batch) lang_tokens = batch[OBS_LANGUAGE_TOKENS] lang_masks = batch[OBS_LANGUAGE_ATTENTION_MASK] - # NOTE: we deliberately do *not* forward ``pixel_values`` to - # ``vlm.generate``. The dataset's images go through SmolVLA's - # custom preprocessor (resize / normalise to whatever shape - # the action expert was trained on), but SmolVLM's standard - # vision tower expects images sized to its own default tile - # grid (e.g. 384/14 → 27×27 patches). The mismatch surfaces - # as ``RuntimeError: shape '[2, 34, 34, 768]' is invalid for - # input of size `` deep in the post-vision reshape. - # - # For the dry-run REPL the high-level branches (subtask / - # plan / memory) are dominated by their text context anyway, - # so running text-only generation through SmolVLM is good - # enough. Restoring full vision conditioning at inference - # would mean either re-processing the images through the - # backbone's own ``ImageProcessor`` (and matching SmolVLA2 - # training shape) or giving ``vlm_with_expert`` a real AR - # text decode mode — both are bigger follow-ups. - gen_kwargs: dict[str, Any] = { - "input_ids": lang_tokens, - "attention_mask": lang_masks.long() if lang_masks.dtype == torch.bool else lang_masks, - "max_new_tokens": max_new_tokens, - "do_sample": temperature > 0, - "pad_token_id": tokenizer.pad_token_id or tokenizer.eos_token_id, - } - if temperature > 0: - gen_kwargs["temperature"] = temperature - gen_kwargs["top_p"] = top_p - if eos_token_id is not None: - gen_kwargs["eos_token_id"] = eos_token_id + prefix_embs, prefix_pad_masks, prefix_att_masks = self.model.embed_prefix( + images, img_masks, lang_tokens, lang_masks, state=state + ) - gen_ids = vlm.generate(**gen_kwargs) - # ``vlm.generate`` returns the prompt + new tokens; slice off - # the prompt so the caller only sees the model's continuation. - prompt_len = lang_tokens.shape[1] - new_token_ids = gen_ids[0, prompt_len:].tolist() - return tokenizer.decode(new_token_ids, skip_special_tokens=True).strip() + device = prefix_embs.device + bsize = prefix_embs.shape[0] + vlm = self.model.vlm_with_expert.vlm + emb_dim = prefix_embs.shape[-1] + text_emb_scale = math.sqrt(emb_dim) + + current_embs = prefix_embs + current_pad = prefix_pad_masks + current_att = prefix_att_masks + ones_step = torch.ones((bsize, 1), dtype=torch.bool, device=device) + + generated: list[int] = [] + for _ in range(max_new_tokens): + full_2d = make_att_2d_masks(current_pad, current_att) + full_pos = torch.cumsum(current_pad, dim=1) - 1 + + out_pair, _ = self.model.vlm_with_expert.forward( + attention_mask=full_2d, + position_ids=full_pos, + past_key_values=None, + inputs_embeds=[current_embs, None], + use_cache=False, + fill_kv_cache=True, + ) + prefix_out = out_pair[0] if isinstance(out_pair, (tuple, list)) else out_pair + if prefix_out is None: + raise RuntimeError( + "select_message: vlm_with_expert.forward returned no hidden states." + ) + + last_hidden = prefix_out[:, -1:].to(vlm.lm_head.weight.dtype) + logits_step = vlm.lm_head(last_hidden)[:, -1] # (B, V) + next_ids = self._sample_next_token(logits_step, temperature, top_p) + tok_id = int(next_ids[0].item()) + generated.append(tok_id) + if eos_token_id is not None and tok_id == eos_token_id: + break + + new_emb = self.model.vlm_with_expert.embed_language_tokens( + next_ids.unsqueeze(0) + ) + new_emb = new_emb * text_emb_scale + current_embs = torch.cat([current_embs, new_emb], dim=1) + current_pad = torch.cat([current_pad, ones_step], dim=1) + current_att = torch.cat([current_att, ones_step], dim=1) + + return tokenizer.decode(generated, skip_special_tokens=True).strip() @staticmethod def _sample_next_token(