mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-17 01:30:14 +00:00
feat(dataset): expose tolerance_s argument to training config (#2653)
This commit is contained in:
@@ -56,6 +56,7 @@ class TrainPipelineConfig(HubMixin):
|
||||
steps: int = 100_000
|
||||
eval_freq: int = 20_000
|
||||
log_freq: int = 200
|
||||
tolerance_s: float = 1e-4
|
||||
save_checkpoint: bool = True
|
||||
# Checkpoint is saved every `save_freq` training iterations and after the last training step.
|
||||
save_freq: int = 20_000
|
||||
|
||||
@@ -98,6 +98,7 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
|
||||
image_transforms=image_transforms,
|
||||
revision=cfg.dataset.revision,
|
||||
video_backend=cfg.dataset.video_backend,
|
||||
tolerance_s=cfg.tolerance_s,
|
||||
)
|
||||
else:
|
||||
dataset = StreamingLeRobotDataset(
|
||||
@@ -108,6 +109,7 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
|
||||
image_transforms=image_transforms,
|
||||
revision=cfg.dataset.revision,
|
||||
max_num_shards=cfg.num_workers,
|
||||
tolerance_s=cfg.tolerance_s,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("The MultiLeRobotDataset isn't supported for now.")
|
||||
|
||||
Reference in New Issue
Block a user