mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-23 04:30:10 +00:00
feat(tools): src/lerobot/tools/ — runnable tool registry + SayTool
Ships the runtime side of the OpenAI-style function-calling stack
introduced in PR 1 (catalog in ``meta/info.json["tools"]``) and PR 2
(annotation pipeline writes the catalog after a run). One file per
tool — heavy deps stay isolated.
Layout:
- ``base.py`` — :class:`Tool` Protocol: ``name``, ``schema``,
``call(arguments)``. Runtime-checkable so tests can use
``isinstance(...)``.
- ``registry.py`` — :data:`TOOL_REGISTRY` (name → class) plus
``get_tools(meta, **kwargs)`` that instantiates every entry whose
``function.name`` is registered. Tools whose name is unknown are
silently skipped — the schema still rides through the chat
template, the model just can't actually invoke that tool at
inference.
- ``say.py`` — :class:`SayTool` wrapping Kyutai's pocket-tts
(CPU-only, ~100M params, ~6× real-time on a MacBook Air M4).
Lazy model load: pocket-tts is imported and the voice state
computed on first ``call(...)`` (or eagerly via ``preload()``).
Returns the PCM tensor; optionally writes a ``.wav`` to
``output_dir`` for offline inspection.
- ``__init__.py`` — re-exports the public surface.
Optional install:
pip install lerobot[tools]
The ``[tools]`` extra in ``pyproject.toml`` pulls in ``pocket-tts`` +
``scipy`` (for the wav writer). Adding more tools later means a new
file + a registry entry — no new extras unless the tool brings new
deps.
To add your own tool, follow the three-step guide in
``docs/source/tools.mdx`` (PR 1):
1. Drop ``src/lerobot/tools/<my_tool>.py`` with a ``Tool``-conforming
class.
2. Register the class in ``TOOL_REGISTRY`` (this file).
3. Pre-populate ``meta/info.json["tools"]`` with the schema (or let
``lerobot-annotate`` add it on the next run).
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -209,6 +209,14 @@ annotations = [
|
||||
"vllm>=0.6.0,<1.0.0; sys_platform == 'linux'",
|
||||
]
|
||||
|
||||
# Tool implementations under src/lerobot/tools/. Each tool's dependencies
|
||||
# are isolated so adding a new tool doesn't bloat the base install.
|
||||
# Currently only `say` (Kyutai pocket-tts; CPU-only, ~100M params).
|
||||
tools = [
|
||||
"pocket-tts>=0.1.0,<1.0.0",
|
||||
"scipy>=1.11.0,<2.0.0", # SayTool.output_dir uses scipy.io.wavfile
|
||||
]
|
||||
|
||||
# Development
|
||||
dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1", "mypy>=1.19.1", "ruff>=0.14.1", "lerobot[notebook]"]
|
||||
notebook = ["jupyter>=1.0.0,<2.0.0", "ipykernel>=6.0.0,<7.0.0"]
|
||||
|
||||
@@ -0,0 +1,29 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""LeRobot tool implementations.
|
||||
|
||||
Storage of the tool catalog (``meta/info.json["tools"]``) and the
|
||||
``SAY_TOOL_SCHEMA`` constant live in PR 1
|
||||
(``lerobot.datasets.language``). This package holds the *runnable*
|
||||
implementations one file per tool, plus the registry that maps tool
|
||||
names to classes.
|
||||
|
||||
See ``docs/source/tools.mdx`` for the authoring guide.
|
||||
"""
|
||||
|
||||
from .base import Tool
|
||||
from .registry import TOOL_REGISTRY, get_tools
|
||||
from .say import SayTool
|
||||
|
||||
__all__ = ["Tool", "TOOL_REGISTRY", "get_tools", "SayTool"]
|
||||
@@ -0,0 +1,58 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tool protocol — the contract every runnable tool implementation honors.
|
||||
|
||||
Tools are the executable side of the OpenAI-style function-calling
|
||||
abstraction the v3.1 language schema (PR 1) carries on assistant
|
||||
messages: the schema describes *what can be called*, the tool
|
||||
implementation describes *how to call it*.
|
||||
|
||||
Implementations live one-per-file under :mod:`lerobot.tools` (e.g.
|
||||
``say.py`` for ``SayTool``) and are registered in
|
||||
:mod:`lerobot.tools.registry`. The runtime instantiates them lazily so
|
||||
heavy dependencies (torch models, audio backends, network clients,
|
||||
hardware drivers) only load when the dataset actually declares the tool.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Protocol, runtime_checkable
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Tool(Protocol):
|
||||
"""Minimum surface every tool must expose."""
|
||||
|
||||
#: Name matching ``schema["function"]["name"]``. The runtime dispatcher
|
||||
#: routes incoming ``tool_calls`` to the implementation by this key.
|
||||
name: str
|
||||
|
||||
#: OpenAI-style function-call schema. Same dict the dataset stores in
|
||||
#: ``meta/info.json["tools"]`` and the chat template renders into the
|
||||
#: prompt.
|
||||
schema: dict[str, Any]
|
||||
|
||||
def call(self, arguments: dict[str, Any]) -> Any:
|
||||
"""Execute the tool with the model-provided arguments.
|
||||
|
||||
``arguments`` is the parsed dict from
|
||||
``tool_calls[i]["function"]["arguments"]`` (already JSON-decoded
|
||||
when the model emits a JSON-string by the chat-template
|
||||
convention). Implementations validate the dict against their own
|
||||
schema; the runtime only routes by name.
|
||||
|
||||
Return value is implementation-defined — typically a tensor
|
||||
(TTS audio), a Path (saved file), a dict (structured result), or
|
||||
``None`` (side-effect-only call).
|
||||
"""
|
||||
@@ -0,0 +1,70 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tool registry — name → implementation class.
|
||||
|
||||
Adding a new tool:
|
||||
|
||||
1. Drop a file under ``src/lerobot/tools/`` that defines a class
|
||||
conforming to :class:`lerobot.tools.base.Tool` (must expose ``name``,
|
||||
``schema``, ``call(arguments)``).
|
||||
2. Register the class here under :data:`TOOL_REGISTRY`.
|
||||
3. (Optional) Pre-populate ``meta/info.json["tools"]`` on your dataset
|
||||
to advertise the schema to the chat-template + policy. The PR 2
|
||||
annotation pipeline preserves anything you put there.
|
||||
|
||||
See ``docs/source/tools.mdx`` for the full authoring guide.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .base import Tool
|
||||
from .say import SayTool
|
||||
|
||||
#: Map from ``function.name`` to a class implementing :class:`Tool`.
|
||||
#: The runtime instantiates entries lazily — registering a tool here is
|
||||
#: essentially free (no model load happens until ``call`` runs).
|
||||
TOOL_REGISTRY: dict[str, type] = {
|
||||
"say": SayTool,
|
||||
}
|
||||
|
||||
|
||||
def get_tools(meta: Any, **kwargs: Any) -> dict[str, Tool]:
|
||||
"""Build name → tool-instance dict from a dataset's declared catalog.
|
||||
|
||||
``meta`` is anything with a ``.tools`` attribute returning the
|
||||
OpenAI-style schema list — typically a
|
||||
:class:`lerobot.datasets.dataset_metadata.LeRobotDatasetMetadata`.
|
||||
Each entry whose ``function.name`` is registered here is
|
||||
instantiated with the schema dict; tools whose name is unknown to
|
||||
the registry are skipped (the schema still rides through the chat
|
||||
template, the model just can't actually invoke that tool at
|
||||
inference).
|
||||
|
||||
Extra keyword arguments are forwarded to every constructor — useful
|
||||
for runtime defaults like ``output_dir=Path("./tts_log")``.
|
||||
"""
|
||||
declared = list(meta.tools)
|
||||
instances: dict[str, Tool] = {}
|
||||
for schema in declared:
|
||||
try:
|
||||
name = schema["function"]["name"]
|
||||
except (KeyError, TypeError):
|
||||
continue
|
||||
cls = TOOL_REGISTRY.get(name)
|
||||
if cls is None:
|
||||
continue
|
||||
instances[name] = cls(schema=schema, **kwargs)
|
||||
return instances
|
||||
@@ -0,0 +1,170 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""``SayTool`` — text-to-speech tool wrapping Kyutai's pocket-tts.
|
||||
|
||||
The first concrete tool implementation. SmolVLA2 (PR 3) and downstream
|
||||
runtime dispatchers consume this when the model emits an assistant
|
||||
message with ``tool_calls=[{function: {name: "say", arguments:
|
||||
{text: ...}}}]``.
|
||||
|
||||
Why pocket-tts:
|
||||
|
||||
- runs on CPU (no GPU dependency); ~6× real-time on a MacBook Air M4
|
||||
- ~100M parameters, ~200ms first-chunk latency
|
||||
- streamable, voice-cloneable
|
||||
- pip-installable, MIT-style permissive license
|
||||
|
||||
The pocket-tts model is loaded **lazily** the first time ``call(...)``
|
||||
runs (or eagerly via ``preload()``). Loading takes a few seconds and
|
||||
several hundred MB of RAM, so we don't pay the cost when the tool is
|
||||
merely *registered* — only when it's *invoked*.
|
||||
|
||||
Optional dependency. Install with::
|
||||
|
||||
pip install lerobot[tools]
|
||||
# or directly:
|
||||
pip install pocket-tts
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from lerobot.datasets.language import SAY_TOOL_SCHEMA
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SayTool:
|
||||
"""Speak a short utterance via Kyutai's pocket-tts.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
schema:
|
||||
Optional schema override; defaults to the canonical
|
||||
``SAY_TOOL_SCHEMA`` from PR 1. Custom voices or extended
|
||||
argument shapes can pass in a modified schema, but the
|
||||
implementation only reads ``arguments["text"]``.
|
||||
voice:
|
||||
One of the pocket-tts catalog voices (``alba``, ``marius``,
|
||||
``javert``, ``jean``, ``fantine``, ``cosette``, ``eponine``,
|
||||
``azelma``) or a path to a ``.wav`` / ``.safetensors`` voice
|
||||
file for cloning. See the pocket-tts model card for licensing.
|
||||
output_dir:
|
||||
If set, every ``call(...)`` writes a ``<timestamp>.wav`` audio
|
||||
file there in addition to returning the PCM tensor.
|
||||
``None`` (default) skips disk writes — useful for live
|
||||
playback paths that hand the tensor directly to a sounddevice
|
||||
/ WebAudio sink.
|
||||
"""
|
||||
|
||||
schema: dict[str, Any] = field(default_factory=lambda: dict(SAY_TOOL_SCHEMA))
|
||||
voice: str = "alba"
|
||||
output_dir: Path | None = None
|
||||
|
||||
name: str = field(init=False, default="say")
|
||||
_model: Any = field(init=False, default=None, repr=False)
|
||||
_voice_state: Any = field(init=False, default=None, repr=False)
|
||||
_sample_rate: int = field(init=False, default=24000, repr=False)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Lazy model load
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def preload(self) -> None:
|
||||
"""Load the pocket-tts model + voice state into memory.
|
||||
|
||||
Optional — ``call(...)`` triggers this automatically on first
|
||||
invocation. Useful when you want the multi-second load to
|
||||
happen at startup rather than on the first ``say`` the policy
|
||||
emits.
|
||||
"""
|
||||
if self._model is not None and self._voice_state is not None:
|
||||
return
|
||||
try:
|
||||
from pocket_tts import TTSModel # noqa: PLC0415 (optional dep)
|
||||
except ImportError as exc: # pragma: no cover (env-dependent)
|
||||
raise ImportError(
|
||||
"SayTool requires pocket-tts. Install with `pip install "
|
||||
"lerobot[tools]` or `pip install pocket-tts`."
|
||||
) from exc
|
||||
logger.info("SayTool: loading pocket-tts model + voice=%r", self.voice)
|
||||
self._model = TTSModel.load_model()
|
||||
self._voice_state = self._model.get_state_for_audio_prompt(self.voice)
|
||||
self._sample_rate = int(getattr(self._model, "sample_rate", 24000))
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Tool protocol
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def call(self, arguments: dict[str, Any]) -> Any:
|
||||
"""Speak ``arguments["text"]`` and return the PCM tensor.
|
||||
|
||||
Optionally also writes ``<output_dir>/<timestamp>.wav`` when
|
||||
``self.output_dir`` is set. The returned tensor is a 1-D
|
||||
``torch.Tensor`` of float32 PCM samples at
|
||||
``self.sample_rate`` Hz — directly playable by
|
||||
``sounddevice.play(audio.numpy(), self.sample_rate)`` or
|
||||
encodable by ``scipy.io.wavfile.write``.
|
||||
"""
|
||||
text = arguments.get("text")
|
||||
if not isinstance(text, str) or not text.strip():
|
||||
raise ValueError(
|
||||
f"SayTool.call expects arguments={{'text': str}}, got {arguments!r}"
|
||||
)
|
||||
self.preload()
|
||||
|
||||
audio = self._model.generate_audio(self._voice_state, text)
|
||||
|
||||
if self.output_dir is not None:
|
||||
self._write_wav(audio, text)
|
||||
|
||||
return audio
|
||||
|
||||
@property
|
||||
def sample_rate(self) -> int:
|
||||
"""PCM sample rate of the returned tensor (Hz)."""
|
||||
return self._sample_rate
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _write_wav(self, audio: Any, text: str) -> Path:
|
||||
"""Write a ``.wav`` next to ``output_dir`` for offline inspection."""
|
||||
import time as _time # noqa: PLC0415
|
||||
|
||||
try:
|
||||
import scipy.io.wavfile # noqa: PLC0415
|
||||
except ImportError as exc: # pragma: no cover
|
||||
raise ImportError(
|
||||
"SayTool.output_dir requires scipy. `pip install scipy`."
|
||||
) from exc
|
||||
|
||||
out_dir = Path(self.output_dir)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
# One file per call; suffix with a millisecond timestamp + a
|
||||
# short text snippet so a directory listing is informative.
|
||||
snippet = "".join(c if c.isalnum() else "_" for c in text[:32]).strip("_")
|
||||
ts_ms = int(_time.time() * 1000)
|
||||
path = out_dir / f"say_{ts_ms}_{snippet}.wav"
|
||||
|
||||
# ``audio`` is a torch tensor; pocket-tts uses CPU, so a plain
|
||||
# ``.numpy()`` is safe.
|
||||
scipy.io.wavfile.write(path, self.sample_rate, audio.numpy())
|
||||
return path
|
||||
Reference in New Issue
Block a user