mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-24 04:59:47 +00:00
perf(rl): use async iterators in OnlineOfflineMixer.get_iterator
This commit is contained in:
@@ -89,6 +89,26 @@ class OnlineOfflineMixer(DataMixer):
|
|||||||
async_prefetch: bool = True,
|
async_prefetch: bool = True,
|
||||||
queue_size: int = 2,
|
queue_size: int = 2,
|
||||||
):
|
):
|
||||||
"""Yield batches from online/offline mixed sampling."""
|
"""Yield batches by composing buffer async iterators."""
|
||||||
|
|
||||||
|
n_online = max(1, int(batch_size * self.online_ratio))
|
||||||
|
|
||||||
|
online_iter = self.online_buffer.get_iterator(
|
||||||
|
batch_size=n_online,
|
||||||
|
async_prefetch=async_prefetch,
|
||||||
|
queue_size=queue_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.offline_buffer is None:
|
||||||
|
yield from online_iter
|
||||||
|
return
|
||||||
|
|
||||||
|
n_offline = batch_size - n_online
|
||||||
|
offline_iter = self.offline_buffer.get_iterator(
|
||||||
|
batch_size=n_offline,
|
||||||
|
async_prefetch=async_prefetch,
|
||||||
|
queue_size=queue_size,
|
||||||
|
)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
yield self.sample(batch_size)
|
yield concatenate_batch_transitions(next(online_iter), next(offline_iter))
|
||||||
|
|||||||
Reference in New Issue
Block a user