From 47aee1fdbec00c2b88dcf4a5eacbe43b77a13d6a Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Fri, 29 Aug 2025 01:06:46 +0200 Subject: [PATCH] revert back `video_utils.py` to using pyav while keeping concat_video_files function --- src/lerobot/datasets/video_utils.py | 208 ++++++++++++++++------------ tests/fixtures/files.py | 10 -- 2 files changed, 116 insertions(+), 102 deletions(-) diff --git a/src/lerobot/datasets/video_utils.py b/src/lerobot/datasets/video_utils.py index a96bf89fa..7552ff808 100644 --- a/src/lerobot/datasets/video_utils.py +++ b/src/lerobot/datasets/video_utils.py @@ -13,18 +13,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import glob import importlib -import json import logging import shutil import subprocess import tempfile import warnings -from collections import OrderedDict from dataclasses import dataclass, field from pathlib import Path from typing import Any, ClassVar +import av import pyarrow as pa import torch import torchvision @@ -106,7 +106,7 @@ def decode_video_frames_torchvision( keyframes_only = False torchvision.set_video_backend(backend) if backend == "pyav": - keyframes_only = True # pyav doesnt support accuracte seek + keyframes_only = True # pyav doesn't support accurate seek # set a video stream reader # TODO(rcadene): also load audio stream at the same time @@ -159,7 +159,6 @@ def decode_video_frames_torchvision( ) # get closest frames to the query timestamps - # TODO(rcadene): remove torch.stack closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_]) closest_ts = loaded_ts[argmin_] @@ -257,51 +256,83 @@ def encode_video_frames( g: int | None = 2, crf: int | None = 30, fast_decode: int = 0, - log_level: str | None = "quiet", + log_level: int | None = av.logging.ERROR, overwrite: bool = False, ) -> None: """More info on ffmpeg arguments tuning on `benchmark/video/README.md`""" + # Check encoder availability + if vcodec not in ["h264", "hevc", "libsvtav1"]: + raise ValueError(f"Unsupported video codec: {vcodec}. Supported codecs are: h264, hevc, libsvtav1.") + video_path = Path(video_path) imgs_dir = Path(imgs_dir) - video_path.parent.mkdir(parents=True, exist_ok=True) - ffmpeg_args = OrderedDict( - [ - ("-f", "image2"), - ("-r", str(fps)), - ("-i", str(imgs_dir / "frame-%06d.png")), - ("-vcodec", vcodec), - ("-pix_fmt", pix_fmt), - ] + video_path.parent.mkdir(parents=True, exist_ok=overwrite) + + # Encoders/pixel formats incompatibility check + if (vcodec == "libsvtav1" or vcodec == "hevc") and pix_fmt == "yuv444p": + logging.warning( + f"Incompatible pixel format 'yuv444p' for codec {vcodec}, auto-selecting format 'yuv420p'" + ) + pix_fmt = "yuv420p" + + # Get input frames + template = "frame_" + ("[0-9]" * 6) + ".png" + input_list = sorted( + glob.glob(str(imgs_dir / template)), key=lambda x: int(x.split("_")[-1].split(".")[0]) ) + # Define video output frame size (assuming all input frames are the same size) + if len(input_list) == 0: + raise FileNotFoundError(f"No images found in {imgs_dir}.") + dummy_image = Image.open(input_list[0]) + width, height = dummy_image.size + + # Define video codec options + video_options = {} + if g is not None: - ffmpeg_args["-g"] = str(g) + video_options["g"] = str(g) if crf is not None: - ffmpeg_args["-crf"] = str(crf) + video_options["crf"] = str(crf) if fast_decode: - key = "-svtav1-params" if vcodec == "libsvtav1" else "-tune" + key = "svtav1-params" if vcodec == "libsvtav1" else "tune" value = f"fast-decode={fast_decode}" if vcodec == "libsvtav1" else "fastdecode" - ffmpeg_args[key] = value + video_options[key] = value + # Set logging level if log_level is not None: - ffmpeg_args["-loglevel"] = str(log_level) + # "While less efficient, it is generally preferable to modify logging with Python's logging" + logging.getLogger("libav").setLevel(log_level) - ffmpeg_args = [item for pair in ffmpeg_args.items() for item in pair] - if overwrite: - ffmpeg_args.append("-y") + # Create and open output file (overwrite by default) + with av.open(str(video_path), "w") as output: + output_stream = output.add_stream(vcodec, fps, options=video_options) + output_stream.pix_fmt = pix_fmt + output_stream.width = width + output_stream.height = height - ffmpeg_cmd = ["ffmpeg"] + ffmpeg_args + [str(video_path)] - # redirect stdin to subprocess.DEVNULL to prevent reading random keyboard inputs from terminal - subprocess.run(ffmpeg_cmd, check=True, stdin=subprocess.DEVNULL) + # Loop through input frames and encode them + for input_data in input_list: + input_image = Image.open(input_data).convert("RGB") + input_frame = av.VideoFrame.from_image(input_image) + packet = output_stream.encode(input_frame) + if packet: + output.mux(packet) + + # Flush the encoder + packet = output_stream.encode() + if packet: + output.mux(packet) + + # Reset logging level + if log_level is not None: + av.logging.restore_default_callback() if not video_path.exists(): - raise OSError( - f"Video encoding did not work. File not found: {video_path}. " - f"Try running the command manually to debug: `{''.join(ffmpeg_cmd)}`" - ) + raise OSError(f"Video encoding did not work. File not found: {video_path}.") def concat_video_files(paths_to_cat: list[Path], root: Path, video_key: str, chunk_idx: int, file_idx: int): @@ -325,6 +356,9 @@ def concat_video_files(paths_to_cat: list[Path], root: Path, video_key: str, chu codec, resolution, and frame rate for proper concatenation. - Output path follows the DEFAULT_VIDEO_PATH pattern with video_key, chunk_idx, and file_idx parameters. + - This function uses subprocess to call ffmpeg directly because PyAV doesn't have + built-in support for video concatenation. The concat demuxer in ffmpeg handles + all the complex timestamp adjustments automatically. """ tmp_dir = Path(tempfile.mkdtemp(dir=root)) @@ -390,78 +424,68 @@ with warnings.catch_warnings(): def get_audio_info(video_path: Path | str) -> dict: - ffprobe_audio_cmd = [ - "ffprobe", - "-v", - "error", - "-select_streams", - "a:0", - "-show_entries", - "stream=channels,codec_name,bit_rate,sample_rate,bit_depth,channel_layout,duration", - "-of", - "json", - str(video_path), - ] - result = subprocess.run(ffprobe_audio_cmd, capture_output=True, text=True) - if result.returncode != 0: - raise RuntimeError(f"Error running ffprobe: {result.stderr}") + # Set logging level + logging.getLogger("libav").setLevel(av.logging.ERROR) - info = json.loads(result.stdout) - audio_stream_info = info["streams"][0] if info.get("streams") else None - if audio_stream_info is None: - return {"has_audio": False} + # Getting audio stream information + audio_info = {} + with av.open(str(video_path), "r") as audio_file: + try: + audio_stream = audio_file.streams.audio[0] + except IndexError: + # Reset logging level + av.logging.restore_default_callback() + return {"has_audio": False} - # Return the information, defaulting to None if no audio stream is present - return { - "has_audio": True, - "audio.channels": audio_stream_info.get("channels", None), - "audio.codec": audio_stream_info.get("codec_name", None), - "audio.bit_rate": int(audio_stream_info["bit_rate"]) if audio_stream_info.get("bit_rate") else None, - "audio.sample_rate": int(audio_stream_info["sample_rate"]) - if audio_stream_info.get("sample_rate") - else None, - "audio.bit_depth": audio_stream_info.get("bit_depth", None), - "audio.channel_layout": audio_stream_info.get("channel_layout", None), - } + audio_info["audio.channels"] = audio_stream.channels + audio_info["audio.codec"] = audio_stream.codec.canonical_name + # In an ideal loseless case : bit depth x sample rate x channels = bit rate. + # In an actual compressed case, the bit rate is set according to the compression level : the lower the bit rate, the more compression is applied. + audio_info["audio.bit_rate"] = audio_stream.bit_rate + audio_info["audio.sample_rate"] = audio_stream.sample_rate # Number of samples per second + # In an ideal loseless case : fixed number of bits per sample. + # In an actual compressed case : variable number of bits per sample (often reduced to match a given depth rate). + audio_info["audio.bit_depth"] = audio_stream.format.bits + audio_info["audio.channel_layout"] = audio_stream.layout.name + audio_info["has_audio"] = True + + # Reset logging level + av.logging.restore_default_callback() + + return audio_info def get_video_info(video_path: Path | str) -> dict: - ffprobe_video_cmd = [ - "ffprobe", - "-v", - "error", - "-select_streams", - "v:0", - "-show_entries", - "stream=r_frame_rate,width,height,codec_name,nb_frames,duration,pix_fmt", - "-of", - "json", - str(video_path), - ] - result = subprocess.run(ffprobe_video_cmd, capture_output=True, text=True) - if result.returncode != 0: - raise RuntimeError(f"Error running ffprobe: {result.stderr}") + # Set logging level + logging.getLogger("libav").setLevel(av.logging.ERROR) - info = json.loads(result.stdout) - video_stream_info = info["streams"][0] + # Getting video stream information + video_info = {} + with av.open(str(video_path), "r") as video_file: + try: + video_stream = video_file.streams.video[0] + except IndexError: + # Reset logging level + av.logging.restore_default_callback() + return {} - # Calculate fps from r_frame_rate - r_frame_rate = video_stream_info["r_frame_rate"] - num, denom = map(int, r_frame_rate.split("/")) - fps = num / denom + video_info["video.height"] = video_stream.height + video_info["video.width"] = video_stream.width + video_info["video.codec"] = video_stream.codec.canonical_name + video_info["video.pix_fmt"] = video_stream.pix_fmt + video_info["video.is_depth_map"] = False - pixel_channels = get_video_pixel_channels(video_stream_info["pix_fmt"]) + # Calculate fps from r_frame_rate + video_info["video.fps"] = int(video_stream.base_rate) - video_info = { - "video.fps": fps, - "video.height": video_stream_info["height"], - "video.width": video_stream_info["width"], - "video.channels": pixel_channels, - "video.codec": video_stream_info["codec_name"], - "video.pix_fmt": video_stream_info["pix_fmt"], - "video.is_depth_map": False, - **get_audio_info(video_path), - } + pixel_channels = get_video_pixel_channels(video_stream.pix_fmt) + video_info["video.channels"] = pixel_channels + + # Reset logging level + av.logging.restore_default_callback() + + # Adding audio stream information + video_info.update(**get_audio_info(video_path)) return video_info diff --git a/tests/fixtures/files.py b/tests/fixtures/files.py index a3611d841..11f3fa94a 100644 --- a/tests/fixtures/files.py +++ b/tests/fixtures/files.py @@ -140,16 +140,6 @@ def create_stats(stats_factory): return _create_stats -# @pytest.fixture(scope="session") -# def create_episodes_stats(episodes_stats_factory): -# def _create_episodes_stats(dir: Path, episodes_stats: Dataset | None = None): -# if episodes_stats is None: -# episodes_stats = episodes_stats_factory() -# write_episodes_stats(episodes_stats, dir) - -# return _create_episodes_stats - - @pytest.fixture(scope="session") def create_tasks(tasks_factory): def _create_tasks(dir: Path, tasks: pd.DataFrame | None = None):