Files
lerobot/scripts/multimodal_analysis.py
T
pepijn c9cfc88602 feat: add benchmark orchestration, LIBERO-plus install parity, and eval hardening
- Add lerobot-benchmark CLI for multi-benchmark train/eval workflows
- Add benchmark_training.mdx documentation
- Add libero-plus pip extra alias with EGL probe deps matching standard libero
- Harden libero.py: wand mock, init-state fallback, renderer EGL→OSMesa fallback
- Add multimodal_analysis.py script and SLURM training template

Made-with: Cursor
2026-03-15 05:52:53 +00:00

690 lines
27 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Chunk-level multi-modality analysis for comparing full/mixed vs curated datasets.
Treats each action chunk (sliding window of CHUNK_SIZE consecutive frames) as the
atomic unit, tagged by the SARM progress score at its start frame. For each
progress band, compares the full vs HQ dataset on:
1. Intra-band action variance
2. Progress delta per chunk
3. GMM + BIC optimal K (number of distinct strategies)
4. PCA embedding (visual cluster inspection)
Usage:
python chunk_multimodality_analysis.py \\
--full-dataset lerobot-data-collection/level12_rac_2_2026-02-08_1 \\
--hq-dataset lerobot-data-collection/level2_final_quality3 \\
--output-dir ./chunk_analysis
"""
from __future__ import annotations
import argparse
import logging
from collections import defaultdict
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from huggingface_hub import hf_hub_download
from scipy.stats import gaussian_kde
from sklearn.decomposition import PCA
from sklearn.mixture import GaussianMixture
from sklearn.preprocessing import StandardScaler
from lerobot.datasets.lerobot_dataset import LeRobotDataset
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
logger = logging.getLogger(__name__)
# ── Visual style ──────────────────────────────────────────────────────────
BG = "#0e1117"
CARD = "#1a1d27"
BORDER = "#2a2d3a"
SUB = "#8b8fa8"
TEXT = "#e8eaf0"
C_FULL = "#f7934f"
C_HQ = "#4dc98a"
def _style_ax(ax: plt.Axes) -> None:
ax.set_facecolor(CARD)
ax.tick_params(colors=SUB, labelsize=8)
for spine in ax.spines.values():
spine.set_color(BORDER)
def _save(fig: plt.Figure, path: Path) -> None:
fig.savefig(path, dpi=150, bbox_inches="tight", facecolor=BG)
plt.close(fig)
logger.info("Saved %s", path)
# ── Step 0: Load episodes ────────────────────────────────────────────────
def _load_sarm_progress(repo_id: str) -> pd.DataFrame | None:
"""Try to download sarm_progress.parquet from the Hub."""
try:
path = hf_hub_download(
repo_id=repo_id, filename="sarm_progress.parquet",
repo_type="dataset",
)
df = pd.read_parquet(path)
col = "progress_sparse" if "progress_sparse" in df.columns else "progress_dense"
if col not in df.columns:
logger.warning("sarm_progress.parquet has no progress columns — ignoring")
return None
logger.info("Loaded SARM progress (%s) for %s (%d rows)", col, repo_id, len(df))
return df.rename(columns={col: "progress"})[["episode_index", "frame_index", "progress"]]
except Exception as exc:
logger.warning("Could not load sarm_progress.parquet for %s: %s", repo_id, exc)
return None
def load_episodes(
repo_id: str,
n_joints: int = 16,
max_episodes: int | None = None,
) -> list[dict]:
dataset = LeRobotDataset(repo_id, download_videos=False)
raw = dataset.hf_dataset
sarm_df = _load_sarm_progress(repo_id)
# Build per-episode progress arrays from SARM parquet (indexed by frame_index)
sarm_by_ep: dict[int, dict[int, float]] = {}
if sarm_df is not None:
if max_episodes is not None:
sarm_df = sarm_df[sarm_df["episode_index"] < max_episodes]
for ep_id, grp in sarm_df.groupby("episode_index"):
sarm_by_ep[int(ep_id)] = dict(
zip(grp["frame_index"].astype(int), grp["progress"].astype(float))
)
episodes: dict[int, dict] = defaultdict(lambda: {"actions": [], "progress": []})
for row in raw:
ep = int(row["episode_index"])
if max_episodes is not None and ep >= max_episodes:
continue
action = np.array(row["action"], dtype=np.float32)[:n_joints]
episodes[ep]["actions"].append(action)
fi = int(row["frame_index"])
ep_prog = sarm_by_ep.get(ep, {})
episodes[ep]["progress"].append(ep_prog.get(fi, float("nan")))
has_sarm = len(sarm_lookup) > 0
result = []
for ep_id, d in sorted(episodes.items()):
actions = np.stack(d["actions"])
T = len(actions)
if has_sarm:
prog = np.array(d["progress"], dtype=np.float32)
prog = np.clip(np.nan_to_num(prog, nan=0.0), 0.0, 1.0)
prog = np.maximum.accumulate(prog)
else:
prog = np.linspace(0.0, 1.0, T, dtype=np.float32)
result.append({"episode": ep_id, "actions": actions, "progress": prog})
src = "SARM" if has_sarm else "time-based"
logger.info("Progress source: %s", src)
return result
# ── Step 1: Filter short episodes ────────────────────────────────────────
def auto_length_threshold(
episodes_full: list[dict], episodes_hq: list[dict]
) -> int:
all_lengths = np.array(
[e["actions"].shape[0] for e in episodes_full + episodes_hq]
)
kde = gaussian_kde(all_lengths, bw_method=0.25)
xs = np.linspace(all_lengths.min(), np.percentile(all_lengths, 40), 300)
return int(xs[np.argmin(kde(xs))])
def plot_length_distribution(
episodes_full: list[dict],
episodes_hq: list[dict],
threshold: int,
out_path: Path,
) -> None:
lens_full = np.array([e["actions"].shape[0] for e in episodes_full])
lens_hq = np.array([e["actions"].shape[0] for e in episodes_hq])
all_lens = np.concatenate([lens_full, lens_hq])
fig, ax = plt.subplots(figsize=(10, 5))
fig.patch.set_facecolor(BG)
_style_ax(ax)
bins = np.linspace(all_lens.min(), all_lens.max(), 50)
ax.hist(lens_full, bins=bins, alpha=0.5, color=C_FULL, label="Full/Mixed")
ax.hist(lens_hq, bins=bins, alpha=0.5, color=C_HQ, label="HQ")
xs = np.linspace(all_lens.min(), all_lens.max(), 300)
kde = gaussian_kde(all_lens, bw_method=0.25)
ax.plot(xs, kde(xs) * len(all_lens) * (bins[1] - bins[0]), color=TEXT, lw=1.5, label="KDE (combined)")
ax.axvline(threshold, color="#ff4b4b", ls="--", lw=1.5, label=f"Threshold = {threshold}")
ax.set_xlabel("Episode length (frames)", color=SUB)
ax.set_ylabel("Count", color=SUB)
ax.set_title("Episode Length Distribution", color=TEXT, fontsize=13)
ax.legend(facecolor=CARD, edgecolor=BORDER, labelcolor=TEXT, fontsize=8)
_save(fig, out_path)
def filter_episodes(episodes: list[dict], min_length: int) -> list[dict]:
kept = [e for e in episodes if e["actions"].shape[0] >= min_length]
logger.info("Kept %d / %d episodes (min_length=%d)", len(kept), len(episodes), min_length)
return kept
# ── Step 2: Extract chunks ───────────────────────────────────────────────
def extract_chunks(
episodes: list[dict],
chunk_size: int = 30,
chunk_stride: int = 15,
) -> list[dict]:
chunks = []
for ep in episodes:
actions = ep["actions"]
T = len(actions)
prog = ep["progress"]
for t in range(0, T - chunk_size, chunk_stride):
chunk = actions[t : t + chunk_size]
p_start = float(prog[t])
p_end = float(prog[min(t + chunk_size, T - 1)])
chunks.append({
"action_mean": chunk.mean(axis=0).astype(np.float32),
"action_flat": chunk.flatten().astype(np.float32),
"progress_start": p_start,
"progress_delta": p_end - p_start,
"episode": ep["episode"],
})
return chunks
# ── Step 3: Adaptive progress bands ─────────────────────────────────────
def make_bands(n_bands: int = 5) -> list[tuple[float, float]]:
edges = np.linspace(0.0, 1.0, n_bands + 1)
return [(float(edges[i]), float(edges[i + 1])) for i in range(n_bands)]
def assign_bands(
chunks: list[dict], band_edges: list[tuple[float, float]]
) -> list[dict]:
n = len(band_edges)
for c in chunks:
p = c["progress_start"]
c["band"] = next(
(bi for bi, (lo, hi) in enumerate(band_edges) if p < hi),
n - 1,
)
return chunks
def split_by_band(chunks: list[dict], n_bands: int) -> dict[int, list[dict]]:
out: dict[int, list[dict]] = {b: [] for b in range(n_bands)}
for c in chunks:
out[c["band"]].append(c)
return out
# ── Step 4: Intra-band action variance ──────────────────────────────────
def band_variance_matrix(
bands: dict[int, list[dict]], n_bands: int, n_joints: int
) -> np.ndarray:
var_mat = np.full((n_bands, n_joints), np.nan)
for b, clist in bands.items():
if len(clist) < 3:
continue
means = np.stack([c["action_mean"] for c in clist])
var_mat[b] = np.var(means, axis=0)
return var_mat
def plot_variance_heatmap(
var_full: np.ndarray,
var_hq: np.ndarray,
band_edges: list[tuple[float, float]],
out_path: Path,
) -> None:
n_bands = var_full.shape[0]
vmin = 0.0
vmax = max(np.nanmax(var_full), np.nanmax(var_hq))
band_labels = [f"{lo:.0%}{hi:.0%}" for lo, hi in band_edges]
joint_labels = [f"J{j}" for j in range(var_full.shape[1])]
fig, axes = plt.subplots(3, 1, figsize=(12, 10), gridspec_kw={"height_ratios": [3, 3, 2]})
fig.patch.set_facecolor(BG)
fig.suptitle("Intra-Band Action Variance", color=TEXT, fontsize=14, y=0.98)
for ax_idx, (mat, label) in enumerate([(var_full, "Full/Mixed"), (var_hq, "HQ")]):
ax = axes[ax_idx]
_style_ax(ax)
im = ax.imshow(mat, aspect="auto", cmap="YlOrRd", vmin=vmin, vmax=vmax)
ax.set_yticks(range(n_bands))
ax.set_yticklabels(band_labels, fontsize=7, color=SUB)
ax.set_xticks(range(var_full.shape[1]))
ax.set_xticklabels(joint_labels, fontsize=7, color=SUB)
ax.set_title(f"Panel {'A' if ax_idx == 0 else 'B'}: {label}", color=TEXT, fontsize=11)
fig.colorbar(im, ax=ax, fraction=0.02, pad=0.02)
with np.errstate(invalid="ignore"):
mean_full = np.nanmean(var_full, axis=1)
mean_hq = np.nanmean(var_hq, axis=1)
ratio = np.where(np.isnan(mean_full) | np.isnan(mean_hq), np.nan,
mean_full / (mean_hq + 1e-8))
ax_bar = axes[2]
_style_ax(ax_bar)
colors = [
"#ff4b4b" if r > 2.0 else "#ffaa33" if r > 1.2 else C_HQ
for r in ratio
]
ax_bar.bar(range(n_bands), ratio, color=colors, edgecolor=BORDER)
ax_bar.axhline(1.0, color=SUB, ls="--", lw=0.8)
ax_bar.set_xticks(range(n_bands))
ax_bar.set_xticklabels(band_labels, fontsize=7, color=SUB)
ax_bar.set_ylabel("Variance ratio\n(Full / HQ)", color=SUB, fontsize=9)
ax_bar.set_title("Panel C: Variance Ratio per Band", color=TEXT, fontsize=11)
fig.tight_layout(rect=[0, 0, 1, 0.96])
_save(fig, out_path)
# ── Step 5: Progress delta per band ──────────────────────────────────────
def plot_progress_delta(
bands_full: dict[int, list[dict]],
bands_hq: dict[int, list[dict]],
band_edges: list[tuple[float, float]],
out_path: Path,
) -> None:
n_bands = len(band_edges)
band_labels = [f"{lo:.0%}{hi:.0%}" for lo, hi in band_edges]
x = np.arange(n_bands)
w = 0.35
means_full, stds_full = [], []
means_hq, stds_hq = [], []
all_deltas_full, all_deltas_hq = [], []
for b in range(n_bands):
df = np.array([c["progress_delta"] for c in bands_full.get(b, [])])
dh = np.array([c["progress_delta"] for c in bands_hq.get(b, [])])
means_full.append(np.mean(df) if len(df) > 0 else 0)
stds_full.append(np.std(df) if len(df) > 0 else 0)
means_hq.append(np.mean(dh) if len(dh) > 0 else 0)
stds_hq.append(np.std(dh) if len(dh) > 0 else 0)
all_deltas_full.extend(df.tolist())
all_deltas_hq.extend(dh.tolist())
fig, (ax_bar, ax_viol) = plt.subplots(1, 2, figsize=(14, 5), gridspec_kw={"width_ratios": [3, 1]})
fig.patch.set_facecolor(BG)
fig.suptitle("Progress Delta per Chunk", color=TEXT, fontsize=14)
_style_ax(ax_bar)
ax_bar.bar(x - w / 2, means_full, w, yerr=stds_full, color=C_FULL, edgecolor=BORDER,
capsize=3, label="Full/Mixed", error_kw={"ecolor": SUB})
ax_bar.bar(x + w / 2, means_hq, w, yerr=stds_hq, color=C_HQ, edgecolor=BORDER,
capsize=3, label="HQ", error_kw={"ecolor": SUB})
ax_bar.set_xticks(x)
ax_bar.set_xticklabels(band_labels, fontsize=7, color=SUB, rotation=30)
ax_bar.set_ylabel("Mean progress Δ", color=SUB)
ax_bar.legend(facecolor=CARD, edgecolor=BORDER, labelcolor=TEXT, fontsize=8)
_style_ax(ax_viol)
data_viol = [np.array(all_deltas_full), np.array(all_deltas_hq)]
if all(len(d) > 0 for d in data_viol):
parts = ax_viol.violinplot(data_viol, positions=[0, 1], showmeans=True, showmedians=True)
for pc, c in zip(parts["bodies"], [C_FULL, C_HQ]):
pc.set_facecolor(c)
pc.set_alpha(0.7)
for key in ("cmeans", "cmedians", "cbars", "cmins", "cmaxes"):
if key in parts:
parts[key].set_color(SUB)
ax_viol.set_xticks([0, 1])
ax_viol.set_xticklabels(["Full", "HQ"], color=SUB)
ax_viol.set_ylabel("Progress Δ", color=SUB)
ax_viol.set_title("Overall Distribution", color=TEXT, fontsize=10)
fig.tight_layout()
_save(fig, out_path)
# ── Step 6: GMM + BIC per band ──────────────────────────────────────────
def gmm_optimal_k(
band_chunks: list[dict],
pca_components: int = 15,
max_k: int = 12,
seed: int = 42,
) -> int | None:
if len(band_chunks) < 20:
return None
X = np.stack([c["action_flat"] for c in band_chunks])
X = StandardScaler().fit_transform(X)
n = min(pca_components, X.shape[1], X.shape[0] - 1)
X_r = PCA(n_components=n, random_state=seed).fit_transform(X)
bics = []
for k in range(1, min(max_k + 1, len(X_r) // 6)):
gmm = GaussianMixture(
n_components=k, covariance_type="full",
n_init=5, max_iter=300, random_state=seed,
)
gmm.fit(X_r)
bics.append((k, gmm.bic(X_r)))
if not bics:
return None
return min(bics, key=lambda x: x[1])[0]
def plot_gmm_bic(
bands_full: dict[int, list[dict]],
bands_hq: dict[int, list[dict]],
band_edges: list[tuple[float, float]],
seed: int,
out_path: Path,
) -> tuple[list[int | None], list[int | None]]:
n_bands = len(band_edges)
ks_full = [gmm_optimal_k(bands_full.get(b, []), seed=seed) for b in range(n_bands)]
ks_hq = [gmm_optimal_k(bands_hq.get(b, []), seed=seed) for b in range(n_bands)]
band_labels = [f"{lo:.0%}{hi:.0%}" for lo, hi in band_edges]
fig, ax = plt.subplots(figsize=(10, 5))
fig.patch.set_facecolor(BG)
_style_ax(ax)
xs = np.arange(n_bands)
valid_full = [(i, k) for i, k in enumerate(ks_full) if k is not None]
valid_hq = [(i, k) for i, k in enumerate(ks_hq) if k is not None]
if valid_full:
xi, yi = zip(*valid_full)
ax.plot(xi, yi, "o-", color=C_FULL, label="Full/Mixed", lw=2, markersize=7)
if valid_hq:
xi, yi = zip(*valid_hq)
ax.plot(xi, yi, "o-", color=C_HQ, label="HQ", lw=2, markersize=7)
if valid_full and valid_hq:
all_x = sorted(set([i for i, _ in valid_full]) & set([i for i, _ in valid_hq]))
if len(all_x) >= 2:
kf_interp = {i: k for i, k in valid_full}
kh_interp = {i: k for i, k in valid_hq}
shared_x = [i for i in all_x if i in kf_interp and i in kh_interp]
yf = [kf_interp[i] for i in shared_x]
yh = [kh_interp[i] for i in shared_x]
ax.fill_between(shared_x, yf, yh, alpha=0.15, color=TEXT)
ax.set_xticks(xs)
ax.set_xticklabels(band_labels, fontsize=7, color=SUB, rotation=30)
ax.set_ylabel("Optimal K (GMM-BIC)", color=SUB)
ax.set_title("Number of Distinct Strategies per Band", color=TEXT, fontsize=13)
ax.legend(facecolor=CARD, edgecolor=BORDER, labelcolor=TEXT, fontsize=9)
ax.yaxis.set_major_locator(plt.MaxNLocator(integer=True))
fig.tight_layout()
_save(fig, out_path)
return ks_full, ks_hq
# ── Step 7: PCA scatter per band ────────────────────────────────────────
def plot_pca_scatter(
bands_full: dict[int, list[dict]],
bands_hq: dict[int, list[dict]],
band_edges: list[tuple[float, float]],
out_path: Path,
) -> None:
n_plot = min(4, len(band_edges))
fig, axes = plt.subplots(2, n_plot, figsize=(4 * n_plot, 7))
fig.patch.set_facecolor(BG)
fig.suptitle("PCA of Action Chunks per Band", color=TEXT, fontsize=14)
if n_plot == 1:
axes = axes.reshape(2, 1)
for col, b in enumerate(range(n_plot)):
cf = bands_full.get(b, [])
ch = bands_hq.get(b, [])
lo, hi = band_edges[b]
for row, (clist, color, label) in enumerate([
(cf, C_FULL, "Full/Mixed"), (ch, C_HQ, "HQ")
]):
ax = axes[row, col]
_style_ax(ax)
if row == 0:
ax.set_title(f"{lo:.0%}{hi:.0%}", color=TEXT, fontsize=10)
if col == 0:
ax.set_ylabel(label, color=SUB, fontsize=9)
if len(cf) < 3 or len(ch) < 3:
ax.text(0.5, 0.5, "Too few\nchunks", transform=ax.transAxes,
ha="center", va="center", color=SUB, fontsize=9)
continue
X_full_b = np.stack([c["action_flat"] for c in cf])
X_hq_b = np.stack([c["action_flat"] for c in ch])
X_all = np.vstack([X_full_b, X_hq_b])
X_all = StandardScaler().fit_transform(X_all)
X_2d = PCA(n_components=2, random_state=42).fit_transform(X_all)
X_2d_full = X_2d[: len(cf)]
X_2d_hq = X_2d[len(cf) :]
pts = X_2d_full if row == 0 else X_2d_hq
ax.scatter(pts[:, 0], pts[:, 1], s=8, alpha=0.5, color=color, edgecolors="none")
fig.tight_layout(rect=[0, 0, 1, 0.95])
_save(fig, out_path)
# ── Plot 1: Chunk counts per band ───────────────────────────────────────
def plot_chunk_counts(
bands_full: dict[int, list[dict]],
bands_hq: dict[int, list[dict]],
band_edges: list[tuple[float, float]],
out_path: Path,
) -> None:
n_bands = len(band_edges)
band_labels = [f"{lo:.0%}{hi:.0%}" for lo, hi in band_edges]
x = np.arange(n_bands)
w = 0.35
counts_full = [len(bands_full.get(b, [])) for b in range(n_bands)]
counts_hq = [len(bands_hq.get(b, [])) for b in range(n_bands)]
fig, ax = plt.subplots(figsize=(10, 5))
fig.patch.set_facecolor(BG)
_style_ax(ax)
ax.bar(x - w / 2, counts_full, w, color=C_FULL, edgecolor=BORDER, label="Full/Mixed")
ax.bar(x + w / 2, counts_hq, w, color=C_HQ, edgecolor=BORDER, label="HQ")
ax.set_xticks(x)
ax.set_xticklabels(band_labels, fontsize=7, color=SUB, rotation=30)
ax.set_ylabel("Chunk count", color=SUB)
ax.set_title("Chunk Counts per Progress Band", color=TEXT, fontsize=13)
ax.legend(facecolor=CARD, edgecolor=BORDER, labelcolor=TEXT, fontsize=8)
fig.tight_layout()
_save(fig, out_path)
# ── Summary figure ───────────────────────────────────────────────────────
def plot_summary(
var_full: np.ndarray,
var_hq: np.ndarray,
band_edges: list[tuple[float, float]],
ks_full: list[int | None],
ks_hq: list[int | None],
bands_full: dict[int, list[dict]],
bands_hq: dict[int, list[dict]],
out_path: Path,
) -> None:
with np.errstate(invalid="ignore"):
mean_full = np.nanmean(var_full, axis=1)
mean_hq = np.nanmean(var_hq, axis=1)
ratio = np.where(np.isnan(mean_full) | np.isnan(mean_hq), np.nan,
mean_full / (mean_hq + 1e-8))
valid_ratio = ratio[~np.isnan(ratio)]
mean_ratio = float(np.mean(valid_ratio)) if len(valid_ratio) > 0 else float("nan")
peak_idx = int(np.argmax(valid_ratio)) if len(valid_ratio) > 0 else 0
peak_ratio = float(valid_ratio[peak_idx]) if len(valid_ratio) > 0 else float("nan")
lo, hi = band_edges[peak_idx]
peak_band = f"{lo:.0%}{hi:.0%}"
valid_kf = [k for k in ks_full if k is not None]
valid_kh = [k for k in ks_hq if k is not None]
mean_k_full = np.mean(valid_kf) if valid_kf else float("nan")
mean_k_hq = np.mean(valid_kh) if valid_kh else float("nan")
n_bands = len(band_edges)
deltas_full = [c["progress_delta"] for b in range(n_bands) for c in bands_full.get(b, [])]
deltas_hq = [c["progress_delta"] for b in range(n_bands) for c in bands_hq.get(b, [])]
mean_delta_full = float(np.mean(deltas_full)) if deltas_full else float("nan")
mean_delta_hq = float(np.mean(deltas_hq)) if deltas_hq else float("nan")
rows = [
("Mean variance ratio (Full / HQ)", f"{mean_ratio:.2f}x"),
("Peak variance ratio", f"{peak_ratio:.2f}x at {peak_band}"),
("Mean GMM K — Full", f"{mean_k_full:.1f}"),
("Mean GMM K — HQ", f"{mean_k_hq:.1f}"),
("Mean progress Δ — Full", f"{mean_delta_full:.4f}"),
("Mean progress Δ — HQ", f"{mean_delta_hq:.4f}"),
]
fig, ax = plt.subplots(figsize=(8, 3))
fig.patch.set_facecolor(BG)
ax.set_facecolor(CARD)
ax.axis("off")
table = ax.table(
cellText=[[m, v] for m, v in rows],
colLabels=["Metric", "Value"],
loc="center",
cellLoc="left",
)
table.auto_set_font_size(False)
table.set_fontsize(10)
for key, cell in table.get_celld().items():
cell.set_edgecolor(BORDER)
cell.set_facecolor(CARD)
cell.set_text_props(color=TEXT)
if key[0] == 0:
cell.set_text_props(color=TEXT, fontweight="bold")
table.scale(1, 1.6)
ax.set_title("Summary Statistics", color=TEXT, fontsize=13, pad=15)
fig.tight_layout()
_save(fig, out_path)
for metric, value in rows:
logger.info(" %s: %s", metric, value)
# ── Main ─────────────────────────────────────────────────────────────────
def main(args: argparse.Namespace) -> None:
out = Path(args.output_dir)
out.mkdir(parents=True, exist_ok=True)
logger.info("Loading FULL dataset: %s", args.full_dataset)
episodes_full = load_episodes(args.full_dataset, args.n_joints, args.max_episodes)
logger.info("Loading HQ dataset: %s", args.hq_dataset)
episodes_hq = load_episodes(args.hq_dataset, args.n_joints, args.max_episodes)
logger.info("Loaded %d full episodes, %d HQ episodes", len(episodes_full), len(episodes_hq))
# Step 1: length threshold + filter
if args.min_episode_length is not None:
threshold = args.min_episode_length
else:
threshold = auto_length_threshold(episodes_full, episodes_hq)
logger.info("Episode length threshold: %d", threshold)
plot_length_distribution(episodes_full, episodes_hq, threshold, out / "0_length_distribution.png")
episodes_full = filter_episodes(episodes_full, threshold)
episodes_hq = filter_episodes(episodes_hq, threshold)
# Step 2: extract chunks
chunks_full = extract_chunks(episodes_full, args.chunk_size, args.chunk_stride)
chunks_hq = extract_chunks(episodes_hq, args.chunk_size, args.chunk_stride)
logger.info("Extracted %d full chunks, %d HQ chunks", len(chunks_full), len(chunks_hq))
# Step 3: fixed equal-width bands over episode-relative progress
band_edges = make_bands(args.n_bands)
n_bands = len(band_edges)
logger.info("Progress bands (%d): %s", n_bands,
[f"{lo:.0%}{hi:.0%}" for lo, hi in band_edges])
chunks_full = assign_bands(chunks_full, band_edges)
chunks_hq = assign_bands(chunks_hq, band_edges)
bands_full = split_by_band(chunks_full, n_bands)
bands_hq = split_by_band(chunks_hq, n_bands)
# Plot 1: chunk counts
plot_chunk_counts(bands_full, bands_hq, band_edges, out / "1_chunk_counts_per_band.png")
# Step 4: variance heatmap
var_full = band_variance_matrix(bands_full, n_bands, args.n_joints)
var_hq = band_variance_matrix(bands_hq, n_bands, args.n_joints)
plot_variance_heatmap(var_full, var_hq, band_edges, out / "2_variance_heatmap.png")
# Step 5: progress delta
plot_progress_delta(bands_full, bands_hq, band_edges, out / "3_progress_delta_per_band.png")
# Step 6: GMM BIC
ks_full, ks_hq = plot_gmm_bic(bands_full, bands_hq, band_edges, args.seed, out / "4_gmm_bic_per_band.png")
# Step 7: PCA scatter
plot_pca_scatter(bands_full, bands_hq, band_edges, out / "5_pca_per_band.png")
# Summary
plot_summary(var_full, var_hq, band_edges, ks_full, ks_hq,
bands_full, bands_hq, out / "6_summary.png")
logger.info("All figures saved to %s", out)
if __name__ == "__main__":
p = argparse.ArgumentParser(
description="Chunk-level multi-modality analysis: Full/Mixed vs HQ dataset.",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
p.add_argument("--full-dataset", default="lerobot-data-collection/level12_rac_2_2026-02-08_1")
p.add_argument("--hq-dataset", default="lerobot-data-collection/level2_final_quality3_trim_0_hil_data")
p.add_argument("--output-dir", default="./chunk_analysis")
p.add_argument("--chunk-size", type=int, default=30)
p.add_argument("--chunk-stride", type=int, default=15)
p.add_argument("--n-bands", type=int, default=5, help="Number of equal-width progress bands")
p.add_argument("--max-episodes", type=int, default=500)
p.add_argument("--n-joints", type=int, default=16)
p.add_argument("--min-episode-length", type=int, default=None,
help="Override auto-detected length filter threshold")
p.add_argument("--seed", type=int, default=42)
args = p.parse_args()
main(args)