feat(streaming): distributed + resumable HF-native StreamingLeRobotDataset

Add the large-scale streaming pieces that were missing from the frame-streaming
internals, keeping the existing Backtrackable + output-reservoir frame-shuffle:

- split_dataset_by_node(rank, world_size) before the per-shard loop so each rank
  streams a disjoint set of shards (fixes duplicate data across GPUs). rank and
  world_size auto-resolve from Accelerate state / RANK,WORLD_SIZE env / (0, 1).
- get_worker_info() shard splitting so DataLoader workers within a rank don't
  yield duplicate frames.
- Dynamic Backtrackable window (dynamic_bounds=True) sized to the requested
  delta_timestamps, removing the fixed 100-frame ceiling so long horizons (e.g. a
  SARM window ~160 frames) reach real frames instead of silently padding. Fix the
  peek_back off-by-one: history = lookback + 1.
- video_decoder_cache_size knob; default (active_shards + 1) x num_cameras so the
  live decoder working set does not thrash the VideoDecoderCache LRU.
- state_dict()/load_state_dict() for resume (per-shard HF stream state + exhausted
  set + RNG). Reservoir is re-warmed, so resumption is not bit-exact (documented).
- factory.py wires buffer_size from a new DatasetConfig.streaming_buffer_size field
  instead of repurposing max_num_shards as the worker count.

Tests: tests/datasets/test_streaming_native.py covers distributed disjointness,
worker de-duplication, the SARM-length window, resume, schema parity vs map-style,
local video path resolution, and shuffle decorrelation. 21 passed (13 existing + 8).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-06-09 13:37:30 +02:00
parent 49755a3d9e
commit d1fc8e298c
4 changed files with 377 additions and 11 deletions
+3
View File
@@ -39,6 +39,9 @@ class DatasetConfig:
# This reduces memory and speeds up DataLoader IPC. The training pipeline handles the conversion.
return_uint8: bool = False
streaming: bool = False
# Output shuffle-buffer size (in frames) when streaming. Larger decorrelates samples better at the cost
# of host RAM. Ignored when streaming is False.
streaming_buffer_size: int = 1000
def __post_init__(self) -> None:
if self.episodes is not None:
+1 -1
View File
@@ -106,7 +106,7 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
delta_timestamps=delta_timestamps,
image_transforms=image_transforms,
revision=cfg.dataset.revision,
max_num_shards=cfg.num_workers,
buffer_size=cfg.dataset.streaming_buffer_size,
tolerance_s=cfg.tolerance_s,
return_uint8=True,
)
+127 -10
View File
@@ -13,6 +13,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import math
import os
from collections import deque
from collections.abc import Callable, Generator, Iterable, Iterator
from pathlib import Path
@@ -21,6 +24,7 @@ import datasets
import numpy as np
import torch
from datasets import load_dataset
from datasets.distributed import split_dataset_by_node
from lerobot.utils.constants import HF_LEROBOT_HOME, LOOKAHEAD_BACKTRACKTABLE, LOOKBACK_BACKTRACKTABLE
@@ -38,6 +42,8 @@ from .video_utils import (
decode_video_frames_torchcodec,
)
logger = logging.getLogger(__name__)
class LookBackError(Exception):
"""
@@ -252,6 +258,9 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
rng: np.random.Generator | None = None,
shuffle: bool = True,
return_uint8: bool = False,
rank: int | None = None,
world_size: int | None = None,
video_decoder_cache_size: int | None = None,
):
"""Initialize a StreamingLeRobotDataset.
@@ -272,6 +281,15 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
seed (int, optional): Reproducibility random seed.
rng (np.random.Generator | None, optional): Random number generator.
shuffle (bool, optional): Whether to shuffle the dataset across exhaustions. Defaults to True.
rank (int | None, optional): This process' rank for distributed (multi-GPU/multi-node) training.
Each rank streams a disjoint set of shards via ``split_dataset_by_node``. When omitted, it is
resolved from Accelerate (``process_index``) or the ``RANK`` env var, defaulting to 0.
world_size (int | None, optional): Total number of distributed processes. When omitted, resolved
from Accelerate (``num_processes``) or the ``WORLD_SIZE`` env var, defaulting to 1 (no sharding).
For an even per-rank split, ``num_shards % world_size == 0`` should hold.
video_decoder_cache_size (int | None, optional): Max number of open video decoders to retain.
When omitted, it defaults to ``(concurrent active shards + 1) × num_cameras`` so the working
set of live decoders never thrashes. See :class:`VideoDecoderCache`.
"""
super().__init__()
self.repo_id = repo_id
@@ -289,10 +307,16 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
self.streaming = streaming
self.buffer_size = buffer_size
self.max_num_shards = max_num_shards
self._return_uint8 = return_uint8
self.rank, self.world_size = self._resolve_distributed(rank, world_size)
self.video_decoder_cache_size = video_decoder_cache_size
# We cache the video decoders to avoid re-initializing them at each frame (avoiding a ~10x slowdown)
self.video_decoder_cache = None
# Resume state captured by load_state_dict() and consumed at the next __iter__.
self._resume_state: dict | None = None
if self._requested_root is not None:
self.root.mkdir(exist_ok=True, parents=True)
@@ -348,22 +372,91 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
while True:
yield rng.choice(elements)
@staticmethod
def _resolve_distributed(rank: int | None, world_size: int | None) -> tuple[int, int]:
"""Resolve (rank, world_size) for distributed streaming.
Explicit arguments win. Otherwise prefer an already-initialized Accelerate state, then the
``RANK``/``WORLD_SIZE`` env vars set by launchers, and finally fall back to single-process (0, 1).
"""
if rank is not None and world_size is not None:
return rank, world_size
try:
from accelerate.state import PartialState
if PartialState._shared_state: # only read it if already initialized; never initialize here
state = PartialState()
return state.process_index, state.num_processes
except Exception:
logger.debug("Could not resolve distributed state from Accelerate; using env/defaults.")
env_rank = os.environ.get("RANK")
env_world = os.environ.get("WORLD_SIZE")
if env_rank is not None and env_world is not None:
return int(env_rank), int(env_world)
return 0, 1
def _make_video_decoder_cache(self, num_active_shards: int) -> VideoDecoderCache:
"""Size the decoder cache to the working set of live shards so it does not thrash.
Each shard mid-episode keeps one open decoder per camera; with several shards iterated
concurrently the working set is ``num_active_shards × num_cameras``. We add one shard worth of
margin so the round-robin never evicts a still-live decoder.
"""
if self.video_decoder_cache_size is not None:
return VideoDecoderCache(max_size=self.video_decoder_cache_size)
num_cameras = len(self.meta.video_keys)
if num_cameras == 0:
return VideoDecoderCache()
return VideoDecoderCache(max_size=(num_active_shards + 1) * num_cameras)
# TODO(fracapuano): Implement multi-threaded prefetching to accelerate data loading.
# The current sequential iteration is a bottleneck. A producer-consumer pattern
# could be used with a ThreadPoolExecutor to run `make_frame` (especially video decoding)
# in parallel, feeding a queue from which this iterator will yield processed items.
def __iter__(self) -> Iterator[dict[str, torch.Tensor]]:
if self.video_decoder_cache is None:
self.video_decoder_cache = VideoDecoderCache()
# Distributed correctness: each rank streams a disjoint set of shards (order preserved).
ds = self.hf_dataset
if self.world_size > 1:
ds = split_dataset_by_node(ds, rank=self.rank, world_size=self.world_size)
num_shards = min(ds.num_shards, self.max_num_shards)
shard_indices = list(range(num_shards))
# DataLoader workers within this rank further split the shards so they don't yield duplicates.
worker_info = torch.utils.data.get_worker_info()
if worker_info is not None:
shard_indices = shard_indices[worker_info.id :: worker_info.num_workers]
self.video_decoder_cache = self._make_video_decoder_cache(len(shard_indices))
# keep the same seed across exhaustions if shuffle is False, otherwise shuffle data across exhaustions
rng = np.random.default_rng(self.seed) if not self.shuffle else self.rng
# Best-effort resume: restore RNG + exhausted shards and rewind each shard's HF stream. The
# shuffle buffer is re-warmed rather than restored, so resumption is not bit-exact (acceptable
# for pretraining); the underlying stream may also skip the few frames Backtrackable read ahead.
resume = self._resume_state
self._resume_state = None
self._exhausted: set[int] = set(resume["exhausted"]) if resume is not None else set()
if resume is not None:
rng.bit_generator.state = resume["rng"]
self._shards: dict[int, datasets.IterableDataset] = {}
for idx in shard_indices:
shard = safe_shard(ds, idx, num_shards)
if resume is not None and str(idx) in resume["shards"]:
shard.load_state_dict(resume["shards"][str(idx)])
self._shards[idx] = shard
buffer_indices_generator = self._iter_random_indices(rng, self.buffer_size)
idx_to_backtrack_dataset = {
idx: self._make_backtrackable_dataset(safe_shard(self.hf_dataset, idx, self.num_shards))
for idx in range(self.num_shards)
idx: self._make_backtrackable_dataset(shard)
for idx, shard in self._shards.items()
if idx not in self._exhausted
}
# This buffer is populated while iterating on the dataset's shards
@@ -389,11 +482,31 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
StopIteration,
): # NOTE: StopIteration inside a generator throws a RuntimeError since python 3.7
del idx_to_backtrack_dataset[shard_key] # Remove exhausted shard, onto another shard
self._exhausted.add(shard_key)
# Once shards are all exhausted, shuffle the buffer and yield the remaining frames
rng.shuffle(frames_buffer)
yield from frames_buffer
def state_dict(self) -> dict:
"""Capture resume state: per-shard HF stream position, exhausted shards, and RNG state.
Must be called after iteration has started (so the shard streams exist). Restore the returned
dict with :meth:`load_state_dict` before re-iterating. The shuffle buffer is not captured, so
resumption is not bit-exact — see :meth:`__iter__`.
"""
if not hasattr(self, "_shards"):
raise RuntimeError("state_dict() requires the dataset to have been iterated at least once.")
return {
"shards": {str(idx): shard.state_dict() for idx, shard in self._shards.items()},
"exhausted": sorted(self._exhausted),
"rng": self.rng.bit_generator.state,
}
def load_state_dict(self, state_dict: dict) -> None:
"""Stage resume state captured by :meth:`state_dict`; applied at the next ``__iter__``."""
self._resume_state = state_dict
def _get_window_steps(
self, delta_timestamps: dict[str, list[float]] | None = None, dynamic_bounds: bool = False
) -> tuple[int, int]:
@@ -405,19 +518,23 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
lookback = LOOKBACK_BACKTRACKTABLE
lookahead = LOOKAHEAD_BACKTRACKTABLE
else:
# Dynamically adjust the windows based on the given delta_timesteps
# Dynamically size the windows to exactly cover the requested delta_timestamps (in frames).
# This removes the fixed LOOKAHEAD_BACKTRACKTABLE ceiling, which would raise LookAheadError for
# long horizons (e.g. a SARM window of 8 steps spaced 1s = ~160 frames @ fps20).
all_timestamps = sum(delta_timestamps.values(), [])
lookback = min(all_timestamps) * self.fps
lookahead = max(all_timestamps) * self.fps
lookback = math.floor(min(all_timestamps) * self.fps)
lookahead = math.ceil(max(all_timestamps) * self.fps)
# When lookback is >=0 it means no negative timesteps have been provided
lookback = 0 if lookback >= 0 else (lookback * -1)
lookback = 0 if lookback >= 0 else -lookback
return lookback, lookahead
def _make_backtrackable_dataset(self, dataset: datasets.IterableDataset) -> Backtrackable:
lookback, lookahead = self._get_window_steps(self.delta_timestamps)
return Backtrackable(dataset, history=lookback, lookahead=lookahead)
lookback, lookahead = self._get_window_steps(self.delta_timestamps, dynamic_bounds=True)
# Backtrackable.peek_back(n) needs `history >= n + 1`, so reach a frame `lookback` steps back requires
# history = lookback + 1. history must be >= 1 and lookahead > 0, so clamp both to at least 1.
return Backtrackable(dataset, history=max(1, lookback + 1), lookahead=max(1, lookahead))
def _make_timestamps_from_indices(
self, start_ts: float, indices: dict[str, list[int]] | None = None
+246
View File
@@ -0,0 +1,246 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the HF-native large-scale streaming additions: distributed (per-rank) sharding,
DataLoader worker splitting, SARM-sized delta windows, resumability, and schema parity."""
import pytest
import torch
from torch.utils.data import DataLoader
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
from lerobot.utils.constants import ACTION
from tests.fixtures.constants import DUMMY_REPO_ID
def _make_local_dataset(factory, root, repo_id, *, total_episodes, total_frames, use_videos=False, **kw):
factory(
root=root,
repo_id=repo_id,
total_episodes=total_episodes,
total_frames=total_frames,
use_videos=use_videos,
data_files_size_in_mb=0.001,
chunks_size=1,
**kw,
)
def _stream_indices(ds: StreamingLeRobotDataset) -> list[int]:
return [int(frame["index"]) for frame in ds]
def test_resolve_distributed_prefers_explicit_then_env(monkeypatch):
assert StreamingLeRobotDataset._resolve_distributed(2, 8) == (2, 8)
monkeypatch.delenv("RANK", raising=False)
monkeypatch.delenv("WORLD_SIZE", raising=False)
# No accelerate state, no env -> single process.
assert StreamingLeRobotDataset._resolve_distributed(None, None) == (0, 1)
monkeypatch.setenv("RANK", "3")
monkeypatch.setenv("WORLD_SIZE", "4")
assert StreamingLeRobotDataset._resolve_distributed(None, None) == (3, 4)
def test_split_by_node_disjoint_across_ranks(tmp_path, lerobot_dataset_factory):
"""Each rank must stream a disjoint set of frames, and the ranks together must cover every frame."""
repo_id = f"{DUMMY_REPO_ID}-ranks"
total_frames, total_episodes = 200, 8
_make_local_dataset(
lerobot_dataset_factory,
tmp_path / "ds",
repo_id,
total_episodes=total_episodes,
total_frames=total_frames,
)
world_size = 2
per_rank = []
for rank in range(world_size):
ds = StreamingLeRobotDataset(
repo_id=repo_id,
root=tmp_path / "ds",
shuffle=False,
buffer_size=8,
max_num_shards=8,
rank=rank,
world_size=world_size,
)
per_rank.append(set(_stream_indices(ds)))
assert per_rank[0].isdisjoint(per_rank[1]), (
"ranks streamed overlapping frames (duplicate data across GPUs)"
)
assert per_rank[0] | per_rank[1] == set(range(total_frames)), "ranks did not jointly cover all frames"
def test_dataloader_workers_no_duplicates_within_rank(tmp_path, lerobot_dataset_factory):
"""DataLoader workers within a rank must split shards so no frame is yielded twice."""
repo_id = f"{DUMMY_REPO_ID}-workers"
total_frames, total_episodes = 120, 8
_make_local_dataset(
lerobot_dataset_factory,
tmp_path / "ds",
repo_id,
total_episodes=total_episodes,
total_frames=total_frames,
)
ds = StreamingLeRobotDataset(
repo_id=repo_id, root=tmp_path / "ds", shuffle=False, buffer_size=4, max_num_shards=4
)
loader = DataLoader(ds, batch_size=None, num_workers=2)
indices = [int(batch["index"]) for batch in loader]
assert len(indices) == len(set(indices)), "DataLoader workers yielded duplicate frames within a rank"
def test_sarm_window_covers_long_horizon_without_padding(tmp_path, lerobot_dataset_factory):
"""A delta window longer than the old 100-frame ceiling must fetch real frames, not pad them.
SARM uses a window of 8 steps spaced 1s (~160 frames @ fps20). Here fps=30, so +5s = 150 frames > 100.
"""
repo_id = f"{DUMMY_REPO_ID}-sarm"
# Two episodes of 200 frames each -> a +150-frame lookahead stays inside an episode for early frames.
_make_local_dataset(lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=2, total_frames=400)
horizon_s = 5.0 # 150 frames @ fps30, well beyond LOOKAHEAD_BACKTRACKTABLE=100
delta_timestamps = {ACTION: [0.0, horizon_s]}
ds = StreamingLeRobotDataset(
repo_id=repo_id,
root=tmp_path / "ds",
shuffle=False,
buffer_size=1,
max_num_shards=1,
delta_timestamps=delta_timestamps,
)
horizon_frames = int(round(horizon_s * ds.fps))
checked = 0
for frame in ds:
idx = int(frame["index"])
# Only assert on frames whose +horizon target is still inside the same episode.
if int(frame["episode_index"]) == 0 and idx + horizon_frames < 200:
assert not bool(frame[f"{ACTION}_is_pad"][-1]), (
f"frame {idx}: +{horizon_frames} target was padded; long delta window did not reach it"
)
checked += 1
assert checked > 0, "test did not exercise any in-episode long-horizon frame"
def test_state_dict_resume_continues_without_restart(tmp_path, lerobot_dataset_factory):
"""state_dict()/load_state_dict() must resume the stream near where it stopped, not from the start."""
repo_id = f"{DUMMY_REPO_ID}-resume"
total_frames = 100
_make_local_dataset(
lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=5, total_frames=total_frames
)
def fresh_ds():
return StreamingLeRobotDataset(
repo_id=repo_id, root=tmp_path / "ds", shuffle=False, buffer_size=1, max_num_shards=1
)
ds = fresh_ds()
it = iter(ds)
stop_after = 40
seen_before = [int(next(it)["index"]) for _ in range(stop_after)]
state = ds.state_dict()
assert set(state) == {"shards", "exhausted", "rng"}
resumed_ds = fresh_ds()
resumed_ds.load_state_dict(state)
resumed = _stream_indices(resumed_ds)
# Resume continues rather than replaying: the full first pass is not re-yielded.
assert len(resumed) < total_frames
overlap = set(seen_before) & set(resumed)
assert len(overlap) <= 2, f"resume re-yielded already-seen frames: {sorted(overlap)}"
# Together the two passes cover essentially the whole dataset (a few frames may be dropped by the
# ahead-read at the resume boundary -- documented non-bit-exact behaviour).
assert len(set(seen_before) | set(resumed)) >= total_frames - 2
def test_schema_parity_with_map_style(tmp_path, lerobot_dataset_factory):
"""Streamed samples must have the same keys / shapes / dtypes as map-style LeRobotDataset."""
repo_id = f"{DUMMY_REPO_ID}-parity"
map_ds = lerobot_dataset_factory(
root=tmp_path / "ds", repo_id=repo_id, total_episodes=4, total_frames=80, use_videos=True
)
stream_ds = StreamingLeRobotDataset(
repo_id=repo_id, root=tmp_path / "ds", shuffle=False, buffer_size=4, max_num_shards=2
)
map_frame = map_ds[0]
stream_frame = next(iter(stream_ds))
assert set(stream_frame) == set(map_frame), set(stream_frame) ^ set(map_frame)
for key, value in stream_frame.items():
ref = map_frame[key]
if isinstance(value, torch.Tensor):
assert isinstance(ref, torch.Tensor) and value.shape == ref.shape and value.dtype == ref.dtype, (
f"{key}: stream {tuple(value.shape)}/{value.dtype} vs map {tuple(ref.shape)}/{ref.dtype}"
)
elif isinstance(value, str):
assert isinstance(ref, str), f"{key}: {type(value)} vs {type(ref)}"
else:
# Scalar numerics: streaming yields python floats where map-style yields 0-dim tensors
# (a long-standing, accepted difference). Compare by value rather than exact type.
assert float(value) == float(ref), f"{key}: {value} vs {ref}"
def test_video_path_resolution_local(tmp_path, lerobot_dataset_factory, monkeypatch):
"""For a local (prewarmed) root, video decode must be issued against the local path, not hf://."""
import lerobot.datasets.streaming_dataset as sd
repo_id = f"{DUMMY_REPO_ID}-vpath"
lerobot_dataset_factory(
root=tmp_path / "ds", repo_id=repo_id, total_episodes=2, total_frames=40, use_videos=True
)
ds = StreamingLeRobotDataset(
repo_id=repo_id, root=tmp_path / "ds", shuffle=False, buffer_size=1, max_num_shards=1
)
seen_paths = []
def fake_decode(video_path, query_ts, *args, **kwargs):
seen_paths.append(str(video_path))
return torch.zeros(len(query_ts), 3, 64, 96)
monkeypatch.setattr(sd, "decode_video_frames_torchcodec", fake_decode)
next(iter(ds))
assert seen_paths, "no video decode was issued"
assert all(str(ds.root) in p and not p.startswith("hf://") for p in seen_paths), seen_paths
def test_shuffle_decorrelates_output_order(tmp_path, lerobot_dataset_factory):
"""With shuffle on, streamed frame order must differ from the underlying sequential order."""
repo_id = f"{DUMMY_REPO_ID}-shuf"
_make_local_dataset(lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=8, total_frames=200)
ordered = _stream_indices(
StreamingLeRobotDataset(
repo_id=repo_id, root=tmp_path / "ds", shuffle=False, buffer_size=1, max_num_shards=1
)
)
shuffled = _stream_indices(
StreamingLeRobotDataset(
repo_id=repo_id, root=tmp_path / "ds", shuffle=True, buffer_size=64, max_num_shards=4, seed=0
)
)
assert sorted(shuffled) == sorted(ordered), "shuffling changed the set of frames"
assert shuffled != ordered, "shuffle did not decorrelate output order"