diff --git a/src/lerobot/policies/smolvla2/modeling_smolvla2.py b/src/lerobot/policies/smolvla2/modeling_smolvla2.py index 89d4a1c7a..f4d0ebf6b 100644 --- a/src/lerobot/policies/smolvla2/modeling_smolvla2.py +++ b/src/lerobot/policies/smolvla2/modeling_smolvla2.py @@ -303,79 +303,65 @@ class SmolVLA2Policy(SmolVLAPolicy): if eos_token_id is None: eos_token_id = tokenizer.eos_token_id - images, img_masks = self.prepare_images(batch) - state = self.prepare_state(batch) + # AR text generation through the underlying SmolVLM rather + # than ``vlm_with_expert.forward``. The latter is built around + # the action-expert decode pattern (one-shot suffix forward + # against a cached prefix, with cross-attn layers that + # *require* an expert input on every call) — it isn't a + # general-purpose AR text decoder. ``vlm.generate`` runs the + # SmolVLM exactly the way HuggingFace inference does it + # everywhere else, so KV caching, beam/greedy/sampling logic, + # and EOS handling all just work. + # + # Trade-off: ``state`` is dropped from the prefix at inference + # time (no slot for it on the standard SmolVLM path), so + # generations may drift from training distribution slightly. + # That's acceptable for the dry-run REPL. The high-level + # branches (subtask / plan / memory / vqa) are mostly + # vision+language conditioned anyway; the action expert is + # where state really matters. + vlm = self.model.vlm_with_expert.vlm lang_tokens = batch[OBS_LANGUAGE_TOKENS] lang_masks = batch[OBS_LANGUAGE_ATTENTION_MASK] - # Embed the (images + lang + state) prefix once. Image - # encoding is the expensive part of ``embed_prefix``, so doing - # it here and concatenating new-token embeddings into the same - # ``current_embs`` buffer lets us avoid re-running SigLIP on - # every decode step. - prefix_embs, prefix_pad_masks, prefix_att_masks = self.model.embed_prefix( - images, img_masks, lang_tokens, lang_masks, state=state + # Collect any image features the runtime merged in. SmolVLM + # expects ``pixel_values`` shape ``[B, num_images, C, H, W]``; + # individual ``observation.images.*`` features are typically + # ``[B, C, H, W]`` after the preprocessor, so stack them on a + # new ``num_images`` axis. + image_tensors: list[Tensor] = [] + for k, v in batch.items(): + if ( + isinstance(k, str) + and k.startswith("observation.images.") + and isinstance(v, Tensor) + ): + image_tensors.append(v if v.ndim == 4 else v.unsqueeze(0)) + pixel_values = ( + torch.stack(image_tensors, dim=1) if image_tensors else None ) - device = prefix_embs.device - bsize = prefix_embs.shape[0] - vlm = self.model.vlm_with_expert.vlm - emb_dim = prefix_embs.shape[-1] - text_emb_scale = math.sqrt(emb_dim) + gen_kwargs: dict[str, Any] = { + "input_ids": lang_tokens, + "attention_mask": lang_masks.long() if lang_masks.dtype == torch.bool else lang_masks, + "max_new_tokens": max_new_tokens, + "do_sample": temperature > 0, + "pad_token_id": tokenizer.pad_token_id or tokenizer.eos_token_id, + } + if temperature > 0: + gen_kwargs["temperature"] = temperature + gen_kwargs["top_p"] = top_p + if eos_token_id is not None: + gen_kwargs["eos_token_id"] = eos_token_id + if pixel_values is not None: + gen_kwargs["pixel_values"] = pixel_values - # Cumulative buffers — the prefix at first, then grown by one - # token per decode step. The attention layer's only supported - # multi-step pattern is "pass the full embedded sequence each - # call with no KV cache" (the underlying - # ``vlm_with_expert.forward`` overwrites the cache instead of - # appending, so true incremental decoding isn't supported). - # This is O(n²) in the text length but matches the pattern - # ``denoise_step`` already uses successfully. - current_embs = prefix_embs - current_pad = prefix_pad_masks - current_att = prefix_att_masks - - # Pre-build a one-step mask append (a generated token always - # has ``pad=1`` and ``att=1`` — fully causal among generated - # tokens, attends back to the entire prefix). - ones_step = torch.ones((bsize, 1), dtype=torch.bool, device=device) - - generated: list[int] = [] - for step in range(max_new_tokens): - full_2d = make_att_2d_masks(current_pad, current_att) - full_pos = torch.cumsum(current_pad, dim=1) - 1 - - out_pair, _ = self.model.vlm_with_expert.forward( - attention_mask=full_2d, - position_ids=full_pos, - past_key_values=None, - inputs_embeds=[current_embs, None], - use_cache=False, - fill_kv_cache=False, - ) - prefix_out = out_pair[0] if isinstance(out_pair, (tuple, list)) else out_pair - if prefix_out is None: - raise RuntimeError( - "select_message: vlm_with_expert.forward returned no hidden states." - ) - - last_hidden = prefix_out[:, -1:] - logits_step = vlm.lm_head(last_hidden)[:, -1] # (B, V) - next_ids = self._sample_next_token(logits_step, temperature, top_p) - tok_id = int(next_ids[0].item()) - generated.append(tok_id) - if eos_token_id is not None and tok_id == eos_token_id: - break - - new_emb = self.model.vlm_with_expert.embed_language_tokens( - next_ids.unsqueeze(0) - ) - new_emb = new_emb * text_emb_scale - current_embs = torch.cat([current_embs, new_emb], dim=1) - current_pad = torch.cat([current_pad, ones_step], dim=1) - current_att = torch.cat([current_att, ones_step], dim=1) - - return tokenizer.decode(generated, skip_special_tokens=True).strip() + gen_ids = vlm.generate(**gen_kwargs) + # ``vlm.generate`` returns the prompt + new tokens; slice off + # the prompt so the caller only sees the model's continuation. + prompt_len = lang_tokens.shape[1] + new_token_ids = gen_ids[0, prompt_len:].tolist() + return tokenizer.decode(new_token_ids, skip_special_tokens=True).strip() @staticmethod def _sample_next_token(