mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-24 04:59:47 +00:00
fix(smolvla2): make silent generation failures visible in REPL
Two failure modes were combining to make the runtime "look dead": 1. ``_build_text_batch`` produced lang tokens via ``apply_chat_template(return_tensors='pt')`` on CPU, but the policy sits on the configured device (mps / cuda). The first prefix-embed inside ``select_message`` then raised a device-mismatch on every call. The bare ``except Exception`` in ``_generate_with_policy`` swallowed it at debug level — no logs, no chat output, no visible sign anything had run. 2. Even when generation succeeded but returned an empty string (greedy EOS, unhappy chat template, etc.), the high-level steps silently no-op'd, so users saw nothing. Move tokens to ``policy.config.device`` in ``_build_text_batch`` so the prefix forward succeeds in the common case. Bump the swallowing log level to ``warning`` (with optional traceback under ``-v``), and when ``state`` is given route the same diagnostic into the REPL log via ``push_log`` so the user sees ``[warn] subtask gen failed: ...`` inline. Also push an ``[info] ... produced no text this tick`` line when generation runs but yields nothing, so empty completions are distinguishable from "step never ran". Apply the same surface to ``LowLevelForward.select_action`` failures. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -118,7 +118,8 @@ class LowLevelForward(InferenceStep):
|
||||
try:
|
||||
action = self.policy.select_action(observation)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.debug("select_action skipped: %s", exc)
|
||||
logger.warning("select_action failed: %s", exc, exc_info=logger.isEnabledFor(logging.DEBUG))
|
||||
push_log(state, f" [warn] select_action failed: {type(exc).__name__}: {exc}")
|
||||
return None
|
||||
# SmolVLA returns a single action; if the underlying policy
|
||||
# streams chunks, split per-step here. For v1 we just enqueue
|
||||
@@ -181,6 +182,17 @@ def _build_text_batch(policy: Any, prompt_messages: list[dict[str, Any]]) -> dic
|
||||
if hasattr(ids, "ndim") and ids.ndim == 1:
|
||||
ids = ids.unsqueeze(0)
|
||||
attn = (ids != tokenizer.pad_token_id) if tokenizer.pad_token_id is not None else None
|
||||
# Move tokens onto the policy's device — otherwise prefix embedding
|
||||
# raises a device-mismatch on every forward (CPU tensor vs MPS / CUDA
|
||||
# model), which the caller's broad except would swallow silently.
|
||||
device = getattr(getattr(policy, "config", None), "device", None)
|
||||
if device is not None:
|
||||
try:
|
||||
ids = ids.to(device)
|
||||
if attn is not None and hasattr(attn, "to"):
|
||||
attn = attn.to(device)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.debug("could not move lang tokens to %s: %s", device, exc)
|
||||
return {"lang_tokens": ids, "lang_masks": attn, "tokenizer": tokenizer}
|
||||
|
||||
|
||||
@@ -208,12 +220,16 @@ class HighLevelSubtaskFwd(InferenceStep):
|
||||
return None
|
||||
ctx = _control_context_messages(state)
|
||||
observation = _maybe_observation(self.observation_provider)
|
||||
msg = _generate_with_policy(self.policy, ctx, observation=observation)
|
||||
msg = _generate_with_policy(
|
||||
self.policy, ctx, observation=observation, state=state, label="subtask gen"
|
||||
)
|
||||
if msg:
|
||||
changed = set_if_changed(state, "current_subtask", msg, label="subtask")
|
||||
if changed:
|
||||
# Subtask change is a downstream trigger.
|
||||
state.setdefault("events_this_tick", []).append("subtask_change")
|
||||
else:
|
||||
push_log(state, " [info] subtask gen produced no text this tick")
|
||||
return None
|
||||
|
||||
|
||||
@@ -231,7 +247,9 @@ class MemoryUpdateFwd(InferenceStep):
|
||||
return None
|
||||
ctx = _control_context_messages(state, include_completed=True)
|
||||
observation = _maybe_observation(self.observation_provider)
|
||||
new_memory = _generate_with_policy(self.policy, ctx, observation=observation)
|
||||
new_memory = _generate_with_policy(
|
||||
self.policy, ctx, observation=observation, state=state, label="memory gen"
|
||||
)
|
||||
if new_memory:
|
||||
set_if_changed(state, "current_memory", new_memory, label="memory")
|
||||
return None
|
||||
@@ -253,8 +271,11 @@ class UserInterjectionFwd(InferenceStep):
|
||||
extra_user=state.get("recent_interjection"),
|
||||
)
|
||||
observation = _maybe_observation(self.observation_provider)
|
||||
out = _generate_with_policy(self.policy, ctx, observation=observation)
|
||||
out = _generate_with_policy(
|
||||
self.policy, ctx, observation=observation, state=state, label="plan/say gen"
|
||||
)
|
||||
if not out:
|
||||
push_log(state, " [info] plan/say gen produced no text this tick")
|
||||
return None
|
||||
# Heuristic split: model is trained to emit one assistant turn
|
||||
# carrying both plan text AND a `say` tool call. Look for a
|
||||
@@ -293,7 +314,9 @@ class AskVQAFwd(InferenceStep):
|
||||
return None
|
||||
ctx = _control_context_messages(state, extra_user=question)
|
||||
observation = _maybe_observation(self.observation_provider)
|
||||
answer = _generate_with_policy(self.policy, ctx, observation=observation)
|
||||
answer = _generate_with_policy(
|
||||
self.policy, ctx, observation=observation, state=state, label="vqa gen"
|
||||
)
|
||||
if answer:
|
||||
push_log(state, f" vqa: {answer}")
|
||||
state["recent_vqa_query"] = None
|
||||
@@ -384,6 +407,8 @@ def _generate_with_policy(
|
||||
messages: list[dict[str, Any]],
|
||||
*,
|
||||
observation: dict | None = None,
|
||||
state: dict[str, Any] | None = None,
|
||||
label: str = "select_message",
|
||||
) -> str:
|
||||
"""Drive ``policy.select_message`` with a chat batch (and optional obs).
|
||||
|
||||
@@ -393,8 +418,15 @@ def _generate_with_policy(
|
||||
on. Without an observation the runtime falls back to a text-only
|
||||
prompt — the text head still runs, but generations may drift from
|
||||
the training distribution.
|
||||
|
||||
Failures are surfaced both to the module logger (``warning``) and,
|
||||
when ``state`` is given, to the runtime's user-visible log via
|
||||
:func:`push_log`, so the REPL no longer "looks dead" when
|
||||
something goes wrong inside generation.
|
||||
"""
|
||||
if not hasattr(policy, "select_message"):
|
||||
if state is not None:
|
||||
push_log(state, f" [warn] policy has no select_message — skipping {label}")
|
||||
return ""
|
||||
text_batch = _build_text_batch(policy, messages)
|
||||
try:
|
||||
@@ -413,7 +445,9 @@ def _generate_with_policy(
|
||||
batch[k] = v
|
||||
return policy.select_message(batch, tokenizer=text_batch["tokenizer"])
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.debug("select_message fell back: %s", exc)
|
||||
logger.warning("%s failed: %s", label, exc, exc_info=logger.isEnabledFor(logging.DEBUG))
|
||||
if state is not None:
|
||||
push_log(state, f" [warn] {label} failed: {type(exc).__name__}: {exc}")
|
||||
return ""
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user