mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-26 22:20:06 +00:00
fix(smolvla2): bool attention mask + clean Claude-Code-style REPL
Two issues that combined to make the REPL unusable: 1. ``BatchEncoding.attention_mask`` is a ``Long`` tensor, but SmolVLA's ``eager_attention_forward`` does ``torch.where(attention_mask[..., None, :, :], ...)`` which requires a *bool* condition. Every forward raised ``where expected condition to be a boolean tensor, but got a tensor with dtype Long`` and the diagnostic surfaced it cleanly in the REPL — but generation produced nothing useful. Cast to ``bool`` in ``_build_text_batch`` so the prefix forward goes through. 2. The interactive REPL used ``rich.live.Live`` panels stacked on top of ``logging.basicConfig(level=DEBUG)`` HTTP request lines from ``httpcore`` / ``httpx`` / ``huggingface_hub``. The two rendering loops fought each other in the user's terminal and the output was illegible: hundreds of debug lines interleaved with re-rendered panels. Replace ``Live`` with a simple block redraw — clear screen, print the state block, print any robot log lines, then a single ``> `` prompt. State changes are visible above the prompt, the way Claude Code's REPL renders. No flicker, no re-render races. ``_silence_noisy_loggers`` drops the chatty third-party HTTP / download / model-init loggers to WARNING. ``-v`` still enables DEBUG on the lerobot loggers; if the user needs the HTTP traces, they can flip those individually. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -211,6 +211,17 @@ def _build_text_batch(policy: Any, prompt_messages: list[dict[str, Any]]) -> dic
|
|||||||
attn = torch.tensor(attn, dtype=torch.long)
|
attn = torch.tensor(attn, dtype=torch.long)
|
||||||
if attn.ndim == 1:
|
if attn.ndim == 1:
|
||||||
attn = attn.unsqueeze(0)
|
attn = attn.unsqueeze(0)
|
||||||
|
# SmolVLA's ``eager_attention_forward`` does
|
||||||
|
# ``torch.where(attention_mask[..., None, :, :], ...)`` which
|
||||||
|
# requires a *bool* condition tensor; ``BatchEncoding``'s
|
||||||
|
# attention_mask is typically Long (0/1). Cast so the prefix
|
||||||
|
# forward doesn't blow up with ``where expected condition to be a
|
||||||
|
# boolean tensor, but got a tensor with dtype Long``.
|
||||||
|
if attn is not None and hasattr(attn, "dtype"):
|
||||||
|
import torch as _torch # noqa: PLC0415
|
||||||
|
|
||||||
|
if attn.dtype != _torch.bool:
|
||||||
|
attn = attn.bool()
|
||||||
# Move tokens onto the policy's device — otherwise prefix embedding
|
# Move tokens onto the policy's device — otherwise prefix embedding
|
||||||
# raises a device-mismatch on every forward (CPU tensor vs MPS / CUDA
|
# raises a device-mismatch on every forward (CPU tensor vs MPS / CUDA
|
||||||
# model), which the caller's broad except would swallow silently.
|
# model), which the caller's broad except would swallow silently.
|
||||||
|
|||||||
@@ -320,12 +320,43 @@ def _build_tools(no_tts: bool, tts_voice: str) -> dict[str, Any]:
|
|||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
def _silence_noisy_loggers() -> None:
|
||||||
|
"""Drop chatty third-party loggers down to WARNING.
|
||||||
|
|
||||||
|
HuggingFace / httpx / urllib3 emit one log line per HTTP request,
|
||||||
|
which the REPL has to print between the state block and the
|
||||||
|
prompt — completely unreadable. We never need that detail in the
|
||||||
|
REPL and the user can opt back into it via ``-v`` (verbose mode
|
||||||
|
keeps DEBUG on the lerobot loggers but still gates the noisy ones
|
||||||
|
here unless they explicitly want them).
|
||||||
|
"""
|
||||||
|
for name in (
|
||||||
|
"httpcore",
|
||||||
|
"httpcore.connection",
|
||||||
|
"httpcore.http11",
|
||||||
|
"httpcore.proxy",
|
||||||
|
"httpx",
|
||||||
|
"urllib3",
|
||||||
|
"urllib3.connectionpool",
|
||||||
|
"huggingface_hub",
|
||||||
|
"huggingface_hub.repocard",
|
||||||
|
"huggingface_hub.file_download",
|
||||||
|
"transformers",
|
||||||
|
"transformers.modeling_utils",
|
||||||
|
"transformers.tokenization_utils_base",
|
||||||
|
"datasets",
|
||||||
|
"filelock",
|
||||||
|
):
|
||||||
|
logging.getLogger(name).setLevel(logging.WARNING)
|
||||||
|
|
||||||
|
|
||||||
def main(argv: list[str] | None = None) -> int:
|
def main(argv: list[str] | None = None) -> int:
|
||||||
args = _parse_args(argv)
|
args = _parse_args(argv)
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=logging.DEBUG if args.verbose else logging.INFO,
|
level=logging.DEBUG if args.verbose else logging.INFO,
|
||||||
format="%(asctime)s %(levelname)s %(message)s",
|
format="%(asctime)s %(levelname)s %(message)s",
|
||||||
)
|
)
|
||||||
|
_silence_noisy_loggers()
|
||||||
|
|
||||||
print(f"[smolvla2] loading policy from {args.policy_path}", flush=True)
|
print(f"[smolvla2] loading policy from {args.policy_path}", flush=True)
|
||||||
policy, preprocessor, _ds_meta = _load_policy_and_preprocessor(
|
policy, preprocessor, _ds_meta = _load_policy_and_preprocessor(
|
||||||
@@ -385,15 +416,17 @@ def main(argv: list[str] | None = None) -> int:
|
|||||||
|
|
||||||
|
|
||||||
def _run_repl(runtime: Any, *, initial_task: str | None, max_ticks: int | None) -> int:
|
def _run_repl(runtime: Any, *, initial_task: str | None, max_ticks: int | None) -> int:
|
||||||
"""Two-zone TUI: chat scrollback above a persistent state panel.
|
"""Claude-Code-style block REPL.
|
||||||
|
|
||||||
Uses :class:`rich.live.Live` to keep the state panel rendered
|
Each turn redraws a status block (task / subtask / plan / memory)
|
||||||
below the chat history. ``console.input`` for the prompt — Live
|
at the top, prints any robot log lines that came in since the last
|
||||||
auto-suspends repaint while the user is typing.
|
turn, then asks for input on a clean ``> `` prompt at the bottom.
|
||||||
|
No live region, no panel re-renders, no rendering races with HTTP
|
||||||
|
log lines — just clear-screen + reprint each turn, the way a
|
||||||
|
chat-style REPL is meant to look.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
from rich.console import Console # noqa: PLC0415
|
from rich.console import Console # noqa: PLC0415
|
||||||
from rich.live import Live # noqa: PLC0415
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
print(
|
print(
|
||||||
"[smolvla2] rich is required for the interactive REPL. "
|
"[smolvla2] rich is required for the interactive REPL. "
|
||||||
@@ -402,71 +435,92 @@ def _run_repl(runtime: Any, *, initial_task: str | None, max_ticks: int | None)
|
|||||||
)
|
)
|
||||||
return 2
|
return 2
|
||||||
|
|
||||||
from lerobot.policies.smolvla2.inference import ( # noqa: PLC0415
|
console = Console(highlight=False)
|
||||||
make_state_panel,
|
|
||||||
print_robot_lines,
|
|
||||||
print_user_line,
|
|
||||||
)
|
|
||||||
|
|
||||||
console = Console()
|
def _redraw(robot_lines: list[str] | None = None) -> None:
|
||||||
console.print(
|
# ANSI clear screen + home cursor. Falls back gracefully on
|
||||||
"[bold]SmolVLA2[/] ready. "
|
# dumb terminals — they just see scrolled output, which is
|
||||||
"Type a task to begin, then any line for an interjection, "
|
# fine.
|
||||||
"a line ending in '?' for VQA, or 'stop' to exit.",
|
console.clear()
|
||||||
)
|
console.rule("[bold]SmolVLA2[/] · dry-run", style="cyan")
|
||||||
|
st = runtime.state
|
||||||
|
for key, label in (
|
||||||
|
("task", "task"),
|
||||||
|
("current_subtask", "subtask"),
|
||||||
|
("current_plan", "plan"),
|
||||||
|
("current_memory", "memory"),
|
||||||
|
):
|
||||||
|
value = st.get(key)
|
||||||
|
if value:
|
||||||
|
console.print(f" [bold cyan]{label:<8}[/] {value}")
|
||||||
|
else:
|
||||||
|
console.print(f" [dim]{label:<8} (not set)[/]")
|
||||||
|
queue_len = (
|
||||||
|
len(st["action_queue"])
|
||||||
|
if isinstance(st.get("action_queue"), (list, tuple))
|
||||||
|
or hasattr(st.get("action_queue"), "__len__")
|
||||||
|
else 0
|
||||||
|
)
|
||||||
|
pending = len(st.get("tool_calls_pending") or [])
|
||||||
|
console.print(
|
||||||
|
f" [dim]queued actions: {queue_len} pending tool calls: {pending}[/]"
|
||||||
|
)
|
||||||
|
console.rule(style="cyan")
|
||||||
|
if robot_lines:
|
||||||
|
for line in robot_lines:
|
||||||
|
console.print(f" [magenta]{line.strip()}[/]")
|
||||||
|
console.print()
|
||||||
|
# Help line under the divider when nothing is set yet.
|
||||||
|
if not st.get("task"):
|
||||||
|
console.print(
|
||||||
|
" [dim]Type the task to begin. Lines ending in '?' are VQA, "
|
||||||
|
"anything else is an interjection. Type 'stop' to exit.[/]"
|
||||||
|
)
|
||||||
|
|
||||||
|
last_logs: list[str] = []
|
||||||
|
_redraw()
|
||||||
if initial_task is None:
|
if initial_task is None:
|
||||||
console.print("[dim]No --task provided; first stdin line will be used.[/]")
|
# Already shown the help line in _redraw when task is None.
|
||||||
|
pass
|
||||||
panel = make_state_panel(runtime.state)
|
|
||||||
ticks_done = 0
|
ticks_done = 0
|
||||||
with Live(
|
try:
|
||||||
panel,
|
while True:
|
||||||
console=console,
|
try:
|
||||||
refresh_per_second=4,
|
line = console.input("[bold cyan]> [/]").strip()
|
||||||
transient=False,
|
except EOFError:
|
||||||
screen=False,
|
break
|
||||||
) as live:
|
if not line:
|
||||||
try:
|
_redraw(last_logs)
|
||||||
while True:
|
continue
|
||||||
try:
|
lower = line.lower()
|
||||||
line = console.input("[bold cyan]>[/] ").strip()
|
if lower in {"stop", "quit", "exit"}:
|
||||||
except EOFError:
|
break
|
||||||
break
|
|
||||||
if not line:
|
|
||||||
continue
|
|
||||||
lower = line.lower()
|
|
||||||
if lower in {"stop", "quit", "exit"}:
|
|
||||||
break
|
|
||||||
|
|
||||||
print_user_line(live.console, line)
|
# Inject the user input as the right kind of event,
|
||||||
|
# then run a single pipeline tick to consume it.
|
||||||
|
if not runtime.state.get("task"):
|
||||||
|
task = line[5:].strip() if lower.startswith("task:") else line
|
||||||
|
runtime.set_task(task)
|
||||||
|
elif lower.endswith("?"):
|
||||||
|
runtime.state["recent_vqa_query"] = line
|
||||||
|
runtime.state.setdefault("events_this_tick", []).append(
|
||||||
|
"user_vqa_query"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
runtime.state["recent_interjection"] = line
|
||||||
|
runtime.state.setdefault("events_this_tick", []).append(
|
||||||
|
"user_interjection"
|
||||||
|
)
|
||||||
|
|
||||||
# Inject the user input as the right kind of event,
|
last_logs = runtime.step_once() or []
|
||||||
# then run a single pipeline tick to consume it.
|
_redraw(last_logs)
|
||||||
if not runtime.state.get("task"):
|
|
||||||
task = line[5:].strip() if lower.startswith("task:") else line
|
|
||||||
runtime.set_task(task)
|
|
||||||
elif lower.endswith("?"):
|
|
||||||
runtime.state["recent_vqa_query"] = line
|
|
||||||
runtime.state.setdefault("events_this_tick", []).append(
|
|
||||||
"user_vqa_query"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
runtime.state["recent_interjection"] = line
|
|
||||||
runtime.state.setdefault("events_this_tick", []).append(
|
|
||||||
"user_interjection"
|
|
||||||
)
|
|
||||||
|
|
||||||
logs = runtime.step_once()
|
ticks_done += 1
|
||||||
if logs:
|
if max_ticks is not None and ticks_done >= max_ticks:
|
||||||
print_robot_lines(live.console, logs)
|
break
|
||||||
live.update(make_state_panel(runtime.state))
|
except KeyboardInterrupt:
|
||||||
|
console.print("\n[dim]interrupted[/]")
|
||||||
ticks_done += 1
|
console.print("[dim]runtime stopped[/]")
|
||||||
if max_ticks is not None and ticks_done >= max_ticks:
|
|
||||||
break
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
console.print("\n[smolvla2] interrupted", style="dim")
|
|
||||||
print("[smolvla2] runtime stopped", flush=True)
|
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user