feat: oversample sparse VQA annotations (recipe consumption + weighted sampler)

VQA annotations are sparse, so VQA was badly underrepresented in training:
its effective share was weight x density, and blend draws that picked an
ask_vqa* sub-recipe for a non-VQA frame were wasted entirely.

Two pieces:

1. Recipe-side consumption (language_render.py): render_sample now routes
   any frame that carries a VQA annotation to a matching ask_vqa* sub-recipe,
   regardless of the weighted blend draw. No VQA annotation is wasted and no
   draw lands on a non-renderable VQA recipe — VQA's recipe-side share now
   equals the VQA-annotation density.

2. Dataset-side oversampling (WeightedEpisodeAwareSampler + vqa_target_fraction):
   a new weighted, episode-aware sampler draws frames with replacement by
   per-frame weight. When TrainPipelineConfig.vqa_target_fraction is set, the
   train script scans language_events, weights VQA frames so they make up
   ~that fraction of the training stream, and uses the weighted sampler. This
   is what actually lets VQA exceed its natural density. Default None keeps
   uniform episode-aware sampling unchanged.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-05-18 15:30:00 +02:00
parent b319ccf688
commit fbcb9225f5
7 changed files with 343 additions and 51 deletions
+8
View File
@@ -72,6 +72,14 @@ class TrainPipelineConfig(HubMixin):
wandb: WandBConfig = field(default_factory=WandBConfig)
peft: PeftConfig | None = None
# VQA oversampling. When set (a fraction in (0, 1)), the training
# dataloader uses a WeightedEpisodeAwareSampler that draws frames
# carrying a `vqa` language annotation often enough that they make
# up roughly this fraction of the training stream. VQA annotations
# are typically sparse, so without this they are underrepresented.
# `None` (default) keeps uniform episode-aware sampling.
vqa_target_fraction: float | None = None
# RA-BC (Reward-Aligned Behavior Cloning) parameters
use_rabc: bool = False # Enable reward-weighted training
rabc_progress_path: str | None = None # Path to precomputed SARM progress parquet file
+2 -1
View File
@@ -47,7 +47,7 @@ from .language import (
from .lerobot_dataset import LeRobotDataset
from .multi_dataset import MultiLeRobotDataset
from .pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
from .sampler import EpisodeAwareSampler
from .sampler import EpisodeAwareSampler, WeightedEpisodeAwareSampler
from .streaming_dataset import StreamingLeRobotDataset
from .utils import DEFAULT_EPISODES_PATH, create_lerobot_dataset_card
from .video_utils import VideoEncodingManager
@@ -75,6 +75,7 @@ __all__ = [
"DEFAULT_QUANTILES",
"EVENT_ONLY_STYLES",
"EpisodeAwareSampler",
"WeightedEpisodeAwareSampler",
"LANGUAGE_EVENTS",
"LANGUAGE_PERSISTENT",
"LeRobotDataset",
+84 -27
View File
@@ -56,13 +56,9 @@ def active_at(
uniformity but is not consulted: only persistent styles are valid here.
"""
_validate_persistent_resolver("active_at", style)
matches = _matching_rows(
persistent, style=style, role=role, tool_name=tool_name, camera=camera
)
matches = _matching_rows(persistent, style=style, role=role, tool_name=tool_name, camera=camera)
matches = [row for row in matches if _timestamp(row) <= t]
return _select_latest(
matches, style=style, role=role, tool_name=tool_name, camera=camera
)
return _select_latest(matches, style=style, role=role, tool_name=tool_name, camera=camera)
def emitted_at(
@@ -88,9 +84,7 @@ def emitted_at(
if column == LANGUAGE_PERSISTENT:
matches = [
row
for row in _matching_rows(
persistent, style=style, role=role, tool_name=tool_name, camera=camera
)
for row in _matching_rows(persistent, style=style, role=role, tool_name=tool_name, camera=camera)
if _timestamp(row) == t
]
return _select_one(
@@ -101,9 +95,7 @@ def emitted_at(
camera=camera,
sort_key=_persistent_sort_key,
)
matches = _matching_rows(
events, style=style, role=role, tool_name=tool_name, camera=camera
)
matches = _matching_rows(events, style=style, role=role, tool_name=tool_name, camera=camera)
return _select_one(
matches,
style=style,
@@ -192,6 +184,29 @@ def render_sample(
"""
persistent_rows = _normalize_rows(persistent or [])
event_rows = _normalize_rows(events or [])
# VQA-priority routing. A ``vqa`` annotation is sparse and
# view-dependent; the plain weighted blend would (a) waste a draw
# whenever it picks an ``ask_vqa*`` sub-recipe for a frame that has
# no VQA, and (b) silently drop a VQA-annotated frame whenever it
# picks a non-VQA sub-recipe. So: if the blend has ``ask_vqa*``
# sub-recipes and *this* frame carries one of their VQA bindings,
# render VQA here regardless of the weighted draw. That makes VQA's
# recipe-side training share equal the VQA-annotation density (the
# maximum reachable without a dataset-level oversampling sampler).
if recipe.blend is not None:
vqa_rendered = _render_vqa_if_present(
recipe,
persistent=persistent_rows,
events=event_rows,
t=t,
sample_idx=sample_idx,
task=task,
dataset_ctx=dataset_ctx,
)
if vqa_rendered is not None:
return vqa_rendered
selected_recipe = _select_recipe(recipe, sample_idx)
bindings = _resolve_bindings(
selected_recipe,
@@ -205,6 +220,59 @@ def render_sample(
return _render_message_recipe(selected_recipe, bindings)
def _render_vqa_if_present(
recipe: TrainingRecipe,
*,
persistent: Sequence[LanguageRow],
events: Sequence[LanguageRow],
t: float,
sample_idx: int,
task: str | None,
dataset_ctx: Any | None,
) -> RenderedMessages | None:
"""Render an ``ask_vqa*`` sub-recipe iff this frame carries a VQA
annotation; otherwise return ``None`` so the caller falls back to the
normal weighted blend.
When several VQA sub-recipes resolve (e.g. a frame annotated for more
than one camera), one is chosen deterministically by relative weight.
"""
assert recipe.blend is not None
renderable: list[tuple[float, RenderedMessages]] = []
for name, component in recipe.blend.items():
if not name.startswith("ask_vqa"):
continue
bindings = _resolve_bindings(
component,
persistent=persistent,
events=events,
t=t,
sample_idx=sample_idx,
task=task,
dataset_ctx=dataset_ctx,
)
rendered = _render_message_recipe(component, bindings)
if rendered is not None:
renderable.append((float(component.weight or 0.0), rendered))
if not renderable:
return None
if len(renderable) == 1:
return renderable[0][1]
# Multiple cameras have a VQA for this frame — deterministic pick by
# relative weight (fall back to a uniform draw if all weights are 0).
total = sum(w for w, _ in renderable) or float(len(renderable))
digest = hashlib.blake2b(f"vqa:{sample_idx}".encode(), digest_size=8).digest()
draw = int.from_bytes(digest, "big") / 2**64 * total
cumulative = 0.0
for w, rendered in renderable:
cumulative += w or (total / len(renderable))
if draw < cumulative:
return rendered
return renderable[-1][1]
def _select_recipe(recipe: TrainingRecipe, sample_idx: int) -> TrainingRecipe:
"""Pick a deterministic blend component for ``sample_idx`` (or return ``recipe``)."""
if recipe.blend is None:
@@ -239,9 +307,7 @@ def _resolve_bindings(
) -> dict[str, LanguageRow | str | None]:
"""Resolve every binding in ``recipe`` (plus ``task``) at time ``t``."""
bindings: dict[str, LanguageRow | str | None] = {
"task": _resolve_task(
task, dataset_ctx, persistent=persistent, sample_idx=sample_idx
),
"task": _resolve_task(task, dataset_ctx, persistent=persistent, sample_idx=sample_idx),
}
specs = {**DEFAULT_BINDINGS, **(recipe.bindings or {})}
for name, spec in specs.items():
@@ -275,18 +341,12 @@ def _resolve_task(
if task is not None:
return task
aug_rows = [
r
for r in persistent
if r.get("style") == "task_aug" and r.get("role") == "user"
]
aug_rows = [r for r in persistent if r.get("style") == "task_aug" and r.get("role") == "user"]
if aug_rows:
# Deterministic, blake2b-based pick keyed on sample_idx so the
# rotation is reproducible across runs (Python's built-in ``hash``
# is process-randomized).
digest = hashlib.blake2b(
f"task_aug:{sample_idx}".encode(), digest_size=8
).digest()
digest = hashlib.blake2b(f"task_aug:{sample_idx}".encode(), digest_size=8).digest()
idx = int.from_bytes(digest, "big") % len(aug_rows)
chosen = aug_rows[idx].get("content")
if chosen:
@@ -444,10 +504,7 @@ def _validate_rendered(rendered: RenderedMessages) -> None:
# Valid iff it supervises something: a text-CE target turn OR a
# ``low_level`` stream turn (flow / action supervision).
if not target_indices and not any(s == "low_level" for s in streams):
raise ValueError(
"Rendered samples must contain a target message or a "
"low_level-stream message."
)
raise ValueError("Rendered samples must contain a target message or a low_level-stream message.")
for idx in target_indices:
if idx < 0 or idx >= len(messages):
raise ValueError(f"Target message index {idx} is out of bounds.")
+63
View File
@@ -84,3 +84,66 @@ class EpisodeAwareSampler:
def __len__(self) -> int:
return len(self.indices)
class WeightedEpisodeAwareSampler(EpisodeAwareSampler):
"""``EpisodeAwareSampler`` that draws frames *with replacement* in
proportion to per-frame weights.
Used to oversample frames carrying a sparse annotation (e.g. a VQA
question) so the policy sees them more often than their natural
dataset density. One epoch still yields ``len(self.indices)``
samples — the weights only change the *composition* of the stream,
not its length. Each epoch re-draws, so the oversampled subset
varies run to run.
"""
def __init__(
self,
dataset_from_indices: list[int],
dataset_to_indices: list[int],
frame_weights,
*,
episode_indices_to_use: list | None = None,
drop_n_first_frames: int = 0,
drop_n_last_frames: int = 0,
):
"""
Args:
dataset_from_indices: Episode start indices (see ``EpisodeAwareSampler``).
dataset_to_indices: Episode end indices.
frame_weights: 1-D sequence/tensor of non-negative weights, one per
dataset frame (length == total dataset frames). Higher weight ⇒
that frame is sampled more often.
episode_indices_to_use / drop_n_first_frames / drop_n_last_frames:
Same meaning as ``EpisodeAwareSampler`` — the episode-boundary
frame filtering is applied first, then weighting is restricted
to the surviving frames.
"""
super().__init__(
dataset_from_indices,
dataset_to_indices,
episode_indices_to_use=episode_indices_to_use,
drop_n_first_frames=drop_n_first_frames,
drop_n_last_frames=drop_n_last_frames,
shuffle=False,
)
weights = torch.as_tensor(frame_weights, dtype=torch.double).flatten()
idx = torch.tensor(self.indices, dtype=torch.long)
if weights.numel() <= int(idx.max()):
raise ValueError(
f"frame_weights has {weights.numel()} entries but the sampler "
f"references frame index {int(idx.max())}."
)
selected = weights[idx]
if not torch.isfinite(selected).all() or bool((selected < 0).any()):
raise ValueError("frame_weights must be finite and non-negative.")
if float(selected.sum()) <= 0.0:
# All surviving frames have zero weight — fall back to uniform.
selected = torch.ones_like(selected)
self._weights = selected
def __iter__(self) -> Iterator[int]:
picks = torch.multinomial(self._weights, num_samples=len(self.indices), replacement=True)
for i in picks.tolist():
yield self.indices[i]
+79 -8
View File
@@ -43,7 +43,7 @@ from lerobot.common.train_utils import (
from lerobot.common.wandb_utils import WandBLogger
from lerobot.configs import parser
from lerobot.configs.train import TrainPipelineConfig
from lerobot.datasets import EpisodeAwareSampler, make_dataset
from lerobot.datasets import EpisodeAwareSampler, WeightedEpisodeAwareSampler, make_dataset
from lerobot.envs import close_envs, make_env, make_env_pre_post_processors
from lerobot.optim.factory import make_optimizer_and_scheduler
from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors
@@ -156,6 +156,61 @@ def update_policy(
return train_metrics, output_dict
def _build_vqa_oversample_weights(dataset: Any, target_fraction: float) -> "torch.Tensor | None":
"""Build per-frame sampling weights that oversample VQA-annotated frames.
Scans the dataset's ``language_events`` column for frames carrying a
``vqa``-style annotation and returns a weight tensor (length == total
dataset frames) such that, under multinomial sampling, VQA frames make up
roughly ``target_fraction`` of the training stream.
Returns ``None`` (⇒ fall back to uniform episode-aware sampling) when VQA
frames cannot be detected or there are none.
"""
if not 0.0 < target_fraction < 1.0:
logging.warning(
"vqa_target_fraction must be in (0, 1); got %s — VQA oversampling disabled.",
target_fraction,
)
return None
hf = getattr(dataset, "hf_dataset", None)
if hf is None or "language_events" not in getattr(hf, "column_names", []):
logging.warning(
"Dataset has no `language_events` column — VQA oversampling disabled."
)
return None
events_col = hf["language_events"]
n_frames = len(events_col)
is_vqa = torch.zeros(n_frames, dtype=torch.bool)
for i, rows in enumerate(events_col):
if rows and any((row or {}).get("style") == "vqa" for row in rows):
is_vqa[i] = True
n_vqa = int(is_vqa.sum())
if n_vqa == 0:
logging.warning("No `vqa` annotations found in the dataset — VQA oversampling disabled.")
return None
n_other = n_frames - n_vqa
# Solve target = (n_vqa·w) / (n_vqa·w + n_other) for the VQA weight w.
# Clamp to ≥ 1 so VQA frames are never *down*-weighted below uniform.
weight = (target_fraction * n_other) / ((1.0 - target_fraction) * max(n_vqa, 1))
weight = max(weight, 1.0)
weights = torch.ones(n_frames, dtype=torch.double)
weights[is_vqa] = weight
logging.info(
"VQA oversampling: %d/%d frames carry a `vqa` annotation (%.2f%%); "
"weighting them x%.2f to target ~%.0f%% of the training stream.",
n_vqa,
n_frames,
100.0 * n_vqa / n_frames,
weight,
100.0 * target_fraction,
)
return weights
@parser.wrap()
def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
"""
@@ -376,13 +431,29 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
# create dataloader for offline training
if hasattr(cfg.policy, "drop_n_last_frames"):
shuffle = False
sampler = EpisodeAwareSampler(
dataset.meta.episodes["dataset_from_index"],
dataset.meta.episodes["dataset_to_index"],
episode_indices_to_use=dataset.episodes,
drop_n_last_frames=cfg.policy.drop_n_last_frames,
shuffle=True,
)
from_indices = dataset.meta.episodes["dataset_from_index"]
to_indices = dataset.meta.episodes["dataset_to_index"]
# When `vqa_target_fraction` is set, oversample VQA-annotated
# frames via a weighted sampler; otherwise plain episode-aware.
vqa_weights = None
if cfg.vqa_target_fraction is not None and not cfg.dataset.streaming:
vqa_weights = _build_vqa_oversample_weights(dataset, cfg.vqa_target_fraction)
if vqa_weights is not None:
sampler = WeightedEpisodeAwareSampler(
from_indices,
to_indices,
vqa_weights,
episode_indices_to_use=dataset.episodes,
drop_n_last_frames=cfg.policy.drop_n_last_frames,
)
else:
sampler = EpisodeAwareSampler(
from_indices,
to_indices,
episode_indices_to_use=dataset.episodes,
drop_n_last_frames=cfg.policy.drop_n_last_frames,
shuffle=True,
)
else:
shuffle = True
sampler = None
+60 -14
View File
@@ -207,12 +207,8 @@ def test_per_camera_blend_renders_both_views():
"top": TrainingRecipe(
weight=1.0,
bindings={
"vqa_query": (
"emitted_at(t, style=vqa, role=user, camera=observation.images.top)"
),
"vqa": (
"emitted_at(t, style=vqa, role=assistant, camera=observation.images.top)"
),
"vqa_query": ("emitted_at(t, style=vqa, role=user, camera=observation.images.top)"),
"vqa": ("emitted_at(t, style=vqa, role=assistant, camera=observation.images.top)"),
},
messages=[
MessageTurn(
@@ -236,12 +232,8 @@ def test_per_camera_blend_renders_both_views():
"wrist": TrainingRecipe(
weight=1.0,
bindings={
"vqa_query": (
"emitted_at(t, style=vqa, role=user, camera=observation.images.wrist)"
),
"vqa": (
"emitted_at(t, style=vqa, role=assistant, camera=observation.images.wrist)"
),
"vqa_query": ("emitted_at(t, style=vqa, role=user, camera=observation.images.wrist)"),
"vqa": ("emitted_at(t, style=vqa, role=assistant, camera=observation.images.wrist)"),
},
messages=[
MessageTurn(
@@ -319,11 +311,19 @@ def test_resolve_task_picks_rephrasing_deterministically_per_sample():
assert seen == {r["content"] for r in rephrasings}
# Same sample_idx → same pick (determinism).
a = render_sample(
recipe=recipe, persistent=rephrasings, events=[], t=0.0, sample_idx=42,
recipe=recipe,
persistent=rephrasings,
events=[],
t=0.0,
sample_idx=42,
dataset_ctx={"task": "canonical"},
)
b = render_sample(
recipe=recipe, persistent=rephrasings, events=[], t=0.0, sample_idx=42,
recipe=recipe,
persistent=rephrasings,
events=[],
t=0.0,
sample_idx=42,
dataset_ctx={"task": "canonical"},
)
assert a["messages"][0]["content"] == b["messages"][0]["content"]
@@ -402,6 +402,52 @@ def test_flow_only_low_level_recipe_renders_without_target():
assert rendered["target_message_indices"] == []
def test_vqa_frame_is_consumed_over_the_weighted_blend():
"""A frame carrying a VQA annotation renders the ``ask_vqa*`` sub-recipe
even when its blend weight is tiny — VQA annotations are sparse and must
never be wasted on a subtask/action draw."""
recipe = TrainingRecipe(
blend={
"high_level_subtask": TrainingRecipe(
weight=0.99,
messages=[
MessageTurn(role="user", content="${task}", stream="high_level"),
MessageTurn(role="assistant", content="a subtask", stream="high_level", target=True),
],
),
"ask_vqa_top": TrainingRecipe(
weight=0.01,
bindings={
"vqa_query": "emitted_at(t, style=vqa, role=user, camera=observation.images.top)",
"vqa": "emitted_at(t, style=vqa, role=assistant, camera=observation.images.top)",
},
messages=[
MessageTurn(
role="user", content="${vqa_query}", stream="high_level", if_present="vqa_query"
),
MessageTurn(
role="assistant",
content="${vqa}",
stream="high_level",
target=True,
if_present="vqa",
),
],
),
}
)
# A frame WITH a vqa event renders VQA on every sample_idx, despite the
# ask_vqa weight being only 0.01.
for sample_idx in range(20):
rendered = render_sample(
recipe=recipe, persistent=PERSISTENT, events=EVENTS_AT_1, t=1.0, sample_idx=sample_idx, task="x"
)
assert rendered["messages"][-1]["content"] == '{"count": 2}', sample_idx
# A frame WITHOUT a vqa event falls back to the normal weighted blend.
rendered = render_sample(recipe=recipe, persistent=PERSISTENT, events=[], t=1.0, sample_idx=0, task="x")
assert rendered["messages"][-1]["content"] == "a subtask"
def test_canonical_recipe_can_render_low_level_branch():
recipe = TrainingRecipe.from_yaml(Path("src/lerobot/configs/recipes/pi05_hirobot.yaml"))
low_level = TrainingRecipe(blend={"low": recipe.blend["low_level_execution"]})
+47 -1
View File
@@ -25,7 +25,7 @@ from datasets import Dataset # noqa: E402
from lerobot.datasets.io_utils import (
hf_transform_to_torch,
)
from lerobot.datasets.sampler import EpisodeAwareSampler
from lerobot.datasets.sampler import EpisodeAwareSampler, WeightedEpisodeAwareSampler
def calculate_episode_data_index(hf_dataset: Dataset) -> dict[str, torch.Tensor]:
@@ -137,3 +137,49 @@ def test_partial_episode_drop_warns(caplog):
# Episode 0 is skipped (1 frame, drop 1), Episode 1 keeps frames 2-5
assert sampler.indices == [2, 3, 4, 5]
assert "Episode 0" in caplog.text
# --- WeightedEpisodeAwareSampler --------------------------------------------
def test_weighted_sampler_respects_episode_drop_and_length():
"""The episode-boundary frame filtering is applied before weighting,
and one epoch still yields ``len(indices)`` samples."""
# One episode, 10 frames; drop the last 2.
sampler = WeightedEpisodeAwareSampler([0], [10], frame_weights=torch.ones(10), drop_n_last_frames=2)
assert sampler.indices == list(range(8))
assert len(sampler) == 8
draws = list(sampler)
assert len(draws) == 8
# Dropped frames 8 and 9 must never be sampled.
assert all(d in set(range(8)) for d in draws)
def test_weighted_sampler_oversamples_high_weight_frames():
"""A heavily-weighted frame dominates the draws."""
torch.manual_seed(0)
# 100 frames, frame 7 is weighted 1000x.
weights = torch.ones(100)
weights[7] = 1000.0
sampler = WeightedEpisodeAwareSampler([0], [100], frame_weights=weights)
counts = {}
for _ in range(20): # 20 epochs
for d in sampler:
counts[d] = counts.get(d, 0) + 1
total = sum(counts.values())
# Frame 7 should be the overwhelming majority of the 2000 draws.
assert counts.get(7, 0) / total > 0.9
def test_weighted_sampler_zero_weights_fall_back_to_uniform():
"""If every surviving frame has zero weight, sampling is uniform
rather than crashing."""
sampler = WeightedEpisodeAwareSampler([0], [6], frame_weights=torch.zeros(6))
draws = set(sampler)
assert draws.issubset(set(range(6)))
assert len(list(sampler)) == 6
def test_weighted_sampler_rejects_short_weight_vector():
with pytest.raises(ValueError, match="frame_weights"):
WeightedEpisodeAwareSampler([0], [10], frame_weights=torch.ones(5))