diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index 6c28f7541..9253202b3 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -1216,18 +1216,23 @@ class LeRobotDataset(torch.utils.data.Dataset): if has_video_keys and not use_batched_encoding: if use_async_encoding: - # Submit encoding tasks to background workers + # Submit encoding tasks to background workers (parallel encoding) + temp_paths = [] for video_key in self.meta.video_keys: img_dir = self._get_image_file_dir(episode_index, video_key) temp_path = Path(tempfile.mkdtemp(dir=self.root)) / f"{video_key}_{episode_index:03d}.mp4" + temp_paths.append(temp_path) self.video_encoder.submit( imgs_dir=img_dir, video_path=temp_path, fps=self.fps, episode_index=episode_index, video_key=video_key, - callback_data={"root": self.root, "meta": self.meta}, ) + # Wait for encoding to complete, then move to proper location + self.video_encoder.wait_until_done() + for video_key, video_path in zip(self.meta.video_keys, temp_paths): + ep_metadata.update(self._save_episode_video(video_key, episode_index, video_path)) else: video_paths = self._encode_multiple_temporary_episode_videos(self.meta.video_keys, episode_index) for video_key, video_path in zip(self.meta.video_keys, video_paths): diff --git a/src/lerobot/datasets/video_utils.py b/src/lerobot/datasets/video_utils.py index 8442ecfbe..77fe9e1f3 100644 --- a/src/lerobot/datasets/video_utils.py +++ b/src/lerobot/datasets/video_utils.py @@ -406,24 +406,36 @@ def encode_video_frames( def _video_encode_worker(task_queue: multiprocessing.JoinableQueue, result_queue: multiprocessing.Queue): """Worker process that encodes videos from a queue.""" + import gc + worker_pid = os.getpid() + logging.info(f"[VideoEncoder Worker {worker_pid}] Started") + while True: try: item = task_queue.get() if item is None: task_queue.task_done() + logging.info(f"[VideoEncoder Worker {worker_pid}] Shutting down") break imgs_dir, video_path, fps, episode_index, video_key, callback_data = item try: encode_video_frames(imgs_dir, video_path, fps, overwrite=True) shutil.rmtree(imgs_dir) - result_queue.put(("success", episode_index, video_key, video_path, callback_data)) + result_queue.put(("success", episode_index, video_key, str(video_path), callback_data)) + logging.info(f"[VideoEncoder Worker {worker_pid}] Done ep {episode_index} {video_key}") except Exception as e: + logging.error(f"[VideoEncoder Worker {worker_pid}] Error: {e}") result_queue.put(("error", episode_index, video_key, str(e), callback_data)) task_queue.task_done() + gc.collect() except Exception as e: - logging.error(f"Video encode worker error: {e}") + logging.error(f"[VideoEncoder Worker {worker_pid}] Fatal error: {e}") + + +# Use spawn context on Linux to avoid fork issues with threading +_mp_ctx = multiprocessing.get_context("spawn") class AsyncVideoEncoder: @@ -435,32 +447,36 @@ class AsyncVideoEncoder: Args: num_workers: Number of encoding worker processes. Defaults to half of available CPUs. - max_queue_size: Maximum number of pending encoding tasks. Defaults to 100. + max_queue_size: Maximum number of pending encoding tasks. Defaults to 10. """ - def __init__(self, num_workers: int | None = None, max_queue_size: int = 100): + def __init__(self, num_workers: int | None = None, max_queue_size: int = 10): if num_workers is None: - num_workers = max(1, os.cpu_count() // 2) + num_workers = max(2, os.cpu_count() // 2) self.num_workers = num_workers self._stopped = False self._lock = threading.Lock() - self.task_queue = multiprocessing.JoinableQueue(maxsize=max_queue_size) - self.result_queue = multiprocessing.Queue() + self.task_queue = _mp_ctx.JoinableQueue(maxsize=max_queue_size) + self.result_queue = _mp_ctx.Queue() self.pending_tasks: dict[tuple[int, str], Any] = {} self.workers = [] - for _ in range(num_workers): - p = multiprocessing.Process(target=_video_encode_worker, args=(self.task_queue, self.result_queue)) - p.daemon = True + for i in range(num_workers): + p = _mp_ctx.Process( + target=_video_encode_worker, + args=(self.task_queue, self.result_queue), + name=f"VideoEncoder-{i}" + ) p.start() self.workers.append(p) + logging.info(f"Started video encoder worker {i} with PID {p.pid}") self._result_thread = threading.Thread(target=self._process_results, daemon=True) self._result_thread.start() - logging.info(f"Started AsyncVideoEncoder with {num_workers} workers") + logging.info(f"AsyncVideoEncoder started with {num_workers} workers (PIDs: {[w.pid for w in self.workers]})") def _process_results(self): """Background thread to process completed encoding results.""" @@ -479,7 +495,7 @@ class AsyncVideoEncoder: if status == "error": logging.error(f"Video encoding failed for ep {episode_index}, {video_key}: {data}") else: - logging.debug(f"Video encoded: ep {episode_index}, {video_key}") + logging.info(f"Video encoded: ep {episode_index}, {video_key} -> {data}") except queue.Empty: continue except Exception as e: @@ -506,6 +522,7 @@ class AsyncVideoEncoder: "video_path": video_path, } + logging.info(f"Queuing video encode: ep {episode_index} {video_key}, queue size: {self.task_queue.qsize()}") self.task_queue.put((imgs_dir, video_path, fps, episode_index, video_key, callback_data)) @property @@ -515,8 +532,16 @@ class AsyncVideoEncoder: return len(self.pending_tasks) def wait_until_done(self, timeout: float | None = None): - """Wait for all pending tasks to complete.""" + """Wait for all pending tasks to complete and flush memory.""" + logging.info(f"Waiting for {self.pending_count} pending video encodes...") self.task_queue.join() + logging.info("All video encodes complete") + + def flush(self): + """Wait for current tasks and clear memory. Call between episodes.""" + self.wait_until_done() + import gc + gc.collect() def stop(self, wait: bool = True): """Stop all workers and clean up resources.""" @@ -524,16 +549,18 @@ class AsyncVideoEncoder: return self._stopped = True + logging.info("Stopping AsyncVideoEncoder...") for _ in self.workers: self.task_queue.put(None) if wait: for w in self.workers: - w.join(timeout=5.0) + w.join(timeout=10.0) for w in self.workers: if w.is_alive(): + logging.warning(f"Force terminating worker {w.pid}") w.terminate() self.task_queue.close()