feat(streaming): random-episode admission via reshard() + multi-input-shard shuffle

Reshard parquet per row group (1 shard == 1 row group == 1 episode) and feed the
episode-pool shuffle with max_buffer_input_shards so the pool is a uniform random
sample of the corpus, independent of episodes-per-file. Add validate_row_groups
guardrails (collapsed-row-group + distributed divisibility), require datasets>=5.0.0,
make the test fixture write one row group per episode, and plumb max_buffer_input_shards
through the dataloading benchmark.

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
pepijn
2026-06-15 13:33:27 +00:00
parent 3ec60da82b
commit 4940281120
5 changed files with 291 additions and 16 deletions
+16
View File
@@ -222,8 +222,12 @@ def build_dataset(scenario: str, args: argparse.Namespace):
revision=MOLMO_REVISION,
data_files_root=data_files_root,
episode_pool_size=args.episode_pool_size,
max_buffer_input_shards=args.max_buffer_input_shards,
video_decoder_cache_size=args.video_decoder_cache_size,
tolerance_s=1e-3,
# Throughput benchmark: don't gate on the one-row-group-per-episode invariant (a public
# dataset may be collapsed); reshard() still yields per-episode shards where it holds.
validate_row_groups=False,
)
return dataset, meta, False, {"num_shards": dataset.num_shards, "data_files_root": data_files_root}
@@ -362,6 +366,9 @@ def run_scenario(scenario: str, args: argparse.Namespace) -> None:
"batch_size": args.batch_size,
"num_workers": args.num_workers,
"episode_pool_size": None if is_map_style else args.episode_pool_size,
"max_buffer_input_shards": None
if is_map_style
else (args.max_buffer_input_shards or args.episode_pool_size),
**info,
"num_cameras": num_cameras,
"image_shape": image_shape,
@@ -419,6 +426,8 @@ def submit_chain(args: argparse.Namespace) -> None:
f"--video_decoder_cache_size {args.video_decoder_cache_size} --duration_s {args.duration_s} "
f"--num_batches {args.num_batches} --out_dir {args.out_dir}"
)
if args.max_buffer_input_shards is not None:
common += f" --max_buffer_input_shards {args.max_buffer_input_shards}"
if args.local_root:
common += f" --local_root {args.local_root}"
env_prefix = "export TOKENIZERS_PARALLELISM=false"
@@ -491,6 +500,13 @@ def parse_args() -> argparse.Namespace:
p.add_argument(
"--episode_pool_size", type=int, default=1024, help="Streaming shuffle pool (randomness knob)."
)
p.add_argument(
"--max_buffer_input_shards",
type=int,
default=None,
help="Concurrently-live random episodes feeding the pool after reshard() "
"(default: episode_pool_size). The frac knob; set >= batch_size for frac->1.",
)
p.add_argument(
"--video_decoder_cache_size", type=int, default=32, help="Max open video decoders (bounds RAM)."
)
+4 -3
View File
@@ -95,7 +95,7 @@ dependencies = [
# ── Feature-scoped extras ──────────────────────────────────
dataset = [
"datasets>=4.7.0,<6.0.0",
"datasets>=5.0.0,<6.0.0", # StreamingLeRobotDataset needs reshard() + shuffle(max_buffer_input_shards=...)
"pandas>=2.0.0,<3.0.0", # NOTE: Transitive dependency of datasets
"pyarrow>=21.0.0,<30.0.0", # NOTE: Transitive dependency of datasets
"lerobot[av-dep]",
@@ -334,8 +334,9 @@ explicit = true
torch = [{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" }]
torchvision = [{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" }]
# Temporary: the native streaming pipeline needs batch(by_column=...) to survive shard/shuffle
# re-creation, fixed in datasets#8259 (merged, not yet released). Pin to the merge commit until the
# next datasets release ships it, then drop this and bump the floor in `dependencies`.
# re-creation (datasets#8259), reshard() per row group (#8193), and shuffle(max_buffer_input_shards=...)
# (#8194) — all merged, not yet in a tagged 5.0 release. Pin to the merge commit until the next
# datasets release ships them, then drop this and rely on the `datasets>=5.0.0` floor in `dependencies`.
datasets = { git = "https://github.com/huggingface/datasets.git", rev = "2c45eab1bb975ac3d846f2aa6217b82adec8eba3" }
[tool.setuptools.package-data]
+133 -10
View File
@@ -51,9 +51,10 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
The tabular side is a pure `datasets` pipeline::
load_dataset(streaming=True) # parquet shards from the Hub / a bucket
-> reshard() # 1 shard == 1 row group == 1 episode
-> split_dataset_by_node(rank, world_size) # disjoint shards per rank
-> batch(by_column="episode_index") # whole episodes
-> shuffle(buffer_size=episode_pool_size) # episode pool (the randomness knob)
-> batch(by_column="episode_index") # whole episodes (one per shard)
-> shuffle(episode_pool_size, max_buffer_input_shards) # K random episodes, global perm
-> map(explode + exact delta windows) # episode -> frames, windows are exact
-> shuffle(buffer_size=frame_shuffle_buffer_size) # frame-level interleave
@@ -62,6 +63,19 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
string. DataLoader workers are split natively by `datasets` (disjoint shards per worker),
and resume uses the native ``state_dict`` / ``load_state_dict``.
Random-episode admission (Plan B): the LeRobot writer stores one Parquet row group per
episode, so ``datasets.IterableDataset.reshard()`` makes one shard == one episode (no new
files; shards are (file, row_group) pairs). ``shuffle`` then permutes shard order globally and
fills its buffer from ``max_buffer_input_shards`` shards concurrently, so the episode pool is a
uniformly-random sample of the corpus regardless of how many episodes are packed per file.
``max_buffer_input_shards`` is the number of concurrently-live random episodes; set it
``>= batch_size`` for the per-batch distinct-episode fraction to approach 1.
Requirement: ONE ROW GROUP PER EPISODE. Recorded datasets satisfy this; bulk
``df.to_parquet`` / ``push_to_hub`` / aggregate paths collapse row groups and are rejected at
init (see ``validate_row_groups``). Old collapsed datasets still load fine for the map-style
path; only this streaming random-episode path requires the invariant.
Randomness: a batch mixes up to ``episode_pool_size`` distinct episodes; delta windows are
exact slices of the resident episode with correct padding at episode boundaries.
@@ -97,6 +111,7 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
force_cache_sync: bool = False,
streaming: bool = True,
episode_pool_size: int | None = 1024,
max_buffer_input_shards: int | None = None,
frame_shuffle_buffer_size: int | None = None,
buffer_size: int | None = None,
max_num_shards: int | None = None,
@@ -108,6 +123,7 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
world_size: int | None = None,
video_decoder_cache_size: int | None = None,
data_files_root: str | None = None,
validate_row_groups: bool = True,
):
"""Initialize a StreamingLeRobotDataset.
@@ -127,6 +143,11 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
across — the randomness knob. Larger mixes more episodes per batch (closer to
map-style uniform) at the cost of cold-start latency and frame-buffer RAM.
Defaults to 1024.
max_buffer_input_shards (int | None, optional): Number of shards (== episodes, after
``reshard()``) the episode-pool ``shuffle`` reads from concurrently — i.e. the count
of concurrently-live random episodes feeding the pool from a global shard permutation.
Set ``>= batch_size`` for the per-batch distinct-episode fraction to approach 1.
Defaults to ``episode_pool_size``.
frame_shuffle_buffer_size (int | None, optional): Frame-level shuffle buffer after the
episode pool. Defaults to ``episode_pool_size x average episode length`` (capped),
which matches the pool's mixing radius.
@@ -147,6 +168,11 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
data_files_root (str | None, optional): fsspec root holding the bulk ``data/`` and ``videos/``
trees (e.g. ``hf://buckets/<owner>/<name>``). When set, parquet and video bytes are read
from there while metadata still loads from ``repo_id`` on the Hub.
validate_row_groups (bool, optional): When True (default), verify at init that the dataset
stores one Parquet row group per episode (sampling data-file footers) and that
``num_shards`` is divisible by ``world_size`` for distributed runs, raising a clear
``ValueError`` otherwise. Set False to skip the checks (e.g. single-process debugging);
the divisibility check then downgrades to a warning.
"""
super().__init__()
self.repo_id = repo_id
@@ -175,6 +201,10 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
self.streaming = streaming
self.episode_pool_size = max(1, episode_pool_size) if episode_pool_size else 1024
self.max_buffer_input_shards = (
max(1, max_buffer_input_shards) if max_buffer_input_shards else self.episode_pool_size
)
self.validate_row_groups = validate_row_groups
self._return_uint8 = return_uint8
self.rank, self.world_size = self._resolve_distributed(rank, world_size)
@@ -232,8 +262,16 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
if extra_columns:
self.hf_dataset = self.hf_dataset.remove_columns(extra_columns)
# Reshard Parquet per row group so 1 shard == 1 row group == 1 episode (the LeRobot writer
# emits one row group per episode). This lets the episode-pool shuffle admit uniformly-random
# episodes from a global shard permutation, independent of how many episodes are packed per file.
if self.streaming:
self.hf_dataset = self.hf_dataset.reshard()
self.num_shards = self.hf_dataset.num_shards
if self.validate_row_groups and self.streaming:
self._validate_row_groups_per_episode()
avg_episode_len = max(1, round(self.meta.total_frames / max(1, self.meta.total_episodes)))
self.frame_shuffle_buffer_size = (
frame_shuffle_buffer_size
@@ -283,22 +321,102 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
return 0, 1
def _resolve_data_root(self) -> str:
"""fsspec root that holds the bulk ``data/`` parquet tree (revision-qualified for the Hub)."""
if self.data_files_root is not None:
return self.data_files_root
if self.streaming and not self.streaming_from_local:
return f"hf://datasets/{self.repo_id}@{self.revision}"
return str(self.root)
def _episode_files(self) -> dict[tuple[int, int], list[int]]:
"""Map each data file ``(chunk_index, file_index)`` to the episode indices it stores."""
file_to_eps: dict[tuple[int, int], list[int]] = {}
for ep in range(self.meta.total_episodes):
row = self.meta.episodes[ep]
key = (int(row["data/chunk_index"]), int(row["data/file_index"]))
file_to_eps.setdefault(key, []).append(ep)
return file_to_eps
def _validate_row_groups_per_episode(self, sample_files: int = 32) -> None:
"""Verify the dataset stores ONE ROW GROUP PER EPISODE so each episode is an independently
addressable shard after ``reshard()``. Cheap (footer-only) and sampled.
Raises:
ValueError: if a sampled data file collapses several episodes into fewer row groups, or
the whole dataset is one row group per file while holding many more episodes than files.
"""
import fsspec
import pyarrow.parquet as pq
file_to_eps = self._episode_files()
num_data_files = len(file_to_eps)
# Whole-dataset extreme: reshard() could not split beyond file granularity (one row group per
# file) yet there are many more episodes than files -> collapsed.
if self.num_shards <= num_data_files and self.meta.total_episodes > self.num_shards:
raise ValueError(
f"{self.repo_id}: after reshard() the stream still has only {self.num_shards} shard(s) "
f"for {self.meta.total_episodes} episodes across {num_data_files} data file(s) — i.e. one "
"row group per file. StreamingLeRobotDataset random-episode shuffling requires ONE ROW "
"GROUP PER EPISODE so each episode is an independently addressable shard after reshard(). "
"Re-emit through the LeRobot writer (one write_table per episode) or fix the aggregate / "
"annotate / push_to_hub writer that collapsed the row groups, then re-upload. Recorded "
"datasets already satisfy this. Pass validate_row_groups=False to bypass (random-episode "
"quality will degrade)."
)
data_root = self._resolve_data_root()
rng = np.random.default_rng(self.seed)
keys = list(file_to_eps)
chosen = rng.choice(len(keys), size=min(sample_files, len(keys)), replace=False)
for i in chosen:
chunk_idx, file_idx = keys[int(i)]
n_ep = len(file_to_eps[(chunk_idx, file_idx)])
rel = self.meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx)
path = f"{data_root}/{rel}"
with fsspec.open(path, "rb") as f:
pf = pq.ParquetFile(f)
n_rg = pf.num_row_groups
num_rows = pf.metadata.num_rows
if n_rg < n_ep:
raise ValueError(
f"{path}: stored as {n_rg} Parquet row group(s) ({num_rows} rows across "
f"{n_ep} episodes). StreamingLeRobotDataset random-episode shuffling requires ONE ROW "
"GROUP PER EPISODE so each episode becomes an independently addressable shard after "
"reshard(). This file was written by a bulk df.to_parquet / push_to_hub / aggregate "
"path that collapses row groups. Re-emit through the LeRobot writer (one write_table "
"per episode) or fix the aggregate/annotate writer, then re-upload. Recorded datasets "
"already satisfy this. Pass validate_row_groups=False to bypass (quality will degrade)."
)
def _build_pipeline(self) -> datasets.IterableDataset:
"""Assemble the native tabular pipeline (everything except video decode)."""
ds = self.hf_dataset
if self.world_size > 1:
if ds.num_shards % self.world_size != 0:
logger.warning(
f"num_shards ({ds.num_shards}) is not divisible by world_size ({self.world_size}): "
"datasets falls back to example-level splitting where every rank reads (and pays "
"for) the full stream. Re-shard the dataset or adjust world size."
msg = (
f"num_shards ({ds.num_shards}) is not divisible by world_size ({self.world_size}). "
"After reshard() num_shards == the episode count, and split_dataset_by_node only "
"assigns shards evenly when num_shards % world_size == 0; otherwise every rank "
"streams (and pays for) the full dataset and keeps only 1/world_size of it. Pin "
"world_size to a divisor of the episode count, or drop/pad episodes to a divisible "
"count with the dataset tools. Set validate_row_groups=False to downgrade to a warning."
)
if self.validate_row_groups:
raise ValueError(msg)
logger.warning(msg)
ds = split_dataset_by_node(ds, rank=self.rank, world_size=self.world_size)
ds = ds.batch(by_column="episode_index")
episode_columns = list(ds.column_names or self.hf_dataset.column_names or [])
if self.shuffle:
ds = ds.shuffle(seed=self.seed, buffer_size=self.episode_pool_size)
max_input_shards = max(1, min(self.max_buffer_input_shards, ds.num_shards))
ds = ds.shuffle(
seed=self.seed,
buffer_size=self.episode_pool_size,
max_buffer_input_shards=max_input_shards,
)
# A row-count-changing batched map must drop the input columns explicitly; the exploded
# frames re-emit them (windowed keys replaced by their delta windows + *_is_pad masks).
ds = ds.map(self._explode_episodes, batched=True, remove_columns=episode_columns)
@@ -358,10 +476,15 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
# `datasets` reshuffles (and re-permutes shard order) per epoch from (seed, epoch);
# DataLoader workers each advance their own copy's counter in lockstep. The in-flight
# epoch is tracked separately so a mid-iteration state_dict() records the epoch the
# stream position actually belongs to.
self._in_flight_epoch = self._epoch
# stream position actually belongs to. Only advance when shuffling: after reshard() the
# stream has one shard per episode, and set_epoch(n>0) re-permutes shard order even without
# a shuffle op, so an unshuffled stream must pin epoch 0 to repeat the same order each pass.
if self.shuffle:
self._in_flight_epoch = self._epoch
self._epoch += 1
else:
self._in_flight_epoch = 0
self._pipeline.set_epoch(self._in_flight_epoch)
self._epoch += 1
self.video_decoder_cache = self._make_video_decoder_cache()
iterator = iter(self._pipeline)
+116
View File
@@ -312,3 +312,119 @@ def test_pipeline_uses_native_primitives(tmp_path, lerobot_dataset_factory):
assert isinstance(ds._pipeline, hf_datasets.IterableDataset)
state = ds._pipeline.state_dict() # the native resume protocol is available end-to-end
assert state is not None
# --- Plan B: random-episode admission via reshard() + multi-input-shard shuffle ---
def test_reshard_makes_one_shard_per_episode(tmp_path, lerobot_dataset_factory):
"""With one row group per episode (the writer's invariant), reshard() turns each episode into its
own shard, so num_shards == total_episodes even when many episodes share a single data file."""
import pyarrow.parquet as pq
repo_id = f"{DUMMY_REPO_ID}-reshard"
total_episodes = 3
# Default (large) data-file size packs all (unequal-length) episodes into one file, so the only way
# num_shards can reach total_episodes is per-row-group resharding.
lerobot_dataset_factory(
root=tmp_path / "ds",
repo_id=repo_id,
total_episodes=total_episodes,
total_frames=90,
use_videos=False,
)
ds = StreamingLeRobotDataset(repo_id=repo_id, root=tmp_path / "ds", shuffle=False, episode_pool_size=3)
file_to_eps = ds._episode_files()
assert len(file_to_eps) == 1, "test expects all episodes packed into a single data file"
for (chunk_idx, file_idx), eps in file_to_eps.items():
rel = ds.meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx)
assert pq.ParquetFile(str(ds.root / rel)).num_row_groups == len(eps)
assert ds.num_shards == total_episodes
def test_max_buffer_input_shards_admits_random_episodes(tmp_path, lerobot_dataset_factory):
"""max_buffer_input_shards (== concurrently-live random episodes) drives the per-batch episode mix:
a single batch should already span most of the live episodes."""
repo_id = f"{DUMMY_REPO_ID}-frac"
total_episodes = 8
lerobot_dataset_factory(
root=tmp_path / "ds",
repo_id=repo_id,
total_episodes=total_episodes,
total_frames=240,
use_videos=False,
)
ds = StreamingLeRobotDataset(
repo_id=repo_id,
root=tmp_path / "ds",
shuffle=True,
seed=0,
episode_pool_size=total_episodes,
max_buffer_input_shards=total_episodes,
)
assert ds.max_buffer_input_shards == total_episodes
batch = 32
head = {int(frame["episode_index"]) for _, frame in zip(range(batch), ds, strict=False)}
assert len(head) >= min(total_episodes, batch) - 2, f"batch did not mix random episodes: {head}"
def test_collapsed_row_groups_raise(tmp_path, lerobot_dataset_factory):
"""A data file that collapses several episodes into a single row group (bulk df.to_parquet /
push_to_hub) must be rejected with an actionable error: reshard() cannot address its episodes."""
import pyarrow.parquet as pq
repo_id = f"{DUMMY_REPO_ID}-collapsed"
lerobot_dataset_factory(
root=tmp_path / "ds", repo_id=repo_id, total_episodes=3, total_frames=90, use_videos=False
)
# Rewrite every data file as a single row group (simulating the aggregate/push_to_hub collapse).
for parquet_path in (tmp_path / "ds" / "data").rglob("*.parquet"):
pq.write_table(pq.read_table(parquet_path), parquet_path)
with pytest.raises(ValueError, match="ONE ROW GROUP PER EPISODE"):
StreamingLeRobotDataset(repo_id=repo_id, root=tmp_path / "ds", shuffle=False, episode_pool_size=3)
def test_collapsed_row_groups_can_be_bypassed(tmp_path, lerobot_dataset_factory):
"""validate_row_groups=False skips the row-group check (collapsed datasets still load, degraded)."""
import pyarrow.parquet as pq
repo_id = f"{DUMMY_REPO_ID}-collapsed-bypass"
lerobot_dataset_factory(
root=tmp_path / "ds", repo_id=repo_id, total_episodes=3, total_frames=90, use_videos=False
)
for parquet_path in (tmp_path / "ds" / "data").rglob("*.parquet"):
pq.write_table(pq.read_table(parquet_path), parquet_path)
ds = StreamingLeRobotDataset(
repo_id=repo_id, root=tmp_path / "ds", shuffle=False, episode_pool_size=3, validate_row_groups=False
)
assert sorted(int(frame["index"]) for frame in ds) == list(range(90))
def test_distributed_divisibility_guard_raises(tmp_path, lerobot_dataset_factory):
"""When num_shards (== episodes after reshard) is not divisible by world_size, every rank would
stream the whole dataset; the guard must raise instead of silently degrading."""
repo_id = f"{DUMMY_REPO_ID}-divis"
lerobot_dataset_factory(
root=tmp_path / "ds", repo_id=repo_id, total_episodes=3, total_frames=90, use_videos=False
)
with pytest.raises(ValueError, match="not divisible by world_size"):
StreamingLeRobotDataset(
repo_id=repo_id, root=tmp_path / "ds", shuffle=False, episode_pool_size=3, rank=0, world_size=2
)
# Bypassing the guard downgrades it to a warning (no raise).
ds = StreamingLeRobotDataset(
repo_id=repo_id,
root=tmp_path / "ds",
shuffle=False,
episode_pool_size=3,
rank=0,
world_size=2,
validate_row_groups=False,
)
assert ds.num_shards == 3
+22 -3
View File
@@ -17,6 +17,7 @@ from pathlib import Path
import datasets
import numpy as np
import pandas as pd
import pyarrow.parquet as pq
import pytest
from datasets import Dataset
@@ -35,6 +36,24 @@ from lerobot.datasets.utils import (
)
def _to_parquet_one_row_group_per_episode(hf_dataset: Dataset, path: Path) -> None:
"""Write ``hf_dataset`` to ``path`` with one Parquet row group per episode.
Mirrors the LeRobot recording writer (one ``write_table`` per episode) so each episode stays an
independently addressable shard after ``datasets.IterableDataset.reshard()``, which
``StreamingLeRobotDataset`` relies on. ``Dataset.to_parquet`` would collapse the file into a
single row group instead.
"""
table = hf_dataset.with_format("arrow")[:]
episode_index = np.asarray(hf_dataset["episode_index"])
boundaries = np.where(np.diff(episode_index) != 0)[0] + 1
starts = [0, *boundaries.tolist()]
ends = [*boundaries.tolist(), len(episode_index)]
with pq.ParquetWriter(str(path), table.schema) as writer:
for start, end in zip(starts, ends, strict=True):
writer.write_table(table.slice(start, end - start))
def write_hf_dataset(
hf_dataset: Dataset,
local_dir: Path,
@@ -67,7 +86,7 @@ def write_hf_dataset(
# If the dataset is small enough, write it to a single file.
path = local_dir / DEFAULT_DATA_PATH.format(chunk_index=0, file_index=0)
path.parent.mkdir(parents=True, exist_ok=True)
hf_dataset.to_parquet(path)
_to_parquet_one_row_group_per_episode(hf_dataset, path)
return
# If the dataset is too large, split it into smaller chunks, keeping episodes whole.
@@ -114,8 +133,8 @@ def write_hf_dataset(
path = local_dir / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
path.parent.mkdir(parents=True, exist_ok=True)
# Write the shard to a Parquet file.
dataset_shard.to_parquet(path)
# Write the shard to a Parquet file (one row group per episode).
_to_parquet_one_row_group_per_episode(dataset_shard, path)
# Update chunk and file indices for the next iteration.
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, chunk_size)