diff --git a/src/lerobot/policies/smolvla2/chat_processor_smolvla2.py b/src/lerobot/policies/smolvla2/chat_processor_smolvla2.py index 36fc02dca..2012c895d 100644 --- a/src/lerobot/policies/smolvla2/chat_processor_smolvla2.py +++ b/src/lerobot/policies/smolvla2/chat_processor_smolvla2.py @@ -347,6 +347,11 @@ class SmolVLA2ChatTokenizerStep(ProcessorStep): messages, target_indices = self._apply_prompt_dropout( messages, target_indices, sample_idx ) + # Flatten ``tool_calls`` into a textual ``...`` marker + # *before* the chat template sees them, so the model is trained + # to emit the same marker the inference parser + # (``_split_plan_and_say``) reads back. See ``_flatten_say_tool_calls``. + messages = [_flatten_say_tool_calls(m) for m in messages] text_messages = [_strip_lerobot_blocks(m) for m in messages] full_ids = tokenizer.apply_chat_template( @@ -508,6 +513,75 @@ def _strip_lerobot_blocks(message: dict[str, Any]) -> dict[str, Any]: return new +def _content_to_text(content: Any) -> str: + """Collapse a message's ``content`` (string or multimodal blocks) to plain text.""" + if isinstance(content, str): + return content + if isinstance(content, list): + parts: list[str] = [] + for block in content: + if isinstance(block, dict) and block.get("type") == "text": + t = block.get("text") + if isinstance(t, str): + parts.append(t) + return "\n".join(parts) + return "" + + +def _flatten_say_tool_calls(message: dict[str, Any]) -> dict[str, Any]: + """Serialize assistant ``say`` tool calls into a textual ``...`` + marker inside the message content (Pi 0.5-style flat tool-call + serialization). + + SmolVLM's chat template would otherwise render ``tool_calls`` as a + structured JSON ```` block, so the LM head learns to emit + JSON — but the inference parser ``_split_plan_and_say`` looks for a + ``...`` marker (``_SAY_RE``). Rewriting the call into the + content text *before* ``apply_chat_template`` aligns the two: the + template only ever tokenizes plain text, and the supervised target + span trains the model to produce the exact marker the runtime reads. + + Messages without ``say`` tool calls are returned unchanged. + """ + tool_calls = message.get("tool_calls") + if not tool_calls: + return message + + say_texts: list[str] = [] + for call in tool_calls: + if not isinstance(call, dict): + continue + fn = call.get("function") or {} + if fn.get("name") != "say": + continue + args = fn.get("arguments") + if isinstance(args, str): + try: + import json # noqa: PLC0415 + + args = json.loads(args) + except (ValueError, TypeError): + args = {} + text = args.get("text", "") if isinstance(args, dict) else "" + if text: + say_texts.append(str(text)) + + if not say_texts: + # No ``say`` calls (or empty text) — drop the structured calls so + # the template doesn't render a stray JSON block, but leave the + # content alone. + new = dict(message) + new.pop("tool_calls", None) + return new + + new = dict(message) + base = _content_to_text(new.get("content")).strip() + marker = "".join(f"{t}" for t in say_texts) + new["content"] = f"{base}\n{marker}" if base else marker + new.pop("tool_calls", None) + return new + + def _is_batched_messages(messages: Any) -> bool: return isinstance(messages, list) and bool(messages) and isinstance(messages[0], list) @@ -572,3 +646,4 @@ def _as_token_ids(value: Any) -> list[int]: # Re-export for tests / introspection strip_lerobot_blocks = _strip_lerobot_blocks +flatten_say_tool_calls = _flatten_say_tool_calls diff --git a/tests/policies/smolvla/test_smolvla2_chat_processor.py b/tests/policies/smolvla/test_smolvla2_chat_processor.py new file mode 100644 index 000000000..26735affe --- /dev/null +++ b/tests/policies/smolvla/test_smolvla2_chat_processor.py @@ -0,0 +1,77 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for SmolVLA2's chat-tokenizer ``tool_calls`` flattening. + +``_split_plan_and_say`` (inference) expects the model to emit a textual +``...`` marker. ``_flatten_say_tool_calls`` is the training-time +serializer that produces it: it rewrites an assistant turn's structured +``say`` tool call into that marker *inside the content text*, before +``apply_chat_template`` runs — so the chat template only tokenizes plain +text and the supervised target span trains the model to emit the marker +the runtime parses back. These tests pin the round-trip. +""" + +from lerobot.policies.smolvla2.chat_processor_smolvla2 import flatten_say_tool_calls +from lerobot.policies.smolvla2.inference.steps import _split_plan_and_say + + +def _say_call(text): + return {"type": "function", "function": {"name": "say", "arguments": {"text": text}}} + + +def test_flatten_appends_say_marker_and_drops_tool_calls(): + msg = {"role": "assistant", "content": "Pick up the blue cube.", "tool_calls": [_say_call("On it!")]} + out = flatten_say_tool_calls(msg) + assert "tool_calls" not in out + assert out["content"] == "Pick up the blue cube.\nOn it!" + + +def test_flatten_roundtrips_through_inference_parser(): + """The marker the serializer writes must be exactly what the inference + parser reads back — this is the train/inference contract.""" + msg = {"role": "assistant", "content": "Move toward the cube.", "tool_calls": [_say_call("Working on it")]} + flat = flatten_say_tool_calls(msg)["content"] + plan, speech = _split_plan_and_say(flat) + assert plan == "Move toward the cube." + assert speech == "Working on it" + + +def test_flatten_accepts_json_string_arguments(): + """``arguments`` may arrive as a JSON string rather than a dict.""" + call = {"type": "function", "function": {"name": "say", "arguments": '{"text": "hello there"}'}} + out = flatten_say_tool_calls({"role": "assistant", "content": "p", "tool_calls": [call]}) + assert out["content"] == "p\nhello there" + + +def test_flatten_leaves_messages_without_tool_calls_untouched(): + msg = {"role": "assistant", "content": "just a plan"} + assert flatten_say_tool_calls(msg) == msg + + +def test_flatten_drops_empty_or_non_say_tool_calls(): + """A non-``say`` call (or empty text) leaves content alone but still + strips the structured calls so the template renders no JSON block.""" + weather = {"type": "function", "function": {"name": "check_weather", "arguments": {}}} + out = flatten_say_tool_calls({"role": "assistant", "content": "plan only", "tool_calls": [weather]}) + assert out["content"] == "plan only" + assert "tool_calls" not in out + + +def test_flatten_marker_only_when_content_empty(): + msg = {"role": "assistant", "content": "", "tool_calls": [_say_call("hi")]} + out = flatten_say_tool_calls(msg) + assert out["content"] == "hi"