fix(pi052): supervise an EOS token at the end of each text target

PI052TextTokenizerStep masked text_labels over the assistant turn's
*content only* — the trailing newline was excluded and no EOS token was
ever a supervised label. So the LM head was never given a stop signal:
at inference select_message decoded to max_new_tokens, producing the
runaway subtask paragraphs and the "}"}"}-style VQA tails.

_format_messages now appends the tokenizer's EOS to each supervised
target turn and extends that turn's span to cover it, so the EOS lands
in text_labels. _shifted_ce then trains "<last content token> -> EOS"
and the model learns to terminate; select_message stops on it.

Inference callers (the runtime's _build_text_batch_pi052) pass no
target_indices / eos_token, so no EOS is baked into the prompt — the
model generates it. Verified end-to-end with the PaliGemma tokenizer:
the supervised span is `<content><eos>` and the trailing newline stays
unsupervised.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-05-19 17:22:22 +02:00
parent 725ac95b0d
commit 15f79b5e5e
2 changed files with 105 additions and 16 deletions
@@ -14,16 +14,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for PI052's text-tokenizer ``say`` tool-call flattening.
"""Tests for PI052's text tokenizer.
PaliGemma's flat prompt has no structured tool calls, so an assistant
``say`` tool call must be serialized into a ``<say>...</say>`` text
marker — otherwise the spoken reply is dropped and never supervised.
Covers ``say`` tool-call flattening (PaliGemma's flat prompt has no
structured tool calls, so a ``say`` call must be serialized into a
``<say>...</say>`` text marker) and EOS-termination supervision (the
supervised target span must end with an EOS token so the LM head learns
to stop instead of rambling to ``max_length`` at inference).
"""
import torch
from lerobot.policies.pi052.text_processor_pi052 import PI052TextTokenizerStep, _flatten_say_tool_calls
from lerobot.policies.pi052.text_processor_pi052 import (
PI052TextTokenizerStep,
_flatten_say_tool_calls,
_format_messages,
)
from lerobot.types import TransitionKey
from lerobot.utils.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
@@ -64,8 +70,71 @@ def test_flatten_drops_non_say_tool_calls_but_keeps_content():
assert "tool_calls" not in out
# ---------------------------------------------------------------------------
# EOS-termination supervision
# ---------------------------------------------------------------------------
def test_format_messages_appends_eos_to_target_turns_only():
msgs = [
{"role": "user", "content": "pick cube"},
{"role": "assistant", "content": "move to cube"},
]
prompt, spans = _format_messages(msgs, target_indices=[1], eos_token="<eos>")
# EOS is appended to the supervised target (assistant) turn only.
assert prompt == "User: pick cube\nAssistant: move to cube<eos>\n"
# The user span is unchanged; the target span covers content + EOS.
assert prompt[spans[0][0] : spans[0][1]] == "pick cube"
assert prompt[spans[1][0] : spans[1][1]] == "move to cube<eos>"
def test_format_messages_without_eos_args_is_unchanged():
"""Inference callers omit target_indices / eos_token — no EOS baked in."""
prompt, spans = _format_messages([{"role": "user", "content": "hi"}])
assert prompt == "User: hi\n"
assert prompt[spans[0][0] : spans[0][1]] == "hi"
def _eos_char_id() -> int:
"""Token id _CharTokenizer assigns to its 1-char EOS."""
return ord("\x1f") % 251 + 1
def test_pi052_text_tokenizer_supervises_eos_at_target_end():
"""The appended EOS is the last supervised label on a target turn —
that's the signal that teaches the LM head to stop. The trailing
newline right after it stays unsupervised (-100)."""
step = PI052TextTokenizerStep(max_length=64)
step._tokenizer = _CharTokenizer()
transition = {
TransitionKey.OBSERVATION: {},
TransitionKey.COMPLEMENTARY_DATA: {
"messages": [
{"role": "user", "content": "pick cube"},
{"role": "assistant", "content": "move to cube"},
],
"target_message_indices": [1],
"message_streams": ["high_level", "high_level"],
"index": torch.tensor(10),
},
}
out = step(transition)
ids = out[TransitionKey.OBSERVATION][OBS_LANGUAGE_TOKENS][0]
labels = out[TransitionKey.COMPLEMENTARY_DATA]["text_labels"][0]
supervised = (labels != -100).nonzero().flatten().tolist()
assert supervised, "target turn produced no supervised labels"
last = supervised[-1]
# The last supervised token is the appended EOS.
assert int(ids[last]) == _eos_char_id()
assert int(labels[last]) == _eos_char_id()
# The token right after the EOS (the trailing newline) is NOT supervised.
assert int(labels[last + 1]) == -100
class _CharTokenizer:
pad_token_id = 0
eos_token = "\x1f" # unit separator — a 1-char "EOS" for testing
def __call__(
self,