mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-21 11:39:50 +00:00
feat: add benchmark orchestration, LIBERO-plus install parity, and eval hardening
- Add lerobot-benchmark CLI for multi-benchmark train/eval workflows - Add benchmark_training.mdx documentation - Add libero-plus pip extra alias with EGL probe deps matching standard libero - Harden libero.py: wand mock, init-state fallback, renderer EGL→OSMesa fallback - Add multimodal_analysis.py script and SLURM training template Made-with: Cursor
This commit is contained in:
@@ -19,6 +19,8 @@
|
|||||||
title: Multi GPU training
|
title: Multi GPU training
|
||||||
- local: peft_training
|
- local: peft_training
|
||||||
title: Training with PEFT (e.g., LoRA)
|
title: Training with PEFT (e.g., LoRA)
|
||||||
|
- local: benchmark_training
|
||||||
|
title: Benchmark Training & Evaluation
|
||||||
title: "Tutorials"
|
title: "Tutorials"
|
||||||
- sections:
|
- sections:
|
||||||
- local: lerobot-dataset-v3
|
- local: lerobot-dataset-v3
|
||||||
|
|||||||
@@ -0,0 +1,260 @@
|
|||||||
|
# Benchmark Training & Evaluation
|
||||||
|
|
||||||
|
This guide explains how to train and evaluate policies on the simulation benchmarks
|
||||||
|
integrated in LeRobot: **LIBERO**, **LIBERO-plus**, **MetaWorld**, **RoboCasa**, and **RoboMME**.
|
||||||
|
|
||||||
|
The workflow is:
|
||||||
|
|
||||||
|
1. Pick one or more benchmarks.
|
||||||
|
2. For each benchmark, train a policy on its combined dataset (multi-GPU).
|
||||||
|
3. Upload the trained policy to the Hugging Face Hub.
|
||||||
|
4. Evaluate the policy on every task suite within that benchmark.
|
||||||
|
|
||||||
|
## Prerequisites
|
||||||
|
|
||||||
|
Install the benchmark-specific dependencies for the environments you want to evaluate on:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# LIBERO (original)
|
||||||
|
pip install -e ".[libero]"
|
||||||
|
|
||||||
|
# LIBERO-plus
|
||||||
|
pip install -e ".[libero_plus]"
|
||||||
|
|
||||||
|
# MetaWorld
|
||||||
|
pip install -e ".[metaworld]"
|
||||||
|
|
||||||
|
# RoboCasa
|
||||||
|
pip install -e ".[robocasa]"
|
||||||
|
|
||||||
|
# RoboMME
|
||||||
|
pip install -e ".[robomme]"
|
||||||
|
```
|
||||||
|
|
||||||
|
`libero_plus` includes the same EGL probe dependencies as `libero` so headless
|
||||||
|
renderer setup is consistent between both installs.
|
||||||
|
|
||||||
|
If your environment has CMake build-isolation issues, use the same fallback as
|
||||||
|
standard LIBERO installs:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
PATH=/usr/bin:/bin:$PATH pip install --no-build-isolation -e ".[libero-plus]"
|
||||||
|
```
|
||||||
|
|
||||||
|
For multi-GPU training you also need [Accelerate](https://huggingface.co/docs/accelerate):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install accelerate
|
||||||
|
```
|
||||||
|
|
||||||
|
## Quick start — single benchmark
|
||||||
|
|
||||||
|
Train SmolVLA on LIBERO-plus with 4 GPUs for 50 000 steps:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-benchmark train \
|
||||||
|
--benchmarks libero_plus \
|
||||||
|
--policy-path lerobot/smolvla_base \
|
||||||
|
--hub-user $HF_USER \
|
||||||
|
--num-gpus 4 \
|
||||||
|
--steps 50000 \
|
||||||
|
--batch-size 32 \
|
||||||
|
--wandb
|
||||||
|
```
|
||||||
|
|
||||||
|
This trains on the combined LIBERO-plus dataset and pushes the checkpoint to
|
||||||
|
`$HF_USER/smolvla_libero_plus` on the Hub.
|
||||||
|
|
||||||
|
Then evaluate on **all four** LIBERO suites (spatial, object, goal, 10):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-benchmark eval \
|
||||||
|
--benchmarks libero_plus \
|
||||||
|
--hub-user $HF_USER \
|
||||||
|
--n-episodes 50
|
||||||
|
```
|
||||||
|
|
||||||
|
This automatically runs a separate `lerobot-eval` for each suite.
|
||||||
|
|
||||||
|
## Full sweep — multiple benchmarks
|
||||||
|
|
||||||
|
Run training **and** evaluation across all benchmarks:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-benchmark all \
|
||||||
|
--benchmarks libero,libero_plus,metaworld,robocasa,robomme \
|
||||||
|
--policy-path lerobot/smolvla_base \
|
||||||
|
--hub-user $HF_USER \
|
||||||
|
--num-gpus 4 \
|
||||||
|
--steps 50000 \
|
||||||
|
--batch-size 32 \
|
||||||
|
--wandb \
|
||||||
|
--push-eval-to-hub
|
||||||
|
```
|
||||||
|
|
||||||
|
For each benchmark the runner:
|
||||||
|
1. Trains a policy on its dataset.
|
||||||
|
2. Evaluates on every eval task in the benchmark (e.g. 4 suites for LIBERO).
|
||||||
|
3. Uploads eval results + videos to the Hub.
|
||||||
|
|
||||||
|
<Tip>
|
||||||
|
|
||||||
|
Use `--dry-run` to print the exact `lerobot-train` / `lerobot-eval` commands without executing them, so you can inspect or modify them before running.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
|
||||||
|
## Using the CLI directly (without the benchmark runner)
|
||||||
|
|
||||||
|
You can also compose the commands yourself. The benchmark runner is a thin wrapper; here is what it does under the hood.
|
||||||
|
|
||||||
|
### Training
|
||||||
|
|
||||||
|
```bash
|
||||||
|
accelerate launch \
|
||||||
|
--multi_gpu \
|
||||||
|
--num_processes=4 \
|
||||||
|
$(which lerobot-train) \
|
||||||
|
--policy.path=lerobot/smolvla_base \
|
||||||
|
--dataset.repo_id=$HF_USER/libero_plus \
|
||||||
|
--policy.repo_id=$HF_USER/smolvla_libero_plus \
|
||||||
|
--env.type=libero_plus \
|
||||||
|
--env.task=libero_spatial \
|
||||||
|
--steps=50000 \
|
||||||
|
--batch_size=32 \
|
||||||
|
--eval_freq=10000 \
|
||||||
|
--save_freq=10000 \
|
||||||
|
--output_dir=outputs/train/smolvla_libero_plus \
|
||||||
|
--job_name=smolvla_libero_plus \
|
||||||
|
--policy.push_to_hub=true \
|
||||||
|
--wandb.enable=true
|
||||||
|
```
|
||||||
|
|
||||||
|
### Evaluation (run once per suite)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
for SUITE in libero_spatial libero_object libero_goal libero_10; do
|
||||||
|
lerobot-eval \
|
||||||
|
--policy.path=$HF_USER/smolvla_libero_plus \
|
||||||
|
--env.type=libero_plus \
|
||||||
|
--env.task=$SUITE \
|
||||||
|
--eval.n_episodes=50 \
|
||||||
|
--eval.batch_size=10 \
|
||||||
|
--output_dir=outputs/eval/smolvla_libero_plus/$SUITE \
|
||||||
|
--policy.device=cuda
|
||||||
|
done
|
||||||
|
```
|
||||||
|
|
||||||
|
## Available benchmarks
|
||||||
|
|
||||||
|
| Benchmark | Env type | Dataset | Eval tasks | Action dim |
|
||||||
|
|---|---|---|---|---|
|
||||||
|
| `libero` | `libero` | `{hub_user}/libero` | spatial, object, goal, 10 | 7 |
|
||||||
|
| `libero_plus` | `libero_plus` | `{hub_user}/libero_plus` | spatial, object, goal, 10 | 7 |
|
||||||
|
| `metaworld` | `metaworld` | `{hub_user}/metaworld` | push-v2 | 4 |
|
||||||
|
| `robocasa` | `robocasa` | `{hub_user}/robocasa` | PickPlaceCounterToCabinet | 12 |
|
||||||
|
| `robomme` | `robomme` | `{hub_user}/robomme` | PickXtimes | 8 |
|
||||||
|
|
||||||
|
Run `lerobot-benchmark list` to see the full registry with all eval tasks.
|
||||||
|
|
||||||
|
## Policy naming convention
|
||||||
|
|
||||||
|
The benchmark runner stores trained policies under:
|
||||||
|
|
||||||
|
```
|
||||||
|
{hub_user}/{policy_name}_{benchmark}
|
||||||
|
```
|
||||||
|
|
||||||
|
The default `--policy-name` is `smolvla`. So training on `libero_plus` as user `alice` produces `alice/smolvla_libero_plus`.
|
||||||
|
|
||||||
|
You can override this, e.g. `--policy-name pi05` if training π₀.₅ instead.
|
||||||
|
|
||||||
|
## Multi-GPU considerations
|
||||||
|
|
||||||
|
The effective batch size is `batch_size × num_gpus`. With `--batch-size=32` and
|
||||||
|
`--num-gpus=4`, you train with an effective batch of 128 per step. LeRobot does **not**
|
||||||
|
auto-scale the learning rate; see the [Multi-GPU Training guide](./multi_gpu_training) for
|
||||||
|
details on when and how to adjust it.
|
||||||
|
|
||||||
|
## Custom benchmarks
|
||||||
|
|
||||||
|
To add a new benchmark, edit the `BENCHMARK_REGISTRY` in
|
||||||
|
`src/lerobot/scripts/lerobot_benchmark.py`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from lerobot.scripts.lerobot_benchmark import BenchmarkEntry, BENCHMARK_REGISTRY
|
||||||
|
|
||||||
|
BENCHMARK_REGISTRY["my_benchmark"] = BenchmarkEntry(
|
||||||
|
dataset_repo_id="{hub_user}/my_dataset",
|
||||||
|
env_type="my_env",
|
||||||
|
env_task="MyDefaultTask",
|
||||||
|
eval_tasks=["TaskA", "TaskB", "TaskC"],
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
Then use `--benchmarks my_benchmark` as usual. The runner will train once and
|
||||||
|
evaluate separately on TaskA, TaskB, and TaskC.
|
||||||
|
|
||||||
|
## Outputs
|
||||||
|
|
||||||
|
After training and evaluation, your outputs directory looks like:
|
||||||
|
|
||||||
|
```
|
||||||
|
outputs/
|
||||||
|
├── train/
|
||||||
|
│ ├── smolvla_libero/
|
||||||
|
│ │ ├── checkpoints/
|
||||||
|
│ │ └── ...
|
||||||
|
│ ├── smolvla_libero_plus/
|
||||||
|
│ ├── smolvla_robocasa/
|
||||||
|
│ └── smolvla_robomme/
|
||||||
|
└── eval/
|
||||||
|
├── smolvla_libero/
|
||||||
|
│ ├── libero_spatial/
|
||||||
|
│ │ ├── eval_info.json
|
||||||
|
│ │ └── videos/
|
||||||
|
│ ├── libero_object/
|
||||||
|
│ ├── libero_goal/
|
||||||
|
│ └── libero_10/
|
||||||
|
├── smolvla_libero_plus/
|
||||||
|
│ ├── libero_spatial/
|
||||||
|
│ ├── libero_object/
|
||||||
|
│ ├── libero_goal/
|
||||||
|
│ └── libero_10/
|
||||||
|
├── smolvla_robocasa/
|
||||||
|
└── smolvla_robomme/
|
||||||
|
```
|
||||||
|
|
||||||
|
Each `eval_info.json` contains per-episode rewards, success rates, and aggregate metrics.
|
||||||
|
|
||||||
|
## Uploading eval results to the Hub
|
||||||
|
|
||||||
|
Add `--push-eval-to-hub` to upload evaluation metrics and videos to the policy's
|
||||||
|
Hub repo after each eval run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-benchmark eval \
|
||||||
|
--benchmarks libero_plus,robocasa \
|
||||||
|
--hub-user $HF_USER \
|
||||||
|
--push-eval-to-hub
|
||||||
|
```
|
||||||
|
|
||||||
|
For LIBERO-plus, each suite's results are uploaded to `eval/libero_spatial/`,
|
||||||
|
`eval/libero_object/`, etc. inside the `$HF_USER/smolvla_libero_plus` model repo.
|
||||||
|
|
||||||
|
This also works with the `all` subcommand — pass `--push-eval-to-hub` and results
|
||||||
|
are automatically uploaded after each eval run.
|
||||||
|
|
||||||
|
## Passing extra arguments
|
||||||
|
|
||||||
|
Any arguments after the recognized flags are forwarded to `lerobot-train` or
|
||||||
|
`lerobot-eval`. For example, to use PEFT/LoRA during training:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-benchmark train \
|
||||||
|
--benchmarks libero_plus \
|
||||||
|
--policy-path lerobot/smolvla_base \
|
||||||
|
--hub-user $HF_USER \
|
||||||
|
--num-gpus 4 \
|
||||||
|
--steps 50000 \
|
||||||
|
--peft.method_type=LORA --peft.r=16
|
||||||
|
```
|
||||||
@@ -177,9 +177,12 @@ pusht = ["gym-pusht>=0.1.5,<0.2.0", "pymunk>=6.6.0,<7.0.0"] # TODO: Fix pymunk v
|
|||||||
libero = ["lerobot[transformers-dep]", "hf-libero>=0.1.3,<0.2.0; sys_platform == 'linux'", "lerobot[scipy-dep]"]
|
libero = ["lerobot[transformers-dep]", "hf-libero>=0.1.3,<0.2.0; sys_platform == 'linux'", "lerobot[scipy-dep]"]
|
||||||
libero_plus = [
|
libero_plus = [
|
||||||
"lerobot[transformers-dep]",
|
"lerobot[transformers-dep]",
|
||||||
|
"hf-egl-probe>=1.0.1; sys_platform == 'linux'",
|
||||||
|
"egl_probe>=1.0.1; sys_platform == 'linux'",
|
||||||
"libero @ git+https://github.com/sylvestf/LIBERO-plus.git@main ; sys_platform == 'linux'",
|
"libero @ git+https://github.com/sylvestf/LIBERO-plus.git@main ; sys_platform == 'linux'",
|
||||||
"lerobot[scipy-dep]",
|
"lerobot[scipy-dep]",
|
||||||
]
|
]
|
||||||
|
libero-plus = ["lerobot[libero_plus]"]
|
||||||
robomme = [
|
robomme = [
|
||||||
"robomme @ git+https://github.com/RoboMME/robomme_benchmark.git@main ; sys_platform == 'linux'",
|
"robomme @ git+https://github.com/RoboMME/robomme_benchmark.git@main ; sys_platform == 'linux'",
|
||||||
]
|
]
|
||||||
@@ -236,6 +239,7 @@ lerobot-find-joint-limits="lerobot.scripts.lerobot_find_joint_limits:main"
|
|||||||
lerobot-imgtransform-viz="lerobot.scripts.lerobot_imgtransform_viz:main"
|
lerobot-imgtransform-viz="lerobot.scripts.lerobot_imgtransform_viz:main"
|
||||||
lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main"
|
lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main"
|
||||||
lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main"
|
lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main"
|
||||||
|
lerobot-benchmark="lerobot.scripts.lerobot_benchmark:main"
|
||||||
|
|
||||||
# ---------------- Tool Configurations ----------------
|
# ---------------- Tool Configurations ----------------
|
||||||
[tool.setuptools.package-data]
|
[tool.setuptools.package-data]
|
||||||
|
|||||||
@@ -0,0 +1,689 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
Chunk-level multi-modality analysis for comparing full/mixed vs curated datasets.
|
||||||
|
|
||||||
|
Treats each action chunk (sliding window of CHUNK_SIZE consecutive frames) as the
|
||||||
|
atomic unit, tagged by the SARM progress score at its start frame. For each
|
||||||
|
progress band, compares the full vs HQ dataset on:
|
||||||
|
|
||||||
|
1. Intra-band action variance
|
||||||
|
2. Progress delta per chunk
|
||||||
|
3. GMM + BIC optimal K (number of distinct strategies)
|
||||||
|
4. PCA embedding (visual cluster inspection)
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python chunk_multimodality_analysis.py \\
|
||||||
|
--full-dataset lerobot-data-collection/level12_rac_2_2026-02-08_1 \\
|
||||||
|
--hq-dataset lerobot-data-collection/level2_final_quality3 \\
|
||||||
|
--output-dir ./chunk_analysis
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from collections import defaultdict
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
from scipy.stats import gaussian_kde
|
||||||
|
from sklearn.decomposition import PCA
|
||||||
|
from sklearn.mixture import GaussianMixture
|
||||||
|
from sklearn.preprocessing import StandardScaler
|
||||||
|
|
||||||
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# ── Visual style ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
BG = "#0e1117"
|
||||||
|
CARD = "#1a1d27"
|
||||||
|
BORDER = "#2a2d3a"
|
||||||
|
SUB = "#8b8fa8"
|
||||||
|
TEXT = "#e8eaf0"
|
||||||
|
C_FULL = "#f7934f"
|
||||||
|
C_HQ = "#4dc98a"
|
||||||
|
|
||||||
|
|
||||||
|
def _style_ax(ax: plt.Axes) -> None:
|
||||||
|
ax.set_facecolor(CARD)
|
||||||
|
ax.tick_params(colors=SUB, labelsize=8)
|
||||||
|
for spine in ax.spines.values():
|
||||||
|
spine.set_color(BORDER)
|
||||||
|
|
||||||
|
|
||||||
|
def _save(fig: plt.Figure, path: Path) -> None:
|
||||||
|
fig.savefig(path, dpi=150, bbox_inches="tight", facecolor=BG)
|
||||||
|
plt.close(fig)
|
||||||
|
logger.info("Saved %s", path)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Step 0: Load episodes ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _load_sarm_progress(repo_id: str) -> pd.DataFrame | None:
|
||||||
|
"""Try to download sarm_progress.parquet from the Hub."""
|
||||||
|
try:
|
||||||
|
path = hf_hub_download(
|
||||||
|
repo_id=repo_id, filename="sarm_progress.parquet",
|
||||||
|
repo_type="dataset",
|
||||||
|
)
|
||||||
|
df = pd.read_parquet(path)
|
||||||
|
col = "progress_sparse" if "progress_sparse" in df.columns else "progress_dense"
|
||||||
|
if col not in df.columns:
|
||||||
|
logger.warning("sarm_progress.parquet has no progress columns — ignoring")
|
||||||
|
return None
|
||||||
|
logger.info("Loaded SARM progress (%s) for %s (%d rows)", col, repo_id, len(df))
|
||||||
|
return df.rename(columns={col: "progress"})[["episode_index", "frame_index", "progress"]]
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("Could not load sarm_progress.parquet for %s: %s", repo_id, exc)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def load_episodes(
|
||||||
|
repo_id: str,
|
||||||
|
n_joints: int = 16,
|
||||||
|
max_episodes: int | None = None,
|
||||||
|
) -> list[dict]:
|
||||||
|
dataset = LeRobotDataset(repo_id, download_videos=False)
|
||||||
|
raw = dataset.hf_dataset
|
||||||
|
|
||||||
|
sarm_df = _load_sarm_progress(repo_id)
|
||||||
|
# Build per-episode progress arrays from SARM parquet (indexed by frame_index)
|
||||||
|
sarm_by_ep: dict[int, dict[int, float]] = {}
|
||||||
|
if sarm_df is not None:
|
||||||
|
if max_episodes is not None:
|
||||||
|
sarm_df = sarm_df[sarm_df["episode_index"] < max_episodes]
|
||||||
|
for ep_id, grp in sarm_df.groupby("episode_index"):
|
||||||
|
sarm_by_ep[int(ep_id)] = dict(
|
||||||
|
zip(grp["frame_index"].astype(int), grp["progress"].astype(float))
|
||||||
|
)
|
||||||
|
|
||||||
|
episodes: dict[int, dict] = defaultdict(lambda: {"actions": [], "progress": []})
|
||||||
|
for row in raw:
|
||||||
|
ep = int(row["episode_index"])
|
||||||
|
if max_episodes is not None and ep >= max_episodes:
|
||||||
|
continue
|
||||||
|
action = np.array(row["action"], dtype=np.float32)[:n_joints]
|
||||||
|
episodes[ep]["actions"].append(action)
|
||||||
|
fi = int(row["frame_index"])
|
||||||
|
ep_prog = sarm_by_ep.get(ep, {})
|
||||||
|
episodes[ep]["progress"].append(ep_prog.get(fi, float("nan")))
|
||||||
|
|
||||||
|
has_sarm = len(sarm_lookup) > 0
|
||||||
|
result = []
|
||||||
|
for ep_id, d in sorted(episodes.items()):
|
||||||
|
actions = np.stack(d["actions"])
|
||||||
|
T = len(actions)
|
||||||
|
if has_sarm:
|
||||||
|
prog = np.array(d["progress"], dtype=np.float32)
|
||||||
|
prog = np.clip(np.nan_to_num(prog, nan=0.0), 0.0, 1.0)
|
||||||
|
prog = np.maximum.accumulate(prog)
|
||||||
|
else:
|
||||||
|
prog = np.linspace(0.0, 1.0, T, dtype=np.float32)
|
||||||
|
result.append({"episode": ep_id, "actions": actions, "progress": prog})
|
||||||
|
|
||||||
|
src = "SARM" if has_sarm else "time-based"
|
||||||
|
logger.info("Progress source: %s", src)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
# ── Step 1: Filter short episodes ────────────────────────────────────────
|
||||||
|
|
||||||
|
def auto_length_threshold(
|
||||||
|
episodes_full: list[dict], episodes_hq: list[dict]
|
||||||
|
) -> int:
|
||||||
|
all_lengths = np.array(
|
||||||
|
[e["actions"].shape[0] for e in episodes_full + episodes_hq]
|
||||||
|
)
|
||||||
|
kde = gaussian_kde(all_lengths, bw_method=0.25)
|
||||||
|
xs = np.linspace(all_lengths.min(), np.percentile(all_lengths, 40), 300)
|
||||||
|
return int(xs[np.argmin(kde(xs))])
|
||||||
|
|
||||||
|
|
||||||
|
def plot_length_distribution(
|
||||||
|
episodes_full: list[dict],
|
||||||
|
episodes_hq: list[dict],
|
||||||
|
threshold: int,
|
||||||
|
out_path: Path,
|
||||||
|
) -> None:
|
||||||
|
lens_full = np.array([e["actions"].shape[0] for e in episodes_full])
|
||||||
|
lens_hq = np.array([e["actions"].shape[0] for e in episodes_hq])
|
||||||
|
all_lens = np.concatenate([lens_full, lens_hq])
|
||||||
|
|
||||||
|
fig, ax = plt.subplots(figsize=(10, 5))
|
||||||
|
fig.patch.set_facecolor(BG)
|
||||||
|
_style_ax(ax)
|
||||||
|
|
||||||
|
bins = np.linspace(all_lens.min(), all_lens.max(), 50)
|
||||||
|
ax.hist(lens_full, bins=bins, alpha=0.5, color=C_FULL, label="Full/Mixed")
|
||||||
|
ax.hist(lens_hq, bins=bins, alpha=0.5, color=C_HQ, label="HQ")
|
||||||
|
|
||||||
|
xs = np.linspace(all_lens.min(), all_lens.max(), 300)
|
||||||
|
kde = gaussian_kde(all_lens, bw_method=0.25)
|
||||||
|
ax.plot(xs, kde(xs) * len(all_lens) * (bins[1] - bins[0]), color=TEXT, lw=1.5, label="KDE (combined)")
|
||||||
|
|
||||||
|
ax.axvline(threshold, color="#ff4b4b", ls="--", lw=1.5, label=f"Threshold = {threshold}")
|
||||||
|
ax.set_xlabel("Episode length (frames)", color=SUB)
|
||||||
|
ax.set_ylabel("Count", color=SUB)
|
||||||
|
ax.set_title("Episode Length Distribution", color=TEXT, fontsize=13)
|
||||||
|
ax.legend(facecolor=CARD, edgecolor=BORDER, labelcolor=TEXT, fontsize=8)
|
||||||
|
_save(fig, out_path)
|
||||||
|
|
||||||
|
|
||||||
|
def filter_episodes(episodes: list[dict], min_length: int) -> list[dict]:
|
||||||
|
kept = [e for e in episodes if e["actions"].shape[0] >= min_length]
|
||||||
|
logger.info("Kept %d / %d episodes (min_length=%d)", len(kept), len(episodes), min_length)
|
||||||
|
return kept
|
||||||
|
|
||||||
|
|
||||||
|
# ── Step 2: Extract chunks ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
def extract_chunks(
|
||||||
|
episodes: list[dict],
|
||||||
|
chunk_size: int = 30,
|
||||||
|
chunk_stride: int = 15,
|
||||||
|
) -> list[dict]:
|
||||||
|
chunks = []
|
||||||
|
for ep in episodes:
|
||||||
|
actions = ep["actions"]
|
||||||
|
T = len(actions)
|
||||||
|
prog = ep["progress"]
|
||||||
|
|
||||||
|
for t in range(0, T - chunk_size, chunk_stride):
|
||||||
|
chunk = actions[t : t + chunk_size]
|
||||||
|
p_start = float(prog[t])
|
||||||
|
p_end = float(prog[min(t + chunk_size, T - 1)])
|
||||||
|
|
||||||
|
chunks.append({
|
||||||
|
"action_mean": chunk.mean(axis=0).astype(np.float32),
|
||||||
|
"action_flat": chunk.flatten().astype(np.float32),
|
||||||
|
"progress_start": p_start,
|
||||||
|
"progress_delta": p_end - p_start,
|
||||||
|
"episode": ep["episode"],
|
||||||
|
})
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
|
||||||
|
# ── Step 3: Adaptive progress bands ─────────────────────────────────────
|
||||||
|
|
||||||
|
def make_bands(n_bands: int = 5) -> list[tuple[float, float]]:
|
||||||
|
edges = np.linspace(0.0, 1.0, n_bands + 1)
|
||||||
|
return [(float(edges[i]), float(edges[i + 1])) for i in range(n_bands)]
|
||||||
|
|
||||||
|
|
||||||
|
def assign_bands(
|
||||||
|
chunks: list[dict], band_edges: list[tuple[float, float]]
|
||||||
|
) -> list[dict]:
|
||||||
|
n = len(band_edges)
|
||||||
|
for c in chunks:
|
||||||
|
p = c["progress_start"]
|
||||||
|
c["band"] = next(
|
||||||
|
(bi for bi, (lo, hi) in enumerate(band_edges) if p < hi),
|
||||||
|
n - 1,
|
||||||
|
)
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
|
||||||
|
def split_by_band(chunks: list[dict], n_bands: int) -> dict[int, list[dict]]:
|
||||||
|
out: dict[int, list[dict]] = {b: [] for b in range(n_bands)}
|
||||||
|
for c in chunks:
|
||||||
|
out[c["band"]].append(c)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
# ── Step 4: Intra-band action variance ──────────────────────────────────
|
||||||
|
|
||||||
|
def band_variance_matrix(
|
||||||
|
bands: dict[int, list[dict]], n_bands: int, n_joints: int
|
||||||
|
) -> np.ndarray:
|
||||||
|
var_mat = np.full((n_bands, n_joints), np.nan)
|
||||||
|
for b, clist in bands.items():
|
||||||
|
if len(clist) < 3:
|
||||||
|
continue
|
||||||
|
means = np.stack([c["action_mean"] for c in clist])
|
||||||
|
var_mat[b] = np.var(means, axis=0)
|
||||||
|
return var_mat
|
||||||
|
|
||||||
|
|
||||||
|
def plot_variance_heatmap(
|
||||||
|
var_full: np.ndarray,
|
||||||
|
var_hq: np.ndarray,
|
||||||
|
band_edges: list[tuple[float, float]],
|
||||||
|
out_path: Path,
|
||||||
|
) -> None:
|
||||||
|
n_bands = var_full.shape[0]
|
||||||
|
vmin = 0.0
|
||||||
|
vmax = max(np.nanmax(var_full), np.nanmax(var_hq))
|
||||||
|
|
||||||
|
band_labels = [f"{lo:.0%}–{hi:.0%}" for lo, hi in band_edges]
|
||||||
|
joint_labels = [f"J{j}" for j in range(var_full.shape[1])]
|
||||||
|
|
||||||
|
fig, axes = plt.subplots(3, 1, figsize=(12, 10), gridspec_kw={"height_ratios": [3, 3, 2]})
|
||||||
|
fig.patch.set_facecolor(BG)
|
||||||
|
fig.suptitle("Intra-Band Action Variance", color=TEXT, fontsize=14, y=0.98)
|
||||||
|
|
||||||
|
for ax_idx, (mat, label) in enumerate([(var_full, "Full/Mixed"), (var_hq, "HQ")]):
|
||||||
|
ax = axes[ax_idx]
|
||||||
|
_style_ax(ax)
|
||||||
|
im = ax.imshow(mat, aspect="auto", cmap="YlOrRd", vmin=vmin, vmax=vmax)
|
||||||
|
ax.set_yticks(range(n_bands))
|
||||||
|
ax.set_yticklabels(band_labels, fontsize=7, color=SUB)
|
||||||
|
ax.set_xticks(range(var_full.shape[1]))
|
||||||
|
ax.set_xticklabels(joint_labels, fontsize=7, color=SUB)
|
||||||
|
ax.set_title(f"Panel {'A' if ax_idx == 0 else 'B'}: {label}", color=TEXT, fontsize=11)
|
||||||
|
fig.colorbar(im, ax=ax, fraction=0.02, pad=0.02)
|
||||||
|
|
||||||
|
with np.errstate(invalid="ignore"):
|
||||||
|
mean_full = np.nanmean(var_full, axis=1)
|
||||||
|
mean_hq = np.nanmean(var_hq, axis=1)
|
||||||
|
ratio = np.where(np.isnan(mean_full) | np.isnan(mean_hq), np.nan,
|
||||||
|
mean_full / (mean_hq + 1e-8))
|
||||||
|
ax_bar = axes[2]
|
||||||
|
_style_ax(ax_bar)
|
||||||
|
colors = [
|
||||||
|
"#ff4b4b" if r > 2.0 else "#ffaa33" if r > 1.2 else C_HQ
|
||||||
|
for r in ratio
|
||||||
|
]
|
||||||
|
ax_bar.bar(range(n_bands), ratio, color=colors, edgecolor=BORDER)
|
||||||
|
ax_bar.axhline(1.0, color=SUB, ls="--", lw=0.8)
|
||||||
|
ax_bar.set_xticks(range(n_bands))
|
||||||
|
ax_bar.set_xticklabels(band_labels, fontsize=7, color=SUB)
|
||||||
|
ax_bar.set_ylabel("Variance ratio\n(Full / HQ)", color=SUB, fontsize=9)
|
||||||
|
ax_bar.set_title("Panel C: Variance Ratio per Band", color=TEXT, fontsize=11)
|
||||||
|
|
||||||
|
fig.tight_layout(rect=[0, 0, 1, 0.96])
|
||||||
|
_save(fig, out_path)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Step 5: Progress delta per band ──────────────────────────────────────
|
||||||
|
|
||||||
|
def plot_progress_delta(
|
||||||
|
bands_full: dict[int, list[dict]],
|
||||||
|
bands_hq: dict[int, list[dict]],
|
||||||
|
band_edges: list[tuple[float, float]],
|
||||||
|
out_path: Path,
|
||||||
|
) -> None:
|
||||||
|
n_bands = len(band_edges)
|
||||||
|
band_labels = [f"{lo:.0%}–{hi:.0%}" for lo, hi in band_edges]
|
||||||
|
x = np.arange(n_bands)
|
||||||
|
w = 0.35
|
||||||
|
|
||||||
|
means_full, stds_full = [], []
|
||||||
|
means_hq, stds_hq = [], []
|
||||||
|
all_deltas_full, all_deltas_hq = [], []
|
||||||
|
|
||||||
|
for b in range(n_bands):
|
||||||
|
df = np.array([c["progress_delta"] for c in bands_full.get(b, [])])
|
||||||
|
dh = np.array([c["progress_delta"] for c in bands_hq.get(b, [])])
|
||||||
|
means_full.append(np.mean(df) if len(df) > 0 else 0)
|
||||||
|
stds_full.append(np.std(df) if len(df) > 0 else 0)
|
||||||
|
means_hq.append(np.mean(dh) if len(dh) > 0 else 0)
|
||||||
|
stds_hq.append(np.std(dh) if len(dh) > 0 else 0)
|
||||||
|
all_deltas_full.extend(df.tolist())
|
||||||
|
all_deltas_hq.extend(dh.tolist())
|
||||||
|
|
||||||
|
fig, (ax_bar, ax_viol) = plt.subplots(1, 2, figsize=(14, 5), gridspec_kw={"width_ratios": [3, 1]})
|
||||||
|
fig.patch.set_facecolor(BG)
|
||||||
|
fig.suptitle("Progress Delta per Chunk", color=TEXT, fontsize=14)
|
||||||
|
|
||||||
|
_style_ax(ax_bar)
|
||||||
|
ax_bar.bar(x - w / 2, means_full, w, yerr=stds_full, color=C_FULL, edgecolor=BORDER,
|
||||||
|
capsize=3, label="Full/Mixed", error_kw={"ecolor": SUB})
|
||||||
|
ax_bar.bar(x + w / 2, means_hq, w, yerr=stds_hq, color=C_HQ, edgecolor=BORDER,
|
||||||
|
capsize=3, label="HQ", error_kw={"ecolor": SUB})
|
||||||
|
ax_bar.set_xticks(x)
|
||||||
|
ax_bar.set_xticklabels(band_labels, fontsize=7, color=SUB, rotation=30)
|
||||||
|
ax_bar.set_ylabel("Mean progress Δ", color=SUB)
|
||||||
|
ax_bar.legend(facecolor=CARD, edgecolor=BORDER, labelcolor=TEXT, fontsize=8)
|
||||||
|
|
||||||
|
_style_ax(ax_viol)
|
||||||
|
data_viol = [np.array(all_deltas_full), np.array(all_deltas_hq)]
|
||||||
|
if all(len(d) > 0 for d in data_viol):
|
||||||
|
parts = ax_viol.violinplot(data_viol, positions=[0, 1], showmeans=True, showmedians=True)
|
||||||
|
for pc, c in zip(parts["bodies"], [C_FULL, C_HQ]):
|
||||||
|
pc.set_facecolor(c)
|
||||||
|
pc.set_alpha(0.7)
|
||||||
|
for key in ("cmeans", "cmedians", "cbars", "cmins", "cmaxes"):
|
||||||
|
if key in parts:
|
||||||
|
parts[key].set_color(SUB)
|
||||||
|
ax_viol.set_xticks([0, 1])
|
||||||
|
ax_viol.set_xticklabels(["Full", "HQ"], color=SUB)
|
||||||
|
ax_viol.set_ylabel("Progress Δ", color=SUB)
|
||||||
|
ax_viol.set_title("Overall Distribution", color=TEXT, fontsize=10)
|
||||||
|
|
||||||
|
fig.tight_layout()
|
||||||
|
_save(fig, out_path)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Step 6: GMM + BIC per band ──────────────────────────────────────────
|
||||||
|
|
||||||
|
def gmm_optimal_k(
|
||||||
|
band_chunks: list[dict],
|
||||||
|
pca_components: int = 15,
|
||||||
|
max_k: int = 12,
|
||||||
|
seed: int = 42,
|
||||||
|
) -> int | None:
|
||||||
|
if len(band_chunks) < 20:
|
||||||
|
return None
|
||||||
|
X = np.stack([c["action_flat"] for c in band_chunks])
|
||||||
|
X = StandardScaler().fit_transform(X)
|
||||||
|
n = min(pca_components, X.shape[1], X.shape[0] - 1)
|
||||||
|
X_r = PCA(n_components=n, random_state=seed).fit_transform(X)
|
||||||
|
bics = []
|
||||||
|
for k in range(1, min(max_k + 1, len(X_r) // 6)):
|
||||||
|
gmm = GaussianMixture(
|
||||||
|
n_components=k, covariance_type="full",
|
||||||
|
n_init=5, max_iter=300, random_state=seed,
|
||||||
|
)
|
||||||
|
gmm.fit(X_r)
|
||||||
|
bics.append((k, gmm.bic(X_r)))
|
||||||
|
if not bics:
|
||||||
|
return None
|
||||||
|
return min(bics, key=lambda x: x[1])[0]
|
||||||
|
|
||||||
|
|
||||||
|
def plot_gmm_bic(
|
||||||
|
bands_full: dict[int, list[dict]],
|
||||||
|
bands_hq: dict[int, list[dict]],
|
||||||
|
band_edges: list[tuple[float, float]],
|
||||||
|
seed: int,
|
||||||
|
out_path: Path,
|
||||||
|
) -> tuple[list[int | None], list[int | None]]:
|
||||||
|
n_bands = len(band_edges)
|
||||||
|
ks_full = [gmm_optimal_k(bands_full.get(b, []), seed=seed) for b in range(n_bands)]
|
||||||
|
ks_hq = [gmm_optimal_k(bands_hq.get(b, []), seed=seed) for b in range(n_bands)]
|
||||||
|
|
||||||
|
band_labels = [f"{lo:.0%}–{hi:.0%}" for lo, hi in band_edges]
|
||||||
|
|
||||||
|
fig, ax = plt.subplots(figsize=(10, 5))
|
||||||
|
fig.patch.set_facecolor(BG)
|
||||||
|
_style_ax(ax)
|
||||||
|
|
||||||
|
xs = np.arange(n_bands)
|
||||||
|
valid_full = [(i, k) for i, k in enumerate(ks_full) if k is not None]
|
||||||
|
valid_hq = [(i, k) for i, k in enumerate(ks_hq) if k is not None]
|
||||||
|
|
||||||
|
if valid_full:
|
||||||
|
xi, yi = zip(*valid_full)
|
||||||
|
ax.plot(xi, yi, "o-", color=C_FULL, label="Full/Mixed", lw=2, markersize=7)
|
||||||
|
if valid_hq:
|
||||||
|
xi, yi = zip(*valid_hq)
|
||||||
|
ax.plot(xi, yi, "o-", color=C_HQ, label="HQ", lw=2, markersize=7)
|
||||||
|
|
||||||
|
if valid_full and valid_hq:
|
||||||
|
all_x = sorted(set([i for i, _ in valid_full]) & set([i for i, _ in valid_hq]))
|
||||||
|
if len(all_x) >= 2:
|
||||||
|
kf_interp = {i: k for i, k in valid_full}
|
||||||
|
kh_interp = {i: k for i, k in valid_hq}
|
||||||
|
shared_x = [i for i in all_x if i in kf_interp and i in kh_interp]
|
||||||
|
yf = [kf_interp[i] for i in shared_x]
|
||||||
|
yh = [kh_interp[i] for i in shared_x]
|
||||||
|
ax.fill_between(shared_x, yf, yh, alpha=0.15, color=TEXT)
|
||||||
|
|
||||||
|
ax.set_xticks(xs)
|
||||||
|
ax.set_xticklabels(band_labels, fontsize=7, color=SUB, rotation=30)
|
||||||
|
ax.set_ylabel("Optimal K (GMM-BIC)", color=SUB)
|
||||||
|
ax.set_title("Number of Distinct Strategies per Band", color=TEXT, fontsize=13)
|
||||||
|
ax.legend(facecolor=CARD, edgecolor=BORDER, labelcolor=TEXT, fontsize=9)
|
||||||
|
ax.yaxis.set_major_locator(plt.MaxNLocator(integer=True))
|
||||||
|
fig.tight_layout()
|
||||||
|
_save(fig, out_path)
|
||||||
|
return ks_full, ks_hq
|
||||||
|
|
||||||
|
|
||||||
|
# ── Step 7: PCA scatter per band ────────────────────────────────────────
|
||||||
|
|
||||||
|
def plot_pca_scatter(
|
||||||
|
bands_full: dict[int, list[dict]],
|
||||||
|
bands_hq: dict[int, list[dict]],
|
||||||
|
band_edges: list[tuple[float, float]],
|
||||||
|
out_path: Path,
|
||||||
|
) -> None:
|
||||||
|
n_plot = min(4, len(band_edges))
|
||||||
|
fig, axes = plt.subplots(2, n_plot, figsize=(4 * n_plot, 7))
|
||||||
|
fig.patch.set_facecolor(BG)
|
||||||
|
fig.suptitle("PCA of Action Chunks per Band", color=TEXT, fontsize=14)
|
||||||
|
|
||||||
|
if n_plot == 1:
|
||||||
|
axes = axes.reshape(2, 1)
|
||||||
|
|
||||||
|
for col, b in enumerate(range(n_plot)):
|
||||||
|
cf = bands_full.get(b, [])
|
||||||
|
ch = bands_hq.get(b, [])
|
||||||
|
lo, hi = band_edges[b]
|
||||||
|
|
||||||
|
for row, (clist, color, label) in enumerate([
|
||||||
|
(cf, C_FULL, "Full/Mixed"), (ch, C_HQ, "HQ")
|
||||||
|
]):
|
||||||
|
ax = axes[row, col]
|
||||||
|
_style_ax(ax)
|
||||||
|
if row == 0:
|
||||||
|
ax.set_title(f"{lo:.0%}–{hi:.0%}", color=TEXT, fontsize=10)
|
||||||
|
if col == 0:
|
||||||
|
ax.set_ylabel(label, color=SUB, fontsize=9)
|
||||||
|
|
||||||
|
if len(cf) < 3 or len(ch) < 3:
|
||||||
|
ax.text(0.5, 0.5, "Too few\nchunks", transform=ax.transAxes,
|
||||||
|
ha="center", va="center", color=SUB, fontsize=9)
|
||||||
|
continue
|
||||||
|
|
||||||
|
X_full_b = np.stack([c["action_flat"] for c in cf])
|
||||||
|
X_hq_b = np.stack([c["action_flat"] for c in ch])
|
||||||
|
X_all = np.vstack([X_full_b, X_hq_b])
|
||||||
|
X_all = StandardScaler().fit_transform(X_all)
|
||||||
|
X_2d = PCA(n_components=2, random_state=42).fit_transform(X_all)
|
||||||
|
|
||||||
|
X_2d_full = X_2d[: len(cf)]
|
||||||
|
X_2d_hq = X_2d[len(cf) :]
|
||||||
|
|
||||||
|
pts = X_2d_full if row == 0 else X_2d_hq
|
||||||
|
ax.scatter(pts[:, 0], pts[:, 1], s=8, alpha=0.5, color=color, edgecolors="none")
|
||||||
|
|
||||||
|
fig.tight_layout(rect=[0, 0, 1, 0.95])
|
||||||
|
_save(fig, out_path)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Plot 1: Chunk counts per band ───────────────────────────────────────
|
||||||
|
|
||||||
|
def plot_chunk_counts(
|
||||||
|
bands_full: dict[int, list[dict]],
|
||||||
|
bands_hq: dict[int, list[dict]],
|
||||||
|
band_edges: list[tuple[float, float]],
|
||||||
|
out_path: Path,
|
||||||
|
) -> None:
|
||||||
|
n_bands = len(band_edges)
|
||||||
|
band_labels = [f"{lo:.0%}–{hi:.0%}" for lo, hi in band_edges]
|
||||||
|
x = np.arange(n_bands)
|
||||||
|
w = 0.35
|
||||||
|
|
||||||
|
counts_full = [len(bands_full.get(b, [])) for b in range(n_bands)]
|
||||||
|
counts_hq = [len(bands_hq.get(b, [])) for b in range(n_bands)]
|
||||||
|
|
||||||
|
fig, ax = plt.subplots(figsize=(10, 5))
|
||||||
|
fig.patch.set_facecolor(BG)
|
||||||
|
_style_ax(ax)
|
||||||
|
|
||||||
|
ax.bar(x - w / 2, counts_full, w, color=C_FULL, edgecolor=BORDER, label="Full/Mixed")
|
||||||
|
ax.bar(x + w / 2, counts_hq, w, color=C_HQ, edgecolor=BORDER, label="HQ")
|
||||||
|
ax.set_xticks(x)
|
||||||
|
ax.set_xticklabels(band_labels, fontsize=7, color=SUB, rotation=30)
|
||||||
|
ax.set_ylabel("Chunk count", color=SUB)
|
||||||
|
ax.set_title("Chunk Counts per Progress Band", color=TEXT, fontsize=13)
|
||||||
|
ax.legend(facecolor=CARD, edgecolor=BORDER, labelcolor=TEXT, fontsize=8)
|
||||||
|
fig.tight_layout()
|
||||||
|
_save(fig, out_path)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Summary figure ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def plot_summary(
|
||||||
|
var_full: np.ndarray,
|
||||||
|
var_hq: np.ndarray,
|
||||||
|
band_edges: list[tuple[float, float]],
|
||||||
|
ks_full: list[int | None],
|
||||||
|
ks_hq: list[int | None],
|
||||||
|
bands_full: dict[int, list[dict]],
|
||||||
|
bands_hq: dict[int, list[dict]],
|
||||||
|
out_path: Path,
|
||||||
|
) -> None:
|
||||||
|
with np.errstate(invalid="ignore"):
|
||||||
|
mean_full = np.nanmean(var_full, axis=1)
|
||||||
|
mean_hq = np.nanmean(var_hq, axis=1)
|
||||||
|
ratio = np.where(np.isnan(mean_full) | np.isnan(mean_hq), np.nan,
|
||||||
|
mean_full / (mean_hq + 1e-8))
|
||||||
|
valid_ratio = ratio[~np.isnan(ratio)]
|
||||||
|
mean_ratio = float(np.mean(valid_ratio)) if len(valid_ratio) > 0 else float("nan")
|
||||||
|
peak_idx = int(np.argmax(valid_ratio)) if len(valid_ratio) > 0 else 0
|
||||||
|
peak_ratio = float(valid_ratio[peak_idx]) if len(valid_ratio) > 0 else float("nan")
|
||||||
|
lo, hi = band_edges[peak_idx]
|
||||||
|
peak_band = f"{lo:.0%}–{hi:.0%}"
|
||||||
|
|
||||||
|
valid_kf = [k for k in ks_full if k is not None]
|
||||||
|
valid_kh = [k for k in ks_hq if k is not None]
|
||||||
|
mean_k_full = np.mean(valid_kf) if valid_kf else float("nan")
|
||||||
|
mean_k_hq = np.mean(valid_kh) if valid_kh else float("nan")
|
||||||
|
|
||||||
|
n_bands = len(band_edges)
|
||||||
|
deltas_full = [c["progress_delta"] for b in range(n_bands) for c in bands_full.get(b, [])]
|
||||||
|
deltas_hq = [c["progress_delta"] for b in range(n_bands) for c in bands_hq.get(b, [])]
|
||||||
|
mean_delta_full = float(np.mean(deltas_full)) if deltas_full else float("nan")
|
||||||
|
mean_delta_hq = float(np.mean(deltas_hq)) if deltas_hq else float("nan")
|
||||||
|
|
||||||
|
rows = [
|
||||||
|
("Mean variance ratio (Full / HQ)", f"{mean_ratio:.2f}x"),
|
||||||
|
("Peak variance ratio", f"{peak_ratio:.2f}x at {peak_band}"),
|
||||||
|
("Mean GMM K — Full", f"{mean_k_full:.1f}"),
|
||||||
|
("Mean GMM K — HQ", f"{mean_k_hq:.1f}"),
|
||||||
|
("Mean progress Δ — Full", f"{mean_delta_full:.4f}"),
|
||||||
|
("Mean progress Δ — HQ", f"{mean_delta_hq:.4f}"),
|
||||||
|
]
|
||||||
|
|
||||||
|
fig, ax = plt.subplots(figsize=(8, 3))
|
||||||
|
fig.patch.set_facecolor(BG)
|
||||||
|
ax.set_facecolor(CARD)
|
||||||
|
ax.axis("off")
|
||||||
|
|
||||||
|
table = ax.table(
|
||||||
|
cellText=[[m, v] for m, v in rows],
|
||||||
|
colLabels=["Metric", "Value"],
|
||||||
|
loc="center",
|
||||||
|
cellLoc="left",
|
||||||
|
)
|
||||||
|
table.auto_set_font_size(False)
|
||||||
|
table.set_fontsize(10)
|
||||||
|
for key, cell in table.get_celld().items():
|
||||||
|
cell.set_edgecolor(BORDER)
|
||||||
|
cell.set_facecolor(CARD)
|
||||||
|
cell.set_text_props(color=TEXT)
|
||||||
|
if key[0] == 0:
|
||||||
|
cell.set_text_props(color=TEXT, fontweight="bold")
|
||||||
|
table.scale(1, 1.6)
|
||||||
|
ax.set_title("Summary Statistics", color=TEXT, fontsize=13, pad=15)
|
||||||
|
fig.tight_layout()
|
||||||
|
_save(fig, out_path)
|
||||||
|
|
||||||
|
for metric, value in rows:
|
||||||
|
logger.info(" %s: %s", metric, value)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Main ─────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def main(args: argparse.Namespace) -> None:
|
||||||
|
out = Path(args.output_dir)
|
||||||
|
out.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
logger.info("Loading FULL dataset: %s", args.full_dataset)
|
||||||
|
episodes_full = load_episodes(args.full_dataset, args.n_joints, args.max_episodes)
|
||||||
|
logger.info("Loading HQ dataset: %s", args.hq_dataset)
|
||||||
|
episodes_hq = load_episodes(args.hq_dataset, args.n_joints, args.max_episodes)
|
||||||
|
logger.info("Loaded %d full episodes, %d HQ episodes", len(episodes_full), len(episodes_hq))
|
||||||
|
|
||||||
|
# Step 1: length threshold + filter
|
||||||
|
if args.min_episode_length is not None:
|
||||||
|
threshold = args.min_episode_length
|
||||||
|
else:
|
||||||
|
threshold = auto_length_threshold(episodes_full, episodes_hq)
|
||||||
|
logger.info("Episode length threshold: %d", threshold)
|
||||||
|
|
||||||
|
plot_length_distribution(episodes_full, episodes_hq, threshold, out / "0_length_distribution.png")
|
||||||
|
episodes_full = filter_episodes(episodes_full, threshold)
|
||||||
|
episodes_hq = filter_episodes(episodes_hq, threshold)
|
||||||
|
|
||||||
|
# Step 2: extract chunks
|
||||||
|
chunks_full = extract_chunks(episodes_full, args.chunk_size, args.chunk_stride)
|
||||||
|
chunks_hq = extract_chunks(episodes_hq, args.chunk_size, args.chunk_stride)
|
||||||
|
logger.info("Extracted %d full chunks, %d HQ chunks", len(chunks_full), len(chunks_hq))
|
||||||
|
|
||||||
|
# Step 3: fixed equal-width bands over episode-relative progress
|
||||||
|
band_edges = make_bands(args.n_bands)
|
||||||
|
n_bands = len(band_edges)
|
||||||
|
logger.info("Progress bands (%d): %s", n_bands,
|
||||||
|
[f"{lo:.0%}–{hi:.0%}" for lo, hi in band_edges])
|
||||||
|
|
||||||
|
chunks_full = assign_bands(chunks_full, band_edges)
|
||||||
|
chunks_hq = assign_bands(chunks_hq, band_edges)
|
||||||
|
bands_full = split_by_band(chunks_full, n_bands)
|
||||||
|
bands_hq = split_by_band(chunks_hq, n_bands)
|
||||||
|
|
||||||
|
# Plot 1: chunk counts
|
||||||
|
plot_chunk_counts(bands_full, bands_hq, band_edges, out / "1_chunk_counts_per_band.png")
|
||||||
|
|
||||||
|
# Step 4: variance heatmap
|
||||||
|
var_full = band_variance_matrix(bands_full, n_bands, args.n_joints)
|
||||||
|
var_hq = band_variance_matrix(bands_hq, n_bands, args.n_joints)
|
||||||
|
plot_variance_heatmap(var_full, var_hq, band_edges, out / "2_variance_heatmap.png")
|
||||||
|
|
||||||
|
# Step 5: progress delta
|
||||||
|
plot_progress_delta(bands_full, bands_hq, band_edges, out / "3_progress_delta_per_band.png")
|
||||||
|
|
||||||
|
# Step 6: GMM BIC
|
||||||
|
ks_full, ks_hq = plot_gmm_bic(bands_full, bands_hq, band_edges, args.seed, out / "4_gmm_bic_per_band.png")
|
||||||
|
|
||||||
|
# Step 7: PCA scatter
|
||||||
|
plot_pca_scatter(bands_full, bands_hq, band_edges, out / "5_pca_per_band.png")
|
||||||
|
|
||||||
|
# Summary
|
||||||
|
plot_summary(var_full, var_hq, band_edges, ks_full, ks_hq,
|
||||||
|
bands_full, bands_hq, out / "6_summary.png")
|
||||||
|
|
||||||
|
logger.info("All figures saved to %s", out)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
p = argparse.ArgumentParser(
|
||||||
|
description="Chunk-level multi-modality analysis: Full/Mixed vs HQ dataset.",
|
||||||
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||||
|
)
|
||||||
|
p.add_argument("--full-dataset", default="lerobot-data-collection/level12_rac_2_2026-02-08_1")
|
||||||
|
p.add_argument("--hq-dataset", default="lerobot-data-collection/level2_final_quality3_trim_0_hil_data")
|
||||||
|
p.add_argument("--output-dir", default="./chunk_analysis")
|
||||||
|
p.add_argument("--chunk-size", type=int, default=30)
|
||||||
|
p.add_argument("--chunk-stride", type=int, default=15)
|
||||||
|
p.add_argument("--n-bands", type=int, default=5, help="Number of equal-width progress bands")
|
||||||
|
p.add_argument("--max-episodes", type=int, default=500)
|
||||||
|
p.add_argument("--n-joints", type=int, default=16)
|
||||||
|
p.add_argument("--min-episode-length", type=int, default=None,
|
||||||
|
help="Override auto-detected length filter threshold")
|
||||||
|
p.add_argument("--seed", type=int, default=42)
|
||||||
|
args = p.parse_args()
|
||||||
|
main(args)
|
||||||
@@ -0,0 +1,29 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
#SBATCH --job-name=smolvla_libero_plus
|
||||||
|
#SBATCH --partition=hopper-prod
|
||||||
|
#SBATCH --nodes=1
|
||||||
|
#SBATCH --ntasks-per-node=1
|
||||||
|
#SBATCH --gpus-per-node=4
|
||||||
|
#SBATCH --cpus-per-task=48
|
||||||
|
#SBATCH --mem=200G
|
||||||
|
#SBATCH --time=12:00:00
|
||||||
|
#SBATCH --output=logs/smolvla_libero_plus_%j.out
|
||||||
|
#SBATCH --error=logs/smolvla_libero_plus_%j.err
|
||||||
|
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
eval "$(conda shell.bash hook 2>/dev/null)"
|
||||||
|
conda activate lerobot312
|
||||||
|
|
||||||
|
cd /admin/home/pepijn/lerobot_wt_robocasa
|
||||||
|
|
||||||
|
lerobot-benchmark train \
|
||||||
|
--benchmarks libero_plus \
|
||||||
|
--policy-path lerobot/smolvla_base \
|
||||||
|
--hub-user pepijn223 \
|
||||||
|
--num-gpus 4 \
|
||||||
|
--steps 30000 \
|
||||||
|
--batch-size 32 \
|
||||||
|
--eval-freq 0 \
|
||||||
|
--wandb \
|
||||||
|
--dataset.repo_id=pepijn223/libero_plus_lerobot
|
||||||
+92
-12
@@ -16,6 +16,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from collections.abc import Callable, Iterable, Mapping, Sequence
|
from collections.abc import Callable, Iterable, Mapping, Sequence
|
||||||
from functools import partial
|
from functools import partial
|
||||||
@@ -28,12 +29,51 @@ import torch
|
|||||||
from gymnasium import spaces
|
from gymnasium import spaces
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from libero.libero import benchmark, get_libero_path
|
import libero as _libero_pkg # noqa: F401
|
||||||
from libero.libero.envs import OffScreenRenderEnv
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
# LIBERO-plus may be installed from source with an extra nested package level.
|
raise ImportError(
|
||||||
from libero.libero.libero import benchmark, get_libero_path
|
"Could not import libero. Install benchmark dependencies with one of:\n"
|
||||||
from libero.libero.libero.envs import OffScreenRenderEnv
|
" pip install -e \".[libero]\"\n"
|
||||||
|
" pip install -e \".[libero_plus]\" (alias: \".[libero-plus]\")"
|
||||||
|
)
|
||||||
|
|
||||||
|
# LIBERO's env_wrapper unconditionally imports wand (ImageMagick Python binding)
|
||||||
|
# which requires the system-level libMagickWand library. The wand features are only
|
||||||
|
# used for visual noise perturbations and are not needed for standard evaluation.
|
||||||
|
# Pre-install a stub so the import succeeds even without ImageMagick.
|
||||||
|
import sys
|
||||||
|
import types
|
||||||
|
|
||||||
|
if "wand" not in sys.modules:
|
||||||
|
try:
|
||||||
|
import wand.api # noqa: F401
|
||||||
|
except (ImportError, OSError):
|
||||||
|
|
||||||
|
class _AttrSink:
|
||||||
|
"""Accepts any attribute get/set without error."""
|
||||||
|
|
||||||
|
def __getattr__(self, _name):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __setattr__(self, _name, _value):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __call__(self, *a, **kw):
|
||||||
|
pass
|
||||||
|
|
||||||
|
_wand = types.ModuleType("wand")
|
||||||
|
_wand_api = types.ModuleType("wand.api")
|
||||||
|
_wand_api.library = _AttrSink()
|
||||||
|
_wand_image = types.ModuleType("wand.image")
|
||||||
|
_wand_image.Image = type("Image", (), {})
|
||||||
|
_wand.api = _wand_api
|
||||||
|
_wand.image = _wand_image
|
||||||
|
sys.modules["wand"] = _wand
|
||||||
|
sys.modules["wand.api"] = _wand_api
|
||||||
|
sys.modules["wand.image"] = _wand_image
|
||||||
|
|
||||||
|
from libero.libero import benchmark, get_libero_path
|
||||||
|
from libero.libero.envs import OffScreenRenderEnv
|
||||||
|
|
||||||
from lerobot.processor import RobotObservation
|
from lerobot.processor import RobotObservation
|
||||||
|
|
||||||
@@ -74,13 +114,30 @@ def _select_task_ids(total_tasks: int, task_ids: Iterable[int] | None) -> list[i
|
|||||||
|
|
||||||
|
|
||||||
def get_task_init_states(task_suite: Any, i: int) -> np.ndarray:
|
def get_task_init_states(task_suite: Any, i: int) -> np.ndarray:
|
||||||
init_states_path = (
|
init_states_dir = Path(get_libero_path("init_states")) / task_suite.tasks[i].problem_folder
|
||||||
Path(get_libero_path("init_states"))
|
init_states_file = task_suite.tasks[i].init_states_file
|
||||||
/ task_suite.tasks[i].problem_folder
|
|
||||||
/ task_suite.tasks[i].init_states_file
|
candidate_names = [init_states_file]
|
||||||
|
# Some LIBERO-plus task names include a "_table_<n>" suffix while shipped
|
||||||
|
# init files use the base name without that table suffix.
|
||||||
|
if "_table_" in init_states_file:
|
||||||
|
candidate_names.append(re.sub(r"_table_\d+(?=\.pruned_init$|\.init$)", "", init_states_file))
|
||||||
|
|
||||||
|
for name in candidate_names:
|
||||||
|
candidate_path = init_states_dir / name
|
||||||
|
if candidate_path.exists():
|
||||||
|
return torch.load(candidate_path, weights_only=False) # nosec B614
|
||||||
|
|
||||||
|
# Last-resort fallback: pick any file matching the base prefix + extension.
|
||||||
|
stem, suffix = os.path.splitext(init_states_file)
|
||||||
|
stem = re.sub(r"_table_\d+$", "", stem)
|
||||||
|
fallback_matches = sorted(init_states_dir.glob(f"{stem}*{suffix}"))
|
||||||
|
if fallback_matches:
|
||||||
|
return torch.load(fallback_matches[0], weights_only=False) # nosec B614
|
||||||
|
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"Could not find init states for task {i}. Tried {candidate_names} in '{init_states_dir}'."
|
||||||
)
|
)
|
||||||
init_states = torch.load(init_states_path, weights_only=False) # nosec B614
|
|
||||||
return init_states
|
|
||||||
|
|
||||||
|
|
||||||
def get_libero_dummy_action():
|
def get_libero_dummy_action():
|
||||||
@@ -100,6 +157,29 @@ TASK_SUITE_MAX_STEPS: dict[str, int] = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _make_offscreen_env_with_renderer_fallback(env_args: dict[str, Any]) -> Any:
|
||||||
|
"""Create OffScreenRenderEnv and fallback to OSMesa if EGL is unavailable."""
|
||||||
|
try:
|
||||||
|
return OffScreenRenderEnv(**env_args)
|
||||||
|
except ImportError as exc:
|
||||||
|
msg = str(exc)
|
||||||
|
if "EGL" not in msg and "PLATFORM_DEVICE" not in msg:
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Headless clusters often miss EGL PLATFORM_DEVICE support. Retry with
|
||||||
|
# software rendering to keep evaluation working.
|
||||||
|
os.environ["MUJOCO_GL"] = "osmesa"
|
||||||
|
os.environ["PYOPENGL_PLATFORM"] = "osmesa"
|
||||||
|
try:
|
||||||
|
return OffScreenRenderEnv(**env_args)
|
||||||
|
except Exception as fallback_exc:
|
||||||
|
raise ImportError(
|
||||||
|
"Failed to initialize robosuite offscreen renderer with both EGL and "
|
||||||
|
"OSMesa backends. Set up EGL-capable drivers or install OSMesa (e.g. "
|
||||||
|
"`conda install -c conda-forge mesalib`) and retry."
|
||||||
|
) from fallback_exc
|
||||||
|
|
||||||
|
|
||||||
class LiberoEnv(gym.Env):
|
class LiberoEnv(gym.Env):
|
||||||
metadata = {"render_modes": ["rgb_array"], "render_fps": 80}
|
metadata = {"render_modes": ["rgb_array"], "render_fps": 80}
|
||||||
|
|
||||||
@@ -244,7 +324,7 @@ class LiberoEnv(gym.Env):
|
|||||||
"camera_heights": self.observation_height,
|
"camera_heights": self.observation_height,
|
||||||
"camera_widths": self.observation_width,
|
"camera_widths": self.observation_width,
|
||||||
}
|
}
|
||||||
env = OffScreenRenderEnv(**env_args)
|
env = _make_offscreen_env_with_renderer_fallback(env_args)
|
||||||
env.reset()
|
env.reset()
|
||||||
return env
|
return env
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,462 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Benchmark runner: train and evaluate policies across simulation benchmarks.
|
||||||
|
|
||||||
|
Orchestrates per-benchmark training and evaluation using the existing
|
||||||
|
``lerobot-train`` and ``lerobot-eval`` CLI tools.
|
||||||
|
|
||||||
|
Typical usage::
|
||||||
|
|
||||||
|
# Train SmolVLA on LIBERO-plus (4 GPUs, 50k steps):
|
||||||
|
lerobot-benchmark train \\
|
||||||
|
--benchmarks libero_plus \\
|
||||||
|
--policy-path lerobot/smolvla_base \\
|
||||||
|
--hub-user $HF_USER \\
|
||||||
|
--num-gpus 4 --steps 50000
|
||||||
|
|
||||||
|
# Evaluate the trained policies:
|
||||||
|
lerobot-benchmark eval \\
|
||||||
|
--benchmarks libero_plus \\
|
||||||
|
--hub-user $HF_USER
|
||||||
|
|
||||||
|
# Full pipeline (train → upload → eval) for multiple benchmarks:
|
||||||
|
lerobot-benchmark all \\
|
||||||
|
--benchmarks libero_plus,robocasa,robomme \\
|
||||||
|
--policy-path lerobot/smolvla_base \\
|
||||||
|
--hub-user $HF_USER \\
|
||||||
|
--num-gpus 4 --steps 50000
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import shutil
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BenchmarkEntry:
|
||||||
|
"""Training + evaluation settings for a single benchmark.
|
||||||
|
|
||||||
|
When ``eval_tasks`` is set, evaluation runs once per task in the list
|
||||||
|
(e.g. libero_spatial, libero_object, …). ``env_task`` is still used as
|
||||||
|
the task for mid-training evaluation during ``lerobot-train``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
dataset_repo_id: str
|
||||||
|
env_type: str
|
||||||
|
env_task: str
|
||||||
|
eval_tasks: list[str] | None = None
|
||||||
|
train_overrides: dict[str, str] = field(default_factory=dict)
|
||||||
|
eval_overrides: dict[str, str] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
LIBERO_SUITES = ["libero_spatial", "libero_object", "libero_goal", "libero_10"]
|
||||||
|
|
||||||
|
# Each benchmark maps a human-readable name to its dataset and eval env.
|
||||||
|
# ``dataset_repo_id`` can contain ``{hub_user}`` which is interpolated at
|
||||||
|
# runtime from ``--hub-user``.
|
||||||
|
BENCHMARK_REGISTRY: dict[str, BenchmarkEntry] = {
|
||||||
|
"libero": BenchmarkEntry(
|
||||||
|
dataset_repo_id="{hub_user}/libero",
|
||||||
|
env_type="libero",
|
||||||
|
env_task="libero_spatial",
|
||||||
|
eval_tasks=LIBERO_SUITES,
|
||||||
|
),
|
||||||
|
"libero_plus": BenchmarkEntry(
|
||||||
|
dataset_repo_id="{hub_user}/libero_plus",
|
||||||
|
env_type="libero_plus",
|
||||||
|
env_task="libero_spatial",
|
||||||
|
eval_tasks=LIBERO_SUITES,
|
||||||
|
),
|
||||||
|
"metaworld": BenchmarkEntry(
|
||||||
|
dataset_repo_id="{hub_user}/metaworld",
|
||||||
|
env_type="metaworld",
|
||||||
|
env_task="metaworld-push-v2",
|
||||||
|
),
|
||||||
|
"robocasa": BenchmarkEntry(
|
||||||
|
dataset_repo_id="{hub_user}/robocasa",
|
||||||
|
env_type="robocasa",
|
||||||
|
env_task="PickPlaceCounterToCabinet",
|
||||||
|
),
|
||||||
|
"robomme": BenchmarkEntry(
|
||||||
|
dataset_repo_id="{hub_user}/robomme",
|
||||||
|
env_type="robomme",
|
||||||
|
env_task="PickXtimes",
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _policy_repo_id(hub_user: str, policy_name: str, benchmark: str) -> str:
|
||||||
|
return f"{hub_user}/{policy_name}_{benchmark}"
|
||||||
|
|
||||||
|
|
||||||
|
def _extra_keys(extra_args: list[str]) -> set[str]:
|
||||||
|
"""Extract ``--key`` prefixes from extra CLI args for override detection."""
|
||||||
|
keys: set[str] = set()
|
||||||
|
for arg in extra_args:
|
||||||
|
if arg.startswith("--") and "=" in arg:
|
||||||
|
keys.add(arg.split("=", 1)[0])
|
||||||
|
return keys
|
||||||
|
|
||||||
|
|
||||||
|
def _build_train_cmd(
|
||||||
|
benchmark: BenchmarkEntry,
|
||||||
|
*,
|
||||||
|
policy_path: str,
|
||||||
|
hub_user: str,
|
||||||
|
policy_name: str,
|
||||||
|
benchmark_name: str,
|
||||||
|
num_gpus: int,
|
||||||
|
steps: int,
|
||||||
|
batch_size: int,
|
||||||
|
eval_freq: int,
|
||||||
|
save_freq: int,
|
||||||
|
wandb: bool,
|
||||||
|
extra_args: list[str],
|
||||||
|
) -> list[str]:
|
||||||
|
"""Build the ``accelerate launch lerobot-train`` command list."""
|
||||||
|
lerobot_train = shutil.which("lerobot-train")
|
||||||
|
if lerobot_train is None:
|
||||||
|
raise RuntimeError("lerobot-train not found on PATH. Is lerobot installed?")
|
||||||
|
|
||||||
|
# Strip bare "--" separators that argparse may pass through
|
||||||
|
cleaned_extra = [a for a in extra_args if a != "--"]
|
||||||
|
overridden = _extra_keys(cleaned_extra)
|
||||||
|
|
||||||
|
repo_id = _policy_repo_id(hub_user, policy_name, benchmark_name)
|
||||||
|
dataset_id = benchmark.dataset_repo_id.format(hub_user=hub_user)
|
||||||
|
|
||||||
|
defaults: list[tuple[str, str]] = [
|
||||||
|
("--policy.path", policy_path),
|
||||||
|
("--dataset.repo_id", dataset_id),
|
||||||
|
("--policy.repo_id", repo_id),
|
||||||
|
("--env.type", benchmark.env_type),
|
||||||
|
("--env.task", benchmark.env_task),
|
||||||
|
("--steps", str(steps)),
|
||||||
|
("--batch_size", str(batch_size)),
|
||||||
|
("--eval_freq", str(eval_freq)),
|
||||||
|
("--save_freq", str(save_freq)),
|
||||||
|
("--output_dir", f"outputs/train/{policy_name}_{benchmark_name}"),
|
||||||
|
("--job_name", f"{policy_name}_{benchmark_name}"),
|
||||||
|
("--policy.push_to_hub", "true"),
|
||||||
|
]
|
||||||
|
if wandb:
|
||||||
|
defaults.append(("--wandb.enable", "true"))
|
||||||
|
for k, v in benchmark.train_overrides.items():
|
||||||
|
defaults.append((f"--{k}", v))
|
||||||
|
|
||||||
|
cmd: list[str] = [
|
||||||
|
"accelerate", "launch",
|
||||||
|
"--multi_gpu",
|
||||||
|
f"--num_processes={num_gpus}",
|
||||||
|
lerobot_train,
|
||||||
|
]
|
||||||
|
for key, val in defaults:
|
||||||
|
if key not in overridden:
|
||||||
|
cmd.append(f"{key}={val}")
|
||||||
|
cmd.extend(cleaned_extra)
|
||||||
|
return cmd
|
||||||
|
|
||||||
|
|
||||||
|
def _build_eval_cmd(
|
||||||
|
benchmark: BenchmarkEntry,
|
||||||
|
*,
|
||||||
|
hub_user: str,
|
||||||
|
policy_name: str,
|
||||||
|
benchmark_name: str,
|
||||||
|
eval_task: str | None = None,
|
||||||
|
n_episodes: int,
|
||||||
|
batch_size_eval: int,
|
||||||
|
extra_args: list[str],
|
||||||
|
) -> list[str]:
|
||||||
|
"""Build the ``lerobot-eval`` command list.
|
||||||
|
|
||||||
|
``eval_task`` overrides the benchmark's ``env_task`` so the same
|
||||||
|
benchmark can be evaluated on multiple suites (e.g. LIBERO).
|
||||||
|
"""
|
||||||
|
lerobot_eval = shutil.which("lerobot-eval")
|
||||||
|
if lerobot_eval is None:
|
||||||
|
raise RuntimeError("lerobot-eval not found on PATH. Is lerobot installed?")
|
||||||
|
|
||||||
|
task = eval_task or benchmark.env_task
|
||||||
|
repo_id = _policy_repo_id(hub_user, policy_name, benchmark_name)
|
||||||
|
out_dir = _eval_output_dir(policy_name, benchmark_name, eval_task=task)
|
||||||
|
|
||||||
|
cleaned_extra = [a for a in extra_args if a != "--"]
|
||||||
|
overridden = _extra_keys(cleaned_extra)
|
||||||
|
|
||||||
|
defaults: list[tuple[str, str]] = [
|
||||||
|
("--policy.path", repo_id),
|
||||||
|
("--env.type", benchmark.env_type),
|
||||||
|
("--env.task", task),
|
||||||
|
("--eval.n_episodes", str(n_episodes)),
|
||||||
|
("--eval.batch_size", str(batch_size_eval)),
|
||||||
|
("--output_dir", out_dir),
|
||||||
|
("--policy.device", "cuda"),
|
||||||
|
]
|
||||||
|
for k, v in benchmark.eval_overrides.items():
|
||||||
|
defaults.append((f"--{k}", v))
|
||||||
|
|
||||||
|
cmd: list[str] = [lerobot_eval]
|
||||||
|
for key, val in defaults:
|
||||||
|
if key not in overridden:
|
||||||
|
cmd.append(f"{key}={val}")
|
||||||
|
cmd.extend(cleaned_extra)
|
||||||
|
return cmd
|
||||||
|
|
||||||
|
|
||||||
|
def _eval_output_dir(policy_name: str, benchmark_name: str, eval_task: str | None = None) -> Path:
|
||||||
|
if eval_task:
|
||||||
|
return Path(f"outputs/eval/{policy_name}_{benchmark_name}/{eval_task}")
|
||||||
|
return Path(f"outputs/eval/{policy_name}_{benchmark_name}")
|
||||||
|
|
||||||
|
|
||||||
|
def _run(cmd: list[str], *, dry_run: bool) -> None:
|
||||||
|
log.info("Command: %s", " \\\n ".join(cmd))
|
||||||
|
if dry_run:
|
||||||
|
log.info("[dry-run] Skipping execution.")
|
||||||
|
return
|
||||||
|
result = subprocess.run(cmd, check=False)
|
||||||
|
if result.returncode != 0:
|
||||||
|
log.error("Command failed with exit code %d", result.returncode)
|
||||||
|
sys.exit(result.returncode)
|
||||||
|
|
||||||
|
|
||||||
|
def _push_eval_to_hub(
|
||||||
|
*,
|
||||||
|
hub_user: str,
|
||||||
|
policy_name: str,
|
||||||
|
benchmark_name: str,
|
||||||
|
eval_task: str | None = None,
|
||||||
|
dry_run: bool,
|
||||||
|
) -> None:
|
||||||
|
"""Upload eval results (metrics + videos) to the policy repo on the Hub."""
|
||||||
|
from huggingface_hub import HfApi
|
||||||
|
|
||||||
|
repo_id = _policy_repo_id(hub_user, policy_name, benchmark_name)
|
||||||
|
local_dir = _eval_output_dir(policy_name, benchmark_name, eval_task=eval_task)
|
||||||
|
hub_path = f"eval/{eval_task}" if eval_task else f"eval/{benchmark_name}"
|
||||||
|
|
||||||
|
if not local_dir.exists():
|
||||||
|
log.warning("Eval output dir %s does not exist, skipping hub upload.", local_dir)
|
||||||
|
return
|
||||||
|
|
||||||
|
log.info("Uploading eval results from %s to %s (path_in_repo=%s)", local_dir, repo_id, hub_path)
|
||||||
|
if dry_run:
|
||||||
|
log.info("[dry-run] Skipping upload.")
|
||||||
|
return
|
||||||
|
|
||||||
|
api = HfApi()
|
||||||
|
api.upload_folder(
|
||||||
|
folder_path=str(local_dir),
|
||||||
|
repo_id=repo_id,
|
||||||
|
path_in_repo=hub_path,
|
||||||
|
repo_type="model",
|
||||||
|
commit_message=f"Upload eval results for {eval_task or benchmark_name}",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_benchmarks(names: str) -> list[tuple[str, BenchmarkEntry]]:
|
||||||
|
out = []
|
||||||
|
for name in names.split(","):
|
||||||
|
name = name.strip()
|
||||||
|
if name not in BENCHMARK_REGISTRY:
|
||||||
|
available = ", ".join(BENCHMARK_REGISTRY)
|
||||||
|
raise ValueError(f"Unknown benchmark '{name}'. Available: {available}")
|
||||||
|
out.append((name, BENCHMARK_REGISTRY[name]))
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def cmd_train(args: argparse.Namespace) -> None:
|
||||||
|
benchmarks = _resolve_benchmarks(args.benchmarks)
|
||||||
|
for bname, bentry in benchmarks:
|
||||||
|
log.info("=== Training on benchmark: %s ===", bname)
|
||||||
|
cmd = _build_train_cmd(
|
||||||
|
bentry,
|
||||||
|
policy_path=args.policy_path,
|
||||||
|
hub_user=args.hub_user,
|
||||||
|
policy_name=args.policy_name,
|
||||||
|
benchmark_name=bname,
|
||||||
|
num_gpus=args.num_gpus,
|
||||||
|
steps=args.steps,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
eval_freq=args.eval_freq,
|
||||||
|
save_freq=args.save_freq,
|
||||||
|
wandb=args.wandb,
|
||||||
|
extra_args=args.extra,
|
||||||
|
)
|
||||||
|
_run(cmd, dry_run=args.dry_run)
|
||||||
|
|
||||||
|
|
||||||
|
def _run_eval_for_benchmark(
|
||||||
|
bname: str,
|
||||||
|
bentry: BenchmarkEntry,
|
||||||
|
args: argparse.Namespace,
|
||||||
|
) -> None:
|
||||||
|
"""Run evaluation for a single benchmark, iterating over all its eval_tasks."""
|
||||||
|
tasks = bentry.eval_tasks or [bentry.env_task]
|
||||||
|
for task in tasks:
|
||||||
|
log.info("=== Evaluating %s / %s ===", bname, task)
|
||||||
|
cmd = _build_eval_cmd(
|
||||||
|
bentry,
|
||||||
|
hub_user=args.hub_user,
|
||||||
|
policy_name=args.policy_name,
|
||||||
|
benchmark_name=bname,
|
||||||
|
eval_task=task if bentry.eval_tasks else None,
|
||||||
|
n_episodes=args.n_episodes,
|
||||||
|
batch_size_eval=args.batch_size_eval,
|
||||||
|
extra_args=args.extra,
|
||||||
|
)
|
||||||
|
_run(cmd, dry_run=args.dry_run)
|
||||||
|
if args.push_eval_to_hub:
|
||||||
|
_push_eval_to_hub(
|
||||||
|
hub_user=args.hub_user,
|
||||||
|
policy_name=args.policy_name,
|
||||||
|
benchmark_name=bname,
|
||||||
|
eval_task=task if bentry.eval_tasks else None,
|
||||||
|
dry_run=args.dry_run,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def cmd_eval(args: argparse.Namespace) -> None:
|
||||||
|
benchmarks = _resolve_benchmarks(args.benchmarks)
|
||||||
|
for bname, bentry in benchmarks:
|
||||||
|
_run_eval_for_benchmark(bname, bentry, args)
|
||||||
|
|
||||||
|
|
||||||
|
def cmd_all(args: argparse.Namespace) -> None:
|
||||||
|
"""Train on each benchmark, then evaluate each."""
|
||||||
|
benchmarks = _resolve_benchmarks(args.benchmarks)
|
||||||
|
|
||||||
|
log.info("Phase 1: Training on %d benchmark(s)", len(benchmarks))
|
||||||
|
for bname, bentry in benchmarks:
|
||||||
|
log.info("=== Training on benchmark: %s ===", bname)
|
||||||
|
cmd = _build_train_cmd(
|
||||||
|
bentry,
|
||||||
|
policy_path=args.policy_path,
|
||||||
|
hub_user=args.hub_user,
|
||||||
|
policy_name=args.policy_name,
|
||||||
|
benchmark_name=bname,
|
||||||
|
num_gpus=args.num_gpus,
|
||||||
|
steps=args.steps,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
eval_freq=args.eval_freq,
|
||||||
|
save_freq=args.save_freq,
|
||||||
|
wandb=args.wandb,
|
||||||
|
extra_args=args.extra,
|
||||||
|
)
|
||||||
|
_run(cmd, dry_run=args.dry_run)
|
||||||
|
|
||||||
|
log.info("Phase 2: Evaluating %d benchmark(s)", len(benchmarks))
|
||||||
|
for bname, bentry in benchmarks:
|
||||||
|
_run_eval_for_benchmark(bname, bentry, args)
|
||||||
|
|
||||||
|
|
||||||
|
def _add_common_args(p: argparse.ArgumentParser) -> None:
|
||||||
|
p.add_argument(
|
||||||
|
"--benchmarks", required=True,
|
||||||
|
help="Comma-separated benchmark names (e.g. libero_plus,robocasa,robomme).",
|
||||||
|
)
|
||||||
|
p.add_argument("--hub-user", required=True, help="HuggingFace Hub username.")
|
||||||
|
p.add_argument(
|
||||||
|
"--policy-name", default="smolvla",
|
||||||
|
help="Short policy name used in repo IDs and output dirs (default: smolvla).",
|
||||||
|
)
|
||||||
|
p.add_argument("--dry-run", action="store_true", help="Print commands without executing.")
|
||||||
|
|
||||||
|
|
||||||
|
def _add_train_args(p: argparse.ArgumentParser) -> None:
|
||||||
|
p.add_argument("--policy-path", default="lerobot/smolvla_base", help="Pretrained policy path.")
|
||||||
|
p.add_argument("--num-gpus", type=int, default=4, help="Number of GPUs.")
|
||||||
|
p.add_argument("--steps", type=int, default=50_000, help="Total training steps.")
|
||||||
|
p.add_argument("--batch-size", type=int, default=32, help="Per-GPU batch size.")
|
||||||
|
p.add_argument("--eval-freq", type=int, default=10_000, help="Eval every N steps (0 to disable).")
|
||||||
|
p.add_argument("--save-freq", type=int, default=10_000, help="Save checkpoint every N steps.")
|
||||||
|
p.add_argument("--wandb", action="store_true", help="Enable Weights & Biases logging.")
|
||||||
|
|
||||||
|
|
||||||
|
def _add_eval_args(p: argparse.ArgumentParser) -> None:
|
||||||
|
p.add_argument("--n-episodes", type=int, default=50, help="Number of eval episodes.")
|
||||||
|
p.add_argument("--batch-size-eval", type=int, default=10, help="Eval batch size (parallel envs).")
|
||||||
|
p.add_argument(
|
||||||
|
"--push-eval-to-hub", action="store_true",
|
||||||
|
help="Upload eval results (metrics + videos) to the policy repo on the Hub.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_parser() -> argparse.ArgumentParser:
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
prog="lerobot-benchmark",
|
||||||
|
description="Train and evaluate policies across simulation benchmarks.",
|
||||||
|
)
|
||||||
|
sub = parser.add_subparsers(dest="command", required=True)
|
||||||
|
|
||||||
|
# train
|
||||||
|
p_train = sub.add_parser("train", help="Train a policy on each selected benchmark.")
|
||||||
|
_add_common_args(p_train)
|
||||||
|
_add_train_args(p_train)
|
||||||
|
p_train.set_defaults(func=cmd_train)
|
||||||
|
|
||||||
|
# eval
|
||||||
|
p_eval = sub.add_parser("eval", help="Evaluate trained policies on each benchmark.")
|
||||||
|
_add_common_args(p_eval)
|
||||||
|
_add_eval_args(p_eval)
|
||||||
|
p_eval.set_defaults(func=cmd_eval)
|
||||||
|
|
||||||
|
# all (train + eval)
|
||||||
|
p_all = sub.add_parser("all", help="Train then evaluate on each benchmark.")
|
||||||
|
_add_common_args(p_all)
|
||||||
|
_add_train_args(p_all)
|
||||||
|
_add_eval_args(p_all)
|
||||||
|
p_all.set_defaults(func=cmd_all)
|
||||||
|
|
||||||
|
# list
|
||||||
|
p_list = sub.add_parser("list", help="List available benchmarks.")
|
||||||
|
p_list.set_defaults(func=lambda _args: _list_benchmarks())
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def _list_benchmarks() -> None:
|
||||||
|
print("Available benchmarks:\n")
|
||||||
|
for name, entry in BENCHMARK_REGISTRY.items():
|
||||||
|
print(f" {name}")
|
||||||
|
print(f" dataset: {entry.dataset_repo_id}")
|
||||||
|
print(f" env: {entry.env_type}")
|
||||||
|
if entry.eval_tasks:
|
||||||
|
print(f" eval on: {', '.join(entry.eval_tasks)}")
|
||||||
|
else:
|
||||||
|
print(f" eval on: {entry.env_task}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
parser = build_parser()
|
||||||
|
args, extra = parser.parse_known_args()
|
||||||
|
args.extra = extra
|
||||||
|
args.func(args)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Reference in New Issue
Block a user