diff --git a/src/lerobot/rl/data_sources/data_mixer.py b/src/lerobot/rl/data_sources/data_mixer.py index 01c9055be..bdf667c6c 100644 --- a/src/lerobot/rl/data_sources/data_mixer.py +++ b/src/lerobot/rl/data_sources/data_mixer.py @@ -89,6 +89,26 @@ class OnlineOfflineMixer(DataMixer): async_prefetch: bool = True, queue_size: int = 2, ): - """Yield batches from online/offline mixed sampling.""" + """Yield batches by composing buffer async iterators.""" + + n_online = max(1, int(batch_size * self.online_ratio)) + + online_iter = self.online_buffer.get_iterator( + batch_size=n_online, + async_prefetch=async_prefetch, + queue_size=queue_size, + ) + + if self.offline_buffer is None: + yield from online_iter + return + + n_offline = batch_size - n_online + offline_iter = self.offline_buffer.get_iterator( + batch_size=n_offline, + async_prefetch=async_prefetch, + queue_size=queue_size, + ) + while True: - yield self.sample(batch_size) + yield concatenate_batch_transitions(next(online_iter), next(offline_iter))