fix(viz): use PyAV for AV1 video decoding, AutoImageProcessor for SigLIP

- Replace cv2.VideoCapture with PyAV (av library) which handles AV1
  codec properly. Decode each video once and index by frame number.
- Use AutoImageProcessor instead of AutoProcessor to avoid loading
  the SigLIP tokenizer (which requires sentencepiece).

Made-with: Cursor
This commit is contained in:
Pepijn
2026-03-23 23:02:50 -07:00
parent efe8c09fca
commit 026e4c937d
@@ -13,7 +13,7 @@ primarily sees.
import json import json
from pathlib import Path from pathlib import Path
import cv2 import av
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
import pandas as pd import pandas as pd
@@ -22,7 +22,7 @@ from huggingface_hub import snapshot_download
from matplotlib.colors import LinearSegmentedColormap from matplotlib.colors import LinearSegmentedColormap
from PIL import Image from PIL import Image
from scipy.spatial import cKDTree from scipy.spatial import cKDTree
from transformers import AutoModel, AutoProcessor from transformers import AutoImageProcessor, AutoModel
DATASETS = [ DATASETS = [
{"repo_id": "lerobot-data-collection/level2_final_quality3", "label": "HQ curated"}, {"repo_id": "lerobot-data-collection/level2_final_quality3", "label": "HQ curated"},
@@ -350,59 +350,67 @@ def build_video_lookup(local: Path, camera_key: str) -> dict:
return lookup return lookup
def _decode_video_frames(video_path: str) -> list[np.ndarray]:
"""Decode all frames from a video file using PyAV. Returns list of RGB arrays."""
container = av.open(video_path)
stream = container.streams.video[0]
stream.thread_type = "AUTO"
decoded = []
for frame in container.decode(stream):
decoded.append(frame.to_ndarray(format="rgb24"))
container.close()
return decoded
def extract_frames( def extract_frames(
chosen_idx: np.ndarray, chosen_idx: np.ndarray,
episode_all: np.ndarray, episode_all: np.ndarray,
video_lookup: dict, video_lookup: dict,
) -> list[np.ndarray | None]: ) -> list[np.ndarray | None]:
""" """
Extract BGR frames for each chosen global index. Extract RGB frames for each chosen global index using PyAV.
Uses episode boundaries + fps to compute the seek timestamp. Returns list of (H, W, 3) RGB arrays (or None on failure).
Returns list of (H, W, 3) BGR arrays (or None on failure).
""" """
# Build per-episode local frame index: for each row in the dataset,
# its position within its episode
unique_eps = np.unique(episode_all) unique_eps = np.unique(episode_all)
ep_start: dict[int, int] = {} ep_start: dict[int, int] = {}
for ep in unique_eps: for ep in unique_eps:
ep_start[int(ep)] = int(np.where(episode_all == ep)[0][0]) ep_start[int(ep)] = int(np.where(episode_all == ep)[0][0])
frames: list[np.ndarray | None] = [] # Build jobs: (output_index, video_path, local_frame_number)
# Group by video file for efficient sequential access jobs: list[tuple[int, str, int]] = []
jobs: list[tuple[int, int, str, float]] = []
for out_i, global_i in enumerate(chosen_idx): for out_i, global_i in enumerate(chosen_idx):
ep = int(episode_all[global_i]) ep = int(episode_all[global_i])
info = video_lookup.get(ep) info = video_lookup.get(ep)
if info is None: if info is None:
jobs.append((out_i, -1, "", 0.0))
continue continue
local_frame = global_i - ep_start[ep] local_frame = global_i - ep_start[ep]
seek_ts = info["from_ts"] + local_frame / info["fps"] jobs.append((out_i, str(info["video_path"]), local_frame))
jobs.append((out_i, global_i, str(info["video_path"]), seek_ts))
jobs.sort(key=lambda x: (x[2], x[3])) # Group by video file, decode each video once
from collections import defaultdict
frames = [None] * len(chosen_idx) video_jobs: dict[str, list[tuple[int, int]]] = defaultdict(list)
current_cap = None for out_i, vpath, local_frame in jobs:
current_path = "" video_jobs[vpath].append((out_i, local_frame))
frames: list[np.ndarray | None] = [None] * len(chosen_idx)
extracted = 0 extracted = 0
for out_i, _global_i, vpath, seek_ts in jobs: n_videos = len(video_jobs)
if not vpath: for vi, (vpath, frame_requests) in enumerate(video_jobs.items()):
if not Path(vpath).exists():
continue continue
if vpath != current_path: try:
if current_cap is not None: decoded = _decode_video_frames(vpath)
current_cap.release() except Exception as exc:
current_cap = cv2.VideoCapture(vpath) print(f" Warning: failed to decode {Path(vpath).name}: {exc}")
current_path = vpath
if current_cap is None or not current_cap.isOpened():
continue continue
current_cap.set(cv2.CAP_PROP_POS_MSEC, seek_ts * 1000.0) for out_i, local_frame in frame_requests:
ret, frame = current_cap.read() if 0 <= local_frame < len(decoded):
if ret: frames[out_i] = decoded[local_frame]
frames[out_i] = frame extracted += 1
extracted += 1 if (vi + 1) % 50 == 0 or (vi + 1) == n_videos:
if current_cap is not None: print(f" Decoded {vi + 1}/{n_videos} videos ({extracted:,} frames so far)")
current_cap.release() del decoded
print(f" Extracted {extracted:,} / {len(chosen_idx):,} frames from video") print(f" Extracted {extracted:,} / {len(chosen_idx):,} frames from video")
return frames return frames
@@ -418,33 +426,27 @@ def encode_frames_siglip(
device: torch.device, device: torch.device,
) -> np.ndarray: ) -> np.ndarray:
""" """
Encode BGR frames through SigLIP vision encoder. Encode RGB frames through SigLIP vision encoder.
Returns (N, embed_dim) float32 array. Frames that are None get a zero vector. Returns (N, embed_dim) float32 array. Frames that are None get a zero vector.
""" """
print(f" Loading SigLIP model: {model_name}") print(f" Loading SigLIP model: {model_name}")
processor = AutoProcessor.from_pretrained(model_name) processor = AutoImageProcessor.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name).to(device).eval() model = AutoModel.from_pretrained(model_name).to(device).eval()
embed_dim = model.config.vision_config.hidden_size embed_dim = model.config.vision_config.hidden_size
n = len(frames) n = len(frames)
embeddings = np.zeros((n, embed_dim), dtype=np.float32) embeddings = np.zeros((n, embed_dim), dtype=np.float32)
# Collect valid frame indices
valid_indices = [i for i, f in enumerate(frames) if f is not None] valid_indices = [i for i, f in enumerate(frames) if f is not None]
print(f" Encoding {len(valid_indices):,} valid frames in batches of {batch_size}") print(f" Encoding {len(valid_indices):,} valid frames in batches of {batch_size}")
for batch_start in range(0, len(valid_indices), batch_size): for batch_start in range(0, len(valid_indices), batch_size):
batch_idx = valid_indices[batch_start : batch_start + batch_size] batch_idx = valid_indices[batch_start : batch_start + batch_size]
pil_images = [] pil_images = [Image.fromarray(frames[i]) for i in batch_idx]
for i in batch_idx:
bgr = frames[i]
rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
pil_images.append(Image.fromarray(rgb))
inputs = processor(images=pil_images, return_tensors="pt").to(device) inputs = processor(images=pil_images, return_tensors="pt").to(device)
with torch.no_grad(): with torch.no_grad():
image_features = model.get_image_features(**inputs) image_features = model.get_image_features(**inputs)
# L2-normalize embeddings for cosine-like KNN
image_features = torch.nn.functional.normalize(image_features, dim=-1) image_features = torch.nn.functional.normalize(image_features, dim=-1)
embeddings[batch_idx] = image_features.cpu().numpy() embeddings[batch_idx] = image_features.cpu().numpy()