mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-19 10:40:04 +00:00
review: decode keyframes via video_utils.decode_video_frames
Addresses three of CarolinePascal's frames.py comments (the fourth, the subprocess re-encode, waits on #3611): - replace the bespoke _decode_pyav_direct PyAV decoder with lerobot.datasets.video_utils.decode_video_frames (torchcodec backend, PyAV fallback) — torchvision's VideoReader removal no longer applies - frames flow through the provider as torch.Tensor (C, H, W uint8); PIL is materialised only at the VLM-message boundary in to_image_blocks / to_video_block, where the chat backends need it - _decode now returns exactly one frame per timestamp (or [] on failure), so frames_at pairs them with strict=True Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -24,13 +24,21 @@ querying the same timestamp pay decode cost once.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Protocol
|
||||
|
||||
import PIL.Image
|
||||
import torch
|
||||
|
||||
from lerobot.datasets.video_utils import decode_video_frames
|
||||
|
||||
from .reader import EpisodeRecord
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FrameProvider(Protocol):
|
||||
"""Decodes camera frames at episode-relative timestamps."""
|
||||
@@ -45,7 +53,12 @@ class FrameProvider(Protocol):
|
||||
timestamps: list[float],
|
||||
camera_key: str | None = None,
|
||||
) -> list[Any]:
|
||||
"""Return one PIL.Image per timestamp from ``camera_key`` (or default).
|
||||
"""Return one decoded frame per timestamp from ``camera_key`` (or default).
|
||||
|
||||
Frames are ``torch.Tensor`` (``C, H, W`` uint8) — the shape
|
||||
:func:`lerobot.datasets.video_utils.decode_video_frames` returns.
|
||||
:func:`to_image_blocks` converts them to PIL only at the VLM-message
|
||||
boundary.
|
||||
|
||||
Empty list if the camera is unavailable. ``camera_key=None`` falls back
|
||||
to the provider's default camera so existing single-camera callers
|
||||
@@ -58,12 +71,13 @@ class FrameProvider(Protocol):
|
||||
max_frames: int,
|
||||
camera_key: str | None = None,
|
||||
) -> list[Any]:
|
||||
"""Return up to ``max_frames`` PIL images covering the whole episode.
|
||||
"""Return up to ``max_frames`` decoded frames covering the whole episode.
|
||||
|
||||
Sampling is uniform across the episode duration. The returned list is
|
||||
intended to be passed as one ``{"type":"video", "video":<list>}``
|
||||
block to a Qwen-VL-compatible model that pools temporally itself.
|
||||
Empty list if no camera available.
|
||||
Sampling is uniform across the episode duration. Frames are
|
||||
``torch.Tensor`` (``C, H, W`` uint8); :func:`to_video_block` wraps
|
||||
them into one ``{"type":"video", "video":<list>}`` block for a
|
||||
Qwen-VL-compatible model that pools temporally itself. Empty list if
|
||||
no camera available.
|
||||
"""
|
||||
|
||||
|
||||
@@ -176,18 +190,20 @@ class VideoFrameProvider:
|
||||
|
||||
if misses:
|
||||
decoded = self._decode(record.episode_index, misses, target)
|
||||
# decoder may return fewer frames than requested when some
|
||||
# timestamps fall outside the video; pair what we have and
|
||||
# leave the rest as None to be filtered below.
|
||||
with self._lock:
|
||||
for i, img in zip(miss_indices, decoded, strict=False):
|
||||
out[i] = img
|
||||
key = (record.episode_index, target, round(float(timestamps[i]), 6))
|
||||
if len(self._cache) >= self.cache_size:
|
||||
self._cache.pop(next(iter(self._cache)))
|
||||
self._cache[key] = img
|
||||
# ``_decode`` returns exactly one frame per requested timestamp,
|
||||
# or an empty list if decoding failed wholesale. A partial list
|
||||
# would mean a frame/timestamp misalignment, so only pair them up
|
||||
# when the counts match (``strict=True`` then guards regressions).
|
||||
if len(decoded) == len(miss_indices):
|
||||
with self._lock:
|
||||
for i, frame in zip(miss_indices, decoded, strict=True):
|
||||
out[i] = frame
|
||||
key = (record.episode_index, target, round(float(timestamps[i]), 6))
|
||||
if len(self._cache) >= self.cache_size:
|
||||
self._cache.pop(next(iter(self._cache)))
|
||||
self._cache[key] = frame
|
||||
# filter out any None left over from decode failures
|
||||
return [img for img in out if img is not None]
|
||||
return [frame for frame in out if frame is not None]
|
||||
|
||||
def video_for_episode(
|
||||
self,
|
||||
@@ -195,10 +211,11 @@ class VideoFrameProvider:
|
||||
max_frames: int,
|
||||
camera_key: str | None = None,
|
||||
) -> list[Any]:
|
||||
"""Return up to ``max_frames`` images uniformly sampled across the episode.
|
||||
"""Return up to ``max_frames`` frames uniformly sampled across the episode.
|
||||
|
||||
The whole episode duration is covered; the model picks subtask
|
||||
boundaries from the temporal pooling it does internally.
|
||||
boundaries from the temporal pooling it does internally. Frames are
|
||||
``torch.Tensor`` (see :meth:`frames_at`).
|
||||
"""
|
||||
target = camera_key if camera_key is not None else self.camera_key
|
||||
if max_frames <= 0 or target is None or not record.frame_timestamps:
|
||||
@@ -267,13 +284,22 @@ class VideoFrameProvider:
|
||||
return out_path if out_path.exists() and out_path.stat().st_size > 0 else None
|
||||
|
||||
def _decode(self, episode_index: int, timestamps: list[float], camera_key: str) -> list[Any]:
|
||||
"""Decode ``timestamps`` from the episode's video as ``(C, H, W)`` tensors.
|
||||
|
||||
Delegates to :func:`lerobot.datasets.video_utils.decode_video_frames`
|
||||
(torchcodec by default, PyAV fallback) rather than a bespoke decoder.
|
||||
Returns one frame per requested timestamp, or ``[]`` if decoding
|
||||
failed wholesale — callers treat ``[]`` as "no frames available".
|
||||
"""
|
||||
ep = self._meta.episodes[episode_index]
|
||||
from_timestamp = ep[f"videos/{camera_key}/from_timestamp"]
|
||||
shifted = [from_timestamp + ts for ts in timestamps]
|
||||
video_path = self.root / self._meta.get_video_file_path(episode_index, camera_key)
|
||||
|
||||
try:
|
||||
return _decode_pyav_direct(video_path, shifted, self.tolerance_s)
|
||||
# Stacked ``(N, C, H, W)`` uint8 tensor; one row per timestamp.
|
||||
decoded = decode_video_frames(video_path, shifted, self.tolerance_s, return_uint8=True)
|
||||
return list(decoded)
|
||||
except Exception as exc:
|
||||
# Log loudly the first time decoding fails so a silent
|
||||
# vqa-module no-op (every prompt skipped because frames_at
|
||||
@@ -284,9 +310,7 @@ class VideoFrameProvider:
|
||||
if not already_warned:
|
||||
self._warned_decode_fail = True
|
||||
if not already_warned:
|
||||
import logging # noqa: PLC0415
|
||||
|
||||
logging.getLogger(__name__).warning(
|
||||
logger.warning(
|
||||
"VideoFrameProvider._decode failed for episode=%s camera=%s video_path=%s: %s",
|
||||
episode_index,
|
||||
camera_key,
|
||||
@@ -297,64 +321,6 @@ class VideoFrameProvider:
|
||||
return []
|
||||
|
||||
|
||||
def _decode_pyav_direct(video_path: Any, timestamps: list[float], tolerance_s: float) -> list[Any]:
|
||||
"""Decode the requested timestamps from ``video_path`` using PyAV directly.
|
||||
|
||||
Bypasses ``lerobot.datasets.video_utils.decode_video_frames`` entirely
|
||||
because its "pyav" path actually goes through
|
||||
``decode_video_frames_torchvision`` → ``torchvision.io.VideoReader``,
|
||||
which was removed in torchvision >= 0.22 (the vllm/vllm-openai:latest
|
||||
container ships with torchvision 0.25). The annotation pipeline only
|
||||
needs a handful of PIL images per (episode, ts), so we can decode them
|
||||
with PyAV without any torch dependency at all.
|
||||
|
||||
Returns one ``PIL.Image`` per requested timestamp, in the same order.
|
||||
Any timestamp the decoder couldn't reach is silently dropped (mirrors
|
||||
the previous behaviour); callers filter ``None``/missing entries.
|
||||
"""
|
||||
import av # noqa: PLC0415
|
||||
|
||||
if not timestamps:
|
||||
return []
|
||||
|
||||
targets = sorted(set(timestamps))
|
||||
seek_to = max(0.0, min(targets) - max(0.5, tolerance_s))
|
||||
|
||||
container = av.open(str(video_path))
|
||||
try:
|
||||
stream = container.streams.video[0]
|
||||
# PyAV needs the seek target in stream timebase ticks.
|
||||
seek_pts = 0 if stream.time_base is None else int(seek_to / float(stream.time_base))
|
||||
try:
|
||||
container.seek(seek_pts, any_frame=False, backward=True, stream=stream)
|
||||
except av.AVError:
|
||||
# Some streams reject the explicit seek; fall back to decoding from start.
|
||||
container.seek(0)
|
||||
|
||||
results: dict[float, Any] = {}
|
||||
target_iter = iter(targets)
|
||||
next_target = next(target_iter, None)
|
||||
for frame in container.decode(stream):
|
||||
if next_target is None:
|
||||
break
|
||||
ts = float(frame.pts * frame.time_base) if frame.pts is not None else None
|
||||
if ts is None:
|
||||
continue
|
||||
# Walk past targets we've already overshot — we keep the closest
|
||||
# frame within tolerance.
|
||||
while next_target is not None and ts >= next_target - tolerance_s:
|
||||
if abs(ts - next_target) <= tolerance_s or ts >= next_target:
|
||||
img = frame.to_image() # PIL.Image.Image (RGB)
|
||||
results.setdefault(next_target, img)
|
||||
next_target = next(target_iter, None)
|
||||
else:
|
||||
break
|
||||
finally:
|
||||
container.close()
|
||||
|
||||
return [results[ts] for ts in timestamps if ts in results]
|
||||
|
||||
|
||||
def make_frame_provider(root: Path, camera_key: str | None = None) -> FrameProvider:
|
||||
"""Build a :class:`VideoFrameProvider` if videos are present, else null."""
|
||||
try:
|
||||
@@ -366,20 +332,38 @@ def make_frame_provider(root: Path, camera_key: str | None = None) -> FrameProvi
|
||||
return provider
|
||||
|
||||
|
||||
def to_image_blocks(images: list[Any]) -> list[dict[str, Any]]:
|
||||
"""Convert PIL images to Qwen-VL-compatible content blocks."""
|
||||
return [{"type": "image", "image": img} for img in images]
|
||||
def _frame_to_pil(frame: Any) -> Any:
|
||||
"""Materialise a decoded frame as a ``PIL.Image`` for the VLM message.
|
||||
|
||||
Frames flow through the provider as ``torch.Tensor`` (``C, H, W`` uint8,
|
||||
straight from :func:`decode_video_frames`); PIL is only created here, at
|
||||
the VLM-message boundary, because the chat backends expect PIL images /
|
||||
data URLs. Non-tensor inputs (e.g. test stubs) pass through untouched.
|
||||
"""
|
||||
if not isinstance(frame, torch.Tensor):
|
||||
return frame
|
||||
array = frame.detach().cpu()
|
||||
if array.ndim == 3 and array.shape[0] in (1, 3):
|
||||
array = array.permute(1, 2, 0) # (C, H, W) -> (H, W, C)
|
||||
if array.shape[-1] == 1:
|
||||
array = array.squeeze(-1)
|
||||
return PIL.Image.fromarray(array.to(torch.uint8).numpy())
|
||||
|
||||
|
||||
def to_video_block(images: list[Any]) -> list[dict[str, Any]]:
|
||||
"""Wrap a list of PIL images as one Qwen-VL video block.
|
||||
def to_image_blocks(frames: list[Any]) -> list[dict[str, Any]]:
|
||||
"""Convert decoded frames to Qwen-VL-compatible image content blocks."""
|
||||
return [{"type": "image", "image": _frame_to_pil(frame)} for frame in frames]
|
||||
|
||||
|
||||
def to_video_block(frames: list[Any]) -> list[dict[str, Any]]:
|
||||
"""Wrap a list of decoded frames as one Qwen-VL video block.
|
||||
|
||||
Returns ``[]`` when the list is empty, so the caller can splat the result
|
||||
into a content array without a separate emptiness check.
|
||||
"""
|
||||
if not images:
|
||||
if not frames:
|
||||
return []
|
||||
return [{"type": "video", "video": list(images)}]
|
||||
return [{"type": "video", "video": [_frame_to_pil(frame) for frame in frames]}]
|
||||
|
||||
|
||||
def to_video_url_block(url: str | None, fps: float = 2.0) -> list[dict[str, Any]]:
|
||||
|
||||
@@ -15,13 +15,12 @@
|
||||
# limitations under the License.
|
||||
"""Unit tests for :class:`VideoFrameProvider` method bindings.
|
||||
|
||||
These were prompted by a real regression: ``video_for_episode`` was
|
||||
indented one level too deep so it ended up nested *inside* the
|
||||
``_decode_pyav_direct`` helper (after that function's ``return``
|
||||
statement) — silently dead code that meant production runs with
|
||||
``use_video_url=False`` would ``AttributeError`` on
|
||||
``self.frame_provider.video_for_episode(...)``. The existing module
|
||||
tests didn't catch it because they exercise stub providers.
|
||||
These were prompted by a real regression: ``video_for_episode`` was once
|
||||
indented one level too deep so it ended up nested *inside* a module-level
|
||||
helper (after that function's ``return`` statement) — silently dead code
|
||||
that meant production runs with ``use_video_url=False`` would
|
||||
``AttributeError`` on ``self.frame_provider.video_for_episode(...)``. The
|
||||
existing module tests didn't catch it because they exercise stub providers.
|
||||
|
||||
The tests below assert on the class itself (not on an instance), so a
|
||||
future reindent regression flips them to red without needing a real
|
||||
@@ -51,8 +50,8 @@ def test_episode_clip_path_is_a_method_of_videoframeprovider():
|
||||
|
||||
|
||||
def test_videoframeprovider_has_a_lock_for_concurrent_use():
|
||||
"""A ``ThreadPoolExecutor`` runs Module 1/2/3 phases concurrently;
|
||||
the cache + warn-flag accesses must be guarded.
|
||||
"""A ``ThreadPoolExecutor`` runs the plan / interjections / vqa phases
|
||||
concurrently; the cache + warn-flag accesses must be guarded.
|
||||
"""
|
||||
import threading
|
||||
|
||||
|
||||
Reference in New Issue
Block a user