mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-25 13:40:00 +00:00
fix(smolvla2): bool attention mask + clean Claude-Code-style REPL
Two issues that combined to make the REPL unusable: 1. ``BatchEncoding.attention_mask`` is a ``Long`` tensor, but SmolVLA's ``eager_attention_forward`` does ``torch.where(attention_mask[..., None, :, :], ...)`` which requires a *bool* condition. Every forward raised ``where expected condition to be a boolean tensor, but got a tensor with dtype Long`` and the diagnostic surfaced it cleanly in the REPL — but generation produced nothing useful. Cast to ``bool`` in ``_build_text_batch`` so the prefix forward goes through. 2. The interactive REPL used ``rich.live.Live`` panels stacked on top of ``logging.basicConfig(level=DEBUG)`` HTTP request lines from ``httpcore`` / ``httpx`` / ``huggingface_hub``. The two rendering loops fought each other in the user's terminal and the output was illegible: hundreds of debug lines interleaved with re-rendered panels. Replace ``Live`` with a simple block redraw — clear screen, print the state block, print any robot log lines, then a single ``> `` prompt. State changes are visible above the prompt, the way Claude Code's REPL renders. No flicker, no re-render races. ``_silence_noisy_loggers`` drops the chatty third-party HTTP / download / model-init loggers to WARNING. ``-v`` still enables DEBUG on the lerobot loggers; if the user needs the HTTP traces, they can flip those individually. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -211,6 +211,17 @@ def _build_text_batch(policy: Any, prompt_messages: list[dict[str, Any]]) -> dic
|
||||
attn = torch.tensor(attn, dtype=torch.long)
|
||||
if attn.ndim == 1:
|
||||
attn = attn.unsqueeze(0)
|
||||
# SmolVLA's ``eager_attention_forward`` does
|
||||
# ``torch.where(attention_mask[..., None, :, :], ...)`` which
|
||||
# requires a *bool* condition tensor; ``BatchEncoding``'s
|
||||
# attention_mask is typically Long (0/1). Cast so the prefix
|
||||
# forward doesn't blow up with ``where expected condition to be a
|
||||
# boolean tensor, but got a tensor with dtype Long``.
|
||||
if attn is not None and hasattr(attn, "dtype"):
|
||||
import torch as _torch # noqa: PLC0415
|
||||
|
||||
if attn.dtype != _torch.bool:
|
||||
attn = attn.bool()
|
||||
# Move tokens onto the policy's device — otherwise prefix embedding
|
||||
# raises a device-mismatch on every forward (CPU tensor vs MPS / CUDA
|
||||
# model), which the caller's broad except would swallow silently.
|
||||
|
||||
@@ -320,12 +320,43 @@ def _build_tools(no_tts: bool, tts_voice: str) -> dict[str, Any]:
|
||||
return {}
|
||||
|
||||
|
||||
def _silence_noisy_loggers() -> None:
|
||||
"""Drop chatty third-party loggers down to WARNING.
|
||||
|
||||
HuggingFace / httpx / urllib3 emit one log line per HTTP request,
|
||||
which the REPL has to print between the state block and the
|
||||
prompt — completely unreadable. We never need that detail in the
|
||||
REPL and the user can opt back into it via ``-v`` (verbose mode
|
||||
keeps DEBUG on the lerobot loggers but still gates the noisy ones
|
||||
here unless they explicitly want them).
|
||||
"""
|
||||
for name in (
|
||||
"httpcore",
|
||||
"httpcore.connection",
|
||||
"httpcore.http11",
|
||||
"httpcore.proxy",
|
||||
"httpx",
|
||||
"urllib3",
|
||||
"urllib3.connectionpool",
|
||||
"huggingface_hub",
|
||||
"huggingface_hub.repocard",
|
||||
"huggingface_hub.file_download",
|
||||
"transformers",
|
||||
"transformers.modeling_utils",
|
||||
"transformers.tokenization_utils_base",
|
||||
"datasets",
|
||||
"filelock",
|
||||
):
|
||||
logging.getLogger(name).setLevel(logging.WARNING)
|
||||
|
||||
|
||||
def main(argv: list[str] | None = None) -> int:
|
||||
args = _parse_args(argv)
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG if args.verbose else logging.INFO,
|
||||
format="%(asctime)s %(levelname)s %(message)s",
|
||||
)
|
||||
_silence_noisy_loggers()
|
||||
|
||||
print(f"[smolvla2] loading policy from {args.policy_path}", flush=True)
|
||||
policy, preprocessor, _ds_meta = _load_policy_and_preprocessor(
|
||||
@@ -385,15 +416,17 @@ def main(argv: list[str] | None = None) -> int:
|
||||
|
||||
|
||||
def _run_repl(runtime: Any, *, initial_task: str | None, max_ticks: int | None) -> int:
|
||||
"""Two-zone TUI: chat scrollback above a persistent state panel.
|
||||
"""Claude-Code-style block REPL.
|
||||
|
||||
Uses :class:`rich.live.Live` to keep the state panel rendered
|
||||
below the chat history. ``console.input`` for the prompt — Live
|
||||
auto-suspends repaint while the user is typing.
|
||||
Each turn redraws a status block (task / subtask / plan / memory)
|
||||
at the top, prints any robot log lines that came in since the last
|
||||
turn, then asks for input on a clean ``> `` prompt at the bottom.
|
||||
No live region, no panel re-renders, no rendering races with HTTP
|
||||
log lines — just clear-screen + reprint each turn, the way a
|
||||
chat-style REPL is meant to look.
|
||||
"""
|
||||
try:
|
||||
from rich.console import Console # noqa: PLC0415
|
||||
from rich.live import Live # noqa: PLC0415
|
||||
except ImportError:
|
||||
print(
|
||||
"[smolvla2] rich is required for the interactive REPL. "
|
||||
@@ -402,71 +435,92 @@ def _run_repl(runtime: Any, *, initial_task: str | None, max_ticks: int | None)
|
||||
)
|
||||
return 2
|
||||
|
||||
from lerobot.policies.smolvla2.inference import ( # noqa: PLC0415
|
||||
make_state_panel,
|
||||
print_robot_lines,
|
||||
print_user_line,
|
||||
)
|
||||
console = Console(highlight=False)
|
||||
|
||||
console = Console()
|
||||
console.print(
|
||||
"[bold]SmolVLA2[/] ready. "
|
||||
"Type a task to begin, then any line for an interjection, "
|
||||
"a line ending in '?' for VQA, or 'stop' to exit.",
|
||||
)
|
||||
def _redraw(robot_lines: list[str] | None = None) -> None:
|
||||
# ANSI clear screen + home cursor. Falls back gracefully on
|
||||
# dumb terminals — they just see scrolled output, which is
|
||||
# fine.
|
||||
console.clear()
|
||||
console.rule("[bold]SmolVLA2[/] · dry-run", style="cyan")
|
||||
st = runtime.state
|
||||
for key, label in (
|
||||
("task", "task"),
|
||||
("current_subtask", "subtask"),
|
||||
("current_plan", "plan"),
|
||||
("current_memory", "memory"),
|
||||
):
|
||||
value = st.get(key)
|
||||
if value:
|
||||
console.print(f" [bold cyan]{label:<8}[/] {value}")
|
||||
else:
|
||||
console.print(f" [dim]{label:<8} (not set)[/]")
|
||||
queue_len = (
|
||||
len(st["action_queue"])
|
||||
if isinstance(st.get("action_queue"), (list, tuple))
|
||||
or hasattr(st.get("action_queue"), "__len__")
|
||||
else 0
|
||||
)
|
||||
pending = len(st.get("tool_calls_pending") or [])
|
||||
console.print(
|
||||
f" [dim]queued actions: {queue_len} pending tool calls: {pending}[/]"
|
||||
)
|
||||
console.rule(style="cyan")
|
||||
if robot_lines:
|
||||
for line in robot_lines:
|
||||
console.print(f" [magenta]{line.strip()}[/]")
|
||||
console.print()
|
||||
# Help line under the divider when nothing is set yet.
|
||||
if not st.get("task"):
|
||||
console.print(
|
||||
" [dim]Type the task to begin. Lines ending in '?' are VQA, "
|
||||
"anything else is an interjection. Type 'stop' to exit.[/]"
|
||||
)
|
||||
|
||||
last_logs: list[str] = []
|
||||
_redraw()
|
||||
if initial_task is None:
|
||||
console.print("[dim]No --task provided; first stdin line will be used.[/]")
|
||||
|
||||
panel = make_state_panel(runtime.state)
|
||||
# Already shown the help line in _redraw when task is None.
|
||||
pass
|
||||
ticks_done = 0
|
||||
with Live(
|
||||
panel,
|
||||
console=console,
|
||||
refresh_per_second=4,
|
||||
transient=False,
|
||||
screen=False,
|
||||
) as live:
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
line = console.input("[bold cyan]>[/] ").strip()
|
||||
except EOFError:
|
||||
break
|
||||
if not line:
|
||||
continue
|
||||
lower = line.lower()
|
||||
if lower in {"stop", "quit", "exit"}:
|
||||
break
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
line = console.input("[bold cyan]> [/]").strip()
|
||||
except EOFError:
|
||||
break
|
||||
if not line:
|
||||
_redraw(last_logs)
|
||||
continue
|
||||
lower = line.lower()
|
||||
if lower in {"stop", "quit", "exit"}:
|
||||
break
|
||||
|
||||
print_user_line(live.console, line)
|
||||
# Inject the user input as the right kind of event,
|
||||
# then run a single pipeline tick to consume it.
|
||||
if not runtime.state.get("task"):
|
||||
task = line[5:].strip() if lower.startswith("task:") else line
|
||||
runtime.set_task(task)
|
||||
elif lower.endswith("?"):
|
||||
runtime.state["recent_vqa_query"] = line
|
||||
runtime.state.setdefault("events_this_tick", []).append(
|
||||
"user_vqa_query"
|
||||
)
|
||||
else:
|
||||
runtime.state["recent_interjection"] = line
|
||||
runtime.state.setdefault("events_this_tick", []).append(
|
||||
"user_interjection"
|
||||
)
|
||||
|
||||
# Inject the user input as the right kind of event,
|
||||
# then run a single pipeline tick to consume it.
|
||||
if not runtime.state.get("task"):
|
||||
task = line[5:].strip() if lower.startswith("task:") else line
|
||||
runtime.set_task(task)
|
||||
elif lower.endswith("?"):
|
||||
runtime.state["recent_vqa_query"] = line
|
||||
runtime.state.setdefault("events_this_tick", []).append(
|
||||
"user_vqa_query"
|
||||
)
|
||||
else:
|
||||
runtime.state["recent_interjection"] = line
|
||||
runtime.state.setdefault("events_this_tick", []).append(
|
||||
"user_interjection"
|
||||
)
|
||||
last_logs = runtime.step_once() or []
|
||||
_redraw(last_logs)
|
||||
|
||||
logs = runtime.step_once()
|
||||
if logs:
|
||||
print_robot_lines(live.console, logs)
|
||||
live.update(make_state_panel(runtime.state))
|
||||
|
||||
ticks_done += 1
|
||||
if max_ticks is not None and ticks_done >= max_ticks:
|
||||
break
|
||||
except KeyboardInterrupt:
|
||||
console.print("\n[smolvla2] interrupted", style="dim")
|
||||
print("[smolvla2] runtime stopped", flush=True)
|
||||
ticks_done += 1
|
||||
if max_ticks is not None and ticks_done >= max_ticks:
|
||||
break
|
||||
except KeyboardInterrupt:
|
||||
console.print("\n[dim]interrupted[/]")
|
||||
console.print("[dim]runtime stopped[/]")
|
||||
return 0
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user