feat(pix_fmt channels): use PyAv to check get pixel formats number of channels

This commit is contained in:
CarolinePascal
2026-05-22 02:01:05 +02:00
parent 72a429764a
commit 7498f1cf61
3 changed files with 31 additions and 18 deletions
+25 -1
View File
@@ -42,6 +42,12 @@ def write_u16_plane(plane: av.video.plane.VideoPlane, src: np.ndarray, fill_valu
dst[:, :width] = src
@functools.cache
def get_pix_fmt_channels(pix_fmt: str) -> int:
"""Return the number of components (channels) for *pix_fmt*."""
return len(av.VideoFormat(pix_fmt).components)
@functools.cache
def get_codec(vcodec: str) -> av.codec.Codec | None:
"""PyAV write-mode ``Codec`` for *vcodec*, or ``None`` if unavailable."""
@@ -153,6 +159,16 @@ def _check_pixel_format(vcodec: str, pix_fmt: str) -> None:
)
def _check_pix_fmt_channels(pix_fmt: str, channels: int) -> None:
"""Ensure *pix_fmt* can carry at least *channels* components."""
pix_fmt_channels = get_pix_fmt_channels(pix_fmt)
if pix_fmt_channels < channels:
raise ValueError(
f"pix_fmt={pix_fmt!r} carries only {pix_fmt_channels} component(s) "
f"but the source data has {channels} channel(s)."
)
def _check_codec_options(vcodec: str, codec_options: dict[str, Any]) -> None:
"""Validate merged encoder options (typed) against the codec's published AVOptions."""
supported_options = _get_codec_options_by_name(vcodec)
@@ -167,12 +183,18 @@ def _check_codec_options(vcodec: str, codec_options: dict[str, Any]) -> None:
_check_option_value(vcodec, key, value, supported_options[key])
def check_video_encoder_parameters_pyav(vcodec: str, pix_fmt: str, codec_options: dict[str, Any]) -> None:
def check_video_encoder_parameters_pyav(
vcodec: str,
pix_fmt: str,
codec_options: dict[str, Any],
channels: int | None = None,
) -> None:
"""Verify *config* is compatible with the bundled FFmpeg build.
Checks pixel format, abstract tuning-field compatibility, and each merged
encoder option from :meth:`~lerobot.configs.video.VideoEncoderConfig.get_codec_options`
against PyAV (including numeric ``extra_options`` present in that dict).
When given, additionally verify that *pix_fmt* carries as many components as the source data channels.
No-op when ``config.vcodec`` isn't in the local FFmpeg build.
Raises:
@@ -182,4 +204,6 @@ def check_video_encoder_parameters_pyav(vcodec: str, pix_fmt: str, codec_options
if not options:
raise ValueError(f"Codec {vcodec!r} is not available in the bundled FFmpeg build")
_check_pixel_format(vcodec, pix_fmt)
if channels is not None:
_check_pix_fmt_channels(pix_fmt, channels)
_check_codec_options(vcodec, codec_options)
+2 -13
View File
@@ -45,6 +45,7 @@ from lerobot.configs import (
from lerobot.utils.import_utils import get_safe_default_video_backend
from .depth_utils import DEPTH_PIX_FMT, quantize_depth
from .pyav_utils import get_pix_fmt_channels
logger = logging.getLogger(__name__)
@@ -966,8 +967,7 @@ def get_video_info(
# Calculate fps from r_frame_rate
video_info["video.fps"] = int(video_stream.base_rate)
pixel_channels = get_video_pixel_channels(video_stream.pix_fmt)
video_info["video.channels"] = pixel_channels
video_info["video.channels"] = get_pix_fmt_channels(video_stream.pix_fmt)
# Reset logging level
av.logging.restore_default_callback()
@@ -988,17 +988,6 @@ def get_video_info(
return video_info
def get_video_pixel_channels(pix_fmt: str) -> int:
if "gray" in pix_fmt or "depth" in pix_fmt or "monochrome" in pix_fmt:
return 1
elif "rgba" in pix_fmt or "yuva" in pix_fmt:
return 4
elif "rgb" in pix_fmt or "yuv" in pix_fmt:
return 3
else:
raise ValueError("Unknown format")
def get_video_duration_in_s(video_path: Path | str) -> float:
"""
Get the duration of a video file in seconds using PyAV.
+4 -4
View File
@@ -25,7 +25,7 @@ from lerobot.datasets.image_writer import (
save_kwargs_for_path,
write_image,
)
from lerobot.datasets.video_utils import get_video_pixel_channels
from lerobot.datasets.pyav_utils import get_pix_fmt_channels
from tests.fixtures.constants import (
DEFAULT_FPS,
DUMMY_CAMERA_FEATURES,
@@ -195,9 +195,9 @@ class TestHwToDatasetFeaturesDepth:
class TestVideoInfoDepthFlag:
"""Misc depth-related constants and helpers in video_utils / configs."""
def test_get_video_pixel_channels_gray(self):
assert get_video_pixel_channels("gray12le") == 1
assert get_video_pixel_channels("gray8") == 1
def test_get_pix_fmt_channels_gray(self):
assert get_pix_fmt_channels("gray12le") == 1
assert get_pix_fmt_channels("gray8") == 1
def test_ffv1_in_valid_codecs(self):
assert "ffv1" in VALID_VIDEO_CODECS