mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-19 10:40:04 +00:00
fix(annotate): decode keyframes via PyAV directly
The pyav fallback routed through lerobot's decode_video_frames(backend= "pyav"), which uses torchvision.io.VideoReader — removed in torchvision 0.23+. On modern torch stacks (e.g. vllm-openai with torchvision 0.26) both torchcodec and that path fail, leaving interjection/vqa prompts without visual context. Add _decode_frames_av: a self-contained PyAV decoder that picks the nearest frame per timestamp. It is the always-available tail of the decoder chain (torchcodec -> pyav) and the target of --video_backend=pyav. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -303,19 +303,21 @@ class VideoFrameProvider:
|
||||
shifted = [from_timestamp + ts for ts in timestamps]
|
||||
video_path = self.root / self._meta.get_video_file_path(episode_index, camera_key)
|
||||
|
||||
# When no backend is pinned, try the platform default first and fall
|
||||
# back to ``pyav`` if it raises — torchcodec is broken in some
|
||||
# containers (e.g. vllm-openai), where pyav decodes the same file fine.
|
||||
# Build the decoder chain. torchcodec is fast but unusable in some
|
||||
# containers (vllm-openai: "Operation not permitted"); lerobot's
|
||||
# ``pyav`` backend routes through ``torchvision.io.VideoReader``,
|
||||
# removed in torchvision 0.23+. ``_decode_frames_av`` talks to the
|
||||
# ``av`` package directly and is the always-available fallback.
|
||||
if self.video_backend:
|
||||
backends: list[str | None] = [self.video_backend]
|
||||
chain = [self.video_backend]
|
||||
else:
|
||||
backends = [None]
|
||||
if get_safe_default_codec() != "pyav":
|
||||
backends.append("pyav")
|
||||
chain = (["torchcodec"] if get_safe_default_codec() == "torchcodec" else []) + ["pyav"]
|
||||
|
||||
exc: Exception | None = None
|
||||
for backend in backends:
|
||||
for backend in chain:
|
||||
try:
|
||||
if backend in ("pyav", "av"):
|
||||
return _decode_frames_av(video_path, shifted)
|
||||
# Stacked ``(N, C, H, W)`` uint8 tensor; one row per timestamp.
|
||||
decoded = decode_video_frames(
|
||||
video_path, shifted, self.tolerance_s, backend=backend, return_uint8=True
|
||||
@@ -339,7 +341,7 @@ class VideoFrameProvider:
|
||||
episode_index,
|
||||
camera_key,
|
||||
video_path,
|
||||
backends,
|
||||
chain,
|
||||
exc,
|
||||
exc_info=exc,
|
||||
)
|
||||
@@ -359,6 +361,42 @@ def make_frame_provider(
|
||||
return provider
|
||||
|
||||
|
||||
def _decode_frames_av(video_path: Path, timestamps: list[float]) -> list[Any]:
|
||||
"""Decode the frames nearest to ``timestamps`` using PyAV directly.
|
||||
|
||||
lerobot's ``decode_video_frames(backend="pyav")`` routes through
|
||||
``torchvision.io.VideoReader``, removed in torchvision 0.23+. This helper
|
||||
talks to the ``av`` package directly so keyframe extraction keeps working
|
||||
on modern torch/torchvision stacks and in containers where torchcodec
|
||||
cannot decode. Returns one ``(C, H, W)`` uint8 tensor per timestamp.
|
||||
"""
|
||||
import av # noqa: PLC0415
|
||||
|
||||
first_ts = min(timestamps)
|
||||
last_ts = max(timestamps)
|
||||
loaded_frames: list[torch.Tensor] = []
|
||||
loaded_ts: list[float] = []
|
||||
with av.open(str(video_path)) as container:
|
||||
stream = container.streams.video[0]
|
||||
# Seek to the keyframe at or before the first requested timestamp.
|
||||
offset = max(int(first_ts / stream.time_base), 0) if stream.time_base else 0
|
||||
container.seek(offset, stream=stream, backward=True, any_frame=False)
|
||||
for idx, frame in enumerate(container.decode(stream)):
|
||||
ts = frame.time
|
||||
if ts is None:
|
||||
ts = float(frame.pts * stream.time_base) if frame.pts is not None else float(idx)
|
||||
loaded_ts.append(ts)
|
||||
loaded_frames.append(
|
||||
torch.from_numpy(frame.to_ndarray(format="rgb24")).permute(2, 0, 1).contiguous()
|
||||
)
|
||||
if ts >= last_ts:
|
||||
break
|
||||
if not loaded_frames:
|
||||
raise RuntimeError(f"PyAV decoded no frames from {video_path}")
|
||||
ts_tensor = torch.tensor(loaded_ts)
|
||||
return [loaded_frames[int(torch.argmin((ts_tensor - q).abs()))] for q in timestamps]
|
||||
|
||||
|
||||
def _frame_to_pil(frame: Any) -> Any:
|
||||
"""Materialise a decoded frame as a ``PIL.Image`` for the VLM message.
|
||||
|
||||
|
||||
@@ -29,12 +29,18 @@ LeRobot dataset on disk.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import shutil
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
|
||||
from lerobot.annotations.steerable_pipeline.frames import ( # noqa: E402
|
||||
VideoFrameProvider,
|
||||
_decode_frames_av,
|
||||
)
|
||||
|
||||
|
||||
@@ -64,3 +70,53 @@ def test_videoframeprovider_has_a_lock_for_concurrent_use():
|
||||
)
|
||||
assert lock_field is not None
|
||||
assert lock_field.default_factory is threading.Lock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_video(tmp_path: Path) -> Path:
|
||||
"""A 3 s 10 fps test-pattern mp4, written with ffmpeg."""
|
||||
if shutil.which("ffmpeg") is None:
|
||||
pytest.skip("ffmpeg not available")
|
||||
out = tmp_path / "sample.mp4"
|
||||
subprocess.run(
|
||||
[
|
||||
"ffmpeg", "-y", "-f", "lavfi",
|
||||
"-i", "testsrc=duration=3:size=160x120:rate=10",
|
||||
"-pix_fmt", "yuv420p", str(out),
|
||||
],
|
||||
check=True,
|
||||
capture_output=True,
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def test_decode_frames_av_returns_one_uint8_frame_per_timestamp(sample_video: Path) -> None:
|
||||
"""``_decode_frames_av`` decodes via PyAV directly — no torchcodec/torchvision.
|
||||
|
||||
This is the always-available fallback: torchcodec is unusable in some
|
||||
containers and lerobot's ``pyav`` backend routes through the removed
|
||||
``torchvision.io.VideoReader``.
|
||||
"""
|
||||
timestamps = [0.0, 1.0, 2.5]
|
||||
frames = _decode_frames_av(sample_video, timestamps)
|
||||
|
||||
assert len(frames) == len(timestamps)
|
||||
for frame in frames:
|
||||
assert isinstance(frame, torch.Tensor)
|
||||
assert frame.dtype == torch.uint8
|
||||
assert frame.shape == (3, 120, 160)
|
||||
|
||||
|
||||
def test_decode_frames_av_picks_nearest_frame(sample_video: Path) -> None:
|
||||
"""Repeated and out-of-order timestamps each resolve to the nearest frame."""
|
||||
frames = _decode_frames_av(sample_video, [2.0, 0.0, 2.0])
|
||||
|
||||
assert len(frames) == 3
|
||||
assert torch.equal(frames[0], frames[2])
|
||||
assert not torch.equal(frames[0], frames[1])
|
||||
|
||||
|
||||
def test_decode_frames_av_raises_on_missing_file(tmp_path: Path) -> None:
|
||||
"""A missing video surfaces as an exception the caller can fall back on."""
|
||||
with pytest.raises(Exception): # noqa: B017, PT011
|
||||
_decode_frames_av(tmp_path / "does_not_exist.mp4", [0.0])
|
||||
|
||||
Reference in New Issue
Block a user