mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-28 15:09:51 +00:00
feat(train): add accelerate for multi gpu training (#2154)
* Enhance training and logging functionality with accelerator support - Added support for multi-GPU training by introducing an `accelerator` parameter in training functions. - Updated `update_policy` to handle gradient updates based on the presence of an accelerator. - Modified logging to prevent duplicate messages in non-main processes. - Enhanced `set_seed` and `get_safe_torch_device` functions to accommodate accelerator usage. - Updated `MetricsTracker` to account for the number of processes when calculating metrics. - Introduced a new feature in `pyproject.toml` for the `accelerate` library dependency. * Initialize logging in training script for both main and non-main processes - Added `init_logging` calls to ensure proper logging setup when using the accelerator and in standard training mode. - This change enhances the clarity and consistency of logging during training sessions. * add docs and only push model once * Place logging under accelerate and update docs * fix pre commit * only log in main process * main logging * try with local rank * add tests * change runner * fix test * dont push to hub in multi gpu tests * pre download dataset in tests * small fixes * fix path optimizer state * update docs, and small improvements in train * simplify accelerate main process detection * small improvements in train * fix OOM bug * change accelerate detection * add some debugging * always use accelerate * cleanup update method * cleanup * fix bug * scale lr decay if we reduce steps * cleanup logging * fix formatting * encorperate feedback pr * add min memory to cpu tests * use accelerate to determin logging * fix precommit and fix tests * chore: minor details --------- Co-authored-by: AdilZouitine <adilzouitinegm@gmail.com> Co-authored-by: Steven Palma <steven.palma@huggingface.co>
This commit is contained in:
@@ -13,6 +13,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from lerobot.utils.utils import format_big_number
|
||||
@@ -84,6 +85,7 @@ class MetricsTracker:
|
||||
"samples",
|
||||
"episodes",
|
||||
"epochs",
|
||||
"accelerator",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
@@ -93,6 +95,7 @@ class MetricsTracker:
|
||||
num_episodes: int,
|
||||
metrics: dict[str, AverageMeter],
|
||||
initial_step: int = 0,
|
||||
accelerator: Callable | None = None,
|
||||
):
|
||||
self.__dict__.update(dict.fromkeys(self.__keys__))
|
||||
self._batch_size = batch_size
|
||||
@@ -106,6 +109,7 @@ class MetricsTracker:
|
||||
self.samples = self.steps * self._batch_size
|
||||
self.episodes = self.samples / self._avg_samples_per_ep
|
||||
self.epochs = self.samples / self._num_frames
|
||||
self.accelerator = accelerator
|
||||
|
||||
def __getattr__(self, name: str) -> int | dict[str, AverageMeter] | AverageMeter | Any:
|
||||
if name in self.__dict__:
|
||||
@@ -128,7 +132,7 @@ class MetricsTracker:
|
||||
Updates metrics that depend on 'step' for one step.
|
||||
"""
|
||||
self.steps += 1
|
||||
self.samples += self._batch_size
|
||||
self.samples += self._batch_size * (self.accelerator.num_processes if self.accelerator else 1)
|
||||
self.episodes = self.samples / self._avg_samples_per_ep
|
||||
self.epochs = self.samples / self._num_frames
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import random
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Callable, Generator
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
@@ -164,14 +164,20 @@ def set_rng_state(random_state_dict: dict[str, Any]):
|
||||
torch.cuda.random.set_rng_state(random_state_dict["torch_cuda_random_state"])
|
||||
|
||||
|
||||
def set_seed(seed) -> None:
|
||||
def set_seed(seed, accelerator: Callable | None = None) -> None:
|
||||
"""Set seed for reproducibility."""
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
if accelerator:
|
||||
from accelerate.utils import set_seed as _accelerate_set_seed
|
||||
|
||||
_accelerate_set_seed(seed)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def seeded_context(seed: int) -> Generator[None, None, None]:
|
||||
|
||||
+33
-18
@@ -27,6 +27,7 @@ from statistics import mean
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
from datasets.utils.logging import disable_progress_bar, enable_progress_bar
|
||||
|
||||
|
||||
@@ -110,36 +111,50 @@ def init_logging(
|
||||
display_pid: bool = False,
|
||||
console_level: str = "INFO",
|
||||
file_level: str = "DEBUG",
|
||||
accelerator: Accelerator | None = None,
|
||||
):
|
||||
"""Initialize logging configuration for LeRobot.
|
||||
|
||||
In multi-GPU training, only the main process logs to console to avoid duplicate output.
|
||||
Non-main processes have console logging suppressed but can still log to file.
|
||||
|
||||
Args:
|
||||
log_file: Optional file path to write logs to
|
||||
display_pid: Include process ID in log messages (useful for debugging multi-process)
|
||||
console_level: Logging level for console output
|
||||
file_level: Logging level for file output
|
||||
accelerator: Optional Accelerator instance (for multi-GPU detection)
|
||||
"""
|
||||
|
||||
def custom_format(record: logging.LogRecord) -> str:
|
||||
dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
fnameline = f"{record.pathname}:{record.lineno}"
|
||||
|
||||
# NOTE: Display PID is useful for multi-process logging.
|
||||
if display_pid:
|
||||
pid_str = f"[PID: {os.getpid()}]"
|
||||
message = f"{record.levelname} {pid_str} {dt} {fnameline[-15:]:>15} {record.getMessage()}"
|
||||
else:
|
||||
message = f"{record.levelname} {dt} {fnameline[-15:]:>15} {record.getMessage()}"
|
||||
return message
|
||||
pid_str = f"[PID: {os.getpid()}] " if display_pid else ""
|
||||
return f"{record.levelname} {pid_str}{dt} {fnameline[-15:]:>15} {record.getMessage()}"
|
||||
|
||||
formatter = logging.Formatter()
|
||||
formatter.format = custom_format
|
||||
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(logging.NOTSET) # Set the logger to the lowest level to capture all messages
|
||||
logger.setLevel(logging.NOTSET)
|
||||
|
||||
# Remove unused default handlers
|
||||
for handler in logger.handlers[:]:
|
||||
logger.removeHandler(handler)
|
||||
# Clear any existing handlers
|
||||
logger.handlers.clear()
|
||||
|
||||
# Write logs to console
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setFormatter(formatter)
|
||||
console_handler.setLevel(console_level.upper())
|
||||
logger.addHandler(console_handler)
|
||||
# Determine if this is a non-main process in distributed training
|
||||
is_main_process = accelerator.is_main_process if accelerator is not None else True
|
||||
|
||||
# Console logging (main process only)
|
||||
if is_main_process:
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setFormatter(formatter)
|
||||
console_handler.setLevel(console_level.upper())
|
||||
logger.addHandler(console_handler)
|
||||
else:
|
||||
# Suppress console output for non-main processes
|
||||
logger.addHandler(logging.NullHandler())
|
||||
logger.setLevel(logging.ERROR)
|
||||
|
||||
# Additionally write logs to file
|
||||
if log_file is not None:
|
||||
file_handler = logging.FileHandler(log_file)
|
||||
file_handler.setFormatter(formatter)
|
||||
|
||||
Reference in New Issue
Block a user