fix(annotate): decode keyframes via PyAV directly

The pyav fallback routed through lerobot's decode_video_frames(backend=
"pyav"), which uses torchvision.io.VideoReader — removed in torchvision
0.23+. On modern torch stacks (e.g. vllm-openai with torchvision 0.26)
both torchcodec and that path fail, leaving interjection/vqa prompts
without visual context.

Add _decode_frames_av: a self-contained PyAV decoder that picks the
nearest frame per timestamp. It is the always-available tail of the
decoder chain (torchcodec -> pyav) and the target of --video_backend=pyav.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-05-18 15:45:04 +02:00
parent 31e0c15e55
commit 7128bb1769
2 changed files with 103 additions and 9 deletions
+56
View File
@@ -29,12 +29,18 @@ LeRobot dataset on disk.
from __future__ import annotations
import shutil
import subprocess
from pathlib import Path
import pytest
import torch
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
from lerobot.annotations.steerable_pipeline.frames import ( # noqa: E402
VideoFrameProvider,
_decode_frames_av,
)
@@ -64,3 +70,53 @@ def test_videoframeprovider_has_a_lock_for_concurrent_use():
)
assert lock_field is not None
assert lock_field.default_factory is threading.Lock
@pytest.fixture
def sample_video(tmp_path: Path) -> Path:
"""A 3 s 10 fps test-pattern mp4, written with ffmpeg."""
if shutil.which("ffmpeg") is None:
pytest.skip("ffmpeg not available")
out = tmp_path / "sample.mp4"
subprocess.run(
[
"ffmpeg", "-y", "-f", "lavfi",
"-i", "testsrc=duration=3:size=160x120:rate=10",
"-pix_fmt", "yuv420p", str(out),
],
check=True,
capture_output=True,
)
return out
def test_decode_frames_av_returns_one_uint8_frame_per_timestamp(sample_video: Path) -> None:
"""``_decode_frames_av`` decodes via PyAV directly — no torchcodec/torchvision.
This is the always-available fallback: torchcodec is unusable in some
containers and lerobot's ``pyav`` backend routes through the removed
``torchvision.io.VideoReader``.
"""
timestamps = [0.0, 1.0, 2.5]
frames = _decode_frames_av(sample_video, timestamps)
assert len(frames) == len(timestamps)
for frame in frames:
assert isinstance(frame, torch.Tensor)
assert frame.dtype == torch.uint8
assert frame.shape == (3, 120, 160)
def test_decode_frames_av_picks_nearest_frame(sample_video: Path) -> None:
"""Repeated and out-of-order timestamps each resolve to the nearest frame."""
frames = _decode_frames_av(sample_video, [2.0, 0.0, 2.0])
assert len(frames) == 3
assert torch.equal(frames[0], frames[2])
assert not torch.equal(frames[0], frames[1])
def test_decode_frames_av_raises_on_missing_file(tmp_path: Path) -> None:
"""A missing video surfaces as an exception the caller can fall back on."""
with pytest.raises(Exception): # noqa: B017, PT011
_decode_frames_av(tmp_path / "does_not_exist.mp4", [0.0])