mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-25 13:40:00 +00:00
fix(viz): use PyAV for AV1 video decoding, AutoImageProcessor for SigLIP
- Replace cv2.VideoCapture with PyAV (av library) which handles AV1 codec properly. Decode each video once and index by frame number. - Use AutoImageProcessor instead of AutoProcessor to avoid loading the SigLIP tokenizer (which requires sentencepiece). Made-with: Cursor
This commit is contained in:
@@ -13,7 +13,7 @@ primarily sees.
|
|||||||
import json
|
import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import cv2
|
import av
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
@@ -22,7 +22,7 @@ from huggingface_hub import snapshot_download
|
|||||||
from matplotlib.colors import LinearSegmentedColormap
|
from matplotlib.colors import LinearSegmentedColormap
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from scipy.spatial import cKDTree
|
from scipy.spatial import cKDTree
|
||||||
from transformers import AutoModel, AutoProcessor
|
from transformers import AutoImageProcessor, AutoModel
|
||||||
|
|
||||||
DATASETS = [
|
DATASETS = [
|
||||||
{"repo_id": "lerobot-data-collection/level2_final_quality3", "label": "HQ curated"},
|
{"repo_id": "lerobot-data-collection/level2_final_quality3", "label": "HQ curated"},
|
||||||
@@ -350,59 +350,67 @@ def build_video_lookup(local: Path, camera_key: str) -> dict:
|
|||||||
return lookup
|
return lookup
|
||||||
|
|
||||||
|
|
||||||
|
def _decode_video_frames(video_path: str) -> list[np.ndarray]:
|
||||||
|
"""Decode all frames from a video file using PyAV. Returns list of RGB arrays."""
|
||||||
|
container = av.open(video_path)
|
||||||
|
stream = container.streams.video[0]
|
||||||
|
stream.thread_type = "AUTO"
|
||||||
|
decoded = []
|
||||||
|
for frame in container.decode(stream):
|
||||||
|
decoded.append(frame.to_ndarray(format="rgb24"))
|
||||||
|
container.close()
|
||||||
|
return decoded
|
||||||
|
|
||||||
|
|
||||||
def extract_frames(
|
def extract_frames(
|
||||||
chosen_idx: np.ndarray,
|
chosen_idx: np.ndarray,
|
||||||
episode_all: np.ndarray,
|
episode_all: np.ndarray,
|
||||||
video_lookup: dict,
|
video_lookup: dict,
|
||||||
) -> list[np.ndarray | None]:
|
) -> list[np.ndarray | None]:
|
||||||
"""
|
"""
|
||||||
Extract BGR frames for each chosen global index.
|
Extract RGB frames for each chosen global index using PyAV.
|
||||||
Uses episode boundaries + fps to compute the seek timestamp.
|
Returns list of (H, W, 3) RGB arrays (or None on failure).
|
||||||
Returns list of (H, W, 3) BGR arrays (or None on failure).
|
|
||||||
"""
|
"""
|
||||||
# Build per-episode local frame index: for each row in the dataset,
|
|
||||||
# its position within its episode
|
|
||||||
unique_eps = np.unique(episode_all)
|
unique_eps = np.unique(episode_all)
|
||||||
ep_start: dict[int, int] = {}
|
ep_start: dict[int, int] = {}
|
||||||
for ep in unique_eps:
|
for ep in unique_eps:
|
||||||
ep_start[int(ep)] = int(np.where(episode_all == ep)[0][0])
|
ep_start[int(ep)] = int(np.where(episode_all == ep)[0][0])
|
||||||
|
|
||||||
frames: list[np.ndarray | None] = []
|
# Build jobs: (output_index, video_path, local_frame_number)
|
||||||
# Group by video file for efficient sequential access
|
jobs: list[tuple[int, str, int]] = []
|
||||||
jobs: list[tuple[int, int, str, float]] = []
|
|
||||||
for out_i, global_i in enumerate(chosen_idx):
|
for out_i, global_i in enumerate(chosen_idx):
|
||||||
ep = int(episode_all[global_i])
|
ep = int(episode_all[global_i])
|
||||||
info = video_lookup.get(ep)
|
info = video_lookup.get(ep)
|
||||||
if info is None:
|
if info is None:
|
||||||
jobs.append((out_i, -1, "", 0.0))
|
|
||||||
continue
|
continue
|
||||||
local_frame = global_i - ep_start[ep]
|
local_frame = global_i - ep_start[ep]
|
||||||
seek_ts = info["from_ts"] + local_frame / info["fps"]
|
jobs.append((out_i, str(info["video_path"]), local_frame))
|
||||||
jobs.append((out_i, global_i, str(info["video_path"]), seek_ts))
|
|
||||||
|
|
||||||
jobs.sort(key=lambda x: (x[2], x[3]))
|
# Group by video file, decode each video once
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
frames = [None] * len(chosen_idx)
|
video_jobs: dict[str, list[tuple[int, int]]] = defaultdict(list)
|
||||||
current_cap = None
|
for out_i, vpath, local_frame in jobs:
|
||||||
current_path = ""
|
video_jobs[vpath].append((out_i, local_frame))
|
||||||
|
|
||||||
|
frames: list[np.ndarray | None] = [None] * len(chosen_idx)
|
||||||
extracted = 0
|
extracted = 0
|
||||||
for out_i, _global_i, vpath, seek_ts in jobs:
|
n_videos = len(video_jobs)
|
||||||
if not vpath:
|
for vi, (vpath, frame_requests) in enumerate(video_jobs.items()):
|
||||||
|
if not Path(vpath).exists():
|
||||||
continue
|
continue
|
||||||
if vpath != current_path:
|
try:
|
||||||
if current_cap is not None:
|
decoded = _decode_video_frames(vpath)
|
||||||
current_cap.release()
|
except Exception as exc:
|
||||||
current_cap = cv2.VideoCapture(vpath)
|
print(f" Warning: failed to decode {Path(vpath).name}: {exc}")
|
||||||
current_path = vpath
|
|
||||||
if current_cap is None or not current_cap.isOpened():
|
|
||||||
continue
|
continue
|
||||||
current_cap.set(cv2.CAP_PROP_POS_MSEC, seek_ts * 1000.0)
|
for out_i, local_frame in frame_requests:
|
||||||
ret, frame = current_cap.read()
|
if 0 <= local_frame < len(decoded):
|
||||||
if ret:
|
frames[out_i] = decoded[local_frame]
|
||||||
frames[out_i] = frame
|
extracted += 1
|
||||||
extracted += 1
|
if (vi + 1) % 50 == 0 or (vi + 1) == n_videos:
|
||||||
if current_cap is not None:
|
print(f" Decoded {vi + 1}/{n_videos} videos ({extracted:,} frames so far)")
|
||||||
current_cap.release()
|
del decoded
|
||||||
|
|
||||||
print(f" Extracted {extracted:,} / {len(chosen_idx):,} frames from video")
|
print(f" Extracted {extracted:,} / {len(chosen_idx):,} frames from video")
|
||||||
return frames
|
return frames
|
||||||
@@ -418,33 +426,27 @@ def encode_frames_siglip(
|
|||||||
device: torch.device,
|
device: torch.device,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Encode BGR frames through SigLIP vision encoder.
|
Encode RGB frames through SigLIP vision encoder.
|
||||||
Returns (N, embed_dim) float32 array. Frames that are None get a zero vector.
|
Returns (N, embed_dim) float32 array. Frames that are None get a zero vector.
|
||||||
"""
|
"""
|
||||||
print(f" Loading SigLIP model: {model_name} …")
|
print(f" Loading SigLIP model: {model_name} …")
|
||||||
processor = AutoProcessor.from_pretrained(model_name)
|
processor = AutoImageProcessor.from_pretrained(model_name)
|
||||||
model = AutoModel.from_pretrained(model_name).to(device).eval()
|
model = AutoModel.from_pretrained(model_name).to(device).eval()
|
||||||
embed_dim = model.config.vision_config.hidden_size
|
embed_dim = model.config.vision_config.hidden_size
|
||||||
|
|
||||||
n = len(frames)
|
n = len(frames)
|
||||||
embeddings = np.zeros((n, embed_dim), dtype=np.float32)
|
embeddings = np.zeros((n, embed_dim), dtype=np.float32)
|
||||||
|
|
||||||
# Collect valid frame indices
|
|
||||||
valid_indices = [i for i, f in enumerate(frames) if f is not None]
|
valid_indices = [i for i, f in enumerate(frames) if f is not None]
|
||||||
print(f" Encoding {len(valid_indices):,} valid frames in batches of {batch_size} …")
|
print(f" Encoding {len(valid_indices):,} valid frames in batches of {batch_size} …")
|
||||||
|
|
||||||
for batch_start in range(0, len(valid_indices), batch_size):
|
for batch_start in range(0, len(valid_indices), batch_size):
|
||||||
batch_idx = valid_indices[batch_start : batch_start + batch_size]
|
batch_idx = valid_indices[batch_start : batch_start + batch_size]
|
||||||
pil_images = []
|
pil_images = [Image.fromarray(frames[i]) for i in batch_idx]
|
||||||
for i in batch_idx:
|
|
||||||
bgr = frames[i]
|
|
||||||
rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
|
|
||||||
pil_images.append(Image.fromarray(rgb))
|
|
||||||
|
|
||||||
inputs = processor(images=pil_images, return_tensors="pt").to(device)
|
inputs = processor(images=pil_images, return_tensors="pt").to(device)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
image_features = model.get_image_features(**inputs)
|
image_features = model.get_image_features(**inputs)
|
||||||
# L2-normalize embeddings for cosine-like KNN
|
|
||||||
image_features = torch.nn.functional.normalize(image_features, dim=-1)
|
image_features = torch.nn.functional.normalize(image_features, dim=-1)
|
||||||
embeddings[batch_idx] = image_features.cpu().numpy()
|
embeddings[batch_idx] = image_features.cpu().numpy()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user