From 9dfc9084e1b955d965382087d87b865b5caf9a2a Mon Sep 17 00:00:00 2001 From: Pepijn Kooijmans Date: Mon, 18 May 2026 14:00:38 +0200 Subject: [PATCH] review: decode keyframes via video_utils.decode_video_frames MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Addresses three of CarolinePascal's frames.py comments (the fourth, the subprocess re-encode, waits on #3611): - replace the bespoke _decode_pyav_direct PyAV decoder with lerobot.datasets.video_utils.decode_video_frames (torchcodec backend, PyAV fallback) — torchvision's VideoReader removal no longer applies - frames flow through the provider as torch.Tensor (C, H, W uint8); PIL is materialised only at the VLM-message boundary in to_image_blocks / to_video_block, where the chat backends need it - _decode now returns exactly one frame per timestamp (or [] on failure), so frames_at pairs them with strict=True Co-Authored-By: Claude Opus 4.7 (1M context) --- .../annotations/steerable_pipeline/frames.py | 160 ++++++++---------- tests/annotations/test_frames.py | 17 +- 2 files changed, 80 insertions(+), 97 deletions(-) diff --git a/src/lerobot/annotations/steerable_pipeline/frames.py b/src/lerobot/annotations/steerable_pipeline/frames.py index fedc8109c..51092b5a7 100644 --- a/src/lerobot/annotations/steerable_pipeline/frames.py +++ b/src/lerobot/annotations/steerable_pipeline/frames.py @@ -24,13 +24,21 @@ querying the same timestamp pay decode cost once. from __future__ import annotations +import logging import threading from dataclasses import dataclass, field from pathlib import Path from typing import Any, Protocol +import PIL.Image +import torch + +from lerobot.datasets.video_utils import decode_video_frames + from .reader import EpisodeRecord +logger = logging.getLogger(__name__) + class FrameProvider(Protocol): """Decodes camera frames at episode-relative timestamps.""" @@ -45,7 +53,12 @@ class FrameProvider(Protocol): timestamps: list[float], camera_key: str | None = None, ) -> list[Any]: - """Return one PIL.Image per timestamp from ``camera_key`` (or default). + """Return one decoded frame per timestamp from ``camera_key`` (or default). + + Frames are ``torch.Tensor`` (``C, H, W`` uint8) — the shape + :func:`lerobot.datasets.video_utils.decode_video_frames` returns. + :func:`to_image_blocks` converts them to PIL only at the VLM-message + boundary. Empty list if the camera is unavailable. ``camera_key=None`` falls back to the provider's default camera so existing single-camera callers @@ -58,12 +71,13 @@ class FrameProvider(Protocol): max_frames: int, camera_key: str | None = None, ) -> list[Any]: - """Return up to ``max_frames`` PIL images covering the whole episode. + """Return up to ``max_frames`` decoded frames covering the whole episode. - Sampling is uniform across the episode duration. The returned list is - intended to be passed as one ``{"type":"video", "video":}`` - block to a Qwen-VL-compatible model that pools temporally itself. - Empty list if no camera available. + Sampling is uniform across the episode duration. Frames are + ``torch.Tensor`` (``C, H, W`` uint8); :func:`to_video_block` wraps + them into one ``{"type":"video", "video":}`` block for a + Qwen-VL-compatible model that pools temporally itself. Empty list if + no camera available. """ @@ -176,18 +190,20 @@ class VideoFrameProvider: if misses: decoded = self._decode(record.episode_index, misses, target) - # decoder may return fewer frames than requested when some - # timestamps fall outside the video; pair what we have and - # leave the rest as None to be filtered below. - with self._lock: - for i, img in zip(miss_indices, decoded, strict=False): - out[i] = img - key = (record.episode_index, target, round(float(timestamps[i]), 6)) - if len(self._cache) >= self.cache_size: - self._cache.pop(next(iter(self._cache))) - self._cache[key] = img + # ``_decode`` returns exactly one frame per requested timestamp, + # or an empty list if decoding failed wholesale. A partial list + # would mean a frame/timestamp misalignment, so only pair them up + # when the counts match (``strict=True`` then guards regressions). + if len(decoded) == len(miss_indices): + with self._lock: + for i, frame in zip(miss_indices, decoded, strict=True): + out[i] = frame + key = (record.episode_index, target, round(float(timestamps[i]), 6)) + if len(self._cache) >= self.cache_size: + self._cache.pop(next(iter(self._cache))) + self._cache[key] = frame # filter out any None left over from decode failures - return [img for img in out if img is not None] + return [frame for frame in out if frame is not None] def video_for_episode( self, @@ -195,10 +211,11 @@ class VideoFrameProvider: max_frames: int, camera_key: str | None = None, ) -> list[Any]: - """Return up to ``max_frames`` images uniformly sampled across the episode. + """Return up to ``max_frames`` frames uniformly sampled across the episode. The whole episode duration is covered; the model picks subtask - boundaries from the temporal pooling it does internally. + boundaries from the temporal pooling it does internally. Frames are + ``torch.Tensor`` (see :meth:`frames_at`). """ target = camera_key if camera_key is not None else self.camera_key if max_frames <= 0 or target is None or not record.frame_timestamps: @@ -267,13 +284,22 @@ class VideoFrameProvider: return out_path if out_path.exists() and out_path.stat().st_size > 0 else None def _decode(self, episode_index: int, timestamps: list[float], camera_key: str) -> list[Any]: + """Decode ``timestamps`` from the episode's video as ``(C, H, W)`` tensors. + + Delegates to :func:`lerobot.datasets.video_utils.decode_video_frames` + (torchcodec by default, PyAV fallback) rather than a bespoke decoder. + Returns one frame per requested timestamp, or ``[]`` if decoding + failed wholesale — callers treat ``[]`` as "no frames available". + """ ep = self._meta.episodes[episode_index] from_timestamp = ep[f"videos/{camera_key}/from_timestamp"] shifted = [from_timestamp + ts for ts in timestamps] video_path = self.root / self._meta.get_video_file_path(episode_index, camera_key) try: - return _decode_pyav_direct(video_path, shifted, self.tolerance_s) + # Stacked ``(N, C, H, W)`` uint8 tensor; one row per timestamp. + decoded = decode_video_frames(video_path, shifted, self.tolerance_s, return_uint8=True) + return list(decoded) except Exception as exc: # Log loudly the first time decoding fails so a silent # vqa-module no-op (every prompt skipped because frames_at @@ -284,9 +310,7 @@ class VideoFrameProvider: if not already_warned: self._warned_decode_fail = True if not already_warned: - import logging # noqa: PLC0415 - - logging.getLogger(__name__).warning( + logger.warning( "VideoFrameProvider._decode failed for episode=%s camera=%s video_path=%s: %s", episode_index, camera_key, @@ -297,64 +321,6 @@ class VideoFrameProvider: return [] -def _decode_pyav_direct(video_path: Any, timestamps: list[float], tolerance_s: float) -> list[Any]: - """Decode the requested timestamps from ``video_path`` using PyAV directly. - - Bypasses ``lerobot.datasets.video_utils.decode_video_frames`` entirely - because its "pyav" path actually goes through - ``decode_video_frames_torchvision`` → ``torchvision.io.VideoReader``, - which was removed in torchvision >= 0.22 (the vllm/vllm-openai:latest - container ships with torchvision 0.25). The annotation pipeline only - needs a handful of PIL images per (episode, ts), so we can decode them - with PyAV without any torch dependency at all. - - Returns one ``PIL.Image`` per requested timestamp, in the same order. - Any timestamp the decoder couldn't reach is silently dropped (mirrors - the previous behaviour); callers filter ``None``/missing entries. - """ - import av # noqa: PLC0415 - - if not timestamps: - return [] - - targets = sorted(set(timestamps)) - seek_to = max(0.0, min(targets) - max(0.5, tolerance_s)) - - container = av.open(str(video_path)) - try: - stream = container.streams.video[0] - # PyAV needs the seek target in stream timebase ticks. - seek_pts = 0 if stream.time_base is None else int(seek_to / float(stream.time_base)) - try: - container.seek(seek_pts, any_frame=False, backward=True, stream=stream) - except av.AVError: - # Some streams reject the explicit seek; fall back to decoding from start. - container.seek(0) - - results: dict[float, Any] = {} - target_iter = iter(targets) - next_target = next(target_iter, None) - for frame in container.decode(stream): - if next_target is None: - break - ts = float(frame.pts * frame.time_base) if frame.pts is not None else None - if ts is None: - continue - # Walk past targets we've already overshot — we keep the closest - # frame within tolerance. - while next_target is not None and ts >= next_target - tolerance_s: - if abs(ts - next_target) <= tolerance_s or ts >= next_target: - img = frame.to_image() # PIL.Image.Image (RGB) - results.setdefault(next_target, img) - next_target = next(target_iter, None) - else: - break - finally: - container.close() - - return [results[ts] for ts in timestamps if ts in results] - - def make_frame_provider(root: Path, camera_key: str | None = None) -> FrameProvider: """Build a :class:`VideoFrameProvider` if videos are present, else null.""" try: @@ -366,20 +332,38 @@ def make_frame_provider(root: Path, camera_key: str | None = None) -> FrameProvi return provider -def to_image_blocks(images: list[Any]) -> list[dict[str, Any]]: - """Convert PIL images to Qwen-VL-compatible content blocks.""" - return [{"type": "image", "image": img} for img in images] +def _frame_to_pil(frame: Any) -> Any: + """Materialise a decoded frame as a ``PIL.Image`` for the VLM message. + + Frames flow through the provider as ``torch.Tensor`` (``C, H, W`` uint8, + straight from :func:`decode_video_frames`); PIL is only created here, at + the VLM-message boundary, because the chat backends expect PIL images / + data URLs. Non-tensor inputs (e.g. test stubs) pass through untouched. + """ + if not isinstance(frame, torch.Tensor): + return frame + array = frame.detach().cpu() + if array.ndim == 3 and array.shape[0] in (1, 3): + array = array.permute(1, 2, 0) # (C, H, W) -> (H, W, C) + if array.shape[-1] == 1: + array = array.squeeze(-1) + return PIL.Image.fromarray(array.to(torch.uint8).numpy()) -def to_video_block(images: list[Any]) -> list[dict[str, Any]]: - """Wrap a list of PIL images as one Qwen-VL video block. +def to_image_blocks(frames: list[Any]) -> list[dict[str, Any]]: + """Convert decoded frames to Qwen-VL-compatible image content blocks.""" + return [{"type": "image", "image": _frame_to_pil(frame)} for frame in frames] + + +def to_video_block(frames: list[Any]) -> list[dict[str, Any]]: + """Wrap a list of decoded frames as one Qwen-VL video block. Returns ``[]`` when the list is empty, so the caller can splat the result into a content array without a separate emptiness check. """ - if not images: + if not frames: return [] - return [{"type": "video", "video": list(images)}] + return [{"type": "video", "video": [_frame_to_pil(frame) for frame in frames]}] def to_video_url_block(url: str | None, fps: float = 2.0) -> list[dict[str, Any]]: diff --git a/tests/annotations/test_frames.py b/tests/annotations/test_frames.py index af2833c23..c0ed96ab3 100644 --- a/tests/annotations/test_frames.py +++ b/tests/annotations/test_frames.py @@ -15,13 +15,12 @@ # limitations under the License. """Unit tests for :class:`VideoFrameProvider` method bindings. -These were prompted by a real regression: ``video_for_episode`` was -indented one level too deep so it ended up nested *inside* the -``_decode_pyav_direct`` helper (after that function's ``return`` -statement) — silently dead code that meant production runs with -``use_video_url=False`` would ``AttributeError`` on -``self.frame_provider.video_for_episode(...)``. The existing module -tests didn't catch it because they exercise stub providers. +These were prompted by a real regression: ``video_for_episode`` was once +indented one level too deep so it ended up nested *inside* a module-level +helper (after that function's ``return`` statement) — silently dead code +that meant production runs with ``use_video_url=False`` would +``AttributeError`` on ``self.frame_provider.video_for_episode(...)``. The +existing module tests didn't catch it because they exercise stub providers. The tests below assert on the class itself (not on an instance), so a future reindent regression flips them to red without needing a real @@ -51,8 +50,8 @@ def test_episode_clip_path_is_a_method_of_videoframeprovider(): def test_videoframeprovider_has_a_lock_for_concurrent_use(): - """A ``ThreadPoolExecutor`` runs Module 1/2/3 phases concurrently; - the cache + warn-flag accesses must be guarded. + """A ``ThreadPoolExecutor`` runs the plan / interjections / vqa phases + concurrently; the cache + warn-flag accesses must be guarded. """ import threading