mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-20 02:59:50 +00:00
426d48dbbf
pi052 had the same text-CE collapse bug smolvla2 had — PaliGemma's embed_prefix flags the language block att=0, so make_att_2d_masks makes it fully bidirectional and the text cross-entropy degenerates into a copy task. Ported the three model-specific fixes: - _mark_target_span_causal: set att=1 on supervised target language positions so the text-CE is genuine causal next-token prediction. Applied in both _compute_all_losses_fused and _compute_text_and_fast_loss. - flow_loss_weight 10.0 -> 5.0: the paper's a=10 swamps the LM head once the flow-only low_level recipe fires often (matches SmolVLA2Config). - _flatten_say_tool_calls in the text tokenizer: serialize `say` tool calls into a <say>...</say> marker so the spoken reply is tokenized and supervised (PaliGemma's flat prompt has no structured calls, so they were dropped entirely). select_message needed no change: pi052's prefix is [images, language] with no trailing state token, so it already decodes from the last language token. Regression tests mirror the smolvla2 attention-masking + tool-call suite. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
61 lines
2.4 KiB
Python
61 lines
2.4 KiB
Python
#!/usr/bin/env python
|
|
|
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""Tests for PI052's text-tokenizer ``say`` tool-call flattening.
|
|
|
|
PaliGemma's flat prompt has no structured tool calls, so an assistant
|
|
``say`` tool call must be serialized into a ``<say>...</say>`` text
|
|
marker — otherwise the spoken reply is dropped and never supervised.
|
|
"""
|
|
|
|
from lerobot.policies.pi052.text_processor_pi052 import _flatten_say_tool_calls
|
|
|
|
|
|
def _say_call(text):
|
|
return {"type": "function", "function": {"name": "say", "arguments": {"text": text}}}
|
|
|
|
|
|
def test_flatten_appends_say_marker_and_drops_tool_calls():
|
|
msg = {"role": "assistant", "content": "Heading to the cube.", "tool_calls": [_say_call("On it!")]}
|
|
out = _flatten_say_tool_calls(msg)
|
|
assert "tool_calls" not in out
|
|
assert out["content"] == "Heading to the cube.\n<say>On it!</say>"
|
|
|
|
|
|
def test_flatten_marker_only_when_content_empty_or_none():
|
|
out = _flatten_say_tool_calls({"role": "assistant", "tool_calls": [_say_call("hi")]})
|
|
assert out["content"] == "<say>hi</say>"
|
|
|
|
|
|
def test_flatten_accepts_json_string_arguments():
|
|
call = {"type": "function", "function": {"name": "say", "arguments": '{"text": "hello there"}'}}
|
|
out = _flatten_say_tool_calls({"role": "assistant", "content": "p", "tool_calls": [call]})
|
|
assert out["content"] == "p\n<say>hello there</say>"
|
|
|
|
|
|
def test_flatten_leaves_messages_without_tool_calls_untouched():
|
|
msg = {"role": "assistant", "content": "just a plan"}
|
|
assert _flatten_say_tool_calls(msg) == msg
|
|
|
|
|
|
def test_flatten_drops_non_say_tool_calls_but_keeps_content():
|
|
weather = {"type": "function", "function": {"name": "check_weather", "arguments": {}}}
|
|
out = _flatten_say_tool_calls(
|
|
{"role": "assistant", "content": "plan only", "tool_calls": [weather]}
|
|
)
|
|
assert out["content"] == "plan only"
|
|
assert "tool_calls" not in out
|