fix(pi052): register PaliGemma <loc> tokens so they tokenize as single ids

THE bug behind the <loc>-salad. PaliGemma's vocab reserves ids
[256000, 257023] for <locDDDD> detection / pointing tokens, but the
stock AutoTokenizer does NOT match them on raw text — it BPE-splits
<loc0162> into SEVEN pieces (<, loc, 0, 1, 6, 2, >). So a VQA target
like "<loc0162><loc0759> green box<eos>" tokenized to 16 pieces, not
5, and training the LM head supervised those generic BPE pieces
instead of one detection-vocab id. The piece logits got pumped up
across ~25% of supervised positions; at inference they dominated
every turn — even subtask prompts produced <loc>-salad followed by
the actual answer.

Register the 1024 <locDDDD> tokens via tokenizer.add_tokens once on
load, in every path the policy uses: PI052TextTokenizerStep (training
encode), _build_text_batch_pi052 (runtime encode), and
select_message's default tokenizer (runtime decode). Verified
empirically with the real PaliGemma tokenizer: VQA target now
tokenizes to 5 ids matching the loc-vocab range (256162, 256759, ...)
with correct offset_mapping.

This unlocks PaliGemma's actual detection prior; <loc>-salad cannot
recur because each <locDDDD> is a single class on the LM head, not a
character sequence the head accidentally learns to extend.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-05-20 11:41:41 +02:00
parent 75507491bf
commit 34269a5d78
4 changed files with 71 additions and 3 deletions
+5 -1
View File
@@ -797,7 +797,11 @@ class PI052Policy(PI05Policy):
if tokenizer is None:
from transformers import AutoTokenizer # noqa: PLC0415
tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")
from .text_processor_pi052 import register_paligemma_loc_tokens # noqa: PLC0415
tokenizer = register_paligemma_loc_tokens(
AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")
)
if eos_token_id is None:
eos_token_id = tokenizer.eos_token_id
@@ -254,6 +254,26 @@ def _sample_indices(value: Any, batch_size: int) -> list[int | None]:
_VQA_COORD_SCALE = 1000.0
def register_paligemma_loc_tokens(tokenizer: Any) -> Any:
"""Make PaliGemma's ``<locDDDD>`` ids match on raw text — single tokens.
PaliGemma reserves vocab ids [256000, 257023] for ``<locDDDD>``
(detection / pointing) tokens, but the *stock* tokenizer does NOT
match them when encoding raw text it BPE-splits ``<loc0162>`` into
7 pieces (``<``, ``loc``, ``0``, ``1``, ``6``, ``2``, ``>``). Training
the LM head on a ``<loc>`` target then supervises those 7 generic
BPE pieces instead of one detection-vocab id, the LM head learns to
emit the *character sequence*, and those pieces' logits dominate
other turns (the ``<loc>``-salad on subtasks). Registering the loc
tokens once makes them tokenize as their single ids (256000+idx),
leveraging PaliGemma's detection prior properly. Idempotent.
"""
if "<loc0000>" in getattr(tokenizer, "added_tokens_encoder", {}):
return tokenizer
tokenizer.add_tokens([f"<loc{i:04d}>" for i in range(1024)])
return tokenizer
def _loc_token(coord: float, scale: float = _VQA_COORD_SCALE) -> str:
"""PaliGemma ``<locNNNN>`` for a coord on a ``[0, scale]`` axis."""
idx = round(float(coord) / scale * 1023) if scale > 0 else 0
@@ -410,7 +430,9 @@ class PI052TextTokenizerStep(ProcessorStep):
return self._tokenizer
from transformers import AutoTokenizer # noqa: PLC0415
self._tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name)
self._tokenizer = register_paligemma_loc_tokens(
AutoTokenizer.from_pretrained(self.tokenizer_name)
)
return self._tokenizer
# ------------------------------------------------------------------
@@ -291,12 +291,15 @@ def _build_text_batch_pi052(
_flatten_say_tool_calls,
_format_messages,
_strip_blocks,
register_paligemma_loc_tokens,
)
tok_name = (
getattr(policy.config, "tokenizer_name", None) or "google/paligemma-3b-pt-224"
)
tokenizer = AutoTokenizer.from_pretrained(tok_name)
# Register PaliGemma's <locDDDD> tokens so inference encoding /
# decoding sees them as single vocab ids — must match training.
tokenizer = register_paligemma_loc_tokens(AutoTokenizer.from_pretrained(tok_name))
messages = [_strip_blocks(_flatten_say_tool_calls(m)) for m in prompt_messages]
prompt, _spans = _format_messages(messages)
@@ -36,9 +36,48 @@ from lerobot.policies.pi052.text_processor_pi052 import ( # noqa: E402
_loc_token,
_messages_vqa_to_loc,
_vqa_answer_to_loc,
register_paligemma_loc_tokens,
)
class _FakeTokenizer:
"""Tracks ``add_tokens`` calls; mimics the bits ``register_paligemma_loc_tokens`` reads."""
def __init__(self, prepopulate: bool = False):
self.added_tokens_encoder: dict[str, int] = {}
self.calls: list[list[str]] = []
if prepopulate:
self.added_tokens_encoder["<loc0000>"] = 256000
def add_tokens(self, tokens: list[str]) -> int:
self.calls.append(list(tokens))
for t in tokens:
self.added_tokens_encoder.setdefault(t, len(self.added_tokens_encoder) + 256000)
return len(tokens)
def test_register_loc_tokens_adds_full_1024_range():
tok = _FakeTokenizer()
out = register_paligemma_loc_tokens(tok)
assert out is tok # returns same instance
assert len(tok.calls) == 1
added = tok.calls[0]
assert len(added) == 1024
assert added[0] == "<loc0000>"
assert added[-1] == "<loc1023>"
# Spot check a few in the middle.
assert added[162] == "<loc0162>"
assert added[759] == "<loc0759>"
def test_register_loc_tokens_is_idempotent():
"""If the loc tokens are already present we skip re-adding them."""
tok = _FakeTokenizer(prepopulate=True)
register_paligemma_loc_tokens(tok)
register_paligemma_loc_tokens(tok)
assert tok.calls == [] # never called add_tokens
def test_loc_token_normalizes_and_clamps():
# Default scale is the 01000 Qwen convention.
assert _loc_token(0) == "<loc0000>"