mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 17:20:05 +00:00
Enhance SAC configuration and replay buffer with asynchronous prefetching support
- Added async_prefetch parameter to SACConfig for improved buffer management. - Implemented get_iterator method in ReplayBuffer to support asynchronous prefetching of batches. - Updated learner_server to utilize the new iterator for online and offline sampling, enhancing training efficiency.
This commit is contained in:
committed by
Adil Zouitine
parent
2d932b710c
commit
74c11c4a75
@@ -269,6 +269,7 @@ def add_actor_information_and_train(
|
||||
policy_parameters_push_frequency = cfg.policy.actor_learner_config.policy_parameters_push_frequency
|
||||
saving_checkpoint = cfg.save_checkpoint
|
||||
online_steps = cfg.policy.online_steps
|
||||
async_prefetch = cfg.policy.async_prefetch
|
||||
|
||||
# Initialize logging for multiprocessing
|
||||
if not use_threads(cfg):
|
||||
@@ -326,6 +327,9 @@ def add_actor_information_and_train(
|
||||
if cfg.dataset is not None:
|
||||
dataset_repo_id = cfg.dataset.repo_id
|
||||
|
||||
# Initialize iterators
|
||||
online_iterator = None
|
||||
offline_iterator = None
|
||||
# NOTE: THIS IS THE MAIN LOOP OF THE LEARNER
|
||||
while True:
|
||||
# Exit the training loop if shutdown is requested
|
||||
@@ -359,16 +363,29 @@ def add_actor_information_and_train(
|
||||
if len(replay_buffer) < online_step_before_learning:
|
||||
continue
|
||||
|
||||
if online_iterator is None:
|
||||
logging.debug("[LEARNER] Initializing online replay buffer iterator")
|
||||
online_iterator = replay_buffer.get_iterator(
|
||||
batch_size=batch_size, async_prefetch=async_prefetch, queue_size=2
|
||||
)
|
||||
|
||||
if offline_replay_buffer is not None and offline_iterator is None:
|
||||
logging.debug("[LEARNER] Initializing offline replay buffer iterator")
|
||||
offline_iterator = offline_replay_buffer.get_iterator(
|
||||
batch_size=batch_size, async_prefetch=async_prefetch, queue_size=2
|
||||
)
|
||||
|
||||
logging.debug("[LEARNER] Starting optimization loop")
|
||||
time_for_one_optimization_step = time.time()
|
||||
for _ in range(utd_ratio - 1):
|
||||
batch = replay_buffer.sample(batch_size=batch_size)
|
||||
# Sample from the iterators
|
||||
batch = next(online_iterator)
|
||||
|
||||
if dataset_repo_id is not None:
|
||||
batch_offline = offline_replay_buffer.sample(batch_size=batch_size)
|
||||
batch = concatenate_batch_transitions(
|
||||
left_batch_transitions=batch, right_batch_transition=batch_offline
|
||||
)
|
||||
if dataset_repo_id is not None:
|
||||
batch_offline = next(offline_iterator)
|
||||
batch = concatenate_batch_transitions(
|
||||
left_batch_transitions=batch, right_batch_transition=batch_offline
|
||||
)
|
||||
|
||||
actions = batch["action"]
|
||||
rewards = batch["reward"]
|
||||
@@ -418,10 +435,11 @@ def add_actor_information_and_train(
|
||||
# Update target networks
|
||||
policy.update_target_networks()
|
||||
|
||||
batch = replay_buffer.sample(batch_size=batch_size)
|
||||
# Sample for the last update in the UTD ratio
|
||||
batch = next(online_iterator)
|
||||
|
||||
if dataset_repo_id is not None:
|
||||
batch_offline = offline_replay_buffer.sample(batch_size=batch_size)
|
||||
batch_offline = next(offline_iterator)
|
||||
batch = concatenate_batch_transitions(
|
||||
left_batch_transitions=batch, right_batch_transition=batch_offline
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user