From 06385902df51b940b4f54cc1bbc9b23a1929c55f Mon Sep 17 00:00:00 2001 From: Pepijn Date: Fri, 13 Mar 2026 09:28:26 -0700 Subject: [PATCH] Add create reward visualization and multimodal analysis tool --- .../chunk_multimodality_analysis.py | 659 ++++++++++++++++++ .../create_progress_videos.py | 471 +++++++++++++ 2 files changed, 1130 insertions(+) create mode 100644 examples/dataset/visualization_tools/chunk_multimodality_analysis.py create mode 100644 examples/dataset/visualization_tools/create_progress_videos.py diff --git a/examples/dataset/visualization_tools/chunk_multimodality_analysis.py b/examples/dataset/visualization_tools/chunk_multimodality_analysis.py new file mode 100644 index 000000000..837fad5a8 --- /dev/null +++ b/examples/dataset/visualization_tools/chunk_multimodality_analysis.py @@ -0,0 +1,659 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Chunk-level multi-modality analysis for comparing full/mixed vs curated datasets. + +Treats each action chunk (sliding window of CHUNK_SIZE consecutive frames) as the +atomic unit, tagged by the SARM progress score at its start frame. For each +progress band, compares the full vs HQ dataset on: + + 1. Intra-band action variance + 2. Progress delta per chunk + 3. GMM + BIC optimal K (number of distinct strategies) + 4. PCA embedding (visual cluster inspection) + +Usage: + python chunk_multimodality_analysis.py \\ + --full-dataset lerobot-data-collection/level12_rac_2_2026-02-08_1 \\ + --hq-dataset lerobot-data-collection/level2_final_quality3 \\ + --output-dir ./chunk_analysis +""" + +from __future__ import annotations + +import argparse +import logging +from collections import defaultdict +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +from scipy.stats import gaussian_kde +from sklearn.decomposition import PCA +from sklearn.mixture import GaussianMixture +from sklearn.preprocessing import StandardScaler + +from lerobot.datasets.lerobot_dataset import LeRobotDataset + +logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") +logger = logging.getLogger(__name__) + +# ── Visual style ────────────────────────────────────────────────────────── + +BG = "#0e1117" +CARD = "#1a1d27" +BORDER = "#2a2d3a" +SUB = "#8b8fa8" +TEXT = "#e8eaf0" +C_FULL = "#f7934f" +C_HQ = "#4dc98a" + + +def _style_ax(ax: plt.Axes) -> None: + ax.set_facecolor(CARD) + ax.tick_params(colors=SUB, labelsize=8) + for spine in ax.spines.values(): + spine.set_color(BORDER) + + +def _save(fig: plt.Figure, path: Path) -> None: + fig.savefig(path, dpi=150, bbox_inches="tight", facecolor=BG) + plt.close(fig) + logger.info("Saved %s", path) + + +# ── Step 0: Load episodes ──────────────────────────────────────────────── + +def load_episodes( + repo_id: str, + n_joints: int = 16, + max_episodes: int | None = None, +) -> list[dict]: + dataset = LeRobotDataset(repo_id, download_videos=False) + raw = dataset.hf_dataset + episodes: dict[int, dict] = defaultdict(lambda: {"actions": [], "progress": []}) + + for row in raw: + ep = int(row["episode_index"]) + if max_episodes is not None and ep >= max_episodes: + continue + action = np.array(row["action"], dtype=np.float32)[:n_joints] + episodes[ep]["actions"].append(action) + progress = float(row.get("next.reward", float("nan"))) + episodes[ep]["progress"].append(progress) + + result = [] + for ep_id, d in sorted(episodes.items()): + result.append({ + "episode": ep_id, + "actions": np.stack(d["actions"]), + "progress": np.array(d["progress"]), + }) + return result + + +# ── Step 1: Filter short episodes ──────────────────────────────────────── + +def auto_length_threshold( + episodes_full: list[dict], episodes_hq: list[dict] +) -> int: + all_lengths = np.array( + [e["actions"].shape[0] for e in episodes_full + episodes_hq] + ) + kde = gaussian_kde(all_lengths, bw_method=0.25) + xs = np.linspace(all_lengths.min(), np.percentile(all_lengths, 40), 300) + return int(xs[np.argmin(kde(xs))]) + + +def plot_length_distribution( + episodes_full: list[dict], + episodes_hq: list[dict], + threshold: int, + out_path: Path, +) -> None: + lens_full = np.array([e["actions"].shape[0] for e in episodes_full]) + lens_hq = np.array([e["actions"].shape[0] for e in episodes_hq]) + all_lens = np.concatenate([lens_full, lens_hq]) + + fig, ax = plt.subplots(figsize=(10, 5)) + fig.patch.set_facecolor(BG) + _style_ax(ax) + + bins = np.linspace(all_lens.min(), all_lens.max(), 50) + ax.hist(lens_full, bins=bins, alpha=0.5, color=C_FULL, label="Full/Mixed") + ax.hist(lens_hq, bins=bins, alpha=0.5, color=C_HQ, label="HQ") + + xs = np.linspace(all_lens.min(), all_lens.max(), 300) + kde = gaussian_kde(all_lens, bw_method=0.25) + ax.plot(xs, kde(xs) * len(all_lens) * (bins[1] - bins[0]), color=TEXT, lw=1.5, label="KDE (combined)") + + ax.axvline(threshold, color="#ff4b4b", ls="--", lw=1.5, label=f"Threshold = {threshold}") + ax.set_xlabel("Episode length (frames)", color=SUB) + ax.set_ylabel("Count", color=SUB) + ax.set_title("Episode Length Distribution", color=TEXT, fontsize=13) + ax.legend(facecolor=CARD, edgecolor=BORDER, labelcolor=TEXT, fontsize=8) + _save(fig, out_path) + + +def filter_episodes(episodes: list[dict], min_length: int) -> list[dict]: + kept = [e for e in episodes if e["actions"].shape[0] >= min_length] + logger.info("Kept %d / %d episodes (min_length=%d)", len(kept), len(episodes), min_length) + return kept + + +# ── Step 2: Extract chunks ─────────────────────────────────────────────── + +def extract_chunks( + episodes: list[dict], + chunk_size: int = 30, + chunk_stride: int = 15, +) -> list[dict]: + chunks = [] + for ep in episodes: + actions = ep["actions"] + T = len(actions) + + prog = np.clip(np.nan_to_num(ep["progress"], nan=0.0), 0.0, 1.0) + prog = np.maximum.accumulate(prog) + + for t in range(0, T - chunk_size, chunk_stride): + chunk = actions[t : t + chunk_size] + p_start = float(prog[t]) + p_end = float(prog[min(t + chunk_size, T - 1)]) + + chunks.append({ + "action_mean": chunk.mean(axis=0).astype(np.float32), + "action_flat": chunk.flatten().astype(np.float32), + "progress_start": p_start, + "progress_delta": p_end - p_start, + "episode": ep["episode"], + }) + return chunks + + +# ── Step 3: Adaptive progress bands ───────────────────────────────────── + +def fit_adaptive_bands( + chunks: list[dict], min_per_band: int = 20 +) -> list[tuple[float, float]]: + prog_vals = np.array([c["progress_start"] for c in chunks]) + fine_edges = np.linspace(0.0, 1.0, 11) + + band_edges: list[tuple[float, float]] = [] + i = 0 + while i < len(fine_edges) - 1: + lo, hi = fine_edges[i], fine_edges[i + 1] + j = i + 1 + while ( + np.sum((prog_vals >= lo) & (prog_vals < hi)) < min_per_band + and j < len(fine_edges) - 1 + ): + j += 1 + hi = fine_edges[j] + band_edges.append((lo, hi)) + i = j + return band_edges + + +def assign_bands( + chunks: list[dict], band_edges: list[tuple[float, float]] +) -> list[dict]: + n = len(band_edges) + for c in chunks: + p = c["progress_start"] + c["band"] = next( + (bi for bi, (lo, hi) in enumerate(band_edges) if p < hi), + n - 1, + ) + return chunks + + +def split_by_band(chunks: list[dict], n_bands: int) -> dict[int, list[dict]]: + out: dict[int, list[dict]] = {b: [] for b in range(n_bands)} + for c in chunks: + out[c["band"]].append(c) + return out + + +# ── Step 4: Intra-band action variance ────────────────────────────────── + +def band_variance_matrix( + bands: dict[int, list[dict]], n_bands: int, n_joints: int +) -> np.ndarray: + var_mat = np.full((n_bands, n_joints), np.nan) + for b, clist in bands.items(): + if len(clist) < 3: + continue + means = np.stack([c["action_mean"] for c in clist]) + var_mat[b] = np.var(means, axis=0) + return var_mat + + +def plot_variance_heatmap( + var_full: np.ndarray, + var_hq: np.ndarray, + band_edges: list[tuple[float, float]], + out_path: Path, +) -> None: + n_bands = var_full.shape[0] + vmin = 0.0 + vmax = max(np.nanmax(var_full), np.nanmax(var_hq)) + + band_labels = [f"{lo:.0%}–{hi:.0%}" for lo, hi in band_edges] + joint_labels = [f"J{j}" for j in range(var_full.shape[1])] + + fig, axes = plt.subplots(3, 1, figsize=(12, 10), gridspec_kw={"height_ratios": [3, 3, 2]}) + fig.patch.set_facecolor(BG) + fig.suptitle("Intra-Band Action Variance", color=TEXT, fontsize=14, y=0.98) + + for ax_idx, (mat, label) in enumerate([(var_full, "Full/Mixed"), (var_hq, "HQ")]): + ax = axes[ax_idx] + _style_ax(ax) + im = ax.imshow(mat, aspect="auto", cmap="YlOrRd", vmin=vmin, vmax=vmax) + ax.set_yticks(range(n_bands)) + ax.set_yticklabels(band_labels, fontsize=7, color=SUB) + ax.set_xticks(range(var_full.shape[1])) + ax.set_xticklabels(joint_labels, fontsize=7, color=SUB) + ax.set_title(f"Panel {'A' if ax_idx == 0 else 'B'}: {label}", color=TEXT, fontsize=11) + fig.colorbar(im, ax=ax, fraction=0.02, pad=0.02) + + ratio = np.nanmean(var_full, axis=1) / (np.nanmean(var_hq, axis=1) + 1e-8) + ax_bar = axes[2] + _style_ax(ax_bar) + colors = [ + "#ff4b4b" if r > 2.0 else "#ffaa33" if r > 1.2 else C_HQ + for r in ratio + ] + ax_bar.bar(range(n_bands), ratio, color=colors, edgecolor=BORDER) + ax_bar.axhline(1.0, color=SUB, ls="--", lw=0.8) + ax_bar.set_xticks(range(n_bands)) + ax_bar.set_xticklabels(band_labels, fontsize=7, color=SUB) + ax_bar.set_ylabel("Variance ratio\n(Full / HQ)", color=SUB, fontsize=9) + ax_bar.set_title("Panel C: Variance Ratio per Band", color=TEXT, fontsize=11) + + fig.tight_layout(rect=[0, 0, 1, 0.96]) + _save(fig, out_path) + + +# ── Step 5: Progress delta per band ────────────────────────────────────── + +def plot_progress_delta( + bands_full: dict[int, list[dict]], + bands_hq: dict[int, list[dict]], + band_edges: list[tuple[float, float]], + out_path: Path, +) -> None: + n_bands = len(band_edges) + band_labels = [f"{lo:.0%}–{hi:.0%}" for lo, hi in band_edges] + x = np.arange(n_bands) + w = 0.35 + + means_full, stds_full = [], [] + means_hq, stds_hq = [], [] + all_deltas_full, all_deltas_hq = [], [] + + for b in range(n_bands): + df = np.array([c["progress_delta"] for c in bands_full.get(b, [])]) + dh = np.array([c["progress_delta"] for c in bands_hq.get(b, [])]) + means_full.append(np.mean(df) if len(df) > 0 else 0) + stds_full.append(np.std(df) if len(df) > 0 else 0) + means_hq.append(np.mean(dh) if len(dh) > 0 else 0) + stds_hq.append(np.std(dh) if len(dh) > 0 else 0) + all_deltas_full.extend(df.tolist()) + all_deltas_hq.extend(dh.tolist()) + + fig, (ax_bar, ax_viol) = plt.subplots(1, 2, figsize=(14, 5), gridspec_kw={"width_ratios": [3, 1]}) + fig.patch.set_facecolor(BG) + fig.suptitle("Progress Delta per Chunk", color=TEXT, fontsize=14) + + _style_ax(ax_bar) + ax_bar.bar(x - w / 2, means_full, w, yerr=stds_full, color=C_FULL, edgecolor=BORDER, + capsize=3, label="Full/Mixed", error_kw={"ecolor": SUB}) + ax_bar.bar(x + w / 2, means_hq, w, yerr=stds_hq, color=C_HQ, edgecolor=BORDER, + capsize=3, label="HQ", error_kw={"ecolor": SUB}) + ax_bar.set_xticks(x) + ax_bar.set_xticklabels(band_labels, fontsize=7, color=SUB, rotation=30) + ax_bar.set_ylabel("Mean progress Δ", color=SUB) + ax_bar.legend(facecolor=CARD, edgecolor=BORDER, labelcolor=TEXT, fontsize=8) + + _style_ax(ax_viol) + data_viol = [np.array(all_deltas_full), np.array(all_deltas_hq)] + if all(len(d) > 0 for d in data_viol): + parts = ax_viol.violinplot(data_viol, positions=[0, 1], showmeans=True, showmedians=True) + for pc, c in zip(parts["bodies"], [C_FULL, C_HQ]): + pc.set_facecolor(c) + pc.set_alpha(0.7) + for key in ("cmeans", "cmedians", "cbars", "cmins", "cmaxes"): + if key in parts: + parts[key].set_color(SUB) + ax_viol.set_xticks([0, 1]) + ax_viol.set_xticklabels(["Full", "HQ"], color=SUB) + ax_viol.set_ylabel("Progress Δ", color=SUB) + ax_viol.set_title("Overall Distribution", color=TEXT, fontsize=10) + + fig.tight_layout() + _save(fig, out_path) + + +# ── Step 6: GMM + BIC per band ────────────────────────────────────────── + +def gmm_optimal_k( + band_chunks: list[dict], + pca_components: int = 15, + max_k: int = 7, + seed: int = 42, +) -> int | None: + if len(band_chunks) < 20: + return None + X = np.stack([c["action_flat"] for c in band_chunks]) + X = StandardScaler().fit_transform(X) + n = min(pca_components, X.shape[1], X.shape[0] - 1) + X_r = PCA(n_components=n, random_state=seed).fit_transform(X) + bics = [] + for k in range(1, min(max_k + 1, len(X_r) // 6)): + gmm = GaussianMixture( + n_components=k, covariance_type="full", + n_init=5, max_iter=300, random_state=seed, + ) + gmm.fit(X_r) + bics.append((k, gmm.bic(X_r))) + if not bics: + return None + return min(bics, key=lambda x: x[1])[0] + + +def plot_gmm_bic( + bands_full: dict[int, list[dict]], + bands_hq: dict[int, list[dict]], + band_edges: list[tuple[float, float]], + seed: int, + out_path: Path, +) -> tuple[list[int | None], list[int | None]]: + n_bands = len(band_edges) + ks_full = [gmm_optimal_k(bands_full.get(b, []), seed=seed) for b in range(n_bands)] + ks_hq = [gmm_optimal_k(bands_hq.get(b, []), seed=seed) for b in range(n_bands)] + + band_labels = [f"{lo:.0%}–{hi:.0%}" for lo, hi in band_edges] + + fig, ax = plt.subplots(figsize=(10, 5)) + fig.patch.set_facecolor(BG) + _style_ax(ax) + + xs = np.arange(n_bands) + valid_full = [(i, k) for i, k in enumerate(ks_full) if k is not None] + valid_hq = [(i, k) for i, k in enumerate(ks_hq) if k is not None] + + if valid_full: + xi, yi = zip(*valid_full) + ax.plot(xi, yi, "o-", color=C_FULL, label="Full/Mixed", lw=2, markersize=7) + if valid_hq: + xi, yi = zip(*valid_hq) + ax.plot(xi, yi, "o-", color=C_HQ, label="HQ", lw=2, markersize=7) + + if valid_full and valid_hq: + all_x = sorted(set([i for i, _ in valid_full]) & set([i for i, _ in valid_hq])) + if len(all_x) >= 2: + kf_interp = {i: k for i, k in valid_full} + kh_interp = {i: k for i, k in valid_hq} + shared_x = [i for i in all_x if i in kf_interp and i in kh_interp] + yf = [kf_interp[i] for i in shared_x] + yh = [kh_interp[i] for i in shared_x] + ax.fill_between(shared_x, yf, yh, alpha=0.15, color=TEXT) + + ax.set_xticks(xs) + ax.set_xticklabels(band_labels, fontsize=7, color=SUB, rotation=30) + ax.set_ylabel("Optimal K (GMM-BIC)", color=SUB) + ax.set_title("Number of Distinct Strategies per Band", color=TEXT, fontsize=13) + ax.legend(facecolor=CARD, edgecolor=BORDER, labelcolor=TEXT, fontsize=9) + ax.yaxis.set_major_locator(plt.MaxNLocator(integer=True)) + fig.tight_layout() + _save(fig, out_path) + return ks_full, ks_hq + + +# ── Step 7: PCA scatter per band ──────────────────────────────────────── + +def plot_pca_scatter( + bands_full: dict[int, list[dict]], + bands_hq: dict[int, list[dict]], + band_edges: list[tuple[float, float]], + out_path: Path, +) -> None: + n_plot = min(4, len(band_edges)) + fig, axes = plt.subplots(2, n_plot, figsize=(4 * n_plot, 7)) + fig.patch.set_facecolor(BG) + fig.suptitle("PCA of Action Chunks per Band", color=TEXT, fontsize=14) + + if n_plot == 1: + axes = axes.reshape(2, 1) + + for col, b in enumerate(range(n_plot)): + cf = bands_full.get(b, []) + ch = bands_hq.get(b, []) + lo, hi = band_edges[b] + + for row, (clist, color, label) in enumerate([ + (cf, C_FULL, "Full/Mixed"), (ch, C_HQ, "HQ") + ]): + ax = axes[row, col] + _style_ax(ax) + if row == 0: + ax.set_title(f"{lo:.0%}–{hi:.0%}", color=TEXT, fontsize=10) + if col == 0: + ax.set_ylabel(label, color=SUB, fontsize=9) + + if len(cf) < 3 or len(ch) < 3: + ax.text(0.5, 0.5, "Too few\nchunks", transform=ax.transAxes, + ha="center", va="center", color=SUB, fontsize=9) + continue + + X_full_b = np.stack([c["action_flat"] for c in cf]) + X_hq_b = np.stack([c["action_flat"] for c in ch]) + X_all = np.vstack([X_full_b, X_hq_b]) + X_all = StandardScaler().fit_transform(X_all) + X_2d = PCA(n_components=2, random_state=42).fit_transform(X_all) + + X_2d_full = X_2d[: len(cf)] + X_2d_hq = X_2d[len(cf) :] + + pts = X_2d_full if row == 0 else X_2d_hq + ax.scatter(pts[:, 0], pts[:, 1], s=8, alpha=0.5, color=color, edgecolors="none") + + fig.tight_layout(rect=[0, 0, 1, 0.95]) + _save(fig, out_path) + + +# ── Plot 1: Chunk counts per band ─────────────────────────────────────── + +def plot_chunk_counts( + bands_full: dict[int, list[dict]], + bands_hq: dict[int, list[dict]], + band_edges: list[tuple[float, float]], + out_path: Path, +) -> None: + n_bands = len(band_edges) + band_labels = [f"{lo:.0%}–{hi:.0%}" for lo, hi in band_edges] + x = np.arange(n_bands) + w = 0.35 + + counts_full = [len(bands_full.get(b, [])) for b in range(n_bands)] + counts_hq = [len(bands_hq.get(b, [])) for b in range(n_bands)] + + fig, ax = plt.subplots(figsize=(10, 5)) + fig.patch.set_facecolor(BG) + _style_ax(ax) + + ax.bar(x - w / 2, counts_full, w, color=C_FULL, edgecolor=BORDER, label="Full/Mixed") + ax.bar(x + w / 2, counts_hq, w, color=C_HQ, edgecolor=BORDER, label="HQ") + ax.set_xticks(x) + ax.set_xticklabels(band_labels, fontsize=7, color=SUB, rotation=30) + ax.set_ylabel("Chunk count", color=SUB) + ax.set_title("Chunk Counts per Progress Band", color=TEXT, fontsize=13) + ax.legend(facecolor=CARD, edgecolor=BORDER, labelcolor=TEXT, fontsize=8) + fig.tight_layout() + _save(fig, out_path) + + +# ── Summary figure ─────────────────────────────────────────────────────── + +def plot_summary( + var_full: np.ndarray, + var_hq: np.ndarray, + band_edges: list[tuple[float, float]], + ks_full: list[int | None], + ks_hq: list[int | None], + bands_full: dict[int, list[dict]], + bands_hq: dict[int, list[dict]], + out_path: Path, +) -> None: + ratio = np.nanmean(var_full, axis=1) / (np.nanmean(var_hq, axis=1) + 1e-8) + valid_ratio = ratio[~np.isnan(ratio)] + mean_ratio = float(np.mean(valid_ratio)) if len(valid_ratio) > 0 else float("nan") + peak_idx = int(np.argmax(valid_ratio)) if len(valid_ratio) > 0 else 0 + peak_ratio = float(valid_ratio[peak_idx]) if len(valid_ratio) > 0 else float("nan") + lo, hi = band_edges[peak_idx] + peak_band = f"{lo:.0%}–{hi:.0%}" + + valid_kf = [k for k in ks_full if k is not None] + valid_kh = [k for k in ks_hq if k is not None] + mean_k_full = np.mean(valid_kf) if valid_kf else float("nan") + mean_k_hq = np.mean(valid_kh) if valid_kh else float("nan") + + n_bands = len(band_edges) + deltas_full = [c["progress_delta"] for b in range(n_bands) for c in bands_full.get(b, [])] + deltas_hq = [c["progress_delta"] for b in range(n_bands) for c in bands_hq.get(b, [])] + mean_delta_full = float(np.mean(deltas_full)) if deltas_full else float("nan") + mean_delta_hq = float(np.mean(deltas_hq)) if deltas_hq else float("nan") + + rows = [ + ("Mean variance ratio (Full / HQ)", f"{mean_ratio:.2f}x"), + ("Peak variance ratio", f"{peak_ratio:.2f}x at {peak_band}"), + ("Mean GMM K — Full", f"{mean_k_full:.1f}"), + ("Mean GMM K — HQ", f"{mean_k_hq:.1f}"), + ("Mean progress Δ — Full", f"{mean_delta_full:.4f}"), + ("Mean progress Δ — HQ", f"{mean_delta_hq:.4f}"), + ] + + fig, ax = plt.subplots(figsize=(8, 3)) + fig.patch.set_facecolor(BG) + ax.set_facecolor(CARD) + ax.axis("off") + + table = ax.table( + cellText=[[m, v] for m, v in rows], + colLabels=["Metric", "Value"], + loc="center", + cellLoc="left", + ) + table.auto_set_font_size(False) + table.set_fontsize(10) + for key, cell in table.get_celld().items(): + cell.set_edgecolor(BORDER) + cell.set_facecolor(CARD) + cell.set_text_props(color=TEXT) + if key[0] == 0: + cell.set_text_props(color=TEXT, fontweight="bold") + table.scale(1, 1.6) + ax.set_title("Summary Statistics", color=TEXT, fontsize=13, pad=15) + fig.tight_layout() + _save(fig, out_path) + + for metric, value in rows: + logger.info(" %s: %s", metric, value) + + +# ── Main ───────────────────────────────────────────────────────────────── + +def main(args: argparse.Namespace) -> None: + out = Path(args.output_dir) + out.mkdir(parents=True, exist_ok=True) + + logger.info("Loading FULL dataset: %s", args.full_dataset) + episodes_full = load_episodes(args.full_dataset, args.n_joints, args.max_episodes) + logger.info("Loading HQ dataset: %s", args.hq_dataset) + episodes_hq = load_episodes(args.hq_dataset, args.n_joints, args.max_episodes) + logger.info("Loaded %d full episodes, %d HQ episodes", len(episodes_full), len(episodes_hq)) + + # Step 1: length threshold + filter + if args.min_episode_length is not None: + threshold = args.min_episode_length + else: + threshold = auto_length_threshold(episodes_full, episodes_hq) + logger.info("Episode length threshold: %d", threshold) + + plot_length_distribution(episodes_full, episodes_hq, threshold, out / "0_length_distribution.png") + episodes_full = filter_episodes(episodes_full, threshold) + episodes_hq = filter_episodes(episodes_hq, threshold) + + # Step 2: extract chunks + chunks_full = extract_chunks(episodes_full, args.chunk_size, args.chunk_stride) + chunks_hq = extract_chunks(episodes_hq, args.chunk_size, args.chunk_stride) + logger.info("Extracted %d full chunks, %d HQ chunks", len(chunks_full), len(chunks_hq)) + + # Step 3: adaptive bands (fit on full, apply to both) + band_edges = fit_adaptive_bands(chunks_full, args.min_chunks_per_band) + n_bands = len(band_edges) + logger.info("Adaptive bands (%d): %s", n_bands, + [f"{lo:.0%}–{hi:.0%}" for lo, hi in band_edges]) + + chunks_full = assign_bands(chunks_full, band_edges) + chunks_hq = assign_bands(chunks_hq, band_edges) + bands_full = split_by_band(chunks_full, n_bands) + bands_hq = split_by_band(chunks_hq, n_bands) + + # Plot 1: chunk counts + plot_chunk_counts(bands_full, bands_hq, band_edges, out / "1_chunk_counts_per_band.png") + + # Step 4: variance heatmap + var_full = band_variance_matrix(bands_full, n_bands, args.n_joints) + var_hq = band_variance_matrix(bands_hq, n_bands, args.n_joints) + plot_variance_heatmap(var_full, var_hq, band_edges, out / "2_variance_heatmap.png") + + # Step 5: progress delta + plot_progress_delta(bands_full, bands_hq, band_edges, out / "3_progress_delta_per_band.png") + + # Step 6: GMM BIC + ks_full, ks_hq = plot_gmm_bic(bands_full, bands_hq, band_edges, args.seed, out / "4_gmm_bic_per_band.png") + + # Step 7: PCA scatter + plot_pca_scatter(bands_full, bands_hq, band_edges, out / "5_pca_per_band.png") + + # Summary + plot_summary(var_full, var_hq, band_edges, ks_full, ks_hq, + bands_full, bands_hq, out / "6_summary.png") + + logger.info("All figures saved to %s", out) + + +if __name__ == "__main__": + p = argparse.ArgumentParser( + description="Chunk-level multi-modality analysis: Full/Mixed vs HQ dataset.", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + p.add_argument("--full-dataset", default="lerobot-data-collection/level12_rac_2_2026-02-08_1") + p.add_argument("--hq-dataset", default="lerobot-data-collection/level2_final_quality3") + p.add_argument("--output-dir", default="./chunk_analysis") + p.add_argument("--chunk-size", type=int, default=30) + p.add_argument("--chunk-stride", type=int, default=15) + p.add_argument("--min-chunks-per-band", type=int, default=20) + p.add_argument("--max-episodes", type=int, default=500) + p.add_argument("--n-joints", type=int, default=16) + p.add_argument("--min-episode-length", type=int, default=None, + help="Override auto-detected length filter threshold") + p.add_argument("--seed", type=int, default=42) + args = p.parse_args() + main(args) diff --git a/examples/dataset/visualization_tools/create_progress_videos.py b/examples/dataset/visualization_tools/create_progress_videos.py new file mode 100644 index 000000000..ec352c655 --- /dev/null +++ b/examples/dataset/visualization_tools/create_progress_videos.py @@ -0,0 +1,471 @@ +""" +Create MP4 videos with sarm_progress overlay for specified episodes. +Downloads datasets from HuggingFace, extracts episode video + progress data, +and draws the progress line directly on each frame (no panel, no axes). +""" + +import json +import subprocess +from pathlib import Path + +import cv2 +import numpy as np +import pandas as pd +from huggingface_hub import snapshot_download + + +# ─────────────────────── Config ─────────────────────── +DATASETS = [ + {"repo_id": "lerobot-data-collection/level1_rac3_rtc_s6_1", "episode": 0}, +] +OUTPUT_DIR = Path("/Users/pepijnkooijmans/Documents/GitHub_local/progress_videos") +OUTPUT_DIR.mkdir(exist_ok=True) + +# Progress line spans the full video height +GRAPH_Y_TOP_FRAC = 0.01 +GRAPH_Y_BOT_FRAC = 0.99 +LINE_THICKNESS = 3 +SHADOW_THICKNESS = 6 # white edge thickness +REF_ALPHA = 0.45 # opacity of the 1.0 reference line +FILL_ALPHA = 0.55 # opacity of the grey fill under the line +SCORE_FONT_SCALE = 0.8 +TASK_FONT_SCALE = 0.55 + + +# ─────────────────────── Helpers ────────────────────── + +def download_episode(repo_id: str, episode: int) -> Path: + """Download only the files needed for this episode.""" + safe_ep = f"{episode:06d}" + # We need: meta/, sarm_progress.parquet, and the relevant video/data chunks. + # We'll download meta + sarm first, then figure out chunks. + print(f"\n[1/5] Downloading metadata for {repo_id} …") + local = Path( + snapshot_download( + repo_id=repo_id, + repo_type="dataset", + allow_patterns=["meta/**", "sarm_progress.parquet"], + ignore_patterns=["*.mp4"], + ) + ) + return local + + +def load_episode_meta(local: Path, episode: int) -> dict: + """Read info.json + episode-level parquet to get fps, video paths, timestamps.""" + info = json.loads((local / "meta" / "info.json").read_text()) + fps = info["fps"] + features = info["features"] + + # Find video keys (keys whose dtype=="video") + video_keys = [k for k, v in features.items() if v.get("dtype") == "video"] + if not video_keys: + raise RuntimeError("No video keys found in dataset features") + first_cam = video_keys[0] + print(f" fps={fps} first_camera='{first_cam}' all_cams={video_keys}") + + # Load all episode-meta parquet files and find our episode + ep_rows = [] + for pq in sorted((local / "meta" / "episodes").glob("**/*.parquet")): + df = pd.read_parquet(pq) + ep_rows.append(df) + ep_df = pd.concat(ep_rows, ignore_index=True) + row = ep_df[ep_df["episode_index"] == episode] + if row.empty: + raise RuntimeError(f"Episode {episode} not found in episode metadata") + row = row.iloc[0] + + # Extract video chunk/file index for first camera + cam_key = first_cam.replace(".", "/") # some datasets store as nested key + # Try both dot and slash variants of the key + chunk_col = f"videos/{first_cam}/chunk_index" + file_col = f"videos/{first_cam}/file_index" + ts_col = f"videos/{first_cam}/from_timestamp" + to_col = f"videos/{first_cam}/to_timestamp" + + # Some datasets use different column naming + if chunk_col not in row.index: + # Try without the 'videos/' prefix + chunk_col = f"{first_cam}/chunk_index" + file_col = f"{first_cam}/file_index" + ts_col = f"{first_cam}/from_timestamp" + to_col = f"{first_cam}/to_timestamp" + if chunk_col not in row.index: + raise RuntimeError(f"Cannot find video metadata columns for {first_cam}.\nAvailable: {list(row.index)}") + + chunk_idx = int(row[chunk_col]) + file_idx = int(row[file_col]) + from_ts = float(row[ts_col]) + to_ts = float(row[to_col]) + + video_template = info.get("video_path", "videos/{video_key}/chunk-{chunk_index:03d}/file-{file_index:03d}.mp4") + video_rel = video_template.format( + video_key=first_cam, + chunk_index=chunk_idx, + file_index=file_idx, + ) + + # Load task name for this episode + # tasks.parquet uses the task string as the row index; task_index column holds the int id + task_name = "" + try: + # Prefer the 'tasks' list directly on the episode row + if "tasks" in row.index and row["tasks"] is not None: + tasks_val = row["tasks"] + if isinstance(tasks_val, (list, tuple, np.ndarray)) and len(tasks_val) > 0: + task_name = str(tasks_val[0]) + else: + task_name = str(tasks_val).strip("[]'") + else: + tasks_pq = local / "meta" / "tasks.parquet" + if tasks_pq.exists(): + tasks_df = pd.read_parquet(tasks_pq) + # Row index is the task string; task_index column is the int + task_idx = int(row.get("task_index", 0)) if "task_index" in row.index else 0 + match = tasks_df[tasks_df["task_index"] == task_idx] + if not match.empty: + task_name = str(match.index[0]) + print(f" Task name: '{task_name}'") + except Exception as e: + print(f" WARNING: could not load task name: {e}") + + return { + "fps": fps, + "first_cam": first_cam, + "video_rel": video_rel, + "chunk_index": chunk_idx, + "file_index": file_idx, + "from_ts": from_ts, + "to_ts": to_ts, + "task_name": task_name, + } + + +def download_video(repo_id: str, local: Path, video_rel: str) -> Path: + """Download the specific video file if not already present.""" + video_path = local / video_rel + if video_path.exists(): + print(f" Video already cached: {video_path}") + return video_path + print(f"[2/5] Downloading video file {video_rel} …") + snapshot_download( + repo_id=repo_id, + repo_type="dataset", + local_dir=str(local), + allow_patterns=[video_rel], + ) + if not video_path.exists(): + raise RuntimeError(f"Video not found after download: {video_path}") + return video_path + + +def load_progress(local: Path, episode: int) -> np.ndarray | None: + """Load sarm_progress values for this episode. Returns sorted array of (frame_index, progress).""" + pq_path = local / "sarm_progress.parquet" + if not pq_path.exists(): + print(" WARNING: sarm_progress.parquet not found, trying data parquet …") + return None + df = pd.read_parquet(pq_path) + print(f" sarm_progress.parquet columns: {list(df.columns)}") + ep_df = df[df["episode_index"] == episode].copy() + if ep_df.empty: + print(f" WARNING: No sarm_progress rows for episode {episode}") + return None + ep_df = ep_df.sort_values("frame_index") + + # Prefer dense, fall back to sparse + if "progress_dense" in ep_df.columns and ep_df["progress_dense"].notna().any(): + prog_col = "progress_dense" + elif "progress_sparse" in ep_df.columns: + prog_col = "progress_sparse" + else: + # Last resort: any column with 'progress' in the name + prog_cols = [c for c in ep_df.columns if "progress" in c.lower()] + if not prog_cols: + return None + prog_col = prog_cols[0] + + print(f" Using progress column: '{prog_col}'") + return ep_df[["frame_index", prog_col]].rename(columns={prog_col: "progress"}).values + + +def extract_episode_clip(video_path: Path, from_ts: float, to_ts: float, out_path: Path) -> Path: + """Use ffmpeg to cut the episode segment from the combined video file.""" + duration = to_ts - from_ts + print(f"[3/5] Extracting clip [{from_ts:.3f}s → {to_ts:.3f}s] ({duration:.2f}s) …") + cmd = [ + "ffmpeg", "-y", + "-ss", str(from_ts), + "-i", str(video_path), + "-t", str(duration), + "-c:v", "libx264", + "-preset", "fast", + "-crf", "18", + "-an", + str(out_path), + ] + result = subprocess.run(cmd, capture_output=True, text=True) + if result.returncode != 0: + raise RuntimeError(f"ffmpeg clip extraction failed:\n{result.stderr}") + return out_path + + +def precompute_pixels( + progress_data: np.ndarray, + n_frames: int, + frame_w: int, + frame_h: int, +) -> np.ndarray: + """ + Map each progress sample to pixel coordinates. + Returns array of shape (N, 2) with (x, y) in pixel space. + x spans full video width; y maps progress [0,1] to graph band. + """ + frame_indices = progress_data[:, 0].astype(float) + progress_vals = np.clip(progress_data[:, 1].astype(float), 0.0, 1.0) + n = len(frame_indices) + + y_top = int(frame_h * GRAPH_Y_TOP_FRAC) + y_bot = int(frame_h * GRAPH_Y_BOT_FRAC) + graph_h = y_bot - y_top + + xs = (frame_indices / (n_frames - 1) * (frame_w - 1)).astype(int) + # progress=1 → y_top, progress=0 → y_bot + ys = (y_bot - progress_vals * graph_h).astype(int) + + return np.stack([xs, ys], axis=1) # (N, 2) + + +def progress_color(t: float) -> tuple[int, int, int]: + """Interpolate BGR color red→green based on normalised position t in [0,1].""" + r = int(255 * (1.0 - t)) + g = int(255 * t) + return (0, g, r) # BGR + + +def prerender_fill( + pixels: np.ndarray, + frame_w: int, + frame_h: int, +) -> np.ndarray: + """Pre-render the full grey fill polygon under the curve as a BGRA image.""" + y_bot = int(frame_h * GRAPH_Y_BOT_FRAC) + fill_img = np.zeros((frame_h, frame_w, 4), dtype=np.uint8) + poly = np.concatenate([ + pixels, + [[pixels[-1][0], y_bot], [pixels[0][0], y_bot]], + ], axis=0).astype(np.int32) + cv2.fillPoly(fill_img, [poly], color=(128, 128, 128, int(255 * FILL_ALPHA))) + return fill_img + + +def alpha_composite(base: np.ndarray, overlay_bgra: np.ndarray, x_max: int) -> None: + """Blend overlay onto base in-place, but only for x < x_max.""" + if x_max <= 0: + return + roi_b = base[:, :x_max] + roi_o = overlay_bgra[:, :x_max] + alpha = roi_o[:, :, 3:4].astype(np.float32) / 255.0 + roi_b[:] = np.clip( + roi_o[:, :, :3].astype(np.float32) * alpha + + roi_b.astype(np.float32) * (1.0 - alpha), + 0, 255, + ).astype(np.uint8) + + +def draw_text_outlined( + frame: np.ndarray, + text: str, + pos: tuple[int, int], + font_scale: float, + thickness: int = 1, +) -> None: + """Draw text with a dark outline for readability on any background.""" + font = cv2.FONT_HERSHEY_SIMPLEX + cv2.putText(frame, text, pos, font, font_scale, (0, 0, 0), thickness + 2, cv2.LINE_AA) + cv2.putText(frame, text, pos, font, font_scale, (255, 255, 255), thickness, cv2.LINE_AA) + + +def composite_video( + clip_path: Path, + progress_data: np.ndarray, + out_path: Path, + fps: float, + frame_h: int, + frame_w: int, + task_name: str = "", +) -> Path: + """Read clip frames, draw gradient progress line with fill + labels, export as GIF.""" + n_total = int(cv2.VideoCapture(str(clip_path)).get(cv2.CAP_PROP_FRAME_COUNT)) + pixels = precompute_pixels(progress_data, n_total, frame_w, frame_h) + + y_top = int(frame_h * GRAPH_Y_TOP_FRAC) + y_bot = int(frame_h * GRAPH_Y_BOT_FRAC) + y_ref = y_top + + # Pre-render fill polygon (line is drawn per-frame with live color) + fill_img = prerender_fill(pixels, frame_w, frame_h) + + # 1.0 reference line overlay (full width, drawn once) + ref_img = np.zeros((frame_h, frame_w, 4), dtype=np.uint8) + cv2.line(ref_img, (0, y_ref), (frame_w - 1, y_ref), + (200, 200, 200, int(255 * REF_ALPHA)), 1, cv2.LINE_AA) + + frame_indices = progress_data[:, 0].astype(int) + progress_vals = progress_data[:, 1].astype(float) + + print(f"[4/4] Compositing {n_total} frames …") + cap = cv2.VideoCapture(str(clip_path)) + fourcc = cv2.VideoWriter_fourcc(*"mp4v") + tmp_path = out_path.parent / (out_path.stem + "_tmp.mp4") + writer = cv2.VideoWriter(str(tmp_path), fourcc, fps, (frame_w, frame_h)) + + fi = 0 + while True: + ret, frame = cap.read() + if not ret: + break + + n_drawn = int(np.searchsorted(frame_indices, fi, side="right")) + x_cur = int(pixels[min(n_drawn, len(pixels)) - 1][0]) + 1 if n_drawn > 0 else 0 + + # 1. reference line (full width, always) + alpha_composite(frame, ref_img, frame_w) + + # 2. grey fill under curve up to current x + alpha_composite(frame, fill_img, x_cur) + + # 3. progress line — single color that transitions red→green over time + if n_drawn >= 2: + t_cur = (n_drawn - 1) / max(len(progress_vals) - 1, 1) + line_col = progress_color(t_cur) + pts = pixels[:n_drawn].reshape(-1, 1, 2).astype(np.int32) + cv2.polylines(frame, [pts], isClosed=False, + color=(255, 255, 255), thickness=SHADOW_THICKNESS, + lineType=cv2.LINE_AA) + cv2.polylines(frame, [pts], isClosed=False, + color=line_col, thickness=LINE_THICKNESS, + lineType=cv2.LINE_AA) + + # 4. score — bottom right + if n_drawn > 0: + score = float(progress_vals[min(n_drawn, len(progress_vals)) - 1]) + score_text = f"{score:.2f}" + (tw, th), _ = cv2.getTextSize(score_text, cv2.FONT_HERSHEY_SIMPLEX, + SCORE_FONT_SCALE, 2) + sx = frame_w - tw - 12 + sy = frame_h - 12 + # coloured score matching current gradient position + t_cur = (n_drawn - 1) / max(len(progress_vals) - 1, 1) + score_col = progress_color(t_cur) + cv2.putText(frame, score_text, (sx, sy), cv2.FONT_HERSHEY_SIMPLEX, + SCORE_FONT_SCALE, (0, 0, 0), 4, cv2.LINE_AA) + cv2.putText(frame, score_text, (sx, sy), cv2.FONT_HERSHEY_SIMPLEX, + SCORE_FONT_SCALE, score_col, 2, cv2.LINE_AA) + + # 5. task name — top centre + if task_name: + (tw, _), _ = cv2.getTextSize(task_name, cv2.FONT_HERSHEY_SIMPLEX, + TASK_FONT_SCALE, 1) + tx = max((frame_w - tw) // 2, 4) + draw_text_outlined(frame, task_name, (tx, 22), TASK_FONT_SCALE) + + writer.write(frame) + fi += 1 + if fi % 100 == 0: + print(f" Frame {fi}/{n_total} …", end="\r") + + cap.release() + writer.release() + print() + + # Convert to GIF: full resolution, 12fps, 128-color diff palette (<40MB) + gif_path = out_path.with_suffix(".gif") + palette = out_path.parent / "_palette.png" + r1 = subprocess.run([ + "ffmpeg", "-y", "-i", str(tmp_path), + "-vf", f"fps=10,scale={frame_w}:-1:flags=lanczos,palettegen=max_colors=128:stats_mode=diff", + "-update", "1", + str(palette), + ], capture_output=True, text=True) + if r1.returncode != 0: + print(f" WARNING: palettegen failed:\n{r1.stderr[-500:]}") + r2 = subprocess.run([ + "ffmpeg", "-y", + "-i", str(tmp_path), "-i", str(palette), + "-filter_complex", + f"fps=10,scale={frame_w}:-1:flags=lanczos[v];[v][1:v]paletteuse=dither=bayer:bayer_scale=3", + str(gif_path), + ], capture_output=True, text=True) + if r2.returncode != 0: + print(f" WARNING: gif encode failed:\n{r2.stderr[-500:]}") + tmp_path.unlink(missing_ok=True) + palette.unlink(missing_ok=True) + return gif_path + + +# ─────────────────────── Main ────────────────────────── + +def process_dataset(repo_id: str, episode: int): + safe_name = repo_id.replace("/", "_") + print(f"\n{'='*60}") + print(f"Processing: {repo_id} | episode {episode}") + print(f"{'='*60}") + + # 1. Download metadata + local = download_episode(repo_id, episode) + print(f" Local cache: {local}") + + # 2. Read episode metadata + ep_meta = load_episode_meta(local, episode) + print(f" Episode meta: {ep_meta}") + + # 3. Download video file + video_path = download_video(repo_id, local, ep_meta["video_rel"]) + + # 4. Extract clip + clip_path = OUTPUT_DIR / f"{safe_name}_ep{episode}_clip.mp4" + extract_episode_clip(video_path, ep_meta["from_ts"], ep_meta["to_ts"], clip_path) + + # 5. Load progress data + progress_data = load_progress(local, episode) + if progress_data is None: + print(" ERROR: Could not load sarm_progress data. Skipping overlay.") + return + + n_progress = len(progress_data) + print(f" Progress frames: {n_progress}") + + # 6. Get clip dimensions + cap = cv2.VideoCapture(str(clip_path)) + frame_w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + frame_h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + n_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + actual_fps = cap.get(cv2.CAP_PROP_FPS) or ep_meta["fps"] + cap.release() + print(f" Clip: {frame_w}×{frame_h} {n_frames} frames @ {actual_fps:.1f}fps") + + # 7. Composite (draw line directly on frames) + out_path = OUTPUT_DIR / f"{safe_name}_ep{episode}_progress.mp4" + final = composite_video(clip_path, progress_data, out_path, actual_fps, frame_h, frame_w, + task_name=ep_meta.get("task_name", "")) + clip_path.unlink(missing_ok=True) + print(f"\n✓ Done: {final}") + return final + + +if __name__ == "__main__": + results = [] + for cfg in DATASETS: + try: + out = process_dataset(cfg["repo_id"], cfg["episode"]) + if out: + results.append(out) + except Exception as e: + print(f"\nERROR processing {cfg['repo_id']}: {e}") + import traceback; traceback.print_exc() + + print("\n" + "="*60) + print("Output files:") + for r in results: + print(f" {r}")