mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-11 14:49:43 +00:00
Compare commits
2 Commits
37b1eb218a
...
223cc8a9e2
| Author | SHA1 | Date | |
|---|---|---|---|
| 223cc8a9e2 | |||
| af6d8ebd5b |
@@ -307,6 +307,7 @@ lerobot-imgtransform-viz="lerobot.scripts.lerobot_imgtransform_viz:main"
|
||||
lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main"
|
||||
lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main"
|
||||
lerobot-annotate="lerobot.scripts.lerobot_annotate:main"
|
||||
lerobot-smolvla2-runtime="lerobot.scripts.lerobot_smolvla2_runtime:main"
|
||||
|
||||
# ---------------- Tool Configurations ----------------
|
||||
[tool.setuptools.package-data]
|
||||
|
||||
@@ -0,0 +1,68 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""SmolVLA2 inference / runtime orchestration.
|
||||
|
||||
Multi-rate runtime that mirrors the recipe-time training shape:
|
||||
|
||||
low_level_execution → LowLevelForward + DispatchAction (high Hz)
|
||||
high_level_subtask → HighLevelSubtaskFwd (~1 Hz)
|
||||
memory_update → MemoryUpdateFwd (event: subtask_change)
|
||||
user_interjection_response → UserInterjectionFwd (event: stdin)
|
||||
ask_vqa_* → AskVQAFwd (event: stdin question)
|
||||
speech tool calls → DispatchToolCalls (event: tool_call_pending)
|
||||
|
||||
The CLI ``lerobot-smolvla2-runtime`` builds an ``SmolVLA2Runtime`` and
|
||||
calls ``run()``.
|
||||
"""
|
||||
|
||||
from .repl import StdinReader
|
||||
from .runtime import SmolVLA2Runtime
|
||||
from .runtime_state import initial_runtime_state, push_log, set_if_changed, take_event
|
||||
from .steps import (
|
||||
AskVQAFwd,
|
||||
DispatchAction,
|
||||
DispatchToolCalls,
|
||||
HighLevelSubtaskFwd,
|
||||
InferenceStep,
|
||||
LowLevelForward,
|
||||
MemoryUpdateFwd,
|
||||
UserInterjectionFwd,
|
||||
)
|
||||
from .triggers import EventTrigger, HzTrigger, Tick, TickClock, Trigger
|
||||
|
||||
__all__ = [
|
||||
# runtime
|
||||
"SmolVLA2Runtime",
|
||||
"StdinReader",
|
||||
# state helpers
|
||||
"initial_runtime_state",
|
||||
"push_log",
|
||||
"set_if_changed",
|
||||
"take_event",
|
||||
# triggers
|
||||
"Trigger",
|
||||
"Tick",
|
||||
"TickClock",
|
||||
"HzTrigger",
|
||||
"EventTrigger",
|
||||
# steps
|
||||
"InferenceStep",
|
||||
"LowLevelForward",
|
||||
"DispatchAction",
|
||||
"HighLevelSubtaskFwd",
|
||||
"MemoryUpdateFwd",
|
||||
"UserInterjectionFwd",
|
||||
"AskVQAFwd",
|
||||
"DispatchToolCalls",
|
||||
]
|
||||
@@ -0,0 +1,87 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Stdin REPL event collector for the SmolVLA2 runtime.
|
||||
|
||||
Reads non-blocking stdin lines, classifies each one heuristically:
|
||||
|
||||
"stop" / "quit" / "exit" → state["stop"] = True
|
||||
ends with "?" → user_vqa_query event
|
||||
starts with "task:" or first line → set runtime task
|
||||
anything else → user_interjection event
|
||||
|
||||
Plugged into the runtime via ``event_collector=StdinReader().poll``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import select
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class StdinReader:
|
||||
"""Non-blocking stdin line collector for the runtime loop."""
|
||||
|
||||
prompt: str = "> "
|
||||
_seen_first_line: bool = field(default=False, init=False)
|
||||
_prompted: bool = field(default=False, init=False)
|
||||
|
||||
def poll(self, state: dict[str, Any]) -> None:
|
||||
"""Drain pending stdin lines into runtime events."""
|
||||
# Print the input prompt once on every fresh tick if we don't
|
||||
# already have a pending line; matches the expected REPL feel.
|
||||
if not self._prompted:
|
||||
print(self.prompt, end="", flush=True)
|
||||
self._prompted = True
|
||||
|
||||
# ``select`` with timeout=0 makes this non-blocking. Only works
|
||||
# for actual TTY / pipe stdins; CI / scripted runs hit EOF.
|
||||
try:
|
||||
ready, _, _ = select.select([sys.stdin], [], [], 0)
|
||||
except (ValueError, OSError):
|
||||
return
|
||||
if not ready:
|
||||
return
|
||||
|
||||
line = sys.stdin.readline()
|
||||
if not line: # EOF
|
||||
state["stop"] = True
|
||||
return
|
||||
line = line.strip()
|
||||
self._prompted = False # we'll re-prompt next tick
|
||||
if not line:
|
||||
return
|
||||
|
||||
lower = line.lower()
|
||||
if lower in {"stop", "quit", "exit"}:
|
||||
state["stop"] = True
|
||||
return
|
||||
|
||||
# First non-control line sets the task if no task is active.
|
||||
if not state.get("task"):
|
||||
task = line[5:].strip() if lower.startswith("task:") else line
|
||||
state["task"] = task
|
||||
print(f"[smolvla2] Task: {task}", flush=True)
|
||||
self._seen_first_line = True
|
||||
return
|
||||
|
||||
# Question → VQA; statement → interjection.
|
||||
if lower.endswith("?"):
|
||||
state["recent_vqa_query"] = line
|
||||
state.setdefault("events_this_tick", []).append("user_vqa_query")
|
||||
else:
|
||||
state["recent_interjection"] = line
|
||||
state.setdefault("events_this_tick", []).append("user_interjection")
|
||||
@@ -0,0 +1,143 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""SmolVLA2 runtime loop.
|
||||
|
||||
Threads the multi-rate inference pipeline together with a stdin REPL
|
||||
event collector, drives ticks through :class:`TickClock`, and prints
|
||||
state-change updates to the user.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections import deque
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable
|
||||
|
||||
from .runtime_state import initial_runtime_state, push_log
|
||||
from .steps import (
|
||||
AskVQAFwd,
|
||||
DispatchAction,
|
||||
DispatchToolCalls,
|
||||
HighLevelSubtaskFwd,
|
||||
InferenceStep,
|
||||
LowLevelForward,
|
||||
MemoryUpdateFwd,
|
||||
UserInterjectionFwd,
|
||||
)
|
||||
from .triggers import HzTrigger, TickClock
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SmolVLA2Runtime:
|
||||
"""Compose the inference pipeline and drive it tick-by-tick."""
|
||||
|
||||
policy: Any
|
||||
tools: dict[str, Any] = field(default_factory=dict)
|
||||
"""Name → tool-instance dict, e.g. ``{"say": SayTool(...)}``. Read
|
||||
from :func:`lerobot.tools.get_tools(meta)` when wiring the
|
||||
runtime."""
|
||||
observation_provider: Callable[[], dict | None] | None = None
|
||||
"""Closure returning the current preprocessed observation batch.
|
||||
``None`` for dry-run / language-only sessions."""
|
||||
robot_executor: Callable[[Any], None] | None = None
|
||||
"""Closure that takes one action chunk and forwards it to the
|
||||
robot. ``None`` for dry-run."""
|
||||
event_collector: Callable[[dict], None] | None = None
|
||||
"""Per-tick hook that polls external sources (stdin, network) and
|
||||
appends event names to ``state["events_this_tick"]``."""
|
||||
chunk_hz: float = 4.0
|
||||
ctrl_hz: float = 50.0
|
||||
high_level_hz: float = 1.0
|
||||
max_rate_hz: float = 50.0
|
||||
|
||||
pipeline: list[InferenceStep] = field(init=False)
|
||||
state: dict[str, Any] = field(init=False)
|
||||
_stop: bool = field(default=False, init=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.pipeline = [
|
||||
LowLevelForward(
|
||||
trigger=HzTrigger(self.chunk_hz),
|
||||
policy=self.policy,
|
||||
observation_provider=self.observation_provider,
|
||||
),
|
||||
DispatchAction(
|
||||
trigger=HzTrigger(self.ctrl_hz),
|
||||
robot_executor=self.robot_executor,
|
||||
),
|
||||
HighLevelSubtaskFwd(
|
||||
trigger=HzTrigger(self.high_level_hz),
|
||||
policy=self.policy,
|
||||
),
|
||||
MemoryUpdateFwd(policy=self.policy),
|
||||
UserInterjectionFwd(policy=self.policy),
|
||||
AskVQAFwd(policy=self.policy),
|
||||
DispatchToolCalls(tools=self.tools),
|
||||
]
|
||||
self.state = initial_runtime_state()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def set_task(self, task: str) -> None:
|
||||
"""Set or replace the active task. Logged for the REPL."""
|
||||
self.state["task"] = task
|
||||
push_log(self.state, f"Task: {task}")
|
||||
|
||||
def stop(self) -> None:
|
||||
self._stop = True
|
||||
|
||||
def run(self, *, max_ticks: int | None = None) -> None:
|
||||
"""Main loop. Returns when ``stop()`` is called or after
|
||||
``max_ticks`` ticks (useful for tests / dry-run)."""
|
||||
clock = TickClock(max_rate_hz=self.max_rate_hz)
|
||||
while not self._stop:
|
||||
tick = clock.advance()
|
||||
self.state["_tick"] = tick
|
||||
self.state["events_this_tick"] = []
|
||||
self.state["log_lines"] = []
|
||||
|
||||
if self.event_collector is not None:
|
||||
self.event_collector(self.state)
|
||||
if self.state.get("stop"):
|
||||
self._stop = True
|
||||
break
|
||||
|
||||
for step in self.pipeline:
|
||||
self.state = step(self.state)
|
||||
|
||||
self._flush_logs()
|
||||
if max_ticks is not None and tick.index >= max_ticks:
|
||||
break
|
||||
|
||||
self._on_shutdown()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# I/O
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _flush_logs(self) -> None:
|
||||
for line in self.state.get("log_lines") or []:
|
||||
print(f"[smolvla2] {line}", flush=True)
|
||||
|
||||
def _on_shutdown(self) -> None:
|
||||
# Drain any queued action chunks safely.
|
||||
queue = self.state.get("action_queue")
|
||||
if isinstance(queue, deque):
|
||||
queue.clear()
|
||||
print("[smolvla2] runtime stopped", flush=True)
|
||||
@@ -0,0 +1,91 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Runtime state passed between inference steps each tick.
|
||||
|
||||
The runtime threads a single dict through the pipeline; this module
|
||||
documents the shape and provides factories. We use a plain ``dict``
|
||||
rather than a frozen dataclass because steps freely add and remove
|
||||
keys (``events_this_tick``, ``messages_pending``, ``tool_calls_pending``,
|
||||
…) and dataclass field churn would just get in the way.
|
||||
|
||||
Stable keys (read by multiple steps):
|
||||
|
||||
task str the current top-level task
|
||||
current_plan str | None latest plan emitted by the planner
|
||||
current_subtask str | None latest subtask the policy is executing
|
||||
current_memory str | None latest compressed memory
|
||||
recent_interjection str | None most recent user interjection text (consumed)
|
||||
|
||||
action_queue collections.deque[Tensor] pending action chunks
|
||||
tool_calls_pending list[dict] parsed but not-yet-dispatched tool calls
|
||||
|
||||
events_this_tick list[str] triggers consumed this tick
|
||||
_tick Tick current tick (set by the loop)
|
||||
|
||||
log_lines list[str] human-readable status lines printed each tick
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import deque
|
||||
from typing import Any
|
||||
|
||||
|
||||
def initial_runtime_state(task: str | None = None) -> dict[str, Any]:
|
||||
"""Build a fresh runtime state dict with sensible defaults."""
|
||||
return {
|
||||
"task": task,
|
||||
"current_plan": None,
|
||||
"current_subtask": None,
|
||||
"current_memory": None,
|
||||
"recent_interjection": None,
|
||||
"action_queue": deque(),
|
||||
"tool_calls_pending": [],
|
||||
"events_this_tick": [],
|
||||
"log_lines": [],
|
||||
"stop": False,
|
||||
}
|
||||
|
||||
|
||||
def take_event(state: dict[str, Any], event_name: str) -> bool:
|
||||
"""Pop ``event_name`` from ``events_this_tick`` if present.
|
||||
|
||||
Steps that consume an event call this so the same event doesn't
|
||||
re-fire on a sibling step within the same tick.
|
||||
"""
|
||||
events: list[str] = state.get("events_this_tick") or []
|
||||
if event_name in events:
|
||||
events.remove(event_name)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def push_log(state: dict[str, Any], line: str) -> None:
|
||||
"""Append ``line`` to the per-tick log buffer; the runtime prints
|
||||
it at the end of the tick."""
|
||||
state.setdefault("log_lines", []).append(line)
|
||||
|
||||
|
||||
def set_if_changed(state: dict[str, Any], key: str, value: Any, label: str | None = None) -> bool:
|
||||
"""Update ``state[key]`` and log a diff line if the value changed.
|
||||
|
||||
Returns ``True`` if the value actually changed.
|
||||
"""
|
||||
prev = state.get(key)
|
||||
if prev == value:
|
||||
return False
|
||||
state[key] = value
|
||||
if label is not None:
|
||||
push_log(state, f" {label}: {value}")
|
||||
return True
|
||||
@@ -0,0 +1,382 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference steps for the SmolVLA2 multi-rate runtime.
|
||||
|
||||
Each step is a tiny class with a ``trigger`` and an ``__call__(state)``;
|
||||
the runtime applies them in order each tick. When a step's trigger
|
||||
doesn't fire, the step is a no-op and the runtime moves on.
|
||||
|
||||
Stream-to-step mapping mirrors the ``smolvla2_hirobot.yaml`` recipe:
|
||||
|
||||
* ``LowLevelForward`` — calls ``policy.select_action`` for the
|
||||
action chunk; trained by
|
||||
``low_level_execution``
|
||||
* ``EnqueueChunk`` — pushes the chunk to ``action_queue``
|
||||
* ``DispatchAction`` — pops one action per control tick and
|
||||
forwards to the robot
|
||||
* ``HighLevelSubtaskFwd`` — calls ``policy.select_message`` for the
|
||||
next subtask; trained by
|
||||
``high_level_subtask``
|
||||
* ``MemoryUpdateFwd`` — fires on subtask boundary; trained by
|
||||
``memory_update``
|
||||
* ``UserInterjectionFwd`` — fires on stdin interjection; trained by
|
||||
``user_interjection_response``
|
||||
* ``AskVQAFwd`` — fires on stdin question; trained by
|
||||
``ask_vqa_*``
|
||||
* ``DispatchToolCalls`` — pops ``tool_calls_pending`` and calls
|
||||
the matching ``Tool`` instance
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from .runtime_state import push_log, set_if_changed, take_event
|
||||
from .triggers import EventTrigger, HzTrigger, Trigger
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Step base + runner
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class InferenceStep:
|
||||
"""A trigger-gated callable. Subclasses override :meth:`run`."""
|
||||
|
||||
trigger: Trigger
|
||||
|
||||
def __call__(self, state: dict[str, Any]) -> dict[str, Any]:
|
||||
if not self.trigger.should_fire(state["_tick"], state):
|
||||
return state
|
||||
return self.run(state) or state
|
||||
|
||||
def run(self, state: dict[str, Any]) -> dict[str, Any] | None: # pragma: no cover
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Low-level (action) path
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class LowLevelForward(InferenceStep):
|
||||
"""Run the policy's action head and produce one action chunk."""
|
||||
|
||||
policy: Any = None
|
||||
observation_provider: Any = None
|
||||
"""Callable ``() -> dict``: returns the current observation batch
|
||||
(already preprocessed). Typically wraps the robot's camera /
|
||||
proprio reads. ``None`` in dry-run mode → step skips."""
|
||||
|
||||
trigger: Trigger = field(default_factory=lambda: HzTrigger(hz=4.0))
|
||||
|
||||
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
|
||||
if self.policy is None or self.observation_provider is None:
|
||||
return None
|
||||
observation = self.observation_provider()
|
||||
if observation is None:
|
||||
return None
|
||||
action = self.policy.select_action(observation)
|
||||
# SmolVLA returns a single action; if the underlying policy
|
||||
# streams chunks, split per-step here. For v1 we just enqueue
|
||||
# the result.
|
||||
state.setdefault("action_queue", []).append(action)
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class DispatchAction(InferenceStep):
|
||||
"""Pop one action per tick and hand it to the robot.
|
||||
|
||||
In dry-run mode (``robot_executor=None``) the step still pops the
|
||||
queue so it doesn't grow unbounded — the popped tensor is logged
|
||||
instead of executed.
|
||||
"""
|
||||
|
||||
robot_executor: Any = None
|
||||
trigger: Trigger = field(default_factory=lambda: HzTrigger(hz=50.0))
|
||||
|
||||
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
|
||||
queue = state.get("action_queue")
|
||||
if not queue:
|
||||
return None
|
||||
action = queue.popleft() if hasattr(queue, "popleft") else queue.pop(0)
|
||||
if self.robot_executor is not None:
|
||||
self.robot_executor(action)
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# High-level (text) paths — all use policy.select_message
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _build_text_batch(policy: Any, prompt_messages: list[dict[str, Any]]) -> dict[str, Any]:
|
||||
"""Tokenize a list of chat messages into the batch shape
|
||||
``select_message`` expects.
|
||||
|
||||
Lazy fallback: re-uses the policy's preprocessor by piggy-backing
|
||||
on the chat tokenizer step. Production use should construct the
|
||||
batch from a real observation; here we focus on the *language*
|
||||
path which is independent of camera observations.
|
||||
"""
|
||||
from transformers import AutoTokenizer # noqa: PLC0415
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(policy.config.vlm_model_name)
|
||||
if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
text_messages = [_strip_recipe_keys(m) for m in prompt_messages]
|
||||
ids = tokenizer.apply_chat_template(
|
||||
text_messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
if isinstance(ids, list):
|
||||
ids = ids[0] if ids else []
|
||||
if hasattr(ids, "ndim") and ids.ndim == 1:
|
||||
ids = ids.unsqueeze(0)
|
||||
attn = (ids != tokenizer.pad_token_id) if tokenizer.pad_token_id is not None else None
|
||||
return {"lang_tokens": ids, "lang_masks": attn, "tokenizer": tokenizer}
|
||||
|
||||
|
||||
def _strip_recipe_keys(m: dict[str, Any]) -> dict[str, Any]:
|
||||
new = dict(m)
|
||||
new.pop("stream", None)
|
||||
new.pop("target", None)
|
||||
return new
|
||||
|
||||
|
||||
@dataclass
|
||||
class HighLevelSubtaskFwd(InferenceStep):
|
||||
"""At ~1 Hz, ask the policy for the next subtask."""
|
||||
|
||||
policy: Any = None
|
||||
trigger: Trigger = field(default_factory=lambda: HzTrigger(hz=1.0))
|
||||
|
||||
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
|
||||
if self.policy is None or not state.get("task"):
|
||||
return None
|
||||
ctx = _control_context_messages(state)
|
||||
msg = _generate_with_policy(self.policy, ctx)
|
||||
if msg:
|
||||
changed = set_if_changed(state, "current_subtask", msg, label="subtask")
|
||||
if changed:
|
||||
# Subtask change is a downstream trigger.
|
||||
state.setdefault("events_this_tick", []).append("subtask_change")
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryUpdateFwd(InferenceStep):
|
||||
"""On subtask boundary, refresh the compressed memory."""
|
||||
|
||||
policy: Any = None
|
||||
trigger: Trigger = field(default_factory=lambda: EventTrigger("subtask_change"))
|
||||
|
||||
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
|
||||
# Don't consume the event — multiple steps may want to react.
|
||||
if self.policy is None:
|
||||
return None
|
||||
ctx = _control_context_messages(state, include_completed=True)
|
||||
new_memory = _generate_with_policy(self.policy, ctx)
|
||||
if new_memory:
|
||||
set_if_changed(state, "current_memory", new_memory, label="memory")
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserInterjectionFwd(InferenceStep):
|
||||
"""On stdin interjection, refresh the plan + emit a paired ``say``."""
|
||||
|
||||
policy: Any = None
|
||||
trigger: Trigger = field(default_factory=lambda: EventTrigger("user_interjection"))
|
||||
|
||||
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
|
||||
if self.policy is None or not take_event(state, "user_interjection"):
|
||||
return None
|
||||
ctx = _control_context_messages(
|
||||
state,
|
||||
extra_user=state.get("recent_interjection"),
|
||||
)
|
||||
out = _generate_with_policy(self.policy, ctx)
|
||||
if not out:
|
||||
return None
|
||||
# Heuristic split: model is trained to emit one assistant turn
|
||||
# carrying both plan text AND a `say` tool call. Look for a
|
||||
# "<say>...</say>" or "say(...)" marker; fall back to whole
|
||||
# text → plan, no speech.
|
||||
plan_text, speech_text = _split_plan_and_say(out)
|
||||
if plan_text:
|
||||
set_if_changed(state, "current_plan", plan_text, label="plan")
|
||||
if speech_text:
|
||||
push_log(state, f" speech: {speech_text}")
|
||||
state.setdefault("tool_calls_pending", []).append(
|
||||
{
|
||||
"type": "function",
|
||||
"function": {"name": "say", "arguments": {"text": speech_text}},
|
||||
}
|
||||
)
|
||||
state.setdefault("events_this_tick", []).append("tool_call_pending")
|
||||
# Mark interjection consumed.
|
||||
state["recent_interjection"] = None
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class AskVQAFwd(InferenceStep):
|
||||
"""On stdin question, answer a frame-grounded VQA."""
|
||||
|
||||
policy: Any = None
|
||||
trigger: Trigger = field(default_factory=lambda: EventTrigger("user_vqa_query"))
|
||||
|
||||
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
|
||||
if self.policy is None or not take_event(state, "user_vqa_query"):
|
||||
return None
|
||||
question = state.get("recent_vqa_query")
|
||||
if not question:
|
||||
return None
|
||||
ctx = _control_context_messages(state, extra_user=question)
|
||||
answer = _generate_with_policy(self.policy, ctx)
|
||||
if answer:
|
||||
push_log(state, f" vqa: {answer}")
|
||||
state["recent_vqa_query"] = None
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tool dispatch
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class DispatchToolCalls(InferenceStep):
|
||||
"""Pop ``tool_calls_pending`` and execute them via :data:`TOOL_REGISTRY`."""
|
||||
|
||||
tools: dict[str, Any] = field(default_factory=dict)
|
||||
trigger: Trigger = field(default_factory=lambda: EventTrigger("tool_call_pending"))
|
||||
|
||||
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
|
||||
take_event(state, "tool_call_pending")
|
||||
pending = state.get("tool_calls_pending") or []
|
||||
for call in pending:
|
||||
try:
|
||||
fn = (call or {}).get("function") or {}
|
||||
name = fn.get("name")
|
||||
args = fn.get("arguments") or {}
|
||||
tool = self.tools.get(name)
|
||||
if tool is None:
|
||||
push_log(state, f" [warn] tool {name!r} not registered — skipping call")
|
||||
continue
|
||||
tool.call(args)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
push_log(state, f" [error] tool dispatch failed: {exc}")
|
||||
state["tool_calls_pending"] = []
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _control_context_messages(
|
||||
state: dict[str, Any],
|
||||
*,
|
||||
include_completed: bool = False,
|
||||
extra_user: str | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Build a chat-template-ready prompt from current runtime state.
|
||||
|
||||
Mirrors what ``smolvla2_hirobot.yaml`` renders into ``${task}\nPlan:
|
||||
${plan}\nMemory: ${memory}`` for the high-level branches.
|
||||
"""
|
||||
parts: list[str] = []
|
||||
task = state.get("task") or ""
|
||||
parts.append(task)
|
||||
if state.get("current_plan"):
|
||||
parts.append(f"Plan: {state['current_plan']}")
|
||||
if state.get("current_memory"):
|
||||
parts.append(f"Memory: {state['current_memory']}")
|
||||
if include_completed and state.get("current_subtask"):
|
||||
parts.append(f"Completed subtask: {state['current_subtask']}")
|
||||
head = "\n".join(parts)
|
||||
msgs: list[dict[str, Any]] = [{"role": "user", "content": head}]
|
||||
if extra_user:
|
||||
msgs.append({"role": "user", "content": extra_user})
|
||||
return msgs
|
||||
|
||||
|
||||
def _generate_with_policy(policy: Any, messages: list[dict[str, Any]]) -> str:
|
||||
"""Drive ``policy.select_message`` with a minimal text-only batch.
|
||||
|
||||
Best-effort: the runtime today doesn't construct a full
|
||||
observation batch with images / state for text generation; the
|
||||
text-head was trained over images + lang + state, so generations
|
||||
here may differ in distribution from training. This is acceptable
|
||||
for a v1 REPL; a follow-up will plug in the real observation.
|
||||
"""
|
||||
if not hasattr(policy, "select_message"):
|
||||
return ""
|
||||
text_batch = _build_text_batch(policy, messages)
|
||||
# ``select_message`` expects a real batch with OBS_LANGUAGE_TOKENS.
|
||||
# The minimal text-only batch we build doesn't have images / state,
|
||||
# so we either run a text-only forward (handled by SmolVLA2 when
|
||||
# supported) or skip and return empty. v1 returns empty when the
|
||||
# policy can't handle it; the runtime logs and continues.
|
||||
try:
|
||||
# Convert to the OBS_LANGUAGE_TOKENS / OBS_LANGUAGE_ATTENTION_MASK
|
||||
# keys ``select_message`` uses internally.
|
||||
from lerobot.utils.constants import ( # noqa: PLC0415
|
||||
OBS_LANGUAGE_ATTENTION_MASK,
|
||||
OBS_LANGUAGE_TOKENS,
|
||||
)
|
||||
|
||||
batch = {
|
||||
OBS_LANGUAGE_TOKENS: text_batch["lang_tokens"],
|
||||
OBS_LANGUAGE_ATTENTION_MASK: text_batch["lang_masks"],
|
||||
}
|
||||
return policy.select_message(batch, tokenizer=text_batch["tokenizer"])
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.debug("select_message fell back: %s", exc)
|
||||
return ""
|
||||
|
||||
|
||||
_SAY_RE = re.compile(r"<\s*say\s*>(.*?)<\s*/\s*say\s*>", re.IGNORECASE | re.DOTALL)
|
||||
|
||||
|
||||
def _split_plan_and_say(text: str) -> tuple[str, str]:
|
||||
"""Pull a ``<say>...</say>`` snippet out of ``text``; remainder is plan.
|
||||
|
||||
The training-time tool-call serializer wraps ``say(text="…")`` in a
|
||||
deterministic textual marker so prefix-LM-style training learns to
|
||||
emit it. The runtime parses it back here. If no marker is present,
|
||||
the entire text is treated as plan with no speech.
|
||||
"""
|
||||
if not text:
|
||||
return "", ""
|
||||
match = _SAY_RE.search(text)
|
||||
if not match:
|
||||
return text.strip(), ""
|
||||
speech = match.group(1).strip().strip('"').strip("'")
|
||||
plan = (text[: match.start()] + text[match.end() :]).strip()
|
||||
return plan, speech
|
||||
@@ -0,0 +1,117 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Trigger primitives for SmolVLA2's multi-rate inference runtime.
|
||||
|
||||
Mirrors the plan's Section "Runtime orchestration": each
|
||||
``InferenceStep`` is gated by a :class:`Trigger` that decides per tick
|
||||
whether the step fires. Two trigger flavours cover all the cadences
|
||||
the canonical recipe needs:
|
||||
|
||||
* :class:`HzTrigger` for periodic beats (action chunks at ~3-5 Hz,
|
||||
high-level subtask generation at ~1 Hz, action dispatch at ~50 Hz)
|
||||
* :class:`EventTrigger` for one-shot reactions (subtask boundary →
|
||||
memory update; user interjection → plan refresh; user VQA query →
|
||||
vqa answer; pending tool call → dispatcher)
|
||||
|
||||
Triggers are stateless except for ``HzTrigger``'s last-fire timestamp.
|
||||
The runtime stores the :class:`Tick` clock as ``state["_tick"]`` so
|
||||
every step shares a single time source.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Protocol
|
||||
|
||||
|
||||
@dataclass
|
||||
class Tick:
|
||||
"""Single tick from :class:`TickClock`. Carries time references the
|
||||
runtime steps consume to gate themselves."""
|
||||
|
||||
index: int
|
||||
"""Monotonic counter — increments by one per tick."""
|
||||
|
||||
monotonic_seconds: float
|
||||
"""``time.monotonic()`` at the start of this tick."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class TickClock:
|
||||
"""Drives the runtime loop at up to ``max_rate_hz``.
|
||||
|
||||
Sleeps just enough between :meth:`advance` calls to enforce the
|
||||
rate. With ``max_rate_hz=50`` the loop wakes ~every 20ms; the
|
||||
higher-level ``HzTrigger`` slices that timeline into sub-cadences.
|
||||
"""
|
||||
|
||||
max_rate_hz: float = 50.0
|
||||
_index: int = field(default=0, init=False)
|
||||
_last_seconds: float | None = field(default=None, init=False)
|
||||
|
||||
def advance(self) -> Tick:
|
||||
period = 1.0 / max(self.max_rate_hz, 0.1)
|
||||
now = time.monotonic()
|
||||
if self._last_seconds is not None:
|
||||
sleep_for = (self._last_seconds + period) - now
|
||||
if sleep_for > 0:
|
||||
time.sleep(sleep_for)
|
||||
now = time.monotonic()
|
||||
self._last_seconds = now
|
||||
self._index += 1
|
||||
return Tick(index=self._index, monotonic_seconds=now)
|
||||
|
||||
|
||||
class Trigger(Protocol):
|
||||
"""Decide whether the next ``InferenceStep`` should fire."""
|
||||
|
||||
def should_fire(self, tick: Tick, state: dict[str, Any]) -> bool: ...
|
||||
|
||||
|
||||
@dataclass
|
||||
class HzTrigger:
|
||||
"""Fire at most ``hz`` times per second."""
|
||||
|
||||
hz: float
|
||||
_last_seconds: float | None = field(default=None, init=False)
|
||||
|
||||
def should_fire(self, tick: Tick, state: dict[str, Any]) -> bool:
|
||||
period = 1.0 / max(self.hz, 1e-6)
|
||||
if self._last_seconds is None or (tick.monotonic_seconds - self._last_seconds) >= period:
|
||||
self._last_seconds = tick.monotonic_seconds
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@dataclass
|
||||
class EventTrigger:
|
||||
"""Fire when ``event_name`` is in ``state["events_this_tick"]``.
|
||||
|
||||
The runtime fills ``events_this_tick`` once per tick from:
|
||||
|
||||
* stdin / network input (``user_interjection``, ``user_vqa_query``,
|
||||
``stop``)
|
||||
* internal state transitions (``subtask_change``,
|
||||
``tool_call_pending``)
|
||||
|
||||
The list is consumed (cleared at the end of the tick) so events
|
||||
fire at most once.
|
||||
"""
|
||||
|
||||
event_name: str
|
||||
|
||||
def should_fire(self, tick: Tick, state: dict[str, Any]) -> bool:
|
||||
events: list[str] = state.get("events_this_tick") or []
|
||||
return self.event_name in events
|
||||
@@ -13,60 +13,62 @@
|
||||
# limitations under the License.
|
||||
"""SmolVLA2 modeling — dual-head subclass of SmolVLAPolicy.
|
||||
|
||||
This module defines :class:`SmolVLA2Policy`, which extends SmolVLA with:
|
||||
Adds:
|
||||
|
||||
* an unfrozen SmolVLM ``lm_head`` so language tokens can be supervised,
|
||||
* a forward path that routes to the flow head, the text head, or both,
|
||||
driven by ``batch["predict_actions"]`` and ``batch["text_labels"]``.
|
||||
* a forward path that runs the flow head, the text head, or both,
|
||||
driven by ``batch["predict_actions"]`` and ``batch["text_labels"]``
|
||||
produced by :class:`SmolVLA2ChatTokenizerStep` (the previous commit on
|
||||
this branch).
|
||||
|
||||
The text-head computation itself is NOT wired up in this scaffold commit
|
||||
(the processor doesn't yet produce ``text_labels`` either). This file is
|
||||
the structural placeholder that:
|
||||
Per-sample routing — within one batch:
|
||||
|
||||
1. registers the ``SmolVLA2Policy`` class with the right config name so
|
||||
``policies/factory.py`` can build it,
|
||||
2. unfreezes ``lm_head`` at construction time when the config asks for it
|
||||
(otherwise SmolVLA's ``train_expert_only`` freezes it again on every
|
||||
``train()`` call),
|
||||
3. forwards to ``SmolVLAPolicy.forward`` so behaviour is identical to
|
||||
SmolVLA when no text labels are present — i.e. existing SmolVLA
|
||||
training scripts keep working.
|
||||
* ``predict_actions[i] = True`` ⇒ sample ``i`` contributes to the flow
|
||||
loss (action chunk supervision).
|
||||
* ``predict_actions[i] = False`` ⇒ sample ``i`` is masked out of the
|
||||
flow loss; only its text tokens (where ``text_labels[i, t] != -100``)
|
||||
contribute to the LM-head cross-entropy.
|
||||
|
||||
The next commit on this branch fills in the actual text-loss path.
|
||||
Falls back to ``SmolVLAPolicy.forward`` cleanly when neither
|
||||
``text_labels`` nor ``predict_actions`` is in the batch — unannotated
|
||||
datasets keep working unchanged.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
|
||||
from ..smolvla.modeling_smolvla import SmolVLAPolicy
|
||||
from lerobot.utils.constants import (
|
||||
ACTION,
|
||||
OBS_LANGUAGE_ATTENTION_MASK,
|
||||
OBS_LANGUAGE_TOKENS,
|
||||
OBS_STATE,
|
||||
)
|
||||
|
||||
from ..smolvla.modeling_smolvla import SmolVLAPolicy, make_att_2d_masks
|
||||
from .configuration_smolvla2 import SmolVLA2Config
|
||||
|
||||
|
||||
class SmolVLA2Policy(SmolVLAPolicy):
|
||||
"""SmolVLA + re-enabled SmolVLM language head.
|
||||
|
||||
Compatible drop-in for ``SmolVLAPolicy`` from a checkpoint or factory
|
||||
perspective. Behaviourally identical to SmolVLA until the text-head
|
||||
code path lands in the next commit on this branch.
|
||||
"""
|
||||
"""SmolVLA + re-enabled SmolVLM language head."""
|
||||
|
||||
config_class = SmolVLA2Config
|
||||
name = "smolvla2"
|
||||
|
||||
def __init__(self, config: SmolVLA2Config, dataset_stats: dict[str, dict[str, Tensor]] | None = None):
|
||||
if not isinstance(config, SmolVLA2Config):
|
||||
# Allow loading a SmolVLA checkpoint into a SmolVLA2 model by
|
||||
# widening the config type — the new fields fall back to their
|
||||
# defaults, which preserves the existing SmolVLA behaviour.
|
||||
config = SmolVLA2Config(**{
|
||||
f.name: getattr(config, f.name)
|
||||
for f in config.__dataclass_fields__.values()
|
||||
if hasattr(config, f.name)
|
||||
})
|
||||
config = SmolVLA2Config(
|
||||
**{
|
||||
f.name: getattr(config, f.name)
|
||||
for f in config.__dataclass_fields__.values()
|
||||
if hasattr(config, f.name)
|
||||
}
|
||||
)
|
||||
super().__init__(config, dataset_stats=dataset_stats)
|
||||
if config.unfreeze_lm_head and config.text_loss_weight > 0:
|
||||
self._unfreeze_lm_head()
|
||||
@@ -76,13 +78,8 @@ class SmolVLA2Policy(SmolVLAPolicy):
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _unfreeze_lm_head(self) -> None:
|
||||
"""Re-enable gradients on the SmolVLM ``lm_head`` (and the bits of
|
||||
the text path SmolVLA freezes) so the text-loss can flow back.
|
||||
|
||||
SmolVLA's ``SmolVLMWithExpertModel.set_requires_grad`` freezes
|
||||
``lm_head``, ``text_model.model.norm.weight``, and the last
|
||||
``text_model.layers.<N-1>`` block. We undo that selectively when
|
||||
text training is enabled.
|
||||
"""Re-enable gradients on the SmolVLM ``lm_head`` (and the bits
|
||||
of the text path SmolVLA freezes) so the text-loss can flow back.
|
||||
"""
|
||||
vlm_with_expert = getattr(self.model, "vlm_with_expert", None)
|
||||
if vlm_with_expert is None:
|
||||
@@ -91,10 +88,7 @@ class SmolVLA2Policy(SmolVLAPolicy):
|
||||
if vlm is None:
|
||||
return
|
||||
for name, param in vlm.named_parameters():
|
||||
if (
|
||||
"lm_head" in name
|
||||
or "text_model.model.norm.weight" in name
|
||||
):
|
||||
if "lm_head" in name or "text_model.model.norm.weight" in name:
|
||||
param.requires_grad = True
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
@@ -108,12 +102,286 @@ class SmolVLA2Policy(SmolVLAPolicy):
|
||||
time: Tensor | None = None,
|
||||
reduction: str = "mean",
|
||||
) -> tuple[Tensor, dict[str, Any]]:
|
||||
"""Forward pass with optional text-head loss.
|
||||
"""Forward pass with optional dual-head loss.
|
||||
|
||||
SCAFFOLD: forwards directly to ``SmolVLAPolicy.forward``. The
|
||||
actual text-loss / dual-head routing lands in the next commit on
|
||||
this branch — it will read ``batch["text_labels"]`` and
|
||||
``batch["predict_actions"]`` (both produced by the SmolVLA2
|
||||
processor) to decide which head(s) to run.
|
||||
Two routing knobs from the batch (produced by
|
||||
:class:`SmolVLA2ChatTokenizerStep`):
|
||||
|
||||
* ``text_labels`` — per-token labels with ``-100`` for non-target
|
||||
positions. Triggers the text-loss path through ``lm_head``.
|
||||
* ``predict_actions`` — per-sample bool tensor. ``True`` ⇒
|
||||
include this sample's action chunk in the flow loss.
|
||||
|
||||
When neither is present, delegate to ``SmolVLAPolicy.forward``.
|
||||
"""
|
||||
return super().forward(batch, noise=noise, time=time, reduction=reduction)
|
||||
text_labels = batch.get("text_labels")
|
||||
predict_actions_t = batch.get("predict_actions")
|
||||
|
||||
has_text_data = (
|
||||
text_labels is not None
|
||||
and isinstance(text_labels, Tensor)
|
||||
and self.config.text_loss_weight > 0
|
||||
)
|
||||
has_per_sample_routing = (
|
||||
predict_actions_t is not None and isinstance(predict_actions_t, Tensor)
|
||||
)
|
||||
|
||||
if not has_text_data and not has_per_sample_routing:
|
||||
return super().forward(batch, noise=noise, time=time, reduction=reduction)
|
||||
|
||||
loss_dict: dict[str, Any] = {}
|
||||
device = batch[OBS_STATE].device
|
||||
total = torch.zeros((), device=device, dtype=torch.float32)
|
||||
|
||||
# ------------------------------------------------------------
|
||||
# Flow loss path — only when at least one sample wants actions.
|
||||
# ------------------------------------------------------------
|
||||
run_flow = self.config.flow_loss_weight > 0 and (
|
||||
not has_per_sample_routing or bool(predict_actions_t.any().item())
|
||||
)
|
||||
if run_flow and ACTION in batch:
|
||||
per_sample_flow, flow_diag = super().forward(
|
||||
batch, noise=noise, time=time, reduction="none"
|
||||
)
|
||||
# ``per_sample_flow`` has shape (B,) from the SmolVLA
|
||||
# reduction="none" branch.
|
||||
if has_per_sample_routing:
|
||||
mask = predict_actions_t.to(per_sample_flow.dtype)
|
||||
masked = per_sample_flow * mask
|
||||
denom = mask.sum().clamp(min=1.0)
|
||||
flow_loss = masked.sum() / denom
|
||||
else:
|
||||
flow_loss = per_sample_flow.mean()
|
||||
total = total + self.config.flow_loss_weight * flow_loss
|
||||
loss_dict["flow_loss"] = float(flow_loss.detach().item())
|
||||
for k, v in flow_diag.items():
|
||||
loss_dict[f"flow_{k}"] = v
|
||||
|
||||
# ------------------------------------------------------------
|
||||
# Text loss path — prefix-only forward → lm_head → CE.
|
||||
# ------------------------------------------------------------
|
||||
if has_text_data:
|
||||
text_loss = self._compute_text_loss(batch, text_labels)
|
||||
total = total + self.config.text_loss_weight * text_loss
|
||||
loss_dict["text_loss"] = float(text_loss.detach().item())
|
||||
|
||||
loss_dict["loss"] = float(total.detach().item())
|
||||
|
||||
if reduction == "none":
|
||||
# Per-sample loss isn't meaningfully defined for the dual
|
||||
# path; broadcast the scalar to (B,) for caller compat.
|
||||
return total.expand(batch[OBS_STATE].shape[0]), loss_dict
|
||||
return total, loss_dict
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Text-loss internals
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _compute_text_loss(self, batch: dict[str, Tensor], text_labels: Tensor) -> Tensor:
|
||||
"""Cross-entropy on the SmolVLM ``lm_head`` over target tokens."""
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
|
||||
|
||||
images, img_masks = self.prepare_images(batch)
|
||||
state = self.prepare_state(batch)
|
||||
lang_tokens = batch[OBS_LANGUAGE_TOKENS]
|
||||
lang_masks = batch[OBS_LANGUAGE_ATTENTION_MASK]
|
||||
|
||||
prefix_embs, prefix_pad_masks, prefix_att_masks = self.model.embed_prefix(
|
||||
images, img_masks, lang_tokens, lang_masks, state=state
|
||||
)
|
||||
prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
|
||||
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
||||
|
||||
# Prefix-only forward.
|
||||
out_pair, _ = self.model.vlm_with_expert.forward(
|
||||
attention_mask=prefix_att_2d_masks,
|
||||
position_ids=prefix_position_ids,
|
||||
past_key_values=None,
|
||||
inputs_embeds=[prefix_embs, None],
|
||||
use_cache=False,
|
||||
fill_kv_cache=False,
|
||||
)
|
||||
prefix_out = out_pair[0] if isinstance(out_pair, (tuple, list)) else out_pair
|
||||
if prefix_out is None:
|
||||
raise RuntimeError(
|
||||
"SmolVLA2: vlm_with_expert.forward returned no prefix hidden "
|
||||
"states — text-loss path needs them."
|
||||
)
|
||||
|
||||
# Lang token positions inside the prefix. ``embed_prefix`` lays
|
||||
# out the prefix as ``[image_blocks..., lang, state]`` so the
|
||||
# lang range is identifiable from the trailing state size and
|
||||
# the known lang length.
|
||||
num_lang = lang_tokens.shape[1]
|
||||
state_for_dim = state if state.ndim >= 2 else state[:, None]
|
||||
num_state = state_for_dim.shape[1] if state_for_dim.ndim >= 2 else 1
|
||||
if num_state < 1:
|
||||
num_state = 1
|
||||
prefix_len = prefix_out.shape[1]
|
||||
lang_end = prefix_len - num_state
|
||||
lang_start = lang_end - num_lang
|
||||
if lang_start < 0 or lang_end > prefix_len:
|
||||
raise RuntimeError(
|
||||
f"SmolVLA2: could not locate lang token range in prefix "
|
||||
f"(prefix_len={prefix_len}, num_lang={num_lang}, "
|
||||
f"num_state={num_state})."
|
||||
)
|
||||
|
||||
lang_hidden = prefix_out[:, lang_start:lang_end]
|
||||
vlm = self.model.vlm_with_expert.vlm
|
||||
logits = vlm.lm_head(lang_hidden) # (B, num_lang, vocab)
|
||||
|
||||
if text_labels.shape[1] != num_lang:
|
||||
common = min(text_labels.shape[1], num_lang)
|
||||
logits = logits[:, :common]
|
||||
text_labels = text_labels[:, :common]
|
||||
|
||||
# Standard next-token CE: hidden state at position t predicts
|
||||
# token at position t+1. Shift logits left, labels right by 1.
|
||||
# Without this, the loss is identity-mapped and the LM head
|
||||
# learns nothing useful — see HuggingFace ``LlamaForCausalLM``
|
||||
# for the same convention.
|
||||
shift_logits = logits[:, :-1, :].contiguous()
|
||||
shift_labels = text_labels[:, 1:].contiguous().long()
|
||||
loss = F.cross_entropy(
|
||||
shift_logits.reshape(-1, shift_logits.shape[-1]),
|
||||
shift_labels.reshape(-1),
|
||||
ignore_index=-100,
|
||||
)
|
||||
return loss
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Inference: text generation
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@torch.no_grad()
|
||||
def select_message(
|
||||
self,
|
||||
batch: dict[str, Tensor],
|
||||
*,
|
||||
max_new_tokens: int = 256,
|
||||
eos_token_id: int | None = None,
|
||||
temperature: float = 0.0,
|
||||
top_p: float = 1.0,
|
||||
tokenizer: Any = None,
|
||||
) -> str:
|
||||
"""Generate text continuation from the chat-templated prompt.
|
||||
|
||||
AR decoding with KV caching reused from SmolVLA's inference
|
||||
path. Batch size is assumed to be 1 (the runtime calls this
|
||||
per-event). Returns the decoded string of new tokens (the
|
||||
prompt itself is not included).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
batch:
|
||||
Already through the SmolVLA2 preprocessor — expects
|
||||
``OBS_IMAGES_*``, ``OBS_STATE``, ``OBS_LANGUAGE_TOKENS``,
|
||||
``OBS_LANGUAGE_ATTENTION_MASK``.
|
||||
max_new_tokens:
|
||||
Hard cap on generated tokens; stops earlier on EOS.
|
||||
eos_token_id:
|
||||
Override the tokenizer's EOS. ``None`` ⇒ use the
|
||||
tokenizer's default.
|
||||
temperature, top_p:
|
||||
``temperature=0`` does greedy argmax (default — matches
|
||||
training distribution most closely). Set ``temperature>0``
|
||||
with optional ``top_p<1`` for nucleus sampling.
|
||||
tokenizer:
|
||||
Optional pre-loaded tokenizer to avoid the cold-start
|
||||
``AutoTokenizer.from_pretrained`` round-trip on every call.
|
||||
"""
|
||||
self.eval()
|
||||
|
||||
if tokenizer is None:
|
||||
from transformers import AutoTokenizer # noqa: PLC0415
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.config.vlm_model_name)
|
||||
if eos_token_id is None:
|
||||
eos_token_id = tokenizer.eos_token_id
|
||||
|
||||
images, img_masks = self.prepare_images(batch)
|
||||
state = self.prepare_state(batch)
|
||||
lang_tokens = batch[OBS_LANGUAGE_TOKENS]
|
||||
lang_masks = batch[OBS_LANGUAGE_ATTENTION_MASK]
|
||||
|
||||
# 1) Embed prefix (images + lang + state) and run with KV cache.
|
||||
prefix_embs, prefix_pad_masks, prefix_att_masks = self.model.embed_prefix(
|
||||
images, img_masks, lang_tokens, lang_masks, state=state
|
||||
)
|
||||
prefix_2d = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
|
||||
prefix_pos = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
||||
out_pair, past_kv = self.model.vlm_with_expert.forward(
|
||||
attention_mask=prefix_2d,
|
||||
position_ids=prefix_pos,
|
||||
past_key_values=None,
|
||||
inputs_embeds=[prefix_embs, None],
|
||||
use_cache=True,
|
||||
fill_kv_cache=True,
|
||||
)
|
||||
prefix_out = out_pair[0] if isinstance(out_pair, (tuple, list)) else out_pair
|
||||
if prefix_out is None:
|
||||
raise RuntimeError("select_message: prefix forward returned no hidden states.")
|
||||
|
||||
vlm = self.model.vlm_with_expert.vlm
|
||||
|
||||
# 2) Initial logits — sample first new token from the last
|
||||
# prefix position.
|
||||
last_hidden = prefix_out[:, -1:]
|
||||
device = last_hidden.device
|
||||
bsize = prefix_embs.shape[0]
|
||||
cur_pos = int(prefix_embs.shape[1])
|
||||
|
||||
generated: list[int] = []
|
||||
for _ in range(max_new_tokens):
|
||||
logits_step = vlm.lm_head(last_hidden)[:, -1] # (B, V)
|
||||
next_ids = self._sample_next_token(logits_step, temperature, top_p)
|
||||
tok_id = int(next_ids[0].item())
|
||||
generated.append(tok_id)
|
||||
if eos_token_id is not None and tok_id == eos_token_id:
|
||||
break
|
||||
|
||||
# 3) Embed the new token and forward with KV cache.
|
||||
new_emb = self.model.vlm_with_expert.embed_language_tokens(
|
||||
next_ids.unsqueeze(0)
|
||||
)
|
||||
new_emb = new_emb * math.sqrt(new_emb.shape[-1])
|
||||
|
||||
new_pos = torch.full((bsize, 1), cur_pos, device=device, dtype=torch.long)
|
||||
new_attn = torch.ones((bsize, cur_pos + 1), device=device, dtype=torch.bool)
|
||||
|
||||
out_pair, past_kv = self.model.vlm_with_expert.forward(
|
||||
attention_mask=new_attn,
|
||||
position_ids=new_pos,
|
||||
past_key_values=past_kv,
|
||||
inputs_embeds=[new_emb, None],
|
||||
use_cache=True,
|
||||
fill_kv_cache=True,
|
||||
)
|
||||
new_prefix_out = out_pair[0] if isinstance(out_pair, (tuple, list)) else out_pair
|
||||
last_hidden = new_prefix_out[:, -1:]
|
||||
cur_pos += 1
|
||||
|
||||
return tokenizer.decode(generated, skip_special_tokens=True).strip()
|
||||
|
||||
@staticmethod
|
||||
def _sample_next_token(
|
||||
logits: Tensor, temperature: float, top_p: float
|
||||
) -> Tensor:
|
||||
"""Pick one token id per batch row from ``logits``."""
|
||||
if temperature <= 0.0:
|
||||
return logits.argmax(dim=-1)
|
||||
scaled = logits / max(temperature, 1e-6)
|
||||
probs = F.softmax(scaled, dim=-1)
|
||||
if top_p < 1.0:
|
||||
sorted_probs, sorted_idx = probs.sort(dim=-1, descending=True)
|
||||
cum = sorted_probs.cumsum(dim=-1)
|
||||
mask = cum > top_p
|
||||
# Always keep the most-likely token.
|
||||
mask[..., 0] = False
|
||||
sorted_probs = sorted_probs.masked_fill(mask, 0.0)
|
||||
sorted_probs = sorted_probs / sorted_probs.sum(dim=-1, keepdim=True).clamp(min=1e-9)
|
||||
pick = torch.multinomial(sorted_probs, num_samples=1)
|
||||
return sorted_idx.gather(-1, pick).squeeze(-1)
|
||||
return torch.multinomial(probs, num_samples=1).squeeze(-1)
|
||||
|
||||
@@ -0,0 +1,196 @@
|
||||
#!/usr/bin/env python
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""``lerobot-smolvla2-runtime`` — interactive REPL for trained SmolVLA2.
|
||||
|
||||
Drives the multi-rate runtime defined in
|
||||
:mod:`lerobot.policies.smolvla2.inference`. Stdin becomes the user
|
||||
channel: type a task, then natural-language interjections / questions.
|
||||
The runtime prints state changes (plan / subtask / memory / vqa /
|
||||
speech) as they happen.
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
||||
Dry run on a checkpoint, no robot connected — useful for sanity-
|
||||
checking text generation::
|
||||
|
||||
uv run lerobot-smolvla2-runtime \\
|
||||
--policy.path=outputs/train/smolvla2_super_poulain/000020000/pretrained_model \\
|
||||
--no_robot \\
|
||||
--task="please clean the kitchen"
|
||||
|
||||
With a real robot::
|
||||
|
||||
uv run lerobot-smolvla2-runtime \\
|
||||
--policy.path=... \\
|
||||
--robot.type=so101 --robot.port=/dev/tty.usbmodem... \\
|
||||
--tts.voice=alba
|
||||
|
||||
Tool dispatch (TTS via ``SayTool``) is enabled by default when
|
||||
``pocket-tts`` is installed; pass ``--no_tts`` to disable.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger("lerobot.smolvla2.runtime")
|
||||
|
||||
|
||||
def _parse_args(argv: list[str] | None = None) -> argparse.Namespace:
|
||||
p = argparse.ArgumentParser(
|
||||
prog="lerobot-smolvla2-runtime",
|
||||
description="Interactive REPL runtime for a trained SmolVLA2 checkpoint.",
|
||||
)
|
||||
p.add_argument(
|
||||
"--policy.path",
|
||||
dest="policy_path",
|
||||
type=Path,
|
||||
required=True,
|
||||
help="Path to a trained SmolVLA2 ``pretrained_model`` directory.",
|
||||
)
|
||||
p.add_argument(
|
||||
"--task",
|
||||
dest="task",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Initial task. If omitted, the first stdin line is treated as the task.",
|
||||
)
|
||||
p.add_argument(
|
||||
"--no_robot",
|
||||
action="store_true",
|
||||
help="Skip robot connection — language-only / dry-run mode.",
|
||||
)
|
||||
p.add_argument(
|
||||
"--no_tts",
|
||||
action="store_true",
|
||||
help="Disable the ``say`` tool dispatch.",
|
||||
)
|
||||
p.add_argument(
|
||||
"--tts.voice",
|
||||
dest="tts_voice",
|
||||
type=str,
|
||||
default="alba",
|
||||
help="Pocket-tts voice name (or path to a .wav for cloning).",
|
||||
)
|
||||
p.add_argument(
|
||||
"--chunk_hz", type=float, default=4.0, help="Action-chunk generation rate."
|
||||
)
|
||||
p.add_argument(
|
||||
"--ctrl_hz", type=float, default=50.0, help="Action dispatch rate."
|
||||
)
|
||||
p.add_argument(
|
||||
"--high_level_hz",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="High-level subtask generation rate.",
|
||||
)
|
||||
p.add_argument(
|
||||
"--max_ticks",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Stop after N ticks (debug / smoke-test).",
|
||||
)
|
||||
p.add_argument("-v", "--verbose", action="store_true", help="Enable DEBUG logging.")
|
||||
return p.parse_args(argv)
|
||||
|
||||
|
||||
def _load_policy(path: Path): # noqa: ANN202
|
||||
"""Load a SmolVLA2 checkpoint from ``path``."""
|
||||
from lerobot.policies.factory import make_policy_from_path # noqa: PLC0415
|
||||
|
||||
policy = make_policy_from_path(str(path))
|
||||
policy.eval()
|
||||
return policy
|
||||
|
||||
|
||||
def _build_tools(policy_path: Path, no_tts: bool, tts_voice: str) -> dict[str, Any]:
|
||||
"""Instantiate the tools declared on this dataset/policy."""
|
||||
if no_tts:
|
||||
return {}
|
||||
try:
|
||||
from lerobot.tools import SayTool # noqa: PLC0415
|
||||
|
||||
return {"say": SayTool(voice=tts_voice)}
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("Could not initialise SayTool (%s) — speech disabled.", exc)
|
||||
return {}
|
||||
|
||||
|
||||
def main(argv: list[str] | None = None) -> int:
|
||||
args = _parse_args(argv)
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG if args.verbose else logging.INFO,
|
||||
format="%(asctime)s %(levelname)s %(message)s",
|
||||
)
|
||||
|
||||
if not args.policy_path.exists():
|
||||
print(f"[smolvla2] policy path not found: {args.policy_path}", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
print(f"[smolvla2] loading policy from {args.policy_path}", flush=True)
|
||||
policy = _load_policy(args.policy_path)
|
||||
|
||||
tools = _build_tools(args.policy_path, args.no_tts, args.tts_voice)
|
||||
if tools:
|
||||
print(f"[smolvla2] tools loaded: {list(tools)}", flush=True)
|
||||
|
||||
# Robot wiring is left as a follow-up — for v1 we run language-only
|
||||
# / dry-run so REPL development doesn't require a connected robot.
|
||||
observation_provider = None
|
||||
robot_executor = None
|
||||
if not args.no_robot:
|
||||
print(
|
||||
"[smolvla2] WARNING: real-robot integration is a follow-up. "
|
||||
"Running in dry-run mode for now (no actions executed).",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
from lerobot.policies.smolvla2.inference import ( # noqa: PLC0415
|
||||
SmolVLA2Runtime,
|
||||
StdinReader,
|
||||
)
|
||||
|
||||
runtime = SmolVLA2Runtime(
|
||||
policy=policy,
|
||||
tools=tools,
|
||||
observation_provider=observation_provider,
|
||||
robot_executor=robot_executor,
|
||||
event_collector=StdinReader().poll,
|
||||
chunk_hz=args.chunk_hz,
|
||||
ctrl_hz=args.ctrl_hz,
|
||||
high_level_hz=args.high_level_hz,
|
||||
)
|
||||
if args.task:
|
||||
runtime.set_task(args.task)
|
||||
print(
|
||||
"[smolvla2] runtime ready. Type a task to begin, then any line for "
|
||||
"interjections, questions ending in '?' for VQA, or 'stop' to exit.",
|
||||
flush=True,
|
||||
)
|
||||
try:
|
||||
runtime.run(max_ticks=args.max_ticks)
|
||||
except KeyboardInterrupt:
|
||||
runtime.stop()
|
||||
print("\n[smolvla2] interrupted by user", flush=True)
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
Reference in New Issue
Block a user