From bfb8cfb4322a2d7842d4152918eedc8e59d41841 Mon Sep 17 00:00:00 2001 From: Pepijn Date: Mon, 18 May 2026 10:47:31 +0200 Subject: [PATCH] fix(smolvla2): flatten say tool_calls into marker before tokenizing MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The chat tokenizer passed assistant `tool_calls` straight to `apply_chat_template`, which renders them as a structured JSON `` block — so the LM head was trained to emit JSON. But the inference parser `_split_plan_and_say` looks for a `...` marker, which the model never saw in training, so the `say` tool never fired at inference. `_flatten_say_tool_calls` is the missing training-time serializer (the one `_split_plan_and_say`'s docstring already assumed existed): it rewrites a `say` tool call into a `...` marker inside the content text before the chat template runs, so the template only tokenizes plain text and the supervised target span trains the model to emit exactly the marker the runtime parses back (Pi 0.5-style flat tool-call serialization). Co-Authored-By: Claude Opus 4.7 (1M context) --- .../smolvla2/chat_processor_smolvla2.py | 75 ++++++++++++++++++ .../smolvla/test_smolvla2_chat_processor.py | 77 +++++++++++++++++++ 2 files changed, 152 insertions(+) create mode 100644 tests/policies/smolvla/test_smolvla2_chat_processor.py diff --git a/src/lerobot/policies/smolvla2/chat_processor_smolvla2.py b/src/lerobot/policies/smolvla2/chat_processor_smolvla2.py index 36fc02dca..2012c895d 100644 --- a/src/lerobot/policies/smolvla2/chat_processor_smolvla2.py +++ b/src/lerobot/policies/smolvla2/chat_processor_smolvla2.py @@ -347,6 +347,11 @@ class SmolVLA2ChatTokenizerStep(ProcessorStep): messages, target_indices = self._apply_prompt_dropout( messages, target_indices, sample_idx ) + # Flatten ``tool_calls`` into a textual ``...`` marker + # *before* the chat template sees them, so the model is trained + # to emit the same marker the inference parser + # (``_split_plan_and_say``) reads back. See ``_flatten_say_tool_calls``. + messages = [_flatten_say_tool_calls(m) for m in messages] text_messages = [_strip_lerobot_blocks(m) for m in messages] full_ids = tokenizer.apply_chat_template( @@ -508,6 +513,75 @@ def _strip_lerobot_blocks(message: dict[str, Any]) -> dict[str, Any]: return new +def _content_to_text(content: Any) -> str: + """Collapse a message's ``content`` (string or multimodal blocks) to plain text.""" + if isinstance(content, str): + return content + if isinstance(content, list): + parts: list[str] = [] + for block in content: + if isinstance(block, dict) and block.get("type") == "text": + t = block.get("text") + if isinstance(t, str): + parts.append(t) + return "\n".join(parts) + return "" + + +def _flatten_say_tool_calls(message: dict[str, Any]) -> dict[str, Any]: + """Serialize assistant ``say`` tool calls into a textual ``...`` + marker inside the message content (Pi 0.5-style flat tool-call + serialization). + + SmolVLM's chat template would otherwise render ``tool_calls`` as a + structured JSON ```` block, so the LM head learns to emit + JSON — but the inference parser ``_split_plan_and_say`` looks for a + ``...`` marker (``_SAY_RE``). Rewriting the call into the + content text *before* ``apply_chat_template`` aligns the two: the + template only ever tokenizes plain text, and the supervised target + span trains the model to produce the exact marker the runtime reads. + + Messages without ``say`` tool calls are returned unchanged. + """ + tool_calls = message.get("tool_calls") + if not tool_calls: + return message + + say_texts: list[str] = [] + for call in tool_calls: + if not isinstance(call, dict): + continue + fn = call.get("function") or {} + if fn.get("name") != "say": + continue + args = fn.get("arguments") + if isinstance(args, str): + try: + import json # noqa: PLC0415 + + args = json.loads(args) + except (ValueError, TypeError): + args = {} + text = args.get("text", "") if isinstance(args, dict) else "" + if text: + say_texts.append(str(text)) + + if not say_texts: + # No ``say`` calls (or empty text) — drop the structured calls so + # the template doesn't render a stray JSON block, but leave the + # content alone. + new = dict(message) + new.pop("tool_calls", None) + return new + + new = dict(message) + base = _content_to_text(new.get("content")).strip() + marker = "".join(f"{t}" for t in say_texts) + new["content"] = f"{base}\n{marker}" if base else marker + new.pop("tool_calls", None) + return new + + def _is_batched_messages(messages: Any) -> bool: return isinstance(messages, list) and bool(messages) and isinstance(messages[0], list) @@ -572,3 +646,4 @@ def _as_token_ids(value: Any) -> list[int]: # Re-export for tests / introspection strip_lerobot_blocks = _strip_lerobot_blocks +flatten_say_tool_calls = _flatten_say_tool_calls diff --git a/tests/policies/smolvla/test_smolvla2_chat_processor.py b/tests/policies/smolvla/test_smolvla2_chat_processor.py new file mode 100644 index 000000000..26735affe --- /dev/null +++ b/tests/policies/smolvla/test_smolvla2_chat_processor.py @@ -0,0 +1,77 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for SmolVLA2's chat-tokenizer ``tool_calls`` flattening. + +``_split_plan_and_say`` (inference) expects the model to emit a textual +``...`` marker. ``_flatten_say_tool_calls`` is the training-time +serializer that produces it: it rewrites an assistant turn's structured +``say`` tool call into that marker *inside the content text*, before +``apply_chat_template`` runs — so the chat template only tokenizes plain +text and the supervised target span trains the model to emit the marker +the runtime parses back. These tests pin the round-trip. +""" + +from lerobot.policies.smolvla2.chat_processor_smolvla2 import flatten_say_tool_calls +from lerobot.policies.smolvla2.inference.steps import _split_plan_and_say + + +def _say_call(text): + return {"type": "function", "function": {"name": "say", "arguments": {"text": text}}} + + +def test_flatten_appends_say_marker_and_drops_tool_calls(): + msg = {"role": "assistant", "content": "Pick up the blue cube.", "tool_calls": [_say_call("On it!")]} + out = flatten_say_tool_calls(msg) + assert "tool_calls" not in out + assert out["content"] == "Pick up the blue cube.\nOn it!" + + +def test_flatten_roundtrips_through_inference_parser(): + """The marker the serializer writes must be exactly what the inference + parser reads back — this is the train/inference contract.""" + msg = {"role": "assistant", "content": "Move toward the cube.", "tool_calls": [_say_call("Working on it")]} + flat = flatten_say_tool_calls(msg)["content"] + plan, speech = _split_plan_and_say(flat) + assert plan == "Move toward the cube." + assert speech == "Working on it" + + +def test_flatten_accepts_json_string_arguments(): + """``arguments`` may arrive as a JSON string rather than a dict.""" + call = {"type": "function", "function": {"name": "say", "arguments": '{"text": "hello there"}'}} + out = flatten_say_tool_calls({"role": "assistant", "content": "p", "tool_calls": [call]}) + assert out["content"] == "p\nhello there" + + +def test_flatten_leaves_messages_without_tool_calls_untouched(): + msg = {"role": "assistant", "content": "just a plan"} + assert flatten_say_tool_calls(msg) == msg + + +def test_flatten_drops_empty_or_non_say_tool_calls(): + """A non-``say`` call (or empty text) leaves content alone but still + strips the structured calls so the template renders no JSON block.""" + weather = {"type": "function", "function": {"name": "check_weather", "arguments": {}}} + out = flatten_say_tool_calls({"role": "assistant", "content": "plan only", "tool_calls": [weather]}) + assert out["content"] == "plan only" + assert "tool_calls" not in out + + +def test_flatten_marker_only_when_content_empty(): + msg = {"role": "assistant", "content": "", "tool_calls": [_say_call("hi")]} + out = flatten_say_tool_calls(msg) + assert out["content"] == "hi"