fix(smolvla2): handle BatchEncoding return from apply_chat_template

``tokenizer.apply_chat_template(..., tokenize=True, return_tensors='pt')``
on newer transformers returns a ``BatchEncoding`` (dict-like) rather
than a raw ``Tensor`` — particularly when the underlying call routes
through a processor. ``_build_text_batch`` only handled the ``Tensor``
and ``list`` shapes, so the encoding object reached SmolVLA's
``embed_language_tokens`` and ``F.embedding`` blew up with
``argument 'indices' must be Tensor, not BatchEncoding`` on every
high-level forward.

Normalise the return:
  * ``BatchEncoding`` / ``dict`` → take ``input_ids`` (and the encoder's
    ``attention_mask`` when present, since ``pad_token_id`` can be
    ``None`` for SmolVLM and the fall-back ``ids != pad_token_id``
    breaks then),
  * ``list[int]`` / ``list[list[int]]`` → wrap in a long tensor,
  * ``Tensor`` → keep as-is.

After unwrapping, ensure shape ``(1, seq)`` and that ``attention_mask``
is a tensor on the same device as ``ids``.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-05-05 11:59:57 +02:00
parent 7296ac97af
commit 0fb5f04965
@@ -171,17 +171,46 @@ def _build_text_batch(policy: Any, prompt_messages: list[dict[str, Any]]) -> dic
tokenizer.pad_token = tokenizer.eos_token
text_messages = [_strip_recipe_keys(m) for m in prompt_messages]
ids = tokenizer.apply_chat_template(
encoded = tokenizer.apply_chat_template(
text_messages,
add_generation_prompt=True,
tokenize=True,
return_tensors="pt",
)
# ``apply_chat_template`` can return any of:
# - a Tensor of shape ``(seq,)`` or ``(1, seq)`` (older transformers),
# - a list[int] / list[list[int]] (when ``return_tensors`` is ignored),
# - a ``BatchEncoding`` dict-like with ``input_ids`` / ``attention_mask``
# (newer transformers, especially via processor.apply_chat_template
# forwarding through here).
# Normalise to ``ids: Tensor[1, seq]`` and grab the encoder's
# attention mask when available so we don't have to re-derive it
# from ``pad_token_id`` (which can be ``None`` for SmolVLM).
attn: Any = None
if hasattr(encoded, "input_ids"):
ids = encoded.input_ids
attn = getattr(encoded, "attention_mask", None)
elif isinstance(encoded, dict) and "input_ids" in encoded:
ids = encoded["input_ids"]
attn = encoded.get("attention_mask")
else:
ids = encoded
if isinstance(ids, list):
ids = ids[0] if ids else []
if ids and isinstance(ids[0], list):
ids = ids[0]
import torch # noqa: PLC0415
ids = torch.tensor(ids, dtype=torch.long)
if hasattr(ids, "ndim") and ids.ndim == 1:
ids = ids.unsqueeze(0)
attn = (ids != tokenizer.pad_token_id) if tokenizer.pad_token_id is not None else None
if attn is None and tokenizer.pad_token_id is not None:
attn = ids != tokenizer.pad_token_id
elif isinstance(attn, list):
import torch # noqa: PLC0415
attn = torch.tensor(attn, dtype=torch.long)
if attn.ndim == 1:
attn = attn.unsqueeze(0)
# Move tokens onto the policy's device — otherwise prefix embedding
# raises a device-mismatch on every forward (CPU tensor vs MPS / CUDA
# model), which the caller's broad except would swallow silently.