review: decode keyframes via video_utils.decode_video_frames

Addresses three of CarolinePascal's frames.py comments (the fourth, the
subprocess re-encode, waits on #3611):

- replace the bespoke _decode_pyav_direct PyAV decoder with
  lerobot.datasets.video_utils.decode_video_frames (torchcodec backend,
  PyAV fallback) — torchvision's VideoReader removal no longer applies
- frames flow through the provider as torch.Tensor (C, H, W uint8); PIL
  is materialised only at the VLM-message boundary in to_image_blocks /
  to_video_block, where the chat backends need it
- _decode now returns exactly one frame per timestamp (or [] on failure),
  so frames_at pairs them with strict=True

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn Kooijmans
2026-05-18 14:00:38 +02:00
parent fd18beb3a1
commit 9dfc9084e1
2 changed files with 80 additions and 97 deletions
@@ -24,13 +24,21 @@ querying the same timestamp pay decode cost once.
from __future__ import annotations from __future__ import annotations
import logging
import threading import threading
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from typing import Any, Protocol from typing import Any, Protocol
import PIL.Image
import torch
from lerobot.datasets.video_utils import decode_video_frames
from .reader import EpisodeRecord from .reader import EpisodeRecord
logger = logging.getLogger(__name__)
class FrameProvider(Protocol): class FrameProvider(Protocol):
"""Decodes camera frames at episode-relative timestamps.""" """Decodes camera frames at episode-relative timestamps."""
@@ -45,7 +53,12 @@ class FrameProvider(Protocol):
timestamps: list[float], timestamps: list[float],
camera_key: str | None = None, camera_key: str | None = None,
) -> list[Any]: ) -> list[Any]:
"""Return one PIL.Image per timestamp from ``camera_key`` (or default). """Return one decoded frame per timestamp from ``camera_key`` (or default).
Frames are ``torch.Tensor`` (``C, H, W`` uint8) — the shape
:func:`lerobot.datasets.video_utils.decode_video_frames` returns.
:func:`to_image_blocks` converts them to PIL only at the VLM-message
boundary.
Empty list if the camera is unavailable. ``camera_key=None`` falls back Empty list if the camera is unavailable. ``camera_key=None`` falls back
to the provider's default camera so existing single-camera callers to the provider's default camera so existing single-camera callers
@@ -58,12 +71,13 @@ class FrameProvider(Protocol):
max_frames: int, max_frames: int,
camera_key: str | None = None, camera_key: str | None = None,
) -> list[Any]: ) -> list[Any]:
"""Return up to ``max_frames`` PIL images covering the whole episode. """Return up to ``max_frames`` decoded frames covering the whole episode.
Sampling is uniform across the episode duration. The returned list is Sampling is uniform across the episode duration. Frames are
intended to be passed as one ``{"type":"video", "video":<list>}`` ``torch.Tensor`` (``C, H, W`` uint8); :func:`to_video_block` wraps
block to a Qwen-VL-compatible model that pools temporally itself. them into one ``{"type":"video", "video":<list>}`` block for a
Empty list if no camera available. Qwen-VL-compatible model that pools temporally itself. Empty list if
no camera available.
""" """
@@ -176,18 +190,20 @@ class VideoFrameProvider:
if misses: if misses:
decoded = self._decode(record.episode_index, misses, target) decoded = self._decode(record.episode_index, misses, target)
# decoder may return fewer frames than requested when some # ``_decode`` returns exactly one frame per requested timestamp,
# timestamps fall outside the video; pair what we have and # or an empty list if decoding failed wholesale. A partial list
# leave the rest as None to be filtered below. # would mean a frame/timestamp misalignment, so only pair them up
with self._lock: # when the counts match (``strict=True`` then guards regressions).
for i, img in zip(miss_indices, decoded, strict=False): if len(decoded) == len(miss_indices):
out[i] = img with self._lock:
key = (record.episode_index, target, round(float(timestamps[i]), 6)) for i, frame in zip(miss_indices, decoded, strict=True):
if len(self._cache) >= self.cache_size: out[i] = frame
self._cache.pop(next(iter(self._cache))) key = (record.episode_index, target, round(float(timestamps[i]), 6))
self._cache[key] = img if len(self._cache) >= self.cache_size:
self._cache.pop(next(iter(self._cache)))
self._cache[key] = frame
# filter out any None left over from decode failures # filter out any None left over from decode failures
return [img for img in out if img is not None] return [frame for frame in out if frame is not None]
def video_for_episode( def video_for_episode(
self, self,
@@ -195,10 +211,11 @@ class VideoFrameProvider:
max_frames: int, max_frames: int,
camera_key: str | None = None, camera_key: str | None = None,
) -> list[Any]: ) -> list[Any]:
"""Return up to ``max_frames`` images uniformly sampled across the episode. """Return up to ``max_frames`` frames uniformly sampled across the episode.
The whole episode duration is covered; the model picks subtask The whole episode duration is covered; the model picks subtask
boundaries from the temporal pooling it does internally. boundaries from the temporal pooling it does internally. Frames are
``torch.Tensor`` (see :meth:`frames_at`).
""" """
target = camera_key if camera_key is not None else self.camera_key target = camera_key if camera_key is not None else self.camera_key
if max_frames <= 0 or target is None or not record.frame_timestamps: if max_frames <= 0 or target is None or not record.frame_timestamps:
@@ -267,13 +284,22 @@ class VideoFrameProvider:
return out_path if out_path.exists() and out_path.stat().st_size > 0 else None return out_path if out_path.exists() and out_path.stat().st_size > 0 else None
def _decode(self, episode_index: int, timestamps: list[float], camera_key: str) -> list[Any]: def _decode(self, episode_index: int, timestamps: list[float], camera_key: str) -> list[Any]:
"""Decode ``timestamps`` from the episode's video as ``(C, H, W)`` tensors.
Delegates to :func:`lerobot.datasets.video_utils.decode_video_frames`
(torchcodec by default, PyAV fallback) rather than a bespoke decoder.
Returns one frame per requested timestamp, or ``[]`` if decoding
failed wholesale — callers treat ``[]`` as "no frames available".
"""
ep = self._meta.episodes[episode_index] ep = self._meta.episodes[episode_index]
from_timestamp = ep[f"videos/{camera_key}/from_timestamp"] from_timestamp = ep[f"videos/{camera_key}/from_timestamp"]
shifted = [from_timestamp + ts for ts in timestamps] shifted = [from_timestamp + ts for ts in timestamps]
video_path = self.root / self._meta.get_video_file_path(episode_index, camera_key) video_path = self.root / self._meta.get_video_file_path(episode_index, camera_key)
try: try:
return _decode_pyav_direct(video_path, shifted, self.tolerance_s) # Stacked ``(N, C, H, W)`` uint8 tensor; one row per timestamp.
decoded = decode_video_frames(video_path, shifted, self.tolerance_s, return_uint8=True)
return list(decoded)
except Exception as exc: except Exception as exc:
# Log loudly the first time decoding fails so a silent # Log loudly the first time decoding fails so a silent
# vqa-module no-op (every prompt skipped because frames_at # vqa-module no-op (every prompt skipped because frames_at
@@ -284,9 +310,7 @@ class VideoFrameProvider:
if not already_warned: if not already_warned:
self._warned_decode_fail = True self._warned_decode_fail = True
if not already_warned: if not already_warned:
import logging # noqa: PLC0415 logger.warning(
logging.getLogger(__name__).warning(
"VideoFrameProvider._decode failed for episode=%s camera=%s video_path=%s: %s", "VideoFrameProvider._decode failed for episode=%s camera=%s video_path=%s: %s",
episode_index, episode_index,
camera_key, camera_key,
@@ -297,64 +321,6 @@ class VideoFrameProvider:
return [] return []
def _decode_pyav_direct(video_path: Any, timestamps: list[float], tolerance_s: float) -> list[Any]:
"""Decode the requested timestamps from ``video_path`` using PyAV directly.
Bypasses ``lerobot.datasets.video_utils.decode_video_frames`` entirely
because its "pyav" path actually goes through
``decode_video_frames_torchvision`` → ``torchvision.io.VideoReader``,
which was removed in torchvision >= 0.22 (the vllm/vllm-openai:latest
container ships with torchvision 0.25). The annotation pipeline only
needs a handful of PIL images per (episode, ts), so we can decode them
with PyAV without any torch dependency at all.
Returns one ``PIL.Image`` per requested timestamp, in the same order.
Any timestamp the decoder couldn't reach is silently dropped (mirrors
the previous behaviour); callers filter ``None``/missing entries.
"""
import av # noqa: PLC0415
if not timestamps:
return []
targets = sorted(set(timestamps))
seek_to = max(0.0, min(targets) - max(0.5, tolerance_s))
container = av.open(str(video_path))
try:
stream = container.streams.video[0]
# PyAV needs the seek target in stream timebase ticks.
seek_pts = 0 if stream.time_base is None else int(seek_to / float(stream.time_base))
try:
container.seek(seek_pts, any_frame=False, backward=True, stream=stream)
except av.AVError:
# Some streams reject the explicit seek; fall back to decoding from start.
container.seek(0)
results: dict[float, Any] = {}
target_iter = iter(targets)
next_target = next(target_iter, None)
for frame in container.decode(stream):
if next_target is None:
break
ts = float(frame.pts * frame.time_base) if frame.pts is not None else None
if ts is None:
continue
# Walk past targets we've already overshot — we keep the closest
# frame within tolerance.
while next_target is not None and ts >= next_target - tolerance_s:
if abs(ts - next_target) <= tolerance_s or ts >= next_target:
img = frame.to_image() # PIL.Image.Image (RGB)
results.setdefault(next_target, img)
next_target = next(target_iter, None)
else:
break
finally:
container.close()
return [results[ts] for ts in timestamps if ts in results]
def make_frame_provider(root: Path, camera_key: str | None = None) -> FrameProvider: def make_frame_provider(root: Path, camera_key: str | None = None) -> FrameProvider:
"""Build a :class:`VideoFrameProvider` if videos are present, else null.""" """Build a :class:`VideoFrameProvider` if videos are present, else null."""
try: try:
@@ -366,20 +332,38 @@ def make_frame_provider(root: Path, camera_key: str | None = None) -> FrameProvi
return provider return provider
def to_image_blocks(images: list[Any]) -> list[dict[str, Any]]: def _frame_to_pil(frame: Any) -> Any:
"""Convert PIL images to Qwen-VL-compatible content blocks.""" """Materialise a decoded frame as a ``PIL.Image`` for the VLM message.
return [{"type": "image", "image": img} for img in images]
Frames flow through the provider as ``torch.Tensor`` (``C, H, W`` uint8,
straight from :func:`decode_video_frames`); PIL is only created here, at
the VLM-message boundary, because the chat backends expect PIL images /
data URLs. Non-tensor inputs (e.g. test stubs) pass through untouched.
"""
if not isinstance(frame, torch.Tensor):
return frame
array = frame.detach().cpu()
if array.ndim == 3 and array.shape[0] in (1, 3):
array = array.permute(1, 2, 0) # (C, H, W) -> (H, W, C)
if array.shape[-1] == 1:
array = array.squeeze(-1)
return PIL.Image.fromarray(array.to(torch.uint8).numpy())
def to_video_block(images: list[Any]) -> list[dict[str, Any]]: def to_image_blocks(frames: list[Any]) -> list[dict[str, Any]]:
"""Wrap a list of PIL images as one Qwen-VL video block. """Convert decoded frames to Qwen-VL-compatible image content blocks."""
return [{"type": "image", "image": _frame_to_pil(frame)} for frame in frames]
def to_video_block(frames: list[Any]) -> list[dict[str, Any]]:
"""Wrap a list of decoded frames as one Qwen-VL video block.
Returns ``[]`` when the list is empty, so the caller can splat the result Returns ``[]`` when the list is empty, so the caller can splat the result
into a content array without a separate emptiness check. into a content array without a separate emptiness check.
""" """
if not images: if not frames:
return [] return []
return [{"type": "video", "video": list(images)}] return [{"type": "video", "video": [_frame_to_pil(frame) for frame in frames]}]
def to_video_url_block(url: str | None, fps: float = 2.0) -> list[dict[str, Any]]: def to_video_url_block(url: str | None, fps: float = 2.0) -> list[dict[str, Any]]:
+8 -9
View File
@@ -15,13 +15,12 @@
# limitations under the License. # limitations under the License.
"""Unit tests for :class:`VideoFrameProvider` method bindings. """Unit tests for :class:`VideoFrameProvider` method bindings.
These were prompted by a real regression: ``video_for_episode`` was These were prompted by a real regression: ``video_for_episode`` was once
indented one level too deep so it ended up nested *inside* the indented one level too deep so it ended up nested *inside* a module-level
``_decode_pyav_direct`` helper (after that function's ``return`` helper (after that function's ``return`` statement) — silently dead code
statement) silently dead code that meant production runs with that meant production runs with ``use_video_url=False`` would
``use_video_url=False`` would ``AttributeError`` on ``AttributeError`` on ``self.frame_provider.video_for_episode(...)``. The
``self.frame_provider.video_for_episode(...)``. The existing module existing module tests didn't catch it because they exercise stub providers.
tests didn't catch it because they exercise stub providers.
The tests below assert on the class itself (not on an instance), so a The tests below assert on the class itself (not on an instance), so a
future reindent regression flips them to red without needing a real future reindent regression flips them to red without needing a real
@@ -51,8 +50,8 @@ def test_episode_clip_path_is_a_method_of_videoframeprovider():
def test_videoframeprovider_has_a_lock_for_concurrent_use(): def test_videoframeprovider_has_a_lock_for_concurrent_use():
"""A ``ThreadPoolExecutor`` runs Module 1/2/3 phases concurrently; """A ``ThreadPoolExecutor`` runs the plan / interjections / vqa phases
the cache + warn-flag accesses must be guarded. concurrently; the cache + warn-flag accesses must be guarded.
""" """
import threading import threading