diff --git a/examples/dataset/visualization_tools/action_consistency.py b/examples/dataset/visualization_tools/action_consistency.py index 70dedbce8..ba5b71865 100644 --- a/examples/dataset/visualization_tools/action_consistency.py +++ b/examples/dataset/visualization_tools/action_consistency.py @@ -1,22 +1,28 @@ """ -Action-state consistency analysis for imitation learning datasets. -For each frame, finds K nearest neighbors in state space (from other episodes) -and measures the variance of corresponding actions. High variance at similar -states = contradictory supervision for the policy. +Action consistency analysis for imitation learning datasets. -Outputs a comparison figure with histograms, per-episode curves, and spatial -heatmaps showing where demonstrations conflict. +Two parallel analyses per dataset: + 1. State-based: KNN in joint-state space → action chunk variance + 2. Image-based: KNN in SigLIP embedding space → action chunk variance + +Comparing them reveals whether visual similarity and proprioceptive similarity +agree on where the data is inconsistent — and images are what the policy +primarily sees. """ import json from pathlib import Path +import cv2 import matplotlib.pyplot as plt import numpy as np import pandas as pd +import torch from huggingface_hub import snapshot_download from matplotlib.colors import LinearSegmentedColormap +from PIL import Image from scipy.spatial import cKDTree +from transformers import AutoModel, AutoProcessor DATASETS = [ {"repo_id": "lerobot-data-collection/level2_final_quality3", "label": "HQ curated"}, @@ -25,9 +31,12 @@ DATASETS = [ OUTPUT_DIR = Path(__file__).resolve().parent / "outputs" OUTPUT_DIR.mkdir(exist_ok=True) -MAX_FRAMES = 10_000 +MAX_FRAMES = 100_000 K_NEIGHBORS = 50 ACTION_CHUNK_SIZE = 30 +CAMERA_KEY = "observation.images.base" +ENCODER_MODEL = "google/siglip-base-patch16-224" +ENCODE_BATCH_SIZE = 128 SEED = 42 DPI = 150 @@ -184,14 +193,17 @@ def _find_joint_indices(features: dict, state_col: str, n_dim: int) -> tuple[lis raise RuntimeError(f"State dim {n_dim} too small for bimanual 7-DOF robot") -def download_data(repo_id: str) -> Path: - print(f" Downloading {repo_id} (parquet only) …") +def download_data(repo_id: str, camera_key: str) -> Path: + print(f" Downloading {repo_id} (parquet + {camera_key} videos) …") return Path( snapshot_download( repo_id=repo_id, repo_type="dataset", - allow_patterns=["meta/**", "data/**"], - ignore_patterns=["*.mp4", "videos/**"], + allow_patterns=[ + "meta/**", + "data/**", + f"videos/{camera_key}/**", + ], ) ) @@ -203,9 +215,8 @@ def _build_action_chunks( actions: np.ndarray, episode_ids: np.ndarray, chunk_size: int ) -> tuple[np.ndarray, np.ndarray]: """ - Build action chunks: for each frame, concatenate the next chunk_size actions - from the same episode. Returns (action_chunks, valid_mask). - Frames too close to episode end to form a full chunk are marked invalid. + For each frame, concatenate the next chunk_size actions from the same episode. + Returns (action_chunks, valid_mask). """ n = len(actions) act_dim = actions.shape[1] @@ -216,7 +227,6 @@ def _build_action_chunks( end = i + chunk_size if end > n: continue - # All frames in the chunk must belong to the same episode if episode_ids[i] != episode_ids[end - 1]: continue chunks[i] = actions[i:end].ravel() @@ -227,8 +237,8 @@ def _build_action_chunks( def load_state_action_data(local: Path, max_frames: int, chunk_size: int, rng: np.random.Generator) -> dict: """ - Load observation.state and action columns, build action chunks of size - chunk_size (matching what the policy learns), subsample, normalize. + Load observation.state and action, build action chunks, subsample, normalize. + Also returns the original row indices (`chosen_idx`) for video frame mapping. """ info = json.loads((local / "meta" / "info.json").read_text()) features = info.get("features", {}) @@ -260,13 +270,11 @@ def load_state_action_data(local: Path, max_frames: int, chunk_size: int, rng: n left_idx, right_idx = _find_joint_indices(features, state_col, n_dim) - # Build action chunks within episode boundaries print(" Building action chunks …") action_chunks, valid = _build_action_chunks(action_all, episode_all, chunk_size) valid_idx = np.where(valid)[0] print(f" Valid frames (with full action chunk): {len(valid_idx):,} / {n_total:,}") - # Subsample from valid frames only if len(valid_idx) > max_frames: chosen = np.sort(rng.choice(valid_idx, max_frames, replace=False)) else: @@ -277,7 +285,6 @@ def load_state_action_data(local: Path, max_frames: int, chunk_size: int, rng: n action_raw = action_chunks[chosen] episode_ids = episode_all[chosen] - # Z-score normalize for fair KNN distance state_mean = state_raw.mean(axis=0) state_std = state_raw.std(axis=0) state_std[state_std < 1e-8] = 1.0 @@ -294,38 +301,186 @@ def load_state_action_data(local: Path, max_frames: int, chunk_size: int, rng: n "action_raw": action_raw, "action_norm": action_norm, "episode_ids": episode_ids, + "episode_all": episode_all, "left_joint_idx": left_idx, "right_joint_idx": right_idx, "n_total": n_total, + "chosen_idx": chosen, + "df": df, } +# ── Video → frame extraction ────────────────────────────── + + +def build_video_lookup(local: Path, camera_key: str) -> dict: + """ + Build a mapping from episode_index → {video_path, fps, from_ts}. + """ + info = json.loads((local / "meta" / "info.json").read_text()) + fps = info["fps"] + video_template = info.get( + "video_path", + "videos/{video_key}/chunk-{chunk_index:03d}/file-{file_index:03d}.mp4", + ) + + ep_rows = [] + for pq in sorted((local / "meta" / "episodes").glob("**/*.parquet")): + ep_rows.append(pd.read_parquet(pq)) + ep_df = pd.concat(ep_rows, ignore_index=True) + + chunk_col = f"videos/{camera_key}/chunk_index" + file_col = f"videos/{camera_key}/file_index" + ts_from = f"videos/{camera_key}/from_timestamp" + if chunk_col not in ep_df.columns: + chunk_col = f"{camera_key}/chunk_index" + file_col = f"{camera_key}/file_index" + ts_from = f"{camera_key}/from_timestamp" + + lookup: dict[int, dict] = {} + for _, row in ep_df.iterrows(): + ci = int(row[chunk_col]) + fi = int(row[file_col]) + video_rel = video_template.format(video_key=camera_key, chunk_index=ci, file_index=fi) + lookup[int(row["episode_index"])] = { + "video_path": local / video_rel, + "from_ts": float(row[ts_from]), + "fps": fps, + } + return lookup + + +def extract_frames( + chosen_idx: np.ndarray, + episode_all: np.ndarray, + video_lookup: dict, +) -> list[np.ndarray | None]: + """ + Extract BGR frames for each chosen global index. + Uses episode boundaries + fps to compute the seek timestamp. + Returns list of (H, W, 3) BGR arrays (or None on failure). + """ + # Build per-episode local frame index: for each row in the dataset, + # its position within its episode + unique_eps = np.unique(episode_all) + ep_start: dict[int, int] = {} + for ep in unique_eps: + ep_start[int(ep)] = int(np.where(episode_all == ep)[0][0]) + + frames: list[np.ndarray | None] = [] + # Group by video file for efficient sequential access + jobs: list[tuple[int, int, str, float]] = [] + for out_i, global_i in enumerate(chosen_idx): + ep = int(episode_all[global_i]) + info = video_lookup.get(ep) + if info is None: + jobs.append((out_i, -1, "", 0.0)) + continue + local_frame = global_i - ep_start[ep] + seek_ts = info["from_ts"] + local_frame / info["fps"] + jobs.append((out_i, global_i, str(info["video_path"]), seek_ts)) + + jobs.sort(key=lambda x: (x[2], x[3])) + + frames = [None] * len(chosen_idx) + current_cap = None + current_path = "" + extracted = 0 + for out_i, _global_i, vpath, seek_ts in jobs: + if not vpath: + continue + if vpath != current_path: + if current_cap is not None: + current_cap.release() + current_cap = cv2.VideoCapture(vpath) + current_path = vpath + if current_cap is None or not current_cap.isOpened(): + continue + current_cap.set(cv2.CAP_PROP_POS_MSEC, seek_ts * 1000.0) + ret, frame = current_cap.read() + if ret: + frames[out_i] = frame + extracted += 1 + if current_cap is not None: + current_cap.release() + + print(f" Extracted {extracted:,} / {len(chosen_idx):,} frames from video") + return frames + + +# ── SigLIP encoding ───────────────────────────────────── + + +def encode_frames_siglip( + frames: list[np.ndarray | None], + model_name: str, + batch_size: int, + device: torch.device, +) -> np.ndarray: + """ + Encode BGR frames through SigLIP vision encoder. + Returns (N, embed_dim) float32 array. Frames that are None get a zero vector. + """ + print(f" Loading SigLIP model: {model_name} …") + processor = AutoProcessor.from_pretrained(model_name) + model = AutoModel.from_pretrained(model_name).to(device).eval() + embed_dim = model.config.vision_config.hidden_size + + n = len(frames) + embeddings = np.zeros((n, embed_dim), dtype=np.float32) + + # Collect valid frame indices + valid_indices = [i for i, f in enumerate(frames) if f is not None] + print(f" Encoding {len(valid_indices):,} valid frames in batches of {batch_size} …") + + for batch_start in range(0, len(valid_indices), batch_size): + batch_idx = valid_indices[batch_start : batch_start + batch_size] + pil_images = [] + for i in batch_idx: + bgr = frames[i] + rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB) + pil_images.append(Image.fromarray(rgb)) + + inputs = processor(images=pil_images, return_tensors="pt").to(device) + with torch.no_grad(): + image_features = model.get_image_features(**inputs) + # L2-normalize embeddings for cosine-like KNN + image_features = torch.nn.functional.normalize(image_features, dim=-1) + embeddings[batch_idx] = image_features.cpu().numpy() + + done = min(batch_start + batch_size, len(valid_indices)) + if done % (batch_size * 10) == 0 or done == len(valid_indices): + print(f" {done:,} / {len(valid_indices):,} encoded") + + del model, processor + torch.cuda.empty_cache() + return embeddings + + # ── KNN consistency ───────────────────────────────────── def compute_consistency( - state_norm: np.ndarray, + features: np.ndarray, action_norm: np.ndarray, episode_ids: np.ndarray, k: int, + label: str = "", ) -> np.ndarray: """ - For each frame, find K nearest neighbors in state space from *other* episodes. + For each frame, find K nearest neighbors in feature space from other episodes. Return per-frame action variance (mean across action dims). """ - n = len(state_norm) - print(f" Building KD-tree on {n:,} state vectors …") - tree = cKDTree(state_norm) + n = len(features) + print(f" Building KD-tree on {n:,} vectors ({label}) …") + tree = cKDTree(features) - # Query extra neighbors to have room after filtering same-episode k_query = min(k * 3, n - 1) print(f" Querying {k_query} neighbors per frame …") - dists, indices = tree.query(state_norm, k=k_query + 1) - - # indices[:, 0] is the point itself — skip it + _dists, indices = tree.query(features, k=k_query + 1) indices = indices[:, 1:] - print(" Computing cross-episode action variance …") + print(f" Computing cross-episode action variance ({label}) …") variance = np.zeros(n) for i in range(n): ep_i = episode_ids[i] @@ -343,113 +498,133 @@ def compute_consistency( # ── Visualization ─────────────────────────────────────── +def _style_ax(ax: plt.Axes) -> None: + ax.set_facecolor("#0d1117") + ax.tick_params(colors="#555", labelsize=8) + for spine in ax.spines.values(): + spine.set_color("#333") + + +def _plot_histogram(ax: plt.Axes, variance: np.ndarray, title: str, color: str) -> None: + _style_ax(ax) + median_var = np.median(variance) + mean_var = np.mean(variance) + nonzero = variance[variance > 0] + if len(nonzero) > 0: + bins = np.logspace(np.log10(nonzero.min().clip(1e-6)), np.log10(nonzero.max()), 60) + ax.hist(nonzero, bins=bins, color=color, alpha=0.8, edgecolor="#222") + ax.set_xscale("log") + ax.axvline(median_var, color="#ff6600", linewidth=2, label=f"median={median_var:.3f}") + ax.axvline(mean_var, color="#ff2222", linewidth=2, linestyle="--", label=f"mean={mean_var:.3f}") + ax.set_xlabel("Action variance (log scale)", color="#888", fontsize=10) + ax.set_ylabel("Frame count", color="#888", fontsize=10) + ax.set_title(title, color="white", fontsize=11, pad=10) + ax.legend(fontsize=8, facecolor="#1a1a2e", edgecolor="#333", labelcolor="white") + + +def _plot_episode_curves( + ax: plt.Axes, + var_state: np.ndarray, + var_image: np.ndarray, + episode_ids: np.ndarray, + title: str, +) -> None: + _style_ax(ax) + unique_eps = np.unique(episode_ids) + + ep_means_s = np.array([var_state[episode_ids == ep].mean() for ep in unique_eps]) + ep_means_i = np.array([var_image[episode_ids == ep].mean() for ep in unique_eps]) + + sorted_s = np.sort(ep_means_s)[::-1] + sorted_i = np.sort(ep_means_i)[::-1] + ep_x = np.arange(len(unique_eps)) + + ax.fill_between(ep_x, sorted_s, alpha=0.2, color="#4363d8") + ax.plot(ep_x, sorted_s, color="#4363d8", linewidth=1.2, label=f"State (med={np.median(ep_means_s):.3f})") + ax.fill_between(ep_x, sorted_i, alpha=0.2, color="#e6194b") + ax.plot(ep_x, sorted_i, color="#e6194b", linewidth=1.2, label=f"Image (med={np.median(ep_means_i):.3f})") + + ax.set_xlabel("Episode rank (worst → best)", color="#888", fontsize=10) + ax.set_ylabel("Mean action variance", color="#888", fontsize=10) + ax.set_title(title, color="white", fontsize=11, pad=10) + ax.legend(fontsize=8, facecolor="#1a1a2e", edgecolor="#333", labelcolor="white") + + +def _plot_heatmap( + ax: plt.Axes, fig: plt.Figure, tcp_xz: np.ndarray, variance: np.ndarray, title: str +) -> None: + _style_ax(ax) + order = np.argsort(variance) + pts = tcp_xz[order] + var_sorted = variance[order] + vmin = np.percentile(variance[variance > 0], 5) if np.any(variance > 0) else 0 + vmax = np.percentile(variance[variance > 0], 95) if np.any(variance > 0) else 1 + sc = ax.scatter( + pts[:, 0], + pts[:, 1], + c=var_sorted, + cmap=CONSISTENCY_CMAP, + s=0.5, + alpha=0.6, + vmin=vmin, + vmax=vmax, + rasterized=True, + ) + ax.set_xlabel("X (m)", color="#888", fontsize=10) + ax.set_ylabel("Z (m)", color="#888", fontsize=10) + ax.set_title(title, color="white", fontsize=11, pad=10) + ax.set_aspect("equal") + cbar = fig.colorbar(sc, ax=ax, shrink=0.8, pad=0.02) + cbar.set_label("Action variance", color="white", fontsize=9) + cbar.ax.tick_params(colors="#aaa", labelsize=7) + + def render(results: list[dict], out_path: Path) -> None: + """ + 4-row x N-column figure: + Row 0: State-based variance histogram + Row 1: Image-based variance histogram + Row 2: Per-episode curves (both overlaid) + Row 3: Spatial heatmap (image-based variance) + """ n_ds = len(results) - fig, axes = plt.subplots(3, n_ds, figsize=(9 * n_ds, 18), facecolor="#0d1117") + fig, axes = plt.subplots(4, n_ds, figsize=(9 * n_ds, 24), facecolor="#0d1117") if n_ds == 1: axes = axes[:, np.newaxis] headline_parts = [] - for col, r in enumerate(results): - variance = r["variance"] - episode_ids = r["episode_ids"] - tcp_xz = r["tcp_xz"] label = r["label"] + var_s = r["var_state"] + var_i = r["var_image"] + tcp_xz = r["tcp_xz"] + episode_ids = r["episode_ids"] - median_var = np.median(variance) - mean_var = np.mean(variance) - headline_parts.append(f"{label}: median={median_var:.3f}, mean={mean_var:.3f}") + med_s = np.median(var_s) + med_i = np.median(var_i) + headline_parts.append(f"{label}: state={med_s:.3f}, image={med_i:.3f}") - # Row 0: Histogram of per-frame action variance - ax = axes[0, col] - ax.set_facecolor("#0d1117") - nonzero = variance[variance > 0] - if len(nonzero) > 0: - bins = np.logspace(np.log10(nonzero.min().clip(1e-6)), np.log10(nonzero.max()), 60) - ax.hist(nonzero, bins=bins, color="#4363d8", alpha=0.8, edgecolor="#222") - ax.set_xscale("log") - ax.axvline(median_var, color="#ff6600", linewidth=2, label=f"median={median_var:.3f}") - ax.axvline(mean_var, color="#ff2222", linewidth=2, linestyle="--", label=f"mean={mean_var:.3f}") - ax.set_xlabel("Action variance (log scale)", color="#888", fontsize=10) - ax.set_ylabel("Frame count", color="#888", fontsize=10) - ax.set_title(f"{label}\nPer-frame action variance distribution", color="white", fontsize=12, pad=10) - ax.tick_params(colors="#555", labelsize=8) - for spine in ax.spines.values(): - spine.set_color("#333") - ax.legend(fontsize=9, facecolor="#1a1a2e", edgecolor="#333", labelcolor="white") - - # Row 1: Per-episode mean inconsistency curve (sorted) - ax = axes[1, col] - ax.set_facecolor("#0d1117") - unique_eps = np.unique(episode_ids) - ep_means = np.array([variance[episode_ids == ep].mean() for ep in unique_eps]) - sorted_means = np.sort(ep_means)[::-1] - ep_x = np.arange(len(sorted_means)) - - p90 = np.percentile(ep_means, 90) - above_p90 = np.sum(ep_means > p90) - - ax.fill_between(ep_x, sorted_means, alpha=0.3, color="#4363d8") - ax.plot(ep_x, sorted_means, color="#4363d8", linewidth=1.2) - ax.axhline( - np.median(ep_means), color="#ff6600", linewidth=1.5, label=f"median={np.median(ep_means):.3f}" + _plot_histogram(axes[0, col], var_s, f"{label}\nState-based variance (K={K_NEIGHBORS})", "#4363d8") + _plot_histogram( + axes[1, col], var_i, f"{label}\nImage-based variance (SigLIP, K={K_NEIGHBORS})", "#e6194b" ) - ax.axhline( - p90, color="#ff2222", linewidth=1, linestyle=":", label=f"p90={p90:.3f} ({above_p90} eps above)" + _plot_episode_curves( + axes[2, col], + var_s, + var_i, + episode_ids, + f"{label}\nPer-episode inconsistency ({len(np.unique(episode_ids)):,} episodes)", ) - ax.set_xlabel("Episode rank (worst → best)", color="#888", fontsize=10) - ax.set_ylabel("Mean action variance", color="#888", fontsize=10) - ax.set_title( - f"{label}\nPer-episode inconsistency ({len(unique_eps):,} episodes)", - color="white", - fontsize=12, - pad=10, + _plot_heatmap( + axes[3, col], + fig, + tcp_xz, + var_i, + f"{label}\nImage-based variance by TCP position (XZ)", ) - ax.tick_params(colors="#555", labelsize=8) - for spine in ax.spines.values(): - spine.set_color("#333") - ax.legend(fontsize=9, facecolor="#1a1a2e", edgecolor="#333", labelcolor="white") - - # Row 2: Spatial heatmap (XZ side view) colored by local action variance - ax = axes[2, col] - ax.set_facecolor("#0d1117") - order = np.argsort(variance) - pts = tcp_xz[order] - var_sorted = variance[order] - - vmin = np.percentile(variance[variance > 0], 5) if np.any(variance > 0) else 0 - vmax = np.percentile(variance[variance > 0], 95) if np.any(variance > 0) else 1 - - sc = ax.scatter( - pts[:, 0], - pts[:, 1], - c=var_sorted, - cmap=CONSISTENCY_CMAP, - s=0.5, - alpha=0.6, - vmin=vmin, - vmax=vmax, - rasterized=True, - ) - ax.set_xlabel("X (m)", color="#888", fontsize=10) - ax.set_ylabel("Z (m)", color="#888", fontsize=10) - ax.set_title( - f"{label}\nAction variance by TCP position (XZ side)", - color="white", - fontsize=12, - pad=10, - ) - ax.tick_params(colors="#555", labelsize=8) - for spine in ax.spines.values(): - spine.set_color("#333") - ax.set_aspect("equal") - cbar = fig.colorbar(sc, ax=ax, shrink=0.8, pad=0.02) - cbar.set_label("Action variance", color="white", fontsize=9) - cbar.ax.tick_params(colors="#aaa", labelsize=7) fig.suptitle( - f"Action-State Consistency Analysis (action chunk = {ACTION_CHUNK_SIZE})\n" + f"Action Consistency: State vs Image (chunk={ACTION_CHUNK_SIZE}, K={K_NEIGHBORS})\n" + " | ".join(headline_parts), color="white", fontsize=15, @@ -465,6 +640,8 @@ def render(results: list[dict], out_path: Path) -> None: def main() -> None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"Device: {device}") rng = np.random.default_rng(SEED) results = [] @@ -474,18 +651,34 @@ def main() -> None: print(f" {label}: {repo_id}") print(f"{'=' * 60}") - local = download_data(repo_id) + local = download_data(repo_id, CAMERA_KEY) data = load_state_action_data(local, MAX_FRAMES, ACTION_CHUNK_SIZE, rng) - variance = compute_consistency( - data["state_norm"], data["action_norm"], data["episode_ids"], K_NEIGHBORS + # --- State-based KNN --- + var_state = compute_consistency( + data["state_norm"], data["action_norm"], data["episode_ids"], K_NEIGHBORS, "state" ) print( - f" Variance stats: median={np.median(variance):.4f} mean={np.mean(variance):.4f} " - f"p90={np.percentile(variance, 90):.4f}" + f" State variance: median={np.median(var_state):.4f} " + f"mean={np.mean(var_state):.4f} p90={np.percentile(var_state, 90):.4f}" ) - # Compute FK for spatial heatmap (left arm TCP, XZ projection) + # --- Image-based KNN --- + print("\n Preparing image embeddings …") + video_lookup = build_video_lookup(local, CAMERA_KEY) + frames = extract_frames(data["chosen_idx"], data["episode_all"], video_lookup) + embeddings = encode_frames_siglip(frames, ENCODER_MODEL, ENCODE_BATCH_SIZE, device) + del frames # free memory + + var_image = compute_consistency( + embeddings, data["action_norm"], data["episode_ids"], K_NEIGHBORS, "image" + ) + print( + f" Image variance: median={np.median(var_image):.4f} " + f"mean={np.mean(var_image):.4f} p90={np.percentile(var_image, 90):.4f}" + ) + + # FK for spatial heatmap print(" Computing FK for spatial heatmap …") left_raw = data["state_raw"][:, data["left_joint_idx"]] left_rad = _detect_and_convert(left_raw) @@ -495,7 +688,8 @@ def main() -> None: results.append( { "label": label, - "variance": variance, + "var_state": var_state, + "var_image": var_image, "episode_ids": data["episode_ids"], "tcp_xz": tcp_xz, "n_total": data["n_total"], @@ -505,6 +699,17 @@ def main() -> None: out = OUTPUT_DIR / "action_consistency_comparison.jpg" render(results, out) + # Save worst-episodes summary (image-based, since that's the stronger signal) + worst_summary = {} + for r in results: + unique_eps = np.unique(r["episode_ids"]) + ep_means = {int(ep): float(r["var_image"][r["episode_ids"] == ep].mean()) for ep in unique_eps} + ranked = sorted(ep_means.items(), key=lambda x: x[1], reverse=True)[:50] + worst_summary[r["label"]] = [{"episode": ep, "mean_variance": v} for ep, v in ranked] + worst_path = OUTPUT_DIR / "action_consistency_worst_episodes.json" + worst_path.write_text(json.dumps(worst_summary, indent=2)) + print(f"✓ Saved worst episodes: {worst_path}") + if __name__ == "__main__": main()