mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 09:09:48 +00:00
Added caching function in the learner_server and modeling sac in order to limit the number of forward passes through the pretrained encoder when its frozen.
Added tensordict dependencies Updated the version of torch and torchvision Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
committed by
AdilZouitine
parent
d48161da1b
commit
ff223c106d
@@ -169,6 +169,25 @@ def initialize_replay_buffer(
|
||||
)
|
||||
|
||||
|
||||
def get_observation_features(policy: SACPolicy, observations: torch.Tensor, next_observations: torch.Tensor) -> tuple[torch.Tensor | None, torch.Tensor | None]:
|
||||
if policy.config.vision_encoder_name is None or not policy.config.freeze_vision_encoder:
|
||||
return None, None
|
||||
|
||||
with torch.no_grad():
|
||||
observation_features = (
|
||||
policy.actor.encoder(observations)
|
||||
if policy.actor.encoder is not None
|
||||
else None
|
||||
)
|
||||
next_observation_features = (
|
||||
policy.actor.encoder(next_observations)
|
||||
if policy.actor.encoder is not None
|
||||
else None
|
||||
)
|
||||
|
||||
return observation_features, next_observation_features
|
||||
|
||||
|
||||
def start_learner_threads(
|
||||
cfg: DictConfig,
|
||||
device: str,
|
||||
@@ -345,9 +364,6 @@ def add_actor_information_and_train(
|
||||
if len(replay_buffer) < cfg.training.online_step_before_learning:
|
||||
continue
|
||||
|
||||
# logging.info(f"Size of replay buffer: {len(replay_buffer)}")
|
||||
# logging.info(f"Size of offline replay buffer: {len(offline_replay_buffer)}")
|
||||
|
||||
time_for_one_optimization_step = time.time()
|
||||
for _ in range(cfg.policy.utd_ratio - 1):
|
||||
batch = replay_buffer.sample(batch_size)
|
||||
@@ -356,6 +372,7 @@ def add_actor_information_and_train(
|
||||
batch_offline = offline_replay_buffer.sample(batch_size)
|
||||
batch = concatenate_batch_transitions(batch, batch_offline)
|
||||
|
||||
|
||||
actions = batch["action"]
|
||||
rewards = batch["reward"]
|
||||
observations = batch["state"]
|
||||
@@ -365,6 +382,7 @@ def add_actor_information_and_train(
|
||||
observations=observations, actions=actions, next_state=next_observations
|
||||
)
|
||||
|
||||
observation_features, next_observation_features = get_observation_features(policy, observations, next_observations)
|
||||
with policy_lock:
|
||||
loss_critic = policy.compute_loss_critic(
|
||||
observations=observations,
|
||||
@@ -372,6 +390,8 @@ def add_actor_information_and_train(
|
||||
rewards=rewards,
|
||||
next_observations=next_observations,
|
||||
done=done,
|
||||
observation_features=observation_features,
|
||||
next_observation_features=next_observation_features,
|
||||
)
|
||||
optimizers["critic"].zero_grad()
|
||||
loss_critic.backward()
|
||||
@@ -395,6 +415,7 @@ def add_actor_information_and_train(
|
||||
observations=observations, actions=actions, next_state=next_observations
|
||||
)
|
||||
|
||||
observation_features, next_observation_features = get_observation_features(policy, observations, next_observations)
|
||||
with policy_lock:
|
||||
loss_critic = policy.compute_loss_critic(
|
||||
observations=observations,
|
||||
@@ -402,6 +423,8 @@ def add_actor_information_and_train(
|
||||
rewards=rewards,
|
||||
next_observations=next_observations,
|
||||
done=done,
|
||||
observation_features=observation_features,
|
||||
next_observation_features=next_observation_features,
|
||||
)
|
||||
optimizers["critic"].zero_grad()
|
||||
loss_critic.backward()
|
||||
@@ -413,7 +436,8 @@ def add_actor_information_and_train(
|
||||
if optimization_step % cfg.training.policy_update_freq == 0:
|
||||
for _ in range(cfg.training.policy_update_freq):
|
||||
with policy_lock:
|
||||
loss_actor = policy.compute_loss_actor(observations=observations)
|
||||
loss_actor = policy.compute_loss_actor(observations=observations,
|
||||
observation_features=observation_features)
|
||||
|
||||
optimizers["actor"].zero_grad()
|
||||
loss_actor.backward()
|
||||
@@ -422,7 +446,8 @@ def add_actor_information_and_train(
|
||||
training_infos["loss_actor"] = loss_actor.item()
|
||||
|
||||
loss_temperature = policy.compute_loss_temperature(
|
||||
observations=observations
|
||||
observations=observations,
|
||||
observation_features=observation_features
|
||||
)
|
||||
optimizers["temperature"].zero_grad()
|
||||
loss_temperature.backward()
|
||||
|
||||
Reference in New Issue
Block a user