mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-12 07:09:43 +00:00
Compare commits
7 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| a0dc324b81 | |||
| 1d275e2021 | |||
| 24bb2cb0ff | |||
| 1d414c07e2 | |||
| e04e3399b9 | |||
| 017ff73fbf | |||
| f90db58c15 |
@@ -19,6 +19,8 @@
|
||||
title: Multi GPU training
|
||||
- local: peft_training
|
||||
title: Training with PEFT (e.g., LoRA)
|
||||
- local: rename_map
|
||||
title: Using Rename Map and Empty Cameras
|
||||
title: "Tutorials"
|
||||
- sections:
|
||||
- local: lerobot-dataset-v3
|
||||
|
||||
@@ -310,4 +310,4 @@ Asynchronous inference represents a significant advancement in real-time robotic
|
||||
- **Universal Compatibility**: Works with all LeRobot-supported policies, from lightweight ACT models to vision-language models like SmolVLA
|
||||
|
||||
Start experimenting with the default parameters, monitor your action queue sizes, and iteratively refine your setup to achieve optimal performance for your specific use case.
|
||||
If you want to discuss this further, hop into our [Discord community](https://discord.gg/s3KuuzsPFb), or open an issue on our [GitHub repository](https://github.com/lerobot/lerobot/issues).
|
||||
If you want to discuss this further, hop into our [Discord community](https://discord.gg/s3KuuzsPFb), or open an issue on our [GitHub repository](https://github.com/huggingface/lerobot/issues).
|
||||
|
||||
@@ -0,0 +1,114 @@
|
||||
# Rename Map and Empty Cameras
|
||||
|
||||
When you train, evaluate, or record with a robot policy, your **dataset** or **environment** provides observations under one set of keys (e.g. `observation.images.front`, `observation.images.eagle`), while your **policy** expects another (e.g. `observation.images.image`, `observation.images.image2`). The **rename map** bridges that gap without changing the policy or data source.
|
||||
|
||||
> **Scope:** The rename map only renames **observation** keys (images and state). Action keys are not affected.
|
||||
|
||||
## Why observation keys don't always match
|
||||
|
||||
Policies have a fixed set of **input feature names** baked into their pretrained config. For example:
|
||||
|
||||
- [pi0fast-libero](https://huggingface.co/lerobot/pi0fast-libero) expects `observation.images.base_0_rgb` and `observation.images.left_wrist_0_rgb`.
|
||||
- [xvla-base](https://huggingface.co/lerobot/xvla-base) expects `observation.images.image`, `observation.images.image2`, and `observation.images.image3`.
|
||||
|
||||
Your dataset might use different names entirely (e.g. `observation.images.front`, `observation.images.eagle`, `observation.images.glove`), and your eval environment might use yet another set. Rather than editing the policy config or renaming columns in the dataset, you pass a **rename map**: a JSON dictionary that maps source keys to the keys the policy expects. Renaming happens inside the preprocessor pipeline, so the policy always sees its expected keys.
|
||||
|
||||
## Using the rename map
|
||||
|
||||
Pass the mapping as a JSON string on the command line. The convention is always:
|
||||
|
||||
```
|
||||
--rename_map='{"source_key": "policy_key", ...}'
|
||||
```
|
||||
|
||||
where **source_key** is what the dataset or environment provides, and **policy_key** is what the policy expects.
|
||||
|
||||
Only listed keys are renamed; everything else passes through unchanged. Order of entries doesn't matter.
|
||||
|
||||
Supported policies: **PI0**, **PI05**, **PI0Fast**, **SmolVLA**, and **XVLA**.
|
||||
|
||||
### Training
|
||||
|
||||
Suppose you fine-tune [lerobot/xvla-base](https://huggingface.co/lerobot/xvla-base) on a dataset with images under `observation.images.front`, `observation.images.eagle`, and `observation.images.glove`. XVLA expects `observation.images.image`, `observation.images.image2`, and `observation.images.image3`:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--dataset.repo_id=YOUR_DATASET \
|
||||
--output_dir=./outputs/xvla_training \
|
||||
--job_name=xvla_training \
|
||||
--policy.path="lerobot/xvla-base" \
|
||||
--policy.repo_id="HF_USER/xvla-your-robot" \
|
||||
--policy.dtype=bfloat16 \
|
||||
--policy.action_mode=auto \
|
||||
--steps=20000 \
|
||||
--policy.device=cuda \
|
||||
--policy.freeze_vision_encoder=false \
|
||||
--policy.freeze_language_encoder=false \
|
||||
--policy.train_policy_transformer=true \
|
||||
--policy.train_soft_prompts=true \
|
||||
--rename_map='{"observation.images.front": "observation.images.image", "observation.images.eagle": "observation.images.image2", "observation.images.glove": "observation.images.image3"}'
|
||||
```
|
||||
|
||||
### Evaluation
|
||||
|
||||
A policy that expects `observation.images.base_0_rgb` and `observation.images.left_wrist_0_rgb` (e.g. [pi0fast-libero](https://huggingface.co/lerobot/pi0fast-libero)), but the LIBERO environment returns `observation.images.image` and `observation.images.image2`:
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.path=lerobot/pi0fast-libero \
|
||||
--env.type=libero \
|
||||
... \
|
||||
--rename_map='{"observation.images.image": "observation.images.base_0_rgb", "observation.images.image2": "observation.images.left_wrist_0_rgb"}'
|
||||
```
|
||||
|
||||
### Recording
|
||||
|
||||
`lerobot-record` also supports rename maps, nested under the dataset config:
|
||||
|
||||
```bash
|
||||
lerobot-record \ # When running inference
|
||||
--policy.path="<user>/smolVLA_finetuned" \
|
||||
... \
|
||||
--dataset.rename_map='{"observation.images.glove2": "observation.images.image"}'
|
||||
```
|
||||
|
||||
## Alternative: edit the policy config directly
|
||||
|
||||
If you always use the same dataset or environment, you can **edit the policy's `config.json`** so its observation keys match your data source. Then no rename map is needed.
|
||||
|
||||
The tradeoff: modifying the policy config ties it to one data source. A rename map keeps one policy usable across many datasets and environments.
|
||||
|
||||
## Empty cameras: fewer views than the policy expects
|
||||
|
||||
Some policies are built for a fixed number of image inputs. If your dataset has fewer cameras, you can set **`empty_cameras`** in the policy config instead of modifying the model architecture.
|
||||
|
||||
### How it works
|
||||
|
||||
Setting `empty_cameras=N` adds N placeholder image features to the policy config, named:
|
||||
|
||||
```
|
||||
observation.images.empty_camera_0
|
||||
observation.images.empty_camera_1
|
||||
...
|
||||
```
|
||||
|
||||
At runtime, these keys have no corresponding data in the batch. The policy fills them with masked dummy tensors (padded with `-1` for SigLIP-based vision encoders, with a zero attention mask), so the extra image slots are effectively ignored during training and inference.
|
||||
|
||||
### Example
|
||||
|
||||
XVLA-base has three visual inputs and `empty_cameras=0` by default. Your dataset only has two cameras:
|
||||
|
||||
1. Set `--policy.empty_cameras=1`.
|
||||
2. The config adds a third key: `observation.images.empty_camera_0`.
|
||||
3. Use the rename map for your two real cameras as usual.
|
||||
4. The third slot is masked out — no fake images needed in your dataset.
|
||||
|
||||
## Quick reference
|
||||
|
||||
| Goal | What to do |
|
||||
| ----------------------------------------- | --------------------------------------------------------------------------- |
|
||||
| Dataset keys ≠ policy keys | `--rename_map='{"dataset_key": "policy_key", ...}'` |
|
||||
| Env keys ≠ policy keys (eval) | `--rename_map='{"env_key": "policy_key", ...}'` |
|
||||
| Recording with different keys (inference) | `--dataset.rename_map='{"source_key": "policy_key", ...}'`. |
|
||||
| Fewer cameras than policy expects | `--policy.empty_cameras=N` (supported by PI0, PI05, PI0Fast, SmolVLA, XVLA) |
|
||||
| Avoid passing a rename map | Edit the policy's `config.json` so its keys match your data source |
|
||||
@@ -1,717 +0,0 @@
|
||||
"""
|
||||
Action consistency analysis for imitation learning datasets.
|
||||
|
||||
Two parallel analyses per dataset:
|
||||
1. State-based: KNN in joint-state space → action chunk variance
|
||||
2. Image-based: KNN in SigLIP embedding space → action chunk variance
|
||||
|
||||
Comparing them reveals whether visual similarity and proprioceptive similarity
|
||||
agree on where the data is inconsistent — and images are what the policy
|
||||
primarily sees.
|
||||
"""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import av
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
from matplotlib.colors import LinearSegmentedColormap
|
||||
from PIL import Image
|
||||
from scipy.spatial import cKDTree
|
||||
from transformers import AutoImageProcessor, AutoModel
|
||||
|
||||
DATASETS = [
|
||||
{"repo_id": "lerobot-data-collection/level2_final_quality3", "label": "HQ curated"},
|
||||
{"repo_id": "lerobot-data-collection/level12_rac_2_2026-02-08_1", "label": "Full collection"},
|
||||
]
|
||||
OUTPUT_DIR = Path(__file__).resolve().parent / "outputs"
|
||||
OUTPUT_DIR.mkdir(exist_ok=True)
|
||||
|
||||
MAX_FRAMES = 100_000
|
||||
K_NEIGHBORS = 50
|
||||
ACTION_CHUNK_SIZE = 30
|
||||
CAMERA_KEY = "observation.images.base"
|
||||
ENCODER_MODEL = "google/siglip-base-patch16-224"
|
||||
ENCODE_BATCH_SIZE = 512
|
||||
SEED = 42
|
||||
DPI = 150
|
||||
|
||||
CONSISTENCY_CMAP = LinearSegmentedColormap.from_list(
|
||||
"consistency", ["#0a2e0a", "#1a8e1a", "#88cc22", "#ffaa22", "#ff2222"]
|
||||
)
|
||||
|
||||
# FK chains from OpenArm bimanual URDF (same as workspace_density.py).
|
||||
LEFT_CHAIN = [
|
||||
((-np.pi / 2, 0, 0), (0, 0.031, 0.698), None),
|
||||
((0, 0, 0), (0, 0, 0.0625), (0, 0, 1)),
|
||||
((-np.pi / 2, 0, 0), (-0.0301, 0, 0.06), (-1, 0, 0)),
|
||||
((0, 0, 0), (0.0301, 0, 0.06625), (0, 0, 1)),
|
||||
((0, 0, 0), (0, 0.0315, 0.15375), (0, 1, 0)),
|
||||
((0, 0, 0), (0, -0.0315, 0.0955), (0, 0, 1)),
|
||||
((0, 0, 0), (0.0375, 0, 0.1205), (1, 0, 0)),
|
||||
((0, 0, 0), (-0.0375, 0, 0), (0, -1, 0)),
|
||||
((0, 0, 0), (0, 0, 0.1001), None),
|
||||
((0, 0, 0), (0, 0, 0.08), None),
|
||||
]
|
||||
RIGHT_CHAIN = [
|
||||
((np.pi / 2, 0, 0), (0, -0.031, 0.698), None),
|
||||
((0, 0, 0), (0, 0, 0.0625), (0, 0, 1)),
|
||||
((np.pi / 2, 0, 0), (-0.0301, 0, 0.06), (-1, 0, 0)),
|
||||
((0, 0, 0), (0.0301, 0, 0.06625), (0, 0, 1)),
|
||||
((0, 0, 0), (0, 0.0315, 0.15375), (0, 1, 0)),
|
||||
((0, 0, 0), (0, -0.0315, 0.0955), (0, 0, 1)),
|
||||
((0, 0, 0), (0.0375, 0, 0.1205), (1, 0, 0)),
|
||||
((0, 0, 0), (-0.0375, 0, 0), (0, 1, 0)),
|
||||
((0, 0, 0), (0, 0, 0.1001), None),
|
||||
((0, 0, 0), (0, 0, 0.08), None),
|
||||
]
|
||||
|
||||
|
||||
# ── FK math ─────────────────────────────────────────────
|
||||
|
||||
|
||||
def _rot_x(a: float) -> np.ndarray:
|
||||
c, s = np.cos(a), np.sin(a)
|
||||
return np.array([[1, 0, 0], [0, c, -s], [0, s, c]])
|
||||
|
||||
|
||||
def _rot_y(a: float) -> np.ndarray:
|
||||
c, s = np.cos(a), np.sin(a)
|
||||
return np.array([[c, 0, s], [0, 1, 0], [-s, 0, c]])
|
||||
|
||||
|
||||
def _rot_z(a: float) -> np.ndarray:
|
||||
c, s = np.cos(a), np.sin(a)
|
||||
return np.array([[c, -s, 0], [s, c, 0], [0, 0, 1]])
|
||||
|
||||
|
||||
def _tf(rpy: tuple, xyz: tuple) -> np.ndarray:
|
||||
r, p, y = rpy
|
||||
mat = np.eye(4)
|
||||
mat[:3, :3] = _rot_z(y) @ _rot_y(p) @ _rot_x(r)
|
||||
mat[:3, 3] = xyz
|
||||
return mat
|
||||
|
||||
|
||||
def _batch_axis_rot(axis: tuple, angles: np.ndarray) -> np.ndarray:
|
||||
n = len(angles)
|
||||
ax = np.asarray(axis, dtype=np.float64)
|
||||
ax = ax / np.linalg.norm(ax)
|
||||
x, y, z = ax
|
||||
c = np.cos(angles)
|
||||
s = np.sin(angles)
|
||||
t = 1 - c
|
||||
rot = np.zeros((n, 4, 4))
|
||||
rot[:, 0, 0] = t * x * x + c
|
||||
rot[:, 0, 1] = t * x * y - s * z
|
||||
rot[:, 0, 2] = t * x * z + s * y
|
||||
rot[:, 1, 0] = t * x * y + s * z
|
||||
rot[:, 1, 1] = t * y * y + c
|
||||
rot[:, 1, 2] = t * y * z - s * x
|
||||
rot[:, 2, 0] = t * x * z - s * y
|
||||
rot[:, 2, 1] = t * y * z + s * x
|
||||
rot[:, 2, 2] = t * z * z + c
|
||||
rot[:, 3, 3] = 1.0
|
||||
return rot
|
||||
|
||||
|
||||
def batch_fk(chain: list, joint_angles: np.ndarray) -> np.ndarray:
|
||||
n = joint_angles.shape[0]
|
||||
tf_batch = np.tile(np.eye(4), (n, 1, 1))
|
||||
qi = 0
|
||||
for rpy, xyz, axis in chain:
|
||||
tf_batch = tf_batch @ _tf(rpy, xyz)
|
||||
if axis is not None:
|
||||
rot = _batch_axis_rot(axis, joint_angles[:, qi])
|
||||
tf_batch = np.einsum("nij,njk->nik", tf_batch, rot)
|
||||
qi += 1
|
||||
return tf_batch[:, :3, 3]
|
||||
|
||||
|
||||
# ── Data helpers ────────────────────────────────────────
|
||||
|
||||
|
||||
def _flatten_names(obj: object) -> list[str]:
|
||||
if isinstance(obj, dict):
|
||||
out: list[str] = []
|
||||
for v in obj.values():
|
||||
out.extend(_flatten_names(v))
|
||||
return out
|
||||
if isinstance(obj, (list, tuple)):
|
||||
out = []
|
||||
for item in obj:
|
||||
if isinstance(item, (list, tuple, dict)):
|
||||
out.extend(_flatten_names(item))
|
||||
else:
|
||||
out.append(str(item))
|
||||
return out
|
||||
return [str(obj)]
|
||||
|
||||
|
||||
def _detect_and_convert(vals: np.ndarray) -> np.ndarray:
|
||||
mx = np.max(np.abs(vals))
|
||||
if mx > 360:
|
||||
print(f" Unit detection: servo ticks (max={mx:.0f})")
|
||||
return (vals - 2048) / 2048 * np.pi
|
||||
if mx > 6.3:
|
||||
print(f" Unit detection: degrees (max={mx:.1f})")
|
||||
return np.deg2rad(vals)
|
||||
print(f" Unit detection: radians (max={mx:.3f})")
|
||||
return vals.astype(np.float64)
|
||||
|
||||
|
||||
def _find_joint_indices(features: dict, state_col: str, n_dim: int) -> tuple[list[int], list[int]]:
|
||||
feat = features.get("observation.state", features.get(state_col, {}))
|
||||
names = _flatten_names(feat.get("names", []))
|
||||
left_idx: list[int] = []
|
||||
right_idx: list[int] = []
|
||||
if names and len(names) == n_dim:
|
||||
names_l = [n.lower() for n in names]
|
||||
print(f" Feature names: {names[:4]}…{names[-4:]}")
|
||||
for j in range(1, 8):
|
||||
for i, nm in enumerate(names_l):
|
||||
if f"left_joint_{j}" in nm and i not in left_idx:
|
||||
left_idx.append(i)
|
||||
break
|
||||
for i, nm in enumerate(names_l):
|
||||
if f"right_joint_{j}" in nm and i not in right_idx:
|
||||
right_idx.append(i)
|
||||
break
|
||||
if len(left_idx) == 7 and len(right_idx) == 7:
|
||||
print(f" Matched by name: left={left_idx} right={right_idx}")
|
||||
return left_idx, right_idx
|
||||
if n_dim >= 16:
|
||||
print(" Falling back to positional: [0:7]=left, [8:15]=right")
|
||||
return list(range(7)), list(range(8, 15))
|
||||
if n_dim >= 14:
|
||||
print(" Falling back to positional: [0:7]=left, [7:14]=right")
|
||||
return list(range(7)), list(range(7, 14))
|
||||
raise RuntimeError(f"State dim {n_dim} too small for bimanual 7-DOF robot")
|
||||
|
||||
|
||||
def download_data(repo_id: str, camera_key: str) -> Path:
|
||||
print(f" Downloading {repo_id} (parquet + {camera_key} videos) …")
|
||||
return Path(
|
||||
snapshot_download(
|
||||
repo_id=repo_id,
|
||||
repo_type="dataset",
|
||||
allow_patterns=[
|
||||
"meta/**",
|
||||
"data/**",
|
||||
f"videos/{camera_key}/**",
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# ── Data loading ────────────────────────────────────────
|
||||
|
||||
|
||||
def _build_action_chunks(
|
||||
actions: np.ndarray, episode_ids: np.ndarray, chunk_size: int
|
||||
) -> tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
For each frame, concatenate the next chunk_size actions from the same episode.
|
||||
Returns (action_chunks, valid_mask).
|
||||
"""
|
||||
n = len(actions)
|
||||
act_dim = actions.shape[1]
|
||||
chunks = np.zeros((n, chunk_size * act_dim), dtype=np.float64)
|
||||
valid = np.zeros(n, dtype=bool)
|
||||
|
||||
for i in range(n):
|
||||
end = i + chunk_size
|
||||
if end > n:
|
||||
continue
|
||||
if episode_ids[i] != episode_ids[end - 1]:
|
||||
continue
|
||||
chunks[i] = actions[i:end].ravel()
|
||||
valid[i] = True
|
||||
|
||||
return chunks, valid
|
||||
|
||||
|
||||
def load_state_action_data(local: Path, max_frames: int, chunk_size: int, rng: np.random.Generator) -> dict:
|
||||
"""
|
||||
Load observation.state and action, build action chunks, subsample, normalize.
|
||||
Also returns the original row indices (`chosen_idx`) for video frame mapping.
|
||||
"""
|
||||
info = json.loads((local / "meta" / "info.json").read_text())
|
||||
features = info.get("features", {})
|
||||
|
||||
dfs = [pd.read_parquet(pq) for pq in sorted((local / "data").glob("**/*.parquet"))]
|
||||
df = pd.concat(dfs, ignore_index=True)
|
||||
n_total = len(df)
|
||||
print(f" Total frames: {n_total:,}")
|
||||
|
||||
state_col = next((c for c in df.columns if "observation.state" in c), None)
|
||||
action_col = next((c for c in df.columns if c == "action"), None)
|
||||
if state_col is None:
|
||||
raise RuntimeError(f"No observation.state column. Available: {list(df.columns)}")
|
||||
if action_col is None:
|
||||
raise RuntimeError(f"No action column. Available: {list(df.columns)}")
|
||||
|
||||
ep_col = next((c for c in df.columns if c == "episode_index"), None)
|
||||
if ep_col is None:
|
||||
raise RuntimeError(f"No episode_index column. Available: {list(df.columns)}")
|
||||
|
||||
state_all = np.stack(df[state_col].values).astype(np.float64)
|
||||
action_all = np.stack(df[action_col].values).astype(np.float64)
|
||||
episode_all = df[ep_col].values.astype(np.int64)
|
||||
|
||||
n_dim = state_all.shape[1]
|
||||
act_dim = action_all.shape[1]
|
||||
print(f" State dim: {n_dim} Action dim: {act_dim} Chunk size: {chunk_size}")
|
||||
print(f" Action chunk dim: {chunk_size * act_dim}")
|
||||
|
||||
left_idx, right_idx = _find_joint_indices(features, state_col, n_dim)
|
||||
|
||||
print(" Building action chunks …")
|
||||
action_chunks, valid = _build_action_chunks(action_all, episode_all, chunk_size)
|
||||
valid_idx = np.where(valid)[0]
|
||||
print(f" Valid frames (with full action chunk): {len(valid_idx):,} / {n_total:,}")
|
||||
|
||||
if len(valid_idx) > max_frames:
|
||||
chosen = np.sort(rng.choice(valid_idx, max_frames, replace=False))
|
||||
else:
|
||||
chosen = valid_idx
|
||||
print(f" Using {len(chosen):,} frames")
|
||||
|
||||
state_raw = state_all[chosen]
|
||||
action_raw = action_chunks[chosen]
|
||||
episode_ids = episode_all[chosen]
|
||||
|
||||
state_mean = state_raw.mean(axis=0)
|
||||
state_std = state_raw.std(axis=0)
|
||||
state_std[state_std < 1e-8] = 1.0
|
||||
state_norm = (state_raw - state_mean) / state_std
|
||||
|
||||
action_mean = action_raw.mean(axis=0)
|
||||
action_std = action_raw.std(axis=0)
|
||||
action_std[action_std < 1e-8] = 1.0
|
||||
action_norm = (action_raw - action_mean) / action_std
|
||||
|
||||
return {
|
||||
"state_raw": state_raw,
|
||||
"state_norm": state_norm,
|
||||
"action_raw": action_raw,
|
||||
"action_norm": action_norm,
|
||||
"episode_ids": episode_ids,
|
||||
"episode_all": episode_all,
|
||||
"left_joint_idx": left_idx,
|
||||
"right_joint_idx": right_idx,
|
||||
"n_total": n_total,
|
||||
"chosen_idx": chosen,
|
||||
"df": df,
|
||||
}
|
||||
|
||||
|
||||
# ── Video → frame extraction ──────────────────────────────
|
||||
|
||||
|
||||
def build_video_lookup(local: Path, camera_key: str) -> dict:
|
||||
"""
|
||||
Build a mapping from episode_index → {video_path, fps, from_ts}.
|
||||
"""
|
||||
info = json.loads((local / "meta" / "info.json").read_text())
|
||||
fps = info["fps"]
|
||||
video_template = info.get(
|
||||
"video_path",
|
||||
"videos/{video_key}/chunk-{chunk_index:03d}/file-{file_index:03d}.mp4",
|
||||
)
|
||||
|
||||
ep_rows = []
|
||||
for pq in sorted((local / "meta" / "episodes").glob("**/*.parquet")):
|
||||
ep_rows.append(pd.read_parquet(pq))
|
||||
ep_df = pd.concat(ep_rows, ignore_index=True)
|
||||
|
||||
chunk_col = f"videos/{camera_key}/chunk_index"
|
||||
file_col = f"videos/{camera_key}/file_index"
|
||||
ts_from = f"videos/{camera_key}/from_timestamp"
|
||||
if chunk_col not in ep_df.columns:
|
||||
chunk_col = f"{camera_key}/chunk_index"
|
||||
file_col = f"{camera_key}/file_index"
|
||||
ts_from = f"{camera_key}/from_timestamp"
|
||||
|
||||
lookup: dict[int, dict] = {}
|
||||
for _, row in ep_df.iterrows():
|
||||
ci = int(row[chunk_col])
|
||||
fi = int(row[file_col])
|
||||
video_rel = video_template.format(video_key=camera_key, chunk_index=ci, file_index=fi)
|
||||
lookup[int(row["episode_index"])] = {
|
||||
"video_path": local / video_rel,
|
||||
"from_ts": float(row[ts_from]),
|
||||
"fps": fps,
|
||||
}
|
||||
return lookup
|
||||
|
||||
|
||||
def _decode_video_frames(video_path: str) -> list[np.ndarray]:
|
||||
"""Decode all frames from a video file using PyAV. Returns list of RGB arrays."""
|
||||
container = av.open(video_path)
|
||||
stream = container.streams.video[0]
|
||||
stream.thread_type = "AUTO"
|
||||
decoded = []
|
||||
for frame in container.decode(stream):
|
||||
decoded.append(frame.to_ndarray(format="rgb24"))
|
||||
container.close()
|
||||
return decoded
|
||||
|
||||
|
||||
def extract_frames(
|
||||
chosen_idx: np.ndarray,
|
||||
episode_all: np.ndarray,
|
||||
video_lookup: dict,
|
||||
) -> list[np.ndarray | None]:
|
||||
"""
|
||||
Extract RGB frames for each chosen global index using PyAV.
|
||||
Returns list of (H, W, 3) RGB arrays (or None on failure).
|
||||
"""
|
||||
unique_eps = np.unique(episode_all)
|
||||
ep_start: dict[int, int] = {}
|
||||
for ep in unique_eps:
|
||||
ep_start[int(ep)] = int(np.where(episode_all == ep)[0][0])
|
||||
|
||||
# Build jobs: (output_index, video_path, local_frame_number)
|
||||
jobs: list[tuple[int, str, int]] = []
|
||||
for out_i, global_i in enumerate(chosen_idx):
|
||||
ep = int(episode_all[global_i])
|
||||
info = video_lookup.get(ep)
|
||||
if info is None:
|
||||
continue
|
||||
local_frame = global_i - ep_start[ep]
|
||||
jobs.append((out_i, str(info["video_path"]), local_frame))
|
||||
|
||||
# Group by video file, decode each video once
|
||||
from collections import defaultdict
|
||||
|
||||
video_jobs: dict[str, list[tuple[int, int]]] = defaultdict(list)
|
||||
for out_i, vpath, local_frame in jobs:
|
||||
video_jobs[vpath].append((out_i, local_frame))
|
||||
|
||||
frames: list[np.ndarray | None] = [None] * len(chosen_idx)
|
||||
extracted = 0
|
||||
n_videos = len(video_jobs)
|
||||
for vi, (vpath, frame_requests) in enumerate(video_jobs.items()):
|
||||
if not Path(vpath).exists():
|
||||
continue
|
||||
try:
|
||||
decoded = _decode_video_frames(vpath)
|
||||
except Exception as exc:
|
||||
print(f" Warning: failed to decode {Path(vpath).name}: {exc}")
|
||||
continue
|
||||
for out_i, local_frame in frame_requests:
|
||||
if 0 <= local_frame < len(decoded):
|
||||
frames[out_i] = decoded[local_frame]
|
||||
extracted += 1
|
||||
if (vi + 1) % 50 == 0 or (vi + 1) == n_videos:
|
||||
print(f" Decoded {vi + 1}/{n_videos} videos ({extracted:,} frames so far)")
|
||||
del decoded
|
||||
|
||||
print(f" Extracted {extracted:,} / {len(chosen_idx):,} frames from video")
|
||||
return frames
|
||||
|
||||
|
||||
# ── SigLIP encoding ─────────────────────────────────────
|
||||
|
||||
|
||||
def encode_frames_siglip(
|
||||
frames: list[np.ndarray | None],
|
||||
model_name: str,
|
||||
batch_size: int,
|
||||
device: torch.device,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Encode RGB frames through SigLIP vision encoder.
|
||||
Returns (N, embed_dim) float32 array. Frames that are None get a zero vector.
|
||||
"""
|
||||
print(f" Loading SigLIP model: {model_name} …")
|
||||
processor = AutoImageProcessor.from_pretrained(model_name)
|
||||
model = AutoModel.from_pretrained(model_name).to(device).eval()
|
||||
embed_dim = model.config.vision_config.hidden_size
|
||||
|
||||
n = len(frames)
|
||||
embeddings = np.zeros((n, embed_dim), dtype=np.float32)
|
||||
|
||||
valid_indices = [i for i, f in enumerate(frames) if f is not None]
|
||||
print(f" Encoding {len(valid_indices):,} valid frames in batches of {batch_size} …")
|
||||
|
||||
for batch_start in range(0, len(valid_indices), batch_size):
|
||||
batch_idx = valid_indices[batch_start : batch_start + batch_size]
|
||||
pil_images = [Image.fromarray(frames[i]) for i in batch_idx]
|
||||
|
||||
inputs = processor(images=pil_images, return_tensors="pt").to(device)
|
||||
with torch.no_grad():
|
||||
image_features = model.get_image_features(**inputs)
|
||||
image_features = torch.nn.functional.normalize(image_features, dim=-1)
|
||||
embeddings[batch_idx] = image_features.cpu().numpy()
|
||||
|
||||
done = min(batch_start + batch_size, len(valid_indices))
|
||||
if done % (batch_size * 10) == 0 or done == len(valid_indices):
|
||||
print(f" {done:,} / {len(valid_indices):,} encoded")
|
||||
|
||||
del model, processor
|
||||
torch.cuda.empty_cache()
|
||||
return embeddings
|
||||
|
||||
|
||||
# ── KNN consistency ─────────────────────────────────────
|
||||
|
||||
|
||||
def compute_consistency(
|
||||
features: np.ndarray,
|
||||
action_norm: np.ndarray,
|
||||
episode_ids: np.ndarray,
|
||||
k: int,
|
||||
label: str = "",
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
For each frame, find K nearest neighbors in feature space from other episodes.
|
||||
Return per-frame action variance (mean across action dims).
|
||||
"""
|
||||
n = len(features)
|
||||
print(f" Building KD-tree on {n:,} vectors ({label}) …")
|
||||
tree = cKDTree(features)
|
||||
|
||||
k_query = min(k * 3, n - 1)
|
||||
print(f" Querying {k_query} neighbors per frame …")
|
||||
_dists, indices = tree.query(features, k=k_query + 1)
|
||||
indices = indices[:, 1:]
|
||||
|
||||
print(f" Computing cross-episode action variance ({label}) …")
|
||||
variance = np.zeros(n)
|
||||
for i in range(n):
|
||||
ep_i = episode_ids[i]
|
||||
neighbors = indices[i]
|
||||
cross_ep = neighbors[episode_ids[neighbors] != ep_i][:k]
|
||||
if len(cross_ep) < 2:
|
||||
variance[i] = 0.0
|
||||
continue
|
||||
neighbor_actions = action_norm[cross_ep]
|
||||
variance[i] = np.mean(np.var(neighbor_actions, axis=0))
|
||||
|
||||
return variance
|
||||
|
||||
|
||||
# ── Visualization ───────────────────────────────────────
|
||||
|
||||
|
||||
def _style_ax(ax: plt.Axes) -> None:
|
||||
ax.set_facecolor("#0d1117")
|
||||
ax.tick_params(colors="#555", labelsize=8)
|
||||
for spine in ax.spines.values():
|
||||
spine.set_color("#333")
|
||||
|
||||
|
||||
def _plot_histogram(ax: plt.Axes, variance: np.ndarray, title: str, color: str) -> None:
|
||||
_style_ax(ax)
|
||||
median_var = np.median(variance)
|
||||
mean_var = np.mean(variance)
|
||||
nonzero = variance[variance > 0]
|
||||
if len(nonzero) > 0:
|
||||
bins = np.logspace(np.log10(nonzero.min().clip(1e-6)), np.log10(nonzero.max()), 60)
|
||||
ax.hist(nonzero, bins=bins, color=color, alpha=0.8, edgecolor="#222")
|
||||
ax.set_xscale("log")
|
||||
ax.axvline(median_var, color="#ff6600", linewidth=2, label=f"median={median_var:.3f}")
|
||||
ax.axvline(mean_var, color="#ff2222", linewidth=2, linestyle="--", label=f"mean={mean_var:.3f}")
|
||||
ax.set_xlabel("Action variance (log scale)", color="#888", fontsize=10)
|
||||
ax.set_ylabel("Frame count", color="#888", fontsize=10)
|
||||
ax.set_title(title, color="white", fontsize=11, pad=10)
|
||||
ax.legend(fontsize=8, facecolor="#1a1a2e", edgecolor="#333", labelcolor="white")
|
||||
|
||||
|
||||
def _plot_episode_curves(
|
||||
ax: plt.Axes,
|
||||
var_state: np.ndarray,
|
||||
var_image: np.ndarray,
|
||||
episode_ids: np.ndarray,
|
||||
title: str,
|
||||
) -> None:
|
||||
_style_ax(ax)
|
||||
unique_eps = np.unique(episode_ids)
|
||||
|
||||
ep_means_s = np.array([var_state[episode_ids == ep].mean() for ep in unique_eps])
|
||||
ep_means_i = np.array([var_image[episode_ids == ep].mean() for ep in unique_eps])
|
||||
|
||||
sorted_s = np.sort(ep_means_s)[::-1]
|
||||
sorted_i = np.sort(ep_means_i)[::-1]
|
||||
ep_x = np.arange(len(unique_eps))
|
||||
|
||||
ax.fill_between(ep_x, sorted_s, alpha=0.2, color="#4363d8")
|
||||
ax.plot(ep_x, sorted_s, color="#4363d8", linewidth=1.2, label=f"State (med={np.median(ep_means_s):.3f})")
|
||||
ax.fill_between(ep_x, sorted_i, alpha=0.2, color="#e6194b")
|
||||
ax.plot(ep_x, sorted_i, color="#e6194b", linewidth=1.2, label=f"Image (med={np.median(ep_means_i):.3f})")
|
||||
|
||||
ax.set_xlabel("Episode rank (worst → best)", color="#888", fontsize=10)
|
||||
ax.set_ylabel("Mean action variance", color="#888", fontsize=10)
|
||||
ax.set_title(title, color="white", fontsize=11, pad=10)
|
||||
ax.legend(fontsize=8, facecolor="#1a1a2e", edgecolor="#333", labelcolor="white")
|
||||
|
||||
|
||||
def _plot_heatmap(
|
||||
ax: plt.Axes, fig: plt.Figure, tcp_xz: np.ndarray, variance: np.ndarray, title: str
|
||||
) -> None:
|
||||
_style_ax(ax)
|
||||
order = np.argsort(variance)
|
||||
pts = tcp_xz[order]
|
||||
var_sorted = variance[order]
|
||||
vmin = np.percentile(variance[variance > 0], 5) if np.any(variance > 0) else 0
|
||||
vmax = np.percentile(variance[variance > 0], 95) if np.any(variance > 0) else 1
|
||||
sc = ax.scatter(
|
||||
pts[:, 0],
|
||||
pts[:, 1],
|
||||
c=var_sorted,
|
||||
cmap=CONSISTENCY_CMAP,
|
||||
s=0.5,
|
||||
alpha=0.6,
|
||||
vmin=vmin,
|
||||
vmax=vmax,
|
||||
rasterized=True,
|
||||
)
|
||||
ax.set_xlabel("X (m)", color="#888", fontsize=10)
|
||||
ax.set_ylabel("Z (m)", color="#888", fontsize=10)
|
||||
ax.set_title(title, color="white", fontsize=11, pad=10)
|
||||
ax.set_aspect("equal")
|
||||
cbar = fig.colorbar(sc, ax=ax, shrink=0.8, pad=0.02)
|
||||
cbar.set_label("Action variance", color="white", fontsize=9)
|
||||
cbar.ax.tick_params(colors="#aaa", labelsize=7)
|
||||
|
||||
|
||||
def render(results: list[dict], out_path: Path) -> None:
|
||||
"""
|
||||
4-row x N-column figure:
|
||||
Row 0: State-based variance histogram
|
||||
Row 1: Image-based variance histogram
|
||||
Row 2: Per-episode curves (both overlaid)
|
||||
Row 3: Spatial heatmap (image-based variance)
|
||||
"""
|
||||
n_ds = len(results)
|
||||
fig, axes = plt.subplots(4, n_ds, figsize=(9 * n_ds, 24), facecolor="#0d1117")
|
||||
if n_ds == 1:
|
||||
axes = axes[:, np.newaxis]
|
||||
|
||||
headline_parts = []
|
||||
for col, r in enumerate(results):
|
||||
label = r["label"]
|
||||
var_s = r["var_state"]
|
||||
var_i = r["var_image"]
|
||||
tcp_xz = r["tcp_xz"]
|
||||
episode_ids = r["episode_ids"]
|
||||
|
||||
med_s = np.median(var_s)
|
||||
med_i = np.median(var_i)
|
||||
headline_parts.append(f"{label}: state={med_s:.3f}, image={med_i:.3f}")
|
||||
|
||||
_plot_histogram(axes[0, col], var_s, f"{label}\nState-based variance (K={K_NEIGHBORS})", "#4363d8")
|
||||
_plot_histogram(
|
||||
axes[1, col], var_i, f"{label}\nImage-based variance (SigLIP, K={K_NEIGHBORS})", "#e6194b"
|
||||
)
|
||||
_plot_episode_curves(
|
||||
axes[2, col],
|
||||
var_s,
|
||||
var_i,
|
||||
episode_ids,
|
||||
f"{label}\nPer-episode inconsistency ({len(np.unique(episode_ids)):,} episodes)",
|
||||
)
|
||||
_plot_heatmap(
|
||||
axes[3, col],
|
||||
fig,
|
||||
tcp_xz,
|
||||
var_i,
|
||||
f"{label}\nImage-based variance by TCP position (XZ)",
|
||||
)
|
||||
|
||||
fig.suptitle(
|
||||
f"Action Consistency: State vs Image (chunk={ACTION_CHUNK_SIZE}, K={K_NEIGHBORS})\n"
|
||||
+ " | ".join(headline_parts),
|
||||
color="white",
|
||||
fontsize=15,
|
||||
y=0.99,
|
||||
)
|
||||
plt.tight_layout(rect=[0, 0, 1, 0.96])
|
||||
plt.savefig(out_path, dpi=DPI, bbox_inches="tight", facecolor=fig.get_facecolor())
|
||||
plt.close()
|
||||
print(f"\n✓ Saved: {out_path}")
|
||||
|
||||
|
||||
# ── Main ────────────────────────────────────────────────
|
||||
|
||||
|
||||
def main() -> None:
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print(f"Device: {device}")
|
||||
rng = np.random.default_rng(SEED)
|
||||
results = []
|
||||
|
||||
for ds in DATASETS:
|
||||
repo_id, label = ds["repo_id"], ds["label"]
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f" {label}: {repo_id}")
|
||||
print(f"{'=' * 60}")
|
||||
|
||||
local = download_data(repo_id, CAMERA_KEY)
|
||||
data = load_state_action_data(local, MAX_FRAMES, ACTION_CHUNK_SIZE, rng)
|
||||
|
||||
# --- State-based KNN ---
|
||||
var_state = compute_consistency(
|
||||
data["state_norm"], data["action_norm"], data["episode_ids"], K_NEIGHBORS, "state"
|
||||
)
|
||||
print(
|
||||
f" State variance: median={np.median(var_state):.4f} "
|
||||
f"mean={np.mean(var_state):.4f} p90={np.percentile(var_state, 90):.4f}"
|
||||
)
|
||||
|
||||
# --- Image-based KNN ---
|
||||
print("\n Preparing image embeddings …")
|
||||
video_lookup = build_video_lookup(local, CAMERA_KEY)
|
||||
frames = extract_frames(data["chosen_idx"], data["episode_all"], video_lookup)
|
||||
embeddings = encode_frames_siglip(frames, ENCODER_MODEL, ENCODE_BATCH_SIZE, device)
|
||||
del frames # free memory
|
||||
|
||||
var_image = compute_consistency(
|
||||
embeddings, data["action_norm"], data["episode_ids"], K_NEIGHBORS, "image"
|
||||
)
|
||||
print(
|
||||
f" Image variance: median={np.median(var_image):.4f} "
|
||||
f"mean={np.mean(var_image):.4f} p90={np.percentile(var_image, 90):.4f}"
|
||||
)
|
||||
|
||||
# FK for spatial heatmap
|
||||
print(" Computing FK for spatial heatmap …")
|
||||
left_raw = data["state_raw"][:, data["left_joint_idx"]]
|
||||
left_rad = _detect_and_convert(left_raw)
|
||||
left_tcp = batch_fk(LEFT_CHAIN, left_rad)
|
||||
tcp_xz = left_tcp[:, [0, 2]]
|
||||
|
||||
results.append(
|
||||
{
|
||||
"label": label,
|
||||
"var_state": var_state,
|
||||
"var_image": var_image,
|
||||
"episode_ids": data["episode_ids"],
|
||||
"tcp_xz": tcp_xz,
|
||||
"n_total": data["n_total"],
|
||||
}
|
||||
)
|
||||
|
||||
out = OUTPUT_DIR / "action_consistency_comparison.jpg"
|
||||
render(results, out)
|
||||
|
||||
# Save worst-episodes summary (image-based, since that's the stronger signal)
|
||||
worst_summary = {}
|
||||
for r in results:
|
||||
unique_eps = np.unique(r["episode_ids"])
|
||||
ep_means = {int(ep): float(r["var_image"][r["episode_ids"] == ep].mean()) for ep in unique_eps}
|
||||
ranked = sorted(ep_means.items(), key=lambda x: x[1], reverse=True)[:50]
|
||||
worst_summary[r["label"]] = [{"episode": ep, "mean_variance": v} for ep, v in ranked]
|
||||
worst_path = OUTPUT_DIR / "action_consistency_worst_episodes.json"
|
||||
worst_path.write_text(json.dumps(worst_summary, indent=2))
|
||||
print(f"✓ Saved worst episodes: {worst_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,178 +0,0 @@
|
||||
"""
|
||||
Create a JPG grid of random frames sampled from a LeRobot video dataset.
|
||||
Downloads metadata + video chunks from HuggingFace, picks random frames,
|
||||
decodes them, and tiles into a single image.
|
||||
"""
|
||||
|
||||
import json
|
||||
import random
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
REPO_ID = "lerobot-data-collection/level2_final_quality3"
|
||||
CAMERA_KEY = "observation.images.base"
|
||||
GRID_COLS = 15
|
||||
GRID_ROWS = 10
|
||||
THUMB_WIDTH = 160
|
||||
OUTPUT_DIR = Path(__file__).resolve().parent / "outputs"
|
||||
OUTPUT_DIR.mkdir(exist_ok=True)
|
||||
SEED = 1
|
||||
|
||||
|
||||
def download_metadata(repo_id: str) -> Path:
|
||||
"""Download only metadata (no videos yet)."""
|
||||
print(f"[1/3] Downloading metadata for {repo_id} …")
|
||||
return Path(
|
||||
snapshot_download(
|
||||
repo_id=repo_id,
|
||||
repo_type="dataset",
|
||||
allow_patterns=["meta/**"],
|
||||
ignore_patterns=["*.mp4"],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def load_video_info(local: Path) -> tuple[str, list[dict], int]:
|
||||
"""Parse info.json and episode parquets. Returns (camera_key, episode_rows, fps)."""
|
||||
info = json.loads((local / "meta" / "info.json").read_text())
|
||||
fps = info["fps"]
|
||||
features = info["features"]
|
||||
|
||||
video_keys = [k for k, v in features.items() if v.get("dtype") == "video"]
|
||||
if not video_keys:
|
||||
raise RuntimeError("No video keys found in dataset features")
|
||||
|
||||
if CAMERA_KEY is not None:
|
||||
if CAMERA_KEY not in video_keys:
|
||||
raise RuntimeError(f"CAMERA_KEY='{CAMERA_KEY}' not found. Available: {video_keys}")
|
||||
cam = CAMERA_KEY
|
||||
else:
|
||||
cam = video_keys[0]
|
||||
print(f" camera='{cam}' all_cams={video_keys} fps={fps}")
|
||||
|
||||
ep_rows = []
|
||||
for pq in sorted((local / "meta" / "episodes").glob("**/*.parquet")):
|
||||
ep_rows.append(pd.read_parquet(pq))
|
||||
ep_df = pd.concat(ep_rows, ignore_index=True)
|
||||
|
||||
video_template = info.get(
|
||||
"video_path",
|
||||
"videos/{video_key}/chunk-{chunk_index:03d}/file-{file_index:03d}.mp4",
|
||||
)
|
||||
|
||||
chunk_col = f"videos/{cam}/chunk_index"
|
||||
file_col = f"videos/{cam}/file_index"
|
||||
ts_from = f"videos/{cam}/from_timestamp"
|
||||
ts_to = f"videos/{cam}/to_timestamp"
|
||||
if chunk_col not in ep_df.columns:
|
||||
chunk_col = f"{cam}/chunk_index"
|
||||
file_col = f"{cam}/file_index"
|
||||
ts_from = f"{cam}/from_timestamp"
|
||||
ts_to = f"{cam}/to_timestamp"
|
||||
|
||||
episodes = []
|
||||
for _, row in ep_df.iterrows():
|
||||
ci = int(row[chunk_col])
|
||||
fi = int(row[file_col])
|
||||
episodes.append(
|
||||
{
|
||||
"episode_index": int(row["episode_index"]),
|
||||
"chunk_index": ci,
|
||||
"file_index": fi,
|
||||
"from_ts": float(row[ts_from]),
|
||||
"to_ts": float(row[ts_to]),
|
||||
"video_rel": video_template.format(video_key=cam, chunk_index=ci, file_index=fi),
|
||||
}
|
||||
)
|
||||
return cam, episodes, fps
|
||||
|
||||
|
||||
def pick_random_frames(episodes: list[dict], fps: int, n: int, rng: random.Random) -> list[dict]:
|
||||
"""Pick n random (episode, timestamp) pairs, return sorted by video file for efficient access."""
|
||||
picks = []
|
||||
for _ in range(n):
|
||||
ep = rng.choice(episodes)
|
||||
duration = ep["to_ts"] - ep["from_ts"]
|
||||
if duration <= 0:
|
||||
continue
|
||||
t = ep["from_ts"] + rng.random() * duration
|
||||
picks.append({**ep, "seek_ts": t})
|
||||
picks.sort(key=lambda p: (p["video_rel"], p["seek_ts"]))
|
||||
return picks
|
||||
|
||||
|
||||
def download_video_files(repo_id: str, local: Path, picks: list[dict]) -> None:
|
||||
"""Download only the video files we need."""
|
||||
needed = sorted({p["video_rel"] for p in picks})
|
||||
print(f"[2/3] Downloading {len(needed)} video file(s) …")
|
||||
snapshot_download(
|
||||
repo_id=repo_id,
|
||||
repo_type="dataset",
|
||||
local_dir=str(local),
|
||||
allow_patterns=needed,
|
||||
)
|
||||
|
||||
|
||||
def extract_frame(video_path: Path, seek_ts: float) -> np.ndarray | None:
|
||||
"""Decode a single frame at the given timestamp."""
|
||||
cap = cv2.VideoCapture(str(video_path))
|
||||
cap.set(cv2.CAP_PROP_POS_MSEC, seek_ts * 1000.0)
|
||||
ret, frame = cap.read()
|
||||
cap.release()
|
||||
return frame if ret else None
|
||||
|
||||
|
||||
def build_grid(frames: list[np.ndarray], cols: int, thumb_w: int) -> np.ndarray:
|
||||
"""Resize frames to uniform thumbnails and tile into a grid."""
|
||||
if not frames:
|
||||
raise RuntimeError("No frames decoded")
|
||||
|
||||
h0, w0 = frames[0].shape[:2]
|
||||
thumb_h = int(thumb_w * h0 / w0)
|
||||
|
||||
thumbs = [cv2.resize(f, (thumb_w, thumb_h), interpolation=cv2.INTER_AREA) for f in frames]
|
||||
|
||||
rows = []
|
||||
for i in range(0, len(thumbs), cols):
|
||||
row_thumbs = thumbs[i : i + cols]
|
||||
while len(row_thumbs) < cols:
|
||||
row_thumbs.append(np.zeros_like(row_thumbs[0]))
|
||||
rows.append(np.hstack(row_thumbs))
|
||||
return np.vstack(rows)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
rng = random.Random(SEED)
|
||||
n_frames = GRID_COLS * GRID_ROWS
|
||||
|
||||
local = download_metadata(REPO_ID)
|
||||
cam, episodes, fps = load_video_info(local)
|
||||
picks = pick_random_frames(episodes, fps, n_frames, rng)
|
||||
download_video_files(REPO_ID, local, picks)
|
||||
|
||||
print(f"[3/3] Decoding {n_frames} frames …")
|
||||
frames: list[np.ndarray] = []
|
||||
for p in picks:
|
||||
vp = local / p["video_rel"]
|
||||
if not vp.exists():
|
||||
print(f" SKIP: {p['video_rel']} not found")
|
||||
continue
|
||||
frame = extract_frame(vp, p["seek_ts"])
|
||||
if frame is not None:
|
||||
frames.append(frame)
|
||||
|
||||
print(f" Decoded {len(frames)}/{n_frames} frames")
|
||||
grid = build_grid(frames, GRID_COLS, THUMB_WIDTH)
|
||||
|
||||
safe_name = REPO_ID.replace("/", "_")
|
||||
out_path = OUTPUT_DIR / f"{safe_name}_grid_{GRID_COLS}x{GRID_ROWS}.jpg"
|
||||
cv2.imwrite(str(out_path), grid, [cv2.IMWRITE_JPEG_QUALITY, 92])
|
||||
print(f"\n✓ Saved: {out_path} ({grid.shape[1]}×{grid.shape[0]})")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,526 +0,0 @@
|
||||
"""
|
||||
Create MP4 videos with sarm_progress overlay for specified episodes.
|
||||
Downloads datasets from HuggingFace, extracts episode video + progress data,
|
||||
and draws the progress line directly on each frame (no panel, no axes).
|
||||
"""
|
||||
|
||||
import json
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
DATASETS = [
|
||||
{"repo_id": "lerobot-data-collection/level2_final_quality3", "episode": 250},
|
||||
]
|
||||
CAMERA_KEY = (
|
||||
"observation.images.base" # None = auto-select first camera, or set e.g. "observation.images.top"
|
||||
)
|
||||
OUTPUT_DIR = Path(__file__).resolve().parent / "outputs"
|
||||
OUTPUT_DIR.mkdir(exist_ok=True)
|
||||
|
||||
# Progress line spans the full video height
|
||||
GRAPH_Y_TOP_FRAC = 0.01
|
||||
GRAPH_Y_BOT_FRAC = 0.99
|
||||
LINE_THICKNESS = 3
|
||||
SHADOW_THICKNESS = 6 # white edge thickness
|
||||
REF_ALPHA = 0.45 # opacity of the 1.0 reference line
|
||||
FILL_ALPHA = 0.55 # opacity of the grey fill under the line
|
||||
SCORE_FONT_SCALE = 0.8
|
||||
TASK_FONT_SCALE = 0.55
|
||||
|
||||
|
||||
def download_episode(repo_id: str, episode: int) -> Path:
|
||||
"""Download only the files needed for this episode."""
|
||||
# We need: meta/, sarm_progress.parquet, and the relevant video/data chunks.
|
||||
# We'll download meta + sarm first, then figure out chunks.
|
||||
print(f"\n[1/5] Downloading metadata for {repo_id} …")
|
||||
local = Path(
|
||||
snapshot_download(
|
||||
repo_id=repo_id,
|
||||
repo_type="dataset",
|
||||
allow_patterns=["meta/**", "sarm_progress.parquet"],
|
||||
ignore_patterns=["*.mp4"],
|
||||
)
|
||||
)
|
||||
return local
|
||||
|
||||
|
||||
def load_episode_meta(local: Path, episode: int) -> dict:
|
||||
"""Read info.json + episode-level parquet to get fps, video paths, timestamps."""
|
||||
info = json.loads((local / "meta" / "info.json").read_text())
|
||||
fps = info["fps"]
|
||||
features = info["features"]
|
||||
|
||||
# Find video keys (keys whose dtype=="video")
|
||||
video_keys = [k for k, v in features.items() if v.get("dtype") == "video"]
|
||||
if not video_keys:
|
||||
raise RuntimeError("No video keys found in dataset features")
|
||||
if CAMERA_KEY is not None:
|
||||
if CAMERA_KEY not in video_keys:
|
||||
raise RuntimeError(f"CAMERA_KEY='{CAMERA_KEY}' not found. Available: {video_keys}")
|
||||
first_cam = CAMERA_KEY
|
||||
else:
|
||||
first_cam = video_keys[0]
|
||||
print(f" fps={fps} camera='{first_cam}' all_cams={video_keys}")
|
||||
|
||||
# Load all episode-meta parquet files and find our episode
|
||||
ep_rows = []
|
||||
for pq in sorted((local / "meta" / "episodes").glob("**/*.parquet")):
|
||||
df = pd.read_parquet(pq)
|
||||
ep_rows.append(df)
|
||||
ep_df = pd.concat(ep_rows, ignore_index=True)
|
||||
row = ep_df[ep_df["episode_index"] == episode]
|
||||
if row.empty:
|
||||
raise RuntimeError(f"Episode {episode} not found in episode metadata")
|
||||
row = row.iloc[0]
|
||||
|
||||
# Extract video chunk/file index for first camera
|
||||
# Try both dot and slash variants of the key
|
||||
chunk_col = f"videos/{first_cam}/chunk_index"
|
||||
file_col = f"videos/{first_cam}/file_index"
|
||||
ts_col = f"videos/{first_cam}/from_timestamp"
|
||||
to_col = f"videos/{first_cam}/to_timestamp"
|
||||
|
||||
# Some datasets use different column naming
|
||||
if chunk_col not in row.index:
|
||||
# Try without the 'videos/' prefix
|
||||
chunk_col = f"{first_cam}/chunk_index"
|
||||
file_col = f"{first_cam}/file_index"
|
||||
ts_col = f"{first_cam}/from_timestamp"
|
||||
to_col = f"{first_cam}/to_timestamp"
|
||||
if chunk_col not in row.index:
|
||||
raise RuntimeError(
|
||||
f"Cannot find video metadata columns for {first_cam}.\nAvailable: {list(row.index)}"
|
||||
)
|
||||
|
||||
chunk_idx = int(row[chunk_col])
|
||||
file_idx = int(row[file_col])
|
||||
from_ts = float(row[ts_col])
|
||||
to_ts = float(row[to_col])
|
||||
|
||||
video_template = info.get(
|
||||
"video_path", "videos/{video_key}/chunk-{chunk_index:03d}/file-{file_index:03d}.mp4"
|
||||
)
|
||||
video_rel = video_template.format(
|
||||
video_key=first_cam,
|
||||
chunk_index=chunk_idx,
|
||||
file_index=file_idx,
|
||||
)
|
||||
|
||||
# Load task name for this episode
|
||||
# tasks.parquet uses the task string as the row index; task_index column holds the int id
|
||||
task_name = ""
|
||||
try:
|
||||
# Prefer the 'tasks' list directly on the episode row
|
||||
if "tasks" in row.index and row["tasks"] is not None:
|
||||
tasks_val = row["tasks"]
|
||||
if isinstance(tasks_val, (list, tuple, np.ndarray)) and len(tasks_val) > 0:
|
||||
task_name = str(tasks_val[0])
|
||||
else:
|
||||
task_name = str(tasks_val).strip("[]'")
|
||||
else:
|
||||
tasks_pq = local / "meta" / "tasks.parquet"
|
||||
if tasks_pq.exists():
|
||||
tasks_df = pd.read_parquet(tasks_pq)
|
||||
# Row index is the task string; task_index column is the int
|
||||
task_idx = int(row.get("task_index", 0)) if "task_index" in row.index else 0
|
||||
match = tasks_df[tasks_df["task_index"] == task_idx]
|
||||
if not match.empty:
|
||||
task_name = str(match.index[0])
|
||||
print(f" Task name: '{task_name}'")
|
||||
except Exception as e:
|
||||
print(f" WARNING: could not load task name: {e}")
|
||||
|
||||
return {
|
||||
"fps": fps,
|
||||
"first_cam": first_cam,
|
||||
"video_rel": video_rel,
|
||||
"chunk_index": chunk_idx,
|
||||
"file_index": file_idx,
|
||||
"from_ts": from_ts,
|
||||
"to_ts": to_ts,
|
||||
"task_name": task_name,
|
||||
}
|
||||
|
||||
|
||||
def download_video(repo_id: str, local: Path, video_rel: str) -> Path:
|
||||
"""Download the specific video file if not already present."""
|
||||
video_path = local / video_rel
|
||||
if video_path.exists():
|
||||
print(f" Video already cached: {video_path}")
|
||||
return video_path
|
||||
print(f"[2/5] Downloading video file {video_rel} …")
|
||||
snapshot_download(
|
||||
repo_id=repo_id,
|
||||
repo_type="dataset",
|
||||
local_dir=str(local),
|
||||
allow_patterns=[video_rel],
|
||||
)
|
||||
if not video_path.exists():
|
||||
raise RuntimeError(f"Video not found after download: {video_path}")
|
||||
return video_path
|
||||
|
||||
|
||||
def load_progress(local: Path, episode: int) -> np.ndarray | None:
|
||||
"""Load sarm_progress values for this episode. Returns sorted array of (frame_index, progress)."""
|
||||
pq_path = local / "sarm_progress.parquet"
|
||||
if not pq_path.exists():
|
||||
print(" WARNING: sarm_progress.parquet not found, trying data parquet …")
|
||||
return None
|
||||
df = pd.read_parquet(pq_path)
|
||||
print(f" sarm_progress.parquet columns: {list(df.columns)}")
|
||||
ep_df = df[df["episode_index"] == episode].copy()
|
||||
if ep_df.empty:
|
||||
print(f" WARNING: No sarm_progress rows for episode {episode}")
|
||||
return None
|
||||
ep_df = ep_df.sort_values("frame_index")
|
||||
|
||||
# Prefer dense, fall back to sparse
|
||||
if "progress_dense" in ep_df.columns and ep_df["progress_dense"].notna().any():
|
||||
prog_col = "progress_dense"
|
||||
elif "progress_sparse" in ep_df.columns:
|
||||
prog_col = "progress_sparse"
|
||||
else:
|
||||
# Last resort: any column with 'progress' in the name
|
||||
prog_cols = [c for c in ep_df.columns if "progress" in c.lower()]
|
||||
if not prog_cols:
|
||||
return None
|
||||
prog_col = prog_cols[0]
|
||||
|
||||
print(f" Using progress column: '{prog_col}'")
|
||||
return ep_df[["frame_index", prog_col]].rename(columns={prog_col: "progress"}).values
|
||||
|
||||
|
||||
def extract_episode_clip(video_path: Path, from_ts: float, to_ts: float, out_path: Path) -> Path:
|
||||
"""Use ffmpeg to cut the episode segment from the combined video file."""
|
||||
duration = to_ts - from_ts
|
||||
print(f"[3/5] Extracting clip [{from_ts:.3f}s → {to_ts:.3f}s] ({duration:.2f}s) …")
|
||||
cmd = [
|
||||
"ffmpeg",
|
||||
"-y",
|
||||
"-ss",
|
||||
str(from_ts),
|
||||
"-i",
|
||||
str(video_path),
|
||||
"-t",
|
||||
str(duration),
|
||||
"-c:v",
|
||||
"libx264",
|
||||
"-preset",
|
||||
"fast",
|
||||
"-crf",
|
||||
"18",
|
||||
"-an",
|
||||
str(out_path),
|
||||
]
|
||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"ffmpeg clip extraction failed:\n{result.stderr}")
|
||||
return out_path
|
||||
|
||||
|
||||
def precompute_pixels(
|
||||
progress_data: np.ndarray,
|
||||
n_frames: int,
|
||||
frame_w: int,
|
||||
frame_h: int,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Map each progress sample to pixel coordinates.
|
||||
Returns array of shape (N, 2) with (x, y) in pixel space.
|
||||
x spans full video width; y maps progress [0,1] to graph band.
|
||||
"""
|
||||
frame_indices = progress_data[:, 0].astype(float)
|
||||
progress_vals = np.clip(progress_data[:, 1].astype(float), 0.0, 1.0)
|
||||
|
||||
y_top = int(frame_h * GRAPH_Y_TOP_FRAC)
|
||||
y_bot = int(frame_h * GRAPH_Y_BOT_FRAC)
|
||||
graph_h = y_bot - y_top
|
||||
|
||||
xs = (frame_indices / (n_frames - 1) * (frame_w - 1)).astype(int)
|
||||
# progress=1 → y_top, progress=0 → y_bot
|
||||
ys = (y_bot - progress_vals * graph_h).astype(int)
|
||||
|
||||
return np.stack([xs, ys], axis=1) # (N, 2)
|
||||
|
||||
|
||||
def progress_color(t: float) -> tuple[int, int, int]:
|
||||
"""Interpolate BGR color red→green based on normalised position t in [0,1]."""
|
||||
r = int(255 * (1.0 - t))
|
||||
g = int(255 * t)
|
||||
return (0, g, r) # BGR
|
||||
|
||||
|
||||
def prerender_fill(
|
||||
pixels: np.ndarray,
|
||||
frame_w: int,
|
||||
frame_h: int,
|
||||
) -> np.ndarray:
|
||||
"""Pre-render the full grey fill polygon under the curve as a BGRA image."""
|
||||
y_bot = int(frame_h * GRAPH_Y_BOT_FRAC)
|
||||
fill_img = np.zeros((frame_h, frame_w, 4), dtype=np.uint8)
|
||||
poly = np.concatenate(
|
||||
[
|
||||
pixels,
|
||||
[[pixels[-1][0], y_bot], [pixels[0][0], y_bot]],
|
||||
],
|
||||
axis=0,
|
||||
).astype(np.int32)
|
||||
cv2.fillPoly(fill_img, [poly], color=(128, 128, 128, int(255 * FILL_ALPHA)))
|
||||
return fill_img
|
||||
|
||||
|
||||
def alpha_composite(base: np.ndarray, overlay_bgra: np.ndarray, x_max: int) -> None:
|
||||
"""Blend overlay onto base in-place, but only for x < x_max."""
|
||||
if x_max <= 0:
|
||||
return
|
||||
roi_b = base[:, :x_max]
|
||||
roi_o = overlay_bgra[:, :x_max]
|
||||
alpha = roi_o[:, :, 3:4].astype(np.float32) / 255.0
|
||||
roi_b[:] = np.clip(
|
||||
roi_o[:, :, :3].astype(np.float32) * alpha + roi_b.astype(np.float32) * (1.0 - alpha),
|
||||
0,
|
||||
255,
|
||||
).astype(np.uint8)
|
||||
|
||||
|
||||
def draw_text_outlined(
|
||||
frame: np.ndarray,
|
||||
text: str,
|
||||
pos: tuple[int, int],
|
||||
font_scale: float,
|
||||
thickness: int = 1,
|
||||
) -> None:
|
||||
"""Draw text with a dark outline for readability on any background."""
|
||||
font = cv2.FONT_HERSHEY_SIMPLEX
|
||||
cv2.putText(frame, text, pos, font, font_scale, (0, 0, 0), thickness + 2, cv2.LINE_AA)
|
||||
cv2.putText(frame, text, pos, font, font_scale, (255, 255, 255), thickness, cv2.LINE_AA)
|
||||
|
||||
|
||||
def composite_video(
|
||||
clip_path: Path,
|
||||
progress_data: np.ndarray,
|
||||
out_path: Path,
|
||||
fps: float,
|
||||
frame_h: int,
|
||||
frame_w: int,
|
||||
task_name: str = "",
|
||||
) -> Path:
|
||||
"""Read clip frames, draw gradient progress line with fill + labels, export as GIF."""
|
||||
n_total = int(cv2.VideoCapture(str(clip_path)).get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
pixels = precompute_pixels(progress_data, n_total, frame_w, frame_h)
|
||||
|
||||
y_ref = int(frame_h * GRAPH_Y_TOP_FRAC)
|
||||
|
||||
# Pre-render fill polygon (line is drawn per-frame with live color)
|
||||
fill_img = prerender_fill(pixels, frame_w, frame_h)
|
||||
|
||||
# 1.0 reference line overlay (full width, drawn once)
|
||||
ref_img = np.zeros((frame_h, frame_w, 4), dtype=np.uint8)
|
||||
cv2.line(ref_img, (0, y_ref), (frame_w - 1, y_ref), (200, 200, 200, int(255 * REF_ALPHA)), 1, cv2.LINE_AA)
|
||||
|
||||
frame_indices = progress_data[:, 0].astype(int)
|
||||
progress_vals = progress_data[:, 1].astype(float)
|
||||
|
||||
print(f"[4/4] Compositing {n_total} frames …")
|
||||
cap = cv2.VideoCapture(str(clip_path))
|
||||
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
|
||||
tmp_path = out_path.parent / (out_path.stem + "_tmp.mp4")
|
||||
writer = cv2.VideoWriter(str(tmp_path), fourcc, fps, (frame_w, frame_h))
|
||||
|
||||
fi = 0
|
||||
while True:
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
break
|
||||
|
||||
n_drawn = int(np.searchsorted(frame_indices, fi, side="right"))
|
||||
x_cur = int(pixels[min(n_drawn, len(pixels)) - 1][0]) + 1 if n_drawn > 0 else 0
|
||||
|
||||
# 1. reference line (full width, always)
|
||||
alpha_composite(frame, ref_img, frame_w)
|
||||
|
||||
# 2. grey fill under curve up to current x
|
||||
alpha_composite(frame, fill_img, x_cur)
|
||||
|
||||
# 3. progress line — single color that transitions red→green over time
|
||||
if n_drawn >= 2:
|
||||
t_cur = (n_drawn - 1) / max(len(progress_vals) - 1, 1)
|
||||
line_col = progress_color(t_cur)
|
||||
pts = pixels[:n_drawn].reshape(-1, 1, 2).astype(np.int32)
|
||||
cv2.polylines(
|
||||
frame,
|
||||
[pts],
|
||||
isClosed=False,
|
||||
color=(255, 255, 255),
|
||||
thickness=SHADOW_THICKNESS,
|
||||
lineType=cv2.LINE_AA,
|
||||
)
|
||||
cv2.polylines(
|
||||
frame, [pts], isClosed=False, color=line_col, thickness=LINE_THICKNESS, lineType=cv2.LINE_AA
|
||||
)
|
||||
|
||||
# 4. score — bottom right
|
||||
if n_drawn > 0:
|
||||
score = float(progress_vals[min(n_drawn, len(progress_vals)) - 1])
|
||||
score_text = f"{score:.2f}"
|
||||
(tw, th), _ = cv2.getTextSize(score_text, cv2.FONT_HERSHEY_SIMPLEX, SCORE_FONT_SCALE, 2)
|
||||
sx = frame_w - tw - 12
|
||||
sy = frame_h - 12
|
||||
# coloured score matching current gradient position
|
||||
t_cur = (n_drawn - 1) / max(len(progress_vals) - 1, 1)
|
||||
score_col = progress_color(t_cur)
|
||||
cv2.putText(
|
||||
frame,
|
||||
score_text,
|
||||
(sx, sy),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
SCORE_FONT_SCALE,
|
||||
(0, 0, 0),
|
||||
4,
|
||||
cv2.LINE_AA,
|
||||
)
|
||||
cv2.putText(
|
||||
frame,
|
||||
score_text,
|
||||
(sx, sy),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
SCORE_FONT_SCALE,
|
||||
score_col,
|
||||
2,
|
||||
cv2.LINE_AA,
|
||||
)
|
||||
|
||||
# 5. task name — top centre
|
||||
if task_name:
|
||||
(tw, _), _ = cv2.getTextSize(task_name, cv2.FONT_HERSHEY_SIMPLEX, TASK_FONT_SCALE, 1)
|
||||
tx = max((frame_w - tw) // 2, 4)
|
||||
draw_text_outlined(frame, task_name, (tx, 22), TASK_FONT_SCALE)
|
||||
|
||||
writer.write(frame)
|
||||
fi += 1
|
||||
if fi % 100 == 0:
|
||||
print(f" Frame {fi}/{n_total} …", end="\r")
|
||||
|
||||
cap.release()
|
||||
writer.release()
|
||||
print()
|
||||
|
||||
# Convert to GIF: full resolution, 12fps, 128-color diff palette (<40MB)
|
||||
gif_path = out_path.with_suffix(".gif")
|
||||
palette = out_path.parent / "_palette.png"
|
||||
r1 = subprocess.run( # nosec B607
|
||||
[
|
||||
"ffmpeg",
|
||||
"-y",
|
||||
"-i",
|
||||
str(tmp_path),
|
||||
"-vf",
|
||||
f"fps=10,scale={frame_w}:-1:flags=lanczos,palettegen=max_colors=128:stats_mode=diff",
|
||||
"-update",
|
||||
"1",
|
||||
str(palette),
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
if r1.returncode != 0:
|
||||
print(f" WARNING: palettegen failed:\n{r1.stderr[-500:]}")
|
||||
r2 = subprocess.run( # nosec B607
|
||||
[
|
||||
"ffmpeg",
|
||||
"-y",
|
||||
"-i",
|
||||
str(tmp_path),
|
||||
"-i",
|
||||
str(palette),
|
||||
"-filter_complex",
|
||||
f"fps=10,scale={frame_w}:-1:flags=lanczos[v];[v][1:v]paletteuse=dither=bayer:bayer_scale=3",
|
||||
str(gif_path),
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
if r2.returncode != 0:
|
||||
print(f" WARNING: gif encode failed:\n{r2.stderr[-500:]}")
|
||||
tmp_path.unlink(missing_ok=True)
|
||||
palette.unlink(missing_ok=True)
|
||||
return gif_path
|
||||
|
||||
|
||||
def process_dataset(repo_id: str, episode: int):
|
||||
safe_name = repo_id.replace("/", "_")
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f"Processing: {repo_id} | episode {episode}")
|
||||
print(f"{'=' * 60}")
|
||||
|
||||
# 1. Download metadata
|
||||
local = download_episode(repo_id, episode)
|
||||
print(f" Local cache: {local}")
|
||||
|
||||
# 2. Read episode metadata
|
||||
ep_meta = load_episode_meta(local, episode)
|
||||
print(f" Episode meta: {ep_meta}")
|
||||
|
||||
# 3. Download video file
|
||||
video_path = download_video(repo_id, local, ep_meta["video_rel"])
|
||||
|
||||
# 4. Extract clip
|
||||
clip_path = OUTPUT_DIR / f"{safe_name}_ep{episode}_clip.mp4"
|
||||
extract_episode_clip(video_path, ep_meta["from_ts"], ep_meta["to_ts"], clip_path)
|
||||
|
||||
# 5. Load progress data
|
||||
progress_data = load_progress(local, episode)
|
||||
if progress_data is None:
|
||||
print(" ERROR: Could not load sarm_progress data. Skipping overlay.")
|
||||
return
|
||||
|
||||
n_progress = len(progress_data)
|
||||
print(f" Progress frames: {n_progress}")
|
||||
|
||||
# 6. Get clip dimensions
|
||||
cap = cv2.VideoCapture(str(clip_path))
|
||||
frame_w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
frame_h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
n_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
actual_fps = cap.get(cv2.CAP_PROP_FPS) or ep_meta["fps"]
|
||||
cap.release()
|
||||
print(f" Clip: {frame_w}×{frame_h} {n_frames} frames @ {actual_fps:.1f}fps")
|
||||
|
||||
# 7. Composite (draw line directly on frames)
|
||||
out_path = OUTPUT_DIR / f"{safe_name}_ep{episode}_progress.mp4"
|
||||
final = composite_video(
|
||||
clip_path,
|
||||
progress_data,
|
||||
out_path,
|
||||
actual_fps,
|
||||
frame_h,
|
||||
frame_w,
|
||||
task_name=ep_meta.get("task_name", ""),
|
||||
)
|
||||
clip_path.unlink(missing_ok=True)
|
||||
print(f"\n✓ Done: {final}")
|
||||
return final
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
results = []
|
||||
for cfg in DATASETS:
|
||||
try:
|
||||
out = process_dataset(cfg["repo_id"], cfg["episode"])
|
||||
if out:
|
||||
results.append(out)
|
||||
except Exception as e:
|
||||
print(f"\nERROR processing {cfg['repo_id']}: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Output files:")
|
||||
for r in results:
|
||||
print(f" {r}")
|
||||
@@ -1,496 +0,0 @@
|
||||
"""
|
||||
Visualize end-effector workspace density and trajectory clusters for OpenArm datasets.
|
||||
Downloads joint position data (no videos) from HuggingFace, computes forward
|
||||
kinematics per episode, clusters trajectories with K-means, and renders
|
||||
2D projections comparing dataset coverage and multimodality.
|
||||
"""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from huggingface_hub import snapshot_download
|
||||
from sklearn.cluster import KMeans
|
||||
|
||||
DATASETS = [
|
||||
{"repo_id": "lerobot-data-collection/level2_final_quality3", "label": "HQ curated"},
|
||||
{"repo_id": "lerobot-data-collection/level12_rac_2_2026-02-08_1", "label": "Full collection"},
|
||||
]
|
||||
OUTPUT_DIR = Path(__file__).resolve().parent / "outputs"
|
||||
OUTPUT_DIR.mkdir(exist_ok=True)
|
||||
|
||||
N_CLUSTERS = 10
|
||||
WAYPOINTS = 50
|
||||
SEED = 42
|
||||
DPI = 180
|
||||
|
||||
CLUSTER_COLORS = [
|
||||
"#e6194b",
|
||||
"#3cb44b",
|
||||
"#4363d8",
|
||||
"#f58231",
|
||||
"#911eb4",
|
||||
"#42d4f4",
|
||||
"#f032e6",
|
||||
"#bfef45",
|
||||
"#fabed4",
|
||||
"#dcbeff",
|
||||
"#9a6324",
|
||||
"#fffac8",
|
||||
"#800000",
|
||||
"#aaffc3",
|
||||
"#808000",
|
||||
"#ffd8b1",
|
||||
"#000075",
|
||||
"#a9a9a9",
|
||||
]
|
||||
|
||||
# FK chains extracted from OpenArm bimanual URDF.
|
||||
# Each entry: (rpy, xyz, revolute_axis_or_None).
|
||||
LEFT_CHAIN = [
|
||||
((-np.pi / 2, 0, 0), (0, 0.031, 0.698), None),
|
||||
((0, 0, 0), (0, 0, 0.0625), (0, 0, 1)),
|
||||
((-np.pi / 2, 0, 0), (-0.0301, 0, 0.06), (-1, 0, 0)),
|
||||
((0, 0, 0), (0.0301, 0, 0.06625), (0, 0, 1)),
|
||||
((0, 0, 0), (0, 0.0315, 0.15375), (0, 1, 0)),
|
||||
((0, 0, 0), (0, -0.0315, 0.0955), (0, 0, 1)),
|
||||
((0, 0, 0), (0.0375, 0, 0.1205), (1, 0, 0)),
|
||||
((0, 0, 0), (-0.0375, 0, 0), (0, -1, 0)),
|
||||
((0, 0, 0), (0, 0, 0.1001), None),
|
||||
((0, 0, 0), (0, 0, 0.08), None),
|
||||
]
|
||||
RIGHT_CHAIN = [
|
||||
((np.pi / 2, 0, 0), (0, -0.031, 0.698), None),
|
||||
((0, 0, 0), (0, 0, 0.0625), (0, 0, 1)),
|
||||
((np.pi / 2, 0, 0), (-0.0301, 0, 0.06), (-1, 0, 0)),
|
||||
((0, 0, 0), (0.0301, 0, 0.06625), (0, 0, 1)),
|
||||
((0, 0, 0), (0, 0.0315, 0.15375), (0, 1, 0)),
|
||||
((0, 0, 0), (0, -0.0315, 0.0955), (0, 0, 1)),
|
||||
((0, 0, 0), (0.0375, 0, 0.1205), (1, 0, 0)),
|
||||
((0, 0, 0), (-0.0375, 0, 0), (0, 1, 0)),
|
||||
((0, 0, 0), (0, 0, 0.1001), None),
|
||||
((0, 0, 0), (0, 0, 0.08), None),
|
||||
]
|
||||
|
||||
|
||||
# ── FK math ─────────────────────────────────────────────
|
||||
|
||||
|
||||
def _rot_x(a: float) -> np.ndarray:
|
||||
c, s = np.cos(a), np.sin(a)
|
||||
return np.array([[1, 0, 0], [0, c, -s], [0, s, c]])
|
||||
|
||||
|
||||
def _rot_y(a: float) -> np.ndarray:
|
||||
c, s = np.cos(a), np.sin(a)
|
||||
return np.array([[c, 0, s], [0, 1, 0], [-s, 0, c]])
|
||||
|
||||
|
||||
def _rot_z(a: float) -> np.ndarray:
|
||||
c, s = np.cos(a), np.sin(a)
|
||||
return np.array([[c, -s, 0], [s, c, 0], [0, 0, 1]])
|
||||
|
||||
|
||||
def _tf(rpy: tuple, xyz: tuple) -> np.ndarray:
|
||||
"""Build a 4x4 homogeneous transform from URDF rpy + xyz."""
|
||||
r, p, y = rpy
|
||||
mat = np.eye(4)
|
||||
mat[:3, :3] = _rot_z(y) @ _rot_y(p) @ _rot_x(r)
|
||||
mat[:3, 3] = xyz
|
||||
return mat
|
||||
|
||||
|
||||
def _batch_axis_rot(axis: tuple, angles: np.ndarray) -> np.ndarray:
|
||||
"""Batched Rodrigues rotation: (n,) angles around a fixed axis → (n, 4, 4)."""
|
||||
n = len(angles)
|
||||
ax = np.asarray(axis, dtype=np.float64)
|
||||
ax = ax / np.linalg.norm(ax)
|
||||
x, y, z = ax
|
||||
c = np.cos(angles)
|
||||
s = np.sin(angles)
|
||||
t = 1 - c
|
||||
rot = np.zeros((n, 4, 4))
|
||||
rot[:, 0, 0] = t * x * x + c
|
||||
rot[:, 0, 1] = t * x * y - s * z
|
||||
rot[:, 0, 2] = t * x * z + s * y
|
||||
rot[:, 1, 0] = t * x * y + s * z
|
||||
rot[:, 1, 1] = t * y * y + c
|
||||
rot[:, 1, 2] = t * y * z - s * x
|
||||
rot[:, 2, 0] = t * x * z - s * y
|
||||
rot[:, 2, 1] = t * y * z + s * x
|
||||
rot[:, 2, 2] = t * z * z + c
|
||||
rot[:, 3, 3] = 1.0
|
||||
return rot
|
||||
|
||||
|
||||
def batch_fk(chain: list, joint_angles: np.ndarray) -> np.ndarray:
|
||||
"""Vectorized FK: (n, 7) radians → (n, 3) TCP positions in world frame."""
|
||||
n = joint_angles.shape[0]
|
||||
tf_batch = np.tile(np.eye(4), (n, 1, 1))
|
||||
qi = 0
|
||||
for rpy, xyz, axis in chain:
|
||||
tf_batch = tf_batch @ _tf(rpy, xyz)
|
||||
if axis is not None:
|
||||
rot = _batch_axis_rot(axis, joint_angles[:, qi])
|
||||
tf_batch = np.einsum("nij,njk->nik", tf_batch, rot)
|
||||
qi += 1
|
||||
return tf_batch[:, :3, 3]
|
||||
|
||||
|
||||
# ── Data loading ────────────────────────────────────────
|
||||
|
||||
|
||||
def _flatten_names(obj: object) -> list[str]:
|
||||
"""Recursively flatten a names structure (list, dict, or nested) into a flat string list."""
|
||||
if isinstance(obj, dict):
|
||||
out: list[str] = []
|
||||
for v in obj.values():
|
||||
out.extend(_flatten_names(v))
|
||||
return out
|
||||
if isinstance(obj, (list, tuple)):
|
||||
out = []
|
||||
for item in obj:
|
||||
if isinstance(item, (list, tuple, dict)):
|
||||
out.extend(_flatten_names(item))
|
||||
else:
|
||||
out.append(str(item))
|
||||
return out
|
||||
return [str(obj)]
|
||||
|
||||
|
||||
def _detect_and_convert(vals: np.ndarray) -> np.ndarray:
|
||||
"""Auto-detect servo ticks / degrees / radians and convert to radians."""
|
||||
mx = np.max(np.abs(vals))
|
||||
if mx > 360:
|
||||
print(f" Unit detection: servo ticks (max={mx:.0f})")
|
||||
return (vals - 2048) / 2048 * np.pi
|
||||
if mx > 6.3:
|
||||
print(f" Unit detection: degrees (max={mx:.1f})")
|
||||
return np.deg2rad(vals)
|
||||
print(f" Unit detection: radians (max={mx:.3f})")
|
||||
return vals.astype(np.float64)
|
||||
|
||||
|
||||
def _find_joint_indices(features: dict, state_col: str, n_dim: int) -> tuple[list[int], list[int]]:
|
||||
"""Try to find left/right joint indices from info.json feature names."""
|
||||
feat = features.get("observation.state", features.get(state_col, {}))
|
||||
names = _flatten_names(feat.get("names", []))
|
||||
|
||||
left_idx: list[int] = []
|
||||
right_idx: list[int] = []
|
||||
if names and len(names) == n_dim:
|
||||
names_l = [n.lower() for n in names]
|
||||
print(f" Feature names: {names[:4]}…{names[-4:]}")
|
||||
for j in range(1, 8):
|
||||
for i, nm in enumerate(names_l):
|
||||
if f"left_joint_{j}" in nm and i not in left_idx:
|
||||
left_idx.append(i)
|
||||
break
|
||||
for i, nm in enumerate(names_l):
|
||||
if f"right_joint_{j}" in nm and i not in right_idx:
|
||||
right_idx.append(i)
|
||||
break
|
||||
|
||||
if len(left_idx) == 7 and len(right_idx) == 7:
|
||||
print(f" Matched by name: left={left_idx} right={right_idx}")
|
||||
return left_idx, right_idx
|
||||
if n_dim >= 16:
|
||||
print(" Falling back to positional: [0:7]=left, [8:15]=right")
|
||||
return list(range(7)), list(range(8, 15))
|
||||
if n_dim >= 14:
|
||||
print(" Falling back to positional: [0:7]=left, [7:14]=right")
|
||||
return list(range(7)), list(range(7, 14))
|
||||
raise RuntimeError(f"State dim {n_dim} too small for bimanual 7-DOF robot")
|
||||
|
||||
|
||||
def download_data(repo_id: str) -> Path:
|
||||
print(f" Downloading {repo_id} (parquet only) …")
|
||||
return Path(
|
||||
snapshot_download(
|
||||
repo_id=repo_id,
|
||||
repo_type="dataset",
|
||||
allow_patterns=["meta/**", "data/**"],
|
||||
ignore_patterns=["*.mp4", "videos/**"],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def resample_trajectory(traj: np.ndarray, n_waypoints: int) -> np.ndarray:
|
||||
"""Resample a (F, 3) trajectory to exactly n_waypoints via linear interpolation."""
|
||||
f = traj.shape[0]
|
||||
if f == n_waypoints:
|
||||
return traj
|
||||
old_t = np.linspace(0, 1, f)
|
||||
new_t = np.linspace(0, 1, n_waypoints)
|
||||
return np.column_stack([np.interp(new_t, old_t, traj[:, d]) for d in range(3)])
|
||||
|
||||
|
||||
def load_episode_trajectories(local: Path) -> list[dict]:
|
||||
"""
|
||||
Load per-episode joint data, compute FK, return list of trajectory dicts.
|
||||
Each dict: {"left_tcp": (F,3), "right_tcp": (F,3), "episode_index": int}.
|
||||
Uses all episodes in the dataset for a fair comparison.
|
||||
"""
|
||||
info = json.loads((local / "meta" / "info.json").read_text())
|
||||
features = info.get("features", {})
|
||||
|
||||
dfs = [pd.read_parquet(pq) for pq in sorted((local / "data").glob("**/*.parquet"))]
|
||||
df = pd.concat(dfs, ignore_index=True)
|
||||
print(f" Total frames: {len(df):,}")
|
||||
|
||||
state_col = next((c for c in df.columns if "observation.state" in c), None)
|
||||
if state_col is None:
|
||||
raise RuntimeError(f"No observation.state column. Available: {list(df.columns)}")
|
||||
|
||||
first = df[state_col].iloc[0]
|
||||
if not hasattr(first, "__len__"):
|
||||
raise RuntimeError(f"observation.state is scalar ({type(first)}), expected array")
|
||||
|
||||
state = np.stack(df[state_col].values).astype(np.float64)
|
||||
n_dim = state.shape[1]
|
||||
print(f" State dim: {n_dim} max|val|: {np.max(np.abs(state)):.1f}")
|
||||
|
||||
left_idx, right_idx = _find_joint_indices(features, state_col, n_dim)
|
||||
|
||||
ep_col = next((c for c in df.columns if c == "episode_index"), None)
|
||||
if ep_col is None:
|
||||
raise RuntimeError(f"No episode_index column. Available: {list(df.columns)}")
|
||||
|
||||
episode_ids = df[ep_col].values
|
||||
unique_eps = np.unique(episode_ids)
|
||||
print(f" Episodes: {len(unique_eps):,}")
|
||||
|
||||
left_raw = state[:, left_idx]
|
||||
right_raw = state[:, right_idx]
|
||||
left_all = _detect_and_convert(left_raw)
|
||||
right_all = _detect_and_convert(right_raw)
|
||||
|
||||
print(" Computing FK per episode …")
|
||||
trajectories = []
|
||||
for ep_id in unique_eps:
|
||||
mask = episode_ids == ep_id
|
||||
left_tcp = batch_fk(LEFT_CHAIN, left_all[mask])
|
||||
right_tcp = batch_fk(RIGHT_CHAIN, right_all[mask])
|
||||
if len(left_tcp) < 3:
|
||||
continue
|
||||
trajectories.append({"left_tcp": left_tcp, "right_tcp": right_tcp, "episode_index": int(ep_id)})
|
||||
|
||||
print(f" Valid trajectories: {len(trajectories):,}")
|
||||
return trajectories
|
||||
|
||||
|
||||
# ── Clustering ──────────────────────────────────────────
|
||||
|
||||
|
||||
def cluster_trajectories(
|
||||
trajectories: list[dict], n_clusters: int, n_waypoints: int
|
||||
) -> tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
K-means on resampled trajectory features.
|
||||
Combines left+right TCP into a single feature vector per episode.
|
||||
Returns (labels, centroid_trajs (k, waypoints, 6), spread_per_cluster (k,) in metres).
|
||||
Spread = mean per-waypoint Euclidean distance from each trajectory to its centroid.
|
||||
"""
|
||||
feat_vecs = []
|
||||
for t in trajectories:
|
||||
left_rs = resample_trajectory(t["left_tcp"], n_waypoints)
|
||||
right_rs = resample_trajectory(t["right_tcp"], n_waypoints)
|
||||
feat_vecs.append(np.concatenate([left_rs.ravel(), right_rs.ravel()]))
|
||||
feat_matrix = np.array(feat_vecs)
|
||||
|
||||
k = min(n_clusters, len(feat_vecs))
|
||||
km = KMeans(n_clusters=k, n_init=10, random_state=SEED)
|
||||
labels = km.fit_predict(feat_matrix)
|
||||
|
||||
centroids_flat = km.cluster_centers_
|
||||
centroid_trajs = np.zeros((k, n_waypoints, 6))
|
||||
for ci in range(k):
|
||||
left_flat = centroids_flat[ci, : n_waypoints * 3]
|
||||
right_flat = centroids_flat[ci, n_waypoints * 3 :]
|
||||
centroid_trajs[ci, :, :3] = left_flat.reshape(n_waypoints, 3)
|
||||
centroid_trajs[ci, :, 3:] = right_flat.reshape(n_waypoints, 3)
|
||||
|
||||
# Mean per-waypoint distance to centroid (in metres) for each cluster
|
||||
spread = np.zeros(k)
|
||||
for ci in range(k):
|
||||
members = np.where(labels == ci)[0]
|
||||
if len(members) == 0:
|
||||
continue
|
||||
centroid_left = centroid_trajs[ci, :, :3]
|
||||
centroid_right = centroid_trajs[ci, :, 3:]
|
||||
dists = []
|
||||
for mi in members:
|
||||
t = trajectories[mi]
|
||||
left_rs = resample_trajectory(t["left_tcp"], n_waypoints)
|
||||
right_rs = resample_trajectory(t["right_tcp"], n_waypoints)
|
||||
d_left = np.linalg.norm(left_rs - centroid_left, axis=1).mean()
|
||||
d_right = np.linalg.norm(right_rs - centroid_right, axis=1).mean()
|
||||
dists.append((d_left + d_right) / 2)
|
||||
spread[ci] = np.mean(dists)
|
||||
|
||||
return labels, centroid_trajs, spread
|
||||
|
||||
|
||||
# ── Visualization ───────────────────────────────────────
|
||||
|
||||
PROJ_VIEWS = [
|
||||
("XZ (side)", 0, 2, "X (m)", "Z (m)"),
|
||||
("XY (top)", 0, 1, "X (m)", "Y (m)"),
|
||||
("YZ (front)", 1, 2, "Y (m)", "Z (m)"),
|
||||
]
|
||||
|
||||
|
||||
def render(results: list[dict], out_path: Path) -> None:
|
||||
"""
|
||||
2-row × 3-col grid per dataset (3 projections × 2 datasets).
|
||||
Trajectory lines colored by cluster, centroid trajectories drawn thick.
|
||||
"""
|
||||
n_ds = len(results)
|
||||
n_proj = len(PROJ_VIEWS)
|
||||
fig, axes = plt.subplots(n_ds, n_proj, figsize=(7 * n_proj, 7 * n_ds), facecolor="#0d1117")
|
||||
if n_ds == 1:
|
||||
axes = axes[np.newaxis, :]
|
||||
|
||||
for row, r in enumerate(results):
|
||||
trajectories = r["trajectories"]
|
||||
labels = r["labels"]
|
||||
centroids = r["centroids"]
|
||||
k = centroids.shape[0]
|
||||
|
||||
cluster_sizes = np.bincount(labels, minlength=k)
|
||||
size_order = np.argsort(-cluster_sizes)
|
||||
pcts = cluster_sizes / len(labels) * 100
|
||||
spread = r["spread"]
|
||||
|
||||
for col, (view_name, dim_a, dim_b, xlabel, ylabel) in enumerate(PROJ_VIEWS):
|
||||
ax = axes[row, col]
|
||||
ax.set_facecolor("#0d1117")
|
||||
|
||||
for ti, traj in enumerate(trajectories):
|
||||
color = CLUSTER_COLORS[labels[ti] % len(CLUSTER_COLORS)]
|
||||
for tcp_key in ("left_tcp", "right_tcp"):
|
||||
pts = traj[tcp_key]
|
||||
ax.plot(pts[:, dim_a], pts[:, dim_b], color=color, alpha=0.12, linewidth=0.4)
|
||||
|
||||
for ci in range(k):
|
||||
color = CLUSTER_COLORS[ci % len(CLUSTER_COLORS)]
|
||||
left_c = centroids[ci, :, :3]
|
||||
right_c = centroids[ci, :, 3:]
|
||||
lw = 1.5 + 2.0 * cluster_sizes[ci] / cluster_sizes.max()
|
||||
for c_pts in (left_c, right_c):
|
||||
ax.plot(
|
||||
c_pts[:, dim_a],
|
||||
c_pts[:, dim_b],
|
||||
color=color,
|
||||
linewidth=lw,
|
||||
alpha=0.95,
|
||||
zorder=10,
|
||||
)
|
||||
ax.plot(
|
||||
c_pts[0, dim_a],
|
||||
c_pts[0, dim_b],
|
||||
"o",
|
||||
color=color,
|
||||
markersize=4,
|
||||
zorder=11,
|
||||
)
|
||||
ax.plot(
|
||||
c_pts[-1, dim_a],
|
||||
c_pts[-1, dim_b],
|
||||
"s",
|
||||
color=color,
|
||||
markersize=4,
|
||||
zorder=11,
|
||||
)
|
||||
|
||||
ax.set_xlabel(xlabel, color="#888", fontsize=9)
|
||||
ax.set_ylabel(ylabel, color="#888", fontsize=9)
|
||||
ax.tick_params(colors="#555", labelsize=7)
|
||||
for spine in ax.spines.values():
|
||||
spine.set_color("#333")
|
||||
ax.set_aspect("equal")
|
||||
|
||||
mean_spread_cm = np.average(spread, weights=cluster_sizes) * 100
|
||||
if col == 0:
|
||||
ax.set_title(
|
||||
f"{r['label']} ({r['n_episodes']:,} episodes, {k} clusters, "
|
||||
f"avg spread {mean_spread_cm:.1f}cm)",
|
||||
color="white",
|
||||
fontsize=11,
|
||||
pad=10,
|
||||
)
|
||||
else:
|
||||
ax.set_title(view_name, color="#aaa", fontsize=10, pad=8)
|
||||
|
||||
# Cluster size + spread legend on the rightmost panel
|
||||
legend_ax = axes[row, -1]
|
||||
for ci in size_order:
|
||||
color = CLUSTER_COLORS[ci % len(CLUSTER_COLORS)]
|
||||
spread_cm = spread[ci] * 100
|
||||
label = f"C{ci}: {cluster_sizes[ci]} eps ({pcts[ci]:.0f}%) ±{spread_cm:.1f}cm"
|
||||
legend_ax.plot([], [], color=color, linewidth=3, label=label)
|
||||
legend_ax.legend(
|
||||
loc="upper right",
|
||||
fontsize=7,
|
||||
frameon=True,
|
||||
facecolor="#1a1a2e",
|
||||
edgecolor="#333",
|
||||
labelcolor="white",
|
||||
handlelength=1.5,
|
||||
)
|
||||
|
||||
fig.suptitle(
|
||||
"End-Effector Trajectory Clusters (FK · K-means)",
|
||||
color="white",
|
||||
fontsize=16,
|
||||
y=0.98,
|
||||
)
|
||||
plt.tight_layout(rect=[0, 0, 1, 0.95])
|
||||
plt.savefig(out_path, dpi=DPI, bbox_inches="tight", facecolor=fig.get_facecolor())
|
||||
plt.close()
|
||||
print(f"\n✓ Saved: {out_path}")
|
||||
|
||||
|
||||
# ── Main ────────────────────────────────────────────────
|
||||
|
||||
|
||||
def main() -> None:
|
||||
results = []
|
||||
|
||||
for ds in DATASETS:
|
||||
repo_id, label = ds["repo_id"], ds["label"]
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f" {label}: {repo_id}")
|
||||
print(f"{'=' * 60}")
|
||||
|
||||
local = download_data(repo_id)
|
||||
trajectories = load_episode_trajectories(local)
|
||||
labels, centroids, spread = cluster_trajectories(trajectories, N_CLUSTERS, WAYPOINTS)
|
||||
|
||||
cluster_sizes = np.bincount(labels, minlength=centroids.shape[0])
|
||||
print(f" Cluster sizes: {sorted(cluster_sizes, reverse=True)}")
|
||||
for ci in np.argsort(-cluster_sizes):
|
||||
print(
|
||||
f" C{ci}: {cluster_sizes[ci]} eps ({cluster_sizes[ci] / len(labels) * 100:.0f}%) "
|
||||
f"spread ±{spread[ci] * 100:.1f}cm"
|
||||
)
|
||||
|
||||
results.append(
|
||||
{
|
||||
"label": label,
|
||||
"trajectories": trajectories,
|
||||
"labels": labels,
|
||||
"centroids": centroids,
|
||||
"spread": spread,
|
||||
"n_episodes": len(trajectories),
|
||||
}
|
||||
)
|
||||
|
||||
out = OUTPUT_DIR / "workspace_trajectory_clusters.jpg"
|
||||
render(results, out)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -131,6 +131,15 @@ class _NormalizationMixin:
|
||||
if self.dtype is None:
|
||||
self.dtype = torch.float32
|
||||
self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype)
|
||||
self._reshape_visual_stats()
|
||||
|
||||
def _reshape_visual_stats(self) -> None:
|
||||
"""Reshape visual stats from ``[C]`` to ``[C, 1, 1]`` for image broadcasting."""
|
||||
for key, feature in self.features.items():
|
||||
if feature.type == FeatureType.VISUAL and key in self._tensor_stats:
|
||||
for stat_name, stat_tensor in self._tensor_stats[key].items():
|
||||
if isinstance(stat_tensor, Tensor) and stat_tensor.ndim == 1:
|
||||
self._tensor_stats[key][stat_name] = stat_tensor.reshape(-1, 1, 1)
|
||||
|
||||
def to(
|
||||
self, device: torch.device | str | None = None, dtype: torch.dtype | None = None
|
||||
@@ -149,6 +158,7 @@ class _NormalizationMixin:
|
||||
if dtype is not None:
|
||||
self.dtype = dtype
|
||||
self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype)
|
||||
self._reshape_visual_stats()
|
||||
return self
|
||||
|
||||
def state_dict(self) -> dict[str, Tensor]:
|
||||
@@ -198,6 +208,7 @@ class _NormalizationMixin:
|
||||
# Don't load from state_dict, keep the explicitly provided stats
|
||||
# But ensure _tensor_stats is properly initialized
|
||||
self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype) # type: ignore[assignment]
|
||||
self._reshape_visual_stats()
|
||||
return
|
||||
|
||||
# Normal behavior: load stats from state_dict
|
||||
@@ -209,6 +220,8 @@ class _NormalizationMixin:
|
||||
dtype=torch.float32, device=self.device
|
||||
)
|
||||
|
||||
self._reshape_visual_stats()
|
||||
|
||||
# Reconstruct the original stats dict from tensor stats for compatibility with to() method
|
||||
# and other functions that rely on self.stats
|
||||
self.stats = {}
|
||||
|
||||
@@ -62,6 +62,7 @@ from lerobot.configs import parser
|
||||
from lerobot.configs.train import TrainRLServerPipelineConfig
|
||||
from lerobot.policies.factory import make_policy
|
||||
from lerobot.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.policies.sac.processor_sac import make_sac_pre_post_processors
|
||||
from lerobot.rl.process import ProcessSignalHandler
|
||||
from lerobot.rl.queue import get_last_item_from_queue
|
||||
from lerobot.robots import so_follower # noqa: F401
|
||||
@@ -258,6 +259,11 @@ def act_with_policy(
|
||||
policy = policy.eval()
|
||||
assert isinstance(policy, nn.Module)
|
||||
|
||||
preprocessor, postprocessor = make_sac_pre_post_processors(
|
||||
config=cfg.policy,
|
||||
dataset_stats=cfg.policy.dataset_stats,
|
||||
)
|
||||
|
||||
obs, info = online_env.reset()
|
||||
env_processor.reset()
|
||||
action_processor.reset()
|
||||
@@ -289,7 +295,9 @@ def act_with_policy(
|
||||
# Time policy inference and check if it meets FPS requirement
|
||||
with policy_timer:
|
||||
# Extract observation from transition for policy
|
||||
action = policy.select_action(batch=observation)
|
||||
normalized_observation = preprocessor.process_observation(observation)
|
||||
action = policy.select_action(batch=normalized_observation)
|
||||
# action = postprocessor.process_action(action)
|
||||
policy_fps = policy_timer.fps_last
|
||||
|
||||
log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
|
||||
|
||||
@@ -66,6 +66,7 @@ from lerobot.datasets.factory import make_dataset
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.policies.factory import make_policy
|
||||
from lerobot.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.policies.sac.processor_sac import make_sac_pre_post_processors
|
||||
from lerobot.rl.buffer import ReplayBuffer, concatenate_batch_transitions
|
||||
from lerobot.rl.process import ProcessSignalHandler
|
||||
from lerobot.rl.wandb_utils import WandBLogger
|
||||
@@ -313,6 +314,11 @@ def add_actor_information_and_train(
|
||||
|
||||
assert isinstance(policy, nn.Module)
|
||||
|
||||
preprocessor, _ = make_sac_pre_post_processors(
|
||||
config=cfg.policy,
|
||||
dataset_stats=cfg.policy.dataset_stats,
|
||||
)
|
||||
|
||||
policy.train()
|
||||
|
||||
push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy)
|
||||
@@ -408,6 +414,9 @@ def add_actor_information_and_train(
|
||||
done = batch["done"]
|
||||
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
|
||||
|
||||
observations = preprocessor.process_observation(observations)
|
||||
next_observations = preprocessor.process_observation(next_observations)
|
||||
|
||||
observation_features, next_observation_features = get_observation_features(
|
||||
policy=policy, observations=observations, next_observations=next_observations
|
||||
)
|
||||
@@ -467,6 +476,9 @@ def add_actor_information_and_train(
|
||||
|
||||
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
|
||||
|
||||
observations = preprocessor.process_observation(observations)
|
||||
next_observations = preprocessor.process_observation(next_observations)
|
||||
|
||||
observation_features, next_observation_features = get_observation_features(
|
||||
policy=policy, observations=observations, next_observations=next_observations
|
||||
)
|
||||
|
||||
@@ -23,65 +23,46 @@ class InputController:
|
||||
"""Base class for input controllers that generate motion deltas."""
|
||||
|
||||
def __init__(self, x_step_size=1.0, y_step_size=1.0, z_step_size=1.0):
|
||||
"""
|
||||
Initialize the controller.
|
||||
|
||||
Args:
|
||||
x_step_size: Base movement step size in meters
|
||||
y_step_size: Base movement step size in meters
|
||||
z_step_size: Base movement step size in meters
|
||||
"""
|
||||
self.x_step_size = x_step_size
|
||||
self.y_step_size = y_step_size
|
||||
self.z_step_size = z_step_size
|
||||
self.running = True
|
||||
self.episode_end_status = None # None, "success", or "failure"
|
||||
self.episode_end_status = None
|
||||
self.intervention_flag = False
|
||||
self.open_gripper_command = False
|
||||
self.close_gripper_command = False
|
||||
|
||||
def start(self):
|
||||
"""Start the controller and initialize resources."""
|
||||
pass
|
||||
|
||||
def stop(self):
|
||||
"""Stop the controller and release resources."""
|
||||
pass
|
||||
|
||||
def reset(self):
|
||||
pass
|
||||
|
||||
def get_deltas(self):
|
||||
"""Get the current movement deltas (dx, dy, dz) in meters."""
|
||||
return 0.0, 0.0, 0.0
|
||||
|
||||
def update(self):
|
||||
"""Update controller state - call this once per frame."""
|
||||
pass
|
||||
|
||||
def __enter__(self):
|
||||
"""Support for use in 'with' statements."""
|
||||
self.start()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Ensure resources are released when exiting 'with' block."""
|
||||
self.stop()
|
||||
|
||||
def get_episode_end_status(self):
|
||||
"""
|
||||
Get the current episode end status.
|
||||
|
||||
Returns:
|
||||
None if episode should continue, "success" or "failure" otherwise
|
||||
"""
|
||||
status = self.episode_end_status
|
||||
self.episode_end_status = None # Reset after reading
|
||||
self.episode_end_status = None
|
||||
return status
|
||||
|
||||
def should_intervene(self):
|
||||
"""Return True if intervention flag was set."""
|
||||
return self.intervention_flag
|
||||
|
||||
def gripper_command(self):
|
||||
"""Return the current gripper command."""
|
||||
if self.open_gripper_command == self.close_gripper_command:
|
||||
return "stay"
|
||||
elif self.open_gripper_command:
|
||||
@@ -102,14 +83,14 @@ class KeyboardController(InputController):
|
||||
"backward_y": False,
|
||||
"forward_z": False,
|
||||
"backward_z": False,
|
||||
"quit": False,
|
||||
"success": False,
|
||||
"failure": False,
|
||||
"intervention": False,
|
||||
"rerecord": False,
|
||||
}
|
||||
self.listener = None
|
||||
|
||||
def start(self):
|
||||
"""Start the keyboard listener."""
|
||||
from pynput import keyboard
|
||||
|
||||
def on_press(key):
|
||||
@@ -126,16 +107,21 @@ class KeyboardController(InputController):
|
||||
self.key_states["backward_z"] = True
|
||||
elif key == keyboard.Key.shift_r:
|
||||
self.key_states["forward_z"] = True
|
||||
elif key == keyboard.Key.esc:
|
||||
self.key_states["quit"] = True
|
||||
self.running = False
|
||||
return False
|
||||
elif key == keyboard.Key.ctrl_r:
|
||||
self.open_gripper_command = True
|
||||
elif key == keyboard.Key.ctrl_l:
|
||||
self.close_gripper_command = True
|
||||
elif key == keyboard.Key.enter:
|
||||
self.key_states["success"] = True
|
||||
self.episode_end_status = TeleopEvents.SUCCESS
|
||||
elif key == keyboard.Key.backspace:
|
||||
elif key == keyboard.Key.esc:
|
||||
self.key_states["failure"] = True
|
||||
self.episode_end_status = TeleopEvents.FAILURE
|
||||
elif key == keyboard.Key.space:
|
||||
self.key_states["intervention"] = not self.key_states["intervention"]
|
||||
elif hasattr(key, "char") and key.char == "r":
|
||||
self.key_states["rerecord"] = True
|
||||
self.episode_end_status = TeleopEvents.RERECORD_EPISODE
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
@@ -153,10 +139,10 @@ class KeyboardController(InputController):
|
||||
self.key_states["backward_z"] = False
|
||||
elif key == keyboard.Key.shift_r:
|
||||
self.key_states["forward_z"] = False
|
||||
elif key == keyboard.Key.enter:
|
||||
self.key_states["success"] = False
|
||||
elif key == keyboard.Key.backspace:
|
||||
self.key_states["failure"] = False
|
||||
elif key == keyboard.Key.ctrl_r:
|
||||
self.open_gripper_command = False
|
||||
elif key == keyboard.Key.ctrl_l:
|
||||
self.close_gripper_command = False
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
@@ -165,18 +151,18 @@ class KeyboardController(InputController):
|
||||
|
||||
print("Keyboard controls:")
|
||||
print(" Arrow keys: Move in X-Y plane")
|
||||
print(" Shift and Shift_R: Move in Z axis")
|
||||
print(" Shift / Shift_R: Move in Z axis")
|
||||
print(" Ctrl_R / Ctrl_L: Open / Close gripper")
|
||||
print(" Space: Toggle intervention")
|
||||
print(" Enter: End episode with SUCCESS")
|
||||
print(" Backspace: End episode with FAILURE")
|
||||
print(" ESC: Exit")
|
||||
print(" Esc: End episode with FAILURE")
|
||||
print(" R: Rerecord episode")
|
||||
|
||||
def stop(self):
|
||||
"""Stop the keyboard listener."""
|
||||
if self.listener and self.listener.is_alive():
|
||||
self.listener.stop()
|
||||
|
||||
def get_deltas(self):
|
||||
"""Get the current movement deltas from keyboard state."""
|
||||
delta_x = delta_y = delta_z = 0.0
|
||||
|
||||
if self.key_states["forward_x"]:
|
||||
@@ -194,18 +180,58 @@ class KeyboardController(InputController):
|
||||
|
||||
return delta_x, delta_y, delta_z
|
||||
|
||||
def should_intervene(self):
|
||||
return self.key_states["intervention"]
|
||||
|
||||
def reset(self):
|
||||
for key in self.key_states:
|
||||
self.key_states[key] = False
|
||||
|
||||
|
||||
class GamepadController(InputController):
|
||||
"""Generate motion deltas from gamepad input."""
|
||||
"""Generate motion deltas from gamepad input using pygame.
|
||||
|
||||
Matches gym-hil button/axis conventions for Linux gamepads, including
|
||||
Xbox mappings.
|
||||
"""
|
||||
|
||||
# Face buttons (same across most controllers on Linux)
|
||||
BUTTON_A = 0
|
||||
BUTTON_B = 1
|
||||
BUTTON_X = 2
|
||||
BUTTON_Y = 3
|
||||
BUTTON_LB = 4
|
||||
BUTTON_RB = 5
|
||||
# Stick axes
|
||||
AXIS_LEFT_X = 0
|
||||
AXIS_LEFT_Y = 1
|
||||
AXIS_RIGHT_X = 2
|
||||
AXIS_RIGHT_Y = 3
|
||||
|
||||
# Default trigger buttons
|
||||
BUTTON_LT = 6
|
||||
BUTTON_RT = 7
|
||||
|
||||
# Xbox (gym-hil mapping on Linux)
|
||||
XBOX_BUTTON_LT = 9
|
||||
XBOX_BUTTON_RT = 10
|
||||
|
||||
def __init__(self, x_step_size=1.0, y_step_size=1.0, z_step_size=1.0, deadzone=0.1):
|
||||
super().__init__(x_step_size, y_step_size, z_step_size)
|
||||
self.deadzone = deadzone
|
||||
self.joystick = None
|
||||
self.intervention_flag = False
|
||||
self.is_xbox = False
|
||||
self._xbox360_profile = False
|
||||
self._invert_left_x = False
|
||||
self._invert_left_y = True
|
||||
self._invert_right_y = True
|
||||
|
||||
def _detect_xbox(self, name):
|
||||
name_lower = name.lower()
|
||||
return any(tag in name_lower for tag in ["xbox", "microsoft", "x-box"])
|
||||
|
||||
def start(self):
|
||||
"""Initialize pygame and the gamepad."""
|
||||
import pygame
|
||||
|
||||
pygame.init()
|
||||
@@ -218,18 +244,35 @@ class GamepadController(InputController):
|
||||
|
||||
self.joystick = pygame.joystick.Joystick(0)
|
||||
self.joystick.init()
|
||||
logging.info(f"Initialized gamepad: {self.joystick.get_name()}")
|
||||
joystick_name = self.joystick.get_name()
|
||||
self.is_xbox = self._detect_xbox(joystick_name)
|
||||
self._xbox360_profile = joystick_name == "Xbox 360 Controller"
|
||||
if self._xbox360_profile:
|
||||
# gym-hil "Xbox 360 Controller" profile
|
||||
self.AXIS_RIGHT_X = 3
|
||||
self.AXIS_RIGHT_Y = 4
|
||||
self.BUTTON_LT = self.XBOX_BUTTON_LT
|
||||
self.BUTTON_RT = self.XBOX_BUTTON_RT
|
||||
self._invert_left_x = True
|
||||
else:
|
||||
# gym-hil default profile
|
||||
self.AXIS_RIGHT_X = 2
|
||||
self.AXIS_RIGHT_Y = 3
|
||||
self.BUTTON_LT = 6
|
||||
self.BUTTON_RT = 7
|
||||
self._invert_left_x = False
|
||||
logging.info(f"Initialized gamepad: {joystick_name} (xbox={self.is_xbox})")
|
||||
|
||||
print("Gamepad controls:")
|
||||
print(" Left analog stick: Move in X-Y plane")
|
||||
print(" Right analog stick (vertical): Move in Z axis")
|
||||
print(" B/Circle button: Exit")
|
||||
print(" Y/Triangle button: End episode with SUCCESS")
|
||||
print(" A/Cross button: End episode with FAILURE")
|
||||
print(" X/Square button: Rerecord episode")
|
||||
print(" RB: Intervention toggle")
|
||||
print(" LT / RT: Close / Open gripper")
|
||||
print(" Y: End episode with SUCCESS")
|
||||
print(" A: End episode with FAILURE")
|
||||
print(" X: Rerecord episode")
|
||||
|
||||
def stop(self):
|
||||
"""Clean up pygame resources."""
|
||||
import pygame
|
||||
|
||||
if pygame.joystick.get_init():
|
||||
@@ -239,67 +282,56 @@ class GamepadController(InputController):
|
||||
pygame.quit()
|
||||
|
||||
def update(self):
|
||||
"""Process pygame events to get fresh gamepad readings."""
|
||||
import pygame
|
||||
|
||||
for event in pygame.event.get():
|
||||
if event.type == pygame.JOYBUTTONDOWN:
|
||||
if event.button == 3:
|
||||
if event.button == self.BUTTON_Y:
|
||||
self.episode_end_status = TeleopEvents.SUCCESS
|
||||
# A button (1) for failure
|
||||
elif event.button == 1:
|
||||
elif event.button == self.BUTTON_A:
|
||||
self.episode_end_status = TeleopEvents.FAILURE
|
||||
# X button (0) for rerecord
|
||||
elif event.button == 0:
|
||||
elif event.button == self.BUTTON_X:
|
||||
self.episode_end_status = TeleopEvents.RERECORD_EPISODE
|
||||
|
||||
# RB button (6) for closing gripper
|
||||
elif event.button == 6:
|
||||
elif event.button == self.BUTTON_LT:
|
||||
self.close_gripper_command = True
|
||||
|
||||
# LT button (7) for opening gripper
|
||||
elif event.button == 7:
|
||||
elif event.button == self.BUTTON_RT:
|
||||
self.open_gripper_command = True
|
||||
|
||||
# Reset episode status on button release
|
||||
elif event.type == pygame.JOYBUTTONUP:
|
||||
if event.button in [0, 2, 3]:
|
||||
if event.button in [self.BUTTON_Y, self.BUTTON_A, self.BUTTON_X]:
|
||||
self.episode_end_status = None
|
||||
|
||||
elif event.button == 6:
|
||||
elif event.button == self.BUTTON_LT:
|
||||
self.close_gripper_command = False
|
||||
|
||||
elif event.button == 7:
|
||||
elif event.button == self.BUTTON_RT:
|
||||
self.open_gripper_command = False
|
||||
|
||||
# Check for RB button (typically button 5) for intervention flag
|
||||
if self.joystick.get_button(5):
|
||||
if self.joystick.get_button(self.BUTTON_RB):
|
||||
self.intervention_flag = True
|
||||
else:
|
||||
self.intervention_flag = False
|
||||
|
||||
def get_deltas(self):
|
||||
"""Get the current movement deltas from gamepad state."""
|
||||
import pygame
|
||||
|
||||
try:
|
||||
# Read joystick axes
|
||||
# Left stick X and Y (typically axes 0 and 1)
|
||||
y_input = self.joystick.get_axis(0) # Up/Down (often inverted)
|
||||
x_input = self.joystick.get_axis(1) # Left/Right
|
||||
x_input = self.joystick.get_axis(self.AXIS_LEFT_X)
|
||||
y_input = self.joystick.get_axis(self.AXIS_LEFT_Y)
|
||||
z_input = self.joystick.get_axis(self.AXIS_RIGHT_Y)
|
||||
|
||||
# Right stick Y (typically axis 3 or 4)
|
||||
z_input = self.joystick.get_axis(3) # Up/Down for Z
|
||||
|
||||
# Apply deadzone to avoid drift
|
||||
x_input = 0 if abs(x_input) < self.deadzone else x_input
|
||||
y_input = 0 if abs(y_input) < self.deadzone else y_input
|
||||
z_input = 0 if abs(z_input) < self.deadzone else z_input
|
||||
|
||||
# Calculate deltas (note: may need to invert axes depending on controller)
|
||||
delta_x = -x_input * self.x_step_size # Forward/backward
|
||||
delta_y = -y_input * self.y_step_size # Left/right
|
||||
delta_z = -z_input * self.z_step_size # Up/down
|
||||
if self._invert_left_x:
|
||||
x_input = -x_input
|
||||
if self._invert_left_y:
|
||||
y_input = -y_input
|
||||
if self._invert_right_y:
|
||||
z_input = -z_input
|
||||
|
||||
delta_x = y_input * self.y_step_size
|
||||
delta_y = x_input * self.x_step_size
|
||||
delta_z = z_input * self.z_step_size
|
||||
|
||||
return delta_x, delta_y, delta_z
|
||||
|
||||
@@ -309,7 +341,15 @@ class GamepadController(InputController):
|
||||
|
||||
|
||||
class GamepadControllerHID(InputController):
|
||||
"""Generate motion deltas from gamepad input using HIDAPI."""
|
||||
"""Generate motion deltas from gamepad input using HIDAPI.
|
||||
|
||||
Supports auto-detection of controller type for correct HID report parsing.
|
||||
Currently supported: Logitech RumblePad 2, 8BitDo Ultimate 2C Wireless.
|
||||
"""
|
||||
|
||||
CONTROLLER_LOGITECH = "logitech"
|
||||
CONTROLLER_8BITDO = "8bitdo"
|
||||
CONTROLLER_UNKNOWN = "unknown"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -318,36 +358,26 @@ class GamepadControllerHID(InputController):
|
||||
z_step_size=1.0,
|
||||
deadzone=0.1,
|
||||
):
|
||||
"""
|
||||
Initialize the HID gamepad controller.
|
||||
|
||||
Args:
|
||||
step_size: Base movement step size in meters
|
||||
z_scale: Scaling factor for Z-axis movement
|
||||
deadzone: Joystick deadzone to prevent drift
|
||||
"""
|
||||
super().__init__(x_step_size, y_step_size, z_step_size)
|
||||
self.deadzone = deadzone
|
||||
self.device = None
|
||||
self.device_info = None
|
||||
self.controller_type = self.CONTROLLER_UNKNOWN
|
||||
|
||||
# Movement values (normalized from -1.0 to 1.0)
|
||||
self.left_x = 0.0
|
||||
self.left_y = 0.0
|
||||
self.right_x = 0.0
|
||||
self.right_y = 0.0
|
||||
|
||||
# Button states
|
||||
self.buttons = {}
|
||||
|
||||
def find_device(self):
|
||||
"""Look for the gamepad device by vendor and product ID."""
|
||||
import hid
|
||||
|
||||
devices = hid.enumerate()
|
||||
for device in devices:
|
||||
device_name = device["product_string"]
|
||||
if any(controller in device_name for controller in ["Logitech", "Xbox", "PS4", "PS5"]):
|
||||
if any(controller in device_name for controller in ["Logitech", "Xbox", "PS4", "PS5", "8BitDo"]):
|
||||
return device
|
||||
|
||||
logging.error(
|
||||
@@ -355,8 +385,15 @@ class GamepadControllerHID(InputController):
|
||||
)
|
||||
return None
|
||||
|
||||
def _detect_controller_type(self, product_string):
|
||||
product = product_string.lower() if product_string else ""
|
||||
if "8bitdo" in product:
|
||||
return self.CONTROLLER_8BITDO
|
||||
elif "logitech" in product:
|
||||
return self.CONTROLLER_LOGITECH
|
||||
return self.CONTROLLER_UNKNOWN
|
||||
|
||||
def start(self):
|
||||
"""Connect to the gamepad using HIDAPI."""
|
||||
import hid
|
||||
|
||||
self.device_info = self.find_device()
|
||||
@@ -374,12 +411,22 @@ class GamepadControllerHID(InputController):
|
||||
product = self.device.get_product_string()
|
||||
logging.info(f"Connected to {manufacturer} {product}")
|
||||
|
||||
logging.info("Gamepad controls (HID mode):")
|
||||
logging.info(" Left analog stick: Move in X-Y plane")
|
||||
logging.info(" Right analog stick: Move in Z axis (vertical)")
|
||||
logging.info(" Button 1/B/Circle: Exit")
|
||||
logging.info(" Button 2/A/Cross: End episode with SUCCESS")
|
||||
logging.info(" Button 3/X/Square: End episode with FAILURE")
|
||||
self.controller_type = self._detect_controller_type(product)
|
||||
logging.info(f"Detected controller type: {self.controller_type}")
|
||||
|
||||
print("Gamepad controls (HID mode):")
|
||||
print(" Left analog stick: Move in X-Y plane")
|
||||
print(" Right analog stick: Move in Z axis (vertical)")
|
||||
print(" RB: Intervention toggle")
|
||||
if self.controller_type == self.CONTROLLER_8BITDO:
|
||||
print(" L3 (left stick click): Close gripper")
|
||||
print(" R3 (right stick click): Open gripper")
|
||||
else:
|
||||
print(" LT: Close gripper")
|
||||
print(" RT: Open gripper")
|
||||
print(" Y: End episode with SUCCESS")
|
||||
print(" X: End episode with FAILURE")
|
||||
print(" A: Rerecord episode")
|
||||
|
||||
except OSError as e:
|
||||
logging.error(f"Error opening gamepad: {e}")
|
||||
@@ -387,74 +434,124 @@ class GamepadControllerHID(InputController):
|
||||
self.running = False
|
||||
|
||||
def stop(self):
|
||||
"""Close the HID device connection."""
|
||||
if self.device:
|
||||
self.device.close()
|
||||
self.device = None
|
||||
|
||||
def update(self):
|
||||
"""
|
||||
Read and process the latest gamepad data.
|
||||
Due to an issue with the HIDAPI, we need to read the read the device several times in order to get a stable reading
|
||||
"""
|
||||
"""Read the device several times to drain the HID buffer and get a stable reading."""
|
||||
for _ in range(10):
|
||||
self._update()
|
||||
|
||||
def _update(self):
|
||||
"""Read and process the latest gamepad data."""
|
||||
if not self.device or not self.running:
|
||||
return
|
||||
|
||||
try:
|
||||
# Read data from the gamepad
|
||||
data = self.device.read(64)
|
||||
# Interpret gamepad data - this will vary by controller model
|
||||
# These offsets are for the Logitech RumblePad 2
|
||||
if data and len(data) >= 8:
|
||||
# Normalize joystick values from 0-255 to -1.0-1.0
|
||||
self.left_y = (data[1] - 128) / 128.0
|
||||
self.left_x = (data[2] - 128) / 128.0
|
||||
self.right_x = (data[3] - 128) / 128.0
|
||||
self.right_y = (data[4] - 128) / 128.0
|
||||
if not data:
|
||||
return
|
||||
|
||||
# Apply deadzone
|
||||
self.left_y = 0 if abs(self.left_y) < self.deadzone else self.left_y
|
||||
self.left_x = 0 if abs(self.left_x) < self.deadzone else self.left_x
|
||||
self.right_x = 0 if abs(self.right_x) < self.deadzone else self.right_x
|
||||
self.right_y = 0 if abs(self.right_y) < self.deadzone else self.right_y
|
||||
|
||||
# Parse button states (byte 5 in the Logitech RumblePad 2)
|
||||
buttons = data[5]
|
||||
|
||||
# Check if RB is pressed then the intervention flag should be set
|
||||
self.intervention_flag = data[6] in [2, 6, 10, 14]
|
||||
|
||||
# Check if RT is pressed
|
||||
self.open_gripper_command = data[6] in [8, 10, 12]
|
||||
|
||||
# Check if LT is pressed
|
||||
self.close_gripper_command = data[6] in [4, 6, 12]
|
||||
|
||||
# Check if Y/Triangle button (bit 7) is pressed for saving
|
||||
# Check if X/Square button (bit 5) is pressed for failure
|
||||
# Check if A/Cross button (bit 4) is pressed for rerecording
|
||||
if buttons & 1 << 7:
|
||||
self.episode_end_status = TeleopEvents.SUCCESS
|
||||
elif buttons & 1 << 5:
|
||||
self.episode_end_status = TeleopEvents.FAILURE
|
||||
elif buttons & 1 << 4:
|
||||
self.episode_end_status = TeleopEvents.RERECORD_EPISODE
|
||||
else:
|
||||
self.episode_end_status = None
|
||||
if self.controller_type == self.CONTROLLER_8BITDO:
|
||||
self._parse_8bitdo(data)
|
||||
else:
|
||||
self._parse_logitech(data)
|
||||
|
||||
except OSError as e:
|
||||
logging.error(f"Error reading from gamepad: {e}")
|
||||
|
||||
def _apply_deadzone(self):
|
||||
self.left_x = 0 if abs(self.left_x) < self.deadzone else self.left_x
|
||||
self.left_y = 0 if abs(self.left_y) < self.deadzone else self.left_y
|
||||
self.right_x = 0 if abs(self.right_x) < self.deadzone else self.right_x
|
||||
self.right_y = 0 if abs(self.right_y) < self.deadzone else self.right_y
|
||||
|
||||
def _parse_8bitdo(self, data):
|
||||
"""Parse HID report from 8BitDo Ultimate 2C Wireless (Bluetooth on macOS).
|
||||
|
||||
11-byte report layout:
|
||||
byte[0]: Report ID (0x01)
|
||||
byte[1]: D-pad hat switch (0=N, 2=E, 5=S, 6=W, 15=neutral)
|
||||
byte[2]: Left Stick X (0=left, 127=center, 255=right)
|
||||
byte[3]: Left Stick Y (0=up, 127=center, 255=down)
|
||||
byte[4]: Right Stick X (inverted: 255=left, 0=right)
|
||||
byte[5]: Right Stick Y (0=up, 127=center, 255=down)
|
||||
byte[6]: RT analog trigger (0-255)
|
||||
byte[7]: LT analog trigger (0-255)
|
||||
byte[8]: Buttons -- bit0=A, bit1=B, bit3=X, bit4=Y, bit6=LB, bit7=RB
|
||||
byte[9]: System -- bit0=LT(digital), bit1=RT(digital), bit3=Select,
|
||||
bit4=Start, bit5=L3, bit6=R3
|
||||
byte[10]: Unused
|
||||
"""
|
||||
if len(data) < 11:
|
||||
return
|
||||
|
||||
self.left_x = (data[2] - 127) / 128.0
|
||||
self.left_y = (data[3] - 127) / 128.0
|
||||
self.right_x = -(data[4] - 127) / 128.0
|
||||
self.right_y = (data[5] - 127) / 128.0
|
||||
|
||||
self._apply_deadzone()
|
||||
|
||||
buttons = data[8]
|
||||
|
||||
# RB (bit 7) = intervention
|
||||
self.intervention_flag = bool(buttons & 0x80)
|
||||
|
||||
# Stick clicks for gripper: R3 (byte[9] bit6) = open, L3 (byte[9] bit5) = close
|
||||
system = data[9]
|
||||
self.open_gripper_command = bool(system & 0x40) # R3
|
||||
self.close_gripper_command = bool(system & 0x20) # L3
|
||||
|
||||
# Y (bit 4) = success, X (bit 3) = failure, A (bit 0) = rerecord
|
||||
if buttons & 0x10:
|
||||
self.episode_end_status = TeleopEvents.SUCCESS
|
||||
elif buttons & 0x08:
|
||||
self.episode_end_status = TeleopEvents.FAILURE
|
||||
elif buttons & 0x01:
|
||||
self.episode_end_status = TeleopEvents.RERECORD_EPISODE
|
||||
else:
|
||||
self.episode_end_status = None
|
||||
|
||||
def _parse_logitech(self, data):
|
||||
"""Parse HID report from Logitech RumblePad 2 (and similar Logitech gamepads).
|
||||
|
||||
Report layout (8+ bytes):
|
||||
byte[1]: Left Stick X (0-255, center=128)
|
||||
byte[2]: Left Stick Y (0-255, center=128)
|
||||
byte[3]: Right Stick X (0-255, center=128)
|
||||
byte[4]: Right Stick Y (0-255, center=128)
|
||||
byte[5]: Face buttons bitmask
|
||||
byte[6]: Shoulder/trigger buttons bitmask
|
||||
"""
|
||||
if len(data) < 8:
|
||||
return
|
||||
|
||||
self.left_x = (data[1] - 128) / 128.0
|
||||
self.left_y = (data[2] - 128) / 128.0
|
||||
self.right_x = (data[3] - 128) / 128.0
|
||||
self.right_y = (data[4] - 128) / 128.0
|
||||
|
||||
self._apply_deadzone()
|
||||
|
||||
buttons = data[5]
|
||||
|
||||
self.intervention_flag = data[6] in [2, 6, 10, 14]
|
||||
self.open_gripper_command = data[6] in [8, 10, 12]
|
||||
self.close_gripper_command = data[6] in [4, 6, 12]
|
||||
|
||||
if buttons & 1 << 7:
|
||||
self.episode_end_status = TeleopEvents.SUCCESS
|
||||
elif buttons & 1 << 5:
|
||||
self.episode_end_status = TeleopEvents.FAILURE
|
||||
elif buttons & 1 << 4:
|
||||
self.episode_end_status = TeleopEvents.RERECORD_EPISODE
|
||||
else:
|
||||
self.episode_end_status = None
|
||||
|
||||
def get_deltas(self):
|
||||
"""Get the current movement deltas from gamepad state."""
|
||||
# Calculate deltas - invert as needed based on controller orientation
|
||||
delta_x = -self.left_x * self.x_step_size # Forward/backward
|
||||
delta_y = -self.left_y * self.y_step_size # Left/right
|
||||
delta_z = -self.right_y * self.z_step_size # Up/down
|
||||
delta_x = -self.left_y * self.x_step_size
|
||||
delta_y = -self.left_x * self.y_step_size
|
||||
delta_z = -self.right_y * self.z_step_size
|
||||
|
||||
return delta_x, delta_y, delta_z
|
||||
|
||||
Reference in New Issue
Block a user