mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-26 22:20:06 +00:00
feat(viz): add image-based consistency analysis with SigLIP
Run two parallel KNN analyses per dataset: 1. State-based: KNN in joint-state space 2. Image-based: KNN in SigLIP embedding space (google/siglip-base-patch16-224) Both measure action chunk variance among cross-episode neighbors. Comparing them reveals whether visual and proprioceptive similarity agree on where data is inconsistent. Output is a 4-row figure: state histogram, image histogram, overlaid per-episode curves, and spatial heatmap colored by image-based variance. Made-with: Cursor
This commit is contained in:
@@ -1,22 +1,28 @@
|
|||||||
"""
|
"""
|
||||||
Action-state consistency analysis for imitation learning datasets.
|
Action consistency analysis for imitation learning datasets.
|
||||||
For each frame, finds K nearest neighbors in state space (from other episodes)
|
|
||||||
and measures the variance of corresponding actions. High variance at similar
|
|
||||||
states = contradictory supervision for the policy.
|
|
||||||
|
|
||||||
Outputs a comparison figure with histograms, per-episode curves, and spatial
|
Two parallel analyses per dataset:
|
||||||
heatmaps showing where demonstrations conflict.
|
1. State-based: KNN in joint-state space → action chunk variance
|
||||||
|
2. Image-based: KNN in SigLIP embedding space → action chunk variance
|
||||||
|
|
||||||
|
Comparing them reveals whether visual similarity and proprioceptive similarity
|
||||||
|
agree on where the data is inconsistent — and images are what the policy
|
||||||
|
primarily sees.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import cv2
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
import torch
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
from matplotlib.colors import LinearSegmentedColormap
|
from matplotlib.colors import LinearSegmentedColormap
|
||||||
|
from PIL import Image
|
||||||
from scipy.spatial import cKDTree
|
from scipy.spatial import cKDTree
|
||||||
|
from transformers import AutoModel, AutoProcessor
|
||||||
|
|
||||||
DATASETS = [
|
DATASETS = [
|
||||||
{"repo_id": "lerobot-data-collection/level2_final_quality3", "label": "HQ curated"},
|
{"repo_id": "lerobot-data-collection/level2_final_quality3", "label": "HQ curated"},
|
||||||
@@ -25,9 +31,12 @@ DATASETS = [
|
|||||||
OUTPUT_DIR = Path(__file__).resolve().parent / "outputs"
|
OUTPUT_DIR = Path(__file__).resolve().parent / "outputs"
|
||||||
OUTPUT_DIR.mkdir(exist_ok=True)
|
OUTPUT_DIR.mkdir(exist_ok=True)
|
||||||
|
|
||||||
MAX_FRAMES = 10_000
|
MAX_FRAMES = 100_000
|
||||||
K_NEIGHBORS = 50
|
K_NEIGHBORS = 50
|
||||||
ACTION_CHUNK_SIZE = 30
|
ACTION_CHUNK_SIZE = 30
|
||||||
|
CAMERA_KEY = "observation.images.base"
|
||||||
|
ENCODER_MODEL = "google/siglip-base-patch16-224"
|
||||||
|
ENCODE_BATCH_SIZE = 128
|
||||||
SEED = 42
|
SEED = 42
|
||||||
DPI = 150
|
DPI = 150
|
||||||
|
|
||||||
@@ -184,14 +193,17 @@ def _find_joint_indices(features: dict, state_col: str, n_dim: int) -> tuple[lis
|
|||||||
raise RuntimeError(f"State dim {n_dim} too small for bimanual 7-DOF robot")
|
raise RuntimeError(f"State dim {n_dim} too small for bimanual 7-DOF robot")
|
||||||
|
|
||||||
|
|
||||||
def download_data(repo_id: str) -> Path:
|
def download_data(repo_id: str, camera_key: str) -> Path:
|
||||||
print(f" Downloading {repo_id} (parquet only) …")
|
print(f" Downloading {repo_id} (parquet + {camera_key} videos) …")
|
||||||
return Path(
|
return Path(
|
||||||
snapshot_download(
|
snapshot_download(
|
||||||
repo_id=repo_id,
|
repo_id=repo_id,
|
||||||
repo_type="dataset",
|
repo_type="dataset",
|
||||||
allow_patterns=["meta/**", "data/**"],
|
allow_patterns=[
|
||||||
ignore_patterns=["*.mp4", "videos/**"],
|
"meta/**",
|
||||||
|
"data/**",
|
||||||
|
f"videos/{camera_key}/**",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -203,9 +215,8 @@ def _build_action_chunks(
|
|||||||
actions: np.ndarray, episode_ids: np.ndarray, chunk_size: int
|
actions: np.ndarray, episode_ids: np.ndarray, chunk_size: int
|
||||||
) -> tuple[np.ndarray, np.ndarray]:
|
) -> tuple[np.ndarray, np.ndarray]:
|
||||||
"""
|
"""
|
||||||
Build action chunks: for each frame, concatenate the next chunk_size actions
|
For each frame, concatenate the next chunk_size actions from the same episode.
|
||||||
from the same episode. Returns (action_chunks, valid_mask).
|
Returns (action_chunks, valid_mask).
|
||||||
Frames too close to episode end to form a full chunk are marked invalid.
|
|
||||||
"""
|
"""
|
||||||
n = len(actions)
|
n = len(actions)
|
||||||
act_dim = actions.shape[1]
|
act_dim = actions.shape[1]
|
||||||
@@ -216,7 +227,6 @@ def _build_action_chunks(
|
|||||||
end = i + chunk_size
|
end = i + chunk_size
|
||||||
if end > n:
|
if end > n:
|
||||||
continue
|
continue
|
||||||
# All frames in the chunk must belong to the same episode
|
|
||||||
if episode_ids[i] != episode_ids[end - 1]:
|
if episode_ids[i] != episode_ids[end - 1]:
|
||||||
continue
|
continue
|
||||||
chunks[i] = actions[i:end].ravel()
|
chunks[i] = actions[i:end].ravel()
|
||||||
@@ -227,8 +237,8 @@ def _build_action_chunks(
|
|||||||
|
|
||||||
def load_state_action_data(local: Path, max_frames: int, chunk_size: int, rng: np.random.Generator) -> dict:
|
def load_state_action_data(local: Path, max_frames: int, chunk_size: int, rng: np.random.Generator) -> dict:
|
||||||
"""
|
"""
|
||||||
Load observation.state and action columns, build action chunks of size
|
Load observation.state and action, build action chunks, subsample, normalize.
|
||||||
chunk_size (matching what the policy learns), subsample, normalize.
|
Also returns the original row indices (`chosen_idx`) for video frame mapping.
|
||||||
"""
|
"""
|
||||||
info = json.loads((local / "meta" / "info.json").read_text())
|
info = json.loads((local / "meta" / "info.json").read_text())
|
||||||
features = info.get("features", {})
|
features = info.get("features", {})
|
||||||
@@ -260,13 +270,11 @@ def load_state_action_data(local: Path, max_frames: int, chunk_size: int, rng: n
|
|||||||
|
|
||||||
left_idx, right_idx = _find_joint_indices(features, state_col, n_dim)
|
left_idx, right_idx = _find_joint_indices(features, state_col, n_dim)
|
||||||
|
|
||||||
# Build action chunks within episode boundaries
|
|
||||||
print(" Building action chunks …")
|
print(" Building action chunks …")
|
||||||
action_chunks, valid = _build_action_chunks(action_all, episode_all, chunk_size)
|
action_chunks, valid = _build_action_chunks(action_all, episode_all, chunk_size)
|
||||||
valid_idx = np.where(valid)[0]
|
valid_idx = np.where(valid)[0]
|
||||||
print(f" Valid frames (with full action chunk): {len(valid_idx):,} / {n_total:,}")
|
print(f" Valid frames (with full action chunk): {len(valid_idx):,} / {n_total:,}")
|
||||||
|
|
||||||
# Subsample from valid frames only
|
|
||||||
if len(valid_idx) > max_frames:
|
if len(valid_idx) > max_frames:
|
||||||
chosen = np.sort(rng.choice(valid_idx, max_frames, replace=False))
|
chosen = np.sort(rng.choice(valid_idx, max_frames, replace=False))
|
||||||
else:
|
else:
|
||||||
@@ -277,7 +285,6 @@ def load_state_action_data(local: Path, max_frames: int, chunk_size: int, rng: n
|
|||||||
action_raw = action_chunks[chosen]
|
action_raw = action_chunks[chosen]
|
||||||
episode_ids = episode_all[chosen]
|
episode_ids = episode_all[chosen]
|
||||||
|
|
||||||
# Z-score normalize for fair KNN distance
|
|
||||||
state_mean = state_raw.mean(axis=0)
|
state_mean = state_raw.mean(axis=0)
|
||||||
state_std = state_raw.std(axis=0)
|
state_std = state_raw.std(axis=0)
|
||||||
state_std[state_std < 1e-8] = 1.0
|
state_std[state_std < 1e-8] = 1.0
|
||||||
@@ -294,38 +301,186 @@ def load_state_action_data(local: Path, max_frames: int, chunk_size: int, rng: n
|
|||||||
"action_raw": action_raw,
|
"action_raw": action_raw,
|
||||||
"action_norm": action_norm,
|
"action_norm": action_norm,
|
||||||
"episode_ids": episode_ids,
|
"episode_ids": episode_ids,
|
||||||
|
"episode_all": episode_all,
|
||||||
"left_joint_idx": left_idx,
|
"left_joint_idx": left_idx,
|
||||||
"right_joint_idx": right_idx,
|
"right_joint_idx": right_idx,
|
||||||
"n_total": n_total,
|
"n_total": n_total,
|
||||||
|
"chosen_idx": chosen,
|
||||||
|
"df": df,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ── Video → frame extraction ──────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def build_video_lookup(local: Path, camera_key: str) -> dict:
|
||||||
|
"""
|
||||||
|
Build a mapping from episode_index → {video_path, fps, from_ts}.
|
||||||
|
"""
|
||||||
|
info = json.loads((local / "meta" / "info.json").read_text())
|
||||||
|
fps = info["fps"]
|
||||||
|
video_template = info.get(
|
||||||
|
"video_path",
|
||||||
|
"videos/{video_key}/chunk-{chunk_index:03d}/file-{file_index:03d}.mp4",
|
||||||
|
)
|
||||||
|
|
||||||
|
ep_rows = []
|
||||||
|
for pq in sorted((local / "meta" / "episodes").glob("**/*.parquet")):
|
||||||
|
ep_rows.append(pd.read_parquet(pq))
|
||||||
|
ep_df = pd.concat(ep_rows, ignore_index=True)
|
||||||
|
|
||||||
|
chunk_col = f"videos/{camera_key}/chunk_index"
|
||||||
|
file_col = f"videos/{camera_key}/file_index"
|
||||||
|
ts_from = f"videos/{camera_key}/from_timestamp"
|
||||||
|
if chunk_col not in ep_df.columns:
|
||||||
|
chunk_col = f"{camera_key}/chunk_index"
|
||||||
|
file_col = f"{camera_key}/file_index"
|
||||||
|
ts_from = f"{camera_key}/from_timestamp"
|
||||||
|
|
||||||
|
lookup: dict[int, dict] = {}
|
||||||
|
for _, row in ep_df.iterrows():
|
||||||
|
ci = int(row[chunk_col])
|
||||||
|
fi = int(row[file_col])
|
||||||
|
video_rel = video_template.format(video_key=camera_key, chunk_index=ci, file_index=fi)
|
||||||
|
lookup[int(row["episode_index"])] = {
|
||||||
|
"video_path": local / video_rel,
|
||||||
|
"from_ts": float(row[ts_from]),
|
||||||
|
"fps": fps,
|
||||||
|
}
|
||||||
|
return lookup
|
||||||
|
|
||||||
|
|
||||||
|
def extract_frames(
|
||||||
|
chosen_idx: np.ndarray,
|
||||||
|
episode_all: np.ndarray,
|
||||||
|
video_lookup: dict,
|
||||||
|
) -> list[np.ndarray | None]:
|
||||||
|
"""
|
||||||
|
Extract BGR frames for each chosen global index.
|
||||||
|
Uses episode boundaries + fps to compute the seek timestamp.
|
||||||
|
Returns list of (H, W, 3) BGR arrays (or None on failure).
|
||||||
|
"""
|
||||||
|
# Build per-episode local frame index: for each row in the dataset,
|
||||||
|
# its position within its episode
|
||||||
|
unique_eps = np.unique(episode_all)
|
||||||
|
ep_start: dict[int, int] = {}
|
||||||
|
for ep in unique_eps:
|
||||||
|
ep_start[int(ep)] = int(np.where(episode_all == ep)[0][0])
|
||||||
|
|
||||||
|
frames: list[np.ndarray | None] = []
|
||||||
|
# Group by video file for efficient sequential access
|
||||||
|
jobs: list[tuple[int, int, str, float]] = []
|
||||||
|
for out_i, global_i in enumerate(chosen_idx):
|
||||||
|
ep = int(episode_all[global_i])
|
||||||
|
info = video_lookup.get(ep)
|
||||||
|
if info is None:
|
||||||
|
jobs.append((out_i, -1, "", 0.0))
|
||||||
|
continue
|
||||||
|
local_frame = global_i - ep_start[ep]
|
||||||
|
seek_ts = info["from_ts"] + local_frame / info["fps"]
|
||||||
|
jobs.append((out_i, global_i, str(info["video_path"]), seek_ts))
|
||||||
|
|
||||||
|
jobs.sort(key=lambda x: (x[2], x[3]))
|
||||||
|
|
||||||
|
frames = [None] * len(chosen_idx)
|
||||||
|
current_cap = None
|
||||||
|
current_path = ""
|
||||||
|
extracted = 0
|
||||||
|
for out_i, _global_i, vpath, seek_ts in jobs:
|
||||||
|
if not vpath:
|
||||||
|
continue
|
||||||
|
if vpath != current_path:
|
||||||
|
if current_cap is not None:
|
||||||
|
current_cap.release()
|
||||||
|
current_cap = cv2.VideoCapture(vpath)
|
||||||
|
current_path = vpath
|
||||||
|
if current_cap is None or not current_cap.isOpened():
|
||||||
|
continue
|
||||||
|
current_cap.set(cv2.CAP_PROP_POS_MSEC, seek_ts * 1000.0)
|
||||||
|
ret, frame = current_cap.read()
|
||||||
|
if ret:
|
||||||
|
frames[out_i] = frame
|
||||||
|
extracted += 1
|
||||||
|
if current_cap is not None:
|
||||||
|
current_cap.release()
|
||||||
|
|
||||||
|
print(f" Extracted {extracted:,} / {len(chosen_idx):,} frames from video")
|
||||||
|
return frames
|
||||||
|
|
||||||
|
|
||||||
|
# ── SigLIP encoding ─────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def encode_frames_siglip(
|
||||||
|
frames: list[np.ndarray | None],
|
||||||
|
model_name: str,
|
||||||
|
batch_size: int,
|
||||||
|
device: torch.device,
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Encode BGR frames through SigLIP vision encoder.
|
||||||
|
Returns (N, embed_dim) float32 array. Frames that are None get a zero vector.
|
||||||
|
"""
|
||||||
|
print(f" Loading SigLIP model: {model_name} …")
|
||||||
|
processor = AutoProcessor.from_pretrained(model_name)
|
||||||
|
model = AutoModel.from_pretrained(model_name).to(device).eval()
|
||||||
|
embed_dim = model.config.vision_config.hidden_size
|
||||||
|
|
||||||
|
n = len(frames)
|
||||||
|
embeddings = np.zeros((n, embed_dim), dtype=np.float32)
|
||||||
|
|
||||||
|
# Collect valid frame indices
|
||||||
|
valid_indices = [i for i, f in enumerate(frames) if f is not None]
|
||||||
|
print(f" Encoding {len(valid_indices):,} valid frames in batches of {batch_size} …")
|
||||||
|
|
||||||
|
for batch_start in range(0, len(valid_indices), batch_size):
|
||||||
|
batch_idx = valid_indices[batch_start : batch_start + batch_size]
|
||||||
|
pil_images = []
|
||||||
|
for i in batch_idx:
|
||||||
|
bgr = frames[i]
|
||||||
|
rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
|
||||||
|
pil_images.append(Image.fromarray(rgb))
|
||||||
|
|
||||||
|
inputs = processor(images=pil_images, return_tensors="pt").to(device)
|
||||||
|
with torch.no_grad():
|
||||||
|
image_features = model.get_image_features(**inputs)
|
||||||
|
# L2-normalize embeddings for cosine-like KNN
|
||||||
|
image_features = torch.nn.functional.normalize(image_features, dim=-1)
|
||||||
|
embeddings[batch_idx] = image_features.cpu().numpy()
|
||||||
|
|
||||||
|
done = min(batch_start + batch_size, len(valid_indices))
|
||||||
|
if done % (batch_size * 10) == 0 or done == len(valid_indices):
|
||||||
|
print(f" {done:,} / {len(valid_indices):,} encoded")
|
||||||
|
|
||||||
|
del model, processor
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
# ── KNN consistency ─────────────────────────────────────
|
# ── KNN consistency ─────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
def compute_consistency(
|
def compute_consistency(
|
||||||
state_norm: np.ndarray,
|
features: np.ndarray,
|
||||||
action_norm: np.ndarray,
|
action_norm: np.ndarray,
|
||||||
episode_ids: np.ndarray,
|
episode_ids: np.ndarray,
|
||||||
k: int,
|
k: int,
|
||||||
|
label: str = "",
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
For each frame, find K nearest neighbors in state space from *other* episodes.
|
For each frame, find K nearest neighbors in feature space from other episodes.
|
||||||
Return per-frame action variance (mean across action dims).
|
Return per-frame action variance (mean across action dims).
|
||||||
"""
|
"""
|
||||||
n = len(state_norm)
|
n = len(features)
|
||||||
print(f" Building KD-tree on {n:,} state vectors …")
|
print(f" Building KD-tree on {n:,} vectors ({label}) …")
|
||||||
tree = cKDTree(state_norm)
|
tree = cKDTree(features)
|
||||||
|
|
||||||
# Query extra neighbors to have room after filtering same-episode
|
|
||||||
k_query = min(k * 3, n - 1)
|
k_query = min(k * 3, n - 1)
|
||||||
print(f" Querying {k_query} neighbors per frame …")
|
print(f" Querying {k_query} neighbors per frame …")
|
||||||
dists, indices = tree.query(state_norm, k=k_query + 1)
|
_dists, indices = tree.query(features, k=k_query + 1)
|
||||||
|
|
||||||
# indices[:, 0] is the point itself — skip it
|
|
||||||
indices = indices[:, 1:]
|
indices = indices[:, 1:]
|
||||||
|
|
||||||
print(" Computing cross-episode action variance …")
|
print(f" Computing cross-episode action variance ({label}) …")
|
||||||
variance = np.zeros(n)
|
variance = np.zeros(n)
|
||||||
for i in range(n):
|
for i in range(n):
|
||||||
ep_i = episode_ids[i]
|
ep_i = episode_ids[i]
|
||||||
@@ -343,84 +498,67 @@ def compute_consistency(
|
|||||||
# ── Visualization ───────────────────────────────────────
|
# ── Visualization ───────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
def render(results: list[dict], out_path: Path) -> None:
|
def _style_ax(ax: plt.Axes) -> None:
|
||||||
n_ds = len(results)
|
ax.set_facecolor("#0d1117")
|
||||||
fig, axes = plt.subplots(3, n_ds, figsize=(9 * n_ds, 18), facecolor="#0d1117")
|
ax.tick_params(colors="#555", labelsize=8)
|
||||||
if n_ds == 1:
|
for spine in ax.spines.values():
|
||||||
axes = axes[:, np.newaxis]
|
spine.set_color("#333")
|
||||||
|
|
||||||
headline_parts = []
|
|
||||||
|
|
||||||
for col, r in enumerate(results):
|
|
||||||
variance = r["variance"]
|
|
||||||
episode_ids = r["episode_ids"]
|
|
||||||
tcp_xz = r["tcp_xz"]
|
|
||||||
label = r["label"]
|
|
||||||
|
|
||||||
|
def _plot_histogram(ax: plt.Axes, variance: np.ndarray, title: str, color: str) -> None:
|
||||||
|
_style_ax(ax)
|
||||||
median_var = np.median(variance)
|
median_var = np.median(variance)
|
||||||
mean_var = np.mean(variance)
|
mean_var = np.mean(variance)
|
||||||
headline_parts.append(f"{label}: median={median_var:.3f}, mean={mean_var:.3f}")
|
|
||||||
|
|
||||||
# Row 0: Histogram of per-frame action variance
|
|
||||||
ax = axes[0, col]
|
|
||||||
ax.set_facecolor("#0d1117")
|
|
||||||
nonzero = variance[variance > 0]
|
nonzero = variance[variance > 0]
|
||||||
if len(nonzero) > 0:
|
if len(nonzero) > 0:
|
||||||
bins = np.logspace(np.log10(nonzero.min().clip(1e-6)), np.log10(nonzero.max()), 60)
|
bins = np.logspace(np.log10(nonzero.min().clip(1e-6)), np.log10(nonzero.max()), 60)
|
||||||
ax.hist(nonzero, bins=bins, color="#4363d8", alpha=0.8, edgecolor="#222")
|
ax.hist(nonzero, bins=bins, color=color, alpha=0.8, edgecolor="#222")
|
||||||
ax.set_xscale("log")
|
ax.set_xscale("log")
|
||||||
ax.axvline(median_var, color="#ff6600", linewidth=2, label=f"median={median_var:.3f}")
|
ax.axvline(median_var, color="#ff6600", linewidth=2, label=f"median={median_var:.3f}")
|
||||||
ax.axvline(mean_var, color="#ff2222", linewidth=2, linestyle="--", label=f"mean={mean_var:.3f}")
|
ax.axvline(mean_var, color="#ff2222", linewidth=2, linestyle="--", label=f"mean={mean_var:.3f}")
|
||||||
ax.set_xlabel("Action variance (log scale)", color="#888", fontsize=10)
|
ax.set_xlabel("Action variance (log scale)", color="#888", fontsize=10)
|
||||||
ax.set_ylabel("Frame count", color="#888", fontsize=10)
|
ax.set_ylabel("Frame count", color="#888", fontsize=10)
|
||||||
ax.set_title(f"{label}\nPer-frame action variance distribution", color="white", fontsize=12, pad=10)
|
ax.set_title(title, color="white", fontsize=11, pad=10)
|
||||||
ax.tick_params(colors="#555", labelsize=8)
|
ax.legend(fontsize=8, facecolor="#1a1a2e", edgecolor="#333", labelcolor="white")
|
||||||
for spine in ax.spines.values():
|
|
||||||
spine.set_color("#333")
|
|
||||||
ax.legend(fontsize=9, facecolor="#1a1a2e", edgecolor="#333", labelcolor="white")
|
|
||||||
|
|
||||||
# Row 1: Per-episode mean inconsistency curve (sorted)
|
|
||||||
ax = axes[1, col]
|
def _plot_episode_curves(
|
||||||
ax.set_facecolor("#0d1117")
|
ax: plt.Axes,
|
||||||
|
var_state: np.ndarray,
|
||||||
|
var_image: np.ndarray,
|
||||||
|
episode_ids: np.ndarray,
|
||||||
|
title: str,
|
||||||
|
) -> None:
|
||||||
|
_style_ax(ax)
|
||||||
unique_eps = np.unique(episode_ids)
|
unique_eps = np.unique(episode_ids)
|
||||||
ep_means = np.array([variance[episode_ids == ep].mean() for ep in unique_eps])
|
|
||||||
sorted_means = np.sort(ep_means)[::-1]
|
|
||||||
ep_x = np.arange(len(sorted_means))
|
|
||||||
|
|
||||||
p90 = np.percentile(ep_means, 90)
|
ep_means_s = np.array([var_state[episode_ids == ep].mean() for ep in unique_eps])
|
||||||
above_p90 = np.sum(ep_means > p90)
|
ep_means_i = np.array([var_image[episode_ids == ep].mean() for ep in unique_eps])
|
||||||
|
|
||||||
|
sorted_s = np.sort(ep_means_s)[::-1]
|
||||||
|
sorted_i = np.sort(ep_means_i)[::-1]
|
||||||
|
ep_x = np.arange(len(unique_eps))
|
||||||
|
|
||||||
|
ax.fill_between(ep_x, sorted_s, alpha=0.2, color="#4363d8")
|
||||||
|
ax.plot(ep_x, sorted_s, color="#4363d8", linewidth=1.2, label=f"State (med={np.median(ep_means_s):.3f})")
|
||||||
|
ax.fill_between(ep_x, sorted_i, alpha=0.2, color="#e6194b")
|
||||||
|
ax.plot(ep_x, sorted_i, color="#e6194b", linewidth=1.2, label=f"Image (med={np.median(ep_means_i):.3f})")
|
||||||
|
|
||||||
ax.fill_between(ep_x, sorted_means, alpha=0.3, color="#4363d8")
|
|
||||||
ax.plot(ep_x, sorted_means, color="#4363d8", linewidth=1.2)
|
|
||||||
ax.axhline(
|
|
||||||
np.median(ep_means), color="#ff6600", linewidth=1.5, label=f"median={np.median(ep_means):.3f}"
|
|
||||||
)
|
|
||||||
ax.axhline(
|
|
||||||
p90, color="#ff2222", linewidth=1, linestyle=":", label=f"p90={p90:.3f} ({above_p90} eps above)"
|
|
||||||
)
|
|
||||||
ax.set_xlabel("Episode rank (worst → best)", color="#888", fontsize=10)
|
ax.set_xlabel("Episode rank (worst → best)", color="#888", fontsize=10)
|
||||||
ax.set_ylabel("Mean action variance", color="#888", fontsize=10)
|
ax.set_ylabel("Mean action variance", color="#888", fontsize=10)
|
||||||
ax.set_title(
|
ax.set_title(title, color="white", fontsize=11, pad=10)
|
||||||
f"{label}\nPer-episode inconsistency ({len(unique_eps):,} episodes)",
|
ax.legend(fontsize=8, facecolor="#1a1a2e", edgecolor="#333", labelcolor="white")
|
||||||
color="white",
|
|
||||||
fontsize=12,
|
|
||||||
pad=10,
|
|
||||||
)
|
|
||||||
ax.tick_params(colors="#555", labelsize=8)
|
|
||||||
for spine in ax.spines.values():
|
|
||||||
spine.set_color("#333")
|
|
||||||
ax.legend(fontsize=9, facecolor="#1a1a2e", edgecolor="#333", labelcolor="white")
|
|
||||||
|
|
||||||
# Row 2: Spatial heatmap (XZ side view) colored by local action variance
|
|
||||||
ax = axes[2, col]
|
def _plot_heatmap(
|
||||||
ax.set_facecolor("#0d1117")
|
ax: plt.Axes, fig: plt.Figure, tcp_xz: np.ndarray, variance: np.ndarray, title: str
|
||||||
|
) -> None:
|
||||||
|
_style_ax(ax)
|
||||||
order = np.argsort(variance)
|
order = np.argsort(variance)
|
||||||
pts = tcp_xz[order]
|
pts = tcp_xz[order]
|
||||||
var_sorted = variance[order]
|
var_sorted = variance[order]
|
||||||
|
|
||||||
vmin = np.percentile(variance[variance > 0], 5) if np.any(variance > 0) else 0
|
vmin = np.percentile(variance[variance > 0], 5) if np.any(variance > 0) else 0
|
||||||
vmax = np.percentile(variance[variance > 0], 95) if np.any(variance > 0) else 1
|
vmax = np.percentile(variance[variance > 0], 95) if np.any(variance > 0) else 1
|
||||||
|
|
||||||
sc = ax.scatter(
|
sc = ax.scatter(
|
||||||
pts[:, 0],
|
pts[:, 0],
|
||||||
pts[:, 1],
|
pts[:, 1],
|
||||||
@@ -434,22 +572,59 @@ def render(results: list[dict], out_path: Path) -> None:
|
|||||||
)
|
)
|
||||||
ax.set_xlabel("X (m)", color="#888", fontsize=10)
|
ax.set_xlabel("X (m)", color="#888", fontsize=10)
|
||||||
ax.set_ylabel("Z (m)", color="#888", fontsize=10)
|
ax.set_ylabel("Z (m)", color="#888", fontsize=10)
|
||||||
ax.set_title(
|
ax.set_title(title, color="white", fontsize=11, pad=10)
|
||||||
f"{label}\nAction variance by TCP position (XZ side)",
|
|
||||||
color="white",
|
|
||||||
fontsize=12,
|
|
||||||
pad=10,
|
|
||||||
)
|
|
||||||
ax.tick_params(colors="#555", labelsize=8)
|
|
||||||
for spine in ax.spines.values():
|
|
||||||
spine.set_color("#333")
|
|
||||||
ax.set_aspect("equal")
|
ax.set_aspect("equal")
|
||||||
cbar = fig.colorbar(sc, ax=ax, shrink=0.8, pad=0.02)
|
cbar = fig.colorbar(sc, ax=ax, shrink=0.8, pad=0.02)
|
||||||
cbar.set_label("Action variance", color="white", fontsize=9)
|
cbar.set_label("Action variance", color="white", fontsize=9)
|
||||||
cbar.ax.tick_params(colors="#aaa", labelsize=7)
|
cbar.ax.tick_params(colors="#aaa", labelsize=7)
|
||||||
|
|
||||||
|
|
||||||
|
def render(results: list[dict], out_path: Path) -> None:
|
||||||
|
"""
|
||||||
|
4-row x N-column figure:
|
||||||
|
Row 0: State-based variance histogram
|
||||||
|
Row 1: Image-based variance histogram
|
||||||
|
Row 2: Per-episode curves (both overlaid)
|
||||||
|
Row 3: Spatial heatmap (image-based variance)
|
||||||
|
"""
|
||||||
|
n_ds = len(results)
|
||||||
|
fig, axes = plt.subplots(4, n_ds, figsize=(9 * n_ds, 24), facecolor="#0d1117")
|
||||||
|
if n_ds == 1:
|
||||||
|
axes = axes[:, np.newaxis]
|
||||||
|
|
||||||
|
headline_parts = []
|
||||||
|
for col, r in enumerate(results):
|
||||||
|
label = r["label"]
|
||||||
|
var_s = r["var_state"]
|
||||||
|
var_i = r["var_image"]
|
||||||
|
tcp_xz = r["tcp_xz"]
|
||||||
|
episode_ids = r["episode_ids"]
|
||||||
|
|
||||||
|
med_s = np.median(var_s)
|
||||||
|
med_i = np.median(var_i)
|
||||||
|
headline_parts.append(f"{label}: state={med_s:.3f}, image={med_i:.3f}")
|
||||||
|
|
||||||
|
_plot_histogram(axes[0, col], var_s, f"{label}\nState-based variance (K={K_NEIGHBORS})", "#4363d8")
|
||||||
|
_plot_histogram(
|
||||||
|
axes[1, col], var_i, f"{label}\nImage-based variance (SigLIP, K={K_NEIGHBORS})", "#e6194b"
|
||||||
|
)
|
||||||
|
_plot_episode_curves(
|
||||||
|
axes[2, col],
|
||||||
|
var_s,
|
||||||
|
var_i,
|
||||||
|
episode_ids,
|
||||||
|
f"{label}\nPer-episode inconsistency ({len(np.unique(episode_ids)):,} episodes)",
|
||||||
|
)
|
||||||
|
_plot_heatmap(
|
||||||
|
axes[3, col],
|
||||||
|
fig,
|
||||||
|
tcp_xz,
|
||||||
|
var_i,
|
||||||
|
f"{label}\nImage-based variance by TCP position (XZ)",
|
||||||
|
)
|
||||||
|
|
||||||
fig.suptitle(
|
fig.suptitle(
|
||||||
f"Action-State Consistency Analysis (action chunk = {ACTION_CHUNK_SIZE})\n"
|
f"Action Consistency: State vs Image (chunk={ACTION_CHUNK_SIZE}, K={K_NEIGHBORS})\n"
|
||||||
+ " | ".join(headline_parts),
|
+ " | ".join(headline_parts),
|
||||||
color="white",
|
color="white",
|
||||||
fontsize=15,
|
fontsize=15,
|
||||||
@@ -465,6 +640,8 @@ def render(results: list[dict], out_path: Path) -> None:
|
|||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
print(f"Device: {device}")
|
||||||
rng = np.random.default_rng(SEED)
|
rng = np.random.default_rng(SEED)
|
||||||
results = []
|
results = []
|
||||||
|
|
||||||
@@ -474,18 +651,34 @@ def main() -> None:
|
|||||||
print(f" {label}: {repo_id}")
|
print(f" {label}: {repo_id}")
|
||||||
print(f"{'=' * 60}")
|
print(f"{'=' * 60}")
|
||||||
|
|
||||||
local = download_data(repo_id)
|
local = download_data(repo_id, CAMERA_KEY)
|
||||||
data = load_state_action_data(local, MAX_FRAMES, ACTION_CHUNK_SIZE, rng)
|
data = load_state_action_data(local, MAX_FRAMES, ACTION_CHUNK_SIZE, rng)
|
||||||
|
|
||||||
variance = compute_consistency(
|
# --- State-based KNN ---
|
||||||
data["state_norm"], data["action_norm"], data["episode_ids"], K_NEIGHBORS
|
var_state = compute_consistency(
|
||||||
|
data["state_norm"], data["action_norm"], data["episode_ids"], K_NEIGHBORS, "state"
|
||||||
)
|
)
|
||||||
print(
|
print(
|
||||||
f" Variance stats: median={np.median(variance):.4f} mean={np.mean(variance):.4f} "
|
f" State variance: median={np.median(var_state):.4f} "
|
||||||
f"p90={np.percentile(variance, 90):.4f}"
|
f"mean={np.mean(var_state):.4f} p90={np.percentile(var_state, 90):.4f}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Compute FK for spatial heatmap (left arm TCP, XZ projection)
|
# --- Image-based KNN ---
|
||||||
|
print("\n Preparing image embeddings …")
|
||||||
|
video_lookup = build_video_lookup(local, CAMERA_KEY)
|
||||||
|
frames = extract_frames(data["chosen_idx"], data["episode_all"], video_lookup)
|
||||||
|
embeddings = encode_frames_siglip(frames, ENCODER_MODEL, ENCODE_BATCH_SIZE, device)
|
||||||
|
del frames # free memory
|
||||||
|
|
||||||
|
var_image = compute_consistency(
|
||||||
|
embeddings, data["action_norm"], data["episode_ids"], K_NEIGHBORS, "image"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f" Image variance: median={np.median(var_image):.4f} "
|
||||||
|
f"mean={np.mean(var_image):.4f} p90={np.percentile(var_image, 90):.4f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# FK for spatial heatmap
|
||||||
print(" Computing FK for spatial heatmap …")
|
print(" Computing FK for spatial heatmap …")
|
||||||
left_raw = data["state_raw"][:, data["left_joint_idx"]]
|
left_raw = data["state_raw"][:, data["left_joint_idx"]]
|
||||||
left_rad = _detect_and_convert(left_raw)
|
left_rad = _detect_and_convert(left_raw)
|
||||||
@@ -495,7 +688,8 @@ def main() -> None:
|
|||||||
results.append(
|
results.append(
|
||||||
{
|
{
|
||||||
"label": label,
|
"label": label,
|
||||||
"variance": variance,
|
"var_state": var_state,
|
||||||
|
"var_image": var_image,
|
||||||
"episode_ids": data["episode_ids"],
|
"episode_ids": data["episode_ids"],
|
||||||
"tcp_xz": tcp_xz,
|
"tcp_xz": tcp_xz,
|
||||||
"n_total": data["n_total"],
|
"n_total": data["n_total"],
|
||||||
@@ -505,6 +699,17 @@ def main() -> None:
|
|||||||
out = OUTPUT_DIR / "action_consistency_comparison.jpg"
|
out = OUTPUT_DIR / "action_consistency_comparison.jpg"
|
||||||
render(results, out)
|
render(results, out)
|
||||||
|
|
||||||
|
# Save worst-episodes summary (image-based, since that's the stronger signal)
|
||||||
|
worst_summary = {}
|
||||||
|
for r in results:
|
||||||
|
unique_eps = np.unique(r["episode_ids"])
|
||||||
|
ep_means = {int(ep): float(r["var_image"][r["episode_ids"] == ep].mean()) for ep in unique_eps}
|
||||||
|
ranked = sorted(ep_means.items(), key=lambda x: x[1], reverse=True)[:50]
|
||||||
|
worst_summary[r["label"]] = [{"episode": ep, "mean_variance": v} for ep, v in ranked]
|
||||||
|
worst_path = OUTPUT_DIR / "action_consistency_worst_episodes.json"
|
||||||
|
worst_path.write_text(json.dumps(worst_summary, indent=2))
|
||||||
|
print(f"✓ Saved worst episodes: {worst_path}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|||||||
Reference in New Issue
Block a user