mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 16:49:55 +00:00
fix(viz): use PyAV for AV1 video decoding, AutoImageProcessor for SigLIP
- Replace cv2.VideoCapture with PyAV (av library) which handles AV1 codec properly. Decode each video once and index by frame number. - Use AutoImageProcessor instead of AutoProcessor to avoid loading the SigLIP tokenizer (which requires sentencepiece). Made-with: Cursor
This commit is contained in:
@@ -13,7 +13,7 @@ primarily sees.
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import av
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
@@ -22,7 +22,7 @@ from huggingface_hub import snapshot_download
|
||||
from matplotlib.colors import LinearSegmentedColormap
|
||||
from PIL import Image
|
||||
from scipy.spatial import cKDTree
|
||||
from transformers import AutoModel, AutoProcessor
|
||||
from transformers import AutoImageProcessor, AutoModel
|
||||
|
||||
DATASETS = [
|
||||
{"repo_id": "lerobot-data-collection/level2_final_quality3", "label": "HQ curated"},
|
||||
@@ -350,59 +350,67 @@ def build_video_lookup(local: Path, camera_key: str) -> dict:
|
||||
return lookup
|
||||
|
||||
|
||||
def _decode_video_frames(video_path: str) -> list[np.ndarray]:
|
||||
"""Decode all frames from a video file using PyAV. Returns list of RGB arrays."""
|
||||
container = av.open(video_path)
|
||||
stream = container.streams.video[0]
|
||||
stream.thread_type = "AUTO"
|
||||
decoded = []
|
||||
for frame in container.decode(stream):
|
||||
decoded.append(frame.to_ndarray(format="rgb24"))
|
||||
container.close()
|
||||
return decoded
|
||||
|
||||
|
||||
def extract_frames(
|
||||
chosen_idx: np.ndarray,
|
||||
episode_all: np.ndarray,
|
||||
video_lookup: dict,
|
||||
) -> list[np.ndarray | None]:
|
||||
"""
|
||||
Extract BGR frames for each chosen global index.
|
||||
Uses episode boundaries + fps to compute the seek timestamp.
|
||||
Returns list of (H, W, 3) BGR arrays (or None on failure).
|
||||
Extract RGB frames for each chosen global index using PyAV.
|
||||
Returns list of (H, W, 3) RGB arrays (or None on failure).
|
||||
"""
|
||||
# Build per-episode local frame index: for each row in the dataset,
|
||||
# its position within its episode
|
||||
unique_eps = np.unique(episode_all)
|
||||
ep_start: dict[int, int] = {}
|
||||
for ep in unique_eps:
|
||||
ep_start[int(ep)] = int(np.where(episode_all == ep)[0][0])
|
||||
|
||||
frames: list[np.ndarray | None] = []
|
||||
# Group by video file for efficient sequential access
|
||||
jobs: list[tuple[int, int, str, float]] = []
|
||||
# Build jobs: (output_index, video_path, local_frame_number)
|
||||
jobs: list[tuple[int, str, int]] = []
|
||||
for out_i, global_i in enumerate(chosen_idx):
|
||||
ep = int(episode_all[global_i])
|
||||
info = video_lookup.get(ep)
|
||||
if info is None:
|
||||
jobs.append((out_i, -1, "", 0.0))
|
||||
continue
|
||||
local_frame = global_i - ep_start[ep]
|
||||
seek_ts = info["from_ts"] + local_frame / info["fps"]
|
||||
jobs.append((out_i, global_i, str(info["video_path"]), seek_ts))
|
||||
jobs.append((out_i, str(info["video_path"]), local_frame))
|
||||
|
||||
jobs.sort(key=lambda x: (x[2], x[3]))
|
||||
# Group by video file, decode each video once
|
||||
from collections import defaultdict
|
||||
|
||||
frames = [None] * len(chosen_idx)
|
||||
current_cap = None
|
||||
current_path = ""
|
||||
video_jobs: dict[str, list[tuple[int, int]]] = defaultdict(list)
|
||||
for out_i, vpath, local_frame in jobs:
|
||||
video_jobs[vpath].append((out_i, local_frame))
|
||||
|
||||
frames: list[np.ndarray | None] = [None] * len(chosen_idx)
|
||||
extracted = 0
|
||||
for out_i, _global_i, vpath, seek_ts in jobs:
|
||||
if not vpath:
|
||||
n_videos = len(video_jobs)
|
||||
for vi, (vpath, frame_requests) in enumerate(video_jobs.items()):
|
||||
if not Path(vpath).exists():
|
||||
continue
|
||||
if vpath != current_path:
|
||||
if current_cap is not None:
|
||||
current_cap.release()
|
||||
current_cap = cv2.VideoCapture(vpath)
|
||||
current_path = vpath
|
||||
if current_cap is None or not current_cap.isOpened():
|
||||
try:
|
||||
decoded = _decode_video_frames(vpath)
|
||||
except Exception as exc:
|
||||
print(f" Warning: failed to decode {Path(vpath).name}: {exc}")
|
||||
continue
|
||||
current_cap.set(cv2.CAP_PROP_POS_MSEC, seek_ts * 1000.0)
|
||||
ret, frame = current_cap.read()
|
||||
if ret:
|
||||
frames[out_i] = frame
|
||||
extracted += 1
|
||||
if current_cap is not None:
|
||||
current_cap.release()
|
||||
for out_i, local_frame in frame_requests:
|
||||
if 0 <= local_frame < len(decoded):
|
||||
frames[out_i] = decoded[local_frame]
|
||||
extracted += 1
|
||||
if (vi + 1) % 50 == 0 or (vi + 1) == n_videos:
|
||||
print(f" Decoded {vi + 1}/{n_videos} videos ({extracted:,} frames so far)")
|
||||
del decoded
|
||||
|
||||
print(f" Extracted {extracted:,} / {len(chosen_idx):,} frames from video")
|
||||
return frames
|
||||
@@ -418,33 +426,27 @@ def encode_frames_siglip(
|
||||
device: torch.device,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Encode BGR frames through SigLIP vision encoder.
|
||||
Encode RGB frames through SigLIP vision encoder.
|
||||
Returns (N, embed_dim) float32 array. Frames that are None get a zero vector.
|
||||
"""
|
||||
print(f" Loading SigLIP model: {model_name} …")
|
||||
processor = AutoProcessor.from_pretrained(model_name)
|
||||
processor = AutoImageProcessor.from_pretrained(model_name)
|
||||
model = AutoModel.from_pretrained(model_name).to(device).eval()
|
||||
embed_dim = model.config.vision_config.hidden_size
|
||||
|
||||
n = len(frames)
|
||||
embeddings = np.zeros((n, embed_dim), dtype=np.float32)
|
||||
|
||||
# Collect valid frame indices
|
||||
valid_indices = [i for i, f in enumerate(frames) if f is not None]
|
||||
print(f" Encoding {len(valid_indices):,} valid frames in batches of {batch_size} …")
|
||||
|
||||
for batch_start in range(0, len(valid_indices), batch_size):
|
||||
batch_idx = valid_indices[batch_start : batch_start + batch_size]
|
||||
pil_images = []
|
||||
for i in batch_idx:
|
||||
bgr = frames[i]
|
||||
rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
|
||||
pil_images.append(Image.fromarray(rgb))
|
||||
pil_images = [Image.fromarray(frames[i]) for i in batch_idx]
|
||||
|
||||
inputs = processor(images=pil_images, return_tensors="pt").to(device)
|
||||
with torch.no_grad():
|
||||
image_features = model.get_image_features(**inputs)
|
||||
# L2-normalize embeddings for cosine-like KNN
|
||||
image_features = torch.nn.functional.normalize(image_features, dim=-1)
|
||||
embeddings[batch_idx] = image_features.cpu().numpy()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user