Refactor SACPolicy and learner_server for improved clarity and functionality

- Updated the `forward` method in `SACPolicy` to handle loss computation for actor, critic, and temperature models.
- Replaced direct calls to `compute_loss_*` methods with a unified `forward` method in `learner_server`.
- Enhanced batch processing by consolidating input parameters into a single dictionary for better readability and maintainability.
- Removed redundant code and improved documentation for clarity.
This commit is contained in:
AdilZouitine
2025-03-28 16:40:45 +00:00
parent 8b02e81bb5
commit b3ad63cf6e
3 changed files with 96 additions and 42 deletions
+32 -27
View File
@@ -382,15 +382,20 @@ def add_actor_information_and_train(
observation_features, next_observation_features = get_observation_features(
policy=policy, observations=observations, next_observations=next_observations
)
loss_critic = policy.compute_loss_critic(
observations=observations,
actions=actions,
rewards=rewards,
next_observations=next_observations,
done=done,
observation_features=observation_features,
next_observation_features=next_observation_features,
)
# Create a batch dictionary with all required elements for the forward method
forward_batch = {
"action": actions,
"reward": rewards,
"state": observations,
"next_state": next_observations,
"done": done,
"observation_feature": observation_features,
"next_observation_feature": next_observation_features,
}
# Use the forward method for critic loss
loss_critic = policy.forward(forward_batch, model="critic")
optimizers["critic"].zero_grad()
loss_critic.backward()
@@ -422,15 +427,20 @@ def add_actor_information_and_train(
observation_features, next_observation_features = get_observation_features(
policy=policy, observations=observations, next_observations=next_observations
)
loss_critic = policy.compute_loss_critic(
observations=observations,
actions=actions,
rewards=rewards,
next_observations=next_observations,
done=done,
observation_features=observation_features,
next_observation_features=next_observation_features,
)
# Create a batch dictionary with all required elements for the forward method
forward_batch = {
"action": actions,
"reward": rewards,
"state": observations,
"next_state": next_observations,
"done": done,
"observation_feature": observation_features,
"next_observation_feature": next_observation_features,
}
# Use the forward method for critic loss
loss_critic = policy.forward(forward_batch, model="critic")
optimizers["critic"].zero_grad()
loss_critic.backward()
@@ -447,10 +457,8 @@ def add_actor_information_and_train(
if optimization_step % policy_update_freq == 0:
for _ in range(policy_update_freq):
loss_actor = policy.compute_loss_actor(
observations=observations,
observation_features=observation_features,
)
# Use the forward method for actor loss
loss_actor = policy.forward(forward_batch, model="actor")
optimizers["actor"].zero_grad()
loss_actor.backward()
@@ -465,11 +473,8 @@ def add_actor_information_and_train(
training_infos["loss_actor"] = loss_actor.item()
training_infos["actor_grad_norm"] = actor_grad_norm
# Temperature optimization
loss_temperature = policy.compute_loss_temperature(
observations=observations,
observation_features=observation_features,
)
# Temperature optimization using forward method
loss_temperature = policy.forward(forward_batch, model="temperature")
optimizers["temperature"].zero_grad()
loss_temperature.backward()