diff --git a/pyproject.toml b/pyproject.toml index 6d284d5ff..9ba9a3229 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -246,6 +246,7 @@ lerobot-info="lerobot.scripts.lerobot_info:main" lerobot-find-joint-limits="lerobot.scripts.lerobot_find_joint_limits:main" lerobot-imgtransform-viz="lerobot.scripts.lerobot_imgtransform_viz:main" lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main" +lerobot-leaderboard="lerobot.scripts.lerobot_leaderboard:main" lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main" lerobot-benchmark="lerobot.scripts.lerobot_benchmark:main" diff --git a/src/lerobot/scripts/lerobot_leaderboard.py b/src/lerobot/scripts/lerobot_leaderboard.py new file mode 100644 index 000000000..de3208e43 --- /dev/null +++ b/src/lerobot/scripts/lerobot_leaderboard.py @@ -0,0 +1,605 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Generate an interactive eval leaderboard from Hub model repos. + +Reads eval results (as pushed by ``lerobot-eval --push_to_hub``) from one or +more Hugging Face model repos and produces a self-contained HTML page with a +sortable, filterable leaderboard table. + +Usage:: + + lerobot-leaderboard \ + --repo-ids user/model_a,user/model_b \ + --output leaderboard.html + + # Or from a file listing repo IDs (one per line): + lerobot-leaderboard \ + --repo-ids-file models.txt \ + --output leaderboard.html +""" + +from __future__ import annotations + +import argparse +import json +import logging +import sys +from dataclasses import dataclass, field +from pathlib import Path + +logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") +logger = logging.getLogger(__name__) + + +@dataclass +class ModelEntry: + repo_id: str + policy_type: str = "—" + dataset: str = "—" + training_steps: str = "—" + batch_size: str = "—" + # env_type -> {group_name -> pc_success} + eval_results: dict[str, dict[str, float]] = field(default_factory=dict) + # env_type -> overall pc_success + eval_overall: dict[str, float] = field(default_factory=dict) + # env_type -> n_episodes + eval_n_episodes: dict[str, int] = field(default_factory=dict) + + +def _try_download(repo_id: str, filename: str) -> dict | None: + """Download a JSON file from a Hub repo, return parsed dict or None.""" + from huggingface_hub import hf_hub_download + from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError + + try: + path = hf_hub_download(repo_id, filename, repo_type="model") + with open(path) as f: + return json.load(f) + except (EntryNotFoundError, RepositoryNotFoundError, OSError): + return None + + +def _list_eval_dirs(repo_id: str) -> list[str]: + """List env_type subdirectories under eval/ in a Hub repo.""" + from huggingface_hub import HfApi + from huggingface_hub.utils import RepositoryNotFoundError + + api = HfApi() + try: + files = api.list_repo_files(repo_id, repo_type="model") + except RepositoryNotFoundError: + return [] + + env_types = set() + for f in files: + if f.startswith("eval/") and f.count("/") >= 2: + env_types.add(f.split("/")[1]) + return sorted(env_types) + + +def fetch_model_entry(repo_id: str) -> ModelEntry: + """Fetch all available metadata and eval results for a single model.""" + entry = ModelEntry(repo_id=repo_id) + + # Policy config + policy_cfg = _try_download(repo_id, "config.json") + if policy_cfg: + entry.policy_type = policy_cfg.get("type", "—") + + # Training config + train_cfg = _try_download(repo_id, "train_config.json") + if train_cfg: + ds = train_cfg.get("dataset", {}) + entry.dataset = ds.get("repo_id", "—") if isinstance(ds, dict) else str(ds) + entry.training_steps = str(train_cfg.get("steps", "—")) + entry.batch_size = str(train_cfg.get("batch_size", "—")) + + # Eval results per env_type + for env_type in _list_eval_dirs(repo_id): + eval_info = _try_download(repo_id, f"eval/{env_type}/eval_info.json") + if not eval_info: + continue + + per_group = eval_info.get("per_group", {}) + group_results = {} + for group_name, stats in per_group.items(): + group_results[group_name] = stats.get("pc_success", float("nan")) + + entry.eval_results[env_type] = group_results + + overall = eval_info.get("overall", {}) + entry.eval_overall[env_type] = overall.get("pc_success", float("nan")) + entry.eval_n_episodes[env_type] = overall.get("n_episodes", 0) + + return entry + + +def fetch_all(repo_ids: list[str]) -> list[ModelEntry]: + entries = [] + for repo_id in repo_ids: + logger.info(f"Fetching {repo_id}...") + try: + entries.append(fetch_model_entry(repo_id)) + except Exception as e: + logger.warning(f"Failed to fetch {repo_id}: {e}") + return entries + + +def collect_all_env_types(entries: list[ModelEntry]) -> list[str]: + """Collect all unique env_types across all entries, sorted.""" + env_types: set[str] = set() + for e in entries: + env_types.update(e.eval_overall.keys()) + return sorted(env_types) + + +def collect_all_groups(entries: list[ModelEntry]) -> dict[str, list[str]]: + """Collect all unique group names per env_type.""" + groups: dict[str, set[str]] = {} + for e in entries: + for env_type, group_results in e.eval_results.items(): + groups.setdefault(env_type, set()).update(group_results.keys()) + return {k: sorted(v) for k, v in groups.items()} + + +def build_html(entries: list[ModelEntry], title: str = "LeRobot Eval Leaderboard") -> str: + env_types = collect_all_env_types(entries) + all_groups = collect_all_groups(entries) + + # Build column structure: fixed cols + per env_type (overall + per-group sub-columns) + # We'll build the data as JSON and let JS handle rendering + table_data = [] + for e in entries: + row = { + "repo_id": e.repo_id, + "policy_type": e.policy_type, + "dataset": e.dataset, + "training_steps": e.training_steps, + "batch_size": e.batch_size, + } + for env_type in env_types: + overall = e.eval_overall.get(env_type) + row[f"{env_type}__overall"] = round(overall, 1) if overall is not None else None + n_ep = e.eval_n_episodes.get(env_type) + row[f"{env_type}__n_episodes"] = n_ep if n_ep else None + for group in all_groups.get(env_type, []): + val = e.eval_results.get(env_type, {}).get(group) + row[f"{env_type}__{group}"] = round(val, 1) if val is not None else None + table_data.append(row) + + # Build column definitions for the JS table + columns_json = json.dumps(_build_column_defs(env_types, all_groups)) + data_json = json.dumps(table_data) + + return _HTML_TEMPLATE.format( + title=title, + columns_json=columns_json, + data_json=data_json, + ) + + +def _build_column_defs(env_types: list[str], all_groups: dict[str, list[str]]) -> list[dict]: + cols = [ + {"key": "repo_id", "label": "Model", "group": "Model Info", "sortable": True, "type": "link"}, + {"key": "policy_type", "label": "Policy", "group": "Model Info", "sortable": True, "type": "text"}, + {"key": "dataset", "label": "Dataset", "group": "Model Info", "sortable": True, "type": "text"}, + { + "key": "training_steps", + "label": "Steps", + "group": "Training", + "sortable": True, + "type": "number", + }, + { + "key": "batch_size", + "label": "Batch", + "group": "Training", + "sortable": True, + "type": "number", + }, + ] + for env_type in env_types: + cols.append( + { + "key": f"{env_type}__overall", + "label": "Overall %", + "group": env_type, + "sortable": True, + "type": "pct", + } + ) + for group in all_groups.get(env_type, []): + cols.append( + { + "key": f"{env_type}__{group}", + "label": f"{group} %", + "group": env_type, + "sortable": True, + "type": "pct", + } + ) + cols.append( + { + "key": f"{env_type}__n_episodes", + "label": "Episodes", + "group": env_type, + "sortable": True, + "type": "number", + } + ) + return cols + + +_HTML_TEMPLATE = """\ + + + + + +{title} + + + + +

🤖 {title}

+

Click any column header to sort. Filter by typing below.

+ +
+ + +
+ +
+ + + +
+
+ + + + + + +""" + + +def parse_args(argv: list[str] | None = None) -> argparse.Namespace: + p = argparse.ArgumentParser( + description="Generate an interactive eval leaderboard from Hub model repos.", + ) + p.add_argument( + "--repo-ids", + type=str, + default=None, + help="Comma-separated list of HF model repo IDs.", + ) + p.add_argument( + "--repo-ids-file", + type=str, + default=None, + help="Path to a text file with one repo ID per line.", + ) + p.add_argument( + "--output", + type=str, + default="leaderboard.html", + help="Output HTML file path (default: leaderboard.html).", + ) + p.add_argument( + "--title", + type=str, + default="LeRobot Eval Leaderboard", + help="Title shown in the leaderboard page.", + ) + return p.parse_args(argv) + + +def main(argv: list[str] | None = None): + args = parse_args(argv) + + repo_ids: list[str] = [] + if args.repo_ids: + repo_ids.extend(r.strip() for r in args.repo_ids.split(",") if r.strip()) + if args.repo_ids_file: + path = Path(args.repo_ids_file) + if not path.exists(): + logger.error(f"File not found: {path}") + sys.exit(1) + repo_ids.extend(line.strip() for line in path.read_text().splitlines() if line.strip()) + if not repo_ids: + logger.error("No repo IDs provided. Use --repo-ids or --repo-ids-file.") + sys.exit(1) + + entries = fetch_all(repo_ids) + if not entries: + logger.error("No valid entries found.") + sys.exit(1) + + html = build_html(entries, title=args.title) + out = Path(args.output) + out.write_text(html) + logger.info(f"Leaderboard written to {out.resolve()}") + + +if __name__ == "__main__": + main()