From 7128bb17693b4abaffded79b3ac5fadf041d6f87 Mon Sep 17 00:00:00 2001 From: Pepijn Date: Mon, 18 May 2026 15:45:04 +0200 Subject: [PATCH] fix(annotate): decode keyframes via PyAV directly MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The pyav fallback routed through lerobot's decode_video_frames(backend= "pyav"), which uses torchvision.io.VideoReader — removed in torchvision 0.23+. On modern torch stacks (e.g. vllm-openai with torchvision 0.26) both torchcodec and that path fail, leaving interjection/vqa prompts without visual context. Add _decode_frames_av: a self-contained PyAV decoder that picks the nearest frame per timestamp. It is the always-available tail of the decoder chain (torchcodec -> pyav) and the target of --video_backend=pyav. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../annotations/steerable_pipeline/frames.py | 56 ++++++++++++++++--- tests/annotations/test_frames.py | 56 +++++++++++++++++++ 2 files changed, 103 insertions(+), 9 deletions(-) diff --git a/src/lerobot/annotations/steerable_pipeline/frames.py b/src/lerobot/annotations/steerable_pipeline/frames.py index b2045fd7d..e918e94cf 100644 --- a/src/lerobot/annotations/steerable_pipeline/frames.py +++ b/src/lerobot/annotations/steerable_pipeline/frames.py @@ -303,19 +303,21 @@ class VideoFrameProvider: shifted = [from_timestamp + ts for ts in timestamps] video_path = self.root / self._meta.get_video_file_path(episode_index, camera_key) - # When no backend is pinned, try the platform default first and fall - # back to ``pyav`` if it raises — torchcodec is broken in some - # containers (e.g. vllm-openai), where pyav decodes the same file fine. + # Build the decoder chain. torchcodec is fast but unusable in some + # containers (vllm-openai: "Operation not permitted"); lerobot's + # ``pyav`` backend routes through ``torchvision.io.VideoReader``, + # removed in torchvision 0.23+. ``_decode_frames_av`` talks to the + # ``av`` package directly and is the always-available fallback. if self.video_backend: - backends: list[str | None] = [self.video_backend] + chain = [self.video_backend] else: - backends = [None] - if get_safe_default_codec() != "pyav": - backends.append("pyav") + chain = (["torchcodec"] if get_safe_default_codec() == "torchcodec" else []) + ["pyav"] exc: Exception | None = None - for backend in backends: + for backend in chain: try: + if backend in ("pyav", "av"): + return _decode_frames_av(video_path, shifted) # Stacked ``(N, C, H, W)`` uint8 tensor; one row per timestamp. decoded = decode_video_frames( video_path, shifted, self.tolerance_s, backend=backend, return_uint8=True @@ -339,7 +341,7 @@ class VideoFrameProvider: episode_index, camera_key, video_path, - backends, + chain, exc, exc_info=exc, ) @@ -359,6 +361,42 @@ def make_frame_provider( return provider +def _decode_frames_av(video_path: Path, timestamps: list[float]) -> list[Any]: + """Decode the frames nearest to ``timestamps`` using PyAV directly. + + lerobot's ``decode_video_frames(backend="pyav")`` routes through + ``torchvision.io.VideoReader``, removed in torchvision 0.23+. This helper + talks to the ``av`` package directly so keyframe extraction keeps working + on modern torch/torchvision stacks and in containers where torchcodec + cannot decode. Returns one ``(C, H, W)`` uint8 tensor per timestamp. + """ + import av # noqa: PLC0415 + + first_ts = min(timestamps) + last_ts = max(timestamps) + loaded_frames: list[torch.Tensor] = [] + loaded_ts: list[float] = [] + with av.open(str(video_path)) as container: + stream = container.streams.video[0] + # Seek to the keyframe at or before the first requested timestamp. + offset = max(int(first_ts / stream.time_base), 0) if stream.time_base else 0 + container.seek(offset, stream=stream, backward=True, any_frame=False) + for idx, frame in enumerate(container.decode(stream)): + ts = frame.time + if ts is None: + ts = float(frame.pts * stream.time_base) if frame.pts is not None else float(idx) + loaded_ts.append(ts) + loaded_frames.append( + torch.from_numpy(frame.to_ndarray(format="rgb24")).permute(2, 0, 1).contiguous() + ) + if ts >= last_ts: + break + if not loaded_frames: + raise RuntimeError(f"PyAV decoded no frames from {video_path}") + ts_tensor = torch.tensor(loaded_ts) + return [loaded_frames[int(torch.argmin((ts_tensor - q).abs()))] for q in timestamps] + + def _frame_to_pil(frame: Any) -> Any: """Materialise a decoded frame as a ``PIL.Image`` for the VLM message. diff --git a/tests/annotations/test_frames.py b/tests/annotations/test_frames.py index c0ed96ab3..61b941579 100644 --- a/tests/annotations/test_frames.py +++ b/tests/annotations/test_frames.py @@ -29,12 +29,18 @@ LeRobot dataset on disk. from __future__ import annotations +import shutil +import subprocess +from pathlib import Path + import pytest +import torch pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])") from lerobot.annotations.steerable_pipeline.frames import ( # noqa: E402 VideoFrameProvider, + _decode_frames_av, ) @@ -64,3 +70,53 @@ def test_videoframeprovider_has_a_lock_for_concurrent_use(): ) assert lock_field is not None assert lock_field.default_factory is threading.Lock + + +@pytest.fixture +def sample_video(tmp_path: Path) -> Path: + """A 3 s 10 fps test-pattern mp4, written with ffmpeg.""" + if shutil.which("ffmpeg") is None: + pytest.skip("ffmpeg not available") + out = tmp_path / "sample.mp4" + subprocess.run( + [ + "ffmpeg", "-y", "-f", "lavfi", + "-i", "testsrc=duration=3:size=160x120:rate=10", + "-pix_fmt", "yuv420p", str(out), + ], + check=True, + capture_output=True, + ) + return out + + +def test_decode_frames_av_returns_one_uint8_frame_per_timestamp(sample_video: Path) -> None: + """``_decode_frames_av`` decodes via PyAV directly — no torchcodec/torchvision. + + This is the always-available fallback: torchcodec is unusable in some + containers and lerobot's ``pyav`` backend routes through the removed + ``torchvision.io.VideoReader``. + """ + timestamps = [0.0, 1.0, 2.5] + frames = _decode_frames_av(sample_video, timestamps) + + assert len(frames) == len(timestamps) + for frame in frames: + assert isinstance(frame, torch.Tensor) + assert frame.dtype == torch.uint8 + assert frame.shape == (3, 120, 160) + + +def test_decode_frames_av_picks_nearest_frame(sample_video: Path) -> None: + """Repeated and out-of-order timestamps each resolve to the nearest frame.""" + frames = _decode_frames_av(sample_video, [2.0, 0.0, 2.0]) + + assert len(frames) == 3 + assert torch.equal(frames[0], frames[2]) + assert not torch.equal(frames[0], frames[1]) + + +def test_decode_frames_av_raises_on_missing_file(tmp_path: Path) -> None: + """A missing video surfaces as an exception the caller can fall back on.""" + with pytest.raises(Exception): # noqa: B017, PT011 + _decode_frames_av(tmp_path / "does_not_exist.mp4", [0.0])