fix(smolvla2): make silent generation failures visible in REPL

Two failure modes were combining to make the runtime "look dead":

1. ``_build_text_batch`` produced lang tokens via
   ``apply_chat_template(return_tensors='pt')`` on CPU, but the policy
   sits on the configured device (mps / cuda). The first prefix-embed
   inside ``select_message`` then raised a device-mismatch on every
   call. The bare ``except Exception`` in ``_generate_with_policy``
   swallowed it at debug level — no logs, no chat output, no visible
   sign anything had run.

2. Even when generation succeeded but returned an empty string
   (greedy EOS, unhappy chat template, etc.), the high-level steps
   silently no-op'd, so users saw nothing.

Move tokens to ``policy.config.device`` in ``_build_text_batch`` so
the prefix forward succeeds in the common case. Bump the swallowing
log level to ``warning`` (with optional traceback under ``-v``), and
when ``state`` is given route the same diagnostic into the REPL log
via ``push_log`` so the user sees ``[warn] subtask gen failed: ...``
inline. Also push an ``[info] ... produced no text this tick`` line
when generation runs but yields nothing, so empty completions are
distinguishable from "step never ran". Apply the same surface to
``LowLevelForward.select_action`` failures.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-05-05 11:47:34 +02:00
parent 9cbbcfb6a2
commit 7296ac97af
@@ -118,7 +118,8 @@ class LowLevelForward(InferenceStep):
try:
action = self.policy.select_action(observation)
except Exception as exc: # noqa: BLE001
logger.debug("select_action skipped: %s", exc)
logger.warning("select_action failed: %s", exc, exc_info=logger.isEnabledFor(logging.DEBUG))
push_log(state, f" [warn] select_action failed: {type(exc).__name__}: {exc}")
return None
# SmolVLA returns a single action; if the underlying policy
# streams chunks, split per-step here. For v1 we just enqueue
@@ -181,6 +182,17 @@ def _build_text_batch(policy: Any, prompt_messages: list[dict[str, Any]]) -> dic
if hasattr(ids, "ndim") and ids.ndim == 1:
ids = ids.unsqueeze(0)
attn = (ids != tokenizer.pad_token_id) if tokenizer.pad_token_id is not None else None
# Move tokens onto the policy's device — otherwise prefix embedding
# raises a device-mismatch on every forward (CPU tensor vs MPS / CUDA
# model), which the caller's broad except would swallow silently.
device = getattr(getattr(policy, "config", None), "device", None)
if device is not None:
try:
ids = ids.to(device)
if attn is not None and hasattr(attn, "to"):
attn = attn.to(device)
except Exception as exc: # noqa: BLE001
logger.debug("could not move lang tokens to %s: %s", device, exc)
return {"lang_tokens": ids, "lang_masks": attn, "tokenizer": tokenizer}
@@ -208,12 +220,16 @@ class HighLevelSubtaskFwd(InferenceStep):
return None
ctx = _control_context_messages(state)
observation = _maybe_observation(self.observation_provider)
msg = _generate_with_policy(self.policy, ctx, observation=observation)
msg = _generate_with_policy(
self.policy, ctx, observation=observation, state=state, label="subtask gen"
)
if msg:
changed = set_if_changed(state, "current_subtask", msg, label="subtask")
if changed:
# Subtask change is a downstream trigger.
state.setdefault("events_this_tick", []).append("subtask_change")
else:
push_log(state, " [info] subtask gen produced no text this tick")
return None
@@ -231,7 +247,9 @@ class MemoryUpdateFwd(InferenceStep):
return None
ctx = _control_context_messages(state, include_completed=True)
observation = _maybe_observation(self.observation_provider)
new_memory = _generate_with_policy(self.policy, ctx, observation=observation)
new_memory = _generate_with_policy(
self.policy, ctx, observation=observation, state=state, label="memory gen"
)
if new_memory:
set_if_changed(state, "current_memory", new_memory, label="memory")
return None
@@ -253,8 +271,11 @@ class UserInterjectionFwd(InferenceStep):
extra_user=state.get("recent_interjection"),
)
observation = _maybe_observation(self.observation_provider)
out = _generate_with_policy(self.policy, ctx, observation=observation)
out = _generate_with_policy(
self.policy, ctx, observation=observation, state=state, label="plan/say gen"
)
if not out:
push_log(state, " [info] plan/say gen produced no text this tick")
return None
# Heuristic split: model is trained to emit one assistant turn
# carrying both plan text AND a `say` tool call. Look for a
@@ -293,7 +314,9 @@ class AskVQAFwd(InferenceStep):
return None
ctx = _control_context_messages(state, extra_user=question)
observation = _maybe_observation(self.observation_provider)
answer = _generate_with_policy(self.policy, ctx, observation=observation)
answer = _generate_with_policy(
self.policy, ctx, observation=observation, state=state, label="vqa gen"
)
if answer:
push_log(state, f" vqa: {answer}")
state["recent_vqa_query"] = None
@@ -384,6 +407,8 @@ def _generate_with_policy(
messages: list[dict[str, Any]],
*,
observation: dict | None = None,
state: dict[str, Any] | None = None,
label: str = "select_message",
) -> str:
"""Drive ``policy.select_message`` with a chat batch (and optional obs).
@@ -393,8 +418,15 @@ def _generate_with_policy(
on. Without an observation the runtime falls back to a text-only
prompt — the text head still runs, but generations may drift from
the training distribution.
Failures are surfaced both to the module logger (``warning``) and,
when ``state`` is given, to the runtime's user-visible log via
:func:`push_log`, so the REPL no longer "looks dead" when
something goes wrong inside generation.
"""
if not hasattr(policy, "select_message"):
if state is not None:
push_log(state, f" [warn] policy has no select_message — skipping {label}")
return ""
text_batch = _build_text_batch(policy, messages)
try:
@@ -413,7 +445,9 @@ def _generate_with_policy(
batch[k] = v
return policy.select_message(batch, tokenizer=text_batch["tokenizer"])
except Exception as exc: # noqa: BLE001
logger.debug("select_message fell back: %s", exc)
logger.warning("%s failed: %s", label, exc, exc_info=logger.isEnabledFor(logging.DEBUG))
if state is not None:
push_log(state, f" [warn] {label} failed: {type(exc).__name__}: {exc}")
return ""