mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 00:59:46 +00:00
perf(rl): use async iterators in OnlineOfflineMixer.get_iterator
This commit is contained in:
@@ -89,6 +89,26 @@ class OnlineOfflineMixer(DataMixer):
|
||||
async_prefetch: bool = True,
|
||||
queue_size: int = 2,
|
||||
):
|
||||
"""Yield batches from online/offline mixed sampling."""
|
||||
"""Yield batches by composing buffer async iterators."""
|
||||
|
||||
n_online = max(1, int(batch_size * self.online_ratio))
|
||||
|
||||
online_iter = self.online_buffer.get_iterator(
|
||||
batch_size=n_online,
|
||||
async_prefetch=async_prefetch,
|
||||
queue_size=queue_size,
|
||||
)
|
||||
|
||||
if self.offline_buffer is None:
|
||||
yield from online_iter
|
||||
return
|
||||
|
||||
n_offline = batch_size - n_online
|
||||
offline_iter = self.offline_buffer.get_iterator(
|
||||
batch_size=n_offline,
|
||||
async_prefetch=async_prefetch,
|
||||
queue_size=queue_size,
|
||||
)
|
||||
|
||||
while True:
|
||||
yield self.sample(batch_size)
|
||||
yield concatenate_batch_transitions(next(online_iter), next(offline_iter))
|
||||
|
||||
Reference in New Issue
Block a user