mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-19 18:49:52 +00:00
Enhance SACPolicy and learner server for improved grasp critic integration
- Updated SACPolicy to conditionally compute grasp critic losses based on the presence of discrete actions. - Refactored the forward method to handle grasp critic model selection and loss computation more clearly. - Adjusted learner server to utilize optimized parameters for grasp critic during training. - Improved action handling in the ManiskillMockGripperWrapper to accommodate both tuple and single action inputs.
This commit is contained in:
committed by
Adil Zouitine
parent
077d18b439
commit
a54baceabb
@@ -405,12 +405,13 @@ def add_actor_information_and_train(
|
||||
optimizers["critic"].step()
|
||||
|
||||
# Grasp critic optimization (if available)
|
||||
if "loss_grasp_critic" in critic_output:
|
||||
loss_grasp_critic = critic_output["loss_grasp_critic"]
|
||||
if policy.config.num_discrete_actions is not None:
|
||||
discrete_critic_output = policy.forward(forward_batch, model="grasp_critic")
|
||||
loss_grasp_critic = discrete_critic_output["loss_grasp_critic"]
|
||||
optimizers["grasp_critic"].zero_grad()
|
||||
loss_grasp_critic.backward()
|
||||
grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value
|
||||
parameters=policy.grasp_critic.parameters_to_optimize, max_norm=clip_grad_norm_value
|
||||
)
|
||||
optimizers["grasp_critic"].step()
|
||||
|
||||
@@ -467,12 +468,13 @@ def add_actor_information_and_train(
|
||||
}
|
||||
|
||||
# Grasp critic optimization (if available)
|
||||
if "loss_grasp_critic" in critic_output:
|
||||
loss_grasp_critic = critic_output["loss_grasp_critic"]
|
||||
if policy.config.num_discrete_actions is not None:
|
||||
discrete_critic_output = policy.forward(forward_batch, model="grasp_critic")
|
||||
loss_grasp_critic = discrete_critic_output["loss_grasp_critic"]
|
||||
optimizers["grasp_critic"].zero_grad()
|
||||
loss_grasp_critic.backward()
|
||||
grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value
|
||||
parameters=policy.grasp_critic.parameters_to_optimize, max_norm=clip_grad_norm_value
|
||||
).item()
|
||||
optimizers["grasp_critic"].step()
|
||||
|
||||
@@ -759,7 +761,7 @@ def make_optimizers_and_scheduler(cfg: TrainPipelineConfig, policy: nn.Module):
|
||||
|
||||
if cfg.policy.num_discrete_actions is not None:
|
||||
optimizer_grasp_critic = torch.optim.Adam(
|
||||
params=policy.grasp_critic.parameters(), lr=policy.critic_lr
|
||||
params=policy.grasp_critic.parameters_to_optimize, lr=cfg.policy.critic_lr
|
||||
)
|
||||
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=cfg.policy.critic_lr)
|
||||
lr_scheduler = None
|
||||
|
||||
Reference in New Issue
Block a user