mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-23 20:50:02 +00:00
feat(pix_fmt channels): use PyAv to check get pixel formats number of channels
This commit is contained in:
@@ -42,6 +42,12 @@ def write_u16_plane(plane: av.video.plane.VideoPlane, src: np.ndarray, fill_valu
|
|||||||
dst[:, :width] = src
|
dst[:, :width] = src
|
||||||
|
|
||||||
|
|
||||||
|
@functools.cache
|
||||||
|
def get_pix_fmt_channels(pix_fmt: str) -> int:
|
||||||
|
"""Return the number of components (channels) for *pix_fmt*."""
|
||||||
|
return len(av.VideoFormat(pix_fmt).components)
|
||||||
|
|
||||||
|
|
||||||
@functools.cache
|
@functools.cache
|
||||||
def get_codec(vcodec: str) -> av.codec.Codec | None:
|
def get_codec(vcodec: str) -> av.codec.Codec | None:
|
||||||
"""PyAV write-mode ``Codec`` for *vcodec*, or ``None`` if unavailable."""
|
"""PyAV write-mode ``Codec`` for *vcodec*, or ``None`` if unavailable."""
|
||||||
@@ -153,6 +159,16 @@ def _check_pixel_format(vcodec: str, pix_fmt: str) -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _check_pix_fmt_channels(pix_fmt: str, channels: int) -> None:
|
||||||
|
"""Ensure *pix_fmt* can carry at least *channels* components."""
|
||||||
|
pix_fmt_channels = get_pix_fmt_channels(pix_fmt)
|
||||||
|
if pix_fmt_channels < channels:
|
||||||
|
raise ValueError(
|
||||||
|
f"pix_fmt={pix_fmt!r} carries only {pix_fmt_channels} component(s) "
|
||||||
|
f"but the source data has {channels} channel(s)."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _check_codec_options(vcodec: str, codec_options: dict[str, Any]) -> None:
|
def _check_codec_options(vcodec: str, codec_options: dict[str, Any]) -> None:
|
||||||
"""Validate merged encoder options (typed) against the codec's published AVOptions."""
|
"""Validate merged encoder options (typed) against the codec's published AVOptions."""
|
||||||
supported_options = _get_codec_options_by_name(vcodec)
|
supported_options = _get_codec_options_by_name(vcodec)
|
||||||
@@ -167,12 +183,18 @@ def _check_codec_options(vcodec: str, codec_options: dict[str, Any]) -> None:
|
|||||||
_check_option_value(vcodec, key, value, supported_options[key])
|
_check_option_value(vcodec, key, value, supported_options[key])
|
||||||
|
|
||||||
|
|
||||||
def check_video_encoder_parameters_pyav(vcodec: str, pix_fmt: str, codec_options: dict[str, Any]) -> None:
|
def check_video_encoder_parameters_pyav(
|
||||||
|
vcodec: str,
|
||||||
|
pix_fmt: str,
|
||||||
|
codec_options: dict[str, Any],
|
||||||
|
channels: int | None = None,
|
||||||
|
) -> None:
|
||||||
"""Verify *config* is compatible with the bundled FFmpeg build.
|
"""Verify *config* is compatible with the bundled FFmpeg build.
|
||||||
|
|
||||||
Checks pixel format, abstract tuning-field compatibility, and each merged
|
Checks pixel format, abstract tuning-field compatibility, and each merged
|
||||||
encoder option from :meth:`~lerobot.configs.video.VideoEncoderConfig.get_codec_options`
|
encoder option from :meth:`~lerobot.configs.video.VideoEncoderConfig.get_codec_options`
|
||||||
against PyAV (including numeric ``extra_options`` present in that dict).
|
against PyAV (including numeric ``extra_options`` present in that dict).
|
||||||
|
When given, additionally verify that *pix_fmt* carries as many components as the source data channels.
|
||||||
No-op when ``config.vcodec`` isn't in the local FFmpeg build.
|
No-op when ``config.vcodec`` isn't in the local FFmpeg build.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
@@ -182,4 +204,6 @@ def check_video_encoder_parameters_pyav(vcodec: str, pix_fmt: str, codec_options
|
|||||||
if not options:
|
if not options:
|
||||||
raise ValueError(f"Codec {vcodec!r} is not available in the bundled FFmpeg build")
|
raise ValueError(f"Codec {vcodec!r} is not available in the bundled FFmpeg build")
|
||||||
_check_pixel_format(vcodec, pix_fmt)
|
_check_pixel_format(vcodec, pix_fmt)
|
||||||
|
if channels is not None:
|
||||||
|
_check_pix_fmt_channels(pix_fmt, channels)
|
||||||
_check_codec_options(vcodec, codec_options)
|
_check_codec_options(vcodec, codec_options)
|
||||||
|
|||||||
@@ -45,6 +45,7 @@ from lerobot.configs import (
|
|||||||
from lerobot.utils.import_utils import get_safe_default_video_backend
|
from lerobot.utils.import_utils import get_safe_default_video_backend
|
||||||
|
|
||||||
from .depth_utils import DEPTH_PIX_FMT, quantize_depth
|
from .depth_utils import DEPTH_PIX_FMT, quantize_depth
|
||||||
|
from .pyav_utils import get_pix_fmt_channels
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -966,8 +967,7 @@ def get_video_info(
|
|||||||
# Calculate fps from r_frame_rate
|
# Calculate fps from r_frame_rate
|
||||||
video_info["video.fps"] = int(video_stream.base_rate)
|
video_info["video.fps"] = int(video_stream.base_rate)
|
||||||
|
|
||||||
pixel_channels = get_video_pixel_channels(video_stream.pix_fmt)
|
video_info["video.channels"] = get_pix_fmt_channels(video_stream.pix_fmt)
|
||||||
video_info["video.channels"] = pixel_channels
|
|
||||||
|
|
||||||
# Reset logging level
|
# Reset logging level
|
||||||
av.logging.restore_default_callback()
|
av.logging.restore_default_callback()
|
||||||
@@ -988,17 +988,6 @@ def get_video_info(
|
|||||||
return video_info
|
return video_info
|
||||||
|
|
||||||
|
|
||||||
def get_video_pixel_channels(pix_fmt: str) -> int:
|
|
||||||
if "gray" in pix_fmt or "depth" in pix_fmt or "monochrome" in pix_fmt:
|
|
||||||
return 1
|
|
||||||
elif "rgba" in pix_fmt or "yuva" in pix_fmt:
|
|
||||||
return 4
|
|
||||||
elif "rgb" in pix_fmt or "yuv" in pix_fmt:
|
|
||||||
return 3
|
|
||||||
else:
|
|
||||||
raise ValueError("Unknown format")
|
|
||||||
|
|
||||||
|
|
||||||
def get_video_duration_in_s(video_path: Path | str) -> float:
|
def get_video_duration_in_s(video_path: Path | str) -> float:
|
||||||
"""
|
"""
|
||||||
Get the duration of a video file in seconds using PyAV.
|
Get the duration of a video file in seconds using PyAV.
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ from lerobot.datasets.image_writer import (
|
|||||||
save_kwargs_for_path,
|
save_kwargs_for_path,
|
||||||
write_image,
|
write_image,
|
||||||
)
|
)
|
||||||
from lerobot.datasets.video_utils import get_video_pixel_channels
|
from lerobot.datasets.pyav_utils import get_pix_fmt_channels
|
||||||
from tests.fixtures.constants import (
|
from tests.fixtures.constants import (
|
||||||
DEFAULT_FPS,
|
DEFAULT_FPS,
|
||||||
DUMMY_CAMERA_FEATURES,
|
DUMMY_CAMERA_FEATURES,
|
||||||
@@ -195,9 +195,9 @@ class TestHwToDatasetFeaturesDepth:
|
|||||||
class TestVideoInfoDepthFlag:
|
class TestVideoInfoDepthFlag:
|
||||||
"""Misc depth-related constants and helpers in video_utils / configs."""
|
"""Misc depth-related constants and helpers in video_utils / configs."""
|
||||||
|
|
||||||
def test_get_video_pixel_channels_gray(self):
|
def test_get_pix_fmt_channels_gray(self):
|
||||||
assert get_video_pixel_channels("gray12le") == 1
|
assert get_pix_fmt_channels("gray12le") == 1
|
||||||
assert get_video_pixel_channels("gray8") == 1
|
assert get_pix_fmt_channels("gray8") == 1
|
||||||
|
|
||||||
def test_ffv1_in_valid_codecs(self):
|
def test_ffv1_in_valid_codecs(self):
|
||||||
assert "ffv1" in VALID_VIDEO_CODECS
|
assert "ffv1" in VALID_VIDEO_CODECS
|
||||||
|
|||||||
Reference in New Issue
Block a user