Files
lerobot/examples/scaling/train_streaming_multinode.py
T
Pepijn 894fc6bfb5 refactor(streaming): rebuild StreamingLeRobotDataset on native datasets primitives
The custom episode pool becomes a pure `datasets` pipeline:

  split_dataset_by_node -> batch(by_column="episode_index")
    -> shuffle(buffer=episode_pool_size)            # episode pool
    -> map(explode + exact delta windows)           # episode -> frames
    -> shuffle(buffer=frame_shuffle_buffer_size)    # frame interleave

and the torch IterableDataset wrapper keeps only per-sample video decode
(decode-on-exit), image transforms, task lookup, and decode/fetch timing.

Replaced by native machinery and deleted: the pooled-episode admission
loop, the refcounted video prefetcher, manual worker shard striding plus
the worker-split suppression patch, the per-(epoch, rank) shard-order
permutation, the per-consumer SplitMix64 RNG, and fast-forward resume.
DataLoader workers are split by `datasets` itself; .shuffle() permutes
shard order per epoch natively; resume delegates to the native
state_dict/load_state_dict (exact with num_workers=0; with workers use
torchdata's StatefulDataLoader, which checkpoints per-worker state
through the same protocol). An in-flight epoch counter ensures a
mid-iteration state_dict records the epoch the stream position belongs
to. Buffer contents are skipped on resume (documented datasets
behavior): never repeats data, drops at most ~pool + frame-buffer frames.

Randomness is unchanged: a batch still mixes up to episode_pool_size
episodes; delta windows are still exact in-episode slices with correct
boundary padding (value-verified against the map-style dataset). The
known trade accepted with this rewrite: no video prefetch-on-admit, so
remote decode pays per-frame range reads at yield time - use a colocated
bucket (data_files_root) at large scale.

The delta-consistency tests gained a scalar-comparison branch: they
silently skipped python-scalar keys before (stale `check` variable),
exposed by the new pipeline's key ordering.

Requires datasets with #8259 (pinned to the merge commit on this
branch). Example updated to per-rank native resume via torchdata's
StatefulDataLoader when available.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
2026-06-11 21:03:09 +02:00

193 lines
8.6 KiB
Python

# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Distributed, resumable streaming training on a large HF-hosted dataset.
This example shows how to train (or just stress the data pipeline) over a multi-TB dataset that never
touches local disk, scaling across GPUs and nodes with Accelerate. It demonstrates the large-scale
streaming features of :class:`StreamingLeRobotDataset`:
- per-rank sharding via ``split_dataset_by_node`` (each GPU streams disjoint data; ``rank``/``world_size``
are auto-resolved from the Accelerate state, so nothing needs to be passed explicitly);
- DataLoader-worker shard splitting (no duplicate frames within a rank);
- native `datasets` resume: the loader checkpoints stream state via ``state_dict()`` (``torchdata`` StatefulDataLoader when available, so ``num_workers > 0`` resumes too);
- an explicit video-decoder cache size so the working set of open decoders does not thrash.
Launch with Accelerate (single node, N GPUs):
accelerate launch --num_processes=8 examples/scaling/train_streaming_multinode.py \
--repo_id=pepijn223/robocasa_pretrain_human300_v4 --batch_size=64
Multinode runs use the same script under SLURM; see ``slurm/train_streaming_robocasa.sh``.
Pass ``--dummy`` to skip the model entirely and measure pure dataloading throughput.
"""
import argparse
import time
from pathlib import Path
import torch
from accelerate import Accelerator
from torch.utils.data import DataLoader
from lerobot.datasets import LeRobotDatasetMetadata, StreamingLeRobotDataset
from lerobot.utils.constants import ACTION
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--repo_id", type=str, default="lerobot/droid_1.0.1")
parser.add_argument(
"--root", type=str, default=None, help="Local/prewarmed dataset root (else stream from Hub)."
)
parser.add_argument("--output_dir", type=str, default="outputs/train/streaming_multinode")
parser.add_argument("--steps", type=int, default=1000)
parser.add_argument("--batch_size", type=int, default=64, help="Per-process batch size.")
parser.add_argument("--num_workers", type=int, default=8)
parser.add_argument(
"--episode_pool_size",
type=int,
default=64,
help="Whole episodes open per consumer (randomness knob).",
)
parser.add_argument("--video_decoder_cache_size", type=int, default=None)
parser.add_argument("--n_action_steps", type=int, default=16, help="Action-chunk length (delta horizon).")
parser.add_argument("--save_freq", type=int, default=200)
parser.add_argument("--log_freq", type=int, default=20)
parser.add_argument("--resume_from", type=str, default=None, help="Checkpoint dir to resume from.")
parser.add_argument("--dummy", action="store_true", help="Skip the model; measure dataloading only.")
return parser.parse_args()
def make_dataloader(
args: argparse.Namespace, meta: LeRobotDatasetMetadata
) -> tuple[DataLoader, StreamingLeRobotDataset]:
# Supervise an action chunk; delta_timestamps drive the SARM-style temporal window.
delta_timestamps = {ACTION: [t / meta.fps for t in range(args.n_action_steps)]}
# rank / world_size are resolved automatically from the Accelerate state inside the dataset.
dataset = StreamingLeRobotDataset(
args.repo_id,
root=args.root,
delta_timestamps=delta_timestamps,
episode_pool_size=args.episode_pool_size,
video_decoder_cache_size=args.video_decoder_cache_size,
tolerance_s=1e-3,
)
# torchdata's StatefulDataLoader checkpoints each worker's dataset state through the
# dataset's native state_dict protocol, making resume work with num_workers > 0. Fall back
# to the plain DataLoader (resume then requires num_workers=0).
try:
from torchdata.stateful_dataloader import StatefulDataLoader
loader_cls = StatefulDataLoader
except ImportError:
loader_cls = DataLoader
loader = loader_cls(
dataset,
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=True,
drop_last=True,
prefetch_factor=2 if args.num_workers > 0 else None,
)
return loader, dataset
def main() -> None:
args = parse_args()
accelerator = Accelerator()
output_dir = Path(args.output_dir)
if accelerator.is_main_process:
output_dir.mkdir(parents=True, exist_ok=True)
meta = LeRobotDatasetMetadata(args.repo_id, root=args.root)
loader, dataset = make_dataloader(args, meta)
if args.dummy:
model = optimizer = None
else:
from lerobot.policies.act import ACTConfig, ACTPolicy
from lerobot.utils.feature_utils import dataset_to_policy_features
features = dataset_to_policy_features(meta.features)
output_features = {k: ft for k, ft in features.items() if k == ACTION}
input_features = {k: ft for k, ft in features.items() if k not in output_features}
cfg = ACTConfig(input_features=input_features, output_features=output_features)
model = ACTPolicy(cfg)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
# Do NOT prepare the dataloader: the dataset is already rank-disjoint via
# split_dataset_by_node, and accelerate's IterableDatasetShard would keep only every
# world_size-th batch of it (silently training on 1/N of the data while decoding all
# of it). Batches are moved to the device manually in the loop.
model, optimizer = accelerator.prepare(model, optimizer)
# Resume: native datasets stream state, saved per rank. With torchdata's StatefulDataLoader
# the state covers every worker; with the plain DataLoader it is exact for num_workers=0.
can_checkpoint_loader = hasattr(loader, "state_dict")
if args.resume_from is not None:
state_path = Path(args.resume_from) / f"dataset_state_rank{accelerator.process_index}.pt"
state = torch.load(state_path, weights_only=False) # plain dict of stream offsets # nosec B614
if can_checkpoint_loader:
loader.load_state_dict(state)
else:
dataset.load_state_dict(state)
accelerator.print(f"Resumed dataset stream from {state_path}")
step = 0
frames_seen = 0
window_start = time.perf_counter()
done = False
while not done:
for batch in loader:
if model is not None:
batch = {k: (v.to(accelerator.device) if torch.is_tensor(v) else v) for k, v in batch.items()}
loss, _ = model.forward(batch)
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
step += 1
frames_seen += args.batch_size
if step % args.log_freq == 0:
elapsed = time.perf_counter() - window_start
fps_per_proc = (args.log_freq * args.batch_size) / max(elapsed, 1e-9)
total_fps = fps_per_proc * accelerator.num_processes
accelerator.print(
f"step {step} | {fps_per_proc:.1f} frames/s/proc | {total_fps:.1f} frames/s total"
+ ("" if model is None else f" | loss {loss.item():.3f}")
)
window_start = time.perf_counter()
if step % args.save_freq == 0:
ckpt = output_dir / f"checkpoint-{step}"
if accelerator.is_main_process:
ckpt.mkdir(parents=True, exist_ok=True)
accelerator.wait_for_everyone()
# Every rank saves its own stream state: shard positions differ per rank.
state = loader.state_dict() if can_checkpoint_loader else dataset.state_dict()
torch.save(state, ckpt / f"dataset_state_rank{accelerator.process_index}.pt")
if model is not None and accelerator.is_main_process:
accelerator.unwrap_model(model).save_pretrained(ckpt)
if step >= args.steps:
done = True
break
accelerator.print(f"End of training: {step} steps, ~{frames_seen} frames/proc")
if __name__ == "__main__":
main()