mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-24 13:09:43 +00:00
5bb2da4da6
The trained model collapsed to spewing 40+ <loc> tokens for *every* prompt — subtask, memory, anything — because VQA targets were supervised to *start* with <loc>. With ~25% of all text samples beginning with a <loc> token, the LM head learned "Assistant: → <loc>" as a strong attractor; once one loc is emitted, autoregression chains the rest. Flip the format so every text target — subtask, memory, speech, AND VQA — starts with a regular word. The model still learns the <loc> vocabulary for the spatial portion of the answer, but loc can no longer be the first generation step out of a clean prompt. Examples: point : "green box <loc0162><loc0759>" bbox : "cube <loc0082>…<loc0409>" multi : "blue <locs> ; yellow <locs>" The runtime parser (parse_loc_answer) strips loc tokens and uses the remainder as label, so it's order-tolerant and works under either format. Old loc-first checkpoints still parse cleanly at inference; new training will use label-first. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
188 lines
7.0 KiB
Python
188 lines
7.0 KiB
Python
#!/usr/bin/env python
|
||
|
||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||
#
|
||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
# you may not use this file except in compliance with the License.
|
||
# You may obtain a copy of the License at
|
||
#
|
||
# http://www.apache.org/licenses/LICENSE-2.0
|
||
#
|
||
# Unless required by applicable law or agreed to in writing, software
|
||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
# See the License for the specific language governing permissions and
|
||
# limitations under the License.
|
||
|
||
"""Training-side conversion of VQA answers to PaliGemma ``<loc>`` text.
|
||
|
||
PI052 trains spatial VQA answers (``bbox`` / ``keypoint``) in
|
||
PaliGemma's native ``<locNNNN>`` detection vocabulary so the LM head
|
||
reuses the detection prior instead of fighting it (the ``<loc>``-salad
|
||
bug). The dataset stores Qwen2.5-VL's grounding output — **0–1000
|
||
normalized** coordinates, *not* pixels. (Verified empirically on the
|
||
published datasets: x and y both span 0..1000 with ~30% of values
|
||
exceeding the camera's pixel dimensions.) The conversion is therefore
|
||
camera-resolution-independent. The dataset stays backbone-agnostic
|
||
JSON; the conversion lives in PI052's tokenizer. These tests pin the
|
||
JSON → ``<loc>`` rewrite.
|
||
"""
|
||
|
||
import pytest
|
||
|
||
pytest.importorskip("transformers")
|
||
|
||
from lerobot.policies.pi052.text_processor_pi052 import ( # noqa: E402
|
||
_loc_token,
|
||
_messages_vqa_to_loc,
|
||
_vqa_answer_to_loc,
|
||
register_paligemma_loc_tokens,
|
||
)
|
||
|
||
|
||
class _FakeTokenizer:
|
||
"""Tracks ``add_tokens`` calls; mimics the bits ``register_paligemma_loc_tokens`` reads."""
|
||
|
||
def __init__(self, prepopulate: bool = False):
|
||
self.added_tokens_encoder: dict[str, int] = {}
|
||
self.calls: list[list[str]] = []
|
||
if prepopulate:
|
||
self.added_tokens_encoder["<loc0000>"] = 256000
|
||
|
||
def add_tokens(self, tokens: list[str]) -> int:
|
||
self.calls.append(list(tokens))
|
||
for t in tokens:
|
||
self.added_tokens_encoder.setdefault(t, len(self.added_tokens_encoder) + 256000)
|
||
return len(tokens)
|
||
|
||
|
||
def test_register_loc_tokens_adds_full_1024_range():
|
||
tok = _FakeTokenizer()
|
||
out = register_paligemma_loc_tokens(tok)
|
||
assert out is tok # returns same instance
|
||
assert len(tok.calls) == 1
|
||
added = tok.calls[0]
|
||
assert len(added) == 1024
|
||
assert added[0] == "<loc0000>"
|
||
assert added[-1] == "<loc1023>"
|
||
# Spot check a few in the middle.
|
||
assert added[162] == "<loc0162>"
|
||
assert added[759] == "<loc0759>"
|
||
|
||
|
||
def test_register_loc_tokens_is_idempotent():
|
||
"""If the loc tokens are already present we skip re-adding them."""
|
||
tok = _FakeTokenizer(prepopulate=True)
|
||
register_paligemma_loc_tokens(tok)
|
||
register_paligemma_loc_tokens(tok)
|
||
assert tok.calls == [] # never called add_tokens
|
||
|
||
|
||
def test_loc_token_normalizes_and_clamps():
|
||
# Default scale is the 0–1000 Qwen convention.
|
||
assert _loc_token(0) == "<loc0000>"
|
||
assert _loc_token(1000) == "<loc1023>"
|
||
assert _loc_token(500) == f"<loc{round(500 / 1000 * 1023):04d}>"
|
||
# out-of-range coordinates clamp into [0, 1023]
|
||
assert _loc_token(9999) == "<loc1023>"
|
||
assert _loc_token(-5) == "<loc0000>"
|
||
|
||
|
||
def test_vqa_answer_to_loc_keypoint_normalized():
|
||
# Label-first: avoids the "Assistant: → <loc>" attractor at training.
|
||
answer = {"label": "blue cube", "point_format": "xy", "point": [500, 500]}
|
||
assert _vqa_answer_to_loc(answer) == "blue cube <loc0512><loc0512>"
|
||
|
||
|
||
def test_vqa_answer_to_loc_bbox_normalized():
|
||
answer = {
|
||
"detections": [{"label": "cube", "bbox_format": "xyxy", "bbox": [0, 0, 1000, 1000]}]
|
||
}
|
||
assert _vqa_answer_to_loc(answer) == "cube <loc0000><loc0000><loc1023><loc1023>"
|
||
|
||
|
||
def test_vqa_answer_to_loc_multiple_detections_separator():
|
||
answer = {
|
||
"detections": [
|
||
{"label": "blue", "bbox_format": "xyxy", "bbox": [0, 0, 500, 500]},
|
||
{"label": "yellow", "bbox_format": "xyxy", "bbox": [500, 500, 1000, 1000]},
|
||
]
|
||
}
|
||
out = _vqa_answer_to_loc(answer)
|
||
# Each segment is "label <locs>", joined by " ; "
|
||
assert out == (
|
||
"blue <loc0000><loc0000><loc0512><loc0512> ; "
|
||
"yellow <loc0512><loc0512><loc1023><loc1023>"
|
||
)
|
||
|
||
|
||
def test_vqa_answer_to_loc_returns_none_for_non_spatial():
|
||
assert _vqa_answer_to_loc({"label": "cubes", "count": 2}) is None
|
||
assert _vqa_answer_to_loc({"weird": "payload"}) is None
|
||
|
||
|
||
def test_messages_vqa_to_loc_rewrites_target_turn():
|
||
messages = [
|
||
{"role": "user", "content": [{"type": "text", "text": "where is the cube?"}]},
|
||
{
|
||
"role": "assistant",
|
||
"content": '{"label": "cube", "point_format": "xy", "point": [500, 500]}',
|
||
},
|
||
]
|
||
out = _messages_vqa_to_loc(messages, target_indices=[1])
|
||
assert out[1]["content"] == "cube <loc0512><loc0512>"
|
||
# input messages are not mutated
|
||
assert messages[1]["content"].startswith("{")
|
||
|
||
|
||
def test_messages_vqa_to_loc_leaves_plain_text_targets_untouched():
|
||
messages = [
|
||
{"role": "user", "content": "pick the cube"},
|
||
{"role": "assistant", "content": "pick up the cube"},
|
||
]
|
||
out = _messages_vqa_to_loc(messages, target_indices=[1])
|
||
assert out[1]["content"] == "pick up the cube"
|
||
|
||
|
||
def test_messages_vqa_to_loc_noop_without_target_indices():
|
||
messages = [
|
||
{"role": "assistant", "content": '{"label": "c", "point_format": "xy", "point": [1, 2]}'}
|
||
]
|
||
assert _messages_vqa_to_loc(messages, []) is messages
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Round-trip: training-side JSON -> <loc> -> runtime-side parse back
|
||
#
|
||
# Pins that the conversion preserves coordinate *order* (JSON is x-first,
|
||
# PaliGemma <loc> is y-first) and the 0–1000 → [0, 1023] scaling. The
|
||
# only loss is quantization to the 1024-bucket <loc> grid, so a coord
|
||
# survives within half a bucket (~1000/2046 ≈ 0.49 on the 0–1000 scale).
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
def test_loc_round_trip_keypoint_preserves_normalized_coords():
|
||
from lerobot.policies.smolvla2.inference.vqa import parse_vqa_answer
|
||
|
||
answer = {"label": "blue cube", "point_format": "xy", "point": [640, 480]}
|
||
loc = _vqa_answer_to_loc(answer)
|
||
parsed = parse_vqa_answer(loc)
|
||
nx, ny = parsed["payload"]["point"]
|
||
# parse_vqa_answer returns [0, 1] normalized; rescale back to 0–1000.
|
||
assert abs(nx * 1000.0 - 640) <= 1000.0 / 2046 + 1e-6
|
||
assert abs(ny * 1000.0 - 480) <= 1000.0 / 2046 + 1e-6
|
||
assert parsed["payload"]["label"] == "blue cube"
|
||
|
||
|
||
def test_loc_round_trip_bbox_preserves_order_and_scale():
|
||
from lerobot.policies.smolvla2.inference.vqa import parse_vqa_answer
|
||
|
||
answer = {
|
||
"detections": [{"label": "cube", "bbox_format": "xyxy", "bbox": [100, 200, 800, 900]}]
|
||
}
|
||
loc = _vqa_answer_to_loc(answer)
|
||
parsed = parse_vqa_answer(loc)
|
||
x1, y1, x2, y2 = parsed["payload"]["detections"][0]["bbox"]
|
||
for got, want in ((x1, 100), (y1, 200), (x2, 800), (y2, 900)):
|
||
assert abs(got * 1000.0 - want) <= 1000.0 / 2046 + 1e-6
|