mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-17 16:27:04 +00:00
Compare commits
49 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| f9b8f297b4 | |||
| 95527f6051 | |||
| 407ee867b9 | |||
| 26ff40ddd7 | |||
| a5e6409985 | |||
| 6d269b28c8 | |||
| b607c8458e | |||
| 9e83510c99 | |||
| 1c9fbba9a9 | |||
| 6a1b5ceb9d | |||
| daa4c4dd30 | |||
| 1f7b03f5f2 | |||
| ff992a7a1d | |||
| cb8edf17e6 | |||
| 5699f6cbf4 | |||
| 48269dddb3 | |||
| 8df8d3d866 | |||
| 0e6114ac36 | |||
| c8ce413d73 | |||
| 82dffde7fa | |||
| eaf0218bc8 | |||
| a0e52d52fe | |||
| e99c55af4b | |||
| 408e0ca763 | |||
| ce24063efd | |||
| 82934719db | |||
| 401a217597 | |||
| 40094b0464 | |||
| fdbfc015a2 | |||
| d656da8ccc | |||
| b5f65e5332 | |||
| cd6b43ea7a | |||
| 2236bbe7a3 | |||
| cb0a944941 | |||
| 8a3d64033f | |||
| 03ee50e08f | |||
| ca87ccd941 | |||
| 77352c495c | |||
| 05a5223885 | |||
| 580d818aa9 | |||
| 587aa82021 | |||
| 12b88fce02 | |||
| fc6c94c82a | |||
| 1add460678 | |||
| 4587c2b648 | |||
| 2236cdb302 | |||
| 7c2466979e | |||
| 39b966e20a | |||
| ba27aab79c |
@@ -382,6 +382,7 @@ jobs:
|
|||||||
--policy.path=\"\$ROBOTWIN_POLICY\" \
|
--policy.path=\"\$ROBOTWIN_POLICY\" \
|
||||||
--env.type=robotwin \
|
--env.type=robotwin \
|
||||||
--env.task=\"\$ROBOTWIN_TASKS\" \
|
--env.task=\"\$ROBOTWIN_TASKS\" \
|
||||||
|
--env.max_parallel_tasks=5 \
|
||||||
--eval.batch_size=1 \
|
--eval.batch_size=1 \
|
||||||
--eval.n_episodes=1 \
|
--eval.n_episodes=1 \
|
||||||
--eval.use_async_envs=false \
|
--eval.use_async_envs=false \
|
||||||
@@ -482,6 +483,7 @@ jobs:
|
|||||||
--policy.path=lerobot/smolvla_robocasa \
|
--policy.path=lerobot/smolvla_robocasa \
|
||||||
--env.type=robocasa \
|
--env.type=robocasa \
|
||||||
--env.task=CloseFridge,OpenCabinet,OpenDrawer,TurnOnMicrowave,TurnOffStove,CloseToasterOvenDoor,SlideDishwasherRack,TurnOnSinkFaucet,NavigateKitchen,TurnOnElectricKettle \
|
--env.task=CloseFridge,OpenCabinet,OpenDrawer,TurnOnMicrowave,TurnOffStove,CloseToasterOvenDoor,SlideDishwasherRack,TurnOnSinkFaucet,NavigateKitchen,TurnOnElectricKettle \
|
||||||
|
--env.max_parallel_tasks=5 \
|
||||||
--eval.batch_size=1 \
|
--eval.batch_size=1 \
|
||||||
--eval.n_episodes=1 \
|
--eval.n_episodes=1 \
|
||||||
--eval.use_async_envs=false \
|
--eval.use_async_envs=false \
|
||||||
@@ -693,6 +695,7 @@ jobs:
|
|||||||
--env.task=\"\$ROBOMME_TASKS\" \
|
--env.task=\"\$ROBOMME_TASKS\" \
|
||||||
--env.dataset_split=test \
|
--env.dataset_split=test \
|
||||||
--env.task_ids=[0] \
|
--env.task_ids=[0] \
|
||||||
|
--env.max_parallel_tasks=5 \
|
||||||
--eval.batch_size=1 \
|
--eval.batch_size=1 \
|
||||||
--eval.n_episodes=1 \
|
--eval.n_episodes=1 \
|
||||||
--eval.use_async_envs=false \
|
--eval.use_async_envs=false \
|
||||||
@@ -800,6 +803,7 @@ jobs:
|
|||||||
--env.type=libero_plus \
|
--env.type=libero_plus \
|
||||||
--env.task=\"\$LIBERO_PLUS_SUITE\" \
|
--env.task=\"\$LIBERO_PLUS_SUITE\" \
|
||||||
--env.task_ids=\"\$LIBERO_PLUS_TASK_IDS\" \
|
--env.task_ids=\"\$LIBERO_PLUS_TASK_IDS\" \
|
||||||
|
--env.max_parallel_tasks=5 \
|
||||||
--eval.batch_size=1 \
|
--eval.batch_size=1 \
|
||||||
--eval.n_episodes=1 \
|
--eval.n_episodes=1 \
|
||||||
--eval.use_async_envs=false \
|
--eval.use_async_envs=false \
|
||||||
@@ -900,6 +904,8 @@ jobs:
|
|||||||
--policy.path=lerobot/smolvla_vlabench \
|
--policy.path=lerobot/smolvla_vlabench \
|
||||||
--env.type=vlabench \
|
--env.type=vlabench \
|
||||||
--env.task=select_fruit,select_toy,select_book,select_painting,select_drink,select_ingredient,select_billiards,select_poker,add_condiment,insert_flower \
|
--env.task=select_fruit,select_toy,select_book,select_painting,select_drink,select_ingredient,select_billiards,select_poker,add_condiment,insert_flower \
|
||||||
|
--env.episode_length=50 \
|
||||||
|
--env.max_parallel_tasks=5 \
|
||||||
--eval.batch_size=1 \
|
--eval.batch_size=1 \
|
||||||
--eval.n_episodes=1 \
|
--eval.n_episodes=1 \
|
||||||
--eval.use_async_envs=false \
|
--eval.use_async_envs=false \
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ jobs:
|
|||||||
github.event.workflow_run.event == 'pull_request' &&
|
github.event.workflow_run.event == 'pull_request' &&
|
||||||
github.event.workflow_run.conclusion == 'success' &&
|
github.event.workflow_run.conclusion == 'success' &&
|
||||||
github.repository == 'huggingface/lerobot'
|
github.repository == 'huggingface/lerobot'
|
||||||
uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@9ad2de8582b56c017cb530c1165116d40433f1c6 # main
|
uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@2430c1ec91d04667414e2fa31ecfc36c153ea391 # main
|
||||||
with:
|
with:
|
||||||
package_name: lerobot
|
package_name: lerobot
|
||||||
secrets:
|
secrets:
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ jobs:
|
|||||||
github.repository == 'huggingface/lerobot'
|
github.repository == 'huggingface/lerobot'
|
||||||
permissions:
|
permissions:
|
||||||
contents: read
|
contents: read
|
||||||
uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@90b4ee2c10b81b5c1a6367c4e6fc9e2fb510a7e3 # main
|
uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@2430c1ec91d04667414e2fa31ecfc36c153ea391 # main
|
||||||
with:
|
with:
|
||||||
commit_sha: ${{ github.sha }}
|
commit_sha: ${{ github.sha }}
|
||||||
package: lerobot
|
package: lerobot
|
||||||
@@ -78,7 +78,7 @@ jobs:
|
|||||||
permissions:
|
permissions:
|
||||||
contents: read
|
contents: read
|
||||||
pull-requests: write
|
pull-requests: write
|
||||||
uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@90b4ee2c10b81b5c1a6367c4e6fc9e2fb510a7e3 # main
|
uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@2430c1ec91d04667414e2fa31ecfc36c153ea391 # main
|
||||||
with:
|
with:
|
||||||
commit_sha: ${{ github.event.pull_request.head.sha }}
|
commit_sha: ${{ github.event.pull_request.head.sha }}
|
||||||
pr_number: ${{ github.event.number }}
|
pr_number: ${{ github.event.number }}
|
||||||
|
|||||||
@@ -152,13 +152,14 @@ jobs:
|
|||||||
BASE_VERSION="${VERSION%%-*}"
|
BASE_VERSION="${VERSION%%-*}"
|
||||||
echo "Installing pre-release version $BASE_VERSION from TestPyPI..."
|
echo "Installing pre-release version $BASE_VERSION from TestPyPI..."
|
||||||
uv pip install \
|
uv pip install \
|
||||||
|
--torch-backend cpu \
|
||||||
--index-url https://test.pypi.org/simple/ \
|
--index-url https://test.pypi.org/simple/ \
|
||||||
--extra-index-url https://pypi.org/simple \
|
--extra-index-url https://pypi.org/simple \
|
||||||
--index-strategy unsafe-best-match \
|
--index-strategy unsafe-best-match \
|
||||||
"lerobot[all]==$BASE_VERSION"
|
"lerobot[all]==$BASE_VERSION"
|
||||||
else
|
else
|
||||||
echo "Installing release version $VERSION from PyPI..."
|
echo "Installing release version $VERSION from PyPI..."
|
||||||
uv pip install "lerobot[all]==$VERSION"
|
uv pip install --torch-backend cpu "lerobot[all]==$VERSION"
|
||||||
fi
|
fi
|
||||||
- name: Check lerobot version
|
- name: Check lerobot version
|
||||||
run: uv run python -c "import lerobot; print(lerobot.__version__)"
|
run: uv run python -c "import lerobot; print(lerobot.__version__)"
|
||||||
|
|||||||
@@ -19,19 +19,19 @@ on:
|
|||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
|
|
||||||
# Runs at 02:00
|
# Runs at 02:00
|
||||||
schedule:
|
# schedule:
|
||||||
- cron: "0 2 * * *"
|
# - cron: "0 2 * * *"
|
||||||
|
|
||||||
env:
|
env:
|
||||||
CLOSE_ISSUE_MESSAGE: >
|
CLOSE_ISSUE_MESSAGE: >
|
||||||
This issue was closed because it has been stalled for 14 days with no activity.
|
This issue was closed because it has been stalled for 30 days with no activity.
|
||||||
Feel free to reopen if is still relevant, or to ping a collaborator if you have any questions.
|
Feel free to reopen if is still relevant, or to ping a collaborator if you have any questions.
|
||||||
CLOSE_PR_MESSAGE: >
|
CLOSE_PR_MESSAGE: >
|
||||||
This PR was closed because it has been stalled for 21 days with no activity.
|
This PR was closed because it has been stalled for 30 days with no activity.
|
||||||
Feel free to reopen if is still relevant, or to ping a collaborator if you have any questions.
|
Feel free to reopen if is still relevant, or to ping a collaborator if you have any questions.
|
||||||
WARN_ISSUE_MESSAGE: >
|
WARN_ISSUE_MESSAGE: >
|
||||||
This issue has been automatically marked as stale because it has not had
|
This issue has been automatically marked as stale because it has not had
|
||||||
recent activity (6 months). It will be closed if no further activity occurs.
|
recent activity (1 year). It will be closed if no further activity occurs.
|
||||||
Any change, comment or update to this issue will reset this count.
|
Any change, comment or update to this issue will reset this count.
|
||||||
Thank you for your contributions.
|
Thank you for your contributions.
|
||||||
WARN_PR_MESSAGE: >
|
WARN_PR_MESSAGE: >
|
||||||
@@ -59,10 +59,10 @@ jobs:
|
|||||||
stale-pr-label: stale
|
stale-pr-label: stale
|
||||||
exempt-issue-labels: never-stale
|
exempt-issue-labels: never-stale
|
||||||
exempt-pr-labels: never-stale
|
exempt-pr-labels: never-stale
|
||||||
days-before-issue-stale: 180
|
days-before-issue-stale: 365
|
||||||
days-before-issue-close: 14
|
days-before-issue-close: 30
|
||||||
days-before-pr-stale: 365
|
days-before-pr-stale: 365
|
||||||
days-before-pr-close: 21
|
days-before-pr-close: 30
|
||||||
delete-branch: true
|
delete-branch: true
|
||||||
close-issue-message: ${{ env.CLOSE_ISSUE_MESSAGE }}
|
close-issue-message: ${{ env.CLOSE_ISSUE_MESSAGE }}
|
||||||
close-pr-message: ${{ env.CLOSE_PR_MESSAGE }}
|
close-pr-message: ${{ env.CLOSE_PR_MESSAGE }}
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
This file provides guidance to AI agents when working with code in this repository.
|
This file provides guidance to AI agents when working with code in this repository.
|
||||||
|
|
||||||
|
> **User-facing help → [`AGENT_GUIDE.md`](./AGENT_GUIDE.md)** (SO-101 setup, recording, picking a policy, training duration, eval — with copy-pasteable commands).
|
||||||
|
|
||||||
## Project Overview
|
## Project Overview
|
||||||
|
|
||||||
LeRobot is a PyTorch-based library for real-world robotics, providing datasets, pretrained policies, and tools for training, evaluation, data collection, and robot control. It integrates with Hugging Face Hub for model/dataset sharing.
|
LeRobot is a PyTorch-based library for real-world robotics, providing datasets, pretrained policies, and tools for training, evaluation, data collection, and robot control. It integrates with Hugging Face Hub for model/dataset sharing.
|
||||||
|
|||||||
+412
@@ -0,0 +1,412 @@
|
|||||||
|
# AGENT_GUIDE.md — LeRobot Helper for AI Agents & Users
|
||||||
|
|
||||||
|
This file is a practical, copy-paste-friendly companion for any AI agent (Cursor, Claude, ChatGPT, Codex, etc.) helping a user work with LeRobot. It complements [`AGENTS.md`](./AGENTS.md) (dev/contributor context) with **user-facing guidance**: how to start, what to train, how long, how to record, and how to calibrate an SO-101.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 1. Start here — ask the user first (MANDATORY)
|
||||||
|
|
||||||
|
Before suggesting any command, an agent MUST ask the user at least these questions and wait for answers:
|
||||||
|
|
||||||
|
1. **What's your goal?** (e.g. "teach my SO-101 to fold a cloth", "train a policy on an existing HF dataset", "contribute a PR", "understand the codebase")
|
||||||
|
2. **What hardware do you have?**
|
||||||
|
- Robot: none / SO-100 / SO-101 / Koch / LeKiwi / Reachy / other
|
||||||
|
- Teleop: leader arm / phone / keyboard / gamepad / none
|
||||||
|
- Cameras: how many, resolution, fixed or moving?
|
||||||
|
3. **What machine will you train on?**
|
||||||
|
- GPU model + VRAM (e.g. "laptop 3060 6 GB", "RTX 4090 24 GB", "A100 80 GB", "CPU only")
|
||||||
|
- OS: macOS / Linux / Windows
|
||||||
|
4. **Skill level & time budget?** First time, some ML, experienced? Hours, days, a weekend?
|
||||||
|
5. **Do you already have a dataset?** Yes (HF repo id?) / no / want to record one
|
||||||
|
6. **How can I help right now?** (pick one concrete next step)
|
||||||
|
|
||||||
|
Only after you have answers, propose a concrete path. If something is ambiguous, ask again rather than guessing. Bias toward **the simplest thing that works** for the user's hardware and goal.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 2. LeRobot in 60 seconds
|
||||||
|
|
||||||
|
LeRobot = **datasets + policies + envs + robot control**, unified by a small set of strong abstractions.
|
||||||
|
|
||||||
|
- **`LeRobotDataset`** — episode-aware dataset (video or images + actions + state), loadable from the Hub or disk.
|
||||||
|
- **Policies** (`ACT`, `Diffusion`, `SmolVLA`, `π0`, `π0.5`, `Wall-X`, `X-VLA`, `VQ-BeT`, `TD-MPC`, …) — all inherit `PreTrainedPolicy` and can be pushed/pulled from the Hub.
|
||||||
|
- **Processors** — small composable transforms between dataset → policy → robot.
|
||||||
|
- **Envs** (sim) and **Robots** (real) — same action/observation contract so code swaps cleanly.
|
||||||
|
- **CLI** — `lerobot-record`, `lerobot-train`, `lerobot-eval`, `lerobot-teleoperate`, `lerobot-calibrate`, `lerobot-find-port`, `lerobot-setup-motors`, `lerobot-replay`.
|
||||||
|
|
||||||
|
See [`AGENTS.md`](./AGENTS.md) for repo architecture.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 3. Quickstart paths (pick one)
|
||||||
|
|
||||||
|
### Path A — "I have an SO-101 and want my first trained policy"
|
||||||
|
|
||||||
|
Go to §4 (SO-101 end-to-end), then §5 (data tips), then §6 (pick a policy — likely **ACT**), then §7 (how long), then §8 (eval).
|
||||||
|
|
||||||
|
### Path B — "No hardware, I want to train on an existing dataset"
|
||||||
|
|
||||||
|
Skip §4. Pick a policy in §6, pick a duration in §7, then run `lerobot-train` per §4.9 with a Hub `--dataset.repo_id` and an `--env.type` for eval. Finish with §8.
|
||||||
|
|
||||||
|
### Path C — "I just want to understand the codebase"
|
||||||
|
|
||||||
|
Read §2 above, then `AGENTS.md` "Architecture", then open `src/lerobot/policies/act/` and `src/lerobot/datasets/lerobot_dataset.py` as canonical examples.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 4. SO-101 end-to-end cheat-sheet
|
||||||
|
|
||||||
|
Full details in [`docs/source/so101.mdx`](./docs/source/so101.mdx) and [`docs/source/il_robots.mdx`](./docs/source/il_robots.mdx). Minimum commands in order. Confirm arms are assembled + powered before issuing.
|
||||||
|
|
||||||
|
**4.1 Install**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install 'lerobot[feetech]' # SO-100/SO-101 motor stack
|
||||||
|
# pip install 'lerobot[all]' # everything
|
||||||
|
# pip install 'lerobot[aloha,pusht]' # specific features
|
||||||
|
# pip install 'lerobot[smolvla]' # add SmolVLA deps
|
||||||
|
git lfs install && git lfs pull
|
||||||
|
hf auth login # required to push datasets/policies
|
||||||
|
```
|
||||||
|
|
||||||
|
Contributors can alternatively use `uv sync --locked --extra feetech` (see `AGENTS.md`).
|
||||||
|
|
||||||
|
**4.2 Find USB ports** — run once per arm, unplug when prompted.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-find-port
|
||||||
|
```
|
||||||
|
|
||||||
|
macOS: `/dev/tty.usbmodem...`; Linux: `/dev/ttyACM0` (may need `sudo chmod 666 /dev/ttyACM0`).
|
||||||
|
|
||||||
|
**4.3 Setup motor IDs & baudrate** (one-time, per arm)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-setup-motors --robot.type=so101_follower --robot.port=<FOLLOWER_PORT>
|
||||||
|
lerobot-setup-motors --teleop.type=so101_leader --teleop.port=<LEADER_PORT>
|
||||||
|
```
|
||||||
|
|
||||||
|
**4.4 Calibrate** — center all joints, press Enter, sweep each joint through its full range. The `id` is the calibration key — reuse it everywhere.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-calibrate --robot.type=so101_follower --robot.port=<FOLLOWER_PORT> --robot.id=my_follower
|
||||||
|
lerobot-calibrate --teleop.type=so101_leader --teleop.port=<LEADER_PORT> --teleop.id=my_leader
|
||||||
|
```
|
||||||
|
|
||||||
|
**4.5 Teleoperate** (sanity check, no recording)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-teleoperate \
|
||||||
|
--robot.type=so101_follower --robot.port=<FOLLOWER_PORT> --robot.id=my_follower \
|
||||||
|
--teleop.type=so101_leader --teleop.port=<LEADER_PORT> --teleop.id=my_leader \
|
||||||
|
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||||
|
--display_data=true
|
||||||
|
```
|
||||||
|
|
||||||
|
> **Feetech timeout / comms error on SO-100 / SO-101?** Before touching software, check the **red motor LEDs** on the daisy chain.
|
||||||
|
>
|
||||||
|
> - **All steady red, gripper → base chain** → wiring OK.
|
||||||
|
> - **One or more motors dark / chain stops mid-way** → wiring issue: reseat the 3-pin cables, check the controller-board power supply, and make sure each motor is fully clicked in.
|
||||||
|
> - **LEDs blinking** → the motor is in an **error state**: usually overload (forcing a joint past its limit) **or wrong power supply voltage**. SO-100 / SO-101 ship in two variants — a **5 V / 7.4 V** build and a **12 V** build — they are NOT interchangeable. Using a 12 V PSU on a 5 V / 7.4 V arm (or vice-versa) will trip this error; confirm your motor variant before powering up.
|
||||||
|
>
|
||||||
|
> Most "timeout" errors are physical, not code.
|
||||||
|
|
||||||
|
**4.6 Record a dataset** — keys: **→** next, **←** redo, **ESC** finish & upload.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
HF_USER=$(NO_COLOR=1 hf auth whoami | awk -F': *' 'NR==1 {print $2}')
|
||||||
|
|
||||||
|
lerobot-record \
|
||||||
|
--robot.type=so101_follower --robot.port=<FOLLOWER_PORT> --robot.id=my_follower \
|
||||||
|
--teleop.type=so101_leader --teleop.port=<LEADER_PORT> --teleop.id=my_leader \
|
||||||
|
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||||
|
--dataset.repo_id=${HF_USER}/my_task \
|
||||||
|
--dataset.single_task="<describe the task in one sentence>" \
|
||||||
|
--dataset.num_episodes=50 \
|
||||||
|
--dataset.episode_time_s=30 \
|
||||||
|
--dataset.reset_time_s=10 \
|
||||||
|
--display_data=true
|
||||||
|
```
|
||||||
|
|
||||||
|
**4.7 Visualize** — **always** do this before training. Look for missing frames, camera blur, unreachable targets, inconsistent object positions.
|
||||||
|
After upload: https://huggingface.co/spaces/lerobot/visualize_dataset → paste `${HF_USER}/my_task`. Works for **any LeRobot-formatted Hub dataset** — use it to scout other datasets, inspect episode quality, or debug your own data before retraining.
|
||||||
|
|
||||||
|
**4.8 Replay an episode** (sanity check)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-replay --robot.type=so101_follower --robot.port=<FOLLOWER_PORT> --robot.id=my_follower \
|
||||||
|
--dataset.repo_id=${HF_USER}/my_task --dataset.episode=0
|
||||||
|
```
|
||||||
|
|
||||||
|
**4.9 Train** (default: ACT — fastest, lowest memory). Apple silicon: `--policy.device=mps`. See §6/§7 for policy and duration.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-train \
|
||||||
|
--dataset.repo_id=${HF_USER}/my_task \
|
||||||
|
--policy.type=act \
|
||||||
|
--policy.device=cuda \
|
||||||
|
--output_dir=outputs/train/act_my_task \
|
||||||
|
--job_name=act_my_task \
|
||||||
|
--batch_size=8 \
|
||||||
|
--wandb.enable=true \
|
||||||
|
--policy.repo_id=${HF_USER}/act_my_task
|
||||||
|
```
|
||||||
|
|
||||||
|
**4.10 Evaluate on the real robot** — compare success rate to a teleoperated baseline.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-record \
|
||||||
|
--robot.type=so101_follower --robot.port=<FOLLOWER_PORT> --robot.id=my_follower \
|
||||||
|
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||||
|
--dataset.repo_id=${HF_USER}/eval_my_task \
|
||||||
|
--dataset.single_task="<same task description as training>" \
|
||||||
|
--dataset.num_episodes=10 \
|
||||||
|
--policy.path=${HF_USER}/act_my_task
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 5. Data collection tips (beginner → reliable policy)
|
||||||
|
|
||||||
|
Good data beats clever models. Adopt these defaults and deviate only with evidence.
|
||||||
|
|
||||||
|
### 5.1 Setup & ergonomics
|
||||||
|
|
||||||
|
- **Fix the rig and cameras** before touching the software. If the rig vibrates or the operator gets frustrated, fix that first — more bad data won't help.
|
||||||
|
- **Lighting matters more than resolution.** Diffuse, consistent light. Avoid moving shadows.
|
||||||
|
- **"Can you do the task from the camera view alone?"** If no, your cameras are wrong. Fix before recording.
|
||||||
|
- Enable **action interpolation** for rollouts when available for smoother trajectories.
|
||||||
|
|
||||||
|
### 5.2 Practice before you record
|
||||||
|
|
||||||
|
- Do 5–10 demos without recording. Build a deliberate, repeatable strategy.
|
||||||
|
- Hesitant or inconsistent demos teach the model hesitation.
|
||||||
|
|
||||||
|
### 5.3 Quality over speed
|
||||||
|
|
||||||
|
Deliberate, high-quality execution beats fast sloppy runs. Optimize for speed only **after** strategy is dialed in — never trade quality for it.
|
||||||
|
|
||||||
|
### 5.4 Consistency within and across episodes
|
||||||
|
|
||||||
|
Same grasp, approach vector, and timing. Coherent strategies are much easier to learn than wildly varying movements.
|
||||||
|
|
||||||
|
### 5.5 Start small, then extend (the golden rule)
|
||||||
|
|
||||||
|
- **First 50 episodes = constrained version** of the task: one object, fixed position, fixed camera setup, one operator.
|
||||||
|
- Train a quick ACT model. See what fails.
|
||||||
|
- **Then add diversity** along one axis at a time: more positions → more lighting → more objects → more operators.
|
||||||
|
- Don't try to collect the "perfect dataset" on day one. Iterate.
|
||||||
|
|
||||||
|
### 5.6 Policy choice for beginners
|
||||||
|
|
||||||
|
- **Laptop / first time / want results fast → ACT.** Works surprisingly well, trains fast even on a laptop GPU.
|
||||||
|
- **Bigger GPU / language-conditioned / multi-task → SmolVLA.** Unfreezing the vision encoder (see §7) is a big win here.
|
||||||
|
- Defer π0 / π0.5 / Wall-X / X-VLA until you have a proven ACT baseline and a 20+ GB GPU.
|
||||||
|
|
||||||
|
### 5.7 Recommended defaults for your first task
|
||||||
|
|
||||||
|
| Setting | Value |
|
||||||
|
| ---------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||||
|
| Episodes | **50** to start, scale to 100–300 after first training |
|
||||||
|
| Episode length | 20–45 s (shorter is fine for grasp/place) |
|
||||||
|
| Reset time | 10 s |
|
||||||
|
| FPS | 30 |
|
||||||
|
| Cameras | **2 cameras recommended**: 1 fixed front + 1 wrist. Multi-view often outperforms single-view. A single fixed camera also works to keep things simple. |
|
||||||
|
| Task description | Short, specific, action-phrased sentence |
|
||||||
|
|
||||||
|
### 5.8 Troubleshooting signal
|
||||||
|
|
||||||
|
- Policy fails at one specific stage → record 10–20 more episodes **targeting that stage**.
|
||||||
|
- Policy flaps / oscillates → likely inconsistent demos, or need more training; re-record worst episodes (use **←** to redo).
|
||||||
|
- Policy ignores the object → camera framing or lighting issue, not a model issue.
|
||||||
|
|
||||||
|
See also: [What makes a good dataset](https://huggingface.co/blog/lerobot-datasets#what-makes-a-good-dataset).
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 6. Which policy should I train?
|
||||||
|
|
||||||
|
Match the policy to the user's **GPU memory** and **time budget**. Numbers below come from an internal profiling run (one training update per policy). They are **indicative only** — see caveats.
|
||||||
|
|
||||||
|
### 6.1 Profiling snapshot (indicative)
|
||||||
|
|
||||||
|
All policies typically train for **5–10 epochs** (see §7).
|
||||||
|
|
||||||
|
> **Human-facing version:** the [Compute Hardware Guide](./docs/source/hardware_guide.mdx) reuses the table below and adds a cloud-GPU tier guide and a Hugging Face Jobs pointer.
|
||||||
|
|
||||||
|
| Policy | Batch | Update (ms) | Peak GPU mem (GB) | Best for |
|
||||||
|
| ----------- | ----: | ----------: | ----------------: | ------------------------------------------------------------------------------------------------ |
|
||||||
|
| `act` | 4 | **83.9** | **0.94** | First-time users, laptops, single-task. Fast and reliable. |
|
||||||
|
| `diffusion` | 4 | 168.6 | 4.94 | Multi-modal action distributions; needs mid-range GPU. |
|
||||||
|
| `smolvla` | 1 | 357.8 | 3.93 | Language-conditioned, multi-task, small VLA. **Unfreeze vision encoder for big gains** (see §7). |
|
||||||
|
| `xvla` | 1 | 731.6 | 15.52 | Large VLA, multi-task. |
|
||||||
|
| `wall_x` | 1 | 716.5 | 15.95 | Large VLA with world-model objective. |
|
||||||
|
| `pi0` | 1 | 940.3 | 15.50 | Strong large VLA baseline (Physical Intelligence). |
|
||||||
|
| `pi05` | 1 | 1055.8 | 16.35 | Newer π policy; similar footprint to `pi0`. |
|
||||||
|
|
||||||
|
**Critical caveats:**
|
||||||
|
|
||||||
|
- **Optimizer:** measured with **SGD**. LeRobot's default is **AdamW**, which keeps extra optimizer state → **peak memory will be noticeably higher** with the default, especially for `pi0`, `pi05`, `wall_x`, `xvla`.
|
||||||
|
- **Batch size:** the large policies were profiled at batch 1. In practice use a **larger batch** for stable training (see §7.4). Memory scales roughly linearly with batch.
|
||||||
|
|
||||||
|
### 6.2 Decision rules
|
||||||
|
|
||||||
|
- **< 8 GB VRAM (laptop, 3060, M-series Mac):** → `act`. Maybe `diffusion` if you have ~6–8 GB free.
|
||||||
|
- **12–16 GB VRAM (4070/4080, A4000):** → `smolvla` with defaults, or `act`/`diffusion` with larger batch. `pi0`/`pi05`/`wall_x`/`xvla` feasible only with small batch + gradient accumulation.
|
||||||
|
- **24+ GB VRAM (3090/4090/A5000):** → any policy. Prefer `smolvla` (unfrozen) for multi-task; `act` for single-task grasp-and-place (still often the best ROI). Could experiment with `pi0` or `pi05` or `xvla`
|
||||||
|
- **80 GB (A100/H100):** → any, with healthy batch. `pi05`, `xvla`, `wall_x` become comfortable.
|
||||||
|
- **CPU only:** → don't train here. Use Google Colab (see [`docs/source/notebooks.mdx`](./docs/source/notebooks.mdx)) or a rented GPU.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 7. How long should I train?
|
||||||
|
|
||||||
|
Robotics imitation learning usually converges in a **few epochs over the dataset**, not hundreds of thousands of raw steps. Think **epochs first**, then translate to steps.
|
||||||
|
|
||||||
|
### 7.1 Rule of thumb
|
||||||
|
|
||||||
|
- **Typical total: 5–10 epochs.** Start at 5, eval, then decide if more helps.
|
||||||
|
- Very small datasets (< 30 episodes) may want slightly more epochs — but first, **collect more data**.
|
||||||
|
- VLAs with a pretrained vision backbone typically need **fewer** epochs than training from scratch.
|
||||||
|
|
||||||
|
### 7.2 Steps ↔ epochs conversion
|
||||||
|
|
||||||
|
```
|
||||||
|
total_frames = sum of frames over all episodes # e.g. 50 eps × 30 fps × 30 s ≈ 45,000
|
||||||
|
steps_per_epoch = ceil(total_frames / batch_size)
|
||||||
|
total_steps = epochs × steps_per_epoch
|
||||||
|
```
|
||||||
|
|
||||||
|
Examples for `--batch_size=8`:
|
||||||
|
|
||||||
|
| Dataset size | Frames | Steps / epoch | 5 epochs | 10 epochs |
|
||||||
|
| ----------------------- | ------: | ------------: | -------: | --------: |
|
||||||
|
| 50 eps × 30 s @ 30 fps | 45,000 | ~5,625 | 28k | 56k |
|
||||||
|
| 100 eps × 30 s @ 30 fps | 90,000 | ~11,250 | 56k | 113k |
|
||||||
|
| 300 eps × 30 s @ 30 fps | 270,000 | ~33,750 | 169k | 338k |
|
||||||
|
|
||||||
|
Pass the resulting total with `--steps=<N>`; eval at intermediate checkpoints (`outputs/train/.../checkpoints/`).
|
||||||
|
|
||||||
|
### 7.3 Per-policy starting points (single-task, ~50 episodes)
|
||||||
|
|
||||||
|
| Policy | Batch | Steps (first run) | Notes |
|
||||||
|
| -------------- | ----: | ----------------: | ----------------------------------------------------------------- |
|
||||||
|
| `act` | 8–16 | 30k–80k | Usually converges under 50k for single-task. |
|
||||||
|
| `diffusion` | 8–16 | 80k–150k | Benefits from longer training than ACT. |
|
||||||
|
| `smolvla` | 4–8 | 30k–80k | Pretrained VLM → converges fast. |
|
||||||
|
| `pi0` / `pi05` | 1–4 | 30k–80k | Memory-bound; use gradient accumulation for effective batch ≥ 16! |
|
||||||
|
|
||||||
|
### 7.4 Batch size guidance
|
||||||
|
|
||||||
|
- **Bigger batch is preferable** for stable gradients on teleop data.
|
||||||
|
- If GPU memory is the bottleneck, use **gradient accumulation** to raise _effective_ batch without raising peak memory.
|
||||||
|
- Scale **learning rate** gently with batch; most LeRobot defaults work fine for a 2–4× batch change.
|
||||||
|
|
||||||
|
### 7.5 Scale LR schedule & checkpoints with `--steps`
|
||||||
|
|
||||||
|
LeRobot's default schedulers (e.g. SmolVLA's cosine decay) use `scheduler_decay_steps=30_000`, which is sized for long training runs. When you shorten training (e.g. 5k–10k steps on a small dataset), **scale the scheduler down to match** — otherwise the LR stays near the peak and never decays. Same for checkpoint frequency.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-train ... \
|
||||||
|
--steps=5000 \
|
||||||
|
--policy.scheduler_decay_steps=5000 \
|
||||||
|
--save_freq=5000
|
||||||
|
```
|
||||||
|
|
||||||
|
Rule of thumb: set `scheduler_decay_steps ≈ steps`, and `save_freq` to whatever granularity you want for eval (e.g. every 1k–5k steps). Match `scheduler_warmup_steps` proportionally if your run is very short.
|
||||||
|
|
||||||
|
### 7.6 SmolVLA: unfreeze the vision encoder for real gains
|
||||||
|
|
||||||
|
SmolVLA ships with `freeze_vision_encoder=True`. Unfreezing usually **improves performance substantially** on specialized tasks, at the cost of more VRAM and slower steps. Enable with:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-train ... --policy.type=smolvla \
|
||||||
|
--policy.freeze_vision_encoder=false \
|
||||||
|
--policy.train_expert_only=false
|
||||||
|
```
|
||||||
|
|
||||||
|
### 7.7 Signals to stop / keep going
|
||||||
|
|
||||||
|
- Train loss plateaus → stop, save a Hub checkpoint.
|
||||||
|
- Train loss still dropping and you're under 10 epochs → keep going.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 8. Evaluation & benchmarks
|
||||||
|
|
||||||
|
Two flavors of evaluation:
|
||||||
|
|
||||||
|
### 8.1 Real-robot eval (SO-101, etc.)
|
||||||
|
|
||||||
|
Reuse `lerobot-record` with `--policy.path` to run the trained policy on-robot and save the run as an eval dataset. Convention: prefix the dataset with `eval_`.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-record \
|
||||||
|
--robot.type=so101_follower --robot.port=<FOLLOWER_PORT> --robot.id=my_follower \
|
||||||
|
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||||
|
--dataset.repo_id=${HF_USER}/eval_my_task \
|
||||||
|
--dataset.single_task="<same task description used during training>" \
|
||||||
|
--dataset.num_episodes=10 \
|
||||||
|
--policy.path=${HF_USER}/act_my_task
|
||||||
|
```
|
||||||
|
|
||||||
|
Report success rate across episodes. Compare to a teleoperated baseline and to an earlier checkpoint to catch regressions.
|
||||||
|
|
||||||
|
### 8.2 Sim-benchmark eval
|
||||||
|
|
||||||
|
For policies trained on sim datasets (PushT, Aloha, LIBERO, MetaWorld, RoboCasa, …) use `lerobot-eval` against the matching `env.type`:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-eval \
|
||||||
|
--policy.path=${HF_USER}/diffusion_pusht \
|
||||||
|
--env.type=pusht \
|
||||||
|
--eval.n_episodes=50 \
|
||||||
|
--eval.batch_size=10 \
|
||||||
|
--policy.device=cuda
|
||||||
|
```
|
||||||
|
|
||||||
|
- Use `--policy.path=outputs/train/.../checkpoints/<step>/pretrained_model` for local checkpoints.
|
||||||
|
- `--eval.n_episodes` should be ≥ 50 for a stable success-rate estimate.
|
||||||
|
- Available envs live in `src/lerobot/envs/`. See [`docs/source/libero.mdx`](./docs/source/libero.mdx), [`metaworld.mdx`](./docs/source/metaworld.mdx), [`robocasa.mdx`](./docs/source/robocasa.mdx), [`vlabench.mdx`](./docs/source/vlabench.mdx) for specific benchmarks.
|
||||||
|
- To add a new benchmark, see [`docs/source/adding_benchmarks.mdx`](./docs/source/adding_benchmarks.mdx) and [`envhub.mdx`](./docs/source/envhub.mdx).
|
||||||
|
|
||||||
|
### 8.2b Dockerfiles for benchmark eval
|
||||||
|
|
||||||
|
Benchmark envs have native dependencies that are painful to install locally. The repo ships **pre-baked Dockerfiles** for each supported benchmark — use these to run `lerobot-eval` in a reproducible environment:
|
||||||
|
|
||||||
|
| Benchmark | Dockerfile |
|
||||||
|
| ----------- | -------------------------------------------------------------------------------------- |
|
||||||
|
| LIBERO | [`docker/Dockerfile.benchmark.libero`](./docker/Dockerfile.benchmark.libero) |
|
||||||
|
| LIBERO+ | [`docker/Dockerfile.benchmark.libero_plus`](./docker/Dockerfile.benchmark.libero_plus) |
|
||||||
|
| MetaWorld | [`docker/Dockerfile.benchmark.metaworld`](./docker/Dockerfile.benchmark.metaworld) |
|
||||||
|
| RoboCasa | [`docker/Dockerfile.benchmark.robocasa`](./docker/Dockerfile.benchmark.robocasa) |
|
||||||
|
| RoboCerebra | [`docker/Dockerfile.benchmark.robocerebra`](./docker/Dockerfile.benchmark.robocerebra) |
|
||||||
|
| RoboMME | [`docker/Dockerfile.benchmark.robomme`](./docker/Dockerfile.benchmark.robomme) |
|
||||||
|
| RoboTwin | [`docker/Dockerfile.benchmark.robotwin`](./docker/Dockerfile.benchmark.robotwin) |
|
||||||
|
| VLABench | [`docker/Dockerfile.benchmark.vlabench`](./docker/Dockerfile.benchmark.vlabench) |
|
||||||
|
|
||||||
|
Build and run (adapt to your benchmark):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker build -f docker/Dockerfile.benchmark.robomme -t lerobot-bench-robomme .
|
||||||
|
docker run --gpus all --rm -it \
|
||||||
|
-v $HOME/.cache/huggingface:/root/.cache/huggingface \
|
||||||
|
lerobot-bench-robomme \
|
||||||
|
lerobot-eval --policy.path=<your_policy> --env.type=<env> --eval.n_episodes=50
|
||||||
|
```
|
||||||
|
|
||||||
|
See [`docker/README.md`](./docker/README.md) for base-image details.
|
||||||
|
|
||||||
|
### 8.3 Target success rates
|
||||||
|
|
||||||
|
Single-task grasp-and-place with 50 clean episodes: ACT should reach **> 70% success** on the training configuration. Less → data problem (see §5), not model problem. Expect a drop when generalizing to new positions — scale episodes or diversity to recover.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 9. Further reading & resources
|
||||||
|
|
||||||
|
- **Getting started:** [`installation.mdx`](./docs/source/installation.mdx) · [`il_robots.mdx`](./docs/source/il_robots.mdx) · [What makes a good dataset](https://huggingface.co/blog/lerobot-datasets)
|
||||||
|
- **Per-policy docs:** browse [`docs/source/*.mdx`](./docs/source/) (policies, hardware, benchmarks, advanced training).
|
||||||
|
- **Community:** [Discord](https://discord.com/invite/s3KuuzsPFb) · [Hub `LeRobot` tag](https://huggingface.co/datasets?other=LeRobot) · [Dataset visualizer](https://huggingface.co/spaces/lerobot/visualize_dataset)
|
||||||
|
|
||||||
|
> Keep this file current. If you learn a rule that would prevent a class of user mistakes, add it here and in [`AGENTS.md`](./AGENTS.md).
|
||||||
@@ -1,3 +1,4 @@
|
|||||||
include src/lerobot/templates/lerobot_modelcard_template.md
|
include src/lerobot/templates/lerobot_modelcard_template.md
|
||||||
|
include src/lerobot/templates/lerobot_rewardmodel_modelcard_template.md
|
||||||
include src/lerobot/datasets/card_template.md
|
include src/lerobot/datasets/card_template.md
|
||||||
include src/lerobot/envs/metaworld_config.json
|
include src/lerobot/envs/metaworld_config.json
|
||||||
|
|||||||
@@ -109,7 +109,7 @@ lerobot-train \
|
|||||||
|
|
||||||
Similarly to the hardware, you can easily implement your own policy & leverage LeRobot's data collection, training, and visualization tools, and share your model to the HF Hub
|
Similarly to the hardware, you can easily implement your own policy & leverage LeRobot's data collection, training, and visualization tools, and share your model to the HF Hub
|
||||||
|
|
||||||
For detailed policy setup guides, see the [Policy Documentation](https://huggingface.co/docs/lerobot/bring_your_own_policies).
|
For detailed policy setup guides, see the [Policy Documentation](https://huggingface.co/docs/lerobot/bring_your_own_policies). For GPU/RAM requirements and expected training time per policy, see the [Compute Hardware Guide](https://huggingface.co/docs/lerobot/hardware_guide).
|
||||||
|
|
||||||
## Inference & Evaluation
|
## Inference & Evaluation
|
||||||
|
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ USER root
|
|||||||
ARG ROBOTWIN_SHA=0aeea2d669c0f8516f4d5785f0aa33ba812c14b4
|
ARG ROBOTWIN_SHA=0aeea2d669c0f8516f4d5785f0aa33ba812c14b4
|
||||||
RUN apt-get update \
|
RUN apt-get update \
|
||||||
&& apt-get install -y --no-install-recommends \
|
&& apt-get install -y --no-install-recommends \
|
||||||
cuda-nvcc-12-4 cuda-cudart-dev-12-4 \
|
cuda-nvcc-12-8 cuda-cudart-dev-12-8 \
|
||||||
libvulkan1 vulkan-tools \
|
libvulkan1 vulkan-tools \
|
||||||
&& mkdir -p /usr/share/vulkan/icd.d \
|
&& mkdir -p /usr/share/vulkan/icd.d \
|
||||||
&& echo '{"file_format_version":"1.0.0","ICD":{"library_path":"libGLX_nvidia.so.0","api_version":"1.3.0"}}' \
|
&& echo '{"file_format_version":"1.0.0","ICD":{"library_path":"libGLX_nvidia.so.0","api_version":"1.3.0"}}' \
|
||||||
@@ -56,11 +56,11 @@ RUN uv pip install --no-cache --no-build-isolation \
|
|||||||
"git+https://github.com/facebookresearch/pytorch3d.git@stable"
|
"git+https://github.com/facebookresearch/pytorch3d.git@stable"
|
||||||
|
|
||||||
# CuRobo — NVlabs motion generator; TORCH_CUDA_ARCH_LIST must be set or the
|
# CuRobo — NVlabs motion generator; TORCH_CUDA_ARCH_LIST must be set or the
|
||||||
# build aborts on an empty arch list. Pinned SHA for reproducibility.
|
# build aborts on an empty arch list. RoboTwin's own installer pins v0.7.8,
|
||||||
ARG CUROBO_SHA=ca941586c33b8482ed9c0e74d60f23efd64b516a
|
# which still exposes the v1 API (`curobo.types.math`) that RoboTwin imports.
|
||||||
|
ARG CUROBO_REF=v0.7.8
|
||||||
RUN cd ${ROBOTWIN_ROOT}/envs \
|
RUN cd ${ROBOTWIN_ROOT}/envs \
|
||||||
&& git clone https://github.com/NVlabs/curobo.git \
|
&& git clone --branch ${CUROBO_REF} --depth 1 https://github.com/NVlabs/curobo.git \
|
||||||
&& git -C curobo checkout ${CUROBO_SHA} \
|
|
||||||
&& cd curobo \
|
&& cd curobo \
|
||||||
&& TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0;8.6;8.9;9.0" \
|
&& TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0;8.6;8.9;9.0" \
|
||||||
uv pip install -e . --no-build-isolation --no-cache
|
uv pip install -e . --no-build-isolation --no-cache
|
||||||
@@ -111,7 +111,23 @@ EOF
|
|||||||
WORKDIR ${ROBOTWIN_ROOT}
|
WORKDIR ${ROBOTWIN_ROOT}
|
||||||
RUN python script/update_embodiment_config_path.py
|
RUN python script/update_embodiment_config_path.py
|
||||||
|
|
||||||
ENV PYTHONPATH="${ROBOTWIN_ROOT}:${PYTHONPATH}"
|
ENV PYTHONPATH="${ROBOTWIN_ROOT}"
|
||||||
|
|
||||||
|
# Fail the image build early if the CuRobo package layout regresses. Importing
|
||||||
|
# RoboTwin's planner here is too eager because CuRobo constructs CUDA-backed
|
||||||
|
# defaults at import time, while Docker builds don't have access to an NVIDIA
|
||||||
|
# driver.
|
||||||
|
RUN python - <<'EOF'
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from curobo.types.math import Pose
|
||||||
|
|
||||||
|
planner_src = (Path("/opt/robotwin/envs/robot/planner.py")).read_text()
|
||||||
|
assert "from curobo.types.math import Pose as CuroboPose" in planner_src
|
||||||
|
|
||||||
|
print("CuRobo import OK:", Pose.__name__)
|
||||||
|
print("RoboTwin planner import references curobo.types.math")
|
||||||
|
EOF
|
||||||
|
|
||||||
# Return to the lerobot source directory (set by base image) before overlaying.
|
# Return to the lerobot source directory (set by base image) before overlaying.
|
||||||
WORKDIR /lerobot
|
WORKDIR /lerobot
|
||||||
|
|||||||
@@ -18,9 +18,8 @@
|
|||||||
# docker build -f docker/Dockerfile.internal -t lerobot-internal .
|
# docker build -f docker/Dockerfile.internal -t lerobot-internal .
|
||||||
|
|
||||||
# Configure the base image for CI with GPU access
|
# Configure the base image for CI with GPU access
|
||||||
# TODO(Steven): Bump these versions
|
ARG CUDA_VERSION=12.8.1
|
||||||
ARG CUDA_VERSION=12.4.1
|
ARG OS_VERSION=24.04
|
||||||
ARG OS_VERSION=22.04
|
|
||||||
FROM nvidia/cuda:${CUDA_VERSION}-base-ubuntu${OS_VERSION}
|
FROM nvidia/cuda:${CUDA_VERSION}-base-ubuntu${OS_VERSION}
|
||||||
|
|
||||||
# Define Python version argument
|
# Define Python version argument
|
||||||
@@ -36,16 +35,13 @@ ENV DEBIAN_FRONTEND=noninteractive \
|
|||||||
|
|
||||||
# Install Python, system dependencies, and uv (as root)
|
# Install Python, system dependencies, and uv (as root)
|
||||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
software-properties-common build-essential git curl \
|
build-essential git curl \
|
||||||
libglib2.0-0 libgl1-mesa-glx libegl1-mesa ffmpeg \
|
libglib2.0-0 libgl1 libegl1 ffmpeg \
|
||||||
libusb-1.0-0-dev speech-dispatcher libgeos-dev portaudio19-dev \
|
libusb-1.0-0-dev speech-dispatcher libgeos-dev portaudio19-dev \
|
||||||
cmake pkg-config ninja-build \
|
cmake pkg-config ninja-build \
|
||||||
&& add-apt-repository -y ppa:deadsnakes/ppa \
|
python${PYTHON_VERSION} \
|
||||||
&& apt-get update \
|
python${PYTHON_VERSION}-venv \
|
||||||
&& apt-get install -y --no-install-recommends \
|
python${PYTHON_VERSION}-dev \
|
||||||
python${PYTHON_VERSION} \
|
|
||||||
python${PYTHON_VERSION}-venv \
|
|
||||||
python${PYTHON_VERSION}-dev \
|
|
||||||
&& curl -LsSf https://astral.sh/uv/install.sh | sh \
|
&& curl -LsSf https://astral.sh/uv/install.sh | sh \
|
||||||
&& mv /root/.local/bin/uv /usr/local/bin/uv \
|
&& mv /root/.local/bin/uv /usr/local/bin/uv \
|
||||||
&& useradd --create-home --shell /bin/bash user_lerobot \
|
&& useradd --create-home --shell /bin/bash user_lerobot \
|
||||||
|
|||||||
@@ -8,7 +8,7 @@
|
|||||||
- local: il_robots
|
- local: il_robots
|
||||||
title: Imitation Learning for Robots
|
title: Imitation Learning for Robots
|
||||||
- local: bring_your_own_policies
|
- local: bring_your_own_policies
|
||||||
title: Bring Your Own Policies
|
title: Adding a Policy
|
||||||
- local: integrate_hardware
|
- local: integrate_hardware
|
||||||
title: Bring Your Own Hardware
|
title: Bring Your Own Hardware
|
||||||
- local: hilserl
|
- local: hilserl
|
||||||
@@ -24,6 +24,12 @@
|
|||||||
- local: rename_map
|
- local: rename_map
|
||||||
title: Using Rename Map and Empty Cameras
|
title: Using Rename Map and Empty Cameras
|
||||||
title: "Tutorials"
|
title: "Tutorials"
|
||||||
|
- sections:
|
||||||
|
- local: hardware_guide
|
||||||
|
title: Compute Hardware Guide
|
||||||
|
- local: torch_accelerators
|
||||||
|
title: PyTorch accelerators
|
||||||
|
title: "Compute & Hardware"
|
||||||
- sections:
|
- sections:
|
||||||
- local: lerobot-dataset-v3
|
- local: lerobot-dataset-v3
|
||||||
title: Using LeRobotDataset
|
title: Using LeRobotDataset
|
||||||
@@ -47,6 +53,10 @@
|
|||||||
title: π₀-FAST (Pi0Fast)
|
title: π₀-FAST (Pi0Fast)
|
||||||
- local: pi05
|
- local: pi05
|
||||||
title: π₀.₅ (Pi05)
|
title: π₀.₅ (Pi05)
|
||||||
|
- local: eo1
|
||||||
|
title: EO-1
|
||||||
|
- local: evo1
|
||||||
|
title: EVO1
|
||||||
- local: groot
|
- local: groot
|
||||||
title: NVIDIA GR00T N1.5
|
title: NVIDIA GR00T N1.5
|
||||||
- local: xvla
|
- local: xvla
|
||||||
@@ -61,6 +71,8 @@
|
|||||||
title: SARM
|
title: SARM
|
||||||
title: "Reward Models"
|
title: "Reward Models"
|
||||||
- sections:
|
- sections:
|
||||||
|
- local: inference
|
||||||
|
title: Policy Deployment (lerobot-rollout)
|
||||||
- local: async
|
- local: async
|
||||||
title: Use Async Inference
|
title: Use Async Inference
|
||||||
- local: rtc
|
- local: rtc
|
||||||
@@ -138,10 +150,6 @@
|
|||||||
- local: cameras
|
- local: cameras
|
||||||
title: Cameras
|
title: Cameras
|
||||||
title: "Sensors"
|
title: "Sensors"
|
||||||
- sections:
|
|
||||||
- local: torch_accelerators
|
|
||||||
title: PyTorch accelerators
|
|
||||||
title: "Supported Hardware"
|
|
||||||
- sections:
|
- sections:
|
||||||
- local: notebooks
|
- local: notebooks
|
||||||
title: Notebooks
|
title: Notebooks
|
||||||
|
|||||||
@@ -1,60 +1,37 @@
|
|||||||
# Bring Your Own Policies
|
# Adding a Policy
|
||||||
|
|
||||||
This tutorial explains how to integrate your own custom policy implementations into the LeRobot ecosystem, allowing you to leverage all LeRobot tools for training, evaluation, and deployment while using your own algorithms.
|
This guide walks you through implementing a custom policy and getting it to work with LeRobot's training, evaluation, and deployment tools. There are two paths:
|
||||||
|
|
||||||
## Step 1: Create a Policy Package
|
- **Plugin (out-of-tree)** — ship your policy as a standalone `lerobot_policy_*` package. Faster, no PR required, easy to iterate. Right for experimentation, internal use, or when you want to publish independently.
|
||||||
|
- **In-tree (contributed to LeRobot)** — land your policy directly in `src/lerobot/policies/`. Requires a PR, but makes your policy a first-class citizen of the library.
|
||||||
|
|
||||||
Your custom policy should be organized as an installable Python package following LeRobot's plugin conventions.
|
The plugin route is usually the right starting point — promote to in-tree once the policy has stabilized and there's clear value in shipping it with the library.
|
||||||
|
|
||||||
### Package Structure
|
Either way, the building blocks are the same: a configuration class, a policy class, and a processor factory. The first half of this guide covers those shared pieces; the second half covers the path-specific scaffolding ([Path A](#path-a-out-of-tree-plugin), [Path B](#path-b-contributing-in-tree)).
|
||||||
|
|
||||||
Create a package with the prefix `lerobot_policy_` (IMPORTANT!) followed by your policy name:
|
A note on tone: robot-learning is an actively evolving field, and "what a policy looks like" can shift with each new architecture. The conventions described here exist because they let `lerobot-train` and `lerobot-eval` work uniformly across very different models. When a new policy genuinely doesn't fit them, raise it (in your PR, or an issue) — the conventions are not sacred.
|
||||||
|
|
||||||
```bash
|
---
|
||||||
lerobot_policy_my_custom_policy/
|
|
||||||
├── pyproject.toml
|
|
||||||
└── src/
|
|
||||||
└── lerobot_policy_my_custom_policy/
|
|
||||||
├── __init__.py
|
|
||||||
├── configuration_my_custom_policy.py
|
|
||||||
├── modeling_my_custom_policy.py
|
|
||||||
└── processor_my_custom_policy.py
|
|
||||||
```
|
|
||||||
|
|
||||||
### Package Configuration
|
## Anatomy of a policy
|
||||||
|
|
||||||
Set up your `pyproject.toml`:
|
Three building blocks make up every policy. The names below use `my_policy` as a placeholder — replace with your policy's name. That name is load-bearing: it must match the string you pass to `@PreTrainedConfig.register_subclass`, the `MyPolicy.name` class attribute, and the `make_<name>_pre_post_processors` factory function (more on each below).
|
||||||
|
|
||||||
```toml
|
### Configuration class
|
||||||
[project]
|
|
||||||
name = "lerobot_policy_my_custom_policy"
|
|
||||||
version = "0.1.0"
|
|
||||||
dependencies = [
|
|
||||||
# your policy-specific dependencies
|
|
||||||
]
|
|
||||||
requires-python = ">= 3.12"
|
|
||||||
|
|
||||||
[build-system]
|
Inherit from [`PreTrainedConfig`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/configs/policies.py) and register your policy type. Here is a template — customize the parameters and methods as needed for your policy's architecture and training requirements.
|
||||||
build-backend = # your-build-backend
|
|
||||||
requires = # your-build-system
|
|
||||||
```
|
|
||||||
|
|
||||||
## Step 2: Define the Policy Configuration
|
|
||||||
|
|
||||||
Create a configuration class that inherits from [`PreTrainedConfig`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/configs/policies.py) and registers your policy type:
|
|
||||||
Here is a template to get you started, customize the parameters and methods as needed for your policy's architecture and training requirements.
|
|
||||||
|
|
||||||
```python
|
```python
|
||||||
# configuration_my_custom_policy.py
|
# configuration_my_policy.py
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from lerobot.configs import PreTrainedConfig
|
from lerobot.configs import PreTrainedConfig
|
||||||
from lerobot.optim import AdamWConfig
|
from lerobot.optim import AdamWConfig
|
||||||
from lerobot.optim import CosineDecayWithWarmupSchedulerConfig
|
from lerobot.optim import CosineDecayWithWarmupSchedulerConfig
|
||||||
|
|
||||||
@PreTrainedConfig.register_subclass("my_custom_policy")
|
@PreTrainedConfig.register_subclass("my_policy")
|
||||||
@dataclass
|
@dataclass
|
||||||
class MyCustomPolicyConfig(PreTrainedConfig):
|
class MyPolicyConfig(PreTrainedConfig):
|
||||||
"""Configuration class for MyCustomPolicy.
|
"""Configuration class for MyPolicy.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
n_obs_steps: Number of observation steps to use as input
|
n_obs_steps: Number of observation steps to use as input
|
||||||
@@ -77,16 +54,20 @@ class MyCustomPolicyConfig(PreTrainedConfig):
|
|||||||
raise ValueError("n_action_steps cannot exceed horizon")
|
raise ValueError("n_action_steps cannot exceed horizon")
|
||||||
|
|
||||||
def validate_features(self) -> None:
|
def validate_features(self) -> None:
|
||||||
"""Validate input/output feature compatibility."""
|
"""Validate input/output feature compatibility.
|
||||||
|
|
||||||
|
Call this explicitly from your policy's __init__ — the base class does not.
|
||||||
|
"""
|
||||||
if not self.image_features:
|
if not self.image_features:
|
||||||
raise ValueError("MyCustomPolicy requires at least one image feature.")
|
raise ValueError("MyPolicy requires at least one image feature.")
|
||||||
if self.action_feature is None:
|
if self.action_feature is None:
|
||||||
raise ValueError("MyCustomPolicy requires 'action' in output_features.")
|
raise ValueError("MyPolicy requires 'action' in output_features.")
|
||||||
|
|
||||||
def get_optimizer_preset(self) -> AdamWConfig:
|
def get_optimizer_preset(self) -> AdamWConfig:
|
||||||
return AdamWConfig(lr=self.optimizer_lr, weight_decay=self.optimizer_weight_decay)
|
return AdamWConfig(lr=self.optimizer_lr, weight_decay=self.optimizer_weight_decay)
|
||||||
|
|
||||||
def get_scheduler_preset(self):
|
def get_scheduler_preset(self):
|
||||||
|
"""Return a LRSchedulerConfig from lerobot.optim, or None."""
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -101,8 +82,7 @@ class MyCustomPolicyConfig(PreTrainedConfig):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def action_delta_indices(self) -> list[int]:
|
def action_delta_indices(self) -> list[int]:
|
||||||
"""Relative timestep offsets for the action chunk the dataset loader returns.
|
"""Relative timestep offsets for the action chunk the dataset loader returns."""
|
||||||
"""
|
|
||||||
return list(range(self.horizon))
|
return list(range(self.horizon))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -110,32 +90,34 @@ class MyCustomPolicyConfig(PreTrainedConfig):
|
|||||||
return None
|
return None
|
||||||
```
|
```
|
||||||
|
|
||||||
## Step 3: Implement the Policy Class
|
The string you pass to `@register_subclass` must match `MyPolicy.name` (next section) and is what users supply as `--policy.type` on the CLI. Default to `AdamW` from `lerobot.optim` for `get_optimizer_preset` unless you genuinely need otherwise.
|
||||||
|
|
||||||
Create your policy implementation by inheriting from [`PreTrainedPolicy`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/pretrained.py):
|
### Policy class
|
||||||
|
|
||||||
|
Inherit from [`PreTrainedPolicy`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/pretrained.py) and set two class attributes — both are checked by `__init_subclass__`:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
# modeling_my_custom_policy.py
|
# modeling_my_policy.py
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from lerobot.policies import PreTrainedPolicy
|
from lerobot.policies import PreTrainedPolicy
|
||||||
from lerobot.utils.constants import ACTION
|
from lerobot.utils.constants import ACTION
|
||||||
from .configuration_my_custom_policy import MyCustomPolicyConfig
|
from .configuration_my_policy import MyPolicyConfig
|
||||||
|
|
||||||
class MyCustomPolicy(PreTrainedPolicy):
|
class MyPolicy(PreTrainedPolicy):
|
||||||
config_class = MyCustomPolicyConfig # must match the string in @register_subclass
|
config_class = MyPolicyConfig # must match the string in @register_subclass
|
||||||
name = "my_custom_policy"
|
name = "my_policy"
|
||||||
|
|
||||||
def __init__(self, config: MyCustomPolicyConfig, dataset_stats: dict[str, Any] = None):
|
def __init__(self, config: MyPolicyConfig, dataset_stats: dict[str, Any] = None):
|
||||||
super().__init__(config, dataset_stats)
|
super().__init__(config, dataset_stats)
|
||||||
config.validate_features() # not called automatically by the base class
|
config.validate_features() # not called automatically by the base class
|
||||||
self.config = config
|
self.config = config
|
||||||
self.model = ... # your nn.Module here
|
self.model = ... # your nn.Module here
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""Reset episode state."""
|
"""Reset per-episode state. Called by lerobot-eval at the start of each episode."""
|
||||||
...
|
...
|
||||||
|
|
||||||
def get_optim_params(self) -> dict:
|
def get_optim_params(self) -> dict:
|
||||||
@@ -147,35 +129,51 @@ class MyCustomPolicy(PreTrainedPolicy):
|
|||||||
...
|
...
|
||||||
|
|
||||||
def select_action(self, batch: dict[str, torch.Tensor], **kwargs) -> torch.Tensor:
|
def select_action(self, batch: dict[str, torch.Tensor], **kwargs) -> torch.Tensor:
|
||||||
"""Return a single action for the current timestep (called at inference)."""
|
"""Return a single action for the current timestep (called every step at inference)."""
|
||||||
...
|
...
|
||||||
|
|
||||||
def forward(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
def forward(self, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, dict | None]:
|
||||||
"""Compute the training loss.
|
"""Compute the training loss.
|
||||||
|
|
||||||
|
Returns `(loss, output_dict)`. `output_dict` may be `None`; everything in it must be
|
||||||
|
logging-friendly Python natives (no tensors with gradients).
|
||||||
|
|
||||||
`batch["action_is_pad"]` is a bool mask of shape (B, horizon) that marks
|
`batch["action_is_pad"]` is a bool mask of shape (B, horizon) that marks
|
||||||
timesteps padded because the episode ended before `horizon` steps, you
|
timesteps padded because the episode ended before `horizon` steps; you
|
||||||
can exclude those from your loss.
|
can exclude those from your loss.
|
||||||
"""
|
"""
|
||||||
actions = batch[ACTION]
|
actions = batch[ACTION]
|
||||||
action_is_pad = batch.get("action_is_pad")
|
action_is_pad = batch.get("action_is_pad")
|
||||||
...
|
...
|
||||||
return {"loss": ...}
|
return loss, {"some_loss_component": some_loss_component.item()}
|
||||||
```
|
```
|
||||||
|
|
||||||
## Step 4: Add Data Processors
|
The methods called by the train/eval loops:
|
||||||
|
|
||||||
Create processor functions. For a concrete reference, see [processor_act.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/act/processor_act.py) or [processor_diffusion.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/diffusion/processor_diffusion.py).
|
| Method | Used by | What it does |
|
||||||
|
| ----------------------------------------------------------------- | ----------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||||
|
| `reset() -> None` | `lerobot-eval` | Clear per-episode state at the start of each episode. |
|
||||||
|
| `select_action(batch, **kwargs) -> Tensor` | `lerobot-eval` | Return the next action `(B, action_dim)`. Called every step. |
|
||||||
|
| `predict_action_chunk(batch, **kwargs) -> Tensor` | the policy itself | Return an action chunk `(B, chunk_size, action_dim)`. Currently abstract on the base class — raise `NotImplementedError` if your policy doesn't chunk. |
|
||||||
|
| `forward(batch, reduction="mean") -> tuple[Tensor, dict \| None]` | `lerobot-train` | Return `(loss, output_dict)`. Accept `reduction="none"` if you want to support per-sample weighting. |
|
||||||
|
| `get_optim_params() -> dict` | the optimizer | Return `self.parameters()` for simple policies; return a named parameter dict for [multi-optimizer policies](https://github.com/huggingface/lerobot/blob/ecd38c50d7d15b4184cf42649ff1185ee2e11eeb/src/lerobot/policies/sac/modeling_sac.py#L61-L73). |
|
||||||
|
| `update() -> None` _(optional)_ | `lerobot-train` | Called after each optimizer step _if defined_. Use for EMA, target nets, replay buffers (TDMPC uses this). |
|
||||||
|
|
||||||
|
Batches are flat dictionaries keyed by the constants in [`lerobot.utils.constants`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/utils/constants.py): `OBS_STATE` (`observation.state.<motor>`), `OBS_IMAGES` (`observation.images.<camera>`), `OBS_LANGUAGE`, `ACTION`, etc. Reuse the constants — don't invent new prefixes.
|
||||||
|
|
||||||
|
### Processor functions
|
||||||
|
|
||||||
|
LeRobot uses `PolicyProcessorPipeline`s to normalize inputs and de-normalize outputs around your policy. For a concrete reference, see [`processor_act.py`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/act/processor_act.py) or [`processor_diffusion.py`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/diffusion/processor_diffusion.py).
|
||||||
|
|
||||||
```python
|
```python
|
||||||
# processor_my_custom_policy.py
|
# processor_my_policy.py
|
||||||
from typing import Any
|
from typing import Any
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
|
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
|
||||||
|
|
||||||
|
|
||||||
def make_my_custom_policy_pre_post_processors(
|
def make_my_policy_pre_post_processors(
|
||||||
config,
|
config,
|
||||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||||
) -> tuple[
|
) -> tuple[
|
||||||
@@ -187,11 +185,48 @@ def make_my_custom_policy_pre_post_processors(
|
|||||||
return preprocessor, postprocessor
|
return preprocessor, postprocessor
|
||||||
```
|
```
|
||||||
|
|
||||||
**Important - function naming:** LeRobot discovers your processor by name. The function **must** be called `make_{policy_name}_pre_post_processors` (matching the string you passed to `@PreTrainedConfig.register_subclass`).
|
**Important — function naming:** LeRobot discovers your processor by name. The function **must** be called `make_{policy_name}_pre_post_processors` (matching the string you passed to `@PreTrainedConfig.register_subclass`).
|
||||||
|
|
||||||
## Step 5: Package Initialization
|
---
|
||||||
|
|
||||||
Expose your classes in the package's `__init__.py`:
|
## Path A: Out-of-tree plugin
|
||||||
|
|
||||||
|
The fastest way to ship a policy: package it as a standalone Python distribution and install it alongside LeRobot. No PR required, you own the release cycle, and you can publish to PyPI under your own namespace.
|
||||||
|
|
||||||
|
### Package structure
|
||||||
|
|
||||||
|
Create a package with the prefix `lerobot_policy_` (IMPORTANT!) followed by your policy name:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot_policy_my_policy/
|
||||||
|
├── pyproject.toml
|
||||||
|
└── src/
|
||||||
|
└── lerobot_policy_my_policy/
|
||||||
|
├── __init__.py
|
||||||
|
├── configuration_my_policy.py
|
||||||
|
├── modeling_my_policy.py
|
||||||
|
└── processor_my_policy.py
|
||||||
|
```
|
||||||
|
|
||||||
|
### `pyproject.toml`
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[project]
|
||||||
|
name = "lerobot_policy_my_policy"
|
||||||
|
version = "0.1.0"
|
||||||
|
dependencies = [
|
||||||
|
# your policy-specific dependencies
|
||||||
|
]
|
||||||
|
requires-python = ">= 3.12"
|
||||||
|
|
||||||
|
[build-system]
|
||||||
|
build-backend = # your-build-backend
|
||||||
|
requires = # your-build-system
|
||||||
|
```
|
||||||
|
|
||||||
|
### Package `__init__.py`
|
||||||
|
|
||||||
|
Expose your classes in the package's `__init__.py` and guard against missing `lerobot`:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
# __init__.py
|
# __init__.py
|
||||||
@@ -204,44 +239,148 @@ except ImportError:
|
|||||||
"lerobot is not installed. Please install lerobot to use this policy package."
|
"lerobot is not installed. Please install lerobot to use this policy package."
|
||||||
)
|
)
|
||||||
|
|
||||||
from .configuration_my_custom_policy import MyCustomPolicyConfig
|
from .configuration_my_policy import MyPolicyConfig
|
||||||
from .modeling_my_custom_policy import MyCustomPolicy
|
from .modeling_my_policy import MyPolicy
|
||||||
from .processor_my_custom_policy import make_my_custom_policy_pre_post_processors
|
from .processor_my_policy import make_my_policy_pre_post_processors
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"MyCustomPolicyConfig",
|
"MyPolicyConfig",
|
||||||
"MyCustomPolicy",
|
"MyPolicy",
|
||||||
"make_my_custom_policy_pre_post_processors",
|
"make_my_policy_pre_post_processors",
|
||||||
]
|
]
|
||||||
```
|
```
|
||||||
|
|
||||||
## Step 6: Installation and Usage
|
### Install and use
|
||||||
|
|
||||||
### Install Your Policy Package
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
cd lerobot_policy_my_custom_policy
|
cd lerobot_policy_my_policy
|
||||||
pip install -e .
|
pip install -e .
|
||||||
|
|
||||||
# Or install from PyPI if published
|
# Or install from PyPI if published
|
||||||
pip install lerobot_policy_my_custom_policy
|
pip install lerobot_policy_my_policy
|
||||||
```
|
```
|
||||||
|
|
||||||
### Use Your Policy
|
|
||||||
|
|
||||||
Once installed, your policy automatically integrates with LeRobot's training and evaluation tools:
|
Once installed, your policy automatically integrates with LeRobot's training and evaluation tools:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
lerobot-train \
|
lerobot-train \
|
||||||
--policy.type my_custom_policy \
|
--policy.type my_policy \
|
||||||
--env.type pusht \
|
--env.type pusht \
|
||||||
--steps 200000
|
--steps 200000
|
||||||
```
|
```
|
||||||
|
|
||||||
## Examples and Community Contributions
|
---
|
||||||
|
|
||||||
|
## Path B: Contributing in-tree
|
||||||
|
|
||||||
|
When your policy has stabilized and there's clear value in shipping it with the library, you can land it directly in LeRobot. Read the general [contribution guide](./contributing) and the [PR template](https://github.com/huggingface/lerobot/blob/main/.github/PULL_REQUEST_TEMPLATE.md) first — that's where you'll find the testing/quality expectations every PR has to meet (`pre-commit run -a`, `pytest`, the community-review rule, etc.). What's below is the policy-specific layer on top of that.
|
||||||
|
|
||||||
|
### In-tree layout
|
||||||
|
|
||||||
|
```
|
||||||
|
src/lerobot/policies/my_policy/
|
||||||
|
├── __init__.py # re-exports config + modeling + processor factory
|
||||||
|
├── configuration_my_policy.py # MyPolicyConfig + @register_subclass
|
||||||
|
├── modeling_my_policy.py # MyPolicy(PreTrainedPolicy)
|
||||||
|
├── processor_my_policy.py # make_my_policy_pre_post_processors
|
||||||
|
└── README.md # symlink → ../../../../docs/source/policy_my_policy_README.md
|
||||||
|
```
|
||||||
|
|
||||||
|
Two notes:
|
||||||
|
|
||||||
|
- The `README.md` next to the source is a **symlink** into `docs/source/policy_<name>_README.md` — the actual file lives under `docs/`. Existing policies (act, smolvla, diffusion, …) all do this; copy one of those symlinks. The policy README is conventionally minimal: paper link + BibTeX citation.
|
||||||
|
- The user-facing tutorial — what to install, how to train, hyperparameters, benchmark numbers — lives separately at `docs/source/<my_policy>.mdx` and is registered in `_toctree.yml` under "Policies".
|
||||||
|
|
||||||
|
The file names are load-bearing: the factory does lazy imports by name, and the processor is discovered by the `make_<policy_name>_pre_post_processors` convention.
|
||||||
|
|
||||||
|
### Wiring
|
||||||
|
|
||||||
|
Three places need to know about your policy. All by name.
|
||||||
|
|
||||||
|
1. **`policies/__init__.py`** — re-export `MyPolicyConfig` and add it to `__all__`. **Don't** re-export the modeling class; it loads lazily through the factory (so `import lerobot` stays fast).
|
||||||
|
2. **`factory.py:get_policy_class`** — add a branch returning `MyPolicy` from a lazy import.
|
||||||
|
3. **`factory.py:make_policy_config`** and **`factory.py:make_pre_post_processors`** — same idea, two more branches.
|
||||||
|
|
||||||
|
Mirror an existing policy that's structurally similar to yours; the diff is small.
|
||||||
|
|
||||||
|
### Heavy / optional dependencies
|
||||||
|
|
||||||
|
Most policies need a heavy backbone (transformers, diffusers, a specific VLM SDK). The convention is **two-step gating**: a `TYPE_CHECKING`-guarded import at module top, and a `require_package` runtime check in the constructor. [`modeling_diffusion.py`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/diffusion/modeling_diffusion.py) is the canonical reference:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
from lerobot.utils.import_utils import _diffusers_available, require_package
|
||||||
|
|
||||||
|
if TYPE_CHECKING or _diffusers_available:
|
||||||
|
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
||||||
|
else:
|
||||||
|
DDIMScheduler = None # keeps the symbol bindable at import time
|
||||||
|
|
||||||
|
class DiffusionPolicy(PreTrainedPolicy):
|
||||||
|
def __init__(self, config):
|
||||||
|
require_package("diffusers", extra="diffusion")
|
||||||
|
super().__init__(config)
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
This way:
|
||||||
|
|
||||||
|
- `import lerobot.policies` keeps working without the extra installed (the symbol is just bound to `None`).
|
||||||
|
- Type checkers see the real symbol.
|
||||||
|
- Instantiating the policy without the extra raises a clear `ImportError` pointing at `pip install 'lerobot[diffusion]'`.
|
||||||
|
|
||||||
|
Add a matching extra to [`pyproject.toml`](https://github.com/huggingface/lerobot/blob/main/pyproject.toml) `[project.optional-dependencies]` and include it in the `all` extra so `pip install 'lerobot[all]'` keeps installing everything.
|
||||||
|
|
||||||
|
### Benchmarks and a published checkpoint
|
||||||
|
|
||||||
|
A new policy is much easier to review — and far more useful — when it ships with a working checkpoint and at least one number you can reproduce.
|
||||||
|
|
||||||
|
**Pick at least one in-tree benchmark.** LeRobot ships sim benchmarks with per-benchmark Docker images (LIBERO, LIBERO-plus, Meta-World, RoboTwin 2.0, RoboCasa365, RoboCerebra, RoboMME, VLABench and more). Pick the one that matches your policy's modality — VLAs usually go to LIBERO or VLABench; image-only BC to LIBERO or Meta-World. The full list lives under [Benchmarks](./libero) in the docs sidebar.
|
||||||
|
|
||||||
|
**Push the checkpoint & processors** to the Hub under `lerobot/<policy>_<benchmark>` (or your namespace if you don't have write access; a maintainer can mirror it). Use `PreTrainedPolicy.push_model_to_hub` so the repo gets `config.json`, `model.safetensors`, and a model card.
|
||||||
|
|
||||||
|
**Report results in your policy's MDX**, with the exact `lerobot-eval` command and hardware so anyone can re-run:
|
||||||
|
|
||||||
|
```markdown
|
||||||
|
## Results
|
||||||
|
|
||||||
|
Evaluated on LIBERO with `lerobot/<policy>_libero`:
|
||||||
|
|
||||||
|
| Suite | Success rate | n_episodes |
|
||||||
|
| -------------- | -----------: | ---------: |
|
||||||
|
| libero_spatial | 87.5% | 50 |
|
||||||
|
| libero_object | 93.0% | 50 |
|
||||||
|
| libero_goal | 81.5% | 50 |
|
||||||
|
| libero_10 | 62.0% | 50 |
|
||||||
|
| **average** | **81.0%** | 200 |
|
||||||
|
|
||||||
|
Reproduce: `lerobot-eval --policy.path=lerobot/<policy>_libero --env.type=libero --env.task=libero_spatial --eval.n_episodes=50` (1× A100 40 GB).
|
||||||
|
```
|
||||||
|
|
||||||
|
Use `n_episodes ≥ 50` per suite for stable success-rate estimates.
|
||||||
|
|
||||||
|
If your policy is real-robot-only and no sim benchmark applies, swap the sim eval for: a public training dataset on the Hub, the `lerobot-train` command, the checkpoint, and a real-robot success rate over ≥10 episodes via `lerobot-rollout --policy.path=...`.
|
||||||
|
|
||||||
|
### PR checklist
|
||||||
|
|
||||||
|
The general expectations are in [`CONTRIBUTING.md`](https://github.com/huggingface/lerobot/blob/main/CONTRIBUTING.md) and the [PR template](https://github.com/huggingface/lerobot/blob/main/.github/PULL_REQUEST_TEMPLATE.md). On top of those, reviewers will look for:
|
||||||
|
|
||||||
|
- [ ] `MyPolicy` and `MyPolicyConfig` cover the surface above; `__init_subclass__` accepts the class.
|
||||||
|
- [ ] `factory.py` and `policies/__init__.py` are wired (lazy imports for modeling).
|
||||||
|
- [ ] `make_my_policy_pre_post_processors` follows the naming convention.
|
||||||
|
- [ ] Optional deps live behind a `[project.optional-dependencies]` extra and the `TYPE_CHECKING + require_package` guard.
|
||||||
|
- [ ] `tests/policies/` updated; backward-compat artifact committed & policy-specific tests.
|
||||||
|
- [ ] `src/lerobot/policies/<name>/README.md` symlinked into `docs/source/policy_<name>_README.md`; user-facing `docs/source/<name>.mdx` written and added to `_toctree.yml`.
|
||||||
|
- [ ] At least one reproducible benchmark eval in the policy MDX with a published checkpoint (sim benchmark, or real-robot dataset + checkpoint).
|
||||||
|
|
||||||
|
The fastest way to get a clean PR is to copy the directory of the existing policy closest to yours, rename, and replace contents method by method. Don't wait until everything is polished — open a draft PR early and iterate with us; reviewers would much rather give feedback on a half-finished branch than a fully-merged one.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Examples and community contributions
|
||||||
|
|
||||||
Check out these example policy implementations:
|
Check out these example policy implementations:
|
||||||
|
|
||||||
- [DiTFlow Policy](https://github.com/danielsanjosepro/lerobot_policy_ditflow) - Diffusion Transformer policy with flow-matching objective. Try it out in this example: [DiTFlow Example](https://github.com/danielsanjosepro/test_lerobot_policy_ditflow)
|
- [DiTFlow Policy](https://github.com/danielsanjosepro/lerobot_policy_ditflow) — Diffusion Transformer policy with flow-matching objective. Try it out in this example: [DiTFlow Example](https://github.com/danielsanjosepro/test_lerobot_policy_ditflow)
|
||||||
|
|
||||||
Share your policy implementations with the community! 🤗
|
Thanks for taking the time to bring a new policy into LeRobot. Every architecture that lands in `main` — and every plugin published by the community — makes the library a little more useful for the next person, and a little more representative of where robot learning is going. We're looking forward to seeing what you ship. 🤗
|
||||||
|
|||||||
@@ -0,0 +1,168 @@
|
|||||||
|
# EO-1
|
||||||
|
|
||||||
|
EO-1 is a **Vision-Language-Action policy for robot control**. The LeRobot implementation integrates EO-1 with the standard LeRobot training, evaluation, processor interface.
|
||||||
|
|
||||||
|
## Model Overview
|
||||||
|
|
||||||
|
EO-1 uses a Qwen2.5-VL backbone for vision-language understanding and adds a continuous flow-matching action head for robot control. The policy formats each robot-control sample as a multimodal conversation: camera images are passed to Qwen2.5-VL, the robot state is represented with EO-1 state tokens, and the future action chunk is represented with EO-1 action tokens.
|
||||||
|
|
||||||
|
<img
|
||||||
|
src="https://huggingface.co/datasets/HaomingSong/lerobot-documentation-images/resolve/main/lerobot/eo_pipeline.png"
|
||||||
|
alt="An overview of EO-1"
|
||||||
|
width="85%"
|
||||||
|
/>
|
||||||
|
|
||||||
|
During training, EO-1 learns to denoise continuous action chunks at the action-token positions. During inference, it samples an action chunk, returns continuous actions, and executes `n_action_steps` from the chunk before sampling again.
|
||||||
|
|
||||||
|
### What the LeRobot Integration Covers
|
||||||
|
|
||||||
|
- Standard `policy.type=eo1` configuration through LeRobot
|
||||||
|
- Qwen2.5-VL image and text preprocessing through policy processors
|
||||||
|
- Continuous flow-matching action prediction
|
||||||
|
- Checkpoint save/load through LeRobot policy APIs
|
||||||
|
- Training with `lerobot-train` and evaluation with `lerobot-eval`
|
||||||
|
|
||||||
|
The broader EO-1 project also includes interleaved vision-text-action pretraining and multimodal reasoning workflows. This page focuses on the LeRobot robot-control policy path.
|
||||||
|
|
||||||
|
## Installation Requirements
|
||||||
|
|
||||||
|
1. Install LeRobot by following the [Installation Guide](./installation).
|
||||||
|
2. Install EO-1 dependencies by running:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install -e ".[eo1]"
|
||||||
|
```
|
||||||
|
|
||||||
|
3. If you want to train or evaluate on LIBERO, install the LIBERO dependencies too:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install -e ".[eo1,libero]"
|
||||||
|
```
|
||||||
|
|
||||||
|
EO-1 can use the standard PyTorch scaled-dot-product attention backend through `policy.attn_implementation=sdpa`. If your environment has a compatible `flash_attn` installation, you can request `policy.attn_implementation=flash_attention_2`.
|
||||||
|
|
||||||
|
## Data Requirements
|
||||||
|
|
||||||
|
EO-1 expects a LeRobot dataset with:
|
||||||
|
|
||||||
|
- At least one visual observation, for example `observation.images.image`
|
||||||
|
- `observation.state`
|
||||||
|
- `action`
|
||||||
|
- A language task instruction through the dataset `task` field
|
||||||
|
|
||||||
|
If your dataset uses different observation names, use `rename_map` to align them with the names expected by your training or evaluation setup.
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
To use EO-1 in a LeRobot configuration, specify the policy type as:
|
||||||
|
|
||||||
|
```python
|
||||||
|
policy.type=eo1
|
||||||
|
```
|
||||||
|
|
||||||
|
By default, a new EO-1 policy initializes its backbone from:
|
||||||
|
|
||||||
|
```python
|
||||||
|
policy.vlm_base=Qwen/Qwen2.5-VL-3B-Instruct
|
||||||
|
```
|
||||||
|
|
||||||
|
Once a LeRobot-format EO-1 checkpoint is available, load it with:
|
||||||
|
|
||||||
|
```python
|
||||||
|
policy.path=your-org/your-eo1-checkpoint
|
||||||
|
```
|
||||||
|
|
||||||
|
## Training
|
||||||
|
|
||||||
|
### Training Command Example
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-train \
|
||||||
|
--dataset.repo_id=your_org/your_dataset \
|
||||||
|
--policy.type=eo1 \
|
||||||
|
--policy.vlm_base=Qwen/Qwen2.5-VL-3B-Instruct \
|
||||||
|
--policy.dtype=bfloat16 \
|
||||||
|
--policy.attn_implementation=sdpa \
|
||||||
|
--policy.gradient_checkpointing=false \
|
||||||
|
--output_dir=./outputs/eo1_training \
|
||||||
|
--job_name=eo1_training \
|
||||||
|
--steps=300000 \
|
||||||
|
--batch_size=16 \
|
||||||
|
--policy.device=cuda
|
||||||
|
```
|
||||||
|
|
||||||
|
### Key Training Parameters
|
||||||
|
|
||||||
|
| Parameter | Default | Description |
|
||||||
|
| -------------------------------------- | ----------------------------- | ----------------------------------------------------------------------- |
|
||||||
|
| `policy.vlm_base` | `Qwen/Qwen2.5-VL-3B-Instruct` | Qwen2.5-VL checkpoint used to initialize a new policy |
|
||||||
|
| `policy.dtype` | `auto` | Backbone dtype request: `auto`, `bfloat16`, or `float32` |
|
||||||
|
| `policy.attn_implementation` | `None` | Optional Qwen attention backend, such as `sdpa` |
|
||||||
|
| `policy.gradient_checkpointing` | `false` | Reduces memory usage during training |
|
||||||
|
| `policy.chunk_size` | `8` | Number of future actions predicted per chunk |
|
||||||
|
| `policy.n_action_steps` | `8` | Number of actions consumed from a sampled chunk |
|
||||||
|
| `policy.num_denoise_steps` | `10` | Number of flow-matching denoising steps used during sampling |
|
||||||
|
| `policy.max_state_dim` | `32` | State padding dimension |
|
||||||
|
| `policy.max_action_dim` | `32` | Action padding dimension |
|
||||||
|
| `policy.force_fp32_autocast` | `true` | Keeps the flow head in fp32 even when the backbone uses mixed precision |
|
||||||
|
| `policy.supervise_padding_action_dims` | `true` | Controls whether padded action dimensions are supervised |
|
||||||
|
| `policy.supervise_padding_actions` | `true` | Controls whether padded future action rows are supervised |
|
||||||
|
|
||||||
|
## Evaluation
|
||||||
|
|
||||||
|
EO-1 can be evaluated through `lerobot-eval` once you have a LeRobot-format checkpoint:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-eval \
|
||||||
|
--policy.path=your-org/your-eo1-checkpoint \
|
||||||
|
--env.type=libero \
|
||||||
|
--env.task=libero_object \
|
||||||
|
--eval.batch_size=1 \
|
||||||
|
--eval.n_episodes=20
|
||||||
|
```
|
||||||
|
|
||||||
|
For datasets or environments whose camera names differ from the checkpoint configuration, pass a `rename_map`:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-eval \
|
||||||
|
--policy.path=your-org/your-eo1-checkpoint \
|
||||||
|
--env.type=libero \
|
||||||
|
--env.task=libero_object \
|
||||||
|
--rename_map='{"observation.images.image2":"observation.images.wrist_image"}'
|
||||||
|
```
|
||||||
|
|
||||||
|
## Configuration Notes
|
||||||
|
|
||||||
|
### Image Processing
|
||||||
|
|
||||||
|
EO-1 uses the Qwen2.5-VL processor. The `policy.image_min_pixels` and `policy.image_max_pixels` settings control the image resizing bounds before the visual tokens are passed into the backbone.
|
||||||
|
|
||||||
|
### State and Action Dimensions
|
||||||
|
|
||||||
|
The policy pads state and action vectors to `policy.max_state_dim` and `policy.max_action_dim` before the EO-1 flow head. Predictions are cropped back to the original action dimension before being returned by the policy.
|
||||||
|
|
||||||
|
### Attention Backend
|
||||||
|
|
||||||
|
Use `policy.attn_implementation=sdpa` for a portable setup. Use `flash_attention_2` only when `flash_attn` is installed and compatible with your environment.
|
||||||
|
|
||||||
|
## References
|
||||||
|
|
||||||
|
- [EO-1 project](https://github.com/EO-Robotics/EO1)
|
||||||
|
- [EO-1 paper](https://arxiv.org/abs/2508.21112)
|
||||||
|
- [Qwen2.5-VL-3B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-3B-Instruct)
|
||||||
|
|
||||||
|
## Citation
|
||||||
|
|
||||||
|
```bibtex
|
||||||
|
@article{eo1,
|
||||||
|
title={EO-1: Interleaved Vision-Text-Action Pretraining for General Robot Control},
|
||||||
|
author={Delin Qu and Haoming Song and Qizhi Chen and Zhaoqing Chen and Xianqiang Gao and Xinyi Ye and Qi Lv and Modi Shi and Guanghui Ren and Cheng Ruan and Maoqing Yao and Haoran Yang and Jiacheng Bao and Bin Zhao and Dong Wang},
|
||||||
|
journal={arXiv preprint},
|
||||||
|
year={2025},
|
||||||
|
url={https://arxiv.org/abs/2508.21112}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
This LeRobot integration follows the **Apache 2.0 License** used by LeRobot. Check the upstream EO-1 model and dataset pages for the licenses of released EO-1 checkpoints and data.
|
||||||
@@ -0,0 +1,186 @@
|
|||||||
|
# EVO1
|
||||||
|
|
||||||
|
EVO1 is a Vision-Language-Action policy for robot control built around an InternVL3 backbone and a continuous flow-matching action head. This LeRobot integration exposes EVO1 as a standard policy type so it can be trained and evaluated with the usual LeRobot dataset, checkpoint, and processor APIs.
|
||||||
|
|
||||||
|
## Model Overview
|
||||||
|
|
||||||
|
The policy embeds one or more camera images and the language task prompt with InternVL3, pads robot state/action vectors to fixed maximum dimensions, and predicts future action chunks with a flow-matching action head. During inference, the policy samples an action chunk and returns `n_action_steps` actions from that chunk before sampling again.
|
||||||
|
|
||||||
|
### What the LeRobot Integration Covers
|
||||||
|
|
||||||
|
- Standard `policy.type=evo1` configuration through LeRobot
|
||||||
|
- InternVL3 image/text embedding with optional FlashAttention fallback
|
||||||
|
- Stage-based finetuning controls for action-head-only and VLM finetuning runs
|
||||||
|
- Continuous flow-matching action prediction
|
||||||
|
- Checkpoint save/load through LeRobot policy APIs
|
||||||
|
- Training with `lerobot-train` and evaluation with standard policy inference APIs
|
||||||
|
|
||||||
|
The broader EVO1 project may include additional training scripts and dataset tooling. This page focuses on the LeRobot robot-control policy path.
|
||||||
|
|
||||||
|
## Installation Requirements
|
||||||
|
|
||||||
|
1. Install LeRobot by following the [Installation Guide](./installation).
|
||||||
|
2. Install EVO1 dependencies:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install -e ".[evo1]"
|
||||||
|
```
|
||||||
|
|
||||||
|
For LIBERO evaluation, install the LIBERO extra as well:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install -e ".[evo1,libero]"
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Install a `flash-attn` wheel only if it is compatible with your Python, PyTorch, CUDA, and GPU stack. EVO1 falls back to standard attention when `flash_attn` is not available, but reproducing the official LIBERO checkpoint conversion result below requires the same FlashAttention path used by the original EVO1 checkpoint.
|
||||||
|
|
||||||
|
EVO1 uses InternVL3 through the Hugging Face `transformers` remote-code path, so the first run may download the configured VLM checkpoint unless `policy.vlm_model_name` points to a local model directory.
|
||||||
|
|
||||||
|
## Data Requirements
|
||||||
|
|
||||||
|
EVO1 expects a LeRobot dataset with:
|
||||||
|
|
||||||
|
- One to `policy.max_views` visual observations, for example `observation.images.image`
|
||||||
|
- `observation.state`
|
||||||
|
- `action`
|
||||||
|
- A language task instruction in the dataset `task` field, or another field configured with `policy.task_field`
|
||||||
|
|
||||||
|
State and action vectors are padded to `policy.max_state_dim` and `policy.max_action_dim`. Predictions are cropped back to the dataset action dimension before being returned.
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
To use EVO1 in a LeRobot configuration, specify:
|
||||||
|
|
||||||
|
```python
|
||||||
|
policy.type=evo1
|
||||||
|
```
|
||||||
|
|
||||||
|
By default, a new EVO1 policy initializes its VLM from:
|
||||||
|
|
||||||
|
```python
|
||||||
|
policy.vlm_model_name=OpenGVLab/InternVL3-1B
|
||||||
|
```
|
||||||
|
|
||||||
|
Once a LeRobot-format EVO1 checkpoint is available, load it with:
|
||||||
|
|
||||||
|
```python
|
||||||
|
policy.path=your-org/your-evo1-checkpoint
|
||||||
|
```
|
||||||
|
|
||||||
|
The converted LIBERO checkpoint used for this PR is available at:
|
||||||
|
|
||||||
|
```python
|
||||||
|
policy.path=javadcc/evo1-libero-lerobot
|
||||||
|
```
|
||||||
|
|
||||||
|
## Training
|
||||||
|
|
||||||
|
### Stage 1
|
||||||
|
|
||||||
|
Stage 1 freezes the VLM and trains the action head:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-train \
|
||||||
|
--dataset.repo_id=your_org/your_dataset \
|
||||||
|
--policy.type=evo1 \
|
||||||
|
--policy.training_stage=stage1 \
|
||||||
|
--policy.vlm_model_name=OpenGVLab/InternVL3-1B \
|
||||||
|
--policy.device=cuda \
|
||||||
|
--policy.chunk_size=50 \
|
||||||
|
--policy.n_action_steps=50 \
|
||||||
|
--policy.max_state_dim=24 \
|
||||||
|
--policy.max_action_dim=24 \
|
||||||
|
--policy.optimizer_lr=1e-5 \
|
||||||
|
--batch_size=4 \
|
||||||
|
--steps=5000 \
|
||||||
|
--output_dir=./outputs/evo1_stage1
|
||||||
|
```
|
||||||
|
|
||||||
|
### Stage 2
|
||||||
|
|
||||||
|
Stage 2 finetunes the VLM branches and action head. A common workflow starts from a Stage 1 checkpoint:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-train \
|
||||||
|
--dataset.repo_id=your_org/your_dataset \
|
||||||
|
--policy.path=./outputs/evo1_stage1/checkpoints/005000/pretrained_model \
|
||||||
|
--policy.training_stage=stage2 \
|
||||||
|
--policy.vlm_model_name=OpenGVLab/InternVL3-1B \
|
||||||
|
--policy.device=cuda \
|
||||||
|
--policy.chunk_size=50 \
|
||||||
|
--policy.n_action_steps=50 \
|
||||||
|
--policy.max_state_dim=24 \
|
||||||
|
--policy.max_action_dim=24 \
|
||||||
|
--policy.optimizer_lr=1e-5 \
|
||||||
|
--batch_size=4 \
|
||||||
|
--steps=80000 \
|
||||||
|
--output_dir=./outputs/evo1_stage2
|
||||||
|
```
|
||||||
|
|
||||||
|
By default, `policy.training_stage` reapplies the finetuning defaults for that stage. This is important when
|
||||||
|
starting Stage 2 from a Stage 1 checkpoint, because the Stage 1 checkpoint config stores the VLM finetuning
|
||||||
|
flags as disabled. These stage defaults take precedence over saved or manually supplied `policy.finetune_*`
|
||||||
|
flags unless `policy.apply_training_stage_defaults=false`, so set that flag only when manually controlling
|
||||||
|
every finetuning flag.
|
||||||
|
|
||||||
|
### Key Training Parameters
|
||||||
|
|
||||||
|
| Parameter | Default | Description |
|
||||||
|
| --------------------------------------------- | ------------------------ | ----------------------------------------------------------------- |
|
||||||
|
| `policy.vlm_model_name` | `OpenGVLab/InternVL3-1B` | InternVL3 checkpoint or local model directory |
|
||||||
|
| `policy.training_stage` | `stage1` | `stage1` trains the action head; `stage2` finetunes VLM branches |
|
||||||
|
| `policy.apply_training_stage_defaults` | `true` | Reapplies stage finetuning defaults after loading a checkpoint |
|
||||||
|
| `policy.vlm_num_layers` | `14` | Number of InternVL3 language layers kept for the policy |
|
||||||
|
| `policy.vlm_dtype` | `bfloat16` | Requested VLM dtype |
|
||||||
|
| `policy.use_flash_attn` | `true` | Requests FlashAttention when installed; otherwise falls back |
|
||||||
|
| `policy.enable_gradient_checkpointing` | `true` | Enables checkpointing on supported InternVL3 modules |
|
||||||
|
| `policy.gradient_checkpointing_use_reentrant` | `false` | Reentrant setting passed to gradient checkpointing when supported |
|
||||||
|
| `policy.chunk_size` | `50` | Number of future actions predicted per chunk |
|
||||||
|
| `policy.n_action_steps` | `50` | Number of actions consumed from a sampled chunk |
|
||||||
|
| `policy.max_state_dim` | `24` | State padding dimension |
|
||||||
|
| `policy.max_action_dim` | `24` | Action padding dimension |
|
||||||
|
| `policy.task_field` | `task` | Batch field used as the language prompt |
|
||||||
|
|
||||||
|
## Results
|
||||||
|
|
||||||
|
### LIBERO Object Checkpoint Conversion
|
||||||
|
|
||||||
|
The checkpoint [javadcc/evo1-libero-lerobot](https://huggingface.co/javadcc/evo1-libero-lerobot)
|
||||||
|
is the LeRobot-format conversion of the official EVO1 LIBERO checkpoint. The conversion was checked against
|
||||||
|
the official EVO1 checkpoint with the same LIBERO Object initial states and action postprocessing.
|
||||||
|
|
||||||
|
| Checkpoint | Suite | Episodes | Success Rate |
|
||||||
|
| ---------------------------- | --------------- | ---------------- | ------------ |
|
||||||
|
| Official EVO1 checkpoint | `libero_object` | 10, one per task | 100% |
|
||||||
|
| LeRobot converted checkpoint | `libero_object` | 10, one per task | 100% |
|
||||||
|
|
||||||
|
For a fixed `libero_object` rollout, the official checkpoint and LeRobot checkpoint produced identical
|
||||||
|
pixel embeddings, VLM fused tokens, normalized actions, and denormalized actions for the checked action step
|
||||||
|
(`max_abs_diff=0.0`).
|
||||||
|
|
||||||
|
The published checkpoint expects the raw LIBERO camera feature names
|
||||||
|
`observation.images.agentview_image` and `observation.images.robot0_eye_in_hand_image`. To run the converted
|
||||||
|
checkpoint with LeRobot LIBERO evaluation for the same one-episode-per-task setting, keep those camera names
|
||||||
|
instead of the default `image`/`image2` mapping:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-eval \
|
||||||
|
--policy.path=javadcc/evo1-libero-lerobot \
|
||||||
|
--policy.device=cuda \
|
||||||
|
--env.type=libero \
|
||||||
|
--env.task=libero_object \
|
||||||
|
--env.camera_name_mapping="{agentview_image: agentview_image, robot0_eye_in_hand_image: robot0_eye_in_hand_image}" \
|
||||||
|
--env.observation_height=448 \
|
||||||
|
--env.observation_width=448 \
|
||||||
|
--eval.batch_size=1 \
|
||||||
|
--eval.n_episodes=1
|
||||||
|
```
|
||||||
|
|
||||||
|
## References
|
||||||
|
|
||||||
|
- [EVO1 repository](https://github.com/MINT-SJTU/Evo-1)
|
||||||
|
- [InternVL3-1B](https://huggingface.co/OpenGVLab/InternVL3-1B)
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
This LeRobot integration follows the Apache 2.0 License used by LeRobot. Check the upstream EVO1 and InternVL3 model pages for the licenses of released checkpoints and data.
|
||||||
@@ -0,0 +1,98 @@
|
|||||||
|
# Compute HW Guide for LeRobot Training
|
||||||
|
|
||||||
|
Rough sizing for training a LeRobot policy: how much VRAM each policy needs, what training time looks like, and where to run when local hardware isn't enough.
|
||||||
|
|
||||||
|
The numbers below are **indicative** — order-of-magnitude figures for picking hardware, not exact predictions. Throughput depends heavily on dataset I/O, image resolution, batch size, and number of GPUs.
|
||||||
|
|
||||||
|
## Memory by policy group
|
||||||
|
|
||||||
|
Policies cluster by backbone size; the groupings below give a single VRAM envelope per group instead of repeating numbers per policy. Memory scales roughly linearly with batch size; AdamW (the LeRobot default) carries optimizer state that adds ~30–100% over a forward+backward pass alone.
|
||||||
|
|
||||||
|
| Group | Policies | Peak VRAM (BS 8, AdamW) | Suitable starter GPUs |
|
||||||
|
| ---------- | ------------------------------------------- | ----------------------: | --------------------------------- |
|
||||||
|
| Light BC | `act`, `vqbet`, `tdmpc` | ~2–6GB | Laptop GPU (RTX 3060), L4, A10G |
|
||||||
|
| Diffusion | `diffusion`, `multi_task_dit` | ~8–14GB | RTX 4070+ / L4 / A10G |
|
||||||
|
| Small VLA | `smolvla` | ~10–16GB | RTX 4080+ / L4 / A10G |
|
||||||
|
| Large VLA | `pi0`, `pi0_fast`, `pi05`, `xvla`, `wall_x` | ~24–40GB | A100 40 GB+ (24 GB tight at BS 1) |
|
||||||
|
| Multimodal | `groot`, `eo1` | ~24–40GB | A100 40 GB+ |
|
||||||
|
| RL | `sac` | config-dep. | See [HIL-SERL guide](./hilserl) |
|
||||||
|
|
||||||
|
Memory-bound? Drop the batch size (~linear), use gradient accumulation to recover effective batch, or for SmolVLA leave `freeze_vision_encoder=True`.
|
||||||
|
|
||||||
|
## Training time
|
||||||
|
|
||||||
|
Robotics imitation learning typically converges in **5–10 epochs over the dataset**, not hundreds of thousands of raw steps. Once you know your epoch count, wall-clock is essentially:
|
||||||
|
|
||||||
|
```text
|
||||||
|
total_frames = sum of frames over all episodes # 50 ep × 30 fps × 30 s ≈ 45,000
|
||||||
|
steps_per_epoch = ceil(total_frames / (num_gpus × batch_size))
|
||||||
|
total_steps = epochs × steps_per_epoch
|
||||||
|
wall_clock ≈ total_steps × per_step_time
|
||||||
|
```
|
||||||
|
|
||||||
|
Per-step time depends on the policy and the GPU. The numbers in the table below are anchors — pick the row closest to your setup and scale linearly with `total_steps` if you train longer or shorter.
|
||||||
|
|
||||||
|
### Common scenarios
|
||||||
|
|
||||||
|
Indicative wall-clock for **5 epochs on a ~50-episode dataset (~45k frames at 30 fps × 30 s)**, default optimizer (AdamW), 640×480 images:
|
||||||
|
|
||||||
|
| Setup | Policy | Batch | Wall-clock |
|
||||||
|
| ------------------------------------ | -------------- | ----- | ---------: |
|
||||||
|
| Single RTX 4090 / RTX 3090 (24 GB) | `act` | 8 | ~30–60min |
|
||||||
|
| Single RTX 4090 / RTX 3090 (24 GB) | `diffusion` | 8 | ~2–4h |
|
||||||
|
| Single L4 / A10G (24 GB) | `act` | 8 | ~1–2h |
|
||||||
|
| Single L4 / A10G (24 GB) | `smolvla` | 4 | ~3–6h |
|
||||||
|
| Single A100 40 GB | `smolvla` | 16 | ~1–2h |
|
||||||
|
| Single A100 40 GB | `pi0` / `pi05` | 4 | ~4–8h |
|
||||||
|
| 4× H100 80 GB cluster (`accelerate`) | `diffusion` | 32 | ~30–60min |
|
||||||
|
| 4× H100 80 GB cluster (`accelerate`) | `smolvla` | 32 | ~1–2h |
|
||||||
|
| Apple Silicon M1/M2/M3 Max (MPS) | `act` | 4 | ~6–14h |
|
||||||
|
|
||||||
|
These are order-of-magnitude figures. Real runs deviate by ±50% depending on image resolution, dataset I/O, dataloader threading, and exact GPU SKU. They are useful as "is this run going to take an hour or a day?" intuition, not as SLAs.
|
||||||
|
|
||||||
|
### Multi-GPU matters a lot
|
||||||
|
|
||||||
|
`accelerate launch --num_processes=N` is the easiest way to cut training time. Each optimizer step processes `N × batch_size` samples in roughly the same wall-clock as a single-GPU step, so 4 GPUs ≈ 4× speedup for compute-bound runs. See the [Multi GPU training](./multi_gpu_training) guide for the full setup.
|
||||||
|
|
||||||
|
Reference data points on a 4×H100 80 GB cluster (`accelerate launch --num_processes=4`), 5000 steps, batch 32, AdamW, dataset [`imstevenpmwork/super_poulain_draft`](https://huggingface.co/datasets/imstevenpmwork/super_poulain_draft) (~50 episodes, ~640×480 images):
|
||||||
|
|
||||||
|
| Policy | Wall-clock | `update_s` | `dataloading_s` | GPU util | Notable flags |
|
||||||
|
| ----------- | ---------- | ---------: | --------------: | -------- | ------------------------------------------------------------------------------------------------------------------------------ |
|
||||||
|
| `diffusion` | 16m 17s | 0.167 | 0.015 | ~90% | defaults (training from scratch) |
|
||||||
|
| `smolvla` | 27m 49s | 0.312 | 0.011 | ~80% | `--policy.path=lerobot/smolvla_base`, `freeze_vision_encoder=false`, `train_expert_only=false` |
|
||||||
|
| `pi05` | 3h 41m | 2.548 | 0.014 | ~95% | `--policy.pretrained_path=lerobot/pi05_base`, `gradient_checkpointing=true`, `dtype=bfloat16`, vision encoder + expert trained |
|
||||||
|
|
||||||
|
The `dataloading_s` vs. `update_s` ratio is the diagnostic that matters: when `dataloading_s` approaches `update_s`, more GPUs stop helping — your dataloader is the bottleneck and you should look at `--num_workers`, image resolution, and disk speed before adding compute.
|
||||||
|
|
||||||
|
### Schedule and checkpoints
|
||||||
|
|
||||||
|
If you shorten training (e.g. 5k–10k steps on a small dataset), also shorten the LR schedule with `--policy.scheduler_decay_steps≈--steps`. Otherwise the LR stays near its peak and never decays. Same for `--save_freq`.
|
||||||
|
|
||||||
|
## Where to run
|
||||||
|
|
||||||
|
VRAM is the first filter. Within a tier, pick by budget and availability — the `$`–`$$$$` columns are relative; check current pricing on the provider you actually use.
|
||||||
|
|
||||||
|
| Class | VRAM | Tier | Comfortable for |
|
||||||
|
| -------------------------- | ----- | ------ | ----------------------------------------------------------- |
|
||||||
|
| RTX 3090 / 4090 (consumer) | 24 GB | `$` | Light BC, Diffusion, SmolVLA. Tight for VLAs at batch 1. |
|
||||||
|
| L4 / A10G (cloud) | 24 GB | `$–$$` | Same envelope; common on Google Cloud, RunPod, AWS `g5/g6`. |
|
||||||
|
| A100 40 GB | 40 GB | `$$$` | Any policy at reasonable batch sizes. |
|
||||||
|
| A100 80 GB / H100 80 GB | 80 GB | `$$$$` | Multi-GPU clusters; large batches for VLAs. |
|
||||||
|
| **CPU only** | — | — | Don't train. Use Colab or rent a GPU. |
|
||||||
|
|
||||||
|
### Hugging Face Jobs
|
||||||
|
|
||||||
|
[Hugging Face Jobs](https://huggingface.co/docs/hub/jobs) lets you run training on managed HF infrastructure, billed by the second. The repo publishes a ready-to-use image: **`huggingface/lerobot-gpu:latest`**, rebuilt **every night at 02:00 UTC from `main`** ([`docker_publish.yml`](https://github.com/huggingface/lerobot/blob/main/.github/workflows/docker_publish.yml)) — so it tracks the current state of the repo, not a tagged release.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
hf jobs run --flavor a10g-large huggingface/lerobot-gpu:latest \
|
||||||
|
bash -c "nvidia-smi && lerobot-train \
|
||||||
|
--policy.type=act --dataset.repo_id=<USER>/<DATASET> \
|
||||||
|
--policy.repo_id=<USER>/act_<task> --batch_size=8 --steps=50000"
|
||||||
|
```
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
|
||||||
|
- The leading `nvidia-smi` is a quick sanity check that CUDA is visible inside the container — useful to fail fast if the flavor or driver mismatched.
|
||||||
|
- The default Job timeout is 30 minutes; pass `--timeout 4h` (or longer) for real training.
|
||||||
|
- `--flavor` maps onto the table above: `t4-small`/`t4-medium` (T4, ACT only), `l4x1`/`l4x4` (L4 24 GB), `a10g-small/large/largex2/largex4` (A10G 24 GB scaled out), `a100-large` (A100). For the current full catalogue + pricing see [https://huggingface.co/docs/hub/jobs](https://huggingface.co/docs/hub/jobs).
|
||||||
@@ -50,30 +50,30 @@ This process can be repeated iteratively: deploy, collect, fine-tune, repeat. Ea
|
|||||||
|
|
||||||
### Teleoperator Requirements
|
### Teleoperator Requirements
|
||||||
|
|
||||||
The `examples/hil` HIL scripts require **teleoperators with active motors** that can:
|
The `lerobot-rollout --strategy.type=dagger` mode requires **teleoperators with active motors** that can:
|
||||||
|
|
||||||
- Enable/disable torque programmatically
|
- Enable/disable torque programmatically
|
||||||
- Move to target positions (to mirror the robot state when pausing)
|
- Move to target positions (to mirror the robot state when pausing)
|
||||||
|
|
||||||
**Compatible teleoperators in the current `examples/hil` scripts:**
|
**Compatible teleoperators:**
|
||||||
|
|
||||||
- `openarm_mini` - OpenArm Mini
|
- `openarm_mini` - OpenArm Mini
|
||||||
- `so_leader` - SO100 / SO101 leader arm
|
- `so_leader` - SO100 / SO101 leader arm
|
||||||
|
|
||||||
> [!IMPORTANT]
|
> [!IMPORTANT]
|
||||||
> The provided `examples/hil` commands default to `bi_openarm_follower` + `openarm_mini`.
|
> The provided commands default to `bi_openarm_follower` + `openarm_mini`.
|
||||||
> `so_follower` + `so_leader` configs are also registered and can be used via CLI flags.
|
> `so_follower` + `so_leader` configs are also registered and can be used via CLI flags.
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## Script
|
## Script
|
||||||
|
|
||||||
A single script handles both synchronous and RTC-based inference. Toggle RTC with `--rtc.enabled=true`:
|
Use `lerobot-rollout` with `--strategy.type=dagger` for HIL data collection. Select the inference backend with `--inference.type=sync|rtc`:
|
||||||
|
|
||||||
| Mode | Flag | Models |
|
| Mode | Flag | Models |
|
||||||
| ------------------------ | -------------------- | --------------------- |
|
| ------------------------ | ---------------------- | --------------------- |
|
||||||
| Standard (default) | _(no flag needed)_ | ACT, Diffusion Policy |
|
| Standard (default) | _(no flag needed)_ | ACT, Diffusion Policy |
|
||||||
| Real-Time Chunking (RTC) | `--rtc.enabled=true` | Pi0, Pi0.5, SmolVLA |
|
| Real-Time Chunking (RTC) | `--inference.type=rtc` | Pi0, Pi0.5, SmolVLA |
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
@@ -97,7 +97,7 @@ python src/lerobot/scripts/lerobot_train.py \
|
|||||||
**Standard inference (ACT, Diffusion Policy):**
|
**Standard inference (ACT, Diffusion Policy):**
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python examples/hil/hil_data_collection.py \
|
lerobot-rollout --strategy.type=dagger \
|
||||||
--robot.type=bi_openarm_follower \
|
--robot.type=bi_openarm_follower \
|
||||||
--robot.left_arm_config.port=can1 \
|
--robot.left_arm_config.port=can1 \
|
||||||
--robot.left_arm_config.side=left \
|
--robot.left_arm_config.side=left \
|
||||||
@@ -108,11 +108,10 @@ python examples/hil/hil_data_collection.py \
|
|||||||
--teleop.port_left=/dev/ttyACM0 \
|
--teleop.port_left=/dev/ttyACM0 \
|
||||||
--teleop.port_right=/dev/ttyACM1 \
|
--teleop.port_right=/dev/ttyACM1 \
|
||||||
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
|
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
|
||||||
--dataset.repo_id=your-username/hil-dataset \
|
--dataset.repo_id=your-username/rollout_hil_dataset \
|
||||||
--dataset.single_task="Fold the T-shirt properly" \
|
--dataset.single_task="Fold the T-shirt properly" \
|
||||||
--dataset.fps=30 \
|
--dataset.fps=30 \
|
||||||
--dataset.episode_time_s=1000 \
|
--strategy.num_episodes=50 \
|
||||||
--dataset.num_episodes=50 \
|
|
||||||
--interpolation_multiplier=2
|
--interpolation_multiplier=2
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -121,11 +120,11 @@ python examples/hil/hil_data_collection.py \
|
|||||||
For models with high inference latency, enable RTC for smooth execution:
|
For models with high inference latency, enable RTC for smooth execution:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python examples/hil/hil_data_collection.py \
|
lerobot-rollout --strategy.type=dagger \
|
||||||
--rtc.enabled=true \
|
--inference.type=rtc \
|
||||||
--rtc.execution_horizon=20 \
|
--inference.rtc.execution_horizon=20 \
|
||||||
--rtc.max_guidance_weight=5.0 \
|
--inference.rtc.max_guidance_weight=5.0 \
|
||||||
--rtc.prefix_attention_schedule=LINEAR \
|
--inference.rtc.prefix_attention_schedule=LINEAR \
|
||||||
--robot.type=bi_openarm_follower \
|
--robot.type=bi_openarm_follower \
|
||||||
--robot.left_arm_config.port=can1 \
|
--robot.left_arm_config.port=can1 \
|
||||||
--robot.left_arm_config.side=left \
|
--robot.left_arm_config.side=left \
|
||||||
@@ -136,11 +135,10 @@ python examples/hil/hil_data_collection.py \
|
|||||||
--teleop.port_left=/dev/ttyACM0 \
|
--teleop.port_left=/dev/ttyACM0 \
|
||||||
--teleop.port_right=/dev/ttyACM1 \
|
--teleop.port_right=/dev/ttyACM1 \
|
||||||
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
|
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
|
||||||
--dataset.repo_id=your-username/hil-rtc-dataset \
|
--dataset.repo_id=your-username/rollout_hil_rtc_dataset \
|
||||||
--dataset.single_task="Fold the T-shirt properly" \
|
--dataset.single_task="Fold the T-shirt properly" \
|
||||||
--dataset.fps=30 \
|
--dataset.fps=30 \
|
||||||
--dataset.episode_time_s=1000 \
|
--strategy.num_episodes=50 \
|
||||||
--dataset.num_episodes=50 \
|
|
||||||
--interpolation_multiplier=3
|
--interpolation_multiplier=3
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -235,7 +233,7 @@ This HIL data collection approach builds on ideas from interactive imitation lea
|
|||||||
|
|
||||||
- **HG-DAgger** (Kelly et al., 2019) made this practical for robotics: a human expert monitors the robot and only intervenes when needed, rather than labeling every state. The gating between autonomous and human control is exactly the pause → takeover → return-to-policy loop used in the scripts here.
|
- **HG-DAgger** (Kelly et al., 2019) made this practical for robotics: a human expert monitors the robot and only intervenes when needed, rather than labeling every state. The gating between autonomous and human control is exactly the pause → takeover → return-to-policy loop used in the scripts here.
|
||||||
|
|
||||||
- **RaC** (Hu et al., 2025) scales this loop to long-horizon tasks by explicitly decomposing interventions into **recovery** (teleoperating back to a good state) and **correction** (demonstrating the right behavior from there). This decomposition is the protocol followed by the HIL scripts in `examples/hil`.
|
- **RaC** (Hu et al., 2025) scales this loop to long-horizon tasks by explicitly decomposing interventions into **recovery** (teleoperating back to a good state) and **correction** (demonstrating the right behavior from there). This decomposition is the protocol followed by the DAgger strategy in `lerobot-rollout`.
|
||||||
|
|
||||||
- **π0.6/RECAP** (Physical Intelligence, 2025) applies the same iterative collect-and-finetune loop at scale with VLA models, showing that even large pretrained policies benefit substantially from targeted human corrections on their own failure modes. π0.6 is trained using RECAP.
|
- **π0.6/RECAP** (Physical Intelligence, 2025) applies the same iterative collect-and-finetune loop at scale with VLA models, showing that even large pretrained policies benefit substantially from targeted human corrections on their own failure modes. π0.6 is trained using RECAP.
|
||||||
|
|
||||||
|
|||||||
+26
-105
@@ -509,121 +509,42 @@ hf upload ${HF_USER}/act_so101_test${CKPT} \
|
|||||||
|
|
||||||
## Run inference and evaluate your policy
|
## Run inference and evaluate your policy
|
||||||
|
|
||||||
You can use the `record` script from [`lerobot-record`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/scripts/lerobot_record.py) with a policy checkpoint as input, to run inference and evaluate your policy. For instance, run this command or API example to run inference and record 10 evaluation episodes:
|
Use `lerobot-rollout` to deploy a trained policy on your robot. You can choose different strategies depending on your needs:
|
||||||
|
|
||||||
<hfoptions id="eval">
|
<hfoptions id="eval">
|
||||||
<hfoption id="Command">
|
<hfoption id="Base mode (no recording)">
|
||||||
```bash
|
```bash
|
||||||
lerobot-record \
|
lerobot-rollout \
|
||||||
|
--strategy.type=base \
|
||||||
|
--policy.path=${HF_USER}/my_policy \
|
||||||
--robot.type=so100_follower \
|
--robot.type=so100_follower \
|
||||||
--robot.port=/dev/ttyACM1 \
|
--robot.port=/dev/ttyACM1 \
|
||||||
--robot.cameras="{ up: {type: opencv, index_or_path: /dev/video10, width: 640, height: 480, fps: 30}, side: {type: intelrealsense, serial_number_or_name: 233522074606, width: 640, height: 480, fps: 30}}" \
|
--robot.cameras="{ up: {type: opencv, index_or_path: /dev/video10, width: 640, height: 480, fps: 30}, side: {type: intelrealsense, serial_number_or_name: 233522074606, width: 640, height: 480, fps: 30}}" \
|
||||||
--robot.id=my_awesome_follower_arm \
|
--task="Put lego brick into the transparent box" \
|
||||||
--display_data=false \
|
--duration=60
|
||||||
--dataset.repo_id=${HF_USER}/eval_so100 \
|
|
||||||
--dataset.single_task="Put lego brick into the transparent box" \
|
|
||||||
--dataset.streaming_encoding=true \
|
|
||||||
--dataset.encoder_threads=2 \
|
|
||||||
# --dataset.vcodec=auto \
|
|
||||||
# <- Teleop optional if you want to teleoperate in between episodes \
|
|
||||||
# --teleop.type=so100_leader \
|
|
||||||
# --teleop.port=/dev/ttyACM0 \
|
|
||||||
# --teleop.id=my_awesome_leader_arm \
|
|
||||||
--policy.path=${HF_USER}/my_policy
|
|
||||||
```
|
```
|
||||||
</hfoption>
|
</hfoption>
|
||||||
<hfoption id="API example">
|
<hfoption id="Sentry mode (with recording)">
|
||||||
|
```bash
|
||||||
<!-- prettier-ignore-start -->
|
lerobot-rollout \
|
||||||
```python
|
--strategy.type=sentry \
|
||||||
from lerobot.cameras.opencv import OpenCVCameraConfig
|
--strategy.upload_every_n_episodes=5 \
|
||||||
from lerobot.datasets import LeRobotDataset
|
--policy.path=${HF_USER}/my_policy \
|
||||||
from lerobot.utils.feature_utils import hw_to_dataset_features
|
--robot.type=so100_follower \
|
||||||
from lerobot.policies.act import ACTPolicy
|
--robot.port=/dev/ttyACM1 \
|
||||||
from lerobot.policies import make_pre_post_processors
|
--robot.cameras="{ up: {type: opencv, index_or_path: /dev/video10, width: 640, height: 480, fps: 30}, side: {type: intelrealsense, serial_number_or_name: 233522074606, width: 640, height: 480, fps: 30}}" \
|
||||||
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig
|
--dataset.repo_id=${HF_USER}/eval_so100 \
|
||||||
from lerobot.scripts.lerobot_record import record_loop
|
--dataset.single_task="Put lego brick into the transparent box" \
|
||||||
from lerobot.common.control_utils import init_keyboard_listener
|
--duration=600
|
||||||
from lerobot.utils.utils import log_say
|
|
||||||
from lerobot.utils.visualization_utils import init_rerun
|
|
||||||
|
|
||||||
|
|
||||||
NUM_EPISODES = 5
|
|
||||||
FPS = 30
|
|
||||||
EPISODE_TIME_SEC = 60
|
|
||||||
TASK_DESCRIPTION = "My task description"
|
|
||||||
HF_MODEL_ID = "<hf_username>/<model_repo_id>"
|
|
||||||
HF_DATASET_ID = "<hf_username>/<eval_dataset_repo_id>"
|
|
||||||
|
|
||||||
# Create the robot configuration
|
|
||||||
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
|
|
||||||
robot_config = SO100FollowerConfig(
|
|
||||||
port="/dev/tty.usbmodem58760434471", id="my_awesome_follower_arm", cameras=camera_config
|
|
||||||
)
|
|
||||||
|
|
||||||
# Initialize the robot
|
|
||||||
robot = SO100Follower(robot_config)
|
|
||||||
|
|
||||||
# Initialize the policy
|
|
||||||
policy = ACTPolicy.from_pretrained(HF_MODEL_ID)
|
|
||||||
|
|
||||||
# Configure the dataset features
|
|
||||||
action_features = hw_to_dataset_features(robot.action_features, "action")
|
|
||||||
obs_features = hw_to_dataset_features(robot.observation_features, "observation")
|
|
||||||
dataset_features = {**action_features, **obs_features}
|
|
||||||
|
|
||||||
# Create the dataset
|
|
||||||
dataset = LeRobotDataset.create(
|
|
||||||
repo_id=HF_DATASET_ID,
|
|
||||||
fps=FPS,
|
|
||||||
features=dataset_features,
|
|
||||||
robot_type=robot.name,
|
|
||||||
use_videos=True,
|
|
||||||
image_writer_threads=4,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Initialize the keyboard listener and rerun visualization
|
|
||||||
_, events = init_keyboard_listener()
|
|
||||||
init_rerun(session_name="recording")
|
|
||||||
|
|
||||||
# Connect the robot
|
|
||||||
robot.connect()
|
|
||||||
|
|
||||||
preprocessor, postprocessor = make_pre_post_processors(
|
|
||||||
policy_cfg=policy,
|
|
||||||
pretrained_path=HF_MODEL_ID,
|
|
||||||
dataset_stats=dataset.meta.stats,
|
|
||||||
)
|
|
||||||
|
|
||||||
for episode_idx in range(NUM_EPISODES):
|
|
||||||
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
|
|
||||||
|
|
||||||
# Run the policy inference loop
|
|
||||||
record_loop(
|
|
||||||
robot=robot,
|
|
||||||
events=events,
|
|
||||||
fps=FPS,
|
|
||||||
policy=policy,
|
|
||||||
preprocessor=preprocessor,
|
|
||||||
postprocessor=postprocessor,
|
|
||||||
dataset=dataset,
|
|
||||||
control_time_s=EPISODE_TIME_SEC,
|
|
||||||
single_task=TASK_DESCRIPTION,
|
|
||||||
display_data=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
dataset.save_episode()
|
|
||||||
|
|
||||||
# Clean up
|
|
||||||
robot.disconnect()
|
|
||||||
dataset.push_to_hub()
|
|
||||||
```
|
```
|
||||||
<!-- prettier-ignore-end -->
|
|
||||||
|
|
||||||
</hfoption>
|
</hfoption>
|
||||||
</hfoptions>
|
</hfoptions>
|
||||||
|
|
||||||
As you can see, it's almost the same command as previously used to record your training dataset. Two things changed:
|
The `--strategy.type` flag selects the execution mode:
|
||||||
|
|
||||||
1. There is an additional `--control.policy.path` argument which indicates the path to your policy checkpoint with (e.g. `outputs/train/eval_act_so101_test/checkpoints/last/pretrained_model`). You can also use the model repository if you uploaded a model checkpoint to the hub (e.g. `${HF_USER}/act_so101_test`).
|
- `base`: Autonomous rollout with no data recording (useful for quick evaluation)
|
||||||
2. The name of dataset begins by `eval` to reflect that you are running inference (e.g. `${HF_USER}/eval_act_so101_test`).
|
- `sentry`: Continuous recording with auto-upload (useful for large-scale evaluation)
|
||||||
|
- `highlight`: Ring buffer recording with keystroke save (useful for capturing interesting events)
|
||||||
|
- `dagger`: Human-in-the-loop data collection (see [HIL Data Collection](./hil_data_collection))
|
||||||
|
|
||||||
|
All strategies support `--inference.type=rtc` for smooth execution with slow VLA models (Pi0, Pi0.5, SmolVLA).
|
||||||
|
|||||||
@@ -0,0 +1,261 @@
|
|||||||
|
# Policy Deployment (lerobot-rollout)
|
||||||
|
|
||||||
|
`lerobot-rollout` is the single CLI for deploying trained policies on real robots. It supports multiple execution strategies and inference backends, from quick evaluation to continuous recording and human-in-the-loop data collection.
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
No extra dependencies are needed beyond your robot and policy extras.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-rollout \
|
||||||
|
--strategy.type=base \
|
||||||
|
--policy.path=lerobot/act_koch_real \
|
||||||
|
--robot.type=koch_follower \
|
||||||
|
--robot.port=/dev/ttyACM0 \
|
||||||
|
--task="pick up cube" \
|
||||||
|
--duration=30
|
||||||
|
```
|
||||||
|
|
||||||
|
This runs the policy for 30 seconds with no recording.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Strategies
|
||||||
|
|
||||||
|
Select a strategy with `--strategy.type=<name>`. Each strategy defines a different control loop with its own recording and interaction semantics.
|
||||||
|
|
||||||
|
### Base (`--strategy.type=base`)
|
||||||
|
|
||||||
|
Autonomous policy execution with no data recording. Use this for quick evaluation, demos, or when you only need to observe the robot.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-rollout \
|
||||||
|
--strategy.type=base \
|
||||||
|
--policy.path=${HF_USER}/my_policy \
|
||||||
|
--robot.type=so100_follower \
|
||||||
|
--robot.port=/dev/ttyACM0 \
|
||||||
|
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||||
|
--task="Put lego brick into the box" \
|
||||||
|
--duration=60
|
||||||
|
```
|
||||||
|
|
||||||
|
| Flag | Description |
|
||||||
|
| ---------------- | ------------------------------------------------------ |
|
||||||
|
| `--duration` | Run time in seconds (0 = infinite) |
|
||||||
|
| `--task` | Task description passed to the policy |
|
||||||
|
| `--display_data` | Stream observations/actions to Rerun for visualization |
|
||||||
|
|
||||||
|
### Sentry (`--strategy.type=sentry`)
|
||||||
|
|
||||||
|
Continuous autonomous recording with periodic upload to the Hugging Face Hub. Episode boundaries are auto-computed from camera resolution and FPS so each saved episode produces a complete video file, keeping uploads efficient.
|
||||||
|
|
||||||
|
Policy state (hidden state, RTC queue) persists across episode boundaries: the robot does not reset between episodes.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-rollout \
|
||||||
|
--strategy.type=sentry \
|
||||||
|
--strategy.upload_every_n_episodes=5 \
|
||||||
|
--policy.path=${HF_USER}/my_policy \
|
||||||
|
--robot.type=so100_follower \
|
||||||
|
--robot.port=/dev/ttyACM0 \
|
||||||
|
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||||
|
--dataset.repo_id=${HF_USER}/rollout_eval_data \
|
||||||
|
--dataset.single_task="Put lego brick into the box" \
|
||||||
|
--duration=3600
|
||||||
|
```
|
||||||
|
|
||||||
|
| Flag | Description |
|
||||||
|
| -------------------------------------- | ----------------------------------------------------------- |
|
||||||
|
| `--strategy.upload_every_n_episodes` | Push to Hub every N episodes (default: 5) |
|
||||||
|
| `--strategy.target_video_file_size_mb` | Target video file size for episode rotation (default: auto) |
|
||||||
|
| `--dataset.repo_id` | **Required.** Hub repository for the recorded dataset |
|
||||||
|
| `--dataset.push_to_hub` | Whether to push to Hub on teardown (default: true) |
|
||||||
|
|
||||||
|
### Highlight (`--strategy.type=highlight`)
|
||||||
|
|
||||||
|
Autonomous rollout with on-demand recording via a memory-bounded ring buffer. The robot runs continuously while the buffer captures the last N seconds of telemetry. Press the save key to flush the buffer and start live recording; press it again to save the episode.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-rollout \
|
||||||
|
--strategy.type=highlight \
|
||||||
|
--strategy.ring_buffer_seconds=30 \
|
||||||
|
--strategy.save_key=s \
|
||||||
|
--strategy.push_key=h \
|
||||||
|
--policy.path=${HF_USER}/my_policy \
|
||||||
|
--robot.type=koch_follower \
|
||||||
|
--robot.port=/dev/ttyACM0 \
|
||||||
|
--dataset.repo_id=${HF_USER}/rollout_highlight_data \
|
||||||
|
--dataset.single_task="Pick up the red cube"
|
||||||
|
```
|
||||||
|
|
||||||
|
**Keyboard controls:**
|
||||||
|
|
||||||
|
| Key | Action |
|
||||||
|
| ------------------ | -------------------------------------------------------- |
|
||||||
|
| `s` (configurable) | Start recording (flushes buffer) / stop and save episode |
|
||||||
|
| `h` (configurable) | Push dataset to Hub |
|
||||||
|
| `ESC` | Stop the session |
|
||||||
|
|
||||||
|
| Flag | Description |
|
||||||
|
| -------------------------------------- | ---------------------------------------------- |
|
||||||
|
| `--strategy.ring_buffer_seconds` | Duration of buffered telemetry (default: 30) |
|
||||||
|
| `--strategy.ring_buffer_max_memory_mb` | Memory cap for the ring buffer (default: 2048) |
|
||||||
|
| `--strategy.save_key` | Key to toggle recording (default: `s`) |
|
||||||
|
| `--strategy.push_key` | Key to push to Hub (default: `h`) |
|
||||||
|
|
||||||
|
### DAgger (`--strategy.type=dagger`)
|
||||||
|
|
||||||
|
Human-in-the-loop data collection. Alternates between autonomous policy execution and human intervention via a teleoperator. Intervention frames are tagged with `intervention=True`. Requires a teleoperator (`--teleop.type`).
|
||||||
|
|
||||||
|
See the [Human-In-the-Loop Data Collection](./hil_data_collection) guide for a detailed walkthrough.
|
||||||
|
|
||||||
|
**Corrections-only mode** (default): Only human correction windows are recorded. Each correction becomes one episode.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-rollout \
|
||||||
|
--strategy.type=dagger \
|
||||||
|
--strategy.num_episodes=20 \
|
||||||
|
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
|
||||||
|
--robot.type=bi_openarm_follower \
|
||||||
|
--teleop.type=openarm_mini \
|
||||||
|
--dataset.repo_id=${HF_USER}/rollout_hil_data \
|
||||||
|
--dataset.single_task="Fold the T-shirt"
|
||||||
|
```
|
||||||
|
|
||||||
|
**Continuous recording mode** (`--strategy.record_autonomous=true`): Both autonomous and correction frames are recorded with time-based episode rotation (same as Sentry).
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-rollout \
|
||||||
|
--strategy.type=dagger \
|
||||||
|
--strategy.record_autonomous=true \
|
||||||
|
--strategy.num_episodes=50 \
|
||||||
|
--policy.path=${HF_USER}/my_policy \
|
||||||
|
--robot.type=so100_follower \
|
||||||
|
--robot.port=/dev/ttyACM0 \
|
||||||
|
--teleop.type=so101_leader \
|
||||||
|
--teleop.port=/dev/ttyACM1 \
|
||||||
|
--dataset.repo_id=${HF_USER}/rollout_dagger_data \
|
||||||
|
--dataset.single_task="Grasp the block"
|
||||||
|
```
|
||||||
|
|
||||||
|
**Keyboard controls** (default input device):
|
||||||
|
|
||||||
|
| Key | Action |
|
||||||
|
| ------- | ------------------------------------------- |
|
||||||
|
| `Space` | Pause / resume policy execution |
|
||||||
|
| `Tab` | Start / stop human correction |
|
||||||
|
| `Enter` | Push dataset to Hub (corrections-only mode) |
|
||||||
|
| `ESC` | Stop the session |
|
||||||
|
|
||||||
|
Foot pedal input is also supported via `--strategy.input_device=pedal`. Configure pedal codes with `--strategy.pedal.*` flags.
|
||||||
|
|
||||||
|
| Flag | Description |
|
||||||
|
| ------------------------------------ | ------------------------------------------------------- |
|
||||||
|
| `--strategy.num_episodes` | Number of correction episodes to record (default: 10) |
|
||||||
|
| `--strategy.record_autonomous` | Record autonomous frames too (default: false) |
|
||||||
|
| `--strategy.upload_every_n_episodes` | Push to Hub every N episodes (default: 5) |
|
||||||
|
| `--strategy.input_device` | Input device: `keyboard` or `pedal` (default: keyboard) |
|
||||||
|
| `--teleop.type` | **Required.** Teleoperator type |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Inference Backends
|
||||||
|
|
||||||
|
Select a backend with `--inference.type=<name>`. All strategies work with both backends.
|
||||||
|
|
||||||
|
### Sync (default)
|
||||||
|
|
||||||
|
One policy call per control tick. The main loop blocks until the action is computed.
|
||||||
|
|
||||||
|
Works with all policies. No extra flags needed.
|
||||||
|
|
||||||
|
### Real-Time Chunking (`--inference.type=rtc`)
|
||||||
|
|
||||||
|
A background thread produces action chunks asynchronously. The main control loop polls for the next ready action while the policy computes the next chunk in parallel.
|
||||||
|
|
||||||
|
Use RTC with large, slow VLA models (Pi0, Pi0.5, SmolVLA) for smooth, continuous motion despite high inference latency.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-rollout \
|
||||||
|
--strategy.type=base \
|
||||||
|
--inference.type=rtc \
|
||||||
|
--inference.rtc.execution_horizon=10 \
|
||||||
|
--inference.rtc.max_guidance_weight=10.0 \
|
||||||
|
--policy.path=${HF_USER}/pi0_policy \
|
||||||
|
--robot.type=so100_follower \
|
||||||
|
--robot.port=/dev/ttyACM0 \
|
||||||
|
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||||
|
--task="Pick up the cube" \
|
||||||
|
--duration=60 \
|
||||||
|
--device=cuda
|
||||||
|
```
|
||||||
|
|
||||||
|
| Flag | Description |
|
||||||
|
| ------------------------------------------- | -------------------------------------------------------------- |
|
||||||
|
| `--inference.rtc.execution_horizon` | Steps to blend with previous chunk (default: varies by policy) |
|
||||||
|
| `--inference.rtc.max_guidance_weight` | Consistency enforcement strength (default: varies by policy) |
|
||||||
|
| `--inference.rtc.prefix_attention_schedule` | Blend schedule: `LINEAR`, `EXP`, `ONES`, `ZEROS` |
|
||||||
|
| `--inference.queue_threshold` | Max queue size before backpressure (default: 30) |
|
||||||
|
|
||||||
|
See the [Real-Time Chunking](./rtc) guide for details on tuning RTC parameters.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Common Flags
|
||||||
|
|
||||||
|
| Flag | Description | Default |
|
||||||
|
| --------------------------------- | ----------------------------------------------------------------- | ------- |
|
||||||
|
| `--policy.path` | **Required.** HF Hub model ID or local checkpoint path | -- |
|
||||||
|
| `--robot.type` | **Required.** Robot type (e.g. `so100_follower`, `koch_follower`) | -- |
|
||||||
|
| `--robot.port` | Serial port for the robot | -- |
|
||||||
|
| `--robot.cameras` | Camera configuration (JSON dict) | -- |
|
||||||
|
| `--fps` | Control loop frequency | 30 |
|
||||||
|
| `--duration` | Run time in seconds (0 = infinite) | 0 |
|
||||||
|
| `--device` | Torch device (`cpu`, `cuda`, `mps`) | auto |
|
||||||
|
| `--task` | Task description (used when no dataset is provided) | -- |
|
||||||
|
| `--display_data` | Stream telemetry to Rerun visualization | false |
|
||||||
|
| `--display_ip` / `--display_port` | Remote Rerun server address | -- |
|
||||||
|
| `--interpolation_multiplier` | Action interpolation factor | 1 |
|
||||||
|
| `--use_torch_compile` | Enable `torch.compile` for inference | false |
|
||||||
|
| `--resume` | Resume a previous recording session | false |
|
||||||
|
| `--play_sounds` | Vocal synthesis for events | true |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Programmatic Usage
|
||||||
|
|
||||||
|
For custom deployments (e.g. with kinematics processors), use the rollout module API directly:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from lerobot.rollout import BaseStrategyConfig, RolloutConfig, build_rollout_context
|
||||||
|
from lerobot.rollout.inference import SyncInferenceConfig
|
||||||
|
from lerobot.rollout.strategies import BaseStrategy
|
||||||
|
from lerobot.utils.process import ProcessSignalHandler
|
||||||
|
|
||||||
|
cfg = RolloutConfig(
|
||||||
|
robot=my_robot_config,
|
||||||
|
policy=my_policy_config,
|
||||||
|
strategy=BaseStrategyConfig(),
|
||||||
|
inference=SyncInferenceConfig(),
|
||||||
|
fps=30,
|
||||||
|
duration=60,
|
||||||
|
task="my task",
|
||||||
|
)
|
||||||
|
|
||||||
|
signal_handler = ProcessSignalHandler(use_threads=True)
|
||||||
|
ctx = build_rollout_context(
|
||||||
|
cfg,
|
||||||
|
signal_handler.shutdown_event,
|
||||||
|
robot_action_processor=my_custom_action_processor, # optional
|
||||||
|
robot_observation_processor=my_custom_obs_processor, # optional
|
||||||
|
)
|
||||||
|
|
||||||
|
strategy = BaseStrategy(cfg.strategy)
|
||||||
|
try:
|
||||||
|
strategy.setup(ctx)
|
||||||
|
strategy.run(ctx)
|
||||||
|
finally:
|
||||||
|
strategy.teardown(ctx)
|
||||||
|
```
|
||||||
|
|
||||||
|
See `examples/so100_to_so100_EE/rollout.py` and `examples/phone_to_so100/rollout.py` for full examples with kinematics processors.
|
||||||
@@ -207,6 +207,56 @@ pip install 'lerobot[feetech]' # Feetech motor support
|
|||||||
|
|
||||||
_Multiple extras can be combined (e.g., `.[core_scripts,pi,pusht]`). For a full list of available extras, refer to `pyproject.toml`._
|
_Multiple extras can be combined (e.g., `.[core_scripts,pi,pusht]`). For a full list of available extras, refer to `pyproject.toml`._
|
||||||
|
|
||||||
|
### PyTorch CUDA variant (Linux only)
|
||||||
|
|
||||||
|
On Linux, the install path determines which CUDA wheel you get. macOS and Windows installs use the PyPI default (MPS / CPU / CUDA-Windows wheel respectively) and can skip this section.
|
||||||
|
|
||||||
|
<!-- prettier-ignore-start -->
|
||||||
|
|
||||||
|
<hfoptions id="cuda_variant">
|
||||||
|
<hfoption id="uv-source">
|
||||||
|
|
||||||
|
**Source install via `uv` (`uv sync` or `uv pip install -e .`)**
|
||||||
|
|
||||||
|
`torch` and `torchvision` are pinned by the project to the **CUDA 12.8** PyTorch index (`https://download.pytorch.org/whl/cu128`, driver floor **570.86**) — covers Ampere/Ada/Hopper/Blackwell GPUs. No action needed for typical NVIDIA setups.
|
||||||
|
|
||||||
|
To override for a different CUDA variant:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv pip install --force-reinstall torch torchvision \
|
||||||
|
--index-url https://download.pytorch.org/whl/cu126 # older drivers; or cu130 for Blackwell on driver ≥ 580
|
||||||
|
```
|
||||||
|
|
||||||
|
</hfoption>
|
||||||
|
<hfoption id="pip-conda">
|
||||||
|
|
||||||
|
**Source install via `pip`/`conda`, or `pip install lerobot` from PyPI**
|
||||||
|
|
||||||
|
PyPI default torch wheel is currently a cu130-bundled Linux wheel, driver floor **580.65**.
|
||||||
|
|
||||||
|
To pick a specific CUDA variant:
|
||||||
|
|
||||||
|
**Using `pip` or `conda`** — install torch first with an explicit index, then lerobot:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install --index-url https://download.pytorch.org/whl/cu128 torch torchvision
|
||||||
|
pip install -e ".[all]" # source
|
||||||
|
# — or —
|
||||||
|
pip install lerobot # from PyPI
|
||||||
|
```
|
||||||
|
|
||||||
|
**Using `uv` to install from PyPI** — one-liner via `--torch-backend` (uv ≥ 0.6):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv pip install --torch-backend cu128 lerobot
|
||||||
|
```
|
||||||
|
|
||||||
|
Supported values include `auto`, `cpu`, `cu126`, `cu128`, `cu129`, `cu130`, plus various `rocm*` and `xpu`. Swap as needed for your driver.
|
||||||
|
|
||||||
|
</hfoption>
|
||||||
|
</hfoptions>
|
||||||
|
<!-- prettier-ignore-end -->
|
||||||
|
|
||||||
### Troubleshooting
|
### Troubleshooting
|
||||||
|
|
||||||
If you encounter build errors, you may need to install additional system dependencies: `cmake`, `build-essential`, and `ffmpeg libs`.
|
If you encounter build errors, you may need to install additional system dependencies: `cmake`, `build-essential`, and `ffmpeg libs`.
|
||||||
|
|||||||
@@ -0,0 +1,18 @@
|
|||||||
|
# EVO1
|
||||||
|
|
||||||
|
EVO1 is a Vision-Language-Action policy for robot control. The LeRobot
|
||||||
|
integration uses an InternVL3 vision-language backbone with a flow-matching
|
||||||
|
action head, and supports staged training through the standard LeRobot policy
|
||||||
|
APIs.
|
||||||
|
|
||||||
|
The upstream EVO1 project is available at
|
||||||
|
[MINT-SJTU/Evo-1](https://github.com/MINT-SJTU/Evo-1).
|
||||||
|
|
||||||
|
```bibtex
|
||||||
|
@misc{evo1,
|
||||||
|
title = {EVO1},
|
||||||
|
author = {{MINT-SJTU}},
|
||||||
|
year = {2026},
|
||||||
|
howpublished = {\url{https://github.com/MINT-SJTU/Evo-1}},
|
||||||
|
}
|
||||||
|
```
|
||||||
@@ -61,17 +61,6 @@ lerobot-eval \
|
|||||||
--rename_map='{"observation.images.image": "observation.images.base_0_rgb", "observation.images.image2": "observation.images.left_wrist_0_rgb"}'
|
--rename_map='{"observation.images.image": "observation.images.base_0_rgb", "observation.images.image2": "observation.images.left_wrist_0_rgb"}'
|
||||||
```
|
```
|
||||||
|
|
||||||
### Recording
|
|
||||||
|
|
||||||
`lerobot-record` also supports rename maps, nested under the dataset config:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
lerobot-record \ # When running inference
|
|
||||||
--policy.path="<user>/smolVLA_finetuned" \
|
|
||||||
... \
|
|
||||||
--dataset.rename_map='{"observation.images.glove2": "observation.images.image"}'
|
|
||||||
```
|
|
||||||
|
|
||||||
## Alternative: edit the policy config directly
|
## Alternative: edit the policy config directly
|
||||||
|
|
||||||
If you always use the same dataset or environment, you can **edit the policy's `config.json`** so its observation keys match your data source. Then no rename map is needed.
|
If you always use the same dataset or environment, you can **edit the policy's `config.json`** so its observation keys match your data source. Then no rename map is needed.
|
||||||
@@ -105,10 +94,10 @@ XVLA-base has three visual inputs and `empty_cameras=0` by default. Your dataset
|
|||||||
|
|
||||||
## Quick reference
|
## Quick reference
|
||||||
|
|
||||||
| Goal | What to do |
|
| Goal | What to do |
|
||||||
| ----------------------------------------- | --------------------------------------------------------------------------- |
|
| --------------------------------------- | --------------------------------------------------------------------------- |
|
||||||
| Dataset keys ≠ policy keys | `--rename_map='{"dataset_key": "policy_key", ...}'` |
|
| Dataset keys ≠ policy keys | `--rename_map='{"dataset_key": "policy_key", ...}'` |
|
||||||
| Env keys ≠ policy keys (eval) | `--rename_map='{"env_key": "policy_key", ...}'` |
|
| Env keys ≠ policy keys (eval) | `--rename_map='{"env_key": "policy_key", ...}'` |
|
||||||
| Recording with different keys (inference) | `--dataset.rename_map='{"source_key": "policy_key", ...}'`. |
|
| Rollout with different keys (inference) | `--rename_map='{"source_key": "policy_key", ...}'`. |
|
||||||
| Fewer cameras than policy expects | `--policy.empty_cameras=N` (supported by PI0, PI05, PI0Fast, SmolVLA, XVLA) |
|
| Fewer cameras than policy expects | `--policy.empty_cameras=N` (supported by PI0, PI05, PI0Fast, SmolVLA, XVLA) |
|
||||||
| Avoid passing a rename map | Edit the policy's `config.json` so its keys match your data source |
|
| Avoid passing a rename map | Edit the policy's `config.json` so its keys match your data source |
|
||||||
|
|||||||
+7
-3
@@ -34,7 +34,7 @@ pip install -e ".[smolvla]"
|
|||||||
|
|
||||||
### Using RTC with Pi0
|
### Using RTC with Pi0
|
||||||
|
|
||||||
You can find a complete reference implementation in [eval_with_real_robot.py](examples/rtc/eval_with_real_robot.py).
|
You can use `lerobot-rollout --strategy.type=base --inference.type=rtc` for RTC deployment on real robots.
|
||||||
The snippet below provides a simplified pseudo-example of how RTC operates with Pi0 in your pipeline:
|
The snippet below provides a simplified pseudo-example of how RTC operates with Pi0 in your pipeline:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
@@ -137,8 +137,12 @@ The script generates a visualization of the denoising process, comparing standar
|
|||||||
## Testing RTC with a Real Robot
|
## Testing RTC with a Real Robot
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python examples/rtc/eval_with_real_robot.py \
|
lerobot-rollout \
|
||||||
|
--strategy.type=base \
|
||||||
--policy.path=${HF_USERNAME}/policy_repo_id \
|
--policy.path=${HF_USERNAME}/policy_repo_id \
|
||||||
|
--inference.type=rtc \
|
||||||
|
--inference.rtc.execution_horizon=10 \
|
||||||
|
--inference.rtc.max_guidance_weight=10.0 \
|
||||||
--robot.type=so100_follower \
|
--robot.type=so100_follower \
|
||||||
--robot.port=/dev/tty.usbmodem58FA0834591 \
|
--robot.port=/dev/tty.usbmodem58FA0834591 \
|
||||||
--robot.cameras="{ gripper: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
--robot.cameras="{ gripper: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||||
@@ -178,7 +182,7 @@ visualizer = RTCDebugVisualizer()
|
|||||||
# ... create plots
|
# ... create plots
|
||||||
```
|
```
|
||||||
|
|
||||||
See `examples/rtc/eval_dataset.py` for a complete example of visualization.
|
See `examples/rtc/eval_dataset.py` for a complete example of offline RTC visualization.
|
||||||
|
|
||||||
## References
|
## References
|
||||||
|
|
||||||
|
|||||||
+29
-28
@@ -46,7 +46,7 @@ This ensures identical task states map to consistent progress values, even acros
|
|||||||
|
|
||||||
## Inputs and Targets (What the new code expects)
|
## Inputs and Targets (What the new code expects)
|
||||||
|
|
||||||
SARM is trained through its processor (`src/lerobot/policies/sarm/processor_sarm.py`), which:
|
SARM is trained through its processor (`src/lerobot/rewards/sarm/processor_sarm.py`), which:
|
||||||
|
|
||||||
- **Encodes** images and task text with CLIP (ViT-B/32) into `video_features` and `text_features`
|
- **Encodes** images and task text with CLIP (ViT-B/32) into `video_features` and `text_features`
|
||||||
- **Pads/truncates** robot state into `state_features` (up to `max_state_dim`)
|
- **Pads/truncates** robot state into `state_features` (up to `max_state_dim`)
|
||||||
@@ -347,7 +347,7 @@ Use `compute_rabc_weights.py` with `--visualize-only` to visualize model predict
|
|||||||
<hfoption id="single_stage">
|
<hfoption id="single_stage">
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \
|
python -m lerobot.rewards.sarm.compute_rabc_weights \
|
||||||
--dataset-repo-id your-username/your-dataset \
|
--dataset-repo-id your-username/your-dataset \
|
||||||
--reward-model-path your-username/sarm-model \
|
--reward-model-path your-username/sarm-model \
|
||||||
--visualize-only \
|
--visualize-only \
|
||||||
@@ -360,7 +360,7 @@ python src/lerobot/policies/sarm/compute_rabc_weights.py \
|
|||||||
<hfoption id="dense_only">
|
<hfoption id="dense_only">
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \
|
python -m lerobot.rewards.sarm.compute_rabc_weights \
|
||||||
--dataset-repo-id your-username/your-dataset \
|
--dataset-repo-id your-username/your-dataset \
|
||||||
--reward-model-path your-username/sarm-model \
|
--reward-model-path your-username/sarm-model \
|
||||||
--visualize-only \
|
--visualize-only \
|
||||||
@@ -373,7 +373,7 @@ python src/lerobot/policies/sarm/compute_rabc_weights.py \
|
|||||||
<hfoption id="dual">
|
<hfoption id="dual">
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \
|
python -m lerobot.rewards.sarm.compute_rabc_weights \
|
||||||
--dataset-repo-id your-username/your-dataset \
|
--dataset-repo-id your-username/your-dataset \
|
||||||
--reward-model-path your-username/sarm-model \
|
--reward-model-path your-username/sarm-model \
|
||||||
--visualize-only \
|
--visualize-only \
|
||||||
@@ -429,7 +429,7 @@ The weighting follows **Equations 8-9** from the paper:
|
|||||||
First, run the SARM model on all frames in your dataset to compute progress values:
|
First, run the SARM model on all frames in your dataset to compute progress values:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \
|
python -m lerobot.rewards.sarm.compute_rabc_weights \
|
||||||
--dataset-repo-id your-username/your-dataset \
|
--dataset-repo-id your-username/your-dataset \
|
||||||
--reward-model-path your-username/sarm-model \
|
--reward-model-path your-username/sarm-model \
|
||||||
--head-mode sparse \
|
--head-mode sparse \
|
||||||
@@ -465,15 +465,15 @@ This script:
|
|||||||
|
|
||||||
### Step 5b: Train Policy with RA-BC
|
### Step 5b: Train Policy with RA-BC
|
||||||
|
|
||||||
Once you have the progress file, train your policy with RA-BC weighting. The progress file is auto-detected from the dataset path (`sarm_progress.parquet`). Currently PI0, PI0.5 and SmolVLA are supported with RA-BC:
|
Once you have the progress file, train your policy with RA-BC weighting. The progress file is auto-detected from the dataset path (`sarm_progress.parquet`) if not explicitly provided. Currently PI0, PI0.5 and SmolVLA are supported with RA-BC:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
lerobot-train \
|
lerobot-train \
|
||||||
--dataset.repo_id=your-username/your-dataset \
|
--dataset.repo_id=your-username/your-dataset \
|
||||||
--policy.type=pi0 \
|
--policy.type=pi0 \
|
||||||
--use_rabc=true \
|
--sample_weighting.type=rabc \
|
||||||
--rabc_head_mode=sparse \
|
--sample_weighting.head_mode=sparse \
|
||||||
--rabc_kappa=0.01 \
|
--sample_weighting.kappa=0.01 \
|
||||||
--output_dir=outputs/train/policy_rabc \
|
--output_dir=outputs/train/policy_rabc \
|
||||||
--batch_size=32 \
|
--batch_size=32 \
|
||||||
--steps=40000
|
--steps=40000
|
||||||
@@ -488,12 +488,13 @@ The training script automatically:
|
|||||||
|
|
||||||
**RA-BC Arguments:**
|
**RA-BC Arguments:**
|
||||||
|
|
||||||
| Argument | Description | Default |
|
| Argument | Description | Default |
|
||||||
| ---------------------- | ---------------------------------------------------------- | ---------------------------------- |
|
| ---------------------------------- | ------------------------------------------------------ | ----------------------- |
|
||||||
| `--use_rabc` | Enable RA-BC sample weighting | `false` |
|
| `--sample_weighting.type` | Weighting strategy type (`rabc` or `uniform`) | `rabc` |
|
||||||
| `--rabc_progress_path` | Path to progress parquet file (auto-detected from dataset) | `sarm_progress.parquet` in dataset |
|
| `--sample_weighting.progress_path` | Path to progress parquet file | `sarm_progress.parquet` |
|
||||||
| `--rabc_head_mode` | Which SARM head's progress to use: `sparse` or `dense` | `sparse` |
|
| `--sample_weighting.head_mode` | Which SARM head's progress to use: `sparse` or `dense` | `sparse` |
|
||||||
| `--rabc_kappa` | Threshold κ for high-quality samples | `0.01` |
|
| `--sample_weighting.kappa` | Threshold κ for high-quality samples | `0.01` |
|
||||||
|
| `--sample_weighting.epsilon` | Small constant for numerical stability | `1e-6` |
|
||||||
|
|
||||||
### Tuning RA-BC Kappa
|
### Tuning RA-BC Kappa
|
||||||
|
|
||||||
@@ -511,30 +512,30 @@ The `kappa` parameter is the threshold that determines which samples get full we
|
|||||||
|
|
||||||
Monitor these WandB metrics during training:
|
Monitor these WandB metrics during training:
|
||||||
|
|
||||||
| Metric | Healthy Range | Problem Indicator |
|
| Metric | Healthy Range | Problem Indicator |
|
||||||
| ------------------ | ------------- | ------------------------- |
|
| ----------------------------- | ------------- | ------------------------- |
|
||||||
| `rabc_mean_weight` | 0.3 - 0.8 | ≈ 1.0 means kappa too low |
|
| `sample_weight_mean_weight` | 0.3 - 0.8 | ≈ 1.0 means kappa too low |
|
||||||
| `rabc_delta_mean` | > 0 | Should be positive |
|
| `sample_weighting/delta_mean` | > 0 | Should be positive |
|
||||||
| `rabc_delta_std` | > 0 | Variance in data quality |
|
| `sample_weighting/delta_std` | > 0 | Variance in data quality |
|
||||||
|
|
||||||
**If `rabc_mean_weight ≈ 1.0`:** Your kappa is too low. Most samples have `delta > kappa` and bypass the soft-weighting entirely. RA-BC becomes equivalent to vanilla BC.
|
**If `sample_weight_mean_weight ≈ 1.0`:** Your kappa is too low. Most samples have `delta > kappa` and bypass the soft-weighting entirely. RA-BC becomes equivalent to vanilla BC.
|
||||||
|
|
||||||
**Setting kappa based on your data:**
|
**Setting kappa based on your data:**
|
||||||
|
|
||||||
The default `kappa=0.01` was tuned for the paper's T-shirt folding task (~90s episodes at 30fps). For your dataset, check the logged `rabc_delta_mean` and `rabc_delta_std`:
|
The default `kappa=0.01` was tuned for the paper's T-shirt folding task (~90s episodes at 30fps). For your dataset, check the logged `sample_weighting/delta_mean` and `sample_weighting/delta_std`:
|
||||||
|
|
||||||
```
|
```
|
||||||
# If delta_mean ≈ 0.03 and delta_std ≈ 0.02:
|
# If delta_mean ≈ 0.03 and delta_std ≈ 0.02:
|
||||||
# Most deltas fall in range [0.01, 0.05]
|
# Most deltas fall in range [0.01, 0.05]
|
||||||
|
|
||||||
# Option 1: Set kappa = delta_mean (medium selectivity)
|
# Option 1: Set kappa = delta_mean (medium selectivity)
|
||||||
--rabc_kappa=0.03
|
--sample_weighting.kappa=0.03
|
||||||
|
|
||||||
# Option 2: Set kappa = delta_mean + delta_std (high selectivity)
|
# Option 2: Set kappa = delta_mean + delta_std (high selectivity)
|
||||||
--rabc_kappa=0.05
|
--sample_weighting.kappa=0.05
|
||||||
|
|
||||||
# Option 3: Set kappa = delta_mean + 2*delta_std (very selective)
|
# Option 3: Set kappa = delta_mean + 2*delta_std (very selective)
|
||||||
--rabc_kappa=0.07
|
--sample_weighting.kappa=0.07
|
||||||
```
|
```
|
||||||
|
|
||||||
**When RA-BC may not help:**
|
**When RA-BC may not help:**
|
||||||
@@ -550,8 +551,8 @@ accelerate launch \
|
|||||||
src/lerobot/scripts/lerobot_train.py \
|
src/lerobot/scripts/lerobot_train.py \
|
||||||
--dataset.repo_id=your-username/your-dataset \
|
--dataset.repo_id=your-username/your-dataset \
|
||||||
--policy.type=pi0 \
|
--policy.type=pi0 \
|
||||||
--use_rabc=true \
|
--sample_weighting.type=rabc \
|
||||||
--rabc_kappa=0.01 \
|
--sample_weighting.kappa=0.01 \
|
||||||
--output_dir=outputs/train/policy_rabc \
|
--output_dir=outputs/train/policy_rabc \
|
||||||
--batch_size=32 \
|
--batch_size=32 \
|
||||||
--steps=40000
|
--steps=40000
|
||||||
@@ -576,7 +577,7 @@ accelerate launch \
|
|||||||
### RA-BC
|
### RA-BC
|
||||||
|
|
||||||
1. **Train SARM first**: RA-BC quality depends entirely on SARM quality
|
1. **Train SARM first**: RA-BC quality depends entirely on SARM quality
|
||||||
2. **Monitor `rabc_mean_weight`**: If it's ≈ 1.0, increase kappa (see [Tuning RA-BC Kappa](#tuning-ra-bc-kappa))
|
2. **Monitor `sample_weight_mean_weight`**: If it's ≈ 1.0, increase kappa (see [Tuning RA-BC Kappa](#tuning-ra-bc-kappa))
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
|||||||
@@ -274,7 +274,8 @@ python src/lerobot/scripts/lerobot_train.py \
|
|||||||
Once trained, we recommend deploying policies using inference-time RTC:
|
Once trained, we recommend deploying policies using inference-time RTC:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python examples/rtc/eval_with_real_robot.py \
|
lerobot-rollout \
|
||||||
|
--strategy.type=base \
|
||||||
--policy.path=your-username/your-repo-id \
|
--policy.path=your-username/your-repo-id \
|
||||||
--policy.device=cuda \
|
--policy.device=cuda \
|
||||||
--robot.type=unitree_g1 \
|
--robot.type=unitree_g1 \
|
||||||
@@ -284,7 +285,7 @@ python examples/rtc/eval_with_real_robot.py \
|
|||||||
--task="task_description" \
|
--task="task_description" \
|
||||||
--duration=1000 \
|
--duration=1000 \
|
||||||
--fps=30 \
|
--fps=30 \
|
||||||
--rtc.enabled=true
|
--inference.type=rtc
|
||||||
```
|
```
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|||||||
@@ -220,7 +220,7 @@ REAL_DIM = 12
|
|||||||
# Postprocessing: Trim 20D predictions to 12D for deployment
|
# Postprocessing: Trim 20D predictions to 12D for deployment
|
||||||
```
|
```
|
||||||
|
|
||||||
See the [action_hub.py](/home/jade_choghari/robot/lerobot/src/lerobot/policies/xvla/action_hub.py) implementation for details.
|
See the [action_hub.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/xvla/action_hub.py) implementation for details.
|
||||||
|
|
||||||
#### Auto Action Mode (Recommended)
|
#### Auto Action Mode (Recommended)
|
||||||
|
|
||||||
@@ -519,9 +519,9 @@ If you use X-VLA in your research, please cite:
|
|||||||
|
|
||||||
- [X-VLA Paper](https://arxiv.org/pdf/2510.10274)
|
- [X-VLA Paper](https://arxiv.org/pdf/2510.10274)
|
||||||
- [LeRobot Documentation](https://github.com/huggingface/lerobot)
|
- [LeRobot Documentation](https://github.com/huggingface/lerobot)
|
||||||
- [Action Registry Implementation](https://github.com/huggingface/lerobot/src/lerobot/policies/xvla/action_hub.py)
|
- [Action Registry Implementation](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/xvla/action_hub.py)
|
||||||
- [Processor Implementation](https://github.com/huggingface/lerobot/src/lerobot/policies/xvla/processor_xvla.py)
|
- [Processor Implementation](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/xvla/processor_xvla.py)
|
||||||
- [Model Configuration](https://github.com/huggingface/lerobot/src/lerobot/policies/xvla/configuration_xvla.py)
|
- [Model Configuration](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/xvla/configuration_xvla.py)
|
||||||
|
|
||||||
## Contributing
|
## Contributing
|
||||||
|
|
||||||
|
|||||||
@@ -69,7 +69,7 @@ class ComputeProgressShards(PipelineStep):
|
|||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from lerobot.policies.sarm.compute_rabc_weights import (
|
from lerobot.rewards.sarm.compute_rabc_weights import (
|
||||||
generate_all_frame_indices,
|
generate_all_frame_indices,
|
||||||
interpolate_progress,
|
interpolate_progress,
|
||||||
load_sarm_resources,
|
load_sarm_resources,
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -1,226 +0,0 @@
|
|||||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
"""Shared utilities for Human-in-the-Loop data collection scripts."""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import time
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from lerobot.common.control_utils import is_headless
|
|
||||||
from lerobot.processor import (
|
|
||||||
IdentityProcessorStep,
|
|
||||||
RobotAction,
|
|
||||||
RobotObservation,
|
|
||||||
RobotProcessorPipeline,
|
|
||||||
observation_to_transition,
|
|
||||||
robot_action_observation_to_transition,
|
|
||||||
transition_to_observation,
|
|
||||||
transition_to_robot_action,
|
|
||||||
)
|
|
||||||
from lerobot.robots import Robot
|
|
||||||
from lerobot.teleoperators import Teleoperator
|
|
||||||
from lerobot.utils.robot_utils import precise_sleep
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class HILDatasetConfig:
|
|
||||||
repo_id: str
|
|
||||||
single_task: str
|
|
||||||
root: str | Path | None = None
|
|
||||||
fps: int = 30
|
|
||||||
episode_time_s: float = 120
|
|
||||||
num_episodes: int = 50
|
|
||||||
video: bool = True
|
|
||||||
push_to_hub: bool = True
|
|
||||||
private: bool = False
|
|
||||||
tags: list[str] | None = None
|
|
||||||
num_image_writer_processes: int = 0
|
|
||||||
num_image_writer_threads_per_camera: int = 4
|
|
||||||
video_encoding_batch_size: int = 1
|
|
||||||
vcodec: str = "auto"
|
|
||||||
streaming_encoding: bool = True
|
|
||||||
encoder_queue_maxsize: int = 30
|
|
||||||
encoder_threads: int | None = None
|
|
||||||
rename_map: dict[str, str] = field(default_factory=dict)
|
|
||||||
|
|
||||||
|
|
||||||
def teleop_has_motor_control(teleop: Teleoperator) -> bool:
|
|
||||||
"""Check if teleoperator has motor control capabilities."""
|
|
||||||
return all(hasattr(teleop, attr) for attr in ("enable_torque", "disable_torque", "write_goal_positions"))
|
|
||||||
|
|
||||||
|
|
||||||
def teleop_disable_torque(teleop: Teleoperator) -> None:
|
|
||||||
"""Disable teleop torque if supported."""
|
|
||||||
if hasattr(teleop, "disable_torque"):
|
|
||||||
teleop.disable_torque()
|
|
||||||
|
|
||||||
|
|
||||||
def teleop_enable_torque(teleop: Teleoperator) -> None:
|
|
||||||
"""Enable teleop torque if supported."""
|
|
||||||
if hasattr(teleop, "enable_torque"):
|
|
||||||
teleop.enable_torque()
|
|
||||||
|
|
||||||
|
|
||||||
def teleop_smooth_move_to(teleop: Teleoperator, target_pos: dict, duration_s: float = 2.0, fps: int = 50):
|
|
||||||
"""Smoothly move teleop to target position if motor control is available."""
|
|
||||||
if not teleop_has_motor_control(teleop):
|
|
||||||
logger.warning("Teleop does not support motor control - cannot mirror robot position")
|
|
||||||
return
|
|
||||||
|
|
||||||
teleop_enable_torque(teleop)
|
|
||||||
current = teleop.get_action()
|
|
||||||
steps = max(int(duration_s * fps), 1)
|
|
||||||
|
|
||||||
for step in range(steps + 1):
|
|
||||||
t = step / steps
|
|
||||||
interp = {}
|
|
||||||
for k in current:
|
|
||||||
if k in target_pos:
|
|
||||||
interp[k] = current[k] * (1 - t) + target_pos[k] * t
|
|
||||||
else:
|
|
||||||
interp[k] = current[k]
|
|
||||||
teleop.write_goal_positions(interp)
|
|
||||||
time.sleep(1 / fps)
|
|
||||||
|
|
||||||
|
|
||||||
def init_keyboard_listener():
|
|
||||||
"""Initialize keyboard listener with HIL controls."""
|
|
||||||
events = {
|
|
||||||
"exit_early": False,
|
|
||||||
"rerecord_episode": False,
|
|
||||||
"stop_recording": False,
|
|
||||||
"policy_paused": False,
|
|
||||||
"correction_active": False,
|
|
||||||
"resume_policy": False,
|
|
||||||
"in_reset": False,
|
|
||||||
"start_next_episode": False,
|
|
||||||
}
|
|
||||||
|
|
||||||
if is_headless():
|
|
||||||
logger.warning("Headless environment - keyboard controls unavailable")
|
|
||||||
return None, events
|
|
||||||
|
|
||||||
from pynput import keyboard
|
|
||||||
|
|
||||||
def on_press(key):
|
|
||||||
try:
|
|
||||||
if events["in_reset"]:
|
|
||||||
if key in [keyboard.Key.space, keyboard.Key.right]:
|
|
||||||
logger.info("[HIL] Starting next episode...")
|
|
||||||
events["start_next_episode"] = True
|
|
||||||
elif hasattr(key, "char") and key.char == "c":
|
|
||||||
events["start_next_episode"] = True
|
|
||||||
elif key == keyboard.Key.esc:
|
|
||||||
logger.info("[HIL] ESC - Stop recording, pushing to hub...")
|
|
||||||
events["stop_recording"] = True
|
|
||||||
events["start_next_episode"] = True
|
|
||||||
else:
|
|
||||||
if key == keyboard.Key.space:
|
|
||||||
if not events["policy_paused"] and not events["correction_active"]:
|
|
||||||
logger.info("[HIL] PAUSED - Press 'c' to take control or 'p' to resume policy")
|
|
||||||
events["policy_paused"] = True
|
|
||||||
elif hasattr(key, "char") and key.char == "c":
|
|
||||||
if events["policy_paused"] and not events["correction_active"]:
|
|
||||||
logger.info("[HIL] Taking control...")
|
|
||||||
events["start_next_episode"] = True
|
|
||||||
elif hasattr(key, "char") and key.char == "p":
|
|
||||||
if events["policy_paused"] or events["correction_active"]:
|
|
||||||
logger.info("[HIL] Resuming policy...")
|
|
||||||
events["resume_policy"] = True
|
|
||||||
elif key == keyboard.Key.right:
|
|
||||||
logger.info("[HIL] End episode")
|
|
||||||
events["exit_early"] = True
|
|
||||||
elif key == keyboard.Key.left:
|
|
||||||
logger.info("[HIL] Re-record episode")
|
|
||||||
events["rerecord_episode"] = True
|
|
||||||
events["exit_early"] = True
|
|
||||||
elif key == keyboard.Key.esc:
|
|
||||||
logger.info("[HIL] ESC - Stop recording...")
|
|
||||||
events["stop_recording"] = True
|
|
||||||
events["exit_early"] = True
|
|
||||||
except Exception as e:
|
|
||||||
logger.info(f"Key error: {e}")
|
|
||||||
|
|
||||||
listener = keyboard.Listener(on_press=on_press)
|
|
||||||
listener.start()
|
|
||||||
return listener, events
|
|
||||||
|
|
||||||
|
|
||||||
def make_identity_processors():
|
|
||||||
"""Create identity processors for recording."""
|
|
||||||
teleop_proc = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
|
||||||
steps=[IdentityProcessorStep()],
|
|
||||||
to_transition=robot_action_observation_to_transition,
|
|
||||||
to_output=transition_to_robot_action,
|
|
||||||
)
|
|
||||||
obs_proc = RobotProcessorPipeline[RobotObservation, RobotObservation](
|
|
||||||
steps=[IdentityProcessorStep()],
|
|
||||||
to_transition=observation_to_transition,
|
|
||||||
to_output=transition_to_observation,
|
|
||||||
)
|
|
||||||
return teleop_proc, obs_proc
|
|
||||||
|
|
||||||
|
|
||||||
def reset_loop(robot: Robot, teleop: Teleoperator, events: dict, fps: int):
|
|
||||||
"""Reset period where human repositions environment."""
|
|
||||||
logger.info("[HIL] RESET")
|
|
||||||
|
|
||||||
events["in_reset"] = True
|
|
||||||
events["start_next_episode"] = False
|
|
||||||
|
|
||||||
obs = robot.get_observation()
|
|
||||||
robot_pos = {k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features}
|
|
||||||
teleop_smooth_move_to(teleop, robot_pos, duration_s=2.0, fps=50)
|
|
||||||
|
|
||||||
logger.info("Press any key to enable teleoperation")
|
|
||||||
while not events["start_next_episode"] and not events["stop_recording"]:
|
|
||||||
precise_sleep(0.05)
|
|
||||||
|
|
||||||
if events["stop_recording"]:
|
|
||||||
return
|
|
||||||
|
|
||||||
events["start_next_episode"] = False
|
|
||||||
teleop_disable_torque(teleop)
|
|
||||||
logger.info("Teleop enabled - press any key to start episode")
|
|
||||||
|
|
||||||
while not events["start_next_episode"] and not events["stop_recording"]:
|
|
||||||
loop_start = time.perf_counter()
|
|
||||||
action = teleop.get_action()
|
|
||||||
robot.send_action(action)
|
|
||||||
precise_sleep(1 / fps - (time.perf_counter() - loop_start))
|
|
||||||
|
|
||||||
events["in_reset"] = False
|
|
||||||
events["start_next_episode"] = False
|
|
||||||
events["exit_early"] = False
|
|
||||||
events["policy_paused"] = False
|
|
||||||
events["correction_active"] = False
|
|
||||||
events["resume_policy"] = False
|
|
||||||
|
|
||||||
|
|
||||||
def print_controls(rtc: bool = False):
|
|
||||||
"""Print control instructions."""
|
|
||||||
mode = "Human-in-the-Loop Data Collection" + (" (RTC)" if rtc else "")
|
|
||||||
logger.info(
|
|
||||||
"%s\n Controls:\n"
|
|
||||||
" SPACE - Pause policy\n"
|
|
||||||
" c - Take control\n"
|
|
||||||
" p - Resume policy after pause/correction\n"
|
|
||||||
" → - End episode\n"
|
|
||||||
" ESC - Stop and push to hub",
|
|
||||||
mode,
|
|
||||||
)
|
|
||||||
+62
-31
@@ -14,17 +14,21 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from lerobot.common.control_utils import init_keyboard_listener
|
import logging
|
||||||
|
import time
|
||||||
|
|
||||||
|
from lerobot.common.control_utils import init_keyboard_listener, predict_action
|
||||||
from lerobot.datasets import LeRobotDataset
|
from lerobot.datasets import LeRobotDataset
|
||||||
from lerobot.policies import make_pre_post_processors
|
from lerobot.policies import make_pre_post_processors
|
||||||
from lerobot.policies.act import ACTPolicy
|
from lerobot.policies.act import ACTPolicy
|
||||||
|
from lerobot.policies.utils import make_robot_action
|
||||||
from lerobot.processor import make_default_processors
|
from lerobot.processor import make_default_processors
|
||||||
from lerobot.robots.lekiwi import LeKiwiClient, LeKiwiClientConfig
|
from lerobot.robots.lekiwi import LeKiwiClient, LeKiwiClientConfig
|
||||||
from lerobot.scripts.lerobot_record import record_loop
|
|
||||||
from lerobot.utils.constants import ACTION, OBS_STR
|
from lerobot.utils.constants import ACTION, OBS_STR
|
||||||
from lerobot.utils.feature_utils import hw_to_dataset_features
|
from lerobot.utils.feature_utils import build_dataset_frame, hw_to_dataset_features
|
||||||
|
from lerobot.utils.robot_utils import precise_sleep
|
||||||
from lerobot.utils.utils import log_say
|
from lerobot.utils.utils import log_say
|
||||||
from lerobot.utils.visualization_utils import init_rerun
|
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
|
||||||
|
|
||||||
NUM_EPISODES = 2
|
NUM_EPISODES = 2
|
||||||
FPS = 30
|
FPS = 30
|
||||||
@@ -35,6 +39,9 @@ HF_DATASET_ID = "<hf_username>/<eval_dataset_repo_id>"
|
|||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
# NOTE: For production policy deployment, use `lerobot-rollout` CLI instead.
|
||||||
|
# This script provides a self-contained example for educational purposes.
|
||||||
|
|
||||||
# Create the robot configuration & robot
|
# Create the robot configuration & robot
|
||||||
robot_config = LeKiwiClientConfig(remote_ip="172.18.134.136", id="lekiwi")
|
robot_config = LeKiwiClientConfig(remote_ip="172.18.134.136", id="lekiwi")
|
||||||
|
|
||||||
@@ -83,43 +90,67 @@ def main():
|
|||||||
raise ValueError("Robot is not connected!")
|
raise ValueError("Robot is not connected!")
|
||||||
|
|
||||||
print("Starting evaluate loop...")
|
print("Starting evaluate loop...")
|
||||||
|
control_interval = 1 / FPS
|
||||||
recorded_episodes = 0
|
recorded_episodes = 0
|
||||||
while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
|
while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
|
||||||
log_say(f"Running inference, recording eval episode {recorded_episodes} of {NUM_EPISODES}")
|
log_say(f"Running inference, recording eval episode {recorded_episodes} of {NUM_EPISODES}")
|
||||||
|
|
||||||
# Main record loop
|
# Inline evaluation loop: predict actions and send to robot
|
||||||
record_loop(
|
timestamp = 0
|
||||||
robot=robot,
|
start_episode_t = time.perf_counter()
|
||||||
events=events,
|
while timestamp < EPISODE_TIME_SEC:
|
||||||
fps=FPS,
|
start_loop_t = time.perf_counter()
|
||||||
policy=policy,
|
|
||||||
preprocessor=preprocessor, # Pass the pre and post policy processors
|
if events["exit_early"]:
|
||||||
postprocessor=postprocessor,
|
events["exit_early"] = False
|
||||||
dataset=dataset,
|
break
|
||||||
control_time_s=EPISODE_TIME_SEC,
|
|
||||||
single_task=TASK_DESCRIPTION,
|
# Get robot observation
|
||||||
display_data=True,
|
obs = robot.get_observation()
|
||||||
teleop_action_processor=teleop_action_processor,
|
obs_processed = robot_observation_processor(obs)
|
||||||
robot_action_processor=robot_action_processor,
|
observation_frame = build_dataset_frame(dataset.features, obs_processed, prefix=OBS_STR)
|
||||||
robot_observation_processor=robot_observation_processor,
|
|
||||||
)
|
# Predict action using the policy
|
||||||
|
action_tensor = predict_action(
|
||||||
|
observation=observation_frame,
|
||||||
|
policy=policy,
|
||||||
|
device=policy.config.device,
|
||||||
|
preprocessor=preprocessor,
|
||||||
|
postprocessor=postprocessor,
|
||||||
|
use_amp=policy.config.device.type == "cuda",
|
||||||
|
task=TASK_DESCRIPTION,
|
||||||
|
robot_type=robot.name,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert policy output to robot action dict
|
||||||
|
action_values = make_robot_action(action_tensor, dataset.features)
|
||||||
|
|
||||||
|
# Process and send action to robot
|
||||||
|
robot_action_to_send = robot_action_processor((action_values, obs))
|
||||||
|
robot.send_action(robot_action_to_send)
|
||||||
|
|
||||||
|
# Write to dataset
|
||||||
|
action_frame = build_dataset_frame(dataset.features, action_values, prefix=ACTION)
|
||||||
|
frame = {**observation_frame, **action_frame, "task": TASK_DESCRIPTION}
|
||||||
|
dataset.add_frame(frame)
|
||||||
|
|
||||||
|
log_rerun_data(observation=obs_processed, action=action_values)
|
||||||
|
|
||||||
|
dt_s = time.perf_counter() - start_loop_t
|
||||||
|
sleep_time_s = control_interval - dt_s
|
||||||
|
if sleep_time_s < 0:
|
||||||
|
logging.warning(
|
||||||
|
f"Evaluate loop is running slower ({1 / dt_s:.1f} Hz) than the target FPS ({FPS} Hz)."
|
||||||
|
)
|
||||||
|
precise_sleep(max(sleep_time_s, 0.0))
|
||||||
|
timestamp = time.perf_counter() - start_episode_t
|
||||||
|
|
||||||
# Reset the environment if not stopping or re-recording
|
# Reset the environment if not stopping or re-recording
|
||||||
if not events["stop_recording"] and (
|
if not events["stop_recording"] and (
|
||||||
(recorded_episodes < NUM_EPISODES - 1) or events["rerecord_episode"]
|
(recorded_episodes < NUM_EPISODES - 1) or events["rerecord_episode"]
|
||||||
):
|
):
|
||||||
log_say("Reset the environment")
|
log_say("Reset the environment")
|
||||||
record_loop(
|
log_say("Waiting for environment reset, press right arrow key when ready...")
|
||||||
robot=robot,
|
|
||||||
events=events,
|
|
||||||
fps=FPS,
|
|
||||||
control_time_s=EPISODE_TIME_SEC,
|
|
||||||
single_task=TASK_DESCRIPTION,
|
|
||||||
display_data=True,
|
|
||||||
teleop_action_processor=teleop_action_processor,
|
|
||||||
robot_action_processor=robot_action_processor,
|
|
||||||
robot_observation_processor=robot_observation_processor,
|
|
||||||
)
|
|
||||||
|
|
||||||
if events["rerecord_episode"]:
|
if events["rerecord_episode"]:
|
||||||
log_say("Re-record episode")
|
log_say("Re-record episode")
|
||||||
|
|||||||
@@ -45,9 +45,6 @@ def main():
|
|||||||
leader_arm = SO100Leader(leader_arm_config)
|
leader_arm = SO100Leader(leader_arm_config)
|
||||||
keyboard = KeyboardTeleop(keyboard_config)
|
keyboard = KeyboardTeleop(keyboard_config)
|
||||||
|
|
||||||
# TODO(Steven): Update this example to use pipelines
|
|
||||||
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
|
|
||||||
|
|
||||||
# Configure the dataset features
|
# Configure the dataset features
|
||||||
action_features = hw_to_dataset_features(robot.action_features, ACTION)
|
action_features = hw_to_dataset_features(robot.action_features, ACTION)
|
||||||
obs_features = hw_to_dataset_features(robot.observation_features, OBS_STR)
|
obs_features = hw_to_dataset_features(robot.observation_features, OBS_STR)
|
||||||
@@ -77,6 +74,10 @@ def main():
|
|||||||
if not robot.is_connected or not leader_arm.is_connected or not keyboard.is_connected:
|
if not robot.is_connected or not leader_arm.is_connected or not keyboard.is_connected:
|
||||||
raise ValueError("Robot or teleop is not connected!")
|
raise ValueError("Robot or teleop is not connected!")
|
||||||
|
|
||||||
|
teleop_action_processor, robot_action_processor, robot_observation_processor = (
|
||||||
|
make_default_processors()
|
||||||
|
)
|
||||||
|
|
||||||
print("Starting record loop...")
|
print("Starting record loop...")
|
||||||
recorded_episodes = 0
|
recorded_episodes = 0
|
||||||
while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
|
while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
|
||||||
@@ -87,14 +88,14 @@ def main():
|
|||||||
robot=robot,
|
robot=robot,
|
||||||
events=events,
|
events=events,
|
||||||
fps=FPS,
|
fps=FPS,
|
||||||
|
teleop_action_processor=teleop_action_processor,
|
||||||
|
robot_action_processor=robot_action_processor,
|
||||||
|
robot_observation_processor=robot_observation_processor,
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
teleop=[leader_arm, keyboard],
|
teleop=[leader_arm, keyboard],
|
||||||
control_time_s=EPISODE_TIME_SEC,
|
control_time_s=EPISODE_TIME_SEC,
|
||||||
single_task=TASK_DESCRIPTION,
|
single_task=TASK_DESCRIPTION,
|
||||||
display_data=True,
|
display_data=True,
|
||||||
teleop_action_processor=teleop_action_processor,
|
|
||||||
robot_action_processor=robot_action_processor,
|
|
||||||
robot_observation_processor=robot_observation_processor,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Reset the environment if not stopping or re-recording
|
# Reset the environment if not stopping or re-recording
|
||||||
@@ -106,13 +107,13 @@ def main():
|
|||||||
robot=robot,
|
robot=robot,
|
||||||
events=events,
|
events=events,
|
||||||
fps=FPS,
|
fps=FPS,
|
||||||
|
teleop_action_processor=teleop_action_processor,
|
||||||
|
robot_action_processor=robot_action_processor,
|
||||||
|
robot_observation_processor=robot_observation_processor,
|
||||||
teleop=[leader_arm, keyboard],
|
teleop=[leader_arm, keyboard],
|
||||||
control_time_s=RESET_TIME_SEC,
|
control_time_s=RESET_TIME_SEC,
|
||||||
single_task=TASK_DESCRIPTION,
|
single_task=TASK_DESCRIPTION,
|
||||||
display_data=True,
|
display_data=True,
|
||||||
teleop_action_processor=teleop_action_processor,
|
|
||||||
robot_action_processor=robot_action_processor,
|
|
||||||
robot_observation_processor=robot_observation_processor,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if events["rerecord_episode"]:
|
if events["rerecord_episode"]:
|
||||||
|
|||||||
@@ -0,0 +1,77 @@
|
|||||||
|
# !/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""Run a trained policy on LeKiwi without recording (base rollout).
|
||||||
|
|
||||||
|
Uses the rollout engine's :class:`BaseStrategy` (autonomous execution,
|
||||||
|
no dataset) with :class:`SyncInferenceConfig` (inline policy call per
|
||||||
|
control tick). For a CLI entry point with the same capabilities plus
|
||||||
|
recording, upload, and human-in-the-loop variants, see ``lerobot-rollout``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from lerobot.configs import PreTrainedConfig
|
||||||
|
from lerobot.robots.lekiwi import LeKiwiClientConfig
|
||||||
|
from lerobot.rollout import BaseStrategyConfig, RolloutConfig, build_rollout_context
|
||||||
|
from lerobot.rollout.inference import SyncInferenceConfig
|
||||||
|
from lerobot.rollout.strategies import BaseStrategy
|
||||||
|
from lerobot.utils.process import ProcessSignalHandler
|
||||||
|
from lerobot.utils.utils import init_logging
|
||||||
|
|
||||||
|
FPS = 30
|
||||||
|
DURATION_SEC = 60
|
||||||
|
TASK_DESCRIPTION = "My task description"
|
||||||
|
HF_MODEL_ID = "<hf_username>/<model_repo_id>"
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
init_logging()
|
||||||
|
|
||||||
|
# Robot: LeKiwi client — make sure lekiwi_host is already running on the robot.
|
||||||
|
robot_config = LeKiwiClientConfig(remote_ip="172.18.134.136", id="lekiwi")
|
||||||
|
|
||||||
|
# Policy: load the pretrained config. ``pretrained_path`` is read downstream
|
||||||
|
# by ``build_rollout_context`` to reload the full model.
|
||||||
|
policy_config = PreTrainedConfig.from_pretrained(HF_MODEL_ID)
|
||||||
|
policy_config.pretrained_path = HF_MODEL_ID
|
||||||
|
|
||||||
|
# Assemble the rollout config: base strategy (no recording) + sync inference.
|
||||||
|
cfg = RolloutConfig(
|
||||||
|
robot=robot_config,
|
||||||
|
policy=policy_config,
|
||||||
|
strategy=BaseStrategyConfig(),
|
||||||
|
inference=SyncInferenceConfig(),
|
||||||
|
fps=FPS,
|
||||||
|
duration=DURATION_SEC,
|
||||||
|
task=TASK_DESCRIPTION,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Graceful Ctrl-C: the strategy loop exits when shutdown_event is set.
|
||||||
|
signal_handler = ProcessSignalHandler(use_threads=True)
|
||||||
|
|
||||||
|
# Build the context (connects robot, loads policy, wires the inference strategy).
|
||||||
|
# No custom processors here — LeKiwi runs on raw joint features.
|
||||||
|
ctx = build_rollout_context(cfg, signal_handler.shutdown_event)
|
||||||
|
|
||||||
|
strategy = BaseStrategy(cfg.strategy)
|
||||||
|
try:
|
||||||
|
strategy.setup(ctx)
|
||||||
|
strategy.run(ctx)
|
||||||
|
finally:
|
||||||
|
strategy.teardown(ctx)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -0,0 +1,136 @@
|
|||||||
|
# OMX Follower — Cube Pick And Place Example
|
||||||
|
|
||||||
|
This is an example of what is possible to do with LeRobot on a physical setup.
|
||||||
|
It is a WIP and being used internally at LeRobot and specific to our setup, but we hope it can be a useful reference for how to use LeRobot APIs and CLIs.
|
||||||
|
|
||||||
|
It includes an end-to-end example for the **OMX Follower** robot arm: pick and place a cube dataset, train a policy, and deploy it autonomously.
|
||||||
|
|
||||||
|
## Hardware
|
||||||
|
|
||||||
|
| Component | Value |
|
||||||
|
| --------- | ------------------------------------ |
|
||||||
|
| Robot | OMX Follower |
|
||||||
|
| Cameras | 2× OpenCV cameras (wrist + top-down) |
|
||||||
|
|
||||||
|
## Scripts
|
||||||
|
|
||||||
|
| Script | Purpose |
|
||||||
|
| ---------------------- | --------------------------------------------------------------- |
|
||||||
|
| `reset_environment.py` | Standalone utility: sweep workspace, grab cube, place cube |
|
||||||
|
| `record_grab.py` | Automated data collection: reset → place → record grab episodes |
|
||||||
|
|
||||||
|
## Setup
|
||||||
|
|
||||||
|
Make sure you have LeRobot installed in your env. (See [the installation guide](https://huggingface.co/docs/lerobot/installation))
|
||||||
|
|
||||||
|
Next, we will declare some environment variables for convenience. Adjust the camera indices and robot port to match your system configuration.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export ROBOT_PORT=/dev/ttyACM0
|
||||||
|
export TELEOP_PORT=/dev/ttyACM1
|
||||||
|
export HF_USERNAME=<your_hf_username>
|
||||||
|
export ROBOT_CAMERAS="{ wrist: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30, fourcc: MJPG}, top: {type: opencv, index_or_path: 2, width: 640, height: 480, fps: 30, fourcc: MJPG} }"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Step 1 — Collect Data
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-record \
|
||||||
|
--robot.type=omx_follower \
|
||||||
|
--robot.port=$ROBOT_PORT \
|
||||||
|
--robot.id=omx_follower \
|
||||||
|
--robot.cameras="$ROBOT_CAMERAS" \
|
||||||
|
--teleop.type=omx_leader \
|
||||||
|
--teleop.port=$TELEOP_PORT \
|
||||||
|
--teleop.id=omx_leader \
|
||||||
|
--dataset.repo_id=$HF_USERNAME/omx_pickandplace \
|
||||||
|
--dataset.root=data/omx_pickandplace \
|
||||||
|
--dataset.num_episodes=50 \
|
||||||
|
--dataset.single_task="Pick the cube and place it in the blue square" \
|
||||||
|
--dataset.streaming_encoding=true \
|
||||||
|
--dataset.push_to_hub=true
|
||||||
|
```
|
||||||
|
|
||||||
|
### Bonus Auto-Collect script
|
||||||
|
|
||||||
|
/!\ This is specific to our setup and the task of picking and placing a cube. It is not a general-purpose data collection script. As you may notice, it doesn't require a teleop.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m examples.omx.record_grab \
|
||||||
|
--robot.type=omx_follower \
|
||||||
|
--robot.port=$ROBOT_PORT \
|
||||||
|
--robot.id=omx_follower \
|
||||||
|
--robot.cameras="$ROBOT_CAMERAS" \
|
||||||
|
--dataset.repo_id=$HF_USERNAME/omx_pickandplace \
|
||||||
|
--dataset.root=data/omx_pickandplace \
|
||||||
|
--dataset.num_episodes=50 \
|
||||||
|
--dataset.single_task="Pick the cube and place it in the blue square" \
|
||||||
|
--dataset.streaming_encoding=true \
|
||||||
|
--dataset.push_to_hub=true
|
||||||
|
```
|
||||||
|
|
||||||
|
Each episode:
|
||||||
|
|
||||||
|
1. The arm grabs the cube from the center of the workspace and places it at a random position.
|
||||||
|
2. The arm returns to HOME.
|
||||||
|
3. A targeted grab is recorded: HOME → approach raised → lower onto cube → grasp → lift → carry → drop → HOME.
|
||||||
|
|
||||||
|
A dataset is already available here [`maximellerbach/omx_pickandplace`](https://huggingface.co/datasets/maximellerbach/omx_pickandplace), so you can skip directly to training if you want.
|
||||||
|
|
||||||
|
## Step 2 — Train
|
||||||
|
|
||||||
|
To train a simple `ACT` policy on the collected dataset, you can use the `lerobot-train` CLI:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-train \
|
||||||
|
--dataset.repo_id=$HF_USERNAME/omx_pickandplace \
|
||||||
|
--policy.type=act \
|
||||||
|
--output_dir=outputs/train/omx_pickandplace_act \
|
||||||
|
--policy.device=cuda \
|
||||||
|
--policy.repo_id=$HF_USERNAME/omx_pickandplace_act \
|
||||||
|
--steps=20000 \
|
||||||
|
--wandb.enable=true
|
||||||
|
```
|
||||||
|
|
||||||
|
A pretrained `ACT` policy is already available here [`maximellerbach/omx_pickandplace_act`](https://huggingface.co/maximellerbach/omx_pickandplace_act).
|
||||||
|
|
||||||
|
## Step 3 — Rollout
|
||||||
|
|
||||||
|
Use the `lerobot-rollout` CLI with base strategy:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-rollout \
|
||||||
|
--strategy.type=base \
|
||||||
|
--robot.type=omx_follower \
|
||||||
|
--robot.port=$ROBOT_PORT \
|
||||||
|
--robot.id=omx_follower \
|
||||||
|
--robot.cameras="$ROBOT_CAMERAS" \
|
||||||
|
--policy.path=$HF_USERNAME/omx_pickandplace_act \
|
||||||
|
```
|
||||||
|
|
||||||
|
For continuous recording with automatic upload (sentry mode):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-rollout \
|
||||||
|
--strategy.type=sentry \
|
||||||
|
--strategy.upload_every_n_episodes=10 \
|
||||||
|
--robot.type=omx_follower \
|
||||||
|
--robot.port=$ROBOT_PORT \
|
||||||
|
--robot.id=omx_follower \
|
||||||
|
--robot.cameras="$ROBOT_CAMERAS" \
|
||||||
|
--policy.path=$HF_USERNAME/omx_pickandplace_act \
|
||||||
|
--dataset.repo_id=$HF_USERNAME/rollout_omx_pickandplace_act \
|
||||||
|
```
|
||||||
|
|
||||||
|
## Environment Reset Utility
|
||||||
|
|
||||||
|
Those are specific to this particular physical setup. Those are scripts that execute hardcoded sequences of actions on the robot to reset the environment, which is useful for data collection and evaluation. They are not general-purpose scripts.
|
||||||
|
|
||||||
|
`reset_environment.py` can be run standalone to prepare the workspace:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Grab cube + place it at a random position on the left side
|
||||||
|
python -m examples.omx.reset_environment --port $ROBOT_PORT --mode grab_and_place
|
||||||
|
```
|
||||||
|
|
||||||
|
It also exposes `grab_cube(robot)` and `place_cube(robot)` for use in custom scripts.
|
||||||
@@ -0,0 +1,422 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Auto-record grab episodes for the OMX robot arm.
|
||||||
|
|
||||||
|
Each episode cycle:
|
||||||
|
1. grab_and_place — grab cube from workspace center and place at a random (pan, reach) position
|
||||||
|
2. HOME — return arm to home with gripper open
|
||||||
|
3. record_grab — execute a targeted grab to the stored position while recording
|
||||||
|
observations + actions to a LeRobotDataset
|
||||||
|
|
||||||
|
Usage (run from repo root):
|
||||||
|
python -m examples.omx.record_grab \\
|
||||||
|
--robot.type=omx_follower \\
|
||||||
|
--robot.port=/dev/ttyACM0 \\
|
||||||
|
--robot.id=omx_follower \\
|
||||||
|
--robot.cameras="{ wrist: {type: opencv, index_or_path: 6, width: 640, height: 480, fps: 30, fourcc: MJPG}, top: {type: opencv, index_or_path: 4, width: 640, height: 480, fps: 30, fourcc: MJPG} }" \\
|
||||||
|
--dataset.repo_id=<hf_username>/<dataset_name> \\
|
||||||
|
--dataset.root=data/omx_grab \\
|
||||||
|
--dataset.num_episodes=50 \\
|
||||||
|
--dataset.single_task="Grab the cube" \\
|
||||||
|
--dataset.streaming_encoding=true
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pprint import pformat
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from lerobot.cameras import CameraConfig # noqa: F401
|
||||||
|
from lerobot.cameras.opencv import OpenCVCameraConfig # noqa: F401
|
||||||
|
from lerobot.configs import parser
|
||||||
|
from lerobot.configs.dataset import DatasetRecordConfig
|
||||||
|
from lerobot.datasets import (
|
||||||
|
LeRobotDataset,
|
||||||
|
VideoEncodingManager,
|
||||||
|
aggregate_pipeline_dataset_features,
|
||||||
|
create_initial_features,
|
||||||
|
)
|
||||||
|
from lerobot.processor import make_default_processors
|
||||||
|
from lerobot.robots import RobotConfig, make_robot_from_config
|
||||||
|
from lerobot.robots.omx_follower import OmxFollower
|
||||||
|
from lerobot.utils.constants import ACTION, OBS_STR
|
||||||
|
from lerobot.utils.feature_utils import build_dataset_frame, combine_feature_dicts
|
||||||
|
from lerobot.utils.robot_utils import precise_sleep
|
||||||
|
|
||||||
|
from .reset_environment import (
|
||||||
|
APPROACH_SPEED,
|
||||||
|
GRIPPER_CLOSE_POS,
|
||||||
|
HOME_POSE,
|
||||||
|
PUSH_END_ELBOW_FLEX,
|
||||||
|
PUSH_END_SHOULDER_LIFT,
|
||||||
|
PUSH_START_ELBOW_FLEX,
|
||||||
|
PUSH_START_SHOULDER_LIFT,
|
||||||
|
array_to_pose,
|
||||||
|
grab_cube,
|
||||||
|
horizontal_wrist_flex,
|
||||||
|
move_to_pose,
|
||||||
|
place_cube,
|
||||||
|
pose_to_array,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Grab-episode motion parameters ────────────────────────────────────────────
|
||||||
|
|
||||||
|
# Shoulder-lift offset for the raised approach phase (subtracted from the target sl, arm is higher).
|
||||||
|
GRAB_RAISE_SL_OFFSET = 20.0
|
||||||
|
GRAB_LOWER_SPEED = 20.0
|
||||||
|
RECORD_SPEED = 30.0
|
||||||
|
|
||||||
|
# Pose the arm travels to after closing the gripper (cube held).
|
||||||
|
GRAB_CARRY_POSE = {
|
||||||
|
"shoulder_pan.pos": -23.0,
|
||||||
|
"shoulder_lift.pos": 5.0,
|
||||||
|
"elbow_flex.pos": 18.0,
|
||||||
|
"wrist_flex.pos": -14.0,
|
||||||
|
"wrist_roll.pos": 0.0,
|
||||||
|
"gripper.pos": GRIPPER_CLOSE_POS,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Per-joint jitter limits (degrees) applied to transit waypoints for human-like variation.
|
||||||
|
# Cube-approach and carry poses are never jittered to preserve precision.
|
||||||
|
_JITTER_LIMITS: dict[str, float] = {
|
||||||
|
"shoulder_pan.pos": 5.0,
|
||||||
|
"shoulder_lift.pos": 4.0,
|
||||||
|
"elbow_flex.pos": 4.0,
|
||||||
|
"wrist_flex.pos": 3.0,
|
||||||
|
"wrist_roll.pos": 2.0,
|
||||||
|
"gripper.pos": 0.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _jitter_pose(pose: dict, rng: np.random.Generator) -> dict:
|
||||||
|
"""Return a copy of pose with independent per-joint random perturbations."""
|
||||||
|
return {
|
||||||
|
k: v + rng.uniform(-_JITTER_LIMITS.get(k, 0.0), _JITTER_LIMITS.get(k, 0.0)) for k, v in pose.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _random_stuck_pose(rng: np.random.Generator) -> dict:
|
||||||
|
"""Return a physically plausible stuck pose (failed grasp), gripper closed.
|
||||||
|
|
||||||
|
ef bounds are piecewise-linear in sl so the arm stays in a reachable,
|
||||||
|
table-safe envelope across the full sl range:
|
||||||
|
sl=-50 → ef ∈ [ 0, 50] (arm raised, can be bent forward)
|
||||||
|
sl= 0 → ef ∈ [-25, 25] (mid reach)
|
||||||
|
sl= 30 → ef ∈ [-20, 0] (arm extended, little room to flex)
|
||||||
|
wrist_flex is randomly offset from the horizontal value.
|
||||||
|
"""
|
||||||
|
pan = float(rng.uniform(-5.0, 35.0))
|
||||||
|
sl = float(rng.uniform(-50.0, 30.0))
|
||||||
|
|
||||||
|
if sl <= 0.0:
|
||||||
|
alpha = (sl + 50.0) / 50.0 # 0 at sl=-50, 1 at sl=0
|
||||||
|
ef_lo = alpha * -25.0 # 0 → -25
|
||||||
|
ef_hi = 50.0 + alpha * -25.0 # 50 → 25
|
||||||
|
else:
|
||||||
|
alpha = sl / 30.0 # 0 at sl=0, 1 at sl=30
|
||||||
|
ef_lo = -25.0 + alpha * 5.0 # -25 → -20
|
||||||
|
ef_hi = 25.0 + alpha * -25.0 # 25 → 0
|
||||||
|
|
||||||
|
ef = float(rng.uniform(ef_lo, ef_hi))
|
||||||
|
wf = horizontal_wrist_flex(sl, ef) + float(rng.uniform(-15.0, 15.0))
|
||||||
|
return {
|
||||||
|
"shoulder_pan.pos": pan,
|
||||||
|
"shoulder_lift.pos": sl,
|
||||||
|
"elbow_flex.pos": ef,
|
||||||
|
"wrist_flex.pos": wf,
|
||||||
|
"wrist_roll.pos": float(rng.uniform(-15.0, 15.0)),
|
||||||
|
"gripper.pos": GRIPPER_CLOSE_POS,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class OmxRecordGrabConfig:
|
||||||
|
robot: RobotConfig
|
||||||
|
dataset: DatasetRecordConfig
|
||||||
|
# Resume recording on an existing dataset.
|
||||||
|
resume: bool = False
|
||||||
|
# Fraction of episodes that start from a random stuck pose (gripper closed) to
|
||||||
|
# generate recovery data. 0.0 = disabled, 1.0 = all episodes are recovery starts.
|
||||||
|
recovery_prob: float = 0.5
|
||||||
|
|
||||||
|
|
||||||
|
def record_episode_spline(
|
||||||
|
robot: OmxFollower,
|
||||||
|
waypoints: list[dict],
|
||||||
|
speeds: list[float],
|
||||||
|
dataset: LeRobotDataset,
|
||||||
|
task: str,
|
||||||
|
) -> None:
|
||||||
|
"""Execute a Catmull-Rom-style spline through waypoints, recording each frame.
|
||||||
|
|
||||||
|
Segment durations are parameterized from the maximum absolute joint delta
|
||||||
|
between consecutive waypoints divided by the requested segment speed,
|
||||||
|
producing non-uniform timing in joint space. Interior tangents are derived
|
||||||
|
from the adjacent per-segment velocities, with clamped (zero-velocity)
|
||||||
|
endpoints so the arm starts and stops smoothly. Each segment is cubic
|
||||||
|
Hermite, giving C1 continuity at every waypoint.
|
||||||
|
"""
|
||||||
|
pts = [pose_to_array(w) for w in waypoints]
|
||||||
|
n = len(pts)
|
||||||
|
|
||||||
|
# Steps and duration per segment
|
||||||
|
n_steps_list = []
|
||||||
|
timestamps = []
|
||||||
|
for i in range(n - 1):
|
||||||
|
max_dist = float(np.max(np.abs(pts[i + 1] - pts[i])))
|
||||||
|
ns = max(1, int(max_dist / speeds[i] * dataset.fps)) if max_dist >= 0.5 else 0
|
||||||
|
n_steps_list.append(ns)
|
||||||
|
timestamps.append(ns / dataset.fps)
|
||||||
|
|
||||||
|
# Velocity tangents (deg/sec) — clamped at endpoints, Catmull-Rom for interior
|
||||||
|
vels = [np.zeros_like(pts[0])]
|
||||||
|
for i in range(1, n - 1):
|
||||||
|
v_prev = (pts[i] - pts[i - 1]) / timestamps[i - 1] if timestamps[i - 1] > 0 else np.zeros_like(pts[0])
|
||||||
|
v_next = (pts[i + 1] - pts[i]) / timestamps[i] if timestamps[i] > 0 else np.zeros_like(pts[0])
|
||||||
|
vels.append(0.5 * (v_prev + v_next))
|
||||||
|
vels.append(np.zeros_like(pts[0]))
|
||||||
|
|
||||||
|
dt = 1.0 / dataset.fps
|
||||||
|
for seg in range(n - 1):
|
||||||
|
ns = n_steps_list[seg]
|
||||||
|
if ns == 0:
|
||||||
|
continue
|
||||||
|
p0, p1 = pts[seg], pts[seg + 1]
|
||||||
|
# Scale velocity (deg/sec) to t-space tangent (deg/t-unit, where t: 0→1 over ns steps)
|
||||||
|
m0 = vels[seg] * timestamps[seg]
|
||||||
|
m1 = vels[seg + 1] * timestamps[seg]
|
||||||
|
|
||||||
|
for step in range(1, ns + 1):
|
||||||
|
t = step / ns
|
||||||
|
h00 = 2 * t**3 - 3 * t**2 + 1
|
||||||
|
h10 = t**3 - 2 * t**2 + t
|
||||||
|
h01 = -2 * t**3 + 3 * t**2
|
||||||
|
h11 = t**3 - t**2
|
||||||
|
commanded = h00 * p0 + h10 * m0 + h01 * p1 + h11 * m1
|
||||||
|
|
||||||
|
action = array_to_pose(commanded)
|
||||||
|
robot.send_action(action)
|
||||||
|
obs = robot.get_observation()
|
||||||
|
obs_frame = build_dataset_frame(dataset.features, obs, prefix=OBS_STR)
|
||||||
|
action_frame = build_dataset_frame(dataset.features, action, prefix=ACTION)
|
||||||
|
dataset.add_frame({**obs_frame, **action_frame, "task": task})
|
||||||
|
precise_sleep(dt)
|
||||||
|
|
||||||
|
|
||||||
|
def record_grab_episode(
|
||||||
|
robot: OmxFollower,
|
||||||
|
dataset: LeRobotDataset,
|
||||||
|
pan: float,
|
||||||
|
t: float,
|
||||||
|
task: str,
|
||||||
|
recovery_start: bool = False,
|
||||||
|
) -> None:
|
||||||
|
"""Execute a targeted grab to the stored (pan, t) position, recording every frame.
|
||||||
|
|
||||||
|
Normal sequence (initial HOME move is NOT recorded):
|
||||||
|
HOME → raised approach above cube → lower → close gripper
|
||||||
|
→ raise [jittered] → retract [jittered] → GRAB_CARRY_POSE → drop → HOME
|
||||||
|
|
||||||
|
Recovery sequence (recovery_start=True): arm is moved to a random stuck pose
|
||||||
|
(gripper closed) without recording, then recording begins from there:
|
||||||
|
stuck_pose → raised approach above cube → [normal grab sequence from there]
|
||||||
|
|
||||||
|
All segments are joined by a Catmull-Rom spline (C1-continuous velocities).
|
||||||
|
"""
|
||||||
|
sl = PUSH_START_SHOULDER_LIFT + t * (PUSH_END_SHOULDER_LIFT - PUSH_START_SHOULDER_LIFT)
|
||||||
|
ef = PUSH_START_ELBOW_FLEX + t * (PUSH_END_ELBOW_FLEX - PUSH_START_ELBOW_FLEX)
|
||||||
|
sl_raised = sl - GRAB_RAISE_SL_OFFSET
|
||||||
|
wf_horizontal = horizontal_wrist_flex(sl, ef)
|
||||||
|
|
||||||
|
rng = np.random.default_rng()
|
||||||
|
|
||||||
|
if recovery_start:
|
||||||
|
stuck_pose = _random_stuck_pose(rng)
|
||||||
|
logger.info(f"Recovery start: {stuck_pose}")
|
||||||
|
move_to_pose(robot, stuck_pose, APPROACH_SPEED)
|
||||||
|
first_waypoints = [stuck_pose]
|
||||||
|
first_speeds = []
|
||||||
|
else:
|
||||||
|
jittery_start = _jitter_pose(HOME_POSE, rng)
|
||||||
|
move_to_pose(robot, jittery_start, APPROACH_SPEED)
|
||||||
|
first_waypoints = [jittery_start]
|
||||||
|
first_speeds = []
|
||||||
|
|
||||||
|
waypoints = first_waypoints + [
|
||||||
|
{ # raised approach: arm above cube
|
||||||
|
"shoulder_pan.pos": pan,
|
||||||
|
"shoulder_lift.pos": sl_raised,
|
||||||
|
"elbow_flex.pos": ef,
|
||||||
|
"wrist_flex.pos": horizontal_wrist_flex(sl_raised, ef),
|
||||||
|
"wrist_roll.pos": 0.0,
|
||||||
|
"gripper.pos": 60.0,
|
||||||
|
},
|
||||||
|
{ # lower onto cube — no jitter: precision needed
|
||||||
|
"shoulder_pan.pos": pan,
|
||||||
|
"shoulder_lift.pos": sl,
|
||||||
|
"elbow_flex.pos": ef,
|
||||||
|
"wrist_flex.pos": wf_horizontal,
|
||||||
|
"wrist_roll.pos": 0.0,
|
||||||
|
"gripper.pos": 60.0,
|
||||||
|
},
|
||||||
|
{ # close gripper — no jitter: precision needed
|
||||||
|
"shoulder_pan.pos": pan,
|
||||||
|
"shoulder_lift.pos": sl,
|
||||||
|
"elbow_flex.pos": ef,
|
||||||
|
"wrist_flex.pos": wf_horizontal,
|
||||||
|
"wrist_roll.pos": 0.0,
|
||||||
|
"gripper.pos": GRIPPER_CLOSE_POS,
|
||||||
|
},
|
||||||
|
_jitter_pose(
|
||||||
|
{ # raise with cube
|
||||||
|
"shoulder_pan.pos": pan,
|
||||||
|
"shoulder_lift.pos": sl_raised,
|
||||||
|
"elbow_flex.pos": ef,
|
||||||
|
"wrist_flex.pos": horizontal_wrist_flex(sl_raised, ef),
|
||||||
|
"wrist_roll.pos": 0.0,
|
||||||
|
"gripper.pos": GRIPPER_CLOSE_POS,
|
||||||
|
},
|
||||||
|
rng,
|
||||||
|
),
|
||||||
|
_jitter_pose(
|
||||||
|
{ # retract: fold arm toward HOME before sweeping to carry zone
|
||||||
|
"shoulder_pan.pos": pan * 0.25,
|
||||||
|
"shoulder_lift.pos": HOME_POSE["shoulder_lift.pos"] + 5.0,
|
||||||
|
"elbow_flex.pos": HOME_POSE["elbow_flex.pos"] - 5.0,
|
||||||
|
"wrist_flex.pos": 0.0,
|
||||||
|
"wrist_roll.pos": 0.0,
|
||||||
|
"gripper.pos": GRIPPER_CLOSE_POS,
|
||||||
|
},
|
||||||
|
rng,
|
||||||
|
),
|
||||||
|
GRAB_CARRY_POSE, # no jitter: target drop zone
|
||||||
|
{**GRAB_CARRY_POSE, "gripper.pos": 60.0}, # drop cube
|
||||||
|
HOME_POSE,
|
||||||
|
]
|
||||||
|
speeds = first_speeds + [
|
||||||
|
RECORD_SPEED, # (HOME →) raised approach
|
||||||
|
GRAB_LOWER_SPEED, # raised approach → lower
|
||||||
|
GRAB_LOWER_SPEED, # lower → close gripper
|
||||||
|
RECORD_SPEED, # close gripper → raise
|
||||||
|
RECORD_SPEED, # raise → retract
|
||||||
|
RECORD_SPEED, # retract → carry pose
|
||||||
|
RECORD_SPEED, # carry pose → drop
|
||||||
|
RECORD_SPEED, # drop → HOME
|
||||||
|
]
|
||||||
|
|
||||||
|
record_episode_spline(robot, waypoints, speeds, dataset, task)
|
||||||
|
|
||||||
|
# Dwell at HOME for ~0.5 s before next episode
|
||||||
|
home_action = build_dataset_frame(dataset.features, HOME_POSE, prefix=ACTION)
|
||||||
|
dt = 1.0 / dataset.fps
|
||||||
|
for _ in range(int(dataset.fps * 0.5)):
|
||||||
|
robot.send_action(HOME_POSE)
|
||||||
|
obs = robot.get_observation()
|
||||||
|
obs_frame = build_dataset_frame(dataset.features, obs, prefix=OBS_STR)
|
||||||
|
dataset.add_frame({**obs_frame, **home_action, "task": task})
|
||||||
|
precise_sleep(dt)
|
||||||
|
|
||||||
|
|
||||||
|
@parser.wrap()
|
||||||
|
def record_grab(cfg: OmxRecordGrabConfig) -> LeRobotDataset:
|
||||||
|
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
|
||||||
|
logger.info(pformat(cfg))
|
||||||
|
|
||||||
|
robot = make_robot_from_config(cfg.robot)
|
||||||
|
use_videos = cfg.dataset.video
|
||||||
|
|
||||||
|
teleop_action_processor, _, robot_obs_processor = make_default_processors()
|
||||||
|
|
||||||
|
dataset_features = combine_feature_dicts(
|
||||||
|
aggregate_pipeline_dataset_features(
|
||||||
|
pipeline=teleop_action_processor,
|
||||||
|
initial_features=create_initial_features(action=robot.action_features),
|
||||||
|
use_videos=use_videos,
|
||||||
|
),
|
||||||
|
aggregate_pipeline_dataset_features(
|
||||||
|
pipeline=robot_obs_processor,
|
||||||
|
initial_features=create_initial_features(observation=robot.observation_features),
|
||||||
|
use_videos=use_videos,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
num_cameras = len(robot.cameras) if hasattr(robot, "cameras") else 0
|
||||||
|
dataset = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
if cfg.resume:
|
||||||
|
dataset = LeRobotDataset.resume(
|
||||||
|
cfg.dataset.repo_id,
|
||||||
|
root=cfg.dataset.root,
|
||||||
|
streaming_encoding=cfg.dataset.streaming_encoding,
|
||||||
|
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
|
||||||
|
vcodec=cfg.dataset.vcodec,
|
||||||
|
encoder_threads=cfg.dataset.encoder_threads,
|
||||||
|
image_writer_processes=cfg.dataset.num_image_writer_processes if num_cameras > 0 else 0,
|
||||||
|
image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera * num_cameras
|
||||||
|
if num_cameras > 0
|
||||||
|
else 0,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cfg.dataset.stamp_repo_id()
|
||||||
|
dataset = LeRobotDataset.create(
|
||||||
|
cfg.dataset.repo_id,
|
||||||
|
cfg.dataset.fps,
|
||||||
|
root=cfg.dataset.root,
|
||||||
|
robot_type=robot.name,
|
||||||
|
features=dataset_features,
|
||||||
|
use_videos=use_videos,
|
||||||
|
streaming_encoding=cfg.dataset.streaming_encoding,
|
||||||
|
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
|
||||||
|
vcodec=cfg.dataset.vcodec,
|
||||||
|
encoder_threads=cfg.dataset.encoder_threads,
|
||||||
|
image_writer_processes=cfg.dataset.num_image_writer_processes if num_cameras > 0 else 0,
|
||||||
|
image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera * num_cameras
|
||||||
|
if num_cameras > 0
|
||||||
|
else 0,
|
||||||
|
)
|
||||||
|
|
||||||
|
robot.connect(calibrate=True)
|
||||||
|
|
||||||
|
rng = np.random.default_rng()
|
||||||
|
with VideoEncodingManager(dataset):
|
||||||
|
for episode_idx in range(cfg.dataset.num_episodes):
|
||||||
|
logger.info(f"=== Episode {episode_idx + 1}/{cfg.dataset.num_episodes} ===")
|
||||||
|
|
||||||
|
logger.info("Step 1: grabbing and placing cube...")
|
||||||
|
grab_cube(robot)
|
||||||
|
pan, t = place_cube(robot)
|
||||||
|
logger.info(f"Cube placed at pan={pan:.1f}, reach={t:.2f}")
|
||||||
|
|
||||||
|
recovery_start = cfg.recovery_prob > 0 and float(rng.random()) < cfg.recovery_prob
|
||||||
|
logger.info(f"Step 2: recording {'recovery ' if recovery_start else ''}grab episode...")
|
||||||
|
record_grab_episode(
|
||||||
|
robot,
|
||||||
|
dataset,
|
||||||
|
pan,
|
||||||
|
t,
|
||||||
|
cfg.dataset.single_task,
|
||||||
|
recovery_start=recovery_start,
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset.save_episode()
|
||||||
|
logger.info(f"Episode {episode_idx + 1} saved.")
|
||||||
|
|
||||||
|
finally:
|
||||||
|
if dataset:
|
||||||
|
dataset.finalize()
|
||||||
|
if robot.is_connected:
|
||||||
|
robot.disconnect()
|
||||||
|
|
||||||
|
if cfg.dataset.push_to_hub and dataset and dataset.num_episodes > 0:
|
||||||
|
dataset.push_to_hub(tags=cfg.dataset.tags, private=cfg.dataset.private)
|
||||||
|
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
record_grab()
|
||||||
@@ -0,0 +1,267 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Auto-reset and cube-grab utility for the OMX robot arm.
|
||||||
|
|
||||||
|
Provides:
|
||||||
|
- grab_cube(robot): sweep workspace, center cube, close gripper
|
||||||
|
- place_cube(robot): carry cube to a random position, release
|
||||||
|
|
||||||
|
Standalone usage (run from repo root):
|
||||||
|
python -m examples.omx.reset_environment --port /dev/ttyACM1 --mode grab
|
||||||
|
python -m examples.omx.reset_environment --port /dev/ttyACM1 --mode grab_and_place
|
||||||
|
|
||||||
|
Joint range: -100 to 100 for arm joints; gripper: 50 = closed, 80 = open.
|
||||||
|
|
||||||
|
To read current joint values for calibration, add after robot.connect():
|
||||||
|
obs = robot.get_observation()
|
||||||
|
print({k: round(obs[k], 1) for k in JOINT_NAMES})
|
||||||
|
robot.disconnect(); raise SystemExit
|
||||||
|
|
||||||
|
Parallel-to-ground IK: wrist_flex = WRIST_HORIZONTAL_OFFSET - shoulder_lift - elbow_flex.
|
||||||
|
Linear interpolation preserves this constraint between any two poses that satisfy it.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from lerobot.robots.omx_follower import OmxFollower, OmxFollowerConfig
|
||||||
|
from lerobot.robots.robot import Robot
|
||||||
|
from lerobot.utils.robot_utils import precise_sleep
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# ── Poses ─────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
HOME_POSE = {
|
||||||
|
"shoulder_pan.pos": 0.0,
|
||||||
|
"shoulder_lift.pos": -50.0,
|
||||||
|
"elbow_flex.pos": 50.0,
|
||||||
|
"wrist_flex.pos": 0.0,
|
||||||
|
"wrist_roll.pos": 0.0,
|
||||||
|
"gripper.pos": 60.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
SWEEP_WAYPOINTS = [
|
||||||
|
{
|
||||||
|
"shoulder_pan.pos": -60.0,
|
||||||
|
"shoulder_lift.pos": 50.0,
|
||||||
|
"elbow_flex.pos": -60.0,
|
||||||
|
"wrist_flex.pos": -20.0,
|
||||||
|
"wrist_roll.pos": 0.0,
|
||||||
|
"gripper.pos": 60.0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"shoulder_pan.pos": -30.0,
|
||||||
|
"shoulder_lift.pos": 50.0,
|
||||||
|
"elbow_flex.pos": -60.0,
|
||||||
|
"wrist_flex.pos": -5.0,
|
||||||
|
"wrist_roll.pos": 0.0,
|
||||||
|
"gripper.pos": 60.0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"shoulder_pan.pos": 20.0,
|
||||||
|
"shoulder_lift.pos": 50.0,
|
||||||
|
"elbow_flex.pos": -55.0,
|
||||||
|
"wrist_flex.pos": -5.0,
|
||||||
|
"wrist_roll.pos": 0.0,
|
||||||
|
"gripper.pos": 60.0,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
# ── Motion parameters ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
CONTROL_HZ = 30
|
||||||
|
APPROACH_SPEED = 50.0
|
||||||
|
SWEEP_SPEED = 40.0
|
||||||
|
|
||||||
|
# ── Grab-sequence parameters ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
GRAB_PAN = 0.0
|
||||||
|
SWEEP_LEFT_PAN = -60.0
|
||||||
|
SWEEP_RIGHT_PAN = 60.0
|
||||||
|
SWEEP_END_OFFSET = 5.0 # stop before center so the cube isn't pushed past GRAB_PAN
|
||||||
|
SWEEP_END_PAN_RANGE = (15.0, 20.0)
|
||||||
|
|
||||||
|
SWEEP_LOW_SHOULDER_LIFT = 50.0
|
||||||
|
SWEEP_LOW_ELBOW_FLEX_START = -60.0
|
||||||
|
SWEEP_LOW_ELBOW_FLEX_END = -55.0
|
||||||
|
|
||||||
|
SWEEP_HIGH_WRIST_FLEX = -20.0 # wrist tilted up during high approach to clear obstacles
|
||||||
|
|
||||||
|
PUSH_START_SHOULDER_LIFT = 0.0
|
||||||
|
PUSH_START_ELBOW_FLEX = 45.0
|
||||||
|
PUSH_END_SHOULDER_LIFT = 50.0
|
||||||
|
PUSH_END_ELBOW_FLEX = -50.0
|
||||||
|
# Subtracted from shoulder_lift during the push sweep to clear the platform surface.
|
||||||
|
# Does not affect the grab-target interpolation in record_grab.py.
|
||||||
|
PUSH_RAISE_OFFSET = 5.0
|
||||||
|
|
||||||
|
WRIST_HORIZONTAL_OFFSET = 0.0 # tune if gripper tilts during push: + tilts nose up, - down
|
||||||
|
GRIPPER_CLOSE_POS = 50.0
|
||||||
|
|
||||||
|
PLACE_LEFT_PAN_RANGE = (5.0, 30.0) # random pan range for cube placement on the left side
|
||||||
|
PLACE_REACH_RANGE = (0.1, 0.7) # 0 = arm retracted (PUSH_START), 1 = fully extended (PUSH_END)
|
||||||
|
|
||||||
|
JOINT_NAMES = [
|
||||||
|
"shoulder_pan.pos",
|
||||||
|
"shoulder_lift.pos",
|
||||||
|
"elbow_flex.pos",
|
||||||
|
"wrist_flex.pos",
|
||||||
|
"wrist_roll.pos",
|
||||||
|
"gripper.pos",
|
||||||
|
]
|
||||||
|
|
||||||
|
# ── Helpers ───────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def pose_to_array(pose: dict) -> np.ndarray:
|
||||||
|
return np.array([pose[k] for k in JOINT_NAMES])
|
||||||
|
|
||||||
|
|
||||||
|
def array_to_pose(arr: np.ndarray) -> dict:
|
||||||
|
return {k: float(arr[i]) for i, k in enumerate(JOINT_NAMES)}
|
||||||
|
|
||||||
|
|
||||||
|
def horizontal_wrist_flex(shoulder_lift: float, elbow_flex: float) -> float:
|
||||||
|
return WRIST_HORIZONTAL_OFFSET - shoulder_lift - elbow_flex
|
||||||
|
|
||||||
|
|
||||||
|
def _low_sweep_pose(pan: float, elbow_flex: float, wrist_flex: float | None = None) -> dict:
|
||||||
|
sl = SWEEP_LOW_SHOULDER_LIFT
|
||||||
|
return {
|
||||||
|
"shoulder_pan.pos": pan,
|
||||||
|
"shoulder_lift.pos": sl,
|
||||||
|
"elbow_flex.pos": elbow_flex,
|
||||||
|
"wrist_flex.pos": horizontal_wrist_flex(sl, elbow_flex) if wrist_flex is None else wrist_flex,
|
||||||
|
"wrist_roll.pos": 0.0,
|
||||||
|
"gripper.pos": 60.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _high_sweep_pose(pan: float) -> dict:
|
||||||
|
return {**HOME_POSE, "shoulder_pan.pos": pan, "wrist_flex.pos": SWEEP_HIGH_WRIST_FLEX}
|
||||||
|
|
||||||
|
|
||||||
|
def _push_pose(shoulder_lift: float, elbow_flex: float, pan: float = GRAB_PAN, gripper: float = 70.0) -> dict:
|
||||||
|
return {
|
||||||
|
"shoulder_pan.pos": pan,
|
||||||
|
"shoulder_lift.pos": shoulder_lift,
|
||||||
|
"elbow_flex.pos": elbow_flex,
|
||||||
|
"wrist_flex.pos": horizontal_wrist_flex(shoulder_lift, elbow_flex),
|
||||||
|
"wrist_roll.pos": 0.0,
|
||||||
|
"gripper.pos": gripper,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def move_to_pose(robot: Robot, target: dict, speed: float) -> None:
|
||||||
|
"""Interpolate from current position to target at the given speed (units/s)."""
|
||||||
|
obs = robot.get_observation()
|
||||||
|
current = np.array([obs[k] for k in JOINT_NAMES])
|
||||||
|
goal = pose_to_array(target)
|
||||||
|
|
||||||
|
max_distance = float(np.max(np.abs(goal - current)))
|
||||||
|
if max_distance < 0.5:
|
||||||
|
return
|
||||||
|
|
||||||
|
n_steps = max(1, int(max_distance / speed * CONTROL_HZ))
|
||||||
|
dt = 1.0 / CONTROL_HZ
|
||||||
|
for step in range(1, n_steps + 1):
|
||||||
|
t = step / n_steps
|
||||||
|
robot.send_action(array_to_pose(current + t * (goal - current)))
|
||||||
|
precise_sleep(dt)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Sequences ─────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def grab_cube(robot: Robot) -> None:
|
||||||
|
"""Left sweep → right sweep → extend arm parallel to ground → close gripper."""
|
||||||
|
move_to_pose(robot, HOME_POSE, APPROACH_SPEED)
|
||||||
|
|
||||||
|
for pan, end_pan in [
|
||||||
|
(SWEEP_LEFT_PAN, GRAB_PAN - SWEEP_END_OFFSET),
|
||||||
|
(SWEEP_RIGHT_PAN, GRAB_PAN + SWEEP_END_OFFSET),
|
||||||
|
]:
|
||||||
|
logger.info(f"Sweeping {'left' if pan < 0 else 'right'} → center...")
|
||||||
|
move_to_pose(robot, _high_sweep_pose(pan), APPROACH_SPEED)
|
||||||
|
move_to_pose(
|
||||||
|
robot, _low_sweep_pose(pan, SWEEP_LOW_ELBOW_FLEX_START, wrist_flex=-20.0), APPROACH_SPEED
|
||||||
|
)
|
||||||
|
move_to_pose(robot, _low_sweep_pose(end_pan, SWEEP_LOW_ELBOW_FLEX_END, wrist_flex=0.0), SWEEP_SPEED)
|
||||||
|
move_to_pose(robot, HOME_POSE, APPROACH_SPEED)
|
||||||
|
|
||||||
|
logger.info("Extending to push cube into gripper...")
|
||||||
|
move_to_pose(
|
||||||
|
robot,
|
||||||
|
_push_pose(PUSH_START_SHOULDER_LIFT - PUSH_RAISE_OFFSET, PUSH_START_ELBOW_FLEX),
|
||||||
|
APPROACH_SPEED,
|
||||||
|
)
|
||||||
|
move_to_pose(
|
||||||
|
robot,
|
||||||
|
_push_pose(PUSH_END_SHOULDER_LIFT - PUSH_RAISE_OFFSET, PUSH_END_ELBOW_FLEX),
|
||||||
|
SWEEP_SPEED,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("Closing gripper...")
|
||||||
|
move_to_pose(
|
||||||
|
robot,
|
||||||
|
_push_pose(PUSH_END_SHOULDER_LIFT, PUSH_END_ELBOW_FLEX, gripper=GRIPPER_CLOSE_POS),
|
||||||
|
APPROACH_SPEED,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("Grab complete.")
|
||||||
|
|
||||||
|
|
||||||
|
def place_cube(robot: Robot) -> tuple[float, float]:
|
||||||
|
"""Carry the cube (gripper closed) to a random position on the left side, then release.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(pan, t): pan angle and reach scalar [0, 1] of the placement position.
|
||||||
|
"""
|
||||||
|
pan = float(np.random.uniform(*PLACE_LEFT_PAN_RANGE))
|
||||||
|
t = float(np.random.uniform(*PLACE_REACH_RANGE))
|
||||||
|
sl = PUSH_START_SHOULDER_LIFT + t * (PUSH_END_SHOULDER_LIFT - PUSH_START_SHOULDER_LIFT)
|
||||||
|
ef = PUSH_START_ELBOW_FLEX + t * (PUSH_END_ELBOW_FLEX - PUSH_START_ELBOW_FLEX)
|
||||||
|
logger.info(f"Placing cube at pan={pan:.1f}, reach={t:.2f}...")
|
||||||
|
|
||||||
|
move_to_pose(robot, {**HOME_POSE, "gripper.pos": GRIPPER_CLOSE_POS}, APPROACH_SPEED)
|
||||||
|
move_to_pose(
|
||||||
|
robot, {**HOME_POSE, "shoulder_pan.pos": pan, "gripper.pos": GRIPPER_CLOSE_POS}, APPROACH_SPEED
|
||||||
|
)
|
||||||
|
move_to_pose(robot, _push_pose(sl, ef, pan=pan, gripper=GRIPPER_CLOSE_POS), APPROACH_SPEED)
|
||||||
|
move_to_pose(robot, _push_pose(sl, ef, pan=pan, gripper=80.0), APPROACH_SPEED)
|
||||||
|
move_to_pose(robot, HOME_POSE, APPROACH_SPEED)
|
||||||
|
logger.info("Place complete.")
|
||||||
|
return pan, t
|
||||||
|
|
||||||
|
|
||||||
|
# ── Entry point ───────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="OMX arm reset / grab script")
|
||||||
|
parser.add_argument("--port", default="/dev/ttyACM1")
|
||||||
|
parser.add_argument("--robot_id", default="omx_follower")
|
||||||
|
parser.add_argument("--mode", choices=["grab", "grab_and_place"], default="grab_and_place")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
|
||||||
|
|
||||||
|
robot = OmxFollower(OmxFollowerConfig(port=args.port, id=args.robot_id))
|
||||||
|
robot.connect(calibrate=True)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if args.mode == "grab":
|
||||||
|
grab_cube(robot)
|
||||||
|
elif args.mode == "grab_and_place":
|
||||||
|
grab_cube(robot)
|
||||||
|
place_cube(robot)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
robot.disconnect()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -14,13 +14,17 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
|
||||||
from lerobot.cameras.opencv import OpenCVCameraConfig
|
from lerobot.cameras.opencv import OpenCVCameraConfig
|
||||||
from lerobot.common.control_utils import init_keyboard_listener
|
from lerobot.common.control_utils import init_keyboard_listener, predict_action
|
||||||
from lerobot.configs import FeatureType, PolicyFeature
|
from lerobot.configs import FeatureType, PolicyFeature
|
||||||
from lerobot.datasets import LeRobotDataset, aggregate_pipeline_dataset_features, create_initial_features
|
from lerobot.datasets import LeRobotDataset, aggregate_pipeline_dataset_features, create_initial_features
|
||||||
from lerobot.model.kinematics import RobotKinematics
|
from lerobot.model.kinematics import RobotKinematics
|
||||||
from lerobot.policies import make_pre_post_processors
|
from lerobot.policies import make_pre_post_processors
|
||||||
from lerobot.policies.act import ACTPolicy
|
from lerobot.policies.act import ACTPolicy
|
||||||
|
from lerobot.policies.utils import make_robot_action
|
||||||
from lerobot.processor import (
|
from lerobot.processor import (
|
||||||
RobotProcessorPipeline,
|
RobotProcessorPipeline,
|
||||||
make_default_teleop_action_processor,
|
make_default_teleop_action_processor,
|
||||||
@@ -34,11 +38,12 @@ from lerobot.robots.so_follower.robot_kinematic_processor import (
|
|||||||
ForwardKinematicsJointsToEE,
|
ForwardKinematicsJointsToEE,
|
||||||
InverseKinematicsEEToJoints,
|
InverseKinematicsEEToJoints,
|
||||||
)
|
)
|
||||||
from lerobot.scripts.lerobot_record import record_loop
|
|
||||||
from lerobot.types import RobotAction, RobotObservation
|
from lerobot.types import RobotAction, RobotObservation
|
||||||
from lerobot.utils.feature_utils import combine_feature_dicts
|
from lerobot.utils.constants import ACTION, OBS_STR
|
||||||
|
from lerobot.utils.feature_utils import build_dataset_frame, combine_feature_dicts
|
||||||
|
from lerobot.utils.robot_utils import precise_sleep
|
||||||
from lerobot.utils.utils import log_say
|
from lerobot.utils.utils import log_say
|
||||||
from lerobot.utils.visualization_utils import init_rerun
|
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
|
||||||
|
|
||||||
NUM_EPISODES = 5
|
NUM_EPISODES = 5
|
||||||
FPS = 30
|
FPS = 30
|
||||||
@@ -49,6 +54,9 @@ HF_DATASET_ID = "<hf_username>/<dataset_repo_id>"
|
|||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
# NOTE: For production policy deployment, use `lerobot-rollout` CLI instead.
|
||||||
|
# This script provides a self-contained example for educational purposes.
|
||||||
|
|
||||||
# Create the robot configuration & robot
|
# Create the robot configuration & robot
|
||||||
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
|
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
|
||||||
robot_config = SO100FollowerConfig(
|
robot_config = SO100FollowerConfig(
|
||||||
@@ -143,43 +151,67 @@ def main():
|
|||||||
raise ValueError("Robot is not connected!")
|
raise ValueError("Robot is not connected!")
|
||||||
|
|
||||||
print("Starting evaluate loop...")
|
print("Starting evaluate loop...")
|
||||||
|
control_interval = 1 / FPS
|
||||||
episode_idx = 0
|
episode_idx = 0
|
||||||
for episode_idx in range(NUM_EPISODES):
|
for episode_idx in range(NUM_EPISODES):
|
||||||
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
|
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||||
|
|
||||||
# Main record loop
|
# Inline evaluation loop: predict actions and send to robot
|
||||||
record_loop(
|
timestamp = 0
|
||||||
robot=robot,
|
start_episode_t = time.perf_counter()
|
||||||
events=events,
|
while timestamp < EPISODE_TIME_SEC:
|
||||||
fps=FPS,
|
start_loop_t = time.perf_counter()
|
||||||
policy=policy,
|
|
||||||
preprocessor=preprocessor, # Pass the pre and post policy processors
|
if events["exit_early"]:
|
||||||
postprocessor=postprocessor,
|
events["exit_early"] = False
|
||||||
dataset=dataset,
|
break
|
||||||
control_time_s=EPISODE_TIME_SEC,
|
|
||||||
single_task=TASK_DESCRIPTION,
|
# Get robot observation
|
||||||
display_data=True,
|
obs = robot.get_observation()
|
||||||
teleop_action_processor=make_default_teleop_action_processor(),
|
obs_processed = robot_joints_to_ee_pose_processor(obs)
|
||||||
robot_action_processor=robot_ee_to_joints_processor,
|
observation_frame = build_dataset_frame(dataset.features, obs_processed, prefix=OBS_STR)
|
||||||
robot_observation_processor=robot_joints_to_ee_pose_processor,
|
|
||||||
)
|
# Predict action using the policy
|
||||||
|
action_tensor = predict_action(
|
||||||
|
observation=observation_frame,
|
||||||
|
policy=policy,
|
||||||
|
device=policy.config.device,
|
||||||
|
preprocessor=preprocessor,
|
||||||
|
postprocessor=postprocessor,
|
||||||
|
use_amp=policy.config.device.type == "cuda",
|
||||||
|
task=TASK_DESCRIPTION,
|
||||||
|
robot_type=robot.name,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert policy output to robot action dict
|
||||||
|
action_values = make_robot_action(action_tensor, dataset.features)
|
||||||
|
|
||||||
|
# Process and send action to robot (EE -> joints via IK)
|
||||||
|
robot_action_to_send = robot_ee_to_joints_processor((action_values, obs))
|
||||||
|
robot.send_action(robot_action_to_send)
|
||||||
|
|
||||||
|
# Write to dataset
|
||||||
|
action_frame = build_dataset_frame(dataset.features, action_values, prefix=ACTION)
|
||||||
|
frame = {**observation_frame, **action_frame, "task": TASK_DESCRIPTION}
|
||||||
|
dataset.add_frame(frame)
|
||||||
|
|
||||||
|
log_rerun_data(observation=obs_processed, action=action_values)
|
||||||
|
|
||||||
|
dt_s = time.perf_counter() - start_loop_t
|
||||||
|
sleep_time_s = control_interval - dt_s
|
||||||
|
if sleep_time_s < 0:
|
||||||
|
logging.warning(
|
||||||
|
f"Evaluate loop is running slower ({1 / dt_s:.1f} Hz) than the target FPS ({FPS} Hz)."
|
||||||
|
)
|
||||||
|
precise_sleep(max(sleep_time_s, 0.0))
|
||||||
|
timestamp = time.perf_counter() - start_episode_t
|
||||||
|
|
||||||
# Reset the environment if not stopping or re-recording
|
# Reset the environment if not stopping or re-recording
|
||||||
if not events["stop_recording"] and (
|
if not events["stop_recording"] and (
|
||||||
(episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"]
|
(episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"]
|
||||||
):
|
):
|
||||||
log_say("Reset the environment")
|
log_say("Reset the environment")
|
||||||
record_loop(
|
log_say("Waiting for environment reset, press right arrow key when ready...")
|
||||||
robot=robot,
|
|
||||||
events=events,
|
|
||||||
fps=FPS,
|
|
||||||
control_time_s=EPISODE_TIME_SEC,
|
|
||||||
single_task=TASK_DESCRIPTION,
|
|
||||||
display_data=True,
|
|
||||||
teleop_action_processor=make_default_teleop_action_processor(),
|
|
||||||
robot_action_processor=robot_ee_to_joints_processor,
|
|
||||||
robot_observation_processor=robot_joints_to_ee_pose_processor,
|
|
||||||
)
|
|
||||||
|
|
||||||
if events["rerecord_episode"]:
|
if events["rerecord_episode"]:
|
||||||
log_say("Re-record episode")
|
log_say("Re-record episode")
|
||||||
@@ -190,7 +222,6 @@ def main():
|
|||||||
|
|
||||||
# Save episode
|
# Save episode
|
||||||
dataset.save_episode()
|
dataset.save_episode()
|
||||||
episode_idx += 1
|
|
||||||
finally:
|
finally:
|
||||||
# Clean up
|
# Clean up
|
||||||
log_say("Stop recording")
|
log_say("Stop recording")
|
||||||
|
|||||||
@@ -65,14 +65,15 @@ def main():
|
|||||||
robot = SO100Follower(robot_config)
|
robot = SO100Follower(robot_config)
|
||||||
phone = Phone(teleop_config)
|
phone = Phone(teleop_config)
|
||||||
|
|
||||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo:
|
||||||
|
# https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||||
kinematics_solver = RobotKinematics(
|
kinematics_solver = RobotKinematics(
|
||||||
urdf_path="./SO101/so101_new_calib.urdf",
|
urdf_path="./SO101/so101_new_calib.urdf",
|
||||||
target_frame_name="gripper_frame_link",
|
target_frame_name="gripper_frame_link",
|
||||||
joint_names=list(robot.bus.motors.keys()),
|
joint_names=list(robot.bus.motors.keys()),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Build pipeline to convert phone action to EE action
|
# Build pipeline to convert phone action to EE action (with gripper velocity mapped to joint).
|
||||||
phone_to_robot_ee_pose_processor = RobotProcessorPipeline[
|
phone_to_robot_ee_pose_processor = RobotProcessorPipeline[
|
||||||
tuple[RobotAction, RobotObservation], RobotAction
|
tuple[RobotAction, RobotObservation], RobotAction
|
||||||
](
|
](
|
||||||
@@ -94,7 +95,7 @@ def main():
|
|||||||
to_output=transition_to_robot_action,
|
to_output=transition_to_robot_action,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Build pipeline to convert EE action to joints action
|
# Build pipeline to convert EE action to joints action (IK).
|
||||||
robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||||
steps=[
|
steps=[
|
||||||
InverseKinematicsEEToJoints(
|
InverseKinematicsEEToJoints(
|
||||||
@@ -107,7 +108,7 @@ def main():
|
|||||||
to_output=transition_to_robot_action,
|
to_output=transition_to_robot_action,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Build pipeline to convert joint observation to EE observation
|
# Build pipeline to convert joint observation to EE observation (FK).
|
||||||
robot_joints_to_ee_pose = RobotProcessorPipeline[RobotObservation, RobotObservation](
|
robot_joints_to_ee_pose = RobotProcessorPipeline[RobotObservation, RobotObservation](
|
||||||
steps=[
|
steps=[
|
||||||
ForwardKinematicsJointsToEE(
|
ForwardKinematicsJointsToEE(
|
||||||
@@ -118,13 +119,12 @@ def main():
|
|||||||
to_output=transition_to_observation,
|
to_output=transition_to_observation,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create the dataset
|
# Create the dataset, deriving features from the pipelines so the on-disk schema
|
||||||
|
# matches exactly what the pipelines produce at runtime.
|
||||||
dataset = LeRobotDataset.create(
|
dataset = LeRobotDataset.create(
|
||||||
repo_id=HF_REPO_ID,
|
repo_id=HF_REPO_ID,
|
||||||
fps=FPS,
|
fps=FPS,
|
||||||
features=combine_feature_dicts(
|
features=combine_feature_dicts(
|
||||||
# Run the feature contract of the pipelines
|
|
||||||
# This tells you how the features would look like after the pipeline steps
|
|
||||||
aggregate_pipeline_dataset_features(
|
aggregate_pipeline_dataset_features(
|
||||||
pipeline=phone_to_robot_ee_pose_processor,
|
pipeline=phone_to_robot_ee_pose_processor,
|
||||||
initial_features=create_initial_features(action=phone.action_features),
|
initial_features=create_initial_features(action=phone.action_features),
|
||||||
@@ -163,14 +163,14 @@ def main():
|
|||||||
robot=robot,
|
robot=robot,
|
||||||
events=events,
|
events=events,
|
||||||
fps=FPS,
|
fps=FPS,
|
||||||
|
teleop_action_processor=phone_to_robot_ee_pose_processor,
|
||||||
|
robot_action_processor=robot_ee_to_joints_processor,
|
||||||
|
robot_observation_processor=robot_joints_to_ee_pose,
|
||||||
teleop=phone,
|
teleop=phone,
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
control_time_s=EPISODE_TIME_SEC,
|
control_time_s=EPISODE_TIME_SEC,
|
||||||
single_task=TASK_DESCRIPTION,
|
single_task=TASK_DESCRIPTION,
|
||||||
display_data=True,
|
display_data=True,
|
||||||
teleop_action_processor=phone_to_robot_ee_pose_processor,
|
|
||||||
robot_action_processor=robot_ee_to_joints_processor,
|
|
||||||
robot_observation_processor=robot_joints_to_ee_pose,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Reset the environment if not stopping or re-recording
|
# Reset the environment if not stopping or re-recording
|
||||||
@@ -182,13 +182,13 @@ def main():
|
|||||||
robot=robot,
|
robot=robot,
|
||||||
events=events,
|
events=events,
|
||||||
fps=FPS,
|
fps=FPS,
|
||||||
|
teleop_action_processor=phone_to_robot_ee_pose_processor,
|
||||||
|
robot_action_processor=robot_ee_to_joints_processor,
|
||||||
|
robot_observation_processor=robot_joints_to_ee_pose,
|
||||||
teleop=phone,
|
teleop=phone,
|
||||||
control_time_s=RESET_TIME_SEC,
|
control_time_s=RESET_TIME_SEC,
|
||||||
single_task=TASK_DESCRIPTION,
|
single_task=TASK_DESCRIPTION,
|
||||||
display_data=True,
|
display_data=True,
|
||||||
teleop_action_processor=phone_to_robot_ee_pose_processor,
|
|
||||||
robot_action_processor=robot_ee_to_joints_processor,
|
|
||||||
robot_observation_processor=robot_joints_to_ee_pose,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if events["rerecord_episode"]:
|
if events["rerecord_episode"]:
|
||||||
|
|||||||
@@ -0,0 +1,126 @@
|
|||||||
|
# !/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""Run a trained EE-space policy on SO100 (phone-trained) without recording.
|
||||||
|
|
||||||
|
Mirrors ``examples/so100_to_so100_EE/rollout.py`` — the model was trained
|
||||||
|
with phone teleoperation in EE space, so at deployment we only need the
|
||||||
|
joint↔EE conversion on the robot side; the phone is not used.
|
||||||
|
|
||||||
|
Uses :class:`BaseStrategy` (no recording) + :class:`SyncInferenceConfig`
|
||||||
|
(inline policy call). For recording during rollout, switch to Sentry,
|
||||||
|
Highlight, or DAgger via ``lerobot-rollout --strategy.type=...``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from lerobot.cameras.opencv import OpenCVCameraConfig
|
||||||
|
from lerobot.configs import PreTrainedConfig
|
||||||
|
from lerobot.model.kinematics import RobotKinematics
|
||||||
|
from lerobot.processor import (
|
||||||
|
RobotProcessorPipeline,
|
||||||
|
observation_to_transition,
|
||||||
|
robot_action_observation_to_transition,
|
||||||
|
transition_to_observation,
|
||||||
|
transition_to_robot_action,
|
||||||
|
)
|
||||||
|
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig
|
||||||
|
from lerobot.robots.so_follower.robot_kinematic_processor import (
|
||||||
|
ForwardKinematicsJointsToEE,
|
||||||
|
InverseKinematicsEEToJoints,
|
||||||
|
)
|
||||||
|
from lerobot.rollout import BaseStrategyConfig, RolloutConfig, build_rollout_context
|
||||||
|
from lerobot.rollout.inference import SyncInferenceConfig
|
||||||
|
from lerobot.rollout.strategies import BaseStrategy
|
||||||
|
from lerobot.types import RobotAction, RobotObservation
|
||||||
|
from lerobot.utils.process import ProcessSignalHandler
|
||||||
|
from lerobot.utils.utils import init_logging
|
||||||
|
|
||||||
|
FPS = 30
|
||||||
|
DURATION_SEC = 60
|
||||||
|
TASK_DESCRIPTION = "My task description"
|
||||||
|
HF_MODEL_ID = "<hf_username>/<model_repo_id>"
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
init_logging()
|
||||||
|
|
||||||
|
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
|
||||||
|
robot_config = SO100FollowerConfig(
|
||||||
|
port="/dev/tty.usbmodem58760434471",
|
||||||
|
id="my_awesome_follower_arm",
|
||||||
|
cameras=camera_config,
|
||||||
|
use_degrees=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Peek at motor names once to build the kinematic solver.
|
||||||
|
temp_robot = SO100Follower(robot_config)
|
||||||
|
motor_names = list(temp_robot.bus.motors.keys())
|
||||||
|
|
||||||
|
kinematics_solver = RobotKinematics(
|
||||||
|
urdf_path="./SO101/so101_new_calib.urdf",
|
||||||
|
target_frame_name="gripper_frame_link",
|
||||||
|
joint_names=motor_names,
|
||||||
|
)
|
||||||
|
|
||||||
|
robot_joints_to_ee_pose_processor = RobotProcessorPipeline[RobotObservation, RobotObservation](
|
||||||
|
steps=[ForwardKinematicsJointsToEE(kinematics=kinematics_solver, motor_names=motor_names)],
|
||||||
|
to_transition=observation_to_transition,
|
||||||
|
to_output=transition_to_observation,
|
||||||
|
)
|
||||||
|
|
||||||
|
robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||||
|
steps=[
|
||||||
|
InverseKinematicsEEToJoints(
|
||||||
|
kinematics=kinematics_solver,
|
||||||
|
motor_names=motor_names,
|
||||||
|
initial_guess_current_joints=True,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
to_transition=robot_action_observation_to_transition,
|
||||||
|
to_output=transition_to_robot_action,
|
||||||
|
)
|
||||||
|
|
||||||
|
policy_config = PreTrainedConfig.from_pretrained(HF_MODEL_ID)
|
||||||
|
policy_config.pretrained_path = HF_MODEL_ID
|
||||||
|
|
||||||
|
cfg = RolloutConfig(
|
||||||
|
robot=robot_config,
|
||||||
|
policy=policy_config,
|
||||||
|
strategy=BaseStrategyConfig(),
|
||||||
|
inference=SyncInferenceConfig(),
|
||||||
|
fps=FPS,
|
||||||
|
duration=DURATION_SEC,
|
||||||
|
task=TASK_DESCRIPTION,
|
||||||
|
)
|
||||||
|
|
||||||
|
signal_handler = ProcessSignalHandler(use_threads=True)
|
||||||
|
|
||||||
|
ctx = build_rollout_context(
|
||||||
|
cfg,
|
||||||
|
signal_handler.shutdown_event,
|
||||||
|
robot_action_processor=robot_ee_to_joints_processor,
|
||||||
|
robot_observation_processor=robot_joints_to_ee_pose_processor,
|
||||||
|
)
|
||||||
|
|
||||||
|
strategy = BaseStrategy(cfg.strategy)
|
||||||
|
try:
|
||||||
|
strategy.setup(ctx)
|
||||||
|
strategy.run(ctx)
|
||||||
|
finally:
|
||||||
|
strategy.teardown(ctx)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -1,673 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
"""
|
|
||||||
Demo script showing how to use Real-Time Chunking (RTC) with action chunking policies on real robots.
|
|
||||||
|
|
||||||
This script demonstrates:
|
|
||||||
1. Creating a robot and policy (SmolVLA, Pi0, etc.) with RTC
|
|
||||||
2. Consuming actions from the policy while the robot executes
|
|
||||||
3. Periodically requesting new action chunks in the background using threads
|
|
||||||
4. Managing action buffers and timing for real-time operation
|
|
||||||
|
|
||||||
For simulation environments, see eval_with_simulation.py
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
# Run RTC with Real robot with RTC
|
|
||||||
uv run examples/rtc/eval_with_real_robot.py \
|
|
||||||
--policy.path=<USER>/smolvla_check_rtc_last3 \
|
|
||||||
--policy.device=mps \
|
|
||||||
--rtc.enabled=true \
|
|
||||||
--rtc.execution_horizon=20 \
|
|
||||||
--robot.type=so100_follower \
|
|
||||||
--robot.port=/dev/tty.usbmodem58FA0834591 \
|
|
||||||
--robot.id=so100_follower \
|
|
||||||
--robot.cameras="{ gripper: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
|
||||||
--task="Move green small object into the purple platform" \
|
|
||||||
--duration=120
|
|
||||||
|
|
||||||
# Run RTC with Real robot without RTC
|
|
||||||
uv run examples/rtc/eval_with_real_robot.py \
|
|
||||||
--policy.path=<USER>/smolvla_check_rtc_last3 \
|
|
||||||
--policy.device=mps \
|
|
||||||
--rtc.enabled=false \
|
|
||||||
--robot.type=so100_follower \
|
|
||||||
--robot.port=/dev/tty.usbmodem58FA0834591 \
|
|
||||||
--robot.id=so100_follower \
|
|
||||||
--robot.cameras="{ gripper: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
|
||||||
--task="Move green small object into the purple platform" \
|
|
||||||
--duration=120
|
|
||||||
|
|
||||||
# Run RTC with Real robot with pi0.5 policy
|
|
||||||
uv run examples/rtc/eval_with_real_robot.py \
|
|
||||||
--policy.path=<USER>/pi05_check_rtc \
|
|
||||||
--policy.device=mps \
|
|
||||||
--rtc.enabled=true \
|
|
||||||
--rtc.execution_horizon=20 \
|
|
||||||
--robot.type=so100_follower \
|
|
||||||
--robot.port=/dev/tty.usbmodem58FA0834591 \
|
|
||||||
--robot.id=so100_follower \
|
|
||||||
--robot.cameras="{ gripper: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}}" \
|
|
||||||
--task="Move green small object into the purple platform" \
|
|
||||||
--duration=120
|
|
||||||
|
|
||||||
# Run RTC with bi_openarm_follower (dual-arm OpenArms) and pi0.5 policy
|
|
||||||
python examples/rtc/eval_with_real_robot.py \
|
|
||||||
--policy.path=lerobot-data-collection/folding_final \
|
|
||||||
--robot.type=bi_openarm_follower \
|
|
||||||
--robot.cameras='{left_wrist: {type: opencv, index_or_path: "/dev/video4", width: 1280, height: 720, fps: 30}, base: {type: opencv, index_or_path: "/dev/video2", width: 640, height: 480, fps: 30}, right_wrist: {type: opencv, index_or_path: "/dev/video0", width: 1280, height: 720, fps: 30}}' \
|
|
||||||
--robot.left_arm_config.port=can0 \
|
|
||||||
--robot.left_arm_config.side=left \
|
|
||||||
--robot.left_arm_config.can_interface=socketcan \
|
|
||||||
--robot.left_arm_config.disable_torque_on_disconnect=true \
|
|
||||||
--robot.left_arm_config.max_relative_target=8.0 \
|
|
||||||
--robot.right_arm_config.port=can1 \
|
|
||||||
--robot.right_arm_config.side=right \
|
|
||||||
--robot.right_arm_config.can_interface=socketcan \
|
|
||||||
--robot.right_arm_config.disable_torque_on_disconnect=true \
|
|
||||||
--robot.right_arm_config.max_relative_target=8.0 \
|
|
||||||
--task="Fold the T-shirt properly" \
|
|
||||||
--fps=30 \
|
|
||||||
--duration=2000 \
|
|
||||||
--interpolation_multiplier=3 \
|
|
||||||
--rtc.enabled=true \
|
|
||||||
--rtc.execution_horizon=20 \
|
|
||||||
--rtc.max_guidance_weight=5.0 \
|
|
||||||
--rtc.prefix_attention_schedule=LINEAR \
|
|
||||||
--device=cuda
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import math
|
|
||||||
import sys
|
|
||||||
import time
|
|
||||||
import traceback
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from threading import Event, Lock, Thread
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch import Tensor
|
|
||||||
|
|
||||||
from lerobot.cameras.opencv import OpenCVCameraConfig # noqa: F401
|
|
||||||
from lerobot.cameras.realsense import RealSenseCameraConfig # noqa: F401
|
|
||||||
from lerobot.cameras.zmq import ZMQCameraConfig # noqa: F401
|
|
||||||
from lerobot.configs import PreTrainedConfig, RTCAttentionSchedule, parser
|
|
||||||
from lerobot.policies import get_policy_class, make_pre_post_processors
|
|
||||||
from lerobot.policies.rtc import ActionInterpolator, ActionQueue, LatencyTracker, RTCConfig
|
|
||||||
from lerobot.processor import (
|
|
||||||
NormalizerProcessorStep,
|
|
||||||
RelativeActionsProcessorStep,
|
|
||||||
TransitionKey,
|
|
||||||
create_transition,
|
|
||||||
make_default_robot_action_processor,
|
|
||||||
make_default_robot_observation_processor,
|
|
||||||
to_relative_actions,
|
|
||||||
)
|
|
||||||
from lerobot.rl.process import ProcessSignalHandler
|
|
||||||
from lerobot.robots import ( # noqa: F401
|
|
||||||
Robot,
|
|
||||||
RobotConfig,
|
|
||||||
bi_openarm_follower,
|
|
||||||
bi_so_follower,
|
|
||||||
koch_follower,
|
|
||||||
so_follower,
|
|
||||||
unitree_g1,
|
|
||||||
)
|
|
||||||
from lerobot.robots.utils import make_robot_from_config
|
|
||||||
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
|
|
||||||
from lerobot.utils.feature_utils import build_dataset_frame, hw_to_dataset_features
|
|
||||||
from lerobot.utils.hub import HubMixin
|
|
||||||
from lerobot.utils.utils import init_logging
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class RobotWrapper:
|
|
||||||
def __init__(self, robot: Robot):
|
|
||||||
self.robot = robot
|
|
||||||
self.lock = Lock()
|
|
||||||
|
|
||||||
def get_observation(self) -> dict[str, Tensor]:
|
|
||||||
with self.lock:
|
|
||||||
return self.robot.get_observation()
|
|
||||||
|
|
||||||
def send_action(self, action: Tensor):
|
|
||||||
with self.lock:
|
|
||||||
self.robot.send_action(action)
|
|
||||||
|
|
||||||
def observation_features(self) -> list[str]:
|
|
||||||
with self.lock:
|
|
||||||
return self.robot.observation_features
|
|
||||||
|
|
||||||
def action_features(self) -> list[str]:
|
|
||||||
with self.lock:
|
|
||||||
return self.robot.action_features
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class RTCDemoConfig(HubMixin):
|
|
||||||
"""Configuration for RTC demo with action chunking policies and real robots."""
|
|
||||||
|
|
||||||
# Policy configuration
|
|
||||||
policy: PreTrainedConfig | None = None
|
|
||||||
|
|
||||||
# Robot configuration
|
|
||||||
robot: RobotConfig | None = None
|
|
||||||
|
|
||||||
# RTC configuration
|
|
||||||
rtc: RTCConfig = field(
|
|
||||||
default_factory=lambda: RTCConfig(
|
|
||||||
execution_horizon=10,
|
|
||||||
max_guidance_weight=1.0,
|
|
||||||
prefix_attention_schedule=RTCAttentionSchedule.EXP,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Demo parameters
|
|
||||||
duration: float = 30.0 # Duration to run the demo (seconds)
|
|
||||||
fps: float = 10.0 # Action execution frequency (Hz)
|
|
||||||
interpolation_multiplier: int = 1 # Control rate multiplier (1=off, 2=2x, 3=3x)
|
|
||||||
|
|
||||||
# Compute device
|
|
||||||
device: str | None = None # Device to run on (cuda, cpu, auto)
|
|
||||||
|
|
||||||
# Get new actions horizon. The amount of executed steps after which will be requested new actions.
|
|
||||||
# It should be higher than inference delay + execution horizon.
|
|
||||||
action_queue_size_to_get_new_actions: int = 30
|
|
||||||
|
|
||||||
# Task to execute
|
|
||||||
task: str = field(default="", metadata={"help": "Task to execute"})
|
|
||||||
|
|
||||||
# Torch compile configuration
|
|
||||||
use_torch_compile: bool = field(
|
|
||||||
default=False,
|
|
||||||
metadata={"help": "Use torch.compile for faster inference (PyTorch 2.0+)"},
|
|
||||||
)
|
|
||||||
|
|
||||||
torch_compile_backend: str = field(
|
|
||||||
default="inductor",
|
|
||||||
metadata={"help": "Backend for torch.compile (inductor, aot_eager, cudagraphs)"},
|
|
||||||
)
|
|
||||||
|
|
||||||
torch_compile_mode: str = field(
|
|
||||||
default="default",
|
|
||||||
metadata={"help": "Compilation mode (default, reduce-overhead, max-autotune)"},
|
|
||||||
)
|
|
||||||
|
|
||||||
torch_compile_disable_cudagraphs: bool = field(
|
|
||||||
default=True,
|
|
||||||
metadata={
|
|
||||||
"help": "Disable CUDA graphs in torch.compile. Required due to in-place tensor "
|
|
||||||
"operations in denoising loop (x_t += dt * v_t) which cause tensor aliasing issues."
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
# HACK: We parse again the cli args here to get the pretrained path if there was one.
|
|
||||||
policy_path = parser.get_path_arg("policy")
|
|
||||||
if policy_path:
|
|
||||||
cli_overrides = parser.get_cli_overrides("policy")
|
|
||||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
|
||||||
self.policy.pretrained_path = policy_path
|
|
||||||
else:
|
|
||||||
raise ValueError("Policy path is required")
|
|
||||||
|
|
||||||
# Validate that robot configuration is provided
|
|
||||||
if self.robot is None:
|
|
||||||
raise ValueError("Robot configuration must be provided")
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def __get_path_fields__(cls) -> list[str]:
|
|
||||||
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
|
|
||||||
return ["policy"]
|
|
||||||
|
|
||||||
|
|
||||||
def is_image_key(k: str) -> bool:
|
|
||||||
return k.startswith(OBS_IMAGES)
|
|
||||||
|
|
||||||
|
|
||||||
def _reanchor_relative_rtc_prefix(
|
|
||||||
prev_actions_absolute: Tensor,
|
|
||||||
current_state: Tensor,
|
|
||||||
relative_step: RelativeActionsProcessorStep,
|
|
||||||
normalizer_step: NormalizerProcessorStep | None,
|
|
||||||
policy_device: torch.device | str,
|
|
||||||
) -> Tensor:
|
|
||||||
"""Convert absolute leftovers into model-space for relative-action RTC policies.
|
|
||||||
|
|
||||||
When a policy uses relative actions, the RTC prefix (leftover actions from
|
|
||||||
the previous chunk) is stored in absolute space. Before feeding it back to
|
|
||||||
the policy we need to re-express it relative to the *current* robot state
|
|
||||||
and then re-normalize.
|
|
||||||
"""
|
|
||||||
state = current_state.detach().cpu()
|
|
||||||
if state.dim() == 1:
|
|
||||||
state = state.unsqueeze(0)
|
|
||||||
|
|
||||||
action_cpu = prev_actions_absolute.detach().cpu()
|
|
||||||
mask = relative_step._build_mask(action_cpu.shape[-1])
|
|
||||||
relative_actions = to_relative_actions(action_cpu, state, mask)
|
|
||||||
|
|
||||||
transition = create_transition(action=relative_actions)
|
|
||||||
if normalizer_step is not None:
|
|
||||||
transition = normalizer_step(transition)
|
|
||||||
|
|
||||||
return transition[TransitionKey.ACTION].to(policy_device)
|
|
||||||
|
|
||||||
|
|
||||||
def get_actions(
|
|
||||||
policy,
|
|
||||||
robot: RobotWrapper,
|
|
||||||
robot_observation_processor,
|
|
||||||
action_queue: ActionQueue,
|
|
||||||
shutdown_event: Event,
|
|
||||||
cfg: RTCDemoConfig,
|
|
||||||
):
|
|
||||||
"""Thread function to request action chunks from the policy.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
policy: The policy instance (SmolVLA, Pi0, etc.)
|
|
||||||
robot: The robot instance for getting observations
|
|
||||||
robot_observation_processor: Processor for raw robot observations
|
|
||||||
action_queue: Queue to put new action chunks
|
|
||||||
shutdown_event: Event to signal shutdown
|
|
||||||
cfg: Demo configuration
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
logger.info("[GET_ACTIONS] Starting get actions thread")
|
|
||||||
|
|
||||||
latency_tracker = LatencyTracker() # Track latency of action chunks
|
|
||||||
fps = cfg.fps
|
|
||||||
time_per_chunk = 1.0 / fps
|
|
||||||
|
|
||||||
# Only keep .pos joints + camera streams if the policy was trained on positions,
|
|
||||||
# not the full pos/vel/torque state the robot exposes.
|
|
||||||
observation_features_hw = {
|
|
||||||
key: value
|
|
||||||
for key, value in robot.observation_features().items()
|
|
||||||
if key.endswith(".pos") or isinstance(value, tuple)
|
|
||||||
}
|
|
||||||
|
|
||||||
dataset_features = hw_to_dataset_features(observation_features_hw, "observation")
|
|
||||||
policy_device = policy.config.device
|
|
||||||
|
|
||||||
# Load preprocessor and postprocessor from pretrained files
|
|
||||||
# The stats are embedded in the processor .safetensors files
|
|
||||||
logger.info(f"[GET_ACTIONS] Loading preprocessor/postprocessor from {cfg.policy.pretrained_path}")
|
|
||||||
|
|
||||||
preprocessor, postprocessor = make_pre_post_processors(
|
|
||||||
policy_cfg=cfg.policy,
|
|
||||||
pretrained_path=cfg.policy.pretrained_path,
|
|
||||||
dataset_stats=None, # Will load from pretrained processor files
|
|
||||||
preprocessor_overrides={
|
|
||||||
"device_processor": {"device": cfg.policy.device},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info("[GET_ACTIONS] Preprocessor/postprocessor loaded successfully with embedded stats")
|
|
||||||
|
|
||||||
relative_step = next(
|
|
||||||
(s for s in preprocessor.steps if isinstance(s, RelativeActionsProcessorStep) and s.enabled),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
normalizer_step = next(
|
|
||||||
(s for s in preprocessor.steps if isinstance(s, NormalizerProcessorStep)),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
if relative_step is not None:
|
|
||||||
if relative_step.action_names is None:
|
|
||||||
cfg_names = getattr(cfg.policy, "action_feature_names", None)
|
|
||||||
if cfg_names:
|
|
||||||
relative_step.action_names = list(cfg_names)
|
|
||||||
else:
|
|
||||||
relative_step.action_names = [
|
|
||||||
k for k in robot.robot.action_features if k.endswith(".pos")
|
|
||||||
]
|
|
||||||
logger.info("[GET_ACTIONS] Relative actions enabled: will re-anchor RTC prefix")
|
|
||||||
|
|
||||||
get_actions_threshold = cfg.action_queue_size_to_get_new_actions
|
|
||||||
|
|
||||||
if not cfg.rtc.enabled:
|
|
||||||
get_actions_threshold = 0
|
|
||||||
|
|
||||||
while not shutdown_event.is_set():
|
|
||||||
if action_queue.qsize() <= get_actions_threshold:
|
|
||||||
current_time = time.perf_counter()
|
|
||||||
action_index_before_inference = action_queue.get_action_index()
|
|
||||||
prev_actions = action_queue.get_left_over()
|
|
||||||
|
|
||||||
inference_latency = latency_tracker.max()
|
|
||||||
inference_delay = math.ceil(inference_latency / time_per_chunk)
|
|
||||||
|
|
||||||
obs = robot.get_observation()
|
|
||||||
|
|
||||||
# Apply robot observation processor
|
|
||||||
obs_processed = robot_observation_processor(obs)
|
|
||||||
|
|
||||||
obs_with_policy_features = build_dataset_frame(
|
|
||||||
dataset_features, obs_processed, prefix="observation"
|
|
||||||
)
|
|
||||||
|
|
||||||
for name in obs_with_policy_features:
|
|
||||||
obs_with_policy_features[name] = torch.from_numpy(obs_with_policy_features[name])
|
|
||||||
if "image" in name:
|
|
||||||
obs_with_policy_features[name] = (
|
|
||||||
obs_with_policy_features[name].type(torch.float32) / 255
|
|
||||||
)
|
|
||||||
obs_with_policy_features[name] = (
|
|
||||||
obs_with_policy_features[name].permute(2, 0, 1).contiguous()
|
|
||||||
)
|
|
||||||
obs_with_policy_features[name] = obs_with_policy_features[name].unsqueeze(0)
|
|
||||||
obs_with_policy_features[name] = obs_with_policy_features[name].to(policy_device)
|
|
||||||
|
|
||||||
obs_with_policy_features["task"] = [cfg.task] # Task should be a list, not a string!
|
|
||||||
obs_with_policy_features["robot_type"] = (
|
|
||||||
robot.robot.name if hasattr(robot.robot, "name") else ""
|
|
||||||
)
|
|
||||||
|
|
||||||
preproceseded_obs = preprocessor(obs_with_policy_features)
|
|
||||||
|
|
||||||
# Re-anchor leftover actions for relative-action policies.
|
|
||||||
# We need the *postprocessed* (absolute) leftover, not the original
|
|
||||||
# (normalized/relative) one that get_left_over() returns.
|
|
||||||
if (
|
|
||||||
prev_actions is not None
|
|
||||||
and relative_step is not None
|
|
||||||
and OBS_STATE in obs_with_policy_features
|
|
||||||
):
|
|
||||||
with action_queue.lock:
|
|
||||||
if action_queue.queue is not None:
|
|
||||||
prev_actions_abs = action_queue.queue[action_queue.last_index :].clone()
|
|
||||||
else:
|
|
||||||
prev_actions_abs = None
|
|
||||||
if prev_actions_abs is not None and prev_actions_abs.numel() > 0:
|
|
||||||
prev_actions = _reanchor_relative_rtc_prefix(
|
|
||||||
prev_actions_absolute=prev_actions_abs,
|
|
||||||
current_state=obs_with_policy_features[OBS_STATE],
|
|
||||||
relative_step=relative_step,
|
|
||||||
normalizer_step=normalizer_step,
|
|
||||||
policy_device=policy_device,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Generate actions WITH RTC
|
|
||||||
actions = policy.predict_action_chunk(
|
|
||||||
preproceseded_obs,
|
|
||||||
inference_delay=inference_delay,
|
|
||||||
prev_chunk_left_over=prev_actions,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Store original actions (before postprocessing) for RTC
|
|
||||||
original_actions = actions.squeeze(0).clone()
|
|
||||||
|
|
||||||
postprocessed_actions = postprocessor(actions)
|
|
||||||
|
|
||||||
postprocessed_actions = postprocessed_actions.squeeze(0)
|
|
||||||
|
|
||||||
new_latency = time.perf_counter() - current_time
|
|
||||||
new_delay = math.ceil(new_latency / time_per_chunk)
|
|
||||||
latency_tracker.add(new_latency)
|
|
||||||
|
|
||||||
if cfg.action_queue_size_to_get_new_actions < cfg.rtc.execution_horizon + new_delay:
|
|
||||||
logger.warning(
|
|
||||||
"[GET_ACTIONS] cfg.action_queue_size_to_get_new_actions Too small, It should be higher than inference delay + execution horizon."
|
|
||||||
)
|
|
||||||
|
|
||||||
action_queue.merge(
|
|
||||||
original_actions, postprocessed_actions, new_delay, action_index_before_inference
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Small sleep to prevent busy waiting
|
|
||||||
time.sleep(0.1)
|
|
||||||
|
|
||||||
logger.info("[GET_ACTIONS] get actions thread shutting down")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"[GET_ACTIONS] Fatal exception in get_actions thread: {e}")
|
|
||||||
logger.error(traceback.format_exc())
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
|
|
||||||
def actor_control(
|
|
||||||
robot: RobotWrapper,
|
|
||||||
robot_action_processor,
|
|
||||||
action_queue: ActionQueue,
|
|
||||||
shutdown_event: Event,
|
|
||||||
cfg: RTCDemoConfig,
|
|
||||||
):
|
|
||||||
"""Thread function to execute actions on the robot.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
robot: The robot instance
|
|
||||||
action_queue: Queue to get actions from
|
|
||||||
shutdown_event: Event to signal shutdown
|
|
||||||
cfg: Demo configuration
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
logger.info("[ACTOR] Starting actor thread")
|
|
||||||
|
|
||||||
action_keys = [k for k in robot.action_features() if k.endswith(".pos")]
|
|
||||||
|
|
||||||
action_count = 0
|
|
||||||
interpolator = ActionInterpolator(multiplier=cfg.interpolation_multiplier)
|
|
||||||
action_interval = interpolator.get_control_interval(cfg.fps)
|
|
||||||
|
|
||||||
while not shutdown_event.is_set():
|
|
||||||
start_time = time.perf_counter()
|
|
||||||
|
|
||||||
if interpolator.needs_new_action():
|
|
||||||
new_action = action_queue.get()
|
|
||||||
if new_action is not None:
|
|
||||||
interpolator.add(new_action.cpu())
|
|
||||||
|
|
||||||
action = interpolator.get()
|
|
||||||
if action is not None:
|
|
||||||
action = action.cpu()
|
|
||||||
action_dict = {key: action[i].item() for i, key in enumerate(action_keys)}
|
|
||||||
action_processed = robot_action_processor((action_dict, None))
|
|
||||||
robot.send_action(action_processed)
|
|
||||||
action_count += 1
|
|
||||||
|
|
||||||
dt_s = time.perf_counter() - start_time
|
|
||||||
time.sleep(max(0, (action_interval - dt_s) - 0.001))
|
|
||||||
|
|
||||||
logger.info(f"[ACTOR] Actor thread shutting down. Total actions executed: {action_count}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"[ACTOR] Fatal exception in actor_control thread: {e}")
|
|
||||||
logger.error(traceback.format_exc())
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
|
|
||||||
def _apply_torch_compile(policy, cfg: RTCDemoConfig):
|
|
||||||
"""Apply torch.compile to the policy's predict_action_chunk method.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
policy: Policy instance to compile
|
|
||||||
cfg: Configuration containing torch compile settings
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Policy with compiled predict_action_chunk method
|
|
||||||
"""
|
|
||||||
|
|
||||||
# PI models handle their own compilation
|
|
||||||
if policy.type == "pi05" or policy.type == "pi0":
|
|
||||||
return policy
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Check if torch.compile is available (PyTorch 2.0+)
|
|
||||||
if not hasattr(torch, "compile"):
|
|
||||||
logger.warning(
|
|
||||||
f"torch.compile is not available. Requires PyTorch 2.0+. "
|
|
||||||
f"Current version: {torch.__version__}. Skipping compilation."
|
|
||||||
)
|
|
||||||
return policy
|
|
||||||
|
|
||||||
logger.info("Applying torch.compile to predict_action_chunk...")
|
|
||||||
logger.info(f" Backend: {cfg.torch_compile_backend}")
|
|
||||||
logger.info(f" Mode: {cfg.torch_compile_mode}")
|
|
||||||
logger.info(f" Disable CUDA graphs: {cfg.torch_compile_disable_cudagraphs}")
|
|
||||||
|
|
||||||
# Compile the predict_action_chunk method
|
|
||||||
# - CUDA graphs disabled to prevent tensor aliasing from in-place ops (x_t += dt * v_t)
|
|
||||||
compile_kwargs = {
|
|
||||||
"backend": cfg.torch_compile_backend,
|
|
||||||
"mode": cfg.torch_compile_mode,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Disable CUDA graphs if requested (prevents tensor aliasing issues)
|
|
||||||
if cfg.torch_compile_disable_cudagraphs:
|
|
||||||
compile_kwargs["options"] = {"triton.cudagraphs": False}
|
|
||||||
|
|
||||||
original_method = policy.predict_action_chunk
|
|
||||||
compiled_method = torch.compile(original_method, **compile_kwargs)
|
|
||||||
policy.predict_action_chunk = compiled_method
|
|
||||||
logger.info("✓ Successfully compiled predict_action_chunk")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to apply torch.compile: {e}")
|
|
||||||
logger.warning("Continuing without torch.compile")
|
|
||||||
|
|
||||||
return policy
|
|
||||||
|
|
||||||
|
|
||||||
@parser.wrap()
|
|
||||||
def demo_cli(cfg: RTCDemoConfig):
|
|
||||||
"""Main entry point for RTC demo with draccus configuration."""
|
|
||||||
|
|
||||||
# Initialize logging
|
|
||||||
init_logging()
|
|
||||||
|
|
||||||
logger.info(f"Using device: {cfg.device}")
|
|
||||||
|
|
||||||
# Setup signal handler for graceful shutdown
|
|
||||||
signal_handler = ProcessSignalHandler(use_threads=True, display_pid=False)
|
|
||||||
shutdown_event = signal_handler.shutdown_event
|
|
||||||
|
|
||||||
policy = None
|
|
||||||
robot = None
|
|
||||||
get_actions_thread = None
|
|
||||||
actor_thread = None
|
|
||||||
|
|
||||||
policy_class = get_policy_class(cfg.policy.type)
|
|
||||||
|
|
||||||
# Load config and set compile_model for pi0/pi05 models
|
|
||||||
config = PreTrainedConfig.from_pretrained(cfg.policy.pretrained_path)
|
|
||||||
|
|
||||||
if cfg.policy.type == "pi05" or cfg.policy.type == "pi0":
|
|
||||||
config.compile_model = cfg.use_torch_compile
|
|
||||||
|
|
||||||
if config.use_peft:
|
|
||||||
from peft import PeftConfig, PeftModel
|
|
||||||
|
|
||||||
peft_pretrained_path = cfg.policy.pretrained_path
|
|
||||||
peft_config = PeftConfig.from_pretrained(peft_pretrained_path)
|
|
||||||
|
|
||||||
policy = policy_class.from_pretrained(
|
|
||||||
pretrained_name_or_path=peft_config.base_model_name_or_path, config=config
|
|
||||||
)
|
|
||||||
policy = PeftModel.from_pretrained(policy, peft_pretrained_path, config=peft_config)
|
|
||||||
else:
|
|
||||||
policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=config)
|
|
||||||
|
|
||||||
# Turn on RTC
|
|
||||||
policy.config.rtc_config = cfg.rtc
|
|
||||||
|
|
||||||
# Init RTC processort, as by default if RTC disabled in the config
|
|
||||||
# The processor won't be created
|
|
||||||
policy.init_rtc_processor()
|
|
||||||
|
|
||||||
assert policy.name in ["smolvla", "pi05", "pi0"], "Only smolvla, pi05, and pi0 are supported for RTC"
|
|
||||||
|
|
||||||
policy = policy.to(cfg.device)
|
|
||||||
policy.eval()
|
|
||||||
|
|
||||||
# Apply torch.compile to predict_action_chunk method if enabled
|
|
||||||
if cfg.use_torch_compile:
|
|
||||||
policy = _apply_torch_compile(policy, cfg)
|
|
||||||
|
|
||||||
# Create robot
|
|
||||||
logger.info(f"Initializing robot: {cfg.robot.type}")
|
|
||||||
robot = make_robot_from_config(cfg.robot)
|
|
||||||
robot.connect()
|
|
||||||
robot_wrapper = RobotWrapper(robot)
|
|
||||||
|
|
||||||
# Create robot observation processor
|
|
||||||
robot_observation_processor = make_default_robot_observation_processor()
|
|
||||||
robot_action_processor = make_default_robot_action_processor()
|
|
||||||
|
|
||||||
# Create action queue for communication between threads
|
|
||||||
action_queue = ActionQueue(cfg.rtc)
|
|
||||||
|
|
||||||
# Start chunk requester thread
|
|
||||||
get_actions_thread = Thread(
|
|
||||||
target=get_actions,
|
|
||||||
args=(policy, robot_wrapper, robot_observation_processor, action_queue, shutdown_event, cfg),
|
|
||||||
daemon=True,
|
|
||||||
name="GetActions",
|
|
||||||
)
|
|
||||||
get_actions_thread.start()
|
|
||||||
logger.info("Started get actions thread")
|
|
||||||
|
|
||||||
# Start action executor thread
|
|
||||||
actor_thread = Thread(
|
|
||||||
target=actor_control,
|
|
||||||
args=(robot_wrapper, robot_action_processor, action_queue, shutdown_event, cfg),
|
|
||||||
daemon=True,
|
|
||||||
name="Actor",
|
|
||||||
)
|
|
||||||
actor_thread.start()
|
|
||||||
logger.info("Started actor thread")
|
|
||||||
|
|
||||||
logger.info("Started stop by duration thread")
|
|
||||||
|
|
||||||
# Main thread monitors for duration or shutdown
|
|
||||||
logger.info(f"Running demo for {cfg.duration} seconds...")
|
|
||||||
start_time = time.time()
|
|
||||||
|
|
||||||
while not shutdown_event.is_set() and (time.time() - start_time) < cfg.duration:
|
|
||||||
time.sleep(10)
|
|
||||||
|
|
||||||
# Log queue status periodically
|
|
||||||
if int(time.time() - start_time) % 5 == 0:
|
|
||||||
logger.info(f"[MAIN] Action queue size: {action_queue.qsize()}")
|
|
||||||
|
|
||||||
if time.time() - start_time > cfg.duration:
|
|
||||||
break
|
|
||||||
|
|
||||||
logger.info("Demo duration reached or shutdown requested")
|
|
||||||
|
|
||||||
# Signal shutdown
|
|
||||||
shutdown_event.set()
|
|
||||||
|
|
||||||
# Wait for threads to finish
|
|
||||||
if get_actions_thread and get_actions_thread.is_alive():
|
|
||||||
logger.info("Waiting for chunk requester thread to finish...")
|
|
||||||
get_actions_thread.join()
|
|
||||||
|
|
||||||
if actor_thread and actor_thread.is_alive():
|
|
||||||
logger.info("Waiting for action executor thread to finish...")
|
|
||||||
actor_thread.join()
|
|
||||||
|
|
||||||
# Cleanup robot
|
|
||||||
if robot:
|
|
||||||
robot.disconnect()
|
|
||||||
logger.info("Robot disconnected")
|
|
||||||
|
|
||||||
logger.info("Cleanup completed")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
demo_cli()
|
|
||||||
logging.info("RTC demo finished")
|
|
||||||
@@ -14,13 +14,17 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
|
||||||
from lerobot.cameras.opencv import OpenCVCameraConfig
|
from lerobot.cameras.opencv import OpenCVCameraConfig
|
||||||
from lerobot.common.control_utils import init_keyboard_listener
|
from lerobot.common.control_utils import init_keyboard_listener, predict_action
|
||||||
from lerobot.configs import FeatureType, PolicyFeature
|
from lerobot.configs import FeatureType, PolicyFeature
|
||||||
from lerobot.datasets import LeRobotDataset, aggregate_pipeline_dataset_features, create_initial_features
|
from lerobot.datasets import LeRobotDataset, aggregate_pipeline_dataset_features, create_initial_features
|
||||||
from lerobot.model.kinematics import RobotKinematics
|
from lerobot.model.kinematics import RobotKinematics
|
||||||
from lerobot.policies import make_pre_post_processors
|
from lerobot.policies import make_pre_post_processors
|
||||||
from lerobot.policies.act import ACTPolicy
|
from lerobot.policies.act import ACTPolicy
|
||||||
|
from lerobot.policies.utils import make_robot_action
|
||||||
from lerobot.processor import (
|
from lerobot.processor import (
|
||||||
RobotProcessorPipeline,
|
RobotProcessorPipeline,
|
||||||
make_default_teleop_action_processor,
|
make_default_teleop_action_processor,
|
||||||
@@ -34,11 +38,12 @@ from lerobot.robots.so_follower.robot_kinematic_processor import (
|
|||||||
ForwardKinematicsJointsToEE,
|
ForwardKinematicsJointsToEE,
|
||||||
InverseKinematicsEEToJoints,
|
InverseKinematicsEEToJoints,
|
||||||
)
|
)
|
||||||
from lerobot.scripts.lerobot_record import record_loop
|
|
||||||
from lerobot.types import RobotAction, RobotObservation
|
from lerobot.types import RobotAction, RobotObservation
|
||||||
from lerobot.utils.feature_utils import combine_feature_dicts
|
from lerobot.utils.constants import ACTION, OBS_STR
|
||||||
|
from lerobot.utils.feature_utils import build_dataset_frame, combine_feature_dicts
|
||||||
|
from lerobot.utils.robot_utils import precise_sleep
|
||||||
from lerobot.utils.utils import log_say
|
from lerobot.utils.utils import log_say
|
||||||
from lerobot.utils.visualization_utils import init_rerun
|
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
|
||||||
|
|
||||||
NUM_EPISODES = 5
|
NUM_EPISODES = 5
|
||||||
FPS = 30
|
FPS = 30
|
||||||
@@ -49,6 +54,9 @@ HF_DATASET_ID = "<hf_username>/<dataset_repo_id>"
|
|||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
# NOTE: For production policy deployment, use `lerobot-rollout` CLI instead.
|
||||||
|
# This script provides a self-contained example for educational purposes.
|
||||||
|
|
||||||
# Create the robot configuration & robot
|
# Create the robot configuration & robot
|
||||||
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
|
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
|
||||||
robot_config = SO100FollowerConfig(
|
robot_config = SO100FollowerConfig(
|
||||||
@@ -143,43 +151,67 @@ def main():
|
|||||||
raise ValueError("Robot is not connected!")
|
raise ValueError("Robot is not connected!")
|
||||||
|
|
||||||
print("Starting evaluate loop...")
|
print("Starting evaluate loop...")
|
||||||
|
control_interval = 1 / FPS
|
||||||
episode_idx = 0
|
episode_idx = 0
|
||||||
for episode_idx in range(NUM_EPISODES):
|
for episode_idx in range(NUM_EPISODES):
|
||||||
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
|
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||||
|
|
||||||
# Main record loop
|
# Inline evaluation loop: predict actions and send to robot
|
||||||
record_loop(
|
timestamp = 0
|
||||||
robot=robot,
|
start_episode_t = time.perf_counter()
|
||||||
events=events,
|
while timestamp < EPISODE_TIME_SEC:
|
||||||
fps=FPS,
|
start_loop_t = time.perf_counter()
|
||||||
policy=policy,
|
|
||||||
preprocessor=preprocessor, # Pass the pre and post policy processors
|
if events["exit_early"]:
|
||||||
postprocessor=postprocessor,
|
events["exit_early"] = False
|
||||||
dataset=dataset,
|
break
|
||||||
control_time_s=EPISODE_TIME_SEC,
|
|
||||||
single_task=TASK_DESCRIPTION,
|
# Get robot observation
|
||||||
display_data=True,
|
obs = robot.get_observation()
|
||||||
teleop_action_processor=make_default_teleop_action_processor(),
|
obs_processed = robot_joints_to_ee_pose_processor(obs)
|
||||||
robot_action_processor=robot_ee_to_joints_processor,
|
observation_frame = build_dataset_frame(dataset.features, obs_processed, prefix=OBS_STR)
|
||||||
robot_observation_processor=robot_joints_to_ee_pose_processor,
|
|
||||||
)
|
# Predict action using the policy
|
||||||
|
action_tensor = predict_action(
|
||||||
|
observation=observation_frame,
|
||||||
|
policy=policy,
|
||||||
|
device=policy.config.device,
|
||||||
|
preprocessor=preprocessor,
|
||||||
|
postprocessor=postprocessor,
|
||||||
|
use_amp=policy.config.device.type == "cuda",
|
||||||
|
task=TASK_DESCRIPTION,
|
||||||
|
robot_type=robot.name,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert policy output to robot action dict
|
||||||
|
action_values = make_robot_action(action_tensor, dataset.features)
|
||||||
|
|
||||||
|
# Process and send action to robot (EE -> joints via IK)
|
||||||
|
robot_action_to_send = robot_ee_to_joints_processor((action_values, obs))
|
||||||
|
robot.send_action(robot_action_to_send)
|
||||||
|
|
||||||
|
# Write to dataset
|
||||||
|
action_frame = build_dataset_frame(dataset.features, action_values, prefix=ACTION)
|
||||||
|
frame = {**observation_frame, **action_frame, "task": TASK_DESCRIPTION}
|
||||||
|
dataset.add_frame(frame)
|
||||||
|
|
||||||
|
log_rerun_data(observation=obs_processed, action=action_values)
|
||||||
|
|
||||||
|
dt_s = time.perf_counter() - start_loop_t
|
||||||
|
sleep_time_s = control_interval - dt_s
|
||||||
|
if sleep_time_s < 0:
|
||||||
|
logging.warning(
|
||||||
|
f"Evaluate loop is running slower ({1 / dt_s:.1f} Hz) than the target FPS ({FPS} Hz)."
|
||||||
|
)
|
||||||
|
precise_sleep(max(sleep_time_s, 0.0))
|
||||||
|
timestamp = time.perf_counter() - start_episode_t
|
||||||
|
|
||||||
# Reset the environment if not stopping or re-recording
|
# Reset the environment if not stopping or re-recording
|
||||||
if not events["stop_recording"] and (
|
if not events["stop_recording"] and (
|
||||||
(episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"]
|
(episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"]
|
||||||
):
|
):
|
||||||
log_say("Reset the environment")
|
log_say("Reset the environment")
|
||||||
record_loop(
|
log_say("Waiting for environment reset, press right arrow key when ready...")
|
||||||
robot=robot,
|
|
||||||
events=events,
|
|
||||||
fps=FPS,
|
|
||||||
control_time_s=EPISODE_TIME_SEC,
|
|
||||||
single_task=TASK_DESCRIPTION,
|
|
||||||
display_data=True,
|
|
||||||
teleop_action_processor=make_default_teleop_action_processor(),
|
|
||||||
robot_action_processor=robot_ee_to_joints_processor,
|
|
||||||
robot_observation_processor=robot_joints_to_ee_pose_processor,
|
|
||||||
)
|
|
||||||
|
|
||||||
if events["rerecord_episode"]:
|
if events["rerecord_episode"]:
|
||||||
log_say("Re-record episode")
|
log_say("Re-record episode")
|
||||||
@@ -190,7 +222,6 @@ def main():
|
|||||||
|
|
||||||
# Save episode
|
# Save episode
|
||||||
dataset.save_episode()
|
dataset.save_episode()
|
||||||
episode_idx += 1
|
|
||||||
finally:
|
finally:
|
||||||
# Clean up
|
# Clean up
|
||||||
log_say("Stop recording")
|
log_say("Stop recording")
|
||||||
|
|||||||
@@ -62,21 +62,20 @@ def main():
|
|||||||
follower = SO100Follower(follower_config)
|
follower = SO100Follower(follower_config)
|
||||||
leader = SO100Leader(leader_config)
|
leader = SO100Leader(leader_config)
|
||||||
|
|
||||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo:
|
||||||
|
# https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||||
follower_kinematics_solver = RobotKinematics(
|
follower_kinematics_solver = RobotKinematics(
|
||||||
urdf_path="./SO101/so101_new_calib.urdf",
|
urdf_path="./SO101/so101_new_calib.urdf",
|
||||||
target_frame_name="gripper_frame_link",
|
target_frame_name="gripper_frame_link",
|
||||||
joint_names=list(follower.bus.motors.keys()),
|
joint_names=list(follower.bus.motors.keys()),
|
||||||
)
|
)
|
||||||
|
|
||||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
|
||||||
leader_kinematics_solver = RobotKinematics(
|
leader_kinematics_solver = RobotKinematics(
|
||||||
urdf_path="./SO101/so101_new_calib.urdf",
|
urdf_path="./SO101/so101_new_calib.urdf",
|
||||||
target_frame_name="gripper_frame_link",
|
target_frame_name="gripper_frame_link",
|
||||||
joint_names=list(leader.bus.motors.keys()),
|
joint_names=list(leader.bus.motors.keys()),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Build pipeline to convert follower joints to EE observation
|
# Build pipeline to convert follower joints to EE observation.
|
||||||
follower_joints_to_ee = RobotProcessorPipeline[RobotObservation, RobotObservation](
|
follower_joints_to_ee = RobotProcessorPipeline[RobotObservation, RobotObservation](
|
||||||
steps=[
|
steps=[
|
||||||
ForwardKinematicsJointsToEE(
|
ForwardKinematicsJointsToEE(
|
||||||
@@ -87,7 +86,7 @@ def main():
|
|||||||
to_output=transition_to_observation,
|
to_output=transition_to_observation,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Build pipeline to convert leader joints to EE action
|
# Build pipeline to convert leader joints to EE action.
|
||||||
leader_joints_to_ee = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
leader_joints_to_ee = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||||
steps=[
|
steps=[
|
||||||
ForwardKinematicsJointsToEE(
|
ForwardKinematicsJointsToEE(
|
||||||
@@ -98,9 +97,9 @@ def main():
|
|||||||
to_output=transition_to_robot_action,
|
to_output=transition_to_robot_action,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Build pipeline to convert EE action to follower joints
|
# Build pipeline to convert EE action to follower joints (with safety bounds).
|
||||||
ee_to_follower_joints = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
ee_to_follower_joints = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||||
[
|
steps=[
|
||||||
EEBoundsAndSafety(
|
EEBoundsAndSafety(
|
||||||
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
|
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
|
||||||
max_ee_step_m=0.10,
|
max_ee_step_m=0.10,
|
||||||
@@ -115,13 +114,12 @@ def main():
|
|||||||
to_output=transition_to_robot_action,
|
to_output=transition_to_robot_action,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create the dataset
|
# Create the dataset, deriving features from the pipelines so the on-disk schema
|
||||||
|
# matches exactly what the pipelines produce at runtime.
|
||||||
dataset = LeRobotDataset.create(
|
dataset = LeRobotDataset.create(
|
||||||
repo_id=HF_REPO_ID,
|
repo_id=HF_REPO_ID,
|
||||||
fps=FPS,
|
fps=FPS,
|
||||||
features=combine_feature_dicts(
|
features=combine_feature_dicts(
|
||||||
# Run the feature contract of the pipelines
|
|
||||||
# This tells you how the features would look like after the pipeline steps
|
|
||||||
aggregate_pipeline_dataset_features(
|
aggregate_pipeline_dataset_features(
|
||||||
pipeline=leader_joints_to_ee,
|
pipeline=leader_joints_to_ee,
|
||||||
initial_features=create_initial_features(action=leader.action_features),
|
initial_features=create_initial_features(action=leader.action_features),
|
||||||
@@ -144,7 +142,7 @@ def main():
|
|||||||
|
|
||||||
# Initialize the keyboard listener and rerun visualization
|
# Initialize the keyboard listener and rerun visualization
|
||||||
listener, events = init_keyboard_listener()
|
listener, events = init_keyboard_listener()
|
||||||
init_rerun(session_name="recording_phone")
|
init_rerun(session_name="recording_so100_ee")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not leader.is_connected or not follower.is_connected:
|
if not leader.is_connected or not follower.is_connected:
|
||||||
@@ -160,14 +158,14 @@ def main():
|
|||||||
robot=follower,
|
robot=follower,
|
||||||
events=events,
|
events=events,
|
||||||
fps=FPS,
|
fps=FPS,
|
||||||
|
teleop_action_processor=leader_joints_to_ee,
|
||||||
|
robot_action_processor=ee_to_follower_joints,
|
||||||
|
robot_observation_processor=follower_joints_to_ee,
|
||||||
teleop=leader,
|
teleop=leader,
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
control_time_s=EPISODE_TIME_SEC,
|
control_time_s=EPISODE_TIME_SEC,
|
||||||
single_task=TASK_DESCRIPTION,
|
single_task=TASK_DESCRIPTION,
|
||||||
display_data=True,
|
display_data=True,
|
||||||
teleop_action_processor=leader_joints_to_ee,
|
|
||||||
robot_action_processor=ee_to_follower_joints,
|
|
||||||
robot_observation_processor=follower_joints_to_ee,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Reset the environment if not stopping or re-recording
|
# Reset the environment if not stopping or re-recording
|
||||||
@@ -179,13 +177,13 @@ def main():
|
|||||||
robot=follower,
|
robot=follower,
|
||||||
events=events,
|
events=events,
|
||||||
fps=FPS,
|
fps=FPS,
|
||||||
|
teleop_action_processor=leader_joints_to_ee,
|
||||||
|
robot_action_processor=ee_to_follower_joints,
|
||||||
|
robot_observation_processor=follower_joints_to_ee,
|
||||||
teleop=leader,
|
teleop=leader,
|
||||||
control_time_s=RESET_TIME_SEC,
|
control_time_s=RESET_TIME_SEC,
|
||||||
single_task=TASK_DESCRIPTION,
|
single_task=TASK_DESCRIPTION,
|
||||||
display_data=True,
|
display_data=True,
|
||||||
teleop_action_processor=leader_joints_to_ee,
|
|
||||||
robot_action_processor=ee_to_follower_joints,
|
|
||||||
robot_observation_processor=follower_joints_to_ee,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if events["rerecord_episode"]:
|
if events["rerecord_episode"]:
|
||||||
|
|||||||
@@ -0,0 +1,134 @@
|
|||||||
|
# !/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""Run a trained EE-space policy on SO100 without recording (base rollout).
|
||||||
|
|
||||||
|
Uses the rollout engine's :class:`BaseStrategy` (autonomous execution,
|
||||||
|
no dataset) with :class:`SyncInferenceConfig` (inline policy call per
|
||||||
|
control tick). The custom observation/action processors convert between
|
||||||
|
joint space (robot hardware) and end-effector space (policy I/O) via
|
||||||
|
forward/inverse kinematics.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from lerobot.cameras.opencv import OpenCVCameraConfig
|
||||||
|
from lerobot.configs import PreTrainedConfig
|
||||||
|
from lerobot.model.kinematics import RobotKinematics
|
||||||
|
from lerobot.processor import (
|
||||||
|
RobotProcessorPipeline,
|
||||||
|
observation_to_transition,
|
||||||
|
robot_action_observation_to_transition,
|
||||||
|
transition_to_observation,
|
||||||
|
transition_to_robot_action,
|
||||||
|
)
|
||||||
|
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig
|
||||||
|
from lerobot.robots.so_follower.robot_kinematic_processor import (
|
||||||
|
ForwardKinematicsJointsToEE,
|
||||||
|
InverseKinematicsEEToJoints,
|
||||||
|
)
|
||||||
|
from lerobot.rollout import BaseStrategyConfig, RolloutConfig, build_rollout_context
|
||||||
|
from lerobot.rollout.inference import SyncInferenceConfig
|
||||||
|
from lerobot.rollout.strategies import BaseStrategy
|
||||||
|
from lerobot.types import RobotAction, RobotObservation
|
||||||
|
from lerobot.utils.process import ProcessSignalHandler
|
||||||
|
from lerobot.utils.utils import init_logging
|
||||||
|
|
||||||
|
FPS = 30
|
||||||
|
DURATION_SEC = 60
|
||||||
|
TASK_DESCRIPTION = "My task description"
|
||||||
|
HF_MODEL_ID = "<hf_username>/<model_repo_id>"
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
init_logging()
|
||||||
|
|
||||||
|
# Robot configuration — the rollout engine will connect it inside build_rollout_context.
|
||||||
|
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
|
||||||
|
robot_config = SO100FollowerConfig(
|
||||||
|
port="/dev/tty.usbmodem5A460814411",
|
||||||
|
id="my_awesome_follower_arm",
|
||||||
|
cameras=camera_config,
|
||||||
|
use_degrees=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Kinematic solver: we need the motor-name list, so peek at the robot once.
|
||||||
|
# (The rollout engine owns the connected instance; we only use this for introspection.)
|
||||||
|
temp_robot = SO100Follower(robot_config)
|
||||||
|
motor_names = list(temp_robot.bus.motors.keys())
|
||||||
|
|
||||||
|
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo:
|
||||||
|
# https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||||
|
kinematics_solver = RobotKinematics(
|
||||||
|
urdf_path="./SO101/so101_new_calib.urdf",
|
||||||
|
target_frame_name="gripper_frame_link",
|
||||||
|
joint_names=motor_names,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Joint-space observation → EE-space observation (consumed by the policy).
|
||||||
|
robot_joints_to_ee_pose_processor = RobotProcessorPipeline[RobotObservation, RobotObservation](
|
||||||
|
steps=[ForwardKinematicsJointsToEE(kinematics=kinematics_solver, motor_names=motor_names)],
|
||||||
|
to_transition=observation_to_transition,
|
||||||
|
to_output=transition_to_observation,
|
||||||
|
)
|
||||||
|
|
||||||
|
# EE-space action (produced by the policy) → joint-space action (sent to robot).
|
||||||
|
robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||||
|
steps=[
|
||||||
|
InverseKinematicsEEToJoints(
|
||||||
|
kinematics=kinematics_solver,
|
||||||
|
motor_names=motor_names,
|
||||||
|
initial_guess_current_joints=True,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
to_transition=robot_action_observation_to_transition,
|
||||||
|
to_output=transition_to_robot_action,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Policy config (full model is loaded inside build_rollout_context).
|
||||||
|
policy_config = PreTrainedConfig.from_pretrained(HF_MODEL_ID)
|
||||||
|
policy_config.pretrained_path = HF_MODEL_ID
|
||||||
|
|
||||||
|
cfg = RolloutConfig(
|
||||||
|
robot=robot_config,
|
||||||
|
policy=policy_config,
|
||||||
|
strategy=BaseStrategyConfig(),
|
||||||
|
inference=SyncInferenceConfig(),
|
||||||
|
fps=FPS,
|
||||||
|
duration=DURATION_SEC,
|
||||||
|
task=TASK_DESCRIPTION,
|
||||||
|
)
|
||||||
|
|
||||||
|
signal_handler = ProcessSignalHandler(use_threads=True)
|
||||||
|
|
||||||
|
# Pass the EE kinematic processors via kwargs; the defaults (identity) would
|
||||||
|
# otherwise skip the joint↔EE conversion and the policy would receive the
|
||||||
|
# wrong observation/action space.
|
||||||
|
ctx = build_rollout_context(
|
||||||
|
cfg,
|
||||||
|
signal_handler.shutdown_event,
|
||||||
|
robot_action_processor=robot_ee_to_joints_processor,
|
||||||
|
robot_observation_processor=robot_joints_to_ee_pose_processor,
|
||||||
|
)
|
||||||
|
|
||||||
|
strategy = BaseStrategy(cfg.strategy)
|
||||||
|
try:
|
||||||
|
strategy.setup(ctx)
|
||||||
|
strategy.run(ctx)
|
||||||
|
finally:
|
||||||
|
strategy.teardown(ctx)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -10,7 +10,7 @@ from lerobot.datasets import LeRobotDataset
|
|||||||
from lerobot.envs.configs import HILSerlProcessorConfig, HILSerlRobotEnvConfig
|
from lerobot.envs.configs import HILSerlProcessorConfig, HILSerlRobotEnvConfig
|
||||||
from lerobot.policies import SACConfig
|
from lerobot.policies import SACConfig
|
||||||
from lerobot.policies.sac.modeling_sac import SACPolicy
|
from lerobot.policies.sac.modeling_sac import SACPolicy
|
||||||
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
|
from lerobot.rewards.classifier.modeling_classifier import Classifier
|
||||||
from lerobot.rl.buffer import ReplayBuffer
|
from lerobot.rl.buffer import ReplayBuffer
|
||||||
from lerobot.rl.gym_manipulator import make_robot_env
|
from lerobot.rl.gym_manipulator import make_robot_env
|
||||||
from lerobot.robots.so_follower import SO100FollowerConfig
|
from lerobot.robots.so_follower import SO100FollowerConfig
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from lerobot.datasets import LeRobotDataset
|
from lerobot.datasets import LeRobotDataset
|
||||||
from lerobot.policies import RewardClassifierConfig, make_policy, make_pre_post_processors
|
from lerobot.rewards import RewardClassifierConfig, make_reward_model, make_reward_pre_post_processors
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
@@ -22,10 +22,10 @@ def main():
|
|||||||
model_name="microsoft/resnet-18",
|
model_name="microsoft/resnet-18",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Make policy, preprocessor, and optimizer
|
# Make reward model, preprocessor, and optimizer
|
||||||
policy = make_policy(config, ds_meta=dataset.meta)
|
reward_model = make_reward_model(config, dataset_stats=dataset.meta.stats)
|
||||||
optimizer = config.get_optimizer_preset().build(policy.parameters())
|
optimizer = config.get_optimizer_preset().build(reward_model.parameters())
|
||||||
preprocessor, _ = make_pre_post_processors(policy_cfg=config, dataset_stats=dataset.meta.stats)
|
preprocessor, _ = make_reward_pre_post_processors(config, dataset_stats=dataset.meta.stats)
|
||||||
|
|
||||||
classifier_id = "<user>/reward_classifier_hil_serl_example"
|
classifier_id = "<user>/reward_classifier_hil_serl_example"
|
||||||
|
|
||||||
@@ -42,7 +42,7 @@ def main():
|
|||||||
batch = preprocessor(batch)
|
batch = preprocessor(batch)
|
||||||
|
|
||||||
# Forward pass
|
# Forward pass
|
||||||
loss, output_dict = policy.forward(batch)
|
loss, output_dict = reward_model.forward(batch)
|
||||||
|
|
||||||
# Backward pass and optimization
|
# Backward pass and optimization
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
@@ -58,8 +58,8 @@ def main():
|
|||||||
|
|
||||||
print("Training finished!")
|
print("Training finished!")
|
||||||
|
|
||||||
# You can now save the trained policy.
|
# You can now save the trained reward model.
|
||||||
policy.push_to_hub(classifier_id)
|
reward_model.push_to_hub(classifier_id)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
+23
-4
@@ -59,8 +59,8 @@ keywords = ["lerobot", "huggingface", "robotics", "machine learning", "artifici
|
|||||||
|
|
||||||
dependencies = [
|
dependencies = [
|
||||||
# Core ML
|
# Core ML
|
||||||
"torch>=2.7,<2.11.0",
|
"torch>=2.7,<2.12.0",
|
||||||
"torchvision>=0.22.0,<0.26.0",
|
"torchvision>=0.22.0,<0.27.0",
|
||||||
"numpy>=2.0.0,<2.3.0", # NOTE: Explicitly listing numpy helps the resolver converge faster. Upper bound imposed by opencv-python-headless.
|
"numpy>=2.0.0,<2.3.0", # NOTE: Explicitly listing numpy helps the resolver converge faster. Upper bound imposed by opencv-python-headless.
|
||||||
"opencv-python-headless>=4.9.0,<4.14.0",
|
"opencv-python-headless>=4.9.0,<4.14.0",
|
||||||
"Pillow>=10.0.0,<13.0.0",
|
"Pillow>=10.0.0,<13.0.0",
|
||||||
@@ -99,7 +99,7 @@ dataset = [
|
|||||||
"pandas>=2.0.0,<3.0.0", # NOTE: Transitive dependency of datasets
|
"pandas>=2.0.0,<3.0.0", # NOTE: Transitive dependency of datasets
|
||||||
"pyarrow>=21.0.0,<30.0.0", # NOTE: Transitive dependency of datasets
|
"pyarrow>=21.0.0,<30.0.0", # NOTE: Transitive dependency of datasets
|
||||||
"lerobot[av-dep]",
|
"lerobot[av-dep]",
|
||||||
"torchcodec>=0.3.0,<0.11.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # NOTE: Windows support starts at version 0.7 (needs torch==2.8), ffmpeg>=8 support starts at version 0.8.1 (needs torch==2.9), system-wide ffmpeg support starts at version 0.10 (needs torch==2.10).
|
"torchcodec>=0.3.0,<0.12.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # NOTE: Windows support starts at version 0.7 (needs torch==2.8), ffmpeg>=8 support starts at version 0.8.1 (needs torch==2.9), system-wide ffmpeg support starts at version 0.10 (needs torch==2.10), 0.11 needs torch==2.11, 0.12 needs torch==2.12.
|
||||||
"jsonlines>=4.0.0,<5.0.0",
|
"jsonlines>=4.0.0,<5.0.0",
|
||||||
]
|
]
|
||||||
training = [
|
training = [
|
||||||
@@ -128,7 +128,7 @@ dataset_viz = ["lerobot[dataset]", "lerobot[viz]"]
|
|||||||
av-dep = ["av>=15.0.0,<16.0.0"]
|
av-dep = ["av>=15.0.0,<16.0.0"]
|
||||||
pygame-dep = ["pygame>=2.5.1,<2.7.0"]
|
pygame-dep = ["pygame>=2.5.1,<2.7.0"]
|
||||||
placo-dep = ["placo>=0.9.6,<0.9.17"]
|
placo-dep = ["placo>=0.9.6,<0.9.17"]
|
||||||
transformers-dep = ["transformers==5.3.0"] # TODO(Steven): https://github.com/huggingface/lerobot/pull/3249
|
transformers-dep = ["transformers>=5.4.0,<5.6.0"]
|
||||||
grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"]
|
grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"]
|
||||||
can-dep = ["python-can>=4.2.0,<5.0.0"]
|
can-dep = ["python-can>=4.2.0,<5.0.0"]
|
||||||
peft-dep = ["peft>=0.18.0,<1.0.0"]
|
peft-dep = ["peft>=0.18.0,<1.0.0"]
|
||||||
@@ -194,6 +194,8 @@ groot = [
|
|||||||
]
|
]
|
||||||
sarm = ["lerobot[transformers-dep]", "pydantic>=2.0.0,<3.0.0", "faker>=33.0.0,<35.0.0", "lerobot[matplotlib-dep]", "lerobot[qwen-vl-utils-dep]"]
|
sarm = ["lerobot[transformers-dep]", "pydantic>=2.0.0,<3.0.0", "faker>=33.0.0,<35.0.0", "lerobot[matplotlib-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||||
xvla = ["lerobot[transformers-dep]"]
|
xvla = ["lerobot[transformers-dep]"]
|
||||||
|
eo1 = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||||
|
evo1 = ["lerobot[transformers-dep]", "timm>=1.0.0,<1.1.0"]
|
||||||
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||||
|
|
||||||
# Features
|
# Features
|
||||||
@@ -257,6 +259,7 @@ all = [
|
|||||||
"lerobot[smolvla]",
|
"lerobot[smolvla]",
|
||||||
# "lerobot[groot]", TODO(Steven): Gr00t requires specific installation instructions for flash-attn
|
# "lerobot[groot]", TODO(Steven): Gr00t requires specific installation instructions for flash-attn
|
||||||
"lerobot[xvla]",
|
"lerobot[xvla]",
|
||||||
|
"lerobot[evo1]",
|
||||||
"lerobot[hilserl]",
|
"lerobot[hilserl]",
|
||||||
"lerobot[async]",
|
"lerobot[async]",
|
||||||
"lerobot[dev]",
|
"lerobot[dev]",
|
||||||
@@ -289,8 +292,23 @@ lerobot-find-joint-limits="lerobot.scripts.lerobot_find_joint_limits:main"
|
|||||||
lerobot-imgtransform-viz="lerobot.scripts.lerobot_imgtransform_viz:main"
|
lerobot-imgtransform-viz="lerobot.scripts.lerobot_imgtransform_viz:main"
|
||||||
lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main"
|
lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main"
|
||||||
lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main"
|
lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main"
|
||||||
|
lerobot-rollout="lerobot.scripts.lerobot_rollout:main"
|
||||||
|
|
||||||
# ---------------- Tool Configurations ----------------
|
# ---------------- Tool Configurations ----------------
|
||||||
|
|
||||||
|
# cu128 wheels keep broad hardware reach; the driver floor is 570.86.
|
||||||
|
# To use a different CUDA variant, reinstall torch with an explicit index, e.g.:
|
||||||
|
# uv pip install --force-reinstall torch torchvision \
|
||||||
|
# --index-url https://download.pytorch.org/whl/cu130
|
||||||
|
[[tool.uv.index]]
|
||||||
|
name = "pytorch-cu128"
|
||||||
|
url = "https://download.pytorch.org/whl/cu128"
|
||||||
|
explicit = true
|
||||||
|
|
||||||
|
[tool.uv.sources]
|
||||||
|
torch = [{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" }]
|
||||||
|
torchvision = [{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" }]
|
||||||
|
|
||||||
[tool.setuptools.package-data]
|
[tool.setuptools.package-data]
|
||||||
lerobot = ["envs/*.json"]
|
lerobot = ["envs/*.json"]
|
||||||
|
|
||||||
@@ -332,6 +350,7 @@ ignore = [
|
|||||||
# E402: conditional-import guards (TYPE_CHECKING / is_package_available) must precede the imports they protect
|
# E402: conditional-import guards (TYPE_CHECKING / is_package_available) must precede the imports they protect
|
||||||
"src/lerobot/scripts/convert_dataset_v21_to_v30.py" = ["E402"]
|
"src/lerobot/scripts/convert_dataset_v21_to_v30.py" = ["E402"]
|
||||||
"src/lerobot/policies/wall_x/**" = ["N801", "N812", "SIM102", "SIM108", "SIM210", "SIM211", "B006", "B007", "SIM118"] # Supprese these as they are coming from original Qwen2_5_vl code TODO(pepijn): refactor original
|
"src/lerobot/policies/wall_x/**" = ["N801", "N812", "SIM102", "SIM108", "SIM210", "SIM211", "B006", "B007", "SIM118"] # Supprese these as they are coming from original Qwen2_5_vl code TODO(pepijn): refactor original
|
||||||
|
"src/lerobot/policies/evo1/**" = ["N801", "N812"]
|
||||||
|
|
||||||
[tool.ruff.lint.isort]
|
[tool.ruff.lint.isort]
|
||||||
combine-as-imports = true
|
combine-as-imports = true
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ Provides the RealSenseCamera class for capturing frames from Intel RealSense cam
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import sys
|
||||||
import time
|
import time
|
||||||
from threading import Event, Lock, Thread
|
from threading import Event, Lock, Thread
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any
|
||||||
@@ -41,6 +42,7 @@ from ..utils import get_cv2_rotation
|
|||||||
from .configuration_realsense import RealSenseCameraConfig
|
from .configuration_realsense import RealSenseCameraConfig
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
pkg_name = "pyrealsense2-macosx" if sys.platform == "darwin" else "pyrealsense2"
|
||||||
|
|
||||||
|
|
||||||
class RealSenseCamera(Camera):
|
class RealSenseCamera(Camera):
|
||||||
@@ -114,7 +116,7 @@ class RealSenseCamera(Camera):
|
|||||||
Args:
|
Args:
|
||||||
config: The configuration settings for the camera.
|
config: The configuration settings for the camera.
|
||||||
"""
|
"""
|
||||||
require_package("pyrealsense2", extra="intelrealsense")
|
require_package(pkg_name, extra="intelrealsense", import_name="pyrealsense2")
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|||||||
@@ -41,8 +41,12 @@ def cfg_to_group(
|
|||||||
return tag
|
return tag
|
||||||
return tag[:max_tag_length]
|
return tag[:max_tag_length]
|
||||||
|
|
||||||
|
if cfg.is_reward_model_training:
|
||||||
|
trainable_tag = f"reward_model:{cfg.reward_model.type}"
|
||||||
|
else:
|
||||||
|
trainable_tag = f"policy:{cfg.policy.type}"
|
||||||
lst = [
|
lst = [
|
||||||
f"policy:{cfg.policy.type}",
|
trainable_tag,
|
||||||
f"seed:{cfg.seed}",
|
f"seed:{cfg.seed}",
|
||||||
]
|
]
|
||||||
if cfg.dataset is not None:
|
if cfg.dataset is not None:
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ are intentionally NOT re-exported here to avoid circular dependencies
|
|||||||
Import them directly: ``from lerobot.configs.train import TrainPipelineConfig``
|
Import them directly: ``from lerobot.configs.train import TrainPipelineConfig``
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from .dataset import DatasetRecordConfig
|
||||||
from .default import DatasetConfig, EvalConfig, PeftConfig, WandBConfig
|
from .default import DatasetConfig, EvalConfig, PeftConfig, WandBConfig
|
||||||
from .policies import PreTrainedConfig
|
from .policies import PreTrainedConfig
|
||||||
from .types import (
|
from .types import (
|
||||||
@@ -39,6 +40,7 @@ __all__ = [
|
|||||||
"PolicyFeature",
|
"PolicyFeature",
|
||||||
"RTCAttentionSchedule",
|
"RTCAttentionSchedule",
|
||||||
# Config classes
|
# Config classes
|
||||||
|
"DatasetRecordConfig",
|
||||||
"DatasetConfig",
|
"DatasetConfig",
|
||||||
"EvalConfig",
|
"EvalConfig",
|
||||||
"PeftConfig",
|
"PeftConfig",
|
||||||
|
|||||||
@@ -0,0 +1,80 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""Shared dataset recording configuration used by both ``lerobot-record`` and ``lerobot-rollout``."""
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DatasetRecordConfig:
|
||||||
|
# Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).
|
||||||
|
repo_id: str = ""
|
||||||
|
# A short but accurate description of the task performed during the recording (e.g. "Pick the Lego block and drop it in the box on the right.")
|
||||||
|
single_task: str = ""
|
||||||
|
# Root directory where the dataset will be stored (e.g. 'dataset/path'). If None, defaults to $HF_LEROBOT_HOME/repo_id.
|
||||||
|
root: str | Path | None = None
|
||||||
|
# Limit the frames per second.
|
||||||
|
fps: int = 30
|
||||||
|
# Number of seconds for data recording for each episode.
|
||||||
|
episode_time_s: int | float = 60
|
||||||
|
# Number of seconds for resetting the environment after each episode.
|
||||||
|
reset_time_s: int | float = 60
|
||||||
|
# Number of episodes to record.
|
||||||
|
num_episodes: int = 50
|
||||||
|
# Encode frames in the dataset into video
|
||||||
|
video: bool = True
|
||||||
|
# Upload dataset to Hugging Face hub.
|
||||||
|
push_to_hub: bool = True
|
||||||
|
# Upload on private repository on the Hugging Face hub.
|
||||||
|
private: bool = False
|
||||||
|
# Add tags to your dataset on the hub.
|
||||||
|
tags: list[str] | None = None
|
||||||
|
# Number of subprocesses handling the saving of frames as PNG. Set to 0 to use threads only;
|
||||||
|
# set to ≥1 to use subprocesses, each using threads to write images. The best number of processes
|
||||||
|
# and threads depends on your system. We recommend 4 threads per camera with 0 processes.
|
||||||
|
# If fps is unstable, adjust the thread count. If still unstable, try using 1 or more subprocesses.
|
||||||
|
num_image_writer_processes: int = 0
|
||||||
|
# Number of threads writing the frames as png images on disk, per camera.
|
||||||
|
# Too many threads might cause unstable teleoperation fps due to main thread being blocked.
|
||||||
|
# Not enough threads might cause low camera fps.
|
||||||
|
num_image_writer_threads_per_camera: int = 4
|
||||||
|
# Number of episodes to record before batch encoding videos
|
||||||
|
# Set to 1 for immediate encoding (default behavior), or higher for batched encoding
|
||||||
|
video_encoding_batch_size: int = 1
|
||||||
|
# Video codec for encoding videos. Options: 'h264', 'hevc', 'libsvtav1', 'auto',
|
||||||
|
# or hardware-specific: 'h264_videotoolbox', 'h264_nvenc', 'h264_vaapi', 'h264_qsv'.
|
||||||
|
# Use 'auto' to auto-detect the best available hardware encoder.
|
||||||
|
vcodec: str = "libsvtav1"
|
||||||
|
# Enable streaming video encoding: encode frames in real-time during capture instead
|
||||||
|
# of writing PNG images first. Makes save_episode() near-instant. More info in the documentation: https://huggingface.co/docs/lerobot/streaming_video_encoding
|
||||||
|
streaming_encoding: bool = False
|
||||||
|
# Maximum number of frames to buffer per camera when using streaming encoding.
|
||||||
|
# ~1s buffer at 30fps. Provides backpressure if the encoder can't keep up.
|
||||||
|
encoder_queue_maxsize: int = 30
|
||||||
|
# Number of threads per encoder instance. None = auto (codec default).
|
||||||
|
# Lower values reduce CPU usage, maps to 'lp' (via svtav1-params) for libsvtav1 and 'threads' for h264/hevc..
|
||||||
|
encoder_threads: int | None = None
|
||||||
|
|
||||||
|
def stamp_repo_id(self) -> None:
|
||||||
|
"""Append a date-time tag to ``repo_id`` so each recording session gets a unique name.
|
||||||
|
|
||||||
|
Must be called explicitly at dataset *creation* time — not on resume,
|
||||||
|
where the existing ``repo_id`` (already stamped) must be preserved.
|
||||||
|
"""
|
||||||
|
if self.repo_id:
|
||||||
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
self.repo_id = f"{self.repo_id}_{timestamp}"
|
||||||
@@ -0,0 +1,163 @@
|
|||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import abc
|
||||||
|
import builtins
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, TypeVar
|
||||||
|
|
||||||
|
import draccus
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
from huggingface_hub.constants import CONFIG_NAME
|
||||||
|
from huggingface_hub.errors import HfHubHTTPError
|
||||||
|
|
||||||
|
from lerobot.configs.types import PolicyFeature
|
||||||
|
from lerobot.optim.optimizers import OptimizerConfig
|
||||||
|
from lerobot.optim.schedulers import LRSchedulerConfig
|
||||||
|
from lerobot.utils.device_utils import auto_select_torch_device, is_torch_device_available
|
||||||
|
from lerobot.utils.hub import HubMixin
|
||||||
|
|
||||||
|
T = TypeVar("T", bound="RewardModelConfig")
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RewardModelConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
||||||
|
"""Base configuration for reward models.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_features: A dictionary defining the PolicyFeature of the input data for the reward. The key represents
|
||||||
|
the input data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
|
||||||
|
output_features: A dictionary defining the PolicyFeature of the output data for the reward. The key represents
|
||||||
|
the output data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Reuses PolicyFeature
|
||||||
|
input_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||||
|
output_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||||
|
|
||||||
|
device: str | None = None
|
||||||
|
|
||||||
|
pretrained_path: str | None = None
|
||||||
|
|
||||||
|
push_to_hub: bool = False
|
||||||
|
repo_id: str | None = None
|
||||||
|
|
||||||
|
# Hub metadata
|
||||||
|
license: str | None = None
|
||||||
|
tags: list[str] | None = None
|
||||||
|
private: bool | None = None
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
if not self.device or not is_torch_device_available(self.device):
|
||||||
|
auto_device = auto_select_torch_device()
|
||||||
|
logger.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.")
|
||||||
|
self.device = auto_device.type
|
||||||
|
|
||||||
|
@property
|
||||||
|
def type(self) -> str:
|
||||||
|
choice_name = self.get_choice_name(self.__class__)
|
||||||
|
if not isinstance(choice_name, str):
|
||||||
|
raise TypeError(f"Expected string from get_choice_name, got {type(choice_name)}")
|
||||||
|
return choice_name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def observation_delta_indices(self) -> list | None: # type: ignore[type-arg]
|
||||||
|
return None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def action_delta_indices(self) -> list | None: # type: ignore[type-arg]
|
||||||
|
return None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def reward_delta_indices(self) -> list | None: # type: ignore[type-arg]
|
||||||
|
return None
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def get_optimizer_preset(self) -> OptimizerConfig:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_scheduler_preset(self) -> LRSchedulerConfig | None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def validate_features(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _save_pretrained(self, save_directory: Path) -> None:
|
||||||
|
with open(save_directory / CONFIG_NAME, "w") as f, draccus.config_type("json"):
|
||||||
|
draccus.dump(self, f, indent=4)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(
|
||||||
|
cls: builtins.type[T],
|
||||||
|
pretrained_name_or_path: str | Path,
|
||||||
|
*,
|
||||||
|
force_download: bool = False,
|
||||||
|
resume_download: bool | None = None,
|
||||||
|
proxies: dict[Any, Any] | None = None,
|
||||||
|
token: str | bool | None = None,
|
||||||
|
cache_dir: str | Path | None = None,
|
||||||
|
local_files_only: bool = False,
|
||||||
|
revision: str | None = None,
|
||||||
|
**reward_kwargs: Any,
|
||||||
|
) -> T:
|
||||||
|
model_id = str(pretrained_name_or_path)
|
||||||
|
config_file: str | None = None
|
||||||
|
if Path(model_id).is_dir():
|
||||||
|
if CONFIG_NAME in os.listdir(model_id):
|
||||||
|
config_file = os.path.join(model_id, CONFIG_NAME)
|
||||||
|
else:
|
||||||
|
logger.error(f"{CONFIG_NAME} not found in {Path(model_id).resolve()}")
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
config_file = hf_hub_download(
|
||||||
|
repo_id=model_id,
|
||||||
|
filename=CONFIG_NAME,
|
||||||
|
revision=revision,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
force_download=force_download,
|
||||||
|
proxies=proxies,
|
||||||
|
resume_download=resume_download,
|
||||||
|
token=token,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
)
|
||||||
|
except HfHubHTTPError as e:
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"{CONFIG_NAME} not found on the HuggingFace Hub in {model_id}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
if config_file is None:
|
||||||
|
raise FileNotFoundError(f"{CONFIG_NAME} not found in {model_id}")
|
||||||
|
|
||||||
|
# HACK: Parse the original config to get the config subclass, so that we can
|
||||||
|
# apply cli overrides.
|
||||||
|
with draccus.config_type("json"):
|
||||||
|
orig_config = draccus.parse(cls, config_file, args=[])
|
||||||
|
|
||||||
|
with open(config_file) as f:
|
||||||
|
config = json.load(f)
|
||||||
|
|
||||||
|
config.pop("type", None)
|
||||||
|
with tempfile.NamedTemporaryFile("w+", delete=False, suffix=".json") as f:
|
||||||
|
json.dump(config, f)
|
||||||
|
config_file = f.name
|
||||||
|
|
||||||
|
cli_overrides = reward_kwargs.pop("cli_overrides", [])
|
||||||
|
with draccus.config_type("json"):
|
||||||
|
return draccus.parse(orig_config.__class__, config_file, args=cli_overrides)
|
||||||
@@ -13,7 +13,9 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import builtins
|
import builtins
|
||||||
import datetime as dt
|
import datetime as dt
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
|
import tempfile
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@@ -26,18 +28,57 @@ from lerobot import envs
|
|||||||
from lerobot.configs import parser
|
from lerobot.configs import parser
|
||||||
from lerobot.optim import LRSchedulerConfig, OptimizerConfig
|
from lerobot.optim import LRSchedulerConfig, OptimizerConfig
|
||||||
from lerobot.utils.hub import HubMixin
|
from lerobot.utils.hub import HubMixin
|
||||||
|
from lerobot.utils.sample_weighting import SampleWeightingConfig
|
||||||
|
|
||||||
from .default import DatasetConfig, EvalConfig, PeftConfig, WandBConfig
|
from .default import DatasetConfig, EvalConfig, PeftConfig, WandBConfig
|
||||||
from .policies import PreTrainedConfig
|
from .policies import PreTrainedConfig
|
||||||
|
from .rewards import RewardModelConfig
|
||||||
|
|
||||||
TRAIN_CONFIG_NAME = "train_config.json"
|
TRAIN_CONFIG_NAME = "train_config.json"
|
||||||
|
|
||||||
|
|
||||||
|
def _migrate_legacy_rabc_fields(config: dict[str, Any]) -> dict[str, Any] | None:
|
||||||
|
"""Return migrated payload for legacy RA-BC fields, or None when no migration is needed."""
|
||||||
|
legacy_fields = (
|
||||||
|
"use_rabc",
|
||||||
|
"rabc_progress_path",
|
||||||
|
"rabc_kappa",
|
||||||
|
"rabc_epsilon",
|
||||||
|
"rabc_head_mode",
|
||||||
|
)
|
||||||
|
if not any(key in config for key in legacy_fields):
|
||||||
|
return None
|
||||||
|
|
||||||
|
migrated_config = dict(config)
|
||||||
|
use_rabc = bool(migrated_config.pop("use_rabc", False))
|
||||||
|
rabc_progress_path = migrated_config.pop("rabc_progress_path", None)
|
||||||
|
rabc_kappa = migrated_config.pop("rabc_kappa", None)
|
||||||
|
rabc_epsilon = migrated_config.pop("rabc_epsilon", None)
|
||||||
|
rabc_head_mode = migrated_config.pop("rabc_head_mode", None)
|
||||||
|
|
||||||
|
# New configs may already define sample_weighting explicitly. In that case,
|
||||||
|
# legacy fields are ignored after being stripped from the payload.
|
||||||
|
if migrated_config.get("sample_weighting") is None and use_rabc:
|
||||||
|
sample_weighting: dict[str, Any] = {"type": "rabc"}
|
||||||
|
if rabc_progress_path is not None:
|
||||||
|
sample_weighting["progress_path"] = rabc_progress_path
|
||||||
|
if rabc_kappa is not None:
|
||||||
|
sample_weighting["kappa"] = rabc_kappa
|
||||||
|
if rabc_epsilon is not None:
|
||||||
|
sample_weighting["epsilon"] = rabc_epsilon
|
||||||
|
if rabc_head_mode is not None:
|
||||||
|
sample_weighting["head_mode"] = rabc_head_mode
|
||||||
|
migrated_config["sample_weighting"] = sample_weighting
|
||||||
|
|
||||||
|
return migrated_config
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TrainPipelineConfig(HubMixin):
|
class TrainPipelineConfig(HubMixin):
|
||||||
dataset: DatasetConfig
|
dataset: DatasetConfig
|
||||||
env: envs.EnvConfig | None = None
|
env: envs.EnvConfig | None = None
|
||||||
policy: PreTrainedConfig | None = None
|
policy: PreTrainedConfig | None = None
|
||||||
|
reward_model: RewardModelConfig | None = None
|
||||||
# Set `dir` to where you would like to save all of the run outputs. If you run another training session
|
# Set `dir` to where you would like to save all of the run outputs. If you run another training session
|
||||||
# with the same value for `dir` its contents will be overwritten unless you set `resume` to true.
|
# with the same value for `dir` its contents will be overwritten unless you set `resume` to true.
|
||||||
output_dir: Path | None = None
|
output_dir: Path | None = None
|
||||||
@@ -72,27 +113,41 @@ class TrainPipelineConfig(HubMixin):
|
|||||||
wandb: WandBConfig = field(default_factory=WandBConfig)
|
wandb: WandBConfig = field(default_factory=WandBConfig)
|
||||||
peft: PeftConfig | None = None
|
peft: PeftConfig | None = None
|
||||||
|
|
||||||
# RA-BC (Reward-Aligned Behavior Cloning) parameters
|
# Sample weighting configuration (e.g., for RA-BC training)
|
||||||
use_rabc: bool = False # Enable reward-weighted training
|
sample_weighting: SampleWeightingConfig | None = None
|
||||||
rabc_progress_path: str | None = None # Path to precomputed SARM progress parquet file
|
|
||||||
rabc_kappa: float = 0.01 # Hard threshold for high-quality samples
|
|
||||||
rabc_epsilon: float = 1e-6 # Small constant for numerical stability
|
|
||||||
rabc_head_mode: str | None = "sparse" # For dual-head models: "sparse" or "dense"
|
|
||||||
|
|
||||||
# Rename map for the observation to override the image and state keys
|
# Rename map for the observation to override the image and state keys
|
||||||
rename_map: dict[str, str] = field(default_factory=dict)
|
rename_map: dict[str, str] = field(default_factory=dict)
|
||||||
checkpoint_path: Path | None = field(init=False, default=None)
|
checkpoint_path: Path | None = field(init=False, default=None)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_reward_model_training(self) -> bool:
|
||||||
|
"""True when the config targets a reward model rather than a policy."""
|
||||||
|
return self.reward_model is not None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def trainable_config(self) -> PreTrainedConfig | RewardModelConfig:
|
||||||
|
"""Return whichever config (policy or reward_model) is active."""
|
||||||
|
if self.is_reward_model_training:
|
||||||
|
return self.reward_model # type: ignore[return-value]
|
||||||
|
return self.policy # type: ignore[return-value]
|
||||||
|
|
||||||
def validate(self) -> None:
|
def validate(self) -> None:
|
||||||
# HACK: We parse again the cli args here to get the pretrained paths if there was some.
|
# HACK: We parse again the cli args here to get the pretrained paths if there was some.
|
||||||
policy_path = parser.get_path_arg("policy")
|
policy_path = parser.get_path_arg("policy")
|
||||||
if policy_path:
|
reward_model_path = parser.get_path_arg("reward_model")
|
||||||
# Only load the policy config
|
|
||||||
|
if reward_model_path:
|
||||||
|
cli_overrides = parser.get_cli_overrides("reward_model")
|
||||||
|
self.reward_model = RewardModelConfig.from_pretrained(
|
||||||
|
reward_model_path, cli_overrides=cli_overrides
|
||||||
|
)
|
||||||
|
self.reward_model.pretrained_path = str(Path(reward_model_path))
|
||||||
|
elif policy_path:
|
||||||
cli_overrides = parser.get_cli_overrides("policy")
|
cli_overrides = parser.get_cli_overrides("policy")
|
||||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||||
self.policy.pretrained_path = Path(policy_path)
|
self.policy.pretrained_path = Path(policy_path)
|
||||||
elif self.resume:
|
elif self.resume:
|
||||||
# The entire train config is already loaded, we just need to get the checkpoint dir
|
|
||||||
config_path = parser.parse_arg("config_path")
|
config_path = parser.parse_arg("config_path")
|
||||||
if not config_path:
|
if not config_path:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -108,18 +163,22 @@ class TrainPipelineConfig(HubMixin):
|
|||||||
policy_dir = Path(config_path).parent
|
policy_dir = Path(config_path).parent
|
||||||
if self.policy is not None:
|
if self.policy is not None:
|
||||||
self.policy.pretrained_path = policy_dir
|
self.policy.pretrained_path = policy_dir
|
||||||
|
if self.reward_model is not None:
|
||||||
|
self.reward_model.pretrained_path = str(policy_dir)
|
||||||
self.checkpoint_path = policy_dir.parent
|
self.checkpoint_path = policy_dir.parent
|
||||||
|
|
||||||
if self.policy is None:
|
if self.policy is None and self.reward_model is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Policy is not configured. Please specify a pretrained policy with `--policy.path`."
|
"Neither policy nor reward_model is configured. "
|
||||||
|
"Please specify one with `--policy.path` or `--reward_model.path`."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
active_cfg = self.trainable_config
|
||||||
if not self.job_name:
|
if not self.job_name:
|
||||||
if self.env is None:
|
if self.env is None:
|
||||||
self.job_name = f"{self.policy.type}"
|
self.job_name = f"{active_cfg.type}"
|
||||||
else:
|
else:
|
||||||
self.job_name = f"{self.env.type}_{self.policy.type}"
|
self.job_name = f"{self.env.type}_{active_cfg.type}"
|
||||||
|
|
||||||
if not self.resume and isinstance(self.output_dir, Path) and self.output_dir.is_dir():
|
if not self.resume and isinstance(self.output_dir, Path) and self.output_dir.is_dir():
|
||||||
raise FileExistsError(
|
raise FileExistsError(
|
||||||
@@ -137,26 +196,16 @@ class TrainPipelineConfig(HubMixin):
|
|||||||
if not self.use_policy_training_preset and (self.optimizer is None or self.scheduler is None):
|
if not self.use_policy_training_preset and (self.optimizer is None or self.scheduler is None):
|
||||||
raise ValueError("Optimizer and Scheduler must be set when the policy presets are not used.")
|
raise ValueError("Optimizer and Scheduler must be set when the policy presets are not used.")
|
||||||
elif self.use_policy_training_preset and not self.resume:
|
elif self.use_policy_training_preset and not self.resume:
|
||||||
self.optimizer = self.policy.get_optimizer_preset()
|
self.optimizer = active_cfg.get_optimizer_preset()
|
||||||
self.scheduler = self.policy.get_scheduler_preset()
|
self.scheduler = active_cfg.get_scheduler_preset()
|
||||||
|
|
||||||
if self.policy.push_to_hub and not self.policy.repo_id:
|
if hasattr(active_cfg, "push_to_hub") and active_cfg.push_to_hub and not active_cfg.repo_id:
|
||||||
raise ValueError(
|
raise ValueError("'repo_id' argument missing. Please specify it to push the model to the hub.")
|
||||||
"'policy.repo_id' argument missing. Please specify it to push the model to the hub."
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.use_rabc and not self.rabc_progress_path:
|
|
||||||
# Auto-detect from dataset path
|
|
||||||
repo_id = self.dataset.repo_id
|
|
||||||
if self.dataset.root:
|
|
||||||
self.rabc_progress_path = str(Path(self.dataset.root) / "sarm_progress.parquet")
|
|
||||||
else:
|
|
||||||
self.rabc_progress_path = f"hf://datasets/{repo_id}/sarm_progress.parquet"
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def __get_path_fields__(cls) -> list[str]:
|
def __get_path_fields__(cls) -> list[str]:
|
||||||
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
|
"""Keys for draccus pretrained-path loading."""
|
||||||
return ["policy"]
|
return ["policy", "reward_model"]
|
||||||
|
|
||||||
def to_dict(self) -> dict[str, Any]:
|
def to_dict(self) -> dict[str, Any]:
|
||||||
return draccus.encode(self) # type: ignore[no-any-return] # because of the third-party library draccus uses Any as the return type
|
return draccus.encode(self) # type: ignore[no-any-return] # because of the third-party library draccus uses Any as the return type
|
||||||
@@ -207,6 +256,17 @@ class TrainPipelineConfig(HubMixin):
|
|||||||
) from e
|
) from e
|
||||||
|
|
||||||
cli_args = kwargs.pop("cli_args", [])
|
cli_args = kwargs.pop("cli_args", [])
|
||||||
|
# Legacy RA-BC migration only applies to framework-saved checkpoints (always JSON).
|
||||||
|
# Hand-written YAML/TOML configs are expected to use the current sample_weighting schema.
|
||||||
|
if config_file is not None and config_file.endswith(".json"):
|
||||||
|
with open(config_file) as f:
|
||||||
|
config = json.load(f)
|
||||||
|
migrated_config = _migrate_legacy_rabc_fields(config)
|
||||||
|
if migrated_config is not None:
|
||||||
|
with tempfile.NamedTemporaryFile("w+", delete=False, suffix=".json") as f:
|
||||||
|
json.dump(migrated_config, f)
|
||||||
|
config_file = f.name
|
||||||
|
|
||||||
with draccus.config_type("json"):
|
with draccus.config_type("json"):
|
||||||
return draccus.parse(cls, config_file, args=cli_args)
|
return draccus.parse(cls, config_file, args=cli_args)
|
||||||
|
|
||||||
|
|||||||
@@ -97,8 +97,8 @@ def update_data_df(df, src_meta, dst_meta):
|
|||||||
pd.DataFrame: Updated DataFrame with adjusted indices.
|
pd.DataFrame: Updated DataFrame with adjusted indices.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
df["episode_index"] = df["episode_index"] + dst_meta.info["total_episodes"]
|
df["episode_index"] = df["episode_index"] + dst_meta.info.total_episodes
|
||||||
df["index"] = df["index"] + dst_meta.info["total_frames"]
|
df["index"] = df["index"] + dst_meta.info.total_frames
|
||||||
|
|
||||||
src_task_names = src_meta.tasks.index.take(df["task_index"].to_numpy())
|
src_task_names = src_meta.tasks.index.take(df["task_index"].to_numpy())
|
||||||
df["task_index"] = dst_meta.tasks.loc[src_task_names, "task_index"].to_numpy()
|
df["task_index"] = dst_meta.tasks.loc[src_task_names, "task_index"].to_numpy()
|
||||||
@@ -225,9 +225,9 @@ def update_meta_data(
|
|||||||
# Clean up temporary columns
|
# Clean up temporary columns
|
||||||
df = df.drop(columns=["_orig_chunk", "_orig_file"])
|
df = df.drop(columns=["_orig_chunk", "_orig_file"])
|
||||||
|
|
||||||
df["dataset_from_index"] = df["dataset_from_index"] + dst_meta.info["total_frames"]
|
df["dataset_from_index"] = df["dataset_from_index"] + dst_meta.info.total_frames
|
||||||
df["dataset_to_index"] = df["dataset_to_index"] + dst_meta.info["total_frames"]
|
df["dataset_to_index"] = df["dataset_to_index"] + dst_meta.info.total_frames
|
||||||
df["episode_index"] = df["episode_index"] + dst_meta.info["total_episodes"]
|
df["episode_index"] = df["episode_index"] + dst_meta.info.total_episodes
|
||||||
|
|
||||||
return df
|
return df
|
||||||
|
|
||||||
@@ -237,8 +237,8 @@ def aggregate_datasets(
|
|||||||
aggr_repo_id: str,
|
aggr_repo_id: str,
|
||||||
roots: list[Path] | None = None,
|
roots: list[Path] | None = None,
|
||||||
aggr_root: Path | None = None,
|
aggr_root: Path | None = None,
|
||||||
data_files_size_in_mb: float | None = None,
|
data_files_size_in_mb: int | None = None,
|
||||||
video_files_size_in_mb: float | None = None,
|
video_files_size_in_mb: int | None = None,
|
||||||
chunk_size: int | None = None,
|
chunk_size: int | None = None,
|
||||||
):
|
):
|
||||||
"""Aggregates multiple LeRobot datasets into a single unified dataset.
|
"""Aggregates multiple LeRobot datasets into a single unified dataset.
|
||||||
@@ -313,8 +313,8 @@ def aggregate_datasets(
|
|||||||
# to avoid interference between different source datasets
|
# to avoid interference between different source datasets
|
||||||
data_idx.pop("src_to_dst", None)
|
data_idx.pop("src_to_dst", None)
|
||||||
|
|
||||||
dst_meta.info["total_episodes"] += src_meta.total_episodes
|
dst_meta.info.total_episodes += src_meta.total_episodes
|
||||||
dst_meta.info["total_frames"] += src_meta.total_frames
|
dst_meta.info.total_frames += src_meta.total_frames
|
||||||
|
|
||||||
finalize_aggregation(dst_meta, all_metadata)
|
finalize_aggregation(dst_meta, all_metadata)
|
||||||
logging.info("Aggregation complete.")
|
logging.info("Aggregation complete.")
|
||||||
@@ -640,14 +640,10 @@ def finalize_aggregation(aggr_meta, all_metadata):
|
|||||||
write_tasks(aggr_meta.tasks, aggr_meta.root)
|
write_tasks(aggr_meta.tasks, aggr_meta.root)
|
||||||
|
|
||||||
logging.info("write info")
|
logging.info("write info")
|
||||||
aggr_meta.info.update(
|
aggr_meta.info.total_tasks = len(aggr_meta.tasks)
|
||||||
{
|
aggr_meta.info.total_episodes = sum(m.total_episodes for m in all_metadata)
|
||||||
"total_tasks": len(aggr_meta.tasks),
|
aggr_meta.info.total_frames = sum(m.total_frames for m in all_metadata)
|
||||||
"total_episodes": sum(m.total_episodes for m in all_metadata),
|
aggr_meta.info.splits = {"train": f"0:{sum(m.total_episodes for m in all_metadata)}"}
|
||||||
"total_frames": sum(m.total_frames for m in all_metadata),
|
|
||||||
"splits": {"train": f"0:{sum(m.total_episodes for m in all_metadata)}"},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
write_info(aggr_meta.info, aggr_meta.root)
|
write_info(aggr_meta.info, aggr_meta.root)
|
||||||
|
|
||||||
logging.info("write stats")
|
logging.info("write stats")
|
||||||
|
|||||||
@@ -37,13 +37,11 @@ from .io_utils import (
|
|||||||
load_subtasks,
|
load_subtasks,
|
||||||
load_tasks,
|
load_tasks,
|
||||||
write_info,
|
write_info,
|
||||||
write_json,
|
|
||||||
write_stats,
|
write_stats,
|
||||||
write_tasks,
|
write_tasks,
|
||||||
)
|
)
|
||||||
from .utils import (
|
from .utils import (
|
||||||
DEFAULT_EPISODES_PATH,
|
DEFAULT_EPISODES_PATH,
|
||||||
INFO_PATH,
|
|
||||||
check_version_compatibility,
|
check_version_compatibility,
|
||||||
get_safe_version,
|
get_safe_version,
|
||||||
has_legacy_hub_download_metadata,
|
has_legacy_hub_download_metadata,
|
||||||
@@ -228,7 +226,7 @@ class LeRobotDatasetMetadata:
|
|||||||
@property
|
@property
|
||||||
def _version(self) -> packaging.version.Version:
|
def _version(self) -> packaging.version.Version:
|
||||||
"""Codebase version used to create this dataset."""
|
"""Codebase version used to create this dataset."""
|
||||||
return packaging.version.parse(self.info["codebase_version"])
|
return packaging.version.parse(self.info.codebase_version)
|
||||||
|
|
||||||
def get_data_file_path(self, ep_index: int) -> Path:
|
def get_data_file_path(self, ep_index: int) -> Path:
|
||||||
"""Return the relative parquet file path for the given episode index.
|
"""Return the relative parquet file path for the given episode index.
|
||||||
@@ -283,27 +281,27 @@ class LeRobotDatasetMetadata:
|
|||||||
@property
|
@property
|
||||||
def data_path(self) -> str:
|
def data_path(self) -> str:
|
||||||
"""Formattable string for the parquet files."""
|
"""Formattable string for the parquet files."""
|
||||||
return self.info["data_path"]
|
return self.info.data_path
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def video_path(self) -> str | None:
|
def video_path(self) -> str | None:
|
||||||
"""Formattable string for the video files."""
|
"""Formattable string for the video files."""
|
||||||
return self.info["video_path"]
|
return self.info.video_path
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def robot_type(self) -> str | None:
|
def robot_type(self) -> str | None:
|
||||||
"""Robot type used in recording this dataset."""
|
"""Robot type used in recording this dataset."""
|
||||||
return self.info["robot_type"]
|
return self.info.robot_type
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def fps(self) -> int:
|
def fps(self) -> int:
|
||||||
"""Frames per second used during data collection."""
|
"""Frames per second used during data collection."""
|
||||||
return self.info["fps"]
|
return self.info.fps
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def features(self) -> dict[str, dict]:
|
def features(self) -> dict[str, dict]:
|
||||||
"""All features contained in the dataset."""
|
"""All features contained in the dataset."""
|
||||||
return self.info["features"]
|
return self.info.features
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def image_keys(self) -> list[str]:
|
def image_keys(self) -> list[str]:
|
||||||
@@ -333,32 +331,32 @@ class LeRobotDatasetMetadata:
|
|||||||
@property
|
@property
|
||||||
def total_episodes(self) -> int:
|
def total_episodes(self) -> int:
|
||||||
"""Total number of episodes available."""
|
"""Total number of episodes available."""
|
||||||
return self.info["total_episodes"]
|
return self.info.total_episodes
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def total_frames(self) -> int:
|
def total_frames(self) -> int:
|
||||||
"""Total number of frames saved in this dataset."""
|
"""Total number of frames saved in this dataset."""
|
||||||
return self.info["total_frames"]
|
return self.info.total_frames
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def total_tasks(self) -> int:
|
def total_tasks(self) -> int:
|
||||||
"""Total number of different tasks performed in this dataset."""
|
"""Total number of different tasks performed in this dataset."""
|
||||||
return self.info["total_tasks"]
|
return self.info.total_tasks
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def chunks_size(self) -> int:
|
def chunks_size(self) -> int:
|
||||||
"""Max number of files per chunk."""
|
"""Max number of files per chunk."""
|
||||||
return self.info["chunks_size"]
|
return self.info.chunks_size
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def data_files_size_in_mb(self) -> int:
|
def data_files_size_in_mb(self) -> int:
|
||||||
"""Max size of data file in mega bytes."""
|
"""Max size of data file in mega bytes."""
|
||||||
return self.info["data_files_size_in_mb"]
|
return self.info.data_files_size_in_mb
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def video_files_size_in_mb(self) -> int:
|
def video_files_size_in_mb(self) -> int:
|
||||||
"""Max size of video file in mega bytes."""
|
"""Max size of video file in mega bytes."""
|
||||||
return self.info["video_files_size_in_mb"]
|
return self.info.video_files_size_in_mb
|
||||||
|
|
||||||
def get_task_index(self, task: str) -> int | None:
|
def get_task_index(self, task: str) -> int | None:
|
||||||
"""
|
"""
|
||||||
@@ -502,10 +500,10 @@ class LeRobotDatasetMetadata:
|
|||||||
self._save_episode_metadata(episode_dict)
|
self._save_episode_metadata(episode_dict)
|
||||||
|
|
||||||
# Update info
|
# Update info
|
||||||
self.info["total_episodes"] += 1
|
self.info.total_episodes += 1
|
||||||
self.info["total_frames"] += episode_length
|
self.info.total_frames += episode_length
|
||||||
self.info["total_tasks"] = len(self.tasks)
|
self.info.total_tasks = len(self.tasks)
|
||||||
self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"}
|
self.info.splits = {"train": f"0:{self.info.total_episodes}"}
|
||||||
|
|
||||||
write_info(self.info, self.root)
|
write_info(self.info, self.root)
|
||||||
|
|
||||||
@@ -524,7 +522,7 @@ class LeRobotDatasetMetadata:
|
|||||||
for key in video_keys:
|
for key in video_keys:
|
||||||
if not self.features[key].get("info", None):
|
if not self.features[key].get("info", None):
|
||||||
video_path = self.root / self.video_path.format(video_key=key, chunk_index=0, file_index=0)
|
video_path = self.root / self.video_path.format(video_key=key, chunk_index=0, file_index=0)
|
||||||
self.info["features"][key]["info"] = get_video_info(video_path)
|
self.info.features[key]["info"] = get_video_info(video_path)
|
||||||
|
|
||||||
def update_chunk_settings(
|
def update_chunk_settings(
|
||||||
self,
|
self,
|
||||||
@@ -546,17 +544,17 @@ class LeRobotDatasetMetadata:
|
|||||||
if chunks_size is not None:
|
if chunks_size is not None:
|
||||||
if chunks_size <= 0:
|
if chunks_size <= 0:
|
||||||
raise ValueError(f"chunks_size must be positive, got {chunks_size}")
|
raise ValueError(f"chunks_size must be positive, got {chunks_size}")
|
||||||
self.info["chunks_size"] = chunks_size
|
self.info.chunks_size = chunks_size
|
||||||
|
|
||||||
if data_files_size_in_mb is not None:
|
if data_files_size_in_mb is not None:
|
||||||
if data_files_size_in_mb <= 0:
|
if data_files_size_in_mb <= 0:
|
||||||
raise ValueError(f"data_files_size_in_mb must be positive, got {data_files_size_in_mb}")
|
raise ValueError(f"data_files_size_in_mb must be positive, got {data_files_size_in_mb}")
|
||||||
self.info["data_files_size_in_mb"] = data_files_size_in_mb
|
self.info.data_files_size_in_mb = data_files_size_in_mb
|
||||||
|
|
||||||
if video_files_size_in_mb is not None:
|
if video_files_size_in_mb is not None:
|
||||||
if video_files_size_in_mb <= 0:
|
if video_files_size_in_mb <= 0:
|
||||||
raise ValueError(f"video_files_size_in_mb must be positive, got {video_files_size_in_mb}")
|
raise ValueError(f"video_files_size_in_mb must be positive, got {video_files_size_in_mb}")
|
||||||
self.info["video_files_size_in_mb"] = video_files_size_in_mb
|
self.info.video_files_size_in_mb = video_files_size_in_mb
|
||||||
|
|
||||||
# Update the info file on disk
|
# Update the info file on disk
|
||||||
write_info(self.info, self.root)
|
write_info(self.info, self.root)
|
||||||
@@ -653,7 +651,7 @@ class LeRobotDatasetMetadata:
|
|||||||
f"Features contain video keys {obj.video_keys}, but 'use_videos' is set to False. "
|
f"Features contain video keys {obj.video_keys}, but 'use_videos' is set to False. "
|
||||||
"Either remove video features from the features dict, or set 'use_videos=True'."
|
"Either remove video features from the features dict, or set 'use_videos=True'."
|
||||||
)
|
)
|
||||||
write_json(obj.info, obj.root / INFO_PATH)
|
write_info(obj.info, obj.root)
|
||||||
obj.revision = None
|
obj.revision = None
|
||||||
obj._pq_writer = None
|
obj._pq_writer = None
|
||||||
obj.latest_episode = None
|
obj.latest_episode = None
|
||||||
|
|||||||
@@ -897,14 +897,10 @@ def _copy_and_reindex_episodes_metadata(
|
|||||||
|
|
||||||
dst_meta.finalize()
|
dst_meta.finalize()
|
||||||
|
|
||||||
dst_meta.info.update(
|
dst_meta.info.total_episodes = len(episode_mapping)
|
||||||
{
|
dst_meta.info.total_frames = total_frames
|
||||||
"total_episodes": len(episode_mapping),
|
dst_meta.info.total_tasks = len(dst_meta.tasks) if dst_meta.tasks is not None else 0
|
||||||
"total_frames": total_frames,
|
dst_meta.info.splits = {"train": f"0:{len(episode_mapping)}"}
|
||||||
"total_tasks": len(dst_meta.tasks) if dst_meta.tasks is not None else 0,
|
|
||||||
"splits": {"train": f"0:{len(episode_mapping)}"},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
write_info(dst_meta.info, dst_meta.root)
|
write_info(dst_meta.info, dst_meta.root)
|
||||||
|
|
||||||
if not all_stats:
|
if not all_stats:
|
||||||
@@ -1069,21 +1065,20 @@ def _copy_episodes_metadata_and_stats(
|
|||||||
if episodes_dir.exists():
|
if episodes_dir.exists():
|
||||||
shutil.copytree(episodes_dir, dst_episodes_dir, dirs_exist_ok=True)
|
shutil.copytree(episodes_dir, dst_episodes_dir, dirs_exist_ok=True)
|
||||||
|
|
||||||
dst_meta.info.update(
|
dst_meta.info.total_episodes = src_dataset.meta.total_episodes
|
||||||
{
|
dst_meta.info.total_frames = src_dataset.meta.total_frames
|
||||||
"total_episodes": src_dataset.meta.total_episodes,
|
dst_meta.info.total_tasks = src_dataset.meta.total_tasks
|
||||||
"total_frames": src_dataset.meta.total_frames,
|
# Preserve original splits if available, otherwise create default
|
||||||
"total_tasks": src_dataset.meta.total_tasks,
|
dst_meta.info.splits = (
|
||||||
"splits": src_dataset.meta.info.get("splits", {"train": f"0:{src_dataset.meta.total_episodes}"}),
|
src_dataset.meta.info.splits
|
||||||
}
|
if src_dataset.meta.info.splits
|
||||||
|
else {"train": f"0:{src_dataset.meta.total_episodes}"}
|
||||||
)
|
)
|
||||||
|
|
||||||
if dst_meta.video_keys and src_dataset.meta.video_keys:
|
if dst_meta.video_keys and src_dataset.meta.video_keys:
|
||||||
for key in dst_meta.video_keys:
|
for key in dst_meta.video_keys:
|
||||||
if key in src_dataset.meta.features:
|
if key in src_dataset.meta.features:
|
||||||
dst_meta.info["features"][key]["info"] = src_dataset.meta.info["features"][key].get(
|
dst_meta.info.features[key]["info"] = src_dataset.meta.info.features[key].get("info", {})
|
||||||
"info", {}
|
|
||||||
)
|
|
||||||
|
|
||||||
write_info(dst_meta.info, dst_meta.root)
|
write_info(dst_meta.info, dst_meta.root)
|
||||||
|
|
||||||
@@ -1525,7 +1520,7 @@ def modify_tasks(
|
|||||||
write_tasks(new_task_df, root)
|
write_tasks(new_task_df, root)
|
||||||
|
|
||||||
# Update info.json
|
# Update info.json
|
||||||
dataset.meta.info["total_tasks"] = len(unique_tasks)
|
dataset.meta.info.total_tasks = len(unique_tasks)
|
||||||
write_info(dataset.meta.info, root)
|
write_info(dataset.meta.info, root)
|
||||||
|
|
||||||
# Reload metadata to reflect changes
|
# Reload metadata to reflect changes
|
||||||
@@ -1858,10 +1853,10 @@ def convert_image_to_video_dataset(
|
|||||||
episodes_df.to_parquet(episodes_path, index=False)
|
episodes_df.to_parquet(episodes_path, index=False)
|
||||||
|
|
||||||
# Update metadata info
|
# Update metadata info
|
||||||
new_meta.info["total_episodes"] = len(episode_indices)
|
new_meta.info.total_episodes = len(episode_indices)
|
||||||
new_meta.info["total_frames"] = sum(ep["length"] for ep in all_episode_metadata.values())
|
new_meta.info.total_frames = sum(ep["length"] for ep in all_episode_metadata.values())
|
||||||
new_meta.info["total_tasks"] = dataset.meta.total_tasks
|
new_meta.info.total_tasks = dataset.meta.total_tasks
|
||||||
new_meta.info["splits"] = {"train": f"0:{len(episode_indices)}"}
|
new_meta.info.splits = {"train": f"0:{len(episode_indices)}"}
|
||||||
|
|
||||||
# Update video info for all image keys (now videos)
|
# Update video info for all image keys (now videos)
|
||||||
# We need to manually set video info since update_video_info() checks video_keys first
|
# We need to manually set video info since update_video_info() checks video_keys first
|
||||||
@@ -1870,7 +1865,7 @@ def convert_image_to_video_dataset(
|
|||||||
video_path = new_meta.root / new_meta.video_path.format(
|
video_path = new_meta.root / new_meta.video_path.format(
|
||||||
video_key=img_key, chunk_index=0, file_index=0
|
video_key=img_key, chunk_index=0, file_index=0
|
||||||
)
|
)
|
||||||
new_meta.info["features"][img_key]["info"] = get_video_info(video_path)
|
new_meta.info.features[img_key]["info"] = get_video_info(video_path)
|
||||||
|
|
||||||
write_info(new_meta.info, new_meta.root)
|
write_info(new_meta.info, new_meta.root)
|
||||||
|
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ from pprint import pformat
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from lerobot.configs import PreTrainedConfig
|
from lerobot.configs import PreTrainedConfig
|
||||||
|
from lerobot.configs.rewards import RewardModelConfig
|
||||||
from lerobot.configs.train import TrainPipelineConfig
|
from lerobot.configs.train import TrainPipelineConfig
|
||||||
from lerobot.transforms import ImageTransforms
|
from lerobot.transforms import ImageTransforms
|
||||||
from lerobot.utils.constants import ACTION, IMAGENET_STATS, OBS_PREFIX, REWARD
|
from lerobot.utils.constants import ACTION, IMAGENET_STATS, OBS_PREFIX, REWARD
|
||||||
@@ -30,12 +31,14 @@ from .streaming_dataset import StreamingLeRobotDataset
|
|||||||
|
|
||||||
|
|
||||||
def resolve_delta_timestamps(
|
def resolve_delta_timestamps(
|
||||||
cfg: PreTrainedConfig, ds_meta: LeRobotDatasetMetadata
|
cfg: PreTrainedConfig | RewardModelConfig, ds_meta: LeRobotDatasetMetadata
|
||||||
) -> dict[str, list] | None:
|
) -> dict[str, list] | None:
|
||||||
"""Resolves delta_timestamps by reading from the 'delta_indices' properties of the PreTrainedConfig.
|
"""Resolves delta_timestamps by reading from the 'delta_indices' properties of the config.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cfg (PreTrainedConfig): The PreTrainedConfig to read delta_indices from.
|
cfg (PreTrainedConfig | RewardModelConfig): The config to read delta_indices from. Both
|
||||||
|
``PreTrainedConfig`` and concrete ``RewardModelConfig`` subclasses expose the
|
||||||
|
``{observation,action,reward}_delta_indices`` properties used below.
|
||||||
ds_meta (LeRobotDatasetMetadata): The dataset from which features and fps are used to build
|
ds_meta (LeRobotDatasetMetadata): The dataset from which features and fps are used to build
|
||||||
delta_timestamps against.
|
delta_timestamps against.
|
||||||
|
|
||||||
@@ -82,7 +85,7 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
|
|||||||
ds_meta = LeRobotDatasetMetadata(
|
ds_meta = LeRobotDatasetMetadata(
|
||||||
cfg.dataset.repo_id, root=cfg.dataset.root, revision=cfg.dataset.revision
|
cfg.dataset.repo_id, root=cfg.dataset.root, revision=cfg.dataset.revision
|
||||||
)
|
)
|
||||||
delta_timestamps = resolve_delta_timestamps(cfg.policy, ds_meta)
|
delta_timestamps = resolve_delta_timestamps(cfg.trainable_config, ds_meta)
|
||||||
if not cfg.dataset.streaming:
|
if not cfg.dataset.streaming:
|
||||||
dataset = LeRobotDataset(
|
dataset = LeRobotDataset(
|
||||||
cfg.dataset.repo_id,
|
cfg.dataset.repo_id,
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ from .utils import (
|
|||||||
DEFAULT_DATA_PATH,
|
DEFAULT_DATA_PATH,
|
||||||
DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
||||||
DEFAULT_VIDEO_PATH,
|
DEFAULT_VIDEO_PATH,
|
||||||
|
DatasetInfo,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -78,8 +79,8 @@ def create_empty_dataset_info(
|
|||||||
chunks_size: int | None = None,
|
chunks_size: int | None = None,
|
||||||
data_files_size_in_mb: int | None = None,
|
data_files_size_in_mb: int | None = None,
|
||||||
video_files_size_in_mb: int | None = None,
|
video_files_size_in_mb: int | None = None,
|
||||||
) -> dict:
|
) -> DatasetInfo:
|
||||||
"""Create a template dictionary for a new dataset's `info.json`.
|
"""Create a template ``DatasetInfo`` object for a new dataset's ``meta/info.json``.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
codebase_version (str): The version of the LeRobot codebase.
|
codebase_version (str): The version of the LeRobot codebase.
|
||||||
@@ -87,25 +88,24 @@ def create_empty_dataset_info(
|
|||||||
features (dict): The LeRobot features dictionary for the dataset.
|
features (dict): The LeRobot features dictionary for the dataset.
|
||||||
use_videos (bool): Whether the dataset will store videos.
|
use_videos (bool): Whether the dataset will store videos.
|
||||||
robot_type (str | None): The type of robot used, if any.
|
robot_type (str | None): The type of robot used, if any.
|
||||||
|
chunks_size (int | None): Max files per chunk directory. Defaults to ``DEFAULT_CHUNK_SIZE``.
|
||||||
|
data_files_size_in_mb (int | None): Max parquet file size in MB. Defaults to ``DEFAULT_DATA_FILE_SIZE_IN_MB``.
|
||||||
|
video_files_size_in_mb (int | None): Max video file size in MB. Defaults to ``DEFAULT_VIDEO_FILE_SIZE_IN_MB``.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict: A dictionary with the initial dataset metadata.
|
DatasetInfo: A typed dataset information object with initial metadata.
|
||||||
"""
|
"""
|
||||||
return {
|
return DatasetInfo(
|
||||||
"codebase_version": codebase_version,
|
codebase_version=codebase_version,
|
||||||
"robot_type": robot_type,
|
fps=fps,
|
||||||
"total_episodes": 0,
|
features=features,
|
||||||
"total_frames": 0,
|
robot_type=robot_type,
|
||||||
"total_tasks": 0,
|
chunks_size=chunks_size or DEFAULT_CHUNK_SIZE,
|
||||||
"chunks_size": chunks_size or DEFAULT_CHUNK_SIZE,
|
data_files_size_in_mb=data_files_size_in_mb or DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||||
"data_files_size_in_mb": data_files_size_in_mb or DEFAULT_DATA_FILE_SIZE_IN_MB,
|
video_files_size_in_mb=video_files_size_in_mb or DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
||||||
"video_files_size_in_mb": video_files_size_in_mb or DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
data_path=DEFAULT_DATA_PATH,
|
||||||
"fps": fps,
|
video_path=DEFAULT_VIDEO_PATH if use_videos else None,
|
||||||
"splits": {},
|
)
|
||||||
"data_path": DEFAULT_DATA_PATH,
|
|
||||||
"video_path": DEFAULT_VIDEO_PATH if use_videos else None,
|
|
||||||
"features": features,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def check_delta_timestamps(
|
def check_delta_timestamps(
|
||||||
|
|||||||
@@ -39,6 +39,7 @@ from .utils import (
|
|||||||
EPISODES_DIR,
|
EPISODES_DIR,
|
||||||
INFO_PATH,
|
INFO_PATH,
|
||||||
STATS_PATH,
|
STATS_PATH,
|
||||||
|
DatasetInfo,
|
||||||
serialize_dict,
|
serialize_dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -115,25 +116,21 @@ def embed_images(dataset: datasets.Dataset) -> datasets.Dataset:
|
|||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
def write_info(info: dict, local_dir: Path) -> None:
|
def write_info(info: DatasetInfo, local_dir: Path) -> None:
|
||||||
write_json(info, local_dir / INFO_PATH)
|
write_json(info.to_dict(), local_dir / INFO_PATH)
|
||||||
|
|
||||||
|
|
||||||
def load_info(local_dir: Path) -> dict:
|
def load_info(local_dir: Path) -> DatasetInfo:
|
||||||
"""Load dataset info metadata from its standard file path.
|
"""Load dataset info metadata from its standard file path.
|
||||||
|
|
||||||
Also converts shape lists to tuples for consistency.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
local_dir (Path): The root directory of the dataset.
|
local_dir (Path): The root directory of the dataset.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict: The dataset information dictionary.
|
DatasetInfo: The typed dataset information object.
|
||||||
"""
|
"""
|
||||||
info = load_json(local_dir / INFO_PATH)
|
raw = load_json(local_dir / INFO_PATH)
|
||||||
for ft in info["features"].values():
|
return DatasetInfo.from_dict(raw)
|
||||||
ft["shape"] = tuple(ft["shape"])
|
|
||||||
return info
|
|
||||||
|
|
||||||
|
|
||||||
def write_stats(stats: dict, local_dir: Path) -> None:
|
def write_stats(stats: dict, local_dir: Path) -> None:
|
||||||
|
|||||||
@@ -630,6 +630,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
streaming_encoding: bool = False,
|
streaming_encoding: bool = False,
|
||||||
encoder_queue_maxsize: int = 30,
|
encoder_queue_maxsize: int = 30,
|
||||||
encoder_threads: int | None = None,
|
encoder_threads: int | None = None,
|
||||||
|
video_files_size_in_mb: int | None = None,
|
||||||
|
data_files_size_in_mb: int | None = None,
|
||||||
) -> "LeRobotDataset":
|
) -> "LeRobotDataset":
|
||||||
"""Create a new LeRobotDataset from scratch for recording data.
|
"""Create a new LeRobotDataset from scratch for recording data.
|
||||||
|
|
||||||
@@ -677,6 +679,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
root=root,
|
root=root,
|
||||||
use_videos=use_videos,
|
use_videos=use_videos,
|
||||||
metadata_buffer_size=metadata_buffer_size,
|
metadata_buffer_size=metadata_buffer_size,
|
||||||
|
video_files_size_in_mb=video_files_size_in_mb,
|
||||||
|
data_files_size_in_mb=data_files_size_in_mb,
|
||||||
)
|
)
|
||||||
obj.repo_id = obj.meta.repo_id
|
obj.repo_id = obj.meta.repo_id
|
||||||
obj._requested_root = obj.meta.root
|
obj._requested_root = obj.meta.root
|
||||||
|
|||||||
@@ -123,7 +123,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info.
|
NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info.
|
||||||
"""
|
"""
|
||||||
return self._datasets[0].meta.info["fps"]
|
return self._datasets[0].meta.info.fps
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def video(self) -> bool:
|
def video(self) -> bool:
|
||||||
@@ -133,7 +133,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info.
|
NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info.
|
||||||
"""
|
"""
|
||||||
return self._datasets[0].meta.info.get("video", False)
|
return len(self._datasets[0].meta.video_keys) > 0
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def features(self) -> datasets.Features:
|
def features(self) -> datasets.Features:
|
||||||
|
|||||||
@@ -434,7 +434,7 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
|||||||
|
|
||||||
def _make_padding_camera_frame(self, camera_key: str):
|
def _make_padding_camera_frame(self, camera_key: str):
|
||||||
"""Variable-shape padding frame for given camera keys, given in (H, W, C)"""
|
"""Variable-shape padding frame for given camera keys, given in (H, W, C)"""
|
||||||
return torch.zeros(self.meta.info["features"][camera_key]["shape"]).permute(-1, 0, 1)
|
return torch.zeros(self.meta.info.features[camera_key]["shape"]).permute(-1, 0, 1)
|
||||||
|
|
||||||
def _get_video_frame_padding_mask(
|
def _get_video_frame_padding_mask(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -14,9 +14,11 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import contextlib
|
import contextlib
|
||||||
|
import dataclasses
|
||||||
import importlib.resources
|
import importlib.resources
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
@@ -70,6 +72,9 @@ class ForwardCompatibilityError(CompatibilityError):
|
|||||||
super().__init__(message)
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_CHUNK_SIZE = 1000 # Max number of files per chunk
|
DEFAULT_CHUNK_SIZE = 1000 # Max number of files per chunk
|
||||||
DEFAULT_DATA_FILE_SIZE_IN_MB = 100 # Max size per file
|
DEFAULT_DATA_FILE_SIZE_IN_MB = 100 # Max size per file
|
||||||
DEFAULT_VIDEO_FILE_SIZE_IN_MB = 200 # Max size per file
|
DEFAULT_VIDEO_FILE_SIZE_IN_MB = 200 # Max size per file
|
||||||
@@ -94,6 +99,123 @@ LEGACY_EPISODES_STATS_PATH = "meta/episodes_stats.jsonl"
|
|||||||
LEGACY_TASKS_PATH = "meta/tasks.jsonl"
|
LEGACY_TASKS_PATH = "meta/tasks.jsonl"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DatasetInfo:
|
||||||
|
"""Typed representation of the ``meta/info.json`` file for a LeRobot dataset.
|
||||||
|
|
||||||
|
Replaces the previously untyped ``dict`` returned by ``load_info()`` and
|
||||||
|
created by ``create_empty_dataset_info()``. Using a dataclass provides
|
||||||
|
explicit field definitions, IDE auto-completion, and validation at
|
||||||
|
construction time.
|
||||||
|
"""
|
||||||
|
|
||||||
|
codebase_version: str
|
||||||
|
fps: int
|
||||||
|
features: dict[str, dict]
|
||||||
|
|
||||||
|
# Episode / frame counters — start at zero for new datasets
|
||||||
|
total_episodes: int = 0
|
||||||
|
total_frames: int = 0
|
||||||
|
total_tasks: int = 0
|
||||||
|
|
||||||
|
# Storage settings
|
||||||
|
chunks_size: int = field(default=DEFAULT_CHUNK_SIZE)
|
||||||
|
data_files_size_in_mb: int = field(default=DEFAULT_DATA_FILE_SIZE_IN_MB)
|
||||||
|
video_files_size_in_mb: int = field(default=DEFAULT_VIDEO_FILE_SIZE_IN_MB)
|
||||||
|
|
||||||
|
# File path templates
|
||||||
|
data_path: str = field(default=DEFAULT_DATA_PATH)
|
||||||
|
video_path: str | None = field(default=DEFAULT_VIDEO_PATH)
|
||||||
|
|
||||||
|
# Optional metadata
|
||||||
|
robot_type: str | None = None
|
||||||
|
splits: dict[str, str] = field(default_factory=dict)
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
# Coerce feature shapes from list to tuple — JSON deserialisation
|
||||||
|
# returns lists, but the rest of the codebase expects tuples.
|
||||||
|
for ft in self.features.values():
|
||||||
|
if isinstance(ft.get("shape"), list):
|
||||||
|
ft["shape"] = tuple(ft["shape"])
|
||||||
|
|
||||||
|
if self.fps <= 0:
|
||||||
|
raise ValueError(f"fps must be positive, got {self.fps}")
|
||||||
|
if self.chunks_size <= 0:
|
||||||
|
raise ValueError(f"chunks_size must be positive, got {self.chunks_size}")
|
||||||
|
if self.data_files_size_in_mb <= 0:
|
||||||
|
raise ValueError(f"data_files_size_in_mb must be positive, got {self.data_files_size_in_mb}")
|
||||||
|
if self.video_files_size_in_mb <= 0:
|
||||||
|
raise ValueError(f"video_files_size_in_mb must be positive, got {self.video_files_size_in_mb}")
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
"""Return a JSON-serialisable dict.
|
||||||
|
|
||||||
|
Converts tuple shapes back to lists so ``json.dump`` can handle them.
|
||||||
|
"""
|
||||||
|
d = dataclasses.asdict(self)
|
||||||
|
for ft in d["features"].values():
|
||||||
|
if isinstance(ft.get("shape"), tuple):
|
||||||
|
ft["shape"] = list(ft["shape"])
|
||||||
|
return d
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: dict) -> "DatasetInfo":
|
||||||
|
"""Construct from a raw dict (e.g. loaded directly from JSON).
|
||||||
|
|
||||||
|
Unknown keys are ignored for forward compatibility with datasets that
|
||||||
|
carry additional fields (e.g. ``total_videos`` from v2.x). A warning is
|
||||||
|
logged when such fields are present.
|
||||||
|
"""
|
||||||
|
known = {f.name for f in dataclasses.fields(cls)}
|
||||||
|
unknown = sorted(k for k in data if k not in known)
|
||||||
|
if unknown:
|
||||||
|
logger.warning(f"Unknown fields in DatasetInfo: {unknown}. These will be ignored.")
|
||||||
|
return cls(**{k: v for k, v in data.items() if k in known})
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Temporary dict-style compatibility layer
|
||||||
|
# Allows existing ``info["key"]`` call-sites to keep working without changes.
|
||||||
|
# Once all callers have been migrated to attribute access, remove these.
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
def __getitem__(self, key: str):
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
warnings.warn(
|
||||||
|
f"Accessing DatasetInfo with dict-style syntax info['{key}'] is deprecated. "
|
||||||
|
f"Use attribute access info.{key} instead.",
|
||||||
|
DeprecationWarning,
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
return getattr(self, key)
|
||||||
|
except AttributeError as err:
|
||||||
|
raise KeyError(key) from err
|
||||||
|
|
||||||
|
def __setitem__(self, key: str, value) -> None:
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
warnings.warn(
|
||||||
|
f"Setting DatasetInfo with dict-style syntax info['{key}'] = ... is deprecated. "
|
||||||
|
f"Use attribute assignment info.{key} = ... instead.",
|
||||||
|
DeprecationWarning,
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
if not hasattr(self, key):
|
||||||
|
raise KeyError(f"DatasetInfo has no field '{key}'")
|
||||||
|
setattr(self, key, value)
|
||||||
|
|
||||||
|
def __contains__(self, key: str) -> bool:
|
||||||
|
"""Check if a field exists (dict-like interface)."""
|
||||||
|
return hasattr(self, key)
|
||||||
|
|
||||||
|
def get(self, key: str, default=None):
|
||||||
|
"""Get attribute value with default fallback (dict-like interface)."""
|
||||||
|
try:
|
||||||
|
return getattr(self, key)
|
||||||
|
except AttributeError:
|
||||||
|
return default
|
||||||
|
|
||||||
|
|
||||||
def has_legacy_hub_download_metadata(root: Path) -> bool:
|
def has_legacy_hub_download_metadata(root: Path) -> bool:
|
||||||
"""Return ``True`` when *root* looks like a legacy Hub ``local_dir`` mirror.
|
"""Return ``True`` when *root* looks like a legacy Hub ``local_dir`` mirror.
|
||||||
|
|
||||||
@@ -294,7 +416,7 @@ def create_branch(repo_id: str, *, branch: str, repo_type: str | None = None) ->
|
|||||||
|
|
||||||
def create_lerobot_dataset_card(
|
def create_lerobot_dataset_card(
|
||||||
tags: list | None = None,
|
tags: list | None = None,
|
||||||
dataset_info: dict | None = None,
|
dataset_info: DatasetInfo | None = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> DatasetCard:
|
) -> DatasetCard:
|
||||||
"""Create a `DatasetCard` for a LeRobot dataset.
|
"""Create a `DatasetCard` for a LeRobot dataset.
|
||||||
@@ -305,7 +427,7 @@ def create_lerobot_dataset_card(
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
tags (list | None): A list of tags to add to the dataset card.
|
tags (list | None): A list of tags to add to the dataset card.
|
||||||
dataset_info (dict | None): The dataset's info dictionary, which will
|
dataset_info (DatasetInfo | None): The dataset's info object, which will
|
||||||
be displayed on the card.
|
be displayed on the card.
|
||||||
**kwargs: Additional keyword arguments to populate the card template.
|
**kwargs: Additional keyword arguments to populate the card template.
|
||||||
|
|
||||||
@@ -318,7 +440,7 @@ def create_lerobot_dataset_card(
|
|||||||
card_tags += tags
|
card_tags += tags
|
||||||
if dataset_info:
|
if dataset_info:
|
||||||
dataset_structure = "[meta/info.json](meta/info.json):\n"
|
dataset_structure = "[meta/info.json](meta/info.json):\n"
|
||||||
dataset_structure += f"```json\n{json.dumps(dataset_info, indent=4)}\n```\n"
|
dataset_structure += f"```json\n{json.dumps(dataset_info.to_dict(), indent=4)}\n```\n"
|
||||||
kwargs = {**kwargs, "dataset_structure": dataset_structure}
|
kwargs = {**kwargs, "dataset_structure": dataset_structure}
|
||||||
card_data = DatasetCardData(
|
card_data = DatasetCardData(
|
||||||
license=kwargs.get("license"),
|
license=kwargs.get("license"),
|
||||||
|
|||||||
@@ -282,7 +282,11 @@ class VideoDecoderCache:
|
|||||||
with self._lock:
|
with self._lock:
|
||||||
if video_path not in self._cache:
|
if video_path not in self._cache:
|
||||||
file_handle = fsspec.open(video_path).__enter__()
|
file_handle = fsspec.open(video_path).__enter__()
|
||||||
decoder = VideoDecoder(file_handle, seek_mode="approximate")
|
try:
|
||||||
|
decoder = VideoDecoder(file_handle, seek_mode="approximate")
|
||||||
|
except Exception:
|
||||||
|
file_handle.close()
|
||||||
|
raise
|
||||||
self._cache[video_path] = (decoder, file_handle)
|
self._cache[video_path] = (decoder, file_handle)
|
||||||
|
|
||||||
return self._cache[video_path][0]
|
return self._cache[video_path][0]
|
||||||
|
|||||||
@@ -24,7 +24,12 @@ import gymnasium as gym
|
|||||||
from gymnasium.envs.registration import registry as gym_registry
|
from gymnasium.envs.registration import registry as gym_registry
|
||||||
|
|
||||||
from lerobot.configs import FeatureType, PolicyFeature
|
from lerobot.configs import FeatureType, PolicyFeature
|
||||||
from lerobot.processor import IsaaclabArenaProcessorStep, LiberoProcessorStep, PolicyProcessorPipeline
|
from lerobot.processor import (
|
||||||
|
IsaaclabArenaProcessorStep,
|
||||||
|
LiberoActionProcessorStep,
|
||||||
|
LiberoProcessorStep,
|
||||||
|
PolicyProcessorPipeline,
|
||||||
|
)
|
||||||
from lerobot.robots import RobotConfig
|
from lerobot.robots import RobotConfig
|
||||||
from lerobot.teleoperators.config import TeleoperatorConfig
|
from lerobot.teleoperators.config import TeleoperatorConfig
|
||||||
from lerobot.utils.constants import (
|
from lerobot.utils.constants import (
|
||||||
@@ -123,7 +128,7 @@ class EnvConfig(draccus.ChoiceRegistry, abc.ABC):
|
|||||||
vec = env_cls([_make_one for _ in range(n_envs)], **extra_kwargs)
|
vec = env_cls([_make_one for _ in range(n_envs)], **extra_kwargs)
|
||||||
return {self.type: {0: vec}}
|
return {self.type: {0: vec}}
|
||||||
|
|
||||||
def get_env_processors(self):
|
def get_env_processors(self, policy_cfg: Any | None = None):
|
||||||
"""Return (preprocessor, postprocessor) for this env. Default: identity."""
|
"""Return (preprocessor, postprocessor) for this env. Default: identity."""
|
||||||
return PolicyProcessorPipeline(steps=[]), PolicyProcessorPipeline(steps=[])
|
return PolicyProcessorPipeline(steps=[]), PolicyProcessorPipeline(steps=[])
|
||||||
|
|
||||||
@@ -436,10 +441,13 @@ class LiberoEnv(EnvConfig):
|
|||||||
is_libero_plus=self.is_libero_plus,
|
is_libero_plus=self.is_libero_plus,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_env_processors(self):
|
def get_env_processors(self, policy_cfg: Any | None = None):
|
||||||
|
max_state_dim = getattr(policy_cfg, "max_state_dim", None) if getattr(policy_cfg, "type", None) == "evo1" else None
|
||||||
|
action_feature = self.features.get(ACTION)
|
||||||
|
action_dim = int(action_feature.shape[0]) if action_feature is not None else 7
|
||||||
return (
|
return (
|
||||||
PolicyProcessorPipeline(steps=[LiberoProcessorStep()]),
|
PolicyProcessorPipeline(steps=[LiberoProcessorStep(max_state_dim=max_state_dim)]),
|
||||||
PolicyProcessorPipeline(steps=[]),
|
PolicyProcessorPipeline(steps=[LiberoActionProcessorStep(action_dim=action_dim)]),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -705,7 +713,7 @@ class IsaaclabArenaEnv(HubEnvConfig):
|
|||||||
def gym_kwargs(self) -> dict:
|
def gym_kwargs(self) -> dict:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
def get_env_processors(self):
|
def get_env_processors(self, policy_cfg: Any | None = None):
|
||||||
state_keys = tuple(k.strip() for k in (self.state_keys or "").split(",") if k.strip())
|
state_keys = tuple(k.strip() for k in (self.state_keys or "").split(",") if k.strip())
|
||||||
camera_keys = tuple(k.strip() for k in (self.camera_keys or "").split(",") if k.strip())
|
camera_keys = tuple(k.strip() for k in (self.camera_keys or "").split(",") if k.strip())
|
||||||
if not state_keys and not camera_keys:
|
if not state_keys and not camera_keys:
|
||||||
|
|||||||
@@ -15,6 +15,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import inspect
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
@@ -52,7 +53,14 @@ def make_env_pre_post_processors(
|
|||||||
|
|
||||||
return make_xvla_libero_pre_post_processors()
|
return make_xvla_libero_pre_post_processors()
|
||||||
|
|
||||||
return env_cfg.get_env_processors()
|
get_processors = env_cfg.get_env_processors
|
||||||
|
signature = inspect.signature(get_processors)
|
||||||
|
supports_policy_cfg = "policy_cfg" in signature.parameters or any(
|
||||||
|
param.kind is inspect.Parameter.VAR_KEYWORD for param in signature.parameters.values()
|
||||||
|
)
|
||||||
|
if supports_policy_cfg:
|
||||||
|
return get_processors(policy_cfg=policy_cfg)
|
||||||
|
return get_processors()
|
||||||
|
|
||||||
|
|
||||||
def make_env(
|
def make_env(
|
||||||
|
|||||||
@@ -12,8 +12,12 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from lerobot.utils.action_interpolator import ActionInterpolator as ActionInterpolator
|
||||||
|
|
||||||
from .act.configuration_act import ACTConfig as ACTConfig
|
from .act.configuration_act import ACTConfig as ACTConfig
|
||||||
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
|
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
|
||||||
|
from .eo1.configuration_eo1 import EO1Config as EO1Config
|
||||||
|
from .evo1.configuration_evo1 import Evo1Config as Evo1Config
|
||||||
from .factory import get_policy_class, make_policy, make_policy_config, make_pre_post_processors
|
from .factory import get_policy_class, make_policy, make_policy_config, make_pre_post_processors
|
||||||
from .groot.configuration_groot import GrootConfig as GrootConfig
|
from .groot.configuration_groot import GrootConfig as GrootConfig
|
||||||
from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig as MultiTaskDiTConfig
|
from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig as MultiTaskDiTConfig
|
||||||
@@ -21,10 +25,7 @@ from .pi0.configuration_pi0 import PI0Config as PI0Config
|
|||||||
from .pi0_fast.configuration_pi0_fast import PI0FastConfig as PI0FastConfig
|
from .pi0_fast.configuration_pi0_fast import PI0FastConfig as PI0FastConfig
|
||||||
from .pi05.configuration_pi05 import PI05Config as PI05Config
|
from .pi05.configuration_pi05 import PI05Config as PI05Config
|
||||||
from .pretrained import PreTrainedPolicy as PreTrainedPolicy
|
from .pretrained import PreTrainedPolicy as PreTrainedPolicy
|
||||||
from .rtc import ActionInterpolator as ActionInterpolator
|
|
||||||
from .sac.configuration_sac import SACConfig as SACConfig
|
from .sac.configuration_sac import SACConfig as SACConfig
|
||||||
from .sac.reward_model.configuration_classifier import RewardClassifierConfig as RewardClassifierConfig
|
|
||||||
from .sarm.configuration_sarm import SARMConfig as SARMConfig
|
|
||||||
from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
|
from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
|
||||||
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
|
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
|
||||||
from .utils import make_robot_action, prepare_observation_for_inference
|
from .utils import make_robot_action, prepare_observation_for_inference
|
||||||
@@ -40,14 +41,14 @@ __all__ = [
|
|||||||
# Configuration classes
|
# Configuration classes
|
||||||
"ACTConfig",
|
"ACTConfig",
|
||||||
"DiffusionConfig",
|
"DiffusionConfig",
|
||||||
|
"Evo1Config",
|
||||||
"GrootConfig",
|
"GrootConfig",
|
||||||
"MultiTaskDiTConfig",
|
"MultiTaskDiTConfig",
|
||||||
|
"EO1Config",
|
||||||
"PI0Config",
|
"PI0Config",
|
||||||
"PI0FastConfig",
|
"PI0FastConfig",
|
||||||
"PI05Config",
|
"PI05Config",
|
||||||
"RewardClassifierConfig",
|
|
||||||
"SACConfig",
|
"SACConfig",
|
||||||
"SARMConfig",
|
|
||||||
"SmolVLAConfig",
|
"SmolVLAConfig",
|
||||||
"TDMPCConfig",
|
"TDMPCConfig",
|
||||||
"VQBeTConfig",
|
"VQBeTConfig",
|
||||||
|
|||||||
@@ -142,9 +142,10 @@ class ACTPolicy(PreTrainedPolicy):
|
|||||||
|
|
||||||
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
|
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
|
||||||
|
|
||||||
l1_loss = (
|
abs_err = F.l1_loss(batch[ACTION], actions_hat, reduction="none")
|
||||||
F.l1_loss(batch[ACTION], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
|
valid_mask = ~batch["action_is_pad"].unsqueeze(-1)
|
||||||
).mean()
|
num_valid = valid_mask.sum() * abs_err.shape[-1]
|
||||||
|
l1_loss = (abs_err * valid_mask).sum() / num_valid.clamp_min(1)
|
||||||
|
|
||||||
loss_dict = {"l1_loss": l1_loss.item()}
|
loss_dict = {"l1_loss": l1_loss.item()}
|
||||||
if self.config.use_vae:
|
if self.config.use_vae:
|
||||||
|
|||||||
@@ -100,8 +100,8 @@ class DiffusionConfig(PreTrainedConfig):
|
|||||||
|
|
||||||
# Inputs / output structure.
|
# Inputs / output structure.
|
||||||
n_obs_steps: int = 2
|
n_obs_steps: int = 2
|
||||||
horizon: int = 16
|
horizon: int = 64
|
||||||
n_action_steps: int = 8
|
n_action_steps: int = 32
|
||||||
|
|
||||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||||
default_factory=lambda: {
|
default_factory=lambda: {
|
||||||
@@ -122,10 +122,10 @@ class DiffusionConfig(PreTrainedConfig):
|
|||||||
crop_ratio: float = 1.0
|
crop_ratio: float = 1.0
|
||||||
crop_shape: tuple[int, int] | None = None
|
crop_shape: tuple[int, int] | None = None
|
||||||
crop_is_random: bool = True
|
crop_is_random: bool = True
|
||||||
pretrained_backbone_weights: str | None = None
|
pretrained_backbone_weights: str | None = "ResNet18_Weights.IMAGENET1K_V1"
|
||||||
use_group_norm: bool = True
|
use_group_norm: bool = False
|
||||||
spatial_softmax_num_keypoints: int = 32
|
spatial_softmax_num_keypoints: int = 32
|
||||||
use_separate_rgb_encoder_per_camera: bool = False
|
use_separate_rgb_encoder_per_camera: bool = True
|
||||||
# Unet.
|
# Unet.
|
||||||
down_dims: tuple[int, ...] = (512, 1024, 2048)
|
down_dims: tuple[int, ...] = (512, 1024, 2048)
|
||||||
kernel_size: int = 5
|
kernel_size: int = 5
|
||||||
|
|||||||
@@ -380,7 +380,9 @@ class DiffusionModel(nn.Module):
|
|||||||
f"{self.config.do_mask_loss_for_padding=}."
|
f"{self.config.do_mask_loss_for_padding=}."
|
||||||
)
|
)
|
||||||
in_episode_bound = ~batch["action_is_pad"]
|
in_episode_bound = ~batch["action_is_pad"]
|
||||||
loss = loss * in_episode_bound.unsqueeze(-1)
|
mask = in_episode_bound.unsqueeze(-1)
|
||||||
|
num_valid = mask.sum() * loss.shape[-1]
|
||||||
|
return (loss * mask).sum() / num_valid.clamp_min(1)
|
||||||
|
|
||||||
return loss.mean()
|
return loss.mean()
|
||||||
|
|
||||||
|
|||||||
+1
@@ -0,0 +1 @@
|
|||||||
|
../../../../docs/source/eo1.mdx
|
||||||
@@ -0,0 +1,7 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
from .configuration_eo1 import EO1Config
|
||||||
|
from .modeling_eo1 import EO1Policy
|
||||||
|
from .processor_eo1 import make_eo1_pre_post_processors
|
||||||
|
|
||||||
|
__all__ = ["EO1Config", "EO1Policy", "make_eo1_pre_post_processors"]
|
||||||
@@ -0,0 +1,193 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from copy import deepcopy
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from lerobot.configs.policies import PreTrainedConfig
|
||||||
|
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||||
|
from lerobot.optim.optimizers import AdamWConfig
|
||||||
|
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
||||||
|
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||||
|
from lerobot.utils.import_utils import _transformers_available, require_package
|
||||||
|
|
||||||
|
if TYPE_CHECKING or _transformers_available:
|
||||||
|
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
|
||||||
|
Qwen2_5_VLConfig,
|
||||||
|
Qwen2_5_VLTextConfig,
|
||||||
|
Qwen2_5_VLVisionConfig,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
Qwen2_5_VLConfig = None
|
||||||
|
Qwen2_5_VLTextConfig = None
|
||||||
|
Qwen2_5_VLVisionConfig = None
|
||||||
|
|
||||||
|
|
||||||
|
@PreTrainedConfig.register_subclass("eo1")
|
||||||
|
@dataclass
|
||||||
|
class EO1Config(PreTrainedConfig):
|
||||||
|
"""Configuration for native EO1 policy integration in LeRobot."""
|
||||||
|
|
||||||
|
vlm_base: str = "Qwen/Qwen2.5-VL-3B-Instruct"
|
||||||
|
vlm_config: dict | None = None
|
||||||
|
|
||||||
|
# Vision processor settings.
|
||||||
|
image_min_pixels: int | None = 64 * 28 * 28
|
||||||
|
image_max_pixels: int | None = 128 * 28 * 28
|
||||||
|
use_fast_processor: bool = False
|
||||||
|
|
||||||
|
# Execution and action horizon.
|
||||||
|
n_obs_steps: int = 1
|
||||||
|
chunk_size: int = 8
|
||||||
|
n_action_steps: int = 8
|
||||||
|
|
||||||
|
# State/action padding to match EO1 flow head dimensionality.
|
||||||
|
max_state_dim: int = 32
|
||||||
|
max_action_dim: int = 32
|
||||||
|
|
||||||
|
# Flow matching sampling.
|
||||||
|
num_denoise_steps: int = 10
|
||||||
|
num_action_layers: int = 2
|
||||||
|
action_act: str = "linear"
|
||||||
|
time_sampling_beta_alpha: float = 1.5
|
||||||
|
time_sampling_beta_beta: float = 1.0
|
||||||
|
time_sampling_scale: float = 0.999
|
||||||
|
time_sampling_offset: float = 0.001
|
||||||
|
min_period: float = 4e-3
|
||||||
|
max_period: float = 4.0
|
||||||
|
supervise_padding_action_dims: bool = True
|
||||||
|
supervise_padding_actions: bool = True
|
||||||
|
|
||||||
|
# Policy-level dtype request for the Qwen backbone.
|
||||||
|
# - "auto": follow the backbone config/checkpoint default dtype. For Qwen2.5-VL this resolves to bf16.
|
||||||
|
# The EO1 flow-matching head still keeps its own parameters in fp32.
|
||||||
|
# - "bfloat16": force the backbone to initialize/load in bf16 regardless of the saved config default.
|
||||||
|
# - "float32": force the backbone to initialize/load in fp32 for maximum numerical conservatism.
|
||||||
|
dtype: str = "auto" # Options: "auto", "bfloat16", "float32"
|
||||||
|
force_fp32_autocast: bool = True
|
||||||
|
|
||||||
|
# Optional attention backend request passed through to the Qwen backbone.
|
||||||
|
# Common values: None, "eager", "sdpa", "flash_attention_2".
|
||||||
|
attn_implementation: str | None = None
|
||||||
|
|
||||||
|
# Training settings.
|
||||||
|
gradient_checkpointing: bool = False # Enable gradient checkpointing for memory optimization
|
||||||
|
|
||||||
|
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
"VISUAL": NormalizationMode.IDENTITY,
|
||||||
|
"STATE": NormalizationMode.MEAN_STD,
|
||||||
|
"ACTION": NormalizationMode.MEAN_STD,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Optimizer settings aligned with EO1/experiments/2_libero/train.sh and EO1 TrainPipelineConfig defaults.
|
||||||
|
optimizer_lr: float = 1e-4
|
||||||
|
optimizer_betas: tuple[float, float] = (0.9, 0.999)
|
||||||
|
optimizer_eps: float = 1e-8
|
||||||
|
optimizer_weight_decay: float = 0.1
|
||||||
|
optimizer_grad_clip_norm: float = 1.0
|
||||||
|
|
||||||
|
# Scheduler settings aligned with EO1 train.sh: cosine schedule with warmup_ratio=0.03.
|
||||||
|
# Note: These will auto-scale if --steps < scheduler_decay_steps
|
||||||
|
# For example, --steps=3000 will scale warmup to 100 and decay to 3000
|
||||||
|
scheduler_warmup_steps: int = 900 # 0.03 * 30_000 long-run steps
|
||||||
|
scheduler_decay_steps: int = 30_000
|
||||||
|
scheduler_decay_lr: float = 0.0
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
super().__post_init__()
|
||||||
|
|
||||||
|
if self.n_action_steps > self.chunk_size:
|
||||||
|
raise ValueError(
|
||||||
|
f"n_action_steps ({self.n_action_steps}) cannot be greater than chunk_size ({self.chunk_size})"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Populate the serialized backbone config only when the caller did not provide one.
|
||||||
|
if self.vlm_config is None:
|
||||||
|
require_package("transformers", extra="eo1")
|
||||||
|
self.vlm_config = Qwen2_5_VLConfig.from_pretrained(self.vlm_base).to_dict()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def vlm_backbone_config(self) -> Qwen2_5_VLConfig:
|
||||||
|
require_package("transformers", extra="eo1")
|
||||||
|
config_dict = deepcopy(self.vlm_config)
|
||||||
|
if self.attn_implementation is not None:
|
||||||
|
config_dict["attn_implementation"] = self.attn_implementation
|
||||||
|
return Qwen2_5_VLConfig(**config_dict)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def text_config(self) -> Qwen2_5_VLTextConfig:
|
||||||
|
return self.vlm_backbone_config.text_config
|
||||||
|
|
||||||
|
@property
|
||||||
|
def vision_config(self) -> Qwen2_5_VLVisionConfig:
|
||||||
|
return self.vlm_backbone_config.vision_config
|
||||||
|
|
||||||
|
def validate_features(self) -> None:
|
||||||
|
"""Validate and set up EO1 input and output features."""
|
||||||
|
image_features = [key for key, feat in self.input_features.items() if feat.type == FeatureType.VISUAL]
|
||||||
|
if not image_features:
|
||||||
|
raise ValueError(
|
||||||
|
"EO1 policy requires at least one visual input feature. "
|
||||||
|
"No features of type FeatureType.VISUAL found in input_features."
|
||||||
|
)
|
||||||
|
|
||||||
|
if OBS_STATE not in self.input_features:
|
||||||
|
state_feature = PolicyFeature(
|
||||||
|
type=FeatureType.STATE,
|
||||||
|
shape=(self.max_state_dim,),
|
||||||
|
)
|
||||||
|
self.input_features[OBS_STATE] = state_feature
|
||||||
|
|
||||||
|
if ACTION not in self.output_features:
|
||||||
|
action_feature = PolicyFeature(
|
||||||
|
type=FeatureType.ACTION,
|
||||||
|
shape=(self.max_action_dim,),
|
||||||
|
)
|
||||||
|
self.output_features[ACTION] = action_feature
|
||||||
|
|
||||||
|
def get_optimizer_preset(self) -> AdamWConfig:
|
||||||
|
return AdamWConfig(
|
||||||
|
lr=self.optimizer_lr,
|
||||||
|
betas=self.optimizer_betas,
|
||||||
|
eps=self.optimizer_eps,
|
||||||
|
weight_decay=self.optimizer_weight_decay,
|
||||||
|
grad_clip_norm=self.optimizer_grad_clip_norm,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_scheduler_preset(self):
|
||||||
|
return CosineDecayWithWarmupSchedulerConfig(
|
||||||
|
peak_lr=self.optimizer_lr,
|
||||||
|
decay_lr=self.scheduler_decay_lr,
|
||||||
|
num_warmup_steps=self.scheduler_warmup_steps,
|
||||||
|
num_decay_steps=self.scheduler_decay_steps,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def observation_delta_indices(self) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def action_delta_indices(self) -> list[int]:
|
||||||
|
return list(range(self.chunk_size))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def reward_delta_indices(self) -> None:
|
||||||
|
return None
|
||||||
@@ -0,0 +1,620 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import contextlib
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
from collections import deque
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F # noqa: N812
|
||||||
|
import torch.utils.checkpoint
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
from lerobot.policies.eo1.configuration_eo1 import EO1Config
|
||||||
|
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||||
|
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||||
|
from lerobot.utils.import_utils import _transformers_available, require_package
|
||||||
|
|
||||||
|
if TYPE_CHECKING or _transformers_available:
|
||||||
|
from transformers.activations import ACT2FN
|
||||||
|
from transformers.models.qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
|
||||||
|
from transformers.utils import torch_compilable_check
|
||||||
|
else:
|
||||||
|
ACT2FN = None
|
||||||
|
Qwen2_5_VLForConditionalGeneration = None
|
||||||
|
torch_compilable_check = None
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def pad_vector(vector, new_dim):
|
||||||
|
"""Pad the last dimension of a vector to new_dim with zeros.
|
||||||
|
|
||||||
|
Can be (batch_size x sequence_length x features_dimension)
|
||||||
|
or (batch_size x features_dimension)
|
||||||
|
"""
|
||||||
|
if vector.shape[-1] >= new_dim:
|
||||||
|
return vector
|
||||||
|
return F.pad(vector, (0, new_dim - vector.shape[-1]))
|
||||||
|
|
||||||
|
|
||||||
|
class EO1Policy(PreTrainedPolicy):
|
||||||
|
"""EO1 policy wrapper for LeRobot robot-only training/evaluation."""
|
||||||
|
|
||||||
|
config_class = EO1Config
|
||||||
|
name = "eo1"
|
||||||
|
|
||||||
|
def __init__(self, config: EO1Config, **kwargs):
|
||||||
|
require_package("transformers", extra="eo1")
|
||||||
|
super().__init__(config)
|
||||||
|
config.validate_features()
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
if config.pretrained_path is None:
|
||||||
|
# Initialize from pretrained VLM
|
||||||
|
vlm_backbone = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||||
|
config.vlm_base,
|
||||||
|
dtype=config.dtype,
|
||||||
|
attn_implementation=config.attn_implementation,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
vlm_backbone = Qwen2_5_VLForConditionalGeneration._from_config(
|
||||||
|
config.vlm_backbone_config,
|
||||||
|
dtype=config.vlm_backbone_config.dtype if config.dtype == "auto" else config.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.model = EO1VisionFlowMatchingModel(config, vlm_backbone)
|
||||||
|
if config.gradient_checkpointing:
|
||||||
|
self.model.gradient_checkpointing_enable()
|
||||||
|
|
||||||
|
self.model.to(config.device)
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self._action_queue = deque(maxlen=self.config.n_action_steps)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_model_inputs(batch: dict[str, Tensor], excluded_keys: set[str]) -> dict[str, Tensor]:
|
||||||
|
return {key: value for key, value in batch.items() if key not in excluded_keys}
|
||||||
|
|
||||||
|
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
||||||
|
state = self.prepare_state(batch[OBS_STATE])
|
||||||
|
actions = self.prepare_action(batch[ACTION])
|
||||||
|
model_inputs = self._get_model_inputs(batch, {OBS_STATE, ACTION})
|
||||||
|
loss = self.model(states=state, action=actions, **model_inputs)
|
||||||
|
|
||||||
|
loss_dict = {"loss": loss.item()}
|
||||||
|
return loss, loss_dict
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs) -> Tensor:
|
||||||
|
self.eval()
|
||||||
|
|
||||||
|
states = self.prepare_state(batch[OBS_STATE])
|
||||||
|
model_inputs = self._get_model_inputs(batch, {OBS_STATE})
|
||||||
|
actions = self.model.sample_actions(states=states, **model_inputs).to(torch.float32)
|
||||||
|
|
||||||
|
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||||
|
return actions[:, :, :original_action_dim]
|
||||||
|
|
||||||
|
def prepare_state(self, state: Tensor) -> Tensor:
|
||||||
|
return pad_vector(state, self.config.max_state_dim)
|
||||||
|
|
||||||
|
def prepare_action(self, action: Tensor) -> Tensor:
|
||||||
|
return pad_vector(action, self.config.max_action_dim)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||||
|
self.eval()
|
||||||
|
|
||||||
|
if len(self._action_queue) == 0:
|
||||||
|
actions = self.predict_action_chunk(batch)[:, : self.config.n_action_steps]
|
||||||
|
self._action_queue.extend(actions.transpose(0, 1))
|
||||||
|
|
||||||
|
return self._action_queue.popleft()
|
||||||
|
|
||||||
|
def get_optim_params(self) -> dict:
|
||||||
|
return self.parameters()
|
||||||
|
|
||||||
|
|
||||||
|
def get_safe_dtype(target_dtype, device_type):
|
||||||
|
"""Get a safe dtype for the given device type."""
|
||||||
|
if device_type == "mps" and target_dtype == torch.float64:
|
||||||
|
return torch.float32
|
||||||
|
if device_type == "cpu":
|
||||||
|
# CPU doesn't support bfloat16, use float32 instead
|
||||||
|
if target_dtype == torch.bfloat16:
|
||||||
|
return torch.float32
|
||||||
|
if target_dtype == torch.float64:
|
||||||
|
return torch.float64
|
||||||
|
return target_dtype
|
||||||
|
|
||||||
|
|
||||||
|
def create_sinusoidal_pos_embedding( # see openpi `create_sinusoidal_pos_embedding` (exact copy)
|
||||||
|
time: torch.Tensor, dimension: int, min_period: float, max_period: float, device="cpu"
|
||||||
|
) -> Tensor:
|
||||||
|
"""Computes sine-cosine positional embedding vectors for scalar positions."""
|
||||||
|
if dimension % 2 != 0:
|
||||||
|
raise ValueError(f"dimension ({dimension}) must be divisible by 2")
|
||||||
|
|
||||||
|
if time.ndim != 1:
|
||||||
|
raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")
|
||||||
|
|
||||||
|
dtype = get_safe_dtype(torch.float64, device.type)
|
||||||
|
fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device)
|
||||||
|
period = min_period * (max_period / min_period) ** fraction
|
||||||
|
|
||||||
|
# Compute the outer product
|
||||||
|
scaling_factor = 1.0 / period * 2 * math.pi
|
||||||
|
sin_input = scaling_factor[None, :] * time[:, None]
|
||||||
|
return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
|
||||||
|
|
||||||
|
|
||||||
|
def sample_beta(alpha, beta, bsize, device): # see openpi `sample_beta` (exact copy)
|
||||||
|
# Beta sampling uses _sample_dirichlet which isn't implemented for MPS, so sample on CPU
|
||||||
|
alpha_t = torch.tensor(alpha, dtype=torch.float32)
|
||||||
|
beta_t = torch.tensor(beta, dtype=torch.float32)
|
||||||
|
dist = torch.distributions.Beta(alpha_t, beta_t)
|
||||||
|
return dist.sample((bsize,)).to(device)
|
||||||
|
|
||||||
|
|
||||||
|
class EO1VisionActionProjector(torch.nn.Sequential):
|
||||||
|
"""This block implements the multi-layer perceptron (MLP) module."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
num_layers: int = 2,
|
||||||
|
activation_layer: str = "linear",
|
||||||
|
bias: bool = True,
|
||||||
|
device: Any = None,
|
||||||
|
dtype: torch.dtype = torch.float32,
|
||||||
|
):
|
||||||
|
layers = []
|
||||||
|
in_dim = in_channels
|
||||||
|
hidden_channels = [in_dim] * (num_layers - 1) + [out_channels]
|
||||||
|
for hidden_dim in hidden_channels[:-1]:
|
||||||
|
layers.append(torch.nn.Linear(in_dim, hidden_dim, bias=bias, dtype=dtype, device=device))
|
||||||
|
layers.append(ACT2FN[activation_layer])
|
||||||
|
in_dim = hidden_dim
|
||||||
|
layers.append(torch.nn.Linear(in_dim, hidden_channels[-1], bias=bias, dtype=dtype, device=device))
|
||||||
|
super().__init__(*layers)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dtype(self):
|
||||||
|
return self[0].weight.dtype
|
||||||
|
|
||||||
|
|
||||||
|
class EO1VisionFlowMatchingModel(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: EO1Config,
|
||||||
|
vlm_backbone: Qwen2_5_VLForConditionalGeneration | None = None,
|
||||||
|
):
|
||||||
|
require_package("transformers", extra="eo1")
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
# Preserve the backbone dtype selected at construction time so Qwen's fp32 rotary buffers stay intact.
|
||||||
|
self.vlm_backbone = vlm_backbone
|
||||||
|
self.hidden_size = self.vlm_backbone.config.text_config.hidden_size
|
||||||
|
max_state_dim = config.max_state_dim
|
||||||
|
max_action_dim = config.max_action_dim
|
||||||
|
self.state_proj = nn.Linear(max_state_dim, self.hidden_size, dtype=torch.float32)
|
||||||
|
self.action_in_proj = nn.Linear(max_action_dim, self.hidden_size, dtype=torch.float32)
|
||||||
|
self.action_out_proj = EO1VisionActionProjector(
|
||||||
|
self.hidden_size,
|
||||||
|
max_action_dim,
|
||||||
|
config.num_action_layers,
|
||||||
|
config.action_act,
|
||||||
|
dtype=torch.float32,
|
||||||
|
)
|
||||||
|
self.action_time_mlp_in = nn.Linear(self.hidden_size * 2, self.hidden_size, dtype=torch.float32)
|
||||||
|
self.action_time_mlp_out = nn.Linear(self.hidden_size, self.hidden_size, dtype=torch.float32)
|
||||||
|
self.gradient_checkpointing_enabled = False
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.vlm_backbone.get_input_embeddings()
|
||||||
|
|
||||||
|
def flow_head_autocast_context(self):
|
||||||
|
if self.config.force_fp32_autocast:
|
||||||
|
return torch.autocast(
|
||||||
|
device_type=self.state_proj.weight.device.type,
|
||||||
|
enabled=False,
|
||||||
|
)
|
||||||
|
return contextlib.nullcontext()
|
||||||
|
|
||||||
|
def gradient_checkpointing_enable(self):
|
||||||
|
"""Enable gradient checkpointing for the Qwen2.5-VL backbone."""
|
||||||
|
self.gradient_checkpointing_enabled = True
|
||||||
|
self.vlm_backbone.gradient_checkpointing_enable(
|
||||||
|
gradient_checkpointing_kwargs={"use_reentrant": False}
|
||||||
|
)
|
||||||
|
logger.info("Enabled gradient checkpointing for EO1VisionFlowMatchingModel")
|
||||||
|
|
||||||
|
def gradient_checkpointing_disable(self):
|
||||||
|
"""Disable gradient checkpointing for the Qwen2.5-VL backbone."""
|
||||||
|
self.gradient_checkpointing_enabled = False
|
||||||
|
self.vlm_backbone.gradient_checkpointing_disable()
|
||||||
|
logger.info("Disabled gradient checkpointing for EO1VisionFlowMatchingModel")
|
||||||
|
|
||||||
|
def _apply_checkpoint(self, func, *args, **kwargs):
|
||||||
|
"""Apply manual gradient checkpointing to EO1 flow-head computations when training."""
|
||||||
|
if self.gradient_checkpointing_enabled and self.training and torch.is_grad_enabled():
|
||||||
|
return torch.utils.checkpoint.checkpoint(
|
||||||
|
func, *args, use_reentrant=False, preserve_rng_state=False, **kwargs
|
||||||
|
)
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
def sample_noise(self, shape, device):
|
||||||
|
noise = torch.normal(
|
||||||
|
mean=0.0,
|
||||||
|
std=1.0,
|
||||||
|
size=shape,
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
return noise
|
||||||
|
|
||||||
|
def sample_time(self, bsize, device):
|
||||||
|
time_beta = sample_beta(
|
||||||
|
self.config.time_sampling_beta_alpha, self.config.time_sampling_beta_beta, bsize, device
|
||||||
|
)
|
||||||
|
time = time_beta * self.config.time_sampling_scale + self.config.time_sampling_offset
|
||||||
|
return time.to(dtype=torch.float32, device=device)
|
||||||
|
|
||||||
|
def get_placeholder_mask(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor | None,
|
||||||
|
inputs_embeds: torch.FloatTensor | None,
|
||||||
|
state_features: torch.FloatTensor | None = None,
|
||||||
|
action_features: torch.FloatTensor | None = None,
|
||||||
|
*,
|
||||||
|
state_token_id: int,
|
||||||
|
action_token_id: int,
|
||||||
|
) -> tuple[torch.BoolTensor, torch.BoolTensor]:
|
||||||
|
"""Return EO1 state/action placeholder masks, following Qwen's multimodal mask style."""
|
||||||
|
if input_ids is None:
|
||||||
|
special_state_mask = inputs_embeds == self.get_input_embeddings()(
|
||||||
|
torch.tensor(state_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||||
|
)
|
||||||
|
special_state_mask = special_state_mask.all(-1)
|
||||||
|
special_action_mask = inputs_embeds == self.get_input_embeddings()(
|
||||||
|
torch.tensor(action_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||||
|
)
|
||||||
|
special_action_mask = special_action_mask.all(-1)
|
||||||
|
else:
|
||||||
|
special_state_mask = input_ids == state_token_id
|
||||||
|
special_action_mask = input_ids == action_token_id
|
||||||
|
|
||||||
|
n_state_tokens = special_state_mask.sum()
|
||||||
|
special_state_mask = (
|
||||||
|
special_state_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||||
|
)
|
||||||
|
if state_features is not None:
|
||||||
|
torch_compilable_check(
|
||||||
|
inputs_embeds[special_state_mask].numel() == state_features.numel(),
|
||||||
|
f"State features and state tokens do not match, tokens: {n_state_tokens}, features: {state_features.shape[0]}",
|
||||||
|
)
|
||||||
|
|
||||||
|
n_action_tokens = special_action_mask.sum()
|
||||||
|
special_action_mask = (
|
||||||
|
special_action_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||||
|
)
|
||||||
|
if action_features is not None:
|
||||||
|
torch_compilable_check(
|
||||||
|
inputs_embeds[special_action_mask].numel() == action_features.numel(),
|
||||||
|
f"Action features and action tokens do not match, tokens: {n_action_tokens}, features: {action_features.shape[0]}",
|
||||||
|
)
|
||||||
|
|
||||||
|
return special_state_mask, special_action_mask
|
||||||
|
|
||||||
|
def embed_prefix(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor,
|
||||||
|
states: torch.Tensor,
|
||||||
|
*,
|
||||||
|
state_token_id: int,
|
||||||
|
action_token_id: int,
|
||||||
|
) -> torch.FloatTensor:
|
||||||
|
"""Embed the EO1 prefix tokens before native Qwen injects multimodal features."""
|
||||||
|
|
||||||
|
# Get the input embeddings for the input IDs
|
||||||
|
def input_embed_func(input_ids: torch.LongTensor) -> torch.FloatTensor:
|
||||||
|
return self.get_input_embeddings()(input_ids)
|
||||||
|
|
||||||
|
inputs_embeds = self._apply_checkpoint(input_embed_func, input_ids)
|
||||||
|
|
||||||
|
# Project the states to the hidden size
|
||||||
|
def state_proj_func(states: torch.Tensor) -> torch.FloatTensor:
|
||||||
|
with self.flow_head_autocast_context():
|
||||||
|
states = states.to(dtype=self.state_proj.weight.dtype)
|
||||||
|
return self.state_proj(states)
|
||||||
|
|
||||||
|
state_embs = self._apply_checkpoint(state_proj_func, states)
|
||||||
|
state_mask, _ = self.get_placeholder_mask(
|
||||||
|
input_ids,
|
||||||
|
inputs_embeds,
|
||||||
|
state_features=state_embs,
|
||||||
|
state_token_id=state_token_id,
|
||||||
|
action_token_id=action_token_id,
|
||||||
|
)
|
||||||
|
state_embs = state_embs.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||||
|
inputs_embeds = inputs_embeds.masked_scatter(state_mask, state_embs)
|
||||||
|
return inputs_embeds
|
||||||
|
|
||||||
|
def embed_suffix(
|
||||||
|
self,
|
||||||
|
timestep: torch.Tensor,
|
||||||
|
noisy_actions: torch.Tensor,
|
||||||
|
) -> torch.FloatTensor:
|
||||||
|
"""Embed the suffix"""
|
||||||
|
|
||||||
|
def action_proj_func(noisy_actions: torch.Tensor) -> torch.FloatTensor:
|
||||||
|
with self.flow_head_autocast_context():
|
||||||
|
noisy_actions = noisy_actions.to(dtype=self.action_in_proj.weight.dtype)
|
||||||
|
return self.action_in_proj(noisy_actions)
|
||||||
|
|
||||||
|
action_embs = self._apply_checkpoint(action_proj_func, noisy_actions)
|
||||||
|
time_embs = create_sinusoidal_pos_embedding(
|
||||||
|
timestep,
|
||||||
|
self.hidden_size,
|
||||||
|
min_period=self.config.min_period,
|
||||||
|
max_period=self.config.max_period,
|
||||||
|
device=action_embs.device,
|
||||||
|
)
|
||||||
|
time_embs = time_embs.to(dtype=action_embs.dtype)
|
||||||
|
time_embs = time_embs[:, None, :].expand_as(action_embs)
|
||||||
|
action_time_embs = torch.cat([action_embs, time_embs], dim=2)
|
||||||
|
|
||||||
|
def mlp_func(action_time_embs: torch.Tensor) -> torch.FloatTensor:
|
||||||
|
with self.flow_head_autocast_context():
|
||||||
|
action_time_embs = action_time_embs.to(dtype=self.action_time_mlp_in.weight.dtype)
|
||||||
|
action_time_embs = self.action_time_mlp_in(action_time_embs)
|
||||||
|
action_time_embs = F.silu(action_time_embs)
|
||||||
|
return self.action_time_mlp_out(action_time_embs)
|
||||||
|
|
||||||
|
action_time_embs = self._apply_checkpoint(mlp_func, action_time_embs)
|
||||||
|
return action_time_embs
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor | None = None,
|
||||||
|
attention_mask: torch.LongTensor | None = None,
|
||||||
|
pixel_values: torch.FloatTensor | None = None,
|
||||||
|
image_grid_thw: torch.LongTensor | None = None,
|
||||||
|
mm_token_type_ids: torch.IntTensor | None = None,
|
||||||
|
states: torch.FloatTensor | None = None,
|
||||||
|
action: torch.FloatTensor | None = None,
|
||||||
|
action_is_pad: torch.BoolTensor | None = None,
|
||||||
|
*,
|
||||||
|
state_token_id: int,
|
||||||
|
action_token_id: int,
|
||||||
|
**kwargs,
|
||||||
|
) -> Tensor:
|
||||||
|
"""Run the EO1 training forward pass and compute the flow-matching loss."""
|
||||||
|
|
||||||
|
# 1. Build the EO1 prefix with state placeholders resolved.
|
||||||
|
inputs_embeds = self.embed_prefix(
|
||||||
|
input_ids,
|
||||||
|
states=states,
|
||||||
|
state_token_id=state_token_id,
|
||||||
|
action_token_id=action_token_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. Sample the diffusion target and replace the action placeholders.
|
||||||
|
time = self.sample_time(action.shape[0], inputs_embeds.device)
|
||||||
|
noise = self.sample_noise(action.shape, inputs_embeds.device)
|
||||||
|
|
||||||
|
time_expanded = time[:, None, None]
|
||||||
|
x_t = time_expanded * noise + (1 - time_expanded) * action
|
||||||
|
u_t = noise - action
|
||||||
|
action_time_embs = self.embed_suffix(time, x_t)
|
||||||
|
_, action_mask = self.get_placeholder_mask(
|
||||||
|
input_ids,
|
||||||
|
inputs_embeds,
|
||||||
|
action_features=action_time_embs,
|
||||||
|
state_token_id=state_token_id,
|
||||||
|
action_token_id=action_token_id,
|
||||||
|
)
|
||||||
|
action_time_embs = action_time_embs.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||||
|
inputs_embeds = inputs_embeds.masked_scatter(action_mask, action_time_embs)
|
||||||
|
|
||||||
|
# 3. Optionally drop padded action tokens from backbone attention.
|
||||||
|
if attention_mask is not None:
|
||||||
|
attention_mask = attention_mask.to(inputs_embeds.device)
|
||||||
|
|
||||||
|
if not self.config.supervise_padding_actions:
|
||||||
|
action_is_pad = action_is_pad.to(device=inputs_embeds.device, dtype=torch.bool)
|
||||||
|
action_token_mask = action_mask[..., 0]
|
||||||
|
action_padding_mask = torch.zeros_like(action_token_mask)
|
||||||
|
action_padding_mask = action_padding_mask.masked_scatter(
|
||||||
|
action_token_mask,
|
||||||
|
action_is_pad.reshape(-1),
|
||||||
|
)
|
||||||
|
attention_mask = attention_mask.masked_fill(action_padding_mask, 0)
|
||||||
|
|
||||||
|
# 4. Run the Qwen backbone on the fused EO1 sequence.
|
||||||
|
def vlm_forward_func(
|
||||||
|
input_ids: torch.LongTensor,
|
||||||
|
attention_mask: torch.Tensor | None,
|
||||||
|
inputs_embeds: torch.FloatTensor,
|
||||||
|
pixel_values: torch.Tensor | None,
|
||||||
|
image_grid_thw: torch.LongTensor | None,
|
||||||
|
mm_token_type_ids: torch.IntTensor | None,
|
||||||
|
) -> torch.FloatTensor:
|
||||||
|
outputs = self.vlm_backbone.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
pixel_values=pixel_values,
|
||||||
|
image_grid_thw=image_grid_thw,
|
||||||
|
mm_token_type_ids=mm_token_type_ids,
|
||||||
|
use_cache=False,
|
||||||
|
output_hidden_states=False,
|
||||||
|
return_dict=True,
|
||||||
|
)
|
||||||
|
return outputs.last_hidden_state
|
||||||
|
|
||||||
|
hidden_states = self._apply_checkpoint(
|
||||||
|
vlm_forward_func,
|
||||||
|
input_ids,
|
||||||
|
attention_mask,
|
||||||
|
inputs_embeds,
|
||||||
|
pixel_values,
|
||||||
|
image_grid_thw,
|
||||||
|
mm_token_type_ids,
|
||||||
|
)
|
||||||
|
action_hidden_states = hidden_states[action_mask[..., 0]]
|
||||||
|
|
||||||
|
# 5. Project the action-token hidden states back to the flow target space.
|
||||||
|
def action_out_proj_func(action_hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
|
with self.flow_head_autocast_context():
|
||||||
|
action_hidden_states = action_hidden_states.to(dtype=self.action_out_proj.dtype)
|
||||||
|
return self.action_out_proj(action_hidden_states)
|
||||||
|
|
||||||
|
v_t = self._apply_checkpoint(action_out_proj_func, action_hidden_states)
|
||||||
|
v_t = v_t.reshape(u_t.shape).to(dtype=u_t.dtype)
|
||||||
|
losses = F.mse_loss(u_t, v_t, reduction="none")
|
||||||
|
|
||||||
|
# 6. Apply the configured supervision mask and reduce the loss.
|
||||||
|
if not self.config.supervise_padding_action_dims:
|
||||||
|
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||||
|
losses = losses[..., :original_action_dim]
|
||||||
|
|
||||||
|
if not self.config.supervise_padding_actions:
|
||||||
|
losses = losses[~action_is_pad]
|
||||||
|
|
||||||
|
return losses.mean()
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample_actions(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor | None = None,
|
||||||
|
attention_mask: torch.Tensor | None = None,
|
||||||
|
pixel_values: torch.Tensor | None = None,
|
||||||
|
image_grid_thw: torch.LongTensor | None = None,
|
||||||
|
mm_token_type_ids: torch.IntTensor | None = None,
|
||||||
|
states: torch.Tensor | None = None,
|
||||||
|
*,
|
||||||
|
state_token_id: int,
|
||||||
|
action_token_id: int,
|
||||||
|
**kwargs,
|
||||||
|
) -> Tensor:
|
||||||
|
"""Sample actions from the model."""
|
||||||
|
if states is None:
|
||||||
|
raise ValueError("states are required for EO1 action sampling.")
|
||||||
|
if mm_token_type_ids is None:
|
||||||
|
raise ValueError("mm_token_type_ids are required for EO1 action sampling.")
|
||||||
|
|
||||||
|
# 1. Resolve the left-padded rollout prompt and locate the action span.
|
||||||
|
chunk_size = self.config.chunk_size
|
||||||
|
|
||||||
|
inputs_embeds = self.embed_prefix(
|
||||||
|
input_ids,
|
||||||
|
states=states,
|
||||||
|
state_token_id=state_token_id,
|
||||||
|
action_token_id=action_token_id,
|
||||||
|
).clone()
|
||||||
|
_, action_placeholder_mask = self.get_placeholder_mask(
|
||||||
|
input_ids,
|
||||||
|
inputs_embeds,
|
||||||
|
state_token_id=state_token_id,
|
||||||
|
action_token_id=action_token_id,
|
||||||
|
)
|
||||||
|
action_mask = action_placeholder_mask[..., 0]
|
||||||
|
token_counts = action_mask.sum(dim=1)
|
||||||
|
if not torch.all(token_counts == chunk_size):
|
||||||
|
raise ValueError(
|
||||||
|
f"Each sample must contain exactly {chunk_size} action tokens, got {token_counts.tolist()}."
|
||||||
|
)
|
||||||
|
if action_mask.ne(action_mask[:1]).any():
|
||||||
|
raise ValueError(
|
||||||
|
"Batch inference expects all samples to share the same action token mask after left padding."
|
||||||
|
)
|
||||||
|
act_start = int(action_mask[0].to(torch.int64).argmax().item())
|
||||||
|
act_end = act_start + self.config.chunk_size
|
||||||
|
if not torch.all(action_mask[:, act_start:act_end]):
|
||||||
|
raise ValueError("Action tokens must form a contiguous chunk of length chunk_size.")
|
||||||
|
act_slice = slice(act_start, act_end)
|
||||||
|
|
||||||
|
# 2. Encode the fixed prefix once and cache its KV state.
|
||||||
|
batch_size = input_ids.shape[0]
|
||||||
|
device = inputs_embeds.device
|
||||||
|
attention_mask = attention_mask.to(device)
|
||||||
|
mm_token_type_ids = mm_token_type_ids.to(device)
|
||||||
|
position_ids, _ = self.vlm_backbone.model.get_rope_index(
|
||||||
|
input_ids,
|
||||||
|
image_grid_thw=image_grid_thw,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
mm_token_type_ids=mm_token_type_ids,
|
||||||
|
)
|
||||||
|
position_ids = position_ids.to(device)
|
||||||
|
|
||||||
|
outputs = self.vlm_backbone.model(
|
||||||
|
input_ids=input_ids[:, :act_start],
|
||||||
|
attention_mask=attention_mask[:, :act_start],
|
||||||
|
position_ids=position_ids[..., :act_start],
|
||||||
|
inputs_embeds=inputs_embeds[:, :act_start],
|
||||||
|
pixel_values=pixel_values,
|
||||||
|
image_grid_thw=image_grid_thw,
|
||||||
|
mm_token_type_ids=mm_token_type_ids[:, :act_start],
|
||||||
|
use_cache=True,
|
||||||
|
return_dict=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
x_t = self.sample_noise(
|
||||||
|
(batch_size, chunk_size, self.config.max_action_dim),
|
||||||
|
device,
|
||||||
|
).to(dtype=self.action_in_proj.weight.dtype)
|
||||||
|
dt = -1.0 / self.config.num_denoise_steps
|
||||||
|
past_key_values = outputs.past_key_values
|
||||||
|
|
||||||
|
# 3. Denoise only the action chunk while keeping the prefix cache invariant.
|
||||||
|
for step in range(self.config.num_denoise_steps):
|
||||||
|
time = torch.full(
|
||||||
|
(batch_size,),
|
||||||
|
1.0 + step * dt,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.float32,
|
||||||
|
)
|
||||||
|
action_time_embs = self.embed_suffix(time, x_t)
|
||||||
|
inputs_embeds[:, act_slice] = action_time_embs.to(inputs_embeds.dtype)
|
||||||
|
|
||||||
|
# Keep the prefix KV cache invariant across denoising steps.
|
||||||
|
past_key_values.crop(act_start)
|
||||||
|
outputs = self.vlm_backbone.model(
|
||||||
|
attention_mask=attention_mask[:, :act_end],
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds[:, act_slice],
|
||||||
|
position_ids=position_ids[..., act_slice],
|
||||||
|
use_cache=True,
|
||||||
|
return_dict=True,
|
||||||
|
)
|
||||||
|
with self.flow_head_autocast_context():
|
||||||
|
hidden_states = outputs.last_hidden_state[:, :chunk_size]
|
||||||
|
hidden_states = hidden_states.to(dtype=self.action_out_proj.dtype)
|
||||||
|
v_t = self.action_out_proj(hidden_states)
|
||||||
|
|
||||||
|
x_t += dt * v_t.reshape(x_t.shape)
|
||||||
|
|
||||||
|
return x_t
|
||||||
@@ -0,0 +1,282 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
|
||||||
|
from lerobot.policies.eo1.configuration_eo1 import EO1Config
|
||||||
|
from lerobot.processor import (
|
||||||
|
AddBatchDimensionProcessorStep,
|
||||||
|
ComplementaryDataProcessorStep,
|
||||||
|
DeviceProcessorStep,
|
||||||
|
NormalizerProcessorStep,
|
||||||
|
PolicyAction,
|
||||||
|
PolicyProcessorPipeline,
|
||||||
|
ProcessorStep,
|
||||||
|
ProcessorStepRegistry,
|
||||||
|
RenameObservationsProcessorStep,
|
||||||
|
UnnormalizerProcessorStep,
|
||||||
|
)
|
||||||
|
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
|
||||||
|
from lerobot.types import TransitionKey
|
||||||
|
from lerobot.utils.constants import (
|
||||||
|
OBS_STATE,
|
||||||
|
POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||||
|
POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||||
|
)
|
||||||
|
from lerobot.utils.import_utils import _transformers_available, require_package
|
||||||
|
|
||||||
|
if TYPE_CHECKING or _transformers_available:
|
||||||
|
from transformers.models.qwen2_5_vl import Qwen2_5_VLProcessor
|
||||||
|
else:
|
||||||
|
Qwen2_5_VLProcessor = None
|
||||||
|
|
||||||
|
SYSTEM_MESSAGE = "You are a helpful physical assistant."
|
||||||
|
|
||||||
|
# EO-1 special tokens
|
||||||
|
ACTION_START_TOKEN = "<|action_start|>" # nosec B105
|
||||||
|
DEFAULT_ACTION_TOKEN = "<|action_pad|>" # nosec B105
|
||||||
|
ACTION_END_TOKEN = "<|action_end|>" # nosec B105
|
||||||
|
STATE_START_TOKEN = "<|state_start|>" # nosec B105
|
||||||
|
DEFAULT_STATE_TOKEN = "<|state_pad|>" # nosec B105
|
||||||
|
STATE_END_TOKEN = "<|state_end|>" # nosec B105
|
||||||
|
TASK_VLA_TOKEN = "<|vla|>" # nosec B105
|
||||||
|
|
||||||
|
EO1_SPECIAL_TOKENS = [
|
||||||
|
ACTION_START_TOKEN,
|
||||||
|
DEFAULT_ACTION_TOKEN,
|
||||||
|
ACTION_END_TOKEN,
|
||||||
|
STATE_START_TOKEN,
|
||||||
|
DEFAULT_STATE_TOKEN,
|
||||||
|
STATE_END_TOKEN,
|
||||||
|
TASK_VLA_TOKEN,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
@ProcessorStepRegistry.register(name="eo1_conversation_template_processor")
|
||||||
|
class EO1ConversationTemplateStep(ComplementaryDataProcessorStep):
|
||||||
|
input_features: dict[str, PolicyFeature] | dict[str, dict[str, Any]]
|
||||||
|
chunk_size: int
|
||||||
|
|
||||||
|
_image_keys: list[str] = field(default_factory=list, init=False, repr=False)
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
# Robust JSON deserialization handling (guard empty maps).
|
||||||
|
if self.input_features:
|
||||||
|
first_val = next(iter(self.input_features.values()))
|
||||||
|
if isinstance(first_val, dict):
|
||||||
|
reconstructed = {}
|
||||||
|
for key, ft_dict in self.input_features.items():
|
||||||
|
reconstructed[key] = PolicyFeature(
|
||||||
|
type=FeatureType(ft_dict["type"]), shape=tuple(ft_dict["shape"])
|
||||||
|
)
|
||||||
|
self.input_features = reconstructed
|
||||||
|
|
||||||
|
self._image_keys = [
|
||||||
|
key for key, value in self.input_features.items() if value.type == FeatureType.VISUAL
|
||||||
|
]
|
||||||
|
|
||||||
|
def complementary_data(self, complementary_data):
|
||||||
|
tasks = complementary_data.get("task")
|
||||||
|
if tasks is None:
|
||||||
|
raise ValueError("Task is required for EO1ConversationTemplateStep.")
|
||||||
|
|
||||||
|
observation = self.transition.get(TransitionKey.OBSERVATION)
|
||||||
|
if observation is None:
|
||||||
|
raise ValueError("Observation is required for EO1ConversationTemplateStep.")
|
||||||
|
|
||||||
|
if OBS_STATE in observation and observation[OBS_STATE].shape[0] != len(tasks):
|
||||||
|
raise ValueError("Batch size mismatch between observation.state and task list.")
|
||||||
|
|
||||||
|
# LeRobot visual observations reach in processor as float32 tensors in [0, 1].
|
||||||
|
# Convert to uint8 in [0, 255] to meet the input requirement of Qwen2.5-VL-3B-Instruct.
|
||||||
|
images = {
|
||||||
|
key: observation[key].clamp(0, 1).mul(255.0).round().to(torch.uint8) for key in self._image_keys
|
||||||
|
}
|
||||||
|
messages = []
|
||||||
|
for i in range(len(tasks)):
|
||||||
|
content = [
|
||||||
|
*[{"type": "image", "image": images[key][i]} for key in self._image_keys],
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": (
|
||||||
|
f"{STATE_START_TOKEN}{DEFAULT_STATE_TOKEN}{STATE_END_TOKEN}{tasks[i]}{TASK_VLA_TOKEN}"
|
||||||
|
),
|
||||||
|
},
|
||||||
|
]
|
||||||
|
messages.append(
|
||||||
|
[
|
||||||
|
{"role": "system", "content": [{"type": "text", "text": SYSTEM_MESSAGE}]},
|
||||||
|
{"role": "user", "content": content},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": f"{ACTION_START_TOKEN}{DEFAULT_ACTION_TOKEN * self.chunk_size}{ACTION_END_TOKEN}",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
complementary_data["messages"] = messages
|
||||||
|
|
||||||
|
return complementary_data
|
||||||
|
|
||||||
|
def transform_features(
|
||||||
|
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||||
|
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||||
|
"""
|
||||||
|
This step only materializes EO1-specific message objects in complementary_data.
|
||||||
|
PipelineFeatureType tracks only ACTION and OBSERVATION, so there is no static
|
||||||
|
feature contract change to record here.
|
||||||
|
"""
|
||||||
|
return features
|
||||||
|
|
||||||
|
def get_config(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"input_features": {
|
||||||
|
key: {"type": ft.type.value, "shape": ft.shape} for key, ft in self.input_features.items()
|
||||||
|
},
|
||||||
|
"chunk_size": self.chunk_size,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
@ProcessorStepRegistry.register(name="eo1_qwen_processor")
|
||||||
|
class EO1QwenProcessorStep(ComplementaryDataProcessorStep):
|
||||||
|
processor_name: str = "Qwen/Qwen2.5-VL-3B-Instruct"
|
||||||
|
image_min_pixels: int | None = 64 * 28 * 28
|
||||||
|
image_max_pixels: int | None = 128 * 28 * 28
|
||||||
|
use_fast_processor: bool = False
|
||||||
|
|
||||||
|
_processor: Qwen2_5_VLProcessor | None = field(default=None, init=False, repr=False)
|
||||||
|
_state_token_id: int | None = field(default=None, init=False, repr=False)
|
||||||
|
_action_token_id: int | None = field(default=None, init=False, repr=False)
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
require_package("transformers", extra="eo1")
|
||||||
|
self._processor = Qwen2_5_VLProcessor.from_pretrained(
|
||||||
|
self.processor_name,
|
||||||
|
use_fast=self.use_fast_processor,
|
||||||
|
)
|
||||||
|
self._processor.tokenizer.add_tokens(EO1_SPECIAL_TOKENS, special_tokens=True)
|
||||||
|
self._state_token_id = self._processor.tokenizer.convert_tokens_to_ids(DEFAULT_STATE_TOKEN)
|
||||||
|
self._action_token_id = self._processor.tokenizer.convert_tokens_to_ids(DEFAULT_ACTION_TOKEN)
|
||||||
|
|
||||||
|
def complementary_data(self, complementary_data):
|
||||||
|
messages = complementary_data.pop("messages", None)
|
||||||
|
if messages is None:
|
||||||
|
raise ValueError("Messages are required for EO1QwenProcessorStep.")
|
||||||
|
|
||||||
|
# Rollout batches use left padding so action spans stay aligned across samples.
|
||||||
|
# Supervised batches use right padding to match standard training collation.
|
||||||
|
padding_side = "right" if self.transition.get(TransitionKey.ACTION) is not None else "left"
|
||||||
|
|
||||||
|
inputs = self._processor.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
tokenize=True,
|
||||||
|
padding=True,
|
||||||
|
padding_side=padding_side,
|
||||||
|
min_pixels=self.image_min_pixels,
|
||||||
|
max_pixels=self.image_max_pixels,
|
||||||
|
add_generation_prompt=False,
|
||||||
|
return_dict=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
|
||||||
|
complementary_data["input_ids"] = inputs["input_ids"]
|
||||||
|
complementary_data["pixel_values"] = inputs["pixel_values"]
|
||||||
|
complementary_data["image_grid_thw"] = inputs["image_grid_thw"]
|
||||||
|
complementary_data["attention_mask"] = inputs["attention_mask"]
|
||||||
|
complementary_data["mm_token_type_ids"] = inputs["mm_token_type_ids"]
|
||||||
|
complementary_data["state_token_id"] = self._state_token_id
|
||||||
|
complementary_data["action_token_id"] = self._action_token_id
|
||||||
|
|
||||||
|
return complementary_data
|
||||||
|
|
||||||
|
def get_config(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"processor_name": self.processor_name,
|
||||||
|
"image_min_pixels": self.image_min_pixels,
|
||||||
|
"image_max_pixels": self.image_max_pixels,
|
||||||
|
"use_fast_processor": self.use_fast_processor,
|
||||||
|
}
|
||||||
|
|
||||||
|
def transform_features(
|
||||||
|
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||||
|
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||||
|
"""
|
||||||
|
This step only converts the messages to the model input format.
|
||||||
|
"""
|
||||||
|
return features
|
||||||
|
|
||||||
|
|
||||||
|
def make_eo1_pre_post_processors(
|
||||||
|
config: EO1Config,
|
||||||
|
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||||
|
) -> tuple[
|
||||||
|
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||||
|
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||||
|
]:
|
||||||
|
"""Build pre/post processor pipelines for EO1."""
|
||||||
|
|
||||||
|
input_steps: list[ProcessorStep] = [
|
||||||
|
RenameObservationsProcessorStep(rename_map={}),
|
||||||
|
AddBatchDimensionProcessorStep(),
|
||||||
|
NormalizerProcessorStep(
|
||||||
|
features={**config.input_features, **config.output_features},
|
||||||
|
norm_map=config.normalization_mapping,
|
||||||
|
stats=dataset_stats,
|
||||||
|
),
|
||||||
|
EO1ConversationTemplateStep(input_features=config.input_features, chunk_size=config.chunk_size),
|
||||||
|
EO1QwenProcessorStep(
|
||||||
|
processor_name=config.vlm_base,
|
||||||
|
image_min_pixels=config.image_min_pixels,
|
||||||
|
image_max_pixels=config.image_max_pixels,
|
||||||
|
use_fast_processor=config.use_fast_processor,
|
||||||
|
),
|
||||||
|
DeviceProcessorStep(device=config.device),
|
||||||
|
]
|
||||||
|
|
||||||
|
output_steps: list[ProcessorStep] = [
|
||||||
|
UnnormalizerProcessorStep(
|
||||||
|
features=config.output_features,
|
||||||
|
norm_map=config.normalization_mapping,
|
||||||
|
stats=dataset_stats,
|
||||||
|
),
|
||||||
|
DeviceProcessorStep(device="cpu"),
|
||||||
|
]
|
||||||
|
|
||||||
|
return (
|
||||||
|
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||||
|
steps=input_steps,
|
||||||
|
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||||
|
),
|
||||||
|
PolicyProcessorPipeline[PolicyAction, PolicyAction](
|
||||||
|
steps=output_steps,
|
||||||
|
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||||
|
to_transition=policy_action_to_transition,
|
||||||
|
to_output=transition_to_policy_action,
|
||||||
|
),
|
||||||
|
)
|
||||||
+1
@@ -0,0 +1 @@
|
|||||||
|
../../../../docs/source/policy_evo1_README.md
|
||||||
@@ -0,0 +1,19 @@
|
|||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from .configuration_evo1 import Evo1Config
|
||||||
|
from .modeling_evo1 import EVO1Policy
|
||||||
|
from .processor_evo1 import make_evo1_pre_post_processors
|
||||||
|
|
||||||
|
__all__ = ["Evo1Config", "EVO1Policy", "make_evo1_pre_post_processors"]
|
||||||
@@ -0,0 +1,225 @@
|
|||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import math
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
from torch.optim import Optimizer
|
||||||
|
from torch.optim.lr_scheduler import LambdaLR
|
||||||
|
|
||||||
|
from lerobot.configs.policies import PreTrainedConfig
|
||||||
|
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||||
|
from lerobot.optim.optimizers import AdamWConfig
|
||||||
|
from lerobot.optim.schedulers import LRSchedulerConfig
|
||||||
|
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||||
|
|
||||||
|
|
||||||
|
@LRSchedulerConfig.register_subclass("evo1_exact")
|
||||||
|
@dataclass
|
||||||
|
class Evo1SchedulerConfig(LRSchedulerConfig):
|
||||||
|
num_warmup_steps: int
|
||||||
|
|
||||||
|
def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR:
|
||||||
|
def lr_lambda(current_step: int) -> float:
|
||||||
|
if current_step < self.num_warmup_steps:
|
||||||
|
return current_step / max(1, self.num_warmup_steps)
|
||||||
|
progress = (current_step - self.num_warmup_steps) / max(
|
||||||
|
1, num_training_steps - self.num_warmup_steps
|
||||||
|
)
|
||||||
|
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))
|
||||||
|
|
||||||
|
return LambdaLR(optimizer, lr_lambda, -1)
|
||||||
|
|
||||||
|
|
||||||
|
@PreTrainedConfig.register_subclass("evo1")
|
||||||
|
@dataclass
|
||||||
|
class Evo1Config(PreTrainedConfig):
|
||||||
|
training_stage: str = "stage1"
|
||||||
|
use_amp: bool = True
|
||||||
|
|
||||||
|
n_obs_steps: int = 1
|
||||||
|
chunk_size: int = 50
|
||||||
|
n_action_steps: int = 50
|
||||||
|
|
||||||
|
max_state_dim: int = 24
|
||||||
|
max_action_dim: int = 24
|
||||||
|
max_views: int = 3
|
||||||
|
image_resolution: tuple[int, int] = (448, 448)
|
||||||
|
empty_cameras: int = 0
|
||||||
|
|
||||||
|
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
"VISUAL": NormalizationMode.IDENTITY,
|
||||||
|
"STATE": NormalizationMode.MIN_MAX,
|
||||||
|
"ACTION": NormalizationMode.MIN_MAX,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
vlm_model_name: str = "OpenGVLab/InternVL3-1B"
|
||||||
|
vlm_num_layers: int | None = 14
|
||||||
|
vlm_dtype: str = "bfloat16"
|
||||||
|
use_flash_attn: bool = True
|
||||||
|
action_head: str = "flowmatching"
|
||||||
|
embed_dim: int = 896
|
||||||
|
hidden_dim: int = 1024
|
||||||
|
state_hidden_dim: int = 1024
|
||||||
|
num_heads: int = 8
|
||||||
|
num_layers: int = 8
|
||||||
|
dropout: float = 0.0
|
||||||
|
num_inference_timesteps: int = 32
|
||||||
|
num_categories: int = 1
|
||||||
|
return_cls_only: bool = False
|
||||||
|
enable_gradient_checkpointing: bool = True
|
||||||
|
gradient_checkpointing_use_reentrant: bool = False
|
||||||
|
|
||||||
|
finetune_vlm: bool | None = None
|
||||||
|
finetune_language_model: bool | None = None
|
||||||
|
finetune_vision_model: bool | None = None
|
||||||
|
finetune_action_head: bool | None = None
|
||||||
|
# Reapply stage defaults after loading checkpoint configs so stage2 cannot
|
||||||
|
# accidentally inherit the frozen VLM flags stored by a stage1 checkpoint.
|
||||||
|
apply_training_stage_defaults: bool = True
|
||||||
|
|
||||||
|
task_field: str = "task"
|
||||||
|
embodiment_id_field: str | None = None
|
||||||
|
default_embodiment_id: int = 0
|
||||||
|
|
||||||
|
optimizer_lr: float = 1e-5
|
||||||
|
optimizer_betas: tuple[float, float] = (0.9, 0.999)
|
||||||
|
optimizer_eps: float = 1e-8
|
||||||
|
optimizer_weight_decay: float = 1e-5
|
||||||
|
optimizer_grad_clip_norm: float = 1.0
|
||||||
|
|
||||||
|
scheduler_warmup_steps: int = 300
|
||||||
|
drop_last: bool = True
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
super().__post_init__()
|
||||||
|
if self.training_stage not in {"stage1", "stage2"}:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported EVO1 training_stage '{self.training_stage}', expected 'stage1' or 'stage2'"
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.apply_training_stage_defaults:
|
||||||
|
if self.training_stage == "stage1":
|
||||||
|
self.finetune_vlm = False
|
||||||
|
self.finetune_language_model = False
|
||||||
|
self.finetune_vision_model = False
|
||||||
|
self.finetune_action_head = True
|
||||||
|
elif self.training_stage == "stage2":
|
||||||
|
self.finetune_vlm = True
|
||||||
|
self.finetune_language_model = True
|
||||||
|
self.finetune_vision_model = True
|
||||||
|
self.finetune_action_head = True
|
||||||
|
elif self.training_stage == "stage1":
|
||||||
|
if self.finetune_vlm is None:
|
||||||
|
self.finetune_vlm = False
|
||||||
|
if self.finetune_language_model is None:
|
||||||
|
self.finetune_language_model = False
|
||||||
|
if self.finetune_vision_model is None:
|
||||||
|
self.finetune_vision_model = False
|
||||||
|
if self.finetune_action_head is None:
|
||||||
|
self.finetune_action_head = True
|
||||||
|
elif self.training_stage == "stage2":
|
||||||
|
has_explicit_branch_flags = any(
|
||||||
|
flag is not None for flag in (self.finetune_language_model, self.finetune_vision_model)
|
||||||
|
)
|
||||||
|
if not has_explicit_branch_flags:
|
||||||
|
if self.finetune_vlm is None:
|
||||||
|
self.finetune_vlm = True
|
||||||
|
if self.finetune_language_model is None:
|
||||||
|
self.finetune_language_model = True
|
||||||
|
if self.finetune_vision_model is None:
|
||||||
|
self.finetune_vision_model = True
|
||||||
|
elif self.finetune_vlm is None:
|
||||||
|
self.finetune_vlm = bool(self.finetune_language_model or self.finetune_vision_model)
|
||||||
|
if self.finetune_action_head is None:
|
||||||
|
self.finetune_action_head = True
|
||||||
|
|
||||||
|
if self.finetune_vlm is None:
|
||||||
|
self.finetune_vlm = False
|
||||||
|
if self.finetune_language_model is None:
|
||||||
|
self.finetune_language_model = False
|
||||||
|
if self.finetune_vision_model is None:
|
||||||
|
self.finetune_vision_model = False
|
||||||
|
if self.finetune_action_head is None:
|
||||||
|
self.finetune_action_head = False
|
||||||
|
|
||||||
|
branch_vlm = self.finetune_language_model or self.finetune_vision_model
|
||||||
|
if self.finetune_vlm != branch_vlm:
|
||||||
|
raise ValueError(
|
||||||
|
"Inconsistent EVO1 finetune config: "
|
||||||
|
f"finetune_vlm={self.finetune_vlm} but "
|
||||||
|
f"(finetune_language_model or finetune_vision_model)={branch_vlm}. "
|
||||||
|
"When branch-level flags are used, finetune_vlm must match their effective union."
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.n_action_steps > self.chunk_size:
|
||||||
|
raise ValueError(
|
||||||
|
f"n_action_steps ({self.n_action_steps}) must be <= chunk_size ({self.chunk_size})"
|
||||||
|
)
|
||||||
|
|
||||||
|
def validate_features(self) -> None:
|
||||||
|
if self.input_features is None:
|
||||||
|
self.input_features = {}
|
||||||
|
if self.output_features is None:
|
||||||
|
self.output_features = {}
|
||||||
|
|
||||||
|
for i in range(self.empty_cameras):
|
||||||
|
key = OBS_IMAGES + f".empty_camera_{i}"
|
||||||
|
if key not in self.input_features:
|
||||||
|
self.input_features[key] = PolicyFeature(
|
||||||
|
type=FeatureType.VISUAL,
|
||||||
|
shape=(3, *self.image_resolution),
|
||||||
|
)
|
||||||
|
|
||||||
|
if OBS_STATE not in self.input_features:
|
||||||
|
self.input_features[OBS_STATE] = PolicyFeature(
|
||||||
|
type=FeatureType.STATE,
|
||||||
|
shape=(self.max_state_dim,),
|
||||||
|
)
|
||||||
|
|
||||||
|
if ACTION not in self.output_features:
|
||||||
|
self.output_features[ACTION] = PolicyFeature(
|
||||||
|
type=FeatureType.ACTION,
|
||||||
|
shape=(self.max_action_dim,),
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_optimizer_preset(self) -> AdamWConfig:
|
||||||
|
return AdamWConfig(
|
||||||
|
lr=self.optimizer_lr,
|
||||||
|
betas=self.optimizer_betas,
|
||||||
|
eps=self.optimizer_eps,
|
||||||
|
weight_decay=self.optimizer_weight_decay,
|
||||||
|
grad_clip_norm=self.optimizer_grad_clip_norm,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_scheduler_preset(self):
|
||||||
|
return Evo1SchedulerConfig(
|
||||||
|
num_warmup_steps=self.scheduler_warmup_steps,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def observation_delta_indices(self) -> list[int]:
|
||||||
|
return [0]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def action_delta_indices(self) -> list[int]:
|
||||||
|
return list(range(self.chunk_size))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def reward_delta_indices(self) -> None:
|
||||||
|
return None
|
||||||
@@ -0,0 +1,234 @@
|
|||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Sequence
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from lerobot.policies.evo1.flow_matching import FlowmatchingActionHead
|
||||||
|
from lerobot.policies.evo1.internvl3_embedder import InternVL3Embedder
|
||||||
|
|
||||||
|
|
||||||
|
def _cfgget(config: Any, key: str, default=None):
|
||||||
|
if isinstance(config, dict):
|
||||||
|
return config.get(key, default)
|
||||||
|
return getattr(config, key, default)
|
||||||
|
|
||||||
|
|
||||||
|
class EVO1(nn.Module):
|
||||||
|
def __init__(self, config: dict):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self._device = _cfgget(config, "device", "cuda")
|
||||||
|
self.return_cls_only = _cfgget(config, "return_cls_only", False)
|
||||||
|
vlm_name = _cfgget(config, "vlm_name", "OpenGVLab/InternVL3-1B")
|
||||||
|
image_size = _cfgget(config, "image_size", 448)
|
||||||
|
if image_size is None:
|
||||||
|
image_resolution = _cfgget(config, "image_resolution", (448, 448))
|
||||||
|
image_size = int(image_resolution[0])
|
||||||
|
|
||||||
|
self.embedder = InternVL3Embedder(
|
||||||
|
model_name=vlm_name,
|
||||||
|
image_size=image_size,
|
||||||
|
device=self._device,
|
||||||
|
num_language_layers=_cfgget(config, "vlm_num_layers", 14),
|
||||||
|
model_dtype=_cfgget(config, "vlm_dtype", "bfloat16"),
|
||||||
|
use_flash_attn=_cfgget(config, "use_flash_attn", True),
|
||||||
|
enable_gradient_checkpointing=_cfgget(config, "enable_gradient_checkpointing", True),
|
||||||
|
gradient_checkpointing_use_reentrant=_cfgget(
|
||||||
|
config, "gradient_checkpointing_use_reentrant", False
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
action_head_type = _cfgget(config, "action_head", "flowmatching").lower()
|
||||||
|
if action_head_type != "flowmatching":
|
||||||
|
raise NotImplementedError(f"Unknown action_head: {action_head_type}")
|
||||||
|
|
||||||
|
horizon = _cfgget(config, "action_horizon", _cfgget(config, "horizon", 16))
|
||||||
|
per_action_dim = _cfgget(config, "per_action_dim", 7)
|
||||||
|
action_dim = horizon * per_action_dim
|
||||||
|
|
||||||
|
if isinstance(config, dict):
|
||||||
|
config["horizon"] = horizon
|
||||||
|
config["per_action_dim"] = per_action_dim
|
||||||
|
config["action_dim"] = action_dim
|
||||||
|
|
||||||
|
self.horizon = horizon
|
||||||
|
self.per_action_dim = per_action_dim
|
||||||
|
self.action_head = FlowmatchingActionHead(config=config).to(self._device)
|
||||||
|
|
||||||
|
def _normalize_image_batches(
|
||||||
|
self,
|
||||||
|
images: Sequence[Image.Image | torch.Tensor] | Sequence[Sequence[Image.Image | torch.Tensor]],
|
||||||
|
prompt: str | list[str] | None,
|
||||||
|
image_mask: torch.Tensor,
|
||||||
|
) -> tuple[list[list[Image.Image | torch.Tensor]], list[str], torch.Tensor]:
|
||||||
|
if not images:
|
||||||
|
raise ValueError("EVO1 expects at least one image per sample.")
|
||||||
|
|
||||||
|
first = images[0]
|
||||||
|
if isinstance(first, (Image.Image, torch.Tensor)):
|
||||||
|
image_batches = [list(images)] # type: ignore[arg-type]
|
||||||
|
else:
|
||||||
|
image_batches = [list(sample) for sample in images] # type: ignore[arg-type]
|
||||||
|
|
||||||
|
batch_size = len(image_batches)
|
||||||
|
if prompt is None:
|
||||||
|
prompts = [""] * batch_size
|
||||||
|
elif isinstance(prompt, str):
|
||||||
|
prompts = [prompt] * batch_size
|
||||||
|
else:
|
||||||
|
prompts = [str(p) for p in prompt]
|
||||||
|
if len(prompts) != batch_size:
|
||||||
|
raise ValueError(
|
||||||
|
f"Prompt batch size {len(prompts)} does not match image batch size {batch_size}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if image_mask.dim() == 1:
|
||||||
|
image_mask = image_mask.unsqueeze(0)
|
||||||
|
if image_mask.shape[0] != batch_size:
|
||||||
|
raise ValueError(
|
||||||
|
f"image_mask batch size {image_mask.shape[0]} does not match image batch size {batch_size}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return image_batches, prompts, image_mask
|
||||||
|
|
||||||
|
def get_vl_embeddings(
|
||||||
|
self,
|
||||||
|
images: list[Image.Image | torch.Tensor] | list[list[Image.Image | torch.Tensor]],
|
||||||
|
image_mask: torch.Tensor,
|
||||||
|
prompt: str | list[str] | None = None,
|
||||||
|
return_cls_only: bool | None = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if return_cls_only is None:
|
||||||
|
return_cls_only = self.return_cls_only
|
||||||
|
|
||||||
|
image_batches, prompts, image_mask = self._normalize_image_batches(images, prompt, image_mask)
|
||||||
|
return self.embedder.get_fused_image_text_embedding_from_tensor_images(
|
||||||
|
image_tensors_batch=image_batches,
|
||||||
|
image_masks=image_mask,
|
||||||
|
text_prompts=prompts,
|
||||||
|
return_cls_only=return_cls_only,
|
||||||
|
)
|
||||||
|
|
||||||
|
def prepare_state(self, state_input: list | torch.Tensor) -> torch.Tensor:
|
||||||
|
if isinstance(state_input, list):
|
||||||
|
state_tensor = torch.tensor(state_input)
|
||||||
|
elif isinstance(state_input, torch.Tensor):
|
||||||
|
state_tensor = state_input
|
||||||
|
else:
|
||||||
|
raise TypeError(f"Unsupported state input type: {type(state_input)}")
|
||||||
|
|
||||||
|
if state_tensor.ndim == 1:
|
||||||
|
state_tensor = state_tensor.unsqueeze(0)
|
||||||
|
|
||||||
|
return state_tensor.to(self._device)
|
||||||
|
|
||||||
|
def predict_action(
|
||||||
|
self,
|
||||||
|
fused_tokens: torch.Tensor,
|
||||||
|
state: torch.Tensor,
|
||||||
|
actions_gt: torch.Tensor | None = None,
|
||||||
|
action_mask: torch.Tensor | None = None,
|
||||||
|
embodiment_ids: torch.Tensor | None = None,
|
||||||
|
):
|
||||||
|
if actions_gt is None:
|
||||||
|
return self.action_head.get_action(
|
||||||
|
fused_tokens,
|
||||||
|
state=state,
|
||||||
|
action_mask=action_mask,
|
||||||
|
embodiment_id=embodiment_ids,
|
||||||
|
)
|
||||||
|
return self.action_head(
|
||||||
|
fused_tokens,
|
||||||
|
state=state,
|
||||||
|
actions_gt=actions_gt,
|
||||||
|
action_mask=action_mask,
|
||||||
|
embodiment_id=embodiment_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def run_inference(
|
||||||
|
self,
|
||||||
|
images: list[Image.Image | torch.Tensor],
|
||||||
|
image_mask: torch.Tensor,
|
||||||
|
prompt: str,
|
||||||
|
state_input: list | torch.Tensor,
|
||||||
|
return_cls_only: bool | None = None,
|
||||||
|
action_mask: torch.Tensor | None = None,
|
||||||
|
embodiment_ids: torch.Tensor | None = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if image_mask.dim() == 1:
|
||||||
|
image_mask = image_mask.unsqueeze(0)
|
||||||
|
|
||||||
|
fused_tokens = self.get_vl_embeddings(
|
||||||
|
images=[images],
|
||||||
|
image_mask=image_mask,
|
||||||
|
prompt=[prompt],
|
||||||
|
return_cls_only=return_cls_only,
|
||||||
|
)
|
||||||
|
state_tensor = self.prepare_state(state_input)
|
||||||
|
action = self.predict_action(
|
||||||
|
fused_tokens,
|
||||||
|
state_tensor,
|
||||||
|
action_mask=action_mask,
|
||||||
|
embodiment_ids=embodiment_ids,
|
||||||
|
)
|
||||||
|
if isinstance(action, torch.Tensor) and action.dtype == torch.bfloat16:
|
||||||
|
action = action.to(torch.float32)
|
||||||
|
return action
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
fused_tokens: torch.Tensor,
|
||||||
|
state: torch.Tensor | None = None,
|
||||||
|
actions_gt: torch.Tensor | None = None,
|
||||||
|
action_mask: torch.Tensor | None = None,
|
||||||
|
embodiment_ids: torch.Tensor | None = None,
|
||||||
|
):
|
||||||
|
return self.predict_action(fused_tokens, state, actions_gt, action_mask, embodiment_ids)
|
||||||
|
|
||||||
|
def _set_module_trainable(self, module: nn.Module, trainable: bool):
|
||||||
|
for param in module.parameters():
|
||||||
|
param.requires_grad = trainable
|
||||||
|
|
||||||
|
def set_finetune_flags(self):
|
||||||
|
finetune_vlm = _cfgget(self.config, "finetune_vlm", False)
|
||||||
|
finetune_language_model = _cfgget(self.config, "finetune_language_model", False)
|
||||||
|
finetune_vision_model = _cfgget(self.config, "finetune_vision_model", False)
|
||||||
|
has_explicit_branch_flags = any(
|
||||||
|
flag is not None for flag in (finetune_language_model, finetune_vision_model)
|
||||||
|
)
|
||||||
|
finetune_language_model = bool(finetune_language_model)
|
||||||
|
finetune_vision_model = bool(finetune_vision_model)
|
||||||
|
finetune_vlm = bool(finetune_vlm)
|
||||||
|
|
||||||
|
if has_explicit_branch_flags:
|
||||||
|
self._set_module_trainable(self.embedder, False)
|
||||||
|
if hasattr(self.embedder.model, "language_model"):
|
||||||
|
self._set_module_trainable(self.embedder.model.language_model, finetune_language_model)
|
||||||
|
if hasattr(self.embedder.model, "vision_model"):
|
||||||
|
self._set_module_trainable(self.embedder.model.vision_model, finetune_vision_model)
|
||||||
|
if hasattr(self.embedder.model, "mlp1"):
|
||||||
|
self._set_module_trainable(self.embedder.model.mlp1, finetune_vision_model)
|
||||||
|
elif not finetune_vlm:
|
||||||
|
self._set_module_trainable(self.embedder, False)
|
||||||
|
|
||||||
|
if not _cfgget(self.config, "finetune_action_head", False):
|
||||||
|
self._set_module_trainable(self.action_head, False)
|
||||||
@@ -0,0 +1,456 @@
|
|||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _cfgget(config, key: str, default=None):
|
||||||
|
if isinstance(config, dict):
|
||||||
|
return config.get(key, default)
|
||||||
|
return getattr(config, key, default)
|
||||||
|
|
||||||
|
|
||||||
|
class SinusoidalPositionalEncoding(nn.Module):
|
||||||
|
def __init__(self, dim: int, max_len: int = 1000):
|
||||||
|
super().__init__()
|
||||||
|
pe = torch.zeros(max_len, dim)
|
||||||
|
position = torch.arange(0, max_len).unsqueeze(1)
|
||||||
|
div_term = torch.exp(torch.arange(0, dim, 2) * -(math.log(10000.0) / dim))
|
||||||
|
pe[:, 0::2] = torch.sin(position * div_term)
|
||||||
|
pe[:, 1::2] = torch.cos(position * div_term)
|
||||||
|
pe = pe.unsqueeze(0)
|
||||||
|
self.register_buffer("pe", pe)
|
||||||
|
|
||||||
|
def forward(self, seq_len: int):
|
||||||
|
if seq_len > self.pe.size(1):
|
||||||
|
self._extend_pe(seq_len)
|
||||||
|
return self.pe[:, :seq_len, :]
|
||||||
|
|
||||||
|
def _extend_pe(self, new_max_len):
|
||||||
|
old_max_len, dim = self.pe.size(1), self.pe.size(2)
|
||||||
|
if new_max_len <= old_max_len:
|
||||||
|
return
|
||||||
|
extra_positions = torch.arange(old_max_len, new_max_len, dtype=torch.float).unsqueeze(1)
|
||||||
|
div_term = torch.exp(torch.arange(0, dim, 2, dtype=torch.float) * -(math.log(10000.0) / dim))
|
||||||
|
extra_pe = torch.zeros(new_max_len - old_max_len, dim)
|
||||||
|
extra_pe[:, 0::2] = torch.sin(extra_positions * div_term)
|
||||||
|
extra_pe[:, 1::2] = torch.cos(extra_positions * div_term)
|
||||||
|
extra_pe = extra_pe.unsqueeze(0)
|
||||||
|
new_pe = torch.cat([self.pe, extra_pe.to(self.pe.device)], dim=1)
|
||||||
|
self.pe = new_pe
|
||||||
|
|
||||||
|
|
||||||
|
class CategorySpecificLinear(nn.Module):
|
||||||
|
def __init__(self, in_dim: int, out_dim: int, num_categories: int = 1):
|
||||||
|
super().__init__()
|
||||||
|
self.num_categories = num_categories
|
||||||
|
if num_categories <= 1:
|
||||||
|
self.linear = nn.Linear(in_dim, out_dim)
|
||||||
|
else:
|
||||||
|
self.weight = nn.Parameter(torch.empty(num_categories, in_dim, out_dim))
|
||||||
|
self.bias = nn.Parameter(torch.zeros(num_categories, out_dim))
|
||||||
|
nn.init.xavier_uniform_(self.weight)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, category_id: torch.LongTensor):
|
||||||
|
if self.num_categories <= 1:
|
||||||
|
if x.dtype != self.linear.weight.dtype:
|
||||||
|
x = x.to(dtype=self.linear.weight.dtype)
|
||||||
|
return self.linear(x)
|
||||||
|
|
||||||
|
if x.dtype != self.weight.dtype:
|
||||||
|
x = x.to(dtype=self.weight.dtype)
|
||||||
|
|
||||||
|
orig_shape = x.shape
|
||||||
|
x_flat = x.reshape(-1, orig_shape[-1])
|
||||||
|
if category_id.dim() == 0:
|
||||||
|
cid = category_id.item()
|
||||||
|
out = x_flat @ self.weight[cid] + self.bias[cid]
|
||||||
|
else:
|
||||||
|
category_id = category_id.reshape(-1)
|
||||||
|
if category_id.numel() != x_flat.size(0):
|
||||||
|
raise ValueError(
|
||||||
|
f"category_id length {category_id.numel()} does not match flattened batch {x_flat.size(0)}"
|
||||||
|
)
|
||||||
|
weight_selected = self.weight[category_id]
|
||||||
|
bias_selected = self.bias[category_id]
|
||||||
|
out = torch.bmm(x_flat.unsqueeze(1), weight_selected).squeeze(1) + bias_selected
|
||||||
|
out_shape = orig_shape[:-1] + (out.shape[-1],)
|
||||||
|
return out.view(out_shape)
|
||||||
|
|
||||||
|
|
||||||
|
class CategorySpecificMLP(nn.Module):
|
||||||
|
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, num_categories: int = 1):
|
||||||
|
super().__init__()
|
||||||
|
self.fc1 = CategorySpecificLinear(input_dim, hidden_dim, num_categories)
|
||||||
|
self.fc2 = CategorySpecificLinear(hidden_dim, output_dim, num_categories)
|
||||||
|
self.activation = nn.ReLU(inplace=True)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, category_id: torch.LongTensor):
|
||||||
|
out = self.activation(self.fc1(x, category_id))
|
||||||
|
out = self.fc2(out, category_id)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class MultiEmbodimentActionEncoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, action_dim: int, embed_dim: int, hidden_dim: int, horizon: int, num_categories: int = 1
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.horizon = horizon
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
self.num_categories = num_categories
|
||||||
|
|
||||||
|
self.W1 = CategorySpecificLinear(action_dim, hidden_dim, num_categories)
|
||||||
|
self.W2 = CategorySpecificLinear(hidden_dim, hidden_dim, num_categories)
|
||||||
|
self.W3 = CategorySpecificLinear(hidden_dim, embed_dim, num_categories)
|
||||||
|
|
||||||
|
self.pos_encoding = SinusoidalPositionalEncoding(hidden_dim, max_len=horizon)
|
||||||
|
self.activation = nn.ReLU(inplace=True)
|
||||||
|
|
||||||
|
def forward(self, action_seq: torch.Tensor, category_id: torch.LongTensor):
|
||||||
|
batch_size, horizon, action_dim = action_seq.shape
|
||||||
|
assert self.horizon == horizon, "Action sequence length must match horizon"
|
||||||
|
|
||||||
|
x = action_seq.reshape(batch_size * horizon, action_dim)
|
||||||
|
if category_id.dim() == 0:
|
||||||
|
cat_ids = category_id.expand(horizon * batch_size)
|
||||||
|
else:
|
||||||
|
cat_ids = category_id.unsqueeze(1).expand(batch_size, horizon).reshape(batch_size * horizon)
|
||||||
|
|
||||||
|
out = self.activation(self.W1(x, cat_ids))
|
||||||
|
pos_enc = self.pos_encoding(horizon).to(device=out.device, dtype=out.dtype)
|
||||||
|
out = out.view(batch_size, horizon, -1) + pos_enc
|
||||||
|
out = out.view(batch_size * horizon, -1)
|
||||||
|
out = self.activation(self.W2(out, cat_ids))
|
||||||
|
out = self.W3(out, cat_ids)
|
||||||
|
return out.view(batch_size, horizon, self.embed_dim)
|
||||||
|
|
||||||
|
|
||||||
|
class BasicTransformerBlock(nn.Module):
|
||||||
|
def __init__(self, embed_dim: int, num_heads: int, hidden_dim: int, dropout: float = 0.0):
|
||||||
|
super().__init__()
|
||||||
|
self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
|
||||||
|
self.norm1 = nn.LayerNorm(embed_dim)
|
||||||
|
self.norm2 = nn.LayerNorm(embed_dim)
|
||||||
|
self.ff = nn.Sequential(nn.Linear(embed_dim, hidden_dim), nn.GELU(), nn.Linear(hidden_dim, embed_dim))
|
||||||
|
|
||||||
|
def forward(self, action_tokens: torch.Tensor, context_tokens: torch.Tensor, time_emb: torch.Tensor):
|
||||||
|
x = self.norm1(action_tokens)
|
||||||
|
attn_out, _ = self.attn(x, context_tokens, context_tokens)
|
||||||
|
x = action_tokens + attn_out
|
||||||
|
x2 = self.norm2(x)
|
||||||
|
if time_emb is not None:
|
||||||
|
x2 = x2 + time_emb.unsqueeze(1)
|
||||||
|
ff_out = self.ff(x2)
|
||||||
|
return x + ff_out
|
||||||
|
|
||||||
|
|
||||||
|
class FlowmatchingActionHead(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config=None,
|
||||||
|
embed_dim: int = 896,
|
||||||
|
hidden_dim: int = 1024,
|
||||||
|
action_dim: int = 16 * 7,
|
||||||
|
horizon: int = 16,
|
||||||
|
per_action_dim: int = 7,
|
||||||
|
num_heads: int = 8,
|
||||||
|
num_layers: int = 8,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
num_inference_timesteps: int = 20,
|
||||||
|
num_categories: int = 1,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if config is not None:
|
||||||
|
embed_dim = _cfgget(config, "embed_dim", embed_dim)
|
||||||
|
hidden_dim = _cfgget(config, "hidden_dim", hidden_dim)
|
||||||
|
action_dim = _cfgget(config, "action_dim", action_dim)
|
||||||
|
horizon = _cfgget(config, "horizon", horizon)
|
||||||
|
per_action_dim = _cfgget(config, "per_action_dim", per_action_dim)
|
||||||
|
num_heads = _cfgget(config, "num_heads", num_heads)
|
||||||
|
num_layers = _cfgget(config, "num_layers", num_layers)
|
||||||
|
dropout = _cfgget(config, "dropout", dropout)
|
||||||
|
num_inference_timesteps = _cfgget(config, "num_inference_timesteps", num_inference_timesteps)
|
||||||
|
num_categories = _cfgget(config, "num_categories", num_categories)
|
||||||
|
self.config = config
|
||||||
|
else:
|
||||||
|
self.config = SimpleNamespace(
|
||||||
|
embed_dim=embed_dim,
|
||||||
|
hidden_dim=hidden_dim,
|
||||||
|
action_dim=action_dim,
|
||||||
|
horizon=horizon,
|
||||||
|
per_action_dim=per_action_dim,
|
||||||
|
num_heads=num_heads,
|
||||||
|
num_layers=num_layers,
|
||||||
|
dropout=dropout,
|
||||||
|
num_inference_timesteps=num_inference_timesteps,
|
||||||
|
num_categories=num_categories,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("FlowmatchingActionHead num_inference_timesteps=%s", num_inference_timesteps)
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
self.horizon = horizon
|
||||||
|
self.per_action_dim = _cfgget(self.config, "per_action_dim", per_action_dim)
|
||||||
|
self.action_dim = _cfgget(self.config, "action_dim", action_dim)
|
||||||
|
|
||||||
|
self.time_pos_enc = SinusoidalPositionalEncoding(embed_dim, max_len=1000)
|
||||||
|
self.transformer_blocks = nn.ModuleList(
|
||||||
|
[
|
||||||
|
BasicTransformerBlock(
|
||||||
|
embed_dim=embed_dim,
|
||||||
|
num_heads=num_heads,
|
||||||
|
hidden_dim=embed_dim * 4,
|
||||||
|
dropout=dropout,
|
||||||
|
)
|
||||||
|
for _ in range(num_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.norm_out = nn.LayerNorm(embed_dim)
|
||||||
|
self.seq_pool_proj = nn.Linear(self.horizon * self.embed_dim, self.embed_dim)
|
||||||
|
self.mlp_head = CategorySpecificMLP(
|
||||||
|
input_dim=embed_dim,
|
||||||
|
hidden_dim=hidden_dim,
|
||||||
|
output_dim=action_dim,
|
||||||
|
num_categories=num_categories,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.state_encoder = None
|
||||||
|
state_dim = _cfgget(self.config, "state_dim")
|
||||||
|
if state_dim is not None:
|
||||||
|
state_hidden = _cfgget(self.config, "state_hidden_dim", embed_dim)
|
||||||
|
self.state_encoder = CategorySpecificMLP(
|
||||||
|
input_dim=state_dim,
|
||||||
|
hidden_dim=state_hidden,
|
||||||
|
output_dim=embed_dim,
|
||||||
|
num_categories=num_categories,
|
||||||
|
)
|
||||||
|
|
||||||
|
if horizon > 1:
|
||||||
|
self.action_encoder = MultiEmbodimentActionEncoder(
|
||||||
|
action_dim=self.per_action_dim,
|
||||||
|
embed_dim=embed_dim,
|
||||||
|
hidden_dim=embed_dim,
|
||||||
|
horizon=horizon,
|
||||||
|
num_categories=num_categories,
|
||||||
|
)
|
||||||
|
self.single_action_proj = None
|
||||||
|
else:
|
||||||
|
self.action_encoder = None
|
||||||
|
self.single_action_proj = nn.Linear(self.per_action_dim, self.embed_dim)
|
||||||
|
|
||||||
|
def _project_actions(self, action_seq: torch.Tensor, embodiment_id: torch.LongTensor) -> torch.Tensor:
|
||||||
|
if self.horizon > 1 and self.action_encoder is not None:
|
||||||
|
return self.action_encoder(action_seq, embodiment_id)
|
||||||
|
if self.single_action_proj is None:
|
||||||
|
raise RuntimeError("single_action_proj is not initialized for horizon <= 1.")
|
||||||
|
return self.single_action_proj(action_seq)
|
||||||
|
|
||||||
|
def _expand_action_mask(
|
||||||
|
self,
|
||||||
|
action_mask: torch.Tensor,
|
||||||
|
batch_size: int,
|
||||||
|
per_action_dim: int,
|
||||||
|
device: torch.device,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if action_mask is None:
|
||||||
|
raise ValueError("action_mask must be provided for flow matching inference.")
|
||||||
|
|
||||||
|
if action_mask.dim() == 2:
|
||||||
|
expected_last_dim = self.horizon * per_action_dim
|
||||||
|
if action_mask.shape == (batch_size, expected_last_dim):
|
||||||
|
expanded_mask = action_mask.reshape(batch_size, self.horizon, per_action_dim)
|
||||||
|
elif action_mask.shape == (batch_size, per_action_dim):
|
||||||
|
expanded_mask = action_mask.unsqueeze(1).expand(batch_size, self.horizon, per_action_dim)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Expected action_mask shape {(batch_size, expected_last_dim)} or "
|
||||||
|
f"{(batch_size, per_action_dim)}, got {tuple(action_mask.shape)}"
|
||||||
|
)
|
||||||
|
elif action_mask.dim() == 3:
|
||||||
|
expected_shape = (batch_size, self.horizon, per_action_dim)
|
||||||
|
if tuple(action_mask.shape) != expected_shape:
|
||||||
|
raise ValueError(
|
||||||
|
f"Expected action_mask shape {expected_shape}, got {tuple(action_mask.shape)}"
|
||||||
|
)
|
||||||
|
expanded_mask = action_mask
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported action_mask rank: {action_mask.dim()}")
|
||||||
|
|
||||||
|
return expanded_mask.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
fused_tokens: torch.Tensor,
|
||||||
|
state: torch.Tensor = None,
|
||||||
|
actions_gt: torch.Tensor = None,
|
||||||
|
embodiment_id: torch.LongTensor = None,
|
||||||
|
state_mask: torch.Tensor = None,
|
||||||
|
action_mask: torch.Tensor = None,
|
||||||
|
):
|
||||||
|
if actions_gt is None:
|
||||||
|
return self.get_action(
|
||||||
|
fused_tokens, state=state, embodiment_id=embodiment_id, action_mask=action_mask
|
||||||
|
)
|
||||||
|
|
||||||
|
batch_size = fused_tokens.size(0)
|
||||||
|
device = fused_tokens.device
|
||||||
|
if embodiment_id is None:
|
||||||
|
embodiment_id = torch.zeros(batch_size, dtype=torch.long, device=device)
|
||||||
|
|
||||||
|
context_tokens = fused_tokens
|
||||||
|
if state is not None and self.state_encoder is not None:
|
||||||
|
state_emb = self.state_encoder(state, embodiment_id).unsqueeze(1)
|
||||||
|
context_tokens = torch.cat([context_tokens, state_emb], dim=1)
|
||||||
|
|
||||||
|
t = (
|
||||||
|
torch.distributions.Beta(2, 2)
|
||||||
|
.sample((batch_size,))
|
||||||
|
.clamp(0.02, 0.98)
|
||||||
|
.to(device)
|
||||||
|
.to(dtype=self.dtype)
|
||||||
|
)
|
||||||
|
time_index = (t * 999).long().clamp_(0, 999)
|
||||||
|
time_emb = self.time_pos_enc(1000)[:, time_index, :].squeeze(0).to(dtype=context_tokens.dtype)
|
||||||
|
|
||||||
|
actions_gt_seq = actions_gt
|
||||||
|
noise = torch.rand_like(actions_gt) * 2 - 1
|
||||||
|
if action_mask is not None:
|
||||||
|
action_mask = action_mask.to(dtype=noise.dtype, device=noise.device)
|
||||||
|
if action_mask.shape != noise.shape:
|
||||||
|
raise ValueError(f"action_mask shape {action_mask.shape} != noise shape {noise.shape}")
|
||||||
|
actions_gt_seq = actions_gt_seq * action_mask
|
||||||
|
noise = noise * action_mask
|
||||||
|
|
||||||
|
if self.horizon > 1:
|
||||||
|
noise_seq = noise.view(batch_size, self.horizon, self.per_action_dim)
|
||||||
|
else:
|
||||||
|
noise_seq = noise if noise.dim() == 3 else noise.unsqueeze(1)
|
||||||
|
t_broadcast = t.view(batch_size, 1, 1)
|
||||||
|
action_intermediate_seq = (1 - t_broadcast) * noise_seq + t_broadcast * actions_gt_seq
|
||||||
|
|
||||||
|
action_tokens = self._project_actions(action_intermediate_seq, embodiment_id)
|
||||||
|
target_dtype = self.dtype
|
||||||
|
action_tokens = action_tokens.to(dtype=target_dtype)
|
||||||
|
context_tokens = context_tokens.to(dtype=target_dtype)
|
||||||
|
time_emb = time_emb.to(dtype=target_dtype)
|
||||||
|
|
||||||
|
x = action_tokens
|
||||||
|
for block in self.transformer_blocks:
|
||||||
|
x = block(x, context_tokens, time_emb)
|
||||||
|
x = self.norm_out(x)
|
||||||
|
|
||||||
|
if self.horizon > 1:
|
||||||
|
x_flat = x.reshape(batch_size, -1)
|
||||||
|
x_pooled = self.seq_pool_proj(x_flat)
|
||||||
|
else:
|
||||||
|
x_pooled = x.squeeze(1)
|
||||||
|
|
||||||
|
pred_velocity = self.mlp_head(x_pooled, embodiment_id)
|
||||||
|
return pred_velocity, noise
|
||||||
|
|
||||||
|
def get_action(
|
||||||
|
self,
|
||||||
|
fused_tokens: torch.Tensor,
|
||||||
|
state: torch.Tensor = None,
|
||||||
|
embodiment_id: torch.LongTensor = None,
|
||||||
|
action_mask: torch.Tensor = None,
|
||||||
|
):
|
||||||
|
batch_size = fused_tokens.size(0)
|
||||||
|
device = fused_tokens.device
|
||||||
|
if embodiment_id is None:
|
||||||
|
embodiment_id = torch.zeros(batch_size, dtype=torch.long, device=device)
|
||||||
|
|
||||||
|
context_tokens = fused_tokens
|
||||||
|
if state is not None and self.state_encoder is not None:
|
||||||
|
state_emb = self.state_encoder(state, embodiment_id).unsqueeze(1)
|
||||||
|
context_tokens = torch.cat([context_tokens, state_emb], dim=1)
|
||||||
|
|
||||||
|
action_dim_total = _cfgget(self.config, "action_dim", self.action_dim)
|
||||||
|
per_action_dim = _cfgget(self.config, "per_action_dim", action_dim_total // max(self.horizon, 1))
|
||||||
|
|
||||||
|
action = torch.rand(batch_size, action_dim_total, device=device, dtype=context_tokens.dtype) * 2 - 1
|
||||||
|
action_seq = (
|
||||||
|
action.view(batch_size, self.horizon, per_action_dim)
|
||||||
|
if self.horizon > 1
|
||||||
|
else action.view(batch_size, 1, per_action_dim)
|
||||||
|
)
|
||||||
|
action_mask = self._expand_action_mask(
|
||||||
|
action_mask,
|
||||||
|
batch_size=batch_size,
|
||||||
|
per_action_dim=per_action_dim,
|
||||||
|
device=action_seq.device,
|
||||||
|
dtype=action_seq.dtype,
|
||||||
|
)
|
||||||
|
action_seq = action_seq * action_mask
|
||||||
|
|
||||||
|
target_dtype = self.dtype
|
||||||
|
context_tokens = context_tokens.to(dtype=target_dtype)
|
||||||
|
|
||||||
|
num_steps = int(_cfgget(self.config, "num_inference_timesteps", 32))
|
||||||
|
if num_steps <= 0:
|
||||||
|
raise ValueError(f"num_inference_timesteps must be positive, got {num_steps}")
|
||||||
|
dt = 1.0 / num_steps
|
||||||
|
|
||||||
|
for i in range(num_steps):
|
||||||
|
t = i / num_steps
|
||||||
|
time_index = min(int(t * 999), 999)
|
||||||
|
time_emb = (
|
||||||
|
self.time_pos_enc(1000)[:, time_index, :].to(device).squeeze(0).to(dtype=context_tokens.dtype)
|
||||||
|
)
|
||||||
|
time_emb = time_emb.unsqueeze(0).repeat(batch_size, 1)
|
||||||
|
|
||||||
|
action_seq = action_seq * action_mask
|
||||||
|
action_tokens = self._project_actions(action_seq, embodiment_id).to(dtype=target_dtype)
|
||||||
|
time_emb = time_emb.to(dtype=target_dtype)
|
||||||
|
|
||||||
|
x = action_tokens
|
||||||
|
for block in self.transformer_blocks:
|
||||||
|
x = block(x, context_tokens, time_emb)
|
||||||
|
x = self.norm_out(x)
|
||||||
|
|
||||||
|
if self.horizon > 1:
|
||||||
|
x_flat = x.reshape(batch_size, -1)
|
||||||
|
x_pooled = self.seq_pool_proj(x_flat)
|
||||||
|
else:
|
||||||
|
x_pooled = x.squeeze(1)
|
||||||
|
|
||||||
|
pred = self.mlp_head(x_pooled, embodiment_id)
|
||||||
|
action = action + dt * pred
|
||||||
|
action_seq = (
|
||||||
|
action.view(batch_size, self.horizon, per_action_dim)
|
||||||
|
if self.horizon > 1
|
||||||
|
else action.view(batch_size, 1, per_action_dim)
|
||||||
|
)
|
||||||
|
|
||||||
|
action_seq = action_seq * action_mask
|
||||||
|
return action_seq.reshape(batch_size, -1)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self):
|
||||||
|
return next(self.parameters()).device
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dtype(self):
|
||||||
|
return next(self.parameters()).dtype
|
||||||
@@ -0,0 +1,435 @@
|
|||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import functools
|
||||||
|
import logging
|
||||||
|
import types
|
||||||
|
from collections.abc import Sequence
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.utils.checkpoint
|
||||||
|
import torchvision.transforms.functional as TF
|
||||||
|
from PIL import Image
|
||||||
|
from torchvision.transforms.functional import to_pil_image
|
||||||
|
|
||||||
|
from lerobot.utils.import_utils import _transformers_available, require_package
|
||||||
|
|
||||||
|
if TYPE_CHECKING or _transformers_available:
|
||||||
|
from transformers import AutoModel, AutoTokenizer
|
||||||
|
else:
|
||||||
|
AutoModel = None
|
||||||
|
AutoTokenizer = None
|
||||||
|
|
||||||
|
IMAGENET_MEAN = (0.485, 0.456, 0.406)
|
||||||
|
IMAGENET_STD = (0.229, 0.224, 0.225)
|
||||||
|
IMG_CONTEXT_TOKEN = "<IMG_CONTEXT>" # nosec B105
|
||||||
|
IMG_START_TOKEN = "<img>" # nosec B105
|
||||||
|
IMG_END_TOKEN = "</img>" # nosec B105
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _patch_vision_encoder_checkpointing(encoder: nn.Module, use_reentrant: bool) -> None:
|
||||||
|
if getattr(encoder, "_evo1_checkpoint_patch_applied", False):
|
||||||
|
encoder.gradient_checkpointing_use_reentrant = use_reentrant
|
||||||
|
return
|
||||||
|
|
||||||
|
original_forward = encoder.forward
|
||||||
|
|
||||||
|
def forward_with_checkpoint_kwargs(self, *args, **kwargs):
|
||||||
|
original_checkpoint = torch.utils.checkpoint.checkpoint
|
||||||
|
|
||||||
|
def checkpoint(function, *checkpoint_args, **checkpoint_kwargs):
|
||||||
|
checkpoint_kwargs.setdefault("use_reentrant", self.gradient_checkpointing_use_reentrant)
|
||||||
|
return original_checkpoint(function, *checkpoint_args, **checkpoint_kwargs)
|
||||||
|
|
||||||
|
torch.utils.checkpoint.checkpoint = checkpoint
|
||||||
|
try:
|
||||||
|
return original_forward(*args, **kwargs)
|
||||||
|
finally:
|
||||||
|
torch.utils.checkpoint.checkpoint = original_checkpoint
|
||||||
|
|
||||||
|
encoder.gradient_checkpointing_use_reentrant = use_reentrant
|
||||||
|
encoder.forward = types.MethodType(forward_with_checkpoint_kwargs, encoder)
|
||||||
|
encoder._evo1_checkpoint_patch_applied = True
|
||||||
|
|
||||||
|
|
||||||
|
def flash_attn_is_available() -> bool:
|
||||||
|
try:
|
||||||
|
import flash_attn # noqa: F401
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def _internvl_transformers5_load_compatibility():
|
||||||
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
|
|
||||||
|
original_linspace = torch.linspace
|
||||||
|
original_mark_tied = PreTrainedModel.mark_tied_weights_as_initialized
|
||||||
|
|
||||||
|
def linspace(*args, **kwargs):
|
||||||
|
if kwargs.get("device") is None:
|
||||||
|
kwargs["device"] = torch.device("cpu")
|
||||||
|
return original_linspace(*args, **kwargs)
|
||||||
|
|
||||||
|
def mark_tied_weights_as_initialized(self, loading_info):
|
||||||
|
if not hasattr(self, "all_tied_weights_keys"):
|
||||||
|
self.all_tied_weights_keys = {}
|
||||||
|
return original_mark_tied(self, loading_info)
|
||||||
|
|
||||||
|
torch.linspace = linspace
|
||||||
|
PreTrainedModel.mark_tied_weights_as_initialized = mark_tied_weights_as_initialized
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
torch.linspace = original_linspace
|
||||||
|
PreTrainedModel.mark_tied_weights_as_initialized = original_mark_tied
|
||||||
|
|
||||||
|
|
||||||
|
@functools.lru_cache(maxsize=10000)
|
||||||
|
def get_target_aspect_ratio(orig_width: int, orig_height: int, image_size: int, min_num: int, max_num: int):
|
||||||
|
aspect_ratio = orig_width / orig_height
|
||||||
|
target_ratios = {
|
||||||
|
(i, j)
|
||||||
|
for n in range(min_num, max_num + 1)
|
||||||
|
for i in range(1, n + 1)
|
||||||
|
for j in range(1, n + 1)
|
||||||
|
if i * j <= max_num and i * j >= min_num
|
||||||
|
}
|
||||||
|
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
|
||||||
|
|
||||||
|
best_ratio_diff = float("inf")
|
||||||
|
best_ratio = (1, 1)
|
||||||
|
area = orig_width * orig_height
|
||||||
|
for ratio in target_ratios:
|
||||||
|
target_ar = ratio[0] / ratio[1]
|
||||||
|
diff = abs(aspect_ratio - target_ar)
|
||||||
|
if diff < best_ratio_diff:
|
||||||
|
best_ratio_diff = diff
|
||||||
|
best_ratio = ratio
|
||||||
|
elif diff == best_ratio_diff and area > 0.5 * image_size**2 * ratio[0] * ratio[1]:
|
||||||
|
best_ratio = ratio
|
||||||
|
return best_ratio
|
||||||
|
|
||||||
|
|
||||||
|
def dynamic_preprocess(image, min_num=1, max_num=1, image_size=448, use_thumbnail=False):
|
||||||
|
orig_width, orig_height = image.size
|
||||||
|
ratio_w, ratio_h = get_target_aspect_ratio(orig_width, orig_height, image_size, min_num, max_num)
|
||||||
|
target_width = image_size * ratio_w
|
||||||
|
target_height = image_size * ratio_h
|
||||||
|
blocks = ratio_w * ratio_h
|
||||||
|
resized_img = image.resize((target_width, target_height))
|
||||||
|
processed_images = []
|
||||||
|
for i in range(blocks):
|
||||||
|
box = (
|
||||||
|
(i % (target_width // image_size)) * image_size,
|
||||||
|
(i // (target_width // image_size)) * image_size,
|
||||||
|
((i % (target_width // image_size)) + 1) * image_size,
|
||||||
|
((i // (target_width // image_size)) + 1) * image_size,
|
||||||
|
)
|
||||||
|
processed_images.append(resized_img.crop(box))
|
||||||
|
if use_thumbnail and len(processed_images) != 1:
|
||||||
|
processed_images.append(image.resize((image_size, image_size)))
|
||||||
|
return processed_images
|
||||||
|
|
||||||
|
|
||||||
|
class InternVL3Embedder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name="OpenGVLab/InternVL3-1B",
|
||||||
|
image_size=448,
|
||||||
|
device="cuda",
|
||||||
|
num_language_layers: int | None = 14,
|
||||||
|
model_dtype: str | torch.dtype = "bfloat16",
|
||||||
|
use_flash_attn: bool = True,
|
||||||
|
enable_gradient_checkpointing: bool = True,
|
||||||
|
gradient_checkpointing_use_reentrant: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self._requested_device = device
|
||||||
|
self.image_size = image_size
|
||||||
|
self.num_language_layers = num_language_layers
|
||||||
|
self.max_text_length = 1024
|
||||||
|
self.enable_gradient_checkpointing = bool(enable_gradient_checkpointing)
|
||||||
|
self.gradient_checkpointing_use_reentrant = bool(gradient_checkpointing_use_reentrant)
|
||||||
|
|
||||||
|
require_package("transformers", extra="evo1")
|
||||||
|
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, use_fast=False)
|
||||||
|
if isinstance(model_dtype, str):
|
||||||
|
try:
|
||||||
|
model_dtype = getattr(torch, model_dtype)
|
||||||
|
except AttributeError as exc:
|
||||||
|
raise ValueError(f"Unsupported EVO1 vlm_dtype '{model_dtype}'") from exc
|
||||||
|
|
||||||
|
resolved_use_flash_attn = bool(use_flash_attn and flash_attn_is_available())
|
||||||
|
if use_flash_attn and not resolved_use_flash_attn:
|
||||||
|
logger.warning("flash_attn is not installed. Falling back to standard attention.")
|
||||||
|
|
||||||
|
# InternVL3 remote code predates Transformers 5 post-init conventions:
|
||||||
|
# it computes stochastic-depth scalars via torch.linspace(...).item()
|
||||||
|
# while Transformers initializes under torch.device("meta"), and it
|
||||||
|
# does not populate all_tied_weights_keys before loading finalization.
|
||||||
|
with _internvl_transformers5_load_compatibility():
|
||||||
|
self.model = AutoModel.from_pretrained(
|
||||||
|
model_name,
|
||||||
|
torch_dtype=model_dtype,
|
||||||
|
trust_remote_code=True,
|
||||||
|
use_flash_attn=resolved_use_flash_attn,
|
||||||
|
low_cpu_mem_usage=True,
|
||||||
|
_fast_init=False,
|
||||||
|
).to(self._requested_device)
|
||||||
|
|
||||||
|
if hasattr(self.model.language_model, "model"):
|
||||||
|
layers = self.model.language_model.model.layers
|
||||||
|
else:
|
||||||
|
layers = self.model.language_model.layers
|
||||||
|
if self.num_language_layers is not None:
|
||||||
|
layers = layers[: self.num_language_layers]
|
||||||
|
|
||||||
|
if hasattr(self.model.language_model, "model"):
|
||||||
|
self.model.language_model.model.layers = torch.nn.ModuleList(layers)
|
||||||
|
else:
|
||||||
|
self.model.language_model.layers = torch.nn.ModuleList(layers)
|
||||||
|
self.model.language_model.lm_head = torch.nn.Identity()
|
||||||
|
|
||||||
|
self._configure_memory_features()
|
||||||
|
self.img_context_token_id = self.tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
|
||||||
|
|
||||||
|
def _configure_memory_features(self) -> None:
|
||||||
|
checkpoint_kwargs = {"use_reentrant": self.gradient_checkpointing_use_reentrant}
|
||||||
|
|
||||||
|
if not self.enable_gradient_checkpointing:
|
||||||
|
if hasattr(self.model, "vision_model") and hasattr(self.model.vision_model, "encoder"):
|
||||||
|
self.model.vision_model.encoder.gradient_checkpointing = False
|
||||||
|
language_model = getattr(self.model, "language_model", None)
|
||||||
|
if language_model is not None:
|
||||||
|
if hasattr(language_model, "gradient_checkpointing_disable"):
|
||||||
|
language_model.gradient_checkpointing_disable()
|
||||||
|
elif hasattr(language_model, "gradient_checkpointing"):
|
||||||
|
language_model.gradient_checkpointing = False
|
||||||
|
if hasattr(language_model, "model"):
|
||||||
|
inner = language_model.model
|
||||||
|
if hasattr(inner, "gradient_checkpointing_disable"):
|
||||||
|
inner.gradient_checkpointing_disable()
|
||||||
|
elif hasattr(inner, "gradient_checkpointing"):
|
||||||
|
inner.gradient_checkpointing = False
|
||||||
|
return
|
||||||
|
|
||||||
|
def _enable_ckpt(module: nn.Module | None) -> bool:
|
||||||
|
if module is None:
|
||||||
|
return False
|
||||||
|
if hasattr(module, "gradient_checkpointing_enable"):
|
||||||
|
try:
|
||||||
|
module.gradient_checkpointing_enable(gradient_checkpointing_kwargs=checkpoint_kwargs)
|
||||||
|
except TypeError:
|
||||||
|
module.gradient_checkpointing_enable()
|
||||||
|
return True
|
||||||
|
if hasattr(module, "gradient_checkpointing"):
|
||||||
|
module.gradient_checkpointing = True
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
enabled_any = _enable_ckpt(self.model)
|
||||||
|
|
||||||
|
if hasattr(self.model, "vision_model") and hasattr(self.model.vision_model, "encoder"):
|
||||||
|
encoder = self.model.vision_model.encoder
|
||||||
|
encoder.gradient_checkpointing = True
|
||||||
|
_patch_vision_encoder_checkpointing(
|
||||||
|
encoder, use_reentrant=self.gradient_checkpointing_use_reentrant
|
||||||
|
)
|
||||||
|
enabled_any = True
|
||||||
|
|
||||||
|
language_model = getattr(self.model, "language_model", None)
|
||||||
|
if language_model is not None:
|
||||||
|
enabled_any = _enable_ckpt(language_model) or enabled_any
|
||||||
|
if hasattr(language_model, "model"):
|
||||||
|
enabled_any = _enable_ckpt(language_model.model) or enabled_any
|
||||||
|
if hasattr(language_model, "config"):
|
||||||
|
language_model.config.use_cache = False
|
||||||
|
|
||||||
|
if hasattr(self.model, "config"):
|
||||||
|
self.model.config.use_cache = False
|
||||||
|
if hasattr(self.model, "enable_input_require_grads"):
|
||||||
|
self.model.enable_input_require_grads()
|
||||||
|
|
||||||
|
if enabled_any:
|
||||||
|
logger.info("Gradient checkpointing enabled for InternVL3 embedder.")
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"Requested gradient checkpointing, but model does not expose checkpointing controls."
|
||||||
|
)
|
||||||
|
|
||||||
|
def _preprocess_single_image(self, image: Image.Image | torch.Tensor) -> torch.Tensor:
|
||||||
|
if isinstance(image, torch.Tensor):
|
||||||
|
pil_image = to_pil_image(image.detach().cpu())
|
||||||
|
else:
|
||||||
|
pil_image = image.convert("RGB")
|
||||||
|
tiles = dynamic_preprocess(pil_image, image_size=self.image_size)
|
||||||
|
tile_tensors = torch.stack([TF.to_tensor(tile) for tile in tiles]).to(
|
||||||
|
device=self.device, dtype=torch.bfloat16
|
||||||
|
)
|
||||||
|
mean = torch.tensor(IMAGENET_MEAN, device=self.device, dtype=torch.bfloat16).view(1, 3, 1, 1)
|
||||||
|
std = torch.tensor(IMAGENET_STD, device=self.device, dtype=torch.bfloat16).view(1, 3, 1, 1)
|
||||||
|
return (tile_tensors - mean) / std
|
||||||
|
|
||||||
|
def _preprocess_images(
|
||||||
|
self,
|
||||||
|
image_tensors_batch: Sequence[Sequence[Image.Image | torch.Tensor]],
|
||||||
|
) -> tuple[torch.Tensor, list[list[int]]]:
|
||||||
|
pixel_values_list = []
|
||||||
|
batch_num_tiles_list: list[list[int]] = []
|
||||||
|
|
||||||
|
for image_tensors in image_tensors_batch:
|
||||||
|
num_tiles_list: list[int] = []
|
||||||
|
for image in image_tensors:
|
||||||
|
tiles = self._preprocess_single_image(image)
|
||||||
|
pixel_values_list.append(tiles)
|
||||||
|
num_tiles_list.append(int(tiles.shape[0]))
|
||||||
|
batch_num_tiles_list.append(num_tiles_list)
|
||||||
|
|
||||||
|
if pixel_values_list:
|
||||||
|
pixel_values = torch.cat(pixel_values_list, dim=0)
|
||||||
|
else:
|
||||||
|
pixel_values = torch.empty(
|
||||||
|
0, 3, self.image_size, self.image_size, dtype=torch.bfloat16, device=self.device
|
||||||
|
)
|
||||||
|
return pixel_values, batch_num_tiles_list
|
||||||
|
|
||||||
|
def _build_multimodal_prompts(
|
||||||
|
self,
|
||||||
|
batch_num_tiles_list: list[list[int]],
|
||||||
|
text_prompts: Sequence[str],
|
||||||
|
) -> list[str]:
|
||||||
|
prompts = []
|
||||||
|
for num_tiles_list, text_prompt in zip(batch_num_tiles_list, text_prompts, strict=True):
|
||||||
|
prompt_segments = []
|
||||||
|
for i, tile_count in enumerate(num_tiles_list):
|
||||||
|
token_count = self.model.num_image_token * tile_count
|
||||||
|
image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * token_count + IMG_END_TOKEN
|
||||||
|
prompt_segments.append(f"Image-{i + 1}: {image_tokens}\n")
|
||||||
|
prompts.append("".join(prompt_segments) + text_prompt.strip())
|
||||||
|
return prompts
|
||||||
|
|
||||||
|
def _prepare_and_fuse_embeddings(
|
||||||
|
self,
|
||||||
|
prompts: Sequence[str],
|
||||||
|
vit_embeds: torch.Tensor,
|
||||||
|
image_masks: torch.Tensor,
|
||||||
|
batch_num_tiles_list: list[list[int]],
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
untruncated_ids = self.tokenizer(list(prompts), padding=False, truncation=False)["input_ids"]
|
||||||
|
true_sequence_length = max((len(ids) for ids in untruncated_ids), default=0)
|
||||||
|
if true_sequence_length > self.max_text_length:
|
||||||
|
logger.warning(
|
||||||
|
"InternVL3 prompt truncated in batch: max_length=%s actual_max_length=%s",
|
||||||
|
self.max_text_length,
|
||||||
|
true_sequence_length,
|
||||||
|
)
|
||||||
|
|
||||||
|
model_inputs = self.tokenizer(
|
||||||
|
list(prompts),
|
||||||
|
return_tensors="pt",
|
||||||
|
padding="max_length",
|
||||||
|
truncation=True,
|
||||||
|
max_length=self.max_text_length,
|
||||||
|
).to(self.device)
|
||||||
|
input_ids = model_inputs["input_ids"]
|
||||||
|
attention_mask = model_inputs["attention_mask"]
|
||||||
|
|
||||||
|
img_token_mask = input_ids == self.img_context_token_id
|
||||||
|
input_embeds = self.model.language_model.get_input_embeddings()(input_ids).clone()
|
||||||
|
|
||||||
|
batch_size, _, channels = input_embeds.shape
|
||||||
|
vit_embeds = vit_embeds.reshape(-1, channels).to(dtype=input_embeds.dtype, device=input_embeds.device)
|
||||||
|
tokens_per_tile = self.model.num_image_token
|
||||||
|
actual_vis_tokens_list = img_token_mask.sum(dim=1).tolist()
|
||||||
|
|
||||||
|
vit_idx = 0
|
||||||
|
for batch_index in range(batch_size):
|
||||||
|
expected_vis_tokens = sum(batch_num_tiles_list[batch_index]) * tokens_per_tile
|
||||||
|
mask_b = img_token_mask[batch_index]
|
||||||
|
actual_vis_tokens = actual_vis_tokens_list[batch_index]
|
||||||
|
|
||||||
|
item_vit_embeds = vit_embeds[vit_idx : vit_idx + expected_vis_tokens]
|
||||||
|
vit_idx += expected_vis_tokens
|
||||||
|
if actual_vis_tokens > 0:
|
||||||
|
if item_vit_embeds.shape[0] < actual_vis_tokens:
|
||||||
|
raise ValueError(
|
||||||
|
f"InternVL3 produced fewer image tokens than expected for sample {batch_index}: "
|
||||||
|
f"got {item_vit_embeds.shape[0]}, need {actual_vis_tokens}"
|
||||||
|
)
|
||||||
|
input_embeds[batch_index, mask_b] = item_vit_embeds[:actual_vis_tokens]
|
||||||
|
|
||||||
|
current_token_idx = 0
|
||||||
|
img_token_locations = torch.where(mask_b)[0]
|
||||||
|
for image_index, num_tiles in enumerate(batch_num_tiles_list[batch_index]):
|
||||||
|
num_tokens_for_image = num_tiles * tokens_per_tile
|
||||||
|
if not bool(image_masks[batch_index, image_index].item()):
|
||||||
|
start_offset = current_token_idx
|
||||||
|
end_offset = min(current_token_idx + num_tokens_for_image, len(img_token_locations))
|
||||||
|
if start_offset < end_offset:
|
||||||
|
idxs = img_token_locations[start_offset:end_offset]
|
||||||
|
attention_mask[batch_index, idxs] = 0
|
||||||
|
current_token_idx += num_tokens_for_image
|
||||||
|
|
||||||
|
return input_embeds, attention_mask
|
||||||
|
|
||||||
|
def get_fused_image_text_embedding_from_tensor_images(
|
||||||
|
self,
|
||||||
|
image_tensors_batch: Sequence[Sequence[Image.Image | torch.Tensor]],
|
||||||
|
image_masks: torch.Tensor,
|
||||||
|
text_prompts: Sequence[str],
|
||||||
|
return_cls_only: bool = True,
|
||||||
|
):
|
||||||
|
pixel_values, batch_num_tiles_list = self._preprocess_images(image_tensors_batch)
|
||||||
|
if pixel_values.shape[0] == 0:
|
||||||
|
logger.warning("InternVL3 received an empty image batch after preprocessing.")
|
||||||
|
hidden_size = getattr(self.model.config, "hidden_size", None)
|
||||||
|
if hidden_size is None and hasattr(self.model.language_model, "config"):
|
||||||
|
hidden_size = getattr(self.model.language_model.config, "hidden_size", None)
|
||||||
|
if hidden_size is None:
|
||||||
|
raise RuntimeError("Unable to infer hidden size for empty InternVL3 batch.")
|
||||||
|
empty = torch.empty(0, hidden_size, device=self.device, dtype=torch.float32)
|
||||||
|
return empty
|
||||||
|
|
||||||
|
prompts = self._build_multimodal_prompts(batch_num_tiles_list, text_prompts)
|
||||||
|
vit_embeds = self.model.extract_feature(pixel_values)
|
||||||
|
inputs_embeds, attention_mask = self._prepare_and_fuse_embeddings(
|
||||||
|
prompts,
|
||||||
|
vit_embeds,
|
||||||
|
image_masks.to(device=self.device),
|
||||||
|
batch_num_tiles_list,
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = self.model.language_model(
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
output_hidden_states=True,
|
||||||
|
use_cache=False,
|
||||||
|
return_dict=True,
|
||||||
|
)
|
||||||
|
fused_hidden = outputs.hidden_states[-1].to(torch.float32)
|
||||||
|
return fused_hidden[:, 0, :] if return_cls_only else fused_hidden
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self) -> torch.device:
|
||||||
|
return next(self.model.parameters()).device
|
||||||
@@ -0,0 +1,450 @@
|
|||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import builtins
|
||||||
|
from collections import deque
|
||||||
|
from contextlib import nullcontext
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
from lerobot.configs.policies import PreTrainedConfig
|
||||||
|
from lerobot.policies.evo1.configuration_evo1 import Evo1Config
|
||||||
|
from lerobot.policies.evo1.evo1_model import EVO1
|
||||||
|
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
||||||
|
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||||
|
|
||||||
|
|
||||||
|
class EVO1Policy(PreTrainedPolicy):
|
||||||
|
config_class = Evo1Config
|
||||||
|
name = "evo1"
|
||||||
|
|
||||||
|
def __init__(self, config: Evo1Config, **kwargs):
|
||||||
|
super().__init__(config)
|
||||||
|
config.validate_features()
|
||||||
|
|
||||||
|
if len(config.image_features) > config.max_views:
|
||||||
|
raise ValueError(
|
||||||
|
f"EVO1 supports at most {config.max_views} camera streams, got {len(config.image_features)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
self.model = EVO1(self._build_model_config(config))
|
||||||
|
self.model.set_finetune_flags()
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(
|
||||||
|
cls: builtins.type[T],
|
||||||
|
pretrained_name_or_path: str | Path,
|
||||||
|
*,
|
||||||
|
config: PreTrainedConfig | None = None,
|
||||||
|
force_download: bool = False,
|
||||||
|
resume_download: bool | None = None,
|
||||||
|
proxies: dict | None = None,
|
||||||
|
token: str | bool | None = None,
|
||||||
|
cache_dir: str | Path | None = None,
|
||||||
|
local_files_only: bool = False,
|
||||||
|
revision: str | None = None,
|
||||||
|
strict: bool | None = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> T:
|
||||||
|
if strict is None:
|
||||||
|
strict = not (config is not None and getattr(config, "training_stage", None) == "stage2")
|
||||||
|
return super().from_pretrained(
|
||||||
|
pretrained_name_or_path=pretrained_name_or_path,
|
||||||
|
config=config,
|
||||||
|
force_download=force_download,
|
||||||
|
resume_download=resume_download,
|
||||||
|
proxies=proxies,
|
||||||
|
token=token,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
revision=revision,
|
||||||
|
strict=strict,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _build_model_config(config: Evo1Config) -> dict:
|
||||||
|
return {
|
||||||
|
"device": config.device,
|
||||||
|
"return_cls_only": config.return_cls_only,
|
||||||
|
"vlm_name": config.vlm_model_name,
|
||||||
|
"vlm_num_layers": config.vlm_num_layers,
|
||||||
|
"vlm_dtype": config.vlm_dtype,
|
||||||
|
"use_flash_attn": config.use_flash_attn,
|
||||||
|
"action_head": config.action_head,
|
||||||
|
"action_horizon": config.chunk_size,
|
||||||
|
"per_action_dim": config.max_action_dim,
|
||||||
|
"state_dim": config.max_state_dim,
|
||||||
|
"embed_dim": config.embed_dim,
|
||||||
|
"hidden_dim": config.hidden_dim,
|
||||||
|
"state_hidden_dim": config.state_hidden_dim,
|
||||||
|
"num_heads": config.num_heads,
|
||||||
|
"num_layers": config.num_layers,
|
||||||
|
"dropout": config.dropout,
|
||||||
|
"num_inference_timesteps": config.num_inference_timesteps,
|
||||||
|
"num_categories": config.num_categories,
|
||||||
|
"enable_gradient_checkpointing": config.enable_gradient_checkpointing,
|
||||||
|
"gradient_checkpointing_use_reentrant": config.gradient_checkpointing_use_reentrant,
|
||||||
|
"finetune_vlm": config.finetune_vlm,
|
||||||
|
"finetune_language_model": config.finetune_language_model,
|
||||||
|
"finetune_vision_model": config.finetune_vision_model,
|
||||||
|
"finetune_action_head": config.finetune_action_head,
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _camera_keys(self) -> list[str]:
|
||||||
|
return list(self.config.image_features)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _env_action_dim(self) -> int:
|
||||||
|
action_feature = self.config.action_feature
|
||||||
|
if action_feature is None:
|
||||||
|
return self.config.max_action_dim
|
||||||
|
return int(action_feature.shape[0])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _compute_dtype(self) -> torch.dtype:
|
||||||
|
return next(self.model.action_head.parameters()).dtype
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _training_compute_dtype(self) -> torch.dtype:
|
||||||
|
if str(self.config.device).startswith("cuda"):
|
||||||
|
return torch.bfloat16
|
||||||
|
return self._compute_dtype
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _inference_compute_dtype(self) -> torch.dtype:
|
||||||
|
if str(self.config.device).startswith("cuda") and self.config.use_amp:
|
||||||
|
return torch.bfloat16
|
||||||
|
return self._compute_dtype
|
||||||
|
|
||||||
|
def get_optim_params(self) -> list[dict]:
|
||||||
|
decay, no_decay = [], []
|
||||||
|
for name, param in self.named_parameters():
|
||||||
|
if not param.requires_grad:
|
||||||
|
continue
|
||||||
|
is_bias = name.endswith("bias") or ".bias" in name
|
||||||
|
is_norm = param.dim() == 1 or "norm" in name.lower()
|
||||||
|
if is_bias or is_norm:
|
||||||
|
no_decay.append(param)
|
||||||
|
else:
|
||||||
|
decay.append(param)
|
||||||
|
return [
|
||||||
|
{"params": decay, "weight_decay": self.config.optimizer_weight_decay},
|
||||||
|
{"params": no_decay, "weight_decay": 0.0},
|
||||||
|
]
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self._action_queue = deque([], maxlen=self.config.n_action_steps)
|
||||||
|
|
||||||
|
def _normalize_task_batch(self, batch: dict[str, Tensor | list[str] | str]) -> list[str]:
|
||||||
|
prompts = batch.get(self.config.task_field)
|
||||||
|
if prompts is None and self.config.task_field != "task":
|
||||||
|
prompts = batch.get("task")
|
||||||
|
if prompts is None:
|
||||||
|
raise ValueError(f"EVO1 expects a '{self.config.task_field}' text field in the batch.")
|
||||||
|
if isinstance(prompts, str):
|
||||||
|
return [prompts]
|
||||||
|
if isinstance(prompts, (list, tuple)):
|
||||||
|
return [str(prompt) for prompt in prompts]
|
||||||
|
raise TypeError(f"Unsupported prompt batch type: {type(prompts)}")
|
||||||
|
|
||||||
|
def _prepare_state(self, batch: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
|
||||||
|
if OBS_STATE not in batch:
|
||||||
|
raise ValueError(f"EVO1 requires '{OBS_STATE}' in the batch.")
|
||||||
|
state = batch[OBS_STATE]
|
||||||
|
if state.dim() == 1:
|
||||||
|
state = state.unsqueeze(0)
|
||||||
|
elif state.dim() == 3:
|
||||||
|
state = state[:, -1]
|
||||||
|
elif state.dim() != 2:
|
||||||
|
raise ValueError(f"Unsupported state tensor shape for EVO1: {tuple(state.shape)}")
|
||||||
|
batch_size, state_dim = state.shape
|
||||||
|
if state_dim > self.config.max_state_dim:
|
||||||
|
raise ValueError(
|
||||||
|
f"State dim {state_dim} exceeds configured max_state_dim {self.config.max_state_dim}"
|
||||||
|
)
|
||||||
|
explicit_mask = batch.get("state_mask")
|
||||||
|
if explicit_mask is not None:
|
||||||
|
if explicit_mask.dim() == 1:
|
||||||
|
explicit_mask = explicit_mask.unsqueeze(0)
|
||||||
|
elif explicit_mask.dim() == 3:
|
||||||
|
explicit_mask = explicit_mask[:, -1]
|
||||||
|
elif explicit_mask.dim() != 2:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported state_mask tensor shape for EVO1: {tuple(explicit_mask.shape)}"
|
||||||
|
)
|
||||||
|
if explicit_mask.shape != (batch_size, state_dim):
|
||||||
|
raise ValueError(
|
||||||
|
f"state_mask shape {tuple(explicit_mask.shape)} does not match state shape {(batch_size, state_dim)}"
|
||||||
|
)
|
||||||
|
padded = torch.zeros(
|
||||||
|
batch_size,
|
||||||
|
self.config.max_state_dim,
|
||||||
|
dtype=state.dtype,
|
||||||
|
device=self.config.device,
|
||||||
|
)
|
||||||
|
padded[:, :state_dim] = state.to(device=self.config.device)
|
||||||
|
mask = torch.zeros(
|
||||||
|
batch_size,
|
||||||
|
self.config.max_state_dim,
|
||||||
|
dtype=torch.bool,
|
||||||
|
device=self.config.device,
|
||||||
|
)
|
||||||
|
if explicit_mask is None:
|
||||||
|
mask[:, :state_dim] = True
|
||||||
|
else:
|
||||||
|
mask[:, :state_dim] = explicit_mask.to(device=self.config.device, dtype=torch.bool)
|
||||||
|
return padded.to(dtype=self._compute_dtype), mask
|
||||||
|
|
||||||
|
def _prepare_actions(self, batch: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
|
||||||
|
if ACTION not in batch:
|
||||||
|
raise ValueError(f"EVO1 requires '{ACTION}' in the batch for training.")
|
||||||
|
action = batch[ACTION]
|
||||||
|
if action.dim() == 2:
|
||||||
|
action = action.unsqueeze(1)
|
||||||
|
batch_size, horizon, action_dim = action.shape
|
||||||
|
if horizon != self.config.chunk_size:
|
||||||
|
raise ValueError(
|
||||||
|
f"EVO1 expects chunk_size={self.config.chunk_size}, got action horizon {horizon}"
|
||||||
|
)
|
||||||
|
if action_dim > self.config.max_action_dim:
|
||||||
|
raise ValueError(
|
||||||
|
f"Action dim {action_dim} exceeds configured max_action_dim {self.config.max_action_dim}"
|
||||||
|
)
|
||||||
|
explicit_mask = batch.get("action_mask")
|
||||||
|
if explicit_mask is not None:
|
||||||
|
if explicit_mask.dim() == 2:
|
||||||
|
if horizon == 1:
|
||||||
|
explicit_mask = explicit_mask.unsqueeze(1)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"2D action_mask is only supported when chunk_size=1, got action horizon {horizon}"
|
||||||
|
)
|
||||||
|
elif explicit_mask.dim() != 3:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported action_mask tensor shape for EVO1: {tuple(explicit_mask.shape)}"
|
||||||
|
)
|
||||||
|
if explicit_mask.shape != (batch_size, horizon, action_dim):
|
||||||
|
raise ValueError(
|
||||||
|
"action_mask shape "
|
||||||
|
f"{tuple(explicit_mask.shape)} does not match action shape {(batch_size, horizon, action_dim)}"
|
||||||
|
)
|
||||||
|
padded = torch.zeros(
|
||||||
|
batch_size,
|
||||||
|
horizon,
|
||||||
|
self.config.max_action_dim,
|
||||||
|
dtype=action.dtype,
|
||||||
|
device=self.config.device,
|
||||||
|
)
|
||||||
|
padded[:, :, :action_dim] = action.to(device=self.config.device)
|
||||||
|
mask = torch.zeros(
|
||||||
|
batch_size,
|
||||||
|
horizon,
|
||||||
|
self.config.max_action_dim,
|
||||||
|
dtype=torch.bool,
|
||||||
|
device=self.config.device,
|
||||||
|
)
|
||||||
|
if explicit_mask is None:
|
||||||
|
mask[:, :, :action_dim] = True
|
||||||
|
else:
|
||||||
|
mask[:, :, :action_dim] = explicit_mask.to(device=self.config.device, dtype=torch.bool)
|
||||||
|
return padded.to(dtype=self._compute_dtype), mask
|
||||||
|
|
||||||
|
def _prepare_inference_action_mask(self, batch_size: int) -> Tensor:
|
||||||
|
mask = torch.zeros(
|
||||||
|
batch_size,
|
||||||
|
self.config.max_action_dim,
|
||||||
|
dtype=torch.bool,
|
||||||
|
device=self.config.device,
|
||||||
|
)
|
||||||
|
mask[:, : self._env_action_dim] = True
|
||||||
|
return mask
|
||||||
|
|
||||||
|
def _get_embodiment_ids(self, batch: dict[str, Tensor], batch_size: int) -> Tensor:
|
||||||
|
embodiment_ids = batch.get("embodiment_id")
|
||||||
|
if embodiment_ids is None and self.config.embodiment_id_field:
|
||||||
|
embodiment_ids = batch.get(self.config.embodiment_id_field)
|
||||||
|
if embodiment_ids is None:
|
||||||
|
return torch.full(
|
||||||
|
(batch_size,),
|
||||||
|
self.config.default_embodiment_id,
|
||||||
|
dtype=torch.long,
|
||||||
|
device=self.config.device,
|
||||||
|
)
|
||||||
|
if embodiment_ids.dim() == 0:
|
||||||
|
embodiment_ids = embodiment_ids.unsqueeze(0)
|
||||||
|
elif embodiment_ids.dim() > 1:
|
||||||
|
embodiment_ids = embodiment_ids[:, -1]
|
||||||
|
return embodiment_ids.to(device=self.config.device, dtype=torch.long)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _tracks_vlm_gradients(self) -> bool:
|
||||||
|
return bool(
|
||||||
|
self.config.finetune_vlm
|
||||||
|
or self.config.finetune_language_model
|
||||||
|
or self.config.finetune_vision_model
|
||||||
|
)
|
||||||
|
|
||||||
|
def _collect_image_batches(self, batch: dict[str, Tensor]) -> tuple[list[list[Tensor]], Tensor]:
|
||||||
|
camera_keys = self._camera_keys or sorted(key for key in batch if key.startswith(f"{OBS_IMAGES}."))
|
||||||
|
if not camera_keys:
|
||||||
|
raise ValueError("EVO1 requires at least one visual observation feature.")
|
||||||
|
|
||||||
|
# Normalize each camera tensor to (B, C, H, W) up-front so that batch_size is read
|
||||||
|
# from a real batch dim and not from C in the unbatched (C, H, W) case.
|
||||||
|
normalized: dict[str, Tensor] = {}
|
||||||
|
for camera_key in camera_keys[: self.config.max_views]:
|
||||||
|
image = batch[camera_key]
|
||||||
|
if image.dim() == 3:
|
||||||
|
image = image.unsqueeze(0)
|
||||||
|
elif image.dim() == 5:
|
||||||
|
image = image[:, -1]
|
||||||
|
elif image.dim() != 4:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported image tensor shape for EVO1: key={camera_key} shape={tuple(image.shape)}"
|
||||||
|
)
|
||||||
|
normalized[camera_key] = image
|
||||||
|
|
||||||
|
batch_size = normalized[camera_keys[0]].shape[0]
|
||||||
|
image_batches: list[list[Tensor]] = []
|
||||||
|
image_masks = torch.zeros(batch_size, self.config.max_views, dtype=torch.bool)
|
||||||
|
|
||||||
|
for batch_index in range(batch_size):
|
||||||
|
sample_images: list[Tensor] = []
|
||||||
|
for camera_key in camera_keys[: self.config.max_views]:
|
||||||
|
sample_images.append(normalized[camera_key][batch_index].detach().cpu())
|
||||||
|
if not sample_images:
|
||||||
|
raise ValueError("EVO1 received a batch without any image tensor.")
|
||||||
|
while len(sample_images) < self.config.max_views:
|
||||||
|
sample_images.append(torch.zeros_like(sample_images[0]))
|
||||||
|
image_batches.append(sample_images[: self.config.max_views])
|
||||||
|
image_masks[batch_index, : min(len(camera_keys), self.config.max_views)] = True
|
||||||
|
|
||||||
|
return image_batches, image_masks
|
||||||
|
|
||||||
|
def _compute_fused_tokens(
|
||||||
|
self,
|
||||||
|
prompts: list[str],
|
||||||
|
image_batches: list[list[Tensor]],
|
||||||
|
image_masks: Tensor,
|
||||||
|
) -> Tensor:
|
||||||
|
track_vlm_gradients = self._tracks_vlm_gradients
|
||||||
|
grad_context = nullcontext() if track_vlm_gradients else torch.no_grad()
|
||||||
|
embedder = getattr(self.model, "embedder", None)
|
||||||
|
embedder_was_training = embedder.training if embedder is not None else None
|
||||||
|
|
||||||
|
if not track_vlm_gradients and embedder is not None:
|
||||||
|
embedder.eval()
|
||||||
|
|
||||||
|
try:
|
||||||
|
with grad_context:
|
||||||
|
fused_tokens = self.model.get_vl_embeddings(
|
||||||
|
images=image_batches,
|
||||||
|
image_mask=image_masks,
|
||||||
|
prompt=prompts,
|
||||||
|
return_cls_only=self.config.return_cls_only,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
if not track_vlm_gradients and embedder is not None and embedder_was_training is not None:
|
||||||
|
embedder.train(embedder_was_training)
|
||||||
|
|
||||||
|
if not track_vlm_gradients:
|
||||||
|
fused_tokens = fused_tokens.detach()
|
||||||
|
return fused_tokens.to(device=self.config.device, dtype=self._compute_dtype)
|
||||||
|
|
||||||
|
def _compute_masked_loss(
|
||||||
|
self,
|
||||||
|
pred_velocity: Tensor,
|
||||||
|
target_velocity: Tensor,
|
||||||
|
action_mask: Tensor,
|
||||||
|
reduction: str,
|
||||||
|
) -> Tensor:
|
||||||
|
flat_mask = action_mask.view(action_mask.shape[0], -1).to(dtype=pred_velocity.dtype)
|
||||||
|
sq_error = ((pred_velocity - target_velocity) * flat_mask).pow(2)
|
||||||
|
active = flat_mask.sum(dim=1).clamp_min(1.0)
|
||||||
|
per_sample_loss = sq_error.sum(dim=1) / active
|
||||||
|
if reduction == "none":
|
||||||
|
return per_sample_loss
|
||||||
|
if reduction != "mean":
|
||||||
|
raise ValueError(f"Unsupported reduction '{reduction}'")
|
||||||
|
return sq_error.sum() / active.sum()
|
||||||
|
|
||||||
|
def forward(self, batch: dict[str, Tensor], reduction: str = "mean") -> tuple[Tensor, dict]:
|
||||||
|
prompts = self._normalize_task_batch(batch)
|
||||||
|
image_batches, image_masks = self._collect_image_batches(batch)
|
||||||
|
states, _state_mask = self._prepare_state(batch)
|
||||||
|
actions_gt, action_mask = self._prepare_actions(batch)
|
||||||
|
fused_tokens = self._compute_fused_tokens(prompts, image_batches, image_masks)
|
||||||
|
states = states.to(dtype=self._training_compute_dtype)
|
||||||
|
actions_gt = actions_gt.to(dtype=self._training_compute_dtype)
|
||||||
|
fused_tokens = fused_tokens.to(dtype=self._training_compute_dtype)
|
||||||
|
embodiment_ids = self._get_embodiment_ids(batch, states.shape[0])
|
||||||
|
|
||||||
|
pred_velocity, noise = self.model(
|
||||||
|
fused_tokens,
|
||||||
|
state=states,
|
||||||
|
actions_gt=actions_gt,
|
||||||
|
action_mask=action_mask.to(device=self.config.device, dtype=self._compute_dtype),
|
||||||
|
embodiment_ids=embodiment_ids,
|
||||||
|
)
|
||||||
|
flat_action_mask = action_mask.view(action_mask.shape[0], -1).to(dtype=actions_gt.dtype)
|
||||||
|
target_velocity = (actions_gt - noise).view(actions_gt.shape[0], -1) * flat_action_mask
|
||||||
|
loss = self._compute_masked_loss(pred_velocity, target_velocity, action_mask, reduction)
|
||||||
|
loss_mean = loss.mean().item() if loss.ndim > 0 else loss.item()
|
||||||
|
return loss, {
|
||||||
|
"loss": loss_mean,
|
||||||
|
"active_action_dims": float(action_mask.sum(dim=(1, 2)).float().mean().item()),
|
||||||
|
}
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs) -> Tensor:
|
||||||
|
self.eval()
|
||||||
|
|
||||||
|
prompts = self._normalize_task_batch(batch)
|
||||||
|
image_batches, image_masks = self._collect_image_batches(batch)
|
||||||
|
states, _state_mask = self._prepare_state(batch)
|
||||||
|
fused_tokens = self._compute_fused_tokens(prompts, image_batches, image_masks)
|
||||||
|
states = states.to(dtype=self._inference_compute_dtype)
|
||||||
|
fused_tokens = fused_tokens.to(dtype=self._inference_compute_dtype)
|
||||||
|
embodiment_ids = self._get_embodiment_ids(batch, states.shape[0])
|
||||||
|
action_mask = self._prepare_inference_action_mask(states.shape[0])
|
||||||
|
|
||||||
|
with (
|
||||||
|
torch.autocast(device_type="cuda", dtype=torch.bfloat16)
|
||||||
|
if self.config.use_amp and str(self.config.device).startswith("cuda")
|
||||||
|
else nullcontext()
|
||||||
|
):
|
||||||
|
actions = self.model(
|
||||||
|
fused_tokens,
|
||||||
|
state=states,
|
||||||
|
action_mask=action_mask,
|
||||||
|
embodiment_ids=embodiment_ids,
|
||||||
|
)
|
||||||
|
actions = actions.view(states.shape[0], self.config.chunk_size, self.config.max_action_dim)
|
||||||
|
return actions[:, :, : self._env_action_dim]
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def select_action(self, batch: dict[str, Tensor], **kwargs) -> Tensor:
|
||||||
|
self.eval()
|
||||||
|
if len(self._action_queue) == 0:
|
||||||
|
action_chunk = self.predict_action_chunk(batch)[:, : self.config.n_action_steps]
|
||||||
|
self._action_queue.extend(action_chunk.transpose(0, 1))
|
||||||
|
return self._action_queue.popleft()
|
||||||
@@ -0,0 +1,106 @@
|
|||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from lerobot.policies.evo1.configuration_evo1 import Evo1Config
|
||||||
|
from lerobot.processor import (
|
||||||
|
AddBatchDimensionProcessorStep,
|
||||||
|
DeviceProcessorStep,
|
||||||
|
NormalizerProcessorStep,
|
||||||
|
PolicyAction,
|
||||||
|
PolicyProcessorPipeline,
|
||||||
|
RenameObservationsProcessorStep,
|
||||||
|
UnnormalizerProcessorStep,
|
||||||
|
)
|
||||||
|
from lerobot.processor.converters import (
|
||||||
|
batch_to_transition,
|
||||||
|
create_transition,
|
||||||
|
policy_action_to_transition,
|
||||||
|
transition_to_policy_action,
|
||||||
|
)
|
||||||
|
from lerobot.utils.constants import (
|
||||||
|
ACTION,
|
||||||
|
DONE,
|
||||||
|
INFO,
|
||||||
|
OBS_PREFIX,
|
||||||
|
POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||||
|
POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||||
|
REWARD,
|
||||||
|
TRUNCATED,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def evo1_batch_to_transition(batch: dict[str, Any]):
|
||||||
|
transition = batch_to_transition(batch)
|
||||||
|
complementary_data = dict(transition.get("complementary_data") or {})
|
||||||
|
reserved = {ACTION, REWARD, DONE, TRUNCATED, INFO}
|
||||||
|
for key, value in batch.items():
|
||||||
|
if key in reserved or key.startswith(OBS_PREFIX):
|
||||||
|
continue
|
||||||
|
complementary_data.setdefault(key, value)
|
||||||
|
return create_transition(
|
||||||
|
observation=transition.get("observation"),
|
||||||
|
action=transition.get("action"),
|
||||||
|
reward=transition.get("reward", 0.0),
|
||||||
|
done=transition.get("done", False),
|
||||||
|
truncated=transition.get("truncated", False),
|
||||||
|
info=transition.get("info", {}),
|
||||||
|
complementary_data=complementary_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def make_evo1_pre_post_processors(
|
||||||
|
config: Evo1Config,
|
||||||
|
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||||
|
) -> tuple[
|
||||||
|
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||||
|
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||||
|
]:
|
||||||
|
input_steps = [
|
||||||
|
RenameObservationsProcessorStep(rename_map={}),
|
||||||
|
AddBatchDimensionProcessorStep(),
|
||||||
|
NormalizerProcessorStep(
|
||||||
|
features={**config.input_features, **config.output_features},
|
||||||
|
norm_map=config.normalization_mapping,
|
||||||
|
stats=dataset_stats,
|
||||||
|
),
|
||||||
|
DeviceProcessorStep(device=config.device),
|
||||||
|
]
|
||||||
|
output_steps = [
|
||||||
|
UnnormalizerProcessorStep(
|
||||||
|
features=config.output_features,
|
||||||
|
norm_map=config.normalization_mapping,
|
||||||
|
stats=dataset_stats,
|
||||||
|
),
|
||||||
|
DeviceProcessorStep(device="cpu"),
|
||||||
|
]
|
||||||
|
|
||||||
|
return (
|
||||||
|
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||||
|
steps=input_steps,
|
||||||
|
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||||
|
to_transition=evo1_batch_to_transition,
|
||||||
|
),
|
||||||
|
PolicyProcessorPipeline[PolicyAction, PolicyAction](
|
||||||
|
steps=output_steps,
|
||||||
|
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||||
|
to_transition=policy_action_to_transition,
|
||||||
|
to_output=transition_to_policy_action,
|
||||||
|
),
|
||||||
|
)
|
||||||
@@ -46,14 +46,14 @@ from lerobot.utils.feature_utils import dataset_to_policy_features
|
|||||||
|
|
||||||
from .act.configuration_act import ACTConfig
|
from .act.configuration_act import ACTConfig
|
||||||
from .diffusion.configuration_diffusion import DiffusionConfig
|
from .diffusion.configuration_diffusion import DiffusionConfig
|
||||||
|
from .eo1.configuration_eo1 import EO1Config
|
||||||
|
from .evo1.configuration_evo1 import Evo1Config
|
||||||
from .groot.configuration_groot import GrootConfig
|
from .groot.configuration_groot import GrootConfig
|
||||||
from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
|
from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
|
||||||
from .pi0.configuration_pi0 import PI0Config
|
from .pi0.configuration_pi0 import PI0Config
|
||||||
from .pi05.configuration_pi05 import PI05Config
|
from .pi05.configuration_pi05 import PI05Config
|
||||||
from .pretrained import PreTrainedPolicy
|
from .pretrained import PreTrainedPolicy
|
||||||
from .sac.configuration_sac import SACConfig
|
from .sac.configuration_sac import SACConfig
|
||||||
from .sac.reward_model.configuration_classifier import RewardClassifierConfig
|
|
||||||
from .sarm.configuration_sarm import SARMConfig
|
|
||||||
from .smolvla.configuration_smolvla import SmolVLAConfig
|
from .smolvla.configuration_smolvla import SmolVLAConfig
|
||||||
from .tdmpc.configuration_tdmpc import TDMPCConfig
|
from .tdmpc.configuration_tdmpc import TDMPCConfig
|
||||||
from .utils import validate_visual_features_consistency
|
from .utils import validate_visual_features_consistency
|
||||||
@@ -89,7 +89,7 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
name: The name of the policy. Supported names are "tdmpc", "diffusion", "act",
|
name: The name of the policy. Supported names are "tdmpc", "diffusion", "act",
|
||||||
"multi_task_dit", "vqbet", "pi0", "pi05", "sac", "reward_classifier", "smolvla", "wall_x".
|
"multi_task_dit", "vqbet", "pi0", "pi05", "sac", "smolvla", "wall_x", "eo1", "evo1".
|
||||||
Returns:
|
Returns:
|
||||||
The policy class corresponding to the given name.
|
The policy class corresponding to the given name.
|
||||||
|
|
||||||
@@ -132,18 +132,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
|||||||
from .sac.modeling_sac import SACPolicy
|
from .sac.modeling_sac import SACPolicy
|
||||||
|
|
||||||
return SACPolicy
|
return SACPolicy
|
||||||
elif name == "reward_classifier":
|
|
||||||
from .sac.reward_model.modeling_classifier import Classifier
|
|
||||||
|
|
||||||
return Classifier
|
|
||||||
elif name == "smolvla":
|
elif name == "smolvla":
|
||||||
from .smolvla.modeling_smolvla import SmolVLAPolicy
|
from .smolvla.modeling_smolvla import SmolVLAPolicy
|
||||||
|
|
||||||
return SmolVLAPolicy
|
return SmolVLAPolicy
|
||||||
elif name == "sarm":
|
|
||||||
from .sarm.modeling_sarm import SARMRewardModel
|
|
||||||
|
|
||||||
return SARMRewardModel
|
|
||||||
elif name == "groot":
|
elif name == "groot":
|
||||||
from .groot.modeling_groot import GrootPolicy
|
from .groot.modeling_groot import GrootPolicy
|
||||||
|
|
||||||
@@ -156,6 +148,14 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
|||||||
from .wall_x.modeling_wall_x import WallXPolicy
|
from .wall_x.modeling_wall_x import WallXPolicy
|
||||||
|
|
||||||
return WallXPolicy
|
return WallXPolicy
|
||||||
|
elif name == "eo1":
|
||||||
|
from .eo1.modeling_eo1 import EO1Policy
|
||||||
|
|
||||||
|
return EO1Policy
|
||||||
|
elif name == "evo1":
|
||||||
|
from .evo1.modeling_evo1 import EVO1Policy
|
||||||
|
|
||||||
|
return EVO1Policy
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
return _get_policy_cls_from_policy_name(name=name)
|
return _get_policy_cls_from_policy_name(name=name)
|
||||||
@@ -173,7 +173,7 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
|||||||
Args:
|
Args:
|
||||||
policy_type: The type of the policy. Supported types include "tdmpc",
|
policy_type: The type of the policy. Supported types include "tdmpc",
|
||||||
"multi_task_dit", "diffusion", "act", "vqbet", "pi0", "pi05", "sac",
|
"multi_task_dit", "diffusion", "act", "vqbet", "pi0", "pi05", "sac",
|
||||||
"smolvla", "reward_classifier", "wall_x".
|
"smolvla", "wall_x", "eo1", "evo1".
|
||||||
**kwargs: Keyword arguments to be passed to the configuration class constructor.
|
**kwargs: Keyword arguments to be passed to the configuration class constructor.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -200,14 +200,16 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
|||||||
return SACConfig(**kwargs)
|
return SACConfig(**kwargs)
|
||||||
elif policy_type == "smolvla":
|
elif policy_type == "smolvla":
|
||||||
return SmolVLAConfig(**kwargs)
|
return SmolVLAConfig(**kwargs)
|
||||||
elif policy_type == "reward_classifier":
|
|
||||||
return RewardClassifierConfig(**kwargs)
|
|
||||||
elif policy_type == "groot":
|
elif policy_type == "groot":
|
||||||
return GrootConfig(**kwargs)
|
return GrootConfig(**kwargs)
|
||||||
elif policy_type == "xvla":
|
elif policy_type == "xvla":
|
||||||
return XVLAConfig(**kwargs)
|
return XVLAConfig(**kwargs)
|
||||||
elif policy_type == "wall_x":
|
elif policy_type == "wall_x":
|
||||||
return WallXConfig(**kwargs)
|
return WallXConfig(**kwargs)
|
||||||
|
elif policy_type == "eo1":
|
||||||
|
return EO1Config(**kwargs)
|
||||||
|
elif policy_type == "evo1":
|
||||||
|
return Evo1Config(**kwargs)
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
config_cls = PreTrainedConfig.get_choice_class(policy_type)
|
config_cls = PreTrainedConfig.get_choice_class(policy_type)
|
||||||
@@ -378,14 +380,6 @@ def make_pre_post_processors(
|
|||||||
dataset_stats=kwargs.get("dataset_stats"),
|
dataset_stats=kwargs.get("dataset_stats"),
|
||||||
)
|
)
|
||||||
|
|
||||||
elif isinstance(policy_cfg, RewardClassifierConfig):
|
|
||||||
from .sac.reward_model.processor_classifier import make_classifier_processor
|
|
||||||
|
|
||||||
processors = make_classifier_processor(
|
|
||||||
config=policy_cfg,
|
|
||||||
dataset_stats=kwargs.get("dataset_stats"),
|
|
||||||
)
|
|
||||||
|
|
||||||
elif isinstance(policy_cfg, SmolVLAConfig):
|
elif isinstance(policy_cfg, SmolVLAConfig):
|
||||||
from .smolvla.processor_smolvla import make_smolvla_pre_post_processors
|
from .smolvla.processor_smolvla import make_smolvla_pre_post_processors
|
||||||
|
|
||||||
@@ -394,14 +388,6 @@ def make_pre_post_processors(
|
|||||||
dataset_stats=kwargs.get("dataset_stats"),
|
dataset_stats=kwargs.get("dataset_stats"),
|
||||||
)
|
)
|
||||||
|
|
||||||
elif isinstance(policy_cfg, SARMConfig):
|
|
||||||
from .sarm.processor_sarm import make_sarm_pre_post_processors
|
|
||||||
|
|
||||||
processors = make_sarm_pre_post_processors(
|
|
||||||
config=policy_cfg,
|
|
||||||
dataset_stats=kwargs.get("dataset_stats"),
|
|
||||||
dataset_meta=kwargs.get("dataset_meta"),
|
|
||||||
)
|
|
||||||
elif isinstance(policy_cfg, GrootConfig):
|
elif isinstance(policy_cfg, GrootConfig):
|
||||||
from .groot.processor_groot import make_groot_pre_post_processors
|
from .groot.processor_groot import make_groot_pre_post_processors
|
||||||
|
|
||||||
@@ -427,6 +413,20 @@ def make_pre_post_processors(
|
|||||||
config=policy_cfg,
|
config=policy_cfg,
|
||||||
dataset_stats=kwargs.get("dataset_stats"),
|
dataset_stats=kwargs.get("dataset_stats"),
|
||||||
)
|
)
|
||||||
|
elif isinstance(policy_cfg, EO1Config):
|
||||||
|
from .eo1.processor_eo1 import make_eo1_pre_post_processors
|
||||||
|
|
||||||
|
processors = make_eo1_pre_post_processors(
|
||||||
|
config=policy_cfg,
|
||||||
|
dataset_stats=kwargs.get("dataset_stats"),
|
||||||
|
)
|
||||||
|
elif isinstance(policy_cfg, Evo1Config):
|
||||||
|
from .evo1.processor_evo1 import make_evo1_pre_post_processors
|
||||||
|
|
||||||
|
processors = make_evo1_pre_post_processors(
|
||||||
|
config=policy_cfg,
|
||||||
|
dataset_stats=kwargs.get("dataset_stats"),
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
@@ -542,7 +542,7 @@ def make_policy(
|
|||||||
|
|
||||||
logging.info("Loading policy's PEFT adapter.")
|
logging.info("Loading policy's PEFT adapter.")
|
||||||
|
|
||||||
peft_pretrained_path = cfg.pretrained_path
|
peft_pretrained_path = str(cfg.pretrained_path)
|
||||||
peft_config = PeftConfig.from_pretrained(peft_pretrained_path)
|
peft_config = PeftConfig.from_pretrained(peft_pretrained_path)
|
||||||
|
|
||||||
kwargs["pretrained_name_or_path"] = peft_config.base_model_name_or_path
|
kwargs["pretrained_name_or_path"] = peft_config.base_model_name_or_path
|
||||||
@@ -555,7 +555,9 @@ def make_policy(
|
|||||||
)
|
)
|
||||||
|
|
||||||
policy = policy_cls.from_pretrained(**kwargs)
|
policy = policy_cls.from_pretrained(**kwargs)
|
||||||
policy = PeftModel.from_pretrained(policy, peft_pretrained_path, config=peft_config)
|
policy = PeftModel.from_pretrained(
|
||||||
|
policy, peft_pretrained_path, config=peft_config, is_trainable=True
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Make a fresh policy.
|
# Make a fresh policy.
|
||||||
|
|||||||
@@ -13,7 +13,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import field
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -109,7 +109,6 @@ class MultiEmbodimentActionEncoder(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class FlowmatchingActionHeadConfig(PretrainedConfig):
|
class FlowmatchingActionHeadConfig(PretrainedConfig):
|
||||||
"""NOTE: N1.5 uses XEmbFlowmatchingPolicyHeadConfig as action head"""
|
"""NOTE: N1.5 uses XEmbFlowmatchingPolicyHeadConfig as action head"""
|
||||||
|
|
||||||
|
|||||||
@@ -13,7 +13,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
@@ -174,17 +173,14 @@ N_COLOR_CHANNELS = 3
|
|||||||
|
|
||||||
|
|
||||||
# config
|
# config
|
||||||
@dataclass
|
|
||||||
class GR00TN15Config(PretrainedConfig):
|
class GR00TN15Config(PretrainedConfig):
|
||||||
model_type = "gr00t_n1_5"
|
model_type = "gr00t_n1_5"
|
||||||
backbone_cfg: dict = field(init=False, metadata={"help": "Backbone configuration."})
|
|
||||||
|
|
||||||
action_head_cfg: dict = field(init=False, metadata={"help": "Action head configuration."})
|
backbone_cfg: dict
|
||||||
|
action_head_cfg: dict
|
||||||
action_horizon: int = field(init=False, metadata={"help": "Action horizon."})
|
action_horizon: int
|
||||||
|
action_dim: int
|
||||||
action_dim: int = field(init=False, metadata={"help": "Action dimension."})
|
compute_dtype: str = "float32"
|
||||||
compute_dtype: str = field(default="float32", metadata={"help": "Compute dtype."})
|
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|||||||
@@ -688,8 +688,9 @@ class DiffusionObjective(nn.Module):
|
|||||||
loss = F.mse_loss(predicted, target, reduction="none")
|
loss = F.mse_loss(predicted, target, reduction="none")
|
||||||
|
|
||||||
if self.do_mask_loss_for_padding and "action_is_pad" in batch:
|
if self.do_mask_loss_for_padding and "action_is_pad" in batch:
|
||||||
valid_actions = ~batch["action_is_pad"]
|
mask = ~batch["action_is_pad"].unsqueeze(-1)
|
||||||
loss = loss * valid_actions.unsqueeze(-1)
|
num_valid = mask.sum() * loss.shape[-1]
|
||||||
|
return (loss * mask).sum() / num_valid.clamp_min(1)
|
||||||
|
|
||||||
return loss.mean()
|
return loss.mean()
|
||||||
|
|
||||||
@@ -752,8 +753,9 @@ class FlowMatchingObjective(nn.Module):
|
|||||||
loss = F.mse_loss(predicted_velocity, target_velocity, reduction="none")
|
loss = F.mse_loss(predicted_velocity, target_velocity, reduction="none")
|
||||||
|
|
||||||
if self.do_mask_loss_for_padding and "action_is_pad" in batch:
|
if self.do_mask_loss_for_padding and "action_is_pad" in batch:
|
||||||
valid_mask = ~batch["action_is_pad"]
|
mask = ~batch["action_is_pad"].unsqueeze(-1)
|
||||||
loss = loss * valid_mask.unsqueeze(-1)
|
num_valid = mask.sum() * loss.shape[-1]
|
||||||
|
return (loss * mask).sum() / num_valid.clamp_min(1)
|
||||||
|
|
||||||
return loss.mean()
|
return loss.mean()
|
||||||
|
|
||||||
|
|||||||
@@ -444,13 +444,13 @@ class PaliGemmaWithExpertModel(
|
|||||||
if image.dtype != torch.float32:
|
if image.dtype != torch.float32:
|
||||||
image = image.to(torch.float32)
|
image = image.to(torch.float32)
|
||||||
image_outputs = self.paligemma.model.get_image_features(image)
|
image_outputs = self.paligemma.model.get_image_features(image)
|
||||||
features = image_outputs.pooler_output * self.paligemma.config.text_config.hidden_size**0.5
|
features = image_outputs.pooler_output
|
||||||
if features.dtype != out_dtype:
|
if features.dtype != out_dtype:
|
||||||
features = features.to(out_dtype)
|
features = features.to(out_dtype)
|
||||||
return features
|
return features
|
||||||
|
|
||||||
def embed_language_tokens(self, tokens: torch.Tensor):
|
def embed_language_tokens(self, tokens: torch.Tensor):
|
||||||
return self.paligemma.model.language_model.embed_tokens(tokens)
|
return self.paligemma.model.language_model.get_input_embeddings()(tokens)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -666,8 +666,7 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
# Process language tokens
|
# Process language tokens
|
||||||
def lang_embed_func(lang_tokens):
|
def lang_embed_func(lang_tokens):
|
||||||
lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens)
|
lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens)
|
||||||
lang_emb_dim = lang_emb.shape[-1]
|
return lang_emb
|
||||||
return lang_emb * math.sqrt(lang_emb_dim)
|
|
||||||
|
|
||||||
lang_emb = self._apply_checkpoint(lang_embed_func, lang_tokens)
|
lang_emb = self._apply_checkpoint(lang_embed_func, lang_tokens)
|
||||||
embs.append(lang_emb)
|
embs.append(lang_emb)
|
||||||
@@ -748,16 +747,8 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
|
|
||||||
return embs, pad_masks, att_masks, adarms_cond
|
return embs, pad_masks, att_masks, adarms_cond
|
||||||
|
|
||||||
def forward(
|
def forward(self, images, img_masks, lang_tokens, lang_masks, state, actions, noise, time) -> Tensor:
|
||||||
self, images, img_masks, lang_tokens, lang_masks, state, actions, noise=None, time=None
|
|
||||||
) -> Tensor:
|
|
||||||
"""Do a full training forward pass and compute the loss."""
|
"""Do a full training forward pass and compute the loss."""
|
||||||
if noise is None:
|
|
||||||
noise = self.sample_noise(actions.shape, actions.device)
|
|
||||||
|
|
||||||
if time is None:
|
|
||||||
time = self.sample_time(actions.shape[0], actions.device)
|
|
||||||
|
|
||||||
time_expanded = time[:, None, None]
|
time_expanded = time[:, None, None]
|
||||||
x_t = time_expanded * noise + (1 - time_expanded) * actions
|
x_t = time_expanded * noise + (1 - time_expanded) * actions
|
||||||
u_t = noise - actions
|
u_t = noise - actions
|
||||||
@@ -1292,8 +1283,11 @@ class PI0Policy(PreTrainedPolicy):
|
|||||||
state = self.prepare_state(batch)
|
state = self.prepare_state(batch)
|
||||||
actions = self.prepare_action(batch)
|
actions = self.prepare_action(batch)
|
||||||
|
|
||||||
|
noise = self.model.sample_noise(actions.shape, actions.device)
|
||||||
|
time = self.model.sample_time(actions.shape[0], actions.device)
|
||||||
|
|
||||||
# Compute loss
|
# Compute loss
|
||||||
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions)
|
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time)
|
||||||
|
|
||||||
# Truncate losses to actual action dimensions
|
# Truncate losses to actual action dimensions
|
||||||
original_action_dim = self.config.output_features[ACTION].shape[0]
|
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||||
|
|||||||
@@ -728,14 +728,8 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
|
|
||||||
return embs, pad_masks, att_masks, adarms_cond
|
return embs, pad_masks, att_masks, adarms_cond
|
||||||
|
|
||||||
def forward(self, images, img_masks, tokens, masks, actions, noise=None, time=None) -> Tensor:
|
def forward(self, images, img_masks, tokens, masks, actions, noise, time) -> Tensor:
|
||||||
"""Do a full training forward pass and compute the loss."""
|
"""Do a full training forward pass and compute the loss."""
|
||||||
if noise is None:
|
|
||||||
noise = self.sample_noise(actions.shape, actions.device)
|
|
||||||
|
|
||||||
if time is None:
|
|
||||||
time = self.sample_time(actions.shape[0], actions.device)
|
|
||||||
|
|
||||||
time_expanded = time[:, None, None]
|
time_expanded = time[:, None, None]
|
||||||
x_t = time_expanded * noise + (1 - time_expanded) * actions
|
x_t = time_expanded * noise + (1 - time_expanded) * actions
|
||||||
u_t = noise - actions
|
u_t = noise - actions
|
||||||
@@ -1262,8 +1256,11 @@ class PI05Policy(PreTrainedPolicy):
|
|||||||
|
|
||||||
actions = self.prepare_action(batch)
|
actions = self.prepare_action(batch)
|
||||||
|
|
||||||
|
noise = self.model.sample_noise(actions.shape, actions.device)
|
||||||
|
time = self.model.sample_time(actions.shape[0], actions.device)
|
||||||
|
|
||||||
# Compute loss (no separate state needed for PI05)
|
# Compute loss (no separate state needed for PI05)
|
||||||
losses = self.model.forward(images, img_masks, tokens, masks, actions)
|
losses = self.model.forward(images, img_masks, tokens, masks, actions, noise, time)
|
||||||
|
|
||||||
# Truncate losses to actual action dimensions
|
# Truncate losses to actual action dimensions
|
||||||
original_action_dim = self.config.output_features[ACTION].shape[0]
|
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||||
|
|||||||
@@ -16,7 +16,6 @@
|
|||||||
|
|
||||||
import builtins
|
import builtins
|
||||||
import logging
|
import logging
|
||||||
import math
|
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Literal, TypedDict, Unpack
|
from typing import TYPE_CHECKING, Literal, TypedDict, Unpack
|
||||||
@@ -227,6 +226,7 @@ class PI0FastPaliGemma(nn.Module):
|
|||||||
# forward(..., adarms_cond=...) is supported (same as pi0/pi05).
|
# forward(..., adarms_cond=...) is supported (same as pi0/pi05).
|
||||||
if use_adarms[0]:
|
if use_adarms[0]:
|
||||||
text_config = self.paligemma.config.text_config
|
text_config = self.paligemma.config.text_config
|
||||||
|
del self.paligemma.model.language_model
|
||||||
self.paligemma.model.language_model = PiGemmaModel(text_config)
|
self.paligemma.model.language_model = PiGemmaModel(text_config)
|
||||||
|
|
||||||
self.to_bfloat16_for_selected_params(precision)
|
self.to_bfloat16_for_selected_params(precision)
|
||||||
@@ -260,13 +260,15 @@ class PI0FastPaliGemma(nn.Module):
|
|||||||
if image.dtype != torch.float32:
|
if image.dtype != torch.float32:
|
||||||
image = image.to(torch.float32)
|
image = image.to(torch.float32)
|
||||||
image_outputs = self.paligemma.model.get_image_features(image)
|
image_outputs = self.paligemma.model.get_image_features(image)
|
||||||
features = image_outputs.pooler_output * self.paligemma.config.text_config.hidden_size**0.5
|
features = image_outputs.pooler_output
|
||||||
|
norm = 2048**0.5
|
||||||
|
features = features / norm * norm
|
||||||
if features.dtype != out_dtype:
|
if features.dtype != out_dtype:
|
||||||
features = features.to(out_dtype)
|
features = features.to(out_dtype)
|
||||||
return features
|
return features
|
||||||
|
|
||||||
def embed_language_tokens(self, tokens: torch.Tensor):
|
def embed_language_tokens(self, tokens: torch.Tensor):
|
||||||
return self.paligemma.model.language_model.embed_tokens(tokens)
|
return self.paligemma.model.language_model.get_input_embeddings()(tokens)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -416,8 +418,7 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
# Process language instruction tokens
|
# Process language instruction tokens
|
||||||
def lang_embed_func(tokens):
|
def lang_embed_func(tokens):
|
||||||
lang_emb = self.paligemma_with_expert.embed_language_tokens(tokens)
|
lang_emb = self.paligemma_with_expert.embed_language_tokens(tokens)
|
||||||
lang_emb_dim = lang_emb.shape[-1]
|
return lang_emb
|
||||||
return lang_emb * math.sqrt(lang_emb_dim)
|
|
||||||
|
|
||||||
lang_emb = self._apply_checkpoint(lang_embed_func, tokens)
|
lang_emb = self._apply_checkpoint(lang_embed_func, tokens)
|
||||||
embs.append(lang_emb)
|
embs.append(lang_emb)
|
||||||
@@ -431,8 +432,7 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
|
|
||||||
def fast_action_embed_func(fast_action_tokens):
|
def fast_action_embed_func(fast_action_tokens):
|
||||||
fast_emb = self.paligemma_with_expert.embed_language_tokens(fast_action_tokens)
|
fast_emb = self.paligemma_with_expert.embed_language_tokens(fast_action_tokens)
|
||||||
fast_emb_dim = fast_emb.shape[-1]
|
return fast_emb
|
||||||
return fast_emb * math.sqrt(fast_emb_dim)
|
|
||||||
|
|
||||||
fast_action_emb = self._apply_checkpoint(fast_action_embed_func, fast_action_tokens)
|
fast_action_emb = self._apply_checkpoint(fast_action_embed_func, fast_action_tokens)
|
||||||
embs.append(fast_action_emb)
|
embs.append(fast_action_emb)
|
||||||
@@ -665,7 +665,6 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
if t < max_decoding_steps - 1:
|
if t < max_decoding_steps - 1:
|
||||||
# embed the newly generated token
|
# embed the newly generated token
|
||||||
next_token_emb = self.paligemma_with_expert.embed_language_tokens(next_token)
|
next_token_emb = self.paligemma_with_expert.embed_language_tokens(next_token)
|
||||||
next_token_emb = next_token_emb * math.sqrt(next_token_emb.shape[-1])
|
|
||||||
if prefix_embs.dtype == torch.bfloat16:
|
if prefix_embs.dtype == torch.bfloat16:
|
||||||
next_token_emb = next_token_emb.to(dtype=torch.bfloat16)
|
next_token_emb = next_token_emb.to(dtype=torch.bfloat16)
|
||||||
|
|
||||||
@@ -770,7 +769,6 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
# Embed the single previous token
|
# Embed the single previous token
|
||||||
# We use embed_language_tokens directly to avoid overhead of full prefix embedding
|
# We use embed_language_tokens directly to avoid overhead of full prefix embedding
|
||||||
next_token_emb = self.paligemma_with_expert.embed_language_tokens(next_token)
|
next_token_emb = self.paligemma_with_expert.embed_language_tokens(next_token)
|
||||||
next_token_emb = next_token_emb * math.sqrt(next_token_emb.shape[-1])
|
|
||||||
if prefix_embs.dtype == torch.bfloat16:
|
if prefix_embs.dtype == torch.bfloat16:
|
||||||
next_token_emb = next_token_emb.to(dtype=torch.bfloat16)
|
next_token_emb = next_token_emb.to(dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
|||||||
@@ -197,6 +197,9 @@ class PiGemmaModel(GemmaModel): # type: ignore[misc]
|
|||||||
|
|
||||||
def __init__(self, config: GemmaConfig, **kwargs):
|
def __init__(self, config: GemmaConfig, **kwargs):
|
||||||
super().__init__(config, **kwargs)
|
super().__init__(config, **kwargs)
|
||||||
|
# Free parent-allocated layers/norm before replacing to avoid ~2x peak memory.
|
||||||
|
del self.layers
|
||||||
|
del self.norm
|
||||||
# if not getattr(config, "use_adarms", False):
|
# if not getattr(config, "use_adarms", False):
|
||||||
# return
|
# return
|
||||||
cond_dim = getattr(config, "adarms_cond_dim", None)
|
cond_dim = getattr(config, "adarms_cond_dim", None)
|
||||||
@@ -328,6 +331,7 @@ class PiGemmaForCausalLM(GemmaForCausalLM): # type: ignore[misc]
|
|||||||
|
|
||||||
def __init__(self, config: GemmaConfig, **kwargs):
|
def __init__(self, config: GemmaConfig, **kwargs):
|
||||||
super().__init__(config, **kwargs)
|
super().__init__(config, **kwargs)
|
||||||
|
del self.model
|
||||||
self.model = PiGemmaModel(config)
|
self.model = PiGemmaModel(config)
|
||||||
|
|
||||||
|
|
||||||
@@ -336,6 +340,7 @@ class PaliGemmaModelWithPiGemma(PaliGemmaModel):
|
|||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
del self.language_model
|
||||||
self.language_model = PiGemmaModel(config.text_config)
|
self.language_model = PiGemmaModel(config.text_config)
|
||||||
|
|
||||||
|
|
||||||
@@ -344,6 +349,7 @@ class PaliGemmaForConditionalGenerationWithPiGemma(PaliGemmaForConditionalGenera
|
|||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
del self.model
|
||||||
self.model = PaliGemmaModelWithPiGemma(config)
|
self.model = PaliGemmaModelWithPiGemma(config)
|
||||||
|
|
||||||
# Make modules available through conditional class for BC
|
# Make modules available through conditional class for BC
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ from .action_queue import ActionQueue
|
|||||||
from .configuration_rtc import RTCConfig
|
from .configuration_rtc import RTCConfig
|
||||||
from .latency_tracker import LatencyTracker
|
from .latency_tracker import LatencyTracker
|
||||||
from .modeling_rtc import RTCProcessor
|
from .modeling_rtc import RTCProcessor
|
||||||
|
from .relative import reanchor_relative_rtc_prefix
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ActionInterpolator",
|
"ActionInterpolator",
|
||||||
@@ -26,4 +27,5 @@ __all__ = [
|
|||||||
"LatencyTracker",
|
"LatencyTracker",
|
||||||
"RTCConfig",
|
"RTCConfig",
|
||||||
"RTCProcessor",
|
"RTCProcessor",
|
||||||
|
"reanchor_relative_rtc_prefix",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1,116 +1,4 @@
|
|||||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
# Moved to lerobot.utils.action_interpolator — re-exported for backwards compatibility.
|
||||||
#
|
from lerobot.utils.action_interpolator import ActionInterpolator
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
"""Action interpolation for smoother robot control.
|
__all__ = ["ActionInterpolator"]
|
||||||
|
|
||||||
Provides configurable Nx control rate by interpolating between consecutive actions.
|
|
||||||
Useful with RTC and action-chunking policies to reduce jerkiness.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from torch import Tensor
|
|
||||||
|
|
||||||
|
|
||||||
class ActionInterpolator:
|
|
||||||
"""Interpolates between consecutive actions for smoother control.
|
|
||||||
|
|
||||||
When enabled with multiplier N, produces N actions per policy action
|
|
||||||
by linearly interpolating between the previous and current action.
|
|
||||||
|
|
||||||
Example with multiplier=3:
|
|
||||||
prev_action -> [1/3 interpolated, 2/3 interpolated, current_action]
|
|
||||||
|
|
||||||
This effectively multiplies the control rate for smoother motion.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
interpolator = ActionInterpolator(multiplier=2) # 2x control rate
|
|
||||||
|
|
||||||
# In control loop:
|
|
||||||
if interpolator.needs_new_action():
|
|
||||||
new_action = queue.get()
|
|
||||||
if new_action:
|
|
||||||
interpolator.add(new_action.cpu())
|
|
||||||
|
|
||||||
action = interpolator.get()
|
|
||||||
if action:
|
|
||||||
robot.send_action(action)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, multiplier: int = 1):
|
|
||||||
"""Initialize the interpolator.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
multiplier: Control rate multiplier (1 = no interpolation, 2 = 2x, 3 = 3x, etc.)
|
|
||||||
"""
|
|
||||||
if multiplier < 1:
|
|
||||||
raise ValueError(f"multiplier must be >= 1, got {multiplier}")
|
|
||||||
self.multiplier = multiplier
|
|
||||||
self._prev: Tensor | None = None
|
|
||||||
self._buffer: list[Tensor] = []
|
|
||||||
self._idx = 0
|
|
||||||
|
|
||||||
@property
|
|
||||||
def enabled(self) -> bool:
|
|
||||||
"""Whether interpolation is active (multiplier > 1)."""
|
|
||||||
return self.multiplier > 1
|
|
||||||
|
|
||||||
def reset(self):
|
|
||||||
"""Reset interpolation state (call between episodes)."""
|
|
||||||
self._prev = None
|
|
||||||
self._buffer = []
|
|
||||||
self._idx = 0
|
|
||||||
|
|
||||||
def needs_new_action(self) -> bool:
|
|
||||||
"""Check if a new action is needed from the queue."""
|
|
||||||
return self._idx >= len(self._buffer)
|
|
||||||
|
|
||||||
def add(self, action: Tensor) -> None:
|
|
||||||
"""Add a new action and compute interpolated sequence.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
action: New action tensor from policy/queue (already on CPU).
|
|
||||||
"""
|
|
||||||
if self.multiplier > 1 and self._prev is not None:
|
|
||||||
self._buffer = []
|
|
||||||
for i in range(1, self.multiplier + 1):
|
|
||||||
t = i / self.multiplier
|
|
||||||
interp = self._prev + t * (action - self._prev)
|
|
||||||
self._buffer.append(interp)
|
|
||||||
else:
|
|
||||||
# First step: no previous action yet, so run at base FPS without interpolation.
|
|
||||||
self._buffer = [action.clone()]
|
|
||||||
self._prev = action.clone()
|
|
||||||
self._idx = 0
|
|
||||||
|
|
||||||
def get(self) -> Tensor | None:
|
|
||||||
"""Get the next interpolated action.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Next action tensor, or None if buffer is exhausted.
|
|
||||||
"""
|
|
||||||
if self._idx >= len(self._buffer):
|
|
||||||
return None
|
|
||||||
action = self._buffer[self._idx]
|
|
||||||
self._idx += 1
|
|
||||||
return action
|
|
||||||
|
|
||||||
def get_control_interval(self, fps: float) -> float:
|
|
||||||
"""Get the control interval based on interpolation multiplier.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
fps: Base frames per second.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Control interval in seconds (divided by multiplier).
|
|
||||||
"""
|
|
||||||
return 1.0 / (fps * self.multiplier)
|
|
||||||
|
|||||||
@@ -92,10 +92,10 @@ class ActionQueue:
|
|||||||
Returns:
|
Returns:
|
||||||
int: Number of unconsumed actions.
|
int: Number of unconsumed actions.
|
||||||
"""
|
"""
|
||||||
if self.queue is None:
|
with self.lock:
|
||||||
return 0
|
if self.queue is None:
|
||||||
length = len(self.queue)
|
return 0
|
||||||
return length - self.last_index
|
return len(self.queue) - self.last_index
|
||||||
|
|
||||||
def empty(self) -> bool:
|
def empty(self) -> bool:
|
||||||
"""Check if the queue is empty.
|
"""Check if the queue is empty.
|
||||||
@@ -103,11 +103,10 @@ class ActionQueue:
|
|||||||
Returns:
|
Returns:
|
||||||
bool: True if no actions remain, False otherwise.
|
bool: True if no actions remain, False otherwise.
|
||||||
"""
|
"""
|
||||||
if self.queue is None:
|
with self.lock:
|
||||||
return True
|
if self.queue is None:
|
||||||
|
return True
|
||||||
length = len(self.queue)
|
return len(self.queue) - self.last_index <= 0
|
||||||
return length - self.last_index <= 0
|
|
||||||
|
|
||||||
def get_action_index(self) -> int:
|
def get_action_index(self) -> int:
|
||||||
"""Get the current action consumption index.
|
"""Get the current action consumption index.
|
||||||
@@ -115,7 +114,8 @@ class ActionQueue:
|
|||||||
Returns:
|
Returns:
|
||||||
int: Index of the next action to be consumed.
|
int: Index of the next action to be consumed.
|
||||||
"""
|
"""
|
||||||
return self.last_index
|
with self.lock:
|
||||||
|
return self.last_index
|
||||||
|
|
||||||
def get_left_over(self) -> Tensor | None:
|
def get_left_over(self) -> Tensor | None:
|
||||||
"""Get leftover original actions for RTC prev_chunk_left_over.
|
"""Get leftover original actions for RTC prev_chunk_left_over.
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ class RTCConfig:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Infrastructure
|
# Infrastructure
|
||||||
enabled: bool = False
|
enabled: bool = True
|
||||||
|
|
||||||
# Core RTC settings
|
# Core RTC settings
|
||||||
# Todo change to exp
|
# Todo change to exp
|
||||||
|
|||||||
@@ -0,0 +1,58 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""Relative-action helpers for Real-Time Chunking (RTC)."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from lerobot.processor import (
|
||||||
|
NormalizerProcessorStep,
|
||||||
|
RelativeActionsProcessorStep,
|
||||||
|
TransitionKey,
|
||||||
|
create_transition,
|
||||||
|
to_relative_actions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def reanchor_relative_rtc_prefix(
|
||||||
|
prev_actions_absolute: torch.Tensor,
|
||||||
|
current_state: torch.Tensor,
|
||||||
|
relative_step: RelativeActionsProcessorStep,
|
||||||
|
normalizer_step: NormalizerProcessorStep | None,
|
||||||
|
policy_device: torch.device | str,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Convert absolute leftover actions into model-space for relative-action RTC policies.
|
||||||
|
|
||||||
|
When using relative actions, the RTC prefix (previous chunk's unexecuted tail)
|
||||||
|
is stored in absolute coordinates. Before feeding it back to the policy, this
|
||||||
|
helper re-expresses those actions relative to the robot's current joint state
|
||||||
|
and optionally normalizes them so the policy receives correctly scaled inputs.
|
||||||
|
"""
|
||||||
|
state = current_state.detach().cpu()
|
||||||
|
if state.dim() == 1:
|
||||||
|
state = state.unsqueeze(0)
|
||||||
|
|
||||||
|
action_cpu = prev_actions_absolute.detach().cpu()
|
||||||
|
mask = relative_step._build_mask(action_cpu.shape[-1])
|
||||||
|
relative_actions = to_relative_actions(action_cpu, state, mask)
|
||||||
|
|
||||||
|
transition = create_transition(action=relative_actions)
|
||||||
|
if normalizer_step is not None:
|
||||||
|
transition = normalizer_step(transition)
|
||||||
|
|
||||||
|
return transition[TransitionKey.ACTION].to(policy_device)
|
||||||
@@ -1 +0,0 @@
|
|||||||
../../../../docs/source/policy_sarm_README.md
|
|
||||||
@@ -394,13 +394,21 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
|||||||
loss_dict["losses_after_rm_padding"] = losses.clone().mean().item()
|
loss_dict["losses_after_rm_padding"] = losses.clone().mean().item()
|
||||||
|
|
||||||
if reduction == "none":
|
if reduction == "none":
|
||||||
# Return per-sample losses (B,) by averaging over time and action dims
|
# Return per-sample losses (B,) by averaging over valid (time, action) entries
|
||||||
per_sample_loss = losses.mean(dim=(1, 2))
|
if actions_is_pad is None:
|
||||||
|
per_sample_loss = losses.mean(dim=(1, 2))
|
||||||
|
else:
|
||||||
|
num_valid = ((~actions_is_pad).sum(dim=1) * losses.shape[-1]).clamp_min(1)
|
||||||
|
per_sample_loss = losses.sum(dim=(1, 2)) / num_valid
|
||||||
loss_dict["loss"] = per_sample_loss.mean().item()
|
loss_dict["loss"] = per_sample_loss.mean().item()
|
||||||
return per_sample_loss, loss_dict
|
return per_sample_loss, loss_dict
|
||||||
else:
|
else:
|
||||||
# Default: return scalar mean loss
|
# Default: return scalar mean loss over valid (time, action) entries
|
||||||
loss = losses.mean()
|
if actions_is_pad is None:
|
||||||
|
loss = losses.mean()
|
||||||
|
else:
|
||||||
|
num_valid = ((~actions_is_pad).sum() * losses.shape[-1]).clamp_min(1)
|
||||||
|
loss = losses.sum() / num_valid
|
||||||
loss_dict["loss"] = loss.item()
|
loss_dict["loss"] = loss.item()
|
||||||
return loss, loss_dict
|
return loss, loss_dict
|
||||||
|
|
||||||
|
|||||||
@@ -97,8 +97,8 @@ class VQBeTConfig(PreTrainedConfig):
|
|||||||
vision_backbone: str = "resnet18"
|
vision_backbone: str = "resnet18"
|
||||||
crop_shape: tuple[int, int] | None = (84, 84)
|
crop_shape: tuple[int, int] | None = (84, 84)
|
||||||
crop_is_random: bool = True
|
crop_is_random: bool = True
|
||||||
pretrained_backbone_weights: str | None = None
|
pretrained_backbone_weights: str | None = "ResNet18_Weights.IMAGENET1K_V1"
|
||||||
use_group_norm: bool = True
|
use_group_norm: bool = False
|
||||||
spatial_softmax_num_keypoints: int = 32
|
spatial_softmax_num_keypoints: int = 32
|
||||||
# VQ-VAE
|
# VQ-VAE
|
||||||
n_vqvae_training_steps: int = 20000
|
n_vqvae_training_steps: int = 20000
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ from transformers.utils import (
|
|||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
add_start_docstrings_to_model_forward,
|
add_start_docstrings_to_model_forward,
|
||||||
is_flash_attn_2_available,
|
is_flash_attn_2_available,
|
||||||
is_flash_attn_greater_or_equal_2_10,
|
is_flash_attn_greater_or_equal,
|
||||||
is_torchdynamo_compiling,
|
is_torchdynamo_compiling,
|
||||||
logging,
|
logging,
|
||||||
replace_return_docstrings,
|
replace_return_docstrings,
|
||||||
@@ -890,7 +890,7 @@ class Qwen2_5_VLFlashAttention2(Qwen2_5_VLAttention):
|
|||||||
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
||||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||||
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal("2.1.0")
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -939,7 +939,7 @@ class Qwen2_5_VLFlashAttention2(Qwen2_5_VLAttention):
|
|||||||
input_dtype = query_states.dtype
|
input_dtype = query_states.dtype
|
||||||
if input_dtype == torch.float32:
|
if input_dtype == torch.float32:
|
||||||
if torch.is_autocast_enabled():
|
if torch.is_autocast_enabled():
|
||||||
target_dtype = torch.get_autocast_gpu_dtype()
|
target_dtype = torch.get_autocast_dtype(query_states.device.type)
|
||||||
# Handle the case where the model is quantized
|
# Handle the case where the model is quantized
|
||||||
elif hasattr(self.config, "_pre_quantization_dtype"):
|
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||||
target_dtype = self.config._pre_quantization_dtype
|
target_dtype = self.config._pre_quantization_dtype
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ from transformers.utils import (
|
|||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
add_start_docstrings_to_model_forward,
|
add_start_docstrings_to_model_forward,
|
||||||
is_flash_attn_2_available,
|
is_flash_attn_2_available,
|
||||||
is_flash_attn_greater_or_equal_2_10,
|
is_flash_attn_greater_or_equal,
|
||||||
logging,
|
logging,
|
||||||
replace_return_docstrings,
|
replace_return_docstrings,
|
||||||
)
|
)
|
||||||
@@ -909,7 +909,7 @@ class Florence2FlashAttention2(Florence2Attention):
|
|||||||
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
||||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||||
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal("2.1.0")
|
||||||
|
|
||||||
def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||||
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
|
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
|
||||||
@@ -985,7 +985,7 @@ class Florence2FlashAttention2(Florence2Attention):
|
|||||||
input_dtype = query_states.dtype
|
input_dtype = query_states.dtype
|
||||||
if input_dtype == torch.float32:
|
if input_dtype == torch.float32:
|
||||||
if torch.is_autocast_enabled():
|
if torch.is_autocast_enabled():
|
||||||
target_dtype = torch.get_autocast_gpu_dtype()
|
target_dtype = torch.get_autocast_dtype(query_states.device.type)
|
||||||
# Handle the case where the model is quantized
|
# Handle the case where the model is quantized
|
||||||
elif hasattr(self.config, "_pre_quantization_dtype"):
|
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||||
target_dtype = self.config._pre_quantization_dtype
|
target_dtype = self.config._pre_quantization_dtype
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ from .converters import (
|
|||||||
)
|
)
|
||||||
from .delta_action_processor import MapDeltaActionToRobotActionStep, MapTensorToDeltaActionDictStep
|
from .delta_action_processor import MapDeltaActionToRobotActionStep, MapTensorToDeltaActionDictStep
|
||||||
from .device_processor import DeviceProcessorStep
|
from .device_processor import DeviceProcessorStep
|
||||||
from .env_processor import IsaaclabArenaProcessorStep, LiberoProcessorStep
|
from .env_processor import IsaaclabArenaProcessorStep, LiberoActionProcessorStep, LiberoProcessorStep
|
||||||
from .factory import (
|
from .factory import (
|
||||||
make_default_processors,
|
make_default_processors,
|
||||||
make_default_robot_action_processor,
|
make_default_robot_action_processor,
|
||||||
@@ -149,6 +149,7 @@ __all__ = [
|
|||||||
"RewardProcessorStep",
|
"RewardProcessorStep",
|
||||||
"DataProcessorPipeline",
|
"DataProcessorPipeline",
|
||||||
"IsaaclabArenaProcessorStep",
|
"IsaaclabArenaProcessorStep",
|
||||||
|
"LiberoActionProcessorStep",
|
||||||
"LiberoProcessorStep",
|
"LiberoProcessorStep",
|
||||||
"TimeLimitProcessorStep",
|
"TimeLimitProcessorStep",
|
||||||
"AddBatchDimensionProcessorStep",
|
"AddBatchDimensionProcessorStep",
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user