Files
lerobot/examples/dataset/visualization_tools/action_consistency.py
T
Pepijn 58eecad8a4 feat(viz): add image-based consistency analysis with SigLIP
Run two parallel KNN analyses per dataset:
  1. State-based: KNN in joint-state space
  2. Image-based: KNN in SigLIP embedding space (google/siglip-base-patch16-224)

Both measure action chunk variance among cross-episode neighbors.
Comparing them reveals whether visual and proprioceptive similarity
agree on where data is inconsistent.

Output is a 4-row figure: state histogram, image histogram,
overlaid per-episode curves, and spatial heatmap colored by
image-based variance.

Made-with: Cursor
2026-03-23 20:29:36 -07:00

716 lines
25 KiB
Python

"""
Action consistency analysis for imitation learning datasets.
Two parallel analyses per dataset:
1. State-based: KNN in joint-state space → action chunk variance
2. Image-based: KNN in SigLIP embedding space → action chunk variance
Comparing them reveals whether visual similarity and proprioceptive similarity
agree on where the data is inconsistent — and images are what the policy
primarily sees.
"""
import json
from pathlib import Path
import cv2
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from huggingface_hub import snapshot_download
from matplotlib.colors import LinearSegmentedColormap
from PIL import Image
from scipy.spatial import cKDTree
from transformers import AutoModel, AutoProcessor
DATASETS = [
{"repo_id": "lerobot-data-collection/level2_final_quality3", "label": "HQ curated"},
{"repo_id": "lerobot-data-collection/level12_rac_2_2026-02-08_1", "label": "Full collection"},
]
OUTPUT_DIR = Path(__file__).resolve().parent / "outputs"
OUTPUT_DIR.mkdir(exist_ok=True)
MAX_FRAMES = 100_000
K_NEIGHBORS = 50
ACTION_CHUNK_SIZE = 30
CAMERA_KEY = "observation.images.base"
ENCODER_MODEL = "google/siglip-base-patch16-224"
ENCODE_BATCH_SIZE = 128
SEED = 42
DPI = 150
CONSISTENCY_CMAP = LinearSegmentedColormap.from_list(
"consistency", ["#0a2e0a", "#1a8e1a", "#88cc22", "#ffaa22", "#ff2222"]
)
# FK chains from OpenArm bimanual URDF (same as workspace_density.py).
LEFT_CHAIN = [
((-np.pi / 2, 0, 0), (0, 0.031, 0.698), None),
((0, 0, 0), (0, 0, 0.0625), (0, 0, 1)),
((-np.pi / 2, 0, 0), (-0.0301, 0, 0.06), (-1, 0, 0)),
((0, 0, 0), (0.0301, 0, 0.06625), (0, 0, 1)),
((0, 0, 0), (0, 0.0315, 0.15375), (0, 1, 0)),
((0, 0, 0), (0, -0.0315, 0.0955), (0, 0, 1)),
((0, 0, 0), (0.0375, 0, 0.1205), (1, 0, 0)),
((0, 0, 0), (-0.0375, 0, 0), (0, -1, 0)),
((0, 0, 0), (0, 0, 0.1001), None),
((0, 0, 0), (0, 0, 0.08), None),
]
RIGHT_CHAIN = [
((np.pi / 2, 0, 0), (0, -0.031, 0.698), None),
((0, 0, 0), (0, 0, 0.0625), (0, 0, 1)),
((np.pi / 2, 0, 0), (-0.0301, 0, 0.06), (-1, 0, 0)),
((0, 0, 0), (0.0301, 0, 0.06625), (0, 0, 1)),
((0, 0, 0), (0, 0.0315, 0.15375), (0, 1, 0)),
((0, 0, 0), (0, -0.0315, 0.0955), (0, 0, 1)),
((0, 0, 0), (0.0375, 0, 0.1205), (1, 0, 0)),
((0, 0, 0), (-0.0375, 0, 0), (0, 1, 0)),
((0, 0, 0), (0, 0, 0.1001), None),
((0, 0, 0), (0, 0, 0.08), None),
]
# ── FK math ─────────────────────────────────────────────
def _rot_x(a: float) -> np.ndarray:
c, s = np.cos(a), np.sin(a)
return np.array([[1, 0, 0], [0, c, -s], [0, s, c]])
def _rot_y(a: float) -> np.ndarray:
c, s = np.cos(a), np.sin(a)
return np.array([[c, 0, s], [0, 1, 0], [-s, 0, c]])
def _rot_z(a: float) -> np.ndarray:
c, s = np.cos(a), np.sin(a)
return np.array([[c, -s, 0], [s, c, 0], [0, 0, 1]])
def _tf(rpy: tuple, xyz: tuple) -> np.ndarray:
r, p, y = rpy
mat = np.eye(4)
mat[:3, :3] = _rot_z(y) @ _rot_y(p) @ _rot_x(r)
mat[:3, 3] = xyz
return mat
def _batch_axis_rot(axis: tuple, angles: np.ndarray) -> np.ndarray:
n = len(angles)
ax = np.asarray(axis, dtype=np.float64)
ax = ax / np.linalg.norm(ax)
x, y, z = ax
c = np.cos(angles)
s = np.sin(angles)
t = 1 - c
rot = np.zeros((n, 4, 4))
rot[:, 0, 0] = t * x * x + c
rot[:, 0, 1] = t * x * y - s * z
rot[:, 0, 2] = t * x * z + s * y
rot[:, 1, 0] = t * x * y + s * z
rot[:, 1, 1] = t * y * y + c
rot[:, 1, 2] = t * y * z - s * x
rot[:, 2, 0] = t * x * z - s * y
rot[:, 2, 1] = t * y * z + s * x
rot[:, 2, 2] = t * z * z + c
rot[:, 3, 3] = 1.0
return rot
def batch_fk(chain: list, joint_angles: np.ndarray) -> np.ndarray:
n = joint_angles.shape[0]
tf_batch = np.tile(np.eye(4), (n, 1, 1))
qi = 0
for rpy, xyz, axis in chain:
tf_batch = tf_batch @ _tf(rpy, xyz)
if axis is not None:
rot = _batch_axis_rot(axis, joint_angles[:, qi])
tf_batch = np.einsum("nij,njk->nik", tf_batch, rot)
qi += 1
return tf_batch[:, :3, 3]
# ── Data helpers ────────────────────────────────────────
def _flatten_names(obj: object) -> list[str]:
if isinstance(obj, dict):
out: list[str] = []
for v in obj.values():
out.extend(_flatten_names(v))
return out
if isinstance(obj, (list, tuple)):
out = []
for item in obj:
if isinstance(item, (list, tuple, dict)):
out.extend(_flatten_names(item))
else:
out.append(str(item))
return out
return [str(obj)]
def _detect_and_convert(vals: np.ndarray) -> np.ndarray:
mx = np.max(np.abs(vals))
if mx > 360:
print(f" Unit detection: servo ticks (max={mx:.0f})")
return (vals - 2048) / 2048 * np.pi
if mx > 6.3:
print(f" Unit detection: degrees (max={mx:.1f})")
return np.deg2rad(vals)
print(f" Unit detection: radians (max={mx:.3f})")
return vals.astype(np.float64)
def _find_joint_indices(features: dict, state_col: str, n_dim: int) -> tuple[list[int], list[int]]:
feat = features.get("observation.state", features.get(state_col, {}))
names = _flatten_names(feat.get("names", []))
left_idx: list[int] = []
right_idx: list[int] = []
if names and len(names) == n_dim:
names_l = [n.lower() for n in names]
print(f" Feature names: {names[:4]}{names[-4:]}")
for j in range(1, 8):
for i, nm in enumerate(names_l):
if f"left_joint_{j}" in nm and i not in left_idx:
left_idx.append(i)
break
for i, nm in enumerate(names_l):
if f"right_joint_{j}" in nm and i not in right_idx:
right_idx.append(i)
break
if len(left_idx) == 7 and len(right_idx) == 7:
print(f" Matched by name: left={left_idx} right={right_idx}")
return left_idx, right_idx
if n_dim >= 16:
print(" Falling back to positional: [0:7]=left, [8:15]=right")
return list(range(7)), list(range(8, 15))
if n_dim >= 14:
print(" Falling back to positional: [0:7]=left, [7:14]=right")
return list(range(7)), list(range(7, 14))
raise RuntimeError(f"State dim {n_dim} too small for bimanual 7-DOF robot")
def download_data(repo_id: str, camera_key: str) -> Path:
print(f" Downloading {repo_id} (parquet + {camera_key} videos) …")
return Path(
snapshot_download(
repo_id=repo_id,
repo_type="dataset",
allow_patterns=[
"meta/**",
"data/**",
f"videos/{camera_key}/**",
],
)
)
# ── Data loading ────────────────────────────────────────
def _build_action_chunks(
actions: np.ndarray, episode_ids: np.ndarray, chunk_size: int
) -> tuple[np.ndarray, np.ndarray]:
"""
For each frame, concatenate the next chunk_size actions from the same episode.
Returns (action_chunks, valid_mask).
"""
n = len(actions)
act_dim = actions.shape[1]
chunks = np.zeros((n, chunk_size * act_dim), dtype=np.float64)
valid = np.zeros(n, dtype=bool)
for i in range(n):
end = i + chunk_size
if end > n:
continue
if episode_ids[i] != episode_ids[end - 1]:
continue
chunks[i] = actions[i:end].ravel()
valid[i] = True
return chunks, valid
def load_state_action_data(local: Path, max_frames: int, chunk_size: int, rng: np.random.Generator) -> dict:
"""
Load observation.state and action, build action chunks, subsample, normalize.
Also returns the original row indices (`chosen_idx`) for video frame mapping.
"""
info = json.loads((local / "meta" / "info.json").read_text())
features = info.get("features", {})
dfs = [pd.read_parquet(pq) for pq in sorted((local / "data").glob("**/*.parquet"))]
df = pd.concat(dfs, ignore_index=True)
n_total = len(df)
print(f" Total frames: {n_total:,}")
state_col = next((c for c in df.columns if "observation.state" in c), None)
action_col = next((c for c in df.columns if c == "action"), None)
if state_col is None:
raise RuntimeError(f"No observation.state column. Available: {list(df.columns)}")
if action_col is None:
raise RuntimeError(f"No action column. Available: {list(df.columns)}")
ep_col = next((c for c in df.columns if c == "episode_index"), None)
if ep_col is None:
raise RuntimeError(f"No episode_index column. Available: {list(df.columns)}")
state_all = np.stack(df[state_col].values).astype(np.float64)
action_all = np.stack(df[action_col].values).astype(np.float64)
episode_all = df[ep_col].values.astype(np.int64)
n_dim = state_all.shape[1]
act_dim = action_all.shape[1]
print(f" State dim: {n_dim} Action dim: {act_dim} Chunk size: {chunk_size}")
print(f" Action chunk dim: {chunk_size * act_dim}")
left_idx, right_idx = _find_joint_indices(features, state_col, n_dim)
print(" Building action chunks …")
action_chunks, valid = _build_action_chunks(action_all, episode_all, chunk_size)
valid_idx = np.where(valid)[0]
print(f" Valid frames (with full action chunk): {len(valid_idx):,} / {n_total:,}")
if len(valid_idx) > max_frames:
chosen = np.sort(rng.choice(valid_idx, max_frames, replace=False))
else:
chosen = valid_idx
print(f" Using {len(chosen):,} frames")
state_raw = state_all[chosen]
action_raw = action_chunks[chosen]
episode_ids = episode_all[chosen]
state_mean = state_raw.mean(axis=0)
state_std = state_raw.std(axis=0)
state_std[state_std < 1e-8] = 1.0
state_norm = (state_raw - state_mean) / state_std
action_mean = action_raw.mean(axis=0)
action_std = action_raw.std(axis=0)
action_std[action_std < 1e-8] = 1.0
action_norm = (action_raw - action_mean) / action_std
return {
"state_raw": state_raw,
"state_norm": state_norm,
"action_raw": action_raw,
"action_norm": action_norm,
"episode_ids": episode_ids,
"episode_all": episode_all,
"left_joint_idx": left_idx,
"right_joint_idx": right_idx,
"n_total": n_total,
"chosen_idx": chosen,
"df": df,
}
# ── Video → frame extraction ──────────────────────────────
def build_video_lookup(local: Path, camera_key: str) -> dict:
"""
Build a mapping from episode_index → {video_path, fps, from_ts}.
"""
info = json.loads((local / "meta" / "info.json").read_text())
fps = info["fps"]
video_template = info.get(
"video_path",
"videos/{video_key}/chunk-{chunk_index:03d}/file-{file_index:03d}.mp4",
)
ep_rows = []
for pq in sorted((local / "meta" / "episodes").glob("**/*.parquet")):
ep_rows.append(pd.read_parquet(pq))
ep_df = pd.concat(ep_rows, ignore_index=True)
chunk_col = f"videos/{camera_key}/chunk_index"
file_col = f"videos/{camera_key}/file_index"
ts_from = f"videos/{camera_key}/from_timestamp"
if chunk_col not in ep_df.columns:
chunk_col = f"{camera_key}/chunk_index"
file_col = f"{camera_key}/file_index"
ts_from = f"{camera_key}/from_timestamp"
lookup: dict[int, dict] = {}
for _, row in ep_df.iterrows():
ci = int(row[chunk_col])
fi = int(row[file_col])
video_rel = video_template.format(video_key=camera_key, chunk_index=ci, file_index=fi)
lookup[int(row["episode_index"])] = {
"video_path": local / video_rel,
"from_ts": float(row[ts_from]),
"fps": fps,
}
return lookup
def extract_frames(
chosen_idx: np.ndarray,
episode_all: np.ndarray,
video_lookup: dict,
) -> list[np.ndarray | None]:
"""
Extract BGR frames for each chosen global index.
Uses episode boundaries + fps to compute the seek timestamp.
Returns list of (H, W, 3) BGR arrays (or None on failure).
"""
# Build per-episode local frame index: for each row in the dataset,
# its position within its episode
unique_eps = np.unique(episode_all)
ep_start: dict[int, int] = {}
for ep in unique_eps:
ep_start[int(ep)] = int(np.where(episode_all == ep)[0][0])
frames: list[np.ndarray | None] = []
# Group by video file for efficient sequential access
jobs: list[tuple[int, int, str, float]] = []
for out_i, global_i in enumerate(chosen_idx):
ep = int(episode_all[global_i])
info = video_lookup.get(ep)
if info is None:
jobs.append((out_i, -1, "", 0.0))
continue
local_frame = global_i - ep_start[ep]
seek_ts = info["from_ts"] + local_frame / info["fps"]
jobs.append((out_i, global_i, str(info["video_path"]), seek_ts))
jobs.sort(key=lambda x: (x[2], x[3]))
frames = [None] * len(chosen_idx)
current_cap = None
current_path = ""
extracted = 0
for out_i, _global_i, vpath, seek_ts in jobs:
if not vpath:
continue
if vpath != current_path:
if current_cap is not None:
current_cap.release()
current_cap = cv2.VideoCapture(vpath)
current_path = vpath
if current_cap is None or not current_cap.isOpened():
continue
current_cap.set(cv2.CAP_PROP_POS_MSEC, seek_ts * 1000.0)
ret, frame = current_cap.read()
if ret:
frames[out_i] = frame
extracted += 1
if current_cap is not None:
current_cap.release()
print(f" Extracted {extracted:,} / {len(chosen_idx):,} frames from video")
return frames
# ── SigLIP encoding ─────────────────────────────────────
def encode_frames_siglip(
frames: list[np.ndarray | None],
model_name: str,
batch_size: int,
device: torch.device,
) -> np.ndarray:
"""
Encode BGR frames through SigLIP vision encoder.
Returns (N, embed_dim) float32 array. Frames that are None get a zero vector.
"""
print(f" Loading SigLIP model: {model_name}")
processor = AutoProcessor.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name).to(device).eval()
embed_dim = model.config.vision_config.hidden_size
n = len(frames)
embeddings = np.zeros((n, embed_dim), dtype=np.float32)
# Collect valid frame indices
valid_indices = [i for i, f in enumerate(frames) if f is not None]
print(f" Encoding {len(valid_indices):,} valid frames in batches of {batch_size}")
for batch_start in range(0, len(valid_indices), batch_size):
batch_idx = valid_indices[batch_start : batch_start + batch_size]
pil_images = []
for i in batch_idx:
bgr = frames[i]
rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
pil_images.append(Image.fromarray(rgb))
inputs = processor(images=pil_images, return_tensors="pt").to(device)
with torch.no_grad():
image_features = model.get_image_features(**inputs)
# L2-normalize embeddings for cosine-like KNN
image_features = torch.nn.functional.normalize(image_features, dim=-1)
embeddings[batch_idx] = image_features.cpu().numpy()
done = min(batch_start + batch_size, len(valid_indices))
if done % (batch_size * 10) == 0 or done == len(valid_indices):
print(f" {done:,} / {len(valid_indices):,} encoded")
del model, processor
torch.cuda.empty_cache()
return embeddings
# ── KNN consistency ─────────────────────────────────────
def compute_consistency(
features: np.ndarray,
action_norm: np.ndarray,
episode_ids: np.ndarray,
k: int,
label: str = "",
) -> np.ndarray:
"""
For each frame, find K nearest neighbors in feature space from other episodes.
Return per-frame action variance (mean across action dims).
"""
n = len(features)
print(f" Building KD-tree on {n:,} vectors ({label}) …")
tree = cKDTree(features)
k_query = min(k * 3, n - 1)
print(f" Querying {k_query} neighbors per frame …")
_dists, indices = tree.query(features, k=k_query + 1)
indices = indices[:, 1:]
print(f" Computing cross-episode action variance ({label}) …")
variance = np.zeros(n)
for i in range(n):
ep_i = episode_ids[i]
neighbors = indices[i]
cross_ep = neighbors[episode_ids[neighbors] != ep_i][:k]
if len(cross_ep) < 2:
variance[i] = 0.0
continue
neighbor_actions = action_norm[cross_ep]
variance[i] = np.mean(np.var(neighbor_actions, axis=0))
return variance
# ── Visualization ───────────────────────────────────────
def _style_ax(ax: plt.Axes) -> None:
ax.set_facecolor("#0d1117")
ax.tick_params(colors="#555", labelsize=8)
for spine in ax.spines.values():
spine.set_color("#333")
def _plot_histogram(ax: plt.Axes, variance: np.ndarray, title: str, color: str) -> None:
_style_ax(ax)
median_var = np.median(variance)
mean_var = np.mean(variance)
nonzero = variance[variance > 0]
if len(nonzero) > 0:
bins = np.logspace(np.log10(nonzero.min().clip(1e-6)), np.log10(nonzero.max()), 60)
ax.hist(nonzero, bins=bins, color=color, alpha=0.8, edgecolor="#222")
ax.set_xscale("log")
ax.axvline(median_var, color="#ff6600", linewidth=2, label=f"median={median_var:.3f}")
ax.axvline(mean_var, color="#ff2222", linewidth=2, linestyle="--", label=f"mean={mean_var:.3f}")
ax.set_xlabel("Action variance (log scale)", color="#888", fontsize=10)
ax.set_ylabel("Frame count", color="#888", fontsize=10)
ax.set_title(title, color="white", fontsize=11, pad=10)
ax.legend(fontsize=8, facecolor="#1a1a2e", edgecolor="#333", labelcolor="white")
def _plot_episode_curves(
ax: plt.Axes,
var_state: np.ndarray,
var_image: np.ndarray,
episode_ids: np.ndarray,
title: str,
) -> None:
_style_ax(ax)
unique_eps = np.unique(episode_ids)
ep_means_s = np.array([var_state[episode_ids == ep].mean() for ep in unique_eps])
ep_means_i = np.array([var_image[episode_ids == ep].mean() for ep in unique_eps])
sorted_s = np.sort(ep_means_s)[::-1]
sorted_i = np.sort(ep_means_i)[::-1]
ep_x = np.arange(len(unique_eps))
ax.fill_between(ep_x, sorted_s, alpha=0.2, color="#4363d8")
ax.plot(ep_x, sorted_s, color="#4363d8", linewidth=1.2, label=f"State (med={np.median(ep_means_s):.3f})")
ax.fill_between(ep_x, sorted_i, alpha=0.2, color="#e6194b")
ax.plot(ep_x, sorted_i, color="#e6194b", linewidth=1.2, label=f"Image (med={np.median(ep_means_i):.3f})")
ax.set_xlabel("Episode rank (worst → best)", color="#888", fontsize=10)
ax.set_ylabel("Mean action variance", color="#888", fontsize=10)
ax.set_title(title, color="white", fontsize=11, pad=10)
ax.legend(fontsize=8, facecolor="#1a1a2e", edgecolor="#333", labelcolor="white")
def _plot_heatmap(
ax: plt.Axes, fig: plt.Figure, tcp_xz: np.ndarray, variance: np.ndarray, title: str
) -> None:
_style_ax(ax)
order = np.argsort(variance)
pts = tcp_xz[order]
var_sorted = variance[order]
vmin = np.percentile(variance[variance > 0], 5) if np.any(variance > 0) else 0
vmax = np.percentile(variance[variance > 0], 95) if np.any(variance > 0) else 1
sc = ax.scatter(
pts[:, 0],
pts[:, 1],
c=var_sorted,
cmap=CONSISTENCY_CMAP,
s=0.5,
alpha=0.6,
vmin=vmin,
vmax=vmax,
rasterized=True,
)
ax.set_xlabel("X (m)", color="#888", fontsize=10)
ax.set_ylabel("Z (m)", color="#888", fontsize=10)
ax.set_title(title, color="white", fontsize=11, pad=10)
ax.set_aspect("equal")
cbar = fig.colorbar(sc, ax=ax, shrink=0.8, pad=0.02)
cbar.set_label("Action variance", color="white", fontsize=9)
cbar.ax.tick_params(colors="#aaa", labelsize=7)
def render(results: list[dict], out_path: Path) -> None:
"""
4-row x N-column figure:
Row 0: State-based variance histogram
Row 1: Image-based variance histogram
Row 2: Per-episode curves (both overlaid)
Row 3: Spatial heatmap (image-based variance)
"""
n_ds = len(results)
fig, axes = plt.subplots(4, n_ds, figsize=(9 * n_ds, 24), facecolor="#0d1117")
if n_ds == 1:
axes = axes[:, np.newaxis]
headline_parts = []
for col, r in enumerate(results):
label = r["label"]
var_s = r["var_state"]
var_i = r["var_image"]
tcp_xz = r["tcp_xz"]
episode_ids = r["episode_ids"]
med_s = np.median(var_s)
med_i = np.median(var_i)
headline_parts.append(f"{label}: state={med_s:.3f}, image={med_i:.3f}")
_plot_histogram(axes[0, col], var_s, f"{label}\nState-based variance (K={K_NEIGHBORS})", "#4363d8")
_plot_histogram(
axes[1, col], var_i, f"{label}\nImage-based variance (SigLIP, K={K_NEIGHBORS})", "#e6194b"
)
_plot_episode_curves(
axes[2, col],
var_s,
var_i,
episode_ids,
f"{label}\nPer-episode inconsistency ({len(np.unique(episode_ids)):,} episodes)",
)
_plot_heatmap(
axes[3, col],
fig,
tcp_xz,
var_i,
f"{label}\nImage-based variance by TCP position (XZ)",
)
fig.suptitle(
f"Action Consistency: State vs Image (chunk={ACTION_CHUNK_SIZE}, K={K_NEIGHBORS})\n"
+ " | ".join(headline_parts),
color="white",
fontsize=15,
y=0.99,
)
plt.tight_layout(rect=[0, 0, 1, 0.96])
plt.savefig(out_path, dpi=DPI, bbox_inches="tight", facecolor=fig.get_facecolor())
plt.close()
print(f"\n✓ Saved: {out_path}")
# ── Main ────────────────────────────────────────────────
def main() -> None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
rng = np.random.default_rng(SEED)
results = []
for ds in DATASETS:
repo_id, label = ds["repo_id"], ds["label"]
print(f"\n{'=' * 60}")
print(f" {label}: {repo_id}")
print(f"{'=' * 60}")
local = download_data(repo_id, CAMERA_KEY)
data = load_state_action_data(local, MAX_FRAMES, ACTION_CHUNK_SIZE, rng)
# --- State-based KNN ---
var_state = compute_consistency(
data["state_norm"], data["action_norm"], data["episode_ids"], K_NEIGHBORS, "state"
)
print(
f" State variance: median={np.median(var_state):.4f} "
f"mean={np.mean(var_state):.4f} p90={np.percentile(var_state, 90):.4f}"
)
# --- Image-based KNN ---
print("\n Preparing image embeddings …")
video_lookup = build_video_lookup(local, CAMERA_KEY)
frames = extract_frames(data["chosen_idx"], data["episode_all"], video_lookup)
embeddings = encode_frames_siglip(frames, ENCODER_MODEL, ENCODE_BATCH_SIZE, device)
del frames # free memory
var_image = compute_consistency(
embeddings, data["action_norm"], data["episode_ids"], K_NEIGHBORS, "image"
)
print(
f" Image variance: median={np.median(var_image):.4f} "
f"mean={np.mean(var_image):.4f} p90={np.percentile(var_image, 90):.4f}"
)
# FK for spatial heatmap
print(" Computing FK for spatial heatmap …")
left_raw = data["state_raw"][:, data["left_joint_idx"]]
left_rad = _detect_and_convert(left_raw)
left_tcp = batch_fk(LEFT_CHAIN, left_rad)
tcp_xz = left_tcp[:, [0, 2]]
results.append(
{
"label": label,
"var_state": var_state,
"var_image": var_image,
"episode_ids": data["episode_ids"],
"tcp_xz": tcp_xz,
"n_total": data["n_total"],
}
)
out = OUTPUT_DIR / "action_consistency_comparison.jpg"
render(results, out)
# Save worst-episodes summary (image-based, since that's the stronger signal)
worst_summary = {}
for r in results:
unique_eps = np.unique(r["episode_ids"])
ep_means = {int(ep): float(r["var_image"][r["episode_ids"] == ep].mean()) for ep in unique_eps}
ranked = sorted(ep_means.items(), key=lambda x: x[1], reverse=True)[:50]
worst_summary[r["label"]] = [{"episode": ep, "mean_variance": v} for ep, v in ranked]
worst_path = OUTPUT_DIR / "action_consistency_worst_episodes.json"
worst_path.write_text(json.dumps(worst_summary, indent=2))
print(f"✓ Saved worst episodes: {worst_path}")
if __name__ == "__main__":
main()