fix pre commit

This commit is contained in:
Pepijn
2025-10-10 13:35:26 +02:00
parent d709acfc55
commit 771b03c30d
2 changed files with 4 additions and 4 deletions
+1 -1
View File
@@ -57,6 +57,7 @@ accelerate launch \
``` ```
**Key accelerate parameters:** **Key accelerate parameters:**
- `--multi_gpu`: Enable multi-GPU training - `--multi_gpu`: Enable multi-GPU training
- `--num_processes=2`: Number of GPUs to use - `--num_processes=2`: Number of GPUs to use
- `--mixed_precision=fp16`: Use fp16 mixed precision (or `bf16` if supported) - `--mixed_precision=fp16`: Use fp16 mixed precision (or `bf16` if supported)
@@ -97,4 +98,3 @@ For faster training, you can enable mixed precision (fp16 or bf16). This is conf
- When saving or pushing models, LeRobot automatically unwraps the model from accelerate's distributed wrapper to ensure compatibility. - When saving or pushing models, LeRobot automatically unwraps the model from accelerate's distributed wrapper to ensure compatibility.
For more advanced configurations and troubleshooting, see the [Accelerate documentation](https://huggingface.co/docs/accelerate). For more advanced configurations and troubleshooting, see the [Accelerate documentation](https://huggingface.co/docs/accelerate).
+3 -3
View File
@@ -15,10 +15,10 @@
# limitations under the License. # limitations under the License.
import logging import logging
import time import time
from collections.abc import Callable
from contextlib import nullcontext from contextlib import nullcontext
from pprint import pformat from pprint import pformat
from typing import Any from typing import Any
from collections.abc import Callable
import torch import torch
from termcolor import colored from termcolor import colored
@@ -163,7 +163,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None):
cfg: A `TrainPipelineConfig` object containing all training configurations. cfg: A `TrainPipelineConfig` object containing all training configurations.
""" """
cfg.validate() cfg.validate()
if accelerator and not accelerator.is_main_process: if accelerator and not accelerator.is_main_process:
# Disable logging on non-main processes. # Disable logging on non-main processes.
cfg.wandb.enable = False cfg.wandb.enable = False
@@ -311,7 +311,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None):
if not accelerator or accelerator.is_main_process: if not accelerator or accelerator.is_main_process:
logging.info("Start offline training on a fixed dataset") logging.info("Start offline training on a fixed dataset")
for _ in range(step, cfg.steps): for _ in range(step, cfg.steps):
start_time = time.perf_counter() start_time = time.perf_counter()
batch = next(dl_iter) batch = next(dl_iter)