From 219c08ccb8f894bbc10daa786f68e7c694190c6d Mon Sep 17 00:00:00 2001 From: Pepijn Date: Wed, 18 Mar 2026 10:13:46 -0700 Subject: [PATCH] add example for creating progress video for sarm --- .../chunk_multimodality_analysis.py | 659 ------------------ .../create_progress_videos.py | 12 +- 2 files changed, 9 insertions(+), 662 deletions(-) delete mode 100644 examples/dataset/visualization_tools/chunk_multimodality_analysis.py diff --git a/examples/dataset/visualization_tools/chunk_multimodality_analysis.py b/examples/dataset/visualization_tools/chunk_multimodality_analysis.py deleted file mode 100644 index 837fad5a8..000000000 --- a/examples/dataset/visualization_tools/chunk_multimodality_analysis.py +++ /dev/null @@ -1,659 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2025 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Chunk-level multi-modality analysis for comparing full/mixed vs curated datasets. - -Treats each action chunk (sliding window of CHUNK_SIZE consecutive frames) as the -atomic unit, tagged by the SARM progress score at its start frame. For each -progress band, compares the full vs HQ dataset on: - - 1. Intra-band action variance - 2. Progress delta per chunk - 3. GMM + BIC optimal K (number of distinct strategies) - 4. PCA embedding (visual cluster inspection) - -Usage: - python chunk_multimodality_analysis.py \\ - --full-dataset lerobot-data-collection/level12_rac_2_2026-02-08_1 \\ - --hq-dataset lerobot-data-collection/level2_final_quality3 \\ - --output-dir ./chunk_analysis -""" - -from __future__ import annotations - -import argparse -import logging -from collections import defaultdict -from pathlib import Path - -import matplotlib.pyplot as plt -import numpy as np -from scipy.stats import gaussian_kde -from sklearn.decomposition import PCA -from sklearn.mixture import GaussianMixture -from sklearn.preprocessing import StandardScaler - -from lerobot.datasets.lerobot_dataset import LeRobotDataset - -logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") -logger = logging.getLogger(__name__) - -# ── Visual style ────────────────────────────────────────────────────────── - -BG = "#0e1117" -CARD = "#1a1d27" -BORDER = "#2a2d3a" -SUB = "#8b8fa8" -TEXT = "#e8eaf0" -C_FULL = "#f7934f" -C_HQ = "#4dc98a" - - -def _style_ax(ax: plt.Axes) -> None: - ax.set_facecolor(CARD) - ax.tick_params(colors=SUB, labelsize=8) - for spine in ax.spines.values(): - spine.set_color(BORDER) - - -def _save(fig: plt.Figure, path: Path) -> None: - fig.savefig(path, dpi=150, bbox_inches="tight", facecolor=BG) - plt.close(fig) - logger.info("Saved %s", path) - - -# ── Step 0: Load episodes ──────────────────────────────────────────────── - -def load_episodes( - repo_id: str, - n_joints: int = 16, - max_episodes: int | None = None, -) -> list[dict]: - dataset = LeRobotDataset(repo_id, download_videos=False) - raw = dataset.hf_dataset - episodes: dict[int, dict] = defaultdict(lambda: {"actions": [], "progress": []}) - - for row in raw: - ep = int(row["episode_index"]) - if max_episodes is not None and ep >= max_episodes: - continue - action = np.array(row["action"], dtype=np.float32)[:n_joints] - episodes[ep]["actions"].append(action) - progress = float(row.get("next.reward", float("nan"))) - episodes[ep]["progress"].append(progress) - - result = [] - for ep_id, d in sorted(episodes.items()): - result.append({ - "episode": ep_id, - "actions": np.stack(d["actions"]), - "progress": np.array(d["progress"]), - }) - return result - - -# ── Step 1: Filter short episodes ──────────────────────────────────────── - -def auto_length_threshold( - episodes_full: list[dict], episodes_hq: list[dict] -) -> int: - all_lengths = np.array( - [e["actions"].shape[0] for e in episodes_full + episodes_hq] - ) - kde = gaussian_kde(all_lengths, bw_method=0.25) - xs = np.linspace(all_lengths.min(), np.percentile(all_lengths, 40), 300) - return int(xs[np.argmin(kde(xs))]) - - -def plot_length_distribution( - episodes_full: list[dict], - episodes_hq: list[dict], - threshold: int, - out_path: Path, -) -> None: - lens_full = np.array([e["actions"].shape[0] for e in episodes_full]) - lens_hq = np.array([e["actions"].shape[0] for e in episodes_hq]) - all_lens = np.concatenate([lens_full, lens_hq]) - - fig, ax = plt.subplots(figsize=(10, 5)) - fig.patch.set_facecolor(BG) - _style_ax(ax) - - bins = np.linspace(all_lens.min(), all_lens.max(), 50) - ax.hist(lens_full, bins=bins, alpha=0.5, color=C_FULL, label="Full/Mixed") - ax.hist(lens_hq, bins=bins, alpha=0.5, color=C_HQ, label="HQ") - - xs = np.linspace(all_lens.min(), all_lens.max(), 300) - kde = gaussian_kde(all_lens, bw_method=0.25) - ax.plot(xs, kde(xs) * len(all_lens) * (bins[1] - bins[0]), color=TEXT, lw=1.5, label="KDE (combined)") - - ax.axvline(threshold, color="#ff4b4b", ls="--", lw=1.5, label=f"Threshold = {threshold}") - ax.set_xlabel("Episode length (frames)", color=SUB) - ax.set_ylabel("Count", color=SUB) - ax.set_title("Episode Length Distribution", color=TEXT, fontsize=13) - ax.legend(facecolor=CARD, edgecolor=BORDER, labelcolor=TEXT, fontsize=8) - _save(fig, out_path) - - -def filter_episodes(episodes: list[dict], min_length: int) -> list[dict]: - kept = [e for e in episodes if e["actions"].shape[0] >= min_length] - logger.info("Kept %d / %d episodes (min_length=%d)", len(kept), len(episodes), min_length) - return kept - - -# ── Step 2: Extract chunks ─────────────────────────────────────────────── - -def extract_chunks( - episodes: list[dict], - chunk_size: int = 30, - chunk_stride: int = 15, -) -> list[dict]: - chunks = [] - for ep in episodes: - actions = ep["actions"] - T = len(actions) - - prog = np.clip(np.nan_to_num(ep["progress"], nan=0.0), 0.0, 1.0) - prog = np.maximum.accumulate(prog) - - for t in range(0, T - chunk_size, chunk_stride): - chunk = actions[t : t + chunk_size] - p_start = float(prog[t]) - p_end = float(prog[min(t + chunk_size, T - 1)]) - - chunks.append({ - "action_mean": chunk.mean(axis=0).astype(np.float32), - "action_flat": chunk.flatten().astype(np.float32), - "progress_start": p_start, - "progress_delta": p_end - p_start, - "episode": ep["episode"], - }) - return chunks - - -# ── Step 3: Adaptive progress bands ───────────────────────────────────── - -def fit_adaptive_bands( - chunks: list[dict], min_per_band: int = 20 -) -> list[tuple[float, float]]: - prog_vals = np.array([c["progress_start"] for c in chunks]) - fine_edges = np.linspace(0.0, 1.0, 11) - - band_edges: list[tuple[float, float]] = [] - i = 0 - while i < len(fine_edges) - 1: - lo, hi = fine_edges[i], fine_edges[i + 1] - j = i + 1 - while ( - np.sum((prog_vals >= lo) & (prog_vals < hi)) < min_per_band - and j < len(fine_edges) - 1 - ): - j += 1 - hi = fine_edges[j] - band_edges.append((lo, hi)) - i = j - return band_edges - - -def assign_bands( - chunks: list[dict], band_edges: list[tuple[float, float]] -) -> list[dict]: - n = len(band_edges) - for c in chunks: - p = c["progress_start"] - c["band"] = next( - (bi for bi, (lo, hi) in enumerate(band_edges) if p < hi), - n - 1, - ) - return chunks - - -def split_by_band(chunks: list[dict], n_bands: int) -> dict[int, list[dict]]: - out: dict[int, list[dict]] = {b: [] for b in range(n_bands)} - for c in chunks: - out[c["band"]].append(c) - return out - - -# ── Step 4: Intra-band action variance ────────────────────────────────── - -def band_variance_matrix( - bands: dict[int, list[dict]], n_bands: int, n_joints: int -) -> np.ndarray: - var_mat = np.full((n_bands, n_joints), np.nan) - for b, clist in bands.items(): - if len(clist) < 3: - continue - means = np.stack([c["action_mean"] for c in clist]) - var_mat[b] = np.var(means, axis=0) - return var_mat - - -def plot_variance_heatmap( - var_full: np.ndarray, - var_hq: np.ndarray, - band_edges: list[tuple[float, float]], - out_path: Path, -) -> None: - n_bands = var_full.shape[0] - vmin = 0.0 - vmax = max(np.nanmax(var_full), np.nanmax(var_hq)) - - band_labels = [f"{lo:.0%}–{hi:.0%}" for lo, hi in band_edges] - joint_labels = [f"J{j}" for j in range(var_full.shape[1])] - - fig, axes = plt.subplots(3, 1, figsize=(12, 10), gridspec_kw={"height_ratios": [3, 3, 2]}) - fig.patch.set_facecolor(BG) - fig.suptitle("Intra-Band Action Variance", color=TEXT, fontsize=14, y=0.98) - - for ax_idx, (mat, label) in enumerate([(var_full, "Full/Mixed"), (var_hq, "HQ")]): - ax = axes[ax_idx] - _style_ax(ax) - im = ax.imshow(mat, aspect="auto", cmap="YlOrRd", vmin=vmin, vmax=vmax) - ax.set_yticks(range(n_bands)) - ax.set_yticklabels(band_labels, fontsize=7, color=SUB) - ax.set_xticks(range(var_full.shape[1])) - ax.set_xticklabels(joint_labels, fontsize=7, color=SUB) - ax.set_title(f"Panel {'A' if ax_idx == 0 else 'B'}: {label}", color=TEXT, fontsize=11) - fig.colorbar(im, ax=ax, fraction=0.02, pad=0.02) - - ratio = np.nanmean(var_full, axis=1) / (np.nanmean(var_hq, axis=1) + 1e-8) - ax_bar = axes[2] - _style_ax(ax_bar) - colors = [ - "#ff4b4b" if r > 2.0 else "#ffaa33" if r > 1.2 else C_HQ - for r in ratio - ] - ax_bar.bar(range(n_bands), ratio, color=colors, edgecolor=BORDER) - ax_bar.axhline(1.0, color=SUB, ls="--", lw=0.8) - ax_bar.set_xticks(range(n_bands)) - ax_bar.set_xticklabels(band_labels, fontsize=7, color=SUB) - ax_bar.set_ylabel("Variance ratio\n(Full / HQ)", color=SUB, fontsize=9) - ax_bar.set_title("Panel C: Variance Ratio per Band", color=TEXT, fontsize=11) - - fig.tight_layout(rect=[0, 0, 1, 0.96]) - _save(fig, out_path) - - -# ── Step 5: Progress delta per band ────────────────────────────────────── - -def plot_progress_delta( - bands_full: dict[int, list[dict]], - bands_hq: dict[int, list[dict]], - band_edges: list[tuple[float, float]], - out_path: Path, -) -> None: - n_bands = len(band_edges) - band_labels = [f"{lo:.0%}–{hi:.0%}" for lo, hi in band_edges] - x = np.arange(n_bands) - w = 0.35 - - means_full, stds_full = [], [] - means_hq, stds_hq = [], [] - all_deltas_full, all_deltas_hq = [], [] - - for b in range(n_bands): - df = np.array([c["progress_delta"] for c in bands_full.get(b, [])]) - dh = np.array([c["progress_delta"] for c in bands_hq.get(b, [])]) - means_full.append(np.mean(df) if len(df) > 0 else 0) - stds_full.append(np.std(df) if len(df) > 0 else 0) - means_hq.append(np.mean(dh) if len(dh) > 0 else 0) - stds_hq.append(np.std(dh) if len(dh) > 0 else 0) - all_deltas_full.extend(df.tolist()) - all_deltas_hq.extend(dh.tolist()) - - fig, (ax_bar, ax_viol) = plt.subplots(1, 2, figsize=(14, 5), gridspec_kw={"width_ratios": [3, 1]}) - fig.patch.set_facecolor(BG) - fig.suptitle("Progress Delta per Chunk", color=TEXT, fontsize=14) - - _style_ax(ax_bar) - ax_bar.bar(x - w / 2, means_full, w, yerr=stds_full, color=C_FULL, edgecolor=BORDER, - capsize=3, label="Full/Mixed", error_kw={"ecolor": SUB}) - ax_bar.bar(x + w / 2, means_hq, w, yerr=stds_hq, color=C_HQ, edgecolor=BORDER, - capsize=3, label="HQ", error_kw={"ecolor": SUB}) - ax_bar.set_xticks(x) - ax_bar.set_xticklabels(band_labels, fontsize=7, color=SUB, rotation=30) - ax_bar.set_ylabel("Mean progress Δ", color=SUB) - ax_bar.legend(facecolor=CARD, edgecolor=BORDER, labelcolor=TEXT, fontsize=8) - - _style_ax(ax_viol) - data_viol = [np.array(all_deltas_full), np.array(all_deltas_hq)] - if all(len(d) > 0 for d in data_viol): - parts = ax_viol.violinplot(data_viol, positions=[0, 1], showmeans=True, showmedians=True) - for pc, c in zip(parts["bodies"], [C_FULL, C_HQ]): - pc.set_facecolor(c) - pc.set_alpha(0.7) - for key in ("cmeans", "cmedians", "cbars", "cmins", "cmaxes"): - if key in parts: - parts[key].set_color(SUB) - ax_viol.set_xticks([0, 1]) - ax_viol.set_xticklabels(["Full", "HQ"], color=SUB) - ax_viol.set_ylabel("Progress Δ", color=SUB) - ax_viol.set_title("Overall Distribution", color=TEXT, fontsize=10) - - fig.tight_layout() - _save(fig, out_path) - - -# ── Step 6: GMM + BIC per band ────────────────────────────────────────── - -def gmm_optimal_k( - band_chunks: list[dict], - pca_components: int = 15, - max_k: int = 7, - seed: int = 42, -) -> int | None: - if len(band_chunks) < 20: - return None - X = np.stack([c["action_flat"] for c in band_chunks]) - X = StandardScaler().fit_transform(X) - n = min(pca_components, X.shape[1], X.shape[0] - 1) - X_r = PCA(n_components=n, random_state=seed).fit_transform(X) - bics = [] - for k in range(1, min(max_k + 1, len(X_r) // 6)): - gmm = GaussianMixture( - n_components=k, covariance_type="full", - n_init=5, max_iter=300, random_state=seed, - ) - gmm.fit(X_r) - bics.append((k, gmm.bic(X_r))) - if not bics: - return None - return min(bics, key=lambda x: x[1])[0] - - -def plot_gmm_bic( - bands_full: dict[int, list[dict]], - bands_hq: dict[int, list[dict]], - band_edges: list[tuple[float, float]], - seed: int, - out_path: Path, -) -> tuple[list[int | None], list[int | None]]: - n_bands = len(band_edges) - ks_full = [gmm_optimal_k(bands_full.get(b, []), seed=seed) for b in range(n_bands)] - ks_hq = [gmm_optimal_k(bands_hq.get(b, []), seed=seed) for b in range(n_bands)] - - band_labels = [f"{lo:.0%}–{hi:.0%}" for lo, hi in band_edges] - - fig, ax = plt.subplots(figsize=(10, 5)) - fig.patch.set_facecolor(BG) - _style_ax(ax) - - xs = np.arange(n_bands) - valid_full = [(i, k) for i, k in enumerate(ks_full) if k is not None] - valid_hq = [(i, k) for i, k in enumerate(ks_hq) if k is not None] - - if valid_full: - xi, yi = zip(*valid_full) - ax.plot(xi, yi, "o-", color=C_FULL, label="Full/Mixed", lw=2, markersize=7) - if valid_hq: - xi, yi = zip(*valid_hq) - ax.plot(xi, yi, "o-", color=C_HQ, label="HQ", lw=2, markersize=7) - - if valid_full and valid_hq: - all_x = sorted(set([i for i, _ in valid_full]) & set([i for i, _ in valid_hq])) - if len(all_x) >= 2: - kf_interp = {i: k for i, k in valid_full} - kh_interp = {i: k for i, k in valid_hq} - shared_x = [i for i in all_x if i in kf_interp and i in kh_interp] - yf = [kf_interp[i] for i in shared_x] - yh = [kh_interp[i] for i in shared_x] - ax.fill_between(shared_x, yf, yh, alpha=0.15, color=TEXT) - - ax.set_xticks(xs) - ax.set_xticklabels(band_labels, fontsize=7, color=SUB, rotation=30) - ax.set_ylabel("Optimal K (GMM-BIC)", color=SUB) - ax.set_title("Number of Distinct Strategies per Band", color=TEXT, fontsize=13) - ax.legend(facecolor=CARD, edgecolor=BORDER, labelcolor=TEXT, fontsize=9) - ax.yaxis.set_major_locator(plt.MaxNLocator(integer=True)) - fig.tight_layout() - _save(fig, out_path) - return ks_full, ks_hq - - -# ── Step 7: PCA scatter per band ──────────────────────────────────────── - -def plot_pca_scatter( - bands_full: dict[int, list[dict]], - bands_hq: dict[int, list[dict]], - band_edges: list[tuple[float, float]], - out_path: Path, -) -> None: - n_plot = min(4, len(band_edges)) - fig, axes = plt.subplots(2, n_plot, figsize=(4 * n_plot, 7)) - fig.patch.set_facecolor(BG) - fig.suptitle("PCA of Action Chunks per Band", color=TEXT, fontsize=14) - - if n_plot == 1: - axes = axes.reshape(2, 1) - - for col, b in enumerate(range(n_plot)): - cf = bands_full.get(b, []) - ch = bands_hq.get(b, []) - lo, hi = band_edges[b] - - for row, (clist, color, label) in enumerate([ - (cf, C_FULL, "Full/Mixed"), (ch, C_HQ, "HQ") - ]): - ax = axes[row, col] - _style_ax(ax) - if row == 0: - ax.set_title(f"{lo:.0%}–{hi:.0%}", color=TEXT, fontsize=10) - if col == 0: - ax.set_ylabel(label, color=SUB, fontsize=9) - - if len(cf) < 3 or len(ch) < 3: - ax.text(0.5, 0.5, "Too few\nchunks", transform=ax.transAxes, - ha="center", va="center", color=SUB, fontsize=9) - continue - - X_full_b = np.stack([c["action_flat"] for c in cf]) - X_hq_b = np.stack([c["action_flat"] for c in ch]) - X_all = np.vstack([X_full_b, X_hq_b]) - X_all = StandardScaler().fit_transform(X_all) - X_2d = PCA(n_components=2, random_state=42).fit_transform(X_all) - - X_2d_full = X_2d[: len(cf)] - X_2d_hq = X_2d[len(cf) :] - - pts = X_2d_full if row == 0 else X_2d_hq - ax.scatter(pts[:, 0], pts[:, 1], s=8, alpha=0.5, color=color, edgecolors="none") - - fig.tight_layout(rect=[0, 0, 1, 0.95]) - _save(fig, out_path) - - -# ── Plot 1: Chunk counts per band ─────────────────────────────────────── - -def plot_chunk_counts( - bands_full: dict[int, list[dict]], - bands_hq: dict[int, list[dict]], - band_edges: list[tuple[float, float]], - out_path: Path, -) -> None: - n_bands = len(band_edges) - band_labels = [f"{lo:.0%}–{hi:.0%}" for lo, hi in band_edges] - x = np.arange(n_bands) - w = 0.35 - - counts_full = [len(bands_full.get(b, [])) for b in range(n_bands)] - counts_hq = [len(bands_hq.get(b, [])) for b in range(n_bands)] - - fig, ax = plt.subplots(figsize=(10, 5)) - fig.patch.set_facecolor(BG) - _style_ax(ax) - - ax.bar(x - w / 2, counts_full, w, color=C_FULL, edgecolor=BORDER, label="Full/Mixed") - ax.bar(x + w / 2, counts_hq, w, color=C_HQ, edgecolor=BORDER, label="HQ") - ax.set_xticks(x) - ax.set_xticklabels(band_labels, fontsize=7, color=SUB, rotation=30) - ax.set_ylabel("Chunk count", color=SUB) - ax.set_title("Chunk Counts per Progress Band", color=TEXT, fontsize=13) - ax.legend(facecolor=CARD, edgecolor=BORDER, labelcolor=TEXT, fontsize=8) - fig.tight_layout() - _save(fig, out_path) - - -# ── Summary figure ─────────────────────────────────────────────────────── - -def plot_summary( - var_full: np.ndarray, - var_hq: np.ndarray, - band_edges: list[tuple[float, float]], - ks_full: list[int | None], - ks_hq: list[int | None], - bands_full: dict[int, list[dict]], - bands_hq: dict[int, list[dict]], - out_path: Path, -) -> None: - ratio = np.nanmean(var_full, axis=1) / (np.nanmean(var_hq, axis=1) + 1e-8) - valid_ratio = ratio[~np.isnan(ratio)] - mean_ratio = float(np.mean(valid_ratio)) if len(valid_ratio) > 0 else float("nan") - peak_idx = int(np.argmax(valid_ratio)) if len(valid_ratio) > 0 else 0 - peak_ratio = float(valid_ratio[peak_idx]) if len(valid_ratio) > 0 else float("nan") - lo, hi = band_edges[peak_idx] - peak_band = f"{lo:.0%}–{hi:.0%}" - - valid_kf = [k for k in ks_full if k is not None] - valid_kh = [k for k in ks_hq if k is not None] - mean_k_full = np.mean(valid_kf) if valid_kf else float("nan") - mean_k_hq = np.mean(valid_kh) if valid_kh else float("nan") - - n_bands = len(band_edges) - deltas_full = [c["progress_delta"] for b in range(n_bands) for c in bands_full.get(b, [])] - deltas_hq = [c["progress_delta"] for b in range(n_bands) for c in bands_hq.get(b, [])] - mean_delta_full = float(np.mean(deltas_full)) if deltas_full else float("nan") - mean_delta_hq = float(np.mean(deltas_hq)) if deltas_hq else float("nan") - - rows = [ - ("Mean variance ratio (Full / HQ)", f"{mean_ratio:.2f}x"), - ("Peak variance ratio", f"{peak_ratio:.2f}x at {peak_band}"), - ("Mean GMM K — Full", f"{mean_k_full:.1f}"), - ("Mean GMM K — HQ", f"{mean_k_hq:.1f}"), - ("Mean progress Δ — Full", f"{mean_delta_full:.4f}"), - ("Mean progress Δ — HQ", f"{mean_delta_hq:.4f}"), - ] - - fig, ax = plt.subplots(figsize=(8, 3)) - fig.patch.set_facecolor(BG) - ax.set_facecolor(CARD) - ax.axis("off") - - table = ax.table( - cellText=[[m, v] for m, v in rows], - colLabels=["Metric", "Value"], - loc="center", - cellLoc="left", - ) - table.auto_set_font_size(False) - table.set_fontsize(10) - for key, cell in table.get_celld().items(): - cell.set_edgecolor(BORDER) - cell.set_facecolor(CARD) - cell.set_text_props(color=TEXT) - if key[0] == 0: - cell.set_text_props(color=TEXT, fontweight="bold") - table.scale(1, 1.6) - ax.set_title("Summary Statistics", color=TEXT, fontsize=13, pad=15) - fig.tight_layout() - _save(fig, out_path) - - for metric, value in rows: - logger.info(" %s: %s", metric, value) - - -# ── Main ───────────────────────────────────────────────────────────────── - -def main(args: argparse.Namespace) -> None: - out = Path(args.output_dir) - out.mkdir(parents=True, exist_ok=True) - - logger.info("Loading FULL dataset: %s", args.full_dataset) - episodes_full = load_episodes(args.full_dataset, args.n_joints, args.max_episodes) - logger.info("Loading HQ dataset: %s", args.hq_dataset) - episodes_hq = load_episodes(args.hq_dataset, args.n_joints, args.max_episodes) - logger.info("Loaded %d full episodes, %d HQ episodes", len(episodes_full), len(episodes_hq)) - - # Step 1: length threshold + filter - if args.min_episode_length is not None: - threshold = args.min_episode_length - else: - threshold = auto_length_threshold(episodes_full, episodes_hq) - logger.info("Episode length threshold: %d", threshold) - - plot_length_distribution(episodes_full, episodes_hq, threshold, out / "0_length_distribution.png") - episodes_full = filter_episodes(episodes_full, threshold) - episodes_hq = filter_episodes(episodes_hq, threshold) - - # Step 2: extract chunks - chunks_full = extract_chunks(episodes_full, args.chunk_size, args.chunk_stride) - chunks_hq = extract_chunks(episodes_hq, args.chunk_size, args.chunk_stride) - logger.info("Extracted %d full chunks, %d HQ chunks", len(chunks_full), len(chunks_hq)) - - # Step 3: adaptive bands (fit on full, apply to both) - band_edges = fit_adaptive_bands(chunks_full, args.min_chunks_per_band) - n_bands = len(band_edges) - logger.info("Adaptive bands (%d): %s", n_bands, - [f"{lo:.0%}–{hi:.0%}" for lo, hi in band_edges]) - - chunks_full = assign_bands(chunks_full, band_edges) - chunks_hq = assign_bands(chunks_hq, band_edges) - bands_full = split_by_band(chunks_full, n_bands) - bands_hq = split_by_band(chunks_hq, n_bands) - - # Plot 1: chunk counts - plot_chunk_counts(bands_full, bands_hq, band_edges, out / "1_chunk_counts_per_band.png") - - # Step 4: variance heatmap - var_full = band_variance_matrix(bands_full, n_bands, args.n_joints) - var_hq = band_variance_matrix(bands_hq, n_bands, args.n_joints) - plot_variance_heatmap(var_full, var_hq, band_edges, out / "2_variance_heatmap.png") - - # Step 5: progress delta - plot_progress_delta(bands_full, bands_hq, band_edges, out / "3_progress_delta_per_band.png") - - # Step 6: GMM BIC - ks_full, ks_hq = plot_gmm_bic(bands_full, bands_hq, band_edges, args.seed, out / "4_gmm_bic_per_band.png") - - # Step 7: PCA scatter - plot_pca_scatter(bands_full, bands_hq, band_edges, out / "5_pca_per_band.png") - - # Summary - plot_summary(var_full, var_hq, band_edges, ks_full, ks_hq, - bands_full, bands_hq, out / "6_summary.png") - - logger.info("All figures saved to %s", out) - - -if __name__ == "__main__": - p = argparse.ArgumentParser( - description="Chunk-level multi-modality analysis: Full/Mixed vs HQ dataset.", - formatter_class=argparse.RawDescriptionHelpFormatter, - ) - p.add_argument("--full-dataset", default="lerobot-data-collection/level12_rac_2_2026-02-08_1") - p.add_argument("--hq-dataset", default="lerobot-data-collection/level2_final_quality3") - p.add_argument("--output-dir", default="./chunk_analysis") - p.add_argument("--chunk-size", type=int, default=30) - p.add_argument("--chunk-stride", type=int, default=15) - p.add_argument("--min-chunks-per-band", type=int, default=20) - p.add_argument("--max-episodes", type=int, default=500) - p.add_argument("--n-joints", type=int, default=16) - p.add_argument("--min-episode-length", type=int, default=None, - help="Override auto-detected length filter threshold") - p.add_argument("--seed", type=int, default=42) - args = p.parse_args() - main(args) diff --git a/examples/dataset/visualization_tools/create_progress_videos.py b/examples/dataset/visualization_tools/create_progress_videos.py index ec352c655..da2cf1a40 100644 --- a/examples/dataset/visualization_tools/create_progress_videos.py +++ b/examples/dataset/visualization_tools/create_progress_videos.py @@ -16,8 +16,9 @@ from huggingface_hub import snapshot_download # ─────────────────────── Config ─────────────────────── DATASETS = [ - {"repo_id": "lerobot-data-collection/level1_rac3_rtc_s6_1", "episode": 0}, + {"repo_id": "lerobot-data-collection/level2_final_quality3", "episode": 1100}, ] +CAMERA_KEY = "observation.images.base" # None = auto-select first camera, or set e.g. "observation.images.top" OUTPUT_DIR = Path("/Users/pepijnkooijmans/Documents/GitHub_local/progress_videos") OUTPUT_DIR.mkdir(exist_ok=True) @@ -61,8 +62,13 @@ def load_episode_meta(local: Path, episode: int) -> dict: video_keys = [k for k, v in features.items() if v.get("dtype") == "video"] if not video_keys: raise RuntimeError("No video keys found in dataset features") - first_cam = video_keys[0] - print(f" fps={fps} first_camera='{first_cam}' all_cams={video_keys}") + if CAMERA_KEY is not None: + if CAMERA_KEY not in video_keys: + raise RuntimeError(f"CAMERA_KEY='{CAMERA_KEY}' not found. Available: {video_keys}") + first_cam = CAMERA_KEY + else: + first_cam = video_keys[0] + print(f" fps={fps} camera='{first_cam}' all_cams={video_keys}") # Load all episode-meta parquet files and find our episode ep_rows = []