diff --git a/pyproject.toml b/pyproject.toml index 3ca7113fe..b77abb367 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -209,6 +209,14 @@ annotations = [ "vllm>=0.6.0,<1.0.0; sys_platform == 'linux'", ] +# Tool implementations under src/lerobot/tools/. Each tool's dependencies +# are isolated so adding a new tool doesn't bloat the base install. +# Currently only `say` (Kyutai pocket-tts; CPU-only, ~100M params). +tools = [ + "pocket-tts>=0.1.0,<1.0.0", + "scipy>=1.11.0,<2.0.0", # SayTool.output_dir uses scipy.io.wavfile +] + # Development dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1", "mypy>=1.19.1", "ruff>=0.14.1", "lerobot[notebook]"] notebook = ["jupyter>=1.0.0,<2.0.0", "ipykernel>=6.0.0,<7.0.0"] diff --git a/src/lerobot/tools/__init__.py b/src/lerobot/tools/__init__.py new file mode 100644 index 000000000..ebd4524a7 --- /dev/null +++ b/src/lerobot/tools/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""LeRobot tool implementations. + +Storage of the tool catalog (``meta/info.json["tools"]``) and the +``SAY_TOOL_SCHEMA`` constant live in PR 1 +(``lerobot.datasets.language``). This package holds the *runnable* +implementations one file per tool, plus the registry that maps tool +names to classes. + +See ``docs/source/tools.mdx`` for the authoring guide. +""" + +from .base import Tool +from .registry import TOOL_REGISTRY, get_tools +from .say import SayTool + +__all__ = ["Tool", "TOOL_REGISTRY", "get_tools", "SayTool"] diff --git a/src/lerobot/tools/base.py b/src/lerobot/tools/base.py new file mode 100644 index 000000000..2c6cc5295 --- /dev/null +++ b/src/lerobot/tools/base.py @@ -0,0 +1,58 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tool protocol — the contract every runnable tool implementation honors. + +Tools are the executable side of the OpenAI-style function-calling +abstraction the v3.1 language schema (PR 1) carries on assistant +messages: the schema describes *what can be called*, the tool +implementation describes *how to call it*. + +Implementations live one-per-file under :mod:`lerobot.tools` (e.g. +``say.py`` for ``SayTool``) and are registered in +:mod:`lerobot.tools.registry`. The runtime instantiates them lazily so +heavy dependencies (torch models, audio backends, network clients, +hardware drivers) only load when the dataset actually declares the tool. +""" + +from __future__ import annotations + +from typing import Any, Protocol, runtime_checkable + + +@runtime_checkable +class Tool(Protocol): + """Minimum surface every tool must expose.""" + + #: Name matching ``schema["function"]["name"]``. The runtime dispatcher + #: routes incoming ``tool_calls`` to the implementation by this key. + name: str + + #: OpenAI-style function-call schema. Same dict the dataset stores in + #: ``meta/info.json["tools"]`` and the chat template renders into the + #: prompt. + schema: dict[str, Any] + + def call(self, arguments: dict[str, Any]) -> Any: + """Execute the tool with the model-provided arguments. + + ``arguments`` is the parsed dict from + ``tool_calls[i]["function"]["arguments"]`` (already JSON-decoded + when the model emits a JSON-string by the chat-template + convention). Implementations validate the dict against their own + schema; the runtime only routes by name. + + Return value is implementation-defined — typically a tensor + (TTS audio), a Path (saved file), a dict (structured result), or + ``None`` (side-effect-only call). + """ diff --git a/src/lerobot/tools/registry.py b/src/lerobot/tools/registry.py new file mode 100644 index 000000000..7908d328c --- /dev/null +++ b/src/lerobot/tools/registry.py @@ -0,0 +1,70 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tool registry — name → implementation class. + +Adding a new tool: + +1. Drop a file under ``src/lerobot/tools/`` that defines a class + conforming to :class:`lerobot.tools.base.Tool` (must expose ``name``, + ``schema``, ``call(arguments)``). +2. Register the class here under :data:`TOOL_REGISTRY`. +3. (Optional) Pre-populate ``meta/info.json["tools"]`` on your dataset + to advertise the schema to the chat-template + policy. The PR 2 + annotation pipeline preserves anything you put there. + +See ``docs/source/tools.mdx`` for the full authoring guide. +""" + +from __future__ import annotations + +from typing import Any + +from .base import Tool +from .say import SayTool + +#: Map from ``function.name`` to a class implementing :class:`Tool`. +#: The runtime instantiates entries lazily — registering a tool here is +#: essentially free (no model load happens until ``call`` runs). +TOOL_REGISTRY: dict[str, type] = { + "say": SayTool, +} + + +def get_tools(meta: Any, **kwargs: Any) -> dict[str, Tool]: + """Build name → tool-instance dict from a dataset's declared catalog. + + ``meta`` is anything with a ``.tools`` attribute returning the + OpenAI-style schema list — typically a + :class:`lerobot.datasets.dataset_metadata.LeRobotDatasetMetadata`. + Each entry whose ``function.name`` is registered here is + instantiated with the schema dict; tools whose name is unknown to + the registry are skipped (the schema still rides through the chat + template, the model just can't actually invoke that tool at + inference). + + Extra keyword arguments are forwarded to every constructor — useful + for runtime defaults like ``output_dir=Path("./tts_log")``. + """ + declared = list(meta.tools) + instances: dict[str, Tool] = {} + for schema in declared: + try: + name = schema["function"]["name"] + except (KeyError, TypeError): + continue + cls = TOOL_REGISTRY.get(name) + if cls is None: + continue + instances[name] = cls(schema=schema, **kwargs) + return instances diff --git a/src/lerobot/tools/say.py b/src/lerobot/tools/say.py new file mode 100644 index 000000000..a5f2c5f89 --- /dev/null +++ b/src/lerobot/tools/say.py @@ -0,0 +1,170 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""``SayTool`` — text-to-speech tool wrapping Kyutai's pocket-tts. + +The first concrete tool implementation. SmolVLA2 (PR 3) and downstream +runtime dispatchers consume this when the model emits an assistant +message with ``tool_calls=[{function: {name: "say", arguments: +{text: ...}}}]``. + +Why pocket-tts: + +- runs on CPU (no GPU dependency); ~6× real-time on a MacBook Air M4 +- ~100M parameters, ~200ms first-chunk latency +- streamable, voice-cloneable +- pip-installable, MIT-style permissive license + +The pocket-tts model is loaded **lazily** the first time ``call(...)`` +runs (or eagerly via ``preload()``). Loading takes a few seconds and +several hundred MB of RAM, so we don't pay the cost when the tool is +merely *registered* — only when it's *invoked*. + +Optional dependency. Install with:: + + pip install lerobot[tools] + # or directly: + pip install pocket-tts +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +from lerobot.datasets.language import SAY_TOOL_SCHEMA + +logger = logging.getLogger(__name__) + + +@dataclass +class SayTool: + """Speak a short utterance via Kyutai's pocket-tts. + + Parameters + ---------- + schema: + Optional schema override; defaults to the canonical + ``SAY_TOOL_SCHEMA`` from PR 1. Custom voices or extended + argument shapes can pass in a modified schema, but the + implementation only reads ``arguments["text"]``. + voice: + One of the pocket-tts catalog voices (``alba``, ``marius``, + ``javert``, ``jean``, ``fantine``, ``cosette``, ``eponine``, + ``azelma``) or a path to a ``.wav`` / ``.safetensors`` voice + file for cloning. See the pocket-tts model card for licensing. + output_dir: + If set, every ``call(...)`` writes a ``.wav`` audio + file there in addition to returning the PCM tensor. + ``None`` (default) skips disk writes — useful for live + playback paths that hand the tensor directly to a sounddevice + / WebAudio sink. + """ + + schema: dict[str, Any] = field(default_factory=lambda: dict(SAY_TOOL_SCHEMA)) + voice: str = "alba" + output_dir: Path | None = None + + name: str = field(init=False, default="say") + _model: Any = field(init=False, default=None, repr=False) + _voice_state: Any = field(init=False, default=None, repr=False) + _sample_rate: int = field(init=False, default=24000, repr=False) + + # ------------------------------------------------------------------ + # Lazy model load + # ------------------------------------------------------------------ + + def preload(self) -> None: + """Load the pocket-tts model + voice state into memory. + + Optional — ``call(...)`` triggers this automatically on first + invocation. Useful when you want the multi-second load to + happen at startup rather than on the first ``say`` the policy + emits. + """ + if self._model is not None and self._voice_state is not None: + return + try: + from pocket_tts import TTSModel # noqa: PLC0415 (optional dep) + except ImportError as exc: # pragma: no cover (env-dependent) + raise ImportError( + "SayTool requires pocket-tts. Install with `pip install " + "lerobot[tools]` or `pip install pocket-tts`." + ) from exc + logger.info("SayTool: loading pocket-tts model + voice=%r", self.voice) + self._model = TTSModel.load_model() + self._voice_state = self._model.get_state_for_audio_prompt(self.voice) + self._sample_rate = int(getattr(self._model, "sample_rate", 24000)) + + # ------------------------------------------------------------------ + # Tool protocol + # ------------------------------------------------------------------ + + def call(self, arguments: dict[str, Any]) -> Any: + """Speak ``arguments["text"]`` and return the PCM tensor. + + Optionally also writes ``/.wav`` when + ``self.output_dir`` is set. The returned tensor is a 1-D + ``torch.Tensor`` of float32 PCM samples at + ``self.sample_rate`` Hz — directly playable by + ``sounddevice.play(audio.numpy(), self.sample_rate)`` or + encodable by ``scipy.io.wavfile.write``. + """ + text = arguments.get("text") + if not isinstance(text, str) or not text.strip(): + raise ValueError( + f"SayTool.call expects arguments={{'text': str}}, got {arguments!r}" + ) + self.preload() + + audio = self._model.generate_audio(self._voice_state, text) + + if self.output_dir is not None: + self._write_wav(audio, text) + + return audio + + @property + def sample_rate(self) -> int: + """PCM sample rate of the returned tensor (Hz).""" + return self._sample_rate + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + + def _write_wav(self, audio: Any, text: str) -> Path: + """Write a ``.wav`` next to ``output_dir`` for offline inspection.""" + import time as _time # noqa: PLC0415 + + try: + import scipy.io.wavfile # noqa: PLC0415 + except ImportError as exc: # pragma: no cover + raise ImportError( + "SayTool.output_dir requires scipy. `pip install scipy`." + ) from exc + + out_dir = Path(self.output_dir) + out_dir.mkdir(parents=True, exist_ok=True) + # One file per call; suffix with a millisecond timestamp + a + # short text snippet so a directory listing is informative. + snippet = "".join(c if c.isalnum() else "_" for c in text[:32]).strip("_") + ts_ms = int(_time.time() * 1000) + path = out_dir / f"say_{ts_ms}_{snippet}.wav" + + # ``audio`` is a torch tensor; pocket-tts uses CPU, so a plain + # ``.numpy()`` is safe. + scipy.io.wavfile.write(path, self.sample_rate, audio.numpy()) + return path