Added caching function in the learner_server and modeling sac in order to limit the number of forward passes through the pretrained encoder when its frozen.

Added tensordict dependencies
Updated the version of torch and torchvision

Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
Michel Aractingi
2025-02-21 10:13:43 +00:00
committed by AdilZouitine
parent d48161da1b
commit ff223c106d
8 changed files with 66 additions and 42 deletions
+30 -5
View File
@@ -169,6 +169,25 @@ def initialize_replay_buffer(
)
def get_observation_features(policy: SACPolicy, observations: torch.Tensor, next_observations: torch.Tensor) -> tuple[torch.Tensor | None, torch.Tensor | None]:
if policy.config.vision_encoder_name is None or not policy.config.freeze_vision_encoder:
return None, None
with torch.no_grad():
observation_features = (
policy.actor.encoder(observations)
if policy.actor.encoder is not None
else None
)
next_observation_features = (
policy.actor.encoder(next_observations)
if policy.actor.encoder is not None
else None
)
return observation_features, next_observation_features
def start_learner_threads(
cfg: DictConfig,
device: str,
@@ -345,9 +364,6 @@ def add_actor_information_and_train(
if len(replay_buffer) < cfg.training.online_step_before_learning:
continue
# logging.info(f"Size of replay buffer: {len(replay_buffer)}")
# logging.info(f"Size of offline replay buffer: {len(offline_replay_buffer)}")
time_for_one_optimization_step = time.time()
for _ in range(cfg.policy.utd_ratio - 1):
batch = replay_buffer.sample(batch_size)
@@ -356,6 +372,7 @@ def add_actor_information_and_train(
batch_offline = offline_replay_buffer.sample(batch_size)
batch = concatenate_batch_transitions(batch, batch_offline)
actions = batch["action"]
rewards = batch["reward"]
observations = batch["state"]
@@ -365,6 +382,7 @@ def add_actor_information_and_train(
observations=observations, actions=actions, next_state=next_observations
)
observation_features, next_observation_features = get_observation_features(policy, observations, next_observations)
with policy_lock:
loss_critic = policy.compute_loss_critic(
observations=observations,
@@ -372,6 +390,8 @@ def add_actor_information_and_train(
rewards=rewards,
next_observations=next_observations,
done=done,
observation_features=observation_features,
next_observation_features=next_observation_features,
)
optimizers["critic"].zero_grad()
loss_critic.backward()
@@ -395,6 +415,7 @@ def add_actor_information_and_train(
observations=observations, actions=actions, next_state=next_observations
)
observation_features, next_observation_features = get_observation_features(policy, observations, next_observations)
with policy_lock:
loss_critic = policy.compute_loss_critic(
observations=observations,
@@ -402,6 +423,8 @@ def add_actor_information_and_train(
rewards=rewards,
next_observations=next_observations,
done=done,
observation_features=observation_features,
next_observation_features=next_observation_features,
)
optimizers["critic"].zero_grad()
loss_critic.backward()
@@ -413,7 +436,8 @@ def add_actor_information_and_train(
if optimization_step % cfg.training.policy_update_freq == 0:
for _ in range(cfg.training.policy_update_freq):
with policy_lock:
loss_actor = policy.compute_loss_actor(observations=observations)
loss_actor = policy.compute_loss_actor(observations=observations,
observation_features=observation_features)
optimizers["actor"].zero_grad()
loss_actor.backward()
@@ -422,7 +446,8 @@ def add_actor_information_and_train(
training_infos["loss_actor"] = loss_actor.item()
loss_temperature = policy.compute_loss_temperature(
observations=observations
observations=observations,
observation_features=observation_features
)
optimizers["temperature"].zero_grad()
loss_temperature.backward()