chore(format): fixing formatting issues

This commit is contained in:
CarolinePascal
2026-04-24 17:19:19 +02:00
parent 7040a106a2
commit c7dc56d8b5
6 changed files with 42 additions and 48 deletions
+1 -1
View File
@@ -416,7 +416,7 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
concatenate_video_files( concatenate_video_files(
[dst_path, src_path], [dst_path, src_path],
dst_path, dst_path,
compatibilty_check=True, compatibility_check=True,
) )
# Update duration of this destination file # Update duration of this destination file
dst_file_durations[dst_key] = current_dst_duration + src_duration dst_file_durations[dst_key] = current_dst_duration + src_duration
+1 -3
View File
@@ -502,9 +502,7 @@ class DatasetWriter:
# Update video info (only needed when first episode is encoded) # Update video info (only needed when first episode is encoded)
if episode_index == 0: if episode_index == 0:
self._meta.update_video_info( self._meta.update_video_info(video_key, camera_encoder_config=self._camera_encoder_config)
video_key, camera_encoder_config=self._camera_encoder_config
)
write_info(self._meta.info, self._meta.root) write_info(self._meta.info, self._meta.root)
metadata = { metadata = {
+1 -1
View File
@@ -36,8 +36,8 @@ from .utils import (
) )
from .video_utils import ( from .video_utils import (
StreamingVideoEncoder, StreamingVideoEncoder,
get_safe_default_video_backend,
VideoEncoderConfig, VideoEncoderConfig,
get_safe_default_video_backend,
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
+6 -18
View File
@@ -95,9 +95,7 @@ def detect_available_encoders_pyav(encoders: list[str] | str) -> list[str]:
return available return available
def _is_field_supported( def _is_field_supported(field_name: str, vcodec: str, options: dict[str, av.option.Option]) -> bool:
field_name: str, vcodec: str, options: dict[str, av.option.Option]
) -> bool:
"""Whether tuning option *field_name* is meaningful for *vcodec*.""" """Whether tuning option *field_name* is meaningful for *vcodec*."""
# GOP is a stream-level option (AVStream.gop_size) not stored in private options. # GOP is a stream-level option (AVStream.gop_size) not stored in private options.
# Every video codec accepts it. # Every video codec accepts it.
@@ -118,20 +116,14 @@ def _is_field_supported(
return field_name in options return field_name in options
def _check_numeric_range( def _check_numeric_range(label: str, num: float, opt: av.option.Option, vcodec: str) -> None:
label: str, num: float, opt: av.option.Option, vcodec: str
) -> None:
"""Raise if *num* lies outside *opt*'s numeric range (no-op if range is degenerate).""" """Raise if *num* lies outside *opt*'s numeric range (no-op if range is degenerate)."""
lo, hi = float(opt.min), float(opt.max) lo, hi = float(opt.min), float(opt.max)
if lo < hi and not (lo <= num <= hi): if lo < hi and not (lo <= num <= hi):
raise ValueError( raise ValueError(f"{label}={num} is out of range for codec {vcodec!r}; must be in [{lo}, {hi}]")
f"{label}={num} is out of range for codec {vcodec!r}; must be in [{lo}, {hi}]"
)
def _validate_option_value( def _validate_option_value(vcodec: str, field_name: str, value: Any, opt: av.option.Option) -> None:
vcodec: str, field_name: str, value: Any, opt: av.option.Option
) -> None:
"""Range-check numeric *value* and choice-check string *value* against *opt*. """Range-check numeric *value* and choice-check string *value* against *opt*.
Type mismatches fall through to FFmpeg's own validation at encode time. Type mismatches fall through to FFmpeg's own validation at encode time.
@@ -154,9 +146,7 @@ def _validate_option_value(
return return
def _validate_extra_option( def _validate_extra_option(vcodec: str, key: str, value: Any, opt: av.option.Option) -> None:
vcodec: str, key: str, value: Any, opt: av.option.Option
) -> None:
"""Validate an ``extra_options`` entry: enforce numeric range/type only. """Validate an ``extra_options`` entry: enforce numeric range/type only.
Non-numeric options are passed through (FFmpeg accepts many ad-hoc strings). Non-numeric options are passed through (FFmpeg accepts many ad-hoc strings).
@@ -208,9 +198,7 @@ def _check_tuning_fields(
# Enforce a positive integer value. # Enforce a positive integer value.
if field_name == "g": if field_name == "g":
if isinstance(value, bool) or not isinstance(value, int) or value < 1: if isinstance(value, bool) or not isinstance(value, int) or value < 1:
raise ValueError( raise ValueError(f"g={value!r} must be a positive integer for codec {vcodec!r}")
f"g={value!r} must be a positive integer for codec {vcodec!r}"
)
continue continue
# Value shape is only cross-checkable when the field maps directly # Value shape is only cross-checkable when the field maps directly
# to a private option: ``preset`` is literally ``"preset"``; # to a private option: ``preset`` is literally ``"preset"``;
+18 -13
View File
@@ -100,7 +100,6 @@ class VideoEncoderConfig:
self.validate() self.validate()
def detect_available_encoders(self, encoders: list[str] | str) -> list[str]: def detect_available_encoders(self, encoders: list[str] | str) -> list[str]:
"""Detect available encoders based on the video backend.""" """Detect available encoders based on the video backend."""
if self.video_backend == "pyav": if self.video_backend == "pyav":
@@ -108,13 +107,11 @@ class VideoEncoderConfig:
else: else:
return [] return []
def validate(self) -> None: def validate(self) -> None:
"""Validate the video encoder config.""" """Validate the video encoder config."""
if self.video_backend == "pyav": if self.video_backend == "pyav":
check_video_encoder_config_pyav(self) check_video_encoder_config_pyav(self)
def resolve_vcodec(self) -> None: def resolve_vcodec(self) -> None:
"""Validate vcodec and resolve 'auto' to best available HW encoder, fallback to libsvtav1. """Validate vcodec and resolve 'auto' to best available HW encoder, fallback to libsvtav1.
@@ -123,9 +120,7 @@ class VideoEncoderConfig:
a host missing the requested encoder. a host missing the requested encoder.
""" """
if self.vcodec not in VALID_VIDEO_CODECS: if self.vcodec not in VALID_VIDEO_CODECS:
raise ValueError( raise ValueError(f"Invalid vcodec '{self.vcodec}'. Must be one of: {sorted(VALID_VIDEO_CODECS)}")
f"Invalid vcodec '{self.vcodec}'. Must be one of: {sorted(VALID_VIDEO_CODECS)}"
)
if self.vcodec == "auto": if self.vcodec == "auto":
available = self.detect_available_encoders(HW_ENCODERS) available = self.detect_available_encoders(HW_ENCODERS)
for encoder in HW_ENCODERS: for encoder in HW_ENCODERS:
@@ -142,7 +137,6 @@ class VideoEncoderConfig:
return return
raise ValueError(f"Unsupported video codec: {self.vcodec} with video backend {self.video_backend}") raise ValueError(f"Unsupported video codec: {self.vcodec} with video backend {self.video_backend}")
def get_codec_options(self, encoder_threads: int | None = None) -> dict[str, str]: def get_codec_options(self, encoder_threads: int | None = None) -> dict[str, str]:
"""Translate the tuning fields to codec-specific FFmpeg options. """Translate the tuning fields to codec-specific FFmpeg options.
@@ -563,7 +557,10 @@ def encode_video_frames(
def concatenate_video_files( def concatenate_video_files(
input_video_paths: list[Path | str], output_video_path: Path, overwrite: bool = True, compatibilty_check: bool = False input_video_paths: list[Path | str],
output_video_path: Path,
overwrite: bool = True,
compatibility_check: bool = False,
): ):
""" """
Concatenate multiple video files into a single video file using pyav. Concatenate multiple video files into a single video file using pyav.
@@ -576,7 +573,7 @@ def concatenate_video_files(
input_video_paths: Ordered list of input video file paths to concatenate. input_video_paths: Ordered list of input video file paths to concatenate.
output_video_path: Path to the output video file. output_video_path: Path to the output video file.
overwrite: Whether to overwrite the output video file if it already exists. Default is True. overwrite: Whether to overwrite the output video file if it already exists. Default is True.
compatibilty_check: Whether to check if the input videos are compatible. Default is False. compatibility_check: Whether to check if the input videos are compatible. Default is False.
Note: Note:
- Creates a temporary directory for intermediate files that is cleaned up after use. - Creates a temporary directory for intermediate files that is cleaned up after use.
@@ -596,12 +593,20 @@ def concatenate_video_files(
raise FileNotFoundError("No input video paths provided.") raise FileNotFoundError("No input video paths provided.")
# This check may be skipped at recording time as videos are encoded with the same encoder config. # This check may be skipped at recording time as videos are encoded with the same encoder config.
if compatibilty_check: if compatibility_check:
reference_video_info = get_video_info(input_video_paths[0]) reference_video_info = get_video_info(input_video_paths[0])
for input_path in input_video_paths[1:]: for input_path in input_video_paths[1:]:
video_info = get_video_info(input_path) video_info = get_video_info(input_path)
if video_info["video.height"] != reference_video_info["video.height"] or video_info["video.width"] != reference_video_info["video.width"] or video_info["video.fps"] != reference_video_info["video.fps"] or video_info["video.codec"] != reference_video_info["video.codec"] or video_info["video.pix_fmt"] != reference_video_info["video.pix_fmt"]: if (
raise ValueError(f"Input video {input_path} is not compatible with the reference video {input_video_paths[0]}.") video_info["video.height"] != reference_video_info["video.height"]
or video_info["video.width"] != reference_video_info["video.width"]
or video_info["video.fps"] != reference_video_info["video.fps"]
or video_info["video.codec"] != reference_video_info["video.codec"]
or video_info["video.pix_fmt"] != reference_video_info["video.pix_fmt"]
):
raise ValueError(
f"Input video {input_path} is not compatible with the reference video {input_video_paths[0]}."
)
# Create a temporary .ffconcat file to list the input video paths # Create a temporary .ffconcat file to list the input video paths
with tempfile.NamedTemporaryFile(mode="w", suffix=".ffconcat", delete=False) as tmp_concatenate_file: with tempfile.NamedTemporaryFile(mode="w", suffix=".ffconcat", delete=False) as tmp_concatenate_file:
@@ -1058,7 +1063,7 @@ def get_video_info(
Args: Args:
video_path: Path to the encoded video file to probe. video_path: Path to the encoded video file to probe.
camera_encoder_config: If provided, record the exact encoder settings used to encode this camera_encoder_config: If provided, record the exact encoder settings used to encode this
video. Stream-derived values take precedence encoder fields are only written for keys video. Stream-derived values take precedence encoder fields are only written for keys
not already populated from the video file itself. not already populated from the video file itself.
""" """
+15 -12
View File
@@ -28,7 +28,7 @@ import av # noqa: E402
from lerobot.datasets.image_writer import write_image from lerobot.datasets.image_writer import write_image
from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.pyav_utils import detect_available_encoders_pyav, get_codec from lerobot.datasets.pyav_utils import get_codec
from lerobot.datasets.utils import INFO_PATH from lerobot.datasets.utils import INFO_PATH
from lerobot.datasets.video_utils import ( from lerobot.datasets.video_utils import (
VALID_VIDEO_CODECS, VALID_VIDEO_CODECS,
@@ -38,12 +38,11 @@ from lerobot.datasets.video_utils import (
get_video_info, get_video_info,
) )
# Per-codec skip markers — validation tests only fire when the codec is available # Per-codec skip markers — validation tests only fire when the codec is available
def _require_encoder(vcodec: str) -> pytest.MarkDecorator: def _require_encoder(vcodec: str) -> pytest.MarkDecorator:
"""Skip the test if ``vcodec`` is not available in the local FFmpeg build.""" """Skip the test if ``vcodec`` is not available in the local FFmpeg build."""
return pytest.mark.skipif( return pytest.mark.skipif(get_codec(vcodec) is None, reason=f"{vcodec!r} not in local FFmpeg build")
get_codec(vcodec) is None, reason=f"{vcodec!r} not in local FFmpeg build"
)
require_libsvtav1 = _require_encoder("libsvtav1") require_libsvtav1 = _require_encoder("libsvtav1")
@@ -306,7 +305,9 @@ def _write_frames(imgs_dir: Path, num_frames: int = 4, height: int = 64, width:
write_image(arr, imgs_dir / f"frame-{i:06d}.png") write_image(arr, imgs_dir / f"frame-{i:06d}.png")
def _encode_video(path: Path, num_frames: int = 4, fps: int = 30, cfg: VideoEncoderConfig | None = None) -> Path: def _encode_video(
path: Path, num_frames: int = 4, fps: int = 30, cfg: VideoEncoderConfig | None = None
) -> Path:
imgs_dir = path.parent / f"imgs_{path.stem}" imgs_dir = path.parent / f"imgs_{path.stem}"
_write_frames(imgs_dir, num_frames=num_frames) _write_frames(imgs_dir, num_frames=num_frames)
encode_video_frames(imgs_dir, path, fps=fps, camera_encoder_config=cfg, overwrite=True) encode_video_frames(imgs_dir, path, fps=fps, camera_encoder_config=cfg, overwrite=True)
@@ -321,11 +322,13 @@ def _read_feature_info(dataset: LeRobotDataset) -> dict:
def _add_frames(dataset: LeRobotDataset, num_frames: int) -> None: def _add_frames(dataset: LeRobotDataset, num_frames: int) -> None:
shape = dataset.meta.features[VIDEO_KEY]["shape"] shape = dataset.meta.features[VIDEO_KEY]["shape"]
for _ in range(num_frames): for _ in range(num_frames):
dataset.add_frame({ dataset.add_frame(
VIDEO_KEY: np.random.randint(0, 256, shape, dtype=np.uint8), {
"action": np.zeros(2, dtype=np.float32), VIDEO_KEY: np.random.randint(0, 256, shape, dtype=np.uint8),
"task": "test", "action": np.zeros(2, dtype=np.float32),
}) "task": "test",
}
)
class TestGetVideoInfo: class TestGetVideoInfo:
@@ -480,7 +483,7 @@ class TestConcatenateVideoFiles:
concatenate_video_files( concatenate_video_files(
[ARTIFACTS / "clip_4frames.mp4", ARTIFACTS / "clip_h264.mp4"], [ARTIFACTS / "clip_4frames.mp4", ARTIFACTS / "clip_h264.mp4"],
tmp_path / "out.mp4", tmp_path / "out.mp4",
compatibilty_check=True, compatibility_check=True,
) )
def test_compatibility_check_raises_on_different_resolution(self, tmp_path): def test_compatibility_check_raises_on_different_resolution(self, tmp_path):
@@ -488,7 +491,7 @@ class TestConcatenateVideoFiles:
concatenate_video_files( concatenate_video_files(
[ARTIFACTS / "clip_4frames.mp4", ARTIFACTS / "clip_32x48.mp4"], [ARTIFACTS / "clip_4frames.mp4", ARTIFACTS / "clip_32x48.mp4"],
tmp_path / "out.mp4", tmp_path / "out.mp4",
compatibilty_check=True, compatibility_check=True,
) )