Async encoding

This commit is contained in:
Pepijn
2026-02-03 08:50:34 +01:00
parent 2598dbc31a
commit c028ae3a44
3 changed files with 220 additions and 7 deletions
@@ -102,6 +102,9 @@ class RaCRTCDatasetConfig:
num_image_writer_threads_per_camera: int = 4 num_image_writer_threads_per_camera: int = 4
video_encoding_batch_size: int = 1 video_encoding_batch_size: int = 1
rename_map: dict[str, str] = field(default_factory=dict) rename_map: dict[str, str] = field(default_factory=dict)
# Async video encoding: encode videos in background without blocking recording
# Set to number of worker processes (0 = half of CPUs, None = disabled)
async_video_encoder_workers: int | None = 0
@dataclass @dataclass
@@ -722,6 +725,9 @@ def rac_rtc_collect(cfg: RaCRTCConfig) -> LeRobotDataset:
num_processes=cfg.dataset.num_image_writer_processes, num_processes=cfg.dataset.num_image_writer_processes,
num_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot_raw.cameras), num_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot_raw.cameras),
) )
if cfg.dataset.async_video_encoder_workers is not None:
workers = None if cfg.dataset.async_video_encoder_workers == 0 else cfg.dataset.async_video_encoder_workers
dataset.start_video_encoder(num_workers=workers)
else: else:
dataset = LeRobotDataset.create( dataset = LeRobotDataset.create(
cfg.dataset.repo_id, cfg.dataset.repo_id,
@@ -734,6 +740,7 @@ def rac_rtc_collect(cfg: RaCRTCConfig) -> LeRobotDataset:
image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera
* len(robot_raw.cameras if hasattr(robot_raw, "cameras") else []), * len(robot_raw.cameras if hasattr(robot_raw, "cameras") else []),
batch_encoding_size=cfg.dataset.video_encoding_batch_size, batch_encoding_size=cfg.dataset.video_encoding_batch_size,
video_encoder_workers=cfg.dataset.async_video_encoder_workers,
) )
# Load policy # Load policy
@@ -846,7 +853,8 @@ def rac_rtc_collect(cfg: RaCRTCConfig) -> LeRobotDataset:
dataset.clear_episode_buffer() dataset.clear_episode_buffer()
continue continue
dataset.save_episode() use_async = cfg.dataset.async_video_encoder_workers is not None
dataset.save_episode(async_video_encoding=use_async)
recorded += 1 recorded += 1
if recorded < cfg.dataset.num_episodes and not events["stop_recording"]: if recorded < cfg.dataset.num_episodes and not events["stop_recording"]:
+70 -6
View File
@@ -69,6 +69,7 @@ from lerobot.datasets.utils import (
write_tasks, write_tasks,
) )
from lerobot.datasets.video_utils import ( from lerobot.datasets.video_utils import (
AsyncVideoEncoder,
VideoFrame, VideoFrame,
concatenate_video_files, concatenate_video_files,
decode_video_frames, decode_video_frames,
@@ -693,6 +694,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
# Unused attributes # Unused attributes
self.image_writer = None self.image_writer = None
self.video_encoder = None
self.episode_buffer = None self.episode_buffer = None
self.writer = None self.writer = None
self.latest_episode = None self.latest_episode = None
@@ -1067,9 +1069,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
def finalize(self): def finalize(self):
""" """
Close the parquet writers. This function needs to be called after data collection/conversion, else footer metadata won't be written to the parquet files. Close the parquet writers and stop async workers. This function needs to be called after data
collection/conversion, else footer metadata won't be written to the parquet files.
The dataset won't be valid and can't be loaded as ds = LeRobotDataset(repo_id=repo, root=HF_LEROBOT_HOME.joinpath(repo)) The dataset won't be valid and can't be loaded as ds = LeRobotDataset(repo_id=repo, root=HF_LEROBOT_HOME.joinpath(repo))
""" """
self._wait_video_encoder()
self.stop_video_encoder()
self._close_writer() self._close_writer()
self.meta._close_writer() self.meta._close_writer()
@@ -1151,6 +1156,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
self, self,
episode_data: dict | None = None, episode_data: dict | None = None,
parallel_encoding: bool = True, parallel_encoding: bool = True,
async_video_encoding: bool = False,
) -> None: ) -> None:
""" """
This will save to disk the current episode in self.episode_buffer. This will save to disk the current episode in self.episode_buffer.
@@ -1158,6 +1164,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
Video encoding is handled automatically based on batch_encoding_size: Video encoding is handled automatically based on batch_encoding_size:
- If batch_encoding_size == 1: Videos are encoded immediately after each episode - If batch_encoding_size == 1: Videos are encoded immediately after each episode
- If batch_encoding_size > 1: Videos are encoded in batches. - If batch_encoding_size > 1: Videos are encoded in batches.
- If async_video_encoding=True: Encoding runs in background (requires start_video_encoder())
Args: Args:
episode_data (dict | None, optional): Dict containing the episode data to save. If None, this will episode_data (dict | None, optional): Dict containing the episode data to save. If None, this will
@@ -1165,6 +1172,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
None. None.
parallel_encoding (bool, optional): If True, encode videos in parallel using ProcessPoolExecutor. parallel_encoding (bool, optional): If True, encode videos in parallel using ProcessPoolExecutor.
Defaults to True on Linux, False on macOS as it tends to use all the CPU available already. Defaults to True on Linux, False on macOS as it tends to use all the CPU available already.
async_video_encoding (bool, optional): If True and video_encoder is started, encode videos
asynchronously in background processes without blocking. Defaults to False.
""" """
episode_buffer = episode_data if episode_data is not None else self.episode_buffer episode_buffer = episode_data if episode_data is not None else self.episode_buffer
@@ -1199,11 +1208,30 @@ class LeRobotDataset(torch.utils.data.Dataset):
ep_metadata = self._save_episode_data(episode_buffer) ep_metadata = self._save_episode_data(episode_buffer)
has_video_keys = len(self.meta.video_keys) > 0 has_video_keys = len(self.meta.video_keys) > 0
use_batched_encoding = self.batch_encoding_size > 1 use_batched_encoding = self.batch_encoding_size > 1
use_async_encoding = (
async_video_encoding
and hasattr(self, "video_encoder")
and self.video_encoder is not None
)
if has_video_keys and not use_batched_encoding: if has_video_keys and not use_batched_encoding:
video_paths = self._encode_multiple_temporary_episode_videos(self.meta.video_keys, episode_index) if use_async_encoding:
for video_key, video_path in zip(self.meta.video_keys, video_paths): # Submit encoding tasks to background workers
ep_metadata.update(self._save_episode_video(video_key, episode_index, video_path)) for video_key in self.meta.video_keys:
img_dir = self._get_image_file_dir(episode_index, video_key)
temp_path = Path(tempfile.mkdtemp(dir=self.root)) / f"{video_key}_{episode_index:03d}.mp4"
self.video_encoder.submit(
imgs_dir=img_dir,
video_path=temp_path,
fps=self.fps,
episode_index=episode_index,
video_key=video_key,
callback_data={"root": self.root, "meta": self.meta},
)
else:
video_paths = self._encode_multiple_temporary_episode_videos(self.meta.video_keys, episode_index)
for video_key, video_path in zip(self.meta.video_keys, video_paths):
ep_metadata.update(self._save_episode_video(video_key, episode_index, video_path))
# `meta.save_episode` need to be executed after encoding the videos # `meta.save_episode` need to be executed after encoding the videos
self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats, ep_metadata) self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats, ep_metadata)
@@ -1219,7 +1247,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
if not episode_data: if not episode_data:
# Reset episode buffer and clean up temporary images (if not already deleted during video encoding) # Reset episode buffer and clean up temporary images (if not already deleted during video encoding)
self.clear_episode_buffer(delete_images=len(self.meta.image_keys) > 0) delete_images = len(self.meta.image_keys) > 0 and not use_async_encoding
self.clear_episode_buffer(delete_images=delete_images)
def _batch_save_episode_video(self, start_episode: int, end_episode: int | None = None) -> None: def _batch_save_episode_video(self, start_episode: int, end_episode: int | None = None) -> None:
""" """
@@ -1491,6 +1520,30 @@ class LeRobotDataset(torch.utils.data.Dataset):
if self.image_writer is not None: if self.image_writer is not None:
self.image_writer.wait_until_done() self.image_writer.wait_until_done()
def start_video_encoder(self, num_workers: int | None = None) -> None:
"""
Start async video encoder for on-the-fly video encoding during recording.
Args:
num_workers: Number of encoding worker processes. Defaults to half of available CPUs.
"""
if hasattr(self, "video_encoder") and self.video_encoder is not None:
logging.warning("Replacing existing AsyncVideoEncoder")
self.video_encoder.stop()
self.video_encoder = AsyncVideoEncoder(num_workers=num_workers)
def stop_video_encoder(self, wait: bool = True) -> None:
"""Stop async video encoder and wait for pending tasks."""
if hasattr(self, "video_encoder") and self.video_encoder is not None:
self.video_encoder.stop(wait=wait)
self.video_encoder = None
def _wait_video_encoder(self) -> None:
"""Wait for async video encoder to finish pending tasks."""
if hasattr(self, "video_encoder") and self.video_encoder is not None:
self.video_encoder.wait_until_done()
def _encode_temporary_episode_video(self, video_key: str, episode_index: int) -> Path: def _encode_temporary_episode_video(self, video_key: str, episode_index: int) -> Path:
""" """
Use ffmpeg to convert frames stored as png into mp4 videos. Use ffmpeg to convert frames stored as png into mp4 videos.
@@ -1529,8 +1582,14 @@ class LeRobotDataset(torch.utils.data.Dataset):
image_writer_threads: int = 0, image_writer_threads: int = 0,
video_backend: str | None = None, video_backend: str | None = None,
batch_encoding_size: int = 1, batch_encoding_size: int = 1,
video_encoder_workers: int | None = None,
) -> "LeRobotDataset": ) -> "LeRobotDataset":
"""Create a LeRobot Dataset from scratch in order to record data.""" """Create a LeRobot Dataset from scratch in order to record data.
Args:
video_encoder_workers: If set, starts async video encoder with this many workers.
Set to 0 for half of available CPUs. For on-the-fly encoding during recording.
"""
obj = cls.__new__(cls) obj = cls.__new__(cls)
obj.meta = LeRobotDatasetMetadata.create( obj.meta = LeRobotDatasetMetadata.create(
repo_id=repo_id, repo_id=repo_id,
@@ -1545,11 +1604,16 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj.revision = None obj.revision = None
obj.tolerance_s = tolerance_s obj.tolerance_s = tolerance_s
obj.image_writer = None obj.image_writer = None
obj.video_encoder = None
obj.batch_encoding_size = batch_encoding_size obj.batch_encoding_size = batch_encoding_size
obj.episodes_since_last_encoding = 0 obj.episodes_since_last_encoding = 0
if image_writer_processes or image_writer_threads: if image_writer_processes or image_writer_threads:
obj.start_image_writer(image_writer_processes, image_writer_threads) obj.start_image_writer(image_writer_processes, image_writer_threads)
if video_encoder_workers is not None:
workers = None if video_encoder_workers == 0 else video_encoder_workers
obj.start_video_encoder(num_workers=workers)
# TODO(aliberts, rcadene, alexander-soare): Merge this with OnlineBuffer/DataBuffer # TODO(aliberts, rcadene, alexander-soare): Merge this with OnlineBuffer/DataBuffer
obj.episode_buffer = obj.create_episode_buffer() obj.episode_buffer = obj.create_episode_buffer()
+141
View File
@@ -16,8 +16,12 @@
import glob import glob
import importlib import importlib
import logging import logging
import multiprocessing
import os
import queue
import shutil import shutil
import tempfile import tempfile
import threading
import warnings import warnings
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
@@ -400,6 +404,143 @@ def encode_video_frames(
raise OSError(f"Video encoding did not work. File not found: {video_path}.") raise OSError(f"Video encoding did not work. File not found: {video_path}.")
def _video_encode_worker(task_queue: multiprocessing.JoinableQueue, result_queue: multiprocessing.Queue):
"""Worker process that encodes videos from a queue."""
while True:
try:
item = task_queue.get()
if item is None:
task_queue.task_done()
break
imgs_dir, video_path, fps, episode_index, video_key, callback_data = item
try:
encode_video_frames(imgs_dir, video_path, fps, overwrite=True)
shutil.rmtree(imgs_dir)
result_queue.put(("success", episode_index, video_key, video_path, callback_data))
except Exception as e:
result_queue.put(("error", episode_index, video_key, str(e), callback_data))
task_queue.task_done()
except Exception as e:
logging.error(f"Video encode worker error: {e}")
class AsyncVideoEncoder:
"""
Async video encoder that processes video encoding in background processes.
This enables on-the-fly video encoding during data collection without blocking
the main recording loop. Uses a configurable number of worker processes.
Args:
num_workers: Number of encoding worker processes. Defaults to half of available CPUs.
max_queue_size: Maximum number of pending encoding tasks. Defaults to 100.
"""
def __init__(self, num_workers: int | None = None, max_queue_size: int = 100):
if num_workers is None:
num_workers = max(1, os.cpu_count() // 2)
self.num_workers = num_workers
self._stopped = False
self._lock = threading.Lock()
self.task_queue = multiprocessing.JoinableQueue(maxsize=max_queue_size)
self.result_queue = multiprocessing.Queue()
self.pending_tasks: dict[tuple[int, str], Any] = {}
self.workers = []
for _ in range(num_workers):
p = multiprocessing.Process(target=_video_encode_worker, args=(self.task_queue, self.result_queue))
p.daemon = True
p.start()
self.workers.append(p)
self._result_thread = threading.Thread(target=self._process_results, daemon=True)
self._result_thread.start()
logging.info(f"Started AsyncVideoEncoder with {num_workers} workers")
def _process_results(self):
"""Background thread to process completed encoding results."""
while not self._stopped:
try:
result = self.result_queue.get(timeout=0.1)
status, episode_index, video_key, data, callback_data = result
with self._lock:
key = (episode_index, video_key)
if key in self.pending_tasks:
task = self.pending_tasks.pop(key)
if task.get("callback"):
task["callback"](status, episode_index, video_key, data, callback_data)
if status == "error":
logging.error(f"Video encoding failed for ep {episode_index}, {video_key}: {data}")
else:
logging.debug(f"Video encoded: ep {episode_index}, {video_key}")
except queue.Empty:
continue
except Exception as e:
if not self._stopped:
logging.error(f"Result processing error: {e}")
def submit(
self,
imgs_dir: Path,
video_path: Path,
fps: int,
episode_index: int,
video_key: str,
callback: callable = None,
callback_data: Any = None,
):
"""Submit a video encoding task."""
if self._stopped:
raise RuntimeError("AsyncVideoEncoder has been stopped")
with self._lock:
self.pending_tasks[(episode_index, video_key)] = {
"callback": callback,
"video_path": video_path,
}
self.task_queue.put((imgs_dir, video_path, fps, episode_index, video_key, callback_data))
@property
def pending_count(self) -> int:
"""Number of pending encoding tasks."""
with self._lock:
return len(self.pending_tasks)
def wait_until_done(self, timeout: float | None = None):
"""Wait for all pending tasks to complete."""
self.task_queue.join()
def stop(self, wait: bool = True):
"""Stop all workers and clean up resources."""
if self._stopped:
return
self._stopped = True
for _ in self.workers:
self.task_queue.put(None)
if wait:
for w in self.workers:
w.join(timeout=5.0)
for w in self.workers:
if w.is_alive():
w.terminate()
self.task_queue.close()
self.result_queue.close()
logging.info("AsyncVideoEncoder stopped")
def concatenate_video_files( def concatenate_video_files(
input_video_paths: list[Path | str], output_video_path: Path, overwrite: bool = True input_video_paths: list[Path | str], output_video_path: Path, overwrite: bool = True
): ):