From fbcb9225f5010120e42e9d4a07a5c9ade273eb58 Mon Sep 17 00:00:00 2001 From: Pepijn Date: Mon, 18 May 2026 15:30:00 +0200 Subject: [PATCH] feat: oversample sparse VQA annotations (recipe consumption + weighted sampler) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit VQA annotations are sparse, so VQA was badly underrepresented in training: its effective share was weight x density, and blend draws that picked an ask_vqa* sub-recipe for a non-VQA frame were wasted entirely. Two pieces: 1. Recipe-side consumption (language_render.py): render_sample now routes any frame that carries a VQA annotation to a matching ask_vqa* sub-recipe, regardless of the weighted blend draw. No VQA annotation is wasted and no draw lands on a non-renderable VQA recipe — VQA's recipe-side share now equals the VQA-annotation density. 2. Dataset-side oversampling (WeightedEpisodeAwareSampler + vqa_target_fraction): a new weighted, episode-aware sampler draws frames with replacement by per-frame weight. When TrainPipelineConfig.vqa_target_fraction is set, the train script scans language_events, weights VQA frames so they make up ~that fraction of the training stream, and uses the weighted sampler. This is what actually lets VQA exceed its natural density. Default None keeps uniform episode-aware sampling unchanged. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/lerobot/configs/train.py | 8 ++ src/lerobot/datasets/__init__.py | 3 +- src/lerobot/datasets/language_render.py | 111 ++++++++++++++++++------ src/lerobot/datasets/sampler.py | 63 ++++++++++++++ src/lerobot/scripts/lerobot_train.py | 87 +++++++++++++++++-- tests/datasets/test_language_render.py | 74 +++++++++++++--- tests/datasets/test_sampler.py | 48 +++++++++- 7 files changed, 343 insertions(+), 51 deletions(-) diff --git a/src/lerobot/configs/train.py b/src/lerobot/configs/train.py index 924bcf5bb..3bdcbf6d5 100644 --- a/src/lerobot/configs/train.py +++ b/src/lerobot/configs/train.py @@ -72,6 +72,14 @@ class TrainPipelineConfig(HubMixin): wandb: WandBConfig = field(default_factory=WandBConfig) peft: PeftConfig | None = None + # VQA oversampling. When set (a fraction in (0, 1)), the training + # dataloader uses a WeightedEpisodeAwareSampler that draws frames + # carrying a `vqa` language annotation often enough that they make + # up roughly this fraction of the training stream. VQA annotations + # are typically sparse, so without this they are underrepresented. + # `None` (default) keeps uniform episode-aware sampling. + vqa_target_fraction: float | None = None + # RA-BC (Reward-Aligned Behavior Cloning) parameters use_rabc: bool = False # Enable reward-weighted training rabc_progress_path: str | None = None # Path to precomputed SARM progress parquet file diff --git a/src/lerobot/datasets/__init__.py b/src/lerobot/datasets/__init__.py index 067f91091..7b62d961b 100644 --- a/src/lerobot/datasets/__init__.py +++ b/src/lerobot/datasets/__init__.py @@ -47,7 +47,7 @@ from .language import ( from .lerobot_dataset import LeRobotDataset from .multi_dataset import MultiLeRobotDataset from .pipeline_features import aggregate_pipeline_dataset_features, create_initial_features -from .sampler import EpisodeAwareSampler +from .sampler import EpisodeAwareSampler, WeightedEpisodeAwareSampler from .streaming_dataset import StreamingLeRobotDataset from .utils import DEFAULT_EPISODES_PATH, create_lerobot_dataset_card from .video_utils import VideoEncodingManager @@ -75,6 +75,7 @@ __all__ = [ "DEFAULT_QUANTILES", "EVENT_ONLY_STYLES", "EpisodeAwareSampler", + "WeightedEpisodeAwareSampler", "LANGUAGE_EVENTS", "LANGUAGE_PERSISTENT", "LeRobotDataset", diff --git a/src/lerobot/datasets/language_render.py b/src/lerobot/datasets/language_render.py index 42206907c..e6c749418 100644 --- a/src/lerobot/datasets/language_render.py +++ b/src/lerobot/datasets/language_render.py @@ -56,13 +56,9 @@ def active_at( uniformity but is not consulted: only persistent styles are valid here. """ _validate_persistent_resolver("active_at", style) - matches = _matching_rows( - persistent, style=style, role=role, tool_name=tool_name, camera=camera - ) + matches = _matching_rows(persistent, style=style, role=role, tool_name=tool_name, camera=camera) matches = [row for row in matches if _timestamp(row) <= t] - return _select_latest( - matches, style=style, role=role, tool_name=tool_name, camera=camera - ) + return _select_latest(matches, style=style, role=role, tool_name=tool_name, camera=camera) def emitted_at( @@ -88,9 +84,7 @@ def emitted_at( if column == LANGUAGE_PERSISTENT: matches = [ row - for row in _matching_rows( - persistent, style=style, role=role, tool_name=tool_name, camera=camera - ) + for row in _matching_rows(persistent, style=style, role=role, tool_name=tool_name, camera=camera) if _timestamp(row) == t ] return _select_one( @@ -101,9 +95,7 @@ def emitted_at( camera=camera, sort_key=_persistent_sort_key, ) - matches = _matching_rows( - events, style=style, role=role, tool_name=tool_name, camera=camera - ) + matches = _matching_rows(events, style=style, role=role, tool_name=tool_name, camera=camera) return _select_one( matches, style=style, @@ -192,6 +184,29 @@ def render_sample( """ persistent_rows = _normalize_rows(persistent or []) event_rows = _normalize_rows(events or []) + + # VQA-priority routing. A ``vqa`` annotation is sparse and + # view-dependent; the plain weighted blend would (a) waste a draw + # whenever it picks an ``ask_vqa*`` sub-recipe for a frame that has + # no VQA, and (b) silently drop a VQA-annotated frame whenever it + # picks a non-VQA sub-recipe. So: if the blend has ``ask_vqa*`` + # sub-recipes and *this* frame carries one of their VQA bindings, + # render VQA here regardless of the weighted draw. That makes VQA's + # recipe-side training share equal the VQA-annotation density (the + # maximum reachable without a dataset-level oversampling sampler). + if recipe.blend is not None: + vqa_rendered = _render_vqa_if_present( + recipe, + persistent=persistent_rows, + events=event_rows, + t=t, + sample_idx=sample_idx, + task=task, + dataset_ctx=dataset_ctx, + ) + if vqa_rendered is not None: + return vqa_rendered + selected_recipe = _select_recipe(recipe, sample_idx) bindings = _resolve_bindings( selected_recipe, @@ -205,6 +220,59 @@ def render_sample( return _render_message_recipe(selected_recipe, bindings) +def _render_vqa_if_present( + recipe: TrainingRecipe, + *, + persistent: Sequence[LanguageRow], + events: Sequence[LanguageRow], + t: float, + sample_idx: int, + task: str | None, + dataset_ctx: Any | None, +) -> RenderedMessages | None: + """Render an ``ask_vqa*`` sub-recipe iff this frame carries a VQA + annotation; otherwise return ``None`` so the caller falls back to the + normal weighted blend. + + When several VQA sub-recipes resolve (e.g. a frame annotated for more + than one camera), one is chosen deterministically by relative weight. + """ + assert recipe.blend is not None + renderable: list[tuple[float, RenderedMessages]] = [] + for name, component in recipe.blend.items(): + if not name.startswith("ask_vqa"): + continue + bindings = _resolve_bindings( + component, + persistent=persistent, + events=events, + t=t, + sample_idx=sample_idx, + task=task, + dataset_ctx=dataset_ctx, + ) + rendered = _render_message_recipe(component, bindings) + if rendered is not None: + renderable.append((float(component.weight or 0.0), rendered)) + + if not renderable: + return None + if len(renderable) == 1: + return renderable[0][1] + + # Multiple cameras have a VQA for this frame — deterministic pick by + # relative weight (fall back to a uniform draw if all weights are 0). + total = sum(w for w, _ in renderable) or float(len(renderable)) + digest = hashlib.blake2b(f"vqa:{sample_idx}".encode(), digest_size=8).digest() + draw = int.from_bytes(digest, "big") / 2**64 * total + cumulative = 0.0 + for w, rendered in renderable: + cumulative += w or (total / len(renderable)) + if draw < cumulative: + return rendered + return renderable[-1][1] + + def _select_recipe(recipe: TrainingRecipe, sample_idx: int) -> TrainingRecipe: """Pick a deterministic blend component for ``sample_idx`` (or return ``recipe``).""" if recipe.blend is None: @@ -239,9 +307,7 @@ def _resolve_bindings( ) -> dict[str, LanguageRow | str | None]: """Resolve every binding in ``recipe`` (plus ``task``) at time ``t``.""" bindings: dict[str, LanguageRow | str | None] = { - "task": _resolve_task( - task, dataset_ctx, persistent=persistent, sample_idx=sample_idx - ), + "task": _resolve_task(task, dataset_ctx, persistent=persistent, sample_idx=sample_idx), } specs = {**DEFAULT_BINDINGS, **(recipe.bindings or {})} for name, spec in specs.items(): @@ -275,18 +341,12 @@ def _resolve_task( if task is not None: return task - aug_rows = [ - r - for r in persistent - if r.get("style") == "task_aug" and r.get("role") == "user" - ] + aug_rows = [r for r in persistent if r.get("style") == "task_aug" and r.get("role") == "user"] if aug_rows: # Deterministic, blake2b-based pick keyed on sample_idx so the # rotation is reproducible across runs (Python's built-in ``hash`` # is process-randomized). - digest = hashlib.blake2b( - f"task_aug:{sample_idx}".encode(), digest_size=8 - ).digest() + digest = hashlib.blake2b(f"task_aug:{sample_idx}".encode(), digest_size=8).digest() idx = int.from_bytes(digest, "big") % len(aug_rows) chosen = aug_rows[idx].get("content") if chosen: @@ -444,10 +504,7 @@ def _validate_rendered(rendered: RenderedMessages) -> None: # Valid iff it supervises something: a text-CE target turn OR a # ``low_level`` stream turn (flow / action supervision). if not target_indices and not any(s == "low_level" for s in streams): - raise ValueError( - "Rendered samples must contain a target message or a " - "low_level-stream message." - ) + raise ValueError("Rendered samples must contain a target message or a low_level-stream message.") for idx in target_indices: if idx < 0 or idx >= len(messages): raise ValueError(f"Target message index {idx} is out of bounds.") diff --git a/src/lerobot/datasets/sampler.py b/src/lerobot/datasets/sampler.py index 2bf7ab922..c03194b63 100644 --- a/src/lerobot/datasets/sampler.py +++ b/src/lerobot/datasets/sampler.py @@ -84,3 +84,66 @@ class EpisodeAwareSampler: def __len__(self) -> int: return len(self.indices) + + +class WeightedEpisodeAwareSampler(EpisodeAwareSampler): + """``EpisodeAwareSampler`` that draws frames *with replacement* in + proportion to per-frame weights. + + Used to oversample frames carrying a sparse annotation (e.g. a VQA + question) so the policy sees them more often than their natural + dataset density. One epoch still yields ``len(self.indices)`` + samples — the weights only change the *composition* of the stream, + not its length. Each epoch re-draws, so the oversampled subset + varies run to run. + """ + + def __init__( + self, + dataset_from_indices: list[int], + dataset_to_indices: list[int], + frame_weights, + *, + episode_indices_to_use: list | None = None, + drop_n_first_frames: int = 0, + drop_n_last_frames: int = 0, + ): + """ + Args: + dataset_from_indices: Episode start indices (see ``EpisodeAwareSampler``). + dataset_to_indices: Episode end indices. + frame_weights: 1-D sequence/tensor of non-negative weights, one per + dataset frame (length == total dataset frames). Higher weight ⇒ + that frame is sampled more often. + episode_indices_to_use / drop_n_first_frames / drop_n_last_frames: + Same meaning as ``EpisodeAwareSampler`` — the episode-boundary + frame filtering is applied first, then weighting is restricted + to the surviving frames. + """ + super().__init__( + dataset_from_indices, + dataset_to_indices, + episode_indices_to_use=episode_indices_to_use, + drop_n_first_frames=drop_n_first_frames, + drop_n_last_frames=drop_n_last_frames, + shuffle=False, + ) + weights = torch.as_tensor(frame_weights, dtype=torch.double).flatten() + idx = torch.tensor(self.indices, dtype=torch.long) + if weights.numel() <= int(idx.max()): + raise ValueError( + f"frame_weights has {weights.numel()} entries but the sampler " + f"references frame index {int(idx.max())}." + ) + selected = weights[idx] + if not torch.isfinite(selected).all() or bool((selected < 0).any()): + raise ValueError("frame_weights must be finite and non-negative.") + if float(selected.sum()) <= 0.0: + # All surviving frames have zero weight — fall back to uniform. + selected = torch.ones_like(selected) + self._weights = selected + + def __iter__(self) -> Iterator[int]: + picks = torch.multinomial(self._weights, num_samples=len(self.indices), replacement=True) + for i in picks.tolist(): + yield self.indices[i] diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index a0cd204f2..2c9b3ad56 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -43,7 +43,7 @@ from lerobot.common.train_utils import ( from lerobot.common.wandb_utils import WandBLogger from lerobot.configs import parser from lerobot.configs.train import TrainPipelineConfig -from lerobot.datasets import EpisodeAwareSampler, make_dataset +from lerobot.datasets import EpisodeAwareSampler, WeightedEpisodeAwareSampler, make_dataset from lerobot.envs import close_envs, make_env, make_env_pre_post_processors from lerobot.optim.factory import make_optimizer_and_scheduler from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors @@ -156,6 +156,61 @@ def update_policy( return train_metrics, output_dict +def _build_vqa_oversample_weights(dataset: Any, target_fraction: float) -> "torch.Tensor | None": + """Build per-frame sampling weights that oversample VQA-annotated frames. + + Scans the dataset's ``language_events`` column for frames carrying a + ``vqa``-style annotation and returns a weight tensor (length == total + dataset frames) such that, under multinomial sampling, VQA frames make up + roughly ``target_fraction`` of the training stream. + + Returns ``None`` (⇒ fall back to uniform episode-aware sampling) when VQA + frames cannot be detected or there are none. + """ + if not 0.0 < target_fraction < 1.0: + logging.warning( + "vqa_target_fraction must be in (0, 1); got %s — VQA oversampling disabled.", + target_fraction, + ) + return None + hf = getattr(dataset, "hf_dataset", None) + if hf is None or "language_events" not in getattr(hf, "column_names", []): + logging.warning( + "Dataset has no `language_events` column — VQA oversampling disabled." + ) + return None + + events_col = hf["language_events"] + n_frames = len(events_col) + is_vqa = torch.zeros(n_frames, dtype=torch.bool) + for i, rows in enumerate(events_col): + if rows and any((row or {}).get("style") == "vqa" for row in rows): + is_vqa[i] = True + + n_vqa = int(is_vqa.sum()) + if n_vqa == 0: + logging.warning("No `vqa` annotations found in the dataset — VQA oversampling disabled.") + return None + n_other = n_frames - n_vqa + + # Solve target = (n_vqa·w) / (n_vqa·w + n_other) for the VQA weight w. + # Clamp to ≥ 1 so VQA frames are never *down*-weighted below uniform. + weight = (target_fraction * n_other) / ((1.0 - target_fraction) * max(n_vqa, 1)) + weight = max(weight, 1.0) + weights = torch.ones(n_frames, dtype=torch.double) + weights[is_vqa] = weight + logging.info( + "VQA oversampling: %d/%d frames carry a `vqa` annotation (%.2f%%); " + "weighting them x%.2f to target ~%.0f%% of the training stream.", + n_vqa, + n_frames, + 100.0 * n_vqa / n_frames, + weight, + 100.0 * target_fraction, + ) + return weights + + @parser.wrap() def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): """ @@ -376,13 +431,29 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): # create dataloader for offline training if hasattr(cfg.policy, "drop_n_last_frames"): shuffle = False - sampler = EpisodeAwareSampler( - dataset.meta.episodes["dataset_from_index"], - dataset.meta.episodes["dataset_to_index"], - episode_indices_to_use=dataset.episodes, - drop_n_last_frames=cfg.policy.drop_n_last_frames, - shuffle=True, - ) + from_indices = dataset.meta.episodes["dataset_from_index"] + to_indices = dataset.meta.episodes["dataset_to_index"] + # When `vqa_target_fraction` is set, oversample VQA-annotated + # frames via a weighted sampler; otherwise plain episode-aware. + vqa_weights = None + if cfg.vqa_target_fraction is not None and not cfg.dataset.streaming: + vqa_weights = _build_vqa_oversample_weights(dataset, cfg.vqa_target_fraction) + if vqa_weights is not None: + sampler = WeightedEpisodeAwareSampler( + from_indices, + to_indices, + vqa_weights, + episode_indices_to_use=dataset.episodes, + drop_n_last_frames=cfg.policy.drop_n_last_frames, + ) + else: + sampler = EpisodeAwareSampler( + from_indices, + to_indices, + episode_indices_to_use=dataset.episodes, + drop_n_last_frames=cfg.policy.drop_n_last_frames, + shuffle=True, + ) else: shuffle = True sampler = None diff --git a/tests/datasets/test_language_render.py b/tests/datasets/test_language_render.py index e604ede4a..f8bd7ce4f 100644 --- a/tests/datasets/test_language_render.py +++ b/tests/datasets/test_language_render.py @@ -207,12 +207,8 @@ def test_per_camera_blend_renders_both_views(): "top": TrainingRecipe( weight=1.0, bindings={ - "vqa_query": ( - "emitted_at(t, style=vqa, role=user, camera=observation.images.top)" - ), - "vqa": ( - "emitted_at(t, style=vqa, role=assistant, camera=observation.images.top)" - ), + "vqa_query": ("emitted_at(t, style=vqa, role=user, camera=observation.images.top)"), + "vqa": ("emitted_at(t, style=vqa, role=assistant, camera=observation.images.top)"), }, messages=[ MessageTurn( @@ -236,12 +232,8 @@ def test_per_camera_blend_renders_both_views(): "wrist": TrainingRecipe( weight=1.0, bindings={ - "vqa_query": ( - "emitted_at(t, style=vqa, role=user, camera=observation.images.wrist)" - ), - "vqa": ( - "emitted_at(t, style=vqa, role=assistant, camera=observation.images.wrist)" - ), + "vqa_query": ("emitted_at(t, style=vqa, role=user, camera=observation.images.wrist)"), + "vqa": ("emitted_at(t, style=vqa, role=assistant, camera=observation.images.wrist)"), }, messages=[ MessageTurn( @@ -319,11 +311,19 @@ def test_resolve_task_picks_rephrasing_deterministically_per_sample(): assert seen == {r["content"] for r in rephrasings} # Same sample_idx → same pick (determinism). a = render_sample( - recipe=recipe, persistent=rephrasings, events=[], t=0.0, sample_idx=42, + recipe=recipe, + persistent=rephrasings, + events=[], + t=0.0, + sample_idx=42, dataset_ctx={"task": "canonical"}, ) b = render_sample( - recipe=recipe, persistent=rephrasings, events=[], t=0.0, sample_idx=42, + recipe=recipe, + persistent=rephrasings, + events=[], + t=0.0, + sample_idx=42, dataset_ctx={"task": "canonical"}, ) assert a["messages"][0]["content"] == b["messages"][0]["content"] @@ -402,6 +402,52 @@ def test_flow_only_low_level_recipe_renders_without_target(): assert rendered["target_message_indices"] == [] +def test_vqa_frame_is_consumed_over_the_weighted_blend(): + """A frame carrying a VQA annotation renders the ``ask_vqa*`` sub-recipe + even when its blend weight is tiny — VQA annotations are sparse and must + never be wasted on a subtask/action draw.""" + recipe = TrainingRecipe( + blend={ + "high_level_subtask": TrainingRecipe( + weight=0.99, + messages=[ + MessageTurn(role="user", content="${task}", stream="high_level"), + MessageTurn(role="assistant", content="a subtask", stream="high_level", target=True), + ], + ), + "ask_vqa_top": TrainingRecipe( + weight=0.01, + bindings={ + "vqa_query": "emitted_at(t, style=vqa, role=user, camera=observation.images.top)", + "vqa": "emitted_at(t, style=vqa, role=assistant, camera=observation.images.top)", + }, + messages=[ + MessageTurn( + role="user", content="${vqa_query}", stream="high_level", if_present="vqa_query" + ), + MessageTurn( + role="assistant", + content="${vqa}", + stream="high_level", + target=True, + if_present="vqa", + ), + ], + ), + } + ) + # A frame WITH a vqa event renders VQA on every sample_idx, despite the + # ask_vqa weight being only 0.01. + for sample_idx in range(20): + rendered = render_sample( + recipe=recipe, persistent=PERSISTENT, events=EVENTS_AT_1, t=1.0, sample_idx=sample_idx, task="x" + ) + assert rendered["messages"][-1]["content"] == '{"count": 2}', sample_idx + # A frame WITHOUT a vqa event falls back to the normal weighted blend. + rendered = render_sample(recipe=recipe, persistent=PERSISTENT, events=[], t=1.0, sample_idx=0, task="x") + assert rendered["messages"][-1]["content"] == "a subtask" + + def test_canonical_recipe_can_render_low_level_branch(): recipe = TrainingRecipe.from_yaml(Path("src/lerobot/configs/recipes/pi05_hirobot.yaml")) low_level = TrainingRecipe(blend={"low": recipe.blend["low_level_execution"]}) diff --git a/tests/datasets/test_sampler.py b/tests/datasets/test_sampler.py index 8bb3be8e9..f7ea5aca5 100644 --- a/tests/datasets/test_sampler.py +++ b/tests/datasets/test_sampler.py @@ -25,7 +25,7 @@ from datasets import Dataset # noqa: E402 from lerobot.datasets.io_utils import ( hf_transform_to_torch, ) -from lerobot.datasets.sampler import EpisodeAwareSampler +from lerobot.datasets.sampler import EpisodeAwareSampler, WeightedEpisodeAwareSampler def calculate_episode_data_index(hf_dataset: Dataset) -> dict[str, torch.Tensor]: @@ -137,3 +137,49 @@ def test_partial_episode_drop_warns(caplog): # Episode 0 is skipped (1 frame, drop 1), Episode 1 keeps frames 2-5 assert sampler.indices == [2, 3, 4, 5] assert "Episode 0" in caplog.text + + +# --- WeightedEpisodeAwareSampler -------------------------------------------- + + +def test_weighted_sampler_respects_episode_drop_and_length(): + """The episode-boundary frame filtering is applied before weighting, + and one epoch still yields ``len(indices)`` samples.""" + # One episode, 10 frames; drop the last 2. + sampler = WeightedEpisodeAwareSampler([0], [10], frame_weights=torch.ones(10), drop_n_last_frames=2) + assert sampler.indices == list(range(8)) + assert len(sampler) == 8 + draws = list(sampler) + assert len(draws) == 8 + # Dropped frames 8 and 9 must never be sampled. + assert all(d in set(range(8)) for d in draws) + + +def test_weighted_sampler_oversamples_high_weight_frames(): + """A heavily-weighted frame dominates the draws.""" + torch.manual_seed(0) + # 100 frames, frame 7 is weighted 1000x. + weights = torch.ones(100) + weights[7] = 1000.0 + sampler = WeightedEpisodeAwareSampler([0], [100], frame_weights=weights) + counts = {} + for _ in range(20): # 20 epochs + for d in sampler: + counts[d] = counts.get(d, 0) + 1 + total = sum(counts.values()) + # Frame 7 should be the overwhelming majority of the 2000 draws. + assert counts.get(7, 0) / total > 0.9 + + +def test_weighted_sampler_zero_weights_fall_back_to_uniform(): + """If every surviving frame has zero weight, sampling is uniform + rather than crashing.""" + sampler = WeightedEpisodeAwareSampler([0], [6], frame_weights=torch.zeros(6)) + draws = set(sampler) + assert draws.issubset(set(range(6))) + assert len(list(sampler)) == 6 + + +def test_weighted_sampler_rejects_short_weight_vector(): + with pytest.raises(ValueError, match="frame_weights"): + WeightedEpisodeAwareSampler([0], [10], frame_weights=torch.ones(5))