mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-20 19:19:56 +00:00
Enhance training and logging functionality with accelerator support
- Added support for multi-GPU training by introducing an `accelerator` parameter in training functions. - Updated `update_policy` to handle gradient updates based on the presence of an accelerator. - Modified logging to prevent duplicate messages in non-main processes. - Enhanced `set_seed` and `get_safe_torch_device` functions to accommodate accelerator usage. - Updated `MetricsTracker` to account for the number of processes when calculating metrics. - Introduced a new feature in `pyproject.toml` for the `accelerate` library dependency.
This commit is contained in:
@@ -125,6 +125,7 @@ hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.11", "lerobot[grpcio-dep]"
|
|||||||
|
|
||||||
# Features
|
# Features
|
||||||
async = ["lerobot[grpcio-dep]", "matplotlib>=3.10.3"]
|
async = ["lerobot[grpcio-dep]", "matplotlib>=3.10.3"]
|
||||||
|
accelerate = ["accelerate>=1.10.0"]
|
||||||
|
|
||||||
# Development
|
# Development
|
||||||
dev = ["pre-commit>=3.7.0", "debugpy>=1.8.1", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1"]
|
dev = ["pre-commit>=3.7.0", "debugpy>=1.8.1", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1"]
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ import time
|
|||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from pprint import pformat
|
from pprint import pformat
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
from collections.abc import Callable
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
@@ -51,6 +52,7 @@ from lerobot.utils.utils import (
|
|||||||
get_safe_torch_device,
|
get_safe_torch_device,
|
||||||
has_method,
|
has_method,
|
||||||
init_logging,
|
init_logging,
|
||||||
|
is_launched_with_accelerate,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -64,6 +66,7 @@ def update_policy(
|
|||||||
lr_scheduler=None,
|
lr_scheduler=None,
|
||||||
use_amp: bool = False,
|
use_amp: bool = False,
|
||||||
lock=None,
|
lock=None,
|
||||||
|
accelerator: Callable | None = None,
|
||||||
) -> tuple[MetricsTracker, dict]:
|
) -> tuple[MetricsTracker, dict]:
|
||||||
"""
|
"""
|
||||||
Performs a single training step to update the policy's weights.
|
Performs a single training step to update the policy's weights.
|
||||||
@@ -81,6 +84,7 @@ def update_policy(
|
|||||||
lr_scheduler: An optional learning rate scheduler.
|
lr_scheduler: An optional learning rate scheduler.
|
||||||
use_amp: A boolean indicating whether to use automatic mixed precision.
|
use_amp: A boolean indicating whether to use automatic mixed precision.
|
||||||
lock: An optional lock for thread-safe optimizer updates.
|
lock: An optional lock for thread-safe optimizer updates.
|
||||||
|
accelerator: An optional accelerator, for multi-gpu training.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A tuple containing:
|
A tuple containing:
|
||||||
@@ -90,26 +94,36 @@ def update_policy(
|
|||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
device = get_device_from_parameters(policy)
|
device = get_device_from_parameters(policy)
|
||||||
policy.train()
|
policy.train()
|
||||||
with torch.autocast(device_type=device.type) if use_amp else nullcontext():
|
with torch.autocast(device_type=device.type) if use_amp and accelerator is None else nullcontext():
|
||||||
loss, output_dict = policy.forward(batch)
|
loss, output_dict = policy.forward(batch)
|
||||||
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
|
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
|
||||||
grad_scaler.scale(loss).backward()
|
|
||||||
|
|
||||||
# Unscale the gradient of the optimizer's assigned params in-place **prior to gradient clipping**.
|
if accelerator:
|
||||||
grad_scaler.unscale_(optimizer)
|
accelerator.backward(loss)
|
||||||
|
accelerator.unscale_gradients(optimizer=optimizer)
|
||||||
|
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||||
|
policy.parameters(),
|
||||||
|
grad_clip_norm,
|
||||||
|
error_if_nonfinite=False,
|
||||||
|
)
|
||||||
|
optimizer.step()
|
||||||
|
else:
|
||||||
|
grad_scaler.scale(loss).backward()
|
||||||
|
# Unscale the gradient of the optimizer's assigned params in-place **prior to gradient clipping**.
|
||||||
|
grad_scaler.unscale_(optimizer)
|
||||||
|
|
||||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||||
policy.parameters(),
|
policy.parameters(),
|
||||||
grad_clip_norm,
|
grad_clip_norm,
|
||||||
error_if_nonfinite=False,
|
error_if_nonfinite=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Optimizer's gradients are already unscaled, so scaler.step does not unscale them,
|
# Optimizer's gradients are already unscaled, so scaler.step does not unscale them,
|
||||||
# although it still skips optimizer.step() if the gradients contain infs or NaNs.
|
# although it still skips optimizer.step() if the gradients contain infs or NaNs.
|
||||||
with lock if lock is not None else nullcontext():
|
with lock if lock is not None else nullcontext():
|
||||||
grad_scaler.step(optimizer)
|
grad_scaler.step(optimizer)
|
||||||
# Updates the scale for next iteration.
|
# Updates the scale for next iteration.
|
||||||
grad_scaler.update()
|
grad_scaler.update()
|
||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
@@ -117,9 +131,13 @@ def update_policy(
|
|||||||
if lr_scheduler is not None:
|
if lr_scheduler is not None:
|
||||||
lr_scheduler.step()
|
lr_scheduler.step()
|
||||||
|
|
||||||
if has_method(policy, "update"):
|
if accelerator:
|
||||||
# To possibly update an internal buffer (for instance an Exponential Moving Average like in TDMPC).
|
if has_method(accelerator.unwrap_model(policy, keep_fp32_wrapper=True), "update"):
|
||||||
policy.update()
|
accelerator.unwrap_model(policy, keep_fp32_wrapper=True).update()
|
||||||
|
else:
|
||||||
|
if has_method(policy, "update"):
|
||||||
|
# To possibly update an internal buffer (for instance an Exponential Moving Average like in TDMPC).
|
||||||
|
policy.update()
|
||||||
|
|
||||||
train_metrics.loss = loss.item()
|
train_metrics.loss = loss.item()
|
||||||
train_metrics.grad_norm = grad_norm.item()
|
train_metrics.grad_norm = grad_norm.item()
|
||||||
@@ -129,7 +147,7 @@ def update_policy(
|
|||||||
|
|
||||||
|
|
||||||
@parser.wrap()
|
@parser.wrap()
|
||||||
def train(cfg: TrainPipelineConfig):
|
def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None):
|
||||||
"""
|
"""
|
||||||
Main function to train a policy.
|
Main function to train a policy.
|
||||||
|
|
||||||
@@ -147,6 +165,10 @@ def train(cfg: TrainPipelineConfig):
|
|||||||
cfg.validate()
|
cfg.validate()
|
||||||
logging.info(pformat(cfg.to_dict()))
|
logging.info(pformat(cfg.to_dict()))
|
||||||
|
|
||||||
|
if accelerator and not accelerator.is_main_process:
|
||||||
|
# Disable logging on non-main processes.
|
||||||
|
cfg.wandb.enable = False
|
||||||
|
|
||||||
if cfg.wandb.enable and cfg.wandb.project:
|
if cfg.wandb.enable and cfg.wandb.project:
|
||||||
wandb_logger = WandBLogger(cfg)
|
wandb_logger = WandBLogger(cfg)
|
||||||
else:
|
else:
|
||||||
@@ -154,10 +176,10 @@ def train(cfg: TrainPipelineConfig):
|
|||||||
logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
|
logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
|
||||||
|
|
||||||
if cfg.seed is not None:
|
if cfg.seed is not None:
|
||||||
set_seed(cfg.seed)
|
set_seed(cfg.seed, accelerator=accelerator)
|
||||||
|
|
||||||
# Check device is available
|
# Check device is available
|
||||||
device = get_safe_torch_device(cfg.policy.device, log=True)
|
device = get_safe_torch_device(cfg.policy.device, log=True, accelerator=accelerator)
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
|
||||||
@@ -177,6 +199,7 @@ def train(cfg: TrainPipelineConfig):
|
|||||||
cfg=cfg.policy,
|
cfg=cfg.policy,
|
||||||
ds_meta=dataset.meta,
|
ds_meta=dataset.meta,
|
||||||
)
|
)
|
||||||
|
policy.to(device)
|
||||||
|
|
||||||
# Create processors - only provide dataset_stats if not resuming from saved processors
|
# Create processors - only provide dataset_stats if not resuming from saved processors
|
||||||
processor_kwargs = {}
|
processor_kwargs = {}
|
||||||
@@ -221,14 +244,15 @@ def train(cfg: TrainPipelineConfig):
|
|||||||
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
||||||
num_total_params = sum(p.numel() for p in policy.parameters())
|
num_total_params = sum(p.numel() for p in policy.parameters())
|
||||||
|
|
||||||
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
|
if not accelerator or accelerator.is_main_process:
|
||||||
if cfg.env is not None:
|
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
|
||||||
logging.info(f"{cfg.env.task=}")
|
if cfg.env is not None:
|
||||||
logging.info(f"{cfg.steps=} ({format_big_number(cfg.steps)})")
|
logging.info(f"{cfg.env.task=}")
|
||||||
logging.info(f"{dataset.num_frames=} ({format_big_number(dataset.num_frames)})")
|
logging.info(f"{cfg.steps=} ({format_big_number(cfg.steps)})")
|
||||||
logging.info(f"{dataset.num_episodes=}")
|
logging.info(f"{dataset.num_frames=} ({format_big_number(dataset.num_frames)})")
|
||||||
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
|
logging.info(f"{dataset.num_episodes=}")
|
||||||
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
|
||||||
|
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
||||||
|
|
||||||
# create dataloader for offline training
|
# create dataloader for offline training
|
||||||
if hasattr(cfg.policy, "drop_n_last_frames"):
|
if hasattr(cfg.policy, "drop_n_last_frames"):
|
||||||
@@ -253,6 +277,10 @@ def train(cfg: TrainPipelineConfig):
|
|||||||
drop_last=False,
|
drop_last=False,
|
||||||
prefetch_factor=2,
|
prefetch_factor=2,
|
||||||
)
|
)
|
||||||
|
if accelerator:
|
||||||
|
policy, optimizer, dataloader, lr_scheduler = accelerator.prepare(
|
||||||
|
policy, optimizer, dataloader, lr_scheduler
|
||||||
|
)
|
||||||
dl_iter = cycle(dataloader)
|
dl_iter = cycle(dataloader)
|
||||||
|
|
||||||
policy.train()
|
policy.train()
|
||||||
@@ -266,10 +294,16 @@ def train(cfg: TrainPipelineConfig):
|
|||||||
}
|
}
|
||||||
|
|
||||||
train_tracker = MetricsTracker(
|
train_tracker = MetricsTracker(
|
||||||
cfg.batch_size, dataset.num_frames, dataset.num_episodes, train_metrics, initial_step=step
|
cfg.batch_size,
|
||||||
|
dataset.num_frames,
|
||||||
|
dataset.num_episodes,
|
||||||
|
train_metrics,
|
||||||
|
initial_step=step,
|
||||||
|
accelerator=accelerator,
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("Start offline training on a fixed dataset")
|
if not accelerator or accelerator.is_main_process:
|
||||||
|
logging.info("Start offline training on a fixed dataset")
|
||||||
for _ in range(step, cfg.steps):
|
for _ in range(step, cfg.steps):
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
batch = next(dl_iter)
|
batch = next(dl_iter)
|
||||||
@@ -285,15 +319,26 @@ def train(cfg: TrainPipelineConfig):
|
|||||||
grad_scaler=grad_scaler,
|
grad_scaler=grad_scaler,
|
||||||
lr_scheduler=lr_scheduler,
|
lr_scheduler=lr_scheduler,
|
||||||
use_amp=cfg.policy.use_amp,
|
use_amp=cfg.policy.use_amp,
|
||||||
|
accelerator=accelerator,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
|
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
|
||||||
# increment `step` here.
|
# increment `step` here.
|
||||||
step += 1
|
step += 1
|
||||||
train_tracker.step()
|
train_tracker.step()
|
||||||
is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0
|
is_log_step = (
|
||||||
is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps
|
cfg.log_freq > 0 and step % cfg.log_freq == 0 and (not accelerator or accelerator.is_main_process)
|
||||||
is_eval_step = cfg.eval_freq > 0 and step % cfg.eval_freq == 0
|
)
|
||||||
|
is_saving_step = (
|
||||||
|
step % cfg.save_freq == 0
|
||||||
|
or step == cfg.steps
|
||||||
|
and (not accelerator or accelerator.is_main_process)
|
||||||
|
)
|
||||||
|
is_eval_step = (
|
||||||
|
cfg.eval_freq > 0
|
||||||
|
and step % cfg.eval_freq == 0
|
||||||
|
and (not accelerator or accelerator.is_main_process)
|
||||||
|
)
|
||||||
|
|
||||||
if is_log_step:
|
if is_log_step:
|
||||||
logging.info(train_tracker)
|
logging.info(train_tracker)
|
||||||
@@ -308,22 +353,31 @@ def train(cfg: TrainPipelineConfig):
|
|||||||
logging.info(f"Checkpoint policy after step {step}")
|
logging.info(f"Checkpoint policy after step {step}")
|
||||||
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step)
|
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step)
|
||||||
save_checkpoint(
|
save_checkpoint(
|
||||||
checkpoint_dir, step, cfg, policy, optimizer, lr_scheduler, preprocessor, postprocessor
|
checkpoint_dir=checkpoint_dir,
|
||||||
|
step=step,
|
||||||
|
cfg=cfg,
|
||||||
|
policy=policy if not accelerator else accelerator.unwrap_model(policy),
|
||||||
|
optimizer=optimizer,
|
||||||
|
scheduler=lr_scheduler,
|
||||||
)
|
)
|
||||||
update_last_checkpoint(checkpoint_dir)
|
update_last_checkpoint(checkpoint_dir)
|
||||||
if wandb_logger:
|
if wandb_logger:
|
||||||
wandb_logger.log_policy(checkpoint_dir)
|
wandb_logger.log_policy(checkpoint_dir)
|
||||||
|
|
||||||
|
if accelerator:
|
||||||
|
accelerator.wait_for_everyone()
|
||||||
if cfg.env and is_eval_step:
|
if cfg.env and is_eval_step:
|
||||||
step_id = get_step_identifier(step, cfg.steps)
|
step_id = get_step_identifier(step, cfg.steps)
|
||||||
logging.info(f"Eval policy at step {step}")
|
logging.info(f"Eval policy at step {step}")
|
||||||
with (
|
with (
|
||||||
torch.no_grad(),
|
torch.no_grad(),
|
||||||
torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(),
|
torch.autocast(device_type=device.type)
|
||||||
|
if cfg.policy.use_amp and not accelerator
|
||||||
|
else nullcontext(),
|
||||||
):
|
):
|
||||||
eval_info = eval_policy_all(
|
eval_info = eval_policy_all(
|
||||||
envs=eval_env, # dict[suite][task_id] -> vec_env
|
envs=eval_env, # dict[suite][task_id] -> vec_env
|
||||||
policy=policy,
|
policy=policy if not accelerator else accelerator.unwrap_model(policy),
|
||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
postprocessor=postprocessor,
|
postprocessor=postprocessor,
|
||||||
n_episodes=cfg.eval.n_episodes,
|
n_episodes=cfg.eval.n_episodes,
|
||||||
@@ -346,7 +400,12 @@ def train(cfg: TrainPipelineConfig):
|
|||||||
"eval_s": AverageMeter("eval_s", ":.3f"),
|
"eval_s": AverageMeter("eval_s", ":.3f"),
|
||||||
}
|
}
|
||||||
eval_tracker = MetricsTracker(
|
eval_tracker = MetricsTracker(
|
||||||
cfg.batch_size, dataset.num_frames, dataset.num_episodes, eval_metrics, initial_step=step
|
cfg.batch_size,
|
||||||
|
dataset.num_frames,
|
||||||
|
dataset.num_episodes,
|
||||||
|
eval_metrics,
|
||||||
|
initial_step=step,
|
||||||
|
accelerator=accelerator,
|
||||||
)
|
)
|
||||||
eval_tracker.eval_s = aggregated.pop("eval_s")
|
eval_tracker.eval_s = aggregated.pop("eval_s")
|
||||||
eval_tracker.avg_sum_reward = aggregated.pop("avg_sum_reward")
|
eval_tracker.avg_sum_reward = aggregated.pop("avg_sum_reward")
|
||||||
@@ -358,7 +417,9 @@ def train(cfg: TrainPipelineConfig):
|
|||||||
|
|
||||||
if eval_env:
|
if eval_env:
|
||||||
close_envs(eval_env)
|
close_envs(eval_env)
|
||||||
logging.info("End of training")
|
|
||||||
|
if not accelerator or accelerator.is_main_process:
|
||||||
|
logging.info("End of training")
|
||||||
|
|
||||||
if cfg.policy.push_to_hub:
|
if cfg.policy.push_to_hub:
|
||||||
policy.push_model_to_hub(cfg)
|
policy.push_model_to_hub(cfg)
|
||||||
@@ -372,4 +433,12 @@ def main():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
if is_launched_with_accelerate():
|
||||||
|
import accelerate
|
||||||
|
|
||||||
|
# We set step_scheduler_with_optimizer False to prevent accelerate from
|
||||||
|
# adjusting the lr_scheduler steps based on the num_processes
|
||||||
|
accelerator = accelerate.Accelerator(step_scheduler_with_optimizer=False)
|
||||||
|
train(accelerator=accelerator)
|
||||||
|
else:
|
||||||
|
train()
|
||||||
|
|||||||
@@ -13,6 +13,7 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
from collections.abc import Callable
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from lerobot.utils.utils import format_big_number
|
from lerobot.utils.utils import format_big_number
|
||||||
@@ -84,6 +85,7 @@ class MetricsTracker:
|
|||||||
"samples",
|
"samples",
|
||||||
"episodes",
|
"episodes",
|
||||||
"epochs",
|
"epochs",
|
||||||
|
"accelerator",
|
||||||
]
|
]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -93,6 +95,7 @@ class MetricsTracker:
|
|||||||
num_episodes: int,
|
num_episodes: int,
|
||||||
metrics: dict[str, AverageMeter],
|
metrics: dict[str, AverageMeter],
|
||||||
initial_step: int = 0,
|
initial_step: int = 0,
|
||||||
|
accelerator: Callable | None = None,
|
||||||
):
|
):
|
||||||
self.__dict__.update(dict.fromkeys(self.__keys__))
|
self.__dict__.update(dict.fromkeys(self.__keys__))
|
||||||
self._batch_size = batch_size
|
self._batch_size = batch_size
|
||||||
@@ -106,6 +109,7 @@ class MetricsTracker:
|
|||||||
self.samples = self.steps * self._batch_size
|
self.samples = self.steps * self._batch_size
|
||||||
self.episodes = self.samples / self._avg_samples_per_ep
|
self.episodes = self.samples / self._avg_samples_per_ep
|
||||||
self.epochs = self.samples / self._num_frames
|
self.epochs = self.samples / self._num_frames
|
||||||
|
self.accelerator = accelerator
|
||||||
|
|
||||||
def __getattr__(self, name: str) -> int | dict[str, AverageMeter] | AverageMeter | Any:
|
def __getattr__(self, name: str) -> int | dict[str, AverageMeter] | AverageMeter | Any:
|
||||||
if name in self.__dict__:
|
if name in self.__dict__:
|
||||||
@@ -128,7 +132,7 @@ class MetricsTracker:
|
|||||||
Updates metrics that depend on 'step' for one step.
|
Updates metrics that depend on 'step' for one step.
|
||||||
"""
|
"""
|
||||||
self.steps += 1
|
self.steps += 1
|
||||||
self.samples += self._batch_size
|
self.samples += self._batch_size * (self.accelerator.num_processes if self.accelerator else 1)
|
||||||
self.episodes = self.samples / self._avg_samples_per_ep
|
self.episodes = self.samples / self._avg_samples_per_ep
|
||||||
self.epochs = self.samples / self._num_frames
|
self.epochs = self.samples / self._num_frames
|
||||||
|
|
||||||
|
|||||||
@@ -14,7 +14,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import random
|
import random
|
||||||
from collections.abc import Generator
|
from collections.abc import Callable, Generator
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@@ -164,14 +164,20 @@ def set_rng_state(random_state_dict: dict[str, Any]):
|
|||||||
torch.cuda.random.set_rng_state(random_state_dict["torch_cuda_random_state"])
|
torch.cuda.random.set_rng_state(random_state_dict["torch_cuda_random_state"])
|
||||||
|
|
||||||
|
|
||||||
def set_seed(seed) -> None:
|
def set_seed(seed, accelerator: Callable | None = None) -> None:
|
||||||
"""Set seed for reproducibility."""
|
"""Set seed for reproducibility."""
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
np.random.seed(seed)
|
np.random.seed(seed)
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.manual_seed_all(seed)
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
|
||||||
|
if accelerator:
|
||||||
|
from accelerate.utils import set_seed
|
||||||
|
|
||||||
|
set_seed(seed)
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def seeded_context(seed: int) -> Generator[None, None, None]:
|
def seeded_context(seed: int) -> Generator[None, None, None]:
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ import select
|
|||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
|
from collections.abc import Callable
|
||||||
from copy import copy, deepcopy
|
from copy import copy, deepcopy
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -49,13 +50,15 @@ def auto_select_torch_device() -> torch.device:
|
|||||||
|
|
||||||
|
|
||||||
# TODO(Steven): Remove log. log shouldn't be an argument, this should be handled by the logger level
|
# TODO(Steven): Remove log. log shouldn't be an argument, this should be handled by the logger level
|
||||||
def get_safe_torch_device(try_device: str, log: bool = False) -> torch.device:
|
def get_safe_torch_device(
|
||||||
|
try_device: str, log: bool = False, accelerator: Callable | None = None
|
||||||
|
) -> torch.device:
|
||||||
"""Given a string, return a torch.device with checks on whether the device is available."""
|
"""Given a string, return a torch.device with checks on whether the device is available."""
|
||||||
try_device = str(try_device)
|
try_device = str(try_device)
|
||||||
match try_device:
|
match try_device:
|
||||||
case "cuda":
|
case "cuda":
|
||||||
assert torch.cuda.is_available()
|
assert torch.cuda.is_available()
|
||||||
device = torch.device("cuda")
|
device = accelerator.device if accelerator else torch.device("cuda")
|
||||||
case "mps":
|
case "mps":
|
||||||
assert torch.backends.mps.is_available()
|
assert torch.backends.mps.is_available()
|
||||||
device = torch.device("mps")
|
device = torch.device("mps")
|
||||||
@@ -109,6 +112,7 @@ def init_logging(
|
|||||||
display_pid: bool = False,
|
display_pid: bool = False,
|
||||||
console_level: str = "INFO",
|
console_level: str = "INFO",
|
||||||
file_level: str = "DEBUG",
|
file_level: str = "DEBUG",
|
||||||
|
accelerator: Callable | None = None,
|
||||||
):
|
):
|
||||||
def custom_format(record: logging.LogRecord) -> str:
|
def custom_format(record: logging.LogRecord) -> str:
|
||||||
dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
@@ -137,6 +141,10 @@ def init_logging(
|
|||||||
console_handler.setFormatter(formatter)
|
console_handler.setFormatter(formatter)
|
||||||
console_handler.setLevel(console_level.upper())
|
console_handler.setLevel(console_level.upper())
|
||||||
logger.addHandler(console_handler)
|
logger.addHandler(console_handler)
|
||||||
|
if accelerator is not None and not accelerator.is_main_process:
|
||||||
|
# Disable duplicate logging on non-main processes
|
||||||
|
logging.info(f"Setting logging level on non-main process {accelerator.process_index} to WARNING.")
|
||||||
|
logging.getLogger().setLevel(logging.WARNING)
|
||||||
|
|
||||||
# Additionally write logs to file
|
# Additionally write logs to file
|
||||||
if log_file is not None:
|
if log_file is not None:
|
||||||
@@ -158,6 +166,10 @@ def format_big_number(num, precision=0):
|
|||||||
return num
|
return num
|
||||||
|
|
||||||
|
|
||||||
|
def is_launched_with_accelerate() -> bool:
|
||||||
|
return "ACCELERATE_MIXED_PRECISION" in os.environ
|
||||||
|
|
||||||
|
|
||||||
def say(text: str, blocking: bool = False):
|
def say(text: str, blocking: bool = False):
|
||||||
system = platform.system()
|
system = platform.system()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user