mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-20 19:19:56 +00:00
Merge (No verify)
This commit is contained in:
@@ -120,7 +120,7 @@ def train(cfg: TrainPipelineConfig):
|
||||
set_seed(cfg.seed)
|
||||
|
||||
# Check device is available
|
||||
device = get_safe_torch_device(cfg.device, log=True)
|
||||
device = get_safe_torch_device(cfg.policy.device, log=True)
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
@@ -133,18 +133,17 @@ def train(cfg: TrainPipelineConfig):
|
||||
eval_env = None
|
||||
if cfg.eval_freq > 0 and cfg.env is not None:
|
||||
logging.info("Creating env")
|
||||
eval_env = make_env(cfg.env, n_envs=cfg.eval.batch_size)
|
||||
eval_env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs)
|
||||
|
||||
logging.info("Creating policy")
|
||||
policy = make_policy(
|
||||
cfg=cfg.policy,
|
||||
device=device,
|
||||
ds_meta=dataset.meta,
|
||||
)
|
||||
|
||||
logging.info("Creating optimizer and scheduler")
|
||||
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
|
||||
grad_scaler = GradScaler(device, enabled=cfg.use_amp)
|
||||
grad_scaler = GradScaler(device.type, enabled=cfg.policy.use_amp)
|
||||
|
||||
step = 0 # number of policy updates (forward + backward + optim)
|
||||
|
||||
@@ -219,7 +218,7 @@ def train(cfg: TrainPipelineConfig):
|
||||
cfg.optimizer.grad_clip_norm,
|
||||
grad_scaler=grad_scaler,
|
||||
lr_scheduler=lr_scheduler,
|
||||
use_amp=cfg.use_amp,
|
||||
use_amp=cfg.policy.use_amp,
|
||||
)
|
||||
|
||||
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
|
||||
@@ -250,7 +249,10 @@ def train(cfg: TrainPipelineConfig):
|
||||
if cfg.env and is_eval_step:
|
||||
step_id = get_step_identifier(step, cfg.steps)
|
||||
logging.info(f"Eval policy at step {step}")
|
||||
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext():
|
||||
with (
|
||||
torch.no_grad(),
|
||||
torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(),
|
||||
):
|
||||
eval_info = eval_policy(
|
||||
eval_env,
|
||||
policy,
|
||||
|
||||
Reference in New Issue
Block a user