mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-28 06:59:44 +00:00
Add create reward visualization and multimodal analysis tool
This commit is contained in:
@@ -0,0 +1,659 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
Chunk-level multi-modality analysis for comparing full/mixed vs curated datasets.
|
||||||
|
|
||||||
|
Treats each action chunk (sliding window of CHUNK_SIZE consecutive frames) as the
|
||||||
|
atomic unit, tagged by the SARM progress score at its start frame. For each
|
||||||
|
progress band, compares the full vs HQ dataset on:
|
||||||
|
|
||||||
|
1. Intra-band action variance
|
||||||
|
2. Progress delta per chunk
|
||||||
|
3. GMM + BIC optimal K (number of distinct strategies)
|
||||||
|
4. PCA embedding (visual cluster inspection)
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python chunk_multimodality_analysis.py \\
|
||||||
|
--full-dataset lerobot-data-collection/level12_rac_2_2026-02-08_1 \\
|
||||||
|
--hq-dataset lerobot-data-collection/level2_final_quality3 \\
|
||||||
|
--output-dir ./chunk_analysis
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from collections import defaultdict
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
from scipy.stats import gaussian_kde
|
||||||
|
from sklearn.decomposition import PCA
|
||||||
|
from sklearn.mixture import GaussianMixture
|
||||||
|
from sklearn.preprocessing import StandardScaler
|
||||||
|
|
||||||
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# ── Visual style ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
BG = "#0e1117"
|
||||||
|
CARD = "#1a1d27"
|
||||||
|
BORDER = "#2a2d3a"
|
||||||
|
SUB = "#8b8fa8"
|
||||||
|
TEXT = "#e8eaf0"
|
||||||
|
C_FULL = "#f7934f"
|
||||||
|
C_HQ = "#4dc98a"
|
||||||
|
|
||||||
|
|
||||||
|
def _style_ax(ax: plt.Axes) -> None:
|
||||||
|
ax.set_facecolor(CARD)
|
||||||
|
ax.tick_params(colors=SUB, labelsize=8)
|
||||||
|
for spine in ax.spines.values():
|
||||||
|
spine.set_color(BORDER)
|
||||||
|
|
||||||
|
|
||||||
|
def _save(fig: plt.Figure, path: Path) -> None:
|
||||||
|
fig.savefig(path, dpi=150, bbox_inches="tight", facecolor=BG)
|
||||||
|
plt.close(fig)
|
||||||
|
logger.info("Saved %s", path)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Step 0: Load episodes ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def load_episodes(
|
||||||
|
repo_id: str,
|
||||||
|
n_joints: int = 16,
|
||||||
|
max_episodes: int | None = None,
|
||||||
|
) -> list[dict]:
|
||||||
|
dataset = LeRobotDataset(repo_id, download_videos=False)
|
||||||
|
raw = dataset.hf_dataset
|
||||||
|
episodes: dict[int, dict] = defaultdict(lambda: {"actions": [], "progress": []})
|
||||||
|
|
||||||
|
for row in raw:
|
||||||
|
ep = int(row["episode_index"])
|
||||||
|
if max_episodes is not None and ep >= max_episodes:
|
||||||
|
continue
|
||||||
|
action = np.array(row["action"], dtype=np.float32)[:n_joints]
|
||||||
|
episodes[ep]["actions"].append(action)
|
||||||
|
progress = float(row.get("next.reward", float("nan")))
|
||||||
|
episodes[ep]["progress"].append(progress)
|
||||||
|
|
||||||
|
result = []
|
||||||
|
for ep_id, d in sorted(episodes.items()):
|
||||||
|
result.append({
|
||||||
|
"episode": ep_id,
|
||||||
|
"actions": np.stack(d["actions"]),
|
||||||
|
"progress": np.array(d["progress"]),
|
||||||
|
})
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
# ── Step 1: Filter short episodes ────────────────────────────────────────
|
||||||
|
|
||||||
|
def auto_length_threshold(
|
||||||
|
episodes_full: list[dict], episodes_hq: list[dict]
|
||||||
|
) -> int:
|
||||||
|
all_lengths = np.array(
|
||||||
|
[e["actions"].shape[0] for e in episodes_full + episodes_hq]
|
||||||
|
)
|
||||||
|
kde = gaussian_kde(all_lengths, bw_method=0.25)
|
||||||
|
xs = np.linspace(all_lengths.min(), np.percentile(all_lengths, 40), 300)
|
||||||
|
return int(xs[np.argmin(kde(xs))])
|
||||||
|
|
||||||
|
|
||||||
|
def plot_length_distribution(
|
||||||
|
episodes_full: list[dict],
|
||||||
|
episodes_hq: list[dict],
|
||||||
|
threshold: int,
|
||||||
|
out_path: Path,
|
||||||
|
) -> None:
|
||||||
|
lens_full = np.array([e["actions"].shape[0] for e in episodes_full])
|
||||||
|
lens_hq = np.array([e["actions"].shape[0] for e in episodes_hq])
|
||||||
|
all_lens = np.concatenate([lens_full, lens_hq])
|
||||||
|
|
||||||
|
fig, ax = plt.subplots(figsize=(10, 5))
|
||||||
|
fig.patch.set_facecolor(BG)
|
||||||
|
_style_ax(ax)
|
||||||
|
|
||||||
|
bins = np.linspace(all_lens.min(), all_lens.max(), 50)
|
||||||
|
ax.hist(lens_full, bins=bins, alpha=0.5, color=C_FULL, label="Full/Mixed")
|
||||||
|
ax.hist(lens_hq, bins=bins, alpha=0.5, color=C_HQ, label="HQ")
|
||||||
|
|
||||||
|
xs = np.linspace(all_lens.min(), all_lens.max(), 300)
|
||||||
|
kde = gaussian_kde(all_lens, bw_method=0.25)
|
||||||
|
ax.plot(xs, kde(xs) * len(all_lens) * (bins[1] - bins[0]), color=TEXT, lw=1.5, label="KDE (combined)")
|
||||||
|
|
||||||
|
ax.axvline(threshold, color="#ff4b4b", ls="--", lw=1.5, label=f"Threshold = {threshold}")
|
||||||
|
ax.set_xlabel("Episode length (frames)", color=SUB)
|
||||||
|
ax.set_ylabel("Count", color=SUB)
|
||||||
|
ax.set_title("Episode Length Distribution", color=TEXT, fontsize=13)
|
||||||
|
ax.legend(facecolor=CARD, edgecolor=BORDER, labelcolor=TEXT, fontsize=8)
|
||||||
|
_save(fig, out_path)
|
||||||
|
|
||||||
|
|
||||||
|
def filter_episodes(episodes: list[dict], min_length: int) -> list[dict]:
|
||||||
|
kept = [e for e in episodes if e["actions"].shape[0] >= min_length]
|
||||||
|
logger.info("Kept %d / %d episodes (min_length=%d)", len(kept), len(episodes), min_length)
|
||||||
|
return kept
|
||||||
|
|
||||||
|
|
||||||
|
# ── Step 2: Extract chunks ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
def extract_chunks(
|
||||||
|
episodes: list[dict],
|
||||||
|
chunk_size: int = 30,
|
||||||
|
chunk_stride: int = 15,
|
||||||
|
) -> list[dict]:
|
||||||
|
chunks = []
|
||||||
|
for ep in episodes:
|
||||||
|
actions = ep["actions"]
|
||||||
|
T = len(actions)
|
||||||
|
|
||||||
|
prog = np.clip(np.nan_to_num(ep["progress"], nan=0.0), 0.0, 1.0)
|
||||||
|
prog = np.maximum.accumulate(prog)
|
||||||
|
|
||||||
|
for t in range(0, T - chunk_size, chunk_stride):
|
||||||
|
chunk = actions[t : t + chunk_size]
|
||||||
|
p_start = float(prog[t])
|
||||||
|
p_end = float(prog[min(t + chunk_size, T - 1)])
|
||||||
|
|
||||||
|
chunks.append({
|
||||||
|
"action_mean": chunk.mean(axis=0).astype(np.float32),
|
||||||
|
"action_flat": chunk.flatten().astype(np.float32),
|
||||||
|
"progress_start": p_start,
|
||||||
|
"progress_delta": p_end - p_start,
|
||||||
|
"episode": ep["episode"],
|
||||||
|
})
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
|
||||||
|
# ── Step 3: Adaptive progress bands ─────────────────────────────────────
|
||||||
|
|
||||||
|
def fit_adaptive_bands(
|
||||||
|
chunks: list[dict], min_per_band: int = 20
|
||||||
|
) -> list[tuple[float, float]]:
|
||||||
|
prog_vals = np.array([c["progress_start"] for c in chunks])
|
||||||
|
fine_edges = np.linspace(0.0, 1.0, 11)
|
||||||
|
|
||||||
|
band_edges: list[tuple[float, float]] = []
|
||||||
|
i = 0
|
||||||
|
while i < len(fine_edges) - 1:
|
||||||
|
lo, hi = fine_edges[i], fine_edges[i + 1]
|
||||||
|
j = i + 1
|
||||||
|
while (
|
||||||
|
np.sum((prog_vals >= lo) & (prog_vals < hi)) < min_per_band
|
||||||
|
and j < len(fine_edges) - 1
|
||||||
|
):
|
||||||
|
j += 1
|
||||||
|
hi = fine_edges[j]
|
||||||
|
band_edges.append((lo, hi))
|
||||||
|
i = j
|
||||||
|
return band_edges
|
||||||
|
|
||||||
|
|
||||||
|
def assign_bands(
|
||||||
|
chunks: list[dict], band_edges: list[tuple[float, float]]
|
||||||
|
) -> list[dict]:
|
||||||
|
n = len(band_edges)
|
||||||
|
for c in chunks:
|
||||||
|
p = c["progress_start"]
|
||||||
|
c["band"] = next(
|
||||||
|
(bi for bi, (lo, hi) in enumerate(band_edges) if p < hi),
|
||||||
|
n - 1,
|
||||||
|
)
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
|
||||||
|
def split_by_band(chunks: list[dict], n_bands: int) -> dict[int, list[dict]]:
|
||||||
|
out: dict[int, list[dict]] = {b: [] for b in range(n_bands)}
|
||||||
|
for c in chunks:
|
||||||
|
out[c["band"]].append(c)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
# ── Step 4: Intra-band action variance ──────────────────────────────────
|
||||||
|
|
||||||
|
def band_variance_matrix(
|
||||||
|
bands: dict[int, list[dict]], n_bands: int, n_joints: int
|
||||||
|
) -> np.ndarray:
|
||||||
|
var_mat = np.full((n_bands, n_joints), np.nan)
|
||||||
|
for b, clist in bands.items():
|
||||||
|
if len(clist) < 3:
|
||||||
|
continue
|
||||||
|
means = np.stack([c["action_mean"] for c in clist])
|
||||||
|
var_mat[b] = np.var(means, axis=0)
|
||||||
|
return var_mat
|
||||||
|
|
||||||
|
|
||||||
|
def plot_variance_heatmap(
|
||||||
|
var_full: np.ndarray,
|
||||||
|
var_hq: np.ndarray,
|
||||||
|
band_edges: list[tuple[float, float]],
|
||||||
|
out_path: Path,
|
||||||
|
) -> None:
|
||||||
|
n_bands = var_full.shape[0]
|
||||||
|
vmin = 0.0
|
||||||
|
vmax = max(np.nanmax(var_full), np.nanmax(var_hq))
|
||||||
|
|
||||||
|
band_labels = [f"{lo:.0%}–{hi:.0%}" for lo, hi in band_edges]
|
||||||
|
joint_labels = [f"J{j}" for j in range(var_full.shape[1])]
|
||||||
|
|
||||||
|
fig, axes = plt.subplots(3, 1, figsize=(12, 10), gridspec_kw={"height_ratios": [3, 3, 2]})
|
||||||
|
fig.patch.set_facecolor(BG)
|
||||||
|
fig.suptitle("Intra-Band Action Variance", color=TEXT, fontsize=14, y=0.98)
|
||||||
|
|
||||||
|
for ax_idx, (mat, label) in enumerate([(var_full, "Full/Mixed"), (var_hq, "HQ")]):
|
||||||
|
ax = axes[ax_idx]
|
||||||
|
_style_ax(ax)
|
||||||
|
im = ax.imshow(mat, aspect="auto", cmap="YlOrRd", vmin=vmin, vmax=vmax)
|
||||||
|
ax.set_yticks(range(n_bands))
|
||||||
|
ax.set_yticklabels(band_labels, fontsize=7, color=SUB)
|
||||||
|
ax.set_xticks(range(var_full.shape[1]))
|
||||||
|
ax.set_xticklabels(joint_labels, fontsize=7, color=SUB)
|
||||||
|
ax.set_title(f"Panel {'A' if ax_idx == 0 else 'B'}: {label}", color=TEXT, fontsize=11)
|
||||||
|
fig.colorbar(im, ax=ax, fraction=0.02, pad=0.02)
|
||||||
|
|
||||||
|
ratio = np.nanmean(var_full, axis=1) / (np.nanmean(var_hq, axis=1) + 1e-8)
|
||||||
|
ax_bar = axes[2]
|
||||||
|
_style_ax(ax_bar)
|
||||||
|
colors = [
|
||||||
|
"#ff4b4b" if r > 2.0 else "#ffaa33" if r > 1.2 else C_HQ
|
||||||
|
for r in ratio
|
||||||
|
]
|
||||||
|
ax_bar.bar(range(n_bands), ratio, color=colors, edgecolor=BORDER)
|
||||||
|
ax_bar.axhline(1.0, color=SUB, ls="--", lw=0.8)
|
||||||
|
ax_bar.set_xticks(range(n_bands))
|
||||||
|
ax_bar.set_xticklabels(band_labels, fontsize=7, color=SUB)
|
||||||
|
ax_bar.set_ylabel("Variance ratio\n(Full / HQ)", color=SUB, fontsize=9)
|
||||||
|
ax_bar.set_title("Panel C: Variance Ratio per Band", color=TEXT, fontsize=11)
|
||||||
|
|
||||||
|
fig.tight_layout(rect=[0, 0, 1, 0.96])
|
||||||
|
_save(fig, out_path)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Step 5: Progress delta per band ──────────────────────────────────────
|
||||||
|
|
||||||
|
def plot_progress_delta(
|
||||||
|
bands_full: dict[int, list[dict]],
|
||||||
|
bands_hq: dict[int, list[dict]],
|
||||||
|
band_edges: list[tuple[float, float]],
|
||||||
|
out_path: Path,
|
||||||
|
) -> None:
|
||||||
|
n_bands = len(band_edges)
|
||||||
|
band_labels = [f"{lo:.0%}–{hi:.0%}" for lo, hi in band_edges]
|
||||||
|
x = np.arange(n_bands)
|
||||||
|
w = 0.35
|
||||||
|
|
||||||
|
means_full, stds_full = [], []
|
||||||
|
means_hq, stds_hq = [], []
|
||||||
|
all_deltas_full, all_deltas_hq = [], []
|
||||||
|
|
||||||
|
for b in range(n_bands):
|
||||||
|
df = np.array([c["progress_delta"] for c in bands_full.get(b, [])])
|
||||||
|
dh = np.array([c["progress_delta"] for c in bands_hq.get(b, [])])
|
||||||
|
means_full.append(np.mean(df) if len(df) > 0 else 0)
|
||||||
|
stds_full.append(np.std(df) if len(df) > 0 else 0)
|
||||||
|
means_hq.append(np.mean(dh) if len(dh) > 0 else 0)
|
||||||
|
stds_hq.append(np.std(dh) if len(dh) > 0 else 0)
|
||||||
|
all_deltas_full.extend(df.tolist())
|
||||||
|
all_deltas_hq.extend(dh.tolist())
|
||||||
|
|
||||||
|
fig, (ax_bar, ax_viol) = plt.subplots(1, 2, figsize=(14, 5), gridspec_kw={"width_ratios": [3, 1]})
|
||||||
|
fig.patch.set_facecolor(BG)
|
||||||
|
fig.suptitle("Progress Delta per Chunk", color=TEXT, fontsize=14)
|
||||||
|
|
||||||
|
_style_ax(ax_bar)
|
||||||
|
ax_bar.bar(x - w / 2, means_full, w, yerr=stds_full, color=C_FULL, edgecolor=BORDER,
|
||||||
|
capsize=3, label="Full/Mixed", error_kw={"ecolor": SUB})
|
||||||
|
ax_bar.bar(x + w / 2, means_hq, w, yerr=stds_hq, color=C_HQ, edgecolor=BORDER,
|
||||||
|
capsize=3, label="HQ", error_kw={"ecolor": SUB})
|
||||||
|
ax_bar.set_xticks(x)
|
||||||
|
ax_bar.set_xticklabels(band_labels, fontsize=7, color=SUB, rotation=30)
|
||||||
|
ax_bar.set_ylabel("Mean progress Δ", color=SUB)
|
||||||
|
ax_bar.legend(facecolor=CARD, edgecolor=BORDER, labelcolor=TEXT, fontsize=8)
|
||||||
|
|
||||||
|
_style_ax(ax_viol)
|
||||||
|
data_viol = [np.array(all_deltas_full), np.array(all_deltas_hq)]
|
||||||
|
if all(len(d) > 0 for d in data_viol):
|
||||||
|
parts = ax_viol.violinplot(data_viol, positions=[0, 1], showmeans=True, showmedians=True)
|
||||||
|
for pc, c in zip(parts["bodies"], [C_FULL, C_HQ]):
|
||||||
|
pc.set_facecolor(c)
|
||||||
|
pc.set_alpha(0.7)
|
||||||
|
for key in ("cmeans", "cmedians", "cbars", "cmins", "cmaxes"):
|
||||||
|
if key in parts:
|
||||||
|
parts[key].set_color(SUB)
|
||||||
|
ax_viol.set_xticks([0, 1])
|
||||||
|
ax_viol.set_xticklabels(["Full", "HQ"], color=SUB)
|
||||||
|
ax_viol.set_ylabel("Progress Δ", color=SUB)
|
||||||
|
ax_viol.set_title("Overall Distribution", color=TEXT, fontsize=10)
|
||||||
|
|
||||||
|
fig.tight_layout()
|
||||||
|
_save(fig, out_path)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Step 6: GMM + BIC per band ──────────────────────────────────────────
|
||||||
|
|
||||||
|
def gmm_optimal_k(
|
||||||
|
band_chunks: list[dict],
|
||||||
|
pca_components: int = 15,
|
||||||
|
max_k: int = 7,
|
||||||
|
seed: int = 42,
|
||||||
|
) -> int | None:
|
||||||
|
if len(band_chunks) < 20:
|
||||||
|
return None
|
||||||
|
X = np.stack([c["action_flat"] for c in band_chunks])
|
||||||
|
X = StandardScaler().fit_transform(X)
|
||||||
|
n = min(pca_components, X.shape[1], X.shape[0] - 1)
|
||||||
|
X_r = PCA(n_components=n, random_state=seed).fit_transform(X)
|
||||||
|
bics = []
|
||||||
|
for k in range(1, min(max_k + 1, len(X_r) // 6)):
|
||||||
|
gmm = GaussianMixture(
|
||||||
|
n_components=k, covariance_type="full",
|
||||||
|
n_init=5, max_iter=300, random_state=seed,
|
||||||
|
)
|
||||||
|
gmm.fit(X_r)
|
||||||
|
bics.append((k, gmm.bic(X_r)))
|
||||||
|
if not bics:
|
||||||
|
return None
|
||||||
|
return min(bics, key=lambda x: x[1])[0]
|
||||||
|
|
||||||
|
|
||||||
|
def plot_gmm_bic(
|
||||||
|
bands_full: dict[int, list[dict]],
|
||||||
|
bands_hq: dict[int, list[dict]],
|
||||||
|
band_edges: list[tuple[float, float]],
|
||||||
|
seed: int,
|
||||||
|
out_path: Path,
|
||||||
|
) -> tuple[list[int | None], list[int | None]]:
|
||||||
|
n_bands = len(band_edges)
|
||||||
|
ks_full = [gmm_optimal_k(bands_full.get(b, []), seed=seed) for b in range(n_bands)]
|
||||||
|
ks_hq = [gmm_optimal_k(bands_hq.get(b, []), seed=seed) for b in range(n_bands)]
|
||||||
|
|
||||||
|
band_labels = [f"{lo:.0%}–{hi:.0%}" for lo, hi in band_edges]
|
||||||
|
|
||||||
|
fig, ax = plt.subplots(figsize=(10, 5))
|
||||||
|
fig.patch.set_facecolor(BG)
|
||||||
|
_style_ax(ax)
|
||||||
|
|
||||||
|
xs = np.arange(n_bands)
|
||||||
|
valid_full = [(i, k) for i, k in enumerate(ks_full) if k is not None]
|
||||||
|
valid_hq = [(i, k) for i, k in enumerate(ks_hq) if k is not None]
|
||||||
|
|
||||||
|
if valid_full:
|
||||||
|
xi, yi = zip(*valid_full)
|
||||||
|
ax.plot(xi, yi, "o-", color=C_FULL, label="Full/Mixed", lw=2, markersize=7)
|
||||||
|
if valid_hq:
|
||||||
|
xi, yi = zip(*valid_hq)
|
||||||
|
ax.plot(xi, yi, "o-", color=C_HQ, label="HQ", lw=2, markersize=7)
|
||||||
|
|
||||||
|
if valid_full and valid_hq:
|
||||||
|
all_x = sorted(set([i for i, _ in valid_full]) & set([i for i, _ in valid_hq]))
|
||||||
|
if len(all_x) >= 2:
|
||||||
|
kf_interp = {i: k for i, k in valid_full}
|
||||||
|
kh_interp = {i: k for i, k in valid_hq}
|
||||||
|
shared_x = [i for i in all_x if i in kf_interp and i in kh_interp]
|
||||||
|
yf = [kf_interp[i] for i in shared_x]
|
||||||
|
yh = [kh_interp[i] for i in shared_x]
|
||||||
|
ax.fill_between(shared_x, yf, yh, alpha=0.15, color=TEXT)
|
||||||
|
|
||||||
|
ax.set_xticks(xs)
|
||||||
|
ax.set_xticklabels(band_labels, fontsize=7, color=SUB, rotation=30)
|
||||||
|
ax.set_ylabel("Optimal K (GMM-BIC)", color=SUB)
|
||||||
|
ax.set_title("Number of Distinct Strategies per Band", color=TEXT, fontsize=13)
|
||||||
|
ax.legend(facecolor=CARD, edgecolor=BORDER, labelcolor=TEXT, fontsize=9)
|
||||||
|
ax.yaxis.set_major_locator(plt.MaxNLocator(integer=True))
|
||||||
|
fig.tight_layout()
|
||||||
|
_save(fig, out_path)
|
||||||
|
return ks_full, ks_hq
|
||||||
|
|
||||||
|
|
||||||
|
# ── Step 7: PCA scatter per band ────────────────────────────────────────
|
||||||
|
|
||||||
|
def plot_pca_scatter(
|
||||||
|
bands_full: dict[int, list[dict]],
|
||||||
|
bands_hq: dict[int, list[dict]],
|
||||||
|
band_edges: list[tuple[float, float]],
|
||||||
|
out_path: Path,
|
||||||
|
) -> None:
|
||||||
|
n_plot = min(4, len(band_edges))
|
||||||
|
fig, axes = plt.subplots(2, n_plot, figsize=(4 * n_plot, 7))
|
||||||
|
fig.patch.set_facecolor(BG)
|
||||||
|
fig.suptitle("PCA of Action Chunks per Band", color=TEXT, fontsize=14)
|
||||||
|
|
||||||
|
if n_plot == 1:
|
||||||
|
axes = axes.reshape(2, 1)
|
||||||
|
|
||||||
|
for col, b in enumerate(range(n_plot)):
|
||||||
|
cf = bands_full.get(b, [])
|
||||||
|
ch = bands_hq.get(b, [])
|
||||||
|
lo, hi = band_edges[b]
|
||||||
|
|
||||||
|
for row, (clist, color, label) in enumerate([
|
||||||
|
(cf, C_FULL, "Full/Mixed"), (ch, C_HQ, "HQ")
|
||||||
|
]):
|
||||||
|
ax = axes[row, col]
|
||||||
|
_style_ax(ax)
|
||||||
|
if row == 0:
|
||||||
|
ax.set_title(f"{lo:.0%}–{hi:.0%}", color=TEXT, fontsize=10)
|
||||||
|
if col == 0:
|
||||||
|
ax.set_ylabel(label, color=SUB, fontsize=9)
|
||||||
|
|
||||||
|
if len(cf) < 3 or len(ch) < 3:
|
||||||
|
ax.text(0.5, 0.5, "Too few\nchunks", transform=ax.transAxes,
|
||||||
|
ha="center", va="center", color=SUB, fontsize=9)
|
||||||
|
continue
|
||||||
|
|
||||||
|
X_full_b = np.stack([c["action_flat"] for c in cf])
|
||||||
|
X_hq_b = np.stack([c["action_flat"] for c in ch])
|
||||||
|
X_all = np.vstack([X_full_b, X_hq_b])
|
||||||
|
X_all = StandardScaler().fit_transform(X_all)
|
||||||
|
X_2d = PCA(n_components=2, random_state=42).fit_transform(X_all)
|
||||||
|
|
||||||
|
X_2d_full = X_2d[: len(cf)]
|
||||||
|
X_2d_hq = X_2d[len(cf) :]
|
||||||
|
|
||||||
|
pts = X_2d_full if row == 0 else X_2d_hq
|
||||||
|
ax.scatter(pts[:, 0], pts[:, 1], s=8, alpha=0.5, color=color, edgecolors="none")
|
||||||
|
|
||||||
|
fig.tight_layout(rect=[0, 0, 1, 0.95])
|
||||||
|
_save(fig, out_path)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Plot 1: Chunk counts per band ───────────────────────────────────────
|
||||||
|
|
||||||
|
def plot_chunk_counts(
|
||||||
|
bands_full: dict[int, list[dict]],
|
||||||
|
bands_hq: dict[int, list[dict]],
|
||||||
|
band_edges: list[tuple[float, float]],
|
||||||
|
out_path: Path,
|
||||||
|
) -> None:
|
||||||
|
n_bands = len(band_edges)
|
||||||
|
band_labels = [f"{lo:.0%}–{hi:.0%}" for lo, hi in band_edges]
|
||||||
|
x = np.arange(n_bands)
|
||||||
|
w = 0.35
|
||||||
|
|
||||||
|
counts_full = [len(bands_full.get(b, [])) for b in range(n_bands)]
|
||||||
|
counts_hq = [len(bands_hq.get(b, [])) for b in range(n_bands)]
|
||||||
|
|
||||||
|
fig, ax = plt.subplots(figsize=(10, 5))
|
||||||
|
fig.patch.set_facecolor(BG)
|
||||||
|
_style_ax(ax)
|
||||||
|
|
||||||
|
ax.bar(x - w / 2, counts_full, w, color=C_FULL, edgecolor=BORDER, label="Full/Mixed")
|
||||||
|
ax.bar(x + w / 2, counts_hq, w, color=C_HQ, edgecolor=BORDER, label="HQ")
|
||||||
|
ax.set_xticks(x)
|
||||||
|
ax.set_xticklabels(band_labels, fontsize=7, color=SUB, rotation=30)
|
||||||
|
ax.set_ylabel("Chunk count", color=SUB)
|
||||||
|
ax.set_title("Chunk Counts per Progress Band", color=TEXT, fontsize=13)
|
||||||
|
ax.legend(facecolor=CARD, edgecolor=BORDER, labelcolor=TEXT, fontsize=8)
|
||||||
|
fig.tight_layout()
|
||||||
|
_save(fig, out_path)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Summary figure ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def plot_summary(
|
||||||
|
var_full: np.ndarray,
|
||||||
|
var_hq: np.ndarray,
|
||||||
|
band_edges: list[tuple[float, float]],
|
||||||
|
ks_full: list[int | None],
|
||||||
|
ks_hq: list[int | None],
|
||||||
|
bands_full: dict[int, list[dict]],
|
||||||
|
bands_hq: dict[int, list[dict]],
|
||||||
|
out_path: Path,
|
||||||
|
) -> None:
|
||||||
|
ratio = np.nanmean(var_full, axis=1) / (np.nanmean(var_hq, axis=1) + 1e-8)
|
||||||
|
valid_ratio = ratio[~np.isnan(ratio)]
|
||||||
|
mean_ratio = float(np.mean(valid_ratio)) if len(valid_ratio) > 0 else float("nan")
|
||||||
|
peak_idx = int(np.argmax(valid_ratio)) if len(valid_ratio) > 0 else 0
|
||||||
|
peak_ratio = float(valid_ratio[peak_idx]) if len(valid_ratio) > 0 else float("nan")
|
||||||
|
lo, hi = band_edges[peak_idx]
|
||||||
|
peak_band = f"{lo:.0%}–{hi:.0%}"
|
||||||
|
|
||||||
|
valid_kf = [k for k in ks_full if k is not None]
|
||||||
|
valid_kh = [k for k in ks_hq if k is not None]
|
||||||
|
mean_k_full = np.mean(valid_kf) if valid_kf else float("nan")
|
||||||
|
mean_k_hq = np.mean(valid_kh) if valid_kh else float("nan")
|
||||||
|
|
||||||
|
n_bands = len(band_edges)
|
||||||
|
deltas_full = [c["progress_delta"] for b in range(n_bands) for c in bands_full.get(b, [])]
|
||||||
|
deltas_hq = [c["progress_delta"] for b in range(n_bands) for c in bands_hq.get(b, [])]
|
||||||
|
mean_delta_full = float(np.mean(deltas_full)) if deltas_full else float("nan")
|
||||||
|
mean_delta_hq = float(np.mean(deltas_hq)) if deltas_hq else float("nan")
|
||||||
|
|
||||||
|
rows = [
|
||||||
|
("Mean variance ratio (Full / HQ)", f"{mean_ratio:.2f}x"),
|
||||||
|
("Peak variance ratio", f"{peak_ratio:.2f}x at {peak_band}"),
|
||||||
|
("Mean GMM K — Full", f"{mean_k_full:.1f}"),
|
||||||
|
("Mean GMM K — HQ", f"{mean_k_hq:.1f}"),
|
||||||
|
("Mean progress Δ — Full", f"{mean_delta_full:.4f}"),
|
||||||
|
("Mean progress Δ — HQ", f"{mean_delta_hq:.4f}"),
|
||||||
|
]
|
||||||
|
|
||||||
|
fig, ax = plt.subplots(figsize=(8, 3))
|
||||||
|
fig.patch.set_facecolor(BG)
|
||||||
|
ax.set_facecolor(CARD)
|
||||||
|
ax.axis("off")
|
||||||
|
|
||||||
|
table = ax.table(
|
||||||
|
cellText=[[m, v] for m, v in rows],
|
||||||
|
colLabels=["Metric", "Value"],
|
||||||
|
loc="center",
|
||||||
|
cellLoc="left",
|
||||||
|
)
|
||||||
|
table.auto_set_font_size(False)
|
||||||
|
table.set_fontsize(10)
|
||||||
|
for key, cell in table.get_celld().items():
|
||||||
|
cell.set_edgecolor(BORDER)
|
||||||
|
cell.set_facecolor(CARD)
|
||||||
|
cell.set_text_props(color=TEXT)
|
||||||
|
if key[0] == 0:
|
||||||
|
cell.set_text_props(color=TEXT, fontweight="bold")
|
||||||
|
table.scale(1, 1.6)
|
||||||
|
ax.set_title("Summary Statistics", color=TEXT, fontsize=13, pad=15)
|
||||||
|
fig.tight_layout()
|
||||||
|
_save(fig, out_path)
|
||||||
|
|
||||||
|
for metric, value in rows:
|
||||||
|
logger.info(" %s: %s", metric, value)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Main ─────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def main(args: argparse.Namespace) -> None:
|
||||||
|
out = Path(args.output_dir)
|
||||||
|
out.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
logger.info("Loading FULL dataset: %s", args.full_dataset)
|
||||||
|
episodes_full = load_episodes(args.full_dataset, args.n_joints, args.max_episodes)
|
||||||
|
logger.info("Loading HQ dataset: %s", args.hq_dataset)
|
||||||
|
episodes_hq = load_episodes(args.hq_dataset, args.n_joints, args.max_episodes)
|
||||||
|
logger.info("Loaded %d full episodes, %d HQ episodes", len(episodes_full), len(episodes_hq))
|
||||||
|
|
||||||
|
# Step 1: length threshold + filter
|
||||||
|
if args.min_episode_length is not None:
|
||||||
|
threshold = args.min_episode_length
|
||||||
|
else:
|
||||||
|
threshold = auto_length_threshold(episodes_full, episodes_hq)
|
||||||
|
logger.info("Episode length threshold: %d", threshold)
|
||||||
|
|
||||||
|
plot_length_distribution(episodes_full, episodes_hq, threshold, out / "0_length_distribution.png")
|
||||||
|
episodes_full = filter_episodes(episodes_full, threshold)
|
||||||
|
episodes_hq = filter_episodes(episodes_hq, threshold)
|
||||||
|
|
||||||
|
# Step 2: extract chunks
|
||||||
|
chunks_full = extract_chunks(episodes_full, args.chunk_size, args.chunk_stride)
|
||||||
|
chunks_hq = extract_chunks(episodes_hq, args.chunk_size, args.chunk_stride)
|
||||||
|
logger.info("Extracted %d full chunks, %d HQ chunks", len(chunks_full), len(chunks_hq))
|
||||||
|
|
||||||
|
# Step 3: adaptive bands (fit on full, apply to both)
|
||||||
|
band_edges = fit_adaptive_bands(chunks_full, args.min_chunks_per_band)
|
||||||
|
n_bands = len(band_edges)
|
||||||
|
logger.info("Adaptive bands (%d): %s", n_bands,
|
||||||
|
[f"{lo:.0%}–{hi:.0%}" for lo, hi in band_edges])
|
||||||
|
|
||||||
|
chunks_full = assign_bands(chunks_full, band_edges)
|
||||||
|
chunks_hq = assign_bands(chunks_hq, band_edges)
|
||||||
|
bands_full = split_by_band(chunks_full, n_bands)
|
||||||
|
bands_hq = split_by_band(chunks_hq, n_bands)
|
||||||
|
|
||||||
|
# Plot 1: chunk counts
|
||||||
|
plot_chunk_counts(bands_full, bands_hq, band_edges, out / "1_chunk_counts_per_band.png")
|
||||||
|
|
||||||
|
# Step 4: variance heatmap
|
||||||
|
var_full = band_variance_matrix(bands_full, n_bands, args.n_joints)
|
||||||
|
var_hq = band_variance_matrix(bands_hq, n_bands, args.n_joints)
|
||||||
|
plot_variance_heatmap(var_full, var_hq, band_edges, out / "2_variance_heatmap.png")
|
||||||
|
|
||||||
|
# Step 5: progress delta
|
||||||
|
plot_progress_delta(bands_full, bands_hq, band_edges, out / "3_progress_delta_per_band.png")
|
||||||
|
|
||||||
|
# Step 6: GMM BIC
|
||||||
|
ks_full, ks_hq = plot_gmm_bic(bands_full, bands_hq, band_edges, args.seed, out / "4_gmm_bic_per_band.png")
|
||||||
|
|
||||||
|
# Step 7: PCA scatter
|
||||||
|
plot_pca_scatter(bands_full, bands_hq, band_edges, out / "5_pca_per_band.png")
|
||||||
|
|
||||||
|
# Summary
|
||||||
|
plot_summary(var_full, var_hq, band_edges, ks_full, ks_hq,
|
||||||
|
bands_full, bands_hq, out / "6_summary.png")
|
||||||
|
|
||||||
|
logger.info("All figures saved to %s", out)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
p = argparse.ArgumentParser(
|
||||||
|
description="Chunk-level multi-modality analysis: Full/Mixed vs HQ dataset.",
|
||||||
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||||
|
)
|
||||||
|
p.add_argument("--full-dataset", default="lerobot-data-collection/level12_rac_2_2026-02-08_1")
|
||||||
|
p.add_argument("--hq-dataset", default="lerobot-data-collection/level2_final_quality3")
|
||||||
|
p.add_argument("--output-dir", default="./chunk_analysis")
|
||||||
|
p.add_argument("--chunk-size", type=int, default=30)
|
||||||
|
p.add_argument("--chunk-stride", type=int, default=15)
|
||||||
|
p.add_argument("--min-chunks-per-band", type=int, default=20)
|
||||||
|
p.add_argument("--max-episodes", type=int, default=500)
|
||||||
|
p.add_argument("--n-joints", type=int, default=16)
|
||||||
|
p.add_argument("--min-episode-length", type=int, default=None,
|
||||||
|
help="Override auto-detected length filter threshold")
|
||||||
|
p.add_argument("--seed", type=int, default=42)
|
||||||
|
args = p.parse_args()
|
||||||
|
main(args)
|
||||||
@@ -0,0 +1,471 @@
|
|||||||
|
"""
|
||||||
|
Create MP4 videos with sarm_progress overlay for specified episodes.
|
||||||
|
Downloads datasets from HuggingFace, extracts episode video + progress data,
|
||||||
|
and draws the progress line directly on each frame (no panel, no axes).
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import subprocess
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
|
|
||||||
|
# ─────────────────────── Config ───────────────────────
|
||||||
|
DATASETS = [
|
||||||
|
{"repo_id": "lerobot-data-collection/level1_rac3_rtc_s6_1", "episode": 0},
|
||||||
|
]
|
||||||
|
OUTPUT_DIR = Path("/Users/pepijnkooijmans/Documents/GitHub_local/progress_videos")
|
||||||
|
OUTPUT_DIR.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
# Progress line spans the full video height
|
||||||
|
GRAPH_Y_TOP_FRAC = 0.01
|
||||||
|
GRAPH_Y_BOT_FRAC = 0.99
|
||||||
|
LINE_THICKNESS = 3
|
||||||
|
SHADOW_THICKNESS = 6 # white edge thickness
|
||||||
|
REF_ALPHA = 0.45 # opacity of the 1.0 reference line
|
||||||
|
FILL_ALPHA = 0.55 # opacity of the grey fill under the line
|
||||||
|
SCORE_FONT_SCALE = 0.8
|
||||||
|
TASK_FONT_SCALE = 0.55
|
||||||
|
|
||||||
|
|
||||||
|
# ─────────────────────── Helpers ──────────────────────
|
||||||
|
|
||||||
|
def download_episode(repo_id: str, episode: int) -> Path:
|
||||||
|
"""Download only the files needed for this episode."""
|
||||||
|
safe_ep = f"{episode:06d}"
|
||||||
|
# We need: meta/, sarm_progress.parquet, and the relevant video/data chunks.
|
||||||
|
# We'll download meta + sarm first, then figure out chunks.
|
||||||
|
print(f"\n[1/5] Downloading metadata for {repo_id} …")
|
||||||
|
local = Path(
|
||||||
|
snapshot_download(
|
||||||
|
repo_id=repo_id,
|
||||||
|
repo_type="dataset",
|
||||||
|
allow_patterns=["meta/**", "sarm_progress.parquet"],
|
||||||
|
ignore_patterns=["*.mp4"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return local
|
||||||
|
|
||||||
|
|
||||||
|
def load_episode_meta(local: Path, episode: int) -> dict:
|
||||||
|
"""Read info.json + episode-level parquet to get fps, video paths, timestamps."""
|
||||||
|
info = json.loads((local / "meta" / "info.json").read_text())
|
||||||
|
fps = info["fps"]
|
||||||
|
features = info["features"]
|
||||||
|
|
||||||
|
# Find video keys (keys whose dtype=="video")
|
||||||
|
video_keys = [k for k, v in features.items() if v.get("dtype") == "video"]
|
||||||
|
if not video_keys:
|
||||||
|
raise RuntimeError("No video keys found in dataset features")
|
||||||
|
first_cam = video_keys[0]
|
||||||
|
print(f" fps={fps} first_camera='{first_cam}' all_cams={video_keys}")
|
||||||
|
|
||||||
|
# Load all episode-meta parquet files and find our episode
|
||||||
|
ep_rows = []
|
||||||
|
for pq in sorted((local / "meta" / "episodes").glob("**/*.parquet")):
|
||||||
|
df = pd.read_parquet(pq)
|
||||||
|
ep_rows.append(df)
|
||||||
|
ep_df = pd.concat(ep_rows, ignore_index=True)
|
||||||
|
row = ep_df[ep_df["episode_index"] == episode]
|
||||||
|
if row.empty:
|
||||||
|
raise RuntimeError(f"Episode {episode} not found in episode metadata")
|
||||||
|
row = row.iloc[0]
|
||||||
|
|
||||||
|
# Extract video chunk/file index for first camera
|
||||||
|
cam_key = first_cam.replace(".", "/") # some datasets store as nested key
|
||||||
|
# Try both dot and slash variants of the key
|
||||||
|
chunk_col = f"videos/{first_cam}/chunk_index"
|
||||||
|
file_col = f"videos/{first_cam}/file_index"
|
||||||
|
ts_col = f"videos/{first_cam}/from_timestamp"
|
||||||
|
to_col = f"videos/{first_cam}/to_timestamp"
|
||||||
|
|
||||||
|
# Some datasets use different column naming
|
||||||
|
if chunk_col not in row.index:
|
||||||
|
# Try without the 'videos/' prefix
|
||||||
|
chunk_col = f"{first_cam}/chunk_index"
|
||||||
|
file_col = f"{first_cam}/file_index"
|
||||||
|
ts_col = f"{first_cam}/from_timestamp"
|
||||||
|
to_col = f"{first_cam}/to_timestamp"
|
||||||
|
if chunk_col not in row.index:
|
||||||
|
raise RuntimeError(f"Cannot find video metadata columns for {first_cam}.\nAvailable: {list(row.index)}")
|
||||||
|
|
||||||
|
chunk_idx = int(row[chunk_col])
|
||||||
|
file_idx = int(row[file_col])
|
||||||
|
from_ts = float(row[ts_col])
|
||||||
|
to_ts = float(row[to_col])
|
||||||
|
|
||||||
|
video_template = info.get("video_path", "videos/{video_key}/chunk-{chunk_index:03d}/file-{file_index:03d}.mp4")
|
||||||
|
video_rel = video_template.format(
|
||||||
|
video_key=first_cam,
|
||||||
|
chunk_index=chunk_idx,
|
||||||
|
file_index=file_idx,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load task name for this episode
|
||||||
|
# tasks.parquet uses the task string as the row index; task_index column holds the int id
|
||||||
|
task_name = ""
|
||||||
|
try:
|
||||||
|
# Prefer the 'tasks' list directly on the episode row
|
||||||
|
if "tasks" in row.index and row["tasks"] is not None:
|
||||||
|
tasks_val = row["tasks"]
|
||||||
|
if isinstance(tasks_val, (list, tuple, np.ndarray)) and len(tasks_val) > 0:
|
||||||
|
task_name = str(tasks_val[0])
|
||||||
|
else:
|
||||||
|
task_name = str(tasks_val).strip("[]'")
|
||||||
|
else:
|
||||||
|
tasks_pq = local / "meta" / "tasks.parquet"
|
||||||
|
if tasks_pq.exists():
|
||||||
|
tasks_df = pd.read_parquet(tasks_pq)
|
||||||
|
# Row index is the task string; task_index column is the int
|
||||||
|
task_idx = int(row.get("task_index", 0)) if "task_index" in row.index else 0
|
||||||
|
match = tasks_df[tasks_df["task_index"] == task_idx]
|
||||||
|
if not match.empty:
|
||||||
|
task_name = str(match.index[0])
|
||||||
|
print(f" Task name: '{task_name}'")
|
||||||
|
except Exception as e:
|
||||||
|
print(f" WARNING: could not load task name: {e}")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"fps": fps,
|
||||||
|
"first_cam": first_cam,
|
||||||
|
"video_rel": video_rel,
|
||||||
|
"chunk_index": chunk_idx,
|
||||||
|
"file_index": file_idx,
|
||||||
|
"from_ts": from_ts,
|
||||||
|
"to_ts": to_ts,
|
||||||
|
"task_name": task_name,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def download_video(repo_id: str, local: Path, video_rel: str) -> Path:
|
||||||
|
"""Download the specific video file if not already present."""
|
||||||
|
video_path = local / video_rel
|
||||||
|
if video_path.exists():
|
||||||
|
print(f" Video already cached: {video_path}")
|
||||||
|
return video_path
|
||||||
|
print(f"[2/5] Downloading video file {video_rel} …")
|
||||||
|
snapshot_download(
|
||||||
|
repo_id=repo_id,
|
||||||
|
repo_type="dataset",
|
||||||
|
local_dir=str(local),
|
||||||
|
allow_patterns=[video_rel],
|
||||||
|
)
|
||||||
|
if not video_path.exists():
|
||||||
|
raise RuntimeError(f"Video not found after download: {video_path}")
|
||||||
|
return video_path
|
||||||
|
|
||||||
|
|
||||||
|
def load_progress(local: Path, episode: int) -> np.ndarray | None:
|
||||||
|
"""Load sarm_progress values for this episode. Returns sorted array of (frame_index, progress)."""
|
||||||
|
pq_path = local / "sarm_progress.parquet"
|
||||||
|
if not pq_path.exists():
|
||||||
|
print(" WARNING: sarm_progress.parquet not found, trying data parquet …")
|
||||||
|
return None
|
||||||
|
df = pd.read_parquet(pq_path)
|
||||||
|
print(f" sarm_progress.parquet columns: {list(df.columns)}")
|
||||||
|
ep_df = df[df["episode_index"] == episode].copy()
|
||||||
|
if ep_df.empty:
|
||||||
|
print(f" WARNING: No sarm_progress rows for episode {episode}")
|
||||||
|
return None
|
||||||
|
ep_df = ep_df.sort_values("frame_index")
|
||||||
|
|
||||||
|
# Prefer dense, fall back to sparse
|
||||||
|
if "progress_dense" in ep_df.columns and ep_df["progress_dense"].notna().any():
|
||||||
|
prog_col = "progress_dense"
|
||||||
|
elif "progress_sparse" in ep_df.columns:
|
||||||
|
prog_col = "progress_sparse"
|
||||||
|
else:
|
||||||
|
# Last resort: any column with 'progress' in the name
|
||||||
|
prog_cols = [c for c in ep_df.columns if "progress" in c.lower()]
|
||||||
|
if not prog_cols:
|
||||||
|
return None
|
||||||
|
prog_col = prog_cols[0]
|
||||||
|
|
||||||
|
print(f" Using progress column: '{prog_col}'")
|
||||||
|
return ep_df[["frame_index", prog_col]].rename(columns={prog_col: "progress"}).values
|
||||||
|
|
||||||
|
|
||||||
|
def extract_episode_clip(video_path: Path, from_ts: float, to_ts: float, out_path: Path) -> Path:
|
||||||
|
"""Use ffmpeg to cut the episode segment from the combined video file."""
|
||||||
|
duration = to_ts - from_ts
|
||||||
|
print(f"[3/5] Extracting clip [{from_ts:.3f}s → {to_ts:.3f}s] ({duration:.2f}s) …")
|
||||||
|
cmd = [
|
||||||
|
"ffmpeg", "-y",
|
||||||
|
"-ss", str(from_ts),
|
||||||
|
"-i", str(video_path),
|
||||||
|
"-t", str(duration),
|
||||||
|
"-c:v", "libx264",
|
||||||
|
"-preset", "fast",
|
||||||
|
"-crf", "18",
|
||||||
|
"-an",
|
||||||
|
str(out_path),
|
||||||
|
]
|
||||||
|
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||||
|
if result.returncode != 0:
|
||||||
|
raise RuntimeError(f"ffmpeg clip extraction failed:\n{result.stderr}")
|
||||||
|
return out_path
|
||||||
|
|
||||||
|
|
||||||
|
def precompute_pixels(
|
||||||
|
progress_data: np.ndarray,
|
||||||
|
n_frames: int,
|
||||||
|
frame_w: int,
|
||||||
|
frame_h: int,
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Map each progress sample to pixel coordinates.
|
||||||
|
Returns array of shape (N, 2) with (x, y) in pixel space.
|
||||||
|
x spans full video width; y maps progress [0,1] to graph band.
|
||||||
|
"""
|
||||||
|
frame_indices = progress_data[:, 0].astype(float)
|
||||||
|
progress_vals = np.clip(progress_data[:, 1].astype(float), 0.0, 1.0)
|
||||||
|
n = len(frame_indices)
|
||||||
|
|
||||||
|
y_top = int(frame_h * GRAPH_Y_TOP_FRAC)
|
||||||
|
y_bot = int(frame_h * GRAPH_Y_BOT_FRAC)
|
||||||
|
graph_h = y_bot - y_top
|
||||||
|
|
||||||
|
xs = (frame_indices / (n_frames - 1) * (frame_w - 1)).astype(int)
|
||||||
|
# progress=1 → y_top, progress=0 → y_bot
|
||||||
|
ys = (y_bot - progress_vals * graph_h).astype(int)
|
||||||
|
|
||||||
|
return np.stack([xs, ys], axis=1) # (N, 2)
|
||||||
|
|
||||||
|
|
||||||
|
def progress_color(t: float) -> tuple[int, int, int]:
|
||||||
|
"""Interpolate BGR color red→green based on normalised position t in [0,1]."""
|
||||||
|
r = int(255 * (1.0 - t))
|
||||||
|
g = int(255 * t)
|
||||||
|
return (0, g, r) # BGR
|
||||||
|
|
||||||
|
|
||||||
|
def prerender_fill(
|
||||||
|
pixels: np.ndarray,
|
||||||
|
frame_w: int,
|
||||||
|
frame_h: int,
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""Pre-render the full grey fill polygon under the curve as a BGRA image."""
|
||||||
|
y_bot = int(frame_h * GRAPH_Y_BOT_FRAC)
|
||||||
|
fill_img = np.zeros((frame_h, frame_w, 4), dtype=np.uint8)
|
||||||
|
poly = np.concatenate([
|
||||||
|
pixels,
|
||||||
|
[[pixels[-1][0], y_bot], [pixels[0][0], y_bot]],
|
||||||
|
], axis=0).astype(np.int32)
|
||||||
|
cv2.fillPoly(fill_img, [poly], color=(128, 128, 128, int(255 * FILL_ALPHA)))
|
||||||
|
return fill_img
|
||||||
|
|
||||||
|
|
||||||
|
def alpha_composite(base: np.ndarray, overlay_bgra: np.ndarray, x_max: int) -> None:
|
||||||
|
"""Blend overlay onto base in-place, but only for x < x_max."""
|
||||||
|
if x_max <= 0:
|
||||||
|
return
|
||||||
|
roi_b = base[:, :x_max]
|
||||||
|
roi_o = overlay_bgra[:, :x_max]
|
||||||
|
alpha = roi_o[:, :, 3:4].astype(np.float32) / 255.0
|
||||||
|
roi_b[:] = np.clip(
|
||||||
|
roi_o[:, :, :3].astype(np.float32) * alpha
|
||||||
|
+ roi_b.astype(np.float32) * (1.0 - alpha),
|
||||||
|
0, 255,
|
||||||
|
).astype(np.uint8)
|
||||||
|
|
||||||
|
|
||||||
|
def draw_text_outlined(
|
||||||
|
frame: np.ndarray,
|
||||||
|
text: str,
|
||||||
|
pos: tuple[int, int],
|
||||||
|
font_scale: float,
|
||||||
|
thickness: int = 1,
|
||||||
|
) -> None:
|
||||||
|
"""Draw text with a dark outline for readability on any background."""
|
||||||
|
font = cv2.FONT_HERSHEY_SIMPLEX
|
||||||
|
cv2.putText(frame, text, pos, font, font_scale, (0, 0, 0), thickness + 2, cv2.LINE_AA)
|
||||||
|
cv2.putText(frame, text, pos, font, font_scale, (255, 255, 255), thickness, cv2.LINE_AA)
|
||||||
|
|
||||||
|
|
||||||
|
def composite_video(
|
||||||
|
clip_path: Path,
|
||||||
|
progress_data: np.ndarray,
|
||||||
|
out_path: Path,
|
||||||
|
fps: float,
|
||||||
|
frame_h: int,
|
||||||
|
frame_w: int,
|
||||||
|
task_name: str = "",
|
||||||
|
) -> Path:
|
||||||
|
"""Read clip frames, draw gradient progress line with fill + labels, export as GIF."""
|
||||||
|
n_total = int(cv2.VideoCapture(str(clip_path)).get(cv2.CAP_PROP_FRAME_COUNT))
|
||||||
|
pixels = precompute_pixels(progress_data, n_total, frame_w, frame_h)
|
||||||
|
|
||||||
|
y_top = int(frame_h * GRAPH_Y_TOP_FRAC)
|
||||||
|
y_bot = int(frame_h * GRAPH_Y_BOT_FRAC)
|
||||||
|
y_ref = y_top
|
||||||
|
|
||||||
|
# Pre-render fill polygon (line is drawn per-frame with live color)
|
||||||
|
fill_img = prerender_fill(pixels, frame_w, frame_h)
|
||||||
|
|
||||||
|
# 1.0 reference line overlay (full width, drawn once)
|
||||||
|
ref_img = np.zeros((frame_h, frame_w, 4), dtype=np.uint8)
|
||||||
|
cv2.line(ref_img, (0, y_ref), (frame_w - 1, y_ref),
|
||||||
|
(200, 200, 200, int(255 * REF_ALPHA)), 1, cv2.LINE_AA)
|
||||||
|
|
||||||
|
frame_indices = progress_data[:, 0].astype(int)
|
||||||
|
progress_vals = progress_data[:, 1].astype(float)
|
||||||
|
|
||||||
|
print(f"[4/4] Compositing {n_total} frames …")
|
||||||
|
cap = cv2.VideoCapture(str(clip_path))
|
||||||
|
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
|
||||||
|
tmp_path = out_path.parent / (out_path.stem + "_tmp.mp4")
|
||||||
|
writer = cv2.VideoWriter(str(tmp_path), fourcc, fps, (frame_w, frame_h))
|
||||||
|
|
||||||
|
fi = 0
|
||||||
|
while True:
|
||||||
|
ret, frame = cap.read()
|
||||||
|
if not ret:
|
||||||
|
break
|
||||||
|
|
||||||
|
n_drawn = int(np.searchsorted(frame_indices, fi, side="right"))
|
||||||
|
x_cur = int(pixels[min(n_drawn, len(pixels)) - 1][0]) + 1 if n_drawn > 0 else 0
|
||||||
|
|
||||||
|
# 1. reference line (full width, always)
|
||||||
|
alpha_composite(frame, ref_img, frame_w)
|
||||||
|
|
||||||
|
# 2. grey fill under curve up to current x
|
||||||
|
alpha_composite(frame, fill_img, x_cur)
|
||||||
|
|
||||||
|
# 3. progress line — single color that transitions red→green over time
|
||||||
|
if n_drawn >= 2:
|
||||||
|
t_cur = (n_drawn - 1) / max(len(progress_vals) - 1, 1)
|
||||||
|
line_col = progress_color(t_cur)
|
||||||
|
pts = pixels[:n_drawn].reshape(-1, 1, 2).astype(np.int32)
|
||||||
|
cv2.polylines(frame, [pts], isClosed=False,
|
||||||
|
color=(255, 255, 255), thickness=SHADOW_THICKNESS,
|
||||||
|
lineType=cv2.LINE_AA)
|
||||||
|
cv2.polylines(frame, [pts], isClosed=False,
|
||||||
|
color=line_col, thickness=LINE_THICKNESS,
|
||||||
|
lineType=cv2.LINE_AA)
|
||||||
|
|
||||||
|
# 4. score — bottom right
|
||||||
|
if n_drawn > 0:
|
||||||
|
score = float(progress_vals[min(n_drawn, len(progress_vals)) - 1])
|
||||||
|
score_text = f"{score:.2f}"
|
||||||
|
(tw, th), _ = cv2.getTextSize(score_text, cv2.FONT_HERSHEY_SIMPLEX,
|
||||||
|
SCORE_FONT_SCALE, 2)
|
||||||
|
sx = frame_w - tw - 12
|
||||||
|
sy = frame_h - 12
|
||||||
|
# coloured score matching current gradient position
|
||||||
|
t_cur = (n_drawn - 1) / max(len(progress_vals) - 1, 1)
|
||||||
|
score_col = progress_color(t_cur)
|
||||||
|
cv2.putText(frame, score_text, (sx, sy), cv2.FONT_HERSHEY_SIMPLEX,
|
||||||
|
SCORE_FONT_SCALE, (0, 0, 0), 4, cv2.LINE_AA)
|
||||||
|
cv2.putText(frame, score_text, (sx, sy), cv2.FONT_HERSHEY_SIMPLEX,
|
||||||
|
SCORE_FONT_SCALE, score_col, 2, cv2.LINE_AA)
|
||||||
|
|
||||||
|
# 5. task name — top centre
|
||||||
|
if task_name:
|
||||||
|
(tw, _), _ = cv2.getTextSize(task_name, cv2.FONT_HERSHEY_SIMPLEX,
|
||||||
|
TASK_FONT_SCALE, 1)
|
||||||
|
tx = max((frame_w - tw) // 2, 4)
|
||||||
|
draw_text_outlined(frame, task_name, (tx, 22), TASK_FONT_SCALE)
|
||||||
|
|
||||||
|
writer.write(frame)
|
||||||
|
fi += 1
|
||||||
|
if fi % 100 == 0:
|
||||||
|
print(f" Frame {fi}/{n_total} …", end="\r")
|
||||||
|
|
||||||
|
cap.release()
|
||||||
|
writer.release()
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Convert to GIF: full resolution, 12fps, 128-color diff palette (<40MB)
|
||||||
|
gif_path = out_path.with_suffix(".gif")
|
||||||
|
palette = out_path.parent / "_palette.png"
|
||||||
|
r1 = subprocess.run([
|
||||||
|
"ffmpeg", "-y", "-i", str(tmp_path),
|
||||||
|
"-vf", f"fps=10,scale={frame_w}:-1:flags=lanczos,palettegen=max_colors=128:stats_mode=diff",
|
||||||
|
"-update", "1",
|
||||||
|
str(palette),
|
||||||
|
], capture_output=True, text=True)
|
||||||
|
if r1.returncode != 0:
|
||||||
|
print(f" WARNING: palettegen failed:\n{r1.stderr[-500:]}")
|
||||||
|
r2 = subprocess.run([
|
||||||
|
"ffmpeg", "-y",
|
||||||
|
"-i", str(tmp_path), "-i", str(palette),
|
||||||
|
"-filter_complex",
|
||||||
|
f"fps=10,scale={frame_w}:-1:flags=lanczos[v];[v][1:v]paletteuse=dither=bayer:bayer_scale=3",
|
||||||
|
str(gif_path),
|
||||||
|
], capture_output=True, text=True)
|
||||||
|
if r2.returncode != 0:
|
||||||
|
print(f" WARNING: gif encode failed:\n{r2.stderr[-500:]}")
|
||||||
|
tmp_path.unlink(missing_ok=True)
|
||||||
|
palette.unlink(missing_ok=True)
|
||||||
|
return gif_path
|
||||||
|
|
||||||
|
|
||||||
|
# ─────────────────────── Main ──────────────────────────
|
||||||
|
|
||||||
|
def process_dataset(repo_id: str, episode: int):
|
||||||
|
safe_name = repo_id.replace("/", "_")
|
||||||
|
print(f"\n{'='*60}")
|
||||||
|
print(f"Processing: {repo_id} | episode {episode}")
|
||||||
|
print(f"{'='*60}")
|
||||||
|
|
||||||
|
# 1. Download metadata
|
||||||
|
local = download_episode(repo_id, episode)
|
||||||
|
print(f" Local cache: {local}")
|
||||||
|
|
||||||
|
# 2. Read episode metadata
|
||||||
|
ep_meta = load_episode_meta(local, episode)
|
||||||
|
print(f" Episode meta: {ep_meta}")
|
||||||
|
|
||||||
|
# 3. Download video file
|
||||||
|
video_path = download_video(repo_id, local, ep_meta["video_rel"])
|
||||||
|
|
||||||
|
# 4. Extract clip
|
||||||
|
clip_path = OUTPUT_DIR / f"{safe_name}_ep{episode}_clip.mp4"
|
||||||
|
extract_episode_clip(video_path, ep_meta["from_ts"], ep_meta["to_ts"], clip_path)
|
||||||
|
|
||||||
|
# 5. Load progress data
|
||||||
|
progress_data = load_progress(local, episode)
|
||||||
|
if progress_data is None:
|
||||||
|
print(" ERROR: Could not load sarm_progress data. Skipping overlay.")
|
||||||
|
return
|
||||||
|
|
||||||
|
n_progress = len(progress_data)
|
||||||
|
print(f" Progress frames: {n_progress}")
|
||||||
|
|
||||||
|
# 6. Get clip dimensions
|
||||||
|
cap = cv2.VideoCapture(str(clip_path))
|
||||||
|
frame_w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||||
|
frame_h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||||
|
n_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||||
|
actual_fps = cap.get(cv2.CAP_PROP_FPS) or ep_meta["fps"]
|
||||||
|
cap.release()
|
||||||
|
print(f" Clip: {frame_w}×{frame_h} {n_frames} frames @ {actual_fps:.1f}fps")
|
||||||
|
|
||||||
|
# 7. Composite (draw line directly on frames)
|
||||||
|
out_path = OUTPUT_DIR / f"{safe_name}_ep{episode}_progress.mp4"
|
||||||
|
final = composite_video(clip_path, progress_data, out_path, actual_fps, frame_h, frame_w,
|
||||||
|
task_name=ep_meta.get("task_name", ""))
|
||||||
|
clip_path.unlink(missing_ok=True)
|
||||||
|
print(f"\n✓ Done: {final}")
|
||||||
|
return final
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
results = []
|
||||||
|
for cfg in DATASETS:
|
||||||
|
try:
|
||||||
|
out = process_dataset(cfg["repo_id"], cfg["episode"])
|
||||||
|
if out:
|
||||||
|
results.append(out)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\nERROR processing {cfg['repo_id']}: {e}")
|
||||||
|
import traceback; traceback.print_exc()
|
||||||
|
|
||||||
|
print("\n" + "="*60)
|
||||||
|
print("Output files:")
|
||||||
|
for r in results:
|
||||||
|
print(f" {r}")
|
||||||
Reference in New Issue
Block a user