mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-27 22:49:48 +00:00
Compare commits
2 Commits
37b1eb218a
...
223cc8a9e2
| Author | SHA1 | Date | |
|---|---|---|---|
| 223cc8a9e2 | |||
| af6d8ebd5b |
@@ -307,6 +307,7 @@ lerobot-imgtransform-viz="lerobot.scripts.lerobot_imgtransform_viz:main"
|
|||||||
lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main"
|
lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main"
|
||||||
lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main"
|
lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main"
|
||||||
lerobot-annotate="lerobot.scripts.lerobot_annotate:main"
|
lerobot-annotate="lerobot.scripts.lerobot_annotate:main"
|
||||||
|
lerobot-smolvla2-runtime="lerobot.scripts.lerobot_smolvla2_runtime:main"
|
||||||
|
|
||||||
# ---------------- Tool Configurations ----------------
|
# ---------------- Tool Configurations ----------------
|
||||||
[tool.setuptools.package-data]
|
[tool.setuptools.package-data]
|
||||||
|
|||||||
@@ -0,0 +1,68 @@
|
|||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""SmolVLA2 inference / runtime orchestration.
|
||||||
|
|
||||||
|
Multi-rate runtime that mirrors the recipe-time training shape:
|
||||||
|
|
||||||
|
low_level_execution → LowLevelForward + DispatchAction (high Hz)
|
||||||
|
high_level_subtask → HighLevelSubtaskFwd (~1 Hz)
|
||||||
|
memory_update → MemoryUpdateFwd (event: subtask_change)
|
||||||
|
user_interjection_response → UserInterjectionFwd (event: stdin)
|
||||||
|
ask_vqa_* → AskVQAFwd (event: stdin question)
|
||||||
|
speech tool calls → DispatchToolCalls (event: tool_call_pending)
|
||||||
|
|
||||||
|
The CLI ``lerobot-smolvla2-runtime`` builds an ``SmolVLA2Runtime`` and
|
||||||
|
calls ``run()``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .repl import StdinReader
|
||||||
|
from .runtime import SmolVLA2Runtime
|
||||||
|
from .runtime_state import initial_runtime_state, push_log, set_if_changed, take_event
|
||||||
|
from .steps import (
|
||||||
|
AskVQAFwd,
|
||||||
|
DispatchAction,
|
||||||
|
DispatchToolCalls,
|
||||||
|
HighLevelSubtaskFwd,
|
||||||
|
InferenceStep,
|
||||||
|
LowLevelForward,
|
||||||
|
MemoryUpdateFwd,
|
||||||
|
UserInterjectionFwd,
|
||||||
|
)
|
||||||
|
from .triggers import EventTrigger, HzTrigger, Tick, TickClock, Trigger
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
# runtime
|
||||||
|
"SmolVLA2Runtime",
|
||||||
|
"StdinReader",
|
||||||
|
# state helpers
|
||||||
|
"initial_runtime_state",
|
||||||
|
"push_log",
|
||||||
|
"set_if_changed",
|
||||||
|
"take_event",
|
||||||
|
# triggers
|
||||||
|
"Trigger",
|
||||||
|
"Tick",
|
||||||
|
"TickClock",
|
||||||
|
"HzTrigger",
|
||||||
|
"EventTrigger",
|
||||||
|
# steps
|
||||||
|
"InferenceStep",
|
||||||
|
"LowLevelForward",
|
||||||
|
"DispatchAction",
|
||||||
|
"HighLevelSubtaskFwd",
|
||||||
|
"MemoryUpdateFwd",
|
||||||
|
"UserInterjectionFwd",
|
||||||
|
"AskVQAFwd",
|
||||||
|
"DispatchToolCalls",
|
||||||
|
]
|
||||||
@@ -0,0 +1,87 @@
|
|||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Stdin REPL event collector for the SmolVLA2 runtime.
|
||||||
|
|
||||||
|
Reads non-blocking stdin lines, classifies each one heuristically:
|
||||||
|
|
||||||
|
"stop" / "quit" / "exit" → state["stop"] = True
|
||||||
|
ends with "?" → user_vqa_query event
|
||||||
|
starts with "task:" or first line → set runtime task
|
||||||
|
anything else → user_interjection event
|
||||||
|
|
||||||
|
Plugged into the runtime via ``event_collector=StdinReader().poll``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import select
|
||||||
|
import sys
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class StdinReader:
|
||||||
|
"""Non-blocking stdin line collector for the runtime loop."""
|
||||||
|
|
||||||
|
prompt: str = "> "
|
||||||
|
_seen_first_line: bool = field(default=False, init=False)
|
||||||
|
_prompted: bool = field(default=False, init=False)
|
||||||
|
|
||||||
|
def poll(self, state: dict[str, Any]) -> None:
|
||||||
|
"""Drain pending stdin lines into runtime events."""
|
||||||
|
# Print the input prompt once on every fresh tick if we don't
|
||||||
|
# already have a pending line; matches the expected REPL feel.
|
||||||
|
if not self._prompted:
|
||||||
|
print(self.prompt, end="", flush=True)
|
||||||
|
self._prompted = True
|
||||||
|
|
||||||
|
# ``select`` with timeout=0 makes this non-blocking. Only works
|
||||||
|
# for actual TTY / pipe stdins; CI / scripted runs hit EOF.
|
||||||
|
try:
|
||||||
|
ready, _, _ = select.select([sys.stdin], [], [], 0)
|
||||||
|
except (ValueError, OSError):
|
||||||
|
return
|
||||||
|
if not ready:
|
||||||
|
return
|
||||||
|
|
||||||
|
line = sys.stdin.readline()
|
||||||
|
if not line: # EOF
|
||||||
|
state["stop"] = True
|
||||||
|
return
|
||||||
|
line = line.strip()
|
||||||
|
self._prompted = False # we'll re-prompt next tick
|
||||||
|
if not line:
|
||||||
|
return
|
||||||
|
|
||||||
|
lower = line.lower()
|
||||||
|
if lower in {"stop", "quit", "exit"}:
|
||||||
|
state["stop"] = True
|
||||||
|
return
|
||||||
|
|
||||||
|
# First non-control line sets the task if no task is active.
|
||||||
|
if not state.get("task"):
|
||||||
|
task = line[5:].strip() if lower.startswith("task:") else line
|
||||||
|
state["task"] = task
|
||||||
|
print(f"[smolvla2] Task: {task}", flush=True)
|
||||||
|
self._seen_first_line = True
|
||||||
|
return
|
||||||
|
|
||||||
|
# Question → VQA; statement → interjection.
|
||||||
|
if lower.endswith("?"):
|
||||||
|
state["recent_vqa_query"] = line
|
||||||
|
state.setdefault("events_this_tick", []).append("user_vqa_query")
|
||||||
|
else:
|
||||||
|
state["recent_interjection"] = line
|
||||||
|
state.setdefault("events_this_tick", []).append("user_interjection")
|
||||||
@@ -0,0 +1,143 @@
|
|||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""SmolVLA2 runtime loop.
|
||||||
|
|
||||||
|
Threads the multi-rate inference pipeline together with a stdin REPL
|
||||||
|
event collector, drives ticks through :class:`TickClock`, and prints
|
||||||
|
state-change updates to the user.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from collections import deque
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any, Callable
|
||||||
|
|
||||||
|
from .runtime_state import initial_runtime_state, push_log
|
||||||
|
from .steps import (
|
||||||
|
AskVQAFwd,
|
||||||
|
DispatchAction,
|
||||||
|
DispatchToolCalls,
|
||||||
|
HighLevelSubtaskFwd,
|
||||||
|
InferenceStep,
|
||||||
|
LowLevelForward,
|
||||||
|
MemoryUpdateFwd,
|
||||||
|
UserInterjectionFwd,
|
||||||
|
)
|
||||||
|
from .triggers import HzTrigger, TickClock
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SmolVLA2Runtime:
|
||||||
|
"""Compose the inference pipeline and drive it tick-by-tick."""
|
||||||
|
|
||||||
|
policy: Any
|
||||||
|
tools: dict[str, Any] = field(default_factory=dict)
|
||||||
|
"""Name → tool-instance dict, e.g. ``{"say": SayTool(...)}``. Read
|
||||||
|
from :func:`lerobot.tools.get_tools(meta)` when wiring the
|
||||||
|
runtime."""
|
||||||
|
observation_provider: Callable[[], dict | None] | None = None
|
||||||
|
"""Closure returning the current preprocessed observation batch.
|
||||||
|
``None`` for dry-run / language-only sessions."""
|
||||||
|
robot_executor: Callable[[Any], None] | None = None
|
||||||
|
"""Closure that takes one action chunk and forwards it to the
|
||||||
|
robot. ``None`` for dry-run."""
|
||||||
|
event_collector: Callable[[dict], None] | None = None
|
||||||
|
"""Per-tick hook that polls external sources (stdin, network) and
|
||||||
|
appends event names to ``state["events_this_tick"]``."""
|
||||||
|
chunk_hz: float = 4.0
|
||||||
|
ctrl_hz: float = 50.0
|
||||||
|
high_level_hz: float = 1.0
|
||||||
|
max_rate_hz: float = 50.0
|
||||||
|
|
||||||
|
pipeline: list[InferenceStep] = field(init=False)
|
||||||
|
state: dict[str, Any] = field(init=False)
|
||||||
|
_stop: bool = field(default=False, init=False)
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
self.pipeline = [
|
||||||
|
LowLevelForward(
|
||||||
|
trigger=HzTrigger(self.chunk_hz),
|
||||||
|
policy=self.policy,
|
||||||
|
observation_provider=self.observation_provider,
|
||||||
|
),
|
||||||
|
DispatchAction(
|
||||||
|
trigger=HzTrigger(self.ctrl_hz),
|
||||||
|
robot_executor=self.robot_executor,
|
||||||
|
),
|
||||||
|
HighLevelSubtaskFwd(
|
||||||
|
trigger=HzTrigger(self.high_level_hz),
|
||||||
|
policy=self.policy,
|
||||||
|
),
|
||||||
|
MemoryUpdateFwd(policy=self.policy),
|
||||||
|
UserInterjectionFwd(policy=self.policy),
|
||||||
|
AskVQAFwd(policy=self.policy),
|
||||||
|
DispatchToolCalls(tools=self.tools),
|
||||||
|
]
|
||||||
|
self.state = initial_runtime_state()
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Lifecycle
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def set_task(self, task: str) -> None:
|
||||||
|
"""Set or replace the active task. Logged for the REPL."""
|
||||||
|
self.state["task"] = task
|
||||||
|
push_log(self.state, f"Task: {task}")
|
||||||
|
|
||||||
|
def stop(self) -> None:
|
||||||
|
self._stop = True
|
||||||
|
|
||||||
|
def run(self, *, max_ticks: int | None = None) -> None:
|
||||||
|
"""Main loop. Returns when ``stop()`` is called or after
|
||||||
|
``max_ticks`` ticks (useful for tests / dry-run)."""
|
||||||
|
clock = TickClock(max_rate_hz=self.max_rate_hz)
|
||||||
|
while not self._stop:
|
||||||
|
tick = clock.advance()
|
||||||
|
self.state["_tick"] = tick
|
||||||
|
self.state["events_this_tick"] = []
|
||||||
|
self.state["log_lines"] = []
|
||||||
|
|
||||||
|
if self.event_collector is not None:
|
||||||
|
self.event_collector(self.state)
|
||||||
|
if self.state.get("stop"):
|
||||||
|
self._stop = True
|
||||||
|
break
|
||||||
|
|
||||||
|
for step in self.pipeline:
|
||||||
|
self.state = step(self.state)
|
||||||
|
|
||||||
|
self._flush_logs()
|
||||||
|
if max_ticks is not None and tick.index >= max_ticks:
|
||||||
|
break
|
||||||
|
|
||||||
|
self._on_shutdown()
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# I/O
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _flush_logs(self) -> None:
|
||||||
|
for line in self.state.get("log_lines") or []:
|
||||||
|
print(f"[smolvla2] {line}", flush=True)
|
||||||
|
|
||||||
|
def _on_shutdown(self) -> None:
|
||||||
|
# Drain any queued action chunks safely.
|
||||||
|
queue = self.state.get("action_queue")
|
||||||
|
if isinstance(queue, deque):
|
||||||
|
queue.clear()
|
||||||
|
print("[smolvla2] runtime stopped", flush=True)
|
||||||
@@ -0,0 +1,91 @@
|
|||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Runtime state passed between inference steps each tick.
|
||||||
|
|
||||||
|
The runtime threads a single dict through the pipeline; this module
|
||||||
|
documents the shape and provides factories. We use a plain ``dict``
|
||||||
|
rather than a frozen dataclass because steps freely add and remove
|
||||||
|
keys (``events_this_tick``, ``messages_pending``, ``tool_calls_pending``,
|
||||||
|
…) and dataclass field churn would just get in the way.
|
||||||
|
|
||||||
|
Stable keys (read by multiple steps):
|
||||||
|
|
||||||
|
task str the current top-level task
|
||||||
|
current_plan str | None latest plan emitted by the planner
|
||||||
|
current_subtask str | None latest subtask the policy is executing
|
||||||
|
current_memory str | None latest compressed memory
|
||||||
|
recent_interjection str | None most recent user interjection text (consumed)
|
||||||
|
|
||||||
|
action_queue collections.deque[Tensor] pending action chunks
|
||||||
|
tool_calls_pending list[dict] parsed but not-yet-dispatched tool calls
|
||||||
|
|
||||||
|
events_this_tick list[str] triggers consumed this tick
|
||||||
|
_tick Tick current tick (set by the loop)
|
||||||
|
|
||||||
|
log_lines list[str] human-readable status lines printed each tick
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections import deque
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
def initial_runtime_state(task: str | None = None) -> dict[str, Any]:
|
||||||
|
"""Build a fresh runtime state dict with sensible defaults."""
|
||||||
|
return {
|
||||||
|
"task": task,
|
||||||
|
"current_plan": None,
|
||||||
|
"current_subtask": None,
|
||||||
|
"current_memory": None,
|
||||||
|
"recent_interjection": None,
|
||||||
|
"action_queue": deque(),
|
||||||
|
"tool_calls_pending": [],
|
||||||
|
"events_this_tick": [],
|
||||||
|
"log_lines": [],
|
||||||
|
"stop": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def take_event(state: dict[str, Any], event_name: str) -> bool:
|
||||||
|
"""Pop ``event_name`` from ``events_this_tick`` if present.
|
||||||
|
|
||||||
|
Steps that consume an event call this so the same event doesn't
|
||||||
|
re-fire on a sibling step within the same tick.
|
||||||
|
"""
|
||||||
|
events: list[str] = state.get("events_this_tick") or []
|
||||||
|
if event_name in events:
|
||||||
|
events.remove(event_name)
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def push_log(state: dict[str, Any], line: str) -> None:
|
||||||
|
"""Append ``line`` to the per-tick log buffer; the runtime prints
|
||||||
|
it at the end of the tick."""
|
||||||
|
state.setdefault("log_lines", []).append(line)
|
||||||
|
|
||||||
|
|
||||||
|
def set_if_changed(state: dict[str, Any], key: str, value: Any, label: str | None = None) -> bool:
|
||||||
|
"""Update ``state[key]`` and log a diff line if the value changed.
|
||||||
|
|
||||||
|
Returns ``True`` if the value actually changed.
|
||||||
|
"""
|
||||||
|
prev = state.get(key)
|
||||||
|
if prev == value:
|
||||||
|
return False
|
||||||
|
state[key] = value
|
||||||
|
if label is not None:
|
||||||
|
push_log(state, f" {label}: {value}")
|
||||||
|
return True
|
||||||
@@ -0,0 +1,382 @@
|
|||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Inference steps for the SmolVLA2 multi-rate runtime.
|
||||||
|
|
||||||
|
Each step is a tiny class with a ``trigger`` and an ``__call__(state)``;
|
||||||
|
the runtime applies them in order each tick. When a step's trigger
|
||||||
|
doesn't fire, the step is a no-op and the runtime moves on.
|
||||||
|
|
||||||
|
Stream-to-step mapping mirrors the ``smolvla2_hirobot.yaml`` recipe:
|
||||||
|
|
||||||
|
* ``LowLevelForward`` — calls ``policy.select_action`` for the
|
||||||
|
action chunk; trained by
|
||||||
|
``low_level_execution``
|
||||||
|
* ``EnqueueChunk`` — pushes the chunk to ``action_queue``
|
||||||
|
* ``DispatchAction`` — pops one action per control tick and
|
||||||
|
forwards to the robot
|
||||||
|
* ``HighLevelSubtaskFwd`` — calls ``policy.select_message`` for the
|
||||||
|
next subtask; trained by
|
||||||
|
``high_level_subtask``
|
||||||
|
* ``MemoryUpdateFwd`` — fires on subtask boundary; trained by
|
||||||
|
``memory_update``
|
||||||
|
* ``UserInterjectionFwd`` — fires on stdin interjection; trained by
|
||||||
|
``user_interjection_response``
|
||||||
|
* ``AskVQAFwd`` — fires on stdin question; trained by
|
||||||
|
``ask_vqa_*``
|
||||||
|
* ``DispatchToolCalls`` — pops ``tool_calls_pending`` and calls
|
||||||
|
the matching ``Tool`` instance
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from .runtime_state import push_log, set_if_changed, take_event
|
||||||
|
from .triggers import EventTrigger, HzTrigger, Trigger
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Step base + runner
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class InferenceStep:
|
||||||
|
"""A trigger-gated callable. Subclasses override :meth:`run`."""
|
||||||
|
|
||||||
|
trigger: Trigger
|
||||||
|
|
||||||
|
def __call__(self, state: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
if not self.trigger.should_fire(state["_tick"], state):
|
||||||
|
return state
|
||||||
|
return self.run(state) or state
|
||||||
|
|
||||||
|
def run(self, state: dict[str, Any]) -> dict[str, Any] | None: # pragma: no cover
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Low-level (action) path
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LowLevelForward(InferenceStep):
|
||||||
|
"""Run the policy's action head and produce one action chunk."""
|
||||||
|
|
||||||
|
policy: Any = None
|
||||||
|
observation_provider: Any = None
|
||||||
|
"""Callable ``() -> dict``: returns the current observation batch
|
||||||
|
(already preprocessed). Typically wraps the robot's camera /
|
||||||
|
proprio reads. ``None`` in dry-run mode → step skips."""
|
||||||
|
|
||||||
|
trigger: Trigger = field(default_factory=lambda: HzTrigger(hz=4.0))
|
||||||
|
|
||||||
|
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
|
||||||
|
if self.policy is None or self.observation_provider is None:
|
||||||
|
return None
|
||||||
|
observation = self.observation_provider()
|
||||||
|
if observation is None:
|
||||||
|
return None
|
||||||
|
action = self.policy.select_action(observation)
|
||||||
|
# SmolVLA returns a single action; if the underlying policy
|
||||||
|
# streams chunks, split per-step here. For v1 we just enqueue
|
||||||
|
# the result.
|
||||||
|
state.setdefault("action_queue", []).append(action)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DispatchAction(InferenceStep):
|
||||||
|
"""Pop one action per tick and hand it to the robot.
|
||||||
|
|
||||||
|
In dry-run mode (``robot_executor=None``) the step still pops the
|
||||||
|
queue so it doesn't grow unbounded — the popped tensor is logged
|
||||||
|
instead of executed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
robot_executor: Any = None
|
||||||
|
trigger: Trigger = field(default_factory=lambda: HzTrigger(hz=50.0))
|
||||||
|
|
||||||
|
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
|
||||||
|
queue = state.get("action_queue")
|
||||||
|
if not queue:
|
||||||
|
return None
|
||||||
|
action = queue.popleft() if hasattr(queue, "popleft") else queue.pop(0)
|
||||||
|
if self.robot_executor is not None:
|
||||||
|
self.robot_executor(action)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# High-level (text) paths — all use policy.select_message
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _build_text_batch(policy: Any, prompt_messages: list[dict[str, Any]]) -> dict[str, Any]:
|
||||||
|
"""Tokenize a list of chat messages into the batch shape
|
||||||
|
``select_message`` expects.
|
||||||
|
|
||||||
|
Lazy fallback: re-uses the policy's preprocessor by piggy-backing
|
||||||
|
on the chat tokenizer step. Production use should construct the
|
||||||
|
batch from a real observation; here we focus on the *language*
|
||||||
|
path which is independent of camera observations.
|
||||||
|
"""
|
||||||
|
from transformers import AutoTokenizer # noqa: PLC0415
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(policy.config.vlm_model_name)
|
||||||
|
if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
|
||||||
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
|
||||||
|
text_messages = [_strip_recipe_keys(m) for m in prompt_messages]
|
||||||
|
ids = tokenizer.apply_chat_template(
|
||||||
|
text_messages,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
tokenize=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
if isinstance(ids, list):
|
||||||
|
ids = ids[0] if ids else []
|
||||||
|
if hasattr(ids, "ndim") and ids.ndim == 1:
|
||||||
|
ids = ids.unsqueeze(0)
|
||||||
|
attn = (ids != tokenizer.pad_token_id) if tokenizer.pad_token_id is not None else None
|
||||||
|
return {"lang_tokens": ids, "lang_masks": attn, "tokenizer": tokenizer}
|
||||||
|
|
||||||
|
|
||||||
|
def _strip_recipe_keys(m: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
new = dict(m)
|
||||||
|
new.pop("stream", None)
|
||||||
|
new.pop("target", None)
|
||||||
|
return new
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class HighLevelSubtaskFwd(InferenceStep):
|
||||||
|
"""At ~1 Hz, ask the policy for the next subtask."""
|
||||||
|
|
||||||
|
policy: Any = None
|
||||||
|
trigger: Trigger = field(default_factory=lambda: HzTrigger(hz=1.0))
|
||||||
|
|
||||||
|
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
|
||||||
|
if self.policy is None or not state.get("task"):
|
||||||
|
return None
|
||||||
|
ctx = _control_context_messages(state)
|
||||||
|
msg = _generate_with_policy(self.policy, ctx)
|
||||||
|
if msg:
|
||||||
|
changed = set_if_changed(state, "current_subtask", msg, label="subtask")
|
||||||
|
if changed:
|
||||||
|
# Subtask change is a downstream trigger.
|
||||||
|
state.setdefault("events_this_tick", []).append("subtask_change")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MemoryUpdateFwd(InferenceStep):
|
||||||
|
"""On subtask boundary, refresh the compressed memory."""
|
||||||
|
|
||||||
|
policy: Any = None
|
||||||
|
trigger: Trigger = field(default_factory=lambda: EventTrigger("subtask_change"))
|
||||||
|
|
||||||
|
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
|
||||||
|
# Don't consume the event — multiple steps may want to react.
|
||||||
|
if self.policy is None:
|
||||||
|
return None
|
||||||
|
ctx = _control_context_messages(state, include_completed=True)
|
||||||
|
new_memory = _generate_with_policy(self.policy, ctx)
|
||||||
|
if new_memory:
|
||||||
|
set_if_changed(state, "current_memory", new_memory, label="memory")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class UserInterjectionFwd(InferenceStep):
|
||||||
|
"""On stdin interjection, refresh the plan + emit a paired ``say``."""
|
||||||
|
|
||||||
|
policy: Any = None
|
||||||
|
trigger: Trigger = field(default_factory=lambda: EventTrigger("user_interjection"))
|
||||||
|
|
||||||
|
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
|
||||||
|
if self.policy is None or not take_event(state, "user_interjection"):
|
||||||
|
return None
|
||||||
|
ctx = _control_context_messages(
|
||||||
|
state,
|
||||||
|
extra_user=state.get("recent_interjection"),
|
||||||
|
)
|
||||||
|
out = _generate_with_policy(self.policy, ctx)
|
||||||
|
if not out:
|
||||||
|
return None
|
||||||
|
# Heuristic split: model is trained to emit one assistant turn
|
||||||
|
# carrying both plan text AND a `say` tool call. Look for a
|
||||||
|
# "<say>...</say>" or "say(...)" marker; fall back to whole
|
||||||
|
# text → plan, no speech.
|
||||||
|
plan_text, speech_text = _split_plan_and_say(out)
|
||||||
|
if plan_text:
|
||||||
|
set_if_changed(state, "current_plan", plan_text, label="plan")
|
||||||
|
if speech_text:
|
||||||
|
push_log(state, f" speech: {speech_text}")
|
||||||
|
state.setdefault("tool_calls_pending", []).append(
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": "say", "arguments": {"text": speech_text}},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
state.setdefault("events_this_tick", []).append("tool_call_pending")
|
||||||
|
# Mark interjection consumed.
|
||||||
|
state["recent_interjection"] = None
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AskVQAFwd(InferenceStep):
|
||||||
|
"""On stdin question, answer a frame-grounded VQA."""
|
||||||
|
|
||||||
|
policy: Any = None
|
||||||
|
trigger: Trigger = field(default_factory=lambda: EventTrigger("user_vqa_query"))
|
||||||
|
|
||||||
|
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
|
||||||
|
if self.policy is None or not take_event(state, "user_vqa_query"):
|
||||||
|
return None
|
||||||
|
question = state.get("recent_vqa_query")
|
||||||
|
if not question:
|
||||||
|
return None
|
||||||
|
ctx = _control_context_messages(state, extra_user=question)
|
||||||
|
answer = _generate_with_policy(self.policy, ctx)
|
||||||
|
if answer:
|
||||||
|
push_log(state, f" vqa: {answer}")
|
||||||
|
state["recent_vqa_query"] = None
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Tool dispatch
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DispatchToolCalls(InferenceStep):
|
||||||
|
"""Pop ``tool_calls_pending`` and execute them via :data:`TOOL_REGISTRY`."""
|
||||||
|
|
||||||
|
tools: dict[str, Any] = field(default_factory=dict)
|
||||||
|
trigger: Trigger = field(default_factory=lambda: EventTrigger("tool_call_pending"))
|
||||||
|
|
||||||
|
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
|
||||||
|
take_event(state, "tool_call_pending")
|
||||||
|
pending = state.get("tool_calls_pending") or []
|
||||||
|
for call in pending:
|
||||||
|
try:
|
||||||
|
fn = (call or {}).get("function") or {}
|
||||||
|
name = fn.get("name")
|
||||||
|
args = fn.get("arguments") or {}
|
||||||
|
tool = self.tools.get(name)
|
||||||
|
if tool is None:
|
||||||
|
push_log(state, f" [warn] tool {name!r} not registered — skipping call")
|
||||||
|
continue
|
||||||
|
tool.call(args)
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
push_log(state, f" [error] tool dispatch failed: {exc}")
|
||||||
|
state["tool_calls_pending"] = []
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _control_context_messages(
|
||||||
|
state: dict[str, Any],
|
||||||
|
*,
|
||||||
|
include_completed: bool = False,
|
||||||
|
extra_user: str | None = None,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Build a chat-template-ready prompt from current runtime state.
|
||||||
|
|
||||||
|
Mirrors what ``smolvla2_hirobot.yaml`` renders into ``${task}\nPlan:
|
||||||
|
${plan}\nMemory: ${memory}`` for the high-level branches.
|
||||||
|
"""
|
||||||
|
parts: list[str] = []
|
||||||
|
task = state.get("task") or ""
|
||||||
|
parts.append(task)
|
||||||
|
if state.get("current_plan"):
|
||||||
|
parts.append(f"Plan: {state['current_plan']}")
|
||||||
|
if state.get("current_memory"):
|
||||||
|
parts.append(f"Memory: {state['current_memory']}")
|
||||||
|
if include_completed and state.get("current_subtask"):
|
||||||
|
parts.append(f"Completed subtask: {state['current_subtask']}")
|
||||||
|
head = "\n".join(parts)
|
||||||
|
msgs: list[dict[str, Any]] = [{"role": "user", "content": head}]
|
||||||
|
if extra_user:
|
||||||
|
msgs.append({"role": "user", "content": extra_user})
|
||||||
|
return msgs
|
||||||
|
|
||||||
|
|
||||||
|
def _generate_with_policy(policy: Any, messages: list[dict[str, Any]]) -> str:
|
||||||
|
"""Drive ``policy.select_message`` with a minimal text-only batch.
|
||||||
|
|
||||||
|
Best-effort: the runtime today doesn't construct a full
|
||||||
|
observation batch with images / state for text generation; the
|
||||||
|
text-head was trained over images + lang + state, so generations
|
||||||
|
here may differ in distribution from training. This is acceptable
|
||||||
|
for a v1 REPL; a follow-up will plug in the real observation.
|
||||||
|
"""
|
||||||
|
if not hasattr(policy, "select_message"):
|
||||||
|
return ""
|
||||||
|
text_batch = _build_text_batch(policy, messages)
|
||||||
|
# ``select_message`` expects a real batch with OBS_LANGUAGE_TOKENS.
|
||||||
|
# The minimal text-only batch we build doesn't have images / state,
|
||||||
|
# so we either run a text-only forward (handled by SmolVLA2 when
|
||||||
|
# supported) or skip and return empty. v1 returns empty when the
|
||||||
|
# policy can't handle it; the runtime logs and continues.
|
||||||
|
try:
|
||||||
|
# Convert to the OBS_LANGUAGE_TOKENS / OBS_LANGUAGE_ATTENTION_MASK
|
||||||
|
# keys ``select_message`` uses internally.
|
||||||
|
from lerobot.utils.constants import ( # noqa: PLC0415
|
||||||
|
OBS_LANGUAGE_ATTENTION_MASK,
|
||||||
|
OBS_LANGUAGE_TOKENS,
|
||||||
|
)
|
||||||
|
|
||||||
|
batch = {
|
||||||
|
OBS_LANGUAGE_TOKENS: text_batch["lang_tokens"],
|
||||||
|
OBS_LANGUAGE_ATTENTION_MASK: text_batch["lang_masks"],
|
||||||
|
}
|
||||||
|
return policy.select_message(batch, tokenizer=text_batch["tokenizer"])
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
logger.debug("select_message fell back: %s", exc)
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
_SAY_RE = re.compile(r"<\s*say\s*>(.*?)<\s*/\s*say\s*>", re.IGNORECASE | re.DOTALL)
|
||||||
|
|
||||||
|
|
||||||
|
def _split_plan_and_say(text: str) -> tuple[str, str]:
|
||||||
|
"""Pull a ``<say>...</say>`` snippet out of ``text``; remainder is plan.
|
||||||
|
|
||||||
|
The training-time tool-call serializer wraps ``say(text="…")`` in a
|
||||||
|
deterministic textual marker so prefix-LM-style training learns to
|
||||||
|
emit it. The runtime parses it back here. If no marker is present,
|
||||||
|
the entire text is treated as plan with no speech.
|
||||||
|
"""
|
||||||
|
if not text:
|
||||||
|
return "", ""
|
||||||
|
match = _SAY_RE.search(text)
|
||||||
|
if not match:
|
||||||
|
return text.strip(), ""
|
||||||
|
speech = match.group(1).strip().strip('"').strip("'")
|
||||||
|
plan = (text[: match.start()] + text[match.end() :]).strip()
|
||||||
|
return plan, speech
|
||||||
@@ -0,0 +1,117 @@
|
|||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Trigger primitives for SmolVLA2's multi-rate inference runtime.
|
||||||
|
|
||||||
|
Mirrors the plan's Section "Runtime orchestration": each
|
||||||
|
``InferenceStep`` is gated by a :class:`Trigger` that decides per tick
|
||||||
|
whether the step fires. Two trigger flavours cover all the cadences
|
||||||
|
the canonical recipe needs:
|
||||||
|
|
||||||
|
* :class:`HzTrigger` for periodic beats (action chunks at ~3-5 Hz,
|
||||||
|
high-level subtask generation at ~1 Hz, action dispatch at ~50 Hz)
|
||||||
|
* :class:`EventTrigger` for one-shot reactions (subtask boundary →
|
||||||
|
memory update; user interjection → plan refresh; user VQA query →
|
||||||
|
vqa answer; pending tool call → dispatcher)
|
||||||
|
|
||||||
|
Triggers are stateless except for ``HzTrigger``'s last-fire timestamp.
|
||||||
|
The runtime stores the :class:`Tick` clock as ``state["_tick"]`` so
|
||||||
|
every step shares a single time source.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any, Protocol
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Tick:
|
||||||
|
"""Single tick from :class:`TickClock`. Carries time references the
|
||||||
|
runtime steps consume to gate themselves."""
|
||||||
|
|
||||||
|
index: int
|
||||||
|
"""Monotonic counter — increments by one per tick."""
|
||||||
|
|
||||||
|
monotonic_seconds: float
|
||||||
|
"""``time.monotonic()`` at the start of this tick."""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TickClock:
|
||||||
|
"""Drives the runtime loop at up to ``max_rate_hz``.
|
||||||
|
|
||||||
|
Sleeps just enough between :meth:`advance` calls to enforce the
|
||||||
|
rate. With ``max_rate_hz=50`` the loop wakes ~every 20ms; the
|
||||||
|
higher-level ``HzTrigger`` slices that timeline into sub-cadences.
|
||||||
|
"""
|
||||||
|
|
||||||
|
max_rate_hz: float = 50.0
|
||||||
|
_index: int = field(default=0, init=False)
|
||||||
|
_last_seconds: float | None = field(default=None, init=False)
|
||||||
|
|
||||||
|
def advance(self) -> Tick:
|
||||||
|
period = 1.0 / max(self.max_rate_hz, 0.1)
|
||||||
|
now = time.monotonic()
|
||||||
|
if self._last_seconds is not None:
|
||||||
|
sleep_for = (self._last_seconds + period) - now
|
||||||
|
if sleep_for > 0:
|
||||||
|
time.sleep(sleep_for)
|
||||||
|
now = time.monotonic()
|
||||||
|
self._last_seconds = now
|
||||||
|
self._index += 1
|
||||||
|
return Tick(index=self._index, monotonic_seconds=now)
|
||||||
|
|
||||||
|
|
||||||
|
class Trigger(Protocol):
|
||||||
|
"""Decide whether the next ``InferenceStep`` should fire."""
|
||||||
|
|
||||||
|
def should_fire(self, tick: Tick, state: dict[str, Any]) -> bool: ...
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class HzTrigger:
|
||||||
|
"""Fire at most ``hz`` times per second."""
|
||||||
|
|
||||||
|
hz: float
|
||||||
|
_last_seconds: float | None = field(default=None, init=False)
|
||||||
|
|
||||||
|
def should_fire(self, tick: Tick, state: dict[str, Any]) -> bool:
|
||||||
|
period = 1.0 / max(self.hz, 1e-6)
|
||||||
|
if self._last_seconds is None or (tick.monotonic_seconds - self._last_seconds) >= period:
|
||||||
|
self._last_seconds = tick.monotonic_seconds
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EventTrigger:
|
||||||
|
"""Fire when ``event_name`` is in ``state["events_this_tick"]``.
|
||||||
|
|
||||||
|
The runtime fills ``events_this_tick`` once per tick from:
|
||||||
|
|
||||||
|
* stdin / network input (``user_interjection``, ``user_vqa_query``,
|
||||||
|
``stop``)
|
||||||
|
* internal state transitions (``subtask_change``,
|
||||||
|
``tool_call_pending``)
|
||||||
|
|
||||||
|
The list is consumed (cleared at the end of the tick) so events
|
||||||
|
fire at most once.
|
||||||
|
"""
|
||||||
|
|
||||||
|
event_name: str
|
||||||
|
|
||||||
|
def should_fire(self, tick: Tick, state: dict[str, Any]) -> bool:
|
||||||
|
events: list[str] = state.get("events_this_tick") or []
|
||||||
|
return self.event_name in events
|
||||||
@@ -13,60 +13,62 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""SmolVLA2 modeling — dual-head subclass of SmolVLAPolicy.
|
"""SmolVLA2 modeling — dual-head subclass of SmolVLAPolicy.
|
||||||
|
|
||||||
This module defines :class:`SmolVLA2Policy`, which extends SmolVLA with:
|
Adds:
|
||||||
|
|
||||||
* an unfrozen SmolVLM ``lm_head`` so language tokens can be supervised,
|
* an unfrozen SmolVLM ``lm_head`` so language tokens can be supervised,
|
||||||
* a forward path that routes to the flow head, the text head, or both,
|
* a forward path that runs the flow head, the text head, or both,
|
||||||
driven by ``batch["predict_actions"]`` and ``batch["text_labels"]``.
|
driven by ``batch["predict_actions"]`` and ``batch["text_labels"]``
|
||||||
|
produced by :class:`SmolVLA2ChatTokenizerStep` (the previous commit on
|
||||||
|
this branch).
|
||||||
|
|
||||||
The text-head computation itself is NOT wired up in this scaffold commit
|
Per-sample routing — within one batch:
|
||||||
(the processor doesn't yet produce ``text_labels`` either). This file is
|
|
||||||
the structural placeholder that:
|
|
||||||
|
|
||||||
1. registers the ``SmolVLA2Policy`` class with the right config name so
|
* ``predict_actions[i] = True`` ⇒ sample ``i`` contributes to the flow
|
||||||
``policies/factory.py`` can build it,
|
loss (action chunk supervision).
|
||||||
2. unfreezes ``lm_head`` at construction time when the config asks for it
|
* ``predict_actions[i] = False`` ⇒ sample ``i`` is masked out of the
|
||||||
(otherwise SmolVLA's ``train_expert_only`` freezes it again on every
|
flow loss; only its text tokens (where ``text_labels[i, t] != -100``)
|
||||||
``train()`` call),
|
contribute to the LM-head cross-entropy.
|
||||||
3. forwards to ``SmolVLAPolicy.forward`` so behaviour is identical to
|
|
||||||
SmolVLA when no text labels are present — i.e. existing SmolVLA
|
|
||||||
training scripts keep working.
|
|
||||||
|
|
||||||
The next commit on this branch fills in the actual text-loss path.
|
Falls back to ``SmolVLAPolicy.forward`` cleanly when neither
|
||||||
|
``text_labels`` nor ``predict_actions`` is in the batch — unannotated
|
||||||
|
datasets keep working unchanged.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import math
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
from ..smolvla.modeling_smolvla import SmolVLAPolicy
|
from lerobot.utils.constants import (
|
||||||
|
ACTION,
|
||||||
|
OBS_LANGUAGE_ATTENTION_MASK,
|
||||||
|
OBS_LANGUAGE_TOKENS,
|
||||||
|
OBS_STATE,
|
||||||
|
)
|
||||||
|
|
||||||
|
from ..smolvla.modeling_smolvla import SmolVLAPolicy, make_att_2d_masks
|
||||||
from .configuration_smolvla2 import SmolVLA2Config
|
from .configuration_smolvla2 import SmolVLA2Config
|
||||||
|
|
||||||
|
|
||||||
class SmolVLA2Policy(SmolVLAPolicy):
|
class SmolVLA2Policy(SmolVLAPolicy):
|
||||||
"""SmolVLA + re-enabled SmolVLM language head.
|
"""SmolVLA + re-enabled SmolVLM language head."""
|
||||||
|
|
||||||
Compatible drop-in for ``SmolVLAPolicy`` from a checkpoint or factory
|
|
||||||
perspective. Behaviourally identical to SmolVLA until the text-head
|
|
||||||
code path lands in the next commit on this branch.
|
|
||||||
"""
|
|
||||||
|
|
||||||
config_class = SmolVLA2Config
|
config_class = SmolVLA2Config
|
||||||
name = "smolvla2"
|
name = "smolvla2"
|
||||||
|
|
||||||
def __init__(self, config: SmolVLA2Config, dataset_stats: dict[str, dict[str, Tensor]] | None = None):
|
def __init__(self, config: SmolVLA2Config, dataset_stats: dict[str, dict[str, Tensor]] | None = None):
|
||||||
if not isinstance(config, SmolVLA2Config):
|
if not isinstance(config, SmolVLA2Config):
|
||||||
# Allow loading a SmolVLA checkpoint into a SmolVLA2 model by
|
config = SmolVLA2Config(
|
||||||
# widening the config type — the new fields fall back to their
|
**{
|
||||||
# defaults, which preserves the existing SmolVLA behaviour.
|
f.name: getattr(config, f.name)
|
||||||
config = SmolVLA2Config(**{
|
for f in config.__dataclass_fields__.values()
|
||||||
f.name: getattr(config, f.name)
|
if hasattr(config, f.name)
|
||||||
for f in config.__dataclass_fields__.values()
|
}
|
||||||
if hasattr(config, f.name)
|
)
|
||||||
})
|
|
||||||
super().__init__(config, dataset_stats=dataset_stats)
|
super().__init__(config, dataset_stats=dataset_stats)
|
||||||
if config.unfreeze_lm_head and config.text_loss_weight > 0:
|
if config.unfreeze_lm_head and config.text_loss_weight > 0:
|
||||||
self._unfreeze_lm_head()
|
self._unfreeze_lm_head()
|
||||||
@@ -76,13 +78,8 @@ class SmolVLA2Policy(SmolVLAPolicy):
|
|||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
def _unfreeze_lm_head(self) -> None:
|
def _unfreeze_lm_head(self) -> None:
|
||||||
"""Re-enable gradients on the SmolVLM ``lm_head`` (and the bits of
|
"""Re-enable gradients on the SmolVLM ``lm_head`` (and the bits
|
||||||
the text path SmolVLA freezes) so the text-loss can flow back.
|
of the text path SmolVLA freezes) so the text-loss can flow back.
|
||||||
|
|
||||||
SmolVLA's ``SmolVLMWithExpertModel.set_requires_grad`` freezes
|
|
||||||
``lm_head``, ``text_model.model.norm.weight``, and the last
|
|
||||||
``text_model.layers.<N-1>`` block. We undo that selectively when
|
|
||||||
text training is enabled.
|
|
||||||
"""
|
"""
|
||||||
vlm_with_expert = getattr(self.model, "vlm_with_expert", None)
|
vlm_with_expert = getattr(self.model, "vlm_with_expert", None)
|
||||||
if vlm_with_expert is None:
|
if vlm_with_expert is None:
|
||||||
@@ -91,10 +88,7 @@ class SmolVLA2Policy(SmolVLAPolicy):
|
|||||||
if vlm is None:
|
if vlm is None:
|
||||||
return
|
return
|
||||||
for name, param in vlm.named_parameters():
|
for name, param in vlm.named_parameters():
|
||||||
if (
|
if "lm_head" in name or "text_model.model.norm.weight" in name:
|
||||||
"lm_head" in name
|
|
||||||
or "text_model.model.norm.weight" in name
|
|
||||||
):
|
|
||||||
param.requires_grad = True
|
param.requires_grad = True
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
@@ -108,12 +102,286 @@ class SmolVLA2Policy(SmolVLAPolicy):
|
|||||||
time: Tensor | None = None,
|
time: Tensor | None = None,
|
||||||
reduction: str = "mean",
|
reduction: str = "mean",
|
||||||
) -> tuple[Tensor, dict[str, Any]]:
|
) -> tuple[Tensor, dict[str, Any]]:
|
||||||
"""Forward pass with optional text-head loss.
|
"""Forward pass with optional dual-head loss.
|
||||||
|
|
||||||
SCAFFOLD: forwards directly to ``SmolVLAPolicy.forward``. The
|
Two routing knobs from the batch (produced by
|
||||||
actual text-loss / dual-head routing lands in the next commit on
|
:class:`SmolVLA2ChatTokenizerStep`):
|
||||||
this branch — it will read ``batch["text_labels"]`` and
|
|
||||||
``batch["predict_actions"]`` (both produced by the SmolVLA2
|
* ``text_labels`` — per-token labels with ``-100`` for non-target
|
||||||
processor) to decide which head(s) to run.
|
positions. Triggers the text-loss path through ``lm_head``.
|
||||||
|
* ``predict_actions`` — per-sample bool tensor. ``True`` ⇒
|
||||||
|
include this sample's action chunk in the flow loss.
|
||||||
|
|
||||||
|
When neither is present, delegate to ``SmolVLAPolicy.forward``.
|
||||||
"""
|
"""
|
||||||
return super().forward(batch, noise=noise, time=time, reduction=reduction)
|
text_labels = batch.get("text_labels")
|
||||||
|
predict_actions_t = batch.get("predict_actions")
|
||||||
|
|
||||||
|
has_text_data = (
|
||||||
|
text_labels is not None
|
||||||
|
and isinstance(text_labels, Tensor)
|
||||||
|
and self.config.text_loss_weight > 0
|
||||||
|
)
|
||||||
|
has_per_sample_routing = (
|
||||||
|
predict_actions_t is not None and isinstance(predict_actions_t, Tensor)
|
||||||
|
)
|
||||||
|
|
||||||
|
if not has_text_data and not has_per_sample_routing:
|
||||||
|
return super().forward(batch, noise=noise, time=time, reduction=reduction)
|
||||||
|
|
||||||
|
loss_dict: dict[str, Any] = {}
|
||||||
|
device = batch[OBS_STATE].device
|
||||||
|
total = torch.zeros((), device=device, dtype=torch.float32)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------
|
||||||
|
# Flow loss path — only when at least one sample wants actions.
|
||||||
|
# ------------------------------------------------------------
|
||||||
|
run_flow = self.config.flow_loss_weight > 0 and (
|
||||||
|
not has_per_sample_routing or bool(predict_actions_t.any().item())
|
||||||
|
)
|
||||||
|
if run_flow and ACTION in batch:
|
||||||
|
per_sample_flow, flow_diag = super().forward(
|
||||||
|
batch, noise=noise, time=time, reduction="none"
|
||||||
|
)
|
||||||
|
# ``per_sample_flow`` has shape (B,) from the SmolVLA
|
||||||
|
# reduction="none" branch.
|
||||||
|
if has_per_sample_routing:
|
||||||
|
mask = predict_actions_t.to(per_sample_flow.dtype)
|
||||||
|
masked = per_sample_flow * mask
|
||||||
|
denom = mask.sum().clamp(min=1.0)
|
||||||
|
flow_loss = masked.sum() / denom
|
||||||
|
else:
|
||||||
|
flow_loss = per_sample_flow.mean()
|
||||||
|
total = total + self.config.flow_loss_weight * flow_loss
|
||||||
|
loss_dict["flow_loss"] = float(flow_loss.detach().item())
|
||||||
|
for k, v in flow_diag.items():
|
||||||
|
loss_dict[f"flow_{k}"] = v
|
||||||
|
|
||||||
|
# ------------------------------------------------------------
|
||||||
|
# Text loss path — prefix-only forward → lm_head → CE.
|
||||||
|
# ------------------------------------------------------------
|
||||||
|
if has_text_data:
|
||||||
|
text_loss = self._compute_text_loss(batch, text_labels)
|
||||||
|
total = total + self.config.text_loss_weight * text_loss
|
||||||
|
loss_dict["text_loss"] = float(text_loss.detach().item())
|
||||||
|
|
||||||
|
loss_dict["loss"] = float(total.detach().item())
|
||||||
|
|
||||||
|
if reduction == "none":
|
||||||
|
# Per-sample loss isn't meaningfully defined for the dual
|
||||||
|
# path; broadcast the scalar to (B,) for caller compat.
|
||||||
|
return total.expand(batch[OBS_STATE].shape[0]), loss_dict
|
||||||
|
return total, loss_dict
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Text-loss internals
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _compute_text_loss(self, batch: dict[str, Tensor], text_labels: Tensor) -> Tensor:
|
||||||
|
"""Cross-entropy on the SmolVLM ``lm_head`` over target tokens."""
|
||||||
|
if self.config.adapt_to_pi_aloha:
|
||||||
|
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
|
||||||
|
|
||||||
|
images, img_masks = self.prepare_images(batch)
|
||||||
|
state = self.prepare_state(batch)
|
||||||
|
lang_tokens = batch[OBS_LANGUAGE_TOKENS]
|
||||||
|
lang_masks = batch[OBS_LANGUAGE_ATTENTION_MASK]
|
||||||
|
|
||||||
|
prefix_embs, prefix_pad_masks, prefix_att_masks = self.model.embed_prefix(
|
||||||
|
images, img_masks, lang_tokens, lang_masks, state=state
|
||||||
|
)
|
||||||
|
prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
|
||||||
|
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
||||||
|
|
||||||
|
# Prefix-only forward.
|
||||||
|
out_pair, _ = self.model.vlm_with_expert.forward(
|
||||||
|
attention_mask=prefix_att_2d_masks,
|
||||||
|
position_ids=prefix_position_ids,
|
||||||
|
past_key_values=None,
|
||||||
|
inputs_embeds=[prefix_embs, None],
|
||||||
|
use_cache=False,
|
||||||
|
fill_kv_cache=False,
|
||||||
|
)
|
||||||
|
prefix_out = out_pair[0] if isinstance(out_pair, (tuple, list)) else out_pair
|
||||||
|
if prefix_out is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"SmolVLA2: vlm_with_expert.forward returned no prefix hidden "
|
||||||
|
"states — text-loss path needs them."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Lang token positions inside the prefix. ``embed_prefix`` lays
|
||||||
|
# out the prefix as ``[image_blocks..., lang, state]`` so the
|
||||||
|
# lang range is identifiable from the trailing state size and
|
||||||
|
# the known lang length.
|
||||||
|
num_lang = lang_tokens.shape[1]
|
||||||
|
state_for_dim = state if state.ndim >= 2 else state[:, None]
|
||||||
|
num_state = state_for_dim.shape[1] if state_for_dim.ndim >= 2 else 1
|
||||||
|
if num_state < 1:
|
||||||
|
num_state = 1
|
||||||
|
prefix_len = prefix_out.shape[1]
|
||||||
|
lang_end = prefix_len - num_state
|
||||||
|
lang_start = lang_end - num_lang
|
||||||
|
if lang_start < 0 or lang_end > prefix_len:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"SmolVLA2: could not locate lang token range in prefix "
|
||||||
|
f"(prefix_len={prefix_len}, num_lang={num_lang}, "
|
||||||
|
f"num_state={num_state})."
|
||||||
|
)
|
||||||
|
|
||||||
|
lang_hidden = prefix_out[:, lang_start:lang_end]
|
||||||
|
vlm = self.model.vlm_with_expert.vlm
|
||||||
|
logits = vlm.lm_head(lang_hidden) # (B, num_lang, vocab)
|
||||||
|
|
||||||
|
if text_labels.shape[1] != num_lang:
|
||||||
|
common = min(text_labels.shape[1], num_lang)
|
||||||
|
logits = logits[:, :common]
|
||||||
|
text_labels = text_labels[:, :common]
|
||||||
|
|
||||||
|
# Standard next-token CE: hidden state at position t predicts
|
||||||
|
# token at position t+1. Shift logits left, labels right by 1.
|
||||||
|
# Without this, the loss is identity-mapped and the LM head
|
||||||
|
# learns nothing useful — see HuggingFace ``LlamaForCausalLM``
|
||||||
|
# for the same convention.
|
||||||
|
shift_logits = logits[:, :-1, :].contiguous()
|
||||||
|
shift_labels = text_labels[:, 1:].contiguous().long()
|
||||||
|
loss = F.cross_entropy(
|
||||||
|
shift_logits.reshape(-1, shift_logits.shape[-1]),
|
||||||
|
shift_labels.reshape(-1),
|
||||||
|
ignore_index=-100,
|
||||||
|
)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Inference: text generation
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def select_message(
|
||||||
|
self,
|
||||||
|
batch: dict[str, Tensor],
|
||||||
|
*,
|
||||||
|
max_new_tokens: int = 256,
|
||||||
|
eos_token_id: int | None = None,
|
||||||
|
temperature: float = 0.0,
|
||||||
|
top_p: float = 1.0,
|
||||||
|
tokenizer: Any = None,
|
||||||
|
) -> str:
|
||||||
|
"""Generate text continuation from the chat-templated prompt.
|
||||||
|
|
||||||
|
AR decoding with KV caching reused from SmolVLA's inference
|
||||||
|
path. Batch size is assumed to be 1 (the runtime calls this
|
||||||
|
per-event). Returns the decoded string of new tokens (the
|
||||||
|
prompt itself is not included).
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
batch:
|
||||||
|
Already through the SmolVLA2 preprocessor — expects
|
||||||
|
``OBS_IMAGES_*``, ``OBS_STATE``, ``OBS_LANGUAGE_TOKENS``,
|
||||||
|
``OBS_LANGUAGE_ATTENTION_MASK``.
|
||||||
|
max_new_tokens:
|
||||||
|
Hard cap on generated tokens; stops earlier on EOS.
|
||||||
|
eos_token_id:
|
||||||
|
Override the tokenizer's EOS. ``None`` ⇒ use the
|
||||||
|
tokenizer's default.
|
||||||
|
temperature, top_p:
|
||||||
|
``temperature=0`` does greedy argmax (default — matches
|
||||||
|
training distribution most closely). Set ``temperature>0``
|
||||||
|
with optional ``top_p<1`` for nucleus sampling.
|
||||||
|
tokenizer:
|
||||||
|
Optional pre-loaded tokenizer to avoid the cold-start
|
||||||
|
``AutoTokenizer.from_pretrained`` round-trip on every call.
|
||||||
|
"""
|
||||||
|
self.eval()
|
||||||
|
|
||||||
|
if tokenizer is None:
|
||||||
|
from transformers import AutoTokenizer # noqa: PLC0415
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(self.config.vlm_model_name)
|
||||||
|
if eos_token_id is None:
|
||||||
|
eos_token_id = tokenizer.eos_token_id
|
||||||
|
|
||||||
|
images, img_masks = self.prepare_images(batch)
|
||||||
|
state = self.prepare_state(batch)
|
||||||
|
lang_tokens = batch[OBS_LANGUAGE_TOKENS]
|
||||||
|
lang_masks = batch[OBS_LANGUAGE_ATTENTION_MASK]
|
||||||
|
|
||||||
|
# 1) Embed prefix (images + lang + state) and run with KV cache.
|
||||||
|
prefix_embs, prefix_pad_masks, prefix_att_masks = self.model.embed_prefix(
|
||||||
|
images, img_masks, lang_tokens, lang_masks, state=state
|
||||||
|
)
|
||||||
|
prefix_2d = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
|
||||||
|
prefix_pos = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
||||||
|
out_pair, past_kv = self.model.vlm_with_expert.forward(
|
||||||
|
attention_mask=prefix_2d,
|
||||||
|
position_ids=prefix_pos,
|
||||||
|
past_key_values=None,
|
||||||
|
inputs_embeds=[prefix_embs, None],
|
||||||
|
use_cache=True,
|
||||||
|
fill_kv_cache=True,
|
||||||
|
)
|
||||||
|
prefix_out = out_pair[0] if isinstance(out_pair, (tuple, list)) else out_pair
|
||||||
|
if prefix_out is None:
|
||||||
|
raise RuntimeError("select_message: prefix forward returned no hidden states.")
|
||||||
|
|
||||||
|
vlm = self.model.vlm_with_expert.vlm
|
||||||
|
|
||||||
|
# 2) Initial logits — sample first new token from the last
|
||||||
|
# prefix position.
|
||||||
|
last_hidden = prefix_out[:, -1:]
|
||||||
|
device = last_hidden.device
|
||||||
|
bsize = prefix_embs.shape[0]
|
||||||
|
cur_pos = int(prefix_embs.shape[1])
|
||||||
|
|
||||||
|
generated: list[int] = []
|
||||||
|
for _ in range(max_new_tokens):
|
||||||
|
logits_step = vlm.lm_head(last_hidden)[:, -1] # (B, V)
|
||||||
|
next_ids = self._sample_next_token(logits_step, temperature, top_p)
|
||||||
|
tok_id = int(next_ids[0].item())
|
||||||
|
generated.append(tok_id)
|
||||||
|
if eos_token_id is not None and tok_id == eos_token_id:
|
||||||
|
break
|
||||||
|
|
||||||
|
# 3) Embed the new token and forward with KV cache.
|
||||||
|
new_emb = self.model.vlm_with_expert.embed_language_tokens(
|
||||||
|
next_ids.unsqueeze(0)
|
||||||
|
)
|
||||||
|
new_emb = new_emb * math.sqrt(new_emb.shape[-1])
|
||||||
|
|
||||||
|
new_pos = torch.full((bsize, 1), cur_pos, device=device, dtype=torch.long)
|
||||||
|
new_attn = torch.ones((bsize, cur_pos + 1), device=device, dtype=torch.bool)
|
||||||
|
|
||||||
|
out_pair, past_kv = self.model.vlm_with_expert.forward(
|
||||||
|
attention_mask=new_attn,
|
||||||
|
position_ids=new_pos,
|
||||||
|
past_key_values=past_kv,
|
||||||
|
inputs_embeds=[new_emb, None],
|
||||||
|
use_cache=True,
|
||||||
|
fill_kv_cache=True,
|
||||||
|
)
|
||||||
|
new_prefix_out = out_pair[0] if isinstance(out_pair, (tuple, list)) else out_pair
|
||||||
|
last_hidden = new_prefix_out[:, -1:]
|
||||||
|
cur_pos += 1
|
||||||
|
|
||||||
|
return tokenizer.decode(generated, skip_special_tokens=True).strip()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _sample_next_token(
|
||||||
|
logits: Tensor, temperature: float, top_p: float
|
||||||
|
) -> Tensor:
|
||||||
|
"""Pick one token id per batch row from ``logits``."""
|
||||||
|
if temperature <= 0.0:
|
||||||
|
return logits.argmax(dim=-1)
|
||||||
|
scaled = logits / max(temperature, 1e-6)
|
||||||
|
probs = F.softmax(scaled, dim=-1)
|
||||||
|
if top_p < 1.0:
|
||||||
|
sorted_probs, sorted_idx = probs.sort(dim=-1, descending=True)
|
||||||
|
cum = sorted_probs.cumsum(dim=-1)
|
||||||
|
mask = cum > top_p
|
||||||
|
# Always keep the most-likely token.
|
||||||
|
mask[..., 0] = False
|
||||||
|
sorted_probs = sorted_probs.masked_fill(mask, 0.0)
|
||||||
|
sorted_probs = sorted_probs / sorted_probs.sum(dim=-1, keepdim=True).clamp(min=1e-9)
|
||||||
|
pick = torch.multinomial(sorted_probs, num_samples=1)
|
||||||
|
return sorted_idx.gather(-1, pick).squeeze(-1)
|
||||||
|
return torch.multinomial(probs, num_samples=1).squeeze(-1)
|
||||||
|
|||||||
@@ -0,0 +1,196 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""``lerobot-smolvla2-runtime`` — interactive REPL for trained SmolVLA2.
|
||||||
|
|
||||||
|
Drives the multi-rate runtime defined in
|
||||||
|
:mod:`lerobot.policies.smolvla2.inference`. Stdin becomes the user
|
||||||
|
channel: type a task, then natural-language interjections / questions.
|
||||||
|
The runtime prints state changes (plan / subtask / memory / vqa /
|
||||||
|
speech) as they happen.
|
||||||
|
|
||||||
|
Examples
|
||||||
|
--------
|
||||||
|
|
||||||
|
Dry run on a checkpoint, no robot connected — useful for sanity-
|
||||||
|
checking text generation::
|
||||||
|
|
||||||
|
uv run lerobot-smolvla2-runtime \\
|
||||||
|
--policy.path=outputs/train/smolvla2_super_poulain/000020000/pretrained_model \\
|
||||||
|
--no_robot \\
|
||||||
|
--task="please clean the kitchen"
|
||||||
|
|
||||||
|
With a real robot::
|
||||||
|
|
||||||
|
uv run lerobot-smolvla2-runtime \\
|
||||||
|
--policy.path=... \\
|
||||||
|
--robot.type=so101 --robot.port=/dev/tty.usbmodem... \\
|
||||||
|
--tts.voice=alba
|
||||||
|
|
||||||
|
Tool dispatch (TTS via ``SayTool``) is enabled by default when
|
||||||
|
``pocket-tts`` is installed; pass ``--no_tts`` to disable.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
logger = logging.getLogger("lerobot.smolvla2.runtime")
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_args(argv: list[str] | None = None) -> argparse.Namespace:
|
||||||
|
p = argparse.ArgumentParser(
|
||||||
|
prog="lerobot-smolvla2-runtime",
|
||||||
|
description="Interactive REPL runtime for a trained SmolVLA2 checkpoint.",
|
||||||
|
)
|
||||||
|
p.add_argument(
|
||||||
|
"--policy.path",
|
||||||
|
dest="policy_path",
|
||||||
|
type=Path,
|
||||||
|
required=True,
|
||||||
|
help="Path to a trained SmolVLA2 ``pretrained_model`` directory.",
|
||||||
|
)
|
||||||
|
p.add_argument(
|
||||||
|
"--task",
|
||||||
|
dest="task",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Initial task. If omitted, the first stdin line is treated as the task.",
|
||||||
|
)
|
||||||
|
p.add_argument(
|
||||||
|
"--no_robot",
|
||||||
|
action="store_true",
|
||||||
|
help="Skip robot connection — language-only / dry-run mode.",
|
||||||
|
)
|
||||||
|
p.add_argument(
|
||||||
|
"--no_tts",
|
||||||
|
action="store_true",
|
||||||
|
help="Disable the ``say`` tool dispatch.",
|
||||||
|
)
|
||||||
|
p.add_argument(
|
||||||
|
"--tts.voice",
|
||||||
|
dest="tts_voice",
|
||||||
|
type=str,
|
||||||
|
default="alba",
|
||||||
|
help="Pocket-tts voice name (or path to a .wav for cloning).",
|
||||||
|
)
|
||||||
|
p.add_argument(
|
||||||
|
"--chunk_hz", type=float, default=4.0, help="Action-chunk generation rate."
|
||||||
|
)
|
||||||
|
p.add_argument(
|
||||||
|
"--ctrl_hz", type=float, default=50.0, help="Action dispatch rate."
|
||||||
|
)
|
||||||
|
p.add_argument(
|
||||||
|
"--high_level_hz",
|
||||||
|
type=float,
|
||||||
|
default=1.0,
|
||||||
|
help="High-level subtask generation rate.",
|
||||||
|
)
|
||||||
|
p.add_argument(
|
||||||
|
"--max_ticks",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="Stop after N ticks (debug / smoke-test).",
|
||||||
|
)
|
||||||
|
p.add_argument("-v", "--verbose", action="store_true", help="Enable DEBUG logging.")
|
||||||
|
return p.parse_args(argv)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_policy(path: Path): # noqa: ANN202
|
||||||
|
"""Load a SmolVLA2 checkpoint from ``path``."""
|
||||||
|
from lerobot.policies.factory import make_policy_from_path # noqa: PLC0415
|
||||||
|
|
||||||
|
policy = make_policy_from_path(str(path))
|
||||||
|
policy.eval()
|
||||||
|
return policy
|
||||||
|
|
||||||
|
|
||||||
|
def _build_tools(policy_path: Path, no_tts: bool, tts_voice: str) -> dict[str, Any]:
|
||||||
|
"""Instantiate the tools declared on this dataset/policy."""
|
||||||
|
if no_tts:
|
||||||
|
return {}
|
||||||
|
try:
|
||||||
|
from lerobot.tools import SayTool # noqa: PLC0415
|
||||||
|
|
||||||
|
return {"say": SayTool(voice=tts_voice)}
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
logger.warning("Could not initialise SayTool (%s) — speech disabled.", exc)
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
def main(argv: list[str] | None = None) -> int:
|
||||||
|
args = _parse_args(argv)
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.DEBUG if args.verbose else logging.INFO,
|
||||||
|
format="%(asctime)s %(levelname)s %(message)s",
|
||||||
|
)
|
||||||
|
|
||||||
|
if not args.policy_path.exists():
|
||||||
|
print(f"[smolvla2] policy path not found: {args.policy_path}", file=sys.stderr)
|
||||||
|
return 1
|
||||||
|
|
||||||
|
print(f"[smolvla2] loading policy from {args.policy_path}", flush=True)
|
||||||
|
policy = _load_policy(args.policy_path)
|
||||||
|
|
||||||
|
tools = _build_tools(args.policy_path, args.no_tts, args.tts_voice)
|
||||||
|
if tools:
|
||||||
|
print(f"[smolvla2] tools loaded: {list(tools)}", flush=True)
|
||||||
|
|
||||||
|
# Robot wiring is left as a follow-up — for v1 we run language-only
|
||||||
|
# / dry-run so REPL development doesn't require a connected robot.
|
||||||
|
observation_provider = None
|
||||||
|
robot_executor = None
|
||||||
|
if not args.no_robot:
|
||||||
|
print(
|
||||||
|
"[smolvla2] WARNING: real-robot integration is a follow-up. "
|
||||||
|
"Running in dry-run mode for now (no actions executed).",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
from lerobot.policies.smolvla2.inference import ( # noqa: PLC0415
|
||||||
|
SmolVLA2Runtime,
|
||||||
|
StdinReader,
|
||||||
|
)
|
||||||
|
|
||||||
|
runtime = SmolVLA2Runtime(
|
||||||
|
policy=policy,
|
||||||
|
tools=tools,
|
||||||
|
observation_provider=observation_provider,
|
||||||
|
robot_executor=robot_executor,
|
||||||
|
event_collector=StdinReader().poll,
|
||||||
|
chunk_hz=args.chunk_hz,
|
||||||
|
ctrl_hz=args.ctrl_hz,
|
||||||
|
high_level_hz=args.high_level_hz,
|
||||||
|
)
|
||||||
|
if args.task:
|
||||||
|
runtime.set_task(args.task)
|
||||||
|
print(
|
||||||
|
"[smolvla2] runtime ready. Type a task to begin, then any line for "
|
||||||
|
"interjections, questions ending in '?' for VQA, or 'stop' to exit.",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
runtime.run(max_ticks=args.max_ticks)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
runtime.stop()
|
||||||
|
print("\n[smolvla2] interrupted by user", flush=True)
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
sys.exit(main())
|
||||||
Reference in New Issue
Block a user