mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 12:09:42 +00:00
feat(tools): src/lerobot/tools/ — runnable tool registry + SayTool
Ships the runtime side of the OpenAI-style function-calling stack
introduced in PR 1 (catalog in ``meta/info.json["tools"]``) and PR 2
(annotation pipeline writes the catalog after a run). One file per
tool — heavy deps stay isolated.
Layout:
- ``base.py`` — :class:`Tool` Protocol: ``name``, ``schema``,
``call(arguments)``. Runtime-checkable so tests can use
``isinstance(...)``.
- ``registry.py`` — :data:`TOOL_REGISTRY` (name → class) plus
``get_tools(meta, **kwargs)`` that instantiates every entry whose
``function.name`` is registered. Tools whose name is unknown are
silently skipped — the schema still rides through the chat
template, the model just can't actually invoke that tool at
inference.
- ``say.py`` — :class:`SayTool` wrapping Kyutai's pocket-tts
(CPU-only, ~100M params, ~6× real-time on a MacBook Air M4).
Lazy model load: pocket-tts is imported and the voice state
computed on first ``call(...)`` (or eagerly via ``preload()``).
Returns the PCM tensor; optionally writes a ``.wav`` to
``output_dir`` for offline inspection.
- ``__init__.py`` — re-exports the public surface.
Optional install:
pip install lerobot[tools]
The ``[tools]`` extra in ``pyproject.toml`` pulls in ``pocket-tts`` +
``scipy`` (for the wav writer). Adding more tools later means a new
file + a registry entry — no new extras unless the tool brings new
deps.
To add your own tool, follow the three-step guide in
``docs/source/tools.mdx`` (PR 1):
1. Drop ``src/lerobot/tools/<my_tool>.py`` with a ``Tool``-conforming
class.
2. Register the class in ``TOOL_REGISTRY`` (this file).
3. Pre-populate ``meta/info.json["tools"]`` with the schema (or let
``lerobot-annotate`` add it on the next run).
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -209,6 +209,14 @@ annotations = [
|
|||||||
"vllm>=0.6.0,<1.0.0; sys_platform == 'linux'",
|
"vllm>=0.6.0,<1.0.0; sys_platform == 'linux'",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# Tool implementations under src/lerobot/tools/. Each tool's dependencies
|
||||||
|
# are isolated so adding a new tool doesn't bloat the base install.
|
||||||
|
# Currently only `say` (Kyutai pocket-tts; CPU-only, ~100M params).
|
||||||
|
tools = [
|
||||||
|
"pocket-tts>=0.1.0,<1.0.0",
|
||||||
|
"scipy>=1.11.0,<2.0.0", # SayTool.output_dir uses scipy.io.wavfile
|
||||||
|
]
|
||||||
|
|
||||||
# Development
|
# Development
|
||||||
dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1", "mypy>=1.19.1", "ruff>=0.14.1", "lerobot[notebook]"]
|
dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1", "mypy>=1.19.1", "ruff>=0.14.1", "lerobot[notebook]"]
|
||||||
notebook = ["jupyter>=1.0.0,<2.0.0", "ipykernel>=6.0.0,<7.0.0"]
|
notebook = ["jupyter>=1.0.0,<2.0.0", "ipykernel>=6.0.0,<7.0.0"]
|
||||||
|
|||||||
@@ -0,0 +1,29 @@
|
|||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""LeRobot tool implementations.
|
||||||
|
|
||||||
|
Storage of the tool catalog (``meta/info.json["tools"]``) and the
|
||||||
|
``SAY_TOOL_SCHEMA`` constant live in PR 1
|
||||||
|
(``lerobot.datasets.language``). This package holds the *runnable*
|
||||||
|
implementations one file per tool, plus the registry that maps tool
|
||||||
|
names to classes.
|
||||||
|
|
||||||
|
See ``docs/source/tools.mdx`` for the authoring guide.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .base import Tool
|
||||||
|
from .registry import TOOL_REGISTRY, get_tools
|
||||||
|
from .say import SayTool
|
||||||
|
|
||||||
|
__all__ = ["Tool", "TOOL_REGISTRY", "get_tools", "SayTool"]
|
||||||
@@ -0,0 +1,58 @@
|
|||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Tool protocol — the contract every runnable tool implementation honors.
|
||||||
|
|
||||||
|
Tools are the executable side of the OpenAI-style function-calling
|
||||||
|
abstraction the v3.1 language schema (PR 1) carries on assistant
|
||||||
|
messages: the schema describes *what can be called*, the tool
|
||||||
|
implementation describes *how to call it*.
|
||||||
|
|
||||||
|
Implementations live one-per-file under :mod:`lerobot.tools` (e.g.
|
||||||
|
``say.py`` for ``SayTool``) and are registered in
|
||||||
|
:mod:`lerobot.tools.registry`. The runtime instantiates them lazily so
|
||||||
|
heavy dependencies (torch models, audio backends, network clients,
|
||||||
|
hardware drivers) only load when the dataset actually declares the tool.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, Protocol, runtime_checkable
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class Tool(Protocol):
|
||||||
|
"""Minimum surface every tool must expose."""
|
||||||
|
|
||||||
|
#: Name matching ``schema["function"]["name"]``. The runtime dispatcher
|
||||||
|
#: routes incoming ``tool_calls`` to the implementation by this key.
|
||||||
|
name: str
|
||||||
|
|
||||||
|
#: OpenAI-style function-call schema. Same dict the dataset stores in
|
||||||
|
#: ``meta/info.json["tools"]`` and the chat template renders into the
|
||||||
|
#: prompt.
|
||||||
|
schema: dict[str, Any]
|
||||||
|
|
||||||
|
def call(self, arguments: dict[str, Any]) -> Any:
|
||||||
|
"""Execute the tool with the model-provided arguments.
|
||||||
|
|
||||||
|
``arguments`` is the parsed dict from
|
||||||
|
``tool_calls[i]["function"]["arguments"]`` (already JSON-decoded
|
||||||
|
when the model emits a JSON-string by the chat-template
|
||||||
|
convention). Implementations validate the dict against their own
|
||||||
|
schema; the runtime only routes by name.
|
||||||
|
|
||||||
|
Return value is implementation-defined — typically a tensor
|
||||||
|
(TTS audio), a Path (saved file), a dict (structured result), or
|
||||||
|
``None`` (side-effect-only call).
|
||||||
|
"""
|
||||||
@@ -0,0 +1,70 @@
|
|||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Tool registry — name → implementation class.
|
||||||
|
|
||||||
|
Adding a new tool:
|
||||||
|
|
||||||
|
1. Drop a file under ``src/lerobot/tools/`` that defines a class
|
||||||
|
conforming to :class:`lerobot.tools.base.Tool` (must expose ``name``,
|
||||||
|
``schema``, ``call(arguments)``).
|
||||||
|
2. Register the class here under :data:`TOOL_REGISTRY`.
|
||||||
|
3. (Optional) Pre-populate ``meta/info.json["tools"]`` on your dataset
|
||||||
|
to advertise the schema to the chat-template + policy. The PR 2
|
||||||
|
annotation pipeline preserves anything you put there.
|
||||||
|
|
||||||
|
See ``docs/source/tools.mdx`` for the full authoring guide.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from .base import Tool
|
||||||
|
from .say import SayTool
|
||||||
|
|
||||||
|
#: Map from ``function.name`` to a class implementing :class:`Tool`.
|
||||||
|
#: The runtime instantiates entries lazily — registering a tool here is
|
||||||
|
#: essentially free (no model load happens until ``call`` runs).
|
||||||
|
TOOL_REGISTRY: dict[str, type] = {
|
||||||
|
"say": SayTool,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_tools(meta: Any, **kwargs: Any) -> dict[str, Tool]:
|
||||||
|
"""Build name → tool-instance dict from a dataset's declared catalog.
|
||||||
|
|
||||||
|
``meta`` is anything with a ``.tools`` attribute returning the
|
||||||
|
OpenAI-style schema list — typically a
|
||||||
|
:class:`lerobot.datasets.dataset_metadata.LeRobotDatasetMetadata`.
|
||||||
|
Each entry whose ``function.name`` is registered here is
|
||||||
|
instantiated with the schema dict; tools whose name is unknown to
|
||||||
|
the registry are skipped (the schema still rides through the chat
|
||||||
|
template, the model just can't actually invoke that tool at
|
||||||
|
inference).
|
||||||
|
|
||||||
|
Extra keyword arguments are forwarded to every constructor — useful
|
||||||
|
for runtime defaults like ``output_dir=Path("./tts_log")``.
|
||||||
|
"""
|
||||||
|
declared = list(meta.tools)
|
||||||
|
instances: dict[str, Tool] = {}
|
||||||
|
for schema in declared:
|
||||||
|
try:
|
||||||
|
name = schema["function"]["name"]
|
||||||
|
except (KeyError, TypeError):
|
||||||
|
continue
|
||||||
|
cls = TOOL_REGISTRY.get(name)
|
||||||
|
if cls is None:
|
||||||
|
continue
|
||||||
|
instances[name] = cls(schema=schema, **kwargs)
|
||||||
|
return instances
|
||||||
@@ -0,0 +1,170 @@
|
|||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""``SayTool`` — text-to-speech tool wrapping Kyutai's pocket-tts.
|
||||||
|
|
||||||
|
The first concrete tool implementation. SmolVLA2 (PR 3) and downstream
|
||||||
|
runtime dispatchers consume this when the model emits an assistant
|
||||||
|
message with ``tool_calls=[{function: {name: "say", arguments:
|
||||||
|
{text: ...}}}]``.
|
||||||
|
|
||||||
|
Why pocket-tts:
|
||||||
|
|
||||||
|
- runs on CPU (no GPU dependency); ~6× real-time on a MacBook Air M4
|
||||||
|
- ~100M parameters, ~200ms first-chunk latency
|
||||||
|
- streamable, voice-cloneable
|
||||||
|
- pip-installable, MIT-style permissive license
|
||||||
|
|
||||||
|
The pocket-tts model is loaded **lazily** the first time ``call(...)``
|
||||||
|
runs (or eagerly via ``preload()``). Loading takes a few seconds and
|
||||||
|
several hundred MB of RAM, so we don't pay the cost when the tool is
|
||||||
|
merely *registered* — only when it's *invoked*.
|
||||||
|
|
||||||
|
Optional dependency. Install with::
|
||||||
|
|
||||||
|
pip install lerobot[tools]
|
||||||
|
# or directly:
|
||||||
|
pip install pocket-tts
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from lerobot.datasets.language import SAY_TOOL_SCHEMA
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SayTool:
|
||||||
|
"""Speak a short utterance via Kyutai's pocket-tts.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
schema:
|
||||||
|
Optional schema override; defaults to the canonical
|
||||||
|
``SAY_TOOL_SCHEMA`` from PR 1. Custom voices or extended
|
||||||
|
argument shapes can pass in a modified schema, but the
|
||||||
|
implementation only reads ``arguments["text"]``.
|
||||||
|
voice:
|
||||||
|
One of the pocket-tts catalog voices (``alba``, ``marius``,
|
||||||
|
``javert``, ``jean``, ``fantine``, ``cosette``, ``eponine``,
|
||||||
|
``azelma``) or a path to a ``.wav`` / ``.safetensors`` voice
|
||||||
|
file for cloning. See the pocket-tts model card for licensing.
|
||||||
|
output_dir:
|
||||||
|
If set, every ``call(...)`` writes a ``<timestamp>.wav`` audio
|
||||||
|
file there in addition to returning the PCM tensor.
|
||||||
|
``None`` (default) skips disk writes — useful for live
|
||||||
|
playback paths that hand the tensor directly to a sounddevice
|
||||||
|
/ WebAudio sink.
|
||||||
|
"""
|
||||||
|
|
||||||
|
schema: dict[str, Any] = field(default_factory=lambda: dict(SAY_TOOL_SCHEMA))
|
||||||
|
voice: str = "alba"
|
||||||
|
output_dir: Path | None = None
|
||||||
|
|
||||||
|
name: str = field(init=False, default="say")
|
||||||
|
_model: Any = field(init=False, default=None, repr=False)
|
||||||
|
_voice_state: Any = field(init=False, default=None, repr=False)
|
||||||
|
_sample_rate: int = field(init=False, default=24000, repr=False)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Lazy model load
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def preload(self) -> None:
|
||||||
|
"""Load the pocket-tts model + voice state into memory.
|
||||||
|
|
||||||
|
Optional — ``call(...)`` triggers this automatically on first
|
||||||
|
invocation. Useful when you want the multi-second load to
|
||||||
|
happen at startup rather than on the first ``say`` the policy
|
||||||
|
emits.
|
||||||
|
"""
|
||||||
|
if self._model is not None and self._voice_state is not None:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
from pocket_tts import TTSModel # noqa: PLC0415 (optional dep)
|
||||||
|
except ImportError as exc: # pragma: no cover (env-dependent)
|
||||||
|
raise ImportError(
|
||||||
|
"SayTool requires pocket-tts. Install with `pip install "
|
||||||
|
"lerobot[tools]` or `pip install pocket-tts`."
|
||||||
|
) from exc
|
||||||
|
logger.info("SayTool: loading pocket-tts model + voice=%r", self.voice)
|
||||||
|
self._model = TTSModel.load_model()
|
||||||
|
self._voice_state = self._model.get_state_for_audio_prompt(self.voice)
|
||||||
|
self._sample_rate = int(getattr(self._model, "sample_rate", 24000))
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Tool protocol
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def call(self, arguments: dict[str, Any]) -> Any:
|
||||||
|
"""Speak ``arguments["text"]`` and return the PCM tensor.
|
||||||
|
|
||||||
|
Optionally also writes ``<output_dir>/<timestamp>.wav`` when
|
||||||
|
``self.output_dir`` is set. The returned tensor is a 1-D
|
||||||
|
``torch.Tensor`` of float32 PCM samples at
|
||||||
|
``self.sample_rate`` Hz — directly playable by
|
||||||
|
``sounddevice.play(audio.numpy(), self.sample_rate)`` or
|
||||||
|
encodable by ``scipy.io.wavfile.write``.
|
||||||
|
"""
|
||||||
|
text = arguments.get("text")
|
||||||
|
if not isinstance(text, str) or not text.strip():
|
||||||
|
raise ValueError(
|
||||||
|
f"SayTool.call expects arguments={{'text': str}}, got {arguments!r}"
|
||||||
|
)
|
||||||
|
self.preload()
|
||||||
|
|
||||||
|
audio = self._model.generate_audio(self._voice_state, text)
|
||||||
|
|
||||||
|
if self.output_dir is not None:
|
||||||
|
self._write_wav(audio, text)
|
||||||
|
|
||||||
|
return audio
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sample_rate(self) -> int:
|
||||||
|
"""PCM sample rate of the returned tensor (Hz)."""
|
||||||
|
return self._sample_rate
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _write_wav(self, audio: Any, text: str) -> Path:
|
||||||
|
"""Write a ``.wav`` next to ``output_dir`` for offline inspection."""
|
||||||
|
import time as _time # noqa: PLC0415
|
||||||
|
|
||||||
|
try:
|
||||||
|
import scipy.io.wavfile # noqa: PLC0415
|
||||||
|
except ImportError as exc: # pragma: no cover
|
||||||
|
raise ImportError(
|
||||||
|
"SayTool.output_dir requires scipy. `pip install scipy`."
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
out_dir = Path(self.output_dir)
|
||||||
|
out_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
# One file per call; suffix with a millisecond timestamp + a
|
||||||
|
# short text snippet so a directory listing is informative.
|
||||||
|
snippet = "".join(c if c.isalnum() else "_" for c in text[:32]).strip("_")
|
||||||
|
ts_ms = int(_time.time() * 1000)
|
||||||
|
path = out_dir / f"say_{ts_ms}_{snippet}.wav"
|
||||||
|
|
||||||
|
# ``audio`` is a torch tensor; pocket-tts uses CPU, so a plain
|
||||||
|
# ``.numpy()`` is safe.
|
||||||
|
scipy.io.wavfile.write(path, self.sample_rate, audio.numpy())
|
||||||
|
return path
|
||||||
Reference in New Issue
Block a user