diff --git a/src/lerobot/policies/smolvla2/inference/steps.py b/src/lerobot/policies/smolvla2/inference/steps.py index 8acb19d40..b0951c789 100644 --- a/src/lerobot/policies/smolvla2/inference/steps.py +++ b/src/lerobot/policies/smolvla2/inference/steps.py @@ -118,7 +118,8 @@ class LowLevelForward(InferenceStep): try: action = self.policy.select_action(observation) except Exception as exc: # noqa: BLE001 - logger.debug("select_action skipped: %s", exc) + logger.warning("select_action failed: %s", exc, exc_info=logger.isEnabledFor(logging.DEBUG)) + push_log(state, f" [warn] select_action failed: {type(exc).__name__}: {exc}") return None # SmolVLA returns a single action; if the underlying policy # streams chunks, split per-step here. For v1 we just enqueue @@ -181,6 +182,17 @@ def _build_text_batch(policy: Any, prompt_messages: list[dict[str, Any]]) -> dic if hasattr(ids, "ndim") and ids.ndim == 1: ids = ids.unsqueeze(0) attn = (ids != tokenizer.pad_token_id) if tokenizer.pad_token_id is not None else None + # Move tokens onto the policy's device — otherwise prefix embedding + # raises a device-mismatch on every forward (CPU tensor vs MPS / CUDA + # model), which the caller's broad except would swallow silently. + device = getattr(getattr(policy, "config", None), "device", None) + if device is not None: + try: + ids = ids.to(device) + if attn is not None and hasattr(attn, "to"): + attn = attn.to(device) + except Exception as exc: # noqa: BLE001 + logger.debug("could not move lang tokens to %s: %s", device, exc) return {"lang_tokens": ids, "lang_masks": attn, "tokenizer": tokenizer} @@ -208,12 +220,16 @@ class HighLevelSubtaskFwd(InferenceStep): return None ctx = _control_context_messages(state) observation = _maybe_observation(self.observation_provider) - msg = _generate_with_policy(self.policy, ctx, observation=observation) + msg = _generate_with_policy( + self.policy, ctx, observation=observation, state=state, label="subtask gen" + ) if msg: changed = set_if_changed(state, "current_subtask", msg, label="subtask") if changed: # Subtask change is a downstream trigger. state.setdefault("events_this_tick", []).append("subtask_change") + else: + push_log(state, " [info] subtask gen produced no text this tick") return None @@ -231,7 +247,9 @@ class MemoryUpdateFwd(InferenceStep): return None ctx = _control_context_messages(state, include_completed=True) observation = _maybe_observation(self.observation_provider) - new_memory = _generate_with_policy(self.policy, ctx, observation=observation) + new_memory = _generate_with_policy( + self.policy, ctx, observation=observation, state=state, label="memory gen" + ) if new_memory: set_if_changed(state, "current_memory", new_memory, label="memory") return None @@ -253,8 +271,11 @@ class UserInterjectionFwd(InferenceStep): extra_user=state.get("recent_interjection"), ) observation = _maybe_observation(self.observation_provider) - out = _generate_with_policy(self.policy, ctx, observation=observation) + out = _generate_with_policy( + self.policy, ctx, observation=observation, state=state, label="plan/say gen" + ) if not out: + push_log(state, " [info] plan/say gen produced no text this tick") return None # Heuristic split: model is trained to emit one assistant turn # carrying both plan text AND a `say` tool call. Look for a @@ -293,7 +314,9 @@ class AskVQAFwd(InferenceStep): return None ctx = _control_context_messages(state, extra_user=question) observation = _maybe_observation(self.observation_provider) - answer = _generate_with_policy(self.policy, ctx, observation=observation) + answer = _generate_with_policy( + self.policy, ctx, observation=observation, state=state, label="vqa gen" + ) if answer: push_log(state, f" vqa: {answer}") state["recent_vqa_query"] = None @@ -384,6 +407,8 @@ def _generate_with_policy( messages: list[dict[str, Any]], *, observation: dict | None = None, + state: dict[str, Any] | None = None, + label: str = "select_message", ) -> str: """Drive ``policy.select_message`` with a chat batch (and optional obs). @@ -393,8 +418,15 @@ def _generate_with_policy( on. Without an observation the runtime falls back to a text-only prompt — the text head still runs, but generations may drift from the training distribution. + + Failures are surfaced both to the module logger (``warning``) and, + when ``state`` is given, to the runtime's user-visible log via + :func:`push_log`, so the REPL no longer "looks dead" when + something goes wrong inside generation. """ if not hasattr(policy, "select_message"): + if state is not None: + push_log(state, f" [warn] policy has no select_message — skipping {label}") return "" text_batch = _build_text_batch(policy, messages) try: @@ -413,7 +445,9 @@ def _generate_with_policy( batch[k] = v return policy.select_message(batch, tokenizer=text_batch["tokenizer"]) except Exception as exc: # noqa: BLE001 - logger.debug("select_message fell back: %s", exc) + logger.warning("%s failed: %s", label, exc, exc_info=logger.isEnabledFor(logging.DEBUG)) + if state is not None: + push_log(state, f" [warn] {label} failed: {type(exc).__name__}: {exc}") return ""