fix(smolvla2): bool attention mask + clean Claude-Code-style REPL

Two issues that combined to make the REPL unusable:

1. ``BatchEncoding.attention_mask`` is a ``Long`` tensor, but SmolVLA's
   ``eager_attention_forward`` does
   ``torch.where(attention_mask[..., None, :, :], ...)`` which
   requires a *bool* condition. Every forward raised ``where expected
   condition to be a boolean tensor, but got a tensor with dtype Long``
   and the diagnostic surfaced it cleanly in the REPL — but generation
   produced nothing useful. Cast to ``bool`` in ``_build_text_batch``
   so the prefix forward goes through.

2. The interactive REPL used ``rich.live.Live`` panels stacked on top
   of ``logging.basicConfig(level=DEBUG)`` HTTP request lines from
   ``httpcore`` / ``httpx`` / ``huggingface_hub``. The two rendering
   loops fought each other in the user's terminal and the output was
   illegible: hundreds of debug lines interleaved with re-rendered
   panels.

   Replace ``Live`` with a simple block redraw — clear screen, print
   the state block, print any robot log lines, then a single ``> ``
   prompt. State changes are visible above the prompt, the way Claude
   Code's REPL renders. No flicker, no re-render races.

   ``_silence_noisy_loggers`` drops the chatty third-party HTTP /
   download / model-init loggers to WARNING. ``-v`` still enables
   DEBUG on the lerobot loggers; if the user needs the HTTP traces,
   they can flip those individually.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-05-05 12:03:47 +02:00
parent 0fb5f04965
commit 2776b57c9e
2 changed files with 129 additions and 64 deletions
@@ -211,6 +211,17 @@ def _build_text_batch(policy: Any, prompt_messages: list[dict[str, Any]]) -> dic
attn = torch.tensor(attn, dtype=torch.long) attn = torch.tensor(attn, dtype=torch.long)
if attn.ndim == 1: if attn.ndim == 1:
attn = attn.unsqueeze(0) attn = attn.unsqueeze(0)
# SmolVLA's ``eager_attention_forward`` does
# ``torch.where(attention_mask[..., None, :, :], ...)`` which
# requires a *bool* condition tensor; ``BatchEncoding``'s
# attention_mask is typically Long (0/1). Cast so the prefix
# forward doesn't blow up with ``where expected condition to be a
# boolean tensor, but got a tensor with dtype Long``.
if attn is not None and hasattr(attn, "dtype"):
import torch as _torch # noqa: PLC0415
if attn.dtype != _torch.bool:
attn = attn.bool()
# Move tokens onto the policy's device — otherwise prefix embedding # Move tokens onto the policy's device — otherwise prefix embedding
# raises a device-mismatch on every forward (CPU tensor vs MPS / CUDA # raises a device-mismatch on every forward (CPU tensor vs MPS / CUDA
# model), which the caller's broad except would swallow silently. # model), which the caller's broad except would swallow silently.
+118 -64
View File
@@ -320,12 +320,43 @@ def _build_tools(no_tts: bool, tts_voice: str) -> dict[str, Any]:
return {} return {}
def _silence_noisy_loggers() -> None:
"""Drop chatty third-party loggers down to WARNING.
HuggingFace / httpx / urllib3 emit one log line per HTTP request,
which the REPL has to print between the state block and the
prompt completely unreadable. We never need that detail in the
REPL and the user can opt back into it via ``-v`` (verbose mode
keeps DEBUG on the lerobot loggers but still gates the noisy ones
here unless they explicitly want them).
"""
for name in (
"httpcore",
"httpcore.connection",
"httpcore.http11",
"httpcore.proxy",
"httpx",
"urllib3",
"urllib3.connectionpool",
"huggingface_hub",
"huggingface_hub.repocard",
"huggingface_hub.file_download",
"transformers",
"transformers.modeling_utils",
"transformers.tokenization_utils_base",
"datasets",
"filelock",
):
logging.getLogger(name).setLevel(logging.WARNING)
def main(argv: list[str] | None = None) -> int: def main(argv: list[str] | None = None) -> int:
args = _parse_args(argv) args = _parse_args(argv)
logging.basicConfig( logging.basicConfig(
level=logging.DEBUG if args.verbose else logging.INFO, level=logging.DEBUG if args.verbose else logging.INFO,
format="%(asctime)s %(levelname)s %(message)s", format="%(asctime)s %(levelname)s %(message)s",
) )
_silence_noisy_loggers()
print(f"[smolvla2] loading policy from {args.policy_path}", flush=True) print(f"[smolvla2] loading policy from {args.policy_path}", flush=True)
policy, preprocessor, _ds_meta = _load_policy_and_preprocessor( policy, preprocessor, _ds_meta = _load_policy_and_preprocessor(
@@ -385,15 +416,17 @@ def main(argv: list[str] | None = None) -> int:
def _run_repl(runtime: Any, *, initial_task: str | None, max_ticks: int | None) -> int: def _run_repl(runtime: Any, *, initial_task: str | None, max_ticks: int | None) -> int:
"""Two-zone TUI: chat scrollback above a persistent state panel. """Claude-Code-style block REPL.
Uses :class:`rich.live.Live` to keep the state panel rendered Each turn redraws a status block (task / subtask / plan / memory)
below the chat history. ``console.input`` for the prompt Live at the top, prints any robot log lines that came in since the last
auto-suspends repaint while the user is typing. turn, then asks for input on a clean ``> `` prompt at the bottom.
No live region, no panel re-renders, no rendering races with HTTP
log lines just clear-screen + reprint each turn, the way a
chat-style REPL is meant to look.
""" """
try: try:
from rich.console import Console # noqa: PLC0415 from rich.console import Console # noqa: PLC0415
from rich.live import Live # noqa: PLC0415
except ImportError: except ImportError:
print( print(
"[smolvla2] rich is required for the interactive REPL. " "[smolvla2] rich is required for the interactive REPL. "
@@ -402,71 +435,92 @@ def _run_repl(runtime: Any, *, initial_task: str | None, max_ticks: int | None)
) )
return 2 return 2
from lerobot.policies.smolvla2.inference import ( # noqa: PLC0415 console = Console(highlight=False)
make_state_panel,
print_robot_lines,
print_user_line,
)
console = Console() def _redraw(robot_lines: list[str] | None = None) -> None:
console.print( # ANSI clear screen + home cursor. Falls back gracefully on
"[bold]SmolVLA2[/] ready. " # dumb terminals — they just see scrolled output, which is
"Type a task to begin, then any line for an interjection, " # fine.
"a line ending in '?' for VQA, or 'stop' to exit.", console.clear()
) console.rule("[bold]SmolVLA2[/] · dry-run", style="cyan")
st = runtime.state
for key, label in (
("task", "task"),
("current_subtask", "subtask"),
("current_plan", "plan"),
("current_memory", "memory"),
):
value = st.get(key)
if value:
console.print(f" [bold cyan]{label:<8}[/] {value}")
else:
console.print(f" [dim]{label:<8} (not set)[/]")
queue_len = (
len(st["action_queue"])
if isinstance(st.get("action_queue"), (list, tuple))
or hasattr(st.get("action_queue"), "__len__")
else 0
)
pending = len(st.get("tool_calls_pending") or [])
console.print(
f" [dim]queued actions: {queue_len} pending tool calls: {pending}[/]"
)
console.rule(style="cyan")
if robot_lines:
for line in robot_lines:
console.print(f" [magenta]{line.strip()}[/]")
console.print()
# Help line under the divider when nothing is set yet.
if not st.get("task"):
console.print(
" [dim]Type the task to begin. Lines ending in '?' are VQA, "
"anything else is an interjection. Type 'stop' to exit.[/]"
)
last_logs: list[str] = []
_redraw()
if initial_task is None: if initial_task is None:
console.print("[dim]No --task provided; first stdin line will be used.[/]") # Already shown the help line in _redraw when task is None.
pass
panel = make_state_panel(runtime.state)
ticks_done = 0 ticks_done = 0
with Live( try:
panel, while True:
console=console, try:
refresh_per_second=4, line = console.input("[bold cyan]> [/]").strip()
transient=False, except EOFError:
screen=False, break
) as live: if not line:
try: _redraw(last_logs)
while True: continue
try: lower = line.lower()
line = console.input("[bold cyan]>[/] ").strip() if lower in {"stop", "quit", "exit"}:
except EOFError: break
break
if not line:
continue
lower = line.lower()
if lower in {"stop", "quit", "exit"}:
break
print_user_line(live.console, line) # Inject the user input as the right kind of event,
# then run a single pipeline tick to consume it.
if not runtime.state.get("task"):
task = line[5:].strip() if lower.startswith("task:") else line
runtime.set_task(task)
elif lower.endswith("?"):
runtime.state["recent_vqa_query"] = line
runtime.state.setdefault("events_this_tick", []).append(
"user_vqa_query"
)
else:
runtime.state["recent_interjection"] = line
runtime.state.setdefault("events_this_tick", []).append(
"user_interjection"
)
# Inject the user input as the right kind of event, last_logs = runtime.step_once() or []
# then run a single pipeline tick to consume it. _redraw(last_logs)
if not runtime.state.get("task"):
task = line[5:].strip() if lower.startswith("task:") else line
runtime.set_task(task)
elif lower.endswith("?"):
runtime.state["recent_vqa_query"] = line
runtime.state.setdefault("events_this_tick", []).append(
"user_vqa_query"
)
else:
runtime.state["recent_interjection"] = line
runtime.state.setdefault("events_this_tick", []).append(
"user_interjection"
)
logs = runtime.step_once() ticks_done += 1
if logs: if max_ticks is not None and ticks_done >= max_ticks:
print_robot_lines(live.console, logs) break
live.update(make_state_panel(runtime.state)) except KeyboardInterrupt:
console.print("\n[dim]interrupted[/]")
ticks_done += 1 console.print("[dim]runtime stopped[/]")
if max_ticks is not None and ticks_done >= max_ticks:
break
except KeyboardInterrupt:
console.print("\n[smolvla2] interrupted", style="dim")
print("[smolvla2] runtime stopped", flush=True)
return 0 return 0