fix(viz): use PyAV for AV1 video decoding, AutoImageProcessor for SigLIP

- Replace cv2.VideoCapture with PyAV (av library) which handles AV1
  codec properly. Decode each video once and index by frame number.
- Use AutoImageProcessor instead of AutoProcessor to avoid loading
  the SigLIP tokenizer (which requires sentencepiece).

Made-with: Cursor
This commit is contained in:
Pepijn
2026-03-23 23:02:50 -07:00
parent efe8c09fca
commit 026e4c937d
@@ -13,7 +13,7 @@ primarily sees.
import json
from pathlib import Path
import cv2
import av
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
@@ -22,7 +22,7 @@ from huggingface_hub import snapshot_download
from matplotlib.colors import LinearSegmentedColormap
from PIL import Image
from scipy.spatial import cKDTree
from transformers import AutoModel, AutoProcessor
from transformers import AutoImageProcessor, AutoModel
DATASETS = [
{"repo_id": "lerobot-data-collection/level2_final_quality3", "label": "HQ curated"},
@@ -350,59 +350,67 @@ def build_video_lookup(local: Path, camera_key: str) -> dict:
return lookup
def _decode_video_frames(video_path: str) -> list[np.ndarray]:
"""Decode all frames from a video file using PyAV. Returns list of RGB arrays."""
container = av.open(video_path)
stream = container.streams.video[0]
stream.thread_type = "AUTO"
decoded = []
for frame in container.decode(stream):
decoded.append(frame.to_ndarray(format="rgb24"))
container.close()
return decoded
def extract_frames(
chosen_idx: np.ndarray,
episode_all: np.ndarray,
video_lookup: dict,
) -> list[np.ndarray | None]:
"""
Extract BGR frames for each chosen global index.
Uses episode boundaries + fps to compute the seek timestamp.
Returns list of (H, W, 3) BGR arrays (or None on failure).
Extract RGB frames for each chosen global index using PyAV.
Returns list of (H, W, 3) RGB arrays (or None on failure).
"""
# Build per-episode local frame index: for each row in the dataset,
# its position within its episode
unique_eps = np.unique(episode_all)
ep_start: dict[int, int] = {}
for ep in unique_eps:
ep_start[int(ep)] = int(np.where(episode_all == ep)[0][0])
frames: list[np.ndarray | None] = []
# Group by video file for efficient sequential access
jobs: list[tuple[int, int, str, float]] = []
# Build jobs: (output_index, video_path, local_frame_number)
jobs: list[tuple[int, str, int]] = []
for out_i, global_i in enumerate(chosen_idx):
ep = int(episode_all[global_i])
info = video_lookup.get(ep)
if info is None:
jobs.append((out_i, -1, "", 0.0))
continue
local_frame = global_i - ep_start[ep]
seek_ts = info["from_ts"] + local_frame / info["fps"]
jobs.append((out_i, global_i, str(info["video_path"]), seek_ts))
jobs.append((out_i, str(info["video_path"]), local_frame))
jobs.sort(key=lambda x: (x[2], x[3]))
# Group by video file, decode each video once
from collections import defaultdict
frames = [None] * len(chosen_idx)
current_cap = None
current_path = ""
video_jobs: dict[str, list[tuple[int, int]]] = defaultdict(list)
for out_i, vpath, local_frame in jobs:
video_jobs[vpath].append((out_i, local_frame))
frames: list[np.ndarray | None] = [None] * len(chosen_idx)
extracted = 0
for out_i, _global_i, vpath, seek_ts in jobs:
if not vpath:
n_videos = len(video_jobs)
for vi, (vpath, frame_requests) in enumerate(video_jobs.items()):
if not Path(vpath).exists():
continue
if vpath != current_path:
if current_cap is not None:
current_cap.release()
current_cap = cv2.VideoCapture(vpath)
current_path = vpath
if current_cap is None or not current_cap.isOpened():
try:
decoded = _decode_video_frames(vpath)
except Exception as exc:
print(f" Warning: failed to decode {Path(vpath).name}: {exc}")
continue
current_cap.set(cv2.CAP_PROP_POS_MSEC, seek_ts * 1000.0)
ret, frame = current_cap.read()
if ret:
frames[out_i] = frame
extracted += 1
if current_cap is not None:
current_cap.release()
for out_i, local_frame in frame_requests:
if 0 <= local_frame < len(decoded):
frames[out_i] = decoded[local_frame]
extracted += 1
if (vi + 1) % 50 == 0 or (vi + 1) == n_videos:
print(f" Decoded {vi + 1}/{n_videos} videos ({extracted:,} frames so far)")
del decoded
print(f" Extracted {extracted:,} / {len(chosen_idx):,} frames from video")
return frames
@@ -418,33 +426,27 @@ def encode_frames_siglip(
device: torch.device,
) -> np.ndarray:
"""
Encode BGR frames through SigLIP vision encoder.
Encode RGB frames through SigLIP vision encoder.
Returns (N, embed_dim) float32 array. Frames that are None get a zero vector.
"""
print(f" Loading SigLIP model: {model_name}")
processor = AutoProcessor.from_pretrained(model_name)
processor = AutoImageProcessor.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name).to(device).eval()
embed_dim = model.config.vision_config.hidden_size
n = len(frames)
embeddings = np.zeros((n, embed_dim), dtype=np.float32)
# Collect valid frame indices
valid_indices = [i for i, f in enumerate(frames) if f is not None]
print(f" Encoding {len(valid_indices):,} valid frames in batches of {batch_size}")
for batch_start in range(0, len(valid_indices), batch_size):
batch_idx = valid_indices[batch_start : batch_start + batch_size]
pil_images = []
for i in batch_idx:
bgr = frames[i]
rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
pil_images.append(Image.fromarray(rgb))
pil_images = [Image.fromarray(frames[i]) for i in batch_idx]
inputs = processor(images=pil_images, return_tensors="pt").to(device)
with torch.no_grad():
image_features = model.get_image_features(**inputs)
# L2-normalize embeddings for cosine-like KNN
image_features = torch.nn.functional.normalize(image_features, dim=-1)
embeddings[batch_idx] = image_features.cpu().numpy()