diff --git a/benchmarks/streaming/benchmark_streaming.py b/benchmarks/streaming/benchmark_streaming.py index f36f9b0e1..6a9ecd016 100644 --- a/benchmarks/streaming/benchmark_streaming.py +++ b/benchmarks/streaming/benchmark_streaming.py @@ -62,7 +62,7 @@ def parse_args() -> argparse.Namespace: parser.add_argument("--source", type=str, default="hub", help="Label only: hub | bucket | warmed_bucket.") parser.add_argument("--batch_size", type=int, default=64) parser.add_argument("--num_workers", type=int, default=8) - parser.add_argument("--buffer_size", type=int, default=2000) + parser.add_argument("--episode_pool_size", type=int, default=64) parser.add_argument("--video_decoder_cache_size", type=int, default=None) parser.add_argument( "--video_decode_device", @@ -86,7 +86,7 @@ def build_dataset(args: argparse.Namespace, meta: LeRobotDatasetMetadata) -> Str root=args.root, data_files_root=args.data_files_root, delta_timestamps=delta_timestamps, - buffer_size=args.buffer_size, + episode_pool_size=args.episode_pool_size, video_decoder_cache_size=args.video_decoder_cache_size, video_decode_device=args.video_decode_device, tolerance_s=1e-3, @@ -172,7 +172,7 @@ def main() -> None: "mode": args.mode, "batch_size": args.batch_size, "num_workers": args.num_workers, - "buffer_size": args.buffer_size, + "episode_pool_size": args.episode_pool_size, "num_cameras": len(meta.video_keys), "fps": meta.fps, "device": str(device), diff --git a/examples/scaling/train_streaming_multinode.py b/examples/scaling/train_streaming_multinode.py index ed74a40b3..af3e4c6b0 100644 --- a/examples/scaling/train_streaming_multinode.py +++ b/examples/scaling/train_streaming_multinode.py @@ -21,7 +21,7 @@ streaming features of :class:`StreamingLeRobotDataset`: - per-rank sharding via ``split_dataset_by_node`` (each GPU streams disjoint data; ``rank``/``world_size`` are auto-resolved from the Accelerate state, so nothing needs to be passed explicitly); - DataLoader-worker shard splitting (no duplicate frames within a rank); -- resumable streaming via ``dataset.state_dict()`` / ``load_state_dict()`` saved into the checkpoint; +- deterministic fast-forward resume via ``dataset.load_state_dict()`` (trainer-side counters only); - an explicit video-decoder cache size so the working set of open decoders does not thrash. Launch with Accelerate (single node, N GPUs): @@ -57,7 +57,10 @@ def parse_args() -> argparse.Namespace: parser.add_argument("--batch_size", type=int, default=64, help="Per-process batch size.") parser.add_argument("--num_workers", type=int, default=8) parser.add_argument( - "--buffer_size", type=int, default=2000, help="Output shuffle-buffer size, in frames." + "--episode_pool_size", + type=int, + default=64, + help="Whole episodes open per consumer (randomness knob).", ) parser.add_argument("--video_decoder_cache_size", type=int, default=None) parser.add_argument("--n_action_steps", type=int, default=16, help="Action-chunk length (delta horizon).") @@ -78,7 +81,7 @@ def make_dataloader( args.repo_id, root=args.root, delta_timestamps=delta_timestamps, - buffer_size=args.buffer_size, + episode_pool_size=args.episode_pool_size, video_decoder_cache_size=args.video_decoder_cache_size, tolerance_s=1e-3, ) @@ -121,13 +124,13 @@ def main() -> None: # of it). Batches are moved to the device manually in the loop. model, optimizer = accelerator.prepare(model, optimizer) - # Resume: restore the dataset's stream position so we don't replay already-seen data. The state holds - # plain HF stream dicts + RNG state (not tensors), so weights_only=False is required; the file is a - # checkpoint this script wrote itself. + # Resume: deterministic fast-forward. Every consumer's order is a pure function of + # (seed, epoch, rank, worker), so resuming only needs the trainer-side counters; each rank and + # worker re-derives its own skip. Same file works for every rank. if args.resume_from is not None: - state = torch.load(Path(args.resume_from) / "dataset_state.pt", weights_only=False) # nosec B614 + state = torch.load(Path(args.resume_from) / "dataset_state.pt", weights_only=True) dataset.load_state_dict(state) - accelerator.print(f"Resumed dataset stream from {args.resume_from}") + accelerator.print(f"Resuming dataset stream: {state['batches_consumed']} batches consumed") step = 0 frames_seen = 0 @@ -157,8 +160,11 @@ def main() -> None: if step % args.save_freq == 0 and accelerator.is_main_process: ckpt = output_dir / f"checkpoint-{step}" ckpt.mkdir(parents=True, exist_ok=True) - # Save the dataset stream position alongside the model so a restart resumes mid-stream. - torch.save(dataset.state_dict(), ckpt / "dataset_state.pt") + # Save the consumed-batch counters so a restart fast-forwards to this position. + torch.save( + {"batches_consumed": step, "batch_size": args.batch_size}, + ckpt / "dataset_state.pt", + ) if model is not None: accelerator.unwrap_model(model).save_pretrained(ckpt) diff --git a/slurm/benchmark_streaming_robocasa.sh b/slurm/benchmark_streaming_robocasa.sh index 0ee150dcd..01311fd37 100644 --- a/slurm/benchmark_streaming_robocasa.sh +++ b/slurm/benchmark_streaming_robocasa.sh @@ -33,7 +33,7 @@ for MODE in single sarm; do --mode $MODE \ --batch_size 64 \ --num_workers 12 \ - --buffer_size 4000 \ + --episode_pool_size 64 \ --num_batches 300 \ --out_dir '"$OUT_DIR"'/node${SLURM_NODEID} done diff --git a/slurm/run_streaming_matrix.sh b/slurm/run_streaming_matrix.sh index a33e181fc..59f3de686 100755 --- a/slurm/run_streaming_matrix.sh +++ b/slurm/run_streaming_matrix.sh @@ -83,7 +83,7 @@ for SOURCE in $SOURCES; do $RUN benchmarks/streaming/benchmark_streaming.py \ --repo_id $REPO_ID $ROOTFLAG \ --mode $MODE --source $SOURCE --video_decode_device $DECODE \ - --batch_size $BATCH_SIZE --num_workers $W --buffer_size $B \ + --batch_size $BATCH_SIZE --num_workers $W --episode_pool_size $B \ --num_batches $NUM_BATCHES --out_dir $OUT_DIR") jid=${jid%%;*} # strip ';cluster' suffix on federated setups echo "submitted job $jid bench_${SOURCE}_${MODE}_${DECODE}${DEPFLAG:+ (after $prev_jid)}" diff --git a/slurm/train_streaming_robocasa.sh b/slurm/train_streaming_robocasa.sh index 31cfe2f4b..f71219dc5 100644 --- a/slurm/train_streaming_robocasa.sh +++ b/slurm/train_streaming_robocasa.sh @@ -42,7 +42,7 @@ accelerate launch \ --repo_id '"$REPO_ID"' \ --batch_size 64 \ --num_workers 12 \ - --buffer_size 4000 \ + --episode_pool_size 64 \ --steps 200000 \ --save_freq 2000 \ --log_freq 50 diff --git a/src/lerobot/configs/default.py b/src/lerobot/configs/default.py index 9de5e6c0e..08fdda209 100644 --- a/src/lerobot/configs/default.py +++ b/src/lerobot/configs/default.py @@ -39,9 +39,10 @@ class DatasetConfig: # This reduces memory and speeds up DataLoader IPC. The training pipeline handles the conversion. return_uint8: bool = False streaming: bool = False - # Output shuffle-buffer size (in frames) when streaming. Larger decorrelates samples better at the cost - # of host RAM. Ignored when streaming is False. - streaming_buffer_size: int = 1000 + # Whole episodes each streaming consumer keeps open to shuffle across (the randomness knob). + # Larger mixes more episodes per batch at the cost of cold-start latency; RAM stays small because + # the pool holds tabular rows only. Ignored when streaming is False. + streaming_episode_pool_size: int = 64 def __post_init__(self) -> None: if self.episodes is not None: diff --git a/src/lerobot/datasets/factory.py b/src/lerobot/datasets/factory.py index 47fe560e1..7b6a77883 100644 --- a/src/lerobot/datasets/factory.py +++ b/src/lerobot/datasets/factory.py @@ -106,7 +106,7 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas delta_timestamps=delta_timestamps, image_transforms=image_transforms, revision=cfg.dataset.revision, - buffer_size=cfg.dataset.streaming_buffer_size, + episode_pool_size=cfg.dataset.streaming_episode_pool_size, tolerance_s=cfg.tolerance_s, return_uint8=True, ) diff --git a/src/lerobot/datasets/streaming_dataset.py b/src/lerobot/datasets/streaming_dataset.py index 7cf61a6ed..787720803 100644 --- a/src/lerobot/datasets/streaming_dataset.py +++ b/src/lerobot/datasets/streaming_dataset.py @@ -14,19 +14,20 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -import math import os -from collections import deque -from collections.abc import Callable, Generator, Iterable, Iterator +import shutil +from collections.abc import Callable, Iterator +from concurrent.futures import Future, ThreadPoolExecutor from pathlib import Path import datasets +import fsspec import numpy as np import torch from datasets import load_dataset from datasets.distributed import split_dataset_by_node -from lerobot.utils.constants import HF_LEROBOT_HOME, LOOKAHEAD_BACKTRACKTABLE, LOOKBACK_BACKTRACKTABLE +from lerobot.utils.constants import HF_LEROBOT_HOME from .dataset_metadata import CODEBASE_VERSION, LeRobotDatasetMetadata from .feature_utils import get_delta_indices @@ -44,200 +45,122 @@ from .video_utils import ( logger = logging.getLogger(__name__) +_MASK_64 = (1 << 64) - 1 -class LookBackError(Exception): - """ - Exception raised when trying to look back in the history of a Backtrackable object. + +def _mix64(x: int) -> int: + """SplitMix64 finalizer (64-bit integer hash) for seed derivation.""" + x = (x + 0x9E3779B97F4A7C15) & _MASK_64 + x ^= x >> 30 + x = (x * 0xBF58476D1CE4E5B9) & _MASK_64 + x ^= x >> 27 + x = (x * 0x94D049BB133111EB) & _MASK_64 + x ^= x >> 31 + return x + + +class _PooledEpisode: + """A fully-loaded episode's tabular rows plus emission bookkeeping.""" + + __slots__ = ("episode_index", "rows", "remaining", "video_rel_paths") + + def __init__(self, episode_index: int, rows: list[dict], video_rel_paths: list[str]): + self.episode_index = episode_index + self.rows = rows + self.remaining = list(range(len(rows))) + self.video_rel_paths = video_rel_paths + + +class _VideoPrefetcher: + """Background downloader of episode video files into a local cache (decode-on-exit support). + + Files are refcounted because LeRobot v3 packs several episodes per video file: a file is + downloaded once when the first pooled episode referencing it is admitted and deleted when + the last one is evicted. Downloads resolve through fsspec (hf://, s3://, https://, ...). """ - pass + def __init__(self, remote_root: str, cache_dir: Path, max_workers: int = 4): + self._remote_root = remote_root.rstrip("/") + self._cache_dir = cache_dir + self._executor = ThreadPoolExecutor(max_workers=max_workers, thread_name_prefix="video-prefetch") + self._refcounts: dict[str, int] = {} + self._futures: dict[str, Future] = {} + def acquire(self, rel_path: str) -> None: + self._refcounts[rel_path] = self._refcounts.get(rel_path, 0) + 1 + if rel_path not in self._futures: + self._futures[rel_path] = self._executor.submit(self._download, rel_path) -class LookAheadError(Exception): - """ - Exception raised when trying to look ahead in the future of a Backtrackable object. - """ + def _download(self, rel_path: str) -> Path: + local = self._cache_dir / rel_path + if local.exists(): + return local + local.parent.mkdir(parents=True, exist_ok=True) + tmp = local.with_suffix(local.suffix + ".tmp") + with fsspec.open(f"{self._remote_root}/{rel_path}", "rb") as src, open(tmp, "wb") as dst: + shutil.copyfileobj(src, dst, length=1 << 22) + tmp.rename(local) + return local - pass - - -class Backtrackable[T]: - """ - Wrap any iterator/iterable so you can step back up to `history` items - and look ahead up to `lookahead` items. - - This is useful for streaming datasets where you need to access previous and future items - but can't load the entire dataset into memory. - - Example: - ------- - ```python - ds = load_dataset("c4", "en", streaming=True, split="train") - rev = Backtrackable(ds, history=3, lookahead=2) - - x0 = next(rev) # forward - x1 = next(rev) - x2 = next(rev) - - # Look ahead - x3_peek = rev.peek_ahead(1) # next item without moving cursor - x4_peek = rev.peek_ahead(2) # two items ahead - - # Look back - x1_again = rev.peek_back(1) # previous item without moving cursor - x0_again = rev.peek_back(2) # two items back - - # Move backward - x1_back = rev.prev() # back one step - next(rev) # returns x2, continues forward from where we were - ``` - """ - - __slots__ = ("_source", "_back_buf", "_ahead_buf", "_cursor", "_history", "_lookahead") - - def __init__(self, iterable: Iterable[T], *, history: int = 1, lookahead: int = 0): - if history < 1: - raise ValueError("history must be >= 1") - if lookahead <= 0: - raise ValueError("lookahead must be > 0") - - self._source: Iterator[T] = iter(iterable) - self._back_buf: deque[T] = deque(maxlen=history) - self._ahead_buf: deque[T] = deque(maxlen=lookahead) if lookahead > 0 else deque() - self._cursor: int = 0 - self._history = history - self._lookahead = lookahead - - def __iter__(self) -> "Backtrackable[T]": - return self - - def __next__(self) -> T: - # If we've stepped back, consume from back buffer first - if self._cursor < 0: # -1 means "last item", etc. - self._cursor += 1 - return self._back_buf[self._cursor] - - # If we have items in the ahead buffer, use them first - item = self._ahead_buf.popleft() if self._ahead_buf else next(self._source) - - # Add current item to back buffer and reset cursor - self._back_buf.append(item) - self._cursor = 0 - return item - - def prev(self) -> T: - """ - Step one item back in history and return it. - Raises IndexError if already at the oldest buffered item. - """ - if len(self._back_buf) + self._cursor <= 1: - raise LookBackError("At start of history") - - self._cursor -= 1 - return self._back_buf[self._cursor] - - def peek_back(self, n: int = 1) -> T: - """ - Look `n` items back (n=1 == previous item) without moving the cursor. - """ - if n < 0 or n + 1 > len(self._back_buf) + self._cursor: - raise LookBackError("peek_back distance out of range") - - return self._back_buf[self._cursor - (n + 1)] - - def peek_ahead(self, n: int = 1) -> T: - """ - Look `n` items ahead (n=1 == next item) without moving the cursor. - Fills the ahead buffer if necessary. - """ - if n < 1: - raise LookAheadError("peek_ahead distance must be 1 or more") - elif n > self._lookahead: - raise LookAheadError("peek_ahead distance exceeds lookahead limit") - - # Fill ahead buffer if we don't have enough items - while len(self._ahead_buf) < n: - try: - item = next(self._source) - self._ahead_buf.append(item) - - except StopIteration as err: - raise LookAheadError("peek_ahead: not enough items in source") from err - - return self._ahead_buf[n - 1] - - def history(self) -> list[T]: - """ - Return a copy of the buffered history (most recent last). - The list length ≤ `history` argument passed at construction. - """ - if self._cursor == 0: - return list(self._back_buf) - - # When cursor<0, slice so the order remains chronological - return list(self._back_buf)[: self._cursor or None] - - def can_peek_back(self, steps: int = 1) -> bool: - """ - Check if we can go back `steps` items without raising an IndexError. - """ - return steps <= len(self._back_buf) + self._cursor - - def can_peek_ahead(self, steps: int = 1) -> bool: - """ - Check if we can peek ahead `steps` items. - This may involve trying to fill the ahead buffer. - """ - if self._lookahead > 0 and steps > self._lookahead: - return False - - # Try to fill ahead buffer to check if we can peek that far + def wait_local(self, rel_path: str) -> Path | None: + """Block until the file is cached; None when not tracked or the download failed.""" + future = self._futures.get(rel_path) + if future is None: + return None try: - while len(self._ahead_buf) < steps: - if self._lookahead > 0 and len(self._ahead_buf) >= self._lookahead: - return False - item = next(self._source) - self._ahead_buf.append(item) - return True - except StopIteration: - return False + return future.result() + except Exception as e: + logger.warning(f"Video prefetch failed for {rel_path} ({e}); decoding from remote instead.") + return None + + def release(self, rel_path: str) -> None: + count = self._refcounts.get(rel_path, 0) - 1 + if count > 0: + self._refcounts[rel_path] = count + return + self._refcounts.pop(rel_path, None) + future = self._futures.pop(rel_path, None) + if future is None: + return + if not future.cancel(): + try: + local = future.result() + local.unlink(missing_ok=True) + except Exception: + logger.debug(f"Could not delete cached video {rel_path}.", exc_info=True) + + def shutdown(self) -> None: + self._executor.shutdown(wait=False, cancel_futures=True) class StreamingLeRobotDataset(torch.utils.data.IterableDataset): """LeRobotDataset with streaming capabilities. - This class extends LeRobotDataset to add streaming functionality, allowing data to be streamed - rather than loaded entirely into memory. This is especially useful for large datasets that may - not fit in memory or when you want to quickly explore a dataset without downloading it completely. + Streams frames from the Hub (or any fsspec source) without downloading the dataset. The + iteration strategy is an *episode pool*: each consumer keeps ``episode_pool_size`` whole + episodes' tabular rows in RAM (a few KB per episode) and emits uniformly random frames + across them, so a batch mixes up to ``batch_size`` distinct episodes. Because a frame's + whole episode is resident, ``delta_timestamps`` windows are exact array slices with correct + padding at episode boundaries. Video is decoded only when a sample is emitted + (decode-on-exit), so pool memory stays tabular-sized; when streaming from a remote source, + each pooled episode's video files are prefetched to a local cache in the background and + deleted on eviction. - The key innovation is using a Backtrackable iterator that maintains a bounded buffer of recent - items, allowing us to access previous frames for delta timestamps without loading the entire - dataset into memory. + Distribution: ranks stream disjoint shards via ``split_dataset_by_node`` and DataLoader + workers split a rank's shards further, so every frame is consumed exactly once per epoch + across the whole fleet. Each consumer's order is a pure function of + ``(seed, epoch, rank, worker)``, which makes resume a deterministic fast-forward (see + :meth:`load_state_dict`). Example: - Basic usage: ```python - from lerobot.common.datasets.streaming_dataset import StreamingLeRobotDataset - - # Create a streaming dataset with delta timestamps - delta_timestamps = { - "observation.image": [-1.0, -0.5, 0.0], # 1 sec ago, 0.5 sec ago, current - "action": [0.0, 0.1, 0.2], # current, 0.1 sec future, 0.2 sec future - } - dataset = StreamingLeRobotDataset( repo_id="your-dataset-repo-id", - delta_timestamps=delta_timestamps, - streaming=True, - buffer_size=1000, + delta_timestamps={"action": [0.0, 0.1, 0.2]}, + episode_pool_size=64, ) - - # Iterate over the dataset - for i, item in enumerate(dataset): - print(f"Sample {i}: Episode {item['episode_index']} Frame {item['frame_index']}") - # item will contain stacked frames according to delta_timestamps - if i >= 10: - break + for sample in dataset: + ... ``` """ @@ -252,8 +175,9 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): revision: str | None = None, force_cache_sync: bool = False, streaming: bool = True, - buffer_size: int = 1000, - max_num_shards: int = 16, + episode_pool_size: int = 64, + buffer_size: int | None = None, + max_num_shards: int | None = None, seed: int = 42, rng: np.random.Generator | None = None, shuffle: bool = True, @@ -263,6 +187,8 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): video_decoder_cache_size: int | None = None, data_files_root: str | None = None, video_decode_device: str = "cpu", + prefetch_videos: bool = True, + video_prefetch_workers: int = 4, ): """Initialize a StreamingLeRobotDataset. @@ -278,30 +204,34 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): revision (str, optional): Git revision id (branch name, tag, or commit hash). force_cache_sync (bool, optional): Flag to sync and refresh local files first. streaming (bool, optional): Whether to stream the dataset or load it all. Defaults to True. - buffer_size (int, optional): Buffer size for shuffling when streaming. Defaults to 1000. - max_num_shards (int, optional): Number of shards to re-shard the input dataset into. Defaults to 16. + episode_pool_size (int, optional): Number of whole episodes each consumer keeps open to + shuffle across — the randomness knob. Larger mixes more episodes per batch (closer to + map-style uniform) at the cost of cold-start latency; RAM stays small because the pool + holds tabular rows only. Defaults to 64. + buffer_size (int | None, optional): Deprecated; superseded by ``episode_pool_size``. + max_num_shards (int | None, optional): Cap on the number of stream shards. None (default) + uses every underlying parquet shard, which is required to feed many DataLoader workers. seed (int, optional): Reproducibility random seed. - rng (np.random.Generator | None, optional): Random number generator. - shuffle (bool, optional): Whether to shuffle the dataset across exhaustions. Defaults to True. - rank (int | None, optional): This process' rank for distributed (multi-GPU/multi-node) training. - Each rank streams a disjoint set of shards via ``split_dataset_by_node``. When omitted, it is - resolved from Accelerate (``process_index``) or the ``RANK`` env var, defaulting to 0. - world_size (int | None, optional): Total number of distributed processes. When omitted, resolved - from Accelerate (``num_processes``) or the ``WORLD_SIZE`` env var, defaulting to 1 (no sharding). - For an even per-rank split, ``num_shards % world_size == 0`` should hold. + rng (np.random.Generator | None, optional): Deprecated; ignored (the RNG is derived from + ``(seed, epoch, rank, worker)`` so consumers are decorrelated and runs reproducible). + shuffle (bool, optional): Whether to shuffle. False yields episodes in stream order. + rank (int | None, optional): This process' rank for distributed training. Each rank streams + a disjoint set of shards via ``split_dataset_by_node``. When omitted, resolved from + Accelerate (``process_index``) or the ``RANK`` env var, defaulting to 0. + world_size (int | None, optional): Total number of distributed processes. When omitted, + resolved from Accelerate or ``WORLD_SIZE``, defaulting to 1. For an even per-rank split, + ``num_shards % world_size == 0`` should hold (warned otherwise). video_decoder_cache_size (int | None, optional): Max number of open video decoders to retain. - When omitted, it defaults to ``(concurrent active shards + 1) × num_cameras`` so the working - set of live decoders never thrashes. See :class:`VideoDecoderCache`. + When omitted, sized to the episode pool's working set, capped at 128. data_files_root (str | None, optional): fsspec root holding the bulk ``data/`` and ``videos/`` - trees (e.g. an HF storage bucket ``hf://buckets//``). When set, parquet and - video frames are read from there while metadata still loads from ``repo_id`` on the Hub. - Resolves through fsspec exactly like ``hf://``; use it to benchmark bucket / prewarmed-bucket - sources without copying the (small) metadata. - video_decode_device (str, optional): Device for video decoding, passed to the torchcodec - ``VideoDecoder``. Defaults to ``"cpu"``. Set to ``"cuda"`` to offload H.264/H.265 decode to - the GPU's dedicated NVDEC engine (independent of the training SMs), which requires a - CUDA-enabled torchcodec build. Note: ``"cuda"`` decode inside ``DataLoader`` workers needs - the ``spawn`` start method (CUDA cannot init in forked workers). + trees (e.g. ``hf://buckets//``). When set, parquet and video bytes are read + from there while metadata still loads from ``repo_id`` on the Hub. + video_decode_device (str, optional): Device for torchcodec decode. ``"cuda"`` offloads to + NVDEC (needs a CUDA torchcodec build and ``spawn`` DataLoader workers). + prefetch_videos (bool, optional): When streaming from a remote source, download each pooled + episode's video files to a local cache in the background so decode-on-exit reads local + bytes instead of paying network seek latency. Defaults to True. + video_prefetch_workers (int, optional): Download threads per consumer. Defaults to 4. """ super().__init__() self.repo_id = repo_id @@ -314,11 +244,17 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): self.tolerance_s = tolerance_s self.revision = revision if revision else CODEBASE_VERSION self.seed = seed - self.rng = rng if rng is not None else np.random.default_rng(seed) + if rng is not None: + logger.warning("StreamingLeRobotDataset: `rng` is deprecated and ignored; use `seed`.") self.shuffle = shuffle self.streaming = streaming - self.buffer_size = buffer_size + if buffer_size is not None: + logger.warning( + "StreamingLeRobotDataset: `buffer_size` is deprecated and ignored; " + "use `episode_pool_size` (whole episodes, not frames)." + ) + self.episode_pool_size = max(1, episode_pool_size) self.max_num_shards = max_num_shards self._return_uint8 = return_uint8 @@ -326,13 +262,19 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): self.video_decoder_cache_size = video_decoder_cache_size self.data_files_root = data_files_root.rstrip("/") if data_files_root else None self.video_decode_device = video_decode_device + self.prefetch_videos = prefetch_videos + self.video_prefetch_workers = video_prefetch_workers # We cache the video decoders to avoid re-initializing them at each frame (avoiding a ~10x slowdown) self.video_decoder_cache = None + self._prefetcher: _VideoPrefetcher | None = None # Shared [hits, misses, evictions] tensor so DataLoader workers aggregate decoder-cache stats into # one place the main process can read after iteration (see video_decoder_cache_stats()). self._cache_counters = torch.zeros(3, dtype=torch.int64).share_memory_() - # Resume state captured by load_state_dict() and consumed at the next __iter__. + # Deterministic fast-forward resume (see load_state_dict): per-consumer epoch counter and + # number of samples still to skip. + self._epoch = 0 + self._ff_remaining = 0 self._resume_state: dict | None = None if self._requested_root is not None: @@ -381,7 +323,11 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): if extra_columns: self.hf_dataset = self.hf_dataset.remove_columns(extra_columns) - self.num_shards = min(self.hf_dataset.num_shards, max_num_shards) + self.num_shards = ( + self.hf_dataset.num_shards + if self.max_num_shards is None + else min(self.hf_dataset.num_shards, self.max_num_shards) + ) @property def num_frames(self): @@ -395,18 +341,6 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): def fps(self): return self.meta.fps - @staticmethod - def _iter_random_indices( - rng: np.random.Generator, buffer_size: int, random_batch_size=100 - ) -> Iterator[int]: - while True: - yield from (int(i) for i in rng.integers(0, buffer_size, size=random_batch_size)) - - @staticmethod - def _infinite_generator_over_elements(rng: np.random.Generator, elements: list[int]) -> Iterator[int]: - while True: - yield rng.choice(elements) - @staticmethod def _resolve_distributed(rank: int | None, world_size: int | None) -> tuple[int, int]: """Resolve (rank, world_size) for distributed streaming. @@ -433,13 +367,15 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): return 0, 1 - def _make_video_decoder_cache(self, num_active_shards: int) -> VideoDecoderCache: - """Size the decoder cache to the working set of live shards so it does not thrash. + def _consumer_rng(self, epoch: int, worker_id: int) -> np.random.Generator: + """RNG derived from (seed, epoch, rank, worker): reproducible, decorrelated consumers.""" + state = _mix64(self.seed) + for salt in (self.rank, worker_id, epoch if self.shuffle else 0): + state = _mix64(state ^ _mix64(salt)) + return np.random.default_rng(state) - Each shard mid-episode keeps one open decoder per camera; with several shards iterated - concurrently the working set is ``num_active_shards × num_cameras``. We add one shard worth of - margin so the round-robin never evicts a still-live decoder. - """ + def _make_video_decoder_cache(self) -> VideoDecoderCache: + """Size the decoder cache to the pool's working set (pool episodes x cameras), capped at 128.""" if self.video_decoder_cache_size is not None: return VideoDecoderCache( max_size=self.video_decoder_cache_size, @@ -450,105 +386,175 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): if num_cameras == 0: return VideoDecoderCache(counters=self._cache_counters, device=self.video_decode_device) return VideoDecoderCache( - max_size=(num_active_shards + 1) * num_cameras, + max_size=min((self.episode_pool_size + 1) * num_cameras, 128), counters=self._cache_counters, device=self.video_decode_device, ) - # TODO(fracapuano): Implement multi-threaded prefetching to accelerate data loading. - # The current sequential iteration is a bottleneck. A producer-consumer pattern - # could be used with a ThreadPoolExecutor to run `make_frame` (especially video decoding) - # in parallel, feeding a queue from which this iterator will yield processed items. + def _make_prefetcher(self) -> _VideoPrefetcher | None: + if not self.prefetch_videos or len(self.meta.video_keys) == 0: + return None + if self.data_files_root is not None: + remote_root = self.data_files_root + elif self.streaming and not self.streaming_from_local: + remote_root = self.meta.url_root + else: + return None # video bytes are already local + return _VideoPrefetcher( + remote_root, + cache_dir=self.root / "streaming_video_cache", + max_workers=self.video_prefetch_workers, + ) + + @staticmethod + def _iter_shard_episodes(shard: datasets.IterableDataset) -> Iterator[tuple[int, list[dict]]]: + """Yield (episode_index, rows) for each complete episode of a shard stream.""" + rows: list[dict] = [] + current: int | None = None + for item in shard: + ep_idx = int(item["episode_index"]) + if current is None: + current = ep_idx + if ep_idx != current: + yield current, rows + rows = [] + current = ep_idx + rows.append(item) + if rows: + yield current, rows + + def _admit_episode(self, ep_idx: int, rows: list[dict], prefetcher: _VideoPrefetcher | None): + video_rel_paths = [str(self.meta.get_video_file_path(ep_idx, key)) for key in self.meta.video_keys] + if prefetcher is not None: + for rel in video_rel_paths: + prefetcher.acquire(rel) + torch_rows = [item_to_torch(row) for row in rows] + return _PooledEpisode(ep_idx, torch_rows, video_rel_paths) + def __iter__(self) -> Iterator[dict[str, torch.Tensor]]: - # Distributed correctness: each rank streams a disjoint set of shards (order preserved). ds = self.hf_dataset if self.world_size > 1: + if ds.num_shards % self.world_size != 0: + logger.warning( + f"num_shards ({ds.num_shards}) is not divisible by world_size ({self.world_size}): " + "datasets falls back to example-level splitting where every rank reads (and pays " + "for) the full stream. Re-shard the dataset or adjust world size." + ) ds = split_dataset_by_node(ds, rank=self.rank, world_size=self.world_size) - num_shards = min(ds.num_shards, self.max_num_shards) + num_shards = ds.num_shards if self.max_num_shards is None else min(ds.num_shards, self.max_num_shards) shard_indices = list(range(num_shards)) # DataLoader workers within this rank further split the shards so they don't yield duplicates. worker_info = torch.utils.data.get_worker_info() - if worker_info is not None: - shard_indices = shard_indices[worker_info.id :: worker_info.num_workers] + worker_id, num_workers = (worker_info.id, worker_info.num_workers) if worker_info else (0, 1) + shard_indices = shard_indices[worker_id::num_workers] + if not shard_indices: + logger.warning( + f"Worker {worker_id} owns no shards ({num_shards} shards < {num_workers} workers): " + "it will yield nothing. Reduce num_workers or re-shard the dataset." + ) + return - self.video_decoder_cache = self._make_video_decoder_cache(len(shard_indices)) + self.video_decoder_cache = self._make_video_decoder_cache() + prefetcher = self._make_prefetcher() + self._prefetcher = prefetcher - # keep the same seed across exhaustions if shuffle is False, otherwise shuffle data across exhaustions - rng = np.random.default_rng(self.seed) if not self.shuffle else self.rng + epoch = self._epoch + self._epoch += 1 + rng = self._consumer_rng(epoch, worker_id) + self._consume_resume_state(worker_id, num_workers) - # Best-effort resume: restore RNG + exhausted shards and rewind each shard's HF stream. The - # shuffle buffer is re-warmed rather than restored, so resumption is not bit-exact (acceptable - # for pretraining); the underlying stream may also skip the few frames Backtrackable read ahead. - resume = self._resume_state - self._resume_state = None - self._exhausted: set[int] = set(resume["exhausted"]) if resume is not None else set() - if resume is not None: - rng.bit_generator.state = resume["rng"] + # Round-robin episode admission across this consumer's shard streams (deterministic). + streams = [self._iter_shard_episodes(safe_shard(ds, idx, num_shards)) for idx in shard_indices] + next_stream = 0 - self._shards: dict[int, datasets.IterableDataset] = {} - for idx in shard_indices: - shard = safe_shard(ds, idx, num_shards) - if resume is not None and str(idx) in resume["shards"]: - shard.load_state_dict(resume["shards"][str(idx)]) - self._shards[idx] = shard + pool: list[_PooledEpisode] = [] + total_remaining = 0 - buffer_indices_generator = self._iter_random_indices(rng, self.buffer_size) + def admit() -> int: + nonlocal next_stream, total_remaining + admitted = 0 + while len(pool) < self.episode_pool_size and streams: + stream = streams[next_stream % len(streams)] + try: + ep_idx, rows = next(stream) + except StopIteration: + streams.remove(stream) + continue + next_stream += 1 + episode = self._admit_episode(ep_idx, rows, prefetcher) + pool.append(episode) + total_remaining += len(episode.remaining) + admitted += 1 + return admitted - idx_to_backtrack_dataset = { - idx: self._make_backtrackable_dataset(shard) - for idx, shard in self._shards.items() - if idx not in self._exhausted - } + try: + admit() + while pool: + if self.shuffle: + # Uniform draw over every remaining frame in the pool: pick the episode by + # cumulative remaining count, then a random remaining position (swap-pop). + draw = int(rng.integers(total_remaining)) + for episode in pool: + if draw < len(episode.remaining): + break + draw -= len(episode.remaining) + pick = int(rng.integers(len(episode.remaining))) + frame_pos = episode.remaining[pick] + episode.remaining[pick] = episode.remaining[-1] + episode.remaining.pop() + else: + episode = pool[0] + frame_pos = episode.remaining.pop(0) + total_remaining -= 1 - # This buffer is populated while iterating on the dataset's shards - # the logic is to add 2 levels of randomness: - # (1) sample one shard at random from the ones available, and - # (2) sample one frame from the shard sampled at (1) - frames_buffer = [] - while available_shards := list(idx_to_backtrack_dataset.keys()): - shard_key = next(self._infinite_generator_over_elements(rng, available_shards)) - backtrack_dataset = idx_to_backtrack_dataset[shard_key] # selects which shard to iterate on + if self._ff_remaining > 0: + self._ff_remaining -= 1 + else: + yield self._make_pool_sample(episode, frame_pos) - try: - for frame in self.make_frame(backtrack_dataset): - if len(frames_buffer) == self.buffer_size: - i = next(buffer_indices_generator) # samples a element from the buffer - yield frames_buffer[i] - frames_buffer[i] = frame - else: - frames_buffer.append(frame) - break # random shard sampled, switch shard - except ( - RuntimeError, - StopIteration, - ): # NOTE: StopIteration inside a generator throws a RuntimeError since python 3.7 - del idx_to_backtrack_dataset[shard_key] # Remove exhausted shard, onto another shard - self._exhausted.add(shard_key) - - # Once shards are all exhausted, shuffle the buffer and yield the remaining frames - rng.shuffle(frames_buffer) - yield from frames_buffer - - def state_dict(self) -> dict: - """Capture resume state: per-shard HF stream position, exhausted shards, and RNG state. - - Must be called after iteration has started (so the shard streams exist). Restore the returned - dict with :meth:`load_state_dict` before re-iterating. The shuffle buffer is not captured, so - resumption is not bit-exact — see :meth:`__iter__`. - """ - if not hasattr(self, "_shards"): - raise RuntimeError("state_dict() requires the dataset to have been iterated at least once.") - return { - "shards": {str(idx): shard.state_dict() for idx, shard in self._shards.items()}, - "exhausted": sorted(self._exhausted), - "rng": self.rng.bit_generator.state, - } + if not episode.remaining: + pool.remove(episode) + if prefetcher is not None: + for rel in episode.video_rel_paths: + prefetcher.release(rel) + admit() + finally: + if prefetcher is not None: + prefetcher.shutdown() + self._prefetcher = None def load_state_dict(self, state_dict: dict) -> None: - """Stage resume state captured by :meth:`state_dict`; applied at the next ``__iter__``.""" - self._resume_state = state_dict + """Stage a deterministic fast-forward resume, applied from the next ``__iter__``. + + ``state_dict`` holds ``{"batches_consumed": int, "batch_size": int}`` — what the trainer + already knows at checkpoint time. Because every consumer's order is a pure function of + (seed, epoch, rank, worker), resume replays the stream while skipping emission (tabular + reads only, no video decode) until each worker reaches its own consumed count; the + DataLoader's round-robin batch assignment makes that count derivable per worker. Exact + within an epoch; crossing epoch boundaries may drift by < one batch per worker per epoch + when ``drop_last`` discards partial batches. + """ + self._resume_state = { + "batches_consumed": int(state_dict["batches_consumed"]), + "batch_size": int(state_dict["batch_size"]), + } + + def _consume_resume_state(self, worker_id: int, num_workers: int) -> None: + if self._resume_state is None: + return + batches = self._resume_state["batches_consumed"] + batch_size = self._resume_state["batch_size"] + self._resume_state = None + # DataLoader assigns batch j to worker j % num_workers. + my_batches = batches // num_workers + (1 if batches % num_workers > worker_id else 0) + self._ff_remaining = my_batches * batch_size + if self._ff_remaining: + logger.info( + f"Streaming resume: worker {worker_id} fast-forwarding {self._ff_remaining} samples " + "(tabular reads only, no video decode)." + ) def video_decoder_cache_stats(self) -> dict[str, int | float]: """Decoder-cache reuse aggregated across DataLoader workers via the shared counter tensor. @@ -566,43 +572,75 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): "hit_rate": round(hits / total, 4) if total else 0.0, } - def _get_window_steps( - self, delta_timestamps: dict[str, list[float]] | None = None, dynamic_bounds: bool = False - ) -> tuple[int, int]: - if delta_timestamps is None: - return 1, 1 + def _make_pool_sample(self, episode: _PooledEpisode, frame_pos: int) -> dict: + """Assemble a full training sample for one pooled frame (tabular slices + video decode).""" + rows = episode.rows + item = dict(rows[frame_pos]) + ep_idx = episode.episode_index + num_rows = len(rows) + current_ts = float(item["timestamp"]) - if not dynamic_bounds: - # Fix the windows - lookback = LOOKBACK_BACKTRACKTABLE - lookahead = LOOKAHEAD_BACKTRACKTABLE - else: - # Dynamically size the windows to exactly cover the requested delta_timestamps (in frames). - # This removes the fixed LOOKAHEAD_BACKTRACKTABLE ceiling, which would raise LookAheadError for - # long horizons (e.g. a SARM window of 8 steps spaced 1s = ~160 frames @ fps20). - all_timestamps = sum(delta_timestamps.values(), []) - lookback = math.floor(min(all_timestamps) * self.fps) - lookahead = math.ceil(max(all_timestamps) * self.fps) + updates: list[dict] = [] + if self.delta_indices is not None: + updates.extend(self._pool_delta_frames(rows, frame_pos, num_rows)) - # When lookback is >=0 it means no negative timesteps have been provided - lookback = 0 if lookback >= 0 else -lookback + if len(self.meta.video_keys) > 0: + # Per-camera episode-local bounds [0, duration]: out-of-episode deltas pad instead of + # decoding against a neighbouring episode sharing the same video file. + episode_boundaries_ts = { + key: ( + 0.0, + self.meta.episodes[ep_idx][f"videos/{key}/to_timestamp"] + - self.meta.episodes[ep_idx][f"videos/{key}/from_timestamp"], + ) + for key in self.meta.video_keys + } + original_timestamps = self._make_timestamps_from_indices(current_ts, self.delta_indices) + query_timestamps = self._get_query_timestamps( + current_ts, self.delta_indices, episode_boundaries_ts + ) + video_frames = self._query_videos(query_timestamps, ep_idx) - return lookback, lookahead + if self.image_transforms is not None: + for cam in self.meta.camera_keys: + video_frames[cam] = self.image_transforms(video_frames[cam]) - def _make_backtrackable_dataset(self, dataset: datasets.IterableDataset) -> Backtrackable: - lookback, lookahead = self._get_window_steps(self.delta_timestamps, dynamic_bounds=True) - # Backtrackable.peek_back(n) needs `history >= n + 1`, so reach a frame `lookback` steps back requires - # history = lookback + 1. history must be >= 1 and lookahead > 0, so clamp both to at least 1. - return Backtrackable(dataset, history=max(1, lookback + 1), lookahead=max(1, lookahead)) + updates.append(video_frames) + if self.delta_indices is not None: + updates.append( + self._get_video_frame_padding_mask(video_frames, query_timestamps, original_timestamps) + ) + + result = item + for update in updates: + result.update(update) + result["task"] = self.meta.tasks.iloc[item["task_index"]].name + return result + + def _pool_delta_frames(self, rows: list[dict], frame_pos: int, num_rows: int) -> list[dict]: + """Exact delta windows by slicing the resident episode; clamped + padded at boundaries.""" + query_result: dict = {} + padding: dict = {} + for key, deltas in self.delta_indices.items(): + if key in self.meta.video_keys: + continue # visual frames are decoded separately + frames = [] + is_pad = [] + for delta in deltas: + j = frame_pos + delta + valid = 0 <= j < num_rows + frames.append(rows[min(max(j, 0), num_rows - 1)][key]) + is_pad.append(not valid) + query_result[key] = torch.stack(frames) + padding[f"{key}_is_pad"] = torch.BoolTensor(is_pad) + return [query_result, padding] def _make_timestamps_from_indices( self, start_ts: float, indices: dict[str, list[int]] | None = None ) -> dict[str, list[float]]: if indices is not None: return { - key: ( - start_ts + torch.tensor(indices[key]) / self.fps - ).tolist() # NOTE: why not delta_timestamps directly? + key: (start_ts + torch.tensor(indices[key]) / self.fps).tolist() for key in self.delta_timestamps } else: @@ -639,72 +677,6 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): return padding_mask - def make_frame(self, dataset_iterator: Backtrackable) -> Generator: - """Makes a frame starting from a dataset iterator""" - item = next(dataset_iterator) - item = item_to_torch(item) - - updates = [] # list of "updates" to apply to the item retrieved from hf_dataset (w/o camera features) - - # Get episode index from the item - ep_idx = item["episode_index"] - - # `timestamp` is episode-local (restarts at 0 each episode). The absolute in-file timestamp is - # `from_timestamp + timestamp`, applied per camera at decode time (see `_query_videos`), mirroring - # the map-style reader. Using `index / fps` here is a dataset-global value that only matches the - # file timeline when the whole dataset is a single video (e.g. small test fixtures), and otherwise - # decodes out-of-range frames on multi-file v3 datasets. - current_ts = float(item["timestamp"]) - - # Per-camera episode-local bounds [0, duration]. Query timestamps are clamped into this range so - # out-of-episode deltas pad rather than decode against a neighbouring episode in the same file. - episode_boundaries_ts = { - key: ( - 0.0, - self.meta.episodes[ep_idx][f"videos/{key}/to_timestamp"] - - self.meta.episodes[ep_idx][f"videos/{key}/from_timestamp"], - ) - for key in self.meta.video_keys - } - - # Apply delta querying logic if necessary - if self.delta_indices is not None: - query_result, padding = self._get_delta_frames(dataset_iterator, item) - updates.append(query_result) - updates.append(padding) - - # Load video frames, when needed - if len(self.meta.video_keys) > 0: - original_timestamps = self._make_timestamps_from_indices(current_ts, self.delta_indices) - - # Some timestamps might not result available considering the episode's boundaries - query_timestamps = self._get_query_timestamps( - current_ts, self.delta_indices, episode_boundaries_ts - ) - video_frames = self._query_videos(query_timestamps, ep_idx) - - if self.image_transforms is not None: - image_keys = self.meta.camera_keys - for cam in image_keys: - video_frames[cam] = self.image_transforms(video_frames[cam]) - - updates.append(video_frames) - - if self.delta_indices is not None: - # We always return the same number of frames. Unavailable frames are padded. - padding_mask = self._get_video_frame_padding_mask( - video_frames, query_timestamps, original_timestamps - ) - updates.append(padding_mask) - - result = item.copy() - for update in updates: - result.update(update) - - result["task"] = self.meta.tasks.iloc[item["task_index"]].name - - yield result - def _get_query_timestamps( self, current_ts: float, @@ -738,13 +710,18 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): # query_ts is episode-local; shift to the absolute in-file timeline by the episode's offset. from_timestamp = self.meta.episodes[ep_idx][f"videos/{video_key}/from_timestamp"] shifted_query_ts = [from_timestamp + ts for ts in query_ts] - if self.data_files_root is not None: - root = self.data_files_root - elif self.streaming and not self.streaming_from_local: - root = self.meta.url_root + rel_path = str(self.meta.get_video_file_path(ep_idx, video_key)) + local = self._prefetcher.wait_local(rel_path) if self._prefetcher is not None else None + if local is not None: + video_path = str(local) else: - root = self.root - video_path = f"{root}/{self.meta.get_video_file_path(ep_idx, video_key)}" + if self.data_files_root is not None: + root = self.data_files_root + elif self.streaming and not self.streaming_from_local: + root = self.meta.url_root + else: + root = self.root + video_path = f"{root}/{rel_path}" frames = decode_video_frames_torchcodec( video_path, shifted_query_ts, @@ -757,116 +734,6 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): return item - def _get_delta_frames(self, dataset_iterator: Backtrackable, current_item: dict): - # TODO(fracapuano): Modularize this function, refactor the code - """Get frames with delta offsets using the backtrackable iterator. - - Args: - current_item (dict): Current item from the iterator. - ep_idx (int): Episode index. - - Returns: - tuple: (query_result, padding) - frames at delta offsets and padding info. - """ - current_episode_idx = current_item["episode_index"] - - # Prepare results - query_result = {} - padding = {} - - for key, delta_indices in self.delta_indices.items(): - if key in self.meta.video_keys: - continue # visual frames are decoded separately - - target_frames = [] - is_pad = [] - - # Create a results dictionary to store frames in processing order, then reconstruct original order for stacking - delta_results = {} - - # Separate and sort deltas by difficulty (easier operations first) - negative_deltas = sorted([d for d in delta_indices if d < 0], reverse=True) # [-1, -2, -3, ...] - positive_deltas = sorted([d for d in delta_indices if d > 0]) # [1, 2, 3, ...] - zero_deltas = [d for d in delta_indices if d == 0] - - # Process zero deltas (current frame) - for delta in zero_deltas: - delta_results[delta] = ( - current_item[key], - False, - ) - - # Process negative deltas in order of increasing difficulty - lookback_failed = False - - last_successful_frame = current_item[key] - - for delta in negative_deltas: - if lookback_failed: - delta_results[delta] = (last_successful_frame, True) - continue - - try: - steps_back = abs(delta) - if dataset_iterator.can_peek_back(steps_back): - past_item = dataset_iterator.peek_back(steps_back) - past_item = item_to_torch(past_item) - - if past_item["episode_index"] == current_episode_idx: - delta_results[delta] = (past_item[key], False) - last_successful_frame = past_item[key] - - else: - raise LookBackError("Retrieved frame is from different episode!") - else: - raise LookBackError("Cannot go back further than the history buffer!") - - except LookBackError: - delta_results[delta] = (last_successful_frame, True) - lookback_failed = True # All subsequent negative deltas will also fail - - # Process positive deltas in order of increasing difficulty - lookahead_failed = False - last_successful_frame = current_item[key] - - for delta in positive_deltas: - if lookahead_failed: - delta_results[delta] = (last_successful_frame, True) - continue - - try: - if dataset_iterator.can_peek_ahead(delta): - future_item = dataset_iterator.peek_ahead(delta) - future_item = item_to_torch(future_item) - - if future_item["episode_index"] == current_episode_idx: - delta_results[delta] = (future_item[key], False) - last_successful_frame = future_item[key] - - else: - raise LookAheadError("Retrieved frame is from different episode!") - else: - raise LookAheadError("Cannot go ahead further than the lookahead buffer!") - - except LookAheadError: - delta_results[delta] = (last_successful_frame, True) - lookahead_failed = True # All subsequent positive deltas will also fail - - # Reconstruct original order for stacking - for delta in delta_indices: - frame, is_padded = delta_results[delta] - - # add batch dimension for stacking - target_frames.append(frame) # frame.unsqueeze(0)) - is_pad.append(is_padded) - - # Stack frames and add to results - if target_frames: - query_result[key] = torch.stack(target_frames) - padding[f"{key}_is_pad"] = torch.BoolTensor(is_pad) - - return query_result, padding - def _validate_delta_timestamp_keys(self, delta_timestamps: dict[list[float]]) -> None: """ Validate that all keys in delta_timestamps correspond to actual features in the dataset. diff --git a/tests/datasets/test_streaming.py b/tests/datasets/test_streaming.py index db167f657..1d39c5a9a 100644 --- a/tests/datasets/test_streaming.py +++ b/tests/datasets/test_streaming.py @@ -13,7 +13,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import numpy as np import pytest import torch @@ -25,52 +24,6 @@ from lerobot.utils.constants import ACTION from tests.fixtures.constants import DUMMY_REPO_ID -def get_frames_expected_order(streaming_ds: StreamingLeRobotDataset) -> list[int]: - """Replicates the shuffling logic of StreamingLeRobotDataset to get the expected order of indices.""" - rng = np.random.default_rng(streaming_ds.seed) - buffer_size = streaming_ds.buffer_size - num_shards = streaming_ds.num_shards - - shards_indices = [] - for shard_idx in range(num_shards): - shard = streaming_ds.hf_dataset.shard(num_shards, index=shard_idx) - shard_indices = [item["index"] for item in shard] - shards_indices.append(shard_indices) - - shard_iterators = {i: iter(s) for i, s in enumerate(shards_indices)} - - buffer_indices_generator = streaming_ds._iter_random_indices(rng, buffer_size) - - frames_buffer = [] - expected_indices = [] - - while shard_iterators: # While there are still available shards - available_shard_keys = list(shard_iterators.keys()) - if not available_shard_keys: - break - - # Call _infinite_generator_over_elements with current available shards (key difference!) - shard_key = next(streaming_ds._infinite_generator_over_elements(rng, available_shard_keys)) - - try: - frame_index = next(shard_iterators[shard_key]) - - if len(frames_buffer) == buffer_size: - i = next(buffer_indices_generator) - expected_indices.append(frames_buffer[i]) - frames_buffer[i] = frame_index - else: - frames_buffer.append(frame_index) - - except StopIteration: - del shard_iterators[shard_key] # Remove exhausted shard - - rng.shuffle(frames_buffer) - expected_indices.extend(frames_buffer) - - return expected_indices - - def test_single_frame_consistency(tmp_path, lerobot_dataset_factory): """Test if are correctly accessed""" ds_num_frames = 400 @@ -120,10 +73,9 @@ def test_single_frame_consistency(tmp_path, lerobot_dataset_factory): [False, True], ) def test_frames_order_over_epochs(tmp_path, lerobot_dataset_factory, shuffle): - """Test if streamed frames correspond to shuffling operations over in-memory dataset.""" + """Each epoch covers every frame exactly once; shuffle reshuffles across epochs.""" ds_num_frames = 400 ds_num_episodes = 10 - buffer_size = 100 seed = 42 n_epochs = 3 @@ -138,25 +90,17 @@ def test_frames_order_over_epochs(tmp_path, lerobot_dataset_factory, shuffle): ) streaming_ds = StreamingLeRobotDataset( - repo_id=repo_id, root=local_path, buffer_size=buffer_size, seed=seed, shuffle=shuffle + repo_id=repo_id, root=local_path, episode_pool_size=4, seed=seed, shuffle=shuffle ) - first_epoch_indices = [frame["index"] for frame in streaming_ds] - expected_indices = get_frames_expected_order(streaming_ds) - - assert first_epoch_indices == expected_indices, "First epoch indices do not match expected indices" - - expected_indices = get_frames_expected_order(streaming_ds) - for _ in range(n_epochs): - streaming_indices = [frame["index"] for frame in streaming_ds] - frames_match = all( - s_index == e_index for s_index, e_index in zip(streaming_indices, expected_indices, strict=True) - ) - - if shuffle: - assert not frames_match - else: - assert frames_match + epochs = [[int(frame["index"]) for frame in streaming_ds] for _ in range(n_epochs)] + for epoch_indices in epochs: + assert sorted(epoch_indices) == list(range(ds_num_frames)), "epoch did not cover every frame once" + if shuffle: + assert epochs[0] != epochs[1], "shuffle did not reshuffle across epochs" + assert epochs[0] != list(range(ds_num_frames)), "shuffle left the stream in sequential order" + else: + assert epochs[0] == epochs[1] == epochs[2], "unshuffled epochs must repeat the same order" @pytest.mark.parametrize( @@ -164,15 +108,11 @@ def test_frames_order_over_epochs(tmp_path, lerobot_dataset_factory, shuffle): [False, True], ) def test_frames_order_with_shards(tmp_path, lerobot_dataset_factory, shuffle): - """Test if streamed frames correspond to shuffling operations over in-memory dataset with multiple shards.""" + """Multi-shard streams keep exactly-once coverage and deterministic per-seed order.""" ds_num_frames = 100 ds_num_episodes = 10 - buffer_size = 10 - seed = 42 - n_epochs = 3 data_file_size_mb = 0.001 - chunks_size = 1 local_path = tmp_path / "test" @@ -187,31 +127,21 @@ def test_frames_order_with_shards(tmp_path, lerobot_dataset_factory, shuffle): chunks_size=chunks_size, ) - streaming_ds = StreamingLeRobotDataset( - repo_id=repo_id, - root=local_path, - buffer_size=buffer_size, - seed=seed, - shuffle=shuffle, - max_num_shards=4, - ) - - first_epoch_indices = [frame["index"] for frame in streaming_ds] - expected_indices = get_frames_expected_order(streaming_ds) - - assert first_epoch_indices == expected_indices, "First epoch indices do not match expected indices" - - for _ in range(n_epochs): - streaming_indices = [ - frame["index"] for frame in streaming_ds - ] # NOTE: this is the same as first_epoch_indices - frames_match = all( - s_index == e_index for s_index, e_index in zip(streaming_indices, expected_indices, strict=True) + def make_ds(): + return StreamingLeRobotDataset( + repo_id=repo_id, + root=local_path, + episode_pool_size=3, + seed=seed, + shuffle=shuffle, + max_num_shards=4, ) - if shuffle: - assert not frames_match - else: - assert frames_match + + first = [int(frame["index"]) for frame in make_ds()] + again = [int(frame["index"]) for frame in make_ds()] + + assert sorted(first) == list(range(ds_num_frames)), "epoch did not cover every frame once" + assert first == again, "same seed must reproduce the same order" @pytest.mark.parametrize( diff --git a/tests/datasets/test_streaming_distributed.py b/tests/datasets/test_streaming_distributed.py index 10ffc5dca..de0093af6 100644 --- a/tests/datasets/test_streaming_distributed.py +++ b/tests/datasets/test_streaming_distributed.py @@ -40,7 +40,7 @@ from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset root, repo_id, out_dir = sys.argv[1], sys.argv[2], sys.argv[3] state = PartialState() ds = StreamingLeRobotDataset( - repo_id=repo_id, root=root, shuffle=False, buffer_size=8, max_num_shards=8 + repo_id=repo_id, root=root, shuffle=False, episode_pool_size=8, max_num_shards=8 ) indices = [int(frame["index"]) for frame in ds] payload = {"rank": state.process_index, "world": state.num_processes, "indices": indices} diff --git a/tests/datasets/test_streaming_native.py b/tests/datasets/test_streaming_native.py index fc3d35153..1577cae38 100644 --- a/tests/datasets/test_streaming_native.py +++ b/tests/datasets/test_streaming_native.py @@ -13,7 +13,8 @@ # limitations under the License. """Tests for the HF-native large-scale streaming additions: distributed (per-rank) sharding, -DataLoader worker splitting, SARM-sized delta windows, resumability, and schema parity.""" +DataLoader worker splitting, the episode pool (randomness, coverage, exact deltas), video +prefetching, deterministic fast-forward resume, and schema parity.""" import pytest import torch @@ -75,7 +76,7 @@ def test_split_by_node_disjoint_across_ranks(tmp_path, lerobot_dataset_factory): repo_id=repo_id, root=tmp_path / "ds", shuffle=False, - buffer_size=8, + episode_pool_size=8, max_num_shards=8, rank=rank, world_size=world_size, @@ -101,7 +102,7 @@ def test_dataloader_workers_no_duplicates_within_rank(tmp_path, lerobot_dataset_ ) ds = StreamingLeRobotDataset( - repo_id=repo_id, root=tmp_path / "ds", shuffle=False, buffer_size=4, max_num_shards=4 + repo_id=repo_id, root=tmp_path / "ds", shuffle=False, episode_pool_size=4, max_num_shards=4 ) loader = DataLoader(ds, batch_size=None, num_workers=2) indices = [int(batch["index"]) for batch in loader] @@ -128,7 +129,7 @@ def test_sarm_window_covers_long_horizon_without_padding(tmp_path, lerobot_datas repo_id=repo_id, root=tmp_path / "ds", shuffle=False, - buffer_size=1, + episode_pool_size=1, max_num_shards=1, delta_timestamps=delta_timestamps, ) @@ -147,8 +148,8 @@ def test_sarm_window_covers_long_horizon_without_padding(tmp_path, lerobot_datas assert checked > 0, "test did not exercise any in-episode long-horizon frame" -def test_state_dict_resume_continues_without_restart(tmp_path, lerobot_dataset_factory): - """state_dict()/load_state_dict() must resume the stream near where it stopped, not from the start.""" +def test_fast_forward_resume_is_sample_exact(tmp_path, lerobot_dataset_factory): + """Resume replays the deterministic stream and continues at the exact sample.""" repo_id = f"{DUMMY_REPO_ID}-resume" total_frames = 100 _make_local_dataset( @@ -157,27 +158,93 @@ def test_state_dict_resume_continues_without_restart(tmp_path, lerobot_dataset_f def fresh_ds(): return StreamingLeRobotDataset( - repo_id=repo_id, root=tmp_path / "ds", shuffle=False, buffer_size=1, max_num_shards=1 + repo_id=repo_id, + root=tmp_path / "ds", + shuffle=True, + seed=7, + episode_pool_size=3, + max_num_shards=1, ) - ds = fresh_ds() - it = iter(ds) - stop_after = 40 - seen_before = [int(next(it)["index"]) for _ in range(stop_after)] - state = ds.state_dict() - assert set(state) == {"shards", "exhausted", "rng"} + full_epoch = _stream_indices(fresh_ds()) + assert sorted(full_epoch) == list(range(total_frames)) + batches_consumed, batch_size = 5, 4 # 20 samples in resumed_ds = fresh_ds() - resumed_ds.load_state_dict(state) + resumed_ds.load_state_dict({"batches_consumed": batches_consumed, "batch_size": batch_size}) resumed = _stream_indices(resumed_ds) - # Resume continues rather than replaying: the full first pass is not re-yielded. - assert len(resumed) < total_frames - overlap = set(seen_before) & set(resumed) - assert len(overlap) <= 2, f"resume re-yielded already-seen frames: {sorted(overlap)}" - # Together the two passes cover essentially the whole dataset (a few frames may be dropped by the - # ahead-read at the resume boundary -- documented non-bit-exact behaviour). - assert len(set(seen_before) | set(resumed)) >= total_frames - 2 + assert resumed == full_epoch[batches_consumed * batch_size :], ( + "fast-forward resume did not continue at the exact sample" + ) + + +def test_pool_order_is_deterministic_per_seed(tmp_path, lerobot_dataset_factory): + repo_id = f"{DUMMY_REPO_ID}-seeds" + _make_local_dataset(lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=6, total_frames=120) + + def order(seed): + return _stream_indices( + StreamingLeRobotDataset( + repo_id=repo_id, + root=tmp_path / "ds", + shuffle=True, + seed=seed, + episode_pool_size=4, + max_num_shards=2, + ) + ) + + assert order(0) == order(0), "same seed must reproduce the same order" + assert order(0) != order(1), "different seeds should give different orders" + + +def test_pool_epochs_reshuffle_and_cover(tmp_path, lerobot_dataset_factory): + """Consecutive passes over the same dataset object reshuffle (epoch advances) but keep coverage.""" + repo_id = f"{DUMMY_REPO_ID}-epochs" + total_frames = 120 + _make_local_dataset( + lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=6, total_frames=total_frames + ) + ds = StreamingLeRobotDataset( + repo_id=repo_id, root=tmp_path / "ds", shuffle=True, seed=3, episode_pool_size=4, max_num_shards=2 + ) + epoch_0 = _stream_indices(ds) + epoch_1 = _stream_indices(ds) + assert sorted(epoch_0) == sorted(epoch_1) == list(range(total_frames)) + assert epoch_0 != epoch_1, "epoch did not reshuffle" + + +def test_pool_mixes_episodes(tmp_path, lerobot_dataset_factory): + """Early samples should already come from several distinct episodes (the pool's purpose).""" + repo_id = f"{DUMMY_REPO_ID}-mix" + _make_local_dataset(lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=8, total_frames=200) + ds = StreamingLeRobotDataset( + repo_id=repo_id, root=tmp_path / "ds", shuffle=True, seed=0, episode_pool_size=8, max_num_shards=4 + ) + episodes_in_head = {int(frame["episode_index"]) for _, frame in zip(range(20), ds, strict=False)} + assert len(episodes_in_head) >= 3, f"pool did not mix episodes: {episodes_in_head}" + + +def test_video_prefetcher_refcounted_lifecycle(tmp_path): + from lerobot.datasets.streaming_dataset import _VideoPrefetcher + + remote = tmp_path / "remote" + (remote / "videos").mkdir(parents=True) + payload = b"x" * 1024 + (remote / "videos" / "a.mp4").write_bytes(payload) + + prefetcher = _VideoPrefetcher(str(remote), cache_dir=tmp_path / "cache", max_workers=1) + prefetcher.acquire("videos/a.mp4") + prefetcher.acquire("videos/a.mp4") # second pooled episode sharing the file + local = prefetcher.wait_local("videos/a.mp4") + assert local is not None and local.read_bytes() == payload + + prefetcher.release("videos/a.mp4") + assert local.exists(), "file deleted while still referenced" + prefetcher.release("videos/a.mp4") + assert not local.exists(), "file not deleted at refcount zero" + prefetcher.shutdown() def test_schema_parity_with_map_style(tmp_path, lerobot_dataset_factory): @@ -187,7 +254,7 @@ def test_schema_parity_with_map_style(tmp_path, lerobot_dataset_factory): root=tmp_path / "ds", repo_id=repo_id, total_episodes=4, total_frames=80, use_videos=True ) stream_ds = StreamingLeRobotDataset( - repo_id=repo_id, root=tmp_path / "ds", shuffle=False, buffer_size=4, max_num_shards=2 + repo_id=repo_id, root=tmp_path / "ds", shuffle=False, episode_pool_size=4, max_num_shards=2 ) map_frame = map_ds[0] @@ -217,7 +284,7 @@ def test_video_path_resolution_local(tmp_path, lerobot_dataset_factory, monkeypa root=tmp_path / "ds", repo_id=repo_id, total_episodes=2, total_frames=40, use_videos=True ) ds = StreamingLeRobotDataset( - repo_id=repo_id, root=tmp_path / "ds", shuffle=False, buffer_size=1, max_num_shards=1 + repo_id=repo_id, root=tmp_path / "ds", shuffle=False, episode_pool_size=1, max_num_shards=1 ) seen_paths = [] @@ -239,12 +306,12 @@ def test_shuffle_decorrelates_output_order(tmp_path, lerobot_dataset_factory): _make_local_dataset(lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=8, total_frames=200) ordered = _stream_indices( StreamingLeRobotDataset( - repo_id=repo_id, root=tmp_path / "ds", shuffle=False, buffer_size=1, max_num_shards=1 + repo_id=repo_id, root=tmp_path / "ds", shuffle=False, episode_pool_size=1, max_num_shards=1 ) ) shuffled = _stream_indices( StreamingLeRobotDataset( - repo_id=repo_id, root=tmp_path / "ds", shuffle=True, buffer_size=64, max_num_shards=4, seed=0 + repo_id=repo_id, root=tmp_path / "ds", shuffle=True, episode_pool_size=8, max_num_shards=4, seed=0 ) ) assert sorted(shuffled) == sorted(ordered), "shuffling changed the set of frames"