diff --git a/src/lerobot/configs/train.py b/src/lerobot/configs/train.py index ab34d43c7..11bdb480a 100644 --- a/src/lerobot/configs/train.py +++ b/src/lerobot/configs/train.py @@ -99,12 +99,13 @@ class TrainPipelineConfig(HubMixin): batch_size: int = 8 prefetch_factor: int = 4 persistent_workers: bool = True - # Use a deterministic O(1)-memory sampler (seeded Feistel permutation) instead of materialized + # Use a deterministic O(1)-memory sampler (seeded Feistel permutation) instead of RNG-based # index shuffling. The data order becomes a pure function of (seed, epoch), which makes # distributed sharding immune to RNG desync, keeps sampler memory constant in dataset size, # and enables sample-exact resume of interrupted runs. The shuffle is pseudo-random rather - # than a true uniform permutation. Not compatible with dataset.streaming. - deterministic_sampler: bool = False + # than a true uniform permutation. Set to false to restore the legacy RNG-based shuffle + # order. Ignored when dataset.streaming is enabled. + deterministic_sampler: bool = True steps: int = 100_000 eval_freq: int = 20_000 log_freq: int = 200 @@ -139,12 +140,6 @@ class TrainPipelineConfig(HubMixin): return self.policy # type: ignore[return-value] def validate(self) -> None: - if self.deterministic_sampler and self.dataset.streaming: - raise ValueError( - "deterministic_sampler requires a map-style dataset and is not compatible with " - "dataset.streaming=true." - ) - # HACK: We parse again the cli args here to get the pretrained paths if there was some. policy_path = parser.get_path_arg("policy") reward_model_path = parser.get_path_arg("reward_model") diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index 405928d45..186431c54 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -387,7 +387,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})") # create dataloader for offline training - if cfg.deterministic_sampler: + if cfg.deterministic_sampler and not cfg.dataset.streaming: # Data order is a pure function of (seed, epoch): nothing to synchronize across ranks, # O(1) memory in dataset size, and a resumed run continues at the exact sample where the # checkpoint left off (up to accelerate's even_batches padding at epoch boundaries).