diff --git a/src/lerobot/policies/smolvla2/inference/steps.py b/src/lerobot/policies/smolvla2/inference/steps.py index b6be411ef..015d03504 100644 --- a/src/lerobot/policies/smolvla2/inference/steps.py +++ b/src/lerobot/policies/smolvla2/inference/steps.py @@ -211,6 +211,17 @@ def _build_text_batch(policy: Any, prompt_messages: list[dict[str, Any]]) -> dic attn = torch.tensor(attn, dtype=torch.long) if attn.ndim == 1: attn = attn.unsqueeze(0) + # SmolVLA's ``eager_attention_forward`` does + # ``torch.where(attention_mask[..., None, :, :], ...)`` which + # requires a *bool* condition tensor; ``BatchEncoding``'s + # attention_mask is typically Long (0/1). Cast so the prefix + # forward doesn't blow up with ``where expected condition to be a + # boolean tensor, but got a tensor with dtype Long``. + if attn is not None and hasattr(attn, "dtype"): + import torch as _torch # noqa: PLC0415 + + if attn.dtype != _torch.bool: + attn = attn.bool() # Move tokens onto the policy's device — otherwise prefix embedding # raises a device-mismatch on every forward (CPU tensor vs MPS / CUDA # model), which the caller's broad except would swallow silently. diff --git a/src/lerobot/scripts/lerobot_smolvla2_runtime.py b/src/lerobot/scripts/lerobot_smolvla2_runtime.py index f721c20d1..bd9ac8e7b 100644 --- a/src/lerobot/scripts/lerobot_smolvla2_runtime.py +++ b/src/lerobot/scripts/lerobot_smolvla2_runtime.py @@ -320,12 +320,43 @@ def _build_tools(no_tts: bool, tts_voice: str) -> dict[str, Any]: return {} +def _silence_noisy_loggers() -> None: + """Drop chatty third-party loggers down to WARNING. + + HuggingFace / httpx / urllib3 emit one log line per HTTP request, + which the REPL has to print between the state block and the + prompt — completely unreadable. We never need that detail in the + REPL and the user can opt back into it via ``-v`` (verbose mode + keeps DEBUG on the lerobot loggers but still gates the noisy ones + here unless they explicitly want them). + """ + for name in ( + "httpcore", + "httpcore.connection", + "httpcore.http11", + "httpcore.proxy", + "httpx", + "urllib3", + "urllib3.connectionpool", + "huggingface_hub", + "huggingface_hub.repocard", + "huggingface_hub.file_download", + "transformers", + "transformers.modeling_utils", + "transformers.tokenization_utils_base", + "datasets", + "filelock", + ): + logging.getLogger(name).setLevel(logging.WARNING) + + def main(argv: list[str] | None = None) -> int: args = _parse_args(argv) logging.basicConfig( level=logging.DEBUG if args.verbose else logging.INFO, format="%(asctime)s %(levelname)s %(message)s", ) + _silence_noisy_loggers() print(f"[smolvla2] loading policy from {args.policy_path}", flush=True) policy, preprocessor, _ds_meta = _load_policy_and_preprocessor( @@ -385,15 +416,17 @@ def main(argv: list[str] | None = None) -> int: def _run_repl(runtime: Any, *, initial_task: str | None, max_ticks: int | None) -> int: - """Two-zone TUI: chat scrollback above a persistent state panel. + """Claude-Code-style block REPL. - Uses :class:`rich.live.Live` to keep the state panel rendered - below the chat history. ``console.input`` for the prompt — Live - auto-suspends repaint while the user is typing. + Each turn redraws a status block (task / subtask / plan / memory) + at the top, prints any robot log lines that came in since the last + turn, then asks for input on a clean ``> `` prompt at the bottom. + No live region, no panel re-renders, no rendering races with HTTP + log lines — just clear-screen + reprint each turn, the way a + chat-style REPL is meant to look. """ try: from rich.console import Console # noqa: PLC0415 - from rich.live import Live # noqa: PLC0415 except ImportError: print( "[smolvla2] rich is required for the interactive REPL. " @@ -402,71 +435,92 @@ def _run_repl(runtime: Any, *, initial_task: str | None, max_ticks: int | None) ) return 2 - from lerobot.policies.smolvla2.inference import ( # noqa: PLC0415 - make_state_panel, - print_robot_lines, - print_user_line, - ) + console = Console(highlight=False) - console = Console() - console.print( - "[bold]SmolVLA2[/] ready. " - "Type a task to begin, then any line for an interjection, " - "a line ending in '?' for VQA, or 'stop' to exit.", - ) + def _redraw(robot_lines: list[str] | None = None) -> None: + # ANSI clear screen + home cursor. Falls back gracefully on + # dumb terminals — they just see scrolled output, which is + # fine. + console.clear() + console.rule("[bold]SmolVLA2[/] · dry-run", style="cyan") + st = runtime.state + for key, label in ( + ("task", "task"), + ("current_subtask", "subtask"), + ("current_plan", "plan"), + ("current_memory", "memory"), + ): + value = st.get(key) + if value: + console.print(f" [bold cyan]{label:<8}[/] {value}") + else: + console.print(f" [dim]{label:<8} (not set)[/]") + queue_len = ( + len(st["action_queue"]) + if isinstance(st.get("action_queue"), (list, tuple)) + or hasattr(st.get("action_queue"), "__len__") + else 0 + ) + pending = len(st.get("tool_calls_pending") or []) + console.print( + f" [dim]queued actions: {queue_len} pending tool calls: {pending}[/]" + ) + console.rule(style="cyan") + if robot_lines: + for line in robot_lines: + console.print(f" [magenta]{line.strip()}[/]") + console.print() + # Help line under the divider when nothing is set yet. + if not st.get("task"): + console.print( + " [dim]Type the task to begin. Lines ending in '?' are VQA, " + "anything else is an interjection. Type 'stop' to exit.[/]" + ) + + last_logs: list[str] = [] + _redraw() if initial_task is None: - console.print("[dim]No --task provided; first stdin line will be used.[/]") - - panel = make_state_panel(runtime.state) + # Already shown the help line in _redraw when task is None. + pass ticks_done = 0 - with Live( - panel, - console=console, - refresh_per_second=4, - transient=False, - screen=False, - ) as live: - try: - while True: - try: - line = console.input("[bold cyan]>[/] ").strip() - except EOFError: - break - if not line: - continue - lower = line.lower() - if lower in {"stop", "quit", "exit"}: - break + try: + while True: + try: + line = console.input("[bold cyan]> [/]").strip() + except EOFError: + break + if not line: + _redraw(last_logs) + continue + lower = line.lower() + if lower in {"stop", "quit", "exit"}: + break - print_user_line(live.console, line) + # Inject the user input as the right kind of event, + # then run a single pipeline tick to consume it. + if not runtime.state.get("task"): + task = line[5:].strip() if lower.startswith("task:") else line + runtime.set_task(task) + elif lower.endswith("?"): + runtime.state["recent_vqa_query"] = line + runtime.state.setdefault("events_this_tick", []).append( + "user_vqa_query" + ) + else: + runtime.state["recent_interjection"] = line + runtime.state.setdefault("events_this_tick", []).append( + "user_interjection" + ) - # Inject the user input as the right kind of event, - # then run a single pipeline tick to consume it. - if not runtime.state.get("task"): - task = line[5:].strip() if lower.startswith("task:") else line - runtime.set_task(task) - elif lower.endswith("?"): - runtime.state["recent_vqa_query"] = line - runtime.state.setdefault("events_this_tick", []).append( - "user_vqa_query" - ) - else: - runtime.state["recent_interjection"] = line - runtime.state.setdefault("events_this_tick", []).append( - "user_interjection" - ) + last_logs = runtime.step_once() or [] + _redraw(last_logs) - logs = runtime.step_once() - if logs: - print_robot_lines(live.console, logs) - live.update(make_state_panel(runtime.state)) - - ticks_done += 1 - if max_ticks is not None and ticks_done >= max_ticks: - break - except KeyboardInterrupt: - console.print("\n[smolvla2] interrupted", style="dim") - print("[smolvla2] runtime stopped", flush=True) + ticks_done += 1 + if max_ticks is not None and ticks_done >= max_ticks: + break + except KeyboardInterrupt: + console.print("\n[dim]interrupted[/]") + console.print("[dim]runtime stopped[/]") return 0