diff --git a/src/lerobot/configs/train.py b/src/lerobot/configs/train.py index d17915c36..13a8d6525 100644 --- a/src/lerobot/configs/train.py +++ b/src/lerobot/configs/train.py @@ -56,6 +56,7 @@ class TrainPipelineConfig(HubMixin): steps: int = 100_000 eval_freq: int = 20_000 log_freq: int = 200 + tolerance_s: float = 1e-4 save_checkpoint: bool = True # Checkpoint is saved every `save_freq` training iterations and after the last training step. save_freq: int = 20_000 diff --git a/src/lerobot/datasets/factory.py b/src/lerobot/datasets/factory.py index f3ceb2b0c..31e939809 100644 --- a/src/lerobot/datasets/factory.py +++ b/src/lerobot/datasets/factory.py @@ -98,6 +98,7 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas image_transforms=image_transforms, revision=cfg.dataset.revision, video_backend=cfg.dataset.video_backend, + tolerance_s=cfg.tolerance_s, ) else: dataset = StreamingLeRobotDataset( @@ -108,6 +109,7 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas image_transforms=image_transforms, revision=cfg.dataset.revision, max_num_shards=cfg.num_workers, + tolerance_s=cfg.tolerance_s, ) else: raise NotImplementedError("The MultiLeRobotDataset isn't supported for now.")