mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-11 14:49:43 +00:00
fix(smolvla2): handle BatchEncoding return from apply_chat_template
``tokenizer.apply_chat_template(..., tokenize=True, return_tensors='pt')``
on newer transformers returns a ``BatchEncoding`` (dict-like) rather
than a raw ``Tensor`` — particularly when the underlying call routes
through a processor. ``_build_text_batch`` only handled the ``Tensor``
and ``list`` shapes, so the encoding object reached SmolVLA's
``embed_language_tokens`` and ``F.embedding`` blew up with
``argument 'indices' must be Tensor, not BatchEncoding`` on every
high-level forward.
Normalise the return:
* ``BatchEncoding`` / ``dict`` → take ``input_ids`` (and the encoder's
``attention_mask`` when present, since ``pad_token_id`` can be
``None`` for SmolVLM and the fall-back ``ids != pad_token_id``
breaks then),
* ``list[int]`` / ``list[list[int]]`` → wrap in a long tensor,
* ``Tensor`` → keep as-is.
After unwrapping, ensure shape ``(1, seq)`` and that ``attention_mask``
is a tensor on the same device as ``ids``.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -171,17 +171,46 @@ def _build_text_batch(policy: Any, prompt_messages: list[dict[str, Any]]) -> dic
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
text_messages = [_strip_recipe_keys(m) for m in prompt_messages]
|
||||
ids = tokenizer.apply_chat_template(
|
||||
encoded = tokenizer.apply_chat_template(
|
||||
text_messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
# ``apply_chat_template`` can return any of:
|
||||
# - a Tensor of shape ``(seq,)`` or ``(1, seq)`` (older transformers),
|
||||
# - a list[int] / list[list[int]] (when ``return_tensors`` is ignored),
|
||||
# - a ``BatchEncoding`` dict-like with ``input_ids`` / ``attention_mask``
|
||||
# (newer transformers, especially via processor.apply_chat_template
|
||||
# forwarding through here).
|
||||
# Normalise to ``ids: Tensor[1, seq]`` and grab the encoder's
|
||||
# attention mask when available so we don't have to re-derive it
|
||||
# from ``pad_token_id`` (which can be ``None`` for SmolVLM).
|
||||
attn: Any = None
|
||||
if hasattr(encoded, "input_ids"):
|
||||
ids = encoded.input_ids
|
||||
attn = getattr(encoded, "attention_mask", None)
|
||||
elif isinstance(encoded, dict) and "input_ids" in encoded:
|
||||
ids = encoded["input_ids"]
|
||||
attn = encoded.get("attention_mask")
|
||||
else:
|
||||
ids = encoded
|
||||
if isinstance(ids, list):
|
||||
ids = ids[0] if ids else []
|
||||
if ids and isinstance(ids[0], list):
|
||||
ids = ids[0]
|
||||
import torch # noqa: PLC0415
|
||||
|
||||
ids = torch.tensor(ids, dtype=torch.long)
|
||||
if hasattr(ids, "ndim") and ids.ndim == 1:
|
||||
ids = ids.unsqueeze(0)
|
||||
attn = (ids != tokenizer.pad_token_id) if tokenizer.pad_token_id is not None else None
|
||||
if attn is None and tokenizer.pad_token_id is not None:
|
||||
attn = ids != tokenizer.pad_token_id
|
||||
elif isinstance(attn, list):
|
||||
import torch # noqa: PLC0415
|
||||
|
||||
attn = torch.tensor(attn, dtype=torch.long)
|
||||
if attn.ndim == 1:
|
||||
attn = attn.unsqueeze(0)
|
||||
# Move tokens onto the policy's device — otherwise prefix embedding
|
||||
# raises a device-mismatch on every forward (CPU tensor vs MPS / CUDA
|
||||
# model), which the caller's broad except would swallow silently.
|
||||
|
||||
Reference in New Issue
Block a user