mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-17 01:30:14 +00:00
feat(train): add cudnn_deterministic option for reproducible training (#3102)
Add a `cudnn_deterministic` flag to `TrainPipelineConfig` (default: False) that sets `torch.backends.cudnn.deterministic = True` and disables benchmark mode, eliminating CUDA floating-point non-determinism at the cost of ~10-20% training speed. When False (default) the existing benchmark=True behaviour is preserved.
This commit is contained in:
@@ -50,6 +50,9 @@ class TrainPipelineConfig(HubMixin):
|
|||||||
# `seed` is used for training (eg: model initialization, dataset shuffling)
|
# `seed` is used for training (eg: model initialization, dataset shuffling)
|
||||||
# AND for the evaluation environments.
|
# AND for the evaluation environments.
|
||||||
seed: int | None = 1000
|
seed: int | None = 1000
|
||||||
|
# Set to True to use deterministic cuDNN algorithms for reproducibility.
|
||||||
|
# This disables cudnn.benchmark and may reduce training speed by ~10-20%.
|
||||||
|
cudnn_deterministic: bool = False
|
||||||
# Number of workers for the dataloader.
|
# Number of workers for the dataloader.
|
||||||
num_workers: int = 4
|
num_workers: int = 4
|
||||||
batch_size: int = 8
|
batch_size: int = 8
|
||||||
|
|||||||
@@ -209,7 +209,11 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
|||||||
|
|
||||||
# Use accelerator's device
|
# Use accelerator's device
|
||||||
device = accelerator.device
|
device = accelerator.device
|
||||||
torch.backends.cudnn.benchmark = True
|
if cfg.cudnn_deterministic:
|
||||||
|
torch.backends.cudnn.deterministic = True
|
||||||
|
torch.backends.cudnn.benchmark = False
|
||||||
|
else:
|
||||||
|
torch.backends.cudnn.benchmark = True
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
|
||||||
# Dataset loading synchronization: main process downloads first to avoid race conditions
|
# Dataset loading synchronization: main process downloads first to avoid race conditions
|
||||||
|
|||||||
Reference in New Issue
Block a user