mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-18 02:00:03 +00:00
Fix linter issue
This commit is contained in:
@@ -53,84 +53,6 @@ from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.scripts.eval import eval_policy
|
||||
|
||||
|
||||
def make_optimizer_and_scheduler(cfg, policy):
|
||||
if cfg.policy.name == "act":
|
||||
optimizer_params_dicts = [
|
||||
{
|
||||
"params": [
|
||||
p
|
||||
for n, p in policy.named_parameters()
|
||||
if not n.startswith("model.backbone") and p.requires_grad
|
||||
]
|
||||
},
|
||||
{
|
||||
"params": [
|
||||
p
|
||||
for n, p in policy.named_parameters()
|
||||
if n.startswith("model.backbone") and p.requires_grad
|
||||
],
|
||||
"lr": cfg.training.lr_backbone,
|
||||
},
|
||||
]
|
||||
optimizer = torch.optim.AdamW(
|
||||
optimizer_params_dicts,
|
||||
lr=cfg.training.lr,
|
||||
weight_decay=cfg.training.weight_decay,
|
||||
)
|
||||
lr_scheduler = None
|
||||
elif cfg.policy.name == "diffusion":
|
||||
optimizer = torch.optim.Adam(
|
||||
policy.diffusion.parameters(),
|
||||
cfg.training.lr,
|
||||
cfg.training.adam_betas,
|
||||
cfg.training.adam_eps,
|
||||
cfg.training.adam_weight_decay,
|
||||
)
|
||||
from diffusers.optimization import get_scheduler
|
||||
|
||||
lr_scheduler = get_scheduler(
|
||||
cfg.training.lr_scheduler,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=cfg.training.lr_warmup_steps,
|
||||
num_training_steps=cfg.training.offline_steps,
|
||||
)
|
||||
elif policy.name == "tdmpc":
|
||||
optimizer = torch.optim.Adam(policy.parameters(), cfg.training.lr)
|
||||
lr_scheduler = None
|
||||
|
||||
elif policy.name == "sac":
|
||||
optimizer = torch.optim.Adam(
|
||||
[
|
||||
{"params": policy.actor.parameters(), "lr": policy.config.actor_lr},
|
||||
{
|
||||
"params": policy.critic_ensemble.parameters(),
|
||||
"lr": policy.config.critic_lr,
|
||||
},
|
||||
{
|
||||
"params": policy.temperature.parameters(),
|
||||
"lr": policy.config.temperature_lr,
|
||||
},
|
||||
]
|
||||
)
|
||||
lr_scheduler = None
|
||||
|
||||
elif cfg.policy.name == "vqbet":
|
||||
from lerobot.common.policies.vqbet.modeling_vqbet import (
|
||||
VQBeTOptimizer,
|
||||
VQBeTScheduler,
|
||||
)
|
||||
|
||||
optimizer = VQBeTOptimizer(policy, cfg)
|
||||
lr_scheduler = VQBeTScheduler(optimizer, cfg)
|
||||
elif cfg.policy.name == "hilserl_classifier":
|
||||
optimizer = torch.optim.AdamW(policy.parameters(), cfg.policy.learning_rate)
|
||||
lr_scheduler = None
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
return optimizer, lr_scheduler
|
||||
|
||||
|
||||
def update_policy(
|
||||
train_metrics: MetricsTracker,
|
||||
policy: PreTrainedPolicy,
|
||||
|
||||
Reference in New Issue
Block a user