mirror of
https://github.com/huggingface/lerobot.git
synced 2026-07-01 15:17:05 +00:00
feat(train): env-gated multi-node dataloader/DDP knobs
- LEROBOT_DATALOADER_MP_CONTEXT: choose dataloader worker start method (forkserver/spawn) to avoid fork() ENOMEM on multi-node EFA clusters. - LEROBOT_DDP_STATIC_GRAPH / LEROBOT_DDP_FIND_UNUSED: opt into static_graph to restore DDP backward/comm overlap when the used-param set is stable. - LEROBOT_DEBUG_NO_GRAD_SYNC: diagnostic-only no_sync to isolate compute vs comms in per-step time. All default to prior behavior when unset. Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -118,6 +118,14 @@ def update_policy(
|
||||
if sample_weighter is not None:
|
||||
sample_weights, weight_stats = sample_weighter.compute_batch_weights(batch)
|
||||
|
||||
# Diagnostic-only: skip DDP gradient all-reduce to isolate compute vs comms
|
||||
# in the per-step time. Training is incorrect under this flag; use for probes.
|
||||
sync_ctx = (
|
||||
accelerator.no_sync(policy)
|
||||
if os.environ.get("LEROBOT_DEBUG_NO_GRAD_SYNC") == "1"
|
||||
else nullcontext()
|
||||
)
|
||||
|
||||
# Let accelerator handle mixed precision
|
||||
with accelerator.autocast():
|
||||
if sample_weights is not None:
|
||||
@@ -143,7 +151,8 @@ def update_policy(
|
||||
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
|
||||
|
||||
# Use accelerator's backward method
|
||||
accelerator.backward(loss)
|
||||
with sync_ctx:
|
||||
accelerator.backward(loss)
|
||||
|
||||
# Clip gradients if specified
|
||||
if grad_clip_norm > 0:
|
||||
@@ -365,7 +374,17 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
|
||||
from accelerate.utils import InitProcessGroupKwargs
|
||||
|
||||
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
||||
# find_unused_parameters=True is needed for conditional computation but
|
||||
# breaks DDP's gradient/backward overlap and bucket coalescing, which is
|
||||
# cheap intra-node (NVLink) but very costly across nodes (EFA). When the
|
||||
# set of used params is stable, static_graph=True keeps unused-param
|
||||
# support AND restores overlap. Env-gated; defaults preserve old behavior.
|
||||
ddp_find_unused = os.environ.get("LEROBOT_DDP_FIND_UNUSED", "1") == "1"
|
||||
ddp_static_graph = os.environ.get("LEROBOT_DDP_STATIC_GRAPH", "0") == "1"
|
||||
ddp_kwargs = DistributedDataParallelKwargs(
|
||||
find_unused_parameters=ddp_find_unused and not ddp_static_graph,
|
||||
static_graph=ddp_static_graph,
|
||||
)
|
||||
# Bump the c10d store-get / barrier timeout so the rank-0-only
|
||||
# ``make_dataset`` block below doesn't trigger a barrier crash on
|
||||
# large datasets. Default is 10 min (``store->get`` 600 s); a
|
||||
@@ -671,6 +690,11 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
# declares language columns; otherwise stay on PyTorch's default
|
||||
# collate so non-language training runs are unaffected.
|
||||
collate_fn = lerobot_collate_fn if dataset.meta.has_language_columns else None
|
||||
# On multi-node EFA clusters, forking workers from a multi-GB rank process can
|
||||
# fail with OSError(ENOMEM) because fork() reserve-charges the parent's full
|
||||
# virtual footprint. Allow opting into "forkserver"/"spawn" so workers come
|
||||
# from a clean process instead. Unset => default "fork" (unchanged behavior).
|
||||
mp_context = os.environ.get("LEROBOT_DATALOADER_MP_CONTEXT") or None
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=cfg.num_workers,
|
||||
@@ -682,6 +706,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
collate_fn=collate_fn,
|
||||
prefetch_factor=cfg.prefetch_factor if cfg.num_workers > 0 else None,
|
||||
persistent_workers=cfg.persistent_workers and cfg.num_workers > 0,
|
||||
multiprocessing_context=mp_context if cfg.num_workers > 0 else None,
|
||||
)
|
||||
|
||||
# Prepare everything with accelerator
|
||||
|
||||
Reference in New Issue
Block a user