feat(streaming): episode-pool iteration with decode-on-exit, video prefetch, and exact resume

Replace the shard/Backtrackable/decoded-shuffle-buffer internals with an
episode pool: each (rank x worker) consumer keeps episode_pool_size whole
episodes' tabular rows in RAM and emits uniformly random frames across
them. delta_timestamps windows become exact in-RAM slices with correct
boundary padding (the Backtrackable machinery and its lookback/lookahead
ceilings are gone), and video is decoded only when a sample is emitted,
so pool memory stays tabular-sized instead of buffer_size decoded
samples.

- Prefetch-on-admit: when streaming from a remote source, each pooled
  episode's video files download to a local cache in the background
  (refcounted, since v3 packs several episodes per file; deleted on
  eviction), so decode-on-exit reads local bytes instead of paying
  network seek latency.
- Per-consumer RNG derived from (seed, epoch, rank, worker): consumers
  decorrelated, runs reproducible, epochs reshuffle automatically.
- Deterministic fast-forward resume: load_state_dict takes the trainer's
  {batches_consumed, batch_size}; each worker re-derives its own skip
  from the DataLoader's round-robin batch assignment and replays
  tabular-only (no decode). Exact within an epoch, works with
  num_workers > 0, and the same state file serves every rank. Replaces
  the per-shard HF state_dict approach, which lived in worker processes
  and could not be captured from the trainer.
- Shard-cap default removed (max_num_shards=None uses every parquet
  shard); runtime warnings for non-divisible world sizes (datasets
  degrades to read-everything splitting) and workers left without
  shards.
- episode_pool_size replaces buffer_size (deprecated, ignored with a
  warning); decoder cache sized to the pool working set, capped at 128.

Legacy order-replication tests asserted the old buffer algorithm
step-by-step and are rewritten as behavior contracts (exactly-once
coverage, per-seed determinism, epoch reshuffle). Value-level parity
tests against the map-style dataset pass unchanged.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-06-11 15:02:15 +02:00
parent 66ac901632
commit 1050c2fb6c
11 changed files with 521 additions and 650 deletions
+3 -3
View File
@@ -62,7 +62,7 @@ def parse_args() -> argparse.Namespace:
parser.add_argument("--source", type=str, default="hub", help="Label only: hub | bucket | warmed_bucket.")
parser.add_argument("--batch_size", type=int, default=64)
parser.add_argument("--num_workers", type=int, default=8)
parser.add_argument("--buffer_size", type=int, default=2000)
parser.add_argument("--episode_pool_size", type=int, default=64)
parser.add_argument("--video_decoder_cache_size", type=int, default=None)
parser.add_argument(
"--video_decode_device",
@@ -86,7 +86,7 @@ def build_dataset(args: argparse.Namespace, meta: LeRobotDatasetMetadata) -> Str
root=args.root,
data_files_root=args.data_files_root,
delta_timestamps=delta_timestamps,
buffer_size=args.buffer_size,
episode_pool_size=args.episode_pool_size,
video_decoder_cache_size=args.video_decoder_cache_size,
video_decode_device=args.video_decode_device,
tolerance_s=1e-3,
@@ -172,7 +172,7 @@ def main() -> None:
"mode": args.mode,
"batch_size": args.batch_size,
"num_workers": args.num_workers,
"buffer_size": args.buffer_size,
"episode_pool_size": args.episode_pool_size,
"num_cameras": len(meta.video_keys),
"fps": meta.fps,
"device": str(device),
+16 -10
View File
@@ -21,7 +21,7 @@ streaming features of :class:`StreamingLeRobotDataset`:
- per-rank sharding via ``split_dataset_by_node`` (each GPU streams disjoint data; ``rank``/``world_size``
are auto-resolved from the Accelerate state, so nothing needs to be passed explicitly);
- DataLoader-worker shard splitting (no duplicate frames within a rank);
- resumable streaming via ``dataset.state_dict()`` / ``load_state_dict()`` saved into the checkpoint;
- deterministic fast-forward resume via ``dataset.load_state_dict()`` (trainer-side counters only);
- an explicit video-decoder cache size so the working set of open decoders does not thrash.
Launch with Accelerate (single node, N GPUs):
@@ -57,7 +57,10 @@ def parse_args() -> argparse.Namespace:
parser.add_argument("--batch_size", type=int, default=64, help="Per-process batch size.")
parser.add_argument("--num_workers", type=int, default=8)
parser.add_argument(
"--buffer_size", type=int, default=2000, help="Output shuffle-buffer size, in frames."
"--episode_pool_size",
type=int,
default=64,
help="Whole episodes open per consumer (randomness knob).",
)
parser.add_argument("--video_decoder_cache_size", type=int, default=None)
parser.add_argument("--n_action_steps", type=int, default=16, help="Action-chunk length (delta horizon).")
@@ -78,7 +81,7 @@ def make_dataloader(
args.repo_id,
root=args.root,
delta_timestamps=delta_timestamps,
buffer_size=args.buffer_size,
episode_pool_size=args.episode_pool_size,
video_decoder_cache_size=args.video_decoder_cache_size,
tolerance_s=1e-3,
)
@@ -121,13 +124,13 @@ def main() -> None:
# of it). Batches are moved to the device manually in the loop.
model, optimizer = accelerator.prepare(model, optimizer)
# Resume: restore the dataset's stream position so we don't replay already-seen data. The state holds
# plain HF stream dicts + RNG state (not tensors), so weights_only=False is required; the file is a
# checkpoint this script wrote itself.
# Resume: deterministic fast-forward. Every consumer's order is a pure function of
# (seed, epoch, rank, worker), so resuming only needs the trainer-side counters; each rank and
# worker re-derives its own skip. Same file works for every rank.
if args.resume_from is not None:
state = torch.load(Path(args.resume_from) / "dataset_state.pt", weights_only=False) # nosec B614
state = torch.load(Path(args.resume_from) / "dataset_state.pt", weights_only=True)
dataset.load_state_dict(state)
accelerator.print(f"Resumed dataset stream from {args.resume_from}")
accelerator.print(f"Resuming dataset stream: {state['batches_consumed']} batches consumed")
step = 0
frames_seen = 0
@@ -157,8 +160,11 @@ def main() -> None:
if step % args.save_freq == 0 and accelerator.is_main_process:
ckpt = output_dir / f"checkpoint-{step}"
ckpt.mkdir(parents=True, exist_ok=True)
# Save the dataset stream position alongside the model so a restart resumes mid-stream.
torch.save(dataset.state_dict(), ckpt / "dataset_state.pt")
# Save the consumed-batch counters so a restart fast-forwards to this position.
torch.save(
{"batches_consumed": step, "batch_size": args.batch_size},
ckpt / "dataset_state.pt",
)
if model is not None:
accelerator.unwrap_model(model).save_pretrained(ckpt)
+1 -1
View File
@@ -33,7 +33,7 @@ for MODE in single sarm; do
--mode $MODE \
--batch_size 64 \
--num_workers 12 \
--buffer_size 4000 \
--episode_pool_size 64 \
--num_batches 300 \
--out_dir '"$OUT_DIR"'/node${SLURM_NODEID}
done
+1 -1
View File
@@ -83,7 +83,7 @@ for SOURCE in $SOURCES; do
$RUN benchmarks/streaming/benchmark_streaming.py \
--repo_id $REPO_ID $ROOTFLAG \
--mode $MODE --source $SOURCE --video_decode_device $DECODE \
--batch_size $BATCH_SIZE --num_workers $W --buffer_size $B \
--batch_size $BATCH_SIZE --num_workers $W --episode_pool_size $B \
--num_batches $NUM_BATCHES --out_dir $OUT_DIR")
jid=${jid%%;*} # strip ';cluster' suffix on federated setups
echo "submitted job $jid bench_${SOURCE}_${MODE}_${DECODE}${DEPFLAG:+ (after $prev_jid)}"
+1 -1
View File
@@ -42,7 +42,7 @@ accelerate launch \
--repo_id '"$REPO_ID"' \
--batch_size 64 \
--num_workers 12 \
--buffer_size 4000 \
--episode_pool_size 64 \
--steps 200000 \
--save_freq 2000 \
--log_freq 50
+4 -3
View File
@@ -39,9 +39,10 @@ class DatasetConfig:
# This reduces memory and speeds up DataLoader IPC. The training pipeline handles the conversion.
return_uint8: bool = False
streaming: bool = False
# Output shuffle-buffer size (in frames) when streaming. Larger decorrelates samples better at the cost
# of host RAM. Ignored when streaming is False.
streaming_buffer_size: int = 1000
# Whole episodes each streaming consumer keeps open to shuffle across (the randomness knob).
# Larger mixes more episodes per batch at the cost of cold-start latency; RAM stays small because
# the pool holds tabular rows only. Ignored when streaming is False.
streaming_episode_pool_size: int = 64
def __post_init__(self) -> None:
if self.episodes is not None:
+1 -1
View File
@@ -106,7 +106,7 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
delta_timestamps=delta_timestamps,
image_transforms=image_transforms,
revision=cfg.dataset.revision,
buffer_size=cfg.dataset.streaming_buffer_size,
episode_pool_size=cfg.dataset.streaming_episode_pool_size,
tolerance_s=cfg.tolerance_s,
return_uint8=True,
)
File diff suppressed because it is too large Load Diff
+25 -95
View File
@@ -13,7 +13,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import pytest
import torch
@@ -25,52 +24,6 @@ from lerobot.utils.constants import ACTION
from tests.fixtures.constants import DUMMY_REPO_ID
def get_frames_expected_order(streaming_ds: StreamingLeRobotDataset) -> list[int]:
"""Replicates the shuffling logic of StreamingLeRobotDataset to get the expected order of indices."""
rng = np.random.default_rng(streaming_ds.seed)
buffer_size = streaming_ds.buffer_size
num_shards = streaming_ds.num_shards
shards_indices = []
for shard_idx in range(num_shards):
shard = streaming_ds.hf_dataset.shard(num_shards, index=shard_idx)
shard_indices = [item["index"] for item in shard]
shards_indices.append(shard_indices)
shard_iterators = {i: iter(s) for i, s in enumerate(shards_indices)}
buffer_indices_generator = streaming_ds._iter_random_indices(rng, buffer_size)
frames_buffer = []
expected_indices = []
while shard_iterators: # While there are still available shards
available_shard_keys = list(shard_iterators.keys())
if not available_shard_keys:
break
# Call _infinite_generator_over_elements with current available shards (key difference!)
shard_key = next(streaming_ds._infinite_generator_over_elements(rng, available_shard_keys))
try:
frame_index = next(shard_iterators[shard_key])
if len(frames_buffer) == buffer_size:
i = next(buffer_indices_generator)
expected_indices.append(frames_buffer[i])
frames_buffer[i] = frame_index
else:
frames_buffer.append(frame_index)
except StopIteration:
del shard_iterators[shard_key] # Remove exhausted shard
rng.shuffle(frames_buffer)
expected_indices.extend(frames_buffer)
return expected_indices
def test_single_frame_consistency(tmp_path, lerobot_dataset_factory):
"""Test if are correctly accessed"""
ds_num_frames = 400
@@ -120,10 +73,9 @@ def test_single_frame_consistency(tmp_path, lerobot_dataset_factory):
[False, True],
)
def test_frames_order_over_epochs(tmp_path, lerobot_dataset_factory, shuffle):
"""Test if streamed frames correspond to shuffling operations over in-memory dataset."""
"""Each epoch covers every frame exactly once; shuffle reshuffles across epochs."""
ds_num_frames = 400
ds_num_episodes = 10
buffer_size = 100
seed = 42
n_epochs = 3
@@ -138,25 +90,17 @@ def test_frames_order_over_epochs(tmp_path, lerobot_dataset_factory, shuffle):
)
streaming_ds = StreamingLeRobotDataset(
repo_id=repo_id, root=local_path, buffer_size=buffer_size, seed=seed, shuffle=shuffle
repo_id=repo_id, root=local_path, episode_pool_size=4, seed=seed, shuffle=shuffle
)
first_epoch_indices = [frame["index"] for frame in streaming_ds]
expected_indices = get_frames_expected_order(streaming_ds)
assert first_epoch_indices == expected_indices, "First epoch indices do not match expected indices"
expected_indices = get_frames_expected_order(streaming_ds)
for _ in range(n_epochs):
streaming_indices = [frame["index"] for frame in streaming_ds]
frames_match = all(
s_index == e_index for s_index, e_index in zip(streaming_indices, expected_indices, strict=True)
)
if shuffle:
assert not frames_match
else:
assert frames_match
epochs = [[int(frame["index"]) for frame in streaming_ds] for _ in range(n_epochs)]
for epoch_indices in epochs:
assert sorted(epoch_indices) == list(range(ds_num_frames)), "epoch did not cover every frame once"
if shuffle:
assert epochs[0] != epochs[1], "shuffle did not reshuffle across epochs"
assert epochs[0] != list(range(ds_num_frames)), "shuffle left the stream in sequential order"
else:
assert epochs[0] == epochs[1] == epochs[2], "unshuffled epochs must repeat the same order"
@pytest.mark.parametrize(
@@ -164,15 +108,11 @@ def test_frames_order_over_epochs(tmp_path, lerobot_dataset_factory, shuffle):
[False, True],
)
def test_frames_order_with_shards(tmp_path, lerobot_dataset_factory, shuffle):
"""Test if streamed frames correspond to shuffling operations over in-memory dataset with multiple shards."""
"""Multi-shard streams keep exactly-once coverage and deterministic per-seed order."""
ds_num_frames = 100
ds_num_episodes = 10
buffer_size = 10
seed = 42
n_epochs = 3
data_file_size_mb = 0.001
chunks_size = 1
local_path = tmp_path / "test"
@@ -187,31 +127,21 @@ def test_frames_order_with_shards(tmp_path, lerobot_dataset_factory, shuffle):
chunks_size=chunks_size,
)
streaming_ds = StreamingLeRobotDataset(
repo_id=repo_id,
root=local_path,
buffer_size=buffer_size,
seed=seed,
shuffle=shuffle,
max_num_shards=4,
)
first_epoch_indices = [frame["index"] for frame in streaming_ds]
expected_indices = get_frames_expected_order(streaming_ds)
assert first_epoch_indices == expected_indices, "First epoch indices do not match expected indices"
for _ in range(n_epochs):
streaming_indices = [
frame["index"] for frame in streaming_ds
] # NOTE: this is the same as first_epoch_indices
frames_match = all(
s_index == e_index for s_index, e_index in zip(streaming_indices, expected_indices, strict=True)
def make_ds():
return StreamingLeRobotDataset(
repo_id=repo_id,
root=local_path,
episode_pool_size=3,
seed=seed,
shuffle=shuffle,
max_num_shards=4,
)
if shuffle:
assert not frames_match
else:
assert frames_match
first = [int(frame["index"]) for frame in make_ds()]
again = [int(frame["index"]) for frame in make_ds()]
assert sorted(first) == list(range(ds_num_frames)), "epoch did not cover every frame once"
assert first == again, "same seed must reproduce the same order"
@pytest.mark.parametrize(
+1 -1
View File
@@ -40,7 +40,7 @@ from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
root, repo_id, out_dir = sys.argv[1], sys.argv[2], sys.argv[3]
state = PartialState()
ds = StreamingLeRobotDataset(
repo_id=repo_id, root=root, shuffle=False, buffer_size=8, max_num_shards=8
repo_id=repo_id, root=root, shuffle=False, episode_pool_size=8, max_num_shards=8
)
indices = [int(frame["index"]) for frame in ds]
payload = {"rank": state.process_index, "world": state.num_processes, "indices": indices}
+92 -25
View File
@@ -13,7 +13,8 @@
# limitations under the License.
"""Tests for the HF-native large-scale streaming additions: distributed (per-rank) sharding,
DataLoader worker splitting, SARM-sized delta windows, resumability, and schema parity."""
DataLoader worker splitting, the episode pool (randomness, coverage, exact deltas), video
prefetching, deterministic fast-forward resume, and schema parity."""
import pytest
import torch
@@ -75,7 +76,7 @@ def test_split_by_node_disjoint_across_ranks(tmp_path, lerobot_dataset_factory):
repo_id=repo_id,
root=tmp_path / "ds",
shuffle=False,
buffer_size=8,
episode_pool_size=8,
max_num_shards=8,
rank=rank,
world_size=world_size,
@@ -101,7 +102,7 @@ def test_dataloader_workers_no_duplicates_within_rank(tmp_path, lerobot_dataset_
)
ds = StreamingLeRobotDataset(
repo_id=repo_id, root=tmp_path / "ds", shuffle=False, buffer_size=4, max_num_shards=4
repo_id=repo_id, root=tmp_path / "ds", shuffle=False, episode_pool_size=4, max_num_shards=4
)
loader = DataLoader(ds, batch_size=None, num_workers=2)
indices = [int(batch["index"]) for batch in loader]
@@ -128,7 +129,7 @@ def test_sarm_window_covers_long_horizon_without_padding(tmp_path, lerobot_datas
repo_id=repo_id,
root=tmp_path / "ds",
shuffle=False,
buffer_size=1,
episode_pool_size=1,
max_num_shards=1,
delta_timestamps=delta_timestamps,
)
@@ -147,8 +148,8 @@ def test_sarm_window_covers_long_horizon_without_padding(tmp_path, lerobot_datas
assert checked > 0, "test did not exercise any in-episode long-horizon frame"
def test_state_dict_resume_continues_without_restart(tmp_path, lerobot_dataset_factory):
"""state_dict()/load_state_dict() must resume the stream near where it stopped, not from the start."""
def test_fast_forward_resume_is_sample_exact(tmp_path, lerobot_dataset_factory):
"""Resume replays the deterministic stream and continues at the exact sample."""
repo_id = f"{DUMMY_REPO_ID}-resume"
total_frames = 100
_make_local_dataset(
@@ -157,27 +158,93 @@ def test_state_dict_resume_continues_without_restart(tmp_path, lerobot_dataset_f
def fresh_ds():
return StreamingLeRobotDataset(
repo_id=repo_id, root=tmp_path / "ds", shuffle=False, buffer_size=1, max_num_shards=1
repo_id=repo_id,
root=tmp_path / "ds",
shuffle=True,
seed=7,
episode_pool_size=3,
max_num_shards=1,
)
ds = fresh_ds()
it = iter(ds)
stop_after = 40
seen_before = [int(next(it)["index"]) for _ in range(stop_after)]
state = ds.state_dict()
assert set(state) == {"shards", "exhausted", "rng"}
full_epoch = _stream_indices(fresh_ds())
assert sorted(full_epoch) == list(range(total_frames))
batches_consumed, batch_size = 5, 4 # 20 samples in
resumed_ds = fresh_ds()
resumed_ds.load_state_dict(state)
resumed_ds.load_state_dict({"batches_consumed": batches_consumed, "batch_size": batch_size})
resumed = _stream_indices(resumed_ds)
# Resume continues rather than replaying: the full first pass is not re-yielded.
assert len(resumed) < total_frames
overlap = set(seen_before) & set(resumed)
assert len(overlap) <= 2, f"resume re-yielded already-seen frames: {sorted(overlap)}"
# Together the two passes cover essentially the whole dataset (a few frames may be dropped by the
# ahead-read at the resume boundary -- documented non-bit-exact behaviour).
assert len(set(seen_before) | set(resumed)) >= total_frames - 2
assert resumed == full_epoch[batches_consumed * batch_size :], (
"fast-forward resume did not continue at the exact sample"
)
def test_pool_order_is_deterministic_per_seed(tmp_path, lerobot_dataset_factory):
repo_id = f"{DUMMY_REPO_ID}-seeds"
_make_local_dataset(lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=6, total_frames=120)
def order(seed):
return _stream_indices(
StreamingLeRobotDataset(
repo_id=repo_id,
root=tmp_path / "ds",
shuffle=True,
seed=seed,
episode_pool_size=4,
max_num_shards=2,
)
)
assert order(0) == order(0), "same seed must reproduce the same order"
assert order(0) != order(1), "different seeds should give different orders"
def test_pool_epochs_reshuffle_and_cover(tmp_path, lerobot_dataset_factory):
"""Consecutive passes over the same dataset object reshuffle (epoch advances) but keep coverage."""
repo_id = f"{DUMMY_REPO_ID}-epochs"
total_frames = 120
_make_local_dataset(
lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=6, total_frames=total_frames
)
ds = StreamingLeRobotDataset(
repo_id=repo_id, root=tmp_path / "ds", shuffle=True, seed=3, episode_pool_size=4, max_num_shards=2
)
epoch_0 = _stream_indices(ds)
epoch_1 = _stream_indices(ds)
assert sorted(epoch_0) == sorted(epoch_1) == list(range(total_frames))
assert epoch_0 != epoch_1, "epoch did not reshuffle"
def test_pool_mixes_episodes(tmp_path, lerobot_dataset_factory):
"""Early samples should already come from several distinct episodes (the pool's purpose)."""
repo_id = f"{DUMMY_REPO_ID}-mix"
_make_local_dataset(lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=8, total_frames=200)
ds = StreamingLeRobotDataset(
repo_id=repo_id, root=tmp_path / "ds", shuffle=True, seed=0, episode_pool_size=8, max_num_shards=4
)
episodes_in_head = {int(frame["episode_index"]) for _, frame in zip(range(20), ds, strict=False)}
assert len(episodes_in_head) >= 3, f"pool did not mix episodes: {episodes_in_head}"
def test_video_prefetcher_refcounted_lifecycle(tmp_path):
from lerobot.datasets.streaming_dataset import _VideoPrefetcher
remote = tmp_path / "remote"
(remote / "videos").mkdir(parents=True)
payload = b"x" * 1024
(remote / "videos" / "a.mp4").write_bytes(payload)
prefetcher = _VideoPrefetcher(str(remote), cache_dir=tmp_path / "cache", max_workers=1)
prefetcher.acquire("videos/a.mp4")
prefetcher.acquire("videos/a.mp4") # second pooled episode sharing the file
local = prefetcher.wait_local("videos/a.mp4")
assert local is not None and local.read_bytes() == payload
prefetcher.release("videos/a.mp4")
assert local.exists(), "file deleted while still referenced"
prefetcher.release("videos/a.mp4")
assert not local.exists(), "file not deleted at refcount zero"
prefetcher.shutdown()
def test_schema_parity_with_map_style(tmp_path, lerobot_dataset_factory):
@@ -187,7 +254,7 @@ def test_schema_parity_with_map_style(tmp_path, lerobot_dataset_factory):
root=tmp_path / "ds", repo_id=repo_id, total_episodes=4, total_frames=80, use_videos=True
)
stream_ds = StreamingLeRobotDataset(
repo_id=repo_id, root=tmp_path / "ds", shuffle=False, buffer_size=4, max_num_shards=2
repo_id=repo_id, root=tmp_path / "ds", shuffle=False, episode_pool_size=4, max_num_shards=2
)
map_frame = map_ds[0]
@@ -217,7 +284,7 @@ def test_video_path_resolution_local(tmp_path, lerobot_dataset_factory, monkeypa
root=tmp_path / "ds", repo_id=repo_id, total_episodes=2, total_frames=40, use_videos=True
)
ds = StreamingLeRobotDataset(
repo_id=repo_id, root=tmp_path / "ds", shuffle=False, buffer_size=1, max_num_shards=1
repo_id=repo_id, root=tmp_path / "ds", shuffle=False, episode_pool_size=1, max_num_shards=1
)
seen_paths = []
@@ -239,12 +306,12 @@ def test_shuffle_decorrelates_output_order(tmp_path, lerobot_dataset_factory):
_make_local_dataset(lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=8, total_frames=200)
ordered = _stream_indices(
StreamingLeRobotDataset(
repo_id=repo_id, root=tmp_path / "ds", shuffle=False, buffer_size=1, max_num_shards=1
repo_id=repo_id, root=tmp_path / "ds", shuffle=False, episode_pool_size=1, max_num_shards=1
)
)
shuffled = _stream_indices(
StreamingLeRobotDataset(
repo_id=repo_id, root=tmp_path / "ds", shuffle=True, buffer_size=64, max_num_shards=4, seed=0
repo_id=repo_id, root=tmp_path / "ds", shuffle=True, episode_pool_size=8, max_num_shards=4, seed=0
)
)
assert sorted(shuffled) == sorted(ordered), "shuffling changed the set of frames"