Files
lerobot/tests/policies/pi052/test_pi052_vqa_loc.py
T
Pepijn 5bb2da4da6 fix(pi052): VQA target format = "label <loc><loc>" not "<loc><loc> label"
The trained model collapsed to spewing 40+ <loc> tokens for *every*
prompt — subtask, memory, anything — because VQA targets were supervised
to *start* with <loc>. With ~25% of all text samples beginning with a
<loc> token, the LM head learned "Assistant: → <loc>" as a strong
attractor; once one loc is emitted, autoregression chains the rest.

Flip the format so every text target — subtask, memory, speech, AND VQA
— starts with a regular word. The model still learns the <loc>
vocabulary for the spatial portion of the answer, but loc can no
longer be the first generation step out of a clean prompt.

Examples:
  point  : "green box <loc0162><loc0759>"
  bbox   : "cube <loc0082>…<loc0409>"
  multi  : "blue <locs> ; yellow <locs>"

The runtime parser (parse_loc_answer) strips loc tokens and uses the
remainder as label, so it's order-tolerant and works under either
format. Old loc-first checkpoints still parse cleanly at inference;
new training will use label-first.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-20 18:56:48 +02:00

188 lines
7.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Training-side conversion of VQA answers to PaliGemma ``<loc>`` text.
PI052 trains spatial VQA answers (``bbox`` / ``keypoint``) in
PaliGemma's native ``<locNNNN>`` detection vocabulary so the LM head
reuses the detection prior instead of fighting it (the ``<loc>``-salad
bug). The dataset stores Qwen2.5-VL's grounding output — **01000
normalized** coordinates, *not* pixels. (Verified empirically on the
published datasets: x and y both span 0..1000 with ~30% of values
exceeding the camera's pixel dimensions.) The conversion is therefore
camera-resolution-independent. The dataset stays backbone-agnostic
JSON; the conversion lives in PI052's tokenizer. These tests pin the
JSON → ``<loc>`` rewrite.
"""
import pytest
pytest.importorskip("transformers")
from lerobot.policies.pi052.text_processor_pi052 import ( # noqa: E402
_loc_token,
_messages_vqa_to_loc,
_vqa_answer_to_loc,
register_paligemma_loc_tokens,
)
class _FakeTokenizer:
"""Tracks ``add_tokens`` calls; mimics the bits ``register_paligemma_loc_tokens`` reads."""
def __init__(self, prepopulate: bool = False):
self.added_tokens_encoder: dict[str, int] = {}
self.calls: list[list[str]] = []
if prepopulate:
self.added_tokens_encoder["<loc0000>"] = 256000
def add_tokens(self, tokens: list[str]) -> int:
self.calls.append(list(tokens))
for t in tokens:
self.added_tokens_encoder.setdefault(t, len(self.added_tokens_encoder) + 256000)
return len(tokens)
def test_register_loc_tokens_adds_full_1024_range():
tok = _FakeTokenizer()
out = register_paligemma_loc_tokens(tok)
assert out is tok # returns same instance
assert len(tok.calls) == 1
added = tok.calls[0]
assert len(added) == 1024
assert added[0] == "<loc0000>"
assert added[-1] == "<loc1023>"
# Spot check a few in the middle.
assert added[162] == "<loc0162>"
assert added[759] == "<loc0759>"
def test_register_loc_tokens_is_idempotent():
"""If the loc tokens are already present we skip re-adding them."""
tok = _FakeTokenizer(prepopulate=True)
register_paligemma_loc_tokens(tok)
register_paligemma_loc_tokens(tok)
assert tok.calls == [] # never called add_tokens
def test_loc_token_normalizes_and_clamps():
# Default scale is the 01000 Qwen convention.
assert _loc_token(0) == "<loc0000>"
assert _loc_token(1000) == "<loc1023>"
assert _loc_token(500) == f"<loc{round(500 / 1000 * 1023):04d}>"
# out-of-range coordinates clamp into [0, 1023]
assert _loc_token(9999) == "<loc1023>"
assert _loc_token(-5) == "<loc0000>"
def test_vqa_answer_to_loc_keypoint_normalized():
# Label-first: avoids the "Assistant: → <loc>" attractor at training.
answer = {"label": "blue cube", "point_format": "xy", "point": [500, 500]}
assert _vqa_answer_to_loc(answer) == "blue cube <loc0512><loc0512>"
def test_vqa_answer_to_loc_bbox_normalized():
answer = {
"detections": [{"label": "cube", "bbox_format": "xyxy", "bbox": [0, 0, 1000, 1000]}]
}
assert _vqa_answer_to_loc(answer) == "cube <loc0000><loc0000><loc1023><loc1023>"
def test_vqa_answer_to_loc_multiple_detections_separator():
answer = {
"detections": [
{"label": "blue", "bbox_format": "xyxy", "bbox": [0, 0, 500, 500]},
{"label": "yellow", "bbox_format": "xyxy", "bbox": [500, 500, 1000, 1000]},
]
}
out = _vqa_answer_to_loc(answer)
# Each segment is "label <locs>", joined by " ; "
assert out == (
"blue <loc0000><loc0000><loc0512><loc0512> ; "
"yellow <loc0512><loc0512><loc1023><loc1023>"
)
def test_vqa_answer_to_loc_returns_none_for_non_spatial():
assert _vqa_answer_to_loc({"label": "cubes", "count": 2}) is None
assert _vqa_answer_to_loc({"weird": "payload"}) is None
def test_messages_vqa_to_loc_rewrites_target_turn():
messages = [
{"role": "user", "content": [{"type": "text", "text": "where is the cube?"}]},
{
"role": "assistant",
"content": '{"label": "cube", "point_format": "xy", "point": [500, 500]}',
},
]
out = _messages_vqa_to_loc(messages, target_indices=[1])
assert out[1]["content"] == "cube <loc0512><loc0512>"
# input messages are not mutated
assert messages[1]["content"].startswith("{")
def test_messages_vqa_to_loc_leaves_plain_text_targets_untouched():
messages = [
{"role": "user", "content": "pick the cube"},
{"role": "assistant", "content": "pick up the cube"},
]
out = _messages_vqa_to_loc(messages, target_indices=[1])
assert out[1]["content"] == "pick up the cube"
def test_messages_vqa_to_loc_noop_without_target_indices():
messages = [
{"role": "assistant", "content": '{"label": "c", "point_format": "xy", "point": [1, 2]}'}
]
assert _messages_vqa_to_loc(messages, []) is messages
# ---------------------------------------------------------------------------
# Round-trip: training-side JSON -> <loc> -> runtime-side parse back
#
# Pins that the conversion preserves coordinate *order* (JSON is x-first,
# PaliGemma <loc> is y-first) and the 01000 → [0, 1023] scaling. The
# only loss is quantization to the 1024-bucket <loc> grid, so a coord
# survives within half a bucket (~1000/2046 ≈ 0.49 on the 01000 scale).
# ---------------------------------------------------------------------------
def test_loc_round_trip_keypoint_preserves_normalized_coords():
from lerobot.policies.smolvla2.inference.vqa import parse_vqa_answer
answer = {"label": "blue cube", "point_format": "xy", "point": [640, 480]}
loc = _vqa_answer_to_loc(answer)
parsed = parse_vqa_answer(loc)
nx, ny = parsed["payload"]["point"]
# parse_vqa_answer returns [0, 1] normalized; rescale back to 01000.
assert abs(nx * 1000.0 - 640) <= 1000.0 / 2046 + 1e-6
assert abs(ny * 1000.0 - 480) <= 1000.0 / 2046 + 1e-6
assert parsed["payload"]["label"] == "blue cube"
def test_loc_round_trip_bbox_preserves_order_and_scale():
from lerobot.policies.smolvla2.inference.vqa import parse_vqa_answer
answer = {
"detections": [{"label": "cube", "bbox_format": "xyxy", "bbox": [100, 200, 800, 900]}]
}
loc = _vqa_answer_to_loc(answer)
parsed = parse_vqa_answer(loc)
x1, y1, x2, y2 = parsed["payload"]["detections"][0]["bbox"]
for got, want in ((x1, 100), (y1, 200), (x2, 800), (y2, 900)):
assert abs(got * 1000.0 - want) <= 1000.0 / 2046 + 1e-6