From 026e4c937dc06989d898b2ea54deed374cc9047f Mon Sep 17 00:00:00 2001 From: Pepijn Date: Mon, 23 Mar 2026 23:02:50 -0700 Subject: [PATCH] fix(viz): use PyAV for AV1 video decoding, AutoImageProcessor for SigLIP - Replace cv2.VideoCapture with PyAV (av library) which handles AV1 codec properly. Decode each video once and index by frame number. - Use AutoImageProcessor instead of AutoProcessor to avoid loading the SigLIP tokenizer (which requires sentencepiece). Made-with: Cursor --- .../visualization_tools/action_consistency.py | 84 ++++++++++--------- 1 file changed, 43 insertions(+), 41 deletions(-) diff --git a/examples/dataset/visualization_tools/action_consistency.py b/examples/dataset/visualization_tools/action_consistency.py index a74b982d6..72c8fb2d7 100644 --- a/examples/dataset/visualization_tools/action_consistency.py +++ b/examples/dataset/visualization_tools/action_consistency.py @@ -13,7 +13,7 @@ primarily sees. import json from pathlib import Path -import cv2 +import av import matplotlib.pyplot as plt import numpy as np import pandas as pd @@ -22,7 +22,7 @@ from huggingface_hub import snapshot_download from matplotlib.colors import LinearSegmentedColormap from PIL import Image from scipy.spatial import cKDTree -from transformers import AutoModel, AutoProcessor +from transformers import AutoImageProcessor, AutoModel DATASETS = [ {"repo_id": "lerobot-data-collection/level2_final_quality3", "label": "HQ curated"}, @@ -350,59 +350,67 @@ def build_video_lookup(local: Path, camera_key: str) -> dict: return lookup +def _decode_video_frames(video_path: str) -> list[np.ndarray]: + """Decode all frames from a video file using PyAV. Returns list of RGB arrays.""" + container = av.open(video_path) + stream = container.streams.video[0] + stream.thread_type = "AUTO" + decoded = [] + for frame in container.decode(stream): + decoded.append(frame.to_ndarray(format="rgb24")) + container.close() + return decoded + + def extract_frames( chosen_idx: np.ndarray, episode_all: np.ndarray, video_lookup: dict, ) -> list[np.ndarray | None]: """ - Extract BGR frames for each chosen global index. - Uses episode boundaries + fps to compute the seek timestamp. - Returns list of (H, W, 3) BGR arrays (or None on failure). + Extract RGB frames for each chosen global index using PyAV. + Returns list of (H, W, 3) RGB arrays (or None on failure). """ - # Build per-episode local frame index: for each row in the dataset, - # its position within its episode unique_eps = np.unique(episode_all) ep_start: dict[int, int] = {} for ep in unique_eps: ep_start[int(ep)] = int(np.where(episode_all == ep)[0][0]) - frames: list[np.ndarray | None] = [] - # Group by video file for efficient sequential access - jobs: list[tuple[int, int, str, float]] = [] + # Build jobs: (output_index, video_path, local_frame_number) + jobs: list[tuple[int, str, int]] = [] for out_i, global_i in enumerate(chosen_idx): ep = int(episode_all[global_i]) info = video_lookup.get(ep) if info is None: - jobs.append((out_i, -1, "", 0.0)) continue local_frame = global_i - ep_start[ep] - seek_ts = info["from_ts"] + local_frame / info["fps"] - jobs.append((out_i, global_i, str(info["video_path"]), seek_ts)) + jobs.append((out_i, str(info["video_path"]), local_frame)) - jobs.sort(key=lambda x: (x[2], x[3])) + # Group by video file, decode each video once + from collections import defaultdict - frames = [None] * len(chosen_idx) - current_cap = None - current_path = "" + video_jobs: dict[str, list[tuple[int, int]]] = defaultdict(list) + for out_i, vpath, local_frame in jobs: + video_jobs[vpath].append((out_i, local_frame)) + + frames: list[np.ndarray | None] = [None] * len(chosen_idx) extracted = 0 - for out_i, _global_i, vpath, seek_ts in jobs: - if not vpath: + n_videos = len(video_jobs) + for vi, (vpath, frame_requests) in enumerate(video_jobs.items()): + if not Path(vpath).exists(): continue - if vpath != current_path: - if current_cap is not None: - current_cap.release() - current_cap = cv2.VideoCapture(vpath) - current_path = vpath - if current_cap is None or not current_cap.isOpened(): + try: + decoded = _decode_video_frames(vpath) + except Exception as exc: + print(f" Warning: failed to decode {Path(vpath).name}: {exc}") continue - current_cap.set(cv2.CAP_PROP_POS_MSEC, seek_ts * 1000.0) - ret, frame = current_cap.read() - if ret: - frames[out_i] = frame - extracted += 1 - if current_cap is not None: - current_cap.release() + for out_i, local_frame in frame_requests: + if 0 <= local_frame < len(decoded): + frames[out_i] = decoded[local_frame] + extracted += 1 + if (vi + 1) % 50 == 0 or (vi + 1) == n_videos: + print(f" Decoded {vi + 1}/{n_videos} videos ({extracted:,} frames so far)") + del decoded print(f" Extracted {extracted:,} / {len(chosen_idx):,} frames from video") return frames @@ -418,33 +426,27 @@ def encode_frames_siglip( device: torch.device, ) -> np.ndarray: """ - Encode BGR frames through SigLIP vision encoder. + Encode RGB frames through SigLIP vision encoder. Returns (N, embed_dim) float32 array. Frames that are None get a zero vector. """ print(f" Loading SigLIP model: {model_name} …") - processor = AutoProcessor.from_pretrained(model_name) + processor = AutoImageProcessor.from_pretrained(model_name) model = AutoModel.from_pretrained(model_name).to(device).eval() embed_dim = model.config.vision_config.hidden_size n = len(frames) embeddings = np.zeros((n, embed_dim), dtype=np.float32) - # Collect valid frame indices valid_indices = [i for i, f in enumerate(frames) if f is not None] print(f" Encoding {len(valid_indices):,} valid frames in batches of {batch_size} …") for batch_start in range(0, len(valid_indices), batch_size): batch_idx = valid_indices[batch_start : batch_start + batch_size] - pil_images = [] - for i in batch_idx: - bgr = frames[i] - rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB) - pil_images.append(Image.fromarray(rgb)) + pil_images = [Image.fromarray(frames[i]) for i in batch_idx] inputs = processor(images=pil_images, return_tensors="pt").to(device) with torch.no_grad(): image_features = model.get_image_features(**inputs) - # L2-normalize embeddings for cosine-like KNN image_features = torch.nn.functional.normalize(image_features, dim=-1) embeddings[batch_idx] = image_features.cpu().numpy()