mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-18 02:00:03 +00:00
feat(dataset): expose tolerance_s argument to training config (#2653)
This commit is contained in:
@@ -56,6 +56,7 @@ class TrainPipelineConfig(HubMixin):
|
|||||||
steps: int = 100_000
|
steps: int = 100_000
|
||||||
eval_freq: int = 20_000
|
eval_freq: int = 20_000
|
||||||
log_freq: int = 200
|
log_freq: int = 200
|
||||||
|
tolerance_s: float = 1e-4
|
||||||
save_checkpoint: bool = True
|
save_checkpoint: bool = True
|
||||||
# Checkpoint is saved every `save_freq` training iterations and after the last training step.
|
# Checkpoint is saved every `save_freq` training iterations and after the last training step.
|
||||||
save_freq: int = 20_000
|
save_freq: int = 20_000
|
||||||
|
|||||||
@@ -98,6 +98,7 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
|
|||||||
image_transforms=image_transforms,
|
image_transforms=image_transforms,
|
||||||
revision=cfg.dataset.revision,
|
revision=cfg.dataset.revision,
|
||||||
video_backend=cfg.dataset.video_backend,
|
video_backend=cfg.dataset.video_backend,
|
||||||
|
tolerance_s=cfg.tolerance_s,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
dataset = StreamingLeRobotDataset(
|
dataset = StreamingLeRobotDataset(
|
||||||
@@ -108,6 +109,7 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
|
|||||||
image_transforms=image_transforms,
|
image_transforms=image_transforms,
|
||||||
revision=cfg.dataset.revision,
|
revision=cfg.dataset.revision,
|
||||||
max_num_shards=cfg.num_workers,
|
max_num_shards=cfg.num_workers,
|
||||||
|
tolerance_s=cfg.tolerance_s,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("The MultiLeRobotDataset isn't supported for now.")
|
raise NotImplementedError("The MultiLeRobotDataset isn't supported for now.")
|
||||||
|
|||||||
Reference in New Issue
Block a user