mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-29 22:27:14 +00:00
fix(smolvla2): handle batched sample indices in chat tokenizer
Normalize tensor and sequence sample indices before prompt dropout so distributed batched preprocessing does not try to cast full index tensors to scalars. Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -122,19 +122,8 @@ class SmolVLA2ChatTokenizerStep(ProcessorStep):
|
||||
# flows through into ``COMPLEMENTARY_DATA`` unchanged. When
|
||||
# absent (e.g. inference) we fall back to 0 which is harmless
|
||||
# because the dropout probs are also 0 at inference time.
|
||||
sample_idx_raw = comp.get("index")
|
||||
if hasattr(sample_idx_raw, "item"):
|
||||
try:
|
||||
sample_idx_raw = sample_idx_raw.item()
|
||||
except Exception: # noqa: BLE001
|
||||
pass
|
||||
|
||||
if _is_batched_messages(messages):
|
||||
indices_iter = (
|
||||
sample_idx_raw
|
||||
if isinstance(sample_idx_raw, (list, tuple))
|
||||
else [sample_idx_raw] * len(messages)
|
||||
)
|
||||
indices_iter = _sample_indices(comp.get("index"), len(messages))
|
||||
encoded = [
|
||||
self._encode_messages(
|
||||
tokenizer,
|
||||
@@ -152,13 +141,14 @@ class SmolVLA2ChatTokenizerStep(ProcessorStep):
|
||||
)
|
||||
]
|
||||
else:
|
||||
sample_idx = _sample_indices(comp.get("index"), 1)[0]
|
||||
encoded = [
|
||||
self._encode_messages(
|
||||
tokenizer,
|
||||
messages,
|
||||
list(comp.get("message_streams") or []),
|
||||
sorted(int(i) for i in (comp.get("target_message_indices") or [])),
|
||||
sample_idx=int(sample_idx_raw) if sample_idx_raw is not None else None,
|
||||
sample_idx=sample_idx,
|
||||
)
|
||||
]
|
||||
|
||||
@@ -396,6 +386,21 @@ def _is_batched_messages(messages: Any) -> bool:
|
||||
return isinstance(messages, list) and bool(messages) and isinstance(messages[0], list)
|
||||
|
||||
|
||||
def _sample_indices(value: Any, batch_size: int) -> list[int | None]:
|
||||
if value is None:
|
||||
return [None] * batch_size
|
||||
if isinstance(value, torch.Tensor):
|
||||
if value.numel() == 1:
|
||||
return [int(value.item())] * batch_size
|
||||
values = value.reshape(-1).tolist()
|
||||
return [int(v) for v in values[:batch_size]]
|
||||
if isinstance(value, (list, tuple)):
|
||||
if len(value) == 1:
|
||||
return _sample_indices(value[0], batch_size)
|
||||
return [int(v.item() if hasattr(v, "item") else v) for v in value[:batch_size]]
|
||||
return [int(value)] * batch_size
|
||||
|
||||
|
||||
def _classify_message_for_dropout(message: dict[str, Any]) -> str | None:
|
||||
"""Best-effort classification of which recipe binding contributed
|
||||
to this message, used for per-component dropout.
|
||||
|
||||
Reference in New Issue
Block a user