Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| ee24f64ae5 | |||
| 123b9f7851 |
@@ -12,83 +12,57 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
name: "🚀 Issue / Bug / Request"
|
||||
description: Report a bug, suggest an improvement, or ask a technical question.
|
||||
name: "\U0001F41B Bug Report"
|
||||
description: Submit a bug report to help us improve LeRobot
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
### Thanks for contributing to LeRobot! 🙌
|
||||
Please choose the most relevant sections below. If this is a general "how-to" question, consider our [Discord](https://discord.gg/s3KuuzsPFb) for faster community support.
|
||||
|
||||
- type: dropdown
|
||||
id: issue-type
|
||||
attributes:
|
||||
label: Ticket Type
|
||||
description: What kind of ticket are you opening?
|
||||
options:
|
||||
- "🐛 Bug Report (Something isn't working)"
|
||||
- "💡 Feature Request / Improvement"
|
||||
- "❓ Technical Question"
|
||||
- "🧹 Maintenance / Documentation"
|
||||
validations:
|
||||
required: true
|
||||
Thanks for taking the time to submit a bug report! 🐛
|
||||
If this is not a bug related to the LeRobot library directly, but instead a general question about your code or the library specifically please use our [discord](https://discord.gg/s3KuuzsPFb).
|
||||
|
||||
- type: textarea
|
||||
id: system-info
|
||||
attributes:
|
||||
label: Environment & System Info
|
||||
description: |
|
||||
For bugs or technical questions, please run `lerobot-info` and paste the output.
|
||||
(Optional for feature requests).
|
||||
label: System Info
|
||||
description: Please share your LeRobot configuration by running `lerobot-info` (if installed) or `python -m lerobot.scripts.display_sys_info` (if not installed) and pasting the output below.
|
||||
render: Shell
|
||||
placeholder: lerobot version, OS, python version, etc.
|
||||
placeholder: lerobot version, OS, python version, numpy version, torch version, and lerobot's configuration
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: checkboxes
|
||||
id: information-scripts-examples
|
||||
attributes:
|
||||
label: Information
|
||||
description: 'The problem arises when using:'
|
||||
options:
|
||||
- label: "One of the scripts in the examples/ folder of LeRobot"
|
||||
- label: "My own task or dataset (give details below)"
|
||||
|
||||
- type: textarea
|
||||
id: description
|
||||
id: reproduction
|
||||
validations:
|
||||
required: true
|
||||
attributes:
|
||||
label: Description
|
||||
label: Reproduction
|
||||
description: |
|
||||
Provide a clear summary of the issue or your proposal.
|
||||
- **Bugs:** What is happening?
|
||||
- **Features:** What is the goal/use case?
|
||||
- **Questions:** What are you trying to achieve?
|
||||
If needed, provide a simple code sample that reproduces the problem you ran into. It can be a Colab link or just a code snippet.
|
||||
Sharing error messages or stack traces could be useful as well!
|
||||
Important! Use code tags to correctly format your code. See https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting
|
||||
Try to avoid screenshots, as they are hard to read and don't allow copy-and-pasting.
|
||||
|
||||
placeholder: |
|
||||
A clear and concise description of the issue or suggestion.
|
||||
Steps to reproduce the behavior:
|
||||
|
||||
1.
|
||||
2.
|
||||
3.
|
||||
|
||||
- type: textarea
|
||||
id: context-repro
|
||||
id: expected-behavior
|
||||
validations:
|
||||
required: true
|
||||
attributes:
|
||||
label: Context & Reproduction
|
||||
description: |
|
||||
Provide a code snippet, steps to reproduce a bug, or technical details about your proposal.
|
||||
Please use code blocks for scripts and CLI commands.
|
||||
placeholder: |
|
||||
Steps to reproduce / Usage example:
|
||||
1.
|
||||
2.
|
||||
3.
|
||||
|
||||
- type: textarea
|
||||
id: logs
|
||||
attributes:
|
||||
label: Relevant logs or stack trace
|
||||
description: If applicable, paste relevant error logs here.
|
||||
render: Shell
|
||||
|
||||
- type: checkboxes
|
||||
id: extras
|
||||
attributes:
|
||||
label: Checklist
|
||||
options:
|
||||
- label: I have searched existing tickets to ensure this isn't a duplicate.
|
||||
- label: I am using the latest version of the `main` branch.
|
||||
- label: I have verified this is not an environment-specific problem.
|
||||
|
||||
- type: textarea
|
||||
id: workaround
|
||||
attributes:
|
||||
label: Additional Info / Workarounds
|
||||
description: Anything else we should know? If you have a workaround, please share it!
|
||||
label: Expected behavior
|
||||
description: "A clear and concise description of what you would expect to happen."
|
||||
|
||||
@@ -1,54 +1,41 @@
|
||||
## Title
|
||||
## What this does
|
||||
|
||||
Short, imperative summary (e.g., "fix(robots): handle None in sensor parser"). See [CONTRIBUTING.md](../CONTRIBUTING.md) for PR conventions.
|
||||
Explain what this PR does. Feel free to tag your PR with the appropriate label(s).
|
||||
|
||||
## Type / Scope
|
||||
Examples:
|
||||
| Title | Label |
|
||||
|----------------------|-----------------|
|
||||
| Fixes #[issue] | (🐛 Bug) |
|
||||
| Adds new dataset | (🗃️ Dataset) |
|
||||
| Optimizes something | (⚡️ Performance) |
|
||||
|
||||
- **Type**: (Bug | Feature | Docs | Performance | Test | CI | Chore)
|
||||
- **Scope**: (optional — name of module or package affected)
|
||||
## How it was tested
|
||||
|
||||
## Summary / Motivation
|
||||
Explain/show how you tested your changes.
|
||||
|
||||
- One-paragraph description of what changes and why.
|
||||
- Why this change is needed and any trade-offs or design notes.
|
||||
Examples:
|
||||
|
||||
## Related issues
|
||||
- Added `test_something` in `tests/test_stuff.py`.
|
||||
- Added `new_feature` and checked that training converges with policy X on dataset/environment Y.
|
||||
- Optimized `some_function`, it now runs X times faster than previously.
|
||||
|
||||
- Fixes / Closes: # (if any)
|
||||
- Related: # (if any)
|
||||
## How to checkout & try? (for the reviewer)
|
||||
|
||||
## What changed
|
||||
Provide a simple way for the reviewer to try out your changes.
|
||||
|
||||
- Short, concrete bullets of the modifications (files/behaviour).
|
||||
- Short note if this introduces breaking changes and migration steps.
|
||||
Examples:
|
||||
|
||||
## How was this tested
|
||||
```bash
|
||||
pytest -sx tests/test_stuff.py::test_something
|
||||
```
|
||||
|
||||
- Tests added: list new tests or test files.
|
||||
- Manual checks / dataset runs performed.
|
||||
```bash
|
||||
lerobot-train --some.option=true
|
||||
```
|
||||
|
||||
## How to run locally (reviewer)
|
||||
## SECTION TO REMOVE BEFORE SUBMITTING YOUR PR
|
||||
|
||||
- Run the relevant tests:
|
||||
**Note**: Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
|
||||
members/contributors who may be interested in your PR. Try to avoid tagging more than 3 people.
|
||||
|
||||
```bash
|
||||
pytest -q tests/ -k <keyword>
|
||||
```
|
||||
|
||||
- Run a quick example or CLI (if applicable):
|
||||
|
||||
```bash
|
||||
lerobot-train --some.option=true
|
||||
```
|
||||
|
||||
## Checklist (required before merge)
|
||||
|
||||
- [ ] Linting/formatting run (`pre-commit run -a`)
|
||||
- [ ] All tests pass locally (`pytest`)
|
||||
- [ ] Documentation updated
|
||||
- [ ] CI is green
|
||||
|
||||
## Reviewer notes
|
||||
|
||||
- Anything the reviewer should focus on (performance, edge-cases, specific files) or general notes.
|
||||
- Anyone in the community is free to review the PR.
|
||||
**Note**: Before submitting this PR, please read the [contributor guideline](https://github.com/huggingface/lerobot/blob/main/CONTRIBUTING.md#submitting-a-pull-request-pr).
|
||||
|
||||
@@ -1,69 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
CI:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- '.github/**'
|
||||
- 'docker/**'
|
||||
|
||||
github_actions:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: '.github/**'
|
||||
|
||||
documentation:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- '**/*.md'
|
||||
- '**/*.mdx'
|
||||
- 'docs/**'
|
||||
|
||||
examples:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: 'examples/**'
|
||||
|
||||
tests:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: 'tests/**'
|
||||
|
||||
sensors:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: 'src/lerobot/cameras/**'
|
||||
|
||||
configuration:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: 'src/lerobot/configs/**'
|
||||
|
||||
dataset:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: 'src/lerobot/datasets/**'
|
||||
|
||||
evaluation:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: 'src/lerobot/envs/**'
|
||||
|
||||
robots:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- 'src/lerobot/teleoperators/**'
|
||||
- 'src/lerobot/robots/**'
|
||||
- 'src/lerobot/motors/**'
|
||||
|
||||
policies:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: 'src/lerobot/policies/**'
|
||||
|
||||
processor:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: 'src/lerobot/processor/**'
|
||||
@@ -33,9 +33,6 @@ on:
|
||||
paths:
|
||||
- "docs/**"
|
||||
|
||||
release:
|
||||
types: [published]
|
||||
|
||||
# Ensures that only the latest commit for a PR or branch is built, canceling older runs.
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
|
||||
@@ -46,7 +43,7 @@ jobs:
|
||||
build_main_docs:
|
||||
name: Build Main Docs
|
||||
if: >
|
||||
(github.event_name == 'push' || github.event_name == 'workflow_dispatch' || github.event_name == 'release') &&
|
||||
(github.event_name == 'push' || github.event_name == 'workflow_dispatch') &&
|
||||
github.repository == 'huggingface/lerobot'
|
||||
permissions:
|
||||
contents: read
|
||||
@@ -54,7 +51,7 @@ jobs:
|
||||
with:
|
||||
commit_sha: ${{ github.sha }}
|
||||
package: lerobot
|
||||
additional_args: --not_python_module ${{ github.event_name == 'release' && format('--version {0}', github.event.release.tag_name) || '' }}
|
||||
additional_args: --not_python_module
|
||||
secrets:
|
||||
token: ${{ secrets.HUGGINGFACE_PUSH }}
|
||||
hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }}
|
||||
|
||||
@@ -62,7 +62,7 @@ jobs:
|
||||
HF_HOME: /mnt/cache/.cache/huggingface
|
||||
HF_LEROBOT_HOME: /mnt/cache/.cache/huggingface/lerobot
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
lfs: true
|
||||
|
||||
@@ -61,7 +61,7 @@ jobs:
|
||||
HF_HOME: /mnt/cache/.cache/huggingface
|
||||
HF_LEROBOT_HOME: /mnt/cache/.cache/huggingface/lerobot
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
lfs: true
|
||||
persist-credentials: false
|
||||
@@ -85,7 +85,7 @@ jobs:
|
||||
python-version: ${{ env.PYTHON_VERSION }}
|
||||
|
||||
- name: Install lerobot with all extras
|
||||
run: uv sync --extra all # TODO(Steven): Make flash-attn optional
|
||||
run: uv sync --all-extras --no-extra groot # TODO(Steven): Make flash-attn optional
|
||||
|
||||
- name: Run pytest (all extras)
|
||||
run: uv run pytest tests -vv --maxfail=10
|
||||
@@ -127,7 +127,7 @@ jobs:
|
||||
sudo apt-get update
|
||||
sudo apt-get install git-lfs
|
||||
git lfs install
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
lfs: true
|
||||
persist-credentials: false
|
||||
|
||||
@@ -1,77 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# This workflow automatically labels issues based on their content.
|
||||
name: Issue Labeler
|
||||
on:
|
||||
# Trigger on new issues and edits to existing issues
|
||||
issues:
|
||||
types: [opened, edited]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
issues: write
|
||||
|
||||
jobs:
|
||||
label-issue:
|
||||
name: Auto Label Issue
|
||||
runs-on: ubuntu-latest
|
||||
if: github.repository == 'huggingface/lerobot'
|
||||
steps:
|
||||
- uses: actions/github-script@v8
|
||||
with:
|
||||
script: |
|
||||
// Setup Input Text
|
||||
const body = (context.payload.issue.body || '');
|
||||
const title = (context.payload.issue.title || '');
|
||||
const cleanBody = body.replace(/```[\s\S]*?```/g, '');
|
||||
const text = `${title}\n${cleanBody}`.toLowerCase();
|
||||
const labelsToAdd = new Set();
|
||||
const matches = (re) => re.test(text);
|
||||
|
||||
// Keyword Heuristics
|
||||
|
||||
if (matches(/\b(bug|error|crash|exception)\b/i)) labelsToAdd.add('bug');
|
||||
if (matches(/\b(new feature|enhancement|improvement|proposal|feature request)\b/i)) labelsToAdd.add('enhancement');
|
||||
if (matches(/\b(question|how to|clarify|explain|how do i|help me|question about)\b/i)) labelsToAdd.add('question');
|
||||
if (matches(/\b(documentation|docs?|readme|tutorial|wiki|typo|docstring)\b/i)) labelsToAdd.add('documentation');
|
||||
if (matches(/\b(example|sample|demo|notebook)s?\b/i)) labelsToAdd.add('examples');
|
||||
if (matches(/\b(datasets?|data loader|data augmentation|data preprocessing)\b/i)) labelsToAdd.add('dataset');
|
||||
if (matches(/\b(mujoco|isaac|simulation|sim)\b/i)) labelsToAdd.add('simulation');
|
||||
if (matches(/\b(train|training|optimizer|gradient|wandb|sac)\b/i)) labelsToAdd.add('training');
|
||||
if (matches(/\b(rerun|plot|render|rendering|visualizer)/i)) labelsToAdd.add('visualization');
|
||||
if (matches(/\b(cameras?|opencv|realsense|lidars?|sensors?|imus?|microphones?|rgbd|encoders?)\b/i)) labelsToAdd.add('sensors');
|
||||
if (matches(/\b(urdf|actuators?|calibration|end-effector|kinematics)\b/i)) labelsToAdd.add('robots');
|
||||
if (matches(/\b(teleop|teleoperator|controller|leader|follower|joystick|gamepad)\b/i)) labelsToAdd.add('teleoperators');
|
||||
if (matches(/\b(policy|policies|model?)\b/i)) labelsToAdd.add('policies');
|
||||
if (matches(/\b(processor|pipeline|preprocessor|postprocessor)s?\b/i)) labelsToAdd.add('processor');
|
||||
if (matches(/\b(eval|evaluate|evaluation|metrics?|score|benchmarks?)\b/i)) labelsToAdd.add('evaluation');
|
||||
if (matches(/\b(tests?|pytest|unittest|failing test)\b/i)) labelsToAdd.add('tests');
|
||||
if (matches(/\b(ci|github actions?|github workflows?|gha|docker|pypi)\b/i)) labelsToAdd.add('CI');
|
||||
if (matches(/\b(perf|latency|throughput|fps|speed|performance|slow|fast|slower|faster|memory usage)\b/i)) labelsToAdd.add('performance');
|
||||
if (matches(/\b(dependency|dependencies|pip|install error|importerror|package not found|pyproject)\b/i)) labelsToAdd.add('dependencies');
|
||||
if (matches(/\b(configuration|config|arguments?|input feature|dracuss)\b/i)) labelsToAdd.add('configuration');
|
||||
|
||||
// Apply Labels
|
||||
const labels = Array.from(labelsToAdd).filter(Boolean);
|
||||
|
||||
if (labels.length > 0) {
|
||||
console.log(`Adding labels: ${labels.join(', ')}`);
|
||||
await github.rest.issues.addLabels({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: context.issue.number,
|
||||
labels,
|
||||
});
|
||||
}
|
||||
@@ -52,7 +52,7 @@ jobs:
|
||||
sudo apt-get update
|
||||
sudo apt-get install git-lfs
|
||||
git lfs install
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
lfs: true
|
||||
persist-credentials: false
|
||||
@@ -87,7 +87,7 @@ jobs:
|
||||
sudo apt-get update
|
||||
sudo apt-get install git-lfs
|
||||
git lfs install
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
lfs: true
|
||||
persist-credentials: false
|
||||
|
||||
@@ -1,39 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# This workflow labels pull requests based on the files that were changed.
|
||||
name: Pull Request Labeler
|
||||
|
||||
on:
|
||||
# Allows labeling pull requests when they are opened or updated
|
||||
# zizmor: ignore[dangerous-triggers] Needed to label PRs from forks
|
||||
pull_request_target:
|
||||
branches:
|
||||
- main
|
||||
types: [opened, synchronize, reopened, ready_for_review]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: write
|
||||
|
||||
jobs:
|
||||
triage:
|
||||
name: Label PR
|
||||
runs-on: ubuntu-latest
|
||||
if: github.repository == 'huggingface/lerobot' && !github.event.pull_request.draft
|
||||
steps:
|
||||
- uses: actions/labeler@v6
|
||||
with:
|
||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
sync-labels: true # Removes labels if files are removed from the PR
|
||||
@@ -43,12 +43,12 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v6
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.10'
|
||||
|
||||
|
||||
@@ -38,12 +38,12 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v6
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.10'
|
||||
|
||||
@@ -135,7 +135,7 @@ jobs:
|
||||
env:
|
||||
MUJOCO_GL: egl
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
lfs: true
|
||||
persist-credentials: false
|
||||
@@ -177,3 +177,4 @@ jobs:
|
||||
|
||||
# TODO(Steven): Publish draft/pre-release and to test pypi weekly
|
||||
# TODO(Steven): Separate build and publish job
|
||||
# TODO(Steven): Tag documentation with the same version as the package
|
||||
|
||||
@@ -43,7 +43,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v6 # zizmor: ignore[unpinned-uses]
|
||||
uses: actions/checkout@v4 # zizmor: ignore[unpinned-uses]
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
@@ -49,7 +49,7 @@ jobs:
|
||||
HF_HOME: /mnt/cache/.cache/huggingface
|
||||
HF_LEROBOT_HOME: /mnt/cache/.cache/huggingface/lerobot
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
lfs: true
|
||||
persist-credentials: false
|
||||
@@ -78,7 +78,7 @@ jobs:
|
||||
echo "Dependencies unbound:" && cat pyproject.toml
|
||||
|
||||
- name: Install lerobot with all extras
|
||||
run: uv sync --extra all # TODO(Steven): Make flash-attn optional
|
||||
run: uv sync --all-extras --no-extra groot # TODO(Steven): Make flash-attn optional
|
||||
|
||||
- name: Run pytest (all extras)
|
||||
run: uv run pytest tests -vv
|
||||
@@ -101,7 +101,7 @@ jobs:
|
||||
sudo apt-get update
|
||||
sudo apt-get install git-lfs
|
||||
git lfs install
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
lfs: true
|
||||
persist-credentials: false
|
||||
|
||||
@@ -87,7 +87,7 @@ repos:
|
||||
# TODO(Steven): Uncomment when ready to use
|
||||
##### Static Analysis & Typing #####
|
||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||
rev: v1.19.1
|
||||
rev: v1.18.2
|
||||
hooks:
|
||||
- id: mypy
|
||||
args: [--config-file=pyproject.toml]
|
||||
|
||||
@@ -52,7 +52,7 @@ decisions when appropriate.
|
||||
|
||||
This Code of Conduct applies within all community spaces, and also applies when
|
||||
an individual is officially representing the community in public spaces.
|
||||
Examples of representing our community include using an official e-mail address,
|
||||
Examples of representing our community include using an official email address,
|
||||
posting via an official social media account, or acting as an appointed
|
||||
representative at an online or offline event.
|
||||
|
||||
@@ -60,7 +60,7 @@ representative at an online or offline event.
|
||||
|
||||
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
||||
reported to the community leaders responsible for enforcement at
|
||||
feedback@huggingface.co.
|
||||
[feedback@huggingface.co](mailto:feedback@huggingface.co).
|
||||
All complaints will be reviewed and investigated promptly and fairly.
|
||||
|
||||
All community leaders are obligated to respect the privacy and security of the
|
||||
|
||||
@@ -1,83 +1,323 @@
|
||||
# How to contribute to 🤗 LeRobot
|
||||
# How to contribute to 🤗 LeRobot?
|
||||
|
||||
Everyone is welcome to contribute, and we value everybody's contribution. Code is not the only way to help the community. Answering questions, helping others, reaching out, and improving the documentation are immensely valuable.
|
||||
Everyone is welcome to contribute, and we value everybody's contribution. Code
|
||||
is thus not the only way to help the community. Answering questions, helping
|
||||
others, reaching out and improving the documentations are immensely valuable to
|
||||
the community.
|
||||
|
||||
Whichever way you choose to contribute, please be mindful to respect our [code of conduct](./CODE_OF_CONDUCT.md).
|
||||
It also helps us if you spread the word: reference the library from blog posts
|
||||
on the awesome projects it made possible, shout out on Twitter when it has
|
||||
helped you, or simply ⭐️ the repo to say "thank you".
|
||||
|
||||
## Ways to Contribute
|
||||
Whichever way you choose to contribute, please be mindful to respect our
|
||||
[code of conduct](https://github.com/huggingface/lerobot/blob/main/CODE_OF_CONDUCT.md).
|
||||
|
||||
You can contribute in many ways:
|
||||
## You can contribute in so many ways!
|
||||
|
||||
- **Fixing issues:** Resolve bugs or improve existing code.
|
||||
- **New features:** Develop new features.
|
||||
- **Extend:** Implement new models/policies, robots, or simulation environments and upload datasets to the Hugging Face Hub.
|
||||
- **Documentation:** Improve examples, guides, and docstrings.
|
||||
- **Feedback:** Submit tickets related to bugs or desired new features.
|
||||
Some of the ways you can contribute to 🤗 LeRobot:
|
||||
|
||||
If you are unsure where to start, join our [Discord Channel](https://discord.gg/JkrYNdmw).
|
||||
- Fixing outstanding issues with the existing code.
|
||||
- Implementing new models, datasets or simulation environments.
|
||||
- Contributing to the examples or to the documentation.
|
||||
- Submitting issues related to bugs or desired new features.
|
||||
|
||||
## Development Setup
|
||||
Following the guides below, feel free to open issues and PRs and to coordinate your efforts with the community on our [Discord Channel](https://discord.gg/VjFz58wn3R). For specific inquiries, reach out to [Remi Cadene](mailto:remi.cadene@huggingface.co).
|
||||
|
||||
To contribute code, you need to set up a development environment.
|
||||
If you are not sure how to contribute or want to know the next features we working on, look on this project page: [LeRobot TODO](https://github.com/orgs/huggingface/projects/46)
|
||||
|
||||
### 1. Fork and Clone
|
||||
## Submitting a new issue or feature request
|
||||
|
||||
Fork the repository on GitHub, then clone your fork:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/<your-handle>/lerobot.git
|
||||
cd lerobot
|
||||
git remote add upstream https://github.com/huggingface/lerobot.git
|
||||
```
|
||||
|
||||
### 2. Environment Installation
|
||||
|
||||
Please follow our [Installation Guide](./docs/source/installation.mdx) for the environment setup & installation from source.
|
||||
|
||||
## Running Tests & Quality Checks
|
||||
|
||||
### Code Style (Pre-commit)
|
||||
|
||||
Install `pre-commit` hooks to run checks automatically before you commit:
|
||||
|
||||
```bash
|
||||
pre-commit install
|
||||
```
|
||||
|
||||
To run checks manually on all files:
|
||||
|
||||
```bash
|
||||
pre-commit run --all-files
|
||||
```
|
||||
|
||||
### Running Tests
|
||||
|
||||
We use `pytest`. First, ensure you have test artifacts by installing **git-lfs**:
|
||||
Do your best to follow these guidelines when submitting an issue or a feature
|
||||
request. It will make it easier for us to come back to you quickly and with good
|
||||
feedback.
|
||||
|
||||
### Did you find a bug?
|
||||
|
||||
The 🤗 LeRobot library is robust and reliable thanks to the users who notify us of
|
||||
the problems they encounter. So thank you for reporting an issue.
|
||||
|
||||
First, we would really appreciate it if you could **make sure the bug was not
|
||||
already reported** (use the search bar on Github under Issues).
|
||||
|
||||
Did not find it? :( So we can act quickly on it, please follow these steps:
|
||||
|
||||
- Include your **OS type and version**, the versions of **Python** and **PyTorch**.
|
||||
- A short, self-contained, code snippet that allows us to reproduce the bug in
|
||||
less than 30s.
|
||||
- The full traceback if an exception is raised.
|
||||
- Attach any other additional information, like screenshots, you think may help.
|
||||
|
||||
### Do you want a new feature?
|
||||
|
||||
A good feature request addresses the following points:
|
||||
|
||||
1. Motivation first:
|
||||
|
||||
- Is it related to a problem/frustration with the library? If so, please explain
|
||||
why. Providing a code snippet that demonstrates the problem is best.
|
||||
- Is it related to something you would need for a project? We'd love to hear
|
||||
about it!
|
||||
- Is it something you worked on and think could benefit the community?
|
||||
Awesome! Tell us what problem it solved for you.
|
||||
|
||||
2. Write a _paragraph_ describing the feature.
|
||||
3. Provide a **code snippet** that demonstrates its future use.
|
||||
4. In case this is related to a paper, please attach a link.
|
||||
5. Attach any additional information (drawings, screenshots, etc.) you think may help.
|
||||
|
||||
If your issue is well written we're already 80% of the way there by the time you
|
||||
post it.
|
||||
|
||||
## Adding new policies, datasets or environments
|
||||
|
||||
Look at our implementations for [datasets](./src/lerobot/datasets/), [policies](./src/lerobot/policies/),
|
||||
environments ([aloha](https://github.com/huggingface/gym-aloha),
|
||||
[pusht](https://github.com/huggingface/gym-pusht))
|
||||
and follow the same api design.
|
||||
|
||||
When implementing a new dataset loadable with LeRobotDataset follow these steps:
|
||||
|
||||
- Update `available_datasets_per_env` in `lerobot/__init__.py`
|
||||
|
||||
When implementing a new environment (e.g. `gym_aloha`), follow these steps:
|
||||
|
||||
- Update `available_tasks_per_env` and `available_datasets_per_env` in `lerobot/__init__.py`
|
||||
|
||||
When implementing a new policy class (e.g. `DiffusionPolicy`) follow these steps:
|
||||
|
||||
- Update `available_policies` and `available_policies_per_env`, in `lerobot/__init__.py`
|
||||
- Set the required `name` class attribute.
|
||||
- Update variables in `tests/test_available.py` by importing your new Policy class
|
||||
|
||||
## Submitting a pull request (PR)
|
||||
|
||||
Before writing code, we strongly advise you to search through the existing PRs or
|
||||
issues to make sure that nobody is already working on the same thing. If you are
|
||||
unsure, it is always a good idea to open an issue to get some feedback.
|
||||
|
||||
You will need basic `git` proficiency to be able to contribute to
|
||||
🤗 LeRobot. `git` is not the easiest tool to use but it has the greatest
|
||||
manual. Type `git --help` in a shell and enjoy. If you prefer books, [Pro
|
||||
Git](https://git-scm.com/book/en/v2) is a very good reference.
|
||||
|
||||
Follow these steps to start contributing:
|
||||
|
||||
1. Fork the [repository](https://github.com/huggingface/lerobot) by
|
||||
clicking on the 'Fork' button on the repository's page. This creates a copy of the code
|
||||
under your GitHub user account.
|
||||
|
||||
2. Clone your fork to your local disk, and add the base repository as a remote. The following command
|
||||
assumes you have your public SSH key uploaded to GitHub. See the following guide for more
|
||||
[information](https://docs.github.com/en/repositories/creating-and-managing-repositories/cloning-a-repository).
|
||||
|
||||
```bash
|
||||
git clone git@github.com:<your Github handle>/lerobot.git
|
||||
cd lerobot
|
||||
git remote add upstream https://github.com/huggingface/lerobot.git
|
||||
```
|
||||
|
||||
3. Create a new branch to hold your development changes, and do this for every new PR you work on.
|
||||
|
||||
Start by synchronizing your `main` branch with the `upstream/main` branch (more details in the [GitHub Docs](https://docs.github.com/en/github/collaborating-with-issues-and-pull-requests/syncing-a-fork)):
|
||||
|
||||
```bash
|
||||
git checkout main
|
||||
git fetch upstream
|
||||
git rebase upstream/main
|
||||
```
|
||||
|
||||
Once your `main` branch is synchronized, create a new branch from it:
|
||||
|
||||
```bash
|
||||
git checkout -b a-descriptive-name-for-my-changes
|
||||
```
|
||||
|
||||
🚨 **Do not** work on the `main` branch.
|
||||
|
||||
4. for development, we advise to use a tool like `poetry` or `uv` instead of just `pip` to easily track our dependencies.
|
||||
Follow the instructions to [install poetry](https://python-poetry.org/docs/#installation) (use a version >=2.1.0) or to [install uv](https://docs.astral.sh/uv/getting-started/installation/#installation-methods) if you don't have one of them already.
|
||||
|
||||
Set up a development environment with conda:
|
||||
|
||||
```bash
|
||||
conda create -y -n lerobot-dev python=3.10 && conda activate lerobot-dev
|
||||
```
|
||||
|
||||
If you're using `uv`, it can manage python versions so you can instead do:
|
||||
|
||||
```bash
|
||||
uv venv --python 3.10 && source .venv/bin/activate
|
||||
```
|
||||
|
||||
To develop on 🤗 LeRobot, you will at least need to install the `dev` and `test` extras dependencies along with the core library:
|
||||
|
||||
using `poetry`
|
||||
|
||||
```bash
|
||||
poetry sync --extras "dev test"
|
||||
```
|
||||
|
||||
using `uv`
|
||||
|
||||
```bash
|
||||
uv sync --extra dev --extra test
|
||||
```
|
||||
|
||||
You can also install the project with all its dependencies (including environments):
|
||||
|
||||
using `poetry`
|
||||
|
||||
```bash
|
||||
poetry sync --all-extras
|
||||
```
|
||||
|
||||
using `uv`
|
||||
|
||||
```bash
|
||||
uv sync --all-extras
|
||||
```
|
||||
|
||||
> **Note:** If you don't install simulation environments with `--all-extras`, the tests that require them will be skipped when running the pytest suite locally. However, they _will_ be tested in the CI. In general, we advise you to install everything and test locally before pushing.
|
||||
|
||||
Whichever command you chose to install the project (e.g. `poetry sync --all-extras`), you should run it again when pulling code with an updated version of `pyproject.toml` and `poetry.lock` in order to synchronize your virtual environment with the new dependencies.
|
||||
|
||||
The equivalent of `pip install some-package`, would just be:
|
||||
|
||||
using `poetry`
|
||||
|
||||
```bash
|
||||
poetry add some-package
|
||||
```
|
||||
|
||||
using `uv`
|
||||
|
||||
```bash
|
||||
uv add some-package
|
||||
```
|
||||
|
||||
When making changes to the poetry sections of the `pyproject.toml`, you should run the following command to lock dependencies.
|
||||
using `poetry`
|
||||
|
||||
```bash
|
||||
poetry lock
|
||||
```
|
||||
|
||||
using `uv`
|
||||
|
||||
```bash
|
||||
uv lock
|
||||
```
|
||||
|
||||
5. Develop the features on your branch.
|
||||
|
||||
As you work on the features, you should make sure that the test suite
|
||||
passes. You should run the tests impacted by your changes like this (see
|
||||
below an explanation regarding the environment variable):
|
||||
|
||||
```bash
|
||||
pytest tests/<TEST_TO_RUN>.py
|
||||
```
|
||||
|
||||
6. Follow our style.
|
||||
|
||||
`lerobot` relies on `ruff` to format its source code
|
||||
consistently. Set up [`pre-commit`](https://pre-commit.com/) to run these checks
|
||||
automatically as Git commit hooks.
|
||||
|
||||
Install `pre-commit` hooks:
|
||||
|
||||
```bash
|
||||
pre-commit install
|
||||
```
|
||||
|
||||
You can run these hooks whenever you need on staged files with:
|
||||
|
||||
```bash
|
||||
pre-commit
|
||||
```
|
||||
|
||||
Once you're happy with your changes, add changed files using `git add` and
|
||||
make a commit with `git commit` to record your changes locally:
|
||||
|
||||
```bash
|
||||
git add modified_file.py
|
||||
git commit
|
||||
```
|
||||
|
||||
Note, if you already committed some changes that have a wrong formatting, you can use:
|
||||
|
||||
```bash
|
||||
pre-commit run --all-files
|
||||
```
|
||||
|
||||
Please write [good commit messages](https://chris.beams.io/posts/git-commit/).
|
||||
|
||||
It is a good idea to sync your copy of the code with the original
|
||||
repository regularly. This way you can quickly account for changes:
|
||||
|
||||
```bash
|
||||
git fetch upstream
|
||||
git rebase upstream/main
|
||||
```
|
||||
|
||||
Push the changes to your account using:
|
||||
|
||||
```bash
|
||||
git push -u origin a-descriptive-name-for-my-changes
|
||||
```
|
||||
|
||||
7. Once you are satisfied (**and the checklist below is happy too**), go to the
|
||||
webpage of your fork on GitHub. Click on 'Pull request' to send your changes
|
||||
to the project maintainers for review.
|
||||
|
||||
8. It's ok if maintainers ask you for changes. It happens to core contributors
|
||||
too! So everyone can see the changes in the Pull request, work in your local
|
||||
branch and push the changes to your fork. They will automatically appear in
|
||||
the pull request.
|
||||
|
||||
### Checklist
|
||||
|
||||
1. The title of your pull request should be a summary of its contribution;
|
||||
2. If your pull request addresses an issue, please mention the issue number in
|
||||
the pull request description to make sure they are linked (and people
|
||||
consulting the issue know you are working on it);
|
||||
3. To indicate a work in progress please prefix the title with `[WIP]`, or preferably mark
|
||||
the PR as a draft PR. These are useful to avoid duplicated work, and to differentiate
|
||||
it from PRs ready to be merged;
|
||||
4. Make sure existing tests pass;
|
||||
|
||||
### Tests
|
||||
|
||||
An extensive test suite is included to test the library behavior and several examples. Library tests can be found in the [tests folder](https://github.com/huggingface/lerobot/tree/main/tests).
|
||||
|
||||
Install [git lfs](https://git-lfs.com/) to retrieve test artifacts (if you don't have it already).
|
||||
|
||||
On Mac:
|
||||
|
||||
```bash
|
||||
brew install git-lfs
|
||||
git lfs install
|
||||
```
|
||||
|
||||
On Ubuntu:
|
||||
|
||||
```bash
|
||||
sudo apt-get install git-lfs
|
||||
git lfs install
|
||||
```
|
||||
|
||||
Pull artifacts if they're not in [tests/artifacts](tests/artifacts)
|
||||
|
||||
```bash
|
||||
git lfs pull
|
||||
```
|
||||
|
||||
Run the full suite (this may require extras installed):
|
||||
We use `pytest` in order to run the tests. From the root of the
|
||||
repository, here's how to run tests with `pytest` for the library:
|
||||
|
||||
```bash
|
||||
pytest -sv ./tests
|
||||
python -m pytest -sv ./tests
|
||||
```
|
||||
|
||||
Or run a specific test file during development:
|
||||
|
||||
```bash
|
||||
pytest -sv tests/test_specific_feature.py
|
||||
```
|
||||
|
||||
## Submitting Issues & Pull Requests
|
||||
|
||||
Use the templates for required fields and examples.
|
||||
|
||||
- **Issues:** Follow the [ticket template](./.github/ISSUE_TEMPLATE/bug-report.yml).
|
||||
- **Pull requests:** Rebase on `upstream/main`, use a descriptive branch (don't work on `main`), run `pre-commit` and tests locally, and follow the [PR template](./.github/PULL_REQUEST_TEMPLATE.md).
|
||||
|
||||
One member of the LeRobot team will then review your contribution.
|
||||
|
||||
Thank you for contributing to LeRobot!
|
||||
You can specify a smaller set of tests in order to test only the feature
|
||||
you're working on.
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
<p align="center">
|
||||
<img alt="LeRobot, Hugging Face Robotics Library" src="./media/readme/lerobot-logo-thumbnail.png" width="100%">
|
||||
<img alt="LeRobot, Hugging Face Robotics Library" src="https://raw.githubusercontent.com/huggingface/lerobot/main/media/lerobot-logo-thumbnail.png" width="100%">
|
||||
<br/>
|
||||
<br/>
|
||||
</p>
|
||||
|
||||
<div align="center">
|
||||
@@ -10,130 +12,323 @@
|
||||
[](https://pypi.org/project/lerobot/)
|
||||
[](https://pypi.org/project/lerobot/)
|
||||
[](https://github.com/huggingface/lerobot/blob/main/CODE_OF_CONDUCT.md)
|
||||
[](https://discord.gg/s3KuuzsPFb)
|
||||
|
||||
<!-- [](https://codecov.io/gh/huggingface/lerobot) -->
|
||||
|
||||
</div>
|
||||
|
||||
**LeRobot** aims to provide models, datasets, and tools for real-world robotics in PyTorch. The goal is to lower the barrier to entry so that everyone can contribute to and benefit from shared datasets and pretrained models.
|
||||
<h2 align="center">
|
||||
<p><a href="https://huggingface.co/docs/lerobot/hope_jr">
|
||||
Build Your Own HopeJR Robot!</a></p>
|
||||
</h2>
|
||||
|
||||
🤗 A hardware-agnostic, Python-native interface that standardizes control across diverse platforms, from low-cost arms (SO-100) to humanoids.
|
||||
<div align="center">
|
||||
<img
|
||||
src="https://raw.githubusercontent.com/huggingface/lerobot/main/media/hope_jr/hopejr.png"
|
||||
alt="HopeJR robot"
|
||||
title="HopeJR robot"
|
||||
width="60%"
|
||||
/>
|
||||
|
||||
🤗 A standardized, scalable LeRobotDataset format (Parquet + MP4 or images) hosted on the Hugging Face Hub, enabling efficient storage, streaming and visualization of massive robotic datasets.
|
||||
<p><strong>Meet HopeJR – A humanoid robot arm and hand for dexterous manipulation!</strong></p>
|
||||
<p>Control it with exoskeletons and gloves for precise hand movements.</p>
|
||||
<p>Perfect for advanced manipulation tasks! 🤖</p>
|
||||
|
||||
🤗 State-of-the-art policies that have been shown to transfer to the real-world ready for training and deployment.
|
||||
<p><a href="https://huggingface.co/docs/lerobot/hope_jr">
|
||||
See the full HopeJR tutorial here.</a></p>
|
||||
</div>
|
||||
|
||||
🤗 Comprehensive support for the open-source ecosystem to democratize physical AI.
|
||||
<br/>
|
||||
|
||||
## Quick Start
|
||||
<h2 align="center">
|
||||
<p><a href="https://huggingface.co/docs/lerobot/so101">
|
||||
Build Your Own SO-101 Robot!</a></p>
|
||||
</h2>
|
||||
|
||||
LeRobot can be installed directly from PyPI.
|
||||
<div align="center">
|
||||
<table>
|
||||
<tr>
|
||||
<td align="center"><img src="https://raw.githubusercontent.com/huggingface/lerobot/main/media/so101/so101.webp" alt="SO-101 follower arm" title="SO-101 follower arm" width="90%"/></td>
|
||||
<td align="center"><img src="https://raw.githubusercontent.com/huggingface/lerobot/main/media/so101/so101-leader.webp" alt="SO-101 leader arm" title="SO-101 leader arm" width="90%"/></td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
<p><strong>Meet the updated SO100, the SO-101 – Just €114 per arm!</strong></p>
|
||||
<p>Train it in minutes with a few simple moves on your laptop.</p>
|
||||
<p>Then sit back and watch your creation act autonomously! 🤯</p>
|
||||
|
||||
<p><a href="https://huggingface.co/docs/lerobot/so101">
|
||||
See the full SO-101 tutorial here.</a></p>
|
||||
|
||||
<p>Want to take it to the next level? Make your SO-101 mobile by building LeKiwi!</p>
|
||||
<p>Check out the <a href="https://huggingface.co/docs/lerobot/lekiwi">LeKiwi tutorial</a> and bring your robot to life on wheels.</p>
|
||||
|
||||
<img src="https://raw.githubusercontent.com/huggingface/lerobot/main/media/lekiwi/kiwi.webp" alt="LeKiwi mobile robot" title="LeKiwi mobile robot" width="50%">
|
||||
</div>
|
||||
|
||||
<br/>
|
||||
|
||||
<h3 align="center">
|
||||
<p>LeRobot: State-of-the-art AI for real-world robotics</p>
|
||||
</h3>
|
||||
|
||||
---
|
||||
|
||||
🤗 LeRobot aims to provide models, datasets, and tools for real-world robotics in PyTorch. The goal is to lower the barrier to entry to robotics so that everyone can contribute and benefit from sharing datasets and pretrained models.
|
||||
|
||||
🤗 LeRobot contains state-of-the-art approaches that have been shown to transfer to the real-world with a focus on imitation learning and reinforcement learning.
|
||||
|
||||
🤗 LeRobot already provides a set of pretrained models, datasets with human collected demonstrations, and simulation environments to get started without assembling a robot. In the coming weeks, the plan is to add more and more support for real-world robotics on the most affordable and capable robots out there.
|
||||
|
||||
🤗 LeRobot hosts pretrained models and datasets on this Hugging Face community page: [huggingface.co/lerobot](https://huggingface.co/lerobot)
|
||||
|
||||
#### Examples of pretrained models on simulation environments
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
<td><img src="https://raw.githubusercontent.com/huggingface/lerobot/main/media/gym/aloha_act.gif" width="100%" alt="ACT policy on ALOHA env"/></td>
|
||||
<td><img src="https://raw.githubusercontent.com/huggingface/lerobot/main/media/gym/simxarm_tdmpc.gif" width="100%" alt="TDMPC policy on SimXArm env"/></td>
|
||||
<td><img src="https://raw.githubusercontent.com/huggingface/lerobot/main/media/gym/pusht_diffusion.gif" width="100%" alt="Diffusion policy on PushT env"/></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">ACT policy on ALOHA env</td>
|
||||
<td align="center">TDMPC policy on SimXArm env</td>
|
||||
<td align="center">Diffusion policy on PushT env</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
## Installation
|
||||
|
||||
LeRobot works with Python 3.10+ and PyTorch 2.2+.
|
||||
|
||||
### Environment Setup
|
||||
|
||||
Create a virtual environment with Python 3.10 and activate it, e.g. with [`miniforge`](https://conda-forge.org/download/):
|
||||
|
||||
```bash
|
||||
conda create -y -n lerobot python=3.10
|
||||
conda activate lerobot
|
||||
```
|
||||
|
||||
When using `conda`, install `ffmpeg` in your environment:
|
||||
|
||||
```bash
|
||||
conda install ffmpeg -c conda-forge
|
||||
```
|
||||
|
||||
> **NOTE:** This usually installs `ffmpeg 7.X` for your platform compiled with the `libsvtav1` encoder. If `libsvtav1` is not supported (check supported encoders with `ffmpeg -encoders`), you can:
|
||||
>
|
||||
> - _[On any platform]_ Explicitly install `ffmpeg 7.X` using:
|
||||
>
|
||||
> ```bash
|
||||
> conda install ffmpeg=7.1.1 -c conda-forge
|
||||
> ```
|
||||
>
|
||||
> - _[On Linux only]_ Install [ffmpeg build dependencies](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#GettheDependencies) and [compile ffmpeg from source with libsvtav1](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#libsvtav1), and make sure you use the corresponding ffmpeg binary to your install with `which ffmpeg`.
|
||||
|
||||
### Install LeRobot 🤗
|
||||
|
||||
#### From Source
|
||||
|
||||
First, clone the repository and navigate into the directory:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/huggingface/lerobot.git
|
||||
cd lerobot
|
||||
```
|
||||
|
||||
Then, install the library in editable mode. This is useful if you plan to contribute to the code.
|
||||
|
||||
```bash
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
> **NOTE:** If you encounter build errors, you may need to install additional dependencies (`cmake`, `build-essential`, and `ffmpeg libs`). On Linux, run:
|
||||
> `sudo apt-get install cmake build-essential python3-dev pkg-config libavformat-dev libavcodec-dev libavdevice-dev libavutil-dev libswscale-dev libswresample-dev libavfilter-dev`. For other systems, see: [Compiling PyAV](https://pyav.org/docs/develop/overview/installation.html#bring-your-own-ffmpeg)
|
||||
|
||||
For simulations, 🤗 LeRobot comes with gymnasium environments that can be installed as extras:
|
||||
|
||||
- [aloha](https://github.com/huggingface/gym-aloha)
|
||||
- [xarm](https://github.com/huggingface/gym-xarm)
|
||||
- [pusht](https://github.com/huggingface/gym-pusht)
|
||||
|
||||
For instance, to install 🤗 LeRobot with aloha and pusht, use:
|
||||
|
||||
```bash
|
||||
pip install -e ".[aloha, pusht]"
|
||||
```
|
||||
|
||||
### Installation from PyPI
|
||||
|
||||
**Core Library:**
|
||||
Install the base package with:
|
||||
|
||||
```bash
|
||||
pip install lerobot
|
||||
lerobot-info
|
||||
```
|
||||
|
||||
> [!IMPORTANT]
|
||||
> For detailed installation guide, please see the [Installation Documentation](https://huggingface.co/docs/lerobot/installation).
|
||||
_This installs only the default dependencies._
|
||||
|
||||
## Robots & Control
|
||||
|
||||
<div align="center">
|
||||
<img src="./media/readme/robots_control_video.webp" width="640px" alt="Reachy 2 Demo">
|
||||
</div>
|
||||
|
||||
LeRobot provides a unified `Robot` class interface that decouples control logic from hardware specifics. It supports a wide range of robots and teleoperation devices.
|
||||
|
||||
```python
|
||||
from lerobot.robots.myrobot import MyRobot
|
||||
|
||||
# Connect to a robot
|
||||
robot = MyRobot(config=...)
|
||||
robot.connect()
|
||||
|
||||
# Read observation and send action
|
||||
obs = robot.get_observation()
|
||||
action = model.select_action(obs)
|
||||
robot.send_action(action)
|
||||
```
|
||||
|
||||
**Supported Hardware:** SO100, LeKiwi, Koch, HopeJR, OMX, EarthRover, Reachy2, Gamepads, Keyboards, Phones, OpenARM, Unitree G1.
|
||||
|
||||
While these devices are natively integrated into the LeRobot codebase, the library is designed to be extensible. You can easily implement the Robot interface to utilize LeRobot's data collection, training, and visualization tools for your own custom robot.
|
||||
|
||||
For detailed hardware setup guides, see the [Hardware Documentation](https://huggingface.co/docs/lerobot/integrate_hardware).
|
||||
|
||||
## LeRobot Dataset
|
||||
|
||||
To solve the data fragmentation problem in robotics, we utilize the **LeRobotDataset** format.
|
||||
|
||||
- **Structure:** Synchronized MP4 videos (or images) for vision and Parquet files for state/action data.
|
||||
- **HF Hub Integration:** Explore thousands of robotics datasets on the [Hugging Face Hub](https://huggingface.co/lerobot).
|
||||
- **Tools:** Seamlessly delete episodes, split by indices/fractions, add/remove features, and merge multiple datasets.
|
||||
|
||||
```python
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
# Load a dataset from the Hub
|
||||
dataset = LeRobotDataset("lerobot/aloha_mobile_cabinet")
|
||||
|
||||
# Access data (automatically handles video decoding)
|
||||
episode_index=0
|
||||
print(f"{dataset[episode_index]['action'].shape=}\n")
|
||||
```
|
||||
|
||||
Learn more about it in the [LeRobotDataset Documentation](https://huggingface.co/docs/lerobot/lerobot-dataset-v3)
|
||||
|
||||
## SoTA Models
|
||||
|
||||
LeRobot implements state-of-the-art policies in pure PyTorch, covering Imitation Learning, Reinforcement Learning, and Vision-Language-Action (VLA) models, with more coming soon. It also provides you with the tools to instrument and inspect your training process.
|
||||
|
||||
<p align="center">
|
||||
<img alt="Gr00t Architecture" src="./media/readme/VLA_architecture.jpg" width="640px">
|
||||
</p>
|
||||
|
||||
Training a policy is as simple as running a script configuration:
|
||||
**Extra Features:**
|
||||
To install additional functionality, use one of the following:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--policy=act \
|
||||
--dataset.repo_id=lerobot/aloha_mobile_cabinet
|
||||
pip install 'lerobot[all]' # All available features
|
||||
pip install 'lerobot[aloha,pusht]' # Specific features (Aloha & Pusht)
|
||||
pip install 'lerobot[feetech]' # Feetech motor support
|
||||
```
|
||||
|
||||
| Category | Models |
|
||||
| -------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| **Imitation Learning** | [ACT](./docs/source/policy_act_README.md), [Diffusion](./docs/source/policy_diffusion_README.md), [VQ-BeT](./docs/source/policy_vqbet_README.md) |
|
||||
| **Reinforcement Learning** | [HIL-SERL](./docs/source/hilserl.mdx), [TDMPC](./docs/source/policy_tdmpc_README.md) & QC-FQL (coming soon) |
|
||||
| **VLAs Models** | [Pi0.5](./docs/source/pi05.mdx), [GR00T N1.5](./docs/source/policy_groot_README.md), [SmolVLA](./docs/source/policy_smolvla_README.md), [XVLA](./docs/source/xvla.mdx) |
|
||||
_Replace `[...]` with your desired features._
|
||||
|
||||
Similarly to the hardware, you can easily implement your own policy & leverage LeRobot's data collection, training, and visualization tools, and share your model to the HF Hub
|
||||
**Available Tags:**
|
||||
For a full list of optional dependencies, see:
|
||||
https://pypi.org/project/lerobot/
|
||||
|
||||
For detailed policy setup guides, see the [Policy Documentation](https://huggingface.co/docs/lerobot/bring_your_own_policies).
|
||||
> [!NOTE]
|
||||
> For lerobot 0.4.0, if you want to install pi tags, you will have to do: `pip install "lerobot[pi]@git+https://github.com/huggingface/lerobot.git"`.
|
||||
>
|
||||
> This will be solved in the next patch release
|
||||
|
||||
## Inference & Evaluation
|
||||
### Weights & Biases
|
||||
|
||||
Evaluate your policies in simulation or on real hardware using the unified evaluation script. LeRobot supports standard benchmarks like **LIBERO**, **MetaWorld** and more to come.
|
||||
To use [Weights and Biases](https://docs.wandb.ai/quickstart) for experiment tracking, log in with
|
||||
|
||||
```bash
|
||||
# Evaluate a policy on the LIBERO benchmark
|
||||
lerobot-eval \
|
||||
--policy.path=lerobot/pi0_libero_finetuned \
|
||||
--env.type=libero \
|
||||
--env.task=libero_object \
|
||||
--eval.n_episodes=10
|
||||
wandb login
|
||||
```
|
||||
|
||||
Learn how to implement your own simulation environment or benchmark and distribute it from the HF Hub by following the [EnvHub Documentation](https://huggingface.co/docs/lerobot/envhub)
|
||||
(note: you will also need to enable WandB in the configuration. See below.)
|
||||
|
||||
## Resources
|
||||
### Visualize datasets
|
||||
|
||||
- **[Documentation](https://huggingface.co/docs/lerobot/index):** The complete guide to tutorials & API.
|
||||
- **[Discord](https://discord.gg/3gxM6Avj):** Join the `LeRobot` server to discuss with the community.
|
||||
- **[X](https://x.com/LeRobotHF):** Follow us on X to stay up-to-date with the latest developments.
|
||||
- **[Robot Learning Tutorial](https://huggingface.co/spaces/lerobot/robot-learning-tutorial):** A free, hands-on course to learn robot learning using LeRobot.
|
||||
Check out [example 1](https://github.com/huggingface/lerobot/blob/main/examples/dataset/load_lerobot_dataset.py) that illustrates how to use our dataset class which automatically downloads data from the Hugging Face hub.
|
||||
|
||||
You can also locally visualize episodes from a dataset on the hub by executing our script from the command line:
|
||||
|
||||
```bash
|
||||
lerobot-dataset-viz \
|
||||
--repo-id lerobot/pusht \
|
||||
--episode-index 0
|
||||
```
|
||||
|
||||
or from a dataset in a local folder with the `root` option and the `--mode local` (in the following case the dataset will be searched for in `./my_local_data_dir/lerobot/pusht`)
|
||||
|
||||
```bash
|
||||
lerobot-dataset-viz \
|
||||
--repo-id lerobot/pusht \
|
||||
--root ./my_local_data_dir \
|
||||
--mode local \
|
||||
--episode-index 0
|
||||
```
|
||||
|
||||
It will open `rerun.io` and display the camera streams, robot states and actions, like this:
|
||||
|
||||
https://github-production-user-asset-6210df.s3.amazonaws.com/4681518/328035972-fd46b787-b532-47e2-bb6f-fd536a55a7ed.mov?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240505%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240505T172924Z&X-Amz-Expires=300&X-Amz-Signature=d680b26c532eeaf80740f08af3320d22ad0b8a4e4da1bcc4f33142c15b509eda&X-Amz-SignedHeaders=host&actor_id=24889239&key_id=0&repo_id=748713144
|
||||
|
||||
Our script can also visualize datasets stored on a distant server. See `lerobot-dataset-viz --help` for more instructions.
|
||||
|
||||
### The `LeRobotDataset` format
|
||||
|
||||
A dataset in `LeRobotDataset` format is very simple to use. It can be loaded from a repository on the Hugging Face hub or a local folder simply with e.g. `dataset = LeRobotDataset("lerobot/aloha_static_coffee")` and can be indexed into like any Hugging Face and PyTorch dataset. For instance `dataset[0]` will retrieve a single temporal frame from the dataset containing observation(s) and an action as PyTorch tensors ready to be fed to a model.
|
||||
|
||||
A specificity of `LeRobotDataset` is that, rather than retrieving a single frame by its index, we can retrieve several frames based on their temporal relationship with the indexed frame, by setting `delta_timestamps` to a list of relative times with respect to the indexed frame. For example, with `delta_timestamps = {"observation.image": [-1, -0.5, -0.2, 0]}` one can retrieve, for a given index, 4 frames: 3 "previous" frames 1 second, 0.5 seconds, and 0.2 seconds before the indexed frame, and the indexed frame itself (corresponding to the 0 entry). See example [1_load_lerobot_dataset.py](https://github.com/huggingface/lerobot/blob/main/examples/dataset/load_lerobot_dataset.py) for more details on `delta_timestamps`.
|
||||
|
||||
Under the hood, the `LeRobotDataset` format makes use of several ways to serialize data which can be useful to understand if you plan to work more closely with this format. We tried to make a flexible yet simple dataset format that would cover most type of features and specificities present in reinforcement learning and robotics, in simulation and in real-world, with a focus on cameras and robot states but easily extended to other types of sensory inputs as long as they can be represented by a tensor.
|
||||
|
||||
Here are the important details and internal structure organization of a typical `LeRobotDataset` instantiated with `dataset = LeRobotDataset("lerobot/aloha_static_coffee")`. The exact features will change from dataset to dataset but not the main aspects:
|
||||
|
||||
```
|
||||
dataset attributes:
|
||||
├ hf_dataset: a Hugging Face dataset (backed by Arrow/parquet). Typical features example:
|
||||
│ ├ observation.images.cam_high (VideoFrame):
|
||||
│ │ VideoFrame = {'path': path to a mp4 video, 'timestamp' (float32): timestamp in the video}
|
||||
│ ├ observation.state (list of float32): position of an arm joints (for instance)
|
||||
│ ... (more observations)
|
||||
│ ├ action (list of float32): goal position of an arm joints (for instance)
|
||||
│ ├ episode_index (int64): index of the episode for this sample
|
||||
│ ├ frame_index (int64): index of the frame for this sample in the episode ; starts at 0 for each episode
|
||||
│ ├ timestamp (float32): timestamp in the episode
|
||||
│ ├ next.done (bool): indicates the end of an episode ; True for the last frame in each episode
|
||||
│ └ index (int64): general index in the whole dataset
|
||||
├ meta: a LeRobotDatasetMetadata object containing:
|
||||
│ ├ info: a dictionary of metadata on the dataset
|
||||
│ │ ├ codebase_version (str): this is to keep track of the codebase version the dataset was created with
|
||||
│ │ ├ fps (int): frame per second the dataset is recorded/synchronized to
|
||||
│ │ ├ features (dict): all features contained in the dataset with their shapes and types
|
||||
│ │ ├ total_episodes (int): total number of episodes in the dataset
|
||||
│ │ ├ total_frames (int): total number of frames in the dataset
|
||||
│ │ ├ robot_type (str): robot type used for recording
|
||||
│ │ ├ data_path (str): formattable string for the parquet files
|
||||
│ │ └ video_path (str): formattable string for the video files (if using videos)
|
||||
│ ├ episodes: a DataFrame containing episode metadata with columns:
|
||||
│ │ ├ episode_index (int): index of the episode
|
||||
│ │ ├ tasks (list): list of tasks for this episode
|
||||
│ │ ├ length (int): number of frames in this episode
|
||||
│ │ ├ dataset_from_index (int): start index of this episode in the dataset
|
||||
│ │ └ dataset_to_index (int): end index of this episode in the dataset
|
||||
│ ├ stats: a dictionary of statistics (max, mean, min, std) for each feature in the dataset, for instance
|
||||
│ │ ├ observation.images.front_cam: {'max': tensor with same number of dimensions (e.g. `(c, 1, 1)` for images, `(c,)` for states), etc.}
|
||||
│ │ └ ...
|
||||
│ └ tasks: a DataFrame containing task information with task names as index and task_index as values
|
||||
├ root (Path): local directory where the dataset is stored
|
||||
├ image_transforms (Callable): optional image transformations to apply to visual modalities
|
||||
└ delta_timestamps (dict): optional delta timestamps for temporal queries
|
||||
```
|
||||
|
||||
A `LeRobotDataset` is serialised using several widespread file formats for each of its parts, namely:
|
||||
|
||||
- hf_dataset stored using Hugging Face datasets library serialization to parquet
|
||||
- videos are stored in mp4 format to save space
|
||||
- metadata are stored in plain json/jsonl files
|
||||
|
||||
Dataset can be uploaded/downloaded from the HuggingFace hub seamlessly. To work on a local dataset, you can specify its location with the `root` argument if it's not in the default `~/.cache/huggingface/lerobot` location.
|
||||
|
||||
#### Reproduce state-of-the-art (SOTA)
|
||||
|
||||
We provide some pretrained policies on our [hub page](https://huggingface.co/lerobot) that can achieve state-of-the-art performances.
|
||||
You can reproduce their training by loading the config from their run. Simply running:
|
||||
|
||||
```bash
|
||||
lerobot-train --config_path=lerobot/diffusion_pusht
|
||||
```
|
||||
|
||||
reproduces SOTA results for Diffusion Policy on the PushT task.
|
||||
|
||||
## Contribute
|
||||
|
||||
If you would like to contribute to 🤗 LeRobot, please check out our [contribution guide](https://github.com/huggingface/lerobot/blob/main/CONTRIBUTING.md).
|
||||
|
||||
### Add a pretrained policy
|
||||
|
||||
Once you have trained a policy you may upload it to the Hugging Face hub using a hub id that looks like `${hf_user}/${repo_name}` (e.g. [lerobot/diffusion_pusht](https://huggingface.co/lerobot/diffusion_pusht)).
|
||||
|
||||
You first need to find the checkpoint folder located inside your experiment directory (e.g. `outputs/train/2024-05-05/20-21-12_aloha_act_default/checkpoints/002500`). Within that there is a `pretrained_model` directory which should contain:
|
||||
|
||||
- `config.json`: A serialized version of the policy configuration (following the policy's dataclass config).
|
||||
- `model.safetensors`: A set of `torch.nn.Module` parameters, saved in [Hugging Face Safetensors](https://huggingface.co/docs/safetensors/index) format.
|
||||
- `train_config.json`: A consolidated configuration containing all parameters used for training. The policy configuration should match `config.json` exactly. This is useful for anyone who wants to evaluate your policy or for reproducibility.
|
||||
|
||||
To upload these to the hub, run the following:
|
||||
|
||||
```bash
|
||||
huggingface-cli upload ${hf_user}/${repo_name} path/to/pretrained_model
|
||||
```
|
||||
|
||||
See [lerobot_eval.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/scripts/lerobot_eval.py) for an example of how other people may use your policy.
|
||||
|
||||
### Acknowledgment
|
||||
|
||||
- The LeRobot team 🤗 for building SmolVLA [Paper](https://arxiv.org/abs/2506.01844), [Blog](https://huggingface.co/blog/smolvla).
|
||||
- Thanks to Tony Zhao, Zipeng Fu and colleagues for open sourcing ACT policy, ALOHA environments and datasets. Ours are adapted from [ALOHA](https://tonyzhaozh.github.io/aloha) and [Mobile ALOHA](https://mobile-aloha.github.io).
|
||||
- Thanks to Cheng Chi, Zhenjia Xu and colleagues for open sourcing Diffusion policy, Pusht environment and datasets, as well as UMI datasets. Ours are adapted from [Diffusion Policy](https://diffusion-policy.cs.columbia.edu) and [UMI Gripper](https://umi-gripper.github.io).
|
||||
- Thanks to Nicklas Hansen, Yunhai Feng and colleagues for open sourcing TDMPC policy, Simxarm environments and datasets. Ours are adapted from [TDMPC](https://github.com/nicklashansen/tdmpc) and [FOWM](https://www.yunhaifeng.com/FOWM).
|
||||
- Thanks to Antonio Loquercio and Ashish Kumar for their early support.
|
||||
- Thanks to [Seungjae (Jay) Lee](https://sjlee.cc/), [Mahi Shafiullah](https://mahis.life/) and colleagues for open sourcing [VQ-BeT](https://sjlee.cc/vq-bet/) policy and helping us adapt the codebase to our repository. The policy is adapted from [VQ-BeT repo](https://github.com/jayLEE0301/vq_bet_official).
|
||||
|
||||
## Citation
|
||||
|
||||
If you use LeRobot in your research, please cite:
|
||||
If you want, you can cite this work with:
|
||||
|
||||
```bibtex
|
||||
@misc{cadene2024lerobot,
|
||||
@@ -144,14 +339,6 @@ If you use LeRobot in your research, please cite:
|
||||
}
|
||||
```
|
||||
|
||||
## Contribute
|
||||
## Star History
|
||||
|
||||
We welcome contributions from everyone in the community! To get started, please read our [CONTRIBUTING.md](./CONTRIBUTING.md) guide. Whether you're adding a new feature, improving documentation, or fixing a bug, your help and feedback are invaluable. We're incredibly excited about the future of open-source robotics and can't wait to work with you on what's next—thank you for your support!
|
||||
|
||||
<p align="center">
|
||||
<img alt="SO101 Video" src="./media/readme/so100_video.webp" width="640px">
|
||||
</p>
|
||||
|
||||
<div align="center">
|
||||
<sub>Built by the <a href="https://huggingface.co/lerobot">LeRobot</a> team at <a href="https://huggingface.co">Hugging Face</a> with ❤️</sub>
|
||||
</div>
|
||||
[](https://star-history.com/#huggingface/lerobot&Timeline)
|
||||
|
||||
@@ -41,13 +41,7 @@
|
||||
title: NVIDIA GR00T N1.5
|
||||
- local: xvla
|
||||
title: X-VLA
|
||||
- local: walloss
|
||||
title: WALL-OSS
|
||||
title: "Policies"
|
||||
- sections:
|
||||
- local: sarm
|
||||
title: SARM
|
||||
title: "Reward Models"
|
||||
- sections:
|
||||
- local: async
|
||||
title: Use Async Inference
|
||||
|
||||
@@ -201,8 +201,7 @@ from lerobot.teleoperators.so100_leader.so100_leader import SO100Leader
|
||||
from lerobot.utils.control_utils import init_keyboard_listener
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import init_rerun
|
||||
from lerobot.scripts.lerobot_record import record_loop
|
||||
from lerobot.processor import make_default_processors
|
||||
from lerobot.record import record_loop
|
||||
|
||||
NUM_EPISODES = 5
|
||||
FPS = 30
|
||||
@@ -210,19 +209,12 @@ EPISODE_TIME_SEC = 60
|
||||
RESET_TIME_SEC = 10
|
||||
TASK_DESCRIPTION = "My task description"
|
||||
|
||||
# Create robot configuration
|
||||
# Create the robot and teleoperator configurations
|
||||
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
|
||||
robot_config = SO100FollowerConfig(
|
||||
id="my_awesome_follower_arm",
|
||||
cameras={
|
||||
"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS) # Optional: fourcc="MJPG" for troubleshooting OpenCV async error.
|
||||
},
|
||||
port="/dev/tty.usbmodem58760434471",
|
||||
)
|
||||
|
||||
teleop_config = SO100LeaderConfig(
|
||||
id="my_awesome_leader_arm",
|
||||
port="/dev/tty.usbmodem585A0077581",
|
||||
port="/dev/tty.usbmodem58760434471", id="my_awesome_follower_arm", cameras=camera_config
|
||||
)
|
||||
teleop_config = SO100LeaderConfig(port="/dev/tty.usbmodem585A0077581", id="my_awesome_leader_arm")
|
||||
|
||||
# Initialize the robot and teleoperator
|
||||
robot = SO100Follower(robot_config)
|
||||
@@ -251,9 +243,6 @@ init_rerun(session_name="recording")
|
||||
robot.connect()
|
||||
teleop.connect()
|
||||
|
||||
# Create the required processors
|
||||
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
|
||||
|
||||
episode_idx = 0
|
||||
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
@@ -262,9 +251,6 @@ while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
teleop=teleop,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
@@ -279,9 +265,6 @@ while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
teleop=teleop,
|
||||
control_time_s=RESET_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
|
||||
@@ -1,35 +0,0 @@
|
||||
# WALL-OSS
|
||||
|
||||
This repository contains the Hugging Face port of **WALL-OSS**, a Vision-Language-Action model for cross-embodiment robotic control based on Qwen2.5-VL with flow matching/FAST action prediction.
|
||||
|
||||
---
|
||||
|
||||
## Model Overview
|
||||
|
||||
| Feature | Description |
|
||||
| ------------------ | ----------------------------------------------------- | --- |
|
||||
| Base Model | Qwen2.5-VL (Vision-Language Model) |
|
||||
| Action Prediction | Flow Matching (diffusion) or FAST (discrete tokens) |
|
||||
| Architecture | Mixture of Experts (MoE) with action-specific routing | |
|
||||
| Multi-Modal Inputs | Vision (images/videos), Language, Proprioception |
|
||||
|
||||
---
|
||||
|
||||
## Citation
|
||||
|
||||
If you use this work, please cite:
|
||||
|
||||
```bibtex
|
||||
@article{zhai2025igniting,
|
||||
title = {Igniting VLMs Toward the Embodied Space},
|
||||
author = {Zhai, Andy and Liu, Brae and Fang, Bruno and Cai, Chalse and Ma, Ellie and Yin, Ethan and Wang, Hao and Zhou, Hugo and Wang, James and Shi, Lights and Liang, Lucy and Wang, Make and Wang, Qian and Gan, Roy and Yu, Ryan and Li, Shalfun and Liu, Starrick and Chen, Sylas and Chen, Vincent and Xu, Zach},
|
||||
journal = {arXiv preprint arXiv:2509.11766},
|
||||
year = {2025}
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## License
|
||||
|
||||
This port follows the **Apache 2.0 License**.
|
||||
@@ -1,586 +0,0 @@
|
||||
# SARM: Stage-Aware Reward Modeling
|
||||
|
||||
SARM (Stage-Aware Reward Modeling) is a video-based reward modeling framework for long-horizon robot manipulation tasks. This guide covers how to train SARM reward models and optionally use them with Reward-Aligned Behavior Cloning (RA-BC).
|
||||
|
||||
**Paper**: [SARM: Stage-Aware Reward Modeling for Long Horizon Robot Manipulation](https://arxiv.org/abs/2509.25358)
|
||||
|
||||
## Why Reward Models?
|
||||
|
||||
Standard behavior cloning treats all demonstration frames equally, but real-world robot datasets are messy. They contain hesitations, corrections, and variable-quality trajectories. Reward models solve this by learning a generalizable notion of **task progress** from demonstrations: given video frames and a task description, they predict how close the robot is to completing the task (0→1). This learned "progress signal" can be used in multiple ways, two promising applications are: (1) **weighted imitation learning** (RA-BC), where high-progress frames receive more weight during policy training, and (2) **reinforcement learning**, where the reward model provides dense rewards for online or offline policy improvement.
|
||||
|
||||
## Overview
|
||||
|
||||
SARM has following features:
|
||||
|
||||
1. **Stage-aware architecture**: Jointly predicts the high-level task stage and fine-grained progress within each stage
|
||||
2. **Subtask annotations**: Uses natural language subtask annotations to derive consistent progress labels
|
||||
3. **Temporal proportions**: Computes dataset-level priors (α̅\_k) for each subtask to normalize progress across variable-length demonstrations
|
||||
|
||||
SARM trains on a compact **stage+tau** target for each frame:
|
||||
|
||||
- **stage**: integer stage index `k ∈ {0, ..., K-1}`
|
||||
- **τ (tau)**: within-stage progress `τ ∈ [0, 1]`
|
||||
- **target encoding**: `y = k + τ` (this is what the dataset processor produces)
|
||||
|
||||
At inference time (and in downstream RA-BC), SARM converts the raw `k + τ` value into a **normalized progress** in `[0, 1]` using dataset-level **temporal proportions** `α̅_k` (stored in `meta/temporal_proportions_*.json`).
|
||||
|
||||
This matches **Formula (2)** from the paper:
|
||||
|
||||
```
|
||||
progress_t = P_{k-1} + α̅_k × τ_t
|
||||
```
|
||||
|
||||
Where:
|
||||
|
||||
- `τ_t = (t - s_k) / (e_k - s_k)` is within-subtask normalized time
|
||||
- `P_{k-1}` is cumulative prior (sum of previous subtask proportions)
|
||||
- `α̅_k` is the temporal proportion for subtask k
|
||||
|
||||
This ensures identical task states map to consistent progress values, even across demonstrations of different lengths.
|
||||
|
||||
## Inputs and Targets (What the new code expects)
|
||||
|
||||
SARM is trained through its processor (`src/lerobot/policies/sarm/processor_sarm.py`), which:
|
||||
|
||||
- **Encodes** images and task text with CLIP (ViT-B/32) into `video_features` and `text_features`
|
||||
- **Pads/truncates** robot state into `state_features` (up to `max_state_dim`)
|
||||
- **Builds targets** as `sparse_targets` (and `dense_targets` in `dense_only`/`dual`) using the stage+tau encoding `y = k + τ`
|
||||
- **Masks rewind frames** using a per-sample `lengths` tensor (rewind is a training-time augmentation)
|
||||
|
||||
At minimum, each training sample needs:
|
||||
|
||||
- `task` (string): task description
|
||||
- `policy.image_key` images and `policy.state_key` states from the dataset
|
||||
|
||||
---
|
||||
|
||||
## Annotation Modes
|
||||
|
||||
You can choose from **3 annotation modes** that determine how progress labels are computed:
|
||||
|
||||
| Mode | Annotations Required | Heads | Use Case |
|
||||
| -------------- | -------------------- | ---------------------------- | ------------------------------------------------------------ |
|
||||
| `single_stage` | None | Sparse only | Simple tasks, quick experiments, no VLM needed |
|
||||
| `dense_only` | Dense (VLM) | Dual (sparse auto-generated) | Detailed subtask tracking without defining high-level stages |
|
||||
| `dual` | Sparse + Dense (VLM) | Dual | Full SARM paper setup with both granularities |
|
||||
|
||||
### Mode Details
|
||||
|
||||
<hfoptions id="mode_explanation">
|
||||
<hfoption id="single_stage">
|
||||
|
||||
**No annotations required.** The entire episode is treated as a single stage called `"task"`, and progress is linear from 0 to 1 over the episode duration.
|
||||
|
||||
- **Sparse head**: 1 stage ("task"), linear progress
|
||||
- **Dense head**: Not used
|
||||
- **Best for**: Simple tasks, quick experiments, or when VLM annotation is not available
|
||||
|
||||
## Set Up Your Environment
|
||||
|
||||
1. Install LeRobot by following our [Installation Guide](./installation).
|
||||
2. Install SARM dependencies by running:
|
||||
|
||||
```bash
|
||||
pip install -e ".[sarm]"
|
||||
```
|
||||
|
||||
Workflow:
|
||||
|
||||
```
|
||||
1. Train SARM → 2. Visualize predictions → 3. (Optional) Train policy with RA-BC
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="dense_only">
|
||||
|
||||
**Only dense (fine-grained) annotations from a VLM.** The sparse head automatically uses a single `"task"` stage covering the full episode, while the dense head learns detailed subtask progression.
|
||||
|
||||
- **Sparse head**: 1 stage ("task"), linear progress (auto-generated)
|
||||
- **Dense head**: Multiple fine-grained stages from VLM annotations
|
||||
- **Best for**: When you want detailed subtask tracking but don't need to define high-level stages
|
||||
|
||||
Workflow:
|
||||
|
||||
```
|
||||
1. Annotate (dense) → 2. Verify → 3. Train SARM → 4. Visualize → 5. (Optional) Train policy with RA-BC
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="dual">
|
||||
|
||||
**Both sparse and dense annotations from VLM.** Full dual-head mode as described in the SARM paper, with both high-level (sparse) and fine-grained (dense) stage predictions.
|
||||
|
||||
- **Sparse head**: High-level stages from VLM annotations
|
||||
- **Dense head**: Fine-grained stages from VLM annotations
|
||||
- **Best for**: Complex multi-stage tasks where both granularities are useful
|
||||
|
||||
Workflow:
|
||||
|
||||
```
|
||||
1. Annotate (sparse+dense) → 2. Verify → 3. Train SARM → 4. Visualize → 5. (Optional) Train policy with RA-BC
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
---
|
||||
|
||||
## Step 1: Subtask Annotation
|
||||
|
||||
<hfoptions id="annotation_mode">
|
||||
<hfoption id="single_stage">
|
||||
|
||||
**No annotation required!** Skip this step entirely. The model will use the episode's task description and compute linear progress automatically.
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="dense_only">
|
||||
|
||||
Generate **dense (fine-grained) annotations only** using a VLM. The sparse stage will be auto-generated.
|
||||
|
||||
```bash
|
||||
python src/lerobot/data_processing/sarm_annotations/subtask_annotation.py \
|
||||
--repo-id your-username/your-dataset \
|
||||
--dense-only \
|
||||
--dense-subtasks "Bring robot arms up from starting position,Grab near side and do 1st fold,Grab side and do 2nd fold,Grab side and do 3rd fold to finish folding" \
|
||||
--video-key observation.images.base \
|
||||
--num-workers 4 \
|
||||
--push-to-hub
|
||||
```
|
||||
|
||||
**What gets saved:**
|
||||
|
||||
- `meta/temporal_proportions_sparse.json` - Auto-generated sparse proportions (`{"task": 1.0}`)
|
||||
- `meta/temporal_proportions_dense.json` - Dense temporal proportions
|
||||
- Per-episode columns in `episodes/*.parquet`:
|
||||
- `dense_subtask_names`, `dense_subtask_start_frames`, `dense_subtask_end_frames`
|
||||
- (also time-based columns: `dense_subtask_start_times`, `dense_subtask_end_times`)
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="dual">
|
||||
|
||||
Generate **both sparse (high-level) and dense (fine-grained) annotations** using a VLM.
|
||||
|
||||
```bash
|
||||
python src/lerobot/data_processing/sarm_annotations/subtask_annotation.py \
|
||||
--repo-id your-username/your-dataset \
|
||||
--sparse-subtasks "Bring arms up from starting position,Fold the towel (3 folds in total)" \
|
||||
--dense-subtasks "Bring robot arms up from starting position,Grab near side and do 1st fold,Grab side and do 2nd fold,Grab side and do 3rd fold to finish folding" \
|
||||
--video-key observation.images.base \
|
||||
--num-workers 4 \
|
||||
--push-to-hub
|
||||
```
|
||||
|
||||
**What gets saved:**
|
||||
|
||||
- `meta/temporal_proportions_sparse.json` - Sparse temporal proportions
|
||||
- `meta/temporal_proportions_dense.json` - Dense temporal proportions
|
||||
- Per-episode columns in `episodes/*.parquet`:
|
||||
- `sparse_subtask_names`, `sparse_subtask_start_frames`, `sparse_subtask_end_frames`
|
||||
- `dense_subtask_names`, `dense_subtask_start_frames`, `dense_subtask_end_frames`
|
||||
- (also time-based columns: `*_subtask_start_times`, `*_subtask_end_times`)
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
### Annotation Arguments
|
||||
|
||||
| Argument | Description |
|
||||
| ---------------------- | ------------------------------------------------------------------------------- |
|
||||
| `--repo-id` | HuggingFace dataset repository ID |
|
||||
| `--sparse-subtasks` | Comma-separated list of high-level subtask names |
|
||||
| `--dense-subtasks` | Comma-separated list of fine-grained subtask names |
|
||||
| `--dense-only` | Generate only dense annotations (auto-creates sparse "task" stage) |
|
||||
| `--video-key` | Camera/video key to use (e.g., `observation.images.top`) |
|
||||
| `--num-workers` | Number of parallel GPU workers (default: 1) |
|
||||
| `--episodes` | Specific episode indices to annotate (default: all) |
|
||||
| `--skip-existing` | Skip episodes that already have annotations |
|
||||
| `--model` | VLM model (default: `Qwen/Qwen3-VL-30B-A3B-Instruct`) |
|
||||
| `--num-visualizations` | Number of episodes to visualize after annotation (default: 5, set to 0 to skip) |
|
||||
|
||||
> **Note**: After annotation completes, 5 episodes are automatically visualized by default. Use `--num-visualizations 0` to skip this step.
|
||||
|
||||
---
|
||||
|
||||
## Step 2: Verify Annotations
|
||||
|
||||
<hfoptions id="verify_mode">
|
||||
<hfoption id="single_stage">
|
||||
|
||||
**No verification needed!** Skip this step.
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="dense_only">
|
||||
|
||||
Visualize annotations using the `--visualize-only` flag:
|
||||
|
||||
```bash
|
||||
python src/lerobot/data_processing/sarm_annotations/subtask_annotation.py \
|
||||
--repo-id your-username/your-dataset \
|
||||
--visualize-only \
|
||||
--visualize-type dense \
|
||||
--num-visualizations 5 \
|
||||
--video-key observation.images.base \
|
||||
--output-dir ./subtask_viz
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="dual">
|
||||
|
||||
Visualize annotations using the `--visualize-only` flag:
|
||||
|
||||
```bash
|
||||
python src/lerobot/data_processing/sarm_annotations/subtask_annotation.py \
|
||||
--repo-id your-username/your-dataset \
|
||||
--visualize-only \
|
||||
--visualize-type both \
|
||||
--num-visualizations 5 \
|
||||
--video-key observation.images.base \
|
||||
--output-dir ./subtask_viz
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
This generates visualizations showing video frames with subtask boundaries overlaid and timeline of subtasks.
|
||||
|
||||
### Visualization Arguments
|
||||
|
||||
| Argument | Description |
|
||||
| ---------------------- | -------------------------------------------------------------- |
|
||||
| `--visualize-only` | Only visualize existing annotations (no generation) |
|
||||
| `--num-visualizations` | Number of episodes to visualize (default: 5) |
|
||||
| `--visualize-type` | Type of annotations to visualize: `sparse`, `dense`, or `both` |
|
||||
|
||||
**Tip**: If annotations are inaccurate, adjust your subtask descriptions to be more specific and re-run.
|
||||
|
||||
---
|
||||
|
||||
## Step 3: Train SARM
|
||||
|
||||
<hfoptions id="train_mode">
|
||||
<hfoption id="single_stage">
|
||||
|
||||
Train with **no annotations** - uses linear progress from 0 to 1:
|
||||
|
||||
```bash
|
||||
python src/lerobot/scripts/lerobot_train.py \
|
||||
--dataset.repo_id=your-username/your-dataset \
|
||||
--policy.type=sarm \
|
||||
--policy.annotation_mode=single_stage \
|
||||
--policy.image_key=observation.images.base \
|
||||
--output_dir=outputs/train/sarm_single \
|
||||
--batch_size=32 \
|
||||
--steps=5000 \
|
||||
--wandb.enable=true \
|
||||
--wandb.project=sarm \
|
||||
--policy.repo_id=your-username/your-model-name
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="dense_only">
|
||||
|
||||
Train with **dense annotations only** (sparse auto-generated):
|
||||
|
||||
```bash
|
||||
python src/lerobot/scripts/lerobot_train.py \
|
||||
--dataset.repo_id=your-username/your-dataset \
|
||||
--policy.type=sarm \
|
||||
--policy.annotation_mode=dense_only \
|
||||
--policy.image_key=observation.images.base \
|
||||
--output_dir=outputs/train/sarm_dense \
|
||||
--batch_size=32 \
|
||||
--steps=5000 \
|
||||
--wandb.enable=true \
|
||||
--wandb.project=sarm \
|
||||
--policy.repo_id=your-username/your-model-name
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="dual">
|
||||
|
||||
Train with **both sparse and dense annotations**:
|
||||
|
||||
```bash
|
||||
python src/lerobot/scripts/lerobot_train.py \
|
||||
--dataset.repo_id=your-username/your-dataset \
|
||||
--policy.type=sarm \
|
||||
--policy.annotation_mode=dual \
|
||||
--policy.image_key=observation.images.base \
|
||||
--output_dir=outputs/train/sarm_dual \
|
||||
--batch_size=32 \
|
||||
--steps=5000 \
|
||||
--wandb.enable=true \
|
||||
--wandb.project=sarm \
|
||||
--policy.repo_id=your-username/your-model-name
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
### Multi-GPU Training
|
||||
|
||||
Add `accelerate launch --multi_gpu --num_processes=4` to use multiple GPUs for training.
|
||||
|
||||
### Training Arguments
|
||||
|
||||
| Argument | Description | Default |
|
||||
| -------------------------- | ----------------------------------------------------------------- | ------------------------ |
|
||||
| `--policy.annotation_mode` | `single_stage`, `dense_only`, or `dual` | `single_stage` |
|
||||
| `--policy.image_key` | Camera key for images | `observation.images.top` |
|
||||
| `--policy.state_key` | Key for joint states | `observation.state` |
|
||||
| `--policy.n_obs_steps` | Observation history steps (total obs frames = `n_obs_steps + 1`) | `8` |
|
||||
| `--policy.frame_gap` | Gap (in frames) between sampled observations (at 30 fps: 30 ≈ 1s) | `30` |
|
||||
|
||||
---
|
||||
|
||||
## Step 4: Visualize Predictions
|
||||
|
||||
Use `compute_rabc_weights.py` with `--visualize-only` to visualize model predictions (and, if available, annotation-derived targets) without writing a parquet file.
|
||||
|
||||
<hfoptions id="viz_mode">
|
||||
<hfoption id="single_stage">
|
||||
|
||||
```bash
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \
|
||||
--dataset-repo-id your-username/your-dataset \
|
||||
--reward-model-path your-username/sarm-model \
|
||||
--visualize-only \
|
||||
--num-visualizations 5 \
|
||||
--head-mode sparse \
|
||||
--output-dir ./sarm_viz
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="dense_only">
|
||||
|
||||
```bash
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \
|
||||
--dataset-repo-id your-username/your-dataset \
|
||||
--reward-model-path your-username/sarm-model \
|
||||
--visualize-only \
|
||||
--num-visualizations 5 \
|
||||
--head-mode dense \
|
||||
--output-dir ./sarm_viz
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="dual">
|
||||
|
||||
```bash
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \
|
||||
--dataset-repo-id your-username/your-dataset \
|
||||
--reward-model-path your-username/sarm-model \
|
||||
--visualize-only \
|
||||
--num-visualizations 5 \
|
||||
--head-mode both \
|
||||
--output-dir ./sarm_viz
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
The visualization shows:
|
||||
|
||||
- **Progress plot**: Predicted progress (and optional annotation-derived “GT” when available and `--stride 1`)
|
||||
- **Stage probabilities**: Stacked area plot of predicted stage probabilities
|
||||
- **Sample frames**: Key frames from the episode with progress/stage labels
|
||||
|
||||
### Visualization Arguments
|
||||
|
||||
| Argument | Description |
|
||||
| ---------------------- | --------------------------------------------------------- |
|
||||
| `--visualize-only` | Only visualize predictions (no RABC computation) |
|
||||
| `--num-visualizations` | Number of episodes to visualize (default: 5) |
|
||||
| `--head-mode` | SARM head to use: `sparse`, `dense`, or `both` |
|
||||
| `--stride` | Compute every N frames, interpolate the rest (default: 1) |
|
||||
|
||||
---
|
||||
|
||||
## Step 5 (Optional): Train Policy with RA-BC
|
||||
|
||||
Reward-Aligned Behavior Cloning (RA-BC) uses the trained SARM model to weight training samples based on predicted progress improvement. This requires two steps:
|
||||
|
||||
1. **Precompute progress values** for all frames using the trained SARM model
|
||||
2. **Train policy** with RA-BC weighting using the precomputed values
|
||||
|
||||
### How RA-BC Works
|
||||
|
||||
For each training sample, RA-BC computes the progress delta:
|
||||
|
||||
```
|
||||
r_i = φ(o_{t+Δ}) - φ(o_t)
|
||||
```
|
||||
|
||||
Where `φ` is the SARM progress prediction and `Δ` is the policy's `chunk_size`. Samples with positive progress (good demonstrations) get higher weights, while samples with negative or zero progress get down-weighted.
|
||||
|
||||
The weighting follows **Equations 8-9** from the paper:
|
||||
|
||||
- **Soft weight**: `w̃_i = clip((r_i − (μ − 2σ)) / (4σ + ε), 0, 1)`
|
||||
- **Final weight**: `w_i = 𝟙{r_i > κ} + 𝟙{0 ≤ r_i ≤ κ} × w̃_i`
|
||||
|
||||
### Step 5a: Compute SARM Progress Values
|
||||
|
||||
First, run the SARM model on all frames in your dataset to compute progress values:
|
||||
|
||||
```bash
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \
|
||||
--dataset-repo-id your-username/your-dataset \
|
||||
--reward-model-path your-username/sarm-model \
|
||||
--head-mode sparse \
|
||||
--num-visualizations 5 \
|
||||
--push-to-hub
|
||||
```
|
||||
|
||||
This script:
|
||||
|
||||
- Processes all frames and computes progress values
|
||||
- Saves progress values to a parquet file next to the dataset on disk (defaults to `<dataset_root>/sarm_progress.parquet`)
|
||||
- Generates visualizations of the first N episodes (default: 5)
|
||||
|
||||
**Arguments:**
|
||||
|
||||
| Argument | Description | Default |
|
||||
| ---------------------- | -------------------------------------------------------------- | ---------- |
|
||||
| `--reward-model-path` | Path to trained SARM model | (required) |
|
||||
| `--head-mode` | SARM head to use: `sparse`, `dense`, or `both` | `sparse` |
|
||||
| `--device` | Device for inference | `cuda` |
|
||||
| `--visualize-only` | Only visualize predictions (no RA-BC computation) | `false` |
|
||||
| `--num-visualizations` | Number of episodes to visualize (default: 5, set to 0 to skip) | `5` |
|
||||
|
||||
**Output format** (`sarm_progress.parquet`):
|
||||
|
||||
| Column | Description |
|
||||
| ----------------- | ---------------------------------------------- |
|
||||
| `index` | Global frame index in dataset |
|
||||
| `episode_index` | Episode number |
|
||||
| `frame_index` | Local frame index within episode |
|
||||
| `progress_sparse` | Sparse head progress value [0, 1] |
|
||||
| `progress_dense` | Dense head progress value [0, 1] (if computed) |
|
||||
|
||||
### Step 5b: Train Policy with RA-BC
|
||||
|
||||
Once you have the progress file, train your policy with RA-BC weighting. The progress file is auto-detected from the dataset path (`sarm_progress.parquet`). Currently PI0, PI0.5 and SmolVLA are supported with RA-BC:
|
||||
|
||||
```bash
|
||||
python src/lerobot/scripts/lerobot_train.py \
|
||||
--dataset.repo_id=your-username/your-dataset \
|
||||
--policy.type=pi0 \
|
||||
--use_rabc=true \
|
||||
--rabc_head_mode=sparse \
|
||||
--rabc_kappa=0.01 \
|
||||
--output_dir=outputs/train/policy_rabc \
|
||||
--batch_size=32 \
|
||||
--steps=40000
|
||||
```
|
||||
|
||||
The training script automatically:
|
||||
|
||||
- Loads the precomputed progress values from the parquet file
|
||||
- Uses the policy's `chunk_size` to compute progress deltas (Δ)
|
||||
- Computes sample weights based on progress improvement
|
||||
- Applies weighted loss during training
|
||||
|
||||
**RA-BC Arguments:**
|
||||
|
||||
| Argument | Description | Default |
|
||||
| ---------------------- | ---------------------------------------------------------- | ---------------------------------- |
|
||||
| `--use_rabc` | Enable RA-BC sample weighting | `false` |
|
||||
| `--rabc_progress_path` | Path to progress parquet file (auto-detected from dataset) | `sarm_progress.parquet` in dataset |
|
||||
| `--rabc_head_mode` | Which SARM head's progress to use: `sparse` or `dense` | `sparse` |
|
||||
| `--rabc_kappa` | Threshold κ for high-quality samples | `0.01` |
|
||||
|
||||
### Tuning RA-BC Kappa
|
||||
|
||||
The `kappa` parameter is the threshold that determines which samples get full weight (w=1). Understanding how to tune it is critical for RA-BC to work effectively.
|
||||
|
||||
**How the weighting works:**
|
||||
|
||||
| Condition | Weight |
|
||||
| ------------------- | ----------------------- |
|
||||
| `delta > kappa` | 1.0 (hard threshold) |
|
||||
| `0 ≤ delta ≤ kappa` | Soft weight from Eq. 8 |
|
||||
| `delta < 0` | 0.0 (negative progress) |
|
||||
|
||||
**Diagnosing kappa issues:**
|
||||
|
||||
Monitor these WandB metrics during training:
|
||||
|
||||
| Metric | Healthy Range | Problem Indicator |
|
||||
| ------------------ | ------------- | ------------------------- |
|
||||
| `rabc_mean_weight` | 0.3 - 0.8 | ≈ 1.0 means kappa too low |
|
||||
| `rabc_delta_mean` | > 0 | Should be positive |
|
||||
| `rabc_delta_std` | > 0 | Variance in data quality |
|
||||
|
||||
**If `rabc_mean_weight ≈ 1.0`:** Your kappa is too low. Most samples have `delta > kappa` and bypass the soft-weighting entirely. RA-BC becomes equivalent to vanilla BC.
|
||||
|
||||
**Setting kappa based on your data:**
|
||||
|
||||
The default `kappa=0.01` was tuned for the paper's T-shirt folding task (~90s episodes at 30fps). For your dataset, check the logged `rabc_delta_mean` and `rabc_delta_std`:
|
||||
|
||||
```
|
||||
# If delta_mean ≈ 0.03 and delta_std ≈ 0.02:
|
||||
# Most deltas fall in range [0.01, 0.05]
|
||||
|
||||
# Option 1: Set kappa = delta_mean (medium selectivity)
|
||||
--rabc_kappa=0.03
|
||||
|
||||
# Option 2: Set kappa = delta_mean + delta_std (high selectivity)
|
||||
--rabc_kappa=0.05
|
||||
|
||||
# Option 3: Set kappa = delta_mean + 2*delta_std (very selective)
|
||||
--rabc_kappa=0.07
|
||||
```
|
||||
|
||||
**When RA-BC may not help:**
|
||||
|
||||
If your dataset is already high quality (consistent progress across all demonstrations), RA-BC won't provide much benefit since there's nothing to filter.
|
||||
|
||||
### Multi-GPU Training with RA-BC
|
||||
|
||||
```bash
|
||||
accelerate launch \
|
||||
--multi_gpu \
|
||||
--num_processes=4 \
|
||||
src/lerobot/scripts/lerobot_train.py \
|
||||
--dataset.repo_id=your-username/your-dataset \
|
||||
--policy.type=pi0 \
|
||||
--use_rabc=true \
|
||||
--rabc_kappa=0.01 \
|
||||
--output_dir=outputs/train/policy_rabc \
|
||||
--batch_size=32 \
|
||||
--steps=40000
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Tips & Best Practices
|
||||
|
||||
### Choosing a Mode
|
||||
|
||||
- **Start with `single_stage`** for quick experiments - no annotation overhead
|
||||
- Use **`dense_only`** when you want detailed progress tracking but tasks don't have clear high-level stages
|
||||
- Use **`dual`** for complex tasks where both coarse and fine-grained progress is meaningful
|
||||
|
||||
### Annotation Quality
|
||||
|
||||
1. **Be specific with subtask names**: Instead of "fold", use "grab near side and fold toward center"
|
||||
2. **Verify with visualization**: Always check a few episodes before training
|
||||
3. **Consistent naming**: Use the same subtask names across all episodes
|
||||
|
||||
### RA-BC
|
||||
|
||||
1. **Train SARM first**: RA-BC quality depends entirely on SARM quality
|
||||
2. **Monitor `rabc_mean_weight`**: If it's ≈ 1.0, increase kappa (see [Tuning RA-BC Kappa](#tuning-ra-bc-kappa))
|
||||
|
||||
---
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@article{chen2025sarm,
|
||||
title={SARM: Stage-Aware Reward Modeling for Long Horizon Robot Manipulation},
|
||||
author={Chen, Qianzhong and Yu, Justin and Schwager, Mac and Abbeel, Pieter and Shentu, Yide and Wu, Philipp},
|
||||
journal={arXiv preprint arXiv:2509.25358},
|
||||
year={2025}
|
||||
}
|
||||
```
|
||||
@@ -163,41 +163,3 @@ lerobot-edit-dataset \
|
||||
```
|
||||
|
||||
There is also a tool for adding features to a dataset that is not yet covered in `lerobot-edit-dataset`.
|
||||
|
||||
# Dataset Visualization
|
||||
|
||||
## Online Visualization
|
||||
|
||||
When you record a dataset using `lerobot`, it automatically uploads to the Hugging Face Hub unless you specify otherwise. To view the dataset online, use our **LeRobot Dataset Visualizer**, available at:
|
||||
https://huggingface.co/spaces/lerobot/visualize_dataset
|
||||
|
||||
## Local Visualization
|
||||
|
||||
You can also visualize episodes from a dataset locally using our command-line tool.
|
||||
|
||||
**From the Hugging Face Hub:**
|
||||
|
||||
```bash
|
||||
lerobot-dataset-viz \
|
||||
--repo-id lerobot/pusht \
|
||||
--episode-index 0
|
||||
```
|
||||
|
||||
**From a local folder:**
|
||||
Add the `--root` option and set `--mode local`. For example, to search in `./my_local_data_dir/lerobot/pusht`:
|
||||
|
||||
```bash
|
||||
lerobot-dataset-viz \
|
||||
--repo-id lerobot/pusht \
|
||||
--root ./my_local_data_dir \
|
||||
--mode local \
|
||||
--episode-index 0
|
||||
```
|
||||
|
||||
Once executed, the tool opens `rerun.io` and displays the camera streams, robot states, and actions for the selected episode.
|
||||
|
||||
For advanced usage—including visualizing datasets stored on a remote server—run:
|
||||
|
||||
```bash
|
||||
lerobot-dataset-viz --help
|
||||
```
|
||||
|
||||
@@ -1,74 +0,0 @@
|
||||
# WALL-OSS
|
||||
|
||||
WALL-OSS is an open-source foundation model for embodied intelligence, proposed by the [XSquare Robot](https://x2robot.com/en/research/68bc2cde8497d7f238dde690) team in 2025. The LeRobot implementation is adapted from their open-source [WallX](https://github.com/X-Square-Robot/wall-x) repository.
|
||||
|
||||
X Square Robot’s WALL-OSS is now integrated into Hugging Face’s LeRobot ecosystem. This is an exciting collaborative project between the LeRobot and X Square Robot teams. You can now post-train, evaluate, and deploy WALL-OSS directly through LeRobot. With this, we’re aiming to make it easier for the open-source robotics community to customize and deploy WALL-OSS foundation models. Read and explore WALL-OSS [paper](https://arxiv.org/pdf/2509.11766) and [code](https://github.com/X-Square-Robot/wall-x).
|
||||
|
||||
## Model Overview
|
||||
|
||||
The WALL-OSS team is building the embodied foundation model to capture and compress the world's most valuable data: the continuous, high-fidelity stream of physical interaction. By creating a direct feedback loop between the model's decisions and the body's lived experience, the emergence of a truly generalizable intelligence is enabled—one that understands not just how the world works, but how to act effectively within it.
|
||||
|
||||
Technically, WALL-OSS introduces a tightly coupled multimodal architecture (tightly-coupled MoE structure) that integrates both discrete and continuous action modeling strategies. Through a two-stage training pipeline (Inspiration → Integration), the model gradually unifies semantic reasoning and high-frequency action generation. Its core innovations include:
|
||||
|
||||
- **Embodied perception–enhanced multimodal pretraining**: Large-scale training on unified vision–language–action data to strengthen spatial, causal, and manipulation understanding.
|
||||
- **Unified Cross-Level Chain-of-Thought (Uni-CoT)**: A single differentiable framework that unifies high-level instruction reasoning, sub-task decomposition, and fine-grained action synthesis, forming a continuous chain from “understanding” to “execution.”
|
||||
- **Mixture-of-Experts (MoE) action heads**: Dynamically activating experts depending on the task phase and modeling actions in discrete or continuous space to maintain stable VLM priors.
|
||||
- **Two-stage training paradigm**:
|
||||
- **Inspiration stage**: Injecting discrete action priors to strengthen spatial understanding and semantic-action alignment.
|
||||
- **Integration stage**: Using flow matching to achieve high-frequency continuous control.
|
||||
|
||||
## Installation Requirements
|
||||
|
||||
1. Install LeRobot by following our [Installation Guide](./installation).
|
||||
2. Install WallX dependencies by running:
|
||||
|
||||
```bash
|
||||
pip install -e ".[wallx]"
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
To use WallX in LeRobot, specify the policy type as:
|
||||
|
||||
```python
|
||||
policy.type=wall_x
|
||||
```
|
||||
|
||||
## Training
|
||||
|
||||
For training WallX, you can use the standard LeRobot training script with the appropriate configuration:
|
||||
|
||||
```bash
|
||||
python src/lerobot/scripts/lerobot_train.py \
|
||||
--dataset.repo_id=your_dataset \
|
||||
--policy.type=wall_x \
|
||||
--output_dir=./outputs/wallx_training \
|
||||
--job_name=wallx_training \
|
||||
--policy.repo_id=your_repo_id \
|
||||
--policy.pretrained_name_or_path=x-square-robot/wall-oss-flow \
|
||||
--policy.prediction_mode=diffusion \
|
||||
--policy.attn_implementation=eager \
|
||||
--steps=3000 \
|
||||
--policy.device=cuda \
|
||||
--batch_size=32
|
||||
```
|
||||
|
||||
### Training Arguments
|
||||
|
||||
| Argument | Description |
|
||||
| ------------------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `--dataset.repo_id` | The Hugging Face Hub repository ID for your training dataset (e.g., `lerobot/aloha_sim_insertion_human`) |
|
||||
| `--policy.type` | Specifies using the WallX policy architecture |
|
||||
| `--output_dir` | Local directory where training checkpoints and logs will be saved |
|
||||
| `--job_name` | A name identifier for this training run (used in logging/tracking) |
|
||||
| `--policy.repo_id` | Your Hugging Face Hub repo ID where the trained model will be pushed |
|
||||
| `--policy.pretrained_path` | Path to pretrained WallX weights to initialize from (the official WALL-OSS checkpoint) |
|
||||
| `--policy.prediction_mode` | The action prediction strategy: `diffusion` or `fast` - `diffusion` uses iterative denoising for action generation, `fast` uses next token prediction instead |
|
||||
| `--policy.attn_implementation` | Attention implementation backend - `eager` uses standard PyTorch attention (alternatives include `flash_attention_2` or `sdpa`) |
|
||||
| `--steps` | Total number of training steps to run |
|
||||
| `--policy.device` | Device to train on (`cuda` for GPU, `cpu` for CPU) |
|
||||
| `--batch_size` | Number of samples per training batch |
|
||||
|
||||
## License
|
||||
|
||||
This model follows the **Apache 2.0 License**, consistent with the original [WallX repository](https://github.com/X-Square-Robot/wall-x).
|
||||
@@ -0,0 +1,454 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
WBT (Whole Body Tracking) Dance Policy for Unitree G1
|
||||
|
||||
Uses ONNX model with motion data baked in.
|
||||
Pattern matches gr00t_locomotion.py - uses UnitreeG1 robot class.
|
||||
|
||||
Usage:
|
||||
python examples/unitree_g1/dance.py
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from xml.etree import ElementTree
|
||||
|
||||
import numpy as np
|
||||
import onnx
|
||||
import onnxruntime as ort
|
||||
import pinocchio as pin
|
||||
|
||||
from lerobot.robots.unitree_g1.config_unitree_g1 import UnitreeG1Config
|
||||
from lerobot.robots.unitree_g1.unitree_g1 import UnitreeG1
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# =============================================================================
|
||||
# CONFIGURATION
|
||||
# =============================================================================
|
||||
|
||||
DANCE_ONNX_PATH = "examples/unitree_g1/fastsac_g1_29dof_dancing.onnx"
|
||||
CONTROL_DT = 0.02 # 50 Hz
|
||||
NUM_DOFS = 29
|
||||
|
||||
# Default joint positions (holosoma training defaults)
|
||||
DEFAULT_DOF_POS = np.array([
|
||||
-0.312, 0.0, 0.0, 0.669, -0.363, 0.0, # Left leg (6)
|
||||
-0.312, 0.0, 0.0, 0.669, -0.363, 0.0, # Right leg (6)
|
||||
0.0, 0.0, 0.0, # Waist (3)
|
||||
0.2, 0.2, 0.0, 0.6, 0.0, 0.0, 0.0, # Left arm (7)
|
||||
0.2, -0.2, 0.0, 0.6, 0.0, 0.0, 0.0, # Right arm (7)
|
||||
], dtype=np.float32)
|
||||
|
||||
# Stiff hold KP/KD (for initialization)
|
||||
STIFF_KP = np.array([
|
||||
150, 150, 200, 200, 40, 40,
|
||||
150, 150, 200, 200, 40, 40,
|
||||
200, 200, 100,
|
||||
100, 100, 100, 100, 50, 50, 50,
|
||||
100, 100, 100, 100, 50, 50, 50,
|
||||
], dtype=np.float32)
|
||||
|
||||
STIFF_KD = np.array([
|
||||
2.5, 2.5, 2.5, 2.5, 2.5, 2.5,
|
||||
2.5, 2.5, 2.5, 2.5, 2.5, 2.5,
|
||||
5.0, 5.0, 5.0,
|
||||
2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5,
|
||||
2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5,
|
||||
], dtype=np.float32)
|
||||
|
||||
# Joints to freeze at 0 with high KP
|
||||
FROZEN_JOINTS = [13, 14, 20, 21, 27, 28]
|
||||
FROZEN_KP = 500.0
|
||||
FROZEN_KD = 5.0
|
||||
|
||||
# =============================================================================
|
||||
# QUATERNION UTILITIES
|
||||
# =============================================================================
|
||||
|
||||
def quat_inverse(q):
|
||||
return np.concatenate((q[:, 0:1], -q[:, 1:]), axis=1)
|
||||
|
||||
def quat_mul(a, b):
|
||||
a, b = a.reshape(-1, 4), b.reshape(-1, 4)
|
||||
w1, x1, y1, z1 = a[..., 0], a[..., 1], a[..., 2], a[..., 3]
|
||||
w2, x2, y2, z2 = b[..., 0], b[..., 1], b[..., 2], b[..., 3]
|
||||
ww = (z1 + x1) * (x2 + y2)
|
||||
yy = (w1 - y1) * (w2 + z2)
|
||||
zz = (w1 + y1) * (w2 - z2)
|
||||
xx = ww + yy + zz
|
||||
qq = 0.5 * (xx + (z1 - x1) * (x2 - y2))
|
||||
w = qq - ww + (z1 - y1) * (y2 - z2)
|
||||
x = qq - xx + (x1 + w1) * (x2 + w2)
|
||||
y = qq - yy + (w1 - x1) * (y2 + z2)
|
||||
z = qq - zz + (z1 + y1) * (w2 - x2)
|
||||
return np.stack([w, x, y, z]).T.reshape(a.shape)
|
||||
|
||||
def subtract_frame_transforms(q01, q02):
|
||||
return quat_mul(quat_inverse(q01), q02)
|
||||
|
||||
def matrix_from_quat(q):
|
||||
r, i, j, k = q[..., 0], q[..., 1], q[..., 2], q[..., 3]
|
||||
two_s = 2.0 / (q * q).sum(-1)
|
||||
o = np.stack((
|
||||
1 - two_s * (j*j + k*k), two_s * (i*j - k*r), two_s * (i*k + j*r),
|
||||
two_s * (i*j + k*r), 1 - two_s * (i*i + k*k), two_s * (j*k - i*r),
|
||||
two_s * (i*k - j*r), two_s * (j*k + i*r), 1 - two_s * (i*i + j*j),
|
||||
), -1)
|
||||
return o.reshape(q.shape[:-1] + (3, 3))
|
||||
|
||||
def xyzw_to_wxyz(xyzw):
|
||||
return np.concatenate([xyzw[:, -1:], xyzw[:, :3]], axis=1)
|
||||
|
||||
def quat_to_rpy(q):
|
||||
w, x, y, z = q
|
||||
roll = np.arctan2(2*(w*x + y*z), 1 - 2*(x**2 + y**2))
|
||||
pitch = np.arcsin(np.clip(2*(w*y - z*x), -1, 1))
|
||||
yaw = np.arctan2(2*(w*z + x*y), 1 - 2*(y**2 + z**2))
|
||||
return roll, pitch, yaw
|
||||
|
||||
def rpy_to_quat(rpy):
|
||||
roll, pitch, yaw = rpy
|
||||
cy, sy = np.cos(yaw*0.5), np.sin(yaw*0.5)
|
||||
cp, sp = np.cos(pitch*0.5), np.sin(pitch*0.5)
|
||||
cr, sr = np.cos(roll*0.5), np.sin(roll*0.5)
|
||||
return np.array([cr*cp*cy + sr*sp*sy, sr*cp*cy - cr*sp*sy,
|
||||
cr*sp*cy + sr*cp*sy, cr*cp*sy - sr*sp*cy])
|
||||
|
||||
# =============================================================================
|
||||
# PINOCCHIO FK
|
||||
# =============================================================================
|
||||
|
||||
DOF_NAMES = (
|
||||
"left_hip_pitch_joint", "left_hip_roll_joint", "left_hip_yaw_joint",
|
||||
"left_knee_joint", "left_ankle_pitch_joint", "left_ankle_roll_joint",
|
||||
"right_hip_pitch_joint", "right_hip_roll_joint", "right_hip_yaw_joint",
|
||||
"right_knee_joint", "right_ankle_pitch_joint", "right_ankle_roll_joint",
|
||||
"waist_yaw_joint", "waist_roll_joint", "waist_pitch_joint",
|
||||
"left_shoulder_pitch_joint", "left_shoulder_roll_joint", "left_shoulder_yaw_joint", "left_elbow_joint",
|
||||
"left_wrist_roll_joint", "left_wrist_pitch_joint", "left_wrist_yaw_joint",
|
||||
"right_shoulder_pitch_joint", "right_shoulder_roll_joint", "right_shoulder_yaw_joint", "right_elbow_joint",
|
||||
"right_wrist_roll_joint", "right_wrist_pitch_joint", "right_wrist_yaw_joint",
|
||||
)
|
||||
|
||||
|
||||
class PinocchioFK:
|
||||
"""Pinocchio forward kinematics for torso_link orientation."""
|
||||
|
||||
def __init__(self, urdf_text: str):
|
||||
root = ElementTree.fromstring(urdf_text)
|
||||
for parent in root.iter():
|
||||
for child in list(parent):
|
||||
if child.tag.split("}")[-1] in {"visual", "collision"}:
|
||||
parent.remove(child)
|
||||
xml_text = '<?xml version="1.0"?>\n' + ElementTree.tostring(root, encoding="unicode")
|
||||
|
||||
self.model = pin.buildModelFromXML(xml_text, pin.JointModelFreeFlyer())
|
||||
self.data = self.model.createData()
|
||||
|
||||
pin_names = [n for n in self.model.names if n not in ["universe", "root_joint"]]
|
||||
self.idx_map = np.array([DOF_NAMES.index(n) for n in pin_names])
|
||||
self.ref_frame_id = self.model.getFrameId("torso_link")
|
||||
logger.info(f"Pinocchio FK: {len(pin_names)} joints, torso_link frame={self.ref_frame_id}")
|
||||
|
||||
def get_torso_quat(self, pos, quat_wxyz, dof_pos):
|
||||
"""Get torso_link orientation in world frame."""
|
||||
quat_xyzw = np.array([quat_wxyz[1], quat_wxyz[2], quat_wxyz[3], quat_wxyz[0]])
|
||||
config = np.concatenate([pos, quat_xyzw, dof_pos[self.idx_map]])
|
||||
pin.framesForwardKinematics(self.model, self.data, config)
|
||||
coeffs = pin.Quaternion(self.data.oMf[self.ref_frame_id].rotation).coeffs()
|
||||
return np.array([coeffs[3], coeffs[0], coeffs[1], coeffs[2]]).reshape(1, 4)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# DANCE CONTROLLER
|
||||
# =============================================================================
|
||||
|
||||
class DanceController:
|
||||
"""
|
||||
Handles WBT dance policy for the Unitree G1 robot.
|
||||
|
||||
This controller manages:
|
||||
- 29-joint observation processing
|
||||
- Pinocchio FK for torso orientation
|
||||
- Policy inference with motion data from ONNX
|
||||
"""
|
||||
|
||||
def __init__(self, policy, robot, pinocchio_fk, motor_kp, motor_kd, action_scale):
|
||||
self.policy = policy
|
||||
self.robot = robot
|
||||
self.pinocchio_fk = pinocchio_fk
|
||||
self.motor_kp = motor_kp
|
||||
self.motor_kd = motor_kd
|
||||
self.action_scale = action_scale
|
||||
|
||||
self.obs_dim = policy.get_inputs()[0].shape[1]
|
||||
self.last_action = np.zeros((1, NUM_DOFS), dtype=np.float32)
|
||||
self.motion_command = None
|
||||
self.ref_quat_xyzw = None
|
||||
self.timestep = 0
|
||||
self.yaw_offset = 0.0
|
||||
|
||||
# Get initial motion data from ONNX
|
||||
dummy = np.zeros((1, self.obs_dim), dtype=np.float32)
|
||||
outs = self.policy.run(["joint_pos", "joint_vel", "ref_quat_xyzw"],
|
||||
{"obs": dummy, "time_step": np.array([[0]], dtype=np.float32)})
|
||||
self.motion_command = np.concatenate(outs[0:2], axis=1)
|
||||
self.ref_quat_xyzw = outs[2]
|
||||
self.motion_start_pose = outs[0].flatten()
|
||||
|
||||
# Thread management
|
||||
self.dance_running = False
|
||||
self.dance_thread = None
|
||||
|
||||
logger.info(f"DanceController: obs_dim={self.obs_dim}, action_scale={action_scale}")
|
||||
|
||||
def capture_yaw_offset(self):
|
||||
"""Capture robot's current yaw for relative tracking."""
|
||||
robot_state = self.robot.lowstate_buffer.get_data()
|
||||
if robot_state and self.pinocchio_fk:
|
||||
quat = np.array(robot_state.imu_state.quaternion, dtype=np.float32)
|
||||
dof = np.array([robot_state.motor_state[i].q for i in range(NUM_DOFS)], dtype=np.float32)
|
||||
torso_q = self.pinocchio_fk.get_torso_quat(np.zeros(3), quat, dof)
|
||||
_, _, self.yaw_offset = quat_to_rpy(torso_q.flatten())
|
||||
logger.info(f"Captured yaw offset: {np.degrees(self.yaw_offset):.1f}°")
|
||||
|
||||
def _remove_yaw_offset(self, quat_wxyz):
|
||||
"""Remove stored yaw offset from orientation."""
|
||||
if abs(self.yaw_offset) < 1e-6:
|
||||
return quat_wxyz
|
||||
yaw_q = rpy_to_quat((0, 0, -self.yaw_offset)).reshape(1, 4)
|
||||
return quat_mul(yaw_q, quat_wxyz)
|
||||
|
||||
def run_step(self):
|
||||
"""Single dance step - reads state, runs policy, sends commands."""
|
||||
robot_state = self.robot.lowstate_buffer.get_data()
|
||||
if robot_state is None:
|
||||
return
|
||||
|
||||
# Read robot state
|
||||
quat = np.array(robot_state.imu_state.quaternion, dtype=np.float32)
|
||||
ang_vel = np.array(robot_state.imu_state.gyroscope, dtype=np.float32)
|
||||
dof_pos = np.array([robot_state.motor_state[i].q for i in range(NUM_DOFS)], dtype=np.float32)
|
||||
dof_vel = np.array([robot_state.motor_state[i].dq for i in range(NUM_DOFS)], dtype=np.float32)
|
||||
|
||||
# Compute motion_ref_ori_b using FK
|
||||
if self.pinocchio_fk:
|
||||
torso_q = self.pinocchio_fk.get_torso_quat(np.zeros(3), quat, dof_pos)
|
||||
torso_q = self._remove_yaw_offset(torso_q)
|
||||
motion_ori = xyzw_to_wxyz(self.ref_quat_xyzw)
|
||||
rel_quat = subtract_frame_transforms(torso_q, motion_ori)
|
||||
ori_b = matrix_from_quat(rel_quat)[..., :2].reshape(1, -1)
|
||||
else:
|
||||
ori_b = np.zeros((1, 6), dtype=np.float32)
|
||||
|
||||
dof_rel = (dof_pos - DEFAULT_DOF_POS).reshape(1, -1)
|
||||
|
||||
# Build observation (alphabetical order)
|
||||
obs_dict = {
|
||||
"actions": self.last_action,
|
||||
"base_ang_vel": ang_vel.reshape(1, 3),
|
||||
"dof_pos": dof_rel,
|
||||
"dof_vel": dof_vel.reshape(1, -1),
|
||||
"motion_command": self.motion_command,
|
||||
"motion_ref_ori_b": ori_b,
|
||||
}
|
||||
obs = np.concatenate([obs_dict[k].astype(np.float32) for k in sorted(obs_dict.keys())], axis=1)
|
||||
obs = np.clip(obs, -100, 100)
|
||||
|
||||
# Run policy
|
||||
outs = self.policy.run(["actions", "joint_pos", "joint_vel", "ref_quat_xyzw"],
|
||||
{"obs": obs, "time_step": np.array([[self.timestep]], dtype=np.float32)})
|
||||
|
||||
action = np.clip(outs[0], -100, 100)
|
||||
self.motion_command = np.concatenate(outs[1:3], axis=1)
|
||||
self.ref_quat_xyzw = outs[3]
|
||||
self.last_action = action.copy()
|
||||
|
||||
# Compute target positions
|
||||
target_pos = DEFAULT_DOF_POS + action.flatten() * self.action_scale
|
||||
|
||||
# Send commands
|
||||
for i in range(NUM_DOFS):
|
||||
if i in FROZEN_JOINTS:
|
||||
self.robot.msg.motor_cmd[i].q = 0.0
|
||||
self.robot.msg.motor_cmd[i].kp = FROZEN_KP
|
||||
self.robot.msg.motor_cmd[i].kd = FROZEN_KD
|
||||
else:
|
||||
self.robot.msg.motor_cmd[i].q = float(target_pos[i])
|
||||
self.robot.msg.motor_cmd[i].kp = self.motor_kp[i]
|
||||
self.robot.msg.motor_cmd[i].kd = self.motor_kd[i]
|
||||
self.robot.msg.motor_cmd[i].qd = 0
|
||||
self.robot.msg.motor_cmd[i].tau = 0
|
||||
|
||||
self.robot.send_action(self.robot.msg)
|
||||
self.timestep += 1
|
||||
|
||||
def _dance_thread_loop(self):
|
||||
"""Background thread that runs the dance policy."""
|
||||
logger.info("Dance thread started")
|
||||
while self.dance_running:
|
||||
start_time = time.time()
|
||||
try:
|
||||
self.run_step()
|
||||
except Exception as e:
|
||||
logger.error(f"Error in dance loop: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
sleep_time = max(0, CONTROL_DT - elapsed)
|
||||
time.sleep(sleep_time)
|
||||
logger.info("Dance thread stopped")
|
||||
|
||||
def start_dance_thread(self):
|
||||
"""Start the dance control thread."""
|
||||
if self.dance_running:
|
||||
logger.warning("Dance thread already running")
|
||||
return
|
||||
|
||||
# Reset state for fresh start
|
||||
self.timestep = 0
|
||||
self.last_action.fill(0)
|
||||
|
||||
# Re-get initial motion data
|
||||
dummy = np.zeros((1, self.obs_dim), dtype=np.float32)
|
||||
outs = self.policy.run(["joint_pos", "joint_vel", "ref_quat_xyzw"],
|
||||
{"obs": dummy, "time_step": np.array([[0]], dtype=np.float32)})
|
||||
self.motion_command = np.concatenate(outs[0:2], axis=1)
|
||||
self.ref_quat_xyzw = outs[2]
|
||||
|
||||
self.capture_yaw_offset()
|
||||
|
||||
logger.info("Starting dance control thread...")
|
||||
self.dance_running = True
|
||||
self.dance_thread = threading.Thread(target=self._dance_thread_loop, daemon=True)
|
||||
self.dance_thread.start()
|
||||
|
||||
def stop_dance_thread(self):
|
||||
"""Stop the dance control thread."""
|
||||
if not self.dance_running:
|
||||
return
|
||||
|
||||
logger.info("Stopping dance control thread...")
|
||||
self.dance_running = False
|
||||
if self.dance_thread:
|
||||
self.dance_thread.join(timeout=2.0)
|
||||
logger.info("Dance control thread stopped")
|
||||
|
||||
def reset_to_motion_pose(self, duration: float = 3.0):
|
||||
"""Move robot to initial motion pose over given duration."""
|
||||
logger.info(f"Moving to dance start pose ({duration}s)...")
|
||||
|
||||
robot_state = self.robot.lowstate_buffer.get_data()
|
||||
init_pos = np.array([robot_state.motor_state[i].q for i in range(NUM_DOFS)], dtype=np.float32)
|
||||
target_pos = self.motion_start_pose
|
||||
|
||||
num_steps = int(duration / CONTROL_DT)
|
||||
for step in range(num_steps):
|
||||
alpha = step / num_steps
|
||||
interp = init_pos * (1 - alpha) + target_pos * alpha
|
||||
|
||||
for i in range(NUM_DOFS):
|
||||
if i in FROZEN_JOINTS:
|
||||
self.robot.msg.motor_cmd[i].q = 0.0
|
||||
self.robot.msg.motor_cmd[i].kp = FROZEN_KP
|
||||
self.robot.msg.motor_cmd[i].kd = FROZEN_KD
|
||||
else:
|
||||
self.robot.msg.motor_cmd[i].q = float(interp[i])
|
||||
self.robot.msg.motor_cmd[i].kp = STIFF_KP[i]
|
||||
self.robot.msg.motor_cmd[i].kd = STIFF_KD[i]
|
||||
self.robot.msg.motor_cmd[i].qd = 0
|
||||
self.robot.msg.motor_cmd[i].tau = 0
|
||||
|
||||
self.robot.msg.crc = self.robot.crc.Crc(self.robot.msg)
|
||||
self.robot.lowcmd_publisher.Write(self.robot.msg)
|
||||
time.sleep(CONTROL_DT)
|
||||
|
||||
logger.info("At dance start pose!")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# MAIN
|
||||
# =============================================================================
|
||||
|
||||
def load_dance_policy(onnx_path: str):
|
||||
"""Load dance policy and extract metadata."""
|
||||
logger.info(f"Loading dance policy: {onnx_path}")
|
||||
|
||||
policy = ort.InferenceSession(onnx_path)
|
||||
model = onnx.load(onnx_path)
|
||||
metadata = {p.key: json.loads(p.value) for p in model.metadata_props}
|
||||
|
||||
motor_kp = np.array(metadata.get("kp", STIFF_KP), dtype=np.float32)
|
||||
motor_kd = np.array(metadata.get("kd", STIFF_KD), dtype=np.float32)
|
||||
action_scale = float(metadata.get("action_scale", 1.0))
|
||||
urdf_text = metadata.get("robot_urdf", None)
|
||||
|
||||
logger.info(f" Obs dim: {policy.get_inputs()[0].shape[1]}")
|
||||
logger.info(f" Action scale: {action_scale}")
|
||||
logger.info(f" KP range: [{motor_kp.min():.1f}, {motor_kp.max():.1f}]")
|
||||
|
||||
# Build Pinocchio FK if URDF available
|
||||
pinocchio_fk = None
|
||||
if urdf_text:
|
||||
logger.info(" Building Pinocchio FK from URDF...")
|
||||
pinocchio_fk = PinocchioFK(urdf_text)
|
||||
else:
|
||||
logger.warning(" No URDF in metadata - FK will not work!")
|
||||
|
||||
return policy, pinocchio_fk, motor_kp, motor_kd, action_scale
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="WBT Dance Policy for Unitree G1")
|
||||
parser.add_argument("--onnx", type=str, default=DANCE_ONNX_PATH, help="Path to dance ONNX model")
|
||||
parser.add_argument("--sim", action="store_true", help="Run in simulation mode")
|
||||
args = parser.parse_args()
|
||||
|
||||
print("=" * 70)
|
||||
print("💃 WBT DANCE POLICY")
|
||||
print("=" * 70)
|
||||
|
||||
# Load policy
|
||||
policy, pinocchio_fk, motor_kp, motor_kd, action_scale = load_dance_policy(args.onnx)
|
||||
|
||||
# Initialize robot
|
||||
logger.info("Initializing robot...")
|
||||
config = UnitreeG1Config()
|
||||
robot = UnitreeG1(config)
|
||||
logger.info("Robot connected!")
|
||||
|
||||
# Create controller
|
||||
controller = DanceController(policy, robot, pinocchio_fk, motor_kp, motor_kd, action_scale)
|
||||
|
||||
try:
|
||||
# Move to start pose
|
||||
controller.reset_to_motion_pose(duration=3.0)
|
||||
|
||||
# Start dancing
|
||||
controller.start_dance_thread()
|
||||
|
||||
logger.info("Dancing! Press Ctrl+C to stop.")
|
||||
print("-" * 70)
|
||||
|
||||
# Log status periodically
|
||||
while True:
|
||||
time.sleep(2.0)
|
||||
logger.info(f"timestep={controller.timestep}")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\nStopping...")
|
||||
finally:
|
||||
controller.stop_dance_thread()
|
||||
robot.disconnect()
|
||||
|
||||
print("\nDone!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,479 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Example: Holosoma Whole-Body Locomotion (23-DOF and 29-DOF)
|
||||
|
||||
This example demonstrates loading Holosoma whole-body locomotion policies
|
||||
and running them on the Unitree G1 robot.
|
||||
|
||||
Supports both:
|
||||
- 23-DOF native policies (82D observations, 23D actions)
|
||||
- 29-DOF policies (100D observations, 29D actions)
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from lerobot.robots.unitree_g1.config_unitree_g1 import UnitreeG1Config
|
||||
from lerobot.robots.unitree_g1.unitree_g1 import UnitreeG1
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# =============================================================================
|
||||
# 29-DOF Configuration
|
||||
# =============================================================================
|
||||
# fmt: off
|
||||
HOLOSOMA_29DOF_DEFAULT_ANGLES = np.array([
|
||||
-0.312, 0.0, 0.0, 0.669, -0.363, 0.0, # left leg
|
||||
-0.312, 0.0, 0.0, 0.669, -0.363, 0.0, # right leg
|
||||
0.0, 0.0, 0.0, # waist (yaw, roll, pitch)
|
||||
0.2, 0.2, 0.0, 0.6, 0.0, 0.0, 0.0, # left arm
|
||||
0.2, -0.2, 0.0, 0.6, 0.0, 0.0, 0.0, # right arm
|
||||
], dtype=np.float32)
|
||||
|
||||
HOLOSOMA_29DOF_KP = np.array([
|
||||
40.179238471, 99.098427777, 40.179238471, 99.098427777, 28.501246196, 28.501246196, # left leg
|
||||
40.179238471, 99.098427777, 40.179238471, 99.098427777, 28.501246196, 28.501246196, # right leg
|
||||
40.179238471, 28.501246196, 28.501246196, # waist
|
||||
14.250623098, 14.250623098, 14.250623098, 14.250623098, 14.250623098, 16.778327481, 16.778327481, # left arm
|
||||
14.250623098, 14.250623098, 14.250623098, 14.250623098, 14.250623098, 16.778327481, 16.778327481, # right arm
|
||||
], dtype=np.float32)
|
||||
|
||||
HOLOSOMA_29DOF_KD = np.array([
|
||||
2.557889765, 6.308801854, 2.557889765, 6.308801854, 1.814445687, 1.814445687, # left leg
|
||||
2.557889765, 6.308801854, 2.557889765, 6.308801854, 1.814445687, 1.814445687, # right leg
|
||||
2.557889765, 1.814445687, 1.814445687, # waist
|
||||
0.907222843, 0.907222843, 0.907222843, 0.907222843, 0.907222843, 1.068141502, 1.068141502, # left arm
|
||||
0.907222843, 0.907222843, 0.907222843, 0.907222843, 0.907222843, 1.068141502, 1.068141502, # right arm
|
||||
], dtype=np.float32)
|
||||
|
||||
# =============================================================================
|
||||
# 23-DOF Configuration (native G1-23: no waist_roll/pitch, no wrist_pitch/yaw)
|
||||
# Derived from 29-DOF Holosoma values
|
||||
# =============================================================================
|
||||
# Joint order: 6 left leg, 6 right leg, 1 waist_yaw, 5 left arm, 5 right arm
|
||||
HOLOSOMA_23DOF_DEFAULT_ANGLES = np.array([
|
||||
-0.312, 0.0, 0.0, 0.669, -0.363, 0.0, # left leg (from 29-DOF)
|
||||
-0.312, 0.0, 0.0, 0.669, -0.363, 0.0, # right leg (from 29-DOF)
|
||||
0.0, # waist_yaw only (from 29-DOF)
|
||||
0.2, 0.2, 0.0, 0.6, 0.0, # left arm first 5 joints (from 29-DOF)
|
||||
0.2, -0.2, 0.0, 0.6, 0.0, # right arm first 5 joints (from 29-DOF)
|
||||
], dtype=np.float32)
|
||||
|
||||
HOLOSOMA_23DOF_KP = np.array([
|
||||
40.179238471, 99.098427777, 40.179238471, 99.098427777, 28.501246196, 28.501246196, # left leg
|
||||
40.179238471, 99.098427777, 40.179238471, 99.098427777, 28.501246196, 28.501246196, # right leg
|
||||
40.179238471, # waist_yaw
|
||||
14.250623098, 14.250623098, 14.250623098, 14.250623098, 14.250623098, # left arm
|
||||
14.250623098, 14.250623098, 14.250623098, 14.250623098, 14.250623098, # right arm
|
||||
], dtype=np.float32)
|
||||
|
||||
HOLOSOMA_23DOF_KD = np.array([
|
||||
2.557889765, 6.308801854, 2.557889765, 6.308801854, 1.814445687, 1.814445687, # left leg
|
||||
2.557889765, 6.308801854, 2.557889765, 6.308801854, 1.814445687, 1.814445687, # right leg
|
||||
2.557889765, # waist_yaw
|
||||
0.907222843, 0.907222843, 0.907222843, 0.907222843, 0.907222843, # left arm
|
||||
0.907222843, 0.907222843, 0.907222843, 0.907222843, 0.907222843, # right arm
|
||||
], dtype=np.float32)
|
||||
|
||||
# Maps 23-DOF policy index → 29-DOF motor index
|
||||
# 23-DOF: legs(0-11), waist_yaw(12), L_arm(13-17), R_arm(18-22)
|
||||
# 29-DOF: legs(0-11), waist(12-14), L_arm(15-21), R_arm(22-28)
|
||||
DOF_23_TO_MOTOR_MAP = [
|
||||
0, 1, 2, 3, 4, 5, # left leg → motor 0-5
|
||||
6, 7, 8, 9, 10, 11, # right leg → motor 6-11
|
||||
12, # waist_yaw → motor 12
|
||||
15, 16, 17, 18, 19, # left arm (skip wrist_pitch/yaw) → motor 15-19
|
||||
22, 23, 24, 25, 26, # right arm (skip wrist_pitch/yaw) → motor 22-26
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
# Control parameters
|
||||
LOCOMOTION_CONTROL_DT = 0.02 # 50Hz
|
||||
LOCOMOTION_ACTION_SCALE = 0.25
|
||||
ANG_VEL_SCALE = 0.25
|
||||
DOF_POS_SCALE = 1.0
|
||||
DOF_VEL_SCALE = 0.05
|
||||
GAIT_PERIOD = 1.0
|
||||
|
||||
DEFAULT_HOLOSOMA_REPO_ID = "nepyope/holosoma_locomotion"
|
||||
|
||||
|
||||
def load_holosoma_policy(
|
||||
repo_id: str = DEFAULT_HOLOSOMA_REPO_ID,
|
||||
policy_name: str = "fastsac",
|
||||
local_path: str | None = None,
|
||||
) -> tuple[ort.InferenceSession, int]:
|
||||
"""Load Holosoma policy and detect observation dimension.
|
||||
|
||||
Returns:
|
||||
(policy, obs_dim) tuple where obs_dim is 82 (23-DOF) or 100 (29-DOF)
|
||||
"""
|
||||
if local_path is not None:
|
||||
logger.info(f"Loading policy from local path: {local_path}")
|
||||
policy_path = local_path
|
||||
else:
|
||||
logger.info(f"Loading policy from Hugging Face Hub: {repo_id}")
|
||||
policy_path = hf_hub_download(repo_id=repo_id, filename=f"{policy_name}_g1_29dof.onnx")
|
||||
|
||||
policy = ort.InferenceSession(policy_path)
|
||||
|
||||
# Detect observation dimension from model input shape
|
||||
input_shape = policy.get_inputs()[0].shape
|
||||
obs_dim = input_shape[1] if len(input_shape) > 1 else input_shape[0]
|
||||
|
||||
logger.info(f"Policy loaded successfully")
|
||||
logger.info(f" Input: {policy.get_inputs()[0].name}, shape: {input_shape} → obs_dim={obs_dim}")
|
||||
logger.info(f" Output: {policy.get_outputs()[0].name}, shape: {policy.get_outputs()[0].shape}")
|
||||
|
||||
return policy, obs_dim
|
||||
|
||||
|
||||
class HolosomaLocomotionController:
|
||||
"""
|
||||
Handles Holosoma whole-body locomotion for Unitree G1.
|
||||
Supports both 23-DOF (82D obs) and 29-DOF (100D obs) policies.
|
||||
"""
|
||||
|
||||
def __init__(self, policy, robot, config, obs_dim: int = 100):
|
||||
self.policy = policy
|
||||
self.robot = robot
|
||||
self.config = config
|
||||
self.obs_dim = obs_dim
|
||||
|
||||
# Detect policy type from observation dimension
|
||||
self.is_23dof = (obs_dim == 82)
|
||||
self.num_dof = 23 if self.is_23dof else 29
|
||||
|
||||
# Velocity commands
|
||||
self.locomotion_cmd = np.array([0.0, 0.0, 0.0], dtype=np.float32)
|
||||
|
||||
# State variables sized for policy type
|
||||
self.qj = np.zeros(self.num_dof, dtype=np.float32)
|
||||
self.dqj = np.zeros(self.num_dof, dtype=np.float32)
|
||||
self.locomotion_action = np.zeros(self.num_dof, dtype=np.float32)
|
||||
self.locomotion_obs = np.zeros(obs_dim, dtype=np.float32)
|
||||
self.last_unscaled_action = np.zeros(self.num_dof, dtype=np.float32)
|
||||
|
||||
# Select config based on DOF
|
||||
if self.is_23dof:
|
||||
self.default_angles = HOLOSOMA_23DOF_DEFAULT_ANGLES
|
||||
self.kp = HOLOSOMA_23DOF_KP
|
||||
self.kd = HOLOSOMA_23DOF_KD
|
||||
self.motor_map = DOF_23_TO_MOTOR_MAP
|
||||
else:
|
||||
self.default_angles = HOLOSOMA_29DOF_DEFAULT_ANGLES
|
||||
self.kp = HOLOSOMA_29DOF_KP
|
||||
self.kd = HOLOSOMA_29DOF_KD
|
||||
self.motor_map = list(range(29)) # Identity map for 29-DOF
|
||||
|
||||
# Phase state for gait
|
||||
self.phase = np.zeros((1, 2), dtype=np.float32)
|
||||
self.phase[0, 0] = 0.0
|
||||
self.phase[0, 1] = np.pi
|
||||
self.phase_dt = 2 * np.pi / (50.0 * GAIT_PERIOD)
|
||||
self.is_standing = False
|
||||
|
||||
self.counter = 0
|
||||
self.locomotion_running = False
|
||||
self.locomotion_thread = None
|
||||
|
||||
logger.info(f"HolosomaLocomotionController initialized")
|
||||
logger.info(f" Mode: {'23-DOF (82D obs)' if self.is_23dof else '29-DOF (100D obs)'}")
|
||||
logger.info(f" Action dim: {self.num_dof}")
|
||||
|
||||
def holosoma_locomotion_run(self):
|
||||
"""Main locomotion loop - handles both 23-DOF and 29-DOF."""
|
||||
self.counter += 1
|
||||
|
||||
if self.counter == 1:
|
||||
print("\n" + "=" * 60)
|
||||
print(f"🚀 RUNNING HOLOSOMA {self.num_dof}-DOF LOCOMOTION POLICY")
|
||||
print(f" {self.obs_dim}D observations → {self.num_dof}D actions")
|
||||
print("=" * 60 + "\n")
|
||||
|
||||
robot_state = self.robot.get_observation()
|
||||
if robot_state is None:
|
||||
return
|
||||
|
||||
# Remote controller
|
||||
if robot_state.wireless_remote is not None:
|
||||
self.robot.remote_controller.set(robot_state.wireless_remote)
|
||||
else:
|
||||
self.robot.remote_controller.lx = 0.0
|
||||
self.robot.remote_controller.ly = 0.0
|
||||
self.robot.remote_controller.rx = 0.0
|
||||
self.robot.remote_controller.ry = 0.0
|
||||
|
||||
# Deadzone
|
||||
ly = self.robot.remote_controller.ly if abs(self.robot.remote_controller.ly) > 0.1 else 0.0
|
||||
lx = self.robot.remote_controller.lx if abs(self.robot.remote_controller.lx) > 0.1 else 0.0
|
||||
rx = self.robot.remote_controller.rx if abs(self.robot.remote_controller.rx) > 0.1 else 0.0
|
||||
|
||||
self.locomotion_cmd[0] = ly
|
||||
self.locomotion_cmd[1] = -lx
|
||||
self.locomotion_cmd[2] = -rx
|
||||
|
||||
# Read joint states using motor map
|
||||
for i in range(self.num_dof):
|
||||
motor_idx = self.motor_map[i]
|
||||
self.qj[i] = robot_state.motor_state[motor_idx].q
|
||||
self.dqj[i] = robot_state.motor_state[motor_idx].dq
|
||||
|
||||
# IMU
|
||||
quat = robot_state.imu_state.quaternion
|
||||
ang_vel = np.array(robot_state.imu_state.gyroscope, dtype=np.float32)
|
||||
gravity_orientation = self.robot.get_gravity_orientation(quat)
|
||||
|
||||
# Scale observations
|
||||
qj_obs = (self.qj - self.default_angles) * DOF_POS_SCALE
|
||||
dqj_obs = self.dqj * DOF_VEL_SCALE
|
||||
ang_vel_scaled = ang_vel * ANG_VEL_SCALE
|
||||
|
||||
# Phase update
|
||||
cmd_norm = np.linalg.norm(self.locomotion_cmd[:2])
|
||||
ang_cmd_norm = np.abs(self.locomotion_cmd[2])
|
||||
|
||||
if cmd_norm < 0.01 and ang_cmd_norm < 0.01:
|
||||
self.phase[0, :] = np.pi * np.ones(2)
|
||||
self.is_standing = True
|
||||
elif self.is_standing:
|
||||
self.phase = np.array([[0.0, np.pi]], dtype=np.float32)
|
||||
self.is_standing = False
|
||||
else:
|
||||
phase_tp1 = self.phase + self.phase_dt
|
||||
self.phase = np.fmod(phase_tp1 + np.pi, 2 * np.pi) - np.pi
|
||||
|
||||
sin_phase = np.sin(self.phase[0, :])
|
||||
cos_phase = np.cos(self.phase[0, :])
|
||||
|
||||
# Build observation (format depends on DOF)
|
||||
if self.is_23dof:
|
||||
# 82D: [23 actions, 3 ang_vel, 1 cmd_yaw, 2 cmd_lin, 2 cos, 23 pos, 23 vel, 3 grav, 2 sin]
|
||||
self.locomotion_obs[0:23] = self.last_unscaled_action
|
||||
self.locomotion_obs[23:26] = ang_vel_scaled
|
||||
self.locomotion_obs[26] = self.locomotion_cmd[2]
|
||||
self.locomotion_obs[27:29] = self.locomotion_cmd[:2]
|
||||
self.locomotion_obs[29:31] = cos_phase
|
||||
self.locomotion_obs[31:54] = qj_obs
|
||||
self.locomotion_obs[54:77] = dqj_obs
|
||||
self.locomotion_obs[77:80] = gravity_orientation
|
||||
self.locomotion_obs[80:82] = sin_phase
|
||||
else:
|
||||
# 100D: [29 actions, 3 ang_vel, 1 cmd_yaw, 2 cmd_lin, 2 cos, 29 pos, 29 vel, 3 grav, 2 sin]
|
||||
self.locomotion_obs[0:29] = self.last_unscaled_action
|
||||
self.locomotion_obs[29:32] = ang_vel_scaled
|
||||
self.locomotion_obs[32] = self.locomotion_cmd[2]
|
||||
self.locomotion_obs[33:35] = self.locomotion_cmd[:2]
|
||||
self.locomotion_obs[35:37] = cos_phase
|
||||
self.locomotion_obs[37:66] = qj_obs
|
||||
self.locomotion_obs[66:95] = dqj_obs
|
||||
self.locomotion_obs[95:98] = gravity_orientation
|
||||
self.locomotion_obs[98:100] = sin_phase
|
||||
|
||||
# Policy inference
|
||||
obs_input = self.locomotion_obs.reshape(1, -1).astype(np.float32)
|
||||
ort_inputs = {self.policy.get_inputs()[0].name: obs_input}
|
||||
ort_outs = self.policy.run(None, ort_inputs)
|
||||
|
||||
raw_action = ort_outs[0].squeeze()
|
||||
clipped_action = np.clip(raw_action, -100.0, 100.0)
|
||||
|
||||
self.last_unscaled_action = clipped_action.copy()
|
||||
self.locomotion_action = clipped_action * LOCOMOTION_ACTION_SCALE
|
||||
|
||||
# Debug
|
||||
if self.counter <= 3:
|
||||
print(f"\n[Holosoma Debug #{self.counter}]")
|
||||
print(f" Phase: ({self.phase[0, 0]:.3f}, {self.phase[0, 1]:.3f})")
|
||||
print(f" Cmd: ({self.locomotion_cmd[0]:.2f}, {self.locomotion_cmd[1]:.2f}, {self.locomotion_cmd[2]:.2f})")
|
||||
print(f" Action range: [{raw_action.min():.3f}, {raw_action.max():.3f}]")
|
||||
|
||||
# Compute target positions
|
||||
target_dof_pos = self.default_angles + self.locomotion_action
|
||||
|
||||
# Send commands to motors via motor map
|
||||
for i in range(self.num_dof):
|
||||
motor_idx = self.motor_map[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].q = target_dof_pos[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].qd = 0
|
||||
self.robot.msg.motor_cmd[motor_idx].kp = self.kp[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].kd = self.kd[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].tau = 0
|
||||
|
||||
# For 23-DOF: zero out missing joints (waist_roll/pitch, wrist_pitch/yaw)
|
||||
if self.is_23dof:
|
||||
missing_motors = [13, 14, 20, 21, 27, 28] # waist_roll, waist_pitch, wrist_pitch/yaw
|
||||
for motor_idx in missing_motors:
|
||||
self.robot.msg.motor_cmd[motor_idx].q = 0.0
|
||||
self.robot.msg.motor_cmd[motor_idx].qd = 0
|
||||
self.robot.msg.motor_cmd[motor_idx].kp = 40.0
|
||||
self.robot.msg.motor_cmd[motor_idx].kd = 2.0
|
||||
self.robot.msg.motor_cmd[motor_idx].tau = 0
|
||||
|
||||
self.robot.send_action(self.robot.msg)
|
||||
|
||||
def _locomotion_thread_loop(self):
|
||||
logger.info("Locomotion thread started")
|
||||
while self.locomotion_running:
|
||||
start_time = time.time()
|
||||
try:
|
||||
self.holosoma_locomotion_run()
|
||||
except Exception as e:
|
||||
logger.error(f"Error in locomotion loop: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
sleep_time = max(0, LOCOMOTION_CONTROL_DT - elapsed)
|
||||
time.sleep(sleep_time)
|
||||
logger.info("Locomotion thread stopped")
|
||||
|
||||
def start_locomotion_thread(self):
|
||||
if self.locomotion_running:
|
||||
logger.warning("Locomotion thread already running")
|
||||
return
|
||||
logger.info("Starting locomotion control thread...")
|
||||
self.locomotion_running = True
|
||||
self.locomotion_thread = threading.Thread(target=self._locomotion_thread_loop, daemon=True)
|
||||
self.locomotion_thread.start()
|
||||
logger.info("Locomotion control thread started!")
|
||||
|
||||
def stop_locomotion_thread(self):
|
||||
if not self.locomotion_running:
|
||||
return
|
||||
logger.info("Stopping locomotion control thread...")
|
||||
self.locomotion_running = False
|
||||
if self.locomotion_thread:
|
||||
self.locomotion_thread.join(timeout=2.0)
|
||||
logger.info("Locomotion control thread stopped")
|
||||
|
||||
def reset_robot(self):
|
||||
"""Move joints to default position."""
|
||||
logger.info(f"Moving {self.num_dof} joints to default position...")
|
||||
|
||||
total_time = 3.0
|
||||
num_step = int(total_time / self.robot.control_dt)
|
||||
|
||||
robot_state = self.robot.get_observation()
|
||||
|
||||
# Record current positions
|
||||
init_dof_pos = np.zeros(self.num_dof, dtype=np.float32)
|
||||
for i in range(self.num_dof):
|
||||
motor_idx = self.motor_map[i]
|
||||
init_dof_pos[i] = robot_state.motor_state[motor_idx].q
|
||||
|
||||
# Interpolate to target
|
||||
for step in range(num_step):
|
||||
alpha = step / num_step
|
||||
for i in range(self.num_dof):
|
||||
motor_idx = self.motor_map[i]
|
||||
target = self.default_angles[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].q = init_dof_pos[i] * (1 - alpha) + target * alpha
|
||||
self.robot.msg.motor_cmd[motor_idx].qd = 0
|
||||
self.robot.msg.motor_cmd[motor_idx].kp = self.kp[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].kd = self.kd[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].tau = 0
|
||||
|
||||
# Zero missing joints for 23-DOF
|
||||
if self.is_23dof:
|
||||
for motor_idx in [13, 14, 20, 21, 27, 28]:
|
||||
self.robot.msg.motor_cmd[motor_idx].q = 0.0
|
||||
self.robot.msg.motor_cmd[motor_idx].qd = 0
|
||||
self.robot.msg.motor_cmd[motor_idx].kp = 40.0
|
||||
self.robot.msg.motor_cmd[motor_idx].kd = 2.0
|
||||
self.robot.msg.motor_cmd[motor_idx].tau = 0
|
||||
|
||||
self.robot.msg.crc = self.robot.crc.Crc(self.robot.msg)
|
||||
self.robot.lowcmd_publisher.Write(self.robot.msg)
|
||||
time.sleep(self.robot.control_dt)
|
||||
|
||||
logger.info(f"Reached default position ({self.num_dof} joints)")
|
||||
|
||||
# Hold for 2 seconds
|
||||
logger.info("Holding default position for 2 seconds...")
|
||||
hold_steps = int(2.0 / self.robot.control_dt)
|
||||
for _ in range(hold_steps):
|
||||
for i in range(self.num_dof):
|
||||
motor_idx = self.motor_map[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].q = self.default_angles[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].qd = 0
|
||||
self.robot.msg.motor_cmd[motor_idx].kp = self.kp[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].kd = self.kd[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].tau = 0
|
||||
|
||||
if self.is_23dof:
|
||||
for motor_idx in [13, 14, 20, 21, 27, 28]:
|
||||
self.robot.msg.motor_cmd[motor_idx].q = 0.0
|
||||
self.robot.msg.motor_cmd[motor_idx].qd = 0
|
||||
self.robot.msg.motor_cmd[motor_idx].kp = 40.0
|
||||
self.robot.msg.motor_cmd[motor_idx].kd = 2.0
|
||||
self.robot.msg.motor_cmd[motor_idx].tau = 0
|
||||
|
||||
self.robot.msg.crc = self.robot.crc.Crc(self.robot.msg)
|
||||
self.robot.lowcmd_publisher.Write(self.robot.msg)
|
||||
time.sleep(self.robot.control_dt)
|
||||
|
||||
logger.info("Ready to start locomotion!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Holosoma Locomotion Controller for Unitree G1")
|
||||
parser.add_argument("--repo-id", type=str, default=DEFAULT_HOLOSOMA_REPO_ID)
|
||||
parser.add_argument("--policy", type=str, default="fastsac", choices=["fastsac", "ppo"])
|
||||
parser.add_argument("--local-path", type=str, default=None, help="Path to local ONNX file")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load policy and detect dimensions
|
||||
policy, obs_dim = load_holosoma_policy(
|
||||
repo_id=args.repo_id,
|
||||
policy_name=args.policy,
|
||||
local_path=args.local_path,
|
||||
)
|
||||
|
||||
# Initialize robot
|
||||
config = UnitreeG1Config()
|
||||
robot = UnitreeG1(config)
|
||||
|
||||
# Initialize controller with detected obs_dim
|
||||
controller = HolosomaLocomotionController(
|
||||
policy=policy,
|
||||
robot=robot,
|
||||
config=config,
|
||||
obs_dim=obs_dim,
|
||||
)
|
||||
|
||||
try:
|
||||
#controller.reset_robot()
|
||||
controller.start_locomotion_thread()
|
||||
|
||||
logger.info(f"Robot initialized with Holosoma {'23-DOF' if obs_dim == 82 else '29-DOF'} policy")
|
||||
logger.info("Use remote controller: LY=fwd/back, LX=left/right, RX=rotate")
|
||||
logger.info("Press Ctrl+C to stop")
|
||||
|
||||
while True:
|
||||
time.sleep(1.0)
|
||||
except KeyboardInterrupt:
|
||||
print("\nStopping locomotion...")
|
||||
controller.stop_locomotion_thread()
|
||||
print("Done!")
|
||||
@@ -0,0 +1,607 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Locomotion ↔ Dance Toggle for Unitree G1
|
||||
|
||||
Press Enter to instantly switch between locomotion and dance modes.
|
||||
- Starts in LOCOMOTION mode (joystick control)
|
||||
- Press Enter → DANCE mode (resets to frame 0)
|
||||
- Press Enter → LOCOMOTION mode
|
||||
- Repeat...
|
||||
|
||||
Auto-recovery feature:
|
||||
- If robot tilts beyond threshold during dance, auto-switches to locomotion
|
||||
- When robot recovers (tilt below recovery threshold), resumes dance from where it left off
|
||||
|
||||
Usage:
|
||||
python examples/unitree_g1/locomotion_to_dance.py
|
||||
python examples/unitree_g1/locomotion_to_dance.py --tilt-threshold 25 --recovery-threshold 10
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import select
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from xml.etree import ElementTree
|
||||
|
||||
import numpy as np
|
||||
import onnx
|
||||
import onnxruntime as ort
|
||||
import pinocchio as pin
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from lerobot.robots.unitree_g1.config_unitree_g1 import UnitreeG1Config
|
||||
from lerobot.robots.unitree_g1.unitree_g1 import UnitreeG1
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# =============================================================================
|
||||
# CONFIGURATION
|
||||
# =============================================================================
|
||||
|
||||
NUM_DOFS = 29
|
||||
CONTROL_DT = 0.02 # 50Hz
|
||||
|
||||
# Locomotion config
|
||||
DEFAULT_HOLOSOMA_REPO_ID = "nepyope/holosoma_locomotion"
|
||||
LOCOMOTION_ACTION_SCALE = 0.25
|
||||
ANG_VEL_SCALE = 0.25
|
||||
DOF_POS_SCALE = 1.0
|
||||
DOF_VEL_SCALE = 0.05
|
||||
GAIT_PERIOD = 1.0
|
||||
|
||||
# Dance config
|
||||
DANCE_ONNX_PATH = "examples/unitree_g1/fastsac_g1_29dof_dancing.onnx"
|
||||
FROZEN_JOINTS = [13, 14, 20, 21, 27, 28]
|
||||
FROZEN_KP = 500.0
|
||||
FROZEN_KD = 5.0
|
||||
|
||||
# fmt: off
|
||||
# 29-DOF defaults (holosoma training)
|
||||
DEFAULT_29DOF_ANGLES = np.array([
|
||||
-0.312, 0.0, 0.0, 0.669, -0.363, 0.0, # left leg
|
||||
-0.312, 0.0, 0.0, 0.669, -0.363, 0.0, # right leg
|
||||
0.0, 0.0, 0.0, # waist
|
||||
0.2, 0.2, 0.0, 0.6, 0.0, 0.0, 0.0, # left arm
|
||||
0.2, -0.2, 0.0, 0.6, 0.0, 0.0, 0.0, # right arm
|
||||
], dtype=np.float32)
|
||||
|
||||
DEFAULT_29DOF_KP = np.array([
|
||||
40.179, 99.098, 40.179, 99.098, 28.501, 28.501,
|
||||
40.179, 99.098, 40.179, 99.098, 28.501, 28.501,
|
||||
40.179, 28.501, 28.501,
|
||||
14.251, 14.251, 14.251, 14.251, 14.251, 16.778, 16.778,
|
||||
14.251, 14.251, 14.251, 14.251, 14.251, 16.778, 16.778,
|
||||
], dtype=np.float32)
|
||||
|
||||
DEFAULT_29DOF_KD = np.array([
|
||||
2.558, 6.309, 2.558, 6.309, 1.814, 1.814,
|
||||
2.558, 6.309, 2.558, 6.309, 1.814, 1.814,
|
||||
2.558, 1.814, 1.814,
|
||||
0.907, 0.907, 0.907, 0.907, 0.907, 1.068, 1.068,
|
||||
0.907, 0.907, 0.907, 0.907, 0.907, 1.068, 1.068,
|
||||
], dtype=np.float32)
|
||||
|
||||
# 23-DOF config (no waist_roll/pitch, no wrist_pitch/yaw)
|
||||
DEFAULT_23DOF_ANGLES = np.array([
|
||||
-0.312, 0.0, 0.0, 0.669, -0.363, 0.0, # left leg
|
||||
-0.312, 0.0, 0.0, 0.669, -0.363, 0.0, # right leg
|
||||
0.0, # waist_yaw only
|
||||
0.2, 0.2, 0.0, 0.6, 0.0, # left arm (5 joints)
|
||||
0.2, -0.2, 0.0, 0.6, 0.0, # right arm (5 joints)
|
||||
], dtype=np.float32)
|
||||
|
||||
DEFAULT_23DOF_KP = np.array([
|
||||
40.179, 99.098, 40.179, 99.098, 28.501, 28.501,
|
||||
40.179, 99.098, 40.179, 99.098, 28.501, 28.501,
|
||||
40.179,
|
||||
14.251, 14.251, 14.251, 14.251, 14.251,
|
||||
14.251, 14.251, 14.251, 14.251, 14.251,
|
||||
], dtype=np.float32)
|
||||
|
||||
DEFAULT_23DOF_KD = np.array([
|
||||
2.558, 6.309, 2.558, 6.309, 1.814, 1.814,
|
||||
2.558, 6.309, 2.558, 6.309, 1.814, 1.814,
|
||||
2.558,
|
||||
0.907, 0.907, 0.907, 0.907, 0.907,
|
||||
0.907, 0.907, 0.907, 0.907, 0.907,
|
||||
], dtype=np.float32)
|
||||
|
||||
# 23-DOF policy index → 29-DOF motor index
|
||||
DOF_23_TO_MOTOR = [
|
||||
0, 1, 2, 3, 4, 5, # left leg
|
||||
6, 7, 8, 9, 10, 11, # right leg
|
||||
12, # waist_yaw
|
||||
15, 16, 17, 18, 19, # left arm (skip wrist_pitch/yaw)
|
||||
22, 23, 24, 25, 26, # right arm (skip wrist_pitch/yaw)
|
||||
]
|
||||
MISSING_23DOF_MOTORS = [13, 14, 20, 21, 27, 28]
|
||||
# fmt: on
|
||||
|
||||
# =============================================================================
|
||||
# QUATERNION UTILITIES
|
||||
# =============================================================================
|
||||
|
||||
def quat_inverse(q):
|
||||
return np.concatenate((q[:, 0:1], -q[:, 1:]), axis=1)
|
||||
|
||||
def quat_mul(a, b):
|
||||
a, b = a.reshape(-1, 4), b.reshape(-1, 4)
|
||||
w1, x1, y1, z1 = a[..., 0], a[..., 1], a[..., 2], a[..., 3]
|
||||
w2, x2, y2, z2 = b[..., 0], b[..., 1], b[..., 2], b[..., 3]
|
||||
ww = (z1 + x1) * (x2 + y2)
|
||||
yy = (w1 - y1) * (w2 + z2)
|
||||
zz = (w1 + y1) * (w2 - z2)
|
||||
xx = ww + yy + zz
|
||||
qq = 0.5 * (xx + (z1 - x1) * (x2 - y2))
|
||||
w = qq - ww + (z1 - y1) * (y2 - z2)
|
||||
x = qq - xx + (x1 + w1) * (x2 + w2)
|
||||
y = qq - yy + (w1 - x1) * (y2 + z2)
|
||||
z = qq - zz + (z1 + y1) * (w2 - x2)
|
||||
return np.stack([w, x, y, z]).T.reshape(a.shape)
|
||||
|
||||
def subtract_frame_transforms(q01, q02):
|
||||
return quat_mul(quat_inverse(q01), q02)
|
||||
|
||||
def matrix_from_quat(q):
|
||||
r, i, j, k = q[..., 0], q[..., 1], q[..., 2], q[..., 3]
|
||||
two_s = 2.0 / (q * q).sum(-1)
|
||||
o = np.stack((
|
||||
1 - two_s * (j*j + k*k), two_s * (i*j - k*r), two_s * (i*k + j*r),
|
||||
two_s * (i*j + k*r), 1 - two_s * (i*i + k*k), two_s * (j*k - i*r),
|
||||
two_s * (i*k - j*r), two_s * (j*k + i*r), 1 - two_s * (i*i + j*j),
|
||||
), -1)
|
||||
return o.reshape(q.shape[:-1] + (3, 3))
|
||||
|
||||
def xyzw_to_wxyz(xyzw):
|
||||
return np.concatenate([xyzw[:, -1:], xyzw[:, :3]], axis=1)
|
||||
|
||||
def quat_to_rpy(q):
|
||||
w, x, y, z = q
|
||||
roll = np.arctan2(2*(w*x + y*z), 1 - 2*(x**2 + y**2))
|
||||
pitch = np.arcsin(np.clip(2*(w*y - z*x), -1, 1))
|
||||
yaw = np.arctan2(2*(w*z + x*y), 1 - 2*(y**2 + z**2))
|
||||
return roll, pitch, yaw
|
||||
|
||||
def rpy_to_quat(rpy):
|
||||
roll, pitch, yaw = rpy
|
||||
cy, sy = np.cos(yaw*0.5), np.sin(yaw*0.5)
|
||||
cp, sp = np.cos(pitch*0.5), np.sin(pitch*0.5)
|
||||
cr, sr = np.cos(roll*0.5), np.sin(roll*0.5)
|
||||
return np.array([cr*cp*cy + sr*sp*sy, sr*cp*cy - cr*sp*sy,
|
||||
cr*sp*cy + sr*cp*sy, cr*cp*sy - sr*sp*cy])
|
||||
|
||||
# =============================================================================
|
||||
# PINOCCHIO FK
|
||||
# =============================================================================
|
||||
|
||||
DOF_NAMES = (
|
||||
"left_hip_pitch_joint", "left_hip_roll_joint", "left_hip_yaw_joint",
|
||||
"left_knee_joint", "left_ankle_pitch_joint", "left_ankle_roll_joint",
|
||||
"right_hip_pitch_joint", "right_hip_roll_joint", "right_hip_yaw_joint",
|
||||
"right_knee_joint", "right_ankle_pitch_joint", "right_ankle_roll_joint",
|
||||
"waist_yaw_joint", "waist_roll_joint", "waist_pitch_joint",
|
||||
"left_shoulder_pitch_joint", "left_shoulder_roll_joint", "left_shoulder_yaw_joint", "left_elbow_joint",
|
||||
"left_wrist_roll_joint", "left_wrist_pitch_joint", "left_wrist_yaw_joint",
|
||||
"right_shoulder_pitch_joint", "right_shoulder_roll_joint", "right_shoulder_yaw_joint", "right_elbow_joint",
|
||||
"right_wrist_roll_joint", "right_wrist_pitch_joint", "right_wrist_yaw_joint",
|
||||
)
|
||||
|
||||
|
||||
class PinocchioFK:
|
||||
def __init__(self, urdf_text: str):
|
||||
root = ElementTree.fromstring(urdf_text)
|
||||
for parent in root.iter():
|
||||
for child in list(parent):
|
||||
if child.tag.split("}")[-1] in {"visual", "collision"}:
|
||||
parent.remove(child)
|
||||
xml_text = '<?xml version="1.0"?>\n' + ElementTree.tostring(root, encoding="unicode")
|
||||
self.model = pin.buildModelFromXML(xml_text, pin.JointModelFreeFlyer())
|
||||
self.data = self.model.createData()
|
||||
pin_names = [n for n in self.model.names if n not in ["universe", "root_joint"]]
|
||||
self.idx_map = np.array([DOF_NAMES.index(n) for n in pin_names])
|
||||
self.ref_frame_id = self.model.getFrameId("torso_link")
|
||||
|
||||
def get_torso_quat(self, pos, quat_wxyz, dof_pos):
|
||||
quat_xyzw = np.array([quat_wxyz[1], quat_wxyz[2], quat_wxyz[3], quat_wxyz[0]])
|
||||
config = np.concatenate([pos, quat_xyzw, dof_pos[self.idx_map]])
|
||||
pin.framesForwardKinematics(self.model, self.data, config)
|
||||
coeffs = pin.Quaternion(self.data.oMf[self.ref_frame_id].rotation).coeffs()
|
||||
return np.array([coeffs[3], coeffs[0], coeffs[1], coeffs[2]]).reshape(1, 4)
|
||||
|
||||
def get_torso_tilt(self, pos, quat_wxyz, dof_pos):
|
||||
"""Get torso tilt angle from upright (degrees). Uses roll and pitch."""
|
||||
torso_q = self.get_torso_quat(pos, quat_wxyz, dof_pos)
|
||||
roll, pitch, _ = quat_to_rpy(torso_q.flatten())
|
||||
# Tilt is the angle from vertical - combine roll and pitch
|
||||
tilt_rad = np.sqrt(roll**2 + pitch**2)
|
||||
return np.degrees(tilt_rad), np.degrees(roll), np.degrees(pitch)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# LOCOMOTION CONTROLLER
|
||||
# =============================================================================
|
||||
|
||||
class LocomotionController:
|
||||
"""Holosoma whole-body locomotion (23-DOF or 29-DOF)."""
|
||||
|
||||
def __init__(self, policy, robot, obs_dim: int):
|
||||
self.policy = policy
|
||||
self.robot = robot
|
||||
self.obs_dim = obs_dim
|
||||
|
||||
# Detect DOF mode
|
||||
self.is_23dof = (obs_dim == 82)
|
||||
self.num_dof = 23 if self.is_23dof else 29
|
||||
|
||||
if self.is_23dof:
|
||||
self.default_angles = DEFAULT_23DOF_ANGLES
|
||||
self.kp = DEFAULT_23DOF_KP
|
||||
self.kd = DEFAULT_23DOF_KD
|
||||
self.motor_map = DOF_23_TO_MOTOR
|
||||
logger.info("Locomotion: 23-DOF (82D obs)")
|
||||
else:
|
||||
self.default_angles = DEFAULT_29DOF_ANGLES
|
||||
self.kp = DEFAULT_29DOF_KP
|
||||
self.kd = DEFAULT_29DOF_KD
|
||||
self.motor_map = list(range(29))
|
||||
logger.info("Locomotion: 29-DOF (100D obs)")
|
||||
|
||||
self.cmd = np.zeros(3, dtype=np.float32)
|
||||
self.qj = np.zeros(self.num_dof, dtype=np.float32)
|
||||
self.dqj = np.zeros(self.num_dof, dtype=np.float32)
|
||||
self.obs = np.zeros(obs_dim, dtype=np.float32)
|
||||
self.last_action = np.zeros(self.num_dof, dtype=np.float32)
|
||||
|
||||
self.phase = np.array([[0.0, np.pi]], dtype=np.float32)
|
||||
self.phase_dt = 2 * np.pi / (50.0 * GAIT_PERIOD)
|
||||
self.is_standing = True
|
||||
|
||||
def run_step(self):
|
||||
"""Single locomotion step."""
|
||||
state = self.robot.lowstate_buffer.get_data()
|
||||
if state is None:
|
||||
return
|
||||
|
||||
# Joystick
|
||||
if state.wireless_remote is not None:
|
||||
self.robot.remote_controller.set(state.wireless_remote)
|
||||
|
||||
ly = self.robot.remote_controller.ly if abs(self.robot.remote_controller.ly) > 0.1 else 0.0
|
||||
lx = self.robot.remote_controller.lx if abs(self.robot.remote_controller.lx) > 0.1 else 0.0
|
||||
rx = self.robot.remote_controller.rx if abs(self.robot.remote_controller.rx) > 0.1 else 0.0
|
||||
self.cmd[0], self.cmd[1], self.cmd[2] = ly, -lx, -rx
|
||||
|
||||
# Read joints via motor map
|
||||
for i in range(self.num_dof):
|
||||
self.qj[i] = state.motor_state[self.motor_map[i]].q
|
||||
self.dqj[i] = state.motor_state[self.motor_map[i]].dq
|
||||
|
||||
# IMU
|
||||
quat = state.imu_state.quaternion
|
||||
ang_vel = np.array(state.imu_state.gyroscope, dtype=np.float32)
|
||||
gravity = self.robot.get_gravity_orientation(quat)
|
||||
|
||||
# Scale
|
||||
qj_obs = (self.qj - self.default_angles) * DOF_POS_SCALE
|
||||
dqj_obs = self.dqj * DOF_VEL_SCALE
|
||||
ang_vel_s = ang_vel * ANG_VEL_SCALE
|
||||
|
||||
# Phase
|
||||
cmd_mag = np.linalg.norm(self.cmd[:2])
|
||||
ang_mag = abs(self.cmd[2])
|
||||
if cmd_mag < 0.01 and ang_mag < 0.01:
|
||||
self.phase[0, :] = np.pi
|
||||
self.is_standing = True
|
||||
elif self.is_standing:
|
||||
self.phase = np.array([[0.0, np.pi]], dtype=np.float32)
|
||||
self.is_standing = False
|
||||
else:
|
||||
self.phase = np.fmod(self.phase + self.phase_dt + np.pi, 2*np.pi) - np.pi
|
||||
|
||||
sin_ph, cos_ph = np.sin(self.phase[0]), np.cos(self.phase[0])
|
||||
|
||||
# Build obs
|
||||
if self.is_23dof:
|
||||
self.obs[0:23] = self.last_action
|
||||
self.obs[23:26] = ang_vel_s
|
||||
self.obs[26] = self.cmd[2]
|
||||
self.obs[27:29] = self.cmd[:2]
|
||||
self.obs[29:31] = cos_ph
|
||||
self.obs[31:54] = qj_obs
|
||||
self.obs[54:77] = dqj_obs
|
||||
self.obs[77:80] = gravity
|
||||
self.obs[80:82] = sin_ph
|
||||
else:
|
||||
self.obs[0:29] = self.last_action
|
||||
self.obs[29:32] = ang_vel_s
|
||||
self.obs[32] = self.cmd[2]
|
||||
self.obs[33:35] = self.cmd[:2]
|
||||
self.obs[35:37] = cos_ph
|
||||
self.obs[37:66] = qj_obs
|
||||
self.obs[66:95] = dqj_obs
|
||||
self.obs[95:98] = gravity
|
||||
self.obs[98:100] = sin_ph
|
||||
|
||||
# Inference
|
||||
obs_in = self.obs.reshape(1, -1).astype(np.float32)
|
||||
ort_in = {self.policy.get_inputs()[0].name: obs_in}
|
||||
raw_action = self.policy.run(None, ort_in)[0].squeeze()
|
||||
clipped = np.clip(raw_action, -100.0, 100.0)
|
||||
self.last_action = clipped.copy()
|
||||
scaled = clipped * LOCOMOTION_ACTION_SCALE
|
||||
target = self.default_angles + scaled
|
||||
|
||||
# Send commands
|
||||
for i in range(self.num_dof):
|
||||
motor_idx = self.motor_map[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].q = float(target[i])
|
||||
self.robot.msg.motor_cmd[motor_idx].qd = 0
|
||||
self.robot.msg.motor_cmd[motor_idx].kp = self.kp[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].kd = self.kd[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].tau = 0
|
||||
|
||||
# Zero missing joints for 23-DOF
|
||||
if self.is_23dof:
|
||||
for idx in MISSING_23DOF_MOTORS:
|
||||
self.robot.msg.motor_cmd[idx].q = 0.0
|
||||
self.robot.msg.motor_cmd[idx].qd = 0
|
||||
self.robot.msg.motor_cmd[idx].kp = 40.0
|
||||
self.robot.msg.motor_cmd[idx].kd = 2.0
|
||||
self.robot.msg.motor_cmd[idx].tau = 0
|
||||
|
||||
self.robot.send_action(self.robot.msg)
|
||||
|
||||
def reset(self):
|
||||
"""Reset state for fresh start."""
|
||||
self.last_action.fill(0)
|
||||
self.phase = np.array([[0.0, np.pi]], dtype=np.float32)
|
||||
self.is_standing = True
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# DANCE CONTROLLER
|
||||
# =============================================================================
|
||||
|
||||
class DanceController:
|
||||
"""WBT dance policy with FK for torso tracking."""
|
||||
|
||||
def __init__(self, policy, robot, pinocchio_fk, motor_kp, motor_kd, action_scale):
|
||||
self.policy = policy
|
||||
self.robot = robot
|
||||
self.pinocchio_fk = pinocchio_fk
|
||||
self.motor_kp = motor_kp
|
||||
self.motor_kd = motor_kd
|
||||
self.action_scale = action_scale
|
||||
|
||||
self.obs_dim = policy.get_inputs()[0].shape[1]
|
||||
self.last_action = np.zeros((1, NUM_DOFS), dtype=np.float32)
|
||||
self.motion_command = None
|
||||
self.ref_quat_xyzw = None
|
||||
self.timestep = 0
|
||||
self.yaw_offset = 0.0
|
||||
|
||||
logger.info(f"Dance: obs_dim={self.obs_dim}, action_scale={action_scale}")
|
||||
|
||||
def initialize(self, reset_to_frame_0: bool = True):
|
||||
"""Initialize dance. If reset_to_frame_0=True, starts from frame 0. Otherwise resumes."""
|
||||
if reset_to_frame_0:
|
||||
self.timestep = 0
|
||||
self.last_action.fill(0)
|
||||
|
||||
# Get initial motion data at frame 0
|
||||
dummy = np.zeros((1, self.obs_dim), dtype=np.float32)
|
||||
outs = self.policy.run(["joint_pos", "joint_vel", "ref_quat_xyzw"],
|
||||
{"obs": dummy, "time_step": np.array([[0]], dtype=np.float32)})
|
||||
self.motion_command = np.concatenate(outs[0:2], axis=1)
|
||||
self.ref_quat_xyzw = outs[2]
|
||||
logger.info("Dance: reset to frame 0")
|
||||
else:
|
||||
# Resume from current timestep - just update motion command for current frame
|
||||
dummy = np.zeros((1, self.obs_dim), dtype=np.float32)
|
||||
outs = self.policy.run(["joint_pos", "joint_vel", "ref_quat_xyzw"],
|
||||
{"obs": dummy, "time_step": np.array([[self.timestep]], dtype=np.float32)})
|
||||
self.motion_command = np.concatenate(outs[0:2], axis=1)
|
||||
self.ref_quat_xyzw = outs[2]
|
||||
logger.info(f"Dance: resuming from frame {self.timestep}")
|
||||
|
||||
# Capture yaw offset
|
||||
state = self.robot.lowstate_buffer.get_data()
|
||||
if state and self.pinocchio_fk:
|
||||
quat = np.array(state.imu_state.quaternion, dtype=np.float32)
|
||||
dof = np.array([state.motor_state[i].q for i in range(NUM_DOFS)], dtype=np.float32)
|
||||
torso_q = self.pinocchio_fk.get_torso_quat(np.zeros(3), quat, dof)
|
||||
_, _, self.yaw_offset = quat_to_rpy(torso_q.flatten())
|
||||
logger.info(f"Dance yaw offset: {np.degrees(self.yaw_offset):.1f}°")
|
||||
|
||||
def _remove_yaw_offset(self, quat_wxyz):
|
||||
if abs(self.yaw_offset) < 1e-6:
|
||||
return quat_wxyz
|
||||
yaw_q = rpy_to_quat((0, 0, -self.yaw_offset)).reshape(1, 4)
|
||||
return quat_mul(yaw_q, quat_wxyz)
|
||||
|
||||
def run_step(self):
|
||||
"""Single dance step."""
|
||||
state = self.robot.lowstate_buffer.get_data()
|
||||
if state is None:
|
||||
return
|
||||
|
||||
quat = np.array(state.imu_state.quaternion, dtype=np.float32)
|
||||
ang_vel = np.array(state.imu_state.gyroscope, dtype=np.float32)
|
||||
dof_pos = np.array([state.motor_state[i].q for i in range(NUM_DOFS)], dtype=np.float32)
|
||||
dof_vel = np.array([state.motor_state[i].dq for i in range(NUM_DOFS)], dtype=np.float32)
|
||||
|
||||
# FK for torso orientation
|
||||
if self.pinocchio_fk:
|
||||
torso_q = self.pinocchio_fk.get_torso_quat(np.zeros(3), quat, dof_pos)
|
||||
torso_q = self._remove_yaw_offset(torso_q)
|
||||
motion_ori = xyzw_to_wxyz(self.ref_quat_xyzw)
|
||||
rel_quat = subtract_frame_transforms(torso_q, motion_ori)
|
||||
ori_b = matrix_from_quat(rel_quat)[..., :2].reshape(1, -1)
|
||||
else:
|
||||
ori_b = np.zeros((1, 6), dtype=np.float32)
|
||||
|
||||
dof_rel = (dof_pos - DEFAULT_29DOF_ANGLES).reshape(1, -1)
|
||||
|
||||
# Build obs (alphabetical)
|
||||
obs_dict = {
|
||||
"actions": self.last_action,
|
||||
"base_ang_vel": ang_vel.reshape(1, 3),
|
||||
"dof_pos": dof_rel,
|
||||
"dof_vel": dof_vel.reshape(1, -1),
|
||||
"motion_command": self.motion_command,
|
||||
"motion_ref_ori_b": ori_b,
|
||||
}
|
||||
obs = np.concatenate([obs_dict[k].astype(np.float32) for k in sorted(obs_dict.keys())], axis=1)
|
||||
obs = np.clip(obs, -100, 100)
|
||||
|
||||
# Inference
|
||||
outs = self.policy.run(["actions", "joint_pos", "joint_vel", "ref_quat_xyzw"],
|
||||
{"obs": obs, "time_step": np.array([[self.timestep]], dtype=np.float32)})
|
||||
action = np.clip(outs[0], -100, 100)
|
||||
self.motion_command = np.concatenate(outs[1:3], axis=1)
|
||||
self.ref_quat_xyzw = outs[3]
|
||||
self.last_action = action.copy()
|
||||
|
||||
target = DEFAULT_29DOF_ANGLES + action.flatten() * self.action_scale
|
||||
|
||||
# Send commands
|
||||
for i in range(NUM_DOFS):
|
||||
if i in FROZEN_JOINTS:
|
||||
self.robot.msg.motor_cmd[i].q = 0.0
|
||||
self.robot.msg.motor_cmd[i].kp = FROZEN_KP
|
||||
self.robot.msg.motor_cmd[i].kd = FROZEN_KD
|
||||
else:
|
||||
self.robot.msg.motor_cmd[i].q = float(target[i])
|
||||
self.robot.msg.motor_cmd[i].kp = self.motor_kp[i]
|
||||
self.robot.msg.motor_cmd[i].kd = self.motor_kd[i]
|
||||
self.robot.msg.motor_cmd[i].qd = 0
|
||||
self.robot.msg.motor_cmd[i].tau = 0
|
||||
|
||||
self.robot.send_action(self.robot.msg)
|
||||
self.timestep += 1
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# MAIN
|
||||
# =============================================================================
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Locomotion ↔ Dance Toggle")
|
||||
parser.add_argument("--loco-repo", type=str, default=DEFAULT_HOLOSOMA_REPO_ID)
|
||||
parser.add_argument("--dance-onnx", type=str, default=DANCE_ONNX_PATH)
|
||||
args = parser.parse_args()
|
||||
|
||||
print("=" * 70)
|
||||
print("🚶 LOCOMOTION ↔ 💃 DANCE")
|
||||
print("=" * 70)
|
||||
print("Press ENTER to toggle between modes")
|
||||
print("=" * 70)
|
||||
|
||||
# Load locomotion policy
|
||||
logger.info("Loading locomotion policy...")
|
||||
loco_path = hf_hub_download(repo_id=args.loco_repo, filename="fastsac_g1_29dof.onnx")
|
||||
loco_policy = ort.InferenceSession(loco_path)
|
||||
loco_obs_dim = loco_policy.get_inputs()[0].shape[1]
|
||||
logger.info(f"Locomotion: {loco_obs_dim}D obs")
|
||||
|
||||
# Load dance policy
|
||||
logger.info("Loading dance policy...")
|
||||
dance_policy = ort.InferenceSession(args.dance_onnx)
|
||||
dance_model = onnx.load(args.dance_onnx)
|
||||
dance_meta = {p.key: json.loads(p.value) for p in dance_model.metadata_props}
|
||||
dance_kp = np.array(dance_meta.get("kp", DEFAULT_29DOF_KP), dtype=np.float32)
|
||||
dance_kd = np.array(dance_meta.get("kd", DEFAULT_29DOF_KD), dtype=np.float32)
|
||||
dance_action_scale = float(dance_meta.get("action_scale", 1.0))
|
||||
logger.info(f"Dance: {dance_policy.get_inputs()[0].shape[1]}D obs, scale={dance_action_scale}")
|
||||
|
||||
# Build Pinocchio FK
|
||||
pinocchio_fk = None
|
||||
if "robot_urdf" in dance_meta:
|
||||
logger.info("Building Pinocchio FK...")
|
||||
pinocchio_fk = PinocchioFK(dance_meta["robot_urdf"])
|
||||
|
||||
# Initialize robot
|
||||
logger.info("Initializing robot...")
|
||||
config = UnitreeG1Config()
|
||||
robot = UnitreeG1(config)
|
||||
logger.info("Robot connected!")
|
||||
|
||||
# Create controllers
|
||||
loco_ctrl = LocomotionController(loco_policy, robot, loco_obs_dim)
|
||||
dance_ctrl = DanceController(dance_policy, robot, pinocchio_fk, dance_kp, dance_kd, dance_action_scale)
|
||||
|
||||
# State
|
||||
mode = "locomotion"
|
||||
toggle_event = threading.Event()
|
||||
shutdown = threading.Event()
|
||||
|
||||
# Input thread
|
||||
def input_loop():
|
||||
while not shutdown.is_set():
|
||||
if select.select([sys.stdin], [], [], 0.1)[0]:
|
||||
sys.stdin.readline()
|
||||
toggle_event.set()
|
||||
|
||||
input_thread = threading.Thread(target=input_loop, daemon=True)
|
||||
input_thread.start()
|
||||
|
||||
print("\n🚶 LOCOMOTION MODE - Use joystick to walk")
|
||||
print(" Press ENTER to switch to DANCE")
|
||||
print("-" * 70)
|
||||
|
||||
step = 0
|
||||
try:
|
||||
while not shutdown.is_set():
|
||||
t0 = time.time()
|
||||
|
||||
# Check toggle
|
||||
if toggle_event.is_set():
|
||||
toggle_event.clear()
|
||||
if mode == "locomotion":
|
||||
mode = "dance"
|
||||
dance_ctrl.initialize()
|
||||
print("\n" + "=" * 70)
|
||||
print("💃 DANCE MODE (frame 0)")
|
||||
print(" Press ENTER to switch to LOCOMOTION")
|
||||
print("=" * 70)
|
||||
else:
|
||||
mode = "locomotion"
|
||||
loco_ctrl.reset()
|
||||
print("\n" + "=" * 70)
|
||||
print("🚶 LOCOMOTION MODE")
|
||||
print(" Press ENTER to switch to DANCE")
|
||||
print("=" * 70)
|
||||
|
||||
# Run controller
|
||||
if mode == "locomotion":
|
||||
loco_ctrl.run_step()
|
||||
else:
|
||||
dance_ctrl.run_step()
|
||||
|
||||
# Log
|
||||
if step % 100 == 0:
|
||||
if mode == "locomotion":
|
||||
print(f"[LOCO ] step={step:5d} cmd=[{loco_ctrl.cmd[0]:.2f},{loco_ctrl.cmd[1]:.2f},{loco_ctrl.cmd[2]:.2f}]")
|
||||
else:
|
||||
print(f"[DANCE] step={step:5d} timestep={dance_ctrl.timestep}")
|
||||
|
||||
step += 1
|
||||
elapsed = time.time() - t0
|
||||
if elapsed < CONTROL_DT:
|
||||
time.sleep(CONTROL_DT - elapsed)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\nStopping...")
|
||||
finally:
|
||||
shutdown.set()
|
||||
robot.disconnect()
|
||||
|
||||
print("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,447 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Example: Unitree RL 12-DOF Legs-Only Locomotion (TorchScript)
|
||||
|
||||
This example demonstrates loading a 12-DOF legs-only locomotion policy
|
||||
(TorchScript .pt format) and running it on the Unitree G1 robot.
|
||||
|
||||
Key characteristics:
|
||||
- Single TorchScript policy (.pt)
|
||||
- 47D observations, 12D actions (legs only)
|
||||
- Phase-based gait timing
|
||||
- Arms and waist held at fixed positions
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
from scipy.spatial.transform import Rotation as R
|
||||
|
||||
from lerobot.robots.unitree_g1.config_unitree_g1 import UnitreeG1Config
|
||||
from lerobot.robots.unitree_g1.unitree_g1 import UnitreeG1
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 12-DOF leg joint configuration
|
||||
# Joint order: [L_hip_pitch, L_hip_roll, L_hip_yaw, L_knee, L_ankle_pitch, L_ankle_roll,
|
||||
# R_hip_pitch, R_hip_roll, R_hip_yaw, R_knee, R_ankle_pitch, R_ankle_roll]
|
||||
LEG_JOINT_INDICES = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
|
||||
|
||||
# Default leg angles for standing
|
||||
DEFAULT_LEG_ANGLES = np.array([
|
||||
-0.1, 0.0, 0.0, 0.3, -0.2, 0.0, # left leg
|
||||
-0.1, 0.0, 0.0, 0.3, -0.2, 0.0, # right leg
|
||||
], dtype=np.float32)
|
||||
|
||||
# KP/KD for leg joints
|
||||
LEG_KPS = np.array([150, 150, 150, 300, 40, 40, 150, 150, 150, 300, 40, 40], dtype=np.float32)
|
||||
LEG_KDS = np.array([6, 6, 6, 4, 2, 2, 6, 6, 6, 4, 2, 2], dtype=np.float32)
|
||||
|
||||
# Waist configuration (held at zero)
|
||||
WAIST_JOINT_INDICES = [12, 13, 14] # yaw, roll, pitch
|
||||
WAIST_KPS = np.array([250, 250, 250], dtype=np.float32)
|
||||
WAIST_KDS = np.array([5, 5, 5], dtype=np.float32)
|
||||
|
||||
# Arm configuration (indices 15-28, held at initial position)
|
||||
ARM_JOINT_INDICES = list(range(15, 29))
|
||||
ARM_KPS = np.array([80, 80, 80, 80, 40, 40, 40, # left arm (shoulder + wrist)
|
||||
80, 80, 80, 80, 40, 40, 40], dtype=np.float32) # right arm
|
||||
ARM_KDS = np.array([3, 3, 3, 3, 1.5, 1.5, 1.5,
|
||||
3, 3, 3, 3, 1.5, 1.5, 1.5], dtype=np.float32)
|
||||
|
||||
# Control parameters
|
||||
LOCOMOTION_CONTROL_DT = 0.02 # 50Hz control rate
|
||||
LOCOMOTION_ACTION_SCALE = 0.25
|
||||
ANG_VEL_SCALE = 0.25
|
||||
DOF_POS_SCALE = 1.0
|
||||
DOF_VEL_SCALE = 0.05
|
||||
CMD_SCALE = np.array([2.0, 2.0, 0.25], dtype=np.float32)
|
||||
MAX_CMD = np.array([0.8, 0.5, 1.57], dtype=np.float32) # max vx, vy, yaw_rate
|
||||
|
||||
# Gait parameters
|
||||
GAIT_PERIOD = 0.8 # seconds
|
||||
|
||||
DEFAULT_REPO_ID = "nepyope/unitree_rl_locomotion"
|
||||
|
||||
|
||||
def load_torchscript_policy(
|
||||
repo_id: str = DEFAULT_REPO_ID,
|
||||
filename: str = "motion.pt",
|
||||
) -> torch.jit.ScriptModule:
|
||||
"""Load TorchScript locomotion policy from Hugging Face Hub.
|
||||
|
||||
Args:
|
||||
repo_id: Hugging Face Hub repository ID containing the policy.
|
||||
filename: Policy filename (default: motion.pt).
|
||||
"""
|
||||
logger.info(f"Loading TorchScript policy from Hugging Face Hub ({repo_id}/{filename})...")
|
||||
|
||||
policy_path = hf_hub_download(
|
||||
repo_id=repo_id,
|
||||
filename=filename,
|
||||
)
|
||||
|
||||
policy = torch.jit.load(policy_path)
|
||||
policy.eval()
|
||||
|
||||
logger.info("TorchScript policy loaded successfully")
|
||||
|
||||
return policy
|
||||
|
||||
|
||||
class UnitreeRLLocomotionController:
|
||||
"""
|
||||
Handles 12-DOF legs-only locomotion control for the Unitree G1 robot.
|
||||
|
||||
This controller manages:
|
||||
- Single TorchScript policy
|
||||
- 47D observations (single frame)
|
||||
- 12D action output (legs only)
|
||||
- Arms and waist held at fixed positions
|
||||
- Phase-based gait timing
|
||||
"""
|
||||
|
||||
def __init__(self, policy, robot, config):
|
||||
self.policy = policy
|
||||
self.robot = robot
|
||||
self.config = config
|
||||
|
||||
# Velocity commands (vx, vy, yaw_rate)
|
||||
self.locomotion_cmd = np.array([0.0, 0.0, 0.0], dtype=np.float32)
|
||||
|
||||
# State variables (12 DOF legs)
|
||||
self.qj = np.zeros(12, dtype=np.float32)
|
||||
self.dqj = np.zeros(12, dtype=np.float32)
|
||||
self.locomotion_action = np.zeros(12, dtype=np.float32)
|
||||
self.locomotion_obs = np.zeros(47, dtype=np.float32)
|
||||
|
||||
# Initial arm positions (captured on reset)
|
||||
self.initial_arm_positions = np.zeros(14, dtype=np.float32)
|
||||
|
||||
# Counter for phase calculation
|
||||
self.counter = 0
|
||||
|
||||
# Thread management
|
||||
self.locomotion_running = False
|
||||
self.locomotion_thread = None
|
||||
|
||||
logger.info("UnitreeRLLocomotionController initialized")
|
||||
logger.info(" Observation dim: 47, Action dim: 12 (legs only)")
|
||||
|
||||
def locomotion_run(self):
|
||||
"""12-DOF legs-only locomotion policy loop."""
|
||||
self.counter += 1
|
||||
|
||||
if self.counter == 1:
|
||||
print("\n" + "=" * 60)
|
||||
print("🚀 RUNNING UNITREE RL 12-DOF LOCOMOTION POLICY")
|
||||
print(" 47D observations → 12D actions (legs only)")
|
||||
print(" Arms and waist held at fixed positions")
|
||||
print("=" * 60 + "\n")
|
||||
|
||||
# Get current observation
|
||||
robot_state = self.robot.get_observation()
|
||||
if robot_state is None:
|
||||
return
|
||||
|
||||
# Get command from remote controller
|
||||
if robot_state.wireless_remote is not None:
|
||||
self.robot.remote_controller.set(robot_state.wireless_remote)
|
||||
else:
|
||||
self.robot.remote_controller.lx = 0.0
|
||||
self.robot.remote_controller.ly = 0.0
|
||||
self.robot.remote_controller.rx = 0.0
|
||||
self.robot.remote_controller.ry = 0.0
|
||||
|
||||
self.locomotion_cmd[0] = self.robot.remote_controller.ly # forward/backward
|
||||
self.locomotion_cmd[1] = self.robot.remote_controller.lx * -1 # left/right (inverted)
|
||||
self.locomotion_cmd[2] = self.robot.remote_controller.rx * -1 # yaw (inverted)
|
||||
|
||||
# Get leg joint positions and velocities (12 DOF)
|
||||
for i, motor_idx in enumerate(LEG_JOINT_INDICES):
|
||||
self.qj[i] = robot_state.motor_state[motor_idx].q
|
||||
self.dqj[i] = robot_state.motor_state[motor_idx].dq
|
||||
|
||||
# Get IMU data
|
||||
quat = robot_state.imu_state.quaternion
|
||||
ang_vel = np.array(robot_state.imu_state.gyroscope, dtype=np.float32)
|
||||
|
||||
# Scale observations
|
||||
gravity_orientation = self.robot.get_gravity_orientation(quat)
|
||||
qj_obs = (self.qj - DEFAULT_LEG_ANGLES) * DOF_POS_SCALE
|
||||
dqj_obs = self.dqj * DOF_VEL_SCALE
|
||||
ang_vel_scaled = ang_vel * ANG_VEL_SCALE
|
||||
|
||||
# Calculate phase
|
||||
count = self.counter * LOCOMOTION_CONTROL_DT
|
||||
phase = (count % GAIT_PERIOD) / GAIT_PERIOD
|
||||
sin_phase = np.sin(2 * np.pi * phase)
|
||||
cos_phase = np.cos(2 * np.pi * phase)
|
||||
|
||||
# Build 47D observation vector
|
||||
# [0:3] - angular velocity (scaled)
|
||||
# [3:6] - gravity orientation
|
||||
# [6:9] - velocity command (scaled)
|
||||
# [9:21] - joint positions (12D, relative to default)
|
||||
# [21:33] - joint velocities (12D, scaled)
|
||||
# [33:45] - previous actions (12D)
|
||||
# [45] - sin_phase
|
||||
# [46] - cos_phase
|
||||
self.locomotion_obs[0:3] = ang_vel_scaled
|
||||
self.locomotion_obs[3:6] = gravity_orientation
|
||||
self.locomotion_obs[6:9] = self.locomotion_cmd * CMD_SCALE * MAX_CMD
|
||||
self.locomotion_obs[9:21] = qj_obs
|
||||
self.locomotion_obs[21:33] = dqj_obs
|
||||
self.locomotion_obs[33:45] = self.locomotion_action
|
||||
self.locomotion_obs[45] = sin_phase
|
||||
self.locomotion_obs[46] = cos_phase
|
||||
|
||||
# Run policy inference (TorchScript)
|
||||
obs_tensor = torch.from_numpy(self.locomotion_obs).unsqueeze(0).float()
|
||||
with torch.no_grad():
|
||||
action_tensor = self.policy(obs_tensor)
|
||||
self.locomotion_action = action_tensor.squeeze().numpy()
|
||||
|
||||
# Transform action to target joint positions
|
||||
target_leg_pos = DEFAULT_LEG_ANGLES + self.locomotion_action * LOCOMOTION_ACTION_SCALE
|
||||
|
||||
# Debug logging (first 3 iterations)
|
||||
if self.counter <= 3:
|
||||
print(f"\n[Unitree RL Debug #{self.counter}]")
|
||||
print(f" Phase: {phase:.3f} (sin={sin_phase:.3f}, cos={cos_phase:.3f})")
|
||||
print(f" Cmd (vx, vy, yaw): ({self.locomotion_cmd[0]:.2f}, {self.locomotion_cmd[1]:.2f}, {self.locomotion_cmd[2]:.2f})")
|
||||
print(f" Action range: [{self.locomotion_action.min():.3f}, {self.locomotion_action.max():.3f}]")
|
||||
|
||||
# Send commands to LEG motors (0-11)
|
||||
for i, motor_idx in enumerate(LEG_JOINT_INDICES):
|
||||
self.robot.msg.motor_cmd[motor_idx].q = target_leg_pos[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].qd = 0
|
||||
self.robot.msg.motor_cmd[motor_idx].kp = LEG_KPS[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].kd = LEG_KDS[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].tau = 0
|
||||
|
||||
# Hold WAIST motors at zero (12, 13, 14)
|
||||
for i, motor_idx in enumerate(WAIST_JOINT_INDICES):
|
||||
self.robot.msg.motor_cmd[motor_idx].q = 0.0
|
||||
self.robot.msg.motor_cmd[motor_idx].qd = 0
|
||||
self.robot.msg.motor_cmd[motor_idx].kp = WAIST_KPS[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].kd = WAIST_KDS[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].tau = 0
|
||||
|
||||
# Hold ARM motors at initial position (15-28)
|
||||
for i, motor_idx in enumerate(ARM_JOINT_INDICES):
|
||||
self.robot.msg.motor_cmd[motor_idx].q = self.initial_arm_positions[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].qd = 0
|
||||
self.robot.msg.motor_cmd[motor_idx].kp = ARM_KPS[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].kd = ARM_KDS[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].tau = 0
|
||||
|
||||
# Send command
|
||||
self.robot.send_action(self.robot.msg)
|
||||
|
||||
def _locomotion_thread_loop(self):
|
||||
"""Background thread that runs the locomotion policy at specified rate."""
|
||||
logger.info("Locomotion thread started")
|
||||
while self.locomotion_running:
|
||||
start_time = time.time()
|
||||
try:
|
||||
self.locomotion_run()
|
||||
except Exception as e:
|
||||
logger.error(f"Error in locomotion loop: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
# Sleep to maintain control rate
|
||||
elapsed = time.time() - start_time
|
||||
sleep_time = max(0, LOCOMOTION_CONTROL_DT - elapsed)
|
||||
time.sleep(sleep_time)
|
||||
logger.info("Locomotion thread stopped")
|
||||
|
||||
def start_locomotion_thread(self):
|
||||
if self.locomotion_running:
|
||||
logger.warning("Locomotion thread already running")
|
||||
return
|
||||
|
||||
logger.info("Starting locomotion control thread...")
|
||||
self.locomotion_running = True
|
||||
self.locomotion_thread = threading.Thread(target=self._locomotion_thread_loop, daemon=True)
|
||||
self.locomotion_thread.start()
|
||||
|
||||
logger.info("Locomotion control thread started!")
|
||||
|
||||
def stop_locomotion_thread(self):
|
||||
if not self.locomotion_running:
|
||||
return
|
||||
|
||||
logger.info("Stopping locomotion control thread...")
|
||||
self.locomotion_running = False
|
||||
if self.locomotion_thread:
|
||||
self.locomotion_thread.join(timeout=2.0)
|
||||
logger.info("Locomotion control thread stopped")
|
||||
|
||||
def reset_robot(self):
|
||||
"""Move legs to default standing position over 2 seconds (arms are captured and held)."""
|
||||
logger.info("Moving legs to default position...")
|
||||
|
||||
total_time = 2.0
|
||||
num_step = int(total_time / self.robot.control_dt)
|
||||
|
||||
# Get current state
|
||||
robot_state = self.robot.get_observation()
|
||||
|
||||
# Capture initial arm positions (to hold during locomotion)
|
||||
for i, motor_idx in enumerate(ARM_JOINT_INDICES):
|
||||
self.initial_arm_positions[i] = robot_state.motor_state[motor_idx].q
|
||||
logger.info(f"Captured initial arm positions: {self.initial_arm_positions[:4]}...")
|
||||
|
||||
# Record current leg positions
|
||||
init_leg_pos = np.zeros(12, dtype=np.float32)
|
||||
for i, motor_idx in enumerate(LEG_JOINT_INDICES):
|
||||
init_leg_pos[i] = robot_state.motor_state[motor_idx].q
|
||||
|
||||
# Interpolate legs to default position
|
||||
for step in range(num_step):
|
||||
alpha = step / num_step
|
||||
|
||||
# Interpolate leg positions
|
||||
for i, motor_idx in enumerate(LEG_JOINT_INDICES):
|
||||
target_pos = DEFAULT_LEG_ANGLES[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].q = (
|
||||
init_leg_pos[i] * (1 - alpha) + target_pos * alpha
|
||||
)
|
||||
self.robot.msg.motor_cmd[motor_idx].qd = 0
|
||||
self.robot.msg.motor_cmd[motor_idx].kp = LEG_KPS[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].kd = LEG_KDS[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].tau = 0
|
||||
|
||||
# Hold waist at zero
|
||||
for i, motor_idx in enumerate(WAIST_JOINT_INDICES):
|
||||
self.robot.msg.motor_cmd[motor_idx].q = 0.0
|
||||
self.robot.msg.motor_cmd[motor_idx].qd = 0
|
||||
self.robot.msg.motor_cmd[motor_idx].kp = WAIST_KPS[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].kd = WAIST_KDS[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].tau = 0
|
||||
|
||||
# Hold arms at initial position
|
||||
for i, motor_idx in enumerate(ARM_JOINT_INDICES):
|
||||
self.robot.msg.motor_cmd[motor_idx].q = self.initial_arm_positions[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].qd = 0
|
||||
self.robot.msg.motor_cmd[motor_idx].kp = ARM_KPS[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].kd = ARM_KDS[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].tau = 0
|
||||
|
||||
self.robot.msg.crc = self.robot.crc.Crc(self.robot.msg)
|
||||
self.robot.lowcmd_publisher.Write(self.robot.msg)
|
||||
time.sleep(self.robot.control_dt)
|
||||
|
||||
logger.info("Reached default leg position")
|
||||
|
||||
# Hold position for 2 seconds
|
||||
logger.info("Holding default position for 2 seconds...")
|
||||
hold_time = 2.0
|
||||
num_hold_steps = int(hold_time / self.robot.control_dt)
|
||||
|
||||
for _ in range(num_hold_steps):
|
||||
# Hold legs at default
|
||||
for i, motor_idx in enumerate(LEG_JOINT_INDICES):
|
||||
self.robot.msg.motor_cmd[motor_idx].q = DEFAULT_LEG_ANGLES[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].qd = 0
|
||||
self.robot.msg.motor_cmd[motor_idx].kp = LEG_KPS[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].kd = LEG_KDS[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].tau = 0
|
||||
|
||||
# Hold waist at zero
|
||||
for i, motor_idx in enumerate(WAIST_JOINT_INDICES):
|
||||
self.robot.msg.motor_cmd[motor_idx].q = 0.0
|
||||
self.robot.msg.motor_cmd[motor_idx].qd = 0
|
||||
self.robot.msg.motor_cmd[motor_idx].kp = WAIST_KPS[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].kd = WAIST_KDS[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].tau = 0
|
||||
|
||||
# Hold arms at initial position
|
||||
for i, motor_idx in enumerate(ARM_JOINT_INDICES):
|
||||
self.robot.msg.motor_cmd[motor_idx].q = self.initial_arm_positions[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].qd = 0
|
||||
self.robot.msg.motor_cmd[motor_idx].kp = ARM_KPS[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].kd = ARM_KDS[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].tau = 0
|
||||
|
||||
self.robot.msg.crc = self.robot.crc.Crc(self.robot.msg)
|
||||
self.robot.lowcmd_publisher.Write(self.robot.msg)
|
||||
time.sleep(self.robot.control_dt)
|
||||
|
||||
logger.info("Ready to start locomotion!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Unitree RL 12-DOF Locomotion Controller for Unitree G1")
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
default=DEFAULT_REPO_ID,
|
||||
help=f"Hugging Face Hub repo ID for policy (default: {DEFAULT_REPO_ID})",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--filename",
|
||||
type=str,
|
||||
default="motion.pt",
|
||||
help="Policy filename (default: motion.pt)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load policy
|
||||
policy = load_torchscript_policy(repo_id=args.repo_id, filename=args.filename)
|
||||
|
||||
# Initialize robot
|
||||
config = UnitreeG1Config()
|
||||
robot = UnitreeG1(config)
|
||||
|
||||
# Initialize locomotion controller
|
||||
locomotion_controller = UnitreeRLLocomotionController(
|
||||
policy=policy,
|
||||
robot=robot,
|
||||
config=config,
|
||||
)
|
||||
|
||||
# Reset robot and start locomotion thread
|
||||
try:
|
||||
locomotion_controller.reset_robot()
|
||||
locomotion_controller.start_locomotion_thread()
|
||||
|
||||
# Log status
|
||||
logger.info("Robot initialized with Unitree RL locomotion policy")
|
||||
logger.info("Locomotion controller running in background thread")
|
||||
logger.info("Use remote controller to command velocity:")
|
||||
logger.info(" Left stick Y: forward/backward")
|
||||
logger.info(" Left stick X: left/right")
|
||||
logger.info(" Right stick X: rotate")
|
||||
logger.info("Press Ctrl+C to stop")
|
||||
|
||||
# Keep robot alive
|
||||
while True:
|
||||
time.sleep(1.0)
|
||||
except KeyboardInterrupt:
|
||||
print("\nStopping locomotion...")
|
||||
locomotion_controller.stop_locomotion_thread()
|
||||
print("Done!")
|
||||
|
||||
|
After Width: | Height: | Size: 2.9 MiB |
|
After Width: | Height: | Size: 185 KiB |
|
After Width: | Height: | Size: 464 KiB |
|
After Width: | Height: | Size: 72 KiB |
|
After Width: | Height: | Size: 219 KiB |
|
After Width: | Height: | Size: 199 KiB |
|
Before Width: | Height: | Size: 160 KiB After Width: | Height: | Size: 160 KiB |
|
Before Width: | Height: | Size: 774 KiB |
|
Before Width: | Height: | Size: 2.3 MiB |
|
Before Width: | Height: | Size: 481 KiB |
|
After Width: | Height: | Size: 117 KiB |
|
After Width: | Height: | Size: 151 KiB |
|
After Width: | Height: | Size: 130 KiB |
|
After Width: | Height: | Size: 407 KiB |
@@ -96,7 +96,7 @@ dependencies = [
|
||||
# Common
|
||||
pygame-dep = ["pygame>=2.5.1,<2.7.0"]
|
||||
placo-dep = ["placo>=0.9.6,<0.10.0"]
|
||||
transformers-dep = ["transformers>=4.57.1,<5.0.0"]
|
||||
transformers-dep = ["transformers>=4.53.0,<5.0.0"]
|
||||
grpcio-dep = ["grpcio==1.73.1", "protobuf==6.31.0"] # TODO: Bumb dependency (compatible with wandb)
|
||||
|
||||
# Motors
|
||||
@@ -120,13 +120,6 @@ intelrealsense = [
|
||||
phone = ["hebi-py>=2.8.0,<2.12.0", "teleop>=0.1.0,<0.2.0", "fastapi<1.0"]
|
||||
|
||||
# Policies
|
||||
wallx = [
|
||||
"transformers==4.49.0",
|
||||
"peft==0.17.1",
|
||||
"scipy==1.15.3",
|
||||
"torchdiffeq==0.2.5",
|
||||
"qwen_vl_utils==0.0.11"
|
||||
]
|
||||
pi = ["transformers @ git+https://github.com/huggingface/transformers.git@fix/lerobot_openpi"]
|
||||
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0", "safetensors>=0.4.3,<1.0.0"]
|
||||
groot = [
|
||||
@@ -140,7 +133,6 @@ groot = [
|
||||
"ninja>=1.11.1,<2.0.0",
|
||||
"flash-attn>=2.5.9,<3.0.0 ; sys_platform != 'darwin'"
|
||||
]
|
||||
sarm = ["lerobot[transformers-dep]", "faker>=33.0.0,<35.0.0", "matplotlib>=3.10.3,<4.0.0", "qwen-vl-utils>=0.0.14"]
|
||||
xvla = ["lerobot[transformers-dep]"]
|
||||
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||
|
||||
@@ -148,7 +140,7 @@ hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpci
|
||||
async = ["lerobot[grpcio-dep]", "matplotlib>=3.10.3,<4.0.0"]
|
||||
|
||||
# Development
|
||||
dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1", "mypy>=1.19.1"]
|
||||
dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1"]
|
||||
test = ["pytest>=8.1.0,<9.0.0", "pytest-timeout>=2.4.0,<3.0.0", "pytest-cov>=5.0.0,<8.0.0", "mock-serial>=0.0.1,<0.1.0 ; sys_platform != 'win32'"]
|
||||
video_benchmark = ["scikit-image>=0.23.2,<0.26.0", "pandas>=2.2.2,<2.4.0"]
|
||||
|
||||
@@ -167,8 +159,7 @@ all = [
|
||||
"lerobot[reachy2]",
|
||||
"lerobot[kinematics]",
|
||||
"lerobot[intelrealsense]",
|
||||
# "lerobot[wallx]",
|
||||
# "lerobot[pi]", TODO(Pepijn): Update pi to transformers v5
|
||||
"lerobot[pi]",
|
||||
"lerobot[smolvla]",
|
||||
# "lerobot[groot]", TODO(Steven): Gr00t requires specific installation instructions for flash-attn
|
||||
"lerobot[xvla]",
|
||||
@@ -182,7 +173,6 @@ all = [
|
||||
"lerobot[phone]",
|
||||
"lerobot[libero]",
|
||||
"lerobot[metaworld]",
|
||||
"lerobot[sarm]"
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
@@ -237,7 +227,6 @@ ignore = [
|
||||
|
||||
[tool.ruff.lint.per-file-ignores]
|
||||
"__init__.py" = ["F401", "F403"]
|
||||
"src/lerobot/policies/wall_x/**" = ["N801", "N812", "SIM102", "SIM108", "SIM210", "SIM211", "B006", "B007", "SIM118"] # Supprese these as they are coming from original Qwen2_5_vl code TODO(pepijn): refactor original
|
||||
|
||||
[tool.ruff.lint.isort]
|
||||
combine-as-imports = true
|
||||
@@ -329,9 +318,9 @@ disallow_untyped_defs = true
|
||||
disallow_incomplete_defs = true
|
||||
check_untyped_defs = true
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "lerobot.optim.*"
|
||||
ignore_errors = false
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.optim.*"
|
||||
# ignore_errors = false
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "lerobot.model.*"
|
||||
@@ -381,77 +370,3 @@ ignore_errors = false
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.scripts.*"
|
||||
# ignore_errors = false
|
||||
|
||||
[tool.uv]
|
||||
# wallx requires transformers==4.49.0 which conflicts with other extras that need >=4.53.0
|
||||
conflicts = [
|
||||
[
|
||||
{ extra = "wallx" },
|
||||
{ extra = "transformers-dep" },
|
||||
],
|
||||
[
|
||||
{ extra = "wallx" },
|
||||
{ extra = "pi" },
|
||||
],
|
||||
[
|
||||
{ extra = "wallx" },
|
||||
{ extra = "smolvla" },
|
||||
],
|
||||
[
|
||||
{ extra = "wallx" },
|
||||
{ extra = "groot" },
|
||||
],
|
||||
[
|
||||
{ extra = "wallx" },
|
||||
{ extra = "xvla" },
|
||||
],
|
||||
[
|
||||
{ extra = "wallx" },
|
||||
{ extra = "sarm" },
|
||||
],
|
||||
[
|
||||
{ extra = "wallx" },
|
||||
{ extra = "hilserl" },
|
||||
],
|
||||
[
|
||||
{ extra = "wallx" },
|
||||
{ extra = "libero" },
|
||||
],
|
||||
[
|
||||
{ extra = "wallx" },
|
||||
{ extra = "all" },
|
||||
],
|
||||
# pi uses custom branch which conflicts with transformers-dep
|
||||
[
|
||||
{ extra = "pi" },
|
||||
{ extra = "transformers-dep" },
|
||||
],
|
||||
[
|
||||
{ extra = "pi" },
|
||||
{ extra = "smolvla" },
|
||||
],
|
||||
[
|
||||
{ extra = "pi" },
|
||||
{ extra = "groot" },
|
||||
],
|
||||
[
|
||||
{ extra = "pi" },
|
||||
{ extra = "xvla" },
|
||||
],
|
||||
[
|
||||
{ extra = "pi" },
|
||||
{ extra = "sarm" },
|
||||
],
|
||||
[
|
||||
{ extra = "pi" },
|
||||
{ extra = "hilserl" },
|
||||
],
|
||||
[
|
||||
{ extra = "pi" },
|
||||
{ extra = "libero" },
|
||||
],
|
||||
[
|
||||
{ extra = "pi" },
|
||||
{ extra = "all" },
|
||||
],
|
||||
]
|
||||
|
||||
@@ -56,7 +56,6 @@ class TrainPipelineConfig(HubMixin):
|
||||
steps: int = 100_000
|
||||
eval_freq: int = 20_000
|
||||
log_freq: int = 200
|
||||
tolerance_s: float = 1e-4
|
||||
save_checkpoint: bool = True
|
||||
# Checkpoint is saved every `save_freq` training iterations and after the last training step.
|
||||
save_freq: int = 20_000
|
||||
@@ -65,17 +64,9 @@ class TrainPipelineConfig(HubMixin):
|
||||
scheduler: LRSchedulerConfig | None = None
|
||||
eval: EvalConfig = field(default_factory=EvalConfig)
|
||||
wandb: WandBConfig = field(default_factory=WandBConfig)
|
||||
|
||||
# RA-BC (Reward-Aligned Behavior Cloning) parameters
|
||||
use_rabc: bool = False # Enable reward-weighted training
|
||||
rabc_progress_path: str | None = None # Path to precomputed SARM progress parquet file
|
||||
rabc_kappa: float = 0.01 # Hard threshold for high-quality samples
|
||||
rabc_epsilon: float = 1e-6 # Small constant for numerical stability
|
||||
rabc_head_mode: str | None = "sparse" # For dual-head models: "sparse" or "dense"
|
||||
|
||||
checkpoint_path: Path | None = field(init=False, default=None)
|
||||
# Rename map for the observation to override the image and state keys
|
||||
rename_map: dict[str, str] = field(default_factory=dict)
|
||||
checkpoint_path: Path | None = field(init=False, default=None)
|
||||
|
||||
def validate(self) -> None:
|
||||
# HACK: We parse again the cli args here to get the pretrained paths if there was some.
|
||||
@@ -139,14 +130,6 @@ class TrainPipelineConfig(HubMixin):
|
||||
"'policy.repo_id' argument missing. Please specify it to push the model to the hub."
|
||||
)
|
||||
|
||||
if self.use_rabc and not self.rabc_progress_path:
|
||||
# Auto-detect from dataset path
|
||||
repo_id = self.dataset.repo_id
|
||||
if self.dataset.root:
|
||||
self.rabc_progress_path = str(Path(self.dataset.root) / "sarm_progress.parquet")
|
||||
else:
|
||||
self.rabc_progress_path = f"hf://datasets/{repo_id}/sarm_progress.parquet"
|
||||
|
||||
@classmethod
|
||||
def __get_path_fields__(cls) -> list[str]:
|
||||
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
@@ -1,13 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
@@ -98,7 +98,6 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
|
||||
image_transforms=image_transforms,
|
||||
revision=cfg.dataset.revision,
|
||||
video_backend=cfg.dataset.video_backend,
|
||||
tolerance_s=cfg.tolerance_s,
|
||||
)
|
||||
else:
|
||||
dataset = StreamingLeRobotDataset(
|
||||
@@ -109,7 +108,6 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
|
||||
image_transforms=image_transforms,
|
||||
revision=cfg.dataset.revision,
|
||||
max_num_shards=cfg.num_workers,
|
||||
tolerance_s=cfg.tolerance_s,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("The MultiLeRobotDataset isn't supported for now.")
|
||||
|
||||
@@ -35,8 +35,6 @@ def make_optimizer_and_scheduler(
|
||||
tuple[Optimizer, LRScheduler | None]: The couple (Optimizer, Scheduler). Scheduler can be `None`.
|
||||
"""
|
||||
params = policy.get_optim_params() if cfg.use_policy_training_preset else policy.parameters()
|
||||
if cfg.optimizer is None:
|
||||
raise ValueError("Optimizer config is required but not provided in TrainPipelineConfig")
|
||||
optimizer = cfg.optimizer.build(params)
|
||||
lr_scheduler = cfg.scheduler.build(optimizer, cfg.steps) if cfg.scheduler is not None else None
|
||||
return optimizer, lr_scheduler
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import abc
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
@@ -30,17 +29,6 @@ from lerobot.utils.constants import (
|
||||
)
|
||||
from lerobot.utils.io_utils import deserialize_json_into_object
|
||||
|
||||
# Type alias for parameters accepted by optimizer build() methods.
|
||||
# This matches PyTorch's optimizer signature while also supporting:
|
||||
# - dict[str, Parameter]: Named parameters for differential LR by name (e.g., XVLA)
|
||||
# - dict[str, Iterable]: Multiple parameter groups for multi-optimizer configs (e.g., SAC)
|
||||
OptimizerParams = (
|
||||
Iterable[torch.nn.Parameter] # From model.parameters()
|
||||
| Iterable[dict[str, Any]] # List of param groups with lr/weight_decay overrides
|
||||
| dict[str, torch.nn.Parameter] # From dict(model.named_parameters()) for name-based LR
|
||||
| dict[str, Any] # For multi-optimizer configs (SAC) with multiple param groups
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class OptimizerConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
@@ -57,24 +45,13 @@ class OptimizerConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
return "adam"
|
||||
|
||||
@abc.abstractmethod
|
||||
def build(self, params: OptimizerParams) -> torch.optim.Optimizer | dict[str, torch.optim.Optimizer]:
|
||||
def build(self) -> torch.optim.Optimizer | dict[str, torch.optim.Optimizer]:
|
||||
"""
|
||||
Build the optimizer. It can be a single optimizer or a dictionary of optimizers.
|
||||
|
||||
NOTE: Multiple optimizers are useful when you have different models to optimize.
|
||||
For example, you can have one optimizer for the policy and another one for the value function
|
||||
in reinforcement learning settings.
|
||||
|
||||
Args:
|
||||
params: Parameters to optimize. Accepts multiple formats depending on the optimizer:
|
||||
- Iterable[Parameter]: From model.parameters() - standard PyTorch usage
|
||||
- Iterable[dict]: List of param groups with 'params' key and optional
|
||||
'lr', 'weight_decay' overrides (e.g., ACT, VQBeT policies)
|
||||
- dict[str, Parameter]: From dict(model.named_parameters()) for optimizers
|
||||
that apply differential learning rates by parameter name (e.g., XVLA)
|
||||
- dict[str, Iterable]: For multi-optimizer configs where each key maps to
|
||||
a separate optimizer's parameters (e.g., SAC with actor/critic/temperature)
|
||||
|
||||
Returns:
|
||||
The optimizer or a dictionary of optimizers.
|
||||
"""
|
||||
@@ -90,7 +67,7 @@ class AdamConfig(OptimizerConfig):
|
||||
weight_decay: float = 0.0
|
||||
grad_clip_norm: float = 10.0
|
||||
|
||||
def build(self, params: OptimizerParams) -> torch.optim.Optimizer:
|
||||
def build(self, params: dict) -> torch.optim.Optimizer:
|
||||
kwargs = asdict(self)
|
||||
kwargs.pop("grad_clip_norm")
|
||||
return torch.optim.Adam(params, **kwargs)
|
||||
@@ -105,7 +82,7 @@ class AdamWConfig(OptimizerConfig):
|
||||
weight_decay: float = 1e-2
|
||||
grad_clip_norm: float = 10.0
|
||||
|
||||
def build(self, params: OptimizerParams) -> torch.optim.Optimizer:
|
||||
def build(self, params: dict) -> torch.optim.Optimizer:
|
||||
kwargs = asdict(self)
|
||||
kwargs.pop("grad_clip_norm")
|
||||
return torch.optim.AdamW(params, **kwargs)
|
||||
@@ -121,7 +98,7 @@ class SGDConfig(OptimizerConfig):
|
||||
weight_decay: float = 0.0
|
||||
grad_clip_norm: float = 10.0
|
||||
|
||||
def build(self, params: OptimizerParams) -> torch.optim.Optimizer:
|
||||
def build(self, params: dict) -> torch.optim.Optimizer:
|
||||
kwargs = asdict(self)
|
||||
kwargs.pop("grad_clip_norm")
|
||||
return torch.optim.SGD(params, **kwargs)
|
||||
@@ -162,19 +139,21 @@ class XVLAAdamWConfig(OptimizerConfig):
|
||||
soft_prompt_lr_scale: float = 1.0 # Scale factor for soft-prompt LR (1.0 = same as base LR)
|
||||
soft_prompt_warmup_lr_scale: float | None = None # If set, start soft-prompts at this scale (e.g., 0.01)
|
||||
|
||||
def build(self, params: OptimizerParams) -> torch.optim.Optimizer:
|
||||
def build(self, params: dict) -> torch.optim.Optimizer:
|
||||
"""
|
||||
Build AdamW optimizer with differential learning rates.
|
||||
|
||||
Expects `named_parameters()` as input (dict of name -> param).
|
||||
Applies:
|
||||
- lr * 0.1 for all VLM-related parameters
|
||||
- lr * soft_prompt_lr_scale for soft-prompt parameters (with optional warmup)
|
||||
- full lr for all other parameters
|
||||
|
||||
Args:
|
||||
params: Must be a dict[str, Parameter] from dict(model.named_parameters())
|
||||
or equivalent.
|
||||
params: Dictionary of parameter names to parameters (from named_parameters())
|
||||
|
||||
Returns:
|
||||
AdamW optimizer with parameter groups for VLM, soft-prompts, and other components
|
||||
|
||||
Raises:
|
||||
AssertionError: If params is not a dict (e.g., from model.parameters())
|
||||
"""
|
||||
assert isinstance(params, dict), "Custom LR optimizer requires `named_parameters()` as inputs."
|
||||
|
||||
@@ -195,7 +174,7 @@ class XVLAAdamWConfig(OptimizerConfig):
|
||||
# Start at warmup scale, scheduler will warm up to soft_prompt_lr
|
||||
soft_prompt_lr = self.lr * self.soft_prompt_warmup_lr_scale
|
||||
|
||||
param_groups: list[dict[str, Any]] = [
|
||||
param_groups = [
|
||||
{
|
||||
"params": vlm_group,
|
||||
"lr": self.lr * 0.1,
|
||||
@@ -245,25 +224,19 @@ class MultiAdamConfig(OptimizerConfig):
|
||||
grad_clip_norm: float = 10.0
|
||||
optimizer_groups: dict[str, dict[str, Any]] = field(default_factory=dict)
|
||||
|
||||
def build(self, params: OptimizerParams) -> dict[str, torch.optim.Optimizer]:
|
||||
def build(self, params_dict: dict[str, list]) -> dict[str, torch.optim.Optimizer]:
|
||||
"""Build multiple Adam optimizers.
|
||||
|
||||
Args:
|
||||
params: Must be a dict[str, Iterable[Parameter]] mapping parameter group names
|
||||
to iterables of parameters. The keys should match the keys in optimizer_groups.
|
||||
Typically from policies that need separate optimizers (e.g., SAC with
|
||||
actor/critic/temperature).
|
||||
params_dict: Dictionary mapping parameter group names to lists of parameters
|
||||
The keys should match the keys in optimizer_groups
|
||||
|
||||
Returns:
|
||||
Dictionary mapping parameter group names to their optimizers
|
||||
|
||||
Raises:
|
||||
AssertionError: If params is not a dict
|
||||
"""
|
||||
assert isinstance(params, dict), "MultiAdamConfig requires a dict of parameter groups as inputs."
|
||||
optimizers = {}
|
||||
|
||||
for name, group_params in params.items():
|
||||
for name, params in params_dict.items():
|
||||
# Get group-specific hyperparameters or use defaults
|
||||
group_config = self.optimizer_groups.get(name, {})
|
||||
|
||||
@@ -275,7 +248,7 @@ class MultiAdamConfig(OptimizerConfig):
|
||||
"weight_decay": group_config.get("weight_decay", self.weight_decay),
|
||||
}
|
||||
|
||||
optimizers[name] = torch.optim.Adam(group_params, **optimizer_kwargs)
|
||||
optimizers[name] = torch.optim.Adam(params, **optimizer_kwargs)
|
||||
|
||||
return optimizers
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ from lerobot.utils.io_utils import deserialize_json_into_object
|
||||
|
||||
@dataclass
|
||||
class LRSchedulerConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
num_warmup_steps: int | None
|
||||
num_warmup_steps: int
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
|
||||
@@ -21,7 +21,6 @@ from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
|
||||
from .smolvla.processor_smolvla import SmolVLANewLineProcessor
|
||||
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
|
||||
from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig
|
||||
from .wall_x.configuration_wall_x import WallXConfig as WallXConfig
|
||||
from .xvla.configuration_xvla import XVLAConfig as XVLAConfig
|
||||
|
||||
__all__ = [
|
||||
@@ -30,10 +29,8 @@ __all__ = [
|
||||
"PI0Config",
|
||||
"PI05Config",
|
||||
"SmolVLAConfig",
|
||||
"SARMConfig",
|
||||
"TDMPCConfig",
|
||||
"VQBeTConfig",
|
||||
"GrootConfig",
|
||||
"XVLAConfig",
|
||||
"WallXConfig",
|
||||
]
|
||||
|
||||
@@ -50,7 +50,6 @@ class ACTPolicy(PreTrainedPolicy):
|
||||
def __init__(
|
||||
self,
|
||||
config: ACTConfig,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
|
||||
@@ -56,7 +56,6 @@ class DiffusionPolicy(PreTrainedPolicy):
|
||||
def __init__(
|
||||
self,
|
||||
config: DiffusionConfig,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
|
||||
@@ -37,12 +37,10 @@ from lerobot.policies.pi05.configuration_pi05 import PI05Config
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
|
||||
from lerobot.policies.sarm.configuration_sarm import SARMConfig
|
||||
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
|
||||
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from lerobot.policies.utils import validate_visual_features_consistency
|
||||
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||
from lerobot.policies.wall_x.configuration_wall_x import WallXConfig
|
||||
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
|
||||
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
|
||||
from lerobot.processor.converters import (
|
||||
@@ -63,7 +61,7 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
|
||||
Args:
|
||||
name: The name of the policy. Supported names are "tdmpc", "diffusion", "act",
|
||||
"vqbet", "pi0", "pi05", "sac", "reward_classifier", "smolvla", "wall_x".
|
||||
"vqbet", "pi0", "pi05", "sac", "reward_classifier", "smolvla".
|
||||
|
||||
Returns:
|
||||
The policy class corresponding to the given name.
|
||||
@@ -107,10 +105,6 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy
|
||||
|
||||
return SmolVLAPolicy
|
||||
elif name == "sarm":
|
||||
from lerobot.policies.sarm.modeling_sarm import SARMRewardModel
|
||||
|
||||
return SARMRewardModel
|
||||
elif name == "groot":
|
||||
from lerobot.policies.groot.modeling_groot import GrootPolicy
|
||||
|
||||
@@ -119,10 +113,6 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
from lerobot.policies.xvla.modeling_xvla import XVLAPolicy
|
||||
|
||||
return XVLAPolicy
|
||||
elif name == "wall_x":
|
||||
from lerobot.policies.wall_x.modeling_wall_x import WallXPolicy
|
||||
|
||||
return WallXPolicy
|
||||
else:
|
||||
try:
|
||||
return _get_policy_cls_from_policy_name(name=name)
|
||||
@@ -140,7 +130,7 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
Args:
|
||||
policy_type: The type of the policy. Supported types include "tdmpc",
|
||||
"diffusion", "act", "vqbet", "pi0", "pi05", "sac", "smolvla",
|
||||
"reward_classifier", "wall_x".
|
||||
"reward_classifier".
|
||||
**kwargs: Keyword arguments to be passed to the configuration class constructor.
|
||||
|
||||
Returns:
|
||||
@@ -171,8 +161,6 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
return GrootConfig(**kwargs)
|
||||
elif policy_type == "xvla":
|
||||
return XVLAConfig(**kwargs)
|
||||
elif policy_type == "wall_x":
|
||||
return WallXConfig(**kwargs)
|
||||
else:
|
||||
try:
|
||||
config_cls = PreTrainedConfig.get_choice_class(policy_type)
|
||||
@@ -349,14 +337,6 @@ def make_pre_post_processors(
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, SARMConfig):
|
||||
from lerobot.policies.sarm.processor_sarm import make_sarm_pre_post_processors
|
||||
|
||||
processors = make_sarm_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
dataset_meta=kwargs.get("dataset_meta"),
|
||||
)
|
||||
elif isinstance(policy_cfg, GrootConfig):
|
||||
from lerobot.policies.groot.processor_groot import make_groot_pre_post_processors
|
||||
|
||||
@@ -364,7 +344,6 @@ def make_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, XVLAConfig):
|
||||
from lerobot.policies.xvla.processor_xvla import (
|
||||
make_xvla_pre_post_processors,
|
||||
@@ -375,14 +354,6 @@ def make_pre_post_processors(
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, WallXConfig):
|
||||
from lerobot.policies.wall_x.processor_wall_x import make_wall_x_pre_post_processors
|
||||
|
||||
processors = make_wall_x_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
else:
|
||||
try:
|
||||
processors = _make_processors_from_policy_config(
|
||||
@@ -464,13 +435,6 @@ def make_policy(
|
||||
cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features}
|
||||
kwargs["config"] = cfg
|
||||
|
||||
# Pass dataset_stats to the policy if available (needed for some policies like SARM)
|
||||
if ds_meta is not None and hasattr(ds_meta, "stats"):
|
||||
kwargs["dataset_stats"] = ds_meta.stats
|
||||
|
||||
if ds_meta is not None:
|
||||
kwargs["dataset_meta"] = ds_meta
|
||||
|
||||
if cfg.pretrained_path:
|
||||
# Load a pretrained policy and override the config if needed (for example, if there are inference-time
|
||||
# hyperparameters that we want to vary).
|
||||
|
||||
@@ -49,7 +49,7 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
name = "groot"
|
||||
config_class = GrootConfig
|
||||
|
||||
def __init__(self, config: GrootConfig, **kwargs):
|
||||
def __init__(self, config: GrootConfig):
|
||||
"""Initialize Groot policy wrapper."""
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
|
||||
@@ -93,11 +93,10 @@ def create_sinusoidal_pos_embedding( # see openpi `create_sinusoidal_pos_embedd
|
||||
|
||||
|
||||
def sample_beta(alpha, beta, bsize, device): # see openpi `sample_beta` (exact copy)
|
||||
# Beta sampling uses _sample_dirichlet which isn't implemented for MPS, so sample on CPU
|
||||
alpha_t = torch.tensor(alpha, dtype=torch.float32)
|
||||
beta_t = torch.tensor(beta, dtype=torch.float32)
|
||||
alpha_t = torch.as_tensor(alpha, dtype=torch.float32, device=device)
|
||||
beta_t = torch.as_tensor(beta, dtype=torch.float32, device=device)
|
||||
dist = torch.distributions.Beta(alpha_t, beta_t)
|
||||
return dist.sample((bsize,)).to(device)
|
||||
return dist.sample((bsize,))
|
||||
|
||||
|
||||
def make_att_2d_masks(pad_masks, att_masks): # see openpi `make_att_2d_masks` (exact copy)
|
||||
@@ -908,7 +907,6 @@ class PI0Policy(PreTrainedPolicy):
|
||||
def __init__(
|
||||
self,
|
||||
config: PI0Config,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@@ -1237,15 +1235,9 @@ class PI0Policy(PreTrainedPolicy):
|
||||
|
||||
return actions
|
||||
|
||||
def forward(self, batch: dict[str, Tensor], reduction: str = "mean") -> tuple[Tensor, dict]:
|
||||
"""Run the batch through the model and compute the loss for training.
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
||||
"""Run the batch through the model and compute the loss for training."""
|
||||
|
||||
Args:
|
||||
batch: Training batch containing observations and actions.
|
||||
reduction: How to reduce the loss. Options:
|
||||
- "mean": Return scalar mean loss (default, backward compatible)
|
||||
- "none": Return per-sample losses of shape (batch_size,) for RA-BC weighting
|
||||
"""
|
||||
# Prepare inputs
|
||||
images, img_masks = self._preprocess_images(batch)
|
||||
lang_tokens, lang_masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
|
||||
@@ -1259,17 +1251,11 @@ class PI0Policy(PreTrainedPolicy):
|
||||
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||
losses = losses[:, :, :original_action_dim]
|
||||
|
||||
loss = losses.mean()
|
||||
|
||||
loss_dict = {
|
||||
"loss": loss.item(),
|
||||
"loss_per_dim": losses.mean(dim=[0, 1]).detach().cpu().numpy().tolist(),
|
||||
}
|
||||
|
||||
if reduction == "none":
|
||||
# Return per-sample losses (B,) by averaging over time and action dims
|
||||
per_sample_loss = losses.mean(dim=(1, 2))
|
||||
loss_dict["loss"] = per_sample_loss.mean().item()
|
||||
return per_sample_loss, loss_dict
|
||||
else:
|
||||
# Default: return scalar mean loss
|
||||
loss = losses.mean()
|
||||
loss_dict["loss"] = loss.item()
|
||||
return loss, loss_dict
|
||||
return loss, loss_dict
|
||||
|
||||
@@ -880,7 +880,6 @@ class PI05Policy(PreTrainedPolicy):
|
||||
def __init__(
|
||||
self,
|
||||
config: PI05Config,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@@ -1210,15 +1209,9 @@ class PI05Policy(PreTrainedPolicy):
|
||||
|
||||
return actions
|
||||
|
||||
def forward(self, batch: dict[str, Tensor], reduction: str = "mean") -> tuple[Tensor, dict]:
|
||||
"""Run the batch through the model and compute the loss for training.
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
||||
"""Run the batch through the model and compute the loss for training."""
|
||||
|
||||
Args:
|
||||
batch: Training batch containing observations and actions.
|
||||
reduction: How to reduce the loss. Options:
|
||||
- "mean": Return scalar mean loss (default, backward compatible)
|
||||
- "none": Return per-sample losses of shape (batch_size,) for RA-BC weighting
|
||||
"""
|
||||
# Prepare inputs
|
||||
images, img_masks = self._preprocess_images(batch)
|
||||
tokens, masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
|
||||
@@ -1232,17 +1225,11 @@ class PI05Policy(PreTrainedPolicy):
|
||||
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||
losses = losses[:, :, :original_action_dim]
|
||||
|
||||
loss = losses.mean()
|
||||
|
||||
loss_dict = {
|
||||
"loss": loss.item(),
|
||||
"loss_per_dim": losses.mean(dim=[0, 1]).detach().cpu().numpy().tolist(),
|
||||
}
|
||||
|
||||
if reduction == "none":
|
||||
# Return per-sample losses (B,) by averaging over time and action dims
|
||||
per_sample_loss = losses.mean(dim=(1, 2))
|
||||
loss_dict["loss"] = per_sample_loss.mean().item()
|
||||
return per_sample_loss, loss_dict
|
||||
else:
|
||||
# Default: return scalar mean loss
|
||||
loss = losses.mean()
|
||||
loss_dict["loss"] = loss.item()
|
||||
return loss, loss_dict
|
||||
return loss, loss_dict
|
||||
|
||||
@@ -1,49 +0,0 @@
|
||||
# π₀.₅ (pi05)
|
||||
|
||||
This repository contains the Hugging Face port of **π₀.₅**, adapted from [OpenPI](https://github.com/Physical-Intelligence/openpi) by the Physical Intelligence.
|
||||
It is designed as a **Vision-Language-Action model with open-world generalization**.
|
||||
|
||||
---
|
||||
|
||||
## Model Overview
|
||||
|
||||
| Feature | π₀ | π₀.₅ |
|
||||
| -------------------- | ------------------------------------------------------ | ----------------------------------------- |
|
||||
| Time Conditioning | Concatenates time with actions via `action_time_mlp_*` | Uses `time_mlp_*` for AdaRMS conditioning |
|
||||
| AdaRMS | Not used | Used in action expert |
|
||||
| Tokenizer Length | 48 tokens | 200 tokens |
|
||||
| Discrete State Input | False (Uses `state_proj` layer) | True |
|
||||
| Parameter Count | Higher (includes state embedding) | Lower (no state embedding) |
|
||||
|
||||
---
|
||||
|
||||
## Citation
|
||||
|
||||
If you use this work, please cite both **OpenPI** and the π₀.₅ paper:
|
||||
|
||||
```bibtex
|
||||
@misc{openpi2024,
|
||||
author = {Physical Intelligence Lab},
|
||||
title = {OpenPI: PyTorch Implementation of π0 and π0.5 Policies},
|
||||
year = {2024},
|
||||
publisher = {GitHub},
|
||||
howpublished = {\url{https://github.com/Physical-Intelligence/openpi}},
|
||||
license = {Apache-2.0}
|
||||
}
|
||||
|
||||
@misc{intelligence2025pi05visionlanguageactionmodelopenworld,
|
||||
title = {π₀.₅: a Vision-Language-Action Model with Open-World Generalization},
|
||||
author = {Physical Intelligence and Kevin Black and Noah Brown and James Darpinian and Karan Dhabalia and Danny Driess and Adnan Esmail and Michael Equi and Chelsea Finn and Niccolo Fusai and Manuel Y. Galliker and Dibya Ghosh and Lachy Groom and Karol Hausman and Brian Ichter and Szymon Jakubczak and Tim Jones and Liyiming Ke and Devin LeBlanc and Sergey Levine and Adrian Li-Bell and Mohith Mothukuri and Suraj Nair and Karl Pertsch and Allen Z. Ren and Lucy Xiaoyang Shi and Laura Smith and Jost Tobias Springenberg and Kyle Stachowicz and James Tanner and Quan Vuong and Homer Walke and Anna Walling and Haohuan Wang and Lili Yu and Ury Zhilinsky},
|
||||
year = {2025},
|
||||
eprint = {2504.16054},
|
||||
archivePrefix= {arXiv},
|
||||
primaryClass = {cs.LG},
|
||||
url = {https://arxiv.org/abs/2504.16054},
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## License
|
||||
|
||||
This port follows the **Apache 2.0 License**, consistent with the original [OpenPI repository](https://github.com/Physical-Intelligence/openpi).
|
||||
@@ -1,21 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .configuration_pi05 import PI05Config
|
||||
from .modeling_pi05 import PI05Policy
|
||||
from .processor_pi05 import make_pi05_pre_post_processors
|
||||
|
||||
__all__ = ["PI05Config", "PI05Policy", "make_pi05_pre_post_processors"]
|
||||
@@ -1,164 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
|
||||
DEFAULT_IMAGE_SIZE = 224
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("pi05")
|
||||
@dataclass
|
||||
class PI05Config(PreTrainedConfig):
|
||||
paligemma_variant: str = "gemma_2b"
|
||||
action_expert_variant: str = "gemma_300m"
|
||||
dtype: str = "float32" # Options: "bfloat16", "float32"
|
||||
|
||||
n_obs_steps: int = 1
|
||||
chunk_size: int = 50 # Number of action steps to predict, in openpi called "action_horizon"
|
||||
n_action_steps: int = 50 # Number of action steps to execute
|
||||
|
||||
# Shorter state and action vectors will be padded to these dimensions
|
||||
max_state_dim: int = 32
|
||||
max_action_dim: int = 32
|
||||
|
||||
# Flow matching parameters: see openpi `PI0Pytorch`
|
||||
num_inference_steps: int = 10
|
||||
time_sampling_beta_alpha: float = 1.5
|
||||
time_sampling_beta_beta: float = 1.0
|
||||
time_sampling_scale: float = 0.999
|
||||
time_sampling_offset: float = 0.001
|
||||
min_period: float = 4e-3
|
||||
max_period: float = 4.0
|
||||
|
||||
# Real-Time Chunking (RTC) configuration
|
||||
rtc_config: RTCConfig | None = None
|
||||
|
||||
image_resolution: tuple[int, int] = (
|
||||
DEFAULT_IMAGE_SIZE,
|
||||
DEFAULT_IMAGE_SIZE,
|
||||
) # see openpi `preprocessing_pytorch.py`
|
||||
|
||||
# Add empty images. Used to add empty cameras when no image features are present.
|
||||
empty_cameras: int = 0
|
||||
|
||||
tokenizer_max_length: int = 200 # see openpi `__post_init__`
|
||||
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.QUANTILES, # Pi0.5 uses quantiles for state
|
||||
"ACTION": NormalizationMode.QUANTILES, # Pi0.5 uses quantiles for action
|
||||
}
|
||||
)
|
||||
|
||||
# Training settings
|
||||
gradient_checkpointing: bool = False # Enable gradient checkpointing for memory optimization
|
||||
compile_model: bool = False # Whether to use torch.compile for model optimization
|
||||
compile_mode: str = "max-autotune" # Torch compile mode
|
||||
device: str | None = None # Device to use for the model (None = auto-detect)
|
||||
|
||||
# Optimizer settings: see openpi `AdamW`
|
||||
optimizer_lr: float = 2.5e-5 # see openpi `CosineDecaySchedule: peak_lr`
|
||||
optimizer_betas: tuple[float, float] = (0.9, 0.95)
|
||||
optimizer_eps: float = 1e-8
|
||||
optimizer_weight_decay: float = 0.01
|
||||
optimizer_grad_clip_norm: float = 1.0
|
||||
|
||||
# Scheduler settings: see openpi `CosineDecaySchedule`
|
||||
# Note: These will auto-scale if --steps < scheduler_decay_steps
|
||||
# For example, --steps=3000 will scale warmup to 100 and decay to 3000
|
||||
scheduler_warmup_steps: int = 1_000
|
||||
scheduler_decay_steps: int = 30_000
|
||||
scheduler_decay_lr: float = 2.5e-6
|
||||
|
||||
tokenizer_max_length: int = 200 # see openpi `__post_init__`
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
# Validate configuration
|
||||
if self.n_action_steps > self.chunk_size:
|
||||
raise ValueError(
|
||||
f"n_action_steps ({self.n_action_steps}) cannot be greater than chunk_size ({self.chunk_size})"
|
||||
)
|
||||
|
||||
if self.paligemma_variant not in ["gemma_300m", "gemma_2b"]:
|
||||
raise ValueError(f"Invalid paligemma_variant: {self.paligemma_variant}")
|
||||
|
||||
if self.action_expert_variant not in ["gemma_300m", "gemma_2b"]:
|
||||
raise ValueError(f"Invalid action_expert_variant: {self.action_expert_variant}")
|
||||
|
||||
if self.dtype not in ["bfloat16", "float32"]:
|
||||
raise ValueError(f"Invalid dtype: {self.dtype}")
|
||||
|
||||
def validate_features(self) -> None:
|
||||
"""Validate and set up input/output features."""
|
||||
for i in range(self.empty_cameras):
|
||||
key = f"observation.images.empty_camera_{i}"
|
||||
empty_camera = PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(3, *self.image_resolution), # Use configured image resolution
|
||||
)
|
||||
self.input_features[key] = empty_camera
|
||||
|
||||
if "observation.state" not in self.input_features:
|
||||
state_feature = PolicyFeature(
|
||||
type=FeatureType.STATE,
|
||||
shape=(self.max_state_dim,), # Padded to max_state_dim
|
||||
)
|
||||
self.input_features["observation.state"] = state_feature
|
||||
|
||||
if "action" not in self.output_features:
|
||||
action_feature = PolicyFeature(
|
||||
type=FeatureType.ACTION,
|
||||
shape=(self.max_action_dim,), # Padded to max_action_dim
|
||||
)
|
||||
self.output_features["action"] = action_feature
|
||||
|
||||
def get_optimizer_preset(self) -> AdamWConfig:
|
||||
return AdamWConfig(
|
||||
lr=self.optimizer_lr,
|
||||
betas=self.optimizer_betas,
|
||||
eps=self.optimizer_eps,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
grad_clip_norm=self.optimizer_grad_clip_norm,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self):
|
||||
return CosineDecayWithWarmupSchedulerConfig(
|
||||
peak_lr=self.optimizer_lr,
|
||||
decay_lr=self.scheduler_decay_lr,
|
||||
num_warmup_steps=self.scheduler_warmup_steps,
|
||||
num_decay_steps=self.scheduler_decay_steps,
|
||||
)
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list:
|
||||
return list(range(self.chunk_size))
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
||||
@@ -1,995 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ONLY AN EXAMPLE FILE, NEVER USED, IT IS OLD CODE
|
||||
"""
|
||||
π0+FAST: Efficient Action Tokenization for Vision-Language-Action Models
|
||||
|
||||
[Paper](https://huggingface.co/papers/2501.09747)
|
||||
[Jax code](https://github.com/Physical-Intelligence/openpi)
|
||||
|
||||
Designed by Physical Intelligence. Ported from Jax by Hugging Face.
|
||||
Disclaimer: It is not expected to perform as well as the original implementation.
|
||||
|
||||
Example of finetuning the pi0+FAST pretrained model (`pi0_fast_base` in `openpi`):
|
||||
```bash
|
||||
lerobot-train \
|
||||
--policy.path=lerobot/pi0fast_base \
|
||||
--dataset.repo_id=danaaubakirova/koch_test
|
||||
```
|
||||
|
||||
Example of training the pi0+FAST neural network with from scratch:
|
||||
```bash
|
||||
lerobot-train \
|
||||
--policy.type=pi0fast \
|
||||
--dataset.repo_id=danaaubakirova/koch_test
|
||||
```
|
||||
|
||||
Example of using the pi0 pretrained model outside LeRobot training framework:
|
||||
```python
|
||||
policy = PI0FASTPolicy.from_pretrained("lerobot/pi0fast_base")
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
from collections import deque
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from PIL import Image
|
||||
from scipy.fft import idct
|
||||
from torch import Tensor, nn
|
||||
from transformers import AutoProcessor, AutoTokenizer, PaliGemmaForConditionalGeneration
|
||||
from transformers.cache_utils import HybridCache, StaticCache
|
||||
from transformers.models.auto import CONFIG_MAPPING
|
||||
|
||||
from lerobot.constants import ACTION, OBS_STATE
|
||||
from lerobot.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
|
||||
PRECISION = {
|
||||
"float16": torch.float16,
|
||||
"float32": torch.float32,
|
||||
"bfloat16": torch.bfloat16,
|
||||
}
|
||||
|
||||
|
||||
def normalize(x, min_val, max_val):
|
||||
return (x - min_val) / (max_val - min_val)
|
||||
|
||||
|
||||
def unnormalize(x, min_val, max_val):
|
||||
return x * (max_val - min_val) + min_val
|
||||
|
||||
|
||||
def safe_arcsin(value):
|
||||
# This ensures that the input stays within
|
||||
# [−1,1] to avoid invalid values for arcsin
|
||||
return torch.arcsin(torch.clamp(value, -1.0, 1.0))
|
||||
|
||||
|
||||
def aloha_gripper_to_angular(value):
|
||||
# Aloha transforms the gripper positions into a linear space. The following code
|
||||
# reverses this transformation to be consistent with pi0 which is pretrained in
|
||||
# angular space.
|
||||
#
|
||||
# These values are coming from the Aloha code:
|
||||
# PUPPET_GRIPPER_POSITION_OPEN, PUPPET_GRIPPER_POSITION_CLOSED
|
||||
value = unnormalize(value, min_val=0.01844, max_val=0.05800)
|
||||
|
||||
# This is the inverse of the angular to linear transformation inside the Interbotix code.
|
||||
def linear_to_radian(linear_position, arm_length, horn_radius):
|
||||
value = (horn_radius**2 + linear_position**2 - arm_length**2) / (2 * horn_radius * linear_position)
|
||||
return safe_arcsin(value)
|
||||
|
||||
# The constants are taken from the Interbotix code.
|
||||
value = linear_to_radian(value, arm_length=0.036, horn_radius=0.022)
|
||||
|
||||
# Normalize to [0, 1].
|
||||
# The values 0.4 and 1.5 were measured on an actual Trossen robot.
|
||||
return normalize(value, min_val=0.4, max_val=1.5)
|
||||
|
||||
|
||||
def aloha_gripper_from_angular(value):
|
||||
# Convert from the gripper position used by pi0 to the gripper position that is used by Aloha.
|
||||
# Note that the units are still angular but the range is different.
|
||||
|
||||
# The values 0.4 and 1.5 were measured on an actual Trossen robot.
|
||||
value = unnormalize(value, min_val=0.4, max_val=1.5)
|
||||
|
||||
# These values are coming from the Aloha code:
|
||||
# PUPPET_GRIPPER_JOINT_OPEN, PUPPET_GRIPPER_JOINT_CLOSE
|
||||
return normalize(value, min_val=-0.6213, max_val=1.4910)
|
||||
|
||||
|
||||
def aloha_gripper_from_angular_inv(value):
|
||||
# Directly inverts the gripper_from_angular function.
|
||||
value = unnormalize(value, min_val=-0.6213, max_val=1.4910)
|
||||
return normalize(value, min_val=0.4, max_val=1.5)
|
||||
|
||||
|
||||
class PI0FASTPolicy(PreTrainedPolicy):
|
||||
"""Wrapper class around PI0FAST tokenizer and model to train and run inference within LeRobot."""
|
||||
|
||||
config_class = PI0FASTConfig
|
||||
name = "pi0fast"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PI0FASTConfig,
|
||||
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
config: Policy configuration class instance or None, in which case the default instantiation of
|
||||
the configuration class is used.
|
||||
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
|
||||
that they will be passed with a call to `load_state_dict` before the policy is used.
|
||||
"""
|
||||
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
|
||||
self.language_tokenizer = AutoProcessor.from_pretrained("google/paligemma-3b-pt-224")
|
||||
self.model = PI0FAST(config)
|
||||
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
"""This should be called whenever the environment is reset."""
|
||||
self._action_queue = deque([], maxlen=self.config.n_action_steps)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
"""Override the from_pretrained method to display important disclaimer."""
|
||||
print(
|
||||
"⚠️ DISCLAIMER: The PI0FAST model is ported from JAX by the Hugging Face team. \n"
|
||||
" It is not expected to perform as well as the original implementation. \n"
|
||||
" Original implementation: https://github.com/Physical-Intelligence/openpi"
|
||||
)
|
||||
return super().from_pretrained(*args, **kwargs)
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
return self.parameters()
|
||||
|
||||
def _pi_aloha_decode_state(self, state):
|
||||
# Flip the joints.
|
||||
for motor_idx in [1, 2, 8, 9]:
|
||||
state[:, motor_idx] *= -1
|
||||
# Reverse the gripper transformation that is being applied by the Aloha runtime.
|
||||
for motor_idx in [6, 13]:
|
||||
state[:, motor_idx] = aloha_gripper_to_angular(state[:, motor_idx])
|
||||
return state
|
||||
|
||||
def _pi_aloha_encode_actions(self, actions):
|
||||
# Flip the joints.
|
||||
for motor_idx in [1, 2, 8, 9]:
|
||||
actions[:, :, motor_idx] *= -1
|
||||
# Reverse the gripper transformation that is being applied by the Aloha runtime.
|
||||
for motor_idx in [6, 13]:
|
||||
actions[:, :, motor_idx] = aloha_gripper_from_angular(actions[:, :, motor_idx])
|
||||
return actions
|
||||
|
||||
def _pi_aloha_encode_actions_inv(self, actions):
|
||||
# Flip the joints again.
|
||||
for motor_idx in [1, 2, 8, 9]:
|
||||
actions[:, :, motor_idx] *= -1
|
||||
# Reverse the gripper transformation that is being applied by the Aloha runtime.
|
||||
for motor_idx in [6, 13]:
|
||||
actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(actions[:, :, motor_idx])
|
||||
return actions
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Predict a chunk of actions given environment observations."""
|
||||
raise NotImplementedError("Currently not implemented for PI0FAST")
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Select a single action given environment observations.
|
||||
|
||||
This method wraps `select_actions` in order to return one action at a time for execution in the
|
||||
environment. It works by managing the actions in a queue and only calling `select_actions` when the
|
||||
queue is empty.
|
||||
"""
|
||||
self.eval()
|
||||
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
|
||||
# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
|
||||
# querying the policy.
|
||||
if len(self._action_queue) == 0:
|
||||
actions = self.model.generate_actions(batch)
|
||||
|
||||
actions = actions[:, : self.config.n_action_steps]
|
||||
|
||||
original_action_dim = self.config.action_feature.shape[
|
||||
0
|
||||
] # self.config.max_action_dim # self.config.action_feature.shape[0]
|
||||
actions = actions[:, :, :original_action_dim]
|
||||
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
actions = self._pi_aloha_encode_actions(actions)
|
||||
|
||||
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
|
||||
# effectively has shape (n_action_steps, batch_size, *), hence the transpose.
|
||||
self._action_queue.extend(actions.transpose(0, 1))
|
||||
return self._action_queue.popleft()
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
|
||||
batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = self.normalize_targets(batch)
|
||||
loss_dict = self.model.forward(batch)
|
||||
return loss_dict["loss"], loss_dict
|
||||
|
||||
|
||||
def block_causal_update_causal_mask(
|
||||
attention_mask,
|
||||
token_type_ids=None,
|
||||
past_key_values=None,
|
||||
cache_position=None,
|
||||
input_tensor=None,
|
||||
attn_implementation: str = "eager",
|
||||
dtype: torch.dtype = "float32",
|
||||
):
|
||||
"""
|
||||
Update the causal mask during training and generation. It can be customized to different attention masks.
|
||||
"""
|
||||
if attn_implementation == "flash_attention_2":
|
||||
if attention_mask is not None and 0.0 in attention_mask:
|
||||
return attention_mask
|
||||
return None
|
||||
using_static_cache = isinstance(past_key_values, StaticCache)
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
|
||||
if input_tensor is None:
|
||||
input_tensor = attention_mask
|
||||
|
||||
inputs_lead_dim, sequence_length = input_tensor.shape[:2]
|
||||
|
||||
if using_static_cache or isinstance(past_key_values, HybridCache):
|
||||
target_length = past_key_values.get_max_cache_shape()
|
||||
else:
|
||||
target_length = (
|
||||
attention_mask.shape[-1]
|
||||
if isinstance(attention_mask, torch.Tensor)
|
||||
else cache_position[0] + sequence_length + 1
|
||||
)
|
||||
|
||||
# Handle precomputed attention masks
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
return attention_mask
|
||||
|
||||
# Causal mask initialization
|
||||
causal_mask = torch.full(
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
|
||||
)
|
||||
|
||||
# Standard causal masking (triu ensures tokens can only attend to past)
|
||||
if sequence_length != 1:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
|
||||
# Apply block causal mask
|
||||
if token_type_ids is not None:
|
||||
token_type_ids = token_type_ids.to(causal_mask.device).bool()
|
||||
cumsum = torch.cumsum(token_type_ids, dim=1)
|
||||
block_causal_mask = cumsum[:, None, :] <= cumsum[:, :, None]
|
||||
|
||||
# Combine causal_mask with block-wise attention mask
|
||||
causal_mask = torch.where(block_causal_mask, 0.0, causal_mask)
|
||||
causal_mask = causal_mask[:, None, :, :]
|
||||
else:
|
||||
# Apply past cache position constraint
|
||||
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
|
||||
-1, 1
|
||||
)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1)
|
||||
else:
|
||||
# Apply past cache position constraint
|
||||
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
|
||||
-1, 1
|
||||
)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1)
|
||||
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # Copy to contiguous memory for in-place edits
|
||||
mask_length = attention_mask.shape[-1]
|
||||
|
||||
# Apply padding mask
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
||||
causal_mask.device
|
||||
)
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
|
||||
return causal_mask
|
||||
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
# self,
|
||||
input_ids,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
cache_position=None,
|
||||
position_ids=None,
|
||||
pixel_values=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
use_cache=True,
|
||||
num_logits_to_keep=None,
|
||||
labels=None,
|
||||
self=None,
|
||||
**kwargs,
|
||||
):
|
||||
# create block causal attention
|
||||
if cache_position[0] > 0 and input_ids.shape[1] > 0:
|
||||
input_tensor = input_ids[:, -1:]
|
||||
new_positions = (
|
||||
torch.ones(
|
||||
(position_ids.shape[0], input_ids.shape[1]),
|
||||
dtype=position_ids.dtype,
|
||||
device=position_ids.device,
|
||||
).cumsum(-1)
|
||||
+ position_ids[:, -1:]
|
||||
)
|
||||
position_ids = torch.cat([position_ids, new_positions], dim=-1)
|
||||
else:
|
||||
input_tensor = inputs_embeds
|
||||
attention_mask = block_causal_update_causal_mask(
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
cache_position=cache_position,
|
||||
input_tensor=input_tensor,
|
||||
token_type_ids=token_type_ids,
|
||||
dtype=self.dtype,
|
||||
attn_implementation=self.config.text_config._attn_implementation,
|
||||
)
|
||||
# Overwritten -- custom `position_ids` and `pixel_values` handling
|
||||
model_inputs = self.language_model.prepare_inputs_for_generation(
|
||||
input_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
cache_position=cache_position,
|
||||
use_cache=use_cache,
|
||||
num_logits_to_keep=num_logits_to_keep,
|
||||
token_type_ids=token_type_ids,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Position_ids in Paligemma are 1-indexed
|
||||
if model_inputs.get("position_ids") is not None:
|
||||
model_inputs["position_ids"] += 1
|
||||
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
|
||||
# Otherwise we need pixel values to be passed to model. NOTE: use_cache=False needs pixel_values always
|
||||
if cache_position[0] == 0:
|
||||
model_inputs["pixel_values"] = pixel_values
|
||||
is_training = token_type_ids is not None and labels is not None
|
||||
if cache_position[0] == 0 and isinstance(past_key_values, HybridCache):
|
||||
input_tensor = inputs_embeds if inputs_embeds is not None else input_ids
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask, token_type_ids, past_key_values, cache_position, input_tensor, is_training
|
||||
)
|
||||
model_inputs["attention_mask"] = causal_mask
|
||||
|
||||
return model_inputs
|
||||
|
||||
|
||||
class PI0FAST(nn.Module):
|
||||
def __init__(self, config: PI0FASTConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
# TODO: move tokenizers in Policy
|
||||
fast_tokenizer_path = "physical-intelligence/fast"
|
||||
pi0_paligemma_path = "google/paligemma-3b-pt-224"
|
||||
self.paligemma_tokenizer = AutoTokenizer.from_pretrained(pi0_paligemma_path)
|
||||
self.processor = AutoProcessor.from_pretrained(pi0_paligemma_path)
|
||||
self.fast_tokenizer = AutoProcessor.from_pretrained(fast_tokenizer_path, trust_remote_code=True)
|
||||
self.fast_skip_tokens = self.config.fast_skip_tokens
|
||||
self.max_input_seq_len = self.config.max_input_seq_len
|
||||
self.action_horizon = self.config.chunk_size
|
||||
self.action_dim = self.config.action_feature.shape[
|
||||
0
|
||||
] # self.config.max_action_dim # self.config.action_feature.shape[0]
|
||||
precision = config.precision
|
||||
torch_precision = PRECISION.get(precision, torch.float32)
|
||||
self.pad_token_id = (
|
||||
self.paligemma_tokenizer.pad_token_id
|
||||
if hasattr(self.paligemma_tokenizer, "pad_token_id")
|
||||
else self.paligemma_tokenizer.eos_token_id
|
||||
)
|
||||
|
||||
paligemma_config = CONFIG_MAPPING["paligemma"](
|
||||
transformers_version="4.48.1",
|
||||
_vocab_size=257152,
|
||||
bos_token_id=2,
|
||||
eos_token_id=1,
|
||||
hidden_size=2048,
|
||||
image_token_index=257152,
|
||||
model_type="paligemma",
|
||||
pad_token_id=0,
|
||||
projection_dim=2048,
|
||||
text_config={
|
||||
"hidden_activation": "gelu_pytorch_tanh",
|
||||
"hidden_size": 2048,
|
||||
"intermediate_size": 16384,
|
||||
"model_type": "gemma",
|
||||
"num_attention_heads": 8,
|
||||
"num_hidden_layers": 18,
|
||||
"num_image_tokens": 256,
|
||||
"num_key_value_heads": 1,
|
||||
"torch_dtype": precision,
|
||||
"vocab_size": 257152,
|
||||
"_attn_implementation": "eager",
|
||||
},
|
||||
vision_config={
|
||||
"hidden_size": 1152,
|
||||
"intermediate_size": 4304,
|
||||
"model_type": "siglip_vision_model",
|
||||
"num_attention_heads": 16,
|
||||
"num_hidden_layers": 27,
|
||||
"num_image_tokens": 256,
|
||||
"patch_size": 14,
|
||||
"projection_dim": 2048,
|
||||
"projector_hidden_act": "gelu_pytorch_tanh",
|
||||
"torch_dtype": precision,
|
||||
"vision_use_head": False,
|
||||
},
|
||||
)
|
||||
self.pi0_paligemma = PaliGemmaForConditionalGeneration(config=paligemma_config)
|
||||
|
||||
self.pi0_paligemma.prepare_inputs_for_generation = partial(
|
||||
prepare_inputs_for_generation, self=self.pi0_paligemma
|
||||
)
|
||||
# change important stuff in bf16
|
||||
params_to_change_dtype = [
|
||||
"language_model",
|
||||
"vision_tower",
|
||||
"multi_modal",
|
||||
]
|
||||
for name, param in self.pi0_paligemma.named_parameters():
|
||||
if any(selector in name for selector in params_to_change_dtype):
|
||||
param.data = param.data.to(dtype=torch_precision)
|
||||
self.set_requires_grad()
|
||||
self.image_keys = self.config.image_features.keys()
|
||||
# TODO: Remove this once we bump transformers to >4.52.0 because the attribute will be removed
|
||||
# AttributeError: 'PaliGemmaConfig' object has no attribute 'ignore_index'
|
||||
self.ignore_index = self.pi0_paligemma.config.ignore_index
|
||||
self.padding_side = self.config.padding_side
|
||||
|
||||
def set_requires_grad(self):
|
||||
if self.config.freeze_vision_encoder:
|
||||
self.pi0_paligemma.vision_tower.eval()
|
||||
for params in self.pi0_paligemma.vision_tower.parameters():
|
||||
params.requires_grad = False
|
||||
# To avoid unused params issue with distributed training
|
||||
if self.config.freeze_lm_head:
|
||||
for name, params in self.pi0_paligemma.named_parameters():
|
||||
if "embed_tokens" in name: # lm heads and embedding layer are tied
|
||||
params.requires_grad = False
|
||||
|
||||
def embed_tokens(self, tokens: torch.Tensor):
|
||||
return self.pi0_paligemma.language_model.model.embed_tokens(tokens)
|
||||
|
||||
def prepare_inputs_for_generation(self, *args, **kwargs):
|
||||
return self.pi0_paligemma.prepare_inputs_for_generation(*args, **kwargs)
|
||||
|
||||
def prepare_images(self, batch):
|
||||
"""Preprocess LeRobot batch into Pi0 inputs"""
|
||||
images = []
|
||||
img_masks = []
|
||||
present_img_keys = [key for key in self.image_keys if key in batch]
|
||||
if len(present_img_keys) == 0:
|
||||
raise ValueError(
|
||||
f"All image features are missing from the batch. At least one expected. (batch: {batch.keys()}) (image_features:{self.config.image_features})"
|
||||
)
|
||||
|
||||
# Preprocess image features present in the batch
|
||||
num_empty_cameras = 0
|
||||
for key in self.image_keys:
|
||||
if key in present_img_keys:
|
||||
img = batch[key]
|
||||
|
||||
if self.config.resize_imgs_with_padding is not None:
|
||||
img = resize_with_pad(
|
||||
img,
|
||||
*self.config.resize_imgs_with_padding,
|
||||
pad_value=0,
|
||||
interpolate_like_pi=self.config.interpolate_like_pi,
|
||||
)
|
||||
|
||||
# Normalize from range [0,1] to [-1,1] as expected by siglip
|
||||
img = img * 2.0 - 1.0
|
||||
|
||||
bsize = img.shape[0]
|
||||
device = img.device
|
||||
mask = torch.ones(bsize, dtype=torch.bool, device=device)
|
||||
else:
|
||||
if num_empty_cameras >= self.config.empty_cameras:
|
||||
continue
|
||||
img = torch.ones_like(img) * -1
|
||||
bsize = img.shape[0]
|
||||
device = img.device
|
||||
mask = torch.ones(bsize, dtype=torch.bool, device=device)
|
||||
num_empty_cameras += 1
|
||||
|
||||
images.append(img)
|
||||
img_masks.append(mask)
|
||||
return images, img_masks
|
||||
|
||||
def normalize_actions(self, actions: torch.Tensor) -> torch.Tensor:
|
||||
mins = actions.amin(dim=(1, 2), keepdim=True) # [0]
|
||||
maxs = actions.amax(dim=(1, 2), keepdim=True) # [0]
|
||||
return 2 * (actions - mins) / (maxs - mins + 1e-8) - 1
|
||||
|
||||
def _act_tokens_to_paligemma_tokens(self, tokens: torch.Tensor) -> torch.Tensor:
|
||||
out = self.paligemma_tokenizer.vocab_size - 1 - self.fast_skip_tokens - tokens
|
||||
return out
|
||||
|
||||
def fast_tokenizer_wrapper(self, actions_norm):
|
||||
"""
|
||||
A wrapper for self.fast_tokenizer that ensures batch processing,
|
||||
conversion to PyTorch tensors, and returns a dictionary without padding.
|
||||
"""
|
||||
batch_tokens = self.fast_tokenizer(actions_norm)
|
||||
fast_out = self.processor.tokenizer.pad({"input_ids": batch_tokens}, return_tensors="pt")
|
||||
|
||||
return fast_out
|
||||
|
||||
def create_token_type_ids(self, padded_mask: torch.Tensor, prefix_len: int) -> torch.Tensor:
|
||||
token_type_ids = torch.zeros_like(padded_mask, dtype=torch.bool)
|
||||
# Compute cumulative sum mask
|
||||
cumsum_mask = (padded_mask != 0).cumsum(dim=1)
|
||||
# Suffix block (everything after prefix_len)
|
||||
suffix_mask = cumsum_mask > prefix_len
|
||||
token_type_ids = suffix_mask
|
||||
return token_type_ids
|
||||
|
||||
def create_input_tokens(self, state, lang_text, actions=None):
|
||||
bsize = state.shape[0]
|
||||
device = state.device
|
||||
bins = torch.linspace(-1, 1, 256 + 1, device=device)[:-1]
|
||||
discretized = torch.bucketize(state, bins) - 1
|
||||
discretized = discretized[:, :32]
|
||||
|
||||
prefix_texts = []
|
||||
state_text = []
|
||||
for txt, disc in zip(lang_text, discretized, strict=False):
|
||||
cleaned = txt.lower().strip().replace("_", " ")
|
||||
state_str = " ".join(str(val.item()) for val in disc)
|
||||
prefix_texts.append(f"Task: {cleaned}, State: {state_str};\n")
|
||||
state_text.append(f"State: {state_str};\n")
|
||||
|
||||
prefix_out = self.paligemma_tokenizer(
|
||||
prefix_texts, add_special_tokens=True, return_tensors="pt", padding="longest", truncation=False
|
||||
)
|
||||
prefix_ids = prefix_out["input_ids"].to(device)
|
||||
prefix_mask = prefix_out["attention_mask"].to(device)
|
||||
prefix_lens = prefix_mask.sum(dim=1)[:, None].cpu()
|
||||
|
||||
if actions is not None:
|
||||
actions_norm = self.normalize_actions(actions)
|
||||
actions_pad = F.pad(
|
||||
actions_norm, (0, max(0, self.config.max_action_dim - actions_norm.shape[2])), value=0
|
||||
)[:, :, : self.config.max_action_dim]
|
||||
fast_out = self.fast_tokenizer_wrapper(
|
||||
actions_pad.cpu(),
|
||||
)
|
||||
act_ids = fast_out["input_ids"]
|
||||
act_mask = fast_out["attention_mask"].to(device)
|
||||
|
||||
act_ids = self._act_tokens_to_paligemma_tokens(act_ids).to(device)
|
||||
# Replace action with 0 to pad tokens
|
||||
act_ids = torch.where(
|
||||
act_ids == self.paligemma_tokenizer.vocab_size - 1 - self.fast_skip_tokens,
|
||||
self.pad_token_id,
|
||||
act_ids,
|
||||
)
|
||||
|
||||
eos_token = torch.tensor(
|
||||
[self.paligemma_tokenizer.eos_token_id], dtype=torch.long, device=device
|
||||
).expand(bsize, -1)
|
||||
eos_mask = torch.tensor([1], dtype=torch.long, device=device).expand(bsize, -1)
|
||||
bos = self.paligemma_tokenizer("Action: ", add_special_tokens=False, return_tensors="pt")
|
||||
bos_token = bos["input_ids"].expand(act_ids.shape[0], -1).to(device)
|
||||
bos_mask = bos["attention_mask"].expand(act_ids.shape[0], -1).to(device)
|
||||
act_ids = torch.cat([bos_token, act_ids, eos_token], dim=1)
|
||||
act_mask = torch.cat([bos_mask, act_mask, eos_mask], dim=1)
|
||||
act_mask = act_mask.to(device)
|
||||
else:
|
||||
act_ids = torch.empty(bsize, self.pad_token_id, dtype=torch.long, device=device)
|
||||
act_mask = torch.empty(bsize, 0, dtype=torch.long, device=device)
|
||||
final_ids = torch.cat([prefix_ids, act_ids], dim=1)
|
||||
|
||||
final_mask = torch.cat([prefix_mask, act_mask], dim=1)
|
||||
batch_inputs = {"input_ids": final_ids.tolist(), "attention_mask": final_mask.tolist()}
|
||||
|
||||
# Use tokenizer pad function
|
||||
padded_output = self.paligemma_tokenizer.pad(
|
||||
batch_inputs, padding="longest", max_length=180, return_tensors="pt"
|
||||
)
|
||||
padded_mask = padded_output["attention_mask"]
|
||||
|
||||
# define tensor of padding lengths
|
||||
att_mask = (padded_mask != 0).cumsum(dim=1) > prefix_lens
|
||||
|
||||
token_type_ids = self.create_token_type_ids(padded_mask=padded_mask, prefix_len=prefix_lens)
|
||||
|
||||
padded_output["padded_mask"] = padded_output.pop("attention_mask")
|
||||
padded_output["attention_mask"] = att_mask
|
||||
# loss is computed not on prefix, and not on padding
|
||||
padded_output["loss_mask"] = att_mask & padded_output["padded_mask"]
|
||||
padded_output["token_type_ids"] = token_type_ids
|
||||
return padded_output
|
||||
|
||||
def shift_padding_side(
|
||||
self,
|
||||
tokens: torch.Tensor,
|
||||
ar_mask: torch.Tensor,
|
||||
padding_mask: torch.Tensor,
|
||||
loss_mask: torch.Tensor,
|
||||
targets: torch.Tensor,
|
||||
token_type_ids: torch.Tensor,
|
||||
padding_side: str = "right",
|
||||
) -> tuple[torch.Tensor]:
|
||||
if padding_side not in ["right", "left"]:
|
||||
return tokens, ar_mask, padding_mask, loss_mask, targets, token_type_ids
|
||||
|
||||
new_tokens = torch.empty_like(tokens)
|
||||
new_ar_masks = torch.empty_like(ar_mask)
|
||||
new_padding_mask = torch.empty_like(padding_mask)
|
||||
new_loss_mask = torch.empty_like(loss_mask)
|
||||
new_targets = torch.empty_like(targets)
|
||||
new_token_type_ids = torch.empty_like(token_type_ids)
|
||||
batch_size = tokens.shape[0]
|
||||
for i in range(batch_size):
|
||||
padding_indices = torch.where(padding_mask[i] == 0)[0]
|
||||
non_padding_indices = torch.where(padding_mask[i] == 1)[0]
|
||||
if padding_side == "left":
|
||||
new_indices = torch.cat((padding_indices, non_padding_indices), dim=0)
|
||||
else:
|
||||
new_indices = torch.cat((non_padding_indices, padding_indices), dim=0)
|
||||
new_tokens[i] = tokens[i].index_select(0, new_indices)
|
||||
new_ar_masks[i] = ar_mask[i].index_select(0, new_indices)
|
||||
new_padding_mask[i] = padding_mask[i].index_select(0, new_indices)
|
||||
new_loss_mask[i] = loss_mask[i].index_select(0, new_indices)
|
||||
new_targets[i] = targets[i].index_select(0, new_indices)
|
||||
new_token_type_ids[i] = token_type_ids[i].index_select(0, new_indices)
|
||||
|
||||
return new_tokens, new_ar_masks, new_padding_mask, new_loss_mask, new_targets, new_token_type_ids
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]):
|
||||
device = batch[OBS_STATE].device
|
||||
# TODO: keep like this or move to the policy .forward
|
||||
images, img_masks = self.prepare_images(batch)
|
||||
|
||||
padded_outs = self.create_input_tokens(
|
||||
state=batch[OBS_STATE],
|
||||
lang_text=batch["task"],
|
||||
actions=batch[ACTION],
|
||||
)
|
||||
|
||||
embs, pad_masks, _, targets, loss_mask, token_type_ids = self.embed_inputs(
|
||||
images,
|
||||
img_masks,
|
||||
padded_outs["input_ids"],
|
||||
padded_outs["padded_mask"],
|
||||
padded_outs["attention_mask"],
|
||||
padded_outs["loss_mask"],
|
||||
padded_outs["token_type_ids"],
|
||||
padding_side=self.padding_side,
|
||||
)
|
||||
position_ids = torch.cumsum(pad_masks, dim=1) - 1
|
||||
token_type_ids = token_type_ids.to(dtype=torch.int64)
|
||||
past_seen_tokens = 0
|
||||
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + embs.shape[1], device=embs.device)
|
||||
pad_masks = block_causal_update_causal_mask(
|
||||
attention_mask=pad_masks,
|
||||
past_key_values=None,
|
||||
cache_position=cache_position,
|
||||
input_tensor=embs,
|
||||
token_type_ids=token_type_ids,
|
||||
dtype=self.pi0_paligemma.dtype,
|
||||
attn_implementation=self.pi0_paligemma.config.text_config._attn_implementation,
|
||||
)
|
||||
outputs = self.pi0_paligemma.forward(
|
||||
input_ids=None,
|
||||
token_type_ids=None,
|
||||
attention_mask=pad_masks,
|
||||
position_ids=position_ids,
|
||||
past_key_values=None,
|
||||
inputs_embeds=embs,
|
||||
use_cache=False,
|
||||
labels=None,
|
||||
)
|
||||
|
||||
logits = outputs.logits
|
||||
|
||||
loss_fct = nn.CrossEntropyLoss(reduction="none")
|
||||
|
||||
# Shift left for next-step prediction
|
||||
logits = logits[:, :-1, :]
|
||||
targets = targets[:, 1:].to(device) # Shift targets
|
||||
loss_mask = loss_mask[:, 1:].to(device) # Ensure correct shape
|
||||
|
||||
# Compute per-token loss
|
||||
token_loss = loss_fct(logits.reshape(-1, logits.shape[-1]), targets.reshape(-1))
|
||||
|
||||
# Apply loss mask
|
||||
token_loss = token_loss * loss_mask.reshape(-1)
|
||||
|
||||
# Compute final loss
|
||||
loss = token_loss.sum() / torch.clamp(loss_mask.sum(), min=1)
|
||||
|
||||
# Return loss dictionary
|
||||
loss_dict = {"ce_loss": loss.item(), "loss": loss}
|
||||
return loss_dict
|
||||
|
||||
def decode_actions_with_fast(
|
||||
self,
|
||||
tokens: list[list[int]],
|
||||
*,
|
||||
time_horizon: int | None = None,
|
||||
action_dim: int | None = None,
|
||||
relaxed_decoding: bool = True,
|
||||
) -> np.array:
|
||||
"""
|
||||
Adapt original decoding in FAST to always return actions instead of zeros.
|
||||
"""
|
||||
self.time_horizon = (
|
||||
time_horizon or self.fast_tokenizer.time_horizon or self.fast_tokenizer.called_time_horizon
|
||||
)
|
||||
self.action_dim = (
|
||||
action_dim or self.fast_tokenizer.action_dim or self.fast_tokenizer.called_action_dim
|
||||
)
|
||||
|
||||
# Cache the time horizon and action dimension for the next call
|
||||
self.called_time_horizon = self.time_horizon
|
||||
self.called_action_dim = self.action_dim
|
||||
|
||||
assert self.time_horizon is not None and self.action_dim is not None, (
|
||||
"Tokenizer not initialized, call encode() once or pass in time_horizon and action_dim."
|
||||
)
|
||||
|
||||
decoded_actions = []
|
||||
for token in tokens:
|
||||
try:
|
||||
decoded_tokens = self.fast_tokenizer.bpe_tokenizer.decode(token)
|
||||
decoded_dct_coeff = np.array(list(map(ord, decoded_tokens))) + self.fast_tokenizer.min_token
|
||||
if relaxed_decoding:
|
||||
# Expected sequence length
|
||||
expected_seq_len = self.time_horizon * self.action_dim
|
||||
diff = expected_seq_len - decoded_dct_coeff.shape[0]
|
||||
# Apply truncation if too long
|
||||
if diff < 0:
|
||||
decoded_dct_coeff = decoded_dct_coeff[:expected_seq_len] # Truncate on the right
|
||||
# Apply padding if too short
|
||||
elif diff > 0:
|
||||
decoded_dct_coeff = np.pad(
|
||||
decoded_dct_coeff, (0, diff), mode="constant", constant_values=0
|
||||
)
|
||||
|
||||
decoded_dct_coeff = decoded_dct_coeff.reshape(-1, self.action_dim)
|
||||
assert decoded_dct_coeff.shape == (
|
||||
self.time_horizon,
|
||||
self.action_dim,
|
||||
), (
|
||||
f"Decoded DCT coefficients have shape {decoded_dct_coeff.shape}, expected ({self.time_horizon}, {self.action_dim})"
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error decoding tokens: {e}")
|
||||
print(f"Tokens: {token}")
|
||||
decoded_dct_coeff = np.zeros((self.time_horizon, self.action_dim))
|
||||
decoded_actions.append(idct(decoded_dct_coeff / self.fast_tokenizer.scale, axis=0, norm="ortho"))
|
||||
return np.stack(decoded_actions)
|
||||
|
||||
def extract_actions(self, tokens: torch.Tensor, action_horizon: int, action_dim: int) -> torch.Tensor:
|
||||
"""
|
||||
Extracts actions from predicted output tokens using the FAST model.
|
||||
|
||||
Args:
|
||||
tokens (torch.Tensor): The input tensor of tokenized outputs.
|
||||
action_horizon (int): The number of timesteps for actions.
|
||||
action_dim (int): The dimensionality of each action.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The extracted actions as a tensor of shape (action_horizon, action_dim).
|
||||
"""
|
||||
# Decode predicted output tokens
|
||||
decoded_tokens = self.paligemma_tokenizer.batch_decode(tokens, skip_special_tokens=True)
|
||||
cleaned_tokens = [
|
||||
tokens_sequence.replace("Action:", "").replace(":", "").strip().split("|")[0].strip()
|
||||
for tokens_sequence in decoded_tokens
|
||||
]
|
||||
raw_action_tokens = [
|
||||
self.processor.tokenizer.encode(sample_tokens, return_tensors="pt", padding=False)
|
||||
for sample_tokens in cleaned_tokens
|
||||
] # something like this should be robust #looks good
|
||||
action_tokens = [
|
||||
self._act_tokens_to_paligemma_tokens(raw_action_token) for raw_action_token in raw_action_tokens
|
||||
]
|
||||
# returns the tensor of decoded actions per sample in a list
|
||||
decoded_actions = [
|
||||
torch.tensor(
|
||||
self.decode_actions_with_fast(
|
||||
tok.tolist(),
|
||||
time_horizon=action_horizon,
|
||||
action_dim=action_dim,
|
||||
relaxed_decoding=self.config.relaxed_action_decoding,
|
||||
),
|
||||
device=tokens.device,
|
||||
).squeeze(0)
|
||||
for tok in action_tokens
|
||||
]
|
||||
|
||||
return torch.stack(
|
||||
decoded_actions,
|
||||
dim=0,
|
||||
)
|
||||
|
||||
def generate_actions(self, batch: dict[str, Tensor]):
|
||||
# TODO: keep like this or move to the policy .forward
|
||||
images, img_masks = self.prepare_images(batch)
|
||||
|
||||
padded_outs = self.create_input_tokens(state=batch[OBS_STATE], lang_text=batch["task"], actions=None)
|
||||
embs, pad_masks, att_masks2, targets, loss_mask, token_type_ids = self.embed_inputs(
|
||||
images,
|
||||
img_masks,
|
||||
padded_outs["input_ids"],
|
||||
padded_outs["padded_mask"],
|
||||
padded_outs["attention_mask"],
|
||||
padded_outs["loss_mask"],
|
||||
padded_outs["token_type_ids"],
|
||||
padding_side="left",
|
||||
)
|
||||
token_type_ids = token_type_ids.to(dtype=torch.int64)
|
||||
prefix_position_ids = torch.cumsum(pad_masks, dim=1) - 1
|
||||
output_tokens = self.pi0_paligemma.generate(
|
||||
input_ids=None,
|
||||
attention_mask=pad_masks,
|
||||
position_ids=prefix_position_ids,
|
||||
past_key_values=None,
|
||||
inputs_embeds=embs,
|
||||
use_cache=self.config.use_cache,
|
||||
max_new_tokens=self.config.max_decoding_steps,
|
||||
do_sample=False,
|
||||
num_beams=1,
|
||||
token_type_ids=token_type_ids,
|
||||
)
|
||||
actions = self.extract_actions(output_tokens, self.action_horizon, self.action_dim)
|
||||
return actions
|
||||
|
||||
def embed_image(self, image: torch.Tensor):
|
||||
# Handle different transformers versions
|
||||
if hasattr(self.pi0_paligemma, "get_image_features"):
|
||||
return self.pi0_paligemma.get_image_features(image)
|
||||
else:
|
||||
return self.pi0_paligemma.model.get_image_features(image)
|
||||
|
||||
def embed_inputs(
|
||||
self,
|
||||
images,
|
||||
img_masks,
|
||||
tokens,
|
||||
pad_mask,
|
||||
ar_mask,
|
||||
loss_mask,
|
||||
token_type_ids,
|
||||
padding_side: str = "right",
|
||||
):
|
||||
# TODO: avoid list in python and torch.cat ; prefer pre-allocation with torch.empty
|
||||
# images are a list of same size
|
||||
# vectorizing everything!
|
||||
device = images[0].device
|
||||
image_embedding_dim = images[0].shape[-1] # TODO should be from self.config
|
||||
all_images = torch.stack(images, dim=1).to(device)
|
||||
b, n, c, h, w = all_images.shape
|
||||
all_images = all_images.view(b * n, c, h, w)
|
||||
embedded = self.embed_image(all_images).to(device)
|
||||
b_n, p, image_embedding_dim = embedded.shape # Extract current dimensions
|
||||
m = b_n // b # Compute the number of images per sample dynamically
|
||||
|
||||
# Reshape dynamically
|
||||
embedded = embedded.view(b, m, p, image_embedding_dim)
|
||||
tokens_embs = self.embed_tokens(tokens.to(device))
|
||||
|
||||
img_masks = torch.stack(img_masks, dim=1).unsqueeze(-1).to(device)
|
||||
num_img_emb = embedded.shape[2]
|
||||
img_pad_masks = img_masks.repeat(1, 1, num_img_emb).view(b, -1)
|
||||
img_att_masks = torch.zeros((b, n, num_img_emb), dtype=torch.long, device=device).reshape(b, -1)
|
||||
|
||||
image_target_tokens = (
|
||||
torch.ones((b, n, num_img_emb), dtype=torch.long, device=device) * self.pad_token_id
|
||||
).reshape(b, -1)
|
||||
image_loss_mask = torch.zeros((b, n, num_img_emb), dtype=torch.long, device=device).reshape(b, -1)
|
||||
|
||||
embedded = embedded.reshape(b, n * num_img_emb, image_embedding_dim) # Shape: (B, N*P, D)
|
||||
|
||||
embs = torch.cat([embedded, tokens_embs], dim=1).to(device)
|
||||
pad_masks = torch.cat([img_pad_masks, pad_mask.to(device)], dim=1)
|
||||
att_masks = torch.cat([img_att_masks, ar_mask.to(device)], dim=1)
|
||||
loss_masks = torch.cat([image_loss_mask, loss_mask.to(device)], dim=1)
|
||||
targets = torch.cat([image_target_tokens, tokens.to(device)], dim=1)
|
||||
token_type_ids = torch.cat([img_att_masks, token_type_ids.to(device)], dim=1)
|
||||
|
||||
# Shift pad tokens to the left (.generate()) or right (.train())
|
||||
embs, att_masks, pad_masks, loss_masks, targets, token_type_ids = self.shift_padding_side(
|
||||
embs, att_masks, pad_masks, loss_masks, targets, token_type_ids, padding_side=padding_side
|
||||
)
|
||||
|
||||
targets = torch.where(targets == self.pad_token_id, self.ignore_index, targets)
|
||||
return embs, pad_masks, att_masks, targets, loss_masks, token_type_ids
|
||||
|
||||
|
||||
def resize_with_pad(img, width, height, pad_value=0, interpolate_like_pi=True):
|
||||
# assume no-op when width height fits already
|
||||
if img.ndim != 4:
|
||||
raise ValueError(f"(b,c,h,w) expected, but {img.shape}")
|
||||
|
||||
cur_height, cur_width = img.shape[2:]
|
||||
|
||||
ratio = max(cur_width / width, cur_height / height)
|
||||
resized_height = int(cur_height / ratio)
|
||||
resized_width = int(cur_width / ratio)
|
||||
|
||||
if interpolate_like_pi:
|
||||
img = (img * 255.0).to(dtype=torch.uint8)
|
||||
img = img.permute(0, 2, 3, 1)
|
||||
original_device = img.device
|
||||
img = img.to(device="cpu").numpy()
|
||||
imgs = []
|
||||
for sub_img in img:
|
||||
sub_img = Image.fromarray(sub_img)
|
||||
resized_img = sub_img.resize((resized_width, resized_height), resample=2)
|
||||
resized_img = torch.from_numpy(np.array(resized_img))
|
||||
imgs.append(resized_img)
|
||||
img = torch.stack(imgs, dim=0)
|
||||
img = img.permute(0, 3, 1, 2)
|
||||
resized_img = img.to(device=original_device, dtype=torch.float32) / 255.0
|
||||
else:
|
||||
resized_img = F.interpolate(
|
||||
img, size=(resized_height, resized_width), mode="bilinear", align_corners=False
|
||||
)
|
||||
|
||||
pad_height = max(0, int(height - resized_height))
|
||||
pad_width = max(0, int(width - resized_width))
|
||||
|
||||
# pad on left and top of image
|
||||
padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value)
|
||||
return padded_img
|
||||
@@ -1,171 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.policies.pi05.configuration_pi05 import PI05Config
|
||||
from lerobot.policies.pi05.modeling_pi05 import pad_vector
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyAction,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorStep,
|
||||
ProcessorStepRegistry,
|
||||
RenameObservationsProcessorStep,
|
||||
TokenizerProcessorStep,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
|
||||
from lerobot.processor.core import EnvTransition, TransitionKey
|
||||
from lerobot.utils.constants import (
|
||||
OBS_STATE,
|
||||
POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
)
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register(name="pi05_prepare_state_tokenizer_processor_step")
|
||||
@dataclass
|
||||
class Pi05PrepareStateTokenizerProcessorStep(ProcessorStep):
|
||||
"""
|
||||
Processor step to prepare the state and tokenize the language input.
|
||||
"""
|
||||
|
||||
max_state_dim: int = 32
|
||||
task_key: str = "task"
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
transition = transition.copy()
|
||||
|
||||
state = transition.get(TransitionKey.OBSERVATION, {}).get(OBS_STATE)
|
||||
if state is None:
|
||||
raise ValueError("State is required for PI05")
|
||||
tasks = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}).get(self.task_key)
|
||||
if tasks is None:
|
||||
raise ValueError("No task found in complementary data")
|
||||
|
||||
# TODO: check if this necessary
|
||||
state = deepcopy(state)
|
||||
|
||||
# Prepare state (pad to max_state_dim)
|
||||
state = pad_vector(state, self.max_state_dim)
|
||||
|
||||
# State should already be normalized to [-1, 1] by the NormalizerProcessorStep that runs before this step
|
||||
# Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`)
|
||||
state_np = state.cpu().numpy()
|
||||
discretized_states = np.digitize(state_np, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
|
||||
|
||||
full_prompts = []
|
||||
for i, task in enumerate(tasks):
|
||||
cleaned_text = task.strip().replace("_", " ").replace("\n", " ")
|
||||
state_str = " ".join(map(str, discretized_states[i]))
|
||||
full_prompt = f"Task: {cleaned_text}, State: {state_str};\nAction: "
|
||||
full_prompts.append(full_prompt)
|
||||
|
||||
transition[TransitionKey.COMPLEMENTARY_DATA][self.task_key] = full_prompts
|
||||
# Normalize state to [-1, 1] range if needed (assuming it's already normalized by normalizer processor step!!)
|
||||
# Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`)
|
||||
return transition
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
"""
|
||||
This step does not alter the feature definitions.
|
||||
"""
|
||||
return features
|
||||
|
||||
|
||||
def make_pi05_pre_post_processors(
|
||||
config: PI05Config,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
"""
|
||||
Constructs pre-processor and post-processor pipelines for the PI0 policy.
|
||||
|
||||
The pre-processing pipeline prepares input data for the model by:
|
||||
1. Renaming features to match pretrained configurations.
|
||||
2. Normalizing input and output features based on dataset statistics.
|
||||
3. Adding a batch dimension.
|
||||
4. Appending a newline character to the task description for tokenizer compatibility.
|
||||
5. Tokenizing the text prompt using the PaliGemma tokenizer.
|
||||
6. Moving all data to the specified device.
|
||||
|
||||
The post-processing pipeline handles the model's output by:
|
||||
1. Moving data to the CPU.
|
||||
2. Unnormalizing the output features to their original scale.
|
||||
|
||||
Args:
|
||||
config: The configuration object for the PI0 policy.
|
||||
dataset_stats: A dictionary of statistics for normalization.
|
||||
preprocessor_kwargs: Additional arguments for the pre-processor pipeline.
|
||||
postprocessor_kwargs: Additional arguments for the post-processor pipeline.
|
||||
|
||||
Returns:
|
||||
A tuple containing the configured pre-processor and post-processor pipelines.
|
||||
"""
|
||||
|
||||
# Add remaining processors
|
||||
input_steps: list[ProcessorStep] = [
|
||||
RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one
|
||||
AddBatchDimensionProcessorStep(),
|
||||
# NOTE: NormalizerProcessorStep MUST come before Pi05PrepareStateTokenizerProcessorStep
|
||||
# because the tokenizer step expects normalized state in [-1, 1] range for discretization
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
Pi05PrepareStateTokenizerProcessorStep(max_state_dim=config.max_state_dim),
|
||||
TokenizerProcessorStep(
|
||||
tokenizer_name="google/paligemma-3b-pt-224",
|
||||
max_length=config.tokenizer_max_length,
|
||||
padding_side="right",
|
||||
padding="max_length",
|
||||
),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
]
|
||||
|
||||
output_steps: list[ProcessorStep] = [
|
||||
UnnormalizerProcessorStep(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
]
|
||||
|
||||
return (
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||
steps=input_steps,
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
),
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction](
|
||||
steps=output_steps,
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
to_transition=policy_action_to_transition,
|
||||
to_output=transition_to_policy_action,
|
||||
),
|
||||
)
|
||||
@@ -1,14 +0,0 @@
|
||||
## Paper
|
||||
|
||||
https://arxiv.org/abs/2509.25358
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@article{chen2025sarm,
|
||||
title={SARM: Stage-Aware Reward Modeling for Long Horizon Robot Manipulation},
|
||||
author={Chen, Qianzhong and Yu, Justin and Schwager, Mac and Abbeel, Pieter and Shentu, Yide and Wu, Philipp},
|
||||
journal={arXiv preprint arXiv:2509.25358},
|
||||
year={2025}
|
||||
}
|
||||
```
|
||||
@@ -1,870 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Compute SARM progress values for RA-BC (Reward-Aware Behavior Cloning) weighting.
|
||||
|
||||
This script processes all frames in a dataset with SARM to compute progress values [0, 1].
|
||||
The results are saved as a parquet file that can be loaded during training for RA-BC weighting.
|
||||
|
||||
Uses multi-output extraction: each SARM query returns progress for 9 frames, so we only
|
||||
need ~num_frames/30 queries instead of one per frame (~30x speedup).
|
||||
|
||||
Usage:
|
||||
# Full RA-BC computation with visualizations
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
|
||||
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
|
||||
--reward-model-path pepijn223/sarm_single_uni4
|
||||
|
||||
# Faster computation with stride (compute every 5 frames, interpolate the rest)
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
|
||||
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
|
||||
--reward-model-path pepijn223/sarm_single_uni4 \\
|
||||
--stride 5
|
||||
|
||||
# Visualize predictions only (no RA-BC computation)
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
|
||||
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
|
||||
--reward-model-path pepijn223/sarm_single_uni4 \\
|
||||
--visualize-only \\
|
||||
--num-visualizations 5
|
||||
|
||||
The output is saved to the dataset's local cache directory as 'sarm_progress.parquet'.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import matplotlib.gridspec as gridspec
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
import pyarrow.parquet as pq
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.policies.sarm.modeling_sarm import SARMRewardModel
|
||||
from lerobot.policies.sarm.processor_sarm import make_sarm_pre_post_processors
|
||||
from lerobot.policies.sarm.sarm_utils import normalize_stage_tau
|
||||
|
||||
|
||||
def get_reward_model_path_from_parquet(parquet_path: Path) -> str | None:
|
||||
"""Read reward_model_path from parquet metadata if available."""
|
||||
if not parquet_path.exists():
|
||||
return None
|
||||
try:
|
||||
metadata = pq.read_metadata(parquet_path).schema.to_arrow_schema().metadata
|
||||
if metadata and b"reward_model_path" in metadata:
|
||||
return metadata[b"reward_model_path"].decode()
|
||||
except Exception: # nosec B110
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
def load_sarm_resources(
|
||||
dataset_repo_id: str,
|
||||
reward_model_path: str,
|
||||
device: str = "cuda",
|
||||
) -> tuple[LeRobotDataset, SARMRewardModel, any]:
|
||||
"""
|
||||
Load SARM model, dataset, and preprocessor.
|
||||
|
||||
Returns:
|
||||
Tuple of (dataset, reward_model, preprocessor)
|
||||
"""
|
||||
logging.info(f"Loading model: {reward_model_path}")
|
||||
reward_model = SARMRewardModel.from_pretrained(reward_model_path)
|
||||
reward_model.config.device = device
|
||||
reward_model.to(device).eval()
|
||||
|
||||
image_key = reward_model.config.image_key
|
||||
state_key = reward_model.config.state_key
|
||||
delta_indices = reward_model.config.observation_delta_indices
|
||||
|
||||
logging.info(f"Loading dataset: {dataset_repo_id}")
|
||||
temp_dataset = LeRobotDataset(dataset_repo_id, download_videos=True)
|
||||
fps = temp_dataset.fps
|
||||
|
||||
delta_timestamps = {
|
||||
image_key: [idx / fps for idx in delta_indices],
|
||||
state_key: [idx / fps for idx in delta_indices],
|
||||
}
|
||||
dataset = LeRobotDataset(dataset_repo_id, delta_timestamps=delta_timestamps)
|
||||
logging.info(f"Dataset: {dataset.num_episodes} episodes, {dataset.num_frames} frames")
|
||||
|
||||
preprocess, _ = make_sarm_pre_post_processors(
|
||||
config=reward_model.config,
|
||||
dataset_stats=dataset.meta.stats,
|
||||
dataset_meta=dataset.meta,
|
||||
)
|
||||
|
||||
return dataset, reward_model, preprocess
|
||||
|
||||
|
||||
def to_numpy_image(img) -> np.ndarray:
|
||||
"""Convert image tensor to numpy uint8 (H, W, C)."""
|
||||
if isinstance(img, torch.Tensor):
|
||||
img = img.cpu().numpy()
|
||||
if img.ndim == 4:
|
||||
# Take center frame for bidirectional sampling
|
||||
img = img[img.shape[0] // 2]
|
||||
if img.shape[0] in [1, 3]:
|
||||
img = np.transpose(img, (1, 2, 0))
|
||||
if img.dtype != np.uint8:
|
||||
# Handle normalized images (may have negative values or values > 1)
|
||||
img = img.astype(np.float32)
|
||||
img = (img - img.min()) / (img.max() - img.min() + 1e-8) # Normalize to [0, 1]
|
||||
img = (img * 255).astype(np.uint8)
|
||||
return img
|
||||
|
||||
|
||||
def visualize_episode(
|
||||
frames, progress_preds, stage_preds, title, output_path, stage_labels, gt_progress=None, gt_stages=None
|
||||
):
|
||||
"""Create visualization with progress plot, stage probabilities, and sample frames.
|
||||
|
||||
Same as sarm_inference_visualization.py
|
||||
"""
|
||||
num_stages = stage_preds.shape[1]
|
||||
colors = plt.cm.tab10(np.linspace(0, 1, num_stages))
|
||||
frame_indices = np.arange(len(progress_preds))
|
||||
|
||||
fig = plt.figure(figsize=(14, 12))
|
||||
gs = gridspec.GridSpec(3, 1, height_ratios=[2, 1, 1], hspace=0.3)
|
||||
ax_progress, ax_stages, ax_frames = fig.add_subplot(gs[0]), fig.add_subplot(gs[1]), fig.add_subplot(gs[2])
|
||||
|
||||
# Progress plot
|
||||
ax_progress.plot(frame_indices, progress_preds, linewidth=2, color="#2E86AB", label="Predicted")
|
||||
ax_progress.fill_between(frame_indices, 0, progress_preds, alpha=0.3, color="#2E86AB")
|
||||
if gt_progress is not None:
|
||||
ax_progress.plot(
|
||||
frame_indices, gt_progress, linewidth=2, color="#28A745", linestyle="--", label="Ground Truth"
|
||||
)
|
||||
ax_progress.axhline(y=1.0, color="gray", linestyle="--", alpha=0.5)
|
||||
ax_progress.set_ylabel("Progress")
|
||||
ax_progress.set_title(f'Task: "{title}"', fontweight="bold")
|
||||
ax_progress.set_ylim(-0.05, 1.1)
|
||||
ax_progress.legend(loc="upper left")
|
||||
ax_progress.grid(True, alpha=0.3)
|
||||
|
||||
# Stage predictions
|
||||
ax_stages.stackplot(
|
||||
frame_indices,
|
||||
*[stage_preds[:, i] for i in range(num_stages)],
|
||||
colors=colors,
|
||||
alpha=0.8,
|
||||
labels=stage_labels,
|
||||
)
|
||||
if gt_stages is not None:
|
||||
for change_idx in np.where(np.diff(gt_stages) != 0)[0] + 1:
|
||||
ax_stages.axvline(x=change_idx, color="black", linestyle="-", alpha=0.7, linewidth=1.5)
|
||||
ax_stages.set_xlabel("Frame")
|
||||
ax_stages.set_ylabel("Stage Probability")
|
||||
ax_stages.set_ylim(0, 1)
|
||||
ax_stages.legend(loc="upper left", ncol=min(num_stages, 5), fontsize=8)
|
||||
ax_stages.grid(True, alpha=0.3)
|
||||
|
||||
# Sample frames
|
||||
ax_frames.axis("off")
|
||||
num_sample = 8
|
||||
sample_indices = np.linspace(0, len(frames) - 1, num_sample, dtype=int)
|
||||
h, w = frames[0].shape[:2]
|
||||
combined = np.zeros((h, w * num_sample, 3), dtype=np.uint8)
|
||||
for i, idx in enumerate(sample_indices):
|
||||
frame = frames[idx]
|
||||
if frame.shape[-1] == 1:
|
||||
frame = np.repeat(frame, 3, axis=-1)
|
||||
combined[:, i * w : (i + 1) * w] = frame
|
||||
stage_name = stage_labels[np.argmax(stage_preds[idx])][:12]
|
||||
ax_frames.text(
|
||||
i * w + w / 2,
|
||||
-10,
|
||||
f"Frame {idx}\n{progress_preds[idx]:.2f}\n{stage_name}",
|
||||
ha="center",
|
||||
va="top",
|
||||
fontsize=7,
|
||||
)
|
||||
ax_frames.imshow(combined)
|
||||
ax_frames.set_title("Sample Frames", pad=20)
|
||||
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
plt.savefig(output_path, dpi=150, bbox_inches="tight")
|
||||
plt.close()
|
||||
print(f"Saved: {output_path}")
|
||||
|
||||
|
||||
def visualize_sarm_predictions(
|
||||
dataset: LeRobotDataset,
|
||||
reward_model: SARMRewardModel,
|
||||
preprocess,
|
||||
episode_indices: list[int],
|
||||
head_mode: str,
|
||||
output_dir: Path,
|
||||
num_display_frames: int = 5,
|
||||
stride: int = 1,
|
||||
):
|
||||
"""
|
||||
Visualize SARM predictions for multiple episodes.
|
||||
|
||||
Computes predictions for every frame by default. With stride > 1, computes predictions
|
||||
every N frames and interpolates (progress + stage probabilities) for visualization.
|
||||
|
||||
Args:
|
||||
dataset: LeRobotDataset with delta_timestamps configured
|
||||
reward_model: Loaded SARM model
|
||||
preprocess: Preprocessor from make_sarm_pre_post_processors
|
||||
episode_indices: List of episode indices to visualize
|
||||
head_mode: "sparse", "dense", or "both"
|
||||
output_dir: Directory to save visualizations
|
||||
num_display_frames: Number of frames to display in thumbnail strip (default: 5)
|
||||
stride: Compute predictions every N frames, interpolate the rest (default: 1)
|
||||
"""
|
||||
output_dir = Path(output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
image_key = reward_model.config.image_key
|
||||
state_key = reward_model.config.state_key
|
||||
dual_mode = reward_model.config.uses_dual_heads
|
||||
device = reward_model.device
|
||||
|
||||
# Center frame index for bidirectional sampling
|
||||
target_idx = reward_model.config.n_obs_steps // 2
|
||||
|
||||
# Determine which heads to visualize
|
||||
schemes_to_viz = []
|
||||
if head_mode in ("sparse", "both") or not dual_mode:
|
||||
schemes_to_viz.append("sparse")
|
||||
if head_mode in ("dense", "both") and dual_mode:
|
||||
schemes_to_viz.append("dense")
|
||||
|
||||
# Set preprocessor to eval mode to disable augmentations
|
||||
if hasattr(preprocess, "eval"):
|
||||
preprocess.eval()
|
||||
for step in preprocess.steps:
|
||||
if hasattr(step, "eval"):
|
||||
step.eval()
|
||||
|
||||
for episode_idx in episode_indices:
|
||||
ep = dataset.meta.episodes[episode_idx]
|
||||
ep_start = ep["dataset_from_index"]
|
||||
ep_end = ep["dataset_to_index"]
|
||||
task = dataset[ep_start].get("task", "perform the task")
|
||||
num_frames = ep_end - ep_start
|
||||
|
||||
# Select frames for display thumbnails (evenly sampled from begin to end)
|
||||
display_indices = set(
|
||||
[
|
||||
ep_start + int(i * (num_frames - 1) / (num_display_frames - 1))
|
||||
for i in range(num_display_frames)
|
||||
]
|
||||
if num_frames >= num_display_frames
|
||||
else list(range(ep_start, ep_end))
|
||||
)
|
||||
viz_frames = {}
|
||||
|
||||
# Load display frames up-front (stride mode might skip them otherwise).
|
||||
for frame_idx in display_indices:
|
||||
sample = dataset[frame_idx]
|
||||
viz_frames[frame_idx] = to_numpy_image(sample[image_key])
|
||||
|
||||
# Initialize storage for each scheme
|
||||
scheme_data = {}
|
||||
for scheme in schemes_to_viz:
|
||||
num_stages = getattr(reward_model.config, f"num_{scheme}_stages")
|
||||
scheme_data[scheme] = {
|
||||
"viz_progress": np.full(num_frames, np.nan),
|
||||
"viz_stages": np.full((num_frames, num_stages), np.nan),
|
||||
"viz_gt_progress": np.full(num_frames, np.nan),
|
||||
"viz_gt_stages": np.full(num_frames, np.nan),
|
||||
"target_key": f"{scheme}_targets",
|
||||
"num_stages": num_stages,
|
||||
"temporal_props": getattr(reward_model.config, f"{scheme}_temporal_proportions"),
|
||||
"subtask_names": getattr(reward_model.config, f"{scheme}_subtask_names"),
|
||||
}
|
||||
|
||||
if stride > 1:
|
||||
logging.info(f"Visualization stride={stride}: inferring every {stride} frames and interpolating")
|
||||
|
||||
# Process frames one at a time to avoid memory buildup
|
||||
frame_indices = list(range(ep_start, ep_end, stride))
|
||||
if (ep_end - 1) not in frame_indices:
|
||||
frame_indices.append(ep_end - 1)
|
||||
frame_indices = sorted(set(frame_indices))
|
||||
|
||||
for frame_idx in tqdm(frame_indices, desc=f"Episode {episode_idx}", leave=False):
|
||||
local_idx = frame_idx - ep_start
|
||||
sample = dataset[frame_idx]
|
||||
|
||||
batch = {
|
||||
image_key: sample[image_key],
|
||||
"task": task,
|
||||
"index": frame_idx,
|
||||
"episode_index": episode_idx,
|
||||
}
|
||||
if state_key in sample:
|
||||
batch[state_key] = sample[state_key]
|
||||
|
||||
with torch.no_grad():
|
||||
processed = preprocess(batch)
|
||||
video_features = processed["video_features"].to(device)
|
||||
text_features = processed["text_features"].to(device)
|
||||
state_features = processed.get("state_features")
|
||||
if state_features is not None:
|
||||
state_features = state_features.to(device)
|
||||
lengths = processed.get("lengths")
|
||||
|
||||
for scheme in schemes_to_viz:
|
||||
sd = scheme_data[scheme]
|
||||
|
||||
# Ground truth
|
||||
# In stride visualization mode, ground-truth plots can be misleading
|
||||
# (only sparse points are available), so we skip GT.
|
||||
if stride == 1 and sd["target_key"] in processed:
|
||||
gt_target = processed[sd["target_key"]][0, target_idx].cpu().item()
|
||||
sd["viz_gt_stages"][local_idx] = int(gt_target)
|
||||
sd["viz_gt_progress"][local_idx] = normalize_stage_tau(
|
||||
gt_target,
|
||||
num_stages=sd["num_stages"],
|
||||
temporal_proportions=sd["temporal_props"],
|
||||
subtask_names=sd["subtask_names"],
|
||||
)
|
||||
|
||||
# Predictions
|
||||
reward, stage_probs = reward_model.calculate_rewards(
|
||||
text_embeddings=text_features,
|
||||
video_embeddings=video_features,
|
||||
state_features=state_features,
|
||||
lengths=lengths,
|
||||
return_all_frames=True,
|
||||
return_stages=True,
|
||||
head_mode=scheme,
|
||||
)
|
||||
|
||||
# Handle both tensor and numpy outputs
|
||||
if isinstance(reward, torch.Tensor):
|
||||
reward = reward.cpu().numpy()
|
||||
stage_probs = stage_probs.cpu().numpy()
|
||||
|
||||
if reward.ndim == 2:
|
||||
sd["viz_progress"][local_idx] = reward[0, target_idx]
|
||||
sd["viz_stages"][local_idx] = stage_probs[0, target_idx, :]
|
||||
else:
|
||||
sd["viz_progress"][local_idx] = reward[target_idx]
|
||||
sd["viz_stages"][local_idx] = stage_probs[target_idx, :]
|
||||
|
||||
# Clear GPU memory after each frame
|
||||
del processed, video_features, text_features
|
||||
if state_features is not None:
|
||||
del state_features
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Interpolate predictions back to per-frame arrays for smooth visualization.
|
||||
if stride > 1:
|
||||
all_local = np.arange(num_frames)
|
||||
for scheme in schemes_to_viz:
|
||||
sd = scheme_data[scheme]
|
||||
|
||||
valid = np.isfinite(sd["viz_progress"])
|
||||
valid_idx = np.where(valid)[0]
|
||||
if valid_idx.size >= 1:
|
||||
sd["viz_progress"] = interpolate_progress(
|
||||
valid_idx, sd["viz_progress"][valid_idx], all_local
|
||||
)
|
||||
|
||||
stage_interp = np.zeros_like(sd["viz_stages"], dtype=np.float32)
|
||||
for s in range(sd["num_stages"]):
|
||||
stage_interp[:, s] = interpolate_progress(
|
||||
valid_idx, sd["viz_stages"][valid_idx, s], all_local
|
||||
)
|
||||
|
||||
stage_interp = np.clip(stage_interp, 0.0, 1.0)
|
||||
row_sums = stage_interp.sum(axis=1, keepdims=True)
|
||||
nz = row_sums.squeeze(-1) > 0
|
||||
stage_interp[nz] = stage_interp[nz] / row_sums[nz]
|
||||
sd["viz_stages"] = stage_interp
|
||||
else:
|
||||
# No valid points: keep NaNs/zeros; visualization will be empty.
|
||||
sd["viz_stages"] = np.nan_to_num(sd["viz_stages"], nan=0.0)
|
||||
|
||||
# Generate visualization for each head
|
||||
ordered_viz_frames = [viz_frames[idx] for idx in sorted(display_indices)]
|
||||
for scheme in schemes_to_viz:
|
||||
sd = scheme_data[scheme]
|
||||
stage_labels = sd["subtask_names"] or [f"Stage {i + 1}" for i in range(sd["num_stages"])]
|
||||
viz_path = output_dir / f"sarm_prediction_ep{episode_idx}_{scheme}.png"
|
||||
|
||||
visualize_episode(
|
||||
frames=np.array(ordered_viz_frames),
|
||||
progress_preds=sd["viz_progress"],
|
||||
stage_preds=sd["viz_stages"],
|
||||
title=f"{task} (Episode {episode_idx})",
|
||||
output_path=viz_path,
|
||||
stage_labels=stage_labels,
|
||||
gt_progress=sd["viz_gt_progress"] if not np.all(np.isnan(sd["viz_gt_progress"])) else None,
|
||||
gt_stages=sd["viz_gt_stages"] if not np.all(np.isnan(sd["viz_gt_stages"])) else None,
|
||||
)
|
||||
|
||||
# Clear memory between episodes
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
logging.info(f"Visualizations saved to: {output_dir.absolute()}")
|
||||
|
||||
|
||||
def generate_all_frame_indices(ep_start: int, ep_end: int, frame_gap: int = 30) -> list[int]:
|
||||
"""Generate all frame indices, ordered by offset for cache-friendly access.
|
||||
|
||||
Orders frames as: [0, 30, 60...], [1, 31, 61...], ..., [29, 59, 89...]
|
||||
This groups frames that share similar temporal windows together.
|
||||
"""
|
||||
num_frames = ep_end - ep_start
|
||||
indices = []
|
||||
for offset in range(frame_gap):
|
||||
for frame_rel in range(offset, num_frames, frame_gap):
|
||||
indices.append(ep_start + frame_rel)
|
||||
return indices
|
||||
|
||||
|
||||
def interpolate_progress(
|
||||
computed_indices: np.ndarray,
|
||||
computed_values: np.ndarray,
|
||||
all_indices: np.ndarray,
|
||||
) -> np.ndarray:
|
||||
"""Linearly interpolate values to fill in gaps (robust to NaNs / edge cases)."""
|
||||
computed_indices = np.asarray(computed_indices)
|
||||
computed_values = np.asarray(computed_values)
|
||||
all_indices = np.asarray(all_indices)
|
||||
|
||||
mask = np.isfinite(computed_values)
|
||||
if mask.sum() == 0:
|
||||
return np.full(all_indices.shape, np.nan, dtype=np.float32)
|
||||
if mask.sum() == 1:
|
||||
return np.full(all_indices.shape, float(computed_values[mask][0]), dtype=np.float32)
|
||||
|
||||
out = np.interp(all_indices, computed_indices[mask], computed_values[mask])
|
||||
return out.astype(np.float32)
|
||||
|
||||
|
||||
def compute_sarm_progress(
|
||||
dataset_repo_id: str,
|
||||
reward_model_path: str,
|
||||
output_path: str | None = None,
|
||||
head_mode: str = "sparse",
|
||||
device: str = "cuda",
|
||||
num_visualizations: int = 5,
|
||||
output_dir: str = "./sarm_viz",
|
||||
stride: int = 1,
|
||||
):
|
||||
"""
|
||||
Compute SARM progress predictions for all frames in a dataset.
|
||||
|
||||
Args:
|
||||
dataset_repo_id: HuggingFace dataset repo ID or local path
|
||||
reward_model_path: Path to pretrained SARM model
|
||||
output_path: Path to save results. If None, saves to dataset's cache directory
|
||||
head_mode: SARM head to use ("sparse", "dense", or "both")
|
||||
device: Device to use for inference
|
||||
num_visualizations: Number of episodes to visualize (0 to skip)
|
||||
output_dir: Directory to save visualizations
|
||||
stride: Compute progress every N frames, interpolate the rest (default: 1 = every frame)
|
||||
"""
|
||||
dataset, reward_model, preprocess = load_sarm_resources(dataset_repo_id, reward_model_path, device)
|
||||
|
||||
# Set preprocessor to eval mode to disable augmentations
|
||||
if hasattr(preprocess, "eval"):
|
||||
preprocess.eval()
|
||||
for step in preprocess.steps:
|
||||
if hasattr(step, "eval"):
|
||||
step.eval()
|
||||
|
||||
image_key = reward_model.config.image_key
|
||||
state_key = reward_model.config.state_key
|
||||
frame_gap = reward_model.config.frame_gap
|
||||
num_episodes = dataset.num_episodes
|
||||
total_frames = dataset.num_frames
|
||||
logging.info(f"Processing {total_frames} frames across {num_episodes} episodes")
|
||||
|
||||
# Determine which heads to compute
|
||||
dual_mode = reward_model.config.uses_dual_heads
|
||||
compute_sparse = head_mode in ("sparse", "both") or not dual_mode
|
||||
compute_dense = head_mode in ("dense", "both") and dual_mode
|
||||
|
||||
# Storage arrays
|
||||
all_indices = []
|
||||
all_episode_indices = []
|
||||
all_frame_indices = []
|
||||
all_progress_sparse = [] if compute_sparse else None
|
||||
all_progress_dense = [] if compute_dense else None
|
||||
|
||||
if stride > 1:
|
||||
logging.info(f"Using stride={stride}: computing every {stride} frames, interpolating the rest")
|
||||
|
||||
# Process all episodes
|
||||
for episode_idx in tqdm(range(num_episodes), desc="Episodes"):
|
||||
ep = dataset.meta.episodes[episode_idx]
|
||||
ep_start = ep["dataset_from_index"]
|
||||
ep_end = ep["dataset_to_index"]
|
||||
|
||||
# Get task description
|
||||
task = dataset[ep_start].get("task", "perform the task")
|
||||
|
||||
# Generate frames to compute (with stride applied)
|
||||
all_ep_indices = generate_all_frame_indices(ep_start, ep_end, frame_gap)
|
||||
if stride > 1:
|
||||
# Only compute every stride-th frame (relative to episode start)
|
||||
compute_indices = [idx for idx in all_ep_indices if (idx - ep_start) % stride == 0]
|
||||
# Always include last frame for better interpolation at episode end
|
||||
last_frame = ep_end - 1
|
||||
if last_frame not in compute_indices:
|
||||
compute_indices.append(last_frame)
|
||||
compute_indices = sorted(set(compute_indices))
|
||||
else:
|
||||
compute_indices = all_ep_indices
|
||||
|
||||
center_idx = reward_model.config.n_obs_steps // 2 # Center of bidirectional window
|
||||
|
||||
# Dictionary to collect results
|
||||
frame_results = {}
|
||||
|
||||
for query_idx in tqdm(compute_indices, desc=f" Ep {episode_idx}", leave=False):
|
||||
try:
|
||||
sample = dataset[query_idx]
|
||||
|
||||
batch = {
|
||||
image_key: sample[image_key],
|
||||
"task": task,
|
||||
"index": query_idx,
|
||||
"episode_index": episode_idx,
|
||||
}
|
||||
if state_key in sample:
|
||||
batch[state_key] = sample[state_key]
|
||||
|
||||
with torch.no_grad():
|
||||
processed = preprocess(batch)
|
||||
video_features = processed["video_features"].to(device)
|
||||
text_features = processed["text_features"].to(device)
|
||||
state_features = processed.get("state_features")
|
||||
if state_features is not None:
|
||||
state_features = state_features.to(device)
|
||||
lengths = processed.get("lengths")
|
||||
|
||||
sparse_val = np.nan
|
||||
dense_val = np.nan
|
||||
|
||||
# Compute sparse prediction for center frame
|
||||
if compute_sparse:
|
||||
sparse_progress = reward_model.calculate_rewards(
|
||||
text_embeddings=text_features,
|
||||
video_embeddings=video_features,
|
||||
state_features=state_features,
|
||||
lengths=lengths,
|
||||
return_all_frames=True,
|
||||
head_mode="sparse",
|
||||
)
|
||||
sparse_val = float(
|
||||
sparse_progress[0, center_idx]
|
||||
if sparse_progress.ndim == 2
|
||||
else sparse_progress[center_idx]
|
||||
)
|
||||
|
||||
# Compute dense prediction for center frame
|
||||
if compute_dense:
|
||||
dense_progress = reward_model.calculate_rewards(
|
||||
text_embeddings=text_features,
|
||||
video_embeddings=video_features,
|
||||
state_features=state_features,
|
||||
lengths=lengths,
|
||||
return_all_frames=True,
|
||||
head_mode="dense",
|
||||
)
|
||||
dense_val = float(
|
||||
dense_progress[0, center_idx]
|
||||
if dense_progress.ndim == 2
|
||||
else dense_progress[center_idx]
|
||||
)
|
||||
|
||||
frame_results[query_idx] = (sparse_val, dense_val)
|
||||
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to process frame {query_idx}: {e}")
|
||||
|
||||
# Interpolate to get values for all frames
|
||||
computed_indices = np.array(sorted(frame_results.keys()))
|
||||
computed_sparse = (
|
||||
np.array([frame_results[i][0] for i in computed_indices]) if compute_sparse else None
|
||||
)
|
||||
computed_dense = np.array([frame_results[i][1] for i in computed_indices]) if compute_dense else None
|
||||
|
||||
# All frame indices for this episode
|
||||
all_frame_idx_array = np.arange(ep_start, ep_end)
|
||||
|
||||
if stride > 1 and len(computed_indices) > 1:
|
||||
# Interpolate progress values
|
||||
if compute_sparse:
|
||||
interp_sparse = interpolate_progress(computed_indices, computed_sparse, all_frame_idx_array)
|
||||
if compute_dense:
|
||||
interp_dense = interpolate_progress(computed_indices, computed_dense, all_frame_idx_array)
|
||||
else:
|
||||
# No interpolation needed
|
||||
interp_sparse = computed_sparse if compute_sparse else None
|
||||
interp_dense = computed_dense if compute_dense else None
|
||||
|
||||
# Store results for all frames
|
||||
for i, frame_idx in enumerate(all_frame_idx_array):
|
||||
local_idx = frame_idx - ep_start
|
||||
all_indices.append(frame_idx)
|
||||
all_episode_indices.append(episode_idx)
|
||||
all_frame_indices.append(local_idx)
|
||||
if compute_sparse:
|
||||
if stride > 1 and len(computed_indices) > 1:
|
||||
all_progress_sparse.append(float(interp_sparse[i]))
|
||||
elif frame_idx in frame_results:
|
||||
all_progress_sparse.append(frame_results[frame_idx][0])
|
||||
else:
|
||||
all_progress_sparse.append(np.nan)
|
||||
if compute_dense:
|
||||
if stride > 1 and len(computed_indices) > 1:
|
||||
all_progress_dense.append(float(interp_dense[i]))
|
||||
elif frame_idx in frame_results:
|
||||
all_progress_dense.append(frame_results[frame_idx][1])
|
||||
else:
|
||||
all_progress_dense.append(np.nan)
|
||||
|
||||
# Create output table
|
||||
table_data = {
|
||||
"index": np.array(all_indices, dtype=np.int64),
|
||||
"episode_index": np.array(all_episode_indices, dtype=np.int64),
|
||||
"frame_index": np.array(all_frame_indices, dtype=np.int64),
|
||||
}
|
||||
if compute_sparse:
|
||||
table_data["progress_sparse"] = np.array(all_progress_sparse, dtype=np.float32)
|
||||
if compute_dense:
|
||||
table_data["progress_dense"] = np.array(all_progress_dense, dtype=np.float32)
|
||||
|
||||
# Sort by index
|
||||
df = pa.table(table_data).to_pandas()
|
||||
df = df.sort_values("index").reset_index(drop=True)
|
||||
final_table = pa.Table.from_pandas(df, preserve_index=False)
|
||||
|
||||
# Add metadata with reward model path
|
||||
metadata = {b"reward_model_path": reward_model_path.encode()}
|
||||
final_table = final_table.replace_schema_metadata(metadata)
|
||||
|
||||
# Determine output path
|
||||
output_path = Path(dataset.root) / "sarm_progress.parquet" if output_path is None else Path(output_path)
|
||||
|
||||
# Save
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
pq.write_table(final_table, output_path)
|
||||
logging.info(f"Saved {len(final_table)} frame progress values to {output_path}")
|
||||
|
||||
# Print statistics
|
||||
if "progress_sparse" in df.columns:
|
||||
valid = df["progress_sparse"].dropna()
|
||||
logging.info(
|
||||
f"Sparse progress: mean={valid.mean():.4f}, std={valid.std():.4f}, "
|
||||
f"min={valid.min():.4f}, max={valid.max():.4f}"
|
||||
)
|
||||
|
||||
if "progress_dense" in df.columns:
|
||||
valid = df["progress_dense"].dropna()
|
||||
logging.info(
|
||||
f"Dense progress: mean={valid.mean():.4f}, std={valid.std():.4f}, "
|
||||
f"min={valid.min():.4f}, max={valid.max():.4f}"
|
||||
)
|
||||
|
||||
# Visualize episodes after processing
|
||||
if num_visualizations > 0:
|
||||
viz_episodes = list(range(min(num_visualizations, num_episodes)))
|
||||
logging.info(f"Generating {len(viz_episodes)} visualizations...")
|
||||
visualize_sarm_predictions(
|
||||
dataset=dataset,
|
||||
reward_model=reward_model,
|
||||
preprocess=preprocess,
|
||||
episode_indices=viz_episodes,
|
||||
head_mode=head_mode,
|
||||
output_dir=Path(output_dir),
|
||||
stride=stride,
|
||||
)
|
||||
|
||||
return output_path
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Compute SARM progress values for RA-BC weighting or visualize SARM predictions",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
# Full RA-BC computation with visualizations
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
|
||||
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
|
||||
--reward-model-path pepijn223/sarm_single_uni4
|
||||
|
||||
# Visualize predictions only (no RA-BC computation)
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
|
||||
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
|
||||
--reward-model-path pepijn223/sarm_single_uni4 \\
|
||||
--visualize-only \\
|
||||
--num-visualizations 10
|
||||
""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-repo-id",
|
||||
type=str,
|
||||
required=True,
|
||||
help="HuggingFace dataset repo ID or local path",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--reward-model-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to pretrained SARM model (reads from existing parquet metadata if not provided)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Output path for parquet. If not set, saves to dataset's cache directory",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--head-mode",
|
||||
type=str,
|
||||
default="sparse",
|
||||
choices=["sparse", "dense", "both"],
|
||||
help="SARM head to use (default: sparse)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="cuda",
|
||||
help="Device to use (default: cuda)",
|
||||
)
|
||||
# Visualization options
|
||||
parser.add_argument(
|
||||
"--visualize-only",
|
||||
action="store_true",
|
||||
help="Only visualize SARM predictions (no RA-BC computation)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-visualizations",
|
||||
type=int,
|
||||
default=5,
|
||||
help="Number of episodes to visualize (default: 5, set to 0 to skip)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=str,
|
||||
default="./sarm_viz",
|
||||
help="Output directory for visualizations (default: ./sarm_viz)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push-to-hub",
|
||||
action="store_true",
|
||||
help="Upload progress file to the dataset repo on HuggingFace Hub",
|
||||
default=True,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--stride",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Compute progress every N frames, interpolate the rest (default: 1 = every frame)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
||||
|
||||
# Try to get reward_model_path from parquet metadata if not provided
|
||||
reward_model_path = args.reward_model_path
|
||||
if reward_model_path is None:
|
||||
# Load dataset to find parquet path
|
||||
temp_dataset = LeRobotDataset(args.dataset_repo_id, download_videos=False)
|
||||
parquet_path = Path(temp_dataset.root) / "sarm_progress.parquet"
|
||||
reward_model_path = get_reward_model_path_from_parquet(parquet_path)
|
||||
if reward_model_path:
|
||||
logging.info(f"Using reward model from parquet metadata: {reward_model_path}")
|
||||
else:
|
||||
raise ValueError(
|
||||
"--reward-model-path is required (no existing parquet with model metadata found)"
|
||||
)
|
||||
|
||||
# Handle visualize-only mode
|
||||
if args.visualize_only:
|
||||
dataset, reward_model, preprocess = load_sarm_resources(
|
||||
args.dataset_repo_id, reward_model_path, args.device
|
||||
)
|
||||
logging.info(f"Visualization-only mode: visualizing {args.num_visualizations} episodes")
|
||||
viz_episodes = list(range(min(args.num_visualizations, dataset.num_episodes)))
|
||||
visualize_sarm_predictions(
|
||||
dataset=dataset,
|
||||
reward_model=reward_model,
|
||||
preprocess=preprocess,
|
||||
episode_indices=viz_episodes,
|
||||
head_mode=args.head_mode,
|
||||
output_dir=Path(args.output_dir),
|
||||
stride=args.stride,
|
||||
)
|
||||
print(f"\nVisualizations saved to: {Path(args.output_dir).absolute()}")
|
||||
return
|
||||
|
||||
# Full RABC computation (compute_sarm_progress loads model/dataset itself)
|
||||
output_path = compute_sarm_progress(
|
||||
dataset_repo_id=args.dataset_repo_id,
|
||||
reward_model_path=reward_model_path,
|
||||
output_path=args.output_path,
|
||||
head_mode=args.head_mode,
|
||||
device=args.device,
|
||||
num_visualizations=args.num_visualizations,
|
||||
output_dir=args.output_dir,
|
||||
stride=args.stride,
|
||||
)
|
||||
|
||||
print(f"\nSARM progress values saved to: {output_path}")
|
||||
|
||||
# Upload to Hub if requested
|
||||
if args.push_to_hub:
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
api = HfApi()
|
||||
hub_path = "sarm_progress.parquet"
|
||||
|
||||
print(f"\nUploading to Hub: {args.dataset_repo_id}/{hub_path}")
|
||||
api.upload_file(
|
||||
path_or_fileobj=str(output_path),
|
||||
path_in_repo=hub_path,
|
||||
repo_id=args.dataset_repo_id,
|
||||
repo_type="dataset",
|
||||
)
|
||||
print(
|
||||
f"Successfully uploaded to: https://huggingface.co/datasets/{args.dataset_repo_id}/blob/main/{hub_path}"
|
||||
)
|
||||
|
||||
print("\nTo use in training, add to your config:")
|
||||
print(" use_rabc: true")
|
||||
print(f" rabc_progress_path: hf://datasets/{args.dataset_repo_id}/{hub_path}")
|
||||
print(" rabc_head_mode: sparse # or dense")
|
||||
else:
|
||||
print("\nTo use in training, add to your config:")
|
||||
print(" use_rabc: true")
|
||||
print(f" rabc_progress_path: {output_path}")
|
||||
print(" rabc_head_mode: sparse # or dense")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,248 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 Qianzhong Chen, Justin Yu, Mac Schwager, Pieter Abbeel, Yide Shentu, Philipp Wu
|
||||
# and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
SARM: Stage-Aware Reward Modeling for Long Horizon Robot Manipulation.
|
||||
Paper: https://arxiv.org/abs/2509.25358
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("sarm")
|
||||
@dataclass
|
||||
class SARMConfig(PreTrainedConfig):
|
||||
"""Configuration class for SARM (Stage-Aware Reward Modeling).
|
||||
|
||||
Supports three annotation modes:
|
||||
|
||||
1. single_stage (default): No annotations needed. Uses the episode's task description
|
||||
as a single stage covering the entire episode.
|
||||
|
||||
2. dense_only: Uses dense (fine-grained) annotations from VLM, with an auto-generated
|
||||
single sparse "task" stage covering the full episode. The dense head learns detailed
|
||||
subtask progression while sparse provides overall task completion.
|
||||
|
||||
3. dual: Full dual-head mode with both sparse (high-level) and dense (fine-grained)
|
||||
annotations from VLM. Both heads are trained on their respective annotations.
|
||||
|
||||
The annotation_mode determines how sparse_temporal_proportions and dense_temporal_proportions
|
||||
are loaded/generated during model initialization.
|
||||
"""
|
||||
|
||||
annotation_mode: str = "single_stage" # "single_stage", "dense_only", or "dual"
|
||||
n_obs_steps: int = 8 # Number of observation history steps
|
||||
frame_gap: int = 30 # Frame gap between frames (at 30 fps = 1 second)
|
||||
max_rewind_steps: int = 4 # Maximum rewind steps for temporal augmentation
|
||||
|
||||
# Total frames = 1 + n_obs_steps + max_rewind_steps (computed in property)
|
||||
# During training with rewind: [obs_frames] + [rewind_frames]
|
||||
# During inference: [obs_frames] only
|
||||
|
||||
# Architecture params
|
||||
image_dim: int = 512
|
||||
text_dim: int = 512
|
||||
hidden_dim: int = 768
|
||||
num_heads: int = 12
|
||||
num_layers: int = 8
|
||||
max_state_dim: int = 32
|
||||
drop_n_last_frames: int = 1
|
||||
batch_size: int = 64
|
||||
clip_batch_size: int = 64
|
||||
dropout: float = 0.1
|
||||
stage_loss_weight: float = 1.0 # Weight for stage classification loss when using subtask annotations
|
||||
|
||||
rewind_probability: float = 0.8
|
||||
language_perturbation_probability: float = 0.2
|
||||
|
||||
# Sparse annotations (high-level stages)
|
||||
num_sparse_stages: int = 1
|
||||
sparse_subtask_names: list | None = None
|
||||
sparse_temporal_proportions: list | None = None
|
||||
|
||||
# Dense annotations (fine-grained stages)
|
||||
num_dense_stages: int | None = None
|
||||
dense_subtask_names: list | None = None
|
||||
dense_temporal_proportions: list | None = None
|
||||
|
||||
pretrained_model_path: str | None = None
|
||||
device: str | None = None
|
||||
image_key: str = "observation.images.top" # Key for image used from the dataset
|
||||
state_key: str = "observation.state"
|
||||
|
||||
# Populated by the processor (video_features, state_features, text_features)
|
||||
input_features: dict = field(default_factory=lambda: {})
|
||||
|
||||
# Output features (updated in __post_init__)
|
||||
output_features: dict = field(
|
||||
default_factory=lambda: {
|
||||
"stage": PolicyFeature(shape=(9, 5), type=FeatureType.REWARD),
|
||||
"progress": PolicyFeature(shape=(9, 1), type=FeatureType.REWARD),
|
||||
}
|
||||
)
|
||||
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.MEAN_STD,
|
||||
"LANGUAGE": NormalizationMode.IDENTITY,
|
||||
"REWARD": NormalizationMode.IDENTITY,
|
||||
}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
if self.annotation_mode not in ["single_stage", "dense_only", "dual"]:
|
||||
raise ValueError(
|
||||
f"annotation_mode must be 'single_stage', 'dense_only', or 'dual', got {self.annotation_mode}"
|
||||
)
|
||||
|
||||
if self.annotation_mode == "single_stage":
|
||||
# Use task description as stage name, full episode as one stage
|
||||
self.num_sparse_stages = 1
|
||||
self.sparse_subtask_names = ["task"]
|
||||
self.sparse_temporal_proportions = [1.0]
|
||||
self.num_dense_stages = None
|
||||
self.dense_subtask_names = None
|
||||
self.dense_temporal_proportions = None
|
||||
|
||||
elif self.annotation_mode == "dense_only":
|
||||
self.num_sparse_stages = 1
|
||||
self.sparse_subtask_names = ["task"]
|
||||
self.sparse_temporal_proportions = [1.0]
|
||||
|
||||
self.input_features = {}
|
||||
self.output_features = {}
|
||||
|
||||
if self.image_key:
|
||||
self.input_features[self.image_key] = PolicyFeature(shape=(480, 640, 3), type=FeatureType.VISUAL)
|
||||
|
||||
self.input_features[self.state_key] = PolicyFeature(
|
||||
shape=(self.max_state_dim,),
|
||||
type=FeatureType.STATE,
|
||||
)
|
||||
|
||||
# Update output features based on annotation_mode
|
||||
if self.annotation_mode in ["dense_only", "dual"]:
|
||||
self.output_features["sparse_stage"] = PolicyFeature(
|
||||
shape=(self.num_frames, self.num_sparse_stages), type=FeatureType.REWARD
|
||||
)
|
||||
self.output_features["sparse_progress"] = PolicyFeature(
|
||||
shape=(self.num_frames, 1), type=FeatureType.REWARD
|
||||
)
|
||||
dense_stages = self.num_dense_stages or self.num_sparse_stages
|
||||
self.output_features["dense_stage"] = PolicyFeature(
|
||||
shape=(self.num_frames, dense_stages), type=FeatureType.REWARD
|
||||
)
|
||||
self.output_features["dense_progress"] = PolicyFeature(
|
||||
shape=(self.num_frames, 1), type=FeatureType.REWARD
|
||||
)
|
||||
else:
|
||||
self.output_features["sparse_stage"] = PolicyFeature(
|
||||
shape=(self.num_frames, self.num_sparse_stages), type=FeatureType.REWARD
|
||||
)
|
||||
self.output_features["sparse_progress"] = PolicyFeature(
|
||||
shape=(self.num_frames, 1), type=FeatureType.REWARD
|
||||
)
|
||||
|
||||
if self.max_rewind_steps >= self.n_obs_steps:
|
||||
raise ValueError(
|
||||
f"max_rewind_steps ({self.max_rewind_steps}) must be less than n_obs_steps ({self.n_obs_steps})"
|
||||
)
|
||||
if self.num_sparse_stages < 1:
|
||||
raise ValueError(f"num_sparse_stages must be at least 1, got {self.num_sparse_stages}")
|
||||
if (
|
||||
self.annotation_mode in ["dense_only", "dual"]
|
||||
and self.num_dense_stages is not None
|
||||
and self.num_dense_stages < 2
|
||||
):
|
||||
raise ValueError(f"num_dense_stages must be at least 2, got {self.num_dense_stages}")
|
||||
|
||||
def get_optimizer_preset(self) -> AdamWConfig:
|
||||
"""Get default optimizer configuration for SARM training."""
|
||||
return AdamWConfig(
|
||||
lr=5e-5,
|
||||
weight_decay=1e-3,
|
||||
betas=(0.9, 0.999),
|
||||
eps=1e-8,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self) -> CosineDecayWithWarmupSchedulerConfig:
|
||||
"""Get default learning rate scheduler configuration."""
|
||||
return CosineDecayWithWarmupSchedulerConfig(
|
||||
peak_lr=5e-5,
|
||||
decay_lr=5e-6,
|
||||
num_warmup_steps=500,
|
||||
num_decay_steps=50000,
|
||||
)
|
||||
|
||||
def validate_features(self) -> None:
|
||||
pass
|
||||
|
||||
@property
|
||||
def uses_dual_heads(self) -> bool:
|
||||
"""Whether the model uses dual heads (dense_only or dual annotation modes)."""
|
||||
return self.annotation_mode in ["dense_only", "dual"]
|
||||
|
||||
@property
|
||||
def num_frames(self) -> int:
|
||||
"""Total number of frames in sequence.
|
||||
|
||||
For training: 1 + n_obs_steps + max_rewind_steps
|
||||
The sequence is: [obs_frames (n_obs_steps + 1)] + [rewind_frames (max_rewind_steps)]
|
||||
"""
|
||||
return 1 + self.n_obs_steps + self.max_rewind_steps
|
||||
|
||||
@property
|
||||
def max_length(self) -> int:
|
||||
return self.num_frames
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> list[int]:
|
||||
"""Bidirectional frame sampling centered on target frame.
|
||||
|
||||
Example with n_obs_steps=8, gap=30:
|
||||
Before: [-120, -90, -60, -30] (4 frames)
|
||||
Current: [0] (1 frame)
|
||||
After: [30, 60, 90, 120] (4 frames)
|
||||
Total: 9 frames
|
||||
"""
|
||||
half_steps = self.n_obs_steps // 2
|
||||
|
||||
past_deltas = [-self.frame_gap * i for i in range(half_steps, 0, -1)]
|
||||
future_deltas = [self.frame_gap * i for i in range(1, half_steps + 1)]
|
||||
obs_deltas = past_deltas + [0] + future_deltas
|
||||
|
||||
# Rewind placeholders
|
||||
rewind_deltas = [-self.frame_gap * (i + 1) for i in range(self.max_rewind_steps)]
|
||||
|
||||
return obs_deltas + rewind_deltas
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> None:
|
||||
"""SARM is a reward model, not an action policy."""
|
||||
return None
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
||||
@@ -1,793 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 Qianzhong Chen, Justin Yu, Mac Schwager, Pieter Abbeel, Yide Shentu, Philipp Wu
|
||||
# and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
SARM: Stage-Aware Reward Modeling for Long Horizon Robot Manipulation.
|
||||
|
||||
Paper: https://arxiv.org/abs/2509.25358
|
||||
|
||||
- StageTransformer: Predicts stage classification (sparse/dense)
|
||||
- SubtaskTransformer: Predicts within-stage progress (tau) conditioned on stage
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.sarm.configuration_sarm import SARMConfig
|
||||
from lerobot.policies.sarm.sarm_utils import (
|
||||
normalize_stage_tau,
|
||||
pad_state_to_max_dim,
|
||||
)
|
||||
|
||||
|
||||
class StageTransformer(nn.Module):
|
||||
"""
|
||||
Stage classification transformer for SARM.
|
||||
|
||||
Predicts which stage/subtask the current frame belongs to.
|
||||
Supports both sparse (high-level) and dense (fine-grained) annotation schemes.
|
||||
|
||||
Input streams: [vis_proj, lang_proj, state_proj] concatenated -> (B, N+2, T, D)
|
||||
Output: stage logits (B, T, num_classes)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int = 512,
|
||||
vis_emb_dim: int = 512,
|
||||
text_emb_dim: int = 512,
|
||||
state_dim: int = 32,
|
||||
n_layers: int = 6,
|
||||
n_heads: int = 8,
|
||||
dropout: float = 0.1,
|
||||
num_cameras: int = 1,
|
||||
num_classes_sparse: int = 4,
|
||||
num_classes_dense: int = 8,
|
||||
):
|
||||
super().__init__()
|
||||
self.d_model = d_model
|
||||
self.num_cameras = num_cameras
|
||||
|
||||
# Projections
|
||||
self.lang_proj = nn.Linear(text_emb_dim, d_model)
|
||||
self.visual_proj = nn.Linear(vis_emb_dim, d_model)
|
||||
self.state_proj = nn.Linear(state_dim, d_model)
|
||||
|
||||
# Encoder
|
||||
enc_layer = nn.TransformerEncoderLayer(d_model, n_heads, 4 * d_model, dropout, batch_first=True)
|
||||
self.transformer = nn.TransformerEncoder(enc_layer, n_layers)
|
||||
|
||||
# Positional bias on first visual frame
|
||||
self.first_pos = nn.Parameter(torch.zeros(1, d_model))
|
||||
|
||||
# Shared fusion MLP
|
||||
# Fuses (num_cameras + 2) streams: cameras + lang + state
|
||||
fused_in = d_model * (num_cameras + 2)
|
||||
self.fusion_backbone = nn.Sequential(
|
||||
nn.LayerNorm(fused_in),
|
||||
nn.Linear(fused_in, d_model),
|
||||
nn.ReLU(),
|
||||
)
|
||||
|
||||
# Scheme-specific heads
|
||||
self.heads = nn.ModuleDict(
|
||||
{
|
||||
"sparse": nn.Linear(d_model, num_classes_sparse),
|
||||
"dense": nn.Linear(d_model, num_classes_dense),
|
||||
}
|
||||
)
|
||||
|
||||
def _prep_lang(self, lang_emb: torch.Tensor, B: int, T: int, D: int) -> torch.Tensor: # noqa: N803
|
||||
"""
|
||||
Prepare language embeddings for fusion.
|
||||
|
||||
Accepts lang_emb of shape:
|
||||
- (B, text_emb_dim) -> broadcast across time
|
||||
- (B, T, text_emb_dim) -> per-timestep (dense annotation mode)
|
||||
|
||||
Returns: (B, 1, T, D)
|
||||
"""
|
||||
if lang_emb.dim() == 3:
|
||||
# (B, T, E) -> (B, T, D) -> (B, 1, T, D)
|
||||
lang_proj = self.lang_proj(lang_emb).unsqueeze(1)
|
||||
else:
|
||||
# (B, E) -> (B, 1, 1, D) -> expand to (B, 1, T, D)
|
||||
lang_proj = self.lang_proj(lang_emb).unsqueeze(1).unsqueeze(2).expand(B, 1, T, D)
|
||||
return lang_proj
|
||||
|
||||
def forward(
|
||||
self,
|
||||
img_seq: torch.Tensor, # (B, N, T, vis_emb_dim)
|
||||
lang_emb: torch.Tensor, # (B, E) or (B, T, E)
|
||||
state: torch.Tensor, # (B, T, state_dim)
|
||||
lengths: torch.Tensor, # (B,) - valid sequence lengths
|
||||
scheme: str = "sparse", # "sparse" or "dense"
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass for stage classification.
|
||||
|
||||
Args:
|
||||
img_seq: Image embeddings (B, N, T, vis_emb_dim) where N=num_cameras
|
||||
lang_emb: Language embeddings (B, E) or (B, T, E) for dense
|
||||
state: State features (B, T, state_dim)
|
||||
lengths: Valid sequence lengths (B,) for masking
|
||||
scheme: "sparse" or "dense" for head selection
|
||||
|
||||
Returns:
|
||||
Stage logits (B, T, num_classes)
|
||||
"""
|
||||
assert scheme in self.heads, f"Unknown scheme '{scheme}'. Use one of {list(self.heads.keys())}."
|
||||
|
||||
B, N, T, _ = img_seq.shape # noqa: N806
|
||||
D = self.d_model # noqa: N806
|
||||
device = img_seq.device
|
||||
|
||||
# Project inputs
|
||||
vis_proj = self.visual_proj(img_seq) # (B, N, T, D)
|
||||
state_proj = self.state_proj(state).unsqueeze(1) # (B, 1, T, D)
|
||||
lang_proj = self._prep_lang(lang_emb, B, T, D) # (B, 1, T, D)
|
||||
|
||||
# Concatenate streams
|
||||
# cameras + lang + state -> (B, N+2, T, D)
|
||||
x = torch.cat([vis_proj, lang_proj, state_proj], dim=1)
|
||||
|
||||
# Add positional bias to first visual frame
|
||||
x[:, :N, 0, :] = x[:, :N, 0, :] + self.first_pos
|
||||
|
||||
# Flatten to tokens for Transformer
|
||||
x_tokens = x.view(B, (N + 2) * T, D)
|
||||
L = x_tokens.size(1) # noqa: N806
|
||||
|
||||
# Create padding mask
|
||||
base_mask = torch.arange(T, device=device).expand(B, T) >= lengths.unsqueeze(1) # (B, T)
|
||||
mask = base_mask.unsqueeze(1).expand(B, N + 2, T).reshape(B, (N + 2) * T)
|
||||
|
||||
# Create causal mask
|
||||
causal_mask = torch.triu(torch.ones(L, L, device=device, dtype=torch.bool), diagonal=1)
|
||||
|
||||
# Encode
|
||||
h = self.transformer(x_tokens, mask=causal_mask, src_key_padding_mask=mask, is_causal=True)
|
||||
|
||||
# Reshape and fuse
|
||||
h = h.view(B, N + 2, T, D).permute(0, 2, 1, 3).reshape(B, T, (N + 2) * D)
|
||||
fused = self.fusion_backbone(h) # (B, T, D)
|
||||
|
||||
# Scheme-specific logits
|
||||
logits = self.heads[scheme](fused) # (B, T, num_classes)
|
||||
return logits
|
||||
|
||||
|
||||
class SubtaskTransformer(nn.Module):
|
||||
"""
|
||||
Subtask progress regression transformer for SARM.
|
||||
|
||||
Predicts within-stage normalized progress (tau) conditioned on stage prior.
|
||||
The stage prior is a one-hot encoding passed from StageTransformer predictions.
|
||||
|
||||
Input streams: [vis_proj, lang_proj, state_proj, stage_emb] -> (B, N+3, T, D)
|
||||
Output: tau predictions (B, T) in [0, 1]
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int = 512,
|
||||
vis_emb_dim: int = 512,
|
||||
text_emb_dim: int = 512,
|
||||
state_dim: int = 32,
|
||||
n_layers: int = 6,
|
||||
n_heads: int = 8,
|
||||
dropout: float = 0.1,
|
||||
num_cameras: int = 1,
|
||||
):
|
||||
super().__init__()
|
||||
self.d_model = d_model
|
||||
self.num_cameras = num_cameras
|
||||
|
||||
# Projections
|
||||
self.lang_proj = nn.Linear(text_emb_dim, d_model)
|
||||
self.visual_proj = nn.Linear(vis_emb_dim, d_model)
|
||||
self.state_proj = nn.Linear(state_dim, d_model)
|
||||
|
||||
# Encoder
|
||||
enc = nn.TransformerEncoderLayer(d_model, n_heads, 4 * d_model, dropout, batch_first=True)
|
||||
self.transformer = nn.TransformerEncoder(enc, n_layers)
|
||||
|
||||
# Learned bias on first visual frame
|
||||
self.first_pos = nn.Parameter(torch.zeros(1, d_model))
|
||||
|
||||
# Shared fusion backbone
|
||||
# Fuses (num_cameras + 3) streams: cameras + lang + state + stage_emb
|
||||
fused_in = d_model * (num_cameras + 3)
|
||||
self.fusion_backbone = nn.Sequential(
|
||||
nn.LayerNorm(fused_in),
|
||||
nn.Linear(fused_in, d_model),
|
||||
nn.ReLU(),
|
||||
)
|
||||
|
||||
# Scheme-specific regression heads
|
||||
self.heads = nn.ModuleDict(
|
||||
{
|
||||
"sparse": nn.Linear(d_model, 1),
|
||||
"dense": nn.Linear(d_model, 1),
|
||||
}
|
||||
)
|
||||
|
||||
def _prep_lang(self, lang_emb: torch.Tensor, B: int, T: int, D: int) -> torch.Tensor: # noqa: N803
|
||||
"""
|
||||
Prepare language embeddings for fusion.
|
||||
"""
|
||||
if lang_emb.dim() == 3:
|
||||
# (B, T, E) -> (B, T, D) -> (B, 1, T, D)
|
||||
return self.lang_proj(lang_emb).unsqueeze(1)
|
||||
else:
|
||||
# (B, E) -> (B, 1, 1, D) -> (B, 1, T, D)
|
||||
return self.lang_proj(lang_emb).unsqueeze(1).unsqueeze(2).expand(B, 1, T, D)
|
||||
|
||||
def _stage_to_dmodel(self, stage_prior: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Deterministic projection of one-hot stage to d_model by pad/truncate.
|
||||
|
||||
Args:
|
||||
stage_prior: One-hot stage embedding (B, 1, T, C)
|
||||
|
||||
Returns:
|
||||
Projected stage embedding (B, 1, T, d_model)
|
||||
"""
|
||||
B, one, T, C = stage_prior.shape # noqa: N806
|
||||
D = self.d_model # noqa: N806
|
||||
if D == C:
|
||||
return stage_prior
|
||||
elif D > C:
|
||||
pad = torch.zeros(B, one, T, D - C, device=stage_prior.device, dtype=stage_prior.dtype)
|
||||
return torch.cat([stage_prior, pad], dim=-1)
|
||||
else:
|
||||
return stage_prior[..., :D]
|
||||
|
||||
def forward(
|
||||
self,
|
||||
img_seq: torch.Tensor, # (B, N, T, vis_emb_dim)
|
||||
lang_emb: torch.Tensor, # (B, E) or (B, T, E)
|
||||
state: torch.Tensor, # (B, T, state_dim)
|
||||
lengths: torch.Tensor, # (B,) - valid sequence lengths
|
||||
stage_prior: torch.Tensor, # (B, 1, T, C) one-hot from gen_stage_emb
|
||||
scheme: str = "sparse", # "sparse" or "dense"
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass for subtask progress regression.
|
||||
|
||||
Args:
|
||||
img_seq: Image embeddings (B, N, T, vis_emb_dim)
|
||||
lang_emb: Language embeddings (B, E) or (B, T, E)
|
||||
state: State features (B, T, state_dim)
|
||||
lengths: Valid sequence lengths (B,) for masking
|
||||
stage_prior: One-hot stage prior (B, 1, T, num_classes)
|
||||
scheme: "sparse" or "dense" for head selection
|
||||
|
||||
Returns:
|
||||
Tau predictions (B, T) in [0, 1] via sigmoid
|
||||
"""
|
||||
assert scheme in self.heads, f"Unknown scheme '{scheme}'. Use one of {list(self.heads.keys())}."
|
||||
|
||||
B, N, T, _ = img_seq.shape # noqa: N806
|
||||
D = self.d_model # noqa: N806
|
||||
device = img_seq.device
|
||||
|
||||
# Project inputs
|
||||
vis_proj = self.visual_proj(img_seq) # (B, N, T, D)
|
||||
state_proj = self.state_proj(state).unsqueeze(1) # (B, 1, T, D)
|
||||
lang_proj = self._prep_lang(lang_emb, B, T, D) # (B, 1, T, D)
|
||||
stage_emb = self._stage_to_dmodel(stage_prior) # (B, 1, T, D)
|
||||
|
||||
# Concatenate all streams
|
||||
# cameras + lang + state + stage_emb -> (B, N+3, T, D)
|
||||
x = torch.cat([vis_proj, lang_proj, state_proj, stage_emb], dim=1)
|
||||
|
||||
# Add positional bias to first visual frame
|
||||
x[:, :N, 0, :] = x[:, :N, 0, :] + self.first_pos
|
||||
|
||||
# Flatten to tokens
|
||||
x_tokens = x.view(B, (N + 3) * T, D)
|
||||
L = x_tokens.size(1) # noqa: N806
|
||||
|
||||
# Create padding mask
|
||||
base_mask = torch.arange(T, device=device).expand(B, T) >= lengths.unsqueeze(1)
|
||||
mask = base_mask.unsqueeze(1).expand(B, N + 3, T).reshape(B, (N + 3) * T)
|
||||
|
||||
# Create causal mask
|
||||
causal_mask = torch.triu(torch.ones(L, L, device=device, dtype=torch.bool), diagonal=1)
|
||||
|
||||
# Encode
|
||||
h = self.transformer(x_tokens, mask=causal_mask, src_key_padding_mask=mask, is_causal=True)
|
||||
|
||||
# Reshape and fuse
|
||||
h = h.view(B, N + 3, T, D)
|
||||
h_flat = h.permute(0, 2, 1, 3).reshape(B, T, (N + 3) * D)
|
||||
fused = self.fusion_backbone(h_flat) # (B, T, D)
|
||||
|
||||
# Scheme-specific regression head -> sigmoid
|
||||
r = torch.sigmoid(self.heads[scheme](fused)).squeeze(-1) # (B, T)
|
||||
return r
|
||||
|
||||
|
||||
def gen_stage_emb(num_classes: int, targets: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Generate one-hot stage embeddings from targets.
|
||||
|
||||
Args:
|
||||
num_classes: Number of stage classes
|
||||
targets: Target values (B, T) where integer part is stage index
|
||||
|
||||
Returns:
|
||||
One-hot stage embedding (B, 1, T, num_classes)
|
||||
"""
|
||||
# Integer part of float targets -> [0, C-1]
|
||||
idx = targets.long().clamp(min=0, max=num_classes - 1) # (B, T)
|
||||
C = num_classes # noqa: N806
|
||||
# Identity-lookup one-hot
|
||||
stage_onehot = torch.eye(C, device=targets.device)[idx] # (B, T, C)
|
||||
stage_onehot = stage_onehot.unsqueeze(1) # (B, 1, T, C)
|
||||
return stage_onehot
|
||||
|
||||
|
||||
class SARMRewardModel(PreTrainedPolicy):
|
||||
"""
|
||||
SARM Reward Model for stage-aware task completion rewards.
|
||||
|
||||
Uses two separate transformer models:
|
||||
- StageTransformer: Classifies which stage/subtask
|
||||
- SubtaskTransformer: Predicts within-stage progress (tau)
|
||||
|
||||
Training uses 75%/25% GT/predicted stage conditioning (teacher forcing).
|
||||
"""
|
||||
|
||||
name = "sarm"
|
||||
config_class = SARMConfig
|
||||
|
||||
def __init__(self, config: SARMConfig, dataset_stats: dict | None = None, dataset_meta=None):
|
||||
super().__init__(config, dataset_stats)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
self.dataset_stats = dataset_stats
|
||||
self.device = torch.device(
|
||||
config.device if config.device else "cuda" if torch.cuda.is_available() else "cpu"
|
||||
)
|
||||
|
||||
# Load temporal proportions based on annotation_mode
|
||||
if config.annotation_mode == "single_stage":
|
||||
logging.info(f"Using single_stage mode: sparse_subtask_names={config.sparse_subtask_names}")
|
||||
elif dataset_meta is not None:
|
||||
self._load_temporal_proportions(dataset_meta)
|
||||
|
||||
# Create two separate models
|
||||
self.stage_model = StageTransformer(
|
||||
d_model=config.hidden_dim,
|
||||
vis_emb_dim=config.image_dim,
|
||||
text_emb_dim=config.text_dim,
|
||||
state_dim=config.max_state_dim,
|
||||
n_layers=config.num_layers,
|
||||
n_heads=config.num_heads,
|
||||
dropout=config.dropout,
|
||||
num_cameras=1, # Single camera for now
|
||||
num_classes_sparse=config.num_sparse_stages,
|
||||
num_classes_dense=config.num_dense_stages or config.num_sparse_stages,
|
||||
)
|
||||
|
||||
self.subtask_model = SubtaskTransformer(
|
||||
d_model=config.hidden_dim,
|
||||
vis_emb_dim=config.image_dim,
|
||||
text_emb_dim=config.text_dim,
|
||||
state_dim=config.max_state_dim,
|
||||
n_layers=config.num_layers,
|
||||
n_heads=config.num_heads,
|
||||
dropout=config.dropout,
|
||||
num_cameras=1,
|
||||
)
|
||||
|
||||
self.stage_model.to(self.device)
|
||||
self.subtask_model.to(self.device)
|
||||
|
||||
# GT/predicted stage ratio for teacher forcing
|
||||
self.gt_stage_ratio = 0.75
|
||||
|
||||
if config.uses_dual_heads:
|
||||
logging.info(
|
||||
f"SARM initialized with dual heads: {config.num_sparse_stages} sparse stages, "
|
||||
f"{config.num_dense_stages} dense stages"
|
||||
)
|
||||
else:
|
||||
logging.info(f"SARM initialized with sparse head only: {config.num_sparse_stages} stages")
|
||||
|
||||
logging.info(f"SARM initialized on {self.device}")
|
||||
|
||||
def _load_proportions_from_json(self, path, annotation_type: str) -> tuple[list[str], list[float]]:
|
||||
"""Load temporal proportions from a JSON file (preserving order)."""
|
||||
if not path.exists():
|
||||
raise ValueError(
|
||||
f"{annotation_type.capitalize()} temporal proportions not found at {path}. "
|
||||
f"Run the subtask annotation tool with --{annotation_type}-subtasks to generate annotations."
|
||||
)
|
||||
with open(path) as f:
|
||||
proportions_dict = json.load(f)
|
||||
names = list(proportions_dict.keys())
|
||||
logging.info(f"Loaded {len(names)} {annotation_type} subtasks: {names}")
|
||||
logging.info(f"{annotation_type.capitalize()} temporal proportions: {proportions_dict}")
|
||||
return names, [proportions_dict[name] for name in names]
|
||||
|
||||
def _load_temporal_proportions(self, dataset_meta) -> None:
|
||||
"""Load temporal proportions based on annotation_mode."""
|
||||
meta_path = dataset_meta.root / "meta"
|
||||
|
||||
if self.config.annotation_mode == "dual":
|
||||
names, props = self._load_proportions_from_json(
|
||||
meta_path / "temporal_proportions_sparse.json", "sparse"
|
||||
)
|
||||
(
|
||||
self.config.num_sparse_stages,
|
||||
self.config.sparse_subtask_names,
|
||||
self.config.sparse_temporal_proportions,
|
||||
) = len(names), names, props
|
||||
|
||||
if self.config.annotation_mode in ["dense_only", "dual"]:
|
||||
names, props = self._load_proportions_from_json(
|
||||
meta_path / "temporal_proportions_dense.json", "dense"
|
||||
)
|
||||
(
|
||||
self.config.num_dense_stages,
|
||||
self.config.dense_subtask_names,
|
||||
self.config.dense_temporal_proportions,
|
||||
) = len(names), names, props
|
||||
if self.config.annotation_mode == "dense_only":
|
||||
logging.info(f"Using auto-generated sparse 'task' stage: {self.config.sparse_subtask_names}")
|
||||
|
||||
def to(self, device):
|
||||
"""Override to method to ensure all components move together."""
|
||||
super().to(device)
|
||||
self.device = device if isinstance(device, torch.device) else torch.device(device)
|
||||
self.stage_model.to(device)
|
||||
self.subtask_model.to(device)
|
||||
return self
|
||||
|
||||
@torch.no_grad()
|
||||
def calculate_rewards(
|
||||
self,
|
||||
text_embeddings: np.ndarray | torch.Tensor,
|
||||
video_embeddings: np.ndarray | torch.Tensor,
|
||||
state_features: np.ndarray | torch.Tensor | None = None,
|
||||
lengths: np.ndarray | torch.Tensor | None = None,
|
||||
return_all_frames: bool = False,
|
||||
return_stages: bool = False,
|
||||
return_confidence: bool = False,
|
||||
head_mode: str | None = "sparse",
|
||||
frame_index: int | None = None,
|
||||
) -> np.ndarray | tuple:
|
||||
"""
|
||||
Calculate rewards for given text, video, and state representations.
|
||||
|
||||
This is the canonical method for SARM reward computation, used for:
|
||||
- Inference/visualization
|
||||
- RA-BC weight computation
|
||||
|
||||
Args:
|
||||
text_embeddings: Encoded text representations (batch_size, 512)
|
||||
video_embeddings: Encoded video representations (batch_size, num_frames, 512)
|
||||
state_features: Joint state features (batch_size, num_frames, state_dim)
|
||||
lengths: Valid sequence lengths (batch_size,)
|
||||
return_all_frames: If True, return rewards for all frames
|
||||
return_stages: If True, also return stage predictions
|
||||
return_confidence: If True, also return stage confidence
|
||||
head_mode: Which head to use ("sparse" or "dense")
|
||||
frame_index: Index of the target frame to extract (default: n_obs_steps).
|
||||
|
||||
Returns:
|
||||
Rewards and optionally stage probs/confidence.
|
||||
"""
|
||||
if isinstance(text_embeddings, np.ndarray):
|
||||
text_embeddings = torch.tensor(text_embeddings, dtype=torch.float32)
|
||||
if isinstance(video_embeddings, np.ndarray):
|
||||
video_embeddings = torch.tensor(video_embeddings, dtype=torch.float32)
|
||||
if state_features is not None and isinstance(state_features, np.ndarray):
|
||||
state_features = torch.tensor(state_features, dtype=torch.float32)
|
||||
|
||||
# Handle single sample case
|
||||
if text_embeddings.dim() == 1:
|
||||
text_embeddings = text_embeddings.unsqueeze(0)
|
||||
video_embeddings = video_embeddings.unsqueeze(0)
|
||||
if state_features is not None:
|
||||
state_features = state_features.unsqueeze(0)
|
||||
single_sample = True
|
||||
else:
|
||||
single_sample = False
|
||||
|
||||
batch_size = video_embeddings.shape[0]
|
||||
seq_len = video_embeddings.shape[1]
|
||||
|
||||
scheme = head_mode
|
||||
|
||||
# Default lengths if not provided
|
||||
if lengths is None:
|
||||
lengths = torch.full((batch_size,), seq_len, dtype=torch.int32)
|
||||
elif isinstance(lengths, np.ndarray):
|
||||
lengths = torch.tensor(lengths, dtype=torch.int32)
|
||||
|
||||
# Reshape video to (B, N, T, D) for multi-camera format
|
||||
# Currently single camera: (B, T, D) -> (B, 1, T, D)
|
||||
img_seq = video_embeddings.unsqueeze(1).to(self.device)
|
||||
lang_emb = text_embeddings.to(self.device)
|
||||
state = (
|
||||
state_features.to(self.device)
|
||||
if state_features is not None
|
||||
else torch.zeros(batch_size, seq_len, self.config.max_state_dim, device=self.device)
|
||||
)
|
||||
lens = lengths.to(self.device)
|
||||
|
||||
# Pad state to max_state_dim
|
||||
state = pad_state_to_max_dim(state, self.config.max_state_dim)
|
||||
|
||||
# Get num_classes for this scheme
|
||||
num_classes = self.config.num_sparse_stages if scheme == "sparse" else self.config.num_dense_stages
|
||||
|
||||
# Run stage model
|
||||
stage_logits = self.stage_model(img_seq, lang_emb, state, lens, scheme=scheme)
|
||||
stage_probs = F.softmax(stage_logits, dim=-1) # (B, T, num_classes)
|
||||
stage_idx = stage_probs.argmax(dim=-1) # (B, T)
|
||||
stage_conf = stage_probs.gather(-1, stage_idx.unsqueeze(-1)).squeeze(-1) # (B, T)
|
||||
|
||||
# Create one-hot stage prior
|
||||
stage_onehot = F.one_hot(stage_idx, num_classes=num_classes).float() # (B, T, C)
|
||||
stage_emb = stage_onehot.unsqueeze(1) # (B, 1, T, C)
|
||||
|
||||
# Run subtask model
|
||||
tau_pred = self.subtask_model(img_seq, lang_emb, state, lens, stage_emb, scheme=scheme)
|
||||
|
||||
# Compute final reward: stage + tau
|
||||
raw_reward = stage_idx.float() + tau_pred # (B, T)
|
||||
|
||||
# Normalize to [0, 1] using temporal proportions for proper weighting
|
||||
if scheme == "sparse":
|
||||
normalized_reward = normalize_stage_tau(
|
||||
raw_reward,
|
||||
num_stages=num_classes,
|
||||
temporal_proportions=self.config.sparse_temporal_proportions,
|
||||
subtask_names=self.config.sparse_subtask_names,
|
||||
)
|
||||
else:
|
||||
normalized_reward = normalize_stage_tau(
|
||||
raw_reward,
|
||||
num_stages=num_classes,
|
||||
temporal_proportions=self.config.dense_temporal_proportions,
|
||||
subtask_names=self.config.dense_subtask_names,
|
||||
)
|
||||
|
||||
# Default frame index is n_obs_steps (last observation frame)
|
||||
if frame_index is None:
|
||||
frame_index = self.config.n_obs_steps
|
||||
|
||||
# Prepare outputs (batch mode or no smoothing)
|
||||
if return_all_frames:
|
||||
rewards = normalized_reward.cpu().numpy()
|
||||
else:
|
||||
rewards = normalized_reward[:, frame_index].cpu().numpy()
|
||||
|
||||
if single_sample:
|
||||
rewards = rewards[0] if not return_all_frames else rewards[0]
|
||||
|
||||
outputs = [rewards]
|
||||
if return_stages:
|
||||
probs = stage_probs.cpu().numpy()
|
||||
if single_sample:
|
||||
probs = probs[0]
|
||||
outputs.append(probs)
|
||||
if return_confidence:
|
||||
conf = stage_conf.cpu().numpy()
|
||||
if single_sample:
|
||||
conf = conf[0]
|
||||
outputs.append(conf)
|
||||
|
||||
return outputs[0] if len(outputs) == 1 else tuple(outputs)
|
||||
|
||||
def train(self, mode: bool = True):
|
||||
"""Set training mode for both models."""
|
||||
super().train(mode)
|
||||
self.stage_model.train(mode)
|
||||
self.subtask_model.train(mode)
|
||||
return self
|
||||
|
||||
def eval(self):
|
||||
"""Set evaluation mode for both models."""
|
||||
return self.train(False)
|
||||
|
||||
def parameters(self):
|
||||
"""Override to return trainable parameters from both models."""
|
||||
from itertools import chain
|
||||
|
||||
return chain(self.stage_model.parameters(), self.subtask_model.parameters())
|
||||
|
||||
def get_optim_params(self):
|
||||
"""Override to return optimizer parameters from both models."""
|
||||
return self.parameters()
|
||||
|
||||
def reset(self):
|
||||
"""Required by PreTrainedPolicy but not used for reward models."""
|
||||
pass
|
||||
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Required by PreTrainedPolicy but not used for reward models."""
|
||||
raise NotImplementedError("SARM model does not predict action chunks")
|
||||
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Required by PreTrainedPolicy but not used for SARM."""
|
||||
raise NotImplementedError("SARM model does not select actions")
|
||||
|
||||
def _train_step(
|
||||
self,
|
||||
img_emb: torch.Tensor, # (B, N, T, D)
|
||||
lang_emb: torch.Tensor, # (B, E) or (B, T, E)
|
||||
state: torch.Tensor, # (B, T, state_dim)
|
||||
lengths: torch.Tensor, # (B,)
|
||||
targets: torch.Tensor, # (B, T) - format: stage.tau
|
||||
scheme: str,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Single training step for one annotation scheme.
|
||||
|
||||
Implements 75%/25% GT/predicted stage conditioning.
|
||||
|
||||
Args:
|
||||
img_emb: Image embeddings (B, N, T, D)
|
||||
lang_emb: Language embeddings
|
||||
state: State features
|
||||
lengths: Valid sequence lengths
|
||||
targets: Target values where floor=stage, remainder=tau
|
||||
scheme: "sparse" or "dense"
|
||||
|
||||
Returns:
|
||||
Dict with stage_loss, subtask_loss, total_loss
|
||||
"""
|
||||
num_classes = self.config.num_sparse_stages if scheme == "sparse" else self.config.num_dense_stages
|
||||
|
||||
# Ground truth: stage (integer) and tau (fractional)
|
||||
# Clamp stage indices to valid range [0, num_classes-1] to handle edge cases
|
||||
# where targets may exceed expected range (e.g., frames between subtasks)
|
||||
gt_stage = torch.floor(targets).long().clamp(0, num_classes - 1) # (B, T)
|
||||
gt_tau = torch.remainder(targets, 1.0) # (B, T)
|
||||
|
||||
# Run stage model
|
||||
stage_pred = self.stage_model(img_emb, lang_emb, state, lengths, scheme=scheme)
|
||||
|
||||
# 75%/25% GT/predicted stage conditioning
|
||||
if random.random() < self.gt_stage_ratio:
|
||||
# Mode 1: Use ground truth stage -> one-hot
|
||||
stage_emb = gen_stage_emb(num_classes, targets) # (B, 1, T, C)
|
||||
else:
|
||||
# Mode 2: Use predicted stage argmax -> one-hot
|
||||
stage_idx = stage_pred.argmax(dim=-1) # (B, T)
|
||||
stage_onehot = F.one_hot(stage_idx, num_classes=num_classes).float() # (B, T, C)
|
||||
stage_emb = stage_onehot.unsqueeze(1) # (B, 1, T, C)
|
||||
|
||||
# Run subtask model with stage prior
|
||||
tau_pred = self.subtask_model(img_emb, lang_emb, state, lengths, stage_emb, scheme=scheme)
|
||||
|
||||
# Compute losses
|
||||
stage_loss = F.cross_entropy(stage_pred.view(-1, num_classes), gt_stage.view(-1), reduction="mean")
|
||||
subtask_loss = F.mse_loss(tau_pred, gt_tau, reduction="mean")
|
||||
|
||||
return {
|
||||
"stage_loss": stage_loss,
|
||||
"subtask_loss": subtask_loss,
|
||||
"total_loss": stage_loss + subtask_loss,
|
||||
}
|
||||
|
||||
def forward(self, batch):
|
||||
"""
|
||||
Forward pass for SARM reward model training.
|
||||
|
||||
Uses stage+tau target format where:
|
||||
- Integer part = stage index
|
||||
- Fractional part = within-stage progress (tau)
|
||||
|
||||
Training uses 75%/25% GT/predicted stage conditioning.
|
||||
|
||||
Args:
|
||||
batch: Dictionary with 'observation' containing:
|
||||
- 'video_features': (B, T, 512) pre-encoded video features
|
||||
- 'text_features': (B, 512) or (B, T, 512) text features
|
||||
- 'state_features': (B, T, state_dim) joint state features
|
||||
- 'lengths': (B,) valid sequence lengths
|
||||
- 'sparse_targets': (B, T) sparse targets (stage.tau format)
|
||||
- 'dense_targets': (B, T) dense targets (optional, for dual mode)
|
||||
|
||||
Returns:
|
||||
Tuple of (total_loss, output_dict with loss components)
|
||||
"""
|
||||
observation = batch.get("observation", batch)
|
||||
|
||||
# Extract features
|
||||
video_features = observation["video_features"].to(self.device)
|
||||
text_features = observation["text_features"].to(self.device)
|
||||
state_features = observation.get("state_features")
|
||||
if state_features is not None:
|
||||
state_features = state_features.to(self.device)
|
||||
|
||||
batch_size = video_features.shape[0]
|
||||
seq_len = video_features.shape[1]
|
||||
|
||||
# Get lengths (default to full sequence)
|
||||
lengths = observation.get("lengths")
|
||||
if lengths is None:
|
||||
lengths = torch.full((batch_size,), seq_len, dtype=torch.int32, device=self.device)
|
||||
else:
|
||||
lengths = lengths.to(self.device)
|
||||
|
||||
# Reshape video to (B, N, T, D) - single camera
|
||||
img_emb = video_features.unsqueeze(1)
|
||||
|
||||
# Pad state to max_state_dim
|
||||
if state_features is None:
|
||||
state_features = torch.zeros(batch_size, seq_len, self.config.max_state_dim, device=self.device)
|
||||
else:
|
||||
state_features = pad_state_to_max_dim(state_features, self.config.max_state_dim)
|
||||
|
||||
output_dict = {}
|
||||
total_loss = torch.tensor(0.0, device=self.device)
|
||||
|
||||
# Sparse training (always)
|
||||
sparse_targets = observation.get("sparse_targets")
|
||||
if sparse_targets is None:
|
||||
# Try legacy format
|
||||
sparse_targets = observation.get("targets")
|
||||
if sparse_targets is None:
|
||||
raise ValueError("sparse_targets (or targets) is required for SARM training")
|
||||
sparse_targets = sparse_targets.to(self.device)
|
||||
|
||||
sparse_result = self._train_step(
|
||||
img_emb, text_features, state_features, lengths, sparse_targets, scheme="sparse"
|
||||
)
|
||||
output_dict["sparse_stage_loss"] = sparse_result["stage_loss"].item()
|
||||
output_dict["sparse_subtask_loss"] = sparse_result["subtask_loss"].item()
|
||||
total_loss = total_loss + sparse_result["total_loss"]
|
||||
|
||||
# Dense training (if dual mode)
|
||||
if self.config.uses_dual_heads:
|
||||
dense_targets = observation.get("dense_targets")
|
||||
if dense_targets is not None:
|
||||
dense_targets = dense_targets.to(self.device)
|
||||
dense_result = self._train_step(
|
||||
img_emb, text_features, state_features, lengths, dense_targets, scheme="dense"
|
||||
)
|
||||
output_dict["dense_stage_loss"] = dense_result["stage_loss"].item()
|
||||
output_dict["dense_subtask_loss"] = dense_result["subtask_loss"].item()
|
||||
total_loss = total_loss + dense_result["total_loss"]
|
||||
|
||||
output_dict["total_loss"] = total_loss.item()
|
||||
return total_loss, output_dict
|
||||
|
||||
|
||||
def compute_stage_loss(stage_logits: torch.Tensor, target_stages: torch.Tensor) -> torch.Tensor:
|
||||
"""Compute cross-entropy loss for stage classification."""
|
||||
_, _, num_stages = stage_logits.shape
|
||||
stage_logits_flat = stage_logits.reshape(-1, num_stages)
|
||||
# Clamp target stage indices to valid range [0, num_stages-1]
|
||||
target_stages_flat = target_stages.reshape(-1).clamp(0, num_stages - 1)
|
||||
return F.cross_entropy(stage_logits_flat, target_stages_flat)
|
||||
@@ -1,518 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""SARM Processor for encoding images/text and generating stage+tau targets."""
|
||||
|
||||
import random
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
from faker import Faker
|
||||
from PIL import Image
|
||||
from transformers import CLIPModel, CLIPProcessor
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.policies.sarm.configuration_sarm import SARMConfig
|
||||
from lerobot.policies.sarm.sarm_utils import (
|
||||
apply_rewind_augmentation,
|
||||
compute_absolute_indices,
|
||||
find_stage_and_tau,
|
||||
pad_state_to_max_dim,
|
||||
)
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyAction,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorStep,
|
||||
RenameObservationsProcessorStep,
|
||||
)
|
||||
from lerobot.processor.converters import (
|
||||
from_tensor_to_numpy,
|
||||
policy_action_to_transition,
|
||||
transition_to_policy_action,
|
||||
)
|
||||
from lerobot.processor.core import EnvTransition, TransitionKey
|
||||
from lerobot.processor.pipeline import PipelineFeatureType
|
||||
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
|
||||
|
||||
class SARMEncodingProcessorStep(ProcessorStep):
|
||||
"""ProcessorStep that encodes images and text with CLIP and generates stage and progress labels for SARM."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: SARMConfig,
|
||||
image_key: str | None = None,
|
||||
dataset_meta=None,
|
||||
dataset_stats: dict | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.image_key = image_key or config.image_key
|
||||
self.dataset_meta = dataset_meta
|
||||
self.dataset_stats = dataset_stats
|
||||
self.annotation_mode = config.annotation_mode
|
||||
|
||||
# Helper to create temporal proportions dict
|
||||
def make_props_dict(names, props):
|
||||
return dict(zip(names, props, strict=True)) if names and props else None
|
||||
|
||||
# Sparse annotations (always needed)
|
||||
self.sparse_temporal_proportions = make_props_dict(
|
||||
config.sparse_subtask_names, config.sparse_temporal_proportions
|
||||
)
|
||||
self.sparse_subtask_names = config.sparse_subtask_names
|
||||
|
||||
# Dense annotations (only for dual mode)
|
||||
self.dense_subtask_names = config.dense_subtask_names if config.uses_dual_heads else None
|
||||
self.dense_temporal_proportions = (
|
||||
make_props_dict(config.dense_subtask_names, config.dense_temporal_proportions)
|
||||
if config.uses_dual_heads
|
||||
else None
|
||||
)
|
||||
|
||||
self.device = torch.device(
|
||||
self.config.device if self.config.device else "cuda" if torch.cuda.is_available() else "cpu"
|
||||
)
|
||||
|
||||
self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
||||
self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", use_fast=True)
|
||||
self.clip_model.to(self.device)
|
||||
self.clip_model.eval()
|
||||
|
||||
self.verbs = ["move", "grasp", "rotate", "push", "pull", "slide", "lift", "place"]
|
||||
self.fake = Faker()
|
||||
|
||||
def _find_episode_for_frame(self, frame_idx: int) -> int:
|
||||
"""Find the episode index for a given frame index."""
|
||||
for ep_idx in range(len(self.dataset_meta.episodes)):
|
||||
ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"]
|
||||
ep_end = self.dataset_meta.episodes[ep_idx]["dataset_to_index"]
|
||||
if ep_start <= frame_idx < ep_end:
|
||||
return ep_idx
|
||||
return 0
|
||||
|
||||
def _get_episode_indices(self, frame_indices: np.ndarray, episode_index) -> np.ndarray:
|
||||
"""Get episode indices for each frame index."""
|
||||
if episode_index is None:
|
||||
return np.array([self._find_episode_for_frame(int(f)) for f in frame_indices])
|
||||
|
||||
episode_indices = np.atleast_1d(np.asarray(from_tensor_to_numpy(episode_index)))
|
||||
|
||||
# If single episode but multiple frames, compute episode for each frame
|
||||
if len(episode_indices) == 1 and len(frame_indices) > 1:
|
||||
return np.array([self._find_episode_for_frame(int(f)) for f in frame_indices])
|
||||
|
||||
return episode_indices
|
||||
|
||||
def _generate_perturbed_task(self) -> str:
|
||||
"""Generate a random perturbed task string for language perturbation."""
|
||||
num_words = random.randint(1, 5)
|
||||
verb = random.choice(self.verbs)
|
||||
phrase = " ".join([verb] + self.fake.words(nb=num_words))
|
||||
return phrase
|
||||
|
||||
def _get_annotation_config(self, annotation_type: str) -> tuple[list[str], dict[str, float] | None]:
|
||||
"""Get global subtask names and temporal proportions for an annotation type."""
|
||||
if annotation_type == "dense":
|
||||
return self.dense_subtask_names, self.dense_temporal_proportions
|
||||
return self.sparse_subtask_names, self.sparse_temporal_proportions
|
||||
|
||||
def _load_episode_annotations(
|
||||
self,
|
||||
ep_idx: int,
|
||||
episodes_df: pd.DataFrame | None,
|
||||
annotation_type: str,
|
||||
global_names: list[str],
|
||||
) -> tuple[list | None, list | None, list | None]:
|
||||
"""Load subtask annotations for an episode from DataFrame."""
|
||||
# Single-stage mode: (linear progress 0→1)
|
||||
if episodes_df is None or len(global_names) == 1:
|
||||
return None, None, None
|
||||
|
||||
# Resolve column name with fallback
|
||||
def col(suffix):
|
||||
prefixed = f"{annotation_type}_{suffix}"
|
||||
return prefixed if prefixed in episodes_df.columns else suffix
|
||||
|
||||
col_names = col("subtask_names")
|
||||
if col_names not in episodes_df.columns or ep_idx >= len(episodes_df):
|
||||
return None, None, None
|
||||
|
||||
subtask_names = episodes_df.loc[ep_idx, col_names]
|
||||
if subtask_names is None or (isinstance(subtask_names, float) and pd.isna(subtask_names)):
|
||||
return None, None, None
|
||||
|
||||
return (
|
||||
subtask_names,
|
||||
episodes_df.loc[ep_idx, col("subtask_start_frames")],
|
||||
episodes_df.loc[ep_idx, col("subtask_end_frames")],
|
||||
)
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
"""
|
||||
Encode images, text, and normalize states in the transition.
|
||||
|
||||
Implements SARM training data preparation:
|
||||
- Applies language perturbation (20% probability)
|
||||
- Applies rewind augmentation (80% probability)
|
||||
- Generates stage+tau targets for all frames
|
||||
- Outputs lengths tensor for valid sequence masking
|
||||
"""
|
||||
new_transition = transition.copy() if hasattr(transition, "copy") else dict(transition)
|
||||
observation = new_transition.get(TransitionKey.OBSERVATION)
|
||||
comp_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||
|
||||
frame_index = comp_data.get("index")
|
||||
episode_index = comp_data.get("episode_index")
|
||||
|
||||
if frame_index is None:
|
||||
raise ValueError("Frame index ('index') not found in COMPLEMENTARY_DATA")
|
||||
if episode_index is None:
|
||||
raise ValueError("Episode index ('episode_index') not found in COMPLEMENTARY_DATA")
|
||||
|
||||
frame_indices = np.atleast_1d(np.asarray(from_tensor_to_numpy(frame_index)))
|
||||
episode_indices = self._get_episode_indices(frame_indices, episode_index)
|
||||
|
||||
image = observation.get(self.image_key)
|
||||
if isinstance(image, torch.Tensor):
|
||||
image = image.cpu().numpy()
|
||||
|
||||
# If 4D (T, C, H, W) from delta_timestamps, add batch dim
|
||||
# If 3D (C, H, W) single frame, add batch and time dims
|
||||
if image.ndim == 4:
|
||||
image = image[np.newaxis, ...] # (T, C, H, W) -> (1, T, C, H, W)
|
||||
elif image.ndim == 3:
|
||||
image = image[np.newaxis, np.newaxis, ...] # (C, H, W) -> (1, 1, C, H, W)
|
||||
|
||||
batch_size = image.shape[0]
|
||||
total_frames = image.shape[1] # Should be 13: 9 obs + 4 rewind placeholders
|
||||
n_obs_steps = self.config.n_obs_steps
|
||||
max_rewind_steps = self.config.max_rewind_steps
|
||||
n_obs_frames = 1 + n_obs_steps # 9 observation frames (including current)
|
||||
|
||||
# Rewind augmentation
|
||||
rewind_steps = torch.zeros(batch_size, dtype=torch.int32)
|
||||
apply_rewind = self.training and random.random() < self.config.rewind_probability
|
||||
|
||||
if apply_rewind and self.dataset_meta is not None:
|
||||
for b_idx, (ep_idx, frame_idx) in enumerate(
|
||||
zip(episode_indices.tolist(), frame_indices.tolist(), strict=True)
|
||||
):
|
||||
ep_idx, frame_idx = int(ep_idx), int(frame_idx)
|
||||
ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"]
|
||||
|
||||
rewind_step, _ = apply_rewind_augmentation(
|
||||
frame_idx, ep_start, n_obs_steps, max_rewind_steps, frame_gap=self.config.frame_gap
|
||||
)
|
||||
rewind_steps[b_idx] = rewind_step
|
||||
|
||||
# Compute valid lengths: n_obs_frames + rewind_steps
|
||||
lengths = n_obs_frames + rewind_steps # (B,)
|
||||
|
||||
# Apply rewind masking to images
|
||||
# For frames beyond valid length, we mask with zeros (or copy last valid frame)
|
||||
for b_idx in range(batch_size):
|
||||
valid_len = lengths[b_idx].item()
|
||||
if valid_len < total_frames:
|
||||
image[b_idx, valid_len:] = 0 # Zero out frames beyond valid length
|
||||
|
||||
# Encode images with CLIP
|
||||
video_features = self._encode_images_batch(image)
|
||||
observation["video_features"] = video_features
|
||||
|
||||
state_key = self.config.state_key
|
||||
state_data = observation.get(state_key)
|
||||
|
||||
if isinstance(state_data, torch.Tensor):
|
||||
state_tensor = state_data.float()
|
||||
else:
|
||||
state_tensor = torch.tensor(state_data, dtype=torch.float32)
|
||||
|
||||
if state_tensor.ndim == 2:
|
||||
state_tensor = state_tensor.unsqueeze(0) # (T, D) -> (1, T, D)
|
||||
elif state_tensor.ndim == 1:
|
||||
state_tensor = state_tensor.unsqueeze(0).unsqueeze(0) # (D,) -> (1, 1, D)
|
||||
|
||||
# Apply same rewind masking to state
|
||||
for b_idx in range(batch_size):
|
||||
valid_len = lengths[b_idx].item()
|
||||
if valid_len < state_tensor.shape[1]:
|
||||
state_tensor[b_idx, valid_len:] = 0 # Zero out frames beyond valid length
|
||||
|
||||
observation["state_features"] = pad_state_to_max_dim(state_tensor, self.config.max_state_dim)
|
||||
|
||||
task = comp_data.get("task")
|
||||
if isinstance(task, list):
|
||||
task = task[0] if task else ""
|
||||
|
||||
# Apply language perturbation during training (20% probability)
|
||||
# When perturbed, targets will be zeroed to train model to output low values for irrelevant text
|
||||
apply_perturbation = self.training and random.random() < self.config.language_perturbation_probability
|
||||
if apply_perturbation:
|
||||
task = self._generate_perturbed_task()
|
||||
|
||||
# Encode text with CLIP
|
||||
observation["text_features"] = self._encode_text_clip(task, batch_size)
|
||||
|
||||
# Store lengths for model
|
||||
observation["lengths"] = lengths
|
||||
|
||||
# When language is perturbed, targets are zero so perturbed samples don't contribute to progress loss
|
||||
if self.dataset_meta is not None:
|
||||
episodes_df = None
|
||||
if self.sparse_subtask_names != ["task"]:
|
||||
episodes_df = self.dataset_meta.episodes.to_pandas()
|
||||
|
||||
# Generate sparse targets
|
||||
if self.sparse_temporal_proportions is not None:
|
||||
if apply_perturbation:
|
||||
# Zero targets when language is perturbed
|
||||
sparse_targets = torch.zeros(batch_size, total_frames, dtype=torch.float32)
|
||||
else:
|
||||
sparse_targets = self._compute_batch_targets(
|
||||
frame_indices, episode_indices, lengths, rewind_steps, episodes_df, "sparse"
|
||||
)
|
||||
observation["sparse_targets"] = sparse_targets
|
||||
|
||||
# Generate dense targets (for dual mode)
|
||||
if self.config.uses_dual_heads and self.dense_temporal_proportions is not None:
|
||||
if apply_perturbation:
|
||||
# Zero targets when language is perturbed
|
||||
dense_targets = torch.zeros(batch_size, total_frames, dtype=torch.float32)
|
||||
else:
|
||||
dense_targets = self._compute_batch_targets(
|
||||
frame_indices, episode_indices, lengths, rewind_steps, episodes_df, "dense"
|
||||
)
|
||||
observation["dense_targets"] = dense_targets
|
||||
|
||||
new_transition[TransitionKey.OBSERVATION] = observation
|
||||
return new_transition
|
||||
|
||||
def _compute_batch_targets(
|
||||
self,
|
||||
frame_indices: np.ndarray,
|
||||
episode_indices: np.ndarray,
|
||||
lengths: torch.Tensor,
|
||||
rewind_steps: torch.Tensor,
|
||||
episodes_df: pd.DataFrame | None,
|
||||
annotation_type: str,
|
||||
) -> torch.Tensor:
|
||||
"""Compute stage+tau targets for a batch of samples."""
|
||||
batch_size = len(frame_indices)
|
||||
n_obs_steps = self.config.n_obs_steps
|
||||
max_rewind_steps = self.config.max_rewind_steps
|
||||
total_frames = 1 + n_obs_steps + max_rewind_steps
|
||||
frame_gap = self.config.frame_gap
|
||||
|
||||
global_names, temporal_props = self._get_annotation_config(annotation_type)
|
||||
targets = torch.zeros(batch_size, total_frames, dtype=torch.float32)
|
||||
|
||||
for b_idx in range(batch_size):
|
||||
ep_idx = int(episode_indices[b_idx])
|
||||
frame_idx = int(frame_indices[b_idx])
|
||||
|
||||
ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"]
|
||||
ep_end = self.dataset_meta.episodes[ep_idx]["dataset_to_index"]
|
||||
ep_length = ep_end - ep_start
|
||||
|
||||
subtask_names, subtask_start_frames, subtask_end_frames = self._load_episode_annotations(
|
||||
ep_idx, episodes_df, annotation_type, global_names
|
||||
)
|
||||
|
||||
# Compute observation frame indices
|
||||
obs_indices, _ = compute_absolute_indices(
|
||||
frame_idx, ep_start, ep_end, n_obs_steps, frame_gap=frame_gap
|
||||
)
|
||||
obs_indices = obs_indices.tolist()
|
||||
|
||||
# Compute targets for observation frames
|
||||
for t_idx, abs_idx in enumerate(obs_indices):
|
||||
rel_frame = abs_idx - ep_start
|
||||
targets[b_idx, t_idx] = find_stage_and_tau(
|
||||
rel_frame,
|
||||
ep_length,
|
||||
subtask_names,
|
||||
subtask_start_frames,
|
||||
subtask_end_frames,
|
||||
global_names,
|
||||
temporal_props,
|
||||
return_combined=True,
|
||||
)
|
||||
|
||||
# Compute targets for rewind frames (if any)
|
||||
rewind_step = rewind_steps[b_idx].item()
|
||||
if rewind_step > 0:
|
||||
_, rewind_indices = apply_rewind_augmentation(
|
||||
frame_idx,
|
||||
ep_start,
|
||||
n_obs_steps,
|
||||
max_rewind_steps,
|
||||
frame_gap=frame_gap,
|
||||
rewind_step=rewind_step,
|
||||
)
|
||||
|
||||
for r_idx, abs_idx in enumerate(rewind_indices[:rewind_step]):
|
||||
rel_frame = max(0, abs_idx - ep_start)
|
||||
targets[b_idx, n_obs_steps + 1 + r_idx] = find_stage_and_tau(
|
||||
rel_frame,
|
||||
ep_length,
|
||||
subtask_names,
|
||||
subtask_start_frames,
|
||||
subtask_end_frames,
|
||||
global_names,
|
||||
temporal_props,
|
||||
return_combined=True,
|
||||
)
|
||||
|
||||
return targets
|
||||
|
||||
@property
|
||||
def training(self) -> bool:
|
||||
return getattr(self, "_training_mode", True)
|
||||
|
||||
def train(self, mode: bool = True):
|
||||
"""Set training mode for augmentation decisions."""
|
||||
self._training_mode = mode
|
||||
return self
|
||||
|
||||
def eval(self):
|
||||
"""Set evaluation mode (disable augmentations)."""
|
||||
return self.train(False)
|
||||
|
||||
@torch.no_grad()
|
||||
def _encode_images_batch(self, images: np.ndarray) -> torch.Tensor:
|
||||
"""Encode a batch of images using CLIP.
|
||||
|
||||
Args:
|
||||
images: Batched images with shape: (B, T, C, H, W)
|
||||
|
||||
Returns:
|
||||
Encoded feature vectors with shape (B, T, 512)
|
||||
"""
|
||||
|
||||
batch_size, seq_length = images.shape[0], images.shape[1]
|
||||
images = images.reshape(batch_size * seq_length, *images.shape[2:])
|
||||
|
||||
num_frames = images.shape[0]
|
||||
images_list = []
|
||||
for i in range(num_frames):
|
||||
img = images[i]
|
||||
if img.shape[0] in [1, 3]: # Channel first (C, H, W)
|
||||
img = img.transpose(1, 2, 0)
|
||||
|
||||
# Handle single channel
|
||||
if img.shape[-1] == 1:
|
||||
img = np.repeat(img, 3, axis=-1)
|
||||
|
||||
if img.dtype != np.uint8:
|
||||
img = (img * 255).astype(np.uint8) if img.max() <= 1.0 else img.astype(np.uint8)
|
||||
|
||||
images_list.append(Image.fromarray(img))
|
||||
|
||||
all_embeddings = []
|
||||
for i in range(0, num_frames, self.config.clip_batch_size):
|
||||
batch_imgs = images_list[i : i + self.config.clip_batch_size]
|
||||
|
||||
inputs = self.clip_processor(images=batch_imgs, return_tensors="pt")
|
||||
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
||||
|
||||
# Get image embeddings
|
||||
embeddings = self.clip_model.get_image_features(**inputs).detach().cpu()
|
||||
|
||||
# Handle single frame case
|
||||
if embeddings.dim() == 1:
|
||||
embeddings = embeddings.unsqueeze(0)
|
||||
|
||||
all_embeddings.append(embeddings)
|
||||
|
||||
all_embeddings = torch.cat(all_embeddings) # (B*T, 512)
|
||||
all_embeddings = all_embeddings.reshape(batch_size, seq_length, -1) # (B, T, 512)
|
||||
|
||||
return all_embeddings
|
||||
|
||||
@torch.no_grad()
|
||||
def _encode_text_clip(self, text: str, batch_size: int) -> torch.Tensor:
|
||||
"""Encode text using CLIP text encoder (per SARM paper A.4).
|
||||
|
||||
Args:
|
||||
text: Task description text to encode
|
||||
batch_size: Batch size to replicate for
|
||||
|
||||
Returns:
|
||||
Encoded text features with shape (B, 512)
|
||||
"""
|
||||
inputs = self.clip_processor.tokenizer([text], return_tensors="pt", padding=True, truncation=True)
|
||||
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
||||
|
||||
text_embedding = self.clip_model.get_text_features(**inputs).detach().cpu()
|
||||
text_embedding = text_embedding.expand(batch_size, -1)
|
||||
|
||||
return text_embedding
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
"""Add encoded features to the observation features."""
|
||||
features[PipelineFeatureType.OBSERVATION]["video_features"] = PolicyFeature(
|
||||
type=FeatureType.VISUAL, shape=(self.config.num_frames, self.config.image_dim)
|
||||
)
|
||||
features[PipelineFeatureType.OBSERVATION]["text_features"] = PolicyFeature(
|
||||
type=FeatureType.LANGUAGE, shape=(self.config.text_dim,)
|
||||
)
|
||||
features[PipelineFeatureType.OBSERVATION]["state_features"] = PolicyFeature(
|
||||
type=FeatureType.STATE, shape=(self.config.num_frames, self.config.max_state_dim)
|
||||
)
|
||||
return features
|
||||
|
||||
|
||||
def make_sarm_pre_post_processors(
|
||||
config: SARMConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
dataset_meta=None,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
"""Create pre-processor and post-processor pipelines for SARM."""
|
||||
return (
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||
steps=[
|
||||
AddBatchDimensionProcessorStep(),
|
||||
RenameObservationsProcessorStep(rename_map={}),
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
SARMEncodingProcessorStep(
|
||||
config=config, dataset_meta=dataset_meta, dataset_stats=dataset_stats
|
||||
),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
],
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
),
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction](
|
||||
steps=[DeviceProcessorStep(device="cpu")],
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
to_transition=policy_action_to_transition,
|
||||
to_output=transition_to_policy_action,
|
||||
),
|
||||
)
|
||||
@@ -1,295 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
|
||||
|
||||
def find_stage_and_tau(
|
||||
current_frame: int,
|
||||
episode_length: int,
|
||||
subtask_names: list | None,
|
||||
subtask_start_frames: list | None,
|
||||
subtask_end_frames: list | None,
|
||||
global_subtask_names: list,
|
||||
temporal_proportions: dict,
|
||||
return_combined: bool = False,
|
||||
) -> tuple[int, float] | float:
|
||||
"""Find stage and within-stage progress (tau) for a frame.
|
||||
|
||||
Args:
|
||||
current_frame: Frame index relative to episode start
|
||||
episode_length: Total frames in episode
|
||||
subtask_names: Subtask names for this episode (None for single_stage)
|
||||
subtask_start_frames: Subtask start frames
|
||||
subtask_end_frames: Subtask end frames
|
||||
global_subtask_names: Global list of all subtask names
|
||||
temporal_proportions: Dict of temporal proportions
|
||||
return_combined: If True, return stage+tau as float; else (stage_idx, tau) tuple
|
||||
|
||||
Returns:
|
||||
Float (stage.tau) if return_combined, else (stage_idx, tau) tuple
|
||||
"""
|
||||
stage_idx, tau = 0, 0.0
|
||||
num_stages = len(global_subtask_names)
|
||||
|
||||
# Single-stage mode: linear progress from 0 to 1
|
||||
if num_stages == 1:
|
||||
tau = min(1.0, max(0.0, current_frame / max(episode_length - 1, 1)))
|
||||
elif subtask_names is None:
|
||||
pass # stage_idx=0, tau=0.0
|
||||
elif current_frame < subtask_start_frames[0]:
|
||||
pass # Before first subtask: stage_idx=0, tau=0.0
|
||||
elif current_frame > subtask_end_frames[-1]:
|
||||
stage_idx, tau = num_stages - 1, 0.999 # After last subtask
|
||||
else:
|
||||
# Find which subtask this frame belongs to
|
||||
found = False
|
||||
for name, start, end in zip(subtask_names, subtask_start_frames, subtask_end_frames, strict=True):
|
||||
if start <= current_frame <= end:
|
||||
stage_idx = global_subtask_names.index(name) if name in global_subtask_names else 0
|
||||
tau = compute_tau(current_frame, start, end)
|
||||
found = True
|
||||
break
|
||||
# Frame between subtasks - use previous subtask's end state
|
||||
if not found:
|
||||
for j in range(len(subtask_names) - 1):
|
||||
if subtask_end_frames[j] < current_frame < subtask_start_frames[j + 1]:
|
||||
name = subtask_names[j]
|
||||
stage_idx = global_subtask_names.index(name) if name in global_subtask_names else j
|
||||
tau = 1.0
|
||||
break
|
||||
|
||||
if return_combined:
|
||||
# Clamp to avoid overflow at end
|
||||
if stage_idx >= num_stages - 1 and tau >= 1.0:
|
||||
return num_stages - 1 + 0.999
|
||||
return stage_idx + tau
|
||||
return stage_idx, tau
|
||||
|
||||
|
||||
def compute_absolute_indices(
|
||||
frame_idx: int,
|
||||
ep_start: int,
|
||||
ep_end: int,
|
||||
n_obs_steps: int,
|
||||
frame_gap: int = 30,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Compute absolute frame indices with clamping for bidirectional observation sequence.
|
||||
|
||||
Bidirectional sampling centered on target frame:
|
||||
- Before: [-frame_gap * half_steps, ..., -frame_gap] (half_steps frames)
|
||||
- Current: [0] (1 frame)
|
||||
- After: [frame_gap, ..., frame_gap * half_steps] (half_steps frames)
|
||||
- Total: n_obs_steps + 1 frames
|
||||
|
||||
Out-of-bounds frames are clamped (duplicated from boundary).
|
||||
|
||||
Args:
|
||||
frame_idx: Target frame index (center frame of sequence)
|
||||
ep_start: Episode start index
|
||||
ep_end: Episode end index (exclusive)
|
||||
n_obs_steps: Number of observation steps (must be even for symmetric sampling)
|
||||
frame_gap: Gap between observation frames
|
||||
|
||||
Returns:
|
||||
Tuple of (indices, out_of_bounds_flags)
|
||||
"""
|
||||
half_steps = n_obs_steps // 2
|
||||
|
||||
# Bidirectional deltas: past + current + future
|
||||
past_deltas = [-frame_gap * i for i in range(half_steps, 0, -1)]
|
||||
future_deltas = [frame_gap * i for i in range(1, half_steps + 1)]
|
||||
delta_indices = past_deltas + [0] + future_deltas
|
||||
|
||||
frames = []
|
||||
out_of_bounds = []
|
||||
|
||||
for delta in delta_indices:
|
||||
target_idx = frame_idx + delta
|
||||
# Clamp to episode bounds (duplicate boundary frames for out-of-bounds)
|
||||
clamped_idx = max(ep_start, min(ep_end - 1, target_idx))
|
||||
frames.append(clamped_idx)
|
||||
# Flag as out-of-bounds if clamping occurred
|
||||
out_of_bounds.append(1 if target_idx != clamped_idx else 0)
|
||||
|
||||
return torch.tensor(frames), torch.tensor(out_of_bounds)
|
||||
|
||||
|
||||
def apply_rewind_augmentation(
|
||||
frame_idx: int,
|
||||
ep_start: int,
|
||||
n_obs_steps: int,
|
||||
max_rewind_steps: int,
|
||||
frame_gap: int = 30,
|
||||
rewind_step: int | None = None,
|
||||
) -> tuple[int, list[int]]:
|
||||
"""
|
||||
Generate rewind frame indices for temporal augmentation.
|
||||
|
||||
Rewind simulates going backwards through previously seen frames,
|
||||
starting from before the earliest observation frame (for bidirectional sampling).
|
||||
Appends reversed frames after the observation sequence.
|
||||
|
||||
Args:
|
||||
frame_idx: Target frame index (center of bidirectional observation window)
|
||||
ep_start: Episode start index
|
||||
n_obs_steps: Number of observation steps
|
||||
max_rewind_steps: Maximum rewind steps
|
||||
frame_gap: Gap between frames
|
||||
rewind_step: If provided, use this exact rewind step (for deterministic behavior).
|
||||
If None, sample randomly.
|
||||
|
||||
Returns:
|
||||
Tuple of (rewind_step, rewind_indices)
|
||||
"""
|
||||
# For bidirectional sampling, earliest obs frame is at frame_idx - half_steps * frame_gap
|
||||
half_steps = n_obs_steps // 2
|
||||
earliest_obs_frame = frame_idx - half_steps * frame_gap
|
||||
|
||||
# Required history: frames before earliest observation frame
|
||||
if earliest_obs_frame <= ep_start:
|
||||
return 0, [] # No history before observation window
|
||||
|
||||
# Max valid rewind steps based on available history before earliest obs frame
|
||||
available_history = earliest_obs_frame - ep_start
|
||||
max_valid_step = available_history // frame_gap
|
||||
max_rewind = min(max_rewind_steps, max(0, max_valid_step))
|
||||
|
||||
if max_rewind <= 0:
|
||||
return 0, []
|
||||
|
||||
# Sample rewind steps if not provided
|
||||
rewind_step = random.randint(1, max_rewind) if rewind_step is None else min(rewind_step, max_rewind)
|
||||
|
||||
if rewind_step == 0:
|
||||
return 0, []
|
||||
|
||||
# Generate rewind indices going backwards from earliest obs frame
|
||||
# rewind_indices[0] is closest to obs window, rewind_indices[-1] is furthest back
|
||||
rewind_indices = []
|
||||
for i in range(1, rewind_step + 1):
|
||||
idx = earliest_obs_frame - i * frame_gap
|
||||
idx = max(ep_start, idx) # Clamp to episode start
|
||||
rewind_indices.append(idx)
|
||||
|
||||
return rewind_step, rewind_indices
|
||||
|
||||
|
||||
def compute_tau(current_frame: int | float, subtask_start: int | float, subtask_end: int | float) -> float:
|
||||
"""Compute τ_t = (t - s_k) / (e_k - s_k) ∈ [0, 1]. Returns 1.0 for zero-duration subtasks."""
|
||||
duration = subtask_end - subtask_start
|
||||
if duration <= 0:
|
||||
return 1.0
|
||||
return float(np.clip((current_frame - subtask_start) / duration, 0.0, 1.0))
|
||||
|
||||
|
||||
def pad_state_to_max_dim(state: torch.Tensor, max_state_dim: int) -> torch.Tensor:
|
||||
"""Pad the state tensor's last dimension to max_state_dim with zeros."""
|
||||
current_dim = state.shape[-1]
|
||||
if current_dim >= max_state_dim:
|
||||
return state[..., :max_state_dim] # Truncate if larger
|
||||
|
||||
# Pad with zeros on the right
|
||||
padding = (0, max_state_dim - current_dim) # (left, right) for last dim
|
||||
return F.pad(state, padding, mode="constant", value=0)
|
||||
|
||||
|
||||
def temporal_proportions_to_breakpoints(
|
||||
temporal_proportions: dict[str, float] | list[float] | None,
|
||||
subtask_names: list[str] | None = None,
|
||||
) -> list[float] | None:
|
||||
"""Convert temporal proportions to cumulative breakpoints for normalization."""
|
||||
if temporal_proportions is None:
|
||||
return None
|
||||
|
||||
if isinstance(temporal_proportions, dict):
|
||||
if subtask_names is not None:
|
||||
proportions = [temporal_proportions.get(name, 0.0) for name in subtask_names]
|
||||
else:
|
||||
proportions = list(temporal_proportions.values())
|
||||
else:
|
||||
proportions = list(temporal_proportions)
|
||||
|
||||
total = sum(proportions)
|
||||
if total > 0 and abs(total - 1.0) > 1e-6:
|
||||
proportions = [p / total for p in proportions]
|
||||
|
||||
breakpoints = [0.0]
|
||||
cumsum = 0.0
|
||||
for prop in proportions:
|
||||
cumsum += prop
|
||||
breakpoints.append(cumsum)
|
||||
breakpoints[-1] = 1.0
|
||||
|
||||
return breakpoints
|
||||
|
||||
|
||||
def normalize_stage_tau(
|
||||
x: float | torch.Tensor,
|
||||
num_stages: int | None = None,
|
||||
breakpoints: list[float] | None = None,
|
||||
temporal_proportions: dict[str, float] | list[float] | None = None,
|
||||
subtask_names: list[str] | None = None,
|
||||
) -> float | torch.Tensor:
|
||||
"""
|
||||
Normalize stage+tau reward to [0, 1] with custom breakpoints.
|
||||
|
||||
Maps stage index + within-stage tau to normalized progress [0, 1].
|
||||
The breakpoints are designed to give appropriate weight to each stage
|
||||
based on their importance in the task (using temporal proportions).
|
||||
|
||||
Priority: breakpoints > temporal_proportions > linear fallback
|
||||
|
||||
Args:
|
||||
x: Raw reward value (stage index + tau) where stage ∈ [0, num_stages-1] and tau ∈ [0, 1)
|
||||
num_stages: Number of stages (required if breakpoints/proportions not provided)
|
||||
breakpoints: Optional custom breakpoints list of length num_stages + 1.
|
||||
temporal_proportions: Optional temporal proportions dict/list to compute breakpoints.
|
||||
subtask_names: Optional ordered list of subtask names (for dict proportions)
|
||||
|
||||
Returns:
|
||||
Normalized progress value ∈ [0, 1]
|
||||
"""
|
||||
if breakpoints is not None:
|
||||
num_stages = len(breakpoints) - 1
|
||||
elif temporal_proportions is not None:
|
||||
breakpoints = temporal_proportions_to_breakpoints(temporal_proportions, subtask_names)
|
||||
num_stages = len(breakpoints) - 1
|
||||
elif num_stages is not None:
|
||||
breakpoints = [i / num_stages for i in range(num_stages + 1)]
|
||||
else:
|
||||
raise ValueError("Either num_stages, breakpoints, or temporal_proportions must be provided")
|
||||
|
||||
if isinstance(x, torch.Tensor):
|
||||
result = torch.zeros_like(x)
|
||||
for i in range(num_stages):
|
||||
mask = (x >= i) & (x < i + 1)
|
||||
tau_in_stage = x - i
|
||||
result[mask] = breakpoints[i] + tau_in_stage[mask] * (breakpoints[i + 1] - breakpoints[i])
|
||||
result[x >= num_stages] = 1.0
|
||||
return result.clamp(0.0, 1.0)
|
||||
else:
|
||||
if x < 0:
|
||||
return 0.0
|
||||
if x >= num_stages:
|
||||
return 1.0
|
||||
stage = int(x)
|
||||
tau = x - stage
|
||||
return breakpoints[stage] + tau * (breakpoints[stage + 1] - breakpoints[stage])
|
||||
@@ -231,7 +231,6 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
||||
def __init__(
|
||||
self,
|
||||
config: SmolVLAConfig,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@@ -353,19 +352,8 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
||||
def _rtc_enabled(self) -> bool:
|
||||
return self.config.rtc_config is not None and self.config.rtc_config.enabled
|
||||
|
||||
def forward(
|
||||
self, batch: dict[str, Tensor], noise=None, time=None, reduction: str = "mean"
|
||||
) -> dict[str, Tensor]:
|
||||
"""Do a full training forward pass to compute the loss.
|
||||
|
||||
Args:
|
||||
batch: Training batch containing observations and actions.
|
||||
noise: Optional noise tensor for flow matching.
|
||||
time: Optional time tensor for flow matching.
|
||||
reduction: How to reduce the loss. Options:
|
||||
- "mean": Return scalar mean loss (default, backward compatible)
|
||||
- "none": Return per-sample losses of shape (batch_size,) for RA-BC weighting
|
||||
"""
|
||||
def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> dict[str, Tensor]:
|
||||
"""Do a full training forward pass to compute the loss"""
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
|
||||
batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
|
||||
@@ -389,16 +377,11 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
||||
losses = losses[:, :, : self.config.max_action_dim]
|
||||
loss_dict["losses_after_rm_padding"] = losses.clone()
|
||||
|
||||
if reduction == "none":
|
||||
# Return per-sample losses (B,) by averaging over time and action dims
|
||||
per_sample_loss = losses.mean(dim=(1, 2))
|
||||
loss_dict["loss"] = per_sample_loss.mean().item()
|
||||
return per_sample_loss, loss_dict
|
||||
else:
|
||||
# Default: return scalar mean loss
|
||||
loss = losses.mean()
|
||||
loss_dict["loss"] = loss.item()
|
||||
return loss, loss_dict
|
||||
# For backward pass
|
||||
loss = losses.mean()
|
||||
# For backward pass
|
||||
loss_dict["loss"] = loss.item()
|
||||
return loss, loss_dict
|
||||
|
||||
def prepare_images(self, batch):
|
||||
"""Apply SmolVLA preprocessing to the images, like resizing to 224x224 and padding to keep aspect ratio, and
|
||||
@@ -544,7 +527,6 @@ class VLAFlowMatching(nn.Module):
|
||||
num_vlm_layers=self.config.num_vlm_layers,
|
||||
self_attn_every_n_layers=self.config.self_attn_every_n_layers,
|
||||
expert_width_multiplier=self.config.expert_width_multiplier,
|
||||
device=self.config.device if self.config.device is not None else "auto",
|
||||
)
|
||||
self.state_proj = nn.Linear(
|
||||
self.config.max_state_dim, self.vlm_with_expert.config.text_config.hidden_size
|
||||
|
||||
@@ -65,7 +65,6 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
def __init__(
|
||||
self,
|
||||
config: TDMPCConfig,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
|
||||
@@ -231,20 +231,11 @@ def validate_visual_features_consistency(
|
||||
"""
|
||||
Validates visual feature consistency between a policy config and provided dataset/environment features.
|
||||
|
||||
Validation passes if EITHER:
|
||||
- Policy's expected visuals are a subset of dataset (policy uses some cameras, dataset has more)
|
||||
- Dataset's provided visuals are a subset of policy (policy declares extras for flexibility)
|
||||
|
||||
Args:
|
||||
cfg (PreTrainedConfig): The model or policy configuration containing input_features and type.
|
||||
features (Dict[str, PolicyFeature]): A mapping of feature names to PolicyFeature objects.
|
||||
"""
|
||||
expected_visuals = {k for k, v in cfg.input_features.items() if v.type == FeatureType.VISUAL}
|
||||
provided_visuals = {k for k, v in features.items() if v.type == FeatureType.VISUAL}
|
||||
|
||||
# Accept if either direction is a subset
|
||||
policy_subset_of_dataset = expected_visuals.issubset(provided_visuals)
|
||||
dataset_subset_of_policy = provided_visuals.issubset(expected_visuals)
|
||||
|
||||
if not (policy_subset_of_dataset or dataset_subset_of_policy):
|
||||
if not provided_visuals.issubset(expected_visuals):
|
||||
raise_feature_mismatch_error(provided_visuals, expected_visuals)
|
||||
|
||||
@@ -47,7 +47,6 @@ class VQBeTPolicy(PreTrainedPolicy):
|
||||
def __init__(
|
||||
self,
|
||||
config: VQBeTConfig | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
../../../../docs/source/policy_walloss_README.md
|
||||
@@ -1,19 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .configuration_wall_x import WallXConfig
|
||||
|
||||
__all__ = ["WallXConfig", "WallXPolicy", "make_wall_x_pre_post_processors"]
|
||||
@@ -1,165 +0,0 @@
|
||||
# Copyright 2025 HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("wall_x")
|
||||
@dataclass
|
||||
class WallXConfig(PreTrainedConfig):
|
||||
"""
|
||||
Configuration class for Wall-X policy.
|
||||
|
||||
Wall-X is based on Qwen2.5-VL with action prediction capabilities using flow matching.
|
||||
It supports cross-embodiment robotic control through unified action representations.
|
||||
|
||||
This config supports multi-modal learning with vision, language, and action data.
|
||||
"""
|
||||
|
||||
# ==================== Input / Output Structure ====================
|
||||
n_obs_steps: int = 1
|
||||
chunk_size: int = 32 # action_horizon in wall-x
|
||||
n_action_steps: int = 32
|
||||
|
||||
# Action dimension - wall-x uses 20
|
||||
max_action_dim: int = 20
|
||||
max_state_dim: int = 20 # For proprioception
|
||||
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.MEAN_STD,
|
||||
"ACTION": NormalizationMode.MEAN_STD,
|
||||
}
|
||||
)
|
||||
|
||||
# ==================== Action Prediction ====================
|
||||
# Pretrained model paths
|
||||
pretrained_name_or_path: str = "x-square-robot/wall-oss-flow"
|
||||
|
||||
# Tokenizer settings
|
||||
action_tokenizer_path: str | None = "physical-intelligence/fast"
|
||||
|
||||
# Action prediction mode: "diffusion" or "fast"
|
||||
prediction_mode: str = "diffusion"
|
||||
|
||||
# Attention Implementation, options: "eager", "flash_attention_2", "sdpa"
|
||||
# NOTE: flash-attn==2.7.4.post1 is required for flash_attention_2 implementation
|
||||
attn_implementation: str = "eager"
|
||||
|
||||
# ==================== Optimizer Presets ====================
|
||||
optimizer_lr: float = 2e-5
|
||||
optimizer_betas: tuple[float, float] = (0.9, 0.95)
|
||||
optimizer_eps: float = 1e-8
|
||||
optimizer_weight_decay: float = 0.01
|
||||
optimizer_grad_clip_norm: float = 1.0
|
||||
|
||||
scheduler_warmup_steps: int = 1000
|
||||
scheduler_decay_steps: int = 100000
|
||||
scheduler_decay_lr: float = 1e-6
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
# Input validation
|
||||
if self.n_action_steps > self.chunk_size:
|
||||
raise ValueError(
|
||||
f"The chunk size is the upper bound for the number of action steps per model invocation. Got "
|
||||
f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`."
|
||||
)
|
||||
|
||||
if self.prediction_mode not in ["diffusion", "fast"]:
|
||||
raise ValueError(f"prediction_mode must be 'diffusion' or 'fast', got {self.prediction_mode}")
|
||||
|
||||
# Assign use_fast_tokenizer based on prediction_mode
|
||||
if self.prediction_mode == "fast":
|
||||
self.use_fast_tokenizer = True
|
||||
elif self.prediction_mode == "diffusion":
|
||||
self.use_fast_tokenizer = False
|
||||
self.action_tokenizer_path = None # disable action tokenizer for diffusion mode
|
||||
else:
|
||||
raise ValueError(f"prediction_mode must be 'diffusion' or 'fast', got {self.prediction_mode}")
|
||||
|
||||
def validate_features(self) -> None:
|
||||
"""Validate and set up input/output features."""
|
||||
image_features = [key for key, feat in self.input_features.items() if feat.type == FeatureType.VISUAL]
|
||||
if not image_features:
|
||||
raise ValueError(
|
||||
"Wall-X policy requires at least one visual input feature. "
|
||||
"No features of type FeatureType.VISUAL found in input_features."
|
||||
)
|
||||
|
||||
if "observation.state" not in self.input_features:
|
||||
state_feature = PolicyFeature(
|
||||
type=FeatureType.STATE,
|
||||
shape=(self.max_state_dim,), # Padded to max_state_dim
|
||||
)
|
||||
self.input_features["observation.state"] = state_feature
|
||||
else:
|
||||
state_shape = self.input_features["observation.state"].shape
|
||||
state_dim = state_shape[0] if state_shape else 0
|
||||
if state_dim > self.max_state_dim:
|
||||
raise ValueError(
|
||||
f"State dimension {state_dim} exceeds max_state_dim {self.max_state_dim}. "
|
||||
f"Either reduce state dimension or increase max_state_dim in config."
|
||||
)
|
||||
|
||||
if "action" not in self.output_features:
|
||||
action_feature = PolicyFeature(
|
||||
type=FeatureType.ACTION,
|
||||
shape=(self.max_action_dim,), # Padded to max_action_dim
|
||||
)
|
||||
self.output_features["action"] = action_feature
|
||||
else:
|
||||
action_shape = self.output_features["action"].shape
|
||||
action_dim = action_shape[0] if action_shape else 0
|
||||
if action_dim > self.max_action_dim:
|
||||
raise ValueError(
|
||||
f"Action dimension {action_dim} exceeds max_action_dim {self.max_action_dim}. "
|
||||
f"Either reduce action dimension or increase max_action_dim in config."
|
||||
)
|
||||
|
||||
def get_optimizer_preset(self) -> AdamWConfig:
|
||||
return AdamWConfig(
|
||||
lr=self.optimizer_lr,
|
||||
betas=self.optimizer_betas,
|
||||
eps=self.optimizer_eps,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
grad_clip_norm=self.optimizer_grad_clip_norm,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self):
|
||||
return CosineDecayWithWarmupSchedulerConfig(
|
||||
peak_lr=self.optimizer_lr,
|
||||
decay_lr=self.scheduler_decay_lr,
|
||||
num_warmup_steps=self.scheduler_warmup_steps,
|
||||
num_decay_steps=self.scheduler_decay_steps,
|
||||
)
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> list:
|
||||
return None
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list:
|
||||
return list(range(self.chunk_size))
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
||||
@@ -1,41 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Wall-X Constants and Configuration Data.
|
||||
"""
|
||||
|
||||
CAMERA_NAME_MAPPING = {
|
||||
"face_view": "front view",
|
||||
"left_wrist_view": "left wrist view",
|
||||
"right_wrist_view": "right wrist view",
|
||||
"move1_view": "move view",
|
||||
"move2_view": "move view",
|
||||
"wall_view": "wall view",
|
||||
"top_view": "top view",
|
||||
}
|
||||
|
||||
RESOLUTION = 256
|
||||
|
||||
# Parameters for preprocessing
|
||||
MAX_PIXELS = 16384 * 28 * 28
|
||||
MIN_PIXELS = 4 * 28 * 28
|
||||
IMAGE_FACTOR = 28
|
||||
PRIORITY_ORDER = None
|
||||
GENERATE_SUBTASK_RATIO = 0.0
|
||||
MODEL_TYPE = "qwen2_5"
|
||||
|
||||
TOKENIZER_MAX_LENGTH = 768
|
||||
@@ -1,133 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.policies.wall_x.configuration_wall_x import WallXConfig
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
ComplementaryDataProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyAction,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorStepRegistry,
|
||||
RenameObservationsProcessorStep,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
|
||||
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
|
||||
|
||||
def make_wall_x_pre_post_processors(
|
||||
config: WallXConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
"""
|
||||
Constructs pre-processor and post-processor pipelines for the Wall-X policy.
|
||||
|
||||
The pre-processing pipeline prepares input data for the model by:
|
||||
1. Renaming features to match pretrained configurations
|
||||
2. Adding a batch dimension
|
||||
4. Normalizing input and output features based on dataset statistics
|
||||
5. Moving all data to the specified device
|
||||
|
||||
The post-processing pipeline handles the model's output by:
|
||||
1. Unnormalizing the output actions to their original scale
|
||||
2. Moving data to the CPU
|
||||
|
||||
Args:
|
||||
config: The configuration object for the Wall-X policy
|
||||
dataset_stats: A dictionary of statistics for normalization
|
||||
|
||||
Returns:
|
||||
A tuple containing the configured pre-processor and post-processor pipelines
|
||||
"""
|
||||
|
||||
input_steps = [
|
||||
RenameObservationsProcessorStep(rename_map={}),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
WallXTaskProcessor(), # Process task description
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
]
|
||||
|
||||
output_steps = [
|
||||
UnnormalizerProcessorStep(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
]
|
||||
|
||||
return (
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||
steps=input_steps,
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
),
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction](
|
||||
steps=output_steps,
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
to_transition=policy_action_to_transition,
|
||||
to_output=transition_to_policy_action,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register(name="wall_x_task_processor")
|
||||
class WallXTaskProcessor(ComplementaryDataProcessorStep):
|
||||
"""
|
||||
A processor step that ensures the task description is properly formatted for Wall-X.
|
||||
|
||||
This step handles task preprocessing similar to Qwen-VL requirements.
|
||||
"""
|
||||
|
||||
def complementary_data(self, complementary_data):
|
||||
if "task" not in complementary_data:
|
||||
return complementary_data
|
||||
|
||||
task = complementary_data["task"]
|
||||
if task is None:
|
||||
# Provide default task if none specified
|
||||
complementary_data["task"] = "Execute the robot action."
|
||||
return complementary_data
|
||||
|
||||
new_complementary_data = dict(complementary_data)
|
||||
|
||||
# Handle both string and list of strings
|
||||
if isinstance(task, str):
|
||||
# Single string: ensure proper formatting
|
||||
if not task.endswith("."):
|
||||
new_complementary_data["task"] = f"{task}."
|
||||
elif isinstance(task, list) and all(isinstance(t, str) for t in task):
|
||||
# List of strings: format each
|
||||
new_complementary_data["task"] = [t if t.endswith(".") else f"{t}." for t in task]
|
||||
|
||||
return new_complementary_data
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
return features
|
||||
@@ -1,248 +0,0 @@
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.modeling_rope_utils import rope_config_validation
|
||||
|
||||
|
||||
class Qwen2_5_VLVisionConfig(PretrainedConfig):
|
||||
model_type = "qwen2_5_vl"
|
||||
base_config_key = "vision_config"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
depth=32,
|
||||
hidden_size=3584,
|
||||
hidden_act="silu",
|
||||
intermediate_size=3420,
|
||||
num_heads=16,
|
||||
in_channels=3,
|
||||
patch_size=14,
|
||||
spatial_merge_size=2,
|
||||
temporal_patch_size=2,
|
||||
tokens_per_second=4,
|
||||
window_size=112,
|
||||
out_hidden_size=3584,
|
||||
fullatt_block_indexes=[7, 15, 23, 31],
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.depth = depth
|
||||
self.hidden_size = hidden_size
|
||||
self.hidden_act = hidden_act
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_heads = num_heads
|
||||
self.in_channels = in_channels
|
||||
self.patch_size = patch_size
|
||||
self.spatial_merge_size = spatial_merge_size
|
||||
self.temporal_patch_size = temporal_patch_size
|
||||
self.tokens_per_second = tokens_per_second
|
||||
self.window_size = window_size
|
||||
self.fullatt_block_indexes = fullatt_block_indexes
|
||||
self.out_hidden_size = out_hidden_size
|
||||
|
||||
|
||||
class Qwen2_5_VLConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`Qwen2_5_VLModel`]. It is used to instantiate a
|
||||
Qwen2-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
||||
with the defaults will yield a similar configuration to that of
|
||||
Qwen2-VL-7B-Instruct [Qwen/Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct).
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 152064):
|
||||
Vocabulary size of the Qwen2_5_VL model. Defines the number of different tokens that can be represented by the
|
||||
`inputs_ids` passed when calling [`Qwen2_5_VLModel`]
|
||||
hidden_size (`int`, *optional*, defaults to 8192):
|
||||
Dimension of the hidden representations.
|
||||
intermediate_size (`int`, *optional*, defaults to 29568):
|
||||
Dimension of the MLP representations.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 80):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 64):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
num_key_value_heads (`int`, *optional*, defaults to 8):
|
||||
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
||||
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
||||
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
||||
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
||||
by meanpooling all the original heads within that group. For more details checkout [this
|
||||
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
||||
The non-linear activation function (function or string) in the decoder.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 32768):
|
||||
The maximum sequence length that this model might ever be used with.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
|
||||
The epsilon used by the rms normalization layers.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||
relevant if `config.is_decoder=True`.
|
||||
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
||||
Whether the model's input and output word embeddings should be tied.
|
||||
rope_theta (`float`, *optional*, defaults to 1000000.0):
|
||||
The base period of the RoPE embeddings.
|
||||
use_sliding_window (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use sliding window attention.
|
||||
sliding_window (`int`, *optional*, defaults to 4096):
|
||||
Sliding window attention (SWA) window size. If not specified, will default to `4096`.
|
||||
max_window_layers (`int`, *optional*, defaults to 80):
|
||||
The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
vision_config (`Dict`, *optional*):
|
||||
The config for the visual encoder initialization.
|
||||
rope_scaling (`Dict`, *optional*):
|
||||
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
|
||||
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
||||
accordingly.
|
||||
Expected contents:
|
||||
`rope_type` (`str`):
|
||||
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
|
||||
'llama3'], with 'default' being the original RoPE implementation.
|
||||
`factor` (`float`, *optional*):
|
||||
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
|
||||
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
|
||||
original maximum pre-trained length.
|
||||
`original_max_position_embeddings` (`int`, *optional*):
|
||||
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
|
||||
pretraining.
|
||||
`attention_factor` (`float`, *optional*):
|
||||
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
|
||||
computation. If unspecified, it defaults to value recommended by the implementation, using the
|
||||
`factor` field to infer the suggested value.
|
||||
`beta_fast` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 32.
|
||||
`beta_slow` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 1.
|
||||
`short_factor` (`List[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`long_factor` (`List[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`low_freq_factor` (`float`, *optional*):
|
||||
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
|
||||
`high_freq_factor` (`float`, *optional*):
|
||||
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
|
||||
|
||||
```python
|
||||
>>> from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLConfig
|
||||
|
||||
>>> # Initializing a Qwen2_5_VL style configuration
|
||||
>>> configuration = Qwen2_5_VLConfig()
|
||||
|
||||
>>> # Initializing a model from the Qwen2-VL-7B style configuration
|
||||
>>> model = Qwen2_5_VLForConditionalGeneration(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "qwen2_5_vl"
|
||||
sub_configs = {"vision_config": Qwen2_5_VLVisionConfig}
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
# Default tensor parallel plan for base model `Qwen2_5_VL`
|
||||
base_model_tp_plan = {
|
||||
"layers.*.self_attn.q_proj": "colwise",
|
||||
"layers.*.self_attn.k_proj": "colwise",
|
||||
"layers.*.self_attn.v_proj": "colwise",
|
||||
"layers.*.self_attn.o_proj": "rowwise",
|
||||
"layers.*.mlp.gate_proj": "colwise",
|
||||
"layers.*.mlp.up_proj": "colwise",
|
||||
"layers.*.mlp.down_proj": "rowwise",
|
||||
}
|
||||
base_model_pp_plan = {
|
||||
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||
"norm": (["hidden_states"], ["hidden_states"]),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=152064,
|
||||
hidden_size=8192,
|
||||
intermediate_size=29568,
|
||||
num_hidden_layers=80,
|
||||
num_attention_heads=64,
|
||||
num_key_value_heads=8,
|
||||
hidden_act="silu",
|
||||
max_position_embeddings=32768,
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=1e-05,
|
||||
use_cache=True,
|
||||
tie_word_embeddings=False,
|
||||
rope_theta=1000000.0,
|
||||
use_sliding_window=False,
|
||||
sliding_window=4096,
|
||||
max_window_layers=80,
|
||||
attention_dropout=0.0,
|
||||
vision_config=None,
|
||||
rope_scaling=None,
|
||||
num_experts=4,
|
||||
experts=None,
|
||||
dof_config=None,
|
||||
noise_scheduler=None,
|
||||
dim_inputs=(1536, 1536),
|
||||
attention_moe=False,
|
||||
mlp_moe=False,
|
||||
**kwargs,
|
||||
):
|
||||
if isinstance(vision_config, dict):
|
||||
self.vision_config = self.sub_configs["vision_config"](**vision_config)
|
||||
elif vision_config is None:
|
||||
self.vision_config = self.sub_configs["vision_config"]()
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.use_sliding_window = use_sliding_window
|
||||
self.sliding_window = sliding_window
|
||||
self.max_window_layers = max_window_layers
|
||||
self.layer_types = ["dense"] * num_hidden_layers
|
||||
|
||||
# for backward compatibility
|
||||
if num_key_value_heads is None:
|
||||
num_key_value_heads = num_attention_heads
|
||||
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.attention_dropout = attention_dropout
|
||||
self.rope_scaling = rope_scaling
|
||||
|
||||
self.num_experts = num_experts
|
||||
self.experts = experts
|
||||
self.dof_config = dof_config
|
||||
self.noise_scheduler = noise_scheduler
|
||||
self.dim_inputs = tuple(dim_inputs)
|
||||
self.attention_moe = attention_moe
|
||||
self.mlp_moe = mlp_moe
|
||||
|
||||
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
||||
if self.rope_scaling["type"] == "mrope":
|
||||
self.rope_scaling["type"] = "default"
|
||||
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
||||
rope_config_validation(self, ignore_keys={"mrope_section"})
|
||||
|
||||
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
|
||||
|
||||
@property
|
||||
def text_config(self):
|
||||
return self
|
||||
|
||||
|
||||
__all__ = ["Qwen2_5_VLConfig"]
|
||||
@@ -1,631 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Wall-X Utility Functions.
|
||||
|
||||
Contains data processing utilities, text formatting functions, and helper classes
|
||||
for the Wall-X cross-embodiment robotic control model.
|
||||
"""
|
||||
|
||||
import random
|
||||
import re
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from transformers import BatchFeature
|
||||
|
||||
from lerobot.policies.wall_x.constant import (
|
||||
CAMERA_NAME_MAPPING,
|
||||
)
|
||||
from lerobot.utils.constants import OBS_IMAGES
|
||||
|
||||
|
||||
@dataclass
|
||||
class X2RDataProcessingConfig:
|
||||
"""Configuration class for X2R data processing pipeline.
|
||||
|
||||
This class contains all the necessary parameters for processing robotic data
|
||||
including camera mappings, tactile sensor configurations, action predictions,
|
||||
and various processing options.
|
||||
"""
|
||||
|
||||
# Action prediction configuration
|
||||
predict_action_keys: list[str] = field(default_factory=list)
|
||||
obs_action_keys: list[str] = field(default_factory=list)
|
||||
|
||||
# Image resolution settings for different views
|
||||
resolution: dict[str, int] = field(
|
||||
default_factory=lambda: {
|
||||
"face_view": -1,
|
||||
"left_wrist_view": 128,
|
||||
"right_wrist_view": 128,
|
||||
}
|
||||
)
|
||||
|
||||
# Dataset splitting
|
||||
train_test_split: float = 0.9
|
||||
split_seed: int = 42
|
||||
|
||||
# Instruction handling
|
||||
priority_order: dict[str, float] | None = None
|
||||
|
||||
# Vision model parameters
|
||||
model_type: str = "qwen2_5"
|
||||
max_pixels: int = 16384 * 28 * 28
|
||||
min_pixels: int = 4 * 28 * 28
|
||||
image_factor: int = 28
|
||||
|
||||
generate_subtask_ratio: float = 0.0
|
||||
|
||||
def __post_init__(self):
|
||||
"""Post-initialization validation and setup."""
|
||||
# Validate train/test split
|
||||
if not 0 < self.train_test_split < 1:
|
||||
raise ValueError(f"train_test_split must be between 0 and 1, got {self.train_test_split}")
|
||||
|
||||
def as_dict(self) -> dict:
|
||||
"""Convert configuration to dictionary format.
|
||||
|
||||
Returns:
|
||||
Dict: Configuration as dictionary
|
||||
"""
|
||||
return self.__dict__
|
||||
|
||||
def update(self, **kwargs) -> "X2RDataProcessingConfig":
|
||||
"""Update configuration parameters.
|
||||
|
||||
Args:
|
||||
**kwargs: Key-value pairs to update
|
||||
|
||||
Returns:
|
||||
X2RDataProcessingConfig: Updated configuration instance
|
||||
"""
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(self, key):
|
||||
setattr(self, key, value)
|
||||
else:
|
||||
raise ValueError(f"Unknown configuration parameter: {key}")
|
||||
return self
|
||||
|
||||
|
||||
def preprocesser_call(
|
||||
processor,
|
||||
images: list | Any | None = None,
|
||||
text: str | list[str] | None = None,
|
||||
videos: list | Any | None = None,
|
||||
padding: bool | str = False,
|
||||
truncation: bool | None = None,
|
||||
max_length: int | None = None,
|
||||
return_tensors: str = "pt",
|
||||
) -> BatchFeature:
|
||||
"""Unified preprocessing function for Wall-X model handling text, image and video inputs.
|
||||
|
||||
Processes inputs into format suitable for multimodal transformer models, including:
|
||||
- Text tokenization and special token handling
|
||||
- Image/video processing through image processor
|
||||
- Attention mask and label generation
|
||||
- Padding and truncation handling
|
||||
|
||||
Args:
|
||||
processor: Multimodal processor containing tokenizer and image processor
|
||||
images: Input images (PIL, numpy arrays, or torch tensors)
|
||||
text: Text or list of texts to tokenize
|
||||
videos: Input videos (numpy arrays or torch tensors)
|
||||
padding: Whether to pad sequences to same length
|
||||
truncation: Whether to truncate sequences longer than max_length
|
||||
max_length: Maximum length for truncation/padding
|
||||
return_tensors: Format for returned tensors ('pt', 'np', etc.)
|
||||
|
||||
Returns:
|
||||
BatchFeature containing processed inputs with keys:
|
||||
- input_ids: Tokenized text
|
||||
- attention_mask: Attention mask for text
|
||||
- pixel_values: Processed image pixels
|
||||
- pixel_values_videos: Processed video frames
|
||||
- image_grid_thw: Image grid dimensions for LLM
|
||||
- video_grid_thw: Video grid dimensions for LLM
|
||||
- labels: Training labels with masking
|
||||
"""
|
||||
# Process image inputs
|
||||
if images is not None and len(images) > 0:
|
||||
image_inputs = processor.image_processor(images=images, videos=None, return_tensors=return_tensors)
|
||||
image_grid_thw = image_inputs["image_grid_thw"]
|
||||
else:
|
||||
image_inputs = {}
|
||||
image_grid_thw = None
|
||||
|
||||
# Process video inputs
|
||||
if videos is not None:
|
||||
videos_inputs = processor.image_processor(images=None, videos=videos, return_tensors=return_tensors)
|
||||
video_grid_thw = videos_inputs["video_grid_thw"]
|
||||
else:
|
||||
videos_inputs = {}
|
||||
video_grid_thw = None
|
||||
|
||||
# Ensure text input is in list format
|
||||
if not isinstance(text, list):
|
||||
text = [text]
|
||||
|
||||
# Process image placeholder tokens in text
|
||||
if image_grid_thw is not None:
|
||||
merge_length = processor.image_processor.merge_size**2
|
||||
index = 0
|
||||
for i in range(len(text)):
|
||||
while "<|image_pad|>" in text[i]:
|
||||
# Add bounds checking to avoid index overflow
|
||||
if index >= len(image_grid_thw):
|
||||
print(
|
||||
f"Warning: Number of image placeholders ({index + 1}) "
|
||||
f"exceeds actual images ({len(image_grid_thw)}), "
|
||||
f"skipping remaining placeholder processing"
|
||||
)
|
||||
break
|
||||
# Replace image placeholder with actual token count
|
||||
token_count = image_grid_thw[index].prod() // merge_length
|
||||
text[i] = text[i].replace("<|image_pad|>", "<|placeholder|>" * token_count, 1)
|
||||
index += 1
|
||||
text[i] = text[i].replace("<|placeholder|>", "<|image_pad|>")
|
||||
|
||||
# Process video placeholder tokens in text
|
||||
if video_grid_thw is not None:
|
||||
merge_length = processor.image_processor.merge_size**2
|
||||
index = 0
|
||||
for i in range(len(text)):
|
||||
while "<|video_pad|>" in text[i]:
|
||||
# Replace video placeholder with actual token count
|
||||
token_count = video_grid_thw[index].prod() // merge_length
|
||||
text[i] = text[i].replace("<|video_pad|>", "<|placeholder|>" * token_count, 1)
|
||||
index += 1
|
||||
text[i] = text[i].replace("<|placeholder|>", "<|video_pad|>")
|
||||
|
||||
# Tokenize complete input text
|
||||
text_inputs = processor.tokenizer(
|
||||
text,
|
||||
return_tensors=return_tensors,
|
||||
padding=padding,
|
||||
truncation=truncation,
|
||||
max_length=max_length,
|
||||
)
|
||||
|
||||
# Get pad token ID for label generation
|
||||
pad_token_id = processor.tokenizer.pad_token_id
|
||||
if pad_token_id is None:
|
||||
pad_token_id = processor.tokenizer.eos_token_id
|
||||
|
||||
# Generate labels for multi-turn dialogue, keeping only assistant response loss
|
||||
labels = torch.full_like(text_inputs.input_ids, -100)
|
||||
assistant_marker = "<|im_start|>assistant\n"
|
||||
im_end_token_id = processor.tokenizer.convert_tokens_to_ids("<|im_end|>")
|
||||
assistant_tokens = processor.tokenizer("<|im_start|>assistant\n", add_special_tokens=False).input_ids
|
||||
|
||||
for i in range(len(text)):
|
||||
assistant_regions = []
|
||||
parts = text[i].split(assistant_marker)
|
||||
|
||||
# Process each part to determine which tokens belong to assistant responses
|
||||
# Count left padding tokens
|
||||
num_left_pads = 0
|
||||
for token_id in text_inputs.input_ids[i]:
|
||||
if token_id == pad_token_id:
|
||||
num_left_pads += 1
|
||||
else:
|
||||
break
|
||||
current_pos = num_left_pads
|
||||
|
||||
for j, part in enumerate(parts):
|
||||
part_tokens = processor.tokenizer(part, add_special_tokens=False).input_ids
|
||||
if j == 0:
|
||||
# First part is system prompt or user question, all labels are -100
|
||||
current_pos += len(part_tokens)
|
||||
continue
|
||||
|
||||
# From second part onwards, each part starts with assistant response
|
||||
for k in range(current_pos + 1, len(text_inputs.input_ids[i])):
|
||||
if text_inputs.input_ids[i][k] == im_end_token_id:
|
||||
assistant_regions.append((current_pos + len(assistant_tokens), k + 2))
|
||||
break
|
||||
current_pos += len(part_tokens) + 3
|
||||
|
||||
# Set labels for assistant response regions
|
||||
for start, end in assistant_regions:
|
||||
labels[i][start:end] = text_inputs.input_ids[i][start:end]
|
||||
|
||||
# Mask special action tokens in labels
|
||||
action_token_id = processor.tokenizer.encode("<|action|>")[0]
|
||||
propri_token_id = processor.tokenizer.encode("<|propri|>")[0]
|
||||
labels[labels == action_token_id] = -100
|
||||
labels[labels == propri_token_id] = -100
|
||||
labels[labels == processor.tokenizer.pad_token_id] = -100
|
||||
|
||||
# Set labels to None if all are invalid to skip cross entropy loss
|
||||
if (labels != -100).any().item():
|
||||
text_inputs["labels"] = labels
|
||||
else:
|
||||
text_inputs["labels"] = None
|
||||
|
||||
return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs})
|
||||
|
||||
|
||||
def process_grounding_points(
|
||||
text: str,
|
||||
orig_height: int,
|
||||
orig_width: int,
|
||||
resized_height: int,
|
||||
resized_width: int,
|
||||
model_type: str,
|
||||
) -> str:
|
||||
"""Process grounding point coordinates in text based on image resizing.
|
||||
|
||||
Adjusts coordinate values in <point> tags to match resized image dimensions
|
||||
for different model types (qwen2, qwen2_5).
|
||||
|
||||
Args:
|
||||
text: Input text containing <point> tags with coordinates
|
||||
orig_height: Original image height
|
||||
orig_width: Original image width
|
||||
resized_height: Resized image height
|
||||
resized_width: Resized image width
|
||||
model_type: Model type for coordinate processing ('qwen2' or 'qwen2_5')
|
||||
|
||||
Returns:
|
||||
Text with adjusted coordinate values
|
||||
"""
|
||||
# Regex pattern to match <point> tags and their contents
|
||||
point_pattern = re.compile(r"<point>(.*?)</point>")
|
||||
|
||||
def process_match(match):
|
||||
"""Process a single point match and adjust coordinates."""
|
||||
coords_str = match.group(1)
|
||||
try:
|
||||
# Extract coordinates from string
|
||||
coords = list(map(int, re.findall(r"\d+", coords_str)))
|
||||
|
||||
# Calculate resize scale factors
|
||||
scale_w = resized_width / orig_width
|
||||
scale_h = resized_height / orig_height
|
||||
|
||||
if len(coords) == 2:
|
||||
x, y = coords
|
||||
if model_type == "qwen2_5":
|
||||
# Qwen2.5 uses pixel coordinates
|
||||
new_x = max(0, min(round(x * scale_w), resized_width - 1))
|
||||
new_y = max(0, min(round(y * scale_h), resized_height - 1))
|
||||
elif model_type == "qwen2":
|
||||
# Qwen2 normalizes to [0, 1000) range
|
||||
new_x = max(0, min(999.999, (x / orig_width) * 1000))
|
||||
new_y = max(0, min(999.999, (y / orig_height) * 1000))
|
||||
else:
|
||||
raise ValueError(f"Unsupported model type: {model_type}")
|
||||
coords = [new_x, new_y]
|
||||
|
||||
elif len(coords) == 4:
|
||||
x1, y1, x2, y2 = coords
|
||||
if model_type == "qwen2_5":
|
||||
new_x1 = max(0, min(round(x1 * scale_w), resized_width - 1))
|
||||
new_y1 = max(0, min(round(y1 * scale_h), resized_height - 1))
|
||||
new_x2 = max(0, min(round(x2 * scale_w), resized_width - 1))
|
||||
new_y2 = max(0, min(round(y2 * scale_h), resized_height - 1))
|
||||
elif model_type == "qwen2":
|
||||
new_x1 = max(0, min(999.999, (x1 / orig_width) * 1000))
|
||||
new_y1 = max(0, min(999.999, (y1 / orig_height) * 1000))
|
||||
new_x2 = max(0, min(999.999, (x2 / orig_width) * 1000))
|
||||
new_y2 = max(0, min(999.999, (y2 / orig_height) * 1000))
|
||||
else:
|
||||
raise ValueError(f"Unsupported model type: {model_type}")
|
||||
coords = [new_x1, new_y1, new_x2, new_y2]
|
||||
|
||||
# Return processed point tag
|
||||
return f"<point>[{', '.join(map(str, coords))}]</point>"
|
||||
|
||||
except (ValueError, TypeError):
|
||||
# Return original content if processing fails
|
||||
return match.group(0)
|
||||
|
||||
# Replace all matching point tags
|
||||
processed_text = point_pattern.sub(process_match, text)
|
||||
return processed_text
|
||||
|
||||
|
||||
def get_frame_instruction(
|
||||
instruction_info: dict[str, Any],
|
||||
frame_idx: int | None = None,
|
||||
truncate_keys: list[str] | None = None,
|
||||
) -> tuple[dict[str, Any], int | None]:
|
||||
"""Extract frame-specific instruction from instruction dictionary.
|
||||
|
||||
Args:
|
||||
instruction_info: Dictionary containing instruction components
|
||||
frame_idx: Current frame index
|
||||
truncate_keys: Keys that trigger truncation when found
|
||||
|
||||
Returns:
|
||||
Tuple of (frame_instruction_dict, split_end_frame)
|
||||
"""
|
||||
if truncate_keys is None:
|
||||
truncate_keys = [
|
||||
"subtask_generation",
|
||||
"distribute",
|
||||
"subtask_generation_zh",
|
||||
"distribute_zh",
|
||||
]
|
||||
|
||||
instruction_for_frame = {}
|
||||
split_end = None
|
||||
|
||||
for key, value in instruction_info.items():
|
||||
if isinstance(value, dict):
|
||||
# Handle frame-range specific instructions
|
||||
for frame_range, frame_instruction in value.items():
|
||||
start_frame, end_frame = map(int, frame_range.split(" "))
|
||||
if start_frame <= frame_idx < end_frame or (start_frame == frame_idx):
|
||||
instruction_for_frame[key] = frame_instruction
|
||||
if truncate_keys is not None and split_end is None and key in truncate_keys:
|
||||
split_end = end_frame + 1
|
||||
break
|
||||
else:
|
||||
instruction_for_frame[key] = value
|
||||
|
||||
return instruction_for_frame, split_end
|
||||
|
||||
|
||||
def get_task_instruction(
|
||||
frame_instruction_info: dict[str, Any], priority_order: OrderedDict | None = None
|
||||
) -> str:
|
||||
"""Construct task instruction from available instruction fields using priority sampling.
|
||||
|
||||
Args:
|
||||
frame_instruction_info: Dictionary containing instruction fields
|
||||
priority_order: OrderedDict specifying sampling probability for each field
|
||||
|
||||
Returns:
|
||||
Combined instruction string with priority components
|
||||
"""
|
||||
# Default priority settings
|
||||
default_priority_order = OrderedDict(
|
||||
{
|
||||
"subtask_generation": 0.25,
|
||||
"subtask_generation_zh": 0.25,
|
||||
"distribute": 0.25,
|
||||
"distribute_zh": 0.25,
|
||||
}
|
||||
)
|
||||
|
||||
if priority_order is not None:
|
||||
priority_order = OrderedDict(priority_order)
|
||||
else:
|
||||
priority_order = default_priority_order
|
||||
|
||||
got_instruction = False
|
||||
task_instruction = ""
|
||||
|
||||
# Sample instruction components based on priority probabilities
|
||||
for key, prob in priority_order.items():
|
||||
if key in frame_instruction_info and frame_instruction_info[key] != "":
|
||||
if got_instruction:
|
||||
if random.random() >= prob:
|
||||
continue
|
||||
|
||||
task_instruction += f"\n{frame_instruction_info[key]}"
|
||||
got_instruction = True
|
||||
break
|
||||
|
||||
# Fall back to base instruction if no priority components found
|
||||
if not got_instruction:
|
||||
task_instruction = frame_instruction_info.get("instruction", "")
|
||||
|
||||
return task_instruction
|
||||
|
||||
|
||||
def get_wallx_normal_text(
|
||||
instruction_info: dict[str, Any],
|
||||
action_chunk_size: int,
|
||||
frame_idx: int,
|
||||
priority_order: OrderedDict | None = None,
|
||||
img_keys: list[str] | None = None,
|
||||
generate_subtask_ratio: float = 0.0,
|
||||
) -> tuple[str, bool]:
|
||||
"""Construct complete multimodal prompt text for Wall-X model.
|
||||
|
||||
Formats input using special tokens including:
|
||||
- System message
|
||||
- User observations (with image placeholders)
|
||||
- Task instructions
|
||||
- Proprioception prompts
|
||||
- Assistant responses (with action tokens)
|
||||
|
||||
Args:
|
||||
instruction_info: Dictionary containing instruction components
|
||||
action_chunk_size: Number of action tokens to generate
|
||||
frame_idx: Current frame index
|
||||
priority_order: Priority order for instruction sampling
|
||||
img_keys: List of image keys
|
||||
generate_subtask_ratio: Probability of generating subtask instead of actions
|
||||
|
||||
Returns:
|
||||
Tuple of (formatted_prompt_text, is_subtask_generation)
|
||||
"""
|
||||
# Special tokens for formatting
|
||||
role_start_symbol = "<|im_start|>"
|
||||
role_end_symbol = "<|im_end|>"
|
||||
vision_start_symbol = "<|vision_start|>"
|
||||
vision_end_symbol = "<|vision_end|>"
|
||||
image_pad_symbol = "<|image_pad|>"
|
||||
propri_symbol = "<|propri|>"
|
||||
action_symbol = "<|action|>"
|
||||
action_fast_symbol = "<|action_fast|>"
|
||||
|
||||
# System prologue
|
||||
prologue = f"{role_start_symbol}system\nYou are a helpful assistant.{role_end_symbol}\n"
|
||||
|
||||
# User request with observation
|
||||
user_request = f"{role_start_symbol}user\nObservation:"
|
||||
if img_keys:
|
||||
img_keys = img_key_mapping(img_keys)
|
||||
for key in img_keys:
|
||||
user_request += f" {key}: {vision_start_symbol}{image_pad_symbol}{vision_end_symbol}"
|
||||
user_request += "\nInstruction:"
|
||||
|
||||
# Get frame-specific instruction
|
||||
frame_instruction_info, _ = get_frame_instruction(instruction_info, frame_idx=frame_idx)
|
||||
|
||||
generate_subtask = False
|
||||
priority_keys = ["subtask_generation", "distribute"]
|
||||
|
||||
# Decide whether to generate subtask or actions
|
||||
if (
|
||||
bool(set(frame_instruction_info.keys()) & set(priority_keys))
|
||||
and random.random() < generate_subtask_ratio
|
||||
):
|
||||
# Generate subtask (equivalent to VQA task)
|
||||
instruction = frame_instruction_info.get("instruction", "")
|
||||
text_prompt = "\nPredict the next action in language.\n"
|
||||
user_message = f"{user_request} {instruction}{text_prompt}{role_end_symbol}\n"
|
||||
|
||||
# Find output instruction from priority keys
|
||||
for key in priority_keys:
|
||||
if key in frame_instruction_info:
|
||||
output_instruction = frame_instruction_info[key]
|
||||
break
|
||||
|
||||
assistant_output = f"{role_start_symbol}assistant\n{output_instruction}\n{role_end_symbol}"
|
||||
generate_subtask = True
|
||||
else:
|
||||
# Generate actions
|
||||
instruction = get_task_instruction(frame_instruction_info, priority_order=priority_order)
|
||||
text_prompt = f"\nPredict the next action in robot action.\nProprioception: {propri_symbol}\n"
|
||||
user_message = f"{user_request} {instruction}{text_prompt}{role_end_symbol}\n"
|
||||
assistant_output = f"{role_start_symbol}assistant\n{action_fast_symbol}{role_end_symbol}\n{action_symbol * action_chunk_size}"
|
||||
|
||||
complete_text = prologue + user_message + assistant_output
|
||||
return complete_text, generate_subtask
|
||||
|
||||
|
||||
def img_key_mapping(img_keys: list[str]) -> list[str]:
|
||||
"""Map image keys to camera names.
|
||||
|
||||
Args:
|
||||
img_keys: List of image keys
|
||||
|
||||
Returns:
|
||||
List of camera names
|
||||
"""
|
||||
processed_img_keys = []
|
||||
for key in img_keys:
|
||||
key = key.replace(OBS_IMAGES + ".", "")
|
||||
if key in CAMERA_NAME_MAPPING:
|
||||
key = CAMERA_NAME_MAPPING[key]
|
||||
else:
|
||||
if "view" in key:
|
||||
key = key.replace("_", " ")
|
||||
else:
|
||||
key = key + " view"
|
||||
processed_img_keys.append(key)
|
||||
return processed_img_keys
|
||||
|
||||
|
||||
def get_action_tokens(normalized_actions: torch.Tensor | list, action_tokenizer) -> list[list[str]]:
|
||||
"""Convert normalized actions to action token strings.
|
||||
|
||||
Args:
|
||||
normalized_actions: Normalized action arrays/tensors
|
||||
action_tokenizer: Tokenizer for converting actions to tokens
|
||||
|
||||
Returns:
|
||||
List of action token string lists for each sample
|
||||
"""
|
||||
if isinstance(normalized_actions, torch.Tensor):
|
||||
normalized_actions = normalized_actions.cpu().numpy()
|
||||
|
||||
all_action_tokens = []
|
||||
for i in range(len(normalized_actions)):
|
||||
if isinstance(normalized_actions[i], torch.Tensor):
|
||||
normalized_actions[i] = normalized_actions[i].cpu().numpy()
|
||||
|
||||
token_id = action_tokenizer(normalized_actions[i])
|
||||
action_tokens = [f"<|action_token_{j}|>" for j in token_id[0]]
|
||||
all_action_tokens.append(action_tokens)
|
||||
|
||||
return all_action_tokens
|
||||
|
||||
|
||||
def pad_action_token_strs(
|
||||
actions_token_lists: list[list[str]],
|
||||
pad_token: str = "<|endoftext|>", # nosec B107
|
||||
) -> list[str]:
|
||||
"""Pad action token lists to same length and join as strings.
|
||||
|
||||
Args:
|
||||
actions_token_lists: List of action token lists for each sample
|
||||
pad_token: Token used for padding
|
||||
|
||||
Returns:
|
||||
List of padded action token strings
|
||||
"""
|
||||
max_len = max(len(tokens) for tokens in actions_token_lists)
|
||||
padded_action_strs = []
|
||||
|
||||
for tokens in actions_token_lists:
|
||||
padded_tokens = tokens + ["<|im_end|>\n"] + [pad_token] * (max_len - len(tokens))
|
||||
padded_action_strs.append("".join(padded_tokens))
|
||||
|
||||
return padded_action_strs
|
||||
|
||||
|
||||
def replace_action_token(
|
||||
text: list[str],
|
||||
norm_action: torch.Tensor | None,
|
||||
action_tokenizer,
|
||||
dof_masks: torch.Tensor | None = None,
|
||||
) -> list[str]:
|
||||
"""Replace action placeholders in text with actual action tokens.
|
||||
|
||||
Args:
|
||||
text: List of text strings with action placeholders
|
||||
norm_action: Normalized action tensors
|
||||
action_tokenizer: Tokenizer for converting actions to tokens
|
||||
dof_masks: Masks for degrees of freedom
|
||||
|
||||
Returns:
|
||||
List of text strings with action tokens replaced
|
||||
"""
|
||||
if action_tokenizer is not None and norm_action is not None:
|
||||
# Extract actions based on chunk sizes and DOF masks
|
||||
norm_action = [action[:32, dof_masks[i, 0].bool()] for i, action in enumerate(norm_action)]
|
||||
|
||||
# Convert to action tokens and pad
|
||||
actions_fast_tokens = get_action_tokens(norm_action, action_tokenizer)
|
||||
actions_fast_token_strs = pad_action_token_strs(actions_fast_tokens)
|
||||
|
||||
# Replace action placeholders with actual tokens
|
||||
actions_fast_token_idx = 0
|
||||
for i in range(len(text)):
|
||||
if "<|action_fast|>" in text[i]:
|
||||
text[i] = text[i].replace(
|
||||
"<|action_fast|><|im_end|>\n",
|
||||
actions_fast_token_strs[actions_fast_token_idx],
|
||||
)
|
||||
actions_fast_token_idx += 1
|
||||
|
||||
# Remove remaining action placeholders
|
||||
text = [t.replace("<|action|>", "") for t in text]
|
||||
else:
|
||||
# Remove action placeholders when no tokenizer available
|
||||
text = [t.replace("<|action_fast|><|im_end|>\n", "") for t in text]
|
||||
|
||||
return text
|
||||
@@ -273,7 +273,7 @@ class XVLAPolicy(PreTrainedPolicy):
|
||||
config_class = XVLAConfig
|
||||
name = "xvla"
|
||||
|
||||
def __init__(self, config: XVLAConfig, **kwargs):
|
||||
def __init__(self, config: XVLAConfig):
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
florence_config = config.get_florence_config()
|
||||
|
||||
@@ -170,9 +170,8 @@ def _extract_complementary_data(batch: dict[str, Any]) -> dict[str, Any]:
|
||||
task_key = {"task": batch["task"]} if "task" in batch else {}
|
||||
index_key = {"index": batch["index"]} if "index" in batch else {}
|
||||
task_index_key = {"task_index": batch["task_index"]} if "task_index" in batch else {}
|
||||
episode_index_key = {"episode_index": batch["episode_index"]} if "episode_index" in batch else {}
|
||||
|
||||
return {**pad_keys, **task_key, **index_key, **task_index_key, **episode_index_key}
|
||||
return {**pad_keys, **task_key, **index_key, **task_index_key}
|
||||
|
||||
|
||||
def create_transition(
|
||||
|
||||
@@ -27,14 +27,13 @@ from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
|
||||
from lerobot.utils.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS, OBS_STATE
|
||||
from lerobot.utils.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
|
||||
from lerobot.utils.import_utils import _transformers_available
|
||||
|
||||
from .core import EnvTransition, TransitionKey
|
||||
from .pipeline import ObservationProcessorStep, ProcessorStepRegistry, ProcessorStep
|
||||
from .pipeline import ObservationProcessorStep, ProcessorStepRegistry
|
||||
|
||||
# Conditional import for type checking and lazy loading
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
@@ -269,328 +268,3 @@ class TokenizerProcessorStep(ObservationProcessorStep):
|
||||
)
|
||||
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="pi0fast_tokenizer_processor")
|
||||
class PI0FASTTokenizerProcessorStep(ProcessorStep):
|
||||
"""
|
||||
Processor step to tokenize state, language, and actions for PI0FAST models.
|
||||
|
||||
This step handles the complete tokenization pipeline for PI0FAST:
|
||||
1. Discretizes state observations
|
||||
2. Formats task descriptions with state
|
||||
3. Tokenizes actions using the FAST tokenizer
|
||||
4. Combines everything into the proper format with masks
|
||||
|
||||
Example usage:
|
||||
```python
|
||||
from transformers import AutoTokenizer, AutoProcessor
|
||||
from lerobot.processor.tokenizer_processor import PI0FASTTokenizerProcessorStep
|
||||
|
||||
# Initialize tokenizers
|
||||
paligemma_tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")
|
||||
paligemma_processor = AutoProcessor.from_pretrained("google/paligemma-3b-pt-224")
|
||||
fast_tokenizer = AutoProcessor.from_pretrained("physical-intelligence/fast", trust_remote_code=True)
|
||||
|
||||
# Create processor step
|
||||
processor = PI0FASTTokenizerProcessorStep(
|
||||
paligemma_tokenizer=paligemma_tokenizer,
|
||||
fast_tokenizer=fast_tokenizer,
|
||||
paligemma_processor=paligemma_processor,
|
||||
max_action_dim=7,
|
||||
fast_skip_tokens=2,
|
||||
max_input_seq_len=180,
|
||||
task_key="task",
|
||||
state_key="observation.state"
|
||||
)
|
||||
|
||||
# Apply to a transition
|
||||
tokenized_transition = processor(transition)
|
||||
|
||||
# Access tokenized data from observation
|
||||
input_ids = tokenized_transition["observation"]["pi0fast_input_ids"]
|
||||
attention_mask = tokenized_transition["observation"]["pi0fast_attention_mask"]
|
||||
loss_mask = tokenized_transition["observation"]["pi0fast_loss_mask"]
|
||||
token_type_ids = tokenized_transition["observation"]["pi0fast_token_type_ids"]
|
||||
```
|
||||
|
||||
Attributes:
|
||||
paligemma_tokenizer: The PaliGemma tokenizer for text
|
||||
fast_tokenizer: The FAST tokenizer for actions
|
||||
paligemma_processor: The PaliGemma processor
|
||||
max_action_dim: Maximum dimension for actions (default: 7)
|
||||
fast_skip_tokens: Number of tokens to skip in FAST tokenizer mapping (default: 2)
|
||||
max_input_seq_len: Maximum input sequence length (default: 180)
|
||||
padding_side: The side to pad on ('left' or 'right', default: 'right')
|
||||
task_key: The key in complementary_data where the task string is stored (default: 'task')
|
||||
state_key: The key in observation where the state is stored (default: 'observation.state')
|
||||
"""
|
||||
|
||||
paligemma_tokenizer: Any = None
|
||||
fast_tokenizer: Any = None
|
||||
paligemma_processor: Any = None
|
||||
max_action_dim: int = 7
|
||||
fast_skip_tokens: int = 2
|
||||
max_input_seq_len: int = 180
|
||||
padding_side: str = "right"
|
||||
task_key: str = "task"
|
||||
state_key: str = OBS_STATE
|
||||
|
||||
def __post_init__(self):
|
||||
"""Initialize the tokenizers."""
|
||||
if not _transformers_available:
|
||||
raise ImportError(
|
||||
"The 'transformers' library is not installed. "
|
||||
"Please install it with `pip install 'lerobot[transformers-dep]'` to use PI0FASTTokenizerProcessorStep."
|
||||
)
|
||||
|
||||
if self.paligemma_tokenizer is None or self.fast_tokenizer is None or self.paligemma_processor is None:
|
||||
raise ValueError(
|
||||
"paligemma_tokenizer, fast_tokenizer, and paligemma_processor must all be provided. "
|
||||
"These should be initialized tokenizer/processor objects."
|
||||
)
|
||||
|
||||
def normalize_actions(self, actions: torch.Tensor) -> torch.Tensor:
|
||||
"""Normalize actions to [-1, 1] range per batch element."""
|
||||
mins = actions.amin(dim=(1, 2), keepdim=True)
|
||||
maxs = actions.amax(dim=(1, 2), keepdim=True)
|
||||
return 2 * (actions - mins) / (maxs - mins + 1e-8) - 1
|
||||
|
||||
def _act_tokens_to_paligemma_tokens(self, tokens: torch.Tensor) -> torch.Tensor:
|
||||
"""Convert FAST tokens to PaliGemma vocabulary space."""
|
||||
vocab_size = getattr(self.paligemma_tokenizer, "vocab_size", 257152)
|
||||
return vocab_size - 1 - self.fast_skip_tokens - tokens
|
||||
|
||||
def fast_tokenizer_wrapper(self, actions_norm):
|
||||
"""Wrapper for FAST tokenizer that ensures batch processing and returns PyTorch tensors."""
|
||||
batch_tokens = self.fast_tokenizer(actions_norm)
|
||||
fast_out = self.paligemma_processor.tokenizer.pad({"input_ids": batch_tokens}, return_tensors="pt")
|
||||
return fast_out
|
||||
|
||||
def create_token_type_ids(self, padded_mask: torch.Tensor, prefix_len: torch.Tensor) -> torch.Tensor:
|
||||
"""Create token type IDs to distinguish prefix from action tokens."""
|
||||
token_type_ids = torch.zeros_like(padded_mask, dtype=torch.bool)
|
||||
cumsum_mask = (padded_mask != 0).cumsum(dim=1)
|
||||
suffix_mask = cumsum_mask > prefix_len
|
||||
return suffix_mask
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
"""
|
||||
Process the transition and add tokenized inputs.
|
||||
|
||||
Args:
|
||||
transition: The environment transition to process
|
||||
|
||||
Returns:
|
||||
The transition with added tokenized data
|
||||
"""
|
||||
self.transition = transition
|
||||
|
||||
# Extract components from transition
|
||||
observation = transition.get(TransitionKey.OBSERVATION)
|
||||
action = transition.get(TransitionKey.ACTION)
|
||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
||||
|
||||
if observation is None:
|
||||
raise ValueError("Observation is None in transition")
|
||||
|
||||
# Get state and language
|
||||
state = observation.get(self.state_key)
|
||||
if state is None:
|
||||
raise ValueError(f"State key '{self.state_key}' not found in observation")
|
||||
|
||||
# Get task description
|
||||
if complementary_data is None:
|
||||
raise ValueError("Complementary data is None, cannot extract task")
|
||||
|
||||
task_data = complementary_data.get(self.task_key)
|
||||
if task_data is None:
|
||||
raise ValueError(f"Task key '{self.task_key}' not found in complementary data")
|
||||
|
||||
# Standardize task to list of strings
|
||||
if isinstance(task_data, str):
|
||||
lang_text = [task_data]
|
||||
elif isinstance(task_data, list) and all(isinstance(t, str) for t in task_data):
|
||||
lang_text = task_data
|
||||
else:
|
||||
raise ValueError(f"Task must be string or list of strings, got {type(task_data)}")
|
||||
|
||||
# Create tokenized inputs
|
||||
tokenized_data = self.create_input_tokens(state, lang_text, action)
|
||||
|
||||
# Add tokenized data to observation
|
||||
new_observation = dict(observation)
|
||||
new_observation["pi0fast_input_ids"] = tokenized_data["input_ids"]
|
||||
new_observation["pi0fast_attention_mask"] = tokenized_data["attention_mask"]
|
||||
new_observation["pi0fast_padded_mask"] = tokenized_data["padded_mask"]
|
||||
new_observation["pi0fast_loss_mask"] = tokenized_data["loss_mask"]
|
||||
new_observation["pi0fast_token_type_ids"] = tokenized_data["token_type_ids"]
|
||||
|
||||
# Create new transition with updated observation
|
||||
new_transition = dict(transition)
|
||||
new_transition[TransitionKey.OBSERVATION] = new_observation
|
||||
|
||||
return new_transition
|
||||
|
||||
def create_input_tokens(self, state, lang_text, actions=None):
|
||||
"""
|
||||
Create tokenized input from state, language, and actions.
|
||||
|
||||
This method follows the same logic as the original PI0FAST create_input_tokens method.
|
||||
|
||||
Args:
|
||||
state: State tensor [batch_size, state_dim]
|
||||
lang_text: List of task description strings
|
||||
actions: Optional action tensor [batch_size, horizon, action_dim]
|
||||
|
||||
Returns:
|
||||
Dictionary containing input_ids, attention_mask, padded_mask, loss_mask, and token_type_ids
|
||||
"""
|
||||
bsize = state.shape[0]
|
||||
device = state.device
|
||||
|
||||
# Discretize state
|
||||
bins = torch.linspace(-1, 1, 256 + 1, device=device)[:-1]
|
||||
discretized = torch.bucketize(state, bins) - 1
|
||||
discretized = discretized[:, :32]
|
||||
|
||||
# Create prefix texts with task and state
|
||||
prefix_texts = []
|
||||
for txt, disc in zip(lang_text, discretized, strict=False):
|
||||
cleaned = txt.lower().strip().replace("_", " ")
|
||||
state_str = " ".join(str(val.item()) for val in disc)
|
||||
prefix_texts.append(f"Task: {cleaned}, State: {state_str};\n")
|
||||
|
||||
# Tokenize prefix
|
||||
prefix_out = self.paligemma_tokenizer(
|
||||
prefix_texts, add_special_tokens=True, return_tensors="pt", padding="longest", truncation=False
|
||||
)
|
||||
prefix_ids = prefix_out["input_ids"].to(device)
|
||||
prefix_mask = prefix_out["attention_mask"].to(device)
|
||||
prefix_lens = prefix_mask.sum(dim=1)[:, None].cpu()
|
||||
|
||||
# Get pad token ID
|
||||
pad_token_id = (
|
||||
self.paligemma_tokenizer.pad_token_id
|
||||
if hasattr(self.paligemma_tokenizer, "pad_token_id")
|
||||
else self.paligemma_tokenizer.eos_token_id
|
||||
)
|
||||
|
||||
if actions is not None:
|
||||
# pad actions
|
||||
actions_pad = F.pad(
|
||||
actions, (0, max(0, self.max_action_dim - actions.shape[2])), value=0
|
||||
)[:, :, : self.max_action_dim]
|
||||
|
||||
# Tokenize actions with FAST tokenizer
|
||||
fast_out = self.fast_tokenizer_wrapper(actions_pad.cpu())
|
||||
act_ids = fast_out["input_ids"]
|
||||
act_mask = fast_out["attention_mask"].to(device)
|
||||
|
||||
# Convert FAST tokens to PaliGemma token space
|
||||
act_ids = self._act_tokens_to_paligemma_tokens(act_ids).to(device)
|
||||
|
||||
# Replace padding tokens
|
||||
vocab_size = getattr(self.paligemma_tokenizer, "vocab_size", 257152)
|
||||
act_ids = torch.where(
|
||||
act_ids == vocab_size - 1 - self.fast_skip_tokens,
|
||||
pad_token_id,
|
||||
act_ids,
|
||||
)
|
||||
|
||||
# Add BOS ("Action: ") and EOS tokens
|
||||
eos_token = torch.tensor(
|
||||
[self.paligemma_tokenizer.eos_token_id], dtype=torch.long, device=device
|
||||
).expand(bsize, -1)
|
||||
eos_mask = torch.tensor([1], dtype=torch.long, device=device).expand(bsize, -1)
|
||||
|
||||
bos = self.paligemma_tokenizer("Action: ", add_special_tokens=False, return_tensors="pt")
|
||||
bos_token = bos["input_ids"].expand(act_ids.shape[0], -1).to(device)
|
||||
bos_mask = bos["attention_mask"].expand(act_ids.shape[0], -1).to(device)
|
||||
|
||||
act_ids = torch.cat([bos_token, act_ids, eos_token], dim=1)
|
||||
act_mask = torch.cat([bos_mask, act_mask, eos_mask], dim=1)
|
||||
act_mask = act_mask.to(device)
|
||||
else:
|
||||
# No actions provided
|
||||
act_ids = torch.empty(bsize, 0, dtype=torch.long, device=device)
|
||||
act_mask = torch.empty(bsize, 0, dtype=torch.long, device=device)
|
||||
|
||||
# Concatenate prefix and action tokens
|
||||
final_ids = torch.cat([prefix_ids, act_ids], dim=1)
|
||||
final_mask = torch.cat([prefix_mask, act_mask], dim=1)
|
||||
|
||||
batch_inputs = {"input_ids": final_ids.tolist(), "attention_mask": final_mask.tolist()}
|
||||
|
||||
# Pad to max length
|
||||
padded_output = self.paligemma_tokenizer.pad(
|
||||
batch_inputs, padding="longest", max_length=self.max_input_seq_len, return_tensors="pt"
|
||||
)
|
||||
padded_mask = padded_output["attention_mask"]
|
||||
|
||||
# Create attention mask (excludes prefix)
|
||||
att_mask = (padded_mask != 0).cumsum(dim=1) > prefix_lens
|
||||
|
||||
# Create token type IDs
|
||||
token_type_ids = self.create_token_type_ids(padded_mask=padded_mask, prefix_len=prefix_lens)
|
||||
|
||||
# Return all masks
|
||||
return {
|
||||
"input_ids": padded_output["input_ids"],
|
||||
"attention_mask": att_mask,
|
||||
"padded_mask": padded_mask,
|
||||
"loss_mask": att_mask & padded_mask, # loss is computed not on prefix, and not on padding
|
||||
"token_type_ids": token_type_ids,
|
||||
}
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""Returns the serializable configuration of the processor."""
|
||||
return {
|
||||
"max_action_dim": self.max_action_dim,
|
||||
"fast_skip_tokens": self.fast_skip_tokens,
|
||||
"max_input_seq_len": self.max_input_seq_len,
|
||||
"padding_side": self.padding_side,
|
||||
"task_key": self.task_key,
|
||||
"state_key": self.state_key,
|
||||
}
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
"""
|
||||
Adds feature definitions for the tokenized PI0FAST inputs.
|
||||
|
||||
Args:
|
||||
features: The dictionary of existing policy features.
|
||||
|
||||
Returns:
|
||||
The updated dictionary of policy features.
|
||||
"""
|
||||
# Add features for tokenized inputs
|
||||
if "pi0fast_input_ids" not in features[PipelineFeatureType.OBSERVATION]:
|
||||
features[PipelineFeatureType.OBSERVATION]["pi0fast_input_ids"] = PolicyFeature(
|
||||
type=FeatureType.LANGUAGE, shape=(self.max_input_seq_len,)
|
||||
)
|
||||
|
||||
if "pi0fast_attention_mask" not in features[PipelineFeatureType.OBSERVATION]:
|
||||
features[PipelineFeatureType.OBSERVATION]["pi0fast_attention_mask"] = PolicyFeature(
|
||||
type=FeatureType.LANGUAGE, shape=(self.max_input_seq_len,)
|
||||
)
|
||||
|
||||
if "pi0fast_padded_mask" not in features[PipelineFeatureType.OBSERVATION]:
|
||||
features[PipelineFeatureType.OBSERVATION]["pi0fast_padded_mask"] = PolicyFeature(
|
||||
type=FeatureType.LANGUAGE, shape=(self.max_input_seq_len,)
|
||||
)
|
||||
|
||||
if "pi0fast_loss_mask" not in features[PipelineFeatureType.OBSERVATION]:
|
||||
features[PipelineFeatureType.OBSERVATION]["pi0fast_loss_mask"] = PolicyFeature(
|
||||
type=FeatureType.LANGUAGE, shape=(self.max_input_seq_len,)
|
||||
)
|
||||
|
||||
if "pi0fast_token_type_ids" not in features[PipelineFeatureType.OBSERVATION]:
|
||||
features[PipelineFeatureType.OBSERVATION]["pi0fast_token_type_ids"] = PolicyFeature(
|
||||
type=FeatureType.LANGUAGE, shape=(self.max_input_seq_len,)
|
||||
)
|
||||
|
||||
return features
|
||||
|
||||
@@ -16,6 +16,8 @@
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.cameras import CameraConfig
|
||||
|
||||
from ..config import RobotConfig
|
||||
|
||||
_GAINS: dict[str, dict[str, list[float]]] = {
|
||||
@@ -52,7 +54,10 @@ class UnitreeG1Config(RobotConfig):
|
||||
control_dt: float = 1.0 / 250.0 # 250Hz
|
||||
|
||||
# launch mujoco simulation
|
||||
is_simulation: bool = True
|
||||
is_simulation: bool = False
|
||||
|
||||
# socket config for ZMQ bridge
|
||||
robot_ip: str = "192.168.123.164"
|
||||
robot_ip: str = "172.18.129.215"
|
||||
|
||||
# cameras (optional)
|
||||
cameras: dict[str, CameraConfig] = field(default_factory=dict)
|
||||
|
||||
@@ -0,0 +1,302 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Standalone keyboard control script for Unitree G1 robot.
|
||||
|
||||
This script provides keyboard-based velocity control for the G1 robot's
|
||||
locomotion system. It can be run alongside the main robot control to
|
||||
provide manual movement commands.
|
||||
|
||||
Usage:
|
||||
python keyboard_control.py [--robot-ip IP] [--simulation]
|
||||
|
||||
Controls:
|
||||
W/S: Forward/Backward
|
||||
A/D: Strafe Left/Right
|
||||
Q/E: Rotate Left/Right
|
||||
R/F: Raise/Lower Height (GR00T policies only)
|
||||
Z: Stop (zero all velocity commands)
|
||||
ESC/Ctrl+C: Exit
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
import select
|
||||
import time
|
||||
import numpy as np
|
||||
|
||||
# Terminal handling for non-blocking keyboard input
|
||||
try:
|
||||
import termios
|
||||
import tty
|
||||
HAS_TERMIOS = True
|
||||
except ImportError:
|
||||
HAS_TERMIOS = False
|
||||
print("Warning: termios not available. Keyboard controls require Linux/macOS.")
|
||||
|
||||
|
||||
class KeyboardController:
|
||||
"""Handles keyboard input and converts to locomotion commands."""
|
||||
|
||||
def __init__(self, callback=None):
|
||||
"""
|
||||
Initialize keyboard controller.
|
||||
|
||||
Args:
|
||||
callback: Optional function called when commands change.
|
||||
Signature: callback(vx, vy, yaw, height)
|
||||
"""
|
||||
self.callback = callback
|
||||
self.running = False
|
||||
|
||||
# Locomotion commands
|
||||
self.vx = 0.0 # Forward/backward velocity
|
||||
self.vy = 0.0 # Left/right velocity (strafe)
|
||||
self.yaw = 0.0 # Rotation rate
|
||||
self.height = 0.74 # Base height (for GR00T policies)
|
||||
|
||||
# Command limits
|
||||
self.vx_limit = (-0.8, 0.8)
|
||||
self.vy_limit = (-0.5, 0.5)
|
||||
self.yaw_limit = (-1.0, 1.0)
|
||||
self.height_limit = (0.50, 1.00)
|
||||
|
||||
# Increments per keypress
|
||||
self.vx_increment = 0.4
|
||||
self.vy_increment = 0.25
|
||||
self.yaw_increment = 0.5
|
||||
self.height_increment = 0.05
|
||||
|
||||
self._old_terminal_settings = None
|
||||
|
||||
def get_commands(self) -> tuple[float, float, float, float]:
|
||||
"""Get current command values as tuple (vx, vy, yaw, height)."""
|
||||
return (self.vx, self.vy, self.yaw, self.height)
|
||||
|
||||
def get_commands_array(self) -> np.ndarray:
|
||||
"""Get velocity commands as numpy array [vx, vy, yaw]."""
|
||||
return np.array([self.vx, self.vy, self.yaw], dtype=np.float32)
|
||||
|
||||
def reset_commands(self):
|
||||
"""Reset all commands to zero (stop)."""
|
||||
self.vx = 0.0
|
||||
self.vy = 0.0
|
||||
self.yaw = 0.0
|
||||
self._notify_callback()
|
||||
|
||||
def _clamp(self, value: float, limits: tuple[float, float]) -> float:
|
||||
"""Clamp value to limits."""
|
||||
return max(limits[0], min(limits[1], value))
|
||||
|
||||
def _notify_callback(self):
|
||||
"""Call callback with current commands if set."""
|
||||
if self.callback:
|
||||
self.callback(self.vx, self.vy, self.yaw, self.height)
|
||||
|
||||
def process_key(self, key: str) -> bool:
|
||||
"""
|
||||
Process a single key press and update commands.
|
||||
|
||||
Args:
|
||||
key: Single character key that was pressed.
|
||||
|
||||
Returns:
|
||||
True if key was handled, False otherwise.
|
||||
"""
|
||||
key = key.lower()
|
||||
handled = True
|
||||
|
||||
if key == 'w':
|
||||
self.vx = self._clamp(self.vx + self.vx_increment, self.vx_limit)
|
||||
elif key == 's':
|
||||
self.vx = self._clamp(self.vx - self.vx_increment, self.vx_limit)
|
||||
elif key == 'a':
|
||||
self.vy = self._clamp(self.vy + self.vy_increment, self.vy_limit)
|
||||
elif key == 'd':
|
||||
self.vy = self._clamp(self.vy - self.vy_increment, self.vy_limit)
|
||||
elif key == 'q':
|
||||
self.yaw = self._clamp(self.yaw + self.yaw_increment, self.yaw_limit)
|
||||
elif key == 'e':
|
||||
self.yaw = self._clamp(self.yaw - self.yaw_increment, self.yaw_limit)
|
||||
elif key == 'r':
|
||||
self.height = self._clamp(self.height + self.height_increment, self.height_limit)
|
||||
elif key == 'f':
|
||||
self.height = self._clamp(self.height - self.height_increment, self.height_limit)
|
||||
elif key == 'z':
|
||||
self.reset_commands()
|
||||
return True # Already notified in reset_commands
|
||||
else:
|
||||
handled = False
|
||||
|
||||
if handled:
|
||||
self._notify_callback()
|
||||
|
||||
return handled
|
||||
|
||||
def _setup_terminal(self):
|
||||
"""Set terminal to raw mode for single character input."""
|
||||
if HAS_TERMIOS:
|
||||
self._old_terminal_settings = termios.tcgetattr(sys.stdin)
|
||||
tty.setcbreak(sys.stdin.fileno())
|
||||
|
||||
def _restore_terminal(self):
|
||||
"""Restore terminal to original settings."""
|
||||
if HAS_TERMIOS and self._old_terminal_settings is not None:
|
||||
termios.tcsetattr(sys.stdin, termios.TCSADRAIN, self._old_terminal_settings)
|
||||
self._old_terminal_settings = None
|
||||
|
||||
def run(self):
|
||||
"""Run the keyboard listener loop (blocking)."""
|
||||
if not HAS_TERMIOS:
|
||||
print("Error: Keyboard controls require termios (Linux/macOS)")
|
||||
return
|
||||
|
||||
self.running = True
|
||||
self._print_controls()
|
||||
|
||||
try:
|
||||
self._setup_terminal()
|
||||
|
||||
while self.running:
|
||||
# Check for keyboard input with timeout
|
||||
if select.select([sys.stdin], [], [], 0.1)[0]:
|
||||
key = sys.stdin.read(1)
|
||||
|
||||
# Handle escape sequences (arrow keys, etc.)
|
||||
if key == '\x1b': # ESC
|
||||
self.running = False
|
||||
break
|
||||
|
||||
if self.process_key(key):
|
||||
self._print_status()
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\nInterrupted by user")
|
||||
finally:
|
||||
self._restore_terminal()
|
||||
print("\nKeyboard controls stopped")
|
||||
|
||||
def stop(self):
|
||||
"""Stop the keyboard listener."""
|
||||
self.running = False
|
||||
|
||||
def _print_controls(self):
|
||||
"""Print control instructions."""
|
||||
print("\n" + "=" * 60)
|
||||
print("KEYBOARD CONTROLS ACTIVE")
|
||||
print("=" * 60)
|
||||
print(" W/S: Forward/Backward")
|
||||
print(" A/D: Strafe Left/Right")
|
||||
print(" Q/E: Rotate Left/Right")
|
||||
print(" R/F: Raise/Lower Height (±5cm)")
|
||||
print(" Z: Stop (zero all commands)")
|
||||
print(" ESC: Exit")
|
||||
print("=" * 60 + "\n")
|
||||
|
||||
def _print_status(self):
|
||||
"""Print current command status."""
|
||||
print(f"[CMD] vx={self.vx:+.2f}, vy={self.vy:+.2f}, yaw={self.yaw:+.2f} | height={self.height:.3f}m")
|
||||
|
||||
|
||||
class RobotKeyboardController(KeyboardController):
|
||||
"""Keyboard controller that directly updates a robot's locomotion commands."""
|
||||
|
||||
def __init__(self, robot):
|
||||
"""
|
||||
Initialize with a UnitreeG1 robot instance.
|
||||
|
||||
Args:
|
||||
robot: UnitreeG1 robot instance with locomotion_cmd attribute.
|
||||
"""
|
||||
super().__init__()
|
||||
self.robot = robot
|
||||
|
||||
# Initialize from robot's current state if available
|
||||
if hasattr(robot, 'locomotion_cmd'):
|
||||
self.vx = robot.locomotion_cmd[0]
|
||||
self.vy = robot.locomotion_cmd[1]
|
||||
self.yaw = robot.locomotion_cmd[2]
|
||||
|
||||
if hasattr(robot, 'groot_height_cmd'):
|
||||
self.height = robot.groot_height_cmd
|
||||
|
||||
def _notify_callback(self):
|
||||
"""Update robot's locomotion commands directly."""
|
||||
if hasattr(self.robot, 'locomotion_cmd'):
|
||||
self.robot.locomotion_cmd[0] = self.vx
|
||||
self.robot.locomotion_cmd[1] = self.vy
|
||||
self.robot.locomotion_cmd[2] = self.yaw
|
||||
|
||||
if hasattr(self.robot, 'groot_height_cmd'):
|
||||
self.robot.groot_height_cmd = self.height
|
||||
|
||||
|
||||
def start_keyboard_control_thread(robot) -> tuple:
|
||||
"""
|
||||
Start keyboard controls for a robot in a background thread.
|
||||
|
||||
Args:
|
||||
robot: UnitreeG1 robot instance.
|
||||
|
||||
Returns:
|
||||
Tuple of (controller, thread) for later stopping.
|
||||
"""
|
||||
import threading
|
||||
|
||||
controller = RobotKeyboardController(robot)
|
||||
thread = threading.Thread(target=controller.run, daemon=True)
|
||||
thread.start()
|
||||
|
||||
return controller, thread
|
||||
|
||||
|
||||
def stop_keyboard_control_thread(controller, thread, timeout: float = 2.0):
|
||||
"""
|
||||
Stop the keyboard control thread.
|
||||
|
||||
Args:
|
||||
controller: KeyboardController instance.
|
||||
thread: Thread running the controller.
|
||||
timeout: Max time to wait for thread to stop.
|
||||
"""
|
||||
controller.stop()
|
||||
thread.join(timeout=timeout)
|
||||
|
||||
|
||||
def main():
|
||||
"""Standalone keyboard control with optional robot connection."""
|
||||
parser = argparse.ArgumentParser(description="Keyboard control for Unitree G1")
|
||||
parser.add_argument("--standalone", action="store_true",
|
||||
help="Run in standalone mode (just print commands, no robot)")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.standalone:
|
||||
# Standalone mode - just demonstrate keyboard input
|
||||
def print_callback(vx, vy, yaw, height):
|
||||
print(f" → Would send: vx={vx:+.2f}, vy={vy:+.2f}, yaw={yaw:+.2f}, height={height:.3f}")
|
||||
|
||||
controller = KeyboardController(callback=print_callback)
|
||||
print("Running in STANDALONE mode (no robot connection)")
|
||||
controller.run()
|
||||
else:
|
||||
print("To use with a robot, import and use RobotKeyboardController:")
|
||||
print("")
|
||||
print(" from lerobot.robots.unitree_g1.keyboard_control import (")
|
||||
print(" RobotKeyboardController,")
|
||||
print(" start_keyboard_control_thread,")
|
||||
print(" stop_keyboard_control_thread")
|
||||
print(" )")
|
||||
print("")
|
||||
print(" # Start keyboard controls")
|
||||
print(" controller, thread = start_keyboard_control_thread(robot)")
|
||||
print("")
|
||||
print(" # ... robot runs ...")
|
||||
print("")
|
||||
print(" # Stop keyboard controls")
|
||||
print(" stop_keyboard_control_thread(controller, thread)")
|
||||
print("")
|
||||
print("Or run with --standalone to test keyboard input without a robot.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -27,25 +27,6 @@ lerobot-info
|
||||
|
||||
import importlib
|
||||
import platform
|
||||
import shutil
|
||||
import subprocess
|
||||
from importlib.metadata import PackageNotFoundError, distribution
|
||||
|
||||
PACKAGE_NAME = "lerobot"
|
||||
|
||||
|
||||
def get_ffmpeg_version() -> str:
|
||||
"""Get the ffmpeg version if installed, otherwise return 'N/A'."""
|
||||
command_path = shutil.which("ffmpeg")
|
||||
if command_path is None:
|
||||
return "N/A"
|
||||
try:
|
||||
result = subprocess.run([command_path, "-version"], capture_output=True, text=True, check=True)
|
||||
first_line = result.stdout.splitlines()[0]
|
||||
version_info = first_line.split(" ")[2]
|
||||
return version_info
|
||||
except (subprocess.SubprocessError, IndexError):
|
||||
return "Installed (version parsing failed)"
|
||||
|
||||
|
||||
def get_package_version(package_name: str) -> str:
|
||||
@@ -57,17 +38,16 @@ def get_package_version(package_name: str) -> str:
|
||||
return "N/A"
|
||||
|
||||
|
||||
def get_sys_info() -> dict[str, str]:
|
||||
def get_sys_info() -> dict:
|
||||
"""Run this to get basic system info to help for tracking issues & bugs."""
|
||||
# General package versions
|
||||
info = {
|
||||
"LeRobot version": get_package_version(PACKAGE_NAME),
|
||||
"lerobot version": get_package_version("lerobot"),
|
||||
"Platform": platform.platform(),
|
||||
"Python version": platform.python_version(),
|
||||
"Huggingface Hub version": get_package_version("huggingface_hub"),
|
||||
"Datasets version": get_package_version("datasets"),
|
||||
"Numpy version": get_package_version("numpy"),
|
||||
"FFmpeg version": get_ffmpeg_version(),
|
||||
}
|
||||
|
||||
# PyTorch and GPU specific information
|
||||
@@ -78,10 +58,10 @@ def get_sys_info() -> dict[str, str]:
|
||||
try:
|
||||
import torch
|
||||
|
||||
torch_version = str(torch.__version__)
|
||||
torch_version = torch.__version__
|
||||
torch_cuda_available = torch.cuda.is_available()
|
||||
if torch_cuda_available:
|
||||
cuda_version = str(torch.version.cuda)
|
||||
cuda_version = torch.version.cuda
|
||||
# Gets the name of the first available GPU
|
||||
gpu_model = torch.cuda.get_device_name(0)
|
||||
except ImportError:
|
||||
@@ -91,34 +71,24 @@ def get_sys_info() -> dict[str, str]:
|
||||
info.update(
|
||||
{
|
||||
"PyTorch version": torch_version,
|
||||
"Is PyTorch built with CUDA support?": str(torch_cuda_available),
|
||||
"Is PyTorch built with CUDA support?": torch_cuda_available,
|
||||
"Cuda version": cuda_version,
|
||||
"GPU model": gpu_model,
|
||||
"Using GPU in script?": "<fill in>",
|
||||
}
|
||||
)
|
||||
scripts = "N/A"
|
||||
try:
|
||||
dist = distribution(PACKAGE_NAME)
|
||||
scripts = [ep.name for ep in dist.entry_points if ep.group == "console_scripts"]
|
||||
except PackageNotFoundError:
|
||||
pass
|
||||
|
||||
info.update({f"{PACKAGE_NAME} scripts": str(scripts)})
|
||||
|
||||
return info
|
||||
|
||||
|
||||
def format_dict_for_markdown(d: dict[str, str]) -> str:
|
||||
def format_dict_for_markdown(d: dict) -> str:
|
||||
"""Formats a dictionary into a markdown-friendly bulleted list."""
|
||||
return "\n".join([f"- {prop}: {val}" for prop, val in d.items()])
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
Main function to print system info in markdown format.
|
||||
"""
|
||||
system_info = get_sys_info()
|
||||
print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the last point.\n")
|
||||
print(format_dict_for_markdown(system_info))
|
||||
|
||||
|
||||
|
||||
@@ -101,6 +101,7 @@ from lerobot.robots import ( # noqa: F401
|
||||
so100_follower,
|
||||
so101_follower,
|
||||
)
|
||||
from lerobot.robots.unitree_g1 import config_unitree_g1 # noqa: F401
|
||||
from lerobot.teleoperators import ( # noqa: F401
|
||||
Teleoperator,
|
||||
TeleoperatorConfig,
|
||||
@@ -197,9 +198,8 @@ class RecordConfig:
|
||||
cli_overrides = parser.get_cli_overrides("policy")
|
||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||
self.policy.pretrained_path = policy_path
|
||||
|
||||
if self.teleop is None and self.policy is None:
|
||||
raise ValueError("Choose a policy, a teleoperator or both to control the robot")
|
||||
# Note: teleop and policy can both be None for robots with built-in control (e.g. unitree_g1)
|
||||
# This is validated in record() after the robot is instantiated
|
||||
|
||||
@classmethod
|
||||
def __get_path_fields__(cls) -> list[str]:
|
||||
@@ -340,6 +340,13 @@ def record_loop(
|
||||
base_action = robot._from_keyboard_to_base_action(keyboard_action)
|
||||
act = {**arm_action, **base_action} if len(base_action) > 0 else arm_action
|
||||
act_processed_teleop = teleop_action_processor((act, obs))
|
||||
elif policy is None and teleop is None and dataset is not None:
|
||||
# Observation-only recording (robot controls itself, e.g. unitree_g1)
|
||||
# Record observations, extract action-relevant values (positions) from obs
|
||||
# Filter obs_processed to only include keys that match action_features
|
||||
action_keys = set(robot.action_features.keys())
|
||||
action_values = {k: v for k, v in obs_processed.items() if k in action_keys}
|
||||
robot_action_to_send = None
|
||||
else:
|
||||
logging.info(
|
||||
"No policy or teleoperator provided, skipping action generation."
|
||||
@@ -352,15 +359,17 @@ def record_loop(
|
||||
if policy is not None and act_processed_policy is not None:
|
||||
action_values = act_processed_policy
|
||||
robot_action_to_send = robot_action_processor((act_processed_policy, obs))
|
||||
else:
|
||||
elif teleop is not None:
|
||||
action_values = act_processed_teleop
|
||||
robot_action_to_send = robot_action_processor((act_processed_teleop, obs))
|
||||
# else: observation-only mode, action_values already set above
|
||||
|
||||
# Send action to robot
|
||||
# Action can eventually be clipped using `max_relative_target`,
|
||||
# so action actually sent is saved in the dataset. action = postprocessor.process(action)
|
||||
# TODO(steven, pepijn, adil): we should use a pipeline step to clip the action, so the sent action is the action that we input to the robot.
|
||||
_sent_action = robot.send_action(robot_action_to_send)
|
||||
# Send action to robot (skip if observation-only mode)
|
||||
if robot_action_to_send is not None:
|
||||
# Action can eventually be clipped using `max_relative_target`,
|
||||
# so action actually sent is saved in the dataset. action = postprocessor.process(action)
|
||||
# TODO(steven, pepijn, adil): we should use a pipeline step to clip the action, so the sent action is the action that we input to the robot.
|
||||
_sent_action = robot.send_action(robot_action_to_send)
|
||||
|
||||
# Write to dataset
|
||||
if dataset is not None:
|
||||
@@ -404,63 +413,82 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
|
||||
),
|
||||
)
|
||||
|
||||
dataset = None
|
||||
listener = None
|
||||
if cfg.resume:
|
||||
dataset = LeRobotDataset(
|
||||
cfg.dataset.repo_id,
|
||||
root=cfg.dataset.root,
|
||||
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
|
||||
)
|
||||
|
||||
try:
|
||||
if cfg.resume:
|
||||
dataset = LeRobotDataset(
|
||||
cfg.dataset.repo_id,
|
||||
root=cfg.dataset.root,
|
||||
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
|
||||
if hasattr(robot, "cameras") and len(robot.cameras) > 0:
|
||||
dataset.start_image_writer(
|
||||
num_processes=cfg.dataset.num_image_writer_processes,
|
||||
num_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot.cameras),
|
||||
)
|
||||
sanity_check_dataset_robot_compatibility(dataset, robot, cfg.dataset.fps, dataset_features)
|
||||
else:
|
||||
# Create empty dataset or load existing saved episodes
|
||||
sanity_check_dataset_name(cfg.dataset.repo_id, cfg.policy)
|
||||
dataset = LeRobotDataset.create(
|
||||
cfg.dataset.repo_id,
|
||||
cfg.dataset.fps,
|
||||
root=cfg.dataset.root,
|
||||
robot_type=robot.name,
|
||||
features=dataset_features,
|
||||
use_videos=cfg.dataset.video,
|
||||
image_writer_processes=cfg.dataset.num_image_writer_processes,
|
||||
image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot.cameras),
|
||||
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
|
||||
)
|
||||
|
||||
# Load pretrained policy
|
||||
policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
|
||||
preprocessor = None
|
||||
postprocessor = None
|
||||
if cfg.policy is not None:
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy,
|
||||
pretrained_path=cfg.policy.pretrained_path,
|
||||
dataset_stats=rename_stats(dataset.meta.stats, cfg.dataset.rename_map),
|
||||
preprocessor_overrides={
|
||||
"device_processor": {"device": cfg.policy.device},
|
||||
"rename_observations_processor": {"rename_map": cfg.dataset.rename_map},
|
||||
},
|
||||
)
|
||||
|
||||
robot.connect()
|
||||
if teleop is not None:
|
||||
teleop.connect()
|
||||
|
||||
listener, events = init_keyboard_listener()
|
||||
|
||||
with VideoEncodingManager(dataset):
|
||||
recorded_episodes = 0
|
||||
while recorded_episodes < cfg.dataset.num_episodes and not events["stop_recording"]:
|
||||
log_say(f"Recording episode {dataset.num_episodes}", cfg.play_sounds)
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=cfg.dataset.fps,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
teleop=teleop,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=cfg.dataset.episode_time_s,
|
||||
single_task=cfg.dataset.single_task,
|
||||
display_data=cfg.display_data,
|
||||
)
|
||||
|
||||
if hasattr(robot, "cameras") and len(robot.cameras) > 0:
|
||||
dataset.start_image_writer(
|
||||
num_processes=cfg.dataset.num_image_writer_processes,
|
||||
num_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot.cameras),
|
||||
)
|
||||
sanity_check_dataset_robot_compatibility(dataset, robot, cfg.dataset.fps, dataset_features)
|
||||
else:
|
||||
# Create empty dataset or load existing saved episodes
|
||||
sanity_check_dataset_name(cfg.dataset.repo_id, cfg.policy)
|
||||
dataset = LeRobotDataset.create(
|
||||
cfg.dataset.repo_id,
|
||||
cfg.dataset.fps,
|
||||
root=cfg.dataset.root,
|
||||
robot_type=robot.name,
|
||||
features=dataset_features,
|
||||
use_videos=cfg.dataset.video,
|
||||
image_writer_processes=cfg.dataset.num_image_writer_processes,
|
||||
image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot.cameras),
|
||||
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
|
||||
)
|
||||
|
||||
# Load pretrained policy
|
||||
policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
|
||||
preprocessor = None
|
||||
postprocessor = None
|
||||
if cfg.policy is not None:
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy,
|
||||
pretrained_path=cfg.policy.pretrained_path,
|
||||
dataset_stats=rename_stats(dataset.meta.stats, cfg.dataset.rename_map),
|
||||
preprocessor_overrides={
|
||||
"device_processor": {"device": cfg.policy.device},
|
||||
"rename_observations_processor": {"rename_map": cfg.dataset.rename_map},
|
||||
},
|
||||
)
|
||||
|
||||
robot.connect()
|
||||
if teleop is not None:
|
||||
teleop.connect()
|
||||
|
||||
listener, events = init_keyboard_listener()
|
||||
|
||||
with VideoEncodingManager(dataset):
|
||||
recorded_episodes = 0
|
||||
while recorded_episodes < cfg.dataset.num_episodes and not events["stop_recording"]:
|
||||
log_say(f"Recording episode {dataset.num_episodes}", cfg.play_sounds)
|
||||
# Execute a few seconds without recording to give time to manually reset the environment
|
||||
# Skip reset for the last episode to be recorded
|
||||
if not events["stop_recording"] and (
|
||||
(recorded_episodes < cfg.dataset.num_episodes - 1) or events["rerecord_episode"]
|
||||
):
|
||||
log_say("Reset the environment", cfg.play_sounds)
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
@@ -469,61 +497,34 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
teleop=teleop,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=cfg.dataset.episode_time_s,
|
||||
control_time_s=cfg.dataset.reset_time_s,
|
||||
single_task=cfg.dataset.single_task,
|
||||
display_data=cfg.display_data,
|
||||
)
|
||||
|
||||
# Execute a few seconds without recording to give time to manually reset the environment
|
||||
# Skip reset for the last episode to be recorded
|
||||
if not events["stop_recording"] and (
|
||||
(recorded_episodes < cfg.dataset.num_episodes - 1) or events["rerecord_episode"]
|
||||
):
|
||||
log_say("Reset the environment", cfg.play_sounds)
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=cfg.dataset.fps,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
teleop=teleop,
|
||||
control_time_s=cfg.dataset.reset_time_s,
|
||||
single_task=cfg.dataset.single_task,
|
||||
display_data=cfg.display_data,
|
||||
)
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-record episode", cfg.play_sounds)
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-record episode", cfg.play_sounds)
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
dataset.save_episode()
|
||||
recorded_episodes += 1
|
||||
|
||||
dataset.save_episode()
|
||||
recorded_episodes += 1
|
||||
finally:
|
||||
log_say("Stop recording", cfg.play_sounds, blocking=True)
|
||||
log_say("Stop recording", cfg.play_sounds, blocking=True)
|
||||
|
||||
if dataset:
|
||||
dataset.finalize()
|
||||
robot.disconnect()
|
||||
if teleop is not None:
|
||||
teleop.disconnect()
|
||||
|
||||
if robot.is_connected:
|
||||
robot.disconnect()
|
||||
if teleop and teleop.is_connected:
|
||||
teleop.disconnect()
|
||||
if not is_headless() and listener is not None:
|
||||
listener.stop()
|
||||
|
||||
if not is_headless() and listener:
|
||||
listener.stop()
|
||||
if cfg.dataset.push_to_hub:
|
||||
dataset.push_to_hub(tags=cfg.dataset.tags, private=cfg.dataset.private)
|
||||
|
||||
if cfg.dataset.push_to_hub:
|
||||
dataset.push_to_hub(tags=cfg.dataset.tags, private=cfg.dataset.private)
|
||||
|
||||
log_say("Exiting", cfg.play_sounds)
|
||||
log_say("Exiting", cfg.play_sounds)
|
||||
return dataset
|
||||
|
||||
|
||||
|
||||
@@ -62,7 +62,6 @@ def update_policy(
|
||||
accelerator: Accelerator,
|
||||
lr_scheduler=None,
|
||||
lock=None,
|
||||
rabc_weights_provider=None,
|
||||
) -> tuple[MetricsTracker, dict]:
|
||||
"""
|
||||
Performs a single training step to update the policy's weights.
|
||||
@@ -79,7 +78,6 @@ def update_policy(
|
||||
accelerator: The Accelerator instance for distributed training and mixed precision.
|
||||
lr_scheduler: An optional learning rate scheduler.
|
||||
lock: An optional lock for thread-safe optimizer updates.
|
||||
rabc_weights_provider: Optional RABCWeights instance for sample weighting.
|
||||
|
||||
Returns:
|
||||
A tuple containing:
|
||||
@@ -89,30 +87,9 @@ def update_policy(
|
||||
start_time = time.perf_counter()
|
||||
policy.train()
|
||||
|
||||
# Get RA-BC weights if enabled
|
||||
rabc_batch_weights = None
|
||||
rabc_batch_stats = None
|
||||
if rabc_weights_provider is not None:
|
||||
rabc_batch_weights, rabc_batch_stats = rabc_weights_provider.compute_batch_weights(batch)
|
||||
|
||||
# Let accelerator handle mixed precision
|
||||
with accelerator.autocast():
|
||||
# Use per-sample loss when RA-BC is enabled for proper weighting
|
||||
if rabc_batch_weights is not None:
|
||||
# Get per-sample losses
|
||||
per_sample_loss, output_dict = policy.forward(batch, reduction="none")
|
||||
|
||||
# Apply RA-BC weights: L_RA-BC = Σ(w_i * l_i) / (Σw_i + ε)
|
||||
# rabc_batch_weights is already normalized to sum to batch_size
|
||||
epsilon = 1e-6
|
||||
loss = (per_sample_loss * rabc_batch_weights).sum() / (rabc_batch_weights.sum() + epsilon)
|
||||
# Log raw mean weight (before normalization) - this is the meaningful metric
|
||||
output_dict["rabc_mean_weight"] = rabc_batch_stats["raw_mean_weight"]
|
||||
output_dict["rabc_num_zero_weight"] = rabc_batch_stats["num_zero_weight"]
|
||||
output_dict["rabc_num_full_weight"] = rabc_batch_stats["num_full_weight"]
|
||||
else:
|
||||
loss, output_dict = policy.forward(batch)
|
||||
|
||||
loss, output_dict = policy.forward(batch)
|
||||
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
|
||||
|
||||
# Use accelerator's backward method
|
||||
@@ -164,6 +141,8 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
cfg: A `TrainPipelineConfig` object containing all training configurations.
|
||||
accelerator: Optional Accelerator instance. If None, one will be created automatically.
|
||||
"""
|
||||
cfg.validate()
|
||||
|
||||
# Create Accelerator if not provided
|
||||
# It will automatically detect if running in distributed mode or single-process mode
|
||||
# We set step_scheduler_with_optimizer=False to prevent accelerate from adjusting the lr_scheduler steps based on the num_processes
|
||||
@@ -180,8 +159,6 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
# When using accelerate, only the main process should log to avoid duplicate outputs
|
||||
is_main_process = accelerator.is_main_process
|
||||
|
||||
cfg.validate()
|
||||
|
||||
# Only log on main process
|
||||
if is_main_process:
|
||||
logging.info(pformat(cfg.to_dict()))
|
||||
@@ -240,10 +217,6 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
# Only provide dataset_stats when not resuming from saved processor state
|
||||
processor_kwargs["dataset_stats"] = dataset.meta.stats
|
||||
|
||||
# For SARM, always provide dataset_meta for progress normalization
|
||||
if cfg.policy.type == "sarm":
|
||||
processor_kwargs["dataset_meta"] = dataset.meta
|
||||
|
||||
if cfg.policy.pretrained_path is not None:
|
||||
processor_kwargs["preprocessor_overrides"] = {
|
||||
"device_processor": {"device": device.type},
|
||||
@@ -275,29 +248,6 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
logging.info("Creating optimizer and scheduler")
|
||||
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
|
||||
|
||||
# Load precomputed SARM progress for RA-BC if enabled
|
||||
# Generate progress using: src/lerobot/policies/sarm/compute_rabc_weights.py
|
||||
rabc_weights = None
|
||||
if cfg.use_rabc:
|
||||
from lerobot.utils.rabc import RABCWeights
|
||||
|
||||
# Get chunk_size from policy config
|
||||
chunk_size = getattr(policy.config, "chunk_size", None)
|
||||
if chunk_size is None:
|
||||
raise ValueError("Chunk size is not found in policy config")
|
||||
|
||||
head_mode = getattr(cfg, "rabc_head_mode", "sparse")
|
||||
logging.info(f"Loading SARM progress for RA-BC from {cfg.rabc_progress_path}")
|
||||
logging.info(f"Using chunk_size={chunk_size} from policy config, head_mode={head_mode}")
|
||||
rabc_weights = RABCWeights(
|
||||
progress_path=cfg.rabc_progress_path,
|
||||
chunk_size=chunk_size,
|
||||
head_mode=head_mode,
|
||||
kappa=getattr(cfg, "rabc_kappa", 0.01),
|
||||
epsilon=getattr(cfg, "rabc_epsilon", 1e-6),
|
||||
device=device,
|
||||
)
|
||||
|
||||
step = 0 # number of policy updates (forward + backward + optim)
|
||||
|
||||
if cfg.resume:
|
||||
@@ -377,9 +327,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
)
|
||||
|
||||
if is_main_process:
|
||||
logging.info(
|
||||
f"Start offline training on a fixed dataset, with effective batch size: {effective_batch_size}"
|
||||
)
|
||||
logging.info("Start offline training on a fixed dataset")
|
||||
|
||||
for _ in range(step, cfg.steps):
|
||||
start_time = time.perf_counter()
|
||||
@@ -395,7 +343,6 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
cfg.optimizer.grad_clip_norm,
|
||||
accelerator=accelerator,
|
||||
lr_scheduler=lr_scheduler,
|
||||
rabc_weights_provider=rabc_weights,
|
||||
)
|
||||
|
||||
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
|
||||
@@ -412,16 +359,6 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
wandb_log_dict = train_tracker.to_dict()
|
||||
if output_dict:
|
||||
wandb_log_dict.update(output_dict)
|
||||
# Log RA-BC statistics if enabled
|
||||
if rabc_weights is not None:
|
||||
rabc_stats = rabc_weights.get_stats()
|
||||
wandb_log_dict.update(
|
||||
{
|
||||
"rabc_delta_mean": rabc_stats["delta_mean"],
|
||||
"rabc_delta_std": rabc_stats["delta_std"],
|
||||
"rabc_num_frames": rabc_stats["num_frames"],
|
||||
}
|
||||
)
|
||||
wandb_logger.log_dict(wandb_log_dict, step)
|
||||
train_tracker.reset_averages()
|
||||
|
||||
|
||||
@@ -14,8 +14,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import importlib
|
||||
import importlib.metadata
|
||||
import logging
|
||||
import pkgutil
|
||||
from typing import Any
|
||||
|
||||
from draccus.choice_types import ChoiceRegistry
|
||||
@@ -132,30 +132,24 @@ def make_device_from_device_class(config: ChoiceRegistry) -> Any:
|
||||
|
||||
def register_third_party_plugins() -> None:
|
||||
"""
|
||||
Discover and import third-party LeRobot plugins so they can register themselves.
|
||||
Discover and import third-party lerobot_* plugins so they can register themselves.
|
||||
|
||||
This function uses `importlib.metadata` to find packages installed in the environment
|
||||
(including editable installs) starting with 'lerobot_robot_', 'lerobot_camera_',
|
||||
'lerobot_teleoperator_', or 'lerobot_policy_' and imports them.
|
||||
Scans top-level modules on sys.path for packages starting with
|
||||
'lerobot_robot_', 'lerobot_camera_', 'lerobot_teleoperator_' or 'lerobot_policy_' and imports them.
|
||||
"""
|
||||
prefixes = ("lerobot_robot_", "lerobot_camera_", "lerobot_teleoperator_", "lerobot_policy_")
|
||||
imported: list[str] = []
|
||||
failed: list[str] = []
|
||||
|
||||
def attempt_import(module_name: str):
|
||||
try:
|
||||
importlib.import_module(module_name)
|
||||
imported.append(module_name)
|
||||
logging.info("Imported third-party plugin: %s", module_name)
|
||||
except Exception:
|
||||
logging.exception("Could not import third-party plugin: %s", module_name)
|
||||
failed.append(module_name)
|
||||
|
||||
for dist in importlib.metadata.distributions():
|
||||
dist_name = dist.metadata.get("Name")
|
||||
if not dist_name:
|
||||
continue
|
||||
if dist_name.startswith(prefixes):
|
||||
attempt_import(dist_name)
|
||||
for module_info in pkgutil.iter_modules():
|
||||
name = module_info.name
|
||||
if name.startswith(prefixes):
|
||||
try:
|
||||
importlib.import_module(name)
|
||||
imported.append(name)
|
||||
logging.info("Imported third-party plugin: %s", name)
|
||||
except Exception:
|
||||
logging.exception("Could not import third-party plugin: %s", name)
|
||||
failed.append(name)
|
||||
|
||||
logging.debug("Third-party plugin import summary: imported=%s failed=%s", imported, failed)
|
||||
|
||||
@@ -1,276 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
|
||||
|
||||
class RABCWeights:
|
||||
"""
|
||||
Load precomputed SARM progress values and compute RA-BC weights during training.
|
||||
|
||||
Progress values are loaded from a parquet file (generated by compute_rabc_weights.py).
|
||||
During training, computes:
|
||||
- progress_delta = progress[t + chunk_size] - progress[t]
|
||||
- rabc_weight based on the delta (paper Eq. 8-9)
|
||||
|
||||
Args:
|
||||
progress_path: Path to parquet file with precomputed progress values
|
||||
chunk_size: Number of frames ahead for computing progress delta
|
||||
head_mode: Which SARM head to use ("sparse" or "dense")
|
||||
kappa: Hard threshold for high-quality samples (default: 0.01)
|
||||
epsilon: Small constant for numerical stability (default: 1e-6)
|
||||
fallback_weight: Weight to use for frames without valid delta (default: 1.0)
|
||||
device: Device to return tensors on
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
progress_path: str | Path,
|
||||
chunk_size: int = 50,
|
||||
head_mode: str = "sparse",
|
||||
kappa: float = 0.01,
|
||||
epsilon: float = 1e-6,
|
||||
fallback_weight: float = 1.0,
|
||||
device: torch.device = None,
|
||||
):
|
||||
self.progress_path = Path(progress_path)
|
||||
self.chunk_size = chunk_size
|
||||
self.head_mode = head_mode
|
||||
self.kappa = kappa
|
||||
self.epsilon = epsilon
|
||||
self.fallback_weight = fallback_weight
|
||||
self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
# Determine progress column name
|
||||
self.progress_column = f"progress_{head_mode}"
|
||||
|
||||
# Load progress values
|
||||
logging.info(f"Loading SARM progress values from {self.progress_path}")
|
||||
self.df = pd.read_parquet(self.progress_path)
|
||||
|
||||
# Check if the requested head mode column exists
|
||||
if self.progress_column not in self.df.columns:
|
||||
available = [c for c in self.df.columns if c.startswith("progress")]
|
||||
raise ValueError(
|
||||
f"Column '{self.progress_column}' not found. Available progress columns: {available}"
|
||||
)
|
||||
|
||||
logging.info(f"Using progress column: {self.progress_column}")
|
||||
|
||||
self.progress_lookup = {}
|
||||
self.episode_lookup = {}
|
||||
|
||||
for _, row in self.df.iterrows():
|
||||
global_idx = int(row["index"])
|
||||
progress = row[self.progress_column]
|
||||
episode_idx = int(row["episode_index"])
|
||||
|
||||
if not np.isnan(progress):
|
||||
self.progress_lookup[global_idx] = float(progress)
|
||||
self.episode_lookup[global_idx] = episode_idx
|
||||
|
||||
# Build episode boundaries for delta computation
|
||||
self.episode_boundaries = {}
|
||||
for episode_idx in self.df["episode_index"].unique():
|
||||
ep_df = self.df[self.df["episode_index"] == episode_idx]
|
||||
self.episode_boundaries[int(episode_idx)] = {
|
||||
"start": int(ep_df["index"].min()),
|
||||
"end": int(ep_df["index"].max()) + 1,
|
||||
}
|
||||
|
||||
logging.info(f"Loaded {len(self.progress_lookup)} frame progress values")
|
||||
logging.info(f"Chunk size for delta computation: {chunk_size}")
|
||||
|
||||
# Compute global statistics for weight computation
|
||||
self._compute_global_stats()
|
||||
|
||||
def _compute_global_stats(self):
|
||||
"""Compute global mean and std of progress deltas for weight calculation."""
|
||||
all_deltas = []
|
||||
|
||||
for global_idx, progress in self.progress_lookup.items():
|
||||
episode_idx = self.episode_lookup.get(global_idx)
|
||||
if episode_idx is None:
|
||||
continue
|
||||
|
||||
bounds = self.episode_boundaries.get(episode_idx)
|
||||
if bounds is None:
|
||||
continue
|
||||
|
||||
future_idx = global_idx + self.chunk_size
|
||||
if future_idx >= bounds["end"]:
|
||||
# Near end of episode: use last frame's progress
|
||||
future_idx = bounds["end"] - 1
|
||||
|
||||
future_progress = self.progress_lookup.get(future_idx)
|
||||
if future_progress is not None:
|
||||
delta = future_progress - progress
|
||||
all_deltas.append(delta)
|
||||
|
||||
if all_deltas:
|
||||
self.delta_mean = max(np.mean(all_deltas), 0.0)
|
||||
self.delta_std = max(np.std(all_deltas), self.epsilon)
|
||||
logging.info(f"Progress delta stats: mean={self.delta_mean:.4f}, std={self.delta_std:.4f}")
|
||||
else:
|
||||
self.delta_mean = 0.0
|
||||
self.delta_std = self.epsilon
|
||||
logging.warning("No valid progress deltas found, using default stats")
|
||||
|
||||
def compute_batch_weights(self, batch: dict) -> tuple[torch.Tensor, dict]:
|
||||
"""
|
||||
Compute RA-BC weights for a batch.
|
||||
|
||||
For each sample:
|
||||
1. Get progress at current frame
|
||||
2. Get progress at frame + chunk_size (within same episode)
|
||||
3. Compute delta = future_progress - current_progress
|
||||
4. Compute weight using paper Eq. 8-9
|
||||
|
||||
Args:
|
||||
batch: Training batch containing "index" key with global frame indices
|
||||
|
||||
Returns:
|
||||
Tuple of:
|
||||
- Weights tensor (batch_size,) normalized to sum to batch_size
|
||||
- Stats dict with raw_mean_weight, num_zero_weight, num_full_weight
|
||||
"""
|
||||
indices = batch.get("index")
|
||||
if indices is None:
|
||||
logging.warning("RA-BC: Batch missing 'index' key, using uniform weights")
|
||||
batch_size = self._get_batch_size(batch)
|
||||
return torch.ones(batch_size, device=self.device), {"raw_mean_weight": 1.0}
|
||||
|
||||
# Convert to list of ints
|
||||
if isinstance(indices, torch.Tensor):
|
||||
indices = indices.cpu().numpy().tolist()
|
||||
elif isinstance(indices, np.ndarray):
|
||||
indices = indices.tolist()
|
||||
|
||||
# Compute deltas and weights for each sample
|
||||
deltas = []
|
||||
for idx in indices:
|
||||
idx = int(idx)
|
||||
delta = self._compute_delta(idx)
|
||||
deltas.append(delta)
|
||||
|
||||
deltas = np.array(deltas, dtype=np.float32)
|
||||
|
||||
# Compute weights from deltas
|
||||
weights = self._compute_weights(deltas)
|
||||
|
||||
# Compute stats before normalization for logging
|
||||
raw_mean_weight = float(np.nanmean(weights))
|
||||
num_zero_weight = int(np.sum(weights == 0))
|
||||
num_full_weight = int(np.sum(weights == 1.0))
|
||||
batch_stats = {
|
||||
"raw_mean_weight": raw_mean_weight,
|
||||
"num_zero_weight": num_zero_weight,
|
||||
"num_full_weight": num_full_weight,
|
||||
}
|
||||
|
||||
weights = torch.tensor(weights, device=self.device, dtype=torch.float32)
|
||||
|
||||
# Normalize to sum to batch_size
|
||||
batch_size = len(weights)
|
||||
weight_sum = weights.sum() + self.epsilon
|
||||
weights = weights * batch_size / weight_sum
|
||||
|
||||
return weights, batch_stats
|
||||
|
||||
def _compute_delta(self, global_idx: int) -> float:
|
||||
"""Compute progress delta for a single frame."""
|
||||
current_progress = self.progress_lookup.get(global_idx)
|
||||
if current_progress is None:
|
||||
return np.nan
|
||||
|
||||
episode_idx = self.episode_lookup.get(global_idx)
|
||||
if episode_idx is None:
|
||||
return np.nan
|
||||
|
||||
bounds = self.episode_boundaries.get(episode_idx)
|
||||
if bounds is None:
|
||||
return np.nan
|
||||
|
||||
future_idx = global_idx + self.chunk_size # Δ = chunk_size
|
||||
if future_idx >= bounds["end"]:
|
||||
# Near end of episode: use last frame's progress instead
|
||||
future_idx = bounds["end"] - 1
|
||||
|
||||
future_progress = self.progress_lookup.get(future_idx)
|
||||
if future_progress is None:
|
||||
return np.nan
|
||||
|
||||
return future_progress - current_progress
|
||||
|
||||
def _compute_weights(self, deltas: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Compute RA-BC weights from progress deltas.
|
||||
|
||||
Following paper Eq. 8-9:
|
||||
- Soft weight: ˜wi = clip((ri − (µ − 2σ)) / (4σ + ε), 0, 1)
|
||||
- Final weight: wi = 1{ri > κ} + 1{0 ≤ ri ≤ κ}˜wi
|
||||
|
||||
Returns:
|
||||
Array of weights
|
||||
"""
|
||||
valid_mask = ~np.isnan(deltas)
|
||||
|
||||
# Compute soft weights using global statistics
|
||||
lower_bound = self.delta_mean - 2 * self.delta_std
|
||||
soft_weights = (deltas - lower_bound) / (4 * self.delta_std + self.epsilon)
|
||||
soft_weights = np.clip(soft_weights, 0.0, 1.0)
|
||||
|
||||
# Apply paper's Eq. 9
|
||||
weights = np.zeros_like(deltas, dtype=np.float32)
|
||||
|
||||
# High quality: ri > kappa → weight = 1
|
||||
high_quality_mask = deltas > self.kappa
|
||||
weights[high_quality_mask] = 1.0
|
||||
|
||||
# Moderate quality: 0 <= ri <= kappa → weight = soft_weight
|
||||
moderate_mask = (deltas >= 0) & (deltas <= self.kappa)
|
||||
weights[moderate_mask] = soft_weights[moderate_mask]
|
||||
|
||||
# Negative progress: ri < 0 → weight = 0 (already 0)
|
||||
# Invalid (NaN): use fallback weight
|
||||
weights[~valid_mask] = self.fallback_weight
|
||||
|
||||
return weights
|
||||
|
||||
def _get_batch_size(self, batch: dict) -> int:
|
||||
"""Determine batch size from batch."""
|
||||
for key in ["action", "index"]:
|
||||
if key in batch:
|
||||
val = batch[key]
|
||||
if isinstance(val, (torch.Tensor, np.ndarray)):
|
||||
return val.shape[0]
|
||||
return 1
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
"""Get statistics."""
|
||||
return {
|
||||
"num_frames": len(self.progress_lookup),
|
||||
"chunk_size": self.chunk_size,
|
||||
"head_mode": self.head_mode,
|
||||
"delta_mean": self.delta_mean,
|
||||
"delta_std": self.delta_std,
|
||||
"kappa": self.kappa,
|
||||
}
|
||||
@@ -1,694 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pytest
|
||||
|
||||
pytest.importorskip("faker")
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.processor.core import TransitionKey
|
||||
|
||||
|
||||
class MockDatasetMeta:
|
||||
"""Mock dataset metadata for testing processor."""
|
||||
|
||||
def __init__(self, episodes: list[dict]):
|
||||
self._episodes = episodes
|
||||
|
||||
@property
|
||||
def episodes(self):
|
||||
"""Return episodes as a mock object with to_pandas() method."""
|
||||
mock = MagicMock()
|
||||
mock.__len__ = lambda s: len(self._episodes)
|
||||
mock.__getitem__ = lambda s, idx: self._episodes[idx]
|
||||
mock.to_pandas = lambda: pd.DataFrame(self._episodes)
|
||||
return mock
|
||||
|
||||
|
||||
class MockConfig:
|
||||
"""Mock SARMConfig for testing processor methods."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_obs_steps: int = 8,
|
||||
max_rewind_steps: int = 4,
|
||||
frame_gap: int = 30,
|
||||
sparse_subtask_names: list = None,
|
||||
sparse_temporal_proportions: list = None,
|
||||
dense_subtask_names: list = None,
|
||||
dense_temporal_proportions: list = None,
|
||||
image_key: str = "observation.images.top",
|
||||
state_key: str = "observation.state",
|
||||
max_state_dim: int = 32,
|
||||
device: str = None,
|
||||
rewind_probability: float = 0.8,
|
||||
language_perturbation_probability: float = 0.2,
|
||||
annotation_mode: str = "dual",
|
||||
clip_batch_size: int = 64,
|
||||
text_dim: int = 512,
|
||||
):
|
||||
self.n_obs_steps = n_obs_steps
|
||||
self.max_rewind_steps = max_rewind_steps
|
||||
self.frame_gap = frame_gap
|
||||
self.sparse_subtask_names = sparse_subtask_names or ["task"]
|
||||
self.sparse_temporal_proportions = sparse_temporal_proportions or [1.0]
|
||||
self.dense_subtask_names = dense_subtask_names
|
||||
self.dense_temporal_proportions = dense_temporal_proportions
|
||||
self.uses_dual_heads = annotation_mode in ["dense_only", "dual"]
|
||||
self.image_key = image_key
|
||||
self.state_key = state_key
|
||||
self.max_state_dim = max_state_dim
|
||||
self.device = device
|
||||
self.rewind_probability = rewind_probability
|
||||
self.language_perturbation_probability = language_perturbation_probability
|
||||
self.annotation_mode = annotation_mode
|
||||
self.clip_batch_size = clip_batch_size
|
||||
self.text_dim = text_dim
|
||||
|
||||
# Compute observation delta indices (same as config: bidirectional)
|
||||
half_steps = self.n_obs_steps // 2
|
||||
past_deltas = [-self.frame_gap * i for i in range(half_steps, 0, -1)]
|
||||
future_deltas = [self.frame_gap * i for i in range(1, half_steps + 1)]
|
||||
obs_deltas = past_deltas + [0] + future_deltas
|
||||
rewind_deltas = [-self.frame_gap * (i + 1) for i in range(self.max_rewind_steps)]
|
||||
self.observation_delta_indices = obs_deltas + rewind_deltas
|
||||
|
||||
@property
|
||||
def num_frames(self) -> int:
|
||||
return 1 + self.n_obs_steps + self.max_rewind_steps
|
||||
|
||||
|
||||
class TestSARMEncodingProcessorStepEndToEnd:
|
||||
"""End-to-end test for SARMEncodingProcessorStep with dummy batch data."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_clip_model(self):
|
||||
"""Mock CLIP model to avoid loading real weights."""
|
||||
with (
|
||||
patch("lerobot.policies.sarm.processor_sarm.CLIPModel") as mock_model_cls,
|
||||
patch("lerobot.policies.sarm.processor_sarm.CLIPProcessor") as mock_processor_cls,
|
||||
):
|
||||
# Mock the CLIP model - return embeddings based on input batch size
|
||||
mock_model = MagicMock()
|
||||
|
||||
def get_image_features_side_effect(**kwargs):
|
||||
pixel_values = kwargs.get("pixel_values")
|
||||
batch_size = pixel_values.shape[0] if pixel_values is not None else 1
|
||||
return torch.randn(batch_size, 512)
|
||||
|
||||
mock_model.get_image_features.side_effect = get_image_features_side_effect
|
||||
mock_model.get_text_features.return_value = torch.randn(1, 512)
|
||||
mock_model.to.return_value = mock_model
|
||||
mock_model_cls.from_pretrained.return_value = mock_model
|
||||
|
||||
# Mock the CLIP processor - return tensors based on input images
|
||||
mock_processor = MagicMock()
|
||||
|
||||
def processor_side_effect(images=None, **kwargs):
|
||||
num_images = len(images) if images is not None else 1
|
||||
return {
|
||||
"pixel_values": torch.randn(num_images, 3, 224, 224),
|
||||
}
|
||||
|
||||
mock_processor.side_effect = processor_side_effect
|
||||
# Mock tokenizer for text encoding
|
||||
mock_processor.tokenizer.return_value = {
|
||||
"input_ids": torch.ones(1, 77, dtype=torch.long),
|
||||
"attention_mask": torch.ones(1, 77, dtype=torch.long),
|
||||
}
|
||||
mock_processor_cls.from_pretrained.return_value = mock_processor
|
||||
|
||||
yield mock_model, mock_processor
|
||||
|
||||
@pytest.fixture
|
||||
def processor_with_mocks(self, mock_clip_model):
|
||||
"""Create a processor with mocked CLIP and dataset metadata for dual mode."""
|
||||
from lerobot.policies.sarm.processor_sarm import SARMEncodingProcessorStep
|
||||
|
||||
# Dual mode config with both sparse and dense annotations
|
||||
config = MockConfig(
|
||||
n_obs_steps=8,
|
||||
max_rewind_steps=4,
|
||||
frame_gap=30,
|
||||
rewind_probability=0.0, # Disable for deterministic test
|
||||
language_perturbation_probability=0.0, # Disable for deterministic test
|
||||
annotation_mode="dual",
|
||||
sparse_subtask_names=["reach", "grasp", "lift"],
|
||||
sparse_temporal_proportions=[0.3, 0.4, 0.3],
|
||||
dense_subtask_names=["approach", "contact", "close_gripper", "lift_up"],
|
||||
dense_temporal_proportions=[0.25, 0.25, 0.25, 0.25],
|
||||
)
|
||||
|
||||
# Create mock dataset metadata with one episode of 300 frames
|
||||
# Include annotation columns for dual mode
|
||||
episodes = [
|
||||
{
|
||||
"dataset_from_index": 0,
|
||||
"dataset_to_index": 300,
|
||||
"task": "pick up the cube",
|
||||
"sparse_subtask_names": ["reach", "grasp", "lift"],
|
||||
"sparse_subtask_start_frames": [0, 90, 210],
|
||||
"sparse_subtask_end_frames": [90, 210, 300],
|
||||
"dense_subtask_names": ["approach", "contact", "close_gripper", "lift_up"],
|
||||
"dense_subtask_start_frames": [0, 75, 150, 225],
|
||||
"dense_subtask_end_frames": [75, 150, 225, 300],
|
||||
}
|
||||
]
|
||||
dataset_meta = MockDatasetMeta(episodes)
|
||||
|
||||
processor = SARMEncodingProcessorStep(
|
||||
config=config,
|
||||
dataset_meta=dataset_meta,
|
||||
)
|
||||
processor.train(True) # Use train() method, not direct assignment
|
||||
|
||||
return processor, config
|
||||
|
||||
def test_call_with_single_frame_batch(self, processor_with_mocks):
|
||||
"""Test processor __call__ with a single-frame batch."""
|
||||
processor, config = processor_with_mocks
|
||||
|
||||
# Create dummy input transition
|
||||
batch_size = 1
|
||||
num_frames = config.num_frames # 13 frames (9 obs + 4 rewind)
|
||||
|
||||
# Image: (T, C, H, W) format as expected by processor
|
||||
dummy_image = np.random.rand(num_frames, 3, 224, 224).astype(np.float32)
|
||||
|
||||
# State: (T, D) format
|
||||
dummy_state = np.random.rand(num_frames, 6).astype(np.float32)
|
||||
|
||||
transition = {
|
||||
TransitionKey.OBSERVATION: {
|
||||
config.image_key: dummy_image,
|
||||
config.state_key: dummy_state,
|
||||
},
|
||||
TransitionKey.COMPLEMENTARY_DATA: {
|
||||
"index": 150, # Middle of episode
|
||||
"episode_index": 0,
|
||||
"task": "pick up the cube",
|
||||
},
|
||||
}
|
||||
|
||||
# Run processor
|
||||
result = processor(transition)
|
||||
|
||||
# Verify output structure
|
||||
obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check video features exist and have correct shape
|
||||
assert "video_features" in obs
|
||||
video_features = obs["video_features"]
|
||||
assert video_features.shape[0] == batch_size
|
||||
assert video_features.shape[1] == num_frames
|
||||
assert video_features.shape[2] == 512 # CLIP embedding dim
|
||||
|
||||
# Check state features exist and have correct shape
|
||||
assert "state_features" in obs
|
||||
state_features = obs["state_features"]
|
||||
assert state_features.shape[0] == batch_size
|
||||
assert state_features.shape[1] == num_frames
|
||||
assert state_features.shape[2] == config.max_state_dim # Padded to max_state_dim
|
||||
|
||||
# Check text features exist and have correct shape
|
||||
assert "text_features" in obs
|
||||
text_features = obs["text_features"]
|
||||
assert text_features.shape[0] == batch_size
|
||||
assert text_features.shape[1] == 512 # CLIP embedding dim
|
||||
|
||||
# Check lengths tensor
|
||||
assert "lengths" in obs
|
||||
lengths = obs["lengths"]
|
||||
assert lengths.shape[0] == batch_size
|
||||
assert lengths.dtype == torch.int32
|
||||
|
||||
# Check sparse_targets exist
|
||||
assert "sparse_targets" in obs
|
||||
sparse_targets = obs["sparse_targets"]
|
||||
assert sparse_targets.shape == (batch_size, num_frames)
|
||||
# All targets should be in [0, max_stages] range (stage.tau format)
|
||||
assert (sparse_targets >= 0).all()
|
||||
|
||||
# Check dense_targets exist (for dual mode)
|
||||
assert "dense_targets" in obs
|
||||
dense_targets = obs["dense_targets"]
|
||||
assert dense_targets.shape == (batch_size, num_frames)
|
||||
assert (dense_targets >= 0).all()
|
||||
|
||||
def test_call_with_batched_input(self, mock_clip_model):
|
||||
"""Test processor __call__ with a batched input (multiple frames) in dual mode."""
|
||||
from lerobot.policies.sarm.processor_sarm import SARMEncodingProcessorStep
|
||||
|
||||
config = MockConfig(
|
||||
n_obs_steps=8,
|
||||
max_rewind_steps=4,
|
||||
frame_gap=30,
|
||||
rewind_probability=0.0,
|
||||
language_perturbation_probability=0.0,
|
||||
annotation_mode="dual",
|
||||
sparse_subtask_names=["reach", "grasp"],
|
||||
sparse_temporal_proportions=[0.5, 0.5],
|
||||
dense_subtask_names=["step1", "step2", "step3"],
|
||||
dense_temporal_proportions=[0.33, 0.34, 0.33],
|
||||
)
|
||||
|
||||
# Two episodes with different lengths, each with sparse+dense annotations
|
||||
episodes = [
|
||||
{
|
||||
"dataset_from_index": 0,
|
||||
"dataset_to_index": 200,
|
||||
"task": "task A",
|
||||
"sparse_subtask_names": ["reach", "grasp"],
|
||||
"sparse_subtask_start_frames": [0, 100],
|
||||
"sparse_subtask_end_frames": [100, 200],
|
||||
"dense_subtask_names": ["step1", "step2", "step3"],
|
||||
"dense_subtask_start_frames": [0, 66, 133],
|
||||
"dense_subtask_end_frames": [66, 133, 200],
|
||||
},
|
||||
{
|
||||
"dataset_from_index": 200,
|
||||
"dataset_to_index": 500,
|
||||
"task": "task B",
|
||||
"sparse_subtask_names": ["reach", "grasp"],
|
||||
"sparse_subtask_start_frames": [200, 350],
|
||||
"sparse_subtask_end_frames": [350, 500],
|
||||
"dense_subtask_names": ["step1", "step2", "step3"],
|
||||
"dense_subtask_start_frames": [200, 300, 400],
|
||||
"dense_subtask_end_frames": [300, 400, 500],
|
||||
},
|
||||
]
|
||||
dataset_meta = MockDatasetMeta(episodes)
|
||||
|
||||
processor = SARMEncodingProcessorStep(config=config, dataset_meta=dataset_meta)
|
||||
processor.train(True)
|
||||
|
||||
batch_size = 2
|
||||
num_frames = config.num_frames
|
||||
|
||||
# Image: (B, T, C, H, W) format
|
||||
dummy_image = np.random.rand(batch_size, num_frames, 3, 224, 224).astype(np.float32)
|
||||
dummy_state = np.random.rand(batch_size, num_frames, 6).astype(np.float32)
|
||||
|
||||
transition = {
|
||||
TransitionKey.OBSERVATION: {
|
||||
config.image_key: dummy_image,
|
||||
config.state_key: dummy_state,
|
||||
},
|
||||
TransitionKey.COMPLEMENTARY_DATA: {
|
||||
"index": np.array([100, 350]), # One frame from each episode
|
||||
"episode_index": np.array([0, 1]),
|
||||
"task": ["task A", "task B"],
|
||||
},
|
||||
}
|
||||
|
||||
result = processor(transition)
|
||||
obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
# Verify batch dimension is preserved for all outputs
|
||||
assert obs["video_features"].shape[0] == batch_size
|
||||
assert obs["state_features"].shape[0] == batch_size
|
||||
assert obs["lengths"].shape[0] == batch_size
|
||||
assert obs["sparse_targets"].shape[0] == batch_size
|
||||
assert obs["dense_targets"].shape[0] == batch_size # Dual mode has dense targets
|
||||
|
||||
def test_targets_increase_with_progress(self, mock_clip_model):
|
||||
"""Test that both sparse and dense targets increase as frame index progresses."""
|
||||
from lerobot.policies.sarm.processor_sarm import SARMEncodingProcessorStep
|
||||
|
||||
config = MockConfig(
|
||||
n_obs_steps=8,
|
||||
max_rewind_steps=4,
|
||||
frame_gap=30,
|
||||
rewind_probability=0.0,
|
||||
language_perturbation_probability=0.0,
|
||||
annotation_mode="dual",
|
||||
sparse_subtask_names=["phase1", "phase2"],
|
||||
sparse_temporal_proportions=[0.5, 0.5],
|
||||
dense_subtask_names=["a", "b", "c", "d"],
|
||||
dense_temporal_proportions=[0.25, 0.25, 0.25, 0.25],
|
||||
)
|
||||
|
||||
episodes = [
|
||||
{
|
||||
"dataset_from_index": 0,
|
||||
"dataset_to_index": 300,
|
||||
"task": "test task",
|
||||
"sparse_subtask_names": ["phase1", "phase2"],
|
||||
"sparse_subtask_start_frames": [0, 150],
|
||||
"sparse_subtask_end_frames": [150, 300],
|
||||
"dense_subtask_names": ["a", "b", "c", "d"],
|
||||
"dense_subtask_start_frames": [0, 75, 150, 225],
|
||||
"dense_subtask_end_frames": [75, 150, 225, 300],
|
||||
}
|
||||
]
|
||||
dataset_meta = MockDatasetMeta(episodes)
|
||||
|
||||
processor = SARMEncodingProcessorStep(config=config, dataset_meta=dataset_meta)
|
||||
processor.train(True)
|
||||
|
||||
num_frames = config.num_frames
|
||||
|
||||
# Test at early, middle, and late points in episode
|
||||
frame_indices = [30, 150, 270]
|
||||
sparse_center_targets = []
|
||||
dense_center_targets = []
|
||||
|
||||
for frame_idx in frame_indices:
|
||||
dummy_image = np.random.rand(num_frames, 3, 224, 224).astype(np.float32)
|
||||
dummy_state = np.random.rand(num_frames, 6).astype(np.float32)
|
||||
|
||||
transition = {
|
||||
TransitionKey.OBSERVATION: {
|
||||
config.image_key: dummy_image,
|
||||
config.state_key: dummy_state,
|
||||
},
|
||||
TransitionKey.COMPLEMENTARY_DATA: {
|
||||
"index": frame_idx,
|
||||
"episode_index": 0,
|
||||
"task": "test task",
|
||||
},
|
||||
}
|
||||
|
||||
result = processor(transition)
|
||||
obs = result[TransitionKey.OBSERVATION]
|
||||
# Get target at center frame (index 4 in 9-frame observation window)
|
||||
sparse_center_targets.append(obs["sparse_targets"][0, 4].item())
|
||||
dense_center_targets.append(obs["dense_targets"][0, 4].item())
|
||||
|
||||
# Both sparse and dense targets should increase with frame index
|
||||
assert sparse_center_targets[0] < sparse_center_targets[2], (
|
||||
f"Early sparse target ({sparse_center_targets[0]}) should be < late ({sparse_center_targets[2]})"
|
||||
)
|
||||
assert dense_center_targets[0] < dense_center_targets[2], (
|
||||
f"Early dense target ({dense_center_targets[0]}) should be < late ({dense_center_targets[2]})"
|
||||
)
|
||||
|
||||
def test_progress_labels_exact_values(self, mock_clip_model):
|
||||
"""Test that progress labels (stage.tau) are computed correctly for known positions."""
|
||||
from lerobot.policies.sarm.processor_sarm import SARMEncodingProcessorStep
|
||||
|
||||
# Simple setup: 2 sparse stages, 4 dense stages, 100 frame episode
|
||||
config = MockConfig(
|
||||
n_obs_steps=8,
|
||||
max_rewind_steps=4,
|
||||
frame_gap=10, # Smaller gap for easier calculation
|
||||
rewind_probability=0.0,
|
||||
language_perturbation_probability=0.0,
|
||||
annotation_mode="dual",
|
||||
sparse_subtask_names=["A", "B"],
|
||||
sparse_temporal_proportions=[0.5, 0.5],
|
||||
dense_subtask_names=["d1", "d2", "d3", "d4"],
|
||||
dense_temporal_proportions=[0.25, 0.25, 0.25, 0.25],
|
||||
)
|
||||
|
||||
# Episode: frames 0-99, sparse stages at [0-49], [50-99]
|
||||
# Dense stages at [0-24], [25-49], [50-74], [75-99]
|
||||
episodes = [
|
||||
{
|
||||
"dataset_from_index": 0,
|
||||
"dataset_to_index": 100,
|
||||
"task": "test",
|
||||
"sparse_subtask_names": ["A", "B"],
|
||||
"sparse_subtask_start_frames": [0, 50],
|
||||
"sparse_subtask_end_frames": [50, 100],
|
||||
"dense_subtask_names": ["d1", "d2", "d3", "d4"],
|
||||
"dense_subtask_start_frames": [0, 25, 50, 75],
|
||||
"dense_subtask_end_frames": [25, 50, 75, 100],
|
||||
}
|
||||
]
|
||||
dataset_meta = MockDatasetMeta(episodes)
|
||||
|
||||
processor = SARMEncodingProcessorStep(config=config, dataset_meta=dataset_meta)
|
||||
processor.train(True)
|
||||
|
||||
num_frames = config.num_frames
|
||||
|
||||
# Test at frame 50 (center of episode)
|
||||
# With frame_gap=10, n_obs_steps=8:
|
||||
# obs indices around frame 50: [10, 20, 30, 40, 50, 60, 70, 80, 90] (9 frames)
|
||||
dummy_image = np.random.rand(num_frames, 3, 224, 224).astype(np.float32)
|
||||
dummy_state = np.random.rand(num_frames, 6).astype(np.float32)
|
||||
|
||||
transition = {
|
||||
TransitionKey.OBSERVATION: {
|
||||
config.image_key: dummy_image,
|
||||
config.state_key: dummy_state,
|
||||
},
|
||||
TransitionKey.COMPLEMENTARY_DATA: {
|
||||
"index": 50,
|
||||
"episode_index": 0,
|
||||
"task": "test",
|
||||
},
|
||||
}
|
||||
|
||||
result = processor(transition)
|
||||
obs = result[TransitionKey.OBSERVATION]
|
||||
sparse_targets = obs["sparse_targets"][0] # (13,)
|
||||
dense_targets = obs["dense_targets"][0] # (13,)
|
||||
|
||||
# First 9 frames are observation frames, last 4 are rewind placeholders (zeros when no rewind)
|
||||
# Check that obs frames have non-zero targets
|
||||
obs_sparse = sparse_targets[:9]
|
||||
obs_dense = dense_targets[:9]
|
||||
|
||||
# Verify targets are monotonically increasing for observation frames
|
||||
for i in range(1, 9):
|
||||
assert obs_sparse[i] >= obs_sparse[i - 1], (
|
||||
f"Sparse targets should be monotonic: {obs_sparse[i - 1].item():.3f} -> {obs_sparse[i].item():.3f}"
|
||||
)
|
||||
assert obs_dense[i] >= obs_dense[i - 1], (
|
||||
f"Dense targets should be monotonic: {obs_dense[i - 1].item():.3f} -> {obs_dense[i].item():.3f}"
|
||||
)
|
||||
|
||||
# Rewind slots should be zero when rewind is disabled
|
||||
rewind_targets = sparse_targets[9:]
|
||||
assert (rewind_targets == 0).all(), "Rewind slots should be zero when rewind is disabled"
|
||||
|
||||
# Check stage transitions: frame 50 is at boundary of sparse stage A->B
|
||||
# Center frame (index 4) corresponds to actual frame 50
|
||||
center_sparse = obs_sparse[4].item()
|
||||
# At frame 50, sparse stage B starts, so target should be ~1.0 (stage 1 + tau 0)
|
||||
assert 0.9 <= center_sparse <= 1.1, (
|
||||
f"At sparse boundary, target should be ~1.0, got {center_sparse:.3f}"
|
||||
)
|
||||
|
||||
def test_rewind_augmentation_applied(self, mock_clip_model):
|
||||
"""Test that rewind augmentation correctly extends sequence and generates targets."""
|
||||
import random
|
||||
|
||||
from lerobot.policies.sarm.processor_sarm import SARMEncodingProcessorStep
|
||||
|
||||
config = MockConfig(
|
||||
n_obs_steps=8,
|
||||
max_rewind_steps=4,
|
||||
frame_gap=10,
|
||||
rewind_probability=1.0, # Always apply rewind
|
||||
language_perturbation_probability=0.0,
|
||||
annotation_mode="dual",
|
||||
sparse_subtask_names=["A", "B"],
|
||||
sparse_temporal_proportions=[0.5, 0.5],
|
||||
dense_subtask_names=["d1", "d2"],
|
||||
dense_temporal_proportions=[0.5, 0.5],
|
||||
)
|
||||
|
||||
episodes = [
|
||||
{
|
||||
"dataset_from_index": 0,
|
||||
"dataset_to_index": 200,
|
||||
"task": "test",
|
||||
"sparse_subtask_names": ["A", "B"],
|
||||
"sparse_subtask_start_frames": [0, 100],
|
||||
"sparse_subtask_end_frames": [100, 200],
|
||||
"dense_subtask_names": ["d1", "d2"],
|
||||
"dense_subtask_start_frames": [0, 100],
|
||||
"dense_subtask_end_frames": [100, 200],
|
||||
}
|
||||
]
|
||||
dataset_meta = MockDatasetMeta(episodes)
|
||||
|
||||
processor = SARMEncodingProcessorStep(config=config, dataset_meta=dataset_meta)
|
||||
processor.train(True)
|
||||
|
||||
num_frames = config.num_frames # 13
|
||||
|
||||
# Test at frame 150 (center of bidirectional window)
|
||||
# With n_obs_steps=8, half_steps=4, frame_gap=10:
|
||||
# - Earliest obs frame = 150 - 4*10 = 110
|
||||
# - Rewind can go back from 110 to frames like 100, 90, 80, 70
|
||||
# - History available = 110 - 0 = 110, so max rewind = 110/10 = 11 (capped at 4)
|
||||
dummy_image = np.random.rand(num_frames, 3, 224, 224).astype(np.float32)
|
||||
dummy_state = np.random.rand(num_frames, 6).astype(np.float32)
|
||||
|
||||
transition = {
|
||||
TransitionKey.OBSERVATION: {
|
||||
config.image_key: dummy_image,
|
||||
config.state_key: dummy_state,
|
||||
},
|
||||
TransitionKey.COMPLEMENTARY_DATA: {
|
||||
"index": 150,
|
||||
"episode_index": 0,
|
||||
"task": "test",
|
||||
},
|
||||
}
|
||||
|
||||
# Seed random for reproducibility
|
||||
random.seed(42)
|
||||
result = processor(transition)
|
||||
obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
lengths = obs["lengths"][0].item()
|
||||
sparse_targets = obs["sparse_targets"][0]
|
||||
|
||||
# With rewind_probability=1.0 and enough history, lengths should be > 9 (9 obs + some rewind)
|
||||
assert lengths > 9, f"With rewind enabled, lengths should be > 9, got {lengths}"
|
||||
assert lengths <= num_frames, f"Lengths should not exceed total frames {num_frames}, got {lengths}"
|
||||
|
||||
# Rewind targets should be non-zero for frames within valid length
|
||||
n_obs_frames = 9
|
||||
rewind_count = lengths - n_obs_frames
|
||||
|
||||
if rewind_count > 0:
|
||||
# Check that rewind frames have targets
|
||||
rewind_targets = sparse_targets[n_obs_frames : n_obs_frames + rewind_count]
|
||||
# Rewind frames are from BEFORE the earliest obs frame (110)
|
||||
# These frames (100, 90, 80, 70) are earlier in the episode
|
||||
earliest_obs_target = sparse_targets[0].item() # Frame 110
|
||||
|
||||
# Rewind targets should be less than earliest obs (they're from earlier frames)
|
||||
for i, rt in enumerate(rewind_targets):
|
||||
assert rt.item() < earliest_obs_target, (
|
||||
f"Rewind target {i} ({rt.item():.3f}) should be < earliest obs ({earliest_obs_target:.3f})"
|
||||
)
|
||||
|
||||
# Rewind targets should be decreasing (going further back in time)
|
||||
for i in range(1, len(rewind_targets)):
|
||||
assert rewind_targets[i] <= rewind_targets[i - 1], (
|
||||
f"Rewind targets should decrease: {rewind_targets[i - 1].item():.3f} -> {rewind_targets[i].item():.3f}"
|
||||
)
|
||||
|
||||
def test_full_sequence_target_consistency(self, mock_clip_model):
|
||||
"""Test that the full sequence of targets is consistent with frame positions."""
|
||||
from lerobot.policies.sarm.processor_sarm import SARMEncodingProcessorStep
|
||||
from lerobot.policies.sarm.sarm_utils import find_stage_and_tau
|
||||
|
||||
config = MockConfig(
|
||||
n_obs_steps=8,
|
||||
max_rewind_steps=4,
|
||||
frame_gap=10,
|
||||
rewind_probability=0.0,
|
||||
language_perturbation_probability=0.0,
|
||||
annotation_mode="dual",
|
||||
sparse_subtask_names=["s1", "s2", "s3"],
|
||||
sparse_temporal_proportions=[0.33, 0.34, 0.33],
|
||||
dense_subtask_names=["d1", "d2"],
|
||||
dense_temporal_proportions=[0.5, 0.5],
|
||||
)
|
||||
|
||||
# 3 sparse stages: [0-33), [33-66), [66-99]
|
||||
# 2 dense stages: [0-50), [50-100)
|
||||
episodes = [
|
||||
{
|
||||
"dataset_from_index": 0,
|
||||
"dataset_to_index": 100,
|
||||
"task": "test",
|
||||
"sparse_subtask_names": ["s1", "s2", "s3"],
|
||||
"sparse_subtask_start_frames": [0, 33, 66],
|
||||
"sparse_subtask_end_frames": [33, 66, 100],
|
||||
"dense_subtask_names": ["d1", "d2"],
|
||||
"dense_subtask_start_frames": [0, 50],
|
||||
"dense_subtask_end_frames": [50, 100],
|
||||
}
|
||||
]
|
||||
dataset_meta = MockDatasetMeta(episodes)
|
||||
|
||||
processor = SARMEncodingProcessorStep(config=config, dataset_meta=dataset_meta)
|
||||
processor.train(True)
|
||||
|
||||
num_frames = config.num_frames
|
||||
|
||||
# Test at frame 50 (middle of episode)
|
||||
dummy_image = np.random.rand(num_frames, 3, 224, 224).astype(np.float32)
|
||||
dummy_state = np.random.rand(num_frames, 6).astype(np.float32)
|
||||
|
||||
transition = {
|
||||
TransitionKey.OBSERVATION: {
|
||||
config.image_key: dummy_image,
|
||||
config.state_key: dummy_state,
|
||||
},
|
||||
TransitionKey.COMPLEMENTARY_DATA: {
|
||||
"index": 50,
|
||||
"episode_index": 0,
|
||||
"task": "test",
|
||||
},
|
||||
}
|
||||
|
||||
result = processor(transition)
|
||||
obs = result[TransitionKey.OBSERVATION]
|
||||
sparse_targets = obs["sparse_targets"][0]
|
||||
dense_targets = obs["dense_targets"][0]
|
||||
|
||||
# Manually compute expected targets for observation frames
|
||||
# With frame_gap=10, n_obs_steps=8, center at 50:
|
||||
# obs frames: [10, 20, 30, 40, 50, 60, 70, 80, 90]
|
||||
expected_obs_frames = [10, 20, 30, 40, 50, 60, 70, 80, 90]
|
||||
|
||||
sparse_names = ["s1", "s2", "s3"]
|
||||
sparse_starts = [0, 33, 66]
|
||||
sparse_ends = [33, 66, 100]
|
||||
sparse_props = {"s1": 0.33, "s2": 0.34, "s3": 0.33}
|
||||
|
||||
dense_names = ["d1", "d2"]
|
||||
dense_starts = [0, 50]
|
||||
dense_ends = [50, 100]
|
||||
dense_props = {"d1": 0.5, "d2": 0.5}
|
||||
|
||||
for i, frame in enumerate(expected_obs_frames):
|
||||
expected_sparse = find_stage_and_tau(
|
||||
frame,
|
||||
100,
|
||||
sparse_names,
|
||||
sparse_starts,
|
||||
sparse_ends,
|
||||
sparse_names,
|
||||
sparse_props,
|
||||
return_combined=True,
|
||||
)
|
||||
expected_dense = find_stage_and_tau(
|
||||
frame,
|
||||
100,
|
||||
dense_names,
|
||||
dense_starts,
|
||||
dense_ends,
|
||||
dense_names,
|
||||
dense_props,
|
||||
return_combined=True,
|
||||
)
|
||||
|
||||
actual_sparse = sparse_targets[i].item()
|
||||
actual_dense = dense_targets[i].item()
|
||||
|
||||
assert abs(actual_sparse - expected_sparse) < 0.01, (
|
||||
f"Frame {frame}: sparse mismatch {actual_sparse:.3f} vs expected {expected_sparse:.3f}"
|
||||
)
|
||||
assert abs(actual_dense - expected_dense) < 0.01, (
|
||||
f"Frame {frame}: dense mismatch {actual_dense:.3f} vs expected {expected_dense:.3f}"
|
||||
)
|
||||