mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-19 18:49:52 +00:00
feat: oversample sparse VQA annotations (recipe consumption + weighted sampler)
VQA annotations are sparse, so VQA was badly underrepresented in training: its effective share was weight x density, and blend draws that picked an ask_vqa* sub-recipe for a non-VQA frame were wasted entirely. Two pieces: 1. Recipe-side consumption (language_render.py): render_sample now routes any frame that carries a VQA annotation to a matching ask_vqa* sub-recipe, regardless of the weighted blend draw. No VQA annotation is wasted and no draw lands on a non-renderable VQA recipe — VQA's recipe-side share now equals the VQA-annotation density. 2. Dataset-side oversampling (WeightedEpisodeAwareSampler + vqa_target_fraction): a new weighted, episode-aware sampler draws frames with replacement by per-frame weight. When TrainPipelineConfig.vqa_target_fraction is set, the train script scans language_events, weights VQA frames so they make up ~that fraction of the training stream, and uses the weighted sampler. This is what actually lets VQA exceed its natural density. Default None keeps uniform episode-aware sampling unchanged. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -207,12 +207,8 @@ def test_per_camera_blend_renders_both_views():
|
||||
"top": TrainingRecipe(
|
||||
weight=1.0,
|
||||
bindings={
|
||||
"vqa_query": (
|
||||
"emitted_at(t, style=vqa, role=user, camera=observation.images.top)"
|
||||
),
|
||||
"vqa": (
|
||||
"emitted_at(t, style=vqa, role=assistant, camera=observation.images.top)"
|
||||
),
|
||||
"vqa_query": ("emitted_at(t, style=vqa, role=user, camera=observation.images.top)"),
|
||||
"vqa": ("emitted_at(t, style=vqa, role=assistant, camera=observation.images.top)"),
|
||||
},
|
||||
messages=[
|
||||
MessageTurn(
|
||||
@@ -236,12 +232,8 @@ def test_per_camera_blend_renders_both_views():
|
||||
"wrist": TrainingRecipe(
|
||||
weight=1.0,
|
||||
bindings={
|
||||
"vqa_query": (
|
||||
"emitted_at(t, style=vqa, role=user, camera=observation.images.wrist)"
|
||||
),
|
||||
"vqa": (
|
||||
"emitted_at(t, style=vqa, role=assistant, camera=observation.images.wrist)"
|
||||
),
|
||||
"vqa_query": ("emitted_at(t, style=vqa, role=user, camera=observation.images.wrist)"),
|
||||
"vqa": ("emitted_at(t, style=vqa, role=assistant, camera=observation.images.wrist)"),
|
||||
},
|
||||
messages=[
|
||||
MessageTurn(
|
||||
@@ -319,11 +311,19 @@ def test_resolve_task_picks_rephrasing_deterministically_per_sample():
|
||||
assert seen == {r["content"] for r in rephrasings}
|
||||
# Same sample_idx → same pick (determinism).
|
||||
a = render_sample(
|
||||
recipe=recipe, persistent=rephrasings, events=[], t=0.0, sample_idx=42,
|
||||
recipe=recipe,
|
||||
persistent=rephrasings,
|
||||
events=[],
|
||||
t=0.0,
|
||||
sample_idx=42,
|
||||
dataset_ctx={"task": "canonical"},
|
||||
)
|
||||
b = render_sample(
|
||||
recipe=recipe, persistent=rephrasings, events=[], t=0.0, sample_idx=42,
|
||||
recipe=recipe,
|
||||
persistent=rephrasings,
|
||||
events=[],
|
||||
t=0.0,
|
||||
sample_idx=42,
|
||||
dataset_ctx={"task": "canonical"},
|
||||
)
|
||||
assert a["messages"][0]["content"] == b["messages"][0]["content"]
|
||||
@@ -402,6 +402,52 @@ def test_flow_only_low_level_recipe_renders_without_target():
|
||||
assert rendered["target_message_indices"] == []
|
||||
|
||||
|
||||
def test_vqa_frame_is_consumed_over_the_weighted_blend():
|
||||
"""A frame carrying a VQA annotation renders the ``ask_vqa*`` sub-recipe
|
||||
even when its blend weight is tiny — VQA annotations are sparse and must
|
||||
never be wasted on a subtask/action draw."""
|
||||
recipe = TrainingRecipe(
|
||||
blend={
|
||||
"high_level_subtask": TrainingRecipe(
|
||||
weight=0.99,
|
||||
messages=[
|
||||
MessageTurn(role="user", content="${task}", stream="high_level"),
|
||||
MessageTurn(role="assistant", content="a subtask", stream="high_level", target=True),
|
||||
],
|
||||
),
|
||||
"ask_vqa_top": TrainingRecipe(
|
||||
weight=0.01,
|
||||
bindings={
|
||||
"vqa_query": "emitted_at(t, style=vqa, role=user, camera=observation.images.top)",
|
||||
"vqa": "emitted_at(t, style=vqa, role=assistant, camera=observation.images.top)",
|
||||
},
|
||||
messages=[
|
||||
MessageTurn(
|
||||
role="user", content="${vqa_query}", stream="high_level", if_present="vqa_query"
|
||||
),
|
||||
MessageTurn(
|
||||
role="assistant",
|
||||
content="${vqa}",
|
||||
stream="high_level",
|
||||
target=True,
|
||||
if_present="vqa",
|
||||
),
|
||||
],
|
||||
),
|
||||
}
|
||||
)
|
||||
# A frame WITH a vqa event renders VQA on every sample_idx, despite the
|
||||
# ask_vqa weight being only 0.01.
|
||||
for sample_idx in range(20):
|
||||
rendered = render_sample(
|
||||
recipe=recipe, persistent=PERSISTENT, events=EVENTS_AT_1, t=1.0, sample_idx=sample_idx, task="x"
|
||||
)
|
||||
assert rendered["messages"][-1]["content"] == '{"count": 2}', sample_idx
|
||||
# A frame WITHOUT a vqa event falls back to the normal weighted blend.
|
||||
rendered = render_sample(recipe=recipe, persistent=PERSISTENT, events=[], t=1.0, sample_idx=0, task="x")
|
||||
assert rendered["messages"][-1]["content"] == "a subtask"
|
||||
|
||||
|
||||
def test_canonical_recipe_can_render_low_level_branch():
|
||||
recipe = TrainingRecipe.from_yaml(Path("src/lerobot/configs/recipes/pi05_hirobot.yaml"))
|
||||
low_level = TrainingRecipe(blend={"low": recipe.blend["low_level_execution"]})
|
||||
|
||||
Reference in New Issue
Block a user