mirror of
https://github.com/huggingface/lerobot.git
synced 2026-07-02 07:37:10 +00:00
reduce memoery load and move to video folder
This commit is contained in:
@@ -1216,18 +1216,23 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
if has_video_keys and not use_batched_encoding:
|
||||
if use_async_encoding:
|
||||
# Submit encoding tasks to background workers
|
||||
# Submit encoding tasks to background workers (parallel encoding)
|
||||
temp_paths = []
|
||||
for video_key in self.meta.video_keys:
|
||||
img_dir = self._get_image_file_dir(episode_index, video_key)
|
||||
temp_path = Path(tempfile.mkdtemp(dir=self.root)) / f"{video_key}_{episode_index:03d}.mp4"
|
||||
temp_paths.append(temp_path)
|
||||
self.video_encoder.submit(
|
||||
imgs_dir=img_dir,
|
||||
video_path=temp_path,
|
||||
fps=self.fps,
|
||||
episode_index=episode_index,
|
||||
video_key=video_key,
|
||||
callback_data={"root": self.root, "meta": self.meta},
|
||||
)
|
||||
# Wait for encoding to complete, then move to proper location
|
||||
self.video_encoder.wait_until_done()
|
||||
for video_key, video_path in zip(self.meta.video_keys, temp_paths):
|
||||
ep_metadata.update(self._save_episode_video(video_key, episode_index, video_path))
|
||||
else:
|
||||
video_paths = self._encode_multiple_temporary_episode_videos(self.meta.video_keys, episode_index)
|
||||
for video_key, video_path in zip(self.meta.video_keys, video_paths):
|
||||
|
||||
@@ -406,24 +406,36 @@ def encode_video_frames(
|
||||
|
||||
def _video_encode_worker(task_queue: multiprocessing.JoinableQueue, result_queue: multiprocessing.Queue):
|
||||
"""Worker process that encodes videos from a queue."""
|
||||
import gc
|
||||
worker_pid = os.getpid()
|
||||
logging.info(f"[VideoEncoder Worker {worker_pid}] Started")
|
||||
|
||||
while True:
|
||||
try:
|
||||
item = task_queue.get()
|
||||
if item is None:
|
||||
task_queue.task_done()
|
||||
logging.info(f"[VideoEncoder Worker {worker_pid}] Shutting down")
|
||||
break
|
||||
|
||||
imgs_dir, video_path, fps, episode_index, video_key, callback_data = item
|
||||
try:
|
||||
encode_video_frames(imgs_dir, video_path, fps, overwrite=True)
|
||||
shutil.rmtree(imgs_dir)
|
||||
result_queue.put(("success", episode_index, video_key, video_path, callback_data))
|
||||
result_queue.put(("success", episode_index, video_key, str(video_path), callback_data))
|
||||
logging.info(f"[VideoEncoder Worker {worker_pid}] Done ep {episode_index} {video_key}")
|
||||
except Exception as e:
|
||||
logging.error(f"[VideoEncoder Worker {worker_pid}] Error: {e}")
|
||||
result_queue.put(("error", episode_index, video_key, str(e), callback_data))
|
||||
|
||||
task_queue.task_done()
|
||||
gc.collect()
|
||||
except Exception as e:
|
||||
logging.error(f"Video encode worker error: {e}")
|
||||
logging.error(f"[VideoEncoder Worker {worker_pid}] Fatal error: {e}")
|
||||
|
||||
|
||||
# Use spawn context on Linux to avoid fork issues with threading
|
||||
_mp_ctx = multiprocessing.get_context("spawn")
|
||||
|
||||
|
||||
class AsyncVideoEncoder:
|
||||
@@ -435,32 +447,36 @@ class AsyncVideoEncoder:
|
||||
|
||||
Args:
|
||||
num_workers: Number of encoding worker processes. Defaults to half of available CPUs.
|
||||
max_queue_size: Maximum number of pending encoding tasks. Defaults to 100.
|
||||
max_queue_size: Maximum number of pending encoding tasks. Defaults to 10.
|
||||
"""
|
||||
|
||||
def __init__(self, num_workers: int | None = None, max_queue_size: int = 100):
|
||||
def __init__(self, num_workers: int | None = None, max_queue_size: int = 10):
|
||||
if num_workers is None:
|
||||
num_workers = max(1, os.cpu_count() // 2)
|
||||
num_workers = max(2, os.cpu_count() // 2)
|
||||
|
||||
self.num_workers = num_workers
|
||||
self._stopped = False
|
||||
self._lock = threading.Lock()
|
||||
|
||||
self.task_queue = multiprocessing.JoinableQueue(maxsize=max_queue_size)
|
||||
self.result_queue = multiprocessing.Queue()
|
||||
self.task_queue = _mp_ctx.JoinableQueue(maxsize=max_queue_size)
|
||||
self.result_queue = _mp_ctx.Queue()
|
||||
self.pending_tasks: dict[tuple[int, str], Any] = {}
|
||||
|
||||
self.workers = []
|
||||
for _ in range(num_workers):
|
||||
p = multiprocessing.Process(target=_video_encode_worker, args=(self.task_queue, self.result_queue))
|
||||
p.daemon = True
|
||||
for i in range(num_workers):
|
||||
p = _mp_ctx.Process(
|
||||
target=_video_encode_worker,
|
||||
args=(self.task_queue, self.result_queue),
|
||||
name=f"VideoEncoder-{i}"
|
||||
)
|
||||
p.start()
|
||||
self.workers.append(p)
|
||||
logging.info(f"Started video encoder worker {i} with PID {p.pid}")
|
||||
|
||||
self._result_thread = threading.Thread(target=self._process_results, daemon=True)
|
||||
self._result_thread.start()
|
||||
|
||||
logging.info(f"Started AsyncVideoEncoder with {num_workers} workers")
|
||||
logging.info(f"AsyncVideoEncoder started with {num_workers} workers (PIDs: {[w.pid for w in self.workers]})")
|
||||
|
||||
def _process_results(self):
|
||||
"""Background thread to process completed encoding results."""
|
||||
@@ -479,7 +495,7 @@ class AsyncVideoEncoder:
|
||||
if status == "error":
|
||||
logging.error(f"Video encoding failed for ep {episode_index}, {video_key}: {data}")
|
||||
else:
|
||||
logging.debug(f"Video encoded: ep {episode_index}, {video_key}")
|
||||
logging.info(f"Video encoded: ep {episode_index}, {video_key} -> {data}")
|
||||
except queue.Empty:
|
||||
continue
|
||||
except Exception as e:
|
||||
@@ -506,6 +522,7 @@ class AsyncVideoEncoder:
|
||||
"video_path": video_path,
|
||||
}
|
||||
|
||||
logging.info(f"Queuing video encode: ep {episode_index} {video_key}, queue size: {self.task_queue.qsize()}")
|
||||
self.task_queue.put((imgs_dir, video_path, fps, episode_index, video_key, callback_data))
|
||||
|
||||
@property
|
||||
@@ -515,8 +532,16 @@ class AsyncVideoEncoder:
|
||||
return len(self.pending_tasks)
|
||||
|
||||
def wait_until_done(self, timeout: float | None = None):
|
||||
"""Wait for all pending tasks to complete."""
|
||||
"""Wait for all pending tasks to complete and flush memory."""
|
||||
logging.info(f"Waiting for {self.pending_count} pending video encodes...")
|
||||
self.task_queue.join()
|
||||
logging.info("All video encodes complete")
|
||||
|
||||
def flush(self):
|
||||
"""Wait for current tasks and clear memory. Call between episodes."""
|
||||
self.wait_until_done()
|
||||
import gc
|
||||
gc.collect()
|
||||
|
||||
def stop(self, wait: bool = True):
|
||||
"""Stop all workers and clean up resources."""
|
||||
@@ -524,16 +549,18 @@ class AsyncVideoEncoder:
|
||||
return
|
||||
|
||||
self._stopped = True
|
||||
logging.info("Stopping AsyncVideoEncoder...")
|
||||
|
||||
for _ in self.workers:
|
||||
self.task_queue.put(None)
|
||||
|
||||
if wait:
|
||||
for w in self.workers:
|
||||
w.join(timeout=5.0)
|
||||
w.join(timeout=10.0)
|
||||
|
||||
for w in self.workers:
|
||||
if w.is_alive():
|
||||
logging.warning(f"Force terminating worker {w.pid}")
|
||||
w.terminate()
|
||||
|
||||
self.task_queue.close()
|
||||
|
||||
Reference in New Issue
Block a user