mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-23 12:40:08 +00:00
feat: oversample sparse VQA annotations (recipe consumption + weighted sampler)
VQA annotations are sparse, so VQA was badly underrepresented in training: its effective share was weight x density, and blend draws that picked an ask_vqa* sub-recipe for a non-VQA frame were wasted entirely. Two pieces: 1. Recipe-side consumption (language_render.py): render_sample now routes any frame that carries a VQA annotation to a matching ask_vqa* sub-recipe, regardless of the weighted blend draw. No VQA annotation is wasted and no draw lands on a non-renderable VQA recipe — VQA's recipe-side share now equals the VQA-annotation density. 2. Dataset-side oversampling (WeightedEpisodeAwareSampler + vqa_target_fraction): a new weighted, episode-aware sampler draws frames with replacement by per-frame weight. When TrainPipelineConfig.vqa_target_fraction is set, the train script scans language_events, weights VQA frames so they make up ~that fraction of the training stream, and uses the weighted sampler. This is what actually lets VQA exceed its natural density. Default None keeps uniform episode-aware sampling unchanged. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -72,6 +72,14 @@ class TrainPipelineConfig(HubMixin):
|
|||||||
wandb: WandBConfig = field(default_factory=WandBConfig)
|
wandb: WandBConfig = field(default_factory=WandBConfig)
|
||||||
peft: PeftConfig | None = None
|
peft: PeftConfig | None = None
|
||||||
|
|
||||||
|
# VQA oversampling. When set (a fraction in (0, 1)), the training
|
||||||
|
# dataloader uses a WeightedEpisodeAwareSampler that draws frames
|
||||||
|
# carrying a `vqa` language annotation often enough that they make
|
||||||
|
# up roughly this fraction of the training stream. VQA annotations
|
||||||
|
# are typically sparse, so without this they are underrepresented.
|
||||||
|
# `None` (default) keeps uniform episode-aware sampling.
|
||||||
|
vqa_target_fraction: float | None = None
|
||||||
|
|
||||||
# RA-BC (Reward-Aligned Behavior Cloning) parameters
|
# RA-BC (Reward-Aligned Behavior Cloning) parameters
|
||||||
use_rabc: bool = False # Enable reward-weighted training
|
use_rabc: bool = False # Enable reward-weighted training
|
||||||
rabc_progress_path: str | None = None # Path to precomputed SARM progress parquet file
|
rabc_progress_path: str | None = None # Path to precomputed SARM progress parquet file
|
||||||
|
|||||||
@@ -47,7 +47,7 @@ from .language import (
|
|||||||
from .lerobot_dataset import LeRobotDataset
|
from .lerobot_dataset import LeRobotDataset
|
||||||
from .multi_dataset import MultiLeRobotDataset
|
from .multi_dataset import MultiLeRobotDataset
|
||||||
from .pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
|
from .pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
|
||||||
from .sampler import EpisodeAwareSampler
|
from .sampler import EpisodeAwareSampler, WeightedEpisodeAwareSampler
|
||||||
from .streaming_dataset import StreamingLeRobotDataset
|
from .streaming_dataset import StreamingLeRobotDataset
|
||||||
from .utils import DEFAULT_EPISODES_PATH, create_lerobot_dataset_card
|
from .utils import DEFAULT_EPISODES_PATH, create_lerobot_dataset_card
|
||||||
from .video_utils import VideoEncodingManager
|
from .video_utils import VideoEncodingManager
|
||||||
@@ -75,6 +75,7 @@ __all__ = [
|
|||||||
"DEFAULT_QUANTILES",
|
"DEFAULT_QUANTILES",
|
||||||
"EVENT_ONLY_STYLES",
|
"EVENT_ONLY_STYLES",
|
||||||
"EpisodeAwareSampler",
|
"EpisodeAwareSampler",
|
||||||
|
"WeightedEpisodeAwareSampler",
|
||||||
"LANGUAGE_EVENTS",
|
"LANGUAGE_EVENTS",
|
||||||
"LANGUAGE_PERSISTENT",
|
"LANGUAGE_PERSISTENT",
|
||||||
"LeRobotDataset",
|
"LeRobotDataset",
|
||||||
|
|||||||
@@ -56,13 +56,9 @@ def active_at(
|
|||||||
uniformity but is not consulted: only persistent styles are valid here.
|
uniformity but is not consulted: only persistent styles are valid here.
|
||||||
"""
|
"""
|
||||||
_validate_persistent_resolver("active_at", style)
|
_validate_persistent_resolver("active_at", style)
|
||||||
matches = _matching_rows(
|
matches = _matching_rows(persistent, style=style, role=role, tool_name=tool_name, camera=camera)
|
||||||
persistent, style=style, role=role, tool_name=tool_name, camera=camera
|
|
||||||
)
|
|
||||||
matches = [row for row in matches if _timestamp(row) <= t]
|
matches = [row for row in matches if _timestamp(row) <= t]
|
||||||
return _select_latest(
|
return _select_latest(matches, style=style, role=role, tool_name=tool_name, camera=camera)
|
||||||
matches, style=style, role=role, tool_name=tool_name, camera=camera
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def emitted_at(
|
def emitted_at(
|
||||||
@@ -88,9 +84,7 @@ def emitted_at(
|
|||||||
if column == LANGUAGE_PERSISTENT:
|
if column == LANGUAGE_PERSISTENT:
|
||||||
matches = [
|
matches = [
|
||||||
row
|
row
|
||||||
for row in _matching_rows(
|
for row in _matching_rows(persistent, style=style, role=role, tool_name=tool_name, camera=camera)
|
||||||
persistent, style=style, role=role, tool_name=tool_name, camera=camera
|
|
||||||
)
|
|
||||||
if _timestamp(row) == t
|
if _timestamp(row) == t
|
||||||
]
|
]
|
||||||
return _select_one(
|
return _select_one(
|
||||||
@@ -101,9 +95,7 @@ def emitted_at(
|
|||||||
camera=camera,
|
camera=camera,
|
||||||
sort_key=_persistent_sort_key,
|
sort_key=_persistent_sort_key,
|
||||||
)
|
)
|
||||||
matches = _matching_rows(
|
matches = _matching_rows(events, style=style, role=role, tool_name=tool_name, camera=camera)
|
||||||
events, style=style, role=role, tool_name=tool_name, camera=camera
|
|
||||||
)
|
|
||||||
return _select_one(
|
return _select_one(
|
||||||
matches,
|
matches,
|
||||||
style=style,
|
style=style,
|
||||||
@@ -192,6 +184,29 @@ def render_sample(
|
|||||||
"""
|
"""
|
||||||
persistent_rows = _normalize_rows(persistent or [])
|
persistent_rows = _normalize_rows(persistent or [])
|
||||||
event_rows = _normalize_rows(events or [])
|
event_rows = _normalize_rows(events or [])
|
||||||
|
|
||||||
|
# VQA-priority routing. A ``vqa`` annotation is sparse and
|
||||||
|
# view-dependent; the plain weighted blend would (a) waste a draw
|
||||||
|
# whenever it picks an ``ask_vqa*`` sub-recipe for a frame that has
|
||||||
|
# no VQA, and (b) silently drop a VQA-annotated frame whenever it
|
||||||
|
# picks a non-VQA sub-recipe. So: if the blend has ``ask_vqa*``
|
||||||
|
# sub-recipes and *this* frame carries one of their VQA bindings,
|
||||||
|
# render VQA here regardless of the weighted draw. That makes VQA's
|
||||||
|
# recipe-side training share equal the VQA-annotation density (the
|
||||||
|
# maximum reachable without a dataset-level oversampling sampler).
|
||||||
|
if recipe.blend is not None:
|
||||||
|
vqa_rendered = _render_vqa_if_present(
|
||||||
|
recipe,
|
||||||
|
persistent=persistent_rows,
|
||||||
|
events=event_rows,
|
||||||
|
t=t,
|
||||||
|
sample_idx=sample_idx,
|
||||||
|
task=task,
|
||||||
|
dataset_ctx=dataset_ctx,
|
||||||
|
)
|
||||||
|
if vqa_rendered is not None:
|
||||||
|
return vqa_rendered
|
||||||
|
|
||||||
selected_recipe = _select_recipe(recipe, sample_idx)
|
selected_recipe = _select_recipe(recipe, sample_idx)
|
||||||
bindings = _resolve_bindings(
|
bindings = _resolve_bindings(
|
||||||
selected_recipe,
|
selected_recipe,
|
||||||
@@ -205,6 +220,59 @@ def render_sample(
|
|||||||
return _render_message_recipe(selected_recipe, bindings)
|
return _render_message_recipe(selected_recipe, bindings)
|
||||||
|
|
||||||
|
|
||||||
|
def _render_vqa_if_present(
|
||||||
|
recipe: TrainingRecipe,
|
||||||
|
*,
|
||||||
|
persistent: Sequence[LanguageRow],
|
||||||
|
events: Sequence[LanguageRow],
|
||||||
|
t: float,
|
||||||
|
sample_idx: int,
|
||||||
|
task: str | None,
|
||||||
|
dataset_ctx: Any | None,
|
||||||
|
) -> RenderedMessages | None:
|
||||||
|
"""Render an ``ask_vqa*`` sub-recipe iff this frame carries a VQA
|
||||||
|
annotation; otherwise return ``None`` so the caller falls back to the
|
||||||
|
normal weighted blend.
|
||||||
|
|
||||||
|
When several VQA sub-recipes resolve (e.g. a frame annotated for more
|
||||||
|
than one camera), one is chosen deterministically by relative weight.
|
||||||
|
"""
|
||||||
|
assert recipe.blend is not None
|
||||||
|
renderable: list[tuple[float, RenderedMessages]] = []
|
||||||
|
for name, component in recipe.blend.items():
|
||||||
|
if not name.startswith("ask_vqa"):
|
||||||
|
continue
|
||||||
|
bindings = _resolve_bindings(
|
||||||
|
component,
|
||||||
|
persistent=persistent,
|
||||||
|
events=events,
|
||||||
|
t=t,
|
||||||
|
sample_idx=sample_idx,
|
||||||
|
task=task,
|
||||||
|
dataset_ctx=dataset_ctx,
|
||||||
|
)
|
||||||
|
rendered = _render_message_recipe(component, bindings)
|
||||||
|
if rendered is not None:
|
||||||
|
renderable.append((float(component.weight or 0.0), rendered))
|
||||||
|
|
||||||
|
if not renderable:
|
||||||
|
return None
|
||||||
|
if len(renderable) == 1:
|
||||||
|
return renderable[0][1]
|
||||||
|
|
||||||
|
# Multiple cameras have a VQA for this frame — deterministic pick by
|
||||||
|
# relative weight (fall back to a uniform draw if all weights are 0).
|
||||||
|
total = sum(w for w, _ in renderable) or float(len(renderable))
|
||||||
|
digest = hashlib.blake2b(f"vqa:{sample_idx}".encode(), digest_size=8).digest()
|
||||||
|
draw = int.from_bytes(digest, "big") / 2**64 * total
|
||||||
|
cumulative = 0.0
|
||||||
|
for w, rendered in renderable:
|
||||||
|
cumulative += w or (total / len(renderable))
|
||||||
|
if draw < cumulative:
|
||||||
|
return rendered
|
||||||
|
return renderable[-1][1]
|
||||||
|
|
||||||
|
|
||||||
def _select_recipe(recipe: TrainingRecipe, sample_idx: int) -> TrainingRecipe:
|
def _select_recipe(recipe: TrainingRecipe, sample_idx: int) -> TrainingRecipe:
|
||||||
"""Pick a deterministic blend component for ``sample_idx`` (or return ``recipe``)."""
|
"""Pick a deterministic blend component for ``sample_idx`` (or return ``recipe``)."""
|
||||||
if recipe.blend is None:
|
if recipe.blend is None:
|
||||||
@@ -239,9 +307,7 @@ def _resolve_bindings(
|
|||||||
) -> dict[str, LanguageRow | str | None]:
|
) -> dict[str, LanguageRow | str | None]:
|
||||||
"""Resolve every binding in ``recipe`` (plus ``task``) at time ``t``."""
|
"""Resolve every binding in ``recipe`` (plus ``task``) at time ``t``."""
|
||||||
bindings: dict[str, LanguageRow | str | None] = {
|
bindings: dict[str, LanguageRow | str | None] = {
|
||||||
"task": _resolve_task(
|
"task": _resolve_task(task, dataset_ctx, persistent=persistent, sample_idx=sample_idx),
|
||||||
task, dataset_ctx, persistent=persistent, sample_idx=sample_idx
|
|
||||||
),
|
|
||||||
}
|
}
|
||||||
specs = {**DEFAULT_BINDINGS, **(recipe.bindings or {})}
|
specs = {**DEFAULT_BINDINGS, **(recipe.bindings or {})}
|
||||||
for name, spec in specs.items():
|
for name, spec in specs.items():
|
||||||
@@ -275,18 +341,12 @@ def _resolve_task(
|
|||||||
if task is not None:
|
if task is not None:
|
||||||
return task
|
return task
|
||||||
|
|
||||||
aug_rows = [
|
aug_rows = [r for r in persistent if r.get("style") == "task_aug" and r.get("role") == "user"]
|
||||||
r
|
|
||||||
for r in persistent
|
|
||||||
if r.get("style") == "task_aug" and r.get("role") == "user"
|
|
||||||
]
|
|
||||||
if aug_rows:
|
if aug_rows:
|
||||||
# Deterministic, blake2b-based pick keyed on sample_idx so the
|
# Deterministic, blake2b-based pick keyed on sample_idx so the
|
||||||
# rotation is reproducible across runs (Python's built-in ``hash``
|
# rotation is reproducible across runs (Python's built-in ``hash``
|
||||||
# is process-randomized).
|
# is process-randomized).
|
||||||
digest = hashlib.blake2b(
|
digest = hashlib.blake2b(f"task_aug:{sample_idx}".encode(), digest_size=8).digest()
|
||||||
f"task_aug:{sample_idx}".encode(), digest_size=8
|
|
||||||
).digest()
|
|
||||||
idx = int.from_bytes(digest, "big") % len(aug_rows)
|
idx = int.from_bytes(digest, "big") % len(aug_rows)
|
||||||
chosen = aug_rows[idx].get("content")
|
chosen = aug_rows[idx].get("content")
|
||||||
if chosen:
|
if chosen:
|
||||||
@@ -444,10 +504,7 @@ def _validate_rendered(rendered: RenderedMessages) -> None:
|
|||||||
# Valid iff it supervises something: a text-CE target turn OR a
|
# Valid iff it supervises something: a text-CE target turn OR a
|
||||||
# ``low_level`` stream turn (flow / action supervision).
|
# ``low_level`` stream turn (flow / action supervision).
|
||||||
if not target_indices and not any(s == "low_level" for s in streams):
|
if not target_indices and not any(s == "low_level" for s in streams):
|
||||||
raise ValueError(
|
raise ValueError("Rendered samples must contain a target message or a low_level-stream message.")
|
||||||
"Rendered samples must contain a target message or a "
|
|
||||||
"low_level-stream message."
|
|
||||||
)
|
|
||||||
for idx in target_indices:
|
for idx in target_indices:
|
||||||
if idx < 0 or idx >= len(messages):
|
if idx < 0 or idx >= len(messages):
|
||||||
raise ValueError(f"Target message index {idx} is out of bounds.")
|
raise ValueError(f"Target message index {idx} is out of bounds.")
|
||||||
|
|||||||
@@ -84,3 +84,66 @@ class EpisodeAwareSampler:
|
|||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
return len(self.indices)
|
return len(self.indices)
|
||||||
|
|
||||||
|
|
||||||
|
class WeightedEpisodeAwareSampler(EpisodeAwareSampler):
|
||||||
|
"""``EpisodeAwareSampler`` that draws frames *with replacement* in
|
||||||
|
proportion to per-frame weights.
|
||||||
|
|
||||||
|
Used to oversample frames carrying a sparse annotation (e.g. a VQA
|
||||||
|
question) so the policy sees them more often than their natural
|
||||||
|
dataset density. One epoch still yields ``len(self.indices)``
|
||||||
|
samples — the weights only change the *composition* of the stream,
|
||||||
|
not its length. Each epoch re-draws, so the oversampled subset
|
||||||
|
varies run to run.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dataset_from_indices: list[int],
|
||||||
|
dataset_to_indices: list[int],
|
||||||
|
frame_weights,
|
||||||
|
*,
|
||||||
|
episode_indices_to_use: list | None = None,
|
||||||
|
drop_n_first_frames: int = 0,
|
||||||
|
drop_n_last_frames: int = 0,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
dataset_from_indices: Episode start indices (see ``EpisodeAwareSampler``).
|
||||||
|
dataset_to_indices: Episode end indices.
|
||||||
|
frame_weights: 1-D sequence/tensor of non-negative weights, one per
|
||||||
|
dataset frame (length == total dataset frames). Higher weight ⇒
|
||||||
|
that frame is sampled more often.
|
||||||
|
episode_indices_to_use / drop_n_first_frames / drop_n_last_frames:
|
||||||
|
Same meaning as ``EpisodeAwareSampler`` — the episode-boundary
|
||||||
|
frame filtering is applied first, then weighting is restricted
|
||||||
|
to the surviving frames.
|
||||||
|
"""
|
||||||
|
super().__init__(
|
||||||
|
dataset_from_indices,
|
||||||
|
dataset_to_indices,
|
||||||
|
episode_indices_to_use=episode_indices_to_use,
|
||||||
|
drop_n_first_frames=drop_n_first_frames,
|
||||||
|
drop_n_last_frames=drop_n_last_frames,
|
||||||
|
shuffle=False,
|
||||||
|
)
|
||||||
|
weights = torch.as_tensor(frame_weights, dtype=torch.double).flatten()
|
||||||
|
idx = torch.tensor(self.indices, dtype=torch.long)
|
||||||
|
if weights.numel() <= int(idx.max()):
|
||||||
|
raise ValueError(
|
||||||
|
f"frame_weights has {weights.numel()} entries but the sampler "
|
||||||
|
f"references frame index {int(idx.max())}."
|
||||||
|
)
|
||||||
|
selected = weights[idx]
|
||||||
|
if not torch.isfinite(selected).all() or bool((selected < 0).any()):
|
||||||
|
raise ValueError("frame_weights must be finite and non-negative.")
|
||||||
|
if float(selected.sum()) <= 0.0:
|
||||||
|
# All surviving frames have zero weight — fall back to uniform.
|
||||||
|
selected = torch.ones_like(selected)
|
||||||
|
self._weights = selected
|
||||||
|
|
||||||
|
def __iter__(self) -> Iterator[int]:
|
||||||
|
picks = torch.multinomial(self._weights, num_samples=len(self.indices), replacement=True)
|
||||||
|
for i in picks.tolist():
|
||||||
|
yield self.indices[i]
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ from lerobot.common.train_utils import (
|
|||||||
from lerobot.common.wandb_utils import WandBLogger
|
from lerobot.common.wandb_utils import WandBLogger
|
||||||
from lerobot.configs import parser
|
from lerobot.configs import parser
|
||||||
from lerobot.configs.train import TrainPipelineConfig
|
from lerobot.configs.train import TrainPipelineConfig
|
||||||
from lerobot.datasets import EpisodeAwareSampler, make_dataset
|
from lerobot.datasets import EpisodeAwareSampler, WeightedEpisodeAwareSampler, make_dataset
|
||||||
from lerobot.envs import close_envs, make_env, make_env_pre_post_processors
|
from lerobot.envs import close_envs, make_env, make_env_pre_post_processors
|
||||||
from lerobot.optim.factory import make_optimizer_and_scheduler
|
from lerobot.optim.factory import make_optimizer_and_scheduler
|
||||||
from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors
|
from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors
|
||||||
@@ -156,6 +156,61 @@ def update_policy(
|
|||||||
return train_metrics, output_dict
|
return train_metrics, output_dict
|
||||||
|
|
||||||
|
|
||||||
|
def _build_vqa_oversample_weights(dataset: Any, target_fraction: float) -> "torch.Tensor | None":
|
||||||
|
"""Build per-frame sampling weights that oversample VQA-annotated frames.
|
||||||
|
|
||||||
|
Scans the dataset's ``language_events`` column for frames carrying a
|
||||||
|
``vqa``-style annotation and returns a weight tensor (length == total
|
||||||
|
dataset frames) such that, under multinomial sampling, VQA frames make up
|
||||||
|
roughly ``target_fraction`` of the training stream.
|
||||||
|
|
||||||
|
Returns ``None`` (⇒ fall back to uniform episode-aware sampling) when VQA
|
||||||
|
frames cannot be detected or there are none.
|
||||||
|
"""
|
||||||
|
if not 0.0 < target_fraction < 1.0:
|
||||||
|
logging.warning(
|
||||||
|
"vqa_target_fraction must be in (0, 1); got %s — VQA oversampling disabled.",
|
||||||
|
target_fraction,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
hf = getattr(dataset, "hf_dataset", None)
|
||||||
|
if hf is None or "language_events" not in getattr(hf, "column_names", []):
|
||||||
|
logging.warning(
|
||||||
|
"Dataset has no `language_events` column — VQA oversampling disabled."
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
events_col = hf["language_events"]
|
||||||
|
n_frames = len(events_col)
|
||||||
|
is_vqa = torch.zeros(n_frames, dtype=torch.bool)
|
||||||
|
for i, rows in enumerate(events_col):
|
||||||
|
if rows and any((row or {}).get("style") == "vqa" for row in rows):
|
||||||
|
is_vqa[i] = True
|
||||||
|
|
||||||
|
n_vqa = int(is_vqa.sum())
|
||||||
|
if n_vqa == 0:
|
||||||
|
logging.warning("No `vqa` annotations found in the dataset — VQA oversampling disabled.")
|
||||||
|
return None
|
||||||
|
n_other = n_frames - n_vqa
|
||||||
|
|
||||||
|
# Solve target = (n_vqa·w) / (n_vqa·w + n_other) for the VQA weight w.
|
||||||
|
# Clamp to ≥ 1 so VQA frames are never *down*-weighted below uniform.
|
||||||
|
weight = (target_fraction * n_other) / ((1.0 - target_fraction) * max(n_vqa, 1))
|
||||||
|
weight = max(weight, 1.0)
|
||||||
|
weights = torch.ones(n_frames, dtype=torch.double)
|
||||||
|
weights[is_vqa] = weight
|
||||||
|
logging.info(
|
||||||
|
"VQA oversampling: %d/%d frames carry a `vqa` annotation (%.2f%%); "
|
||||||
|
"weighting them x%.2f to target ~%.0f%% of the training stream.",
|
||||||
|
n_vqa,
|
||||||
|
n_frames,
|
||||||
|
100.0 * n_vqa / n_frames,
|
||||||
|
weight,
|
||||||
|
100.0 * target_fraction,
|
||||||
|
)
|
||||||
|
return weights
|
||||||
|
|
||||||
|
|
||||||
@parser.wrap()
|
@parser.wrap()
|
||||||
def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||||
"""
|
"""
|
||||||
@@ -376,9 +431,25 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
|||||||
# create dataloader for offline training
|
# create dataloader for offline training
|
||||||
if hasattr(cfg.policy, "drop_n_last_frames"):
|
if hasattr(cfg.policy, "drop_n_last_frames"):
|
||||||
shuffle = False
|
shuffle = False
|
||||||
|
from_indices = dataset.meta.episodes["dataset_from_index"]
|
||||||
|
to_indices = dataset.meta.episodes["dataset_to_index"]
|
||||||
|
# When `vqa_target_fraction` is set, oversample VQA-annotated
|
||||||
|
# frames via a weighted sampler; otherwise plain episode-aware.
|
||||||
|
vqa_weights = None
|
||||||
|
if cfg.vqa_target_fraction is not None and not cfg.dataset.streaming:
|
||||||
|
vqa_weights = _build_vqa_oversample_weights(dataset, cfg.vqa_target_fraction)
|
||||||
|
if vqa_weights is not None:
|
||||||
|
sampler = WeightedEpisodeAwareSampler(
|
||||||
|
from_indices,
|
||||||
|
to_indices,
|
||||||
|
vqa_weights,
|
||||||
|
episode_indices_to_use=dataset.episodes,
|
||||||
|
drop_n_last_frames=cfg.policy.drop_n_last_frames,
|
||||||
|
)
|
||||||
|
else:
|
||||||
sampler = EpisodeAwareSampler(
|
sampler = EpisodeAwareSampler(
|
||||||
dataset.meta.episodes["dataset_from_index"],
|
from_indices,
|
||||||
dataset.meta.episodes["dataset_to_index"],
|
to_indices,
|
||||||
episode_indices_to_use=dataset.episodes,
|
episode_indices_to_use=dataset.episodes,
|
||||||
drop_n_last_frames=cfg.policy.drop_n_last_frames,
|
drop_n_last_frames=cfg.policy.drop_n_last_frames,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
|
|||||||
@@ -207,12 +207,8 @@ def test_per_camera_blend_renders_both_views():
|
|||||||
"top": TrainingRecipe(
|
"top": TrainingRecipe(
|
||||||
weight=1.0,
|
weight=1.0,
|
||||||
bindings={
|
bindings={
|
||||||
"vqa_query": (
|
"vqa_query": ("emitted_at(t, style=vqa, role=user, camera=observation.images.top)"),
|
||||||
"emitted_at(t, style=vqa, role=user, camera=observation.images.top)"
|
"vqa": ("emitted_at(t, style=vqa, role=assistant, camera=observation.images.top)"),
|
||||||
),
|
|
||||||
"vqa": (
|
|
||||||
"emitted_at(t, style=vqa, role=assistant, camera=observation.images.top)"
|
|
||||||
),
|
|
||||||
},
|
},
|
||||||
messages=[
|
messages=[
|
||||||
MessageTurn(
|
MessageTurn(
|
||||||
@@ -236,12 +232,8 @@ def test_per_camera_blend_renders_both_views():
|
|||||||
"wrist": TrainingRecipe(
|
"wrist": TrainingRecipe(
|
||||||
weight=1.0,
|
weight=1.0,
|
||||||
bindings={
|
bindings={
|
||||||
"vqa_query": (
|
"vqa_query": ("emitted_at(t, style=vqa, role=user, camera=observation.images.wrist)"),
|
||||||
"emitted_at(t, style=vqa, role=user, camera=observation.images.wrist)"
|
"vqa": ("emitted_at(t, style=vqa, role=assistant, camera=observation.images.wrist)"),
|
||||||
),
|
|
||||||
"vqa": (
|
|
||||||
"emitted_at(t, style=vqa, role=assistant, camera=observation.images.wrist)"
|
|
||||||
),
|
|
||||||
},
|
},
|
||||||
messages=[
|
messages=[
|
||||||
MessageTurn(
|
MessageTurn(
|
||||||
@@ -319,11 +311,19 @@ def test_resolve_task_picks_rephrasing_deterministically_per_sample():
|
|||||||
assert seen == {r["content"] for r in rephrasings}
|
assert seen == {r["content"] for r in rephrasings}
|
||||||
# Same sample_idx → same pick (determinism).
|
# Same sample_idx → same pick (determinism).
|
||||||
a = render_sample(
|
a = render_sample(
|
||||||
recipe=recipe, persistent=rephrasings, events=[], t=0.0, sample_idx=42,
|
recipe=recipe,
|
||||||
|
persistent=rephrasings,
|
||||||
|
events=[],
|
||||||
|
t=0.0,
|
||||||
|
sample_idx=42,
|
||||||
dataset_ctx={"task": "canonical"},
|
dataset_ctx={"task": "canonical"},
|
||||||
)
|
)
|
||||||
b = render_sample(
|
b = render_sample(
|
||||||
recipe=recipe, persistent=rephrasings, events=[], t=0.0, sample_idx=42,
|
recipe=recipe,
|
||||||
|
persistent=rephrasings,
|
||||||
|
events=[],
|
||||||
|
t=0.0,
|
||||||
|
sample_idx=42,
|
||||||
dataset_ctx={"task": "canonical"},
|
dataset_ctx={"task": "canonical"},
|
||||||
)
|
)
|
||||||
assert a["messages"][0]["content"] == b["messages"][0]["content"]
|
assert a["messages"][0]["content"] == b["messages"][0]["content"]
|
||||||
@@ -402,6 +402,52 @@ def test_flow_only_low_level_recipe_renders_without_target():
|
|||||||
assert rendered["target_message_indices"] == []
|
assert rendered["target_message_indices"] == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_vqa_frame_is_consumed_over_the_weighted_blend():
|
||||||
|
"""A frame carrying a VQA annotation renders the ``ask_vqa*`` sub-recipe
|
||||||
|
even when its blend weight is tiny — VQA annotations are sparse and must
|
||||||
|
never be wasted on a subtask/action draw."""
|
||||||
|
recipe = TrainingRecipe(
|
||||||
|
blend={
|
||||||
|
"high_level_subtask": TrainingRecipe(
|
||||||
|
weight=0.99,
|
||||||
|
messages=[
|
||||||
|
MessageTurn(role="user", content="${task}", stream="high_level"),
|
||||||
|
MessageTurn(role="assistant", content="a subtask", stream="high_level", target=True),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
"ask_vqa_top": TrainingRecipe(
|
||||||
|
weight=0.01,
|
||||||
|
bindings={
|
||||||
|
"vqa_query": "emitted_at(t, style=vqa, role=user, camera=observation.images.top)",
|
||||||
|
"vqa": "emitted_at(t, style=vqa, role=assistant, camera=observation.images.top)",
|
||||||
|
},
|
||||||
|
messages=[
|
||||||
|
MessageTurn(
|
||||||
|
role="user", content="${vqa_query}", stream="high_level", if_present="vqa_query"
|
||||||
|
),
|
||||||
|
MessageTurn(
|
||||||
|
role="assistant",
|
||||||
|
content="${vqa}",
|
||||||
|
stream="high_level",
|
||||||
|
target=True,
|
||||||
|
if_present="vqa",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
# A frame WITH a vqa event renders VQA on every sample_idx, despite the
|
||||||
|
# ask_vqa weight being only 0.01.
|
||||||
|
for sample_idx in range(20):
|
||||||
|
rendered = render_sample(
|
||||||
|
recipe=recipe, persistent=PERSISTENT, events=EVENTS_AT_1, t=1.0, sample_idx=sample_idx, task="x"
|
||||||
|
)
|
||||||
|
assert rendered["messages"][-1]["content"] == '{"count": 2}', sample_idx
|
||||||
|
# A frame WITHOUT a vqa event falls back to the normal weighted blend.
|
||||||
|
rendered = render_sample(recipe=recipe, persistent=PERSISTENT, events=[], t=1.0, sample_idx=0, task="x")
|
||||||
|
assert rendered["messages"][-1]["content"] == "a subtask"
|
||||||
|
|
||||||
|
|
||||||
def test_canonical_recipe_can_render_low_level_branch():
|
def test_canonical_recipe_can_render_low_level_branch():
|
||||||
recipe = TrainingRecipe.from_yaml(Path("src/lerobot/configs/recipes/pi05_hirobot.yaml"))
|
recipe = TrainingRecipe.from_yaml(Path("src/lerobot/configs/recipes/pi05_hirobot.yaml"))
|
||||||
low_level = TrainingRecipe(blend={"low": recipe.blend["low_level_execution"]})
|
low_level = TrainingRecipe(blend={"low": recipe.blend["low_level_execution"]})
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ from datasets import Dataset # noqa: E402
|
|||||||
from lerobot.datasets.io_utils import (
|
from lerobot.datasets.io_utils import (
|
||||||
hf_transform_to_torch,
|
hf_transform_to_torch,
|
||||||
)
|
)
|
||||||
from lerobot.datasets.sampler import EpisodeAwareSampler
|
from lerobot.datasets.sampler import EpisodeAwareSampler, WeightedEpisodeAwareSampler
|
||||||
|
|
||||||
|
|
||||||
def calculate_episode_data_index(hf_dataset: Dataset) -> dict[str, torch.Tensor]:
|
def calculate_episode_data_index(hf_dataset: Dataset) -> dict[str, torch.Tensor]:
|
||||||
@@ -137,3 +137,49 @@ def test_partial_episode_drop_warns(caplog):
|
|||||||
# Episode 0 is skipped (1 frame, drop 1), Episode 1 keeps frames 2-5
|
# Episode 0 is skipped (1 frame, drop 1), Episode 1 keeps frames 2-5
|
||||||
assert sampler.indices == [2, 3, 4, 5]
|
assert sampler.indices == [2, 3, 4, 5]
|
||||||
assert "Episode 0" in caplog.text
|
assert "Episode 0" in caplog.text
|
||||||
|
|
||||||
|
|
||||||
|
# --- WeightedEpisodeAwareSampler --------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_weighted_sampler_respects_episode_drop_and_length():
|
||||||
|
"""The episode-boundary frame filtering is applied before weighting,
|
||||||
|
and one epoch still yields ``len(indices)`` samples."""
|
||||||
|
# One episode, 10 frames; drop the last 2.
|
||||||
|
sampler = WeightedEpisodeAwareSampler([0], [10], frame_weights=torch.ones(10), drop_n_last_frames=2)
|
||||||
|
assert sampler.indices == list(range(8))
|
||||||
|
assert len(sampler) == 8
|
||||||
|
draws = list(sampler)
|
||||||
|
assert len(draws) == 8
|
||||||
|
# Dropped frames 8 and 9 must never be sampled.
|
||||||
|
assert all(d in set(range(8)) for d in draws)
|
||||||
|
|
||||||
|
|
||||||
|
def test_weighted_sampler_oversamples_high_weight_frames():
|
||||||
|
"""A heavily-weighted frame dominates the draws."""
|
||||||
|
torch.manual_seed(0)
|
||||||
|
# 100 frames, frame 7 is weighted 1000x.
|
||||||
|
weights = torch.ones(100)
|
||||||
|
weights[7] = 1000.0
|
||||||
|
sampler = WeightedEpisodeAwareSampler([0], [100], frame_weights=weights)
|
||||||
|
counts = {}
|
||||||
|
for _ in range(20): # 20 epochs
|
||||||
|
for d in sampler:
|
||||||
|
counts[d] = counts.get(d, 0) + 1
|
||||||
|
total = sum(counts.values())
|
||||||
|
# Frame 7 should be the overwhelming majority of the 2000 draws.
|
||||||
|
assert counts.get(7, 0) / total > 0.9
|
||||||
|
|
||||||
|
|
||||||
|
def test_weighted_sampler_zero_weights_fall_back_to_uniform():
|
||||||
|
"""If every surviving frame has zero weight, sampling is uniform
|
||||||
|
rather than crashing."""
|
||||||
|
sampler = WeightedEpisodeAwareSampler([0], [6], frame_weights=torch.zeros(6))
|
||||||
|
draws = set(sampler)
|
||||||
|
assert draws.issubset(set(range(6)))
|
||||||
|
assert len(list(sampler)) == 6
|
||||||
|
|
||||||
|
|
||||||
|
def test_weighted_sampler_rejects_short_weight_vector():
|
||||||
|
with pytest.raises(ValueError, match="frame_weights"):
|
||||||
|
WeightedEpisodeAwareSampler([0], [10], frame_weights=torch.ones(5))
|
||||||
|
|||||||
Reference in New Issue
Block a user