diff --git a/src/lerobot/policies/pi052/modeling_pi052.py b/src/lerobot/policies/pi052/modeling_pi052.py index a8e2103fa..440c5afa9 100644 --- a/src/lerobot/policies/pi052/modeling_pi052.py +++ b/src/lerobot/policies/pi052/modeling_pi052.py @@ -797,7 +797,11 @@ class PI052Policy(PI05Policy): if tokenizer is None: from transformers import AutoTokenizer # noqa: PLC0415 - tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224") + from .text_processor_pi052 import register_paligemma_loc_tokens # noqa: PLC0415 + + tokenizer = register_paligemma_loc_tokens( + AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224") + ) if eos_token_id is None: eos_token_id = tokenizer.eos_token_id diff --git a/src/lerobot/policies/pi052/text_processor_pi052.py b/src/lerobot/policies/pi052/text_processor_pi052.py index 6e0e8bff6..ea5951b4b 100644 --- a/src/lerobot/policies/pi052/text_processor_pi052.py +++ b/src/lerobot/policies/pi052/text_processor_pi052.py @@ -254,6 +254,26 @@ def _sample_indices(value: Any, batch_size: int) -> list[int | None]: _VQA_COORD_SCALE = 1000.0 +def register_paligemma_loc_tokens(tokenizer: Any) -> Any: + """Make PaliGemma's ```` ids match on raw text — single tokens. + + PaliGemma reserves vocab ids [256000, 257023] for ```` + (detection / pointing) tokens, but the *stock* tokenizer does NOT + match them when encoding raw text — it BPE-splits ```` into + 7 pieces (``<``, ``loc``, ``0``, ``1``, ``6``, ``2``, ``>``). Training + the LM head on a ```` target then supervises those 7 generic + BPE pieces instead of one detection-vocab id, the LM head learns to + emit the *character sequence*, and those pieces' logits dominate + other turns (the ````-salad on subtasks). Registering the loc + tokens once makes them tokenize as their single ids (256000+idx), + leveraging PaliGemma's detection prior properly. Idempotent. + """ + if "" in getattr(tokenizer, "added_tokens_encoder", {}): + return tokenizer + tokenizer.add_tokens([f"" for i in range(1024)]) + return tokenizer + + def _loc_token(coord: float, scale: float = _VQA_COORD_SCALE) -> str: """PaliGemma ```` for a coord on a ``[0, scale]`` axis.""" idx = round(float(coord) / scale * 1023) if scale > 0 else 0 @@ -410,7 +430,9 @@ class PI052TextTokenizerStep(ProcessorStep): return self._tokenizer from transformers import AutoTokenizer # noqa: PLC0415 - self._tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name) + self._tokenizer = register_paligemma_loc_tokens( + AutoTokenizer.from_pretrained(self.tokenizer_name) + ) return self._tokenizer # ------------------------------------------------------------------ diff --git a/src/lerobot/policies/smolvla2/inference/steps.py b/src/lerobot/policies/smolvla2/inference/steps.py index 403e10003..6366103e9 100644 --- a/src/lerobot/policies/smolvla2/inference/steps.py +++ b/src/lerobot/policies/smolvla2/inference/steps.py @@ -291,12 +291,15 @@ def _build_text_batch_pi052( _flatten_say_tool_calls, _format_messages, _strip_blocks, + register_paligemma_loc_tokens, ) tok_name = ( getattr(policy.config, "tokenizer_name", None) or "google/paligemma-3b-pt-224" ) - tokenizer = AutoTokenizer.from_pretrained(tok_name) + # Register PaliGemma's tokens so inference encoding / + # decoding sees them as single vocab ids — must match training. + tokenizer = register_paligemma_loc_tokens(AutoTokenizer.from_pretrained(tok_name)) messages = [_strip_blocks(_flatten_say_tool_calls(m)) for m in prompt_messages] prompt, _spans = _format_messages(messages) diff --git a/tests/policies/pi052/test_pi052_vqa_loc.py b/tests/policies/pi052/test_pi052_vqa_loc.py index a51452a08..a1e145350 100644 --- a/tests/policies/pi052/test_pi052_vqa_loc.py +++ b/tests/policies/pi052/test_pi052_vqa_loc.py @@ -36,9 +36,48 @@ from lerobot.policies.pi052.text_processor_pi052 import ( # noqa: E402 _loc_token, _messages_vqa_to_loc, _vqa_answer_to_loc, + register_paligemma_loc_tokens, ) +class _FakeTokenizer: + """Tracks ``add_tokens`` calls; mimics the bits ``register_paligemma_loc_tokens`` reads.""" + + def __init__(self, prepopulate: bool = False): + self.added_tokens_encoder: dict[str, int] = {} + self.calls: list[list[str]] = [] + if prepopulate: + self.added_tokens_encoder[""] = 256000 + + def add_tokens(self, tokens: list[str]) -> int: + self.calls.append(list(tokens)) + for t in tokens: + self.added_tokens_encoder.setdefault(t, len(self.added_tokens_encoder) + 256000) + return len(tokens) + + +def test_register_loc_tokens_adds_full_1024_range(): + tok = _FakeTokenizer() + out = register_paligemma_loc_tokens(tok) + assert out is tok # returns same instance + assert len(tok.calls) == 1 + added = tok.calls[0] + assert len(added) == 1024 + assert added[0] == "" + assert added[-1] == "" + # Spot check a few in the middle. + assert added[162] == "" + assert added[759] == "" + + +def test_register_loc_tokens_is_idempotent(): + """If the loc tokens are already present we skip re-adding them.""" + tok = _FakeTokenizer(prepopulate=True) + register_paligemma_loc_tokens(tok) + register_paligemma_loc_tokens(tok) + assert tok.calls == [] # never called add_tokens + + def test_loc_token_normalizes_and_clamps(): # Default scale is the 0–1000 Qwen convention. assert _loc_token(0) == ""