mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 17:20:05 +00:00
feat(dataset): add streaming video encoding + HW encoder support (#2974)
* feat(dataset): init stream encoding * feat(dataset): use threads to fix frame pickle latency * refactor(dataset): remove HW encoded related changes * add lp (#2977) * feat(dataset): add Hw encoding + log drop frames (#2978) * chore(docs): add streaming video encoding guide * fix(dataset): style docs + testing * chore(docs): simplify sttreaming video encoding guide * chore(dataset): add commands + streaming encoding default false + print note if false + queue default is now 30 * chore(docs): add verification note advice * chore(dataset): adjusting defaults & docs for streaming encoding * docs(scripts): improve docstrings * test(dataset): polish streaming encoding tests * chore(dataset): move FYI log related to streaming * chore(dataset): add arg vcodec to suggestions * refactor(dataset): better handling for auto and available vcodec * chore(dataset): change log level * docs(dataset): add note related to training performance vcodec * docs(dataset): add more notes to streaming encoding --------- Co-authored-by: Caroline Pascal <caroline8.pascal@gmail.com> Co-authored-by: Pepijn <pepijn@huggingface.co>
This commit is contained in:
@@ -68,6 +68,7 @@ from lerobot.datasets.utils import (
|
||||
write_tasks,
|
||||
)
|
||||
from lerobot.datasets.video_utils import (
|
||||
StreamingVideoEncoder,
|
||||
VideoFrame,
|
||||
concatenate_video_files,
|
||||
decode_video_frames,
|
||||
@@ -75,11 +76,11 @@ from lerobot.datasets.video_utils import (
|
||||
get_safe_default_codec,
|
||||
get_video_duration_in_s,
|
||||
get_video_info,
|
||||
resolve_vcodec,
|
||||
)
|
||||
from lerobot.utils.constants import HF_LEROBOT_HOME
|
||||
|
||||
CODEBASE_VERSION = "v3.0"
|
||||
VALID_VIDEO_CODECS = {"h264", "hevc", "libsvtav1"}
|
||||
|
||||
|
||||
class LeRobotDatasetMetadata:
|
||||
@@ -545,12 +546,19 @@ class LeRobotDatasetMetadata:
|
||||
|
||||
|
||||
def _encode_video_worker(
|
||||
video_key: str, episode_index: int, root: Path, fps: int, vcodec: str = "libsvtav1"
|
||||
video_key: str,
|
||||
episode_index: int,
|
||||
root: Path,
|
||||
fps: int,
|
||||
vcodec: str = "libsvtav1",
|
||||
encoder_threads: int | None = None,
|
||||
) -> Path:
|
||||
temp_path = Path(tempfile.mkdtemp(dir=root)) / f"{video_key}_{episode_index:03d}.mp4"
|
||||
fpath = DEFAULT_IMAGE_PATH.format(image_key=video_key, episode_index=episode_index, frame_index=0)
|
||||
img_dir = (root / fpath).parent
|
||||
encode_video_frames(img_dir, temp_path, fps, vcodec=vcodec, overwrite=True)
|
||||
encode_video_frames(
|
||||
img_dir, temp_path, fps, vcodec=vcodec, overwrite=True, encoder_threads=encoder_threads
|
||||
)
|
||||
shutil.rmtree(img_dir)
|
||||
return temp_path
|
||||
|
||||
@@ -570,6 +578,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
video_backend: str | None = None,
|
||||
batch_encoding_size: int = 1,
|
||||
vcodec: str = "libsvtav1",
|
||||
streaming_encoding: bool = False,
|
||||
encoder_queue_maxsize: int = 30,
|
||||
encoder_threads: int | None = None,
|
||||
):
|
||||
"""
|
||||
2 modes are available for instantiating this class, depending on 2 different use cases:
|
||||
@@ -683,12 +694,17 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
batch_encoding_size (int, optional): Number of episodes to accumulate before batch encoding videos.
|
||||
Set to 1 for immediate encoding (default), or higher for batched encoding. Defaults to 1.
|
||||
vcodec (str, optional): Video codec for encoding videos during recording. Options: 'h264', 'hevc',
|
||||
'libsvtav1'. Defaults to 'libsvtav1'. Use 'h264' for faster encoding on systems where AV1
|
||||
encoding is CPU-heavy.
|
||||
'libsvtav1', 'auto', or hardware-specific codecs like 'h264_videotoolbox', 'h264_nvenc'.
|
||||
Defaults to 'libsvtav1'. Use 'auto' to auto-detect the best available hardware encoder.
|
||||
streaming_encoding (bool, optional): If True, encode video frames in real-time during capture
|
||||
instead of writing PNG images first. This makes save_episode() near-instant. Defaults to False.
|
||||
encoder_queue_maxsize (int, optional): Maximum number of frames to buffer per camera when using
|
||||
streaming encoding. Defaults to 30 (~1s at 30fps).
|
||||
encoder_threads (int | None, optional): Number of threads per encoder instance. None lets the
|
||||
codec auto-detect (default). Lower values reduce CPU usage per encoder. Maps to 'lp' (via svtav1-params) for
|
||||
libsvtav1 and 'threads' for h264/hevc.
|
||||
"""
|
||||
super().__init__()
|
||||
if vcodec not in VALID_VIDEO_CODECS:
|
||||
raise ValueError(f"Invalid vcodec '{vcodec}'. Must be one of: {sorted(VALID_VIDEO_CODECS)}")
|
||||
self.repo_id = repo_id
|
||||
self.root = Path(root) if root else HF_LEROBOT_HOME / repo_id
|
||||
self.image_transforms = image_transforms
|
||||
@@ -700,7 +716,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.delta_indices = None
|
||||
self.batch_encoding_size = batch_encoding_size
|
||||
self.episodes_since_last_encoding = 0
|
||||
self.vcodec = vcodec
|
||||
self.vcodec = resolve_vcodec(vcodec)
|
||||
self._encoder_threads = encoder_threads
|
||||
|
||||
# Unused attributes
|
||||
self.image_writer = None
|
||||
@@ -708,6 +725,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.writer = None
|
||||
self.latest_episode = None
|
||||
self._current_file_start_frame = None # Track the starting frame index of the current parquet file
|
||||
self._streaming_encoder = None
|
||||
|
||||
self.root.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
@@ -749,6 +767,19 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s)
|
||||
self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps)
|
||||
|
||||
# Initialize streaming encoder for resumed recording
|
||||
if streaming_encoding and len(self.meta.video_keys) > 0:
|
||||
self._streaming_encoder = StreamingVideoEncoder(
|
||||
fps=self.meta.fps,
|
||||
vcodec=self.vcodec,
|
||||
pix_fmt="yuv420p",
|
||||
g=2,
|
||||
crf=30,
|
||||
preset=None,
|
||||
queue_maxsize=encoder_queue_maxsize,
|
||||
encoder_threads=encoder_threads,
|
||||
)
|
||||
|
||||
def _close_writer(self) -> None:
|
||||
"""Close and cleanup the parquet writer if it exists."""
|
||||
writer = getattr(self, "writer", None)
|
||||
@@ -1104,6 +1135,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
self._close_writer()
|
||||
self.meta._close_writer()
|
||||
if self._streaming_encoder is not None:
|
||||
self._streaming_encoder.close()
|
||||
|
||||
def create_episode_buffer(self, episode_index: int | None = None) -> dict:
|
||||
current_ep_idx = self.meta.total_episodes if episode_index is None else episode_index
|
||||
@@ -1158,6 +1191,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.episode_buffer["timestamp"].append(timestamp)
|
||||
self.episode_buffer["task"].append(frame.pop("task")) # Remove task from frame after processing
|
||||
|
||||
# Start streaming encoder on first frame of episode (once, before iterating keys)
|
||||
if frame_index == 0 and self._streaming_encoder is not None:
|
||||
self._streaming_encoder.start_episode(
|
||||
video_keys=list(self.meta.video_keys),
|
||||
temp_dir=self.root,
|
||||
)
|
||||
|
||||
# Add frame features to episode_buffer
|
||||
for key in frame:
|
||||
if key not in self.features:
|
||||
@@ -1165,7 +1205,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
f"An element of the frame is not in the features. '{key}' not in '{self.features.keys()}'."
|
||||
)
|
||||
|
||||
if self.features[key]["dtype"] in ["image", "video"]:
|
||||
if self.features[key]["dtype"] == "video" and self._streaming_encoder is not None:
|
||||
self._streaming_encoder.feed_frame(key, frame[key])
|
||||
self.episode_buffer[key].append(None) # Placeholder (video keys are skipped in parquet)
|
||||
elif self.features[key]["dtype"] in ["image", "video"]:
|
||||
img_path = self._get_image_file_path(
|
||||
episode_index=self.episode_buffer["episode_index"], image_key=key, frame_index=frame_index
|
||||
)
|
||||
@@ -1226,13 +1269,38 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
# Wait for image writer to end, so that episode stats over images can be computed
|
||||
self._wait_image_writer()
|
||||
ep_stats = compute_episode_stats(episode_buffer, self.features)
|
||||
|
||||
ep_metadata = self._save_episode_data(episode_buffer)
|
||||
has_video_keys = len(self.meta.video_keys) > 0
|
||||
use_streaming = self._streaming_encoder is not None and has_video_keys
|
||||
use_batched_encoding = self.batch_encoding_size > 1
|
||||
|
||||
if has_video_keys and not use_batched_encoding:
|
||||
if use_streaming:
|
||||
# Compute stats for non-video features only (video stats come from encoder)
|
||||
non_video_buffer = {
|
||||
k: v
|
||||
for k, v in episode_buffer.items()
|
||||
if self.features.get(k, {}).get("dtype") not in ("video",)
|
||||
}
|
||||
non_video_features = {k: v for k, v in self.features.items() if v["dtype"] != "video"}
|
||||
ep_stats = compute_episode_stats(non_video_buffer, non_video_features)
|
||||
else:
|
||||
ep_stats = compute_episode_stats(episode_buffer, self.features)
|
||||
|
||||
ep_metadata = self._save_episode_data(episode_buffer)
|
||||
|
||||
if use_streaming:
|
||||
# Finish streaming encoding and collect results
|
||||
streaming_results = self._streaming_encoder.finish_episode()
|
||||
for video_key in self.meta.video_keys:
|
||||
temp_path, video_stats = streaming_results[video_key]
|
||||
if video_stats is not None:
|
||||
# Format stats same as compute_episode_stats: normalize to [0,1], reshape to (C,1,1)
|
||||
ep_stats[video_key] = {
|
||||
k: v if k == "count" else np.squeeze(v.reshape(1, -1, 1, 1) / 255.0, axis=0)
|
||||
for k, v in video_stats.items()
|
||||
}
|
||||
ep_metadata.update(self._save_episode_video(video_key, episode_index, temp_path=temp_path))
|
||||
elif has_video_keys and not use_batched_encoding:
|
||||
num_cameras = len(self.meta.video_keys)
|
||||
if parallel_encoding and num_cameras > 1:
|
||||
# TODO(Steven): Ideally we would like to control the number of threads per encoding such that:
|
||||
@@ -1246,6 +1314,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.root,
|
||||
self.fps,
|
||||
self.vcodec,
|
||||
self._encoder_threads,
|
||||
): video_key
|
||||
for video_key in self.meta.video_keys
|
||||
}
|
||||
@@ -1514,6 +1583,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
return metadata
|
||||
|
||||
def clear_episode_buffer(self, delete_images: bool = True) -> None:
|
||||
# Cancel streaming encoder if active
|
||||
if self._streaming_encoder is not None:
|
||||
self._streaming_encoder.cancel_episode()
|
||||
|
||||
# Clean up image files for the current episode buffer
|
||||
if delete_images:
|
||||
# Wait for the async image writer to finish
|
||||
@@ -1561,7 +1634,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
|
||||
since video encoding with ffmpeg is already using multithreading.
|
||||
"""
|
||||
return _encode_video_worker(video_key, episode_index, self.root, self.fps, self.vcodec)
|
||||
return _encode_video_worker(
|
||||
video_key, episode_index, self.root, self.fps, self.vcodec, self._encoder_threads
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
@@ -1578,10 +1653,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
video_backend: str | None = None,
|
||||
batch_encoding_size: int = 1,
|
||||
vcodec: str = "libsvtav1",
|
||||
streaming_encoding: bool = False,
|
||||
encoder_queue_maxsize: int = 30,
|
||||
encoder_threads: int | None = None,
|
||||
) -> "LeRobotDataset":
|
||||
"""Create a LeRobot Dataset from scratch in order to record data."""
|
||||
if vcodec not in VALID_VIDEO_CODECS:
|
||||
raise ValueError(f"Invalid vcodec '{vcodec}'. Must be one of: {sorted(VALID_VIDEO_CODECS)}")
|
||||
vcodec = resolve_vcodec(vcodec)
|
||||
obj = cls.__new__(cls)
|
||||
obj.meta = LeRobotDatasetMetadata.create(
|
||||
repo_id=repo_id,
|
||||
@@ -1599,6 +1676,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
obj.batch_encoding_size = batch_encoding_size
|
||||
obj.episodes_since_last_encoding = 0
|
||||
obj.vcodec = vcodec
|
||||
obj._encoder_threads = encoder_threads
|
||||
|
||||
if image_writer_processes or image_writer_threads:
|
||||
obj.start_image_writer(image_writer_processes, image_writer_threads)
|
||||
@@ -1620,6 +1698,22 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
obj._lazy_loading = False
|
||||
obj._recorded_frames = 0
|
||||
obj._writer_closed_for_reading = False
|
||||
|
||||
# Initialize streaming encoder
|
||||
if streaming_encoding and len(obj.meta.video_keys) > 0:
|
||||
obj._streaming_encoder = StreamingVideoEncoder(
|
||||
fps=fps,
|
||||
vcodec=vcodec,
|
||||
pix_fmt="yuv420p",
|
||||
g=2,
|
||||
crf=30,
|
||||
preset=None,
|
||||
queue_maxsize=encoder_queue_maxsize,
|
||||
encoder_threads=encoder_threads,
|
||||
)
|
||||
else:
|
||||
obj._streaming_encoder = None
|
||||
|
||||
return obj
|
||||
|
||||
|
||||
|
||||
@@ -13,25 +13,106 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import contextlib
|
||||
import glob
|
||||
import importlib
|
||||
import logging
|
||||
import queue
|
||||
import shutil
|
||||
import tempfile
|
||||
import threading
|
||||
import warnings
|
||||
from dataclasses import dataclass, field
|
||||
from fractions import Fraction
|
||||
from pathlib import Path
|
||||
from threading import Lock
|
||||
from typing import Any, ClassVar
|
||||
|
||||
import av
|
||||
import fsspec
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
import torch
|
||||
import torchvision
|
||||
from datasets.features.features import register_feature
|
||||
from PIL import Image
|
||||
|
||||
# List of hardware encoders to probe for auto-selection. Availability depends on the platform and FFmpeg build.
|
||||
# Determines the order of preference for auto-selection when vcodec="auto" is used.
|
||||
HW_ENCODERS = [
|
||||
"h264_videotoolbox", # macOS
|
||||
"hevc_videotoolbox", # macOS
|
||||
"h264_nvenc", # NVIDIA GPU
|
||||
"hevc_nvenc", # NVIDIA GPU
|
||||
"h264_vaapi", # Linux Intel/AMD
|
||||
"h264_qsv", # Intel Quick Sync
|
||||
]
|
||||
|
||||
VALID_VIDEO_CODECS = {"h264", "hevc", "libsvtav1", "auto"} | set(HW_ENCODERS)
|
||||
|
||||
|
||||
def _get_codec_options(
|
||||
vcodec: str,
|
||||
g: int | None = 2,
|
||||
crf: int | None = 30,
|
||||
preset: int | None = None,
|
||||
) -> dict:
|
||||
"""Build codec-specific options dict for video encoding."""
|
||||
options = {}
|
||||
|
||||
# GOP size (keyframe interval) - supported by VideoToolbox and software encoders
|
||||
if g is not None and (vcodec in ("h264_videotoolbox", "hevc_videotoolbox") or vcodec not in HW_ENCODERS):
|
||||
options["g"] = str(g)
|
||||
|
||||
# Quality control (codec-specific parameter names)
|
||||
if crf is not None:
|
||||
if vcodec in ("h264", "hevc", "libsvtav1"):
|
||||
options["crf"] = str(crf)
|
||||
elif vcodec in ("h264_videotoolbox", "hevc_videotoolbox"):
|
||||
quality = max(1, min(100, int(100 - crf * 2)))
|
||||
options["q:v"] = str(quality)
|
||||
elif vcodec in ("h264_nvenc", "hevc_nvenc"):
|
||||
options["rc"] = "constqp"
|
||||
options["qp"] = str(crf)
|
||||
elif vcodec in ("h264_vaapi",):
|
||||
options["qp"] = str(crf)
|
||||
elif vcodec in ("h264_qsv",):
|
||||
options["global_quality"] = str(crf)
|
||||
|
||||
# Preset (only for libsvtav1)
|
||||
if vcodec == "libsvtav1":
|
||||
options["preset"] = str(preset) if preset is not None else "12"
|
||||
|
||||
return options
|
||||
|
||||
|
||||
def detect_available_hw_encoders() -> list[str]:
|
||||
"""Probe PyAV/FFmpeg for available hardware video encoders."""
|
||||
available = []
|
||||
for codec_name in HW_ENCODERS:
|
||||
try:
|
||||
av.codec.Codec(codec_name, "w")
|
||||
available.append(codec_name)
|
||||
except Exception: # nosec B110
|
||||
pass # nosec B110
|
||||
return available
|
||||
|
||||
|
||||
def resolve_vcodec(vcodec: str) -> str:
|
||||
"""Validate vcodec and resolve 'auto' to best available HW encoder, fallback to libsvtav1."""
|
||||
if vcodec not in VALID_VIDEO_CODECS:
|
||||
raise ValueError(f"Invalid vcodec '{vcodec}'. Must be one of: {sorted(VALID_VIDEO_CODECS)}")
|
||||
if vcodec != "auto":
|
||||
logging.info(f"Using video codec: {vcodec}")
|
||||
return vcodec
|
||||
available = detect_available_hw_encoders()
|
||||
for encoder in HW_ENCODERS:
|
||||
if encoder in available:
|
||||
logging.info(f"Auto-selected video codec: {encoder}")
|
||||
return encoder
|
||||
logging.info("No hardware encoder available, falling back to software encoder 'libsvtav1'")
|
||||
return "libsvtav1"
|
||||
|
||||
|
||||
def get_safe_default_codec():
|
||||
if importlib.util.find_spec("torchcodec"):
|
||||
@@ -309,14 +390,13 @@ def encode_video_frames(
|
||||
g: int | None = 2,
|
||||
crf: int | None = 30,
|
||||
fast_decode: int = 0,
|
||||
log_level: int | None = av.logging.ERROR,
|
||||
log_level: int | None = av.logging.WARNING,
|
||||
overwrite: bool = False,
|
||||
preset: int | None = None,
|
||||
encoder_threads: int | None = None,
|
||||
) -> None:
|
||||
"""More info on ffmpeg arguments tuning on `benchmark/video/README.md`"""
|
||||
# Check encoder availability
|
||||
if vcodec not in ["h264", "hevc", "libsvtav1"]:
|
||||
raise ValueError(f"Unsupported video codec: {vcodec}. Supported codecs are: h264, hevc, libsvtav1.")
|
||||
vcodec = resolve_vcodec(vcodec)
|
||||
|
||||
video_path = Path(video_path)
|
||||
imgs_dir = Path(imgs_dir)
|
||||
@@ -347,21 +427,22 @@ def encode_video_frames(
|
||||
width, height = dummy_image.size
|
||||
|
||||
# Define video codec options
|
||||
video_options = {}
|
||||
|
||||
if g is not None:
|
||||
video_options["g"] = str(g)
|
||||
|
||||
if crf is not None:
|
||||
video_options["crf"] = str(crf)
|
||||
video_options = _get_codec_options(vcodec, g, crf, preset)
|
||||
|
||||
if fast_decode:
|
||||
key = "svtav1-params" if vcodec == "libsvtav1" else "tune"
|
||||
value = f"fast-decode={fast_decode}" if vcodec == "libsvtav1" else "fastdecode"
|
||||
video_options[key] = value
|
||||
|
||||
if vcodec == "libsvtav1":
|
||||
video_options["preset"] = str(preset) if preset is not None else "12"
|
||||
if encoder_threads is not None:
|
||||
if vcodec == "libsvtav1":
|
||||
lp_param = f"lp={encoder_threads}"
|
||||
if "svtav1-params" in video_options:
|
||||
video_options["svtav1-params"] += f":{lp_param}"
|
||||
else:
|
||||
video_options["svtav1-params"] = lp_param
|
||||
else:
|
||||
video_options["threads"] = str(encoder_threads)
|
||||
|
||||
# Set logging level
|
||||
if log_level is not None:
|
||||
@@ -480,6 +561,348 @@ def concatenate_video_files(
|
||||
Path(tmp_concatenate_path).unlink()
|
||||
|
||||
|
||||
class _CameraEncoderThread(threading.Thread):
|
||||
"""A thread that encodes video frames streamed via a queue into an MP4 file.
|
||||
|
||||
One instance is created per camera per episode. Frames are received as numpy arrays
|
||||
from the main thread, encoded in real-time using PyAV (which releases the GIL during
|
||||
encoding), and written to disk. Stats are computed incrementally using
|
||||
RunningQuantileStats and returned via result_queue.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
video_path: Path,
|
||||
fps: int,
|
||||
vcodec: str,
|
||||
pix_fmt: str,
|
||||
g: int | None,
|
||||
crf: int | None,
|
||||
preset: int | None,
|
||||
frame_queue: queue.Queue,
|
||||
result_queue: queue.Queue,
|
||||
stop_event: threading.Event,
|
||||
encoder_threads: int | None = None,
|
||||
):
|
||||
super().__init__(daemon=True)
|
||||
self.video_path = video_path
|
||||
self.fps = fps
|
||||
self.vcodec = vcodec
|
||||
self.pix_fmt = pix_fmt
|
||||
self.g = g
|
||||
self.crf = crf
|
||||
self.preset = preset
|
||||
self.frame_queue = frame_queue
|
||||
self.result_queue = result_queue
|
||||
self.stop_event = stop_event
|
||||
self.encoder_threads = encoder_threads
|
||||
|
||||
def run(self) -> None:
|
||||
from lerobot.datasets.compute_stats import RunningQuantileStats, auto_downsample_height_width
|
||||
|
||||
container = None
|
||||
output_stream = None
|
||||
stats_tracker = RunningQuantileStats()
|
||||
frame_count = 0
|
||||
|
||||
try:
|
||||
logging.getLogger("libav").setLevel(av.logging.WARNING)
|
||||
|
||||
while True:
|
||||
try:
|
||||
frame_data = self.frame_queue.get(timeout=1)
|
||||
except queue.Empty:
|
||||
if self.stop_event.is_set():
|
||||
break
|
||||
continue
|
||||
|
||||
if frame_data is None:
|
||||
# Sentinel: flush and close
|
||||
break
|
||||
|
||||
# Ensure HWC uint8 numpy array
|
||||
if isinstance(frame_data, np.ndarray):
|
||||
if frame_data.ndim == 3 and frame_data.shape[0] == 3:
|
||||
# CHW -> HWC
|
||||
frame_data = frame_data.transpose(1, 2, 0)
|
||||
if frame_data.dtype != np.uint8:
|
||||
frame_data = (frame_data * 255).astype(np.uint8)
|
||||
|
||||
# Open container on first frame (to get width/height)
|
||||
if container is None:
|
||||
height, width = frame_data.shape[:2]
|
||||
video_options = _get_codec_options(self.vcodec, self.g, self.crf, self.preset)
|
||||
if self.encoder_threads is not None:
|
||||
if self.vcodec == "libsvtav1":
|
||||
lp_param = f"lp={self.encoder_threads}"
|
||||
if "svtav1-params" in video_options:
|
||||
video_options["svtav1-params"] += f":{lp_param}"
|
||||
else:
|
||||
video_options["svtav1-params"] = lp_param
|
||||
else:
|
||||
video_options["threads"] = str(self.encoder_threads)
|
||||
Path(self.video_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
container = av.open(str(self.video_path), "w")
|
||||
output_stream = container.add_stream(self.vcodec, self.fps, options=video_options)
|
||||
output_stream.pix_fmt = self.pix_fmt
|
||||
output_stream.width = width
|
||||
output_stream.height = height
|
||||
output_stream.time_base = Fraction(1, self.fps)
|
||||
|
||||
# Encode frame with explicit timestamps
|
||||
pil_img = Image.fromarray(frame_data)
|
||||
video_frame = av.VideoFrame.from_image(pil_img)
|
||||
video_frame.pts = frame_count
|
||||
video_frame.time_base = Fraction(1, self.fps)
|
||||
packet = output_stream.encode(video_frame)
|
||||
if packet:
|
||||
container.mux(packet)
|
||||
|
||||
# Update stats with downsampled frame (per-channel stats like compute_episode_stats)
|
||||
img_chw = frame_data.transpose(2, 0, 1) # HWC -> CHW
|
||||
img_downsampled = auto_downsample_height_width(img_chw)
|
||||
# Reshape CHW to (H*W, C) for per-channel stats
|
||||
channels = img_downsampled.shape[0]
|
||||
img_for_stats = img_downsampled.transpose(1, 2, 0).reshape(-1, channels)
|
||||
stats_tracker.update(img_for_stats)
|
||||
|
||||
frame_count += 1
|
||||
|
||||
# Flush encoder
|
||||
if output_stream is not None:
|
||||
packet = output_stream.encode()
|
||||
if packet:
|
||||
container.mux(packet)
|
||||
|
||||
if container is not None:
|
||||
container.close()
|
||||
|
||||
av.logging.restore_default_callback()
|
||||
|
||||
# Get stats and put on result queue
|
||||
if frame_count >= 2:
|
||||
stats = stats_tracker.get_statistics()
|
||||
self.result_queue.put(("ok", stats))
|
||||
else:
|
||||
self.result_queue.put(("ok", None))
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Encoder thread error: {e}")
|
||||
if container is not None:
|
||||
with contextlib.suppress(Exception):
|
||||
container.close()
|
||||
self.result_queue.put(("error", str(e)))
|
||||
|
||||
|
||||
class StreamingVideoEncoder:
|
||||
"""Manages per-camera encoder threads for real-time video encoding during recording.
|
||||
|
||||
Instead of writing frames as PNG images and then encoding to MP4 at episode end,
|
||||
this class streams frames directly to encoder threads, eliminating the
|
||||
PNG round-trip and making save_episode() near-instant.
|
||||
|
||||
Uses threading instead of multiprocessing to avoid the overhead of pickling large
|
||||
numpy arrays through multiprocessing.Queue. PyAV's encode() releases the GIL,
|
||||
so encoding runs in parallel with the main recording loop.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fps: int,
|
||||
vcodec: str = "libsvtav1",
|
||||
pix_fmt: str = "yuv420p",
|
||||
g: int | None = 2,
|
||||
crf: int | None = 30,
|
||||
preset: int | None = None,
|
||||
queue_maxsize: int = 30,
|
||||
encoder_threads: int | None = None,
|
||||
):
|
||||
self.fps = fps
|
||||
self.vcodec = resolve_vcodec(vcodec)
|
||||
self.pix_fmt = pix_fmt
|
||||
self.g = g
|
||||
self.crf = crf
|
||||
self.preset = preset
|
||||
self.queue_maxsize = queue_maxsize
|
||||
self.encoder_threads = encoder_threads
|
||||
|
||||
self._frame_queues: dict[str, queue.Queue] = {}
|
||||
self._result_queues: dict[str, queue.Queue] = {}
|
||||
self._threads: dict[str, _CameraEncoderThread] = {}
|
||||
self._stop_events: dict[str, threading.Event] = {}
|
||||
self._video_paths: dict[str, Path] = {}
|
||||
self._dropped_frames: dict[str, int] = {}
|
||||
self._episode_active = False
|
||||
|
||||
def start_episode(self, video_keys: list[str], temp_dir: Path) -> None:
|
||||
"""Start encoder threads for a new episode.
|
||||
|
||||
Args:
|
||||
video_keys: List of video feature keys (e.g. ["observation.images.laptop"])
|
||||
temp_dir: Base directory for temporary MP4 files
|
||||
"""
|
||||
if self._episode_active:
|
||||
self.cancel_episode()
|
||||
|
||||
self._dropped_frames.clear()
|
||||
|
||||
for video_key in video_keys:
|
||||
frame_queue: queue.Queue = queue.Queue(maxsize=self.queue_maxsize)
|
||||
result_queue: queue.Queue = queue.Queue(maxsize=1)
|
||||
stop_event = threading.Event()
|
||||
|
||||
temp_video_dir = Path(tempfile.mkdtemp(dir=temp_dir))
|
||||
video_path = temp_video_dir / f"{video_key.replace('/', '_')}_streaming.mp4"
|
||||
|
||||
encoder_thread = _CameraEncoderThread(
|
||||
video_path=video_path,
|
||||
fps=self.fps,
|
||||
vcodec=self.vcodec,
|
||||
pix_fmt=self.pix_fmt,
|
||||
g=self.g,
|
||||
crf=self.crf,
|
||||
preset=self.preset,
|
||||
frame_queue=frame_queue,
|
||||
result_queue=result_queue,
|
||||
stop_event=stop_event,
|
||||
encoder_threads=self.encoder_threads,
|
||||
)
|
||||
encoder_thread.start()
|
||||
|
||||
self._frame_queues[video_key] = frame_queue
|
||||
self._result_queues[video_key] = result_queue
|
||||
self._threads[video_key] = encoder_thread
|
||||
self._stop_events[video_key] = stop_event
|
||||
self._video_paths[video_key] = video_path
|
||||
|
||||
self._episode_active = True
|
||||
|
||||
def feed_frame(self, video_key: str, image: np.ndarray) -> None:
|
||||
"""Feed a frame to the encoder for a specific camera.
|
||||
|
||||
A copy of the image is made before enqueueing to prevent race conditions
|
||||
with camera drivers that may reuse buffers. If the encoder queue is full
|
||||
(encoder can't keep up), the frame is dropped with a warning instead of
|
||||
crashing the recording session.
|
||||
|
||||
Args:
|
||||
video_key: The video feature key
|
||||
image: numpy array in (H,W,C) or (C,H,W) format, uint8 or float
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the encoder thread has crashed
|
||||
"""
|
||||
if not self._episode_active:
|
||||
raise RuntimeError("No active episode. Call start_episode() first.")
|
||||
|
||||
thread = self._threads[video_key]
|
||||
if not thread.is_alive():
|
||||
# Check for error
|
||||
try:
|
||||
status, msg = self._result_queues[video_key].get_nowait()
|
||||
if status == "error":
|
||||
raise RuntimeError(f"Encoder thread for {video_key} crashed: {msg}")
|
||||
except queue.Empty:
|
||||
pass
|
||||
raise RuntimeError(f"Encoder thread for {video_key} is not alive")
|
||||
|
||||
try:
|
||||
self._frame_queues[video_key].put(image.copy(), timeout=0.1)
|
||||
except queue.Full:
|
||||
self._dropped_frames[video_key] = self._dropped_frames.get(video_key, 0) + 1
|
||||
count = self._dropped_frames[video_key]
|
||||
# Log periodically to avoid spam (1st, then every 10th)
|
||||
if count == 1 or count % 10 == 0:
|
||||
logging.warning(
|
||||
f"Encoder queue full for {video_key}, dropped {count} frame(s). "
|
||||
f"Consider using vcodec='auto' for hardware encoding or increasing encoder_queue_maxsize."
|
||||
)
|
||||
|
||||
def finish_episode(self) -> dict[str, tuple[Path, dict | None]]:
|
||||
"""Finish encoding the current episode.
|
||||
|
||||
Sends sentinel values, waits for encoder threads to complete,
|
||||
and collects results.
|
||||
|
||||
Returns:
|
||||
Dict mapping video_key to (mp4_path, stats_dict_or_None)
|
||||
"""
|
||||
if not self._episode_active:
|
||||
raise RuntimeError("No active episode to finish.")
|
||||
|
||||
results = {}
|
||||
|
||||
# Report dropped frames
|
||||
for video_key, count in self._dropped_frames.items():
|
||||
if count > 0:
|
||||
logging.warning(f"Episode finished with {count} dropped frame(s) for {video_key}.")
|
||||
|
||||
# Send sentinel to all queues
|
||||
for video_key in self._frame_queues:
|
||||
self._frame_queues[video_key].put(None)
|
||||
|
||||
# Wait for all threads and collect results
|
||||
for video_key in self._threads:
|
||||
self._threads[video_key].join(timeout=120)
|
||||
if self._threads[video_key].is_alive():
|
||||
logging.error(f"Encoder thread for {video_key} did not finish in time")
|
||||
self._stop_events[video_key].set()
|
||||
self._threads[video_key].join(timeout=5)
|
||||
results[video_key] = (self._video_paths[video_key], None)
|
||||
continue
|
||||
|
||||
try:
|
||||
status, data = self._result_queues[video_key].get(timeout=5)
|
||||
if status == "error":
|
||||
raise RuntimeError(f"Encoder thread for {video_key} failed: {data}")
|
||||
results[video_key] = (self._video_paths[video_key], data)
|
||||
except queue.Empty:
|
||||
logging.error(f"No result from encoder thread for {video_key}")
|
||||
results[video_key] = (self._video_paths[video_key], None)
|
||||
|
||||
self._cleanup()
|
||||
self._episode_active = False
|
||||
return results
|
||||
|
||||
def cancel_episode(self) -> None:
|
||||
"""Cancel the current episode, stopping encoder threads and cleaning up."""
|
||||
if not self._episode_active:
|
||||
return
|
||||
|
||||
# Signal all threads to stop
|
||||
for video_key in self._stop_events:
|
||||
self._stop_events[video_key].set()
|
||||
|
||||
# Wait for threads to finish
|
||||
for video_key in self._threads:
|
||||
self._threads[video_key].join(timeout=5)
|
||||
|
||||
# Clean up temp MP4 files
|
||||
video_path = self._video_paths.get(video_key)
|
||||
if video_path is not None and video_path.exists():
|
||||
shutil.rmtree(str(video_path.parent), ignore_errors=True)
|
||||
|
||||
self._cleanup()
|
||||
self._episode_active = False
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the encoder, canceling any in-progress episode."""
|
||||
if self._episode_active:
|
||||
self.cancel_episode()
|
||||
|
||||
def _cleanup(self) -> None:
|
||||
"""Clean up queues and thread tracking dicts."""
|
||||
for q in self._frame_queues.values():
|
||||
with contextlib.suppress(Exception):
|
||||
while not q.empty():
|
||||
q.get_nowait()
|
||||
self._frame_queues.clear()
|
||||
self._result_queues.clear()
|
||||
self._threads.clear()
|
||||
self._stop_events.clear()
|
||||
self._video_paths.clear()
|
||||
|
||||
|
||||
@dataclass
|
||||
class VideoFrame:
|
||||
# TODO(rcadene, lhoestq): move to Hugging Face `datasets` repo
|
||||
@@ -514,7 +937,7 @@ with warnings.catch_warnings():
|
||||
|
||||
def get_audio_info(video_path: Path | str) -> dict:
|
||||
# Set logging level
|
||||
logging.getLogger("libav").setLevel(av.logging.ERROR)
|
||||
logging.getLogger("libav").setLevel(av.logging.WARNING)
|
||||
|
||||
# Getting audio stream information
|
||||
audio_info = {}
|
||||
@@ -546,7 +969,7 @@ def get_audio_info(video_path: Path | str) -> dict:
|
||||
|
||||
def get_video_info(video_path: Path | str) -> dict:
|
||||
# Set logging level
|
||||
logging.getLogger("libav").setLevel(av.logging.ERROR)
|
||||
logging.getLogger("libav").setLevel(av.logging.WARNING)
|
||||
|
||||
# Getting video stream information
|
||||
video_info = {}
|
||||
@@ -632,8 +1055,15 @@ class VideoEncodingManager:
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
# Handle any remaining episodes that haven't been batch encoded
|
||||
if self.dataset.episodes_since_last_encoding > 0:
|
||||
streaming_encoder = getattr(self.dataset, "_streaming_encoder", None)
|
||||
|
||||
if streaming_encoder is not None:
|
||||
# Handle streaming encoder cleanup
|
||||
if exc_type is not None:
|
||||
streaming_encoder.cancel_episode()
|
||||
streaming_encoder.close()
|
||||
elif self.dataset.episodes_since_last_encoding > 0:
|
||||
# Handle any remaining episodes that haven't been batch encoded
|
||||
if exc_type is not None:
|
||||
logging.info("Exception occurred. Encoding remaining episodes before exit...")
|
||||
else:
|
||||
@@ -650,8 +1080,8 @@ class VideoEncodingManager:
|
||||
# Finalize the dataset to properly close all writers
|
||||
self.dataset.finalize()
|
||||
|
||||
# Clean up episode images if recording was interrupted
|
||||
if exc_type is not None:
|
||||
# Clean up episode images if recording was interrupted (only for non-streaming mode)
|
||||
if exc_type is not None and streaming_encoder is None:
|
||||
interrupted_episode_index = self.dataset.num_episodes
|
||||
for key in self.dataset.meta.video_keys:
|
||||
img_dir = self.dataset._get_image_file_path(
|
||||
@@ -665,14 +1095,12 @@ class VideoEncodingManager:
|
||||
|
||||
# Clean up any remaining images directory if it's empty
|
||||
img_dir = self.dataset.root / "images"
|
||||
# Check for any remaining PNG files
|
||||
png_files = list(img_dir.rglob("*.png"))
|
||||
if len(png_files) == 0:
|
||||
# Only remove the images directory if no PNG files remain
|
||||
if img_dir.exists():
|
||||
if img_dir.exists():
|
||||
png_files = list(img_dir.rglob("*.png"))
|
||||
if len(png_files) == 0:
|
||||
shutil.rmtree(img_dir)
|
||||
logging.debug("Cleaned up empty images directory")
|
||||
else:
|
||||
logging.debug(f"Images directory is not empty, containing {len(png_files)} PNG files")
|
||||
else:
|
||||
logging.debug(f"Images directory is not empty, containing {len(png_files)} PNG files")
|
||||
|
||||
return False # Don't suppress the original exception
|
||||
|
||||
@@ -26,8 +26,10 @@ lerobot-record \
|
||||
--dataset.repo_id=<my_username>/<my_dataset_name> \
|
||||
--dataset.num_episodes=2 \
|
||||
--dataset.single_task="Grab the cube" \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
--display_data=true
|
||||
# <- Optional: specify video codec (h264, hevc, libsvtav1). Default is libsvtav1. \
|
||||
# <- Optional: specify video codec (auto, h264, hevc, libsvtav1). Default is libsvtav1. \
|
||||
# --dataset.vcodec=h264 \
|
||||
# <- Teleop optional if you want to teleoperate to record or in between episodes with a policy \
|
||||
# --teleop.type=so100_leader \
|
||||
@@ -58,7 +60,10 @@ lerobot-record \
|
||||
--display_data=true \
|
||||
--dataset.repo_id=${HF_USER}/bimanual-so-handover-cube \
|
||||
--dataset.num_episodes=25 \
|
||||
--dataset.single_task="Grab and handover the red cube to the other arm"
|
||||
--dataset.single_task="Grab and handover the red cube to the other arm" \
|
||||
--dataset.streaming_encoding=true \
|
||||
# --dataset.vcodec=auto \
|
||||
--dataset.encoder_threads=2
|
||||
```
|
||||
"""
|
||||
|
||||
@@ -179,9 +184,19 @@ class DatasetRecordConfig:
|
||||
# Number of episodes to record before batch encoding videos
|
||||
# Set to 1 for immediate encoding (default behavior), or higher for batched encoding
|
||||
video_encoding_batch_size: int = 1
|
||||
# Video codec for encoding videos. Options: 'h264', 'hevc', 'libsvtav1'.
|
||||
# Use 'h264' for faster encoding on systems where AV1 encoding is CPU-heavy.
|
||||
# Video codec for encoding videos. Options: 'h264', 'hevc', 'libsvtav1', 'auto',
|
||||
# or hardware-specific: 'h264_videotoolbox', 'h264_nvenc', 'h264_vaapi', 'h264_qsv'.
|
||||
# Use 'auto' to auto-detect the best available hardware encoder.
|
||||
vcodec: str = "libsvtav1"
|
||||
# Enable streaming video encoding: encode frames in real-time during capture instead
|
||||
# of writing PNG images first. Makes save_episode() near-instant. More info in the documentation: https://huggingface.co/docs/lerobot/streaming_video_encoding
|
||||
streaming_encoding: bool = False
|
||||
# Maximum number of frames to buffer per camera when using streaming encoding.
|
||||
# ~1s buffer at 30fps. Provides backpressure if the encoder can't keep up.
|
||||
encoder_queue_maxsize: int = 30
|
||||
# Number of threads per encoder instance. None = auto (codec default).
|
||||
# Lower values reduce CPU usage, maps to 'lp' (via svtav1-params) for libsvtav1 and 'threads' for h264/hevc..
|
||||
encoder_threads: int | None = None
|
||||
# Rename map for the observation to override the image and state keys
|
||||
rename_map: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
@@ -452,6 +467,9 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
|
||||
root=cfg.dataset.root,
|
||||
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
|
||||
vcodec=cfg.dataset.vcodec,
|
||||
streaming_encoding=cfg.dataset.streaming_encoding,
|
||||
encoder_queue_maxsize=cfg.dataset.encoder_queue_maxsize,
|
||||
encoder_threads=cfg.dataset.encoder_threads,
|
||||
)
|
||||
|
||||
if hasattr(robot, "cameras") and len(robot.cameras) > 0:
|
||||
@@ -474,6 +492,9 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
|
||||
image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot.cameras),
|
||||
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
|
||||
vcodec=cfg.dataset.vcodec,
|
||||
streaming_encoding=cfg.dataset.streaming_encoding,
|
||||
encoder_queue_maxsize=cfg.dataset.encoder_queue_maxsize,
|
||||
encoder_threads=cfg.dataset.encoder_threads,
|
||||
)
|
||||
|
||||
# Load pretrained policy
|
||||
@@ -497,6 +518,11 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
|
||||
|
||||
listener, events = init_keyboard_listener()
|
||||
|
||||
if not cfg.dataset.streaming_encoding:
|
||||
logging.info(
|
||||
"Streaming encoding is disabled. If you have capable hardware, consider enabling it for way faster episode saving. --dataset.streaming_encoding=true --dataset.encoder_threads=2 # --dataset.vcodec=auto. More info in the documentation: https://huggingface.co/docs/lerobot/streaming_video_encoding"
|
||||
)
|
||||
|
||||
with VideoEncodingManager(dataset):
|
||||
recorded_episodes = 0
|
||||
while recorded_episodes < cfg.dataset.num_episodes and not events["stop_recording"]:
|
||||
|
||||
Reference in New Issue
Block a user