mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-20 11:09:59 +00:00
fix pre commit
This commit is contained in:
@@ -57,6 +57,7 @@ accelerate launch \
|
|||||||
```
|
```
|
||||||
|
|
||||||
**Key accelerate parameters:**
|
**Key accelerate parameters:**
|
||||||
|
|
||||||
- `--multi_gpu`: Enable multi-GPU training
|
- `--multi_gpu`: Enable multi-GPU training
|
||||||
- `--num_processes=2`: Number of GPUs to use
|
- `--num_processes=2`: Number of GPUs to use
|
||||||
- `--mixed_precision=fp16`: Use fp16 mixed precision (or `bf16` if supported)
|
- `--mixed_precision=fp16`: Use fp16 mixed precision (or `bf16` if supported)
|
||||||
@@ -97,4 +98,3 @@ For faster training, you can enable mixed precision (fp16 or bf16). This is conf
|
|||||||
- When saving or pushing models, LeRobot automatically unwraps the model from accelerate's distributed wrapper to ensure compatibility.
|
- When saving or pushing models, LeRobot automatically unwraps the model from accelerate's distributed wrapper to ensure compatibility.
|
||||||
|
|
||||||
For more advanced configurations and troubleshooting, see the [Accelerate documentation](https://huggingface.co/docs/accelerate).
|
For more advanced configurations and troubleshooting, see the [Accelerate documentation](https://huggingface.co/docs/accelerate).
|
||||||
|
|
||||||
|
|||||||
@@ -15,10 +15,10 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
|
from collections.abc import Callable
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from pprint import pformat
|
from pprint import pformat
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from collections.abc import Callable
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
@@ -163,7 +163,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None):
|
|||||||
cfg: A `TrainPipelineConfig` object containing all training configurations.
|
cfg: A `TrainPipelineConfig` object containing all training configurations.
|
||||||
"""
|
"""
|
||||||
cfg.validate()
|
cfg.validate()
|
||||||
|
|
||||||
if accelerator and not accelerator.is_main_process:
|
if accelerator and not accelerator.is_main_process:
|
||||||
# Disable logging on non-main processes.
|
# Disable logging on non-main processes.
|
||||||
cfg.wandb.enable = False
|
cfg.wandb.enable = False
|
||||||
@@ -311,7 +311,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None):
|
|||||||
|
|
||||||
if not accelerator or accelerator.is_main_process:
|
if not accelerator or accelerator.is_main_process:
|
||||||
logging.info("Start offline training on a fixed dataset")
|
logging.info("Start offline training on a fixed dataset")
|
||||||
|
|
||||||
for _ in range(step, cfg.steps):
|
for _ in range(step, cfg.steps):
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
batch = next(dl_iter)
|
batch = next(dl_iter)
|
||||||
|
|||||||
Reference in New Issue
Block a user