mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-24 21:19:53 +00:00
Refactor SACPolicy and learner_server for improved clarity and functionality
- Updated the `forward` method in `SACPolicy` to handle loss computation for actor, critic, and temperature models. - Replaced direct calls to `compute_loss_*` methods with a unified `forward` method in `learner_server`. - Enhanced batch processing by consolidating input parameters into a single dictionary for better readability and maintainability. - Removed redundant code and improved documentation for clarity.
This commit is contained in:
@@ -382,15 +382,20 @@ def add_actor_information_and_train(
|
||||
observation_features, next_observation_features = get_observation_features(
|
||||
policy=policy, observations=observations, next_observations=next_observations
|
||||
)
|
||||
loss_critic = policy.compute_loss_critic(
|
||||
observations=observations,
|
||||
actions=actions,
|
||||
rewards=rewards,
|
||||
next_observations=next_observations,
|
||||
done=done,
|
||||
observation_features=observation_features,
|
||||
next_observation_features=next_observation_features,
|
||||
)
|
||||
|
||||
# Create a batch dictionary with all required elements for the forward method
|
||||
forward_batch = {
|
||||
"action": actions,
|
||||
"reward": rewards,
|
||||
"state": observations,
|
||||
"next_state": next_observations,
|
||||
"done": done,
|
||||
"observation_feature": observation_features,
|
||||
"next_observation_feature": next_observation_features,
|
||||
}
|
||||
|
||||
# Use the forward method for critic loss
|
||||
loss_critic = policy.forward(forward_batch, model="critic")
|
||||
optimizers["critic"].zero_grad()
|
||||
loss_critic.backward()
|
||||
|
||||
@@ -422,15 +427,20 @@ def add_actor_information_and_train(
|
||||
observation_features, next_observation_features = get_observation_features(
|
||||
policy=policy, observations=observations, next_observations=next_observations
|
||||
)
|
||||
loss_critic = policy.compute_loss_critic(
|
||||
observations=observations,
|
||||
actions=actions,
|
||||
rewards=rewards,
|
||||
next_observations=next_observations,
|
||||
done=done,
|
||||
observation_features=observation_features,
|
||||
next_observation_features=next_observation_features,
|
||||
)
|
||||
|
||||
# Create a batch dictionary with all required elements for the forward method
|
||||
forward_batch = {
|
||||
"action": actions,
|
||||
"reward": rewards,
|
||||
"state": observations,
|
||||
"next_state": next_observations,
|
||||
"done": done,
|
||||
"observation_feature": observation_features,
|
||||
"next_observation_feature": next_observation_features,
|
||||
}
|
||||
|
||||
# Use the forward method for critic loss
|
||||
loss_critic = policy.forward(forward_batch, model="critic")
|
||||
optimizers["critic"].zero_grad()
|
||||
loss_critic.backward()
|
||||
|
||||
@@ -447,10 +457,8 @@ def add_actor_information_and_train(
|
||||
|
||||
if optimization_step % policy_update_freq == 0:
|
||||
for _ in range(policy_update_freq):
|
||||
loss_actor = policy.compute_loss_actor(
|
||||
observations=observations,
|
||||
observation_features=observation_features,
|
||||
)
|
||||
# Use the forward method for actor loss
|
||||
loss_actor = policy.forward(forward_batch, model="actor")
|
||||
|
||||
optimizers["actor"].zero_grad()
|
||||
loss_actor.backward()
|
||||
@@ -465,11 +473,8 @@ def add_actor_information_and_train(
|
||||
training_infos["loss_actor"] = loss_actor.item()
|
||||
training_infos["actor_grad_norm"] = actor_grad_norm
|
||||
|
||||
# Temperature optimization
|
||||
loss_temperature = policy.compute_loss_temperature(
|
||||
observations=observations,
|
||||
observation_features=observation_features,
|
||||
)
|
||||
# Temperature optimization using forward method
|
||||
loss_temperature = policy.forward(forward_batch, model="temperature")
|
||||
optimizers["temperature"].zero_grad()
|
||||
loss_temperature.backward()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user