mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-26 14:09:47 +00:00
feat(smolvla2): runtime accepts Hub IDs + dataset-driven dry-run
The runtime CLI's loader was broken — it imported a `make_policy_from_path` that doesn't exist in `lerobot.policies.factory` — and the high-level text steps generated plan / subtask / memory / VQA from a text-only batch with no images or state, so dry-runs drifted from the training distribution. Switch to the standard `PreTrainedConfig.from_pretrained` + `make_policy(cfg, ds_meta=...)` flow so `--policy.path` accepts both local directories and Hub repo ids, and add a `--dataset.repo_id` path that walks a chosen episode and feeds preprocessed observations into every forward pass — including the four high-level steps (`HighLevelSubtaskFwd`, `MemoryUpdateFwd`, `UserInterjectionFwd`, `AskVQAFwd`). Frames are routed through the saved preprocessor pipeline with `language_persistent` / `language_events` stripped so the recipe-render step stays a no-op (the runtime supplies its own messages from current state). Also wires the rich-based two-zone REPL layout (`ui.py`) that the script was already importing. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -40,6 +40,7 @@ from .steps import (
|
|||||||
UserInterjectionFwd,
|
UserInterjectionFwd,
|
||||||
)
|
)
|
||||||
from .triggers import EventTrigger, HzTrigger, Tick, TickClock, Trigger
|
from .triggers import EventTrigger, HzTrigger, Tick, TickClock, Trigger
|
||||||
|
from .ui import make_state_panel, print_robot_lines, print_user_line
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# runtime
|
# runtime
|
||||||
@@ -65,4 +66,8 @@ __all__ = [
|
|||||||
"UserInterjectionFwd",
|
"UserInterjectionFwd",
|
||||||
"AskVQAFwd",
|
"AskVQAFwd",
|
||||||
"DispatchToolCalls",
|
"DispatchToolCalls",
|
||||||
|
# UI
|
||||||
|
"make_state_panel",
|
||||||
|
"print_robot_lines",
|
||||||
|
"print_user_line",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -82,10 +82,20 @@ class SmolVLA2Runtime:
|
|||||||
HighLevelSubtaskFwd(
|
HighLevelSubtaskFwd(
|
||||||
trigger=HzTrigger(self.high_level_hz),
|
trigger=HzTrigger(self.high_level_hz),
|
||||||
policy=self.policy,
|
policy=self.policy,
|
||||||
|
observation_provider=self.observation_provider,
|
||||||
|
),
|
||||||
|
MemoryUpdateFwd(
|
||||||
|
policy=self.policy,
|
||||||
|
observation_provider=self.observation_provider,
|
||||||
|
),
|
||||||
|
UserInterjectionFwd(
|
||||||
|
policy=self.policy,
|
||||||
|
observation_provider=self.observation_provider,
|
||||||
|
),
|
||||||
|
AskVQAFwd(
|
||||||
|
policy=self.policy,
|
||||||
|
observation_provider=self.observation_provider,
|
||||||
),
|
),
|
||||||
MemoryUpdateFwd(policy=self.policy),
|
|
||||||
UserInterjectionFwd(policy=self.policy),
|
|
||||||
AskVQAFwd(policy=self.policy),
|
|
||||||
DispatchToolCalls(tools=self.tools),
|
DispatchToolCalls(tools=self.tools),
|
||||||
]
|
]
|
||||||
self.state = initial_runtime_state()
|
self.state = initial_runtime_state()
|
||||||
@@ -127,6 +137,39 @@ class SmolVLA2Runtime:
|
|||||||
|
|
||||||
self._on_shutdown()
|
self._on_shutdown()
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# REPL helper: drive one full pipeline pass and return its logs
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def step_once(self) -> list[str]:
|
||||||
|
"""Run one tick of the pipeline and return the log lines.
|
||||||
|
|
||||||
|
Used by the interactive REPL: instead of a background thread,
|
||||||
|
the CLI drives ticks synchronously after each user input. Logs
|
||||||
|
are returned (not printed) so the caller can route them into
|
||||||
|
the rich-Live chat scrollback.
|
||||||
|
"""
|
||||||
|
from .triggers import Tick # noqa: PLC0415
|
||||||
|
|
||||||
|
# Synthesize a tick. We don't need the real wall-clock pacing
|
||||||
|
# here — the REPL drives the runtime, not vice versa — but
|
||||||
|
# ``HzTrigger`` uses ``tick.monotonic_seconds`` to gate, so we
|
||||||
|
# bump it generously so every Hz-triggered step considers
|
||||||
|
# itself due.
|
||||||
|
import time as _time # noqa: PLC0415
|
||||||
|
|
||||||
|
prev_index = self.state.get("_tick").index if isinstance(self.state.get("_tick"), Tick) else 0
|
||||||
|
self.state["_tick"] = Tick(index=prev_index + 1, monotonic_seconds=_time.monotonic())
|
||||||
|
self.state["log_lines"] = []
|
||||||
|
# ``events_this_tick`` is set up by the caller before
|
||||||
|
# ``step_once`` (the REPL pushes user-driven events first).
|
||||||
|
self.state.setdefault("events_this_tick", [])
|
||||||
|
|
||||||
|
for step in self.pipeline:
|
||||||
|
self.state = step(self.state)
|
||||||
|
|
||||||
|
return list(self.state.get("log_lines") or [])
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# I/O
|
# I/O
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
|
|||||||
@@ -171,13 +171,19 @@ class HighLevelSubtaskFwd(InferenceStep):
|
|||||||
"""At ~1 Hz, ask the policy for the next subtask."""
|
"""At ~1 Hz, ask the policy for the next subtask."""
|
||||||
|
|
||||||
policy: Any = None
|
policy: Any = None
|
||||||
|
observation_provider: Any = None
|
||||||
|
"""Same shape as ``LowLevelForward.observation_provider``. When
|
||||||
|
set, the resulting observation is merged into ``select_message``'s
|
||||||
|
batch so text generation runs against real video + state."""
|
||||||
|
|
||||||
trigger: Trigger = field(default_factory=lambda: HzTrigger(hz=1.0))
|
trigger: Trigger = field(default_factory=lambda: HzTrigger(hz=1.0))
|
||||||
|
|
||||||
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
|
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
|
||||||
if self.policy is None or not state.get("task"):
|
if self.policy is None or not state.get("task"):
|
||||||
return None
|
return None
|
||||||
ctx = _control_context_messages(state)
|
ctx = _control_context_messages(state)
|
||||||
msg = _generate_with_policy(self.policy, ctx)
|
observation = _maybe_observation(self.observation_provider)
|
||||||
|
msg = _generate_with_policy(self.policy, ctx, observation=observation)
|
||||||
if msg:
|
if msg:
|
||||||
changed = set_if_changed(state, "current_subtask", msg, label="subtask")
|
changed = set_if_changed(state, "current_subtask", msg, label="subtask")
|
||||||
if changed:
|
if changed:
|
||||||
@@ -191,6 +197,7 @@ class MemoryUpdateFwd(InferenceStep):
|
|||||||
"""On subtask boundary, refresh the compressed memory."""
|
"""On subtask boundary, refresh the compressed memory."""
|
||||||
|
|
||||||
policy: Any = None
|
policy: Any = None
|
||||||
|
observation_provider: Any = None
|
||||||
trigger: Trigger = field(default_factory=lambda: EventTrigger("subtask_change"))
|
trigger: Trigger = field(default_factory=lambda: EventTrigger("subtask_change"))
|
||||||
|
|
||||||
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
|
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
|
||||||
@@ -198,7 +205,8 @@ class MemoryUpdateFwd(InferenceStep):
|
|||||||
if self.policy is None:
|
if self.policy is None:
|
||||||
return None
|
return None
|
||||||
ctx = _control_context_messages(state, include_completed=True)
|
ctx = _control_context_messages(state, include_completed=True)
|
||||||
new_memory = _generate_with_policy(self.policy, ctx)
|
observation = _maybe_observation(self.observation_provider)
|
||||||
|
new_memory = _generate_with_policy(self.policy, ctx, observation=observation)
|
||||||
if new_memory:
|
if new_memory:
|
||||||
set_if_changed(state, "current_memory", new_memory, label="memory")
|
set_if_changed(state, "current_memory", new_memory, label="memory")
|
||||||
return None
|
return None
|
||||||
@@ -209,6 +217,7 @@ class UserInterjectionFwd(InferenceStep):
|
|||||||
"""On stdin interjection, refresh the plan + emit a paired ``say``."""
|
"""On stdin interjection, refresh the plan + emit a paired ``say``."""
|
||||||
|
|
||||||
policy: Any = None
|
policy: Any = None
|
||||||
|
observation_provider: Any = None
|
||||||
trigger: Trigger = field(default_factory=lambda: EventTrigger("user_interjection"))
|
trigger: Trigger = field(default_factory=lambda: EventTrigger("user_interjection"))
|
||||||
|
|
||||||
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
|
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
|
||||||
@@ -218,7 +227,8 @@ class UserInterjectionFwd(InferenceStep):
|
|||||||
state,
|
state,
|
||||||
extra_user=state.get("recent_interjection"),
|
extra_user=state.get("recent_interjection"),
|
||||||
)
|
)
|
||||||
out = _generate_with_policy(self.policy, ctx)
|
observation = _maybe_observation(self.observation_provider)
|
||||||
|
out = _generate_with_policy(self.policy, ctx, observation=observation)
|
||||||
if not out:
|
if not out:
|
||||||
return None
|
return None
|
||||||
# Heuristic split: model is trained to emit one assistant turn
|
# Heuristic split: model is trained to emit one assistant turn
|
||||||
@@ -247,6 +257,7 @@ class AskVQAFwd(InferenceStep):
|
|||||||
"""On stdin question, answer a frame-grounded VQA."""
|
"""On stdin question, answer a frame-grounded VQA."""
|
||||||
|
|
||||||
policy: Any = None
|
policy: Any = None
|
||||||
|
observation_provider: Any = None
|
||||||
trigger: Trigger = field(default_factory=lambda: EventTrigger("user_vqa_query"))
|
trigger: Trigger = field(default_factory=lambda: EventTrigger("user_vqa_query"))
|
||||||
|
|
||||||
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
|
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
|
||||||
@@ -256,7 +267,8 @@ class AskVQAFwd(InferenceStep):
|
|||||||
if not question:
|
if not question:
|
||||||
return None
|
return None
|
||||||
ctx = _control_context_messages(state, extra_user=question)
|
ctx = _control_context_messages(state, extra_user=question)
|
||||||
answer = _generate_with_policy(self.policy, ctx)
|
observation = _maybe_observation(self.observation_provider)
|
||||||
|
answer = _generate_with_policy(self.policy, ctx, observation=observation)
|
||||||
if answer:
|
if answer:
|
||||||
push_log(state, f" vqa: {answer}")
|
push_log(state, f" vqa: {answer}")
|
||||||
state["recent_vqa_query"] = None
|
state["recent_vqa_query"] = None
|
||||||
@@ -326,35 +338,54 @@ def _control_context_messages(
|
|||||||
return msgs
|
return msgs
|
||||||
|
|
||||||
|
|
||||||
def _generate_with_policy(policy: Any, messages: list[dict[str, Any]]) -> str:
|
def _maybe_observation(provider: Any) -> dict | None:
|
||||||
"""Drive ``policy.select_message`` with a minimal text-only batch.
|
"""Pull one observation from ``provider`` if it's set, else ``None``.
|
||||||
|
|
||||||
Best-effort: the runtime today doesn't construct a full
|
Errors from the provider are logged at debug level and swallowed —
|
||||||
observation batch with images / state for text generation; the
|
text generation still runs (in text-only mode) so a flaky frame
|
||||||
text-head was trained over images + lang + state, so generations
|
source doesn't kill the REPL.
|
||||||
here may differ in distribution from training. This is acceptable
|
"""
|
||||||
for a v1 REPL; a follow-up will plug in the real observation.
|
if provider is None:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
return provider()
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
logger.debug("observation_provider raised %s — falling back to text-only", exc)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _generate_with_policy(
|
||||||
|
policy: Any,
|
||||||
|
messages: list[dict[str, Any]],
|
||||||
|
*,
|
||||||
|
observation: dict | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""Drive ``policy.select_message`` with a chat batch (and optional obs).
|
||||||
|
|
||||||
|
When ``observation`` carries ``observation.images.*`` and
|
||||||
|
``observation.state``, those are merged into the batch so
|
||||||
|
``select_message`` runs the same VLM prefix the policy was trained
|
||||||
|
on. Without an observation the runtime falls back to a text-only
|
||||||
|
prompt — the text head still runs, but generations may drift from
|
||||||
|
the training distribution.
|
||||||
"""
|
"""
|
||||||
if not hasattr(policy, "select_message"):
|
if not hasattr(policy, "select_message"):
|
||||||
return ""
|
return ""
|
||||||
text_batch = _build_text_batch(policy, messages)
|
text_batch = _build_text_batch(policy, messages)
|
||||||
# ``select_message`` expects a real batch with OBS_LANGUAGE_TOKENS.
|
|
||||||
# The minimal text-only batch we build doesn't have images / state,
|
|
||||||
# so we either run a text-only forward (handled by SmolVLA2 when
|
|
||||||
# supported) or skip and return empty. v1 returns empty when the
|
|
||||||
# policy can't handle it; the runtime logs and continues.
|
|
||||||
try:
|
try:
|
||||||
# Convert to the OBS_LANGUAGE_TOKENS / OBS_LANGUAGE_ATTENTION_MASK
|
|
||||||
# keys ``select_message`` uses internally.
|
|
||||||
from lerobot.utils.constants import ( # noqa: PLC0415
|
from lerobot.utils.constants import ( # noqa: PLC0415
|
||||||
OBS_LANGUAGE_ATTENTION_MASK,
|
OBS_LANGUAGE_ATTENTION_MASK,
|
||||||
OBS_LANGUAGE_TOKENS,
|
OBS_LANGUAGE_TOKENS,
|
||||||
)
|
)
|
||||||
|
|
||||||
batch = {
|
batch: dict[str, Any] = {
|
||||||
OBS_LANGUAGE_TOKENS: text_batch["lang_tokens"],
|
OBS_LANGUAGE_TOKENS: text_batch["lang_tokens"],
|
||||||
OBS_LANGUAGE_ATTENTION_MASK: text_batch["lang_masks"],
|
OBS_LANGUAGE_ATTENTION_MASK: text_batch["lang_masks"],
|
||||||
}
|
}
|
||||||
|
if observation:
|
||||||
|
for k, v in observation.items():
|
||||||
|
if isinstance(k, str) and k.startswith("observation.") and k not in batch:
|
||||||
|
batch[k] = v
|
||||||
return policy.select_message(batch, tokenizer=text_batch["tokenizer"])
|
return policy.select_message(batch, tokenizer=text_batch["tokenizer"])
|
||||||
except Exception as exc: # noqa: BLE001
|
except Exception as exc: # noqa: BLE001
|
||||||
logger.debug("select_message fell back: %s", exc)
|
logger.debug("select_message fell back: %s", exc)
|
||||||
|
|||||||
@@ -0,0 +1,119 @@
|
|||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Rich-based REPL layout for the SmolVLA2 runtime.
|
||||||
|
|
||||||
|
Two-zone terminal layout:
|
||||||
|
|
||||||
|
[chat scrollback — user messages / robot responses, scrolls naturally]
|
||||||
|
|
||||||
|
┌── State ──────────────────────────────────────────┐
|
||||||
|
│ task please clean up the kitchen │
|
||||||
|
│ subtask grasp the handle of the sponge │
|
||||||
|
│ plan 1. grasp sponge 2. wipe 3. tidy │
|
||||||
|
│ memory sponge picked up; counter still dirty │
|
||||||
|
└───────────────────────────────────────────────────┘
|
||||||
|
> _
|
||||||
|
|
||||||
|
The state panel re-renders on every state change. Chat lines are
|
||||||
|
``console.print``'d above the live region so they accumulate naturally
|
||||||
|
in scrollback. Implemented with :class:`rich.live.Live` plus
|
||||||
|
:func:`rich.console.Console.input` for the prompt — when an input is
|
||||||
|
pending, ``rich.Live`` auto-suspends so the input doesn't fight the
|
||||||
|
panel for cursor position.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
try: # rich is optional; only required for the interactive REPL.
|
||||||
|
from rich.console import Console
|
||||||
|
from rich.panel import Panel
|
||||||
|
from rich.table import Table
|
||||||
|
from rich.text import Text
|
||||||
|
|
||||||
|
_HAS_RICH = True
|
||||||
|
except ImportError: # pragma: no cover
|
||||||
|
_HAS_RICH = False
|
||||||
|
Console = Any # type: ignore[assignment]
|
||||||
|
Panel = Any # type: ignore[assignment]
|
||||||
|
Table = Any # type: ignore[assignment]
|
||||||
|
Text = Any # type: ignore[assignment]
|
||||||
|
|
||||||
|
|
||||||
|
_STATE_KEYS = (
|
||||||
|
("task", "task"),
|
||||||
|
("current_subtask", "subtask"),
|
||||||
|
("current_plan", "plan"),
|
||||||
|
("current_memory", "memory"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def make_state_panel(state: dict[str, Any]) -> Any:
|
||||||
|
"""Render the persistent state panel for the live region.
|
||||||
|
|
||||||
|
Returns a :class:`rich.panel.Panel`. Caller passes it to
|
||||||
|
``Live.update(panel)`` whenever the state changes.
|
||||||
|
"""
|
||||||
|
if not _HAS_RICH:
|
||||||
|
raise RuntimeError(
|
||||||
|
"rich is required for the interactive REPL. "
|
||||||
|
"`pip install rich` (it's a transitive dep of lerobot)."
|
||||||
|
)
|
||||||
|
table = Table.grid(padding=(0, 2), expand=True)
|
||||||
|
table.add_column(justify="right", style="dim", no_wrap=True, width=10)
|
||||||
|
table.add_column(justify="left")
|
||||||
|
for key, label in _STATE_KEYS:
|
||||||
|
value = state.get(key)
|
||||||
|
if value is None:
|
||||||
|
rendered = Text("(not set)", style="dim italic")
|
||||||
|
else:
|
||||||
|
rendered = Text(str(value), style="bold")
|
||||||
|
table.add_row(label, rendered)
|
||||||
|
queue = state.get("action_queue")
|
||||||
|
queue_len = len(queue) if hasattr(queue, "__len__") else 0
|
||||||
|
pending = state.get("tool_calls_pending") or []
|
||||||
|
footer = Text.assemble(
|
||||||
|
("queued actions: ", "dim"),
|
||||||
|
(str(queue_len), "bold cyan"),
|
||||||
|
(" pending tool calls: ", "dim"),
|
||||||
|
(str(len(pending)), "bold magenta"),
|
||||||
|
)
|
||||||
|
table.add_row("", footer)
|
||||||
|
return Panel(table, title="[bold]SmolVLA2 state[/]", border_style="cyan")
|
||||||
|
|
||||||
|
|
||||||
|
def print_user_line(console: Any, line: str) -> None:
|
||||||
|
"""Append a user-typed line to the chat scrollback."""
|
||||||
|
if not _HAS_RICH:
|
||||||
|
print(f"you: {line}", flush=True)
|
||||||
|
return
|
||||||
|
console.print(f"[bold cyan]you:[/] {line}")
|
||||||
|
|
||||||
|
|
||||||
|
def print_robot_lines(console: Any, lines: list[str]) -> None:
|
||||||
|
"""Append robot/runtime log lines to the chat scrollback."""
|
||||||
|
if not _HAS_RICH:
|
||||||
|
for line in lines:
|
||||||
|
print(f"robot: {line.lstrip()}", flush=True)
|
||||||
|
return
|
||||||
|
for line in lines:
|
||||||
|
# The runtime uses leading whitespace + "label: text"; render
|
||||||
|
# the label in green and the value in default for readability.
|
||||||
|
stripped = line.lstrip()
|
||||||
|
if ":" in stripped:
|
||||||
|
label, _, value = stripped.partition(":")
|
||||||
|
console.print(f"[bold green]robot[/] [dim]({label.strip()})[/] {value.strip()}")
|
||||||
|
else:
|
||||||
|
console.print(f"[bold green]robot:[/] {stripped}")
|
||||||
@@ -23,11 +23,21 @@ speech) as they happen.
|
|||||||
Examples
|
Examples
|
||||||
--------
|
--------
|
||||||
|
|
||||||
Dry run on a checkpoint, no robot connected — useful for sanity-
|
Dry run on a Hub checkpoint, no robot connected — useful for sanity-
|
||||||
checking text generation::
|
checking text generation::
|
||||||
|
|
||||||
uv run lerobot-smolvla2-runtime \\
|
uv run lerobot-smolvla2-runtime \\
|
||||||
--policy.path=outputs/train/smolvla2_super_poulain/000020000/pretrained_model \\
|
--policy.path=pepijn223/smolvla2_hirobot_super_poulain_tool2 \\
|
||||||
|
--no_robot \\
|
||||||
|
--task="please clean the kitchen"
|
||||||
|
|
||||||
|
Same, but feed real frames from an annotated dataset so plan / subtask
|
||||||
|
/ memory / VQA generation runs against actual video + state::
|
||||||
|
|
||||||
|
uv run lerobot-smolvla2-runtime \\
|
||||||
|
--policy.path=pepijn223/smolvla2_hirobot_super_poulain_tool2 \\
|
||||||
|
--dataset.repo_id=pepijn223/super_poulain_annotated \\
|
||||||
|
--dataset.episode=0 \\
|
||||||
--no_robot \\
|
--no_robot \\
|
||||||
--task="please clean the kitchen"
|
--task="please clean the kitchen"
|
||||||
|
|
||||||
@@ -38,6 +48,9 @@ With a real robot::
|
|||||||
--robot.type=so101 --robot.port=/dev/tty.usbmodem... \\
|
--robot.type=so101 --robot.port=/dev/tty.usbmodem... \\
|
||||||
--tts.voice=alba
|
--tts.voice=alba
|
||||||
|
|
||||||
|
``--policy.path`` accepts either a local directory or a Hugging Face
|
||||||
|
Hub repo id. ``--dataset.repo_id`` likewise.
|
||||||
|
|
||||||
Tool dispatch (TTS via ``SayTool``) is enabled by default when
|
Tool dispatch (TTS via ``SayTool``) is enabled by default when
|
||||||
``pocket-tts`` is installed; pass ``--no_tts`` to disable.
|
``pocket-tts`` is installed; pass ``--no_tts`` to disable.
|
||||||
"""
|
"""
|
||||||
@@ -47,8 +60,7 @@ from __future__ import annotations
|
|||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from typing import Any, Callable
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
logger = logging.getLogger("lerobot.smolvla2.runtime")
|
logger = logging.getLogger("lerobot.smolvla2.runtime")
|
||||||
|
|
||||||
@@ -61,9 +73,51 @@ def _parse_args(argv: list[str] | None = None) -> argparse.Namespace:
|
|||||||
p.add_argument(
|
p.add_argument(
|
||||||
"--policy.path",
|
"--policy.path",
|
||||||
dest="policy_path",
|
dest="policy_path",
|
||||||
type=Path,
|
type=str,
|
||||||
required=True,
|
required=True,
|
||||||
help="Path to a trained SmolVLA2 ``pretrained_model`` directory.",
|
help=(
|
||||||
|
"Local directory or Hugging Face Hub repo id pointing at a "
|
||||||
|
"trained SmolVLA2 ``pretrained_model``."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
p.add_argument(
|
||||||
|
"--dataset.repo_id",
|
||||||
|
dest="dataset_repo_id",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help=(
|
||||||
|
"Optional dataset (local path or Hub repo id) used to drive "
|
||||||
|
"observations during dry-run inference. When set, the runtime "
|
||||||
|
"reads camera frames + state from the chosen episode and feeds "
|
||||||
|
"them into all forward passes — so plan / subtask / memory / "
|
||||||
|
"VQA generation see the same visual context the policy was "
|
||||||
|
"trained on."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
p.add_argument(
|
||||||
|
"--dataset.episode",
|
||||||
|
dest="dataset_episode",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="Episode index to walk through (default: 0).",
|
||||||
|
)
|
||||||
|
p.add_argument(
|
||||||
|
"--dataset.start_frame",
|
||||||
|
dest="dataset_start_frame",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="Frame index within the episode to start from (default: 0).",
|
||||||
|
)
|
||||||
|
p.add_argument(
|
||||||
|
"--dataset.advance_per_tick",
|
||||||
|
dest="dataset_advance_per_tick",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help=(
|
||||||
|
"How many dataset frames to advance per runtime tick. The "
|
||||||
|
"default of 1 means the runtime walks the episode forward "
|
||||||
|
"frame by frame; set to 0 to freeze on ``start_frame``."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
p.add_argument(
|
p.add_argument(
|
||||||
"--task",
|
"--task",
|
||||||
@@ -111,16 +165,136 @@ def _parse_args(argv: list[str] | None = None) -> argparse.Namespace:
|
|||||||
return p.parse_args(argv)
|
return p.parse_args(argv)
|
||||||
|
|
||||||
|
|
||||||
def _load_policy(path: Path): # noqa: ANN202
|
def _load_policy_and_preprocessor(
|
||||||
"""Load a SmolVLA2 checkpoint from ``path``."""
|
policy_path: str,
|
||||||
from lerobot.policies.factory import make_policy_from_path # noqa: PLC0415
|
dataset_repo_id: str | None,
|
||||||
|
) -> tuple[Any, Any, Any]:
|
||||||
|
"""Load a SmolVLA2 checkpoint (local path or Hub repo id).
|
||||||
|
|
||||||
|
When ``dataset_repo_id`` is provided, the dataset's metadata is used
|
||||||
|
to derive policy features (matching the standard
|
||||||
|
``make_policy(cfg, ds_meta=...)`` flow used by ``lerobot-train`` and
|
||||||
|
``lerobot-record``). When it isn't, we fall back to instantiating
|
||||||
|
the policy directly via ``from_pretrained`` — this skips the
|
||||||
|
feature-derivation path that ``make_policy`` insists on, but also
|
||||||
|
means we can't load the saved preprocessor pipeline (which depends
|
||||||
|
on ``input_features`` / ``output_features``). For inference-only
|
||||||
|
dry-runs this is fine; the policy still loads.
|
||||||
|
|
||||||
|
Returns ``(policy, preprocessor, ds_meta)`` where ``preprocessor``
|
||||||
|
and ``ds_meta`` may be ``None`` if no dataset was provided.
|
||||||
|
"""
|
||||||
|
from lerobot.configs import PreTrainedConfig # noqa: PLC0415
|
||||||
|
from lerobot.policies.factory import make_policy, make_pre_post_processors # noqa: PLC0415
|
||||||
|
|
||||||
|
cfg = PreTrainedConfig.from_pretrained(policy_path)
|
||||||
|
cfg.pretrained_path = policy_path
|
||||||
|
|
||||||
|
ds_meta = None
|
||||||
|
preprocessor = None
|
||||||
|
if dataset_repo_id is not None:
|
||||||
|
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata # noqa: PLC0415
|
||||||
|
|
||||||
|
ds_meta = LeRobotDatasetMetadata(dataset_repo_id)
|
||||||
|
policy = make_policy(cfg, ds_meta=ds_meta)
|
||||||
|
preprocessor, _ = make_pre_post_processors(
|
||||||
|
cfg,
|
||||||
|
pretrained_path=policy_path,
|
||||||
|
dataset_stats=ds_meta.stats,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# No dataset: instantiate the policy class directly so we don't
|
||||||
|
# need ds_meta. This bypasses ``make_policy``'s feature-shape
|
||||||
|
# derivation, which is fine for a pretrained checkpoint where
|
||||||
|
# the saved config already carries those shapes.
|
||||||
|
from lerobot.policies.factory import get_policy_class # noqa: PLC0415
|
||||||
|
|
||||||
|
policy_cls = get_policy_class(cfg.type)
|
||||||
|
policy = policy_cls.from_pretrained(policy_path, config=cfg)
|
||||||
|
policy.to(cfg.device)
|
||||||
|
|
||||||
policy = make_policy_from_path(str(path))
|
|
||||||
policy.eval()
|
policy.eval()
|
||||||
return policy
|
return policy, preprocessor, ds_meta
|
||||||
|
|
||||||
|
|
||||||
def _build_tools(policy_path: Path, no_tts: bool, tts_voice: str) -> dict[str, Any]:
|
def _build_observation_provider(
|
||||||
|
*,
|
||||||
|
dataset_repo_id: str,
|
||||||
|
episode: int,
|
||||||
|
start_frame: int,
|
||||||
|
advance_per_tick: int,
|
||||||
|
preprocessor: Any,
|
||||||
|
device: str,
|
||||||
|
) -> Callable[[], dict | None]:
|
||||||
|
"""Build a closure that feeds dataset frames into the runtime.
|
||||||
|
|
||||||
|
Each call returns a preprocessed observation batch (images +
|
||||||
|
state, batched, on the policy's device, normalized) suitable for
|
||||||
|
``policy.select_action`` and ``policy.select_message``. The
|
||||||
|
closure walks the chosen episode forward by ``advance_per_tick``
|
||||||
|
frames per call, looping back to the episode start when it falls
|
||||||
|
off the end.
|
||||||
|
|
||||||
|
The dataset's ``language_persistent`` / ``language_events``
|
||||||
|
columns are stripped before the sample reaches the preprocessor,
|
||||||
|
so ``RenderMessagesStep`` and ``SmolVLA2ChatTokenizerStep`` are
|
||||||
|
no-ops; the runtime supplies its own messages from current state.
|
||||||
|
"""
|
||||||
|
import torch # noqa: PLC0415
|
||||||
|
|
||||||
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset # noqa: PLC0415
|
||||||
|
|
||||||
|
ds = LeRobotDataset(dataset_repo_id, episodes=[episode])
|
||||||
|
if len(ds) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Dataset {dataset_repo_id!r} episode {episode} is empty."
|
||||||
|
)
|
||||||
|
|
||||||
|
state = {"cursor": max(0, min(start_frame, len(ds) - 1))}
|
||||||
|
|
||||||
|
def _provider() -> dict | None:
|
||||||
|
idx = state["cursor"]
|
||||||
|
if advance_per_tick > 0:
|
||||||
|
state["cursor"] = (idx + advance_per_tick) % len(ds)
|
||||||
|
|
||||||
|
sample = ds[idx]
|
||||||
|
# Strip the language columns so the preprocessor's render step
|
||||||
|
# is a no-op — the runtime drives messages itself.
|
||||||
|
for k in ("language_persistent", "language_events"):
|
||||||
|
sample.pop(k, None)
|
||||||
|
|
||||||
|
if preprocessor is not None:
|
||||||
|
sample = preprocessor(sample)
|
||||||
|
|
||||||
|
# Keep only observation keys; the runtime's text path will
|
||||||
|
# merge these with its own lang_tokens / lang_masks.
|
||||||
|
observation = {
|
||||||
|
k: v
|
||||||
|
for k, v in sample.items()
|
||||||
|
if isinstance(k, str) and k.startswith("observation.")
|
||||||
|
}
|
||||||
|
# Defensive: if something further upstream forgot the batch
|
||||||
|
# dim, add it now so downstream Tensor ops don't crash.
|
||||||
|
for k, v in list(observation.items()):
|
||||||
|
if isinstance(v, torch.Tensor) and v.ndim > 0 and v.shape[0] != 1:
|
||||||
|
# ``add_batch_dim`` already ran inside the preprocessor;
|
||||||
|
# an unbatched tensor at this point means a step
|
||||||
|
# somewhere produced an unbatched output. Best-effort
|
||||||
|
# fix.
|
||||||
|
if v.shape[0] != 1 and v.ndim < 4 and "image" not in k:
|
||||||
|
observation[k] = v.unsqueeze(0)
|
||||||
|
# Move to device (the preprocessor's DeviceProcessorStep should
|
||||||
|
# already have done this when ``preprocessor is not None``;
|
||||||
|
# this is a belt-and-braces no-op in the common case).
|
||||||
|
for k, v in list(observation.items()):
|
||||||
|
if isinstance(v, torch.Tensor):
|
||||||
|
observation[k] = v.to(device)
|
||||||
|
return observation
|
||||||
|
|
||||||
|
return _provider
|
||||||
|
|
||||||
|
|
||||||
|
def _build_tools(no_tts: bool, tts_voice: str) -> dict[str, Any]:
|
||||||
"""Instantiate the tools declared on this dataset/policy."""
|
"""Instantiate the tools declared on this dataset/policy."""
|
||||||
if no_tts:
|
if no_tts:
|
||||||
return {}
|
return {}
|
||||||
@@ -140,20 +314,32 @@ def main(argv: list[str] | None = None) -> int:
|
|||||||
format="%(asctime)s %(levelname)s %(message)s",
|
format="%(asctime)s %(levelname)s %(message)s",
|
||||||
)
|
)
|
||||||
|
|
||||||
if not args.policy_path.exists():
|
|
||||||
print(f"[smolvla2] policy path not found: {args.policy_path}", file=sys.stderr)
|
|
||||||
return 1
|
|
||||||
|
|
||||||
print(f"[smolvla2] loading policy from {args.policy_path}", flush=True)
|
print(f"[smolvla2] loading policy from {args.policy_path}", flush=True)
|
||||||
policy = _load_policy(args.policy_path)
|
policy, preprocessor, _ds_meta = _load_policy_and_preprocessor(
|
||||||
|
args.policy_path, args.dataset_repo_id
|
||||||
|
)
|
||||||
|
|
||||||
tools = _build_tools(args.policy_path, args.no_tts, args.tts_voice)
|
observation_provider: Callable[[], dict | None] | None = None
|
||||||
|
if args.dataset_repo_id is not None:
|
||||||
|
print(
|
||||||
|
f"[smolvla2] streaming observations from {args.dataset_repo_id} "
|
||||||
|
f"episode={args.dataset_episode} "
|
||||||
|
f"start_frame={args.dataset_start_frame}",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
observation_provider = _build_observation_provider(
|
||||||
|
dataset_repo_id=args.dataset_repo_id,
|
||||||
|
episode=args.dataset_episode,
|
||||||
|
start_frame=args.dataset_start_frame,
|
||||||
|
advance_per_tick=args.dataset_advance_per_tick,
|
||||||
|
preprocessor=preprocessor,
|
||||||
|
device=str(getattr(policy.config, "device", "cpu")),
|
||||||
|
)
|
||||||
|
|
||||||
|
tools = _build_tools(args.no_tts, args.tts_voice)
|
||||||
if tools:
|
if tools:
|
||||||
print(f"[smolvla2] tools loaded: {list(tools)}", flush=True)
|
print(f"[smolvla2] tools loaded: {list(tools)}", flush=True)
|
||||||
|
|
||||||
# Robot wiring is left as a follow-up — for v1 we run language-only
|
|
||||||
# / dry-run so REPL development doesn't require a connected robot.
|
|
||||||
observation_provider = None
|
|
||||||
robot_executor = None
|
robot_executor = None
|
||||||
if not args.no_robot:
|
if not args.no_robot:
|
||||||
print(
|
print(
|
||||||
@@ -162,33 +348,112 @@ def main(argv: list[str] | None = None) -> int:
|
|||||||
flush=True,
|
flush=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
from lerobot.policies.smolvla2.inference import ( # noqa: PLC0415
|
from lerobot.policies.smolvla2.inference import SmolVLA2Runtime # noqa: PLC0415
|
||||||
SmolVLA2Runtime,
|
|
||||||
StdinReader,
|
|
||||||
)
|
|
||||||
|
|
||||||
runtime = SmolVLA2Runtime(
|
runtime = SmolVLA2Runtime(
|
||||||
policy=policy,
|
policy=policy,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
observation_provider=observation_provider,
|
observation_provider=observation_provider,
|
||||||
robot_executor=robot_executor,
|
robot_executor=robot_executor,
|
||||||
event_collector=StdinReader().poll,
|
# No background event collector — the REPL drives ticks
|
||||||
|
# synchronously after each user input. The runtime's own
|
||||||
|
# ``run()`` loop is bypassed here in favour of ``step_once()``
|
||||||
|
# so the input prompt and the live state panel co-exist
|
||||||
|
# cleanly.
|
||||||
|
event_collector=None,
|
||||||
chunk_hz=args.chunk_hz,
|
chunk_hz=args.chunk_hz,
|
||||||
ctrl_hz=args.ctrl_hz,
|
ctrl_hz=args.ctrl_hz,
|
||||||
high_level_hz=args.high_level_hz,
|
high_level_hz=args.high_level_hz,
|
||||||
)
|
)
|
||||||
if args.task:
|
if args.task:
|
||||||
runtime.set_task(args.task)
|
runtime.set_task(args.task)
|
||||||
print(
|
|
||||||
"[smolvla2] runtime ready. Type a task to begin, then any line for "
|
return _run_repl(runtime, initial_task=args.task, max_ticks=args.max_ticks)
|
||||||
"interjections, questions ending in '?' for VQA, or 'stop' to exit.",
|
|
||||||
flush=True,
|
|
||||||
)
|
def _run_repl(runtime: Any, *, initial_task: str | None, max_ticks: int | None) -> int:
|
||||||
|
"""Two-zone TUI: chat scrollback above a persistent state panel.
|
||||||
|
|
||||||
|
Uses :class:`rich.live.Live` to keep the state panel rendered
|
||||||
|
below the chat history. ``console.input`` for the prompt — Live
|
||||||
|
auto-suspends repaint while the user is typing.
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
runtime.run(max_ticks=args.max_ticks)
|
from rich.console import Console # noqa: PLC0415
|
||||||
except KeyboardInterrupt:
|
from rich.live import Live # noqa: PLC0415
|
||||||
runtime.stop()
|
except ImportError:
|
||||||
print("\n[smolvla2] interrupted by user", flush=True)
|
print(
|
||||||
|
"[smolvla2] rich is required for the interactive REPL. "
|
||||||
|
"`pip install rich` and re-run.",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
return 2
|
||||||
|
|
||||||
|
from lerobot.policies.smolvla2.inference import ( # noqa: PLC0415
|
||||||
|
make_state_panel,
|
||||||
|
print_robot_lines,
|
||||||
|
print_user_line,
|
||||||
|
)
|
||||||
|
|
||||||
|
console = Console()
|
||||||
|
console.print(
|
||||||
|
"[bold]SmolVLA2[/] ready. "
|
||||||
|
"Type a task to begin, then any line for an interjection, "
|
||||||
|
"a line ending in '?' for VQA, or 'stop' to exit.",
|
||||||
|
)
|
||||||
|
if initial_task is None:
|
||||||
|
console.print("[dim]No --task provided; first stdin line will be used.[/]")
|
||||||
|
|
||||||
|
panel = make_state_panel(runtime.state)
|
||||||
|
ticks_done = 0
|
||||||
|
with Live(
|
||||||
|
panel,
|
||||||
|
console=console,
|
||||||
|
refresh_per_second=4,
|
||||||
|
transient=False,
|
||||||
|
screen=False,
|
||||||
|
) as live:
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
line = console.input("[bold cyan]>[/] ").strip()
|
||||||
|
except EOFError:
|
||||||
|
break
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
lower = line.lower()
|
||||||
|
if lower in {"stop", "quit", "exit"}:
|
||||||
|
break
|
||||||
|
|
||||||
|
print_user_line(live.console, line)
|
||||||
|
|
||||||
|
# Inject the user input as the right kind of event,
|
||||||
|
# then run a single pipeline tick to consume it.
|
||||||
|
if not runtime.state.get("task"):
|
||||||
|
task = line[5:].strip() if lower.startswith("task:") else line
|
||||||
|
runtime.set_task(task)
|
||||||
|
elif lower.endswith("?"):
|
||||||
|
runtime.state["recent_vqa_query"] = line
|
||||||
|
runtime.state.setdefault("events_this_tick", []).append(
|
||||||
|
"user_vqa_query"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
runtime.state["recent_interjection"] = line
|
||||||
|
runtime.state.setdefault("events_this_tick", []).append(
|
||||||
|
"user_interjection"
|
||||||
|
)
|
||||||
|
|
||||||
|
logs = runtime.step_once()
|
||||||
|
if logs:
|
||||||
|
print_robot_lines(live.console, logs)
|
||||||
|
live.update(make_state_panel(runtime.state))
|
||||||
|
|
||||||
|
ticks_done += 1
|
||||||
|
if max_ticks is not None and ticks_done >= max_ticks:
|
||||||
|
break
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
console.print("\n[smolvla2] interrupted", style="dim")
|
||||||
|
print("[smolvla2] runtime stopped", flush=True)
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user