mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-21 03:30:10 +00:00
fix(pi052): register PaliGemma <loc> tokens so they tokenize as single ids
THE bug behind the <loc>-salad. PaliGemma's vocab reserves ids [256000, 257023] for <locDDDD> detection / pointing tokens, but the stock AutoTokenizer does NOT match them on raw text — it BPE-splits <loc0162> into SEVEN pieces (<, loc, 0, 1, 6, 2, >). So a VQA target like "<loc0162><loc0759> green box<eos>" tokenized to 16 pieces, not 5, and training the LM head supervised those generic BPE pieces instead of one detection-vocab id. The piece logits got pumped up across ~25% of supervised positions; at inference they dominated every turn — even subtask prompts produced <loc>-salad followed by the actual answer. Register the 1024 <locDDDD> tokens via tokenizer.add_tokens once on load, in every path the policy uses: PI052TextTokenizerStep (training encode), _build_text_batch_pi052 (runtime encode), and select_message's default tokenizer (runtime decode). Verified empirically with the real PaliGemma tokenizer: VQA target now tokenizes to 5 ids matching the loc-vocab range (256162, 256759, ...) with correct offset_mapping. This unlocks PaliGemma's actual detection prior; <loc>-salad cannot recur because each <locDDDD> is a single class on the LM head, not a character sequence the head accidentally learns to extend. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -797,7 +797,11 @@ class PI052Policy(PI05Policy):
|
||||
if tokenizer is None:
|
||||
from transformers import AutoTokenizer # noqa: PLC0415
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")
|
||||
from .text_processor_pi052 import register_paligemma_loc_tokens # noqa: PLC0415
|
||||
|
||||
tokenizer = register_paligemma_loc_tokens(
|
||||
AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")
|
||||
)
|
||||
if eos_token_id is None:
|
||||
eos_token_id = tokenizer.eos_token_id
|
||||
|
||||
|
||||
@@ -254,6 +254,26 @@ def _sample_indices(value: Any, batch_size: int) -> list[int | None]:
|
||||
_VQA_COORD_SCALE = 1000.0
|
||||
|
||||
|
||||
def register_paligemma_loc_tokens(tokenizer: Any) -> Any:
|
||||
"""Make PaliGemma's ``<locDDDD>`` ids match on raw text — single tokens.
|
||||
|
||||
PaliGemma reserves vocab ids [256000, 257023] for ``<locDDDD>``
|
||||
(detection / pointing) tokens, but the *stock* tokenizer does NOT
|
||||
match them when encoding raw text — it BPE-splits ``<loc0162>`` into
|
||||
7 pieces (``<``, ``loc``, ``0``, ``1``, ``6``, ``2``, ``>``). Training
|
||||
the LM head on a ``<loc>`` target then supervises those 7 generic
|
||||
BPE pieces instead of one detection-vocab id, the LM head learns to
|
||||
emit the *character sequence*, and those pieces' logits dominate
|
||||
other turns (the ``<loc>``-salad on subtasks). Registering the loc
|
||||
tokens once makes them tokenize as their single ids (256000+idx),
|
||||
leveraging PaliGemma's detection prior properly. Idempotent.
|
||||
"""
|
||||
if "<loc0000>" in getattr(tokenizer, "added_tokens_encoder", {}):
|
||||
return tokenizer
|
||||
tokenizer.add_tokens([f"<loc{i:04d}>" for i in range(1024)])
|
||||
return tokenizer
|
||||
|
||||
|
||||
def _loc_token(coord: float, scale: float = _VQA_COORD_SCALE) -> str:
|
||||
"""PaliGemma ``<locNNNN>`` for a coord on a ``[0, scale]`` axis."""
|
||||
idx = round(float(coord) / scale * 1023) if scale > 0 else 0
|
||||
@@ -410,7 +430,9 @@ class PI052TextTokenizerStep(ProcessorStep):
|
||||
return self._tokenizer
|
||||
from transformers import AutoTokenizer # noqa: PLC0415
|
||||
|
||||
self._tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name)
|
||||
self._tokenizer = register_paligemma_loc_tokens(
|
||||
AutoTokenizer.from_pretrained(self.tokenizer_name)
|
||||
)
|
||||
return self._tokenizer
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@@ -291,12 +291,15 @@ def _build_text_batch_pi052(
|
||||
_flatten_say_tool_calls,
|
||||
_format_messages,
|
||||
_strip_blocks,
|
||||
register_paligemma_loc_tokens,
|
||||
)
|
||||
|
||||
tok_name = (
|
||||
getattr(policy.config, "tokenizer_name", None) or "google/paligemma-3b-pt-224"
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(tok_name)
|
||||
# Register PaliGemma's <locDDDD> tokens so inference encoding /
|
||||
# decoding sees them as single vocab ids — must match training.
|
||||
tokenizer = register_paligemma_loc_tokens(AutoTokenizer.from_pretrained(tok_name))
|
||||
|
||||
messages = [_strip_blocks(_flatten_say_tool_calls(m)) for m in prompt_messages]
|
||||
prompt, _spans = _format_messages(messages)
|
||||
|
||||
@@ -36,9 +36,48 @@ from lerobot.policies.pi052.text_processor_pi052 import ( # noqa: E402
|
||||
_loc_token,
|
||||
_messages_vqa_to_loc,
|
||||
_vqa_answer_to_loc,
|
||||
register_paligemma_loc_tokens,
|
||||
)
|
||||
|
||||
|
||||
class _FakeTokenizer:
|
||||
"""Tracks ``add_tokens`` calls; mimics the bits ``register_paligemma_loc_tokens`` reads."""
|
||||
|
||||
def __init__(self, prepopulate: bool = False):
|
||||
self.added_tokens_encoder: dict[str, int] = {}
|
||||
self.calls: list[list[str]] = []
|
||||
if prepopulate:
|
||||
self.added_tokens_encoder["<loc0000>"] = 256000
|
||||
|
||||
def add_tokens(self, tokens: list[str]) -> int:
|
||||
self.calls.append(list(tokens))
|
||||
for t in tokens:
|
||||
self.added_tokens_encoder.setdefault(t, len(self.added_tokens_encoder) + 256000)
|
||||
return len(tokens)
|
||||
|
||||
|
||||
def test_register_loc_tokens_adds_full_1024_range():
|
||||
tok = _FakeTokenizer()
|
||||
out = register_paligemma_loc_tokens(tok)
|
||||
assert out is tok # returns same instance
|
||||
assert len(tok.calls) == 1
|
||||
added = tok.calls[0]
|
||||
assert len(added) == 1024
|
||||
assert added[0] == "<loc0000>"
|
||||
assert added[-1] == "<loc1023>"
|
||||
# Spot check a few in the middle.
|
||||
assert added[162] == "<loc0162>"
|
||||
assert added[759] == "<loc0759>"
|
||||
|
||||
|
||||
def test_register_loc_tokens_is_idempotent():
|
||||
"""If the loc tokens are already present we skip re-adding them."""
|
||||
tok = _FakeTokenizer(prepopulate=True)
|
||||
register_paligemma_loc_tokens(tok)
|
||||
register_paligemma_loc_tokens(tok)
|
||||
assert tok.calls == [] # never called add_tokens
|
||||
|
||||
|
||||
def test_loc_token_normalizes_and_clamps():
|
||||
# Default scale is the 0–1000 Qwen convention.
|
||||
assert _loc_token(0) == "<loc0000>"
|
||||
|
||||
Reference in New Issue
Block a user