feat(train): enable deterministic_sampler by default

Deterministic data order (sample-exact resume, no cross-rank RNG sync,
O(1) sampler memory) is now the default for map-style training; set
deterministic_sampler=false to restore the legacy RNG-based shuffle.
Streaming datasets ignore the flag (the sampler path only applies to
map-style datasets), replacing the previous hard validation error so
streaming configs keep working with the new default.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-06-11 11:45:36 +02:00
parent 32b0d7d1ef
commit b2d5d4ccfc
2 changed files with 5 additions and 10 deletions
+4 -9
View File
@@ -99,12 +99,13 @@ class TrainPipelineConfig(HubMixin):
batch_size: int = 8
prefetch_factor: int = 4
persistent_workers: bool = True
# Use a deterministic O(1)-memory sampler (seeded Feistel permutation) instead of materialized
# Use a deterministic O(1)-memory sampler (seeded Feistel permutation) instead of RNG-based
# index shuffling. The data order becomes a pure function of (seed, epoch), which makes
# distributed sharding immune to RNG desync, keeps sampler memory constant in dataset size,
# and enables sample-exact resume of interrupted runs. The shuffle is pseudo-random rather
# than a true uniform permutation. Not compatible with dataset.streaming.
deterministic_sampler: bool = False
# than a true uniform permutation. Set to false to restore the legacy RNG-based shuffle
# order. Ignored when dataset.streaming is enabled.
deterministic_sampler: bool = True
steps: int = 100_000
eval_freq: int = 20_000
log_freq: int = 200
@@ -139,12 +140,6 @@ class TrainPipelineConfig(HubMixin):
return self.policy # type: ignore[return-value]
def validate(self) -> None:
if self.deterministic_sampler and self.dataset.streaming:
raise ValueError(
"deterministic_sampler requires a map-style dataset and is not compatible with "
"dataset.streaming=true."
)
# HACK: We parse again the cli args here to get the pretrained paths if there was some.
policy_path = parser.get_path_arg("policy")
reward_model_path = parser.get_path_arg("reward_model")
+1 -1
View File
@@ -387,7 +387,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
# create dataloader for offline training
if cfg.deterministic_sampler:
if cfg.deterministic_sampler and not cfg.dataset.streaming:
# Data order is a pure function of (seed, epoch): nothing to synchronize across ranks,
# O(1) memory in dataset size, and a resumed run continues at the exact sample where the
# checkpoint left off (up to accelerate's even_batches padding at epoch boundaries).