From 771b03c30df3b1ef62de3faeb95be63f52614e63 Mon Sep 17 00:00:00 2001 From: Pepijn Date: Fri, 10 Oct 2025 13:35:26 +0200 Subject: [PATCH] fix pre commit --- docs/source/multi_gpu_training.mdx | 2 +- src/lerobot/scripts/lerobot_train.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/source/multi_gpu_training.mdx b/docs/source/multi_gpu_training.mdx index a6cd540bc..1e26e8806 100644 --- a/docs/source/multi_gpu_training.mdx +++ b/docs/source/multi_gpu_training.mdx @@ -57,6 +57,7 @@ accelerate launch \ ``` **Key accelerate parameters:** + - `--multi_gpu`: Enable multi-GPU training - `--num_processes=2`: Number of GPUs to use - `--mixed_precision=fp16`: Use fp16 mixed precision (or `bf16` if supported) @@ -97,4 +98,3 @@ For faster training, you can enable mixed precision (fp16 or bf16). This is conf - When saving or pushing models, LeRobot automatically unwraps the model from accelerate's distributed wrapper to ensure compatibility. For more advanced configurations and troubleshooting, see the [Accelerate documentation](https://huggingface.co/docs/accelerate). - diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index 65e86cb46..1d1f0adc8 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -15,10 +15,10 @@ # limitations under the License. import logging import time +from collections.abc import Callable from contextlib import nullcontext from pprint import pformat from typing import Any -from collections.abc import Callable import torch from termcolor import colored @@ -163,7 +163,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None): cfg: A `TrainPipelineConfig` object containing all training configurations. """ cfg.validate() - + if accelerator and not accelerator.is_main_process: # Disable logging on non-main processes. cfg.wandb.enable = False @@ -311,7 +311,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None): if not accelerator or accelerator.is_main_process: logging.info("Start offline training on a fixed dataset") - + for _ in range(step, cfg.steps): start_time = time.perf_counter() batch = next(dl_iter)