mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-19 02:29:47 +00:00
feat: oversample sparse VQA annotations (recipe consumption + weighted sampler)
VQA annotations are sparse, so VQA was badly underrepresented in training: its effective share was weight x density, and blend draws that picked an ask_vqa* sub-recipe for a non-VQA frame were wasted entirely. Two pieces: 1. Recipe-side consumption (language_render.py): render_sample now routes any frame that carries a VQA annotation to a matching ask_vqa* sub-recipe, regardless of the weighted blend draw. No VQA annotation is wasted and no draw lands on a non-renderable VQA recipe — VQA's recipe-side share now equals the VQA-annotation density. 2. Dataset-side oversampling (WeightedEpisodeAwareSampler + vqa_target_fraction): a new weighted, episode-aware sampler draws frames with replacement by per-frame weight. When TrainPipelineConfig.vqa_target_fraction is set, the train script scans language_events, weights VQA frames so they make up ~that fraction of the training stream, and uses the weighted sampler. This is what actually lets VQA exceed its natural density. Default None keeps uniform episode-aware sampling unchanged. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -72,6 +72,14 @@ class TrainPipelineConfig(HubMixin):
|
||||
wandb: WandBConfig = field(default_factory=WandBConfig)
|
||||
peft: PeftConfig | None = None
|
||||
|
||||
# VQA oversampling. When set (a fraction in (0, 1)), the training
|
||||
# dataloader uses a WeightedEpisodeAwareSampler that draws frames
|
||||
# carrying a `vqa` language annotation often enough that they make
|
||||
# up roughly this fraction of the training stream. VQA annotations
|
||||
# are typically sparse, so without this they are underrepresented.
|
||||
# `None` (default) keeps uniform episode-aware sampling.
|
||||
vqa_target_fraction: float | None = None
|
||||
|
||||
# RA-BC (Reward-Aligned Behavior Cloning) parameters
|
||||
use_rabc: bool = False # Enable reward-weighted training
|
||||
rabc_progress_path: str | None = None # Path to precomputed SARM progress parquet file
|
||||
|
||||
@@ -47,7 +47,7 @@ from .language import (
|
||||
from .lerobot_dataset import LeRobotDataset
|
||||
from .multi_dataset import MultiLeRobotDataset
|
||||
from .pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
|
||||
from .sampler import EpisodeAwareSampler
|
||||
from .sampler import EpisodeAwareSampler, WeightedEpisodeAwareSampler
|
||||
from .streaming_dataset import StreamingLeRobotDataset
|
||||
from .utils import DEFAULT_EPISODES_PATH, create_lerobot_dataset_card
|
||||
from .video_utils import VideoEncodingManager
|
||||
@@ -75,6 +75,7 @@ __all__ = [
|
||||
"DEFAULT_QUANTILES",
|
||||
"EVENT_ONLY_STYLES",
|
||||
"EpisodeAwareSampler",
|
||||
"WeightedEpisodeAwareSampler",
|
||||
"LANGUAGE_EVENTS",
|
||||
"LANGUAGE_PERSISTENT",
|
||||
"LeRobotDataset",
|
||||
|
||||
@@ -56,13 +56,9 @@ def active_at(
|
||||
uniformity but is not consulted: only persistent styles are valid here.
|
||||
"""
|
||||
_validate_persistent_resolver("active_at", style)
|
||||
matches = _matching_rows(
|
||||
persistent, style=style, role=role, tool_name=tool_name, camera=camera
|
||||
)
|
||||
matches = _matching_rows(persistent, style=style, role=role, tool_name=tool_name, camera=camera)
|
||||
matches = [row for row in matches if _timestamp(row) <= t]
|
||||
return _select_latest(
|
||||
matches, style=style, role=role, tool_name=tool_name, camera=camera
|
||||
)
|
||||
return _select_latest(matches, style=style, role=role, tool_name=tool_name, camera=camera)
|
||||
|
||||
|
||||
def emitted_at(
|
||||
@@ -88,9 +84,7 @@ def emitted_at(
|
||||
if column == LANGUAGE_PERSISTENT:
|
||||
matches = [
|
||||
row
|
||||
for row in _matching_rows(
|
||||
persistent, style=style, role=role, tool_name=tool_name, camera=camera
|
||||
)
|
||||
for row in _matching_rows(persistent, style=style, role=role, tool_name=tool_name, camera=camera)
|
||||
if _timestamp(row) == t
|
||||
]
|
||||
return _select_one(
|
||||
@@ -101,9 +95,7 @@ def emitted_at(
|
||||
camera=camera,
|
||||
sort_key=_persistent_sort_key,
|
||||
)
|
||||
matches = _matching_rows(
|
||||
events, style=style, role=role, tool_name=tool_name, camera=camera
|
||||
)
|
||||
matches = _matching_rows(events, style=style, role=role, tool_name=tool_name, camera=camera)
|
||||
return _select_one(
|
||||
matches,
|
||||
style=style,
|
||||
@@ -192,6 +184,29 @@ def render_sample(
|
||||
"""
|
||||
persistent_rows = _normalize_rows(persistent or [])
|
||||
event_rows = _normalize_rows(events or [])
|
||||
|
||||
# VQA-priority routing. A ``vqa`` annotation is sparse and
|
||||
# view-dependent; the plain weighted blend would (a) waste a draw
|
||||
# whenever it picks an ``ask_vqa*`` sub-recipe for a frame that has
|
||||
# no VQA, and (b) silently drop a VQA-annotated frame whenever it
|
||||
# picks a non-VQA sub-recipe. So: if the blend has ``ask_vqa*``
|
||||
# sub-recipes and *this* frame carries one of their VQA bindings,
|
||||
# render VQA here regardless of the weighted draw. That makes VQA's
|
||||
# recipe-side training share equal the VQA-annotation density (the
|
||||
# maximum reachable without a dataset-level oversampling sampler).
|
||||
if recipe.blend is not None:
|
||||
vqa_rendered = _render_vqa_if_present(
|
||||
recipe,
|
||||
persistent=persistent_rows,
|
||||
events=event_rows,
|
||||
t=t,
|
||||
sample_idx=sample_idx,
|
||||
task=task,
|
||||
dataset_ctx=dataset_ctx,
|
||||
)
|
||||
if vqa_rendered is not None:
|
||||
return vqa_rendered
|
||||
|
||||
selected_recipe = _select_recipe(recipe, sample_idx)
|
||||
bindings = _resolve_bindings(
|
||||
selected_recipe,
|
||||
@@ -205,6 +220,59 @@ def render_sample(
|
||||
return _render_message_recipe(selected_recipe, bindings)
|
||||
|
||||
|
||||
def _render_vqa_if_present(
|
||||
recipe: TrainingRecipe,
|
||||
*,
|
||||
persistent: Sequence[LanguageRow],
|
||||
events: Sequence[LanguageRow],
|
||||
t: float,
|
||||
sample_idx: int,
|
||||
task: str | None,
|
||||
dataset_ctx: Any | None,
|
||||
) -> RenderedMessages | None:
|
||||
"""Render an ``ask_vqa*`` sub-recipe iff this frame carries a VQA
|
||||
annotation; otherwise return ``None`` so the caller falls back to the
|
||||
normal weighted blend.
|
||||
|
||||
When several VQA sub-recipes resolve (e.g. a frame annotated for more
|
||||
than one camera), one is chosen deterministically by relative weight.
|
||||
"""
|
||||
assert recipe.blend is not None
|
||||
renderable: list[tuple[float, RenderedMessages]] = []
|
||||
for name, component in recipe.blend.items():
|
||||
if not name.startswith("ask_vqa"):
|
||||
continue
|
||||
bindings = _resolve_bindings(
|
||||
component,
|
||||
persistent=persistent,
|
||||
events=events,
|
||||
t=t,
|
||||
sample_idx=sample_idx,
|
||||
task=task,
|
||||
dataset_ctx=dataset_ctx,
|
||||
)
|
||||
rendered = _render_message_recipe(component, bindings)
|
||||
if rendered is not None:
|
||||
renderable.append((float(component.weight or 0.0), rendered))
|
||||
|
||||
if not renderable:
|
||||
return None
|
||||
if len(renderable) == 1:
|
||||
return renderable[0][1]
|
||||
|
||||
# Multiple cameras have a VQA for this frame — deterministic pick by
|
||||
# relative weight (fall back to a uniform draw if all weights are 0).
|
||||
total = sum(w for w, _ in renderable) or float(len(renderable))
|
||||
digest = hashlib.blake2b(f"vqa:{sample_idx}".encode(), digest_size=8).digest()
|
||||
draw = int.from_bytes(digest, "big") / 2**64 * total
|
||||
cumulative = 0.0
|
||||
for w, rendered in renderable:
|
||||
cumulative += w or (total / len(renderable))
|
||||
if draw < cumulative:
|
||||
return rendered
|
||||
return renderable[-1][1]
|
||||
|
||||
|
||||
def _select_recipe(recipe: TrainingRecipe, sample_idx: int) -> TrainingRecipe:
|
||||
"""Pick a deterministic blend component for ``sample_idx`` (or return ``recipe``)."""
|
||||
if recipe.blend is None:
|
||||
@@ -239,9 +307,7 @@ def _resolve_bindings(
|
||||
) -> dict[str, LanguageRow | str | None]:
|
||||
"""Resolve every binding in ``recipe`` (plus ``task``) at time ``t``."""
|
||||
bindings: dict[str, LanguageRow | str | None] = {
|
||||
"task": _resolve_task(
|
||||
task, dataset_ctx, persistent=persistent, sample_idx=sample_idx
|
||||
),
|
||||
"task": _resolve_task(task, dataset_ctx, persistent=persistent, sample_idx=sample_idx),
|
||||
}
|
||||
specs = {**DEFAULT_BINDINGS, **(recipe.bindings or {})}
|
||||
for name, spec in specs.items():
|
||||
@@ -275,18 +341,12 @@ def _resolve_task(
|
||||
if task is not None:
|
||||
return task
|
||||
|
||||
aug_rows = [
|
||||
r
|
||||
for r in persistent
|
||||
if r.get("style") == "task_aug" and r.get("role") == "user"
|
||||
]
|
||||
aug_rows = [r for r in persistent if r.get("style") == "task_aug" and r.get("role") == "user"]
|
||||
if aug_rows:
|
||||
# Deterministic, blake2b-based pick keyed on sample_idx so the
|
||||
# rotation is reproducible across runs (Python's built-in ``hash``
|
||||
# is process-randomized).
|
||||
digest = hashlib.blake2b(
|
||||
f"task_aug:{sample_idx}".encode(), digest_size=8
|
||||
).digest()
|
||||
digest = hashlib.blake2b(f"task_aug:{sample_idx}".encode(), digest_size=8).digest()
|
||||
idx = int.from_bytes(digest, "big") % len(aug_rows)
|
||||
chosen = aug_rows[idx].get("content")
|
||||
if chosen:
|
||||
@@ -444,10 +504,7 @@ def _validate_rendered(rendered: RenderedMessages) -> None:
|
||||
# Valid iff it supervises something: a text-CE target turn OR a
|
||||
# ``low_level`` stream turn (flow / action supervision).
|
||||
if not target_indices and not any(s == "low_level" for s in streams):
|
||||
raise ValueError(
|
||||
"Rendered samples must contain a target message or a "
|
||||
"low_level-stream message."
|
||||
)
|
||||
raise ValueError("Rendered samples must contain a target message or a low_level-stream message.")
|
||||
for idx in target_indices:
|
||||
if idx < 0 or idx >= len(messages):
|
||||
raise ValueError(f"Target message index {idx} is out of bounds.")
|
||||
|
||||
@@ -84,3 +84,66 @@ class EpisodeAwareSampler:
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.indices)
|
||||
|
||||
|
||||
class WeightedEpisodeAwareSampler(EpisodeAwareSampler):
|
||||
"""``EpisodeAwareSampler`` that draws frames *with replacement* in
|
||||
proportion to per-frame weights.
|
||||
|
||||
Used to oversample frames carrying a sparse annotation (e.g. a VQA
|
||||
question) so the policy sees them more often than their natural
|
||||
dataset density. One epoch still yields ``len(self.indices)``
|
||||
samples — the weights only change the *composition* of the stream,
|
||||
not its length. Each epoch re-draws, so the oversampled subset
|
||||
varies run to run.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset_from_indices: list[int],
|
||||
dataset_to_indices: list[int],
|
||||
frame_weights,
|
||||
*,
|
||||
episode_indices_to_use: list | None = None,
|
||||
drop_n_first_frames: int = 0,
|
||||
drop_n_last_frames: int = 0,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
dataset_from_indices: Episode start indices (see ``EpisodeAwareSampler``).
|
||||
dataset_to_indices: Episode end indices.
|
||||
frame_weights: 1-D sequence/tensor of non-negative weights, one per
|
||||
dataset frame (length == total dataset frames). Higher weight ⇒
|
||||
that frame is sampled more often.
|
||||
episode_indices_to_use / drop_n_first_frames / drop_n_last_frames:
|
||||
Same meaning as ``EpisodeAwareSampler`` — the episode-boundary
|
||||
frame filtering is applied first, then weighting is restricted
|
||||
to the surviving frames.
|
||||
"""
|
||||
super().__init__(
|
||||
dataset_from_indices,
|
||||
dataset_to_indices,
|
||||
episode_indices_to_use=episode_indices_to_use,
|
||||
drop_n_first_frames=drop_n_first_frames,
|
||||
drop_n_last_frames=drop_n_last_frames,
|
||||
shuffle=False,
|
||||
)
|
||||
weights = torch.as_tensor(frame_weights, dtype=torch.double).flatten()
|
||||
idx = torch.tensor(self.indices, dtype=torch.long)
|
||||
if weights.numel() <= int(idx.max()):
|
||||
raise ValueError(
|
||||
f"frame_weights has {weights.numel()} entries but the sampler "
|
||||
f"references frame index {int(idx.max())}."
|
||||
)
|
||||
selected = weights[idx]
|
||||
if not torch.isfinite(selected).all() or bool((selected < 0).any()):
|
||||
raise ValueError("frame_weights must be finite and non-negative.")
|
||||
if float(selected.sum()) <= 0.0:
|
||||
# All surviving frames have zero weight — fall back to uniform.
|
||||
selected = torch.ones_like(selected)
|
||||
self._weights = selected
|
||||
|
||||
def __iter__(self) -> Iterator[int]:
|
||||
picks = torch.multinomial(self._weights, num_samples=len(self.indices), replacement=True)
|
||||
for i in picks.tolist():
|
||||
yield self.indices[i]
|
||||
|
||||
@@ -43,7 +43,7 @@ from lerobot.common.train_utils import (
|
||||
from lerobot.common.wandb_utils import WandBLogger
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.datasets import EpisodeAwareSampler, make_dataset
|
||||
from lerobot.datasets import EpisodeAwareSampler, WeightedEpisodeAwareSampler, make_dataset
|
||||
from lerobot.envs import close_envs, make_env, make_env_pre_post_processors
|
||||
from lerobot.optim.factory import make_optimizer_and_scheduler
|
||||
from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors
|
||||
@@ -156,6 +156,61 @@ def update_policy(
|
||||
return train_metrics, output_dict
|
||||
|
||||
|
||||
def _build_vqa_oversample_weights(dataset: Any, target_fraction: float) -> "torch.Tensor | None":
|
||||
"""Build per-frame sampling weights that oversample VQA-annotated frames.
|
||||
|
||||
Scans the dataset's ``language_events`` column for frames carrying a
|
||||
``vqa``-style annotation and returns a weight tensor (length == total
|
||||
dataset frames) such that, under multinomial sampling, VQA frames make up
|
||||
roughly ``target_fraction`` of the training stream.
|
||||
|
||||
Returns ``None`` (⇒ fall back to uniform episode-aware sampling) when VQA
|
||||
frames cannot be detected or there are none.
|
||||
"""
|
||||
if not 0.0 < target_fraction < 1.0:
|
||||
logging.warning(
|
||||
"vqa_target_fraction must be in (0, 1); got %s — VQA oversampling disabled.",
|
||||
target_fraction,
|
||||
)
|
||||
return None
|
||||
hf = getattr(dataset, "hf_dataset", None)
|
||||
if hf is None or "language_events" not in getattr(hf, "column_names", []):
|
||||
logging.warning(
|
||||
"Dataset has no `language_events` column — VQA oversampling disabled."
|
||||
)
|
||||
return None
|
||||
|
||||
events_col = hf["language_events"]
|
||||
n_frames = len(events_col)
|
||||
is_vqa = torch.zeros(n_frames, dtype=torch.bool)
|
||||
for i, rows in enumerate(events_col):
|
||||
if rows and any((row or {}).get("style") == "vqa" for row in rows):
|
||||
is_vqa[i] = True
|
||||
|
||||
n_vqa = int(is_vqa.sum())
|
||||
if n_vqa == 0:
|
||||
logging.warning("No `vqa` annotations found in the dataset — VQA oversampling disabled.")
|
||||
return None
|
||||
n_other = n_frames - n_vqa
|
||||
|
||||
# Solve target = (n_vqa·w) / (n_vqa·w + n_other) for the VQA weight w.
|
||||
# Clamp to ≥ 1 so VQA frames are never *down*-weighted below uniform.
|
||||
weight = (target_fraction * n_other) / ((1.0 - target_fraction) * max(n_vqa, 1))
|
||||
weight = max(weight, 1.0)
|
||||
weights = torch.ones(n_frames, dtype=torch.double)
|
||||
weights[is_vqa] = weight
|
||||
logging.info(
|
||||
"VQA oversampling: %d/%d frames carry a `vqa` annotation (%.2f%%); "
|
||||
"weighting them x%.2f to target ~%.0f%% of the training stream.",
|
||||
n_vqa,
|
||||
n_frames,
|
||||
100.0 * n_vqa / n_frames,
|
||||
weight,
|
||||
100.0 * target_fraction,
|
||||
)
|
||||
return weights
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
"""
|
||||
@@ -376,13 +431,29 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
# create dataloader for offline training
|
||||
if hasattr(cfg.policy, "drop_n_last_frames"):
|
||||
shuffle = False
|
||||
sampler = EpisodeAwareSampler(
|
||||
dataset.meta.episodes["dataset_from_index"],
|
||||
dataset.meta.episodes["dataset_to_index"],
|
||||
episode_indices_to_use=dataset.episodes,
|
||||
drop_n_last_frames=cfg.policy.drop_n_last_frames,
|
||||
shuffle=True,
|
||||
)
|
||||
from_indices = dataset.meta.episodes["dataset_from_index"]
|
||||
to_indices = dataset.meta.episodes["dataset_to_index"]
|
||||
# When `vqa_target_fraction` is set, oversample VQA-annotated
|
||||
# frames via a weighted sampler; otherwise plain episode-aware.
|
||||
vqa_weights = None
|
||||
if cfg.vqa_target_fraction is not None and not cfg.dataset.streaming:
|
||||
vqa_weights = _build_vqa_oversample_weights(dataset, cfg.vqa_target_fraction)
|
||||
if vqa_weights is not None:
|
||||
sampler = WeightedEpisodeAwareSampler(
|
||||
from_indices,
|
||||
to_indices,
|
||||
vqa_weights,
|
||||
episode_indices_to_use=dataset.episodes,
|
||||
drop_n_last_frames=cfg.policy.drop_n_last_frames,
|
||||
)
|
||||
else:
|
||||
sampler = EpisodeAwareSampler(
|
||||
from_indices,
|
||||
to_indices,
|
||||
episode_indices_to_use=dataset.episodes,
|
||||
drop_n_last_frames=cfg.policy.drop_n_last_frames,
|
||||
shuffle=True,
|
||||
)
|
||||
else:
|
||||
shuffle = True
|
||||
sampler = None
|
||||
|
||||
@@ -207,12 +207,8 @@ def test_per_camera_blend_renders_both_views():
|
||||
"top": TrainingRecipe(
|
||||
weight=1.0,
|
||||
bindings={
|
||||
"vqa_query": (
|
||||
"emitted_at(t, style=vqa, role=user, camera=observation.images.top)"
|
||||
),
|
||||
"vqa": (
|
||||
"emitted_at(t, style=vqa, role=assistant, camera=observation.images.top)"
|
||||
),
|
||||
"vqa_query": ("emitted_at(t, style=vqa, role=user, camera=observation.images.top)"),
|
||||
"vqa": ("emitted_at(t, style=vqa, role=assistant, camera=observation.images.top)"),
|
||||
},
|
||||
messages=[
|
||||
MessageTurn(
|
||||
@@ -236,12 +232,8 @@ def test_per_camera_blend_renders_both_views():
|
||||
"wrist": TrainingRecipe(
|
||||
weight=1.0,
|
||||
bindings={
|
||||
"vqa_query": (
|
||||
"emitted_at(t, style=vqa, role=user, camera=observation.images.wrist)"
|
||||
),
|
||||
"vqa": (
|
||||
"emitted_at(t, style=vqa, role=assistant, camera=observation.images.wrist)"
|
||||
),
|
||||
"vqa_query": ("emitted_at(t, style=vqa, role=user, camera=observation.images.wrist)"),
|
||||
"vqa": ("emitted_at(t, style=vqa, role=assistant, camera=observation.images.wrist)"),
|
||||
},
|
||||
messages=[
|
||||
MessageTurn(
|
||||
@@ -319,11 +311,19 @@ def test_resolve_task_picks_rephrasing_deterministically_per_sample():
|
||||
assert seen == {r["content"] for r in rephrasings}
|
||||
# Same sample_idx → same pick (determinism).
|
||||
a = render_sample(
|
||||
recipe=recipe, persistent=rephrasings, events=[], t=0.0, sample_idx=42,
|
||||
recipe=recipe,
|
||||
persistent=rephrasings,
|
||||
events=[],
|
||||
t=0.0,
|
||||
sample_idx=42,
|
||||
dataset_ctx={"task": "canonical"},
|
||||
)
|
||||
b = render_sample(
|
||||
recipe=recipe, persistent=rephrasings, events=[], t=0.0, sample_idx=42,
|
||||
recipe=recipe,
|
||||
persistent=rephrasings,
|
||||
events=[],
|
||||
t=0.0,
|
||||
sample_idx=42,
|
||||
dataset_ctx={"task": "canonical"},
|
||||
)
|
||||
assert a["messages"][0]["content"] == b["messages"][0]["content"]
|
||||
@@ -402,6 +402,52 @@ def test_flow_only_low_level_recipe_renders_without_target():
|
||||
assert rendered["target_message_indices"] == []
|
||||
|
||||
|
||||
def test_vqa_frame_is_consumed_over_the_weighted_blend():
|
||||
"""A frame carrying a VQA annotation renders the ``ask_vqa*`` sub-recipe
|
||||
even when its blend weight is tiny — VQA annotations are sparse and must
|
||||
never be wasted on a subtask/action draw."""
|
||||
recipe = TrainingRecipe(
|
||||
blend={
|
||||
"high_level_subtask": TrainingRecipe(
|
||||
weight=0.99,
|
||||
messages=[
|
||||
MessageTurn(role="user", content="${task}", stream="high_level"),
|
||||
MessageTurn(role="assistant", content="a subtask", stream="high_level", target=True),
|
||||
],
|
||||
),
|
||||
"ask_vqa_top": TrainingRecipe(
|
||||
weight=0.01,
|
||||
bindings={
|
||||
"vqa_query": "emitted_at(t, style=vqa, role=user, camera=observation.images.top)",
|
||||
"vqa": "emitted_at(t, style=vqa, role=assistant, camera=observation.images.top)",
|
||||
},
|
||||
messages=[
|
||||
MessageTurn(
|
||||
role="user", content="${vqa_query}", stream="high_level", if_present="vqa_query"
|
||||
),
|
||||
MessageTurn(
|
||||
role="assistant",
|
||||
content="${vqa}",
|
||||
stream="high_level",
|
||||
target=True,
|
||||
if_present="vqa",
|
||||
),
|
||||
],
|
||||
),
|
||||
}
|
||||
)
|
||||
# A frame WITH a vqa event renders VQA on every sample_idx, despite the
|
||||
# ask_vqa weight being only 0.01.
|
||||
for sample_idx in range(20):
|
||||
rendered = render_sample(
|
||||
recipe=recipe, persistent=PERSISTENT, events=EVENTS_AT_1, t=1.0, sample_idx=sample_idx, task="x"
|
||||
)
|
||||
assert rendered["messages"][-1]["content"] == '{"count": 2}', sample_idx
|
||||
# A frame WITHOUT a vqa event falls back to the normal weighted blend.
|
||||
rendered = render_sample(recipe=recipe, persistent=PERSISTENT, events=[], t=1.0, sample_idx=0, task="x")
|
||||
assert rendered["messages"][-1]["content"] == "a subtask"
|
||||
|
||||
|
||||
def test_canonical_recipe_can_render_low_level_branch():
|
||||
recipe = TrainingRecipe.from_yaml(Path("src/lerobot/configs/recipes/pi05_hirobot.yaml"))
|
||||
low_level = TrainingRecipe(blend={"low": recipe.blend["low_level_execution"]})
|
||||
|
||||
@@ -25,7 +25,7 @@ from datasets import Dataset # noqa: E402
|
||||
from lerobot.datasets.io_utils import (
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
from lerobot.datasets.sampler import EpisodeAwareSampler
|
||||
from lerobot.datasets.sampler import EpisodeAwareSampler, WeightedEpisodeAwareSampler
|
||||
|
||||
|
||||
def calculate_episode_data_index(hf_dataset: Dataset) -> dict[str, torch.Tensor]:
|
||||
@@ -137,3 +137,49 @@ def test_partial_episode_drop_warns(caplog):
|
||||
# Episode 0 is skipped (1 frame, drop 1), Episode 1 keeps frames 2-5
|
||||
assert sampler.indices == [2, 3, 4, 5]
|
||||
assert "Episode 0" in caplog.text
|
||||
|
||||
|
||||
# --- WeightedEpisodeAwareSampler --------------------------------------------
|
||||
|
||||
|
||||
def test_weighted_sampler_respects_episode_drop_and_length():
|
||||
"""The episode-boundary frame filtering is applied before weighting,
|
||||
and one epoch still yields ``len(indices)`` samples."""
|
||||
# One episode, 10 frames; drop the last 2.
|
||||
sampler = WeightedEpisodeAwareSampler([0], [10], frame_weights=torch.ones(10), drop_n_last_frames=2)
|
||||
assert sampler.indices == list(range(8))
|
||||
assert len(sampler) == 8
|
||||
draws = list(sampler)
|
||||
assert len(draws) == 8
|
||||
# Dropped frames 8 and 9 must never be sampled.
|
||||
assert all(d in set(range(8)) for d in draws)
|
||||
|
||||
|
||||
def test_weighted_sampler_oversamples_high_weight_frames():
|
||||
"""A heavily-weighted frame dominates the draws."""
|
||||
torch.manual_seed(0)
|
||||
# 100 frames, frame 7 is weighted 1000x.
|
||||
weights = torch.ones(100)
|
||||
weights[7] = 1000.0
|
||||
sampler = WeightedEpisodeAwareSampler([0], [100], frame_weights=weights)
|
||||
counts = {}
|
||||
for _ in range(20): # 20 epochs
|
||||
for d in sampler:
|
||||
counts[d] = counts.get(d, 0) + 1
|
||||
total = sum(counts.values())
|
||||
# Frame 7 should be the overwhelming majority of the 2000 draws.
|
||||
assert counts.get(7, 0) / total > 0.9
|
||||
|
||||
|
||||
def test_weighted_sampler_zero_weights_fall_back_to_uniform():
|
||||
"""If every surviving frame has zero weight, sampling is uniform
|
||||
rather than crashing."""
|
||||
sampler = WeightedEpisodeAwareSampler([0], [6], frame_weights=torch.zeros(6))
|
||||
draws = set(sampler)
|
||||
assert draws.issubset(set(range(6)))
|
||||
assert len(list(sampler)) == 6
|
||||
|
||||
|
||||
def test_weighted_sampler_rejects_short_weight_vector():
|
||||
with pytest.raises(ValueError, match="frame_weights"):
|
||||
WeightedEpisodeAwareSampler([0], [10], frame_weights=torch.ones(5))
|
||||
|
||||
Reference in New Issue
Block a user