mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-17 08:17:02 +00:00
feat(train): enable deterministic_sampler by default
Deterministic data order (sample-exact resume, no cross-rank RNG sync, O(1) sampler memory) is now the default for map-style training; set deterministic_sampler=false to restore the legacy RNG-based shuffle. Streaming datasets ignore the flag (the sampler path only applies to map-style datasets), replacing the previous hard validation error so streaming configs keep working with the new default. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
This commit is contained in:
@@ -99,12 +99,13 @@ class TrainPipelineConfig(HubMixin):
|
||||
batch_size: int = 8
|
||||
prefetch_factor: int = 4
|
||||
persistent_workers: bool = True
|
||||
# Use a deterministic O(1)-memory sampler (seeded Feistel permutation) instead of materialized
|
||||
# Use a deterministic O(1)-memory sampler (seeded Feistel permutation) instead of RNG-based
|
||||
# index shuffling. The data order becomes a pure function of (seed, epoch), which makes
|
||||
# distributed sharding immune to RNG desync, keeps sampler memory constant in dataset size,
|
||||
# and enables sample-exact resume of interrupted runs. The shuffle is pseudo-random rather
|
||||
# than a true uniform permutation. Not compatible with dataset.streaming.
|
||||
deterministic_sampler: bool = False
|
||||
# than a true uniform permutation. Set to false to restore the legacy RNG-based shuffle
|
||||
# order. Ignored when dataset.streaming is enabled.
|
||||
deterministic_sampler: bool = True
|
||||
steps: int = 100_000
|
||||
eval_freq: int = 20_000
|
||||
log_freq: int = 200
|
||||
@@ -139,12 +140,6 @@ class TrainPipelineConfig(HubMixin):
|
||||
return self.policy # type: ignore[return-value]
|
||||
|
||||
def validate(self) -> None:
|
||||
if self.deterministic_sampler and self.dataset.streaming:
|
||||
raise ValueError(
|
||||
"deterministic_sampler requires a map-style dataset and is not compatible with "
|
||||
"dataset.streaming=true."
|
||||
)
|
||||
|
||||
# HACK: We parse again the cli args here to get the pretrained paths if there was some.
|
||||
policy_path = parser.get_path_arg("policy")
|
||||
reward_model_path = parser.get_path_arg("reward_model")
|
||||
|
||||
@@ -387,7 +387,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
||||
|
||||
# create dataloader for offline training
|
||||
if cfg.deterministic_sampler:
|
||||
if cfg.deterministic_sampler and not cfg.dataset.streaming:
|
||||
# Data order is a pure function of (seed, epoch): nothing to synchronize across ranks,
|
||||
# O(1) memory in dataset size, and a resumed run continues at the exact sample where the
|
||||
# checkpoint left off (up to accelerate's even_batches padding at epoch boundaries).
|
||||
|
||||
Reference in New Issue
Block a user