Compare commits

...

3 Commits

Author SHA1 Message Date
Pepijn 37b1eb218a feat(smolvla2): chat-template processor + label mask + predict_actions
Wires PR 1's recipe stack into the SmolVLA2 pipeline so multi-target
sub-recipes (memory_update, ask_vqa, user_interjection_response,
high_level_subtask) carry meaningful supervision through to the model.

- New ``chat_processor_smolvla2.py`` with
  ``SmolVLA2ChatTokenizerStep``: reads ``messages`` /
  ``message_streams`` / ``target_message_indices`` from the rendered
  sample (PR 1 ``RenderMessagesStep``), calls
  ``apply_chat_template(messages, tools=DEFAULT_TOOLS, ...)`` on the
  SmolVLM tokenizer, and writes:

    OBS_LANGUAGE_TOKENS / _ATTENTION_MASK   ← chat-templated prompt
    text_labels                              ← -100 except target msg tokens
    predict_actions                          ← True iff any low_level target

  Builds the label mask robustly by re-rendering the chat through
  each target's prefix and reading off the prefix length — same
  tokenizer, same tools, so the prefix tokens are guaranteed to be
  a prefix of the full sequence. Image/video content blocks
  (LeRobot ``feature``-keyed) are stripped before tokenizing; the
  actual image tensors flow through SmolVLA's existing
  ``OBS_IMAGES_*`` channels and ``embed_prefix`` puts them before
  the language embeddings, matching the chat-template-stripped
  text order.

- ``processor_smolvla2.py``: when ``config.recipe_path`` is set,
  build a new pipeline with ``RenderMessagesStep`` +
  ``SmolVLA2ChatTokenizerStep`` instead of SmolVLA's plain
  ``TokenizerProcessorStep``. When ``recipe_path`` is ``None``,
  fall back to SmolVLA's pipeline so unannotated datasets still
  work unchanged. Resolves recipe paths relative to
  ``src/lerobot/configs/`` so ``recipes/smolvla2_hirobot.yaml``
  works directly.

The next commit on this branch picks up ``text_labels`` and
``predict_actions`` from the batch and routes them through the
SmolVLM ``lm_head`` for the actual dual-loss training.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-30 19:21:03 +02:00
Pepijn 52e1fd35cb feat(tools): src/lerobot/tools/ — runnable tool registry + SayTool
Ships the runtime side of the OpenAI-style function-calling stack
introduced in PR 1 (catalog in ``meta/info.json["tools"]``) and PR 2
(annotation pipeline writes the catalog after a run). One file per
tool — heavy deps stay isolated.

Layout:

- ``base.py`` — :class:`Tool` Protocol: ``name``, ``schema``,
  ``call(arguments)``. Runtime-checkable so tests can use
  ``isinstance(...)``.
- ``registry.py`` — :data:`TOOL_REGISTRY` (name → class) plus
  ``get_tools(meta, **kwargs)`` that instantiates every entry whose
  ``function.name`` is registered. Tools whose name is unknown are
  silently skipped — the schema still rides through the chat
  template, the model just can't actually invoke that tool at
  inference.
- ``say.py`` — :class:`SayTool` wrapping Kyutai's pocket-tts
  (CPU-only, ~100M params, ~6× real-time on a MacBook Air M4).
  Lazy model load: pocket-tts is imported and the voice state
  computed on first ``call(...)`` (or eagerly via ``preload()``).
  Returns the PCM tensor; optionally writes a ``.wav`` to
  ``output_dir`` for offline inspection.
- ``__init__.py`` — re-exports the public surface.

Optional install:

    pip install lerobot[tools]

The ``[tools]`` extra in ``pyproject.toml`` pulls in ``pocket-tts`` +
``scipy`` (for the wav writer). Adding more tools later means a new
file + a registry entry — no new extras unless the tool brings new
deps.

To add your own tool, follow the three-step guide in
``docs/source/tools.mdx`` (PR 1):

  1. Drop ``src/lerobot/tools/<my_tool>.py`` with a ``Tool``-conforming
     class.
  2. Register the class in ``TOOL_REGISTRY`` (this file).
  3. Pre-populate ``meta/info.json["tools"]`` with the schema (or let
     ``lerobot-annotate`` add it on the next run).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-30 18:58:04 +02:00
Pepijn 7459dfccb6 feat(policies): scaffold smolvla2 (smolvla + lm_head re-enabled)
PR 3 of the steerable-annotation plan retargeted from Pi0.5 to SmolVLA
because the recipe stack (PR 1 + PR 2) outputs HF/TRL-compatible chat
which a chat-pretrained backbone consumes natively. SmolVLA strips the
SmolVLM ``lm_head`` though, so it can only do flow-matching action
prediction. SmolVLA2 keeps the LM head so the same model can train on
the full Hi Robot / MEM / ECoT blend defined in the plan:

  * action-only sub-recipes  (low_level_execution)        flow loss
  * text-only sub-recipes    (memory_update / ask_vqa /   CE loss on
                              user_interjection_response)  lm_head
  * mixed sub-recipes                                      both summed

This first commit lays down the structural scaffold:

- ``src/lerobot/policies/smolvla2/`` — new package with thin subclasses
  of ``SmolVLAConfig`` / ``SmolVLAPolicy`` so we don't fork the 900-line
  modeling code. ``SmolVLA2Config`` adds ``recipe_path``,
  ``apply_chat_template``, ``text_loss_weight``, ``flow_loss_weight``,
  and ``unfreeze_lm_head``. ``SmolVLA2Policy`` unfreezes the SmolVLM
  ``lm_head`` (and the surrounding norm + last text-model layer SmolVLA
  freezes) when ``unfreeze_lm_head=True`` and ``text_loss_weight>0``.
- ``factory.py`` registers ``smolvla2`` in ``get_policy_class``,
  ``make_policy_config``, and the pre/post-processor builder. Important:
  the ``smolvla2`` branch lives BEFORE the ``isinstance(config,
  SmolVLAConfig)`` check because ``SmolVLA2Config`` subclasses
  ``SmolVLAConfig`` — without the ordering, SmolVLA2 would silently
  pick up SmolVLA's processor.
- ``configs/recipes/smolvla2_hirobot.yaml`` — canonical Hi Robot blend
  for SmolVLA2. Same shape as ``pi05_hirobot.yaml`` (PR 1) so the
  recipe stack stays uniform across policy backbones.

Behaviour today is identical to SmolVLA: the modeling forward
delegates to ``SmolVLAPolicy.forward`` and the processor delegates to
``make_smolvla_pre_post_processors``. The next commit on this branch
adds the chat-template processor + ``text_labels`` / ``predict_actions``
batch keys; the commit after that wires the actual text-loss path
through ``vlm.lm_head``.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-30 18:55:23 +02:00
12 changed files with 1098 additions and 0 deletions
+8
View File
@@ -209,6 +209,14 @@ annotations = [
"vllm>=0.6.0,<1.0.0; sys_platform == 'linux'",
]
# Tool implementations under src/lerobot/tools/. Each tool's dependencies
# are isolated so adding a new tool doesn't bloat the base install.
# Currently only `say` (Kyutai pocket-tts; CPU-only, ~100M params).
tools = [
"pocket-tts>=0.1.0,<1.0.0",
"scipy>=1.11.0,<2.0.0", # SayTool.output_dir uses scipy.io.wavfile
]
# Development
dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1", "mypy>=1.19.1", "ruff>=0.14.1", "lerobot[notebook]"]
notebook = ["jupyter>=1.0.0,<2.0.0", "ipykernel>=6.0.0,<7.0.0"]
@@ -0,0 +1,88 @@
# SmolVLA2 canonical training recipe — Hi Robot / MEM / ECoT blend.
#
# Same blend shape as pi05_hirobot.yaml. SmolVLA2 differs from Pi0.5 in
# how the renderer's output is consumed:
#
# - SmolVLA2 calls SmolVLM's tokenizer.apply_chat_template(messages,
# tools=DEFAULT_TOOLS) on the rendered messages, since SmolVLM is a
# chat-pretrained backbone.
# - The processor builds a `text_labels` tensor that masks every token
# except those belonging to messages whose index is in
# `target_message_indices`. Cross-entropy on those positions trains
# the LM head.
# - `predict_actions = bool(targets_by_stream.get("low_level"))` —
# same convention as Pi0.5. ``low_level_execution`` is the only
# branch that runs the action expert / flow head.
blend:
memory_update:
weight: 0.10
bindings:
prior_memory: "nth_prev(style=memory, offset=1)"
current_memory: "emitted_at(t, style=memory)"
completed_subtask: "nth_prev(style=subtask, offset=1)"
messages:
- {role: user, content: "${task}", stream: high_level}
- {role: assistant, content: "Previous memory: ${prior_memory}", stream: high_level, if_present: prior_memory}
- {role: user, content: "Completed subtask: ${completed_subtask}", stream: high_level, if_present: completed_subtask}
- {role: assistant, content: "${current_memory}", stream: high_level, target: true, if_present: current_memory}
user_interjection_response:
weight: 0.16
bindings:
prior_plan: "nth_prev(style=plan, offset=1)"
current_plan: "emitted_at(t, style=plan)"
interjection: "emitted_at(t, style=interjection)"
speech: "emitted_at(t, role=assistant, tool_name=say)"
messages:
- {role: user, content: "${task}", stream: high_level}
- {role: assistant, content: "Previous plan:\n${prior_plan}", stream: high_level, if_present: prior_plan}
- {role: user, content: "${interjection}", stream: high_level, if_present: interjection}
- {role: assistant, content: "${current_plan}", stream: high_level, target: true, if_present: current_plan, tool_calls_from: speech}
high_level_subtask:
weight: 0.15
bindings:
next_subtask: "nth_next(style=subtask, offset=1)"
messages:
- {role: user, content: "${task}\nPlan: ${plan}\nMemory: ${memory}", stream: high_level}
- {role: user, content: "Current subtask: ${subtask}", stream: high_level, if_present: subtask}
- {role: assistant, content: "${next_subtask}", stream: high_level, target: true}
low_level_execution:
weight: 0.35
messages:
- {role: user, content: "${task}\nPlan: ${plan}\nMemory: ${memory}", stream: high_level}
- {role: assistant, content: "${subtask}", stream: low_level, target: true}
# Per-camera VQA sub-recipes (PR 1's view-dependent style routing).
# Adjust the camera keys (and add more sub-recipes) to match the
# cameras present on your dataset.
ask_vqa_top:
weight: 0.10
bindings:
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.top)"
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.top)"
messages:
- role: user
stream: high_level
if_present: vqa_query
content:
- {type: image, feature: observation.images.top}
- {type: text, text: "${vqa_query}"}
- {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa}
ask_vqa_wrist:
weight: 0.10
bindings:
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.wrist)"
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.wrist)"
messages:
- role: user
stream: high_level
if_present: vqa_query
content:
- {type: image, feature: observation.images.wrist}
- {type: text, text: "${vqa_query}"}
- {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa}
+19
View File
@@ -140,6 +140,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
from .smolvla.modeling_smolvla import SmolVLAPolicy
return SmolVLAPolicy
elif name == "smolvla2":
from .smolvla2.modeling_smolvla2 import SmolVLA2Policy
return SmolVLA2Policy
elif name == "sarm":
from .sarm.modeling_sarm import SARMRewardModel
@@ -200,6 +204,10 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
return SACConfig(**kwargs)
elif policy_type == "smolvla":
return SmolVLAConfig(**kwargs)
elif policy_type == "smolvla2":
from .smolvla2.configuration_smolvla2 import SmolVLA2Config
return SmolVLA2Config(**kwargs)
elif policy_type == "reward_classifier":
return RewardClassifierConfig(**kwargs)
elif policy_type == "groot":
@@ -386,6 +394,17 @@ def make_pre_post_processors(
dataset_stats=kwargs.get("dataset_stats"),
)
elif policy_cfg.type == "smolvla2":
# NOTE: SmolVLA2Config subclasses SmolVLAConfig, so this branch
# MUST come before the SmolVLAConfig isinstance check below
# (otherwise SmolVLA2 would silently pick up SmolVLA's processor).
from .smolvla2.processor_smolvla2 import make_smolvla2_pre_post_processors
processors = make_smolvla2_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, SmolVLAConfig):
from .smolvla.processor_smolvla import make_smolvla_pre_post_processors
+38
View File
@@ -0,0 +1,38 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""SmolVLA2 — SmolVLA with the SmolVLM language head re-enabled.
SmolVLA strips the LM head from the SmolVLM backbone because it only does
flow-matching action prediction. SmolVLA2 keeps the LM head so the same
model can train on the full Hi Robot / MEM / ECoT message blend defined in
the steerable annotation plan (PR1 + PR2):
* action-only sub-recipes (e.g. ``low_level_execution``) flow loss
* text-only sub-recipes (e.g. ``memory_update``, ``ask_vqa``,
``user_interjection_response``, ``high_level_subtask``) CE loss on
``lm_head`` over the recipe's target message tokens
* mixed sub-recipes both losses summed (weighted)
The ``predict_actions`` toggle follows the Pi0.5 convention from Section
I.7 of the plan: ``True`` if any ``low_level`` target is present in the
sample, else ``False``.
This package is a thin subclass of ``lerobot.policies.smolvla`` so most of
the model code stays in one place only the dual-loss path and the
chat-template processor live here.
"""
from .configuration_smolvla2 import SmolVLA2Config
__all__ = ["SmolVLA2Config"]
@@ -0,0 +1,271 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""SmolVLA2's chat-template tokenization step.
Replaces SmolVLA's plain ``TokenizerProcessorStep`` for SmolVLA2 when a
``recipe_path`` is set. Reads the rendered messages produced by
``RenderMessagesStep`` (PR 1) and produces:
* ``OBS_LANGUAGE_TOKENS`` / ``OBS_LANGUAGE_ATTENTION_MASK``
the chat-templated prompt tokenized by SmolVLM's tokenizer, with
``tools=meta.tools`` (PR 1's catalog).
* ``text_labels`` same shape as token ids, ``-100`` everywhere except
the positions belonging to messages whose index is in
``target_message_indices``. The next commit's modeling forward path
applies cross-entropy on those positions via the SmolVLM ``lm_head``.
* ``predict_actions`` bool tensor, ``True`` iff any of the rendered
target messages has ``message_streams[i] == "low_level"``. The
modeling forward uses this to gate the flow head.
Image / video content blocks in the rendered messages are dropped
before tokenization the chat template only handles text, and SmolVLA
already passes camera tensors out-of-band via the standard
``OBS_IMAGES_*`` features. This keeps the prefix layout unchanged
(``embed_prefix`` puts image embeddings before language embeddings,
matching the chat-template-stripped text order).
"""
from __future__ import annotations
import copy
import logging
from dataclasses import dataclass
from typing import Any
import torch
from lerobot.configs import PipelineFeatureType, PolicyFeature
from lerobot.datasets.language import DEFAULT_TOOLS
from lerobot.processor.pipeline import ProcessorStep, ProcessorStepRegistry
from lerobot.types import EnvTransition, TransitionKey
from lerobot.utils.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
logger = logging.getLogger(__name__)
@dataclass
@ProcessorStepRegistry.register(name="smolvla2_chat_tokenizer")
class SmolVLA2ChatTokenizerStep(ProcessorStep):
"""Render messages → token ids + label mask + predict_actions flag.
This is the bridge between the recipe stack (PR 1's
``RenderMessagesStep`` outputs) and the SmolVLA2 modeling forward
(next commit, which reads ``text_labels`` / ``predict_actions``).
Pure-text turns and multi-stream targets are both handled.
"""
tokenizer_name: str = "HuggingFaceTB/SmolVLM2-500M-Video-Instruct"
max_length: int = 2048
padding: str = "longest"
padding_side: str = "right"
tools: list[dict[str, Any]] | None = None
def __post_init__(self) -> None:
# Lazy: don't load the tokenizer until the step actually runs,
# so unit tests that import the module without transformers
# installed still pass.
self._tokenizer: Any = None
if self.tools is None:
# Default: ship the canonical ``say`` schema. Users who set
# ``meta.tools`` differently can override via
# ``with_tools(meta.tools)``.
self.tools = list(DEFAULT_TOOLS)
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def with_tools(self, tools: list[dict[str, Any]]) -> "SmolVLA2ChatTokenizerStep":
"""Override the tools catalog rendered into the system prompt."""
self.tools = list(tools)
return self
def __call__(self, transition: EnvTransition) -> EnvTransition | None:
comp = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
messages = comp.get("messages")
if not messages:
# No recipe rendering happened — nothing to do; downstream
# falls back to whatever ``task`` is in the transition.
return transition
message_streams: list[str | None] = list(comp.get("message_streams") or [])
target_indices: list[int] = sorted(
int(i) for i in (comp.get("target_message_indices") or [])
)
tokenizer = self._get_tokenizer()
text_messages = [_strip_lerobot_blocks(m) for m in messages]
# Tokenize the full chat once.
full_ids = tokenizer.apply_chat_template(
text_messages,
tools=self.tools,
add_generation_prompt=False,
tokenize=True,
return_tensors=None,
)
if isinstance(full_ids, list) and full_ids and isinstance(full_ids[0], list):
full_ids = full_ids[0]
# Build the label mask by re-rendering progressively up to each
# target message and reading off the prefix length. This is the
# robust way to get exact token boundaries: we use the same
# tokenizer, the same ``tools=`` argument, and the same chat
# template — so the prefix tokens are guaranteed to be a prefix
# of the full sequence.
labels = [-100] * len(full_ids)
for tgt in target_indices:
prefix_ids = tokenizer.apply_chat_template(
text_messages[:tgt],
tools=self.tools,
add_generation_prompt=False,
tokenize=True,
return_tensors=None,
)
full_through_target = tokenizer.apply_chat_template(
text_messages[: tgt + 1],
tools=self.tools,
add_generation_prompt=False,
tokenize=True,
return_tensors=None,
)
if isinstance(prefix_ids, list) and prefix_ids and isinstance(prefix_ids[0], list):
prefix_ids = prefix_ids[0]
if (
isinstance(full_through_target, list)
and full_through_target
and isinstance(full_through_target[0], list)
):
full_through_target = full_through_target[0]
start = len(prefix_ids)
end = min(len(full_through_target), len(full_ids))
for pos in range(start, end):
labels[pos] = int(full_ids[pos])
# Truncate / pad to ``max_length`` so batches collate cleanly.
# The SmolVLA pipeline downstream relies on a fixed length
# behaviour ("longest" or "max_length") — we mirror it here.
if len(full_ids) > self.max_length:
full_ids = full_ids[: self.max_length]
labels = labels[: self.max_length]
attn = [1] * len(full_ids)
if self.padding == "max_length" and len(full_ids) < self.max_length:
pad_id = (
tokenizer.pad_token_id
if tokenizer.pad_token_id is not None
else 0
)
n_pad = self.max_length - len(full_ids)
full_ids = full_ids + [pad_id] * n_pad
labels = labels + [-100] * n_pad
attn = attn + [0] * n_pad
ids_t = torch.tensor(full_ids, dtype=torch.long)
attn_t = torch.tensor(attn, dtype=torch.bool)
labels_t = torch.tensor(labels, dtype=torch.long)
predict_actions = any(
i < len(message_streams) and message_streams[i] == "low_level"
for i in target_indices
)
new_complementary = dict(comp)
# Drop the per-recipe sidecar keys; everything downstream needs
# is now in the tokenized form.
new_complementary.pop("messages", None)
new_complementary.pop("message_streams", None)
new_complementary.pop("target_message_indices", None)
# SmolVLA's pipeline expects ``OBS_LANGUAGE_TOKENS`` /
# ``OBS_LANGUAGE_ATTENTION_MASK`` on the OBSERVATION key. Place
# them there — and drop ``task`` so the upstream
# ``TokenizerProcessorStep`` (which we replace) doesn't double-
# tokenize.
observation = dict(transition.get(TransitionKey.OBSERVATION) or {})
observation[OBS_LANGUAGE_TOKENS] = ids_t
observation[OBS_LANGUAGE_ATTENTION_MASK] = attn_t
new_complementary["text_labels"] = labels_t
new_complementary["predict_actions"] = torch.tensor(predict_actions, dtype=torch.bool)
new_complementary.pop("task", None)
new_transition = dict(transition)
new_transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary
new_transition[TransitionKey.OBSERVATION] = observation
return new_transition
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
"""Pass-through; this step writes runtime tensors not features."""
return features
# ------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------
def _get_tokenizer(self): # noqa: ANN202
if self._tokenizer is not None:
return self._tokenizer
try:
from transformers import AutoTokenizer # noqa: PLC0415
except ImportError as exc: # pragma: no cover
raise ImportError(
"SmolVLA2ChatTokenizerStep requires transformers. "
"`pip install lerobot[transformers-dep]`."
) from exc
self._tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name)
if self._tokenizer.pad_token_id is None and self._tokenizer.eos_token_id is not None:
self._tokenizer.pad_token = self._tokenizer.eos_token
return self._tokenizer
def _strip_lerobot_blocks(message: dict[str, Any]) -> dict[str, Any]:
"""Remove LeRobot-specific multimodal blocks from ``message`` content.
The recipe DSL allows authors to write multimodal content like
``{"type": "image", "feature": "observation.images.top"}``. SmolVLM's
tokenizer doesn't know that ``feature`` key (it expects ``url`` or
``path``). The actual image tensor flows through SmolVLA's
``OBS_IMAGES_*`` channels separately; the chat template only needs
the text. So we strip non-text blocks before tokenizing.
"""
new = dict(message)
content = new.get("content")
if isinstance(content, list):
text_parts: list[dict[str, Any]] = []
for block in content:
if not isinstance(block, dict):
continue
if block.get("type") == "text":
text_parts.append({"type": "text", "text": str(block.get("text", ""))})
# If only one text block survives, flatten to a string for
# template friendliness; some chat templates choke on a single-
# element list.
if len(text_parts) == 1:
new["content"] = text_parts[0]["text"]
elif text_parts:
new["content"] = text_parts
else:
new["content"] = ""
if "tool_calls" in new and not new["tool_calls"]:
# Drop empty tool_calls — some templates render them as a
# spurious empty marker.
new.pop("tool_calls")
# ``stream`` and ``target`` were recipe metadata; templates don't
# know them and may warn or crash.
new.pop("stream", None)
new.pop("target", None)
return new
# Re-export for tests / introspection
strip_lerobot_blocks = _strip_lerobot_blocks
@@ -0,0 +1,97 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from lerobot.configs import PreTrainedConfig
from ..smolvla.configuration_smolvla import SmolVLAConfig
@PreTrainedConfig.register_subclass("smolvla2")
@dataclass
class SmolVLA2Config(SmolVLAConfig):
"""SmolVLA2 — SmolVLA with the underlying SmolVLM language head re-enabled.
SmolVLA strips the LM head from the SmolVLM backbone because it only
needs flow-matching action prediction. SmolVLA2 keeps the LM head so the
same model can train on:
* **action-only sub-recipes** (e.g. ``low_level_execution``) flow loss
on the action expert, same as SmolVLA. ``predict_actions=True``.
* **text-only sub-recipes** (e.g. ``memory_update`` / ``ask_vqa`` /
``user_interjection_response`` / ``high_level_subtask``) cross-
entropy loss on the LM head over the recipe's target message tokens.
Skips the flow head entirely. ``predict_actions=False``.
* **mixed sub-recipes** both heads run, losses summed (weighted).
The split is controlled by ``predict_actions = bool(targets_by_stream
.get("low_level"))`` per the Pi0.5 convention in the steerable
annotation plan (Section I.7), implemented inside the processor /
forward path. Recipes drive it via ``stream`` + ``target`` metadata.
Compared to ``SmolVLAConfig`` this adds:
- ``recipe_path``: path to a ``TrainingRecipe`` YAML (loaded by the
train script). When ``None``, SmolVLA2 falls back to the SmolVLA
task-only path so unannotated datasets still work.
- ``text_loss_weight`` / ``flow_loss_weight``: relative weights when
both losses are active in a single sample.
- ``unfreeze_lm_head``: must be ``True`` for the text head to learn
SmolVLA freezes ``lm_head`` to "avoid unused params issues" and we
need to undo that for SmolVLA2.
- ``train_expert_only=False`` by default, since the VLM body now also
participates in text-target gradients.
"""
# Recipe / language stack ---------------------------------------------
recipe_path: str | None = "recipes/smolvla2_hirobot.yaml"
"""Path (absolute or relative to ``src/lerobot/configs/``) to a
``TrainingRecipe`` YAML. The default points at the canonical Hi Robot
blend shipped alongside SmolVLA2. Set to ``None`` to disable recipe
rendering and fall back to SmolVLA's single-task prompt path
(unannotated datasets keep working that way)."""
apply_chat_template: bool = True
"""Apply the SmolVLM tokenizer's chat template to the rendered messages
before tokenizing. SmolVLM's backbone is chat-pretrained, so this
matches its training distribution."""
# Loss weights --------------------------------------------------------
text_loss_weight: float = 1.0
"""Weight on the LM-head cross-entropy term. Set to ``0`` to disable
text training entirely (reverts to flow-only / SmolVLA behaviour)."""
flow_loss_weight: float = 1.0
"""Weight on the action-expert flow-matching term."""
# Backbone training ---------------------------------------------------
unfreeze_lm_head: bool = True
"""Whether to unfreeze the SmolVLM ``lm_head`` (and the immediately
preceding norm + last text-model layer that SmolVLA freezes). Must be
``True`` for the text head to learn. Setting this to ``False``
effectively reduces SmolVLA2 back to SmolVLA's flow-only training,
which is occasionally useful for ablations."""
def __post_init__(self) -> None:
super().__post_init__()
# Backbone needs gradients flowing through its text path when the
# LM head is producing supervised text. Override the SmolVLA
# default (`train_expert_only=True`) unless the user explicitly
# opts out of text training via `text_loss_weight=0`.
if self.text_loss_weight > 0 and self.unfreeze_lm_head:
# The user can still flip this back via CLI; this only
# changes the *default* when SmolVLA2 is actually training a
# text head.
self.train_expert_only = False
@@ -0,0 +1,119 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""SmolVLA2 modeling — dual-head subclass of SmolVLAPolicy.
This module defines :class:`SmolVLA2Policy`, which extends SmolVLA with:
* an unfrozen SmolVLM ``lm_head`` so language tokens can be supervised,
* a forward path that routes to the flow head, the text head, or both,
driven by ``batch["predict_actions"]`` and ``batch["text_labels"]``.
The text-head computation itself is NOT wired up in this scaffold commit
(the processor doesn't yet produce ``text_labels`` either). This file is
the structural placeholder that:
1. registers the ``SmolVLA2Policy`` class with the right config name so
``policies/factory.py`` can build it,
2. unfreezes ``lm_head`` at construction time when the config asks for it
(otherwise SmolVLA's ``train_expert_only`` freezes it again on every
``train()`` call),
3. forwards to ``SmolVLAPolicy.forward`` so behaviour is identical to
SmolVLA when no text labels are present i.e. existing SmolVLA
training scripts keep working.
The next commit on this branch fills in the actual text-loss path.
"""
from __future__ import annotations
from typing import Any
import torch
from torch import Tensor
from ..smolvla.modeling_smolvla import SmolVLAPolicy
from .configuration_smolvla2 import SmolVLA2Config
class SmolVLA2Policy(SmolVLAPolicy):
"""SmolVLA + re-enabled SmolVLM language head.
Compatible drop-in for ``SmolVLAPolicy`` from a checkpoint or factory
perspective. Behaviourally identical to SmolVLA until the text-head
code path lands in the next commit on this branch.
"""
config_class = SmolVLA2Config
name = "smolvla2"
def __init__(self, config: SmolVLA2Config, dataset_stats: dict[str, dict[str, Tensor]] | None = None):
if not isinstance(config, SmolVLA2Config):
# Allow loading a SmolVLA checkpoint into a SmolVLA2 model by
# widening the config type — the new fields fall back to their
# defaults, which preserves the existing SmolVLA behaviour.
config = SmolVLA2Config(**{
f.name: getattr(config, f.name)
for f in config.__dataclass_fields__.values()
if hasattr(config, f.name)
})
super().__init__(config, dataset_stats=dataset_stats)
if config.unfreeze_lm_head and config.text_loss_weight > 0:
self._unfreeze_lm_head()
# ------------------------------------------------------------------
# Backbone surgery
# ------------------------------------------------------------------
def _unfreeze_lm_head(self) -> None:
"""Re-enable gradients on the SmolVLM ``lm_head`` (and the bits of
the text path SmolVLA freezes) so the text-loss can flow back.
SmolVLA's ``SmolVLMWithExpertModel.set_requires_grad`` freezes
``lm_head``, ``text_model.model.norm.weight``, and the last
``text_model.layers.<N-1>`` block. We undo that selectively when
text training is enabled.
"""
vlm_with_expert = getattr(self.model, "vlm_with_expert", None)
if vlm_with_expert is None:
return
vlm = getattr(vlm_with_expert, "vlm", None)
if vlm is None:
return
for name, param in vlm.named_parameters():
if (
"lm_head" in name
or "text_model.model.norm.weight" in name
):
param.requires_grad = True
# ------------------------------------------------------------------
# Forward
# ------------------------------------------------------------------
def forward(
self,
batch: dict[str, Tensor],
noise: Tensor | None = None,
time: Tensor | None = None,
reduction: str = "mean",
) -> tuple[Tensor, dict[str, Any]]:
"""Forward pass with optional text-head loss.
SCAFFOLD: forwards directly to ``SmolVLAPolicy.forward``. The
actual text-loss / dual-head routing lands in the next commit on
this branch it will read ``batch["text_labels"]`` and
``batch["predict_actions"]`` (both produced by the SmolVLA2
processor) to decide which head(s) to run.
"""
return super().forward(batch, noise=noise, time=time, reduction=reduction)
@@ -0,0 +1,131 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""SmolVLA2 processor pipelines.
When ``config.recipe_path`` is set, the pre-processor pipeline becomes:
rename observations
add batch dim
RenderMessagesStep(recipe) # PR 1: language_* → messages
SmolVLA2ChatTokenizerStep(...) # chat template + label mask + predict_actions
DeviceProcessorStep
NormalizerProcessorStep
When ``config.recipe_path`` is ``None``, we delegate to SmolVLA's
plain task-string pipeline so unannotated datasets still work.
Post-processor is unchanged from SmolVLA.
"""
from __future__ import annotations
from pathlib import Path
from typing import Any
import torch
from lerobot.configs.recipe import TrainingRecipe
from lerobot.processor import (
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
NormalizerProcessorStep,
PolicyAction,
PolicyProcessorPipeline,
RenameObservationsProcessorStep,
RenderMessagesStep,
UnnormalizerProcessorStep,
policy_action_to_transition,
transition_to_policy_action,
)
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
from ..smolvla.processor_smolvla import make_smolvla_pre_post_processors
from .chat_processor_smolvla2 import SmolVLA2ChatTokenizerStep
from .configuration_smolvla2 import SmolVLA2Config
def make_smolvla2_pre_post_processors(
config: SmolVLA2Config,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
"""Build SmolVLA2's pre/post-processor pipelines.
With ``recipe_path`` set, inserts the recipe-rendering step and the
chat-template tokenizer that emits ``text_labels`` and
``predict_actions`` for the dual-loss path. Without it, falls back
to SmolVLA's plain task-string pipeline so unannotated datasets
keep working unchanged.
"""
if not config.recipe_path:
return make_smolvla_pre_post_processors(config, dataset_stats=dataset_stats)
recipe = _load_recipe(config.recipe_path)
input_steps = [
RenameObservationsProcessorStep(rename_map={}),
AddBatchDimensionProcessorStep(),
RenderMessagesStep(recipe=recipe),
SmolVLA2ChatTokenizerStep(
tokenizer_name=config.vlm_model_name,
max_length=config.tokenizer_max_length,
padding=config.pad_language_to,
),
DeviceProcessorStep(device=config.device),
NormalizerProcessorStep(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
]
output_steps = [
UnnormalizerProcessorStep(
features=config.output_features,
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
DeviceProcessorStep(device="cpu"),
]
return (
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
steps=input_steps,
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
),
PolicyProcessorPipeline[PolicyAction, PolicyAction](
steps=output_steps,
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
to_transition=policy_action_to_transition,
to_output=transition_to_policy_action,
),
)
def _load_recipe(path_str: str) -> TrainingRecipe:
"""Resolve ``path_str`` to a ``TrainingRecipe``.
Accepts an absolute path or a path relative to
``src/lerobot/configs/`` so recipe authors can write
``--policy.recipe_path=recipes/smolvla2_hirobot.yaml``.
"""
p = Path(path_str)
if not p.is_absolute() and not p.exists():
from lerobot.configs import recipe as _recipe_module # noqa: PLC0415
configs_dir = Path(_recipe_module.__file__).resolve().parent
candidate = configs_dir / path_str
if candidate.exists():
p = candidate
return TrainingRecipe.from_yaml(p)
+29
View File
@@ -0,0 +1,29 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""LeRobot tool implementations.
Storage of the tool catalog (``meta/info.json["tools"]``) and the
``SAY_TOOL_SCHEMA`` constant live in PR 1
(``lerobot.datasets.language``). This package holds the *runnable*
implementations one file per tool, plus the registry that maps tool
names to classes.
See ``docs/source/tools.mdx`` for the authoring guide.
"""
from .base import Tool
from .registry import TOOL_REGISTRY, get_tools
from .say import SayTool
__all__ = ["Tool", "TOOL_REGISTRY", "get_tools", "SayTool"]
+58
View File
@@ -0,0 +1,58 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tool protocol — the contract every runnable tool implementation honors.
Tools are the executable side of the OpenAI-style function-calling
abstraction the v3.1 language schema (PR 1) carries on assistant
messages: the schema describes *what can be called*, the tool
implementation describes *how to call it*.
Implementations live one-per-file under :mod:`lerobot.tools` (e.g.
``say.py`` for ``SayTool``) and are registered in
:mod:`lerobot.tools.registry`. The runtime instantiates them lazily so
heavy dependencies (torch models, audio backends, network clients,
hardware drivers) only load when the dataset actually declares the tool.
"""
from __future__ import annotations
from typing import Any, Protocol, runtime_checkable
@runtime_checkable
class Tool(Protocol):
"""Minimum surface every tool must expose."""
#: Name matching ``schema["function"]["name"]``. The runtime dispatcher
#: routes incoming ``tool_calls`` to the implementation by this key.
name: str
#: OpenAI-style function-call schema. Same dict the dataset stores in
#: ``meta/info.json["tools"]`` and the chat template renders into the
#: prompt.
schema: dict[str, Any]
def call(self, arguments: dict[str, Any]) -> Any:
"""Execute the tool with the model-provided arguments.
``arguments`` is the parsed dict from
``tool_calls[i]["function"]["arguments"]`` (already JSON-decoded
when the model emits a JSON-string by the chat-template
convention). Implementations validate the dict against their own
schema; the runtime only routes by name.
Return value is implementation-defined typically a tensor
(TTS audio), a Path (saved file), a dict (structured result), or
``None`` (side-effect-only call).
"""
+70
View File
@@ -0,0 +1,70 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tool registry — name → implementation class.
Adding a new tool:
1. Drop a file under ``src/lerobot/tools/`` that defines a class
conforming to :class:`lerobot.tools.base.Tool` (must expose ``name``,
``schema``, ``call(arguments)``).
2. Register the class here under :data:`TOOL_REGISTRY`.
3. (Optional) Pre-populate ``meta/info.json["tools"]`` on your dataset
to advertise the schema to the chat-template + policy. The PR 2
annotation pipeline preserves anything you put there.
See ``docs/source/tools.mdx`` for the full authoring guide.
"""
from __future__ import annotations
from typing import Any
from .base import Tool
from .say import SayTool
#: Map from ``function.name`` to a class implementing :class:`Tool`.
#: The runtime instantiates entries lazily — registering a tool here is
#: essentially free (no model load happens until ``call`` runs).
TOOL_REGISTRY: dict[str, type] = {
"say": SayTool,
}
def get_tools(meta: Any, **kwargs: Any) -> dict[str, Tool]:
"""Build name → tool-instance dict from a dataset's declared catalog.
``meta`` is anything with a ``.tools`` attribute returning the
OpenAI-style schema list typically a
:class:`lerobot.datasets.dataset_metadata.LeRobotDatasetMetadata`.
Each entry whose ``function.name`` is registered here is
instantiated with the schema dict; tools whose name is unknown to
the registry are skipped (the schema still rides through the chat
template, the model just can't actually invoke that tool at
inference).
Extra keyword arguments are forwarded to every constructor useful
for runtime defaults like ``output_dir=Path("./tts_log")``.
"""
declared = list(meta.tools)
instances: dict[str, Tool] = {}
for schema in declared:
try:
name = schema["function"]["name"]
except (KeyError, TypeError):
continue
cls = TOOL_REGISTRY.get(name)
if cls is None:
continue
instances[name] = cls(schema=schema, **kwargs)
return instances
+170
View File
@@ -0,0 +1,170 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""``SayTool`` — text-to-speech tool wrapping Kyutai's pocket-tts.
The first concrete tool implementation. SmolVLA2 (PR 3) and downstream
runtime dispatchers consume this when the model emits an assistant
message with ``tool_calls=[{function: {name: "say", arguments:
{text: ...}}}]``.
Why pocket-tts:
- runs on CPU (no GPU dependency); ~6× real-time on a MacBook Air M4
- ~100M parameters, ~200ms first-chunk latency
- streamable, voice-cloneable
- pip-installable, MIT-style permissive license
The pocket-tts model is loaded **lazily** the first time ``call(...)``
runs (or eagerly via ``preload()``). Loading takes a few seconds and
several hundred MB of RAM, so we don't pay the cost when the tool is
merely *registered* only when it's *invoked*.
Optional dependency. Install with::
pip install lerobot[tools]
# or directly:
pip install pocket-tts
"""
from __future__ import annotations
import logging
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
from lerobot.datasets.language import SAY_TOOL_SCHEMA
logger = logging.getLogger(__name__)
@dataclass
class SayTool:
"""Speak a short utterance via Kyutai's pocket-tts.
Parameters
----------
schema:
Optional schema override; defaults to the canonical
``SAY_TOOL_SCHEMA`` from PR 1. Custom voices or extended
argument shapes can pass in a modified schema, but the
implementation only reads ``arguments["text"]``.
voice:
One of the pocket-tts catalog voices (``alba``, ``marius``,
``javert``, ``jean``, ``fantine``, ``cosette``, ``eponine``,
``azelma``) or a path to a ``.wav`` / ``.safetensors`` voice
file for cloning. See the pocket-tts model card for licensing.
output_dir:
If set, every ``call(...)`` writes a ``<timestamp>.wav`` audio
file there in addition to returning the PCM tensor.
``None`` (default) skips disk writes useful for live
playback paths that hand the tensor directly to a sounddevice
/ WebAudio sink.
"""
schema: dict[str, Any] = field(default_factory=lambda: dict(SAY_TOOL_SCHEMA))
voice: str = "alba"
output_dir: Path | None = None
name: str = field(init=False, default="say")
_model: Any = field(init=False, default=None, repr=False)
_voice_state: Any = field(init=False, default=None, repr=False)
_sample_rate: int = field(init=False, default=24000, repr=False)
# ------------------------------------------------------------------
# Lazy model load
# ------------------------------------------------------------------
def preload(self) -> None:
"""Load the pocket-tts model + voice state into memory.
Optional ``call(...)`` triggers this automatically on first
invocation. Useful when you want the multi-second load to
happen at startup rather than on the first ``say`` the policy
emits.
"""
if self._model is not None and self._voice_state is not None:
return
try:
from pocket_tts import TTSModel # noqa: PLC0415 (optional dep)
except ImportError as exc: # pragma: no cover (env-dependent)
raise ImportError(
"SayTool requires pocket-tts. Install with `pip install "
"lerobot[tools]` or `pip install pocket-tts`."
) from exc
logger.info("SayTool: loading pocket-tts model + voice=%r", self.voice)
self._model = TTSModel.load_model()
self._voice_state = self._model.get_state_for_audio_prompt(self.voice)
self._sample_rate = int(getattr(self._model, "sample_rate", 24000))
# ------------------------------------------------------------------
# Tool protocol
# ------------------------------------------------------------------
def call(self, arguments: dict[str, Any]) -> Any:
"""Speak ``arguments["text"]`` and return the PCM tensor.
Optionally also writes ``<output_dir>/<timestamp>.wav`` when
``self.output_dir`` is set. The returned tensor is a 1-D
``torch.Tensor`` of float32 PCM samples at
``self.sample_rate`` Hz directly playable by
``sounddevice.play(audio.numpy(), self.sample_rate)`` or
encodable by ``scipy.io.wavfile.write``.
"""
text = arguments.get("text")
if not isinstance(text, str) or not text.strip():
raise ValueError(
f"SayTool.call expects arguments={{'text': str}}, got {arguments!r}"
)
self.preload()
audio = self._model.generate_audio(self._voice_state, text)
if self.output_dir is not None:
self._write_wav(audio, text)
return audio
@property
def sample_rate(self) -> int:
"""PCM sample rate of the returned tensor (Hz)."""
return self._sample_rate
# ------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------
def _write_wav(self, audio: Any, text: str) -> Path:
"""Write a ``.wav`` next to ``output_dir`` for offline inspection."""
import time as _time # noqa: PLC0415
try:
import scipy.io.wavfile # noqa: PLC0415
except ImportError as exc: # pragma: no cover
raise ImportError(
"SayTool.output_dir requires scipy. `pip install scipy`."
) from exc
out_dir = Path(self.output_dir)
out_dir.mkdir(parents=True, exist_ok=True)
# One file per call; suffix with a millisecond timestamp + a
# short text snippet so a directory listing is informative.
snippet = "".join(c if c.isalnum() else "_" for c in text[:32]).strip("_")
ts_ms = int(_time.time() * 1000)
path = out_dir / f"say_{ts_ms}_{snippet}.wav"
# ``audio`` is a torch tensor; pocket-tts uses CPU, so a plain
# ``.numpy()`` is safe.
scipy.io.wavfile.write(path, self.sample_rate, audio.numpy())
return path