This commit is contained in:
Pepijn
2026-02-21 18:48:46 +01:00
parent 40f4386e4a
commit ab4dce6fed
+3 -6
View File
@@ -172,16 +172,13 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
# We set step_scheduler_with_optimizer=False to prevent accelerate from adjusting the lr_scheduler steps based on the num_processes
# We set find_unused_parameters=True to handle models with conditional computation
if accelerator is None:
from datetime import timedelta
from accelerate.utils import DistributedDataParallelKwargs, InitProcessGroupKwargs
from accelerate.utils import DistributedDataParallelKwargs
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
init_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=int(os.environ.get("NCCL_TIMEOUT", 600))))
force_cpu = cfg.policy.device == "cpu"
accelerator = Accelerator(
step_scheduler_with_optimizer=False,
kwargs_handlers=[ddp_kwargs, init_kwargs],
kwargs_handlers=[ddp_kwargs],
cpu=force_cpu,
)
@@ -226,7 +223,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
chunk_size = cfg.policy.chunk_size
hf = dataset.hf_dataset
total_frames = len(hf)
max_samples = min(500_000, total_frames - chunk_size)
max_samples = min(100_000, total_frames - chunk_size)
indices = np.random.choice(total_frames - chunk_size, max_samples, replace=False)
logging.info(
f"use_delta_actions is enabled — computing delta action stats "