mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-13 06:29:57 +00:00
refactor(streaming): rebuild StreamingLeRobotDataset on native datasets primitives
The custom episode pool becomes a pure `datasets` pipeline:
split_dataset_by_node -> batch(by_column="episode_index")
-> shuffle(buffer=episode_pool_size) # episode pool
-> map(explode + exact delta windows) # episode -> frames
-> shuffle(buffer=frame_shuffle_buffer_size) # frame interleave
and the torch IterableDataset wrapper keeps only per-sample video decode
(decode-on-exit), image transforms, task lookup, and decode/fetch timing.
Replaced by native machinery and deleted: the pooled-episode admission
loop, the refcounted video prefetcher, manual worker shard striding plus
the worker-split suppression patch, the per-(epoch, rank) shard-order
permutation, the per-consumer SplitMix64 RNG, and fast-forward resume.
DataLoader workers are split by `datasets` itself; .shuffle() permutes
shard order per epoch natively; resume delegates to the native
state_dict/load_state_dict (exact with num_workers=0; with workers use
torchdata's StatefulDataLoader, which checkpoints per-worker state
through the same protocol). An in-flight epoch counter ensures a
mid-iteration state_dict records the epoch the stream position belongs
to. Buffer contents are skipped on resume (documented datasets
behavior): never repeats data, drops at most ~pool + frame-buffer frames.
Randomness is unchanged: a batch still mixes up to episode_pool_size
episodes; delta windows are still exact in-episode slices with correct
boundary padding (value-verified against the map-style dataset). The
known trade accepted with this rewrite: no video prefetch-on-admit, so
remote decode pays per-frame range reads at yield time - use a colocated
bucket (data_files_root) at large scale.
The delta-consistency tests gained a scalar-comparison branch: they
silently skipped python-scalar keys before (stale `check` variable),
exposed by the new pipeline's key ordering.
Requires datasets with #8259 (pinned to the merge commit on this
branch). Example updated to per-rank native resume via torchdata's
StatefulDataLoader when available.
Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
This commit is contained in:
@@ -21,7 +21,7 @@ streaming features of :class:`StreamingLeRobotDataset`:
|
||||
- per-rank sharding via ``split_dataset_by_node`` (each GPU streams disjoint data; ``rank``/``world_size``
|
||||
are auto-resolved from the Accelerate state, so nothing needs to be passed explicitly);
|
||||
- DataLoader-worker shard splitting (no duplicate frames within a rank);
|
||||
- deterministic fast-forward resume via ``dataset.load_state_dict()`` (trainer-side counters only);
|
||||
- native `datasets` resume: the loader checkpoints stream state via ``state_dict()`` (``torchdata`` StatefulDataLoader when available, so ``num_workers > 0`` resumes too);
|
||||
- an explicit video-decoder cache size so the working set of open decoders does not thrash.
|
||||
|
||||
Launch with Accelerate (single node, N GPUs):
|
||||
@@ -85,7 +85,16 @@ def make_dataloader(
|
||||
video_decoder_cache_size=args.video_decoder_cache_size,
|
||||
tolerance_s=1e-3,
|
||||
)
|
||||
loader = DataLoader(
|
||||
# torchdata's StatefulDataLoader checkpoints each worker's dataset state through the
|
||||
# dataset's native state_dict protocol, making resume work with num_workers > 0. Fall back
|
||||
# to the plain DataLoader (resume then requires num_workers=0).
|
||||
try:
|
||||
from torchdata.stateful_dataloader import StatefulDataLoader
|
||||
|
||||
loader_cls = StatefulDataLoader
|
||||
except ImportError:
|
||||
loader_cls = DataLoader
|
||||
loader = loader_cls(
|
||||
dataset,
|
||||
batch_size=args.batch_size,
|
||||
num_workers=args.num_workers,
|
||||
@@ -124,13 +133,17 @@ def main() -> None:
|
||||
# of it). Batches are moved to the device manually in the loop.
|
||||
model, optimizer = accelerator.prepare(model, optimizer)
|
||||
|
||||
# Resume: deterministic fast-forward. Every consumer's order is a pure function of
|
||||
# (seed, epoch, rank, worker), so resuming only needs the trainer-side counters; each rank and
|
||||
# worker re-derives its own skip. Same file works for every rank.
|
||||
# Resume: native datasets stream state, saved per rank. With torchdata's StatefulDataLoader
|
||||
# the state covers every worker; with the plain DataLoader it is exact for num_workers=0.
|
||||
can_checkpoint_loader = hasattr(loader, "state_dict")
|
||||
if args.resume_from is not None:
|
||||
state = torch.load(Path(args.resume_from) / "dataset_state.pt", weights_only=True)
|
||||
dataset.load_state_dict(state)
|
||||
accelerator.print(f"Resuming dataset stream: {state['batches_consumed']} batches consumed")
|
||||
state_path = Path(args.resume_from) / f"dataset_state_rank{accelerator.process_index}.pt"
|
||||
state = torch.load(state_path, weights_only=False) # plain dict of stream offsets # nosec B614
|
||||
if can_checkpoint_loader:
|
||||
loader.load_state_dict(state)
|
||||
else:
|
||||
dataset.load_state_dict(state)
|
||||
accelerator.print(f"Resumed dataset stream from {state_path}")
|
||||
|
||||
step = 0
|
||||
frames_seen = 0
|
||||
@@ -157,15 +170,15 @@ def main() -> None:
|
||||
)
|
||||
window_start = time.perf_counter()
|
||||
|
||||
if step % args.save_freq == 0 and accelerator.is_main_process:
|
||||
if step % args.save_freq == 0:
|
||||
ckpt = output_dir / f"checkpoint-{step}"
|
||||
ckpt.mkdir(parents=True, exist_ok=True)
|
||||
# Save the consumed-batch counters so a restart fast-forwards to this position.
|
||||
torch.save(
|
||||
{"batches_consumed": step, "batch_size": args.batch_size},
|
||||
ckpt / "dataset_state.pt",
|
||||
)
|
||||
if model is not None:
|
||||
if accelerator.is_main_process:
|
||||
ckpt.mkdir(parents=True, exist_ok=True)
|
||||
accelerator.wait_for_everyone()
|
||||
# Every rank saves its own stream state: shard positions differ per rank.
|
||||
state = loader.state_dict() if can_checkpoint_loader else dataset.state_dict()
|
||||
torch.save(state, ckpt / f"dataset_state_rank{accelerator.process_index}.pt")
|
||||
if model is not None and accelerator.is_main_process:
|
||||
accelerator.unwrap_model(model).save_pretrained(ckpt)
|
||||
|
||||
if step >= args.steps:
|
||||
|
||||
@@ -13,18 +13,12 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import contextlib
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import time
|
||||
from collections.abc import Callable, Iterator
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
import fsspec
|
||||
import numpy as np
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
@@ -39,7 +33,6 @@ from .utils import (
|
||||
check_version_compatibility,
|
||||
find_float_index,
|
||||
is_float_in_list,
|
||||
safe_shard,
|
||||
)
|
||||
from .video_utils import (
|
||||
VideoDecoderCache,
|
||||
@@ -48,134 +41,38 @@ from .video_utils import (
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# datasets >= 5 groups a stream into whole-episode batches natively (Arrow-side accumulation,
|
||||
# https://github.com/huggingface/datasets/pull/8172); older versions fall back to a Python row loop.
|
||||
_HAS_BATCH_BY_COLUMN = "by_column" in inspect.signature(datasets.IterableDataset.batch).parameters
|
||||
|
||||
_MASK_64 = (1 << 64) - 1
|
||||
|
||||
|
||||
def _mix64(x: int) -> int:
|
||||
"""SplitMix64 finalizer (64-bit integer hash) for seed derivation."""
|
||||
x = (x + 0x9E3779B97F4A7C15) & _MASK_64
|
||||
x ^= x >> 30
|
||||
x = (x * 0xBF58476D1CE4E5B9) & _MASK_64
|
||||
x ^= x >> 27
|
||||
x = (x * 0x94D049BB133111EB) & _MASK_64
|
||||
x ^= x >> 31
|
||||
return x
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _suppress_hf_worker_split():
|
||||
"""Hide the torch DataLoader worker context from `datasets` while we drain its streams.
|
||||
|
||||
`datasets` detects torch workers and re-splits its shards across them internally
|
||||
(`_iter_pytorch`); this dataset already assigns disjoint shards per worker, so the second
|
||||
split silently drops data whenever a per-worker stream has fewer internal shards than there
|
||||
are workers — and on datasets 5.0 it also crashes `batch(by_column=...)`. The patch is local
|
||||
to this DataLoader worker process and restored on exit.
|
||||
"""
|
||||
original = torch.utils.data.get_worker_info
|
||||
torch.utils.data.get_worker_info = lambda: None
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
torch.utils.data.get_worker_info = original
|
||||
|
||||
|
||||
class _PooledEpisode:
|
||||
"""A fully-loaded episode's tabular rows plus emission bookkeeping."""
|
||||
|
||||
__slots__ = ("episode_index", "rows", "remaining", "video_rel_paths")
|
||||
|
||||
def __init__(self, episode_index: int, rows: list[dict], video_rel_paths: list[str]):
|
||||
self.episode_index = episode_index
|
||||
self.rows = rows
|
||||
self.remaining = list(range(len(rows)))
|
||||
self.video_rel_paths = video_rel_paths
|
||||
|
||||
|
||||
class _VideoPrefetcher:
|
||||
"""Background downloader of episode video files into a local cache (decode-on-exit support).
|
||||
|
||||
Files are refcounted because LeRobot v3 packs several episodes per video file: a file is
|
||||
downloaded once when the first pooled episode referencing it is admitted and deleted when
|
||||
the last one is evicted. Downloads resolve through fsspec (hf://, s3://, https://, ...).
|
||||
"""
|
||||
|
||||
def __init__(self, remote_root: str, cache_dir: Path, max_workers: int = 4):
|
||||
self._remote_root = remote_root.rstrip("/")
|
||||
self._cache_dir = cache_dir
|
||||
self._executor = ThreadPoolExecutor(max_workers=max_workers, thread_name_prefix="video-prefetch")
|
||||
self._refcounts: dict[str, int] = {}
|
||||
self._futures: dict[str, Future] = {}
|
||||
|
||||
def acquire(self, rel_path: str) -> None:
|
||||
self._refcounts[rel_path] = self._refcounts.get(rel_path, 0) + 1
|
||||
if rel_path not in self._futures:
|
||||
self._futures[rel_path] = self._executor.submit(self._download, rel_path)
|
||||
|
||||
def _download(self, rel_path: str) -> Path:
|
||||
local = self._cache_dir / rel_path
|
||||
if local.exists():
|
||||
return local
|
||||
local.parent.mkdir(parents=True, exist_ok=True)
|
||||
tmp = local.with_suffix(local.suffix + ".tmp")
|
||||
with fsspec.open(f"{self._remote_root}/{rel_path}", "rb") as src, open(tmp, "wb") as dst:
|
||||
shutil.copyfileobj(src, dst, length=1 << 22)
|
||||
tmp.rename(local)
|
||||
return local
|
||||
|
||||
def wait_local(self, rel_path: str) -> Path | None:
|
||||
"""Block until the file is cached; None when not tracked or the download failed."""
|
||||
future = self._futures.get(rel_path)
|
||||
if future is None:
|
||||
return None
|
||||
try:
|
||||
return future.result()
|
||||
except Exception as e:
|
||||
logger.warning(f"Video prefetch failed for {rel_path} ({e}); decoding from remote instead.")
|
||||
return None
|
||||
|
||||
def release(self, rel_path: str) -> None:
|
||||
count = self._refcounts.get(rel_path, 0) - 1
|
||||
if count > 0:
|
||||
self._refcounts[rel_path] = count
|
||||
return
|
||||
self._refcounts.pop(rel_path, None)
|
||||
future = self._futures.pop(rel_path, None)
|
||||
if future is None:
|
||||
return
|
||||
if not future.cancel():
|
||||
try:
|
||||
local = future.result()
|
||||
local.unlink(missing_ok=True)
|
||||
except Exception:
|
||||
logger.debug(f"Could not delete cached video {rel_path}.", exc_info=True)
|
||||
|
||||
def shutdown(self) -> None:
|
||||
self._executor.shutdown(wait=False, cancel_futures=True)
|
||||
# Bound the default frame-level shuffle buffer: rows are tabular-only (~KB each), so this is
|
||||
# roughly a few hundred MB of host RAM per consumer at the cap.
|
||||
_MAX_DEFAULT_FRAME_BUFFER = 200_000
|
||||
|
||||
|
||||
class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
"""LeRobotDataset with streaming capabilities.
|
||||
"""LeRobotDataset with streaming capabilities, built on native HF `datasets` primitives.
|
||||
|
||||
Streams frames from the Hub (or any fsspec source) without downloading the dataset. The
|
||||
iteration strategy is an *episode pool*: each consumer keeps ``episode_pool_size`` whole
|
||||
episodes' tabular rows in RAM (a few KB per episode) and emits uniformly random frames
|
||||
across them, so a batch mixes up to ``batch_size`` distinct episodes. Because a frame's
|
||||
whole episode is resident, ``delta_timestamps`` windows are exact array slices with correct
|
||||
padding at episode boundaries. Video is decoded only when a sample is emitted
|
||||
(decode-on-exit), so pool memory stays tabular-sized; when streaming from a remote source,
|
||||
each pooled episode's video files are prefetched to a local cache in the background and
|
||||
deleted on eviction.
|
||||
The tabular side is a pure `datasets` pipeline::
|
||||
|
||||
Distribution: ranks stream disjoint shards via ``split_dataset_by_node`` and DataLoader
|
||||
workers split a rank's shards further, so every frame is consumed exactly once per epoch
|
||||
across the whole fleet. Each consumer's order is a pure function of
|
||||
``(seed, epoch, rank, worker)``, which makes resume a deterministic fast-forward (see
|
||||
:meth:`load_state_dict`).
|
||||
load_dataset(streaming=True) # parquet shards from the Hub / a bucket
|
||||
-> split_dataset_by_node(rank, world_size) # disjoint shards per rank
|
||||
-> batch(by_column="episode_index") # whole episodes
|
||||
-> shuffle(buffer_size=episode_pool_size) # episode pool (the randomness knob)
|
||||
-> map(explode + exact delta windows) # episode -> frames, windows are exact
|
||||
-> shuffle(buffer_size=frame_shuffle_buffer_size) # frame-level interleave
|
||||
|
||||
and this class is a thin torch ``IterableDataset`` wrapper around it that decodes video
|
||||
per emitted sample (decode-on-exit), applies image transforms, and attaches the task
|
||||
string. DataLoader workers are split natively by `datasets` (disjoint shards per worker),
|
||||
and resume uses the native ``state_dict`` / ``load_state_dict``.
|
||||
|
||||
Randomness: a batch mixes up to ``episode_pool_size`` distinct episodes; delta windows are
|
||||
exact slices of the resident episode with correct padding at episode boundaries.
|
||||
|
||||
Resume: ``state_dict()`` / ``load_state_dict()`` delegate to `datasets`. Samples sitting in
|
||||
the shuffle buffers at checkpoint time are skipped on resume (documented `datasets`
|
||||
behavior), so resume never repeats data but may drop up to roughly
|
||||
``episode_pool_size x episode_len + frame_shuffle_buffer_size`` frames — negligible at
|
||||
training scale. The contract is exact with ``num_workers=0``; with DataLoader workers use
|
||||
``torchdata.stateful_dataloader.StatefulDataLoader``, which checkpoints each worker's
|
||||
dataset state through this same protocol.
|
||||
|
||||
Example:
|
||||
```python
|
||||
@@ -201,6 +98,7 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
force_cache_sync: bool = False,
|
||||
streaming: bool = True,
|
||||
episode_pool_size: int | None = 64,
|
||||
frame_shuffle_buffer_size: int | None = None,
|
||||
buffer_size: int | None = None,
|
||||
max_num_shards: int | None = None,
|
||||
seed: int = 42,
|
||||
@@ -212,8 +110,6 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
video_decoder_cache_size: int | None = None,
|
||||
data_files_root: str | None = None,
|
||||
video_decode_device: str = "cpu",
|
||||
prefetch_videos: bool = True,
|
||||
video_prefetch_workers: int = 4,
|
||||
):
|
||||
"""Initialize a StreamingLeRobotDataset.
|
||||
|
||||
@@ -229,16 +125,18 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
revision (str, optional): Git revision id (branch name, tag, or commit hash).
|
||||
force_cache_sync (bool, optional): Flag to sync and refresh local files first.
|
||||
streaming (bool, optional): Whether to stream the dataset or load it all. Defaults to True.
|
||||
episode_pool_size (int, optional): Number of whole episodes each consumer keeps open to
|
||||
shuffle across — the randomness knob. Larger mixes more episodes per batch (closer to
|
||||
map-style uniform) at the cost of cold-start latency; RAM stays small because the pool
|
||||
holds tabular rows only. Defaults to 64.
|
||||
episode_pool_size (int, optional): Whole episodes each consumer keeps open to shuffle
|
||||
across — the randomness knob. Larger mixes more episodes per batch (closer to
|
||||
map-style uniform) at the cost of cold-start latency and frame-buffer RAM.
|
||||
Defaults to 64.
|
||||
frame_shuffle_buffer_size (int | None, optional): Frame-level shuffle buffer after the
|
||||
episode pool. Defaults to ``episode_pool_size x average episode length`` (capped),
|
||||
which matches the pool's mixing radius.
|
||||
buffer_size (int | None, optional): Deprecated; superseded by ``episode_pool_size``.
|
||||
max_num_shards (int | None, optional): Cap on the number of stream shards. None (default)
|
||||
uses every underlying parquet shard, which is required to feed many DataLoader workers.
|
||||
max_num_shards (int | None, optional): Deprecated; `datasets` handles shard-to-worker
|
||||
assignment natively.
|
||||
seed (int, optional): Reproducibility random seed.
|
||||
rng (np.random.Generator | None, optional): Deprecated; ignored (the RNG is derived from
|
||||
``(seed, epoch, rank, worker)`` so consumers are decorrelated and runs reproducible).
|
||||
rng (np.random.Generator | None, optional): Deprecated; ignored.
|
||||
shuffle (bool, optional): Whether to shuffle. False yields episodes in stream order.
|
||||
rank (int | None, optional): This process' rank for distributed training. Each rank streams
|
||||
a disjoint set of shards via ``split_dataset_by_node``. When omitted, resolved from
|
||||
@@ -253,10 +151,6 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
from there while metadata still loads from ``repo_id`` on the Hub.
|
||||
video_decode_device (str, optional): Device for torchcodec decode. ``"cuda"`` offloads to
|
||||
NVDEC (needs a CUDA torchcodec build and ``spawn`` DataLoader workers).
|
||||
prefetch_videos (bool, optional): When streaming from a remote source, download each pooled
|
||||
episode's video files to a local cache in the background so decode-on-exit reads local
|
||||
bytes instead of paying network seek latency. Defaults to True.
|
||||
video_prefetch_workers (int, optional): Download threads per consumer. Defaults to 4.
|
||||
"""
|
||||
super().__init__()
|
||||
self.repo_id = repo_id
|
||||
@@ -271,37 +165,35 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
self.seed = seed
|
||||
if rng is not None:
|
||||
logger.warning("StreamingLeRobotDataset: `rng` is deprecated and ignored; use `seed`.")
|
||||
self.shuffle = shuffle
|
||||
|
||||
self.streaming = streaming
|
||||
if buffer_size is not None:
|
||||
logger.warning(
|
||||
"StreamingLeRobotDataset: `buffer_size` is deprecated and ignored; "
|
||||
"use `episode_pool_size` (whole episodes, not frames)."
|
||||
)
|
||||
if max_num_shards is not None:
|
||||
logger.warning(
|
||||
"StreamingLeRobotDataset: `max_num_shards` is deprecated and ignored; "
|
||||
"`datasets` assigns shards to DataLoader workers natively."
|
||||
)
|
||||
self.shuffle = shuffle
|
||||
|
||||
self.streaming = streaming
|
||||
self.episode_pool_size = max(1, episode_pool_size) if episode_pool_size else 64
|
||||
self.max_num_shards = max_num_shards
|
||||
self._return_uint8 = return_uint8
|
||||
|
||||
self.rank, self.world_size = self._resolve_distributed(rank, world_size)
|
||||
self.video_decoder_cache_size = video_decoder_cache_size
|
||||
self.data_files_root = data_files_root.rstrip("/") if data_files_root else None
|
||||
self.video_decode_device = video_decode_device
|
||||
self.prefetch_videos = prefetch_videos
|
||||
self.video_prefetch_workers = video_prefetch_workers
|
||||
|
||||
# We cache the video decoders to avoid re-initializing them at each frame (avoiding a ~10x slowdown)
|
||||
self.video_decoder_cache = None
|
||||
self._prefetcher: _VideoPrefetcher | None = None
|
||||
# Shared [hits, misses, evictions, decode_ns, fetch_ns] tensor so DataLoader workers aggregate
|
||||
# decoder-cache stats and component timings into one place the main process can read after
|
||||
# iteration (see video_decoder_cache_stats() / timing_stats()).
|
||||
self._cache_counters = torch.zeros(5, dtype=torch.int64).share_memory_()
|
||||
# Deterministic fast-forward resume (see load_state_dict): per-consumer epoch counter and
|
||||
# number of samples still to skip.
|
||||
self._epoch = 0
|
||||
self._ff_remaining = 0
|
||||
self._resume_state: dict | None = None
|
||||
self._in_flight_epoch = 0
|
||||
|
||||
if self._requested_root is not None:
|
||||
self.root.mkdir(exist_ok=True, parents=True)
|
||||
@@ -349,12 +241,17 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
if extra_columns:
|
||||
self.hf_dataset = self.hf_dataset.remove_columns(extra_columns)
|
||||
|
||||
self.num_shards = (
|
||||
self.hf_dataset.num_shards
|
||||
if self.max_num_shards is None
|
||||
else min(self.hf_dataset.num_shards, self.max_num_shards)
|
||||
self.num_shards = self.hf_dataset.num_shards
|
||||
|
||||
avg_episode_len = max(1, round(self.meta.total_frames / max(1, self.meta.total_episodes)))
|
||||
self.frame_shuffle_buffer_size = (
|
||||
frame_shuffle_buffer_size
|
||||
if frame_shuffle_buffer_size is not None
|
||||
else min(self.episode_pool_size * avg_episode_len, _MAX_DEFAULT_FRAME_BUFFER)
|
||||
)
|
||||
|
||||
self._pipeline = self._build_pipeline()
|
||||
|
||||
@property
|
||||
def num_frames(self):
|
||||
return self.meta.total_frames
|
||||
@@ -374,6 +271,8 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
Explicit arguments win. Otherwise prefer an already-initialized Accelerate state, then the
|
||||
``RANK``/``WORLD_SIZE`` env vars set by launchers, and finally fall back to single-process (0, 1).
|
||||
"""
|
||||
import os
|
||||
|
||||
if rank is not None and world_size is not None:
|
||||
return rank, world_size
|
||||
|
||||
@@ -393,27 +292,67 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
|
||||
return 0, 1
|
||||
|
||||
def _consumer_rng(self, epoch: int, worker_id: int) -> np.random.Generator:
|
||||
"""RNG derived from (seed, epoch, rank, worker): reproducible, decorrelated consumers."""
|
||||
state = _mix64(self.seed)
|
||||
for salt in (self.rank, worker_id, epoch if self.shuffle else 0):
|
||||
state = _mix64(state ^ _mix64(salt))
|
||||
return np.random.default_rng(state)
|
||||
def _build_pipeline(self) -> datasets.IterableDataset:
|
||||
"""Assemble the native tabular pipeline (everything except video decode)."""
|
||||
ds = self.hf_dataset
|
||||
if self.world_size > 1:
|
||||
if ds.num_shards % self.world_size != 0:
|
||||
logger.warning(
|
||||
f"num_shards ({ds.num_shards}) is not divisible by world_size ({self.world_size}): "
|
||||
"datasets falls back to example-level splitting where every rank reads (and pays "
|
||||
"for) the full stream. Re-shard the dataset or adjust world size."
|
||||
)
|
||||
ds = split_dataset_by_node(ds, rank=self.rank, world_size=self.world_size)
|
||||
|
||||
def _shard_order(self, epoch: int, num_shards: int) -> list[int]:
|
||||
"""Seeded permutation of this rank's shard indices, re-drawn every epoch.
|
||||
|
||||
In a sub-epoch run over a corpus consolidated source-by-source, index-order shard
|
||||
assignment means training on whatever the first N% of files contains; permuting the
|
||||
shard order turns that into a uniform sample of files. Seeded by (seed, epoch, rank)
|
||||
only — every DataLoader worker of the rank must agree on this list, because workers
|
||||
stride it and disagreement would create overlapping shard assignments.
|
||||
"""
|
||||
order = list(range(num_shards))
|
||||
ds = ds.batch(by_column="episode_index")
|
||||
episode_columns = list(ds.column_names or self.hf_dataset.column_names or [])
|
||||
if self.shuffle:
|
||||
state = _mix64(self.seed) ^ _mix64(0x5EED5EED) ^ _mix64(self.rank) ^ _mix64(epoch)
|
||||
np.random.default_rng(_mix64(state)).shuffle(order)
|
||||
return order
|
||||
ds = ds.shuffle(seed=self.seed, buffer_size=self.episode_pool_size)
|
||||
# A row-count-changing batched map must drop the input columns explicitly; the exploded
|
||||
# frames re-emit them (windowed keys replaced by their delta windows + *_is_pad masks).
|
||||
ds = ds.map(self._explode_episodes, batched=True, remove_columns=episode_columns)
|
||||
if self.shuffle:
|
||||
ds = ds.shuffle(seed=self.seed + 1, buffer_size=max(2, self.frame_shuffle_buffer_size))
|
||||
return ds
|
||||
|
||||
def _tabular_window_keys(self) -> list[str]:
|
||||
if self.delta_indices is None:
|
||||
return []
|
||||
return [key for key in self.delta_indices if key not in self.meta.video_keys]
|
||||
|
||||
def _explode_episodes(self, episode_batch: dict[str, list[list]]) -> dict[str, list]:
|
||||
"""Episode batches -> per-frame rows, with exact tabular delta windows and pad masks.
|
||||
|
||||
Runs inside the `datasets` pipeline (plain Python values, no torch). For each windowed key
|
||||
the original per-frame value is replaced by its delta window (list of values, clamped to
|
||||
the episode bounds) plus a ``{key}_is_pad`` mask, mirroring the map-style dataset.
|
||||
"""
|
||||
window_keys = set(self._tabular_window_keys())
|
||||
out: dict[str, list] = {key: [] for key in episode_batch if key not in window_keys}
|
||||
for key in window_keys:
|
||||
out[key] = []
|
||||
out[f"{key}_is_pad"] = []
|
||||
|
||||
num_episodes = len(episode_batch["episode_index"])
|
||||
for e in range(num_episodes):
|
||||
length = len(episode_batch["episode_index"][e])
|
||||
for key, column in episode_batch.items():
|
||||
if key in window_keys:
|
||||
continue
|
||||
out[key].extend(column[e])
|
||||
for key in window_keys:
|
||||
episode_column = episode_batch[key][e]
|
||||
deltas = self.delta_indices[key]
|
||||
for t in range(length):
|
||||
window = []
|
||||
is_pad = []
|
||||
for delta in deltas:
|
||||
j = t + delta
|
||||
window.append(episode_column[min(max(j, 0), length - 1)])
|
||||
is_pad.append(not 0 <= j < length)
|
||||
out[key].append(window)
|
||||
out[f"{key}_is_pad"].append(is_pad)
|
||||
return out
|
||||
|
||||
def _make_video_decoder_cache(self) -> VideoDecoderCache:
|
||||
"""Size the decoder cache to the pool's working set (pool episodes x cameras), capped at 128."""
|
||||
@@ -432,231 +371,37 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
device=self.video_decode_device,
|
||||
)
|
||||
|
||||
def _make_prefetcher(self) -> _VideoPrefetcher | None:
|
||||
if not self.prefetch_videos or len(self.meta.video_keys) == 0:
|
||||
return None
|
||||
if self.data_files_root is not None:
|
||||
remote_root = self.data_files_root
|
||||
elif self.streaming and not self.streaming_from_local:
|
||||
remote_root = self.meta.url_root
|
||||
else:
|
||||
return None # video bytes are already local
|
||||
return _VideoPrefetcher(
|
||||
remote_root,
|
||||
cache_dir=self.root / "streaming_video_cache",
|
||||
max_workers=self.video_prefetch_workers,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _iter_shard_episodes(shard: datasets.IterableDataset) -> Iterator[tuple[int, list[dict]]]:
|
||||
"""Yield (episode_index, rows) for each complete episode of a shard stream.
|
||||
|
||||
On datasets >= 5 the grouping runs natively in Arrow via ``batch(by_column=...)``
|
||||
(one accumulation per episode instead of one Python dict per row); older versions
|
||||
use the equivalent row loop.
|
||||
"""
|
||||
if _HAS_BATCH_BY_COLUMN:
|
||||
for batch in shard.batch(by_column="episode_index"):
|
||||
keys = list(batch.keys())
|
||||
num_rows = len(batch["episode_index"])
|
||||
rows = [{key: batch[key][i] for key in keys} for i in range(num_rows)]
|
||||
yield int(batch["episode_index"][0]), rows
|
||||
return
|
||||
rows: list[dict] = []
|
||||
current: int | None = None
|
||||
for item in shard:
|
||||
ep_idx = int(item["episode_index"])
|
||||
if current is None:
|
||||
current = ep_idx
|
||||
if ep_idx != current:
|
||||
yield current, rows
|
||||
rows = []
|
||||
current = ep_idx
|
||||
rows.append(item)
|
||||
if rows:
|
||||
yield current, rows
|
||||
|
||||
def _admit_episode(self, ep_idx: int, rows: list[dict], prefetcher: _VideoPrefetcher | None):
|
||||
video_rel_paths = [str(self.meta.get_video_file_path(ep_idx, key)) for key in self.meta.video_keys]
|
||||
if prefetcher is not None:
|
||||
for rel in video_rel_paths:
|
||||
prefetcher.acquire(rel)
|
||||
torch_rows = [item_to_torch(row) for row in rows]
|
||||
return _PooledEpisode(ep_idx, torch_rows, video_rel_paths)
|
||||
|
||||
def __iter__(self) -> Iterator[dict[str, torch.Tensor]]:
|
||||
ds = self.hf_dataset
|
||||
if self.world_size > 1:
|
||||
if ds.num_shards % self.world_size != 0:
|
||||
logger.warning(
|
||||
f"num_shards ({ds.num_shards}) is not divisible by world_size ({self.world_size}): "
|
||||
"datasets falls back to example-level splitting where every rank reads (and pays "
|
||||
"for) the full stream. Re-shard the dataset or adjust world size."
|
||||
)
|
||||
ds = split_dataset_by_node(ds, rank=self.rank, world_size=self.world_size)
|
||||
|
||||
num_shards = ds.num_shards if self.max_num_shards is None else min(ds.num_shards, self.max_num_shards)
|
||||
epoch = self._epoch
|
||||
# `datasets` reshuffles (and re-permutes shard order) per epoch from (seed, epoch);
|
||||
# DataLoader workers each advance their own copy's counter in lockstep. The in-flight
|
||||
# epoch is tracked separately so a mid-iteration state_dict() records the epoch the
|
||||
# stream position actually belongs to.
|
||||
self._in_flight_epoch = self._epoch
|
||||
self._pipeline.set_epoch(self._in_flight_epoch)
|
||||
self._epoch += 1
|
||||
shard_indices = self._shard_order(epoch, num_shards)
|
||||
|
||||
# DataLoader workers within this rank further split the shards so they don't yield duplicates.
|
||||
worker_info = torch.utils.data.get_worker_info()
|
||||
worker_id, num_workers = (worker_info.id, worker_info.num_workers) if worker_info else (0, 1)
|
||||
shard_indices = shard_indices[worker_id::num_workers]
|
||||
if not shard_indices:
|
||||
logger.warning(
|
||||
f"Worker {worker_id} owns no shards ({num_shards} shards < {num_workers} workers): "
|
||||
"it will yield nothing. Reduce num_workers or re-shard the dataset."
|
||||
)
|
||||
return
|
||||
|
||||
self.video_decoder_cache = self._make_video_decoder_cache()
|
||||
prefetcher = self._make_prefetcher()
|
||||
self._prefetcher = prefetcher
|
||||
|
||||
rng = self._consumer_rng(epoch, worker_id)
|
||||
# Workers beyond the shard count yield nothing and are stopped by the DataLoader, so the
|
||||
# batch round-robin effectively runs over min(num_workers, num_shards) active workers.
|
||||
self._consume_resume_state(worker_id, min(num_workers, num_shards))
|
||||
iterator = iter(self._pipeline)
|
||||
while True:
|
||||
fetch_start = time.perf_counter_ns()
|
||||
try:
|
||||
row = next(iterator)
|
||||
except StopIteration:
|
||||
return
|
||||
finally:
|
||||
self._cache_counters[4] += time.perf_counter_ns() - fetch_start
|
||||
yield self._finalize_sample(row)
|
||||
|
||||
# Round-robin episode admission across this consumer's shard streams (deterministic).
|
||||
streams = [self._iter_shard_episodes(safe_shard(ds, idx, num_shards)) for idx in shard_indices]
|
||||
next_stream = 0
|
||||
|
||||
pool: list[_PooledEpisode] = []
|
||||
total_remaining = 0
|
||||
|
||||
def admit() -> int:
|
||||
nonlocal next_stream, total_remaining
|
||||
admitted = 0
|
||||
while len(pool) < self.episode_pool_size and streams:
|
||||
stream = streams[next_stream % len(streams)]
|
||||
fetch_start = time.perf_counter_ns()
|
||||
try:
|
||||
ep_idx, rows = next(stream)
|
||||
except StopIteration:
|
||||
streams.remove(stream)
|
||||
continue
|
||||
finally:
|
||||
self._cache_counters[4] += time.perf_counter_ns() - fetch_start
|
||||
next_stream += 1
|
||||
episode = self._admit_episode(ep_idx, rows, prefetcher)
|
||||
pool.append(episode)
|
||||
total_remaining += len(episode.remaining)
|
||||
admitted += 1
|
||||
return admitted
|
||||
|
||||
worker_split_guard = (
|
||||
_suppress_hf_worker_split() if worker_info is not None else contextlib.nullcontext()
|
||||
)
|
||||
try:
|
||||
with worker_split_guard:
|
||||
admit()
|
||||
while pool:
|
||||
if self.shuffle:
|
||||
# Uniform draw over every remaining frame in the pool: pick the episode by
|
||||
# cumulative remaining count, then a random remaining position (swap-pop).
|
||||
draw = int(rng.integers(total_remaining))
|
||||
for episode in pool:
|
||||
if draw < len(episode.remaining):
|
||||
break
|
||||
draw -= len(episode.remaining)
|
||||
pick = int(rng.integers(len(episode.remaining)))
|
||||
frame_pos = episode.remaining[pick]
|
||||
episode.remaining[pick] = episode.remaining[-1]
|
||||
episode.remaining.pop()
|
||||
else:
|
||||
episode = pool[0]
|
||||
frame_pos = episode.remaining.pop(0)
|
||||
total_remaining -= 1
|
||||
|
||||
if self._ff_remaining > 0:
|
||||
self._ff_remaining -= 1
|
||||
else:
|
||||
yield self._make_pool_sample(episode, frame_pos)
|
||||
|
||||
if not episode.remaining:
|
||||
pool.remove(episode)
|
||||
if prefetcher is not None:
|
||||
for rel in episode.video_rel_paths:
|
||||
prefetcher.release(rel)
|
||||
admit()
|
||||
finally:
|
||||
if prefetcher is not None:
|
||||
prefetcher.shutdown()
|
||||
self._prefetcher = None
|
||||
|
||||
def load_state_dict(self, state_dict: dict) -> None:
|
||||
"""Stage a deterministic fast-forward resume, applied from the next ``__iter__``.
|
||||
|
||||
``state_dict`` holds ``{"batches_consumed": int, "batch_size": int}`` — what the trainer
|
||||
already knows at checkpoint time. Because every consumer's order is a pure function of
|
||||
(seed, epoch, rank, worker), resume replays the stream while skipping emission (tabular
|
||||
reads only, no video decode) until each worker reaches its own consumed count; the
|
||||
DataLoader's round-robin batch assignment makes that count derivable per worker. Exact
|
||||
within an epoch; crossing epoch boundaries may drift by < one batch per worker per epoch
|
||||
when ``drop_last`` discards partial batches.
|
||||
"""
|
||||
self._resume_state = {
|
||||
"batches_consumed": int(state_dict["batches_consumed"]),
|
||||
"batch_size": int(state_dict["batch_size"]),
|
||||
}
|
||||
|
||||
def _consume_resume_state(self, worker_id: int, active_workers: int) -> None:
|
||||
if self._resume_state is None:
|
||||
return
|
||||
batches = self._resume_state["batches_consumed"]
|
||||
batch_size = self._resume_state["batch_size"]
|
||||
self._resume_state = None
|
||||
if worker_id >= active_workers:
|
||||
return # this worker owns no shards and never delivered a batch
|
||||
# The DataLoader assigns batch j to active worker j % active_workers.
|
||||
my_batches = batches // active_workers + (1 if batches % active_workers > worker_id else 0)
|
||||
self._ff_remaining = my_batches * batch_size
|
||||
if self._ff_remaining:
|
||||
logger.info(
|
||||
f"Streaming resume: worker {worker_id} fast-forwarding {self._ff_remaining} samples "
|
||||
"(tabular reads only, no video decode)."
|
||||
)
|
||||
|
||||
def video_decoder_cache_stats(self) -> dict[str, int | float]:
|
||||
"""Decoder-cache reuse aggregated across DataLoader workers via the shared counter tensor.
|
||||
|
||||
Unlike ``self.video_decoder_cache.stats()`` (which only reflects the main process), this sums
|
||||
hits/misses/evictions over every worker. Counts are lock-free across processes, so treat them as
|
||||
approximate; the ``hit_rate`` ratio is preserved.
|
||||
"""
|
||||
hits, misses, evictions = (int(x) for x in self._cache_counters[:3].tolist())
|
||||
total = hits + misses
|
||||
return {
|
||||
"hits": hits,
|
||||
"misses": misses,
|
||||
"evictions": evictions,
|
||||
"hit_rate": round(hits / total, 4) if total else 0.0,
|
||||
}
|
||||
|
||||
def timing_stats(self) -> dict[str, float]:
|
||||
"""Cumulative seconds spent in video decode and episode (tabular) fetch, summed across
|
||||
DataLoader workers via the shared counter tensor. These overlap in wall-clock (workers run
|
||||
in parallel), so compare them to ``num_workers x wallclock`` for time fractions.
|
||||
"""
|
||||
decode_ns, fetch_ns = (int(x) for x in self._cache_counters[3:5].tolist())
|
||||
return {"decode_s_total": round(decode_ns / 1e9, 2), "fetch_s_total": round(fetch_ns / 1e9, 2)}
|
||||
|
||||
def _make_pool_sample(self, episode: _PooledEpisode, frame_pos: int) -> dict:
|
||||
"""Assemble a full training sample for one pooled frame (tabular slices + video decode)."""
|
||||
rows = episode.rows
|
||||
item = dict(rows[frame_pos])
|
||||
ep_idx = episode.episode_index
|
||||
num_rows = len(rows)
|
||||
current_ts = float(item["timestamp"])
|
||||
|
||||
updates: list[dict] = []
|
||||
if self.delta_indices is not None:
|
||||
updates.extend(self._pool_delta_frames(rows, frame_pos, num_rows))
|
||||
def _finalize_sample(self, row: dict) -> dict:
|
||||
"""Torch conversion + video decode (decode-on-exit) + transforms + task for one frame."""
|
||||
window_keys = self._tabular_window_keys()
|
||||
pad_masks = {f"{key}_is_pad": torch.BoolTensor(row.pop(f"{key}_is_pad")) for key in window_keys}
|
||||
item = item_to_torch(row)
|
||||
item.update(pad_masks)
|
||||
|
||||
if len(self.meta.video_keys) > 0:
|
||||
ep_idx = int(item["episode_index"])
|
||||
current_ts = float(item["timestamp"])
|
||||
# Per-camera episode-local bounds [0, duration]: out-of-episode deltas pad instead of
|
||||
# decoding against a neighbouring episode sharing the same video file.
|
||||
episode_boundaries_ts = {
|
||||
@@ -679,35 +424,57 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
for cam in self.meta.camera_keys:
|
||||
video_frames[cam] = self.image_transforms(video_frames[cam])
|
||||
|
||||
updates.append(video_frames)
|
||||
item.update(video_frames)
|
||||
if self.delta_indices is not None:
|
||||
updates.append(
|
||||
item.update(
|
||||
self._get_video_frame_padding_mask(video_frames, query_timestamps, original_timestamps)
|
||||
)
|
||||
|
||||
result = item
|
||||
for update in updates:
|
||||
result.update(update)
|
||||
result["task"] = self.meta.tasks.iloc[item["task_index"]].name
|
||||
return result
|
||||
item["task"] = self.meta.tasks.iloc[int(item["task_index"])].name
|
||||
return item
|
||||
|
||||
def _pool_delta_frames(self, rows: list[dict], frame_pos: int, num_rows: int) -> list[dict]:
|
||||
"""Exact delta windows by slicing the resident episode; clamped + padded at boundaries."""
|
||||
query_result: dict = {}
|
||||
padding: dict = {}
|
||||
for key, deltas in self.delta_indices.items():
|
||||
if key in self.meta.video_keys:
|
||||
continue # visual frames are decoded separately
|
||||
frames = []
|
||||
is_pad = []
|
||||
for delta in deltas:
|
||||
j = frame_pos + delta
|
||||
valid = 0 <= j < num_rows
|
||||
frames.append(rows[min(max(j, 0), num_rows - 1)][key])
|
||||
is_pad.append(not valid)
|
||||
query_result[key] = torch.stack(frames)
|
||||
padding[f"{key}_is_pad"] = torch.BoolTensor(is_pad)
|
||||
return [query_result, padding]
|
||||
def set_epoch(self, epoch: int) -> None:
|
||||
"""Set the epoch the next ``__iter__`` will use (reshuffles the native pipeline)."""
|
||||
self._epoch = epoch
|
||||
|
||||
def state_dict(self) -> dict:
|
||||
"""Native `datasets` stream state. Exact contract with ``num_workers=0``; with DataLoader
|
||||
workers use ``torchdata.stateful_dataloader.StatefulDataLoader`` (it checkpoints each
|
||||
worker's copy through this protocol). Samples in the shuffle buffers are skipped on
|
||||
resume (never repeated), bounded by the pool + frame buffer sizes.
|
||||
"""
|
||||
return {"pipeline": self._pipeline.state_dict(), "epoch": self._in_flight_epoch}
|
||||
|
||||
def load_state_dict(self, state_dict: dict) -> None:
|
||||
# Resume continues inside the recorded epoch: the next __iter__ replays that epoch's
|
||||
# shuffle order from the restored stream position, then advances normally.
|
||||
self._epoch = int(state_dict.get("epoch", 0))
|
||||
self._pipeline.load_state_dict(state_dict["pipeline"])
|
||||
|
||||
def video_decoder_cache_stats(self) -> dict[str, int | float]:
|
||||
"""Decoder-cache reuse aggregated across DataLoader workers via the shared counter tensor.
|
||||
|
||||
Unlike ``self.video_decoder_cache.stats()`` (which only reflects the main process), this sums
|
||||
hits/misses/evictions over every worker. Counts are lock-free across processes, so treat them as
|
||||
approximate; the ``hit_rate`` ratio is preserved.
|
||||
"""
|
||||
hits, misses, evictions = (int(x) for x in self._cache_counters[:3].tolist())
|
||||
total = hits + misses
|
||||
return {
|
||||
"hits": hits,
|
||||
"misses": misses,
|
||||
"evictions": evictions,
|
||||
"hit_rate": round(hits / total, 4) if total else 0.0,
|
||||
}
|
||||
|
||||
def timing_stats(self) -> dict[str, float]:
|
||||
"""Cumulative seconds spent in video decode and in the upstream tabular pipeline (parquet
|
||||
fetch + grouping + shuffles + explode), summed across DataLoader workers via the shared
|
||||
counter tensor. These overlap in wall-clock (workers run in parallel), so compare them to
|
||||
``num_workers x wallclock`` for time fractions.
|
||||
"""
|
||||
decode_ns, fetch_ns = (int(x) for x in self._cache_counters[3:5].tolist())
|
||||
return {"decode_s_total": round(decode_ns / 1e9, 2), "fetch_s_total": round(fetch_ns / 1e9, 2)}
|
||||
|
||||
def _make_timestamps_from_indices(
|
||||
self, start_ts: float, indices: dict[str, list[int]] | None = None
|
||||
@@ -785,17 +552,13 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
from_timestamp = self.meta.episodes[ep_idx][f"videos/{video_key}/from_timestamp"]
|
||||
shifted_query_ts = [from_timestamp + ts for ts in query_ts]
|
||||
rel_path = str(self.meta.get_video_file_path(ep_idx, video_key))
|
||||
local = self._prefetcher.wait_local(rel_path) if self._prefetcher is not None else None
|
||||
if local is not None:
|
||||
video_path = str(local)
|
||||
if self.data_files_root is not None:
|
||||
root = self.data_files_root
|
||||
elif self.streaming and not self.streaming_from_local:
|
||||
root = self.meta.url_root
|
||||
else:
|
||||
if self.data_files_root is not None:
|
||||
root = self.data_files_root
|
||||
elif self.streaming and not self.streaming_from_local:
|
||||
root = self.meta.url_root
|
||||
else:
|
||||
root = self.root
|
||||
video_path = f"{root}/{rel_path}"
|
||||
root = self.root
|
||||
video_path = f"{root}/{rel_path}"
|
||||
frames = decode_video_frames_torchcodec(
|
||||
video_path,
|
||||
shifted_query_ts,
|
||||
|
||||
@@ -218,6 +218,11 @@ def test_frames_with_delta_consistency(tmp_path, lerobot_dataset_factory, state_
|
||||
|
||||
check = torch.allclose(left, right) and left.shape == right.shape
|
||||
|
||||
else:
|
||||
# Scalar numerics: streaming yields python floats/ints where map-style yields
|
||||
# 0-dim tensors (long-standing accepted difference). Compare by value.
|
||||
check = float(left) == float(right)
|
||||
|
||||
key_checks.append((key, check))
|
||||
|
||||
assert all(t[1] for t in key_checks), (
|
||||
|
||||
@@ -148,37 +148,6 @@ def test_sarm_window_covers_long_horizon_without_padding(tmp_path, lerobot_datas
|
||||
assert checked > 0, "test did not exercise any in-episode long-horizon frame"
|
||||
|
||||
|
||||
def test_fast_forward_resume_is_sample_exact(tmp_path, lerobot_dataset_factory):
|
||||
"""Resume replays the deterministic stream and continues at the exact sample."""
|
||||
repo_id = f"{DUMMY_REPO_ID}-resume"
|
||||
total_frames = 100
|
||||
_make_local_dataset(
|
||||
lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=5, total_frames=total_frames
|
||||
)
|
||||
|
||||
def fresh_ds():
|
||||
return StreamingLeRobotDataset(
|
||||
repo_id=repo_id,
|
||||
root=tmp_path / "ds",
|
||||
shuffle=True,
|
||||
seed=7,
|
||||
episode_pool_size=3,
|
||||
max_num_shards=1,
|
||||
)
|
||||
|
||||
full_epoch = _stream_indices(fresh_ds())
|
||||
assert sorted(full_epoch) == list(range(total_frames))
|
||||
|
||||
batches_consumed, batch_size = 5, 4 # 20 samples in
|
||||
resumed_ds = fresh_ds()
|
||||
resumed_ds.load_state_dict({"batches_consumed": batches_consumed, "batch_size": batch_size})
|
||||
resumed = _stream_indices(resumed_ds)
|
||||
|
||||
assert resumed == full_epoch[batches_consumed * batch_size :], (
|
||||
"fast-forward resume did not continue at the exact sample"
|
||||
)
|
||||
|
||||
|
||||
def test_pool_order_is_deterministic_per_seed(tmp_path, lerobot_dataset_factory):
|
||||
repo_id = f"{DUMMY_REPO_ID}-seeds"
|
||||
_make_local_dataset(lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=6, total_frames=120)
|
||||
@@ -226,27 +195,6 @@ def test_pool_mixes_episodes(tmp_path, lerobot_dataset_factory):
|
||||
assert len(episodes_in_head) >= 3, f"pool did not mix episodes: {episodes_in_head}"
|
||||
|
||||
|
||||
def test_video_prefetcher_refcounted_lifecycle(tmp_path):
|
||||
from lerobot.datasets.streaming_dataset import _VideoPrefetcher
|
||||
|
||||
remote = tmp_path / "remote"
|
||||
(remote / "videos").mkdir(parents=True)
|
||||
payload = b"x" * 1024
|
||||
(remote / "videos" / "a.mp4").write_bytes(payload)
|
||||
|
||||
prefetcher = _VideoPrefetcher(str(remote), cache_dir=tmp_path / "cache", max_workers=1)
|
||||
prefetcher.acquire("videos/a.mp4")
|
||||
prefetcher.acquire("videos/a.mp4") # second pooled episode sharing the file
|
||||
local = prefetcher.wait_local("videos/a.mp4")
|
||||
assert local is not None and local.read_bytes() == payload
|
||||
|
||||
prefetcher.release("videos/a.mp4")
|
||||
assert local.exists(), "file deleted while still referenced"
|
||||
prefetcher.release("videos/a.mp4")
|
||||
assert not local.exists(), "file not deleted at refcount zero"
|
||||
prefetcher.shutdown()
|
||||
|
||||
|
||||
def test_schema_parity_with_map_style(tmp_path, lerobot_dataset_factory):
|
||||
"""Streamed samples must have the same keys / shapes / dtypes as map-style LeRobotDataset."""
|
||||
repo_id = f"{DUMMY_REPO_ID}-parity"
|
||||
@@ -318,87 +266,49 @@ def test_shuffle_decorrelates_output_order(tmp_path, lerobot_dataset_factory):
|
||||
assert shuffled != ordered, "shuffle did not decorrelate output order"
|
||||
|
||||
|
||||
def test_fast_forward_resume_with_dataloader_workers(tmp_path, lerobot_dataset_factory):
|
||||
"""Resume must be exact under num_workers > 0: each worker re-derives its own skip."""
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
repo_id = f"{DUMMY_REPO_ID}-resume-workers"
|
||||
_make_local_dataset(lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=8, total_frames=120)
|
||||
|
||||
num_workers = 2
|
||||
def test_native_resume_never_repeats_and_loss_is_bounded(tmp_path, lerobot_dataset_factory):
|
||||
"""Native state_dict resume: no sample is re-yielded; loss is bounded by the shuffle buffers."""
|
||||
repo_id = f"{DUMMY_REPO_ID}-native-resume"
|
||||
total_frames = 100
|
||||
_make_local_dataset(
|
||||
lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=5, total_frames=total_frames
|
||||
)
|
||||
|
||||
def fresh_ds():
|
||||
return StreamingLeRobotDataset(
|
||||
repo_id=repo_id,
|
||||
root=tmp_path / "ds",
|
||||
shuffle=True,
|
||||
seed=11,
|
||||
episode_pool_size=3,
|
||||
max_num_shards=4,
|
||||
seed=7,
|
||||
episode_pool_size=2,
|
||||
frame_shuffle_buffer_size=8,
|
||||
)
|
||||
|
||||
def epoch_samples(ds):
|
||||
# batch_size=None yields raw samples; the DataLoader round-robins them across workers,
|
||||
# which is batch_size=1 in the resume arithmetic.
|
||||
loader = DataLoader(ds, batch_size=None, num_workers=num_workers)
|
||||
return [int(sample["index"]) for sample in loader]
|
||||
ds = fresh_ds()
|
||||
it = iter(ds)
|
||||
consumed = [int(next(it)["index"]) for _ in range(30)]
|
||||
state = ds.state_dict()
|
||||
|
||||
full = epoch_samples(fresh_ds())
|
||||
|
||||
samples_consumed = 17
|
||||
resumed_ds = fresh_ds()
|
||||
resumed_ds.load_state_dict({"batches_consumed": samples_consumed, "batch_size": 1})
|
||||
resumed = epoch_samples(resumed_ds)
|
||||
resumed_ds.load_state_dict(state)
|
||||
rest = [int(frame["index"]) for frame in resumed_ds]
|
||||
|
||||
assert resumed == full[samples_consumed:], (
|
||||
"fast-forward resume with DataLoader workers did not continue at the exact sample"
|
||||
)
|
||||
assert not set(consumed) & set(rest), "resume re-yielded already-seen frames"
|
||||
# in-flight buffer contents are skipped on resume (documented datasets behavior):
|
||||
# bounded by the episode pool (2 episodes of <= ~30 frames here) + frame buffer (8)
|
||||
covered = len(set(consumed) | set(rest))
|
||||
max_in_flight = 2 * 30 + 8
|
||||
assert covered >= total_frames - max_in_flight
|
||||
assert covered + len(consumed) >= total_frames - max_in_flight
|
||||
|
||||
|
||||
def test_episode_grouping_native_and_fallback_agree(tmp_path, lerobot_dataset_factory, monkeypatch):
|
||||
"""The datasets>=5 batch(by_column=...) path must group episodes identically to the row loop."""
|
||||
import lerobot.datasets.streaming_dataset as sd
|
||||
|
||||
repo_id = f"{DUMMY_REPO_ID}-grouping"
|
||||
_make_local_dataset(lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=5, total_frames=100)
|
||||
ds = StreamingLeRobotDataset(repo_id=repo_id, root=tmp_path / "ds", shuffle=False, max_num_shards=1)
|
||||
|
||||
def episode_signature(use_native):
|
||||
monkeypatch.setattr(sd, "_HAS_BATCH_BY_COLUMN", use_native)
|
||||
return [
|
||||
(ep_idx, [int(row["index"]) for row in rows])
|
||||
for ep_idx, rows in ds._iter_shard_episodes(ds.hf_dataset)
|
||||
]
|
||||
|
||||
fallback = episode_signature(False)
|
||||
assert len(fallback) == 5
|
||||
if not sd._HAS_BATCH_BY_COLUMN and "by_column" not in str(
|
||||
type(ds.hf_dataset).batch.__doc__ or ""
|
||||
): # datasets < 5: only the fallback path exists
|
||||
return
|
||||
native = episode_signature(True)
|
||||
assert native == fallback
|
||||
|
||||
|
||||
def test_shard_order_permutation_properties(tmp_path, lerobot_dataset_factory):
|
||||
"""Shard order: a valid permutation, deterministic per (seed, epoch, rank), worker-independent
|
||||
(workers stride the same list, so it must not depend on worker id), reshuffled across epochs,
|
||||
and identity when shuffle is off."""
|
||||
repo_id = f"{DUMMY_REPO_ID}-shardorder"
|
||||
def test_pipeline_uses_native_primitives(tmp_path, lerobot_dataset_factory):
|
||||
"""The tabular pipeline is pure datasets: batch(by_column) + shuffle + map + shuffle."""
|
||||
repo_id = f"{DUMMY_REPO_ID}-native-pipe"
|
||||
_make_local_dataset(lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=4, total_frames=80)
|
||||
ds = StreamingLeRobotDataset(repo_id=repo_id, root=tmp_path / "ds", shuffle=True, episode_pool_size=2)
|
||||
import datasets as hf_datasets
|
||||
|
||||
ds = StreamingLeRobotDataset(repo_id=repo_id, root=tmp_path / "ds", shuffle=True, seed=5)
|
||||
num_shards = 32
|
||||
order_epoch0 = ds._shard_order(0, num_shards)
|
||||
assert sorted(order_epoch0) == list(range(num_shards))
|
||||
assert ds._shard_order(0, num_shards) == order_epoch0 # deterministic
|
||||
assert ds._shard_order(1, num_shards) != order_epoch0 # reshuffles per epoch
|
||||
assert order_epoch0 != list(range(num_shards)) # actually permuted (P=1/32! of false alarm)
|
||||
|
||||
other_rank = StreamingLeRobotDataset(
|
||||
repo_id=repo_id, root=tmp_path / "ds", shuffle=True, seed=5, rank=1, world_size=2
|
||||
)
|
||||
assert other_rank._shard_order(0, num_shards) != order_epoch0 # ranks decorrelated
|
||||
|
||||
unshuffled = StreamingLeRobotDataset(repo_id=repo_id, root=tmp_path / "ds", shuffle=False, seed=5)
|
||||
assert unshuffled._shard_order(0, num_shards) == list(range(num_shards))
|
||||
assert isinstance(ds._pipeline, hf_datasets.IterableDataset)
|
||||
state = ds._pipeline.state_dict() # the native resume protocol is available end-to-end
|
||||
assert state is not None
|
||||
|
||||
Reference in New Issue
Block a user