mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-19 10:40:04 +00:00
fbcb9225f5
VQA annotations are sparse, so VQA was badly underrepresented in training: its effective share was weight x density, and blend draws that picked an ask_vqa* sub-recipe for a non-VQA frame were wasted entirely. Two pieces: 1. Recipe-side consumption (language_render.py): render_sample now routes any frame that carries a VQA annotation to a matching ask_vqa* sub-recipe, regardless of the weighted blend draw. No VQA annotation is wasted and no draw lands on a non-renderable VQA recipe — VQA's recipe-side share now equals the VQA-annotation density. 2. Dataset-side oversampling (WeightedEpisodeAwareSampler + vqa_target_fraction): a new weighted, episode-aware sampler draws frames with replacement by per-frame weight. When TrainPipelineConfig.vqa_target_fraction is set, the train script scans language_events, weights VQA frames so they make up ~that fraction of the training stream, and uses the weighted sampler. This is what actually lets VQA exceed its natural density. Default None keeps uniform episode-aware sampling unchanged. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
186 lines
7.0 KiB
Python
186 lines
7.0 KiB
Python
#!/usr/bin/env python
|
|
|
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import logging
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
|
|
|
from datasets import Dataset # noqa: E402
|
|
|
|
from lerobot.datasets.io_utils import (
|
|
hf_transform_to_torch,
|
|
)
|
|
from lerobot.datasets.sampler import EpisodeAwareSampler, WeightedEpisodeAwareSampler
|
|
|
|
|
|
def calculate_episode_data_index(hf_dataset: Dataset) -> dict[str, torch.Tensor]:
|
|
"""Calculate episode data index for testing. Returns {"from": Tensor, "to": Tensor}."""
|
|
episode_data_index: dict[str, list[int]] = {"from": [], "to": []}
|
|
current_episode = None
|
|
if len(hf_dataset) == 0:
|
|
return {"from": torch.tensor([]), "to": torch.tensor([])}
|
|
for idx, episode_idx in enumerate(hf_dataset["episode_index"]):
|
|
if episode_idx != current_episode:
|
|
episode_data_index["from"].append(idx)
|
|
if current_episode is not None:
|
|
episode_data_index["to"].append(idx)
|
|
current_episode = episode_idx
|
|
episode_data_index["to"].append(idx + 1)
|
|
return {k: torch.tensor(v) for k, v in episode_data_index.items()}
|
|
|
|
|
|
def test_drop_n_first_frames():
|
|
dataset = Dataset.from_dict(
|
|
{
|
|
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
|
|
"index": [0, 1, 2, 3, 4, 5],
|
|
"episode_index": [0, 0, 1, 2, 2, 2],
|
|
},
|
|
)
|
|
dataset.set_transform(hf_transform_to_torch)
|
|
episode_data_index = calculate_episode_data_index(dataset)
|
|
sampler = EpisodeAwareSampler(episode_data_index["from"], episode_data_index["to"], drop_n_first_frames=1)
|
|
assert sampler.indices == [1, 4, 5]
|
|
assert len(sampler) == 3
|
|
assert list(sampler) == [1, 4, 5]
|
|
|
|
|
|
def test_drop_n_last_frames():
|
|
dataset = Dataset.from_dict(
|
|
{
|
|
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
|
|
"index": [0, 1, 2, 3, 4, 5],
|
|
"episode_index": [0, 0, 1, 2, 2, 2],
|
|
},
|
|
)
|
|
dataset.set_transform(hf_transform_to_torch)
|
|
episode_data_index = calculate_episode_data_index(dataset)
|
|
sampler = EpisodeAwareSampler(episode_data_index["from"], episode_data_index["to"], drop_n_last_frames=1)
|
|
assert sampler.indices == [0, 3, 4]
|
|
assert len(sampler) == 3
|
|
assert list(sampler) == [0, 3, 4]
|
|
|
|
|
|
def test_episode_indices_to_use():
|
|
dataset = Dataset.from_dict(
|
|
{
|
|
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
|
|
"index": [0, 1, 2, 3, 4, 5],
|
|
"episode_index": [0, 0, 1, 2, 2, 2],
|
|
},
|
|
)
|
|
dataset.set_transform(hf_transform_to_torch)
|
|
episode_data_index = calculate_episode_data_index(dataset)
|
|
sampler = EpisodeAwareSampler(
|
|
episode_data_index["from"], episode_data_index["to"], episode_indices_to_use=[0, 2]
|
|
)
|
|
assert sampler.indices == [0, 1, 3, 4, 5]
|
|
assert len(sampler) == 5
|
|
assert list(sampler) == [0, 1, 3, 4, 5]
|
|
|
|
|
|
def test_shuffle():
|
|
dataset = Dataset.from_dict(
|
|
{
|
|
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
|
|
"index": [0, 1, 2, 3, 4, 5],
|
|
"episode_index": [0, 0, 1, 2, 2, 2],
|
|
},
|
|
)
|
|
dataset.set_transform(hf_transform_to_torch)
|
|
episode_data_index = calculate_episode_data_index(dataset)
|
|
sampler = EpisodeAwareSampler(episode_data_index["from"], episode_data_index["to"], shuffle=False)
|
|
assert sampler.indices == [0, 1, 2, 3, 4, 5]
|
|
assert len(sampler) == 6
|
|
assert list(sampler) == [0, 1, 2, 3, 4, 5]
|
|
sampler = EpisodeAwareSampler(episode_data_index["from"], episode_data_index["to"], shuffle=True)
|
|
assert sampler.indices == [0, 1, 2, 3, 4, 5]
|
|
assert len(sampler) == 6
|
|
assert set(sampler) == {0, 1, 2, 3, 4, 5}
|
|
|
|
|
|
def test_negative_drop_first_frames_raises():
|
|
with pytest.raises(ValueError, match="drop_n_first_frames must be >= 0"):
|
|
EpisodeAwareSampler([0], [10], drop_n_first_frames=-1)
|
|
|
|
|
|
def test_negative_drop_last_frames_raises():
|
|
with pytest.raises(ValueError, match="drop_n_last_frames must be >= 0"):
|
|
EpisodeAwareSampler([0], [10], drop_n_last_frames=-1)
|
|
|
|
|
|
def test_all_episodes_dropped_raises():
|
|
# All episodes have 1 frame, drop_n_first_frames=1 removes all
|
|
with pytest.raises(ValueError, match="No valid frames remain"):
|
|
EpisodeAwareSampler([0, 1, 2], [1, 2, 3], drop_n_first_frames=1)
|
|
|
|
|
|
def test_partial_episode_drop_warns(caplog):
|
|
# Episode 0: 1 frame (dropped), Episode 1: 5 frames (kept)
|
|
with caplog.at_level(logging.WARNING, logger="lerobot.datasets.sampler"):
|
|
sampler = EpisodeAwareSampler([0, 1], [1, 6], drop_n_first_frames=1)
|
|
# Episode 0 is skipped (1 frame, drop 1), Episode 1 keeps frames 2-5
|
|
assert sampler.indices == [2, 3, 4, 5]
|
|
assert "Episode 0" in caplog.text
|
|
|
|
|
|
# --- WeightedEpisodeAwareSampler --------------------------------------------
|
|
|
|
|
|
def test_weighted_sampler_respects_episode_drop_and_length():
|
|
"""The episode-boundary frame filtering is applied before weighting,
|
|
and one epoch still yields ``len(indices)`` samples."""
|
|
# One episode, 10 frames; drop the last 2.
|
|
sampler = WeightedEpisodeAwareSampler([0], [10], frame_weights=torch.ones(10), drop_n_last_frames=2)
|
|
assert sampler.indices == list(range(8))
|
|
assert len(sampler) == 8
|
|
draws = list(sampler)
|
|
assert len(draws) == 8
|
|
# Dropped frames 8 and 9 must never be sampled.
|
|
assert all(d in set(range(8)) for d in draws)
|
|
|
|
|
|
def test_weighted_sampler_oversamples_high_weight_frames():
|
|
"""A heavily-weighted frame dominates the draws."""
|
|
torch.manual_seed(0)
|
|
# 100 frames, frame 7 is weighted 1000x.
|
|
weights = torch.ones(100)
|
|
weights[7] = 1000.0
|
|
sampler = WeightedEpisodeAwareSampler([0], [100], frame_weights=weights)
|
|
counts = {}
|
|
for _ in range(20): # 20 epochs
|
|
for d in sampler:
|
|
counts[d] = counts.get(d, 0) + 1
|
|
total = sum(counts.values())
|
|
# Frame 7 should be the overwhelming majority of the 2000 draws.
|
|
assert counts.get(7, 0) / total > 0.9
|
|
|
|
|
|
def test_weighted_sampler_zero_weights_fall_back_to_uniform():
|
|
"""If every surviving frame has zero weight, sampling is uniform
|
|
rather than crashing."""
|
|
sampler = WeightedEpisodeAwareSampler([0], [6], frame_weights=torch.zeros(6))
|
|
draws = set(sampler)
|
|
assert draws.issubset(set(range(6)))
|
|
assert len(list(sampler)) == 6
|
|
|
|
|
|
def test_weighted_sampler_rejects_short_weight_vector():
|
|
with pytest.raises(ValueError, match="frame_weights"):
|
|
WeightedEpisodeAwareSampler([0], [10], frame_weights=torch.ones(5))
|