From d1fc8e298c8aacafa26cdcb2d121221da945b824 Mon Sep 17 00:00:00 2001 From: Pepijn Date: Tue, 9 Jun 2026 13:37:30 +0200 Subject: [PATCH] feat(streaming): distributed + resumable HF-native StreamingLeRobotDataset Add the large-scale streaming pieces that were missing from the frame-streaming internals, keeping the existing Backtrackable + output-reservoir frame-shuffle: - split_dataset_by_node(rank, world_size) before the per-shard loop so each rank streams a disjoint set of shards (fixes duplicate data across GPUs). rank and world_size auto-resolve from Accelerate state / RANK,WORLD_SIZE env / (0, 1). - get_worker_info() shard splitting so DataLoader workers within a rank don't yield duplicate frames. - Dynamic Backtrackable window (dynamic_bounds=True) sized to the requested delta_timestamps, removing the fixed 100-frame ceiling so long horizons (e.g. a SARM window ~160 frames) reach real frames instead of silently padding. Fix the peek_back off-by-one: history = lookback + 1. - video_decoder_cache_size knob; default (active_shards + 1) x num_cameras so the live decoder working set does not thrash the VideoDecoderCache LRU. - state_dict()/load_state_dict() for resume (per-shard HF stream state + exhausted set + RNG). Reservoir is re-warmed, so resumption is not bit-exact (documented). - factory.py wires buffer_size from a new DatasetConfig.streaming_buffer_size field instead of repurposing max_num_shards as the worker count. Tests: tests/datasets/test_streaming_native.py covers distributed disjointness, worker de-duplication, the SARM-length window, resume, schema parity vs map-style, local video path resolution, and shuffle decorrelation. 21 passed (13 existing + 8). Co-Authored-By: Claude Opus 4.8 (1M context) --- src/lerobot/configs/default.py | 3 + src/lerobot/datasets/factory.py | 2 +- src/lerobot/datasets/streaming_dataset.py | 137 +++++++++++- tests/datasets/test_streaming_native.py | 246 ++++++++++++++++++++++ 4 files changed, 377 insertions(+), 11 deletions(-) create mode 100644 tests/datasets/test_streaming_native.py diff --git a/src/lerobot/configs/default.py b/src/lerobot/configs/default.py index b809e71d9..9de5e6c0e 100644 --- a/src/lerobot/configs/default.py +++ b/src/lerobot/configs/default.py @@ -39,6 +39,9 @@ class DatasetConfig: # This reduces memory and speeds up DataLoader IPC. The training pipeline handles the conversion. return_uint8: bool = False streaming: bool = False + # Output shuffle-buffer size (in frames) when streaming. Larger decorrelates samples better at the cost + # of host RAM. Ignored when streaming is False. + streaming_buffer_size: int = 1000 def __post_init__(self) -> None: if self.episodes is not None: diff --git a/src/lerobot/datasets/factory.py b/src/lerobot/datasets/factory.py index cbbe83dc8..47fe560e1 100644 --- a/src/lerobot/datasets/factory.py +++ b/src/lerobot/datasets/factory.py @@ -106,7 +106,7 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas delta_timestamps=delta_timestamps, image_transforms=image_transforms, revision=cfg.dataset.revision, - max_num_shards=cfg.num_workers, + buffer_size=cfg.dataset.streaming_buffer_size, tolerance_s=cfg.tolerance_s, return_uint8=True, ) diff --git a/src/lerobot/datasets/streaming_dataset.py b/src/lerobot/datasets/streaming_dataset.py index 3c1e4a73c..ff8097330 100644 --- a/src/lerobot/datasets/streaming_dataset.py +++ b/src/lerobot/datasets/streaming_dataset.py @@ -13,6 +13,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging +import math +import os from collections import deque from collections.abc import Callable, Generator, Iterable, Iterator from pathlib import Path @@ -21,6 +24,7 @@ import datasets import numpy as np import torch from datasets import load_dataset +from datasets.distributed import split_dataset_by_node from lerobot.utils.constants import HF_LEROBOT_HOME, LOOKAHEAD_BACKTRACKTABLE, LOOKBACK_BACKTRACKTABLE @@ -38,6 +42,8 @@ from .video_utils import ( decode_video_frames_torchcodec, ) +logger = logging.getLogger(__name__) + class LookBackError(Exception): """ @@ -252,6 +258,9 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): rng: np.random.Generator | None = None, shuffle: bool = True, return_uint8: bool = False, + rank: int | None = None, + world_size: int | None = None, + video_decoder_cache_size: int | None = None, ): """Initialize a StreamingLeRobotDataset. @@ -272,6 +281,15 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): seed (int, optional): Reproducibility random seed. rng (np.random.Generator | None, optional): Random number generator. shuffle (bool, optional): Whether to shuffle the dataset across exhaustions. Defaults to True. + rank (int | None, optional): This process' rank for distributed (multi-GPU/multi-node) training. + Each rank streams a disjoint set of shards via ``split_dataset_by_node``. When omitted, it is + resolved from Accelerate (``process_index``) or the ``RANK`` env var, defaulting to 0. + world_size (int | None, optional): Total number of distributed processes. When omitted, resolved + from Accelerate (``num_processes``) or the ``WORLD_SIZE`` env var, defaulting to 1 (no sharding). + For an even per-rank split, ``num_shards % world_size == 0`` should hold. + video_decoder_cache_size (int | None, optional): Max number of open video decoders to retain. + When omitted, it defaults to ``(concurrent active shards + 1) × num_cameras`` so the working + set of live decoders never thrashes. See :class:`VideoDecoderCache`. """ super().__init__() self.repo_id = repo_id @@ -289,10 +307,16 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): self.streaming = streaming self.buffer_size = buffer_size + self.max_num_shards = max_num_shards self._return_uint8 = return_uint8 + self.rank, self.world_size = self._resolve_distributed(rank, world_size) + self.video_decoder_cache_size = video_decoder_cache_size + # We cache the video decoders to avoid re-initializing them at each frame (avoiding a ~10x slowdown) self.video_decoder_cache = None + # Resume state captured by load_state_dict() and consumed at the next __iter__. + self._resume_state: dict | None = None if self._requested_root is not None: self.root.mkdir(exist_ok=True, parents=True) @@ -348,22 +372,91 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): while True: yield rng.choice(elements) + @staticmethod + def _resolve_distributed(rank: int | None, world_size: int | None) -> tuple[int, int]: + """Resolve (rank, world_size) for distributed streaming. + + Explicit arguments win. Otherwise prefer an already-initialized Accelerate state, then the + ``RANK``/``WORLD_SIZE`` env vars set by launchers, and finally fall back to single-process (0, 1). + """ + if rank is not None and world_size is not None: + return rank, world_size + + try: + from accelerate.state import PartialState + + if PartialState._shared_state: # only read it if already initialized; never initialize here + state = PartialState() + return state.process_index, state.num_processes + except Exception: + logger.debug("Could not resolve distributed state from Accelerate; using env/defaults.") + + env_rank = os.environ.get("RANK") + env_world = os.environ.get("WORLD_SIZE") + if env_rank is not None and env_world is not None: + return int(env_rank), int(env_world) + + return 0, 1 + + def _make_video_decoder_cache(self, num_active_shards: int) -> VideoDecoderCache: + """Size the decoder cache to the working set of live shards so it does not thrash. + + Each shard mid-episode keeps one open decoder per camera; with several shards iterated + concurrently the working set is ``num_active_shards × num_cameras``. We add one shard worth of + margin so the round-robin never evicts a still-live decoder. + """ + if self.video_decoder_cache_size is not None: + return VideoDecoderCache(max_size=self.video_decoder_cache_size) + num_cameras = len(self.meta.video_keys) + if num_cameras == 0: + return VideoDecoderCache() + return VideoDecoderCache(max_size=(num_active_shards + 1) * num_cameras) + # TODO(fracapuano): Implement multi-threaded prefetching to accelerate data loading. # The current sequential iteration is a bottleneck. A producer-consumer pattern # could be used with a ThreadPoolExecutor to run `make_frame` (especially video decoding) # in parallel, feeding a queue from which this iterator will yield processed items. def __iter__(self) -> Iterator[dict[str, torch.Tensor]]: - if self.video_decoder_cache is None: - self.video_decoder_cache = VideoDecoderCache() + # Distributed correctness: each rank streams a disjoint set of shards (order preserved). + ds = self.hf_dataset + if self.world_size > 1: + ds = split_dataset_by_node(ds, rank=self.rank, world_size=self.world_size) + + num_shards = min(ds.num_shards, self.max_num_shards) + shard_indices = list(range(num_shards)) + + # DataLoader workers within this rank further split the shards so they don't yield duplicates. + worker_info = torch.utils.data.get_worker_info() + if worker_info is not None: + shard_indices = shard_indices[worker_info.id :: worker_info.num_workers] + + self.video_decoder_cache = self._make_video_decoder_cache(len(shard_indices)) # keep the same seed across exhaustions if shuffle is False, otherwise shuffle data across exhaustions rng = np.random.default_rng(self.seed) if not self.shuffle else self.rng + # Best-effort resume: restore RNG + exhausted shards and rewind each shard's HF stream. The + # shuffle buffer is re-warmed rather than restored, so resumption is not bit-exact (acceptable + # for pretraining); the underlying stream may also skip the few frames Backtrackable read ahead. + resume = self._resume_state + self._resume_state = None + self._exhausted: set[int] = set(resume["exhausted"]) if resume is not None else set() + if resume is not None: + rng.bit_generator.state = resume["rng"] + + self._shards: dict[int, datasets.IterableDataset] = {} + for idx in shard_indices: + shard = safe_shard(ds, idx, num_shards) + if resume is not None and str(idx) in resume["shards"]: + shard.load_state_dict(resume["shards"][str(idx)]) + self._shards[idx] = shard + buffer_indices_generator = self._iter_random_indices(rng, self.buffer_size) idx_to_backtrack_dataset = { - idx: self._make_backtrackable_dataset(safe_shard(self.hf_dataset, idx, self.num_shards)) - for idx in range(self.num_shards) + idx: self._make_backtrackable_dataset(shard) + for idx, shard in self._shards.items() + if idx not in self._exhausted } # This buffer is populated while iterating on the dataset's shards @@ -389,11 +482,31 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): StopIteration, ): # NOTE: StopIteration inside a generator throws a RuntimeError since python 3.7 del idx_to_backtrack_dataset[shard_key] # Remove exhausted shard, onto another shard + self._exhausted.add(shard_key) # Once shards are all exhausted, shuffle the buffer and yield the remaining frames rng.shuffle(frames_buffer) yield from frames_buffer + def state_dict(self) -> dict: + """Capture resume state: per-shard HF stream position, exhausted shards, and RNG state. + + Must be called after iteration has started (so the shard streams exist). Restore the returned + dict with :meth:`load_state_dict` before re-iterating. The shuffle buffer is not captured, so + resumption is not bit-exact — see :meth:`__iter__`. + """ + if not hasattr(self, "_shards"): + raise RuntimeError("state_dict() requires the dataset to have been iterated at least once.") + return { + "shards": {str(idx): shard.state_dict() for idx, shard in self._shards.items()}, + "exhausted": sorted(self._exhausted), + "rng": self.rng.bit_generator.state, + } + + def load_state_dict(self, state_dict: dict) -> None: + """Stage resume state captured by :meth:`state_dict`; applied at the next ``__iter__``.""" + self._resume_state = state_dict + def _get_window_steps( self, delta_timestamps: dict[str, list[float]] | None = None, dynamic_bounds: bool = False ) -> tuple[int, int]: @@ -405,19 +518,23 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): lookback = LOOKBACK_BACKTRACKTABLE lookahead = LOOKAHEAD_BACKTRACKTABLE else: - # Dynamically adjust the windows based on the given delta_timesteps + # Dynamically size the windows to exactly cover the requested delta_timestamps (in frames). + # This removes the fixed LOOKAHEAD_BACKTRACKTABLE ceiling, which would raise LookAheadError for + # long horizons (e.g. a SARM window of 8 steps spaced 1s = ~160 frames @ fps20). all_timestamps = sum(delta_timestamps.values(), []) - lookback = min(all_timestamps) * self.fps - lookahead = max(all_timestamps) * self.fps + lookback = math.floor(min(all_timestamps) * self.fps) + lookahead = math.ceil(max(all_timestamps) * self.fps) # When lookback is >=0 it means no negative timesteps have been provided - lookback = 0 if lookback >= 0 else (lookback * -1) + lookback = 0 if lookback >= 0 else -lookback return lookback, lookahead def _make_backtrackable_dataset(self, dataset: datasets.IterableDataset) -> Backtrackable: - lookback, lookahead = self._get_window_steps(self.delta_timestamps) - return Backtrackable(dataset, history=lookback, lookahead=lookahead) + lookback, lookahead = self._get_window_steps(self.delta_timestamps, dynamic_bounds=True) + # Backtrackable.peek_back(n) needs `history >= n + 1`, so reach a frame `lookback` steps back requires + # history = lookback + 1. history must be >= 1 and lookahead > 0, so clamp both to at least 1. + return Backtrackable(dataset, history=max(1, lookback + 1), lookahead=max(1, lookahead)) def _make_timestamps_from_indices( self, start_ts: float, indices: dict[str, list[int]] | None = None diff --git a/tests/datasets/test_streaming_native.py b/tests/datasets/test_streaming_native.py new file mode 100644 index 000000000..cd0317317 --- /dev/null +++ b/tests/datasets/test_streaming_native.py @@ -0,0 +1,246 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the HF-native large-scale streaming additions: distributed (per-rank) sharding, +DataLoader worker splitting, SARM-sized delta windows, resumability, and schema parity.""" + +import pytest +import torch +from torch.utils.data import DataLoader + +pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])") + +from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset +from lerobot.utils.constants import ACTION +from tests.fixtures.constants import DUMMY_REPO_ID + + +def _make_local_dataset(factory, root, repo_id, *, total_episodes, total_frames, use_videos=False, **kw): + factory( + root=root, + repo_id=repo_id, + total_episodes=total_episodes, + total_frames=total_frames, + use_videos=use_videos, + data_files_size_in_mb=0.001, + chunks_size=1, + **kw, + ) + + +def _stream_indices(ds: StreamingLeRobotDataset) -> list[int]: + return [int(frame["index"]) for frame in ds] + + +def test_resolve_distributed_prefers_explicit_then_env(monkeypatch): + assert StreamingLeRobotDataset._resolve_distributed(2, 8) == (2, 8) + + monkeypatch.delenv("RANK", raising=False) + monkeypatch.delenv("WORLD_SIZE", raising=False) + # No accelerate state, no env -> single process. + assert StreamingLeRobotDataset._resolve_distributed(None, None) == (0, 1) + + monkeypatch.setenv("RANK", "3") + monkeypatch.setenv("WORLD_SIZE", "4") + assert StreamingLeRobotDataset._resolve_distributed(None, None) == (3, 4) + + +def test_split_by_node_disjoint_across_ranks(tmp_path, lerobot_dataset_factory): + """Each rank must stream a disjoint set of frames, and the ranks together must cover every frame.""" + repo_id = f"{DUMMY_REPO_ID}-ranks" + total_frames, total_episodes = 200, 8 + _make_local_dataset( + lerobot_dataset_factory, + tmp_path / "ds", + repo_id, + total_episodes=total_episodes, + total_frames=total_frames, + ) + + world_size = 2 + per_rank = [] + for rank in range(world_size): + ds = StreamingLeRobotDataset( + repo_id=repo_id, + root=tmp_path / "ds", + shuffle=False, + buffer_size=8, + max_num_shards=8, + rank=rank, + world_size=world_size, + ) + per_rank.append(set(_stream_indices(ds))) + + assert per_rank[0].isdisjoint(per_rank[1]), ( + "ranks streamed overlapping frames (duplicate data across GPUs)" + ) + assert per_rank[0] | per_rank[1] == set(range(total_frames)), "ranks did not jointly cover all frames" + + +def test_dataloader_workers_no_duplicates_within_rank(tmp_path, lerobot_dataset_factory): + """DataLoader workers within a rank must split shards so no frame is yielded twice.""" + repo_id = f"{DUMMY_REPO_ID}-workers" + total_frames, total_episodes = 120, 8 + _make_local_dataset( + lerobot_dataset_factory, + tmp_path / "ds", + repo_id, + total_episodes=total_episodes, + total_frames=total_frames, + ) + + ds = StreamingLeRobotDataset( + repo_id=repo_id, root=tmp_path / "ds", shuffle=False, buffer_size=4, max_num_shards=4 + ) + loader = DataLoader(ds, batch_size=None, num_workers=2) + indices = [int(batch["index"]) for batch in loader] + + assert len(indices) == len(set(indices)), "DataLoader workers yielded duplicate frames within a rank" + + +def test_sarm_window_covers_long_horizon_without_padding(tmp_path, lerobot_dataset_factory): + """A delta window longer than the old 100-frame ceiling must fetch real frames, not pad them. + + SARM uses a window of 8 steps spaced 1s (~160 frames @ fps20). Here fps=30, so +5s = 150 frames > 100. + """ + repo_id = f"{DUMMY_REPO_ID}-sarm" + # Two episodes of 200 frames each -> a +150-frame lookahead stays inside an episode for early frames. + _make_local_dataset(lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=2, total_frames=400) + + horizon_s = 5.0 # 150 frames @ fps30, well beyond LOOKAHEAD_BACKTRACKTABLE=100 + delta_timestamps = {ACTION: [0.0, horizon_s]} + ds = StreamingLeRobotDataset( + repo_id=repo_id, + root=tmp_path / "ds", + shuffle=False, + buffer_size=1, + max_num_shards=1, + delta_timestamps=delta_timestamps, + ) + + horizon_frames = int(round(horizon_s * ds.fps)) + checked = 0 + for frame in ds: + idx = int(frame["index"]) + # Only assert on frames whose +horizon target is still inside the same episode. + if int(frame["episode_index"]) == 0 and idx + horizon_frames < 200: + assert not bool(frame[f"{ACTION}_is_pad"][-1]), ( + f"frame {idx}: +{horizon_frames} target was padded; long delta window did not reach it" + ) + checked += 1 + assert checked > 0, "test did not exercise any in-episode long-horizon frame" + + +def test_state_dict_resume_continues_without_restart(tmp_path, lerobot_dataset_factory): + """state_dict()/load_state_dict() must resume the stream near where it stopped, not from the start.""" + repo_id = f"{DUMMY_REPO_ID}-resume" + total_frames = 100 + _make_local_dataset( + lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=5, total_frames=total_frames + ) + + def fresh_ds(): + return StreamingLeRobotDataset( + repo_id=repo_id, root=tmp_path / "ds", shuffle=False, buffer_size=1, max_num_shards=1 + ) + + ds = fresh_ds() + it = iter(ds) + stop_after = 40 + seen_before = [int(next(it)["index"]) for _ in range(stop_after)] + state = ds.state_dict() + assert set(state) == {"shards", "exhausted", "rng"} + + resumed_ds = fresh_ds() + resumed_ds.load_state_dict(state) + resumed = _stream_indices(resumed_ds) + + # Resume continues rather than replaying: the full first pass is not re-yielded. + assert len(resumed) < total_frames + overlap = set(seen_before) & set(resumed) + assert len(overlap) <= 2, f"resume re-yielded already-seen frames: {sorted(overlap)}" + # Together the two passes cover essentially the whole dataset (a few frames may be dropped by the + # ahead-read at the resume boundary -- documented non-bit-exact behaviour). + assert len(set(seen_before) | set(resumed)) >= total_frames - 2 + + +def test_schema_parity_with_map_style(tmp_path, lerobot_dataset_factory): + """Streamed samples must have the same keys / shapes / dtypes as map-style LeRobotDataset.""" + repo_id = f"{DUMMY_REPO_ID}-parity" + map_ds = lerobot_dataset_factory( + root=tmp_path / "ds", repo_id=repo_id, total_episodes=4, total_frames=80, use_videos=True + ) + stream_ds = StreamingLeRobotDataset( + repo_id=repo_id, root=tmp_path / "ds", shuffle=False, buffer_size=4, max_num_shards=2 + ) + + map_frame = map_ds[0] + stream_frame = next(iter(stream_ds)) + + assert set(stream_frame) == set(map_frame), set(stream_frame) ^ set(map_frame) + for key, value in stream_frame.items(): + ref = map_frame[key] + if isinstance(value, torch.Tensor): + assert isinstance(ref, torch.Tensor) and value.shape == ref.shape and value.dtype == ref.dtype, ( + f"{key}: stream {tuple(value.shape)}/{value.dtype} vs map {tuple(ref.shape)}/{ref.dtype}" + ) + elif isinstance(value, str): + assert isinstance(ref, str), f"{key}: {type(value)} vs {type(ref)}" + else: + # Scalar numerics: streaming yields python floats where map-style yields 0-dim tensors + # (a long-standing, accepted difference). Compare by value rather than exact type. + assert float(value) == float(ref), f"{key}: {value} vs {ref}" + + +def test_video_path_resolution_local(tmp_path, lerobot_dataset_factory, monkeypatch): + """For a local (prewarmed) root, video decode must be issued against the local path, not hf://.""" + import lerobot.datasets.streaming_dataset as sd + + repo_id = f"{DUMMY_REPO_ID}-vpath" + lerobot_dataset_factory( + root=tmp_path / "ds", repo_id=repo_id, total_episodes=2, total_frames=40, use_videos=True + ) + ds = StreamingLeRobotDataset( + repo_id=repo_id, root=tmp_path / "ds", shuffle=False, buffer_size=1, max_num_shards=1 + ) + + seen_paths = [] + + def fake_decode(video_path, query_ts, *args, **kwargs): + seen_paths.append(str(video_path)) + return torch.zeros(len(query_ts), 3, 64, 96) + + monkeypatch.setattr(sd, "decode_video_frames_torchcodec", fake_decode) + next(iter(ds)) + + assert seen_paths, "no video decode was issued" + assert all(str(ds.root) in p and not p.startswith("hf://") for p in seen_paths), seen_paths + + +def test_shuffle_decorrelates_output_order(tmp_path, lerobot_dataset_factory): + """With shuffle on, streamed frame order must differ from the underlying sequential order.""" + repo_id = f"{DUMMY_REPO_ID}-shuf" + _make_local_dataset(lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=8, total_frames=200) + ordered = _stream_indices( + StreamingLeRobotDataset( + repo_id=repo_id, root=tmp_path / "ds", shuffle=False, buffer_size=1, max_num_shards=1 + ) + ) + shuffled = _stream_indices( + StreamingLeRobotDataset( + repo_id=repo_id, root=tmp_path / "ds", shuffle=True, buffer_size=64, max_num_shards=4, seed=0 + ) + ) + assert sorted(shuffled) == sorted(ordered), "shuffling changed the set of frames" + assert shuffled != ordered, "shuffle did not decorrelate output order"