feat(train): env-gated multi-node dataloader/DDP knobs

- LEROBOT_DATALOADER_MP_CONTEXT: choose dataloader worker start method
  (forkserver/spawn) to avoid fork() ENOMEM on multi-node EFA clusters.
- LEROBOT_DDP_STATIC_GRAPH / LEROBOT_DDP_FIND_UNUSED: opt into static_graph
  to restore DDP backward/comm overlap when the used-param set is stable.
- LEROBOT_DEBUG_NO_GRAD_SYNC: diagnostic-only no_sync to isolate compute
  vs comms in per-step time.

All default to prior behavior when unset.

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
pepijn
2026-06-29 14:14:24 +00:00
parent 57e4b638c3
commit ec5df4db7a
+27 -2
View File
@@ -118,6 +118,14 @@ def update_policy(
if sample_weighter is not None:
sample_weights, weight_stats = sample_weighter.compute_batch_weights(batch)
# Diagnostic-only: skip DDP gradient all-reduce to isolate compute vs comms
# in the per-step time. Training is incorrect under this flag; use for probes.
sync_ctx = (
accelerator.no_sync(policy)
if os.environ.get("LEROBOT_DEBUG_NO_GRAD_SYNC") == "1"
else nullcontext()
)
# Let accelerator handle mixed precision
with accelerator.autocast():
if sample_weights is not None:
@@ -143,7 +151,8 @@ def update_policy(
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
# Use accelerator's backward method
accelerator.backward(loss)
with sync_ctx:
accelerator.backward(loss)
# Clip gradients if specified
if grad_clip_norm > 0:
@@ -365,7 +374,17 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
from accelerate.utils import InitProcessGroupKwargs
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
# find_unused_parameters=True is needed for conditional computation but
# breaks DDP's gradient/backward overlap and bucket coalescing, which is
# cheap intra-node (NVLink) but very costly across nodes (EFA). When the
# set of used params is stable, static_graph=True keeps unused-param
# support AND restores overlap. Env-gated; defaults preserve old behavior.
ddp_find_unused = os.environ.get("LEROBOT_DDP_FIND_UNUSED", "1") == "1"
ddp_static_graph = os.environ.get("LEROBOT_DDP_STATIC_GRAPH", "0") == "1"
ddp_kwargs = DistributedDataParallelKwargs(
find_unused_parameters=ddp_find_unused and not ddp_static_graph,
static_graph=ddp_static_graph,
)
# Bump the c10d store-get / barrier timeout so the rank-0-only
# ``make_dataset`` block below doesn't trigger a barrier crash on
# large datasets. Default is 10 min (``store->get`` 600 s); a
@@ -671,6 +690,11 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
# declares language columns; otherwise stay on PyTorch's default
# collate so non-language training runs are unaffected.
collate_fn = lerobot_collate_fn if dataset.meta.has_language_columns else None
# On multi-node EFA clusters, forking workers from a multi-GB rank process can
# fail with OSError(ENOMEM) because fork() reserve-charges the parent's full
# virtual footprint. Allow opting into "forkserver"/"spawn" so workers come
# from a clean process instead. Unset => default "fork" (unchanged behavior).
mp_context = os.environ.get("LEROBOT_DATALOADER_MP_CONTEXT") or None
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=cfg.num_workers,
@@ -682,6 +706,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
collate_fn=collate_fn,
prefetch_factor=cfg.prefetch_factor if cfg.num_workers > 0 else None,
persistent_workers=cfg.persistent_workers and cfg.num_workers > 0,
multiprocessing_context=mp_context if cfg.num_workers > 0 else None,
)
# Prepare everything with accelerator