Async encoding

This commit is contained in:
Pepijn
2026-02-03 08:50:34 +01:00
parent 2598dbc31a
commit c028ae3a44
3 changed files with 220 additions and 7 deletions
@@ -102,6 +102,9 @@ class RaCRTCDatasetConfig:
num_image_writer_threads_per_camera: int = 4
video_encoding_batch_size: int = 1
rename_map: dict[str, str] = field(default_factory=dict)
# Async video encoding: encode videos in background without blocking recording
# Set to number of worker processes (0 = half of CPUs, None = disabled)
async_video_encoder_workers: int | None = 0
@dataclass
@@ -722,6 +725,9 @@ def rac_rtc_collect(cfg: RaCRTCConfig) -> LeRobotDataset:
num_processes=cfg.dataset.num_image_writer_processes,
num_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot_raw.cameras),
)
if cfg.dataset.async_video_encoder_workers is not None:
workers = None if cfg.dataset.async_video_encoder_workers == 0 else cfg.dataset.async_video_encoder_workers
dataset.start_video_encoder(num_workers=workers)
else:
dataset = LeRobotDataset.create(
cfg.dataset.repo_id,
@@ -734,6 +740,7 @@ def rac_rtc_collect(cfg: RaCRTCConfig) -> LeRobotDataset:
image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera
* len(robot_raw.cameras if hasattr(robot_raw, "cameras") else []),
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
video_encoder_workers=cfg.dataset.async_video_encoder_workers,
)
# Load policy
@@ -846,7 +853,8 @@ def rac_rtc_collect(cfg: RaCRTCConfig) -> LeRobotDataset:
dataset.clear_episode_buffer()
continue
dataset.save_episode()
use_async = cfg.dataset.async_video_encoder_workers is not None
dataset.save_episode(async_video_encoding=use_async)
recorded += 1
if recorded < cfg.dataset.num_episodes and not events["stop_recording"]:
+67 -3
View File
@@ -69,6 +69,7 @@ from lerobot.datasets.utils import (
write_tasks,
)
from lerobot.datasets.video_utils import (
AsyncVideoEncoder,
VideoFrame,
concatenate_video_files,
decode_video_frames,
@@ -693,6 +694,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
# Unused attributes
self.image_writer = None
self.video_encoder = None
self.episode_buffer = None
self.writer = None
self.latest_episode = None
@@ -1067,9 +1069,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
def finalize(self):
"""
Close the parquet writers. This function needs to be called after data collection/conversion, else footer metadata won't be written to the parquet files.
Close the parquet writers and stop async workers. This function needs to be called after data
collection/conversion, else footer metadata won't be written to the parquet files.
The dataset won't be valid and can't be loaded as ds = LeRobotDataset(repo_id=repo, root=HF_LEROBOT_HOME.joinpath(repo))
"""
self._wait_video_encoder()
self.stop_video_encoder()
self._close_writer()
self.meta._close_writer()
@@ -1151,6 +1156,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
self,
episode_data: dict | None = None,
parallel_encoding: bool = True,
async_video_encoding: bool = False,
) -> None:
"""
This will save to disk the current episode in self.episode_buffer.
@@ -1158,6 +1164,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
Video encoding is handled automatically based on batch_encoding_size:
- If batch_encoding_size == 1: Videos are encoded immediately after each episode
- If batch_encoding_size > 1: Videos are encoded in batches.
- If async_video_encoding=True: Encoding runs in background (requires start_video_encoder())
Args:
episode_data (dict | None, optional): Dict containing the episode data to save. If None, this will
@@ -1165,6 +1172,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
None.
parallel_encoding (bool, optional): If True, encode videos in parallel using ProcessPoolExecutor.
Defaults to True on Linux, False on macOS as it tends to use all the CPU available already.
async_video_encoding (bool, optional): If True and video_encoder is started, encode videos
asynchronously in background processes without blocking. Defaults to False.
"""
episode_buffer = episode_data if episode_data is not None else self.episode_buffer
@@ -1199,8 +1208,27 @@ class LeRobotDataset(torch.utils.data.Dataset):
ep_metadata = self._save_episode_data(episode_buffer)
has_video_keys = len(self.meta.video_keys) > 0
use_batched_encoding = self.batch_encoding_size > 1
use_async_encoding = (
async_video_encoding
and hasattr(self, "video_encoder")
and self.video_encoder is not None
)
if has_video_keys and not use_batched_encoding:
if use_async_encoding:
# Submit encoding tasks to background workers
for video_key in self.meta.video_keys:
img_dir = self._get_image_file_dir(episode_index, video_key)
temp_path = Path(tempfile.mkdtemp(dir=self.root)) / f"{video_key}_{episode_index:03d}.mp4"
self.video_encoder.submit(
imgs_dir=img_dir,
video_path=temp_path,
fps=self.fps,
episode_index=episode_index,
video_key=video_key,
callback_data={"root": self.root, "meta": self.meta},
)
else:
video_paths = self._encode_multiple_temporary_episode_videos(self.meta.video_keys, episode_index)
for video_key, video_path in zip(self.meta.video_keys, video_paths):
ep_metadata.update(self._save_episode_video(video_key, episode_index, video_path))
@@ -1219,7 +1247,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
if not episode_data:
# Reset episode buffer and clean up temporary images (if not already deleted during video encoding)
self.clear_episode_buffer(delete_images=len(self.meta.image_keys) > 0)
delete_images = len(self.meta.image_keys) > 0 and not use_async_encoding
self.clear_episode_buffer(delete_images=delete_images)
def _batch_save_episode_video(self, start_episode: int, end_episode: int | None = None) -> None:
"""
@@ -1491,6 +1520,30 @@ class LeRobotDataset(torch.utils.data.Dataset):
if self.image_writer is not None:
self.image_writer.wait_until_done()
def start_video_encoder(self, num_workers: int | None = None) -> None:
"""
Start async video encoder for on-the-fly video encoding during recording.
Args:
num_workers: Number of encoding worker processes. Defaults to half of available CPUs.
"""
if hasattr(self, "video_encoder") and self.video_encoder is not None:
logging.warning("Replacing existing AsyncVideoEncoder")
self.video_encoder.stop()
self.video_encoder = AsyncVideoEncoder(num_workers=num_workers)
def stop_video_encoder(self, wait: bool = True) -> None:
"""Stop async video encoder and wait for pending tasks."""
if hasattr(self, "video_encoder") and self.video_encoder is not None:
self.video_encoder.stop(wait=wait)
self.video_encoder = None
def _wait_video_encoder(self) -> None:
"""Wait for async video encoder to finish pending tasks."""
if hasattr(self, "video_encoder") and self.video_encoder is not None:
self.video_encoder.wait_until_done()
def _encode_temporary_episode_video(self, video_key: str, episode_index: int) -> Path:
"""
Use ffmpeg to convert frames stored as png into mp4 videos.
@@ -1529,8 +1582,14 @@ class LeRobotDataset(torch.utils.data.Dataset):
image_writer_threads: int = 0,
video_backend: str | None = None,
batch_encoding_size: int = 1,
video_encoder_workers: int | None = None,
) -> "LeRobotDataset":
"""Create a LeRobot Dataset from scratch in order to record data."""
"""Create a LeRobot Dataset from scratch in order to record data.
Args:
video_encoder_workers: If set, starts async video encoder with this many workers.
Set to 0 for half of available CPUs. For on-the-fly encoding during recording.
"""
obj = cls.__new__(cls)
obj.meta = LeRobotDatasetMetadata.create(
repo_id=repo_id,
@@ -1545,12 +1604,17 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj.revision = None
obj.tolerance_s = tolerance_s
obj.image_writer = None
obj.video_encoder = None
obj.batch_encoding_size = batch_encoding_size
obj.episodes_since_last_encoding = 0
if image_writer_processes or image_writer_threads:
obj.start_image_writer(image_writer_processes, image_writer_threads)
if video_encoder_workers is not None:
workers = None if video_encoder_workers == 0 else video_encoder_workers
obj.start_video_encoder(num_workers=workers)
# TODO(aliberts, rcadene, alexander-soare): Merge this with OnlineBuffer/DataBuffer
obj.episode_buffer = obj.create_episode_buffer()
+141
View File
@@ -16,8 +16,12 @@
import glob
import importlib
import logging
import multiprocessing
import os
import queue
import shutil
import tempfile
import threading
import warnings
from dataclasses import dataclass, field
from pathlib import Path
@@ -400,6 +404,143 @@ def encode_video_frames(
raise OSError(f"Video encoding did not work. File not found: {video_path}.")
def _video_encode_worker(task_queue: multiprocessing.JoinableQueue, result_queue: multiprocessing.Queue):
"""Worker process that encodes videos from a queue."""
while True:
try:
item = task_queue.get()
if item is None:
task_queue.task_done()
break
imgs_dir, video_path, fps, episode_index, video_key, callback_data = item
try:
encode_video_frames(imgs_dir, video_path, fps, overwrite=True)
shutil.rmtree(imgs_dir)
result_queue.put(("success", episode_index, video_key, video_path, callback_data))
except Exception as e:
result_queue.put(("error", episode_index, video_key, str(e), callback_data))
task_queue.task_done()
except Exception as e:
logging.error(f"Video encode worker error: {e}")
class AsyncVideoEncoder:
"""
Async video encoder that processes video encoding in background processes.
This enables on-the-fly video encoding during data collection without blocking
the main recording loop. Uses a configurable number of worker processes.
Args:
num_workers: Number of encoding worker processes. Defaults to half of available CPUs.
max_queue_size: Maximum number of pending encoding tasks. Defaults to 100.
"""
def __init__(self, num_workers: int | None = None, max_queue_size: int = 100):
if num_workers is None:
num_workers = max(1, os.cpu_count() // 2)
self.num_workers = num_workers
self._stopped = False
self._lock = threading.Lock()
self.task_queue = multiprocessing.JoinableQueue(maxsize=max_queue_size)
self.result_queue = multiprocessing.Queue()
self.pending_tasks: dict[tuple[int, str], Any] = {}
self.workers = []
for _ in range(num_workers):
p = multiprocessing.Process(target=_video_encode_worker, args=(self.task_queue, self.result_queue))
p.daemon = True
p.start()
self.workers.append(p)
self._result_thread = threading.Thread(target=self._process_results, daemon=True)
self._result_thread.start()
logging.info(f"Started AsyncVideoEncoder with {num_workers} workers")
def _process_results(self):
"""Background thread to process completed encoding results."""
while not self._stopped:
try:
result = self.result_queue.get(timeout=0.1)
status, episode_index, video_key, data, callback_data = result
with self._lock:
key = (episode_index, video_key)
if key in self.pending_tasks:
task = self.pending_tasks.pop(key)
if task.get("callback"):
task["callback"](status, episode_index, video_key, data, callback_data)
if status == "error":
logging.error(f"Video encoding failed for ep {episode_index}, {video_key}: {data}")
else:
logging.debug(f"Video encoded: ep {episode_index}, {video_key}")
except queue.Empty:
continue
except Exception as e:
if not self._stopped:
logging.error(f"Result processing error: {e}")
def submit(
self,
imgs_dir: Path,
video_path: Path,
fps: int,
episode_index: int,
video_key: str,
callback: callable = None,
callback_data: Any = None,
):
"""Submit a video encoding task."""
if self._stopped:
raise RuntimeError("AsyncVideoEncoder has been stopped")
with self._lock:
self.pending_tasks[(episode_index, video_key)] = {
"callback": callback,
"video_path": video_path,
}
self.task_queue.put((imgs_dir, video_path, fps, episode_index, video_key, callback_data))
@property
def pending_count(self) -> int:
"""Number of pending encoding tasks."""
with self._lock:
return len(self.pending_tasks)
def wait_until_done(self, timeout: float | None = None):
"""Wait for all pending tasks to complete."""
self.task_queue.join()
def stop(self, wait: bool = True):
"""Stop all workers and clean up resources."""
if self._stopped:
return
self._stopped = True
for _ in self.workers:
self.task_queue.put(None)
if wait:
for w in self.workers:
w.join(timeout=5.0)
for w in self.workers:
if w.is_alive():
w.terminate()
self.task_queue.close()
self.result_queue.close()
logging.info("AsyncVideoEncoder stopped")
def concatenate_video_files(
input_video_paths: list[Path | str], output_video_path: Path, overwrite: bool = True
):