feat: oversample sparse VQA annotations (recipe consumption + weighted sampler)

VQA annotations are sparse, so VQA was badly underrepresented in training:
its effective share was weight x density, and blend draws that picked an
ask_vqa* sub-recipe for a non-VQA frame were wasted entirely.

Two pieces:

1. Recipe-side consumption (language_render.py): render_sample now routes
   any frame that carries a VQA annotation to a matching ask_vqa* sub-recipe,
   regardless of the weighted blend draw. No VQA annotation is wasted and no
   draw lands on a non-renderable VQA recipe — VQA's recipe-side share now
   equals the VQA-annotation density.

2. Dataset-side oversampling (WeightedEpisodeAwareSampler + vqa_target_fraction):
   a new weighted, episode-aware sampler draws frames with replacement by
   per-frame weight. When TrainPipelineConfig.vqa_target_fraction is set, the
   train script scans language_events, weights VQA frames so they make up
   ~that fraction of the training stream, and uses the weighted sampler. This
   is what actually lets VQA exceed its natural density. Default None keeps
   uniform episode-aware sampling unchanged.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-05-18 15:30:00 +02:00
parent b319ccf688
commit fbcb9225f5
7 changed files with 343 additions and 51 deletions
+60 -14
View File
@@ -207,12 +207,8 @@ def test_per_camera_blend_renders_both_views():
"top": TrainingRecipe(
weight=1.0,
bindings={
"vqa_query": (
"emitted_at(t, style=vqa, role=user, camera=observation.images.top)"
),
"vqa": (
"emitted_at(t, style=vqa, role=assistant, camera=observation.images.top)"
),
"vqa_query": ("emitted_at(t, style=vqa, role=user, camera=observation.images.top)"),
"vqa": ("emitted_at(t, style=vqa, role=assistant, camera=observation.images.top)"),
},
messages=[
MessageTurn(
@@ -236,12 +232,8 @@ def test_per_camera_blend_renders_both_views():
"wrist": TrainingRecipe(
weight=1.0,
bindings={
"vqa_query": (
"emitted_at(t, style=vqa, role=user, camera=observation.images.wrist)"
),
"vqa": (
"emitted_at(t, style=vqa, role=assistant, camera=observation.images.wrist)"
),
"vqa_query": ("emitted_at(t, style=vqa, role=user, camera=observation.images.wrist)"),
"vqa": ("emitted_at(t, style=vqa, role=assistant, camera=observation.images.wrist)"),
},
messages=[
MessageTurn(
@@ -319,11 +311,19 @@ def test_resolve_task_picks_rephrasing_deterministically_per_sample():
assert seen == {r["content"] for r in rephrasings}
# Same sample_idx → same pick (determinism).
a = render_sample(
recipe=recipe, persistent=rephrasings, events=[], t=0.0, sample_idx=42,
recipe=recipe,
persistent=rephrasings,
events=[],
t=0.0,
sample_idx=42,
dataset_ctx={"task": "canonical"},
)
b = render_sample(
recipe=recipe, persistent=rephrasings, events=[], t=0.0, sample_idx=42,
recipe=recipe,
persistent=rephrasings,
events=[],
t=0.0,
sample_idx=42,
dataset_ctx={"task": "canonical"},
)
assert a["messages"][0]["content"] == b["messages"][0]["content"]
@@ -402,6 +402,52 @@ def test_flow_only_low_level_recipe_renders_without_target():
assert rendered["target_message_indices"] == []
def test_vqa_frame_is_consumed_over_the_weighted_blend():
"""A frame carrying a VQA annotation renders the ``ask_vqa*`` sub-recipe
even when its blend weight is tiny — VQA annotations are sparse and must
never be wasted on a subtask/action draw."""
recipe = TrainingRecipe(
blend={
"high_level_subtask": TrainingRecipe(
weight=0.99,
messages=[
MessageTurn(role="user", content="${task}", stream="high_level"),
MessageTurn(role="assistant", content="a subtask", stream="high_level", target=True),
],
),
"ask_vqa_top": TrainingRecipe(
weight=0.01,
bindings={
"vqa_query": "emitted_at(t, style=vqa, role=user, camera=observation.images.top)",
"vqa": "emitted_at(t, style=vqa, role=assistant, camera=observation.images.top)",
},
messages=[
MessageTurn(
role="user", content="${vqa_query}", stream="high_level", if_present="vqa_query"
),
MessageTurn(
role="assistant",
content="${vqa}",
stream="high_level",
target=True,
if_present="vqa",
),
],
),
}
)
# A frame WITH a vqa event renders VQA on every sample_idx, despite the
# ask_vqa weight being only 0.01.
for sample_idx in range(20):
rendered = render_sample(
recipe=recipe, persistent=PERSISTENT, events=EVENTS_AT_1, t=1.0, sample_idx=sample_idx, task="x"
)
assert rendered["messages"][-1]["content"] == '{"count": 2}', sample_idx
# A frame WITHOUT a vqa event falls back to the normal weighted blend.
rendered = render_sample(recipe=recipe, persistent=PERSISTENT, events=[], t=1.0, sample_idx=0, task="x")
assert rendered["messages"][-1]["content"] == "a subtask"
def test_canonical_recipe_can_render_low_level_branch():
recipe = TrainingRecipe.from_yaml(Path("src/lerobot/configs/recipes/pi05_hirobot.yaml"))
low_level = TrainingRecipe(blend={"low": recipe.blend["low_level_execution"]})