fix(smolvla2): bool attention mask + clean Claude-Code-style REPL

Two issues that combined to make the REPL unusable:

1. ``BatchEncoding.attention_mask`` is a ``Long`` tensor, but SmolVLA's
   ``eager_attention_forward`` does
   ``torch.where(attention_mask[..., None, :, :], ...)`` which
   requires a *bool* condition. Every forward raised ``where expected
   condition to be a boolean tensor, but got a tensor with dtype Long``
   and the diagnostic surfaced it cleanly in the REPL — but generation
   produced nothing useful. Cast to ``bool`` in ``_build_text_batch``
   so the prefix forward goes through.

2. The interactive REPL used ``rich.live.Live`` panels stacked on top
   of ``logging.basicConfig(level=DEBUG)`` HTTP request lines from
   ``httpcore`` / ``httpx`` / ``huggingface_hub``. The two rendering
   loops fought each other in the user's terminal and the output was
   illegible: hundreds of debug lines interleaved with re-rendered
   panels.

   Replace ``Live`` with a simple block redraw — clear screen, print
   the state block, print any robot log lines, then a single ``> ``
   prompt. State changes are visible above the prompt, the way Claude
   Code's REPL renders. No flicker, no re-render races.

   ``_silence_noisy_loggers`` drops the chatty third-party HTTP /
   download / model-init loggers to WARNING. ``-v`` still enables
   DEBUG on the lerobot loggers; if the user needs the HTTP traces,
   they can flip those individually.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-05-05 12:03:47 +02:00
parent 0fb5f04965
commit 2776b57c9e
2 changed files with 129 additions and 64 deletions
@@ -211,6 +211,17 @@ def _build_text_batch(policy: Any, prompt_messages: list[dict[str, Any]]) -> dic
attn = torch.tensor(attn, dtype=torch.long)
if attn.ndim == 1:
attn = attn.unsqueeze(0)
# SmolVLA's ``eager_attention_forward`` does
# ``torch.where(attention_mask[..., None, :, :], ...)`` which
# requires a *bool* condition tensor; ``BatchEncoding``'s
# attention_mask is typically Long (0/1). Cast so the prefix
# forward doesn't blow up with ``where expected condition to be a
# boolean tensor, but got a tensor with dtype Long``.
if attn is not None and hasattr(attn, "dtype"):
import torch as _torch # noqa: PLC0415
if attn.dtype != _torch.bool:
attn = attn.bool()
# Move tokens onto the policy's device — otherwise prefix embedding
# raises a device-mismatch on every forward (CPU tensor vs MPS / CUDA
# model), which the caller's broad except would swallow silently.
+118 -64
View File
@@ -320,12 +320,43 @@ def _build_tools(no_tts: bool, tts_voice: str) -> dict[str, Any]:
return {}
def _silence_noisy_loggers() -> None:
"""Drop chatty third-party loggers down to WARNING.
HuggingFace / httpx / urllib3 emit one log line per HTTP request,
which the REPL has to print between the state block and the
prompt — completely unreadable. We never need that detail in the
REPL and the user can opt back into it via ``-v`` (verbose mode
keeps DEBUG on the lerobot loggers but still gates the noisy ones
here unless they explicitly want them).
"""
for name in (
"httpcore",
"httpcore.connection",
"httpcore.http11",
"httpcore.proxy",
"httpx",
"urllib3",
"urllib3.connectionpool",
"huggingface_hub",
"huggingface_hub.repocard",
"huggingface_hub.file_download",
"transformers",
"transformers.modeling_utils",
"transformers.tokenization_utils_base",
"datasets",
"filelock",
):
logging.getLogger(name).setLevel(logging.WARNING)
def main(argv: list[str] | None = None) -> int:
args = _parse_args(argv)
logging.basicConfig(
level=logging.DEBUG if args.verbose else logging.INFO,
format="%(asctime)s %(levelname)s %(message)s",
)
_silence_noisy_loggers()
print(f"[smolvla2] loading policy from {args.policy_path}", flush=True)
policy, preprocessor, _ds_meta = _load_policy_and_preprocessor(
@@ -385,15 +416,17 @@ def main(argv: list[str] | None = None) -> int:
def _run_repl(runtime: Any, *, initial_task: str | None, max_ticks: int | None) -> int:
"""Two-zone TUI: chat scrollback above a persistent state panel.
"""Claude-Code-style block REPL.
Uses :class:`rich.live.Live` to keep the state panel rendered
below the chat history. ``console.input`` for the prompt — Live
auto-suspends repaint while the user is typing.
Each turn redraws a status block (task / subtask / plan / memory)
at the top, prints any robot log lines that came in since the last
turn, then asks for input on a clean ``> `` prompt at the bottom.
No live region, no panel re-renders, no rendering races with HTTP
log lines — just clear-screen + reprint each turn, the way a
chat-style REPL is meant to look.
"""
try:
from rich.console import Console # noqa: PLC0415
from rich.live import Live # noqa: PLC0415
except ImportError:
print(
"[smolvla2] rich is required for the interactive REPL. "
@@ -402,71 +435,92 @@ def _run_repl(runtime: Any, *, initial_task: str | None, max_ticks: int | None)
)
return 2
from lerobot.policies.smolvla2.inference import ( # noqa: PLC0415
make_state_panel,
print_robot_lines,
print_user_line,
)
console = Console(highlight=False)
console = Console()
console.print(
"[bold]SmolVLA2[/] ready. "
"Type a task to begin, then any line for an interjection, "
"a line ending in '?' for VQA, or 'stop' to exit.",
)
def _redraw(robot_lines: list[str] | None = None) -> None:
# ANSI clear screen + home cursor. Falls back gracefully on
# dumb terminals — they just see scrolled output, which is
# fine.
console.clear()
console.rule("[bold]SmolVLA2[/] · dry-run", style="cyan")
st = runtime.state
for key, label in (
("task", "task"),
("current_subtask", "subtask"),
("current_plan", "plan"),
("current_memory", "memory"),
):
value = st.get(key)
if value:
console.print(f" [bold cyan]{label:<8}[/] {value}")
else:
console.print(f" [dim]{label:<8} (not set)[/]")
queue_len = (
len(st["action_queue"])
if isinstance(st.get("action_queue"), (list, tuple))
or hasattr(st.get("action_queue"), "__len__")
else 0
)
pending = len(st.get("tool_calls_pending") or [])
console.print(
f" [dim]queued actions: {queue_len} pending tool calls: {pending}[/]"
)
console.rule(style="cyan")
if robot_lines:
for line in robot_lines:
console.print(f" [magenta]{line.strip()}[/]")
console.print()
# Help line under the divider when nothing is set yet.
if not st.get("task"):
console.print(
" [dim]Type the task to begin. Lines ending in '?' are VQA, "
"anything else is an interjection. Type 'stop' to exit.[/]"
)
last_logs: list[str] = []
_redraw()
if initial_task is None:
console.print("[dim]No --task provided; first stdin line will be used.[/]")
panel = make_state_panel(runtime.state)
# Already shown the help line in _redraw when task is None.
pass
ticks_done = 0
with Live(
panel,
console=console,
refresh_per_second=4,
transient=False,
screen=False,
) as live:
try:
while True:
try:
line = console.input("[bold cyan]>[/] ").strip()
except EOFError:
break
if not line:
continue
lower = line.lower()
if lower in {"stop", "quit", "exit"}:
break
try:
while True:
try:
line = console.input("[bold cyan]> [/]").strip()
except EOFError:
break
if not line:
_redraw(last_logs)
continue
lower = line.lower()
if lower in {"stop", "quit", "exit"}:
break
print_user_line(live.console, line)
# Inject the user input as the right kind of event,
# then run a single pipeline tick to consume it.
if not runtime.state.get("task"):
task = line[5:].strip() if lower.startswith("task:") else line
runtime.set_task(task)
elif lower.endswith("?"):
runtime.state["recent_vqa_query"] = line
runtime.state.setdefault("events_this_tick", []).append(
"user_vqa_query"
)
else:
runtime.state["recent_interjection"] = line
runtime.state.setdefault("events_this_tick", []).append(
"user_interjection"
)
# Inject the user input as the right kind of event,
# then run a single pipeline tick to consume it.
if not runtime.state.get("task"):
task = line[5:].strip() if lower.startswith("task:") else line
runtime.set_task(task)
elif lower.endswith("?"):
runtime.state["recent_vqa_query"] = line
runtime.state.setdefault("events_this_tick", []).append(
"user_vqa_query"
)
else:
runtime.state["recent_interjection"] = line
runtime.state.setdefault("events_this_tick", []).append(
"user_interjection"
)
last_logs = runtime.step_once() or []
_redraw(last_logs)
logs = runtime.step_once()
if logs:
print_robot_lines(live.console, logs)
live.update(make_state_panel(runtime.state))
ticks_done += 1
if max_ticks is not None and ticks_done >= max_ticks:
break
except KeyboardInterrupt:
console.print("\n[smolvla2] interrupted", style="dim")
print("[smolvla2] runtime stopped", flush=True)
ticks_done += 1
if max_ticks is not None and ticks_done >= max_ticks:
break
except KeyboardInterrupt:
console.print("\n[dim]interrupted[/]")
console.print("[dim]runtime stopped[/]")
return 0