feat(viz): add image-based consistency analysis with SigLIP

Run two parallel KNN analyses per dataset:
  1. State-based: KNN in joint-state space
  2. Image-based: KNN in SigLIP embedding space (google/siglip-base-patch16-224)

Both measure action chunk variance among cross-episode neighbors.
Comparing them reveals whether visual and proprioceptive similarity
agree on where data is inconsistent.

Output is a 4-row figure: state histogram, image histogram,
overlaid per-episode curves, and spatial heatmap colored by
image-based variance.

Made-with: Cursor
This commit is contained in:
Pepijn
2026-03-23 20:29:36 -07:00
parent c7fd1f47d1
commit 58eecad8a4
@@ -1,22 +1,28 @@
"""
Action-state consistency analysis for imitation learning datasets.
For each frame, finds K nearest neighbors in state space (from other episodes)
and measures the variance of corresponding actions. High variance at similar
states = contradictory supervision for the policy.
Action consistency analysis for imitation learning datasets.
Outputs a comparison figure with histograms, per-episode curves, and spatial
heatmaps showing where demonstrations conflict.
Two parallel analyses per dataset:
1. State-based: KNN in joint-state space → action chunk variance
2. Image-based: KNN in SigLIP embedding space → action chunk variance
Comparing them reveals whether visual similarity and proprioceptive similarity
agree on where the data is inconsistent — and images are what the policy
primarily sees.
"""
import json
from pathlib import Path
import cv2
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from huggingface_hub import snapshot_download
from matplotlib.colors import LinearSegmentedColormap
from PIL import Image
from scipy.spatial import cKDTree
from transformers import AutoModel, AutoProcessor
DATASETS = [
{"repo_id": "lerobot-data-collection/level2_final_quality3", "label": "HQ curated"},
@@ -25,9 +31,12 @@ DATASETS = [
OUTPUT_DIR = Path(__file__).resolve().parent / "outputs"
OUTPUT_DIR.mkdir(exist_ok=True)
MAX_FRAMES = 10_000
MAX_FRAMES = 100_000
K_NEIGHBORS = 50
ACTION_CHUNK_SIZE = 30
CAMERA_KEY = "observation.images.base"
ENCODER_MODEL = "google/siglip-base-patch16-224"
ENCODE_BATCH_SIZE = 128
SEED = 42
DPI = 150
@@ -184,14 +193,17 @@ def _find_joint_indices(features: dict, state_col: str, n_dim: int) -> tuple[lis
raise RuntimeError(f"State dim {n_dim} too small for bimanual 7-DOF robot")
def download_data(repo_id: str) -> Path:
print(f" Downloading {repo_id} (parquet only) …")
def download_data(repo_id: str, camera_key: str) -> Path:
print(f" Downloading {repo_id} (parquet + {camera_key} videos) …")
return Path(
snapshot_download(
repo_id=repo_id,
repo_type="dataset",
allow_patterns=["meta/**", "data/**"],
ignore_patterns=["*.mp4", "videos/**"],
allow_patterns=[
"meta/**",
"data/**",
f"videos/{camera_key}/**",
],
)
)
@@ -203,9 +215,8 @@ def _build_action_chunks(
actions: np.ndarray, episode_ids: np.ndarray, chunk_size: int
) -> tuple[np.ndarray, np.ndarray]:
"""
Build action chunks: for each frame, concatenate the next chunk_size actions
from the same episode. Returns (action_chunks, valid_mask).
Frames too close to episode end to form a full chunk are marked invalid.
For each frame, concatenate the next chunk_size actions from the same episode.
Returns (action_chunks, valid_mask).
"""
n = len(actions)
act_dim = actions.shape[1]
@@ -216,7 +227,6 @@ def _build_action_chunks(
end = i + chunk_size
if end > n:
continue
# All frames in the chunk must belong to the same episode
if episode_ids[i] != episode_ids[end - 1]:
continue
chunks[i] = actions[i:end].ravel()
@@ -227,8 +237,8 @@ def _build_action_chunks(
def load_state_action_data(local: Path, max_frames: int, chunk_size: int, rng: np.random.Generator) -> dict:
"""
Load observation.state and action columns, build action chunks of size
chunk_size (matching what the policy learns), subsample, normalize.
Load observation.state and action, build action chunks, subsample, normalize.
Also returns the original row indices (`chosen_idx`) for video frame mapping.
"""
info = json.loads((local / "meta" / "info.json").read_text())
features = info.get("features", {})
@@ -260,13 +270,11 @@ def load_state_action_data(local: Path, max_frames: int, chunk_size: int, rng: n
left_idx, right_idx = _find_joint_indices(features, state_col, n_dim)
# Build action chunks within episode boundaries
print(" Building action chunks …")
action_chunks, valid = _build_action_chunks(action_all, episode_all, chunk_size)
valid_idx = np.where(valid)[0]
print(f" Valid frames (with full action chunk): {len(valid_idx):,} / {n_total:,}")
# Subsample from valid frames only
if len(valid_idx) > max_frames:
chosen = np.sort(rng.choice(valid_idx, max_frames, replace=False))
else:
@@ -277,7 +285,6 @@ def load_state_action_data(local: Path, max_frames: int, chunk_size: int, rng: n
action_raw = action_chunks[chosen]
episode_ids = episode_all[chosen]
# Z-score normalize for fair KNN distance
state_mean = state_raw.mean(axis=0)
state_std = state_raw.std(axis=0)
state_std[state_std < 1e-8] = 1.0
@@ -294,38 +301,186 @@ def load_state_action_data(local: Path, max_frames: int, chunk_size: int, rng: n
"action_raw": action_raw,
"action_norm": action_norm,
"episode_ids": episode_ids,
"episode_all": episode_all,
"left_joint_idx": left_idx,
"right_joint_idx": right_idx,
"n_total": n_total,
"chosen_idx": chosen,
"df": df,
}
# ── Video → frame extraction ──────────────────────────────
def build_video_lookup(local: Path, camera_key: str) -> dict:
"""
Build a mapping from episode_index → {video_path, fps, from_ts}.
"""
info = json.loads((local / "meta" / "info.json").read_text())
fps = info["fps"]
video_template = info.get(
"video_path",
"videos/{video_key}/chunk-{chunk_index:03d}/file-{file_index:03d}.mp4",
)
ep_rows = []
for pq in sorted((local / "meta" / "episodes").glob("**/*.parquet")):
ep_rows.append(pd.read_parquet(pq))
ep_df = pd.concat(ep_rows, ignore_index=True)
chunk_col = f"videos/{camera_key}/chunk_index"
file_col = f"videos/{camera_key}/file_index"
ts_from = f"videos/{camera_key}/from_timestamp"
if chunk_col not in ep_df.columns:
chunk_col = f"{camera_key}/chunk_index"
file_col = f"{camera_key}/file_index"
ts_from = f"{camera_key}/from_timestamp"
lookup: dict[int, dict] = {}
for _, row in ep_df.iterrows():
ci = int(row[chunk_col])
fi = int(row[file_col])
video_rel = video_template.format(video_key=camera_key, chunk_index=ci, file_index=fi)
lookup[int(row["episode_index"])] = {
"video_path": local / video_rel,
"from_ts": float(row[ts_from]),
"fps": fps,
}
return lookup
def extract_frames(
chosen_idx: np.ndarray,
episode_all: np.ndarray,
video_lookup: dict,
) -> list[np.ndarray | None]:
"""
Extract BGR frames for each chosen global index.
Uses episode boundaries + fps to compute the seek timestamp.
Returns list of (H, W, 3) BGR arrays (or None on failure).
"""
# Build per-episode local frame index: for each row in the dataset,
# its position within its episode
unique_eps = np.unique(episode_all)
ep_start: dict[int, int] = {}
for ep in unique_eps:
ep_start[int(ep)] = int(np.where(episode_all == ep)[0][0])
frames: list[np.ndarray | None] = []
# Group by video file for efficient sequential access
jobs: list[tuple[int, int, str, float]] = []
for out_i, global_i in enumerate(chosen_idx):
ep = int(episode_all[global_i])
info = video_lookup.get(ep)
if info is None:
jobs.append((out_i, -1, "", 0.0))
continue
local_frame = global_i - ep_start[ep]
seek_ts = info["from_ts"] + local_frame / info["fps"]
jobs.append((out_i, global_i, str(info["video_path"]), seek_ts))
jobs.sort(key=lambda x: (x[2], x[3]))
frames = [None] * len(chosen_idx)
current_cap = None
current_path = ""
extracted = 0
for out_i, _global_i, vpath, seek_ts in jobs:
if not vpath:
continue
if vpath != current_path:
if current_cap is not None:
current_cap.release()
current_cap = cv2.VideoCapture(vpath)
current_path = vpath
if current_cap is None or not current_cap.isOpened():
continue
current_cap.set(cv2.CAP_PROP_POS_MSEC, seek_ts * 1000.0)
ret, frame = current_cap.read()
if ret:
frames[out_i] = frame
extracted += 1
if current_cap is not None:
current_cap.release()
print(f" Extracted {extracted:,} / {len(chosen_idx):,} frames from video")
return frames
# ── SigLIP encoding ─────────────────────────────────────
def encode_frames_siglip(
frames: list[np.ndarray | None],
model_name: str,
batch_size: int,
device: torch.device,
) -> np.ndarray:
"""
Encode BGR frames through SigLIP vision encoder.
Returns (N, embed_dim) float32 array. Frames that are None get a zero vector.
"""
print(f" Loading SigLIP model: {model_name}")
processor = AutoProcessor.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name).to(device).eval()
embed_dim = model.config.vision_config.hidden_size
n = len(frames)
embeddings = np.zeros((n, embed_dim), dtype=np.float32)
# Collect valid frame indices
valid_indices = [i for i, f in enumerate(frames) if f is not None]
print(f" Encoding {len(valid_indices):,} valid frames in batches of {batch_size}")
for batch_start in range(0, len(valid_indices), batch_size):
batch_idx = valid_indices[batch_start : batch_start + batch_size]
pil_images = []
for i in batch_idx:
bgr = frames[i]
rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
pil_images.append(Image.fromarray(rgb))
inputs = processor(images=pil_images, return_tensors="pt").to(device)
with torch.no_grad():
image_features = model.get_image_features(**inputs)
# L2-normalize embeddings for cosine-like KNN
image_features = torch.nn.functional.normalize(image_features, dim=-1)
embeddings[batch_idx] = image_features.cpu().numpy()
done = min(batch_start + batch_size, len(valid_indices))
if done % (batch_size * 10) == 0 or done == len(valid_indices):
print(f" {done:,} / {len(valid_indices):,} encoded")
del model, processor
torch.cuda.empty_cache()
return embeddings
# ── KNN consistency ─────────────────────────────────────
def compute_consistency(
state_norm: np.ndarray,
features: np.ndarray,
action_norm: np.ndarray,
episode_ids: np.ndarray,
k: int,
label: str = "",
) -> np.ndarray:
"""
For each frame, find K nearest neighbors in state space from *other* episodes.
For each frame, find K nearest neighbors in feature space from other episodes.
Return per-frame action variance (mean across action dims).
"""
n = len(state_norm)
print(f" Building KD-tree on {n:,} state vectors …")
tree = cKDTree(state_norm)
n = len(features)
print(f" Building KD-tree on {n:,} vectors ({label})")
tree = cKDTree(features)
# Query extra neighbors to have room after filtering same-episode
k_query = min(k * 3, n - 1)
print(f" Querying {k_query} neighbors per frame …")
dists, indices = tree.query(state_norm, k=k_query + 1)
# indices[:, 0] is the point itself — skip it
_dists, indices = tree.query(features, k=k_query + 1)
indices = indices[:, 1:]
print(" Computing cross-episode action variance …")
print(f" Computing cross-episode action variance ({label})")
variance = np.zeros(n)
for i in range(n):
ep_i = episode_ids[i]
@@ -343,113 +498,133 @@ def compute_consistency(
# ── Visualization ───────────────────────────────────────
def _style_ax(ax: plt.Axes) -> None:
ax.set_facecolor("#0d1117")
ax.tick_params(colors="#555", labelsize=8)
for spine in ax.spines.values():
spine.set_color("#333")
def _plot_histogram(ax: plt.Axes, variance: np.ndarray, title: str, color: str) -> None:
_style_ax(ax)
median_var = np.median(variance)
mean_var = np.mean(variance)
nonzero = variance[variance > 0]
if len(nonzero) > 0:
bins = np.logspace(np.log10(nonzero.min().clip(1e-6)), np.log10(nonzero.max()), 60)
ax.hist(nonzero, bins=bins, color=color, alpha=0.8, edgecolor="#222")
ax.set_xscale("log")
ax.axvline(median_var, color="#ff6600", linewidth=2, label=f"median={median_var:.3f}")
ax.axvline(mean_var, color="#ff2222", linewidth=2, linestyle="--", label=f"mean={mean_var:.3f}")
ax.set_xlabel("Action variance (log scale)", color="#888", fontsize=10)
ax.set_ylabel("Frame count", color="#888", fontsize=10)
ax.set_title(title, color="white", fontsize=11, pad=10)
ax.legend(fontsize=8, facecolor="#1a1a2e", edgecolor="#333", labelcolor="white")
def _plot_episode_curves(
ax: plt.Axes,
var_state: np.ndarray,
var_image: np.ndarray,
episode_ids: np.ndarray,
title: str,
) -> None:
_style_ax(ax)
unique_eps = np.unique(episode_ids)
ep_means_s = np.array([var_state[episode_ids == ep].mean() for ep in unique_eps])
ep_means_i = np.array([var_image[episode_ids == ep].mean() for ep in unique_eps])
sorted_s = np.sort(ep_means_s)[::-1]
sorted_i = np.sort(ep_means_i)[::-1]
ep_x = np.arange(len(unique_eps))
ax.fill_between(ep_x, sorted_s, alpha=0.2, color="#4363d8")
ax.plot(ep_x, sorted_s, color="#4363d8", linewidth=1.2, label=f"State (med={np.median(ep_means_s):.3f})")
ax.fill_between(ep_x, sorted_i, alpha=0.2, color="#e6194b")
ax.plot(ep_x, sorted_i, color="#e6194b", linewidth=1.2, label=f"Image (med={np.median(ep_means_i):.3f})")
ax.set_xlabel("Episode rank (worst → best)", color="#888", fontsize=10)
ax.set_ylabel("Mean action variance", color="#888", fontsize=10)
ax.set_title(title, color="white", fontsize=11, pad=10)
ax.legend(fontsize=8, facecolor="#1a1a2e", edgecolor="#333", labelcolor="white")
def _plot_heatmap(
ax: plt.Axes, fig: plt.Figure, tcp_xz: np.ndarray, variance: np.ndarray, title: str
) -> None:
_style_ax(ax)
order = np.argsort(variance)
pts = tcp_xz[order]
var_sorted = variance[order]
vmin = np.percentile(variance[variance > 0], 5) if np.any(variance > 0) else 0
vmax = np.percentile(variance[variance > 0], 95) if np.any(variance > 0) else 1
sc = ax.scatter(
pts[:, 0],
pts[:, 1],
c=var_sorted,
cmap=CONSISTENCY_CMAP,
s=0.5,
alpha=0.6,
vmin=vmin,
vmax=vmax,
rasterized=True,
)
ax.set_xlabel("X (m)", color="#888", fontsize=10)
ax.set_ylabel("Z (m)", color="#888", fontsize=10)
ax.set_title(title, color="white", fontsize=11, pad=10)
ax.set_aspect("equal")
cbar = fig.colorbar(sc, ax=ax, shrink=0.8, pad=0.02)
cbar.set_label("Action variance", color="white", fontsize=9)
cbar.ax.tick_params(colors="#aaa", labelsize=7)
def render(results: list[dict], out_path: Path) -> None:
"""
4-row x N-column figure:
Row 0: State-based variance histogram
Row 1: Image-based variance histogram
Row 2: Per-episode curves (both overlaid)
Row 3: Spatial heatmap (image-based variance)
"""
n_ds = len(results)
fig, axes = plt.subplots(3, n_ds, figsize=(9 * n_ds, 18), facecolor="#0d1117")
fig, axes = plt.subplots(4, n_ds, figsize=(9 * n_ds, 24), facecolor="#0d1117")
if n_ds == 1:
axes = axes[:, np.newaxis]
headline_parts = []
for col, r in enumerate(results):
variance = r["variance"]
episode_ids = r["episode_ids"]
tcp_xz = r["tcp_xz"]
label = r["label"]
var_s = r["var_state"]
var_i = r["var_image"]
tcp_xz = r["tcp_xz"]
episode_ids = r["episode_ids"]
median_var = np.median(variance)
mean_var = np.mean(variance)
headline_parts.append(f"{label}: median={median_var:.3f}, mean={mean_var:.3f}")
med_s = np.median(var_s)
med_i = np.median(var_i)
headline_parts.append(f"{label}: state={med_s:.3f}, image={med_i:.3f}")
# Row 0: Histogram of per-frame action variance
ax = axes[0, col]
ax.set_facecolor("#0d1117")
nonzero = variance[variance > 0]
if len(nonzero) > 0:
bins = np.logspace(np.log10(nonzero.min().clip(1e-6)), np.log10(nonzero.max()), 60)
ax.hist(nonzero, bins=bins, color="#4363d8", alpha=0.8, edgecolor="#222")
ax.set_xscale("log")
ax.axvline(median_var, color="#ff6600", linewidth=2, label=f"median={median_var:.3f}")
ax.axvline(mean_var, color="#ff2222", linewidth=2, linestyle="--", label=f"mean={mean_var:.3f}")
ax.set_xlabel("Action variance (log scale)", color="#888", fontsize=10)
ax.set_ylabel("Frame count", color="#888", fontsize=10)
ax.set_title(f"{label}\nPer-frame action variance distribution", color="white", fontsize=12, pad=10)
ax.tick_params(colors="#555", labelsize=8)
for spine in ax.spines.values():
spine.set_color("#333")
ax.legend(fontsize=9, facecolor="#1a1a2e", edgecolor="#333", labelcolor="white")
# Row 1: Per-episode mean inconsistency curve (sorted)
ax = axes[1, col]
ax.set_facecolor("#0d1117")
unique_eps = np.unique(episode_ids)
ep_means = np.array([variance[episode_ids == ep].mean() for ep in unique_eps])
sorted_means = np.sort(ep_means)[::-1]
ep_x = np.arange(len(sorted_means))
p90 = np.percentile(ep_means, 90)
above_p90 = np.sum(ep_means > p90)
ax.fill_between(ep_x, sorted_means, alpha=0.3, color="#4363d8")
ax.plot(ep_x, sorted_means, color="#4363d8", linewidth=1.2)
ax.axhline(
np.median(ep_means), color="#ff6600", linewidth=1.5, label=f"median={np.median(ep_means):.3f}"
_plot_histogram(axes[0, col], var_s, f"{label}\nState-based variance (K={K_NEIGHBORS})", "#4363d8")
_plot_histogram(
axes[1, col], var_i, f"{label}\nImage-based variance (SigLIP, K={K_NEIGHBORS})", "#e6194b"
)
ax.axhline(
p90, color="#ff2222", linewidth=1, linestyle=":", label=f"p90={p90:.3f} ({above_p90} eps above)"
_plot_episode_curves(
axes[2, col],
var_s,
var_i,
episode_ids,
f"{label}\nPer-episode inconsistency ({len(np.unique(episode_ids)):,} episodes)",
)
ax.set_xlabel("Episode rank (worst → best)", color="#888", fontsize=10)
ax.set_ylabel("Mean action variance", color="#888", fontsize=10)
ax.set_title(
f"{label}\nPer-episode inconsistency ({len(unique_eps):,} episodes)",
color="white",
fontsize=12,
pad=10,
_plot_heatmap(
axes[3, col],
fig,
tcp_xz,
var_i,
f"{label}\nImage-based variance by TCP position (XZ)",
)
ax.tick_params(colors="#555", labelsize=8)
for spine in ax.spines.values():
spine.set_color("#333")
ax.legend(fontsize=9, facecolor="#1a1a2e", edgecolor="#333", labelcolor="white")
# Row 2: Spatial heatmap (XZ side view) colored by local action variance
ax = axes[2, col]
ax.set_facecolor("#0d1117")
order = np.argsort(variance)
pts = tcp_xz[order]
var_sorted = variance[order]
vmin = np.percentile(variance[variance > 0], 5) if np.any(variance > 0) else 0
vmax = np.percentile(variance[variance > 0], 95) if np.any(variance > 0) else 1
sc = ax.scatter(
pts[:, 0],
pts[:, 1],
c=var_sorted,
cmap=CONSISTENCY_CMAP,
s=0.5,
alpha=0.6,
vmin=vmin,
vmax=vmax,
rasterized=True,
)
ax.set_xlabel("X (m)", color="#888", fontsize=10)
ax.set_ylabel("Z (m)", color="#888", fontsize=10)
ax.set_title(
f"{label}\nAction variance by TCP position (XZ side)",
color="white",
fontsize=12,
pad=10,
)
ax.tick_params(colors="#555", labelsize=8)
for spine in ax.spines.values():
spine.set_color("#333")
ax.set_aspect("equal")
cbar = fig.colorbar(sc, ax=ax, shrink=0.8, pad=0.02)
cbar.set_label("Action variance", color="white", fontsize=9)
cbar.ax.tick_params(colors="#aaa", labelsize=7)
fig.suptitle(
f"Action-State Consistency Analysis (action chunk = {ACTION_CHUNK_SIZE})\n"
f"Action Consistency: State vs Image (chunk={ACTION_CHUNK_SIZE}, K={K_NEIGHBORS})\n"
+ " | ".join(headline_parts),
color="white",
fontsize=15,
@@ -465,6 +640,8 @@ def render(results: list[dict], out_path: Path) -> None:
def main() -> None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
rng = np.random.default_rng(SEED)
results = []
@@ -474,18 +651,34 @@ def main() -> None:
print(f" {label}: {repo_id}")
print(f"{'=' * 60}")
local = download_data(repo_id)
local = download_data(repo_id, CAMERA_KEY)
data = load_state_action_data(local, MAX_FRAMES, ACTION_CHUNK_SIZE, rng)
variance = compute_consistency(
data["state_norm"], data["action_norm"], data["episode_ids"], K_NEIGHBORS
# --- State-based KNN ---
var_state = compute_consistency(
data["state_norm"], data["action_norm"], data["episode_ids"], K_NEIGHBORS, "state"
)
print(
f" Variance stats: median={np.median(variance):.4f} mean={np.mean(variance):.4f} "
f"p90={np.percentile(variance, 90):.4f}"
f" State variance: median={np.median(var_state):.4f} "
f"mean={np.mean(var_state):.4f} p90={np.percentile(var_state, 90):.4f}"
)
# Compute FK for spatial heatmap (left arm TCP, XZ projection)
# --- Image-based KNN ---
print("\n Preparing image embeddings …")
video_lookup = build_video_lookup(local, CAMERA_KEY)
frames = extract_frames(data["chosen_idx"], data["episode_all"], video_lookup)
embeddings = encode_frames_siglip(frames, ENCODER_MODEL, ENCODE_BATCH_SIZE, device)
del frames # free memory
var_image = compute_consistency(
embeddings, data["action_norm"], data["episode_ids"], K_NEIGHBORS, "image"
)
print(
f" Image variance: median={np.median(var_image):.4f} "
f"mean={np.mean(var_image):.4f} p90={np.percentile(var_image, 90):.4f}"
)
# FK for spatial heatmap
print(" Computing FK for spatial heatmap …")
left_raw = data["state_raw"][:, data["left_joint_idx"]]
left_rad = _detect_and_convert(left_raw)
@@ -495,7 +688,8 @@ def main() -> None:
results.append(
{
"label": label,
"variance": variance,
"var_state": var_state,
"var_image": var_image,
"episode_ids": data["episode_ids"],
"tcp_xz": tcp_xz,
"n_total": data["n_total"],
@@ -505,6 +699,17 @@ def main() -> None:
out = OUTPUT_DIR / "action_consistency_comparison.jpg"
render(results, out)
# Save worst-episodes summary (image-based, since that's the stronger signal)
worst_summary = {}
for r in results:
unique_eps = np.unique(r["episode_ids"])
ep_means = {int(ep): float(r["var_image"][r["episode_ids"] == ep].mean()) for ep in unique_eps}
ranked = sorted(ep_means.items(), key=lambda x: x[1], reverse=True)[:50]
worst_summary[r["label"]] = [{"episode": ep, "mean_variance": v} for ep, v in ranked]
worst_path = OUTPUT_DIR / "action_consistency_worst_episodes.json"
worst_path.write_text(json.dumps(worst_summary, indent=2))
print(f"✓ Saved worst episodes: {worst_path}")
if __name__ == "__main__":
main()