mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 08:47:05 +00:00
894fc6bfb5
The custom episode pool becomes a pure `datasets` pipeline:
split_dataset_by_node -> batch(by_column="episode_index")
-> shuffle(buffer=episode_pool_size) # episode pool
-> map(explode + exact delta windows) # episode -> frames
-> shuffle(buffer=frame_shuffle_buffer_size) # frame interleave
and the torch IterableDataset wrapper keeps only per-sample video decode
(decode-on-exit), image transforms, task lookup, and decode/fetch timing.
Replaced by native machinery and deleted: the pooled-episode admission
loop, the refcounted video prefetcher, manual worker shard striding plus
the worker-split suppression patch, the per-(epoch, rank) shard-order
permutation, the per-consumer SplitMix64 RNG, and fast-forward resume.
DataLoader workers are split by `datasets` itself; .shuffle() permutes
shard order per epoch natively; resume delegates to the native
state_dict/load_state_dict (exact with num_workers=0; with workers use
torchdata's StatefulDataLoader, which checkpoints per-worker state
through the same protocol). An in-flight epoch counter ensures a
mid-iteration state_dict records the epoch the stream position belongs
to. Buffer contents are skipped on resume (documented datasets
behavior): never repeats data, drops at most ~pool + frame-buffer frames.
Randomness is unchanged: a batch still mixes up to episode_pool_size
episodes; delta windows are still exact in-episode slices with correct
boundary padding (value-verified against the map-style dataset). The
known trade accepted with this rewrite: no video prefetch-on-admit, so
remote decode pays per-frame range reads at yield time - use a colocated
bucket (data_files_root) at large scale.
The delta-consistency tests gained a scalar-comparison branch: they
silently skipped python-scalar keys before (stale `check` variable),
exposed by the new pipeline's key ordering.
Requires datasets with #8259 (pinned to the merge commit on this
branch). Example updated to per-rank native resume via torchdata's
StatefulDataLoader when available.
Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
193 lines
8.6 KiB
Python
193 lines
8.6 KiB
Python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""Distributed, resumable streaming training on a large HF-hosted dataset.
|
|
|
|
This example shows how to train (or just stress the data pipeline) over a multi-TB dataset that never
|
|
touches local disk, scaling across GPUs and nodes with Accelerate. It demonstrates the large-scale
|
|
streaming features of :class:`StreamingLeRobotDataset`:
|
|
|
|
- per-rank sharding via ``split_dataset_by_node`` (each GPU streams disjoint data; ``rank``/``world_size``
|
|
are auto-resolved from the Accelerate state, so nothing needs to be passed explicitly);
|
|
- DataLoader-worker shard splitting (no duplicate frames within a rank);
|
|
- native `datasets` resume: the loader checkpoints stream state via ``state_dict()`` (``torchdata`` StatefulDataLoader when available, so ``num_workers > 0`` resumes too);
|
|
- an explicit video-decoder cache size so the working set of open decoders does not thrash.
|
|
|
|
Launch with Accelerate (single node, N GPUs):
|
|
|
|
accelerate launch --num_processes=8 examples/scaling/train_streaming_multinode.py \
|
|
--repo_id=pepijn223/robocasa_pretrain_human300_v4 --batch_size=64
|
|
|
|
Multinode runs use the same script under SLURM; see ``slurm/train_streaming_robocasa.sh``.
|
|
|
|
Pass ``--dummy`` to skip the model entirely and measure pure dataloading throughput.
|
|
"""
|
|
|
|
import argparse
|
|
import time
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
from accelerate import Accelerator
|
|
from torch.utils.data import DataLoader
|
|
|
|
from lerobot.datasets import LeRobotDatasetMetadata, StreamingLeRobotDataset
|
|
from lerobot.utils.constants import ACTION
|
|
|
|
|
|
def parse_args() -> argparse.Namespace:
|
|
parser = argparse.ArgumentParser(description=__doc__)
|
|
parser.add_argument("--repo_id", type=str, default="lerobot/droid_1.0.1")
|
|
parser.add_argument(
|
|
"--root", type=str, default=None, help="Local/prewarmed dataset root (else stream from Hub)."
|
|
)
|
|
parser.add_argument("--output_dir", type=str, default="outputs/train/streaming_multinode")
|
|
parser.add_argument("--steps", type=int, default=1000)
|
|
parser.add_argument("--batch_size", type=int, default=64, help="Per-process batch size.")
|
|
parser.add_argument("--num_workers", type=int, default=8)
|
|
parser.add_argument(
|
|
"--episode_pool_size",
|
|
type=int,
|
|
default=64,
|
|
help="Whole episodes open per consumer (randomness knob).",
|
|
)
|
|
parser.add_argument("--video_decoder_cache_size", type=int, default=None)
|
|
parser.add_argument("--n_action_steps", type=int, default=16, help="Action-chunk length (delta horizon).")
|
|
parser.add_argument("--save_freq", type=int, default=200)
|
|
parser.add_argument("--log_freq", type=int, default=20)
|
|
parser.add_argument("--resume_from", type=str, default=None, help="Checkpoint dir to resume from.")
|
|
parser.add_argument("--dummy", action="store_true", help="Skip the model; measure dataloading only.")
|
|
return parser.parse_args()
|
|
|
|
|
|
def make_dataloader(
|
|
args: argparse.Namespace, meta: LeRobotDatasetMetadata
|
|
) -> tuple[DataLoader, StreamingLeRobotDataset]:
|
|
# Supervise an action chunk; delta_timestamps drive the SARM-style temporal window.
|
|
delta_timestamps = {ACTION: [t / meta.fps for t in range(args.n_action_steps)]}
|
|
# rank / world_size are resolved automatically from the Accelerate state inside the dataset.
|
|
dataset = StreamingLeRobotDataset(
|
|
args.repo_id,
|
|
root=args.root,
|
|
delta_timestamps=delta_timestamps,
|
|
episode_pool_size=args.episode_pool_size,
|
|
video_decoder_cache_size=args.video_decoder_cache_size,
|
|
tolerance_s=1e-3,
|
|
)
|
|
# torchdata's StatefulDataLoader checkpoints each worker's dataset state through the
|
|
# dataset's native state_dict protocol, making resume work with num_workers > 0. Fall back
|
|
# to the plain DataLoader (resume then requires num_workers=0).
|
|
try:
|
|
from torchdata.stateful_dataloader import StatefulDataLoader
|
|
|
|
loader_cls = StatefulDataLoader
|
|
except ImportError:
|
|
loader_cls = DataLoader
|
|
loader = loader_cls(
|
|
dataset,
|
|
batch_size=args.batch_size,
|
|
num_workers=args.num_workers,
|
|
pin_memory=True,
|
|
drop_last=True,
|
|
prefetch_factor=2 if args.num_workers > 0 else None,
|
|
)
|
|
return loader, dataset
|
|
|
|
|
|
def main() -> None:
|
|
args = parse_args()
|
|
accelerator = Accelerator()
|
|
output_dir = Path(args.output_dir)
|
|
if accelerator.is_main_process:
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
meta = LeRobotDatasetMetadata(args.repo_id, root=args.root)
|
|
loader, dataset = make_dataloader(args, meta)
|
|
|
|
if args.dummy:
|
|
model = optimizer = None
|
|
else:
|
|
from lerobot.policies.act import ACTConfig, ACTPolicy
|
|
from lerobot.utils.feature_utils import dataset_to_policy_features
|
|
|
|
features = dataset_to_policy_features(meta.features)
|
|
output_features = {k: ft for k, ft in features.items() if k == ACTION}
|
|
input_features = {k: ft for k, ft in features.items() if k not in output_features}
|
|
cfg = ACTConfig(input_features=input_features, output_features=output_features)
|
|
model = ACTPolicy(cfg)
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
|
|
# Do NOT prepare the dataloader: the dataset is already rank-disjoint via
|
|
# split_dataset_by_node, and accelerate's IterableDatasetShard would keep only every
|
|
# world_size-th batch of it (silently training on 1/N of the data while decoding all
|
|
# of it). Batches are moved to the device manually in the loop.
|
|
model, optimizer = accelerator.prepare(model, optimizer)
|
|
|
|
# Resume: native datasets stream state, saved per rank. With torchdata's StatefulDataLoader
|
|
# the state covers every worker; with the plain DataLoader it is exact for num_workers=0.
|
|
can_checkpoint_loader = hasattr(loader, "state_dict")
|
|
if args.resume_from is not None:
|
|
state_path = Path(args.resume_from) / f"dataset_state_rank{accelerator.process_index}.pt"
|
|
state = torch.load(state_path, weights_only=False) # plain dict of stream offsets # nosec B614
|
|
if can_checkpoint_loader:
|
|
loader.load_state_dict(state)
|
|
else:
|
|
dataset.load_state_dict(state)
|
|
accelerator.print(f"Resumed dataset stream from {state_path}")
|
|
|
|
step = 0
|
|
frames_seen = 0
|
|
window_start = time.perf_counter()
|
|
done = False
|
|
while not done:
|
|
for batch in loader:
|
|
if model is not None:
|
|
batch = {k: (v.to(accelerator.device) if torch.is_tensor(v) else v) for k, v in batch.items()}
|
|
loss, _ = model.forward(batch)
|
|
accelerator.backward(loss)
|
|
optimizer.step()
|
|
optimizer.zero_grad()
|
|
|
|
step += 1
|
|
frames_seen += args.batch_size
|
|
if step % args.log_freq == 0:
|
|
elapsed = time.perf_counter() - window_start
|
|
fps_per_proc = (args.log_freq * args.batch_size) / max(elapsed, 1e-9)
|
|
total_fps = fps_per_proc * accelerator.num_processes
|
|
accelerator.print(
|
|
f"step {step} | {fps_per_proc:.1f} frames/s/proc | {total_fps:.1f} frames/s total"
|
|
+ ("" if model is None else f" | loss {loss.item():.3f}")
|
|
)
|
|
window_start = time.perf_counter()
|
|
|
|
if step % args.save_freq == 0:
|
|
ckpt = output_dir / f"checkpoint-{step}"
|
|
if accelerator.is_main_process:
|
|
ckpt.mkdir(parents=True, exist_ok=True)
|
|
accelerator.wait_for_everyone()
|
|
# Every rank saves its own stream state: shard positions differ per rank.
|
|
state = loader.state_dict() if can_checkpoint_loader else dataset.state_dict()
|
|
torch.save(state, ckpt / f"dataset_state_rank{accelerator.process_index}.pt")
|
|
if model is not None and accelerator.is_main_process:
|
|
accelerator.unwrap_model(model).save_pretrained(ckpt)
|
|
|
|
if step >= args.steps:
|
|
done = True
|
|
break
|
|
|
|
accelerator.print(f"End of training: {step} steps, ~{frames_seen} frames/proc")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|