mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-24 11:47:17 +00:00
fix(train): seed sampler generator and gate dataset download per node
- Pass a generator seeded with cfg.seed to EpisodeAwareSampler so accelerator.prepare registers it as the synchronized RNG and the shuffle order is reproducible. - Gate the initial make_dataset call on is_local_main_process instead of is_main_process: the global main process only exists on node 0, so on every other node all local ranks were downloading the dataset and building the Arrow cache concurrently. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
This commit is contained in:
@@ -232,15 +232,18 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
# Dataset loading synchronization: main process downloads first to avoid race conditions
|
||||
if is_main_process:
|
||||
logging.info("Creating dataset")
|
||||
# Dataset loading synchronization: each node's local main process downloads first to avoid
|
||||
# race conditions (the global main process only exists on node 0, so gating on it would let
|
||||
# all ranks of the other nodes download and build the Arrow cache concurrently).
|
||||
if accelerator.is_local_main_process:
|
||||
if is_main_process:
|
||||
logging.info("Creating dataset")
|
||||
dataset = make_dataset(cfg)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# Now all other processes can safely load the dataset
|
||||
if not is_main_process:
|
||||
# Now all other processes can safely load the dataset from the local cache
|
||||
if not accelerator.is_local_main_process:
|
||||
dataset = make_dataset(cfg)
|
||||
|
||||
# Create environment used for evaluating checkpoints during training on simulation data.
|
||||
@@ -386,12 +389,19 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
# create dataloader for offline training
|
||||
if hasattr(active_cfg, "drop_n_last_frames"):
|
||||
shuffle = False
|
||||
# A dedicated generator (rather than the global torch RNG) lets accelerator.prepare
|
||||
# synchronize the shuffle permutation across ranks, keeping batch shards disjoint even
|
||||
# when ranks consume the global RNG asymmetrically (e.g. eval on the main process only).
|
||||
sampler_generator = torch.Generator()
|
||||
if cfg.seed is not None:
|
||||
sampler_generator.manual_seed(cfg.seed)
|
||||
sampler = EpisodeAwareSampler(
|
||||
dataset.meta.episodes["dataset_from_index"],
|
||||
dataset.meta.episodes["dataset_to_index"],
|
||||
episode_indices_to_use=dataset.episodes,
|
||||
drop_n_last_frames=active_cfg.drop_n_last_frames,
|
||||
shuffle=True,
|
||||
generator=sampler_generator,
|
||||
)
|
||||
else:
|
||||
shuffle = True
|
||||
|
||||
Reference in New Issue
Block a user