Enhance SAC configuration and replay buffer with asynchronous prefetching support

- Added async_prefetch parameter to SACConfig for improved buffer management.
- Implemented get_iterator method in ReplayBuffer to support asynchronous prefetching of batches.
- Updated learner_server to utilize the new iterator for online and offline sampling, enhancing training efficiency.
This commit is contained in:
AdilZouitine
2025-04-03 14:23:50 +00:00
committed by Adil Zouitine
parent 2d932b710c
commit 74c11c4a75
3 changed files with 132 additions and 482 deletions
+26 -8
View File
@@ -269,6 +269,7 @@ def add_actor_information_and_train(
policy_parameters_push_frequency = cfg.policy.actor_learner_config.policy_parameters_push_frequency
saving_checkpoint = cfg.save_checkpoint
online_steps = cfg.policy.online_steps
async_prefetch = cfg.policy.async_prefetch
# Initialize logging for multiprocessing
if not use_threads(cfg):
@@ -326,6 +327,9 @@ def add_actor_information_and_train(
if cfg.dataset is not None:
dataset_repo_id = cfg.dataset.repo_id
# Initialize iterators
online_iterator = None
offline_iterator = None
# NOTE: THIS IS THE MAIN LOOP OF THE LEARNER
while True:
# Exit the training loop if shutdown is requested
@@ -359,16 +363,29 @@ def add_actor_information_and_train(
if len(replay_buffer) < online_step_before_learning:
continue
if online_iterator is None:
logging.debug("[LEARNER] Initializing online replay buffer iterator")
online_iterator = replay_buffer.get_iterator(
batch_size=batch_size, async_prefetch=async_prefetch, queue_size=2
)
if offline_replay_buffer is not None and offline_iterator is None:
logging.debug("[LEARNER] Initializing offline replay buffer iterator")
offline_iterator = offline_replay_buffer.get_iterator(
batch_size=batch_size, async_prefetch=async_prefetch, queue_size=2
)
logging.debug("[LEARNER] Starting optimization loop")
time_for_one_optimization_step = time.time()
for _ in range(utd_ratio - 1):
batch = replay_buffer.sample(batch_size=batch_size)
# Sample from the iterators
batch = next(online_iterator)
if dataset_repo_id is not None:
batch_offline = offline_replay_buffer.sample(batch_size=batch_size)
batch = concatenate_batch_transitions(
left_batch_transitions=batch, right_batch_transition=batch_offline
)
if dataset_repo_id is not None:
batch_offline = next(offline_iterator)
batch = concatenate_batch_transitions(
left_batch_transitions=batch, right_batch_transition=batch_offline
)
actions = batch["action"]
rewards = batch["reward"]
@@ -418,10 +435,11 @@ def add_actor_information_and_train(
# Update target networks
policy.update_target_networks()
batch = replay_buffer.sample(batch_size=batch_size)
# Sample for the last update in the UTD ratio
batch = next(online_iterator)
if dataset_repo_id is not None:
batch_offline = offline_replay_buffer.sample(batch_size=batch_size)
batch_offline = next(offline_iterator)
batch = concatenate_batch_transitions(
left_batch_transitions=batch, right_batch_transition=batch_offline
)