feat: oversample sparse VQA annotations (recipe consumption + weighted sampler)

VQA annotations are sparse, so VQA was badly underrepresented in training:
its effective share was weight x density, and blend draws that picked an
ask_vqa* sub-recipe for a non-VQA frame were wasted entirely.

Two pieces:

1. Recipe-side consumption (language_render.py): render_sample now routes
   any frame that carries a VQA annotation to a matching ask_vqa* sub-recipe,
   regardless of the weighted blend draw. No VQA annotation is wasted and no
   draw lands on a non-renderable VQA recipe — VQA's recipe-side share now
   equals the VQA-annotation density.

2. Dataset-side oversampling (WeightedEpisodeAwareSampler + vqa_target_fraction):
   a new weighted, episode-aware sampler draws frames with replacement by
   per-frame weight. When TrainPipelineConfig.vqa_target_fraction is set, the
   train script scans language_events, weights VQA frames so they make up
   ~that fraction of the training stream, and uses the weighted sampler. This
   is what actually lets VQA exceed its natural density. Default None keeps
   uniform episode-aware sampling unchanged.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-05-18 15:30:00 +02:00
parent b319ccf688
commit fbcb9225f5
7 changed files with 343 additions and 51 deletions
+47 -1
View File
@@ -25,7 +25,7 @@ from datasets import Dataset # noqa: E402
from lerobot.datasets.io_utils import (
hf_transform_to_torch,
)
from lerobot.datasets.sampler import EpisodeAwareSampler
from lerobot.datasets.sampler import EpisodeAwareSampler, WeightedEpisodeAwareSampler
def calculate_episode_data_index(hf_dataset: Dataset) -> dict[str, torch.Tensor]:
@@ -137,3 +137,49 @@ def test_partial_episode_drop_warns(caplog):
# Episode 0 is skipped (1 frame, drop 1), Episode 1 keeps frames 2-5
assert sampler.indices == [2, 3, 4, 5]
assert "Episode 0" in caplog.text
# --- WeightedEpisodeAwareSampler --------------------------------------------
def test_weighted_sampler_respects_episode_drop_and_length():
"""The episode-boundary frame filtering is applied before weighting,
and one epoch still yields ``len(indices)`` samples."""
# One episode, 10 frames; drop the last 2.
sampler = WeightedEpisodeAwareSampler([0], [10], frame_weights=torch.ones(10), drop_n_last_frames=2)
assert sampler.indices == list(range(8))
assert len(sampler) == 8
draws = list(sampler)
assert len(draws) == 8
# Dropped frames 8 and 9 must never be sampled.
assert all(d in set(range(8)) for d in draws)
def test_weighted_sampler_oversamples_high_weight_frames():
"""A heavily-weighted frame dominates the draws."""
torch.manual_seed(0)
# 100 frames, frame 7 is weighted 1000x.
weights = torch.ones(100)
weights[7] = 1000.0
sampler = WeightedEpisodeAwareSampler([0], [100], frame_weights=weights)
counts = {}
for _ in range(20): # 20 epochs
for d in sampler:
counts[d] = counts.get(d, 0) + 1
total = sum(counts.values())
# Frame 7 should be the overwhelming majority of the 2000 draws.
assert counts.get(7, 0) / total > 0.9
def test_weighted_sampler_zero_weights_fall_back_to_uniform():
"""If every surviving frame has zero weight, sampling is uniform
rather than crashing."""
sampler = WeightedEpisodeAwareSampler([0], [6], frame_weights=torch.zeros(6))
draws = set(sampler)
assert draws.issubset(set(range(6)))
assert len(list(sampler)) == 6
def test_weighted_sampler_rejects_short_weight_vector():
with pytest.raises(ValueError, match="frame_weights"):
WeightedEpisodeAwareSampler([0], [10], frame_weights=torch.ones(5))