mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-15 23:39:50 +00:00
Compare commits
28 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 4940281120 | |||
| 3ec60da82b | |||
| 7bcd5a1502 | |||
| 674c990a39 | |||
| 38106ea6b4 | |||
| 894fc6bfb5 | |||
| 984b400e5c | |||
| 4e056081cb | |||
| a164bb97bd | |||
| 79b547de32 | |||
| a7b7f4964e | |||
| 1050c2fb6c | |||
| 66ac901632 | |||
| ce326207e6 | |||
| 2ab71231cd | |||
| 42d4788e4a | |||
| 2d1c17d971 | |||
| 7241f029c6 | |||
| 06ddc59913 | |||
| 23c58f5f9e | |||
| b0ab57cedc | |||
| afdc084677 | |||
| a32a2c647b | |||
| 343ecd7980 | |||
| f7c8a526e8 | |||
| 77af66a29c | |||
| 68fa5d80b0 | |||
| d1fc8e298c |
@@ -0,0 +1,547 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Single-image dataloading benchmark across the LeRobot loaders, MADE TO RUN ON A COMPUTE CLUSTER (SLURM).
|
||||
|
||||
This one file is both the orchestrator and the worker:
|
||||
|
||||
* Run it with no ``--scenario`` (from a login node) and it submits a SERIAL sbatch chain of all
|
||||
scenarios below (no two network-bound jobs overlap, so CDN numbers stay clean).
|
||||
* Run it with ``--scenario <name>`` and it executes that single benchmark (this is what each sbatch
|
||||
job calls). The 2-node scenario is launched with ``srun`` and reads ``RANK``/``WORLD_SIZE`` so the
|
||||
streaming dataset splits shards per node.
|
||||
|
||||
Scenarios (all single-frame / non-SARM):
|
||||
1. ``mmap_local`` map-style LeRobotDataset over a LOCAL copy (``--local_root``, no network).
|
||||
2. ``mmap_local_maxworkers`` same, but workers scaled to saturate the node's cores (decode-bound).
|
||||
3. ``stream_hub`` StreamingLeRobotDataset from the Hub (allenai/MolmoAct2-BimanualYAM-Dataset).
|
||||
4. ``stream_bucket`` StreamingLeRobotDataset from a warmed storage bucket (1 node).
|
||||
5. ``stream_bucket_2node`` same warmed bucket, 2 nodes (split_dataset_by_node, per-rank results).
|
||||
|
||||
Reported per run: peak process-tree RSS (max memory), parallel throughput (samples/s, where a sample
|
||||
is one timestep, plus decoded_frames/s = samples/s x num_cameras),
|
||||
single-process throughput, shuffle randomness fraction (distinct episodes per batch / batch size),
|
||||
fetch vs decode split (% of single-process per-sample time), first-batch latency, and p50/p95/p99
|
||||
sample latency. Results are written as JSON + CSV under ``--out_dir``.
|
||||
|
||||
Submit the whole chain (from a login node, inside the repo). Point the scheduler env vars at your own
|
||||
cluster's account/partition/qos, and ``--local_root`` at a local copy of the map-style dataset:
|
||||
ACCOUNT=<account> PARTITION=<partition> QOS=<qos> \\
|
||||
python examples/scaling/benchmark_dataloading.py --local_root /path/to/local/dataset
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import statistics
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from lerobot.datasets import LeRobotDataset, LeRobotDatasetMetadata, StreamingLeRobotDataset
|
||||
from lerobot.datasets.partition import group_episodes_by_files, partition_episodes
|
||||
|
||||
ROBOCASA_REPO = "pepijn223/robocasa_pretrain_human300_v4"
|
||||
MOLMO_REPO = "allenai/MolmoAct2-BimanualYAM-Dataset"
|
||||
MOLMO_BUCKET = "hf://buckets/pepijn223/MolmoAct2-BimanualYAM-Dataset-bucket"
|
||||
# MolmoAct2 is published without a codebase-version git tag, so the version-safe loader would refuse
|
||||
# it; "main" pins the branch directly and skips that check.
|
||||
MOLMO_REVISION = "main"
|
||||
|
||||
# Per-scenario sbatch shape. mem is generous for the streaming legs (32k-episode, 3-camera, 2.35 TB
|
||||
# dataset keeps many AV1 decoders open); the local map-style leg is light. Optional ``num_workers`` /
|
||||
# ``cpus`` override the CLI defaults for that leg.
|
||||
# ``mmap_local_maxworkers``: map-style decode is CPU-bound and each worker decodes its cameras on
|
||||
# parallel threads, so the saturation point is ~num_cpus / num_cameras workers (~90 concurrent decode
|
||||
# threads). The 96-core H100 nodes here schedule at most 92 cpus/task, so we take 92 cpus / 30 workers.
|
||||
SCENARIOS = {
|
||||
"mmap_local": {"kind": "map", "nodes": 1, "mem": "64G", "time": "01:00:00"},
|
||||
"mmap_local_maxworkers": {
|
||||
"kind": "map",
|
||||
"nodes": 1,
|
||||
"mem": "128G",
|
||||
"time": "01:00:00",
|
||||
"num_workers": 30,
|
||||
"cpus": 92,
|
||||
},
|
||||
"stream_hub": {"kind": "stream", "nodes": 1, "mem": "250G", "time": "03:00:00"},
|
||||
"stream_bucket": {"kind": "stream", "nodes": 1, "mem": "250G", "time": "03:00:00"},
|
||||
"stream_bucket_2node": {"kind": "stream", "nodes": 2, "mem": "250G", "time": "03:00:00"},
|
||||
}
|
||||
|
||||
|
||||
def _tree_rss_bytes() -> int:
|
||||
"""Sum RSS of this process and all descendants via /proc (DataLoader workers are separate procs)."""
|
||||
try:
|
||||
children: dict[int, list[int]] = {}
|
||||
for entry in os.listdir("/proc"):
|
||||
if not entry.isdigit():
|
||||
continue
|
||||
try:
|
||||
with open(f"/proc/{entry}/stat") as f:
|
||||
ppid = int(f.read().split(") ", 1)[1].split()[1])
|
||||
children.setdefault(ppid, []).append(int(entry))
|
||||
except (OSError, ValueError, IndexError):
|
||||
pass
|
||||
total, stack = 0, [os.getpid()]
|
||||
while stack:
|
||||
cur = stack.pop()
|
||||
try:
|
||||
with open(f"/proc/{cur}/statm") as f:
|
||||
total += int(f.read().split()[1]) * os.sysconf("SC_PAGE_SIZE")
|
||||
except (OSError, ValueError, IndexError):
|
||||
pass
|
||||
stack.extend(children.get(cur, []))
|
||||
return total
|
||||
except OSError:
|
||||
return 0
|
||||
|
||||
|
||||
class PeakRSSSampler:
|
||||
"""Background thread tracking peak process-tree RSS for the duration of the ``with`` block."""
|
||||
|
||||
def __init__(self, interval_s: float = 0.5):
|
||||
self.interval_s = interval_s
|
||||
self.peak_bytes = 0
|
||||
self._stop = threading.Event()
|
||||
self._thread = threading.Thread(target=self._run, daemon=True)
|
||||
|
||||
def _run(self) -> None:
|
||||
while not self._stop.is_set():
|
||||
self.peak_bytes = max(self.peak_bytes, _tree_rss_bytes())
|
||||
self._stop.wait(self.interval_s)
|
||||
|
||||
def __enter__(self) -> "PeakRSSSampler":
|
||||
self._thread.start()
|
||||
return self
|
||||
|
||||
def __exit__(self, *exc) -> None:
|
||||
self._stop.set()
|
||||
self._thread.join(timeout=2)
|
||||
|
||||
|
||||
def percentile(values: list[float], pct: float) -> float:
|
||||
if not values:
|
||||
return float("nan")
|
||||
ordered = sorted(values)
|
||||
k = max(0, min(len(ordered) - 1, int(round((pct / 100.0) * (len(ordered) - 1)))))
|
||||
return ordered[k]
|
||||
|
||||
|
||||
class _TimedStreaming(StreamingLeRobotDataset):
|
||||
"""StreamingLeRobotDataset that times the fetch stage (parquet/network row) separately from the
|
||||
decode stage (video decode + torch conversion in ``_finalize_sample``), so a single-process pass
|
||||
can attribute per-sample cost to fetch vs decode. Timing lives here in the benchmark, not in the
|
||||
library, to keep the dataset itself instrumentation-free."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.fetch_s = 0.0
|
||||
self.decode_s = 0.0
|
||||
|
||||
def __iter__(self):
|
||||
self._in_flight_epoch = self._epoch
|
||||
self._pipeline.set_epoch(self._in_flight_epoch)
|
||||
self._epoch += 1
|
||||
self.video_decoder_cache = self._make_video_decoder_cache()
|
||||
iterator = iter(self._pipeline)
|
||||
while True:
|
||||
t0 = time.perf_counter()
|
||||
try:
|
||||
row = next(iterator)
|
||||
except StopIteration:
|
||||
return
|
||||
t1 = time.perf_counter()
|
||||
sample = self._finalize_sample(row)
|
||||
t2 = time.perf_counter()
|
||||
self.fetch_s += t1 - t0
|
||||
self.decode_s += t2 - t1
|
||||
yield sample
|
||||
|
||||
|
||||
def select_node_episodes(
|
||||
meta: LeRobotDatasetMetadata, num_partitions: int, index: int, cap: int
|
||||
) -> list[int]:
|
||||
"""This node's episode share, mirroring lerobot_train ``--data_partition=node``: group episodes by
|
||||
shared video files, LPT-balance the groups by frame count, take this node's bin (capped)."""
|
||||
episodes = list(range(meta.total_episodes))
|
||||
from_idx = meta.episodes["dataset_from_index"]
|
||||
to_idx = meta.episodes["dataset_to_index"]
|
||||
lengths = [int(to_idx[ep] - from_idx[ep]) for ep in episodes]
|
||||
if meta.video_keys:
|
||||
file_columns = {
|
||||
key: (meta.episodes[f"videos/{key}/chunk_index"], meta.episodes[f"videos/{key}/file_index"])
|
||||
for key in meta.video_keys
|
||||
}
|
||||
else:
|
||||
file_columns = {"data": (meta.episodes["data/chunk_index"], meta.episodes["data/file_index"])}
|
||||
episode_file_ids = [
|
||||
[(key, chunks[ep], files[ep]) for key, (chunks, files) in file_columns.items()] for ep in episodes
|
||||
]
|
||||
groups = group_episodes_by_files(episode_file_ids)
|
||||
if len(groups) < num_partitions:
|
||||
groups = [[i] for i in range(len(episodes))]
|
||||
group_lengths = [sum(lengths[i] for i in g) for g in groups]
|
||||
bins = partition_episodes(group_lengths, num_partitions)
|
||||
chosen = sorted(episodes[i] for g in bins[index] for i in groups[g])
|
||||
return chosen[:cap] if cap and len(chosen) > cap else chosen
|
||||
|
||||
|
||||
def build_dataset(scenario: str, args: argparse.Namespace):
|
||||
"""Return (dataset, meta, is_map_style, info) for the scenario; single-frame (no delta windows)."""
|
||||
if scenario.startswith("mmap_local"):
|
||||
if not args.local_root:
|
||||
raise SystemExit("mmap_local needs --local_root pointing at a local LeRobotDataset copy.")
|
||||
meta = LeRobotDatasetMetadata(ROBOCASA_REPO, root=args.local_root)
|
||||
episodes = select_node_episodes(meta, args.num_partitions, args.partition_index, args.max_episodes)
|
||||
dataset = LeRobotDataset(ROBOCASA_REPO, root=args.local_root, episodes=episodes, tolerance_s=1e-3)
|
||||
return dataset, meta, True, {"loaded_episodes": len(episodes)}
|
||||
|
||||
data_files_root = MOLMO_BUCKET if scenario.startswith("stream_bucket") else None
|
||||
meta = LeRobotDatasetMetadata(MOLMO_REPO, revision=MOLMO_REVISION)
|
||||
dataset = _TimedStreaming(
|
||||
MOLMO_REPO,
|
||||
revision=MOLMO_REVISION,
|
||||
data_files_root=data_files_root,
|
||||
episode_pool_size=args.episode_pool_size,
|
||||
max_buffer_input_shards=args.max_buffer_input_shards,
|
||||
video_decoder_cache_size=args.video_decoder_cache_size,
|
||||
tolerance_s=1e-3,
|
||||
# Throughput benchmark: don't gate on the one-row-group-per-episode invariant (a public
|
||||
# dataset may be collapsed); reshard() still yields per-episode shards where it holds.
|
||||
validate_row_groups=False,
|
||||
)
|
||||
return dataset, meta, False, {"num_shards": dataset.num_shards, "data_files_root": data_files_root}
|
||||
|
||||
|
||||
def _split(fetch_s: float, decode_s: float, getitem_s: float, n_probe: int) -> dict:
|
||||
stage = fetch_s + decode_s
|
||||
return {
|
||||
"single_proc_samples_per_s": round(n_probe / getitem_s, 2) if getitem_s else None,
|
||||
"fetch_pct": round(100 * fetch_s / stage, 1) if stage else None,
|
||||
"decode_pct": round(100 * decode_s / stage, 1) if stage else None,
|
||||
}
|
||||
|
||||
|
||||
def measure_fetch_decode_stream(dataset: _TimedStreaming, n_probe: int, warmup: int) -> dict:
|
||||
"""Single-process pass attributing per-sample time to fetch (parquet/network row) vs decode (video)."""
|
||||
it = iter(dataset)
|
||||
for _ in range(warmup): # exclude the cold shuffle-buffer fill from the ratio
|
||||
next(it)
|
||||
dataset.fetch_s = dataset.decode_s = 0.0
|
||||
t0 = time.perf_counter()
|
||||
for _ in range(n_probe):
|
||||
next(it)
|
||||
return _split(dataset.fetch_s, dataset.decode_s, time.perf_counter() - t0, n_probe)
|
||||
|
||||
|
||||
def measure_fetch_decode_map(dataset: LeRobotDataset, n_probe: int, warmup: int) -> dict:
|
||||
"""Same split for the map-style loader: fetch = raw tabular row (``get_raw_item``), decode = the rest
|
||||
of ``__getitem__`` (video decode + transforms). Local reads make fetch tiny and decode dominant.
|
||||
|
||||
Random frames are resampled past any that torchcodec fails to decode, so a single flaky frame can't
|
||||
abort the whole benchmark (the parallel DataLoader pass draws its own fresh random frames)."""
|
||||
rng = random.Random(0)
|
||||
n = len(dataset)
|
||||
fetch_s = getitem_s = 0.0
|
||||
warmed = measured = skipped = attempts = 0
|
||||
while measured < n_probe and attempts < (warmup + n_probe) * 10:
|
||||
attempts += 1
|
||||
i = rng.randrange(n)
|
||||
try:
|
||||
t0 = time.perf_counter()
|
||||
dataset.get_raw_item(i)
|
||||
t1 = time.perf_counter()
|
||||
dataset[i]
|
||||
t2 = time.perf_counter()
|
||||
except Exception:
|
||||
skipped += 1
|
||||
continue
|
||||
if warmed < warmup:
|
||||
warmed += 1
|
||||
continue
|
||||
fetch_s += t1 - t0
|
||||
getitem_s += t2 - t1
|
||||
measured += 1
|
||||
if skipped:
|
||||
print(f"map fetch/decode probe skipped {skipped} undecodable frame(s)", flush=True)
|
||||
return _split(fetch_s, max(0.0, getitem_s - fetch_s), getitem_s, measured)
|
||||
|
||||
|
||||
def run_scenario(scenario: str, args: argparse.Namespace) -> None:
|
||||
rank = int(os.environ.get("RANK", "0"))
|
||||
world_size = int(os.environ.get("WORLD_SIZE", "1"))
|
||||
device = torch.device(args.device)
|
||||
|
||||
dataset, meta, is_map_style, info = build_dataset(scenario, args)
|
||||
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=args.batch_size,
|
||||
num_workers=args.num_workers,
|
||||
shuffle=is_map_style, # map-style: global random shuffle; streaming: shuffled inside the dataset
|
||||
pin_memory=device.type == "cuda",
|
||||
drop_last=True,
|
||||
prefetch_factor=args.prefetch_factor if args.num_workers > 0 else None,
|
||||
persistent_workers=args.num_workers > 0,
|
||||
)
|
||||
|
||||
sample_latencies_ms: list[float] = []
|
||||
episodes_per_batch: list[int] = []
|
||||
samples = 0
|
||||
first_batch_latency_s = None
|
||||
steady_start = None
|
||||
|
||||
t_start = time.perf_counter()
|
||||
t_prev = t_start
|
||||
with PeakRSSSampler() as rss:
|
||||
for i, batch in enumerate(loader):
|
||||
for value in batch.values():
|
||||
if torch.is_tensor(value):
|
||||
value.to(device, non_blocking=device.type == "cuda")
|
||||
now = time.perf_counter()
|
||||
if first_batch_latency_s is None:
|
||||
first_batch_latency_s = now - t_start
|
||||
if i == args.warmup_batches:
|
||||
steady_start = now
|
||||
elif i > args.warmup_batches:
|
||||
sample_latencies_ms.append((now - t_prev) / args.batch_size * 1000.0)
|
||||
samples += args.batch_size
|
||||
ep = batch.get("episode_index")
|
||||
if torch.is_tensor(ep):
|
||||
episodes_per_batch.append(int(torch.unique(ep).numel()))
|
||||
t_prev = now
|
||||
# Measure throughput over a fixed wall-clock window (after warmup) so every scenario is
|
||||
# compared over the same duration regardless of its speed; num_batches is only a safety cap.
|
||||
if steady_start is not None and (now - steady_start) >= args.duration_s:
|
||||
break
|
||||
if i + 1 >= args.num_batches:
|
||||
break
|
||||
peak_rss_gb = round(rss.peak_bytes / 1e9, 2) if rss.peak_bytes else None
|
||||
|
||||
now = time.perf_counter()
|
||||
elapsed = now - t_start
|
||||
steady_elapsed_s = (now - steady_start) if steady_start is not None else elapsed
|
||||
|
||||
if samples == 0:
|
||||
raise SystemExit(
|
||||
f"FAILED: 0 samples in {args.duration_s}s for scenario={scenario} "
|
||||
"(inspect worker logs; try --num_workers 0 to surface the exception)."
|
||||
)
|
||||
|
||||
# Single-process fetch/decode split + single-proc throughput. Run AFTER the DataLoader pass: this
|
||||
# decodes video in the main process, which must stay decode-clean until the workers have forked
|
||||
# (decoding before fork corrupts the workers' torchcodec state).
|
||||
del loader
|
||||
if is_map_style:
|
||||
fetch_decode = measure_fetch_decode_map(dataset, args.probe_samples, args.probe_warmup)
|
||||
else:
|
||||
fetch_decode = measure_fetch_decode_stream(dataset, args.probe_samples, args.probe_warmup)
|
||||
|
||||
image_shape = list(meta.features[meta.video_keys[0]]["shape"]) if meta.video_keys else None
|
||||
num_cameras = len(meta.video_keys)
|
||||
results = {
|
||||
"scenario": scenario,
|
||||
"rank": rank,
|
||||
"world_size": world_size,
|
||||
"loader": "map_style" if is_map_style else "streaming",
|
||||
"batch_size": args.batch_size,
|
||||
"num_workers": args.num_workers,
|
||||
"episode_pool_size": None if is_map_style else args.episode_pool_size,
|
||||
"max_buffer_input_shards": None
|
||||
if is_map_style
|
||||
else (args.max_buffer_input_shards or args.episode_pool_size),
|
||||
**info,
|
||||
"num_cameras": num_cameras,
|
||||
"image_shape": image_shape,
|
||||
"fps": meta.fps,
|
||||
"peak_rss_gb": peak_rss_gb,
|
||||
"samples_measured": samples,
|
||||
"steady_window_s": round(steady_elapsed_s, 2),
|
||||
"first_batch_latency_s": round(first_batch_latency_s or float("nan"), 3),
|
||||
# Parallel throughput over the steady window (excludes warmup + the prefetch queue it filled).
|
||||
# A sample is one timestep (one dataset item); it decodes num_cameras video frames.
|
||||
"samples_per_s": round(samples / steady_elapsed_s, 2) if steady_elapsed_s else 0.0,
|
||||
"decoded_frames_per_s": round(samples / steady_elapsed_s * num_cameras, 2)
|
||||
if steady_elapsed_s
|
||||
else 0.0,
|
||||
**fetch_decode,
|
||||
# Distinct episodes per batch / batch size: ~1.0 ≈ map-style uniform, low ≈ correlated samples.
|
||||
"shuffle_randomness_frac": round(statistics.mean(episodes_per_batch) / args.batch_size, 3)
|
||||
if episodes_per_batch
|
||||
else None,
|
||||
"p50_sample_latency_ms": round(statistics.median(sample_latencies_ms), 3)
|
||||
if sample_latencies_ms
|
||||
else None,
|
||||
"p95_sample_latency_ms": round(percentile(sample_latencies_ms, 95), 3),
|
||||
"p99_sample_latency_ms": round(percentile(sample_latencies_ms, 99), 3),
|
||||
"total_time_s": round(elapsed, 2),
|
||||
}
|
||||
|
||||
out_dir = Path(args.out_dir)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
tag = f"{scenario}_bs{args.batch_size}_w{args.num_workers}_r{rank}of{world_size}"
|
||||
(out_dir / f"{tag}.json").write_text(json.dumps(results, indent=2))
|
||||
flat = {k: (json.dumps(v) if isinstance(v, (dict, list)) else v) for k, v in results.items()}
|
||||
with open(out_dir / f"{tag}.csv", "w", newline="") as f:
|
||||
writer = csv.DictWriter(f, fieldnames=list(flat))
|
||||
writer.writeheader()
|
||||
writer.writerow(flat)
|
||||
print(json.dumps(results, indent=2), flush=True)
|
||||
print(f"Wrote {out_dir / tag}.json and .csv", flush=True)
|
||||
|
||||
|
||||
def submit_chain(args: argparse.Namespace) -> None:
|
||||
"""Submit every scenario as a serial sbatch chain (one network-bound job at a time).
|
||||
|
||||
Bodies are passed to ``sbatch --wrap`` as a single argv (no outer shell), so ``$SLURM_PROCID`` /
|
||||
``$SLURM_NTASKS`` stay literal and expand at job runtime, not at submit time.
|
||||
"""
|
||||
this_file = Path(__file__).resolve()
|
||||
repo_dir = str(this_file.parents[2]) # <repo>/examples/scaling/<this file>
|
||||
logs = Path(repo_dir) / "logs"
|
||||
logs.mkdir(exist_ok=True)
|
||||
run = f"conda run --no-capture-output -n {args.conda_env} python"
|
||||
common = (
|
||||
f"--batch_size {args.batch_size} "
|
||||
f"--prefetch_factor {args.prefetch_factor} --episode_pool_size {args.episode_pool_size} "
|
||||
f"--video_decoder_cache_size {args.video_decoder_cache_size} --duration_s {args.duration_s} "
|
||||
f"--num_batches {args.num_batches} --out_dir {args.out_dir}"
|
||||
)
|
||||
if args.max_buffer_input_shards is not None:
|
||||
common += f" --max_buffer_input_shards {args.max_buffer_input_shards}"
|
||||
if args.local_root:
|
||||
common += f" --local_root {args.local_root}"
|
||||
env_prefix = "export TOKENIZERS_PARALLELISM=false"
|
||||
sched = []
|
||||
for opt, env in (("--account", "ACCOUNT"), ("--partition", "PARTITION"), ("--qos", "QOS")):
|
||||
if os.environ.get(env):
|
||||
sched.append(f"{opt}={os.environ[env]}")
|
||||
|
||||
selected = args.scenarios.split(",") if args.scenarios else list(SCENARIOS)
|
||||
prev = ""
|
||||
for scenario in selected:
|
||||
cfg = SCENARIOS[scenario]
|
||||
nw = cfg.get("num_workers", args.num_workers)
|
||||
cpus = cfg.get("cpus", nw + 4)
|
||||
worker = f"{run} {this_file} --scenario {scenario} --num_workers {nw} {common}"
|
||||
if cfg["nodes"] > 1:
|
||||
# One task per node; each exports RANK/WORLD_SIZE so the stream splits shards per node.
|
||||
inner = f"export RANK=$SLURM_PROCID WORLD_SIZE=$SLURM_NTASKS && cd {repo_dir} && {env_prefix} && {worker}"
|
||||
body = f"srun --export=ALL bash -c '{inner}'"
|
||||
node_flags = [f"--nodes={cfg['nodes']}", "--ntasks-per-node=1", "--gpus-per-node=1"]
|
||||
else:
|
||||
body = f"cd {repo_dir} && {env_prefix} && {worker}"
|
||||
node_flags = ["--nodes=1", "--ntasks=1", "--gpus=1"]
|
||||
cmd = [
|
||||
"sbatch",
|
||||
"--parsable",
|
||||
f"--job-name=dlbench_{scenario}",
|
||||
*node_flags,
|
||||
f"--cpus-per-task={cpus}",
|
||||
f"--mem={cfg['mem']}",
|
||||
f"--time={cfg['time']}",
|
||||
f"--output={logs}/%x-%j.out",
|
||||
*sched,
|
||||
]
|
||||
if prev:
|
||||
cmd.append(f"--dependency=afterany:{prev}")
|
||||
cmd += ["--wrap", body]
|
||||
jid = subprocess.check_output(cmd, text=True).strip().split(";")[0]
|
||||
print(f"submitted {jid} dlbench_{scenario}{f' (after {prev})' if prev else ''}", flush=True)
|
||||
prev = jid
|
||||
|
||||
print(f"\nSubmitted {len(selected)} jobs as a serial chain. Results: {args.out_dir}/*.json", flush=True)
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
p = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter)
|
||||
p.add_argument(
|
||||
"--scenario",
|
||||
choices=list(SCENARIOS),
|
||||
default=None,
|
||||
help="Run ONE scenario (worker mode). Omit to submit the whole chain (orchestrator mode).",
|
||||
)
|
||||
p.add_argument(
|
||||
"--scenarios",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Orchestrator only: comma-separated subset of scenarios to submit (default: all).",
|
||||
)
|
||||
p.add_argument("--local_root", type=str, default=None, help="Local LeRobotDataset copy for mmap_local.")
|
||||
p.add_argument(
|
||||
"--num_partitions", type=int, default=8, help="Node count for mmap_local episode partition."
|
||||
)
|
||||
p.add_argument("--partition_index", type=int, default=0)
|
||||
p.add_argument(
|
||||
"--max_episodes", type=int, default=512, help="Cap mmap_local episodes to the local share."
|
||||
)
|
||||
p.add_argument("--batch_size", type=int, default=64)
|
||||
p.add_argument("--num_workers", type=int, default=8)
|
||||
p.add_argument("--prefetch_factor", type=int, default=2)
|
||||
p.add_argument(
|
||||
"--episode_pool_size", type=int, default=1024, help="Streaming shuffle pool (randomness knob)."
|
||||
)
|
||||
p.add_argument(
|
||||
"--max_buffer_input_shards",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Concurrently-live random episodes feeding the pool after reshard() "
|
||||
"(default: episode_pool_size). The frac knob; set >= batch_size for frac->1.",
|
||||
)
|
||||
p.add_argument(
|
||||
"--video_decoder_cache_size", type=int, default=32, help="Max open video decoders (bounds RAM)."
|
||||
)
|
||||
p.add_argument(
|
||||
"--duration_s", type=float, default=60.0, help="Steady-state measurement window (seconds)."
|
||||
)
|
||||
p.add_argument(
|
||||
"--num_batches", type=int, default=1_000_000, help="Safety cap; duration_s governs the window."
|
||||
)
|
||||
p.add_argument("--warmup_batches", type=int, default=5, help="Excluded from steady-state throughput.")
|
||||
p.add_argument(
|
||||
"--probe_samples", type=int, default=100, help="Single-process samples for fetch/decode split."
|
||||
)
|
||||
p.add_argument(
|
||||
"--probe_warmup", type=int, default=10, help="Samples skipped before the fetch/decode probe."
|
||||
)
|
||||
p.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
|
||||
p.add_argument("--conda_env", type=str, default="lerobot", help="Conda env the chained jobs run in.")
|
||||
p.add_argument("--out_dir", type=str, default="benchmarks/streaming/results_dataloading")
|
||||
return p.parse_args()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
if args.scenario is None:
|
||||
if torch.cuda.is_available():
|
||||
print(
|
||||
"NOTE: no --scenario given, submitting the SLURM chain. This benchmark is meant to run on a "
|
||||
"compute cluster; run from a login node with ACCOUNT/PARTITION/QOS set.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
submit_chain(args)
|
||||
else:
|
||||
run_scenario(args.scenario, args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
+6
-1
@@ -95,7 +95,7 @@ dependencies = [
|
||||
|
||||
# ── Feature-scoped extras ──────────────────────────────────
|
||||
dataset = [
|
||||
"datasets>=4.7.0,<5.0.0",
|
||||
"datasets>=5.0.0,<6.0.0", # StreamingLeRobotDataset needs reshard() + shuffle(max_buffer_input_shards=...)
|
||||
"pandas>=2.0.0,<3.0.0", # NOTE: Transitive dependency of datasets
|
||||
"pyarrow>=21.0.0,<30.0.0", # NOTE: Transitive dependency of datasets
|
||||
"lerobot[av-dep]",
|
||||
@@ -333,6 +333,11 @@ explicit = true
|
||||
[tool.uv.sources]
|
||||
torch = [{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" }]
|
||||
torchvision = [{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" }]
|
||||
# Temporary: the native streaming pipeline needs batch(by_column=...) to survive shard/shuffle
|
||||
# re-creation (datasets#8259), reshard() per row group (#8193), and shuffle(max_buffer_input_shards=...)
|
||||
# (#8194) — all merged, not yet in a tagged 5.0 release. Pin to the merge commit until the next
|
||||
# datasets release ships them, then drop this and rely on the `datasets>=5.0.0` floor in `dependencies`.
|
||||
datasets = { git = "https://github.com/huggingface/datasets.git", rev = "2c45eab1bb975ac3d846f2aa6217b82adec8eba3" }
|
||||
|
||||
[tool.setuptools.package-data]
|
||||
lerobot = ["envs/*.json"]
|
||||
|
||||
@@ -39,6 +39,10 @@ class DatasetConfig:
|
||||
# This reduces memory and speeds up DataLoader IPC. The training pipeline handles the conversion.
|
||||
return_uint8: bool = False
|
||||
streaming: bool = False
|
||||
# Whole episodes each streaming consumer keeps open to shuffle across (the randomness knob).
|
||||
# Larger mixes more episodes per batch at the cost of cold-start latency; RAM stays small because
|
||||
# the pool holds tabular rows only. Ignored when streaming is False.
|
||||
streaming_episode_pool_size: int = 1024
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.episodes is not None:
|
||||
|
||||
@@ -945,8 +945,17 @@ def _write_parquet(df: pd.DataFrame, path: Path, meta: LeRobotDatasetMetadata) -
|
||||
ep_dataset = embed_images(ep_dataset)
|
||||
|
||||
table = ep_dataset.with_format("arrow")[:]
|
||||
writer = pq.ParquetWriter(path, schema=table.schema, compression="snappy", use_dictionary=True)
|
||||
writer.write_table(table)
|
||||
# Emit several row groups with a page index instead of one giant row group. A single row group forces
|
||||
# streaming readers to materialize the whole file's columns per open shard; with random-access streaming
|
||||
# (shuffle + delta windows) across many workers x shards that dominates RAM. Targeting ~32MB-uncompressed
|
||||
# groups bounds per-shard memory while keeping groups large enough to scan
|
||||
# efficiently; the page index lets readers skip to the pages they need.
|
||||
target_row_group_bytes = 32 * 1024 * 1024
|
||||
row_group_size = max(1, min(table.num_rows, table.num_rows * target_row_group_bytes // max(table.nbytes, 1)))
|
||||
writer = pq.ParquetWriter(
|
||||
path, schema=table.schema, compression="snappy", use_dictionary=True, write_page_index=True
|
||||
)
|
||||
writer.write_table(table, row_group_size=row_group_size)
|
||||
writer.close()
|
||||
|
||||
|
||||
|
||||
@@ -106,7 +106,7 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
|
||||
delta_timestamps=delta_timestamps,
|
||||
image_transforms=image_transforms,
|
||||
revision=cfg.dataset.revision,
|
||||
max_num_shards=cfg.num_workers,
|
||||
episode_pool_size=cfg.dataset.streaming_episode_pool_size,
|
||||
tolerance_s=cfg.tolerance_s,
|
||||
return_uint8=True,
|
||||
)
|
||||
|
||||
@@ -13,16 +13,17 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from collections import deque
|
||||
from collections.abc import Callable, Generator, Iterable, Iterator
|
||||
import logging
|
||||
from collections.abc import Callable, Iterator
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from datasets.distributed import split_dataset_by_node
|
||||
|
||||
from lerobot.utils.constants import HF_LEROBOT_HOME, LOOKAHEAD_BACKTRACKTABLE, LOOKBACK_BACKTRACKTABLE
|
||||
from lerobot.utils.constants import HF_LEROBOT_HOME
|
||||
|
||||
from .dataset_metadata import CODEBASE_VERSION, LeRobotDatasetMetadata
|
||||
from .feature_utils import get_delta_indices
|
||||
@@ -31,207 +32,70 @@ from .utils import (
|
||||
check_version_compatibility,
|
||||
find_float_index,
|
||||
is_float_in_list,
|
||||
safe_shard,
|
||||
)
|
||||
from .video_utils import (
|
||||
VideoDecoderCache,
|
||||
decode_video_frames_torchcodec,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class LookBackError(Exception):
|
||||
"""
|
||||
Exception raised when trying to look back in the history of a Backtrackable object.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class LookAheadError(Exception):
|
||||
"""
|
||||
Exception raised when trying to look ahead in the future of a Backtrackable object.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class Backtrackable[T]:
|
||||
"""
|
||||
Wrap any iterator/iterable so you can step back up to `history` items
|
||||
and look ahead up to `lookahead` items.
|
||||
|
||||
This is useful for streaming datasets where you need to access previous and future items
|
||||
but can't load the entire dataset into memory.
|
||||
|
||||
Example:
|
||||
-------
|
||||
```python
|
||||
ds = load_dataset("c4", "en", streaming=True, split="train")
|
||||
rev = Backtrackable(ds, history=3, lookahead=2)
|
||||
|
||||
x0 = next(rev) # forward
|
||||
x1 = next(rev)
|
||||
x2 = next(rev)
|
||||
|
||||
# Look ahead
|
||||
x3_peek = rev.peek_ahead(1) # next item without moving cursor
|
||||
x4_peek = rev.peek_ahead(2) # two items ahead
|
||||
|
||||
# Look back
|
||||
x1_again = rev.peek_back(1) # previous item without moving cursor
|
||||
x0_again = rev.peek_back(2) # two items back
|
||||
|
||||
# Move backward
|
||||
x1_back = rev.prev() # back one step
|
||||
next(rev) # returns x2, continues forward from where we were
|
||||
```
|
||||
"""
|
||||
|
||||
__slots__ = ("_source", "_back_buf", "_ahead_buf", "_cursor", "_history", "_lookahead")
|
||||
|
||||
def __init__(self, iterable: Iterable[T], *, history: int = 1, lookahead: int = 0):
|
||||
if history < 1:
|
||||
raise ValueError("history must be >= 1")
|
||||
if lookahead <= 0:
|
||||
raise ValueError("lookahead must be > 0")
|
||||
|
||||
self._source: Iterator[T] = iter(iterable)
|
||||
self._back_buf: deque[T] = deque(maxlen=history)
|
||||
self._ahead_buf: deque[T] = deque(maxlen=lookahead) if lookahead > 0 else deque()
|
||||
self._cursor: int = 0
|
||||
self._history = history
|
||||
self._lookahead = lookahead
|
||||
|
||||
def __iter__(self) -> "Backtrackable[T]":
|
||||
return self
|
||||
|
||||
def __next__(self) -> T:
|
||||
# If we've stepped back, consume from back buffer first
|
||||
if self._cursor < 0: # -1 means "last item", etc.
|
||||
self._cursor += 1
|
||||
return self._back_buf[self._cursor]
|
||||
|
||||
# If we have items in the ahead buffer, use them first
|
||||
item = self._ahead_buf.popleft() if self._ahead_buf else next(self._source)
|
||||
|
||||
# Add current item to back buffer and reset cursor
|
||||
self._back_buf.append(item)
|
||||
self._cursor = 0
|
||||
return item
|
||||
|
||||
def prev(self) -> T:
|
||||
"""
|
||||
Step one item back in history and return it.
|
||||
Raises IndexError if already at the oldest buffered item.
|
||||
"""
|
||||
if len(self._back_buf) + self._cursor <= 1:
|
||||
raise LookBackError("At start of history")
|
||||
|
||||
self._cursor -= 1
|
||||
return self._back_buf[self._cursor]
|
||||
|
||||
def peek_back(self, n: int = 1) -> T:
|
||||
"""
|
||||
Look `n` items back (n=1 == previous item) without moving the cursor.
|
||||
"""
|
||||
if n < 0 or n + 1 > len(self._back_buf) + self._cursor:
|
||||
raise LookBackError("peek_back distance out of range")
|
||||
|
||||
return self._back_buf[self._cursor - (n + 1)]
|
||||
|
||||
def peek_ahead(self, n: int = 1) -> T:
|
||||
"""
|
||||
Look `n` items ahead (n=1 == next item) without moving the cursor.
|
||||
Fills the ahead buffer if necessary.
|
||||
"""
|
||||
if n < 1:
|
||||
raise LookAheadError("peek_ahead distance must be 1 or more")
|
||||
elif n > self._lookahead:
|
||||
raise LookAheadError("peek_ahead distance exceeds lookahead limit")
|
||||
|
||||
# Fill ahead buffer if we don't have enough items
|
||||
while len(self._ahead_buf) < n:
|
||||
try:
|
||||
item = next(self._source)
|
||||
self._ahead_buf.append(item)
|
||||
|
||||
except StopIteration as err:
|
||||
raise LookAheadError("peek_ahead: not enough items in source") from err
|
||||
|
||||
return self._ahead_buf[n - 1]
|
||||
|
||||
def history(self) -> list[T]:
|
||||
"""
|
||||
Return a copy of the buffered history (most recent last).
|
||||
The list length ≤ `history` argument passed at construction.
|
||||
"""
|
||||
if self._cursor == 0:
|
||||
return list(self._back_buf)
|
||||
|
||||
# When cursor<0, slice so the order remains chronological
|
||||
return list(self._back_buf)[: self._cursor or None]
|
||||
|
||||
def can_peek_back(self, steps: int = 1) -> bool:
|
||||
"""
|
||||
Check if we can go back `steps` items without raising an IndexError.
|
||||
"""
|
||||
return steps <= len(self._back_buf) + self._cursor
|
||||
|
||||
def can_peek_ahead(self, steps: int = 1) -> bool:
|
||||
"""
|
||||
Check if we can peek ahead `steps` items.
|
||||
This may involve trying to fill the ahead buffer.
|
||||
"""
|
||||
if self._lookahead > 0 and steps > self._lookahead:
|
||||
return False
|
||||
|
||||
# Try to fill ahead buffer to check if we can peek that far
|
||||
try:
|
||||
while len(self._ahead_buf) < steps:
|
||||
if self._lookahead > 0 and len(self._ahead_buf) >= self._lookahead:
|
||||
return False
|
||||
item = next(self._source)
|
||||
self._ahead_buf.append(item)
|
||||
return True
|
||||
except StopIteration:
|
||||
return False
|
||||
# Bound the default frame-level shuffle buffer: rows are tabular-only (~KB each), so this is
|
||||
# roughly a few hundred MB of host RAM per consumer at the cap.
|
||||
_MAX_DEFAULT_FRAME_BUFFER = 200_000
|
||||
|
||||
|
||||
class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
"""LeRobotDataset with streaming capabilities.
|
||||
"""LeRobotDataset with streaming capabilities, built on native HF `datasets` primitives.
|
||||
|
||||
This class extends LeRobotDataset to add streaming functionality, allowing data to be streamed
|
||||
rather than loaded entirely into memory. This is especially useful for large datasets that may
|
||||
not fit in memory or when you want to quickly explore a dataset without downloading it completely.
|
||||
The tabular side is a pure `datasets` pipeline::
|
||||
|
||||
The key innovation is using a Backtrackable iterator that maintains a bounded buffer of recent
|
||||
items, allowing us to access previous frames for delta timestamps without loading the entire
|
||||
dataset into memory.
|
||||
load_dataset(streaming=True) # parquet shards from the Hub / a bucket
|
||||
-> reshard() # 1 shard == 1 row group == 1 episode
|
||||
-> split_dataset_by_node(rank, world_size) # disjoint shards per rank
|
||||
-> batch(by_column="episode_index") # whole episodes (one per shard)
|
||||
-> shuffle(episode_pool_size, max_buffer_input_shards) # K random episodes, global perm
|
||||
-> map(explode + exact delta windows) # episode -> frames, windows are exact
|
||||
-> shuffle(buffer_size=frame_shuffle_buffer_size) # frame-level interleave
|
||||
|
||||
and this class is a thin torch ``IterableDataset`` wrapper around it that decodes video
|
||||
per emitted sample (decode-on-exit), applies image transforms, and attaches the task
|
||||
string. DataLoader workers are split natively by `datasets` (disjoint shards per worker),
|
||||
and resume uses the native ``state_dict`` / ``load_state_dict``.
|
||||
|
||||
Random-episode admission (Plan B): the LeRobot writer stores one Parquet row group per
|
||||
episode, so ``datasets.IterableDataset.reshard()`` makes one shard == one episode (no new
|
||||
files; shards are (file, row_group) pairs). ``shuffle`` then permutes shard order globally and
|
||||
fills its buffer from ``max_buffer_input_shards`` shards concurrently, so the episode pool is a
|
||||
uniformly-random sample of the corpus regardless of how many episodes are packed per file.
|
||||
``max_buffer_input_shards`` is the number of concurrently-live random episodes; set it
|
||||
``>= batch_size`` for the per-batch distinct-episode fraction to approach 1.
|
||||
|
||||
Requirement: ONE ROW GROUP PER EPISODE. Recorded datasets satisfy this; bulk
|
||||
``df.to_parquet`` / ``push_to_hub`` / aggregate paths collapse row groups and are rejected at
|
||||
init (see ``validate_row_groups``). Old collapsed datasets still load fine for the map-style
|
||||
path; only this streaming random-episode path requires the invariant.
|
||||
|
||||
Randomness: a batch mixes up to ``episode_pool_size`` distinct episodes; delta windows are
|
||||
exact slices of the resident episode with correct padding at episode boundaries.
|
||||
|
||||
Resume: ``state_dict()`` / ``load_state_dict()`` delegate to `datasets`. Samples sitting in
|
||||
the shuffle buffers at checkpoint time are skipped on resume (documented `datasets`
|
||||
behavior), so resume never repeats data but may drop up to roughly
|
||||
``episode_pool_size x episode_len + frame_shuffle_buffer_size`` frames — negligible at
|
||||
training scale. The contract is exact with ``num_workers=0``; with DataLoader workers use
|
||||
``torchdata.stateful_dataloader.StatefulDataLoader``, which checkpoints each worker's
|
||||
dataset state through this same protocol.
|
||||
|
||||
Example:
|
||||
Basic usage:
|
||||
```python
|
||||
from lerobot.common.datasets.streaming_dataset import StreamingLeRobotDataset
|
||||
|
||||
# Create a streaming dataset with delta timestamps
|
||||
delta_timestamps = {
|
||||
"observation.image": [-1.0, -0.5, 0.0], # 1 sec ago, 0.5 sec ago, current
|
||||
"action": [0.0, 0.1, 0.2], # current, 0.1 sec future, 0.2 sec future
|
||||
}
|
||||
|
||||
dataset = StreamingLeRobotDataset(
|
||||
repo_id="your-dataset-repo-id",
|
||||
delta_timestamps=delta_timestamps,
|
||||
streaming=True,
|
||||
buffer_size=1000,
|
||||
delta_timestamps={"action": [0.0, 0.1, 0.2]},
|
||||
episode_pool_size=1024,
|
||||
)
|
||||
|
||||
# Iterate over the dataset
|
||||
for i, item in enumerate(dataset):
|
||||
print(f"Sample {i}: Episode {item['episode_index']} Frame {item['frame_index']}")
|
||||
# item will contain stacked frames according to delta_timestamps
|
||||
if i >= 10:
|
||||
break
|
||||
for sample in dataset:
|
||||
...
|
||||
```
|
||||
"""
|
||||
|
||||
@@ -246,12 +110,20 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
revision: str | None = None,
|
||||
force_cache_sync: bool = False,
|
||||
streaming: bool = True,
|
||||
buffer_size: int = 1000,
|
||||
max_num_shards: int = 16,
|
||||
episode_pool_size: int | None = 1024,
|
||||
max_buffer_input_shards: int | None = None,
|
||||
frame_shuffle_buffer_size: int | None = None,
|
||||
buffer_size: int | None = None,
|
||||
max_num_shards: int | None = None,
|
||||
seed: int = 42,
|
||||
rng: np.random.Generator | None = None,
|
||||
shuffle: bool = True,
|
||||
return_uint8: bool = False,
|
||||
rank: int | None = None,
|
||||
world_size: int | None = None,
|
||||
video_decoder_cache_size: int | None = None,
|
||||
data_files_root: str | None = None,
|
||||
validate_row_groups: bool = True,
|
||||
):
|
||||
"""Initialize a StreamingLeRobotDataset.
|
||||
|
||||
@@ -267,11 +139,40 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
revision (str, optional): Git revision id (branch name, tag, or commit hash).
|
||||
force_cache_sync (bool, optional): Flag to sync and refresh local files first.
|
||||
streaming (bool, optional): Whether to stream the dataset or load it all. Defaults to True.
|
||||
buffer_size (int, optional): Buffer size for shuffling when streaming. Defaults to 1000.
|
||||
max_num_shards (int, optional): Number of shards to re-shard the input dataset into. Defaults to 16.
|
||||
episode_pool_size (int, optional): Whole episodes each consumer keeps open to shuffle
|
||||
across — the randomness knob. Larger mixes more episodes per batch (closer to
|
||||
map-style uniform) at the cost of cold-start latency and frame-buffer RAM.
|
||||
Defaults to 1024.
|
||||
max_buffer_input_shards (int | None, optional): Number of shards (== episodes, after
|
||||
``reshard()``) the episode-pool ``shuffle`` reads from concurrently — i.e. the count
|
||||
of concurrently-live random episodes feeding the pool from a global shard permutation.
|
||||
Set ``>= batch_size`` for the per-batch distinct-episode fraction to approach 1.
|
||||
Defaults to ``episode_pool_size``.
|
||||
frame_shuffle_buffer_size (int | None, optional): Frame-level shuffle buffer after the
|
||||
episode pool. Defaults to ``episode_pool_size x average episode length`` (capped),
|
||||
which matches the pool's mixing radius.
|
||||
buffer_size (int | None, optional): Deprecated; superseded by ``episode_pool_size``.
|
||||
max_num_shards (int | None, optional): Deprecated; `datasets` handles shard-to-worker
|
||||
assignment natively.
|
||||
seed (int, optional): Reproducibility random seed.
|
||||
rng (np.random.Generator | None, optional): Random number generator.
|
||||
shuffle (bool, optional): Whether to shuffle the dataset across exhaustions. Defaults to True.
|
||||
rng (np.random.Generator | None, optional): Deprecated; ignored.
|
||||
shuffle (bool, optional): Whether to shuffle. False yields episodes in stream order.
|
||||
rank (int | None, optional): This process' rank for distributed training. Each rank streams
|
||||
a disjoint set of shards via ``split_dataset_by_node``. When omitted, resolved from
|
||||
Accelerate (``process_index``) or the ``RANK`` env var, defaulting to 0.
|
||||
world_size (int | None, optional): Total number of distributed processes. When omitted,
|
||||
resolved from Accelerate or ``WORLD_SIZE``, defaulting to 1. For an even per-rank split,
|
||||
``num_shards % world_size == 0`` should hold (warned otherwise).
|
||||
video_decoder_cache_size (int | None, optional): Max number of open video decoders to retain.
|
||||
When omitted, sized to the episode pool's working set, capped at 128.
|
||||
data_files_root (str | None, optional): fsspec root holding the bulk ``data/`` and ``videos/``
|
||||
trees (e.g. ``hf://buckets/<owner>/<name>``). When set, parquet and video bytes are read
|
||||
from there while metadata still loads from ``repo_id`` on the Hub.
|
||||
validate_row_groups (bool, optional): When True (default), verify at init that the dataset
|
||||
stores one Parquet row group per episode (sampling data-file footers) and that
|
||||
``num_shards`` is divisible by ``world_size`` for distributed runs, raising a clear
|
||||
``ValueError`` otherwise. Set False to skip the checks (e.g. single-process debugging);
|
||||
the divisibility check then downgrades to a warning.
|
||||
"""
|
||||
super().__init__()
|
||||
self.repo_id = repo_id
|
||||
@@ -284,15 +185,36 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
self.tolerance_s = tolerance_s
|
||||
self.revision = revision if revision else CODEBASE_VERSION
|
||||
self.seed = seed
|
||||
self.rng = rng if rng is not None else np.random.default_rng(seed)
|
||||
if rng is not None:
|
||||
logger.warning("StreamingLeRobotDataset: `rng` is deprecated and ignored; use `seed`.")
|
||||
if buffer_size is not None:
|
||||
logger.warning(
|
||||
"StreamingLeRobotDataset: `buffer_size` is deprecated and ignored; "
|
||||
"use `episode_pool_size` (whole episodes, not frames)."
|
||||
)
|
||||
if max_num_shards is not None:
|
||||
logger.warning(
|
||||
"StreamingLeRobotDataset: `max_num_shards` is deprecated and ignored; "
|
||||
"`datasets` assigns shards to DataLoader workers natively."
|
||||
)
|
||||
self.shuffle = shuffle
|
||||
|
||||
self.streaming = streaming
|
||||
self.buffer_size = buffer_size
|
||||
self.episode_pool_size = max(1, episode_pool_size) if episode_pool_size else 1024
|
||||
self.max_buffer_input_shards = (
|
||||
max(1, max_buffer_input_shards) if max_buffer_input_shards else self.episode_pool_size
|
||||
)
|
||||
self.validate_row_groups = validate_row_groups
|
||||
self._return_uint8 = return_uint8
|
||||
|
||||
self.rank, self.world_size = self._resolve_distributed(rank, world_size)
|
||||
self.video_decoder_cache_size = video_decoder_cache_size
|
||||
self.data_files_root = data_files_root.rstrip("/") if data_files_root else None
|
||||
|
||||
# We cache the video decoders to avoid re-initializing them at each frame (avoiding a ~10x slowdown)
|
||||
self.video_decoder_cache = None
|
||||
self._epoch = 0
|
||||
self._in_flight_epoch = 0
|
||||
|
||||
if self._requested_root is not None:
|
||||
self.root.mkdir(exist_ok=True, parents=True)
|
||||
@@ -314,15 +236,50 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
self.delta_timestamps = delta_timestamps
|
||||
self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps)
|
||||
|
||||
self.hf_dataset: datasets.IterableDataset = load_dataset(
|
||||
self.repo_id if not self.streaming_from_local else str(self.root),
|
||||
split="train",
|
||||
streaming=self.streaming,
|
||||
data_files="data/*/*.parquet",
|
||||
revision=self.revision,
|
||||
if self.data_files_root is not None:
|
||||
# Bulk data lives in an fsspec root (e.g. an HF storage bucket); metadata stays on the Hub.
|
||||
self.hf_dataset: datasets.IterableDataset = load_dataset(
|
||||
"parquet",
|
||||
split="train",
|
||||
streaming=self.streaming,
|
||||
data_files=f"{self.data_files_root}/data/*/*.parquet",
|
||||
)
|
||||
else:
|
||||
self.hf_dataset = load_dataset(
|
||||
self.repo_id if not self.streaming_from_local else str(self.root),
|
||||
split="train",
|
||||
streaming=self.streaming,
|
||||
data_files="data/*/*.parquet",
|
||||
revision=self.revision,
|
||||
)
|
||||
|
||||
# Drop any parquet columns not declared in the dataset's feature contract. Some revisions / sources
|
||||
# (e.g. an unversioned bucket holding `main`) carry extra, possibly variable-length annotation
|
||||
# columns such as `language_events`; left in, they leak into the sample and break default DataLoader
|
||||
# collation across frames of differing length. On a clean revision this is a no-op.
|
||||
known_columns = set(self.meta.features)
|
||||
extra_columns = [c for c in (self.hf_dataset.column_names or []) if c not in known_columns]
|
||||
if extra_columns:
|
||||
self.hf_dataset = self.hf_dataset.remove_columns(extra_columns)
|
||||
|
||||
# Reshard Parquet per row group so 1 shard == 1 row group == 1 episode (the LeRobot writer
|
||||
# emits one row group per episode). This lets the episode-pool shuffle admit uniformly-random
|
||||
# episodes from a global shard permutation, independent of how many episodes are packed per file.
|
||||
if self.streaming:
|
||||
self.hf_dataset = self.hf_dataset.reshard()
|
||||
self.num_shards = self.hf_dataset.num_shards
|
||||
|
||||
if self.validate_row_groups and self.streaming:
|
||||
self._validate_row_groups_per_episode()
|
||||
|
||||
avg_episode_len = max(1, round(self.meta.total_frames / max(1, self.meta.total_episodes)))
|
||||
self.frame_shuffle_buffer_size = (
|
||||
frame_shuffle_buffer_size
|
||||
if frame_shuffle_buffer_size is not None
|
||||
else min(self.episode_pool_size * avg_episode_len, _MAX_DEFAULT_FRAME_BUFFER)
|
||||
)
|
||||
|
||||
self.num_shards = min(self.hf_dataset.num_shards, max_num_shards)
|
||||
self._pipeline = self._build_pipeline()
|
||||
|
||||
@property
|
||||
def num_frames(self):
|
||||
@@ -337,96 +294,270 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
return self.meta.fps
|
||||
|
||||
@staticmethod
|
||||
def _iter_random_indices(
|
||||
rng: np.random.Generator, buffer_size: int, random_batch_size=100
|
||||
) -> Iterator[int]:
|
||||
while True:
|
||||
yield from (int(i) for i in rng.integers(0, buffer_size, size=random_batch_size))
|
||||
def _resolve_distributed(rank: int | None, world_size: int | None) -> tuple[int, int]:
|
||||
"""Resolve (rank, world_size) for distributed streaming.
|
||||
|
||||
@staticmethod
|
||||
def _infinite_generator_over_elements(rng: np.random.Generator, elements: list[int]) -> Iterator[int]:
|
||||
while True:
|
||||
yield rng.choice(elements)
|
||||
Explicit arguments win. Otherwise prefer an already-initialized Accelerate state, then the
|
||||
``RANK``/``WORLD_SIZE`` env vars set by launchers, and finally fall back to single-process (0, 1).
|
||||
"""
|
||||
import os
|
||||
|
||||
if rank is not None and world_size is not None:
|
||||
return rank, world_size
|
||||
|
||||
try:
|
||||
from accelerate.state import PartialState
|
||||
|
||||
if PartialState._shared_state: # only read it if already initialized; never initialize here
|
||||
state = PartialState()
|
||||
return state.process_index, state.num_processes
|
||||
except Exception:
|
||||
logger.debug("Could not resolve distributed state from Accelerate; using env/defaults.")
|
||||
|
||||
env_rank = os.environ.get("RANK")
|
||||
env_world = os.environ.get("WORLD_SIZE")
|
||||
if env_rank is not None and env_world is not None:
|
||||
return int(env_rank), int(env_world)
|
||||
|
||||
return 0, 1
|
||||
|
||||
def _resolve_data_root(self) -> str:
|
||||
"""fsspec root that holds the bulk ``data/`` parquet tree (revision-qualified for the Hub)."""
|
||||
if self.data_files_root is not None:
|
||||
return self.data_files_root
|
||||
if self.streaming and not self.streaming_from_local:
|
||||
return f"hf://datasets/{self.repo_id}@{self.revision}"
|
||||
return str(self.root)
|
||||
|
||||
def _episode_files(self) -> dict[tuple[int, int], list[int]]:
|
||||
"""Map each data file ``(chunk_index, file_index)`` to the episode indices it stores."""
|
||||
file_to_eps: dict[tuple[int, int], list[int]] = {}
|
||||
for ep in range(self.meta.total_episodes):
|
||||
row = self.meta.episodes[ep]
|
||||
key = (int(row["data/chunk_index"]), int(row["data/file_index"]))
|
||||
file_to_eps.setdefault(key, []).append(ep)
|
||||
return file_to_eps
|
||||
|
||||
def _validate_row_groups_per_episode(self, sample_files: int = 32) -> None:
|
||||
"""Verify the dataset stores ONE ROW GROUP PER EPISODE so each episode is an independently
|
||||
addressable shard after ``reshard()``. Cheap (footer-only) and sampled.
|
||||
|
||||
Raises:
|
||||
ValueError: if a sampled data file collapses several episodes into fewer row groups, or
|
||||
the whole dataset is one row group per file while holding many more episodes than files.
|
||||
"""
|
||||
import fsspec
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
file_to_eps = self._episode_files()
|
||||
num_data_files = len(file_to_eps)
|
||||
|
||||
# Whole-dataset extreme: reshard() could not split beyond file granularity (one row group per
|
||||
# file) yet there are many more episodes than files -> collapsed.
|
||||
if self.num_shards <= num_data_files and self.meta.total_episodes > self.num_shards:
|
||||
raise ValueError(
|
||||
f"{self.repo_id}: after reshard() the stream still has only {self.num_shards} shard(s) "
|
||||
f"for {self.meta.total_episodes} episodes across {num_data_files} data file(s) — i.e. one "
|
||||
"row group per file. StreamingLeRobotDataset random-episode shuffling requires ONE ROW "
|
||||
"GROUP PER EPISODE so each episode is an independently addressable shard after reshard(). "
|
||||
"Re-emit through the LeRobot writer (one write_table per episode) or fix the aggregate / "
|
||||
"annotate / push_to_hub writer that collapsed the row groups, then re-upload. Recorded "
|
||||
"datasets already satisfy this. Pass validate_row_groups=False to bypass (random-episode "
|
||||
"quality will degrade)."
|
||||
)
|
||||
|
||||
data_root = self._resolve_data_root()
|
||||
rng = np.random.default_rng(self.seed)
|
||||
keys = list(file_to_eps)
|
||||
chosen = rng.choice(len(keys), size=min(sample_files, len(keys)), replace=False)
|
||||
for i in chosen:
|
||||
chunk_idx, file_idx = keys[int(i)]
|
||||
n_ep = len(file_to_eps[(chunk_idx, file_idx)])
|
||||
rel = self.meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||
path = f"{data_root}/{rel}"
|
||||
with fsspec.open(path, "rb") as f:
|
||||
pf = pq.ParquetFile(f)
|
||||
n_rg = pf.num_row_groups
|
||||
num_rows = pf.metadata.num_rows
|
||||
if n_rg < n_ep:
|
||||
raise ValueError(
|
||||
f"{path}: stored as {n_rg} Parquet row group(s) ({num_rows} rows across "
|
||||
f"{n_ep} episodes). StreamingLeRobotDataset random-episode shuffling requires ONE ROW "
|
||||
"GROUP PER EPISODE so each episode becomes an independently addressable shard after "
|
||||
"reshard(). This file was written by a bulk df.to_parquet / push_to_hub / aggregate "
|
||||
"path that collapses row groups. Re-emit through the LeRobot writer (one write_table "
|
||||
"per episode) or fix the aggregate/annotate writer, then re-upload. Recorded datasets "
|
||||
"already satisfy this. Pass validate_row_groups=False to bypass (quality will degrade)."
|
||||
)
|
||||
|
||||
def _build_pipeline(self) -> datasets.IterableDataset:
|
||||
"""Assemble the native tabular pipeline (everything except video decode)."""
|
||||
ds = self.hf_dataset
|
||||
if self.world_size > 1:
|
||||
if ds.num_shards % self.world_size != 0:
|
||||
msg = (
|
||||
f"num_shards ({ds.num_shards}) is not divisible by world_size ({self.world_size}). "
|
||||
"After reshard() num_shards == the episode count, and split_dataset_by_node only "
|
||||
"assigns shards evenly when num_shards % world_size == 0; otherwise every rank "
|
||||
"streams (and pays for) the full dataset and keeps only 1/world_size of it. Pin "
|
||||
"world_size to a divisor of the episode count, or drop/pad episodes to a divisible "
|
||||
"count with the dataset tools. Set validate_row_groups=False to downgrade to a warning."
|
||||
)
|
||||
if self.validate_row_groups:
|
||||
raise ValueError(msg)
|
||||
logger.warning(msg)
|
||||
ds = split_dataset_by_node(ds, rank=self.rank, world_size=self.world_size)
|
||||
|
||||
ds = ds.batch(by_column="episode_index")
|
||||
episode_columns = list(ds.column_names or self.hf_dataset.column_names or [])
|
||||
if self.shuffle:
|
||||
max_input_shards = max(1, min(self.max_buffer_input_shards, ds.num_shards))
|
||||
ds = ds.shuffle(
|
||||
seed=self.seed,
|
||||
buffer_size=self.episode_pool_size,
|
||||
max_buffer_input_shards=max_input_shards,
|
||||
)
|
||||
# A row-count-changing batched map must drop the input columns explicitly; the exploded
|
||||
# frames re-emit them (windowed keys replaced by their delta windows + *_is_pad masks).
|
||||
ds = ds.map(self._explode_episodes, batched=True, remove_columns=episode_columns)
|
||||
if self.shuffle:
|
||||
ds = ds.shuffle(seed=self.seed + 1, buffer_size=max(2, self.frame_shuffle_buffer_size))
|
||||
return ds
|
||||
|
||||
def _tabular_window_keys(self) -> list[str]:
|
||||
if self.delta_indices is None:
|
||||
return []
|
||||
return [key for key in self.delta_indices if key not in self.meta.video_keys]
|
||||
|
||||
def _explode_episodes(self, episode_batch: dict[str, list[list]]) -> dict[str, list]:
|
||||
"""Episode batches -> per-frame rows, with exact tabular delta windows and pad masks.
|
||||
|
||||
Runs inside the `datasets` pipeline (plain Python values, no torch). For each windowed key
|
||||
the original per-frame value is replaced by its delta window (list of values, clamped to
|
||||
the episode bounds) plus a ``{key}_is_pad`` mask, mirroring the map-style dataset.
|
||||
"""
|
||||
window_keys = set(self._tabular_window_keys())
|
||||
out: dict[str, list] = {key: [] for key in episode_batch if key not in window_keys}
|
||||
for key in window_keys:
|
||||
out[key] = []
|
||||
out[f"{key}_is_pad"] = []
|
||||
|
||||
num_episodes = len(episode_batch["episode_index"])
|
||||
for e in range(num_episodes):
|
||||
length = len(episode_batch["episode_index"][e])
|
||||
for key, column in episode_batch.items():
|
||||
if key in window_keys:
|
||||
continue
|
||||
out[key].extend(column[e])
|
||||
for key in window_keys:
|
||||
episode_column = episode_batch[key][e]
|
||||
deltas = self.delta_indices[key]
|
||||
for t in range(length):
|
||||
window = []
|
||||
is_pad = []
|
||||
for delta in deltas:
|
||||
j = t + delta
|
||||
window.append(episode_column[min(max(j, 0), length - 1)])
|
||||
is_pad.append(not 0 <= j < length)
|
||||
out[key].append(window)
|
||||
out[f"{key}_is_pad"].append(is_pad)
|
||||
return out
|
||||
|
||||
def _make_video_decoder_cache(self) -> VideoDecoderCache:
|
||||
"""Size the decoder cache to the pool's working set (pool episodes x cameras), capped at 128."""
|
||||
if self.video_decoder_cache_size is not None:
|
||||
return VideoDecoderCache(max_size=self.video_decoder_cache_size)
|
||||
num_cameras = len(self.meta.video_keys)
|
||||
if num_cameras == 0:
|
||||
return VideoDecoderCache()
|
||||
return VideoDecoderCache(max_size=min((self.episode_pool_size + 1) * num_cameras, 128))
|
||||
|
||||
# TODO(fracapuano): Implement multi-threaded prefetching to accelerate data loading.
|
||||
# The current sequential iteration is a bottleneck. A producer-consumer pattern
|
||||
# could be used with a ThreadPoolExecutor to run `make_frame` (especially video decoding)
|
||||
# in parallel, feeding a queue from which this iterator will yield processed items.
|
||||
def __iter__(self) -> Iterator[dict[str, torch.Tensor]]:
|
||||
if self.video_decoder_cache is None:
|
||||
self.video_decoder_cache = VideoDecoderCache()
|
||||
|
||||
# keep the same seed across exhaustions if shuffle is False, otherwise shuffle data across exhaustions
|
||||
rng = np.random.default_rng(self.seed) if not self.shuffle else self.rng
|
||||
|
||||
buffer_indices_generator = self._iter_random_indices(rng, self.buffer_size)
|
||||
|
||||
idx_to_backtrack_dataset = {
|
||||
idx: self._make_backtrackable_dataset(safe_shard(self.hf_dataset, idx, self.num_shards))
|
||||
for idx in range(self.num_shards)
|
||||
}
|
||||
|
||||
# This buffer is populated while iterating on the dataset's shards
|
||||
# the logic is to add 2 levels of randomness:
|
||||
# (1) sample one shard at random from the ones available, and
|
||||
# (2) sample one frame from the shard sampled at (1)
|
||||
frames_buffer = []
|
||||
while available_shards := list(idx_to_backtrack_dataset.keys()):
|
||||
shard_key = next(self._infinite_generator_over_elements(rng, available_shards))
|
||||
backtrack_dataset = idx_to_backtrack_dataset[shard_key] # selects which shard to iterate on
|
||||
|
||||
try:
|
||||
for frame in self.make_frame(backtrack_dataset):
|
||||
if len(frames_buffer) == self.buffer_size:
|
||||
i = next(buffer_indices_generator) # samples a element from the buffer
|
||||
yield frames_buffer[i]
|
||||
frames_buffer[i] = frame
|
||||
else:
|
||||
frames_buffer.append(frame)
|
||||
break # random shard sampled, switch shard
|
||||
except (
|
||||
RuntimeError,
|
||||
StopIteration,
|
||||
): # NOTE: StopIteration inside a generator throws a RuntimeError since python 3.7
|
||||
del idx_to_backtrack_dataset[shard_key] # Remove exhausted shard, onto another shard
|
||||
|
||||
# Once shards are all exhausted, shuffle the buffer and yield the remaining frames
|
||||
rng.shuffle(frames_buffer)
|
||||
yield from frames_buffer
|
||||
|
||||
def _get_window_steps(
|
||||
self, delta_timestamps: dict[str, list[float]] | None = None, dynamic_bounds: bool = False
|
||||
) -> tuple[int, int]:
|
||||
if delta_timestamps is None:
|
||||
return 1, 1
|
||||
|
||||
if not dynamic_bounds:
|
||||
# Fix the windows
|
||||
lookback = LOOKBACK_BACKTRACKTABLE
|
||||
lookahead = LOOKAHEAD_BACKTRACKTABLE
|
||||
# `datasets` reshuffles (and re-permutes shard order) per epoch from (seed, epoch);
|
||||
# DataLoader workers each advance their own copy's counter in lockstep. The in-flight
|
||||
# epoch is tracked separately so a mid-iteration state_dict() records the epoch the
|
||||
# stream position actually belongs to. Only advance when shuffling: after reshard() the
|
||||
# stream has one shard per episode, and set_epoch(n>0) re-permutes shard order even without
|
||||
# a shuffle op, so an unshuffled stream must pin epoch 0 to repeat the same order each pass.
|
||||
if self.shuffle:
|
||||
self._in_flight_epoch = self._epoch
|
||||
self._epoch += 1
|
||||
else:
|
||||
# Dynamically adjust the windows based on the given delta_timesteps
|
||||
all_timestamps = sum(delta_timestamps.values(), [])
|
||||
lookback = min(all_timestamps) * self.fps
|
||||
lookahead = max(all_timestamps) * self.fps
|
||||
self._in_flight_epoch = 0
|
||||
self._pipeline.set_epoch(self._in_flight_epoch)
|
||||
self.video_decoder_cache = self._make_video_decoder_cache()
|
||||
|
||||
# When lookback is >=0 it means no negative timesteps have been provided
|
||||
lookback = 0 if lookback >= 0 else (lookback * -1)
|
||||
iterator = iter(self._pipeline)
|
||||
while True:
|
||||
try:
|
||||
row = next(iterator)
|
||||
except StopIteration:
|
||||
return
|
||||
yield self._finalize_sample(row)
|
||||
|
||||
return lookback, lookahead
|
||||
def _finalize_sample(self, row: dict) -> dict:
|
||||
"""Torch conversion + video decode (decode-on-exit) + transforms + task for one frame."""
|
||||
window_keys = self._tabular_window_keys()
|
||||
pad_masks = {f"{key}_is_pad": torch.BoolTensor(row.pop(f"{key}_is_pad")) for key in window_keys}
|
||||
item = item_to_torch(row)
|
||||
item.update(pad_masks)
|
||||
|
||||
def _make_backtrackable_dataset(self, dataset: datasets.IterableDataset) -> Backtrackable:
|
||||
lookback, lookahead = self._get_window_steps(self.delta_timestamps)
|
||||
return Backtrackable(dataset, history=lookback, lookahead=lookahead)
|
||||
if len(self.meta.video_keys) > 0:
|
||||
ep_idx = int(item["episode_index"])
|
||||
current_ts = float(item["timestamp"])
|
||||
# Per-camera episode-local bounds [0, duration]: out-of-episode deltas pad instead of
|
||||
# decoding against a neighbouring episode sharing the same video file.
|
||||
episode_boundaries_ts = {
|
||||
key: (
|
||||
0.0,
|
||||
self.meta.episodes[ep_idx][f"videos/{key}/to_timestamp"]
|
||||
- self.meta.episodes[ep_idx][f"videos/{key}/from_timestamp"],
|
||||
)
|
||||
for key in self.meta.video_keys
|
||||
}
|
||||
original_timestamps = self._make_timestamps_from_indices(current_ts, self.delta_indices)
|
||||
query_timestamps = self._get_query_timestamps(
|
||||
current_ts, self.delta_indices, episode_boundaries_ts
|
||||
)
|
||||
video_frames = self._query_videos(query_timestamps, ep_idx)
|
||||
|
||||
if self.image_transforms is not None:
|
||||
for cam in self.meta.camera_keys:
|
||||
video_frames[cam] = self.image_transforms(video_frames[cam])
|
||||
|
||||
item.update(video_frames)
|
||||
if self.delta_indices is not None:
|
||||
item.update(
|
||||
self._get_video_frame_padding_mask(video_frames, query_timestamps, original_timestamps)
|
||||
)
|
||||
|
||||
item["task"] = self.meta.tasks.iloc[int(item["task_index"])].name
|
||||
return item
|
||||
|
||||
def set_epoch(self, epoch: int) -> None:
|
||||
"""Set the epoch the next ``__iter__`` will use (reshuffles the native pipeline)."""
|
||||
self._epoch = epoch
|
||||
|
||||
def state_dict(self) -> dict:
|
||||
"""Native `datasets` stream state. Exact contract with ``num_workers=0``; with DataLoader
|
||||
workers use ``torchdata.stateful_dataloader.StatefulDataLoader`` (it checkpoints each
|
||||
worker's copy through this protocol). Samples in the shuffle buffers are skipped on
|
||||
resume (never repeated), bounded by the pool + frame buffer sizes.
|
||||
"""
|
||||
return {"pipeline": self._pipeline.state_dict(), "epoch": self._in_flight_epoch}
|
||||
|
||||
def load_state_dict(self, state_dict: dict) -> None:
|
||||
# Resume continues inside the recorded epoch: the next __iter__ replays that epoch's
|
||||
# shuffle order from the restored stream position, then advances normally.
|
||||
self._epoch = int(state_dict.get("epoch", 0))
|
||||
self._pipeline.load_state_dict(state_dict["pipeline"])
|
||||
|
||||
def _make_timestamps_from_indices(
|
||||
self, start_ts: float, indices: dict[str, list[int]] | None = None
|
||||
) -> dict[str, list[float]]:
|
||||
if indices is not None:
|
||||
return {
|
||||
key: (
|
||||
start_ts + torch.tensor(indices[key]) / self.fps
|
||||
).tolist() # NOTE: why not delta_timestamps directly?
|
||||
key: (start_ts + torch.tensor(indices[key]) / self.fps).tolist()
|
||||
for key in self.delta_timestamps
|
||||
}
|
||||
else:
|
||||
@@ -463,65 +594,6 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
|
||||
return padding_mask
|
||||
|
||||
def make_frame(self, dataset_iterator: Backtrackable) -> Generator:
|
||||
"""Makes a frame starting from a dataset iterator"""
|
||||
item = next(dataset_iterator)
|
||||
item = item_to_torch(item)
|
||||
|
||||
updates = [] # list of "updates" to apply to the item retrieved from hf_dataset (w/o camera features)
|
||||
|
||||
# Get episode index from the item
|
||||
ep_idx = item["episode_index"]
|
||||
|
||||
# "timestamp" restarts from 0 for each episode, whereas we need a global timestep within the single .mp4 file (given by index/fps)
|
||||
current_ts = item["index"] / self.fps
|
||||
|
||||
episode_boundaries_ts = {
|
||||
key: (
|
||||
self.meta.episodes[ep_idx][f"videos/{key}/from_timestamp"],
|
||||
self.meta.episodes[ep_idx][f"videos/{key}/to_timestamp"],
|
||||
)
|
||||
for key in self.meta.video_keys
|
||||
}
|
||||
|
||||
# Apply delta querying logic if necessary
|
||||
if self.delta_indices is not None:
|
||||
query_result, padding = self._get_delta_frames(dataset_iterator, item)
|
||||
updates.append(query_result)
|
||||
updates.append(padding)
|
||||
|
||||
# Load video frames, when needed
|
||||
if len(self.meta.video_keys) > 0:
|
||||
original_timestamps = self._make_timestamps_from_indices(current_ts, self.delta_indices)
|
||||
|
||||
# Some timestamps might not result available considering the episode's boundaries
|
||||
query_timestamps = self._get_query_timestamps(
|
||||
current_ts, self.delta_indices, episode_boundaries_ts
|
||||
)
|
||||
video_frames = self._query_videos(query_timestamps, ep_idx)
|
||||
|
||||
if self.image_transforms is not None:
|
||||
image_keys = self.meta.camera_keys
|
||||
for cam in image_keys:
|
||||
video_frames[cam] = self.image_transforms(video_frames[cam])
|
||||
|
||||
updates.append(video_frames)
|
||||
|
||||
if self.delta_indices is not None:
|
||||
# We always return the same number of frames. Unavailable frames are padded.
|
||||
padding_mask = self._get_video_frame_padding_mask(
|
||||
video_frames, query_timestamps, original_timestamps
|
||||
)
|
||||
updates.append(padding_mask)
|
||||
|
||||
result = item.copy()
|
||||
for update in updates:
|
||||
result.update(update)
|
||||
|
||||
result["task"] = self.meta.tasks.iloc[item["task_index"]].name
|
||||
|
||||
yield result
|
||||
|
||||
def _get_query_timestamps(
|
||||
self,
|
||||
current_ts: float,
|
||||
@@ -552,11 +624,20 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
|
||||
item = {}
|
||||
for video_key, query_ts in query_timestamps.items():
|
||||
root = self.meta.url_root if self.streaming and not self.streaming_from_local else self.root
|
||||
video_path = f"{root}/{self.meta.get_video_file_path(ep_idx, video_key)}"
|
||||
# query_ts is episode-local; shift to the absolute in-file timeline by the episode's offset.
|
||||
from_timestamp = self.meta.episodes[ep_idx][f"videos/{video_key}/from_timestamp"]
|
||||
shifted_query_ts = [from_timestamp + ts for ts in query_ts]
|
||||
rel_path = str(self.meta.get_video_file_path(ep_idx, video_key))
|
||||
if self.data_files_root is not None:
|
||||
root = self.data_files_root
|
||||
elif self.streaming and not self.streaming_from_local:
|
||||
root = self.meta.url_root
|
||||
else:
|
||||
root = self.root
|
||||
video_path = f"{root}/{rel_path}"
|
||||
frames = decode_video_frames_torchcodec(
|
||||
video_path,
|
||||
query_ts,
|
||||
shifted_query_ts,
|
||||
self.tolerance_s,
|
||||
decoder_cache=self.video_decoder_cache,
|
||||
return_uint8=self._return_uint8,
|
||||
@@ -566,116 +647,6 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
|
||||
return item
|
||||
|
||||
def _get_delta_frames(self, dataset_iterator: Backtrackable, current_item: dict):
|
||||
# TODO(fracapuano): Modularize this function, refactor the code
|
||||
"""Get frames with delta offsets using the backtrackable iterator.
|
||||
|
||||
Args:
|
||||
current_item (dict): Current item from the iterator.
|
||||
ep_idx (int): Episode index.
|
||||
|
||||
Returns:
|
||||
tuple: (query_result, padding) - frames at delta offsets and padding info.
|
||||
"""
|
||||
current_episode_idx = current_item["episode_index"]
|
||||
|
||||
# Prepare results
|
||||
query_result = {}
|
||||
padding = {}
|
||||
|
||||
for key, delta_indices in self.delta_indices.items():
|
||||
if key in self.meta.video_keys:
|
||||
continue # visual frames are decoded separately
|
||||
|
||||
target_frames = []
|
||||
is_pad = []
|
||||
|
||||
# Create a results dictionary to store frames in processing order, then reconstruct original order for stacking
|
||||
delta_results = {}
|
||||
|
||||
# Separate and sort deltas by difficulty (easier operations first)
|
||||
negative_deltas = sorted([d for d in delta_indices if d < 0], reverse=True) # [-1, -2, -3, ...]
|
||||
positive_deltas = sorted([d for d in delta_indices if d > 0]) # [1, 2, 3, ...]
|
||||
zero_deltas = [d for d in delta_indices if d == 0]
|
||||
|
||||
# Process zero deltas (current frame)
|
||||
for delta in zero_deltas:
|
||||
delta_results[delta] = (
|
||||
current_item[key],
|
||||
False,
|
||||
)
|
||||
|
||||
# Process negative deltas in order of increasing difficulty
|
||||
lookback_failed = False
|
||||
|
||||
last_successful_frame = current_item[key]
|
||||
|
||||
for delta in negative_deltas:
|
||||
if lookback_failed:
|
||||
delta_results[delta] = (last_successful_frame, True)
|
||||
continue
|
||||
|
||||
try:
|
||||
steps_back = abs(delta)
|
||||
if dataset_iterator.can_peek_back(steps_back):
|
||||
past_item = dataset_iterator.peek_back(steps_back)
|
||||
past_item = item_to_torch(past_item)
|
||||
|
||||
if past_item["episode_index"] == current_episode_idx:
|
||||
delta_results[delta] = (past_item[key], False)
|
||||
last_successful_frame = past_item[key]
|
||||
|
||||
else:
|
||||
raise LookBackError("Retrieved frame is from different episode!")
|
||||
else:
|
||||
raise LookBackError("Cannot go back further than the history buffer!")
|
||||
|
||||
except LookBackError:
|
||||
delta_results[delta] = (last_successful_frame, True)
|
||||
lookback_failed = True # All subsequent negative deltas will also fail
|
||||
|
||||
# Process positive deltas in order of increasing difficulty
|
||||
lookahead_failed = False
|
||||
last_successful_frame = current_item[key]
|
||||
|
||||
for delta in positive_deltas:
|
||||
if lookahead_failed:
|
||||
delta_results[delta] = (last_successful_frame, True)
|
||||
continue
|
||||
|
||||
try:
|
||||
if dataset_iterator.can_peek_ahead(delta):
|
||||
future_item = dataset_iterator.peek_ahead(delta)
|
||||
future_item = item_to_torch(future_item)
|
||||
|
||||
if future_item["episode_index"] == current_episode_idx:
|
||||
delta_results[delta] = (future_item[key], False)
|
||||
last_successful_frame = future_item[key]
|
||||
|
||||
else:
|
||||
raise LookAheadError("Retrieved frame is from different episode!")
|
||||
else:
|
||||
raise LookAheadError("Cannot go ahead further than the lookahead buffer!")
|
||||
|
||||
except LookAheadError:
|
||||
delta_results[delta] = (last_successful_frame, True)
|
||||
lookahead_failed = True # All subsequent positive deltas will also fail
|
||||
|
||||
# Reconstruct original order for stacking
|
||||
for delta in delta_indices:
|
||||
frame, is_padded = delta_results[delta]
|
||||
|
||||
# add batch dimension for stacking
|
||||
target_frames.append(frame) # frame.unsqueeze(0))
|
||||
is_pad.append(is_padded)
|
||||
|
||||
# Stack frames and add to results
|
||||
if target_frames:
|
||||
query_result[key] = torch.stack(target_frames)
|
||||
padding[f"{key}_is_pad"] = torch.BoolTensor(is_pad)
|
||||
|
||||
return query_result, padding
|
||||
|
||||
def _validate_delta_timestamp_keys(self, delta_timestamps: dict[list[float]]) -> None:
|
||||
"""
|
||||
Validate that all keys in delta_timestamps correspond to actual features in the dataset.
|
||||
|
||||
@@ -273,7 +273,11 @@ class VideoDecoderCache:
|
||||
self._cache.move_to_end(video_path)
|
||||
return entry[0]
|
||||
|
||||
file_handle = fsspec.open(video_path).__enter__()
|
||||
# Bound per-handle buffering: with many decoders kept open at once (one per camera per active
|
||||
# shard, across all workers), the default fsspec read cache balloons RAM on remote backends
|
||||
# like hf:// buckets. A small readahead cache caps each handle's footprint without hurting the
|
||||
# mostly-sequential reads torchcodec issues.
|
||||
file_handle = fsspec.open(video_path, cache_type="readahead", block_size=2**20).__enter__()
|
||||
try:
|
||||
decoder = VideoDecoder(file_handle, seek_mode="approximate")
|
||||
except Exception:
|
||||
|
||||
@@ -387,7 +387,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
||||
|
||||
# create dataloader for offline training
|
||||
if hasattr(active_cfg, "drop_n_last_frames"):
|
||||
if hasattr(active_cfg, "drop_n_last_frames") and not cfg.dataset.streaming:
|
||||
shuffle = False
|
||||
# A dedicated generator (rather than the global torch RNG) lets accelerator.prepare
|
||||
# synchronize the shuffle permutation across ranks, keeping batch shards disjoint even
|
||||
@@ -426,9 +426,16 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
|
||||
# Prepare everything with accelerator
|
||||
accelerator.wait_for_everyone()
|
||||
policy, optimizer, dataloader, lr_scheduler = accelerator.prepare(
|
||||
policy, optimizer, dataloader, lr_scheduler
|
||||
)
|
||||
if cfg.dataset.streaming:
|
||||
# The streaming IterableDataset is already rank-disjoint via split_dataset_by_node, so we must
|
||||
# NOT hand the dataloader to accelerate: its IterableDatasetShard would keep only every
|
||||
# world_size-th batch of each rank's already-disjoint stream (silently training on 1/N of the
|
||||
# data while decoding all of it). Batches are moved to the device manually in the loop below.
|
||||
policy, optimizer, lr_scheduler = accelerator.prepare(policy, optimizer, lr_scheduler)
|
||||
else:
|
||||
policy, optimizer, dataloader, lr_scheduler = accelerator.prepare(
|
||||
policy, optimizer, dataloader, lr_scheduler
|
||||
)
|
||||
dl_iter = cycle(dataloader)
|
||||
|
||||
policy.train()
|
||||
@@ -468,6 +475,9 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
for _ in range(step, cfg.steps):
|
||||
start_time = time.perf_counter()
|
||||
batch = next(dl_iter)
|
||||
if cfg.dataset.streaming:
|
||||
# The streaming dataloader is not accelerate-prepared (see above), so move to device here.
|
||||
batch = {k: (v.to(device, non_blocking=True) if torch.is_tensor(v) else v) for k, v in batch.items()}
|
||||
for cam_key in dataset.meta.camera_keys:
|
||||
if cam_key in batch and batch[cam_key].dtype == torch.uint8:
|
||||
batch[cam_key] = batch[cam_key].to(dtype=torch.float32) / 255.0
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
@@ -25,52 +24,6 @@ from lerobot.utils.constants import ACTION
|
||||
from tests.fixtures.constants import DUMMY_REPO_ID
|
||||
|
||||
|
||||
def get_frames_expected_order(streaming_ds: StreamingLeRobotDataset) -> list[int]:
|
||||
"""Replicates the shuffling logic of StreamingLeRobotDataset to get the expected order of indices."""
|
||||
rng = np.random.default_rng(streaming_ds.seed)
|
||||
buffer_size = streaming_ds.buffer_size
|
||||
num_shards = streaming_ds.num_shards
|
||||
|
||||
shards_indices = []
|
||||
for shard_idx in range(num_shards):
|
||||
shard = streaming_ds.hf_dataset.shard(num_shards, index=shard_idx)
|
||||
shard_indices = [item["index"] for item in shard]
|
||||
shards_indices.append(shard_indices)
|
||||
|
||||
shard_iterators = {i: iter(s) for i, s in enumerate(shards_indices)}
|
||||
|
||||
buffer_indices_generator = streaming_ds._iter_random_indices(rng, buffer_size)
|
||||
|
||||
frames_buffer = []
|
||||
expected_indices = []
|
||||
|
||||
while shard_iterators: # While there are still available shards
|
||||
available_shard_keys = list(shard_iterators.keys())
|
||||
if not available_shard_keys:
|
||||
break
|
||||
|
||||
# Call _infinite_generator_over_elements with current available shards (key difference!)
|
||||
shard_key = next(streaming_ds._infinite_generator_over_elements(rng, available_shard_keys))
|
||||
|
||||
try:
|
||||
frame_index = next(shard_iterators[shard_key])
|
||||
|
||||
if len(frames_buffer) == buffer_size:
|
||||
i = next(buffer_indices_generator)
|
||||
expected_indices.append(frames_buffer[i])
|
||||
frames_buffer[i] = frame_index
|
||||
else:
|
||||
frames_buffer.append(frame_index)
|
||||
|
||||
except StopIteration:
|
||||
del shard_iterators[shard_key] # Remove exhausted shard
|
||||
|
||||
rng.shuffle(frames_buffer)
|
||||
expected_indices.extend(frames_buffer)
|
||||
|
||||
return expected_indices
|
||||
|
||||
|
||||
def test_single_frame_consistency(tmp_path, lerobot_dataset_factory):
|
||||
"""Test if are correctly accessed"""
|
||||
ds_num_frames = 400
|
||||
@@ -120,10 +73,9 @@ def test_single_frame_consistency(tmp_path, lerobot_dataset_factory):
|
||||
[False, True],
|
||||
)
|
||||
def test_frames_order_over_epochs(tmp_path, lerobot_dataset_factory, shuffle):
|
||||
"""Test if streamed frames correspond to shuffling operations over in-memory dataset."""
|
||||
"""Each epoch covers every frame exactly once; shuffle reshuffles across epochs."""
|
||||
ds_num_frames = 400
|
||||
ds_num_episodes = 10
|
||||
buffer_size = 100
|
||||
seed = 42
|
||||
n_epochs = 3
|
||||
|
||||
@@ -138,25 +90,17 @@ def test_frames_order_over_epochs(tmp_path, lerobot_dataset_factory, shuffle):
|
||||
)
|
||||
|
||||
streaming_ds = StreamingLeRobotDataset(
|
||||
repo_id=repo_id, root=local_path, buffer_size=buffer_size, seed=seed, shuffle=shuffle
|
||||
repo_id=repo_id, root=local_path, episode_pool_size=4, seed=seed, shuffle=shuffle
|
||||
)
|
||||
|
||||
first_epoch_indices = [frame["index"] for frame in streaming_ds]
|
||||
expected_indices = get_frames_expected_order(streaming_ds)
|
||||
|
||||
assert first_epoch_indices == expected_indices, "First epoch indices do not match expected indices"
|
||||
|
||||
expected_indices = get_frames_expected_order(streaming_ds)
|
||||
for _ in range(n_epochs):
|
||||
streaming_indices = [frame["index"] for frame in streaming_ds]
|
||||
frames_match = all(
|
||||
s_index == e_index for s_index, e_index in zip(streaming_indices, expected_indices, strict=True)
|
||||
)
|
||||
|
||||
if shuffle:
|
||||
assert not frames_match
|
||||
else:
|
||||
assert frames_match
|
||||
epochs = [[int(frame["index"]) for frame in streaming_ds] for _ in range(n_epochs)]
|
||||
for epoch_indices in epochs:
|
||||
assert sorted(epoch_indices) == list(range(ds_num_frames)), "epoch did not cover every frame once"
|
||||
if shuffle:
|
||||
assert epochs[0] != epochs[1], "shuffle did not reshuffle across epochs"
|
||||
assert epochs[0] != list(range(ds_num_frames)), "shuffle left the stream in sequential order"
|
||||
else:
|
||||
assert epochs[0] == epochs[1] == epochs[2], "unshuffled epochs must repeat the same order"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -164,15 +108,11 @@ def test_frames_order_over_epochs(tmp_path, lerobot_dataset_factory, shuffle):
|
||||
[False, True],
|
||||
)
|
||||
def test_frames_order_with_shards(tmp_path, lerobot_dataset_factory, shuffle):
|
||||
"""Test if streamed frames correspond to shuffling operations over in-memory dataset with multiple shards."""
|
||||
"""Multi-shard streams keep exactly-once coverage and deterministic per-seed order."""
|
||||
ds_num_frames = 100
|
||||
ds_num_episodes = 10
|
||||
buffer_size = 10
|
||||
|
||||
seed = 42
|
||||
n_epochs = 3
|
||||
data_file_size_mb = 0.001
|
||||
|
||||
chunks_size = 1
|
||||
|
||||
local_path = tmp_path / "test"
|
||||
@@ -187,31 +127,21 @@ def test_frames_order_with_shards(tmp_path, lerobot_dataset_factory, shuffle):
|
||||
chunks_size=chunks_size,
|
||||
)
|
||||
|
||||
streaming_ds = StreamingLeRobotDataset(
|
||||
repo_id=repo_id,
|
||||
root=local_path,
|
||||
buffer_size=buffer_size,
|
||||
seed=seed,
|
||||
shuffle=shuffle,
|
||||
max_num_shards=4,
|
||||
)
|
||||
|
||||
first_epoch_indices = [frame["index"] for frame in streaming_ds]
|
||||
expected_indices = get_frames_expected_order(streaming_ds)
|
||||
|
||||
assert first_epoch_indices == expected_indices, "First epoch indices do not match expected indices"
|
||||
|
||||
for _ in range(n_epochs):
|
||||
streaming_indices = [
|
||||
frame["index"] for frame in streaming_ds
|
||||
] # NOTE: this is the same as first_epoch_indices
|
||||
frames_match = all(
|
||||
s_index == e_index for s_index, e_index in zip(streaming_indices, expected_indices, strict=True)
|
||||
def make_ds():
|
||||
return StreamingLeRobotDataset(
|
||||
repo_id=repo_id,
|
||||
root=local_path,
|
||||
episode_pool_size=3,
|
||||
seed=seed,
|
||||
shuffle=shuffle,
|
||||
max_num_shards=4,
|
||||
)
|
||||
if shuffle:
|
||||
assert not frames_match
|
||||
else:
|
||||
assert frames_match
|
||||
|
||||
first = [int(frame["index"]) for frame in make_ds()]
|
||||
again = [int(frame["index"]) for frame in make_ds()]
|
||||
|
||||
assert sorted(first) == list(range(ds_num_frames)), "epoch did not cover every frame once"
|
||||
assert first == again, "same seed must reproduce the same order"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -288,6 +218,11 @@ def test_frames_with_delta_consistency(tmp_path, lerobot_dataset_factory, state_
|
||||
|
||||
check = torch.allclose(left, right) and left.shape == right.shape
|
||||
|
||||
else:
|
||||
# Scalar numerics: streaming yields python floats/ints where map-style yields
|
||||
# 0-dim tensors (long-standing accepted difference). Compare by value.
|
||||
check = float(left) == float(right)
|
||||
|
||||
key_checks.append((key, check))
|
||||
|
||||
assert all(t[1] for t in key_checks), (
|
||||
|
||||
@@ -0,0 +1,100 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""End-to-end distributed streaming smoke test under a real `accelerate launch`.
|
||||
|
||||
Mirrors tests/training/test_multi_gpu.py but runs on CPU and only checks the dataloading contract: with
|
||||
two processes, `split_dataset_by_node` (auto-resolved from the Accelerate state) must give each rank a
|
||||
disjoint set of frames that together cover the dataset. Skips if the environment can't actually spawn
|
||||
>= 2 processes (e.g. local macOS multi-CPU), so it never silently passes as a single process.
|
||||
"""
|
||||
|
||||
import json
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
pytest.importorskip("accelerate", reason="accelerate is required (install lerobot[training])")
|
||||
|
||||
from tests.fixtures.constants import DUMMY_REPO_ID
|
||||
|
||||
WORKER = """
|
||||
import json, sys
|
||||
from accelerate import PartialState
|
||||
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
|
||||
|
||||
root, repo_id, out_dir = sys.argv[1], sys.argv[2], sys.argv[3]
|
||||
state = PartialState()
|
||||
ds = StreamingLeRobotDataset(
|
||||
repo_id=repo_id, root=root, shuffle=False, episode_pool_size=8, max_num_shards=8
|
||||
)
|
||||
indices = [int(frame["index"]) for frame in ds]
|
||||
payload = {"rank": state.process_index, "world": state.num_processes, "indices": indices}
|
||||
with open(f"{out_dir}/rank_{state.process_index}.json", "w") as f:
|
||||
json.dump(payload, f)
|
||||
"""
|
||||
|
||||
|
||||
@pytest.mark.skipif(shutil.which("accelerate") is None, reason="accelerate CLI not available")
|
||||
def test_accelerate_launch_ranks_are_disjoint(tmp_path, lerobot_dataset_factory):
|
||||
total_frames = 160
|
||||
repo_id = f"{DUMMY_REPO_ID}-acc"
|
||||
root = tmp_path / "ds"
|
||||
lerobot_dataset_factory(
|
||||
root=root,
|
||||
repo_id=repo_id,
|
||||
total_episodes=8,
|
||||
total_frames=total_frames,
|
||||
use_videos=False,
|
||||
data_files_size_in_mb=0.001,
|
||||
chunks_size=1,
|
||||
)
|
||||
|
||||
worker = tmp_path / "worker.py"
|
||||
worker.write_text(WORKER)
|
||||
out_dir = tmp_path / "out"
|
||||
out_dir.mkdir()
|
||||
|
||||
cmd = [
|
||||
"accelerate",
|
||||
"launch",
|
||||
"--num_processes=2",
|
||||
"--num_machines=1",
|
||||
"--mixed_precision=no",
|
||||
"--dynamo_backend=no",
|
||||
"--cpu",
|
||||
str(worker),
|
||||
str(root),
|
||||
repo_id,
|
||||
str(out_dir),
|
||||
]
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, timeout=600)
|
||||
assert result.returncode == 0, (
|
||||
f"accelerate launch failed:\nSTDOUT:\n{result.stdout}\nSTDERR:\n{result.stderr}"
|
||||
)
|
||||
|
||||
payloads = [json.loads(p.read_text()) for p in sorted(out_dir.glob("rank_*.json"))]
|
||||
if len(payloads) < 2 or any(p["world"] < 2 for p in payloads):
|
||||
pytest.skip("environment did not spawn >= 2 distributed processes (e.g. local macOS multi-CPU)")
|
||||
|
||||
rank_sets = [set(p["indices"]) for p in payloads]
|
||||
assert rank_sets[0].isdisjoint(rank_sets[1]), "ranks streamed overlapping frames under accelerate launch"
|
||||
assert set().union(*rank_sets) == set(range(total_frames)), "ranks did not jointly cover all frames"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(pytest.main([__file__, "-v"]))
|
||||
@@ -0,0 +1,430 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for the HF-native large-scale streaming additions: distributed (per-rank) sharding,
|
||||
DataLoader worker splitting, the episode pool (randomness, coverage, exact deltas), video
|
||||
prefetching, deterministic fast-forward resume, and schema parity."""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
|
||||
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
|
||||
from lerobot.utils.constants import ACTION
|
||||
from tests.fixtures.constants import DUMMY_REPO_ID
|
||||
|
||||
|
||||
def _make_local_dataset(factory, root, repo_id, *, total_episodes, total_frames, use_videos=False, **kw):
|
||||
factory(
|
||||
root=root,
|
||||
repo_id=repo_id,
|
||||
total_episodes=total_episodes,
|
||||
total_frames=total_frames,
|
||||
use_videos=use_videos,
|
||||
data_files_size_in_mb=0.001,
|
||||
chunks_size=1,
|
||||
**kw,
|
||||
)
|
||||
|
||||
|
||||
def _stream_indices(ds: StreamingLeRobotDataset) -> list[int]:
|
||||
return [int(frame["index"]) for frame in ds]
|
||||
|
||||
|
||||
def test_resolve_distributed_prefers_explicit_then_env(monkeypatch):
|
||||
assert StreamingLeRobotDataset._resolve_distributed(2, 8) == (2, 8)
|
||||
|
||||
monkeypatch.delenv("RANK", raising=False)
|
||||
monkeypatch.delenv("WORLD_SIZE", raising=False)
|
||||
# No accelerate state, no env -> single process.
|
||||
assert StreamingLeRobotDataset._resolve_distributed(None, None) == (0, 1)
|
||||
|
||||
monkeypatch.setenv("RANK", "3")
|
||||
monkeypatch.setenv("WORLD_SIZE", "4")
|
||||
assert StreamingLeRobotDataset._resolve_distributed(None, None) == (3, 4)
|
||||
|
||||
|
||||
def test_split_by_node_disjoint_across_ranks(tmp_path, lerobot_dataset_factory):
|
||||
"""Each rank must stream a disjoint set of frames, and the ranks together must cover every frame."""
|
||||
repo_id = f"{DUMMY_REPO_ID}-ranks"
|
||||
total_frames, total_episodes = 200, 8
|
||||
_make_local_dataset(
|
||||
lerobot_dataset_factory,
|
||||
tmp_path / "ds",
|
||||
repo_id,
|
||||
total_episodes=total_episodes,
|
||||
total_frames=total_frames,
|
||||
)
|
||||
|
||||
world_size = 2
|
||||
per_rank = []
|
||||
for rank in range(world_size):
|
||||
ds = StreamingLeRobotDataset(
|
||||
repo_id=repo_id,
|
||||
root=tmp_path / "ds",
|
||||
shuffle=False,
|
||||
episode_pool_size=8,
|
||||
max_num_shards=8,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
per_rank.append(set(_stream_indices(ds)))
|
||||
|
||||
assert per_rank[0].isdisjoint(per_rank[1]), (
|
||||
"ranks streamed overlapping frames (duplicate data across GPUs)"
|
||||
)
|
||||
assert per_rank[0] | per_rank[1] == set(range(total_frames)), "ranks did not jointly cover all frames"
|
||||
|
||||
|
||||
def test_dataloader_workers_no_duplicates_within_rank(tmp_path, lerobot_dataset_factory):
|
||||
"""DataLoader workers within a rank must split shards so no frame is yielded twice."""
|
||||
repo_id = f"{DUMMY_REPO_ID}-workers"
|
||||
total_frames, total_episodes = 120, 8
|
||||
_make_local_dataset(
|
||||
lerobot_dataset_factory,
|
||||
tmp_path / "ds",
|
||||
repo_id,
|
||||
total_episodes=total_episodes,
|
||||
total_frames=total_frames,
|
||||
)
|
||||
|
||||
ds = StreamingLeRobotDataset(
|
||||
repo_id=repo_id, root=tmp_path / "ds", shuffle=False, episode_pool_size=4, max_num_shards=4
|
||||
)
|
||||
loader = DataLoader(ds, batch_size=None, num_workers=2)
|
||||
indices = [int(batch["index"]) for batch in loader]
|
||||
|
||||
assert len(indices) == len(set(indices)), "DataLoader workers yielded duplicate frames within a rank"
|
||||
|
||||
|
||||
def test_sarm_window_covers_long_horizon_without_padding(tmp_path, lerobot_dataset_factory):
|
||||
"""A delta window longer than the old 100-frame ceiling must fetch real frames, not pad them.
|
||||
|
||||
SARM uses a window of 8 steps spaced 1s (~160 frames @ fps20). Here fps=30, so +5s = 150 frames > 100.
|
||||
"""
|
||||
repo_id = f"{DUMMY_REPO_ID}-sarm"
|
||||
# A single long episode so a +150-frame lookahead is unambiguously inside the episode (the fixture
|
||||
# gives episodes variable lengths, so multi-episode boundaries can't be assumed).
|
||||
episode_frames = 300
|
||||
_make_local_dataset(
|
||||
lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=1, total_frames=episode_frames
|
||||
)
|
||||
|
||||
horizon_s = 5.0 # 150 frames @ fps30, well beyond LOOKAHEAD_BACKTRACKTABLE=100
|
||||
delta_timestamps = {ACTION: [0.0, horizon_s]}
|
||||
ds = StreamingLeRobotDataset(
|
||||
repo_id=repo_id,
|
||||
root=tmp_path / "ds",
|
||||
shuffle=False,
|
||||
episode_pool_size=1,
|
||||
max_num_shards=1,
|
||||
delta_timestamps=delta_timestamps,
|
||||
)
|
||||
|
||||
horizon_frames = int(round(horizon_s * ds.fps))
|
||||
assert horizon_frames > 100, "test must exceed the old LOOKAHEAD_BACKTRACKTABLE ceiling"
|
||||
checked = 0
|
||||
for frame in ds:
|
||||
idx = int(frame["index"])
|
||||
# The +horizon target is inside the single episode -> it must be a real frame, not padding.
|
||||
if idx + horizon_frames < episode_frames:
|
||||
assert not bool(frame[f"{ACTION}_is_pad"][-1]), (
|
||||
f"frame {idx}: +{horizon_frames} target was padded; long delta window did not reach it"
|
||||
)
|
||||
checked += 1
|
||||
assert checked > 0, "test did not exercise any in-episode long-horizon frame"
|
||||
|
||||
|
||||
def test_pool_order_is_deterministic_per_seed(tmp_path, lerobot_dataset_factory):
|
||||
repo_id = f"{DUMMY_REPO_ID}-seeds"
|
||||
_make_local_dataset(lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=6, total_frames=120)
|
||||
|
||||
def order(seed):
|
||||
return _stream_indices(
|
||||
StreamingLeRobotDataset(
|
||||
repo_id=repo_id,
|
||||
root=tmp_path / "ds",
|
||||
shuffle=True,
|
||||
seed=seed,
|
||||
episode_pool_size=4,
|
||||
max_num_shards=2,
|
||||
)
|
||||
)
|
||||
|
||||
assert order(0) == order(0), "same seed must reproduce the same order"
|
||||
assert order(0) != order(1), "different seeds should give different orders"
|
||||
|
||||
|
||||
def test_pool_epochs_reshuffle_and_cover(tmp_path, lerobot_dataset_factory):
|
||||
"""Consecutive passes over the same dataset object reshuffle (epoch advances) but keep coverage."""
|
||||
repo_id = f"{DUMMY_REPO_ID}-epochs"
|
||||
total_frames = 120
|
||||
_make_local_dataset(
|
||||
lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=6, total_frames=total_frames
|
||||
)
|
||||
ds = StreamingLeRobotDataset(
|
||||
repo_id=repo_id, root=tmp_path / "ds", shuffle=True, seed=3, episode_pool_size=4, max_num_shards=2
|
||||
)
|
||||
epoch_0 = _stream_indices(ds)
|
||||
epoch_1 = _stream_indices(ds)
|
||||
assert sorted(epoch_0) == sorted(epoch_1) == list(range(total_frames))
|
||||
assert epoch_0 != epoch_1, "epoch did not reshuffle"
|
||||
|
||||
|
||||
def test_pool_mixes_episodes(tmp_path, lerobot_dataset_factory):
|
||||
"""Early samples should already come from several distinct episodes (the pool's purpose)."""
|
||||
repo_id = f"{DUMMY_REPO_ID}-mix"
|
||||
_make_local_dataset(lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=8, total_frames=200)
|
||||
ds = StreamingLeRobotDataset(
|
||||
repo_id=repo_id, root=tmp_path / "ds", shuffle=True, seed=0, episode_pool_size=8, max_num_shards=4
|
||||
)
|
||||
episodes_in_head = {int(frame["episode_index"]) for _, frame in zip(range(20), ds, strict=False)}
|
||||
assert len(episodes_in_head) >= 3, f"pool did not mix episodes: {episodes_in_head}"
|
||||
|
||||
|
||||
def test_schema_parity_with_map_style(tmp_path, lerobot_dataset_factory):
|
||||
"""Streamed samples must have the same keys / shapes / dtypes as map-style LeRobotDataset."""
|
||||
repo_id = f"{DUMMY_REPO_ID}-parity"
|
||||
map_ds = lerobot_dataset_factory(
|
||||
root=tmp_path / "ds", repo_id=repo_id, total_episodes=4, total_frames=80, use_videos=True
|
||||
)
|
||||
stream_ds = StreamingLeRobotDataset(
|
||||
repo_id=repo_id, root=tmp_path / "ds", shuffle=False, episode_pool_size=4, max_num_shards=2
|
||||
)
|
||||
|
||||
map_frame = map_ds[0]
|
||||
stream_frame = next(iter(stream_ds))
|
||||
|
||||
assert set(stream_frame) == set(map_frame), set(stream_frame) ^ set(map_frame)
|
||||
for key, value in stream_frame.items():
|
||||
ref = map_frame[key]
|
||||
if isinstance(value, torch.Tensor):
|
||||
assert isinstance(ref, torch.Tensor) and value.shape == ref.shape and value.dtype == ref.dtype, (
|
||||
f"{key}: stream {tuple(value.shape)}/{value.dtype} vs map {tuple(ref.shape)}/{ref.dtype}"
|
||||
)
|
||||
elif isinstance(value, str):
|
||||
assert isinstance(ref, str), f"{key}: {type(value)} vs {type(ref)}"
|
||||
else:
|
||||
# Scalar numerics: streaming yields python floats where map-style yields 0-dim tensors
|
||||
# (a long-standing, accepted difference). Compare by value rather than exact type.
|
||||
assert float(value) == float(ref), f"{key}: {value} vs {ref}"
|
||||
|
||||
|
||||
def test_video_path_resolution_local(tmp_path, lerobot_dataset_factory, monkeypatch):
|
||||
"""For a local (prewarmed) root, video decode must be issued against the local path, not hf://."""
|
||||
import lerobot.datasets.streaming_dataset as sd
|
||||
|
||||
repo_id = f"{DUMMY_REPO_ID}-vpath"
|
||||
lerobot_dataset_factory(
|
||||
root=tmp_path / "ds", repo_id=repo_id, total_episodes=2, total_frames=40, use_videos=True
|
||||
)
|
||||
ds = StreamingLeRobotDataset(
|
||||
repo_id=repo_id, root=tmp_path / "ds", shuffle=False, episode_pool_size=1, max_num_shards=1
|
||||
)
|
||||
|
||||
seen_paths = []
|
||||
|
||||
def fake_decode(video_path, query_ts, *args, **kwargs):
|
||||
seen_paths.append(str(video_path))
|
||||
return torch.zeros(len(query_ts), 3, 64, 96)
|
||||
|
||||
monkeypatch.setattr(sd, "decode_video_frames_torchcodec", fake_decode)
|
||||
next(iter(ds))
|
||||
|
||||
assert seen_paths, "no video decode was issued"
|
||||
assert all(str(ds.root) in p and not p.startswith("hf://") for p in seen_paths), seen_paths
|
||||
|
||||
|
||||
def test_shuffle_decorrelates_output_order(tmp_path, lerobot_dataset_factory):
|
||||
"""With shuffle on, streamed frame order must differ from the underlying sequential order."""
|
||||
repo_id = f"{DUMMY_REPO_ID}-shuf"
|
||||
_make_local_dataset(lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=8, total_frames=200)
|
||||
ordered = _stream_indices(
|
||||
StreamingLeRobotDataset(
|
||||
repo_id=repo_id, root=tmp_path / "ds", shuffle=False, episode_pool_size=1, max_num_shards=1
|
||||
)
|
||||
)
|
||||
shuffled = _stream_indices(
|
||||
StreamingLeRobotDataset(
|
||||
repo_id=repo_id, root=tmp_path / "ds", shuffle=True, episode_pool_size=8, max_num_shards=4, seed=0
|
||||
)
|
||||
)
|
||||
assert sorted(shuffled) == sorted(ordered), "shuffling changed the set of frames"
|
||||
assert shuffled != ordered, "shuffle did not decorrelate output order"
|
||||
|
||||
|
||||
def test_native_resume_never_repeats_and_loss_is_bounded(tmp_path, lerobot_dataset_factory):
|
||||
"""Native state_dict resume: no sample is re-yielded; loss is bounded by the shuffle buffers."""
|
||||
repo_id = f"{DUMMY_REPO_ID}-native-resume"
|
||||
total_frames = 100
|
||||
_make_local_dataset(
|
||||
lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=5, total_frames=total_frames
|
||||
)
|
||||
|
||||
def fresh_ds():
|
||||
return StreamingLeRobotDataset(
|
||||
repo_id=repo_id,
|
||||
root=tmp_path / "ds",
|
||||
shuffle=True,
|
||||
seed=7,
|
||||
episode_pool_size=2,
|
||||
frame_shuffle_buffer_size=8,
|
||||
)
|
||||
|
||||
ds = fresh_ds()
|
||||
it = iter(ds)
|
||||
consumed = [int(next(it)["index"]) for _ in range(30)]
|
||||
state = ds.state_dict()
|
||||
|
||||
resumed_ds = fresh_ds()
|
||||
resumed_ds.load_state_dict(state)
|
||||
rest = [int(frame["index"]) for frame in resumed_ds]
|
||||
|
||||
assert not set(consumed) & set(rest), "resume re-yielded already-seen frames"
|
||||
# in-flight buffer contents are skipped on resume (documented datasets behavior):
|
||||
# bounded by the episode pool (2 episodes of <= ~30 frames here) + frame buffer (8)
|
||||
covered = len(set(consumed) | set(rest))
|
||||
max_in_flight = 2 * 30 + 8
|
||||
assert covered >= total_frames - max_in_flight
|
||||
assert covered + len(consumed) >= total_frames - max_in_flight
|
||||
|
||||
|
||||
def test_pipeline_uses_native_primitives(tmp_path, lerobot_dataset_factory):
|
||||
"""The tabular pipeline is pure datasets: batch(by_column) + shuffle + map + shuffle."""
|
||||
repo_id = f"{DUMMY_REPO_ID}-native-pipe"
|
||||
_make_local_dataset(lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=4, total_frames=80)
|
||||
ds = StreamingLeRobotDataset(repo_id=repo_id, root=tmp_path / "ds", shuffle=True, episode_pool_size=2)
|
||||
import datasets as hf_datasets
|
||||
|
||||
assert isinstance(ds._pipeline, hf_datasets.IterableDataset)
|
||||
state = ds._pipeline.state_dict() # the native resume protocol is available end-to-end
|
||||
assert state is not None
|
||||
|
||||
|
||||
# --- Plan B: random-episode admission via reshard() + multi-input-shard shuffle ---
|
||||
|
||||
|
||||
def test_reshard_makes_one_shard_per_episode(tmp_path, lerobot_dataset_factory):
|
||||
"""With one row group per episode (the writer's invariant), reshard() turns each episode into its
|
||||
own shard, so num_shards == total_episodes even when many episodes share a single data file."""
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
repo_id = f"{DUMMY_REPO_ID}-reshard"
|
||||
total_episodes = 3
|
||||
# Default (large) data-file size packs all (unequal-length) episodes into one file, so the only way
|
||||
# num_shards can reach total_episodes is per-row-group resharding.
|
||||
lerobot_dataset_factory(
|
||||
root=tmp_path / "ds",
|
||||
repo_id=repo_id,
|
||||
total_episodes=total_episodes,
|
||||
total_frames=90,
|
||||
use_videos=False,
|
||||
)
|
||||
ds = StreamingLeRobotDataset(repo_id=repo_id, root=tmp_path / "ds", shuffle=False, episode_pool_size=3)
|
||||
|
||||
file_to_eps = ds._episode_files()
|
||||
assert len(file_to_eps) == 1, "test expects all episodes packed into a single data file"
|
||||
for (chunk_idx, file_idx), eps in file_to_eps.items():
|
||||
rel = ds.meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||
assert pq.ParquetFile(str(ds.root / rel)).num_row_groups == len(eps)
|
||||
|
||||
assert ds.num_shards == total_episodes
|
||||
|
||||
|
||||
def test_max_buffer_input_shards_admits_random_episodes(tmp_path, lerobot_dataset_factory):
|
||||
"""max_buffer_input_shards (== concurrently-live random episodes) drives the per-batch episode mix:
|
||||
a single batch should already span most of the live episodes."""
|
||||
repo_id = f"{DUMMY_REPO_ID}-frac"
|
||||
total_episodes = 8
|
||||
lerobot_dataset_factory(
|
||||
root=tmp_path / "ds",
|
||||
repo_id=repo_id,
|
||||
total_episodes=total_episodes,
|
||||
total_frames=240,
|
||||
use_videos=False,
|
||||
)
|
||||
ds = StreamingLeRobotDataset(
|
||||
repo_id=repo_id,
|
||||
root=tmp_path / "ds",
|
||||
shuffle=True,
|
||||
seed=0,
|
||||
episode_pool_size=total_episodes,
|
||||
max_buffer_input_shards=total_episodes,
|
||||
)
|
||||
assert ds.max_buffer_input_shards == total_episodes
|
||||
|
||||
batch = 32
|
||||
head = {int(frame["episode_index"]) for _, frame in zip(range(batch), ds, strict=False)}
|
||||
assert len(head) >= min(total_episodes, batch) - 2, f"batch did not mix random episodes: {head}"
|
||||
|
||||
|
||||
def test_collapsed_row_groups_raise(tmp_path, lerobot_dataset_factory):
|
||||
"""A data file that collapses several episodes into a single row group (bulk df.to_parquet /
|
||||
push_to_hub) must be rejected with an actionable error: reshard() cannot address its episodes."""
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
repo_id = f"{DUMMY_REPO_ID}-collapsed"
|
||||
lerobot_dataset_factory(
|
||||
root=tmp_path / "ds", repo_id=repo_id, total_episodes=3, total_frames=90, use_videos=False
|
||||
)
|
||||
# Rewrite every data file as a single row group (simulating the aggregate/push_to_hub collapse).
|
||||
for parquet_path in (tmp_path / "ds" / "data").rglob("*.parquet"):
|
||||
pq.write_table(pq.read_table(parquet_path), parquet_path)
|
||||
|
||||
with pytest.raises(ValueError, match="ONE ROW GROUP PER EPISODE"):
|
||||
StreamingLeRobotDataset(repo_id=repo_id, root=tmp_path / "ds", shuffle=False, episode_pool_size=3)
|
||||
|
||||
|
||||
def test_collapsed_row_groups_can_be_bypassed(tmp_path, lerobot_dataset_factory):
|
||||
"""validate_row_groups=False skips the row-group check (collapsed datasets still load, degraded)."""
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
repo_id = f"{DUMMY_REPO_ID}-collapsed-bypass"
|
||||
lerobot_dataset_factory(
|
||||
root=tmp_path / "ds", repo_id=repo_id, total_episodes=3, total_frames=90, use_videos=False
|
||||
)
|
||||
for parquet_path in (tmp_path / "ds" / "data").rglob("*.parquet"):
|
||||
pq.write_table(pq.read_table(parquet_path), parquet_path)
|
||||
|
||||
ds = StreamingLeRobotDataset(
|
||||
repo_id=repo_id, root=tmp_path / "ds", shuffle=False, episode_pool_size=3, validate_row_groups=False
|
||||
)
|
||||
assert sorted(int(frame["index"]) for frame in ds) == list(range(90))
|
||||
|
||||
|
||||
def test_distributed_divisibility_guard_raises(tmp_path, lerobot_dataset_factory):
|
||||
"""When num_shards (== episodes after reshard) is not divisible by world_size, every rank would
|
||||
stream the whole dataset; the guard must raise instead of silently degrading."""
|
||||
repo_id = f"{DUMMY_REPO_ID}-divis"
|
||||
lerobot_dataset_factory(
|
||||
root=tmp_path / "ds", repo_id=repo_id, total_episodes=3, total_frames=90, use_videos=False
|
||||
)
|
||||
with pytest.raises(ValueError, match="not divisible by world_size"):
|
||||
StreamingLeRobotDataset(
|
||||
repo_id=repo_id, root=tmp_path / "ds", shuffle=False, episode_pool_size=3, rank=0, world_size=2
|
||||
)
|
||||
|
||||
# Bypassing the guard downgrades it to a warning (no raise).
|
||||
ds = StreamingLeRobotDataset(
|
||||
repo_id=repo_id,
|
||||
root=tmp_path / "ds",
|
||||
shuffle=False,
|
||||
episode_pool_size=3,
|
||||
rank=0,
|
||||
world_size=2,
|
||||
validate_row_groups=False,
|
||||
)
|
||||
assert ds.num_shards == 3
|
||||
Vendored
+22
-3
@@ -17,6 +17,7 @@ from pathlib import Path
|
||||
import datasets
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pyarrow.parquet as pq
|
||||
import pytest
|
||||
from datasets import Dataset
|
||||
|
||||
@@ -35,6 +36,24 @@ from lerobot.datasets.utils import (
|
||||
)
|
||||
|
||||
|
||||
def _to_parquet_one_row_group_per_episode(hf_dataset: Dataset, path: Path) -> None:
|
||||
"""Write ``hf_dataset`` to ``path`` with one Parquet row group per episode.
|
||||
|
||||
Mirrors the LeRobot recording writer (one ``write_table`` per episode) so each episode stays an
|
||||
independently addressable shard after ``datasets.IterableDataset.reshard()``, which
|
||||
``StreamingLeRobotDataset`` relies on. ``Dataset.to_parquet`` would collapse the file into a
|
||||
single row group instead.
|
||||
"""
|
||||
table = hf_dataset.with_format("arrow")[:]
|
||||
episode_index = np.asarray(hf_dataset["episode_index"])
|
||||
boundaries = np.where(np.diff(episode_index) != 0)[0] + 1
|
||||
starts = [0, *boundaries.tolist()]
|
||||
ends = [*boundaries.tolist(), len(episode_index)]
|
||||
with pq.ParquetWriter(str(path), table.schema) as writer:
|
||||
for start, end in zip(starts, ends, strict=True):
|
||||
writer.write_table(table.slice(start, end - start))
|
||||
|
||||
|
||||
def write_hf_dataset(
|
||||
hf_dataset: Dataset,
|
||||
local_dir: Path,
|
||||
@@ -67,7 +86,7 @@ def write_hf_dataset(
|
||||
# If the dataset is small enough, write it to a single file.
|
||||
path = local_dir / DEFAULT_DATA_PATH.format(chunk_index=0, file_index=0)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
hf_dataset.to_parquet(path)
|
||||
_to_parquet_one_row_group_per_episode(hf_dataset, path)
|
||||
return
|
||||
|
||||
# If the dataset is too large, split it into smaller chunks, keeping episodes whole.
|
||||
@@ -114,8 +133,8 @@ def write_hf_dataset(
|
||||
path = local_dir / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Write the shard to a Parquet file.
|
||||
dataset_shard.to_parquet(path)
|
||||
# Write the shard to a Parquet file (one row group per episode).
|
||||
_to_parquet_one_row_group_per_episode(dataset_shard, path)
|
||||
|
||||
# Update chunk and file indices for the next iteration.
|
||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, chunk_size)
|
||||
|
||||
@@ -1084,8 +1084,8 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "datasets"
|
||||
version = "4.8.5"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
version = "5.0.1.dev0"
|
||||
source = { git = "https://github.com/huggingface/datasets.git?rev=2c45eab1bb975ac3d846f2aa6217b82adec8eba3#2c45eab1bb975ac3d846f2aa6217b82adec8eba3" }
|
||||
dependencies = [
|
||||
{ name = "dill" },
|
||||
{ name = "filelock" },
|
||||
@@ -1102,10 +1102,6 @@ dependencies = [
|
||||
{ name = "tqdm" },
|
||||
{ name = "xxhash" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/66/34/14cd8e76f907f7d4dca2334cfeec9f81d30fd15c25a015f99aaea694eaed/datasets-4.8.5.tar.gz", hash = "sha256:0f0c1c3d56ffff2c93b2f4c63c95bac94f3d7e8621aea2a2a576275233bba772", size = 605649, upload-time = "2026-04-27T15:43:57.384Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/65/99/00f3196036501b53032c4b1ab8337a0b978dee832ed276dae3815df4e8b5/datasets-4.8.5-py3-none-any.whl", hash = "sha256:5079900781719c0e063a8efdd2cd95a31ad0c63209178669cd23cf1b926149ff", size = 528973, upload-time = "2026-04-27T15:43:53.702Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "debugpy"
|
||||
@@ -3078,7 +3074,7 @@ requires-dist = [
|
||||
{ name = "av", marker = "extra == 'av-dep'", specifier = ">=15.0.0,<16.0.0" },
|
||||
{ name = "cmake", specifier = ">=3.29.0.1,<4.2.0" },
|
||||
{ name = "contourpy", marker = "extra == 'matplotlib-dep'", specifier = ">=1.3.0,<2.0.0" },
|
||||
{ name = "datasets", marker = "extra == 'dataset'", specifier = ">=4.7.0,<5.0.0" },
|
||||
{ name = "datasets", marker = "extra == 'dataset'", git = "https://github.com/huggingface/datasets.git?rev=2c45eab1bb975ac3d846f2aa6217b82adec8eba3" },
|
||||
{ name = "debugpy", marker = "extra == 'dev'", specifier = ">=1.8.1,<1.9.0" },
|
||||
{ name = "decord", marker = "(platform_machine == 'AMD64' and extra == 'groot') or (platform_machine == 'x86_64' and extra == 'groot')", specifier = ">=0.6.0,<1.0.0" },
|
||||
{ name = "deepdiff", marker = "extra == 'deepdiff-dep'", specifier = ">=7.0.1,<9.0.0" },
|
||||
|
||||
Reference in New Issue
Block a user