mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-21 03:30:10 +00:00
Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 37b1eb218a | |||
| 52e1fd35cb | |||
| 7459dfccb6 |
@@ -209,6 +209,14 @@ annotations = [
|
||||
"vllm>=0.6.0,<1.0.0; sys_platform == 'linux'",
|
||||
]
|
||||
|
||||
# Tool implementations under src/lerobot/tools/. Each tool's dependencies
|
||||
# are isolated so adding a new tool doesn't bloat the base install.
|
||||
# Currently only `say` (Kyutai pocket-tts; CPU-only, ~100M params).
|
||||
tools = [
|
||||
"pocket-tts>=0.1.0,<1.0.0",
|
||||
"scipy>=1.11.0,<2.0.0", # SayTool.output_dir uses scipy.io.wavfile
|
||||
]
|
||||
|
||||
# Development
|
||||
dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1", "mypy>=1.19.1", "ruff>=0.14.1", "lerobot[notebook]"]
|
||||
notebook = ["jupyter>=1.0.0,<2.0.0", "ipykernel>=6.0.0,<7.0.0"]
|
||||
|
||||
@@ -0,0 +1,88 @@
|
||||
# SmolVLA2 canonical training recipe — Hi Robot / MEM / ECoT blend.
|
||||
#
|
||||
# Same blend shape as pi05_hirobot.yaml. SmolVLA2 differs from Pi0.5 in
|
||||
# how the renderer's output is consumed:
|
||||
#
|
||||
# - SmolVLA2 calls SmolVLM's tokenizer.apply_chat_template(messages,
|
||||
# tools=DEFAULT_TOOLS) on the rendered messages, since SmolVLM is a
|
||||
# chat-pretrained backbone.
|
||||
# - The processor builds a `text_labels` tensor that masks every token
|
||||
# except those belonging to messages whose index is in
|
||||
# `target_message_indices`. Cross-entropy on those positions trains
|
||||
# the LM head.
|
||||
# - `predict_actions = bool(targets_by_stream.get("low_level"))` —
|
||||
# same convention as Pi0.5. ``low_level_execution`` is the only
|
||||
# branch that runs the action expert / flow head.
|
||||
|
||||
blend:
|
||||
|
||||
memory_update:
|
||||
weight: 0.10
|
||||
bindings:
|
||||
prior_memory: "nth_prev(style=memory, offset=1)"
|
||||
current_memory: "emitted_at(t, style=memory)"
|
||||
completed_subtask: "nth_prev(style=subtask, offset=1)"
|
||||
messages:
|
||||
- {role: user, content: "${task}", stream: high_level}
|
||||
- {role: assistant, content: "Previous memory: ${prior_memory}", stream: high_level, if_present: prior_memory}
|
||||
- {role: user, content: "Completed subtask: ${completed_subtask}", stream: high_level, if_present: completed_subtask}
|
||||
- {role: assistant, content: "${current_memory}", stream: high_level, target: true, if_present: current_memory}
|
||||
|
||||
user_interjection_response:
|
||||
weight: 0.16
|
||||
bindings:
|
||||
prior_plan: "nth_prev(style=plan, offset=1)"
|
||||
current_plan: "emitted_at(t, style=plan)"
|
||||
interjection: "emitted_at(t, style=interjection)"
|
||||
speech: "emitted_at(t, role=assistant, tool_name=say)"
|
||||
messages:
|
||||
- {role: user, content: "${task}", stream: high_level}
|
||||
- {role: assistant, content: "Previous plan:\n${prior_plan}", stream: high_level, if_present: prior_plan}
|
||||
- {role: user, content: "${interjection}", stream: high_level, if_present: interjection}
|
||||
- {role: assistant, content: "${current_plan}", stream: high_level, target: true, if_present: current_plan, tool_calls_from: speech}
|
||||
|
||||
high_level_subtask:
|
||||
weight: 0.15
|
||||
bindings:
|
||||
next_subtask: "nth_next(style=subtask, offset=1)"
|
||||
messages:
|
||||
- {role: user, content: "${task}\nPlan: ${plan}\nMemory: ${memory}", stream: high_level}
|
||||
- {role: user, content: "Current subtask: ${subtask}", stream: high_level, if_present: subtask}
|
||||
- {role: assistant, content: "${next_subtask}", stream: high_level, target: true}
|
||||
|
||||
low_level_execution:
|
||||
weight: 0.35
|
||||
messages:
|
||||
- {role: user, content: "${task}\nPlan: ${plan}\nMemory: ${memory}", stream: high_level}
|
||||
- {role: assistant, content: "${subtask}", stream: low_level, target: true}
|
||||
|
||||
# Per-camera VQA sub-recipes (PR 1's view-dependent style routing).
|
||||
# Adjust the camera keys (and add more sub-recipes) to match the
|
||||
# cameras present on your dataset.
|
||||
ask_vqa_top:
|
||||
weight: 0.10
|
||||
bindings:
|
||||
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.top)"
|
||||
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.top)"
|
||||
messages:
|
||||
- role: user
|
||||
stream: high_level
|
||||
if_present: vqa_query
|
||||
content:
|
||||
- {type: image, feature: observation.images.top}
|
||||
- {type: text, text: "${vqa_query}"}
|
||||
- {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa}
|
||||
|
||||
ask_vqa_wrist:
|
||||
weight: 0.10
|
||||
bindings:
|
||||
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.wrist)"
|
||||
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.wrist)"
|
||||
messages:
|
||||
- role: user
|
||||
stream: high_level
|
||||
if_present: vqa_query
|
||||
content:
|
||||
- {type: image, feature: observation.images.wrist}
|
||||
- {type: text, text: "${vqa_query}"}
|
||||
- {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa}
|
||||
@@ -140,6 +140,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
from .smolvla.modeling_smolvla import SmolVLAPolicy
|
||||
|
||||
return SmolVLAPolicy
|
||||
elif name == "smolvla2":
|
||||
from .smolvla2.modeling_smolvla2 import SmolVLA2Policy
|
||||
|
||||
return SmolVLA2Policy
|
||||
elif name == "sarm":
|
||||
from .sarm.modeling_sarm import SARMRewardModel
|
||||
|
||||
@@ -200,6 +204,10 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
return SACConfig(**kwargs)
|
||||
elif policy_type == "smolvla":
|
||||
return SmolVLAConfig(**kwargs)
|
||||
elif policy_type == "smolvla2":
|
||||
from .smolvla2.configuration_smolvla2 import SmolVLA2Config
|
||||
|
||||
return SmolVLA2Config(**kwargs)
|
||||
elif policy_type == "reward_classifier":
|
||||
return RewardClassifierConfig(**kwargs)
|
||||
elif policy_type == "groot":
|
||||
@@ -386,6 +394,17 @@ def make_pre_post_processors(
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif policy_cfg.type == "smolvla2":
|
||||
# NOTE: SmolVLA2Config subclasses SmolVLAConfig, so this branch
|
||||
# MUST come before the SmolVLAConfig isinstance check below
|
||||
# (otherwise SmolVLA2 would silently pick up SmolVLA's processor).
|
||||
from .smolvla2.processor_smolvla2 import make_smolvla2_pre_post_processors
|
||||
|
||||
processors = make_smolvla2_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, SmolVLAConfig):
|
||||
from .smolvla.processor_smolvla import make_smolvla_pre_post_processors
|
||||
|
||||
|
||||
@@ -0,0 +1,38 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""SmolVLA2 — SmolVLA with the SmolVLM language head re-enabled.
|
||||
|
||||
SmolVLA strips the LM head from the SmolVLM backbone because it only does
|
||||
flow-matching action prediction. SmolVLA2 keeps the LM head so the same
|
||||
model can train on the full Hi Robot / MEM / ECoT message blend defined in
|
||||
the steerable annotation plan (PR1 + PR2):
|
||||
|
||||
* action-only sub-recipes (e.g. ``low_level_execution``) → flow loss
|
||||
* text-only sub-recipes (e.g. ``memory_update``, ``ask_vqa``,
|
||||
``user_interjection_response``, ``high_level_subtask``) → CE loss on
|
||||
``lm_head`` over the recipe's target message tokens
|
||||
* mixed sub-recipes → both losses summed (weighted)
|
||||
|
||||
The ``predict_actions`` toggle follows the Pi0.5 convention from Section
|
||||
I.7 of the plan: ``True`` if any ``low_level`` target is present in the
|
||||
sample, else ``False``.
|
||||
|
||||
This package is a thin subclass of ``lerobot.policies.smolvla`` so most of
|
||||
the model code stays in one place — only the dual-loss path and the
|
||||
chat-template processor live here.
|
||||
"""
|
||||
|
||||
from .configuration_smolvla2 import SmolVLA2Config
|
||||
|
||||
__all__ = ["SmolVLA2Config"]
|
||||
@@ -0,0 +1,271 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""SmolVLA2's chat-template tokenization step.
|
||||
|
||||
Replaces SmolVLA's plain ``TokenizerProcessorStep`` for SmolVLA2 when a
|
||||
``recipe_path`` is set. Reads the rendered messages produced by
|
||||
``RenderMessagesStep`` (PR 1) and produces:
|
||||
|
||||
* ``OBS_LANGUAGE_TOKENS`` / ``OBS_LANGUAGE_ATTENTION_MASK`` —
|
||||
the chat-templated prompt tokenized by SmolVLM's tokenizer, with
|
||||
``tools=meta.tools`` (PR 1's catalog).
|
||||
* ``text_labels`` — same shape as token ids, ``-100`` everywhere except
|
||||
the positions belonging to messages whose index is in
|
||||
``target_message_indices``. The next commit's modeling forward path
|
||||
applies cross-entropy on those positions via the SmolVLM ``lm_head``.
|
||||
* ``predict_actions`` — bool tensor, ``True`` iff any of the rendered
|
||||
target messages has ``message_streams[i] == "low_level"``. The
|
||||
modeling forward uses this to gate the flow head.
|
||||
|
||||
Image / video content blocks in the rendered messages are dropped
|
||||
before tokenization — the chat template only handles text, and SmolVLA
|
||||
already passes camera tensors out-of-band via the standard
|
||||
``OBS_IMAGES_*`` features. This keeps the prefix layout unchanged
|
||||
(``embed_prefix`` puts image embeddings before language embeddings,
|
||||
matching the chat-template-stripped text order).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.datasets.language import DEFAULT_TOOLS
|
||||
from lerobot.processor.pipeline import ProcessorStep, ProcessorStepRegistry
|
||||
from lerobot.types import EnvTransition, TransitionKey
|
||||
from lerobot.utils.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="smolvla2_chat_tokenizer")
|
||||
class SmolVLA2ChatTokenizerStep(ProcessorStep):
|
||||
"""Render messages → token ids + label mask + predict_actions flag.
|
||||
|
||||
This is the bridge between the recipe stack (PR 1's
|
||||
``RenderMessagesStep`` outputs) and the SmolVLA2 modeling forward
|
||||
(next commit, which reads ``text_labels`` / ``predict_actions``).
|
||||
Pure-text turns and multi-stream targets are both handled.
|
||||
"""
|
||||
|
||||
tokenizer_name: str = "HuggingFaceTB/SmolVLM2-500M-Video-Instruct"
|
||||
max_length: int = 2048
|
||||
padding: str = "longest"
|
||||
padding_side: str = "right"
|
||||
tools: list[dict[str, Any]] | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# Lazy: don't load the tokenizer until the step actually runs,
|
||||
# so unit tests that import the module without transformers
|
||||
# installed still pass.
|
||||
self._tokenizer: Any = None
|
||||
if self.tools is None:
|
||||
# Default: ship the canonical ``say`` schema. Users who set
|
||||
# ``meta.tools`` differently can override via
|
||||
# ``with_tools(meta.tools)``.
|
||||
self.tools = list(DEFAULT_TOOLS)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def with_tools(self, tools: list[dict[str, Any]]) -> "SmolVLA2ChatTokenizerStep":
|
||||
"""Override the tools catalog rendered into the system prompt."""
|
||||
self.tools = list(tools)
|
||||
return self
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition | None:
|
||||
comp = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
|
||||
messages = comp.get("messages")
|
||||
if not messages:
|
||||
# No recipe rendering happened — nothing to do; downstream
|
||||
# falls back to whatever ``task`` is in the transition.
|
||||
return transition
|
||||
|
||||
message_streams: list[str | None] = list(comp.get("message_streams") or [])
|
||||
target_indices: list[int] = sorted(
|
||||
int(i) for i in (comp.get("target_message_indices") or [])
|
||||
)
|
||||
|
||||
tokenizer = self._get_tokenizer()
|
||||
text_messages = [_strip_lerobot_blocks(m) for m in messages]
|
||||
|
||||
# Tokenize the full chat once.
|
||||
full_ids = tokenizer.apply_chat_template(
|
||||
text_messages,
|
||||
tools=self.tools,
|
||||
add_generation_prompt=False,
|
||||
tokenize=True,
|
||||
return_tensors=None,
|
||||
)
|
||||
if isinstance(full_ids, list) and full_ids and isinstance(full_ids[0], list):
|
||||
full_ids = full_ids[0]
|
||||
|
||||
# Build the label mask by re-rendering progressively up to each
|
||||
# target message and reading off the prefix length. This is the
|
||||
# robust way to get exact token boundaries: we use the same
|
||||
# tokenizer, the same ``tools=`` argument, and the same chat
|
||||
# template — so the prefix tokens are guaranteed to be a prefix
|
||||
# of the full sequence.
|
||||
labels = [-100] * len(full_ids)
|
||||
for tgt in target_indices:
|
||||
prefix_ids = tokenizer.apply_chat_template(
|
||||
text_messages[:tgt],
|
||||
tools=self.tools,
|
||||
add_generation_prompt=False,
|
||||
tokenize=True,
|
||||
return_tensors=None,
|
||||
)
|
||||
full_through_target = tokenizer.apply_chat_template(
|
||||
text_messages[: tgt + 1],
|
||||
tools=self.tools,
|
||||
add_generation_prompt=False,
|
||||
tokenize=True,
|
||||
return_tensors=None,
|
||||
)
|
||||
if isinstance(prefix_ids, list) and prefix_ids and isinstance(prefix_ids[0], list):
|
||||
prefix_ids = prefix_ids[0]
|
||||
if (
|
||||
isinstance(full_through_target, list)
|
||||
and full_through_target
|
||||
and isinstance(full_through_target[0], list)
|
||||
):
|
||||
full_through_target = full_through_target[0]
|
||||
start = len(prefix_ids)
|
||||
end = min(len(full_through_target), len(full_ids))
|
||||
for pos in range(start, end):
|
||||
labels[pos] = int(full_ids[pos])
|
||||
|
||||
# Truncate / pad to ``max_length`` so batches collate cleanly.
|
||||
# The SmolVLA pipeline downstream relies on a fixed length
|
||||
# behaviour ("longest" or "max_length") — we mirror it here.
|
||||
if len(full_ids) > self.max_length:
|
||||
full_ids = full_ids[: self.max_length]
|
||||
labels = labels[: self.max_length]
|
||||
attn = [1] * len(full_ids)
|
||||
if self.padding == "max_length" and len(full_ids) < self.max_length:
|
||||
pad_id = (
|
||||
tokenizer.pad_token_id
|
||||
if tokenizer.pad_token_id is not None
|
||||
else 0
|
||||
)
|
||||
n_pad = self.max_length - len(full_ids)
|
||||
full_ids = full_ids + [pad_id] * n_pad
|
||||
labels = labels + [-100] * n_pad
|
||||
attn = attn + [0] * n_pad
|
||||
|
||||
ids_t = torch.tensor(full_ids, dtype=torch.long)
|
||||
attn_t = torch.tensor(attn, dtype=torch.bool)
|
||||
labels_t = torch.tensor(labels, dtype=torch.long)
|
||||
predict_actions = any(
|
||||
i < len(message_streams) and message_streams[i] == "low_level"
|
||||
for i in target_indices
|
||||
)
|
||||
|
||||
new_complementary = dict(comp)
|
||||
# Drop the per-recipe sidecar keys; everything downstream needs
|
||||
# is now in the tokenized form.
|
||||
new_complementary.pop("messages", None)
|
||||
new_complementary.pop("message_streams", None)
|
||||
new_complementary.pop("target_message_indices", None)
|
||||
# SmolVLA's pipeline expects ``OBS_LANGUAGE_TOKENS`` /
|
||||
# ``OBS_LANGUAGE_ATTENTION_MASK`` on the OBSERVATION key. Place
|
||||
# them there — and drop ``task`` so the upstream
|
||||
# ``TokenizerProcessorStep`` (which we replace) doesn't double-
|
||||
# tokenize.
|
||||
observation = dict(transition.get(TransitionKey.OBSERVATION) or {})
|
||||
observation[OBS_LANGUAGE_TOKENS] = ids_t
|
||||
observation[OBS_LANGUAGE_ATTENTION_MASK] = attn_t
|
||||
new_complementary["text_labels"] = labels_t
|
||||
new_complementary["predict_actions"] = torch.tensor(predict_actions, dtype=torch.bool)
|
||||
new_complementary.pop("task", None)
|
||||
|
||||
new_transition = dict(transition)
|
||||
new_transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary
|
||||
new_transition[TransitionKey.OBSERVATION] = observation
|
||||
return new_transition
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
"""Pass-through; this step writes runtime tensors not features."""
|
||||
return features
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _get_tokenizer(self): # noqa: ANN202
|
||||
if self._tokenizer is not None:
|
||||
return self._tokenizer
|
||||
try:
|
||||
from transformers import AutoTokenizer # noqa: PLC0415
|
||||
except ImportError as exc: # pragma: no cover
|
||||
raise ImportError(
|
||||
"SmolVLA2ChatTokenizerStep requires transformers. "
|
||||
"`pip install lerobot[transformers-dep]`."
|
||||
) from exc
|
||||
self._tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name)
|
||||
if self._tokenizer.pad_token_id is None and self._tokenizer.eos_token_id is not None:
|
||||
self._tokenizer.pad_token = self._tokenizer.eos_token
|
||||
return self._tokenizer
|
||||
|
||||
|
||||
def _strip_lerobot_blocks(message: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Remove LeRobot-specific multimodal blocks from ``message`` content.
|
||||
|
||||
The recipe DSL allows authors to write multimodal content like
|
||||
``{"type": "image", "feature": "observation.images.top"}``. SmolVLM's
|
||||
tokenizer doesn't know that ``feature`` key (it expects ``url`` or
|
||||
``path``). The actual image tensor flows through SmolVLA's
|
||||
``OBS_IMAGES_*`` channels separately; the chat template only needs
|
||||
the text. So we strip non-text blocks before tokenizing.
|
||||
"""
|
||||
new = dict(message)
|
||||
content = new.get("content")
|
||||
if isinstance(content, list):
|
||||
text_parts: list[dict[str, Any]] = []
|
||||
for block in content:
|
||||
if not isinstance(block, dict):
|
||||
continue
|
||||
if block.get("type") == "text":
|
||||
text_parts.append({"type": "text", "text": str(block.get("text", ""))})
|
||||
# If only one text block survives, flatten to a string for
|
||||
# template friendliness; some chat templates choke on a single-
|
||||
# element list.
|
||||
if len(text_parts) == 1:
|
||||
new["content"] = text_parts[0]["text"]
|
||||
elif text_parts:
|
||||
new["content"] = text_parts
|
||||
else:
|
||||
new["content"] = ""
|
||||
if "tool_calls" in new and not new["tool_calls"]:
|
||||
# Drop empty tool_calls — some templates render them as a
|
||||
# spurious empty marker.
|
||||
new.pop("tool_calls")
|
||||
# ``stream`` and ``target`` were recipe metadata; templates don't
|
||||
# know them and may warn or crash.
|
||||
new.pop("stream", None)
|
||||
new.pop("target", None)
|
||||
return new
|
||||
|
||||
|
||||
# Re-export for tests / introspection
|
||||
strip_lerobot_blocks = _strip_lerobot_blocks
|
||||
@@ -0,0 +1,97 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from lerobot.configs import PreTrainedConfig
|
||||
|
||||
from ..smolvla.configuration_smolvla import SmolVLAConfig
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("smolvla2")
|
||||
@dataclass
|
||||
class SmolVLA2Config(SmolVLAConfig):
|
||||
"""SmolVLA2 — SmolVLA with the underlying SmolVLM language head re-enabled.
|
||||
|
||||
SmolVLA strips the LM head from the SmolVLM backbone because it only
|
||||
needs flow-matching action prediction. SmolVLA2 keeps the LM head so the
|
||||
same model can train on:
|
||||
|
||||
* **action-only sub-recipes** (e.g. ``low_level_execution``) — flow loss
|
||||
on the action expert, same as SmolVLA. ``predict_actions=True``.
|
||||
* **text-only sub-recipes** (e.g. ``memory_update`` / ``ask_vqa`` /
|
||||
``user_interjection_response`` / ``high_level_subtask``) — cross-
|
||||
entropy loss on the LM head over the recipe's target message tokens.
|
||||
Skips the flow head entirely. ``predict_actions=False``.
|
||||
* **mixed sub-recipes** — both heads run, losses summed (weighted).
|
||||
|
||||
The split is controlled by ``predict_actions = bool(targets_by_stream
|
||||
.get("low_level"))`` per the Pi0.5 convention in the steerable
|
||||
annotation plan (Section I.7), implemented inside the processor /
|
||||
forward path. Recipes drive it via ``stream`` + ``target`` metadata.
|
||||
|
||||
Compared to ``SmolVLAConfig`` this adds:
|
||||
|
||||
- ``recipe_path``: path to a ``TrainingRecipe`` YAML (loaded by the
|
||||
train script). When ``None``, SmolVLA2 falls back to the SmolVLA
|
||||
task-only path so unannotated datasets still work.
|
||||
- ``text_loss_weight`` / ``flow_loss_weight``: relative weights when
|
||||
both losses are active in a single sample.
|
||||
- ``unfreeze_lm_head``: must be ``True`` for the text head to learn —
|
||||
SmolVLA freezes ``lm_head`` to "avoid unused params issues" and we
|
||||
need to undo that for SmolVLA2.
|
||||
- ``train_expert_only=False`` by default, since the VLM body now also
|
||||
participates in text-target gradients.
|
||||
"""
|
||||
|
||||
# Recipe / language stack ---------------------------------------------
|
||||
recipe_path: str | None = "recipes/smolvla2_hirobot.yaml"
|
||||
"""Path (absolute or relative to ``src/lerobot/configs/``) to a
|
||||
``TrainingRecipe`` YAML. The default points at the canonical Hi Robot
|
||||
blend shipped alongside SmolVLA2. Set to ``None`` to disable recipe
|
||||
rendering and fall back to SmolVLA's single-task prompt path
|
||||
(unannotated datasets keep working that way)."""
|
||||
|
||||
apply_chat_template: bool = True
|
||||
"""Apply the SmolVLM tokenizer's chat template to the rendered messages
|
||||
before tokenizing. SmolVLM's backbone is chat-pretrained, so this
|
||||
matches its training distribution."""
|
||||
|
||||
# Loss weights --------------------------------------------------------
|
||||
text_loss_weight: float = 1.0
|
||||
"""Weight on the LM-head cross-entropy term. Set to ``0`` to disable
|
||||
text training entirely (reverts to flow-only / SmolVLA behaviour)."""
|
||||
|
||||
flow_loss_weight: float = 1.0
|
||||
"""Weight on the action-expert flow-matching term."""
|
||||
|
||||
# Backbone training ---------------------------------------------------
|
||||
unfreeze_lm_head: bool = True
|
||||
"""Whether to unfreeze the SmolVLM ``lm_head`` (and the immediately
|
||||
preceding norm + last text-model layer that SmolVLA freezes). Must be
|
||||
``True`` for the text head to learn. Setting this to ``False``
|
||||
effectively reduces SmolVLA2 back to SmolVLA's flow-only training,
|
||||
which is occasionally useful for ablations."""
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
super().__post_init__()
|
||||
# Backbone needs gradients flowing through its text path when the
|
||||
# LM head is producing supervised text. Override the SmolVLA
|
||||
# default (`train_expert_only=True`) unless the user explicitly
|
||||
# opts out of text training via `text_loss_weight=0`.
|
||||
if self.text_loss_weight > 0 and self.unfreeze_lm_head:
|
||||
# The user can still flip this back via CLI; this only
|
||||
# changes the *default* when SmolVLA2 is actually training a
|
||||
# text head.
|
||||
self.train_expert_only = False
|
||||
@@ -0,0 +1,119 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""SmolVLA2 modeling — dual-head subclass of SmolVLAPolicy.
|
||||
|
||||
This module defines :class:`SmolVLA2Policy`, which extends SmolVLA with:
|
||||
|
||||
* an unfrozen SmolVLM ``lm_head`` so language tokens can be supervised,
|
||||
* a forward path that routes to the flow head, the text head, or both,
|
||||
driven by ``batch["predict_actions"]`` and ``batch["text_labels"]``.
|
||||
|
||||
The text-head computation itself is NOT wired up in this scaffold commit
|
||||
(the processor doesn't yet produce ``text_labels`` either). This file is
|
||||
the structural placeholder that:
|
||||
|
||||
1. registers the ``SmolVLA2Policy`` class with the right config name so
|
||||
``policies/factory.py`` can build it,
|
||||
2. unfreezes ``lm_head`` at construction time when the config asks for it
|
||||
(otherwise SmolVLA's ``train_expert_only`` freezes it again on every
|
||||
``train()`` call),
|
||||
3. forwards to ``SmolVLAPolicy.forward`` so behaviour is identical to
|
||||
SmolVLA when no text labels are present — i.e. existing SmolVLA
|
||||
training scripts keep working.
|
||||
|
||||
The next commit on this branch fills in the actual text-loss path.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from ..smolvla.modeling_smolvla import SmolVLAPolicy
|
||||
from .configuration_smolvla2 import SmolVLA2Config
|
||||
|
||||
|
||||
class SmolVLA2Policy(SmolVLAPolicy):
|
||||
"""SmolVLA + re-enabled SmolVLM language head.
|
||||
|
||||
Compatible drop-in for ``SmolVLAPolicy`` from a checkpoint or factory
|
||||
perspective. Behaviourally identical to SmolVLA until the text-head
|
||||
code path lands in the next commit on this branch.
|
||||
"""
|
||||
|
||||
config_class = SmolVLA2Config
|
||||
name = "smolvla2"
|
||||
|
||||
def __init__(self, config: SmolVLA2Config, dataset_stats: dict[str, dict[str, Tensor]] | None = None):
|
||||
if not isinstance(config, SmolVLA2Config):
|
||||
# Allow loading a SmolVLA checkpoint into a SmolVLA2 model by
|
||||
# widening the config type — the new fields fall back to their
|
||||
# defaults, which preserves the existing SmolVLA behaviour.
|
||||
config = SmolVLA2Config(**{
|
||||
f.name: getattr(config, f.name)
|
||||
for f in config.__dataclass_fields__.values()
|
||||
if hasattr(config, f.name)
|
||||
})
|
||||
super().__init__(config, dataset_stats=dataset_stats)
|
||||
if config.unfreeze_lm_head and config.text_loss_weight > 0:
|
||||
self._unfreeze_lm_head()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Backbone surgery
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _unfreeze_lm_head(self) -> None:
|
||||
"""Re-enable gradients on the SmolVLM ``lm_head`` (and the bits of
|
||||
the text path SmolVLA freezes) so the text-loss can flow back.
|
||||
|
||||
SmolVLA's ``SmolVLMWithExpertModel.set_requires_grad`` freezes
|
||||
``lm_head``, ``text_model.model.norm.weight``, and the last
|
||||
``text_model.layers.<N-1>`` block. We undo that selectively when
|
||||
text training is enabled.
|
||||
"""
|
||||
vlm_with_expert = getattr(self.model, "vlm_with_expert", None)
|
||||
if vlm_with_expert is None:
|
||||
return
|
||||
vlm = getattr(vlm_with_expert, "vlm", None)
|
||||
if vlm is None:
|
||||
return
|
||||
for name, param in vlm.named_parameters():
|
||||
if (
|
||||
"lm_head" in name
|
||||
or "text_model.model.norm.weight" in name
|
||||
):
|
||||
param.requires_grad = True
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Forward
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def forward(
|
||||
self,
|
||||
batch: dict[str, Tensor],
|
||||
noise: Tensor | None = None,
|
||||
time: Tensor | None = None,
|
||||
reduction: str = "mean",
|
||||
) -> tuple[Tensor, dict[str, Any]]:
|
||||
"""Forward pass with optional text-head loss.
|
||||
|
||||
SCAFFOLD: forwards directly to ``SmolVLAPolicy.forward``. The
|
||||
actual text-loss / dual-head routing lands in the next commit on
|
||||
this branch — it will read ``batch["text_labels"]`` and
|
||||
``batch["predict_actions"]`` (both produced by the SmolVLA2
|
||||
processor) to decide which head(s) to run.
|
||||
"""
|
||||
return super().forward(batch, noise=noise, time=time, reduction=reduction)
|
||||
@@ -0,0 +1,131 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""SmolVLA2 processor pipelines.
|
||||
|
||||
When ``config.recipe_path`` is set, the pre-processor pipeline becomes:
|
||||
|
||||
rename observations
|
||||
add batch dim
|
||||
RenderMessagesStep(recipe) # PR 1: language_* → messages
|
||||
SmolVLA2ChatTokenizerStep(...) # chat template + label mask + predict_actions
|
||||
DeviceProcessorStep
|
||||
NormalizerProcessorStep
|
||||
|
||||
When ``config.recipe_path`` is ``None``, we delegate to SmolVLA's
|
||||
plain task-string pipeline so unannotated datasets still work.
|
||||
|
||||
Post-processor is unchanged from SmolVLA.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.recipe import TrainingRecipe
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyAction,
|
||||
PolicyProcessorPipeline,
|
||||
RenameObservationsProcessorStep,
|
||||
RenderMessagesStep,
|
||||
UnnormalizerProcessorStep,
|
||||
policy_action_to_transition,
|
||||
transition_to_policy_action,
|
||||
)
|
||||
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
|
||||
from ..smolvla.processor_smolvla import make_smolvla_pre_post_processors
|
||||
from .chat_processor_smolvla2 import SmolVLA2ChatTokenizerStep
|
||||
from .configuration_smolvla2 import SmolVLA2Config
|
||||
|
||||
|
||||
def make_smolvla2_pre_post_processors(
|
||||
config: SmolVLA2Config,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
"""Build SmolVLA2's pre/post-processor pipelines.
|
||||
|
||||
With ``recipe_path`` set, inserts the recipe-rendering step and the
|
||||
chat-template tokenizer that emits ``text_labels`` and
|
||||
``predict_actions`` for the dual-loss path. Without it, falls back
|
||||
to SmolVLA's plain task-string pipeline so unannotated datasets
|
||||
keep working unchanged.
|
||||
"""
|
||||
if not config.recipe_path:
|
||||
return make_smolvla_pre_post_processors(config, dataset_stats=dataset_stats)
|
||||
|
||||
recipe = _load_recipe(config.recipe_path)
|
||||
|
||||
input_steps = [
|
||||
RenameObservationsProcessorStep(rename_map={}),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
RenderMessagesStep(recipe=recipe),
|
||||
SmolVLA2ChatTokenizerStep(
|
||||
tokenizer_name=config.vlm_model_name,
|
||||
max_length=config.tokenizer_max_length,
|
||||
padding=config.pad_language_to,
|
||||
),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
]
|
||||
output_steps = [
|
||||
UnnormalizerProcessorStep(
|
||||
features=config.output_features,
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
]
|
||||
return (
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||
steps=input_steps,
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
),
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction](
|
||||
steps=output_steps,
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
to_transition=policy_action_to_transition,
|
||||
to_output=transition_to_policy_action,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _load_recipe(path_str: str) -> TrainingRecipe:
|
||||
"""Resolve ``path_str`` to a ``TrainingRecipe``.
|
||||
|
||||
Accepts an absolute path or a path relative to
|
||||
``src/lerobot/configs/`` so recipe authors can write
|
||||
``--policy.recipe_path=recipes/smolvla2_hirobot.yaml``.
|
||||
"""
|
||||
p = Path(path_str)
|
||||
if not p.is_absolute() and not p.exists():
|
||||
from lerobot.configs import recipe as _recipe_module # noqa: PLC0415
|
||||
|
||||
configs_dir = Path(_recipe_module.__file__).resolve().parent
|
||||
candidate = configs_dir / path_str
|
||||
if candidate.exists():
|
||||
p = candidate
|
||||
return TrainingRecipe.from_yaml(p)
|
||||
@@ -0,0 +1,29 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""LeRobot tool implementations.
|
||||
|
||||
Storage of the tool catalog (``meta/info.json["tools"]``) and the
|
||||
``SAY_TOOL_SCHEMA`` constant live in PR 1
|
||||
(``lerobot.datasets.language``). This package holds the *runnable*
|
||||
implementations one file per tool, plus the registry that maps tool
|
||||
names to classes.
|
||||
|
||||
See ``docs/source/tools.mdx`` for the authoring guide.
|
||||
"""
|
||||
|
||||
from .base import Tool
|
||||
from .registry import TOOL_REGISTRY, get_tools
|
||||
from .say import SayTool
|
||||
|
||||
__all__ = ["Tool", "TOOL_REGISTRY", "get_tools", "SayTool"]
|
||||
@@ -0,0 +1,58 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tool protocol — the contract every runnable tool implementation honors.
|
||||
|
||||
Tools are the executable side of the OpenAI-style function-calling
|
||||
abstraction the v3.1 language schema (PR 1) carries on assistant
|
||||
messages: the schema describes *what can be called*, the tool
|
||||
implementation describes *how to call it*.
|
||||
|
||||
Implementations live one-per-file under :mod:`lerobot.tools` (e.g.
|
||||
``say.py`` for ``SayTool``) and are registered in
|
||||
:mod:`lerobot.tools.registry`. The runtime instantiates them lazily so
|
||||
heavy dependencies (torch models, audio backends, network clients,
|
||||
hardware drivers) only load when the dataset actually declares the tool.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Protocol, runtime_checkable
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Tool(Protocol):
|
||||
"""Minimum surface every tool must expose."""
|
||||
|
||||
#: Name matching ``schema["function"]["name"]``. The runtime dispatcher
|
||||
#: routes incoming ``tool_calls`` to the implementation by this key.
|
||||
name: str
|
||||
|
||||
#: OpenAI-style function-call schema. Same dict the dataset stores in
|
||||
#: ``meta/info.json["tools"]`` and the chat template renders into the
|
||||
#: prompt.
|
||||
schema: dict[str, Any]
|
||||
|
||||
def call(self, arguments: dict[str, Any]) -> Any:
|
||||
"""Execute the tool with the model-provided arguments.
|
||||
|
||||
``arguments`` is the parsed dict from
|
||||
``tool_calls[i]["function"]["arguments"]`` (already JSON-decoded
|
||||
when the model emits a JSON-string by the chat-template
|
||||
convention). Implementations validate the dict against their own
|
||||
schema; the runtime only routes by name.
|
||||
|
||||
Return value is implementation-defined — typically a tensor
|
||||
(TTS audio), a Path (saved file), a dict (structured result), or
|
||||
``None`` (side-effect-only call).
|
||||
"""
|
||||
@@ -0,0 +1,70 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tool registry — name → implementation class.
|
||||
|
||||
Adding a new tool:
|
||||
|
||||
1. Drop a file under ``src/lerobot/tools/`` that defines a class
|
||||
conforming to :class:`lerobot.tools.base.Tool` (must expose ``name``,
|
||||
``schema``, ``call(arguments)``).
|
||||
2. Register the class here under :data:`TOOL_REGISTRY`.
|
||||
3. (Optional) Pre-populate ``meta/info.json["tools"]`` on your dataset
|
||||
to advertise the schema to the chat-template + policy. The PR 2
|
||||
annotation pipeline preserves anything you put there.
|
||||
|
||||
See ``docs/source/tools.mdx`` for the full authoring guide.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .base import Tool
|
||||
from .say import SayTool
|
||||
|
||||
#: Map from ``function.name`` to a class implementing :class:`Tool`.
|
||||
#: The runtime instantiates entries lazily — registering a tool here is
|
||||
#: essentially free (no model load happens until ``call`` runs).
|
||||
TOOL_REGISTRY: dict[str, type] = {
|
||||
"say": SayTool,
|
||||
}
|
||||
|
||||
|
||||
def get_tools(meta: Any, **kwargs: Any) -> dict[str, Tool]:
|
||||
"""Build name → tool-instance dict from a dataset's declared catalog.
|
||||
|
||||
``meta`` is anything with a ``.tools`` attribute returning the
|
||||
OpenAI-style schema list — typically a
|
||||
:class:`lerobot.datasets.dataset_metadata.LeRobotDatasetMetadata`.
|
||||
Each entry whose ``function.name`` is registered here is
|
||||
instantiated with the schema dict; tools whose name is unknown to
|
||||
the registry are skipped (the schema still rides through the chat
|
||||
template, the model just can't actually invoke that tool at
|
||||
inference).
|
||||
|
||||
Extra keyword arguments are forwarded to every constructor — useful
|
||||
for runtime defaults like ``output_dir=Path("./tts_log")``.
|
||||
"""
|
||||
declared = list(meta.tools)
|
||||
instances: dict[str, Tool] = {}
|
||||
for schema in declared:
|
||||
try:
|
||||
name = schema["function"]["name"]
|
||||
except (KeyError, TypeError):
|
||||
continue
|
||||
cls = TOOL_REGISTRY.get(name)
|
||||
if cls is None:
|
||||
continue
|
||||
instances[name] = cls(schema=schema, **kwargs)
|
||||
return instances
|
||||
@@ -0,0 +1,170 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""``SayTool`` — text-to-speech tool wrapping Kyutai's pocket-tts.
|
||||
|
||||
The first concrete tool implementation. SmolVLA2 (PR 3) and downstream
|
||||
runtime dispatchers consume this when the model emits an assistant
|
||||
message with ``tool_calls=[{function: {name: "say", arguments:
|
||||
{text: ...}}}]``.
|
||||
|
||||
Why pocket-tts:
|
||||
|
||||
- runs on CPU (no GPU dependency); ~6× real-time on a MacBook Air M4
|
||||
- ~100M parameters, ~200ms first-chunk latency
|
||||
- streamable, voice-cloneable
|
||||
- pip-installable, MIT-style permissive license
|
||||
|
||||
The pocket-tts model is loaded **lazily** the first time ``call(...)``
|
||||
runs (or eagerly via ``preload()``). Loading takes a few seconds and
|
||||
several hundred MB of RAM, so we don't pay the cost when the tool is
|
||||
merely *registered* — only when it's *invoked*.
|
||||
|
||||
Optional dependency. Install with::
|
||||
|
||||
pip install lerobot[tools]
|
||||
# or directly:
|
||||
pip install pocket-tts
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from lerobot.datasets.language import SAY_TOOL_SCHEMA
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SayTool:
|
||||
"""Speak a short utterance via Kyutai's pocket-tts.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
schema:
|
||||
Optional schema override; defaults to the canonical
|
||||
``SAY_TOOL_SCHEMA`` from PR 1. Custom voices or extended
|
||||
argument shapes can pass in a modified schema, but the
|
||||
implementation only reads ``arguments["text"]``.
|
||||
voice:
|
||||
One of the pocket-tts catalog voices (``alba``, ``marius``,
|
||||
``javert``, ``jean``, ``fantine``, ``cosette``, ``eponine``,
|
||||
``azelma``) or a path to a ``.wav`` / ``.safetensors`` voice
|
||||
file for cloning. See the pocket-tts model card for licensing.
|
||||
output_dir:
|
||||
If set, every ``call(...)`` writes a ``<timestamp>.wav`` audio
|
||||
file there in addition to returning the PCM tensor.
|
||||
``None`` (default) skips disk writes — useful for live
|
||||
playback paths that hand the tensor directly to a sounddevice
|
||||
/ WebAudio sink.
|
||||
"""
|
||||
|
||||
schema: dict[str, Any] = field(default_factory=lambda: dict(SAY_TOOL_SCHEMA))
|
||||
voice: str = "alba"
|
||||
output_dir: Path | None = None
|
||||
|
||||
name: str = field(init=False, default="say")
|
||||
_model: Any = field(init=False, default=None, repr=False)
|
||||
_voice_state: Any = field(init=False, default=None, repr=False)
|
||||
_sample_rate: int = field(init=False, default=24000, repr=False)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Lazy model load
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def preload(self) -> None:
|
||||
"""Load the pocket-tts model + voice state into memory.
|
||||
|
||||
Optional — ``call(...)`` triggers this automatically on first
|
||||
invocation. Useful when you want the multi-second load to
|
||||
happen at startup rather than on the first ``say`` the policy
|
||||
emits.
|
||||
"""
|
||||
if self._model is not None and self._voice_state is not None:
|
||||
return
|
||||
try:
|
||||
from pocket_tts import TTSModel # noqa: PLC0415 (optional dep)
|
||||
except ImportError as exc: # pragma: no cover (env-dependent)
|
||||
raise ImportError(
|
||||
"SayTool requires pocket-tts. Install with `pip install "
|
||||
"lerobot[tools]` or `pip install pocket-tts`."
|
||||
) from exc
|
||||
logger.info("SayTool: loading pocket-tts model + voice=%r", self.voice)
|
||||
self._model = TTSModel.load_model()
|
||||
self._voice_state = self._model.get_state_for_audio_prompt(self.voice)
|
||||
self._sample_rate = int(getattr(self._model, "sample_rate", 24000))
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Tool protocol
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def call(self, arguments: dict[str, Any]) -> Any:
|
||||
"""Speak ``arguments["text"]`` and return the PCM tensor.
|
||||
|
||||
Optionally also writes ``<output_dir>/<timestamp>.wav`` when
|
||||
``self.output_dir`` is set. The returned tensor is a 1-D
|
||||
``torch.Tensor`` of float32 PCM samples at
|
||||
``self.sample_rate`` Hz — directly playable by
|
||||
``sounddevice.play(audio.numpy(), self.sample_rate)`` or
|
||||
encodable by ``scipy.io.wavfile.write``.
|
||||
"""
|
||||
text = arguments.get("text")
|
||||
if not isinstance(text, str) or not text.strip():
|
||||
raise ValueError(
|
||||
f"SayTool.call expects arguments={{'text': str}}, got {arguments!r}"
|
||||
)
|
||||
self.preload()
|
||||
|
||||
audio = self._model.generate_audio(self._voice_state, text)
|
||||
|
||||
if self.output_dir is not None:
|
||||
self._write_wav(audio, text)
|
||||
|
||||
return audio
|
||||
|
||||
@property
|
||||
def sample_rate(self) -> int:
|
||||
"""PCM sample rate of the returned tensor (Hz)."""
|
||||
return self._sample_rate
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _write_wav(self, audio: Any, text: str) -> Path:
|
||||
"""Write a ``.wav`` next to ``output_dir`` for offline inspection."""
|
||||
import time as _time # noqa: PLC0415
|
||||
|
||||
try:
|
||||
import scipy.io.wavfile # noqa: PLC0415
|
||||
except ImportError as exc: # pragma: no cover
|
||||
raise ImportError(
|
||||
"SayTool.output_dir requires scipy. `pip install scipy`."
|
||||
) from exc
|
||||
|
||||
out_dir = Path(self.output_dir)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
# One file per call; suffix with a millisecond timestamp + a
|
||||
# short text snippet so a directory listing is informative.
|
||||
snippet = "".join(c if c.isalnum() else "_" for c in text[:32]).strip("_")
|
||||
ts_ms = int(_time.time() * 1000)
|
||||
path = out_dir / f"say_{ts_ms}_{snippet}.wav"
|
||||
|
||||
# ``audio`` is a torch tensor; pocket-tts uses CPU, so a plain
|
||||
# ``.numpy()`` is safe.
|
||||
scipy.io.wavfile.write(path, self.sample_rate, audio.numpy())
|
||||
return path
|
||||
Reference in New Issue
Block a user