Compare commits

..

36 Commits

Author SHA1 Message Date
Jade Choghari cbb380df34 draft changes 2025-12-26 14:06:30 +00:00
Alexis Alva 12043b3b5c fix: use importlib.metadata for plugin discovery to support PEP 660 (#2687) 2025-12-24 15:45:14 +01:00
Salman Chishti a06f4b9140 Upgrade GitHub Actions for Node 24 compatibility (#2691) 2025-12-24 10:42:29 +01:00
Steven Palma 20c22a2799 chore(ci): make keyword matching more conservative (#2711) 2025-12-24 02:03:12 +01:00
Steven Palma 2f238fce15 feat(ci): adds release versioning to docs (#2709)
* feat(ci): adds release versioning to docs

* chore(ci): remove TODO
2025-12-24 00:40:56 +01:00
Pepijn ff271e8b51 pi fixes for dependencies (#2706)
* pi fixes for dependencies

* add walls sarm conflict

* also add conflicts for pi

* fix(ci): use --extra all instead of --all-extras + --no-extra

---------

Co-authored-by: Steven Palma <steven.palma@huggingface.co>
2025-12-23 23:58:34 +01:00
Pepijn a142c365dd use syslink for wall-x readme (#2708)
* use syslink for wall-x readme

* remove whitespace
2025-12-23 14:13:32 +01:00
Steven Palma b2ef6ae720 chore: modernize contributing.md (#2677) 2025-12-23 12:10:44 +01:00
Tong Wu a64f2fd322 modify the README file for wallx (#2705)
* support wallx

* fix bugs in flow

* incorporate wallx model into lerobot

* update the policy methods

* reduce to least config and params & pass lerobot basic test

* fixed dtype bugs

* add wallx dependencies

* update

* remove flash-attn requirement && fix bug in inference and fast mode

* fix bug for inference

* add some small modifications

* fix pre-commit errors

* remove lerobot[wallx]

* fix ci

* fix precommit issues

* fix: exclude wallx extra properly in CI workflows

* fix: add uv conflicts for wallx transformers version

* fix: peft test import

* pre-commit

* only export WallXConfig from wall_x package to avoid peft import in CI

* remove torch dep

* precommit

* add import

* update doc files

* fix minor errors

---------

Signed-off-by: Pepijn <138571049+pkooij@users.noreply.github.com>
Co-authored-by: vincentchen <chenlufang@x2robot.com>
Co-authored-by: Geoffrey19 <sympathischmann35@gmail.com>
Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com>
Co-authored-by: Pepijn <pepijn@huggingface.co>
2025-12-23 11:35:06 +01:00
Tong Wu 17c5a0774f feat: support wallx model (#2593)
* support wallx

* fix bugs in flow

* incorporate wallx model into lerobot

* update the policy methods

* reduce to least config and params & pass lerobot basic test

* fixed dtype bugs

* add wallx dependencies

* update

* remove flash-attn requirement && fix bug in inference and fast mode

* fix bug for inference

* add some small modifications

* fix pre-commit errors

* remove lerobot[wallx]

* fix ci

* fix precommit issues

* fix: exclude wallx extra properly in CI workflows

* fix: add uv conflicts for wallx transformers version

* fix: peft test import

* pre-commit

* only export WallXConfig from wall_x package to avoid peft import in CI

* remove torch dep

* precommit

* add import

---------

Co-authored-by: vincentchen <chenlufang@x2robot.com>
Co-authored-by: Geoffrey19 <sympathischmann35@gmail.com>
Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com>
Co-authored-by: Pepijn <pepijn@huggingface.co>
2025-12-22 10:12:39 +01:00
Pepijn 0071b1ff6e Add readme (#2698)
* Add readme

* change ref
2025-12-22 10:04:33 +01:00
Clément Verrier 00b5f65752 fix(optim): enable and resolve mypy type errors (#2683)
* fix(optim): enable and resolve mypy type errors

Resolves #1729

build(deps): add mypy as dependency and update pre-commit hook

* change build's type annotation
2025-12-20 17:19:42 +01:00
Francesco Capuano 2f6c870c4b Fixes ReadMe Policies Classification (#2682)
* fix: tdmpc is a model-based RL method, does not learn from expert demonstrations so no IL there

* fix: typo

* remove trailing space

Signed-off-by: Francesco Capuano <74058581+fracapuano@users.noreply.github.com>

* fix: minor

---------

Signed-off-by: Francesco Capuano <74058581+fracapuano@users.noreply.github.com>
2025-12-20 17:11:02 +01:00
Steven Palma 0bd1969d0a feat(docs): modernize readme (#2660) 2025-12-18 19:45:13 +01:00
Pepijn f04958527e Add sarm (#2639)
* add initial modeling

* make rewind pretrained policy

* add annotation

* small fix

* add sarm

* subtasks

* fix spawn

* fix rewind discrepancies

* Add script to generate embedding for dataset (#2138)

* Add generate and validate script

* fix precommit

* Improve generate embeddings function by using dataset tools (#2206)

---------

Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>

* cleanup

* change order train log

* print batch size

* update sarm processor

* add reward output

* change expected features

* add image validation

* change validation

* get state input from dataset stats

* raise if no state key is found

* pass stats

* cleanup and refactor

* add episode inddex to complementary data

* add subtask init and detection

* revert lerobot_train changes

* pass dataset metadata to policy

* change loadig subtasks

* add small logging

* fix progress conversion and adding initial frame

* use large offset for initial frame (ugly)

* Remove rewind, use clip tokenizer

* add tests, implement formula 1,2 correctly and cleanup

* use task from dataset, cleanup visualizer

* simplify

* simplify and cleanup code and move compute_temporal_proportions to utils

* fix normalization in visualization

* Fix visualization and change prompt

* fix formatting

* add visualize subtask annotations

* use qwen thinking

* try different prompt

* format

* update prompt

* higher temp, long output

* different settings

* use instruct

* show full resp

* split message

* Temp: increase tolerance dataset

* Fix RA-BC (#2572)

* Add next observation loading for RA-BC progress deltas

* Compute weights based on temporal progress deltas instead of static rewards

* Add hard-masking for negative progress deltas in weight computation

* Feat/add dual head (#2582)

* Add dual dense sparse head and annotation

* Add docs

* add dual to procesor

* cleanup

* change sampling in visualize and cleanup

* remove validation

* remove compile

* Feat/test uniform (#2587)

* test uniform

* add different string for misaligned

* Fix rewind and add tests

* uncomment text implementation

* run precommit

* Add head mode for ra-bc

* fix visalization of single task

* add

* return per sample loss

* Fix RA_BC (#2602)

* update rabc implementation

* compute rabc beforehand

* fix import

* add only progress calulation

* use precomputed progress

* multi gpu processing

* import

* fix dataset meta data extraction

* add logging

* logging

* log

* progress per episode

* split differently

* move clip to gpu

* pre decode frames for an episode

* fix cuda initalization

* fix import

* multi processing

* rename

* fix import

* fix

* fix rabc

* use last known progress if oob

* use last known progress if oob

* add misalignment loss with random embeddings

* discard previous changes

* add selection of models to docs for ra_bc

* add transformers dep

* extend tolerance

* initial commit with new codebase

* add tests

* fix

* remove temporal sampler

* drop last frame for sampler

* use original ref

* some fixes

* fix visualization

* remove smoothing and fix order subtasks

* add stride rabc computation

* add push to hub

* add explanation

* add kappa expllaination

* better rabc logging

* feedback pr

* remove dataset tolerance

* revert dataset tool

* revert dataset changes

* add credit

* run precommit

* change path for generate ra_bc

* fix type

* include sarm in all in pyproject

* fix precommit

* lazy import matplotlib

* lazy import qwen

* remove rich console

* skip if transformers is not installed?

* run only when we have faker

* place transformer lazy loading

* Dont test if low transformer version

* fix

* increase transformer

* increase as 4.57.0 is yanked

* remove pi from all

* go back

---------

Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
Co-authored-by: s1lent4gnt <kmeftah.khalil@gmail.com>
2025-12-18 12:50:32 +01:00
Steven Palma 4a151a9682 chore(ci): minor improvement bug-report template & pr auto label (#2676)
* chore(ci): minor improvement bug-report template

* chore(ci): change triggers for PR auto label
2025-12-18 00:23:23 +01:00
Steven Palma 8667b9ef08 chore(ci): minor improvements auto labeling (#2675) 2025-12-17 22:54:47 +01:00
Steven Palma 86eee5c1e2 fix(ci): close bracket pattern (#2674) 2025-12-17 22:40:33 +01:00
Steven Palma 469b855e42 fix(ci): better heuristic + issue type template fix (#2672)
* fix(ci): better heuristic + issue type template fix

* chore(ci): remove keywords in performance tag
2025-12-17 22:31:22 +01:00
Steven Palma 292333cafc chore(ci): update issue template (#2666) 2025-12-17 18:02:20 +01:00
Steven Palma f0c98e23f1 feat(ci): simple automatic labelling (#2667)
* ci: add pr labeler

* ci: add issue labeler

* ci: minor fixes for labelers

* fix(ci): add explicit path for pr labeler
2025-12-17 17:52:45 +01:00
Steven Palma 7621af5acd chore(ci): update PR template (#2665)
* chore: update code of conduct to transformers one

* chore: update PR template
2025-12-17 17:10:04 +01:00
Steven Palma 92fdbe9bbf docs(dataset): add visualization section (#2664) 2025-12-17 14:14:31 +01:00
Steven Palma b303d1ab38 feat(scripts): add more info to lerobot-info (#2663) 2025-12-17 14:14:23 +01:00
Steven Palma b1d162f333 fix(policies): add device back to smolvlm expert (#2662) 2025-12-17 12:12:03 +01:00
Steven Palma 2b304eeb84 feat(dataset): expose tolerance_s argument to training config (#2653) 2025-12-16 00:53:19 +01:00
Sota Nakamura 4e6048a221 finalize the dataset after recording (#2496)
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2025-12-15 17:57:04 +01:00
./c² 81ebcac8d7 docs: update IL robots API example and add OpenCV workaround (#2648)
* docs: update IL robots API example and add OpenCV workaround

- Fix import path from lerobot.record to lerobot.scripts.lerobot_record
- Add required processor parameters to record_loop calls
- Document fourcc="MJPG" workaround for OpenCV async errors
- Improve code formatting in robot configuration examples

Fixes compatibility issues for users following the tutorial on embedded systems
and ensures API examples match current codebase requirements.

* Update il_robots.mdx

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Signed-off-by: ./c² <cagataycali@icloud.com>

---------

Signed-off-by: ./c² <cagataycali@icloud.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2025-12-15 17:56:33 +01:00
Martino Russi a6c3a0fa09 Feat/add mj env (#2613)
* add sim support

* close fix threading issues
2025-12-15 16:22:27 +01:00
Woojin Wie c2fb644613 feat(robot): Add support for OMX robot (#2614)
* upload

* feat(omx): simplify motor initialization and remove default calibration files

* feat(omx): read motor positions without normalization for improved accuracy

* update calibration method for return factory value

Signed-off-by: Junha Cha <ckwnsgk1@gachon.ac.kr>

* change the drive mode

* refactor: clean up code by removing unnecessary blank lines in omx_follower and omx_leader modules

* feat(omx): update calibration method to set drive modes for motors

* feat(pyproject): add 'ROBOTIS' to extend-ignore-identifiers-re list

* feat(omx): enhance calibration method to write default drive modes to motors

* Update src/lerobot/robots/omx_follower/__init__.py

Add informations about the robot

Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
Signed-off-by: Woojin Wie <dnldnwls1123@gmail.com>

---------

Signed-off-by: Junha Cha <ckwnsgk1@gachon.ac.kr>
Signed-off-by: Woojin Wie <dnldnwls1123@gmail.com>
Co-authored-by: Junha02 <chajunha2023@naver.com>
Co-authored-by: Junha Cha <ckwnsgk1@gachon.ac.kr>
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2025-12-15 15:50:29 +01:00
Jade Choghari 1d07a4aefd add auto in docs (#2645)
Signed-off-by: Jade Choghari <chogharijade@gmail.com>
2025-12-13 17:11:19 +01:00
Michel Aractingi ce348a3460 enable variable image sizes to pi0/pi0.5 (#2609)
* enable variable image sizes to pi0/pi0.5

* add square image assertion
2025-12-10 19:41:11 +01:00
Jade Choghari cb920235c4 docs: update X-VLA training strategies/commands (#2611) 2025-12-09 19:08:09 +01:00
Jade Choghari 7f40b3bf82 feat(dataset): add tool to convert images to video datasets (#2560)
* add video encoding tool

* style

* make it work

* more fixes
2025-12-08 18:50:21 +01:00
Michel Aractingi 2e9c9fd832 Replay while loop in sample actions with for loops (#2600) 2025-12-08 14:47:54 +01:00
Steven Palma f9cb5e659c chore(ci): skip workflows if not lerobot repository (#2601)
Co-authored-by: Alex Tyshka <atyshka15@gmail.com>
2025-12-08 12:44:36 +01:00
135 changed files with 17059 additions and 5643 deletions
+62 -36
View File
@@ -12,57 +12,83 @@
# See the License for the specific language governing permissions and
# limitations under the License.
name: "\U0001F41B Bug Report"
description: Submit a bug report to help us improve LeRobot
name: "🚀 Issue / Bug / Request"
description: Report a bug, suggest an improvement, or ask a technical question.
body:
- type: markdown
attributes:
value: |
Thanks for taking the time to submit a bug report! 🐛
If this is not a bug related to the LeRobot library directly, but instead a general question about your code or the library specifically please use our [discord](https://discord.gg/s3KuuzsPFb).
### Thanks for contributing to LeRobot! 🙌
Please choose the most relevant sections below. If this is a general "how-to" question, consider our [Discord](https://discord.gg/s3KuuzsPFb) for faster community support.
- type: dropdown
id: issue-type
attributes:
label: Ticket Type
description: What kind of ticket are you opening?
options:
- "🐛 Bug Report (Something isn't working)"
- "💡 Feature Request / Improvement"
- "❓ Technical Question"
- "🧹 Maintenance / Documentation"
validations:
required: true
- type: textarea
id: system-info
attributes:
label: System Info
description: Please share your LeRobot configuration by running `lerobot-info` (if installed) or `python -m lerobot.scripts.display_sys_info` (if not installed) and pasting the output below.
label: Environment & System Info
description: |
For bugs or technical questions, please run `lerobot-info` and paste the output.
(Optional for feature requests).
render: Shell
placeholder: lerobot version, OS, python version, numpy version, torch version, and lerobot's configuration
placeholder: lerobot version, OS, python version, etc.
- type: textarea
id: description
validations:
required: true
attributes:
label: Description
description: |
Provide a clear summary of the issue or your proposal.
- **Bugs:** What is happening?
- **Features:** What is the goal/use case?
- **Questions:** What are you trying to achieve?
placeholder: |
A clear and concise description of the issue or suggestion.
- type: textarea
id: context-repro
attributes:
label: Context & Reproduction
description: |
Provide a code snippet, steps to reproduce a bug, or technical details about your proposal.
Please use code blocks for scripts and CLI commands.
placeholder: |
Steps to reproduce / Usage example:
1.
2.
3.
- type: textarea
id: logs
attributes:
label: Relevant logs or stack trace
description: If applicable, paste relevant error logs here.
render: Shell
- type: checkboxes
id: information-scripts-examples
id: extras
attributes:
label: Information
description: 'The problem arises when using:'
label: Checklist
options:
- label: "One of the scripts in the examples/ folder of LeRobot"
- label: "My own task or dataset (give details below)"
- label: I have searched existing tickets to ensure this isn't a duplicate.
- label: I am using the latest version of the `main` branch.
- label: I have verified this is not an environment-specific problem.
- type: textarea
id: reproduction
validations:
required: true
id: workaround
attributes:
label: Reproduction
description: |
If needed, provide a simple code sample that reproduces the problem you ran into. It can be a Colab link or just a code snippet.
Sharing error messages or stack traces could be useful as well!
Important! Use code tags to correctly format your code. See https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting
Try to avoid screenshots, as they are hard to read and don't allow copy-and-pasting.
placeholder: |
Steps to reproduce the behavior:
1.
2.
3.
- type: textarea
id: expected-behavior
validations:
required: true
attributes:
label: Expected behavior
description: "A clear and concise description of what you would expect to happen."
label: Additional Info / Workarounds
description: Anything else we should know? If you have a workaround, please share it!
+40 -27
View File
@@ -1,41 +1,54 @@
## What this does
## Title
Explain what this PR does. Feel free to tag your PR with the appropriate label(s).
Short, imperative summary (e.g., "fix(robots): handle None in sensor parser"). See [CONTRIBUTING.md](../CONTRIBUTING.md) for PR conventions.
Examples:
| Title | Label |
|----------------------|-----------------|
| Fixes #[issue] | (🐛 Bug) |
| Adds new dataset | (🗃️ Dataset) |
| Optimizes something | (⚡️ Performance) |
## Type / Scope
## How it was tested
- **Type**: (Bug | Feature | Docs | Performance | Test | CI | Chore)
- **Scope**: (optional — name of module or package affected)
Explain/show how you tested your changes.
## Summary / Motivation
Examples:
- One-paragraph description of what changes and why.
- Why this change is needed and any trade-offs or design notes.
- Added `test_something` in `tests/test_stuff.py`.
- Added `new_feature` and checked that training converges with policy X on dataset/environment Y.
- Optimized `some_function`, it now runs X times faster than previously.
## Related issues
## How to checkout & try? (for the reviewer)
- Fixes / Closes: # (if any)
- Related: # (if any)
Provide a simple way for the reviewer to try out your changes.
## What changed
Examples:
- Short, concrete bullets of the modifications (files/behaviour).
- Short note if this introduces breaking changes and migration steps.
```bash
pytest -sx tests/test_stuff.py::test_something
```
## How was this tested
```bash
lerobot-train --some.option=true
```
- Tests added: list new tests or test files.
- Manual checks / dataset runs performed.
## SECTION TO REMOVE BEFORE SUBMITTING YOUR PR
## How to run locally (reviewer)
**Note**: Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR. Try to avoid tagging more than 3 people.
- Run the relevant tests:
**Note**: Before submitting this PR, please read the [contributor guideline](https://github.com/huggingface/lerobot/blob/main/CONTRIBUTING.md#submitting-a-pull-request-pr).
```bash
pytest -q tests/ -k <keyword>
```
- Run a quick example or CLI (if applicable):
```bash
lerobot-train --some.option=true
```
## Checklist (required before merge)
- [ ] Linting/formatting run (`pre-commit run -a`)
- [ ] All tests pass locally (`pytest`)
- [ ] Documentation updated
- [ ] CI is green
## Reviewer notes
- Anything the reviewer should focus on (performance, edge-cases, specific files) or general notes.
- Anyone in the community is free to review the PR.
+69
View File
@@ -0,0 +1,69 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
CI:
- changed-files:
- any-glob-to-any-file:
- '.github/**'
- 'docker/**'
github_actions:
- changed-files:
- any-glob-to-any-file: '.github/**'
documentation:
- changed-files:
- any-glob-to-any-file:
- '**/*.md'
- '**/*.mdx'
- 'docs/**'
examples:
- changed-files:
- any-glob-to-any-file: 'examples/**'
tests:
- changed-files:
- any-glob-to-any-file: 'tests/**'
sensors:
- changed-files:
- any-glob-to-any-file: 'src/lerobot/cameras/**'
configuration:
- changed-files:
- any-glob-to-any-file: 'src/lerobot/configs/**'
dataset:
- changed-files:
- any-glob-to-any-file: 'src/lerobot/datasets/**'
evaluation:
- changed-files:
- any-glob-to-any-file: 'src/lerobot/envs/**'
robots:
- changed-files:
- any-glob-to-any-file:
- 'src/lerobot/teleoperators/**'
- 'src/lerobot/robots/**'
- 'src/lerobot/motors/**'
policies:
- changed-files:
- any-glob-to-any-file: 'src/lerobot/policies/**'
processor:
- changed-files:
- any-glob-to-any-file: 'src/lerobot/processor/**'
@@ -31,7 +31,8 @@ jobs:
name: Upload Preview and Comment
if: >
github.event.workflow_run.event == 'pull_request' &&
github.event.workflow_run.conclusion == 'success'
github.event.workflow_run.conclusion == 'success' &&
github.repository == 'huggingface/lerobot'
uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@main
with:
package_name: lerobot
+8 -3
View File
@@ -33,6 +33,9 @@ on:
paths:
- "docs/**"
release:
types: [published]
# Ensures that only the latest commit for a PR or branch is built, canceling older runs.
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
@@ -42,14 +45,16 @@ jobs:
# This job builds and deploys the official documentation.
build_main_docs:
name: Build Main Docs
if: github.event_name == 'push' || github.event_name == 'workflow_dispatch'
if: >
(github.event_name == 'push' || github.event_name == 'workflow_dispatch' || github.event_name == 'release') &&
github.repository == 'huggingface/lerobot'
permissions:
contents: read
uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@main
with:
commit_sha: ${{ github.sha }}
package: lerobot
additional_args: --not_python_module
additional_args: --not_python_module ${{ github.event_name == 'release' && format('--version {0}', github.event.release.tag_name) || '' }}
secrets:
token: ${{ secrets.HUGGINGFACE_PUSH }}
hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }}
@@ -58,7 +63,7 @@ jobs:
# The result of this job triggers the 'Upload PR Documentation' workflow.
build_pr_docs:
name: Build PR Docs
if: github.event_name == 'pull_request'
if: github.event_name == 'pull_request' && github.repository == 'huggingface/lerobot'
permissions:
contents: read
pull-requests: write
+1 -2
View File
@@ -45,7 +45,6 @@ permissions:
env:
UV_VERSION: "0.8.0"
PYTHON_VERSION: "3.10"
DOCKER_IMAGE_NAME: huggingface/lerobot-gpu
# Ensures that only the latest commit for a PR or branch is built, canceling older runs.
concurrency:
@@ -63,7 +62,7 @@ jobs:
HF_HOME: /mnt/cache/.cache/huggingface
HF_LEROBOT_HOME: /mnt/cache/.cache/huggingface/lerobot
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@v6
with:
persist-credentials: false
lfs: true
+3 -3
View File
@@ -61,7 +61,7 @@ jobs:
HF_HOME: /mnt/cache/.cache/huggingface
HF_LEROBOT_HOME: /mnt/cache/.cache/huggingface/lerobot
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@v6
with:
lfs: true
persist-credentials: false
@@ -85,7 +85,7 @@ jobs:
python-version: ${{ env.PYTHON_VERSION }}
- name: Install lerobot with all extras
run: uv sync --all-extras --no-extra groot # TODO(Steven): Make flash-attn optional
run: uv sync --extra all # TODO(Steven): Make flash-attn optional
- name: Run pytest (all extras)
run: uv run pytest tests -vv --maxfail=10
@@ -127,7 +127,7 @@ jobs:
sudo apt-get update
sudo apt-get install git-lfs
git lfs install
- uses: actions/checkout@v4
- uses: actions/checkout@v6
with:
lfs: true
persist-credentials: false
+77
View File
@@ -0,0 +1,77 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This workflow automatically labels issues based on their content.
name: Issue Labeler
on:
# Trigger on new issues and edits to existing issues
issues:
types: [opened, edited]
permissions:
contents: read
issues: write
jobs:
label-issue:
name: Auto Label Issue
runs-on: ubuntu-latest
if: github.repository == 'huggingface/lerobot'
steps:
- uses: actions/github-script@v8
with:
script: |
// Setup Input Text
const body = (context.payload.issue.body || '');
const title = (context.payload.issue.title || '');
const cleanBody = body.replace(/```[\s\S]*?```/g, '');
const text = `${title}\n${cleanBody}`.toLowerCase();
const labelsToAdd = new Set();
const matches = (re) => re.test(text);
// Keyword Heuristics
if (matches(/\b(bug|error|crash|exception)\b/i)) labelsToAdd.add('bug');
if (matches(/\b(new feature|enhancement|improvement|proposal|feature request)\b/i)) labelsToAdd.add('enhancement');
if (matches(/\b(question|how to|clarify|explain|how do i|help me|question about)\b/i)) labelsToAdd.add('question');
if (matches(/\b(documentation|docs?|readme|tutorial|wiki|typo|docstring)\b/i)) labelsToAdd.add('documentation');
if (matches(/\b(example|sample|demo|notebook)s?\b/i)) labelsToAdd.add('examples');
if (matches(/\b(datasets?|data loader|data augmentation|data preprocessing)\b/i)) labelsToAdd.add('dataset');
if (matches(/\b(mujoco|isaac|simulation|sim)\b/i)) labelsToAdd.add('simulation');
if (matches(/\b(train|training|optimizer|gradient|wandb|sac)\b/i)) labelsToAdd.add('training');
if (matches(/\b(rerun|plot|render|rendering|visualizer)/i)) labelsToAdd.add('visualization');
if (matches(/\b(cameras?|opencv|realsense|lidars?|sensors?|imus?|microphones?|rgbd|encoders?)\b/i)) labelsToAdd.add('sensors');
if (matches(/\b(urdf|actuators?|calibration|end-effector|kinematics)\b/i)) labelsToAdd.add('robots');
if (matches(/\b(teleop|teleoperator|controller|leader|follower|joystick|gamepad)\b/i)) labelsToAdd.add('teleoperators');
if (matches(/\b(policy|policies|model?)\b/i)) labelsToAdd.add('policies');
if (matches(/\b(processor|pipeline|preprocessor|postprocessor)s?\b/i)) labelsToAdd.add('processor');
if (matches(/\b(eval|evaluate|evaluation|metrics?|score|benchmarks?)\b/i)) labelsToAdd.add('evaluation');
if (matches(/\b(tests?|pytest|unittest|failing test)\b/i)) labelsToAdd.add('tests');
if (matches(/\b(ci|github actions?|github workflows?|gha|docker|pypi)\b/i)) labelsToAdd.add('CI');
if (matches(/\b(perf|latency|throughput|fps|speed|performance|slow|fast|slower|faster|memory usage)\b/i)) labelsToAdd.add('performance');
if (matches(/\b(dependency|dependencies|pip|install error|importerror|package not found|pyproject)\b/i)) labelsToAdd.add('dependencies');
if (matches(/\b(configuration|config|arguments?|input feature|dracuss)\b/i)) labelsToAdd.add('configuration');
// Apply Labels
const labels = Array.from(labelsToAdd).filter(Boolean);
if (labels.length > 0) {
console.log(`Adding labels: ${labels.join(', ')}`);
await github.rest.issues.addLabels({
owner: context.repo.owner,
repo: context.repo.repo,
issue_number: context.issue.number,
labels,
});
}
+4 -2
View File
@@ -43,6 +43,7 @@ jobs:
name: Build CPU Docker for Nightly
runs-on:
group: aws-general-8-plus
if: github.repository == 'huggingface/lerobot'
outputs:
image_tag: ${{ env.DOCKER_IMAGE_NAME_CPU }}
steps:
@@ -51,7 +52,7 @@ jobs:
sudo apt-get update
sudo apt-get install git-lfs
git lfs install
- uses: actions/checkout@v4
- uses: actions/checkout@v6
with:
lfs: true
persist-credentials: false
@@ -77,6 +78,7 @@ jobs:
name: Build GPU Docker for Nightly
runs-on:
group: aws-general-8-plus
if: github.repository == 'huggingface/lerobot'
outputs:
image_tag: ${{ env.DOCKER_IMAGE_NAME_GPU }}
steps:
@@ -85,7 +87,7 @@ jobs:
sudo apt-get update
sudo apt-get install git-lfs
git lfs install
- uses: actions/checkout@v4
- uses: actions/checkout@v6
with:
lfs: true
persist-credentials: false
+39
View File
@@ -0,0 +1,39 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This workflow labels pull requests based on the files that were changed.
name: Pull Request Labeler
on:
# Allows labeling pull requests when they are opened or updated
# zizmor: ignore[dangerous-triggers] Needed to label PRs from forks
pull_request_target:
branches:
- main
types: [opened, synchronize, reopened, ready_for_review]
permissions:
contents: read
pull-requests: write
jobs:
triage:
name: Label PR
runs-on: ubuntu-latest
if: github.repository == 'huggingface/lerobot' && !github.event.pull_request.draft
steps:
- uses: actions/labeler@v6
with:
repo-token: ${{ secrets.GITHUB_TOKEN }}
sync-labels: true # Removes labels if files are removed from the PR
+2 -2
View File
@@ -43,12 +43,12 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@v6
with:
persist-credentials: false
- name: Set up Python
uses: actions/setup-python@v5
uses: actions/setup-python@v6
with:
python-version: '3.10'
+4 -4
View File
@@ -29,6 +29,7 @@ jobs:
build-and-publish:
name: Build and publish Python distributions
runs-on: ubuntu-latest
if: github.repository == 'huggingface/lerobot'
outputs:
version: ${{ steps.extract_info.outputs.tag_version }}
permissions:
@@ -37,12 +38,12 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@v6
with:
persist-credentials: false
- name: Set up Python
uses: actions/setup-python@v5
uses: actions/setup-python@v6
with:
python-version: '3.10'
@@ -134,7 +135,7 @@ jobs:
env:
MUJOCO_GL: egl
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@v6
with:
lfs: true
persist-credentials: false
@@ -176,4 +177,3 @@ jobs:
# TODO(Steven): Publish draft/pre-release and to test pypi weekly
# TODO(Steven): Separate build and publish job
# TODO(Steven): Tag documentation with the same version as the package
+1 -1
View File
@@ -43,7 +43,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4 # zizmor: ignore[unpinned-uses]
uses: actions/checkout@v6 # zizmor: ignore[unpinned-uses]
with:
fetch-depth: 0
persist-credentials: false
+1
View File
@@ -45,6 +45,7 @@ jobs:
stale:
name: Close Stale Issues and PRs
runs-on: ubuntu-latest
if: github.repository == 'huggingface/lerobot'
permissions:
actions: write
contents: write # only for delete-branch option
+4 -3
View File
@@ -43,12 +43,13 @@ jobs:
full-tests:
name: Full Unbound Tests
runs-on: ubuntu-latest
if: github.repository == 'huggingface/lerobot'
env:
MUJOCO_GL: egl
HF_HOME: /mnt/cache/.cache/huggingface
HF_LEROBOT_HOME: /mnt/cache/.cache/huggingface/lerobot
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@v6
with:
lfs: true
persist-credentials: false
@@ -77,7 +78,7 @@ jobs:
echo "Dependencies unbound:" && cat pyproject.toml
- name: Install lerobot with all extras
run: uv sync --all-extras --no-extra groot # TODO(Steven): Make flash-attn optional
run: uv sync --extra all # TODO(Steven): Make flash-attn optional
- name: Run pytest (all extras)
run: uv run pytest tests -vv
@@ -100,7 +101,7 @@ jobs:
sudo apt-get update
sudo apt-get install git-lfs
git lfs install
- uses: actions/checkout@v4
- uses: actions/checkout@v6
with:
lfs: true
persist-credentials: false
+1 -1
View File
@@ -87,7 +87,7 @@ repos:
# TODO(Steven): Uncomment when ready to use
##### Static Analysis & Typing #####
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.18.2
rev: v1.19.1
hooks:
- id: mypy
args: [--config-file=pyproject.toml]
+2 -2
View File
@@ -52,7 +52,7 @@ decisions when appropriate.
This Code of Conduct applies within all community spaces, and also applies when
an individual is officially representing the community in public spaces.
Examples of representing our community include using an official email address,
Examples of representing our community include using an official e-mail address,
posting via an official social media account, or acting as an appointed
representative at an online or offline event.
@@ -60,7 +60,7 @@ representative at an online or offline event.
Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported to the community leaders responsible for enforcement at
[feedback@huggingface.co](mailto:feedback@huggingface.co).
feedback@huggingface.co.
All complaints will be reviewed and investigated promptly and fairly.
All community leaders are obligated to respect the privacy and security of the
+56 -296
View File
@@ -1,323 +1,83 @@
# How to contribute to 🤗 LeRobot?
# How to contribute to 🤗 LeRobot
Everyone is welcome to contribute, and we value everybody's contribution. Code
is thus not the only way to help the community. Answering questions, helping
others, reaching out and improving the documentations are immensely valuable to
the community.
Everyone is welcome to contribute, and we value everybody's contribution. Code is not the only way to help the community. Answering questions, helping others, reaching out, and improving the documentation are immensely valuable.
It also helps us if you spread the word: reference the library from blog posts
on the awesome projects it made possible, shout out on Twitter when it has
helped you, or simply ⭐️ the repo to say "thank you".
Whichever way you choose to contribute, please be mindful to respect our [code of conduct](./CODE_OF_CONDUCT.md).
Whichever way you choose to contribute, please be mindful to respect our
[code of conduct](https://github.com/huggingface/lerobot/blob/main/CODE_OF_CONDUCT.md).
## Ways to Contribute
## You can contribute in so many ways!
You can contribute in many ways:
Some of the ways you can contribute to 🤗 LeRobot:
- **Fixing issues:** Resolve bugs or improve existing code.
- **New features:** Develop new features.
- **Extend:** Implement new models/policies, robots, or simulation environments and upload datasets to the Hugging Face Hub.
- **Documentation:** Improve examples, guides, and docstrings.
- **Feedback:** Submit tickets related to bugs or desired new features.
- Fixing outstanding issues with the existing code.
- Implementing new models, datasets or simulation environments.
- Contributing to the examples or to the documentation.
- Submitting issues related to bugs or desired new features.
If you are unsure where to start, join our [Discord Channel](https://discord.gg/JkrYNdmw).
Following the guides below, feel free to open issues and PRs and to coordinate your efforts with the community on our [Discord Channel](https://discord.gg/VjFz58wn3R). For specific inquiries, reach out to [Remi Cadene](mailto:remi.cadene@huggingface.co).
## Development Setup
If you are not sure how to contribute or want to know the next features we working on, look on this project page: [LeRobot TODO](https://github.com/orgs/huggingface/projects/46)
To contribute code, you need to set up a development environment.
## Submitting a new issue or feature request
### 1. Fork and Clone
Do your best to follow these guidelines when submitting an issue or a feature
request. It will make it easier for us to come back to you quickly and with good
feedback.
### Did you find a bug?
The 🤗 LeRobot library is robust and reliable thanks to the users who notify us of
the problems they encounter. So thank you for reporting an issue.
First, we would really appreciate it if you could **make sure the bug was not
already reported** (use the search bar on Github under Issues).
Did not find it? :( So we can act quickly on it, please follow these steps:
- Include your **OS type and version**, the versions of **Python** and **PyTorch**.
- A short, self-contained, code snippet that allows us to reproduce the bug in
less than 30s.
- The full traceback if an exception is raised.
- Attach any other additional information, like screenshots, you think may help.
### Do you want a new feature?
A good feature request addresses the following points:
1. Motivation first:
- Is it related to a problem/frustration with the library? If so, please explain
why. Providing a code snippet that demonstrates the problem is best.
- Is it related to something you would need for a project? We'd love to hear
about it!
- Is it something you worked on and think could benefit the community?
Awesome! Tell us what problem it solved for you.
2. Write a _paragraph_ describing the feature.
3. Provide a **code snippet** that demonstrates its future use.
4. In case this is related to a paper, please attach a link.
5. Attach any additional information (drawings, screenshots, etc.) you think may help.
If your issue is well written we're already 80% of the way there by the time you
post it.
## Adding new policies, datasets or environments
Look at our implementations for [datasets](./src/lerobot/datasets/), [policies](./src/lerobot/policies/),
environments ([aloha](https://github.com/huggingface/gym-aloha),
[pusht](https://github.com/huggingface/gym-pusht))
and follow the same api design.
When implementing a new dataset loadable with LeRobotDataset follow these steps:
- Update `available_datasets_per_env` in `lerobot/__init__.py`
When implementing a new environment (e.g. `gym_aloha`), follow these steps:
- Update `available_tasks_per_env` and `available_datasets_per_env` in `lerobot/__init__.py`
When implementing a new policy class (e.g. `DiffusionPolicy`) follow these steps:
- Update `available_policies` and `available_policies_per_env`, in `lerobot/__init__.py`
- Set the required `name` class attribute.
- Update variables in `tests/test_available.py` by importing your new Policy class
## Submitting a pull request (PR)
Before writing code, we strongly advise you to search through the existing PRs or
issues to make sure that nobody is already working on the same thing. If you are
unsure, it is always a good idea to open an issue to get some feedback.
You will need basic `git` proficiency to be able to contribute to
🤗 LeRobot. `git` is not the easiest tool to use but it has the greatest
manual. Type `git --help` in a shell and enjoy. If you prefer books, [Pro
Git](https://git-scm.com/book/en/v2) is a very good reference.
Follow these steps to start contributing:
1. Fork the [repository](https://github.com/huggingface/lerobot) by
clicking on the 'Fork' button on the repository's page. This creates a copy of the code
under your GitHub user account.
2. Clone your fork to your local disk, and add the base repository as a remote. The following command
assumes you have your public SSH key uploaded to GitHub. See the following guide for more
[information](https://docs.github.com/en/repositories/creating-and-managing-repositories/cloning-a-repository).
```bash
git clone git@github.com:<your Github handle>/lerobot.git
cd lerobot
git remote add upstream https://github.com/huggingface/lerobot.git
```
3. Create a new branch to hold your development changes, and do this for every new PR you work on.
Start by synchronizing your `main` branch with the `upstream/main` branch (more details in the [GitHub Docs](https://docs.github.com/en/github/collaborating-with-issues-and-pull-requests/syncing-a-fork)):
```bash
git checkout main
git fetch upstream
git rebase upstream/main
```
Once your `main` branch is synchronized, create a new branch from it:
```bash
git checkout -b a-descriptive-name-for-my-changes
```
🚨 **Do not** work on the `main` branch.
4. for development, we advise to use a tool like `poetry` or `uv` instead of just `pip` to easily track our dependencies.
Follow the instructions to [install poetry](https://python-poetry.org/docs/#installation) (use a version >=2.1.0) or to [install uv](https://docs.astral.sh/uv/getting-started/installation/#installation-methods) if you don't have one of them already.
Set up a development environment with conda:
```bash
conda create -y -n lerobot-dev python=3.10 && conda activate lerobot-dev
```
If you're using `uv`, it can manage python versions so you can instead do:
```bash
uv venv --python 3.10 && source .venv/bin/activate
```
To develop on 🤗 LeRobot, you will at least need to install the `dev` and `test` extras dependencies along with the core library:
using `poetry`
```bash
poetry sync --extras "dev test"
```
using `uv`
```bash
uv sync --extra dev --extra test
```
You can also install the project with all its dependencies (including environments):
using `poetry`
```bash
poetry sync --all-extras
```
using `uv`
```bash
uv sync --all-extras
```
> **Note:** If you don't install simulation environments with `--all-extras`, the tests that require them will be skipped when running the pytest suite locally. However, they _will_ be tested in the CI. In general, we advise you to install everything and test locally before pushing.
Whichever command you chose to install the project (e.g. `poetry sync --all-extras`), you should run it again when pulling code with an updated version of `pyproject.toml` and `poetry.lock` in order to synchronize your virtual environment with the new dependencies.
The equivalent of `pip install some-package`, would just be:
using `poetry`
```bash
poetry add some-package
```
using `uv`
```bash
uv add some-package
```
When making changes to the poetry sections of the `pyproject.toml`, you should run the following command to lock dependencies.
using `poetry`
```bash
poetry lock
```
using `uv`
```bash
uv lock
```
5. Develop the features on your branch.
As you work on the features, you should make sure that the test suite
passes. You should run the tests impacted by your changes like this (see
below an explanation regarding the environment variable):
```bash
pytest tests/<TEST_TO_RUN>.py
```
6. Follow our style.
`lerobot` relies on `ruff` to format its source code
consistently. Set up [`pre-commit`](https://pre-commit.com/) to run these checks
automatically as Git commit hooks.
Install `pre-commit` hooks:
```bash
pre-commit install
```
You can run these hooks whenever you need on staged files with:
```bash
pre-commit
```
Once you're happy with your changes, add changed files using `git add` and
make a commit with `git commit` to record your changes locally:
```bash
git add modified_file.py
git commit
```
Note, if you already committed some changes that have a wrong formatting, you can use:
```bash
pre-commit run --all-files
```
Please write [good commit messages](https://chris.beams.io/posts/git-commit/).
It is a good idea to sync your copy of the code with the original
repository regularly. This way you can quickly account for changes:
```bash
git fetch upstream
git rebase upstream/main
```
Push the changes to your account using:
```bash
git push -u origin a-descriptive-name-for-my-changes
```
7. Once you are satisfied (**and the checklist below is happy too**), go to the
webpage of your fork on GitHub. Click on 'Pull request' to send your changes
to the project maintainers for review.
8. It's ok if maintainers ask you for changes. It happens to core contributors
too! So everyone can see the changes in the Pull request, work in your local
branch and push the changes to your fork. They will automatically appear in
the pull request.
### Checklist
1. The title of your pull request should be a summary of its contribution;
2. If your pull request addresses an issue, please mention the issue number in
the pull request description to make sure they are linked (and people
consulting the issue know you are working on it);
3. To indicate a work in progress please prefix the title with `[WIP]`, or preferably mark
the PR as a draft PR. These are useful to avoid duplicated work, and to differentiate
it from PRs ready to be merged;
4. Make sure existing tests pass;
### Tests
An extensive test suite is included to test the library behavior and several examples. Library tests can be found in the [tests folder](https://github.com/huggingface/lerobot/tree/main/tests).
Install [git lfs](https://git-lfs.com/) to retrieve test artifacts (if you don't have it already).
On Mac:
Fork the repository on GitHub, then clone your fork:
```bash
brew install git-lfs
git lfs install
git clone https://github.com/<your-handle>/lerobot.git
cd lerobot
git remote add upstream https://github.com/huggingface/lerobot.git
```
On Ubuntu:
### 2. Environment Installation
Please follow our [Installation Guide](./docs/source/installation.mdx) for the environment setup & installation from source.
## Running Tests & Quality Checks
### Code Style (Pre-commit)
Install `pre-commit` hooks to run checks automatically before you commit:
```bash
sudo apt-get install git-lfs
git lfs install
pre-commit install
```
Pull artifacts if they're not in [tests/artifacts](tests/artifacts)
To run checks manually on all files:
```bash
pre-commit run --all-files
```
### Running Tests
We use `pytest`. First, ensure you have test artifacts by installing **git-lfs**:
```bash
git lfs install
git lfs pull
```
We use `pytest` in order to run the tests. From the root of the
repository, here's how to run tests with `pytest` for the library:
Run the full suite (this may require extras installed):
```bash
python -m pytest -sv ./tests
pytest -sv ./tests
```
You can specify a smaller set of tests in order to test only the feature
you're working on.
Or run a specific test file during development:
```bash
pytest -sv tests/test_specific_feature.py
```
## Submitting Issues & Pull Requests
Use the templates for required fields and examples.
- **Issues:** Follow the [ticket template](./.github/ISSUE_TEMPLATE/bug-report.yml).
- **Pull requests:** Rebase on `upstream/main`, use a descriptive branch (don't work on `main`), run `pre-commit` and tests locally, and follow the [PR template](./.github/PULL_REQUEST_TEMPLATE.md).
One member of the LeRobot team will then review your contribution.
Thank you for contributing to LeRobot!
+103 -290
View File
@@ -1,7 +1,5 @@
<p align="center">
<img alt="LeRobot, Hugging Face Robotics Library" src="https://raw.githubusercontent.com/huggingface/lerobot/main/media/lerobot-logo-thumbnail.png" width="100%">
<br/>
<br/>
<img alt="LeRobot, Hugging Face Robotics Library" src="./media/readme/lerobot-logo-thumbnail.png" width="100%">
</p>
<div align="center">
@@ -12,323 +10,130 @@
[![Status](https://img.shields.io/pypi/status/lerobot)](https://pypi.org/project/lerobot/)
[![Version](https://img.shields.io/pypi/v/lerobot)](https://pypi.org/project/lerobot/)
[![Contributor Covenant](https://img.shields.io/badge/Contributor%20Covenant-v2.1-ff69b4.svg)](https://github.com/huggingface/lerobot/blob/main/CODE_OF_CONDUCT.md)
[![Discord](https://dcbadge.vercel.app/api/server/C5P34WJ68S?style=flat)](https://discord.gg/s3KuuzsPFb)
<!-- [![Coverage](https://codecov.io/gh/huggingface/lerobot/branch/main/graph/badge.svg?token=TODO)](https://codecov.io/gh/huggingface/lerobot) -->
</div>
<h2 align="center">
<p><a href="https://huggingface.co/docs/lerobot/hope_jr">
Build Your Own HopeJR Robot!</a></p>
</h2>
**LeRobot** aims to provide models, datasets, and tools for real-world robotics in PyTorch. The goal is to lower the barrier to entry so that everyone can contribute to and benefit from shared datasets and pretrained models.
<div align="center">
<img
src="https://raw.githubusercontent.com/huggingface/lerobot/main/media/hope_jr/hopejr.png"
alt="HopeJR robot"
title="HopeJR robot"
width="60%"
/>
🤗 A hardware-agnostic, Python-native interface that standardizes control across diverse platforms, from low-cost arms (SO-100) to humanoids.
<p><strong>Meet HopeJR A humanoid robot arm and hand for dexterous manipulation!</strong></p>
<p>Control it with exoskeletons and gloves for precise hand movements.</p>
<p>Perfect for advanced manipulation tasks! 🤖</p>
🤗 A standardized, scalable LeRobotDataset format (Parquet + MP4 or images) hosted on the Hugging Face Hub, enabling efficient storage, streaming and visualization of massive robotic datasets.
<p><a href="https://huggingface.co/docs/lerobot/hope_jr">
See the full HopeJR tutorial here.</a></p>
</div>
🤗 State-of-the-art policies that have been shown to transfer to the real-world ready for training and deployment.
<br/>
🤗 Comprehensive support for the open-source ecosystem to democratize physical AI.
<h2 align="center">
<p><a href="https://huggingface.co/docs/lerobot/so101">
Build Your Own SO-101 Robot!</a></p>
</h2>
## Quick Start
<div align="center">
<table>
<tr>
<td align="center"><img src="https://raw.githubusercontent.com/huggingface/lerobot/main/media/so101/so101.webp" alt="SO-101 follower arm" title="SO-101 follower arm" width="90%"/></td>
<td align="center"><img src="https://raw.githubusercontent.com/huggingface/lerobot/main/media/so101/so101-leader.webp" alt="SO-101 leader arm" title="SO-101 leader arm" width="90%"/></td>
</tr>
</table>
<p><strong>Meet the updated SO100, the SO-101 Just €114 per arm!</strong></p>
<p>Train it in minutes with a few simple moves on your laptop.</p>
<p>Then sit back and watch your creation act autonomously! 🤯</p>
<p><a href="https://huggingface.co/docs/lerobot/so101">
See the full SO-101 tutorial here.</a></p>
<p>Want to take it to the next level? Make your SO-101 mobile by building LeKiwi!</p>
<p>Check out the <a href="https://huggingface.co/docs/lerobot/lekiwi">LeKiwi tutorial</a> and bring your robot to life on wheels.</p>
<img src="https://raw.githubusercontent.com/huggingface/lerobot/main/media/lekiwi/kiwi.webp" alt="LeKiwi mobile robot" title="LeKiwi mobile robot" width="50%">
</div>
<br/>
<h3 align="center">
<p>LeRobot: State-of-the-art AI for real-world robotics</p>
</h3>
---
🤗 LeRobot aims to provide models, datasets, and tools for real-world robotics in PyTorch. The goal is to lower the barrier to entry to robotics so that everyone can contribute and benefit from sharing datasets and pretrained models.
🤗 LeRobot contains state-of-the-art approaches that have been shown to transfer to the real-world with a focus on imitation learning and reinforcement learning.
🤗 LeRobot already provides a set of pretrained models, datasets with human collected demonstrations, and simulation environments to get started without assembling a robot. In the coming weeks, the plan is to add more and more support for real-world robotics on the most affordable and capable robots out there.
🤗 LeRobot hosts pretrained models and datasets on this Hugging Face community page: [huggingface.co/lerobot](https://huggingface.co/lerobot)
#### Examples of pretrained models on simulation environments
<table>
<tr>
<td><img src="https://raw.githubusercontent.com/huggingface/lerobot/main/media/gym/aloha_act.gif" width="100%" alt="ACT policy on ALOHA env"/></td>
<td><img src="https://raw.githubusercontent.com/huggingface/lerobot/main/media/gym/simxarm_tdmpc.gif" width="100%" alt="TDMPC policy on SimXArm env"/></td>
<td><img src="https://raw.githubusercontent.com/huggingface/lerobot/main/media/gym/pusht_diffusion.gif" width="100%" alt="Diffusion policy on PushT env"/></td>
</tr>
<tr>
<td align="center">ACT policy on ALOHA env</td>
<td align="center">TDMPC policy on SimXArm env</td>
<td align="center">Diffusion policy on PushT env</td>
</tr>
</table>
## Installation
LeRobot works with Python 3.10+ and PyTorch 2.2+.
### Environment Setup
Create a virtual environment with Python 3.10 and activate it, e.g. with [`miniforge`](https://conda-forge.org/download/):
```bash
conda create -y -n lerobot python=3.10
conda activate lerobot
```
When using `conda`, install `ffmpeg` in your environment:
```bash
conda install ffmpeg -c conda-forge
```
> **NOTE:** This usually installs `ffmpeg 7.X` for your platform compiled with the `libsvtav1` encoder. If `libsvtav1` is not supported (check supported encoders with `ffmpeg -encoders`), you can:
>
> - _[On any platform]_ Explicitly install `ffmpeg 7.X` using:
>
> ```bash
> conda install ffmpeg=7.1.1 -c conda-forge
> ```
>
> - _[On Linux only]_ Install [ffmpeg build dependencies](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#GettheDependencies) and [compile ffmpeg from source with libsvtav1](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#libsvtav1), and make sure you use the corresponding ffmpeg binary to your install with `which ffmpeg`.
### Install LeRobot 🤗
#### From Source
First, clone the repository and navigate into the directory:
```bash
git clone https://github.com/huggingface/lerobot.git
cd lerobot
```
Then, install the library in editable mode. This is useful if you plan to contribute to the code.
```bash
pip install -e .
```
> **NOTE:** If you encounter build errors, you may need to install additional dependencies (`cmake`, `build-essential`, and `ffmpeg libs`). On Linux, run:
> `sudo apt-get install cmake build-essential python3-dev pkg-config libavformat-dev libavcodec-dev libavdevice-dev libavutil-dev libswscale-dev libswresample-dev libavfilter-dev`. For other systems, see: [Compiling PyAV](https://pyav.org/docs/develop/overview/installation.html#bring-your-own-ffmpeg)
For simulations, 🤗 LeRobot comes with gymnasium environments that can be installed as extras:
- [aloha](https://github.com/huggingface/gym-aloha)
- [xarm](https://github.com/huggingface/gym-xarm)
- [pusht](https://github.com/huggingface/gym-pusht)
For instance, to install 🤗 LeRobot with aloha and pusht, use:
```bash
pip install -e ".[aloha, pusht]"
```
### Installation from PyPI
**Core Library:**
Install the base package with:
LeRobot can be installed directly from PyPI.
```bash
pip install lerobot
lerobot-info
```
_This installs only the default dependencies._
> [!IMPORTANT]
> For detailed installation guide, please see the [Installation Documentation](https://huggingface.co/docs/lerobot/installation).
**Extra Features:**
To install additional functionality, use one of the following:
## Robots & Control
<div align="center">
<img src="./media/readme/robots_control_video.webp" width="640px" alt="Reachy 2 Demo">
</div>
LeRobot provides a unified `Robot` class interface that decouples control logic from hardware specifics. It supports a wide range of robots and teleoperation devices.
```python
from lerobot.robots.myrobot import MyRobot
# Connect to a robot
robot = MyRobot(config=...)
robot.connect()
# Read observation and send action
obs = robot.get_observation()
action = model.select_action(obs)
robot.send_action(action)
```
**Supported Hardware:** SO100, LeKiwi, Koch, HopeJR, OMX, EarthRover, Reachy2, Gamepads, Keyboards, Phones, OpenARM, Unitree G1.
While these devices are natively integrated into the LeRobot codebase, the library is designed to be extensible. You can easily implement the Robot interface to utilize LeRobot's data collection, training, and visualization tools for your own custom robot.
For detailed hardware setup guides, see the [Hardware Documentation](https://huggingface.co/docs/lerobot/integrate_hardware).
## LeRobot Dataset
To solve the data fragmentation problem in robotics, we utilize the **LeRobotDataset** format.
- **Structure:** Synchronized MP4 videos (or images) for vision and Parquet files for state/action data.
- **HF Hub Integration:** Explore thousands of robotics datasets on the [Hugging Face Hub](https://huggingface.co/lerobot).
- **Tools:** Seamlessly delete episodes, split by indices/fractions, add/remove features, and merge multiple datasets.
```python
from lerobot.datasets.lerobot_dataset import LeRobotDataset
# Load a dataset from the Hub
dataset = LeRobotDataset("lerobot/aloha_mobile_cabinet")
# Access data (automatically handles video decoding)
episode_index=0
print(f"{dataset[episode_index]['action'].shape=}\n")
```
Learn more about it in the [LeRobotDataset Documentation](https://huggingface.co/docs/lerobot/lerobot-dataset-v3)
## SoTA Models
LeRobot implements state-of-the-art policies in pure PyTorch, covering Imitation Learning, Reinforcement Learning, and Vision-Language-Action (VLA) models, with more coming soon. It also provides you with the tools to instrument and inspect your training process.
<p align="center">
<img alt="Gr00t Architecture" src="./media/readme/VLA_architecture.jpg" width="640px">
</p>
Training a policy is as simple as running a script configuration:
```bash
pip install 'lerobot[all]' # All available features
pip install 'lerobot[aloha,pusht]' # Specific features (Aloha & Pusht)
pip install 'lerobot[feetech]' # Feetech motor support
lerobot-train \
--policy=act \
--dataset.repo_id=lerobot/aloha_mobile_cabinet
```
_Replace `[...]` with your desired features._
| Category | Models |
| -------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| **Imitation Learning** | [ACT](./docs/source/policy_act_README.md), [Diffusion](./docs/source/policy_diffusion_README.md), [VQ-BeT](./docs/source/policy_vqbet_README.md) |
| **Reinforcement Learning** | [HIL-SERL](./docs/source/hilserl.mdx), [TDMPC](./docs/source/policy_tdmpc_README.md) & QC-FQL (coming soon) |
| **VLAs Models** | [Pi0.5](./docs/source/pi05.mdx), [GR00T N1.5](./docs/source/policy_groot_README.md), [SmolVLA](./docs/source/policy_smolvla_README.md), [XVLA](./docs/source/xvla.mdx) |
**Available Tags:**
For a full list of optional dependencies, see:
https://pypi.org/project/lerobot/
Similarly to the hardware, you can easily implement your own policy & leverage LeRobot's data collection, training, and visualization tools, and share your model to the HF Hub
> [!NOTE]
> For lerobot 0.4.0, if you want to install pi tags, you will have to do: `pip install "lerobot[pi]@git+https://github.com/huggingface/lerobot.git"`.
>
> This will be solved in the next patch release
For detailed policy setup guides, see the [Policy Documentation](https://huggingface.co/docs/lerobot/bring_your_own_policies).
### Weights & Biases
## Inference & Evaluation
To use [Weights and Biases](https://docs.wandb.ai/quickstart) for experiment tracking, log in with
Evaluate your policies in simulation or on real hardware using the unified evaluation script. LeRobot supports standard benchmarks like **LIBERO**, **MetaWorld** and more to come.
```bash
wandb login
# Evaluate a policy on the LIBERO benchmark
lerobot-eval \
--policy.path=lerobot/pi0_libero_finetuned \
--env.type=libero \
--env.task=libero_object \
--eval.n_episodes=10
```
(note: you will also need to enable WandB in the configuration. See below.)
Learn how to implement your own simulation environment or benchmark and distribute it from the HF Hub by following the [EnvHub Documentation](https://huggingface.co/docs/lerobot/envhub)
### Visualize datasets
## Resources
Check out [example 1](https://github.com/huggingface/lerobot/blob/main/examples/dataset/load_lerobot_dataset.py) that illustrates how to use our dataset class which automatically downloads data from the Hugging Face hub.
You can also locally visualize episodes from a dataset on the hub by executing our script from the command line:
```bash
lerobot-dataset-viz \
--repo-id lerobot/pusht \
--episode-index 0
```
or from a dataset in a local folder with the `root` option and the `--mode local` (in the following case the dataset will be searched for in `./my_local_data_dir/lerobot/pusht`)
```bash
lerobot-dataset-viz \
--repo-id lerobot/pusht \
--root ./my_local_data_dir \
--mode local \
--episode-index 0
```
It will open `rerun.io` and display the camera streams, robot states and actions, like this:
https://github-production-user-asset-6210df.s3.amazonaws.com/4681518/328035972-fd46b787-b532-47e2-bb6f-fd536a55a7ed.mov?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240505%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240505T172924Z&X-Amz-Expires=300&X-Amz-Signature=d680b26c532eeaf80740f08af3320d22ad0b8a4e4da1bcc4f33142c15b509eda&X-Amz-SignedHeaders=host&actor_id=24889239&key_id=0&repo_id=748713144
Our script can also visualize datasets stored on a distant server. See `lerobot-dataset-viz --help` for more instructions.
### The `LeRobotDataset` format
A dataset in `LeRobotDataset` format is very simple to use. It can be loaded from a repository on the Hugging Face hub or a local folder simply with e.g. `dataset = LeRobotDataset("lerobot/aloha_static_coffee")` and can be indexed into like any Hugging Face and PyTorch dataset. For instance `dataset[0]` will retrieve a single temporal frame from the dataset containing observation(s) and an action as PyTorch tensors ready to be fed to a model.
A specificity of `LeRobotDataset` is that, rather than retrieving a single frame by its index, we can retrieve several frames based on their temporal relationship with the indexed frame, by setting `delta_timestamps` to a list of relative times with respect to the indexed frame. For example, with `delta_timestamps = {"observation.image": [-1, -0.5, -0.2, 0]}` one can retrieve, for a given index, 4 frames: 3 "previous" frames 1 second, 0.5 seconds, and 0.2 seconds before the indexed frame, and the indexed frame itself (corresponding to the 0 entry). See example [1_load_lerobot_dataset.py](https://github.com/huggingface/lerobot/blob/main/examples/dataset/load_lerobot_dataset.py) for more details on `delta_timestamps`.
Under the hood, the `LeRobotDataset` format makes use of several ways to serialize data which can be useful to understand if you plan to work more closely with this format. We tried to make a flexible yet simple dataset format that would cover most type of features and specificities present in reinforcement learning and robotics, in simulation and in real-world, with a focus on cameras and robot states but easily extended to other types of sensory inputs as long as they can be represented by a tensor.
Here are the important details and internal structure organization of a typical `LeRobotDataset` instantiated with `dataset = LeRobotDataset("lerobot/aloha_static_coffee")`. The exact features will change from dataset to dataset but not the main aspects:
```
dataset attributes:
├ hf_dataset: a Hugging Face dataset (backed by Arrow/parquet). Typical features example:
│ ├ observation.images.cam_high (VideoFrame):
│ │ VideoFrame = {'path': path to a mp4 video, 'timestamp' (float32): timestamp in the video}
│ ├ observation.state (list of float32): position of an arm joints (for instance)
│ ... (more observations)
│ ├ action (list of float32): goal position of an arm joints (for instance)
│ ├ episode_index (int64): index of the episode for this sample
│ ├ frame_index (int64): index of the frame for this sample in the episode ; starts at 0 for each episode
│ ├ timestamp (float32): timestamp in the episode
│ ├ next.done (bool): indicates the end of an episode ; True for the last frame in each episode
│ └ index (int64): general index in the whole dataset
├ meta: a LeRobotDatasetMetadata object containing:
│ ├ info: a dictionary of metadata on the dataset
│ │ ├ codebase_version (str): this is to keep track of the codebase version the dataset was created with
│ │ ├ fps (int): frame per second the dataset is recorded/synchronized to
│ │ ├ features (dict): all features contained in the dataset with their shapes and types
│ │ ├ total_episodes (int): total number of episodes in the dataset
│ │ ├ total_frames (int): total number of frames in the dataset
│ │ ├ robot_type (str): robot type used for recording
│ │ ├ data_path (str): formattable string for the parquet files
│ │ └ video_path (str): formattable string for the video files (if using videos)
│ ├ episodes: a DataFrame containing episode metadata with columns:
│ │ ├ episode_index (int): index of the episode
│ │ ├ tasks (list): list of tasks for this episode
│ │ ├ length (int): number of frames in this episode
│ │ ├ dataset_from_index (int): start index of this episode in the dataset
│ │ └ dataset_to_index (int): end index of this episode in the dataset
│ ├ stats: a dictionary of statistics (max, mean, min, std) for each feature in the dataset, for instance
│ │ ├ observation.images.front_cam: {'max': tensor with same number of dimensions (e.g. `(c, 1, 1)` for images, `(c,)` for states), etc.}
│ │ └ ...
│ └ tasks: a DataFrame containing task information with task names as index and task_index as values
├ root (Path): local directory where the dataset is stored
├ image_transforms (Callable): optional image transformations to apply to visual modalities
└ delta_timestamps (dict): optional delta timestamps for temporal queries
```
A `LeRobotDataset` is serialised using several widespread file formats for each of its parts, namely:
- hf_dataset stored using Hugging Face datasets library serialization to parquet
- videos are stored in mp4 format to save space
- metadata are stored in plain json/jsonl files
Dataset can be uploaded/downloaded from the HuggingFace hub seamlessly. To work on a local dataset, you can specify its location with the `root` argument if it's not in the default `~/.cache/huggingface/lerobot` location.
#### Reproduce state-of-the-art (SOTA)
We provide some pretrained policies on our [hub page](https://huggingface.co/lerobot) that can achieve state-of-the-art performances.
You can reproduce their training by loading the config from their run. Simply running:
```bash
lerobot-train --config_path=lerobot/diffusion_pusht
```
reproduces SOTA results for Diffusion Policy on the PushT task.
## Contribute
If you would like to contribute to 🤗 LeRobot, please check out our [contribution guide](https://github.com/huggingface/lerobot/blob/main/CONTRIBUTING.md).
### Add a pretrained policy
Once you have trained a policy you may upload it to the Hugging Face hub using a hub id that looks like `${hf_user}/${repo_name}` (e.g. [lerobot/diffusion_pusht](https://huggingface.co/lerobot/diffusion_pusht)).
You first need to find the checkpoint folder located inside your experiment directory (e.g. `outputs/train/2024-05-05/20-21-12_aloha_act_default/checkpoints/002500`). Within that there is a `pretrained_model` directory which should contain:
- `config.json`: A serialized version of the policy configuration (following the policy's dataclass config).
- `model.safetensors`: A set of `torch.nn.Module` parameters, saved in [Hugging Face Safetensors](https://huggingface.co/docs/safetensors/index) format.
- `train_config.json`: A consolidated configuration containing all parameters used for training. The policy configuration should match `config.json` exactly. This is useful for anyone who wants to evaluate your policy or for reproducibility.
To upload these to the hub, run the following:
```bash
huggingface-cli upload ${hf_user}/${repo_name} path/to/pretrained_model
```
See [lerobot_eval.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/scripts/lerobot_eval.py) for an example of how other people may use your policy.
### Acknowledgment
- The LeRobot team 🤗 for building SmolVLA [Paper](https://arxiv.org/abs/2506.01844), [Blog](https://huggingface.co/blog/smolvla).
- Thanks to Tony Zhao, Zipeng Fu and colleagues for open sourcing ACT policy, ALOHA environments and datasets. Ours are adapted from [ALOHA](https://tonyzhaozh.github.io/aloha) and [Mobile ALOHA](https://mobile-aloha.github.io).
- Thanks to Cheng Chi, Zhenjia Xu and colleagues for open sourcing Diffusion policy, Pusht environment and datasets, as well as UMI datasets. Ours are adapted from [Diffusion Policy](https://diffusion-policy.cs.columbia.edu) and [UMI Gripper](https://umi-gripper.github.io).
- Thanks to Nicklas Hansen, Yunhai Feng and colleagues for open sourcing TDMPC policy, Simxarm environments and datasets. Ours are adapted from [TDMPC](https://github.com/nicklashansen/tdmpc) and [FOWM](https://www.yunhaifeng.com/FOWM).
- Thanks to Antonio Loquercio and Ashish Kumar for their early support.
- Thanks to [Seungjae (Jay) Lee](https://sjlee.cc/), [Mahi Shafiullah](https://mahis.life/) and colleagues for open sourcing [VQ-BeT](https://sjlee.cc/vq-bet/) policy and helping us adapt the codebase to our repository. The policy is adapted from [VQ-BeT repo](https://github.com/jayLEE0301/vq_bet_official).
- **[Documentation](https://huggingface.co/docs/lerobot/index):** The complete guide to tutorials & API.
- **[Discord](https://discord.gg/3gxM6Avj):** Join the `LeRobot` server to discuss with the community.
- **[X](https://x.com/LeRobotHF):** Follow us on X to stay up-to-date with the latest developments.
- **[Robot Learning Tutorial](https://huggingface.co/spaces/lerobot/robot-learning-tutorial):** A free, hands-on course to learn robot learning using LeRobot.
## Citation
If you want, you can cite this work with:
If you use LeRobot in your research, please cite:
```bibtex
@misc{cadene2024lerobot,
@@ -339,6 +144,14 @@ If you want, you can cite this work with:
}
```
## Star History
## Contribute
[![Star History Chart](https://api.star-history.com/svg?repos=huggingface/lerobot&type=Timeline)](https://star-history.com/#huggingface/lerobot&Timeline)
We welcome contributions from everyone in the community! To get started, please read our [CONTRIBUTING.md](./CONTRIBUTING.md) guide. Whether you're adding a new feature, improving documentation, or fixing a bug, your help and feedback are invaluable. We're incredibly excited about the future of open-source robotics and can't wait to work with you on what's next—thank you for your support!
<p align="center">
<img alt="SO101 Video" src="./media/readme/so100_video.webp" width="640px">
</p>
<div align="center">
<sub>Built by the <a href="https://huggingface.co/lerobot">LeRobot</a> team at <a href="https://huggingface.co">Hugging Face</a> with ❤️</sub>
</div>
+6
View File
@@ -41,7 +41,13 @@
title: NVIDIA GR00T N1.5
- local: xvla
title: X-VLA
- local: walloss
title: WALL-OSS
title: "Policies"
- sections:
- local: sarm
title: SARM
title: "Reward Models"
- sections:
- local: async
title: Use Async Inference
+22 -5
View File
@@ -201,7 +201,8 @@ from lerobot.teleoperators.so100_leader.so100_leader import SO100Leader
from lerobot.utils.control_utils import init_keyboard_listener
from lerobot.utils.utils import log_say
from lerobot.utils.visualization_utils import init_rerun
from lerobot.record import record_loop
from lerobot.scripts.lerobot_record import record_loop
from lerobot.processor import make_default_processors
NUM_EPISODES = 5
FPS = 30
@@ -209,12 +210,19 @@ EPISODE_TIME_SEC = 60
RESET_TIME_SEC = 10
TASK_DESCRIPTION = "My task description"
# Create the robot and teleoperator configurations
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
# Create robot configuration
robot_config = SO100FollowerConfig(
port="/dev/tty.usbmodem58760434471", id="my_awesome_follower_arm", cameras=camera_config
id="my_awesome_follower_arm",
cameras={
"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS) # Optional: fourcc="MJPG" for troubleshooting OpenCV async error.
},
port="/dev/tty.usbmodem58760434471",
)
teleop_config = SO100LeaderConfig(
id="my_awesome_leader_arm",
port="/dev/tty.usbmodem585A0077581",
)
teleop_config = SO100LeaderConfig(port="/dev/tty.usbmodem585A0077581", id="my_awesome_leader_arm")
# Initialize the robot and teleoperator
robot = SO100Follower(robot_config)
@@ -243,6 +251,9 @@ init_rerun(session_name="recording")
robot.connect()
teleop.connect()
# Create the required processors
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
episode_idx = 0
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
@@ -251,6 +262,9 @@ while episode_idx < NUM_EPISODES and not events["stop_recording"]:
robot=robot,
events=events,
fps=FPS,
teleop_action_processor=teleop_action_processor,
robot_action_processor=robot_action_processor,
robot_observation_processor=robot_observation_processor,
teleop=teleop,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
@@ -265,6 +279,9 @@ while episode_idx < NUM_EPISODES and not events["stop_recording"]:
robot=robot,
events=events,
fps=FPS,
teleop_action_processor=teleop_action_processor,
robot_action_processor=robot_action_processor,
robot_observation_processor=robot_observation_processor,
teleop=teleop,
control_time_s=RESET_TIME_SEC,
single_task=TASK_DESCRIPTION,
+35
View File
@@ -0,0 +1,35 @@
# WALL-OSS
This repository contains the Hugging Face port of **WALL-OSS**, a Vision-Language-Action model for cross-embodiment robotic control based on Qwen2.5-VL with flow matching/FAST action prediction.
---
## Model Overview
| Feature | Description |
| ------------------ | ----------------------------------------------------- | --- |
| Base Model | Qwen2.5-VL (Vision-Language Model) |
| Action Prediction | Flow Matching (diffusion) or FAST (discrete tokens) |
| Architecture | Mixture of Experts (MoE) with action-specific routing | |
| Multi-Modal Inputs | Vision (images/videos), Language, Proprioception |
---
## Citation
If you use this work, please cite:
```bibtex
@article{zhai2025igniting,
title = {Igniting VLMs Toward the Embodied Space},
author = {Zhai, Andy and Liu, Brae and Fang, Bruno and Cai, Chalse and Ma, Ellie and Yin, Ethan and Wang, Hao and Zhou, Hugo and Wang, James and Shi, Lights and Liang, Lucy and Wang, Make and Wang, Qian and Gan, Roy and Yu, Ryan and Li, Shalfun and Liu, Starrick and Chen, Sylas and Chen, Vincent and Xu, Zach},
journal = {arXiv preprint arXiv:2509.11766},
year = {2025}
}
```
---
## License
This port follows the **Apache 2.0 License**.
+586
View File
@@ -0,0 +1,586 @@
# SARM: Stage-Aware Reward Modeling
SARM (Stage-Aware Reward Modeling) is a video-based reward modeling framework for long-horizon robot manipulation tasks. This guide covers how to train SARM reward models and optionally use them with Reward-Aligned Behavior Cloning (RA-BC).
**Paper**: [SARM: Stage-Aware Reward Modeling for Long Horizon Robot Manipulation](https://arxiv.org/abs/2509.25358)
## Why Reward Models?
Standard behavior cloning treats all demonstration frames equally, but real-world robot datasets are messy. They contain hesitations, corrections, and variable-quality trajectories. Reward models solve this by learning a generalizable notion of **task progress** from demonstrations: given video frames and a task description, they predict how close the robot is to completing the task (0→1). This learned "progress signal" can be used in multiple ways, two promising applications are: (1) **weighted imitation learning** (RA-BC), where high-progress frames receive more weight during policy training, and (2) **reinforcement learning**, where the reward model provides dense rewards for online or offline policy improvement.
## Overview
SARM has following features:
1. **Stage-aware architecture**: Jointly predicts the high-level task stage and fine-grained progress within each stage
2. **Subtask annotations**: Uses natural language subtask annotations to derive consistent progress labels
3. **Temporal proportions**: Computes dataset-level priors (α̅\_k) for each subtask to normalize progress across variable-length demonstrations
SARM trains on a compact **stage+tau** target for each frame:
- **stage**: integer stage index `k ∈ {0, ..., K-1}`
- **τ (tau)**: within-stage progress `τ ∈ [0, 1]`
- **target encoding**: `y = k + τ` (this is what the dataset processor produces)
At inference time (and in downstream RA-BC), SARM converts the raw `k + τ` value into a **normalized progress** in `[0, 1]` using dataset-level **temporal proportions** `α̅_k` (stored in `meta/temporal_proportions_*.json`).
This matches **Formula (2)** from the paper:
```
progress_t = P_{k-1} + α̅_k × τ_t
```
Where:
- `τ_t = (t - s_k) / (e_k - s_k)` is within-subtask normalized time
- `P_{k-1}` is cumulative prior (sum of previous subtask proportions)
- `α̅_k` is the temporal proportion for subtask k
This ensures identical task states map to consistent progress values, even across demonstrations of different lengths.
## Inputs and Targets (What the new code expects)
SARM is trained through its processor (`src/lerobot/policies/sarm/processor_sarm.py`), which:
- **Encodes** images and task text with CLIP (ViT-B/32) into `video_features` and `text_features`
- **Pads/truncates** robot state into `state_features` (up to `max_state_dim`)
- **Builds targets** as `sparse_targets` (and `dense_targets` in `dense_only`/`dual`) using the stage+tau encoding `y = k + τ`
- **Masks rewind frames** using a per-sample `lengths` tensor (rewind is a training-time augmentation)
At minimum, each training sample needs:
- `task` (string): task description
- `policy.image_key` images and `policy.state_key` states from the dataset
---
## Annotation Modes
You can choose from **3 annotation modes** that determine how progress labels are computed:
| Mode | Annotations Required | Heads | Use Case |
| -------------- | -------------------- | ---------------------------- | ------------------------------------------------------------ |
| `single_stage` | None | Sparse only | Simple tasks, quick experiments, no VLM needed |
| `dense_only` | Dense (VLM) | Dual (sparse auto-generated) | Detailed subtask tracking without defining high-level stages |
| `dual` | Sparse + Dense (VLM) | Dual | Full SARM paper setup with both granularities |
### Mode Details
<hfoptions id="mode_explanation">
<hfoption id="single_stage">
**No annotations required.** The entire episode is treated as a single stage called `"task"`, and progress is linear from 0 to 1 over the episode duration.
- **Sparse head**: 1 stage ("task"), linear progress
- **Dense head**: Not used
- **Best for**: Simple tasks, quick experiments, or when VLM annotation is not available
## Set Up Your Environment
1. Install LeRobot by following our [Installation Guide](./installation).
2. Install SARM dependencies by running:
```bash
pip install -e ".[sarm]"
```
Workflow:
```
1. Train SARM → 2. Visualize predictions → 3. (Optional) Train policy with RA-BC
```
</hfoption>
<hfoption id="dense_only">
**Only dense (fine-grained) annotations from a VLM.** The sparse head automatically uses a single `"task"` stage covering the full episode, while the dense head learns detailed subtask progression.
- **Sparse head**: 1 stage ("task"), linear progress (auto-generated)
- **Dense head**: Multiple fine-grained stages from VLM annotations
- **Best for**: When you want detailed subtask tracking but don't need to define high-level stages
Workflow:
```
1. Annotate (dense) → 2. Verify → 3. Train SARM → 4. Visualize → 5. (Optional) Train policy with RA-BC
```
</hfoption>
<hfoption id="dual">
**Both sparse and dense annotations from VLM.** Full dual-head mode as described in the SARM paper, with both high-level (sparse) and fine-grained (dense) stage predictions.
- **Sparse head**: High-level stages from VLM annotations
- **Dense head**: Fine-grained stages from VLM annotations
- **Best for**: Complex multi-stage tasks where both granularities are useful
Workflow:
```
1. Annotate (sparse+dense) → 2. Verify → 3. Train SARM → 4. Visualize → 5. (Optional) Train policy with RA-BC
```
</hfoption>
</hfoptions>
---
## Step 1: Subtask Annotation
<hfoptions id="annotation_mode">
<hfoption id="single_stage">
**No annotation required!** Skip this step entirely. The model will use the episode's task description and compute linear progress automatically.
</hfoption>
<hfoption id="dense_only">
Generate **dense (fine-grained) annotations only** using a VLM. The sparse stage will be auto-generated.
```bash
python src/lerobot/data_processing/sarm_annotations/subtask_annotation.py \
--repo-id your-username/your-dataset \
--dense-only \
--dense-subtasks "Bring robot arms up from starting position,Grab near side and do 1st fold,Grab side and do 2nd fold,Grab side and do 3rd fold to finish folding" \
--video-key observation.images.base \
--num-workers 4 \
--push-to-hub
```
**What gets saved:**
- `meta/temporal_proportions_sparse.json` - Auto-generated sparse proportions (`{"task": 1.0}`)
- `meta/temporal_proportions_dense.json` - Dense temporal proportions
- Per-episode columns in `episodes/*.parquet`:
- `dense_subtask_names`, `dense_subtask_start_frames`, `dense_subtask_end_frames`
- (also time-based columns: `dense_subtask_start_times`, `dense_subtask_end_times`)
</hfoption>
<hfoption id="dual">
Generate **both sparse (high-level) and dense (fine-grained) annotations** using a VLM.
```bash
python src/lerobot/data_processing/sarm_annotations/subtask_annotation.py \
--repo-id your-username/your-dataset \
--sparse-subtasks "Bring arms up from starting position,Fold the towel (3 folds in total)" \
--dense-subtasks "Bring robot arms up from starting position,Grab near side and do 1st fold,Grab side and do 2nd fold,Grab side and do 3rd fold to finish folding" \
--video-key observation.images.base \
--num-workers 4 \
--push-to-hub
```
**What gets saved:**
- `meta/temporal_proportions_sparse.json` - Sparse temporal proportions
- `meta/temporal_proportions_dense.json` - Dense temporal proportions
- Per-episode columns in `episodes/*.parquet`:
- `sparse_subtask_names`, `sparse_subtask_start_frames`, `sparse_subtask_end_frames`
- `dense_subtask_names`, `dense_subtask_start_frames`, `dense_subtask_end_frames`
- (also time-based columns: `*_subtask_start_times`, `*_subtask_end_times`)
</hfoption>
</hfoptions>
### Annotation Arguments
| Argument | Description |
| ---------------------- | ------------------------------------------------------------------------------- |
| `--repo-id` | HuggingFace dataset repository ID |
| `--sparse-subtasks` | Comma-separated list of high-level subtask names |
| `--dense-subtasks` | Comma-separated list of fine-grained subtask names |
| `--dense-only` | Generate only dense annotations (auto-creates sparse "task" stage) |
| `--video-key` | Camera/video key to use (e.g., `observation.images.top`) |
| `--num-workers` | Number of parallel GPU workers (default: 1) |
| `--episodes` | Specific episode indices to annotate (default: all) |
| `--skip-existing` | Skip episodes that already have annotations |
| `--model` | VLM model (default: `Qwen/Qwen3-VL-30B-A3B-Instruct`) |
| `--num-visualizations` | Number of episodes to visualize after annotation (default: 5, set to 0 to skip) |
> **Note**: After annotation completes, 5 episodes are automatically visualized by default. Use `--num-visualizations 0` to skip this step.
---
## Step 2: Verify Annotations
<hfoptions id="verify_mode">
<hfoption id="single_stage">
**No verification needed!** Skip this step.
</hfoption>
<hfoption id="dense_only">
Visualize annotations using the `--visualize-only` flag:
```bash
python src/lerobot/data_processing/sarm_annotations/subtask_annotation.py \
--repo-id your-username/your-dataset \
--visualize-only \
--visualize-type dense \
--num-visualizations 5 \
--video-key observation.images.base \
--output-dir ./subtask_viz
```
</hfoption>
<hfoption id="dual">
Visualize annotations using the `--visualize-only` flag:
```bash
python src/lerobot/data_processing/sarm_annotations/subtask_annotation.py \
--repo-id your-username/your-dataset \
--visualize-only \
--visualize-type both \
--num-visualizations 5 \
--video-key observation.images.base \
--output-dir ./subtask_viz
```
</hfoption>
</hfoptions>
This generates visualizations showing video frames with subtask boundaries overlaid and timeline of subtasks.
### Visualization Arguments
| Argument | Description |
| ---------------------- | -------------------------------------------------------------- |
| `--visualize-only` | Only visualize existing annotations (no generation) |
| `--num-visualizations` | Number of episodes to visualize (default: 5) |
| `--visualize-type` | Type of annotations to visualize: `sparse`, `dense`, or `both` |
**Tip**: If annotations are inaccurate, adjust your subtask descriptions to be more specific and re-run.
---
## Step 3: Train SARM
<hfoptions id="train_mode">
<hfoption id="single_stage">
Train with **no annotations** - uses linear progress from 0 to 1:
```bash
python src/lerobot/scripts/lerobot_train.py \
--dataset.repo_id=your-username/your-dataset \
--policy.type=sarm \
--policy.annotation_mode=single_stage \
--policy.image_key=observation.images.base \
--output_dir=outputs/train/sarm_single \
--batch_size=32 \
--steps=5000 \
--wandb.enable=true \
--wandb.project=sarm \
--policy.repo_id=your-username/your-model-name
```
</hfoption>
<hfoption id="dense_only">
Train with **dense annotations only** (sparse auto-generated):
```bash
python src/lerobot/scripts/lerobot_train.py \
--dataset.repo_id=your-username/your-dataset \
--policy.type=sarm \
--policy.annotation_mode=dense_only \
--policy.image_key=observation.images.base \
--output_dir=outputs/train/sarm_dense \
--batch_size=32 \
--steps=5000 \
--wandb.enable=true \
--wandb.project=sarm \
--policy.repo_id=your-username/your-model-name
```
</hfoption>
<hfoption id="dual">
Train with **both sparse and dense annotations**:
```bash
python src/lerobot/scripts/lerobot_train.py \
--dataset.repo_id=your-username/your-dataset \
--policy.type=sarm \
--policy.annotation_mode=dual \
--policy.image_key=observation.images.base \
--output_dir=outputs/train/sarm_dual \
--batch_size=32 \
--steps=5000 \
--wandb.enable=true \
--wandb.project=sarm \
--policy.repo_id=your-username/your-model-name
```
</hfoption>
</hfoptions>
### Multi-GPU Training
Add `accelerate launch --multi_gpu --num_processes=4` to use multiple GPUs for training.
### Training Arguments
| Argument | Description | Default |
| -------------------------- | ----------------------------------------------------------------- | ------------------------ |
| `--policy.annotation_mode` | `single_stage`, `dense_only`, or `dual` | `single_stage` |
| `--policy.image_key` | Camera key for images | `observation.images.top` |
| `--policy.state_key` | Key for joint states | `observation.state` |
| `--policy.n_obs_steps` | Observation history steps (total obs frames = `n_obs_steps + 1`) | `8` |
| `--policy.frame_gap` | Gap (in frames) between sampled observations (at 30 fps: 30 ≈ 1s) | `30` |
---
## Step 4: Visualize Predictions
Use `compute_rabc_weights.py` with `--visualize-only` to visualize model predictions (and, if available, annotation-derived targets) without writing a parquet file.
<hfoptions id="viz_mode">
<hfoption id="single_stage">
```bash
python src/lerobot/policies/sarm/compute_rabc_weights.py \
--dataset-repo-id your-username/your-dataset \
--reward-model-path your-username/sarm-model \
--visualize-only \
--num-visualizations 5 \
--head-mode sparse \
--output-dir ./sarm_viz
```
</hfoption>
<hfoption id="dense_only">
```bash
python src/lerobot/policies/sarm/compute_rabc_weights.py \
--dataset-repo-id your-username/your-dataset \
--reward-model-path your-username/sarm-model \
--visualize-only \
--num-visualizations 5 \
--head-mode dense \
--output-dir ./sarm_viz
```
</hfoption>
<hfoption id="dual">
```bash
python src/lerobot/policies/sarm/compute_rabc_weights.py \
--dataset-repo-id your-username/your-dataset \
--reward-model-path your-username/sarm-model \
--visualize-only \
--num-visualizations 5 \
--head-mode both \
--output-dir ./sarm_viz
```
</hfoption>
</hfoptions>
The visualization shows:
- **Progress plot**: Predicted progress (and optional annotation-derived “GT” when available and `--stride 1`)
- **Stage probabilities**: Stacked area plot of predicted stage probabilities
- **Sample frames**: Key frames from the episode with progress/stage labels
### Visualization Arguments
| Argument | Description |
| ---------------------- | --------------------------------------------------------- |
| `--visualize-only` | Only visualize predictions (no RABC computation) |
| `--num-visualizations` | Number of episodes to visualize (default: 5) |
| `--head-mode` | SARM head to use: `sparse`, `dense`, or `both` |
| `--stride` | Compute every N frames, interpolate the rest (default: 1) |
---
## Step 5 (Optional): Train Policy with RA-BC
Reward-Aligned Behavior Cloning (RA-BC) uses the trained SARM model to weight training samples based on predicted progress improvement. This requires two steps:
1. **Precompute progress values** for all frames using the trained SARM model
2. **Train policy** with RA-BC weighting using the precomputed values
### How RA-BC Works
For each training sample, RA-BC computes the progress delta:
```
r_i = φ(o_{t+Δ}) - φ(o_t)
```
Where `φ` is the SARM progress prediction and `Δ` is the policy's `chunk_size`. Samples with positive progress (good demonstrations) get higher weights, while samples with negative or zero progress get down-weighted.
The weighting follows **Equations 8-9** from the paper:
- **Soft weight**: `w̃_i = clip((r_i 2σ)) / (4σ + ε), 0, 1)`
- **Final weight**: `w_i = 𝟙{r_i > κ} + 𝟙{0 ≤ r_i ≤ κ} × w̃_i`
### Step 5a: Compute SARM Progress Values
First, run the SARM model on all frames in your dataset to compute progress values:
```bash
python src/lerobot/policies/sarm/compute_rabc_weights.py \
--dataset-repo-id your-username/your-dataset \
--reward-model-path your-username/sarm-model \
--head-mode sparse \
--num-visualizations 5 \
--push-to-hub
```
This script:
- Processes all frames and computes progress values
- Saves progress values to a parquet file next to the dataset on disk (defaults to `<dataset_root>/sarm_progress.parquet`)
- Generates visualizations of the first N episodes (default: 5)
**Arguments:**
| Argument | Description | Default |
| ---------------------- | -------------------------------------------------------------- | ---------- |
| `--reward-model-path` | Path to trained SARM model | (required) |
| `--head-mode` | SARM head to use: `sparse`, `dense`, or `both` | `sparse` |
| `--device` | Device for inference | `cuda` |
| `--visualize-only` | Only visualize predictions (no RA-BC computation) | `false` |
| `--num-visualizations` | Number of episodes to visualize (default: 5, set to 0 to skip) | `5` |
**Output format** (`sarm_progress.parquet`):
| Column | Description |
| ----------------- | ---------------------------------------------- |
| `index` | Global frame index in dataset |
| `episode_index` | Episode number |
| `frame_index` | Local frame index within episode |
| `progress_sparse` | Sparse head progress value [0, 1] |
| `progress_dense` | Dense head progress value [0, 1] (if computed) |
### Step 5b: Train Policy with RA-BC
Once you have the progress file, train your policy with RA-BC weighting. The progress file is auto-detected from the dataset path (`sarm_progress.parquet`). Currently PI0, PI0.5 and SmolVLA are supported with RA-BC:
```bash
python src/lerobot/scripts/lerobot_train.py \
--dataset.repo_id=your-username/your-dataset \
--policy.type=pi0 \
--use_rabc=true \
--rabc_head_mode=sparse \
--rabc_kappa=0.01 \
--output_dir=outputs/train/policy_rabc \
--batch_size=32 \
--steps=40000
```
The training script automatically:
- Loads the precomputed progress values from the parquet file
- Uses the policy's `chunk_size` to compute progress deltas (Δ)
- Computes sample weights based on progress improvement
- Applies weighted loss during training
**RA-BC Arguments:**
| Argument | Description | Default |
| ---------------------- | ---------------------------------------------------------- | ---------------------------------- |
| `--use_rabc` | Enable RA-BC sample weighting | `false` |
| `--rabc_progress_path` | Path to progress parquet file (auto-detected from dataset) | `sarm_progress.parquet` in dataset |
| `--rabc_head_mode` | Which SARM head's progress to use: `sparse` or `dense` | `sparse` |
| `--rabc_kappa` | Threshold κ for high-quality samples | `0.01` |
### Tuning RA-BC Kappa
The `kappa` parameter is the threshold that determines which samples get full weight (w=1). Understanding how to tune it is critical for RA-BC to work effectively.
**How the weighting works:**
| Condition | Weight |
| ------------------- | ----------------------- |
| `delta > kappa` | 1.0 (hard threshold) |
| `0 ≤ delta ≤ kappa` | Soft weight from Eq. 8 |
| `delta < 0` | 0.0 (negative progress) |
**Diagnosing kappa issues:**
Monitor these WandB metrics during training:
| Metric | Healthy Range | Problem Indicator |
| ------------------ | ------------- | ------------------------- |
| `rabc_mean_weight` | 0.3 - 0.8 | ≈ 1.0 means kappa too low |
| `rabc_delta_mean` | > 0 | Should be positive |
| `rabc_delta_std` | > 0 | Variance in data quality |
**If `rabc_mean_weight ≈ 1.0`:** Your kappa is too low. Most samples have `delta > kappa` and bypass the soft-weighting entirely. RA-BC becomes equivalent to vanilla BC.
**Setting kappa based on your data:**
The default `kappa=0.01` was tuned for the paper's T-shirt folding task (~90s episodes at 30fps). For your dataset, check the logged `rabc_delta_mean` and `rabc_delta_std`:
```
# If delta_mean ≈ 0.03 and delta_std ≈ 0.02:
# Most deltas fall in range [0.01, 0.05]
# Option 1: Set kappa = delta_mean (medium selectivity)
--rabc_kappa=0.03
# Option 2: Set kappa = delta_mean + delta_std (high selectivity)
--rabc_kappa=0.05
# Option 3: Set kappa = delta_mean + 2*delta_std (very selective)
--rabc_kappa=0.07
```
**When RA-BC may not help:**
If your dataset is already high quality (consistent progress across all demonstrations), RA-BC won't provide much benefit since there's nothing to filter.
### Multi-GPU Training with RA-BC
```bash
accelerate launch \
--multi_gpu \
--num_processes=4 \
src/lerobot/scripts/lerobot_train.py \
--dataset.repo_id=your-username/your-dataset \
--policy.type=pi0 \
--use_rabc=true \
--rabc_kappa=0.01 \
--output_dir=outputs/train/policy_rabc \
--batch_size=32 \
--steps=40000
```
---
## Tips & Best Practices
### Choosing a Mode
- **Start with `single_stage`** for quick experiments - no annotation overhead
- Use **`dense_only`** when you want detailed progress tracking but tasks don't have clear high-level stages
- Use **`dual`** for complex tasks where both coarse and fine-grained progress is meaningful
### Annotation Quality
1. **Be specific with subtask names**: Instead of "fold", use "grab near side and fold toward center"
2. **Verify with visualization**: Always check a few episodes before training
3. **Consistent naming**: Use the same subtask names across all episodes
### RA-BC
1. **Train SARM first**: RA-BC quality depends entirely on SARM quality
2. **Monitor `rabc_mean_weight`**: If it's ≈ 1.0, increase kappa (see [Tuning RA-BC Kappa](#tuning-ra-bc-kappa))
---
## Citation
```bibtex
@article{chen2025sarm,
title={SARM: Stage-Aware Reward Modeling for Long Horizon Robot Manipulation},
author={Chen, Qianzhong and Yu, Justin and Schwager, Mac and Abbeel, Pieter and Shentu, Yide and Wu, Philipp},
journal={arXiv preprint arXiv:2509.25358},
year={2025}
}
```
+6 -1
View File
@@ -4,11 +4,12 @@ This guide covers the complete setup process for the Unitree G1 humanoid, from i
## About the Unitree G1
We offer support for both 29 and 23 DOF G1. In this first PR we introduce:
We offer support for both 29 and 23 DOF G1. We introduce:
- **`unitree g1` robot class, handling low level communication with the humanoid**
- **ZMQ socket bridge** for remote communication over WiFi, allowing one to deploy policies remotely instead of over ethernet or directly on the Orin
- **GR00T locomotion policy** for bipedal walking and balance
- **MuJoCo simulation mode** for testing policies without the physical robot
---
@@ -191,6 +192,10 @@ Press `Ctrl+C` to stop the policy.
---
## Extra: Running in Simulation Mode (MuJoCo)
You can now test and develop policies without a physical robot using MuJoCo. to do so set `is_simulation=True` in config.
## Additional Resources
- [Unitree SDK Documentation](https://github.com/unitreerobotics/unitree_sdk2_python)
+104 -3
View File
@@ -11,13 +11,14 @@ LeRobot provides several utilities for manipulating datasets:
3. **Merge Datasets** - Combine multiple datasets into one. The datasets must have identical features, and episodes are concatenated in the order specified in `repo_ids`
4. **Add Features** - Add new features to a dataset
5. **Remove Features** - Remove features from a dataset
6. **Convert to Video** - Convert image-based datasets to video format for efficient storage
The core implementation is in `lerobot.datasets.dataset_tools`.
An example script detailing how to use the tools API is available in `examples/dataset/use_dataset_tools.py`.
## Command-Line Tool: lerobot-edit-dataset
`lerobot-edit-dataset` is a command-line script for editing datasets. It can be used to delete episodes, split datasets, merge datasets, add features, and remove features.
`lerobot-edit-dataset` is a command-line script for editing datasets. It can be used to delete episodes, split datasets, merge datasets, add features, remove features, and convert image datasets to video format.
Run `lerobot-edit-dataset --help` for more information on the configuration of each operation.
@@ -86,9 +87,71 @@ lerobot-edit-dataset \
--operation.feature_names "['observation.images.top']"
```
#### Convert to Video
Convert an image-based dataset to video format, creating a new LeRobotDataset where images are stored as videos. This is useful for reducing storage requirements and improving data loading performance. The new dataset will have the exact same structure as the original, but with images encoded as MP4 videos in the proper LeRobot format.
```bash
# Local-only: Save to a custom output directory (no hub push)
lerobot-edit-dataset \
--repo_id lerobot/pusht_image \
--operation.type convert_to_video \
--operation.output_dir /path/to/output/pusht_video
# Save with new repo_id (local storage)
lerobot-edit-dataset \
--repo_id lerobot/pusht_image \
--new_repo_id lerobot/pusht_video \
--operation.type convert_to_video
# Convert and push to Hugging Face Hub
lerobot-edit-dataset \
--repo_id lerobot/pusht_image \
--new_repo_id lerobot/pusht_video \
--operation.type convert_to_video \
--push_to_hub true
# Convert with custom video codec and quality settings
lerobot-edit-dataset \
--repo_id lerobot/pusht_image \
--operation.type convert_to_video \
--operation.output_dir outputs/pusht_video \
--operation.vcodec libsvtav1 \
--operation.pix_fmt yuv420p \
--operation.g 2 \
--operation.crf 30
# Convert only specific episodes
lerobot-edit-dataset \
--repo_id lerobot/pusht_image \
--operation.type convert_to_video \
--operation.output_dir outputs/pusht_video \
--operation.episode_indices "[0, 1, 2, 5, 10]"
# Convert with multiple workers for parallel processing
lerobot-edit-dataset \
--repo_id lerobot/pusht_image \
--operation.type convert_to_video \
--operation.output_dir outputs/pusht_video \
--operation.num_workers 8
```
**Parameters:**
- `output_dir`: Custom output directory (optional - by default uses `new_repo_id` or `{repo_id}_video`)
- `vcodec`: Video codec to use - options: `h264`, `hevc`, `libsvtav1` (default: `libsvtav1`)
- `pix_fmt`: Pixel format - options: `yuv420p`, `yuv444p` (default: `yuv420p`)
- `g`: Group of pictures (GOP) size - lower values give better quality but larger files (default: 2)
- `crf`: Constant rate factor - lower values give better quality but larger files, 0 is lossless (default: 30)
- `fast_decode`: Fast decode tuning option (default: 0)
- `episode_indices`: List of specific episodes to convert (default: all episodes)
- `num_workers`: Number of parallel workers for processing (default: 4)
**Note:** The resulting dataset will be a proper LeRobotDataset with all cameras encoded as videos in the `videos/` directory, with parquet files containing only metadata (no raw image data). All episodes, stats, and tasks are preserved.
### Push to Hub
Add the `--push_to_hub` flag to any command to automatically upload the resulting dataset to the Hugging Face Hub:
Add the `--push_to_hub true` flag to any command to automatically upload the resulting dataset to the Hugging Face Hub:
```bash
lerobot-edit-dataset \
@@ -96,7 +159,45 @@ lerobot-edit-dataset \
--new_repo_id lerobot/pusht_after_deletion \
--operation.type delete_episodes \
--operation.episode_indices "[0, 2, 5]" \
--push_to_hub
--push_to_hub true
```
There is also a tool for adding features to a dataset that is not yet covered in `lerobot-edit-dataset`.
# Dataset Visualization
## Online Visualization
When you record a dataset using `lerobot`, it automatically uploads to the Hugging Face Hub unless you specify otherwise. To view the dataset online, use our **LeRobot Dataset Visualizer**, available at:
https://huggingface.co/spaces/lerobot/visualize_dataset
## Local Visualization
You can also visualize episodes from a dataset locally using our command-line tool.
**From the Hugging Face Hub:**
```bash
lerobot-dataset-viz \
--repo-id lerobot/pusht \
--episode-index 0
```
**From a local folder:**
Add the `--root` option and set `--mode local`. For example, to search in `./my_local_data_dir/lerobot/pusht`:
```bash
lerobot-dataset-viz \
--repo-id lerobot/pusht \
--root ./my_local_data_dir \
--mode local \
--episode-index 0
```
Once executed, the tool opens `rerun.io` and displays the camera streams, robot states, and actions for the selected episode.
For advanced usage—including visualizing datasets stored on a remote server—run:
```bash
lerobot-dataset-viz --help
```
+74
View File
@@ -0,0 +1,74 @@
# WALL-OSS
WALL-OSS is an open-source foundation model for embodied intelligence, proposed by the [XSquare Robot](https://x2robot.com/en/research/68bc2cde8497d7f238dde690) team in 2025. The LeRobot implementation is adapted from their open-source [WallX](https://github.com/X-Square-Robot/wall-x) repository.
X Square Robots WALL-OSS is now integrated into Hugging Faces LeRobot ecosystem. This is an exciting collaborative project between the LeRobot and X Square Robot teams. You can now post-train, evaluate, and deploy WALL-OSS directly through LeRobot. With this, were aiming to make it easier for the open-source robotics community to customize and deploy WALL-OSS foundation models. Read and explore WALL-OSS [paper](https://arxiv.org/pdf/2509.11766) and [code](https://github.com/X-Square-Robot/wall-x).
## Model Overview
The WALL-OSS team is building the embodied foundation model to capture and compress the world's most valuable data: the continuous, high-fidelity stream of physical interaction. By creating a direct feedback loop between the model's decisions and the body's lived experience, the emergence of a truly generalizable intelligence is enabled—one that understands not just how the world works, but how to act effectively within it.
Technically, WALL-OSS introduces a tightly coupled multimodal architecture (tightly-coupled MoE structure) that integrates both discrete and continuous action modeling strategies. Through a two-stage training pipeline (Inspiration → Integration), the model gradually unifies semantic reasoning and high-frequency action generation. Its core innovations include:
- **Embodied perceptionenhanced multimodal pretraining**: Large-scale training on unified visionlanguageaction data to strengthen spatial, causal, and manipulation understanding.
- **Unified Cross-Level Chain-of-Thought (Uni-CoT)**: A single differentiable framework that unifies high-level instruction reasoning, sub-task decomposition, and fine-grained action synthesis, forming a continuous chain from “understanding” to “execution.”
- **Mixture-of-Experts (MoE) action heads**: Dynamically activating experts depending on the task phase and modeling actions in discrete or continuous space to maintain stable VLM priors.
- **Two-stage training paradigm**:
- **Inspiration stage**: Injecting discrete action priors to strengthen spatial understanding and semantic-action alignment.
- **Integration stage**: Using flow matching to achieve high-frequency continuous control.
## Installation Requirements
1. Install LeRobot by following our [Installation Guide](./installation).
2. Install WallX dependencies by running:
```bash
pip install -e ".[wallx]"
```
## Usage
To use WallX in LeRobot, specify the policy type as:
```python
policy.type=wall_x
```
## Training
For training WallX, you can use the standard LeRobot training script with the appropriate configuration:
```bash
python src/lerobot/scripts/lerobot_train.py \
--dataset.repo_id=your_dataset \
--policy.type=wall_x \
--output_dir=./outputs/wallx_training \
--job_name=wallx_training \
--policy.repo_id=your_repo_id \
--policy.pretrained_name_or_path=x-square-robot/wall-oss-flow \
--policy.prediction_mode=diffusion \
--policy.attn_implementation=eager \
--steps=3000 \
--policy.device=cuda \
--batch_size=32
```
### Training Arguments
| Argument | Description |
| ------------------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `--dataset.repo_id` | The Hugging Face Hub repository ID for your training dataset (e.g., `lerobot/aloha_sim_insertion_human`) |
| `--policy.type` | Specifies using the WallX policy architecture |
| `--output_dir` | Local directory where training checkpoints and logs will be saved |
| `--job_name` | A name identifier for this training run (used in logging/tracking) |
| `--policy.repo_id` | Your Hugging Face Hub repo ID where the trained model will be pushed |
| `--policy.pretrained_path` | Path to pretrained WallX weights to initialize from (the official WALL-OSS checkpoint) |
| `--policy.prediction_mode` | The action prediction strategy: `diffusion` or `fast` - `diffusion` uses iterative denoising for action generation, `fast` uses next token prediction instead |
| `--policy.attn_implementation` | Attention implementation backend - `eager` uses standard PyTorch attention (alternatives include `flash_attention_2` or `sdpa`) |
| `--steps` | Total number of training steps to run |
| `--policy.device` | Device to train on (`cuda` for GPU, `cpu` for CPU) |
| `--batch_size` | Number of samples per training batch |
## License
This model follows the **Apache 2.0 License**, consistent with the original [WallX repository](https://github.com/X-Square-Robot/wall-x).
+42 -84
View File
@@ -24,7 +24,7 @@ Built from pure Transformer encoders, X-VLA scales naturally with model size and
<img
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/xvla-architecture2.png"
alt="XVLA Architecture 2"
style="width: 32%; max-width: 450px; height: auto;"
style="width: 60%; height: auto;"
/>
</p>
@@ -120,7 +120,7 @@ Adapted for Google Robot platforms.
### Recommended Training Configuration
When fine-tuning X-VLA for a new embodiment or task, we recommend the following freezing strategy:
When fine-tuning X-VLA for a new embodiment or task, we recommend not freezing the VLM, and also setting the `policy.dtype=bfloat16` to not hit OOM errors.
```bash
lerobot-train \
@@ -129,25 +129,26 @@ lerobot-train \
--job_name=xvla_training \
--policy.path="lerobot/xvla-base" \
--policy.repo_id="HF_USER/xvla-your-robot" \
--steps=3000 \
--policy.dtype=bfloat16 \
--policy.action_mode=auto \
--steps=20000 \
--policy.device=cuda \
--policy.freeze_vision_encoder=True \
--policy.freeze_language_encoder=True \
--policy.train_policy_transformer=True \
--policy.train_soft_prompts=True \
--policy.action_mode=YOUR_ACTION_MODE
--policy.freeze_vision_encoder=false \
--policy.freeze_language_encoder=false \
--policy.train_policy_transformer=true \
--policy.train_soft_prompts=true \
```
### Training Parameters Explained
| Parameter | Default | Description |
| -------------------------- | ------- | ---------------------------------------- |
| `freeze_vision_encoder` | `True` | Freeze the VLM vision encoder weights |
| `freeze_language_encoder` | `True` | Freeze the VLM language encoder weights |
| `train_policy_transformer` | `True` | Allow policy transformer layers to train |
| `train_soft_prompts` | `True` | Allow soft prompts to train |
| Parameter | Default | Description |
| -------------------------- | ------- | ---------------------------------------------- |
| `freeze_vision_encoder` | `false` | Do not freeze the VLM vision encoder weights |
| `freeze_language_encoder` | `false` | Do not freeze the VLM language encoder weights |
| `train_policy_transformer` | `true` | Allow policy transformer layers to train |
| `train_soft_prompts` | `true` | Allow soft prompts to train |
**💡 Best Practice**: For Phase II adaptation to new embodiments, freeze the VLM encoders and only train the policy transformer and soft prompts. This provides excellent sample efficiency with minimal compute.
**💡 Best Practice**: For Phase II adaptation to new embodiments, do not freeze the VLM encoders and also train the policy transformer and soft prompts.
### Example: Training on Bimanual Robot
@@ -157,14 +158,15 @@ lerobot-train \
--output_dir=./outputs/xvla_bimanual \
--job_name=xvla_so101_training \
--policy.path="lerobot/xvla-base" \
--policy.dtype=bfloat16 \
--policy.repo_id="YOUR_USERNAME/xvla-biso101" \
--steps=3000 \
--policy.device=cuda \
--policy.action_mode=so101_bimanual \
--policy.freeze_vision_encoder=True \
--policy.freeze_language_encoder=True \
--policy.train_policy_transformer=True \
--policy.train_soft_prompts=True
--policy.freeze_vision_encoder=false \
--policy.freeze_language_encoder=false \
--policy.train_policy_transformer=true \
--policy.train_soft_prompts=true
```
💡 **Best Performance:** If you have sufficient computational resources and want to achieve best X-VLA finetuning performance, you should follow the official finetuning strategy:
@@ -172,71 +174,7 @@ lerobot-train \
**🔥 Full-finetune all components with a custom learning-rate scheme**
To ensure stable optimization, the Vision-Language Model (VLM) must be trained with only 1/10 of the base learning rate, while all other components use the full LR.
This LR ratio is crucial for achieving strong and stable finetuning performance.
To enable this behavior, you must:
1. Implement a custom optimizer and register it in your training config
```
from dataclasses import dataclass, asdict
from lerobot.optim.optimizers import OptimizerConfig
import torch
@OptimizerConfig.register_subclass("xvla-adamw")
@dataclass
class XVLAAdamW(OptimizerConfig):
lr: float = 1e-4
betas: tuple[float, float] = (0.9, 0.99)
eps: float = 1e-8
weight_decay: float = 0.0
grad_clip_norm: float = 10.0
def build(self, params: dict) -> torch.optim.Optimizer:
"""
Expect `named_parameters()` as input.
Apply lr = lr / 10 for all VLM-related parameters.
"""
assert isinstance(params, dict), \
"Custom LR optimizer requires `named_parameters()` as inputs."
kwargs = asdict(self)
kwargs.pop("grad_clip_norm")
vlm_group, other_group = [], []
for name, p in params.items():
if not p.requires_grad:
continue
if "vlm" in name.lower():
vlm_group.append(p)
else:
other_group.append(p)
param_groups = [
{"params": vlm_group, "lr": self.lr * 0.1, "weight_decay": self.weight_decay * 0.1},
{"params": other_group, "lr": self.lr, "weight_decay": self.weight_decay},
]
return torch.optim.AdamW(param_groups, **kwargs)
```
2. Modify X-VLAs get_optim_params to return named parameters
Replace:
```
def get_optim_params(self) -> dict:
"""Return only trainable parameters for optimization."""
return filter(lambda p: p.requires_grad, self.parameters())
```
with:
```
def get_optim_params(self):
"""Return trainable named parameters."""
return filter(lambda kv: kv[1].requires_grad, self.named_parameters())
```
This ensures the optimizer receives a dict of named parameters, allowing it to correctly detect VLM modules and apply the 1/10 LR rule.
This LR ratio is crucial for achieving strong and stable finetuning performance. This is already done for you by default.
❕Note
Completely matching the official reported performance may require an additional warm-up LR schedule for soft-prompts, which can bring minor improvements.
@@ -326,6 +264,26 @@ domain_id = 3
The domain_id is automatically added to observations by the `XVLAAddDomainIdProcessorStep` in the preprocessing pipeline.
The `lerobot/xvla-base` model has been trained on the following domain IDs. It is recommended to choose one that most resembles your robot/configuration:
#### Fine-tuning Datasets
| Dataset Name | Domain ID |
| ---------------- | --------- |
| Bridge | 0 |
| RT1 | 1 |
| Calvin | 2 |
| libero | 3 |
| widowx-air | 4 |
| AIR-AGILEX-HQ | 5 |
| robotwin2_abs_ee | 6 |
| robotwin2_clean | 6 |
| robocasa-human | 7 |
| VLABench | 8 |
| AGIBOT-challenge | 9 |
| AIR-AGILEX | 10 |
| AIRBOT | 18 |
### 3. Processor Steps
X-VLA requires specific preprocessing and postprocessing steps for proper operation.
-243
View File
@@ -1,243 +0,0 @@
# Synthetic Data Generation Script - Summary
## ✅ What Was Created
### Main Script: `annotate_pgen.py` (717 lines)
A production-ready script implementing the Hi-Robot synthetic data generation pipeline.
**Key Features:**
- ✅ Loads LeRobot datasets with skill annotations
- ✅ Generates synthetic user prompts and robot utterances using Qwen VLM
- ✅ **Temporal sampling** - generates dialogue every N seconds (default: 1s)
- ✅ Adds `task_index_high_level` feature to dataset parquets
- ✅ Saves high-level tasks to `meta/tasks_high_level.parquet`
- ✅ Exports debug JSONL for quality analysis
- ✅ Supports both Qwen2-VL and Qwen3-VL models
- ✅ Multi-view camera support
- ✅ Episode-aware processing with automatic first-frame sampling
- ✅ Modular architecture for easy extension
### Supporting Files Created
1. **`run_pgen.sh`** - Convenience script with sensible defaults
2. **`README_PGEN.md`** - Comprehensive documentation with examples
3. **`example_pgen_usage.md`** - Practical examples and performance estimates
4. **`SAMPLING_DIAGRAM.md`** - Visual explanation of temporal sampling strategy
5. **`PGEN_SUMMARY.md`** - This file
## 🚀 Key Innovation: Temporal Sampling
The script processes **ALL episodes** in the dataset efficiently via `--sample-interval`:
```bash
# Instead of calling VLM for every frame (expensive):
# 15,000 frames × VLM call = ~5 hours
# Generate dialogue every 1 second (efficient):
python annotate_pgen.py --repo-id dataset --model qwen --sample-interval 1.0
# 15,000 frames processed, only ~500 VLM calls (30x speedup!)
```
**How it works:**
- Process ALL frames in ALL episodes (complete coverage)
- Generate dialogue at sampled timepoints (e.g., every 1 second)
- Propagate task indices to intermediate frames
- Always sample first frame of each episode
- All frames get labeled, but VLM is only called for samples
- No dummy values or skipped episodes
**Benefits:**
- 30-100x speedup depending on interval
- Maintains temporal coherence
- Reduces cost without losing quality
- Configurable based on skill duration
## 📊 Efficiency Comparison
For a typical 15,000 frame dataset at 30 fps:
| Method | VLM Calls | Time | Cost |
|--------|-----------|------|------|
| Every frame | 15,000 | ~5 hours | $$$$ |
| Every 0.5s | 1,000 | ~20 min | $$$ |
| **Every 1s** (default) | **500** | **~10 min** | **$$** |
| Every 2s | 250 | ~5 min | $ |
## 🎯 Usage
### Quick Test (5s sampling for fast iteration)
```bash
python examples/dataset/annotate_pgen.py \
--data-dir /fsx/jade_choghari/.cache/huggingface/lerobot/lerobot/svla_so101_pickplace \
--model Qwen/Qwen2-VL-7B-Instruct \
--sample-interval 5.0 \
--output-dir ./outputs/test_quick
```
### Production Run (Recommended Settings)
```bash
python examples/dataset/annotate_pgen.py \
--data-dir /fsx/jade_choghari/.cache/huggingface/lerobot/lerobot/svla_so101_pickplace \
--model Qwen/Qwen2-VL-7B-Instruct \
--sample-interval 1.0 \
--output-dir ./outputs/full_pgen
```
### High-Quality with Qwen3
```bash
python examples/dataset/annotate_pgen.py \
--data-dir /fsx/jade_choghari/.cache/huggingface/lerobot/lerobot/svla_so101_pickplace \
--model Qwen/Qwen3-VL-30B-A3B-Instruct \
--sample-interval 0.5 \
--temperature 0.6 \
--output-dir ./outputs/high_quality
```
## 📦 Output Structure
After running, you'll have:
```
dataset_root/
├── meta/
│ ├── tasks_high_level.parquet # High-level tasks with prompts/utterances
│ └── syn_annotations.jsonl # Debug: full context for each sample
└── data/
└── chunk-000/
└── file-000.parquet # Updated with task_index_high_level
```
**New feature added to all parquet files:**
- `task_index_high_level` (int64): Links to tasks_high_level.parquet
## 🔧 All Parameters
| Parameter | Default | Description |
|-----------|---------|-------------|
| `--repo-id` / `--data-dir` | - | Dataset source |
| `--model` | Qwen/Qwen2-VL-7B-Instruct | VLM model |
| `--device` | cuda | Device to use |
| `--dtype` | bfloat16 | Model precision |
| `--temperature` | 0.7 | Sampling temperature |
| **`--sample-interval`** | **1.0** | **Generate every N seconds (all episodes processed)** |
| `--num-image-views-per-sample` | 1 | Number of cameras |
| `--batch-size` | 1 | Batch size (currently unused) |
| `--output-dir` | None | Output directory |
| `--push-to-hub` | False | Push to HuggingFace |
## 🎨 Generated Data Format
Each sampled frame produces:
```json
{
"scenario_type": "specific_object",
"response_type": "confirmation",
"user_prompt": "Can you pick up the pink brick?",
"robot_utterance": "Sure, I'll grab the pink lego brick.",
"skill": "robot arm picks up pink lego brick",
"episode_id": 0,
"frame_index": 45,
"timestamp": 1.5,
"skill_history": ["robot arm moves towards pink lego brick"],
"task_description": "pink lego brick into the transparent box"
}
```
**Scenario Types:**
- specific_object, negative_task, situated_correction, implicit_request, constraint_based
**Response Types:**
- confirmation, clarification, acknowledgment, constraint_acknowledgment
## 🔬 Code Architecture
```python
# Main components (modular design)
class QwenPgen:
"""VLM wrapper supporting Qwen2/3"""
def call_qwen(images, prompt) -> dict
def construct_prompt(task, history, skill) -> str:
"""Build contextual prompt with history"""
def annotate_sample(pgen, images, ...) -> dict:
"""Generate dialogue for one sample"""
def generate_synthetic_data(dataset, pgen, ...) -> tuple:
"""Process entire dataset with temporal sampling"""
# Core sampling logic:
# - Track last_sample_timestamp per episode
# - Sample if time_elapsed >= sample_interval
# - Always sample first frame of episodes
# - Propagate task_index to intermediate frames
def main():
"""CLI entrypoint with argparse"""
```
## ✨ Next Steps
1. **Quick test with large interval:**
```bash
# Fast iteration - samples every 5 seconds
python examples/dataset/annotate_pgen.py \
--data-dir /path/to/dataset \
--model Qwen/Qwen2-VL-7B-Instruct \
--sample-interval 5.0 \
--output-dir ./outputs/quick_test
```
2. **Verify output quality:**
```bash
head outputs/quick_test/meta/syn_annotations.jsonl
```
3. **Production run:**
```bash
# Standard 1 second sampling for production
bash examples/dataset/run_pgen.sh
```
4. **Use in training:**
```python
from lerobot.datasets.lerobot_dataset import LeRobotDataset
ds = LeRobotDataset(repo_id="...", root="outputs/pgen_annotations")
# Access high-level task for each frame
frame = ds[100]
task_idx = frame["task_index_high_level"].item()
```
## 📚 Documentation Files
- **`README_PGEN.md`**: Full API reference and troubleshooting
- **`example_pgen_usage.md`**: Practical examples with performance estimates
- **`SAMPLING_DIAGRAM.md`**: Visual explanation of temporal sampling
- **`PGEN_SUMMARY.md`**: This overview document
## 🎯 Success Criteria
✅ Script generates synthetic dialogue using Qwen VLM
✅ Adds `task_index_high_level` feature to dataset
✅ Saves tasks to `tasks_high_level.parquet`
✅ Implements efficient temporal sampling (30-100x speedup)
✅ Handles episode boundaries correctly
✅ Produces diverse interaction types (scenarios + responses)
✅ Maintains temporal coherence within episodes
✅ Includes comprehensive documentation and examples
✅ Ready for production use on real datasets
## 💡 Key Takeaway
**The script processes ALL episodes with intelligent sampling:**
- `--sample-interval` controls how often VLM is called (default: 1.0s)
- ALL frames in ALL episodes get labeled (complete coverage)
- Intermediate frames inherit from most recent sample (temporal coherence)
- Achieves 30-100x speedup while maintaining quality
- Adjust interval based on use case: 5.0s for testing, 1.0s for production, 0.5s for fine detail
This makes the synthetic data generation **practical, scalable, and complete** for real-world datasets!
-243
View File
@@ -1,243 +0,0 @@
# Synthetic Data Generation for Hierarchical Robot Policies
This directory contains `annotate_pgen.py`, a script for generating synthetic user prompts and robot utterances for hierarchical policy training using Vision-Language Models (VLMs).
## Overview
The script implements the synthetic data generation pipeline described in the Hi-Robot paper:
1. **Load** a LeRobot dataset with skill annotations (from `annotate.py`)
2. **Generate** synthetic dialogue using Qwen VLM:
- User prompts (_t): Natural requests that lead to specific skills
- Robot utterances (u_t): Acknowledgments and clarifications
3. **Save** results as a new dataset feature `task_index_high_level`
## Prerequisites
1. First, annotate your dataset with skills using `annotate.py`:
```bash
python examples/dataset/annotate.py \
--repo-id lerobot/svla_so101_pickplace \
--video-key observation.images.base \
--model Qwen/Qwen2-VL-7B-Instruct
```
This creates `meta/skills.json` with skill segmentation for each episode.
## Usage
### Basic Usage
```bash
python examples/dataset/annotate_pgen.py \
--repo-id lerobot/svla_so101_pickplace \
--model Qwen/Qwen2-VL-7B-Instruct \
--sample-interval 1.0 \
--output-dir ./outputs/pgen_dataset
```
**Note**: The script processes **all episodes** in the dataset. It generates dialogue every 1 second (`--sample-interval 1.0`) using temporal sampling. Frames between samples reuse the last generated dialogue. This makes the process efficient while ensuring complete dataset coverage.
### Advanced Options
```bash
python examples/dataset/annotate_pgen.py \
--repo-id lerobot/svla_so101_pickplace \
--model Qwen/Qwen3-VL-30B-A3B-Instruct \
--temperature 0.8 \
--sample-interval 0.5 \
--num-image-views-per-sample 2 \
--output-dir ./outputs/pgen_dataset \
--push-to-hub
```
This example uses a more powerful model and samples every 0.5 seconds for finer granularity.
### Fast Testing (larger interval)
```bash
python examples/dataset/annotate_pgen.py \
--repo-id lerobot/svla_so101_pickplace \
--model Qwen/Qwen2-VL-7B-Instruct \
--sample-interval 5.0 \
--output-dir ./outputs/pgen_quick_test
```
Use a larger interval (5.0 seconds) for rapid iteration during development. All episodes are still processed.
### Using Local Dataset
```bash
python examples/dataset/annotate_pgen.py \
--data-dir /fsx/jade_choghari/.cache/huggingface/lerobot/lerobot/svla_so101_pickplace \
--model Qwen/Qwen2-VL-7B-Instruct \
--output-dir ./outputs/pgen_dataset
```
## Output Files
The script produces several outputs:
1. **`meta/tasks_high_level.parquet`**: High-level tasks with user prompts and robot utterances
- Columns: task_index, user_prompt, robot_utterance, skill, scenario_type, response_type
2. **`meta/syn_annotations.jsonl`**: Debug file with all generated dialogues
- One JSON object per line with full context for each frame
3. **Modified dataset**: New dataset with `task_index_high_level` feature added to all parquet files
## Scenario and Response Types
The generator produces diverse interaction types:
### Scenario Types
- **specific_object**: Direct specification of objects/actions
- **negative_task**: Instructions about what NOT to do
- **situated_correction**: Adjustments based on current state
- **implicit_request**: Implied needs without direct commands
- **constraint_based**: Specific constraints or preferences
### Response Types
- **confirmation**: Simple acknowledgment ("OK, I'll do X")
- **clarification**: Seeking confirmation ("Just to confirm...")
- **acknowledgment**: Action acknowledgment ("Got it, doing X")
- **constraint_acknowledgment**: Acknowledging constraints ("Sure, I'll X while Y")
## Example Generated Data
```json
{
"episode_id": 0,
"frame_index": 45,
"timestamp": 2.5,
"skill_current": "robot arm picks up pink lego brick",
"skill_history": ["robot arm moves towards pink lego brick"],
"task_description": "pink lego brick into the transparent box",
"scenario_type": "specific_object",
"response_type": "confirmation",
"user_prompt": "Can you grab the pink brick?",
"robot_utterance": "Sure, I'll pick up the pink lego brick."
}
```
## Accessing the Data
After running the script, access the synthetic data in your code:
```python
from lerobot.datasets.lerobot_dataset import LeRobotDataset
import pandas as pd
# Load modified dataset
dataset = LeRobotDataset(repo_id="lerobot/svla_so101_pickplace_with_high_level_tasks")
# Access frame with high-level task
frame = dataset[100]
high_level_task_idx = frame["task_index_high_level"].item()
# Load high-level tasks
tasks_df = pd.read_parquet(dataset.root / "meta" / "tasks_high_level.parquet")
task_info = tasks_df.iloc[high_level_task_idx]
print(f"User prompt: {task_info['user_prompt']}")
print(f"Robot utterance: {task_info['robot_utterance']}")
print(f"Skill: {task_info['skill']}")
```
## Architecture
The script is modular and extensible:
```python
# Core components
class QwenPgen:
"""VLM wrapper for generation"""
def call_qwen(images, prompt) -> dict
def construct_prompt(task, history, skill) -> str
"""Build prompt for VLM"""
def annotate_sample(pgen, images, ...) -> dict
"""Generate dialogue for one sample"""
def generate_synthetic_data(dataset, pgen, ...) -> tuple
"""Process entire dataset"""
```
## Parameters
| Parameter | Default | Description |
|-----------|---------|-------------|
| `--repo-id` | - | HuggingFace dataset ID |
| `--data-dir` | - | Local dataset path |
| `--model` | Qwen/Qwen2-VL-7B-Instruct | VLM model name |
| `--device` | cuda | Device (cuda/cpu) |
| `--dtype` | bfloat16 | Model precision |
| `--temperature` | 0.7 | Sampling temperature |
| `--sample-interval` | 1.0 | Generate dialogue every N seconds (all episodes processed) |
| `--num-image-views-per-sample` | 1 | Number of cameras |
| `--output-dir` | None | Output directory |
| `--push-to-hub` | False | Push to HuggingFace Hub |
## Sampling Strategy
The script uses **temporal sampling** to efficiently generate dialogue:
- **Default**: Generate dialogue every 1 second (`--sample-interval 1.0`)
- **Efficiency**: If a dataset runs at 30fps, this samples ~3% of frames
- **Propagation**: Frames between samples reuse the last generated task_index
- **Episode-aware**: Always samples the first frame of each episode
### Example with 30 fps dataset:
```bash
# Sample every 1 second (every 30 frames)
--sample-interval 1.0 # ~3,000 generations for a 100 episode dataset (3 sec/episode)
# Sample every 0.5 seconds (every 15 frames)
--sample-interval 0.5 # ~6,000 generations (more granular)
# Sample every 2 seconds (every 60 frames)
--sample-interval 2.0 # ~1,500 generations (more efficient)
```
### Why sampling works:
- Skills typically last 1-3 seconds
- Dialogue doesn't need to change every frame
- Reduces computational cost by 30-100x
- Still provides good coverage for training
## Tips
1. **Quick testing**: Use larger `--sample-interval` (e.g., 5.0 or 10.0) for rapid iteration
2. **Monitor GPU**: VLM inference is memory-intensive
3. **Check outputs**: Review `syn_annotations.jsonl` for quality
4. **Adjust temperature**: Higher = more diverse, lower = more consistent
5. **Multiple views**: Use `--num-image-views-per-sample 2+` for better context
6. **Tune sampling**: Start with 1.0s, increase for speed (testing), decrease for granularity (production)
## Troubleshooting
### No skills.json found
Run `annotate.py` first to generate skill annotations.
### Out of memory
- Reduce batch size to 1
- Use smaller model (Qwen2-VL-7B instead of Qwen3-VL-30B)
- Process fewer samples at a time
### Poor quality generations
- Adjust temperature (try 0.6-0.9)
- Check that skills.json has good annotations
- Ensure images are loading correctly
## Citation
Based on the Hi-Robot paper's synthetic data generation approach:
```
@article{hirobot2024,
title={Hi-Robot: Hierarchical Robot Learning with Vision-Language Models},
year={2024}
}
```
-141
View File
@@ -1,141 +0,0 @@
# Temporal Sampling Strategy Visualization
## How `--sample-interval` Works
### Example: 30 fps dataset, `--sample-interval 1.0` (1 second)
```
Timeline (seconds): 0.0 0.5 1.0 1.5 2.0 2.5 3.0
│ │ │ │ │ │ │
Frames: 0───15───30───45───60───75───90───105──120──135──150
│ │ │ │ │ │ │
▼ ▼ ▼ ▼
Sampled: YES NO YES NO YES NO YES
│ │ │ │
Task Index: [0]──────────────>[1]──────────────>[2]──────────────>[3]
│ │ │ │
VLM Called: ✓ Gen ✓ Gen ✓ Gen ✓ Gen
dialogue dialogue dialogue dialogue
│ │ │ │
Frames 0-29 ─────┘ │ │ │
get task 0 │ │ │
│ │ │
Frames 30-59 ────────────────────────┘ │ │
get task 1 │ │
│ │
Frames 60-89 ──────────────────────────────────────────┘ │
get task 2 │
Frames 90-119 ────────────────────────────────────────────────────────────┘
get task 3
```
## Comparison: Different Sampling Intervals
### `--sample-interval 2.0` (every 2 seconds)
```
Timeline: 0.0 1.0 2.0 3.0 4.0 5.0 6.0
│ │ │ │ │ │ │
Sampled: YES NO YES NO YES NO YES
│ │ │ │
Tasks: [0]───────────────>[1]───────────────>[2]───────────────>[3]
VLM Calls: 4 (fewer calls, faster but less granular)
```
### `--sample-interval 1.0` (every 1 second) - **DEFAULT**
```
Timeline: 0.0 0.5 1.0 1.5 2.0 2.5 3.0 3.5 4.0 4.5 5.0 5.5 6.0
│ │ │ │ │ │ │ │ │ │ │ │ │
Sampled: YES NO YES NO YES NO YES NO YES NO YES NO YES
│ │ │ │ │ │ │
Tasks: [0]─────────>[1]─────────>[2]─────────>[3]─────────>[4]─────────>[5]─────>[6]
VLM Calls: 7 (balanced coverage and speed)
```
### `--sample-interval 0.5` (every 0.5 seconds)
```
Timeline: 0.0 0.5 1.0 1.5 2.0 2.5 3.0 3.5 4.0 4.5 5.0 5.5 6.0
│ │ │ │ │ │ │ │ │ │ │ │ │
Sampled: YES YES YES YES YES YES YES YES YES YES YES YES YES
│ │ │ │ │ │ │ │ │ │ │ │ │
Tasks: [0]─>[1]─>[2]─>[3]─>[4]─>[5]─>[6]─>[7]─>[8]─>[9]─>[10]>[11]>[12]
VLM Calls: 13 (high granularity, slower but more detailed)
```
## Episode Boundaries
The script always samples the **first frame** of each episode:
```
Episode 0 Episode 1 Episode 2
├─────────────────────────────────┤├─────────────────────────────────┤├──────...
│ ││ ││
Frame: 0 30 60 90 120 130 160 190 220 250 260 290 320
Time: 0.0 1.0 2.0 3.0 4.0 0.0 1.0 2.0 3.0 4.0 0.0 1.0 2.0
│ │ │ │ │ │ │ │ │ │ │ │ │
▼ ▼ ▼ ▼ ▼ ▼ ▼ ▼ ▼ ▼ ▼ ▼ ▼
Sample:YES YES YES YES YES YES YES YES YES YES YES YES YES
│ │ │ │ │ │ │ │ │ │ │ │ │
Task: 0────1─────2─────3────4 5─────6─────7─────8────9 10────11───12
Note: Frames 0, 130, 260 are ALWAYS sampled (episode starts)
Even if they're within the sample-interval window
```
## Real-World Example: svla_so101_pickplace Dataset
Typical stats:
- **Total episodes**: 50
- **Avg episode length**: 300 frames (10 seconds at 30 fps)
- **Total frames**: 15,000
### Without Sampling (every frame)
```
Frames processed: 15,000
VLM calls: 15,000
Time estimate: ~5 hours
Unique tasks: ~12,000 (lots of duplicates)
```
### With `--sample-interval 1.0` (every 1 second)
```
Frames processed: 15,000 ✓
VLM calls: 500
Time estimate: ~10 minutes
Unique tasks: ~450 (meaningful variety)
Efficiency gain: 30x faster
```
### With `--sample-interval 2.0` (every 2 seconds)
```
Frames processed: 15,000 ✓
VLM calls: 250
Time estimate: ~5 minutes
Unique tasks: ~220
Efficiency gain: 60x faster
```
## Key Points
1. **All frames get labeled**: Every frame gets a `task_index_high_level`
2. **Only sampled frames call VLM**: Huge efficiency gain
3. **Temporal coherence**: Nearby frames share the same task
4. **Episode-aware**: Always samples episode starts
5. **Configurable**: Adjust `--sample-interval` based on your needs
## Choosing Your Sampling Interval
| Use Case | Recommended Interval | Why |
|----------|---------------------|-----|
| Quick testing | 2.0s | Fastest iteration |
| Standard training | 1.0s | Good balance |
| High-quality dataset | 0.5s | Better coverage |
| Fine-grained control | 0.33s | Very detailed |
| Dense annotations | 0.1s | Nearly every frame |
**Rule of thumb**: Match your sampling interval to your typical skill duration.
If skills last 1-3 seconds, sampling every 1 second captures each skill multiple times.
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
-143
View File
@@ -1,143 +0,0 @@
# Example: Synthetic Data Generation with Sampling
## Quick Start
### 1. Test with 100 frames and 1 second sampling
```bash
python examples/dataset/annotate_pgen.py \
--data-dir /fsx/jade_choghari/.cache/huggingface/lerobot/lerobot/svla_so101_pickplace \
--model Qwen/Qwen2-VL-7B-Instruct \
--num-samples 100 \
--sample-interval 1.0 \
--output-dir ./outputs/test_pgen
```
**Expected behavior** (assuming 30 fps):
- Total frames: 100
- Frames sampled: ~4 (every 30 frames = 1 second)
- Efficiency: 96% fewer VLM calls
- Output: All 100 frames get `task_index_high_level`, but only 4 unique dialogues generated
### 2. Process full dataset with different sampling rates
#### Conservative (every 2 seconds)
```bash
python examples/dataset/annotate_pgen.py \
--data-dir /fsx/jade_choghari/.cache/huggingface/lerobot/lerobot/svla_so101_pickplace \
--model Qwen/Qwen2-VL-7B-Instruct \
--sample-interval 2.0 \
--output-dir ./outputs/pgen_2s
```
#### Standard (every 1 second) - **RECOMMENDED**
```bash
python examples/dataset/annotate_pgen.py \
--data-dir /fsx/jade_choghari/.cache/huggingface/lerobot/lerobot/svla_so101_pickplace \
--model Qwen/Qwen2-VL-7B-Instruct \
--sample-interval 1.0 \
--output-dir ./outputs/pgen_1s
```
#### Fine-grained (every 0.5 seconds)
```bash
python examples/dataset/annotate_pgen.py \
--data-dir /fsx/jade_choghari/.cache/huggingface/lerobot/lerobot/svla_so101_pickplace \
--model Qwen/Qwen2-VL-7B-Instruct \
--sample-interval 0.5 \
--output-dir ./outputs/pgen_0.5s
```
## Performance Estimates
For a dataset with:
- 100 episodes
- 10 seconds per episode (average)
- 30 fps
- Total frames: 30,000
| Sampling Interval | Frames Sampled | % Sampled | Speedup | Time Estimate |
|-------------------|----------------|-----------|---------|---------------|
| Every frame (0.033s) | 30,000 | 100% | 1x | ~10 hours |
| 0.5 seconds | 2,000 | 6.7% | 15x | ~40 min |
| **1.0 seconds** | **1,000** | **3.3%** | **30x** | **~20 min** |
| 2.0 seconds | 500 | 1.7% | 60x | ~10 min |
*Note: Times are approximate and depend on GPU, model size, and generation speed*
## Understanding the Output
### Console Output Example
```
[cyan]Generating synthetic data for 30000 frames...[/cyan]
[cyan]Sampling interval: 1.0s (fps: 30)[/cyan]
Generating synthetic dialogue: 100%|████████| 30000/30000 [20:15<00:00, 24.68it/s]
[green]✓ Sampled 1000 frames out of 30000 (3.3%)[/green]
[green]✓ Generated 450 unique high-level tasks[/green]
```
### What happens:
1. **Frame 0 (t=0.0s)**: Generate dialogue → Task index 0
2. **Frames 1-29 (t=0.033s-0.967s)**: Reuse task index 0
3. **Frame 30 (t=1.0s)**: Generate new dialogue → Task index 1
4. **Frames 31-59 (t=1.033s-1.967s)**: Reuse task index 1
5. And so on...
### Result:
- Every frame has a `task_index_high_level`
- Only sampled frames have unique dialogues generated
- Intermediate frames inherit from the most recent sample
- Maintains temporal coherence within episodes
## Checking Your Results
After running, verify the output:
```bash
# Check the generated tasks
python -c "
import pandas as pd
from pathlib import Path
tasks = pd.read_parquet('outputs/test_pgen/meta/tasks_high_level.parquet')
print(f'Total unique tasks: {len(tasks)}')
print(f'Sample tasks:')
print(tasks[['user_prompt', 'robot_utterance', 'skill']].head())
"
# Check debug output
head outputs/test_pgen/meta/syn_annotations.jsonl
# Load and verify dataset
python -c "
from lerobot.datasets.lerobot_dataset import LeRobotDataset
ds = LeRobotDataset(repo_id='local_with_high_level_tasks',
root='outputs/test_pgen')
print(f'Dataset has {len(ds)} frames')
print(f'Features: {list(ds.features.keys())}')
assert 'task_index_high_level' in ds.features
print('✓ task_index_high_level feature added successfully!')
"
```
## Common Use Cases
### Development/Testing
```bash
--sample-interval 2.0 # Fast iteration
--num-samples 500 # Small subset
```
### Production Training
```bash
--sample-interval 1.0 # Good coverage
# Process all samples (no --num-samples)
```
### High-Quality Dataset
```bash
--sample-interval 0.5 # Fine-grained
--temperature 0.6 # More consistent
--model Qwen/Qwen3-VL-30B-A3B-Instruct # Larger model
```
-17
View File
@@ -1,17 +0,0 @@
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
model_id = "google/paligemma-3b-pt-224"
model = PaliGemmaForConditionalGeneration.from_pretrained(model_id)
processor = AutoProcessor.from_pretrained(model_id)
breakpoint()
prefix_output = model.language_model.forward(
inputs_embeds=inputs_embeds[0],
attention_mask=attention_mask,
position_ids=position_ids,
adarms_cond=adarms_cond[0] if adarms_cond is not None else None,
)
prefix_past_key_values = prefix_output.past_key_values
# prefix_output to be used for the language head
# shape: [batch_size, seq_len, hidden_size] with hidden_size = 2048
prefix_output = prefix_output.last_hidden_state
-58
View File
@@ -1,58 +0,0 @@
import torch
from huggingface_hub import HfApi
import lerobot
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
# import make_pre_post_processors
from lerobot.policies.factory import make_pre_post_processors
from lerobot.policies.pi05.configuration_pi05 import PI05Config
from lerobot.policies.factory import make_policy, make_policy_config
from lerobot.configs.policies import PreTrainedConfig
cfg = PreTrainedConfig.from_pretrained(
pretrained_name_or_path="/fsx/jade_choghari/outputs/pi0_training_new/checkpoints/last/pretrained_model",
)
cfg.dtype = "bfloat16"
pre_processor, post_processor = make_pre_post_processors(
policy_cfg=cfg,
pretrained_path="/fsx/jade_choghari/outputs/pi0_training_new/checkpoints/last/pretrained_model",
)
dataset = LeRobotDataset(repo_id="local", root="/fsx/jade_choghari/outputs/pgen_annotations1")
# rename map --rename_map='{
# "observation.images.side": "observation.images.base_0_rgb",
# "observation.images.up": "observation.images.left_wrist_0_rgb"
# }'
rename_map = {
"observation.images.side": "observation.images.base_0_rgb",
"observation.images.up": "observation.images.left_wrist_0_rgb"
}
policy = make_policy(
cfg=cfg,
ds_meta=dataset.meta,
rename_map=rename_map,
)
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=0,
batch_size=4,
shuffle=True,
)
batch = next(iter(dataloader))
batch = pre_processor(batch)
# Test training forward pass
policy.train()
loss, loss_dict = policy.forward(batch)
print(f"Training loss: {loss_dict}")
# Test inference
policy.eval()
with torch.no_grad():
actions = policy.predict_action_chunk(batch)
print(f"Predicted actions shape: {actions.shape}")
-23
View File
@@ -1,23 +0,0 @@
import torch
from huggingface_hub import HfApi
import lerobot
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
dataset = LeRobotDataset(repo_id="local", root="/fsx/jade_choghari/outputs/pgen_annotations1")
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=0,
batch_size=32,
shuffle=True,
)
batch = next(iter(dataloader))
print(batch.keys())
print(batch['task_index_high_level'].shape)
print(batch['task_index_high_level'])
print(batch['user_prompt'][0])
print(batch['robot_utterance'][0])
print(batch['task'][0])
breakpoint()
-334
View File
@@ -1,334 +0,0 @@
Generate annotate_pgen.py using Qwen for synthetic data generation
You are writing a Python script called annotate_pgen.py.
This script generates synthetic user prompts (_t) and robot utterances (u_t) for Hi Robotstyle hierarchical policy training, using Qwen 3vl as the generator model (pgen).
SCRIPT PURPOSE
The script must:
Load Dlabeled which is a LeRobot Dataset that has been annotate using the annotate.py script, which contains:
images: list of image paths at time t
skill_current: the annotated skill label (ℓ̂_t)
skill_history: list of previous skill labels (ℓ̂₀ … ℓ̂_{t1}), those where annotated, and you can find details on them stored in teh dataset inside the the DATA_PATH/meta/skills.json
you will find something like
{
"coarse_description": "pink lego brick into the transparent box",
"skill_to_task_index": {
"robot arm picks up pink lego brick": 19,
"robot arm approaches transparent box": 3,
"robot arm retracts from transparent box": 28,
"robot arm moves towards pink lego brick": 12,
"robot arm releases red lego brick into box": 26,
"robot arm releases red lego brick into transparent box": 27,
"robot arm closes gripper to pick up the pink lego brick": 5,
"robot arm lifts the pink lego brick": 7,
etc..
},
"episodes": {
"0": {
"episode_index": 0,
"description": "pink lego brick into the transparent box",
"skills": [
{
"name": "robot arm moves towards pink lego brick",
"start": 0.0,
"end": 1.8
},
{
"name": "robot arm picks up pink lego brick",
"start": 1.8,
"end": 3.1
},
{
"name": "robot arm moves towards transparent box",
"start": 3.1,
"end": 5.5
},
{
"name": "robot arm releases pink lego brick into transparent box",
"start": 5.5,
"end": 7.0
},
{
"name": "robot arm retracts from transparent box",
"start": 7.0,
"end": 10.1
}
]
},
"1": {
"episode_index": 1,
"description": "pink lego brick into the transparent box",
"skills": [
{
"name": "robot arm moves towards red lego brick",
"start": 0.0,
"end": 1.2
},
{
"name": "robot arm picks up red lego brick",
"start": 1.2,
"end": 2.0
},
{
"name": "robot arm moves towards transparent box",
"start": 2.0,
"end": 3.8
},
{
"name": "robot arm places red lego brick into transparent box",
"start": 3.8,
"end": 5.0
},
{
"name": "robot arm moves away from transparent box",
"start": 5.0,
"end": 8.9
}
]
},
notice how task_description: is a high-level description (e.g., "make a sandwich") stored in description for each episode
For each sample, call Qwen VLM to generate:
synthetic user prompt _t
synthetic robot response u_t
Save results to D_syn in Parquet format insdie DATA_PATH/meta/tasks.parquet ; note tasks.parquet already contains the other tasks, so you need to update
Should be modular, clean, easy to extend, with:
a PGEN_PROMPT_TEMPLATE
a construct_prompt() method
a call_qwen() method
a annotate_sample() method
a CLI entrypoint (if __name__ == "__main__":)
📦 INPUT FORMAT (Dlabeled)
The script should expect Dlabeled as a .jsonl file where each line has:
{
"episode_id": "ep_001",
"t": 37,
"images": ["path/to/cam0_t.jpg", "path/to/cam1_t.jpg"],
"skill_current": "pick up the KitKat",
"skill_history": ["open fridge", "pick up lettuce", "place lettuce"],
"task_description": "making a sandwich"
}
📤 OUTPUT FORMAT (D_syn)
Each line of synthetically generated data should be:
{
"episode_id": "ep_001",
"t": 37,
"images": ["path/to/cam0_t.jpg", "path/to/cam1_t.jpg"],
"skill_current": "pick up the KitKat",
"skill_history": [...],
"user_prompt": "Can you grab me something sweet?",
"robot_utterance": "Sure, I can pick up the KitKat.",
"task_description": "making a sandwich"
}
Store as syn_annotations.jsonl. for debugging
🧠 pgen MODEL (Qwen) REQUIREMENTS
Use HuggingFace Transformers:
Qwen/Qwen2-VL-7B-Instruct (or any Qwen2-VL Vision-Language model available)
Use the image + text chat interface
Vision inputs should be loaded with PIL
Use a single forward pass that outputs BOTH _t and u_t in a structured JSON
📝 PROMPT FORMAT FOR pgen
Create a template like:
You are a robot-assistant dialogue generator for hierarchical robot policies.
You will receive:
- A list of images showing the current robot scene.
- The high-level task: {task_description}
- Previous skill steps completed: {skill_history}
- The next skill to be performed by the robot: {skill_current}
Generate two things in JSON:
1. "user_prompt": a natural-sounding user request that logically leads to the robot performing the skill "{skill_current}" given the task and history.
2. "robot_utterance": a natural robot reply acknowledging or clarifying the request.
The responses must be grounded in the visual scene, the task, and the skill history.
Respond ONLY in JSON:
{
"user_prompt": "...",
"robot_utterance": "..."
}
This resposne will have a corresponsing task_index, and the task will be saved in task.parqeut and you must update each dataset parquet in for example /fsx/jade_choghari/.cache/huggingface/lerobot/lerobot/svla_so101_pickplace/data/chunk-000/
file-000.parquet to include this new feature called task_index_high_level consider udpatign the metadata in info.json as well
📌 LOGIC REQUIRED
construct_prompt(sample)
Loads sample dict
Inserts:
task_description
skill_history
skill_current
Returns a full text prompt string
call_qwen(images, prompt)
Loads images into Qwen-VL multimodal input format
Calls model.generate
Parses JSON output
annotate_sample(sample)
Builds prompt
Calls Qwen
Returns augmented sample with user_prompt + robot_utterance
🚀 CLI Usage
The script should run as:
python annotate_pgen.py \
--output-dir PATH \
--model Qwen/Qwen2-VL-7B-Instruct \
--repo-id lerobot/svla_so101_pickplace \
--model Qwen/Qwen3-VL-30B-A3B-Instruct \
--batch-size 1
Include arguments via argparse.
🔧 OTHER REQUIREMENTS
Use tqdm for progress bars
Log errors gracefully and continue
Support GPU acceleration (device="cuda")
Cache model loading so it's not reloaded every call
Make the prompt deterministic but allow temperature parameter
Add a flag --num-image-views-per-sample
Add automatic JSON parsing with helpful error messages
🎯 FINAL DELIVERABLE
Cursor must now generate:
A full Python file named annotate_pgen.py implementing the above functionality end-to-end.
It should be production-ready, runnable on real data, cleanly structured, and easy to modify.
from the paper:
Next, we use a large vision-language model (VLM) pgen
to produce synthetic user prompts and interjections t,
and corresponding robot utterance ut. Given Dlabeled, we
prompt pgen with both the visual context I1
t ,...,In
t and the
skill labelˆ
t (e.g., pick up the lettuce). pgen then imag-
ines an appropriate interaction that might have led toˆ
t in a
real user interaction: it generates possible user prompts t
(e.g., “Can you add some lettuce for me?”) along with the
robots verbal responses and clarifications ut. We detail the
A. Synthetic Data Generation
A.1. Scenario and Response Categorization
To ensure the quality and diversity of the synthetic data,
we incorporate structured scenario classification and re-
sponse categorization into the prompt design for pgen, fol-
lowing (Stephan et al., 2024). Specifically, we classify
interactions into different scenario types, such as nega-
tive task (where the user instructs the robot what not to
do), situated correction (where the user adjusts an earlier
command based on the evolving task state), and specific
constraint (where the user specifies particular constraints,
such as dietary preferences). In addition, we categorize
the robots responses into types such as simple confirma-
tions, clarifications, and error handling. These classifica-
tions guide the generation process to ensure a broad range
of user-robot interactions.
A.2. Prompt Construction for Contextual Grounding
In prompt P, we include a detailed description of the task
(e.g., bussing a table, making a sandwich, grocery shop-
ping) and instruct the model to ground responses in visual
observations and prior context. A key advantage of lever-
aging large pretrained VLMs is their ability to incorporate
world knowledge when generating interactions. For in-
stance, the model can infer dietary constraints when gener-
ating prompts for sandwich-making, producing user com-
mands such as “Can you make a sandwich for me? Im
lactose intolerant” and an appropriate robot response like
“Sure, I wont put cheese on it.” Similarly, it can reason
over ambiguous or implicit requests, such as inferring that
“I want something sweet” in a grocery shopping scenario
should lead to suggestions like chocolate or candy.
To maintain consistency in multi-step tasks, we condition
pgen on prior skill labels within an episodeˆ
ˆ
0,...,
t1,
allowing it to generate coherent user commands that
account for past actions. For instance, if the robot
has already placed lettuce and tomato on a sandwich,
the generated user prompt might request additional in-
gredients that logically follow. This ensures that the
synthetic interactions reflect realistic task progression
rather than isolated commands. As such, we leverage
ˆ
ˆ
ˆ
pgen(t,ut|I1
t ,...,In
t ,
0,...,
t1,
t,P) to produce a richer,
more diverse synthetic dataset Dsyn that provides mean-
ingful supervision for training our high-level policy.
While in this work we generate a separate Dsyn and train
a separate high-level policy for each task (e.g., sandwich
making vs. table cleaning) for clarity and ease of bench-
marking, the architecture is readily amenable to a unified
multi-task formulation. In principle, the same hierarchical
approach could be used to train a single high-level policy
across a multitude of tasks, facilitating knowledge transfer
The result should be a new LeRobotDataset with a new feature called task_index_high_level inside each dataset parquet
-10
View File
@@ -1,10 +0,0 @@
# python examples/dataset/annotate.py \
# --repo-id lerobot/svla_so101_pickplace \
# --video-key observation.images.side \
# --model Qwen/Qwen3-VL-30B-A3B-Instruct \
python examples/dataset/annotate.py \
--repo-id lerobot/svla_so101_pickplace \
--video-key observation.images.side \
--model Qwen/Qwen3-VL-30B-A3B-Instruct \
--episodes 3 5 7 44
-42
View File
@@ -1,42 +0,0 @@
#!/bin/bash
# Example script to run synthetic data generation with Qwen VLM
# This generates user prompts and robot utterances for hierarchical policy training
# Configuration
REPO_ID="lerobot/svla_so101_pickplace"
MODEL="Qwen/Qwen3-VL-30B-A3B-Instruct"
# Alternative: MODEL="Qwen/Qwen2-VL-7B-Instruct"
OUTPUT_DIR="/fsx/jade_choghari/outputs/pgen_annotations1"
BATCH_SIZE=32
TEMPERATURE=0.9
SAMPLE_INTERVAL=5.0 # Generate dialogue every 1 second (all episodes processed)
# Run synthetic data generation (processes ALL episodes)
python examples/dataset/annotate_pgen.py \
--repo-id "$REPO_ID" \
--model "$MODEL" \
--output-dir "$OUTPUT_DIR" \
--temperature "$TEMPERATURE" \
--batch-size "$BATCH_SIZE" \
--sample-interval "$SAMPLE_INTERVAL" \
--num-image-views-per-sample 1
# For faster testing, increase sample interval:
# --sample-interval 5.0 # Samples every 5 seconds (much faster)
# To push to hub after generation:
# Add --push-to-hub flag
# Efficient batch processing: 4 episodes at once
# python examples/dataset/annotate_pgen.py \
# --repo-id "$REPO_ID" \
# --model "$MODEL" \
# --output-dir "$OUTPUT_DIR" \
# --video-mode \
# --video-key observation.images.up \
# --video-batch-size "$BATCH_SIZE" \
# --sample-interval 1.0
-44
View File
@@ -1,44 +0,0 @@
#!/bin/bash
# Quick test to verify the fix for task_indices length mismatch
# This should now work correctly even with --num-samples < full dataset length
echo "Testing annotate_pgen.py with --num-samples=100 on full dataset..."
python examples/dataset/annotate_pgen.py \
--data-dir /fsx/jade_choghari/.cache/huggingface/lerobot/lerobot/svla_so101_pickplace \
--model Qwen/Qwen3-VL-30B-A3B-Instruct \
--num-samples 100 \
--sample-interval 1.0 \
--output-dir /fsx/jade_choghari/outputs/pgen_test_fixed
if [ $? -eq 0 ]; then
echo "✓ SUCCESS: Script completed without errors!"
echo ""
echo "Verifying output..."
# Check that all frames have task_index_high_level
python -c "
from lerobot.datasets.lerobot_dataset import LeRobotDataset
import numpy as np
ds = LeRobotDataset(repo_id='local_test', root='/fsx/jade_choghari/outputs/pgen_test_fixed')
print(f'Dataset has {len(ds)} frames')
print(f'Features: {list(ds.features.keys())}')
# Check that task_index_high_level exists
assert 'task_index_high_level' in ds.features, 'task_index_high_level not in features!'
# Sample some frames
for idx in [0, 50, 99, 100, 500, 1000, 11938]:
if idx < len(ds):
frame = ds[idx]
task_idx = frame['task_index_high_level'].item()
print(f'Frame {idx}: task_index_high_level = {task_idx}')
print('✓ All checks passed!')
"
else
echo "✗ FAILED: Script exited with error code $?"
fi
Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.9 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 185 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 464 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 72 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 219 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 199 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 774 KiB

Before

Width:  |  Height:  |  Size: 160 KiB

After

Width:  |  Height:  |  Size: 160 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.3 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 481 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 117 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 151 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 130 KiB

BIN
View File
Binary file not shown.

Before

Width:  |  Height:  |  Size: 407 KiB

+92 -6
View File
@@ -96,7 +96,7 @@ dependencies = [
# Common
pygame-dep = ["pygame>=2.5.1,<2.7.0"]
placo-dep = ["placo>=0.9.6,<0.10.0"]
transformers-dep = ["transformers>=4.53.0,<5.0.0"]
transformers-dep = ["transformers>=4.57.1,<5.0.0"]
grpcio-dep = ["grpcio==1.73.1", "protobuf==6.31.0"] # TODO: Bumb dependency (compatible with wandb)
# Motors
@@ -120,6 +120,13 @@ intelrealsense = [
phone = ["hebi-py>=2.8.0,<2.12.0", "teleop>=0.1.0,<0.2.0", "fastapi<1.0"]
# Policies
wallx = [
"transformers==4.49.0",
"peft==0.17.1",
"scipy==1.15.3",
"torchdiffeq==0.2.5",
"qwen_vl_utils==0.0.11"
]
pi = ["transformers @ git+https://github.com/huggingface/transformers.git@fix/lerobot_openpi"]
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0", "safetensors>=0.4.3,<1.0.0"]
groot = [
@@ -133,6 +140,7 @@ groot = [
"ninja>=1.11.1,<2.0.0",
"flash-attn>=2.5.9,<3.0.0 ; sys_platform != 'darwin'"
]
sarm = ["lerobot[transformers-dep]", "faker>=33.0.0,<35.0.0", "matplotlib>=3.10.3,<4.0.0", "qwen-vl-utils>=0.0.14"]
xvla = ["lerobot[transformers-dep]"]
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
@@ -140,7 +148,7 @@ hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpci
async = ["lerobot[grpcio-dep]", "matplotlib>=3.10.3,<4.0.0"]
# Development
dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1"]
dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1", "mypy>=1.19.1"]
test = ["pytest>=8.1.0,<9.0.0", "pytest-timeout>=2.4.0,<3.0.0", "pytest-cov>=5.0.0,<8.0.0", "mock-serial>=0.0.1,<0.1.0 ; sys_platform != 'win32'"]
video_benchmark = ["scikit-image>=0.23.2,<0.26.0", "pandas>=2.2.2,<2.4.0"]
@@ -159,7 +167,8 @@ all = [
"lerobot[reachy2]",
"lerobot[kinematics]",
"lerobot[intelrealsense]",
"lerobot[pi]",
# "lerobot[wallx]",
# "lerobot[pi]", TODO(Pepijn): Update pi to transformers v5
"lerobot[smolvla]",
# "lerobot[groot]", TODO(Steven): Gr00t requires specific installation instructions for flash-attn
"lerobot[xvla]",
@@ -173,6 +182,7 @@ all = [
"lerobot[phone]",
"lerobot[libero]",
"lerobot[metaworld]",
"lerobot[sarm]"
]
[project.scripts]
@@ -227,6 +237,7 @@ ignore = [
[tool.ruff.lint.per-file-ignores]
"__init__.py" = ["F401", "F403"]
"src/lerobot/policies/wall_x/**" = ["N801", "N812", "SIM102", "SIM108", "SIM210", "SIM211", "B006", "B007", "SIM118"] # Supprese these as they are coming from original Qwen2_5_vl code TODO(pepijn): refactor original
[tool.ruff.lint.isort]
combine-as-imports = true
@@ -263,6 +274,7 @@ default.extend-ignore-identifiers-re = [
"ein",
"thw",
"inpt",
"ROBOTIS",
]
# TODO: Uncomment when ready to use
@@ -317,9 +329,9 @@ disallow_untyped_defs = true
disallow_incomplete_defs = true
check_untyped_defs = true
# [[tool.mypy.overrides]]
# module = "lerobot.optim.*"
# ignore_errors = false
[[tool.mypy.overrides]]
module = "lerobot.optim.*"
ignore_errors = false
[[tool.mypy.overrides]]
module = "lerobot.model.*"
@@ -369,3 +381,77 @@ ignore_errors = false
# [[tool.mypy.overrides]]
# module = "lerobot.scripts.*"
# ignore_errors = false
[tool.uv]
# wallx requires transformers==4.49.0 which conflicts with other extras that need >=4.53.0
conflicts = [
[
{ extra = "wallx" },
{ extra = "transformers-dep" },
],
[
{ extra = "wallx" },
{ extra = "pi" },
],
[
{ extra = "wallx" },
{ extra = "smolvla" },
],
[
{ extra = "wallx" },
{ extra = "groot" },
],
[
{ extra = "wallx" },
{ extra = "xvla" },
],
[
{ extra = "wallx" },
{ extra = "sarm" },
],
[
{ extra = "wallx" },
{ extra = "hilserl" },
],
[
{ extra = "wallx" },
{ extra = "libero" },
],
[
{ extra = "wallx" },
{ extra = "all" },
],
# pi uses custom branch which conflicts with transformers-dep
[
{ extra = "pi" },
{ extra = "transformers-dep" },
],
[
{ extra = "pi" },
{ extra = "smolvla" },
],
[
{ extra = "pi" },
{ extra = "groot" },
],
[
{ extra = "pi" },
{ extra = "xvla" },
],
[
{ extra = "pi" },
{ extra = "sarm" },
],
[
{ extra = "pi" },
{ extra = "hilserl" },
],
[
{ extra = "pi" },
{ extra = "libero" },
],
[
{ extra = "pi" },
{ extra = "all" },
],
]
+1 -1
View File
@@ -26,4 +26,4 @@ DEFAULT_OBS_QUEUE_TIMEOUT = 2
SUPPORTED_POLICIES = ["act", "smolvla", "diffusion", "tdmpc", "vqbet", "pi0", "pi05"]
# TODO: Add all other robots
SUPPORTED_ROBOTS = ["so100_follower", "so101_follower", "bi_so100_follower"]
SUPPORTED_ROBOTS = ["so100_follower", "so101_follower", "bi_so100_follower", "omx_follower"]
@@ -54,6 +54,7 @@ from lerobot.robots import ( # noqa: F401
bi_so100_follower,
koch_follower,
make_robot_from_config,
omx_follower,
so100_follower,
so101_follower,
)
+18 -1
View File
@@ -56,6 +56,7 @@ class TrainPipelineConfig(HubMixin):
steps: int = 100_000
eval_freq: int = 20_000
log_freq: int = 200
tolerance_s: float = 1e-4
save_checkpoint: bool = True
# Checkpoint is saved every `save_freq` training iterations and after the last training step.
save_freq: int = 20_000
@@ -64,9 +65,17 @@ class TrainPipelineConfig(HubMixin):
scheduler: LRSchedulerConfig | None = None
eval: EvalConfig = field(default_factory=EvalConfig)
wandb: WandBConfig = field(default_factory=WandBConfig)
checkpoint_path: Path | None = field(init=False, default=None)
# RA-BC (Reward-Aligned Behavior Cloning) parameters
use_rabc: bool = False # Enable reward-weighted training
rabc_progress_path: str | None = None # Path to precomputed SARM progress parquet file
rabc_kappa: float = 0.01 # Hard threshold for high-quality samples
rabc_epsilon: float = 1e-6 # Small constant for numerical stability
rabc_head_mode: str | None = "sparse" # For dual-head models: "sparse" or "dense"
# Rename map for the observation to override the image and state keys
rename_map: dict[str, str] = field(default_factory=dict)
checkpoint_path: Path | None = field(init=False, default=None)
def validate(self) -> None:
# HACK: We parse again the cli args here to get the pretrained paths if there was some.
@@ -130,6 +139,14 @@ class TrainPipelineConfig(HubMixin):
"'policy.repo_id' argument missing. Please specify it to push the model to the hub."
)
if self.use_rabc and not self.rabc_progress_path:
# Auto-detect from dataset path
repo_id = self.dataset.repo_id
if self.dataset.root:
self.rabc_progress_path = str(Path(self.dataset.root) / "sarm_progress.parquet")
else:
self.rabc_progress_path = f"hf://datasets/{repo_id}/sarm_progress.parquet"
@classmethod
def __get_path_fields__(cls) -> list[str]:
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
+13
View File
@@ -0,0 +1,13 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
@@ -0,0 +1,13 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
@@ -59,6 +59,7 @@ python examples/dataset_annotation/subtask_annotation.py \
import argparse
import json
import multiprocessing as mp
import random
import re
import subprocess
import tempfile
@@ -66,21 +67,100 @@ import textwrap
import time
from concurrent.futures import ProcessPoolExecutor, as_completed
from pathlib import Path
from typing import Any
import cv2
import numpy as np
import pandas as pd
import torch
from qwen_vl_utils import process_vision_info
from rich.console import Console
from pydantic import BaseModel, Field
from transformers import AutoProcessor, Qwen3VLMoeForConditionalGeneration
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.policies.sarm.sarm_utils import (
Subtask,
SubtaskAnnotation,
Timestamp,
compute_temporal_proportions,
)
# Pydantic Models for SARM Subtask Annotation
class Timestamp(BaseModel):
"""Timestamp in MM:SS or SS format"""
start: str = Field(description="Start timestamp (MM:SS or just seconds)")
end: str = Field(description="End timestamp (MM:SS or just seconds)")
class Subtask(BaseModel):
"""Individual subtask/stage - must use EXACT names from provided list"""
name: str = Field(description="Subtask name - MUST match one from the predefined list exactly")
timestamps: Timestamp
class SubtaskAnnotation(BaseModel):
"""Complete annotation for a robot manipulation episode"""
subtasks: list[Subtask] = Field(description="List of all subtasks in temporal order")
def compute_temporal_proportions(
annotations: dict[int, Any], fps: int = 30, subtask_order: list[str] | None = None
) -> dict[str, float]:
"""
Compute dataset-level temporal proportions (priors) for each subtask.
Implements SARM Paper Formula (1): ᾱ_k = (1/M) × Σ_i (L_{i,k} / T_i)
Args:
annotations: Dict mapping episode index to SubtaskAnnotation object.
fps: Frames per second (unused, kept for API compatibility)
subtask_order: Optional list defining the output order of subtasks.
Returns:
Dict mapping subtask name to its temporal proportion (ᾱ_k), ordered by subtask_order if provided.
"""
subtask_proportions: dict[str, list[float]] = {}
for annotation in annotations.values():
total_duration = 0
durations: dict[str, int] = {}
for subtask in annotation.subtasks:
start_parts = subtask.timestamps.start.split(":")
end_parts = subtask.timestamps.end.split(":")
start_seconds = (
int(start_parts[0]) * 60 + int(start_parts[1])
if len(start_parts) == 2
else int(start_parts[0])
)
end_seconds = (
int(end_parts[0]) * 60 + int(end_parts[1]) if len(end_parts) == 2 else int(end_parts[0])
)
duration = end_seconds - start_seconds
durations[subtask.name] = duration
total_duration += duration
if total_duration > 0:
for name, duration in durations.items():
if name not in subtask_proportions:
subtask_proportions[name] = []
subtask_proportions[name].append(duration / total_duration)
if not subtask_proportions:
return {}
avg_proportions = {name: sum(props) / len(props) for name, props in subtask_proportions.items()}
total = sum(avg_proportions.values())
if total > 0:
avg_proportions = {name: prop / total for name, prop in avg_proportions.items()}
# Reorder according to subtask_order if provided
if subtask_order:
avg_proportions = {
name: avg_proportions.get(name, 0.0) for name in subtask_order if name in avg_proportions
}
return avg_proportions
def create_sarm_prompt(subtask_list: list[str]) -> str:
@@ -177,8 +257,8 @@ class VideoAnnotator:
model_name: str = "Qwen/Qwen3-VL-30B-A3B-Instruct",
device: str = "cuda",
torch_dtype: torch.dtype = torch.bfloat16,
model: "Qwen3VLMoeForConditionalGeneration | None" = None,
processor: "AutoProcessor | None" = None,
model: Qwen3VLMoeForConditionalGeneration | None = None, # noqa: F821
processor: AutoProcessor | None = None, # noqa: F821
):
"""
Initialize the video annotator with local model.
@@ -193,16 +273,17 @@ class VideoAnnotator:
"""
self.subtask_list = subtask_list
self.prompt = create_sarm_prompt(subtask_list)
self.console = Console()
self.device = device
# Use provided model/processor or load new ones
if model is not None and processor is not None:
self.model = model
self.processor = processor
self.console.print(f"[green]✓ Using shared model on {device}[/green]")
print(f"Using shared model on {device}")
else:
self.console.print(f"[cyan]Loading model: {model_name}...[/cyan]")
from transformers import AutoProcessor, Qwen3VLMoeForConditionalGeneration
print(f"Loading model: {model_name}...")
self.model = Qwen3VLMoeForConditionalGeneration.from_pretrained(
model_name, torch_dtype=torch_dtype, device_map=device, trust_remote_code=True
@@ -210,7 +291,7 @@ class VideoAnnotator:
self.processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
self.console.print(f"[green]✓ Model loaded successfully on {device}[/green]")
print(f"Model loaded successfully on {device}")
def extract_episode_segment(
self, file_path: Path, start_timestamp: float, end_timestamp: float, target_fps: int = 1
@@ -229,25 +310,22 @@ class VideoAnnotator:
Path to extracted video file
"""
# Create temporary file for extracted video
tmp_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
tmp_path = Path(tmp_file.name)
tmp_file.close()
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp_file:
tmp_path = Path(tmp_file.name)
try:
# Check if ffmpeg is available
subprocess.run(
subprocess.run( # nosec B607
["ffmpeg", "-version"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True
)
except (subprocess.CalledProcessError, FileNotFoundError):
raise RuntimeError("ffmpeg not found, cannot extract episode segment") from e
except (subprocess.CalledProcessError, FileNotFoundError) as err:
raise RuntimeError("ffmpeg not found, cannot extract episode segment") from err
try:
# Calculate duration
duration = end_timestamp - start_timestamp
self.console.print(
f"[cyan]Extracting episode: {start_timestamp:.1f}s-{end_timestamp:.1f}s ({duration:.1f}s)[/cyan]"
)
print(f"Extracting episode: {start_timestamp:.1f}s-{end_timestamp:.1f}s ({duration:.1f}s)")
# Use ffmpeg to extract segment with minimal quality loss
cmd = [
@@ -275,7 +353,7 @@ class VideoAnnotator:
# Verify the output file was created and is not empty
if not tmp_path.exists() or tmp_path.stat().st_size == 0:
self.console.print("[red]✗ Video extraction failed (0 bytes) - skipping episode[/red]")
print("Video extraction failed (0 bytes) - skipping episode")
if tmp_path.exists():
tmp_path.unlink()
raise RuntimeError("FFmpeg produced empty video file")
@@ -285,13 +363,11 @@ class VideoAnnotator:
# Fail if file is too small (< 100KB likely means extraction failed)
if file_size_mb < 0.1:
self.console.print(
f"[red]✗ Extracted video too small ({file_size_mb:.2f}MB) - skipping episode[/red]"
)
print(f"Extracted video too small ({file_size_mb:.2f}MB) - skipping episode")
tmp_path.unlink()
raise RuntimeError(f"Video extraction produced invalid file ({file_size_mb:.2f}MB)")
self.console.print(f"[green]✓ Extracted: {file_size_mb:.1f}MB ({target_fps} FPS)[/green]")
print(f"Extracted: {file_size_mb:.1f}MB ({target_fps} FPS)")
return tmp_path
@@ -307,6 +383,8 @@ class VideoAnnotator:
max_retries: int = 3,
) -> SubtaskAnnotation:
"""Annotate a video segment using local GPU."""
from qwen_vl_utils import process_vision_info
file_path = Path(file_path)
if end_timestamp is None:
@@ -355,7 +433,7 @@ class VideoAnnotator:
)
response = self.processor.batch_decode(
[out[len(inp) :] for inp, out in zip(inputs.input_ids, generated_ids)],
[out[len(inp) :] for inp, out in zip(inputs.input_ids, generated_ids, strict=True)],
skip_special_tokens=True,
)[0].strip()
@@ -371,7 +449,7 @@ class VideoAnnotator:
match = re.search(r"\{.*\}", response, re.DOTALL)
if match:
return SubtaskAnnotation.model_validate(json.loads(match.group()))
raise ValueError("No JSON found")
raise ValueError("No JSON found") from None
except Exception as e:
if attempt == max_retries - 1:
raise RuntimeError(f"Failed after {max_retries} attempts") from e
@@ -381,16 +459,12 @@ class VideoAnnotator:
extracted_path.unlink()
def display_annotation(
annotation: SubtaskAnnotation, console: Console, episode_idx: int, fps: int, prefix: str = ""
):
def display_annotation(annotation: SubtaskAnnotation, episode_idx: int, fps: int, prefix: str = ""):
"""Display annotation summary."""
subtask_summary = ", ".join(
f"{s.name}({s.timestamps.start}-{s.timestamps.end})" for s in annotation.subtasks
)
console.print(
f"[green]Episode {episode_idx} {prefix}: {len(annotation.subtasks)} subtasks - {subtask_summary}[/green]"
)
print(f"Episode {episode_idx} {prefix}: {len(annotation.subtasks)} subtasks - {subtask_summary}")
def timestamp_to_seconds(timestamp: str) -> float:
@@ -402,6 +476,272 @@ def timestamp_to_seconds(timestamp: str) -> float:
return int(parts[0])
def extract_frame(video_path: Path, timestamp: float) -> np.ndarray | None:
"""Extract a single frame from video at given timestamp."""
cap = cv2.VideoCapture(str(video_path))
if not cap.isOpened():
return None
cap.set(cv2.CAP_PROP_POS_MSEC, timestamp * 1000)
ret, frame = cap.read()
cap.release()
return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) if ret else None
def draw_timeline(ax, subtasks, total_duration, colors):
"""Draw a timeline with color-coded subtask segments."""
import matplotlib.patches as mpatches
bar_height, bar_y = 0.6, 0.5
for i, subtask in enumerate(subtasks):
start = timestamp_to_seconds(subtask.timestamps.start)
end = timestamp_to_seconds(subtask.timestamps.end)
color = colors[i % len(colors)]
rect = mpatches.FancyBboxPatch(
(start, bar_y - bar_height / 2),
end - start,
bar_height,
boxstyle="round,pad=0.02,rounding_size=0.1",
facecolor=color,
edgecolor="white",
linewidth=1.5,
alpha=0.85,
)
ax.add_patch(rect)
# Add label if segment is wide enough
duration = end - start
if duration > total_duration * 0.06:
ax.text(
(start + end) / 2,
bar_y,
subtask.name,
ha="center",
va="center",
fontsize=8,
fontweight="bold",
color="white",
rotation=0 if duration > total_duration * 0.12 else 45,
)
if i > 0:
ax.axvline(x=start, ymin=0.1, ymax=0.9, color="white", linestyle="--", linewidth=1.5, alpha=0.7)
ax.axvline(x=0, ymin=0.1, ymax=0.9, color="#00ff00", linestyle="-", linewidth=2, alpha=0.9)
if subtasks:
ax.axvline(
x=timestamp_to_seconds(subtasks[-1].timestamps.end),
ymin=0.1,
ymax=0.9,
color="white",
linestyle="--",
linewidth=1.5,
alpha=0.7,
)
ax.set_xlim(-total_duration * 0.02, total_duration * 1.02)
ax.set_ylim(-0.1, 1.1)
ax.set_xlabel("Time (seconds)", fontsize=10, color="white", labelpad=5)
for spine in ["top", "right", "left"]:
ax.spines[spine].set_visible(False)
ax.spines["bottom"].set_color("#444444")
ax.tick_params(axis="x", colors="#888888", labelsize=8)
ax.tick_params(axis="y", left=False, labelleft=False)
def visualize_episode(
ep_idx: int,
annotation: SubtaskAnnotation,
video_path: Path,
video_start: float,
video_end: float,
output_path: Path,
video_key: str,
ann_type: str,
):
"""Create visualization for a single episode with frames and timeline."""
import matplotlib.pyplot as plt
if annotation is None:
print(f"No {ann_type} annotation for episode {ep_idx}")
return
subtasks = annotation.subtasks
if not subtasks:
print(f"No subtasks for episode {ep_idx}")
return
colors = plt.cm.tab10(np.linspace(0, 1, max(len(subtasks), 10)))
total_duration = timestamp_to_seconds(subtasks[-1].timestamps.end)
# Extract middle frame from each subtask
sample_frames, frame_times = [], []
for subtask in subtasks:
start = timestamp_to_seconds(subtask.timestamps.start)
end = timestamp_to_seconds(subtask.timestamps.end)
mid = (start + end) / 2
frame_times.append(mid)
sample_frames.append(extract_frame(video_path, video_start + mid))
# Create figure
fig_width = max(16, len(subtasks) * 2.5)
fig = plt.figure(figsize=(fig_width, 10))
fig.patch.set_facecolor("#1a1a2e")
gs = fig.add_gridspec(
2,
max(len(subtasks), 1),
height_ratios=[2, 1],
hspace=0.3,
wspace=0.1,
left=0.05,
right=0.95,
top=0.88,
bottom=0.1,
)
fig.suptitle(
f"Episode {ep_idx} - {ann_type.capitalize()} Annotations",
fontsize=18,
fontweight="bold",
color="white",
y=0.96,
)
fig.text(
0.5,
0.91,
f"Camera: {video_key} | Duration: {video_end - video_start:.1f}s | {len(subtasks)} subtasks",
ha="center",
fontsize=11,
color="#888888",
)
# Plot frames
for i, (frame, subtask) in enumerate(zip(sample_frames, subtasks, strict=True)):
ax = fig.add_subplot(gs[0, i])
ax.set_facecolor("#16213e")
if frame is not None:
ax.imshow(frame)
else:
ax.text(
0.5, 0.5, "N/A", ha="center", va="center", fontsize=12, color="white", transform=ax.transAxes
)
ax.set_title(subtask.name, fontsize=10, fontweight="bold", color=colors[i % len(colors)], pad=8)
ax.axis("off")
ax.text(
0.5,
-0.08,
f"t={frame_times[i]:.1f}s",
ha="center",
fontsize=9,
color="#888888",
transform=ax.transAxes,
)
# Plot timeline
ax_timeline = fig.add_subplot(gs[1, :])
ax_timeline.set_facecolor("#16213e")
draw_timeline(ax_timeline, subtasks, total_duration, colors)
output_path.parent.mkdir(parents=True, exist_ok=True)
plt.savefig(output_path, dpi=150, facecolor=fig.get_facecolor(), edgecolor="none", bbox_inches="tight")
plt.close()
print(f"Saved: {output_path}")
def visualize_annotations(
dataset: LeRobotDataset,
sparse_annotations: dict[int, SubtaskAnnotation],
dense_annotations: dict[int, SubtaskAnnotation] | None,
video_key: str,
output_dir: Path,
num_episodes: int = 5,
annotation_type: str = "sparse",
episode_indices: list[int] | None = None,
):
"""
Visualize subtask annotations for a set of episodes.
Args:
dataset: LeRobotDataset instance
sparse_annotations: Dict mapping episode index to sparse annotations
dense_annotations: Dict mapping episode index to dense annotations (or None)
video_key: Camera/video key to use
output_dir: Directory to save visualization images
num_episodes: Number of episodes to visualize (ignored if episode_indices provided)
annotation_type: "sparse", "dense", or "both"
episode_indices: Specific episode indices to visualize (optional)
"""
# Determine available episodes based on annotation type
if annotation_type == "sparse":
available = set(sparse_annotations.keys())
elif annotation_type == "dense":
available = set(dense_annotations.keys()) if dense_annotations else set()
else: # both
sparse_set = set(sparse_annotations.keys())
dense_set = set(dense_annotations.keys()) if dense_annotations else set()
available = sparse_set | dense_set
if not available:
print("Error: No annotations found to visualize.")
return
# Select episodes to visualize
if episode_indices:
episodes = sorted([e for e in episode_indices if e in available])
missing = set(episode_indices) - available
if missing:
print(f"Episodes not found in annotations: {sorted(missing)}")
else:
episodes = sorted(random.sample(list(available), min(num_episodes, len(available))))
print(f"Visualizing {len(episodes)} episodes: {episodes}")
output_dir.mkdir(parents=True, exist_ok=True)
# Generate visualizations
for i, ep_idx in enumerate(episodes, 1):
print(f"Processing episode {ep_idx} ({i}/{len(episodes)})")
video_path = dataset.root / dataset.meta.get_video_file_path(ep_idx, video_key)
if not video_path.exists():
print(f"Video not found: {video_path}")
continue
video_start = float(dataset.meta.episodes[f"videos/{video_key}/from_timestamp"][ep_idx])
video_end = float(dataset.meta.episodes[f"videos/{video_key}/to_timestamp"][ep_idx])
if annotation_type == "both":
# Visualize both sparse and dense
for ann_type, annotations in [("sparse", sparse_annotations), ("dense", dense_annotations)]:
if annotations and ep_idx in annotations:
output_path = output_dir / f"episode_{ep_idx:04d}_{ann_type}.png"
visualize_episode(
ep_idx,
annotations.get(ep_idx),
video_path,
video_start,
video_end,
output_path,
video_key,
ann_type,
)
else:
annotations = sparse_annotations if annotation_type == "sparse" else dense_annotations
if annotations and ep_idx in annotations:
output_path = output_dir / f"episode_{ep_idx:04d}_{annotation_type}.png"
visualize_episode(
ep_idx,
annotations.get(ep_idx),
video_path,
video_start,
video_end,
output_path,
video_key,
annotation_type,
)
print(f"Visualizations saved to: {output_dir.absolute()}")
def save_annotations_to_dataset(
dataset_path: Path, annotations: dict[int, SubtaskAnnotation], fps: int, prefix: str = "sparse"
):
@@ -533,7 +873,7 @@ def load_annotations_from_dataset(dataset_path: Path, prefix: str = "sparse") ->
end=f"{int(e) // 60:02d}:{int(e) % 60:02d}",
),
)
for n, s, e in zip(names, starts, ends)
for n, s, e in zip(names, starts, ends, strict=True)
]
)
return annotations
@@ -546,7 +886,6 @@ def process_single_episode(
video_key: str,
fps: int,
annotator: VideoAnnotator,
console: Console,
) -> tuple[int, SubtaskAnnotation | None, str | None]:
"""Process a single episode annotation."""
try:
@@ -574,7 +913,6 @@ def worker_process_episodes(
) -> tuple[dict, dict | None]:
"""Worker for parallel processing across GPUs."""
device = f"cuda:{gpu_id}"
console = Console()
dataset = LeRobotDataset(repo_id, download_videos=False)
sparse_annotator = VideoAnnotator(sparse_subtask_list, model_name, device, torch_dtype)
@@ -595,14 +933,14 @@ def worker_process_episodes(
for ep_idx in episode_indices:
_, sparse_ann, err = process_single_episode(
ep_idx, dataset.root, dataset.meta, video_key, dataset.fps, sparse_annotator, console
ep_idx, dataset.root, dataset.meta, video_key, dataset.fps, sparse_annotator
)
if sparse_ann:
sparse_annotations[ep_idx] = sparse_ann
if dense_annotator:
_, dense_ann, _ = process_single_episode(
ep_idx, dataset.root, dataset.meta, video_key, dataset.fps, dense_annotator, console
ep_idx, dataset.root, dataset.meta, video_key, dataset.fps, dense_annotator
)
if dense_ann:
dense_annotations[ep_idx] = dense_ann
@@ -632,25 +970,36 @@ def main():
parser.add_argument("--dtype", type=str, default="bfloat16", choices=["bfloat16", "float16", "float32"])
parser.add_argument("--num-workers", type=int, default=1, help="Parallel workers for multi-GPU")
parser.add_argument("--gpu-ids", type=int, nargs="+", default=None, help="GPU IDs to use")
# Visualization options
parser.add_argument(
"--visualize-only",
action="store_true",
help="Only visualize existing annotations (no generation)",
)
parser.add_argument(
"--num-visualizations",
type=int,
default=5,
help="Number of episodes to visualize (default: 5)",
)
parser.add_argument(
"--visualize-type",
type=str,
default="sparse",
choices=["sparse", "dense", "both"],
help="Type of annotations to visualize (default: sparse)",
)
parser.add_argument(
"--output-dir",
type=str,
default="./subtask_viz",
help="Output directory for visualizations (default: ./subtask_viz)",
)
args = parser.parse_args()
console = Console()
# Validate arguments
if args.dense_only and not args.dense_subtasks:
return console.print("[red]Error: --dense-only requires --dense-subtasks[/red]")
if args.dense_subtasks and not args.sparse_subtasks and not args.dense_only:
return console.print("[red]Error: --dense-subtasks requires --sparse-subtasks or --dense-only[/red]")
sparse_subtask_list = (
[s.strip() for s in args.sparse_subtasks.split(",")] if args.sparse_subtasks else None
)
dense_subtask_list = [s.strip() for s in args.dense_subtasks.split(",")] if args.dense_subtasks else None
auto_sparse = sparse_subtask_list is None
dense_mode = dense_subtask_list is not None
torch_dtype = {"bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32}[args.dtype]
console.print(f"[cyan]Loading dataset: {args.repo_id}[/cyan]")
# Load dataset first (needed for both annotation and visualization)
print(f"Loading dataset: {args.repo_id}")
dataset = LeRobotDataset(args.repo_id, download_videos=True)
fps = dataset.fps
@@ -660,7 +1009,44 @@ def main():
video_key = (
args.video_key if args.video_key in (dataset.meta.video_keys or []) else dataset.meta.video_keys[0]
)
console.print(f"[cyan]Using camera: {video_key}, FPS: {fps}[/cyan]")
print(f"Using camera: {video_key}, FPS: {fps}")
# Handle visualization-only mode
if args.visualize_only:
print("Visualization-only mode")
sparse_annotations = load_annotations_from_dataset(dataset.root, prefix="sparse")
dense_annotations = load_annotations_from_dataset(dataset.root, prefix="dense")
if not sparse_annotations and not dense_annotations:
return print("Error: No annotations found. Run annotation first.")
print(f"Found {len(sparse_annotations)} sparse, {len(dense_annotations)} dense annotations")
visualize_annotations(
dataset=dataset,
sparse_annotations=sparse_annotations,
dense_annotations=dense_annotations if dense_annotations else None,
video_key=video_key,
output_dir=Path(args.output_dir),
num_episodes=args.num_visualizations,
annotation_type=args.visualize_type,
episode_indices=args.episodes,
)
return
# Validate arguments for annotation mode
if args.dense_only and not args.dense_subtasks:
return print("Error: --dense-only requires --dense-subtasks")
if args.dense_subtasks and not args.sparse_subtasks and not args.dense_only:
return print("Error: --dense-subtasks requires --sparse-subtasks or --dense-only")
sparse_subtask_list = (
[s.strip() for s in args.sparse_subtasks.split(",")] if args.sparse_subtasks else None
)
dense_subtask_list = [s.strip() for s in args.dense_subtasks.split(",")] if args.dense_subtasks else None
auto_sparse = sparse_subtask_list is None
dense_mode = dense_subtask_list is not None
torch_dtype = {"bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32}[args.dtype]
# Determine episodes
episode_indices = args.episodes or list(range(dataset.meta.total_episodes))
@@ -670,8 +1056,8 @@ def main():
episode_indices = [ep for ep in episode_indices if ep not in existing_annotations]
if not episode_indices:
return console.print("[green]All episodes already annotated![/green]")
console.print(f"[cyan]Annotating {len(episode_indices)} episodes[/cyan]")
return print("All episodes already annotated!")
print(f"Annotating {len(episode_indices)} episodes")
# GPU setup
gpu_ids = args.gpu_ids or list(
@@ -686,7 +1072,7 @@ def main():
if auto_sparse:
sparse_annotations.update(generate_auto_sparse_annotations(dataset, episode_indices, video_key))
save_annotations_to_dataset(dataset.root, sparse_annotations, fps, prefix="sparse")
console.print(f"[green]Auto-generated {len(episode_indices)} sparse 'task' annotations[/green]")
print(f"Auto-generated {len(episode_indices)} sparse 'task' annotations")
# VLM annotation (for sparse if not auto, and for dense)
need_vlm = (not auto_sparse) or dense_mode
@@ -694,7 +1080,7 @@ def main():
if need_vlm:
if args.num_workers > 1 and not auto_sparse:
# Parallel processing
console.print(f"[cyan]Parallel processing with {args.num_workers} workers[/cyan]")
print(f"Parallel processing with {args.num_workers} workers")
episodes_per_worker = [[] for _ in range(args.num_workers)]
for i, ep_idx in enumerate(episode_indices):
episodes_per_worker[i % args.num_workers].append(ep_idx)
@@ -751,52 +1137,66 @@ def main():
)
for i, ep_idx in enumerate(episode_indices):
console.print(f"[cyan]Episode {ep_idx} ({i + 1}/{len(episode_indices)})[/cyan]")
print(f"Episode {ep_idx} ({i + 1}/{len(episode_indices)})")
if sparse_annotator:
_, sparse_ann, err = process_single_episode(
ep_idx, dataset.root, dataset.meta, video_key, fps, sparse_annotator, console
ep_idx, dataset.root, dataset.meta, video_key, fps, sparse_annotator
)
if sparse_ann:
sparse_annotations[ep_idx] = sparse_ann
save_annotations_to_dataset(dataset.root, sparse_annotations, fps, prefix="sparse")
elif err:
console.print(f"[red]Sparse failed: {err}[/red]")
print(f"Sparse failed: {err}")
if dense_annotator:
_, dense_ann, err = process_single_episode(
ep_idx, dataset.root, dataset.meta, video_key, fps, dense_annotator, console
ep_idx, dataset.root, dataset.meta, video_key, fps, dense_annotator
)
if dense_ann:
dense_annotations[ep_idx] = dense_ann
save_annotations_to_dataset(dataset.root, dense_annotations, fps, prefix="dense")
elif err:
console.print(f"[red]Dense failed: {err}[/red]")
print(f"Dense failed: {err}")
# Save temporal proportions
def save_proportions(annotations, prefix, is_auto=False):
props: dict[str, float] = {"task": 1.0} if is_auto else compute_temporal_proportions(annotations, fps)
def save_proportions(annotations, prefix, subtask_list=None, is_auto=False):
props: dict[str, float] = (
{"task": 1.0} if is_auto else compute_temporal_proportions(annotations, fps, subtask_list)
)
path = dataset.root / "meta" / f"temporal_proportions_{prefix}.json"
path.parent.mkdir(parents=True, exist_ok=True)
with open(path, "w") as f:
json.dump(props, f, indent=2)
console.print(f"[green]Saved {prefix} temporal proportions[/green]")
print(f"Saved {prefix} temporal proportions")
save_proportions(sparse_annotations, "sparse", auto_sparse)
save_proportions(sparse_annotations, "sparse", sparse_subtask_list, auto_sparse)
if dense_mode and dense_annotations:
save_proportions(dense_annotations, "dense")
save_proportions(dense_annotations, "dense", dense_subtask_list)
console.print(
f"\n[bold green]Complete! {len(sparse_annotations)} sparse, {len(dense_annotations or {})} dense annotations[/bold green]"
)
print(f"\nComplete! {len(sparse_annotations)} sparse, {len(dense_annotations or {})} dense annotations")
# Visualize annotations after generation
if args.num_visualizations > 0:
print(f"\nGenerating {args.num_visualizations} visualizations...")
visualize_type = "both" if dense_mode else "sparse"
visualize_annotations(
dataset=dataset,
sparse_annotations=sparse_annotations,
dense_annotations=dense_annotations,
video_key=video_key,
output_dir=Path(args.output_dir),
num_episodes=args.num_visualizations,
annotation_type=visualize_type,
)
if args.push_to_hub:
try:
dataset.push_to_hub(push_videos=True)
console.print(f"[green]Pushed to {args.output_repo_id or args.repo_id}[/green]")
print(f"Pushed to {args.output_repo_id or args.repo_id}")
except Exception as e:
console.print(f"[red]Push failed: {e}[/red]")
print(f"Push failed: {e}")
if __name__ == "__main__":
main()
main()
+2
View File
@@ -98,6 +98,7 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
image_transforms=image_transforms,
revision=cfg.dataset.revision,
video_backend=cfg.dataset.video_backend,
tolerance_s=cfg.tolerance_s,
)
else:
dataset = StreamingLeRobotDataset(
@@ -108,6 +109,7 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
image_transforms=image_transforms,
revision=cfg.dataset.revision,
max_num_shards=cfg.num_workers,
tolerance_s=cfg.tolerance_s,
)
else:
raise NotImplementedError("The MultiLeRobotDataset isn't supported for now.")
-8
View File
@@ -58,7 +58,6 @@ from lerobot.datasets.utils import (
load_nested_dataset,
load_stats,
load_tasks,
load_tasks_high_level,
update_chunk_file_indices,
validate_episode_buffer,
validate_frame,
@@ -162,7 +161,6 @@ class LeRobotDatasetMetadata:
self.info = load_info(self.root)
check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
self.tasks = load_tasks(self.root)
self.tasks_high_level = load_tasks_high_level(self.root)
self.episodes = load_episodes(self.root)
self.stats = load_stats(self.root)
@@ -1052,12 +1050,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
# Add task as a string
task_idx = item["task_index"].item()
item["task"] = self.meta.tasks.iloc[task_idx].name
# Optionally add high level task index
if "task_index_high_level" in self.features:
high_level_task_idx = item["task_index_high_level"].item()
item["robot_utterance"] = self.meta.tasks_high_level.iloc[high_level_task_idx]["robot_utterance"]
item["user_prompt"] = self.meta.tasks_high_level.iloc[high_level_task_idx]["user_prompt"]
return item
def __repr__(self):
-4
View File
@@ -60,7 +60,6 @@ VIDEO_DIR = "videos"
CHUNK_FILE_PATTERN = "chunk-{chunk_index:03d}/file-{file_index:03d}"
DEFAULT_TASKS_PATH = "meta/tasks.parquet"
DEFAULT_TASKS_HIGH_LEVEL_PATH = "meta/tasks_high_level.parquet"
DEFAULT_EPISODES_PATH = EPISODES_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
DEFAULT_DATA_PATH = DATA_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
DEFAULT_VIDEO_PATH = VIDEO_DIR + "/{video_key}/" + CHUNK_FILE_PATTERN + ".mp4"
@@ -353,9 +352,6 @@ def load_tasks(local_dir: Path) -> pandas.DataFrame:
tasks = pd.read_parquet(local_dir / DEFAULT_TASKS_PATH)
return tasks
def load_tasks_high_level(local_dir: Path) -> pandas.DataFrame:
tasks = pd.read_parquet(local_dir / DEFAULT_TASKS_HIGH_LEVEL_PATH)
return tasks
def write_episodes(episodes: Dataset, local_dir: Path) -> None:
"""Write episode metadata to a parquet file in the LeRobot v3.0 format.
+2
View File
@@ -35,6 +35,8 @@ def make_optimizer_and_scheduler(
tuple[Optimizer, LRScheduler | None]: The couple (Optimizer, Scheduler). Scheduler can be `None`.
"""
params = policy.get_optim_params() if cfg.use_policy_training_preset else policy.parameters()
if cfg.optimizer is None:
raise ValueError("Optimizer config is required but not provided in TrainPipelineConfig")
optimizer = cfg.optimizer.build(params)
lr_scheduler = cfg.scheduler.build(optimizer, cfg.steps) if cfg.scheduler is not None else None
return optimizer, lr_scheduler
+45 -18
View File
@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
from collections.abc import Iterable
from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import Any
@@ -29,6 +30,17 @@ from lerobot.utils.constants import (
)
from lerobot.utils.io_utils import deserialize_json_into_object
# Type alias for parameters accepted by optimizer build() methods.
# This matches PyTorch's optimizer signature while also supporting:
# - dict[str, Parameter]: Named parameters for differential LR by name (e.g., XVLA)
# - dict[str, Iterable]: Multiple parameter groups for multi-optimizer configs (e.g., SAC)
OptimizerParams = (
Iterable[torch.nn.Parameter] # From model.parameters()
| Iterable[dict[str, Any]] # List of param groups with lr/weight_decay overrides
| dict[str, torch.nn.Parameter] # From dict(model.named_parameters()) for name-based LR
| dict[str, Any] # For multi-optimizer configs (SAC) with multiple param groups
)
@dataclass
class OptimizerConfig(draccus.ChoiceRegistry, abc.ABC):
@@ -45,13 +57,24 @@ class OptimizerConfig(draccus.ChoiceRegistry, abc.ABC):
return "adam"
@abc.abstractmethod
def build(self) -> torch.optim.Optimizer | dict[str, torch.optim.Optimizer]:
def build(self, params: OptimizerParams) -> torch.optim.Optimizer | dict[str, torch.optim.Optimizer]:
"""
Build the optimizer. It can be a single optimizer or a dictionary of optimizers.
NOTE: Multiple optimizers are useful when you have different models to optimize.
For example, you can have one optimizer for the policy and another one for the value function
in reinforcement learning settings.
Args:
params: Parameters to optimize. Accepts multiple formats depending on the optimizer:
- Iterable[Parameter]: From model.parameters() - standard PyTorch usage
- Iterable[dict]: List of param groups with 'params' key and optional
'lr', 'weight_decay' overrides (e.g., ACT, VQBeT policies)
- dict[str, Parameter]: From dict(model.named_parameters()) for optimizers
that apply differential learning rates by parameter name (e.g., XVLA)
- dict[str, Iterable]: For multi-optimizer configs where each key maps to
a separate optimizer's parameters (e.g., SAC with actor/critic/temperature)
Returns:
The optimizer or a dictionary of optimizers.
"""
@@ -67,7 +90,7 @@ class AdamConfig(OptimizerConfig):
weight_decay: float = 0.0
grad_clip_norm: float = 10.0
def build(self, params: dict) -> torch.optim.Optimizer:
def build(self, params: OptimizerParams) -> torch.optim.Optimizer:
kwargs = asdict(self)
kwargs.pop("grad_clip_norm")
return torch.optim.Adam(params, **kwargs)
@@ -82,7 +105,7 @@ class AdamWConfig(OptimizerConfig):
weight_decay: float = 1e-2
grad_clip_norm: float = 10.0
def build(self, params: dict) -> torch.optim.Optimizer:
def build(self, params: OptimizerParams) -> torch.optim.Optimizer:
kwargs = asdict(self)
kwargs.pop("grad_clip_norm")
return torch.optim.AdamW(params, **kwargs)
@@ -98,7 +121,7 @@ class SGDConfig(OptimizerConfig):
weight_decay: float = 0.0
grad_clip_norm: float = 10.0
def build(self, params: dict) -> torch.optim.Optimizer:
def build(self, params: OptimizerParams) -> torch.optim.Optimizer:
kwargs = asdict(self)
kwargs.pop("grad_clip_norm")
return torch.optim.SGD(params, **kwargs)
@@ -139,21 +162,19 @@ class XVLAAdamWConfig(OptimizerConfig):
soft_prompt_lr_scale: float = 1.0 # Scale factor for soft-prompt LR (1.0 = same as base LR)
soft_prompt_warmup_lr_scale: float | None = None # If set, start soft-prompts at this scale (e.g., 0.01)
def build(self, params: dict) -> torch.optim.Optimizer:
def build(self, params: OptimizerParams) -> torch.optim.Optimizer:
"""
Build AdamW optimizer with differential learning rates.
Expects `named_parameters()` as input (dict of name -> param).
Applies:
- lr * 0.1 for all VLM-related parameters
- lr * soft_prompt_lr_scale for soft-prompt parameters (with optional warmup)
- full lr for all other parameters
Args:
params: Dictionary of parameter names to parameters (from named_parameters())
params: Must be a dict[str, Parameter] from dict(model.named_parameters())
or equivalent.
Returns:
AdamW optimizer with parameter groups for VLM, soft-prompts, and other components
Raises:
AssertionError: If params is not a dict (e.g., from model.parameters())
"""
assert isinstance(params, dict), "Custom LR optimizer requires `named_parameters()` as inputs."
@@ -174,7 +195,7 @@ class XVLAAdamWConfig(OptimizerConfig):
# Start at warmup scale, scheduler will warm up to soft_prompt_lr
soft_prompt_lr = self.lr * self.soft_prompt_warmup_lr_scale
param_groups = [
param_groups: list[dict[str, Any]] = [
{
"params": vlm_group,
"lr": self.lr * 0.1,
@@ -224,19 +245,25 @@ class MultiAdamConfig(OptimizerConfig):
grad_clip_norm: float = 10.0
optimizer_groups: dict[str, dict[str, Any]] = field(default_factory=dict)
def build(self, params_dict: dict[str, list]) -> dict[str, torch.optim.Optimizer]:
def build(self, params: OptimizerParams) -> dict[str, torch.optim.Optimizer]:
"""Build multiple Adam optimizers.
Args:
params_dict: Dictionary mapping parameter group names to lists of parameters
The keys should match the keys in optimizer_groups
params: Must be a dict[str, Iterable[Parameter]] mapping parameter group names
to iterables of parameters. The keys should match the keys in optimizer_groups.
Typically from policies that need separate optimizers (e.g., SAC with
actor/critic/temperature).
Returns:
Dictionary mapping parameter group names to their optimizers
Raises:
AssertionError: If params is not a dict
"""
assert isinstance(params, dict), "MultiAdamConfig requires a dict of parameter groups as inputs."
optimizers = {}
for name, params in params_dict.items():
for name, group_params in params.items():
# Get group-specific hyperparameters or use defaults
group_config = self.optimizer_groups.get(name, {})
@@ -248,7 +275,7 @@ class MultiAdamConfig(OptimizerConfig):
"weight_decay": group_config.get("weight_decay", self.weight_decay),
}
optimizers[name] = torch.optim.Adam(params, **optimizer_kwargs)
optimizers[name] = torch.optim.Adam(group_params, **optimizer_kwargs)
return optimizers
+1 -1
View File
@@ -30,7 +30,7 @@ from lerobot.utils.io_utils import deserialize_json_into_object
@dataclass
class LRSchedulerConfig(draccus.ChoiceRegistry, abc.ABC):
num_warmup_steps: int
num_warmup_steps: int | None
@property
def type(self) -> str:
+3
View File
@@ -21,6 +21,7 @@ from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
from .smolvla.processor_smolvla import SmolVLANewLineProcessor
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig
from .wall_x.configuration_wall_x import WallXConfig as WallXConfig
from .xvla.configuration_xvla import XVLAConfig as XVLAConfig
__all__ = [
@@ -29,8 +30,10 @@ __all__ = [
"PI0Config",
"PI05Config",
"SmolVLAConfig",
"SARMConfig",
"TDMPCConfig",
"VQBeTConfig",
"GrootConfig",
"XVLAConfig",
"WallXConfig",
]
+1
View File
@@ -50,6 +50,7 @@ class ACTPolicy(PreTrainedPolicy):
def __init__(
self,
config: ACTConfig,
**kwargs,
):
"""
Args:
@@ -56,6 +56,7 @@ class DiffusionPolicy(PreTrainedPolicy):
def __init__(
self,
config: DiffusionConfig,
**kwargs,
):
"""
Args:
+38 -2
View File
@@ -37,10 +37,12 @@ from lerobot.policies.pi05.configuration_pi05 import PI05Config
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.sac.configuration_sac import SACConfig
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
from lerobot.policies.sarm.configuration_sarm import SARMConfig
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.policies.utils import validate_visual_features_consistency
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
from lerobot.policies.wall_x.configuration_wall_x import WallXConfig
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
from lerobot.processor.converters import (
@@ -61,7 +63,7 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
Args:
name: The name of the policy. Supported names are "tdmpc", "diffusion", "act",
"vqbet", "pi0", "pi05", "sac", "reward_classifier", "smolvla".
"vqbet", "pi0", "pi05", "sac", "reward_classifier", "smolvla", "wall_x".
Returns:
The policy class corresponding to the given name.
@@ -105,6 +107,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy
return SmolVLAPolicy
elif name == "sarm":
from lerobot.policies.sarm.modeling_sarm import SARMRewardModel
return SARMRewardModel
elif name == "groot":
from lerobot.policies.groot.modeling_groot import GrootPolicy
@@ -113,6 +119,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
from lerobot.policies.xvla.modeling_xvla import XVLAPolicy
return XVLAPolicy
elif name == "wall_x":
from lerobot.policies.wall_x.modeling_wall_x import WallXPolicy
return WallXPolicy
else:
try:
return _get_policy_cls_from_policy_name(name=name)
@@ -130,7 +140,7 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
Args:
policy_type: The type of the policy. Supported types include "tdmpc",
"diffusion", "act", "vqbet", "pi0", "pi05", "sac", "smolvla",
"reward_classifier".
"reward_classifier", "wall_x".
**kwargs: Keyword arguments to be passed to the configuration class constructor.
Returns:
@@ -161,6 +171,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
return GrootConfig(**kwargs)
elif policy_type == "xvla":
return XVLAConfig(**kwargs)
elif policy_type == "wall_x":
return WallXConfig(**kwargs)
else:
try:
config_cls = PreTrainedConfig.get_choice_class(policy_type)
@@ -337,6 +349,14 @@ def make_pre_post_processors(
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, SARMConfig):
from lerobot.policies.sarm.processor_sarm import make_sarm_pre_post_processors
processors = make_sarm_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
dataset_meta=kwargs.get("dataset_meta"),
)
elif isinstance(policy_cfg, GrootConfig):
from lerobot.policies.groot.processor_groot import make_groot_pre_post_processors
@@ -344,6 +364,7 @@ def make_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, XVLAConfig):
from lerobot.policies.xvla.processor_xvla import (
make_xvla_pre_post_processors,
@@ -354,6 +375,14 @@ def make_pre_post_processors(
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, WallXConfig):
from lerobot.policies.wall_x.processor_wall_x import make_wall_x_pre_post_processors
processors = make_wall_x_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
else:
try:
processors = _make_processors_from_policy_config(
@@ -435,6 +464,13 @@ def make_policy(
cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features}
kwargs["config"] = cfg
# Pass dataset_stats to the policy if available (needed for some policies like SARM)
if ds_meta is not None and hasattr(ds_meta, "stats"):
kwargs["dataset_stats"] = ds_meta.stats
if ds_meta is not None:
kwargs["dataset_meta"] = ds_meta
if cfg.pretrained_path:
# Load a pretrained policy and override the config if needed (for example, if there are inference-time
# hyperparameters that we want to vary).
+1 -1
View File
@@ -49,7 +49,7 @@ class GrootPolicy(PreTrainedPolicy):
name = "groot"
config_class = GrootConfig
def __init__(self, config: GrootConfig):
def __init__(self, config: GrootConfig, **kwargs):
"""Initialize Groot policy wrapper."""
super().__init__(config)
config.validate_features()
@@ -23,6 +23,8 @@ from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
from lerobot.policies.rtc.configuration_rtc import RTCConfig
from lerobot.utils.constants import OBS_IMAGES
DEFAULT_IMAGE_SIZE = 224
@PreTrainedConfig.register_subclass("pi0")
@dataclass
@@ -51,7 +53,10 @@ class PI0Config(PreTrainedConfig):
# Real-Time Chunking (RTC) configuration
rtc_config: RTCConfig | None = None
image_resolution: tuple[int, int] = (224, 224) # see openpi `preprocessing_pytorch.py`
image_resolution: tuple[int, int] = (
DEFAULT_IMAGE_SIZE,
DEFAULT_IMAGE_SIZE,
) # see openpi `preprocessing_pytorch.py`
# Add empty images. Used to add empty cameras when no image features are present.
empty_cameras: int = 0
+37 -22
View File
@@ -41,7 +41,7 @@ else:
PaliGemmaForConditionalGeneration = None
from lerobot.configs.policies import PreTrainedConfig
from lerobot.policies.pi0.configuration_pi0 import PI0Config
from lerobot.policies.pi0.configuration_pi0 import DEFAULT_IMAGE_SIZE, PI0Config
from lerobot.policies.pretrained import PreTrainedPolicy, T
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
from lerobot.utils.constants import (
@@ -93,10 +93,11 @@ def create_sinusoidal_pos_embedding( # see openpi `create_sinusoidal_pos_embedd
def sample_beta(alpha, beta, bsize, device): # see openpi `sample_beta` (exact copy)
alpha_t = torch.as_tensor(alpha, dtype=torch.float32, device=device)
beta_t = torch.as_tensor(beta, dtype=torch.float32, device=device)
# Beta sampling uses _sample_dirichlet which isn't implemented for MPS, so sample on CPU
alpha_t = torch.tensor(alpha, dtype=torch.float32)
beta_t = torch.tensor(beta, dtype=torch.float32)
dist = torch.distributions.Beta(alpha_t, beta_t)
return dist.sample((bsize,))
return dist.sample((bsize,)).to(device)
def make_att_2d_masks(pad_masks, att_masks): # see openpi `make_att_2d_masks` (exact copy)
@@ -337,6 +338,7 @@ class PaliGemmaWithExpertModel(
action_expert_config,
use_adarms=None,
precision: Literal["bfloat16", "float32"] = "bfloat16",
image_size: int = DEFAULT_IMAGE_SIZE,
):
if use_adarms is None:
use_adarms = [False, False]
@@ -356,6 +358,7 @@ class PaliGemmaWithExpertModel(
vlm_config_hf.text_config.vocab_size = 257152
vlm_config_hf.text_config.use_adarms = use_adarms[0]
vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None
vlm_config_hf.vision_config.image_size = image_size
vlm_config_hf.vision_config.intermediate_size = 4304
vlm_config_hf.vision_config.projection_dim = 2048
vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast"
@@ -519,11 +522,17 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
paligemma_config = get_gemma_config(config.paligemma_variant)
action_expert_config = get_gemma_config(config.action_expert_variant)
if config.image_resolution[0] != config.image_resolution[1]:
raise ValueError(
f"PaliGemma expects square image resolution, invalid resolution: {config.image_resolution}"
)
self.paligemma_with_expert = PaliGemmaWithExpertModel(
paligemma_config,
action_expert_config,
use_adarms=[False, False],
precision=config.dtype,
image_size=config.image_resolution[0],
)
self.action_in_proj = nn.Linear(config.max_action_dim, action_expert_config.width)
@@ -812,16 +821,13 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
)
dt = -1.0 / num_steps
dt = torch.tensor(dt, dtype=torch.float32, device=device)
x_t = noise
time = torch.tensor(1.0, dtype=torch.float32, device=device)
while time >= -dt / 2:
expanded_time = time.expand(bsize)
for step in range(num_steps):
time = 1.0 + step * dt
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
# Define a closure function to properly capture expanded_time
# This avoids the lambda expression (E731) and loop variable binding (B023) issues
def denoise_step_partial_call(input_x_t, current_timestep=expanded_time):
def denoise_step_partial_call(input_x_t, current_timestep=time_tensor):
return self.denoise_step(
state=state,
prefix_pad_masks=prefix_pad_masks,
@@ -846,15 +852,11 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
else:
v_t = denoise_step_partial_call(x_t)
# Euler step
x_t += dt * v_t
x_t = x_t + dt * v_t
# Record x_t and v_t after Euler step
if self.rtc_processor is not None and self.rtc_processor.is_debug_enabled():
self.rtc_processor.track(time=time, x_t=x_t, v_t=v_t)
time += dt
return x_t
def denoise_step(
@@ -906,6 +908,7 @@ class PI0Policy(PreTrainedPolicy):
def __init__(
self,
config: PI0Config,
**kwargs,
):
"""
Args:
@@ -1234,9 +1237,15 @@ class PI0Policy(PreTrainedPolicy):
return actions
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
"""Run the batch through the model and compute the loss for training."""
def forward(self, batch: dict[str, Tensor], reduction: str = "mean") -> tuple[Tensor, dict]:
"""Run the batch through the model and compute the loss for training.
Args:
batch: Training batch containing observations and actions.
reduction: How to reduce the loss. Options:
- "mean": Return scalar mean loss (default, backward compatible)
- "none": Return per-sample losses of shape (batch_size,) for RA-BC weighting
"""
# Prepare inputs
images, img_masks = self._preprocess_images(batch)
lang_tokens, lang_masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
@@ -1250,11 +1259,17 @@ class PI0Policy(PreTrainedPolicy):
original_action_dim = self.config.output_features[ACTION].shape[0]
losses = losses[:, :, :original_action_dim]
loss = losses.mean()
loss_dict = {
"loss": loss.item(),
"loss_per_dim": losses.mean(dim=[0, 1]).detach().cpu().numpy().tolist(),
}
return loss, loss_dict
if reduction == "none":
# Return per-sample losses (B,) by averaging over time and action dims
per_sample_loss = losses.mean(dim=(1, 2))
loss_dict["loss"] = per_sample_loss.mean().item()
return per_sample_loss, loss_dict
else:
# Default: return scalar mean loss
loss = losses.mean()
loss_dict["loss"] = loss.item()
return loss, loss_dict
@@ -22,6 +22,8 @@ from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
from lerobot.policies.rtc.configuration_rtc import RTCConfig
DEFAULT_IMAGE_SIZE = 224
@PreTrainedConfig.register_subclass("pi05")
@dataclass
@@ -50,7 +52,10 @@ class PI05Config(PreTrainedConfig):
# Real-Time Chunking (RTC) configuration
rtc_config: RTCConfig | None = None
image_resolution: tuple[int, int] = (224, 224) # see openpi `preprocessing_pytorch.py`
image_resolution: tuple[int, int] = (
DEFAULT_IMAGE_SIZE,
DEFAULT_IMAGE_SIZE,
) # see openpi `preprocessing_pytorch.py`
# Add empty images. Used to add empty cameras when no image features are present.
empty_cameras: int = 0
@@ -60,8 +65,8 @@ class PI05Config(PreTrainedConfig):
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.MEAN_STD, # Pi0.5 uses quantiles for state
"ACTION": NormalizationMode.MEAN_STD, # Pi0.5 uses quantiles for action
"STATE": NormalizationMode.QUANTILES, # Pi0.5 uses quantiles for state
"ACTION": NormalizationMode.QUANTILES, # Pi0.5 uses quantiles for action
}
)
+81 -289
View File
@@ -41,17 +41,13 @@ else:
PaliGemmaForConditionalGeneration = None
from lerobot.configs.policies import PreTrainedConfig
from lerobot.policies.pi05.configuration_pi05 import PI05Config
from lerobot.policies.pi05.configuration_pi05 import DEFAULT_IMAGE_SIZE, PI05Config
from lerobot.policies.pretrained import PreTrainedPolicy, T
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
from lerobot.utils.constants import (
ACTION,
OBS_LANGUAGE_ATTENTION_MASK,
OBS_LANGUAGE_TOKENS,
OBS_LANGUAGE_PROMPT_TOKENS,
OBS_LANGUAGE_PROMPT_ATTENTION_MASK,
OBS_LANGUAGE_TARGET_TOKENS,
OBS_LANGUAGE_TARGET_ATTENTION_MASK,
OPENPI_ATTENTION_MASK_VALUE,
)
@@ -340,6 +336,7 @@ class PaliGemmaWithExpertModel(
action_expert_config,
use_adarms=None,
precision: Literal["bfloat16", "float32"] = "bfloat16",
image_size: int = DEFAULT_IMAGE_SIZE,
):
if use_adarms is None:
use_adarms = [False, False]
@@ -359,6 +356,7 @@ class PaliGemmaWithExpertModel(
vlm_config_hf.text_config.vocab_size = 257152
vlm_config_hf.text_config.use_adarms = use_adarms[0]
vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None
vlm_config_hf.vision_config.image_size = image_size
vlm_config_hf.vision_config.intermediate_size = 4304
vlm_config_hf.vision_config.projection_dim = 2048
vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast"
@@ -433,8 +431,6 @@ class PaliGemmaWithExpertModel(
adarms_cond=adarms_cond[0] if adarms_cond is not None else None,
)
prefix_past_key_values = prefix_output.past_key_values
# prefix_output to be used for the language head
# shape: [batch_size, seq_len, hidden_size] with hidden_size = 2048
prefix_output = prefix_output.last_hidden_state
suffix_output = None
elif inputs_embeds[0] is None:
@@ -524,11 +520,17 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
paligemma_config = get_gemma_config(config.paligemma_variant)
action_expert_config = get_gemma_config(config.action_expert_variant)
if config.image_resolution[0] != config.image_resolution[1]:
raise ValueError(
f"PaliGemma expects square image resolution, invalid resolution: {config.image_resolution}"
)
self.paligemma_with_expert = PaliGemmaWithExpertModel(
paligemma_config,
action_expert_config,
use_adarms=[False, True],
precision=config.dtype,
image_size=config.image_resolution[0],
)
self.action_in_proj = nn.Linear(config.max_action_dim, action_expert_config.width)
@@ -584,13 +586,10 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
)
return func(*args, **kwargs)
def _prepare_attention_masks_4d(self, att_2d_masks, dtype=None):
def _prepare_attention_masks_4d(self, att_2d_masks):
"""Helper method to prepare 4D attention masks for transformer."""
att_2d_masks_4d = att_2d_masks[:, None, :, :]
result = torch.where(att_2d_masks_4d, 0.0, OPENPI_ATTENTION_MASK_VALUE)
if dtype is not None:
result = result.to(dtype=dtype)
return result
return torch.where(att_2d_masks_4d, 0.0, OPENPI_ATTENTION_MASK_VALUE)
def sample_noise(self, shape, device):
return torch.normal(
@@ -609,29 +608,13 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
return time.to(dtype=torch.float32, device=device)
def embed_prefix(
self, images, img_masks, prompt_tokens, target_tokens, prompt_masks, target_masks=None
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]:
"""Embed images with SigLIP, prompt tokens, and optionally target tokens with embedding layer.
Args:
images: List of image tensors
img_masks: List of image masks
prompt_tokens: Prompt tokens (input for generation)
target_tokens: Target tokens to predict (can be None for inference)
prompt_masks: Attention masks for prompt tokens
target_masks: Attention masks for target tokens
Returns:
embs: Concatenated embeddings [images, prompt_tokens, (target_tokens if provided)]
pad_masks: Padding masks
att_masks: Attention masks (with causal masking for target prediction if target_tokens provided)
total_T_images: Total number of image tokens
"""
self, images, img_masks, tokens, masks
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Embed images with SigLIP and language tokens with embedding layer."""
embs = []
pad_masks = []
att_masks = []
total_T_images = 0
# Process images
for img, img_mask in zip(images, img_masks, strict=True):
@@ -643,48 +626,29 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
embs.append(img_emb)
pad_masks.append(img_mask[:, None].expand(bsize, num_img_embs))
att_masks += [0] * num_img_embs # Images can attend to all previous tokens
total_T_images += num_img_embs
# Process prompt tokens
def prompt_embed_func(prompt_tokens):
prompt_emb = self.paligemma_with_expert.embed_language_tokens(prompt_tokens)
prompt_emb_dim = prompt_emb.shape[-1]
return prompt_emb * math.sqrt(prompt_emb_dim)
att_masks += [0] * num_img_embs
prompt_emb = self._apply_checkpoint(prompt_embed_func, prompt_tokens)
embs.append(prompt_emb)
pad_masks.append(prompt_masks)
# Process language tokens
def lang_embed_func(tokens):
lang_emb = self.paligemma_with_expert.embed_language_tokens(tokens)
lang_emb_dim = lang_emb.shape[-1]
return lang_emb * math.sqrt(lang_emb_dim)
num_prompt_embs = prompt_emb.shape[1]
att_masks += [0] * num_prompt_embs # Prompt tokens can attend to all previous tokens (images + prompt)
lang_emb = self._apply_checkpoint(lang_embed_func, tokens)
embs.append(lang_emb)
pad_masks.append(masks)
# Process target tokens if provided (these are predicted, so use causal masking)
if target_tokens is not None:
def target_embed_func(target_tokens):
target_emb = self.paligemma_with_expert.embed_language_tokens(target_tokens)
target_emb_dim = target_emb.shape[-1]
return target_emb * math.sqrt(target_emb_dim)
target_emb = self._apply_checkpoint(target_embed_func, target_tokens)
embs.append(target_emb)
# Create target pad masks (non-zero tokens are valid)
pad_masks.append(target_masks)
num_target_embs = target_emb.shape[1]
# Causal masking for target tokens: each target token can attend to images, all prompt tokens,
# and previous target tokens
att_masks += [1] * num_target_embs # Use 1 for causal attention on target tokens
num_lang_embs = lang_emb.shape[1]
att_masks += [0] * num_lang_embs
embs = torch.cat(embs, dim=1)
pad_masks = torch.cat(pad_masks, dim=1)
att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device)
bsize = pad_masks.shape[0]
att_masks = att_masks[None, :].expand(bsize, att_masks.shape[0])
att_masks = att_masks[None, :].expand(bsize, len(att_masks))
return embs, pad_masks, att_masks, total_T_images
return embs, pad_masks, att_masks
def embed_suffix(self, noisy_actions, timestep):
"""Embed noisy_actions, timestep to prepare for Expert Gemma processing."""
@@ -733,20 +697,8 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
return embs, pad_masks, att_masks, adarms_cond
def forward(self, images, img_masks, prompt_tokens, prompt_masks, target_tokens, target_masks, actions, noise=None, time=None) -> Tensor:
"""Do a full training forward pass and compute the loss.
Args:
images: List of image tensors
img_masks: List of image masks
prompt_tokens: Prompt tokens WITHOUT target (e.g., "High level task: X; State: Y; Subtask:")
prompt_masks: Attention masks for prompt_tokens
target_tokens: Target tokens to predict (e.g., tokens for "pick up the cup")
target_masks: Attention masks for target_tokens
actions: Ground truth actions
noise: Optional noise for flow matching
time: Optional time for flow matching
"""
def forward(self, images, img_masks, tokens, masks, actions, noise=None, time=None) -> Tensor:
"""Do a full training forward pass and compute the loss."""
if noise is None:
noise = self.sample_noise(actions.shape, actions.device)
@@ -756,57 +708,10 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
time_expanded = time[:, None, None]
x_t = time_expanded * noise + (1 - time_expanded) * actions
u_t = noise - actions
# Embed prefix (images + prompt_tokens + target_tokens)
prefix_embs, prefix_pad_masks, prefix_att_masks, total_T_images = self.embed_prefix(
images, img_masks, prompt_tokens, target_tokens, prompt_masks, target_masks
)
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, tokens, masks)
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(x_t, time)
# Prepare attention masks for prefix-only pass (for target token prediction)
att_2d_prefix = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
position_ids_prefix = torch.cumsum(prefix_pad_masks, dim=1) - 1
att_2d_prefix_4d = self._prepare_attention_masks_4d(att_2d_prefix, dtype=prefix_embs.dtype)
# prefix-only transformer run for target token prediction
(prefix_out, _), _ = self.paligemma_with_expert.forward(
attention_mask=att_2d_prefix_4d,
position_ids=position_ids_prefix,
past_key_values=None,
inputs_embeds=[prefix_embs, None], # SUFFIX = None
use_cache=False,
adarms_cond=[None, None],
)
# LM HEAD → TARGET LOGITS
# prefix_out: (B, T_prefix, H) where T_prefix = total_T_images + T_prompt + T_target
lm_head = self.paligemma_with_expert.paligemma.lm_head
logits = lm_head(prefix_out) # (B, T_prefix, vocab)
# Extract logits for target token prediction (shifted by 1 for autoregressive training)
# Position i predicts token i+1, so we take logits from positions before target tokens:
# - Position (start_index-1) (last prompt token) predicts target_tokens[0]
# - Position (start_index) (first target token) predicts target_tokens[1], etc.
T_prompt = prompt_tokens.size(1)
T_target = target_tokens.size(1)
start_index = total_T_images + T_prompt
end_index = start_index + T_target
logits_target = logits[:, start_index-1:end_index-1, :] # (B, T_target, vocab)
# Compute cross-entropy loss
loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
# Reshape for loss computation
logits_flat = logits_target.reshape(-1, logits_target.size(-1)) # (B*T_target, vocab)
targets_flat = target_tokens.reshape(-1) # (B*T_target)
loss_per_token = loss_fct(logits_flat, targets_flat) # (B*T_target)
loss_per_token = loss_per_token.reshape(target_tokens.shape) # (B, T_target)
# Apply mask and compute mean loss over valid tokens
masked_loss = loss_per_token * target_masks.float()
target_loss = masked_loss.sum() / target_masks.sum().clamp(min=1)
# Convert embeddings to bfloat16 if needed for the model
if (
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
== torch.bfloat16
@@ -814,14 +719,13 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
suffix_embs = suffix_embs.to(dtype=torch.bfloat16)
prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
# Concatenate prefix (images + prompt_tokens + target_tokens) and suffix (actions) masks
pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)
att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1)
# Prepare attention masks for full forward pass (prefix + suffix)
att_2d_masks = make_att_2d_masks(pad_masks, att_masks)
position_ids = torch.cumsum(pad_masks, dim=1) - 1
att_2d_masks_4d = self._prepare_attention_masks_4d(att_2d_masks, dtype=prefix_embs.dtype)
att_2d_masks_4d = self._prepare_attention_masks_4d(att_2d_masks)
def forward_func(prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond):
(_, suffix_out), _ = self.paligemma_with_expert.forward(
@@ -832,7 +736,6 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
use_cache=False,
adarms_cond=[None, adarms_cond],
)
# prefix_out to be used for the language head
return suffix_out
suffix_out = self._apply_checkpoint(
@@ -847,104 +750,25 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
v_t = self._apply_checkpoint(action_out_proj_func, suffix_out)
fm_loss = F.mse_loss(u_t, v_t, reduction="none")
return {
"flow_loss": fm_loss,
"target_loss": target_loss,
"loss": 10 * fm_loss.mean() + target_loss,
}
@torch.no_grad()
def _generate_target_tokens(
self, images, img_masks, prompt_tokens, prompt_masks, tokenizer, max_length, device
):
"""Generate target tokens autoregressively using next token prediction."""
bsize = prompt_tokens.shape[0]
# Get lm_head for token generation
lm_head = self.paligemma_with_expert.paligemma.lm_head
# Embed prefix without target tokens first
prefix_embs, prefix_pad_masks, prefix_att_masks, total_T_images = self.embed_prefix(
images, img_masks, prompt_tokens, target_tokens=None, prompt_masks=prompt_masks, target_masks=None
)
# Initialize generated tokens list
generated_tokens = torch.zeros((bsize, max_length), dtype=torch.long, device=device)
for t in range(max_length):
# Prepare attention masks for current prefix
att_2d_prefix = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
position_ids_prefix = torch.cumsum(prefix_pad_masks, dim=1) - 1
att_2d_prefix_4d = self._prepare_attention_masks_4d(att_2d_prefix, dtype=prefix_embs.dtype)
# Forward pass through model to get logits
(prefix_out, _), _ = self.paligemma_with_expert.forward(
attention_mask=att_2d_prefix_4d,
position_ids=position_ids_prefix,
past_key_values=None,
inputs_embeds=[prefix_embs, None],
use_cache=False,
adarms_cond=[None, None],
)
# Get logits from the last position
logits = lm_head(prefix_out) # (B, T_prefix, vocab)
next_token_logits = logits[:, -1, :] # (B, vocab)
# Greedy decoding - take the most likely token
next_token = torch.argmax(next_token_logits, dim=-1) # (B,)
# Store generated token
generated_tokens[:, t] = next_token
# Check for EOS token - if all batches have generated EOS, stop
if tokenizer.eos_token_id is not None:
if (next_token == tokenizer.eos_token_id).all():
break
# Embed the generated token and append to prefix
next_token_unsqueezed = next_token.unsqueeze(1) # (B, 1)
def next_token_embed_func(next_token_unsqueezed):
next_emb = self.paligemma_with_expert.embed_language_tokens(next_token_unsqueezed)
next_emb_dim = next_emb.shape[-1]
return next_emb * math.sqrt(next_emb_dim)
next_emb = self._apply_checkpoint(next_token_embed_func, next_token_unsqueezed)
# Append to prefix embeddings
prefix_embs = torch.cat([prefix_embs, next_emb], dim=1)
# Update masks - new token is valid and uses causal attention
prefix_pad_masks = torch.cat([
prefix_pad_masks,
torch.ones((bsize, 1), dtype=torch.bool, device=device)
], dim=1)
prefix_att_masks = torch.cat([prefix_att_masks, torch.ones((bsize, 1), dtype=torch.bool, device=device)], dim=1)
return generated_tokens
return F.mse_loss(u_t, v_t, reduction="none")
@torch.no_grad() # see openpi `sample_actions` (slightly adapted)
def sample_actions(
self,
images,
img_masks,
prompt_tokens,
prompt_masks,
tokens,
masks,
noise=None,
num_steps=None,
tokenizer=None,
max_target_tokens=50,
**kwargs: Unpack[ActionSelectKwargs],
) -> Tensor:
"""Do a full inference forward and compute the action."""
if num_steps is None:
num_steps = self.config.num_inference_steps
bsize = prompt_tokens.shape[0]
device = prompt_tokens.device
bsize = tokens.shape[0]
device = tokens.device
if noise is None:
# Sample noise with padded dimension as expected by action_in_proj
@@ -955,33 +779,11 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
) # Use config max_action_dim for internal processing
noise = self.sample_noise(actions_shape, device)
# Generate target tokens autoregressively during inference (if tokenizer provided)
generated_target_tokens = None
target_masks = None
if tokenizer is not None:
generated_target_tokens = self._generate_target_tokens(
images, img_masks, prompt_tokens, prompt_masks, tokenizer, max_target_tokens, device
)
# Decode and print the generated target tokens
for i in range(bsize):
# Remove padding tokens (0) and special tokens
valid_tokens = generated_target_tokens[i][generated_target_tokens[i] != 0]
decoded_text = tokenizer.decode(valid_tokens, skip_special_tokens=True)
print(f"[Inference] Generated target {i}: {decoded_text}")
# Create mask for generated tokens (all valid where token != 0)
target_masks = generated_target_tokens != 0
# Embed prefix with prompt and optionally generated target tokens
prefix_embs, prefix_pad_masks, prefix_att_masks, _ = self.embed_prefix(
images, img_masks, prompt_tokens, target_tokens=generated_target_tokens,
prompt_masks=prompt_masks, target_masks=target_masks
)
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, tokens, masks)
prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks, dtype=prefix_embs.dtype)
prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks)
self.paligemma_with_expert.paligemma.language_model.config._attn_implementation = "eager" # noqa: SLF001
_, past_key_values = self.paligemma_with_expert.forward(
@@ -993,16 +795,13 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
)
dt = -1.0 / num_steps
dt = torch.tensor(dt, dtype=torch.float32, device=device)
x_t = noise
time = torch.tensor(1.0, dtype=torch.float32, device=device)
while time >= -dt / 2:
expanded_time = time.expand(bsize)
for step in range(num_steps):
time = 1.0 + step * dt
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
# Define a closure function to properly capture expanded_time
# This avoids the lambda expression (E731) and loop variable binding (B023) issues
def denoise_step_partial_call(input_x_t, current_timestep=expanded_time):
def denoise_step_partial_call(input_x_t, current_timestep=time_tensor):
return self.denoise_step(
prefix_pad_masks=prefix_pad_masks,
past_key_values=past_key_values,
@@ -1026,15 +825,11 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
else:
v_t = denoise_step_partial_call(x_t)
# Euler step
x_t += dt * v_t
x_t = x_t + dt * v_t
# Record x_t and v_t after Euler step
if self.rtc_processor is not None and self.rtc_processor.is_debug_enabled():
self.rtc_processor.track(time=time, x_t=x_t, v_t=v_t)
time += dt
return x_t
def denoise_step(
@@ -1058,7 +853,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None]
position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1
full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks, dtype=suffix_embs.dtype)
full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks)
self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
outputs_embeds, _ = self.paligemma_with_expert.forward(
@@ -1085,6 +880,7 @@ class PI05Policy(PreTrainedPolicy):
def __init__(
self,
config: PI05Config,
**kwargs,
):
"""
Args:
@@ -1103,14 +899,6 @@ class PI05Policy(PreTrainedPolicy):
self.model.gradient_checkpointing_enable()
self.model.to(config.device)
# Load tokenizer for subtask decoding
try:
from transformers import AutoTokenizer
self.tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")
except Exception as e:
logging.warning(f"Could not load tokenizer for subtask decoding: {e}")
self.tokenizer = None
self.reset()
@@ -1411,16 +1199,10 @@ class PI05Policy(PreTrainedPolicy):
# Prepare inputs
images, img_masks = self._preprocess_images(batch)
# Use prompt tokens (WITHOUT target) for inference - we'll generate the target
prompt_tokens = batch[f"{OBS_LANGUAGE_PROMPT_TOKENS}"]
prompt_masks = batch[f"{OBS_LANGUAGE_PROMPT_ATTENTION_MASK}"]
tokens, masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
# Sample actions using the model (pass through RTC kwargs, no separate state needed for PI05)
actions = self.model.sample_actions(
images, img_masks, prompt_tokens, prompt_masks,
tokenizer=self.tokenizer,
**kwargs
)
actions = self.model.sample_actions(images, img_masks, tokens, masks, **kwargs)
# Unpad actions to actual action dimension
original_action_dim = self.config.output_features[ACTION].shape[0]
@@ -1428,29 +1210,39 @@ class PI05Policy(PreTrainedPolicy):
return actions
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
"""Run the batch through the model and compute the loss for training."""
def forward(self, batch: dict[str, Tensor], reduction: str = "mean") -> tuple[Tensor, dict]:
"""Run the batch through the model and compute the loss for training.
Args:
batch: Training batch containing observations and actions.
reduction: How to reduce the loss. Options:
- "mean": Return scalar mean loss (default, backward compatible)
- "none": Return per-sample losses of shape (batch_size,) for RA-BC weighting
"""
# Prepare inputs
images, img_masks = self._preprocess_images(batch)
prompt_tokens = batch[f"{OBS_LANGUAGE_PROMPT_TOKENS}"]
prompt_masks = batch[f"{OBS_LANGUAGE_PROMPT_ATTENTION_MASK}"]
target_tokens, target_masks = batch[f"{OBS_LANGUAGE_TARGET_TOKENS}"], batch[f"{OBS_LANGUAGE_TARGET_ATTENTION_MASK}"]
actions = self.prepare_action(batch)
# Compute loss
# prompt_tokens = instruction tokens WITHOUT target (e.g., "High level task: X; State: Y; Subtask:")
# target_tokens = target tokens to predict (e.g., "pick up the cup")
loss_dict = self.model.forward(images, img_masks, prompt_tokens, prompt_masks, target_tokens, target_masks, actions)
tokens, masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
# Extract the total loss
loss = loss_dict["loss"]
# Prepare detailed loss dictionary for logging
detailed_loss_dict = {
"loss": loss.item(),
"flow_loss": loss_dict["flow_loss"].mean().item(),
"target_loss": loss_dict["target_loss"].item(),
actions = self.prepare_action(batch)
# Compute loss (no separate state needed for PI05)
losses = self.model.forward(images, img_masks, tokens, masks, actions)
# Truncate losses to actual action dimensions
original_action_dim = self.config.output_features[ACTION].shape[0]
losses = losses[:, :, :original_action_dim]
loss_dict = {
"loss_per_dim": losses.mean(dim=[0, 1]).detach().cpu().numpy().tolist(),
}
return loss, detailed_loss_dict
if reduction == "none":
# Return per-sample losses (B,) by averaging over time and action dims
per_sample_loss = losses.mean(dim=(1, 2))
loss_dict["loss"] = per_sample_loss.mean().item()
return per_sample_loss, loss_dict
else:
# Default: return scalar mean loss
loss = losses.mean()
loss_dict["loss"] = loss.item()
return loss, loss_dict
+9 -30
View File
@@ -47,15 +47,13 @@ from lerobot.utils.constants import (
@ProcessorStepRegistry.register(name="pi05_prepare_state_tokenizer_processor_step")
@dataclass
class Pi05PrepareStateAndLanguageTokenizerProcessorStep(ProcessorStep):
class Pi05PrepareStateTokenizerProcessorStep(ProcessorStep):
"""
Processor step to prepare the state and tokenize the language input.
"""
max_state_dim: int = 32
task_key: str = "task"
prompt_key: str = "prompt"
target_key: str = "target"
def __call__(self, transition: EnvTransition) -> EnvTransition:
transition = transition.copy()
@@ -66,8 +64,6 @@ class Pi05PrepareStateAndLanguageTokenizerProcessorStep(ProcessorStep):
tasks = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}).get(self.task_key)
if tasks is None:
raise ValueError("No task found in complementary data")
high_level_tasks = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}).get("user_prompt")
# TODO: check if this necessary
state = deepcopy(state)
@@ -80,33 +76,16 @@ class Pi05PrepareStateAndLanguageTokenizerProcessorStep(ProcessorStep):
state_np = state.cpu().numpy()
discretized_states = np.digitize(state_np, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
# Clean high level tasks first (if available)
cleaned_high_level_tasks = []
if high_level_tasks is not None:
for high_level_task in high_level_tasks:
cleaned_high_level_tasks.append(high_level_task.strip().replace("_", " ").replace("\n", " "))
# Process tasks to create prompts (input) and targets (what to predict)
prompts = [] # Input prompts ending with "Subtask:"
targets = [] # Target text to predict (the subtask)
full_prompts = []
for i, task in enumerate(tasks):
cleaned_text = task.strip().replace("_", " ").replace("\n", " ")
state_str = " ".join(map(str, discretized_states[i]))
# Store the subtask text as target for prediction
targets.append(cleaned_text)
if cleaned_high_level_tasks:
cleaned_high_level_task = cleaned_high_level_tasks[i]
# Prompt ends with "Subtask:" - model will predict the target
prompt = f"High level task: {cleaned_high_level_task}; State: {state_str}; Subtask:"
else:
raise ValueError("No high level tasks found")
prompts.append(prompt)
full_prompt = f"Task: {cleaned_text}, State: {state_str};\nAction: "
full_prompts.append(full_prompt)
transition[TransitionKey.COMPLEMENTARY_DATA][self.prompt_key] = prompts
transition[TransitionKey.COMPLEMENTARY_DATA][self.target_key] = targets
transition[TransitionKey.COMPLEMENTARY_DATA][self.task_key] = full_prompts
# Normalize state to [-1, 1] range if needed (assuming it's already normalized by normalizer processor step!!)
# Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`)
return transition
def transform_features(
@@ -154,14 +133,14 @@ def make_pi05_pre_post_processors(
input_steps: list[ProcessorStep] = [
RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one
AddBatchDimensionProcessorStep(),
# NOTE: NormalizerProcessorStep MUST come before Pi05PrepareStateAndLanguageTokenizerProcessorStep
# NOTE: NormalizerProcessorStep MUST come before Pi05PrepareStateTokenizerProcessorStep
# because the tokenizer step expects normalized state in [-1, 1] range for discretization
NormalizerProcessorStep(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
Pi05PrepareStateAndLanguageTokenizerProcessorStep(max_state_dim=config.max_state_dim),
Pi05PrepareStateTokenizerProcessorStep(max_state_dim=config.max_state_dim),
TokenizerProcessorStep(
tokenizer_name="google/paligemma-3b-pt-224",
max_length=config.tokenizer_max_length,
+49
View File
@@ -0,0 +1,49 @@
# π₀.₅ (pi05)
This repository contains the Hugging Face port of **π₀.₅**, adapted from [OpenPI](https://github.com/Physical-Intelligence/openpi) by the Physical Intelligence.
It is designed as a **Vision-Language-Action model with open-world generalization**.
---
## Model Overview
| Feature | π₀ | π₀.₅ |
| -------------------- | ------------------------------------------------------ | ----------------------------------------- |
| Time Conditioning | Concatenates time with actions via `action_time_mlp_*` | Uses `time_mlp_*` for AdaRMS conditioning |
| AdaRMS | Not used | Used in action expert |
| Tokenizer Length | 48 tokens | 200 tokens |
| Discrete State Input | False (Uses `state_proj` layer) | True |
| Parameter Count | Higher (includes state embedding) | Lower (no state embedding) |
---
## Citation
If you use this work, please cite both **OpenPI** and the π₀.₅ paper:
```bibtex
@misc{openpi2024,
author = {Physical Intelligence Lab},
title = {OpenPI: PyTorch Implementation of π0 and π0.5 Policies},
year = {2024},
publisher = {GitHub},
howpublished = {\url{https://github.com/Physical-Intelligence/openpi}},
license = {Apache-2.0}
}
@misc{intelligence2025pi05visionlanguageactionmodelopenworld,
title = {π₀.₅: a Vision-Language-Action Model with Open-World Generalization},
author = {Physical Intelligence and Kevin Black and Noah Brown and James Darpinian and Karan Dhabalia and Danny Driess and Adnan Esmail and Michael Equi and Chelsea Finn and Niccolo Fusai and Manuel Y. Galliker and Dibya Ghosh and Lachy Groom and Karol Hausman and Brian Ichter and Szymon Jakubczak and Tim Jones and Liyiming Ke and Devin LeBlanc and Sergey Levine and Adrian Li-Bell and Mohith Mothukuri and Suraj Nair and Karl Pertsch and Allen Z. Ren and Lucy Xiaoyang Shi and Laura Smith and Jost Tobias Springenberg and Kyle Stachowicz and James Tanner and Quan Vuong and Homer Walke and Anna Walling and Haohuan Wang and Lili Yu and Ury Zhilinsky},
year = {2025},
eprint = {2504.16054},
archivePrefix= {arXiv},
primaryClass = {cs.LG},
url = {https://arxiv.org/abs/2504.16054},
}
```
---
## License
This port follows the **Apache 2.0 License**, consistent with the original [OpenPI repository](https://github.com/Physical-Intelligence/openpi).
+21
View File
@@ -0,0 +1,21 @@
#!/usr/bin/env python
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .configuration_pi05 import PI05Config
from .modeling_pi05 import PI05Policy
from .processor_pi05 import make_pi05_pre_post_processors
__all__ = ["PI05Config", "PI05Policy", "make_pi05_pre_post_processors"]
@@ -0,0 +1,164 @@
#!/usr/bin/env python
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
from lerobot.policies.rtc.configuration_rtc import RTCConfig
DEFAULT_IMAGE_SIZE = 224
@PreTrainedConfig.register_subclass("pi05")
@dataclass
class PI05Config(PreTrainedConfig):
paligemma_variant: str = "gemma_2b"
action_expert_variant: str = "gemma_300m"
dtype: str = "float32" # Options: "bfloat16", "float32"
n_obs_steps: int = 1
chunk_size: int = 50 # Number of action steps to predict, in openpi called "action_horizon"
n_action_steps: int = 50 # Number of action steps to execute
# Shorter state and action vectors will be padded to these dimensions
max_state_dim: int = 32
max_action_dim: int = 32
# Flow matching parameters: see openpi `PI0Pytorch`
num_inference_steps: int = 10
time_sampling_beta_alpha: float = 1.5
time_sampling_beta_beta: float = 1.0
time_sampling_scale: float = 0.999
time_sampling_offset: float = 0.001
min_period: float = 4e-3
max_period: float = 4.0
# Real-Time Chunking (RTC) configuration
rtc_config: RTCConfig | None = None
image_resolution: tuple[int, int] = (
DEFAULT_IMAGE_SIZE,
DEFAULT_IMAGE_SIZE,
) # see openpi `preprocessing_pytorch.py`
# Add empty images. Used to add empty cameras when no image features are present.
empty_cameras: int = 0
tokenizer_max_length: int = 200 # see openpi `__post_init__`
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.QUANTILES, # Pi0.5 uses quantiles for state
"ACTION": NormalizationMode.QUANTILES, # Pi0.5 uses quantiles for action
}
)
# Training settings
gradient_checkpointing: bool = False # Enable gradient checkpointing for memory optimization
compile_model: bool = False # Whether to use torch.compile for model optimization
compile_mode: str = "max-autotune" # Torch compile mode
device: str | None = None # Device to use for the model (None = auto-detect)
# Optimizer settings: see openpi `AdamW`
optimizer_lr: float = 2.5e-5 # see openpi `CosineDecaySchedule: peak_lr`
optimizer_betas: tuple[float, float] = (0.9, 0.95)
optimizer_eps: float = 1e-8
optimizer_weight_decay: float = 0.01
optimizer_grad_clip_norm: float = 1.0
# Scheduler settings: see openpi `CosineDecaySchedule`
# Note: These will auto-scale if --steps < scheduler_decay_steps
# For example, --steps=3000 will scale warmup to 100 and decay to 3000
scheduler_warmup_steps: int = 1_000
scheduler_decay_steps: int = 30_000
scheduler_decay_lr: float = 2.5e-6
tokenizer_max_length: int = 200 # see openpi `__post_init__`
def __post_init__(self):
super().__post_init__()
# Validate configuration
if self.n_action_steps > self.chunk_size:
raise ValueError(
f"n_action_steps ({self.n_action_steps}) cannot be greater than chunk_size ({self.chunk_size})"
)
if self.paligemma_variant not in ["gemma_300m", "gemma_2b"]:
raise ValueError(f"Invalid paligemma_variant: {self.paligemma_variant}")
if self.action_expert_variant not in ["gemma_300m", "gemma_2b"]:
raise ValueError(f"Invalid action_expert_variant: {self.action_expert_variant}")
if self.dtype not in ["bfloat16", "float32"]:
raise ValueError(f"Invalid dtype: {self.dtype}")
def validate_features(self) -> None:
"""Validate and set up input/output features."""
for i in range(self.empty_cameras):
key = f"observation.images.empty_camera_{i}"
empty_camera = PolicyFeature(
type=FeatureType.VISUAL,
shape=(3, *self.image_resolution), # Use configured image resolution
)
self.input_features[key] = empty_camera
if "observation.state" not in self.input_features:
state_feature = PolicyFeature(
type=FeatureType.STATE,
shape=(self.max_state_dim,), # Padded to max_state_dim
)
self.input_features["observation.state"] = state_feature
if "action" not in self.output_features:
action_feature = PolicyFeature(
type=FeatureType.ACTION,
shape=(self.max_action_dim,), # Padded to max_action_dim
)
self.output_features["action"] = action_feature
def get_optimizer_preset(self) -> AdamWConfig:
return AdamWConfig(
lr=self.optimizer_lr,
betas=self.optimizer_betas,
eps=self.optimizer_eps,
weight_decay=self.optimizer_weight_decay,
grad_clip_norm=self.optimizer_grad_clip_norm,
)
def get_scheduler_preset(self):
return CosineDecayWithWarmupSchedulerConfig(
peak_lr=self.optimizer_lr,
decay_lr=self.scheduler_decay_lr,
num_warmup_steps=self.scheduler_warmup_steps,
num_decay_steps=self.scheduler_decay_steps,
)
@property
def observation_delta_indices(self) -> None:
return None
@property
def action_delta_indices(self) -> list:
return list(range(self.chunk_size))
@property
def reward_delta_indices(self) -> None:
return None
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,995 @@
#!/usr/bin/env python
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ONLY AN EXAMPLE FILE, NEVER USED, IT IS OLD CODE
"""
π0+FAST: Efficient Action Tokenization for Vision-Language-Action Models
[Paper](https://huggingface.co/papers/2501.09747)
[Jax code](https://github.com/Physical-Intelligence/openpi)
Designed by Physical Intelligence. Ported from Jax by Hugging Face.
Disclaimer: It is not expected to perform as well as the original implementation.
Example of finetuning the pi0+FAST pretrained model (`pi0_fast_base` in `openpi`):
```bash
lerobot-train \
--policy.path=lerobot/pi0fast_base \
--dataset.repo_id=danaaubakirova/koch_test
```
Example of training the pi0+FAST neural network with from scratch:
```bash
lerobot-train \
--policy.type=pi0fast \
--dataset.repo_id=danaaubakirova/koch_test
```
Example of using the pi0 pretrained model outside LeRobot training framework:
```python
policy = PI0FASTPolicy.from_pretrained("lerobot/pi0fast_base")
```
"""
from collections import deque
from functools import partial
import numpy as np
import torch
import torch.nn.functional as F # noqa: N812
from PIL import Image
from scipy.fft import idct
from torch import Tensor, nn
from transformers import AutoProcessor, AutoTokenizer, PaliGemmaForConditionalGeneration
from transformers.cache_utils import HybridCache, StaticCache
from transformers.models.auto import CONFIG_MAPPING
from lerobot.constants import ACTION, OBS_STATE
from lerobot.policies.normalize import Normalize, Unnormalize
from lerobot.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
from lerobot.policies.pretrained import PreTrainedPolicy
PRECISION = {
"float16": torch.float16,
"float32": torch.float32,
"bfloat16": torch.bfloat16,
}
def normalize(x, min_val, max_val):
return (x - min_val) / (max_val - min_val)
def unnormalize(x, min_val, max_val):
return x * (max_val - min_val) + min_val
def safe_arcsin(value):
# This ensures that the input stays within
# [1,1] to avoid invalid values for arcsin
return torch.arcsin(torch.clamp(value, -1.0, 1.0))
def aloha_gripper_to_angular(value):
# Aloha transforms the gripper positions into a linear space. The following code
# reverses this transformation to be consistent with pi0 which is pretrained in
# angular space.
#
# These values are coming from the Aloha code:
# PUPPET_GRIPPER_POSITION_OPEN, PUPPET_GRIPPER_POSITION_CLOSED
value = unnormalize(value, min_val=0.01844, max_val=0.05800)
# This is the inverse of the angular to linear transformation inside the Interbotix code.
def linear_to_radian(linear_position, arm_length, horn_radius):
value = (horn_radius**2 + linear_position**2 - arm_length**2) / (2 * horn_radius * linear_position)
return safe_arcsin(value)
# The constants are taken from the Interbotix code.
value = linear_to_radian(value, arm_length=0.036, horn_radius=0.022)
# Normalize to [0, 1].
# The values 0.4 and 1.5 were measured on an actual Trossen robot.
return normalize(value, min_val=0.4, max_val=1.5)
def aloha_gripper_from_angular(value):
# Convert from the gripper position used by pi0 to the gripper position that is used by Aloha.
# Note that the units are still angular but the range is different.
# The values 0.4 and 1.5 were measured on an actual Trossen robot.
value = unnormalize(value, min_val=0.4, max_val=1.5)
# These values are coming from the Aloha code:
# PUPPET_GRIPPER_JOINT_OPEN, PUPPET_GRIPPER_JOINT_CLOSE
return normalize(value, min_val=-0.6213, max_val=1.4910)
def aloha_gripper_from_angular_inv(value):
# Directly inverts the gripper_from_angular function.
value = unnormalize(value, min_val=-0.6213, max_val=1.4910)
return normalize(value, min_val=0.4, max_val=1.5)
class PI0FASTPolicy(PreTrainedPolicy):
"""Wrapper class around PI0FAST tokenizer and model to train and run inference within LeRobot."""
config_class = PI0FASTConfig
name = "pi0fast"
def __init__(
self,
config: PI0FASTConfig,
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
):
"""
Args:
config: Policy configuration class instance or None, in which case the default instantiation of
the configuration class is used.
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
that they will be passed with a call to `load_state_dict` before the policy is used.
"""
super().__init__(config)
config.validate_features()
self.config = config
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
self.normalize_targets = Normalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.unnormalize_outputs = Unnormalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.language_tokenizer = AutoProcessor.from_pretrained("google/paligemma-3b-pt-224")
self.model = PI0FAST(config)
self.reset()
def reset(self):
"""This should be called whenever the environment is reset."""
self._action_queue = deque([], maxlen=self.config.n_action_steps)
@classmethod
def from_pretrained(cls, *args, **kwargs):
"""Override the from_pretrained method to display important disclaimer."""
print(
"⚠️ DISCLAIMER: The PI0FAST model is ported from JAX by the Hugging Face team. \n"
" It is not expected to perform as well as the original implementation. \n"
" Original implementation: https://github.com/Physical-Intelligence/openpi"
)
return super().from_pretrained(*args, **kwargs)
def get_optim_params(self) -> dict:
return self.parameters()
def _pi_aloha_decode_state(self, state):
# Flip the joints.
for motor_idx in [1, 2, 8, 9]:
state[:, motor_idx] *= -1
# Reverse the gripper transformation that is being applied by the Aloha runtime.
for motor_idx in [6, 13]:
state[:, motor_idx] = aloha_gripper_to_angular(state[:, motor_idx])
return state
def _pi_aloha_encode_actions(self, actions):
# Flip the joints.
for motor_idx in [1, 2, 8, 9]:
actions[:, :, motor_idx] *= -1
# Reverse the gripper transformation that is being applied by the Aloha runtime.
for motor_idx in [6, 13]:
actions[:, :, motor_idx] = aloha_gripper_from_angular(actions[:, :, motor_idx])
return actions
def _pi_aloha_encode_actions_inv(self, actions):
# Flip the joints again.
for motor_idx in [1, 2, 8, 9]:
actions[:, :, motor_idx] *= -1
# Reverse the gripper transformation that is being applied by the Aloha runtime.
for motor_idx in [6, 13]:
actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(actions[:, :, motor_idx])
return actions
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""Predict a chunk of actions given environment observations."""
raise NotImplementedError("Currently not implemented for PI0FAST")
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations.
This method wraps `select_actions` in order to return one action at a time for execution in the
environment. It works by managing the actions in a queue and only calling `select_actions` when the
queue is empty.
"""
self.eval()
if self.config.adapt_to_pi_aloha:
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
batch = self.normalize_inputs(batch)
# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
# querying the policy.
if len(self._action_queue) == 0:
actions = self.model.generate_actions(batch)
actions = actions[:, : self.config.n_action_steps]
original_action_dim = self.config.action_feature.shape[
0
] # self.config.max_action_dim # self.config.action_feature.shape[0]
actions = actions[:, :, :original_action_dim]
actions = self.unnormalize_outputs({"action": actions})["action"]
if self.config.adapt_to_pi_aloha:
actions = self._pi_aloha_encode_actions(actions)
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
# effectively has shape (n_action_steps, batch_size, *), hence the transpose.
self._action_queue.extend(actions.transpose(0, 1))
return self._action_queue.popleft()
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
if self.config.adapt_to_pi_aloha:
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
batch = self.normalize_inputs(batch)
batch = self.normalize_targets(batch)
loss_dict = self.model.forward(batch)
return loss_dict["loss"], loss_dict
def block_causal_update_causal_mask(
attention_mask,
token_type_ids=None,
past_key_values=None,
cache_position=None,
input_tensor=None,
attn_implementation: str = "eager",
dtype: torch.dtype = "float32",
):
"""
Update the causal mask during training and generation. It can be customized to different attention masks.
"""
if attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
using_static_cache = isinstance(past_key_values, StaticCache)
min_dtype = torch.finfo(dtype).min
if input_tensor is None:
input_tensor = attention_mask
inputs_lead_dim, sequence_length = input_tensor.shape[:2]
if using_static_cache or isinstance(past_key_values, HybridCache):
target_length = past_key_values.get_max_cache_shape()
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else cache_position[0] + sequence_length + 1
)
# Handle precomputed attention masks
if attention_mask is not None and attention_mask.dim() == 4:
return attention_mask
# Causal mask initialization
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
# Standard causal masking (triu ensures tokens can only attend to past)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
# Apply block causal mask
if token_type_ids is not None:
token_type_ids = token_type_ids.to(causal_mask.device).bool()
cumsum = torch.cumsum(token_type_ids, dim=1)
block_causal_mask = cumsum[:, None, :] <= cumsum[:, :, None]
# Combine causal_mask with block-wise attention mask
causal_mask = torch.where(block_causal_mask, 0.0, causal_mask)
causal_mask = causal_mask[:, None, :, :]
else:
# Apply past cache position constraint
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
-1, 1
)
causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1)
else:
# Apply past cache position constraint
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
-1, 1
)
causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # Copy to contiguous memory for in-place edits
mask_length = attention_mask.shape[-1]
# Apply padding mask
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
def prepare_inputs_for_generation(
# self,
input_ids,
past_key_values=None,
inputs_embeds=None,
cache_position=None,
position_ids=None,
pixel_values=None,
attention_mask=None,
token_type_ids=None,
use_cache=True,
num_logits_to_keep=None,
labels=None,
self=None,
**kwargs,
):
# create block causal attention
if cache_position[0] > 0 and input_ids.shape[1] > 0:
input_tensor = input_ids[:, -1:]
new_positions = (
torch.ones(
(position_ids.shape[0], input_ids.shape[1]),
dtype=position_ids.dtype,
device=position_ids.device,
).cumsum(-1)
+ position_ids[:, -1:]
)
position_ids = torch.cat([position_ids, new_positions], dim=-1)
else:
input_tensor = inputs_embeds
attention_mask = block_causal_update_causal_mask(
attention_mask=attention_mask,
past_key_values=past_key_values,
cache_position=cache_position,
input_tensor=input_tensor,
token_type_ids=token_type_ids,
dtype=self.dtype,
attn_implementation=self.config.text_config._attn_implementation,
)
# Overwritten -- custom `position_ids` and `pixel_values` handling
model_inputs = self.language_model.prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
cache_position=cache_position,
use_cache=use_cache,
num_logits_to_keep=num_logits_to_keep,
token_type_ids=token_type_ids,
**kwargs,
)
# Position_ids in Paligemma are 1-indexed
if model_inputs.get("position_ids") is not None:
model_inputs["position_ids"] += 1
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
# Otherwise we need pixel values to be passed to model. NOTE: use_cache=False needs pixel_values always
if cache_position[0] == 0:
model_inputs["pixel_values"] = pixel_values
is_training = token_type_ids is not None and labels is not None
if cache_position[0] == 0 and isinstance(past_key_values, HybridCache):
input_tensor = inputs_embeds if inputs_embeds is not None else input_ids
causal_mask = self._update_causal_mask(
attention_mask, token_type_ids, past_key_values, cache_position, input_tensor, is_training
)
model_inputs["attention_mask"] = causal_mask
return model_inputs
class PI0FAST(nn.Module):
def __init__(self, config: PI0FASTConfig):
super().__init__()
self.config = config
# TODO: move tokenizers in Policy
fast_tokenizer_path = "physical-intelligence/fast"
pi0_paligemma_path = "google/paligemma-3b-pt-224"
self.paligemma_tokenizer = AutoTokenizer.from_pretrained(pi0_paligemma_path)
self.processor = AutoProcessor.from_pretrained(pi0_paligemma_path)
self.fast_tokenizer = AutoProcessor.from_pretrained(fast_tokenizer_path, trust_remote_code=True)
self.fast_skip_tokens = self.config.fast_skip_tokens
self.max_input_seq_len = self.config.max_input_seq_len
self.action_horizon = self.config.chunk_size
self.action_dim = self.config.action_feature.shape[
0
] # self.config.max_action_dim # self.config.action_feature.shape[0]
precision = config.precision
torch_precision = PRECISION.get(precision, torch.float32)
self.pad_token_id = (
self.paligemma_tokenizer.pad_token_id
if hasattr(self.paligemma_tokenizer, "pad_token_id")
else self.paligemma_tokenizer.eos_token_id
)
paligemma_config = CONFIG_MAPPING["paligemma"](
transformers_version="4.48.1",
_vocab_size=257152,
bos_token_id=2,
eos_token_id=1,
hidden_size=2048,
image_token_index=257152,
model_type="paligemma",
pad_token_id=0,
projection_dim=2048,
text_config={
"hidden_activation": "gelu_pytorch_tanh",
"hidden_size": 2048,
"intermediate_size": 16384,
"model_type": "gemma",
"num_attention_heads": 8,
"num_hidden_layers": 18,
"num_image_tokens": 256,
"num_key_value_heads": 1,
"torch_dtype": precision,
"vocab_size": 257152,
"_attn_implementation": "eager",
},
vision_config={
"hidden_size": 1152,
"intermediate_size": 4304,
"model_type": "siglip_vision_model",
"num_attention_heads": 16,
"num_hidden_layers": 27,
"num_image_tokens": 256,
"patch_size": 14,
"projection_dim": 2048,
"projector_hidden_act": "gelu_pytorch_tanh",
"torch_dtype": precision,
"vision_use_head": False,
},
)
self.pi0_paligemma = PaliGemmaForConditionalGeneration(config=paligemma_config)
self.pi0_paligemma.prepare_inputs_for_generation = partial(
prepare_inputs_for_generation, self=self.pi0_paligemma
)
# change important stuff in bf16
params_to_change_dtype = [
"language_model",
"vision_tower",
"multi_modal",
]
for name, param in self.pi0_paligemma.named_parameters():
if any(selector in name for selector in params_to_change_dtype):
param.data = param.data.to(dtype=torch_precision)
self.set_requires_grad()
self.image_keys = self.config.image_features.keys()
# TODO: Remove this once we bump transformers to >4.52.0 because the attribute will be removed
# AttributeError: 'PaliGemmaConfig' object has no attribute 'ignore_index'
self.ignore_index = self.pi0_paligemma.config.ignore_index
self.padding_side = self.config.padding_side
def set_requires_grad(self):
if self.config.freeze_vision_encoder:
self.pi0_paligemma.vision_tower.eval()
for params in self.pi0_paligemma.vision_tower.parameters():
params.requires_grad = False
# To avoid unused params issue with distributed training
if self.config.freeze_lm_head:
for name, params in self.pi0_paligemma.named_parameters():
if "embed_tokens" in name: # lm heads and embedding layer are tied
params.requires_grad = False
def embed_tokens(self, tokens: torch.Tensor):
return self.pi0_paligemma.language_model.model.embed_tokens(tokens)
def prepare_inputs_for_generation(self, *args, **kwargs):
return self.pi0_paligemma.prepare_inputs_for_generation(*args, **kwargs)
def prepare_images(self, batch):
"""Preprocess LeRobot batch into Pi0 inputs"""
images = []
img_masks = []
present_img_keys = [key for key in self.image_keys if key in batch]
if len(present_img_keys) == 0:
raise ValueError(
f"All image features are missing from the batch. At least one expected. (batch: {batch.keys()}) (image_features:{self.config.image_features})"
)
# Preprocess image features present in the batch
num_empty_cameras = 0
for key in self.image_keys:
if key in present_img_keys:
img = batch[key]
if self.config.resize_imgs_with_padding is not None:
img = resize_with_pad(
img,
*self.config.resize_imgs_with_padding,
pad_value=0,
interpolate_like_pi=self.config.interpolate_like_pi,
)
# Normalize from range [0,1] to [-1,1] as expected by siglip
img = img * 2.0 - 1.0
bsize = img.shape[0]
device = img.device
mask = torch.ones(bsize, dtype=torch.bool, device=device)
else:
if num_empty_cameras >= self.config.empty_cameras:
continue
img = torch.ones_like(img) * -1
bsize = img.shape[0]
device = img.device
mask = torch.ones(bsize, dtype=torch.bool, device=device)
num_empty_cameras += 1
images.append(img)
img_masks.append(mask)
return images, img_masks
def normalize_actions(self, actions: torch.Tensor) -> torch.Tensor:
mins = actions.amin(dim=(1, 2), keepdim=True) # [0]
maxs = actions.amax(dim=(1, 2), keepdim=True) # [0]
return 2 * (actions - mins) / (maxs - mins + 1e-8) - 1
def _act_tokens_to_paligemma_tokens(self, tokens: torch.Tensor) -> torch.Tensor:
out = self.paligemma_tokenizer.vocab_size - 1 - self.fast_skip_tokens - tokens
return out
def fast_tokenizer_wrapper(self, actions_norm):
"""
A wrapper for self.fast_tokenizer that ensures batch processing,
conversion to PyTorch tensors, and returns a dictionary without padding.
"""
batch_tokens = self.fast_tokenizer(actions_norm)
fast_out = self.processor.tokenizer.pad({"input_ids": batch_tokens}, return_tensors="pt")
return fast_out
def create_token_type_ids(self, padded_mask: torch.Tensor, prefix_len: int) -> torch.Tensor:
token_type_ids = torch.zeros_like(padded_mask, dtype=torch.bool)
# Compute cumulative sum mask
cumsum_mask = (padded_mask != 0).cumsum(dim=1)
# Suffix block (everything after prefix_len)
suffix_mask = cumsum_mask > prefix_len
token_type_ids = suffix_mask
return token_type_ids
def create_input_tokens(self, state, lang_text, actions=None):
bsize = state.shape[0]
device = state.device
bins = torch.linspace(-1, 1, 256 + 1, device=device)[:-1]
discretized = torch.bucketize(state, bins) - 1
discretized = discretized[:, :32]
prefix_texts = []
state_text = []
for txt, disc in zip(lang_text, discretized, strict=False):
cleaned = txt.lower().strip().replace("_", " ")
state_str = " ".join(str(val.item()) for val in disc)
prefix_texts.append(f"Task: {cleaned}, State: {state_str};\n")
state_text.append(f"State: {state_str};\n")
prefix_out = self.paligemma_tokenizer(
prefix_texts, add_special_tokens=True, return_tensors="pt", padding="longest", truncation=False
)
prefix_ids = prefix_out["input_ids"].to(device)
prefix_mask = prefix_out["attention_mask"].to(device)
prefix_lens = prefix_mask.sum(dim=1)[:, None].cpu()
if actions is not None:
actions_norm = self.normalize_actions(actions)
actions_pad = F.pad(
actions_norm, (0, max(0, self.config.max_action_dim - actions_norm.shape[2])), value=0
)[:, :, : self.config.max_action_dim]
fast_out = self.fast_tokenizer_wrapper(
actions_pad.cpu(),
)
act_ids = fast_out["input_ids"]
act_mask = fast_out["attention_mask"].to(device)
act_ids = self._act_tokens_to_paligemma_tokens(act_ids).to(device)
# Replace action with 0 to pad tokens
act_ids = torch.where(
act_ids == self.paligemma_tokenizer.vocab_size - 1 - self.fast_skip_tokens,
self.pad_token_id,
act_ids,
)
eos_token = torch.tensor(
[self.paligemma_tokenizer.eos_token_id], dtype=torch.long, device=device
).expand(bsize, -1)
eos_mask = torch.tensor([1], dtype=torch.long, device=device).expand(bsize, -1)
bos = self.paligemma_tokenizer("Action: ", add_special_tokens=False, return_tensors="pt")
bos_token = bos["input_ids"].expand(act_ids.shape[0], -1).to(device)
bos_mask = bos["attention_mask"].expand(act_ids.shape[0], -1).to(device)
act_ids = torch.cat([bos_token, act_ids, eos_token], dim=1)
act_mask = torch.cat([bos_mask, act_mask, eos_mask], dim=1)
act_mask = act_mask.to(device)
else:
act_ids = torch.empty(bsize, self.pad_token_id, dtype=torch.long, device=device)
act_mask = torch.empty(bsize, 0, dtype=torch.long, device=device)
final_ids = torch.cat([prefix_ids, act_ids], dim=1)
final_mask = torch.cat([prefix_mask, act_mask], dim=1)
batch_inputs = {"input_ids": final_ids.tolist(), "attention_mask": final_mask.tolist()}
# Use tokenizer pad function
padded_output = self.paligemma_tokenizer.pad(
batch_inputs, padding="longest", max_length=180, return_tensors="pt"
)
padded_mask = padded_output["attention_mask"]
# define tensor of padding lengths
att_mask = (padded_mask != 0).cumsum(dim=1) > prefix_lens
token_type_ids = self.create_token_type_ids(padded_mask=padded_mask, prefix_len=prefix_lens)
padded_output["padded_mask"] = padded_output.pop("attention_mask")
padded_output["attention_mask"] = att_mask
# loss is computed not on prefix, and not on padding
padded_output["loss_mask"] = att_mask & padded_output["padded_mask"]
padded_output["token_type_ids"] = token_type_ids
return padded_output
def shift_padding_side(
self,
tokens: torch.Tensor,
ar_mask: torch.Tensor,
padding_mask: torch.Tensor,
loss_mask: torch.Tensor,
targets: torch.Tensor,
token_type_ids: torch.Tensor,
padding_side: str = "right",
) -> tuple[torch.Tensor]:
if padding_side not in ["right", "left"]:
return tokens, ar_mask, padding_mask, loss_mask, targets, token_type_ids
new_tokens = torch.empty_like(tokens)
new_ar_masks = torch.empty_like(ar_mask)
new_padding_mask = torch.empty_like(padding_mask)
new_loss_mask = torch.empty_like(loss_mask)
new_targets = torch.empty_like(targets)
new_token_type_ids = torch.empty_like(token_type_ids)
batch_size = tokens.shape[0]
for i in range(batch_size):
padding_indices = torch.where(padding_mask[i] == 0)[0]
non_padding_indices = torch.where(padding_mask[i] == 1)[0]
if padding_side == "left":
new_indices = torch.cat((padding_indices, non_padding_indices), dim=0)
else:
new_indices = torch.cat((non_padding_indices, padding_indices), dim=0)
new_tokens[i] = tokens[i].index_select(0, new_indices)
new_ar_masks[i] = ar_mask[i].index_select(0, new_indices)
new_padding_mask[i] = padding_mask[i].index_select(0, new_indices)
new_loss_mask[i] = loss_mask[i].index_select(0, new_indices)
new_targets[i] = targets[i].index_select(0, new_indices)
new_token_type_ids[i] = token_type_ids[i].index_select(0, new_indices)
return new_tokens, new_ar_masks, new_padding_mask, new_loss_mask, new_targets, new_token_type_ids
def forward(self, batch: dict[str, Tensor]):
device = batch[OBS_STATE].device
# TODO: keep like this or move to the policy .forward
images, img_masks = self.prepare_images(batch)
padded_outs = self.create_input_tokens(
state=batch[OBS_STATE],
lang_text=batch["task"],
actions=batch[ACTION],
)
embs, pad_masks, _, targets, loss_mask, token_type_ids = self.embed_inputs(
images,
img_masks,
padded_outs["input_ids"],
padded_outs["padded_mask"],
padded_outs["attention_mask"],
padded_outs["loss_mask"],
padded_outs["token_type_ids"],
padding_side=self.padding_side,
)
position_ids = torch.cumsum(pad_masks, dim=1) - 1
token_type_ids = token_type_ids.to(dtype=torch.int64)
past_seen_tokens = 0
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + embs.shape[1], device=embs.device)
pad_masks = block_causal_update_causal_mask(
attention_mask=pad_masks,
past_key_values=None,
cache_position=cache_position,
input_tensor=embs,
token_type_ids=token_type_ids,
dtype=self.pi0_paligemma.dtype,
attn_implementation=self.pi0_paligemma.config.text_config._attn_implementation,
)
outputs = self.pi0_paligemma.forward(
input_ids=None,
token_type_ids=None,
attention_mask=pad_masks,
position_ids=position_ids,
past_key_values=None,
inputs_embeds=embs,
use_cache=False,
labels=None,
)
logits = outputs.logits
loss_fct = nn.CrossEntropyLoss(reduction="none")
# Shift left for next-step prediction
logits = logits[:, :-1, :]
targets = targets[:, 1:].to(device) # Shift targets
loss_mask = loss_mask[:, 1:].to(device) # Ensure correct shape
# Compute per-token loss
token_loss = loss_fct(logits.reshape(-1, logits.shape[-1]), targets.reshape(-1))
# Apply loss mask
token_loss = token_loss * loss_mask.reshape(-1)
# Compute final loss
loss = token_loss.sum() / torch.clamp(loss_mask.sum(), min=1)
# Return loss dictionary
loss_dict = {"ce_loss": loss.item(), "loss": loss}
return loss_dict
def decode_actions_with_fast(
self,
tokens: list[list[int]],
*,
time_horizon: int | None = None,
action_dim: int | None = None,
relaxed_decoding: bool = True,
) -> np.array:
"""
Adapt original decoding in FAST to always return actions instead of zeros.
"""
self.time_horizon = (
time_horizon or self.fast_tokenizer.time_horizon or self.fast_tokenizer.called_time_horizon
)
self.action_dim = (
action_dim or self.fast_tokenizer.action_dim or self.fast_tokenizer.called_action_dim
)
# Cache the time horizon and action dimension for the next call
self.called_time_horizon = self.time_horizon
self.called_action_dim = self.action_dim
assert self.time_horizon is not None and self.action_dim is not None, (
"Tokenizer not initialized, call encode() once or pass in time_horizon and action_dim."
)
decoded_actions = []
for token in tokens:
try:
decoded_tokens = self.fast_tokenizer.bpe_tokenizer.decode(token)
decoded_dct_coeff = np.array(list(map(ord, decoded_tokens))) + self.fast_tokenizer.min_token
if relaxed_decoding:
# Expected sequence length
expected_seq_len = self.time_horizon * self.action_dim
diff = expected_seq_len - decoded_dct_coeff.shape[0]
# Apply truncation if too long
if diff < 0:
decoded_dct_coeff = decoded_dct_coeff[:expected_seq_len] # Truncate on the right
# Apply padding if too short
elif diff > 0:
decoded_dct_coeff = np.pad(
decoded_dct_coeff, (0, diff), mode="constant", constant_values=0
)
decoded_dct_coeff = decoded_dct_coeff.reshape(-1, self.action_dim)
assert decoded_dct_coeff.shape == (
self.time_horizon,
self.action_dim,
), (
f"Decoded DCT coefficients have shape {decoded_dct_coeff.shape}, expected ({self.time_horizon}, {self.action_dim})"
)
except Exception as e:
print(f"Error decoding tokens: {e}")
print(f"Tokens: {token}")
decoded_dct_coeff = np.zeros((self.time_horizon, self.action_dim))
decoded_actions.append(idct(decoded_dct_coeff / self.fast_tokenizer.scale, axis=0, norm="ortho"))
return np.stack(decoded_actions)
def extract_actions(self, tokens: torch.Tensor, action_horizon: int, action_dim: int) -> torch.Tensor:
"""
Extracts actions from predicted output tokens using the FAST model.
Args:
tokens (torch.Tensor): The input tensor of tokenized outputs.
action_horizon (int): The number of timesteps for actions.
action_dim (int): The dimensionality of each action.
Returns:
torch.Tensor: The extracted actions as a tensor of shape (action_horizon, action_dim).
"""
# Decode predicted output tokens
decoded_tokens = self.paligemma_tokenizer.batch_decode(tokens, skip_special_tokens=True)
cleaned_tokens = [
tokens_sequence.replace("Action:", "").replace(":", "").strip().split("|")[0].strip()
for tokens_sequence in decoded_tokens
]
raw_action_tokens = [
self.processor.tokenizer.encode(sample_tokens, return_tensors="pt", padding=False)
for sample_tokens in cleaned_tokens
] # something like this should be robust #looks good
action_tokens = [
self._act_tokens_to_paligemma_tokens(raw_action_token) for raw_action_token in raw_action_tokens
]
# returns the tensor of decoded actions per sample in a list
decoded_actions = [
torch.tensor(
self.decode_actions_with_fast(
tok.tolist(),
time_horizon=action_horizon,
action_dim=action_dim,
relaxed_decoding=self.config.relaxed_action_decoding,
),
device=tokens.device,
).squeeze(0)
for tok in action_tokens
]
return torch.stack(
decoded_actions,
dim=0,
)
def generate_actions(self, batch: dict[str, Tensor]):
# TODO: keep like this or move to the policy .forward
images, img_masks = self.prepare_images(batch)
padded_outs = self.create_input_tokens(state=batch[OBS_STATE], lang_text=batch["task"], actions=None)
embs, pad_masks, att_masks2, targets, loss_mask, token_type_ids = self.embed_inputs(
images,
img_masks,
padded_outs["input_ids"],
padded_outs["padded_mask"],
padded_outs["attention_mask"],
padded_outs["loss_mask"],
padded_outs["token_type_ids"],
padding_side="left",
)
token_type_ids = token_type_ids.to(dtype=torch.int64)
prefix_position_ids = torch.cumsum(pad_masks, dim=1) - 1
output_tokens = self.pi0_paligemma.generate(
input_ids=None,
attention_mask=pad_masks,
position_ids=prefix_position_ids,
past_key_values=None,
inputs_embeds=embs,
use_cache=self.config.use_cache,
max_new_tokens=self.config.max_decoding_steps,
do_sample=False,
num_beams=1,
token_type_ids=token_type_ids,
)
actions = self.extract_actions(output_tokens, self.action_horizon, self.action_dim)
return actions
def embed_image(self, image: torch.Tensor):
# Handle different transformers versions
if hasattr(self.pi0_paligemma, "get_image_features"):
return self.pi0_paligemma.get_image_features(image)
else:
return self.pi0_paligemma.model.get_image_features(image)
def embed_inputs(
self,
images,
img_masks,
tokens,
pad_mask,
ar_mask,
loss_mask,
token_type_ids,
padding_side: str = "right",
):
# TODO: avoid list in python and torch.cat ; prefer pre-allocation with torch.empty
# images are a list of same size
# vectorizing everything!
device = images[0].device
image_embedding_dim = images[0].shape[-1] # TODO should be from self.config
all_images = torch.stack(images, dim=1).to(device)
b, n, c, h, w = all_images.shape
all_images = all_images.view(b * n, c, h, w)
embedded = self.embed_image(all_images).to(device)
b_n, p, image_embedding_dim = embedded.shape # Extract current dimensions
m = b_n // b # Compute the number of images per sample dynamically
# Reshape dynamically
embedded = embedded.view(b, m, p, image_embedding_dim)
tokens_embs = self.embed_tokens(tokens.to(device))
img_masks = torch.stack(img_masks, dim=1).unsqueeze(-1).to(device)
num_img_emb = embedded.shape[2]
img_pad_masks = img_masks.repeat(1, 1, num_img_emb).view(b, -1)
img_att_masks = torch.zeros((b, n, num_img_emb), dtype=torch.long, device=device).reshape(b, -1)
image_target_tokens = (
torch.ones((b, n, num_img_emb), dtype=torch.long, device=device) * self.pad_token_id
).reshape(b, -1)
image_loss_mask = torch.zeros((b, n, num_img_emb), dtype=torch.long, device=device).reshape(b, -1)
embedded = embedded.reshape(b, n * num_img_emb, image_embedding_dim) # Shape: (B, N*P, D)
embs = torch.cat([embedded, tokens_embs], dim=1).to(device)
pad_masks = torch.cat([img_pad_masks, pad_mask.to(device)], dim=1)
att_masks = torch.cat([img_att_masks, ar_mask.to(device)], dim=1)
loss_masks = torch.cat([image_loss_mask, loss_mask.to(device)], dim=1)
targets = torch.cat([image_target_tokens, tokens.to(device)], dim=1)
token_type_ids = torch.cat([img_att_masks, token_type_ids.to(device)], dim=1)
# Shift pad tokens to the left (.generate()) or right (.train())
embs, att_masks, pad_masks, loss_masks, targets, token_type_ids = self.shift_padding_side(
embs, att_masks, pad_masks, loss_masks, targets, token_type_ids, padding_side=padding_side
)
targets = torch.where(targets == self.pad_token_id, self.ignore_index, targets)
return embs, pad_masks, att_masks, targets, loss_masks, token_type_ids
def resize_with_pad(img, width, height, pad_value=0, interpolate_like_pi=True):
# assume no-op when width height fits already
if img.ndim != 4:
raise ValueError(f"(b,c,h,w) expected, but {img.shape}")
cur_height, cur_width = img.shape[2:]
ratio = max(cur_width / width, cur_height / height)
resized_height = int(cur_height / ratio)
resized_width = int(cur_width / ratio)
if interpolate_like_pi:
img = (img * 255.0).to(dtype=torch.uint8)
img = img.permute(0, 2, 3, 1)
original_device = img.device
img = img.to(device="cpu").numpy()
imgs = []
for sub_img in img:
sub_img = Image.fromarray(sub_img)
resized_img = sub_img.resize((resized_width, resized_height), resample=2)
resized_img = torch.from_numpy(np.array(resized_img))
imgs.append(resized_img)
img = torch.stack(imgs, dim=0)
img = img.permute(0, 3, 1, 2)
resized_img = img.to(device=original_device, dtype=torch.float32) / 255.0
else:
resized_img = F.interpolate(
img, size=(resized_height, resized_width), mode="bilinear", align_corners=False
)
pad_height = max(0, int(height - resized_height))
pad_width = max(0, int(width - resized_width))
# pad on left and top of image
padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value)
return padded_img
@@ -0,0 +1,171 @@
#!/usr/bin/env python
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from copy import deepcopy
from dataclasses import dataclass
from typing import Any
import numpy as np
import torch
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
from lerobot.policies.pi05.configuration_pi05 import PI05Config
from lerobot.policies.pi05.modeling_pi05 import pad_vector
from lerobot.processor import (
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
NormalizerProcessorStep,
PolicyAction,
PolicyProcessorPipeline,
ProcessorStep,
ProcessorStepRegistry,
RenameObservationsProcessorStep,
TokenizerProcessorStep,
UnnormalizerProcessorStep,
)
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
from lerobot.processor.core import EnvTransition, TransitionKey
from lerobot.utils.constants import (
OBS_STATE,
POLICY_POSTPROCESSOR_DEFAULT_NAME,
POLICY_PREPROCESSOR_DEFAULT_NAME,
)
@ProcessorStepRegistry.register(name="pi05_prepare_state_tokenizer_processor_step")
@dataclass
class Pi05PrepareStateTokenizerProcessorStep(ProcessorStep):
"""
Processor step to prepare the state and tokenize the language input.
"""
max_state_dim: int = 32
task_key: str = "task"
def __call__(self, transition: EnvTransition) -> EnvTransition:
transition = transition.copy()
state = transition.get(TransitionKey.OBSERVATION, {}).get(OBS_STATE)
if state is None:
raise ValueError("State is required for PI05")
tasks = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}).get(self.task_key)
if tasks is None:
raise ValueError("No task found in complementary data")
# TODO: check if this necessary
state = deepcopy(state)
# Prepare state (pad to max_state_dim)
state = pad_vector(state, self.max_state_dim)
# State should already be normalized to [-1, 1] by the NormalizerProcessorStep that runs before this step
# Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`)
state_np = state.cpu().numpy()
discretized_states = np.digitize(state_np, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
full_prompts = []
for i, task in enumerate(tasks):
cleaned_text = task.strip().replace("_", " ").replace("\n", " ")
state_str = " ".join(map(str, discretized_states[i]))
full_prompt = f"Task: {cleaned_text}, State: {state_str};\nAction: "
full_prompts.append(full_prompt)
transition[TransitionKey.COMPLEMENTARY_DATA][self.task_key] = full_prompts
# Normalize state to [-1, 1] range if needed (assuming it's already normalized by normalizer processor step!!)
# Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`)
return transition
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
"""
This step does not alter the feature definitions.
"""
return features
def make_pi05_pre_post_processors(
config: PI05Config,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
"""
Constructs pre-processor and post-processor pipelines for the PI0 policy.
The pre-processing pipeline prepares input data for the model by:
1. Renaming features to match pretrained configurations.
2. Normalizing input and output features based on dataset statistics.
3. Adding a batch dimension.
4. Appending a newline character to the task description for tokenizer compatibility.
5. Tokenizing the text prompt using the PaliGemma tokenizer.
6. Moving all data to the specified device.
The post-processing pipeline handles the model's output by:
1. Moving data to the CPU.
2. Unnormalizing the output features to their original scale.
Args:
config: The configuration object for the PI0 policy.
dataset_stats: A dictionary of statistics for normalization.
preprocessor_kwargs: Additional arguments for the pre-processor pipeline.
postprocessor_kwargs: Additional arguments for the post-processor pipeline.
Returns:
A tuple containing the configured pre-processor and post-processor pipelines.
"""
# Add remaining processors
input_steps: list[ProcessorStep] = [
RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one
AddBatchDimensionProcessorStep(),
# NOTE: NormalizerProcessorStep MUST come before Pi05PrepareStateTokenizerProcessorStep
# because the tokenizer step expects normalized state in [-1, 1] range for discretization
NormalizerProcessorStep(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
Pi05PrepareStateTokenizerProcessorStep(max_state_dim=config.max_state_dim),
TokenizerProcessorStep(
tokenizer_name="google/paligemma-3b-pt-224",
max_length=config.tokenizer_max_length,
padding_side="right",
padding="max_length",
),
DeviceProcessorStep(device=config.device),
]
output_steps: list[ProcessorStep] = [
UnnormalizerProcessorStep(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
DeviceProcessorStep(device="cpu"),
]
return (
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
steps=input_steps,
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
),
PolicyProcessorPipeline[PolicyAction, PolicyAction](
steps=output_steps,
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
to_transition=policy_action_to_transition,
to_output=transition_to_policy_action,
),
)
+14
View File
@@ -0,0 +1,14 @@
## Paper
https://arxiv.org/abs/2509.25358
## Citation
```bibtex
@article{chen2025sarm,
title={SARM: Stage-Aware Reward Modeling for Long Horizon Robot Manipulation},
author={Chen, Qianzhong and Yu, Justin and Schwager, Mac and Abbeel, Pieter and Shentu, Yide and Wu, Philipp},
journal={arXiv preprint arXiv:2509.25358},
year={2025}
}
```
@@ -0,0 +1,870 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Compute SARM progress values for RA-BC (Reward-Aware Behavior Cloning) weighting.
This script processes all frames in a dataset with SARM to compute progress values [0, 1].
The results are saved as a parquet file that can be loaded during training for RA-BC weighting.
Uses multi-output extraction: each SARM query returns progress for 9 frames, so we only
need ~num_frames/30 queries instead of one per frame (~30x speedup).
Usage:
# Full RA-BC computation with visualizations
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
--reward-model-path pepijn223/sarm_single_uni4
# Faster computation with stride (compute every 5 frames, interpolate the rest)
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
--reward-model-path pepijn223/sarm_single_uni4 \\
--stride 5
# Visualize predictions only (no RA-BC computation)
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
--reward-model-path pepijn223/sarm_single_uni4 \\
--visualize-only \\
--num-visualizations 5
The output is saved to the dataset's local cache directory as 'sarm_progress.parquet'.
"""
import argparse
import logging
from pathlib import Path
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import numpy as np
import pyarrow as pa
import pyarrow.parquet as pq
import torch
from tqdm import tqdm
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.policies.sarm.modeling_sarm import SARMRewardModel
from lerobot.policies.sarm.processor_sarm import make_sarm_pre_post_processors
from lerobot.policies.sarm.sarm_utils import normalize_stage_tau
def get_reward_model_path_from_parquet(parquet_path: Path) -> str | None:
"""Read reward_model_path from parquet metadata if available."""
if not parquet_path.exists():
return None
try:
metadata = pq.read_metadata(parquet_path).schema.to_arrow_schema().metadata
if metadata and b"reward_model_path" in metadata:
return metadata[b"reward_model_path"].decode()
except Exception: # nosec B110
return None
return None
def load_sarm_resources(
dataset_repo_id: str,
reward_model_path: str,
device: str = "cuda",
) -> tuple[LeRobotDataset, SARMRewardModel, any]:
"""
Load SARM model, dataset, and preprocessor.
Returns:
Tuple of (dataset, reward_model, preprocessor)
"""
logging.info(f"Loading model: {reward_model_path}")
reward_model = SARMRewardModel.from_pretrained(reward_model_path)
reward_model.config.device = device
reward_model.to(device).eval()
image_key = reward_model.config.image_key
state_key = reward_model.config.state_key
delta_indices = reward_model.config.observation_delta_indices
logging.info(f"Loading dataset: {dataset_repo_id}")
temp_dataset = LeRobotDataset(dataset_repo_id, download_videos=True)
fps = temp_dataset.fps
delta_timestamps = {
image_key: [idx / fps for idx in delta_indices],
state_key: [idx / fps for idx in delta_indices],
}
dataset = LeRobotDataset(dataset_repo_id, delta_timestamps=delta_timestamps)
logging.info(f"Dataset: {dataset.num_episodes} episodes, {dataset.num_frames} frames")
preprocess, _ = make_sarm_pre_post_processors(
config=reward_model.config,
dataset_stats=dataset.meta.stats,
dataset_meta=dataset.meta,
)
return dataset, reward_model, preprocess
def to_numpy_image(img) -> np.ndarray:
"""Convert image tensor to numpy uint8 (H, W, C)."""
if isinstance(img, torch.Tensor):
img = img.cpu().numpy()
if img.ndim == 4:
# Take center frame for bidirectional sampling
img = img[img.shape[0] // 2]
if img.shape[0] in [1, 3]:
img = np.transpose(img, (1, 2, 0))
if img.dtype != np.uint8:
# Handle normalized images (may have negative values or values > 1)
img = img.astype(np.float32)
img = (img - img.min()) / (img.max() - img.min() + 1e-8) # Normalize to [0, 1]
img = (img * 255).astype(np.uint8)
return img
def visualize_episode(
frames, progress_preds, stage_preds, title, output_path, stage_labels, gt_progress=None, gt_stages=None
):
"""Create visualization with progress plot, stage probabilities, and sample frames.
Same as sarm_inference_visualization.py
"""
num_stages = stage_preds.shape[1]
colors = plt.cm.tab10(np.linspace(0, 1, num_stages))
frame_indices = np.arange(len(progress_preds))
fig = plt.figure(figsize=(14, 12))
gs = gridspec.GridSpec(3, 1, height_ratios=[2, 1, 1], hspace=0.3)
ax_progress, ax_stages, ax_frames = fig.add_subplot(gs[0]), fig.add_subplot(gs[1]), fig.add_subplot(gs[2])
# Progress plot
ax_progress.plot(frame_indices, progress_preds, linewidth=2, color="#2E86AB", label="Predicted")
ax_progress.fill_between(frame_indices, 0, progress_preds, alpha=0.3, color="#2E86AB")
if gt_progress is not None:
ax_progress.plot(
frame_indices, gt_progress, linewidth=2, color="#28A745", linestyle="--", label="Ground Truth"
)
ax_progress.axhline(y=1.0, color="gray", linestyle="--", alpha=0.5)
ax_progress.set_ylabel("Progress")
ax_progress.set_title(f'Task: "{title}"', fontweight="bold")
ax_progress.set_ylim(-0.05, 1.1)
ax_progress.legend(loc="upper left")
ax_progress.grid(True, alpha=0.3)
# Stage predictions
ax_stages.stackplot(
frame_indices,
*[stage_preds[:, i] for i in range(num_stages)],
colors=colors,
alpha=0.8,
labels=stage_labels,
)
if gt_stages is not None:
for change_idx in np.where(np.diff(gt_stages) != 0)[0] + 1:
ax_stages.axvline(x=change_idx, color="black", linestyle="-", alpha=0.7, linewidth=1.5)
ax_stages.set_xlabel("Frame")
ax_stages.set_ylabel("Stage Probability")
ax_stages.set_ylim(0, 1)
ax_stages.legend(loc="upper left", ncol=min(num_stages, 5), fontsize=8)
ax_stages.grid(True, alpha=0.3)
# Sample frames
ax_frames.axis("off")
num_sample = 8
sample_indices = np.linspace(0, len(frames) - 1, num_sample, dtype=int)
h, w = frames[0].shape[:2]
combined = np.zeros((h, w * num_sample, 3), dtype=np.uint8)
for i, idx in enumerate(sample_indices):
frame = frames[idx]
if frame.shape[-1] == 1:
frame = np.repeat(frame, 3, axis=-1)
combined[:, i * w : (i + 1) * w] = frame
stage_name = stage_labels[np.argmax(stage_preds[idx])][:12]
ax_frames.text(
i * w + w / 2,
-10,
f"Frame {idx}\n{progress_preds[idx]:.2f}\n{stage_name}",
ha="center",
va="top",
fontsize=7,
)
ax_frames.imshow(combined)
ax_frames.set_title("Sample Frames", pad=20)
output_path.parent.mkdir(parents=True, exist_ok=True)
plt.savefig(output_path, dpi=150, bbox_inches="tight")
plt.close()
print(f"Saved: {output_path}")
def visualize_sarm_predictions(
dataset: LeRobotDataset,
reward_model: SARMRewardModel,
preprocess,
episode_indices: list[int],
head_mode: str,
output_dir: Path,
num_display_frames: int = 5,
stride: int = 1,
):
"""
Visualize SARM predictions for multiple episodes.
Computes predictions for every frame by default. With stride > 1, computes predictions
every N frames and interpolates (progress + stage probabilities) for visualization.
Args:
dataset: LeRobotDataset with delta_timestamps configured
reward_model: Loaded SARM model
preprocess: Preprocessor from make_sarm_pre_post_processors
episode_indices: List of episode indices to visualize
head_mode: "sparse", "dense", or "both"
output_dir: Directory to save visualizations
num_display_frames: Number of frames to display in thumbnail strip (default: 5)
stride: Compute predictions every N frames, interpolate the rest (default: 1)
"""
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
image_key = reward_model.config.image_key
state_key = reward_model.config.state_key
dual_mode = reward_model.config.uses_dual_heads
device = reward_model.device
# Center frame index for bidirectional sampling
target_idx = reward_model.config.n_obs_steps // 2
# Determine which heads to visualize
schemes_to_viz = []
if head_mode in ("sparse", "both") or not dual_mode:
schemes_to_viz.append("sparse")
if head_mode in ("dense", "both") and dual_mode:
schemes_to_viz.append("dense")
# Set preprocessor to eval mode to disable augmentations
if hasattr(preprocess, "eval"):
preprocess.eval()
for step in preprocess.steps:
if hasattr(step, "eval"):
step.eval()
for episode_idx in episode_indices:
ep = dataset.meta.episodes[episode_idx]
ep_start = ep["dataset_from_index"]
ep_end = ep["dataset_to_index"]
task = dataset[ep_start].get("task", "perform the task")
num_frames = ep_end - ep_start
# Select frames for display thumbnails (evenly sampled from begin to end)
display_indices = set(
[
ep_start + int(i * (num_frames - 1) / (num_display_frames - 1))
for i in range(num_display_frames)
]
if num_frames >= num_display_frames
else list(range(ep_start, ep_end))
)
viz_frames = {}
# Load display frames up-front (stride mode might skip them otherwise).
for frame_idx in display_indices:
sample = dataset[frame_idx]
viz_frames[frame_idx] = to_numpy_image(sample[image_key])
# Initialize storage for each scheme
scheme_data = {}
for scheme in schemes_to_viz:
num_stages = getattr(reward_model.config, f"num_{scheme}_stages")
scheme_data[scheme] = {
"viz_progress": np.full(num_frames, np.nan),
"viz_stages": np.full((num_frames, num_stages), np.nan),
"viz_gt_progress": np.full(num_frames, np.nan),
"viz_gt_stages": np.full(num_frames, np.nan),
"target_key": f"{scheme}_targets",
"num_stages": num_stages,
"temporal_props": getattr(reward_model.config, f"{scheme}_temporal_proportions"),
"subtask_names": getattr(reward_model.config, f"{scheme}_subtask_names"),
}
if stride > 1:
logging.info(f"Visualization stride={stride}: inferring every {stride} frames and interpolating")
# Process frames one at a time to avoid memory buildup
frame_indices = list(range(ep_start, ep_end, stride))
if (ep_end - 1) not in frame_indices:
frame_indices.append(ep_end - 1)
frame_indices = sorted(set(frame_indices))
for frame_idx in tqdm(frame_indices, desc=f"Episode {episode_idx}", leave=False):
local_idx = frame_idx - ep_start
sample = dataset[frame_idx]
batch = {
image_key: sample[image_key],
"task": task,
"index": frame_idx,
"episode_index": episode_idx,
}
if state_key in sample:
batch[state_key] = sample[state_key]
with torch.no_grad():
processed = preprocess(batch)
video_features = processed["video_features"].to(device)
text_features = processed["text_features"].to(device)
state_features = processed.get("state_features")
if state_features is not None:
state_features = state_features.to(device)
lengths = processed.get("lengths")
for scheme in schemes_to_viz:
sd = scheme_data[scheme]
# Ground truth
# In stride visualization mode, ground-truth plots can be misleading
# (only sparse points are available), so we skip GT.
if stride == 1 and sd["target_key"] in processed:
gt_target = processed[sd["target_key"]][0, target_idx].cpu().item()
sd["viz_gt_stages"][local_idx] = int(gt_target)
sd["viz_gt_progress"][local_idx] = normalize_stage_tau(
gt_target,
num_stages=sd["num_stages"],
temporal_proportions=sd["temporal_props"],
subtask_names=sd["subtask_names"],
)
# Predictions
reward, stage_probs = reward_model.calculate_rewards(
text_embeddings=text_features,
video_embeddings=video_features,
state_features=state_features,
lengths=lengths,
return_all_frames=True,
return_stages=True,
head_mode=scheme,
)
# Handle both tensor and numpy outputs
if isinstance(reward, torch.Tensor):
reward = reward.cpu().numpy()
stage_probs = stage_probs.cpu().numpy()
if reward.ndim == 2:
sd["viz_progress"][local_idx] = reward[0, target_idx]
sd["viz_stages"][local_idx] = stage_probs[0, target_idx, :]
else:
sd["viz_progress"][local_idx] = reward[target_idx]
sd["viz_stages"][local_idx] = stage_probs[target_idx, :]
# Clear GPU memory after each frame
del processed, video_features, text_features
if state_features is not None:
del state_features
torch.cuda.empty_cache()
# Interpolate predictions back to per-frame arrays for smooth visualization.
if stride > 1:
all_local = np.arange(num_frames)
for scheme in schemes_to_viz:
sd = scheme_data[scheme]
valid = np.isfinite(sd["viz_progress"])
valid_idx = np.where(valid)[0]
if valid_idx.size >= 1:
sd["viz_progress"] = interpolate_progress(
valid_idx, sd["viz_progress"][valid_idx], all_local
)
stage_interp = np.zeros_like(sd["viz_stages"], dtype=np.float32)
for s in range(sd["num_stages"]):
stage_interp[:, s] = interpolate_progress(
valid_idx, sd["viz_stages"][valid_idx, s], all_local
)
stage_interp = np.clip(stage_interp, 0.0, 1.0)
row_sums = stage_interp.sum(axis=1, keepdims=True)
nz = row_sums.squeeze(-1) > 0
stage_interp[nz] = stage_interp[nz] / row_sums[nz]
sd["viz_stages"] = stage_interp
else:
# No valid points: keep NaNs/zeros; visualization will be empty.
sd["viz_stages"] = np.nan_to_num(sd["viz_stages"], nan=0.0)
# Generate visualization for each head
ordered_viz_frames = [viz_frames[idx] for idx in sorted(display_indices)]
for scheme in schemes_to_viz:
sd = scheme_data[scheme]
stage_labels = sd["subtask_names"] or [f"Stage {i + 1}" for i in range(sd["num_stages"])]
viz_path = output_dir / f"sarm_prediction_ep{episode_idx}_{scheme}.png"
visualize_episode(
frames=np.array(ordered_viz_frames),
progress_preds=sd["viz_progress"],
stage_preds=sd["viz_stages"],
title=f"{task} (Episode {episode_idx})",
output_path=viz_path,
stage_labels=stage_labels,
gt_progress=sd["viz_gt_progress"] if not np.all(np.isnan(sd["viz_gt_progress"])) else None,
gt_stages=sd["viz_gt_stages"] if not np.all(np.isnan(sd["viz_gt_stages"])) else None,
)
# Clear memory between episodes
torch.cuda.empty_cache()
logging.info(f"Visualizations saved to: {output_dir.absolute()}")
def generate_all_frame_indices(ep_start: int, ep_end: int, frame_gap: int = 30) -> list[int]:
"""Generate all frame indices, ordered by offset for cache-friendly access.
Orders frames as: [0, 30, 60...], [1, 31, 61...], ..., [29, 59, 89...]
This groups frames that share similar temporal windows together.
"""
num_frames = ep_end - ep_start
indices = []
for offset in range(frame_gap):
for frame_rel in range(offset, num_frames, frame_gap):
indices.append(ep_start + frame_rel)
return indices
def interpolate_progress(
computed_indices: np.ndarray,
computed_values: np.ndarray,
all_indices: np.ndarray,
) -> np.ndarray:
"""Linearly interpolate values to fill in gaps (robust to NaNs / edge cases)."""
computed_indices = np.asarray(computed_indices)
computed_values = np.asarray(computed_values)
all_indices = np.asarray(all_indices)
mask = np.isfinite(computed_values)
if mask.sum() == 0:
return np.full(all_indices.shape, np.nan, dtype=np.float32)
if mask.sum() == 1:
return np.full(all_indices.shape, float(computed_values[mask][0]), dtype=np.float32)
out = np.interp(all_indices, computed_indices[mask], computed_values[mask])
return out.astype(np.float32)
def compute_sarm_progress(
dataset_repo_id: str,
reward_model_path: str,
output_path: str | None = None,
head_mode: str = "sparse",
device: str = "cuda",
num_visualizations: int = 5,
output_dir: str = "./sarm_viz",
stride: int = 1,
):
"""
Compute SARM progress predictions for all frames in a dataset.
Args:
dataset_repo_id: HuggingFace dataset repo ID or local path
reward_model_path: Path to pretrained SARM model
output_path: Path to save results. If None, saves to dataset's cache directory
head_mode: SARM head to use ("sparse", "dense", or "both")
device: Device to use for inference
num_visualizations: Number of episodes to visualize (0 to skip)
output_dir: Directory to save visualizations
stride: Compute progress every N frames, interpolate the rest (default: 1 = every frame)
"""
dataset, reward_model, preprocess = load_sarm_resources(dataset_repo_id, reward_model_path, device)
# Set preprocessor to eval mode to disable augmentations
if hasattr(preprocess, "eval"):
preprocess.eval()
for step in preprocess.steps:
if hasattr(step, "eval"):
step.eval()
image_key = reward_model.config.image_key
state_key = reward_model.config.state_key
frame_gap = reward_model.config.frame_gap
num_episodes = dataset.num_episodes
total_frames = dataset.num_frames
logging.info(f"Processing {total_frames} frames across {num_episodes} episodes")
# Determine which heads to compute
dual_mode = reward_model.config.uses_dual_heads
compute_sparse = head_mode in ("sparse", "both") or not dual_mode
compute_dense = head_mode in ("dense", "both") and dual_mode
# Storage arrays
all_indices = []
all_episode_indices = []
all_frame_indices = []
all_progress_sparse = [] if compute_sparse else None
all_progress_dense = [] if compute_dense else None
if stride > 1:
logging.info(f"Using stride={stride}: computing every {stride} frames, interpolating the rest")
# Process all episodes
for episode_idx in tqdm(range(num_episodes), desc="Episodes"):
ep = dataset.meta.episodes[episode_idx]
ep_start = ep["dataset_from_index"]
ep_end = ep["dataset_to_index"]
# Get task description
task = dataset[ep_start].get("task", "perform the task")
# Generate frames to compute (with stride applied)
all_ep_indices = generate_all_frame_indices(ep_start, ep_end, frame_gap)
if stride > 1:
# Only compute every stride-th frame (relative to episode start)
compute_indices = [idx for idx in all_ep_indices if (idx - ep_start) % stride == 0]
# Always include last frame for better interpolation at episode end
last_frame = ep_end - 1
if last_frame not in compute_indices:
compute_indices.append(last_frame)
compute_indices = sorted(set(compute_indices))
else:
compute_indices = all_ep_indices
center_idx = reward_model.config.n_obs_steps // 2 # Center of bidirectional window
# Dictionary to collect results
frame_results = {}
for query_idx in tqdm(compute_indices, desc=f" Ep {episode_idx}", leave=False):
try:
sample = dataset[query_idx]
batch = {
image_key: sample[image_key],
"task": task,
"index": query_idx,
"episode_index": episode_idx,
}
if state_key in sample:
batch[state_key] = sample[state_key]
with torch.no_grad():
processed = preprocess(batch)
video_features = processed["video_features"].to(device)
text_features = processed["text_features"].to(device)
state_features = processed.get("state_features")
if state_features is not None:
state_features = state_features.to(device)
lengths = processed.get("lengths")
sparse_val = np.nan
dense_val = np.nan
# Compute sparse prediction for center frame
if compute_sparse:
sparse_progress = reward_model.calculate_rewards(
text_embeddings=text_features,
video_embeddings=video_features,
state_features=state_features,
lengths=lengths,
return_all_frames=True,
head_mode="sparse",
)
sparse_val = float(
sparse_progress[0, center_idx]
if sparse_progress.ndim == 2
else sparse_progress[center_idx]
)
# Compute dense prediction for center frame
if compute_dense:
dense_progress = reward_model.calculate_rewards(
text_embeddings=text_features,
video_embeddings=video_features,
state_features=state_features,
lengths=lengths,
return_all_frames=True,
head_mode="dense",
)
dense_val = float(
dense_progress[0, center_idx]
if dense_progress.ndim == 2
else dense_progress[center_idx]
)
frame_results[query_idx] = (sparse_val, dense_val)
except Exception as e:
logging.warning(f"Failed to process frame {query_idx}: {e}")
# Interpolate to get values for all frames
computed_indices = np.array(sorted(frame_results.keys()))
computed_sparse = (
np.array([frame_results[i][0] for i in computed_indices]) if compute_sparse else None
)
computed_dense = np.array([frame_results[i][1] for i in computed_indices]) if compute_dense else None
# All frame indices for this episode
all_frame_idx_array = np.arange(ep_start, ep_end)
if stride > 1 and len(computed_indices) > 1:
# Interpolate progress values
if compute_sparse:
interp_sparse = interpolate_progress(computed_indices, computed_sparse, all_frame_idx_array)
if compute_dense:
interp_dense = interpolate_progress(computed_indices, computed_dense, all_frame_idx_array)
else:
# No interpolation needed
interp_sparse = computed_sparse if compute_sparse else None
interp_dense = computed_dense if compute_dense else None
# Store results for all frames
for i, frame_idx in enumerate(all_frame_idx_array):
local_idx = frame_idx - ep_start
all_indices.append(frame_idx)
all_episode_indices.append(episode_idx)
all_frame_indices.append(local_idx)
if compute_sparse:
if stride > 1 and len(computed_indices) > 1:
all_progress_sparse.append(float(interp_sparse[i]))
elif frame_idx in frame_results:
all_progress_sparse.append(frame_results[frame_idx][0])
else:
all_progress_sparse.append(np.nan)
if compute_dense:
if stride > 1 and len(computed_indices) > 1:
all_progress_dense.append(float(interp_dense[i]))
elif frame_idx in frame_results:
all_progress_dense.append(frame_results[frame_idx][1])
else:
all_progress_dense.append(np.nan)
# Create output table
table_data = {
"index": np.array(all_indices, dtype=np.int64),
"episode_index": np.array(all_episode_indices, dtype=np.int64),
"frame_index": np.array(all_frame_indices, dtype=np.int64),
}
if compute_sparse:
table_data["progress_sparse"] = np.array(all_progress_sparse, dtype=np.float32)
if compute_dense:
table_data["progress_dense"] = np.array(all_progress_dense, dtype=np.float32)
# Sort by index
df = pa.table(table_data).to_pandas()
df = df.sort_values("index").reset_index(drop=True)
final_table = pa.Table.from_pandas(df, preserve_index=False)
# Add metadata with reward model path
metadata = {b"reward_model_path": reward_model_path.encode()}
final_table = final_table.replace_schema_metadata(metadata)
# Determine output path
output_path = Path(dataset.root) / "sarm_progress.parquet" if output_path is None else Path(output_path)
# Save
output_path.parent.mkdir(parents=True, exist_ok=True)
pq.write_table(final_table, output_path)
logging.info(f"Saved {len(final_table)} frame progress values to {output_path}")
# Print statistics
if "progress_sparse" in df.columns:
valid = df["progress_sparse"].dropna()
logging.info(
f"Sparse progress: mean={valid.mean():.4f}, std={valid.std():.4f}, "
f"min={valid.min():.4f}, max={valid.max():.4f}"
)
if "progress_dense" in df.columns:
valid = df["progress_dense"].dropna()
logging.info(
f"Dense progress: mean={valid.mean():.4f}, std={valid.std():.4f}, "
f"min={valid.min():.4f}, max={valid.max():.4f}"
)
# Visualize episodes after processing
if num_visualizations > 0:
viz_episodes = list(range(min(num_visualizations, num_episodes)))
logging.info(f"Generating {len(viz_episodes)} visualizations...")
visualize_sarm_predictions(
dataset=dataset,
reward_model=reward_model,
preprocess=preprocess,
episode_indices=viz_episodes,
head_mode=head_mode,
output_dir=Path(output_dir),
stride=stride,
)
return output_path
def main():
parser = argparse.ArgumentParser(
description="Compute SARM progress values for RA-BC weighting or visualize SARM predictions",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Full RA-BC computation with visualizations
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
--reward-model-path pepijn223/sarm_single_uni4
# Visualize predictions only (no RA-BC computation)
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
--reward-model-path pepijn223/sarm_single_uni4 \\
--visualize-only \\
--num-visualizations 10
""",
)
parser.add_argument(
"--dataset-repo-id",
type=str,
required=True,
help="HuggingFace dataset repo ID or local path",
)
parser.add_argument(
"--reward-model-path",
type=str,
default=None,
help="Path to pretrained SARM model (reads from existing parquet metadata if not provided)",
)
parser.add_argument(
"--output-path",
type=str,
default=None,
help="Output path for parquet. If not set, saves to dataset's cache directory",
)
parser.add_argument(
"--head-mode",
type=str,
default="sparse",
choices=["sparse", "dense", "both"],
help="SARM head to use (default: sparse)",
)
parser.add_argument(
"--device",
type=str,
default="cuda",
help="Device to use (default: cuda)",
)
# Visualization options
parser.add_argument(
"--visualize-only",
action="store_true",
help="Only visualize SARM predictions (no RA-BC computation)",
)
parser.add_argument(
"--num-visualizations",
type=int,
default=5,
help="Number of episodes to visualize (default: 5, set to 0 to skip)",
)
parser.add_argument(
"--output-dir",
type=str,
default="./sarm_viz",
help="Output directory for visualizations (default: ./sarm_viz)",
)
parser.add_argument(
"--push-to-hub",
action="store_true",
help="Upload progress file to the dataset repo on HuggingFace Hub",
default=True,
)
parser.add_argument(
"--stride",
type=int,
default=1,
help="Compute progress every N frames, interpolate the rest (default: 1 = every frame)",
)
args = parser.parse_args()
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
# Try to get reward_model_path from parquet metadata if not provided
reward_model_path = args.reward_model_path
if reward_model_path is None:
# Load dataset to find parquet path
temp_dataset = LeRobotDataset(args.dataset_repo_id, download_videos=False)
parquet_path = Path(temp_dataset.root) / "sarm_progress.parquet"
reward_model_path = get_reward_model_path_from_parquet(parquet_path)
if reward_model_path:
logging.info(f"Using reward model from parquet metadata: {reward_model_path}")
else:
raise ValueError(
"--reward-model-path is required (no existing parquet with model metadata found)"
)
# Handle visualize-only mode
if args.visualize_only:
dataset, reward_model, preprocess = load_sarm_resources(
args.dataset_repo_id, reward_model_path, args.device
)
logging.info(f"Visualization-only mode: visualizing {args.num_visualizations} episodes")
viz_episodes = list(range(min(args.num_visualizations, dataset.num_episodes)))
visualize_sarm_predictions(
dataset=dataset,
reward_model=reward_model,
preprocess=preprocess,
episode_indices=viz_episodes,
head_mode=args.head_mode,
output_dir=Path(args.output_dir),
stride=args.stride,
)
print(f"\nVisualizations saved to: {Path(args.output_dir).absolute()}")
return
# Full RABC computation (compute_sarm_progress loads model/dataset itself)
output_path = compute_sarm_progress(
dataset_repo_id=args.dataset_repo_id,
reward_model_path=reward_model_path,
output_path=args.output_path,
head_mode=args.head_mode,
device=args.device,
num_visualizations=args.num_visualizations,
output_dir=args.output_dir,
stride=args.stride,
)
print(f"\nSARM progress values saved to: {output_path}")
# Upload to Hub if requested
if args.push_to_hub:
from huggingface_hub import HfApi
api = HfApi()
hub_path = "sarm_progress.parquet"
print(f"\nUploading to Hub: {args.dataset_repo_id}/{hub_path}")
api.upload_file(
path_or_fileobj=str(output_path),
path_in_repo=hub_path,
repo_id=args.dataset_repo_id,
repo_type="dataset",
)
print(
f"Successfully uploaded to: https://huggingface.co/datasets/{args.dataset_repo_id}/blob/main/{hub_path}"
)
print("\nTo use in training, add to your config:")
print(" use_rabc: true")
print(f" rabc_progress_path: hf://datasets/{args.dataset_repo_id}/{hub_path}")
print(" rabc_head_mode: sparse # or dense")
else:
print("\nTo use in training, add to your config:")
print(" use_rabc: true")
print(f" rabc_progress_path: {output_path}")
print(" rabc_head_mode: sparse # or dense")
if __name__ == "__main__":
main()
@@ -0,0 +1,248 @@
#!/usr/bin/env python
# Copyright 2025 Qianzhong Chen, Justin Yu, Mac Schwager, Pieter Abbeel, Yide Shentu, Philipp Wu
# and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
SARM: Stage-Aware Reward Modeling for Long Horizon Robot Manipulation.
Paper: https://arxiv.org/abs/2509.25358
"""
from dataclasses import dataclass, field
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
@PreTrainedConfig.register_subclass("sarm")
@dataclass
class SARMConfig(PreTrainedConfig):
"""Configuration class for SARM (Stage-Aware Reward Modeling).
Supports three annotation modes:
1. single_stage (default): No annotations needed. Uses the episode's task description
as a single stage covering the entire episode.
2. dense_only: Uses dense (fine-grained) annotations from VLM, with an auto-generated
single sparse "task" stage covering the full episode. The dense head learns detailed
subtask progression while sparse provides overall task completion.
3. dual: Full dual-head mode with both sparse (high-level) and dense (fine-grained)
annotations from VLM. Both heads are trained on their respective annotations.
The annotation_mode determines how sparse_temporal_proportions and dense_temporal_proportions
are loaded/generated during model initialization.
"""
annotation_mode: str = "single_stage" # "single_stage", "dense_only", or "dual"
n_obs_steps: int = 8 # Number of observation history steps
frame_gap: int = 30 # Frame gap between frames (at 30 fps = 1 second)
max_rewind_steps: int = 4 # Maximum rewind steps for temporal augmentation
# Total frames = 1 + n_obs_steps + max_rewind_steps (computed in property)
# During training with rewind: [obs_frames] + [rewind_frames]
# During inference: [obs_frames] only
# Architecture params
image_dim: int = 512
text_dim: int = 512
hidden_dim: int = 768
num_heads: int = 12
num_layers: int = 8
max_state_dim: int = 32
drop_n_last_frames: int = 1
batch_size: int = 64
clip_batch_size: int = 64
dropout: float = 0.1
stage_loss_weight: float = 1.0 # Weight for stage classification loss when using subtask annotations
rewind_probability: float = 0.8
language_perturbation_probability: float = 0.2
# Sparse annotations (high-level stages)
num_sparse_stages: int = 1
sparse_subtask_names: list | None = None
sparse_temporal_proportions: list | None = None
# Dense annotations (fine-grained stages)
num_dense_stages: int | None = None
dense_subtask_names: list | None = None
dense_temporal_proportions: list | None = None
pretrained_model_path: str | None = None
device: str | None = None
image_key: str = "observation.images.top" # Key for image used from the dataset
state_key: str = "observation.state"
# Populated by the processor (video_features, state_features, text_features)
input_features: dict = field(default_factory=lambda: {})
# Output features (updated in __post_init__)
output_features: dict = field(
default_factory=lambda: {
"stage": PolicyFeature(shape=(9, 5), type=FeatureType.REWARD),
"progress": PolicyFeature(shape=(9, 1), type=FeatureType.REWARD),
}
)
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.MEAN_STD,
"LANGUAGE": NormalizationMode.IDENTITY,
"REWARD": NormalizationMode.IDENTITY,
}
)
def __post_init__(self):
super().__post_init__()
if self.annotation_mode not in ["single_stage", "dense_only", "dual"]:
raise ValueError(
f"annotation_mode must be 'single_stage', 'dense_only', or 'dual', got {self.annotation_mode}"
)
if self.annotation_mode == "single_stage":
# Use task description as stage name, full episode as one stage
self.num_sparse_stages = 1
self.sparse_subtask_names = ["task"]
self.sparse_temporal_proportions = [1.0]
self.num_dense_stages = None
self.dense_subtask_names = None
self.dense_temporal_proportions = None
elif self.annotation_mode == "dense_only":
self.num_sparse_stages = 1
self.sparse_subtask_names = ["task"]
self.sparse_temporal_proportions = [1.0]
self.input_features = {}
self.output_features = {}
if self.image_key:
self.input_features[self.image_key] = PolicyFeature(shape=(480, 640, 3), type=FeatureType.VISUAL)
self.input_features[self.state_key] = PolicyFeature(
shape=(self.max_state_dim,),
type=FeatureType.STATE,
)
# Update output features based on annotation_mode
if self.annotation_mode in ["dense_only", "dual"]:
self.output_features["sparse_stage"] = PolicyFeature(
shape=(self.num_frames, self.num_sparse_stages), type=FeatureType.REWARD
)
self.output_features["sparse_progress"] = PolicyFeature(
shape=(self.num_frames, 1), type=FeatureType.REWARD
)
dense_stages = self.num_dense_stages or self.num_sparse_stages
self.output_features["dense_stage"] = PolicyFeature(
shape=(self.num_frames, dense_stages), type=FeatureType.REWARD
)
self.output_features["dense_progress"] = PolicyFeature(
shape=(self.num_frames, 1), type=FeatureType.REWARD
)
else:
self.output_features["sparse_stage"] = PolicyFeature(
shape=(self.num_frames, self.num_sparse_stages), type=FeatureType.REWARD
)
self.output_features["sparse_progress"] = PolicyFeature(
shape=(self.num_frames, 1), type=FeatureType.REWARD
)
if self.max_rewind_steps >= self.n_obs_steps:
raise ValueError(
f"max_rewind_steps ({self.max_rewind_steps}) must be less than n_obs_steps ({self.n_obs_steps})"
)
if self.num_sparse_stages < 1:
raise ValueError(f"num_sparse_stages must be at least 1, got {self.num_sparse_stages}")
if (
self.annotation_mode in ["dense_only", "dual"]
and self.num_dense_stages is not None
and self.num_dense_stages < 2
):
raise ValueError(f"num_dense_stages must be at least 2, got {self.num_dense_stages}")
def get_optimizer_preset(self) -> AdamWConfig:
"""Get default optimizer configuration for SARM training."""
return AdamWConfig(
lr=5e-5,
weight_decay=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
)
def get_scheduler_preset(self) -> CosineDecayWithWarmupSchedulerConfig:
"""Get default learning rate scheduler configuration."""
return CosineDecayWithWarmupSchedulerConfig(
peak_lr=5e-5,
decay_lr=5e-6,
num_warmup_steps=500,
num_decay_steps=50000,
)
def validate_features(self) -> None:
pass
@property
def uses_dual_heads(self) -> bool:
"""Whether the model uses dual heads (dense_only or dual annotation modes)."""
return self.annotation_mode in ["dense_only", "dual"]
@property
def num_frames(self) -> int:
"""Total number of frames in sequence.
For training: 1 + n_obs_steps + max_rewind_steps
The sequence is: [obs_frames (n_obs_steps + 1)] + [rewind_frames (max_rewind_steps)]
"""
return 1 + self.n_obs_steps + self.max_rewind_steps
@property
def max_length(self) -> int:
return self.num_frames
@property
def observation_delta_indices(self) -> list[int]:
"""Bidirectional frame sampling centered on target frame.
Example with n_obs_steps=8, gap=30:
Before: [-120, -90, -60, -30] (4 frames)
Current: [0] (1 frame)
After: [30, 60, 90, 120] (4 frames)
Total: 9 frames
"""
half_steps = self.n_obs_steps // 2
past_deltas = [-self.frame_gap * i for i in range(half_steps, 0, -1)]
future_deltas = [self.frame_gap * i for i in range(1, half_steps + 1)]
obs_deltas = past_deltas + [0] + future_deltas
# Rewind placeholders
rewind_deltas = [-self.frame_gap * (i + 1) for i in range(self.max_rewind_steps)]
return obs_deltas + rewind_deltas
@property
def action_delta_indices(self) -> None:
"""SARM is a reward model, not an action policy."""
return None
@property
def reward_delta_indices(self) -> None:
return None
+793
View File
@@ -0,0 +1,793 @@
#!/usr/bin/env python
# Copyright 2025 Qianzhong Chen, Justin Yu, Mac Schwager, Pieter Abbeel, Yide Shentu, Philipp Wu
# and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
SARM: Stage-Aware Reward Modeling for Long Horizon Robot Manipulation.
Paper: https://arxiv.org/abs/2509.25358
- StageTransformer: Predicts stage classification (sparse/dense)
- SubtaskTransformer: Predicts within-stage progress (tau) conditioned on stage
"""
import json
import logging
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F # noqa: N812
from torch import Tensor
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.sarm.configuration_sarm import SARMConfig
from lerobot.policies.sarm.sarm_utils import (
normalize_stage_tau,
pad_state_to_max_dim,
)
class StageTransformer(nn.Module):
"""
Stage classification transformer for SARM.
Predicts which stage/subtask the current frame belongs to.
Supports both sparse (high-level) and dense (fine-grained) annotation schemes.
Input streams: [vis_proj, lang_proj, state_proj] concatenated -> (B, N+2, T, D)
Output: stage logits (B, T, num_classes)
"""
def __init__(
self,
d_model: int = 512,
vis_emb_dim: int = 512,
text_emb_dim: int = 512,
state_dim: int = 32,
n_layers: int = 6,
n_heads: int = 8,
dropout: float = 0.1,
num_cameras: int = 1,
num_classes_sparse: int = 4,
num_classes_dense: int = 8,
):
super().__init__()
self.d_model = d_model
self.num_cameras = num_cameras
# Projections
self.lang_proj = nn.Linear(text_emb_dim, d_model)
self.visual_proj = nn.Linear(vis_emb_dim, d_model)
self.state_proj = nn.Linear(state_dim, d_model)
# Encoder
enc_layer = nn.TransformerEncoderLayer(d_model, n_heads, 4 * d_model, dropout, batch_first=True)
self.transformer = nn.TransformerEncoder(enc_layer, n_layers)
# Positional bias on first visual frame
self.first_pos = nn.Parameter(torch.zeros(1, d_model))
# Shared fusion MLP
# Fuses (num_cameras + 2) streams: cameras + lang + state
fused_in = d_model * (num_cameras + 2)
self.fusion_backbone = nn.Sequential(
nn.LayerNorm(fused_in),
nn.Linear(fused_in, d_model),
nn.ReLU(),
)
# Scheme-specific heads
self.heads = nn.ModuleDict(
{
"sparse": nn.Linear(d_model, num_classes_sparse),
"dense": nn.Linear(d_model, num_classes_dense),
}
)
def _prep_lang(self, lang_emb: torch.Tensor, B: int, T: int, D: int) -> torch.Tensor: # noqa: N803
"""
Prepare language embeddings for fusion.
Accepts lang_emb of shape:
- (B, text_emb_dim) -> broadcast across time
- (B, T, text_emb_dim) -> per-timestep (dense annotation mode)
Returns: (B, 1, T, D)
"""
if lang_emb.dim() == 3:
# (B, T, E) -> (B, T, D) -> (B, 1, T, D)
lang_proj = self.lang_proj(lang_emb).unsqueeze(1)
else:
# (B, E) -> (B, 1, 1, D) -> expand to (B, 1, T, D)
lang_proj = self.lang_proj(lang_emb).unsqueeze(1).unsqueeze(2).expand(B, 1, T, D)
return lang_proj
def forward(
self,
img_seq: torch.Tensor, # (B, N, T, vis_emb_dim)
lang_emb: torch.Tensor, # (B, E) or (B, T, E)
state: torch.Tensor, # (B, T, state_dim)
lengths: torch.Tensor, # (B,) - valid sequence lengths
scheme: str = "sparse", # "sparse" or "dense"
) -> torch.Tensor:
"""
Forward pass for stage classification.
Args:
img_seq: Image embeddings (B, N, T, vis_emb_dim) where N=num_cameras
lang_emb: Language embeddings (B, E) or (B, T, E) for dense
state: State features (B, T, state_dim)
lengths: Valid sequence lengths (B,) for masking
scheme: "sparse" or "dense" for head selection
Returns:
Stage logits (B, T, num_classes)
"""
assert scheme in self.heads, f"Unknown scheme '{scheme}'. Use one of {list(self.heads.keys())}."
B, N, T, _ = img_seq.shape # noqa: N806
D = self.d_model # noqa: N806
device = img_seq.device
# Project inputs
vis_proj = self.visual_proj(img_seq) # (B, N, T, D)
state_proj = self.state_proj(state).unsqueeze(1) # (B, 1, T, D)
lang_proj = self._prep_lang(lang_emb, B, T, D) # (B, 1, T, D)
# Concatenate streams
# cameras + lang + state -> (B, N+2, T, D)
x = torch.cat([vis_proj, lang_proj, state_proj], dim=1)
# Add positional bias to first visual frame
x[:, :N, 0, :] = x[:, :N, 0, :] + self.first_pos
# Flatten to tokens for Transformer
x_tokens = x.view(B, (N + 2) * T, D)
L = x_tokens.size(1) # noqa: N806
# Create padding mask
base_mask = torch.arange(T, device=device).expand(B, T) >= lengths.unsqueeze(1) # (B, T)
mask = base_mask.unsqueeze(1).expand(B, N + 2, T).reshape(B, (N + 2) * T)
# Create causal mask
causal_mask = torch.triu(torch.ones(L, L, device=device, dtype=torch.bool), diagonal=1)
# Encode
h = self.transformer(x_tokens, mask=causal_mask, src_key_padding_mask=mask, is_causal=True)
# Reshape and fuse
h = h.view(B, N + 2, T, D).permute(0, 2, 1, 3).reshape(B, T, (N + 2) * D)
fused = self.fusion_backbone(h) # (B, T, D)
# Scheme-specific logits
logits = self.heads[scheme](fused) # (B, T, num_classes)
return logits
class SubtaskTransformer(nn.Module):
"""
Subtask progress regression transformer for SARM.
Predicts within-stage normalized progress (tau) conditioned on stage prior.
The stage prior is a one-hot encoding passed from StageTransformer predictions.
Input streams: [vis_proj, lang_proj, state_proj, stage_emb] -> (B, N+3, T, D)
Output: tau predictions (B, T) in [0, 1]
"""
def __init__(
self,
d_model: int = 512,
vis_emb_dim: int = 512,
text_emb_dim: int = 512,
state_dim: int = 32,
n_layers: int = 6,
n_heads: int = 8,
dropout: float = 0.1,
num_cameras: int = 1,
):
super().__init__()
self.d_model = d_model
self.num_cameras = num_cameras
# Projections
self.lang_proj = nn.Linear(text_emb_dim, d_model)
self.visual_proj = nn.Linear(vis_emb_dim, d_model)
self.state_proj = nn.Linear(state_dim, d_model)
# Encoder
enc = nn.TransformerEncoderLayer(d_model, n_heads, 4 * d_model, dropout, batch_first=True)
self.transformer = nn.TransformerEncoder(enc, n_layers)
# Learned bias on first visual frame
self.first_pos = nn.Parameter(torch.zeros(1, d_model))
# Shared fusion backbone
# Fuses (num_cameras + 3) streams: cameras + lang + state + stage_emb
fused_in = d_model * (num_cameras + 3)
self.fusion_backbone = nn.Sequential(
nn.LayerNorm(fused_in),
nn.Linear(fused_in, d_model),
nn.ReLU(),
)
# Scheme-specific regression heads
self.heads = nn.ModuleDict(
{
"sparse": nn.Linear(d_model, 1),
"dense": nn.Linear(d_model, 1),
}
)
def _prep_lang(self, lang_emb: torch.Tensor, B: int, T: int, D: int) -> torch.Tensor: # noqa: N803
"""
Prepare language embeddings for fusion.
"""
if lang_emb.dim() == 3:
# (B, T, E) -> (B, T, D) -> (B, 1, T, D)
return self.lang_proj(lang_emb).unsqueeze(1)
else:
# (B, E) -> (B, 1, 1, D) -> (B, 1, T, D)
return self.lang_proj(lang_emb).unsqueeze(1).unsqueeze(2).expand(B, 1, T, D)
def _stage_to_dmodel(self, stage_prior: torch.Tensor) -> torch.Tensor:
"""
Deterministic projection of one-hot stage to d_model by pad/truncate.
Args:
stage_prior: One-hot stage embedding (B, 1, T, C)
Returns:
Projected stage embedding (B, 1, T, d_model)
"""
B, one, T, C = stage_prior.shape # noqa: N806
D = self.d_model # noqa: N806
if D == C:
return stage_prior
elif D > C:
pad = torch.zeros(B, one, T, D - C, device=stage_prior.device, dtype=stage_prior.dtype)
return torch.cat([stage_prior, pad], dim=-1)
else:
return stage_prior[..., :D]
def forward(
self,
img_seq: torch.Tensor, # (B, N, T, vis_emb_dim)
lang_emb: torch.Tensor, # (B, E) or (B, T, E)
state: torch.Tensor, # (B, T, state_dim)
lengths: torch.Tensor, # (B,) - valid sequence lengths
stage_prior: torch.Tensor, # (B, 1, T, C) one-hot from gen_stage_emb
scheme: str = "sparse", # "sparse" or "dense"
) -> torch.Tensor:
"""
Forward pass for subtask progress regression.
Args:
img_seq: Image embeddings (B, N, T, vis_emb_dim)
lang_emb: Language embeddings (B, E) or (B, T, E)
state: State features (B, T, state_dim)
lengths: Valid sequence lengths (B,) for masking
stage_prior: One-hot stage prior (B, 1, T, num_classes)
scheme: "sparse" or "dense" for head selection
Returns:
Tau predictions (B, T) in [0, 1] via sigmoid
"""
assert scheme in self.heads, f"Unknown scheme '{scheme}'. Use one of {list(self.heads.keys())}."
B, N, T, _ = img_seq.shape # noqa: N806
D = self.d_model # noqa: N806
device = img_seq.device
# Project inputs
vis_proj = self.visual_proj(img_seq) # (B, N, T, D)
state_proj = self.state_proj(state).unsqueeze(1) # (B, 1, T, D)
lang_proj = self._prep_lang(lang_emb, B, T, D) # (B, 1, T, D)
stage_emb = self._stage_to_dmodel(stage_prior) # (B, 1, T, D)
# Concatenate all streams
# cameras + lang + state + stage_emb -> (B, N+3, T, D)
x = torch.cat([vis_proj, lang_proj, state_proj, stage_emb], dim=1)
# Add positional bias to first visual frame
x[:, :N, 0, :] = x[:, :N, 0, :] + self.first_pos
# Flatten to tokens
x_tokens = x.view(B, (N + 3) * T, D)
L = x_tokens.size(1) # noqa: N806
# Create padding mask
base_mask = torch.arange(T, device=device).expand(B, T) >= lengths.unsqueeze(1)
mask = base_mask.unsqueeze(1).expand(B, N + 3, T).reshape(B, (N + 3) * T)
# Create causal mask
causal_mask = torch.triu(torch.ones(L, L, device=device, dtype=torch.bool), diagonal=1)
# Encode
h = self.transformer(x_tokens, mask=causal_mask, src_key_padding_mask=mask, is_causal=True)
# Reshape and fuse
h = h.view(B, N + 3, T, D)
h_flat = h.permute(0, 2, 1, 3).reshape(B, T, (N + 3) * D)
fused = self.fusion_backbone(h_flat) # (B, T, D)
# Scheme-specific regression head -> sigmoid
r = torch.sigmoid(self.heads[scheme](fused)).squeeze(-1) # (B, T)
return r
def gen_stage_emb(num_classes: int, targets: torch.Tensor) -> torch.Tensor:
"""
Generate one-hot stage embeddings from targets.
Args:
num_classes: Number of stage classes
targets: Target values (B, T) where integer part is stage index
Returns:
One-hot stage embedding (B, 1, T, num_classes)
"""
# Integer part of float targets -> [0, C-1]
idx = targets.long().clamp(min=0, max=num_classes - 1) # (B, T)
C = num_classes # noqa: N806
# Identity-lookup one-hot
stage_onehot = torch.eye(C, device=targets.device)[idx] # (B, T, C)
stage_onehot = stage_onehot.unsqueeze(1) # (B, 1, T, C)
return stage_onehot
class SARMRewardModel(PreTrainedPolicy):
"""
SARM Reward Model for stage-aware task completion rewards.
Uses two separate transformer models:
- StageTransformer: Classifies which stage/subtask
- SubtaskTransformer: Predicts within-stage progress (tau)
Training uses 75%/25% GT/predicted stage conditioning (teacher forcing).
"""
name = "sarm"
config_class = SARMConfig
def __init__(self, config: SARMConfig, dataset_stats: dict | None = None, dataset_meta=None):
super().__init__(config, dataset_stats)
config.validate_features()
self.config = config
self.dataset_stats = dataset_stats
self.device = torch.device(
config.device if config.device else "cuda" if torch.cuda.is_available() else "cpu"
)
# Load temporal proportions based on annotation_mode
if config.annotation_mode == "single_stage":
logging.info(f"Using single_stage mode: sparse_subtask_names={config.sparse_subtask_names}")
elif dataset_meta is not None:
self._load_temporal_proportions(dataset_meta)
# Create two separate models
self.stage_model = StageTransformer(
d_model=config.hidden_dim,
vis_emb_dim=config.image_dim,
text_emb_dim=config.text_dim,
state_dim=config.max_state_dim,
n_layers=config.num_layers,
n_heads=config.num_heads,
dropout=config.dropout,
num_cameras=1, # Single camera for now
num_classes_sparse=config.num_sparse_stages,
num_classes_dense=config.num_dense_stages or config.num_sparse_stages,
)
self.subtask_model = SubtaskTransformer(
d_model=config.hidden_dim,
vis_emb_dim=config.image_dim,
text_emb_dim=config.text_dim,
state_dim=config.max_state_dim,
n_layers=config.num_layers,
n_heads=config.num_heads,
dropout=config.dropout,
num_cameras=1,
)
self.stage_model.to(self.device)
self.subtask_model.to(self.device)
# GT/predicted stage ratio for teacher forcing
self.gt_stage_ratio = 0.75
if config.uses_dual_heads:
logging.info(
f"SARM initialized with dual heads: {config.num_sparse_stages} sparse stages, "
f"{config.num_dense_stages} dense stages"
)
else:
logging.info(f"SARM initialized with sparse head only: {config.num_sparse_stages} stages")
logging.info(f"SARM initialized on {self.device}")
def _load_proportions_from_json(self, path, annotation_type: str) -> tuple[list[str], list[float]]:
"""Load temporal proportions from a JSON file (preserving order)."""
if not path.exists():
raise ValueError(
f"{annotation_type.capitalize()} temporal proportions not found at {path}. "
f"Run the subtask annotation tool with --{annotation_type}-subtasks to generate annotations."
)
with open(path) as f:
proportions_dict = json.load(f)
names = list(proportions_dict.keys())
logging.info(f"Loaded {len(names)} {annotation_type} subtasks: {names}")
logging.info(f"{annotation_type.capitalize()} temporal proportions: {proportions_dict}")
return names, [proportions_dict[name] for name in names]
def _load_temporal_proportions(self, dataset_meta) -> None:
"""Load temporal proportions based on annotation_mode."""
meta_path = dataset_meta.root / "meta"
if self.config.annotation_mode == "dual":
names, props = self._load_proportions_from_json(
meta_path / "temporal_proportions_sparse.json", "sparse"
)
(
self.config.num_sparse_stages,
self.config.sparse_subtask_names,
self.config.sparse_temporal_proportions,
) = len(names), names, props
if self.config.annotation_mode in ["dense_only", "dual"]:
names, props = self._load_proportions_from_json(
meta_path / "temporal_proportions_dense.json", "dense"
)
(
self.config.num_dense_stages,
self.config.dense_subtask_names,
self.config.dense_temporal_proportions,
) = len(names), names, props
if self.config.annotation_mode == "dense_only":
logging.info(f"Using auto-generated sparse 'task' stage: {self.config.sparse_subtask_names}")
def to(self, device):
"""Override to method to ensure all components move together."""
super().to(device)
self.device = device if isinstance(device, torch.device) else torch.device(device)
self.stage_model.to(device)
self.subtask_model.to(device)
return self
@torch.no_grad()
def calculate_rewards(
self,
text_embeddings: np.ndarray | torch.Tensor,
video_embeddings: np.ndarray | torch.Tensor,
state_features: np.ndarray | torch.Tensor | None = None,
lengths: np.ndarray | torch.Tensor | None = None,
return_all_frames: bool = False,
return_stages: bool = False,
return_confidence: bool = False,
head_mode: str | None = "sparse",
frame_index: int | None = None,
) -> np.ndarray | tuple:
"""
Calculate rewards for given text, video, and state representations.
This is the canonical method for SARM reward computation, used for:
- Inference/visualization
- RA-BC weight computation
Args:
text_embeddings: Encoded text representations (batch_size, 512)
video_embeddings: Encoded video representations (batch_size, num_frames, 512)
state_features: Joint state features (batch_size, num_frames, state_dim)
lengths: Valid sequence lengths (batch_size,)
return_all_frames: If True, return rewards for all frames
return_stages: If True, also return stage predictions
return_confidence: If True, also return stage confidence
head_mode: Which head to use ("sparse" or "dense")
frame_index: Index of the target frame to extract (default: n_obs_steps).
Returns:
Rewards and optionally stage probs/confidence.
"""
if isinstance(text_embeddings, np.ndarray):
text_embeddings = torch.tensor(text_embeddings, dtype=torch.float32)
if isinstance(video_embeddings, np.ndarray):
video_embeddings = torch.tensor(video_embeddings, dtype=torch.float32)
if state_features is not None and isinstance(state_features, np.ndarray):
state_features = torch.tensor(state_features, dtype=torch.float32)
# Handle single sample case
if text_embeddings.dim() == 1:
text_embeddings = text_embeddings.unsqueeze(0)
video_embeddings = video_embeddings.unsqueeze(0)
if state_features is not None:
state_features = state_features.unsqueeze(0)
single_sample = True
else:
single_sample = False
batch_size = video_embeddings.shape[0]
seq_len = video_embeddings.shape[1]
scheme = head_mode
# Default lengths if not provided
if lengths is None:
lengths = torch.full((batch_size,), seq_len, dtype=torch.int32)
elif isinstance(lengths, np.ndarray):
lengths = torch.tensor(lengths, dtype=torch.int32)
# Reshape video to (B, N, T, D) for multi-camera format
# Currently single camera: (B, T, D) -> (B, 1, T, D)
img_seq = video_embeddings.unsqueeze(1).to(self.device)
lang_emb = text_embeddings.to(self.device)
state = (
state_features.to(self.device)
if state_features is not None
else torch.zeros(batch_size, seq_len, self.config.max_state_dim, device=self.device)
)
lens = lengths.to(self.device)
# Pad state to max_state_dim
state = pad_state_to_max_dim(state, self.config.max_state_dim)
# Get num_classes for this scheme
num_classes = self.config.num_sparse_stages if scheme == "sparse" else self.config.num_dense_stages
# Run stage model
stage_logits = self.stage_model(img_seq, lang_emb, state, lens, scheme=scheme)
stage_probs = F.softmax(stage_logits, dim=-1) # (B, T, num_classes)
stage_idx = stage_probs.argmax(dim=-1) # (B, T)
stage_conf = stage_probs.gather(-1, stage_idx.unsqueeze(-1)).squeeze(-1) # (B, T)
# Create one-hot stage prior
stage_onehot = F.one_hot(stage_idx, num_classes=num_classes).float() # (B, T, C)
stage_emb = stage_onehot.unsqueeze(1) # (B, 1, T, C)
# Run subtask model
tau_pred = self.subtask_model(img_seq, lang_emb, state, lens, stage_emb, scheme=scheme)
# Compute final reward: stage + tau
raw_reward = stage_idx.float() + tau_pred # (B, T)
# Normalize to [0, 1] using temporal proportions for proper weighting
if scheme == "sparse":
normalized_reward = normalize_stage_tau(
raw_reward,
num_stages=num_classes,
temporal_proportions=self.config.sparse_temporal_proportions,
subtask_names=self.config.sparse_subtask_names,
)
else:
normalized_reward = normalize_stage_tau(
raw_reward,
num_stages=num_classes,
temporal_proportions=self.config.dense_temporal_proportions,
subtask_names=self.config.dense_subtask_names,
)
# Default frame index is n_obs_steps (last observation frame)
if frame_index is None:
frame_index = self.config.n_obs_steps
# Prepare outputs (batch mode or no smoothing)
if return_all_frames:
rewards = normalized_reward.cpu().numpy()
else:
rewards = normalized_reward[:, frame_index].cpu().numpy()
if single_sample:
rewards = rewards[0] if not return_all_frames else rewards[0]
outputs = [rewards]
if return_stages:
probs = stage_probs.cpu().numpy()
if single_sample:
probs = probs[0]
outputs.append(probs)
if return_confidence:
conf = stage_conf.cpu().numpy()
if single_sample:
conf = conf[0]
outputs.append(conf)
return outputs[0] if len(outputs) == 1 else tuple(outputs)
def train(self, mode: bool = True):
"""Set training mode for both models."""
super().train(mode)
self.stage_model.train(mode)
self.subtask_model.train(mode)
return self
def eval(self):
"""Set evaluation mode for both models."""
return self.train(False)
def parameters(self):
"""Override to return trainable parameters from both models."""
from itertools import chain
return chain(self.stage_model.parameters(), self.subtask_model.parameters())
def get_optim_params(self):
"""Override to return optimizer parameters from both models."""
return self.parameters()
def reset(self):
"""Required by PreTrainedPolicy but not used for reward models."""
pass
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""Required by PreTrainedPolicy but not used for reward models."""
raise NotImplementedError("SARM model does not predict action chunks")
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Required by PreTrainedPolicy but not used for SARM."""
raise NotImplementedError("SARM model does not select actions")
def _train_step(
self,
img_emb: torch.Tensor, # (B, N, T, D)
lang_emb: torch.Tensor, # (B, E) or (B, T, E)
state: torch.Tensor, # (B, T, state_dim)
lengths: torch.Tensor, # (B,)
targets: torch.Tensor, # (B, T) - format: stage.tau
scheme: str,
) -> dict[str, torch.Tensor]:
"""
Single training step for one annotation scheme.
Implements 75%/25% GT/predicted stage conditioning.
Args:
img_emb: Image embeddings (B, N, T, D)
lang_emb: Language embeddings
state: State features
lengths: Valid sequence lengths
targets: Target values where floor=stage, remainder=tau
scheme: "sparse" or "dense"
Returns:
Dict with stage_loss, subtask_loss, total_loss
"""
num_classes = self.config.num_sparse_stages if scheme == "sparse" else self.config.num_dense_stages
# Ground truth: stage (integer) and tau (fractional)
# Clamp stage indices to valid range [0, num_classes-1] to handle edge cases
# where targets may exceed expected range (e.g., frames between subtasks)
gt_stage = torch.floor(targets).long().clamp(0, num_classes - 1) # (B, T)
gt_tau = torch.remainder(targets, 1.0) # (B, T)
# Run stage model
stage_pred = self.stage_model(img_emb, lang_emb, state, lengths, scheme=scheme)
# 75%/25% GT/predicted stage conditioning
if random.random() < self.gt_stage_ratio:
# Mode 1: Use ground truth stage -> one-hot
stage_emb = gen_stage_emb(num_classes, targets) # (B, 1, T, C)
else:
# Mode 2: Use predicted stage argmax -> one-hot
stage_idx = stage_pred.argmax(dim=-1) # (B, T)
stage_onehot = F.one_hot(stage_idx, num_classes=num_classes).float() # (B, T, C)
stage_emb = stage_onehot.unsqueeze(1) # (B, 1, T, C)
# Run subtask model with stage prior
tau_pred = self.subtask_model(img_emb, lang_emb, state, lengths, stage_emb, scheme=scheme)
# Compute losses
stage_loss = F.cross_entropy(stage_pred.view(-1, num_classes), gt_stage.view(-1), reduction="mean")
subtask_loss = F.mse_loss(tau_pred, gt_tau, reduction="mean")
return {
"stage_loss": stage_loss,
"subtask_loss": subtask_loss,
"total_loss": stage_loss + subtask_loss,
}
def forward(self, batch):
"""
Forward pass for SARM reward model training.
Uses stage+tau target format where:
- Integer part = stage index
- Fractional part = within-stage progress (tau)
Training uses 75%/25% GT/predicted stage conditioning.
Args:
batch: Dictionary with 'observation' containing:
- 'video_features': (B, T, 512) pre-encoded video features
- 'text_features': (B, 512) or (B, T, 512) text features
- 'state_features': (B, T, state_dim) joint state features
- 'lengths': (B,) valid sequence lengths
- 'sparse_targets': (B, T) sparse targets (stage.tau format)
- 'dense_targets': (B, T) dense targets (optional, for dual mode)
Returns:
Tuple of (total_loss, output_dict with loss components)
"""
observation = batch.get("observation", batch)
# Extract features
video_features = observation["video_features"].to(self.device)
text_features = observation["text_features"].to(self.device)
state_features = observation.get("state_features")
if state_features is not None:
state_features = state_features.to(self.device)
batch_size = video_features.shape[0]
seq_len = video_features.shape[1]
# Get lengths (default to full sequence)
lengths = observation.get("lengths")
if lengths is None:
lengths = torch.full((batch_size,), seq_len, dtype=torch.int32, device=self.device)
else:
lengths = lengths.to(self.device)
# Reshape video to (B, N, T, D) - single camera
img_emb = video_features.unsqueeze(1)
# Pad state to max_state_dim
if state_features is None:
state_features = torch.zeros(batch_size, seq_len, self.config.max_state_dim, device=self.device)
else:
state_features = pad_state_to_max_dim(state_features, self.config.max_state_dim)
output_dict = {}
total_loss = torch.tensor(0.0, device=self.device)
# Sparse training (always)
sparse_targets = observation.get("sparse_targets")
if sparse_targets is None:
# Try legacy format
sparse_targets = observation.get("targets")
if sparse_targets is None:
raise ValueError("sparse_targets (or targets) is required for SARM training")
sparse_targets = sparse_targets.to(self.device)
sparse_result = self._train_step(
img_emb, text_features, state_features, lengths, sparse_targets, scheme="sparse"
)
output_dict["sparse_stage_loss"] = sparse_result["stage_loss"].item()
output_dict["sparse_subtask_loss"] = sparse_result["subtask_loss"].item()
total_loss = total_loss + sparse_result["total_loss"]
# Dense training (if dual mode)
if self.config.uses_dual_heads:
dense_targets = observation.get("dense_targets")
if dense_targets is not None:
dense_targets = dense_targets.to(self.device)
dense_result = self._train_step(
img_emb, text_features, state_features, lengths, dense_targets, scheme="dense"
)
output_dict["dense_stage_loss"] = dense_result["stage_loss"].item()
output_dict["dense_subtask_loss"] = dense_result["subtask_loss"].item()
total_loss = total_loss + dense_result["total_loss"]
output_dict["total_loss"] = total_loss.item()
return total_loss, output_dict
def compute_stage_loss(stage_logits: torch.Tensor, target_stages: torch.Tensor) -> torch.Tensor:
"""Compute cross-entropy loss for stage classification."""
_, _, num_stages = stage_logits.shape
stage_logits_flat = stage_logits.reshape(-1, num_stages)
# Clamp target stage indices to valid range [0, num_stages-1]
target_stages_flat = target_stages.reshape(-1).clamp(0, num_stages - 1)
return F.cross_entropy(stage_logits_flat, target_stages_flat)
+518
View File
@@ -0,0 +1,518 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""SARM Processor for encoding images/text and generating stage+tau targets."""
import random
from typing import Any
import numpy as np
import pandas as pd
import torch
from faker import Faker
from PIL import Image
from transformers import CLIPModel, CLIPProcessor
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.policies.sarm.configuration_sarm import SARMConfig
from lerobot.policies.sarm.sarm_utils import (
apply_rewind_augmentation,
compute_absolute_indices,
find_stage_and_tau,
pad_state_to_max_dim,
)
from lerobot.processor import (
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
NormalizerProcessorStep,
PolicyAction,
PolicyProcessorPipeline,
ProcessorStep,
RenameObservationsProcessorStep,
)
from lerobot.processor.converters import (
from_tensor_to_numpy,
policy_action_to_transition,
transition_to_policy_action,
)
from lerobot.processor.core import EnvTransition, TransitionKey
from lerobot.processor.pipeline import PipelineFeatureType
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
class SARMEncodingProcessorStep(ProcessorStep):
"""ProcessorStep that encodes images and text with CLIP and generates stage and progress labels for SARM."""
def __init__(
self,
config: SARMConfig,
image_key: str | None = None,
dataset_meta=None,
dataset_stats: dict | None = None,
):
super().__init__()
self.config = config
self.image_key = image_key or config.image_key
self.dataset_meta = dataset_meta
self.dataset_stats = dataset_stats
self.annotation_mode = config.annotation_mode
# Helper to create temporal proportions dict
def make_props_dict(names, props):
return dict(zip(names, props, strict=True)) if names and props else None
# Sparse annotations (always needed)
self.sparse_temporal_proportions = make_props_dict(
config.sparse_subtask_names, config.sparse_temporal_proportions
)
self.sparse_subtask_names = config.sparse_subtask_names
# Dense annotations (only for dual mode)
self.dense_subtask_names = config.dense_subtask_names if config.uses_dual_heads else None
self.dense_temporal_proportions = (
make_props_dict(config.dense_subtask_names, config.dense_temporal_proportions)
if config.uses_dual_heads
else None
)
self.device = torch.device(
self.config.device if self.config.device else "cuda" if torch.cuda.is_available() else "cpu"
)
self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", use_fast=True)
self.clip_model.to(self.device)
self.clip_model.eval()
self.verbs = ["move", "grasp", "rotate", "push", "pull", "slide", "lift", "place"]
self.fake = Faker()
def _find_episode_for_frame(self, frame_idx: int) -> int:
"""Find the episode index for a given frame index."""
for ep_idx in range(len(self.dataset_meta.episodes)):
ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"]
ep_end = self.dataset_meta.episodes[ep_idx]["dataset_to_index"]
if ep_start <= frame_idx < ep_end:
return ep_idx
return 0
def _get_episode_indices(self, frame_indices: np.ndarray, episode_index) -> np.ndarray:
"""Get episode indices for each frame index."""
if episode_index is None:
return np.array([self._find_episode_for_frame(int(f)) for f in frame_indices])
episode_indices = np.atleast_1d(np.asarray(from_tensor_to_numpy(episode_index)))
# If single episode but multiple frames, compute episode for each frame
if len(episode_indices) == 1 and len(frame_indices) > 1:
return np.array([self._find_episode_for_frame(int(f)) for f in frame_indices])
return episode_indices
def _generate_perturbed_task(self) -> str:
"""Generate a random perturbed task string for language perturbation."""
num_words = random.randint(1, 5)
verb = random.choice(self.verbs)
phrase = " ".join([verb] + self.fake.words(nb=num_words))
return phrase
def _get_annotation_config(self, annotation_type: str) -> tuple[list[str], dict[str, float] | None]:
"""Get global subtask names and temporal proportions for an annotation type."""
if annotation_type == "dense":
return self.dense_subtask_names, self.dense_temporal_proportions
return self.sparse_subtask_names, self.sparse_temporal_proportions
def _load_episode_annotations(
self,
ep_idx: int,
episodes_df: pd.DataFrame | None,
annotation_type: str,
global_names: list[str],
) -> tuple[list | None, list | None, list | None]:
"""Load subtask annotations for an episode from DataFrame."""
# Single-stage mode: (linear progress 0→1)
if episodes_df is None or len(global_names) == 1:
return None, None, None
# Resolve column name with fallback
def col(suffix):
prefixed = f"{annotation_type}_{suffix}"
return prefixed if prefixed in episodes_df.columns else suffix
col_names = col("subtask_names")
if col_names not in episodes_df.columns or ep_idx >= len(episodes_df):
return None, None, None
subtask_names = episodes_df.loc[ep_idx, col_names]
if subtask_names is None or (isinstance(subtask_names, float) and pd.isna(subtask_names)):
return None, None, None
return (
subtask_names,
episodes_df.loc[ep_idx, col("subtask_start_frames")],
episodes_df.loc[ep_idx, col("subtask_end_frames")],
)
def __call__(self, transition: EnvTransition) -> EnvTransition:
"""
Encode images, text, and normalize states in the transition.
Implements SARM training data preparation:
- Applies language perturbation (20% probability)
- Applies rewind augmentation (80% probability)
- Generates stage+tau targets for all frames
- Outputs lengths tensor for valid sequence masking
"""
new_transition = transition.copy() if hasattr(transition, "copy") else dict(transition)
observation = new_transition.get(TransitionKey.OBSERVATION)
comp_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
frame_index = comp_data.get("index")
episode_index = comp_data.get("episode_index")
if frame_index is None:
raise ValueError("Frame index ('index') not found in COMPLEMENTARY_DATA")
if episode_index is None:
raise ValueError("Episode index ('episode_index') not found in COMPLEMENTARY_DATA")
frame_indices = np.atleast_1d(np.asarray(from_tensor_to_numpy(frame_index)))
episode_indices = self._get_episode_indices(frame_indices, episode_index)
image = observation.get(self.image_key)
if isinstance(image, torch.Tensor):
image = image.cpu().numpy()
# If 4D (T, C, H, W) from delta_timestamps, add batch dim
# If 3D (C, H, W) single frame, add batch and time dims
if image.ndim == 4:
image = image[np.newaxis, ...] # (T, C, H, W) -> (1, T, C, H, W)
elif image.ndim == 3:
image = image[np.newaxis, np.newaxis, ...] # (C, H, W) -> (1, 1, C, H, W)
batch_size = image.shape[0]
total_frames = image.shape[1] # Should be 13: 9 obs + 4 rewind placeholders
n_obs_steps = self.config.n_obs_steps
max_rewind_steps = self.config.max_rewind_steps
n_obs_frames = 1 + n_obs_steps # 9 observation frames (including current)
# Rewind augmentation
rewind_steps = torch.zeros(batch_size, dtype=torch.int32)
apply_rewind = self.training and random.random() < self.config.rewind_probability
if apply_rewind and self.dataset_meta is not None:
for b_idx, (ep_idx, frame_idx) in enumerate(
zip(episode_indices.tolist(), frame_indices.tolist(), strict=True)
):
ep_idx, frame_idx = int(ep_idx), int(frame_idx)
ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"]
rewind_step, _ = apply_rewind_augmentation(
frame_idx, ep_start, n_obs_steps, max_rewind_steps, frame_gap=self.config.frame_gap
)
rewind_steps[b_idx] = rewind_step
# Compute valid lengths: n_obs_frames + rewind_steps
lengths = n_obs_frames + rewind_steps # (B,)
# Apply rewind masking to images
# For frames beyond valid length, we mask with zeros (or copy last valid frame)
for b_idx in range(batch_size):
valid_len = lengths[b_idx].item()
if valid_len < total_frames:
image[b_idx, valid_len:] = 0 # Zero out frames beyond valid length
# Encode images with CLIP
video_features = self._encode_images_batch(image)
observation["video_features"] = video_features
state_key = self.config.state_key
state_data = observation.get(state_key)
if isinstance(state_data, torch.Tensor):
state_tensor = state_data.float()
else:
state_tensor = torch.tensor(state_data, dtype=torch.float32)
if state_tensor.ndim == 2:
state_tensor = state_tensor.unsqueeze(0) # (T, D) -> (1, T, D)
elif state_tensor.ndim == 1:
state_tensor = state_tensor.unsqueeze(0).unsqueeze(0) # (D,) -> (1, 1, D)
# Apply same rewind masking to state
for b_idx in range(batch_size):
valid_len = lengths[b_idx].item()
if valid_len < state_tensor.shape[1]:
state_tensor[b_idx, valid_len:] = 0 # Zero out frames beyond valid length
observation["state_features"] = pad_state_to_max_dim(state_tensor, self.config.max_state_dim)
task = comp_data.get("task")
if isinstance(task, list):
task = task[0] if task else ""
# Apply language perturbation during training (20% probability)
# When perturbed, targets will be zeroed to train model to output low values for irrelevant text
apply_perturbation = self.training and random.random() < self.config.language_perturbation_probability
if apply_perturbation:
task = self._generate_perturbed_task()
# Encode text with CLIP
observation["text_features"] = self._encode_text_clip(task, batch_size)
# Store lengths for model
observation["lengths"] = lengths
# When language is perturbed, targets are zero so perturbed samples don't contribute to progress loss
if self.dataset_meta is not None:
episodes_df = None
if self.sparse_subtask_names != ["task"]:
episodes_df = self.dataset_meta.episodes.to_pandas()
# Generate sparse targets
if self.sparse_temporal_proportions is not None:
if apply_perturbation:
# Zero targets when language is perturbed
sparse_targets = torch.zeros(batch_size, total_frames, dtype=torch.float32)
else:
sparse_targets = self._compute_batch_targets(
frame_indices, episode_indices, lengths, rewind_steps, episodes_df, "sparse"
)
observation["sparse_targets"] = sparse_targets
# Generate dense targets (for dual mode)
if self.config.uses_dual_heads and self.dense_temporal_proportions is not None:
if apply_perturbation:
# Zero targets when language is perturbed
dense_targets = torch.zeros(batch_size, total_frames, dtype=torch.float32)
else:
dense_targets = self._compute_batch_targets(
frame_indices, episode_indices, lengths, rewind_steps, episodes_df, "dense"
)
observation["dense_targets"] = dense_targets
new_transition[TransitionKey.OBSERVATION] = observation
return new_transition
def _compute_batch_targets(
self,
frame_indices: np.ndarray,
episode_indices: np.ndarray,
lengths: torch.Tensor,
rewind_steps: torch.Tensor,
episodes_df: pd.DataFrame | None,
annotation_type: str,
) -> torch.Tensor:
"""Compute stage+tau targets for a batch of samples."""
batch_size = len(frame_indices)
n_obs_steps = self.config.n_obs_steps
max_rewind_steps = self.config.max_rewind_steps
total_frames = 1 + n_obs_steps + max_rewind_steps
frame_gap = self.config.frame_gap
global_names, temporal_props = self._get_annotation_config(annotation_type)
targets = torch.zeros(batch_size, total_frames, dtype=torch.float32)
for b_idx in range(batch_size):
ep_idx = int(episode_indices[b_idx])
frame_idx = int(frame_indices[b_idx])
ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"]
ep_end = self.dataset_meta.episodes[ep_idx]["dataset_to_index"]
ep_length = ep_end - ep_start
subtask_names, subtask_start_frames, subtask_end_frames = self._load_episode_annotations(
ep_idx, episodes_df, annotation_type, global_names
)
# Compute observation frame indices
obs_indices, _ = compute_absolute_indices(
frame_idx, ep_start, ep_end, n_obs_steps, frame_gap=frame_gap
)
obs_indices = obs_indices.tolist()
# Compute targets for observation frames
for t_idx, abs_idx in enumerate(obs_indices):
rel_frame = abs_idx - ep_start
targets[b_idx, t_idx] = find_stage_and_tau(
rel_frame,
ep_length,
subtask_names,
subtask_start_frames,
subtask_end_frames,
global_names,
temporal_props,
return_combined=True,
)
# Compute targets for rewind frames (if any)
rewind_step = rewind_steps[b_idx].item()
if rewind_step > 0:
_, rewind_indices = apply_rewind_augmentation(
frame_idx,
ep_start,
n_obs_steps,
max_rewind_steps,
frame_gap=frame_gap,
rewind_step=rewind_step,
)
for r_idx, abs_idx in enumerate(rewind_indices[:rewind_step]):
rel_frame = max(0, abs_idx - ep_start)
targets[b_idx, n_obs_steps + 1 + r_idx] = find_stage_and_tau(
rel_frame,
ep_length,
subtask_names,
subtask_start_frames,
subtask_end_frames,
global_names,
temporal_props,
return_combined=True,
)
return targets
@property
def training(self) -> bool:
return getattr(self, "_training_mode", True)
def train(self, mode: bool = True):
"""Set training mode for augmentation decisions."""
self._training_mode = mode
return self
def eval(self):
"""Set evaluation mode (disable augmentations)."""
return self.train(False)
@torch.no_grad()
def _encode_images_batch(self, images: np.ndarray) -> torch.Tensor:
"""Encode a batch of images using CLIP.
Args:
images: Batched images with shape: (B, T, C, H, W)
Returns:
Encoded feature vectors with shape (B, T, 512)
"""
batch_size, seq_length = images.shape[0], images.shape[1]
images = images.reshape(batch_size * seq_length, *images.shape[2:])
num_frames = images.shape[0]
images_list = []
for i in range(num_frames):
img = images[i]
if img.shape[0] in [1, 3]: # Channel first (C, H, W)
img = img.transpose(1, 2, 0)
# Handle single channel
if img.shape[-1] == 1:
img = np.repeat(img, 3, axis=-1)
if img.dtype != np.uint8:
img = (img * 255).astype(np.uint8) if img.max() <= 1.0 else img.astype(np.uint8)
images_list.append(Image.fromarray(img))
all_embeddings = []
for i in range(0, num_frames, self.config.clip_batch_size):
batch_imgs = images_list[i : i + self.config.clip_batch_size]
inputs = self.clip_processor(images=batch_imgs, return_tensors="pt")
inputs = {k: v.to(self.device) for k, v in inputs.items()}
# Get image embeddings
embeddings = self.clip_model.get_image_features(**inputs).detach().cpu()
# Handle single frame case
if embeddings.dim() == 1:
embeddings = embeddings.unsqueeze(0)
all_embeddings.append(embeddings)
all_embeddings = torch.cat(all_embeddings) # (B*T, 512)
all_embeddings = all_embeddings.reshape(batch_size, seq_length, -1) # (B, T, 512)
return all_embeddings
@torch.no_grad()
def _encode_text_clip(self, text: str, batch_size: int) -> torch.Tensor:
"""Encode text using CLIP text encoder (per SARM paper A.4).
Args:
text: Task description text to encode
batch_size: Batch size to replicate for
Returns:
Encoded text features with shape (B, 512)
"""
inputs = self.clip_processor.tokenizer([text], return_tensors="pt", padding=True, truncation=True)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
text_embedding = self.clip_model.get_text_features(**inputs).detach().cpu()
text_embedding = text_embedding.expand(batch_size, -1)
return text_embedding
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
"""Add encoded features to the observation features."""
features[PipelineFeatureType.OBSERVATION]["video_features"] = PolicyFeature(
type=FeatureType.VISUAL, shape=(self.config.num_frames, self.config.image_dim)
)
features[PipelineFeatureType.OBSERVATION]["text_features"] = PolicyFeature(
type=FeatureType.LANGUAGE, shape=(self.config.text_dim,)
)
features[PipelineFeatureType.OBSERVATION]["state_features"] = PolicyFeature(
type=FeatureType.STATE, shape=(self.config.num_frames, self.config.max_state_dim)
)
return features
def make_sarm_pre_post_processors(
config: SARMConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
dataset_meta=None,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
"""Create pre-processor and post-processor pipelines for SARM."""
return (
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
steps=[
AddBatchDimensionProcessorStep(),
RenameObservationsProcessorStep(rename_map={}),
NormalizerProcessorStep(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
SARMEncodingProcessorStep(
config=config, dataset_meta=dataset_meta, dataset_stats=dataset_stats
),
DeviceProcessorStep(device=config.device),
],
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
),
PolicyProcessorPipeline[PolicyAction, PolicyAction](
steps=[DeviceProcessorStep(device="cpu")],
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
to_transition=policy_action_to_transition,
to_output=transition_to_policy_action,
),
)
+295
View File
@@ -0,0 +1,295 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import numpy as np
import torch
import torch.nn.functional as F # noqa: N812
def find_stage_and_tau(
current_frame: int,
episode_length: int,
subtask_names: list | None,
subtask_start_frames: list | None,
subtask_end_frames: list | None,
global_subtask_names: list,
temporal_proportions: dict,
return_combined: bool = False,
) -> tuple[int, float] | float:
"""Find stage and within-stage progress (tau) for a frame.
Args:
current_frame: Frame index relative to episode start
episode_length: Total frames in episode
subtask_names: Subtask names for this episode (None for single_stage)
subtask_start_frames: Subtask start frames
subtask_end_frames: Subtask end frames
global_subtask_names: Global list of all subtask names
temporal_proportions: Dict of temporal proportions
return_combined: If True, return stage+tau as float; else (stage_idx, tau) tuple
Returns:
Float (stage.tau) if return_combined, else (stage_idx, tau) tuple
"""
stage_idx, tau = 0, 0.0
num_stages = len(global_subtask_names)
# Single-stage mode: linear progress from 0 to 1
if num_stages == 1:
tau = min(1.0, max(0.0, current_frame / max(episode_length - 1, 1)))
elif subtask_names is None:
pass # stage_idx=0, tau=0.0
elif current_frame < subtask_start_frames[0]:
pass # Before first subtask: stage_idx=0, tau=0.0
elif current_frame > subtask_end_frames[-1]:
stage_idx, tau = num_stages - 1, 0.999 # After last subtask
else:
# Find which subtask this frame belongs to
found = False
for name, start, end in zip(subtask_names, subtask_start_frames, subtask_end_frames, strict=True):
if start <= current_frame <= end:
stage_idx = global_subtask_names.index(name) if name in global_subtask_names else 0
tau = compute_tau(current_frame, start, end)
found = True
break
# Frame between subtasks - use previous subtask's end state
if not found:
for j in range(len(subtask_names) - 1):
if subtask_end_frames[j] < current_frame < subtask_start_frames[j + 1]:
name = subtask_names[j]
stage_idx = global_subtask_names.index(name) if name in global_subtask_names else j
tau = 1.0
break
if return_combined:
# Clamp to avoid overflow at end
if stage_idx >= num_stages - 1 and tau >= 1.0:
return num_stages - 1 + 0.999
return stage_idx + tau
return stage_idx, tau
def compute_absolute_indices(
frame_idx: int,
ep_start: int,
ep_end: int,
n_obs_steps: int,
frame_gap: int = 30,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Compute absolute frame indices with clamping for bidirectional observation sequence.
Bidirectional sampling centered on target frame:
- Before: [-frame_gap * half_steps, ..., -frame_gap] (half_steps frames)
- Current: [0] (1 frame)
- After: [frame_gap, ..., frame_gap * half_steps] (half_steps frames)
- Total: n_obs_steps + 1 frames
Out-of-bounds frames are clamped (duplicated from boundary).
Args:
frame_idx: Target frame index (center frame of sequence)
ep_start: Episode start index
ep_end: Episode end index (exclusive)
n_obs_steps: Number of observation steps (must be even for symmetric sampling)
frame_gap: Gap between observation frames
Returns:
Tuple of (indices, out_of_bounds_flags)
"""
half_steps = n_obs_steps // 2
# Bidirectional deltas: past + current + future
past_deltas = [-frame_gap * i for i in range(half_steps, 0, -1)]
future_deltas = [frame_gap * i for i in range(1, half_steps + 1)]
delta_indices = past_deltas + [0] + future_deltas
frames = []
out_of_bounds = []
for delta in delta_indices:
target_idx = frame_idx + delta
# Clamp to episode bounds (duplicate boundary frames for out-of-bounds)
clamped_idx = max(ep_start, min(ep_end - 1, target_idx))
frames.append(clamped_idx)
# Flag as out-of-bounds if clamping occurred
out_of_bounds.append(1 if target_idx != clamped_idx else 0)
return torch.tensor(frames), torch.tensor(out_of_bounds)
def apply_rewind_augmentation(
frame_idx: int,
ep_start: int,
n_obs_steps: int,
max_rewind_steps: int,
frame_gap: int = 30,
rewind_step: int | None = None,
) -> tuple[int, list[int]]:
"""
Generate rewind frame indices for temporal augmentation.
Rewind simulates going backwards through previously seen frames,
starting from before the earliest observation frame (for bidirectional sampling).
Appends reversed frames after the observation sequence.
Args:
frame_idx: Target frame index (center of bidirectional observation window)
ep_start: Episode start index
n_obs_steps: Number of observation steps
max_rewind_steps: Maximum rewind steps
frame_gap: Gap between frames
rewind_step: If provided, use this exact rewind step (for deterministic behavior).
If None, sample randomly.
Returns:
Tuple of (rewind_step, rewind_indices)
"""
# For bidirectional sampling, earliest obs frame is at frame_idx - half_steps * frame_gap
half_steps = n_obs_steps // 2
earliest_obs_frame = frame_idx - half_steps * frame_gap
# Required history: frames before earliest observation frame
if earliest_obs_frame <= ep_start:
return 0, [] # No history before observation window
# Max valid rewind steps based on available history before earliest obs frame
available_history = earliest_obs_frame - ep_start
max_valid_step = available_history // frame_gap
max_rewind = min(max_rewind_steps, max(0, max_valid_step))
if max_rewind <= 0:
return 0, []
# Sample rewind steps if not provided
rewind_step = random.randint(1, max_rewind) if rewind_step is None else min(rewind_step, max_rewind)
if rewind_step == 0:
return 0, []
# Generate rewind indices going backwards from earliest obs frame
# rewind_indices[0] is closest to obs window, rewind_indices[-1] is furthest back
rewind_indices = []
for i in range(1, rewind_step + 1):
idx = earliest_obs_frame - i * frame_gap
idx = max(ep_start, idx) # Clamp to episode start
rewind_indices.append(idx)
return rewind_step, rewind_indices
def compute_tau(current_frame: int | float, subtask_start: int | float, subtask_end: int | float) -> float:
"""Compute τ_t = (t - s_k) / (e_k - s_k) ∈ [0, 1]. Returns 1.0 for zero-duration subtasks."""
duration = subtask_end - subtask_start
if duration <= 0:
return 1.0
return float(np.clip((current_frame - subtask_start) / duration, 0.0, 1.0))
def pad_state_to_max_dim(state: torch.Tensor, max_state_dim: int) -> torch.Tensor:
"""Pad the state tensor's last dimension to max_state_dim with zeros."""
current_dim = state.shape[-1]
if current_dim >= max_state_dim:
return state[..., :max_state_dim] # Truncate if larger
# Pad with zeros on the right
padding = (0, max_state_dim - current_dim) # (left, right) for last dim
return F.pad(state, padding, mode="constant", value=0)
def temporal_proportions_to_breakpoints(
temporal_proportions: dict[str, float] | list[float] | None,
subtask_names: list[str] | None = None,
) -> list[float] | None:
"""Convert temporal proportions to cumulative breakpoints for normalization."""
if temporal_proportions is None:
return None
if isinstance(temporal_proportions, dict):
if subtask_names is not None:
proportions = [temporal_proportions.get(name, 0.0) for name in subtask_names]
else:
proportions = list(temporal_proportions.values())
else:
proportions = list(temporal_proportions)
total = sum(proportions)
if total > 0 and abs(total - 1.0) > 1e-6:
proportions = [p / total for p in proportions]
breakpoints = [0.0]
cumsum = 0.0
for prop in proportions:
cumsum += prop
breakpoints.append(cumsum)
breakpoints[-1] = 1.0
return breakpoints
def normalize_stage_tau(
x: float | torch.Tensor,
num_stages: int | None = None,
breakpoints: list[float] | None = None,
temporal_proportions: dict[str, float] | list[float] | None = None,
subtask_names: list[str] | None = None,
) -> float | torch.Tensor:
"""
Normalize stage+tau reward to [0, 1] with custom breakpoints.
Maps stage index + within-stage tau to normalized progress [0, 1].
The breakpoints are designed to give appropriate weight to each stage
based on their importance in the task (using temporal proportions).
Priority: breakpoints > temporal_proportions > linear fallback
Args:
x: Raw reward value (stage index + tau) where stage [0, num_stages-1] and tau [0, 1)
num_stages: Number of stages (required if breakpoints/proportions not provided)
breakpoints: Optional custom breakpoints list of length num_stages + 1.
temporal_proportions: Optional temporal proportions dict/list to compute breakpoints.
subtask_names: Optional ordered list of subtask names (for dict proportions)
Returns:
Normalized progress value [0, 1]
"""
if breakpoints is not None:
num_stages = len(breakpoints) - 1
elif temporal_proportions is not None:
breakpoints = temporal_proportions_to_breakpoints(temporal_proportions, subtask_names)
num_stages = len(breakpoints) - 1
elif num_stages is not None:
breakpoints = [i / num_stages for i in range(num_stages + 1)]
else:
raise ValueError("Either num_stages, breakpoints, or temporal_proportions must be provided")
if isinstance(x, torch.Tensor):
result = torch.zeros_like(x)
for i in range(num_stages):
mask = (x >= i) & (x < i + 1)
tau_in_stage = x - i
result[mask] = breakpoints[i] + tau_in_stage[mask] * (breakpoints[i + 1] - breakpoints[i])
result[x >= num_stages] = 1.0
return result.clamp(0.0, 1.0)
else:
if x < 0:
return 0.0
if x >= num_stages:
return 1.0
stage = int(x)
tau = x - stage
return breakpoints[stage] + tau * (breakpoints[stage + 1] - breakpoints[stage])
@@ -231,6 +231,7 @@ class SmolVLAPolicy(PreTrainedPolicy):
def __init__(
self,
config: SmolVLAConfig,
**kwargs,
):
"""
Args:
@@ -352,8 +353,19 @@ class SmolVLAPolicy(PreTrainedPolicy):
def _rtc_enabled(self) -> bool:
return self.config.rtc_config is not None and self.config.rtc_config.enabled
def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> dict[str, Tensor]:
"""Do a full training forward pass to compute the loss"""
def forward(
self, batch: dict[str, Tensor], noise=None, time=None, reduction: str = "mean"
) -> dict[str, Tensor]:
"""Do a full training forward pass to compute the loss.
Args:
batch: Training batch containing observations and actions.
noise: Optional noise tensor for flow matching.
time: Optional time tensor for flow matching.
reduction: How to reduce the loss. Options:
- "mean": Return scalar mean loss (default, backward compatible)
- "none": Return per-sample losses of shape (batch_size,) for RA-BC weighting
"""
if self.config.adapt_to_pi_aloha:
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
@@ -377,11 +389,16 @@ class SmolVLAPolicy(PreTrainedPolicy):
losses = losses[:, :, : self.config.max_action_dim]
loss_dict["losses_after_rm_padding"] = losses.clone()
# For backward pass
loss = losses.mean()
# For backward pass
loss_dict["loss"] = loss.item()
return loss, loss_dict
if reduction == "none":
# Return per-sample losses (B,) by averaging over time and action dims
per_sample_loss = losses.mean(dim=(1, 2))
loss_dict["loss"] = per_sample_loss.mean().item()
return per_sample_loss, loss_dict
else:
# Default: return scalar mean loss
loss = losses.mean()
loss_dict["loss"] = loss.item()
return loss, loss_dict
def prepare_images(self, batch):
"""Apply SmolVLA preprocessing to the images, like resizing to 224x224 and padding to keep aspect ratio, and
@@ -527,6 +544,7 @@ class VLAFlowMatching(nn.Module):
num_vlm_layers=self.config.num_vlm_layers,
self_attn_every_n_layers=self.config.self_attn_every_n_layers,
expert_width_multiplier=self.config.expert_width_multiplier,
device=self.config.device if self.config.device is not None else "auto",
)
self.state_proj = nn.Linear(
self.config.max_state_dim, self.vlm_with_expert.config.text_config.hidden_size
@@ -783,18 +801,15 @@ class VLAFlowMatching(nn.Module):
use_cache=self.config.use_cache,
fill_kv_cache=True,
)
dt = -1.0 / self.config.num_steps
dt = torch.tensor(dt, dtype=torch.float32, device=device)
num_steps = self.config.num_steps
dt = -1.0 / num_steps
x_t = noise
time = torch.tensor(1.0, dtype=torch.float32, device=device)
for step in range(num_steps):
time = 1.0 + step * dt
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
while time >= -dt / 2:
expanded_time = time.expand(bsize)
# Define a closure function to properly capture expanded_time
# This avoids the lambda expression (E731) and loop variable binding (B023) issues
def denoise_step_partial_call(input_x_t, current_timestep=expanded_time):
def denoise_step_partial_call(input_x_t, current_timestep=time_tensor):
return self.denoise_step(
x_t=input_x_t,
prefix_pad_masks=prefix_pad_masks,
@@ -818,15 +833,11 @@ class VLAFlowMatching(nn.Module):
else:
v_t = denoise_step_partial_call(x_t)
# Euler step
x_t += dt * v_t
x_t = x_t + dt * v_t
# Record x_t and v_t after Euler step (other params are recorded in rtc_processor.denoise_step)
if self.rtc_processor is not None and self.rtc_processor.is_debug_enabled():
self.rtc_processor.track(time=time, x_t=x_t, v_t=v_t)
time += dt
return x_t
def denoise_step(
@@ -65,6 +65,7 @@ class TDMPCPolicy(PreTrainedPolicy):
def __init__(
self,
config: TDMPCConfig,
**kwargs,
):
"""
Args:
+10 -1
View File
@@ -231,11 +231,20 @@ def validate_visual_features_consistency(
"""
Validates visual feature consistency between a policy config and provided dataset/environment features.
Validation passes if EITHER:
- Policy's expected visuals are a subset of dataset (policy uses some cameras, dataset has more)
- Dataset's provided visuals are a subset of policy (policy declares extras for flexibility)
Args:
cfg (PreTrainedConfig): The model or policy configuration containing input_features and type.
features (Dict[str, PolicyFeature]): A mapping of feature names to PolicyFeature objects.
"""
expected_visuals = {k for k, v in cfg.input_features.items() if v.type == FeatureType.VISUAL}
provided_visuals = {k for k, v in features.items() if v.type == FeatureType.VISUAL}
if not provided_visuals.issubset(expected_visuals):
# Accept if either direction is a subset
policy_subset_of_dataset = expected_visuals.issubset(provided_visuals)
dataset_subset_of_policy = provided_visuals.issubset(expected_visuals)
if not (policy_subset_of_dataset or dataset_subset_of_policy):
raise_feature_mismatch_error(provided_visuals, expected_visuals)
@@ -47,6 +47,7 @@ class VQBeTPolicy(PreTrainedPolicy):
def __init__(
self,
config: VQBeTConfig | None = None,
**kwargs,
):
"""
Args:
+1
View File
@@ -0,0 +1 @@
../../../../docs/source/policy_walloss_README.md
+19
View File
@@ -0,0 +1,19 @@
#!/usr/bin/env python
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .configuration_wall_x import WallXConfig
__all__ = ["WallXConfig", "WallXPolicy", "make_wall_x_pre_post_processors"]
@@ -0,0 +1,165 @@
# Copyright 2025 HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
@PreTrainedConfig.register_subclass("wall_x")
@dataclass
class WallXConfig(PreTrainedConfig):
"""
Configuration class for Wall-X policy.
Wall-X is based on Qwen2.5-VL with action prediction capabilities using flow matching.
It supports cross-embodiment robotic control through unified action representations.
This config supports multi-modal learning with vision, language, and action data.
"""
# ==================== Input / Output Structure ====================
n_obs_steps: int = 1
chunk_size: int = 32 # action_horizon in wall-x
n_action_steps: int = 32
# Action dimension - wall-x uses 20
max_action_dim: int = 20
max_state_dim: int = 20 # For proprioception
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.MEAN_STD,
"ACTION": NormalizationMode.MEAN_STD,
}
)
# ==================== Action Prediction ====================
# Pretrained model paths
pretrained_name_or_path: str = "x-square-robot/wall-oss-flow"
# Tokenizer settings
action_tokenizer_path: str | None = "physical-intelligence/fast"
# Action prediction mode: "diffusion" or "fast"
prediction_mode: str = "diffusion"
# Attention Implementation, options: "eager", "flash_attention_2", "sdpa"
# NOTE: flash-attn==2.7.4.post1 is required for flash_attention_2 implementation
attn_implementation: str = "eager"
# ==================== Optimizer Presets ====================
optimizer_lr: float = 2e-5
optimizer_betas: tuple[float, float] = (0.9, 0.95)
optimizer_eps: float = 1e-8
optimizer_weight_decay: float = 0.01
optimizer_grad_clip_norm: float = 1.0
scheduler_warmup_steps: int = 1000
scheduler_decay_steps: int = 100000
scheduler_decay_lr: float = 1e-6
def __post_init__(self):
super().__post_init__()
# Input validation
if self.n_action_steps > self.chunk_size:
raise ValueError(
f"The chunk size is the upper bound for the number of action steps per model invocation. Got "
f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`."
)
if self.prediction_mode not in ["diffusion", "fast"]:
raise ValueError(f"prediction_mode must be 'diffusion' or 'fast', got {self.prediction_mode}")
# Assign use_fast_tokenizer based on prediction_mode
if self.prediction_mode == "fast":
self.use_fast_tokenizer = True
elif self.prediction_mode == "diffusion":
self.use_fast_tokenizer = False
self.action_tokenizer_path = None # disable action tokenizer for diffusion mode
else:
raise ValueError(f"prediction_mode must be 'diffusion' or 'fast', got {self.prediction_mode}")
def validate_features(self) -> None:
"""Validate and set up input/output features."""
image_features = [key for key, feat in self.input_features.items() if feat.type == FeatureType.VISUAL]
if not image_features:
raise ValueError(
"Wall-X policy requires at least one visual input feature. "
"No features of type FeatureType.VISUAL found in input_features."
)
if "observation.state" not in self.input_features:
state_feature = PolicyFeature(
type=FeatureType.STATE,
shape=(self.max_state_dim,), # Padded to max_state_dim
)
self.input_features["observation.state"] = state_feature
else:
state_shape = self.input_features["observation.state"].shape
state_dim = state_shape[0] if state_shape else 0
if state_dim > self.max_state_dim:
raise ValueError(
f"State dimension {state_dim} exceeds max_state_dim {self.max_state_dim}. "
f"Either reduce state dimension or increase max_state_dim in config."
)
if "action" not in self.output_features:
action_feature = PolicyFeature(
type=FeatureType.ACTION,
shape=(self.max_action_dim,), # Padded to max_action_dim
)
self.output_features["action"] = action_feature
else:
action_shape = self.output_features["action"].shape
action_dim = action_shape[0] if action_shape else 0
if action_dim > self.max_action_dim:
raise ValueError(
f"Action dimension {action_dim} exceeds max_action_dim {self.max_action_dim}. "
f"Either reduce action dimension or increase max_action_dim in config."
)
def get_optimizer_preset(self) -> AdamWConfig:
return AdamWConfig(
lr=self.optimizer_lr,
betas=self.optimizer_betas,
eps=self.optimizer_eps,
weight_decay=self.optimizer_weight_decay,
grad_clip_norm=self.optimizer_grad_clip_norm,
)
def get_scheduler_preset(self):
return CosineDecayWithWarmupSchedulerConfig(
peak_lr=self.optimizer_lr,
decay_lr=self.scheduler_decay_lr,
num_warmup_steps=self.scheduler_warmup_steps,
num_decay_steps=self.scheduler_decay_steps,
)
@property
def observation_delta_indices(self) -> list:
return None
@property
def action_delta_indices(self) -> list:
return list(range(self.chunk_size))
@property
def reward_delta_indices(self) -> None:
return None
+41
View File
@@ -0,0 +1,41 @@
#!/usr/bin/env python
# Copyright 2025 HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Wall-X Constants and Configuration Data.
"""
CAMERA_NAME_MAPPING = {
"face_view": "front view",
"left_wrist_view": "left wrist view",
"right_wrist_view": "right wrist view",
"move1_view": "move view",
"move2_view": "move view",
"wall_view": "wall view",
"top_view": "top view",
}
RESOLUTION = 256
# Parameters for preprocessing
MAX_PIXELS = 16384 * 28 * 28
MIN_PIXELS = 4 * 28 * 28
IMAGE_FACTOR = 28
PRIORITY_ORDER = None
GENERATE_SUBTASK_RATIO = 0.0
MODEL_TYPE = "qwen2_5"
TOKENIZER_MAX_LENGTH = 768
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,133 @@
#!/usr/bin/env python
# Copyright 2025 HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
import torch
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
from lerobot.policies.wall_x.configuration_wall_x import WallXConfig
from lerobot.processor import (
AddBatchDimensionProcessorStep,
ComplementaryDataProcessorStep,
DeviceProcessorStep,
NormalizerProcessorStep,
PolicyAction,
PolicyProcessorPipeline,
ProcessorStepRegistry,
RenameObservationsProcessorStep,
UnnormalizerProcessorStep,
)
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
def make_wall_x_pre_post_processors(
config: WallXConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
"""
Constructs pre-processor and post-processor pipelines for the Wall-X policy.
The pre-processing pipeline prepares input data for the model by:
1. Renaming features to match pretrained configurations
2. Adding a batch dimension
4. Normalizing input and output features based on dataset statistics
5. Moving all data to the specified device
The post-processing pipeline handles the model's output by:
1. Unnormalizing the output actions to their original scale
2. Moving data to the CPU
Args:
config: The configuration object for the Wall-X policy
dataset_stats: A dictionary of statistics for normalization
Returns:
A tuple containing the configured pre-processor and post-processor pipelines
"""
input_steps = [
RenameObservationsProcessorStep(rename_map={}),
AddBatchDimensionProcessorStep(),
WallXTaskProcessor(), # Process task description
NormalizerProcessorStep(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
DeviceProcessorStep(device=config.device),
]
output_steps = [
UnnormalizerProcessorStep(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
DeviceProcessorStep(device="cpu"),
]
return (
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
steps=input_steps,
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
),
PolicyProcessorPipeline[PolicyAction, PolicyAction](
steps=output_steps,
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
to_transition=policy_action_to_transition,
to_output=transition_to_policy_action,
),
)
@ProcessorStepRegistry.register(name="wall_x_task_processor")
class WallXTaskProcessor(ComplementaryDataProcessorStep):
"""
A processor step that ensures the task description is properly formatted for Wall-X.
This step handles task preprocessing similar to Qwen-VL requirements.
"""
def complementary_data(self, complementary_data):
if "task" not in complementary_data:
return complementary_data
task = complementary_data["task"]
if task is None:
# Provide default task if none specified
complementary_data["task"] = "Execute the robot action."
return complementary_data
new_complementary_data = dict(complementary_data)
# Handle both string and list of strings
if isinstance(task, str):
# Single string: ensure proper formatting
if not task.endswith("."):
new_complementary_data["task"] = f"{task}."
elif isinstance(task, list) and all(isinstance(t, str) for t in task):
# List of strings: format each
new_complementary_data["task"] = [t if t.endswith(".") else f"{t}." for t in task]
return new_complementary_data
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
return features
@@ -0,0 +1,248 @@
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_rope_utils import rope_config_validation
class Qwen2_5_VLVisionConfig(PretrainedConfig):
model_type = "qwen2_5_vl"
base_config_key = "vision_config"
def __init__(
self,
depth=32,
hidden_size=3584,
hidden_act="silu",
intermediate_size=3420,
num_heads=16,
in_channels=3,
patch_size=14,
spatial_merge_size=2,
temporal_patch_size=2,
tokens_per_second=4,
window_size=112,
out_hidden_size=3584,
fullatt_block_indexes=[7, 15, 23, 31],
**kwargs,
):
super().__init__(**kwargs)
self.depth = depth
self.hidden_size = hidden_size
self.hidden_act = hidden_act
self.intermediate_size = intermediate_size
self.num_heads = num_heads
self.in_channels = in_channels
self.patch_size = patch_size
self.spatial_merge_size = spatial_merge_size
self.temporal_patch_size = temporal_patch_size
self.tokens_per_second = tokens_per_second
self.window_size = window_size
self.fullatt_block_indexes = fullatt_block_indexes
self.out_hidden_size = out_hidden_size
class Qwen2_5_VLConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Qwen2_5_VLModel`]. It is used to instantiate a
Qwen2-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of
Qwen2-VL-7B-Instruct [Qwen/Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct).
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 152064):
Vocabulary size of the Qwen2_5_VL model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`Qwen2_5_VLModel`]
hidden_size (`int`, *optional*, defaults to 8192):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 29568):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 80):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 64):
Number of attention heads for each attention layer in the Transformer encoder.
num_key_value_heads (`int`, *optional*, defaults to 8):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 32768):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether the model's input and output word embeddings should be tied.
rope_theta (`float`, *optional*, defaults to 1000000.0):
The base period of the RoPE embeddings.
use_sliding_window (`bool`, *optional*, defaults to `False`):
Whether to use sliding window attention.
sliding_window (`int`, *optional*, defaults to 4096):
Sliding window attention (SWA) window size. If not specified, will default to `4096`.
max_window_layers (`int`, *optional*, defaults to 80):
The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
vision_config (`Dict`, *optional*):
The config for the visual encoder initialization.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
accordingly.
Expected contents:
`rope_type` (`str`):
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
'llama3'], with 'default' being the original RoPE implementation.
`factor` (`float`, *optional*):
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
original maximum pre-trained length.
`original_max_position_embeddings` (`int`, *optional*):
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
pretraining.
`attention_factor` (`float`, *optional*):
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
computation. If unspecified, it defaults to value recommended by the implementation, using the
`factor` field to infer the suggested value.
`beta_fast` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
ramp function. If unspecified, it defaults to 32.
`beta_slow` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
ramp function. If unspecified, it defaults to 1.
`short_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`long_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`low_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
`high_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
```python
>>> from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLConfig
>>> # Initializing a Qwen2_5_VL style configuration
>>> configuration = Qwen2_5_VLConfig()
>>> # Initializing a model from the Qwen2-VL-7B style configuration
>>> model = Qwen2_5_VLForConditionalGeneration(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "qwen2_5_vl"
sub_configs = {"vision_config": Qwen2_5_VLVisionConfig}
keys_to_ignore_at_inference = ["past_key_values"]
# Default tensor parallel plan for base model `Qwen2_5_VL`
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}
def __init__(
self,
vocab_size=152064,
hidden_size=8192,
intermediate_size=29568,
num_hidden_layers=80,
num_attention_heads=64,
num_key_value_heads=8,
hidden_act="silu",
max_position_embeddings=32768,
initializer_range=0.02,
rms_norm_eps=1e-05,
use_cache=True,
tie_word_embeddings=False,
rope_theta=1000000.0,
use_sliding_window=False,
sliding_window=4096,
max_window_layers=80,
attention_dropout=0.0,
vision_config=None,
rope_scaling=None,
num_experts=4,
experts=None,
dof_config=None,
noise_scheduler=None,
dim_inputs=(1536, 1536),
attention_moe=False,
mlp_moe=False,
**kwargs,
):
if isinstance(vision_config, dict):
self.vision_config = self.sub_configs["vision_config"](**vision_config)
elif vision_config is None:
self.vision_config = self.sub_configs["vision_config"]()
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.use_sliding_window = use_sliding_window
self.sliding_window = sliding_window
self.max_window_layers = max_window_layers
self.layer_types = ["dense"] * num_hidden_layers
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.attention_dropout = attention_dropout
self.rope_scaling = rope_scaling
self.num_experts = num_experts
self.experts = experts
self.dof_config = dof_config
self.noise_scheduler = noise_scheduler
self.dim_inputs = tuple(dim_inputs)
self.attention_moe = attention_moe
self.mlp_moe = mlp_moe
if self.rope_scaling is not None and "type" in self.rope_scaling:
if self.rope_scaling["type"] == "mrope":
self.rope_scaling["type"] = "default"
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
rope_config_validation(self, ignore_keys={"mrope_section"})
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
@property
def text_config(self):
return self
__all__ = ["Qwen2_5_VLConfig"]

Some files were not shown because too many files have changed in this diff Show More