Compare commits

..

3 Commits

Author SHA1 Message Date
Pepijn 2ab59a3099 feat(benchmarks): add matrix runner and leaderboard 2026-04-15 21:31:33 +02:00
Pepijn dab511dbb1 Merge branch 'main' into feat/libero-benchmark 2026-04-14 10:43:49 +02:00
Pepijn fd00e38851 feat(benchmarks): add LIBERO training benchmark pipeline
Single-script benchmark that trains and evaluates all 9 LeRobot policies
on LIBERO. Each SLURM job self-publishes its result row to a HuggingFace
leaderboard dataset — no separate collection step needed.

Policies: pi0, pi0_fast, pi05, groot, act, diffusion, smolvla, xvla,
multi_task_dit. 5000 steps, BS 256, with per-policy GPU allocation and
default LR/scheduler presets.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-09 17:01:49 +02:00
130 changed files with 6036 additions and 7003 deletions
+22 -4
View File
@@ -2,6 +2,11 @@
Short, imperative summary (e.g., "fix(robots): handle None in sensor parser"). See [CONTRIBUTING.md](../CONTRIBUTING.md) for PR conventions.
## Type / Scope
- **Type**: (Bug | Feature | Docs | Performance | Test | CI | Chore)
- **Scope**: (optional — name of module or package affected)
## Summary / Motivation
- One-paragraph description of what changes and why.
@@ -14,14 +19,28 @@ Short, imperative summary (e.g., "fix(robots): handle None in sensor parser"). S
## What changed
- Short, concrete bullets explaining the functional changes (how the behavior or output differs now).
- Short, concrete bullets of the modifications (files/behaviour).
- Short note if this introduces breaking changes and migration steps.
## How was this tested (or how to run locally)
- Tests added: list new tests or test files. `pytest -q tests/ -k <keyword>`
- Tests added: list new tests or test files.
- Manual checks / dataset runs performed.
- Instructions for the reviewer for reproducing with a quick example or CLI (if applicable)
- Instructions for the reviewer
Example:
- Ran the relevant tests:
```bash
pytest -q tests/ -k <keyword>
```
- Reproduce with a quick example or CLI (if applicable):
```bash
lerobot-train --some.option=true
```
## Checklist (required before merge)
@@ -29,7 +48,6 @@ Short, imperative summary (e.g., "fix(robots): handle None in sensor parser"). S
- [ ] All tests pass locally (`pytest`)
- [ ] Documentation updated
- [ ] CI is green
- [ ] Community Review: I have reviewed another contributor's open PR and linked it here: # (insert PR number/link)
## Reviewer notes
+178
View File
@@ -310,3 +310,181 @@ jobs:
name: metaworld-metrics
path: /tmp/metaworld-artifacts/metrics.json
if-no-files-found: warn
# ── LIBERO-plus ───────────────────────────────────────────────────────────
libero-plus-integration-test:
name: LIBERO-plus — build image + 1-episode eval
runs-on:
group: aws-g6-4xlarge-plus
env:
HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }}
steps:
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
lfs: true
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3 # zizmor: ignore[unpinned-uses]
with:
cache-binary: false
- name: Build LIBERO-plus benchmark image
uses: docker/build-push-action@v6 # zizmor: ignore[unpinned-uses]
with:
context: .
file: docker/Dockerfile.benchmark.libero_plus
push: false
load: true
tags: lerobot-benchmark-libero-plus:ci
cache-from: type=local,src=/tmp/.buildx-cache-libero-plus
cache-to: type=local,dest=/tmp/.buildx-cache-libero-plus,mode=max
- name: Run LIBERO-plus smoke eval (1 episode)
if: env.HF_USER_TOKEN != ''
run: |
docker run --name libero-plus-eval --gpus all \
--shm-size=4g \
-e HF_HOME=/tmp/hf \
-e HF_USER_TOKEN="${HF_USER_TOKEN}" \
-e HF_HUB_DOWNLOAD_TIMEOUT=300 \
lerobot-benchmark-libero-plus:ci \
bash -c "
hf auth login --token \"\$HF_USER_TOKEN\" --add-to-git-credential 2>/dev/null || true
lerobot-eval \
--policy.path=lerobot/smolvla_libero_plus \
--env.type=libero_plus \
--env.task=libero_spatial \
'--env.task_ids=[0,100,260,500,1000,1500,2000,2400]' \
--eval.batch_size=1 \
--eval.n_episodes=1 \
--eval.use_async_envs=false \
--policy.device=cuda \
'--env.camera_name_mapping={\"agentview_image\": \"camera1\", \"robot0_eye_in_hand_image\": \"camera2\"}' \
--policy.empty_cameras=1 \
--output_dir=/tmp/eval-artifacts
python scripts/ci/extract_task_descriptions.py \
--env libero_plus --task libero_spatial \
--output /tmp/eval-artifacts/task_descriptions.json
"
- name: Copy LIBERO-plus artifacts from container
if: always()
run: |
mkdir -p /tmp/libero-plus-artifacts
docker cp libero-plus-eval:/tmp/eval-artifacts/. /tmp/libero-plus-artifacts/ 2>/dev/null || true
docker rm -f libero-plus-eval || true
- name: Parse LIBERO-plus eval metrics
if: always()
run: |
python3 scripts/ci/parse_eval_metrics.py \
--artifacts-dir /tmp/libero-plus-artifacts \
--env libero_plus \
--task libero_spatial \
--policy lerobot/smolvla_libero_plus
- name: Upload LIBERO-plus rollout video
if: always()
uses: actions/upload-artifact@v4 # zizmor: ignore[unpinned-uses]
with:
name: libero-plus-rollout-video
path: /tmp/libero-plus-artifacts/videos/
if-no-files-found: warn
- name: Upload LIBERO-plus eval metrics
if: always()
uses: actions/upload-artifact@v4 # zizmor: ignore[unpinned-uses]
with:
name: libero-plus-metrics
path: /tmp/libero-plus-artifacts/metrics.json
if-no-files-found: warn
# ── ROBOMME ───────────────────────────────────────────────────────────────
robomme-integration-test:
name: RoboMME — build image + 1-episode eval
runs-on:
group: aws-g6-4xlarge-plus
env:
HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }}
steps:
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
lfs: true
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3 # zizmor: ignore[unpinned-uses]
with:
cache-binary: false
- name: Build RoboMME benchmark image
uses: docker/build-push-action@v6 # zizmor: ignore[unpinned-uses]
with:
context: .
file: docker/Dockerfile.benchmark.robomme
push: false
load: true
tags: lerobot-benchmark-robomme:ci
- name: Run RoboMME smoke eval (1 episode)
if: env.HF_USER_TOKEN != ''
run: |
docker run --name robomme-eval --gpus all \
--shm-size=4g \
-e HF_HOME=/tmp/hf \
-e HF_USER_TOKEN="${HF_USER_TOKEN}" \
-e HF_HUB_DOWNLOAD_TIMEOUT=300 \
lerobot-benchmark-robomme:ci \
bash -c "
hf auth login --token \"\$HF_USER_TOKEN\" --add-to-git-credential 2>/dev/null || true
lerobot-eval \
--policy.path=lerobot/smolvla_robomme \
--env.type=robomme \
--env.task=PickXtimes,BinFill,StopCube,MoveCube,InsertPeg \
--env.dataset_split=test \
--eval.batch_size=1 \
--eval.n_episodes=1 \
--eval.use_async_envs=false \
--policy.device=cuda \
'--rename_map={\"observation.images.image\": \"observation.images.camera1\", \"observation.images.wrist_image\": \"observation.images.camera2\"}' \
--policy.empty_cameras=3 \
--output_dir=/tmp/eval-artifacts
python scripts/ci/extract_task_descriptions.py \
--env robomme --task PickXtimes,BinFill,StopCube,MoveCube,InsertPeg \
--output /tmp/eval-artifacts/task_descriptions.json
"
- name: Copy RoboMME artifacts from container
if: always()
run: |
mkdir -p /tmp/robomme-artifacts
docker cp robomme-eval:/tmp/eval-artifacts/. /tmp/robomme-artifacts/ 2>/dev/null || true
docker rm -f robomme-eval || true
- name: Parse RoboMME eval metrics
if: always()
run: |
python3 scripts/ci/parse_eval_metrics.py \
--artifacts-dir /tmp/robomme-artifacts \
--env robomme \
--task PickXtimes \
--policy lerobot/smolvla_robomme
- name: Upload RoboMME rollout video
if: always()
uses: actions/upload-artifact@v4 # zizmor: ignore[unpinned-uses]
with:
name: robomme-rollout-video
path: /tmp/robomme-artifacts/videos/
if-no-files-found: warn
- name: Upload RoboMME eval metrics
if: always()
uses: actions/upload-artifact@v4 # zizmor: ignore[unpinned-uses]
with:
name: robomme-metrics
path: /tmp/robomme-artifacts/metrics.json
if-no-files-found: warn
@@ -33,7 +33,7 @@ jobs:
github.event.workflow_run.event == 'pull_request' &&
github.event.workflow_run.conclusion == 'success' &&
github.repository == 'huggingface/lerobot'
uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@9ad2de8582b56c017cb530c1165116d40433f1c6 # main
uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@90b4ee2c10b81b5c1a6367c4e6fc9e2fb510a7e3 # main
with:
package_name: lerobot
secrets:
-18
View File
@@ -217,24 +217,6 @@ jobs:
- name: Run end-to-end tests
run: make test-end-to-end
slack-notification:
name: Slack Notification
needs: [cpu-tests, gpu-tests, upgrade-lock]
if: always() && needs.upgrade-lock.outputs.changed == 'true'
runs-on: ubuntu-latest
permissions:
contents: read
env:
CI_SLACK_CHANNEL: ${{ secrets.CI_SLACK_CHANNEL }}
steps:
- name: Post to a Slack channel
uses: huggingface/hf-workflows/.github/actions/post-slack@a88e7fa2eaee28de5a4d6142381b1fb792349b67 # main
with:
slack_channel: ${{ env.CI_SLACK_CHANNEL }}
title: "Results of the latest dependency tests (CPU + GPU)"
status: ${{ (needs.cpu-tests.result == 'success' && needs.gpu-tests.result == 'success') && 'success' || 'failure' }}
slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
# This job creates or updates a PR with the upgraded lockfile
open-pr:
name: Open PR
+1 -4
View File
@@ -78,9 +78,6 @@ Use the templates for required fields and examples.
- **Issues:** Follow the [ticket template](https://github.com/huggingface/lerobot/blob/main/.github/ISSUE_TEMPLATE/bug-report.yml).
- **Pull requests:** Rebase on `upstream/main`, use a descriptive branch (don't work on `main`), run `pre-commit` and tests locally, and follow the [PR template](https://github.com/huggingface/lerobot/blob/main/.github/PULL_REQUEST_TEMPLATE.md).
> [!IMPORTANT]
> Community Review Policy: To help scale our efforts and foster a collaborative environment, we ask contributors to review at least one other person's open PR before their own receives attention. This shared responsibility multiplies our review capacity and helps everyone's code get merged faster!
Once you have submitted your PR and completed a peer review, a member of the LeRobot team will review your contribution.
One member of the LeRobot team will then review your contribution.
Thank you for contributing to LeRobot!
+1
View File
@@ -0,0 +1 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
+60
View File
@@ -0,0 +1,60 @@
# LeRobot LIBERO Training Benchmark
Train and evaluate all LeRobot policies on [LIBERO](https://libero-project.github.io/) and publish results as a HuggingFace leaderboard dataset.
## Policies
| Policy | Base Model | GPUs | LR | Chunk | Notes |
| -------------- | -------------------- | ---- | ------ | ----- | ------------------------------------- |
| pi0 | lerobot/pi0_base | 8 | 2.5e-5 | 30 | PaliGemma + Gemma flow matching |
| pi0_fast | lerobot/pi0fast-base | 8 | 2.5e-5 | 30 | Requires tokenizer pre-training |
| pi05 | lerobot/pi05_base | 8 | 2.5e-5 | 30 | Quantiles normalization |
| groot | nvidia/GR00T-N1.5-3B | 8 | 1e-4 | 30 | bf16, diffusion head + projector only |
| act | From scratch | 1 | 1e-5 | 30 | ResNet-18, lightweight |
| diffusion | From scratch | 1 | 1e-4 | 32\* | U-Net, horizon must be divisible by 8 |
| smolvla | lerobot/smolvla_base | 8 | 1e-4 | 30 | SmolVLM2-500M |
| xvla | lerobot/xvla-widowx | 4 | 1e-4 | 32\* | Florence2 + CLIP |
| multi_task_dit | From scratch | 1 | 2e-5 | 32\* | CLIP + DiT |
\* These policies use `horizon` rather than `chunk_size`. Set to 32 (nearest valid value to 30).
## Training spec
- **Steps**: 5,000 per policy
- **Batch size**: 32 per GPU (effective BS = 256 for multi-GPU)
- **Dataset**: `lerobot/libero` (libero_spatial)
- **Evaluation**: 20 episodes after training
- **LR**: each policy's default optimizer/scheduler preset
- **Results**: each SLURM job publishes its own row to the HF leaderboard dataset automatically
## Quick start
### 1. Generate SLURM scripts
```bash
python benchmarks/libero/run_benchmark.py \
--output_dir /scratch/lerobot-benchmark \
--hub_org lerobot
```
### 2. Submit jobs
```bash
# If using pi0_fast, submit tokenizer first:
sbatch /scratch/lerobot-benchmark/slurm_scripts/00_tokenizer.sh
# Wait, then submit pi0_fast
# All other policies can run in parallel:
for script in /scratch/lerobot-benchmark/slurm_scripts/[0-9][0-9]_*.sh; do
[[ "$script" == *pi0_fast* ]] && continue
sbatch "$script"
done
```
Each job publishes its result to `lerobot/benchmark-libero` on the Hub when it finishes.
## Prerequisites
- SLURM cluster with CUDA GPUs (A100 80GB recommended for VLM policies)
- `pip install lerobot[pi,smolvla,groot,xvla,multi_task_dit,libero] datasets`
- `huggingface-cli login`
+606
View File
@@ -0,0 +1,606 @@
#!/usr/bin/env python
"""Generate SLURM sbatch scripts for training all LeRobot policies on LIBERO.
Each generated script trains one policy, evaluates it, and publishes its
results row to a HuggingFace leaderboard dataset — no separate collection
step needed.
Usage:
# Generate scripts for all policies:
python benchmarks/libero/run_benchmark.py \\
--output_dir /scratch/lerobot-benchmark --hub_org lerobot
# Generate for a subset:
python benchmarks/libero/run_benchmark.py \\
--policies pi0 smolvla act \\
--output_dir /scratch/lerobot-benchmark --hub_org lerobot
"""
from __future__ import annotations
import argparse
import json
import subprocess
import textwrap
import uuid
from dataclasses import dataclass, field
from datetime import UTC, datetime
from pathlib import Path
# ──────────────────────────────────────────────────────────────────────
# Policy benchmark configs
# ──────────────────────────────────────────────────────────────────────
@dataclass
class PolicyBenchmarkConfig:
"""Training configuration for a single policy on a benchmark."""
policy_type: str
policy_path: str | None = None
num_gpus: int = 1
chunk_size: int | None = None # Set on policies that use chunk_size (not horizon)
extra_policy_args: dict[str, str] = field(default_factory=dict)
needs_tokenizer: bool = False
tokenizer_args: dict[str, str] = field(default_factory=dict)
COMMON_TRAINING_ARGS: dict[str, str] = {
"dataset.repo_id": "lerobot/libero",
"dataset.use_imagenet_stats": "false",
"env.type": "libero",
"env.task": "libero_spatial",
"steps": "5000",
"batch_size": "32",
"eval_freq": "0",
"save_freq": "5000",
"save_checkpoint": "true",
"log_freq": "100",
"wandb.enable": "true",
"policy.push_to_hub": "true",
"rename_map": (
'{"observation.images.image":"observation.images.camera1",'
'"observation.images.image2":"observation.images.camera2"}'
),
}
EVAL_ARGS: dict[str, str] = {
"env.type": "libero",
"env.task": "libero_spatial",
"eval.n_episodes": "20",
"eval.batch_size": "10",
}
POLICY_CONFIGS: dict[str, PolicyBenchmarkConfig] = {
"pi0": PolicyBenchmarkConfig(
policy_type="pi0",
policy_path="lerobot/pi0_base",
num_gpus=8,
chunk_size=30,
extra_policy_args={
"policy.n_action_steps": "30",
"policy.scheduler_decay_steps": "5000",
},
),
"pi0_fast": PolicyBenchmarkConfig(
policy_type="pi0_fast",
policy_path="lerobot/pi0fast-base",
num_gpus=8,
chunk_size=30,
extra_policy_args={
"policy.n_action_steps": "30",
"policy.scheduler_decay_steps": "5000",
},
needs_tokenizer=True,
tokenizer_args={
"repo_id": "lerobot/libero",
"action_horizon": "30",
"encoded_dims": "0:7",
"normalization_mode": "QUANTILES",
"vocab_size": "1024",
"scale": "10.0",
"push_to_hub": "true",
},
),
"pi05": PolicyBenchmarkConfig(
policy_type="pi05",
policy_path="lerobot/pi05_base",
num_gpus=8,
chunk_size=30,
extra_policy_args={
"policy.n_action_steps": "30",
"policy.scheduler_decay_steps": "5000",
},
),
"groot": PolicyBenchmarkConfig(
policy_type="groot",
policy_path=None,
num_gpus=8,
chunk_size=30,
extra_policy_args={
"policy.n_action_steps": "30",
"policy.base_model_path": "nvidia/GR00T-N1.5-3B",
"policy.tune_diffusion_model": "true",
"policy.tune_projector": "true",
"policy.tune_llm": "false",
"policy.tune_visual": "false",
"policy.use_bf16": "true",
},
),
"act": PolicyBenchmarkConfig(
policy_type="act",
policy_path=None,
num_gpus=1,
chunk_size=30,
extra_policy_args={"policy.n_action_steps": "30"},
),
"diffusion": PolicyBenchmarkConfig(
policy_type="diffusion",
policy_path=None,
num_gpus=1,
chunk_size=None,
extra_policy_args={
"policy.horizon": "32",
"policy.n_action_steps": "30",
"policy.n_obs_steps": "2",
},
),
"smolvla": PolicyBenchmarkConfig(
policy_type="smolvla",
policy_path="lerobot/smolvla_base",
num_gpus=8,
chunk_size=30,
extra_policy_args={
"policy.n_action_steps": "30",
"policy.load_vlm_weights": "true",
"policy.freeze_vision_encoder": "false",
"policy.train_expert_only": "false",
"policy.scheduler_decay_steps": "5000",
},
),
"xvla": PolicyBenchmarkConfig(
policy_type="xvla",
policy_path="lerobot/xvla-widowx",
num_gpus=4,
chunk_size=32,
extra_policy_args={
"policy.n_action_steps": "32",
"policy.scheduler_decay_steps": "5000",
},
),
"multi_task_dit": PolicyBenchmarkConfig(
policy_type="multi_task_dit",
policy_path=None,
num_gpus=1,
chunk_size=None,
extra_policy_args={
"policy.horizon": "32",
"policy.n_action_steps": "30",
},
),
}
ALL_POLICY_NAMES = list(POLICY_CONFIGS.keys())
# GPU memory estimates (GB) for SLURM --mem allocation
GPU_MEM_ESTIMATES: dict[str, int] = {
"pi0": 320,
"pi0_fast": 320,
"pi05": 280,
"groot": 320,
"act": 64,
"diffusion": 64,
"smolvla": 160,
"xvla": 160,
"multi_task_dit": 64,
}
# ──────────────────────────────────────────────────────────────────────
# SLURM script generation
# ──────────────────────────────────────────────────────────────────────
def _cli_args(args: dict[str, str]) -> str:
"""Build a backslash-continued CLI arg string with proper shell quoting."""
lines = []
for key, value in args.items():
if any(c in str(value) for c in ["{", "}", " ", '"', "'"]):
lines.append(f" --{key}='{value}'")
else:
lines.append(f" --{key}={value}")
return " \\\n".join(lines)
def _training_cli_args(
policy_name: str,
output_dir: Path,
hub_org: str,
benchmark_uuid: str,
) -> str:
cfg = POLICY_CONFIGS[policy_name]
args: dict[str, str] = {}
args.update(COMMON_TRAINING_ARGS)
args["policy.type"] = cfg.policy_type
if cfg.policy_path:
args["policy.path"] = cfg.policy_path
if cfg.chunk_size is not None:
args["policy.chunk_size"] = str(cfg.chunk_size)
args.update(cfg.extra_policy_args)
args["output_dir"] = str(output_dir / "train" / policy_name)
args["policy.repo_id"] = f"{hub_org}/{policy_name}_libero"
args["wandb.project"] = "lerobot-libero-benchmark"
args["wandb.run_name"] = f"{policy_name}_{benchmark_uuid[:8]}"
return _cli_args(args)
def _publish_snippet(
policy_name: str,
output_dir: Path,
hub_org: str,
benchmark_uuid: str,
hub_dataset: str,
) -> str:
"""Inline Python that each SLURM job runs to publish its own result row."""
cfg = POLICY_CONFIGS[policy_name]
steps = int(COMMON_TRAINING_ARGS["steps"])
bs = int(COMMON_TRAINING_ARGS["batch_size"])
eff_bs = bs * cfg.num_gpus
train_dir = output_dir / "train" / policy_name
return textwrap.dedent(f"""\
python3 -c "
import json, os, re, sys
from pathlib import Path
from datetime import datetime, timezone
timing = {{}}
tp = Path('{output_dir}/logs/{policy_name}_timing.txt')
if tp.exists():
for ln in tp.read_text().splitlines():
if '=' in ln:
k, _, v = ln.partition('=')
timing[k.strip()] = v.strip()
# Parse eval results
eval_sr, eval_per_task, eval_n = None, '{{}}', 0
eval_dir = Path('{train_dir}/eval_results')
if eval_dir.exists():
for jf in eval_dir.glob('**/*.json'):
try:
d = json.loads(jf.read_text())
except Exception:
continue
if 'avg_success_rate' in d:
eval_sr = d['avg_success_rate']
elif 'eval_info' in d and 'avg_success_rate' in d.get('eval_info', {{}}):
eval_sr = d['eval_info']['avg_success_rate']
pt = {{k: v for k, v in d.items() if 'success_rate' in k and k != 'avg_success_rate'}}
if pt:
eval_per_task = json.dumps(pt)
if 'n_episodes' in d:
eval_n = d['n_episodes']
# Parse final loss from SLURM stdout
final_loss = None
for lf in sorted(Path('{output_dir}/logs').glob('{policy_name}_*.out'), reverse=True):
losses = re.findall(r'\\\"loss\\\"\\s*:\\s*([\\d.e+-]+)', lf.read_text())
if losses:
final_loss = float(losses[-1])
break
# Parse peak GPU mem
peak_mem = 0.0
csv_p = Path('{output_dir}/logs/{policy_name}_gpu_mem.csv')
if csv_p.exists():
for ln in csv_p.read_text().splitlines():
parts = ln.strip().split(',')
if len(parts) >= 2:
try:
peak_mem = max(peak_mem, float(parts[1].strip()))
except ValueError:
pass
# Parse train config for optimizer details
lr, opt_wd, sched_type, sched_warmup, sched_decay = 0.0, 0.0, '', 0, 0
freeze_ve, train_eo, grad_ckpt = False, False, False
cfg_path = Path('{train_dir}/checkpoints/{steps:06d}/pretrained_model/train_config.json')
if cfg_path.exists():
tc = json.loads(cfg_path.read_text())
o = tc.get('optimizer', {{}})
lr = o.get('lr', 0.0)
opt_wd = o.get('weight_decay', 0.0)
s = tc.get('scheduler', {{}})
sched_type = s.get('type', '')
sched_warmup = s.get('num_warmup_steps', 0)
sched_decay = s.get('num_decay_steps', 0)
p = tc.get('policy', {{}})
freeze_ve = p.get('freeze_vision_encoder', False)
train_eo = p.get('train_expert_only', False)
grad_ckpt = p.get('gradient_checkpointing', False)
row = {{
'benchmark_uuid': '{benchmark_uuid}',
'policy_type': '{policy_name}',
'policy_repo_id': '{hub_org}/{policy_name}_libero',
'base_model_repo_id': '{cfg.policy_path or ""}',
'dataset_repo_id': '{COMMON_TRAINING_ARGS["dataset.repo_id"]}',
'env_type': '{COMMON_TRAINING_ARGS["env.type"]}',
'env_task': '{COMMON_TRAINING_ARGS["env.task"]}',
'steps': {steps},
'batch_size_per_gpu': {bs},
'num_gpus': {cfg.num_gpus},
'effective_batch_size': {eff_bs},
'total_samples_seen': {steps * eff_bs},
'chunk_size': {cfg.chunk_size or 0},
'learning_rate': lr,
'optimizer_type': 'AdamW',
'optimizer_weight_decay': opt_wd,
'scheduler_type': sched_type,
'scheduler_warmup_steps': sched_warmup,
'scheduler_decay_steps': sched_decay,
'freeze_vision_encoder': freeze_ve,
'train_expert_only': train_eo,
'gradient_checkpointing': grad_ckpt,
'eval_success_rate': eval_sr,
'eval_success_rate_per_task': eval_per_task,
'eval_n_episodes': eval_n,
'final_train_loss': final_loss,
'training_time_s': float(timing.get('TRAINING_TIME_S', 0)),
'peak_gpu_memory_mb': peak_mem or float(timing.get('MAX_GPU_MEM_MB', 0)),
'gpu_type': timing.get('GPU_TYPE', 'unknown'),
'lerobot_commit': timing.get('LEROBOT_COMMIT', 'unknown'),
'timestamp': datetime.now(timezone.utc).isoformat(),
}}
# Save locally
Path('{train_dir}/benchmark_result.json').write_text(json.dumps(row, indent=2, default=str))
# Push to HF dataset
try:
from datasets import Dataset, load_dataset
try:
existing = load_dataset('{hub_dataset}', split='train')
rows = existing.to_list() + [row]
except Exception:
rows = [row]
Dataset.from_list(rows).push_to_hub('{hub_dataset}', split='train')
print('Published result to {hub_dataset}')
except ImportError:
print('datasets library not installed — result saved locally only')
except Exception as e:
print(f'Failed to push to hub: {{e}} — result saved locally')
"
""")
def _generate_sbatch_script(
policy_name: str,
output_dir: Path,
hub_org: str,
benchmark_uuid: str,
hub_dataset: str,
lerobot_commit: str,
) -> str:
cfg = POLICY_CONFIGS[policy_name]
steps = int(COMMON_TRAINING_ARGS["steps"])
log_dir = output_dir / "logs"
train_dir = output_dir / "train" / policy_name
checkpoint_path = train_dir / f"checkpoints/{steps:06d}/pretrained_model"
training_args = _training_cli_args(policy_name, output_dir, hub_org, benchmark_uuid)
eval_args = _cli_args(EVAL_ARGS)
publish = _publish_snippet(policy_name, output_dir, hub_org, benchmark_uuid, hub_dataset)
return textwrap.dedent(f"""\
#!/bin/bash
#SBATCH --job-name=bench_{policy_name}
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=1
#SBATCH --gres=gpu:{cfg.num_gpus}
#SBATCH --cpus-per-task={cfg.num_gpus * 8}
#SBATCH --mem={GPU_MEM_ESTIMATES.get(policy_name, 128)}G
#SBATCH --time=06:00:00
#SBATCH --output={log_dir}/{policy_name}_%j.out
#SBATCH --error={log_dir}/{policy_name}_%j.err
set -euo pipefail
echo "=========================================="
echo "LeRobot LIBERO Benchmark — {policy_name}"
echo "UUID: {benchmark_uuid}"
echo "Start: $(date -Iseconds)"
echo "Host: $(hostname) | GPUs: {cfg.num_gpus}"
echo "=========================================="
START_TIME=$(date +%s)
# GPU memory monitoring (every 30s)
nvidia-smi --query-gpu=index,memory.used,memory.total,gpu_name \\
--format=csv,noheader,nounits -l 30 \\
> "{log_dir}/{policy_name}_gpu_mem.csv" &
GPU_MONITOR_PID=$!
# ── Training ──────────────────────────────────────────────────
echo "[$(date -Iseconds)] Starting training..."
accelerate launch --num_processes={cfg.num_gpus} \\
$(which lerobot-train) \\
{training_args}
TRAIN_EXIT=$?
TRAIN_END=$(date +%s)
echo "[$(date -Iseconds)] Training exit code: $TRAIN_EXIT"
# ── Evaluation ────────────────────────────────────────────────
EVAL_EXIT=1
if [ $TRAIN_EXIT -eq 0 ]; then
echo "[$(date -Iseconds)] Starting evaluation..."
lerobot-eval \\
--policy.path="{checkpoint_path}" \\
{eval_args} \\
--output_dir="{train_dir}/eval_results"
EVAL_EXIT=$?
echo "[$(date -Iseconds)] Eval exit code: $EVAL_EXIT"
else
echo "[$(date -Iseconds)] Skipping eval — training failed."
fi
# ── Timing ────────────────────────────────────────────────────
END_TIME=$(date +%s)
kill $GPU_MONITOR_PID 2>/dev/null || true
cat > "{log_dir}/{policy_name}_timing.txt" <<TIMING_EOF
BENCHMARK_UUID={benchmark_uuid}
POLICY_TYPE={policy_name}
TRAINING_TIME_S=$((TRAIN_END - START_TIME))
TOTAL_TIME_S=$((END_TIME - START_TIME))
TRAIN_EXIT=$TRAIN_EXIT
EVAL_EXIT=$EVAL_EXIT
MAX_GPU_MEM_MB=$(awk -F',' '{{print $2}}' "{log_dir}/{policy_name}_gpu_mem.csv" 2>/dev/null | sort -n | tail -1)
GPU_TYPE=$(nvidia-smi --query-gpu=gpu_name --format=csv,noheader | head -1 | xargs)
LEROBOT_COMMIT={lerobot_commit}
TIMING_EOF
# ── Publish result to HF dataset ──────────────────────────────
echo "[$(date -Iseconds)] Publishing result..."
{publish}
echo "=========================================="
echo "Done: $(date -Iseconds)"
echo "Training: $((TRAIN_END - START_TIME))s | Total: $((END_TIME - START_TIME))s"
echo "=========================================="
""")
def _generate_tokenizer_script(
output_dir: Path,
hub_org: str,
benchmark_uuid: str,
) -> str:
cfg = POLICY_CONFIGS["pi0_fast"]
log_dir = output_dir / "logs"
tokenizer_hub_repo = f"{hub_org}/fast-tokenizer-libero"
tok_args = dict(cfg.tokenizer_args)
tok_args["hub_repo_id"] = tokenizer_hub_repo
return textwrap.dedent(f"""\
#!/bin/bash
#SBATCH --job-name=bench_tokenizer
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=1
#SBATCH --gres=gpu:1
#SBATCH --cpus-per-task=8
#SBATCH --mem=64G
#SBATCH --time=01:00:00
#SBATCH --output={log_dir}/tokenizer_%j.out
#SBATCH --error={log_dir}/tokenizer_%j.err
set -euo pipefail
echo "LeRobot — FAST Tokenizer | UUID: {benchmark_uuid}"
lerobot-train-tokenizer \\
{_cli_args(tok_args)}
echo "Tokenizer pushed to: {tokenizer_hub_repo}"
""")
# ──────────────────────────────────────────────────────────────────────
# Main
# ──────────────────────────────────────────────────────────────────────
def main() -> None:
parser = argparse.ArgumentParser(description="Generate SLURM scripts for LeRobot LIBERO benchmark.")
parser.add_argument(
"--policies",
nargs="+",
default=ALL_POLICY_NAMES,
choices=ALL_POLICY_NAMES,
help="Policies to benchmark (default: all).",
)
parser.add_argument("--output_dir", type=Path, required=True, help="Root output directory.")
parser.add_argument("--hub_org", type=str, default="lerobot", help="HuggingFace org.")
parser.add_argument("--hub_dataset", type=str, default=None, help="HF dataset repo for results.")
parser.add_argument("--uuid", type=str, default=None, help="Override benchmark UUID.")
args = parser.parse_args()
benchmark_uuid = args.uuid or str(uuid.uuid4())
output_dir: Path = args.output_dir.resolve()
policies: list[str] = args.policies
hub_org: str = args.hub_org
hub_dataset: str = args.hub_dataset or f"{hub_org}/benchmark-libero"
try:
commit = subprocess.check_output(["git", "rev-parse", "HEAD"], text=True).strip()
except (subprocess.CalledProcessError, FileNotFoundError):
commit = "unknown"
scripts_dir = output_dir / "slurm_scripts"
log_dir = output_dir / "logs"
scripts_dir.mkdir(parents=True, exist_ok=True)
log_dir.mkdir(parents=True, exist_ok=True)
for p in policies:
(output_dir / "train" / p).mkdir(parents=True, exist_ok=True)
generated: dict[str, Path] = {}
# Tokenizer job for pi0_fast
tokenizer_path = None
if "pi0_fast" in policies:
script = _generate_tokenizer_script(output_dir, hub_org, benchmark_uuid)
tokenizer_path = scripts_dir / "00_tokenizer.sh"
tokenizer_path.write_text(script)
tokenizer_path.chmod(0o755)
generated["tokenizer"] = tokenizer_path
tokenizer_hub_repo = f"{hub_org}/fast-tokenizer-libero"
POLICY_CONFIGS["pi0_fast"].extra_policy_args["policy.action_tokenizer_name"] = tokenizer_hub_repo
# Per-policy scripts
for i, name in enumerate(sorted(policies), start=1):
script = _generate_sbatch_script(name, output_dir, hub_org, benchmark_uuid, hub_dataset, commit)
path = scripts_dir / f"{i:02d}_{name}.sh"
path.write_text(script)
path.chmod(0o755)
generated[name] = path
# Manifest
manifest = {
"benchmark_uuid": benchmark_uuid,
"timestamp": datetime.now(UTC).isoformat(),
"lerobot_commit": commit,
"hub_org": hub_org,
"hub_dataset": hub_dataset,
"policies": policies,
"output_dir": str(output_dir),
"scripts": {k: str(v) for k, v in generated.items()},
}
manifest_path = output_dir / "benchmark_manifest.json"
manifest_path.write_text(json.dumps(manifest, indent=2))
# Instructions
print("=" * 60)
print("LeRobot LIBERO Benchmark — Scripts Generated")
print(f"UUID: {benchmark_uuid}")
print(f"Output: {output_dir}")
print(f"Results dataset: {hub_dataset}")
print("=" * 60)
print()
for _name, path in sorted(generated.items()):
print(f" {path}")
print()
if tokenizer_path:
print("IMPORTANT: pi0_fast requires tokenizer training FIRST.")
print(f" 1. sbatch {tokenizer_path}")
print(" 2. Wait for completion")
print(f" 3. sbatch {generated.get('pi0_fast', 'N/A')}")
print(" 4. All other policies can run in parallel")
else:
print("All scripts can be submitted in parallel.")
print()
print("Each job publishes its result to the HF dataset automatically.")
if __name__ == "__main__":
main()
+156
View File
@@ -0,0 +1,156 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Publish benchmark rows and lightweight artifacts to a Hub dataset."""
from __future__ import annotations
import argparse
import json
from datetime import UTC, datetime
from pathlib import Path
from typing import Any
from lerobot.utils.history_repo import UploadTarget, make_hub_file_url, upload_targets, utc_timestamp_slug
def load_json_if_exists(path: Path) -> dict[str, Any] | None:
if not path.exists():
return None
return json.loads(path.read_text())
def find_latest_train_config_path(run_root: Path) -> Path | None:
checkpoints_dir = run_root / "train" / "checkpoints"
if not checkpoints_dir.exists():
return None
candidates = sorted(
checkpoints_dir.glob("*/pretrained_model/train_config.json"),
key=lambda path: path.parts[-3],
)
return candidates[-1] if candidates else None
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--benchmark", required=True)
parser.add_argument("--policy", required=True)
parser.add_argument("--run_root", required=True, type=Path)
parser.add_argument("--results_repo", required=True)
parser.add_argument("--git_commit", required=True)
parser.add_argument("--num_gpus", required=True, type=int)
parser.add_argument("--microbatch_per_gpu", required=True, type=int)
parser.add_argument("--gradient_accumulation_steps", required=True, type=int)
parser.add_argument("--effective_batch_size", required=True, type=int)
parser.add_argument("--train_wall_time_s", required=True, type=float)
parser.add_argument("--eval_wall_time_s", required=True, type=float)
parser.add_argument("--slurm_job_id", default="")
parser.add_argument("--docker_image", required=True)
return parser.parse_args()
def build_row(args: argparse.Namespace) -> tuple[dict[str, Any], list[UploadTarget]]:
now = datetime.now(UTC)
created_at = now.isoformat()
timestamp = utc_timestamp_slug(now)
run_id = f"{timestamp}__{args.benchmark}__{args.policy}__{args.slurm_job_id or 'manual'}"
eval_info = load_json_if_exists(args.run_root / "eval" / "eval_info.json") or {}
train_config_path = find_latest_train_config_path(args.run_root)
train_config = load_json_if_exists(train_config_path) or {}
artifact_prefix = f"artifacts/{args.benchmark}/{args.policy}/{run_id}"
row_path_in_repo = f"rows/{args.benchmark}/{args.policy}/{run_id}.json"
row = {
"schema_version": 1,
"created_at": created_at,
"run_id": run_id,
"benchmark": args.benchmark,
"policy": args.policy,
"git_commit": args.git_commit,
"slurm_job_id": args.slurm_job_id or None,
"docker_image": args.docker_image,
"resources": {
"num_gpus": args.num_gpus,
"microbatch_per_gpu": args.microbatch_per_gpu,
"gradient_accumulation_steps": args.gradient_accumulation_steps,
"effective_batch_size": args.effective_batch_size,
},
"timings": {
"train_wall_time_s": args.train_wall_time_s,
"eval_wall_time_s": args.eval_wall_time_s,
"total_wall_time_s": args.train_wall_time_s + args.eval_wall_time_s,
},
"eval": {
"overall": eval_info.get("overall", {}),
"per_group": eval_info.get("per_group", {}),
"per_task_count": len(eval_info.get("per_task", [])),
},
"paths": {
"run_root": str(args.run_root),
"train_dir": str(args.run_root / "train"),
"eval_dir": str(args.run_root / "eval"),
},
"train_config": train_config,
"artifact_urls": {
"row": make_hub_file_url(args.results_repo, row_path_in_repo),
},
}
row_path = args.run_root / "benchmark_row.json"
row_path.parent.mkdir(parents=True, exist_ok=True)
upload_list = [UploadTarget(local_path=row_path, path_in_repo=row_path_in_repo)]
eval_info_path = args.run_root / "eval" / "eval_info.json"
if eval_info_path.exists():
row["artifact_urls"]["eval_info"] = make_hub_file_url(
args.results_repo, f"{artifact_prefix}/eval_info.json"
)
upload_list.append(
UploadTarget(local_path=eval_info_path, path_in_repo=f"{artifact_prefix}/eval_info.json")
)
if train_config_path is not None and train_config_path.exists():
row["artifact_urls"]["train_config"] = make_hub_file_url(
args.results_repo, f"{artifact_prefix}/train_config.json"
)
upload_list.append(
UploadTarget(local_path=train_config_path, path_in_repo=f"{artifact_prefix}/train_config.json")
)
row_path.write_text(json.dumps(row, indent=2, sort_keys=True))
return row, upload_list
def main() -> int:
args = parse_args()
row, upload_list = build_row(args)
uploaded = upload_targets(
repo_id=args.results_repo,
targets=upload_list,
repo_type="dataset",
private=False,
commit_message=f"Add benchmark row {row['run_id']}",
)
row["uploaded_paths"] = uploaded
row_path = args.run_root / "benchmark_row.json"
row_path.write_text(json.dumps(row, indent=2, sort_keys=True))
print(json.dumps(row, indent=2, sort_keys=True))
return 0
if __name__ == "__main__":
raise SystemExit(main())
+647
View File
@@ -0,0 +1,647 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Generate lightweight SLURM jobs for policy x benchmark benchmarking."""
from __future__ import annotations
import argparse
import json
import math
import subprocess
from dataclasses import asdict, dataclass, field
from datetime import UTC, datetime
from pathlib import Path
from typing import Any
from lerobot.utils.history_repo import utc_timestamp_slug
MAX_GPUS = 8
MIN_GPUS = 1
DEFAULT_STEPS = 20_000
DEFAULT_EFFECTIVE_BATCH_SIZE = 256
DEFAULT_MICROBATCH_PER_GPU = 32
DEFAULT_EVAL_BATCH_SIZE = 1
DEFAULT_CPUS_PER_GPU = 8
DEFAULT_MEMORY_PER_GPU_GB = 40
@dataclass(frozen=True)
class BenchmarkSpec:
name: str
dataset_repo_id: str
docker_image: str
eval_env_type: str
eval_task: str
eval_n_episodes: int
train_steps: int = DEFAULT_STEPS
effective_batch_size: int = DEFAULT_EFFECTIVE_BATCH_SIZE
train_extra_args: dict[str, Any] = field(default_factory=dict)
eval_extra_args: dict[str, Any] = field(default_factory=dict)
@dataclass(frozen=True)
class PolicySpec:
name: str
policy_type: str
num_gpus: int
policy_path: str | None = None
microbatch_per_gpu: int = DEFAULT_MICROBATCH_PER_GPU
extra_train_args: dict[str, Any] = field(default_factory=dict)
extra_eval_args: dict[str, Any] = field(default_factory=dict)
needs_tokenizer: bool = False
tokenizer_args: dict[str, Any] = field(default_factory=dict)
@dataclass(frozen=True)
class PlannedJob:
benchmark: str
policy: str
run_rel: str
num_gpus: int
microbatch_per_gpu: int
gradient_accumulation_steps: int
effective_batch_size: int
docker_image: str
train_args: dict[str, Any]
eval_args: dict[str, Any]
tokenizer_args: dict[str, Any] | None
script_path: str
BENCHMARKS: dict[str, BenchmarkSpec] = {
"libero_plus": BenchmarkSpec(
name="libero_plus",
dataset_repo_id="lerobot/libero_plus",
docker_image="lerobot-benchmark-libero-plus:latest",
eval_env_type="libero_plus",
eval_task="libero_spatial,libero_object,libero_goal,libero_10",
eval_n_episodes=10,
train_extra_args={
"rename_map": {
"observation.images.image": "observation.images.camera1",
"observation.images.image2": "observation.images.camera2",
},
},
eval_extra_args={
"env.camera_name_mapping": {
"agentview_image": "camera1",
"robot0_eye_in_hand_image": "camera2",
},
"env.max_parallel_tasks": 1,
"eval.batch_size": DEFAULT_EVAL_BATCH_SIZE,
"eval.use_async_envs": False,
"eval.max_episodes_rendered": 0,
"policy.device": "cuda",
},
),
"robomme": BenchmarkSpec(
name="robomme",
dataset_repo_id="lerobot/robomme",
docker_image="lerobot-benchmark-robomme:latest",
eval_env_type="robomme",
eval_task=(
"BinFill,PickXtimes,SwingXtimes,StopCube,VideoUnmask,VideoUnmaskSwap,"
"ButtonUnmask,ButtonUnmaskSwap,PickHighlight,VideoRepick,VideoPlaceButton,"
"VideoPlaceOrder,MoveCube,InsertPeg,PatternLock,RouteStick"
),
eval_n_episodes=50,
train_extra_args={
"rename_map": {
"observation.images.image": "observation.images.camera1",
"observation.images.wrist_image": "observation.images.camera2",
},
},
eval_extra_args={
"env.dataset_split": "test",
"env.max_parallel_tasks": 1,
"rename_map": {
"observation.images.image": "observation.images.camera1",
"observation.images.wrist_image": "observation.images.camera2",
},
"eval.batch_size": DEFAULT_EVAL_BATCH_SIZE,
"eval.use_async_envs": False,
"eval.max_episodes_rendered": 0,
"policy.device": "cuda",
},
),
}
POLICIES: dict[str, PolicySpec] = {
"pi0": PolicySpec(
name="pi0",
policy_type="pi0",
policy_path="lerobot/pi0_base",
num_gpus=8,
extra_train_args={
"policy.n_action_steps": 30,
"policy.scheduler_decay_steps": DEFAULT_STEPS,
"policy.empty_cameras": 0,
},
),
"pi0_fast": PolicySpec(
name="pi0_fast",
policy_type="pi0_fast",
policy_path="lerobot/pi0fast-base",
num_gpus=8,
extra_train_args={
"policy.n_action_steps": 30,
"policy.scheduler_decay_steps": DEFAULT_STEPS,
"policy.empty_cameras": 0,
},
needs_tokenizer=True,
tokenizer_args={
"action_horizon": 30,
"encoded_dims": "0:7",
"normalization_mode": "QUANTILES",
"vocab_size": 1024,
"scale": 10.0,
"push_to_hub": True,
},
),
"pi05": PolicySpec(
name="pi05",
policy_type="pi05",
policy_path="lerobot/pi05_base",
num_gpus=8,
extra_train_args={
"policy.n_action_steps": 30,
"policy.scheduler_decay_steps": DEFAULT_STEPS,
"policy.empty_cameras": 0,
},
),
"groot": PolicySpec(
name="groot",
policy_type="groot",
num_gpus=8,
extra_train_args={
"policy.n_action_steps": 30,
"policy.base_model_path": "nvidia/GR00T-N1.5-3B",
"policy.tune_diffusion_model": True,
"policy.tune_projector": True,
"policy.tune_llm": False,
"policy.tune_visual": False,
"policy.use_bf16": True,
},
),
"act": PolicySpec(
name="act",
policy_type="act",
num_gpus=1,
extra_train_args={
"policy.n_action_steps": 30,
},
),
"diffusion": PolicySpec(
name="diffusion",
policy_type="diffusion",
num_gpus=1,
extra_train_args={
"policy.horizon": 32,
"policy.n_action_steps": 30,
"policy.n_obs_steps": 2,
},
),
"smolvla": PolicySpec(
name="smolvla",
policy_type="smolvla",
policy_path="lerobot/smolvla_base",
num_gpus=8,
extra_train_args={
"policy.n_action_steps": 30,
"policy.load_vlm_weights": True,
"policy.freeze_vision_encoder": False,
"policy.train_expert_only": False,
"policy.scheduler_decay_steps": DEFAULT_STEPS,
"policy.empty_cameras": 1,
},
),
"xvla": PolicySpec(
name="xvla",
policy_type="xvla",
policy_path="lerobot/xvla-widowx",
num_gpus=4,
extra_train_args={
"policy.n_action_steps": 32,
"policy.scheduler_decay_steps": DEFAULT_STEPS,
"policy.empty_cameras": 1,
},
),
"multi_task_dit": PolicySpec(
name="multi_task_dit",
policy_type="multi_task_dit",
num_gpus=1,
extra_train_args={
"policy.horizon": 32,
"policy.n_action_steps": 30,
},
),
}
def normalize_repo_id(hub_org: str, repo_or_id: str) -> str:
return repo_or_id if "/" in repo_or_id else f"{hub_org}/{repo_or_id}"
def get_requested_names(
requested: list[str] | None,
available: dict[str, Any],
*,
kind: str,
) -> list[str]:
if not requested:
return list(available)
unknown = sorted(set(requested) - set(available))
if unknown:
raise ValueError(f"Unknown {kind}: {', '.join(unknown)}. Available: {', '.join(available)}")
return requested
def compute_gradient_accumulation_steps(
*,
effective_batch_size: int,
num_gpus: int,
microbatch_per_gpu: int,
) -> int:
per_step_batch = num_gpus * microbatch_per_gpu
if effective_batch_size % per_step_batch != 0:
raise ValueError(
f"Cannot reach effective batch {effective_batch_size} with {num_gpus=} and "
f"{microbatch_per_gpu=}."
)
return effective_batch_size // per_step_batch
def make_run_slug() -> str:
return utc_timestamp_slug()
def shell_value(value: Any) -> str:
if isinstance(value, bool):
value = "true" if value else "false"
elif isinstance(value, (dict, list)):
value = json.dumps(value, sort_keys=True)
else:
value = str(value)
escaped = (
value.replace("\\", "\\\\")
.replace('"', '\\"')
.replace("$", "\\$")
.replace("`", "\\`")
)
return f'"{escaped}"'
def format_cli_args(args: dict[str, Any]) -> str:
lines = []
for key, value in args.items():
lines.append(f" --{key}={shell_value(value)}")
return " \\\n".join(lines)
def build_train_args(
*,
benchmark: BenchmarkSpec,
policy: PolicySpec,
train_dir: str,
gradient_accumulation_steps: int,
) -> dict[str, Any]:
args: dict[str, Any] = {
"dataset.repo_id": benchmark.dataset_repo_id,
"output_dir": train_dir,
"steps": benchmark.train_steps,
"batch_size": policy.microbatch_per_gpu,
"gradient_accumulation_steps": gradient_accumulation_steps,
"eval_freq": 0,
"save_freq": benchmark.train_steps,
"save_checkpoint": True,
"log_freq": 100,
"wandb.enable": False,
"policy.push_to_hub": False,
"policy.device": "cuda",
}
if policy.policy_path:
args["policy.path"] = policy.policy_path
else:
args["policy.type"] = policy.policy_type
args.update(benchmark.train_extra_args)
args.update(policy.extra_train_args)
return args
def build_eval_args(
*,
benchmark: BenchmarkSpec,
policy: PolicySpec,
checkpoint_path: str,
eval_dir: str,
) -> dict[str, Any]:
args: dict[str, Any] = {
"policy.path": checkpoint_path,
"env.type": benchmark.eval_env_type,
"env.task": benchmark.eval_task,
"eval.n_episodes": benchmark.eval_n_episodes,
"output_dir": eval_dir,
}
args.update(benchmark.eval_extra_args)
args.update(policy.extra_eval_args)
return args
def plan_jobs(
*,
output_dir: Path,
hub_org: str,
results_repo: str,
policies: list[str],
benchmarks: list[str],
) -> list[PlannedJob]:
_ = hub_org
_ = results_repo
scripts_dir = output_dir / "slurm"
jobs: list[PlannedJob] = []
for benchmark_name in benchmarks:
benchmark = BENCHMARKS[benchmark_name]
for policy_name in policies:
policy = POLICIES[policy_name]
num_gpus = max(MIN_GPUS, min(policy.num_gpus, MAX_GPUS))
run_rel = f"runs/{benchmark_name}/{policy_name}/{make_run_slug()}"
run_root = f"/benchmark-output/{run_rel}"
gradient_accumulation_steps = compute_gradient_accumulation_steps(
effective_batch_size=benchmark.effective_batch_size,
num_gpus=num_gpus,
microbatch_per_gpu=policy.microbatch_per_gpu,
)
train_dir = f"{run_root}/train"
checkpoint_path = f"{train_dir}/checkpoints/{benchmark.train_steps:06d}/pretrained_model"
eval_dir = f"{run_root}/eval"
train_args = build_train_args(
benchmark=benchmark,
policy=policy,
train_dir=train_dir,
gradient_accumulation_steps=gradient_accumulation_steps,
)
eval_args = build_eval_args(
benchmark=benchmark,
policy=policy,
checkpoint_path=checkpoint_path,
eval_dir=eval_dir,
)
tokenizer_args = None
if policy.needs_tokenizer:
tokenizer_repo_id = f"{hub_org}/{policy_name}-{benchmark_name}-tokenizer"
tokenizer_args = {
"repo_id": benchmark.dataset_repo_id,
"output_dir": f"{run_root}/tokenizer",
"hub_repo_id": tokenizer_repo_id,
**policy.tokenizer_args,
}
train_args["policy.action_tokenizer_name"] = tokenizer_repo_id
script_path = str(scripts_dir / f"{benchmark_name}__{policy_name}.sbatch")
jobs.append(
PlannedJob(
benchmark=benchmark_name,
policy=policy_name,
run_rel=run_rel,
num_gpus=num_gpus,
microbatch_per_gpu=policy.microbatch_per_gpu,
gradient_accumulation_steps=gradient_accumulation_steps,
effective_batch_size=benchmark.effective_batch_size,
docker_image=benchmark.docker_image,
train_args=train_args,
eval_args=eval_args,
tokenizer_args=tokenizer_args,
script_path=script_path,
)
)
return jobs
def render_sbatch_script(
*,
job: PlannedJob,
output_dir: Path,
results_repo_id: str,
git_commit: str,
) -> str:
host_output_dir = output_dir.resolve()
run_root = f"/benchmark-output/{job.run_rel}"
host_run_root = host_output_dir / job.run_rel
cpus_per_task = max(DEFAULT_CPUS_PER_GPU, DEFAULT_CPUS_PER_GPU * job.num_gpus)
mem_gb = max(DEFAULT_MEMORY_PER_GPU_GB, DEFAULT_MEMORY_PER_GPU_GB * job.num_gpus)
gpu_ids_expr = "${GPU_IDS}"
train_cli = format_cli_args(job.train_args)
eval_cli = format_cli_args(job.eval_args)
tokenizer_command = ""
if job.tokenizer_args:
tokenizer_cli = format_cli_args(job.tokenizer_args)
tokenizer_command = f"""
docker run --rm --gpus all \\
--shm-size=16g \\
-e CUDA_VISIBLE_DEVICES={gpu_ids_expr} \\
-e HF_TOKEN="${{HF_TOKEN:-}}" \\
-e HF_USER_TOKEN="${{HF_TOKEN:-}}" \\
-e HF_HOME=/tmp/hf \\
-v "{host_output_dir}:/benchmark-output" \\
-w /lerobot \\
"{job.docker_image}" \\
bash -lc '
set -euo pipefail
if [[ -n "${{HF_TOKEN:-}}" ]]; then
hf auth login --token "${{HF_TOKEN}}" --add-to-git-credential 2>/dev/null || true
fi
lerobot-train-tokenizer \\
{tokenizer_cli}
'
"""
return f"""#!/bin/bash
#SBATCH --job-name=bench-{job.benchmark}-{job.policy}
#SBATCH --gres=gpu:{job.num_gpus}
#SBATCH --cpus-per-task={cpus_per_task}
#SBATCH --mem={mem_gb}G
#SBATCH --output={output_dir.resolve()}/logs/{job.benchmark}__{job.policy}__%j.out
#SBATCH --error={output_dir.resolve()}/logs/{job.benchmark}__{job.policy}__%j.err
set -euo pipefail
HF_TOKEN="${{HF_TOKEN:-${{HF_USER_TOKEN:-}}}}"
GPU_IDS="$(seq -s, 0 $(({job.num_gpus} - 1)))"
RUN_ROOT="{run_root}"
mkdir -p "{host_output_dir}/logs"
mkdir -p "{host_run_root.parent}"
{tokenizer_command}
TRAIN_START="$(date +%s)"
docker run --rm --gpus all \\
--shm-size=16g \\
-e CUDA_VISIBLE_DEVICES="${{GPU_IDS}}" \\
-e HF_TOKEN="${{HF_TOKEN:-}}" \\
-e HF_USER_TOKEN="${{HF_TOKEN:-}}" \\
-e HF_HOME=/tmp/hf \\
-v "{host_output_dir}:/benchmark-output" \\
-w /lerobot \\
"{job.docker_image}" \\
bash -lc '
set -euo pipefail
if [[ -n "${{HF_TOKEN:-}}" ]]; then
hf auth login --token "${{HF_TOKEN}}" --add-to-git-credential 2>/dev/null || true
fi
accelerate launch --num_processes={job.num_gpus} $(which lerobot-train) \\
{train_cli}
'
TRAIN_END="$(date +%s)"
EVAL_START="$(date +%s)"
docker run --rm --gpus all \\
--shm-size=16g \\
-e CUDA_VISIBLE_DEVICES="${{GPU_IDS}}" \\
-e HF_TOKEN="${{HF_TOKEN:-}}" \\
-e HF_USER_TOKEN="${{HF_TOKEN:-}}" \\
-e HF_HOME=/tmp/hf \\
-v "{host_output_dir}:/benchmark-output" \\
-w /lerobot \\
"{job.docker_image}" \\
bash -lc '
set -euo pipefail
if [[ -n "${{HF_TOKEN:-}}" ]]; then
hf auth login --token "${{HF_TOKEN}}" --add-to-git-credential 2>/dev/null || true
fi
lerobot-eval \\
{eval_cli}
'
EVAL_END="$(date +%s)"
TRAIN_WALL_TIME_S="$((TRAIN_END - TRAIN_START))"
EVAL_WALL_TIME_S="$((EVAL_END - EVAL_START))"
docker run --rm --gpus all \\
--shm-size=16g \\
-e CUDA_VISIBLE_DEVICES="${{GPU_IDS}}" \\
-e HF_TOKEN="${{HF_TOKEN:-}}" \\
-e HF_USER_TOKEN="${{HF_TOKEN:-}}" \\
-e HF_HOME=/tmp/hf \\
-e RUN_ROOT="${{RUN_ROOT}}" \\
-e TRAIN_WALL_TIME_S="${{TRAIN_WALL_TIME_S}}" \\
-e EVAL_WALL_TIME_S="${{EVAL_WALL_TIME_S}}" \\
-v "{host_output_dir}:/benchmark-output" \\
-w /lerobot \\
"{job.docker_image}" \\
bash -lc '
set -euo pipefail
if [[ -n "${{HF_TOKEN:-}}" ]]; then
hf auth login --token "${{HF_TOKEN}}" --add-to-git-credential 2>/dev/null || true
fi
uv run python benchmarks/publish_benchmark_result.py \\
--benchmark={job.benchmark} \\
--policy={job.policy} \\
--run_root="${{RUN_ROOT}}" \\
--results_repo={results_repo_id} \\
--git_commit={git_commit} \\
--num_gpus={job.num_gpus} \\
--microbatch_per_gpu={job.microbatch_per_gpu} \\
--gradient_accumulation_steps={job.gradient_accumulation_steps} \\
--effective_batch_size={job.effective_batch_size} \\
--train_wall_time_s="${{TRAIN_WALL_TIME_S}}" \\
--eval_wall_time_s="${{EVAL_WALL_TIME_S}}" \\
--slurm_job_id="${{SLURM_JOB_ID:-}}" \\
--docker_image={job.docker_image}
'
"""
def write_manifest(
*,
output_dir: Path,
jobs: list[PlannedJob],
git_commit: str,
hub_org: str,
results_repo: str,
) -> Path:
manifest = {
"generated_at": datetime.now(UTC).isoformat(),
"git_commit": git_commit,
"hub_org": hub_org,
"results_repo": results_repo,
"jobs": [asdict(job) for job in jobs],
}
manifest_path = output_dir / "manifest.json"
manifest_path.write_text(json.dumps(manifest, indent=2, sort_keys=True))
return manifest_path
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--policies", nargs="*", default=None)
parser.add_argument("--benchmarks", nargs="*", default=None)
parser.add_argument("--output_dir", required=True, type=Path)
parser.add_argument("--hub_org", required=True)
parser.add_argument("--results_repo", required=True)
parser.add_argument("--submit", action="store_true")
return parser.parse_args()
def get_git_commit() -> str:
return subprocess.check_output(["git", "rev-parse", "HEAD"], text=True).strip()
def main() -> int:
args = parse_args()
args.output_dir.mkdir(parents=True, exist_ok=True)
(args.output_dir / "slurm").mkdir(parents=True, exist_ok=True)
(args.output_dir / "logs").mkdir(parents=True, exist_ok=True)
selected_policies = get_requested_names(args.policies, POLICIES, kind="policies")
selected_benchmarks = get_requested_names(args.benchmarks, BENCHMARKS, kind="benchmarks")
git_commit = get_git_commit()
results_repo_id = normalize_repo_id(args.hub_org, args.results_repo)
jobs = plan_jobs(
output_dir=args.output_dir,
hub_org=args.hub_org,
results_repo=results_repo_id,
policies=selected_policies,
benchmarks=selected_benchmarks,
)
for job in jobs:
script = render_sbatch_script(
job=job,
output_dir=args.output_dir,
results_repo_id=results_repo_id,
git_commit=git_commit,
)
script_path = Path(job.script_path)
script_path.write_text(script)
script_path.chmod(0o755)
if args.submit:
subprocess.run(["sbatch", str(script_path)], check=True)
manifest_path = write_manifest(
output_dir=args.output_dir,
jobs=jobs,
git_commit=git_commit,
hub_org=args.hub_org,
results_repo=results_repo_id,
)
print(f"Wrote {len(jobs)} benchmark jobs to {args.output_dir}")
print(f"Manifest: {manifest_path}")
return 0
if __name__ == "__main__":
raise SystemExit(main())
+48
View File
@@ -0,0 +1,48 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
FROM huggingface/lerobot-gpu:latest
USER root
RUN apt-get update \
&& apt-get install -y --no-install-recommends \
unzip libexpat1 libfontconfig1-dev libmagickwand-dev \
&& apt-get clean && rm -rf /var/lib/apt/lists/*
USER user_lerobot
RUN uv pip install --no-cache \
"robosuite==1.4.1" bddl easydict mujoco matplotlib wand scikit-image gym
ENV LIBERO_PLUS_ROOT=/home/user_lerobot/libero-plus/libero/libero
RUN git clone --depth=1 https://github.com/sylvestf/LIBERO-plus.git /home/user_lerobot/libero-plus \
&& cd /home/user_lerobot/libero-plus && uv pip install --no-cache --no-deps -e "." \
&& uv pip uninstall hf-libero 2>/dev/null || true
ENV PYTHONPATH="/home/user_lerobot/libero-plus:${PYTHONPATH}"
RUN python -c "\
from huggingface_hub import hf_hub_download; \
hf_hub_download(repo_id='Sylvest/LIBERO-plus', repo_type='dataset', \
filename='assets.zip', local_dir='/tmp/libero-plus-dl')" \
&& unzip -q /tmp/libero-plus-dl/assets.zip -d /tmp/libero-plus-dl/extract \
&& mv /tmp/libero-plus-dl/extract/inspire/hdd/project/embodied-multimodality/public/syfei/libero_new/release/dataset/LIBERO-plus-0/assets \
${LIBERO_PLUS_ROOT}/assets \
&& rm -rf /tmp/libero-plus-dl
RUN mkdir -p /home/user_lerobot/.libero \
&& printf "assets: ${LIBERO_PLUS_ROOT}/assets\nbddl_files: ${LIBERO_PLUS_ROOT}/bddl_files\ndatasets: ${LIBERO_PLUS_ROOT}/../datasets\ninit_states: ${LIBERO_PLUS_ROOT}/init_files\n" \
> /home/user_lerobot/.libero/config.yaml
COPY --chown=user_lerobot:user_lerobot . .
CMD ["/bin/bash"]
+39
View File
@@ -0,0 +1,39 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
FROM huggingface/lerobot-gpu:latest
ENV NVIDIA_DRIVER_CAPABILITIES=all \
VK_ICD_FILENAMES=/usr/share/vulkan/icd.d/nvidia_icd.json
USER root
RUN apt-get update \
&& apt-get install -y --no-install-recommends \
libvulkan1 libvulkan-dev mesa-vulkan-drivers \
&& mkdir -p /usr/share/vulkan/icd.d \
&& echo '{"file_format_version":"1.0.0","ICD":{"library_path":"libGLX_nvidia.so.0","api_version":"1.3.0"}}' \
> /usr/share/vulkan/icd.d/nvidia_icd.json \
&& apt-get clean && rm -rf /var/lib/apt/lists/*
USER user_lerobot
COPY --chown=user_lerobot:user_lerobot setup.py pyproject.toml uv.lock README.md MANIFEST.in ./
RUN printf 'gymnasium==0.29.1\nnumpy==1.26.4\n' > /tmp/robomme_override.txt \
&& uv pip install --no-cache --override /tmp/robomme_override.txt \
-e ".[smolvla,av-dep]" \
"robomme @ git+https://github.com/RoboMME/robomme_benchmark.git@main" \
&& python -c "import robomme; print('robomme import OK')"
COPY --chown=user_lerobot:user_lerobot . .
CMD ["/bin/bash"]
-2
View File
@@ -61,8 +61,6 @@
title: SARM
title: "Reward Models"
- sections:
- local: inference
title: Policy Deployment (lerobot-rollout)
- local: async
title: Use Async Inference
- local: rtc
+19 -17
View File
@@ -50,30 +50,30 @@ This process can be repeated iteratively: deploy, collect, fine-tune, repeat. Ea
### Teleoperator Requirements
The `lerobot-rollout --strategy.type=dagger` mode requires **teleoperators with active motors** that can:
The `examples/hil` HIL scripts require **teleoperators with active motors** that can:
- Enable/disable torque programmatically
- Move to target positions (to mirror the robot state when pausing)
**Compatible teleoperators:**
**Compatible teleoperators in the current `examples/hil` scripts:**
- `openarm_mini` - OpenArm Mini
- `so_leader` - SO100 / SO101 leader arm
> [!IMPORTANT]
> The provided commands default to `bi_openarm_follower` + `openarm_mini`.
> The provided `examples/hil` commands default to `bi_openarm_follower` + `openarm_mini`.
> `so_follower` + `so_leader` configs are also registered and can be used via CLI flags.
---
## Script
Use `lerobot-rollout` with `--strategy.type=dagger` for HIL data collection. Select the inference backend with `--inference.type=sync|rtc`:
A single script handles both synchronous and RTC-based inference. Toggle RTC with `--rtc.enabled=true`:
| Mode | Flag | Models |
| ------------------------ | ---------------------- | --------------------- |
| Standard (default) | _(no flag needed)_ | ACT, Diffusion Policy |
| Real-Time Chunking (RTC) | `--inference.type=rtc` | Pi0, Pi0.5, SmolVLA |
| Mode | Flag | Models |
| ------------------------ | -------------------- | --------------------- |
| Standard (default) | _(no flag needed)_ | ACT, Diffusion Policy |
| Real-Time Chunking (RTC) | `--rtc.enabled=true` | Pi0, Pi0.5, SmolVLA |
---
@@ -97,7 +97,7 @@ python src/lerobot/scripts/lerobot_train.py \
**Standard inference (ACT, Diffusion Policy):**
```bash
lerobot-rollout --strategy.type=dagger \
python examples/hil/hil_data_collection.py \
--robot.type=bi_openarm_follower \
--robot.left_arm_config.port=can1 \
--robot.left_arm_config.side=left \
@@ -111,7 +111,8 @@ lerobot-rollout --strategy.type=dagger \
--dataset.repo_id=your-username/hil-dataset \
--dataset.single_task="Fold the T-shirt properly" \
--dataset.fps=30 \
--strategy.num_episodes=50 \
--dataset.episode_time_s=1000 \
--dataset.num_episodes=50 \
--interpolation_multiplier=2
```
@@ -120,11 +121,11 @@ lerobot-rollout --strategy.type=dagger \
For models with high inference latency, enable RTC for smooth execution:
```bash
lerobot-rollout --strategy.type=dagger \
--inference.type=rtc \
--inference.rtc.execution_horizon=20 \
--inference.rtc.max_guidance_weight=5.0 \
--inference.rtc.prefix_attention_schedule=LINEAR \
python examples/hil/hil_data_collection.py \
--rtc.enabled=true \
--rtc.execution_horizon=20 \
--rtc.max_guidance_weight=5.0 \
--rtc.prefix_attention_schedule=LINEAR \
--robot.type=bi_openarm_follower \
--robot.left_arm_config.port=can1 \
--robot.left_arm_config.side=left \
@@ -138,7 +139,8 @@ lerobot-rollout --strategy.type=dagger \
--dataset.repo_id=your-username/hil-rtc-dataset \
--dataset.single_task="Fold the T-shirt properly" \
--dataset.fps=30 \
--strategy.num_episodes=50 \
--dataset.episode_time_s=1000 \
--dataset.num_episodes=50 \
--interpolation_multiplier=3
```
@@ -233,7 +235,7 @@ This HIL data collection approach builds on ideas from interactive imitation lea
- **HG-DAgger** (Kelly et al., 2019) made this practical for robotics: a human expert monitors the robot and only intervenes when needed, rather than labeling every state. The gating between autonomous and human control is exactly the pause → takeover → return-to-policy loop used in the scripts here.
- **RaC** (Hu et al., 2025) scales this loop to long-horizon tasks by explicitly decomposing interventions into **recovery** (teleoperating back to a good state) and **correction** (demonstrating the right behavior from there). This decomposition is the protocol followed by the DAgger strategy in `lerobot-rollout`.
- **RaC** (Hu et al., 2025) scales this loop to long-horizon tasks by explicitly decomposing interventions into **recovery** (teleoperating back to a good state) and **correction** (demonstrating the right behavior from there). This decomposition is the protocol followed by the HIL scripts in `examples/hil`.
- **π0.6/RECAP** (Physical Intelligence, 2025) applies the same iterative collect-and-finetune loop at scale with VLA models, showing that even large pretrained policies benefit substantially from targeted human corrections on their own failure modes. π0.6 is trained using RECAP.
+2 -26
View File
@@ -685,10 +685,6 @@ Example configuration for training the [reward classifier](https://huggingface.c
```json
{
"dataset": {
"repo_id": "hf_username/dataset_name",
"root": null
},
"policy": {
"type": "reward_classifier",
"model_name": "helper2424/resnet10",
@@ -709,28 +705,8 @@ Example configuration for training the [reward classifier](https://huggingface.c
"type": "VISUAL",
"shape": [3, 128, 128]
}
},
"push_to_hub": true,
"repo_id": "hf_username/model_repo"
},
"batch_size": 16,
"num_workers": 4,
"steps": 5000,
"log_freq": 10,
"eval_freq": 1000,
"save_freq": 1000,
"save_checkpoint": true,
"seed": 2,
"resume": false,
"optimizer": {
"grad_clip_norm": 10.0
},
"wandb": {
"enable": true,
"project": "reward-classifier",
"disable_artifact": false
},
"job_name": "reward-classifier"
}
}
}
```
+105 -32
View File
@@ -32,12 +32,6 @@ Once youve gathered enough trajectories, youll train a neural network to i
If you run into any issues at any point, jump into our [Discord community](https://discord.com/invite/s3KuuzsPFb) for support.
<Tip>
Want to quickly get the right commands for your setup? The [quickstart notebook](https://github.com/huggingface/lerobot/blob/main/examples/notebooks/quickstart.ipynb) [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/lerobot/blob/main/examples/notebooks/quickstart.ipynb) lets you configure your robot once and generates all the commands below ready to paste.
</Tip>
## Set up and Calibrate
If you haven't yet set up and calibrated your robot and teleop device, please do so by following the robot-specific tutorial.
@@ -509,42 +503,121 @@ hf upload ${HF_USER}/act_so101_test${CKPT} \
## Run inference and evaluate your policy
Use `lerobot-rollout` to deploy a trained policy on your robot. You can choose different strategies depending on your needs:
You can use the `record` script from [`lerobot-record`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/scripts/lerobot_record.py) with a policy checkpoint as input, to run inference and evaluate your policy. For instance, run this command or API example to run inference and record 10 evaluation episodes:
<hfoptions id="eval">
<hfoption id="Base mode (no recording)">
<hfoption id="Command">
```bash
lerobot-rollout \
--strategy.type=base \
--policy.path=${HF_USER}/my_policy \
--robot.type=so100_follower \
--robot.port=/dev/ttyACM1 \
--robot.cameras="{ up: {type: opencv, index_or_path: /dev/video10, width: 640, height: 480, fps: 30}, side: {type: intelrealsense, serial_number_or_name: 233522074606, width: 640, height: 480, fps: 30}}" \
--task="Put lego brick into the transparent box" \
--duration=60
```
</hfoption>
<hfoption id="Sentry mode (with recording)">
```bash
lerobot-rollout \
--strategy.type=sentry \
--strategy.upload_every_n_episodes=5 \
--policy.path=${HF_USER}/my_policy \
lerobot-record \
--robot.type=so100_follower \
--robot.port=/dev/ttyACM1 \
--robot.cameras="{ up: {type: opencv, index_or_path: /dev/video10, width: 640, height: 480, fps: 30}, side: {type: intelrealsense, serial_number_or_name: 233522074606, width: 640, height: 480, fps: 30}}" \
--robot.id=my_awesome_follower_arm \
--display_data=false \
--dataset.repo_id=${HF_USER}/eval_so100 \
--dataset.single_task="Put lego brick into the transparent box" \
--duration=600
--dataset.streaming_encoding=true \
--dataset.encoder_threads=2 \
# --dataset.vcodec=auto \
# <- Teleop optional if you want to teleoperate in between episodes \
# --teleop.type=so100_leader \
# --teleop.port=/dev/ttyACM0 \
# --teleop.id=my_awesome_leader_arm \
--policy.path=${HF_USER}/my_policy
```
</hfoption>
<hfoption id="API example">
<!-- prettier-ignore-start -->
```python
from lerobot.cameras.opencv import OpenCVCameraConfig
from lerobot.datasets import LeRobotDataset
from lerobot.utils.feature_utils import hw_to_dataset_features
from lerobot.policies.act import ACTPolicy
from lerobot.policies import make_pre_post_processors
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig
from lerobot.scripts.lerobot_record import record_loop
from lerobot.common.control_utils import init_keyboard_listener
from lerobot.utils.utils import log_say
from lerobot.utils.visualization_utils import init_rerun
NUM_EPISODES = 5
FPS = 30
EPISODE_TIME_SEC = 60
TASK_DESCRIPTION = "My task description"
HF_MODEL_ID = "<hf_username>/<model_repo_id>"
HF_DATASET_ID = "<hf_username>/<eval_dataset_repo_id>"
# Create the robot configuration
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
robot_config = SO100FollowerConfig(
port="/dev/tty.usbmodem58760434471", id="my_awesome_follower_arm", cameras=camera_config
)
# Initialize the robot
robot = SO100Follower(robot_config)
# Initialize the policy
policy = ACTPolicy.from_pretrained(HF_MODEL_ID)
# Configure the dataset features
action_features = hw_to_dataset_features(robot.action_features, "action")
obs_features = hw_to_dataset_features(robot.observation_features, "observation")
dataset_features = {**action_features, **obs_features}
# Create the dataset
dataset = LeRobotDataset.create(
repo_id=HF_DATASET_ID,
fps=FPS,
features=dataset_features,
robot_type=robot.name,
use_videos=True,
image_writer_threads=4,
)
# Initialize the keyboard listener and rerun visualization
_, events = init_keyboard_listener()
init_rerun(session_name="recording")
# Connect the robot
robot.connect()
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=policy,
pretrained_path=HF_MODEL_ID,
dataset_stats=dataset.meta.stats,
)
for episode_idx in range(NUM_EPISODES):
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
# Run the policy inference loop
record_loop(
robot=robot,
events=events,
fps=FPS,
policy=policy,
preprocessor=preprocessor,
postprocessor=postprocessor,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
)
dataset.save_episode()
# Clean up
robot.disconnect()
dataset.push_to_hub()
```
<!-- prettier-ignore-end -->
</hfoption>
</hfoptions>
The `--strategy.type` flag selects the execution mode:
As you can see, it's almost the same command as previously used to record your training dataset. Two things changed:
- `base`: Autonomous rollout with no data recording (useful for quick evaluation)
- `sentry`: Continuous recording with auto-upload (useful for large-scale evaluation)
- `highlight`: Ring buffer recording with keystroke save (useful for capturing interesting events)
- `dagger`: Human-in-the-loop data collection (see [HIL Data Collection](./hil_data_collection))
All strategies support `--inference.type=rtc` for smooth execution with slow VLA models (Pi0, Pi0.5, SmolVLA).
1. There is an additional `--control.policy.path` argument which indicates the path to your policy checkpoint with (e.g. `outputs/train/eval_act_so101_test/checkpoints/last/pretrained_model`). You can also use the model repository if you uploaded a model checkpoint to the hub (e.g. `${HF_USER}/act_so101_test`).
2. The name of dataset begins by `eval` to reflect that you are running inference (e.g. `${HF_USER}/eval_act_so101_test`).
-261
View File
@@ -1,261 +0,0 @@
# Policy Deployment (lerobot-rollout)
`lerobot-rollout` is the single CLI for deploying trained policies on real robots. It supports multiple execution strategies and inference backends, from quick evaluation to continuous recording and human-in-the-loop data collection.
## Quick Start
No extra dependencies are needed beyond your robot and policy extras.
```bash
lerobot-rollout \
--strategy.type=base \
--policy.path=lerobot/act_koch_real \
--robot.type=koch_follower \
--robot.port=/dev/ttyACM0 \
--task="pick up cube" \
--duration=30
```
This runs the policy for 30 seconds with no recording.
---
## Strategies
Select a strategy with `--strategy.type=<name>`. Each strategy defines a different control loop with its own recording and interaction semantics.
### Base (`--strategy.type=base`)
Autonomous policy execution with no data recording. Use this for quick evaluation, demos, or when you only need to observe the robot.
```bash
lerobot-rollout \
--strategy.type=base \
--policy.path=${HF_USER}/my_policy \
--robot.type=so100_follower \
--robot.port=/dev/ttyACM0 \
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
--task="Put lego brick into the box" \
--duration=60
```
| Flag | Description |
| ---------------- | ------------------------------------------------------ |
| `--duration` | Run time in seconds (0 = infinite) |
| `--task` | Task description passed to the policy |
| `--display_data` | Stream observations/actions to Rerun for visualization |
### Sentry (`--strategy.type=sentry`)
Continuous autonomous recording with periodic upload to the Hugging Face Hub. Episode boundaries are auto-computed from camera resolution and FPS so each saved episode produces a complete video file, keeping uploads efficient.
Policy state (hidden state, RTC queue) persists across episode boundaries: the robot does not reset between episodes.
```bash
lerobot-rollout \
--strategy.type=sentry \
--strategy.upload_every_n_episodes=5 \
--policy.path=${HF_USER}/my_policy \
--robot.type=so100_follower \
--robot.port=/dev/ttyACM0 \
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
--dataset.repo_id=${HF_USER}/eval_data \
--dataset.single_task="Put lego brick into the box" \
--duration=3600
```
| Flag | Description |
| -------------------------------------- | ----------------------------------------------------------- |
| `--strategy.upload_every_n_episodes` | Push to Hub every N episodes (default: 5) |
| `--strategy.target_video_file_size_mb` | Target video file size for episode rotation (default: auto) |
| `--dataset.repo_id` | **Required.** Hub repository for the recorded dataset |
| `--dataset.push_to_hub` | Whether to push to Hub on teardown (default: true) |
### Highlight (`--strategy.type=highlight`)
Autonomous rollout with on-demand recording via a memory-bounded ring buffer. The robot runs continuously while the buffer captures the last N seconds of telemetry. Press the save key to flush the buffer and start live recording; press it again to save the episode.
```bash
lerobot-rollout \
--strategy.type=highlight \
--strategy.ring_buffer_seconds=30 \
--strategy.save_key=s \
--strategy.push_key=h \
--policy.path=${HF_USER}/my_policy \
--robot.type=koch_follower \
--robot.port=/dev/ttyACM0 \
--dataset.repo_id=${HF_USER}/highlight_data \
--dataset.single_task="Pick up the red cube"
```
**Keyboard controls:**
| Key | Action |
| ------------------ | -------------------------------------------------------- |
| `s` (configurable) | Start recording (flushes buffer) / stop and save episode |
| `h` (configurable) | Push dataset to Hub |
| `ESC` | Stop the session |
| Flag | Description |
| -------------------------------------- | ---------------------------------------------- |
| `--strategy.ring_buffer_seconds` | Duration of buffered telemetry (default: 30) |
| `--strategy.ring_buffer_max_memory_mb` | Memory cap for the ring buffer (default: 2048) |
| `--strategy.save_key` | Key to toggle recording (default: `s`) |
| `--strategy.push_key` | Key to push to Hub (default: `h`) |
### DAgger (`--strategy.type=dagger`)
Human-in-the-loop data collection. Alternates between autonomous policy execution and human intervention via a teleoperator. Intervention frames are tagged with `intervention=True`. Requires a teleoperator (`--teleop.type`).
See the [Human-In-the-Loop Data Collection](./hil_data_collection) guide for a detailed walkthrough.
**Corrections-only mode** (default): Only human correction windows are recorded. Each correction becomes one episode.
```bash
lerobot-rollout \
--strategy.type=dagger \
--strategy.num_episodes=20 \
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
--robot.type=bi_openarm_follower \
--teleop.type=openarm_mini \
--dataset.repo_id=${HF_USER}/hil_data \
--dataset.single_task="Fold the T-shirt"
```
**Continuous recording mode** (`--strategy.record_autonomous=true`): Both autonomous and correction frames are recorded with time-based episode rotation (same as Sentry).
```bash
lerobot-rollout \
--strategy.type=dagger \
--strategy.record_autonomous=true \
--strategy.num_episodes=50 \
--policy.path=${HF_USER}/my_policy \
--robot.type=so100_follower \
--robot.port=/dev/ttyACM0 \
--teleop.type=so101_leader \
--teleop.port=/dev/ttyACM1 \
--dataset.repo_id=${HF_USER}/dagger_data \
--dataset.single_task="Grasp the block"
```
**Keyboard controls** (default input device):
| Key | Action |
| ------- | ------------------------------------------- |
| `Space` | Pause / resume policy execution |
| `Tab` | Start / stop human correction |
| `Enter` | Push dataset to Hub (corrections-only mode) |
| `ESC` | Stop the session |
Foot pedal input is also supported via `--strategy.input_device=pedal`. Configure pedal codes with `--strategy.pedal.*` flags.
| Flag | Description |
| ------------------------------------ | ------------------------------------------------------- |
| `--strategy.num_episodes` | Number of correction episodes to record (default: 10) |
| `--strategy.record_autonomous` | Record autonomous frames too (default: false) |
| `--strategy.upload_every_n_episodes` | Push to Hub every N episodes (default: 5) |
| `--strategy.input_device` | Input device: `keyboard` or `pedal` (default: keyboard) |
| `--teleop.type` | **Required.** Teleoperator type |
---
## Inference Backends
Select a backend with `--inference.type=<name>`. All strategies work with both backends.
### Sync (default)
One policy call per control tick. The main loop blocks until the action is computed.
Works with all policies. No extra flags needed.
### Real-Time Chunking (`--inference.type=rtc`)
A background thread produces action chunks asynchronously. The main control loop polls for the next ready action while the policy computes the next chunk in parallel.
Use RTC with large, slow VLA models (Pi0, Pi0.5, SmolVLA) for smooth, continuous motion despite high inference latency.
```bash
lerobot-rollout \
--strategy.type=base \
--inference.type=rtc \
--inference.rtc.execution_horizon=10 \
--inference.rtc.max_guidance_weight=10.0 \
--policy.path=${HF_USER}/pi0_policy \
--robot.type=so100_follower \
--robot.port=/dev/ttyACM0 \
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
--task="Pick up the cube" \
--duration=60 \
--device=cuda
```
| Flag | Description |
| ------------------------------------------- | -------------------------------------------------------------- |
| `--inference.rtc.execution_horizon` | Steps to blend with previous chunk (default: varies by policy) |
| `--inference.rtc.max_guidance_weight` | Consistency enforcement strength (default: varies by policy) |
| `--inference.rtc.prefix_attention_schedule` | Blend schedule: `LINEAR`, `EXP`, `ONES`, `ZEROS` |
| `--inference.queue_threshold` | Max queue size before backpressure (default: 30) |
See the [Real-Time Chunking](./rtc) guide for details on tuning RTC parameters.
---
## Common Flags
| Flag | Description | Default |
| --------------------------------- | ----------------------------------------------------------------- | ------- |
| `--policy.path` | **Required.** HF Hub model ID or local checkpoint path | -- |
| `--robot.type` | **Required.** Robot type (e.g. `so100_follower`, `koch_follower`) | -- |
| `--robot.port` | Serial port for the robot | -- |
| `--robot.cameras` | Camera configuration (JSON dict) | -- |
| `--fps` | Control loop frequency | 30 |
| `--duration` | Run time in seconds (0 = infinite) | 0 |
| `--device` | Torch device (`cpu`, `cuda`, `mps`) | auto |
| `--task` | Task description (used when no dataset is provided) | -- |
| `--display_data` | Stream telemetry to Rerun visualization | false |
| `--display_ip` / `--display_port` | Remote Rerun server address | -- |
| `--interpolation_multiplier` | Action interpolation factor | 1 |
| `--use_torch_compile` | Enable `torch.compile` for inference | false |
| `--resume` | Resume a previous recording session | false |
| `--play_sounds` | Vocal synthesis for events | true |
---
## Programmatic Usage
For custom deployments (e.g. with kinematics processors), use the rollout module API directly:
```python
from lerobot.rollout import BaseStrategyConfig, RolloutConfig, build_rollout_context
from lerobot.rollout.inference import SyncInferenceConfig
from lerobot.rollout.strategies import BaseStrategy
from lerobot.utils.process import ProcessSignalHandler
cfg = RolloutConfig(
robot=my_robot_config,
policy=my_policy_config,
strategy=BaseStrategyConfig(),
inference=SyncInferenceConfig(),
fps=30,
duration=60,
task="my task",
)
signal_handler = ProcessSignalHandler(use_threads=True)
ctx = build_rollout_context(
cfg,
signal_handler.shutdown_event,
robot_action_processor=my_custom_action_processor, # optional
robot_observation_processor=my_custom_obs_processor, # optional
)
strategy = BaseStrategy(cfg.strategy)
try:
strategy.setup(ctx)
strategy.run(ctx)
finally:
strategy.teardown(ctx)
```
See `examples/so100_to_so100_EE/rollout.py` and `examples/phone_to_so100/rollout.py` for full examples with kinematics processors.
+3 -7
View File
@@ -34,7 +34,7 @@ pip install -e ".[smolvla]"
### Using RTC with Pi0
You can use `lerobot-rollout --strategy.type=base --inference.type=rtc` for RTC deployment on real robots.
You can find a complete reference implementation in [eval_with_real_robot.py](examples/rtc/eval_with_real_robot.py).
The snippet below provides a simplified pseudo-example of how RTC operates with Pi0 in your pipeline:
```python
@@ -137,12 +137,8 @@ The script generates a visualization of the denoising process, comparing standar
## Testing RTC with a Real Robot
```bash
lerobot-rollout \
--strategy.type=base \
python examples/rtc/eval_with_real_robot.py \
--policy.path=${HF_USERNAME}/policy_repo_id \
--inference.type=rtc \
--inference.rtc.execution_horizon=10 \
--inference.rtc.max_guidance_weight=10.0 \
--robot.type=so100_follower \
--robot.port=/dev/tty.usbmodem58FA0834591 \
--robot.cameras="{ gripper: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
@@ -182,7 +178,7 @@ visualizer = RTCDebugVisualizer()
# ... create plots
```
See `examples/rtc/eval_dataset.py` for a complete example of offline RTC visualization.
See `examples/rtc/eval_dataset.py` for a complete example of visualization.
## References
+1 -1
View File
@@ -284,7 +284,7 @@ python examples/rtc/eval_with_real_robot.py \
--task="task_description" \
--duration=1000 \
--fps=30 \
--inference.type=rtc
--rtc.enabled=true
```
---
File diff suppressed because it is too large Load Diff
+226
View File
@@ -0,0 +1,226 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Shared utilities for Human-in-the-Loop data collection scripts."""
import logging
import time
from dataclasses import dataclass, field
from pathlib import Path
from lerobot.common.control_utils import is_headless
from lerobot.processor import (
IdentityProcessorStep,
RobotAction,
RobotObservation,
RobotProcessorPipeline,
observation_to_transition,
robot_action_observation_to_transition,
transition_to_observation,
transition_to_robot_action,
)
from lerobot.robots import Robot
from lerobot.teleoperators import Teleoperator
from lerobot.utils.robot_utils import precise_sleep
logger = logging.getLogger(__name__)
@dataclass
class HILDatasetConfig:
repo_id: str
single_task: str
root: str | Path | None = None
fps: int = 30
episode_time_s: float = 120
num_episodes: int = 50
video: bool = True
push_to_hub: bool = True
private: bool = False
tags: list[str] | None = None
num_image_writer_processes: int = 0
num_image_writer_threads_per_camera: int = 4
video_encoding_batch_size: int = 1
vcodec: str = "auto"
streaming_encoding: bool = True
encoder_queue_maxsize: int = 30
encoder_threads: int | None = None
rename_map: dict[str, str] = field(default_factory=dict)
def teleop_has_motor_control(teleop: Teleoperator) -> bool:
"""Check if teleoperator has motor control capabilities."""
return all(hasattr(teleop, attr) for attr in ("enable_torque", "disable_torque", "write_goal_positions"))
def teleop_disable_torque(teleop: Teleoperator) -> None:
"""Disable teleop torque if supported."""
if hasattr(teleop, "disable_torque"):
teleop.disable_torque()
def teleop_enable_torque(teleop: Teleoperator) -> None:
"""Enable teleop torque if supported."""
if hasattr(teleop, "enable_torque"):
teleop.enable_torque()
def teleop_smooth_move_to(teleop: Teleoperator, target_pos: dict, duration_s: float = 2.0, fps: int = 50):
"""Smoothly move teleop to target position if motor control is available."""
if not teleop_has_motor_control(teleop):
logger.warning("Teleop does not support motor control - cannot mirror robot position")
return
teleop_enable_torque(teleop)
current = teleop.get_action()
steps = max(int(duration_s * fps), 1)
for step in range(steps + 1):
t = step / steps
interp = {}
for k in current:
if k in target_pos:
interp[k] = current[k] * (1 - t) + target_pos[k] * t
else:
interp[k] = current[k]
teleop.write_goal_positions(interp)
time.sleep(1 / fps)
def init_keyboard_listener():
"""Initialize keyboard listener with HIL controls."""
events = {
"exit_early": False,
"rerecord_episode": False,
"stop_recording": False,
"policy_paused": False,
"correction_active": False,
"resume_policy": False,
"in_reset": False,
"start_next_episode": False,
}
if is_headless():
logger.warning("Headless environment - keyboard controls unavailable")
return None, events
from pynput import keyboard
def on_press(key):
try:
if events["in_reset"]:
if key in [keyboard.Key.space, keyboard.Key.right]:
logger.info("[HIL] Starting next episode...")
events["start_next_episode"] = True
elif hasattr(key, "char") and key.char == "c":
events["start_next_episode"] = True
elif key == keyboard.Key.esc:
logger.info("[HIL] ESC - Stop recording, pushing to hub...")
events["stop_recording"] = True
events["start_next_episode"] = True
else:
if key == keyboard.Key.space:
if not events["policy_paused"] and not events["correction_active"]:
logger.info("[HIL] PAUSED - Press 'c' to take control or 'p' to resume policy")
events["policy_paused"] = True
elif hasattr(key, "char") and key.char == "c":
if events["policy_paused"] and not events["correction_active"]:
logger.info("[HIL] Taking control...")
events["start_next_episode"] = True
elif hasattr(key, "char") and key.char == "p":
if events["policy_paused"] or events["correction_active"]:
logger.info("[HIL] Resuming policy...")
events["resume_policy"] = True
elif key == keyboard.Key.right:
logger.info("[HIL] End episode")
events["exit_early"] = True
elif key == keyboard.Key.left:
logger.info("[HIL] Re-record episode")
events["rerecord_episode"] = True
events["exit_early"] = True
elif key == keyboard.Key.esc:
logger.info("[HIL] ESC - Stop recording...")
events["stop_recording"] = True
events["exit_early"] = True
except Exception as e:
logger.info(f"Key error: {e}")
listener = keyboard.Listener(on_press=on_press)
listener.start()
return listener, events
def make_identity_processors():
"""Create identity processors for recording."""
teleop_proc = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
steps=[IdentityProcessorStep()],
to_transition=robot_action_observation_to_transition,
to_output=transition_to_robot_action,
)
obs_proc = RobotProcessorPipeline[RobotObservation, RobotObservation](
steps=[IdentityProcessorStep()],
to_transition=observation_to_transition,
to_output=transition_to_observation,
)
return teleop_proc, obs_proc
def reset_loop(robot: Robot, teleop: Teleoperator, events: dict, fps: int):
"""Reset period where human repositions environment."""
logger.info("[HIL] RESET")
events["in_reset"] = True
events["start_next_episode"] = False
obs = robot.get_observation()
robot_pos = {k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features}
teleop_smooth_move_to(teleop, robot_pos, duration_s=2.0, fps=50)
logger.info("Press any key to enable teleoperation")
while not events["start_next_episode"] and not events["stop_recording"]:
precise_sleep(0.05)
if events["stop_recording"]:
return
events["start_next_episode"] = False
teleop_disable_torque(teleop)
logger.info("Teleop enabled - press any key to start episode")
while not events["start_next_episode"] and not events["stop_recording"]:
loop_start = time.perf_counter()
action = teleop.get_action()
robot.send_action(action)
precise_sleep(1 / fps - (time.perf_counter() - loop_start))
events["in_reset"] = False
events["start_next_episode"] = False
events["exit_early"] = False
events["policy_paused"] = False
events["correction_active"] = False
events["resume_policy"] = False
def print_controls(rtc: bool = False):
"""Print control instructions."""
mode = "Human-in-the-Loop Data Collection" + (" (RTC)" if rtc else "")
logger.info(
"%s\n Controls:\n"
" SPACE - Pause policy\n"
" c - Take control\n"
" p - Resume policy after pause/correction\n"
" → - End episode\n"
" ESC - Stop and push to hub",
mode,
)
+31 -62
View File
@@ -14,21 +14,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import time
from lerobot.common.control_utils import init_keyboard_listener, predict_action
from lerobot.common.control_utils import init_keyboard_listener
from lerobot.datasets import LeRobotDataset
from lerobot.policies import make_pre_post_processors
from lerobot.policies.act import ACTPolicy
from lerobot.policies.utils import make_robot_action
from lerobot.processor import make_default_processors
from lerobot.robots.lekiwi import LeKiwiClient, LeKiwiClientConfig
from lerobot.scripts.lerobot_record import record_loop
from lerobot.utils.constants import ACTION, OBS_STR
from lerobot.utils.feature_utils import build_dataset_frame, hw_to_dataset_features
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.feature_utils import hw_to_dataset_features
from lerobot.utils.utils import log_say
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
from lerobot.utils.visualization_utils import init_rerun
NUM_EPISODES = 2
FPS = 30
@@ -39,9 +35,6 @@ HF_DATASET_ID = "<hf_username>/<eval_dataset_repo_id>"
def main():
# NOTE: For production policy deployment, use `lerobot-rollout` CLI instead.
# This script provides a self-contained example for educational purposes.
# Create the robot configuration & robot
robot_config = LeKiwiClientConfig(remote_ip="172.18.134.136", id="lekiwi")
@@ -90,67 +83,43 @@ def main():
raise ValueError("Robot is not connected!")
print("Starting evaluate loop...")
control_interval = 1 / FPS
recorded_episodes = 0
while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
log_say(f"Running inference, recording eval episode {recorded_episodes} of {NUM_EPISODES}")
# Inline evaluation loop: predict actions and send to robot
timestamp = 0
start_episode_t = time.perf_counter()
while timestamp < EPISODE_TIME_SEC:
start_loop_t = time.perf_counter()
if events["exit_early"]:
events["exit_early"] = False
break
# Get robot observation
obs = robot.get_observation()
obs_processed = robot_observation_processor(obs)
observation_frame = build_dataset_frame(dataset.features, obs_processed, prefix=OBS_STR)
# Predict action using the policy
action_tensor = predict_action(
observation=observation_frame,
policy=policy,
device=policy.config.device,
preprocessor=preprocessor,
postprocessor=postprocessor,
use_amp=policy.config.device.type == "cuda",
task=TASK_DESCRIPTION,
robot_type=robot.name,
)
# Convert policy output to robot action dict
action_values = make_robot_action(action_tensor, dataset.features)
# Process and send action to robot
robot_action_to_send = robot_action_processor((action_values, obs))
robot.send_action(robot_action_to_send)
# Write to dataset
action_frame = build_dataset_frame(dataset.features, action_values, prefix=ACTION)
frame = {**observation_frame, **action_frame, "task": TASK_DESCRIPTION}
dataset.add_frame(frame)
log_rerun_data(observation=obs_processed, action=action_values)
dt_s = time.perf_counter() - start_loop_t
sleep_time_s = control_interval - dt_s
if sleep_time_s < 0:
logging.warning(
f"Evaluate loop is running slower ({1 / dt_s:.1f} Hz) than the target FPS ({FPS} Hz)."
)
precise_sleep(max(sleep_time_s, 0.0))
timestamp = time.perf_counter() - start_episode_t
# Main record loop
record_loop(
robot=robot,
events=events,
fps=FPS,
policy=policy,
preprocessor=preprocessor, # Pass the pre and post policy processors
postprocessor=postprocessor,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=teleop_action_processor,
robot_action_processor=robot_action_processor,
robot_observation_processor=robot_observation_processor,
)
# Reset the environment if not stopping or re-recording
if not events["stop_recording"] and (
(recorded_episodes < NUM_EPISODES - 1) or events["rerecord_episode"]
):
log_say("Reset the environment")
log_say("Waiting for environment reset, press right arrow key when ready...")
record_loop(
robot=robot,
events=events,
fps=FPS,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=teleop_action_processor,
robot_action_processor=robot_action_processor,
robot_observation_processor=robot_observation_processor,
)
if events["rerecord_episode"]:
log_say("Re-record episode")
+9 -10
View File
@@ -45,6 +45,9 @@ def main():
leader_arm = SO100Leader(leader_arm_config)
keyboard = KeyboardTeleop(keyboard_config)
# TODO(Steven): Update this example to use pipelines
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
# Configure the dataset features
action_features = hw_to_dataset_features(robot.action_features, ACTION)
obs_features = hw_to_dataset_features(robot.observation_features, OBS_STR)
@@ -74,10 +77,6 @@ def main():
if not robot.is_connected or not leader_arm.is_connected or not keyboard.is_connected:
raise ValueError("Robot or teleop is not connected!")
teleop_action_processor, robot_action_processor, robot_observation_processor = (
make_default_processors()
)
print("Starting record loop...")
recorded_episodes = 0
while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
@@ -88,14 +87,14 @@ def main():
robot=robot,
events=events,
fps=FPS,
teleop_action_processor=teleop_action_processor,
robot_action_processor=robot_action_processor,
robot_observation_processor=robot_observation_processor,
dataset=dataset,
teleop=[leader_arm, keyboard],
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=teleop_action_processor,
robot_action_processor=robot_action_processor,
robot_observation_processor=robot_observation_processor,
)
# Reset the environment if not stopping or re-recording
@@ -107,13 +106,13 @@ def main():
robot=robot,
events=events,
fps=FPS,
teleop_action_processor=teleop_action_processor,
robot_action_processor=robot_action_processor,
robot_observation_processor=robot_observation_processor,
teleop=[leader_arm, keyboard],
control_time_s=RESET_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=teleop_action_processor,
robot_action_processor=robot_action_processor,
robot_observation_processor=robot_observation_processor,
)
if events["rerecord_episode"]:
-77
View File
@@ -1,77 +0,0 @@
# !/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Run a trained policy on LeKiwi without recording (base rollout).
Uses the rollout engine's :class:`BaseStrategy` (autonomous execution,
no dataset) with :class:`SyncInferenceConfig` (inline policy call per
control tick). For a CLI entry point with the same capabilities plus
recording, upload, and human-in-the-loop variants, see ``lerobot-rollout``.
"""
from lerobot.configs import PreTrainedConfig
from lerobot.robots.lekiwi import LeKiwiClientConfig
from lerobot.rollout import BaseStrategyConfig, RolloutConfig, build_rollout_context
from lerobot.rollout.inference import SyncInferenceConfig
from lerobot.rollout.strategies import BaseStrategy
from lerobot.utils.process import ProcessSignalHandler
from lerobot.utils.utils import init_logging
FPS = 30
DURATION_SEC = 60
TASK_DESCRIPTION = "My task description"
HF_MODEL_ID = "<hf_username>/<model_repo_id>"
def main():
init_logging()
# Robot: LeKiwi client — make sure lekiwi_host is already running on the robot.
robot_config = LeKiwiClientConfig(remote_ip="172.18.134.136", id="lekiwi")
# Policy: load the pretrained config. ``pretrained_path`` is read downstream
# by ``build_rollout_context`` to reload the full model.
policy_config = PreTrainedConfig.from_pretrained(HF_MODEL_ID)
policy_config.pretrained_path = HF_MODEL_ID
# Assemble the rollout config: base strategy (no recording) + sync inference.
cfg = RolloutConfig(
robot=robot_config,
policy=policy_config,
strategy=BaseStrategyConfig(),
inference=SyncInferenceConfig(),
fps=FPS,
duration=DURATION_SEC,
task=TASK_DESCRIPTION,
)
# Graceful Ctrl-C: the strategy loop exits when shutdown_event is set.
signal_handler = ProcessSignalHandler(use_threads=True)
# Build the context (connects robot, loads policy, wires the inference strategy).
# No custom processors here — LeKiwi runs on raw joint features.
ctx = build_rollout_context(cfg, signal_handler.shutdown_event)
strategy = BaseStrategy(cfg.strategy)
try:
strategy.setup(ctx)
strategy.run(ctx)
finally:
strategy.teardown(ctx)
if __name__ == "__main__":
main()
-342
View File
@@ -1,342 +0,0 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 🤗 LeRobot Quickstart\n",
"\n",
"Calibration → teleoperation → data collection → training → evaluation.\n",
"\n",
"Install the required dependencies: `pip install -e .[notebook,dataset,training,viz,hardware]`.\n",
"\n",
"**How to use:**\n",
"1. Edit the **Configuration** cell with your settings.\n",
"2. Run all cells (`Run All`).\n",
"3. Each section prints a ready-to-paste terminal command - copy it and run it.\n",
"\n",
"Each setup is different, please refer to the [LeRobot documentation](https://huggingface.co/docs/lerobot/il_robots) for more details on each step and available options. <br>\n",
"Feel free to make this notebook your own and adapt it to your needs!"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"---\n",
"## Utils"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def _cameras_arg(cameras: dict) -> str:\n",
" if not cameras:\n",
" return \"\"\n",
" entries = [f\"{n}: {{{', '.join(f'{k}: {v}' for k, v in cfg.items())}}}\" for n, cfg in cameras.items()]\n",
" return \"{ \" + \", \".join(entries) + \" }\"\n",
"\n",
"\n",
"def print_cmd(*parts: str) -> None:\n",
" \"\"\"Print a shell command with line continuations, skipping empty parts.\"\"\"\n",
" non_empty = [p for p in parts if p]\n",
" print(\" \\\\\\n \".join(non_empty))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"---\n",
"## Configuration\n",
"\n",
"Edit this cell, then **Run All** to generate all commands below."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Robot (follower) - run `lerobot-find-port` to discover the port\n",
"ROBOT_TYPE = \"so101_follower\"\n",
"ROBOT_PORT = \"/dev/ttyACM0\"\n",
"ROBOT_ID = \"my_follower_arm\"\n",
"\n",
"# Teleop (leader) - run `lerobot-find-port` to discover the port\n",
"TELEOP_TYPE = \"so101_leader\"\n",
"TELEOP_PORT = \"/dev/ttyACM1\"\n",
"TELEOP_ID = \"my_leader_arm\"\n",
"\n",
"# Cameras - set to {} to disable\n",
"# Run `lerobot-find-cameras opencv` to list available cameras and their indices\n",
"CAMERAS = {\n",
" \"top\": {\"type\": \"opencv\", \"index_or_path\": 2, \"width\": 640, \"height\": 480, \"fps\": 30},\n",
" \"wrist\": {\"type\": \"opencv\", \"index_or_path\": 4, \"width\": 640, \"height\": 480, \"fps\": 30},\n",
"}\n",
"\n",
"# Dataset\n",
"HF_USER = \"your_hf_username\" # `huggingface-cli whoami` to find your username\n",
"DATASET_NAME = \"my_so101_dataset\"\n",
"TASK_DESCRIPTION = \"pick and place the block\"\n",
"NUM_EPISODES = 10\n",
"\n",
"# Training\n",
"POLICY_TYPE = \"act\" # act, diffusion, smolvla, ...\n",
"POLICY_DEVICE = \"cuda\" # cuda / cpu / mps\n",
"TRAIN_STEPS = 10_000\n",
"SAVE_FREQ = 2_000\n",
"OUTPUT_DIR = f\"outputs/train/{DATASET_NAME}\"\n",
"\n",
"# Inference - Hub repo ID or local checkpoint path\n",
"# e.g. set to f\"{OUTPUT_DIR}/checkpoints/last\" to use a local checkpoint\n",
"POLICY_PATH = f\"{HF_USER}/{DATASET_NAME}_{POLICY_TYPE}\"\n",
"LAST_CHECKPOINT_PATH = f\"{OUTPUT_DIR}/checkpoints/last\"\n",
"\n",
"# Derived\n",
"DATASET_REPO_ID = f\"{HF_USER}/{DATASET_NAME}\"\n",
"DATASET_ROOT = f\"data/{DATASET_NAME}\"\n",
"POLICY_REPO_ID = f\"{HF_USER}/{DATASET_NAME}_{POLICY_TYPE}\"\n",
"EVAL_REPO_ID = f\"{HF_USER}/eval_{DATASET_NAME}\"\n",
"CAMERAS_ARG = _cameras_arg(CAMERAS)\n",
"CAMERAS_FLAG = f'--robot.cameras=\"{CAMERAS_ARG}\"' if CAMERAS_ARG else \"\"\n",
"\n",
"print(f\"Robot : {ROBOT_TYPE} @ {ROBOT_PORT}\")\n",
"print(f\"Teleop : {TELEOP_TYPE} @ {TELEOP_PORT}\")\n",
"print(f\"Cameras: {list(CAMERAS) or 'none'}\")\n",
"print(f\"Dataset: {DATASET_REPO_ID} ({NUM_EPISODES} episodes) saved to {DATASET_ROOT}\")\n",
"print(f\"Policy : {POLICY_TYPE} -> {POLICY_REPO_ID}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"---\n",
"## 1. Calibration\n",
"\n",
"Run once per arm before first use."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Follower\n",
"print_cmd(\n",
" \"lerobot-calibrate\",\n",
" f\"--robot.type={ROBOT_TYPE}\",\n",
" f\"--robot.port={ROBOT_PORT}\",\n",
" f\"--robot.id={ROBOT_ID}\",\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Leader\n",
"print_cmd(\n",
" \"lerobot-calibrate\",\n",
" f\"--teleop.type={TELEOP_TYPE}\",\n",
" f\"--teleop.port={TELEOP_PORT}\",\n",
" f\"--teleop.id={TELEOP_ID}\",\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"---\n",
"## 2. Teleoperation\n",
"\n",
"See the [teleoperation docs](https://huggingface.co/docs/lerobot/il_robots#teleoperate) and the [cameras guide](https://huggingface.co/docs/lerobot/cameras) for more options."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print_cmd(\n",
" \"lerobot-teleoperate\",\n",
" f\"--robot.type={ROBOT_TYPE}\",\n",
" f\"--robot.port={ROBOT_PORT}\",\n",
" f\"--robot.id={ROBOT_ID}\",\n",
" CAMERAS_FLAG,\n",
" f\"--teleop.type={TELEOP_TYPE}\",\n",
" f\"--teleop.port={TELEOP_PORT}\",\n",
" f\"--teleop.id={TELEOP_ID}\",\n",
" \"--display_data=true\",\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"---\n",
"## 3. Record Dataset\n",
"\n",
"See the [recording docs](https://huggingface.co/docs/lerobot/il_robots#record-a-dataset) for tips on gathering good data."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print_cmd(\n",
" \"lerobot-record\",\n",
" f\"--robot.type={ROBOT_TYPE}\",\n",
" f\"--robot.port={ROBOT_PORT}\",\n",
" f\"--robot.id={ROBOT_ID}\",\n",
" CAMERAS_FLAG,\n",
" f\"--teleop.type={TELEOP_TYPE}\",\n",
" f\"--teleop.port={TELEOP_PORT}\",\n",
" f\"--teleop.id={TELEOP_ID}\",\n",
" f\"--dataset.repo_id={DATASET_REPO_ID}\",\n",
" f\"--dataset.num_episodes={NUM_EPISODES}\",\n",
" f'--dataset.single_task=\"{TASK_DESCRIPTION}\"',\n",
" \"--dataset.streaming_encoding=true\",\n",
" \"--display_data=true\",\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Resume a previously interrupted recording session\n",
"print_cmd(\n",
" \"lerobot-record\",\n",
" f\"--robot.type={ROBOT_TYPE}\",\n",
" f\"--robot.port={ROBOT_PORT}\",\n",
" f\"--robot.id={ROBOT_ID}\",\n",
" CAMERAS_FLAG,\n",
" f\"--teleop.type={TELEOP_TYPE}\",\n",
" f\"--teleop.port={TELEOP_PORT}\",\n",
" f\"--teleop.id={TELEOP_ID}\",\n",
" f\"--dataset.repo_id={DATASET_REPO_ID}\",\n",
" f\"--dataset.root={DATASET_ROOT}\",\n",
" f\"--dataset.num_episodes={NUM_EPISODES}\",\n",
" f'--dataset.single_task=\"{TASK_DESCRIPTION}\"',\n",
" \"--dataset.streaming_encoding=true\",\n",
" \"--display_data=true\",\n",
" \"--resume=true\",\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"---\n",
"## 4. Train Policy\n",
"\n",
"See the [training docs](https://huggingface.co/docs/lerobot/il_robots#train-a-policy) for configuration options and tips."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print_cmd(\n",
" \"lerobot-train\",\n",
" f\"--dataset.repo_id={DATASET_REPO_ID}\",\n",
" f\"--policy.type={POLICY_TYPE}\",\n",
" f\"--policy.device={POLICY_DEVICE}\",\n",
" f\"--policy.repo_id={POLICY_REPO_ID}\",\n",
" f\"--output_dir={OUTPUT_DIR}\",\n",
" f\"--steps={TRAIN_STEPS}\",\n",
" f\"--save_freq={SAVE_FREQ}\",\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Resume a previously interrupted training session\n",
"print_cmd(\n",
" \"lerobot-train\",\n",
" f\"--config_path={LAST_CHECKPOINT_PATH}/pretrained_model/train_config.json\",\n",
" \"--resume=true\",\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"---\n",
"## 5. Inference\n",
"\n",
"Uses `POLICY_PATH` from the Configuration cell (defaults to the Hub repo ID). You can also put there the `LAST_CHECKPOINT_PATH`.\n",
"\n",
"See the [inference docs](https://huggingface.co/docs/lerobot/il_robots#run-inference-and-evaluate-your-policy) for details."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print_cmd(\n",
" \"lerobot-record\",\n",
" f\"--policy.path={POLICY_PATH}\",\n",
" f\"--robot.type={ROBOT_TYPE}\",\n",
" f\"--robot.port={ROBOT_PORT}\",\n",
" f\"--robot.id={ROBOT_ID}\",\n",
" CAMERAS_FLAG,\n",
" f\"--teleop.type={TELEOP_TYPE}\",\n",
" f\"--teleop.port={TELEOP_PORT}\",\n",
" f\"--teleop.id={TELEOP_ID}\",\n",
" f\"--dataset.repo_id={EVAL_REPO_ID}\",\n",
" f\"--dataset.num_episodes={NUM_EPISODES}\",\n",
" f'--dataset.single_task=\"{TASK_DESCRIPTION}\"',\n",
" \"--dataset.streaming_encoding=true\",\n",
")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "lerobot (3.12.3)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
+32 -63
View File
@@ -14,17 +14,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import time
from lerobot.cameras.opencv import OpenCVCameraConfig
from lerobot.common.control_utils import init_keyboard_listener, predict_action
from lerobot.common.control_utils import init_keyboard_listener
from lerobot.configs import FeatureType, PolicyFeature
from lerobot.datasets import LeRobotDataset, aggregate_pipeline_dataset_features, create_initial_features
from lerobot.model.kinematics import RobotKinematics
from lerobot.policies import make_pre_post_processors
from lerobot.policies.act import ACTPolicy
from lerobot.policies.utils import make_robot_action
from lerobot.processor import (
RobotProcessorPipeline,
make_default_teleop_action_processor,
@@ -38,12 +34,11 @@ from lerobot.robots.so_follower.robot_kinematic_processor import (
ForwardKinematicsJointsToEE,
InverseKinematicsEEToJoints,
)
from lerobot.scripts.lerobot_record import record_loop
from lerobot.types import RobotAction, RobotObservation
from lerobot.utils.constants import ACTION, OBS_STR
from lerobot.utils.feature_utils import build_dataset_frame, combine_feature_dicts
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.feature_utils import combine_feature_dicts
from lerobot.utils.utils import log_say
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
from lerobot.utils.visualization_utils import init_rerun
NUM_EPISODES = 5
FPS = 30
@@ -54,9 +49,6 @@ HF_DATASET_ID = "<hf_username>/<dataset_repo_id>"
def main():
# NOTE: For production policy deployment, use `lerobot-rollout` CLI instead.
# This script provides a self-contained example for educational purposes.
# Create the robot configuration & robot
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
robot_config = SO100FollowerConfig(
@@ -151,67 +143,43 @@ def main():
raise ValueError("Robot is not connected!")
print("Starting evaluate loop...")
control_interval = 1 / FPS
episode_idx = 0
for episode_idx in range(NUM_EPISODES):
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
# Inline evaluation loop: predict actions and send to robot
timestamp = 0
start_episode_t = time.perf_counter()
while timestamp < EPISODE_TIME_SEC:
start_loop_t = time.perf_counter()
if events["exit_early"]:
events["exit_early"] = False
break
# Get robot observation
obs = robot.get_observation()
obs_processed = robot_joints_to_ee_pose_processor(obs)
observation_frame = build_dataset_frame(dataset.features, obs_processed, prefix=OBS_STR)
# Predict action using the policy
action_tensor = predict_action(
observation=observation_frame,
policy=policy,
device=policy.config.device,
preprocessor=preprocessor,
postprocessor=postprocessor,
use_amp=policy.config.device.type == "cuda",
task=TASK_DESCRIPTION,
robot_type=robot.name,
)
# Convert policy output to robot action dict
action_values = make_robot_action(action_tensor, dataset.features)
# Process and send action to robot (EE -> joints via IK)
robot_action_to_send = robot_ee_to_joints_processor((action_values, obs))
robot.send_action(robot_action_to_send)
# Write to dataset
action_frame = build_dataset_frame(dataset.features, action_values, prefix=ACTION)
frame = {**observation_frame, **action_frame, "task": TASK_DESCRIPTION}
dataset.add_frame(frame)
log_rerun_data(observation=obs_processed, action=action_values)
dt_s = time.perf_counter() - start_loop_t
sleep_time_s = control_interval - dt_s
if sleep_time_s < 0:
logging.warning(
f"Evaluate loop is running slower ({1 / dt_s:.1f} Hz) than the target FPS ({FPS} Hz)."
)
precise_sleep(max(sleep_time_s, 0.0))
timestamp = time.perf_counter() - start_episode_t
# Main record loop
record_loop(
robot=robot,
events=events,
fps=FPS,
policy=policy,
preprocessor=preprocessor, # Pass the pre and post policy processors
postprocessor=postprocessor,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=make_default_teleop_action_processor(),
robot_action_processor=robot_ee_to_joints_processor,
robot_observation_processor=robot_joints_to_ee_pose_processor,
)
# Reset the environment if not stopping or re-recording
if not events["stop_recording"] and (
(episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"]
):
log_say("Reset the environment")
log_say("Waiting for environment reset, press right arrow key when ready...")
record_loop(
robot=robot,
events=events,
fps=FPS,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=make_default_teleop_action_processor(),
robot_action_processor=robot_ee_to_joints_processor,
robot_observation_processor=robot_joints_to_ee_pose_processor,
)
if events["rerecord_episode"]:
log_say("Re-record episode")
@@ -222,6 +190,7 @@ def main():
# Save episode
dataset.save_episode()
episode_idx += 1
finally:
# Clean up
log_say("Stop recording")
+13 -13
View File
@@ -65,15 +65,14 @@ def main():
robot = SO100Follower(robot_config)
phone = Phone(teleop_config)
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo:
# https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
kinematics_solver = RobotKinematics(
urdf_path="./SO101/so101_new_calib.urdf",
target_frame_name="gripper_frame_link",
joint_names=list(robot.bus.motors.keys()),
)
# Build pipeline to convert phone action to EE action (with gripper velocity mapped to joint).
# Build pipeline to convert phone action to EE action
phone_to_robot_ee_pose_processor = RobotProcessorPipeline[
tuple[RobotAction, RobotObservation], RobotAction
](
@@ -95,7 +94,7 @@ def main():
to_output=transition_to_robot_action,
)
# Build pipeline to convert EE action to joints action (IK).
# Build pipeline to convert EE action to joints action
robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
steps=[
InverseKinematicsEEToJoints(
@@ -108,7 +107,7 @@ def main():
to_output=transition_to_robot_action,
)
# Build pipeline to convert joint observation to EE observation (FK).
# Build pipeline to convert joint observation to EE observation
robot_joints_to_ee_pose = RobotProcessorPipeline[RobotObservation, RobotObservation](
steps=[
ForwardKinematicsJointsToEE(
@@ -119,12 +118,13 @@ def main():
to_output=transition_to_observation,
)
# Create the dataset, deriving features from the pipelines so the on-disk schema
# matches exactly what the pipelines produce at runtime.
# Create the dataset
dataset = LeRobotDataset.create(
repo_id=HF_REPO_ID,
fps=FPS,
features=combine_feature_dicts(
# Run the feature contract of the pipelines
# This tells you how the features would look like after the pipeline steps
aggregate_pipeline_dataset_features(
pipeline=phone_to_robot_ee_pose_processor,
initial_features=create_initial_features(action=phone.action_features),
@@ -163,14 +163,14 @@ def main():
robot=robot,
events=events,
fps=FPS,
teleop_action_processor=phone_to_robot_ee_pose_processor,
robot_action_processor=robot_ee_to_joints_processor,
robot_observation_processor=robot_joints_to_ee_pose,
teleop=phone,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=phone_to_robot_ee_pose_processor,
robot_action_processor=robot_ee_to_joints_processor,
robot_observation_processor=robot_joints_to_ee_pose,
)
# Reset the environment if not stopping or re-recording
@@ -182,13 +182,13 @@ def main():
robot=robot,
events=events,
fps=FPS,
teleop_action_processor=phone_to_robot_ee_pose_processor,
robot_action_processor=robot_ee_to_joints_processor,
robot_observation_processor=robot_joints_to_ee_pose,
teleop=phone,
control_time_s=RESET_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=phone_to_robot_ee_pose_processor,
robot_action_processor=robot_ee_to_joints_processor,
robot_observation_processor=robot_joints_to_ee_pose,
)
if events["rerecord_episode"]:
-126
View File
@@ -1,126 +0,0 @@
# !/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Run a trained EE-space policy on SO100 (phone-trained) without recording.
Mirrors ``examples/so100_to_so100_EE/rollout.py`` — the model was trained
with phone teleoperation in EE space, so at deployment we only need the
joint↔EE conversion on the robot side; the phone is not used.
Uses :class:`BaseStrategy` (no recording) + :class:`SyncInferenceConfig`
(inline policy call). For recording during rollout, switch to Sentry,
Highlight, or DAgger via ``lerobot-rollout --strategy.type=...``.
"""
from lerobot.cameras.opencv import OpenCVCameraConfig
from lerobot.configs import PreTrainedConfig
from lerobot.model.kinematics import RobotKinematics
from lerobot.processor import (
RobotProcessorPipeline,
observation_to_transition,
robot_action_observation_to_transition,
transition_to_observation,
transition_to_robot_action,
)
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig
from lerobot.robots.so_follower.robot_kinematic_processor import (
ForwardKinematicsJointsToEE,
InverseKinematicsEEToJoints,
)
from lerobot.rollout import BaseStrategyConfig, RolloutConfig, build_rollout_context
from lerobot.rollout.inference import SyncInferenceConfig
from lerobot.rollout.strategies import BaseStrategy
from lerobot.types import RobotAction, RobotObservation
from lerobot.utils.process import ProcessSignalHandler
from lerobot.utils.utils import init_logging
FPS = 30
DURATION_SEC = 60
TASK_DESCRIPTION = "My task description"
HF_MODEL_ID = "<hf_username>/<model_repo_id>"
def main():
init_logging()
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
robot_config = SO100FollowerConfig(
port="/dev/tty.usbmodem58760434471",
id="my_awesome_follower_arm",
cameras=camera_config,
use_degrees=True,
)
# Peek at motor names once to build the kinematic solver.
temp_robot = SO100Follower(robot_config)
motor_names = list(temp_robot.bus.motors.keys())
kinematics_solver = RobotKinematics(
urdf_path="./SO101/so101_new_calib.urdf",
target_frame_name="gripper_frame_link",
joint_names=motor_names,
)
robot_joints_to_ee_pose_processor = RobotProcessorPipeline[RobotObservation, RobotObservation](
steps=[ForwardKinematicsJointsToEE(kinematics=kinematics_solver, motor_names=motor_names)],
to_transition=observation_to_transition,
to_output=transition_to_observation,
)
robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
steps=[
InverseKinematicsEEToJoints(
kinematics=kinematics_solver,
motor_names=motor_names,
initial_guess_current_joints=True,
),
],
to_transition=robot_action_observation_to_transition,
to_output=transition_to_robot_action,
)
policy_config = PreTrainedConfig.from_pretrained(HF_MODEL_ID)
policy_config.pretrained_path = HF_MODEL_ID
cfg = RolloutConfig(
robot=robot_config,
policy=policy_config,
strategy=BaseStrategyConfig(),
inference=SyncInferenceConfig(),
fps=FPS,
duration=DURATION_SEC,
task=TASK_DESCRIPTION,
)
signal_handler = ProcessSignalHandler(use_threads=True)
ctx = build_rollout_context(
cfg,
signal_handler.shutdown_event,
robot_action_processor=robot_ee_to_joints_processor,
robot_observation_processor=robot_joints_to_ee_pose_processor,
)
strategy = BaseStrategy(cfg.strategy)
try:
strategy.setup(ctx)
strategy.run(ctx)
finally:
strategy.teardown(ctx)
if __name__ == "__main__":
main()
+673
View File
@@ -0,0 +1,673 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Demo script showing how to use Real-Time Chunking (RTC) with action chunking policies on real robots.
This script demonstrates:
1. Creating a robot and policy (SmolVLA, Pi0, etc.) with RTC
2. Consuming actions from the policy while the robot executes
3. Periodically requesting new action chunks in the background using threads
4. Managing action buffers and timing for real-time operation
For simulation environments, see eval_with_simulation.py
Usage:
# Run RTC with Real robot with RTC
uv run examples/rtc/eval_with_real_robot.py \
--policy.path=<USER>/smolvla_check_rtc_last3 \
--policy.device=mps \
--rtc.enabled=true \
--rtc.execution_horizon=20 \
--robot.type=so100_follower \
--robot.port=/dev/tty.usbmodem58FA0834591 \
--robot.id=so100_follower \
--robot.cameras="{ gripper: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
--task="Move green small object into the purple platform" \
--duration=120
# Run RTC with Real robot without RTC
uv run examples/rtc/eval_with_real_robot.py \
--policy.path=<USER>/smolvla_check_rtc_last3 \
--policy.device=mps \
--rtc.enabled=false \
--robot.type=so100_follower \
--robot.port=/dev/tty.usbmodem58FA0834591 \
--robot.id=so100_follower \
--robot.cameras="{ gripper: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
--task="Move green small object into the purple platform" \
--duration=120
# Run RTC with Real robot with pi0.5 policy
uv run examples/rtc/eval_with_real_robot.py \
--policy.path=<USER>/pi05_check_rtc \
--policy.device=mps \
--rtc.enabled=true \
--rtc.execution_horizon=20 \
--robot.type=so100_follower \
--robot.port=/dev/tty.usbmodem58FA0834591 \
--robot.id=so100_follower \
--robot.cameras="{ gripper: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}}" \
--task="Move green small object into the purple platform" \
--duration=120
# Run RTC with bi_openarm_follower (dual-arm OpenArms) and pi0.5 policy
python examples/rtc/eval_with_real_robot.py \
--policy.path=lerobot-data-collection/folding_final \
--robot.type=bi_openarm_follower \
--robot.cameras='{left_wrist: {type: opencv, index_or_path: "/dev/video4", width: 1280, height: 720, fps: 30}, base: {type: opencv, index_or_path: "/dev/video2", width: 640, height: 480, fps: 30}, right_wrist: {type: opencv, index_or_path: "/dev/video0", width: 1280, height: 720, fps: 30}}' \
--robot.left_arm_config.port=can0 \
--robot.left_arm_config.side=left \
--robot.left_arm_config.can_interface=socketcan \
--robot.left_arm_config.disable_torque_on_disconnect=true \
--robot.left_arm_config.max_relative_target=8.0 \
--robot.right_arm_config.port=can1 \
--robot.right_arm_config.side=right \
--robot.right_arm_config.can_interface=socketcan \
--robot.right_arm_config.disable_torque_on_disconnect=true \
--robot.right_arm_config.max_relative_target=8.0 \
--task="Fold the T-shirt properly" \
--fps=30 \
--duration=2000 \
--interpolation_multiplier=3 \
--rtc.enabled=true \
--rtc.execution_horizon=20 \
--rtc.max_guidance_weight=5.0 \
--rtc.prefix_attention_schedule=LINEAR \
--device=cuda
"""
import logging
import math
import sys
import time
import traceback
from dataclasses import dataclass, field
from threading import Event, Lock, Thread
import torch
from torch import Tensor
from lerobot.cameras.opencv import OpenCVCameraConfig # noqa: F401
from lerobot.cameras.realsense import RealSenseCameraConfig # noqa: F401
from lerobot.cameras.zmq import ZMQCameraConfig # noqa: F401
from lerobot.configs import PreTrainedConfig, RTCAttentionSchedule, parser
from lerobot.policies import get_policy_class, make_pre_post_processors
from lerobot.policies.rtc import ActionInterpolator, ActionQueue, LatencyTracker, RTCConfig
from lerobot.processor import (
NormalizerProcessorStep,
RelativeActionsProcessorStep,
TransitionKey,
create_transition,
make_default_robot_action_processor,
make_default_robot_observation_processor,
to_relative_actions,
)
from lerobot.rl.process import ProcessSignalHandler
from lerobot.robots import ( # noqa: F401
Robot,
RobotConfig,
bi_openarm_follower,
bi_so_follower,
koch_follower,
so_follower,
unitree_g1,
)
from lerobot.robots.utils import make_robot_from_config
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
from lerobot.utils.feature_utils import build_dataset_frame, hw_to_dataset_features
from lerobot.utils.hub import HubMixin
from lerobot.utils.utils import init_logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class RobotWrapper:
def __init__(self, robot: Robot):
self.robot = robot
self.lock = Lock()
def get_observation(self) -> dict[str, Tensor]:
with self.lock:
return self.robot.get_observation()
def send_action(self, action: Tensor):
with self.lock:
self.robot.send_action(action)
def observation_features(self) -> list[str]:
with self.lock:
return self.robot.observation_features
def action_features(self) -> list[str]:
with self.lock:
return self.robot.action_features
@dataclass
class RTCDemoConfig(HubMixin):
"""Configuration for RTC demo with action chunking policies and real robots."""
# Policy configuration
policy: PreTrainedConfig | None = None
# Robot configuration
robot: RobotConfig | None = None
# RTC configuration
rtc: RTCConfig = field(
default_factory=lambda: RTCConfig(
execution_horizon=10,
max_guidance_weight=1.0,
prefix_attention_schedule=RTCAttentionSchedule.EXP,
)
)
# Demo parameters
duration: float = 30.0 # Duration to run the demo (seconds)
fps: float = 10.0 # Action execution frequency (Hz)
interpolation_multiplier: int = 1 # Control rate multiplier (1=off, 2=2x, 3=3x)
# Compute device
device: str | None = None # Device to run on (cuda, cpu, auto)
# Get new actions horizon. The amount of executed steps after which will be requested new actions.
# It should be higher than inference delay + execution horizon.
action_queue_size_to_get_new_actions: int = 30
# Task to execute
task: str = field(default="", metadata={"help": "Task to execute"})
# Torch compile configuration
use_torch_compile: bool = field(
default=False,
metadata={"help": "Use torch.compile for faster inference (PyTorch 2.0+)"},
)
torch_compile_backend: str = field(
default="inductor",
metadata={"help": "Backend for torch.compile (inductor, aot_eager, cudagraphs)"},
)
torch_compile_mode: str = field(
default="default",
metadata={"help": "Compilation mode (default, reduce-overhead, max-autotune)"},
)
torch_compile_disable_cudagraphs: bool = field(
default=True,
metadata={
"help": "Disable CUDA graphs in torch.compile. Required due to in-place tensor "
"operations in denoising loop (x_t += dt * v_t) which cause tensor aliasing issues."
},
)
def __post_init__(self):
# HACK: We parse again the cli args here to get the pretrained path if there was one.
policy_path = parser.get_path_arg("policy")
if policy_path:
cli_overrides = parser.get_cli_overrides("policy")
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
self.policy.pretrained_path = policy_path
else:
raise ValueError("Policy path is required")
# Validate that robot configuration is provided
if self.robot is None:
raise ValueError("Robot configuration must be provided")
@classmethod
def __get_path_fields__(cls) -> list[str]:
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
return ["policy"]
def is_image_key(k: str) -> bool:
return k.startswith(OBS_IMAGES)
def _reanchor_relative_rtc_prefix(
prev_actions_absolute: Tensor,
current_state: Tensor,
relative_step: RelativeActionsProcessorStep,
normalizer_step: NormalizerProcessorStep | None,
policy_device: torch.device | str,
) -> Tensor:
"""Convert absolute leftovers into model-space for relative-action RTC policies.
When a policy uses relative actions, the RTC prefix (leftover actions from
the previous chunk) is stored in absolute space. Before feeding it back to
the policy we need to re-express it relative to the *current* robot state
and then re-normalize.
"""
state = current_state.detach().cpu()
if state.dim() == 1:
state = state.unsqueeze(0)
action_cpu = prev_actions_absolute.detach().cpu()
mask = relative_step._build_mask(action_cpu.shape[-1])
relative_actions = to_relative_actions(action_cpu, state, mask)
transition = create_transition(action=relative_actions)
if normalizer_step is not None:
transition = normalizer_step(transition)
return transition[TransitionKey.ACTION].to(policy_device)
def get_actions(
policy,
robot: RobotWrapper,
robot_observation_processor,
action_queue: ActionQueue,
shutdown_event: Event,
cfg: RTCDemoConfig,
):
"""Thread function to request action chunks from the policy.
Args:
policy: The policy instance (SmolVLA, Pi0, etc.)
robot: The robot instance for getting observations
robot_observation_processor: Processor for raw robot observations
action_queue: Queue to put new action chunks
shutdown_event: Event to signal shutdown
cfg: Demo configuration
"""
try:
logger.info("[GET_ACTIONS] Starting get actions thread")
latency_tracker = LatencyTracker() # Track latency of action chunks
fps = cfg.fps
time_per_chunk = 1.0 / fps
# Only keep .pos joints + camera streams if the policy was trained on positions,
# not the full pos/vel/torque state the robot exposes.
observation_features_hw = {
key: value
for key, value in robot.observation_features().items()
if key.endswith(".pos") or isinstance(value, tuple)
}
dataset_features = hw_to_dataset_features(observation_features_hw, "observation")
policy_device = policy.config.device
# Load preprocessor and postprocessor from pretrained files
# The stats are embedded in the processor .safetensors files
logger.info(f"[GET_ACTIONS] Loading preprocessor/postprocessor from {cfg.policy.pretrained_path}")
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=cfg.policy,
pretrained_path=cfg.policy.pretrained_path,
dataset_stats=None, # Will load from pretrained processor files
preprocessor_overrides={
"device_processor": {"device": cfg.policy.device},
},
)
logger.info("[GET_ACTIONS] Preprocessor/postprocessor loaded successfully with embedded stats")
relative_step = next(
(s for s in preprocessor.steps if isinstance(s, RelativeActionsProcessorStep) and s.enabled),
None,
)
normalizer_step = next(
(s for s in preprocessor.steps if isinstance(s, NormalizerProcessorStep)),
None,
)
if relative_step is not None:
if relative_step.action_names is None:
cfg_names = getattr(cfg.policy, "action_feature_names", None)
if cfg_names:
relative_step.action_names = list(cfg_names)
else:
relative_step.action_names = [
k for k in robot.robot.action_features if k.endswith(".pos")
]
logger.info("[GET_ACTIONS] Relative actions enabled: will re-anchor RTC prefix")
get_actions_threshold = cfg.action_queue_size_to_get_new_actions
if not cfg.rtc.enabled:
get_actions_threshold = 0
while not shutdown_event.is_set():
if action_queue.qsize() <= get_actions_threshold:
current_time = time.perf_counter()
action_index_before_inference = action_queue.get_action_index()
prev_actions = action_queue.get_left_over()
inference_latency = latency_tracker.max()
inference_delay = math.ceil(inference_latency / time_per_chunk)
obs = robot.get_observation()
# Apply robot observation processor
obs_processed = robot_observation_processor(obs)
obs_with_policy_features = build_dataset_frame(
dataset_features, obs_processed, prefix="observation"
)
for name in obs_with_policy_features:
obs_with_policy_features[name] = torch.from_numpy(obs_with_policy_features[name])
if "image" in name:
obs_with_policy_features[name] = (
obs_with_policy_features[name].type(torch.float32) / 255
)
obs_with_policy_features[name] = (
obs_with_policy_features[name].permute(2, 0, 1).contiguous()
)
obs_with_policy_features[name] = obs_with_policy_features[name].unsqueeze(0)
obs_with_policy_features[name] = obs_with_policy_features[name].to(policy_device)
obs_with_policy_features["task"] = [cfg.task] # Task should be a list, not a string!
obs_with_policy_features["robot_type"] = (
robot.robot.name if hasattr(robot.robot, "name") else ""
)
preproceseded_obs = preprocessor(obs_with_policy_features)
# Re-anchor leftover actions for relative-action policies.
# We need the *postprocessed* (absolute) leftover, not the original
# (normalized/relative) one that get_left_over() returns.
if (
prev_actions is not None
and relative_step is not None
and OBS_STATE in obs_with_policy_features
):
with action_queue.lock:
if action_queue.queue is not None:
prev_actions_abs = action_queue.queue[action_queue.last_index :].clone()
else:
prev_actions_abs = None
if prev_actions_abs is not None and prev_actions_abs.numel() > 0:
prev_actions = _reanchor_relative_rtc_prefix(
prev_actions_absolute=prev_actions_abs,
current_state=obs_with_policy_features[OBS_STATE],
relative_step=relative_step,
normalizer_step=normalizer_step,
policy_device=policy_device,
)
# Generate actions WITH RTC
actions = policy.predict_action_chunk(
preproceseded_obs,
inference_delay=inference_delay,
prev_chunk_left_over=prev_actions,
)
# Store original actions (before postprocessing) for RTC
original_actions = actions.squeeze(0).clone()
postprocessed_actions = postprocessor(actions)
postprocessed_actions = postprocessed_actions.squeeze(0)
new_latency = time.perf_counter() - current_time
new_delay = math.ceil(new_latency / time_per_chunk)
latency_tracker.add(new_latency)
if cfg.action_queue_size_to_get_new_actions < cfg.rtc.execution_horizon + new_delay:
logger.warning(
"[GET_ACTIONS] cfg.action_queue_size_to_get_new_actions Too small, It should be higher than inference delay + execution horizon."
)
action_queue.merge(
original_actions, postprocessed_actions, new_delay, action_index_before_inference
)
else:
# Small sleep to prevent busy waiting
time.sleep(0.1)
logger.info("[GET_ACTIONS] get actions thread shutting down")
except Exception as e:
logger.error(f"[GET_ACTIONS] Fatal exception in get_actions thread: {e}")
logger.error(traceback.format_exc())
sys.exit(1)
def actor_control(
robot: RobotWrapper,
robot_action_processor,
action_queue: ActionQueue,
shutdown_event: Event,
cfg: RTCDemoConfig,
):
"""Thread function to execute actions on the robot.
Args:
robot: The robot instance
action_queue: Queue to get actions from
shutdown_event: Event to signal shutdown
cfg: Demo configuration
"""
try:
logger.info("[ACTOR] Starting actor thread")
action_keys = [k for k in robot.action_features() if k.endswith(".pos")]
action_count = 0
interpolator = ActionInterpolator(multiplier=cfg.interpolation_multiplier)
action_interval = interpolator.get_control_interval(cfg.fps)
while not shutdown_event.is_set():
start_time = time.perf_counter()
if interpolator.needs_new_action():
new_action = action_queue.get()
if new_action is not None:
interpolator.add(new_action.cpu())
action = interpolator.get()
if action is not None:
action = action.cpu()
action_dict = {key: action[i].item() for i, key in enumerate(action_keys)}
action_processed = robot_action_processor((action_dict, None))
robot.send_action(action_processed)
action_count += 1
dt_s = time.perf_counter() - start_time
time.sleep(max(0, (action_interval - dt_s) - 0.001))
logger.info(f"[ACTOR] Actor thread shutting down. Total actions executed: {action_count}")
except Exception as e:
logger.error(f"[ACTOR] Fatal exception in actor_control thread: {e}")
logger.error(traceback.format_exc())
sys.exit(1)
def _apply_torch_compile(policy, cfg: RTCDemoConfig):
"""Apply torch.compile to the policy's predict_action_chunk method.
Args:
policy: Policy instance to compile
cfg: Configuration containing torch compile settings
Returns:
Policy with compiled predict_action_chunk method
"""
# PI models handle their own compilation
if policy.type == "pi05" or policy.type == "pi0":
return policy
try:
# Check if torch.compile is available (PyTorch 2.0+)
if not hasattr(torch, "compile"):
logger.warning(
f"torch.compile is not available. Requires PyTorch 2.0+. "
f"Current version: {torch.__version__}. Skipping compilation."
)
return policy
logger.info("Applying torch.compile to predict_action_chunk...")
logger.info(f" Backend: {cfg.torch_compile_backend}")
logger.info(f" Mode: {cfg.torch_compile_mode}")
logger.info(f" Disable CUDA graphs: {cfg.torch_compile_disable_cudagraphs}")
# Compile the predict_action_chunk method
# - CUDA graphs disabled to prevent tensor aliasing from in-place ops (x_t += dt * v_t)
compile_kwargs = {
"backend": cfg.torch_compile_backend,
"mode": cfg.torch_compile_mode,
}
# Disable CUDA graphs if requested (prevents tensor aliasing issues)
if cfg.torch_compile_disable_cudagraphs:
compile_kwargs["options"] = {"triton.cudagraphs": False}
original_method = policy.predict_action_chunk
compiled_method = torch.compile(original_method, **compile_kwargs)
policy.predict_action_chunk = compiled_method
logger.info("✓ Successfully compiled predict_action_chunk")
except Exception as e:
logger.error(f"Failed to apply torch.compile: {e}")
logger.warning("Continuing without torch.compile")
return policy
@parser.wrap()
def demo_cli(cfg: RTCDemoConfig):
"""Main entry point for RTC demo with draccus configuration."""
# Initialize logging
init_logging()
logger.info(f"Using device: {cfg.device}")
# Setup signal handler for graceful shutdown
signal_handler = ProcessSignalHandler(use_threads=True, display_pid=False)
shutdown_event = signal_handler.shutdown_event
policy = None
robot = None
get_actions_thread = None
actor_thread = None
policy_class = get_policy_class(cfg.policy.type)
# Load config and set compile_model for pi0/pi05 models
config = PreTrainedConfig.from_pretrained(cfg.policy.pretrained_path)
if cfg.policy.type == "pi05" or cfg.policy.type == "pi0":
config.compile_model = cfg.use_torch_compile
if config.use_peft:
from peft import PeftConfig, PeftModel
peft_pretrained_path = cfg.policy.pretrained_path
peft_config = PeftConfig.from_pretrained(peft_pretrained_path)
policy = policy_class.from_pretrained(
pretrained_name_or_path=peft_config.base_model_name_or_path, config=config
)
policy = PeftModel.from_pretrained(policy, peft_pretrained_path, config=peft_config)
else:
policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=config)
# Turn on RTC
policy.config.rtc_config = cfg.rtc
# Init RTC processort, as by default if RTC disabled in the config
# The processor won't be created
policy.init_rtc_processor()
assert policy.name in ["smolvla", "pi05", "pi0"], "Only smolvla, pi05, and pi0 are supported for RTC"
policy = policy.to(cfg.device)
policy.eval()
# Apply torch.compile to predict_action_chunk method if enabled
if cfg.use_torch_compile:
policy = _apply_torch_compile(policy, cfg)
# Create robot
logger.info(f"Initializing robot: {cfg.robot.type}")
robot = make_robot_from_config(cfg.robot)
robot.connect()
robot_wrapper = RobotWrapper(robot)
# Create robot observation processor
robot_observation_processor = make_default_robot_observation_processor()
robot_action_processor = make_default_robot_action_processor()
# Create action queue for communication between threads
action_queue = ActionQueue(cfg.rtc)
# Start chunk requester thread
get_actions_thread = Thread(
target=get_actions,
args=(policy, robot_wrapper, robot_observation_processor, action_queue, shutdown_event, cfg),
daemon=True,
name="GetActions",
)
get_actions_thread.start()
logger.info("Started get actions thread")
# Start action executor thread
actor_thread = Thread(
target=actor_control,
args=(robot_wrapper, robot_action_processor, action_queue, shutdown_event, cfg),
daemon=True,
name="Actor",
)
actor_thread.start()
logger.info("Started actor thread")
logger.info("Started stop by duration thread")
# Main thread monitors for duration or shutdown
logger.info(f"Running demo for {cfg.duration} seconds...")
start_time = time.time()
while not shutdown_event.is_set() and (time.time() - start_time) < cfg.duration:
time.sleep(10)
# Log queue status periodically
if int(time.time() - start_time) % 5 == 0:
logger.info(f"[MAIN] Action queue size: {action_queue.qsize()}")
if time.time() - start_time > cfg.duration:
break
logger.info("Demo duration reached or shutdown requested")
# Signal shutdown
shutdown_event.set()
# Wait for threads to finish
if get_actions_thread and get_actions_thread.is_alive():
logger.info("Waiting for chunk requester thread to finish...")
get_actions_thread.join()
if actor_thread and actor_thread.is_alive():
logger.info("Waiting for action executor thread to finish...")
actor_thread.join()
# Cleanup robot
if robot:
robot.disconnect()
logger.info("Robot disconnected")
logger.info("Cleanup completed")
if __name__ == "__main__":
demo_cli()
logging.info("RTC demo finished")
+32 -63
View File
@@ -14,17 +14,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import time
from lerobot.cameras.opencv import OpenCVCameraConfig
from lerobot.common.control_utils import init_keyboard_listener, predict_action
from lerobot.common.control_utils import init_keyboard_listener
from lerobot.configs import FeatureType, PolicyFeature
from lerobot.datasets import LeRobotDataset, aggregate_pipeline_dataset_features, create_initial_features
from lerobot.model.kinematics import RobotKinematics
from lerobot.policies import make_pre_post_processors
from lerobot.policies.act import ACTPolicy
from lerobot.policies.utils import make_robot_action
from lerobot.processor import (
RobotProcessorPipeline,
make_default_teleop_action_processor,
@@ -38,12 +34,11 @@ from lerobot.robots.so_follower.robot_kinematic_processor import (
ForwardKinematicsJointsToEE,
InverseKinematicsEEToJoints,
)
from lerobot.scripts.lerobot_record import record_loop
from lerobot.types import RobotAction, RobotObservation
from lerobot.utils.constants import ACTION, OBS_STR
from lerobot.utils.feature_utils import build_dataset_frame, combine_feature_dicts
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.feature_utils import combine_feature_dicts
from lerobot.utils.utils import log_say
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
from lerobot.utils.visualization_utils import init_rerun
NUM_EPISODES = 5
FPS = 30
@@ -54,9 +49,6 @@ HF_DATASET_ID = "<hf_username>/<dataset_repo_id>"
def main():
# NOTE: For production policy deployment, use `lerobot-rollout` CLI instead.
# This script provides a self-contained example for educational purposes.
# Create the robot configuration & robot
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
robot_config = SO100FollowerConfig(
@@ -151,67 +143,43 @@ def main():
raise ValueError("Robot is not connected!")
print("Starting evaluate loop...")
control_interval = 1 / FPS
episode_idx = 0
for episode_idx in range(NUM_EPISODES):
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
# Inline evaluation loop: predict actions and send to robot
timestamp = 0
start_episode_t = time.perf_counter()
while timestamp < EPISODE_TIME_SEC:
start_loop_t = time.perf_counter()
if events["exit_early"]:
events["exit_early"] = False
break
# Get robot observation
obs = robot.get_observation()
obs_processed = robot_joints_to_ee_pose_processor(obs)
observation_frame = build_dataset_frame(dataset.features, obs_processed, prefix=OBS_STR)
# Predict action using the policy
action_tensor = predict_action(
observation=observation_frame,
policy=policy,
device=policy.config.device,
preprocessor=preprocessor,
postprocessor=postprocessor,
use_amp=policy.config.device.type == "cuda",
task=TASK_DESCRIPTION,
robot_type=robot.name,
)
# Convert policy output to robot action dict
action_values = make_robot_action(action_tensor, dataset.features)
# Process and send action to robot (EE -> joints via IK)
robot_action_to_send = robot_ee_to_joints_processor((action_values, obs))
robot.send_action(robot_action_to_send)
# Write to dataset
action_frame = build_dataset_frame(dataset.features, action_values, prefix=ACTION)
frame = {**observation_frame, **action_frame, "task": TASK_DESCRIPTION}
dataset.add_frame(frame)
log_rerun_data(observation=obs_processed, action=action_values)
dt_s = time.perf_counter() - start_loop_t
sleep_time_s = control_interval - dt_s
if sleep_time_s < 0:
logging.warning(
f"Evaluate loop is running slower ({1 / dt_s:.1f} Hz) than the target FPS ({FPS} Hz)."
)
precise_sleep(max(sleep_time_s, 0.0))
timestamp = time.perf_counter() - start_episode_t
# Main record loop
record_loop(
robot=robot,
events=events,
fps=FPS,
policy=policy,
preprocessor=preprocessor, # Pass the pre and post policy processors
postprocessor=postprocessor,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=make_default_teleop_action_processor(),
robot_action_processor=robot_ee_to_joints_processor,
robot_observation_processor=robot_joints_to_ee_pose_processor,
)
# Reset the environment if not stopping or re-recording
if not events["stop_recording"] and (
(episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"]
):
log_say("Reset the environment")
log_say("Waiting for environment reset, press right arrow key when ready...")
record_loop(
robot=robot,
events=events,
fps=FPS,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=make_default_teleop_action_processor(),
robot_action_processor=robot_ee_to_joints_processor,
robot_observation_processor=robot_joints_to_ee_pose_processor,
)
if events["rerecord_episode"]:
log_say("Re-record episode")
@@ -222,6 +190,7 @@ def main():
# Save episode
dataset.save_episode()
episode_idx += 1
finally:
# Clean up
log_say("Stop recording")
+17 -15
View File
@@ -62,20 +62,21 @@ def main():
follower = SO100Follower(follower_config)
leader = SO100Leader(leader_config)
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo:
# https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
follower_kinematics_solver = RobotKinematics(
urdf_path="./SO101/so101_new_calib.urdf",
target_frame_name="gripper_frame_link",
joint_names=list(follower.bus.motors.keys()),
)
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
leader_kinematics_solver = RobotKinematics(
urdf_path="./SO101/so101_new_calib.urdf",
target_frame_name="gripper_frame_link",
joint_names=list(leader.bus.motors.keys()),
)
# Build pipeline to convert follower joints to EE observation.
# Build pipeline to convert follower joints to EE observation
follower_joints_to_ee = RobotProcessorPipeline[RobotObservation, RobotObservation](
steps=[
ForwardKinematicsJointsToEE(
@@ -86,7 +87,7 @@ def main():
to_output=transition_to_observation,
)
# Build pipeline to convert leader joints to EE action.
# Build pipeline to convert leader joints to EE action
leader_joints_to_ee = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
steps=[
ForwardKinematicsJointsToEE(
@@ -97,9 +98,9 @@ def main():
to_output=transition_to_robot_action,
)
# Build pipeline to convert EE action to follower joints (with safety bounds).
# Build pipeline to convert EE action to follower joints
ee_to_follower_joints = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
steps=[
[
EEBoundsAndSafety(
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
max_ee_step_m=0.10,
@@ -114,12 +115,13 @@ def main():
to_output=transition_to_robot_action,
)
# Create the dataset, deriving features from the pipelines so the on-disk schema
# matches exactly what the pipelines produce at runtime.
# Create the dataset
dataset = LeRobotDataset.create(
repo_id=HF_REPO_ID,
fps=FPS,
features=combine_feature_dicts(
# Run the feature contract of the pipelines
# This tells you how the features would look like after the pipeline steps
aggregate_pipeline_dataset_features(
pipeline=leader_joints_to_ee,
initial_features=create_initial_features(action=leader.action_features),
@@ -142,7 +144,7 @@ def main():
# Initialize the keyboard listener and rerun visualization
listener, events = init_keyboard_listener()
init_rerun(session_name="recording_so100_ee")
init_rerun(session_name="recording_phone")
try:
if not leader.is_connected or not follower.is_connected:
@@ -158,14 +160,14 @@ def main():
robot=follower,
events=events,
fps=FPS,
teleop_action_processor=leader_joints_to_ee,
robot_action_processor=ee_to_follower_joints,
robot_observation_processor=follower_joints_to_ee,
teleop=leader,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=leader_joints_to_ee,
robot_action_processor=ee_to_follower_joints,
robot_observation_processor=follower_joints_to_ee,
)
# Reset the environment if not stopping or re-recording
@@ -177,13 +179,13 @@ def main():
robot=follower,
events=events,
fps=FPS,
teleop_action_processor=leader_joints_to_ee,
robot_action_processor=ee_to_follower_joints,
robot_observation_processor=follower_joints_to_ee,
teleop=leader,
control_time_s=RESET_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=leader_joints_to_ee,
robot_action_processor=ee_to_follower_joints,
robot_observation_processor=follower_joints_to_ee,
)
if events["rerecord_episode"]:
-134
View File
@@ -1,134 +0,0 @@
# !/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Run a trained EE-space policy on SO100 without recording (base rollout).
Uses the rollout engine's :class:`BaseStrategy` (autonomous execution,
no dataset) with :class:`SyncInferenceConfig` (inline policy call per
control tick). The custom observation/action processors convert between
joint space (robot hardware) and end-effector space (policy I/O) via
forward/inverse kinematics.
"""
from lerobot.cameras.opencv import OpenCVCameraConfig
from lerobot.configs import PreTrainedConfig
from lerobot.model.kinematics import RobotKinematics
from lerobot.processor import (
RobotProcessorPipeline,
observation_to_transition,
robot_action_observation_to_transition,
transition_to_observation,
transition_to_robot_action,
)
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig
from lerobot.robots.so_follower.robot_kinematic_processor import (
ForwardKinematicsJointsToEE,
InverseKinematicsEEToJoints,
)
from lerobot.rollout import BaseStrategyConfig, RolloutConfig, build_rollout_context
from lerobot.rollout.inference import SyncInferenceConfig
from lerobot.rollout.strategies import BaseStrategy
from lerobot.types import RobotAction, RobotObservation
from lerobot.utils.process import ProcessSignalHandler
from lerobot.utils.utils import init_logging
FPS = 30
DURATION_SEC = 60
TASK_DESCRIPTION = "My task description"
HF_MODEL_ID = "<hf_username>/<model_repo_id>"
def main():
init_logging()
# Robot configuration — the rollout engine will connect it inside build_rollout_context.
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
robot_config = SO100FollowerConfig(
port="/dev/tty.usbmodem5A460814411",
id="my_awesome_follower_arm",
cameras=camera_config,
use_degrees=True,
)
# Kinematic solver: we need the motor-name list, so peek at the robot once.
# (The rollout engine owns the connected instance; we only use this for introspection.)
temp_robot = SO100Follower(robot_config)
motor_names = list(temp_robot.bus.motors.keys())
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo:
# https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
kinematics_solver = RobotKinematics(
urdf_path="./SO101/so101_new_calib.urdf",
target_frame_name="gripper_frame_link",
joint_names=motor_names,
)
# Joint-space observation → EE-space observation (consumed by the policy).
robot_joints_to_ee_pose_processor = RobotProcessorPipeline[RobotObservation, RobotObservation](
steps=[ForwardKinematicsJointsToEE(kinematics=kinematics_solver, motor_names=motor_names)],
to_transition=observation_to_transition,
to_output=transition_to_observation,
)
# EE-space action (produced by the policy) → joint-space action (sent to robot).
robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
steps=[
InverseKinematicsEEToJoints(
kinematics=kinematics_solver,
motor_names=motor_names,
initial_guess_current_joints=True,
),
],
to_transition=robot_action_observation_to_transition,
to_output=transition_to_robot_action,
)
# Policy config (full model is loaded inside build_rollout_context).
policy_config = PreTrainedConfig.from_pretrained(HF_MODEL_ID)
policy_config.pretrained_path = HF_MODEL_ID
cfg = RolloutConfig(
robot=robot_config,
policy=policy_config,
strategy=BaseStrategyConfig(),
inference=SyncInferenceConfig(),
fps=FPS,
duration=DURATION_SEC,
task=TASK_DESCRIPTION,
)
signal_handler = ProcessSignalHandler(use_threads=True)
# Pass the EE kinematic processors via kwargs; the defaults (identity) would
# otherwise skip the joint↔EE conversion and the policy would receive the
# wrong observation/action space.
ctx = build_rollout_context(
cfg,
signal_handler.shutdown_event,
robot_action_processor=robot_ee_to_joints_processor,
robot_observation_processor=robot_joints_to_ee_pose_processor,
)
strategy = BaseStrategy(cfg.strategy)
try:
strategy.setup(ctx)
strategy.run(ctx)
finally:
strategy.teardown(ctx)
if __name__ == "__main__":
main()
+8 -15
View File
@@ -108,9 +108,9 @@ training = [
"wandb>=0.24.0,<0.25.0",
]
hardware = [
"lerobot[pynput-dep]",
"lerobot[pyserial-dep]",
"lerobot[deepdiff-dep]",
"pynput>=1.7.8,<1.9.0",
"pyserial>=3.5,<4.0",
"deepdiff>=7.0.1,<9.0.0",
]
viz = [
"rerun-sdk>=0.24.0,<0.27.0",
@@ -136,14 +136,10 @@ scipy-dep = ["scipy>=1.14.0,<2.0.0"]
diffusers-dep = ["diffusers>=0.27.2,<0.36.0"]
qwen-vl-utils-dep = ["qwen-vl-utils>=0.0.11,<0.1.0"]
matplotlib-dep = ["matplotlib>=3.10.3,<4.0.0", "contourpy>=1.3.0,<2.0.0"] # NOTE: Explicitly listing contourpy helps the resolver converge faster.
pyserial-dep = ["pyserial>=3.5,<4.0"]
deepdiff-dep = ["deepdiff>=7.0.1,<9.0.0"]
pynput-dep = ["pynput>=1.7.8,<1.9.0"]
pyzmq-dep = ["pyzmq>=26.2.1,<28.0.0"]
# Motors
feetech = ["feetech-servo-sdk>=1.0.0,<2.0.0", "lerobot[pyserial-dep]", "lerobot[deepdiff-dep]"]
dynamixel = ["dynamixel-sdk>=3.7.31,<3.9.0", "lerobot[pyserial-dep]", "lerobot[deepdiff-dep]"]
feetech = ["feetech-servo-sdk>=1.0.0,<2.0.0"]
dynamixel = ["dynamixel-sdk>=3.7.31,<3.9.0"]
damiao = ["lerobot[can-dep]"]
robstride = ["lerobot[can-dep]"]
@@ -151,11 +147,10 @@ robstride = ["lerobot[can-dep]"]
openarms = ["lerobot[damiao]"]
gamepad = ["lerobot[pygame-dep]", "hidapi>=0.14.0,<0.15.0"]
hopejr = ["lerobot[feetech]", "lerobot[pygame-dep]"]
lekiwi = ["lerobot[feetech]", "lerobot[pyzmq-dep]"]
lekiwi = ["lerobot[feetech]", "pyzmq>=26.2.1,<28.0.0"]
unitree_g1 = [
# "unitree-sdk2==1.0.1",
"lerobot[pyzmq-dep]",
"lerobot[pyserial-dep]",
"pyzmq>=26.2.1,<28.0.0",
"onnxruntime>=1.16.0,<2.0.0",
"onnx>=1.16.0,<2.0.0",
"meshcat>=0.3.0,<0.4.0",
@@ -201,8 +196,7 @@ async = ["lerobot[grpcio-dep]", "lerobot[matplotlib-dep]"]
peft = ["lerobot[transformers-dep]", "lerobot[peft-dep]"]
# Development
dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1", "mypy>=1.19.1", "ruff>=0.14.1", "lerobot[notebook]"]
notebook = ["jupyter>=1.0.0,<2.0.0", "ipykernel>=6.0.0,<7.0.0"]
dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1", "mypy>=1.19.1", "ruff>=0.14.1"]
test = ["pytest>=8.1.0,<9.0.0", "pytest-timeout>=2.4.0,<3.0.0", "pytest-cov>=5.0.0,<8.0.0", "mock-serial>=0.0.1,<0.1.0 ; sys_platform != 'win32'"]
video_benchmark = ["scikit-image>=0.23.2,<0.26.0", "pandas>=2.2.2,<2.4.0"]
@@ -275,7 +269,6 @@ lerobot-find-joint-limits="lerobot.scripts.lerobot_find_joint_limits:main"
lerobot-imgtransform-viz="lerobot.scripts.lerobot_imgtransform_viz:main"
lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main"
lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main"
lerobot-rollout="lerobot.scripts.lerobot_rollout:main"
# ---------------- Tool Configurations ----------------
[tool.setuptools.package-data]
+27 -2
View File
@@ -31,10 +31,22 @@ from __future__ import annotations
import argparse
import json
import re
import sys
from pathlib import Path
# LIBERO-plus derives task.language by space-joining the perturbation-variant
# filename, so strip the perturbation metadata blob to recover the base prompt.
_LIBERO_PERTURBATION_TAIL_RE = re.compile(
r"(?:\s(?:view|initstate|noise|add|tb|table|light|level)(?:\s\d+)+)+$"
)
def _strip_libero_perturbation_tail(instruction: str) -> str:
return _LIBERO_PERTURBATION_TAIL_RE.sub("", instruction).strip()
def _libero_descriptions(task_suite: str) -> dict[str, str]:
from libero.libero import benchmark # type: ignore[import-untyped]
@@ -47,7 +59,10 @@ def _libero_descriptions(task_suite: str) -> dict[str, str]:
)
return {}
suite = suite_dict[task_suite]()
return {f"{task_suite}_{i}": suite.get_task(i).language for i in range(suite.n_tasks)}
return {
f"{task_suite}_{i}": _strip_libero_perturbation_tail(suite.get_task(i).language)
for i in range(suite.n_tasks)
}
def _metaworld_descriptions(task_name: str) -> dict[str, str]:
@@ -57,6 +72,14 @@ def _metaworld_descriptions(task_name: str) -> dict[str, str]:
return {f"{task_name}_0": label}
def _robomme_descriptions(task_names: str) -> dict[str, str]:
return {
f"{task_name}_0": task_name.replace("_", " ").strip()
for task_name in (task.strip() for task in task_names.split(","))
if task_name
}
def main() -> int:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--env", required=True, help="Environment family (libero, metaworld, ...)")
@@ -66,10 +89,12 @@ def main() -> int:
descriptions: dict[str, str] = {}
try:
if args.env == "libero":
if args.env in {"libero", "libero_plus"}:
descriptions = _libero_descriptions(args.task)
elif args.env == "metaworld":
descriptions = _metaworld_descriptions(args.task)
elif args.env == "robomme":
descriptions = _robomme_descriptions(args.task)
else:
print(
f"[extract_task_descriptions] No description extractor for env '{args.env}'.",
+27
View File
@@ -0,0 +1,27 @@
---
title: LeRobot Benchmark Leaderboard
emoji: 🤖
colorFrom: yellow
colorTo: orange
sdk: gradio
sdk_version: 5.29.0
app_file: app.py
pinned: false
license: apache-2.0
short_description: Benchmark history for LeRobot policy x benchmark runs
---
# LeRobot Benchmark Leaderboard
This Space reads immutable benchmark rows from a Hugging Face dataset and shows:
- Latest result per policy and benchmark
- Historical trends over time
- Direct links to uploaded eval and config artifacts
## Configuration
Set `BENCHMARK_RESULTS_REPO` in the Space settings if you want to point the UI
at a different public dataset. The default is:
- `lerobot/benchmark-history`
+226
View File
@@ -0,0 +1,226 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import json
import os
import time
from pathlib import Path
from typing import Any
import gradio as gr
import pandas as pd
import plotly.express as px
from huggingface_hub import HfApi, hf_hub_download
RESULTS_REPO = os.environ.get("BENCHMARK_RESULTS_REPO", "lerobot/benchmark-history")
CACHE_DIR = Path("/tmp/benchmark-leaderboard-cache")
CACHE_DIR.mkdir(parents=True, exist_ok=True)
CACHE_TTL_S = 300
_CACHE: dict[str, tuple[float, pd.DataFrame]] = {}
def _row_to_record(row: dict[str, Any]) -> dict[str, Any]:
overall = row.get("eval", {}).get("overall", {})
resources = row.get("resources", {})
timings = row.get("timings", {})
artifact_urls = row.get("artifact_urls", {})
return {
"created_at": row.get("created_at"),
"benchmark": row.get("benchmark"),
"policy": row.get("policy"),
"success_rate": overall.get("pc_success"),
"n_episodes": overall.get("n_episodes"),
"avg_sum_reward": overall.get("avg_sum_reward"),
"train_wall_time_s": timings.get("train_wall_time_s"),
"eval_wall_time_s": timings.get("eval_wall_time_s"),
"total_wall_time_s": timings.get("total_wall_time_s"),
"num_gpus": resources.get("num_gpus"),
"microbatch_per_gpu": resources.get("microbatch_per_gpu"),
"gradient_accumulation_steps": resources.get("gradient_accumulation_steps"),
"effective_batch_size": resources.get("effective_batch_size"),
"git_commit": row.get("git_commit"),
"row_url": artifact_urls.get("row"),
"eval_info_url": artifact_urls.get("eval_info"),
"train_config_url": artifact_urls.get("train_config"),
}
def load_rows(repo_id: str = RESULTS_REPO) -> pd.DataFrame:
cache_key = f"rows::{repo_id}"
cached = _CACHE.get(cache_key)
if cached is not None and (time.monotonic() - cached[0]) < CACHE_TTL_S:
return cached[1]
api = HfApi()
files = [path for path in api.list_repo_files(repo_id=repo_id, repo_type="dataset") if path.startswith("rows/")]
records: list[dict[str, Any]] = []
for path_in_repo in sorted(files, reverse=True):
local_path = hf_hub_download(repo_id=repo_id, repo_type="dataset", filename=path_in_repo, cache_dir=CACHE_DIR)
with open(local_path) as f:
row = json.load(f)
records.append(_row_to_record(row))
df = pd.DataFrame.from_records(records)
if not df.empty:
df["created_at"] = pd.to_datetime(df["created_at"], utc=True)
df = df.sort_values("created_at", ascending=False).reset_index(drop=True)
_CACHE[cache_key] = (time.monotonic(), df)
return df
def make_latest_table(df: pd.DataFrame) -> pd.DataFrame:
if df.empty:
return df
latest = (
df.sort_values("created_at", ascending=False)
.groupby(["benchmark", "policy"], as_index=False)
.first()
.sort_values(["benchmark", "success_rate"], ascending=[True, False], na_position="last")
)
return latest[
[
"benchmark",
"policy",
"success_rate",
"n_episodes",
"train_wall_time_s",
"eval_wall_time_s",
"num_gpus",
"effective_batch_size",
"git_commit",
"row_url",
"eval_info_url",
"train_config_url",
]
]
def make_history_figure(df: pd.DataFrame, benchmark: str, policy: str | None) -> Any:
filtered = df[df["benchmark"] == benchmark]
if policy and policy != "All":
filtered = filtered[filtered["policy"] == policy]
if filtered.empty:
return px.line(title="No benchmark rows found")
fig = px.line(
filtered.sort_values("created_at"),
x="created_at",
y="success_rate",
color="policy",
markers=True,
hover_data=["git_commit", "num_gpus", "train_wall_time_s", "eval_wall_time_s"],
title=f"{benchmark} success rate history",
)
fig.update_layout(yaxis_title="Success rate (%)", xaxis_title="Run time")
return fig
def make_run_markdown(df: pd.DataFrame, benchmark: str, policy: str | None) -> str:
filtered = df[df["benchmark"] == benchmark]
if policy and policy != "All":
filtered = filtered[filtered["policy"] == policy]
if filtered.empty:
return "No matching runs yet."
latest = filtered.sort_values("created_at", ascending=False).iloc[0]
row_link = latest["row_url"] if pd.notna(latest["row_url"]) else None
eval_link = latest["eval_info_url"] if pd.notna(latest["eval_info_url"]) else None
train_link = latest["train_config_url"] if pd.notna(latest["train_config_url"]) else None
lines = [
f"Latest run: `{latest['policy']}` on `{latest['benchmark']}`",
f"Success rate: `{latest['success_rate']}`",
f"GPUs: `{latest['num_gpus']}`",
f"Effective batch size: `{latest['effective_batch_size']}`",
f"Commit: `{latest['git_commit']}`",
]
if row_link:
lines.append(f"Row JSON: [open]({row_link})")
if eval_link:
lines.append(f"Eval Info: [open]({eval_link})")
if train_link:
lines.append(f"Train Config: [open]({train_link})")
return "\n\n".join(lines)
def refresh_view(benchmark: str, policy: str) -> tuple[pd.DataFrame, dict[str, Any], Any, str]:
df = load_rows()
latest_table = make_latest_table(df)
benchmark_names = sorted(df["benchmark"].dropna().unique().tolist()) if not df.empty else []
if benchmark not in benchmark_names and benchmark_names:
benchmark = benchmark_names[0]
policy_choices = ["All"]
if benchmark and not df.empty:
policy_choices.extend(sorted(df[df["benchmark"] == benchmark]["policy"].dropna().unique().tolist()))
if policy not in policy_choices:
policy = "All"
history = make_history_figure(df, benchmark, policy)
summary = make_run_markdown(df, benchmark, policy)
return latest_table, gr.update(choices=policy_choices, value=policy), history, summary
with gr.Blocks(title="LeRobot Benchmark Leaderboard") as demo:
gr.Markdown(
f"""
# LeRobot Benchmark Leaderboard
Results dataset: [`{RESULTS_REPO}`](https://huggingface.co/datasets/{RESULTS_REPO})
"""
)
with gr.Row():
benchmark_dropdown = gr.Dropdown(label="Benchmark", choices=[])
policy_dropdown = gr.Dropdown(label="Policy", choices=["All"], value="All")
refresh_button = gr.Button("Refresh")
latest_table = gr.Dataframe(label="Latest Results", interactive=False)
history_plot = gr.Plot(label="History")
latest_summary = gr.Markdown()
def _initial_state():
df = load_rows()
benchmarks = sorted(df["benchmark"].dropna().unique().tolist()) if not df.empty else []
benchmark = benchmarks[0] if benchmarks else ""
latest, policy_choices, history, summary = refresh_view(benchmark, "All")
return (
gr.update(choices=benchmarks, value=benchmark),
policy_choices,
latest,
history,
summary,
)
demo.load(
_initial_state,
outputs=[benchmark_dropdown, policy_dropdown, latest_table, history_plot, latest_summary],
)
refresh_button.click(
refresh_view,
inputs=[benchmark_dropdown, policy_dropdown],
outputs=[latest_table, policy_dropdown, history_plot, latest_summary],
)
benchmark_dropdown.change(
refresh_view,
inputs=[benchmark_dropdown, policy_dropdown],
outputs=[latest_table, policy_dropdown, history_plot, latest_summary],
)
policy_dropdown.change(
refresh_view,
inputs=[benchmark_dropdown, policy_dropdown],
outputs=[latest_table, policy_dropdown, history_plot, latest_summary],
)
if __name__ == "__main__":
demo.launch()
@@ -0,0 +1,4 @@
gradio>=5.0.0,<6.0.0
plotly>=5.18.0
pandas>=2.0.0
huggingface-hub>=1.0.0,<2.0.0
@@ -33,7 +33,7 @@ import cv2 # type: ignore # TODO: add type stubs for OpenCV
import numpy as np # type: ignore # TODO: add type stubs for numpy
from lerobot.utils.decorators import check_if_not_connected
from lerobot.utils.import_utils import _reachy2_sdk_available, require_package
from lerobot.utils.import_utils import _reachy2_sdk_available
if TYPE_CHECKING or _reachy2_sdk_available:
from reachy2_sdk.media.camera import CameraView
@@ -76,7 +76,6 @@ class Reachy2Camera(Camera):
Args:
config: The configuration settings for the camera.
"""
require_package("reachy2_sdk", extra="reachy2")
super().__init__(config)
self.config = config
@@ -19,18 +19,16 @@ Provides the RealSenseCamera class for capturing frames from Intel RealSense cam
import logging
import time
from threading import Event, Lock, Thread
from typing import TYPE_CHECKING, Any
from typing import Any
import cv2 # type: ignore # TODO: add type stubs for OpenCV
import numpy as np # type: ignore # TODO: add type stubs for numpy
from numpy.typing import NDArray # type: ignore # TODO: add type stubs for numpy.typing
from lerobot.utils.import_utils import _pyrealsense2_available, require_package
if TYPE_CHECKING or _pyrealsense2_available:
import pyrealsense2 as rs
else:
rs = None
try:
import pyrealsense2 as rs # type: ignore # TODO: add type stubs for pyrealsense2
except Exception as e:
logging.info(f"Could not import realsense: {e}")
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from lerobot.utils.errors import DeviceNotConnectedError
@@ -114,7 +112,7 @@ class RealSenseCamera(Camera):
Args:
config: The configuration settings for the camera.
"""
require_package("pyrealsense2", extra="intelrealsense")
super().__init__(config)
self.config = config
+9 -11
View File
@@ -28,19 +28,12 @@ import json
import logging
import time
from threading import Event, Lock, Thread
from typing import TYPE_CHECKING, Any
from typing import Any
import cv2
import numpy as np
from numpy.typing import NDArray
from lerobot.utils.import_utils import _zmq_available, require_package
if TYPE_CHECKING or _zmq_available:
import zmq
else:
zmq = None
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from lerobot.utils.errors import DeviceNotConnectedError
@@ -81,8 +74,8 @@ class ZMQCamera(Camera):
"""
def __init__(self, config: ZMQCameraConfig):
require_package("pyzmq", extra="pyzmq-dep", import_name="zmq")
super().__init__(config)
import zmq
self.config = config
self.server_address = config.server_address
@@ -124,6 +117,8 @@ class ZMQCamera(Camera):
logger.info(f"Connecting to {self}...")
try:
import zmq
self.context = zmq.Context()
self.socket = self.context.socket(zmq.SUB)
self.socket.setsockopt_string(zmq.SUBSCRIBE, "")
@@ -185,8 +180,11 @@ class ZMQCamera(Camera):
try:
message = self.socket.recv_string()
except zmq.Again as e:
raise TimeoutError(f"{self} timeout after {self.timeout_ms}ms") from e
except Exception as e:
# zmq is lazy-imported in connect(), so check by name to avoid a top-level import
if type(e).__name__ == "Again":
raise TimeoutError(f"{self} timeout after {self.timeout_ms}ms") from e
raise
# Decode JSON message
data = json.loads(message)
+4 -7
View File
@@ -28,12 +28,6 @@ import numpy as np
import torch
from lerobot.policies import PreTrainedPolicy, prepare_observation_for_inference
from lerobot.utils.import_utils import _deepdiff_available, require_package
if TYPE_CHECKING or _deepdiff_available:
from deepdiff import DeepDiff
else:
DeepDiff = None
if TYPE_CHECKING:
from lerobot.datasets import LeRobotDataset
@@ -223,7 +217,10 @@ def sanity_check_dataset_robot_compatibility(
Raises:
ValueError: If any of the checked metadata fields do not match.
"""
require_package("deepdiff", extra="deepdiff-dep")
from lerobot.utils.import_utils import require_package
require_package("deepdiff", extra="hardware")
from deepdiff import DeepDiff
from lerobot.utils.constants import DEFAULT_FEATURES
-2
View File
@@ -21,7 +21,6 @@ are intentionally NOT re-exported here to avoid circular dependencies
Import them directly: ``from lerobot.configs.train import TrainPipelineConfig``
"""
from .dataset import DatasetRecordConfig
from .default import DatasetConfig, EvalConfig, PeftConfig, WandBConfig
from .policies import PreTrainedConfig
from .types import (
@@ -40,7 +39,6 @@ __all__ = [
"PolicyFeature",
"RTCAttentionSchedule",
# Config classes
"DatasetRecordConfig",
"DatasetConfig",
"EvalConfig",
"PeftConfig",
-77
View File
@@ -1,77 +0,0 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Shared dataset recording configuration used by both ``lerobot-record`` and ``lerobot-rollout``."""
from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path
@dataclass
class DatasetRecordConfig:
# Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).
repo_id: str = ""
# A short but accurate description of the task performed during the recording (e.g. "Pick the Lego block and drop it in the box on the right.")
single_task: str = ""
# Root directory where the dataset will be stored (e.g. 'dataset/path'). If None, defaults to $HF_LEROBOT_HOME/repo_id.
root: str | Path | None = None
# Limit the frames per second.
fps: int = 30
# Number of seconds for data recording for each episode.
episode_time_s: int | float = 60
# Number of seconds for resetting the environment after each episode.
reset_time_s: int | float = 60
# Number of episodes to record.
num_episodes: int = 50
# Encode frames in the dataset into video
video: bool = True
# Upload dataset to Hugging Face hub.
push_to_hub: bool = True
# Upload on private repository on the Hugging Face hub.
private: bool = False
# Add tags to your dataset on the hub.
tags: list[str] | None = None
# Number of subprocesses handling the saving of frames as PNG. Set to 0 to use threads only;
# set to ≥1 to use subprocesses, each using threads to write images. The best number of processes
# and threads depends on your system. We recommend 4 threads per camera with 0 processes.
# If fps is unstable, adjust the thread count. If still unstable, try using 1 or more subprocesses.
num_image_writer_processes: int = 0
# Number of threads writing the frames as png images on disk, per camera.
# Too many threads might cause unstable teleoperation fps due to main thread being blocked.
# Not enough threads might cause low camera fps.
num_image_writer_threads_per_camera: int = 4
# Number of episodes to record before batch encoding videos
# Set to 1 for immediate encoding (default behavior), or higher for batched encoding
video_encoding_batch_size: int = 1
# Video codec for encoding videos. Options: 'h264', 'hevc', 'libsvtav1', 'auto',
# or hardware-specific: 'h264_videotoolbox', 'h264_nvenc', 'h264_vaapi', 'h264_qsv'.
# Use 'auto' to auto-detect the best available hardware encoder.
vcodec: str = "libsvtav1"
# Enable streaming video encoding: encode frames in real-time during capture instead
# of writing PNG images first. Makes save_episode() near-instant. More info in the documentation: https://huggingface.co/docs/lerobot/streaming_video_encoding
streaming_encoding: bool = False
# Maximum number of frames to buffer per camera when using streaming encoding.
# ~1s buffer at 30fps. Provides backpressure if the encoder can't keep up.
encoder_queue_maxsize: int = 30
# Number of threads per encoder instance. None = auto (codec default).
# Lower values reduce CPU usage, maps to 'lp' (via svtav1-params) for libsvtav1 and 'threads' for h264/hevc..
encoder_threads: int | None = None
# Rename map for the observation to override the image and state keys
rename_map: dict[str, str] = field(default_factory=dict)
def __post_init__(self) -> None:
if self.repo_id:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
self.repo_id = f"{self.repo_id}_{timestamp}"
+6 -3
View File
@@ -35,9 +35,6 @@ class DatasetConfig:
revision: str | None = None
use_imagenet_stats: bool = True
video_backend: str = field(default_factory=get_safe_default_codec)
# When True, video frames are returned as uint8 tensors (0-255) instead of float32 (0.0-1.0).
# This reduces memory and speeds up DataLoader IPC. The training pipeline handles the conversion.
return_uint8: bool = False
streaming: bool = False
def __post_init__(self) -> None:
@@ -70,11 +67,17 @@ class EvalConfig:
# `batch_size` specifies the number of environments to use in a gym.vector.VectorEnv.
# Set to 0 for auto-tuning based on available CPU cores and n_episodes.
batch_size: int = 0
# Number of rollout videos to save per evaluated task. Set to 0 to disable videos.
max_episodes_rendered: int = 10
# `use_async_envs` specifies whether to use asynchronous environments (multiprocessing).
# Defaults to True; automatically downgraded to SyncVectorEnv when batch_size=1.
use_async_envs: bool = True
def __post_init__(self) -> None:
if self.max_episodes_rendered < 0:
raise ValueError(
f"`max_episodes_rendered` must be non-negative, got {self.max_episodes_rendered}."
)
if self.batch_size == 0:
self.batch_size = self._auto_batch_size()
if self.batch_size > self.n_episodes:
+6 -2
View File
@@ -56,8 +56,7 @@ class TrainPipelineConfig(HubMixin):
# Number of workers for the dataloader.
num_workers: int = 4
batch_size: int = 8
prefetch_factor: int = 4
persistent_workers: bool = True
gradient_accumulation_steps: int = 1
steps: int = 100_000
eval_freq: int = 20_000
log_freq: int = 200
@@ -134,6 +133,11 @@ class TrainPipelineConfig(HubMixin):
if isinstance(self.dataset.repo_id, list):
raise NotImplementedError("LeRobotMultiDataset is not currently implemented.")
if self.gradient_accumulation_steps <= 0:
raise ValueError(
f"`gradient_accumulation_steps` must be strictly positive, got {self.gradient_accumulation_steps}."
)
if not self.use_policy_training_preset and (self.optimizer is None or self.scheduler is None):
raise ValueError("Optimizer and Scheduler must be set when the policy presets are not used.")
elif self.use_policy_training_preset and not self.resume:
+10 -25
View File
@@ -16,7 +16,6 @@
"""Private reader component for LeRobotDataset. Handles random-access reading (HF dataset, delta indices, video decoding)."""
from collections.abc import Callable
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
import datasets
@@ -50,7 +49,6 @@ class DatasetReader:
video_backend: str,
delta_timestamps: dict[str, list[float]] | None,
image_transforms: Callable | None,
return_uint8: bool = False,
):
"""Initialize the reader with metadata, filtering, and transform config.
@@ -75,7 +73,6 @@ class DatasetReader:
self._tolerance_s = tolerance_s
self._video_backend = video_backend
self._image_transforms = image_transforms
self._return_uint8 = return_uint8
self.hf_dataset: datasets.Dataset | None = None
self._absolute_to_relative_idx: dict[int, int] | None = None
@@ -108,8 +105,10 @@ class DatasetReader:
"""Build absolute-to-relative index mapping from loaded hf_dataset."""
self._absolute_to_relative_idx = None
if self.episodes is not None and self.hf_dataset is not None:
indices = self.hf_dataset.data.column("index").to_numpy()
self._absolute_to_relative_idx = dict(zip(indices.tolist(), range(len(indices)), strict=True))
self._absolute_to_relative_idx = {
abs_idx.item() if isinstance(abs_idx, torch.Tensor) else abs_idx: rel_idx
for rel_idx, abs_idx in enumerate(self.hf_dataset["index"])
}
@property
def num_frames(self) -> int:
@@ -236,30 +235,16 @@ class DatasetReader:
Segmentation Fault.
"""
ep = self._meta.episodes[ep_idx]
def _decode_single(vid_key: str, query_ts: list[float]) -> tuple[str, torch.Tensor]:
item = {}
for vid_key, query_ts in query_timestamps.items():
from_timestamp = ep[f"videos/{vid_key}/from_timestamp"]
shifted_query_ts = [from_timestamp + ts for ts in query_ts]
video_path = self.root / self._meta.get_video_file_path(ep_idx, vid_key)
frames = decode_video_frames(
video_path,
shifted_query_ts,
self._tolerance_s,
self._video_backend,
return_uint8=self._return_uint8,
)
return vid_key, frames.squeeze(0)
frames = decode_video_frames(video_path, shifted_query_ts, self._tolerance_s, self._video_backend)
item[vid_key] = frames.squeeze(0)
items = list(query_timestamps.items())
# Single camera: no threading overhead
if len(items) <= 1:
return {vid_key: _decode_single(vid_key, query_ts)[1] for vid_key, query_ts in items}
# Multi-camera: decode in parallel (video decoding releases the GIL)
with ThreadPoolExecutor(max_workers=len(items)) as pool:
futures = [pool.submit(_decode_single, k, ts) for k, ts in items]
return dict(f.result() for f in futures)
return item
def get_item(self, idx) -> dict:
"""Core __getitem__ logic. Assumes hf_dataset is loaded.
+1 -1
View File
@@ -597,7 +597,7 @@ class DatasetWriter:
def cleanup_interrupted_episode(self, episode_index: int) -> None:
"""Remove temporary image directories for an interrupted episode."""
for key in self._meta.camera_keys:
for key in self._meta.video_keys:
img_dir = self._get_image_file_path(
episode_index=episode_index, image_key=key, frame_index=0
).parent
-2
View File
@@ -92,7 +92,6 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
image_transforms=image_transforms,
revision=cfg.dataset.revision,
video_backend=cfg.dataset.video_backend,
return_uint8=True,
tolerance_s=cfg.tolerance_s,
)
else:
@@ -105,7 +104,6 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
revision=cfg.dataset.revision,
max_num_shards=cfg.num_workers,
tolerance_s=cfg.tolerance_s,
return_uint8=True,
)
else:
raise NotImplementedError("The MultiLeRobotDataset isn't supported for now.")
+2 -2
View File
@@ -30,13 +30,13 @@ def safe_stop_image_writer(func):
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except BaseException:
except Exception as e:
dataset = kwargs.get("dataset")
writer = getattr(dataset, "writer", None) if dataset else None
if writer is not None and writer.image_writer is not None:
logger.warning("Waiting for image writer to terminate...")
writer.image_writer.stop()
raise
raise e
return wrapper
-6
View File
@@ -56,7 +56,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
force_cache_sync: bool = False,
download_videos: bool = True,
video_backend: str | None = None,
return_uint8: bool = False,
batch_encoding_size: int = 1,
vcodec: str = "libsvtav1",
streaming_encoding: bool = False,
@@ -203,7 +202,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.tolerance_s = tolerance_s
self.revision = revision if revision else CODEBASE_VERSION
self._video_backend = video_backend if video_backend else get_safe_default_codec()
self._return_uint8 = return_uint8
self._batch_encoding_size = batch_encoding_size
self._vcodec = resolve_vcodec(vcodec)
self._encoder_threads = encoder_threads
@@ -227,7 +225,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
video_backend=self._video_backend,
delta_timestamps=delta_timestamps,
image_transforms=image_transforms,
return_uint8=self._return_uint8,
)
# Load actual data
@@ -291,7 +288,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
video_backend=self._video_backend,
delta_timestamps=self.delta_timestamps,
image_transforms=self.image_transforms,
return_uint8=self._return_uint8,
)
return self.reader
@@ -687,7 +683,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj.delta_timestamps = None
obj.episodes = None
obj._video_backend = video_backend if video_backend is not None else get_safe_default_codec()
obj._return_uint8 = False
obj._batch_encoding_size = batch_encoding_size
obj._vcodec = vcodec
obj._encoder_threads = encoder_threads
@@ -780,7 +775,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj.delta_timestamps = None
obj.episodes = None
obj._video_backend = video_backend if video_backend else get_safe_default_codec()
obj._return_uint8 = False
obj._batch_encoding_size = batch_encoding_size
obj._vcodec = vcodec
obj._encoder_threads = encoder_threads
+1 -7
View File
@@ -251,7 +251,6 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
seed: int = 42,
rng: np.random.Generator | None = None,
shuffle: bool = True,
return_uint8: bool = False,
):
"""Initialize a StreamingLeRobotDataset.
@@ -289,7 +288,6 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
self.streaming = streaming
self.buffer_size = buffer_size
self._return_uint8 = return_uint8
# We cache the video decoders to avoid re-initializing them at each frame (avoiding a ~10x slowdown)
self.video_decoder_cache = None
@@ -555,11 +553,7 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
root = self.meta.url_root if self.streaming and not self.streaming_from_local else self.root
video_path = f"{root}/{self.meta.get_video_file_path(ep_idx, video_key)}"
frames = decode_video_frames_torchcodec(
video_path,
query_ts,
self.tolerance_s,
decoder_cache=self.video_decoder_cache,
return_uint8=self._return_uint8,
video_path, query_ts, self.tolerance_s, decoder_cache=self.video_decoder_cache
)
item[video_key] = frames.squeeze(0) if len(query_ts) == 1 else frames
+2 -2
View File
@@ -71,8 +71,8 @@ class ForwardCompatibilityError(CompatibilityError):
DEFAULT_CHUNK_SIZE = 1000 # Max number of files per chunk
DEFAULT_DATA_FILE_SIZE_IN_MB = 50 # Max size per file
DEFAULT_VIDEO_FILE_SIZE_IN_MB = 100 # Max size per file
DEFAULT_DATA_FILE_SIZE_IN_MB = 100 # Max size per file
DEFAULT_VIDEO_FILE_SIZE_IN_MB = 200 # Max size per file
INFO_PATH = "meta/info.json"
STATS_PATH = "meta/stats.json"
+10 -22
View File
@@ -123,7 +123,6 @@ def decode_video_frames(
timestamps: list[float],
tolerance_s: float,
backend: str | None = None,
return_uint8: bool = False,
) -> torch.Tensor:
"""
Decodes video frames using the specified backend.
@@ -132,23 +131,19 @@ def decode_video_frames(
video_path (Path): Path to the video file.
timestamps (list[float]): List of timestamps to extract frames.
tolerance_s (float): Allowed deviation in seconds for frame retrieval.
backend (str, optional): Backend to use for decoding. Defaults to "torchcodec" when available in the platform; otherwise, defaults to "pyav".
return_uint8 (bool): If True, return raw uint8 frames without float32 normalization.
This reduces memory for DataLoader IPC; normalization can be done on GPU afterward.
backend (str, optional): Backend to use for decoding. Defaults to "torchcodec" when available in the platform; otherwise, defaults to "pyav"..
Returns:
torch.Tensor: Decoded frames (float32 in [0,1] by default, or uint8 if return_uint8=True).
torch.Tensor: Decoded frames.
Currently supports torchcodec on cpu and pyav.
"""
if backend is None:
backend = get_safe_default_codec()
if backend == "torchcodec":
return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s, return_uint8=return_uint8)
return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s)
elif backend in ["pyav", "video_reader"]:
return decode_video_frames_torchvision(
video_path, timestamps, tolerance_s, backend, return_uint8=return_uint8
)
return decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
else:
raise ValueError(f"Unsupported video backend: {backend}")
@@ -159,7 +154,6 @@ def decode_video_frames_torchvision(
tolerance_s: float,
backend: str = "pyav",
log_loaded_timestamps: bool = False,
return_uint8: bool = False,
) -> torch.Tensor:
"""Loads frames associated to the requested timestamps of a video
@@ -246,17 +240,14 @@ def decode_video_frames_torchvision(
if log_loaded_timestamps:
logger.info(f"{closest_ts=}")
# convert to the pytorch format which is float32 in [0,1] range (and channel first)
closest_frames = closest_frames.type(torch.float32) / 255
if len(timestamps) != len(closest_frames):
raise FrameTimestampError(
f"Number of retrieved frames ({len(closest_frames)}) does not match "
f"number of queried timestamps ({len(timestamps)})"
)
if return_uint8:
return closest_frames
# convert to the pytorch format which is float32 in [0,1] range (and channel first)
closest_frames = closest_frames.type(torch.float32) / 255
return closest_frames
@@ -315,7 +306,6 @@ def decode_video_frames_torchcodec(
tolerance_s: float,
log_loaded_timestamps: bool = False,
decoder_cache: VideoDecoderCache | None = None,
return_uint8: bool = False,
) -> torch.Tensor:
"""Loads frames associated with the requested timestamps of a video using torchcodec.
@@ -383,16 +373,14 @@ def decode_video_frames_torchcodec(
if log_loaded_timestamps:
logger.info(f"{closest_ts=}")
# convert to float32 in [0,1] range
closest_frames = (closest_frames / 255.0).type(torch.float32)
if not len(timestamps) == len(closest_frames):
raise FrameTimestampError(
f"Retrieved timestamps differ from queried {set(closest_frames) - set(timestamps)}"
)
if return_uint8:
return closest_frames
# convert to float32 in [0,1] range
closest_frames = (closest_frames / 255.0).type(torch.float32)
return closest_frames
+11 -1
View File
@@ -18,7 +18,15 @@
# from lerobot.utils.import_utils import require_package
# require_package("gymnasium", extra="<update_extra>", import_name="gymnasium")
from .configs import AlohaEnv, EnvConfig, HILSerlRobotEnvConfig, HubEnvConfig, PushtEnv
from .configs import (
AlohaEnv,
EnvConfig,
HILSerlRobotEnvConfig,
HubEnvConfig,
LiberoPlusEnv,
PushtEnv,
RoboMMEEnv,
)
from .factory import make_env, make_env_config, make_env_pre_post_processors
from .utils import check_env_attributes_and_types, close_envs, env_to_policy_features, preprocess_observation
@@ -27,7 +35,9 @@ __all__ = [
"EnvConfig",
"HILSerlRobotEnvConfig",
"HubEnvConfig",
"LiberoPlusEnv",
"PushtEnv",
"RoboMMEEnv",
"check_env_attributes_and_types",
"close_envs",
"env_to_policy_features",
+55
View File
@@ -574,3 +574,58 @@ class IsaaclabArenaEnv(HubEnvConfig):
),
PolicyProcessorPipeline(steps=[]),
)
@EnvConfig.register_subclass("libero_plus")
@dataclass
class LiberoPlusEnv(LiberoEnv):
"""Config for LIBERO-plus robustness benchmark evaluation."""
task: str = "libero_spatial"
@EnvConfig.register_subclass("robomme")
@dataclass
class RoboMMEEnv(EnvConfig):
"""RoboMME memory-augmented manipulation benchmark."""
task: str = "PickXtimes"
fps: int = 10
episode_length: int = 300
action_space: str = "joint_angle"
dataset_split: str = "test"
task_ids: list[int] | None = None
features: dict[str, PolicyFeature] = field(
default_factory=lambda: {
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(8,)),
"image": PolicyFeature(type=FeatureType.VISUAL, shape=(256, 256, 3)),
"wrist_image": PolicyFeature(type=FeatureType.VISUAL, shape=(256, 256, 3)),
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(8,)),
}
)
features_map: dict[str, str] = field(
default_factory=lambda: {
ACTION: ACTION,
"image": f"{OBS_IMAGES}.image",
"wrist_image": f"{OBS_IMAGES}.wrist_image",
OBS_STATE: OBS_STATE,
}
)
@property
def gym_kwargs(self) -> dict:
return {}
def create_envs(self, n_envs: int, use_async_envs: bool = True):
from .robomme import create_robomme_envs
env_cls = _make_vec_env_cls(use_async_envs, n_envs)
return create_robomme_envs(
task=self.task,
n_envs=n_envs,
action_space_type=self.action_space,
dataset=self.dataset_split,
episode_length=self.episode_length,
task_ids=self.task_ids,
env_cls=env_cls,
)
+22 -7
View File
@@ -16,6 +16,7 @@
from __future__ import annotations
import os
import re
from collections import defaultdict
from collections.abc import Callable, Iterable, Mapping, Sequence
from functools import partial
@@ -69,14 +70,28 @@ def _select_task_ids(total_tasks: int, task_ids: Iterable[int] | None) -> list[i
return ids
# LIBERO-plus perturbation variants encode the perturbation in the filename
# but on disk only the base `.pruned_init` exists — strip the suffix to match
# LIBERO-plus's own suite.get_task_init_states() (we reimplement it here so we
# can pass weights_only=False for PyTorch 2.6+ numpy pickles).
_LIBERO_PERTURBATION_SUFFIX_RE = re.compile(r"_(?:language|view|light)_[^.]*|_(?:table|tb)_\d+")
def get_task_init_states(task_suite: Any, i: int) -> np.ndarray:
init_states_path = (
Path(get_libero_path("init_states"))
/ task_suite.tasks[i].problem_folder
/ task_suite.tasks[i].init_states_file
)
init_states = torch.load(init_states_path, weights_only=False) # nosec B614
return init_states
task = task_suite.tasks[i]
filename = Path(task.init_states_file)
root = Path(get_libero_path("init_states"))
# `_add_` / `_level` variants store extra-object layouts under libero_newobj/
# as a flat array that must be reshaped to (1, -1).
if "_add_" in filename.name or "_level" in filename.name:
init_states_path = root / "libero_newobj" / task.problem_folder / filename.name
init_states = torch.load(init_states_path, weights_only=False) # nosec B614
return init_states.reshape(1, -1)
stripped = _LIBERO_PERTURBATION_SUFFIX_RE.sub("", filename.stem) + filename.suffix
init_states_path = root / task.problem_folder / stripped
return torch.load(init_states_path, weights_only=False) # nosec B614
def get_libero_dummy_action():
+209
View File
@@ -0,0 +1,209 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""RoboMME environment wrapper for LeRobot evaluation."""
from __future__ import annotations
from collections.abc import Callable, Sequence
from functools import partial
from typing import Any
import gymnasium as gym
import numpy as np
from gymnasium import spaces
ROBOMME_TASKS = [
"BinFill",
"PickXtimes",
"SwingXtimes",
"StopCube",
"VideoUnmask",
"VideoUnmaskSwap",
"ButtonUnmask",
"ButtonUnmaskSwap",
"PickHighlight",
"VideoRepick",
"VideoPlaceButton",
"VideoPlaceOrder",
"MoveCube",
"InsertPeg",
"PatternLock",
"RouteStick",
]
class RoboMMEGymEnv(gym.Env):
"""Thin Gymnasium wrapper around a single RoboMME episode env."""
metadata = {"render_modes": ["rgb_array"], "render_fps": 10}
def __init__(
self,
task: str = "PickXtimes",
action_space_type: str = "joint_angle",
dataset: str = "test",
episode_idx: int = 0,
max_steps: int = 300,
):
super().__init__()
from robomme.env_record_wrapper import BenchmarkEnvBuilder
self._builder = BenchmarkEnvBuilder(
env_id=task,
dataset=dataset,
action_space=action_space_type,
gui_render=False,
max_steps=max_steps,
)
self._max_episode_steps = max_steps
self._episode_idx = episode_idx
self._max_steps = max_steps
self._env = None
self._last_raw_obs: dict | None = None
action_dim = 8 if action_space_type == "joint_angle" else 7
self.action_space = spaces.Box(low=-1.0, high=1.0, shape=(action_dim,), dtype=np.float32)
self.observation_space = spaces.Dict(
{
"image": spaces.Box(0, 255, shape=(256, 256, 3), dtype=np.uint8),
"wrist_image": spaces.Box(0, 255, shape=(256, 256, 3), dtype=np.uint8),
"state": spaces.Box(-np.inf, np.inf, shape=(8,), dtype=np.float32),
}
)
def reset(self, *, seed=None, options=None):
super().reset(seed=seed)
self._env = self._builder.make_env_for_episode(
episode_idx=self._episode_idx,
max_steps=self._max_steps,
)
obs, info = self._env.reset()
self._last_raw_obs = obs
return self._convert_obs(obs), self._convert_info(info)
def step(self, action):
obs, reward, terminated, truncated, info = self._env.step(action)
self._last_raw_obs = obs
terminated_bool = bool(terminated.item()) if hasattr(terminated, "item") else bool(terminated)
truncated_bool = bool(truncated.item()) if hasattr(truncated, "item") else bool(truncated)
status = info.get("status", "ongoing")
conv_info = self._convert_info(info)
conv_info["is_success"] = status == "success"
return self._convert_obs(obs), float(reward), terminated_bool, truncated_bool, conv_info
def render(self) -> np.ndarray | None:
if self._last_raw_obs is None:
return np.zeros((256, 256, 3), dtype=np.uint8)
front = self._last_raw_obs.get("front_rgb_list")
if front is None:
return np.zeros((256, 256, 3), dtype=np.uint8)
frame = front[-1] if isinstance(front, list) else front
return np.asarray(frame, dtype=np.uint8)
def _convert_obs(self, obs: dict) -> dict:
front_rgb = (
obs["front_rgb_list"][-1] if isinstance(obs["front_rgb_list"], list) else obs["front_rgb_list"]
)
wrist_rgb = (
obs["wrist_rgb_list"][-1] if isinstance(obs["wrist_rgb_list"], list) else obs["wrist_rgb_list"]
)
joint_state = (
obs["joint_state_list"][-1]
if isinstance(obs["joint_state_list"], list)
else obs["joint_state_list"]
)
gripper_state = (
obs["gripper_state_list"][-1]
if isinstance(obs["gripper_state_list"], list)
else obs["gripper_state_list"]
)
joint = np.asarray(joint_state, dtype=np.float32).flatten()[:7]
gripper = np.asarray(gripper_state, dtype=np.float32).flatten()[:1]
state = np.concatenate([joint, gripper])
return {
"image": np.asarray(front_rgb, dtype=np.uint8),
"wrist_image": np.asarray(wrist_rgb, dtype=np.uint8),
"state": state,
}
def _convert_info(self, info: dict) -> dict:
return {
"status": info.get("status", "ongoing"),
"task_goal": info.get("task_goal", ""),
}
def _make_env_fns(
*,
task: str,
n_envs: int,
action_space_type: str,
dataset: str,
episode_length: int,
task_id: int,
) -> list[Callable[[], RoboMMEGymEnv]]:
def _make_one(episode_index: int) -> RoboMMEGymEnv:
return RoboMMEGymEnv(
task=task,
action_space_type=action_space_type,
dataset=dataset,
episode_idx=episode_index,
max_steps=episode_length,
)
return [partial(_make_one, task_id + i) for i in range(n_envs)]
def create_robomme_envs(
task: str,
n_envs: int = 1,
action_space_type: str = "joint_angle",
dataset: str = "test",
episode_length: int = 300,
task_ids: list[int] | None = None,
env_cls: Callable[[Sequence[Callable[[], Any]]], Any] | None = None,
) -> dict[str, dict[int, gym.vector.VectorEnv]]:
"""Create vectorized RoboMME environments for evaluation."""
if env_cls is None or not callable(env_cls):
raise ValueError("env_cls must be a callable that wraps a list of env factory callables.")
if not isinstance(n_envs, int) or n_envs <= 0:
raise ValueError(f"n_envs must be a positive int; got {n_envs}.")
if task_ids is None:
task_ids = [0]
task_names = [t.strip() for t in task.split(",") if t.strip()]
out: dict[str, dict[int, gym.vector.VectorEnv]] = {}
for task_name in task_names:
envs_by_task: dict[int, gym.vector.VectorEnv] = {}
for task_id in task_ids:
fns = _make_env_fns(
task=task_name,
n_envs=n_envs,
action_space_type=action_space_type,
dataset=dataset,
episode_length=episode_length,
task_id=task_id,
)
envs_by_task[task_id] = env_cls(fns)
out[task_name] = envs_by_task
return out
+7 -12
View File
@@ -12,19 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import TYPE_CHECKING
import numpy as np
from lerobot.utils.import_utils import _placo_available, require_package
if TYPE_CHECKING or _placo_available:
import placo # type: ignore[import-not-found]
else:
placo = None
class RobotKinematics:
"""Robot kinematics using placo library for forward and inverse kinematics."""
@@ -43,7 +32,13 @@ class RobotKinematics:
target_frame_name (str): Name of the end-effector frame in the URDF
joint_names (list[str] | None): List of joint names to use for the kinematics solver
"""
require_package("placo", extra="placo-dep")
try:
import placo # type: ignore[import-not-found] # C++ library with Python bindings, no type stubs available. TODO: Create stub file or request upstream typing support.
except ImportError as e:
raise ImportError(
"placo is required for RobotKinematics. "
"Please install the optional dependencies of `kinematics` in the package."
) from e
self.robot = placo.RobotWrapper(urdf_path)
self.solver = placo.KinematicsSolver(self.robot)
+1 -2
View File
@@ -24,7 +24,7 @@ from functools import cached_property
from typing import TYPE_CHECKING, Any, TypedDict
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from lerobot.utils.import_utils import _can_available, require_package
from lerobot.utils.import_utils import _can_available
if TYPE_CHECKING or _can_available:
import can
@@ -111,7 +111,6 @@ class DamiaoMotorsBus(MotorsBusBase):
bitrate: Nominal bitrate in bps (default: 1000000 = 1 Mbps)
data_bitrate: Data bitrate for CAN FD in bps (default: 5000000 = 5 Mbps), ignored if use_can_fd is False
"""
require_package("python-can", extra="damiao", import_name="can")
super().__init__(port, motors, calibration)
self.port = port
self.can_interface = can_interface
+2 -2
View File
@@ -356,8 +356,8 @@ class SerialMotorsBus(MotorsBusBase):
motors: dict[str, Motor],
calibration: dict[str, MotorCalibration] | None = None,
):
require_package("pyserial", extra="pyserial-dep", import_name="serial")
require_package("deepdiff", extra="deepdiff-dep")
require_package("pyserial", extra="hardware", import_name="serial")
require_package("deepdiff", extra="hardware")
super().__init__(port, motors, calibration)
self.port_handler: PortHandler
+2 -3
View File
@@ -23,12 +23,12 @@ from types import SimpleNamespace
from typing import TYPE_CHECKING, Any, TypedDict
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from lerobot.utils.import_utils import _can_available, require_package
from lerobot.utils.import_utils import _can_available
if TYPE_CHECKING or _can_available:
import can
else:
can = SimpleNamespace(Message=object, interface=None, BusABC=object)
can = SimpleNamespace(Message=object, interface=None)
import numpy as np
from lerobot.utils.errors import DeviceNotConnectedError
@@ -106,7 +106,6 @@ class RobstrideMotorsBus(MotorsBusBase):
bitrate: Nominal bitrate in bps (default: 1000000 = 1 Mbps)
data_bitrate: Data bitrate for CAN FD in bps (default: 5000000 = 5 Mbps), ignored if use_can_fd is False
"""
require_package("python-can", extra="robstride", import_name="can")
super().__init__(port, motors, calibration)
self.port = port
self.can_interface = can_interface
+3 -7
View File
@@ -18,21 +18,14 @@ import logging
import math
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import TYPE_CHECKING
import draccus
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
from lerobot.utils.constants import SCHEDULER_STATE
from lerobot.utils.import_utils import _diffusers_available, require_package
from lerobot.utils.io_utils import deserialize_json_into_object, write_json
if TYPE_CHECKING or _diffusers_available:
from diffusers.optimization import get_scheduler
else:
get_scheduler = None
@dataclass
class LRSchedulerConfig(draccus.ChoiceRegistry, abc.ABC):
@@ -54,7 +47,10 @@ class DiffuserSchedulerConfig(LRSchedulerConfig):
num_warmup_steps: int | None = None
def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR:
from lerobot.utils.import_utils import require_package
require_package("diffusers", extra="diffusion")
from diffusers.optimization import get_scheduler
kwargs = {**asdict(self), "num_training_steps": num_training_steps, "optimizer": optimizer}
return get_scheduler(**kwargs)
+1 -2
View File
@@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from lerobot.utils.action_interpolator import ActionInterpolator as ActionInterpolator
from .act.configuration_act import ACTConfig as ACTConfig
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
from .factory import get_policy_class, make_policy, make_policy_config, make_pre_post_processors
@@ -23,6 +21,7 @@ from .pi0.configuration_pi0 import PI0Config as PI0Config
from .pi0_fast.configuration_pi0_fast import PI0FastConfig as PI0FastConfig
from .pi05.configuration_pi05 import PI05Config as PI05Config
from .pretrained import PreTrainedPolicy as PreTrainedPolicy
from .rtc import ActionInterpolator as ActionInterpolator
from .sac.configuration_sac import SACConfig as SACConfig
from .sac.reward_model.configuration_classifier import RewardClassifierConfig as RewardClassifierConfig
from .sarm.configuration_sarm import SARMConfig as SARMConfig
@@ -23,7 +23,6 @@ TODO(alexander-soare):
import math
from collections import deque
from collections.abc import Callable
from typing import TYPE_CHECKING
import einops
import numpy as np
@@ -33,14 +32,6 @@ import torchvision
from torch import Tensor, nn
from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_IMAGES, OBS_STATE
from lerobot.utils.import_utils import _diffusers_available, require_package
if TYPE_CHECKING or _diffusers_available:
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
else:
DDIMScheduler = None
DDPMScheduler = None
from ..pretrained import PreTrainedPolicy
from ..utils import (
@@ -73,7 +64,6 @@ class DiffusionPolicy(PreTrainedPolicy):
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
that they will be passed with a call to `load_state_dict` before the policy is used.
"""
require_package("diffusers", extra="diffusion")
super().__init__(config)
config.validate_features()
self.config = config
@@ -165,7 +155,11 @@ def _make_noise_scheduler(name: str, **kwargs: dict):
Factory for noise scheduler instances of the requested type. All kwargs are passed
to the scheduler.
"""
from lerobot.utils.import_utils import require_package
require_package("diffusers", extra="diffusion")
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
if name == "DDPM":
return DDPMScheduler(**kwargs)
@@ -204,9 +204,7 @@ class FlowmatchingActionHead(nn.Module):
self.position_embedding = nn.Embedding(config.max_seq_len, self.input_embedding_dim)
nn.init.normal_(self.position_embedding.weight, mean=0.0, std=0.02)
self._noise_beta_alpha = config.noise_beta_alpha
self._noise_beta_beta = config.noise_beta_beta
self._beta_dist = None
self.beta_dist = Beta(config.noise_beta_alpha, config.noise_beta_beta)
self.num_timestep_buckets = config.num_timestep_buckets
self.config = config
self.set_trainable_parameters(config.tune_projector, config.tune_diffusion_model)
@@ -251,9 +249,7 @@ class FlowmatchingActionHead(nn.Module):
self.model.eval()
def sample_time(self, batch_size, device, dtype):
if self._beta_dist is None:
self._beta_dist = Beta(self._noise_beta_alpha, self._noise_beta_beta, validate_args=False)
sample = self._beta_dist.sample([batch_size]).to(device, dtype=dtype)
sample = self.beta_dist.sample([batch_size]).to(device, dtype=dtype)
return (self.config.noise_s - sample) / self.config.noise_s
def prepare_input(self, batch: dict) -> BatchFeature:
@@ -222,13 +222,6 @@ class Eagle25VLProcessor(ProcessorMixin):
videos=None,
**output_kwargs["images_kwargs"],
)
if isinstance(image_inputs["pixel_values"], list):
_pv = image_inputs["pixel_values"]
if _pv and isinstance(_pv[0], list):
_pv = [t for sub in _pv for t in sub]
image_inputs["pixel_values"] = torch.stack(
[t if isinstance(t, torch.Tensor) else torch.as_tensor(t) for t in _pv]
)
num_all_tiles = image_inputs["pixel_values"].shape[0]
special_placeholder = f"<image {idx_in_list + 1}>{self.image_start_token}{self.image_token * num_all_tiles * self.tokens_per_tile}{self.image_end_token}"
unified_frame_list.append(image_inputs)
@@ -240,13 +233,6 @@ class Eagle25VLProcessor(ProcessorMixin):
videos=[video_list[idx_in_list]],
**output_kwargs["videos_kwargs"],
)
if isinstance(video_inputs["pixel_values"], list):
_pv = video_inputs["pixel_values"]
if _pv and isinstance(_pv[0], list):
_pv = [t for sub in _pv for t in sub]
video_inputs["pixel_values"] = torch.stack(
[t if isinstance(t, torch.Tensor) else torch.as_tensor(t) for t in _pv]
)
num_all_tiles = video_inputs["pixel_values"].shape[0]
image_sizes = video_inputs["image_sizes"]
if timestamps_list is not None and -1 not in timestamps_list:
@@ -302,18 +288,8 @@ class Eagle25VLProcessor(ProcessorMixin):
text = replace_in_text(text)
if len(unified_frame_list) > 0:
def _to_tensor(v):
if isinstance(v, torch.Tensor):
return v
if isinstance(v, list):
if v and isinstance(v[0], list):
v = [t for sub in v for t in sub]
return torch.stack([t if isinstance(t, torch.Tensor) else torch.as_tensor(t) for t in v])
return torch.as_tensor(v)
pixel_values = torch.cat([_to_tensor(frame["pixel_values"]) for frame in unified_frame_list])
image_sizes = torch.cat([_to_tensor(frame["image_sizes"]) for frame in unified_frame_list])
pixel_values = torch.cat([frame["pixel_values"] for frame in unified_frame_list])
image_sizes = torch.cat([frame["image_sizes"] for frame in unified_frame_list])
else:
pixel_values = None
image_sizes = None
-1
View File
@@ -221,7 +221,6 @@ class GR00TN15(PreTrainedModel):
self.action_horizon = config.action_horizon
self.action_dim = config.action_dim
self.compute_dtype = config.compute_dtype
self.post_init()
def validate_inputs(self, inputs):
# NOTE -- this should be handled internally by the model
@@ -43,7 +43,6 @@ from torch import Tensor
from lerobot.configs import FeatureType, PolicyFeature
from lerobot.utils.constants import ACTION, OBS_IMAGES
from lerobot.utils.import_utils import require_package
from ..pretrained import PreTrainedPolicy
from .configuration_groot import GrootConfig
@@ -60,7 +59,6 @@ class GrootPolicy(PreTrainedPolicy):
def __init__(self, config: GrootConfig, **kwargs):
"""Initialize Groot policy wrapper."""
require_package("transformers", extra="groot")
super().__init__(config)
config.validate_features()
self.config = config
@@ -36,7 +36,7 @@ import torch.nn.functional as F # noqa: N812
import torchvision
from torch import Tensor
from lerobot.utils.import_utils import _diffusers_available, _transformers_available, require_package
from lerobot.utils.import_utils import _transformers_available
from .configuration_multi_task_dit import MultiTaskDiTConfig
@@ -46,13 +46,6 @@ if TYPE_CHECKING or _transformers_available:
else:
CLIPTextModel = None
CLIPVisionModel = None
if TYPE_CHECKING or _diffusers_available:
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
else:
DDIMScheduler = None
DDPMScheduler = None
from lerobot.utils.constants import (
ACTION,
OBS_IMAGES,
@@ -72,8 +65,6 @@ class MultiTaskDiTPolicy(PreTrainedPolicy):
name = "multi_task_dit"
def __init__(self, config: MultiTaskDiTConfig, **kwargs):
require_package("transformers", extra="multi_task_dit")
require_package("diffusers", extra="multi_task_dit")
super().__init__(config)
config.validate_features()
self.config = config
@@ -652,6 +643,12 @@ class DiffusionObjective(nn.Module):
"prediction_type": config.prediction_type,
}
from lerobot.utils.import_utils import require_package
require_package("diffusers", extra="multi_task_dit")
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
if config.noise_scheduler_type == "DDPM":
self.noise_scheduler: DDPMScheduler | DDIMScheduler = DDPMScheduler(**scheduler_kwargs)
elif config.noise_scheduler_type == "DDIM":
+1 -2
View File
@@ -26,7 +26,7 @@ import torch
import torch.nn.functional as F # noqa: N812
from torch import Tensor, nn
from lerobot.utils.import_utils import _transformers_available, require_package
from lerobot.utils.import_utils import _transformers_available
# Conditional import for type checking and lazy loading
if TYPE_CHECKING or _transformers_available:
@@ -947,7 +947,6 @@ class PI0Policy(PreTrainedPolicy):
Args:
config: Policy configuration class instance.
"""
require_package("transformers", extra="pi")
super().__init__(config)
config.validate_features()
self.config = config
+1 -2
View File
@@ -26,7 +26,7 @@ import torch
import torch.nn.functional as F # noqa: N812
from torch import Tensor, nn
from lerobot.utils.import_utils import _transformers_available, require_package
from lerobot.utils.import_utils import _transformers_available
# Conditional import for type checking and lazy loading
if TYPE_CHECKING or _transformers_available:
@@ -918,7 +918,6 @@ class PI05Policy(PreTrainedPolicy):
Args:
config: Policy configuration class instance.
"""
require_package("transformers", extra="pi")
super().__init__(config)
config.validate_features()
self.config = config
@@ -26,7 +26,7 @@ import torch
import torch.nn.functional as F # noqa: N812
from torch import Tensor, nn
from lerobot.utils.import_utils import _scipy_available, _transformers_available, require_package
from lerobot.utils.import_utils import _scipy_available, _transformers_available
# Conditional import for type checking and lazy loading
if TYPE_CHECKING or _scipy_available:
@@ -35,7 +35,7 @@ else:
idct = None
if TYPE_CHECKING or _transformers_available:
from transformers import AutoProcessor, AutoTokenizer
from transformers import AutoTokenizer
from transformers.models.auto import CONFIG_MAPPING
from ..pi_gemma import (
@@ -44,7 +44,6 @@ if TYPE_CHECKING or _transformers_available:
)
else:
CONFIG_MAPPING = None
AutoProcessor = None
AutoTokenizer = None
PiGemmaModel = None
PaliGemmaForConditionalGenerationWithPiGemma = None
@@ -827,14 +826,14 @@ class PI0FastPolicy(PreTrainedPolicy):
Args:
config: Policy configuration class instance.
"""
require_package("transformers", extra="pi")
require_package("scipy", extra="pi")
super().__init__(config)
config.validate_features()
self.config = config
# Load tokenizers first
try:
from transformers import AutoProcessor, AutoTokenizer
# Load FAST tokenizer
self.action_tokenizer = AutoProcessor.from_pretrained(
config.action_tokenizer_name, trust_remote_code=True
+115 -3
View File
@@ -1,4 +1,116 @@
# Moved to lerobot.utils.action_interpolator — re-exported for backwards compatibility.
from lerobot.utils.action_interpolator import ActionInterpolator
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__all__ = ["ActionInterpolator"]
"""Action interpolation for smoother robot control.
Provides configurable Nx control rate by interpolating between consecutive actions.
Useful with RTC and action-chunking policies to reduce jerkiness.
"""
from torch import Tensor
class ActionInterpolator:
"""Interpolates between consecutive actions for smoother control.
When enabled with multiplier N, produces N actions per policy action
by linearly interpolating between the previous and current action.
Example with multiplier=3:
prev_action -> [1/3 interpolated, 2/3 interpolated, current_action]
This effectively multiplies the control rate for smoother motion.
Usage:
interpolator = ActionInterpolator(multiplier=2) # 2x control rate
# In control loop:
if interpolator.needs_new_action():
new_action = queue.get()
if new_action:
interpolator.add(new_action.cpu())
action = interpolator.get()
if action:
robot.send_action(action)
"""
def __init__(self, multiplier: int = 1):
"""Initialize the interpolator.
Args:
multiplier: Control rate multiplier (1 = no interpolation, 2 = 2x, 3 = 3x, etc.)
"""
if multiplier < 1:
raise ValueError(f"multiplier must be >= 1, got {multiplier}")
self.multiplier = multiplier
self._prev: Tensor | None = None
self._buffer: list[Tensor] = []
self._idx = 0
@property
def enabled(self) -> bool:
"""Whether interpolation is active (multiplier > 1)."""
return self.multiplier > 1
def reset(self):
"""Reset interpolation state (call between episodes)."""
self._prev = None
self._buffer = []
self._idx = 0
def needs_new_action(self) -> bool:
"""Check if a new action is needed from the queue."""
return self._idx >= len(self._buffer)
def add(self, action: Tensor) -> None:
"""Add a new action and compute interpolated sequence.
Args:
action: New action tensor from policy/queue (already on CPU).
"""
if self.multiplier > 1 and self._prev is not None:
self._buffer = []
for i in range(1, self.multiplier + 1):
t = i / self.multiplier
interp = self._prev + t * (action - self._prev)
self._buffer.append(interp)
else:
# First step: no previous action yet, so run at base FPS without interpolation.
self._buffer = [action.clone()]
self._prev = action.clone()
self._idx = 0
def get(self) -> Tensor | None:
"""Get the next interpolated action.
Returns:
Next action tensor, or None if buffer is exhausted.
"""
if self._idx >= len(self._buffer):
return None
action = self._buffer[self._idx]
self._idx += 1
return action
def get_control_interval(self, fps: float) -> float:
"""Get the control interval based on interpolation multiplier.
Args:
fps: Base frames per second.
Returns:
Control interval in seconds (divided by multiplier).
"""
return 1.0 / (fps * self.multiplier)
+10 -10
View File
@@ -92,10 +92,10 @@ class ActionQueue:
Returns:
int: Number of unconsumed actions.
"""
with self.lock:
if self.queue is None:
return 0
return len(self.queue) - self.last_index
if self.queue is None:
return 0
length = len(self.queue)
return length - self.last_index
def empty(self) -> bool:
"""Check if the queue is empty.
@@ -103,10 +103,11 @@ class ActionQueue:
Returns:
bool: True if no actions remain, False otherwise.
"""
with self.lock:
if self.queue is None:
return True
return len(self.queue) - self.last_index <= 0
if self.queue is None:
return True
length = len(self.queue)
return length - self.last_index <= 0
def get_action_index(self) -> int:
"""Get the current action consumption index.
@@ -114,8 +115,7 @@ class ActionQueue:
Returns:
int: Index of the next action to be consumed.
"""
with self.lock:
return self.last_index
return self.last_index
def get_left_over(self) -> Tensor | None:
"""Get leftover original actions for RTC prev_chunk_left_over.
@@ -62,7 +62,6 @@ from torch import Tensor, nn
from lerobot.utils.constants import ACTION, OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS, OBS_STATE
from lerobot.utils.device_utils import get_safe_dtype
from lerobot.utils.import_utils import require_package
from ..pretrained import PreTrainedPolicy
from ..rtc.modeling_rtc import RTCProcessor
@@ -240,7 +239,6 @@ class SmolVLAPolicy(PreTrainedPolicy):
the configuration class is used.
"""
require_package("transformers", extra="smolvla")
super().__init__(config)
config.validate_features()
self.config = config
+2 -2
View File
@@ -27,7 +27,7 @@ import torch.distributed as distributed
import torch.nn.functional as F # noqa: N812
from einops import pack, rearrange, reduce, repeat, unpack
from torch import einsum, nn
from torch.amp import autocast
from torch.cuda.amp import autocast
from torch.optim import Optimizer
from .configuration_vqbet import VQBeTConfig
@@ -1370,7 +1370,7 @@ class EuclideanCodebook(nn.Module):
batch_samples = rearrange(batch_samples, "h ... d -> h (...) d")
self.replace(batch_samples, batch_mask=expired_codes)
@autocast("cuda", enabled=False)
@autocast(enabled=False)
def forward(self, x, sample_codebook_temp=None, mask=None, freeze_codebook=False):
needs_codebook_dim = x.ndim < 4
sample_codebook_temp = (
+1 -1
View File
@@ -76,7 +76,6 @@ from lerobot.transport.utils import (
)
from lerobot.types import TransitionKey
from lerobot.utils.device_utils import get_safe_torch_device
from lerobot.utils.process import ProcessSignalHandler
from lerobot.utils.random_utils import set_seed
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.transition import (
@@ -95,6 +94,7 @@ from .gym_manipulator import (
make_robot_env,
step_env_and_process_transition,
)
from .process import ProcessSignalHandler
from .queue import get_last_item_from_queue
# Main entry point
+58 -54
View File
@@ -15,7 +15,6 @@
# limitations under the License.
import functools
import threading
from collections.abc import Callable, Sequence
from contextlib import suppress
from typing import TypedDict
@@ -116,7 +115,6 @@ class ReplayBuffer:
self.size = 0
self.initialized = False
self.optimize_memory = optimize_memory
self._lock = threading.Lock()
# Track episode boundaries for memory optimization
self.episode_ends = torch.zeros(capacity, dtype=torch.bool, device=storage_device)
@@ -200,75 +198,68 @@ class ReplayBuffer:
complementary_info: dict[str, torch.Tensor] | None = None,
):
"""Saves a transition, ensuring tensors are stored on the designated storage device."""
with self._lock:
# Initialize storage if this is the first transition
if not self.initialized:
self._initialize_storage(state=state, action=action, complementary_info=complementary_info)
# Initialize storage if this is the first transition
if not self.initialized:
self._initialize_storage(state=state, action=action, complementary_info=complementary_info)
# Store the transition in pre-allocated tensors
for key in self.states:
self.states[key][self.position].copy_(state[key].squeeze(dim=0))
# Store the transition in pre-allocated tensors
for key in self.states:
self.states[key][self.position].copy_(state[key].squeeze(dim=0))
if not self.optimize_memory:
# Only store next_states if not optimizing memory
self.next_states[key][self.position].copy_(next_state[key].squeeze(dim=0))
if not self.optimize_memory:
# Only store next_states if not optimizing memory
self.next_states[key][self.position].copy_(next_state[key].squeeze(dim=0))
self.actions[self.position].copy_(action.squeeze(dim=0))
self.rewards[self.position] = reward
self.dones[self.position] = done
self.truncateds[self.position] = truncated
self.actions[self.position].copy_(action.squeeze(dim=0))
self.rewards[self.position] = reward
self.dones[self.position] = done
self.truncateds[self.position] = truncated
# Handle complementary_info if provided and storage is initialized
if complementary_info is not None and self.has_complementary_info:
for key in self.complementary_info_keys:
if key in complementary_info:
value = complementary_info[key]
if isinstance(value, torch.Tensor):
self.complementary_info[key][self.position].copy_(value.squeeze(dim=0))
elif isinstance(value, (int | float)):
self.complementary_info[key][self.position] = value
# Handle complementary_info if provided and storage is initialized
if complementary_info is not None and self.has_complementary_info:
# Store the complementary_info
for key in self.complementary_info_keys:
if key in complementary_info:
value = complementary_info[key]
if isinstance(value, torch.Tensor):
self.complementary_info[key][self.position].copy_(value.squeeze(dim=0))
elif isinstance(value, (int | float)):
self.complementary_info[key][self.position] = value
self.position = (self.position + 1) % self.capacity
self.size = min(self.size + 1, self.capacity)
self.position = (self.position + 1) % self.capacity
self.size = min(self.size + 1, self.capacity)
def sample(self, batch_size: int) -> BatchTransition:
"""Sample a random batch of transitions and collate them into batched tensors."""
if not self.initialized:
raise RuntimeError("Cannot sample from an empty buffer. Add transitions first.")
with self._lock:
batch_size = min(batch_size, self.size)
high = max(0, self.size - 1) if self.optimize_memory and self.size < self.capacity else self.size
batch_size = min(batch_size, self.size)
high = max(0, self.size - 1) if self.optimize_memory and self.size < self.capacity else self.size
idx = torch.randint(low=0, high=high, size=(batch_size,), device=self.storage_device)
# Random indices for sampling - create on the same device as storage
idx = torch.randint(low=0, high=high, size=(batch_size,), device=self.storage_device)
image_keys = [k for k in self.states if k.startswith(OBS_IMAGE)] if self.use_drq else []
# Identify image keys that need augmentation
image_keys = [k for k in self.states if k.startswith(OBS_IMAGE)] if self.use_drq else []
batch_state = {}
batch_next_state = {}
# Create batched state and next_state
batch_state = {}
batch_next_state = {}
for key in self.states:
batch_state[key] = self.states[key][idx].to(self.device)
# First pass: load all state tensors to target device
for key in self.states:
batch_state[key] = self.states[key][idx].to(self.device)
if not self.optimize_memory:
batch_next_state[key] = self.next_states[key][idx].to(self.device)
else:
next_idx = (idx + 1) % self.capacity
batch_next_state[key] = self.states[key][next_idx].to(self.device)
# Sample other tensors
batch_actions = self.actions[idx].to(self.device)
batch_rewards = self.rewards[idx].to(self.device)
batch_dones = self.dones[idx].to(self.device).float()
batch_truncateds = self.truncateds[idx].to(self.device).float()
# Sample complementary_info if available
batch_complementary_info = None
if self.has_complementary_info:
batch_complementary_info = {}
for key in self.complementary_info_keys:
batch_complementary_info[key] = self.complementary_info[key][idx].to(self.device)
if not self.optimize_memory:
# Standard approach - load next_states directly
batch_next_state[key] = self.next_states[key][idx].to(self.device)
else:
# Memory-optimized approach - get next_state from the next index
next_idx = (idx + 1) % self.capacity
batch_next_state[key] = self.states[key][next_idx].to(self.device)
# Apply image augmentation in a batched way if needed
if self.use_drq and image_keys:
# Concatenate all images from state and next_state
all_images = []
@@ -289,6 +280,19 @@ class ReplayBuffer:
# Next states start after the states at index (i*2+1)*batch_size and also take up batch_size slots
batch_next_state[key] = augmented_images[(i * 2 + 1) * batch_size : (i + 1) * 2 * batch_size]
# Sample other tensors
batch_actions = self.actions[idx].to(self.device)
batch_rewards = self.rewards[idx].to(self.device)
batch_dones = self.dones[idx].to(self.device).float()
batch_truncateds = self.truncateds[idx].to(self.device).float()
# Sample complementary_info if available
batch_complementary_info = None
if self.has_complementary_info:
batch_complementary_info = {}
for key in self.complementary_info_keys:
batch_complementary_info[key] = self.complementary_info[key][idx].to(self.device)
return BatchTransition(
state=batch_state,
action=batch_actions,
+2 -2
View File
@@ -551,8 +551,8 @@ def step_env_and_process_transition(
terminated = terminated or processed_action_transition[TransitionKey.DONE]
truncated = truncated or processed_action_transition[TransitionKey.TRUNCATED]
complementary_data = processed_action_transition[TransitionKey.COMPLEMENTARY_DATA].copy()
new_info = info.copy()
new_info.update(processed_action_transition[TransitionKey.INFO])
new_info = processed_action_transition[TransitionKey.INFO].copy()
new_info.update(info)
new_transition = create_transition(
observation=obs,
+1 -1
View File
@@ -90,7 +90,6 @@ from lerobot.utils.constants import (
TRAINING_STATE_DIR,
)
from lerobot.utils.device_utils import get_safe_torch_device
from lerobot.utils.process import ProcessSignalHandler
from lerobot.utils.random_utils import set_seed
from lerobot.utils.transition import move_state_dict_to_device, move_transition_to_device
from lerobot.utils.utils import (
@@ -100,6 +99,7 @@ from lerobot.utils.utils import (
from .buffer import ReplayBuffer, concatenate_batch_transitions
from .learner_service import MAX_WORKERS, SHUTDOWN_TIMEOUT, LearnerService
from .process import ProcessSignalHandler
@parser.wrap()
+1 -2
View File
@@ -20,7 +20,7 @@ from typing import TYPE_CHECKING, Any
from lerobot.cameras import make_cameras_from_configs
from lerobot.types import RobotAction, RobotObservation
from lerobot.utils.import_utils import _reachy2_sdk_available, require_package
from lerobot.utils.import_utils import _reachy2_sdk_available
from ..robot import Robot
from ..utils import ensure_safe_goal_position
@@ -81,7 +81,6 @@ class Reachy2Robot(Robot):
name = "reachy2"
def __init__(self, config: Reachy2RobotConfig):
require_package("reachy2_sdk", extra="reachy2")
super().__init__(config)
self.config = config
+1 -2
View File
@@ -27,7 +27,7 @@ import numpy as np
from lerobot.cameras import make_cameras_from_configs
from lerobot.types import RobotAction, RobotObservation
from lerobot.utils.import_utils import _unitree_sdk_available, require_package
from lerobot.utils.import_utils import _unitree_sdk_available
from ..robot import Robot
from .config_unitree_g1 import UnitreeG1Config
@@ -111,7 +111,6 @@ class UnitreeG1(Robot):
name = "unitree_g1"
def __init__(self, config: UnitreeG1Config):
require_package("unitree-sdk2py", extra="unitree_g1", import_name="unitree_sdk2py")
super().__init__(config)
logger.info("Initialize UnitreeG1...")
-82
View File
@@ -1,82 +0,0 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Policy deployment engine with pluggable rollout strategies."""
from lerobot.utils.import_utils import require_package
require_package("datasets", extra="dataset")
from .configs import (
BaseStrategyConfig,
DAggerKeyboardConfig,
DAggerPedalConfig,
DAggerStrategyConfig,
DatasetRecordConfig,
HighlightStrategyConfig,
RolloutConfig,
RolloutStrategyConfig,
SentryStrategyConfig,
)
from .context import (
DatasetContext,
HardwareContext,
PolicyContext,
ProcessorContext,
RolloutContext,
RuntimeContext,
build_rollout_context,
)
from .inference import (
InferenceEngine,
InferenceEngineConfig,
RTCInferenceConfig,
RTCInferenceEngine,
SyncInferenceConfig,
SyncInferenceEngine,
create_inference_engine,
)
from .ring_buffer import RolloutRingBuffer
from .robot_wrapper import ThreadSafeRobot
from .strategies import RolloutStrategy, create_strategy
__all__ = [
"BaseStrategyConfig",
"DAggerKeyboardConfig",
"DAggerPedalConfig",
"DAggerStrategyConfig",
"DatasetContext",
"DatasetRecordConfig",
"HardwareContext",
"HighlightStrategyConfig",
"InferenceEngine",
"InferenceEngineConfig",
"PolicyContext",
"ProcessorContext",
"RTCInferenceConfig",
"RTCInferenceEngine",
"RolloutConfig",
"RolloutContext",
"RolloutRingBuffer",
"RolloutStrategy",
"RolloutStrategyConfig",
"RuntimeContext",
"SentryStrategyConfig",
"SyncInferenceConfig",
"SyncInferenceEngine",
"ThreadSafeRobot",
"build_rollout_context",
"create_inference_engine",
"create_strategy",
]
-270
View File
@@ -1,270 +0,0 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Configuration dataclasses for the rollout deployment engine."""
from __future__ import annotations
import abc
import logging
from dataclasses import dataclass, field
import draccus
from lerobot.configs import PreTrainedConfig, parser
from lerobot.configs.dataset import DatasetRecordConfig
from lerobot.robots.config import RobotConfig
from lerobot.teleoperators.config import TeleoperatorConfig
from .inference import InferenceEngineConfig, SyncInferenceConfig
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Strategy configs (polymorphic dispatch via draccus ChoiceRegistry)
# ---------------------------------------------------------------------------
@dataclass
class RolloutStrategyConfig(draccus.ChoiceRegistry, abc.ABC):
"""Abstract base for rollout strategy configurations.
Use ``--strategy.type=<name>`` on the CLI to select a strategy.
"""
@property
def type(self) -> str:
return self.get_choice_name(self.__class__)
@RolloutStrategyConfig.register_subclass("base")
@dataclass
class BaseStrategyConfig(RolloutStrategyConfig):
"""Autonomous rollout with no data recording."""
pass
@RolloutStrategyConfig.register_subclass("sentry")
@dataclass
class SentryStrategyConfig(RolloutStrategyConfig):
"""Continuous autonomous rollout with always-on recording.
Episode duration is derived from camera resolution, FPS, and
``target_video_file_size_mb`` so that each saved episode produces a
video file that has crossed the target size. This aligns episode
boundaries with the dataset's video file chunking, so each
``push_to_hub`` call uploads complete video files rather than
re-uploading a growing file that hasn't crossed the chunk boundary.
"""
upload_every_n_episodes: int = 5
# Target video file size in MB for episode rotation. Episodes are
# saved once the estimated video duration would exceed this limit.
# Defaults to DEFAULT_VIDEO_FILE_SIZE_IN_MB when set to None.
target_video_file_size_mb: float | None = None
@RolloutStrategyConfig.register_subclass("highlight")
@dataclass
class HighlightStrategyConfig(RolloutStrategyConfig):
"""Autonomous rollout with on-demand recording via ring buffer.
A memory-bounded ring buffer continuously captures telemetry. When
the user presses the save key, the buffer contents are flushed to
the dataset and live recording continues until the key is pressed
again.
"""
ring_buffer_seconds: float = 30.0
ring_buffer_max_memory_mb: float = 2048.0
save_key: str = "s"
push_key: str = "h"
@dataclass
class DAggerKeyboardConfig:
"""Keyboard key bindings for DAgger controls.
Keys are specified as single characters (e.g. ``"c"``, ``"h"``) or
special key names (``"space"``).
"""
pause_resume: str = "space"
correction: str = "tab"
upload: str = "enter"
@dataclass
class DAggerPedalConfig:
"""Foot pedal configuration for DAgger controls.
Pedal codes are evdev key code strings (e.g. ``"KEY_A"``).
"""
device_path: str = "/dev/input/by-id/usb-PCsensor_FootSwitch-event-kbd"
pause_resume: str = "KEY_A"
correction: str = "KEY_B"
upload: str = "KEY_C"
@RolloutStrategyConfig.register_subclass("dagger")
@dataclass
class DAggerStrategyConfig(RolloutStrategyConfig):
"""Human-in-the-loop data collection (DAgger / RaC).
Alternates between autonomous policy execution and human intervention.
Intervention frames are tagged with ``intervention=True``.
Input is controlled via either a keyboard or foot pedal, selected by
``input_device``. Each device exposes three actions:
1. **pause_resume** toggle policy execution on/off.
2. **correction** toggle human correction recording.
3. **upload** push dataset to hub on demand (corrections-only mode).
When ``record_autonomous=True`` (default) both autonomous and correction
frames are recorded with size-based episode rotation (same as Sentry)
and background uploading. ``push_to_hub`` is blocked while a correction
is in progress. Set to ``False`` to record only the human-correction
windows, where each correction becomes its own episode.
"""
num_episodes: int = 10
record_autonomous: bool = False
upload_every_n_episodes: int = 5
# Target video file size in MB for episode rotation (record_autonomous
# mode only). Defaults to DEFAULT_VIDEO_FILE_SIZE_IN_MB when None.
target_video_file_size_mb: float | None = None
input_device: str = "keyboard"
keyboard: DAggerKeyboardConfig = field(default_factory=DAggerKeyboardConfig)
pedal: DAggerPedalConfig = field(default_factory=DAggerPedalConfig)
def __post_init__(self):
if self.input_device not in ("keyboard", "pedal"):
raise ValueError(f"DAgger input_device must be 'keyboard' or 'pedal', got '{self.input_device}'")
# ---------------------------------------------------------------------------
# Top-level rollout config
# ---------------------------------------------------------------------------
@dataclass
class RolloutConfig:
"""Top-level configuration for the ``lerobot-rollout`` CLI.
Combines hardware, policy, strategy, and runtime settings. The
``__post_init__`` method performs fail-fast validation to reject
invalid flag combinations early.
"""
# Hardware
robot: RobotConfig | None = None
teleop: TeleoperatorConfig | None = None
# Policy (loaded from --policy.path via __post_init__)
policy: PreTrainedConfig | None = None
# Strategy (polymorphic: --strategy.type=base|sentry|highlight|dagger)
strategy: RolloutStrategyConfig = field(default_factory=BaseStrategyConfig)
# Inference backend (polymorphic: --inference.type=sync|rtc)
inference: InferenceEngineConfig = field(default_factory=SyncInferenceConfig)
# Dataset (required for sentry, highlight, dagger; None for base)
dataset: DatasetRecordConfig | None = None
# Runtime
fps: float = 30.0
duration: float = 0.0 # 0 = infinite (24/7 mode)
interpolation_multiplier: int = 1
device: str | None = None
task: str = ""
display_data: bool = False
# Display data on a remote Rerun server
display_ip: str | None = None
# Port of the remote Rerun server
display_port: int | None = None
# Whether to display compressed images in Rerun
display_compressed_images: bool = False
# Use vocal synthesis to read events
play_sounds: bool = True
resume: bool = False
# Torch compile
use_torch_compile: bool = False
torch_compile_backend: str = "inductor"
torch_compile_mode: str = "default"
compile_warmup_inferences: int = 2
def __post_init__(self):
"""Validate config invariants and load the policy config from ``--policy.path``."""
# --- Strategy-specific validation ---
if isinstance(self.strategy, DAggerStrategyConfig) and self.teleop is None:
raise ValueError("DAgger strategy requires --teleop.type to be set")
needs_dataset = isinstance(self.strategy, (SentryStrategyConfig, HighlightStrategyConfig))
if needs_dataset and (self.dataset is None or not self.dataset.repo_id):
raise ValueError(f"{self.strategy.type} strategy requires --dataset.repo_id to be set")
if isinstance(self.strategy, BaseStrategyConfig) and self.dataset is not None:
raise ValueError(
"Base strategy does not record data. Use sentry, highlight, or dagger for recording."
)
# Sentry MUST use streaming encoding to avoid disk I/O blocking the control loop
if (
isinstance(self.strategy, SentryStrategyConfig)
and self.dataset is not None
and not self.dataset.streaming_encoding
):
logger.warning("Sentry mode forces streaming_encoding=True")
self.dataset.streaming_encoding = True
# Highlight writes frames while the policy is still running, so streaming is mandatory.
if (
isinstance(self.strategy, HighlightStrategyConfig)
and self.dataset is not None
and not self.dataset.streaming_encoding
):
logger.warning("Highlight mode forces streaming_encoding=True")
self.dataset.streaming_encoding = True
# DAgger: streaming is mandatory only when the autonomous phase is also recorded.
if (
isinstance(self.strategy, DAggerStrategyConfig)
and self.strategy.record_autonomous
and self.dataset is not None
and not self.dataset.streaming_encoding
):
logger.warning("DAgger with record_autonomous=True forces streaming_encoding=True")
self.dataset.streaming_encoding = True
# --- Policy loading ---
if self.robot is None:
raise ValueError("--robot.type is required for rollout")
policy_path = parser.get_path_arg("policy")
if policy_path:
cli_overrides = parser.get_cli_overrides("policy")
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
self.policy.pretrained_path = policy_path
if self.policy is None:
raise ValueError("--policy.path is required for rollout")
@classmethod
def __get_path_fields__(cls) -> list[str]:
return ["policy"]
-429
View File
@@ -1,429 +0,0 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Rollout context: shared state created once before strategy dispatch.
Grouped into five topical sub-contexts :class:`RuntimeContext`,
:class:`HardwareContext`, :class:`PolicyContext`, :class:`ProcessorContext`,
and :class:`DatasetContext` assembled into :class:`RolloutContext`.
"""
from __future__ import annotations
import logging
from dataclasses import dataclass, field
from threading import Event
import torch
from lerobot.configs import FeatureType, PreTrainedConfig
from lerobot.datasets import (
LeRobotDataset,
aggregate_pipeline_dataset_features,
create_initial_features,
)
from lerobot.policies import get_policy_class, make_pre_post_processors
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.processor import (
PolicyProcessorPipeline,
RobotAction,
RobotObservation,
RobotProcessorPipeline,
make_default_processors,
rename_stats,
)
from lerobot.robots import make_robot_from_config
from lerobot.teleoperators import Teleoperator, make_teleoperator_from_config
from lerobot.utils.feature_utils import combine_feature_dicts, hw_to_dataset_features
from .configs import BaseStrategyConfig, DAggerStrategyConfig, RolloutConfig
from .inference import (
InferenceEngine,
RTCInferenceConfig,
create_inference_engine,
)
from .robot_wrapper import ThreadSafeRobot
logger = logging.getLogger(__name__)
def _resolve_action_key_order(
policy_action_names: list[str] | None, dataset_action_names: list[str]
) -> list[str]:
"""Choose action name ordering for mapping policy tensor outputs to robot action dicts."""
if not policy_action_names:
return dataset_action_names
policy_action_names = list(policy_action_names)
if len(policy_action_names) != len(dataset_action_names):
logger.warning(
"policy.action_feature_names length (%d) != dataset action dim (%d); using dataset order",
len(policy_action_names),
len(dataset_action_names),
)
return dataset_action_names
if set(dataset_action_names) != set(policy_action_names):
logger.warning("policy.action_feature_names keys don't match dataset; using dataset order")
return dataset_action_names
return policy_action_names
# ---------------------------------------------------------------------------
# Sub-contexts
# ---------------------------------------------------------------------------
@dataclass
class RuntimeContext:
"""Runtime knobs shared with every strategy."""
cfg: RolloutConfig
shutdown_event: Event
@dataclass
class HardwareContext:
"""Connected hardware.
The raw robot is available via ``robot_wrapper.inner`` when needed
(e.g. for disconnect); strategies should otherwise go through the
thread-safe wrapper.
``initial_position`` stores the robot's joint positions at connect
time. Strategies use it to return the robot to a safe pose before
shutting down.
"""
robot_wrapper: ThreadSafeRobot
teleop: Teleoperator | None
initial_position: dict | None = None
@dataclass
class PolicyContext:
"""Loaded policy and its inference engine."""
policy: PreTrainedPolicy
preprocessor: PolicyProcessorPipeline
postprocessor: PolicyProcessorPipeline
inference: InferenceEngine
@dataclass
class ProcessorContext:
"""Robot-side pipelines (run outside the policy)."""
teleop_action_processor: RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction]
robot_action_processor: RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction]
robot_observation_processor: RobotProcessorPipeline[RobotObservation, RobotObservation]
@dataclass
class DatasetContext:
"""Dataset and feature bookkeeping."""
dataset: LeRobotDataset | None
dataset_features: dict = field(default_factory=dict)
hw_features: dict = field(default_factory=dict)
ordered_action_keys: list[str] = field(default_factory=list)
@dataclass
class RolloutContext:
"""Bundle of sub-contexts passed to every rollout strategy.
Built once by :func:`build_rollout_context` before strategy dispatch.
"""
runtime: RuntimeContext
hardware: HardwareContext
policy: PolicyContext
processors: ProcessorContext
data: DatasetContext
# ---------------------------------------------------------------------------
# Build
# ---------------------------------------------------------------------------
def build_rollout_context(
cfg: RolloutConfig,
shutdown_event: Event,
teleop_action_processor: RobotProcessorPipeline | None = None,
robot_action_processor: RobotProcessorPipeline | None = None,
robot_observation_processor: RobotProcessorPipeline | None = None,
) -> RolloutContext:
"""Wire up policy, processors, hardware, dataset, and inference engine.
The order is policy-first / hardware-last so a bad ``--policy.path``
fails fast without touching the robot.
"""
is_rtc = isinstance(cfg.inference, RTCInferenceConfig)
# --- 1. Policy (heavy I/O, but no hardware yet) -------------------
logger.info("Loading policy from '%s'...", cfg.policy.pretrained_path)
policy_config = cfg.policy
policy_class = get_policy_class(policy_config.type)
full_config = PreTrainedConfig.from_pretrained(cfg.policy.pretrained_path)
for attr in ("device", "use_amp"):
if hasattr(cfg.policy, attr) and hasattr(full_config, attr):
cli_val = getattr(cfg.policy, attr)
if cli_val is not None:
setattr(full_config, attr, cli_val)
if hasattr(full_config, "compile_model"):
full_config.compile_model = cfg.use_torch_compile
if full_config.type == "vqbet" and cfg.device == "mps":
raise NotImplementedError(
"Current implementation of VQBeT does not support `mps` backend. "
"Please use `cpu` or `cuda` backend."
)
if full_config.use_peft:
from peft import PeftConfig, PeftModel
peft_path = cfg.policy.pretrained_path
peft_config = PeftConfig.from_pretrained(peft_path)
policy = policy_class.from_pretrained(
pretrained_name_or_path=peft_config.base_model_name_or_path, config=full_config
)
policy = PeftModel.from_pretrained(policy, peft_path, config=peft_config)
else:
policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=full_config)
if is_rtc:
policy.config.rtc_config = cfg.inference.rtc
if hasattr(policy, "init_rtc_processor"):
policy.init_rtc_processor()
policy = policy.to(cfg.device)
policy.eval()
logger.info("Policy loaded: type=%s, device=%s", policy_config.type, cfg.device)
if cfg.use_torch_compile and policy.type not in ("pi0", "pi05"):
try:
if hasattr(torch, "compile"):
compile_kwargs = {
"backend": cfg.torch_compile_backend,
"mode": cfg.torch_compile_mode,
"options": {"triton.cudagraphs": False},
}
policy.predict_action_chunk = torch.compile(policy.predict_action_chunk, **compile_kwargs)
logger.info("torch.compile applied to predict_action_chunk")
except Exception as e:
logger.warning("Failed to apply torch.compile: %s", e)
# --- 2. Robot-side processors (user-supplied or defaults) --------
if (
teleop_action_processor is None
or robot_action_processor is None
or robot_observation_processor is None
):
_t, _r, _o = make_default_processors()
teleop_action_processor = teleop_action_processor or _t
robot_action_processor = robot_action_processor or _r
robot_observation_processor = robot_observation_processor or _o
# --- 3. Hardware (heaviest side-effect, deferred) -----------------
logger.info("Connecting robot (%s)...", cfg.robot.type if cfg.robot else "?")
robot = make_robot_from_config(cfg.robot)
robot.connect()
logger.info("Robot connected: %s", robot.name)
# Store the initial joint positions so we can return to a safe pose on shutdown.
initial_obs = robot.get_observation()
initial_position = {k: v for k, v in initial_obs.items() if k.endswith(".pos")}
logger.info("Captured initial robot position (%d keys)", len(initial_position))
robot_wrapper = ThreadSafeRobot(robot)
teleop = None
if cfg.teleop is not None:
logger.info("Connecting teleoperator (%s)...", cfg.teleop.type if cfg.teleop else "?")
teleop = make_teleoperator_from_config(cfg.teleop)
teleop.connect()
logger.info("Teleoperator connected")
# DAgger requires teleop with motor control capabilities (enable_torque,
# disable_torque, write_goal_positions).
# TODO(Steven): either enforce this (meaning all teleop must implement these methods) or
# user is responsible for moving the teleop to the same position as the robot when starting the correction.
# if isinstance(cfg.strategy, DAggerStrategyConfig) and teleop is not None:
# required_teleop_methods = ("enable_torque", "disable_torque", "write_goal_positions")
# missing = [m for m in required_teleop_methods if not callable(getattr(teleop, m, None))]
# if missing:
# teleop.disconnect()
# raise ValueError(
# f"DAgger strategy requires a teleoperator with motor control methods "
# f"{required_teleop_methods}. '{type(teleop).__name__}' is missing: {missing}"
# )
# --- 4. Features + action-key reconciliation ---------------------
all_obs_features = robot.observation_features
observation_features_hw = {
k: v for k, v in all_obs_features.items() if v is float or isinstance(v, tuple)
}
action_features_hw = robot.action_features
# The action side is always needed: sync inference reads action names from
# ``dataset_features[ACTION]`` to map policy tensors back to robot actions.
action_dataset_features = aggregate_pipeline_dataset_features(
pipeline=teleop_action_processor,
initial_features=create_initial_features(action=action_features_hw),
use_videos=cfg.dataset.video if cfg.dataset else True,
)
# Observation-side aggregation is needed because of build_dataset_frame
observation_dataset_features = aggregate_pipeline_dataset_features(
pipeline=robot_observation_processor,
initial_features=create_initial_features(observation=observation_features_hw),
use_videos=cfg.dataset.video if cfg.dataset else True,
)
dataset_features = combine_feature_dicts(action_dataset_features, observation_dataset_features)
hw_features = hw_to_dataset_features(observation_features_hw, "observation")
raw_action_keys = list(robot.action_features.keys())
policy_action_names = getattr(policy_config, "action_feature_names", None)
ordered_action_keys = _resolve_action_key_order(
list(policy_action_names) if policy_action_names else None,
raw_action_keys,
)
# Validate visual features if no rename_map is active
rename_map = cfg.dataset.rename_map if cfg.dataset else {}
if not rename_map:
expected_visuals = {k for k, v in full_config.input_features.items() if v.type == FeatureType.VISUAL}
provided_visuals = {
f"observation.{k}" for k, v in robot.observation_features.items() if isinstance(v, tuple)
}
policy_subset = expected_visuals.issubset(provided_visuals)
hw_subset = provided_visuals.issubset(expected_visuals)
if not (policy_subset or hw_subset):
raise ValueError(
f"Visual feature mismatch between policy and robot hardware.\n"
f"Policy expects: {expected_visuals}\n"
f"Robot provides: {provided_visuals}"
)
# --- 5. Dataset -------------
dataset = None
if cfg.dataset is not None and not isinstance(cfg.strategy, BaseStrategyConfig):
logger.info("Setting up dataset (repo_id=%s)...", cfg.dataset.repo_id)
if cfg.resume:
dataset = LeRobotDataset.resume(
cfg.dataset.repo_id,
root=cfg.dataset.root,
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
vcodec=cfg.dataset.vcodec,
streaming_encoding=cfg.dataset.streaming_encoding,
encoder_queue_maxsize=cfg.dataset.encoder_queue_maxsize,
encoder_threads=cfg.dataset.encoder_threads,
image_writer_processes=cfg.dataset.num_image_writer_processes,
image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera
* len(robot.cameras if hasattr(robot, "cameras") else []),
)
else:
if isinstance(cfg.strategy, DAggerStrategyConfig):
dataset_features["intervention"] = {
"dtype": "bool",
"shape": (1,),
"names": None,
}
dataset = LeRobotDataset.create(
cfg.dataset.repo_id,
cfg.dataset.fps,
root=cfg.dataset.root,
robot_type=robot.name,
features=dataset_features,
use_videos=cfg.dataset.video,
image_writer_processes=cfg.dataset.num_image_writer_processes,
image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera
* len(robot.cameras if hasattr(robot, "cameras") else []),
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
vcodec=cfg.dataset.vcodec,
streaming_encoding=cfg.dataset.streaming_encoding,
encoder_queue_maxsize=cfg.dataset.encoder_queue_maxsize,
encoder_threads=cfg.dataset.encoder_threads,
)
if dataset is not None:
logger.info("Dataset ready: %s (%d existing episodes)", dataset.repo_id, dataset.num_episodes)
# --- 6. Policy pre/post processors (needs dataset stats if any) ---
dataset_stats = None
if dataset is not None:
dataset_stats = rename_stats(
dataset.meta.stats,
cfg.dataset.rename_map if cfg.dataset else {},
)
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=policy_config,
pretrained_path=cfg.policy.pretrained_path,
dataset_stats=dataset_stats,
preprocessor_overrides={
"device_processor": {"device": cfg.device or getattr(policy_config, "device", "cpu")},
"rename_observations_processor": {"rename_map": cfg.dataset.rename_map if cfg.dataset else {}},
},
)
# --- 7. Inference strategy (needs policy + pre/post + hardware) --
logger.info(
"Creating inference engine (type=%s)...",
cfg.inference.type if hasattr(cfg.inference, "type") else "sync",
)
task_str = cfg.dataset.single_task if cfg.dataset else cfg.task
inference_strategy = create_inference_engine(
cfg.inference,
policy=policy,
preprocessor=preprocessor,
postprocessor=postprocessor,
robot_wrapper=robot_wrapper,
hw_features=hw_features,
dataset_features=dataset_features,
ordered_action_keys=ordered_action_keys,
task=task_str,
fps=cfg.fps,
device=cfg.device,
use_torch_compile=cfg.use_torch_compile,
compile_warmup_inferences=cfg.compile_warmup_inferences,
shutdown_event=shutdown_event,
)
# --- 8. Assemble ---------------------------------------------------
logger.info("Rollout context assembled successfully")
return RolloutContext(
runtime=RuntimeContext(cfg=cfg, shutdown_event=shutdown_event),
hardware=HardwareContext(
robot_wrapper=robot_wrapper, teleop=teleop, initial_position=initial_position
),
policy=PolicyContext(
policy=policy,
preprocessor=preprocessor,
postprocessor=postprocessor,
inference=inference_strategy,
),
processors=ProcessorContext(
teleop_action_processor=teleop_action_processor,
robot_action_processor=robot_action_processor,
robot_observation_processor=robot_observation_processor,
),
data=DatasetContext(
dataset=dataset,
dataset_features=dataset_features,
hw_features=hw_features,
ordered_action_keys=ordered_action_keys,
),
)
-39
View File
@@ -1,39 +0,0 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference engine package — backend-agnostic action production.
Concrete strategies (sync, RTC, ) expose the same small interface so
rollout strategies never branch on the inference backend.
"""
from .base import InferenceEngine
from .factory import (
InferenceEngineConfig,
RTCInferenceConfig,
SyncInferenceConfig,
create_inference_engine,
)
from .rtc import RTCInferenceEngine
from .sync import SyncInferenceEngine
__all__ = [
"InferenceEngine",
"InferenceEngineConfig",
"RTCInferenceConfig",
"RTCInferenceEngine",
"SyncInferenceConfig",
"SyncInferenceEngine",
"create_inference_engine",
]
-88
View File
@@ -1,88 +0,0 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference engine ABC.
Rollout strategies consume actions through this small interface so they
do not need to know whether the inference engine is synchronous, runs in
a background thread (RTC), or comes from an external source.
"""
from __future__ import annotations
import abc
import torch
class InferenceEngine(abc.ABC):
"""Abstract backend for producing actions during rollout.
Subclasses decide whether inference happens inline, in a background
thread, or externally. The contract is minimal so new backends can
be added without touching rollout strategies.
Lifecycle
---------
``start`` prepare the backend (e.g. launch a background thread).
``stop`` shut the backend down cleanly.
``reset`` clear episode-scoped state (policy hidden state, queues).
Action production
-----------------
``get_action(obs_frame)`` return the next action tensor, or
``None`` if none is available (e.g. async queue empty). Sync
backends always compute from ``obs_frame``; async backends may
ignore it (they get observations via ``notify_observation``).
Optional hooks
--------------
``notify_observation`` / ``pause`` / ``resume`` have a no-op default
so rollout strategies can invoke them unconditionally.
"""
@abc.abstractmethod
def start(self) -> None:
"""Initialise the backend."""
@abc.abstractmethod
def stop(self) -> None:
"""Tear the backend down."""
@abc.abstractmethod
def reset(self) -> None:
"""Clear episode-scoped state."""
@abc.abstractmethod
def get_action(self, obs_frame: dict | None) -> torch.Tensor | None:
"""Return the next action tensor, or ``None`` if unavailable."""
def notify_observation(self, obs: dict) -> None: # noqa: B027
"""Publish the latest processed observation. Default: no-op."""
def pause(self) -> None: # noqa: B027
"""Pause background inference. Default: no-op."""
def resume(self) -> None: # noqa: B027
"""Resume background inference. Default: no-op."""
@property
def ready(self) -> bool:
"""True once the backend can produce actions (e.g. warmup done)."""
return True
@property
def failed(self) -> bool:
"""True if an unrecoverable error occurred in the backend."""
return False
-129
View File
@@ -1,129 +0,0 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference engine configs and factory.
Selection is explicit via ``--inference.type=sync|rtc``. Adding a new
backend requires registering its config subclass and dispatching it in
:func:`create_inference_engine`.
"""
from __future__ import annotations
import abc
import logging
from dataclasses import dataclass, field
from threading import Event
import draccus
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.rtc.configuration_rtc import RTCConfig
from lerobot.processor import PolicyProcessorPipeline
from ..robot_wrapper import ThreadSafeRobot
from .base import InferenceEngine
from .rtc import RTCInferenceEngine
from .sync import SyncInferenceEngine
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Configs
# ---------------------------------------------------------------------------
@dataclass
class InferenceEngineConfig(draccus.ChoiceRegistry, abc.ABC):
"""Abstract base for inference backend configuration.
Use ``--inference.type=<name>`` on the CLI to select a backend.
"""
@property
def type(self) -> str:
return self.get_choice_name(self.__class__)
@InferenceEngineConfig.register_subclass("sync")
@dataclass
class SyncInferenceConfig(InferenceEngineConfig):
"""Inline synchronous inference (one policy call per control tick)."""
@InferenceEngineConfig.register_subclass("rtc")
@dataclass
class RTCInferenceConfig(InferenceEngineConfig):
"""Real-Time Chunking: async policy inference in a background thread."""
# ``RTCConfig`` is a small dataclass with default-only fields, so eagerly
# constructing one here costs nothing and keeps draccus' CLI surface flat
# (``--inference.rtc.execution_horizon=...`` etc.). No need to lazy-init.
rtc: RTCConfig = field(default_factory=RTCConfig)
queue_threshold: int = 30
# ---------------------------------------------------------------------------
# Factory
# ---------------------------------------------------------------------------
def create_inference_engine(
config: InferenceEngineConfig,
*,
policy: PreTrainedPolicy,
preprocessor: PolicyProcessorPipeline,
postprocessor: PolicyProcessorPipeline,
robot_wrapper: ThreadSafeRobot,
hw_features: dict,
dataset_features: dict,
ordered_action_keys: list[str],
task: str,
fps: float,
device: str | None,
use_torch_compile: bool = False,
compile_warmup_inferences: int = 2,
shutdown_event: Event | None = None,
) -> InferenceEngine:
"""Instantiate the appropriate inference engine from a config object."""
logger.info("Creating inference engine: %s", config.type)
if isinstance(config, SyncInferenceConfig):
return SyncInferenceEngine(
policy=policy,
preprocessor=preprocessor,
postprocessor=postprocessor,
dataset_features=dataset_features,
ordered_action_keys=ordered_action_keys,
task=task,
device=device,
robot_type=robot_wrapper.robot_type,
)
if isinstance(config, RTCInferenceConfig):
return RTCInferenceEngine(
policy=policy,
preprocessor=preprocessor,
postprocessor=postprocessor,
robot_wrapper=robot_wrapper,
rtc_config=config.rtc,
hw_features=hw_features,
task=task,
fps=fps,
device=device,
use_torch_compile=use_torch_compile,
compile_warmup_inferences=compile_warmup_inferences,
rtc_queue_threshold=config.queue_threshold,
shutdown_event=shutdown_event,
)
raise ValueError(f"Unknown inference engine type: {type(config).__name__}")
-391
View File
@@ -1,391 +0,0 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Real-Time Chunking inference engine.
A background thread produces action chunks asynchronously via
:meth:`policy.predict_action_chunk`. The main control loop polls
``get_action`` for the next ready action; observations flow the other
way via ``notify_observation``.
"""
from __future__ import annotations
import logging
import math
import time
import traceback
from threading import Event, Lock, Thread
from typing import Any
import torch
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.rtc import ActionQueue, LatencyTracker
from lerobot.policies.rtc.configuration_rtc import RTCConfig
from lerobot.policies.utils import prepare_observation_for_inference
from lerobot.processor import (
NormalizerProcessorStep,
PolicyProcessorPipeline,
RelativeActionsProcessorStep,
TransitionKey,
create_transition,
to_relative_actions,
)
from lerobot.utils.constants import OBS_STATE
from lerobot.utils.feature_utils import build_dataset_frame
from ..robot_wrapper import ThreadSafeRobot
from .base import InferenceEngine
logger = logging.getLogger(__name__)
# How long the RTC loop sleeps when paused, idle, or backpressured by a full queue.
_RTC_IDLE_SLEEP_S: float = 0.01
# Backoff between transient inference errors (per consecutive failure).
_RTC_ERROR_RETRY_DELAY_S: float = 0.5
# Consecutive transient errors tolerated before giving up and propagating shutdown.
_RTC_MAX_CONSECUTIVE_ERRORS: int = 10
# Hard timeout for joining the RTC thread on stop().
_RTC_JOIN_TIMEOUT_S: float = 3.0
# ---------------------------------------------------------------------------
# RTC helpers
# ---------------------------------------------------------------------------
def _reanchor_relative_rtc_prefix(
prev_actions_absolute: torch.Tensor,
current_state: torch.Tensor,
relative_step: RelativeActionsProcessorStep,
normalizer_step: NormalizerProcessorStep | None,
policy_device: torch.device | str,
) -> torch.Tensor:
"""Convert absolute leftover actions into model-space for relative-action RTC policies.
When using relative actions, the RTC prefix (previous chunk's unexecuted tail)
is stored in absolute coordinates. Before feeding it back to the policy, this
helper re-expresses those actions relative to the robot's current joint state
and optionally normalizes them so the policy receives correctly scaled inputs.
"""
state = current_state.detach().cpu()
if state.dim() == 1:
state = state.unsqueeze(0)
action_cpu = prev_actions_absolute.detach().cpu()
mask = relative_step._build_mask(action_cpu.shape[-1])
relative_actions = to_relative_actions(action_cpu, state, mask)
transition = create_transition(action=relative_actions)
if normalizer_step is not None:
transition = normalizer_step(transition)
return transition[TransitionKey.ACTION].to(policy_device)
def _normalize_prev_actions_length(prev_actions: torch.Tensor, target_steps: int) -> torch.Tensor:
"""Pad or truncate RTC prefix actions to a fixed length for stable compiled inference."""
if prev_actions.ndim != 2:
raise ValueError(f"Expected 2D [T, A] tensor, got shape={tuple(prev_actions.shape)}")
steps, action_dim = prev_actions.shape
if steps == target_steps:
return prev_actions
if steps > target_steps:
return prev_actions[:target_steps]
padded = torch.zeros((target_steps, action_dim), dtype=prev_actions.dtype, device=prev_actions.device)
padded[:steps] = prev_actions
return padded
# ---------------------------------------------------------------------------
# RTCInferenceEngine
# ---------------------------------------------------------------------------
class RTCInferenceEngine(InferenceEngine):
"""Async RTC inference: a background thread produces action chunks.
``get_action`` pops the next action from the shared queue (or
returns ``None`` if the queue is empty). The main loop should call
``notify_observation`` every tick and ``pause``/``resume`` around
human-intervention phases.
"""
def __init__(
self,
policy: PreTrainedPolicy,
preprocessor: PolicyProcessorPipeline,
postprocessor: PolicyProcessorPipeline,
robot_wrapper: ThreadSafeRobot,
rtc_config: RTCConfig,
hw_features: dict,
task: str,
fps: float,
device: str | None,
use_torch_compile: bool = False,
compile_warmup_inferences: int = 2,
rtc_queue_threshold: int = 30,
shutdown_event: Event | None = None,
) -> None:
self._policy = policy
self._preprocessor = preprocessor
self._postprocessor = postprocessor
self._robot = robot_wrapper
self._rtc_config = rtc_config
self._hw_features = hw_features
self._task = task
self._fps = fps
self._device = device or "cpu"
self._use_torch_compile = use_torch_compile
self._compile_warmup_inferences = compile_warmup_inferences
self._rtc_queue_threshold = rtc_queue_threshold
self._action_queue: ActionQueue | None = None
self._obs_holder: dict[str, Any] = {}
self._obs_lock = Lock()
self._policy_active = Event()
self._compile_warmup_done = Event()
self._shutdown_event = Event()
self._rtc_error = Event()
self._global_shutdown_event = shutdown_event
self._rtc_thread: Thread | None = None
if not self._use_torch_compile:
self._compile_warmup_done.set()
logger.info("RTCInferenceEngine initialized (torch.compile disabled, no warmup needed)")
else:
logger.info(
"RTCInferenceEngine initialized (torch.compile enabled, %d warmup inferences)",
compile_warmup_inferences,
)
# Processor introspection for relative-action re-anchoring.
self._relative_step = next(
(s for s in preprocessor.steps if isinstance(s, RelativeActionsProcessorStep) and s.enabled),
None,
)
self._normalizer_step = next(
(s for s in preprocessor.steps if isinstance(s, NormalizerProcessorStep)),
None,
)
if self._relative_step is not None:
if self._relative_step.action_names is None:
cfg_names = getattr(policy.config, "action_feature_names", None)
if cfg_names:
self._relative_step.action_names = list(cfg_names)
else:
self._relative_step.action_names = [
k for k in robot_wrapper.action_features if k.endswith(".pos")
]
logger.info("Relative actions enabled: RTC prefix will be re-anchored")
# ------------------------------------------------------------------
# Lifecycle
# ------------------------------------------------------------------
@property
def ready(self) -> bool:
"""True once torch.compile warmup is complete (or immediately if compile is disabled)."""
return self._compile_warmup_done.is_set()
@property
def failed(self) -> bool:
"""True if the RTC background thread exited due to an unrecoverable error."""
return self._rtc_error.is_set()
@property
def action_queue(self) -> ActionQueue | None:
"""The shared action queue between the RTC thread and the main loop."""
return self._action_queue
def start(self) -> None:
"""Launch the RTC background thread."""
self._action_queue = ActionQueue(self._rtc_config)
self._obs_holder = {
"obs": None,
"robot_type": self._robot.robot_type,
}
self._shutdown_event.clear()
self._rtc_thread = Thread(
target=self._rtc_loop,
daemon=True,
name="RTCInference",
)
self._rtc_thread.start()
logger.info("RTC inference thread started")
def stop(self) -> None:
"""Signal the RTC thread to stop and wait for it."""
logger.info("Stopping RTC inference thread...")
self._shutdown_event.set()
self._policy_active.clear()
if self._rtc_thread is not None and self._rtc_thread.is_alive():
self._rtc_thread.join(timeout=_RTC_JOIN_TIMEOUT_S)
if self._rtc_thread.is_alive():
logger.warning("RTC thread did not join within %.1fs", _RTC_JOIN_TIMEOUT_S)
else:
logger.info("RTC inference thread stopped")
self._rtc_thread = None
def pause(self) -> None:
"""Pause the RTC background thread."""
logger.info("Pausing RTC inference thread")
self._policy_active.clear()
def resume(self) -> None:
"""Resume the RTC background thread."""
logger.info("Resuming RTC inference thread")
self._policy_active.set()
def reset(self) -> None:
"""Reset the policy, processors, and action queue."""
logger.info("Resetting RTC inference state (policy + processors + queue)")
self._policy.reset()
self._preprocessor.reset()
self._postprocessor.reset()
if self._action_queue is not None:
self._action_queue.clear()
# ------------------------------------------------------------------
# Action production (called from main thread)
# ------------------------------------------------------------------
def get_action(self, obs_frame: dict | None) -> torch.Tensor | None:
"""Pop the next action from the RTC queue (ignores ``obs_frame``)."""
if self._action_queue is None:
return None
return self._action_queue.get()
def notify_observation(self, obs: dict) -> None:
"""Publish the latest observation for the RTC thread to consume."""
with self._obs_lock:
self._obs_holder["obs"] = obs
# ------------------------------------------------------------------
# RTC: background inference thread
# ------------------------------------------------------------------
def _rtc_loop(self) -> None:
"""Background thread that generates action chunks via RTC."""
try:
latency_tracker = LatencyTracker()
time_per_chunk = 1.0 / self._fps
policy_device = torch.device(self._device)
warmup_required = max(1, self._compile_warmup_inferences) if self._use_torch_compile else 0
inference_count = 0
consecutive_errors = 0
while not self._shutdown_event.is_set():
if not self._policy_active.is_set():
time.sleep(_RTC_IDLE_SLEEP_S)
continue
queue = self._action_queue
with self._obs_lock:
obs = self._obs_holder.get("obs")
if queue is None or obs is None:
time.sleep(_RTC_IDLE_SLEEP_S)
continue
if queue.qsize() <= self._rtc_queue_threshold:
try:
current_time = time.perf_counter()
idx_before = queue.get_action_index()
prev_actions = queue.get_left_over()
latency = latency_tracker.max()
delay = math.ceil(latency / time_per_chunk) if latency else 0
obs_batch = build_dataset_frame(self._hw_features, obs, prefix="observation")
obs_batch = prepare_observation_for_inference(
obs_batch, policy_device, self._task, self._robot.robot_type
)
obs_batch["task"] = [self._task]
preprocessed = self._preprocessor(obs_batch)
if prev_actions is not None and self._relative_step is not None:
state_tensor = preprocessed.get(OBS_STATE)
if state_tensor is not None:
prev_abs = queue.get_processed_left_over()
if prev_abs is not None and prev_abs.numel() > 0:
prev_actions = _reanchor_relative_rtc_prefix(
prev_actions_absolute=prev_abs,
current_state=state_tensor,
relative_step=self._relative_step,
normalizer_step=self._normalizer_step,
policy_device=policy_device,
)
if prev_actions is not None:
prev_actions = _normalize_prev_actions_length(
prev_actions, target_steps=self._rtc_config.execution_horizon
)
actions = self._policy.predict_action_chunk(
preprocessed, inference_delay=delay, prev_chunk_left_over=prev_actions
)
original = actions.squeeze(0).clone()
processed = self._postprocessor(actions).squeeze(0)
new_latency = time.perf_counter() - current_time
new_delay = math.ceil(new_latency / time_per_chunk)
inference_count += 1
consecutive_errors = 0
is_warmup = self._use_torch_compile and inference_count <= warmup_required
if is_warmup:
latency_tracker.reset()
else:
latency_tracker.add(new_latency)
queue.merge(original, processed, new_delay, idx_before)
if (
is_warmup
and inference_count >= warmup_required
and not self._compile_warmup_done.is_set()
):
self._compile_warmup_done.set()
logger.info("Compile warmup complete (%d inferences)", inference_count)
logger.debug("RTC inference latency=%.2fs, queue=%d", new_latency, queue.qsize())
except Exception as e:
consecutive_errors += 1
logger.error(
"RTC inference error (%d/%d): %s",
consecutive_errors,
_RTC_MAX_CONSECUTIVE_ERRORS,
e,
)
logger.debug(traceback.format_exc())
if consecutive_errors >= _RTC_MAX_CONSECUTIVE_ERRORS:
# Persistent failure: stop retrying and propagate shutdown.
raise
time.sleep(_RTC_ERROR_RETRY_DELAY_S)
else:
time.sleep(_RTC_IDLE_SLEEP_S)
except Exception as e:
logger.error("Fatal error in RTC thread: %s", e)
logger.error(traceback.format_exc())
self._rtc_error.set()
# Unblock any warmup waiters so the main loop doesn't spin forever
self._compile_warmup_done.set()
# Signal the top-level shutdown so strategies exit their control loops
if self._global_shutdown_event is not None:
self._global_shutdown_event.set()
-107
View File
@@ -1,107 +0,0 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Synchronous inference engine: inline policy call per control tick."""
from __future__ import annotations
import logging
from contextlib import nullcontext
from copy import copy
import torch
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.utils import make_robot_action, prepare_observation_for_inference
from lerobot.processor import PolicyProcessorPipeline
from .base import InferenceEngine
logger = logging.getLogger(__name__)
class SyncInferenceEngine(InferenceEngine):
"""Inline synchronous inference: compute one action per call.
``get_action`` runs the full policy pipeline (pre/post-processor +
``select_action``) on the given observation frame and returns a
CPU action tensor reordered to match the dataset action keys.
"""
def __init__(
self,
policy: PreTrainedPolicy,
preprocessor: PolicyProcessorPipeline,
postprocessor: PolicyProcessorPipeline,
dataset_features: dict,
ordered_action_keys: list[str],
task: str,
device: str | None,
robot_type: str,
) -> None:
self._policy = policy
self._preprocessor = preprocessor
self._postprocessor = postprocessor
self._dataset_features = dataset_features
self._ordered_action_keys = ordered_action_keys
self._task = task
self._device = torch.device(device or "cpu")
self._robot_type = robot_type
logger.info(
"SyncInferenceEngine initialized (device=%s, action_keys=%d)",
self._device,
len(ordered_action_keys),
)
def start(self) -> None:
"""No background resources to start."""
logger.info("SyncInferenceEngine started (inline mode — no background thread)")
def stop(self) -> None:
"""No background resources to stop."""
logger.info("SyncInferenceEngine stopped")
def reset(self) -> None:
"""Reset the policy and pre/post-processors."""
logger.info("Resetting sync inference state (policy + processors)")
self._policy.reset()
self._preprocessor.reset()
self._postprocessor.reset()
def get_action(self, obs_frame: dict | None) -> torch.Tensor | None:
"""Run the full inference pipeline on ``obs_frame`` and return an action tensor."""
if obs_frame is None:
return None
# Shallow copy is intentional: the caller (`send_next_action`) builds
# ``obs_frame`` fresh per tick via ``build_dataset_frame``, so the
# tensor/array values are not shared with any other reader.
observation = copy(obs_frame)
autocast_ctx = (
torch.autocast(device_type=self._device.type)
if self._device.type == "cuda" and self._policy.config.use_amp
else nullcontext()
)
with torch.inference_mode(), autocast_ctx:
observation = prepare_observation_for_inference(
observation, self._device, self._task, self._robot_type
)
observation = self._preprocessor(observation)
action = self._policy.select_action(observation)
action = self._postprocessor(action)
action_tensor = action.squeeze(0).cpu()
# Reorder to match dataset action ordering so the caller can treat
# the returned tensor uniformly across backends.
action_dict = make_robot_action(action_tensor, self._dataset_features)
return torch.tensor([action_dict[k] for k in self._ordered_action_keys])
-112
View File
@@ -1,112 +0,0 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Memory-bounded ring buffer for the Highlight Reel rollout strategy."""
from __future__ import annotations
from collections import deque
import numpy as np
import torch
class RolloutRingBuffer:
"""Fixed-capacity circular buffer for observation/action frames.
Stores the last *N* seconds of telemetry in memory, bounded by both
time (``max_frames``) and memory (``max_memory_bytes``). When either
limit is reached the oldest frames are evicted.
.. note::
This class is **single-threaded**. ``append``/``drain``/``clear``
must all be called from the same thread (the rollout main loop).
Concurrent access from a background thread will corrupt
``_current_bytes`` accounting.
Parameters
----------
max_seconds:
Maximum duration of buffered telemetry.
max_memory_mb:
Hard memory cap in MiB. Frames are evicted when the estimated
total size exceeds this.
fps:
Frames per second used to convert ``max_seconds`` to a frame
count.
"""
def __init__(self, max_seconds: float = 30.0, max_memory_mb: float = 2048.0, fps: float = 30.0) -> None:
self._max_frames = int(max_seconds * fps)
self._max_bytes = int(max_memory_mb * 1024 * 1024)
self._buffer: deque[dict] = deque(maxlen=self._max_frames)
self._current_bytes: int = 0
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def append(self, frame: dict) -> None:
"""Add *frame* to the buffer, evicting the oldest if at capacity."""
frame_bytes = _estimate_frame_bytes(frame)
# Evict oldest frames until we are under the memory cap
while self._current_bytes + frame_bytes > self._max_bytes and self._buffer:
evicted = self._buffer.popleft()
self._current_bytes -= _estimate_frame_bytes(evicted)
self._buffer.append(frame)
self._current_bytes += frame_bytes
def drain(self) -> list[dict]:
"""Return all buffered frames and clear the buffer."""
frames = list(self._buffer)
self._buffer.clear()
self._current_bytes = 0
return frames
def clear(self) -> None:
"""Discard all buffered frames."""
self._buffer.clear()
self._current_bytes = 0
def __len__(self) -> int:
return len(self._buffer)
@property
def estimated_bytes(self) -> int:
"""Estimated total byte size of all buffered frames."""
return self._current_bytes
# ------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------
def _estimate_frame_bytes(frame: dict) -> int:
"""Rough byte estimate for a single frame dictionary."""
total = 0
for v in frame.values():
if isinstance(v, torch.Tensor):
# ``torch.Tensor`` has no ``nbytes``; compute it explicitly so the
# memory cap is honoured even when frames hold unconverted tensors.
total += v.nelement() * v.element_size()
elif isinstance(v, np.ndarray) or hasattr(v, "nbytes"):
total += v.nbytes
elif isinstance(v, (int, float)):
total += 8
elif isinstance(v, (str, bytes)):
total += len(v)
return max(total, 1) # avoid zero-size frames
-79
View File
@@ -1,79 +0,0 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Thread-safe robot wrapper for concurrent observation/action access."""
from __future__ import annotations
from threading import Lock
from typing import Any
from lerobot.robots import Robot
class ThreadSafeRobot:
"""Lock-protected wrapper around a :class:`Robot` for use with background threads.
When RTC inference runs in a background thread while the main loop
executes actions, both threads may access the robot concurrently.
This wrapper serialises ``get_observation`` and ``send_action`` calls.
Read-only properties are proxied without the lock since they don't
mutate hardware state.
"""
def __init__(self, robot: Robot) -> None:
self._robot = robot
self._lock = Lock()
# -- Lock-protected I/O --------------------------------------------------
def get_observation(self) -> dict[str, Any]:
with self._lock:
return self._robot.get_observation()
def send_action(self, action: dict[str, Any] | Any) -> Any:
with self._lock:
return self._robot.send_action(action)
# -- Read-only proxies (no lock needed) -----------------------------------
@property
def observation_features(self) -> dict:
return self._robot.observation_features
@property
def action_features(self) -> dict:
return self._robot.action_features
@property
def name(self) -> str:
return self._robot.name
@property
def robot_type(self) -> str:
return self._robot.robot_type
@property
def cameras(self):
return getattr(self._robot, "cameras", {})
@property
def is_connected(self) -> bool:
return self._robot.is_connected
@property
def inner(self) -> Robot:
"""Access the underlying robot (e.g. for connect/disconnect)."""
return self._robot
@@ -1,36 +0,0 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Rollout strategies — public API re-exports."""
from .base import BaseStrategy
from .core import RolloutStrategy, estimate_max_episode_seconds, safe_push_to_hub, send_next_action
from .dagger import DAggerEvents, DAggerPhase, DAggerStrategy
from .factory import create_strategy
from .highlight import HighlightStrategy
from .sentry import SentryStrategy
__all__ = [
"BaseStrategy",
"DAggerEvents",
"DAggerPhase",
"DAggerStrategy",
"HighlightStrategy",
"RolloutStrategy",
"SentryStrategy",
"create_strategy",
"estimate_max_episode_seconds",
"safe_push_to_hub",
"send_next_action",
]
-79
View File
@@ -1,79 +0,0 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Base rollout strategy: autonomous policy execution with no data recording."""
from __future__ import annotations
import logging
import time
from lerobot.utils.robot_utils import precise_sleep
from ..context import RolloutContext
from .core import RolloutStrategy, send_next_action
logger = logging.getLogger(__name__)
class BaseStrategy(RolloutStrategy):
"""Autonomous policy rollout with no data recording.
All actions flow through the ``robot_action_processor`` pipeline
before reaching the robot.
"""
def setup(self, ctx: RolloutContext) -> None:
"""Initialise the inference engine."""
self._init_engine(ctx)
logger.info("Base strategy ready")
def run(self, ctx: RolloutContext) -> None:
"""Run the autonomous control loop until shutdown or duration expires."""
engine = self._engine
cfg = ctx.runtime.cfg
robot = ctx.hardware.robot_wrapper
interpolator = self._interpolator
control_interval = interpolator.get_control_interval(cfg.fps)
start_time = time.perf_counter()
engine.resume()
logger.info("Base strategy control loop started")
while not ctx.runtime.shutdown_event.is_set():
loop_start = time.perf_counter()
if cfg.duration > 0 and (time.perf_counter() - start_time) >= cfg.duration:
logger.info("Duration limit reached (%.0fs)", cfg.duration)
break
obs = robot.get_observation()
obs_processed = ctx.processors.robot_observation_processor(obs)
engine.notify_observation(obs_processed)
if self._handle_warmup(cfg.use_torch_compile, loop_start, control_interval):
continue
action_dict = send_next_action(obs_processed, obs, ctx, interpolator)
self._log_telemetry(obs_processed, action_dict, ctx.runtime)
dt = time.perf_counter() - loop_start
if (sleep_t := control_interval - dt) > 0:
precise_sleep(sleep_t)
def teardown(self, ctx: RolloutContext) -> None:
"""Disconnect hardware and stop inference."""
self._teardown_hardware(ctx.hardware)
logger.info("Base strategy teardown complete")
-272
View File
@@ -1,272 +0,0 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Rollout strategy ABC and shared action-dispatch helper."""
from __future__ import annotations
import abc
import logging
import time
from typing import TYPE_CHECKING
from lerobot.datasets.utils import DEFAULT_VIDEO_FILE_SIZE_IN_MB
from lerobot.utils.action_interpolator import ActionInterpolator
from lerobot.utils.constants import OBS_STR
from lerobot.utils.feature_utils import build_dataset_frame
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.visualization_utils import log_rerun_data
from ..inference import InferenceEngine
if TYPE_CHECKING:
from ..configs import RolloutStrategyConfig
from ..context import HardwareContext, RolloutContext, RuntimeContext
logger = logging.getLogger(__name__)
class RolloutStrategy(abc.ABC):
"""Abstract base for rollout execution strategies.
Each concrete strategy implements a self-contained control loop with
its own recording/interaction semantics. Strategies are mutually
exclusive only one runs per session.
"""
def __init__(self, config: RolloutStrategyConfig) -> None:
self.config = config
self._engine: InferenceEngine | None = None
self._interpolator: ActionInterpolator | None = None
self._warmup_flushed: bool = False
def _init_engine(self, ctx: RolloutContext) -> None:
"""Attach the inference engine and action interpolator, then start the backend.
Creates an :class:`ActionInterpolator` from the config's
``interpolation_multiplier`` and starts the inference engine.
Call this from ``setup()`` so strategies share identical
initialisation without duplicating code.
"""
self._interpolator = ActionInterpolator(multiplier=ctx.runtime.cfg.interpolation_multiplier)
self._engine = ctx.policy.inference
logger.info("Starting inference engine...")
self._engine.start()
self._warmup_flushed = False
logger.info("Inference engine started")
def _handle_warmup(self, use_torch_compile: bool, loop_start: float, control_interval: float) -> bool:
"""Handle torch.compile warmup phase.
Returns ``True`` if the caller should ``continue`` (still warming
up). On the first post-warmup iteration the engine and
interpolator are reset so stale warmup state is discarded.
"""
engine = self._engine
interpolator = self._interpolator
if not use_torch_compile:
return False
if not engine.ready:
dt = time.perf_counter() - loop_start
if (sleep_t := control_interval - dt) > 0:
precise_sleep(sleep_t)
return True
if not self._warmup_flushed:
logger.info("Warmup complete — flushing stale state and resuming engine")
engine.reset()
interpolator.reset()
self._warmup_flushed = True
engine.resume()
return False
def _teardown_hardware(self, hw: HardwareContext) -> None:
"""Stop the inference engine, return robot to initial position, and disconnect hardware."""
if self._engine is not None:
logger.info("Stopping inference engine...")
self._engine.stop()
robot = hw.robot_wrapper.inner
if robot.is_connected:
if hw.initial_position:
logger.info("Returning robot to initial position before shutdown...")
self._return_to_initial_position(hw)
logger.info("Disconnecting robot...")
robot.disconnect()
teleop = hw.teleop
if teleop is not None and teleop.is_connected:
logger.info("Disconnecting teleoperator...")
teleop.disconnect()
@staticmethod
def _return_to_initial_position(hw: HardwareContext, duration_s: float = 3.0, fps: int = 50) -> None:
"""Smoothly interpolate the robot back to its initial position."""
robot = hw.robot_wrapper
target = hw.initial_position
try:
current_obs = robot.get_observation()
current_pos = {k: v for k, v in current_obs.items() if k in target}
steps = max(int(duration_s * fps), 1)
for step in range(1, steps + 1):
t = step / steps
interp = {}
for k in current_pos:
interp[k] = current_pos[k] * (1 - t) + target[k] * t
robot.send_action(interp)
precise_sleep(1 / fps)
except Exception as e:
logger.warning("Could not return to initial position: %s", e)
@staticmethod
def _log_telemetry(
obs_processed: dict | None,
action_dict: dict | None,
runtime_ctx: RuntimeContext,
) -> None:
"""Log observation/action telemetry to Rerun if display_data is enabled."""
cfg = runtime_ctx.cfg
if not cfg.display_data:
return
log_rerun_data(
observation=obs_processed,
action=action_dict,
compress_images=cfg.display_compressed_images,
)
@abc.abstractmethod
def setup(self, ctx: RolloutContext) -> None:
"""Strategy-specific initialisation (keyboard listeners, buffers, etc.)."""
@abc.abstractmethod
def run(self, ctx: RolloutContext) -> None:
"""Main rollout loop. Returns when shutdown is requested or duration expires."""
@abc.abstractmethod
def teardown(self, ctx: RolloutContext) -> None:
"""Cleanup: save dataset, stop threads, disconnect hardware."""
# ---------------------------------------------------------------------------
# Shared helpers
# ---------------------------------------------------------------------------
def safe_push_to_hub(dataset, tags=None, private=False) -> bool:
"""Push dataset to hub, skipping if no episodes have been saved.
Returns ``True`` if the push was attempted, ``False`` if skipped.
"""
if dataset.num_episodes == 0:
logger.warning("No episodes saved — skipping push to hub")
return False
dataset.push_to_hub(tags=tags, private=private)
return True
def estimate_max_episode_seconds(
dataset_features: dict,
fps: float,
target_size_mb: float = DEFAULT_VIDEO_FILE_SIZE_IN_MB,
) -> float:
"""Conservatively estimate how many seconds of video will exceed *target_size_mb*.
Each camera produces its own video file, so the episode duration is
driven by the **slowest** camera to fill ``target_size_mb`` i.e.
the one with the fewest pixels per frame (lowest bitrate).
Uses a deliberately **low** bits-per-pixel estimate so the computed
duration is *longer* than reality. By the time the timer fires the
actual video file is guaranteed to have crossed the target size,
which aligns episode boundaries with the dataset's video-file
chunking each ``push_to_hub`` uploads complete files rather than
re-uploading a still-growing one.
The estimate ignores codec-specific settings (CRF, preset) on purpose:
we only need a rough lower bound on bitrate, not a precise prediction.
Falls back to 600 s (10 min) when no video features are present.
"""
# 0.1 bits-per-pixel is a *low* estimate for CRF-30 streaming video of
# robot footage (real-world is typically 0.1 0.3 bpp). Under-
# estimating the bitrate over-estimates the time → the episode will be
# *larger* than target_size_mb when we save, which is what we want.
conservative_bpp = 0.1
# Collect per-camera pixel counts — each camera has its own video file.
camera_pixels = []
for feat in dataset_features.values():
if feat.get("dtype") == "video":
shape = feat.get("shape", ())
# Assuming shape could be (C, H, W) or (T, C, H, W)
# We want to extract the spatial dimensions.
if len(shape) >= 3:
h, w = shape[-2], shape[-1]
pixels = h * w
if pixels > 0:
camera_pixels.append(pixels)
if not camera_pixels:
return 600.0
# Use the smallest camera: it produces the lowest bitrate and therefore
# takes the longest to reach the target — the conservative choice.
min_pixels = min(camera_pixels)
bits_per_frame = min_pixels * conservative_bpp
bytes_per_second = (bits_per_frame * fps) / 8
# Guard against division by zero just in case
if bytes_per_second <= 0:
return 600.0
return (target_size_mb * 1024 * 1024) / bytes_per_second
# ---------------------------------------------------------------------------
# Shared action-dispatch helper
# ---------------------------------------------------------------------------
def send_next_action(
obs_processed: dict,
obs_raw: dict,
ctx: RolloutContext,
interpolator: ActionInterpolator,
) -> dict | None:
"""Dispatch the next action to the robot.
Pulls the next action tensor from the inference engine, feeds the
interpolator, and sends the interpolated action through the
``robot_action_processor`` to the robot. Works identically for
sync and async backends the rollout strategy never needs to branch.
Returns the action dict that was sent, or ``None`` if no action was
ready (e.g. empty async queue, interpolator not yet primed).
"""
engine = ctx.policy.inference
features = ctx.data.dataset_features
ordered_keys = ctx.data.ordered_action_keys
if interpolator.needs_new_action():
obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR)
action_tensor = engine.get_action(obs_frame)
if action_tensor is not None:
interpolator.add(action_tensor.cpu())
interp = interpolator.get()
if interp is None:
return None
action_dict = {k: interp[i].item() for i, k in enumerate(ordered_keys) if i < len(interp)}
processed = ctx.processors.robot_action_processor((action_dict, obs_raw))
ctx.hardware.robot_wrapper.send_action(processed)
return action_dict
-733
View File
@@ -1,733 +0,0 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""DAgger rollout strategy: Human-in-the-Loop data collection.
Implements the RaC paradigm (Recovery and Correction) for interactive
imitation learning. Alternates between autonomous policy execution and
human intervention via teleoperator.
Input is controlled via either a keyboard or foot pedal, selected by
the ``input_device`` config field. Each device exposes three actions:
1. **pause_resume** Toggle policy execution (AUTONOMOUS <-> PAUSED).
2. **correction** Toggle correction recording (PAUSED <-> CORRECTING).
3. **upload** Push dataset to hub on demand (corrections-only mode).
ESC (keyboard only) Stop session.
Recording Modes:
``record_autonomous=True``: Sentry-like continuous recording with
time-based episode rotation. Both autonomous and correction
frames are recorded; corrections tagged ``intervention=True``.
``record_autonomous=False``: Only correction windows are recorded.
Each correction (start to stop) becomes one episode.
"""
from __future__ import annotations
import contextlib
import enum
import logging
import os
import sys
import time
from concurrent.futures import Future, ThreadPoolExecutor
from threading import Event, Lock
from typing import Any
import numpy as np
from lerobot.common.control_utils import is_headless
from lerobot.datasets import VideoEncodingManager
from lerobot.datasets.utils import DEFAULT_VIDEO_FILE_SIZE_IN_MB
from lerobot.teleoperators import Teleoperator
from lerobot.utils.constants import ACTION, OBS_STR
from lerobot.utils.feature_utils import build_dataset_frame
from lerobot.utils.import_utils import _pynput_available
from lerobot.utils.pedal import start_pedal_listener
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.utils import log_say
from ..configs import DAggerKeyboardConfig, DAggerPedalConfig, DAggerStrategyConfig
from ..context import RolloutContext
from ..robot_wrapper import ThreadSafeRobot
from .core import RolloutStrategy, estimate_max_episode_seconds, safe_push_to_hub, send_next_action
PYNPUT_AVAILABLE = _pynput_available
keyboard = None
if PYNPUT_AVAILABLE:
try:
if ("DISPLAY" not in os.environ) and ("linux" in sys.platform):
logging.info("No DISPLAY set. Skipping pynput import.")
PYNPUT_AVAILABLE = False
else:
from pynput import keyboard
except Exception as e:
PYNPUT_AVAILABLE = False
logging.info(f"Could not import pynput: {e}")
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# DAgger state machine
# ---------------------------------------------------------------------------
class DAggerPhase(enum.Enum):
"""Observable phases of a DAgger episode."""
AUTONOMOUS = "autonomous" # Policy driving
PAUSED = "paused" # Engine paused, teleop aligned, awaiting input
CORRECTING = "correcting" # Human driving via teleop, recording interventions
# Valid (current_phase, event) -> next_phase
_DAGGER_TRANSITIONS: dict[tuple[DAggerPhase, str], DAggerPhase] = {
(DAggerPhase.AUTONOMOUS, "pause_resume"): DAggerPhase.PAUSED,
(DAggerPhase.PAUSED, "pause_resume"): DAggerPhase.AUTONOMOUS,
(DAggerPhase.PAUSED, "correction"): DAggerPhase.CORRECTING,
(DAggerPhase.CORRECTING, "correction"): DAggerPhase.PAUSED,
}
class DAggerEvents:
"""Thread-safe container for DAgger input device events.
The keyboard/pedal threads write transition requests; the main loop
consumes them.
"""
def __init__(self) -> None:
self._lock = Lock()
self._phase = DAggerPhase.AUTONOMOUS
self._pending_transition: str | None = None
# Session-level flags
self.stop_recording = Event()
self.upload_requested = Event()
# -- Thread-safe phase access ------------------------------------------
@property
def phase(self) -> DAggerPhase:
"""Current phase of the DAgger state machine."""
with self._lock:
return self._phase
@phase.setter
def phase(self, value: DAggerPhase) -> None:
with self._lock:
self._phase = value
def request_transition(self, event: str) -> None:
"""Request a phase transition (called from keyboard/pedal threads).
Only enqueues the request if it corresponds to a valid transition
from the current phase, preventing impossible state changes.
"""
with self._lock:
if (self._phase, event) in _DAGGER_TRANSITIONS:
self._pending_transition = event
def consume_transition(self) -> tuple[DAggerPhase, DAggerPhase] | None:
"""Consume a pending transition (called from main loop)."""
with self._lock:
if self._pending_transition is None:
return None
key = (self._phase, self._pending_transition)
self._pending_transition = None
new_phase = _DAGGER_TRANSITIONS.get(key)
if new_phase is None:
return None
old_phase = self._phase
self._phase = new_phase
return old_phase, new_phase
def reset(self) -> None:
"""Reset all transient state for a fresh session."""
with self._lock:
self._phase = DAggerPhase.AUTONOMOUS
self._pending_transition = None
self.upload_requested.clear()
# ---------------------------------------------------------------------------
# Teleoperator helpers
# ---------------------------------------------------------------------------
# TODO(Steven): either enforce this (meaning all teleop must implement these methods) or
# user is responsible for moving the teleop to the same position as the robot when starting the correction.
def _teleop_smooth_move_to(
teleop: Teleoperator, target_pos: dict, duration_s: float = 2.0, fps: int = 50
) -> None:
"""Smoothly move teleop to target position via linear interpolation.
Requires the teleoperator to support motor control methods
(``enable_torque``, ``write_goal_positions``, ``get_action``).
"""
teleop.enable_torque()
current = teleop.get_action()
steps = max(int(duration_s * fps), 1)
for step in range(steps + 1):
t = step / steps
interp = {}
for k in current:
if k in target_pos:
interp[k] = current[k] * (1 - t) + target_pos[k] * t
else:
interp[k] = current[k]
teleop.write_goal_positions(interp)
time.sleep(1 / fps)
# ---------------------------------------------------------------------------
# Input device handlers
# ---------------------------------------------------------------------------
def _init_dagger_keyboard(events: DAggerEvents, cfg: DAggerKeyboardConfig):
"""Initialise keyboard listener with DAgger 3-key controls.
Returns the pynput Listener (or ``None`` in headless mode or when
pynput is unavailable).
"""
if not PYNPUT_AVAILABLE or is_headless():
logger.warning("Headless environment or pynput unavailable — keyboard controls disabled")
return None
# Map config key names to pynput Key objects for special keys
special_keys = {
"space": keyboard.Key.space,
"tab": keyboard.Key.tab,
"enter": keyboard.Key.enter,
}
def _resolve_key(key) -> str | None:
"""Resolve a pynput key event to a config-comparable string."""
if key == keyboard.Key.esc:
return "esc"
for name, pynput_key in special_keys.items():
if key == pynput_key:
return name
if hasattr(key, "char") and key.char:
return key.char
return None
# Build mapping: resolved key string -> DAgger event name
key_to_event = {
cfg.pause_resume: "pause_resume",
cfg.correction: "correction",
}
def on_press(key):
try:
resolved = _resolve_key(key)
if resolved is None:
return
if resolved == "esc":
logger.info("Stop recording...")
events.stop_recording.set()
return
if resolved in key_to_event:
events.request_transition(key_to_event[resolved])
if resolved == cfg.upload:
events.upload_requested.set()
except Exception as e:
logger.debug("Key error: %s", e)
listener = keyboard.Listener(on_press=on_press)
listener.start()
logger.info(
"DAgger keyboard listener started (pause_resume='%s', correction='%s', upload='%s', ESC=stop)",
cfg.pause_resume,
cfg.correction,
cfg.upload,
)
return listener
def _init_dagger_pedal(events: DAggerEvents, cfg: DAggerPedalConfig):
"""Initialise foot pedal listener with DAgger 3-pedal controls.
Returns the pedal listener thread (or ``None`` if evdev is unavailable).
"""
code_to_event = {
cfg.pause_resume: "pause_resume",
cfg.correction: "correction",
}
def on_press(code: str) -> None:
if code in code_to_event:
events.request_transition(code_to_event[code])
if code == cfg.upload:
events.upload_requested.set()
logger.info("Initializing DAgger foot pedal listener (device=%s)", cfg.device_path)
return start_pedal_listener(on_press, device_path=cfg.device_path)
# ---------------------------------------------------------------------------
# DAgger Strategy
# ---------------------------------------------------------------------------
class DAggerStrategy(RolloutStrategy):
"""Human-in-the-Loop data collection with intervention tagging.
State machine::
AUTONOMOUS --(key1)--> PAUSED --(key2)--> CORRECTING --(key2)--> PAUSED
--(key1)--> AUTONOMOUS
Recording modes:
``record_autonomous=True``: Sentry-like continuous recording with
time-based episode rotation. Intervention frames tagged True.
``record_autonomous=False``: Only correction windows recorded.
Each correction = one episode. Upload on demand via key3.
"""
config: DAggerStrategyConfig
def __init__(self, config: DAggerStrategyConfig):
super().__init__(config)
self._listener = None
self._pedal_thread = None
self._events = DAggerEvents()
self._push_executor: ThreadPoolExecutor | None = None
self._pending_push: Future | None = None
self._needs_push = Event()
self._episode_lock = Lock()
def setup(self, ctx: RolloutContext) -> None:
"""Initialise the inference engine and input device listener."""
self._init_engine(ctx)
self._push_executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix="dagger-push")
target_mb = self.config.target_video_file_size_mb or DEFAULT_VIDEO_FILE_SIZE_IN_MB
self._episode_duration_s = estimate_max_episode_seconds(
ctx.data.dataset_features, ctx.runtime.cfg.fps, target_size_mb=target_mb
)
if self.config.input_device == "keyboard":
self._listener = _init_dagger_keyboard(self._events, self.config.keyboard)
else:
self._pedal_thread = _init_dagger_pedal(self._events, self.config.pedal)
record_mode = "all frames (sentry-like)" if self.config.record_autonomous else "corrections only"
logger.info(
"DAgger strategy ready (input=%s, episodes=%d, record=%s, episode_duration=%.0fs)",
self.config.input_device,
self.config.num_episodes,
record_mode,
self._episode_duration_s,
)
def run(self, ctx: RolloutContext) -> None:
"""Run DAgger episodes with human-in-the-loop intervention."""
if self.config.record_autonomous:
self._run_continuous(ctx)
else:
self._run_corrections_only(ctx)
def teardown(self, ctx: RolloutContext) -> None:
"""Stop listeners, finalise the dataset, and disconnect hardware."""
play_sounds = ctx.runtime.cfg.play_sounds
logger.info("Stopping DAgger recording")
log_say("Stopping DAgger recording", play_sounds)
if self._listener is not None and not is_headless():
logger.info("Stopping keyboard listener")
self._listener.stop()
# Flush any queued/running push cleanly
if self._push_executor is not None:
logger.info("Shutting down push executor (waiting for pending pushes)...")
self._push_executor.shutdown(wait=True)
self._push_executor = None
if ctx.data.dataset is not None:
logger.info("Finalizing dataset...")
ctx.data.dataset.finalize()
if self._needs_push.is_set() and ctx.runtime.cfg.dataset and ctx.runtime.cfg.dataset.push_to_hub:
logger.info("Pushing final dataset to hub...")
if safe_push_to_hub(
ctx.data.dataset,
tags=ctx.runtime.cfg.dataset.tags,
private=ctx.runtime.cfg.dataset.private,
):
logger.info("Dataset uploaded to hub")
log_say("Dataset uploaded to hub", play_sounds)
self._teardown_hardware(ctx.hardware)
logger.info("DAgger strategy teardown complete")
# ------------------------------------------------------------------
# Continuous recording mode (record_autonomous=True)
# ------------------------------------------------------------------
def _run_continuous(self, ctx: RolloutContext) -> None:
"""Sentry-like continuous recording with intervention tagging.
Episodes are auto-rotated every ``episode_time_s`` seconds and
uploaded in the background every ``upload_every_n_episodes`` episodes.
Both autonomous and correction frames are recorded; corrections are
tagged with ``intervention=True``.
"""
engine = self._engine
cfg = ctx.runtime.cfg
robot = ctx.hardware.robot_wrapper
teleop = ctx.hardware.teleop
dataset = ctx.data.dataset
events = self._events
interpolator = self._interpolator
features = ctx.data.dataset_features
control_interval = interpolator.get_control_interval(cfg.fps)
record_stride = max(1, cfg.interpolation_multiplier)
task_str = cfg.dataset.single_task if cfg.dataset else cfg.task
play_sounds = cfg.play_sounds
engine.reset()
interpolator.reset()
events.reset()
# TODO(Steven): either enforce this (meaning all teleop must implement these methods) or
# user is responsible for moving the teleop to the same position as the robot when starting the correction.
# teleop.disable_torque()
engine.resume()
last_action: dict[str, Any] | None = None
record_tick = 0
start_time = time.perf_counter()
episode_start = time.perf_counter()
episodes_since_push = 0
episode_duration_s = self._episode_duration_s
logger.info("DAgger continuous recording started (episode_duration=%.0fs)", episode_duration_s)
with VideoEncodingManager(dataset):
try:
while not events.stop_recording.is_set() and not ctx.runtime.shutdown_event.is_set():
loop_start = time.perf_counter()
if cfg.duration > 0 and (time.perf_counter() - start_time) >= cfg.duration:
logger.info("Duration limit reached (%.0fs)", cfg.duration)
break
# Process transitions
transition = events.consume_transition()
if transition is not None:
old_phase, new_phase = transition
self._apply_transition(old_phase, new_phase, engine, interpolator, robot, teleop)
last_action = None
phase = events.phase
obs = robot.get_observation()
obs_processed = ctx.processors.robot_observation_processor(obs)
obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR)
# --- CORRECTING: human teleop control ---
if phase == DAggerPhase.CORRECTING:
teleop_action = teleop.get_action()
processed_teleop = ctx.processors.teleop_action_processor((teleop_action, obs))
robot_action_to_send = ctx.processors.robot_action_processor((processed_teleop, obs))
robot.send_action(robot_action_to_send)
last_action = robot_action_to_send
self._log_telemetry(obs_processed, processed_teleop, ctx.runtime)
action_frame = build_dataset_frame(features, processed_teleop, prefix=ACTION)
if record_tick % record_stride == 0:
frame = {
**obs_frame,
**action_frame,
"task": task_str,
"intervention": np.array([True], dtype=bool),
}
dataset.add_frame(frame)
record_tick += 1
# --- PAUSED: hold position ---
elif phase == DAggerPhase.PAUSED:
if last_action:
robot.send_action(last_action)
# --- AUTONOMOUS: policy control ---
else:
engine.notify_observation(obs_processed)
if self._handle_warmup(cfg.use_torch_compile, loop_start, control_interval):
continue
action_dict = send_next_action(obs_processed, obs, ctx, interpolator)
if action_dict is not None:
self._log_telemetry(obs_processed, action_dict, ctx.runtime)
last_action = ctx.processors.robot_action_processor((action_dict, obs))
action_frame = build_dataset_frame(features, action_dict, prefix=ACTION)
if record_tick % record_stride == 0:
frame = {
**obs_frame,
**action_frame,
"task": task_str,
"intervention": np.array([False], dtype=bool),
}
dataset.add_frame(frame)
record_tick += 1
# Episode rotation derived from video file-size target.
# Do NOT save mid-correction — wait for the correction
# to finish so the episode boundary is clean.
elapsed = time.perf_counter() - episode_start
if elapsed >= episode_duration_s and phase != DAggerPhase.CORRECTING:
with self._episode_lock:
dataset.save_episode()
episodes_since_push += 1
self._needs_push.set()
logger.info(
"Episode saved (total: %d, elapsed: %.1fs)",
dataset.num_episodes,
elapsed,
)
log_say(f"Episode {dataset.num_episodes} saved", play_sounds)
if episodes_since_push >= self.config.upload_every_n_episodes:
self._background_push(dataset, cfg)
episodes_since_push = 0
episode_start = time.perf_counter()
dt = time.perf_counter() - loop_start
if (sleep_t := control_interval - dt) > 0:
precise_sleep(sleep_t)
finally:
logger.info("DAgger continuous control loop ended — pausing engine")
engine.pause()
# TODO(Steven): either enforce this (meaning all teleop must implement these methods) or
# user is responsible for moving the teleop to the same position as the robot when starting the correction.
# teleop.disable_torque()
with contextlib.suppress(Exception):
with self._episode_lock:
dataset.save_episode()
self._needs_push.set()
logger.info("Final in-progress episode saved")
# ------------------------------------------------------------------
# Corrections-only mode (record_autonomous=False)
# ------------------------------------------------------------------
def _run_corrections_only(self, ctx: RolloutContext) -> None:
"""Record only human correction windows. Each correction = one episode.
The policy runs autonomously without recording. When the user
pauses and starts a correction, frames are recorded with
``intervention=True``. Stopping the correction saves the episode.
The dataset can be uploaded on demand via the upload key/pedal.
"""
engine = self._engine
cfg = ctx.runtime.cfg
robot = ctx.hardware.robot_wrapper
teleop = ctx.hardware.teleop
dataset = ctx.data.dataset
events = self._events
interpolator = self._interpolator
features = ctx.data.dataset_features
control_interval = interpolator.get_control_interval(cfg.fps)
record_stride = max(1, cfg.interpolation_multiplier)
task_str = cfg.dataset.single_task if cfg.dataset else cfg.task
play_sounds = cfg.play_sounds
engine.reset()
interpolator.reset()
events.reset()
# TODO(Steven): either enforce this (meaning all teleop must implement these methods) or
# user is responsible for moving the teleop to the same position as the robot when starting the correction.
# teleop.disable_torque()
engine.resume()
last_action: dict[str, Any] | None = None
record_tick = 0
recorded = 0
logger.info(
"DAgger corrections-only recording started (target: %d episodes)", self.config.num_episodes
)
with VideoEncodingManager(dataset):
try:
while (
recorded < self.config.num_episodes
and not events.stop_recording.is_set()
and not ctx.runtime.shutdown_event.is_set()
):
loop_start = time.perf_counter()
# Process transitions
transition = events.consume_transition()
if transition is not None:
old_phase, new_phase = transition
self._apply_transition(old_phase, new_phase, engine, interpolator, robot, teleop)
last_action = None
# Correction ended -> save episode (blocking if not streaming)
if old_phase == DAggerPhase.CORRECTING and new_phase == DAggerPhase.PAUSED:
with self._episode_lock:
dataset.save_episode()
recorded += 1
self._needs_push.set()
logger.info(
"Correction %d/%d saved",
recorded,
self.config.num_episodes,
)
log_say(f"Correction {recorded} saved", play_sounds)
# On-demand upload
if events.upload_requested.is_set():
events.upload_requested.clear()
logger.info("Upload requested by user")
self._background_push(dataset, cfg)
phase = events.phase
obs = robot.get_observation()
obs_processed = ctx.processors.robot_observation_processor(obs)
# --- CORRECTING: human teleop control + recording ---
if phase == DAggerPhase.CORRECTING:
teleop_action = teleop.get_action()
processed_teleop = ctx.processors.teleop_action_processor((teleop_action, obs))
robot_action_to_send = ctx.processors.robot_action_processor((processed_teleop, obs))
robot.send_action(robot_action_to_send)
last_action = robot_action_to_send
self._log_telemetry(obs_processed, processed_teleop, ctx.runtime)
obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR)
action_frame = build_dataset_frame(features, processed_teleop, prefix=ACTION)
if record_tick % record_stride == 0:
dataset.add_frame(
{
**obs_frame,
**action_frame,
"task": task_str,
"intervention": np.array([True], dtype=bool),
}
)
record_tick += 1
# --- PAUSED: hold position ---
elif phase == DAggerPhase.PAUSED:
if last_action:
robot.send_action(last_action)
# --- AUTONOMOUS: policy control (no recording) ---
else:
engine.notify_observation(obs_processed)
if self._handle_warmup(cfg.use_torch_compile, loop_start, control_interval):
continue
action_dict = send_next_action(obs_processed, obs, ctx, interpolator)
if action_dict is not None:
self._log_telemetry(obs_processed, action_dict, ctx.runtime)
last_action = ctx.processors.robot_action_processor((action_dict, obs))
dt = time.perf_counter() - loop_start
if (sleep_t := control_interval - dt) > 0:
precise_sleep(sleep_t)
finally:
logger.info("DAgger corrections-only loop ended — pausing engine")
engine.pause()
# TODO(Steven): either enforce this (meaning all teleop must implement these methods) or
# user is responsible for moving the teleop to the same position as the robot when starting the correction.
# teleop.disable_torque()
with contextlib.suppress(Exception):
with self._episode_lock:
dataset.save_episode()
self._needs_push.set()
logger.info("Final in-progress episode saved")
# ------------------------------------------------------------------
# State-machine transition side-effects
# ------------------------------------------------------------------
@staticmethod
def _apply_transition(
old_phase: DAggerPhase,
new_phase: DAggerPhase,
engine,
interpolator,
robot: ThreadSafeRobot,
teleop: Teleoperator,
) -> None:
"""Execute side-effects for a validated phase transition."""
logger.info("Phase transition: %s -> %s", old_phase.value, new_phase.value)
if old_phase == DAggerPhase.AUTONOMOUS and new_phase == DAggerPhase.PAUSED:
logger.info("Pausing engine — robot holds position")
engine.pause()
obs = robot.get_observation()
_robot_pos = {
k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features
}
# TODO(Steven): either enforce this (meaning all teleop must implement these methods) or
# user is responsible for moving the teleop to the same position as the robot when starting the correction.
# _teleop_smooth_move_to(teleop, robot_pos, duration_s=2.0, fps=50)
elif new_phase == DAggerPhase.CORRECTING:
logger.info("Entering correction mode — human teleop control")
# TODO(Steven): either enforce this (meaning all teleop must implement these methods) or
# user is responsible for moving the teleop to the same position as the robot when starting the correction.
# teleop.disable_torque()
elif new_phase == DAggerPhase.AUTONOMOUS:
logger.info("Resuming autonomous mode — resetting engine and interpolator")
interpolator.reset()
engine.reset()
engine.resume()
# ------------------------------------------------------------------
# Background push (shared by both modes)
# ------------------------------------------------------------------
def _background_push(self, dataset, cfg) -> None:
"""Queue a Hub push on the single-worker executor.
The executor's max_workers=1 guarantees at most one push runs at
a time; submitted tasks are queued rather than dropped. Pushes
are blocked while the operator is mid-correction to avoid
uploading a partially-recorded episode.
"""
if self._push_executor is None:
return
if self._events.phase == DAggerPhase.CORRECTING:
logger.info("Skipping push — correction in progress")
return
if self._pending_push is not None and not self._pending_push.done():
logger.info("Previous push still in progress; queueing next")
def _push():
try:
with self._episode_lock:
if safe_push_to_hub(
dataset,
tags=cfg.dataset.tags if cfg.dataset else None,
private=cfg.dataset.private if cfg.dataset else False,
):
self._needs_push.clear()
logger.info("Background push to hub complete")
except Exception as e:
logger.error("Background push failed: %s", e)
self._pending_push = self._push_executor.submit(_push)
logger.info("Background push task submitted")
-45
View File
@@ -1,45 +0,0 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Strategy factory: config type-name → strategy class dispatch."""
from __future__ import annotations
from typing import TYPE_CHECKING
from .base import BaseStrategy
from .core import RolloutStrategy
from .dagger import DAggerStrategy
from .highlight import HighlightStrategy
from .sentry import SentryStrategy
if TYPE_CHECKING:
from lerobot.rollout import RolloutStrategyConfig
def create_strategy(config: RolloutStrategyConfig) -> RolloutStrategy:
"""Instantiate the appropriate strategy from a config object.
Dispatches on ``config.type`` (the name registered via
``draccus.ChoiceRegistry``).
"""
if config.type == "base":
return BaseStrategy(config)
if config.type == "sentry":
return SentryStrategy(config)
if config.type == "highlight":
return HighlightStrategy(config)
if config.type == "dagger":
return DAggerStrategy(config)
raise ValueError(f"Unknown strategy type '{config.type}'. Available: base, sentry, highlight, dagger")
-277
View File
@@ -1,277 +0,0 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Highlight Reel strategy: on-demand recording via ring buffer."""
from __future__ import annotations
import contextlib
import logging
import os
import sys
import time
from concurrent.futures import Future, ThreadPoolExecutor
from threading import Event as ThreadingEvent
from lerobot.common.control_utils import is_headless
from lerobot.datasets import VideoEncodingManager
from lerobot.utils.constants import ACTION, OBS_STR
from lerobot.utils.feature_utils import build_dataset_frame
from lerobot.utils.import_utils import _pynput_available, require_package
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.utils import log_say
from ..configs import HighlightStrategyConfig
from ..context import RolloutContext
from ..ring_buffer import RolloutRingBuffer
from .core import RolloutStrategy, safe_push_to_hub, send_next_action
PYNPUT_AVAILABLE = _pynput_available
keyboard = None
if PYNPUT_AVAILABLE:
try:
if ("DISPLAY" not in os.environ) and ("linux" in sys.platform):
logging.info("No DISPLAY set. Skipping pynput import.")
PYNPUT_AVAILABLE = False
else:
from pynput import keyboard
except Exception as e:
PYNPUT_AVAILABLE = False
logging.info(f"Could not import pynput: {e}")
logger = logging.getLogger(__name__)
class HighlightStrategy(RolloutStrategy):
"""Autonomous rollout with on-demand recording via ring buffer.
The robot runs autonomously while a memory-bounded ring buffer
captures continuous telemetry. When the user presses the save key:
1. The ring buffer is flushed to the dataset (last *Z* seconds).
2. Live recording continues until the save key is pressed again.
3. The episode is saved and the ring buffer resumes capturing.
Requires ``streaming_encoding=True`` (enforced in config validation)
so that ``dataset.add_frame`` is a non-blocking queue put draining
900 frames stays sub-ms per frame.
"""
config: HighlightStrategyConfig
def __init__(self, config: HighlightStrategyConfig):
super().__init__(config)
require_package("pynput", extra="pynput-dep")
self._ring: RolloutRingBuffer | None = None
self._listener = None
self._save_requested = ThreadingEvent()
self._recording_live = ThreadingEvent()
self._push_requested = ThreadingEvent()
self._push_executor: ThreadPoolExecutor | None = None
self._pending_push: Future | None = None
def setup(self, ctx: RolloutContext) -> None:
"""Initialise the inference engine, ring buffer, and keyboard listener."""
self._init_engine(ctx)
self._ring = RolloutRingBuffer(
max_seconds=self.config.ring_buffer_seconds,
max_memory_mb=self.config.ring_buffer_max_memory_mb,
fps=ctx.runtime.cfg.fps,
)
self._push_executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix="highlight-push")
logger.info(
"Ring buffer initialized (max_seconds=%.0f, max_memory=%.0fMB)",
self.config.ring_buffer_seconds,
self.config.ring_buffer_max_memory_mb,
)
self._setup_keyboard(ctx.runtime.shutdown_event)
logger.info(
"Highlight strategy ready (buffer=%.0fs, save='%s', push='%s')",
self.config.ring_buffer_seconds,
self.config.save_key,
self.config.push_key,
)
def run(self, ctx: RolloutContext) -> None:
"""Run the autonomous loop, buffering frames and recording on demand."""
engine = self._engine
cfg = ctx.runtime.cfg
robot = ctx.hardware.robot_wrapper
dataset = ctx.data.dataset
ring = self._ring
interpolator = self._interpolator
features = ctx.data.dataset_features
control_interval = interpolator.get_control_interval(cfg.fps)
engine.resume()
play_sounds = cfg.play_sounds
start_time = time.perf_counter()
task_str = cfg.dataset.single_task if cfg.dataset else cfg.task
logger.info("Highlight strategy recording started (press '%s' to save)", self.config.save_key)
with VideoEncodingManager(dataset):
try:
while not ctx.runtime.shutdown_event.is_set():
loop_start = time.perf_counter()
if cfg.duration > 0 and (time.perf_counter() - start_time) >= cfg.duration:
logger.info("Duration limit reached (%.0fs)", cfg.duration)
break
obs = robot.get_observation()
obs_processed = ctx.processors.robot_observation_processor(obs)
engine.notify_observation(obs_processed)
if self._handle_warmup(cfg.use_torch_compile, loop_start, control_interval):
continue
action_dict = send_next_action(obs_processed, obs, ctx, interpolator)
if action_dict is not None:
self._log_telemetry(obs_processed, action_dict, ctx.runtime)
obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR)
action_frame = build_dataset_frame(features, action_dict, prefix=ACTION)
frame = {**obs_frame, **action_frame, "task": task_str}
# NOTE: ``is_set()`` then ``clear()`` is not atomic
# against the keyboard thread setting the flag again
# in between — but that is benign: we lose at most one
# toggle, processed on the next iteration. The
# ``_recording_live`` branch below is reached in the
# SAME iteration after ``clear()`` runs, so a frame
# finalised by ``save_episode()`` is never re-added to
# the next episode.
if self._save_requested.is_set():
self._save_requested.clear()
if not self._recording_live.is_set():
logger.info(
"Flushing ring buffer (%d frames) + starting live recording",
len(ring),
)
for buffered_frame in ring.drain():
dataset.add_frame(buffered_frame)
self._recording_live.set()
else:
dataset.add_frame(frame)
dataset.save_episode()
logger.info("Episode saved (total: %d)", dataset.num_episodes)
log_say(
f"Episode {dataset.num_episodes} saved",
play_sounds,
)
self._recording_live.clear()
if self._push_requested.is_set():
self._push_requested.clear()
logger.info("Push requested by user")
self._background_push(dataset, cfg)
if self._recording_live.is_set():
dataset.add_frame(frame)
else:
ring.append(frame)
dt = time.perf_counter() - loop_start
if (sleep_t := control_interval - dt) > 0:
precise_sleep(sleep_t)
finally:
logger.info("Highlight control loop ended")
if self._recording_live.is_set():
logger.info("Saving in-progress live episode")
with contextlib.suppress(Exception):
dataset.save_episode()
def teardown(self, ctx: RolloutContext) -> None:
"""Stop listeners, finalise the dataset, and disconnect hardware."""
play_sounds = ctx.runtime.cfg.play_sounds
logger.info("Stopping highlight recording")
log_say("Stopping highlight recording", play_sounds)
if self._listener is not None:
logger.info("Stopping keyboard listener")
self._listener.stop()
if self._push_executor is not None:
logger.info("Shutting down push executor (waiting for pending pushes)...")
self._push_executor.shutdown(wait=True)
self._push_executor = None
if ctx.data.dataset is not None:
logger.info("Finalizing dataset...")
ctx.data.dataset.finalize()
if ctx.runtime.cfg.dataset and ctx.runtime.cfg.dataset.push_to_hub:
logger.info("Pushing final dataset to hub...")
if safe_push_to_hub(
ctx.data.dataset,
tags=ctx.runtime.cfg.dataset.tags,
private=ctx.runtime.cfg.dataset.private,
):
logger.info("Dataset uploaded to hub")
log_say("Dataset uploaded to hub", play_sounds)
self._teardown_hardware(ctx.hardware)
logger.info("Highlight strategy teardown complete")
def _setup_keyboard(self, shutdown_event: ThreadingEvent) -> None:
"""Set up keyboard listener for save and push keys."""
if is_headless():
logger.warning("Headless environment — highlight keys unavailable")
return
try:
save_key = self.config.save_key
push_key = self.config.push_key
def on_press(key):
with contextlib.suppress(Exception):
if hasattr(key, "char") and key.char == save_key:
self._save_requested.set()
elif hasattr(key, "char") and key.char == push_key:
self._push_requested.set()
elif key == keyboard.Key.esc:
self._save_requested.clear()
shutdown_event.set()
self._listener = keyboard.Listener(on_press=on_press)
self._listener.start()
logger.info("Keyboard listener started (save='%s', push='%s', ESC=stop)", save_key, push_key)
except ImportError:
logger.warning("pynput not available — keyboard listener disabled")
def _background_push(self, dataset, cfg) -> None:
"""Queue a Hub push on the single-worker executor."""
if self._push_executor is None:
return
if self._pending_push is not None and not self._pending_push.done():
logger.info("Previous push still in progress; queueing next")
def _push():
try:
if safe_push_to_hub(
dataset,
tags=cfg.dataset.tags if cfg.dataset else None,
private=cfg.dataset.private if cfg.dataset else False,
):
logger.info("Background push to hub complete")
except Exception as e:
logger.error("Background push failed: %s", e)
self._pending_push = self._push_executor.submit(_push)
logger.info("Background push task submitted")
-225
View File
@@ -1,225 +0,0 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Sentry rollout strategy: continuous autonomous recording with auto-upload."""
from __future__ import annotations
import contextlib
import logging
import time
from concurrent.futures import Future, ThreadPoolExecutor
from threading import Event, Lock
from lerobot.datasets import VideoEncodingManager
from lerobot.datasets.utils import DEFAULT_VIDEO_FILE_SIZE_IN_MB
from lerobot.utils.constants import ACTION, OBS_STR
from lerobot.utils.feature_utils import build_dataset_frame
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.utils import log_say
from ..configs import SentryStrategyConfig
from ..context import RolloutContext
from .core import RolloutStrategy, estimate_max_episode_seconds, safe_push_to_hub, send_next_action
logger = logging.getLogger(__name__)
class SentryStrategy(RolloutStrategy):
"""Continuous autonomous rollout with always-on recording.
Episode duration is derived from camera resolution, FPS, and
``DEFAULT_VIDEO_FILE_SIZE_IN_MB`` so that each saved episode
produces a video file that has crossed the chunk-size boundary.
This keeps ``push_to_hub`` efficient it uploads complete video
files rather than re-uploading a still-growing one.
The dataset is pushed to the Hub via a bounded single-worker executor
so no push is ever silently dropped and exactly one push runs at a
time.
Policy state (hidden state, RTC queue) intentionally persists across
episode boundaries Sentry slices one continuous rollout, the robot
does not reset between slices.
Requires ``streaming_encoding=True`` (enforced in config validation)
to prevent disk I/O from blocking the control loop.
"""
config: SentryStrategyConfig
def __init__(self, config: SentryStrategyConfig):
super().__init__(config)
self._push_executor: ThreadPoolExecutor | None = None
self._pending_push: Future | None = None
self._needs_push = Event()
self._episode_lock = Lock()
def setup(self, ctx: RolloutContext) -> None:
"""Initialise the inference engine and background push executor."""
self._init_engine(ctx)
self._push_executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix="sentry-push")
target_mb = self.config.target_video_file_size_mb or DEFAULT_VIDEO_FILE_SIZE_IN_MB
self._episode_duration_s = estimate_max_episode_seconds(
ctx.data.dataset_features, ctx.runtime.cfg.fps, target_size_mb=target_mb
)
logger.info(
"Sentry strategy ready (episode_duration=%.0fs, upload_every=%d eps)",
self._episode_duration_s,
self.config.upload_every_n_episodes,
)
def run(self, ctx: RolloutContext) -> None:
"""Run the continuous recording loop with automatic episode rotation."""
engine = self._engine
cfg = ctx.runtime.cfg
robot = ctx.hardware.robot_wrapper
dataset = ctx.data.dataset
interpolator = self._interpolator
features = ctx.data.dataset_features
control_interval = interpolator.get_control_interval(cfg.fps)
engine.resume()
play_sounds = cfg.play_sounds
episode_duration_s = self._episode_duration_s
start_time = time.perf_counter()
episode_start = time.perf_counter()
episodes_since_push = 0
task_str = cfg.dataset.single_task if cfg.dataset else cfg.task
logger.info("Sentry recording started (episode_duration=%.0fs)", episode_duration_s)
with VideoEncodingManager(dataset):
try:
while not ctx.runtime.shutdown_event.is_set():
loop_start = time.perf_counter()
if cfg.duration > 0 and (time.perf_counter() - start_time) >= cfg.duration:
logger.info("Duration limit reached (%.0fs)", cfg.duration)
break
obs = robot.get_observation()
obs_processed = ctx.processors.robot_observation_processor(obs)
engine.notify_observation(obs_processed)
if self._handle_warmup(cfg.use_torch_compile, loop_start, control_interval):
continue
action_dict = send_next_action(obs_processed, obs, ctx, interpolator)
if action_dict is not None:
self._log_telemetry(obs_processed, action_dict, ctx.runtime)
obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR)
action_frame = build_dataset_frame(features, action_dict, prefix=ACTION)
frame = {**obs_frame, **action_frame, "task": task_str}
# ``add_frame`` writes to the in-progress episode buffer; the
# background pusher only ever touches *finalised* episode
# artifacts on disk. The two operate on disjoint state, so
# ``add_frame`` does not need ``_episode_lock``.
dataset.add_frame(frame)
# Episode rotation derived from video file-size target.
# The duration is a conservative estimate so the actual
# video has crossed DEFAULT_VIDEO_FILE_SIZE_IN_MB by now,
# keeping push_to_hub efficient (uploads complete files).
elapsed = time.perf_counter() - episode_start
if elapsed >= episode_duration_s:
# ``save_episode`` finalises the in-progress episode and
# flushes it to disk; ``_episode_lock`` serialises this with
# ``push_to_hub`` (run in the background executor) so the
# pusher never reads a half-written episode.
with self._episode_lock:
dataset.save_episode()
episodes_since_push += 1
self._needs_push.set()
logger.info(
"Episode saved (total: %d, elapsed: %.1fs)",
dataset.num_episodes,
elapsed,
)
log_say(f"Episode {dataset.num_episodes} saved", play_sounds)
if episodes_since_push >= self.config.upload_every_n_episodes:
self._background_push(dataset, cfg)
episodes_since_push = 0
episode_start = time.perf_counter()
dt = time.perf_counter() - loop_start
if (sleep_t := control_interval - dt) > 0:
precise_sleep(sleep_t)
finally:
logger.info("Sentry control loop ended — saving final episode")
with contextlib.suppress(Exception):
with self._episode_lock:
dataset.save_episode()
self._needs_push.set()
def teardown(self, ctx: RolloutContext) -> None:
"""Flush pending pushes, finalise the dataset, and disconnect hardware."""
play_sounds = ctx.runtime.cfg.play_sounds
logger.info("Stopping sentry recording")
log_say("Stopping sentry recording", play_sounds)
# Flush any queued/running push cleanly.
if self._push_executor is not None:
logger.info("Shutting down push executor (waiting for pending pushes)...")
self._push_executor.shutdown(wait=True)
self._push_executor = None
if ctx.data.dataset is not None:
logger.info("Finalizing dataset...")
ctx.data.dataset.finalize()
if self._needs_push.is_set() and ctx.runtime.cfg.dataset and ctx.runtime.cfg.dataset.push_to_hub:
logger.info("Pushing final dataset to hub...")
if safe_push_to_hub(
ctx.data.dataset,
tags=ctx.runtime.cfg.dataset.tags,
private=ctx.runtime.cfg.dataset.private,
):
logger.info("Dataset uploaded to hub")
log_say("Dataset uploaded to hub", play_sounds)
self._teardown_hardware(ctx.hardware)
logger.info("Sentry strategy teardown complete")
def _background_push(self, dataset, cfg) -> None:
"""Queue a Hub push on the single-worker executor.
The executor's max_workers=1 guarantees at most one push runs at
a time; submitted tasks are queued rather than dropped.
"""
if self._push_executor is None:
return
if self._pending_push is not None and not self._pending_push.done():
logger.info("Previous push still in progress; queueing next")
def _push():
try:
with self._episode_lock:
if safe_push_to_hub(
dataset,
tags=cfg.dataset.tags if cfg.dataset else None,
private=cfg.dataset.private if cfg.dataset else False,
):
self._needs_push.clear()
logger.info("Background push to hub complete")
except Exception as e:
logger.error("Background push failed: %s", e)
self._pending_push = self._push_executor.submit(_push)
logger.info("Background push task submitted")

Some files were not shown because too many files have changed in this diff Show More