mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-25 21:50:03 +00:00
feat(smolvla2): runtime accepts Hub IDs + dataset-driven dry-run
The runtime CLI's loader was broken — it imported a `make_policy_from_path` that doesn't exist in `lerobot.policies.factory` — and the high-level text steps generated plan / subtask / memory / VQA from a text-only batch with no images or state, so dry-runs drifted from the training distribution. Switch to the standard `PreTrainedConfig.from_pretrained` + `make_policy(cfg, ds_meta=...)` flow so `--policy.path` accepts both local directories and Hub repo ids, and add a `--dataset.repo_id` path that walks a chosen episode and feeds preprocessed observations into every forward pass — including the four high-level steps (`HighLevelSubtaskFwd`, `MemoryUpdateFwd`, `UserInterjectionFwd`, `AskVQAFwd`). Frames are routed through the saved preprocessor pipeline with `language_persistent` / `language_events` stripped so the recipe-render step stays a no-op (the runtime supplies its own messages from current state). Also wires the rich-based two-zone REPL layout (`ui.py`) that the script was already importing. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -40,6 +40,7 @@ from .steps import (
|
||||
UserInterjectionFwd,
|
||||
)
|
||||
from .triggers import EventTrigger, HzTrigger, Tick, TickClock, Trigger
|
||||
from .ui import make_state_panel, print_robot_lines, print_user_line
|
||||
|
||||
__all__ = [
|
||||
# runtime
|
||||
@@ -65,4 +66,8 @@ __all__ = [
|
||||
"UserInterjectionFwd",
|
||||
"AskVQAFwd",
|
||||
"DispatchToolCalls",
|
||||
# UI
|
||||
"make_state_panel",
|
||||
"print_robot_lines",
|
||||
"print_user_line",
|
||||
]
|
||||
|
||||
@@ -82,10 +82,20 @@ class SmolVLA2Runtime:
|
||||
HighLevelSubtaskFwd(
|
||||
trigger=HzTrigger(self.high_level_hz),
|
||||
policy=self.policy,
|
||||
observation_provider=self.observation_provider,
|
||||
),
|
||||
MemoryUpdateFwd(
|
||||
policy=self.policy,
|
||||
observation_provider=self.observation_provider,
|
||||
),
|
||||
UserInterjectionFwd(
|
||||
policy=self.policy,
|
||||
observation_provider=self.observation_provider,
|
||||
),
|
||||
AskVQAFwd(
|
||||
policy=self.policy,
|
||||
observation_provider=self.observation_provider,
|
||||
),
|
||||
MemoryUpdateFwd(policy=self.policy),
|
||||
UserInterjectionFwd(policy=self.policy),
|
||||
AskVQAFwd(policy=self.policy),
|
||||
DispatchToolCalls(tools=self.tools),
|
||||
]
|
||||
self.state = initial_runtime_state()
|
||||
@@ -127,6 +137,39 @@ class SmolVLA2Runtime:
|
||||
|
||||
self._on_shutdown()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# REPL helper: drive one full pipeline pass and return its logs
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def step_once(self) -> list[str]:
|
||||
"""Run one tick of the pipeline and return the log lines.
|
||||
|
||||
Used by the interactive REPL: instead of a background thread,
|
||||
the CLI drives ticks synchronously after each user input. Logs
|
||||
are returned (not printed) so the caller can route them into
|
||||
the rich-Live chat scrollback.
|
||||
"""
|
||||
from .triggers import Tick # noqa: PLC0415
|
||||
|
||||
# Synthesize a tick. We don't need the real wall-clock pacing
|
||||
# here — the REPL drives the runtime, not vice versa — but
|
||||
# ``HzTrigger`` uses ``tick.monotonic_seconds`` to gate, so we
|
||||
# bump it generously so every Hz-triggered step considers
|
||||
# itself due.
|
||||
import time as _time # noqa: PLC0415
|
||||
|
||||
prev_index = self.state.get("_tick").index if isinstance(self.state.get("_tick"), Tick) else 0
|
||||
self.state["_tick"] = Tick(index=prev_index + 1, monotonic_seconds=_time.monotonic())
|
||||
self.state["log_lines"] = []
|
||||
# ``events_this_tick`` is set up by the caller before
|
||||
# ``step_once`` (the REPL pushes user-driven events first).
|
||||
self.state.setdefault("events_this_tick", [])
|
||||
|
||||
for step in self.pipeline:
|
||||
self.state = step(self.state)
|
||||
|
||||
return list(self.state.get("log_lines") or [])
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# I/O
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@@ -171,13 +171,19 @@ class HighLevelSubtaskFwd(InferenceStep):
|
||||
"""At ~1 Hz, ask the policy for the next subtask."""
|
||||
|
||||
policy: Any = None
|
||||
observation_provider: Any = None
|
||||
"""Same shape as ``LowLevelForward.observation_provider``. When
|
||||
set, the resulting observation is merged into ``select_message``'s
|
||||
batch so text generation runs against real video + state."""
|
||||
|
||||
trigger: Trigger = field(default_factory=lambda: HzTrigger(hz=1.0))
|
||||
|
||||
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
|
||||
if self.policy is None or not state.get("task"):
|
||||
return None
|
||||
ctx = _control_context_messages(state)
|
||||
msg = _generate_with_policy(self.policy, ctx)
|
||||
observation = _maybe_observation(self.observation_provider)
|
||||
msg = _generate_with_policy(self.policy, ctx, observation=observation)
|
||||
if msg:
|
||||
changed = set_if_changed(state, "current_subtask", msg, label="subtask")
|
||||
if changed:
|
||||
@@ -191,6 +197,7 @@ class MemoryUpdateFwd(InferenceStep):
|
||||
"""On subtask boundary, refresh the compressed memory."""
|
||||
|
||||
policy: Any = None
|
||||
observation_provider: Any = None
|
||||
trigger: Trigger = field(default_factory=lambda: EventTrigger("subtask_change"))
|
||||
|
||||
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
|
||||
@@ -198,7 +205,8 @@ class MemoryUpdateFwd(InferenceStep):
|
||||
if self.policy is None:
|
||||
return None
|
||||
ctx = _control_context_messages(state, include_completed=True)
|
||||
new_memory = _generate_with_policy(self.policy, ctx)
|
||||
observation = _maybe_observation(self.observation_provider)
|
||||
new_memory = _generate_with_policy(self.policy, ctx, observation=observation)
|
||||
if new_memory:
|
||||
set_if_changed(state, "current_memory", new_memory, label="memory")
|
||||
return None
|
||||
@@ -209,6 +217,7 @@ class UserInterjectionFwd(InferenceStep):
|
||||
"""On stdin interjection, refresh the plan + emit a paired ``say``."""
|
||||
|
||||
policy: Any = None
|
||||
observation_provider: Any = None
|
||||
trigger: Trigger = field(default_factory=lambda: EventTrigger("user_interjection"))
|
||||
|
||||
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
|
||||
@@ -218,7 +227,8 @@ class UserInterjectionFwd(InferenceStep):
|
||||
state,
|
||||
extra_user=state.get("recent_interjection"),
|
||||
)
|
||||
out = _generate_with_policy(self.policy, ctx)
|
||||
observation = _maybe_observation(self.observation_provider)
|
||||
out = _generate_with_policy(self.policy, ctx, observation=observation)
|
||||
if not out:
|
||||
return None
|
||||
# Heuristic split: model is trained to emit one assistant turn
|
||||
@@ -247,6 +257,7 @@ class AskVQAFwd(InferenceStep):
|
||||
"""On stdin question, answer a frame-grounded VQA."""
|
||||
|
||||
policy: Any = None
|
||||
observation_provider: Any = None
|
||||
trigger: Trigger = field(default_factory=lambda: EventTrigger("user_vqa_query"))
|
||||
|
||||
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
|
||||
@@ -256,7 +267,8 @@ class AskVQAFwd(InferenceStep):
|
||||
if not question:
|
||||
return None
|
||||
ctx = _control_context_messages(state, extra_user=question)
|
||||
answer = _generate_with_policy(self.policy, ctx)
|
||||
observation = _maybe_observation(self.observation_provider)
|
||||
answer = _generate_with_policy(self.policy, ctx, observation=observation)
|
||||
if answer:
|
||||
push_log(state, f" vqa: {answer}")
|
||||
state["recent_vqa_query"] = None
|
||||
@@ -326,35 +338,54 @@ def _control_context_messages(
|
||||
return msgs
|
||||
|
||||
|
||||
def _generate_with_policy(policy: Any, messages: list[dict[str, Any]]) -> str:
|
||||
"""Drive ``policy.select_message`` with a minimal text-only batch.
|
||||
def _maybe_observation(provider: Any) -> dict | None:
|
||||
"""Pull one observation from ``provider`` if it's set, else ``None``.
|
||||
|
||||
Best-effort: the runtime today doesn't construct a full
|
||||
observation batch with images / state for text generation; the
|
||||
text-head was trained over images + lang + state, so generations
|
||||
here may differ in distribution from training. This is acceptable
|
||||
for a v1 REPL; a follow-up will plug in the real observation.
|
||||
Errors from the provider are logged at debug level and swallowed —
|
||||
text generation still runs (in text-only mode) so a flaky frame
|
||||
source doesn't kill the REPL.
|
||||
"""
|
||||
if provider is None:
|
||||
return None
|
||||
try:
|
||||
return provider()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.debug("observation_provider raised %s — falling back to text-only", exc)
|
||||
return None
|
||||
|
||||
|
||||
def _generate_with_policy(
|
||||
policy: Any,
|
||||
messages: list[dict[str, Any]],
|
||||
*,
|
||||
observation: dict | None = None,
|
||||
) -> str:
|
||||
"""Drive ``policy.select_message`` with a chat batch (and optional obs).
|
||||
|
||||
When ``observation`` carries ``observation.images.*`` and
|
||||
``observation.state``, those are merged into the batch so
|
||||
``select_message`` runs the same VLM prefix the policy was trained
|
||||
on. Without an observation the runtime falls back to a text-only
|
||||
prompt — the text head still runs, but generations may drift from
|
||||
the training distribution.
|
||||
"""
|
||||
if not hasattr(policy, "select_message"):
|
||||
return ""
|
||||
text_batch = _build_text_batch(policy, messages)
|
||||
# ``select_message`` expects a real batch with OBS_LANGUAGE_TOKENS.
|
||||
# The minimal text-only batch we build doesn't have images / state,
|
||||
# so we either run a text-only forward (handled by SmolVLA2 when
|
||||
# supported) or skip and return empty. v1 returns empty when the
|
||||
# policy can't handle it; the runtime logs and continues.
|
||||
try:
|
||||
# Convert to the OBS_LANGUAGE_TOKENS / OBS_LANGUAGE_ATTENTION_MASK
|
||||
# keys ``select_message`` uses internally.
|
||||
from lerobot.utils.constants import ( # noqa: PLC0415
|
||||
OBS_LANGUAGE_ATTENTION_MASK,
|
||||
OBS_LANGUAGE_TOKENS,
|
||||
)
|
||||
|
||||
batch = {
|
||||
batch: dict[str, Any] = {
|
||||
OBS_LANGUAGE_TOKENS: text_batch["lang_tokens"],
|
||||
OBS_LANGUAGE_ATTENTION_MASK: text_batch["lang_masks"],
|
||||
}
|
||||
if observation:
|
||||
for k, v in observation.items():
|
||||
if isinstance(k, str) and k.startswith("observation.") and k not in batch:
|
||||
batch[k] = v
|
||||
return policy.select_message(batch, tokenizer=text_batch["tokenizer"])
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.debug("select_message fell back: %s", exc)
|
||||
|
||||
@@ -0,0 +1,119 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Rich-based REPL layout for the SmolVLA2 runtime.
|
||||
|
||||
Two-zone terminal layout:
|
||||
|
||||
[chat scrollback — user messages / robot responses, scrolls naturally]
|
||||
|
||||
┌── State ──────────────────────────────────────────┐
|
||||
│ task please clean up the kitchen │
|
||||
│ subtask grasp the handle of the sponge │
|
||||
│ plan 1. grasp sponge 2. wipe 3. tidy │
|
||||
│ memory sponge picked up; counter still dirty │
|
||||
└───────────────────────────────────────────────────┘
|
||||
> _
|
||||
|
||||
The state panel re-renders on every state change. Chat lines are
|
||||
``console.print``'d above the live region so they accumulate naturally
|
||||
in scrollback. Implemented with :class:`rich.live.Live` plus
|
||||
:func:`rich.console.Console.input` for the prompt — when an input is
|
||||
pending, ``rich.Live`` auto-suspends so the input doesn't fight the
|
||||
panel for cursor position.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
try: # rich is optional; only required for the interactive REPL.
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.table import Table
|
||||
from rich.text import Text
|
||||
|
||||
_HAS_RICH = True
|
||||
except ImportError: # pragma: no cover
|
||||
_HAS_RICH = False
|
||||
Console = Any # type: ignore[assignment]
|
||||
Panel = Any # type: ignore[assignment]
|
||||
Table = Any # type: ignore[assignment]
|
||||
Text = Any # type: ignore[assignment]
|
||||
|
||||
|
||||
_STATE_KEYS = (
|
||||
("task", "task"),
|
||||
("current_subtask", "subtask"),
|
||||
("current_plan", "plan"),
|
||||
("current_memory", "memory"),
|
||||
)
|
||||
|
||||
|
||||
def make_state_panel(state: dict[str, Any]) -> Any:
|
||||
"""Render the persistent state panel for the live region.
|
||||
|
||||
Returns a :class:`rich.panel.Panel`. Caller passes it to
|
||||
``Live.update(panel)`` whenever the state changes.
|
||||
"""
|
||||
if not _HAS_RICH:
|
||||
raise RuntimeError(
|
||||
"rich is required for the interactive REPL. "
|
||||
"`pip install rich` (it's a transitive dep of lerobot)."
|
||||
)
|
||||
table = Table.grid(padding=(0, 2), expand=True)
|
||||
table.add_column(justify="right", style="dim", no_wrap=True, width=10)
|
||||
table.add_column(justify="left")
|
||||
for key, label in _STATE_KEYS:
|
||||
value = state.get(key)
|
||||
if value is None:
|
||||
rendered = Text("(not set)", style="dim italic")
|
||||
else:
|
||||
rendered = Text(str(value), style="bold")
|
||||
table.add_row(label, rendered)
|
||||
queue = state.get("action_queue")
|
||||
queue_len = len(queue) if hasattr(queue, "__len__") else 0
|
||||
pending = state.get("tool_calls_pending") or []
|
||||
footer = Text.assemble(
|
||||
("queued actions: ", "dim"),
|
||||
(str(queue_len), "bold cyan"),
|
||||
(" pending tool calls: ", "dim"),
|
||||
(str(len(pending)), "bold magenta"),
|
||||
)
|
||||
table.add_row("", footer)
|
||||
return Panel(table, title="[bold]SmolVLA2 state[/]", border_style="cyan")
|
||||
|
||||
|
||||
def print_user_line(console: Any, line: str) -> None:
|
||||
"""Append a user-typed line to the chat scrollback."""
|
||||
if not _HAS_RICH:
|
||||
print(f"you: {line}", flush=True)
|
||||
return
|
||||
console.print(f"[bold cyan]you:[/] {line}")
|
||||
|
||||
|
||||
def print_robot_lines(console: Any, lines: list[str]) -> None:
|
||||
"""Append robot/runtime log lines to the chat scrollback."""
|
||||
if not _HAS_RICH:
|
||||
for line in lines:
|
||||
print(f"robot: {line.lstrip()}", flush=True)
|
||||
return
|
||||
for line in lines:
|
||||
# The runtime uses leading whitespace + "label: text"; render
|
||||
# the label in green and the value in default for readability.
|
||||
stripped = line.lstrip()
|
||||
if ":" in stripped:
|
||||
label, _, value = stripped.partition(":")
|
||||
console.print(f"[bold green]robot[/] [dim]({label.strip()})[/] {value.strip()}")
|
||||
else:
|
||||
console.print(f"[bold green]robot:[/] {stripped}")
|
||||
@@ -23,11 +23,21 @@ speech) as they happen.
|
||||
Examples
|
||||
--------
|
||||
|
||||
Dry run on a checkpoint, no robot connected — useful for sanity-
|
||||
Dry run on a Hub checkpoint, no robot connected — useful for sanity-
|
||||
checking text generation::
|
||||
|
||||
uv run lerobot-smolvla2-runtime \\
|
||||
--policy.path=outputs/train/smolvla2_super_poulain/000020000/pretrained_model \\
|
||||
--policy.path=pepijn223/smolvla2_hirobot_super_poulain_tool2 \\
|
||||
--no_robot \\
|
||||
--task="please clean the kitchen"
|
||||
|
||||
Same, but feed real frames from an annotated dataset so plan / subtask
|
||||
/ memory / VQA generation runs against actual video + state::
|
||||
|
||||
uv run lerobot-smolvla2-runtime \\
|
||||
--policy.path=pepijn223/smolvla2_hirobot_super_poulain_tool2 \\
|
||||
--dataset.repo_id=pepijn223/super_poulain_annotated \\
|
||||
--dataset.episode=0 \\
|
||||
--no_robot \\
|
||||
--task="please clean the kitchen"
|
||||
|
||||
@@ -38,6 +48,9 @@ With a real robot::
|
||||
--robot.type=so101 --robot.port=/dev/tty.usbmodem... \\
|
||||
--tts.voice=alba
|
||||
|
||||
``--policy.path`` accepts either a local directory or a Hugging Face
|
||||
Hub repo id. ``--dataset.repo_id`` likewise.
|
||||
|
||||
Tool dispatch (TTS via ``SayTool``) is enabled by default when
|
||||
``pocket-tts`` is installed; pass ``--no_tts`` to disable.
|
||||
"""
|
||||
@@ -47,8 +60,7 @@ from __future__ import annotations
|
||||
import argparse
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import Any, Callable
|
||||
|
||||
logger = logging.getLogger("lerobot.smolvla2.runtime")
|
||||
|
||||
@@ -61,9 +73,51 @@ def _parse_args(argv: list[str] | None = None) -> argparse.Namespace:
|
||||
p.add_argument(
|
||||
"--policy.path",
|
||||
dest="policy_path",
|
||||
type=Path,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to a trained SmolVLA2 ``pretrained_model`` directory.",
|
||||
help=(
|
||||
"Local directory or Hugging Face Hub repo id pointing at a "
|
||||
"trained SmolVLA2 ``pretrained_model``."
|
||||
),
|
||||
)
|
||||
p.add_argument(
|
||||
"--dataset.repo_id",
|
||||
dest="dataset_repo_id",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"Optional dataset (local path or Hub repo id) used to drive "
|
||||
"observations during dry-run inference. When set, the runtime "
|
||||
"reads camera frames + state from the chosen episode and feeds "
|
||||
"them into all forward passes — so plan / subtask / memory / "
|
||||
"VQA generation see the same visual context the policy was "
|
||||
"trained on."
|
||||
),
|
||||
)
|
||||
p.add_argument(
|
||||
"--dataset.episode",
|
||||
dest="dataset_episode",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Episode index to walk through (default: 0).",
|
||||
)
|
||||
p.add_argument(
|
||||
"--dataset.start_frame",
|
||||
dest="dataset_start_frame",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Frame index within the episode to start from (default: 0).",
|
||||
)
|
||||
p.add_argument(
|
||||
"--dataset.advance_per_tick",
|
||||
dest="dataset_advance_per_tick",
|
||||
type=int,
|
||||
default=1,
|
||||
help=(
|
||||
"How many dataset frames to advance per runtime tick. The "
|
||||
"default of 1 means the runtime walks the episode forward "
|
||||
"frame by frame; set to 0 to freeze on ``start_frame``."
|
||||
),
|
||||
)
|
||||
p.add_argument(
|
||||
"--task",
|
||||
@@ -111,16 +165,136 @@ def _parse_args(argv: list[str] | None = None) -> argparse.Namespace:
|
||||
return p.parse_args(argv)
|
||||
|
||||
|
||||
def _load_policy(path: Path): # noqa: ANN202
|
||||
"""Load a SmolVLA2 checkpoint from ``path``."""
|
||||
from lerobot.policies.factory import make_policy_from_path # noqa: PLC0415
|
||||
def _load_policy_and_preprocessor(
|
||||
policy_path: str,
|
||||
dataset_repo_id: str | None,
|
||||
) -> tuple[Any, Any, Any]:
|
||||
"""Load a SmolVLA2 checkpoint (local path or Hub repo id).
|
||||
|
||||
When ``dataset_repo_id`` is provided, the dataset's metadata is used
|
||||
to derive policy features (matching the standard
|
||||
``make_policy(cfg, ds_meta=...)`` flow used by ``lerobot-train`` and
|
||||
``lerobot-record``). When it isn't, we fall back to instantiating
|
||||
the policy directly via ``from_pretrained`` — this skips the
|
||||
feature-derivation path that ``make_policy`` insists on, but also
|
||||
means we can't load the saved preprocessor pipeline (which depends
|
||||
on ``input_features`` / ``output_features``). For inference-only
|
||||
dry-runs this is fine; the policy still loads.
|
||||
|
||||
Returns ``(policy, preprocessor, ds_meta)`` where ``preprocessor``
|
||||
and ``ds_meta`` may be ``None`` if no dataset was provided.
|
||||
"""
|
||||
from lerobot.configs import PreTrainedConfig # noqa: PLC0415
|
||||
from lerobot.policies.factory import make_policy, make_pre_post_processors # noqa: PLC0415
|
||||
|
||||
cfg = PreTrainedConfig.from_pretrained(policy_path)
|
||||
cfg.pretrained_path = policy_path
|
||||
|
||||
ds_meta = None
|
||||
preprocessor = None
|
||||
if dataset_repo_id is not None:
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata # noqa: PLC0415
|
||||
|
||||
ds_meta = LeRobotDatasetMetadata(dataset_repo_id)
|
||||
policy = make_policy(cfg, ds_meta=ds_meta)
|
||||
preprocessor, _ = make_pre_post_processors(
|
||||
cfg,
|
||||
pretrained_path=policy_path,
|
||||
dataset_stats=ds_meta.stats,
|
||||
)
|
||||
else:
|
||||
# No dataset: instantiate the policy class directly so we don't
|
||||
# need ds_meta. This bypasses ``make_policy``'s feature-shape
|
||||
# derivation, which is fine for a pretrained checkpoint where
|
||||
# the saved config already carries those shapes.
|
||||
from lerobot.policies.factory import get_policy_class # noqa: PLC0415
|
||||
|
||||
policy_cls = get_policy_class(cfg.type)
|
||||
policy = policy_cls.from_pretrained(policy_path, config=cfg)
|
||||
policy.to(cfg.device)
|
||||
|
||||
policy = make_policy_from_path(str(path))
|
||||
policy.eval()
|
||||
return policy
|
||||
return policy, preprocessor, ds_meta
|
||||
|
||||
|
||||
def _build_tools(policy_path: Path, no_tts: bool, tts_voice: str) -> dict[str, Any]:
|
||||
def _build_observation_provider(
|
||||
*,
|
||||
dataset_repo_id: str,
|
||||
episode: int,
|
||||
start_frame: int,
|
||||
advance_per_tick: int,
|
||||
preprocessor: Any,
|
||||
device: str,
|
||||
) -> Callable[[], dict | None]:
|
||||
"""Build a closure that feeds dataset frames into the runtime.
|
||||
|
||||
Each call returns a preprocessed observation batch (images +
|
||||
state, batched, on the policy's device, normalized) suitable for
|
||||
``policy.select_action`` and ``policy.select_message``. The
|
||||
closure walks the chosen episode forward by ``advance_per_tick``
|
||||
frames per call, looping back to the episode start when it falls
|
||||
off the end.
|
||||
|
||||
The dataset's ``language_persistent`` / ``language_events``
|
||||
columns are stripped before the sample reaches the preprocessor,
|
||||
so ``RenderMessagesStep`` and ``SmolVLA2ChatTokenizerStep`` are
|
||||
no-ops; the runtime supplies its own messages from current state.
|
||||
"""
|
||||
import torch # noqa: PLC0415
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset # noqa: PLC0415
|
||||
|
||||
ds = LeRobotDataset(dataset_repo_id, episodes=[episode])
|
||||
if len(ds) == 0:
|
||||
raise ValueError(
|
||||
f"Dataset {dataset_repo_id!r} episode {episode} is empty."
|
||||
)
|
||||
|
||||
state = {"cursor": max(0, min(start_frame, len(ds) - 1))}
|
||||
|
||||
def _provider() -> dict | None:
|
||||
idx = state["cursor"]
|
||||
if advance_per_tick > 0:
|
||||
state["cursor"] = (idx + advance_per_tick) % len(ds)
|
||||
|
||||
sample = ds[idx]
|
||||
# Strip the language columns so the preprocessor's render step
|
||||
# is a no-op — the runtime drives messages itself.
|
||||
for k in ("language_persistent", "language_events"):
|
||||
sample.pop(k, None)
|
||||
|
||||
if preprocessor is not None:
|
||||
sample = preprocessor(sample)
|
||||
|
||||
# Keep only observation keys; the runtime's text path will
|
||||
# merge these with its own lang_tokens / lang_masks.
|
||||
observation = {
|
||||
k: v
|
||||
for k, v in sample.items()
|
||||
if isinstance(k, str) and k.startswith("observation.")
|
||||
}
|
||||
# Defensive: if something further upstream forgot the batch
|
||||
# dim, add it now so downstream Tensor ops don't crash.
|
||||
for k, v in list(observation.items()):
|
||||
if isinstance(v, torch.Tensor) and v.ndim > 0 and v.shape[0] != 1:
|
||||
# ``add_batch_dim`` already ran inside the preprocessor;
|
||||
# an unbatched tensor at this point means a step
|
||||
# somewhere produced an unbatched output. Best-effort
|
||||
# fix.
|
||||
if v.shape[0] != 1 and v.ndim < 4 and "image" not in k:
|
||||
observation[k] = v.unsqueeze(0)
|
||||
# Move to device (the preprocessor's DeviceProcessorStep should
|
||||
# already have done this when ``preprocessor is not None``;
|
||||
# this is a belt-and-braces no-op in the common case).
|
||||
for k, v in list(observation.items()):
|
||||
if isinstance(v, torch.Tensor):
|
||||
observation[k] = v.to(device)
|
||||
return observation
|
||||
|
||||
return _provider
|
||||
|
||||
|
||||
def _build_tools(no_tts: bool, tts_voice: str) -> dict[str, Any]:
|
||||
"""Instantiate the tools declared on this dataset/policy."""
|
||||
if no_tts:
|
||||
return {}
|
||||
@@ -140,20 +314,32 @@ def main(argv: list[str] | None = None) -> int:
|
||||
format="%(asctime)s %(levelname)s %(message)s",
|
||||
)
|
||||
|
||||
if not args.policy_path.exists():
|
||||
print(f"[smolvla2] policy path not found: {args.policy_path}", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
print(f"[smolvla2] loading policy from {args.policy_path}", flush=True)
|
||||
policy = _load_policy(args.policy_path)
|
||||
policy, preprocessor, _ds_meta = _load_policy_and_preprocessor(
|
||||
args.policy_path, args.dataset_repo_id
|
||||
)
|
||||
|
||||
tools = _build_tools(args.policy_path, args.no_tts, args.tts_voice)
|
||||
observation_provider: Callable[[], dict | None] | None = None
|
||||
if args.dataset_repo_id is not None:
|
||||
print(
|
||||
f"[smolvla2] streaming observations from {args.dataset_repo_id} "
|
||||
f"episode={args.dataset_episode} "
|
||||
f"start_frame={args.dataset_start_frame}",
|
||||
flush=True,
|
||||
)
|
||||
observation_provider = _build_observation_provider(
|
||||
dataset_repo_id=args.dataset_repo_id,
|
||||
episode=args.dataset_episode,
|
||||
start_frame=args.dataset_start_frame,
|
||||
advance_per_tick=args.dataset_advance_per_tick,
|
||||
preprocessor=preprocessor,
|
||||
device=str(getattr(policy.config, "device", "cpu")),
|
||||
)
|
||||
|
||||
tools = _build_tools(args.no_tts, args.tts_voice)
|
||||
if tools:
|
||||
print(f"[smolvla2] tools loaded: {list(tools)}", flush=True)
|
||||
|
||||
# Robot wiring is left as a follow-up — for v1 we run language-only
|
||||
# / dry-run so REPL development doesn't require a connected robot.
|
||||
observation_provider = None
|
||||
robot_executor = None
|
||||
if not args.no_robot:
|
||||
print(
|
||||
@@ -162,33 +348,112 @@ def main(argv: list[str] | None = None) -> int:
|
||||
flush=True,
|
||||
)
|
||||
|
||||
from lerobot.policies.smolvla2.inference import ( # noqa: PLC0415
|
||||
SmolVLA2Runtime,
|
||||
StdinReader,
|
||||
)
|
||||
from lerobot.policies.smolvla2.inference import SmolVLA2Runtime # noqa: PLC0415
|
||||
|
||||
runtime = SmolVLA2Runtime(
|
||||
policy=policy,
|
||||
tools=tools,
|
||||
observation_provider=observation_provider,
|
||||
robot_executor=robot_executor,
|
||||
event_collector=StdinReader().poll,
|
||||
# No background event collector — the REPL drives ticks
|
||||
# synchronously after each user input. The runtime's own
|
||||
# ``run()`` loop is bypassed here in favour of ``step_once()``
|
||||
# so the input prompt and the live state panel co-exist
|
||||
# cleanly.
|
||||
event_collector=None,
|
||||
chunk_hz=args.chunk_hz,
|
||||
ctrl_hz=args.ctrl_hz,
|
||||
high_level_hz=args.high_level_hz,
|
||||
)
|
||||
if args.task:
|
||||
runtime.set_task(args.task)
|
||||
print(
|
||||
"[smolvla2] runtime ready. Type a task to begin, then any line for "
|
||||
"interjections, questions ending in '?' for VQA, or 'stop' to exit.",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
return _run_repl(runtime, initial_task=args.task, max_ticks=args.max_ticks)
|
||||
|
||||
|
||||
def _run_repl(runtime: Any, *, initial_task: str | None, max_ticks: int | None) -> int:
|
||||
"""Two-zone TUI: chat scrollback above a persistent state panel.
|
||||
|
||||
Uses :class:`rich.live.Live` to keep the state panel rendered
|
||||
below the chat history. ``console.input`` for the prompt — Live
|
||||
auto-suspends repaint while the user is typing.
|
||||
"""
|
||||
try:
|
||||
runtime.run(max_ticks=args.max_ticks)
|
||||
except KeyboardInterrupt:
|
||||
runtime.stop()
|
||||
print("\n[smolvla2] interrupted by user", flush=True)
|
||||
from rich.console import Console # noqa: PLC0415
|
||||
from rich.live import Live # noqa: PLC0415
|
||||
except ImportError:
|
||||
print(
|
||||
"[smolvla2] rich is required for the interactive REPL. "
|
||||
"`pip install rich` and re-run.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return 2
|
||||
|
||||
from lerobot.policies.smolvla2.inference import ( # noqa: PLC0415
|
||||
make_state_panel,
|
||||
print_robot_lines,
|
||||
print_user_line,
|
||||
)
|
||||
|
||||
console = Console()
|
||||
console.print(
|
||||
"[bold]SmolVLA2[/] ready. "
|
||||
"Type a task to begin, then any line for an interjection, "
|
||||
"a line ending in '?' for VQA, or 'stop' to exit.",
|
||||
)
|
||||
if initial_task is None:
|
||||
console.print("[dim]No --task provided; first stdin line will be used.[/]")
|
||||
|
||||
panel = make_state_panel(runtime.state)
|
||||
ticks_done = 0
|
||||
with Live(
|
||||
panel,
|
||||
console=console,
|
||||
refresh_per_second=4,
|
||||
transient=False,
|
||||
screen=False,
|
||||
) as live:
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
line = console.input("[bold cyan]>[/] ").strip()
|
||||
except EOFError:
|
||||
break
|
||||
if not line:
|
||||
continue
|
||||
lower = line.lower()
|
||||
if lower in {"stop", "quit", "exit"}:
|
||||
break
|
||||
|
||||
print_user_line(live.console, line)
|
||||
|
||||
# Inject the user input as the right kind of event,
|
||||
# then run a single pipeline tick to consume it.
|
||||
if not runtime.state.get("task"):
|
||||
task = line[5:].strip() if lower.startswith("task:") else line
|
||||
runtime.set_task(task)
|
||||
elif lower.endswith("?"):
|
||||
runtime.state["recent_vqa_query"] = line
|
||||
runtime.state.setdefault("events_this_tick", []).append(
|
||||
"user_vqa_query"
|
||||
)
|
||||
else:
|
||||
runtime.state["recent_interjection"] = line
|
||||
runtime.state.setdefault("events_this_tick", []).append(
|
||||
"user_interjection"
|
||||
)
|
||||
|
||||
logs = runtime.step_once()
|
||||
if logs:
|
||||
print_robot_lines(live.console, logs)
|
||||
live.update(make_state_panel(runtime.state))
|
||||
|
||||
ticks_done += 1
|
||||
if max_ticks is not None and ticks_done >= max_ticks:
|
||||
break
|
||||
except KeyboardInterrupt:
|
||||
console.print("\n[smolvla2] interrupted", style="dim")
|
||||
print("[smolvla2] runtime stopped", flush=True)
|
||||
return 0
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user