mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-17 16:27:04 +00:00
feat(streaming): distributed + resumable HF-native StreamingLeRobotDataset
Add the large-scale streaming pieces that were missing from the frame-streaming internals, keeping the existing Backtrackable + output-reservoir frame-shuffle: - split_dataset_by_node(rank, world_size) before the per-shard loop so each rank streams a disjoint set of shards (fixes duplicate data across GPUs). rank and world_size auto-resolve from Accelerate state / RANK,WORLD_SIZE env / (0, 1). - get_worker_info() shard splitting so DataLoader workers within a rank don't yield duplicate frames. - Dynamic Backtrackable window (dynamic_bounds=True) sized to the requested delta_timestamps, removing the fixed 100-frame ceiling so long horizons (e.g. a SARM window ~160 frames) reach real frames instead of silently padding. Fix the peek_back off-by-one: history = lookback + 1. - video_decoder_cache_size knob; default (active_shards + 1) x num_cameras so the live decoder working set does not thrash the VideoDecoderCache LRU. - state_dict()/load_state_dict() for resume (per-shard HF stream state + exhausted set + RNG). Reservoir is re-warmed, so resumption is not bit-exact (documented). - factory.py wires buffer_size from a new DatasetConfig.streaming_buffer_size field instead of repurposing max_num_shards as the worker count. Tests: tests/datasets/test_streaming_native.py covers distributed disjointness, worker de-duplication, the SARM-length window, resume, schema parity vs map-style, local video path resolution, and shuffle decorrelation. 21 passed (13 existing + 8). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -39,6 +39,9 @@ class DatasetConfig:
|
||||
# This reduces memory and speeds up DataLoader IPC. The training pipeline handles the conversion.
|
||||
return_uint8: bool = False
|
||||
streaming: bool = False
|
||||
# Output shuffle-buffer size (in frames) when streaming. Larger decorrelates samples better at the cost
|
||||
# of host RAM. Ignored when streaming is False.
|
||||
streaming_buffer_size: int = 1000
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.episodes is not None:
|
||||
|
||||
@@ -106,7 +106,7 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
|
||||
delta_timestamps=delta_timestamps,
|
||||
image_transforms=image_transforms,
|
||||
revision=cfg.dataset.revision,
|
||||
max_num_shards=cfg.num_workers,
|
||||
buffer_size=cfg.dataset.streaming_buffer_size,
|
||||
tolerance_s=cfg.tolerance_s,
|
||||
return_uint8=True,
|
||||
)
|
||||
|
||||
@@ -13,6 +13,9 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
from collections import deque
|
||||
from collections.abc import Callable, Generator, Iterable, Iterator
|
||||
from pathlib import Path
|
||||
@@ -21,6 +24,7 @@ import datasets
|
||||
import numpy as np
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from datasets.distributed import split_dataset_by_node
|
||||
|
||||
from lerobot.utils.constants import HF_LEROBOT_HOME, LOOKAHEAD_BACKTRACKTABLE, LOOKBACK_BACKTRACKTABLE
|
||||
|
||||
@@ -38,6 +42,8 @@ from .video_utils import (
|
||||
decode_video_frames_torchcodec,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LookBackError(Exception):
|
||||
"""
|
||||
@@ -252,6 +258,9 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
rng: np.random.Generator | None = None,
|
||||
shuffle: bool = True,
|
||||
return_uint8: bool = False,
|
||||
rank: int | None = None,
|
||||
world_size: int | None = None,
|
||||
video_decoder_cache_size: int | None = None,
|
||||
):
|
||||
"""Initialize a StreamingLeRobotDataset.
|
||||
|
||||
@@ -272,6 +281,15 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
seed (int, optional): Reproducibility random seed.
|
||||
rng (np.random.Generator | None, optional): Random number generator.
|
||||
shuffle (bool, optional): Whether to shuffle the dataset across exhaustions. Defaults to True.
|
||||
rank (int | None, optional): This process' rank for distributed (multi-GPU/multi-node) training.
|
||||
Each rank streams a disjoint set of shards via ``split_dataset_by_node``. When omitted, it is
|
||||
resolved from Accelerate (``process_index``) or the ``RANK`` env var, defaulting to 0.
|
||||
world_size (int | None, optional): Total number of distributed processes. When omitted, resolved
|
||||
from Accelerate (``num_processes``) or the ``WORLD_SIZE`` env var, defaulting to 1 (no sharding).
|
||||
For an even per-rank split, ``num_shards % world_size == 0`` should hold.
|
||||
video_decoder_cache_size (int | None, optional): Max number of open video decoders to retain.
|
||||
When omitted, it defaults to ``(concurrent active shards + 1) × num_cameras`` so the working
|
||||
set of live decoders never thrashes. See :class:`VideoDecoderCache`.
|
||||
"""
|
||||
super().__init__()
|
||||
self.repo_id = repo_id
|
||||
@@ -289,10 +307,16 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
|
||||
self.streaming = streaming
|
||||
self.buffer_size = buffer_size
|
||||
self.max_num_shards = max_num_shards
|
||||
self._return_uint8 = return_uint8
|
||||
|
||||
self.rank, self.world_size = self._resolve_distributed(rank, world_size)
|
||||
self.video_decoder_cache_size = video_decoder_cache_size
|
||||
|
||||
# We cache the video decoders to avoid re-initializing them at each frame (avoiding a ~10x slowdown)
|
||||
self.video_decoder_cache = None
|
||||
# Resume state captured by load_state_dict() and consumed at the next __iter__.
|
||||
self._resume_state: dict | None = None
|
||||
|
||||
if self._requested_root is not None:
|
||||
self.root.mkdir(exist_ok=True, parents=True)
|
||||
@@ -348,22 +372,91 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
while True:
|
||||
yield rng.choice(elements)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_distributed(rank: int | None, world_size: int | None) -> tuple[int, int]:
|
||||
"""Resolve (rank, world_size) for distributed streaming.
|
||||
|
||||
Explicit arguments win. Otherwise prefer an already-initialized Accelerate state, then the
|
||||
``RANK``/``WORLD_SIZE`` env vars set by launchers, and finally fall back to single-process (0, 1).
|
||||
"""
|
||||
if rank is not None and world_size is not None:
|
||||
return rank, world_size
|
||||
|
||||
try:
|
||||
from accelerate.state import PartialState
|
||||
|
||||
if PartialState._shared_state: # only read it if already initialized; never initialize here
|
||||
state = PartialState()
|
||||
return state.process_index, state.num_processes
|
||||
except Exception:
|
||||
logger.debug("Could not resolve distributed state from Accelerate; using env/defaults.")
|
||||
|
||||
env_rank = os.environ.get("RANK")
|
||||
env_world = os.environ.get("WORLD_SIZE")
|
||||
if env_rank is not None and env_world is not None:
|
||||
return int(env_rank), int(env_world)
|
||||
|
||||
return 0, 1
|
||||
|
||||
def _make_video_decoder_cache(self, num_active_shards: int) -> VideoDecoderCache:
|
||||
"""Size the decoder cache to the working set of live shards so it does not thrash.
|
||||
|
||||
Each shard mid-episode keeps one open decoder per camera; with several shards iterated
|
||||
concurrently the working set is ``num_active_shards × num_cameras``. We add one shard worth of
|
||||
margin so the round-robin never evicts a still-live decoder.
|
||||
"""
|
||||
if self.video_decoder_cache_size is not None:
|
||||
return VideoDecoderCache(max_size=self.video_decoder_cache_size)
|
||||
num_cameras = len(self.meta.video_keys)
|
||||
if num_cameras == 0:
|
||||
return VideoDecoderCache()
|
||||
return VideoDecoderCache(max_size=(num_active_shards + 1) * num_cameras)
|
||||
|
||||
# TODO(fracapuano): Implement multi-threaded prefetching to accelerate data loading.
|
||||
# The current sequential iteration is a bottleneck. A producer-consumer pattern
|
||||
# could be used with a ThreadPoolExecutor to run `make_frame` (especially video decoding)
|
||||
# in parallel, feeding a queue from which this iterator will yield processed items.
|
||||
def __iter__(self) -> Iterator[dict[str, torch.Tensor]]:
|
||||
if self.video_decoder_cache is None:
|
||||
self.video_decoder_cache = VideoDecoderCache()
|
||||
# Distributed correctness: each rank streams a disjoint set of shards (order preserved).
|
||||
ds = self.hf_dataset
|
||||
if self.world_size > 1:
|
||||
ds = split_dataset_by_node(ds, rank=self.rank, world_size=self.world_size)
|
||||
|
||||
num_shards = min(ds.num_shards, self.max_num_shards)
|
||||
shard_indices = list(range(num_shards))
|
||||
|
||||
# DataLoader workers within this rank further split the shards so they don't yield duplicates.
|
||||
worker_info = torch.utils.data.get_worker_info()
|
||||
if worker_info is not None:
|
||||
shard_indices = shard_indices[worker_info.id :: worker_info.num_workers]
|
||||
|
||||
self.video_decoder_cache = self._make_video_decoder_cache(len(shard_indices))
|
||||
|
||||
# keep the same seed across exhaustions if shuffle is False, otherwise shuffle data across exhaustions
|
||||
rng = np.random.default_rng(self.seed) if not self.shuffle else self.rng
|
||||
|
||||
# Best-effort resume: restore RNG + exhausted shards and rewind each shard's HF stream. The
|
||||
# shuffle buffer is re-warmed rather than restored, so resumption is not bit-exact (acceptable
|
||||
# for pretraining); the underlying stream may also skip the few frames Backtrackable read ahead.
|
||||
resume = self._resume_state
|
||||
self._resume_state = None
|
||||
self._exhausted: set[int] = set(resume["exhausted"]) if resume is not None else set()
|
||||
if resume is not None:
|
||||
rng.bit_generator.state = resume["rng"]
|
||||
|
||||
self._shards: dict[int, datasets.IterableDataset] = {}
|
||||
for idx in shard_indices:
|
||||
shard = safe_shard(ds, idx, num_shards)
|
||||
if resume is not None and str(idx) in resume["shards"]:
|
||||
shard.load_state_dict(resume["shards"][str(idx)])
|
||||
self._shards[idx] = shard
|
||||
|
||||
buffer_indices_generator = self._iter_random_indices(rng, self.buffer_size)
|
||||
|
||||
idx_to_backtrack_dataset = {
|
||||
idx: self._make_backtrackable_dataset(safe_shard(self.hf_dataset, idx, self.num_shards))
|
||||
for idx in range(self.num_shards)
|
||||
idx: self._make_backtrackable_dataset(shard)
|
||||
for idx, shard in self._shards.items()
|
||||
if idx not in self._exhausted
|
||||
}
|
||||
|
||||
# This buffer is populated while iterating on the dataset's shards
|
||||
@@ -389,11 +482,31 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
StopIteration,
|
||||
): # NOTE: StopIteration inside a generator throws a RuntimeError since python 3.7
|
||||
del idx_to_backtrack_dataset[shard_key] # Remove exhausted shard, onto another shard
|
||||
self._exhausted.add(shard_key)
|
||||
|
||||
# Once shards are all exhausted, shuffle the buffer and yield the remaining frames
|
||||
rng.shuffle(frames_buffer)
|
||||
yield from frames_buffer
|
||||
|
||||
def state_dict(self) -> dict:
|
||||
"""Capture resume state: per-shard HF stream position, exhausted shards, and RNG state.
|
||||
|
||||
Must be called after iteration has started (so the shard streams exist). Restore the returned
|
||||
dict with :meth:`load_state_dict` before re-iterating. The shuffle buffer is not captured, so
|
||||
resumption is not bit-exact — see :meth:`__iter__`.
|
||||
"""
|
||||
if not hasattr(self, "_shards"):
|
||||
raise RuntimeError("state_dict() requires the dataset to have been iterated at least once.")
|
||||
return {
|
||||
"shards": {str(idx): shard.state_dict() for idx, shard in self._shards.items()},
|
||||
"exhausted": sorted(self._exhausted),
|
||||
"rng": self.rng.bit_generator.state,
|
||||
}
|
||||
|
||||
def load_state_dict(self, state_dict: dict) -> None:
|
||||
"""Stage resume state captured by :meth:`state_dict`; applied at the next ``__iter__``."""
|
||||
self._resume_state = state_dict
|
||||
|
||||
def _get_window_steps(
|
||||
self, delta_timestamps: dict[str, list[float]] | None = None, dynamic_bounds: bool = False
|
||||
) -> tuple[int, int]:
|
||||
@@ -405,19 +518,23 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
lookback = LOOKBACK_BACKTRACKTABLE
|
||||
lookahead = LOOKAHEAD_BACKTRACKTABLE
|
||||
else:
|
||||
# Dynamically adjust the windows based on the given delta_timesteps
|
||||
# Dynamically size the windows to exactly cover the requested delta_timestamps (in frames).
|
||||
# This removes the fixed LOOKAHEAD_BACKTRACKTABLE ceiling, which would raise LookAheadError for
|
||||
# long horizons (e.g. a SARM window of 8 steps spaced 1s = ~160 frames @ fps20).
|
||||
all_timestamps = sum(delta_timestamps.values(), [])
|
||||
lookback = min(all_timestamps) * self.fps
|
||||
lookahead = max(all_timestamps) * self.fps
|
||||
lookback = math.floor(min(all_timestamps) * self.fps)
|
||||
lookahead = math.ceil(max(all_timestamps) * self.fps)
|
||||
|
||||
# When lookback is >=0 it means no negative timesteps have been provided
|
||||
lookback = 0 if lookback >= 0 else (lookback * -1)
|
||||
lookback = 0 if lookback >= 0 else -lookback
|
||||
|
||||
return lookback, lookahead
|
||||
|
||||
def _make_backtrackable_dataset(self, dataset: datasets.IterableDataset) -> Backtrackable:
|
||||
lookback, lookahead = self._get_window_steps(self.delta_timestamps)
|
||||
return Backtrackable(dataset, history=lookback, lookahead=lookahead)
|
||||
lookback, lookahead = self._get_window_steps(self.delta_timestamps, dynamic_bounds=True)
|
||||
# Backtrackable.peek_back(n) needs `history >= n + 1`, so reach a frame `lookback` steps back requires
|
||||
# history = lookback + 1. history must be >= 1 and lookahead > 0, so clamp both to at least 1.
|
||||
return Backtrackable(dataset, history=max(1, lookback + 1), lookahead=max(1, lookahead))
|
||||
|
||||
def _make_timestamps_from_indices(
|
||||
self, start_ts: float, indices: dict[str, list[int]] | None = None
|
||||
|
||||
@@ -0,0 +1,246 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for the HF-native large-scale streaming additions: distributed (per-rank) sharding,
|
||||
DataLoader worker splitting, SARM-sized delta windows, resumability, and schema parity."""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
|
||||
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
|
||||
from lerobot.utils.constants import ACTION
|
||||
from tests.fixtures.constants import DUMMY_REPO_ID
|
||||
|
||||
|
||||
def _make_local_dataset(factory, root, repo_id, *, total_episodes, total_frames, use_videos=False, **kw):
|
||||
factory(
|
||||
root=root,
|
||||
repo_id=repo_id,
|
||||
total_episodes=total_episodes,
|
||||
total_frames=total_frames,
|
||||
use_videos=use_videos,
|
||||
data_files_size_in_mb=0.001,
|
||||
chunks_size=1,
|
||||
**kw,
|
||||
)
|
||||
|
||||
|
||||
def _stream_indices(ds: StreamingLeRobotDataset) -> list[int]:
|
||||
return [int(frame["index"]) for frame in ds]
|
||||
|
||||
|
||||
def test_resolve_distributed_prefers_explicit_then_env(monkeypatch):
|
||||
assert StreamingLeRobotDataset._resolve_distributed(2, 8) == (2, 8)
|
||||
|
||||
monkeypatch.delenv("RANK", raising=False)
|
||||
monkeypatch.delenv("WORLD_SIZE", raising=False)
|
||||
# No accelerate state, no env -> single process.
|
||||
assert StreamingLeRobotDataset._resolve_distributed(None, None) == (0, 1)
|
||||
|
||||
monkeypatch.setenv("RANK", "3")
|
||||
monkeypatch.setenv("WORLD_SIZE", "4")
|
||||
assert StreamingLeRobotDataset._resolve_distributed(None, None) == (3, 4)
|
||||
|
||||
|
||||
def test_split_by_node_disjoint_across_ranks(tmp_path, lerobot_dataset_factory):
|
||||
"""Each rank must stream a disjoint set of frames, and the ranks together must cover every frame."""
|
||||
repo_id = f"{DUMMY_REPO_ID}-ranks"
|
||||
total_frames, total_episodes = 200, 8
|
||||
_make_local_dataset(
|
||||
lerobot_dataset_factory,
|
||||
tmp_path / "ds",
|
||||
repo_id,
|
||||
total_episodes=total_episodes,
|
||||
total_frames=total_frames,
|
||||
)
|
||||
|
||||
world_size = 2
|
||||
per_rank = []
|
||||
for rank in range(world_size):
|
||||
ds = StreamingLeRobotDataset(
|
||||
repo_id=repo_id,
|
||||
root=tmp_path / "ds",
|
||||
shuffle=False,
|
||||
buffer_size=8,
|
||||
max_num_shards=8,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
per_rank.append(set(_stream_indices(ds)))
|
||||
|
||||
assert per_rank[0].isdisjoint(per_rank[1]), (
|
||||
"ranks streamed overlapping frames (duplicate data across GPUs)"
|
||||
)
|
||||
assert per_rank[0] | per_rank[1] == set(range(total_frames)), "ranks did not jointly cover all frames"
|
||||
|
||||
|
||||
def test_dataloader_workers_no_duplicates_within_rank(tmp_path, lerobot_dataset_factory):
|
||||
"""DataLoader workers within a rank must split shards so no frame is yielded twice."""
|
||||
repo_id = f"{DUMMY_REPO_ID}-workers"
|
||||
total_frames, total_episodes = 120, 8
|
||||
_make_local_dataset(
|
||||
lerobot_dataset_factory,
|
||||
tmp_path / "ds",
|
||||
repo_id,
|
||||
total_episodes=total_episodes,
|
||||
total_frames=total_frames,
|
||||
)
|
||||
|
||||
ds = StreamingLeRobotDataset(
|
||||
repo_id=repo_id, root=tmp_path / "ds", shuffle=False, buffer_size=4, max_num_shards=4
|
||||
)
|
||||
loader = DataLoader(ds, batch_size=None, num_workers=2)
|
||||
indices = [int(batch["index"]) for batch in loader]
|
||||
|
||||
assert len(indices) == len(set(indices)), "DataLoader workers yielded duplicate frames within a rank"
|
||||
|
||||
|
||||
def test_sarm_window_covers_long_horizon_without_padding(tmp_path, lerobot_dataset_factory):
|
||||
"""A delta window longer than the old 100-frame ceiling must fetch real frames, not pad them.
|
||||
|
||||
SARM uses a window of 8 steps spaced 1s (~160 frames @ fps20). Here fps=30, so +5s = 150 frames > 100.
|
||||
"""
|
||||
repo_id = f"{DUMMY_REPO_ID}-sarm"
|
||||
# Two episodes of 200 frames each -> a +150-frame lookahead stays inside an episode for early frames.
|
||||
_make_local_dataset(lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=2, total_frames=400)
|
||||
|
||||
horizon_s = 5.0 # 150 frames @ fps30, well beyond LOOKAHEAD_BACKTRACKTABLE=100
|
||||
delta_timestamps = {ACTION: [0.0, horizon_s]}
|
||||
ds = StreamingLeRobotDataset(
|
||||
repo_id=repo_id,
|
||||
root=tmp_path / "ds",
|
||||
shuffle=False,
|
||||
buffer_size=1,
|
||||
max_num_shards=1,
|
||||
delta_timestamps=delta_timestamps,
|
||||
)
|
||||
|
||||
horizon_frames = int(round(horizon_s * ds.fps))
|
||||
checked = 0
|
||||
for frame in ds:
|
||||
idx = int(frame["index"])
|
||||
# Only assert on frames whose +horizon target is still inside the same episode.
|
||||
if int(frame["episode_index"]) == 0 and idx + horizon_frames < 200:
|
||||
assert not bool(frame[f"{ACTION}_is_pad"][-1]), (
|
||||
f"frame {idx}: +{horizon_frames} target was padded; long delta window did not reach it"
|
||||
)
|
||||
checked += 1
|
||||
assert checked > 0, "test did not exercise any in-episode long-horizon frame"
|
||||
|
||||
|
||||
def test_state_dict_resume_continues_without_restart(tmp_path, lerobot_dataset_factory):
|
||||
"""state_dict()/load_state_dict() must resume the stream near where it stopped, not from the start."""
|
||||
repo_id = f"{DUMMY_REPO_ID}-resume"
|
||||
total_frames = 100
|
||||
_make_local_dataset(
|
||||
lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=5, total_frames=total_frames
|
||||
)
|
||||
|
||||
def fresh_ds():
|
||||
return StreamingLeRobotDataset(
|
||||
repo_id=repo_id, root=tmp_path / "ds", shuffle=False, buffer_size=1, max_num_shards=1
|
||||
)
|
||||
|
||||
ds = fresh_ds()
|
||||
it = iter(ds)
|
||||
stop_after = 40
|
||||
seen_before = [int(next(it)["index"]) for _ in range(stop_after)]
|
||||
state = ds.state_dict()
|
||||
assert set(state) == {"shards", "exhausted", "rng"}
|
||||
|
||||
resumed_ds = fresh_ds()
|
||||
resumed_ds.load_state_dict(state)
|
||||
resumed = _stream_indices(resumed_ds)
|
||||
|
||||
# Resume continues rather than replaying: the full first pass is not re-yielded.
|
||||
assert len(resumed) < total_frames
|
||||
overlap = set(seen_before) & set(resumed)
|
||||
assert len(overlap) <= 2, f"resume re-yielded already-seen frames: {sorted(overlap)}"
|
||||
# Together the two passes cover essentially the whole dataset (a few frames may be dropped by the
|
||||
# ahead-read at the resume boundary -- documented non-bit-exact behaviour).
|
||||
assert len(set(seen_before) | set(resumed)) >= total_frames - 2
|
||||
|
||||
|
||||
def test_schema_parity_with_map_style(tmp_path, lerobot_dataset_factory):
|
||||
"""Streamed samples must have the same keys / shapes / dtypes as map-style LeRobotDataset."""
|
||||
repo_id = f"{DUMMY_REPO_ID}-parity"
|
||||
map_ds = lerobot_dataset_factory(
|
||||
root=tmp_path / "ds", repo_id=repo_id, total_episodes=4, total_frames=80, use_videos=True
|
||||
)
|
||||
stream_ds = StreamingLeRobotDataset(
|
||||
repo_id=repo_id, root=tmp_path / "ds", shuffle=False, buffer_size=4, max_num_shards=2
|
||||
)
|
||||
|
||||
map_frame = map_ds[0]
|
||||
stream_frame = next(iter(stream_ds))
|
||||
|
||||
assert set(stream_frame) == set(map_frame), set(stream_frame) ^ set(map_frame)
|
||||
for key, value in stream_frame.items():
|
||||
ref = map_frame[key]
|
||||
if isinstance(value, torch.Tensor):
|
||||
assert isinstance(ref, torch.Tensor) and value.shape == ref.shape and value.dtype == ref.dtype, (
|
||||
f"{key}: stream {tuple(value.shape)}/{value.dtype} vs map {tuple(ref.shape)}/{ref.dtype}"
|
||||
)
|
||||
elif isinstance(value, str):
|
||||
assert isinstance(ref, str), f"{key}: {type(value)} vs {type(ref)}"
|
||||
else:
|
||||
# Scalar numerics: streaming yields python floats where map-style yields 0-dim tensors
|
||||
# (a long-standing, accepted difference). Compare by value rather than exact type.
|
||||
assert float(value) == float(ref), f"{key}: {value} vs {ref}"
|
||||
|
||||
|
||||
def test_video_path_resolution_local(tmp_path, lerobot_dataset_factory, monkeypatch):
|
||||
"""For a local (prewarmed) root, video decode must be issued against the local path, not hf://."""
|
||||
import lerobot.datasets.streaming_dataset as sd
|
||||
|
||||
repo_id = f"{DUMMY_REPO_ID}-vpath"
|
||||
lerobot_dataset_factory(
|
||||
root=tmp_path / "ds", repo_id=repo_id, total_episodes=2, total_frames=40, use_videos=True
|
||||
)
|
||||
ds = StreamingLeRobotDataset(
|
||||
repo_id=repo_id, root=tmp_path / "ds", shuffle=False, buffer_size=1, max_num_shards=1
|
||||
)
|
||||
|
||||
seen_paths = []
|
||||
|
||||
def fake_decode(video_path, query_ts, *args, **kwargs):
|
||||
seen_paths.append(str(video_path))
|
||||
return torch.zeros(len(query_ts), 3, 64, 96)
|
||||
|
||||
monkeypatch.setattr(sd, "decode_video_frames_torchcodec", fake_decode)
|
||||
next(iter(ds))
|
||||
|
||||
assert seen_paths, "no video decode was issued"
|
||||
assert all(str(ds.root) in p and not p.startswith("hf://") for p in seen_paths), seen_paths
|
||||
|
||||
|
||||
def test_shuffle_decorrelates_output_order(tmp_path, lerobot_dataset_factory):
|
||||
"""With shuffle on, streamed frame order must differ from the underlying sequential order."""
|
||||
repo_id = f"{DUMMY_REPO_ID}-shuf"
|
||||
_make_local_dataset(lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=8, total_frames=200)
|
||||
ordered = _stream_indices(
|
||||
StreamingLeRobotDataset(
|
||||
repo_id=repo_id, root=tmp_path / "ds", shuffle=False, buffer_size=1, max_num_shards=1
|
||||
)
|
||||
)
|
||||
shuffled = _stream_indices(
|
||||
StreamingLeRobotDataset(
|
||||
repo_id=repo_id, root=tmp_path / "ds", shuffle=True, buffer_size=64, max_num_shards=4, seed=0
|
||||
)
|
||||
)
|
||||
assert sorted(shuffled) == sorted(ordered), "shuffling changed the set of frames"
|
||||
assert shuffled != ordered, "shuffle did not decorrelate output order"
|
||||
Reference in New Issue
Block a user