From 894fc6bfb5c449ffc56fab46534c8189097a29d9 Mon Sep 17 00:00:00 2001 From: Pepijn Date: Thu, 11 Jun 2026 21:03:09 +0200 Subject: [PATCH] refactor(streaming): rebuild StreamingLeRobotDataset on native datasets primitives The custom episode pool becomes a pure `datasets` pipeline: split_dataset_by_node -> batch(by_column="episode_index") -> shuffle(buffer=episode_pool_size) # episode pool -> map(explode + exact delta windows) # episode -> frames -> shuffle(buffer=frame_shuffle_buffer_size) # frame interleave and the torch IterableDataset wrapper keeps only per-sample video decode (decode-on-exit), image transforms, task lookup, and decode/fetch timing. Replaced by native machinery and deleted: the pooled-episode admission loop, the refcounted video prefetcher, manual worker shard striding plus the worker-split suppression patch, the per-(epoch, rank) shard-order permutation, the per-consumer SplitMix64 RNG, and fast-forward resume. DataLoader workers are split by `datasets` itself; .shuffle() permutes shard order per epoch natively; resume delegates to the native state_dict/load_state_dict (exact with num_workers=0; with workers use torchdata's StatefulDataLoader, which checkpoints per-worker state through the same protocol). An in-flight epoch counter ensures a mid-iteration state_dict records the epoch the stream position belongs to. Buffer contents are skipped on resume (documented datasets behavior): never repeats data, drops at most ~pool + frame-buffer frames. Randomness is unchanged: a batch still mixes up to episode_pool_size episodes; delta windows are still exact in-episode slices with correct boundary padding (value-verified against the map-style dataset). The known trade accepted with this rewrite: no video prefetch-on-admit, so remote decode pays per-frame range reads at yield time - use a colocated bucket (data_files_root) at large scale. The delta-consistency tests gained a scalar-comparison branch: they silently skipped python-scalar keys before (stale `check` variable), exposed by the new pipeline's key ordering. Requires datasets with #8259 (pinned to the merge commit on this branch). Example updated to per-rank native resume via torchdata's StatefulDataLoader when available. Co-Authored-By: Claude Fable 5 --- examples/scaling/train_streaming_multinode.py | 45 +- src/lerobot/datasets/streaming_dataset.py | 623 ++++++------------ tests/datasets/test_streaming.py | 5 + tests/datasets/test_streaming_native.py | 152 +---- 4 files changed, 258 insertions(+), 567 deletions(-) diff --git a/examples/scaling/train_streaming_multinode.py b/examples/scaling/train_streaming_multinode.py index af3e4c6b0..f2983bfbb 100644 --- a/examples/scaling/train_streaming_multinode.py +++ b/examples/scaling/train_streaming_multinode.py @@ -21,7 +21,7 @@ streaming features of :class:`StreamingLeRobotDataset`: - per-rank sharding via ``split_dataset_by_node`` (each GPU streams disjoint data; ``rank``/``world_size`` are auto-resolved from the Accelerate state, so nothing needs to be passed explicitly); - DataLoader-worker shard splitting (no duplicate frames within a rank); -- deterministic fast-forward resume via ``dataset.load_state_dict()`` (trainer-side counters only); +- native `datasets` resume: the loader checkpoints stream state via ``state_dict()`` (``torchdata`` StatefulDataLoader when available, so ``num_workers > 0`` resumes too); - an explicit video-decoder cache size so the working set of open decoders does not thrash. Launch with Accelerate (single node, N GPUs): @@ -85,7 +85,16 @@ def make_dataloader( video_decoder_cache_size=args.video_decoder_cache_size, tolerance_s=1e-3, ) - loader = DataLoader( + # torchdata's StatefulDataLoader checkpoints each worker's dataset state through the + # dataset's native state_dict protocol, making resume work with num_workers > 0. Fall back + # to the plain DataLoader (resume then requires num_workers=0). + try: + from torchdata.stateful_dataloader import StatefulDataLoader + + loader_cls = StatefulDataLoader + except ImportError: + loader_cls = DataLoader + loader = loader_cls( dataset, batch_size=args.batch_size, num_workers=args.num_workers, @@ -124,13 +133,17 @@ def main() -> None: # of it). Batches are moved to the device manually in the loop. model, optimizer = accelerator.prepare(model, optimizer) - # Resume: deterministic fast-forward. Every consumer's order is a pure function of - # (seed, epoch, rank, worker), so resuming only needs the trainer-side counters; each rank and - # worker re-derives its own skip. Same file works for every rank. + # Resume: native datasets stream state, saved per rank. With torchdata's StatefulDataLoader + # the state covers every worker; with the plain DataLoader it is exact for num_workers=0. + can_checkpoint_loader = hasattr(loader, "state_dict") if args.resume_from is not None: - state = torch.load(Path(args.resume_from) / "dataset_state.pt", weights_only=True) - dataset.load_state_dict(state) - accelerator.print(f"Resuming dataset stream: {state['batches_consumed']} batches consumed") + state_path = Path(args.resume_from) / f"dataset_state_rank{accelerator.process_index}.pt" + state = torch.load(state_path, weights_only=False) # plain dict of stream offsets # nosec B614 + if can_checkpoint_loader: + loader.load_state_dict(state) + else: + dataset.load_state_dict(state) + accelerator.print(f"Resumed dataset stream from {state_path}") step = 0 frames_seen = 0 @@ -157,15 +170,15 @@ def main() -> None: ) window_start = time.perf_counter() - if step % args.save_freq == 0 and accelerator.is_main_process: + if step % args.save_freq == 0: ckpt = output_dir / f"checkpoint-{step}" - ckpt.mkdir(parents=True, exist_ok=True) - # Save the consumed-batch counters so a restart fast-forwards to this position. - torch.save( - {"batches_consumed": step, "batch_size": args.batch_size}, - ckpt / "dataset_state.pt", - ) - if model is not None: + if accelerator.is_main_process: + ckpt.mkdir(parents=True, exist_ok=True) + accelerator.wait_for_everyone() + # Every rank saves its own stream state: shard positions differ per rank. + state = loader.state_dict() if can_checkpoint_loader else dataset.state_dict() + torch.save(state, ckpt / f"dataset_state_rank{accelerator.process_index}.pt") + if model is not None and accelerator.is_main_process: accelerator.unwrap_model(model).save_pretrained(ckpt) if step >= args.steps: diff --git a/src/lerobot/datasets/streaming_dataset.py b/src/lerobot/datasets/streaming_dataset.py index 2f69867e1..7d4f9fa8a 100644 --- a/src/lerobot/datasets/streaming_dataset.py +++ b/src/lerobot/datasets/streaming_dataset.py @@ -13,18 +13,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import contextlib -import inspect import logging -import os -import shutil import time from collections.abc import Callable, Iterator -from concurrent.futures import Future, ThreadPoolExecutor from pathlib import Path import datasets -import fsspec import numpy as np import torch from datasets import load_dataset @@ -39,7 +33,6 @@ from .utils import ( check_version_compatibility, find_float_index, is_float_in_list, - safe_shard, ) from .video_utils import ( VideoDecoderCache, @@ -48,134 +41,38 @@ from .video_utils import ( logger = logging.getLogger(__name__) -# datasets >= 5 groups a stream into whole-episode batches natively (Arrow-side accumulation, -# https://github.com/huggingface/datasets/pull/8172); older versions fall back to a Python row loop. -_HAS_BATCH_BY_COLUMN = "by_column" in inspect.signature(datasets.IterableDataset.batch).parameters - -_MASK_64 = (1 << 64) - 1 - - -def _mix64(x: int) -> int: - """SplitMix64 finalizer (64-bit integer hash) for seed derivation.""" - x = (x + 0x9E3779B97F4A7C15) & _MASK_64 - x ^= x >> 30 - x = (x * 0xBF58476D1CE4E5B9) & _MASK_64 - x ^= x >> 27 - x = (x * 0x94D049BB133111EB) & _MASK_64 - x ^= x >> 31 - return x - - -@contextlib.contextmanager -def _suppress_hf_worker_split(): - """Hide the torch DataLoader worker context from `datasets` while we drain its streams. - - `datasets` detects torch workers and re-splits its shards across them internally - (`_iter_pytorch`); this dataset already assigns disjoint shards per worker, so the second - split silently drops data whenever a per-worker stream has fewer internal shards than there - are workers — and on datasets 5.0 it also crashes `batch(by_column=...)`. The patch is local - to this DataLoader worker process and restored on exit. - """ - original = torch.utils.data.get_worker_info - torch.utils.data.get_worker_info = lambda: None - try: - yield - finally: - torch.utils.data.get_worker_info = original - - -class _PooledEpisode: - """A fully-loaded episode's tabular rows plus emission bookkeeping.""" - - __slots__ = ("episode_index", "rows", "remaining", "video_rel_paths") - - def __init__(self, episode_index: int, rows: list[dict], video_rel_paths: list[str]): - self.episode_index = episode_index - self.rows = rows - self.remaining = list(range(len(rows))) - self.video_rel_paths = video_rel_paths - - -class _VideoPrefetcher: - """Background downloader of episode video files into a local cache (decode-on-exit support). - - Files are refcounted because LeRobot v3 packs several episodes per video file: a file is - downloaded once when the first pooled episode referencing it is admitted and deleted when - the last one is evicted. Downloads resolve through fsspec (hf://, s3://, https://, ...). - """ - - def __init__(self, remote_root: str, cache_dir: Path, max_workers: int = 4): - self._remote_root = remote_root.rstrip("/") - self._cache_dir = cache_dir - self._executor = ThreadPoolExecutor(max_workers=max_workers, thread_name_prefix="video-prefetch") - self._refcounts: dict[str, int] = {} - self._futures: dict[str, Future] = {} - - def acquire(self, rel_path: str) -> None: - self._refcounts[rel_path] = self._refcounts.get(rel_path, 0) + 1 - if rel_path not in self._futures: - self._futures[rel_path] = self._executor.submit(self._download, rel_path) - - def _download(self, rel_path: str) -> Path: - local = self._cache_dir / rel_path - if local.exists(): - return local - local.parent.mkdir(parents=True, exist_ok=True) - tmp = local.with_suffix(local.suffix + ".tmp") - with fsspec.open(f"{self._remote_root}/{rel_path}", "rb") as src, open(tmp, "wb") as dst: - shutil.copyfileobj(src, dst, length=1 << 22) - tmp.rename(local) - return local - - def wait_local(self, rel_path: str) -> Path | None: - """Block until the file is cached; None when not tracked or the download failed.""" - future = self._futures.get(rel_path) - if future is None: - return None - try: - return future.result() - except Exception as e: - logger.warning(f"Video prefetch failed for {rel_path} ({e}); decoding from remote instead.") - return None - - def release(self, rel_path: str) -> None: - count = self._refcounts.get(rel_path, 0) - 1 - if count > 0: - self._refcounts[rel_path] = count - return - self._refcounts.pop(rel_path, None) - future = self._futures.pop(rel_path, None) - if future is None: - return - if not future.cancel(): - try: - local = future.result() - local.unlink(missing_ok=True) - except Exception: - logger.debug(f"Could not delete cached video {rel_path}.", exc_info=True) - - def shutdown(self) -> None: - self._executor.shutdown(wait=False, cancel_futures=True) +# Bound the default frame-level shuffle buffer: rows are tabular-only (~KB each), so this is +# roughly a few hundred MB of host RAM per consumer at the cap. +_MAX_DEFAULT_FRAME_BUFFER = 200_000 class StreamingLeRobotDataset(torch.utils.data.IterableDataset): - """LeRobotDataset with streaming capabilities. + """LeRobotDataset with streaming capabilities, built on native HF `datasets` primitives. - Streams frames from the Hub (or any fsspec source) without downloading the dataset. The - iteration strategy is an *episode pool*: each consumer keeps ``episode_pool_size`` whole - episodes' tabular rows in RAM (a few KB per episode) and emits uniformly random frames - across them, so a batch mixes up to ``batch_size`` distinct episodes. Because a frame's - whole episode is resident, ``delta_timestamps`` windows are exact array slices with correct - padding at episode boundaries. Video is decoded only when a sample is emitted - (decode-on-exit), so pool memory stays tabular-sized; when streaming from a remote source, - each pooled episode's video files are prefetched to a local cache in the background and - deleted on eviction. + The tabular side is a pure `datasets` pipeline:: - Distribution: ranks stream disjoint shards via ``split_dataset_by_node`` and DataLoader - workers split a rank's shards further, so every frame is consumed exactly once per epoch - across the whole fleet. Each consumer's order is a pure function of - ``(seed, epoch, rank, worker)``, which makes resume a deterministic fast-forward (see - :meth:`load_state_dict`). + load_dataset(streaming=True) # parquet shards from the Hub / a bucket + -> split_dataset_by_node(rank, world_size) # disjoint shards per rank + -> batch(by_column="episode_index") # whole episodes + -> shuffle(buffer_size=episode_pool_size) # episode pool (the randomness knob) + -> map(explode + exact delta windows) # episode -> frames, windows are exact + -> shuffle(buffer_size=frame_shuffle_buffer_size) # frame-level interleave + + and this class is a thin torch ``IterableDataset`` wrapper around it that decodes video + per emitted sample (decode-on-exit), applies image transforms, and attaches the task + string. DataLoader workers are split natively by `datasets` (disjoint shards per worker), + and resume uses the native ``state_dict`` / ``load_state_dict``. + + Randomness: a batch mixes up to ``episode_pool_size`` distinct episodes; delta windows are + exact slices of the resident episode with correct padding at episode boundaries. + + Resume: ``state_dict()`` / ``load_state_dict()`` delegate to `datasets`. Samples sitting in + the shuffle buffers at checkpoint time are skipped on resume (documented `datasets` + behavior), so resume never repeats data but may drop up to roughly + ``episode_pool_size x episode_len + frame_shuffle_buffer_size`` frames — negligible at + training scale. The contract is exact with ``num_workers=0``; with DataLoader workers use + ``torchdata.stateful_dataloader.StatefulDataLoader``, which checkpoints each worker's + dataset state through this same protocol. Example: ```python @@ -201,6 +98,7 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): force_cache_sync: bool = False, streaming: bool = True, episode_pool_size: int | None = 64, + frame_shuffle_buffer_size: int | None = None, buffer_size: int | None = None, max_num_shards: int | None = None, seed: int = 42, @@ -212,8 +110,6 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): video_decoder_cache_size: int | None = None, data_files_root: str | None = None, video_decode_device: str = "cpu", - prefetch_videos: bool = True, - video_prefetch_workers: int = 4, ): """Initialize a StreamingLeRobotDataset. @@ -229,16 +125,18 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): revision (str, optional): Git revision id (branch name, tag, or commit hash). force_cache_sync (bool, optional): Flag to sync and refresh local files first. streaming (bool, optional): Whether to stream the dataset or load it all. Defaults to True. - episode_pool_size (int, optional): Number of whole episodes each consumer keeps open to - shuffle across — the randomness knob. Larger mixes more episodes per batch (closer to - map-style uniform) at the cost of cold-start latency; RAM stays small because the pool - holds tabular rows only. Defaults to 64. + episode_pool_size (int, optional): Whole episodes each consumer keeps open to shuffle + across — the randomness knob. Larger mixes more episodes per batch (closer to + map-style uniform) at the cost of cold-start latency and frame-buffer RAM. + Defaults to 64. + frame_shuffle_buffer_size (int | None, optional): Frame-level shuffle buffer after the + episode pool. Defaults to ``episode_pool_size x average episode length`` (capped), + which matches the pool's mixing radius. buffer_size (int | None, optional): Deprecated; superseded by ``episode_pool_size``. - max_num_shards (int | None, optional): Cap on the number of stream shards. None (default) - uses every underlying parquet shard, which is required to feed many DataLoader workers. + max_num_shards (int | None, optional): Deprecated; `datasets` handles shard-to-worker + assignment natively. seed (int, optional): Reproducibility random seed. - rng (np.random.Generator | None, optional): Deprecated; ignored (the RNG is derived from - ``(seed, epoch, rank, worker)`` so consumers are decorrelated and runs reproducible). + rng (np.random.Generator | None, optional): Deprecated; ignored. shuffle (bool, optional): Whether to shuffle. False yields episodes in stream order. rank (int | None, optional): This process' rank for distributed training. Each rank streams a disjoint set of shards via ``split_dataset_by_node``. When omitted, resolved from @@ -253,10 +151,6 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): from there while metadata still loads from ``repo_id`` on the Hub. video_decode_device (str, optional): Device for torchcodec decode. ``"cuda"`` offloads to NVDEC (needs a CUDA torchcodec build and ``spawn`` DataLoader workers). - prefetch_videos (bool, optional): When streaming from a remote source, download each pooled - episode's video files to a local cache in the background so decode-on-exit reads local - bytes instead of paying network seek latency. Defaults to True. - video_prefetch_workers (int, optional): Download threads per consumer. Defaults to 4. """ super().__init__() self.repo_id = repo_id @@ -271,37 +165,35 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): self.seed = seed if rng is not None: logger.warning("StreamingLeRobotDataset: `rng` is deprecated and ignored; use `seed`.") - self.shuffle = shuffle - - self.streaming = streaming if buffer_size is not None: logger.warning( "StreamingLeRobotDataset: `buffer_size` is deprecated and ignored; " "use `episode_pool_size` (whole episodes, not frames)." ) + if max_num_shards is not None: + logger.warning( + "StreamingLeRobotDataset: `max_num_shards` is deprecated and ignored; " + "`datasets` assigns shards to DataLoader workers natively." + ) + self.shuffle = shuffle + + self.streaming = streaming self.episode_pool_size = max(1, episode_pool_size) if episode_pool_size else 64 - self.max_num_shards = max_num_shards self._return_uint8 = return_uint8 self.rank, self.world_size = self._resolve_distributed(rank, world_size) self.video_decoder_cache_size = video_decoder_cache_size self.data_files_root = data_files_root.rstrip("/") if data_files_root else None self.video_decode_device = video_decode_device - self.prefetch_videos = prefetch_videos - self.video_prefetch_workers = video_prefetch_workers # We cache the video decoders to avoid re-initializing them at each frame (avoiding a ~10x slowdown) self.video_decoder_cache = None - self._prefetcher: _VideoPrefetcher | None = None # Shared [hits, misses, evictions, decode_ns, fetch_ns] tensor so DataLoader workers aggregate # decoder-cache stats and component timings into one place the main process can read after # iteration (see video_decoder_cache_stats() / timing_stats()). self._cache_counters = torch.zeros(5, dtype=torch.int64).share_memory_() - # Deterministic fast-forward resume (see load_state_dict): per-consumer epoch counter and - # number of samples still to skip. self._epoch = 0 - self._ff_remaining = 0 - self._resume_state: dict | None = None + self._in_flight_epoch = 0 if self._requested_root is not None: self.root.mkdir(exist_ok=True, parents=True) @@ -349,12 +241,17 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): if extra_columns: self.hf_dataset = self.hf_dataset.remove_columns(extra_columns) - self.num_shards = ( - self.hf_dataset.num_shards - if self.max_num_shards is None - else min(self.hf_dataset.num_shards, self.max_num_shards) + self.num_shards = self.hf_dataset.num_shards + + avg_episode_len = max(1, round(self.meta.total_frames / max(1, self.meta.total_episodes))) + self.frame_shuffle_buffer_size = ( + frame_shuffle_buffer_size + if frame_shuffle_buffer_size is not None + else min(self.episode_pool_size * avg_episode_len, _MAX_DEFAULT_FRAME_BUFFER) ) + self._pipeline = self._build_pipeline() + @property def num_frames(self): return self.meta.total_frames @@ -374,6 +271,8 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): Explicit arguments win. Otherwise prefer an already-initialized Accelerate state, then the ``RANK``/``WORLD_SIZE`` env vars set by launchers, and finally fall back to single-process (0, 1). """ + import os + if rank is not None and world_size is not None: return rank, world_size @@ -393,27 +292,67 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): return 0, 1 - def _consumer_rng(self, epoch: int, worker_id: int) -> np.random.Generator: - """RNG derived from (seed, epoch, rank, worker): reproducible, decorrelated consumers.""" - state = _mix64(self.seed) - for salt in (self.rank, worker_id, epoch if self.shuffle else 0): - state = _mix64(state ^ _mix64(salt)) - return np.random.default_rng(state) + def _build_pipeline(self) -> datasets.IterableDataset: + """Assemble the native tabular pipeline (everything except video decode).""" + ds = self.hf_dataset + if self.world_size > 1: + if ds.num_shards % self.world_size != 0: + logger.warning( + f"num_shards ({ds.num_shards}) is not divisible by world_size ({self.world_size}): " + "datasets falls back to example-level splitting where every rank reads (and pays " + "for) the full stream. Re-shard the dataset or adjust world size." + ) + ds = split_dataset_by_node(ds, rank=self.rank, world_size=self.world_size) - def _shard_order(self, epoch: int, num_shards: int) -> list[int]: - """Seeded permutation of this rank's shard indices, re-drawn every epoch. - - In a sub-epoch run over a corpus consolidated source-by-source, index-order shard - assignment means training on whatever the first N% of files contains; permuting the - shard order turns that into a uniform sample of files. Seeded by (seed, epoch, rank) - only — every DataLoader worker of the rank must agree on this list, because workers - stride it and disagreement would create overlapping shard assignments. - """ - order = list(range(num_shards)) + ds = ds.batch(by_column="episode_index") + episode_columns = list(ds.column_names or self.hf_dataset.column_names or []) if self.shuffle: - state = _mix64(self.seed) ^ _mix64(0x5EED5EED) ^ _mix64(self.rank) ^ _mix64(epoch) - np.random.default_rng(_mix64(state)).shuffle(order) - return order + ds = ds.shuffle(seed=self.seed, buffer_size=self.episode_pool_size) + # A row-count-changing batched map must drop the input columns explicitly; the exploded + # frames re-emit them (windowed keys replaced by their delta windows + *_is_pad masks). + ds = ds.map(self._explode_episodes, batched=True, remove_columns=episode_columns) + if self.shuffle: + ds = ds.shuffle(seed=self.seed + 1, buffer_size=max(2, self.frame_shuffle_buffer_size)) + return ds + + def _tabular_window_keys(self) -> list[str]: + if self.delta_indices is None: + return [] + return [key for key in self.delta_indices if key not in self.meta.video_keys] + + def _explode_episodes(self, episode_batch: dict[str, list[list]]) -> dict[str, list]: + """Episode batches -> per-frame rows, with exact tabular delta windows and pad masks. + + Runs inside the `datasets` pipeline (plain Python values, no torch). For each windowed key + the original per-frame value is replaced by its delta window (list of values, clamped to + the episode bounds) plus a ``{key}_is_pad`` mask, mirroring the map-style dataset. + """ + window_keys = set(self._tabular_window_keys()) + out: dict[str, list] = {key: [] for key in episode_batch if key not in window_keys} + for key in window_keys: + out[key] = [] + out[f"{key}_is_pad"] = [] + + num_episodes = len(episode_batch["episode_index"]) + for e in range(num_episodes): + length = len(episode_batch["episode_index"][e]) + for key, column in episode_batch.items(): + if key in window_keys: + continue + out[key].extend(column[e]) + for key in window_keys: + episode_column = episode_batch[key][e] + deltas = self.delta_indices[key] + for t in range(length): + window = [] + is_pad = [] + for delta in deltas: + j = t + delta + window.append(episode_column[min(max(j, 0), length - 1)]) + is_pad.append(not 0 <= j < length) + out[key].append(window) + out[f"{key}_is_pad"].append(is_pad) + return out def _make_video_decoder_cache(self) -> VideoDecoderCache: """Size the decoder cache to the pool's working set (pool episodes x cameras), capped at 128.""" @@ -432,231 +371,37 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): device=self.video_decode_device, ) - def _make_prefetcher(self) -> _VideoPrefetcher | None: - if not self.prefetch_videos or len(self.meta.video_keys) == 0: - return None - if self.data_files_root is not None: - remote_root = self.data_files_root - elif self.streaming and not self.streaming_from_local: - remote_root = self.meta.url_root - else: - return None # video bytes are already local - return _VideoPrefetcher( - remote_root, - cache_dir=self.root / "streaming_video_cache", - max_workers=self.video_prefetch_workers, - ) - - @staticmethod - def _iter_shard_episodes(shard: datasets.IterableDataset) -> Iterator[tuple[int, list[dict]]]: - """Yield (episode_index, rows) for each complete episode of a shard stream. - - On datasets >= 5 the grouping runs natively in Arrow via ``batch(by_column=...)`` - (one accumulation per episode instead of one Python dict per row); older versions - use the equivalent row loop. - """ - if _HAS_BATCH_BY_COLUMN: - for batch in shard.batch(by_column="episode_index"): - keys = list(batch.keys()) - num_rows = len(batch["episode_index"]) - rows = [{key: batch[key][i] for key in keys} for i in range(num_rows)] - yield int(batch["episode_index"][0]), rows - return - rows: list[dict] = [] - current: int | None = None - for item in shard: - ep_idx = int(item["episode_index"]) - if current is None: - current = ep_idx - if ep_idx != current: - yield current, rows - rows = [] - current = ep_idx - rows.append(item) - if rows: - yield current, rows - - def _admit_episode(self, ep_idx: int, rows: list[dict], prefetcher: _VideoPrefetcher | None): - video_rel_paths = [str(self.meta.get_video_file_path(ep_idx, key)) for key in self.meta.video_keys] - if prefetcher is not None: - for rel in video_rel_paths: - prefetcher.acquire(rel) - torch_rows = [item_to_torch(row) for row in rows] - return _PooledEpisode(ep_idx, torch_rows, video_rel_paths) - def __iter__(self) -> Iterator[dict[str, torch.Tensor]]: - ds = self.hf_dataset - if self.world_size > 1: - if ds.num_shards % self.world_size != 0: - logger.warning( - f"num_shards ({ds.num_shards}) is not divisible by world_size ({self.world_size}): " - "datasets falls back to example-level splitting where every rank reads (and pays " - "for) the full stream. Re-shard the dataset or adjust world size." - ) - ds = split_dataset_by_node(ds, rank=self.rank, world_size=self.world_size) - - num_shards = ds.num_shards if self.max_num_shards is None else min(ds.num_shards, self.max_num_shards) - epoch = self._epoch + # `datasets` reshuffles (and re-permutes shard order) per epoch from (seed, epoch); + # DataLoader workers each advance their own copy's counter in lockstep. The in-flight + # epoch is tracked separately so a mid-iteration state_dict() records the epoch the + # stream position actually belongs to. + self._in_flight_epoch = self._epoch + self._pipeline.set_epoch(self._in_flight_epoch) self._epoch += 1 - shard_indices = self._shard_order(epoch, num_shards) - - # DataLoader workers within this rank further split the shards so they don't yield duplicates. - worker_info = torch.utils.data.get_worker_info() - worker_id, num_workers = (worker_info.id, worker_info.num_workers) if worker_info else (0, 1) - shard_indices = shard_indices[worker_id::num_workers] - if not shard_indices: - logger.warning( - f"Worker {worker_id} owns no shards ({num_shards} shards < {num_workers} workers): " - "it will yield nothing. Reduce num_workers or re-shard the dataset." - ) - return - self.video_decoder_cache = self._make_video_decoder_cache() - prefetcher = self._make_prefetcher() - self._prefetcher = prefetcher - rng = self._consumer_rng(epoch, worker_id) - # Workers beyond the shard count yield nothing and are stopped by the DataLoader, so the - # batch round-robin effectively runs over min(num_workers, num_shards) active workers. - self._consume_resume_state(worker_id, min(num_workers, num_shards)) + iterator = iter(self._pipeline) + while True: + fetch_start = time.perf_counter_ns() + try: + row = next(iterator) + except StopIteration: + return + finally: + self._cache_counters[4] += time.perf_counter_ns() - fetch_start + yield self._finalize_sample(row) - # Round-robin episode admission across this consumer's shard streams (deterministic). - streams = [self._iter_shard_episodes(safe_shard(ds, idx, num_shards)) for idx in shard_indices] - next_stream = 0 - - pool: list[_PooledEpisode] = [] - total_remaining = 0 - - def admit() -> int: - nonlocal next_stream, total_remaining - admitted = 0 - while len(pool) < self.episode_pool_size and streams: - stream = streams[next_stream % len(streams)] - fetch_start = time.perf_counter_ns() - try: - ep_idx, rows = next(stream) - except StopIteration: - streams.remove(stream) - continue - finally: - self._cache_counters[4] += time.perf_counter_ns() - fetch_start - next_stream += 1 - episode = self._admit_episode(ep_idx, rows, prefetcher) - pool.append(episode) - total_remaining += len(episode.remaining) - admitted += 1 - return admitted - - worker_split_guard = ( - _suppress_hf_worker_split() if worker_info is not None else contextlib.nullcontext() - ) - try: - with worker_split_guard: - admit() - while pool: - if self.shuffle: - # Uniform draw over every remaining frame in the pool: pick the episode by - # cumulative remaining count, then a random remaining position (swap-pop). - draw = int(rng.integers(total_remaining)) - for episode in pool: - if draw < len(episode.remaining): - break - draw -= len(episode.remaining) - pick = int(rng.integers(len(episode.remaining))) - frame_pos = episode.remaining[pick] - episode.remaining[pick] = episode.remaining[-1] - episode.remaining.pop() - else: - episode = pool[0] - frame_pos = episode.remaining.pop(0) - total_remaining -= 1 - - if self._ff_remaining > 0: - self._ff_remaining -= 1 - else: - yield self._make_pool_sample(episode, frame_pos) - - if not episode.remaining: - pool.remove(episode) - if prefetcher is not None: - for rel in episode.video_rel_paths: - prefetcher.release(rel) - admit() - finally: - if prefetcher is not None: - prefetcher.shutdown() - self._prefetcher = None - - def load_state_dict(self, state_dict: dict) -> None: - """Stage a deterministic fast-forward resume, applied from the next ``__iter__``. - - ``state_dict`` holds ``{"batches_consumed": int, "batch_size": int}`` — what the trainer - already knows at checkpoint time. Because every consumer's order is a pure function of - (seed, epoch, rank, worker), resume replays the stream while skipping emission (tabular - reads only, no video decode) until each worker reaches its own consumed count; the - DataLoader's round-robin batch assignment makes that count derivable per worker. Exact - within an epoch; crossing epoch boundaries may drift by < one batch per worker per epoch - when ``drop_last`` discards partial batches. - """ - self._resume_state = { - "batches_consumed": int(state_dict["batches_consumed"]), - "batch_size": int(state_dict["batch_size"]), - } - - def _consume_resume_state(self, worker_id: int, active_workers: int) -> None: - if self._resume_state is None: - return - batches = self._resume_state["batches_consumed"] - batch_size = self._resume_state["batch_size"] - self._resume_state = None - if worker_id >= active_workers: - return # this worker owns no shards and never delivered a batch - # The DataLoader assigns batch j to active worker j % active_workers. - my_batches = batches // active_workers + (1 if batches % active_workers > worker_id else 0) - self._ff_remaining = my_batches * batch_size - if self._ff_remaining: - logger.info( - f"Streaming resume: worker {worker_id} fast-forwarding {self._ff_remaining} samples " - "(tabular reads only, no video decode)." - ) - - def video_decoder_cache_stats(self) -> dict[str, int | float]: - """Decoder-cache reuse aggregated across DataLoader workers via the shared counter tensor. - - Unlike ``self.video_decoder_cache.stats()`` (which only reflects the main process), this sums - hits/misses/evictions over every worker. Counts are lock-free across processes, so treat them as - approximate; the ``hit_rate`` ratio is preserved. - """ - hits, misses, evictions = (int(x) for x in self._cache_counters[:3].tolist()) - total = hits + misses - return { - "hits": hits, - "misses": misses, - "evictions": evictions, - "hit_rate": round(hits / total, 4) if total else 0.0, - } - - def timing_stats(self) -> dict[str, float]: - """Cumulative seconds spent in video decode and episode (tabular) fetch, summed across - DataLoader workers via the shared counter tensor. These overlap in wall-clock (workers run - in parallel), so compare them to ``num_workers x wallclock`` for time fractions. - """ - decode_ns, fetch_ns = (int(x) for x in self._cache_counters[3:5].tolist()) - return {"decode_s_total": round(decode_ns / 1e9, 2), "fetch_s_total": round(fetch_ns / 1e9, 2)} - - def _make_pool_sample(self, episode: _PooledEpisode, frame_pos: int) -> dict: - """Assemble a full training sample for one pooled frame (tabular slices + video decode).""" - rows = episode.rows - item = dict(rows[frame_pos]) - ep_idx = episode.episode_index - num_rows = len(rows) - current_ts = float(item["timestamp"]) - - updates: list[dict] = [] - if self.delta_indices is not None: - updates.extend(self._pool_delta_frames(rows, frame_pos, num_rows)) + def _finalize_sample(self, row: dict) -> dict: + """Torch conversion + video decode (decode-on-exit) + transforms + task for one frame.""" + window_keys = self._tabular_window_keys() + pad_masks = {f"{key}_is_pad": torch.BoolTensor(row.pop(f"{key}_is_pad")) for key in window_keys} + item = item_to_torch(row) + item.update(pad_masks) if len(self.meta.video_keys) > 0: + ep_idx = int(item["episode_index"]) + current_ts = float(item["timestamp"]) # Per-camera episode-local bounds [0, duration]: out-of-episode deltas pad instead of # decoding against a neighbouring episode sharing the same video file. episode_boundaries_ts = { @@ -679,35 +424,57 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): for cam in self.meta.camera_keys: video_frames[cam] = self.image_transforms(video_frames[cam]) - updates.append(video_frames) + item.update(video_frames) if self.delta_indices is not None: - updates.append( + item.update( self._get_video_frame_padding_mask(video_frames, query_timestamps, original_timestamps) ) - result = item - for update in updates: - result.update(update) - result["task"] = self.meta.tasks.iloc[item["task_index"]].name - return result + item["task"] = self.meta.tasks.iloc[int(item["task_index"])].name + return item - def _pool_delta_frames(self, rows: list[dict], frame_pos: int, num_rows: int) -> list[dict]: - """Exact delta windows by slicing the resident episode; clamped + padded at boundaries.""" - query_result: dict = {} - padding: dict = {} - for key, deltas in self.delta_indices.items(): - if key in self.meta.video_keys: - continue # visual frames are decoded separately - frames = [] - is_pad = [] - for delta in deltas: - j = frame_pos + delta - valid = 0 <= j < num_rows - frames.append(rows[min(max(j, 0), num_rows - 1)][key]) - is_pad.append(not valid) - query_result[key] = torch.stack(frames) - padding[f"{key}_is_pad"] = torch.BoolTensor(is_pad) - return [query_result, padding] + def set_epoch(self, epoch: int) -> None: + """Set the epoch the next ``__iter__`` will use (reshuffles the native pipeline).""" + self._epoch = epoch + + def state_dict(self) -> dict: + """Native `datasets` stream state. Exact contract with ``num_workers=0``; with DataLoader + workers use ``torchdata.stateful_dataloader.StatefulDataLoader`` (it checkpoints each + worker's copy through this protocol). Samples in the shuffle buffers are skipped on + resume (never repeated), bounded by the pool + frame buffer sizes. + """ + return {"pipeline": self._pipeline.state_dict(), "epoch": self._in_flight_epoch} + + def load_state_dict(self, state_dict: dict) -> None: + # Resume continues inside the recorded epoch: the next __iter__ replays that epoch's + # shuffle order from the restored stream position, then advances normally. + self._epoch = int(state_dict.get("epoch", 0)) + self._pipeline.load_state_dict(state_dict["pipeline"]) + + def video_decoder_cache_stats(self) -> dict[str, int | float]: + """Decoder-cache reuse aggregated across DataLoader workers via the shared counter tensor. + + Unlike ``self.video_decoder_cache.stats()`` (which only reflects the main process), this sums + hits/misses/evictions over every worker. Counts are lock-free across processes, so treat them as + approximate; the ``hit_rate`` ratio is preserved. + """ + hits, misses, evictions = (int(x) for x in self._cache_counters[:3].tolist()) + total = hits + misses + return { + "hits": hits, + "misses": misses, + "evictions": evictions, + "hit_rate": round(hits / total, 4) if total else 0.0, + } + + def timing_stats(self) -> dict[str, float]: + """Cumulative seconds spent in video decode and in the upstream tabular pipeline (parquet + fetch + grouping + shuffles + explode), summed across DataLoader workers via the shared + counter tensor. These overlap in wall-clock (workers run in parallel), so compare them to + ``num_workers x wallclock`` for time fractions. + """ + decode_ns, fetch_ns = (int(x) for x in self._cache_counters[3:5].tolist()) + return {"decode_s_total": round(decode_ns / 1e9, 2), "fetch_s_total": round(fetch_ns / 1e9, 2)} def _make_timestamps_from_indices( self, start_ts: float, indices: dict[str, list[int]] | None = None @@ -785,17 +552,13 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): from_timestamp = self.meta.episodes[ep_idx][f"videos/{video_key}/from_timestamp"] shifted_query_ts = [from_timestamp + ts for ts in query_ts] rel_path = str(self.meta.get_video_file_path(ep_idx, video_key)) - local = self._prefetcher.wait_local(rel_path) if self._prefetcher is not None else None - if local is not None: - video_path = str(local) + if self.data_files_root is not None: + root = self.data_files_root + elif self.streaming and not self.streaming_from_local: + root = self.meta.url_root else: - if self.data_files_root is not None: - root = self.data_files_root - elif self.streaming and not self.streaming_from_local: - root = self.meta.url_root - else: - root = self.root - video_path = f"{root}/{rel_path}" + root = self.root + video_path = f"{root}/{rel_path}" frames = decode_video_frames_torchcodec( video_path, shifted_query_ts, diff --git a/tests/datasets/test_streaming.py b/tests/datasets/test_streaming.py index 1d39c5a9a..e2298b4d3 100644 --- a/tests/datasets/test_streaming.py +++ b/tests/datasets/test_streaming.py @@ -218,6 +218,11 @@ def test_frames_with_delta_consistency(tmp_path, lerobot_dataset_factory, state_ check = torch.allclose(left, right) and left.shape == right.shape + else: + # Scalar numerics: streaming yields python floats/ints where map-style yields + # 0-dim tensors (long-standing accepted difference). Compare by value. + check = float(left) == float(right) + key_checks.append((key, check)) assert all(t[1] for t in key_checks), ( diff --git a/tests/datasets/test_streaming_native.py b/tests/datasets/test_streaming_native.py index 2856cc4ff..de25eb144 100644 --- a/tests/datasets/test_streaming_native.py +++ b/tests/datasets/test_streaming_native.py @@ -148,37 +148,6 @@ def test_sarm_window_covers_long_horizon_without_padding(tmp_path, lerobot_datas assert checked > 0, "test did not exercise any in-episode long-horizon frame" -def test_fast_forward_resume_is_sample_exact(tmp_path, lerobot_dataset_factory): - """Resume replays the deterministic stream and continues at the exact sample.""" - repo_id = f"{DUMMY_REPO_ID}-resume" - total_frames = 100 - _make_local_dataset( - lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=5, total_frames=total_frames - ) - - def fresh_ds(): - return StreamingLeRobotDataset( - repo_id=repo_id, - root=tmp_path / "ds", - shuffle=True, - seed=7, - episode_pool_size=3, - max_num_shards=1, - ) - - full_epoch = _stream_indices(fresh_ds()) - assert sorted(full_epoch) == list(range(total_frames)) - - batches_consumed, batch_size = 5, 4 # 20 samples in - resumed_ds = fresh_ds() - resumed_ds.load_state_dict({"batches_consumed": batches_consumed, "batch_size": batch_size}) - resumed = _stream_indices(resumed_ds) - - assert resumed == full_epoch[batches_consumed * batch_size :], ( - "fast-forward resume did not continue at the exact sample" - ) - - def test_pool_order_is_deterministic_per_seed(tmp_path, lerobot_dataset_factory): repo_id = f"{DUMMY_REPO_ID}-seeds" _make_local_dataset(lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=6, total_frames=120) @@ -226,27 +195,6 @@ def test_pool_mixes_episodes(tmp_path, lerobot_dataset_factory): assert len(episodes_in_head) >= 3, f"pool did not mix episodes: {episodes_in_head}" -def test_video_prefetcher_refcounted_lifecycle(tmp_path): - from lerobot.datasets.streaming_dataset import _VideoPrefetcher - - remote = tmp_path / "remote" - (remote / "videos").mkdir(parents=True) - payload = b"x" * 1024 - (remote / "videos" / "a.mp4").write_bytes(payload) - - prefetcher = _VideoPrefetcher(str(remote), cache_dir=tmp_path / "cache", max_workers=1) - prefetcher.acquire("videos/a.mp4") - prefetcher.acquire("videos/a.mp4") # second pooled episode sharing the file - local = prefetcher.wait_local("videos/a.mp4") - assert local is not None and local.read_bytes() == payload - - prefetcher.release("videos/a.mp4") - assert local.exists(), "file deleted while still referenced" - prefetcher.release("videos/a.mp4") - assert not local.exists(), "file not deleted at refcount zero" - prefetcher.shutdown() - - def test_schema_parity_with_map_style(tmp_path, lerobot_dataset_factory): """Streamed samples must have the same keys / shapes / dtypes as map-style LeRobotDataset.""" repo_id = f"{DUMMY_REPO_ID}-parity" @@ -318,87 +266,49 @@ def test_shuffle_decorrelates_output_order(tmp_path, lerobot_dataset_factory): assert shuffled != ordered, "shuffle did not decorrelate output order" -def test_fast_forward_resume_with_dataloader_workers(tmp_path, lerobot_dataset_factory): - """Resume must be exact under num_workers > 0: each worker re-derives its own skip.""" - from torch.utils.data import DataLoader - - repo_id = f"{DUMMY_REPO_ID}-resume-workers" - _make_local_dataset(lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=8, total_frames=120) - - num_workers = 2 +def test_native_resume_never_repeats_and_loss_is_bounded(tmp_path, lerobot_dataset_factory): + """Native state_dict resume: no sample is re-yielded; loss is bounded by the shuffle buffers.""" + repo_id = f"{DUMMY_REPO_ID}-native-resume" + total_frames = 100 + _make_local_dataset( + lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=5, total_frames=total_frames + ) def fresh_ds(): return StreamingLeRobotDataset( repo_id=repo_id, root=tmp_path / "ds", shuffle=True, - seed=11, - episode_pool_size=3, - max_num_shards=4, + seed=7, + episode_pool_size=2, + frame_shuffle_buffer_size=8, ) - def epoch_samples(ds): - # batch_size=None yields raw samples; the DataLoader round-robins them across workers, - # which is batch_size=1 in the resume arithmetic. - loader = DataLoader(ds, batch_size=None, num_workers=num_workers) - return [int(sample["index"]) for sample in loader] + ds = fresh_ds() + it = iter(ds) + consumed = [int(next(it)["index"]) for _ in range(30)] + state = ds.state_dict() - full = epoch_samples(fresh_ds()) - - samples_consumed = 17 resumed_ds = fresh_ds() - resumed_ds.load_state_dict({"batches_consumed": samples_consumed, "batch_size": 1}) - resumed = epoch_samples(resumed_ds) + resumed_ds.load_state_dict(state) + rest = [int(frame["index"]) for frame in resumed_ds] - assert resumed == full[samples_consumed:], ( - "fast-forward resume with DataLoader workers did not continue at the exact sample" - ) + assert not set(consumed) & set(rest), "resume re-yielded already-seen frames" + # in-flight buffer contents are skipped on resume (documented datasets behavior): + # bounded by the episode pool (2 episodes of <= ~30 frames here) + frame buffer (8) + covered = len(set(consumed) | set(rest)) + max_in_flight = 2 * 30 + 8 + assert covered >= total_frames - max_in_flight + assert covered + len(consumed) >= total_frames - max_in_flight -def test_episode_grouping_native_and_fallback_agree(tmp_path, lerobot_dataset_factory, monkeypatch): - """The datasets>=5 batch(by_column=...) path must group episodes identically to the row loop.""" - import lerobot.datasets.streaming_dataset as sd - - repo_id = f"{DUMMY_REPO_ID}-grouping" - _make_local_dataset(lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=5, total_frames=100) - ds = StreamingLeRobotDataset(repo_id=repo_id, root=tmp_path / "ds", shuffle=False, max_num_shards=1) - - def episode_signature(use_native): - monkeypatch.setattr(sd, "_HAS_BATCH_BY_COLUMN", use_native) - return [ - (ep_idx, [int(row["index"]) for row in rows]) - for ep_idx, rows in ds._iter_shard_episodes(ds.hf_dataset) - ] - - fallback = episode_signature(False) - assert len(fallback) == 5 - if not sd._HAS_BATCH_BY_COLUMN and "by_column" not in str( - type(ds.hf_dataset).batch.__doc__ or "" - ): # datasets < 5: only the fallback path exists - return - native = episode_signature(True) - assert native == fallback - - -def test_shard_order_permutation_properties(tmp_path, lerobot_dataset_factory): - """Shard order: a valid permutation, deterministic per (seed, epoch, rank), worker-independent - (workers stride the same list, so it must not depend on worker id), reshuffled across epochs, - and identity when shuffle is off.""" - repo_id = f"{DUMMY_REPO_ID}-shardorder" +def test_pipeline_uses_native_primitives(tmp_path, lerobot_dataset_factory): + """The tabular pipeline is pure datasets: batch(by_column) + shuffle + map + shuffle.""" + repo_id = f"{DUMMY_REPO_ID}-native-pipe" _make_local_dataset(lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=4, total_frames=80) + ds = StreamingLeRobotDataset(repo_id=repo_id, root=tmp_path / "ds", shuffle=True, episode_pool_size=2) + import datasets as hf_datasets - ds = StreamingLeRobotDataset(repo_id=repo_id, root=tmp_path / "ds", shuffle=True, seed=5) - num_shards = 32 - order_epoch0 = ds._shard_order(0, num_shards) - assert sorted(order_epoch0) == list(range(num_shards)) - assert ds._shard_order(0, num_shards) == order_epoch0 # deterministic - assert ds._shard_order(1, num_shards) != order_epoch0 # reshuffles per epoch - assert order_epoch0 != list(range(num_shards)) # actually permuted (P=1/32! of false alarm) - - other_rank = StreamingLeRobotDataset( - repo_id=repo_id, root=tmp_path / "ds", shuffle=True, seed=5, rank=1, world_size=2 - ) - assert other_rank._shard_order(0, num_shards) != order_epoch0 # ranks decorrelated - - unshuffled = StreamingLeRobotDataset(repo_id=repo_id, root=tmp_path / "ds", shuffle=False, seed=5) - assert unshuffled._shard_order(0, num_shards) == list(range(num_shards)) + assert isinstance(ds._pipeline, hf_datasets.IterableDataset) + state = ds._pipeline.state_dict() # the native resume protocol is available end-to-end + assert state is not None