fix(smolvla2-runtime): match training visual distribution on robot frames

Root cause for the LM head's empty-completion symptom on the live robot
(while the same checkpoint produced sensible subtask/plan/memory in
``--no_robot`` dry-run on dataset frames): the camera observation was
flowing into the model at its native resolution. A Mac/USB webcam
hands us 1280×720 or 1920×1080; the dataset was recorded at the
feature schema's ``observation.images.*['shape']`` resolution
(typically 480×640). SmolVLA's internal ``resize_with_pad(512, 512)``
*does* fit both — but with very different pad geometry, so visual
tokens at each tile carry different content than at training. Action
expert tolerates this; the tightly-supervised LM head goes OOD and
the head's distribution at position 0 collapses to its dominant mode
(``\n`` ×N then ``<end_of_utterance>`` for this checkpoint).

The fix: in ``_build_robot_observation_provider``, pre-compute the
camera-key → (H, W) target from ``ds_features`` and ``cv2.resize``
each live frame to that shape before tensorising. The downstream
``resize_with_pad`` then sees the same input geometry as training and
the LM head returns to producing readable subtask text under plain
greedy decoding — the same as dry-run.

Also drops the inference-time patches (``min_new_tokens``,
``temperature``, ``top_p`` overrides) on the four high-level callers.
They were band-aids around the visual-distribution shift, not a real
LM problem, and they drift inference off the training distribution.
Greedy argmax is what training matched. The ``select_message``
signature still accepts the knobs for callers that want them.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-05-12 17:59:24 +02:00
parent 1292304c42
commit ab5c1dc392
2 changed files with 72 additions and 20 deletions
@@ -365,23 +365,24 @@ class HighLevelSubtaskFwd(InferenceStep):
return None
ctx = _msgs_for_subtask(state)
observation = _maybe_observation(self.observation_provider)
# Force the head to commit to ≥ 5 real tokens before it can
# close the turn, and sample at moderate temperature with
# nucleus filtering. On a memorised head whose argmax at
# position 0 is EOS, greedy decoding silently produced empty
# completions every chunk boundary (visible as the
# ``empty:N`` counter climbing). Temp 0.4 + top_p 0.9 is well
# below where SmolVLM goes incoherent and above where greedy
# collapse re-emerges.
# Match training: greedy argmax, no min_new_tokens, no
# special-token suppression. Earlier experiments forced
# min_new_tokens=5 + sampling because the LM head was
# collapsing to EOS at position 0 — but that turned out to
# be a visual-distribution shift (camera frames being fed
# at the camera's native resolution rather than the
# dataset's recorded resolution), not a head pathology.
# With the camera frame resized to the dataset's
# ``ds_features['observation.images.*']['shape']`` shape,
# the visual prefix is back on-distribution and the same
# greedy decoding that works in ``--no_robot`` dry-run also
# works on the live robot.
msg = _generate_with_policy(
self.policy,
ctx,
observation=observation,
state=state,
label="subtask gen",
min_new_tokens=5,
temperature=0.4,
top_p=0.9,
)
# Diagnostics: surface what the model is *actually* producing
# at chunk boundaries, even when the output gets rejected or
@@ -474,9 +475,6 @@ class MemoryUpdateFwd(InferenceStep):
observation=observation,
state=state,
label="memory gen",
min_new_tokens=5,
temperature=0.4,
top_p=0.9,
)
state["last_memory_raw"] = new_memory or ""
if new_memory and _looks_like_gibberish(new_memory):
@@ -520,9 +518,6 @@ class UserInterjectionFwd(InferenceStep):
observation=observation,
state=state,
label="plan/say gen",
min_new_tokens=10,
temperature=0.5,
top_p=0.9,
)
if not out:
# Don't log every empty completion — happens repeatedly on
@@ -592,9 +587,6 @@ class AskVQAFwd(InferenceStep):
observation=observation,
state=state,
label="vqa gen",
min_new_tokens=3,
temperature=0.4,
top_p=0.9,
)
# VQA answers are intentionally JSON-like during training, so
# ``_looks_like_gibberish`` would false-positive on them. Keep
@@ -594,6 +594,40 @@ def _build_robot_observation_provider(
getattr(robot, "config", None), "type", None
)
# Pre-compute the camera-key → target (H, W) map from
# ``ds_features``. The training distribution sees frames at the
# recorded resolution (e.g. 480×640); a live Mac/USB camera will
# almost always hand us a different native size (720p / 1080p).
# SmolVLA's internal ``resize_with_pad(512, 512)`` does pad the
# input to a fixed canvas, but the *geometry* of that pad differs
# by input aspect ratio — top/left padding varies, so the visual
# tokens at each tile carry different content than what the model
# saw at training. The action expert tolerates this (flow head
# rides broad geometry); the LM head, supervised much more
# tightly on visual features, goes out of distribution and the
# head's distribution at position 0 collapses to its dominant
# mode (a memorised ``\n``-only run in this checkpoint).
target_image_shapes: dict[str, tuple[int, int]] = {}
if ds_features:
for fkey, fmeta in ds_features.items():
if not isinstance(fmeta, dict):
continue
dtype = fmeta.get("dtype")
if dtype not in ("image", "video"):
continue
shape = fmeta.get("shape")
if not shape or len(shape) != 3:
continue
names = fmeta.get("names") or []
# Feature schema stores either (H, W, C) or (C, H, W);
# disambiguate by the ``names`` ordering when present.
if names and len(names) == 3 and names[0] == "channels":
_, h, w = shape
else:
h, w, _ = shape
cam_key = fkey.removeprefix("observation.images.")
target_image_shapes[cam_key] = (int(h), int(w))
def _provider() -> dict | None:
try:
raw = robot.get_observation()
@@ -606,6 +640,32 @@ def _build_robot_observation_provider(
for k in ("language_persistent", "language_events"):
raw.pop(k, None)
# Force-match the training-time visual distribution:
# every camera frame the model trained on came from the
# dataset at its recorded (H, W). Resize the live frame to
# that exact shape so the downstream resize_with_pad geometry
# matches training. Without this the LM head is OOD on every
# tick.
if target_image_shapes:
try:
import cv2 as _cv2 # noqa: PLC0415
import numpy as _np # noqa: PLC0415
for cam_key, (target_h, target_w) in target_image_shapes.items():
img = raw.get(cam_key)
if img is None or not isinstance(img, _np.ndarray):
continue
if img.ndim != 3:
continue
cur_h, cur_w = img.shape[:2]
if (cur_h, cur_w) == (target_h, target_w):
continue
raw[cam_key] = _cv2.resize(
img, (target_w, target_h), interpolation=_cv2.INTER_AREA
)
except Exception as exc: # noqa: BLE001
logger.warning("camera resize to dataset shape failed: %s", exc)
try:
if ds_features:
# Use the dataset's feature schema to pick the right