mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 16:57:12 +00:00
1050c2fb6c
Replace the shard/Backtrackable/decoded-shuffle-buffer internals with an
episode pool: each (rank x worker) consumer keeps episode_pool_size whole
episodes' tabular rows in RAM and emits uniformly random frames across
them. delta_timestamps windows become exact in-RAM slices with correct
boundary padding (the Backtrackable machinery and its lookback/lookahead
ceilings are gone), and video is decoded only when a sample is emitted,
so pool memory stays tabular-sized instead of buffer_size decoded
samples.
- Prefetch-on-admit: when streaming from a remote source, each pooled
episode's video files download to a local cache in the background
(refcounted, since v3 packs several episodes per file; deleted on
eviction), so decode-on-exit reads local bytes instead of paying
network seek latency.
- Per-consumer RNG derived from (seed, epoch, rank, worker): consumers
decorrelated, runs reproducible, epochs reshuffle automatically.
- Deterministic fast-forward resume: load_state_dict takes the trainer's
{batches_consumed, batch_size}; each worker re-derives its own skip
from the DataLoader's round-robin batch assignment and replays
tabular-only (no decode). Exact within an epoch, works with
num_workers > 0, and the same state file serves every rank. Replaces
the per-shard HF state_dict approach, which lived in worker processes
and could not be captured from the trainer.
- Shard-cap default removed (max_num_shards=None uses every parquet
shard); runtime warnings for non-divisible world sizes (datasets
degrades to read-everything splitting) and workers left without
shards.
- episode_pool_size replaces buffer_size (deprecated, ignored with a
warning); decoder cache sized to the pool working set, capped at 128.
Legacy order-replication tests asserted the old buffer algorithm
step-by-step and are rewritten as behavior contracts (exactly-once
coverage, per-seed determinism, epoch reshuffle). Value-level parity
tests against the map-style dataset pass unchanged.
Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
180 lines
7.9 KiB
Python
180 lines
7.9 KiB
Python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""Distributed, resumable streaming training on a large HF-hosted dataset.
|
|
|
|
This example shows how to train (or just stress the data pipeline) over a multi-TB dataset that never
|
|
touches local disk, scaling across GPUs and nodes with Accelerate. It demonstrates the large-scale
|
|
streaming features of :class:`StreamingLeRobotDataset`:
|
|
|
|
- per-rank sharding via ``split_dataset_by_node`` (each GPU streams disjoint data; ``rank``/``world_size``
|
|
are auto-resolved from the Accelerate state, so nothing needs to be passed explicitly);
|
|
- DataLoader-worker shard splitting (no duplicate frames within a rank);
|
|
- deterministic fast-forward resume via ``dataset.load_state_dict()`` (trainer-side counters only);
|
|
- an explicit video-decoder cache size so the working set of open decoders does not thrash.
|
|
|
|
Launch with Accelerate (single node, N GPUs):
|
|
|
|
accelerate launch --num_processes=8 examples/scaling/train_streaming_multinode.py \
|
|
--repo_id=pepijn223/robocasa_pretrain_human300_v4 --batch_size=64
|
|
|
|
Multinode runs use the same script under SLURM; see ``slurm/train_streaming_robocasa.sh``.
|
|
|
|
Pass ``--dummy`` to skip the model entirely and measure pure dataloading throughput.
|
|
"""
|
|
|
|
import argparse
|
|
import time
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
from accelerate import Accelerator
|
|
from torch.utils.data import DataLoader
|
|
|
|
from lerobot.datasets import LeRobotDatasetMetadata, StreamingLeRobotDataset
|
|
from lerobot.utils.constants import ACTION
|
|
|
|
|
|
def parse_args() -> argparse.Namespace:
|
|
parser = argparse.ArgumentParser(description=__doc__)
|
|
parser.add_argument("--repo_id", type=str, default="lerobot/droid_1.0.1")
|
|
parser.add_argument(
|
|
"--root", type=str, default=None, help="Local/prewarmed dataset root (else stream from Hub)."
|
|
)
|
|
parser.add_argument("--output_dir", type=str, default="outputs/train/streaming_multinode")
|
|
parser.add_argument("--steps", type=int, default=1000)
|
|
parser.add_argument("--batch_size", type=int, default=64, help="Per-process batch size.")
|
|
parser.add_argument("--num_workers", type=int, default=8)
|
|
parser.add_argument(
|
|
"--episode_pool_size",
|
|
type=int,
|
|
default=64,
|
|
help="Whole episodes open per consumer (randomness knob).",
|
|
)
|
|
parser.add_argument("--video_decoder_cache_size", type=int, default=None)
|
|
parser.add_argument("--n_action_steps", type=int, default=16, help="Action-chunk length (delta horizon).")
|
|
parser.add_argument("--save_freq", type=int, default=200)
|
|
parser.add_argument("--log_freq", type=int, default=20)
|
|
parser.add_argument("--resume_from", type=str, default=None, help="Checkpoint dir to resume from.")
|
|
parser.add_argument("--dummy", action="store_true", help="Skip the model; measure dataloading only.")
|
|
return parser.parse_args()
|
|
|
|
|
|
def make_dataloader(
|
|
args: argparse.Namespace, meta: LeRobotDatasetMetadata
|
|
) -> tuple[DataLoader, StreamingLeRobotDataset]:
|
|
# Supervise an action chunk; delta_timestamps drive the SARM-style temporal window.
|
|
delta_timestamps = {ACTION: [t / meta.fps for t in range(args.n_action_steps)]}
|
|
# rank / world_size are resolved automatically from the Accelerate state inside the dataset.
|
|
dataset = StreamingLeRobotDataset(
|
|
args.repo_id,
|
|
root=args.root,
|
|
delta_timestamps=delta_timestamps,
|
|
episode_pool_size=args.episode_pool_size,
|
|
video_decoder_cache_size=args.video_decoder_cache_size,
|
|
tolerance_s=1e-3,
|
|
)
|
|
loader = DataLoader(
|
|
dataset,
|
|
batch_size=args.batch_size,
|
|
num_workers=args.num_workers,
|
|
pin_memory=True,
|
|
drop_last=True,
|
|
prefetch_factor=2 if args.num_workers > 0 else None,
|
|
)
|
|
return loader, dataset
|
|
|
|
|
|
def main() -> None:
|
|
args = parse_args()
|
|
accelerator = Accelerator()
|
|
output_dir = Path(args.output_dir)
|
|
if accelerator.is_main_process:
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
meta = LeRobotDatasetMetadata(args.repo_id, root=args.root)
|
|
loader, dataset = make_dataloader(args, meta)
|
|
|
|
if args.dummy:
|
|
model = optimizer = None
|
|
else:
|
|
from lerobot.policies.act import ACTConfig, ACTPolicy
|
|
from lerobot.utils.feature_utils import dataset_to_policy_features
|
|
|
|
features = dataset_to_policy_features(meta.features)
|
|
output_features = {k: ft for k, ft in features.items() if k == ACTION}
|
|
input_features = {k: ft for k, ft in features.items() if k not in output_features}
|
|
cfg = ACTConfig(input_features=input_features, output_features=output_features)
|
|
model = ACTPolicy(cfg)
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
|
|
# Do NOT prepare the dataloader: the dataset is already rank-disjoint via
|
|
# split_dataset_by_node, and accelerate's IterableDatasetShard would keep only every
|
|
# world_size-th batch of it (silently training on 1/N of the data while decoding all
|
|
# of it). Batches are moved to the device manually in the loop.
|
|
model, optimizer = accelerator.prepare(model, optimizer)
|
|
|
|
# Resume: deterministic fast-forward. Every consumer's order is a pure function of
|
|
# (seed, epoch, rank, worker), so resuming only needs the trainer-side counters; each rank and
|
|
# worker re-derives its own skip. Same file works for every rank.
|
|
if args.resume_from is not None:
|
|
state = torch.load(Path(args.resume_from) / "dataset_state.pt", weights_only=True)
|
|
dataset.load_state_dict(state)
|
|
accelerator.print(f"Resuming dataset stream: {state['batches_consumed']} batches consumed")
|
|
|
|
step = 0
|
|
frames_seen = 0
|
|
window_start = time.perf_counter()
|
|
done = False
|
|
while not done:
|
|
for batch in loader:
|
|
if model is not None:
|
|
batch = {k: (v.to(accelerator.device) if torch.is_tensor(v) else v) for k, v in batch.items()}
|
|
loss, _ = model.forward(batch)
|
|
accelerator.backward(loss)
|
|
optimizer.step()
|
|
optimizer.zero_grad()
|
|
|
|
step += 1
|
|
frames_seen += args.batch_size
|
|
if step % args.log_freq == 0:
|
|
elapsed = time.perf_counter() - window_start
|
|
fps_per_proc = (args.log_freq * args.batch_size) / max(elapsed, 1e-9)
|
|
total_fps = fps_per_proc * accelerator.num_processes
|
|
accelerator.print(
|
|
f"step {step} | {fps_per_proc:.1f} frames/s/proc | {total_fps:.1f} frames/s total"
|
|
+ ("" if model is None else f" | loss {loss.item():.3f}")
|
|
)
|
|
window_start = time.perf_counter()
|
|
|
|
if step % args.save_freq == 0 and accelerator.is_main_process:
|
|
ckpt = output_dir / f"checkpoint-{step}"
|
|
ckpt.mkdir(parents=True, exist_ok=True)
|
|
# Save the consumed-batch counters so a restart fast-forwards to this position.
|
|
torch.save(
|
|
{"batches_consumed": step, "batch_size": args.batch_size},
|
|
ckpt / "dataset_state.pt",
|
|
)
|
|
if model is not None:
|
|
accelerator.unwrap_model(model).save_pretrained(ckpt)
|
|
|
|
if step >= args.steps:
|
|
done = True
|
|
break
|
|
|
|
accelerator.print(f"End of training: {step} steps, ~{frames_seen} frames/proc")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|