fix(smolvla2): handle batched sample indices in chat tokenizer

Normalize tensor and sequence sample indices before prompt dropout so distributed batched preprocessing does not try to cast full index tensors to scalars.

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
pepijn
2026-05-12 16:56:13 +00:00
parent 4908433f9a
commit bfd3bb1791
@@ -122,19 +122,8 @@ class SmolVLA2ChatTokenizerStep(ProcessorStep):
# flows through into ``COMPLEMENTARY_DATA`` unchanged. When
# absent (e.g. inference) we fall back to 0 which is harmless
# because the dropout probs are also 0 at inference time.
sample_idx_raw = comp.get("index")
if hasattr(sample_idx_raw, "item"):
try:
sample_idx_raw = sample_idx_raw.item()
except Exception: # noqa: BLE001
pass
if _is_batched_messages(messages):
indices_iter = (
sample_idx_raw
if isinstance(sample_idx_raw, (list, tuple))
else [sample_idx_raw] * len(messages)
)
indices_iter = _sample_indices(comp.get("index"), len(messages))
encoded = [
self._encode_messages(
tokenizer,
@@ -152,13 +141,14 @@ class SmolVLA2ChatTokenizerStep(ProcessorStep):
)
]
else:
sample_idx = _sample_indices(comp.get("index"), 1)[0]
encoded = [
self._encode_messages(
tokenizer,
messages,
list(comp.get("message_streams") or []),
sorted(int(i) for i in (comp.get("target_message_indices") or [])),
sample_idx=int(sample_idx_raw) if sample_idx_raw is not None else None,
sample_idx=sample_idx,
)
]
@@ -396,6 +386,21 @@ def _is_batched_messages(messages: Any) -> bool:
return isinstance(messages, list) and bool(messages) and isinstance(messages[0], list)
def _sample_indices(value: Any, batch_size: int) -> list[int | None]:
if value is None:
return [None] * batch_size
if isinstance(value, torch.Tensor):
if value.numel() == 1:
return [int(value.item())] * batch_size
values = value.reshape(-1).tolist()
return [int(v) for v in values[:batch_size]]
if isinstance(value, (list, tuple)):
if len(value) == 1:
return _sample_indices(value[0], batch_size)
return [int(v.item() if hasattr(v, "item") else v) for v in value[:batch_size]]
return [int(value)] * batch_size
def _classify_message_for_dropout(message: dict[str, Any]) -> str | None:
"""Best-effort classification of which recipe binding contributed
to this message, used for per-component dropout.