Enhance SACPolicy and learner server for improved grasp critic integration

- Updated SACPolicy to conditionally compute grasp critic losses based on the presence of discrete actions.
- Refactored the forward method to handle grasp critic model selection and loss computation more clearly.
- Adjusted learner server to utilize optimized parameters for grasp critic during training.
- Improved action handling in the ManiskillMockGripperWrapper to accommodate both tuple and single action inputs.
This commit is contained in:
AdilZouitine
2025-04-02 15:50:39 +00:00
committed by Adil Zouitine
parent 077d18b439
commit a54baceabb
3 changed files with 72 additions and 50 deletions
+9 -7
View File
@@ -405,12 +405,13 @@ def add_actor_information_and_train(
optimizers["critic"].step()
# Grasp critic optimization (if available)
if "loss_grasp_critic" in critic_output:
loss_grasp_critic = critic_output["loss_grasp_critic"]
if policy.config.num_discrete_actions is not None:
discrete_critic_output = policy.forward(forward_batch, model="grasp_critic")
loss_grasp_critic = discrete_critic_output["loss_grasp_critic"]
optimizers["grasp_critic"].zero_grad()
loss_grasp_critic.backward()
grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value
parameters=policy.grasp_critic.parameters_to_optimize, max_norm=clip_grad_norm_value
)
optimizers["grasp_critic"].step()
@@ -467,12 +468,13 @@ def add_actor_information_and_train(
}
# Grasp critic optimization (if available)
if "loss_grasp_critic" in critic_output:
loss_grasp_critic = critic_output["loss_grasp_critic"]
if policy.config.num_discrete_actions is not None:
discrete_critic_output = policy.forward(forward_batch, model="grasp_critic")
loss_grasp_critic = discrete_critic_output["loss_grasp_critic"]
optimizers["grasp_critic"].zero_grad()
loss_grasp_critic.backward()
grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value
parameters=policy.grasp_critic.parameters_to_optimize, max_norm=clip_grad_norm_value
).item()
optimizers["grasp_critic"].step()
@@ -759,7 +761,7 @@ def make_optimizers_and_scheduler(cfg: TrainPipelineConfig, policy: nn.Module):
if cfg.policy.num_discrete_actions is not None:
optimizer_grasp_critic = torch.optim.Adam(
params=policy.grasp_critic.parameters(), lr=policy.critic_lr
params=policy.grasp_critic.parameters_to_optimize, lr=cfg.policy.critic_lr
)
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=cfg.policy.critic_lr)
lr_scheduler = None