Files
lerobot/tests/policies/pi052/test_pi052_text_processor.py
T
Pepijn 426d48dbbf fix(pi052): port the smolvla2 text-head fixes to pi052
pi052 had the same text-CE collapse bug smolvla2 had — PaliGemma's
embed_prefix flags the language block att=0, so make_att_2d_masks makes
it fully bidirectional and the text cross-entropy degenerates into a
copy task. Ported the three model-specific fixes:

- _mark_target_span_causal: set att=1 on supervised target language
  positions so the text-CE is genuine causal next-token prediction.
  Applied in both _compute_all_losses_fused and _compute_text_and_fast_loss.
- flow_loss_weight 10.0 -> 5.0: the paper's a=10 swamps the LM head once
  the flow-only low_level recipe fires often (matches SmolVLA2Config).
- _flatten_say_tool_calls in the text tokenizer: serialize `say` tool
  calls into a <say>...</say> marker so the spoken reply is tokenized
  and supervised (PaliGemma's flat prompt has no structured calls, so
  they were dropped entirely).

select_message needed no change: pi052's prefix is [images, language]
with no trailing state token, so it already decodes from the last
language token.

Regression tests mirror the smolvla2 attention-masking + tool-call suite.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-18 15:42:19 +02:00

61 lines
2.4 KiB
Python

#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for PI052's text-tokenizer ``say`` tool-call flattening.
PaliGemma's flat prompt has no structured tool calls, so an assistant
``say`` tool call must be serialized into a ``<say>...</say>`` text
marker — otherwise the spoken reply is dropped and never supervised.
"""
from lerobot.policies.pi052.text_processor_pi052 import _flatten_say_tool_calls
def _say_call(text):
return {"type": "function", "function": {"name": "say", "arguments": {"text": text}}}
def test_flatten_appends_say_marker_and_drops_tool_calls():
msg = {"role": "assistant", "content": "Heading to the cube.", "tool_calls": [_say_call("On it!")]}
out = _flatten_say_tool_calls(msg)
assert "tool_calls" not in out
assert out["content"] == "Heading to the cube.\n<say>On it!</say>"
def test_flatten_marker_only_when_content_empty_or_none():
out = _flatten_say_tool_calls({"role": "assistant", "tool_calls": [_say_call("hi")]})
assert out["content"] == "<say>hi</say>"
def test_flatten_accepts_json_string_arguments():
call = {"type": "function", "function": {"name": "say", "arguments": '{"text": "hello there"}'}}
out = _flatten_say_tool_calls({"role": "assistant", "content": "p", "tool_calls": [call]})
assert out["content"] == "p\n<say>hello there</say>"
def test_flatten_leaves_messages_without_tool_calls_untouched():
msg = {"role": "assistant", "content": "just a plan"}
assert _flatten_say_tool_calls(msg) == msg
def test_flatten_drops_non_say_tool_calls_but_keeps_content():
weather = {"type": "function", "function": {"name": "check_weather", "arguments": {}}}
out = _flatten_say_tool_calls(
{"role": "assistant", "content": "plan only", "tool_calls": [weather]}
)
assert out["content"] == "plan only"
assert "tool_calls" not in out