perf(rl): use async iterators in OnlineOfflineMixer.get_iterator

This commit is contained in:
Khalil Meftah
2026-04-18 16:02:28 +02:00
parent 72fb0faf62
commit 2487a6ee6d
+22 -2
View File
@@ -89,6 +89,26 @@ class OnlineOfflineMixer(DataMixer):
async_prefetch: bool = True,
queue_size: int = 2,
):
"""Yield batches from online/offline mixed sampling."""
"""Yield batches by composing buffer async iterators."""
n_online = max(1, int(batch_size * self.online_ratio))
online_iter = self.online_buffer.get_iterator(
batch_size=n_online,
async_prefetch=async_prefetch,
queue_size=queue_size,
)
if self.offline_buffer is None:
yield from online_iter
return
n_offline = batch_size - n_online
offline_iter = self.offline_buffer.get_iterator(
batch_size=n_offline,
async_prefetch=async_prefetch,
queue_size=queue_size,
)
while True:
yield self.sample(batch_size)
yield concatenate_batch_transitions(next(online_iter), next(offline_iter))