mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-21 11:39:50 +00:00
fix(smolvla2): handle BatchEncoding return from apply_chat_template
``tokenizer.apply_chat_template(..., tokenize=True, return_tensors='pt')``
on newer transformers returns a ``BatchEncoding`` (dict-like) rather
than a raw ``Tensor`` — particularly when the underlying call routes
through a processor. ``_build_text_batch`` only handled the ``Tensor``
and ``list`` shapes, so the encoding object reached SmolVLA's
``embed_language_tokens`` and ``F.embedding`` blew up with
``argument 'indices' must be Tensor, not BatchEncoding`` on every
high-level forward.
Normalise the return:
* ``BatchEncoding`` / ``dict`` → take ``input_ids`` (and the encoder's
``attention_mask`` when present, since ``pad_token_id`` can be
``None`` for SmolVLM and the fall-back ``ids != pad_token_id``
breaks then),
* ``list[int]`` / ``list[list[int]]`` → wrap in a long tensor,
* ``Tensor`` → keep as-is.
After unwrapping, ensure shape ``(1, seq)`` and that ``attention_mask``
is a tensor on the same device as ``ids``.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -171,17 +171,46 @@ def _build_text_batch(policy: Any, prompt_messages: list[dict[str, Any]]) -> dic
|
|||||||
tokenizer.pad_token = tokenizer.eos_token
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
|
||||||
text_messages = [_strip_recipe_keys(m) for m in prompt_messages]
|
text_messages = [_strip_recipe_keys(m) for m in prompt_messages]
|
||||||
ids = tokenizer.apply_chat_template(
|
encoded = tokenizer.apply_chat_template(
|
||||||
text_messages,
|
text_messages,
|
||||||
add_generation_prompt=True,
|
add_generation_prompt=True,
|
||||||
tokenize=True,
|
tokenize=True,
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
)
|
)
|
||||||
|
# ``apply_chat_template`` can return any of:
|
||||||
|
# - a Tensor of shape ``(seq,)`` or ``(1, seq)`` (older transformers),
|
||||||
|
# - a list[int] / list[list[int]] (when ``return_tensors`` is ignored),
|
||||||
|
# - a ``BatchEncoding`` dict-like with ``input_ids`` / ``attention_mask``
|
||||||
|
# (newer transformers, especially via processor.apply_chat_template
|
||||||
|
# forwarding through here).
|
||||||
|
# Normalise to ``ids: Tensor[1, seq]`` and grab the encoder's
|
||||||
|
# attention mask when available so we don't have to re-derive it
|
||||||
|
# from ``pad_token_id`` (which can be ``None`` for SmolVLM).
|
||||||
|
attn: Any = None
|
||||||
|
if hasattr(encoded, "input_ids"):
|
||||||
|
ids = encoded.input_ids
|
||||||
|
attn = getattr(encoded, "attention_mask", None)
|
||||||
|
elif isinstance(encoded, dict) and "input_ids" in encoded:
|
||||||
|
ids = encoded["input_ids"]
|
||||||
|
attn = encoded.get("attention_mask")
|
||||||
|
else:
|
||||||
|
ids = encoded
|
||||||
if isinstance(ids, list):
|
if isinstance(ids, list):
|
||||||
ids = ids[0] if ids else []
|
if ids and isinstance(ids[0], list):
|
||||||
|
ids = ids[0]
|
||||||
|
import torch # noqa: PLC0415
|
||||||
|
|
||||||
|
ids = torch.tensor(ids, dtype=torch.long)
|
||||||
if hasattr(ids, "ndim") and ids.ndim == 1:
|
if hasattr(ids, "ndim") and ids.ndim == 1:
|
||||||
ids = ids.unsqueeze(0)
|
ids = ids.unsqueeze(0)
|
||||||
attn = (ids != tokenizer.pad_token_id) if tokenizer.pad_token_id is not None else None
|
if attn is None and tokenizer.pad_token_id is not None:
|
||||||
|
attn = ids != tokenizer.pad_token_id
|
||||||
|
elif isinstance(attn, list):
|
||||||
|
import torch # noqa: PLC0415
|
||||||
|
|
||||||
|
attn = torch.tensor(attn, dtype=torch.long)
|
||||||
|
if attn.ndim == 1:
|
||||||
|
attn = attn.unsqueeze(0)
|
||||||
# Move tokens onto the policy's device — otherwise prefix embedding
|
# Move tokens onto the policy's device — otherwise prefix embedding
|
||||||
# raises a device-mismatch on every forward (CPU tensor vs MPS / CUDA
|
# raises a device-mismatch on every forward (CPU tensor vs MPS / CUDA
|
||||||
# model), which the caller's broad except would swallow silently.
|
# model), which the caller's broad except would swallow silently.
|
||||||
|
|||||||
Reference in New Issue
Block a user