refactor(streaming): rebuild StreamingLeRobotDataset on native datasets primitives

The custom episode pool becomes a pure `datasets` pipeline:

  split_dataset_by_node -> batch(by_column="episode_index")
    -> shuffle(buffer=episode_pool_size)            # episode pool
    -> map(explode + exact delta windows)           # episode -> frames
    -> shuffle(buffer=frame_shuffle_buffer_size)    # frame interleave

and the torch IterableDataset wrapper keeps only per-sample video decode
(decode-on-exit), image transforms, task lookup, and decode/fetch timing.

Replaced by native machinery and deleted: the pooled-episode admission
loop, the refcounted video prefetcher, manual worker shard striding plus
the worker-split suppression patch, the per-(epoch, rank) shard-order
permutation, the per-consumer SplitMix64 RNG, and fast-forward resume.
DataLoader workers are split by `datasets` itself; .shuffle() permutes
shard order per epoch natively; resume delegates to the native
state_dict/load_state_dict (exact with num_workers=0; with workers use
torchdata's StatefulDataLoader, which checkpoints per-worker state
through the same protocol). An in-flight epoch counter ensures a
mid-iteration state_dict records the epoch the stream position belongs
to. Buffer contents are skipped on resume (documented datasets
behavior): never repeats data, drops at most ~pool + frame-buffer frames.

Randomness is unchanged: a batch still mixes up to episode_pool_size
episodes; delta windows are still exact in-episode slices with correct
boundary padding (value-verified against the map-style dataset). The
known trade accepted with this rewrite: no video prefetch-on-admit, so
remote decode pays per-frame range reads at yield time - use a colocated
bucket (data_files_root) at large scale.

The delta-consistency tests gained a scalar-comparison branch: they
silently skipped python-scalar keys before (stale `check` variable),
exposed by the new pipeline's key ordering.

Requires datasets with #8259 (pinned to the merge commit on this
branch). Example updated to per-rank native resume via torchdata's
StatefulDataLoader when available.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-06-11 21:03:09 +02:00
parent 984b400e5c
commit 894fc6bfb5
4 changed files with 258 additions and 567 deletions
+29 -16
View File
@@ -21,7 +21,7 @@ streaming features of :class:`StreamingLeRobotDataset`:
- per-rank sharding via ``split_dataset_by_node`` (each GPU streams disjoint data; ``rank``/``world_size``
are auto-resolved from the Accelerate state, so nothing needs to be passed explicitly);
- DataLoader-worker shard splitting (no duplicate frames within a rank);
- deterministic fast-forward resume via ``dataset.load_state_dict()`` (trainer-side counters only);
- native `datasets` resume: the loader checkpoints stream state via ``state_dict()`` (``torchdata`` StatefulDataLoader when available, so ``num_workers > 0`` resumes too);
- an explicit video-decoder cache size so the working set of open decoders does not thrash.
Launch with Accelerate (single node, N GPUs):
@@ -85,7 +85,16 @@ def make_dataloader(
video_decoder_cache_size=args.video_decoder_cache_size,
tolerance_s=1e-3,
)
loader = DataLoader(
# torchdata's StatefulDataLoader checkpoints each worker's dataset state through the
# dataset's native state_dict protocol, making resume work with num_workers > 0. Fall back
# to the plain DataLoader (resume then requires num_workers=0).
try:
from torchdata.stateful_dataloader import StatefulDataLoader
loader_cls = StatefulDataLoader
except ImportError:
loader_cls = DataLoader
loader = loader_cls(
dataset,
batch_size=args.batch_size,
num_workers=args.num_workers,
@@ -124,13 +133,17 @@ def main() -> None:
# of it). Batches are moved to the device manually in the loop.
model, optimizer = accelerator.prepare(model, optimizer)
# Resume: deterministic fast-forward. Every consumer's order is a pure function of
# (seed, epoch, rank, worker), so resuming only needs the trainer-side counters; each rank and
# worker re-derives its own skip. Same file works for every rank.
# Resume: native datasets stream state, saved per rank. With torchdata's StatefulDataLoader
# the state covers every worker; with the plain DataLoader it is exact for num_workers=0.
can_checkpoint_loader = hasattr(loader, "state_dict")
if args.resume_from is not None:
state = torch.load(Path(args.resume_from) / "dataset_state.pt", weights_only=True)
dataset.load_state_dict(state)
accelerator.print(f"Resuming dataset stream: {state['batches_consumed']} batches consumed")
state_path = Path(args.resume_from) / f"dataset_state_rank{accelerator.process_index}.pt"
state = torch.load(state_path, weights_only=False) # plain dict of stream offsets # nosec B614
if can_checkpoint_loader:
loader.load_state_dict(state)
else:
dataset.load_state_dict(state)
accelerator.print(f"Resumed dataset stream from {state_path}")
step = 0
frames_seen = 0
@@ -157,15 +170,15 @@ def main() -> None:
)
window_start = time.perf_counter()
if step % args.save_freq == 0 and accelerator.is_main_process:
if step % args.save_freq == 0:
ckpt = output_dir / f"checkpoint-{step}"
ckpt.mkdir(parents=True, exist_ok=True)
# Save the consumed-batch counters so a restart fast-forwards to this position.
torch.save(
{"batches_consumed": step, "batch_size": args.batch_size},
ckpt / "dataset_state.pt",
)
if model is not None:
if accelerator.is_main_process:
ckpt.mkdir(parents=True, exist_ok=True)
accelerator.wait_for_everyone()
# Every rank saves its own stream state: shard positions differ per rank.
state = loader.state_dict() if can_checkpoint_loader else dataset.state_dict()
torch.save(state, ckpt / f"dataset_state_rank{accelerator.process_index}.pt")
if model is not None and accelerator.is_main_process:
accelerator.unwrap_model(model).save_pretrained(ckpt)
if step >= args.steps:
+193 -430
View File
@@ -13,18 +13,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import inspect
import logging
import os
import shutil
import time
from collections.abc import Callable, Iterator
from concurrent.futures import Future, ThreadPoolExecutor
from pathlib import Path
import datasets
import fsspec
import numpy as np
import torch
from datasets import load_dataset
@@ -39,7 +33,6 @@ from .utils import (
check_version_compatibility,
find_float_index,
is_float_in_list,
safe_shard,
)
from .video_utils import (
VideoDecoderCache,
@@ -48,134 +41,38 @@ from .video_utils import (
logger = logging.getLogger(__name__)
# datasets >= 5 groups a stream into whole-episode batches natively (Arrow-side accumulation,
# https://github.com/huggingface/datasets/pull/8172); older versions fall back to a Python row loop.
_HAS_BATCH_BY_COLUMN = "by_column" in inspect.signature(datasets.IterableDataset.batch).parameters
_MASK_64 = (1 << 64) - 1
def _mix64(x: int) -> int:
"""SplitMix64 finalizer (64-bit integer hash) for seed derivation."""
x = (x + 0x9E3779B97F4A7C15) & _MASK_64
x ^= x >> 30
x = (x * 0xBF58476D1CE4E5B9) & _MASK_64
x ^= x >> 27
x = (x * 0x94D049BB133111EB) & _MASK_64
x ^= x >> 31
return x
@contextlib.contextmanager
def _suppress_hf_worker_split():
"""Hide the torch DataLoader worker context from `datasets` while we drain its streams.
`datasets` detects torch workers and re-splits its shards across them internally
(`_iter_pytorch`); this dataset already assigns disjoint shards per worker, so the second
split silently drops data whenever a per-worker stream has fewer internal shards than there
are workers — and on datasets 5.0 it also crashes `batch(by_column=...)`. The patch is local
to this DataLoader worker process and restored on exit.
"""
original = torch.utils.data.get_worker_info
torch.utils.data.get_worker_info = lambda: None
try:
yield
finally:
torch.utils.data.get_worker_info = original
class _PooledEpisode:
"""A fully-loaded episode's tabular rows plus emission bookkeeping."""
__slots__ = ("episode_index", "rows", "remaining", "video_rel_paths")
def __init__(self, episode_index: int, rows: list[dict], video_rel_paths: list[str]):
self.episode_index = episode_index
self.rows = rows
self.remaining = list(range(len(rows)))
self.video_rel_paths = video_rel_paths
class _VideoPrefetcher:
"""Background downloader of episode video files into a local cache (decode-on-exit support).
Files are refcounted because LeRobot v3 packs several episodes per video file: a file is
downloaded once when the first pooled episode referencing it is admitted and deleted when
the last one is evicted. Downloads resolve through fsspec (hf://, s3://, https://, ...).
"""
def __init__(self, remote_root: str, cache_dir: Path, max_workers: int = 4):
self._remote_root = remote_root.rstrip("/")
self._cache_dir = cache_dir
self._executor = ThreadPoolExecutor(max_workers=max_workers, thread_name_prefix="video-prefetch")
self._refcounts: dict[str, int] = {}
self._futures: dict[str, Future] = {}
def acquire(self, rel_path: str) -> None:
self._refcounts[rel_path] = self._refcounts.get(rel_path, 0) + 1
if rel_path not in self._futures:
self._futures[rel_path] = self._executor.submit(self._download, rel_path)
def _download(self, rel_path: str) -> Path:
local = self._cache_dir / rel_path
if local.exists():
return local
local.parent.mkdir(parents=True, exist_ok=True)
tmp = local.with_suffix(local.suffix + ".tmp")
with fsspec.open(f"{self._remote_root}/{rel_path}", "rb") as src, open(tmp, "wb") as dst:
shutil.copyfileobj(src, dst, length=1 << 22)
tmp.rename(local)
return local
def wait_local(self, rel_path: str) -> Path | None:
"""Block until the file is cached; None when not tracked or the download failed."""
future = self._futures.get(rel_path)
if future is None:
return None
try:
return future.result()
except Exception as e:
logger.warning(f"Video prefetch failed for {rel_path} ({e}); decoding from remote instead.")
return None
def release(self, rel_path: str) -> None:
count = self._refcounts.get(rel_path, 0) - 1
if count > 0:
self._refcounts[rel_path] = count
return
self._refcounts.pop(rel_path, None)
future = self._futures.pop(rel_path, None)
if future is None:
return
if not future.cancel():
try:
local = future.result()
local.unlink(missing_ok=True)
except Exception:
logger.debug(f"Could not delete cached video {rel_path}.", exc_info=True)
def shutdown(self) -> None:
self._executor.shutdown(wait=False, cancel_futures=True)
# Bound the default frame-level shuffle buffer: rows are tabular-only (~KB each), so this is
# roughly a few hundred MB of host RAM per consumer at the cap.
_MAX_DEFAULT_FRAME_BUFFER = 200_000
class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
"""LeRobotDataset with streaming capabilities.
"""LeRobotDataset with streaming capabilities, built on native HF `datasets` primitives.
Streams frames from the Hub (or any fsspec source) without downloading the dataset. The
iteration strategy is an *episode pool*: each consumer keeps ``episode_pool_size`` whole
episodes' tabular rows in RAM (a few KB per episode) and emits uniformly random frames
across them, so a batch mixes up to ``batch_size`` distinct episodes. Because a frame's
whole episode is resident, ``delta_timestamps`` windows are exact array slices with correct
padding at episode boundaries. Video is decoded only when a sample is emitted
(decode-on-exit), so pool memory stays tabular-sized; when streaming from a remote source,
each pooled episode's video files are prefetched to a local cache in the background and
deleted on eviction.
The tabular side is a pure `datasets` pipeline::
Distribution: ranks stream disjoint shards via ``split_dataset_by_node`` and DataLoader
workers split a rank's shards further, so every frame is consumed exactly once per epoch
across the whole fleet. Each consumer's order is a pure function of
``(seed, epoch, rank, worker)``, which makes resume a deterministic fast-forward (see
:meth:`load_state_dict`).
load_dataset(streaming=True) # parquet shards from the Hub / a bucket
-> split_dataset_by_node(rank, world_size) # disjoint shards per rank
-> batch(by_column="episode_index") # whole episodes
-> shuffle(buffer_size=episode_pool_size) # episode pool (the randomness knob)
-> map(explode + exact delta windows) # episode -> frames, windows are exact
-> shuffle(buffer_size=frame_shuffle_buffer_size) # frame-level interleave
and this class is a thin torch ``IterableDataset`` wrapper around it that decodes video
per emitted sample (decode-on-exit), applies image transforms, and attaches the task
string. DataLoader workers are split natively by `datasets` (disjoint shards per worker),
and resume uses the native ``state_dict`` / ``load_state_dict``.
Randomness: a batch mixes up to ``episode_pool_size`` distinct episodes; delta windows are
exact slices of the resident episode with correct padding at episode boundaries.
Resume: ``state_dict()`` / ``load_state_dict()`` delegate to `datasets`. Samples sitting in
the shuffle buffers at checkpoint time are skipped on resume (documented `datasets`
behavior), so resume never repeats data but may drop up to roughly
``episode_pool_size x episode_len + frame_shuffle_buffer_size`` frames — negligible at
training scale. The contract is exact with ``num_workers=0``; with DataLoader workers use
``torchdata.stateful_dataloader.StatefulDataLoader``, which checkpoints each worker's
dataset state through this same protocol.
Example:
```python
@@ -201,6 +98,7 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
force_cache_sync: bool = False,
streaming: bool = True,
episode_pool_size: int | None = 64,
frame_shuffle_buffer_size: int | None = None,
buffer_size: int | None = None,
max_num_shards: int | None = None,
seed: int = 42,
@@ -212,8 +110,6 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
video_decoder_cache_size: int | None = None,
data_files_root: str | None = None,
video_decode_device: str = "cpu",
prefetch_videos: bool = True,
video_prefetch_workers: int = 4,
):
"""Initialize a StreamingLeRobotDataset.
@@ -229,16 +125,18 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
revision (str, optional): Git revision id (branch name, tag, or commit hash).
force_cache_sync (bool, optional): Flag to sync and refresh local files first.
streaming (bool, optional): Whether to stream the dataset or load it all. Defaults to True.
episode_pool_size (int, optional): Number of whole episodes each consumer keeps open to
shuffle across — the randomness knob. Larger mixes more episodes per batch (closer to
map-style uniform) at the cost of cold-start latency; RAM stays small because the pool
holds tabular rows only. Defaults to 64.
episode_pool_size (int, optional): Whole episodes each consumer keeps open to shuffle
across — the randomness knob. Larger mixes more episodes per batch (closer to
map-style uniform) at the cost of cold-start latency and frame-buffer RAM.
Defaults to 64.
frame_shuffle_buffer_size (int | None, optional): Frame-level shuffle buffer after the
episode pool. Defaults to ``episode_pool_size x average episode length`` (capped),
which matches the pool's mixing radius.
buffer_size (int | None, optional): Deprecated; superseded by ``episode_pool_size``.
max_num_shards (int | None, optional): Cap on the number of stream shards. None (default)
uses every underlying parquet shard, which is required to feed many DataLoader workers.
max_num_shards (int | None, optional): Deprecated; `datasets` handles shard-to-worker
assignment natively.
seed (int, optional): Reproducibility random seed.
rng (np.random.Generator | None, optional): Deprecated; ignored (the RNG is derived from
``(seed, epoch, rank, worker)`` so consumers are decorrelated and runs reproducible).
rng (np.random.Generator | None, optional): Deprecated; ignored.
shuffle (bool, optional): Whether to shuffle. False yields episodes in stream order.
rank (int | None, optional): This process' rank for distributed training. Each rank streams
a disjoint set of shards via ``split_dataset_by_node``. When omitted, resolved from
@@ -253,10 +151,6 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
from there while metadata still loads from ``repo_id`` on the Hub.
video_decode_device (str, optional): Device for torchcodec decode. ``"cuda"`` offloads to
NVDEC (needs a CUDA torchcodec build and ``spawn`` DataLoader workers).
prefetch_videos (bool, optional): When streaming from a remote source, download each pooled
episode's video files to a local cache in the background so decode-on-exit reads local
bytes instead of paying network seek latency. Defaults to True.
video_prefetch_workers (int, optional): Download threads per consumer. Defaults to 4.
"""
super().__init__()
self.repo_id = repo_id
@@ -271,37 +165,35 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
self.seed = seed
if rng is not None:
logger.warning("StreamingLeRobotDataset: `rng` is deprecated and ignored; use `seed`.")
self.shuffle = shuffle
self.streaming = streaming
if buffer_size is not None:
logger.warning(
"StreamingLeRobotDataset: `buffer_size` is deprecated and ignored; "
"use `episode_pool_size` (whole episodes, not frames)."
)
if max_num_shards is not None:
logger.warning(
"StreamingLeRobotDataset: `max_num_shards` is deprecated and ignored; "
"`datasets` assigns shards to DataLoader workers natively."
)
self.shuffle = shuffle
self.streaming = streaming
self.episode_pool_size = max(1, episode_pool_size) if episode_pool_size else 64
self.max_num_shards = max_num_shards
self._return_uint8 = return_uint8
self.rank, self.world_size = self._resolve_distributed(rank, world_size)
self.video_decoder_cache_size = video_decoder_cache_size
self.data_files_root = data_files_root.rstrip("/") if data_files_root else None
self.video_decode_device = video_decode_device
self.prefetch_videos = prefetch_videos
self.video_prefetch_workers = video_prefetch_workers
# We cache the video decoders to avoid re-initializing them at each frame (avoiding a ~10x slowdown)
self.video_decoder_cache = None
self._prefetcher: _VideoPrefetcher | None = None
# Shared [hits, misses, evictions, decode_ns, fetch_ns] tensor so DataLoader workers aggregate
# decoder-cache stats and component timings into one place the main process can read after
# iteration (see video_decoder_cache_stats() / timing_stats()).
self._cache_counters = torch.zeros(5, dtype=torch.int64).share_memory_()
# Deterministic fast-forward resume (see load_state_dict): per-consumer epoch counter and
# number of samples still to skip.
self._epoch = 0
self._ff_remaining = 0
self._resume_state: dict | None = None
self._in_flight_epoch = 0
if self._requested_root is not None:
self.root.mkdir(exist_ok=True, parents=True)
@@ -349,12 +241,17 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
if extra_columns:
self.hf_dataset = self.hf_dataset.remove_columns(extra_columns)
self.num_shards = (
self.hf_dataset.num_shards
if self.max_num_shards is None
else min(self.hf_dataset.num_shards, self.max_num_shards)
self.num_shards = self.hf_dataset.num_shards
avg_episode_len = max(1, round(self.meta.total_frames / max(1, self.meta.total_episodes)))
self.frame_shuffle_buffer_size = (
frame_shuffle_buffer_size
if frame_shuffle_buffer_size is not None
else min(self.episode_pool_size * avg_episode_len, _MAX_DEFAULT_FRAME_BUFFER)
)
self._pipeline = self._build_pipeline()
@property
def num_frames(self):
return self.meta.total_frames
@@ -374,6 +271,8 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
Explicit arguments win. Otherwise prefer an already-initialized Accelerate state, then the
``RANK``/``WORLD_SIZE`` env vars set by launchers, and finally fall back to single-process (0, 1).
"""
import os
if rank is not None and world_size is not None:
return rank, world_size
@@ -393,27 +292,67 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
return 0, 1
def _consumer_rng(self, epoch: int, worker_id: int) -> np.random.Generator:
"""RNG derived from (seed, epoch, rank, worker): reproducible, decorrelated consumers."""
state = _mix64(self.seed)
for salt in (self.rank, worker_id, epoch if self.shuffle else 0):
state = _mix64(state ^ _mix64(salt))
return np.random.default_rng(state)
def _build_pipeline(self) -> datasets.IterableDataset:
"""Assemble the native tabular pipeline (everything except video decode)."""
ds = self.hf_dataset
if self.world_size > 1:
if ds.num_shards % self.world_size != 0:
logger.warning(
f"num_shards ({ds.num_shards}) is not divisible by world_size ({self.world_size}): "
"datasets falls back to example-level splitting where every rank reads (and pays "
"for) the full stream. Re-shard the dataset or adjust world size."
)
ds = split_dataset_by_node(ds, rank=self.rank, world_size=self.world_size)
def _shard_order(self, epoch: int, num_shards: int) -> list[int]:
"""Seeded permutation of this rank's shard indices, re-drawn every epoch.
In a sub-epoch run over a corpus consolidated source-by-source, index-order shard
assignment means training on whatever the first N% of files contains; permuting the
shard order turns that into a uniform sample of files. Seeded by (seed, epoch, rank)
only — every DataLoader worker of the rank must agree on this list, because workers
stride it and disagreement would create overlapping shard assignments.
"""
order = list(range(num_shards))
ds = ds.batch(by_column="episode_index")
episode_columns = list(ds.column_names or self.hf_dataset.column_names or [])
if self.shuffle:
state = _mix64(self.seed) ^ _mix64(0x5EED5EED) ^ _mix64(self.rank) ^ _mix64(epoch)
np.random.default_rng(_mix64(state)).shuffle(order)
return order
ds = ds.shuffle(seed=self.seed, buffer_size=self.episode_pool_size)
# A row-count-changing batched map must drop the input columns explicitly; the exploded
# frames re-emit them (windowed keys replaced by their delta windows + *_is_pad masks).
ds = ds.map(self._explode_episodes, batched=True, remove_columns=episode_columns)
if self.shuffle:
ds = ds.shuffle(seed=self.seed + 1, buffer_size=max(2, self.frame_shuffle_buffer_size))
return ds
def _tabular_window_keys(self) -> list[str]:
if self.delta_indices is None:
return []
return [key for key in self.delta_indices if key not in self.meta.video_keys]
def _explode_episodes(self, episode_batch: dict[str, list[list]]) -> dict[str, list]:
"""Episode batches -> per-frame rows, with exact tabular delta windows and pad masks.
Runs inside the `datasets` pipeline (plain Python values, no torch). For each windowed key
the original per-frame value is replaced by its delta window (list of values, clamped to
the episode bounds) plus a ``{key}_is_pad`` mask, mirroring the map-style dataset.
"""
window_keys = set(self._tabular_window_keys())
out: dict[str, list] = {key: [] for key in episode_batch if key not in window_keys}
for key in window_keys:
out[key] = []
out[f"{key}_is_pad"] = []
num_episodes = len(episode_batch["episode_index"])
for e in range(num_episodes):
length = len(episode_batch["episode_index"][e])
for key, column in episode_batch.items():
if key in window_keys:
continue
out[key].extend(column[e])
for key in window_keys:
episode_column = episode_batch[key][e]
deltas = self.delta_indices[key]
for t in range(length):
window = []
is_pad = []
for delta in deltas:
j = t + delta
window.append(episode_column[min(max(j, 0), length - 1)])
is_pad.append(not 0 <= j < length)
out[key].append(window)
out[f"{key}_is_pad"].append(is_pad)
return out
def _make_video_decoder_cache(self) -> VideoDecoderCache:
"""Size the decoder cache to the pool's working set (pool episodes x cameras), capped at 128."""
@@ -432,231 +371,37 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
device=self.video_decode_device,
)
def _make_prefetcher(self) -> _VideoPrefetcher | None:
if not self.prefetch_videos or len(self.meta.video_keys) == 0:
return None
if self.data_files_root is not None:
remote_root = self.data_files_root
elif self.streaming and not self.streaming_from_local:
remote_root = self.meta.url_root
else:
return None # video bytes are already local
return _VideoPrefetcher(
remote_root,
cache_dir=self.root / "streaming_video_cache",
max_workers=self.video_prefetch_workers,
)
@staticmethod
def _iter_shard_episodes(shard: datasets.IterableDataset) -> Iterator[tuple[int, list[dict]]]:
"""Yield (episode_index, rows) for each complete episode of a shard stream.
On datasets >= 5 the grouping runs natively in Arrow via ``batch(by_column=...)``
(one accumulation per episode instead of one Python dict per row); older versions
use the equivalent row loop.
"""
if _HAS_BATCH_BY_COLUMN:
for batch in shard.batch(by_column="episode_index"):
keys = list(batch.keys())
num_rows = len(batch["episode_index"])
rows = [{key: batch[key][i] for key in keys} for i in range(num_rows)]
yield int(batch["episode_index"][0]), rows
return
rows: list[dict] = []
current: int | None = None
for item in shard:
ep_idx = int(item["episode_index"])
if current is None:
current = ep_idx
if ep_idx != current:
yield current, rows
rows = []
current = ep_idx
rows.append(item)
if rows:
yield current, rows
def _admit_episode(self, ep_idx: int, rows: list[dict], prefetcher: _VideoPrefetcher | None):
video_rel_paths = [str(self.meta.get_video_file_path(ep_idx, key)) for key in self.meta.video_keys]
if prefetcher is not None:
for rel in video_rel_paths:
prefetcher.acquire(rel)
torch_rows = [item_to_torch(row) for row in rows]
return _PooledEpisode(ep_idx, torch_rows, video_rel_paths)
def __iter__(self) -> Iterator[dict[str, torch.Tensor]]:
ds = self.hf_dataset
if self.world_size > 1:
if ds.num_shards % self.world_size != 0:
logger.warning(
f"num_shards ({ds.num_shards}) is not divisible by world_size ({self.world_size}): "
"datasets falls back to example-level splitting where every rank reads (and pays "
"for) the full stream. Re-shard the dataset or adjust world size."
)
ds = split_dataset_by_node(ds, rank=self.rank, world_size=self.world_size)
num_shards = ds.num_shards if self.max_num_shards is None else min(ds.num_shards, self.max_num_shards)
epoch = self._epoch
# `datasets` reshuffles (and re-permutes shard order) per epoch from (seed, epoch);
# DataLoader workers each advance their own copy's counter in lockstep. The in-flight
# epoch is tracked separately so a mid-iteration state_dict() records the epoch the
# stream position actually belongs to.
self._in_flight_epoch = self._epoch
self._pipeline.set_epoch(self._in_flight_epoch)
self._epoch += 1
shard_indices = self._shard_order(epoch, num_shards)
# DataLoader workers within this rank further split the shards so they don't yield duplicates.
worker_info = torch.utils.data.get_worker_info()
worker_id, num_workers = (worker_info.id, worker_info.num_workers) if worker_info else (0, 1)
shard_indices = shard_indices[worker_id::num_workers]
if not shard_indices:
logger.warning(
f"Worker {worker_id} owns no shards ({num_shards} shards < {num_workers} workers): "
"it will yield nothing. Reduce num_workers or re-shard the dataset."
)
return
self.video_decoder_cache = self._make_video_decoder_cache()
prefetcher = self._make_prefetcher()
self._prefetcher = prefetcher
rng = self._consumer_rng(epoch, worker_id)
# Workers beyond the shard count yield nothing and are stopped by the DataLoader, so the
# batch round-robin effectively runs over min(num_workers, num_shards) active workers.
self._consume_resume_state(worker_id, min(num_workers, num_shards))
iterator = iter(self._pipeline)
while True:
fetch_start = time.perf_counter_ns()
try:
row = next(iterator)
except StopIteration:
return
finally:
self._cache_counters[4] += time.perf_counter_ns() - fetch_start
yield self._finalize_sample(row)
# Round-robin episode admission across this consumer's shard streams (deterministic).
streams = [self._iter_shard_episodes(safe_shard(ds, idx, num_shards)) for idx in shard_indices]
next_stream = 0
pool: list[_PooledEpisode] = []
total_remaining = 0
def admit() -> int:
nonlocal next_stream, total_remaining
admitted = 0
while len(pool) < self.episode_pool_size and streams:
stream = streams[next_stream % len(streams)]
fetch_start = time.perf_counter_ns()
try:
ep_idx, rows = next(stream)
except StopIteration:
streams.remove(stream)
continue
finally:
self._cache_counters[4] += time.perf_counter_ns() - fetch_start
next_stream += 1
episode = self._admit_episode(ep_idx, rows, prefetcher)
pool.append(episode)
total_remaining += len(episode.remaining)
admitted += 1
return admitted
worker_split_guard = (
_suppress_hf_worker_split() if worker_info is not None else contextlib.nullcontext()
)
try:
with worker_split_guard:
admit()
while pool:
if self.shuffle:
# Uniform draw over every remaining frame in the pool: pick the episode by
# cumulative remaining count, then a random remaining position (swap-pop).
draw = int(rng.integers(total_remaining))
for episode in pool:
if draw < len(episode.remaining):
break
draw -= len(episode.remaining)
pick = int(rng.integers(len(episode.remaining)))
frame_pos = episode.remaining[pick]
episode.remaining[pick] = episode.remaining[-1]
episode.remaining.pop()
else:
episode = pool[0]
frame_pos = episode.remaining.pop(0)
total_remaining -= 1
if self._ff_remaining > 0:
self._ff_remaining -= 1
else:
yield self._make_pool_sample(episode, frame_pos)
if not episode.remaining:
pool.remove(episode)
if prefetcher is not None:
for rel in episode.video_rel_paths:
prefetcher.release(rel)
admit()
finally:
if prefetcher is not None:
prefetcher.shutdown()
self._prefetcher = None
def load_state_dict(self, state_dict: dict) -> None:
"""Stage a deterministic fast-forward resume, applied from the next ``__iter__``.
``state_dict`` holds ``{"batches_consumed": int, "batch_size": int}`` — what the trainer
already knows at checkpoint time. Because every consumer's order is a pure function of
(seed, epoch, rank, worker), resume replays the stream while skipping emission (tabular
reads only, no video decode) until each worker reaches its own consumed count; the
DataLoader's round-robin batch assignment makes that count derivable per worker. Exact
within an epoch; crossing epoch boundaries may drift by < one batch per worker per epoch
when ``drop_last`` discards partial batches.
"""
self._resume_state = {
"batches_consumed": int(state_dict["batches_consumed"]),
"batch_size": int(state_dict["batch_size"]),
}
def _consume_resume_state(self, worker_id: int, active_workers: int) -> None:
if self._resume_state is None:
return
batches = self._resume_state["batches_consumed"]
batch_size = self._resume_state["batch_size"]
self._resume_state = None
if worker_id >= active_workers:
return # this worker owns no shards and never delivered a batch
# The DataLoader assigns batch j to active worker j % active_workers.
my_batches = batches // active_workers + (1 if batches % active_workers > worker_id else 0)
self._ff_remaining = my_batches * batch_size
if self._ff_remaining:
logger.info(
f"Streaming resume: worker {worker_id} fast-forwarding {self._ff_remaining} samples "
"(tabular reads only, no video decode)."
)
def video_decoder_cache_stats(self) -> dict[str, int | float]:
"""Decoder-cache reuse aggregated across DataLoader workers via the shared counter tensor.
Unlike ``self.video_decoder_cache.stats()`` (which only reflects the main process), this sums
hits/misses/evictions over every worker. Counts are lock-free across processes, so treat them as
approximate; the ``hit_rate`` ratio is preserved.
"""
hits, misses, evictions = (int(x) for x in self._cache_counters[:3].tolist())
total = hits + misses
return {
"hits": hits,
"misses": misses,
"evictions": evictions,
"hit_rate": round(hits / total, 4) if total else 0.0,
}
def timing_stats(self) -> dict[str, float]:
"""Cumulative seconds spent in video decode and episode (tabular) fetch, summed across
DataLoader workers via the shared counter tensor. These overlap in wall-clock (workers run
in parallel), so compare them to ``num_workers x wallclock`` for time fractions.
"""
decode_ns, fetch_ns = (int(x) for x in self._cache_counters[3:5].tolist())
return {"decode_s_total": round(decode_ns / 1e9, 2), "fetch_s_total": round(fetch_ns / 1e9, 2)}
def _make_pool_sample(self, episode: _PooledEpisode, frame_pos: int) -> dict:
"""Assemble a full training sample for one pooled frame (tabular slices + video decode)."""
rows = episode.rows
item = dict(rows[frame_pos])
ep_idx = episode.episode_index
num_rows = len(rows)
current_ts = float(item["timestamp"])
updates: list[dict] = []
if self.delta_indices is not None:
updates.extend(self._pool_delta_frames(rows, frame_pos, num_rows))
def _finalize_sample(self, row: dict) -> dict:
"""Torch conversion + video decode (decode-on-exit) + transforms + task for one frame."""
window_keys = self._tabular_window_keys()
pad_masks = {f"{key}_is_pad": torch.BoolTensor(row.pop(f"{key}_is_pad")) for key in window_keys}
item = item_to_torch(row)
item.update(pad_masks)
if len(self.meta.video_keys) > 0:
ep_idx = int(item["episode_index"])
current_ts = float(item["timestamp"])
# Per-camera episode-local bounds [0, duration]: out-of-episode deltas pad instead of
# decoding against a neighbouring episode sharing the same video file.
episode_boundaries_ts = {
@@ -679,35 +424,57 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
for cam in self.meta.camera_keys:
video_frames[cam] = self.image_transforms(video_frames[cam])
updates.append(video_frames)
item.update(video_frames)
if self.delta_indices is not None:
updates.append(
item.update(
self._get_video_frame_padding_mask(video_frames, query_timestamps, original_timestamps)
)
result = item
for update in updates:
result.update(update)
result["task"] = self.meta.tasks.iloc[item["task_index"]].name
return result
item["task"] = self.meta.tasks.iloc[int(item["task_index"])].name
return item
def _pool_delta_frames(self, rows: list[dict], frame_pos: int, num_rows: int) -> list[dict]:
"""Exact delta windows by slicing the resident episode; clamped + padded at boundaries."""
query_result: dict = {}
padding: dict = {}
for key, deltas in self.delta_indices.items():
if key in self.meta.video_keys:
continue # visual frames are decoded separately
frames = []
is_pad = []
for delta in deltas:
j = frame_pos + delta
valid = 0 <= j < num_rows
frames.append(rows[min(max(j, 0), num_rows - 1)][key])
is_pad.append(not valid)
query_result[key] = torch.stack(frames)
padding[f"{key}_is_pad"] = torch.BoolTensor(is_pad)
return [query_result, padding]
def set_epoch(self, epoch: int) -> None:
"""Set the epoch the next ``__iter__`` will use (reshuffles the native pipeline)."""
self._epoch = epoch
def state_dict(self) -> dict:
"""Native `datasets` stream state. Exact contract with ``num_workers=0``; with DataLoader
workers use ``torchdata.stateful_dataloader.StatefulDataLoader`` (it checkpoints each
worker's copy through this protocol). Samples in the shuffle buffers are skipped on
resume (never repeated), bounded by the pool + frame buffer sizes.
"""
return {"pipeline": self._pipeline.state_dict(), "epoch": self._in_flight_epoch}
def load_state_dict(self, state_dict: dict) -> None:
# Resume continues inside the recorded epoch: the next __iter__ replays that epoch's
# shuffle order from the restored stream position, then advances normally.
self._epoch = int(state_dict.get("epoch", 0))
self._pipeline.load_state_dict(state_dict["pipeline"])
def video_decoder_cache_stats(self) -> dict[str, int | float]:
"""Decoder-cache reuse aggregated across DataLoader workers via the shared counter tensor.
Unlike ``self.video_decoder_cache.stats()`` (which only reflects the main process), this sums
hits/misses/evictions over every worker. Counts are lock-free across processes, so treat them as
approximate; the ``hit_rate`` ratio is preserved.
"""
hits, misses, evictions = (int(x) for x in self._cache_counters[:3].tolist())
total = hits + misses
return {
"hits": hits,
"misses": misses,
"evictions": evictions,
"hit_rate": round(hits / total, 4) if total else 0.0,
}
def timing_stats(self) -> dict[str, float]:
"""Cumulative seconds spent in video decode and in the upstream tabular pipeline (parquet
fetch + grouping + shuffles + explode), summed across DataLoader workers via the shared
counter tensor. These overlap in wall-clock (workers run in parallel), so compare them to
``num_workers x wallclock`` for time fractions.
"""
decode_ns, fetch_ns = (int(x) for x in self._cache_counters[3:5].tolist())
return {"decode_s_total": round(decode_ns / 1e9, 2), "fetch_s_total": round(fetch_ns / 1e9, 2)}
def _make_timestamps_from_indices(
self, start_ts: float, indices: dict[str, list[int]] | None = None
@@ -785,17 +552,13 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
from_timestamp = self.meta.episodes[ep_idx][f"videos/{video_key}/from_timestamp"]
shifted_query_ts = [from_timestamp + ts for ts in query_ts]
rel_path = str(self.meta.get_video_file_path(ep_idx, video_key))
local = self._prefetcher.wait_local(rel_path) if self._prefetcher is not None else None
if local is not None:
video_path = str(local)
if self.data_files_root is not None:
root = self.data_files_root
elif self.streaming and not self.streaming_from_local:
root = self.meta.url_root
else:
if self.data_files_root is not None:
root = self.data_files_root
elif self.streaming and not self.streaming_from_local:
root = self.meta.url_root
else:
root = self.root
video_path = f"{root}/{rel_path}"
root = self.root
video_path = f"{root}/{rel_path}"
frames = decode_video_frames_torchcodec(
video_path,
shifted_query_ts,
+5
View File
@@ -218,6 +218,11 @@ def test_frames_with_delta_consistency(tmp_path, lerobot_dataset_factory, state_
check = torch.allclose(left, right) and left.shape == right.shape
else:
# Scalar numerics: streaming yields python floats/ints where map-style yields
# 0-dim tensors (long-standing accepted difference). Compare by value.
check = float(left) == float(right)
key_checks.append((key, check))
assert all(t[1] for t in key_checks), (
+31 -121
View File
@@ -148,37 +148,6 @@ def test_sarm_window_covers_long_horizon_without_padding(tmp_path, lerobot_datas
assert checked > 0, "test did not exercise any in-episode long-horizon frame"
def test_fast_forward_resume_is_sample_exact(tmp_path, lerobot_dataset_factory):
"""Resume replays the deterministic stream and continues at the exact sample."""
repo_id = f"{DUMMY_REPO_ID}-resume"
total_frames = 100
_make_local_dataset(
lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=5, total_frames=total_frames
)
def fresh_ds():
return StreamingLeRobotDataset(
repo_id=repo_id,
root=tmp_path / "ds",
shuffle=True,
seed=7,
episode_pool_size=3,
max_num_shards=1,
)
full_epoch = _stream_indices(fresh_ds())
assert sorted(full_epoch) == list(range(total_frames))
batches_consumed, batch_size = 5, 4 # 20 samples in
resumed_ds = fresh_ds()
resumed_ds.load_state_dict({"batches_consumed": batches_consumed, "batch_size": batch_size})
resumed = _stream_indices(resumed_ds)
assert resumed == full_epoch[batches_consumed * batch_size :], (
"fast-forward resume did not continue at the exact sample"
)
def test_pool_order_is_deterministic_per_seed(tmp_path, lerobot_dataset_factory):
repo_id = f"{DUMMY_REPO_ID}-seeds"
_make_local_dataset(lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=6, total_frames=120)
@@ -226,27 +195,6 @@ def test_pool_mixes_episodes(tmp_path, lerobot_dataset_factory):
assert len(episodes_in_head) >= 3, f"pool did not mix episodes: {episodes_in_head}"
def test_video_prefetcher_refcounted_lifecycle(tmp_path):
from lerobot.datasets.streaming_dataset import _VideoPrefetcher
remote = tmp_path / "remote"
(remote / "videos").mkdir(parents=True)
payload = b"x" * 1024
(remote / "videos" / "a.mp4").write_bytes(payload)
prefetcher = _VideoPrefetcher(str(remote), cache_dir=tmp_path / "cache", max_workers=1)
prefetcher.acquire("videos/a.mp4")
prefetcher.acquire("videos/a.mp4") # second pooled episode sharing the file
local = prefetcher.wait_local("videos/a.mp4")
assert local is not None and local.read_bytes() == payload
prefetcher.release("videos/a.mp4")
assert local.exists(), "file deleted while still referenced"
prefetcher.release("videos/a.mp4")
assert not local.exists(), "file not deleted at refcount zero"
prefetcher.shutdown()
def test_schema_parity_with_map_style(tmp_path, lerobot_dataset_factory):
"""Streamed samples must have the same keys / shapes / dtypes as map-style LeRobotDataset."""
repo_id = f"{DUMMY_REPO_ID}-parity"
@@ -318,87 +266,49 @@ def test_shuffle_decorrelates_output_order(tmp_path, lerobot_dataset_factory):
assert shuffled != ordered, "shuffle did not decorrelate output order"
def test_fast_forward_resume_with_dataloader_workers(tmp_path, lerobot_dataset_factory):
"""Resume must be exact under num_workers > 0: each worker re-derives its own skip."""
from torch.utils.data import DataLoader
repo_id = f"{DUMMY_REPO_ID}-resume-workers"
_make_local_dataset(lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=8, total_frames=120)
num_workers = 2
def test_native_resume_never_repeats_and_loss_is_bounded(tmp_path, lerobot_dataset_factory):
"""Native state_dict resume: no sample is re-yielded; loss is bounded by the shuffle buffers."""
repo_id = f"{DUMMY_REPO_ID}-native-resume"
total_frames = 100
_make_local_dataset(
lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=5, total_frames=total_frames
)
def fresh_ds():
return StreamingLeRobotDataset(
repo_id=repo_id,
root=tmp_path / "ds",
shuffle=True,
seed=11,
episode_pool_size=3,
max_num_shards=4,
seed=7,
episode_pool_size=2,
frame_shuffle_buffer_size=8,
)
def epoch_samples(ds):
# batch_size=None yields raw samples; the DataLoader round-robins them across workers,
# which is batch_size=1 in the resume arithmetic.
loader = DataLoader(ds, batch_size=None, num_workers=num_workers)
return [int(sample["index"]) for sample in loader]
ds = fresh_ds()
it = iter(ds)
consumed = [int(next(it)["index"]) for _ in range(30)]
state = ds.state_dict()
full = epoch_samples(fresh_ds())
samples_consumed = 17
resumed_ds = fresh_ds()
resumed_ds.load_state_dict({"batches_consumed": samples_consumed, "batch_size": 1})
resumed = epoch_samples(resumed_ds)
resumed_ds.load_state_dict(state)
rest = [int(frame["index"]) for frame in resumed_ds]
assert resumed == full[samples_consumed:], (
"fast-forward resume with DataLoader workers did not continue at the exact sample"
)
assert not set(consumed) & set(rest), "resume re-yielded already-seen frames"
# in-flight buffer contents are skipped on resume (documented datasets behavior):
# bounded by the episode pool (2 episodes of <= ~30 frames here) + frame buffer (8)
covered = len(set(consumed) | set(rest))
max_in_flight = 2 * 30 + 8
assert covered >= total_frames - max_in_flight
assert covered + len(consumed) >= total_frames - max_in_flight
def test_episode_grouping_native_and_fallback_agree(tmp_path, lerobot_dataset_factory, monkeypatch):
"""The datasets>=5 batch(by_column=...) path must group episodes identically to the row loop."""
import lerobot.datasets.streaming_dataset as sd
repo_id = f"{DUMMY_REPO_ID}-grouping"
_make_local_dataset(lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=5, total_frames=100)
ds = StreamingLeRobotDataset(repo_id=repo_id, root=tmp_path / "ds", shuffle=False, max_num_shards=1)
def episode_signature(use_native):
monkeypatch.setattr(sd, "_HAS_BATCH_BY_COLUMN", use_native)
return [
(ep_idx, [int(row["index"]) for row in rows])
for ep_idx, rows in ds._iter_shard_episodes(ds.hf_dataset)
]
fallback = episode_signature(False)
assert len(fallback) == 5
if not sd._HAS_BATCH_BY_COLUMN and "by_column" not in str(
type(ds.hf_dataset).batch.__doc__ or ""
): # datasets < 5: only the fallback path exists
return
native = episode_signature(True)
assert native == fallback
def test_shard_order_permutation_properties(tmp_path, lerobot_dataset_factory):
"""Shard order: a valid permutation, deterministic per (seed, epoch, rank), worker-independent
(workers stride the same list, so it must not depend on worker id), reshuffled across epochs,
and identity when shuffle is off."""
repo_id = f"{DUMMY_REPO_ID}-shardorder"
def test_pipeline_uses_native_primitives(tmp_path, lerobot_dataset_factory):
"""The tabular pipeline is pure datasets: batch(by_column) + shuffle + map + shuffle."""
repo_id = f"{DUMMY_REPO_ID}-native-pipe"
_make_local_dataset(lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=4, total_frames=80)
ds = StreamingLeRobotDataset(repo_id=repo_id, root=tmp_path / "ds", shuffle=True, episode_pool_size=2)
import datasets as hf_datasets
ds = StreamingLeRobotDataset(repo_id=repo_id, root=tmp_path / "ds", shuffle=True, seed=5)
num_shards = 32
order_epoch0 = ds._shard_order(0, num_shards)
assert sorted(order_epoch0) == list(range(num_shards))
assert ds._shard_order(0, num_shards) == order_epoch0 # deterministic
assert ds._shard_order(1, num_shards) != order_epoch0 # reshuffles per epoch
assert order_epoch0 != list(range(num_shards)) # actually permuted (P=1/32! of false alarm)
other_rank = StreamingLeRobotDataset(
repo_id=repo_id, root=tmp_path / "ds", shuffle=True, seed=5, rank=1, world_size=2
)
assert other_rank._shard_order(0, num_shards) != order_epoch0 # ranks decorrelated
unshuffled = StreamingLeRobotDataset(repo_id=repo_id, root=tmp_path / "ds", shuffle=False, seed=5)
assert unshuffled._shard_order(0, num_shards) == list(range(num_shards))
assert isinstance(ds._pipeline, hf_datasets.IterableDataset)
state = ds._pipeline.state_dict() # the native resume protocol is available end-to-end
assert state is not None