mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-20 02:59:50 +00:00
fix(pi052): port the smolvla2 text-head fixes to pi052
pi052 had the same text-CE collapse bug smolvla2 had — PaliGemma's embed_prefix flags the language block att=0, so make_att_2d_masks makes it fully bidirectional and the text cross-entropy degenerates into a copy task. Ported the three model-specific fixes: - _mark_target_span_causal: set att=1 on supervised target language positions so the text-CE is genuine causal next-token prediction. Applied in both _compute_all_losses_fused and _compute_text_and_fast_loss. - flow_loss_weight 10.0 -> 5.0: the paper's a=10 swamps the LM head once the flow-only low_level recipe fires often (matches SmolVLA2Config). - _flatten_say_tool_calls in the text tokenizer: serialize `say` tool calls into a <say>...</say> marker so the spoken reply is tokenized and supervised (PaliGemma's flat prompt has no structured calls, so they were dropped entirely). select_message needed no change: pi052's prefix is [images, language] with no trailing state token, so it already decodes from the last language token. Regression tests mirror the smolvla2 attention-masking + tool-call suite. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,138 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Attention-masking tests for the PI052 (π0.5 v2) text head.
|
||||
|
||||
Regression coverage for the text-CE collapse bug: PaliGemma's
|
||||
``embed_prefix`` flags every language token ``att=0``, which
|
||||
``make_att_2d_masks`` turns into one fully *bidirectional* block. Under
|
||||
that mask the text cross-entropy degenerates into a copy task — a
|
||||
supervised target token attends to the tokens it is trained to predict —
|
||||
and the LM head never learns causal generation, so ``select_message``
|
||||
collapses at inference.
|
||||
|
||||
``_mark_target_span_causal`` sets ``att=1`` on the supervised target
|
||||
language positions so each target token attends causally among the
|
||||
targets while staying bidirectional to images + the user prompt. These
|
||||
tests pin that behaviour for the PaliGemma prefix layout.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
# modeling_pi052 / modeling_pi05 import transformers transitively.
|
||||
pytest.importorskip("transformers")
|
||||
|
||||
from lerobot.policies.pi05.modeling_pi05 import make_att_2d_masks # noqa: E402
|
||||
from lerobot.policies.pi052.modeling_pi052 import _mark_target_span_causal # noqa: E402
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# A synthetic PI052 prefix layout: [images, prompt-lang, target-lang]
|
||||
#
|
||||
# indices 0-1 : 2 image tokens (att = 0)
|
||||
# indices 2-4 : 3 user-prompt lang (att = 0)
|
||||
# indices 5-8 : 4 supervised target lang(att = 0 from embed_prefix)
|
||||
#
|
||||
# ``text_labels`` covers the 7 language tokens; -100 on the prompt span,
|
||||
# real ids on the 4-token target span. PaliGemma's prefix has no state
|
||||
# token (unlike SmolVLA), so the lang span ends at the prefix end.
|
||||
# ---------------------------------------------------------------------------
|
||||
N_IMAGE = 2
|
||||
N_PROMPT = 3
|
||||
N_TARGET = 4
|
||||
LANG_START = N_IMAGE
|
||||
LANG_END = N_IMAGE + N_PROMPT + N_TARGET # = prefix length
|
||||
PREFIX_LEN = LANG_END
|
||||
|
||||
|
||||
def _embed_prefix_att_masks() -> torch.Tensor:
|
||||
"""Mimic PaliGemma ``embed_prefix``: images + lang all att=0."""
|
||||
return torch.zeros(1, PREFIX_LEN, dtype=torch.bool)
|
||||
|
||||
|
||||
def _text_labels() -> torch.Tensor:
|
||||
"""-100 over the prompt span, real ids over the target span."""
|
||||
labels = torch.full((1, N_PROMPT + N_TARGET), -100, dtype=torch.long)
|
||||
labels[0, N_PROMPT:] = torch.arange(10, 10 + N_TARGET)
|
||||
return labels
|
||||
|
||||
|
||||
def _attends(prefix_att_masks: torch.Tensor) -> torch.Tensor:
|
||||
"""2D boolean attendance matrix; ``[i, j]`` True ⇒ i attends to j."""
|
||||
pad = torch.ones(1, PREFIX_LEN, dtype=torch.bool)
|
||||
return make_att_2d_masks(pad, prefix_att_masks)[0]
|
||||
|
||||
|
||||
def test_mark_sets_att_on_targets_only():
|
||||
"""Only the supervised target language positions flip to att=1."""
|
||||
marked = _mark_target_span_causal(
|
||||
_embed_prefix_att_masks(), _text_labels(), LANG_START, LANG_END
|
||||
)
|
||||
expected = [False] * PREFIX_LEN
|
||||
for i in range(LANG_START + N_PROMPT, LANG_END): # target span
|
||||
expected[i] = True
|
||||
assert marked[0].tolist() == expected
|
||||
|
||||
|
||||
def test_target_tokens_attend_causally_among_themselves():
|
||||
"""A target token must NOT attend to later targets, but must attend
|
||||
to earlier ones — genuine causal next-token prediction."""
|
||||
marked = _mark_target_span_causal(
|
||||
_embed_prefix_att_masks(), _text_labels(), LANG_START, LANG_END
|
||||
)
|
||||
attends = _attends(marked)
|
||||
tgt = range(LANG_START + N_PROMPT, LANG_END)
|
||||
for i in tgt:
|
||||
for j in tgt:
|
||||
if j > i:
|
||||
assert not attends[i, j], f"target {i} must not see future target {j}"
|
||||
else:
|
||||
assert attends[i, j], f"target {i} must see earlier/self target {j}"
|
||||
|
||||
|
||||
def test_target_tokens_attend_prompt_and_images_bidirectionally():
|
||||
"""Targets keep full visibility of images + the user prompt."""
|
||||
marked = _mark_target_span_causal(
|
||||
_embed_prefix_att_masks(), _text_labels(), LANG_START, LANG_END
|
||||
)
|
||||
attends = _attends(marked)
|
||||
context = list(range(0, LANG_START + N_PROMPT)) # images + prompt
|
||||
for i in range(LANG_START + N_PROMPT, LANG_END):
|
||||
for j in context:
|
||||
assert attends[i, j], f"target {i} must attend context {j}"
|
||||
|
||||
|
||||
def test_non_target_subtask_stays_bidirectional():
|
||||
"""A flow-only / non-target language span (all -100 labels) leaves the
|
||||
mask untouched — the action expert reads it bidirectionally."""
|
||||
all_ignored = torch.full((1, N_PROMPT + N_TARGET), -100, dtype=torch.long)
|
||||
marked = _mark_target_span_causal(
|
||||
_embed_prefix_att_masks(), all_ignored, LANG_START, LANG_END
|
||||
)
|
||||
assert torch.equal(marked, _embed_prefix_att_masks())
|
||||
|
||||
|
||||
def test_unmarked_mask_is_bidirectional_the_bug():
|
||||
"""Documents the bug the fix prevents: without ``_mark_target_span_causal``
|
||||
a target token attends *bidirectionally* to later targets — the
|
||||
text-CE can copy the answer it is trained to predict."""
|
||||
attends = _attends(_embed_prefix_att_masks())
|
||||
first_tgt = LANG_START + N_PROMPT
|
||||
last_tgt = LANG_END - 1
|
||||
assert attends[first_tgt, last_tgt], (
|
||||
"raw embed_prefix mask is bidirectional over language — the first "
|
||||
"target token can see the last, which is the collapse bug"
|
||||
)
|
||||
@@ -0,0 +1,60 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for PI052's text-tokenizer ``say`` tool-call flattening.
|
||||
|
||||
PaliGemma's flat prompt has no structured tool calls, so an assistant
|
||||
``say`` tool call must be serialized into a ``<say>...</say>`` text
|
||||
marker — otherwise the spoken reply is dropped and never supervised.
|
||||
"""
|
||||
|
||||
from lerobot.policies.pi052.text_processor_pi052 import _flatten_say_tool_calls
|
||||
|
||||
|
||||
def _say_call(text):
|
||||
return {"type": "function", "function": {"name": "say", "arguments": {"text": text}}}
|
||||
|
||||
|
||||
def test_flatten_appends_say_marker_and_drops_tool_calls():
|
||||
msg = {"role": "assistant", "content": "Heading to the cube.", "tool_calls": [_say_call("On it!")]}
|
||||
out = _flatten_say_tool_calls(msg)
|
||||
assert "tool_calls" not in out
|
||||
assert out["content"] == "Heading to the cube.\n<say>On it!</say>"
|
||||
|
||||
|
||||
def test_flatten_marker_only_when_content_empty_or_none():
|
||||
out = _flatten_say_tool_calls({"role": "assistant", "tool_calls": [_say_call("hi")]})
|
||||
assert out["content"] == "<say>hi</say>"
|
||||
|
||||
|
||||
def test_flatten_accepts_json_string_arguments():
|
||||
call = {"type": "function", "function": {"name": "say", "arguments": '{"text": "hello there"}'}}
|
||||
out = _flatten_say_tool_calls({"role": "assistant", "content": "p", "tool_calls": [call]})
|
||||
assert out["content"] == "p\n<say>hello there</say>"
|
||||
|
||||
|
||||
def test_flatten_leaves_messages_without_tool_calls_untouched():
|
||||
msg = {"role": "assistant", "content": "just a plan"}
|
||||
assert _flatten_say_tool_calls(msg) == msg
|
||||
|
||||
|
||||
def test_flatten_drops_non_say_tool_calls_but_keeps_content():
|
||||
weather = {"type": "function", "function": {"name": "check_weather", "arguments": {}}}
|
||||
out = _flatten_say_tool_calls(
|
||||
{"role": "assistant", "content": "plan only", "tool_calls": [weather]}
|
||||
)
|
||||
assert out["content"] == "plan only"
|
||||
assert "tool_calls" not in out
|
||||
Reference in New Issue
Block a user