diff --git a/src/lerobot/datasets/pyav_utils.py b/src/lerobot/datasets/pyav_utils.py index 1fbbe5f89..4745058ca 100644 --- a/src/lerobot/datasets/pyav_utils.py +++ b/src/lerobot/datasets/pyav_utils.py @@ -42,6 +42,12 @@ def write_u16_plane(plane: av.video.plane.VideoPlane, src: np.ndarray, fill_valu dst[:, :width] = src +@functools.cache +def get_pix_fmt_channels(pix_fmt: str) -> int: + """Return the number of components (channels) for *pix_fmt*.""" + return len(av.VideoFormat(pix_fmt).components) + + @functools.cache def get_codec(vcodec: str) -> av.codec.Codec | None: """PyAV write-mode ``Codec`` for *vcodec*, or ``None`` if unavailable.""" @@ -153,6 +159,16 @@ def _check_pixel_format(vcodec: str, pix_fmt: str) -> None: ) +def _check_pix_fmt_channels(pix_fmt: str, channels: int) -> None: + """Ensure *pix_fmt* can carry at least *channels* components.""" + pix_fmt_channels = get_pix_fmt_channels(pix_fmt) + if pix_fmt_channels < channels: + raise ValueError( + f"pix_fmt={pix_fmt!r} carries only {pix_fmt_channels} component(s) " + f"but the source data has {channels} channel(s)." + ) + + def _check_codec_options(vcodec: str, codec_options: dict[str, Any]) -> None: """Validate merged encoder options (typed) against the codec's published AVOptions.""" supported_options = _get_codec_options_by_name(vcodec) @@ -167,12 +183,18 @@ def _check_codec_options(vcodec: str, codec_options: dict[str, Any]) -> None: _check_option_value(vcodec, key, value, supported_options[key]) -def check_video_encoder_parameters_pyav(vcodec: str, pix_fmt: str, codec_options: dict[str, Any]) -> None: +def check_video_encoder_parameters_pyav( + vcodec: str, + pix_fmt: str, + codec_options: dict[str, Any], + channels: int | None = None, +) -> None: """Verify *config* is compatible with the bundled FFmpeg build. Checks pixel format, abstract tuning-field compatibility, and each merged encoder option from :meth:`~lerobot.configs.video.VideoEncoderConfig.get_codec_options` against PyAV (including numeric ``extra_options`` present in that dict). + When given, additionally verify that *pix_fmt* carries as many components as the source data channels. No-op when ``config.vcodec`` isn't in the local FFmpeg build. Raises: @@ -182,4 +204,6 @@ def check_video_encoder_parameters_pyav(vcodec: str, pix_fmt: str, codec_options if not options: raise ValueError(f"Codec {vcodec!r} is not available in the bundled FFmpeg build") _check_pixel_format(vcodec, pix_fmt) + if channels is not None: + _check_pix_fmt_channels(pix_fmt, channels) _check_codec_options(vcodec, codec_options) diff --git a/src/lerobot/datasets/video_utils.py b/src/lerobot/datasets/video_utils.py index 1c68de52d..ab741ba0a 100644 --- a/src/lerobot/datasets/video_utils.py +++ b/src/lerobot/datasets/video_utils.py @@ -45,6 +45,7 @@ from lerobot.configs import ( from lerobot.utils.import_utils import get_safe_default_video_backend from .depth_utils import DEPTH_PIX_FMT, quantize_depth +from .pyav_utils import get_pix_fmt_channels logger = logging.getLogger(__name__) @@ -966,8 +967,7 @@ def get_video_info( # Calculate fps from r_frame_rate video_info["video.fps"] = int(video_stream.base_rate) - pixel_channels = get_video_pixel_channels(video_stream.pix_fmt) - video_info["video.channels"] = pixel_channels + video_info["video.channels"] = get_pix_fmt_channels(video_stream.pix_fmt) # Reset logging level av.logging.restore_default_callback() @@ -988,17 +988,6 @@ def get_video_info( return video_info -def get_video_pixel_channels(pix_fmt: str) -> int: - if "gray" in pix_fmt or "depth" in pix_fmt or "monochrome" in pix_fmt: - return 1 - elif "rgba" in pix_fmt or "yuva" in pix_fmt: - return 4 - elif "rgb" in pix_fmt or "yuv" in pix_fmt: - return 3 - else: - raise ValueError("Unknown format") - - def get_video_duration_in_s(video_path: Path | str) -> float: """ Get the duration of a video file in seconds using PyAV. diff --git a/tests/datasets/test_depth.py b/tests/datasets/test_depth.py index 6823857e4..9a6c1f942 100644 --- a/tests/datasets/test_depth.py +++ b/tests/datasets/test_depth.py @@ -25,7 +25,7 @@ from lerobot.datasets.image_writer import ( save_kwargs_for_path, write_image, ) -from lerobot.datasets.video_utils import get_video_pixel_channels +from lerobot.datasets.pyav_utils import get_pix_fmt_channels from tests.fixtures.constants import ( DEFAULT_FPS, DUMMY_CAMERA_FEATURES, @@ -195,9 +195,9 @@ class TestHwToDatasetFeaturesDepth: class TestVideoInfoDepthFlag: """Misc depth-related constants and helpers in video_utils / configs.""" - def test_get_video_pixel_channels_gray(self): - assert get_video_pixel_channels("gray12le") == 1 - assert get_video_pixel_channels("gray8") == 1 + def test_get_pix_fmt_channels_gray(self): + assert get_pix_fmt_channels("gray12le") == 1 + assert get_pix_fmt_channels("gray8") == 1 def test_ffv1_in_valid_codecs(self): assert "ffv1" in VALID_VIDEO_CODECS