mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-11 22:59:50 +00:00
Compare commits
11 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 026e4c937d | |||
| efe8c09fca | |||
| 58eecad8a4 | |||
| c7fd1f47d1 | |||
| 6370949e5c | |||
| 46b97da168 | |||
| e69be57a66 | |||
| c3a6ddb668 | |||
| dad661012d | |||
| 219c08ccb8 | |||
| 06385902df |
@@ -0,0 +1,717 @@
|
||||
"""
|
||||
Action consistency analysis for imitation learning datasets.
|
||||
|
||||
Two parallel analyses per dataset:
|
||||
1. State-based: KNN in joint-state space → action chunk variance
|
||||
2. Image-based: KNN in SigLIP embedding space → action chunk variance
|
||||
|
||||
Comparing them reveals whether visual similarity and proprioceptive similarity
|
||||
agree on where the data is inconsistent — and images are what the policy
|
||||
primarily sees.
|
||||
"""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import av
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
from matplotlib.colors import LinearSegmentedColormap
|
||||
from PIL import Image
|
||||
from scipy.spatial import cKDTree
|
||||
from transformers import AutoImageProcessor, AutoModel
|
||||
|
||||
DATASETS = [
|
||||
{"repo_id": "lerobot-data-collection/level2_final_quality3", "label": "HQ curated"},
|
||||
{"repo_id": "lerobot-data-collection/level12_rac_2_2026-02-08_1", "label": "Full collection"},
|
||||
]
|
||||
OUTPUT_DIR = Path(__file__).resolve().parent / "outputs"
|
||||
OUTPUT_DIR.mkdir(exist_ok=True)
|
||||
|
||||
MAX_FRAMES = 100_000
|
||||
K_NEIGHBORS = 50
|
||||
ACTION_CHUNK_SIZE = 30
|
||||
CAMERA_KEY = "observation.images.base"
|
||||
ENCODER_MODEL = "google/siglip-base-patch16-224"
|
||||
ENCODE_BATCH_SIZE = 512
|
||||
SEED = 42
|
||||
DPI = 150
|
||||
|
||||
CONSISTENCY_CMAP = LinearSegmentedColormap.from_list(
|
||||
"consistency", ["#0a2e0a", "#1a8e1a", "#88cc22", "#ffaa22", "#ff2222"]
|
||||
)
|
||||
|
||||
# FK chains from OpenArm bimanual URDF (same as workspace_density.py).
|
||||
LEFT_CHAIN = [
|
||||
((-np.pi / 2, 0, 0), (0, 0.031, 0.698), None),
|
||||
((0, 0, 0), (0, 0, 0.0625), (0, 0, 1)),
|
||||
((-np.pi / 2, 0, 0), (-0.0301, 0, 0.06), (-1, 0, 0)),
|
||||
((0, 0, 0), (0.0301, 0, 0.06625), (0, 0, 1)),
|
||||
((0, 0, 0), (0, 0.0315, 0.15375), (0, 1, 0)),
|
||||
((0, 0, 0), (0, -0.0315, 0.0955), (0, 0, 1)),
|
||||
((0, 0, 0), (0.0375, 0, 0.1205), (1, 0, 0)),
|
||||
((0, 0, 0), (-0.0375, 0, 0), (0, -1, 0)),
|
||||
((0, 0, 0), (0, 0, 0.1001), None),
|
||||
((0, 0, 0), (0, 0, 0.08), None),
|
||||
]
|
||||
RIGHT_CHAIN = [
|
||||
((np.pi / 2, 0, 0), (0, -0.031, 0.698), None),
|
||||
((0, 0, 0), (0, 0, 0.0625), (0, 0, 1)),
|
||||
((np.pi / 2, 0, 0), (-0.0301, 0, 0.06), (-1, 0, 0)),
|
||||
((0, 0, 0), (0.0301, 0, 0.06625), (0, 0, 1)),
|
||||
((0, 0, 0), (0, 0.0315, 0.15375), (0, 1, 0)),
|
||||
((0, 0, 0), (0, -0.0315, 0.0955), (0, 0, 1)),
|
||||
((0, 0, 0), (0.0375, 0, 0.1205), (1, 0, 0)),
|
||||
((0, 0, 0), (-0.0375, 0, 0), (0, 1, 0)),
|
||||
((0, 0, 0), (0, 0, 0.1001), None),
|
||||
((0, 0, 0), (0, 0, 0.08), None),
|
||||
]
|
||||
|
||||
|
||||
# ── FK math ─────────────────────────────────────────────
|
||||
|
||||
|
||||
def _rot_x(a: float) -> np.ndarray:
|
||||
c, s = np.cos(a), np.sin(a)
|
||||
return np.array([[1, 0, 0], [0, c, -s], [0, s, c]])
|
||||
|
||||
|
||||
def _rot_y(a: float) -> np.ndarray:
|
||||
c, s = np.cos(a), np.sin(a)
|
||||
return np.array([[c, 0, s], [0, 1, 0], [-s, 0, c]])
|
||||
|
||||
|
||||
def _rot_z(a: float) -> np.ndarray:
|
||||
c, s = np.cos(a), np.sin(a)
|
||||
return np.array([[c, -s, 0], [s, c, 0], [0, 0, 1]])
|
||||
|
||||
|
||||
def _tf(rpy: tuple, xyz: tuple) -> np.ndarray:
|
||||
r, p, y = rpy
|
||||
mat = np.eye(4)
|
||||
mat[:3, :3] = _rot_z(y) @ _rot_y(p) @ _rot_x(r)
|
||||
mat[:3, 3] = xyz
|
||||
return mat
|
||||
|
||||
|
||||
def _batch_axis_rot(axis: tuple, angles: np.ndarray) -> np.ndarray:
|
||||
n = len(angles)
|
||||
ax = np.asarray(axis, dtype=np.float64)
|
||||
ax = ax / np.linalg.norm(ax)
|
||||
x, y, z = ax
|
||||
c = np.cos(angles)
|
||||
s = np.sin(angles)
|
||||
t = 1 - c
|
||||
rot = np.zeros((n, 4, 4))
|
||||
rot[:, 0, 0] = t * x * x + c
|
||||
rot[:, 0, 1] = t * x * y - s * z
|
||||
rot[:, 0, 2] = t * x * z + s * y
|
||||
rot[:, 1, 0] = t * x * y + s * z
|
||||
rot[:, 1, 1] = t * y * y + c
|
||||
rot[:, 1, 2] = t * y * z - s * x
|
||||
rot[:, 2, 0] = t * x * z - s * y
|
||||
rot[:, 2, 1] = t * y * z + s * x
|
||||
rot[:, 2, 2] = t * z * z + c
|
||||
rot[:, 3, 3] = 1.0
|
||||
return rot
|
||||
|
||||
|
||||
def batch_fk(chain: list, joint_angles: np.ndarray) -> np.ndarray:
|
||||
n = joint_angles.shape[0]
|
||||
tf_batch = np.tile(np.eye(4), (n, 1, 1))
|
||||
qi = 0
|
||||
for rpy, xyz, axis in chain:
|
||||
tf_batch = tf_batch @ _tf(rpy, xyz)
|
||||
if axis is not None:
|
||||
rot = _batch_axis_rot(axis, joint_angles[:, qi])
|
||||
tf_batch = np.einsum("nij,njk->nik", tf_batch, rot)
|
||||
qi += 1
|
||||
return tf_batch[:, :3, 3]
|
||||
|
||||
|
||||
# ── Data helpers ────────────────────────────────────────
|
||||
|
||||
|
||||
def _flatten_names(obj: object) -> list[str]:
|
||||
if isinstance(obj, dict):
|
||||
out: list[str] = []
|
||||
for v in obj.values():
|
||||
out.extend(_flatten_names(v))
|
||||
return out
|
||||
if isinstance(obj, (list, tuple)):
|
||||
out = []
|
||||
for item in obj:
|
||||
if isinstance(item, (list, tuple, dict)):
|
||||
out.extend(_flatten_names(item))
|
||||
else:
|
||||
out.append(str(item))
|
||||
return out
|
||||
return [str(obj)]
|
||||
|
||||
|
||||
def _detect_and_convert(vals: np.ndarray) -> np.ndarray:
|
||||
mx = np.max(np.abs(vals))
|
||||
if mx > 360:
|
||||
print(f" Unit detection: servo ticks (max={mx:.0f})")
|
||||
return (vals - 2048) / 2048 * np.pi
|
||||
if mx > 6.3:
|
||||
print(f" Unit detection: degrees (max={mx:.1f})")
|
||||
return np.deg2rad(vals)
|
||||
print(f" Unit detection: radians (max={mx:.3f})")
|
||||
return vals.astype(np.float64)
|
||||
|
||||
|
||||
def _find_joint_indices(features: dict, state_col: str, n_dim: int) -> tuple[list[int], list[int]]:
|
||||
feat = features.get("observation.state", features.get(state_col, {}))
|
||||
names = _flatten_names(feat.get("names", []))
|
||||
left_idx: list[int] = []
|
||||
right_idx: list[int] = []
|
||||
if names and len(names) == n_dim:
|
||||
names_l = [n.lower() for n in names]
|
||||
print(f" Feature names: {names[:4]}…{names[-4:]}")
|
||||
for j in range(1, 8):
|
||||
for i, nm in enumerate(names_l):
|
||||
if f"left_joint_{j}" in nm and i not in left_idx:
|
||||
left_idx.append(i)
|
||||
break
|
||||
for i, nm in enumerate(names_l):
|
||||
if f"right_joint_{j}" in nm and i not in right_idx:
|
||||
right_idx.append(i)
|
||||
break
|
||||
if len(left_idx) == 7 and len(right_idx) == 7:
|
||||
print(f" Matched by name: left={left_idx} right={right_idx}")
|
||||
return left_idx, right_idx
|
||||
if n_dim >= 16:
|
||||
print(" Falling back to positional: [0:7]=left, [8:15]=right")
|
||||
return list(range(7)), list(range(8, 15))
|
||||
if n_dim >= 14:
|
||||
print(" Falling back to positional: [0:7]=left, [7:14]=right")
|
||||
return list(range(7)), list(range(7, 14))
|
||||
raise RuntimeError(f"State dim {n_dim} too small for bimanual 7-DOF robot")
|
||||
|
||||
|
||||
def download_data(repo_id: str, camera_key: str) -> Path:
|
||||
print(f" Downloading {repo_id} (parquet + {camera_key} videos) …")
|
||||
return Path(
|
||||
snapshot_download(
|
||||
repo_id=repo_id,
|
||||
repo_type="dataset",
|
||||
allow_patterns=[
|
||||
"meta/**",
|
||||
"data/**",
|
||||
f"videos/{camera_key}/**",
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# ── Data loading ────────────────────────────────────────
|
||||
|
||||
|
||||
def _build_action_chunks(
|
||||
actions: np.ndarray, episode_ids: np.ndarray, chunk_size: int
|
||||
) -> tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
For each frame, concatenate the next chunk_size actions from the same episode.
|
||||
Returns (action_chunks, valid_mask).
|
||||
"""
|
||||
n = len(actions)
|
||||
act_dim = actions.shape[1]
|
||||
chunks = np.zeros((n, chunk_size * act_dim), dtype=np.float64)
|
||||
valid = np.zeros(n, dtype=bool)
|
||||
|
||||
for i in range(n):
|
||||
end = i + chunk_size
|
||||
if end > n:
|
||||
continue
|
||||
if episode_ids[i] != episode_ids[end - 1]:
|
||||
continue
|
||||
chunks[i] = actions[i:end].ravel()
|
||||
valid[i] = True
|
||||
|
||||
return chunks, valid
|
||||
|
||||
|
||||
def load_state_action_data(local: Path, max_frames: int, chunk_size: int, rng: np.random.Generator) -> dict:
|
||||
"""
|
||||
Load observation.state and action, build action chunks, subsample, normalize.
|
||||
Also returns the original row indices (`chosen_idx`) for video frame mapping.
|
||||
"""
|
||||
info = json.loads((local / "meta" / "info.json").read_text())
|
||||
features = info.get("features", {})
|
||||
|
||||
dfs = [pd.read_parquet(pq) for pq in sorted((local / "data").glob("**/*.parquet"))]
|
||||
df = pd.concat(dfs, ignore_index=True)
|
||||
n_total = len(df)
|
||||
print(f" Total frames: {n_total:,}")
|
||||
|
||||
state_col = next((c for c in df.columns if "observation.state" in c), None)
|
||||
action_col = next((c for c in df.columns if c == "action"), None)
|
||||
if state_col is None:
|
||||
raise RuntimeError(f"No observation.state column. Available: {list(df.columns)}")
|
||||
if action_col is None:
|
||||
raise RuntimeError(f"No action column. Available: {list(df.columns)}")
|
||||
|
||||
ep_col = next((c for c in df.columns if c == "episode_index"), None)
|
||||
if ep_col is None:
|
||||
raise RuntimeError(f"No episode_index column. Available: {list(df.columns)}")
|
||||
|
||||
state_all = np.stack(df[state_col].values).astype(np.float64)
|
||||
action_all = np.stack(df[action_col].values).astype(np.float64)
|
||||
episode_all = df[ep_col].values.astype(np.int64)
|
||||
|
||||
n_dim = state_all.shape[1]
|
||||
act_dim = action_all.shape[1]
|
||||
print(f" State dim: {n_dim} Action dim: {act_dim} Chunk size: {chunk_size}")
|
||||
print(f" Action chunk dim: {chunk_size * act_dim}")
|
||||
|
||||
left_idx, right_idx = _find_joint_indices(features, state_col, n_dim)
|
||||
|
||||
print(" Building action chunks …")
|
||||
action_chunks, valid = _build_action_chunks(action_all, episode_all, chunk_size)
|
||||
valid_idx = np.where(valid)[0]
|
||||
print(f" Valid frames (with full action chunk): {len(valid_idx):,} / {n_total:,}")
|
||||
|
||||
if len(valid_idx) > max_frames:
|
||||
chosen = np.sort(rng.choice(valid_idx, max_frames, replace=False))
|
||||
else:
|
||||
chosen = valid_idx
|
||||
print(f" Using {len(chosen):,} frames")
|
||||
|
||||
state_raw = state_all[chosen]
|
||||
action_raw = action_chunks[chosen]
|
||||
episode_ids = episode_all[chosen]
|
||||
|
||||
state_mean = state_raw.mean(axis=0)
|
||||
state_std = state_raw.std(axis=0)
|
||||
state_std[state_std < 1e-8] = 1.0
|
||||
state_norm = (state_raw - state_mean) / state_std
|
||||
|
||||
action_mean = action_raw.mean(axis=0)
|
||||
action_std = action_raw.std(axis=0)
|
||||
action_std[action_std < 1e-8] = 1.0
|
||||
action_norm = (action_raw - action_mean) / action_std
|
||||
|
||||
return {
|
||||
"state_raw": state_raw,
|
||||
"state_norm": state_norm,
|
||||
"action_raw": action_raw,
|
||||
"action_norm": action_norm,
|
||||
"episode_ids": episode_ids,
|
||||
"episode_all": episode_all,
|
||||
"left_joint_idx": left_idx,
|
||||
"right_joint_idx": right_idx,
|
||||
"n_total": n_total,
|
||||
"chosen_idx": chosen,
|
||||
"df": df,
|
||||
}
|
||||
|
||||
|
||||
# ── Video → frame extraction ──────────────────────────────
|
||||
|
||||
|
||||
def build_video_lookup(local: Path, camera_key: str) -> dict:
|
||||
"""
|
||||
Build a mapping from episode_index → {video_path, fps, from_ts}.
|
||||
"""
|
||||
info = json.loads((local / "meta" / "info.json").read_text())
|
||||
fps = info["fps"]
|
||||
video_template = info.get(
|
||||
"video_path",
|
||||
"videos/{video_key}/chunk-{chunk_index:03d}/file-{file_index:03d}.mp4",
|
||||
)
|
||||
|
||||
ep_rows = []
|
||||
for pq in sorted((local / "meta" / "episodes").glob("**/*.parquet")):
|
||||
ep_rows.append(pd.read_parquet(pq))
|
||||
ep_df = pd.concat(ep_rows, ignore_index=True)
|
||||
|
||||
chunk_col = f"videos/{camera_key}/chunk_index"
|
||||
file_col = f"videos/{camera_key}/file_index"
|
||||
ts_from = f"videos/{camera_key}/from_timestamp"
|
||||
if chunk_col not in ep_df.columns:
|
||||
chunk_col = f"{camera_key}/chunk_index"
|
||||
file_col = f"{camera_key}/file_index"
|
||||
ts_from = f"{camera_key}/from_timestamp"
|
||||
|
||||
lookup: dict[int, dict] = {}
|
||||
for _, row in ep_df.iterrows():
|
||||
ci = int(row[chunk_col])
|
||||
fi = int(row[file_col])
|
||||
video_rel = video_template.format(video_key=camera_key, chunk_index=ci, file_index=fi)
|
||||
lookup[int(row["episode_index"])] = {
|
||||
"video_path": local / video_rel,
|
||||
"from_ts": float(row[ts_from]),
|
||||
"fps": fps,
|
||||
}
|
||||
return lookup
|
||||
|
||||
|
||||
def _decode_video_frames(video_path: str) -> list[np.ndarray]:
|
||||
"""Decode all frames from a video file using PyAV. Returns list of RGB arrays."""
|
||||
container = av.open(video_path)
|
||||
stream = container.streams.video[0]
|
||||
stream.thread_type = "AUTO"
|
||||
decoded = []
|
||||
for frame in container.decode(stream):
|
||||
decoded.append(frame.to_ndarray(format="rgb24"))
|
||||
container.close()
|
||||
return decoded
|
||||
|
||||
|
||||
def extract_frames(
|
||||
chosen_idx: np.ndarray,
|
||||
episode_all: np.ndarray,
|
||||
video_lookup: dict,
|
||||
) -> list[np.ndarray | None]:
|
||||
"""
|
||||
Extract RGB frames for each chosen global index using PyAV.
|
||||
Returns list of (H, W, 3) RGB arrays (or None on failure).
|
||||
"""
|
||||
unique_eps = np.unique(episode_all)
|
||||
ep_start: dict[int, int] = {}
|
||||
for ep in unique_eps:
|
||||
ep_start[int(ep)] = int(np.where(episode_all == ep)[0][0])
|
||||
|
||||
# Build jobs: (output_index, video_path, local_frame_number)
|
||||
jobs: list[tuple[int, str, int]] = []
|
||||
for out_i, global_i in enumerate(chosen_idx):
|
||||
ep = int(episode_all[global_i])
|
||||
info = video_lookup.get(ep)
|
||||
if info is None:
|
||||
continue
|
||||
local_frame = global_i - ep_start[ep]
|
||||
jobs.append((out_i, str(info["video_path"]), local_frame))
|
||||
|
||||
# Group by video file, decode each video once
|
||||
from collections import defaultdict
|
||||
|
||||
video_jobs: dict[str, list[tuple[int, int]]] = defaultdict(list)
|
||||
for out_i, vpath, local_frame in jobs:
|
||||
video_jobs[vpath].append((out_i, local_frame))
|
||||
|
||||
frames: list[np.ndarray | None] = [None] * len(chosen_idx)
|
||||
extracted = 0
|
||||
n_videos = len(video_jobs)
|
||||
for vi, (vpath, frame_requests) in enumerate(video_jobs.items()):
|
||||
if not Path(vpath).exists():
|
||||
continue
|
||||
try:
|
||||
decoded = _decode_video_frames(vpath)
|
||||
except Exception as exc:
|
||||
print(f" Warning: failed to decode {Path(vpath).name}: {exc}")
|
||||
continue
|
||||
for out_i, local_frame in frame_requests:
|
||||
if 0 <= local_frame < len(decoded):
|
||||
frames[out_i] = decoded[local_frame]
|
||||
extracted += 1
|
||||
if (vi + 1) % 50 == 0 or (vi + 1) == n_videos:
|
||||
print(f" Decoded {vi + 1}/{n_videos} videos ({extracted:,} frames so far)")
|
||||
del decoded
|
||||
|
||||
print(f" Extracted {extracted:,} / {len(chosen_idx):,} frames from video")
|
||||
return frames
|
||||
|
||||
|
||||
# ── SigLIP encoding ─────────────────────────────────────
|
||||
|
||||
|
||||
def encode_frames_siglip(
|
||||
frames: list[np.ndarray | None],
|
||||
model_name: str,
|
||||
batch_size: int,
|
||||
device: torch.device,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Encode RGB frames through SigLIP vision encoder.
|
||||
Returns (N, embed_dim) float32 array. Frames that are None get a zero vector.
|
||||
"""
|
||||
print(f" Loading SigLIP model: {model_name} …")
|
||||
processor = AutoImageProcessor.from_pretrained(model_name)
|
||||
model = AutoModel.from_pretrained(model_name).to(device).eval()
|
||||
embed_dim = model.config.vision_config.hidden_size
|
||||
|
||||
n = len(frames)
|
||||
embeddings = np.zeros((n, embed_dim), dtype=np.float32)
|
||||
|
||||
valid_indices = [i for i, f in enumerate(frames) if f is not None]
|
||||
print(f" Encoding {len(valid_indices):,} valid frames in batches of {batch_size} …")
|
||||
|
||||
for batch_start in range(0, len(valid_indices), batch_size):
|
||||
batch_idx = valid_indices[batch_start : batch_start + batch_size]
|
||||
pil_images = [Image.fromarray(frames[i]) for i in batch_idx]
|
||||
|
||||
inputs = processor(images=pil_images, return_tensors="pt").to(device)
|
||||
with torch.no_grad():
|
||||
image_features = model.get_image_features(**inputs)
|
||||
image_features = torch.nn.functional.normalize(image_features, dim=-1)
|
||||
embeddings[batch_idx] = image_features.cpu().numpy()
|
||||
|
||||
done = min(batch_start + batch_size, len(valid_indices))
|
||||
if done % (batch_size * 10) == 0 or done == len(valid_indices):
|
||||
print(f" {done:,} / {len(valid_indices):,} encoded")
|
||||
|
||||
del model, processor
|
||||
torch.cuda.empty_cache()
|
||||
return embeddings
|
||||
|
||||
|
||||
# ── KNN consistency ─────────────────────────────────────
|
||||
|
||||
|
||||
def compute_consistency(
|
||||
features: np.ndarray,
|
||||
action_norm: np.ndarray,
|
||||
episode_ids: np.ndarray,
|
||||
k: int,
|
||||
label: str = "",
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
For each frame, find K nearest neighbors in feature space from other episodes.
|
||||
Return per-frame action variance (mean across action dims).
|
||||
"""
|
||||
n = len(features)
|
||||
print(f" Building KD-tree on {n:,} vectors ({label}) …")
|
||||
tree = cKDTree(features)
|
||||
|
||||
k_query = min(k * 3, n - 1)
|
||||
print(f" Querying {k_query} neighbors per frame …")
|
||||
_dists, indices = tree.query(features, k=k_query + 1)
|
||||
indices = indices[:, 1:]
|
||||
|
||||
print(f" Computing cross-episode action variance ({label}) …")
|
||||
variance = np.zeros(n)
|
||||
for i in range(n):
|
||||
ep_i = episode_ids[i]
|
||||
neighbors = indices[i]
|
||||
cross_ep = neighbors[episode_ids[neighbors] != ep_i][:k]
|
||||
if len(cross_ep) < 2:
|
||||
variance[i] = 0.0
|
||||
continue
|
||||
neighbor_actions = action_norm[cross_ep]
|
||||
variance[i] = np.mean(np.var(neighbor_actions, axis=0))
|
||||
|
||||
return variance
|
||||
|
||||
|
||||
# ── Visualization ───────────────────────────────────────
|
||||
|
||||
|
||||
def _style_ax(ax: plt.Axes) -> None:
|
||||
ax.set_facecolor("#0d1117")
|
||||
ax.tick_params(colors="#555", labelsize=8)
|
||||
for spine in ax.spines.values():
|
||||
spine.set_color("#333")
|
||||
|
||||
|
||||
def _plot_histogram(ax: plt.Axes, variance: np.ndarray, title: str, color: str) -> None:
|
||||
_style_ax(ax)
|
||||
median_var = np.median(variance)
|
||||
mean_var = np.mean(variance)
|
||||
nonzero = variance[variance > 0]
|
||||
if len(nonzero) > 0:
|
||||
bins = np.logspace(np.log10(nonzero.min().clip(1e-6)), np.log10(nonzero.max()), 60)
|
||||
ax.hist(nonzero, bins=bins, color=color, alpha=0.8, edgecolor="#222")
|
||||
ax.set_xscale("log")
|
||||
ax.axvline(median_var, color="#ff6600", linewidth=2, label=f"median={median_var:.3f}")
|
||||
ax.axvline(mean_var, color="#ff2222", linewidth=2, linestyle="--", label=f"mean={mean_var:.3f}")
|
||||
ax.set_xlabel("Action variance (log scale)", color="#888", fontsize=10)
|
||||
ax.set_ylabel("Frame count", color="#888", fontsize=10)
|
||||
ax.set_title(title, color="white", fontsize=11, pad=10)
|
||||
ax.legend(fontsize=8, facecolor="#1a1a2e", edgecolor="#333", labelcolor="white")
|
||||
|
||||
|
||||
def _plot_episode_curves(
|
||||
ax: plt.Axes,
|
||||
var_state: np.ndarray,
|
||||
var_image: np.ndarray,
|
||||
episode_ids: np.ndarray,
|
||||
title: str,
|
||||
) -> None:
|
||||
_style_ax(ax)
|
||||
unique_eps = np.unique(episode_ids)
|
||||
|
||||
ep_means_s = np.array([var_state[episode_ids == ep].mean() for ep in unique_eps])
|
||||
ep_means_i = np.array([var_image[episode_ids == ep].mean() for ep in unique_eps])
|
||||
|
||||
sorted_s = np.sort(ep_means_s)[::-1]
|
||||
sorted_i = np.sort(ep_means_i)[::-1]
|
||||
ep_x = np.arange(len(unique_eps))
|
||||
|
||||
ax.fill_between(ep_x, sorted_s, alpha=0.2, color="#4363d8")
|
||||
ax.plot(ep_x, sorted_s, color="#4363d8", linewidth=1.2, label=f"State (med={np.median(ep_means_s):.3f})")
|
||||
ax.fill_between(ep_x, sorted_i, alpha=0.2, color="#e6194b")
|
||||
ax.plot(ep_x, sorted_i, color="#e6194b", linewidth=1.2, label=f"Image (med={np.median(ep_means_i):.3f})")
|
||||
|
||||
ax.set_xlabel("Episode rank (worst → best)", color="#888", fontsize=10)
|
||||
ax.set_ylabel("Mean action variance", color="#888", fontsize=10)
|
||||
ax.set_title(title, color="white", fontsize=11, pad=10)
|
||||
ax.legend(fontsize=8, facecolor="#1a1a2e", edgecolor="#333", labelcolor="white")
|
||||
|
||||
|
||||
def _plot_heatmap(
|
||||
ax: plt.Axes, fig: plt.Figure, tcp_xz: np.ndarray, variance: np.ndarray, title: str
|
||||
) -> None:
|
||||
_style_ax(ax)
|
||||
order = np.argsort(variance)
|
||||
pts = tcp_xz[order]
|
||||
var_sorted = variance[order]
|
||||
vmin = np.percentile(variance[variance > 0], 5) if np.any(variance > 0) else 0
|
||||
vmax = np.percentile(variance[variance > 0], 95) if np.any(variance > 0) else 1
|
||||
sc = ax.scatter(
|
||||
pts[:, 0],
|
||||
pts[:, 1],
|
||||
c=var_sorted,
|
||||
cmap=CONSISTENCY_CMAP,
|
||||
s=0.5,
|
||||
alpha=0.6,
|
||||
vmin=vmin,
|
||||
vmax=vmax,
|
||||
rasterized=True,
|
||||
)
|
||||
ax.set_xlabel("X (m)", color="#888", fontsize=10)
|
||||
ax.set_ylabel("Z (m)", color="#888", fontsize=10)
|
||||
ax.set_title(title, color="white", fontsize=11, pad=10)
|
||||
ax.set_aspect("equal")
|
||||
cbar = fig.colorbar(sc, ax=ax, shrink=0.8, pad=0.02)
|
||||
cbar.set_label("Action variance", color="white", fontsize=9)
|
||||
cbar.ax.tick_params(colors="#aaa", labelsize=7)
|
||||
|
||||
|
||||
def render(results: list[dict], out_path: Path) -> None:
|
||||
"""
|
||||
4-row x N-column figure:
|
||||
Row 0: State-based variance histogram
|
||||
Row 1: Image-based variance histogram
|
||||
Row 2: Per-episode curves (both overlaid)
|
||||
Row 3: Spatial heatmap (image-based variance)
|
||||
"""
|
||||
n_ds = len(results)
|
||||
fig, axes = plt.subplots(4, n_ds, figsize=(9 * n_ds, 24), facecolor="#0d1117")
|
||||
if n_ds == 1:
|
||||
axes = axes[:, np.newaxis]
|
||||
|
||||
headline_parts = []
|
||||
for col, r in enumerate(results):
|
||||
label = r["label"]
|
||||
var_s = r["var_state"]
|
||||
var_i = r["var_image"]
|
||||
tcp_xz = r["tcp_xz"]
|
||||
episode_ids = r["episode_ids"]
|
||||
|
||||
med_s = np.median(var_s)
|
||||
med_i = np.median(var_i)
|
||||
headline_parts.append(f"{label}: state={med_s:.3f}, image={med_i:.3f}")
|
||||
|
||||
_plot_histogram(axes[0, col], var_s, f"{label}\nState-based variance (K={K_NEIGHBORS})", "#4363d8")
|
||||
_plot_histogram(
|
||||
axes[1, col], var_i, f"{label}\nImage-based variance (SigLIP, K={K_NEIGHBORS})", "#e6194b"
|
||||
)
|
||||
_plot_episode_curves(
|
||||
axes[2, col],
|
||||
var_s,
|
||||
var_i,
|
||||
episode_ids,
|
||||
f"{label}\nPer-episode inconsistency ({len(np.unique(episode_ids)):,} episodes)",
|
||||
)
|
||||
_plot_heatmap(
|
||||
axes[3, col],
|
||||
fig,
|
||||
tcp_xz,
|
||||
var_i,
|
||||
f"{label}\nImage-based variance by TCP position (XZ)",
|
||||
)
|
||||
|
||||
fig.suptitle(
|
||||
f"Action Consistency: State vs Image (chunk={ACTION_CHUNK_SIZE}, K={K_NEIGHBORS})\n"
|
||||
+ " | ".join(headline_parts),
|
||||
color="white",
|
||||
fontsize=15,
|
||||
y=0.99,
|
||||
)
|
||||
plt.tight_layout(rect=[0, 0, 1, 0.96])
|
||||
plt.savefig(out_path, dpi=DPI, bbox_inches="tight", facecolor=fig.get_facecolor())
|
||||
plt.close()
|
||||
print(f"\n✓ Saved: {out_path}")
|
||||
|
||||
|
||||
# ── Main ────────────────────────────────────────────────
|
||||
|
||||
|
||||
def main() -> None:
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print(f"Device: {device}")
|
||||
rng = np.random.default_rng(SEED)
|
||||
results = []
|
||||
|
||||
for ds in DATASETS:
|
||||
repo_id, label = ds["repo_id"], ds["label"]
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f" {label}: {repo_id}")
|
||||
print(f"{'=' * 60}")
|
||||
|
||||
local = download_data(repo_id, CAMERA_KEY)
|
||||
data = load_state_action_data(local, MAX_FRAMES, ACTION_CHUNK_SIZE, rng)
|
||||
|
||||
# --- State-based KNN ---
|
||||
var_state = compute_consistency(
|
||||
data["state_norm"], data["action_norm"], data["episode_ids"], K_NEIGHBORS, "state"
|
||||
)
|
||||
print(
|
||||
f" State variance: median={np.median(var_state):.4f} "
|
||||
f"mean={np.mean(var_state):.4f} p90={np.percentile(var_state, 90):.4f}"
|
||||
)
|
||||
|
||||
# --- Image-based KNN ---
|
||||
print("\n Preparing image embeddings …")
|
||||
video_lookup = build_video_lookup(local, CAMERA_KEY)
|
||||
frames = extract_frames(data["chosen_idx"], data["episode_all"], video_lookup)
|
||||
embeddings = encode_frames_siglip(frames, ENCODER_MODEL, ENCODE_BATCH_SIZE, device)
|
||||
del frames # free memory
|
||||
|
||||
var_image = compute_consistency(
|
||||
embeddings, data["action_norm"], data["episode_ids"], K_NEIGHBORS, "image"
|
||||
)
|
||||
print(
|
||||
f" Image variance: median={np.median(var_image):.4f} "
|
||||
f"mean={np.mean(var_image):.4f} p90={np.percentile(var_image, 90):.4f}"
|
||||
)
|
||||
|
||||
# FK for spatial heatmap
|
||||
print(" Computing FK for spatial heatmap …")
|
||||
left_raw = data["state_raw"][:, data["left_joint_idx"]]
|
||||
left_rad = _detect_and_convert(left_raw)
|
||||
left_tcp = batch_fk(LEFT_CHAIN, left_rad)
|
||||
tcp_xz = left_tcp[:, [0, 2]]
|
||||
|
||||
results.append(
|
||||
{
|
||||
"label": label,
|
||||
"var_state": var_state,
|
||||
"var_image": var_image,
|
||||
"episode_ids": data["episode_ids"],
|
||||
"tcp_xz": tcp_xz,
|
||||
"n_total": data["n_total"],
|
||||
}
|
||||
)
|
||||
|
||||
out = OUTPUT_DIR / "action_consistency_comparison.jpg"
|
||||
render(results, out)
|
||||
|
||||
# Save worst-episodes summary (image-based, since that's the stronger signal)
|
||||
worst_summary = {}
|
||||
for r in results:
|
||||
unique_eps = np.unique(r["episode_ids"])
|
||||
ep_means = {int(ep): float(r["var_image"][r["episode_ids"] == ep].mean()) for ep in unique_eps}
|
||||
ranked = sorted(ep_means.items(), key=lambda x: x[1], reverse=True)[:50]
|
||||
worst_summary[r["label"]] = [{"episode": ep, "mean_variance": v} for ep, v in ranked]
|
||||
worst_path = OUTPUT_DIR / "action_consistency_worst_episodes.json"
|
||||
worst_path.write_text(json.dumps(worst_summary, indent=2))
|
||||
print(f"✓ Saved worst episodes: {worst_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,178 @@
|
||||
"""
|
||||
Create a JPG grid of random frames sampled from a LeRobot video dataset.
|
||||
Downloads metadata + video chunks from HuggingFace, picks random frames,
|
||||
decodes them, and tiles into a single image.
|
||||
"""
|
||||
|
||||
import json
|
||||
import random
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
REPO_ID = "lerobot-data-collection/level2_final_quality3"
|
||||
CAMERA_KEY = "observation.images.base"
|
||||
GRID_COLS = 15
|
||||
GRID_ROWS = 10
|
||||
THUMB_WIDTH = 160
|
||||
OUTPUT_DIR = Path(__file__).resolve().parent / "outputs"
|
||||
OUTPUT_DIR.mkdir(exist_ok=True)
|
||||
SEED = 1
|
||||
|
||||
|
||||
def download_metadata(repo_id: str) -> Path:
|
||||
"""Download only metadata (no videos yet)."""
|
||||
print(f"[1/3] Downloading metadata for {repo_id} …")
|
||||
return Path(
|
||||
snapshot_download(
|
||||
repo_id=repo_id,
|
||||
repo_type="dataset",
|
||||
allow_patterns=["meta/**"],
|
||||
ignore_patterns=["*.mp4"],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def load_video_info(local: Path) -> tuple[str, list[dict], int]:
|
||||
"""Parse info.json and episode parquets. Returns (camera_key, episode_rows, fps)."""
|
||||
info = json.loads((local / "meta" / "info.json").read_text())
|
||||
fps = info["fps"]
|
||||
features = info["features"]
|
||||
|
||||
video_keys = [k for k, v in features.items() if v.get("dtype") == "video"]
|
||||
if not video_keys:
|
||||
raise RuntimeError("No video keys found in dataset features")
|
||||
|
||||
if CAMERA_KEY is not None:
|
||||
if CAMERA_KEY not in video_keys:
|
||||
raise RuntimeError(f"CAMERA_KEY='{CAMERA_KEY}' not found. Available: {video_keys}")
|
||||
cam = CAMERA_KEY
|
||||
else:
|
||||
cam = video_keys[0]
|
||||
print(f" camera='{cam}' all_cams={video_keys} fps={fps}")
|
||||
|
||||
ep_rows = []
|
||||
for pq in sorted((local / "meta" / "episodes").glob("**/*.parquet")):
|
||||
ep_rows.append(pd.read_parquet(pq))
|
||||
ep_df = pd.concat(ep_rows, ignore_index=True)
|
||||
|
||||
video_template = info.get(
|
||||
"video_path",
|
||||
"videos/{video_key}/chunk-{chunk_index:03d}/file-{file_index:03d}.mp4",
|
||||
)
|
||||
|
||||
chunk_col = f"videos/{cam}/chunk_index"
|
||||
file_col = f"videos/{cam}/file_index"
|
||||
ts_from = f"videos/{cam}/from_timestamp"
|
||||
ts_to = f"videos/{cam}/to_timestamp"
|
||||
if chunk_col not in ep_df.columns:
|
||||
chunk_col = f"{cam}/chunk_index"
|
||||
file_col = f"{cam}/file_index"
|
||||
ts_from = f"{cam}/from_timestamp"
|
||||
ts_to = f"{cam}/to_timestamp"
|
||||
|
||||
episodes = []
|
||||
for _, row in ep_df.iterrows():
|
||||
ci = int(row[chunk_col])
|
||||
fi = int(row[file_col])
|
||||
episodes.append(
|
||||
{
|
||||
"episode_index": int(row["episode_index"]),
|
||||
"chunk_index": ci,
|
||||
"file_index": fi,
|
||||
"from_ts": float(row[ts_from]),
|
||||
"to_ts": float(row[ts_to]),
|
||||
"video_rel": video_template.format(video_key=cam, chunk_index=ci, file_index=fi),
|
||||
}
|
||||
)
|
||||
return cam, episodes, fps
|
||||
|
||||
|
||||
def pick_random_frames(episodes: list[dict], fps: int, n: int, rng: random.Random) -> list[dict]:
|
||||
"""Pick n random (episode, timestamp) pairs, return sorted by video file for efficient access."""
|
||||
picks = []
|
||||
for _ in range(n):
|
||||
ep = rng.choice(episodes)
|
||||
duration = ep["to_ts"] - ep["from_ts"]
|
||||
if duration <= 0:
|
||||
continue
|
||||
t = ep["from_ts"] + rng.random() * duration
|
||||
picks.append({**ep, "seek_ts": t})
|
||||
picks.sort(key=lambda p: (p["video_rel"], p["seek_ts"]))
|
||||
return picks
|
||||
|
||||
|
||||
def download_video_files(repo_id: str, local: Path, picks: list[dict]) -> None:
|
||||
"""Download only the video files we need."""
|
||||
needed = sorted({p["video_rel"] for p in picks})
|
||||
print(f"[2/3] Downloading {len(needed)} video file(s) …")
|
||||
snapshot_download(
|
||||
repo_id=repo_id,
|
||||
repo_type="dataset",
|
||||
local_dir=str(local),
|
||||
allow_patterns=needed,
|
||||
)
|
||||
|
||||
|
||||
def extract_frame(video_path: Path, seek_ts: float) -> np.ndarray | None:
|
||||
"""Decode a single frame at the given timestamp."""
|
||||
cap = cv2.VideoCapture(str(video_path))
|
||||
cap.set(cv2.CAP_PROP_POS_MSEC, seek_ts * 1000.0)
|
||||
ret, frame = cap.read()
|
||||
cap.release()
|
||||
return frame if ret else None
|
||||
|
||||
|
||||
def build_grid(frames: list[np.ndarray], cols: int, thumb_w: int) -> np.ndarray:
|
||||
"""Resize frames to uniform thumbnails and tile into a grid."""
|
||||
if not frames:
|
||||
raise RuntimeError("No frames decoded")
|
||||
|
||||
h0, w0 = frames[0].shape[:2]
|
||||
thumb_h = int(thumb_w * h0 / w0)
|
||||
|
||||
thumbs = [cv2.resize(f, (thumb_w, thumb_h), interpolation=cv2.INTER_AREA) for f in frames]
|
||||
|
||||
rows = []
|
||||
for i in range(0, len(thumbs), cols):
|
||||
row_thumbs = thumbs[i : i + cols]
|
||||
while len(row_thumbs) < cols:
|
||||
row_thumbs.append(np.zeros_like(row_thumbs[0]))
|
||||
rows.append(np.hstack(row_thumbs))
|
||||
return np.vstack(rows)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
rng = random.Random(SEED)
|
||||
n_frames = GRID_COLS * GRID_ROWS
|
||||
|
||||
local = download_metadata(REPO_ID)
|
||||
cam, episodes, fps = load_video_info(local)
|
||||
picks = pick_random_frames(episodes, fps, n_frames, rng)
|
||||
download_video_files(REPO_ID, local, picks)
|
||||
|
||||
print(f"[3/3] Decoding {n_frames} frames …")
|
||||
frames: list[np.ndarray] = []
|
||||
for p in picks:
|
||||
vp = local / p["video_rel"]
|
||||
if not vp.exists():
|
||||
print(f" SKIP: {p['video_rel']} not found")
|
||||
continue
|
||||
frame = extract_frame(vp, p["seek_ts"])
|
||||
if frame is not None:
|
||||
frames.append(frame)
|
||||
|
||||
print(f" Decoded {len(frames)}/{n_frames} frames")
|
||||
grid = build_grid(frames, GRID_COLS, THUMB_WIDTH)
|
||||
|
||||
safe_name = REPO_ID.replace("/", "_")
|
||||
out_path = OUTPUT_DIR / f"{safe_name}_grid_{GRID_COLS}x{GRID_ROWS}.jpg"
|
||||
cv2.imwrite(str(out_path), grid, [cv2.IMWRITE_JPEG_QUALITY, 92])
|
||||
print(f"\n✓ Saved: {out_path} ({grid.shape[1]}×{grid.shape[0]})")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,526 @@
|
||||
"""
|
||||
Create MP4 videos with sarm_progress overlay for specified episodes.
|
||||
Downloads datasets from HuggingFace, extracts episode video + progress data,
|
||||
and draws the progress line directly on each frame (no panel, no axes).
|
||||
"""
|
||||
|
||||
import json
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
DATASETS = [
|
||||
{"repo_id": "lerobot-data-collection/level2_final_quality3", "episode": 250},
|
||||
]
|
||||
CAMERA_KEY = (
|
||||
"observation.images.base" # None = auto-select first camera, or set e.g. "observation.images.top"
|
||||
)
|
||||
OUTPUT_DIR = Path(__file__).resolve().parent / "outputs"
|
||||
OUTPUT_DIR.mkdir(exist_ok=True)
|
||||
|
||||
# Progress line spans the full video height
|
||||
GRAPH_Y_TOP_FRAC = 0.01
|
||||
GRAPH_Y_BOT_FRAC = 0.99
|
||||
LINE_THICKNESS = 3
|
||||
SHADOW_THICKNESS = 6 # white edge thickness
|
||||
REF_ALPHA = 0.45 # opacity of the 1.0 reference line
|
||||
FILL_ALPHA = 0.55 # opacity of the grey fill under the line
|
||||
SCORE_FONT_SCALE = 0.8
|
||||
TASK_FONT_SCALE = 0.55
|
||||
|
||||
|
||||
def download_episode(repo_id: str, episode: int) -> Path:
|
||||
"""Download only the files needed for this episode."""
|
||||
# We need: meta/, sarm_progress.parquet, and the relevant video/data chunks.
|
||||
# We'll download meta + sarm first, then figure out chunks.
|
||||
print(f"\n[1/5] Downloading metadata for {repo_id} …")
|
||||
local = Path(
|
||||
snapshot_download(
|
||||
repo_id=repo_id,
|
||||
repo_type="dataset",
|
||||
allow_patterns=["meta/**", "sarm_progress.parquet"],
|
||||
ignore_patterns=["*.mp4"],
|
||||
)
|
||||
)
|
||||
return local
|
||||
|
||||
|
||||
def load_episode_meta(local: Path, episode: int) -> dict:
|
||||
"""Read info.json + episode-level parquet to get fps, video paths, timestamps."""
|
||||
info = json.loads((local / "meta" / "info.json").read_text())
|
||||
fps = info["fps"]
|
||||
features = info["features"]
|
||||
|
||||
# Find video keys (keys whose dtype=="video")
|
||||
video_keys = [k for k, v in features.items() if v.get("dtype") == "video"]
|
||||
if not video_keys:
|
||||
raise RuntimeError("No video keys found in dataset features")
|
||||
if CAMERA_KEY is not None:
|
||||
if CAMERA_KEY not in video_keys:
|
||||
raise RuntimeError(f"CAMERA_KEY='{CAMERA_KEY}' not found. Available: {video_keys}")
|
||||
first_cam = CAMERA_KEY
|
||||
else:
|
||||
first_cam = video_keys[0]
|
||||
print(f" fps={fps} camera='{first_cam}' all_cams={video_keys}")
|
||||
|
||||
# Load all episode-meta parquet files and find our episode
|
||||
ep_rows = []
|
||||
for pq in sorted((local / "meta" / "episodes").glob("**/*.parquet")):
|
||||
df = pd.read_parquet(pq)
|
||||
ep_rows.append(df)
|
||||
ep_df = pd.concat(ep_rows, ignore_index=True)
|
||||
row = ep_df[ep_df["episode_index"] == episode]
|
||||
if row.empty:
|
||||
raise RuntimeError(f"Episode {episode} not found in episode metadata")
|
||||
row = row.iloc[0]
|
||||
|
||||
# Extract video chunk/file index for first camera
|
||||
# Try both dot and slash variants of the key
|
||||
chunk_col = f"videos/{first_cam}/chunk_index"
|
||||
file_col = f"videos/{first_cam}/file_index"
|
||||
ts_col = f"videos/{first_cam}/from_timestamp"
|
||||
to_col = f"videos/{first_cam}/to_timestamp"
|
||||
|
||||
# Some datasets use different column naming
|
||||
if chunk_col not in row.index:
|
||||
# Try without the 'videos/' prefix
|
||||
chunk_col = f"{first_cam}/chunk_index"
|
||||
file_col = f"{first_cam}/file_index"
|
||||
ts_col = f"{first_cam}/from_timestamp"
|
||||
to_col = f"{first_cam}/to_timestamp"
|
||||
if chunk_col not in row.index:
|
||||
raise RuntimeError(
|
||||
f"Cannot find video metadata columns for {first_cam}.\nAvailable: {list(row.index)}"
|
||||
)
|
||||
|
||||
chunk_idx = int(row[chunk_col])
|
||||
file_idx = int(row[file_col])
|
||||
from_ts = float(row[ts_col])
|
||||
to_ts = float(row[to_col])
|
||||
|
||||
video_template = info.get(
|
||||
"video_path", "videos/{video_key}/chunk-{chunk_index:03d}/file-{file_index:03d}.mp4"
|
||||
)
|
||||
video_rel = video_template.format(
|
||||
video_key=first_cam,
|
||||
chunk_index=chunk_idx,
|
||||
file_index=file_idx,
|
||||
)
|
||||
|
||||
# Load task name for this episode
|
||||
# tasks.parquet uses the task string as the row index; task_index column holds the int id
|
||||
task_name = ""
|
||||
try:
|
||||
# Prefer the 'tasks' list directly on the episode row
|
||||
if "tasks" in row.index and row["tasks"] is not None:
|
||||
tasks_val = row["tasks"]
|
||||
if isinstance(tasks_val, (list, tuple, np.ndarray)) and len(tasks_val) > 0:
|
||||
task_name = str(tasks_val[0])
|
||||
else:
|
||||
task_name = str(tasks_val).strip("[]'")
|
||||
else:
|
||||
tasks_pq = local / "meta" / "tasks.parquet"
|
||||
if tasks_pq.exists():
|
||||
tasks_df = pd.read_parquet(tasks_pq)
|
||||
# Row index is the task string; task_index column is the int
|
||||
task_idx = int(row.get("task_index", 0)) if "task_index" in row.index else 0
|
||||
match = tasks_df[tasks_df["task_index"] == task_idx]
|
||||
if not match.empty:
|
||||
task_name = str(match.index[0])
|
||||
print(f" Task name: '{task_name}'")
|
||||
except Exception as e:
|
||||
print(f" WARNING: could not load task name: {e}")
|
||||
|
||||
return {
|
||||
"fps": fps,
|
||||
"first_cam": first_cam,
|
||||
"video_rel": video_rel,
|
||||
"chunk_index": chunk_idx,
|
||||
"file_index": file_idx,
|
||||
"from_ts": from_ts,
|
||||
"to_ts": to_ts,
|
||||
"task_name": task_name,
|
||||
}
|
||||
|
||||
|
||||
def download_video(repo_id: str, local: Path, video_rel: str) -> Path:
|
||||
"""Download the specific video file if not already present."""
|
||||
video_path = local / video_rel
|
||||
if video_path.exists():
|
||||
print(f" Video already cached: {video_path}")
|
||||
return video_path
|
||||
print(f"[2/5] Downloading video file {video_rel} …")
|
||||
snapshot_download(
|
||||
repo_id=repo_id,
|
||||
repo_type="dataset",
|
||||
local_dir=str(local),
|
||||
allow_patterns=[video_rel],
|
||||
)
|
||||
if not video_path.exists():
|
||||
raise RuntimeError(f"Video not found after download: {video_path}")
|
||||
return video_path
|
||||
|
||||
|
||||
def load_progress(local: Path, episode: int) -> np.ndarray | None:
|
||||
"""Load sarm_progress values for this episode. Returns sorted array of (frame_index, progress)."""
|
||||
pq_path = local / "sarm_progress.parquet"
|
||||
if not pq_path.exists():
|
||||
print(" WARNING: sarm_progress.parquet not found, trying data parquet …")
|
||||
return None
|
||||
df = pd.read_parquet(pq_path)
|
||||
print(f" sarm_progress.parquet columns: {list(df.columns)}")
|
||||
ep_df = df[df["episode_index"] == episode].copy()
|
||||
if ep_df.empty:
|
||||
print(f" WARNING: No sarm_progress rows for episode {episode}")
|
||||
return None
|
||||
ep_df = ep_df.sort_values("frame_index")
|
||||
|
||||
# Prefer dense, fall back to sparse
|
||||
if "progress_dense" in ep_df.columns and ep_df["progress_dense"].notna().any():
|
||||
prog_col = "progress_dense"
|
||||
elif "progress_sparse" in ep_df.columns:
|
||||
prog_col = "progress_sparse"
|
||||
else:
|
||||
# Last resort: any column with 'progress' in the name
|
||||
prog_cols = [c for c in ep_df.columns if "progress" in c.lower()]
|
||||
if not prog_cols:
|
||||
return None
|
||||
prog_col = prog_cols[0]
|
||||
|
||||
print(f" Using progress column: '{prog_col}'")
|
||||
return ep_df[["frame_index", prog_col]].rename(columns={prog_col: "progress"}).values
|
||||
|
||||
|
||||
def extract_episode_clip(video_path: Path, from_ts: float, to_ts: float, out_path: Path) -> Path:
|
||||
"""Use ffmpeg to cut the episode segment from the combined video file."""
|
||||
duration = to_ts - from_ts
|
||||
print(f"[3/5] Extracting clip [{from_ts:.3f}s → {to_ts:.3f}s] ({duration:.2f}s) …")
|
||||
cmd = [
|
||||
"ffmpeg",
|
||||
"-y",
|
||||
"-ss",
|
||||
str(from_ts),
|
||||
"-i",
|
||||
str(video_path),
|
||||
"-t",
|
||||
str(duration),
|
||||
"-c:v",
|
||||
"libx264",
|
||||
"-preset",
|
||||
"fast",
|
||||
"-crf",
|
||||
"18",
|
||||
"-an",
|
||||
str(out_path),
|
||||
]
|
||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"ffmpeg clip extraction failed:\n{result.stderr}")
|
||||
return out_path
|
||||
|
||||
|
||||
def precompute_pixels(
|
||||
progress_data: np.ndarray,
|
||||
n_frames: int,
|
||||
frame_w: int,
|
||||
frame_h: int,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Map each progress sample to pixel coordinates.
|
||||
Returns array of shape (N, 2) with (x, y) in pixel space.
|
||||
x spans full video width; y maps progress [0,1] to graph band.
|
||||
"""
|
||||
frame_indices = progress_data[:, 0].astype(float)
|
||||
progress_vals = np.clip(progress_data[:, 1].astype(float), 0.0, 1.0)
|
||||
|
||||
y_top = int(frame_h * GRAPH_Y_TOP_FRAC)
|
||||
y_bot = int(frame_h * GRAPH_Y_BOT_FRAC)
|
||||
graph_h = y_bot - y_top
|
||||
|
||||
xs = (frame_indices / (n_frames - 1) * (frame_w - 1)).astype(int)
|
||||
# progress=1 → y_top, progress=0 → y_bot
|
||||
ys = (y_bot - progress_vals * graph_h).astype(int)
|
||||
|
||||
return np.stack([xs, ys], axis=1) # (N, 2)
|
||||
|
||||
|
||||
def progress_color(t: float) -> tuple[int, int, int]:
|
||||
"""Interpolate BGR color red→green based on normalised position t in [0,1]."""
|
||||
r = int(255 * (1.0 - t))
|
||||
g = int(255 * t)
|
||||
return (0, g, r) # BGR
|
||||
|
||||
|
||||
def prerender_fill(
|
||||
pixels: np.ndarray,
|
||||
frame_w: int,
|
||||
frame_h: int,
|
||||
) -> np.ndarray:
|
||||
"""Pre-render the full grey fill polygon under the curve as a BGRA image."""
|
||||
y_bot = int(frame_h * GRAPH_Y_BOT_FRAC)
|
||||
fill_img = np.zeros((frame_h, frame_w, 4), dtype=np.uint8)
|
||||
poly = np.concatenate(
|
||||
[
|
||||
pixels,
|
||||
[[pixels[-1][0], y_bot], [pixels[0][0], y_bot]],
|
||||
],
|
||||
axis=0,
|
||||
).astype(np.int32)
|
||||
cv2.fillPoly(fill_img, [poly], color=(128, 128, 128, int(255 * FILL_ALPHA)))
|
||||
return fill_img
|
||||
|
||||
|
||||
def alpha_composite(base: np.ndarray, overlay_bgra: np.ndarray, x_max: int) -> None:
|
||||
"""Blend overlay onto base in-place, but only for x < x_max."""
|
||||
if x_max <= 0:
|
||||
return
|
||||
roi_b = base[:, :x_max]
|
||||
roi_o = overlay_bgra[:, :x_max]
|
||||
alpha = roi_o[:, :, 3:4].astype(np.float32) / 255.0
|
||||
roi_b[:] = np.clip(
|
||||
roi_o[:, :, :3].astype(np.float32) * alpha + roi_b.astype(np.float32) * (1.0 - alpha),
|
||||
0,
|
||||
255,
|
||||
).astype(np.uint8)
|
||||
|
||||
|
||||
def draw_text_outlined(
|
||||
frame: np.ndarray,
|
||||
text: str,
|
||||
pos: tuple[int, int],
|
||||
font_scale: float,
|
||||
thickness: int = 1,
|
||||
) -> None:
|
||||
"""Draw text with a dark outline for readability on any background."""
|
||||
font = cv2.FONT_HERSHEY_SIMPLEX
|
||||
cv2.putText(frame, text, pos, font, font_scale, (0, 0, 0), thickness + 2, cv2.LINE_AA)
|
||||
cv2.putText(frame, text, pos, font, font_scale, (255, 255, 255), thickness, cv2.LINE_AA)
|
||||
|
||||
|
||||
def composite_video(
|
||||
clip_path: Path,
|
||||
progress_data: np.ndarray,
|
||||
out_path: Path,
|
||||
fps: float,
|
||||
frame_h: int,
|
||||
frame_w: int,
|
||||
task_name: str = "",
|
||||
) -> Path:
|
||||
"""Read clip frames, draw gradient progress line with fill + labels, export as GIF."""
|
||||
n_total = int(cv2.VideoCapture(str(clip_path)).get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
pixels = precompute_pixels(progress_data, n_total, frame_w, frame_h)
|
||||
|
||||
y_ref = int(frame_h * GRAPH_Y_TOP_FRAC)
|
||||
|
||||
# Pre-render fill polygon (line is drawn per-frame with live color)
|
||||
fill_img = prerender_fill(pixels, frame_w, frame_h)
|
||||
|
||||
# 1.0 reference line overlay (full width, drawn once)
|
||||
ref_img = np.zeros((frame_h, frame_w, 4), dtype=np.uint8)
|
||||
cv2.line(ref_img, (0, y_ref), (frame_w - 1, y_ref), (200, 200, 200, int(255 * REF_ALPHA)), 1, cv2.LINE_AA)
|
||||
|
||||
frame_indices = progress_data[:, 0].astype(int)
|
||||
progress_vals = progress_data[:, 1].astype(float)
|
||||
|
||||
print(f"[4/4] Compositing {n_total} frames …")
|
||||
cap = cv2.VideoCapture(str(clip_path))
|
||||
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
|
||||
tmp_path = out_path.parent / (out_path.stem + "_tmp.mp4")
|
||||
writer = cv2.VideoWriter(str(tmp_path), fourcc, fps, (frame_w, frame_h))
|
||||
|
||||
fi = 0
|
||||
while True:
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
break
|
||||
|
||||
n_drawn = int(np.searchsorted(frame_indices, fi, side="right"))
|
||||
x_cur = int(pixels[min(n_drawn, len(pixels)) - 1][0]) + 1 if n_drawn > 0 else 0
|
||||
|
||||
# 1. reference line (full width, always)
|
||||
alpha_composite(frame, ref_img, frame_w)
|
||||
|
||||
# 2. grey fill under curve up to current x
|
||||
alpha_composite(frame, fill_img, x_cur)
|
||||
|
||||
# 3. progress line — single color that transitions red→green over time
|
||||
if n_drawn >= 2:
|
||||
t_cur = (n_drawn - 1) / max(len(progress_vals) - 1, 1)
|
||||
line_col = progress_color(t_cur)
|
||||
pts = pixels[:n_drawn].reshape(-1, 1, 2).astype(np.int32)
|
||||
cv2.polylines(
|
||||
frame,
|
||||
[pts],
|
||||
isClosed=False,
|
||||
color=(255, 255, 255),
|
||||
thickness=SHADOW_THICKNESS,
|
||||
lineType=cv2.LINE_AA,
|
||||
)
|
||||
cv2.polylines(
|
||||
frame, [pts], isClosed=False, color=line_col, thickness=LINE_THICKNESS, lineType=cv2.LINE_AA
|
||||
)
|
||||
|
||||
# 4. score — bottom right
|
||||
if n_drawn > 0:
|
||||
score = float(progress_vals[min(n_drawn, len(progress_vals)) - 1])
|
||||
score_text = f"{score:.2f}"
|
||||
(tw, th), _ = cv2.getTextSize(score_text, cv2.FONT_HERSHEY_SIMPLEX, SCORE_FONT_SCALE, 2)
|
||||
sx = frame_w - tw - 12
|
||||
sy = frame_h - 12
|
||||
# coloured score matching current gradient position
|
||||
t_cur = (n_drawn - 1) / max(len(progress_vals) - 1, 1)
|
||||
score_col = progress_color(t_cur)
|
||||
cv2.putText(
|
||||
frame,
|
||||
score_text,
|
||||
(sx, sy),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
SCORE_FONT_SCALE,
|
||||
(0, 0, 0),
|
||||
4,
|
||||
cv2.LINE_AA,
|
||||
)
|
||||
cv2.putText(
|
||||
frame,
|
||||
score_text,
|
||||
(sx, sy),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
SCORE_FONT_SCALE,
|
||||
score_col,
|
||||
2,
|
||||
cv2.LINE_AA,
|
||||
)
|
||||
|
||||
# 5. task name — top centre
|
||||
if task_name:
|
||||
(tw, _), _ = cv2.getTextSize(task_name, cv2.FONT_HERSHEY_SIMPLEX, TASK_FONT_SCALE, 1)
|
||||
tx = max((frame_w - tw) // 2, 4)
|
||||
draw_text_outlined(frame, task_name, (tx, 22), TASK_FONT_SCALE)
|
||||
|
||||
writer.write(frame)
|
||||
fi += 1
|
||||
if fi % 100 == 0:
|
||||
print(f" Frame {fi}/{n_total} …", end="\r")
|
||||
|
||||
cap.release()
|
||||
writer.release()
|
||||
print()
|
||||
|
||||
# Convert to GIF: full resolution, 12fps, 128-color diff palette (<40MB)
|
||||
gif_path = out_path.with_suffix(".gif")
|
||||
palette = out_path.parent / "_palette.png"
|
||||
r1 = subprocess.run( # nosec B607
|
||||
[
|
||||
"ffmpeg",
|
||||
"-y",
|
||||
"-i",
|
||||
str(tmp_path),
|
||||
"-vf",
|
||||
f"fps=10,scale={frame_w}:-1:flags=lanczos,palettegen=max_colors=128:stats_mode=diff",
|
||||
"-update",
|
||||
"1",
|
||||
str(palette),
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
if r1.returncode != 0:
|
||||
print(f" WARNING: palettegen failed:\n{r1.stderr[-500:]}")
|
||||
r2 = subprocess.run( # nosec B607
|
||||
[
|
||||
"ffmpeg",
|
||||
"-y",
|
||||
"-i",
|
||||
str(tmp_path),
|
||||
"-i",
|
||||
str(palette),
|
||||
"-filter_complex",
|
||||
f"fps=10,scale={frame_w}:-1:flags=lanczos[v];[v][1:v]paletteuse=dither=bayer:bayer_scale=3",
|
||||
str(gif_path),
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
if r2.returncode != 0:
|
||||
print(f" WARNING: gif encode failed:\n{r2.stderr[-500:]}")
|
||||
tmp_path.unlink(missing_ok=True)
|
||||
palette.unlink(missing_ok=True)
|
||||
return gif_path
|
||||
|
||||
|
||||
def process_dataset(repo_id: str, episode: int):
|
||||
safe_name = repo_id.replace("/", "_")
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f"Processing: {repo_id} | episode {episode}")
|
||||
print(f"{'=' * 60}")
|
||||
|
||||
# 1. Download metadata
|
||||
local = download_episode(repo_id, episode)
|
||||
print(f" Local cache: {local}")
|
||||
|
||||
# 2. Read episode metadata
|
||||
ep_meta = load_episode_meta(local, episode)
|
||||
print(f" Episode meta: {ep_meta}")
|
||||
|
||||
# 3. Download video file
|
||||
video_path = download_video(repo_id, local, ep_meta["video_rel"])
|
||||
|
||||
# 4. Extract clip
|
||||
clip_path = OUTPUT_DIR / f"{safe_name}_ep{episode}_clip.mp4"
|
||||
extract_episode_clip(video_path, ep_meta["from_ts"], ep_meta["to_ts"], clip_path)
|
||||
|
||||
# 5. Load progress data
|
||||
progress_data = load_progress(local, episode)
|
||||
if progress_data is None:
|
||||
print(" ERROR: Could not load sarm_progress data. Skipping overlay.")
|
||||
return
|
||||
|
||||
n_progress = len(progress_data)
|
||||
print(f" Progress frames: {n_progress}")
|
||||
|
||||
# 6. Get clip dimensions
|
||||
cap = cv2.VideoCapture(str(clip_path))
|
||||
frame_w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
frame_h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
n_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
actual_fps = cap.get(cv2.CAP_PROP_FPS) or ep_meta["fps"]
|
||||
cap.release()
|
||||
print(f" Clip: {frame_w}×{frame_h} {n_frames} frames @ {actual_fps:.1f}fps")
|
||||
|
||||
# 7. Composite (draw line directly on frames)
|
||||
out_path = OUTPUT_DIR / f"{safe_name}_ep{episode}_progress.mp4"
|
||||
final = composite_video(
|
||||
clip_path,
|
||||
progress_data,
|
||||
out_path,
|
||||
actual_fps,
|
||||
frame_h,
|
||||
frame_w,
|
||||
task_name=ep_meta.get("task_name", ""),
|
||||
)
|
||||
clip_path.unlink(missing_ok=True)
|
||||
print(f"\n✓ Done: {final}")
|
||||
return final
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
results = []
|
||||
for cfg in DATASETS:
|
||||
try:
|
||||
out = process_dataset(cfg["repo_id"], cfg["episode"])
|
||||
if out:
|
||||
results.append(out)
|
||||
except Exception as e:
|
||||
print(f"\nERROR processing {cfg['repo_id']}: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Output files:")
|
||||
for r in results:
|
||||
print(f" {r}")
|
||||
@@ -0,0 +1,496 @@
|
||||
"""
|
||||
Visualize end-effector workspace density and trajectory clusters for OpenArm datasets.
|
||||
Downloads joint position data (no videos) from HuggingFace, computes forward
|
||||
kinematics per episode, clusters trajectories with K-means, and renders
|
||||
2D projections comparing dataset coverage and multimodality.
|
||||
"""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from huggingface_hub import snapshot_download
|
||||
from sklearn.cluster import KMeans
|
||||
|
||||
DATASETS = [
|
||||
{"repo_id": "lerobot-data-collection/level2_final_quality3", "label": "HQ curated"},
|
||||
{"repo_id": "lerobot-data-collection/level12_rac_2_2026-02-08_1", "label": "Full collection"},
|
||||
]
|
||||
OUTPUT_DIR = Path(__file__).resolve().parent / "outputs"
|
||||
OUTPUT_DIR.mkdir(exist_ok=True)
|
||||
|
||||
N_CLUSTERS = 10
|
||||
WAYPOINTS = 50
|
||||
SEED = 42
|
||||
DPI = 180
|
||||
|
||||
CLUSTER_COLORS = [
|
||||
"#e6194b",
|
||||
"#3cb44b",
|
||||
"#4363d8",
|
||||
"#f58231",
|
||||
"#911eb4",
|
||||
"#42d4f4",
|
||||
"#f032e6",
|
||||
"#bfef45",
|
||||
"#fabed4",
|
||||
"#dcbeff",
|
||||
"#9a6324",
|
||||
"#fffac8",
|
||||
"#800000",
|
||||
"#aaffc3",
|
||||
"#808000",
|
||||
"#ffd8b1",
|
||||
"#000075",
|
||||
"#a9a9a9",
|
||||
]
|
||||
|
||||
# FK chains extracted from OpenArm bimanual URDF.
|
||||
# Each entry: (rpy, xyz, revolute_axis_or_None).
|
||||
LEFT_CHAIN = [
|
||||
((-np.pi / 2, 0, 0), (0, 0.031, 0.698), None),
|
||||
((0, 0, 0), (0, 0, 0.0625), (0, 0, 1)),
|
||||
((-np.pi / 2, 0, 0), (-0.0301, 0, 0.06), (-1, 0, 0)),
|
||||
((0, 0, 0), (0.0301, 0, 0.06625), (0, 0, 1)),
|
||||
((0, 0, 0), (0, 0.0315, 0.15375), (0, 1, 0)),
|
||||
((0, 0, 0), (0, -0.0315, 0.0955), (0, 0, 1)),
|
||||
((0, 0, 0), (0.0375, 0, 0.1205), (1, 0, 0)),
|
||||
((0, 0, 0), (-0.0375, 0, 0), (0, -1, 0)),
|
||||
((0, 0, 0), (0, 0, 0.1001), None),
|
||||
((0, 0, 0), (0, 0, 0.08), None),
|
||||
]
|
||||
RIGHT_CHAIN = [
|
||||
((np.pi / 2, 0, 0), (0, -0.031, 0.698), None),
|
||||
((0, 0, 0), (0, 0, 0.0625), (0, 0, 1)),
|
||||
((np.pi / 2, 0, 0), (-0.0301, 0, 0.06), (-1, 0, 0)),
|
||||
((0, 0, 0), (0.0301, 0, 0.06625), (0, 0, 1)),
|
||||
((0, 0, 0), (0, 0.0315, 0.15375), (0, 1, 0)),
|
||||
((0, 0, 0), (0, -0.0315, 0.0955), (0, 0, 1)),
|
||||
((0, 0, 0), (0.0375, 0, 0.1205), (1, 0, 0)),
|
||||
((0, 0, 0), (-0.0375, 0, 0), (0, 1, 0)),
|
||||
((0, 0, 0), (0, 0, 0.1001), None),
|
||||
((0, 0, 0), (0, 0, 0.08), None),
|
||||
]
|
||||
|
||||
|
||||
# ── FK math ─────────────────────────────────────────────
|
||||
|
||||
|
||||
def _rot_x(a: float) -> np.ndarray:
|
||||
c, s = np.cos(a), np.sin(a)
|
||||
return np.array([[1, 0, 0], [0, c, -s], [0, s, c]])
|
||||
|
||||
|
||||
def _rot_y(a: float) -> np.ndarray:
|
||||
c, s = np.cos(a), np.sin(a)
|
||||
return np.array([[c, 0, s], [0, 1, 0], [-s, 0, c]])
|
||||
|
||||
|
||||
def _rot_z(a: float) -> np.ndarray:
|
||||
c, s = np.cos(a), np.sin(a)
|
||||
return np.array([[c, -s, 0], [s, c, 0], [0, 0, 1]])
|
||||
|
||||
|
||||
def _tf(rpy: tuple, xyz: tuple) -> np.ndarray:
|
||||
"""Build a 4x4 homogeneous transform from URDF rpy + xyz."""
|
||||
r, p, y = rpy
|
||||
mat = np.eye(4)
|
||||
mat[:3, :3] = _rot_z(y) @ _rot_y(p) @ _rot_x(r)
|
||||
mat[:3, 3] = xyz
|
||||
return mat
|
||||
|
||||
|
||||
def _batch_axis_rot(axis: tuple, angles: np.ndarray) -> np.ndarray:
|
||||
"""Batched Rodrigues rotation: (n,) angles around a fixed axis → (n, 4, 4)."""
|
||||
n = len(angles)
|
||||
ax = np.asarray(axis, dtype=np.float64)
|
||||
ax = ax / np.linalg.norm(ax)
|
||||
x, y, z = ax
|
||||
c = np.cos(angles)
|
||||
s = np.sin(angles)
|
||||
t = 1 - c
|
||||
rot = np.zeros((n, 4, 4))
|
||||
rot[:, 0, 0] = t * x * x + c
|
||||
rot[:, 0, 1] = t * x * y - s * z
|
||||
rot[:, 0, 2] = t * x * z + s * y
|
||||
rot[:, 1, 0] = t * x * y + s * z
|
||||
rot[:, 1, 1] = t * y * y + c
|
||||
rot[:, 1, 2] = t * y * z - s * x
|
||||
rot[:, 2, 0] = t * x * z - s * y
|
||||
rot[:, 2, 1] = t * y * z + s * x
|
||||
rot[:, 2, 2] = t * z * z + c
|
||||
rot[:, 3, 3] = 1.0
|
||||
return rot
|
||||
|
||||
|
||||
def batch_fk(chain: list, joint_angles: np.ndarray) -> np.ndarray:
|
||||
"""Vectorized FK: (n, 7) radians → (n, 3) TCP positions in world frame."""
|
||||
n = joint_angles.shape[0]
|
||||
tf_batch = np.tile(np.eye(4), (n, 1, 1))
|
||||
qi = 0
|
||||
for rpy, xyz, axis in chain:
|
||||
tf_batch = tf_batch @ _tf(rpy, xyz)
|
||||
if axis is not None:
|
||||
rot = _batch_axis_rot(axis, joint_angles[:, qi])
|
||||
tf_batch = np.einsum("nij,njk->nik", tf_batch, rot)
|
||||
qi += 1
|
||||
return tf_batch[:, :3, 3]
|
||||
|
||||
|
||||
# ── Data loading ────────────────────────────────────────
|
||||
|
||||
|
||||
def _flatten_names(obj: object) -> list[str]:
|
||||
"""Recursively flatten a names structure (list, dict, or nested) into a flat string list."""
|
||||
if isinstance(obj, dict):
|
||||
out: list[str] = []
|
||||
for v in obj.values():
|
||||
out.extend(_flatten_names(v))
|
||||
return out
|
||||
if isinstance(obj, (list, tuple)):
|
||||
out = []
|
||||
for item in obj:
|
||||
if isinstance(item, (list, tuple, dict)):
|
||||
out.extend(_flatten_names(item))
|
||||
else:
|
||||
out.append(str(item))
|
||||
return out
|
||||
return [str(obj)]
|
||||
|
||||
|
||||
def _detect_and_convert(vals: np.ndarray) -> np.ndarray:
|
||||
"""Auto-detect servo ticks / degrees / radians and convert to radians."""
|
||||
mx = np.max(np.abs(vals))
|
||||
if mx > 360:
|
||||
print(f" Unit detection: servo ticks (max={mx:.0f})")
|
||||
return (vals - 2048) / 2048 * np.pi
|
||||
if mx > 6.3:
|
||||
print(f" Unit detection: degrees (max={mx:.1f})")
|
||||
return np.deg2rad(vals)
|
||||
print(f" Unit detection: radians (max={mx:.3f})")
|
||||
return vals.astype(np.float64)
|
||||
|
||||
|
||||
def _find_joint_indices(features: dict, state_col: str, n_dim: int) -> tuple[list[int], list[int]]:
|
||||
"""Try to find left/right joint indices from info.json feature names."""
|
||||
feat = features.get("observation.state", features.get(state_col, {}))
|
||||
names = _flatten_names(feat.get("names", []))
|
||||
|
||||
left_idx: list[int] = []
|
||||
right_idx: list[int] = []
|
||||
if names and len(names) == n_dim:
|
||||
names_l = [n.lower() for n in names]
|
||||
print(f" Feature names: {names[:4]}…{names[-4:]}")
|
||||
for j in range(1, 8):
|
||||
for i, nm in enumerate(names_l):
|
||||
if f"left_joint_{j}" in nm and i not in left_idx:
|
||||
left_idx.append(i)
|
||||
break
|
||||
for i, nm in enumerate(names_l):
|
||||
if f"right_joint_{j}" in nm and i not in right_idx:
|
||||
right_idx.append(i)
|
||||
break
|
||||
|
||||
if len(left_idx) == 7 and len(right_idx) == 7:
|
||||
print(f" Matched by name: left={left_idx} right={right_idx}")
|
||||
return left_idx, right_idx
|
||||
if n_dim >= 16:
|
||||
print(" Falling back to positional: [0:7]=left, [8:15]=right")
|
||||
return list(range(7)), list(range(8, 15))
|
||||
if n_dim >= 14:
|
||||
print(" Falling back to positional: [0:7]=left, [7:14]=right")
|
||||
return list(range(7)), list(range(7, 14))
|
||||
raise RuntimeError(f"State dim {n_dim} too small for bimanual 7-DOF robot")
|
||||
|
||||
|
||||
def download_data(repo_id: str) -> Path:
|
||||
print(f" Downloading {repo_id} (parquet only) …")
|
||||
return Path(
|
||||
snapshot_download(
|
||||
repo_id=repo_id,
|
||||
repo_type="dataset",
|
||||
allow_patterns=["meta/**", "data/**"],
|
||||
ignore_patterns=["*.mp4", "videos/**"],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def resample_trajectory(traj: np.ndarray, n_waypoints: int) -> np.ndarray:
|
||||
"""Resample a (F, 3) trajectory to exactly n_waypoints via linear interpolation."""
|
||||
f = traj.shape[0]
|
||||
if f == n_waypoints:
|
||||
return traj
|
||||
old_t = np.linspace(0, 1, f)
|
||||
new_t = np.linspace(0, 1, n_waypoints)
|
||||
return np.column_stack([np.interp(new_t, old_t, traj[:, d]) for d in range(3)])
|
||||
|
||||
|
||||
def load_episode_trajectories(local: Path) -> list[dict]:
|
||||
"""
|
||||
Load per-episode joint data, compute FK, return list of trajectory dicts.
|
||||
Each dict: {"left_tcp": (F,3), "right_tcp": (F,3), "episode_index": int}.
|
||||
Uses all episodes in the dataset for a fair comparison.
|
||||
"""
|
||||
info = json.loads((local / "meta" / "info.json").read_text())
|
||||
features = info.get("features", {})
|
||||
|
||||
dfs = [pd.read_parquet(pq) for pq in sorted((local / "data").glob("**/*.parquet"))]
|
||||
df = pd.concat(dfs, ignore_index=True)
|
||||
print(f" Total frames: {len(df):,}")
|
||||
|
||||
state_col = next((c for c in df.columns if "observation.state" in c), None)
|
||||
if state_col is None:
|
||||
raise RuntimeError(f"No observation.state column. Available: {list(df.columns)}")
|
||||
|
||||
first = df[state_col].iloc[0]
|
||||
if not hasattr(first, "__len__"):
|
||||
raise RuntimeError(f"observation.state is scalar ({type(first)}), expected array")
|
||||
|
||||
state = np.stack(df[state_col].values).astype(np.float64)
|
||||
n_dim = state.shape[1]
|
||||
print(f" State dim: {n_dim} max|val|: {np.max(np.abs(state)):.1f}")
|
||||
|
||||
left_idx, right_idx = _find_joint_indices(features, state_col, n_dim)
|
||||
|
||||
ep_col = next((c for c in df.columns if c == "episode_index"), None)
|
||||
if ep_col is None:
|
||||
raise RuntimeError(f"No episode_index column. Available: {list(df.columns)}")
|
||||
|
||||
episode_ids = df[ep_col].values
|
||||
unique_eps = np.unique(episode_ids)
|
||||
print(f" Episodes: {len(unique_eps):,}")
|
||||
|
||||
left_raw = state[:, left_idx]
|
||||
right_raw = state[:, right_idx]
|
||||
left_all = _detect_and_convert(left_raw)
|
||||
right_all = _detect_and_convert(right_raw)
|
||||
|
||||
print(" Computing FK per episode …")
|
||||
trajectories = []
|
||||
for ep_id in unique_eps:
|
||||
mask = episode_ids == ep_id
|
||||
left_tcp = batch_fk(LEFT_CHAIN, left_all[mask])
|
||||
right_tcp = batch_fk(RIGHT_CHAIN, right_all[mask])
|
||||
if len(left_tcp) < 3:
|
||||
continue
|
||||
trajectories.append({"left_tcp": left_tcp, "right_tcp": right_tcp, "episode_index": int(ep_id)})
|
||||
|
||||
print(f" Valid trajectories: {len(trajectories):,}")
|
||||
return trajectories
|
||||
|
||||
|
||||
# ── Clustering ──────────────────────────────────────────
|
||||
|
||||
|
||||
def cluster_trajectories(
|
||||
trajectories: list[dict], n_clusters: int, n_waypoints: int
|
||||
) -> tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
K-means on resampled trajectory features.
|
||||
Combines left+right TCP into a single feature vector per episode.
|
||||
Returns (labels, centroid_trajs (k, waypoints, 6), spread_per_cluster (k,) in metres).
|
||||
Spread = mean per-waypoint Euclidean distance from each trajectory to its centroid.
|
||||
"""
|
||||
feat_vecs = []
|
||||
for t in trajectories:
|
||||
left_rs = resample_trajectory(t["left_tcp"], n_waypoints)
|
||||
right_rs = resample_trajectory(t["right_tcp"], n_waypoints)
|
||||
feat_vecs.append(np.concatenate([left_rs.ravel(), right_rs.ravel()]))
|
||||
feat_matrix = np.array(feat_vecs)
|
||||
|
||||
k = min(n_clusters, len(feat_vecs))
|
||||
km = KMeans(n_clusters=k, n_init=10, random_state=SEED)
|
||||
labels = km.fit_predict(feat_matrix)
|
||||
|
||||
centroids_flat = km.cluster_centers_
|
||||
centroid_trajs = np.zeros((k, n_waypoints, 6))
|
||||
for ci in range(k):
|
||||
left_flat = centroids_flat[ci, : n_waypoints * 3]
|
||||
right_flat = centroids_flat[ci, n_waypoints * 3 :]
|
||||
centroid_trajs[ci, :, :3] = left_flat.reshape(n_waypoints, 3)
|
||||
centroid_trajs[ci, :, 3:] = right_flat.reshape(n_waypoints, 3)
|
||||
|
||||
# Mean per-waypoint distance to centroid (in metres) for each cluster
|
||||
spread = np.zeros(k)
|
||||
for ci in range(k):
|
||||
members = np.where(labels == ci)[0]
|
||||
if len(members) == 0:
|
||||
continue
|
||||
centroid_left = centroid_trajs[ci, :, :3]
|
||||
centroid_right = centroid_trajs[ci, :, 3:]
|
||||
dists = []
|
||||
for mi in members:
|
||||
t = trajectories[mi]
|
||||
left_rs = resample_trajectory(t["left_tcp"], n_waypoints)
|
||||
right_rs = resample_trajectory(t["right_tcp"], n_waypoints)
|
||||
d_left = np.linalg.norm(left_rs - centroid_left, axis=1).mean()
|
||||
d_right = np.linalg.norm(right_rs - centroid_right, axis=1).mean()
|
||||
dists.append((d_left + d_right) / 2)
|
||||
spread[ci] = np.mean(dists)
|
||||
|
||||
return labels, centroid_trajs, spread
|
||||
|
||||
|
||||
# ── Visualization ───────────────────────────────────────
|
||||
|
||||
PROJ_VIEWS = [
|
||||
("XZ (side)", 0, 2, "X (m)", "Z (m)"),
|
||||
("XY (top)", 0, 1, "X (m)", "Y (m)"),
|
||||
("YZ (front)", 1, 2, "Y (m)", "Z (m)"),
|
||||
]
|
||||
|
||||
|
||||
def render(results: list[dict], out_path: Path) -> None:
|
||||
"""
|
||||
2-row × 3-col grid per dataset (3 projections × 2 datasets).
|
||||
Trajectory lines colored by cluster, centroid trajectories drawn thick.
|
||||
"""
|
||||
n_ds = len(results)
|
||||
n_proj = len(PROJ_VIEWS)
|
||||
fig, axes = plt.subplots(n_ds, n_proj, figsize=(7 * n_proj, 7 * n_ds), facecolor="#0d1117")
|
||||
if n_ds == 1:
|
||||
axes = axes[np.newaxis, :]
|
||||
|
||||
for row, r in enumerate(results):
|
||||
trajectories = r["trajectories"]
|
||||
labels = r["labels"]
|
||||
centroids = r["centroids"]
|
||||
k = centroids.shape[0]
|
||||
|
||||
cluster_sizes = np.bincount(labels, minlength=k)
|
||||
size_order = np.argsort(-cluster_sizes)
|
||||
pcts = cluster_sizes / len(labels) * 100
|
||||
spread = r["spread"]
|
||||
|
||||
for col, (view_name, dim_a, dim_b, xlabel, ylabel) in enumerate(PROJ_VIEWS):
|
||||
ax = axes[row, col]
|
||||
ax.set_facecolor("#0d1117")
|
||||
|
||||
for ti, traj in enumerate(trajectories):
|
||||
color = CLUSTER_COLORS[labels[ti] % len(CLUSTER_COLORS)]
|
||||
for tcp_key in ("left_tcp", "right_tcp"):
|
||||
pts = traj[tcp_key]
|
||||
ax.plot(pts[:, dim_a], pts[:, dim_b], color=color, alpha=0.12, linewidth=0.4)
|
||||
|
||||
for ci in range(k):
|
||||
color = CLUSTER_COLORS[ci % len(CLUSTER_COLORS)]
|
||||
left_c = centroids[ci, :, :3]
|
||||
right_c = centroids[ci, :, 3:]
|
||||
lw = 1.5 + 2.0 * cluster_sizes[ci] / cluster_sizes.max()
|
||||
for c_pts in (left_c, right_c):
|
||||
ax.plot(
|
||||
c_pts[:, dim_a],
|
||||
c_pts[:, dim_b],
|
||||
color=color,
|
||||
linewidth=lw,
|
||||
alpha=0.95,
|
||||
zorder=10,
|
||||
)
|
||||
ax.plot(
|
||||
c_pts[0, dim_a],
|
||||
c_pts[0, dim_b],
|
||||
"o",
|
||||
color=color,
|
||||
markersize=4,
|
||||
zorder=11,
|
||||
)
|
||||
ax.plot(
|
||||
c_pts[-1, dim_a],
|
||||
c_pts[-1, dim_b],
|
||||
"s",
|
||||
color=color,
|
||||
markersize=4,
|
||||
zorder=11,
|
||||
)
|
||||
|
||||
ax.set_xlabel(xlabel, color="#888", fontsize=9)
|
||||
ax.set_ylabel(ylabel, color="#888", fontsize=9)
|
||||
ax.tick_params(colors="#555", labelsize=7)
|
||||
for spine in ax.spines.values():
|
||||
spine.set_color("#333")
|
||||
ax.set_aspect("equal")
|
||||
|
||||
mean_spread_cm = np.average(spread, weights=cluster_sizes) * 100
|
||||
if col == 0:
|
||||
ax.set_title(
|
||||
f"{r['label']} ({r['n_episodes']:,} episodes, {k} clusters, "
|
||||
f"avg spread {mean_spread_cm:.1f}cm)",
|
||||
color="white",
|
||||
fontsize=11,
|
||||
pad=10,
|
||||
)
|
||||
else:
|
||||
ax.set_title(view_name, color="#aaa", fontsize=10, pad=8)
|
||||
|
||||
# Cluster size + spread legend on the rightmost panel
|
||||
legend_ax = axes[row, -1]
|
||||
for ci in size_order:
|
||||
color = CLUSTER_COLORS[ci % len(CLUSTER_COLORS)]
|
||||
spread_cm = spread[ci] * 100
|
||||
label = f"C{ci}: {cluster_sizes[ci]} eps ({pcts[ci]:.0f}%) ±{spread_cm:.1f}cm"
|
||||
legend_ax.plot([], [], color=color, linewidth=3, label=label)
|
||||
legend_ax.legend(
|
||||
loc="upper right",
|
||||
fontsize=7,
|
||||
frameon=True,
|
||||
facecolor="#1a1a2e",
|
||||
edgecolor="#333",
|
||||
labelcolor="white",
|
||||
handlelength=1.5,
|
||||
)
|
||||
|
||||
fig.suptitle(
|
||||
"End-Effector Trajectory Clusters (FK · K-means)",
|
||||
color="white",
|
||||
fontsize=16,
|
||||
y=0.98,
|
||||
)
|
||||
plt.tight_layout(rect=[0, 0, 1, 0.95])
|
||||
plt.savefig(out_path, dpi=DPI, bbox_inches="tight", facecolor=fig.get_facecolor())
|
||||
plt.close()
|
||||
print(f"\n✓ Saved: {out_path}")
|
||||
|
||||
|
||||
# ── Main ────────────────────────────────────────────────
|
||||
|
||||
|
||||
def main() -> None:
|
||||
results = []
|
||||
|
||||
for ds in DATASETS:
|
||||
repo_id, label = ds["repo_id"], ds["label"]
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f" {label}: {repo_id}")
|
||||
print(f"{'=' * 60}")
|
||||
|
||||
local = download_data(repo_id)
|
||||
trajectories = load_episode_trajectories(local)
|
||||
labels, centroids, spread = cluster_trajectories(trajectories, N_CLUSTERS, WAYPOINTS)
|
||||
|
||||
cluster_sizes = np.bincount(labels, minlength=centroids.shape[0])
|
||||
print(f" Cluster sizes: {sorted(cluster_sizes, reverse=True)}")
|
||||
for ci in np.argsort(-cluster_sizes):
|
||||
print(
|
||||
f" C{ci}: {cluster_sizes[ci]} eps ({cluster_sizes[ci] / len(labels) * 100:.0f}%) "
|
||||
f"spread ±{spread[ci] * 100:.1f}cm"
|
||||
)
|
||||
|
||||
results.append(
|
||||
{
|
||||
"label": label,
|
||||
"trajectories": trajectories,
|
||||
"labels": labels,
|
||||
"centroids": centroids,
|
||||
"spread": spread,
|
||||
"n_episodes": len(trajectories),
|
||||
}
|
||||
)
|
||||
|
||||
out = OUTPUT_DIR / "workspace_trajectory_clusters.jpg"
|
||||
render(results, out)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user