diff --git a/src/lerobot/policies/smolvla2/chat_processor_smolvla2.py b/src/lerobot/policies/smolvla2/chat_processor_smolvla2.py index ddb08645e..e370a88e9 100644 --- a/src/lerobot/policies/smolvla2/chat_processor_smolvla2.py +++ b/src/lerobot/policies/smolvla2/chat_processor_smolvla2.py @@ -122,19 +122,8 @@ class SmolVLA2ChatTokenizerStep(ProcessorStep): # flows through into ``COMPLEMENTARY_DATA`` unchanged. When # absent (e.g. inference) we fall back to 0 which is harmless # because the dropout probs are also 0 at inference time. - sample_idx_raw = comp.get("index") - if hasattr(sample_idx_raw, "item"): - try: - sample_idx_raw = sample_idx_raw.item() - except Exception: # noqa: BLE001 - pass - if _is_batched_messages(messages): - indices_iter = ( - sample_idx_raw - if isinstance(sample_idx_raw, (list, tuple)) - else [sample_idx_raw] * len(messages) - ) + indices_iter = _sample_indices(comp.get("index"), len(messages)) encoded = [ self._encode_messages( tokenizer, @@ -152,13 +141,14 @@ class SmolVLA2ChatTokenizerStep(ProcessorStep): ) ] else: + sample_idx = _sample_indices(comp.get("index"), 1)[0] encoded = [ self._encode_messages( tokenizer, messages, list(comp.get("message_streams") or []), sorted(int(i) for i in (comp.get("target_message_indices") or [])), - sample_idx=int(sample_idx_raw) if sample_idx_raw is not None else None, + sample_idx=sample_idx, ) ] @@ -396,6 +386,21 @@ def _is_batched_messages(messages: Any) -> bool: return isinstance(messages, list) and bool(messages) and isinstance(messages[0], list) +def _sample_indices(value: Any, batch_size: int) -> list[int | None]: + if value is None: + return [None] * batch_size + if isinstance(value, torch.Tensor): + if value.numel() == 1: + return [int(value.item())] * batch_size + values = value.reshape(-1).tolist() + return [int(v) for v in values[:batch_size]] + if isinstance(value, (list, tuple)): + if len(value) == 1: + return _sample_indices(value[0], batch_size) + return [int(v.item() if hasattr(v, "item") else v) for v in value[:batch_size]] + return [int(value)] * batch_size + + def _classify_message_for_dropout(message: dict[str, Any]) -> str | None: """Best-effort classification of which recipe binding contributed to this message, used for per-component dropout.