diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index 64d0884ea..5eeef8a8f 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -172,16 +172,13 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): # We set step_scheduler_with_optimizer=False to prevent accelerate from adjusting the lr_scheduler steps based on the num_processes # We set find_unused_parameters=True to handle models with conditional computation if accelerator is None: - from datetime import timedelta - - from accelerate.utils import DistributedDataParallelKwargs, InitProcessGroupKwargs + from accelerate.utils import DistributedDataParallelKwargs ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) - init_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=int(os.environ.get("NCCL_TIMEOUT", 600)))) force_cpu = cfg.policy.device == "cpu" accelerator = Accelerator( step_scheduler_with_optimizer=False, - kwargs_handlers=[ddp_kwargs, init_kwargs], + kwargs_handlers=[ddp_kwargs], cpu=force_cpu, ) @@ -226,7 +223,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): chunk_size = cfg.policy.chunk_size hf = dataset.hf_dataset total_frames = len(hf) - max_samples = min(500_000, total_frames - chunk_size) + max_samples = min(100_000, total_frames - chunk_size) indices = np.random.choice(total_frames - chunk_size, max_samples, replace=False) logging.info( f"use_delta_actions is enabled — computing delta action stats "