fix(annotate): decode keyframes via PyAV directly

The pyav fallback routed through lerobot's decode_video_frames(backend=
"pyav"), which uses torchvision.io.VideoReader — removed in torchvision
0.23+. On modern torch stacks (e.g. vllm-openai with torchvision 0.26)
both torchcodec and that path fail, leaving interjection/vqa prompts
without visual context.

Add _decode_frames_av: a self-contained PyAV decoder that picks the
nearest frame per timestamp. It is the always-available tail of the
decoder chain (torchcodec -> pyav) and the target of --video_backend=pyav.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-05-18 15:45:04 +02:00
parent 31e0c15e55
commit 7128bb1769
2 changed files with 103 additions and 9 deletions
@@ -303,19 +303,21 @@ class VideoFrameProvider:
shifted = [from_timestamp + ts for ts in timestamps]
video_path = self.root / self._meta.get_video_file_path(episode_index, camera_key)
# When no backend is pinned, try the platform default first and fall
# back to ``pyav`` if it raises — torchcodec is broken in some
# containers (e.g. vllm-openai), where pyav decodes the same file fine.
# Build the decoder chain. torchcodec is fast but unusable in some
# containers (vllm-openai: "Operation not permitted"); lerobot's
# ``pyav`` backend routes through ``torchvision.io.VideoReader``,
# removed in torchvision 0.23+. ``_decode_frames_av`` talks to the
# ``av`` package directly and is the always-available fallback.
if self.video_backend:
backends: list[str | None] = [self.video_backend]
chain = [self.video_backend]
else:
backends = [None]
if get_safe_default_codec() != "pyav":
backends.append("pyav")
chain = (["torchcodec"] if get_safe_default_codec() == "torchcodec" else []) + ["pyav"]
exc: Exception | None = None
for backend in backends:
for backend in chain:
try:
if backend in ("pyav", "av"):
return _decode_frames_av(video_path, shifted)
# Stacked ``(N, C, H, W)`` uint8 tensor; one row per timestamp.
decoded = decode_video_frames(
video_path, shifted, self.tolerance_s, backend=backend, return_uint8=True
@@ -339,7 +341,7 @@ class VideoFrameProvider:
episode_index,
camera_key,
video_path,
backends,
chain,
exc,
exc_info=exc,
)
@@ -359,6 +361,42 @@ def make_frame_provider(
return provider
def _decode_frames_av(video_path: Path, timestamps: list[float]) -> list[Any]:
"""Decode the frames nearest to ``timestamps`` using PyAV directly.
lerobot's ``decode_video_frames(backend="pyav")`` routes through
``torchvision.io.VideoReader``, removed in torchvision 0.23+. This helper
talks to the ``av`` package directly so keyframe extraction keeps working
on modern torch/torchvision stacks and in containers where torchcodec
cannot decode. Returns one ``(C, H, W)`` uint8 tensor per timestamp.
"""
import av # noqa: PLC0415
first_ts = min(timestamps)
last_ts = max(timestamps)
loaded_frames: list[torch.Tensor] = []
loaded_ts: list[float] = []
with av.open(str(video_path)) as container:
stream = container.streams.video[0]
# Seek to the keyframe at or before the first requested timestamp.
offset = max(int(first_ts / stream.time_base), 0) if stream.time_base else 0
container.seek(offset, stream=stream, backward=True, any_frame=False)
for idx, frame in enumerate(container.decode(stream)):
ts = frame.time
if ts is None:
ts = float(frame.pts * stream.time_base) if frame.pts is not None else float(idx)
loaded_ts.append(ts)
loaded_frames.append(
torch.from_numpy(frame.to_ndarray(format="rgb24")).permute(2, 0, 1).contiguous()
)
if ts >= last_ts:
break
if not loaded_frames:
raise RuntimeError(f"PyAV decoded no frames from {video_path}")
ts_tensor = torch.tensor(loaded_ts)
return [loaded_frames[int(torch.argmin((ts_tensor - q).abs()))] for q in timestamps]
def _frame_to_pil(frame: Any) -> Any:
"""Materialise a decoded frame as a ``PIL.Image`` for the VLM message.
+56
View File
@@ -29,12 +29,18 @@ LeRobot dataset on disk.
from __future__ import annotations
import shutil
import subprocess
from pathlib import Path
import pytest
import torch
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
from lerobot.annotations.steerable_pipeline.frames import ( # noqa: E402
VideoFrameProvider,
_decode_frames_av,
)
@@ -64,3 +70,53 @@ def test_videoframeprovider_has_a_lock_for_concurrent_use():
)
assert lock_field is not None
assert lock_field.default_factory is threading.Lock
@pytest.fixture
def sample_video(tmp_path: Path) -> Path:
"""A 3 s 10 fps test-pattern mp4, written with ffmpeg."""
if shutil.which("ffmpeg") is None:
pytest.skip("ffmpeg not available")
out = tmp_path / "sample.mp4"
subprocess.run(
[
"ffmpeg", "-y", "-f", "lavfi",
"-i", "testsrc=duration=3:size=160x120:rate=10",
"-pix_fmt", "yuv420p", str(out),
],
check=True,
capture_output=True,
)
return out
def test_decode_frames_av_returns_one_uint8_frame_per_timestamp(sample_video: Path) -> None:
"""``_decode_frames_av`` decodes via PyAV directly — no torchcodec/torchvision.
This is the always-available fallback: torchcodec is unusable in some
containers and lerobot's ``pyav`` backend routes through the removed
``torchvision.io.VideoReader``.
"""
timestamps = [0.0, 1.0, 2.5]
frames = _decode_frames_av(sample_video, timestamps)
assert len(frames) == len(timestamps)
for frame in frames:
assert isinstance(frame, torch.Tensor)
assert frame.dtype == torch.uint8
assert frame.shape == (3, 120, 160)
def test_decode_frames_av_picks_nearest_frame(sample_video: Path) -> None:
"""Repeated and out-of-order timestamps each resolve to the nearest frame."""
frames = _decode_frames_av(sample_video, [2.0, 0.0, 2.0])
assert len(frames) == 3
assert torch.equal(frames[0], frames[2])
assert not torch.equal(frames[0], frames[1])
def test_decode_frames_av_raises_on_missing_file(tmp_path: Path) -> None:
"""A missing video surfaces as an exception the caller can fall back on."""
with pytest.raises(Exception): # noqa: B017, PT011
_decode_frames_av(tmp_path / "does_not_exist.mp4", [0.0])