mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-27 14:39:43 +00:00
examples(port_datasets): SLURM+datatrove RoboCasa composite_seen build
Parallel variant of build_robocasa_composite_seen.py modeled after the
existing slurm_port_shards.py / slurm_aggregate_shards.py pattern.
Two-phase datatrove pipeline:
* Phase 1 DOWNLOAD: tasks=16 (one per RoboCasa composite_seen task),
each worker downloads its assigned tar via RoboCasa's own
download_datasets helper. Network-bound, idempotent.
* Phase 2 AGGREGATE: tasks=1, single worker calls aggregate_datasets
over the 16 extracted directories. Submitted with depends=phase1 so
SLURM only releases it once all 16 downloads succeed.
Reuses the COMPOSITE_SEEN_TASKS list and per-task download/resolve
helpers from the single-machine script via aliased imports — single
source of truth for 'what does it mean to download a composite_seen
task'.
Local (--slurm 0) mode runs the two phases sequentially in-process for
debugging on a workstation.
Usage on SLURM:
uv run python examples/port_datasets/slurm_build_robocasa_composite_seen.py \
--output-dir=/scratch/${USER}/robocasa_composite_seen \
--hub-repo-id=${HF_USER}/robocasa_composite_seen \
--logs-dir=/scratch/${USER}/logs/robocasa \
--partition=cpu --push-to-hub
Prereq: uv sync --extra annotations (pulls datatrove)
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -162,7 +162,7 @@ def test_messages_vqa_to_loc_noop_without_target_indices():
|
||||
|
||||
|
||||
def test_loc_round_trip_keypoint_preserves_normalized_coords():
|
||||
from lerobot.policies.smolvla2.inference.vqa import parse_vqa_answer
|
||||
from lerobot.policies.pi052.inference.vqa import parse_vqa_answer
|
||||
|
||||
answer = {"label": "blue cube", "point_format": "xy", "point": [640, 480]}
|
||||
loc = _vqa_answer_to_loc(answer)
|
||||
@@ -175,7 +175,7 @@ def test_loc_round_trip_keypoint_preserves_normalized_coords():
|
||||
|
||||
|
||||
def test_loc_round_trip_bbox_preserves_order_and_scale():
|
||||
from lerobot.policies.smolvla2.inference.vqa import parse_vqa_answer
|
||||
from lerobot.policies.pi052.inference.vqa import parse_vqa_answer
|
||||
|
||||
answer = {
|
||||
"detections": [{"label": "cube", "bbox_format": "xyxy", "bbox": [100, 200, 800, 900]}]
|
||||
|
||||
@@ -1,163 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Attention-masking tests for the SmolVLA2 text head.
|
||||
|
||||
Regression coverage for the text-CE collapse bug: ``embed_prefix`` flags
|
||||
every language token ``att=0``, which ``make_att_2d_masks`` turns into a
|
||||
single fully *bidirectional* block. Under that mask the text
|
||||
cross-entropy degenerates into a copy task — a supervised target token
|
||||
attends to the tokens it is trained to predict — and the model never
|
||||
learns causal generation, so ``select_message`` collapses at inference.
|
||||
|
||||
``_mark_target_span_causal`` sets ``att=1`` on the supervised target
|
||||
language positions so each target token attends causally among the
|
||||
targets while staying bidirectional to images + the user prompt. These
|
||||
tests pin that behaviour.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
# The smolvla2 modeling module imports transformers transitively.
|
||||
pytest.importorskip("transformers")
|
||||
|
||||
from lerobot.policies.smolvla.modeling_smolvla import make_att_2d_masks # noqa: E402
|
||||
from lerobot.policies.smolvla2.modeling_smolvla2 import ( # noqa: E402
|
||||
_locate_lang_range,
|
||||
_mark_target_span_causal,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# A synthetic SmolVLA prefix layout: [images, prompt-lang, target-lang, state]
|
||||
#
|
||||
# indices 0-1 : 2 image tokens (att = 0)
|
||||
# indices 2-4 : 3 user-prompt lang (att = 0)
|
||||
# indices 5-8 : 4 supervised target lang(att = 0 from embed_prefix)
|
||||
# index 9 : 1 state token (att = 1)
|
||||
#
|
||||
# ``text_labels`` covers the 7 language tokens; -100 on the prompt span,
|
||||
# real ids on the 4-token target span.
|
||||
# ---------------------------------------------------------------------------
|
||||
N_IMAGE = 2
|
||||
N_PROMPT = 3
|
||||
N_TARGET = 4
|
||||
LANG_START = N_IMAGE
|
||||
LANG_END = N_IMAGE + N_PROMPT + N_TARGET # = state-token index
|
||||
PREFIX_LEN = LANG_END + 1
|
||||
|
||||
|
||||
def _embed_prefix_att_masks() -> torch.Tensor:
|
||||
"""Mimic ``embed_prefix``: images + lang all att=0, state att=1."""
|
||||
att = torch.zeros(1, PREFIX_LEN, dtype=torch.bool)
|
||||
att[0, LANG_END] = True # the single state token
|
||||
return att
|
||||
|
||||
|
||||
def _text_labels() -> torch.Tensor:
|
||||
"""-100 over the prompt span, real ids over the target span."""
|
||||
labels = torch.full((1, N_PROMPT + N_TARGET), -100, dtype=torch.long)
|
||||
labels[0, N_PROMPT:] = torch.arange(10, 10 + N_TARGET)
|
||||
return labels
|
||||
|
||||
|
||||
def _attends(prefix_att_masks: torch.Tensor) -> torch.Tensor:
|
||||
"""2D boolean attendance matrix; ``[i, j]`` True ⇒ i attends to j."""
|
||||
pad = torch.ones(1, PREFIX_LEN, dtype=torch.bool)
|
||||
return make_att_2d_masks(pad, prefix_att_masks)[0]
|
||||
|
||||
|
||||
def test_locate_lang_range_anchors_on_state_token():
|
||||
"""``_locate_lang_range`` finds the lang span via the lone att=1 token."""
|
||||
lang_start, lang_end = _locate_lang_range(
|
||||
_embed_prefix_att_masks(), num_lang=N_PROMPT + N_TARGET
|
||||
)
|
||||
assert (lang_start, lang_end) == (LANG_START, LANG_END)
|
||||
|
||||
|
||||
def test_mark_sets_att_on_targets_only():
|
||||
"""Only the supervised target language positions flip to att=1."""
|
||||
marked = _mark_target_span_causal(
|
||||
_embed_prefix_att_masks(), _text_labels(), LANG_START, LANG_END
|
||||
)
|
||||
expected = [False] * PREFIX_LEN
|
||||
for i in range(LANG_START + N_PROMPT, LANG_END): # target span
|
||||
expected[i] = True
|
||||
expected[LANG_END] = True # state token, untouched
|
||||
assert marked[0].tolist() == expected
|
||||
|
||||
|
||||
def test_target_tokens_attend_causally_among_themselves():
|
||||
"""A target token must NOT attend to later targets, but must attend
|
||||
to earlier ones — i.e. genuine causal next-token prediction."""
|
||||
marked = _mark_target_span_causal(
|
||||
_embed_prefix_att_masks(), _text_labels(), LANG_START, LANG_END
|
||||
)
|
||||
attends = _attends(marked)
|
||||
tgt = range(LANG_START + N_PROMPT, LANG_END)
|
||||
for i in tgt:
|
||||
for j in tgt:
|
||||
if j > i:
|
||||
assert not attends[i, j], f"target {i} must not see future target {j}"
|
||||
else:
|
||||
assert attends[i, j], f"target {i} must see earlier/self target {j}"
|
||||
|
||||
|
||||
def test_target_tokens_attend_prompt_and_images_bidirectionally():
|
||||
"""Targets keep full visibility of images + the user prompt."""
|
||||
marked = _mark_target_span_causal(
|
||||
_embed_prefix_att_masks(), _text_labels(), LANG_START, LANG_END
|
||||
)
|
||||
attends = _attends(marked)
|
||||
context = list(range(0, LANG_START + N_PROMPT)) # images + prompt
|
||||
for i in range(LANG_START + N_PROMPT, LANG_END):
|
||||
for j in context:
|
||||
assert attends[i, j], f"target {i} must attend context {j}"
|
||||
|
||||
|
||||
def test_action_expert_token_still_sees_full_subtask():
|
||||
"""The state token (action-expert context) attends to every target —
|
||||
causal masking the targets must not hide them from the action path."""
|
||||
marked = _mark_target_span_causal(
|
||||
_embed_prefix_att_masks(), _text_labels(), LANG_START, LANG_END
|
||||
)
|
||||
attends = _attends(marked)
|
||||
for j in range(LANG_START + N_PROMPT, LANG_END):
|
||||
assert attends[LANG_END, j], f"state token must see target {j}"
|
||||
|
||||
|
||||
def test_non_target_subtask_stays_bidirectional():
|
||||
"""``low_level_execution`` renders the subtask as a user turn — its
|
||||
``text_labels`` are all -100, so the mask must be left untouched and
|
||||
the action expert reads the subtask bidirectionally."""
|
||||
all_ignored = torch.full((1, N_PROMPT + N_TARGET), -100, dtype=torch.long)
|
||||
marked = _mark_target_span_causal(
|
||||
_embed_prefix_att_masks(), all_ignored, LANG_START, LANG_END
|
||||
)
|
||||
assert torch.equal(marked, _embed_prefix_att_masks())
|
||||
|
||||
|
||||
def test_unmarked_mask_is_bidirectional_the_bug():
|
||||
"""Documents the bug the fix prevents: without ``_mark_target_span_causal``
|
||||
a target token attends *bidirectionally* to later targets — the
|
||||
text-CE can copy the answer it is trained to predict."""
|
||||
attends = _attends(_embed_prefix_att_masks())
|
||||
first_tgt = LANG_START + N_PROMPT
|
||||
last_tgt = LANG_END - 1
|
||||
assert attends[first_tgt, last_tgt], (
|
||||
"raw embed_prefix mask is bidirectional over language — the first "
|
||||
"target token can see the last, which is the collapse bug"
|
||||
)
|
||||
@@ -1,77 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for SmolVLA2's chat-tokenizer ``tool_calls`` flattening.
|
||||
|
||||
``_split_plan_and_say`` (inference) expects the model to emit a textual
|
||||
``<say>...</say>`` marker. ``_flatten_say_tool_calls`` is the training-time
|
||||
serializer that produces it: it rewrites an assistant turn's structured
|
||||
``say`` tool call into that marker *inside the content text*, before
|
||||
``apply_chat_template`` runs — so the chat template only tokenizes plain
|
||||
text and the supervised target span trains the model to emit the marker
|
||||
the runtime parses back. These tests pin the round-trip.
|
||||
"""
|
||||
|
||||
from lerobot.policies.smolvla2.chat_processor_smolvla2 import flatten_say_tool_calls
|
||||
from lerobot.policies.smolvla2.inference.steps import _split_plan_and_say
|
||||
|
||||
|
||||
def _say_call(text):
|
||||
return {"type": "function", "function": {"name": "say", "arguments": {"text": text}}}
|
||||
|
||||
|
||||
def test_flatten_appends_say_marker_and_drops_tool_calls():
|
||||
msg = {"role": "assistant", "content": "Pick up the blue cube.", "tool_calls": [_say_call("On it!")]}
|
||||
out = flatten_say_tool_calls(msg)
|
||||
assert "tool_calls" not in out
|
||||
assert out["content"] == "Pick up the blue cube.\n<say>On it!</say>"
|
||||
|
||||
|
||||
def test_flatten_roundtrips_through_inference_parser():
|
||||
"""The marker the serializer writes must be exactly what the inference
|
||||
parser reads back — this is the train/inference contract."""
|
||||
msg = {"role": "assistant", "content": "Move toward the cube.", "tool_calls": [_say_call("Working on it")]}
|
||||
flat = flatten_say_tool_calls(msg)["content"]
|
||||
plan, speech = _split_plan_and_say(flat)
|
||||
assert plan == "Move toward the cube."
|
||||
assert speech == "Working on it"
|
||||
|
||||
|
||||
def test_flatten_accepts_json_string_arguments():
|
||||
"""``arguments`` may arrive as a JSON string rather than a dict."""
|
||||
call = {"type": "function", "function": {"name": "say", "arguments": '{"text": "hello there"}'}}
|
||||
out = flatten_say_tool_calls({"role": "assistant", "content": "p", "tool_calls": [call]})
|
||||
assert out["content"] == "p\n<say>hello there</say>"
|
||||
|
||||
|
||||
def test_flatten_leaves_messages_without_tool_calls_untouched():
|
||||
msg = {"role": "assistant", "content": "just a plan"}
|
||||
assert flatten_say_tool_calls(msg) == msg
|
||||
|
||||
|
||||
def test_flatten_drops_empty_or_non_say_tool_calls():
|
||||
"""A non-``say`` call (or empty text) leaves content alone but still
|
||||
strips the structured calls so the template renders no JSON block."""
|
||||
weather = {"type": "function", "function": {"name": "check_weather", "arguments": {}}}
|
||||
out = flatten_say_tool_calls({"role": "assistant", "content": "plan only", "tool_calls": [weather]})
|
||||
assert out["content"] == "plan only"
|
||||
assert "tool_calls" not in out
|
||||
|
||||
|
||||
def test_flatten_marker_only_when_content_empty():
|
||||
msg = {"role": "assistant", "content": "", "tool_calls": [_say_call("hi")]}
|
||||
out = flatten_say_tool_calls(msg)
|
||||
assert out["content"] == "<say>hi</say>"
|
||||
@@ -1,228 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for the SmolVLA2 runtime's interactive-VQA helpers.
|
||||
|
||||
Covers camera selection, VQA-answer parsing, and the bounding-box /
|
||||
keypoint overlay drawing — the pure functions, no model load.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from lerobot.policies.smolvla2.inference.vqa import (
|
||||
answer_has_overlay,
|
||||
available_cameras,
|
||||
camera_short_name,
|
||||
draw_vqa_overlay,
|
||||
observation_image_to_pil,
|
||||
parse_vqa_answer,
|
||||
prompt_camera_choice,
|
||||
)
|
||||
|
||||
PIL = pytest.importorskip("PIL")
|
||||
from PIL import Image # noqa: E402
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Camera selection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_available_cameras_extracts_and_sorts_image_keys():
|
||||
observation = {
|
||||
"observation.images.wrist": object(),
|
||||
"observation.state": object(),
|
||||
"observation.images.top": object(),
|
||||
"task": "x",
|
||||
}
|
||||
assert available_cameras(observation) == [
|
||||
"observation.images.top",
|
||||
"observation.images.wrist",
|
||||
]
|
||||
|
||||
|
||||
def test_available_cameras_handles_none_and_empty():
|
||||
assert available_cameras(None) == []
|
||||
assert available_cameras({}) == []
|
||||
|
||||
|
||||
def test_camera_short_name_strips_prefix():
|
||||
assert camera_short_name("observation.images.top") == "top"
|
||||
assert camera_short_name("top") == "top"
|
||||
|
||||
|
||||
def test_prompt_camera_choice_single_camera_auto_selects():
|
||||
cams = ["observation.images.top"]
|
||||
# input_fn must never be called for a single-camera setup.
|
||||
chosen = prompt_camera_choice(cams, input_fn=_boom, print_fn=lambda *_: None)
|
||||
assert chosen == "observation.images.top"
|
||||
|
||||
|
||||
def test_prompt_camera_choice_by_number():
|
||||
cams = ["observation.images.top", "observation.images.wrist"]
|
||||
chosen = prompt_camera_choice(cams, input_fn=lambda _: "2", print_fn=lambda *_: None)
|
||||
assert chosen == "observation.images.wrist"
|
||||
|
||||
|
||||
def test_prompt_camera_choice_by_name():
|
||||
cams = ["observation.images.top", "observation.images.wrist"]
|
||||
chosen = prompt_camera_choice(cams, input_fn=lambda _: "top", print_fn=lambda *_: None)
|
||||
assert chosen == "observation.images.top"
|
||||
|
||||
|
||||
def test_prompt_camera_choice_invalid_returns_none():
|
||||
cams = ["observation.images.top", "observation.images.wrist"]
|
||||
assert prompt_camera_choice(cams, input_fn=lambda _: "99", print_fn=lambda *_: None) is None
|
||||
|
||||
|
||||
def _boom(*_args, **_kwargs):
|
||||
raise AssertionError("input_fn should not be called")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Answer parsing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_parse_bbox_answer():
|
||||
answer = '{"detections": [{"label": "cube", "bbox_format": "xyxy", "bbox": [10, 20, 50, 80]}]}'
|
||||
parsed = parse_vqa_answer(answer)
|
||||
assert parsed["kind"] == "bbox"
|
||||
assert answer_has_overlay(parsed)
|
||||
|
||||
|
||||
def test_parse_keypoint_answer():
|
||||
answer = '{"label": "blue cube", "point_format": "xy", "point": [120, 90]}'
|
||||
parsed = parse_vqa_answer(answer)
|
||||
assert parsed["kind"] == "keypoint"
|
||||
assert answer_has_overlay(parsed)
|
||||
|
||||
|
||||
def test_parse_count_answer_is_not_an_overlay():
|
||||
parsed = parse_vqa_answer('{"label": "cubes", "count": 2}')
|
||||
assert parsed["kind"] == "count"
|
||||
assert not answer_has_overlay(parsed)
|
||||
|
||||
|
||||
def test_parse_invalid_json_returns_none():
|
||||
assert parse_vqa_answer("not json at all") is None
|
||||
assert parse_vqa_answer("") is None
|
||||
# A JSON array is valid JSON but not a VQA answer object.
|
||||
assert parse_vqa_answer("[1, 2, 3]") is None
|
||||
|
||||
|
||||
def test_parse_unknown_shape():
|
||||
parsed = parse_vqa_answer('{"weird": "payload"}')
|
||||
assert parsed["kind"] == "unknown"
|
||||
assert not answer_has_overlay(parsed)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Overlay drawing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _blank(size=(160, 120)):
|
||||
return Image.new("RGB", size, (0, 0, 0))
|
||||
|
||||
|
||||
def test_draw_bbox_overlay_changes_pixels_and_preserves_size():
|
||||
img = _blank()
|
||||
parsed = parse_vqa_answer(
|
||||
'{"detections": [{"label": "cube", "bbox_format": "xyxy", "bbox": [10, 20, 50, 80]}]}'
|
||||
)
|
||||
out = draw_vqa_overlay(img, parsed)
|
||||
assert out.size == img.size
|
||||
assert out.tobytes() != img.tobytes()
|
||||
|
||||
|
||||
def test_draw_keypoint_overlay_changes_pixels():
|
||||
img = _blank()
|
||||
parsed = parse_vqa_answer('{"label": "cube", "point_format": "xy", "point": [80, 60]}')
|
||||
out = draw_vqa_overlay(img, parsed)
|
||||
assert out.size == img.size
|
||||
assert out.tobytes() != img.tobytes()
|
||||
|
||||
|
||||
def test_draw_overlay_non_spatial_leaves_image_unchanged():
|
||||
img = _blank()
|
||||
parsed = parse_vqa_answer('{"label": "cubes", "count": 2}')
|
||||
out = draw_vqa_overlay(img, parsed)
|
||||
assert out.tobytes() == img.tobytes()
|
||||
|
||||
|
||||
def test_draw_overlay_tolerates_malformed_coordinates():
|
||||
img = _blank()
|
||||
# bbox with the wrong arity must not raise.
|
||||
out = draw_vqa_overlay(img, {"kind": "bbox", "payload": {"detections": [{"bbox": [1, 2]}]}})
|
||||
assert out.size == img.size
|
||||
|
||||
|
||||
def test_observation_image_to_pil_from_batched_float_array():
|
||||
# (1, C, H, W) float array in [0, 1], the runtime observation shape.
|
||||
arr = np.zeros((1, 3, 24, 32), dtype=np.float32)
|
||||
pil = observation_image_to_pil(arr)
|
||||
assert pil.size == (32, 24)
|
||||
assert pil.mode == "RGB"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PaliGemma <loc>-format answers (PI052 trains spatial VQA in this vocab)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_parse_loc_keypoint_answer():
|
||||
# <locY><locX> label — y=512/1023≈0.5, x=256/1023≈0.25
|
||||
parsed = parse_vqa_answer("<loc0512><loc0256> blue cube")
|
||||
assert parsed["kind"] == "keypoint"
|
||||
assert parsed["normalized"] is True
|
||||
x, y = parsed["payload"]["point"]
|
||||
assert 0.24 < x < 0.26
|
||||
assert 0.49 < y < 0.51
|
||||
assert parsed["payload"]["label"] == "blue cube"
|
||||
assert answer_has_overlay(parsed)
|
||||
|
||||
|
||||
def test_parse_loc_bbox_answer():
|
||||
# <locY0><locX0><locY1><locX1> label
|
||||
parsed = parse_vqa_answer("<loc0100><loc0080><loc0400><loc0360> yellow cube")
|
||||
assert parsed["kind"] == "bbox"
|
||||
assert parsed["normalized"] is True
|
||||
det = parsed["payload"]["detections"][0]
|
||||
x1, y1, x2, y2 = det["bbox"]
|
||||
assert x1 < x2 and y1 < y2
|
||||
assert det["label"] == "yellow cube"
|
||||
assert answer_has_overlay(parsed)
|
||||
|
||||
|
||||
def test_parse_loc_multiple_boxes():
|
||||
answer = "<loc0100><loc0080><loc0400><loc0360> cube ; <loc0200><loc0500><loc0600><loc0900> box"
|
||||
parsed = parse_vqa_answer(answer)
|
||||
assert parsed["kind"] == "bbox"
|
||||
assert len(parsed["payload"]["detections"]) == 2
|
||||
|
||||
|
||||
def test_parse_loc_takes_precedence_over_json():
|
||||
# An answer with <loc> tokens is parsed as loc even if JSON-ish.
|
||||
assert parse_vqa_answer('{"x": <loc0001><loc0002>}')["normalized"] is True
|
||||
|
||||
|
||||
def test_draw_loc_overlay_denormalizes_to_pixels():
|
||||
img = _blank((200, 100))
|
||||
parsed = parse_vqa_answer("<loc0511><loc0511> cube") # ~centre
|
||||
out = draw_vqa_overlay(img, parsed)
|
||||
assert out.size == img.size
|
||||
assert out.tobytes() != img.tobytes()
|
||||
Reference in New Issue
Block a user