fix(smolvla2): flatten say tool_calls into <say> marker before tokenizing

The chat tokenizer passed assistant `tool_calls` straight to
`apply_chat_template`, which renders them as a structured JSON
`<tool_call>` block — so the LM head was trained to emit JSON. But the
inference parser `_split_plan_and_say` looks for a `<say>...</say>`
marker, which the model never saw in training, so the `say` tool never
fired at inference.

`_flatten_say_tool_calls` is the missing training-time serializer (the
one `_split_plan_and_say`'s docstring already assumed existed): it
rewrites a `say` tool call into a `<say>...</say>` marker inside the
content text before the chat template runs, so the template only
tokenizes plain text and the supervised target span trains the model to
emit exactly the marker the runtime parses back (Pi 0.5-style flat
tool-call serialization).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-05-18 10:47:31 +02:00
parent 5e3b9ba82c
commit bfb8cfb432
2 changed files with 152 additions and 0 deletions
@@ -347,6 +347,11 @@ class SmolVLA2ChatTokenizerStep(ProcessorStep):
messages, target_indices = self._apply_prompt_dropout( messages, target_indices = self._apply_prompt_dropout(
messages, target_indices, sample_idx messages, target_indices, sample_idx
) )
# Flatten ``tool_calls`` into a textual ``<say>...</say>`` marker
# *before* the chat template sees them, so the model is trained
# to emit the same marker the inference parser
# (``_split_plan_and_say``) reads back. See ``_flatten_say_tool_calls``.
messages = [_flatten_say_tool_calls(m) for m in messages]
text_messages = [_strip_lerobot_blocks(m) for m in messages] text_messages = [_strip_lerobot_blocks(m) for m in messages]
full_ids = tokenizer.apply_chat_template( full_ids = tokenizer.apply_chat_template(
@@ -508,6 +513,75 @@ def _strip_lerobot_blocks(message: dict[str, Any]) -> dict[str, Any]:
return new return new
def _content_to_text(content: Any) -> str:
"""Collapse a message's ``content`` (string or multimodal blocks) to plain text."""
if isinstance(content, str):
return content
if isinstance(content, list):
parts: list[str] = []
for block in content:
if isinstance(block, dict) and block.get("type") == "text":
t = block.get("text")
if isinstance(t, str):
parts.append(t)
return "\n".join(parts)
return ""
def _flatten_say_tool_calls(message: dict[str, Any]) -> dict[str, Any]:
"""Serialize assistant ``say`` tool calls into a textual ``<say>...</say>``
marker inside the message content (Pi 0.5-style flat tool-call
serialization).
SmolVLM's chat template would otherwise render ``tool_calls`` as a
structured JSON ``<tool_call>`` block, so the LM head learns to emit
JSON — but the inference parser ``_split_plan_and_say`` looks for a
``<say>...</say>`` marker (``_SAY_RE``). Rewriting the call into the
content text *before* ``apply_chat_template`` aligns the two: the
template only ever tokenizes plain text, and the supervised target
span trains the model to produce the exact marker the runtime reads.
Messages without ``say`` tool calls are returned unchanged.
"""
tool_calls = message.get("tool_calls")
if not tool_calls:
return message
say_texts: list[str] = []
for call in tool_calls:
if not isinstance(call, dict):
continue
fn = call.get("function") or {}
if fn.get("name") != "say":
continue
args = fn.get("arguments")
if isinstance(args, str):
try:
import json # noqa: PLC0415
args = json.loads(args)
except (ValueError, TypeError):
args = {}
text = args.get("text", "") if isinstance(args, dict) else ""
if text:
say_texts.append(str(text))
if not say_texts:
# No ``say`` calls (or empty text) — drop the structured calls so
# the template doesn't render a stray JSON block, but leave the
# content alone.
new = dict(message)
new.pop("tool_calls", None)
return new
new = dict(message)
base = _content_to_text(new.get("content")).strip()
marker = "".join(f"<say>{t}</say>" for t in say_texts)
new["content"] = f"{base}\n{marker}" if base else marker
new.pop("tool_calls", None)
return new
def _is_batched_messages(messages: Any) -> bool: def _is_batched_messages(messages: Any) -> bool:
return isinstance(messages, list) and bool(messages) and isinstance(messages[0], list) return isinstance(messages, list) and bool(messages) and isinstance(messages[0], list)
@@ -572,3 +646,4 @@ def _as_token_ids(value: Any) -> list[int]:
# Re-export for tests / introspection # Re-export for tests / introspection
strip_lerobot_blocks = _strip_lerobot_blocks strip_lerobot_blocks = _strip_lerobot_blocks
flatten_say_tool_calls = _flatten_say_tool_calls
@@ -0,0 +1,77 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for SmolVLA2's chat-tokenizer ``tool_calls`` flattening.
``_split_plan_and_say`` (inference) expects the model to emit a textual
``<say>...</say>`` marker. ``_flatten_say_tool_calls`` is the training-time
serializer that produces it: it rewrites an assistant turn's structured
``say`` tool call into that marker *inside the content text*, before
``apply_chat_template`` runs so the chat template only tokenizes plain
text and the supervised target span trains the model to emit the marker
the runtime parses back. These tests pin the round-trip.
"""
from lerobot.policies.smolvla2.chat_processor_smolvla2 import flatten_say_tool_calls
from lerobot.policies.smolvla2.inference.steps import _split_plan_and_say
def _say_call(text):
return {"type": "function", "function": {"name": "say", "arguments": {"text": text}}}
def test_flatten_appends_say_marker_and_drops_tool_calls():
msg = {"role": "assistant", "content": "Pick up the blue cube.", "tool_calls": [_say_call("On it!")]}
out = flatten_say_tool_calls(msg)
assert "tool_calls" not in out
assert out["content"] == "Pick up the blue cube.\n<say>On it!</say>"
def test_flatten_roundtrips_through_inference_parser():
"""The marker the serializer writes must be exactly what the inference
parser reads back this is the train/inference contract."""
msg = {"role": "assistant", "content": "Move toward the cube.", "tool_calls": [_say_call("Working on it")]}
flat = flatten_say_tool_calls(msg)["content"]
plan, speech = _split_plan_and_say(flat)
assert plan == "Move toward the cube."
assert speech == "Working on it"
def test_flatten_accepts_json_string_arguments():
"""``arguments`` may arrive as a JSON string rather than a dict."""
call = {"type": "function", "function": {"name": "say", "arguments": '{"text": "hello there"}'}}
out = flatten_say_tool_calls({"role": "assistant", "content": "p", "tool_calls": [call]})
assert out["content"] == "p\n<say>hello there</say>"
def test_flatten_leaves_messages_without_tool_calls_untouched():
msg = {"role": "assistant", "content": "just a plan"}
assert flatten_say_tool_calls(msg) == msg
def test_flatten_drops_empty_or_non_say_tool_calls():
"""A non-``say`` call (or empty text) leaves content alone but still
strips the structured calls so the template renders no JSON block."""
weather = {"type": "function", "function": {"name": "check_weather", "arguments": {}}}
out = flatten_say_tool_calls({"role": "assistant", "content": "plan only", "tool_calls": [weather]})
assert out["content"] == "plan only"
assert "tool_calls" not in out
def test_flatten_marker_only_when_content_empty():
msg = {"role": "assistant", "content": "", "tool_calls": [_say_call("hi")]}
out = flatten_say_tool_calls(msg)
assert out["content"] == "<say>hi</say>"