Compare commits
6 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 9c74cbe599 | |||
| fa3919a0ff | |||
| e38346316b | |||
| 2a2b648891 | |||
| cf36f4b873 | |||
| e1ae51b02a |
@@ -12,83 +12,57 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
name: "🚀 Issue / Bug / Request"
|
||||
description: Report a bug, suggest an improvement, or ask a technical question.
|
||||
name: "\U0001F41B Bug Report"
|
||||
description: Submit a bug report to help us improve LeRobot
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
### Thanks for contributing to LeRobot! 🙌
|
||||
Please choose the most relevant sections below. If this is a general "how-to" question, consider our [Discord](https://discord.gg/s3KuuzsPFb) for faster community support.
|
||||
|
||||
- type: dropdown
|
||||
id: issue-type
|
||||
attributes:
|
||||
label: Ticket Type
|
||||
description: What kind of ticket are you opening?
|
||||
options:
|
||||
- "🐛 Bug Report (Something isn't working)"
|
||||
- "💡 Feature Request / Improvement"
|
||||
- "❓ Technical Question"
|
||||
- "🧹 Maintenance / Documentation"
|
||||
validations:
|
||||
required: true
|
||||
Thanks for taking the time to submit a bug report! 🐛
|
||||
If this is not a bug related to the LeRobot library directly, but instead a general question about your code or the library specifically please use our [discord](https://discord.gg/s3KuuzsPFb).
|
||||
|
||||
- type: textarea
|
||||
id: system-info
|
||||
attributes:
|
||||
label: Environment & System Info
|
||||
description: |
|
||||
For bugs or technical questions, please run `lerobot-info` and paste the output.
|
||||
(Optional for feature requests).
|
||||
label: System Info
|
||||
description: Please share your LeRobot configuration by running `lerobot-info` (if installed) or `python -m lerobot.scripts.display_sys_info` (if not installed) and pasting the output below.
|
||||
render: Shell
|
||||
placeholder: lerobot version, OS, python version, etc.
|
||||
placeholder: lerobot version, OS, python version, numpy version, torch version, and lerobot's configuration
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: checkboxes
|
||||
id: information-scripts-examples
|
||||
attributes:
|
||||
label: Information
|
||||
description: 'The problem arises when using:'
|
||||
options:
|
||||
- label: "One of the scripts in the examples/ folder of LeRobot"
|
||||
- label: "My own task or dataset (give details below)"
|
||||
|
||||
- type: textarea
|
||||
id: description
|
||||
id: reproduction
|
||||
validations:
|
||||
required: true
|
||||
attributes:
|
||||
label: Description
|
||||
label: Reproduction
|
||||
description: |
|
||||
Provide a clear summary of the issue or your proposal.
|
||||
- **Bugs:** What is happening?
|
||||
- **Features:** What is the goal/use case?
|
||||
- **Questions:** What are you trying to achieve?
|
||||
If needed, provide a simple code sample that reproduces the problem you ran into. It can be a Colab link or just a code snippet.
|
||||
Sharing error messages or stack traces could be useful as well!
|
||||
Important! Use code tags to correctly format your code. See https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting
|
||||
Try to avoid screenshots, as they are hard to read and don't allow copy-and-pasting.
|
||||
|
||||
placeholder: |
|
||||
A clear and concise description of the issue or suggestion.
|
||||
Steps to reproduce the behavior:
|
||||
|
||||
1.
|
||||
2.
|
||||
3.
|
||||
|
||||
- type: textarea
|
||||
id: context-repro
|
||||
id: expected-behavior
|
||||
validations:
|
||||
required: true
|
||||
attributes:
|
||||
label: Context & Reproduction
|
||||
description: |
|
||||
Provide a code snippet, steps to reproduce a bug, or technical details about your proposal.
|
||||
Please use code blocks for scripts and CLI commands.
|
||||
placeholder: |
|
||||
Steps to reproduce / Usage example:
|
||||
1.
|
||||
2.
|
||||
3.
|
||||
|
||||
- type: textarea
|
||||
id: logs
|
||||
attributes:
|
||||
label: Relevant logs or stack trace
|
||||
description: If applicable, paste relevant error logs here.
|
||||
render: Shell
|
||||
|
||||
- type: checkboxes
|
||||
id: extras
|
||||
attributes:
|
||||
label: Checklist
|
||||
options:
|
||||
- label: I have searched existing tickets to ensure this isn't a duplicate.
|
||||
- label: I am using the latest version of the `main` branch.
|
||||
- label: I have verified this is not an environment-specific problem.
|
||||
|
||||
- type: textarea
|
||||
id: workaround
|
||||
attributes:
|
||||
label: Additional Info / Workarounds
|
||||
description: Anything else we should know? If you have a workaround, please share it!
|
||||
label: Expected behavior
|
||||
description: "A clear and concise description of what you would expect to happen."
|
||||
|
||||
@@ -1,54 +1,41 @@
|
||||
## Title
|
||||
## What this does
|
||||
|
||||
Short, imperative summary (e.g., "fix(robots): handle None in sensor parser"). See [CONTRIBUTING.md](../CONTRIBUTING.md) for PR conventions.
|
||||
Explain what this PR does. Feel free to tag your PR with the appropriate label(s).
|
||||
|
||||
## Type / Scope
|
||||
Examples:
|
||||
| Title | Label |
|
||||
|----------------------|-----------------|
|
||||
| Fixes #[issue] | (🐛 Bug) |
|
||||
| Adds new dataset | (🗃️ Dataset) |
|
||||
| Optimizes something | (⚡️ Performance) |
|
||||
|
||||
- **Type**: (Bug | Feature | Docs | Performance | Test | CI | Chore)
|
||||
- **Scope**: (optional — name of module or package affected)
|
||||
## How it was tested
|
||||
|
||||
## Summary / Motivation
|
||||
Explain/show how you tested your changes.
|
||||
|
||||
- One-paragraph description of what changes and why.
|
||||
- Why this change is needed and any trade-offs or design notes.
|
||||
Examples:
|
||||
|
||||
## Related issues
|
||||
- Added `test_something` in `tests/test_stuff.py`.
|
||||
- Added `new_feature` and checked that training converges with policy X on dataset/environment Y.
|
||||
- Optimized `some_function`, it now runs X times faster than previously.
|
||||
|
||||
- Fixes / Closes: # (if any)
|
||||
- Related: # (if any)
|
||||
## How to checkout & try? (for the reviewer)
|
||||
|
||||
## What changed
|
||||
Provide a simple way for the reviewer to try out your changes.
|
||||
|
||||
- Short, concrete bullets of the modifications (files/behaviour).
|
||||
- Short note if this introduces breaking changes and migration steps.
|
||||
Examples:
|
||||
|
||||
## How was this tested
|
||||
```bash
|
||||
pytest -sx tests/test_stuff.py::test_something
|
||||
```
|
||||
|
||||
- Tests added: list new tests or test files.
|
||||
- Manual checks / dataset runs performed.
|
||||
```bash
|
||||
lerobot-train --some.option=true
|
||||
```
|
||||
|
||||
## How to run locally (reviewer)
|
||||
## SECTION TO REMOVE BEFORE SUBMITTING YOUR PR
|
||||
|
||||
- Run the relevant tests:
|
||||
**Note**: Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
|
||||
members/contributors who may be interested in your PR. Try to avoid tagging more than 3 people.
|
||||
|
||||
```bash
|
||||
pytest -q tests/ -k <keyword>
|
||||
```
|
||||
|
||||
- Run a quick example or CLI (if applicable):
|
||||
|
||||
```bash
|
||||
lerobot-train --some.option=true
|
||||
```
|
||||
|
||||
## Checklist (required before merge)
|
||||
|
||||
- [ ] Linting/formatting run (`pre-commit run -a`)
|
||||
- [ ] All tests pass locally (`pytest`)
|
||||
- [ ] Documentation updated
|
||||
- [ ] CI is green
|
||||
|
||||
## Reviewer notes
|
||||
|
||||
- Anything the reviewer should focus on (performance, edge-cases, specific files) or general notes.
|
||||
- Anyone in the community is free to review the PR.
|
||||
**Note**: Before submitting this PR, please read the [contributor guideline](https://github.com/huggingface/lerobot/blob/main/CONTRIBUTING.md#submitting-a-pull-request-pr).
|
||||
|
||||
@@ -1,69 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
CI:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- '.github/**'
|
||||
- 'docker/**'
|
||||
|
||||
github_actions:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: '.github/**'
|
||||
|
||||
documentation:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- '**/*.md'
|
||||
- '**/*.mdx'
|
||||
- 'docs/**'
|
||||
|
||||
examples:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: 'examples/**'
|
||||
|
||||
tests:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: 'tests/**'
|
||||
|
||||
sensors:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: 'src/lerobot/cameras/**'
|
||||
|
||||
configuration:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: 'src/lerobot/configs/**'
|
||||
|
||||
dataset:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: 'src/lerobot/datasets/**'
|
||||
|
||||
evaluation:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: 'src/lerobot/envs/**'
|
||||
|
||||
robots:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- 'src/lerobot/teleoperators/**'
|
||||
- 'src/lerobot/robots/**'
|
||||
- 'src/lerobot/motors/**'
|
||||
|
||||
policies:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: 'src/lerobot/policies/**'
|
||||
|
||||
processor:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: 'src/lerobot/processor/**'
|
||||
@@ -31,8 +31,7 @@ jobs:
|
||||
name: Upload Preview and Comment
|
||||
if: >
|
||||
github.event.workflow_run.event == 'pull_request' &&
|
||||
github.event.workflow_run.conclusion == 'success' &&
|
||||
github.repository == 'huggingface/lerobot'
|
||||
github.event.workflow_run.conclusion == 'success'
|
||||
uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@main
|
||||
with:
|
||||
package_name: lerobot
|
||||
|
||||
@@ -42,9 +42,7 @@ jobs:
|
||||
# This job builds and deploys the official documentation.
|
||||
build_main_docs:
|
||||
name: Build Main Docs
|
||||
if: >
|
||||
(github.event_name == 'push' || github.event_name == 'workflow_dispatch') &&
|
||||
github.repository == 'huggingface/lerobot'
|
||||
if: github.event_name == 'push' || github.event_name == 'workflow_dispatch'
|
||||
permissions:
|
||||
contents: read
|
||||
uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@main
|
||||
@@ -60,7 +58,7 @@ jobs:
|
||||
# The result of this job triggers the 'Upload PR Documentation' workflow.
|
||||
build_pr_docs:
|
||||
name: Build PR Docs
|
||||
if: github.event_name == 'pull_request' && github.repository == 'huggingface/lerobot'
|
||||
if: github.event_name == 'pull_request'
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: write
|
||||
|
||||
@@ -45,6 +45,7 @@ permissions:
|
||||
env:
|
||||
UV_VERSION: "0.8.0"
|
||||
PYTHON_VERSION: "3.10"
|
||||
DOCKER_IMAGE_NAME: huggingface/lerobot-gpu
|
||||
|
||||
# Ensures that only the latest commit for a PR or branch is built, canceling older runs.
|
||||
concurrency:
|
||||
|
||||
@@ -85,7 +85,7 @@ jobs:
|
||||
python-version: ${{ env.PYTHON_VERSION }}
|
||||
|
||||
- name: Install lerobot with all extras
|
||||
run: uv sync --all-extras --no-extra groot --no-extra wallx # TODO(Steven): Make flash-attn optional
|
||||
run: uv sync --all-extras --no-extra groot # TODO(Steven): Make flash-attn optional
|
||||
|
||||
- name: Run pytest (all extras)
|
||||
run: uv run pytest tests -vv --maxfail=10
|
||||
|
||||
@@ -1,89 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# This workflow automatically labels issues based on their content.
|
||||
name: Issue Labeler
|
||||
on:
|
||||
# Trigger on new issues and edits to existing issues
|
||||
issues:
|
||||
types: [opened, edited]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
issues: write
|
||||
|
||||
jobs:
|
||||
label-issue:
|
||||
name: Auto Label Issue
|
||||
runs-on: ubuntu-latest
|
||||
if: github.repository == 'huggingface/lerobot'
|
||||
steps:
|
||||
- uses: actions/github-script@v8
|
||||
with:
|
||||
script: |
|
||||
// Setup Input Text
|
||||
const body = (context.payload.issue.body || '');
|
||||
const title = (context.payload.issue.title || '');
|
||||
const cleanBody = body.replace(/```[\s\S]*?```/g, '');
|
||||
const text = `${title}\n${cleanBody}`.toLowerCase();
|
||||
const labelsToAdd = new Set();
|
||||
const matches = (re) => re.test(text);
|
||||
|
||||
// Keyword Heuristics
|
||||
|
||||
// Domain Specific
|
||||
if (matches(/\b(bug|error|issue|fault|crash|exception)\b/i)) labelsToAdd.add('bug');
|
||||
if (matches(/\b(feature|enhancement|improvement|support|implement|proposal)\b/i)) labelsToAdd.add('enhancement');
|
||||
if (matches(/\b(question|help|how to||clarify|explain|unclear)\b/i)) labelsToAdd.add('question');
|
||||
if (matches(/\b(maintenance|documentation|docs|readme|tutorial|guide|wiki)\b/i)) labelsToAdd.add('documentation');
|
||||
if (matches(/\b(example|script|sample|demo|notebook)s?\b/i)) labelsToAdd.add('examples');
|
||||
if (matches(/\b(datasets?|data loader|data augmentation|data preprocessing)\b/i)) labelsToAdd.add('dataset');
|
||||
if (matches(/\b(mujoco|isaac|simulation|sim)\b/i)) labelsToAdd.add('simulation');
|
||||
if (matches(/\b(train|training|loss|optimizer|backward|gradient|wandb|sac)\b/i)) labelsToAdd.add('training');
|
||||
if (matches(/\b(rerun|plot|video|render|visualiz|gif)/i)) labelsToAdd.add('visualization');
|
||||
if (matches(/\b(camera|realsense|lidar|depth|sensor|imu|microphone|rgbd)\b/i)) labelsToAdd.add('sensors');
|
||||
if (matches(/\b(aloha|koch|so-100|so100|mobile|teleop|manipulator|robots?)\b/i)) labelsToAdd.add('robots');
|
||||
if (matches(/\b(teleop|teleoperator|controller|leader|follower|joystick|gamepad)\b/i)) labelsToAdd.add('teleoperators');
|
||||
if (matches(/\b(policy|policies|p0licy)\b/i)) labelsToAdd.add('policies');
|
||||
if (matches(/\b(processors?|pipeline)\b/i)) labelsToAdd.add('processor');
|
||||
if (matches(/\b(eval|evaluate|evaluation|metrics?|score|benchmark)\b/i)) labelsToAdd.add('evaluation');
|
||||
|
||||
// Infrastructure & Code Quality
|
||||
if (matches(/\b(tests?|pytest|unittest|failing test)\b/i)) labelsToAdd.add('tests');
|
||||
if (matches(/\b(ci|github actions|workflow|gha|actions?|pipeline)\b/i)) {
|
||||
labelsToAdd.add('CI');
|
||||
labelsToAdd.add('github_actions');
|
||||
}
|
||||
if (matches(/\b(perf|latency|throughput|fps|speed|performance)\b/i)) labelsToAdd.add('performance');
|
||||
if (matches(/\b(dependency|requirements|pip|conda|install error|importerror|package not found)\b/i)) labelsToAdd.add('dependencies');
|
||||
if (matches(/\b(python|pyproject|requirements(\.txt)?|pip install|typing error)\b/i)) labelsToAdd.add('python');
|
||||
|
||||
// Documentation & Meta
|
||||
if (matches(/\b(doc|documentation|docs|readme|typo|how to)\b/i)) labelsToAdd.add('documentation');
|
||||
if (matches(/\b(refactor|cleanup|restructure|rename|modernize code)\b/i)) labelsToAdd.add('refactor');
|
||||
if (matches(/\b(release|changelog|version bump|cut a release|tag v)\b/i)) labelsToAdd.add('release');
|
||||
if (matches(/\b(breaking change|major change)\b/i)) labelsToAdd.add('breaking change');
|
||||
|
||||
// Apply Labels
|
||||
const labels = Array.from(labelsToAdd).filter(Boolean);
|
||||
|
||||
if (labels.length > 0) {
|
||||
console.log(`Adding labels: ${labels.join(', ')}`);
|
||||
await github.rest.issues.addLabels({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: context.issue.number,
|
||||
labels,
|
||||
});
|
||||
}
|
||||
@@ -43,7 +43,6 @@ jobs:
|
||||
name: Build CPU Docker for Nightly
|
||||
runs-on:
|
||||
group: aws-general-8-plus
|
||||
if: github.repository == 'huggingface/lerobot'
|
||||
outputs:
|
||||
image_tag: ${{ env.DOCKER_IMAGE_NAME_CPU }}
|
||||
steps:
|
||||
@@ -78,7 +77,6 @@ jobs:
|
||||
name: Build GPU Docker for Nightly
|
||||
runs-on:
|
||||
group: aws-general-8-plus
|
||||
if: github.repository == 'huggingface/lerobot'
|
||||
outputs:
|
||||
image_tag: ${{ env.DOCKER_IMAGE_NAME_GPU }}
|
||||
steps:
|
||||
|
||||
@@ -1,39 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# This workflow labels pull requests based on the files that were changed.
|
||||
name: Pull Request Labeler
|
||||
|
||||
on:
|
||||
# Allows labeling pull requests when they are opened or updated
|
||||
# zizmor: ignore[dangerous-triggers] Needed to label PRs from forks
|
||||
pull_request_target:
|
||||
branches:
|
||||
- main
|
||||
types: [opened, synchronize, reopened, ready_for_review]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: write
|
||||
|
||||
jobs:
|
||||
triage:
|
||||
name: Label PR
|
||||
runs-on: ubuntu-latest
|
||||
if: github.repository == 'huggingface/lerobot' && !github.event.pull_request.draft
|
||||
steps:
|
||||
- uses: actions/labeler@v6
|
||||
with:
|
||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
sync-labels: true # Removes labels if files are removed from the PR
|
||||
@@ -29,7 +29,6 @@ jobs:
|
||||
build-and-publish:
|
||||
name: Build and publish Python distributions
|
||||
runs-on: ubuntu-latest
|
||||
if: github.repository == 'huggingface/lerobot'
|
||||
outputs:
|
||||
version: ${{ steps.extract_info.outputs.tag_version }}
|
||||
permissions:
|
||||
|
||||
@@ -45,7 +45,6 @@ jobs:
|
||||
stale:
|
||||
name: Close Stale Issues and PRs
|
||||
runs-on: ubuntu-latest
|
||||
if: github.repository == 'huggingface/lerobot'
|
||||
permissions:
|
||||
actions: write
|
||||
contents: write # only for delete-branch option
|
||||
|
||||
@@ -43,7 +43,6 @@ jobs:
|
||||
full-tests:
|
||||
name: Full Unbound Tests
|
||||
runs-on: ubuntu-latest
|
||||
if: github.repository == 'huggingface/lerobot'
|
||||
env:
|
||||
MUJOCO_GL: egl
|
||||
HF_HOME: /mnt/cache/.cache/huggingface
|
||||
@@ -78,7 +77,7 @@ jobs:
|
||||
echo "Dependencies unbound:" && cat pyproject.toml
|
||||
|
||||
- name: Install lerobot with all extras
|
||||
run: uv sync --all-extras --no-extra groot --no-extra wallx # TODO(Steven): Make flash-attn optional
|
||||
run: uv sync --all-extras --no-extra groot # TODO(Steven): Make flash-attn optional
|
||||
|
||||
- name: Run pytest (all extras)
|
||||
run: uv run pytest tests -vv
|
||||
|
||||
@@ -87,7 +87,7 @@ repos:
|
||||
# TODO(Steven): Uncomment when ready to use
|
||||
##### Static Analysis & Typing #####
|
||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||
rev: v1.19.1
|
||||
rev: v1.18.2
|
||||
hooks:
|
||||
- id: mypy
|
||||
args: [--config-file=pyproject.toml]
|
||||
|
||||
@@ -52,7 +52,7 @@ decisions when appropriate.
|
||||
|
||||
This Code of Conduct applies within all community spaces, and also applies when
|
||||
an individual is officially representing the community in public spaces.
|
||||
Examples of representing our community include using an official e-mail address,
|
||||
Examples of representing our community include using an official email address,
|
||||
posting via an official social media account, or acting as an appointed
|
||||
representative at an online or offline event.
|
||||
|
||||
@@ -60,7 +60,7 @@ representative at an online or offline event.
|
||||
|
||||
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
||||
reported to the community leaders responsible for enforcement at
|
||||
feedback@huggingface.co.
|
||||
[feedback@huggingface.co](mailto:feedback@huggingface.co).
|
||||
All complaints will be reviewed and investigated promptly and fairly.
|
||||
|
||||
All community leaders are obligated to respect the privacy and security of the
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
<p align="center">
|
||||
<img alt="LeRobot, Hugging Face Robotics Library" src="./media/readme/lerobot-logo-thumbnail.png" width="100%">
|
||||
<img alt="LeRobot, Hugging Face Robotics Library" src="https://raw.githubusercontent.com/huggingface/lerobot/main/media/lerobot-logo-thumbnail.png" width="100%">
|
||||
<br/>
|
||||
<br/>
|
||||
</p>
|
||||
|
||||
<div align="center">
|
||||
@@ -10,130 +12,323 @@
|
||||
[](https://pypi.org/project/lerobot/)
|
||||
[](https://pypi.org/project/lerobot/)
|
||||
[](https://github.com/huggingface/lerobot/blob/main/CODE_OF_CONDUCT.md)
|
||||
[](https://discord.gg/s3KuuzsPFb)
|
||||
|
||||
<!-- [](https://codecov.io/gh/huggingface/lerobot) -->
|
||||
|
||||
</div>
|
||||
|
||||
**LeRobot** aims to provide models, datasets, and tools for real-world robotics in PyTorch. The goal is to lower the barrier to entry so that everyone can contribute to and benefit from shared datasets and pretrained models.
|
||||
<h2 align="center">
|
||||
<p><a href="https://huggingface.co/docs/lerobot/hope_jr">
|
||||
Build Your Own HopeJR Robot!</a></p>
|
||||
</h2>
|
||||
|
||||
🤗 A hardware-agnostic, Python-native interface that standardizes control across diverse platforms, from low-cost arms (SO-100) to humanoids.
|
||||
<div align="center">
|
||||
<img
|
||||
src="https://raw.githubusercontent.com/huggingface/lerobot/main/media/hope_jr/hopejr.png"
|
||||
alt="HopeJR robot"
|
||||
title="HopeJR robot"
|
||||
width="60%"
|
||||
/>
|
||||
|
||||
🤗 A standardized, scalable LeRobotDataset format (Parquet + MP4 or images) hosted on the Hugging Face Hub, enabling efficient storage, streaming and visualization of massive robotic datasets.
|
||||
<p><strong>Meet HopeJR – A humanoid robot arm and hand for dexterous manipulation!</strong></p>
|
||||
<p>Control it with exoskeletons and gloves for precise hand movements.</p>
|
||||
<p>Perfect for advanced manipulation tasks! 🤖</p>
|
||||
|
||||
🤗 State-of-the-art policies that have been shown to transfer to the real-world ready for training and deployment.
|
||||
<p><a href="https://huggingface.co/docs/lerobot/hope_jr">
|
||||
See the full HopeJR tutorial here.</a></p>
|
||||
</div>
|
||||
|
||||
🤗 Comprehensive support for the open-source ecosystem to democratize physical AI.
|
||||
<br/>
|
||||
|
||||
## Quick Start
|
||||
<h2 align="center">
|
||||
<p><a href="https://huggingface.co/docs/lerobot/so101">
|
||||
Build Your Own SO-101 Robot!</a></p>
|
||||
</h2>
|
||||
|
||||
LeRobot can be installed directly from PyPI.
|
||||
<div align="center">
|
||||
<table>
|
||||
<tr>
|
||||
<td align="center"><img src="https://raw.githubusercontent.com/huggingface/lerobot/main/media/so101/so101.webp" alt="SO-101 follower arm" title="SO-101 follower arm" width="90%"/></td>
|
||||
<td align="center"><img src="https://raw.githubusercontent.com/huggingface/lerobot/main/media/so101/so101-leader.webp" alt="SO-101 leader arm" title="SO-101 leader arm" width="90%"/></td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
<p><strong>Meet the updated SO100, the SO-101 – Just €114 per arm!</strong></p>
|
||||
<p>Train it in minutes with a few simple moves on your laptop.</p>
|
||||
<p>Then sit back and watch your creation act autonomously! 🤯</p>
|
||||
|
||||
<p><a href="https://huggingface.co/docs/lerobot/so101">
|
||||
See the full SO-101 tutorial here.</a></p>
|
||||
|
||||
<p>Want to take it to the next level? Make your SO-101 mobile by building LeKiwi!</p>
|
||||
<p>Check out the <a href="https://huggingface.co/docs/lerobot/lekiwi">LeKiwi tutorial</a> and bring your robot to life on wheels.</p>
|
||||
|
||||
<img src="https://raw.githubusercontent.com/huggingface/lerobot/main/media/lekiwi/kiwi.webp" alt="LeKiwi mobile robot" title="LeKiwi mobile robot" width="50%">
|
||||
</div>
|
||||
|
||||
<br/>
|
||||
|
||||
<h3 align="center">
|
||||
<p>LeRobot: State-of-the-art AI for real-world robotics</p>
|
||||
</h3>
|
||||
|
||||
---
|
||||
|
||||
🤗 LeRobot aims to provide models, datasets, and tools for real-world robotics in PyTorch. The goal is to lower the barrier to entry to robotics so that everyone can contribute and benefit from sharing datasets and pretrained models.
|
||||
|
||||
🤗 LeRobot contains state-of-the-art approaches that have been shown to transfer to the real-world with a focus on imitation learning and reinforcement learning.
|
||||
|
||||
🤗 LeRobot already provides a set of pretrained models, datasets with human collected demonstrations, and simulation environments to get started without assembling a robot. In the coming weeks, the plan is to add more and more support for real-world robotics on the most affordable and capable robots out there.
|
||||
|
||||
🤗 LeRobot hosts pretrained models and datasets on this Hugging Face community page: [huggingface.co/lerobot](https://huggingface.co/lerobot)
|
||||
|
||||
#### Examples of pretrained models on simulation environments
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
<td><img src="https://raw.githubusercontent.com/huggingface/lerobot/main/media/gym/aloha_act.gif" width="100%" alt="ACT policy on ALOHA env"/></td>
|
||||
<td><img src="https://raw.githubusercontent.com/huggingface/lerobot/main/media/gym/simxarm_tdmpc.gif" width="100%" alt="TDMPC policy on SimXArm env"/></td>
|
||||
<td><img src="https://raw.githubusercontent.com/huggingface/lerobot/main/media/gym/pusht_diffusion.gif" width="100%" alt="Diffusion policy on PushT env"/></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">ACT policy on ALOHA env</td>
|
||||
<td align="center">TDMPC policy on SimXArm env</td>
|
||||
<td align="center">Diffusion policy on PushT env</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
## Installation
|
||||
|
||||
LeRobot works with Python 3.10+ and PyTorch 2.2+.
|
||||
|
||||
### Environment Setup
|
||||
|
||||
Create a virtual environment with Python 3.10 and activate it, e.g. with [`miniforge`](https://conda-forge.org/download/):
|
||||
|
||||
```bash
|
||||
conda create -y -n lerobot python=3.10
|
||||
conda activate lerobot
|
||||
```
|
||||
|
||||
When using `conda`, install `ffmpeg` in your environment:
|
||||
|
||||
```bash
|
||||
conda install ffmpeg -c conda-forge
|
||||
```
|
||||
|
||||
> **NOTE:** This usually installs `ffmpeg 7.X` for your platform compiled with the `libsvtav1` encoder. If `libsvtav1` is not supported (check supported encoders with `ffmpeg -encoders`), you can:
|
||||
>
|
||||
> - _[On any platform]_ Explicitly install `ffmpeg 7.X` using:
|
||||
>
|
||||
> ```bash
|
||||
> conda install ffmpeg=7.1.1 -c conda-forge
|
||||
> ```
|
||||
>
|
||||
> - _[On Linux only]_ Install [ffmpeg build dependencies](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#GettheDependencies) and [compile ffmpeg from source with libsvtav1](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#libsvtav1), and make sure you use the corresponding ffmpeg binary to your install with `which ffmpeg`.
|
||||
|
||||
### Install LeRobot 🤗
|
||||
|
||||
#### From Source
|
||||
|
||||
First, clone the repository and navigate into the directory:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/huggingface/lerobot.git
|
||||
cd lerobot
|
||||
```
|
||||
|
||||
Then, install the library in editable mode. This is useful if you plan to contribute to the code.
|
||||
|
||||
```bash
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
> **NOTE:** If you encounter build errors, you may need to install additional dependencies (`cmake`, `build-essential`, and `ffmpeg libs`). On Linux, run:
|
||||
> `sudo apt-get install cmake build-essential python3-dev pkg-config libavformat-dev libavcodec-dev libavdevice-dev libavutil-dev libswscale-dev libswresample-dev libavfilter-dev`. For other systems, see: [Compiling PyAV](https://pyav.org/docs/develop/overview/installation.html#bring-your-own-ffmpeg)
|
||||
|
||||
For simulations, 🤗 LeRobot comes with gymnasium environments that can be installed as extras:
|
||||
|
||||
- [aloha](https://github.com/huggingface/gym-aloha)
|
||||
- [xarm](https://github.com/huggingface/gym-xarm)
|
||||
- [pusht](https://github.com/huggingface/gym-pusht)
|
||||
|
||||
For instance, to install 🤗 LeRobot with aloha and pusht, use:
|
||||
|
||||
```bash
|
||||
pip install -e ".[aloha, pusht]"
|
||||
```
|
||||
|
||||
### Installation from PyPI
|
||||
|
||||
**Core Library:**
|
||||
Install the base package with:
|
||||
|
||||
```bash
|
||||
pip install lerobot
|
||||
lerobot-info
|
||||
```
|
||||
|
||||
> [!IMPORTANT]
|
||||
> For detailed installation guide, please see the [Installation Documentation](https://huggingface.co/docs/lerobot/installation).
|
||||
_This installs only the default dependencies._
|
||||
|
||||
## Robots & Control
|
||||
|
||||
<div align="center">
|
||||
<img src="./media/readme/robots_control_video.webp" width="640px" alt="Reachy 2 Demo">
|
||||
</div>
|
||||
|
||||
LeRobot provides a unified `Robot` class interface that decouples control logic from hardware specifics. It supports a wide range of robots and teleoperation devices.
|
||||
|
||||
```python
|
||||
from lerobot.robots.myrobot import MyRobot
|
||||
|
||||
# Connect to a robot
|
||||
robot = MyRobot(config=...)
|
||||
robot.connect()
|
||||
|
||||
# Read observation and send action
|
||||
obs = robot.get_observation()
|
||||
action = model.select_action(obs)
|
||||
robot.send_action(action)
|
||||
```
|
||||
|
||||
**Supported Hardware:** SO100, LeKiwi, Koch, HopeJR, OMX, EarthRover, Reachy2, Gamepads, Keyboards, Phones, OpenARM, Unitree G1.
|
||||
|
||||
While these devices are natively integrated into the LeRobot codebase, the library is designed to be extensible. You can easily implement the Robot interface to utilize LeRobot's data collection, training, and visualization tools for your own custom robot.
|
||||
|
||||
For detailed hardware setup guides, see the [Hardware Documentation](https://huggingface.co/docs/lerobot/integrate_hardware).
|
||||
|
||||
## LeRobot Dataset
|
||||
|
||||
To solve the data fragmentation problem in robotics, we utilize the **LeRobotDataset** format.
|
||||
|
||||
- **Structure:** Synchronized MP4 videos (or images) for vision and Parquet files for state/action data.
|
||||
- **HF Hub Integration:** Explore thousands of robotics datasets on the [Hugging Face Hub](https://huggingface.co/lerobot).
|
||||
- **Tools:** Seamlessly delete episodes, split by indices/fractions, add/remove features, and merge multiple datasets.
|
||||
|
||||
```python
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
# Load a dataset from the Hub
|
||||
dataset = LeRobotDataset("lerobot/aloha_mobile_cabinet")
|
||||
|
||||
# Access data (automatically handles video decoding)
|
||||
episode_index=0
|
||||
print(f"{dataset[episode_index]['action'].shape=}\n")
|
||||
```
|
||||
|
||||
Learn more about it in the [LeRobotDataset Documentation](https://huggingface.co/docs/lerobot/lerobot-dataset-v3)
|
||||
|
||||
## SoTA Models
|
||||
|
||||
LeRobot implements state-of-the-art policies in pure PyTorch, covering Imitation Learning, Reinforcement Learning, and Vision-Language-Action (VLA) models, with more coming soon. It also provides you with the tools to instrument and inspect your training process.
|
||||
|
||||
<p align="center">
|
||||
<img alt="Gr00t Architecture" src="./media/readme/VLA_architecture.jpg" width="640px">
|
||||
</p>
|
||||
|
||||
Training a policy is as simple as running a script configuration:
|
||||
**Extra Features:**
|
||||
To install additional functionality, use one of the following:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--policy=act \
|
||||
--dataset.repo_id=lerobot/aloha_mobile_cabinet
|
||||
pip install 'lerobot[all]' # All available features
|
||||
pip install 'lerobot[aloha,pusht]' # Specific features (Aloha & Pusht)
|
||||
pip install 'lerobot[feetech]' # Feetech motor support
|
||||
```
|
||||
|
||||
| Category | Models |
|
||||
| -------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| **Imitation Learning** | [ACT](./docs/source/policy_act_README.md), [Diffusion](./docs/source/policy_diffusion_README.md), [VQ-BeT](./docs/source/policy_vqbet_README.md) |
|
||||
| **Reinforcement Learning** | [HIL-SERL](./docs/source/hilserl.mdx), [TDMPC](./docs/source/policy_tdmpc_README.md) & QC-FQL (coming soon) |
|
||||
| **VLAs Models** | [Pi0.5](./docs/source/pi05.mdx), [GR00T N1.5](./docs/source/policy_groot_README.md), [SmolVLA](./docs/source/policy_smolvla_README.md), [XVLA](./docs/source/xvla.mdx) |
|
||||
_Replace `[...]` with your desired features._
|
||||
|
||||
Similarly to the hardware, you can easily implement your own policy & leverage LeRobot's data collection, training, and visualization tools, and share your model to the HF Hub
|
||||
**Available Tags:**
|
||||
For a full list of optional dependencies, see:
|
||||
https://pypi.org/project/lerobot/
|
||||
|
||||
For detailed policy setup guides, see the [Policy Documentation](https://huggingface.co/docs/lerobot/bring_your_own_policies).
|
||||
> [!NOTE]
|
||||
> For lerobot 0.4.0, if you want to install pi tags, you will have to do: `pip install "lerobot[pi]@git+https://github.com/huggingface/lerobot.git"`.
|
||||
>
|
||||
> This will be solved in the next patch release
|
||||
|
||||
## Inference & Evaluation
|
||||
### Weights & Biases
|
||||
|
||||
Evaluate your policies in simulation or on real hardware using the unified evaluation script. LeRobot supports standard benchmarks like **LIBERO**, **MetaWorld** and more to come.
|
||||
To use [Weights and Biases](https://docs.wandb.ai/quickstart) for experiment tracking, log in with
|
||||
|
||||
```bash
|
||||
# Evaluate a policy on the LIBERO benchmark
|
||||
lerobot-eval \
|
||||
--policy.path=lerobot/pi0_libero_finetuned \
|
||||
--env.type=libero \
|
||||
--env.task=libero_object \
|
||||
--eval.n_episodes=10
|
||||
wandb login
|
||||
```
|
||||
|
||||
Learn how to implement your own simulation environment or benchmark and distribute it from the HF Hub by following the [EnvHub Documentation](https://huggingface.co/docs/lerobot/envhub)
|
||||
(note: you will also need to enable WandB in the configuration. See below.)
|
||||
|
||||
## Resources
|
||||
### Visualize datasets
|
||||
|
||||
- **[Documentation](https://huggingface.co/docs/lerobot/index):** The complete guide to tutorials & API.
|
||||
- **[Discord](https://discord.gg/3gxM6Avj):** Join the `LeRobot` server to discuss with the community.
|
||||
- **[X](https://x.com/LeRobotHF):** Follow us on X to stay up-to-date with the latest developments.
|
||||
- **[Robot Learning Tutorial](https://huggingface.co/spaces/lerobot/robot-learning-tutorial):** A free, hands-on course to learn robot learning using LeRobot.
|
||||
Check out [example 1](https://github.com/huggingface/lerobot/blob/main/examples/dataset/load_lerobot_dataset.py) that illustrates how to use our dataset class which automatically downloads data from the Hugging Face hub.
|
||||
|
||||
You can also locally visualize episodes from a dataset on the hub by executing our script from the command line:
|
||||
|
||||
```bash
|
||||
lerobot-dataset-viz \
|
||||
--repo-id lerobot/pusht \
|
||||
--episode-index 0
|
||||
```
|
||||
|
||||
or from a dataset in a local folder with the `root` option and the `--mode local` (in the following case the dataset will be searched for in `./my_local_data_dir/lerobot/pusht`)
|
||||
|
||||
```bash
|
||||
lerobot-dataset-viz \
|
||||
--repo-id lerobot/pusht \
|
||||
--root ./my_local_data_dir \
|
||||
--mode local \
|
||||
--episode-index 0
|
||||
```
|
||||
|
||||
It will open `rerun.io` and display the camera streams, robot states and actions, like this:
|
||||
|
||||
https://github-production-user-asset-6210df.s3.amazonaws.com/4681518/328035972-fd46b787-b532-47e2-bb6f-fd536a55a7ed.mov?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240505%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240505T172924Z&X-Amz-Expires=300&X-Amz-Signature=d680b26c532eeaf80740f08af3320d22ad0b8a4e4da1bcc4f33142c15b509eda&X-Amz-SignedHeaders=host&actor_id=24889239&key_id=0&repo_id=748713144
|
||||
|
||||
Our script can also visualize datasets stored on a distant server. See `lerobot-dataset-viz --help` for more instructions.
|
||||
|
||||
### The `LeRobotDataset` format
|
||||
|
||||
A dataset in `LeRobotDataset` format is very simple to use. It can be loaded from a repository on the Hugging Face hub or a local folder simply with e.g. `dataset = LeRobotDataset("lerobot/aloha_static_coffee")` and can be indexed into like any Hugging Face and PyTorch dataset. For instance `dataset[0]` will retrieve a single temporal frame from the dataset containing observation(s) and an action as PyTorch tensors ready to be fed to a model.
|
||||
|
||||
A specificity of `LeRobotDataset` is that, rather than retrieving a single frame by its index, we can retrieve several frames based on their temporal relationship with the indexed frame, by setting `delta_timestamps` to a list of relative times with respect to the indexed frame. For example, with `delta_timestamps = {"observation.image": [-1, -0.5, -0.2, 0]}` one can retrieve, for a given index, 4 frames: 3 "previous" frames 1 second, 0.5 seconds, and 0.2 seconds before the indexed frame, and the indexed frame itself (corresponding to the 0 entry). See example [1_load_lerobot_dataset.py](https://github.com/huggingface/lerobot/blob/main/examples/dataset/load_lerobot_dataset.py) for more details on `delta_timestamps`.
|
||||
|
||||
Under the hood, the `LeRobotDataset` format makes use of several ways to serialize data which can be useful to understand if you plan to work more closely with this format. We tried to make a flexible yet simple dataset format that would cover most type of features and specificities present in reinforcement learning and robotics, in simulation and in real-world, with a focus on cameras and robot states but easily extended to other types of sensory inputs as long as they can be represented by a tensor.
|
||||
|
||||
Here are the important details and internal structure organization of a typical `LeRobotDataset` instantiated with `dataset = LeRobotDataset("lerobot/aloha_static_coffee")`. The exact features will change from dataset to dataset but not the main aspects:
|
||||
|
||||
```
|
||||
dataset attributes:
|
||||
├ hf_dataset: a Hugging Face dataset (backed by Arrow/parquet). Typical features example:
|
||||
│ ├ observation.images.cam_high (VideoFrame):
|
||||
│ │ VideoFrame = {'path': path to a mp4 video, 'timestamp' (float32): timestamp in the video}
|
||||
│ ├ observation.state (list of float32): position of an arm joints (for instance)
|
||||
│ ... (more observations)
|
||||
│ ├ action (list of float32): goal position of an arm joints (for instance)
|
||||
│ ├ episode_index (int64): index of the episode for this sample
|
||||
│ ├ frame_index (int64): index of the frame for this sample in the episode ; starts at 0 for each episode
|
||||
│ ├ timestamp (float32): timestamp in the episode
|
||||
│ ├ next.done (bool): indicates the end of an episode ; True for the last frame in each episode
|
||||
│ └ index (int64): general index in the whole dataset
|
||||
├ meta: a LeRobotDatasetMetadata object containing:
|
||||
│ ├ info: a dictionary of metadata on the dataset
|
||||
│ │ ├ codebase_version (str): this is to keep track of the codebase version the dataset was created with
|
||||
│ │ ├ fps (int): frame per second the dataset is recorded/synchronized to
|
||||
│ │ ├ features (dict): all features contained in the dataset with their shapes and types
|
||||
│ │ ├ total_episodes (int): total number of episodes in the dataset
|
||||
│ │ ├ total_frames (int): total number of frames in the dataset
|
||||
│ │ ├ robot_type (str): robot type used for recording
|
||||
│ │ ├ data_path (str): formattable string for the parquet files
|
||||
│ │ └ video_path (str): formattable string for the video files (if using videos)
|
||||
│ ├ episodes: a DataFrame containing episode metadata with columns:
|
||||
│ │ ├ episode_index (int): index of the episode
|
||||
│ │ ├ tasks (list): list of tasks for this episode
|
||||
│ │ ├ length (int): number of frames in this episode
|
||||
│ │ ├ dataset_from_index (int): start index of this episode in the dataset
|
||||
│ │ └ dataset_to_index (int): end index of this episode in the dataset
|
||||
│ ├ stats: a dictionary of statistics (max, mean, min, std) for each feature in the dataset, for instance
|
||||
│ │ ├ observation.images.front_cam: {'max': tensor with same number of dimensions (e.g. `(c, 1, 1)` for images, `(c,)` for states), etc.}
|
||||
│ │ └ ...
|
||||
│ └ tasks: a DataFrame containing task information with task names as index and task_index as values
|
||||
├ root (Path): local directory where the dataset is stored
|
||||
├ image_transforms (Callable): optional image transformations to apply to visual modalities
|
||||
└ delta_timestamps (dict): optional delta timestamps for temporal queries
|
||||
```
|
||||
|
||||
A `LeRobotDataset` is serialised using several widespread file formats for each of its parts, namely:
|
||||
|
||||
- hf_dataset stored using Hugging Face datasets library serialization to parquet
|
||||
- videos are stored in mp4 format to save space
|
||||
- metadata are stored in plain json/jsonl files
|
||||
|
||||
Dataset can be uploaded/downloaded from the HuggingFace hub seamlessly. To work on a local dataset, you can specify its location with the `root` argument if it's not in the default `~/.cache/huggingface/lerobot` location.
|
||||
|
||||
#### Reproduce state-of-the-art (SOTA)
|
||||
|
||||
We provide some pretrained policies on our [hub page](https://huggingface.co/lerobot) that can achieve state-of-the-art performances.
|
||||
You can reproduce their training by loading the config from their run. Simply running:
|
||||
|
||||
```bash
|
||||
lerobot-train --config_path=lerobot/diffusion_pusht
|
||||
```
|
||||
|
||||
reproduces SOTA results for Diffusion Policy on the PushT task.
|
||||
|
||||
## Contribute
|
||||
|
||||
If you would like to contribute to 🤗 LeRobot, please check out our [contribution guide](https://github.com/huggingface/lerobot/blob/main/CONTRIBUTING.md).
|
||||
|
||||
### Add a pretrained policy
|
||||
|
||||
Once you have trained a policy you may upload it to the Hugging Face hub using a hub id that looks like `${hf_user}/${repo_name}` (e.g. [lerobot/diffusion_pusht](https://huggingface.co/lerobot/diffusion_pusht)).
|
||||
|
||||
You first need to find the checkpoint folder located inside your experiment directory (e.g. `outputs/train/2024-05-05/20-21-12_aloha_act_default/checkpoints/002500`). Within that there is a `pretrained_model` directory which should contain:
|
||||
|
||||
- `config.json`: A serialized version of the policy configuration (following the policy's dataclass config).
|
||||
- `model.safetensors`: A set of `torch.nn.Module` parameters, saved in [Hugging Face Safetensors](https://huggingface.co/docs/safetensors/index) format.
|
||||
- `train_config.json`: A consolidated configuration containing all parameters used for training. The policy configuration should match `config.json` exactly. This is useful for anyone who wants to evaluate your policy or for reproducibility.
|
||||
|
||||
To upload these to the hub, run the following:
|
||||
|
||||
```bash
|
||||
huggingface-cli upload ${hf_user}/${repo_name} path/to/pretrained_model
|
||||
```
|
||||
|
||||
See [lerobot_eval.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/scripts/lerobot_eval.py) for an example of how other people may use your policy.
|
||||
|
||||
### Acknowledgment
|
||||
|
||||
- The LeRobot team 🤗 for building SmolVLA [Paper](https://arxiv.org/abs/2506.01844), [Blog](https://huggingface.co/blog/smolvla).
|
||||
- Thanks to Tony Zhao, Zipeng Fu and colleagues for open sourcing ACT policy, ALOHA environments and datasets. Ours are adapted from [ALOHA](https://tonyzhaozh.github.io/aloha) and [Mobile ALOHA](https://mobile-aloha.github.io).
|
||||
- Thanks to Cheng Chi, Zhenjia Xu and colleagues for open sourcing Diffusion policy, Pusht environment and datasets, as well as UMI datasets. Ours are adapted from [Diffusion Policy](https://diffusion-policy.cs.columbia.edu) and [UMI Gripper](https://umi-gripper.github.io).
|
||||
- Thanks to Nicklas Hansen, Yunhai Feng and colleagues for open sourcing TDMPC policy, Simxarm environments and datasets. Ours are adapted from [TDMPC](https://github.com/nicklashansen/tdmpc) and [FOWM](https://www.yunhaifeng.com/FOWM).
|
||||
- Thanks to Antonio Loquercio and Ashish Kumar for their early support.
|
||||
- Thanks to [Seungjae (Jay) Lee](https://sjlee.cc/), [Mahi Shafiullah](https://mahis.life/) and colleagues for open sourcing [VQ-BeT](https://sjlee.cc/vq-bet/) policy and helping us adapt the codebase to our repository. The policy is adapted from [VQ-BeT repo](https://github.com/jayLEE0301/vq_bet_official).
|
||||
|
||||
## Citation
|
||||
|
||||
If you use LeRobot in your research, please cite:
|
||||
If you want, you can cite this work with:
|
||||
|
||||
```bibtex
|
||||
@misc{cadene2024lerobot,
|
||||
@@ -144,14 +339,6 @@ If you use LeRobot in your research, please cite:
|
||||
}
|
||||
```
|
||||
|
||||
## Contribute
|
||||
## Star History
|
||||
|
||||
We welcome contributions from everyone in the community! To get started, please read our [CONTRIBUTING.md](./CONTRIBUTING.md) guide. Whether you're adding a new feature, improving documentation, or fixing a bug, your help and feedback are invaluable. We're incredibly excited about the future of open-source robotics and can't wait to work with you on what's next—thank you for your support!
|
||||
|
||||
<p align="center">
|
||||
<img alt="SO101 Video" src="./media/readme/so100_video.webp" width="640px">
|
||||
</p>
|
||||
|
||||
<div align="center">
|
||||
<sub>Built by the <a href="https://huggingface.co/lerobot">LeRobot</a> team at <a href="https://huggingface.co">Hugging Face</a> with ❤️</sub>
|
||||
</div>
|
||||
[](https://star-history.com/#huggingface/lerobot&Timeline)
|
||||
|
||||
@@ -9,8 +9,6 @@
|
||||
title: Imitation Learning for Robots
|
||||
- local: cameras
|
||||
title: Cameras
|
||||
- local: bring_your_own_policies
|
||||
title: Bring Your Own Policies
|
||||
- local: integrate_hardware
|
||||
title: Bring Your Own Hardware
|
||||
- local: hilserl
|
||||
@@ -19,8 +17,6 @@
|
||||
title: Train RL in Simulation
|
||||
- local: multi_gpu_training
|
||||
title: Multi GPU training
|
||||
- local: peft_training
|
||||
title: Training with PEFT (e.g., LoRA)
|
||||
title: "Tutorials"
|
||||
- sections:
|
||||
- local: lerobot-dataset-v3
|
||||
@@ -41,13 +37,7 @@
|
||||
title: π₀.₅ (Pi05)
|
||||
- local: groot
|
||||
title: NVIDIA GR00T N1.5
|
||||
- local: xvla
|
||||
title: X-VLA
|
||||
title: "Policies"
|
||||
- sections:
|
||||
- local: sarm
|
||||
title: SARM
|
||||
title: "Reward Models"
|
||||
- sections:
|
||||
- local: async
|
||||
title: Use Async Inference
|
||||
@@ -91,17 +81,11 @@
|
||||
title: Reachy 2
|
||||
- local: unitree_g1
|
||||
title: Unitree G1
|
||||
- local: earthrover_mini_plus
|
||||
title: Earth Rover Mini
|
||||
title: "Robots"
|
||||
- sections:
|
||||
- local: phone_teleop
|
||||
title: Phone
|
||||
title: "Teleoperators"
|
||||
- sections:
|
||||
- local: torch_accelerators
|
||||
title: PyTorch accelerators
|
||||
title: "Supported Hardware"
|
||||
- sections:
|
||||
- local: notebooks
|
||||
title: Notebooks
|
||||
|
||||
@@ -278,7 +278,7 @@ We found the default values of `actions_per_chunk` and `chunk_size_threshold` to
|
||||
2. **Adjust your `fps` based on inference latency.** While the server generates a new action chunk, the client is not idle and is stepping through its current action queue. If the two processes happen at fundamentally different speeds, the client might end up with an empty queue. As such, you should reduce your fps if you consistently run out of actions in queue.
|
||||
3. **Adjust `chunk_size_threshold`**.
|
||||
- Values closer to `0.0` result in almost sequential behavior. Values closer to `1.0` → send observation every step (more bandwidth, relies on good world-model).
|
||||
- We found values around 0.5-0.6 to work well. If you want to tweak this, spin up a `RobotClient` setting the `--debug_visualize_queue_size` to `True`. This will plot the action queue size evolution at runtime, and you can use it to find the value of `chunk_size_threshold` that works best for your setup.
|
||||
- We found values around 0.5-0.6 to work well. If you want to tweak this, spin up a `RobotClient` setting the `--debug-visualize-queue-size` to `True`. This will plot the action queue size evolution at runtime, and you can use it to find the value of `chunk_size_threshold` that works best for your setup.
|
||||
|
||||
<p align="center">
|
||||
<img
|
||||
@@ -289,7 +289,7 @@ We found the default values of `actions_per_chunk` and `chunk_size_threshold` to
|
||||
<p align="center">
|
||||
<i>
|
||||
The action queue size is plotted at runtime when the
|
||||
`--debug_visualize_queue_size` flag is passed, for various levels of
|
||||
`--debug-visualize-queue-size` flag is passed, for various levels of
|
||||
`chunk_size_threshold` (`g` in the SmolVLA paper).
|
||||
</i>
|
||||
</p>
|
||||
|
||||
@@ -1,175 +0,0 @@
|
||||
# Bring Your Own Policies
|
||||
|
||||
This tutorial explains how to integrate your own custom policy implementations into the LeRobot ecosystem, allowing you to leverage all LeRobot tools for training, evaluation, and deployment while using your own algorithms.
|
||||
|
||||
## Step 1: Create a Policy Package
|
||||
|
||||
Your custom policy should be organized as an installable Python package following LeRobot's plugin conventions.
|
||||
|
||||
### Package Structure
|
||||
|
||||
Create a package with the prefix `lerobot_policy_` (IMPORTANT!) followed by your policy name:
|
||||
|
||||
```bash
|
||||
lerobot_policy_my_custom_policy/
|
||||
├── pyproject.toml
|
||||
└── src/
|
||||
└── lerobot_policy_my_custom_policy/
|
||||
├── __init__.py
|
||||
├── configuration_my_custom_policy.py
|
||||
├── modeling_my_custom_policy.py
|
||||
└── processor_my_custom_policy.py
|
||||
```
|
||||
|
||||
### Package Configuration
|
||||
|
||||
Set up your `pyproject.toml`:
|
||||
|
||||
```toml
|
||||
[project]
|
||||
name = "lerobot_policy_my_custom_policy"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
# your policy-specific dependencies
|
||||
]
|
||||
requires-python = ">= 3.11"
|
||||
|
||||
[build-system]
|
||||
build-backend = # your-build-backend
|
||||
requires = # your-build-system
|
||||
```
|
||||
|
||||
## Step 2: Define the Policy Configuration
|
||||
|
||||
Create a configuration class that inherits from `PreTrainedConfig` and registers your policy type:
|
||||
|
||||
```python
|
||||
# configuration_my_custom_policy.py
|
||||
from dataclasses import dataclass, field
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import NormalizationMode
|
||||
|
||||
@PreTrainedConfig.register_subclass("my_custom_policy")
|
||||
@dataclass
|
||||
class MyCustomPolicyConfig(PreTrainedConfig):
|
||||
"""Configuration class for MyCustomPolicy.
|
||||
|
||||
Args:
|
||||
n_obs_steps: Number of observation steps to use as input
|
||||
horizon: Action prediction horizon
|
||||
n_action_steps: Number of action steps to execute
|
||||
hidden_dim: Hidden dimension for the policy network
|
||||
# Add your policy-specific parameters here
|
||||
"""
|
||||
# ...PreTrainedConfig fields...
|
||||
pass
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
# Add any validation logic here
|
||||
|
||||
def validate_features(self) -> None:
|
||||
"""Validate input/output feature compatibility."""
|
||||
# Implement validation logic for your policy's requirements
|
||||
pass
|
||||
```
|
||||
|
||||
## Step 3: Implement the Policy Class
|
||||
|
||||
Create your policy implementation by inheriting from LeRobot's base `PreTrainedPolicy` class:
|
||||
|
||||
```python
|
||||
# modeling_my_custom_policy.py
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Dict, Any
|
||||
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from .configuration_my_custom_policy import MyCustomPolicyConfig
|
||||
|
||||
class MyCustomPolicy(PreTrainedPolicy):
|
||||
config_class = MyCustomPolicyConfig
|
||||
name = "my_custom_policy"
|
||||
|
||||
def __init__(self, config: MyCustomPolicyConfig, dataset_stats: Dict[str, Any] = None):
|
||||
super().__init__(config, dataset_stats)
|
||||
...
|
||||
```
|
||||
|
||||
## Step 4: Add Data Processors
|
||||
|
||||
Create processor functions:
|
||||
|
||||
```python
|
||||
# processor_my_custom_policy.py
|
||||
from typing import Dict, Any
|
||||
import torch
|
||||
|
||||
|
||||
def make_my_custom_policy_pre_post_processors(
|
||||
config,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
"""Create preprocessing and postprocessing functions for your policy."""
|
||||
pass # Define your preprocessing and postprocessing logic here
|
||||
|
||||
```
|
||||
|
||||
## Step 5: Package Initialization
|
||||
|
||||
Expose your classes in the package's `__init__.py`:
|
||||
|
||||
```python
|
||||
# __init__.py
|
||||
"""Custom policy package for LeRobot."""
|
||||
|
||||
try:
|
||||
import lerobot # noqa: F401
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"lerobot is not installed. Please install lerobot to use this policy package."
|
||||
)
|
||||
|
||||
from .configuration_my_custom_policy import MyCustomPolicyConfig
|
||||
from .modeling_my_custom_policy import MyCustomPolicy
|
||||
from .processor_my_custom_policy import make_my_custom_policy_pre_post_processors
|
||||
|
||||
__all__ = [
|
||||
"MyCustomPolicyConfig",
|
||||
"MyCustomPolicy",
|
||||
"make_my_custom_policy_pre_post_processors",
|
||||
]
|
||||
```
|
||||
|
||||
## Step 6: Installation and Usage
|
||||
|
||||
### Install Your Policy Package
|
||||
|
||||
```bash
|
||||
cd lerobot_policy_my_custom_policy
|
||||
pip install -e .
|
||||
|
||||
# Or install from PyPI if published
|
||||
pip install lerobot_policy_my_custom_policy
|
||||
```
|
||||
|
||||
### Use Your Policy
|
||||
|
||||
Once installed, your policy automatically integrates with LeRobot's training and evaluation tools:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--policy.type my_custom_policy \
|
||||
--env.type pusht \
|
||||
--steps 200000
|
||||
```
|
||||
|
||||
## Examples and Community Contributions
|
||||
|
||||
Check out these example policy implementations:
|
||||
|
||||
- [DiTFlow Policy](https://github.com/danielsanjosepro/lerobot_policy_ditflow) - Diffusion Transformer policy with flow-matching objective. Try it out in this example: [DiTFlow Example](https://github.com/danielsanjosepro/test_lerobot_policy_ditflow)
|
||||
|
||||
Share your policy implementations with the community! 🤗
|
||||
@@ -1,206 +0,0 @@
|
||||
# EarthRover Mini Plus
|
||||
|
||||
The EarthRover Mini Plus is a fully open source mobile robot that connects through the cloud using the Frodobots SDK. This lets you control the robot and record datasets for training AI models.
|
||||
|
||||
## What You Need
|
||||
|
||||
### Hardware
|
||||
|
||||
- EarthRover Mini robot
|
||||
- Computer with Python 3.10 or newer
|
||||
- Internet connection
|
||||
|
||||
### Setting Up the Frodobots SDK
|
||||
|
||||
The robot needs the [Frodobots SDK](https://github.com/Frodobots/earth-rovers-sdk) running on your computer. Here's how:
|
||||
|
||||
1. Download and install the SDK:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/Frodobots/earth-rovers-sdk.git
|
||||
cd earth-rovers-sdk
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
2. Start the SDK:
|
||||
|
||||
```bash
|
||||
hypercorn main:app --reload
|
||||
```
|
||||
|
||||
3. Open your web browser and go to `http://localhost:8000`, then click "Join"
|
||||
|
||||
The SDK gives you:
|
||||
|
||||
- Live video from front and rear cameras
|
||||
|
||||
> [!IMPORTANT]
|
||||
> The SDK must be running before you can use the robot.
|
||||
|
||||
## Install LeRobot
|
||||
|
||||
Follow our [Installation Guide](./installation) to install LeRobot.
|
||||
|
||||
In addition to the base installation, install the EarthRover Mini dependencies:
|
||||
|
||||
```bash
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
## How It Works
|
||||
|
||||
The robot uses the internet to communicate:
|
||||
|
||||
- **Movement commands**: Sent through the SDK
|
||||
- **Camera video**: Received from the SDK
|
||||
- **Robot info**: Battery, location, speed from the SDK
|
||||
|
||||
You don't need to plug anything in - it all works through the SDK.
|
||||
|
||||
## Calibration
|
||||
|
||||
No calibration needed! The robot is ready to use as soon as the SDK is running.
|
||||
|
||||
## Controlling the Robot
|
||||
|
||||
You control the robot using your keyboard - just like playing a video game with WASD keys.
|
||||
|
||||
### Keyboard Controls
|
||||
|
||||
| Key | Action |
|
||||
| --- | -------------------------------- |
|
||||
| W | Move forward |
|
||||
| S | Move backward |
|
||||
| A | Turn left (with forward motion) |
|
||||
| D | Turn right (with forward motion) |
|
||||
| Q | Rotate left in place |
|
||||
| E | Rotate right in place |
|
||||
| X | Stop all movement |
|
||||
| +/= | Increase speed |
|
||||
| - | Decrease speed |
|
||||
| ESC | Disconnect |
|
||||
|
||||
### Speed Settings
|
||||
|
||||
You can adjust how fast the robot moves:
|
||||
|
||||
- **Forward/backward speed**: Default is full speed (1.0)
|
||||
- **Turning speed**: Default is full speed (1.0)
|
||||
- **Speed changes**: Use +/- keys to adjust by 0.1 each time
|
||||
|
||||
### Try It Out
|
||||
|
||||
Test driving the robot before recording data:
|
||||
|
||||
```python
|
||||
from lerobot.robots.earthrover_mini_plus import EarthRoverMiniPlus, EarthRoverMiniPlusConfig
|
||||
from lerobot.teleoperators.keyboard import KeyboardRoverTeleop, KeyboardRoverTeleopConfig
|
||||
|
||||
# Initialize robot
|
||||
robot_config = EarthRoverMiniPlusConfig()
|
||||
robot = EarthRoverMiniPlus(robot_config)
|
||||
|
||||
# Initialize teleoperator
|
||||
teleop_config = KeyboardRoverTeleopConfig(
|
||||
linear_speed=1.0,
|
||||
angular_speed=1.0,
|
||||
speed_increment=0.1
|
||||
)
|
||||
teleop = KeyboardRoverTeleop(teleop_config)
|
||||
|
||||
# Connect
|
||||
robot.connect()
|
||||
teleop.connect()
|
||||
|
||||
# Teleoperate (use keyboard controls)
|
||||
try:
|
||||
while True:
|
||||
action = teleop.get_action()
|
||||
robot.send_action(action)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
robot.disconnect()
|
||||
teleop.disconnect()
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
> If you're using a Mac, you might need to give Terminal permission to access your keyboard for teleoperation. Go to System Preferences > Security & Privacy > Input Monitoring and check the box for Terminal.
|
||||
|
||||
## Recording Data
|
||||
|
||||
Once you can drive the robot well, you can start recording data to train AI models. The system records:
|
||||
|
||||
- **What you do**: How you move the robot (forward, backward, turning)
|
||||
- **What the robot sees**:
|
||||
- Videos from both cameras
|
||||
- Robot speed and direction
|
||||
- Battery level and location
|
||||
- GPS position and signal
|
||||
- Other sensor data
|
||||
- **When it happened**: Timestamps for everything
|
||||
|
||||
### Setting Up Hugging Face
|
||||
|
||||
We use Hugging Face to store your data online. First, log in with your token from [Hugging Face settings](https://huggingface.co/settings/tokens):
|
||||
|
||||
```bash
|
||||
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
|
||||
```
|
||||
|
||||
Store your Hugging Face username:
|
||||
|
||||
```bash
|
||||
HF_USER=$(huggingface-cli whoami | head -n 1)
|
||||
echo $HF_USER
|
||||
```
|
||||
|
||||
### Start Recording
|
||||
|
||||
Use the standard recording command:
|
||||
|
||||
```bash
|
||||
python src/lerobot/scripts/lerobot_record.py \
|
||||
--robot.type=earthrover_mini_plus \
|
||||
--teleop.type=keyboard_rover \
|
||||
--dataset.repo_id=your_username/dataset_name \
|
||||
--dataset.num_episodes=2 \
|
||||
--dataset.fps=10 \
|
||||
--dataset.single_task="Navigate around obstacles" \
|
||||
--display_data=true
|
||||
```
|
||||
|
||||
Replace `your_username/dataset_name` with your Hugging Face username and a name for your dataset.
|
||||
|
||||
### What Gets Saved
|
||||
|
||||
Your dataset includes:
|
||||
|
||||
**Your Actions (2 things)**:
|
||||
|
||||
- How much you moved forward/backward
|
||||
- How much you turned left/right
|
||||
|
||||
**Robot Observations (12 things)**:
|
||||
|
||||
- Front camera video
|
||||
- Rear camera video
|
||||
- Current speed
|
||||
- Battery level
|
||||
- Which way the robot is facing
|
||||
- GPS location (latitude, longitude, signal strength)
|
||||
- Network signal strength
|
||||
- Vibration level
|
||||
- Lamp status (on/off)
|
||||
|
||||
### Where Your Data Goes
|
||||
|
||||
On your computer: `~/.cache/huggingface/lerobot/{repo-id}`
|
||||
|
||||
After recording, your data automatically uploads to your Hugging Face page:
|
||||
|
||||
```bash
|
||||
echo https://huggingface.co/datasets/${HF_USER}/earthrover-navigation
|
||||
```
|
||||
|
||||
Your dataset will be tagged with `LeRobot` for community discovery.
|
||||
@@ -201,8 +201,7 @@ from lerobot.teleoperators.so100_leader.so100_leader import SO100Leader
|
||||
from lerobot.utils.control_utils import init_keyboard_listener
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import init_rerun
|
||||
from lerobot.scripts.lerobot_record import record_loop
|
||||
from lerobot.processor import make_default_processors
|
||||
from lerobot.record import record_loop
|
||||
|
||||
NUM_EPISODES = 5
|
||||
FPS = 30
|
||||
@@ -210,19 +209,12 @@ EPISODE_TIME_SEC = 60
|
||||
RESET_TIME_SEC = 10
|
||||
TASK_DESCRIPTION = "My task description"
|
||||
|
||||
# Create robot configuration
|
||||
# Create the robot and teleoperator configurations
|
||||
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
|
||||
robot_config = SO100FollowerConfig(
|
||||
id="my_awesome_follower_arm",
|
||||
cameras={
|
||||
"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS) # Optional: fourcc="MJPG" for troubleshooting OpenCV async error.
|
||||
},
|
||||
port="/dev/tty.usbmodem58760434471",
|
||||
)
|
||||
|
||||
teleop_config = SO100LeaderConfig(
|
||||
id="my_awesome_leader_arm",
|
||||
port="/dev/tty.usbmodem585A0077581",
|
||||
port="/dev/tty.usbmodem58760434471", id="my_awesome_follower_arm", cameras=camera_config
|
||||
)
|
||||
teleop_config = SO100LeaderConfig(port="/dev/tty.usbmodem585A0077581", id="my_awesome_leader_arm")
|
||||
|
||||
# Initialize the robot and teleoperator
|
||||
robot = SO100Follower(robot_config)
|
||||
@@ -251,9 +243,6 @@ init_rerun(session_name="recording")
|
||||
robot.connect()
|
||||
teleop.connect()
|
||||
|
||||
# Create the required processors
|
||||
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
|
||||
|
||||
episode_idx = 0
|
||||
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
@@ -262,9 +251,6 @@ while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
teleop=teleop,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
@@ -279,9 +265,6 @@ while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
teleop=teleop,
|
||||
control_time_s=RESET_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
@@ -445,7 +428,7 @@ Your robot should replicate movements similar to those you recorded. For example
|
||||
|
||||
## Train a policy
|
||||
|
||||
To train a policy to control your robot, use the [`lerobot-train`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/scripts/lerobot_train.py) script. A few arguments are required. Here is an example command:
|
||||
To train a policy to control your robot, use the [`lerobot-train`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
@@ -502,7 +485,7 @@ huggingface-cli upload ${HF_USER}/act_so101_test${CKPT} \
|
||||
|
||||
## Run inference and evaluate your policy
|
||||
|
||||
You can use the `record` script from [`lerobot-record`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/scripts/lerobot_record.py) with a policy checkpoint as input, to run inference and evaluate your policy. For instance, run this command or API example to run inference and record 10 evaluation episodes:
|
||||
You can use the `record` script from [`lerobot/record.py`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/record.py) with a policy checkpoint as input, to run inference and evaluate your policy. For instance, run this command or API example to run inference and record 10 evaluation episodes:
|
||||
|
||||
<hfoptions id="eval">
|
||||
<hfoption id="Command">
|
||||
|
||||
@@ -90,7 +90,7 @@ If you encounter build errors, you may need to install additional dependencies:
|
||||
To install these for linux run:
|
||||
|
||||
```bash
|
||||
sudo apt-get install cmake build-essential python3-dev pkg-config libavformat-dev libavcodec-dev libavdevice-dev libavutil-dev libswscale-dev libswresample-dev libavfilter-dev
|
||||
sudo apt-get install cmake build-essential python-dev pkg-config libavformat-dev libavcodec-dev libavdevice-dev libavutil-dev libswscale-dev libswresample-dev libavfilter-dev pkg-config
|
||||
```
|
||||
|
||||
For other systems, see: [Compiling PyAV](https://pyav.org/docs/develop/overview/installation.html#bring-your-own-ffmpeg)
|
||||
|
||||
@@ -62,11 +62,6 @@ lerobot-eval \
|
||||
|
||||
- Pass a comma-separated list to `--env.task` for multi-suite evaluation.
|
||||
|
||||
### Control Mode
|
||||
|
||||
LIBERO now supports two control modes: relative and absolute. This matters because different VLA checkpoints are trained with different mode of action to output hence control parameterizations.
|
||||
You can switch them with: `env.control_mode = "relative"` and `env.control_mode = "absolute"`
|
||||
|
||||
### Policy inputs and outputs
|
||||
|
||||
When using LIBERO through LeRobot, policies interact with the environment via **observations** and **actions**:
|
||||
|
||||
@@ -1,62 +0,0 @@
|
||||
# Parameter efficient fine-tuning with 🤗 PEFT
|
||||
|
||||
[🤗 PEFT](https://github.com/huggingface/peft) (Parameter-Efficient Fine-Tuning) is a library for efficiently adapting
|
||||
large pretrained models such as pre-trained policies (e.g., SmolVLA, π₀, ...) to new tasks without training all
|
||||
of the model's parameters while yielding comparable performance.
|
||||
|
||||
Install the `lerobot[peft]` optional package to enable PEFT support.
|
||||
|
||||
To read about all the possible methods of adaption, please refer to the [🤗 PEFT docs](https://huggingface.co/docs/peft/index).
|
||||
|
||||
## Training SmolVLA
|
||||
|
||||
In this section we'll show you how to train a pre-trained SmolVLA policy with PEFT on the libero dataset.
|
||||
For brevity we're only training on the `libero_spatial` subset. We will use `lerobot/smolvla_base` as the model
|
||||
to parameter efficiently fine-tune:
|
||||
|
||||
```
|
||||
lerobot-train \
|
||||
--policy.path=lerobot/smolvla_base \
|
||||
--policy.repo_id=your_hub_name/my_libero_smolvla \
|
||||
--dataset.repo_id=HuggingFaceVLA/libero \
|
||||
--policy.output_features=null \
|
||||
--policy.input_features=null \
|
||||
--policy.optimizer_lr=1e-3 \
|
||||
--policy.scheduler_decay_lr=1e-4 \
|
||||
--env.type=libero \
|
||||
--env.task=libero_spatial \
|
||||
--steps=100000 \
|
||||
--batch_size=32 \
|
||||
--peft.method_type=LORA \
|
||||
--peft.r=64
|
||||
```
|
||||
|
||||
Note the `--peft.method_type` parameter that let's you select which PEFT method to use. Here we use
|
||||
[LoRA](https://huggingface.co/docs/peft/main/en/package_reference/lora) (Low-Rank Adapter) which is probably the most
|
||||
popular fine-tuning method to date. Low-rank adaption means that we only fine-tune a matrix with comparably low rank
|
||||
instead of the full weight matrix. This rank can be specified using the `--peft.r` parameter. The higher the rank
|
||||
the closer you get to full fine-tuning
|
||||
|
||||
There are more complex methods that have more parameters. These are not yet supported, feel free to raise an issue
|
||||
if you want to see a specific PEFT method supported.
|
||||
|
||||
By default, PEFT will target the `q_proj` and `v_proj` layers of the LM expert in SmolVLA. It will also target the
|
||||
state and action projection matrices as they are most likely task-dependent. If you need to target different layers
|
||||
you can use `--peft.target_modules` to specify which layers to target. You can refer to the respective PEFT method's
|
||||
documentation to see what inputs are supported, (e.g., [LoRA's target_modules documentation](https://huggingface.co/docs/peft/main/en/package_reference/lora#peft.LoraConfig.target_modules)).
|
||||
Usually a list of suffixes or a regex are supported. For example, to target the MLPs of the `lm_expert` instead of
|
||||
the `q` and `v` projections, use:
|
||||
|
||||
```
|
||||
--peft.target_modules='(model\.vlm_with_expert\.lm_expert\..*\.(down|gate|up)_proj|.*\.(state_proj|action_in_proj|action_out_proj|action_time_mlp_in|action_time_mlp_out))'
|
||||
```
|
||||
|
||||
In case you need to fully fine-tune a layer instead of just adapting it, you can supply a list of layer suffixes
|
||||
to the `--peft.full_training_modules` parameter:
|
||||
|
||||
```
|
||||
--peft.full_training_modules=["state_proj"]
|
||||
```
|
||||
|
||||
The learning rate and the scheduled target learning rate can usually be scaled by a factor of 10 compared to the
|
||||
learning rate used for full fine-tuning (e.g., 1e-4 normal, so 1e-3 using LoRA).
|
||||
@@ -1,586 +0,0 @@
|
||||
# SARM: Stage-Aware Reward Modeling
|
||||
|
||||
SARM (Stage-Aware Reward Modeling) is a video-based reward modeling framework for long-horizon robot manipulation tasks. This guide covers how to train SARM reward models and optionally use them with Reward-Aligned Behavior Cloning (RA-BC).
|
||||
|
||||
**Paper**: [SARM: Stage-Aware Reward Modeling for Long Horizon Robot Manipulation](https://arxiv.org/abs/2509.25358)
|
||||
|
||||
## Why Reward Models?
|
||||
|
||||
Standard behavior cloning treats all demonstration frames equally, but real-world robot datasets are messy. They contain hesitations, corrections, and variable-quality trajectories. Reward models solve this by learning a generalizable notion of **task progress** from demonstrations: given video frames and a task description, they predict how close the robot is to completing the task (0→1). This learned "progress signal" can be used in multiple ways, two promising applications are: (1) **weighted imitation learning** (RA-BC), where high-progress frames receive more weight during policy training, and (2) **reinforcement learning**, where the reward model provides dense rewards for online or offline policy improvement.
|
||||
|
||||
## Overview
|
||||
|
||||
SARM has following features:
|
||||
|
||||
1. **Stage-aware architecture**: Jointly predicts the high-level task stage and fine-grained progress within each stage
|
||||
2. **Subtask annotations**: Uses natural language subtask annotations to derive consistent progress labels
|
||||
3. **Temporal proportions**: Computes dataset-level priors (α̅\_k) for each subtask to normalize progress across variable-length demonstrations
|
||||
|
||||
SARM trains on a compact **stage+tau** target for each frame:
|
||||
|
||||
- **stage**: integer stage index `k ∈ {0, ..., K-1}`
|
||||
- **τ (tau)**: within-stage progress `τ ∈ [0, 1]`
|
||||
- **target encoding**: `y = k + τ` (this is what the dataset processor produces)
|
||||
|
||||
At inference time (and in downstream RA-BC), SARM converts the raw `k + τ` value into a **normalized progress** in `[0, 1]` using dataset-level **temporal proportions** `α̅_k` (stored in `meta/temporal_proportions_*.json`).
|
||||
|
||||
This matches **Formula (2)** from the paper:
|
||||
|
||||
```
|
||||
progress_t = P_{k-1} + α̅_k × τ_t
|
||||
```
|
||||
|
||||
Where:
|
||||
|
||||
- `τ_t = (t - s_k) / (e_k - s_k)` is within-subtask normalized time
|
||||
- `P_{k-1}` is cumulative prior (sum of previous subtask proportions)
|
||||
- `α̅_k` is the temporal proportion for subtask k
|
||||
|
||||
This ensures identical task states map to consistent progress values, even across demonstrations of different lengths.
|
||||
|
||||
## Inputs and Targets (What the new code expects)
|
||||
|
||||
SARM is trained through its processor (`src/lerobot/policies/sarm/processor_sarm.py`), which:
|
||||
|
||||
- **Encodes** images and task text with CLIP (ViT-B/32) into `video_features` and `text_features`
|
||||
- **Pads/truncates** robot state into `state_features` (up to `max_state_dim`)
|
||||
- **Builds targets** as `sparse_targets` (and `dense_targets` in `dense_only`/`dual`) using the stage+tau encoding `y = k + τ`
|
||||
- **Masks rewind frames** using a per-sample `lengths` tensor (rewind is a training-time augmentation)
|
||||
|
||||
At minimum, each training sample needs:
|
||||
|
||||
- `task` (string): task description
|
||||
- `policy.image_key` images and `policy.state_key` states from the dataset
|
||||
|
||||
---
|
||||
|
||||
## Annotation Modes
|
||||
|
||||
You can choose from **3 annotation modes** that determine how progress labels are computed:
|
||||
|
||||
| Mode | Annotations Required | Heads | Use Case |
|
||||
| -------------- | -------------------- | ---------------------------- | ------------------------------------------------------------ |
|
||||
| `single_stage` | None | Sparse only | Simple tasks, quick experiments, no VLM needed |
|
||||
| `dense_only` | Dense (VLM) | Dual (sparse auto-generated) | Detailed subtask tracking without defining high-level stages |
|
||||
| `dual` | Sparse + Dense (VLM) | Dual | Full SARM paper setup with both granularities |
|
||||
|
||||
### Mode Details
|
||||
|
||||
<hfoptions id="mode_explanation">
|
||||
<hfoption id="single_stage">
|
||||
|
||||
**No annotations required.** The entire episode is treated as a single stage called `"task"`, and progress is linear from 0 to 1 over the episode duration.
|
||||
|
||||
- **Sparse head**: 1 stage ("task"), linear progress
|
||||
- **Dense head**: Not used
|
||||
- **Best for**: Simple tasks, quick experiments, or when VLM annotation is not available
|
||||
|
||||
## Set Up Your Environment
|
||||
|
||||
1. Install LeRobot by following our [Installation Guide](./installation).
|
||||
2. Install SARM dependencies by running:
|
||||
|
||||
```bash
|
||||
pip install -e ".[sarm]"
|
||||
```
|
||||
|
||||
Workflow:
|
||||
|
||||
```
|
||||
1. Train SARM → 2. Visualize predictions → 3. (Optional) Train policy with RA-BC
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="dense_only">
|
||||
|
||||
**Only dense (fine-grained) annotations from a VLM.** The sparse head automatically uses a single `"task"` stage covering the full episode, while the dense head learns detailed subtask progression.
|
||||
|
||||
- **Sparse head**: 1 stage ("task"), linear progress (auto-generated)
|
||||
- **Dense head**: Multiple fine-grained stages from VLM annotations
|
||||
- **Best for**: When you want detailed subtask tracking but don't need to define high-level stages
|
||||
|
||||
Workflow:
|
||||
|
||||
```
|
||||
1. Annotate (dense) → 2. Verify → 3. Train SARM → 4. Visualize → 5. (Optional) Train policy with RA-BC
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="dual">
|
||||
|
||||
**Both sparse and dense annotations from VLM.** Full dual-head mode as described in the SARM paper, with both high-level (sparse) and fine-grained (dense) stage predictions.
|
||||
|
||||
- **Sparse head**: High-level stages from VLM annotations
|
||||
- **Dense head**: Fine-grained stages from VLM annotations
|
||||
- **Best for**: Complex multi-stage tasks where both granularities are useful
|
||||
|
||||
Workflow:
|
||||
|
||||
```
|
||||
1. Annotate (sparse+dense) → 2. Verify → 3. Train SARM → 4. Visualize → 5. (Optional) Train policy with RA-BC
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
---
|
||||
|
||||
## Step 1: Subtask Annotation
|
||||
|
||||
<hfoptions id="annotation_mode">
|
||||
<hfoption id="single_stage">
|
||||
|
||||
**No annotation required!** Skip this step entirely. The model will use the episode's task description and compute linear progress automatically.
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="dense_only">
|
||||
|
||||
Generate **dense (fine-grained) annotations only** using a VLM. The sparse stage will be auto-generated.
|
||||
|
||||
```bash
|
||||
python src/lerobot/data_processing/sarm_annotations/subtask_annotation.py \
|
||||
--repo-id your-username/your-dataset \
|
||||
--dense-only \
|
||||
--dense-subtasks "Bring robot arms up from starting position,Grab near side and do 1st fold,Grab side and do 2nd fold,Grab side and do 3rd fold to finish folding" \
|
||||
--video-key observation.images.base \
|
||||
--num-workers 4 \
|
||||
--push-to-hub
|
||||
```
|
||||
|
||||
**What gets saved:**
|
||||
|
||||
- `meta/temporal_proportions_sparse.json` - Auto-generated sparse proportions (`{"task": 1.0}`)
|
||||
- `meta/temporal_proportions_dense.json` - Dense temporal proportions
|
||||
- Per-episode columns in `episodes/*.parquet`:
|
||||
- `dense_subtask_names`, `dense_subtask_start_frames`, `dense_subtask_end_frames`
|
||||
- (also time-based columns: `dense_subtask_start_times`, `dense_subtask_end_times`)
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="dual">
|
||||
|
||||
Generate **both sparse (high-level) and dense (fine-grained) annotations** using a VLM.
|
||||
|
||||
```bash
|
||||
python src/lerobot/data_processing/sarm_annotations/subtask_annotation.py \
|
||||
--repo-id your-username/your-dataset \
|
||||
--sparse-subtasks "Bring arms up from starting position,Fold the towel (3 folds in total)" \
|
||||
--dense-subtasks "Bring robot arms up from starting position,Grab near side and do 1st fold,Grab side and do 2nd fold,Grab side and do 3rd fold to finish folding" \
|
||||
--video-key observation.images.base \
|
||||
--num-workers 4 \
|
||||
--push-to-hub
|
||||
```
|
||||
|
||||
**What gets saved:**
|
||||
|
||||
- `meta/temporal_proportions_sparse.json` - Sparse temporal proportions
|
||||
- `meta/temporal_proportions_dense.json` - Dense temporal proportions
|
||||
- Per-episode columns in `episodes/*.parquet`:
|
||||
- `sparse_subtask_names`, `sparse_subtask_start_frames`, `sparse_subtask_end_frames`
|
||||
- `dense_subtask_names`, `dense_subtask_start_frames`, `dense_subtask_end_frames`
|
||||
- (also time-based columns: `*_subtask_start_times`, `*_subtask_end_times`)
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
### Annotation Arguments
|
||||
|
||||
| Argument | Description |
|
||||
| ---------------------- | ------------------------------------------------------------------------------- |
|
||||
| `--repo-id` | HuggingFace dataset repository ID |
|
||||
| `--sparse-subtasks` | Comma-separated list of high-level subtask names |
|
||||
| `--dense-subtasks` | Comma-separated list of fine-grained subtask names |
|
||||
| `--dense-only` | Generate only dense annotations (auto-creates sparse "task" stage) |
|
||||
| `--video-key` | Camera/video key to use (e.g., `observation.images.top`) |
|
||||
| `--num-workers` | Number of parallel GPU workers (default: 1) |
|
||||
| `--episodes` | Specific episode indices to annotate (default: all) |
|
||||
| `--skip-existing` | Skip episodes that already have annotations |
|
||||
| `--model` | VLM model (default: `Qwen/Qwen3-VL-30B-A3B-Instruct`) |
|
||||
| `--num-visualizations` | Number of episodes to visualize after annotation (default: 5, set to 0 to skip) |
|
||||
|
||||
> **Note**: After annotation completes, 5 episodes are automatically visualized by default. Use `--num-visualizations 0` to skip this step.
|
||||
|
||||
---
|
||||
|
||||
## Step 2: Verify Annotations
|
||||
|
||||
<hfoptions id="verify_mode">
|
||||
<hfoption id="single_stage">
|
||||
|
||||
**No verification needed!** Skip this step.
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="dense_only">
|
||||
|
||||
Visualize annotations using the `--visualize-only` flag:
|
||||
|
||||
```bash
|
||||
python src/lerobot/data_processing/sarm_annotations/subtask_annotation.py \
|
||||
--repo-id your-username/your-dataset \
|
||||
--visualize-only \
|
||||
--visualize-type dense \
|
||||
--num-visualizations 5 \
|
||||
--video-key observation.images.base \
|
||||
--output-dir ./subtask_viz
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="dual">
|
||||
|
||||
Visualize annotations using the `--visualize-only` flag:
|
||||
|
||||
```bash
|
||||
python src/lerobot/data_processing/sarm_annotations/subtask_annotation.py \
|
||||
--repo-id your-username/your-dataset \
|
||||
--visualize-only \
|
||||
--visualize-type both \
|
||||
--num-visualizations 5 \
|
||||
--video-key observation.images.base \
|
||||
--output-dir ./subtask_viz
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
This generates visualizations showing video frames with subtask boundaries overlaid and timeline of subtasks.
|
||||
|
||||
### Visualization Arguments
|
||||
|
||||
| Argument | Description |
|
||||
| ---------------------- | -------------------------------------------------------------- |
|
||||
| `--visualize-only` | Only visualize existing annotations (no generation) |
|
||||
| `--num-visualizations` | Number of episodes to visualize (default: 5) |
|
||||
| `--visualize-type` | Type of annotations to visualize: `sparse`, `dense`, or `both` |
|
||||
|
||||
**Tip**: If annotations are inaccurate, adjust your subtask descriptions to be more specific and re-run.
|
||||
|
||||
---
|
||||
|
||||
## Step 3: Train SARM
|
||||
|
||||
<hfoptions id="train_mode">
|
||||
<hfoption id="single_stage">
|
||||
|
||||
Train with **no annotations** - uses linear progress from 0 to 1:
|
||||
|
||||
```bash
|
||||
python src/lerobot/scripts/lerobot_train.py \
|
||||
--dataset.repo_id=your-username/your-dataset \
|
||||
--policy.type=sarm \
|
||||
--policy.annotation_mode=single_stage \
|
||||
--policy.image_key=observation.images.base \
|
||||
--output_dir=outputs/train/sarm_single \
|
||||
--batch_size=32 \
|
||||
--steps=5000 \
|
||||
--wandb.enable=true \
|
||||
--wandb.project=sarm \
|
||||
--policy.repo_id=your-username/your-model-name
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="dense_only">
|
||||
|
||||
Train with **dense annotations only** (sparse auto-generated):
|
||||
|
||||
```bash
|
||||
python src/lerobot/scripts/lerobot_train.py \
|
||||
--dataset.repo_id=your-username/your-dataset \
|
||||
--policy.type=sarm \
|
||||
--policy.annotation_mode=dense_only \
|
||||
--policy.image_key=observation.images.base \
|
||||
--output_dir=outputs/train/sarm_dense \
|
||||
--batch_size=32 \
|
||||
--steps=5000 \
|
||||
--wandb.enable=true \
|
||||
--wandb.project=sarm \
|
||||
--policy.repo_id=your-username/your-model-name
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="dual">
|
||||
|
||||
Train with **both sparse and dense annotations**:
|
||||
|
||||
```bash
|
||||
python src/lerobot/scripts/lerobot_train.py \
|
||||
--dataset.repo_id=your-username/your-dataset \
|
||||
--policy.type=sarm \
|
||||
--policy.annotation_mode=dual \
|
||||
--policy.image_key=observation.images.base \
|
||||
--output_dir=outputs/train/sarm_dual \
|
||||
--batch_size=32 \
|
||||
--steps=5000 \
|
||||
--wandb.enable=true \
|
||||
--wandb.project=sarm \
|
||||
--policy.repo_id=your-username/your-model-name
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
### Multi-GPU Training
|
||||
|
||||
Add `accelerate launch --multi_gpu --num_processes=4` to use multiple GPUs for training.
|
||||
|
||||
### Training Arguments
|
||||
|
||||
| Argument | Description | Default |
|
||||
| -------------------------- | ----------------------------------------------------------------- | ------------------------ |
|
||||
| `--policy.annotation_mode` | `single_stage`, `dense_only`, or `dual` | `single_stage` |
|
||||
| `--policy.image_key` | Camera key for images | `observation.images.top` |
|
||||
| `--policy.state_key` | Key for joint states | `observation.state` |
|
||||
| `--policy.n_obs_steps` | Observation history steps (total obs frames = `n_obs_steps + 1`) | `8` |
|
||||
| `--policy.frame_gap` | Gap (in frames) between sampled observations (at 30 fps: 30 ≈ 1s) | `30` |
|
||||
|
||||
---
|
||||
|
||||
## Step 4: Visualize Predictions
|
||||
|
||||
Use `compute_rabc_weights.py` with `--visualize-only` to visualize model predictions (and, if available, annotation-derived targets) without writing a parquet file.
|
||||
|
||||
<hfoptions id="viz_mode">
|
||||
<hfoption id="single_stage">
|
||||
|
||||
```bash
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \
|
||||
--dataset-repo-id your-username/your-dataset \
|
||||
--reward-model-path your-username/sarm-model \
|
||||
--visualize-only \
|
||||
--num-visualizations 5 \
|
||||
--head-mode sparse \
|
||||
--output-dir ./sarm_viz
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="dense_only">
|
||||
|
||||
```bash
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \
|
||||
--dataset-repo-id your-username/your-dataset \
|
||||
--reward-model-path your-username/sarm-model \
|
||||
--visualize-only \
|
||||
--num-visualizations 5 \
|
||||
--head-mode dense \
|
||||
--output-dir ./sarm_viz
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="dual">
|
||||
|
||||
```bash
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \
|
||||
--dataset-repo-id your-username/your-dataset \
|
||||
--reward-model-path your-username/sarm-model \
|
||||
--visualize-only \
|
||||
--num-visualizations 5 \
|
||||
--head-mode both \
|
||||
--output-dir ./sarm_viz
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
The visualization shows:
|
||||
|
||||
- **Progress plot**: Predicted progress (and optional annotation-derived “GT” when available and `--stride 1`)
|
||||
- **Stage probabilities**: Stacked area plot of predicted stage probabilities
|
||||
- **Sample frames**: Key frames from the episode with progress/stage labels
|
||||
|
||||
### Visualization Arguments
|
||||
|
||||
| Argument | Description |
|
||||
| ---------------------- | --------------------------------------------------------- |
|
||||
| `--visualize-only` | Only visualize predictions (no RABC computation) |
|
||||
| `--num-visualizations` | Number of episodes to visualize (default: 5) |
|
||||
| `--head-mode` | SARM head to use: `sparse`, `dense`, or `both` |
|
||||
| `--stride` | Compute every N frames, interpolate the rest (default: 1) |
|
||||
|
||||
---
|
||||
|
||||
## Step 5 (Optional): Train Policy with RA-BC
|
||||
|
||||
Reward-Aligned Behavior Cloning (RA-BC) uses the trained SARM model to weight training samples based on predicted progress improvement. This requires two steps:
|
||||
|
||||
1. **Precompute progress values** for all frames using the trained SARM model
|
||||
2. **Train policy** with RA-BC weighting using the precomputed values
|
||||
|
||||
### How RA-BC Works
|
||||
|
||||
For each training sample, RA-BC computes the progress delta:
|
||||
|
||||
```
|
||||
r_i = φ(o_{t+Δ}) - φ(o_t)
|
||||
```
|
||||
|
||||
Where `φ` is the SARM progress prediction and `Δ` is the policy's `chunk_size`. Samples with positive progress (good demonstrations) get higher weights, while samples with negative or zero progress get down-weighted.
|
||||
|
||||
The weighting follows **Equations 8-9** from the paper:
|
||||
|
||||
- **Soft weight**: `w̃_i = clip((r_i − (μ − 2σ)) / (4σ + ε), 0, 1)`
|
||||
- **Final weight**: `w_i = 𝟙{r_i > κ} + 𝟙{0 ≤ r_i ≤ κ} × w̃_i`
|
||||
|
||||
### Step 5a: Compute SARM Progress Values
|
||||
|
||||
First, run the SARM model on all frames in your dataset to compute progress values:
|
||||
|
||||
```bash
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \
|
||||
--dataset-repo-id your-username/your-dataset \
|
||||
--reward-model-path your-username/sarm-model \
|
||||
--head-mode sparse \
|
||||
--num-visualizations 5 \
|
||||
--push-to-hub
|
||||
```
|
||||
|
||||
This script:
|
||||
|
||||
- Processes all frames and computes progress values
|
||||
- Saves progress values to a parquet file next to the dataset on disk (defaults to `<dataset_root>/sarm_progress.parquet`)
|
||||
- Generates visualizations of the first N episodes (default: 5)
|
||||
|
||||
**Arguments:**
|
||||
|
||||
| Argument | Description | Default |
|
||||
| ---------------------- | -------------------------------------------------------------- | ---------- |
|
||||
| `--reward-model-path` | Path to trained SARM model | (required) |
|
||||
| `--head-mode` | SARM head to use: `sparse`, `dense`, or `both` | `sparse` |
|
||||
| `--device` | Device for inference | `cuda` |
|
||||
| `--visualize-only` | Only visualize predictions (no RA-BC computation) | `false` |
|
||||
| `--num-visualizations` | Number of episodes to visualize (default: 5, set to 0 to skip) | `5` |
|
||||
|
||||
**Output format** (`sarm_progress.parquet`):
|
||||
|
||||
| Column | Description |
|
||||
| ----------------- | ---------------------------------------------- |
|
||||
| `index` | Global frame index in dataset |
|
||||
| `episode_index` | Episode number |
|
||||
| `frame_index` | Local frame index within episode |
|
||||
| `progress_sparse` | Sparse head progress value [0, 1] |
|
||||
| `progress_dense` | Dense head progress value [0, 1] (if computed) |
|
||||
|
||||
### Step 5b: Train Policy with RA-BC
|
||||
|
||||
Once you have the progress file, train your policy with RA-BC weighting. The progress file is auto-detected from the dataset path (`sarm_progress.parquet`). Currently PI0, PI0.5 and SmolVLA are supported with RA-BC:
|
||||
|
||||
```bash
|
||||
python src/lerobot/scripts/lerobot_train.py \
|
||||
--dataset.repo_id=your-username/your-dataset \
|
||||
--policy.type=pi0 \
|
||||
--use_rabc=true \
|
||||
--rabc_head_mode=sparse \
|
||||
--rabc_kappa=0.01 \
|
||||
--output_dir=outputs/train/policy_rabc \
|
||||
--batch_size=32 \
|
||||
--steps=40000
|
||||
```
|
||||
|
||||
The training script automatically:
|
||||
|
||||
- Loads the precomputed progress values from the parquet file
|
||||
- Uses the policy's `chunk_size` to compute progress deltas (Δ)
|
||||
- Computes sample weights based on progress improvement
|
||||
- Applies weighted loss during training
|
||||
|
||||
**RA-BC Arguments:**
|
||||
|
||||
| Argument | Description | Default |
|
||||
| ---------------------- | ---------------------------------------------------------- | ---------------------------------- |
|
||||
| `--use_rabc` | Enable RA-BC sample weighting | `false` |
|
||||
| `--rabc_progress_path` | Path to progress parquet file (auto-detected from dataset) | `sarm_progress.parquet` in dataset |
|
||||
| `--rabc_head_mode` | Which SARM head's progress to use: `sparse` or `dense` | `sparse` |
|
||||
| `--rabc_kappa` | Threshold κ for high-quality samples | `0.01` |
|
||||
|
||||
### Tuning RA-BC Kappa
|
||||
|
||||
The `kappa` parameter is the threshold that determines which samples get full weight (w=1). Understanding how to tune it is critical for RA-BC to work effectively.
|
||||
|
||||
**How the weighting works:**
|
||||
|
||||
| Condition | Weight |
|
||||
| ------------------- | ----------------------- |
|
||||
| `delta > kappa` | 1.0 (hard threshold) |
|
||||
| `0 ≤ delta ≤ kappa` | Soft weight from Eq. 8 |
|
||||
| `delta < 0` | 0.0 (negative progress) |
|
||||
|
||||
**Diagnosing kappa issues:**
|
||||
|
||||
Monitor these WandB metrics during training:
|
||||
|
||||
| Metric | Healthy Range | Problem Indicator |
|
||||
| ------------------ | ------------- | ------------------------- |
|
||||
| `rabc_mean_weight` | 0.3 - 0.8 | ≈ 1.0 means kappa too low |
|
||||
| `rabc_delta_mean` | > 0 | Should be positive |
|
||||
| `rabc_delta_std` | > 0 | Variance in data quality |
|
||||
|
||||
**If `rabc_mean_weight ≈ 1.0`:** Your kappa is too low. Most samples have `delta > kappa` and bypass the soft-weighting entirely. RA-BC becomes equivalent to vanilla BC.
|
||||
|
||||
**Setting kappa based on your data:**
|
||||
|
||||
The default `kappa=0.01` was tuned for the paper's T-shirt folding task (~90s episodes at 30fps). For your dataset, check the logged `rabc_delta_mean` and `rabc_delta_std`:
|
||||
|
||||
```
|
||||
# If delta_mean ≈ 0.03 and delta_std ≈ 0.02:
|
||||
# Most deltas fall in range [0.01, 0.05]
|
||||
|
||||
# Option 1: Set kappa = delta_mean (medium selectivity)
|
||||
--rabc_kappa=0.03
|
||||
|
||||
# Option 2: Set kappa = delta_mean + delta_std (high selectivity)
|
||||
--rabc_kappa=0.05
|
||||
|
||||
# Option 3: Set kappa = delta_mean + 2*delta_std (very selective)
|
||||
--rabc_kappa=0.07
|
||||
```
|
||||
|
||||
**When RA-BC may not help:**
|
||||
|
||||
If your dataset is already high quality (consistent progress across all demonstrations), RA-BC won't provide much benefit since there's nothing to filter.
|
||||
|
||||
### Multi-GPU Training with RA-BC
|
||||
|
||||
```bash
|
||||
accelerate launch \
|
||||
--multi_gpu \
|
||||
--num_processes=4 \
|
||||
src/lerobot/scripts/lerobot_train.py \
|
||||
--dataset.repo_id=your-username/your-dataset \
|
||||
--policy.type=pi0 \
|
||||
--use_rabc=true \
|
||||
--rabc_kappa=0.01 \
|
||||
--output_dir=outputs/train/policy_rabc \
|
||||
--batch_size=32 \
|
||||
--steps=40000
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Tips & Best Practices
|
||||
|
||||
### Choosing a Mode
|
||||
|
||||
- **Start with `single_stage`** for quick experiments - no annotation overhead
|
||||
- Use **`dense_only`** when you want detailed progress tracking but tasks don't have clear high-level stages
|
||||
- Use **`dual`** for complex tasks where both coarse and fine-grained progress is meaningful
|
||||
|
||||
### Annotation Quality
|
||||
|
||||
1. **Be specific with subtask names**: Instead of "fold", use "grab near side and fold toward center"
|
||||
2. **Verify with visualization**: Always check a few episodes before training
|
||||
3. **Consistent naming**: Use the same subtask names across all episodes
|
||||
|
||||
### RA-BC
|
||||
|
||||
1. **Train SARM first**: RA-BC quality depends entirely on SARM quality
|
||||
2. **Monitor `rabc_mean_weight`**: If it's ≈ 1.0, increase kappa (see [Tuning RA-BC Kappa](#tuning-ra-bc-kappa))
|
||||
|
||||
---
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@article{chen2025sarm,
|
||||
title={SARM: Stage-Aware Reward Modeling for Long Horizon Robot Manipulation},
|
||||
author={Chen, Qianzhong and Yu, Justin and Schwager, Mac and Abbeel, Pieter and Shentu, Yide and Wu, Philipp},
|
||||
journal={arXiv preprint arXiv:2509.25358},
|
||||
year={2025}
|
||||
}
|
||||
```
|
||||
@@ -30,6 +30,131 @@ The follower arm uses 6x STS3215 motors with 1/345 gearing. The leader, however,
|
||||
| Wrist Roll | 5 | 1 / 147 |
|
||||
| Gripper | 6 | 1 / 147 |
|
||||
|
||||
### Clean Parts
|
||||
|
||||
Remove all support material from the 3D-printed parts. The easiest way to do this is using a small screwdriver to get underneath the support material.
|
||||
|
||||
It is advisable to install one 3-pin cable in the motor after placing them before continuing assembly.
|
||||
|
||||
### Joint 1
|
||||
|
||||
- Place the first motor into the base.
|
||||
- Fasten the motor with 4 M2x6mm screws (smallest screws). Two from the top and two from the bottom.
|
||||
- Slide over the first motor holder and fasten it using two M2x6mm screws (one on each side).
|
||||
- Install both motor horns, securing the top horn with a M3x6mm screw.
|
||||
- Attach the shoulder part.
|
||||
- Tighten the shoulder part with 4 M3x6mm screws on top and 4 M3x6mm screws on the bottom
|
||||
- Add the shoulder motor holder.
|
||||
|
||||
<div class="video-container">
|
||||
<video controls width="600">
|
||||
<source
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Joint1_v2.mp4"
|
||||
type="video/mp4"
|
||||
/>
|
||||
</video>
|
||||
</div>
|
||||
|
||||
### Joint 2
|
||||
|
||||
- Slide the second motor in from the top.
|
||||
- Fasten the second motor with 4 M2x6mm screws.
|
||||
- Attach both motor horns to motor 2, again use the M3x6mm horn screw.
|
||||
- Attach the upper arm with 4 M3x6mm screws on each side.
|
||||
|
||||
<div class="video-container">
|
||||
<video controls width="600">
|
||||
<source
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Joint2_v2.mp4"
|
||||
type="video/mp4"
|
||||
/>
|
||||
</video>
|
||||
</div>
|
||||
|
||||
### Joint 3
|
||||
|
||||
- Insert motor 3 and fasten using 4 M2x6mm screws
|
||||
- Attach both motor horns to motor 3 and secure one again with a M3x6mm horn screw.
|
||||
- Connect the forearm to motor 3 using 4 M3x6mm screws on each side.
|
||||
|
||||
<div class="video-container">
|
||||
<video controls width="600">
|
||||
<source
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Joint3_v2.mp4"
|
||||
type="video/mp4"
|
||||
/>
|
||||
</video>
|
||||
</div>
|
||||
|
||||
### Joint 4
|
||||
|
||||
- Slide over motor holder 4.
|
||||
- Slide in motor 4.
|
||||
- Fasten motor 4 with 4 M2x6mm screws and attach its motor horns, use a M3x6mm horn screw.
|
||||
|
||||
<div class="video-container">
|
||||
<video controls width="600">
|
||||
<source
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Joint4_v2.mp4"
|
||||
type="video/mp4"
|
||||
/>
|
||||
</video>
|
||||
</div>
|
||||
|
||||
### Joint 5
|
||||
|
||||
- Insert motor 5 into the wrist holder and secure it with 2 M2x6mm front screws.
|
||||
- Install only one motor horn on the wrist motor and secure it with a M3x6mm horn screw.
|
||||
- Secure the wrist to motor 4 using 4 M3x6mm screws on both sides.
|
||||
|
||||
<div class="video-container">
|
||||
<video controls width="600">
|
||||
<source
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Joint5_v2.mp4"
|
||||
type="video/mp4"
|
||||
/>
|
||||
</video>
|
||||
</div>
|
||||
|
||||
### Gripper / Handle
|
||||
|
||||
<hfoptions id="assembly">
|
||||
<hfoption id="Follower">
|
||||
|
||||
- Attach the gripper to motor 5, attach it to the motor horn on the wrist using 4 M3x6mm screws.
|
||||
- Insert the gripper motor and secure it with 2 M2x6mm screws on each side.
|
||||
- Attach the motor horns and again use a M3x6mm horn screw.
|
||||
- Install the gripper claw and secure it with 4 M3x6mm screws on both sides.
|
||||
|
||||
<div class="video-container">
|
||||
<video controls width="600">
|
||||
<source
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Gripper_v2.mp4"
|
||||
type="video/mp4"
|
||||
/>
|
||||
</video>
|
||||
</div>
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="Leader">
|
||||
|
||||
- Mount the leader holder onto the wrist and secure it with 4 M3x6mm screws.
|
||||
- Attach the handle to motor 5 using 1 M2x6mm screw.
|
||||
- Insert the gripper motor, secure it with 2 M2x6mm screws on each side, attach a motor horn using a M3x6mm horn screw.
|
||||
- Attach the follower trigger with 4 M3x6mm screws.
|
||||
|
||||
<div class="video-container">
|
||||
<video controls width="600">
|
||||
<source
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Leader_v2.mp4"
|
||||
type="video/mp4"
|
||||
/>
|
||||
</video>
|
||||
</div>
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
## Configure the motors
|
||||
|
||||
### 1. Find the USB ports associated with each arm
|
||||
@@ -215,131 +340,6 @@ leader.setup_motors()
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
### Clean Parts
|
||||
|
||||
Remove all support material from the 3D-printed parts. The easiest way to do this is using a small screwdriver to get underneath the support material.
|
||||
|
||||
It is advisable to install one 3-pin cable in the motor after placing them before continuing assembly.
|
||||
|
||||
### Joint 1
|
||||
|
||||
- Place the first motor into the base.
|
||||
- Fasten the motor with 4 M2x6mm screws (smallest screws). Two from the top and two from the bottom.
|
||||
- Slide over the first motor holder and fasten it using two M2x6mm screws (one on each side).
|
||||
- Install both motor horns, securing the top horn with a M3x6mm screw.
|
||||
- Attach the shoulder part.
|
||||
- Tighten the shoulder part with 4 M3x6mm screws on top and 4 M3x6mm screws on the bottom
|
||||
- Add the shoulder motor holder.
|
||||
|
||||
<div class="video-container">
|
||||
<video controls width="600">
|
||||
<source
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Joint1_v2.mp4"
|
||||
type="video/mp4"
|
||||
/>
|
||||
</video>
|
||||
</div>
|
||||
|
||||
### Joint 2
|
||||
|
||||
- Slide the second motor in from the top.
|
||||
- Fasten the second motor with 4 M2x6mm screws.
|
||||
- Attach both motor horns to motor 2, again use the M3x6mm horn screw.
|
||||
- Attach the upper arm with 4 M3x6mm screws on each side.
|
||||
|
||||
<div class="video-container">
|
||||
<video controls width="600">
|
||||
<source
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Joint2_v2.mp4"
|
||||
type="video/mp4"
|
||||
/>
|
||||
</video>
|
||||
</div>
|
||||
|
||||
### Joint 3
|
||||
|
||||
- Insert motor 3 and fasten using 4 M2x6mm screws
|
||||
- Attach both motor horns to motor 3 and secure one again with a M3x6mm horn screw.
|
||||
- Connect the forearm to motor 3 using 4 M3x6mm screws on each side.
|
||||
|
||||
<div class="video-container">
|
||||
<video controls width="600">
|
||||
<source
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Joint3_v2.mp4"
|
||||
type="video/mp4"
|
||||
/>
|
||||
</video>
|
||||
</div>
|
||||
|
||||
### Joint 4
|
||||
|
||||
- Slide over motor holder 4.
|
||||
- Slide in motor 4.
|
||||
- Fasten motor 4 with 4 M2x6mm screws and attach its motor horns, use a M3x6mm horn screw.
|
||||
|
||||
<div class="video-container">
|
||||
<video controls width="600">
|
||||
<source
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Joint4_v2.mp4"
|
||||
type="video/mp4"
|
||||
/>
|
||||
</video>
|
||||
</div>
|
||||
|
||||
### Joint 5
|
||||
|
||||
- Insert motor 5 into the wrist holder and secure it with 2 M2x6mm front screws.
|
||||
- Install only one motor horn on the wrist motor and secure it with a M3x6mm horn screw.
|
||||
- Secure the wrist to motor 4 using 4 M3x6mm screws on both sides.
|
||||
|
||||
<div class="video-container">
|
||||
<video controls width="600">
|
||||
<source
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Joint5_v2.mp4"
|
||||
type="video/mp4"
|
||||
/>
|
||||
</video>
|
||||
</div>
|
||||
|
||||
### Gripper / Handle
|
||||
|
||||
<hfoptions id="assembly">
|
||||
<hfoption id="Follower">
|
||||
|
||||
- Attach the gripper to motor 5, attach it to the motor horn on the wrist using 4 M3x6mm screws.
|
||||
- Insert the gripper motor and secure it with 2 M2x6mm screws on each side.
|
||||
- Attach the motor horns and again use a M3x6mm horn screw.
|
||||
- Install the gripper claw and secure it with 4 M3x6mm screws on both sides.
|
||||
|
||||
<div class="video-container">
|
||||
<video controls width="600">
|
||||
<source
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Gripper_v2.mp4"
|
||||
type="video/mp4"
|
||||
/>
|
||||
</video>
|
||||
</div>
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="Leader">
|
||||
|
||||
- Mount the leader holder onto the wrist and secure it with 4 M3x6mm screws.
|
||||
- Attach the handle to motor 5 using 1 M2x6mm screw.
|
||||
- Insert the gripper motor, secure it with 2 M2x6mm screws on each side, attach a motor horn using a M3x6mm horn screw.
|
||||
- Attach the follower trigger with 4 M3x6mm screws.
|
||||
|
||||
<div class="video-container">
|
||||
<video controls width="600">
|
||||
<source
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Leader_v2.mp4"
|
||||
type="video/mp4"
|
||||
/>
|
||||
</video>
|
||||
</div>
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
## Calibrate
|
||||
|
||||
Next, you'll need to calibrate your robot to ensure that the leader and follower arms have the same position values when they are in the same physical position.
|
||||
|
||||
@@ -1,42 +0,0 @@
|
||||
# PyTorch accelerators
|
||||
|
||||
LeRobot supports multiple hardware acceleration options for both training and inference.
|
||||
|
||||
These options include:
|
||||
|
||||
- **CPU**: CPU executes all computations, no dedicated accelerator is used
|
||||
- **CUDA**: acceleration with NVIDIA & AMD GPUs
|
||||
- **MPS**: acceleration with Apple Silicon GPUs
|
||||
- **XPU**: acceleration with Intel integrated and discrete GPUs
|
||||
|
||||
## Getting Started
|
||||
|
||||
To use particular accelerator, a suitable version of PyTorch should be installed.
|
||||
|
||||
For CPU, CUDA, and MPS backends follow instructions provided on [PyTorch installation page](https://pytorch.org/get-started/locally).
|
||||
For XPU backend, follow instructions from [PyTorch documentation](https://docs.pytorch.org/docs/stable/notes/get_start_xpu.html).
|
||||
|
||||
### Verifying the installation
|
||||
|
||||
After installation, accelerator availability can be verified by running
|
||||
|
||||
```python
|
||||
import torch
|
||||
print(torch.<backend_name>.is_available()) # <backend_name> is cuda, mps, or xpu
|
||||
```
|
||||
|
||||
## How to run training or evaluation
|
||||
|
||||
To select the desired accelerator, use the `--policy.device` flag when running `lerobot-train` or `lerobot-eval`. For example, to use MPS on Apple Silicon, run:
|
||||
|
||||
```bash
|
||||
lerobot-train
|
||||
--policy.device=mps ...
|
||||
```
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.device=mps ...
|
||||
```
|
||||
|
||||
However, in most cases, presence of an accelerator is detected automatically and `policy.device` parameter can be omitted from CLI commands.
|
||||
@@ -4,12 +4,11 @@ This guide covers the complete setup process for the Unitree G1 humanoid, from i
|
||||
|
||||
## About the Unitree G1
|
||||
|
||||
We offer support for both 29 and 23 DOF G1. We introduce:
|
||||
We offer support for both 29 and 23 DOF G1. In this first PR we introduce:
|
||||
|
||||
- **`unitree g1` robot class, handling low level communication with the humanoid**
|
||||
- **ZMQ socket bridge** for remote communication over WiFi, allowing one to deploy policies remotely instead of over ethernet or directly on the Orin
|
||||
- **GR00T locomotion policy** for bipedal walking and balance
|
||||
- **MuJoCo simulation mode** for testing policies without the physical robot
|
||||
|
||||
---
|
||||
|
||||
@@ -192,10 +191,6 @@ Press `Ctrl+C` to stop the policy.
|
||||
|
||||
---
|
||||
|
||||
## Extra: Running in Simulation Mode (MuJoCo)
|
||||
|
||||
You can now test and develop policies without a physical robot using MuJoCo. to do so set `is_simulation=True` in config.
|
||||
|
||||
## Additional Resources
|
||||
|
||||
- [Unitree SDK Documentation](https://github.com/unitreerobotics/unitree_sdk2_python)
|
||||
|
||||
@@ -11,14 +11,13 @@ LeRobot provides several utilities for manipulating datasets:
|
||||
3. **Merge Datasets** - Combine multiple datasets into one. The datasets must have identical features, and episodes are concatenated in the order specified in `repo_ids`
|
||||
4. **Add Features** - Add new features to a dataset
|
||||
5. **Remove Features** - Remove features from a dataset
|
||||
6. **Convert to Video** - Convert image-based datasets to video format for efficient storage
|
||||
|
||||
The core implementation is in `lerobot.datasets.dataset_tools`.
|
||||
An example script detailing how to use the tools API is available in `examples/dataset/use_dataset_tools.py`.
|
||||
|
||||
## Command-Line Tool: lerobot-edit-dataset
|
||||
|
||||
`lerobot-edit-dataset` is a command-line script for editing datasets. It can be used to delete episodes, split datasets, merge datasets, add features, remove features, and convert image datasets to video format.
|
||||
`lerobot-edit-dataset` is a command-line script for editing datasets. It can be used to delete episodes, split datasets, merge datasets, add features, and remove features.
|
||||
|
||||
Run `lerobot-edit-dataset --help` for more information on the configuration of each operation.
|
||||
|
||||
@@ -87,71 +86,9 @@ lerobot-edit-dataset \
|
||||
--operation.feature_names "['observation.images.top']"
|
||||
```
|
||||
|
||||
#### Convert to Video
|
||||
|
||||
Convert an image-based dataset to video format, creating a new LeRobotDataset where images are stored as videos. This is useful for reducing storage requirements and improving data loading performance. The new dataset will have the exact same structure as the original, but with images encoded as MP4 videos in the proper LeRobot format.
|
||||
|
||||
```bash
|
||||
# Local-only: Save to a custom output directory (no hub push)
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--operation.type convert_to_video \
|
||||
--operation.output_dir /path/to/output/pusht_video
|
||||
|
||||
# Save with new repo_id (local storage)
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--new_repo_id lerobot/pusht_video \
|
||||
--operation.type convert_to_video
|
||||
|
||||
# Convert and push to Hugging Face Hub
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--new_repo_id lerobot/pusht_video \
|
||||
--operation.type convert_to_video \
|
||||
--push_to_hub true
|
||||
|
||||
# Convert with custom video codec and quality settings
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--operation.type convert_to_video \
|
||||
--operation.output_dir outputs/pusht_video \
|
||||
--operation.vcodec libsvtav1 \
|
||||
--operation.pix_fmt yuv420p \
|
||||
--operation.g 2 \
|
||||
--operation.crf 30
|
||||
|
||||
# Convert only specific episodes
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--operation.type convert_to_video \
|
||||
--operation.output_dir outputs/pusht_video \
|
||||
--operation.episode_indices "[0, 1, 2, 5, 10]"
|
||||
|
||||
# Convert with multiple workers for parallel processing
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--operation.type convert_to_video \
|
||||
--operation.output_dir outputs/pusht_video \
|
||||
--operation.num_workers 8
|
||||
```
|
||||
|
||||
**Parameters:**
|
||||
|
||||
- `output_dir`: Custom output directory (optional - by default uses `new_repo_id` or `{repo_id}_video`)
|
||||
- `vcodec`: Video codec to use - options: `h264`, `hevc`, `libsvtav1` (default: `libsvtav1`)
|
||||
- `pix_fmt`: Pixel format - options: `yuv420p`, `yuv444p` (default: `yuv420p`)
|
||||
- `g`: Group of pictures (GOP) size - lower values give better quality but larger files (default: 2)
|
||||
- `crf`: Constant rate factor - lower values give better quality but larger files, 0 is lossless (default: 30)
|
||||
- `fast_decode`: Fast decode tuning option (default: 0)
|
||||
- `episode_indices`: List of specific episodes to convert (default: all episodes)
|
||||
- `num_workers`: Number of parallel workers for processing (default: 4)
|
||||
|
||||
**Note:** The resulting dataset will be a proper LeRobotDataset with all cameras encoded as videos in the `videos/` directory, with parquet files containing only metadata (no raw image data). All episodes, stats, and tasks are preserved.
|
||||
|
||||
### Push to Hub
|
||||
|
||||
Add the `--push_to_hub true` flag to any command to automatically upload the resulting dataset to the Hugging Face Hub:
|
||||
Add the `--push_to_hub` flag to any command to automatically upload the resulting dataset to the Hugging Face Hub:
|
||||
|
||||
```bash
|
||||
lerobot-edit-dataset \
|
||||
@@ -159,45 +96,7 @@ lerobot-edit-dataset \
|
||||
--new_repo_id lerobot/pusht_after_deletion \
|
||||
--operation.type delete_episodes \
|
||||
--operation.episode_indices "[0, 2, 5]" \
|
||||
--push_to_hub true
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
There is also a tool for adding features to a dataset that is not yet covered in `lerobot-edit-dataset`.
|
||||
|
||||
# Dataset Visualization
|
||||
|
||||
## Online Visualization
|
||||
|
||||
When you record a dataset using `lerobot`, it automatically uploads to the Hugging Face Hub unless you specify otherwise. To view the dataset online, use our **LeRobot Dataset Visualizer**, available at:
|
||||
https://huggingface.co/spaces/lerobot/visualize_dataset
|
||||
|
||||
## Local Visualization
|
||||
|
||||
You can also visualize episodes from a dataset locally using our command-line tool.
|
||||
|
||||
**From the Hugging Face Hub:**
|
||||
|
||||
```bash
|
||||
lerobot-dataset-viz \
|
||||
--repo-id lerobot/pusht \
|
||||
--episode-index 0
|
||||
```
|
||||
|
||||
**From a local folder:**
|
||||
Add the `--root` option and set `--mode local`. For example, to search in `./my_local_data_dir/lerobot/pusht`:
|
||||
|
||||
```bash
|
||||
lerobot-dataset-viz \
|
||||
--repo-id lerobot/pusht \
|
||||
--root ./my_local_data_dir \
|
||||
--mode local \
|
||||
--episode-index 0
|
||||
```
|
||||
|
||||
Once executed, the tool opens `rerun.io` and displays the camera streams, robot states, and actions for the selected episode.
|
||||
|
||||
For advanced usage—including visualizing datasets stored on a remote server—run:
|
||||
|
||||
```bash
|
||||
lerobot-dataset-viz --help
|
||||
```
|
||||
|
||||
@@ -1,528 +0,0 @@
|
||||
# X-VLA: The First Soft-Prompted Robot Foundation Model for Any Robot, Any Task
|
||||
|
||||
## Overview
|
||||
|
||||
For years, robotics has aspired to build agents that can follow natural human instructions and operate dexterously across many environments and robot bodies. Recent breakthroughs in LLMs and VLMs suggest a path forward: extend these foundation-model architectures to embodied control by grounding them in actions. This has led to the rise of Vision-Language-Action (VLA) models, with the hope that a single generalist model could combine broad semantic understanding with robust manipulation skills.
|
||||
|
||||
But training such models is difficult. Robot data is fragmented across platforms, sensors, embodiments, and collection protocols. Heterogeneity appears everywhere: different arm configurations, different action spaces, different camera setups, different visual domains, and different task distributions. These inconsistencies create major distribution shifts that make pretraining unstable and adaptation unreliable.
|
||||
|
||||
Inspired by meta-learning and prompt learning, we ask: **"What if a VLA model could learn the structure of each robot and dataset the same way LLMs learn tasks, through prompts?"**
|
||||
|
||||
**X-VLA** is a soft-prompted, flow-matching VLA framework that treats each hardware setup as a "task" and encodes it using a small set of learnable embeddings. These **Soft Prompts** capture embodiment and domain-specific variations, guiding the Transformer from the earliest stages of multimodal fusion. With this mechanism, X-VLA can reconcile diverse robot morphologies, data types, and sensor setups within a single unified architecture.
|
||||
|
||||
<p align="center">
|
||||
<img
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/xvla-architecture.png"
|
||||
alt="XVLA Architecture"
|
||||
style="max-width: 100%; height: auto; width: 800px;"
|
||||
/>
|
||||
</p>
|
||||
|
||||
Built from pure Transformer encoders, X-VLA scales naturally with model size and dataset diversity. Across 6 simulation benchmarks and 3 real robots, Soft Prompts consistently outperform existing methods in handling hardware and domain differences. X-VLA-0.9B, trained on 290K episodes spanning seven robotic platforms, learns an embodiment-agnostic generalist policy in Phase I, and adapts efficiently to new robots in Phase II simply by learning a new set of prompts, while keeping the backbone frozen.
|
||||
|
||||
<p align="center">
|
||||
<img
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/xvla-architecture2.png"
|
||||
alt="XVLA Architecture 2"
|
||||
style="width: 60%; height: auto;"
|
||||
/>
|
||||
</p>
|
||||
|
||||
With only 1% of parameters tuned (9M), X-VLA-0.9B achieves near-π₀ performance on LIBERO and Simpler-WidowX, despite using **300× fewer trainable parameters**. It also demonstrates strong real-world dexterity with minimal demonstrations, including folding cloths in under two minutes.
|
||||
|
||||
<p align="center">
|
||||
<img
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/xvla-fold.png"
|
||||
alt="XVLA fold visualization"
|
||||
style="width: 95%; max-width: 1100px; height: auto;"
|
||||
/>
|
||||
</p>
|
||||
|
||||
X-VLA shows that generalist robot intelligence does not require increasingly complex architectures, only the right way to absorb heterogeneity. Soft Prompts offer a simple, scalable mechanism for unifying diverse robotic data, paving the way toward adaptable, cross-embodiment robot foundation models.
|
||||
|
||||
## Installation
|
||||
|
||||
After installing LeRobot, install the X-VLA dependencies:
|
||||
|
||||
```bash
|
||||
pip install -e .[xvla]
|
||||
```
|
||||
|
||||
After the new release, you'll be able to do:
|
||||
|
||||
```bash
|
||||
pip install lerobot[xvla]
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Basic Usage
|
||||
|
||||
To use X-VLA in your LeRobot configuration, specify the policy type as:
|
||||
|
||||
```bash
|
||||
policy.type=xvla
|
||||
```
|
||||
|
||||
### Evaluating Pre-trained Checkpoints
|
||||
|
||||
Example evaluation with LIBERO:
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.path="lerobot/xvla-libero" \
|
||||
--env.type=libero \
|
||||
--env.task=libero_spatial,libero_goal,libero_10 \
|
||||
--env.control_mode=absolute \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=1 \
|
||||
--env.episode_length=800 \
|
||||
--seed=142
|
||||
```
|
||||
|
||||
## Available Checkpoints
|
||||
|
||||
### 🎯 Base Model
|
||||
|
||||
**[lerobot/xvla-base](https://huggingface.co/lerobot/xvla-base)**
|
||||
|
||||
A 0.9B parameter instantiation of X-VLA, trained with a carefully designed data processing and learning recipe. The training pipeline consists of two phases:
|
||||
|
||||
- **Phase I: Pretraining** - Pretrained on 290K episodes from Droid, Robomind, and Agibot, spanning seven platforms across five types of robotic arms (single-arm to bi-manual setups). By leveraging soft prompts to absorb embodiment-specific variations, the model learns an embodiment-agnostic generalist policy.
|
||||
|
||||
- **Phase II: Domain Adaptation** - Adapted to deployable policies for target domains. A new set of soft prompts is introduced and optimized to encode the hardware configuration of the novel domain, while the pretrained backbone remains frozen.
|
||||
|
||||
### Simulation Checkpoints
|
||||
|
||||
**[lerobot/xvla-libero](https://huggingface.co/lerobot/xvla-libero)**
|
||||
|
||||
Achieves 93% success rate on LIBERO benchmarks. Fine-tuned from the base model for simulation tasks.
|
||||
|
||||
**[lerobot/xvla-widowx](https://huggingface.co/lerobot/xvla-widowx)**
|
||||
|
||||
Fine-tuned on BridgeData for pick-and-place experiments on compact WidowX platforms. Demonstrates robust manipulation capabilities.
|
||||
|
||||
### 🤖 Real-World Checkpoints
|
||||
|
||||
**[lerobot/xvla-folding](https://huggingface.co/lerobot/xvla-folding)**
|
||||
|
||||
A fine-tuned dexterous manipulation model trained on the high-quality Soft-FOLD cloth folding dataset. Achieves 100% success rate over 2 hours of continuous cloth folding.
|
||||
|
||||
**[lerobot/xvla-agibot-world](https://huggingface.co/lerobot/xvla-agibot-world)**
|
||||
|
||||
Optimized for AgileX robot dexterous manipulation tasks.
|
||||
|
||||
**[lerobot/xvla-google-robot](https://huggingface.co/lerobot/xvla-google-robot)**
|
||||
|
||||
Adapted for Google Robot platforms.
|
||||
|
||||
## Training X-VLA
|
||||
|
||||
### Recommended Training Configuration
|
||||
|
||||
When fine-tuning X-VLA for a new embodiment or task, we recommend not freezing the VLM, and also setting the `policy.dtype=bfloat16` to not hit OOM errors.
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--dataset.repo_id=YOUR_DATASET \
|
||||
--output_dir=./outputs/xvla_training \
|
||||
--job_name=xvla_training \
|
||||
--policy.path="lerobot/xvla-base" \
|
||||
--policy.repo_id="HF_USER/xvla-your-robot" \
|
||||
--policy.dtype=bfloat16 \
|
||||
--policy.action_mode=auto \
|
||||
--steps=20000 \
|
||||
--policy.device=cuda \
|
||||
--policy.freeze_vision_encoder=false \
|
||||
--policy.freeze_language_encoder=false \
|
||||
--policy.train_policy_transformer=true \
|
||||
--policy.train_soft_prompts=true \
|
||||
```
|
||||
|
||||
### Training Parameters Explained
|
||||
|
||||
| Parameter | Default | Description |
|
||||
| -------------------------- | ------- | ---------------------------------------------- |
|
||||
| `freeze_vision_encoder` | `false` | Do not freeze the VLM vision encoder weights |
|
||||
| `freeze_language_encoder` | `false` | Do not freeze the VLM language encoder weights |
|
||||
| `train_policy_transformer` | `true` | Allow policy transformer layers to train |
|
||||
| `train_soft_prompts` | `true` | Allow soft prompts to train |
|
||||
|
||||
**💡 Best Practice**: For Phase II adaptation to new embodiments, do not freeze the VLM encoders and also train the policy transformer and soft prompts.
|
||||
|
||||
### Example: Training on Bimanual Robot
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--dataset.repo_id=pepijn223/bimanual-so100-handover-cube \
|
||||
--output_dir=./outputs/xvla_bimanual \
|
||||
--job_name=xvla_so101_training \
|
||||
--policy.path="lerobot/xvla-base" \
|
||||
--policy.dtype=bfloat16 \
|
||||
--policy.repo_id="YOUR_USERNAME/xvla-biso101" \
|
||||
--steps=3000 \
|
||||
--policy.device=cuda \
|
||||
--policy.action_mode=so101_bimanual \
|
||||
--policy.freeze_vision_encoder=false \
|
||||
--policy.freeze_language_encoder=false \
|
||||
--policy.train_policy_transformer=true \
|
||||
--policy.train_soft_prompts=true
|
||||
```
|
||||
|
||||
💡 **Best Performance:** If you have sufficient computational resources and want to achieve best X-VLA finetuning performance, you should follow the official finetuning strategy:
|
||||
|
||||
**🔥 Full-finetune all components with a custom learning-rate scheme**
|
||||
|
||||
To ensure stable optimization, the Vision-Language Model (VLM) must be trained with only 1/10 of the base learning rate, while all other components use the full LR.
|
||||
This LR ratio is crucial for achieving strong and stable finetuning performance. This is already done for you by default.
|
||||
❕Note
|
||||
|
||||
Completely matching the official reported performance may require an additional warm-up LR schedule for soft-prompts, which can bring minor improvements.
|
||||
We encourage implementing this in your customized training pipeline for optimal results.
|
||||
|
||||
## Core Concepts
|
||||
|
||||
### 1. Action Modes
|
||||
|
||||
X-VLA uses an **Action Registry** system to handle different action spaces and embodiments. The `action_mode` parameter defines how actions are processed, what loss functions are used, and how predictions are post-processed.
|
||||
|
||||
#### Available Action Modes
|
||||
|
||||
| Action Mode | Action Dim | Description | Use Case |
|
||||
| ---------------- | ----------------------- | ------------------------------------------- | ------------------------------------ |
|
||||
| `ee6d` | 20 | End-effector with xyz, 6D rotation, gripper | Dual-arm setups with spatial control |
|
||||
| `joint` | 14 | Joint-space with gripper | Direct joint control robots |
|
||||
| `agibot_ee6d` | 20 | AGI-bot variant with MSE loss | AGI-bot platforms |
|
||||
| `so101_bimanual` | 20 (model), 12 (real) | SO101 bimanual robot | Bimanual manipulation tasks |
|
||||
| `auto` | 20 (model), auto (real) | Auto-detects action dim from dataset | **Recommended** for new robots |
|
||||
|
||||
#### Why Action Modes Matter
|
||||
|
||||
When you have a pretrained checkpoint like `lerobot/xvla-base` trained with `action_dim=20`, and you want to train on a dataset with a different action dimension (e.g., 14 for bimanual arms), you can't simply trim the action dimension. The action mode orchestrates:
|
||||
|
||||
1. **Loss Computation**: Different loss functions for different action components (MSE for joints, BCE for grippers, etc.)
|
||||
2. **Preprocessing**: Zeroing out gripper channels, padding dimensions
|
||||
3. **Postprocessing**: Applying sigmoid to gripper logits, trimming padding
|
||||
|
||||
#### Example: BimanualSO101 Action Space
|
||||
|
||||
The `so101_bimanual` action mode handles the mismatch between model output (20D) and real robot control (12D):
|
||||
|
||||
```python
|
||||
# Model outputs 20 dimensions for compatibility
|
||||
dim_action = 20
|
||||
|
||||
# Real robot only needs 12 dimensions
|
||||
# [left_arm (6), right_arm (6)] = [joints (5) + gripper (1)] × 2
|
||||
REAL_DIM = 12
|
||||
|
||||
# Preprocessing: Pad 12D actions to 20D for training
|
||||
# Postprocessing: Trim 20D predictions to 12D for deployment
|
||||
```
|
||||
|
||||
See the [action_hub.py](/home/jade_choghari/robot/lerobot/src/lerobot/policies/xvla/action_hub.py) implementation for details.
|
||||
|
||||
#### Auto Action Mode (Recommended)
|
||||
|
||||
The `auto` action mode is the easiest way to use X-VLA with any robot. It automatically detects your dataset's action dimension and handles padding/trimming:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--policy.path="lerobot/xvla-base" \
|
||||
--policy.action_mode=auto \
|
||||
--policy.max_action_dim=20 \
|
||||
...
|
||||
```
|
||||
|
||||
**How it works:**
|
||||
|
||||
- Reads `action_feature.shape[-1]` from your dataset (e.g., 7 for Franka)
|
||||
- Model outputs `max_action_dim` (default 20) for pretrained compatibility
|
||||
- Loss is computed **only on the real dimensions**: `MSE(pred[:,:,:real_dim], target[:,:,:real_dim])`
|
||||
- Postprocess trims output back to `real_dim` for robot control
|
||||
|
||||
This eliminates the need to create custom action modes for most robots.
|
||||
|
||||
### 2. Domain IDs
|
||||
|
||||
Domain IDs are learnable identifiers for different robot configurations and camera setups. They allow X-VLA to distinguish between:
|
||||
|
||||
- Different robots (Robot 1 vs Robot 2)
|
||||
- Different camera configurations (cam1 vs cam2)
|
||||
- Different combinations (Robot1-cam1-cam2 vs Robot1-cam1 vs Robot2-cam1)
|
||||
|
||||
#### Setting Domain IDs
|
||||
|
||||
**During Training**: By default, domain_id is set to 0 for general training.
|
||||
|
||||
**During Evaluation**: Specify the domain_id that matches your checkpoint's training configuration.
|
||||
|
||||
```python
|
||||
# Example: LIBERO checkpoint uses domain_id=3
|
||||
domain_id = 3
|
||||
```
|
||||
|
||||
The domain_id is automatically added to observations by the `XVLAAddDomainIdProcessorStep` in the preprocessing pipeline.
|
||||
|
||||
The `lerobot/xvla-base` model has been trained on the following domain IDs. It is recommended to choose one that most resembles your robot/configuration:
|
||||
|
||||
#### Fine-tuning Datasets
|
||||
|
||||
| Dataset Name | Domain ID |
|
||||
| ---------------- | --------- |
|
||||
| Bridge | 0 |
|
||||
| RT1 | 1 |
|
||||
| Calvin | 2 |
|
||||
| libero | 3 |
|
||||
| widowx-air | 4 |
|
||||
| AIR-AGILEX-HQ | 5 |
|
||||
| robotwin2_abs_ee | 6 |
|
||||
| robotwin2_clean | 6 |
|
||||
| robocasa-human | 7 |
|
||||
| VLABench | 8 |
|
||||
| AGIBOT-challenge | 9 |
|
||||
| AIR-AGILEX | 10 |
|
||||
| AIRBOT | 18 |
|
||||
|
||||
### 3. Processor Steps
|
||||
|
||||
X-VLA requires specific preprocessing and postprocessing steps for proper operation.
|
||||
|
||||
#### Required Preprocessing Steps
|
||||
|
||||
1. **XVLAImageToFloatProcessorStep**: Converts images from [0, 255] to [0, 1] range
|
||||
2. **XVLAImageNetNormalizeProcessorStep**: Applies ImageNet normalization (required for VLM backbone)
|
||||
3. **XVLAAddDomainIdProcessorStep**: Adds domain_id to observations
|
||||
|
||||
#### Example Custom Processor
|
||||
|
||||
For LIBERO environments, a custom processor handles the specific observation format:
|
||||
|
||||
```python
|
||||
from lerobot.policies.xvla.processor_xvla import LiberoProcessorStep
|
||||
|
||||
processor = LiberoProcessorStep()
|
||||
# Handles robot_state dictionary, converts rotation matrices to 6D representation
|
||||
# Applies 180° image rotation for camera convention
|
||||
```
|
||||
|
||||
### 4. Configuration Parameters
|
||||
|
||||
Key configuration parameters for X-VLA:
|
||||
|
||||
```python
|
||||
# Observation and action
|
||||
n_obs_steps: int = 1 # Number of observation timesteps
|
||||
chunk_size: int = 32 # Action sequence length
|
||||
n_action_steps: int = 32 # Number of action steps to execute
|
||||
|
||||
# Model architecture
|
||||
hidden_size: int = 1024 # Transformer hidden dimension
|
||||
depth: int = 24 # Number of transformer layers
|
||||
num_heads: int = 16 # Number of attention heads
|
||||
num_domains: int = 30 # Maximum number of domain IDs
|
||||
len_soft_prompts: int = 32 # Length of soft prompt embeddings
|
||||
|
||||
# Action space
|
||||
action_mode: str = "ee6d" # Action space type (use "auto" for auto-detection)
|
||||
use_proprio: bool = True # Use proprioceptive state
|
||||
max_state_dim: int = 32 # Maximum state dimension
|
||||
max_action_dim: int = 20 # Max action dim for padding (used by "auto" mode)
|
||||
|
||||
# Vision
|
||||
num_image_views: int | None # Number of camera views
|
||||
resize_imgs_with_padding: tuple[int, int] | None # Target image size with padding
|
||||
|
||||
# Training
|
||||
num_denoising_steps: int = 10 # Flow matching denoising steps
|
||||
```
|
||||
|
||||
## Creating Custom Action Modes
|
||||
|
||||
If your robot has a unique action space, you can create a custom action mode:
|
||||
|
||||
### Step 1: Define Your Action Space
|
||||
|
||||
```python
|
||||
from lerobot.policies.xvla.action_hub import BaseActionSpace, register_action
|
||||
import torch.nn as nn
|
||||
|
||||
@register_action("my_custom_robot")
|
||||
class MyCustomActionSpace(BaseActionSpace):
|
||||
"""Custom action space for my robot."""
|
||||
|
||||
dim_action = 15 # Your robot's action dimension
|
||||
gripper_idx = (7, 14) # Gripper channel indices
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.mse = nn.MSELoss()
|
||||
self.bce = nn.BCEWithLogitsLoss()
|
||||
|
||||
def compute_loss(self, pred, target):
|
||||
"""Define your loss computation."""
|
||||
# Example: MSE for joints, BCE for grippers
|
||||
joints_loss = self.mse(pred[:, :, :7], target[:, :, :7])
|
||||
gripper_loss = self.bce(pred[:, :, self.gripper_idx],
|
||||
target[:, :, self.gripper_idx])
|
||||
|
||||
return {
|
||||
"joints_loss": joints_loss,
|
||||
"gripper_loss": gripper_loss,
|
||||
}
|
||||
|
||||
def preprocess(self, proprio, action, mode="train"):
|
||||
"""Preprocess actions before training."""
|
||||
# Example: Zero out grippers in proprioception
|
||||
proprio_m = proprio.clone()
|
||||
action_m = action.clone() if action is not None else None
|
||||
proprio_m[..., self.gripper_idx] = 0.0
|
||||
if action_m is not None:
|
||||
action_m[..., self.gripper_idx] = 0.0
|
||||
return proprio_m, action_m
|
||||
|
||||
def postprocess(self, action):
|
||||
"""Post-process predictions for deployment."""
|
||||
# Example: Apply sigmoid to gripper logits
|
||||
action[..., self.gripper_idx] = torch.sigmoid(action[..., self.gripper_idx])
|
||||
return action
|
||||
```
|
||||
|
||||
### Step 2: Use Your Custom Action Mode
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--policy.action_mode=my_custom_robot \
|
||||
--dataset.repo_id=YOUR_DATASET \
|
||||
--policy.path="lerobot/xvla-base" \
|
||||
...
|
||||
```
|
||||
|
||||
## Advanced Topics
|
||||
|
||||
### Multi-Camera Support
|
||||
|
||||
X-VLA supports multiple camera views through the `num_image_views` parameter:
|
||||
|
||||
```python
|
||||
# Configure for 3 camera views
|
||||
policy.num_image_views=3
|
||||
|
||||
# Add empty cameras if you have fewer physical cameras
|
||||
policy.empty_cameras=1 # Adds 1 zero-padded camera view
|
||||
```
|
||||
|
||||
### Custom Preprocessing Pipeline
|
||||
|
||||
Create a custom preprocessing pipeline for your environment:
|
||||
|
||||
```python
|
||||
from lerobot.processor import PolicyProcessorPipeline
|
||||
from lerobot.policies.xvla.processor_xvla import (
|
||||
XVLAImageToFloatProcessorStep,
|
||||
XVLAImageNetNormalizeProcessorStep,
|
||||
XVLAAddDomainIdProcessorStep,
|
||||
)
|
||||
|
||||
# Build custom pipeline
|
||||
preprocessor = PolicyProcessorPipeline(
|
||||
steps=[
|
||||
YourCustomProcessorStep(), # Your custom processing
|
||||
XVLAImageToFloatProcessorStep(), # Required: convert to float
|
||||
XVLAImageNetNormalizeProcessorStep(), # Required: ImageNet norm
|
||||
XVLAAddDomainIdProcessorStep(domain_id=5), # Your domain ID
|
||||
]
|
||||
)
|
||||
```
|
||||
|
||||
### Handling Different Action Dimensions
|
||||
|
||||
When your dataset has fewer action dimensions than the pretrained model:
|
||||
|
||||
**Option 1 (Recommended)**: Use `auto` action mode
|
||||
|
||||
```bash
|
||||
# Automatically detects your dataset's action dimension
|
||||
# Works with any robot without custom code
|
||||
policy.action_mode=auto
|
||||
policy.max_action_dim=20 # Match pretrained model
|
||||
```
|
||||
|
||||
**Option 2**: Use a predefined action mode with built-in padding
|
||||
|
||||
```python
|
||||
# Model expects 20D, dataset has 12D
|
||||
# Action mode handles padding internally
|
||||
action_mode = "so101_bimanual" # Pads 12 → 20
|
||||
```
|
||||
|
||||
**Option 2**: Create a custom action mode that maps dimensions explicitly
|
||||
|
||||
```python
|
||||
@register_action("my_mapped_action")
|
||||
class MappedActionSpace(BaseActionSpace):
|
||||
dim_action = 20
|
||||
REAL_DIM = 12
|
||||
|
||||
def _pad_to_model_dim(self, x):
|
||||
# Custom padding logic
|
||||
...
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
|
||||
**Issue**: "Action dimension mismatch"
|
||||
|
||||
- **Solution**: Check that your `action_mode` matches your robot's action space. Create a custom action mode if needed.
|
||||
|
||||
**Issue**: "Image values outside [0, 1] range"
|
||||
|
||||
- **Solution**: Ensure images are preprocessed with `XVLAImageToFloatProcessorStep` before normalization.
|
||||
|
||||
**Issue**: "Domain ID not found"
|
||||
|
||||
- **Solution**: Make sure `XVLAAddDomainIdProcessorStep` is in your preprocessing pipeline with the correct domain_id.
|
||||
|
||||
**Issue**: "Low success rate on new embodiment"
|
||||
|
||||
- **Solution**:
|
||||
1. Verify your action_mode is correct
|
||||
2. Check that soft prompts are being trained (`train_soft_prompts=True`)
|
||||
3. Ensure proper preprocessing (ImageNet normalization, domain_id)
|
||||
4. Consider increasing training steps
|
||||
|
||||
**Issue**: "Out of memory during training"
|
||||
|
||||
- **Solution**:
|
||||
1. Reduce `chunk_size` (e.g., from 32 to 16)
|
||||
2. Enable gradient checkpointing
|
||||
3. Reduce batch size
|
||||
4. Freeze more components
|
||||
|
||||
## Citation
|
||||
|
||||
If you use X-VLA in your research, please cite:
|
||||
|
||||
```bibtex
|
||||
@article{zheng2025x,
|
||||
title = {X-VLA: Soft-Prompted Transformer as Scalable Cross-Embodiment Vision-Language-Action Model},
|
||||
author = {Zheng, Jinliang and Li, Jianxiong and Wang, Zhihao and Liu, Dongxiu and Kang, Xirui
|
||||
and Feng, Yuchun and Zheng, Yinan and Zou, Jiayin and Chen, Yilun and Zeng, Jia and others},
|
||||
journal = {arXiv preprint arXiv:2510.10274},
|
||||
year = {2025}
|
||||
}
|
||||
```
|
||||
|
||||
## Additional Resources
|
||||
|
||||
- [X-VLA Paper](https://arxiv.org/pdf/2510.10274)
|
||||
- [LeRobot Documentation](https://github.com/huggingface/lerobot)
|
||||
- [Action Registry Implementation](https://github.com/huggingface/lerobot/src/lerobot/policies/xvla/action_hub.py)
|
||||
- [Processor Implementation](https://github.com/huggingface/lerobot/src/lerobot/policies/xvla/processor_xvla.py)
|
||||
- [Model Configuration](https://github.com/huggingface/lerobot/src/lerobot/policies/xvla/configuration_xvla.py)
|
||||
|
||||
## Contributing
|
||||
|
||||
We welcome contributions! If you've implemented a new action mode or processor for your robot, please consider submitting a PR to help the community.
|
||||
@@ -0,0 +1,245 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Aggregate EgoDex shards into a single dataset.
|
||||
|
||||
After distributed processing creates multiple shards, this script combines
|
||||
them into a single unified dataset.
|
||||
|
||||
Reference: https://arxiv.org/abs/2505.11709, https://github.com/apple/ml-egodex
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from datatrove.executor import LocalPipelineExecutor
|
||||
from datatrove.executor.slurm import SlurmPipelineExecutor
|
||||
from datatrove.pipeline.base import PipelineStep
|
||||
|
||||
|
||||
class AggregateEgoDexDatasets(PipelineStep):
|
||||
"""Datatrove pipeline step for aggregating EgoDex shards."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
repo_ids: list[str],
|
||||
aggregated_repo_id: str,
|
||||
local_dir: Path | str | None = None,
|
||||
push_to_hub: bool = False,
|
||||
hf_repo_id: str | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.repo_ids = repo_ids
|
||||
self.aggr_repo_id = aggregated_repo_id
|
||||
self.local_dir = Path(local_dir) if local_dir else None
|
||||
self.push_to_hub = push_to_hub
|
||||
self.hf_repo_id = hf_repo_id if hf_repo_id else aggregated_repo_id
|
||||
|
||||
def run(self, data=None, rank: int = 0, world_size: int = 1):
|
||||
import logging
|
||||
|
||||
from lerobot.datasets.aggregate import aggregate_datasets
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
init_logging()
|
||||
|
||||
# Only worker 0 performs aggregation (aggregate_datasets handles parallelism internally)
|
||||
if rank == 0:
|
||||
logging.info(f"Starting aggregation of {len(self.repo_ids)} shards into {self.aggr_repo_id}")
|
||||
|
||||
# Build roots list if local_dir is specified
|
||||
roots = None
|
||||
if self.local_dir:
|
||||
roots = [self.local_dir / repo_id for repo_id in self.repo_ids]
|
||||
# Filter to only existing directories
|
||||
existing_roots = [r for r in roots if r.exists()]
|
||||
if len(existing_roots) != len(self.repo_ids):
|
||||
logging.warning(
|
||||
f"Only {len(existing_roots)} of {len(self.repo_ids)} shard directories found. "
|
||||
"Missing shards will be skipped."
|
||||
)
|
||||
# Update repo_ids to match existing roots
|
||||
existing_repo_ids = [
|
||||
repo_id for repo_id, r in zip(self.repo_ids, roots, strict=False) if r.exists()
|
||||
]
|
||||
roots = existing_roots
|
||||
self.repo_ids = existing_repo_ids
|
||||
|
||||
if len(self.repo_ids) == 0:
|
||||
logging.error("No shard directories found. Nothing to aggregate.")
|
||||
return
|
||||
|
||||
aggr_root = self.local_dir / self.aggr_repo_id if self.local_dir else None
|
||||
|
||||
aggregate_datasets(
|
||||
repo_ids=self.repo_ids,
|
||||
aggr_repo_id=self.aggr_repo_id,
|
||||
roots=roots,
|
||||
aggr_root=aggr_root,
|
||||
)
|
||||
logging.info("Aggregation complete!")
|
||||
|
||||
# Push to Hugging Face Hub if requested
|
||||
if self.push_to_hub:
|
||||
logging.info(f"Pushing to Hugging Face Hub as {self.hf_repo_id}...")
|
||||
dataset = LeRobotDataset(
|
||||
repo_id=self.aggr_repo_id,
|
||||
root=aggr_root,
|
||||
)
|
||||
# Update repo_id for pushing to different HF account if specified
|
||||
dataset.repo_id = self.hf_repo_id
|
||||
dataset.push_to_hub(
|
||||
tags=["egodex", "hand", "dexterous", "lerobot"],
|
||||
license="cc-by-nc-nd-4.0",
|
||||
)
|
||||
logging.info("Push to hub complete!")
|
||||
else:
|
||||
logging.info(f"Worker {rank} skipping - only worker 0 performs aggregation")
|
||||
|
||||
|
||||
def make_aggregate_executor(
|
||||
repo_id,
|
||||
num_shards,
|
||||
job_name,
|
||||
logs_dir,
|
||||
partition,
|
||||
cpus_per_task,
|
||||
mem_per_cpu,
|
||||
local_dir,
|
||||
push_to_hub,
|
||||
hf_repo_id,
|
||||
slurm=True,
|
||||
):
|
||||
"""Create executor for aggregating EgoDex shards."""
|
||||
# Generate repo IDs for all shards
|
||||
repo_ids = [f"{repo_id}_world_{num_shards}_rank_{rank}" for rank in range(num_shards)]
|
||||
|
||||
kwargs = {
|
||||
"pipeline": [
|
||||
AggregateEgoDexDatasets(repo_ids, repo_id, local_dir, push_to_hub, hf_repo_id),
|
||||
],
|
||||
"logging_dir": str(logs_dir / job_name),
|
||||
}
|
||||
|
||||
if slurm:
|
||||
kwargs.update(
|
||||
{
|
||||
"job_name": job_name,
|
||||
"tasks": 1, # Only need 1 task for aggregation
|
||||
"workers": 1, # Only need 1 worker
|
||||
"time": "24:00:00", # 24 hours for aggregation
|
||||
"partition": partition,
|
||||
"cpus_per_task": cpus_per_task,
|
||||
"sbatch_args": {"mem-per-cpu": mem_per_cpu},
|
||||
}
|
||||
)
|
||||
executor = SlurmPipelineExecutor(**kwargs)
|
||||
else:
|
||||
kwargs.update(
|
||||
{
|
||||
"tasks": 1,
|
||||
"workers": 1,
|
||||
}
|
||||
)
|
||||
executor = LocalPipelineExecutor(**kwargs)
|
||||
|
||||
return executor
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Aggregate EgoDex dataset shards into a single unified dataset."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Repository identifier (base name without shard suffix, e.g., pepijn/egodex-test)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-shards",
|
||||
type=int,
|
||||
required=True,
|
||||
help="Number of shards to aggregate (must match --workers from slurm_port_egodex.py)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--logs-dir",
|
||||
type=Path,
|
||||
default=Path("logs"),
|
||||
help="Path to logs directory for datatrove",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--job-name",
|
||||
type=str,
|
||||
default="aggr_egodex",
|
||||
help="Job name used in SLURM",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--slurm",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Launch over SLURM. Use --slurm 0 to launch locally (for debugging)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--partition",
|
||||
type=str,
|
||||
help="SLURM partition (ideally CPU partition)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cpus-per-task",
|
||||
type=int,
|
||||
default=16,
|
||||
help="Number of CPUs for aggregation task",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mem-per-cpu",
|
||||
type=str,
|
||||
default="8G",
|
||||
help="Memory per CPU for aggregation",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--local-dir",
|
||||
type=Path,
|
||||
default=None,
|
||||
help="Local directory where shards are stored. If not specified, uses default HF cache.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push-to-hub",
|
||||
action="store_true",
|
||||
help="Push aggregated dataset to Hugging Face Hub after aggregation.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hf-repo-id",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Hugging Face repo ID for upload (e.g., username/dataset-name). Defaults to --repo-id.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
kwargs = vars(args)
|
||||
kwargs["slurm"] = kwargs.pop("slurm") == 1
|
||||
|
||||
aggregate_executor = make_aggregate_executor(**kwargs)
|
||||
aggregate_executor.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -0,0 +1,129 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Download EgoDex dataset
|
||||
# Reference: https://arxiv.org/abs/2505.11709, https://github.com/apple/ml-egodex
|
||||
#
|
||||
# Usage: ./download_egodex.sh [output_dir] [parts...]
|
||||
#
|
||||
# Examples:
|
||||
# ./download_egodex.sh ./data test # Download test set only (16 GB)
|
||||
# ./download_egodex.sh ./data part1 part2 # Download training parts 1 and 2
|
||||
# ./download_egodex.sh ./data all # Download everything (~1.7 TB)
|
||||
#
|
||||
# Available parts:
|
||||
# test - Test set (16 GB)
|
||||
# part1 - Training set part 1 (300 GB)
|
||||
# part2 - Training set part 2 (300 GB)
|
||||
# part3 - Training set part 3 (300 GB)
|
||||
# part4 - Training set part 4 (300 GB)
|
||||
# part5 - Training set part 5 (300 GB)
|
||||
# extra - Additional data (200 GB)
|
||||
# all - Download all parts (~1.7 TB total)
|
||||
|
||||
set -e
|
||||
|
||||
BASE_URL="https://ml-site.cdn-apple.com/datasets/egodex"
|
||||
|
||||
# Map part names to filenames
|
||||
declare -A PART_FILES=(
|
||||
["test"]="test.zip"
|
||||
["part1"]="part1.zip"
|
||||
["part2"]="part2.zip"
|
||||
["part3"]="part3.zip"
|
||||
["part4"]="part4.zip"
|
||||
["part5"]="part5.zip"
|
||||
["extra"]="extra.zip"
|
||||
)
|
||||
|
||||
ALL_PARTS=("test" "part1" "part2" "part3" "part4" "part5" "extra")
|
||||
|
||||
usage() {
|
||||
echo "Usage: $0 <output_dir> <parts...>"
|
||||
echo ""
|
||||
echo "Examples:"
|
||||
echo " $0 ./data test # Download test set only (16 GB)"
|
||||
echo " $0 ./data part1 part2 # Download training parts 1 and 2"
|
||||
echo " $0 ./data all # Download everything (~1.7 TB)"
|
||||
echo ""
|
||||
echo "Available parts: test, part1, part2, part3, part4, part5, extra, all"
|
||||
exit 1
|
||||
}
|
||||
|
||||
download_part() {
|
||||
local output_dir="$1"
|
||||
local part="$2"
|
||||
local filename="${PART_FILES[$part]}"
|
||||
local url="${BASE_URL}/${filename}"
|
||||
local output_file="${output_dir}/${filename}"
|
||||
|
||||
echo "----------------------------------------"
|
||||
echo "Downloading: ${part} (${filename})"
|
||||
echo "URL: ${url}"
|
||||
echo "Output: ${output_file}"
|
||||
echo "----------------------------------------"
|
||||
|
||||
# Download with curl, showing progress
|
||||
curl -L --progress-bar "${url}" -o "${output_file}"
|
||||
|
||||
# Unzip
|
||||
echo "Extracting ${filename}..."
|
||||
unzip -q "${output_file}" -d "${output_dir}"
|
||||
|
||||
# Optionally remove zip file to save space
|
||||
# Uncomment the next line if you want to delete zips after extraction
|
||||
# rm "${output_file}"
|
||||
|
||||
echo "Done: ${part}"
|
||||
echo ""
|
||||
}
|
||||
|
||||
# Check arguments
|
||||
if [ $# -lt 2 ]; then
|
||||
usage
|
||||
fi
|
||||
|
||||
OUTPUT_DIR="$1"
|
||||
shift
|
||||
|
||||
# Create output directory
|
||||
mkdir -p "${OUTPUT_DIR}"
|
||||
|
||||
# Determine which parts to download
|
||||
PARTS_TO_DOWNLOAD=()
|
||||
|
||||
for arg in "$@"; do
|
||||
if [ "$arg" == "all" ]; then
|
||||
PARTS_TO_DOWNLOAD=("${ALL_PARTS[@]}")
|
||||
break
|
||||
elif [ -n "${PART_FILES[$arg]}" ]; then
|
||||
PARTS_TO_DOWNLOAD+=("$arg")
|
||||
else
|
||||
echo "Error: Unknown part '${arg}'"
|
||||
echo "Available parts: test, part1, part2, part3, part4, part5, extra, all"
|
||||
exit 1
|
||||
fi
|
||||
done
|
||||
|
||||
if [ ${#PARTS_TO_DOWNLOAD[@]} -eq 0 ]; then
|
||||
echo "Error: No valid parts specified"
|
||||
usage
|
||||
fi
|
||||
|
||||
echo "========================================"
|
||||
echo "EgoDex Dataset Download"
|
||||
echo "========================================"
|
||||
echo "Output directory: ${OUTPUT_DIR}"
|
||||
echo "Parts to download: ${PARTS_TO_DOWNLOAD[*]}"
|
||||
echo "========================================"
|
||||
echo ""
|
||||
|
||||
# Download each part
|
||||
for part in "${PARTS_TO_DOWNLOAD[@]}"; do
|
||||
download_part "${OUTPUT_DIR}" "${part}"
|
||||
done
|
||||
|
||||
echo "========================================"
|
||||
echo "Download complete!"
|
||||
echo "Data saved to: ${OUTPUT_DIR}"
|
||||
echo "========================================"
|
||||
|
||||
@@ -0,0 +1,443 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Distributed EgoDex dataset porting using SLURM and datatrove.
|
||||
|
||||
EgoDex is a large-scale dataset for egocentric dexterous manipulation collected
|
||||
with ARKit on Apple Vision Pro. This script converts EgoDex data to LeRobot format.
|
||||
|
||||
Reference: https://arxiv.org/abs/2505.11709, https://github.com/apple/ml-egodex
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import h5py
|
||||
import mediapy as mpy
|
||||
import numpy as np
|
||||
from datatrove.executor import LocalPipelineExecutor
|
||||
from datatrove.executor.slurm import SlurmPipelineExecutor
|
||||
from datatrove.pipeline.base import PipelineStep
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
# Image dimensions
|
||||
DEFAULT_IMAGE_HEIGHT = 1080
|
||||
DEFAULT_IMAGE_WIDTH = 1920
|
||||
|
||||
class PortEgoDexShards(PipelineStep):
|
||||
def __init__(
|
||||
self,
|
||||
raw_dir: Path | str,
|
||||
repo_id: str,
|
||||
local_dir: Path | str = None,
|
||||
percentage: float = 100.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.raw_dir = Path(raw_dir)
|
||||
self.repo_id = repo_id
|
||||
self.local_dir = Path(local_dir) if local_dir else Path("data/local_datasets")
|
||||
self.percentage = percentage
|
||||
|
||||
def run(self, data=None, rank: int = 0, world_size: int = 1):
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import h5py
|
||||
import mediapy as mpy
|
||||
import numpy as np
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
def _get_state_for_single_frame(transforms_group, frame_idx):
|
||||
"""
|
||||
Construct 48D hand state representation from EgoDex.
|
||||
|
||||
State vector composition (per hand = 24D, total = 48D):
|
||||
- Wrist 3D position (3)
|
||||
- Wrist orientation in 6D representation (6)
|
||||
- 5 fingertip 3D positions (15)
|
||||
"""
|
||||
state_vector = []
|
||||
fingertip_joints = {
|
||||
"left": [
|
||||
"leftThumbTip",
|
||||
"leftIndexFingerTip",
|
||||
"leftMiddleFingerTip",
|
||||
"leftRingFingerTip",
|
||||
"leftLittleFingerTip",
|
||||
],
|
||||
"right": [
|
||||
"rightThumbTip",
|
||||
"rightIndexFingerTip",
|
||||
"rightMiddleFingerTip",
|
||||
"rightRingFingerTip",
|
||||
"rightLittleFingerTip",
|
||||
],
|
||||
}
|
||||
|
||||
for hand_side in ["left", "right"]:
|
||||
hand_key = f"{hand_side}Hand"
|
||||
hand_transform = transforms_group[hand_key][frame_idx]
|
||||
|
||||
# 1. Wrist 3D position
|
||||
hand_position = hand_transform[:3, 3]
|
||||
state_vector.extend(hand_position)
|
||||
|
||||
# 2. Wrist orientation in compact 6D representation
|
||||
rotation_matrix = hand_transform[:3, :3]
|
||||
rotation_6d = np.concatenate([rotation_matrix[:, 0], rotation_matrix[:, 1]])
|
||||
state_vector.extend(rotation_6d)
|
||||
|
||||
# 3. 3D positions of 5 fingertips
|
||||
for fingertip in fingertip_joints[hand_side]:
|
||||
fingertip_transform = transforms_group[fingertip][frame_idx]
|
||||
fingertip_pos = fingertip_transform[:3, 3]
|
||||
state_vector.extend(fingertip_pos)
|
||||
|
||||
# Also return camera extrinsics for optional coordinate frame transformations
|
||||
return np.array(state_vector, dtype=np.float32), transforms_group["camera"][frame_idx]
|
||||
|
||||
def get_state_and_action_from_egodex_annotations(demo):
|
||||
"""
|
||||
Convert EgoDex demo annotations into states and actions.
|
||||
|
||||
The "action" is the state at time t+1 (next-pose prediction).
|
||||
"""
|
||||
transforms_group = demo["transforms"]
|
||||
total_frames = list(transforms_group.values())[0].shape[0]
|
||||
|
||||
states_list, extrinsics_list = [], []
|
||||
for frame_idx in range(total_frames):
|
||||
state_vector, extrinsics = _get_state_for_single_frame(transforms_group, frame_idx)
|
||||
states_list.append(state_vector)
|
||||
extrinsics_list.append(extrinsics.flatten()) # Flatten 4x4 to 16D
|
||||
|
||||
state = np.array(states_list, dtype=np.float32)
|
||||
extrinsics = np.array(extrinsics_list, dtype=np.float32)
|
||||
|
||||
# Shift by 1 timestep to convert state to action
|
||||
action = np.roll(state, -1, axis=0)
|
||||
|
||||
return state, action, extrinsics
|
||||
|
||||
def process_demo(hdf5_file_path, video_path):
|
||||
"""Process a single EgoDex demo and return frames for LeRobot."""
|
||||
video = mpy.read_video(str(video_path))
|
||||
video = np.asarray(video)
|
||||
num_frames = video.shape[0]
|
||||
frames = []
|
||||
|
||||
with h5py.File(hdf5_file_path, "r") as demo:
|
||||
state, action, extrinsics = get_state_and_action_from_egodex_annotations(demo)
|
||||
|
||||
# Get natural language task description
|
||||
if demo.attrs.get("llm_type") == "reversible":
|
||||
direction = demo.attrs.get("which_llm_description", "1")
|
||||
lang_instruction = demo.attrs.get(
|
||||
"llm_description" if direction == "1" else "llm_description2",
|
||||
"manipulation task",
|
||||
)
|
||||
else:
|
||||
lang_instruction = demo.attrs.get("llm_description", "manipulation task")
|
||||
|
||||
for step_idx in range(num_frames):
|
||||
# Resize image to default dimensions
|
||||
image_resized = cv2.resize(
|
||||
video[step_idx],
|
||||
(DEFAULT_IMAGE_WIDTH, DEFAULT_IMAGE_HEIGHT),
|
||||
interpolation=cv2.INTER_AREA,
|
||||
)
|
||||
frame = {
|
||||
"task": lang_instruction,
|
||||
"observation.image": image_resized,
|
||||
"observation.state": state[step_idx],
|
||||
"observation.extrinsics": extrinsics[step_idx],
|
||||
"action": action[step_idx],
|
||||
}
|
||||
frames.append(frame)
|
||||
|
||||
return frames
|
||||
|
||||
init_logging()
|
||||
|
||||
# Define EgoDex features
|
||||
EGODEX_FEATURES = {
|
||||
"observation.image": {
|
||||
"dtype": "video",
|
||||
"shape": (DEFAULT_IMAGE_HEIGHT, DEFAULT_IMAGE_WIDTH, 3),
|
||||
"names": ["height", "width", "rgb"],
|
||||
},
|
||||
"observation.state": {
|
||||
"dtype": "float32",
|
||||
"shape": (48,),
|
||||
"names": [
|
||||
# Left hand wrist position (3)
|
||||
"left_wrist_x",
|
||||
"left_wrist_y",
|
||||
"left_wrist_z",
|
||||
# Left hand wrist rotation 6D (6)
|
||||
"left_rot_0",
|
||||
"left_rot_1",
|
||||
"left_rot_2",
|
||||
"left_rot_3",
|
||||
"left_rot_4",
|
||||
"left_rot_5",
|
||||
# Left fingertips (15)
|
||||
"left_thumb_x",
|
||||
"left_thumb_y",
|
||||
"left_thumb_z",
|
||||
"left_index_x",
|
||||
"left_index_y",
|
||||
"left_index_z",
|
||||
"left_middle_x",
|
||||
"left_middle_y",
|
||||
"left_middle_z",
|
||||
"left_ring_x",
|
||||
"left_ring_y",
|
||||
"left_ring_z",
|
||||
"left_little_x",
|
||||
"left_little_y",
|
||||
"left_little_z",
|
||||
# Right hand wrist position (3)
|
||||
"right_wrist_x",
|
||||
"right_wrist_y",
|
||||
"right_wrist_z",
|
||||
# Right hand wrist rotation 6D (6)
|
||||
"right_rot_0",
|
||||
"right_rot_1",
|
||||
"right_rot_2",
|
||||
"right_rot_3",
|
||||
"right_rot_4",
|
||||
"right_rot_5",
|
||||
# Right fingertips (15)
|
||||
"right_thumb_x",
|
||||
"right_thumb_y",
|
||||
"right_thumb_z",
|
||||
"right_index_x",
|
||||
"right_index_y",
|
||||
"right_index_z",
|
||||
"right_middle_x",
|
||||
"right_middle_y",
|
||||
"right_middle_z",
|
||||
"right_ring_x",
|
||||
"right_ring_y",
|
||||
"right_ring_z",
|
||||
"right_little_x",
|
||||
"right_little_y",
|
||||
"right_little_z",
|
||||
],
|
||||
},
|
||||
"observation.extrinsics": {
|
||||
"dtype": "float32",
|
||||
"shape": (16,),
|
||||
"names": [f"extrinsic_{i}" for i in range(16)],
|
||||
},
|
||||
"action": {
|
||||
"dtype": "float32",
|
||||
"shape": (48,),
|
||||
"names": [f"action_{i}" for i in range(48)],
|
||||
},
|
||||
}
|
||||
|
||||
# 1. Discover all HDF5 files
|
||||
files = sorted(list(self.raw_dir.rglob("*.hdf5")))
|
||||
if not files:
|
||||
print(f"No HDF5 files found in {self.raw_dir}")
|
||||
return
|
||||
|
||||
# 2. Apply percentage filter
|
||||
if self.percentage < 100:
|
||||
num_files = max(1, int(len(files) * self.percentage / 100))
|
||||
files = files[:num_files]
|
||||
print(f"Processing {self.percentage}% of dataset: {num_files} files")
|
||||
|
||||
# 3. Assign files to this worker
|
||||
my_files = files[rank::world_size]
|
||||
if not my_files:
|
||||
print(f"Rank {rank} has no files to process.")
|
||||
return
|
||||
|
||||
print(f"Rank {rank} processing {len(my_files)} files out of {len(files)} total.")
|
||||
|
||||
# 4. Create a LeRobot dataset for this shard
|
||||
shard_repo_id = f"{self.repo_id}_world_{world_size}_rank_{rank}"
|
||||
shard_root = self.local_dir / shard_repo_id if self.local_dir else None
|
||||
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=shard_repo_id,
|
||||
fps=30,
|
||||
robot_type="hand",
|
||||
features=EGODEX_FEATURES,
|
||||
root=shard_root,
|
||||
)
|
||||
|
||||
# 5. Process each file
|
||||
for input_h5 in my_files:
|
||||
try:
|
||||
# Derive corresponding video path
|
||||
video_path = input_h5.with_suffix(".mp4")
|
||||
if not video_path.exists():
|
||||
print(f"Warning: Video file not found for {input_h5}, skipping.")
|
||||
continue
|
||||
|
||||
# Process demo and add frames
|
||||
frames = process_demo(input_h5, video_path)
|
||||
for frame in frames:
|
||||
dataset.add_frame(frame)
|
||||
dataset.save_episode()
|
||||
|
||||
# Clean up to avoid OOM
|
||||
del frames
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing {input_h5}: {e}")
|
||||
continue
|
||||
|
||||
# 6. Finalize the dataset
|
||||
dataset.finalize()
|
||||
|
||||
|
||||
def make_port_executor(
|
||||
raw_dir,
|
||||
repo_id,
|
||||
job_name,
|
||||
logs_dir,
|
||||
workers,
|
||||
partition,
|
||||
cpus_per_task,
|
||||
mem_per_cpu,
|
||||
local_dir,
|
||||
percentage,
|
||||
slurm=True,
|
||||
):
|
||||
kwargs = {
|
||||
"pipeline": [
|
||||
PortEgoDexShards(raw_dir, repo_id, local_dir, percentage),
|
||||
],
|
||||
"logging_dir": str(logs_dir / job_name),
|
||||
}
|
||||
|
||||
if slurm:
|
||||
kwargs.update(
|
||||
{
|
||||
"job_name": job_name,
|
||||
"tasks": workers,
|
||||
"workers": workers,
|
||||
"time": "10:00:00", # EgoDex is large, allow more time
|
||||
"partition": partition,
|
||||
"cpus_per_task": cpus_per_task,
|
||||
"sbatch_args": {"mem-per-cpu": mem_per_cpu},
|
||||
}
|
||||
)
|
||||
executor = SlurmPipelineExecutor(**kwargs)
|
||||
else:
|
||||
kwargs.update(
|
||||
{
|
||||
"tasks": workers,
|
||||
"workers": 1, # Run locally sequentially for debugging
|
||||
}
|
||||
)
|
||||
executor = LocalPipelineExecutor(**kwargs)
|
||||
|
||||
return executor
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Convert EgoDex dataset to LeRobot format using SLURM."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--raw-dir",
|
||||
type=Path,
|
||||
required=True,
|
||||
help="Directory containing input EgoDex data (HDF5 + MP4 files).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Repository identifier (e.g., user/egodex-lerobot).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--logs-dir",
|
||||
type=Path,
|
||||
default=Path("logs"),
|
||||
help="Path to logs directory.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--job-name",
|
||||
type=str,
|
||||
default="port_egodex",
|
||||
help="Job name used in SLURM.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--slurm",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Launch over SLURM. Use --slurm 0 to launch sequentially (useful for debugging).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--workers",
|
||||
type=int,
|
||||
default=50,
|
||||
help="Number of SLURM workers.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--partition",
|
||||
type=str,
|
||||
help="SLURM partition.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cpus-per-task",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Number of CPUs per worker.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mem-per-cpu",
|
||||
type=str,
|
||||
default="4G",
|
||||
help="Memory per CPU.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--percentage",
|
||||
type=float,
|
||||
default=100.0,
|
||||
help="Percentage of dataset to process (e.g., 1.0 for 1%%). Useful for testing.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--local-dir",
|
||||
type=Path,
|
||||
default=None,
|
||||
help="Local directory to save the LeRobot dataset. Defaults to data/local_datasets.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
kwargs = vars(args)
|
||||
kwargs["slurm"] = kwargs.pop("slurm") == 1
|
||||
|
||||
port_executor = make_port_executor(**kwargs)
|
||||
port_executor.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -455,18 +455,7 @@ def demo_cli(cfg: RTCDemoConfig):
|
||||
if cfg.policy.type == "pi05" or cfg.policy.type == "pi0":
|
||||
config.compile_model = cfg.use_torch_compile
|
||||
|
||||
if config.use_peft:
|
||||
from peft import PeftConfig, PeftModel
|
||||
|
||||
peft_pretrained_path = cfg.policy.pretrained_path
|
||||
peft_config = PeftConfig.from_pretrained(peft_pretrained_path)
|
||||
|
||||
policy = policy_class.from_pretrained(
|
||||
pretrained_name_or_path=peft_config.base_model_name_or_path, config=config
|
||||
)
|
||||
policy = PeftModel.from_pretrained(policy, peft_pretrained_path, config=peft_config)
|
||||
else:
|
||||
policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=config)
|
||||
policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=config)
|
||||
|
||||
# Turn on RTC
|
||||
policy.config.rtc_config = cfg.rtc
|
||||
|
||||
|
After Width: | Height: | Size: 2.9 MiB |
|
After Width: | Height: | Size: 185 KiB |
|
After Width: | Height: | Size: 464 KiB |
|
After Width: | Height: | Size: 72 KiB |
|
After Width: | Height: | Size: 219 KiB |
|
After Width: | Height: | Size: 199 KiB |
|
Before Width: | Height: | Size: 160 KiB After Width: | Height: | Size: 160 KiB |
|
Before Width: | Height: | Size: 774 KiB |
|
Before Width: | Height: | Size: 2.3 MiB |
|
Before Width: | Height: | Size: 481 KiB |
|
After Width: | Height: | Size: 117 KiB |
|
After Width: | Height: | Size: 151 KiB |
|
After Width: | Height: | Size: 130 KiB |
|
After Width: | Height: | Size: 407 KiB |
@@ -96,7 +96,7 @@ dependencies = [
|
||||
# Common
|
||||
pygame-dep = ["pygame>=2.5.1,<2.7.0"]
|
||||
placo-dep = ["placo>=0.9.6,<0.10.0"]
|
||||
transformers-dep = ["transformers>=4.57.1,<5.0.0"]
|
||||
transformers-dep = ["transformers>=4.53.0,<5.0.0"]
|
||||
grpcio-dep = ["grpcio==1.73.1", "protobuf==6.31.0"] # TODO: Bumb dependency (compatible with wandb)
|
||||
|
||||
# Motors
|
||||
@@ -120,13 +120,6 @@ intelrealsense = [
|
||||
phone = ["hebi-py>=2.8.0,<2.12.0", "teleop>=0.1.0,<0.2.0", "fastapi<1.0"]
|
||||
|
||||
# Policies
|
||||
wallx = [
|
||||
"transformers==4.49.0",
|
||||
"peft==0.17.1",
|
||||
"scipy==1.15.3",
|
||||
"torchdiffeq==0.2.5",
|
||||
"qwen_vl_utils==0.0.11"
|
||||
]
|
||||
pi = ["transformers @ git+https://github.com/huggingface/transformers.git@fix/lerobot_openpi"]
|
||||
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0", "safetensors>=0.4.3,<1.0.0"]
|
||||
groot = [
|
||||
@@ -140,16 +133,13 @@ groot = [
|
||||
"ninja>=1.11.1,<2.0.0",
|
||||
"flash-attn>=2.5.9,<3.0.0 ; sys_platform != 'darwin'"
|
||||
]
|
||||
sarm = ["lerobot[transformers-dep]", "faker>=33.0.0,<35.0.0", "matplotlib>=3.10.3,<4.0.0", "qwen-vl-utils>=0.0.14"]
|
||||
xvla = ["lerobot[transformers-dep]"]
|
||||
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||
|
||||
# Features
|
||||
async = ["lerobot[grpcio-dep]", "matplotlib>=3.10.3,<4.0.0"]
|
||||
peft = ["lerobot[transformers-dep]", "peft>=0.18.0"]
|
||||
|
||||
# Development
|
||||
dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1", "mypy>=1.19.1"]
|
||||
dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1"]
|
||||
test = ["pytest>=8.1.0,<9.0.0", "pytest-timeout>=2.4.0,<3.0.0", "pytest-cov>=5.0.0,<8.0.0", "mock-serial>=0.0.1,<0.1.0 ; sys_platform != 'win32'"]
|
||||
video_benchmark = ["scikit-image>=0.23.2,<0.26.0", "pandas>=2.2.2,<2.4.0"]
|
||||
|
||||
@@ -168,11 +158,9 @@ all = [
|
||||
"lerobot[reachy2]",
|
||||
"lerobot[kinematics]",
|
||||
"lerobot[intelrealsense]",
|
||||
# "lerobot[wallx]",
|
||||
"lerobot[pi]",
|
||||
"lerobot[smolvla]",
|
||||
# "lerobot[groot]", TODO(Steven): Gr00t requires specific installation instructions for flash-attn
|
||||
"lerobot[xvla]",
|
||||
"lerobot[hilserl]",
|
||||
"lerobot[async]",
|
||||
"lerobot[dev]",
|
||||
@@ -183,8 +171,6 @@ all = [
|
||||
"lerobot[phone]",
|
||||
"lerobot[libero]",
|
||||
"lerobot[metaworld]",
|
||||
"lerobot[sarm]",
|
||||
"lerobot[peft]",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
@@ -239,7 +225,6 @@ ignore = [
|
||||
|
||||
[tool.ruff.lint.per-file-ignores]
|
||||
"__init__.py" = ["F401", "F403"]
|
||||
"src/lerobot/policies/wall_x/**" = ["N801", "N812", "SIM102", "SIM108", "SIM210", "SIM211", "B006", "B007", "SIM118"] # Supprese these as they are coming from original Qwen2_5_vl code TODO(pepijn): refactor original
|
||||
|
||||
[tool.ruff.lint.isort]
|
||||
combine-as-imports = true
|
||||
@@ -276,7 +261,6 @@ default.extend-ignore-identifiers-re = [
|
||||
"ein",
|
||||
"thw",
|
||||
"inpt",
|
||||
"ROBOTIS",
|
||||
]
|
||||
|
||||
# TODO: Uncomment when ready to use
|
||||
@@ -331,9 +315,9 @@ disallow_untyped_defs = true
|
||||
disallow_incomplete_defs = true
|
||||
check_untyped_defs = true
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "lerobot.optim.*"
|
||||
ignore_errors = false
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.optim.*"
|
||||
# ignore_errors = false
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "lerobot.model.*"
|
||||
@@ -376,47 +360,10 @@ ignore_errors = false
|
||||
# module = "lerobot.async_inference.*"
|
||||
# ignore_errors = false
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "lerobot.transport.*"
|
||||
ignore_errors = false
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.transport.*"
|
||||
# ignore_errors = false
|
||||
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.scripts.*"
|
||||
# ignore_errors = false
|
||||
|
||||
[tool.uv]
|
||||
# wallx requires transformers==4.49.0 which conflicts with other extras that need >=4.53.0
|
||||
conflicts = [
|
||||
[
|
||||
{ extra = "wallx" },
|
||||
{ extra = "transformers-dep" },
|
||||
],
|
||||
[
|
||||
{ extra = "wallx" },
|
||||
{ extra = "pi" },
|
||||
],
|
||||
[
|
||||
{ extra = "wallx" },
|
||||
{ extra = "smolvla" },
|
||||
],
|
||||
[
|
||||
{ extra = "wallx" },
|
||||
{ extra = "groot" },
|
||||
],
|
||||
[
|
||||
{ extra = "wallx" },
|
||||
{ extra = "xvla" },
|
||||
],
|
||||
[
|
||||
{ extra = "wallx" },
|
||||
{ extra = "hilserl" },
|
||||
],
|
||||
[
|
||||
{ extra = "wallx" },
|
||||
{ extra = "libero" },
|
||||
],
|
||||
[
|
||||
{ extra = "wallx" },
|
||||
{ extra = "all" },
|
||||
],
|
||||
]
|
||||
|
||||
@@ -26,4 +26,4 @@ DEFAULT_OBS_QUEUE_TIMEOUT = 2
|
||||
SUPPORTED_POLICIES = ["act", "smolvla", "diffusion", "tdmpc", "vqbet", "pi0", "pi05"]
|
||||
|
||||
# TODO: Add all other robots
|
||||
SUPPORTED_ROBOTS = ["so100_follower", "so101_follower", "bi_so100_follower", "omx_follower"]
|
||||
SUPPORTED_ROBOTS = ["so100_follower", "so101_follower", "bi_so100_follower"]
|
||||
|
||||
@@ -54,7 +54,6 @@ from lerobot.robots import ( # noqa: F401
|
||||
bi_so100_follower,
|
||||
koch_follower,
|
||||
make_robot_from_config,
|
||||
omx_follower,
|
||||
so100_follower,
|
||||
so101_follower,
|
||||
)
|
||||
|
||||
@@ -67,31 +67,3 @@ class EvalConfig:
|
||||
f"to increase the number of episodes to match the batch size (e.g. `eval.n_episodes={self.batch_size}`), "
|
||||
f"or lower the batch size (e.g. `eval.batch_size={self.n_episodes}`)."
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PeftConfig:
|
||||
# PEFT offers many fine-tuning methods, layer adapters being the most common and currently also the most
|
||||
# effective methods so we'll focus on those in this high-level config interface.
|
||||
|
||||
# Either a string (module name suffix or 'all-linear'), a list of module name suffixes or a regular expression
|
||||
# describing module names to target with the configured PEFT method. Some policies have a default value for this
|
||||
# so that you don't *have* to choose which layers to adapt but it might still be worthwhile depending on your case.
|
||||
target_modules: list[str] | str | None = None
|
||||
|
||||
# Names/suffixes of modules to fully fine-tune and store alongside adapter weights. Useful for layers that are
|
||||
# not part of a pre-trained model (e.g., action state projections). Depending on the policy this defaults to layers
|
||||
# that are newly created in pre-trained policies. If you're fine-tuning an already trained policy you might want
|
||||
# to set this to `[]`. Corresponds to PEFT's `modules_to_save`.
|
||||
full_training_modules: list[str] | None = None
|
||||
|
||||
# The PEFT (adapter) method to apply to the policy. Needs to be a valid PEFT type.
|
||||
method_type: str = "LORA"
|
||||
|
||||
# Adapter initialization method. Look at the specific PEFT adapter documentation for defaults.
|
||||
init_type: str | None = None
|
||||
|
||||
# We expect that all PEFT adapters are in some way doing rank-decomposition therefore this parameter specifies
|
||||
# the rank used for the adapter. In general a higher rank means more trainable parameters and closer to full
|
||||
# fine-tuning.
|
||||
r: int = 16
|
||||
|
||||
@@ -55,18 +55,14 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno
|
||||
|
||||
n_obs_steps: int = 1
|
||||
|
||||
# `input_features` can be set to None/null in order to infer those values from the dataset.
|
||||
input_features: dict[str, PolicyFeature] | None = field(default_factory=dict)
|
||||
output_features: dict[str, PolicyFeature] | None = field(default_factory=dict)
|
||||
input_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||
output_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||
|
||||
device: str | None = None # e.g. "cuda", "cuda:0", "cpu", or "mps"
|
||||
# `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
|
||||
# automatic gradient scaling is used.
|
||||
use_amp: bool = False
|
||||
|
||||
# Whether the policy employed PEFT for training.
|
||||
use_peft: bool = False
|
||||
|
||||
push_to_hub: bool = True # type: ignore[assignment] # TODO: use a different name to avoid override
|
||||
repo_id: str | None = None
|
||||
|
||||
@@ -129,8 +125,6 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno
|
||||
|
||||
@property
|
||||
def robot_state_feature(self) -> PolicyFeature | None:
|
||||
if self.input_features is None:
|
||||
return None
|
||||
for ft_name, ft in self.input_features.items():
|
||||
if ft.type is FeatureType.STATE and ft_name == OBS_STATE:
|
||||
return ft
|
||||
@@ -138,8 +132,6 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno
|
||||
|
||||
@property
|
||||
def env_state_feature(self) -> PolicyFeature | None:
|
||||
if self.input_features is None:
|
||||
return None
|
||||
for _, ft in self.input_features.items():
|
||||
if ft.type is FeatureType.ENV:
|
||||
return ft
|
||||
@@ -147,14 +139,10 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno
|
||||
|
||||
@property
|
||||
def image_features(self) -> dict[str, PolicyFeature]:
|
||||
if self.input_features is None:
|
||||
return {}
|
||||
return {key: ft for key, ft in self.input_features.items() if ft.type is FeatureType.VISUAL}
|
||||
|
||||
@property
|
||||
def action_feature(self) -> PolicyFeature | None:
|
||||
if self.output_features is None:
|
||||
return None
|
||||
for ft_name, ft in self.output_features.items():
|
||||
if ft.type is FeatureType.ACTION and ft_name == ACTION:
|
||||
return ft
|
||||
|
||||
@@ -24,7 +24,7 @@ from huggingface_hub.errors import HfHubHTTPError
|
||||
|
||||
from lerobot import envs
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.default import DatasetConfig, EvalConfig, PeftConfig, WandBConfig
|
||||
from lerobot.configs.default import DatasetConfig, EvalConfig, WandBConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.optim import OptimizerConfig
|
||||
from lerobot.optim.schedulers import LRSchedulerConfig
|
||||
@@ -56,7 +56,6 @@ class TrainPipelineConfig(HubMixin):
|
||||
steps: int = 100_000
|
||||
eval_freq: int = 20_000
|
||||
log_freq: int = 200
|
||||
tolerance_s: float = 1e-4
|
||||
save_checkpoint: bool = True
|
||||
# Checkpoint is saved every `save_freq` training iterations and after the last training step.
|
||||
save_freq: int = 20_000
|
||||
@@ -65,18 +64,9 @@ class TrainPipelineConfig(HubMixin):
|
||||
scheduler: LRSchedulerConfig | None = None
|
||||
eval: EvalConfig = field(default_factory=EvalConfig)
|
||||
wandb: WandBConfig = field(default_factory=WandBConfig)
|
||||
peft: PeftConfig | None = None
|
||||
|
||||
# RA-BC (Reward-Aligned Behavior Cloning) parameters
|
||||
use_rabc: bool = False # Enable reward-weighted training
|
||||
rabc_progress_path: str | None = None # Path to precomputed SARM progress parquet file
|
||||
rabc_kappa: float = 0.01 # Hard threshold for high-quality samples
|
||||
rabc_epsilon: float = 1e-6 # Small constant for numerical stability
|
||||
rabc_head_mode: str | None = "sparse" # For dual-head models: "sparse" or "dense"
|
||||
|
||||
checkpoint_path: Path | None = field(init=False, default=None)
|
||||
# Rename map for the observation to override the image and state keys
|
||||
rename_map: dict[str, str] = field(default_factory=dict)
|
||||
checkpoint_path: Path | None = field(init=False, default=None)
|
||||
|
||||
def validate(self) -> None:
|
||||
# HACK: We parse again the cli args here to get the pretrained paths if there was some.
|
||||
@@ -140,14 +130,6 @@ class TrainPipelineConfig(HubMixin):
|
||||
"'policy.repo_id' argument missing. Please specify it to push the model to the hub."
|
||||
)
|
||||
|
||||
if self.use_rabc and not self.rabc_progress_path:
|
||||
# Auto-detect from dataset path
|
||||
repo_id = self.dataset.repo_id
|
||||
if self.dataset.root:
|
||||
self.rabc_progress_path = str(Path(self.dataset.root) / "sarm_progress.parquet")
|
||||
else:
|
||||
self.rabc_progress_path = f"hf://datasets/{repo_id}/sarm_progress.parquet"
|
||||
|
||||
@classmethod
|
||||
def __get_path_fields__(cls) -> list[str]:
|
||||
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
@@ -1,13 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
@@ -136,40 +136,21 @@ def update_meta_data(
|
||||
df["_orig_chunk"] = df[orig_chunk_col].copy()
|
||||
df["_orig_file"] = df[orig_file_col].copy()
|
||||
|
||||
# Get mappings for this video key
|
||||
# Update chunk and file indices to point to destination
|
||||
df[orig_chunk_col] = video_idx["chunk"]
|
||||
df[orig_file_col] = video_idx["file"]
|
||||
|
||||
# Apply per-source-file timestamp offsets
|
||||
src_to_offset = video_idx.get("src_to_offset", {})
|
||||
src_to_dst = video_idx.get("src_to_dst", {})
|
||||
|
||||
# Apply per-source-file mappings
|
||||
if src_to_dst:
|
||||
# Map each episode to its correct destination file and apply offset
|
||||
if src_to_offset:
|
||||
# Apply offset based on original source file
|
||||
for idx in df.index:
|
||||
# Convert to Python int to avoid numpy type mismatch in dict lookup
|
||||
src_key = (int(df.at[idx, "_orig_chunk"]), int(df.at[idx, "_orig_file"]))
|
||||
|
||||
# Get destination chunk/file for this source file
|
||||
dst_chunk, dst_file = src_to_dst.get(src_key, (video_idx["chunk"], video_idx["file"]))
|
||||
df.at[idx, orig_chunk_col] = dst_chunk
|
||||
df.at[idx, orig_file_col] = dst_file
|
||||
|
||||
# Apply timestamp offset
|
||||
offset = src_to_offset.get(src_key, 0)
|
||||
df.at[idx, f"videos/{key}/from_timestamp"] += offset
|
||||
df.at[idx, f"videos/{key}/to_timestamp"] += offset
|
||||
elif src_to_offset:
|
||||
# Fallback: use same destination for all, but apply per-file offsets
|
||||
df[orig_chunk_col] = video_idx["chunk"]
|
||||
df[orig_file_col] = video_idx["file"]
|
||||
for idx in df.index:
|
||||
# Convert to Python int to avoid numpy type mismatch in dict lookup
|
||||
src_key = (int(df.at[idx, "_orig_chunk"]), int(df.at[idx, "_orig_file"]))
|
||||
src_key = (df.at[idx, "_orig_chunk"], df.at[idx, "_orig_file"])
|
||||
offset = src_to_offset.get(src_key, 0)
|
||||
df.at[idx, f"videos/{key}/from_timestamp"] += offset
|
||||
df.at[idx, f"videos/{key}/to_timestamp"] += offset
|
||||
else:
|
||||
# Fallback to simple offset (for backward compatibility)
|
||||
df[orig_chunk_col] = video_idx["chunk"]
|
||||
df[orig_file_col] = video_idx["file"]
|
||||
df[f"videos/{key}/from_timestamp"] = (
|
||||
df[f"videos/{key}/from_timestamp"] + video_idx["latest_duration"]
|
||||
)
|
||||
@@ -287,12 +268,6 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
|
||||
videos_idx[key]["episode_duration"] = 0
|
||||
# Track offset for each source (chunk, file) pair
|
||||
videos_idx[key]["src_to_offset"] = {}
|
||||
# Track destination (chunk, file) for each source (chunk, file) pair
|
||||
videos_idx[key]["src_to_dst"] = {}
|
||||
# Initialize dst_file_durations if not present
|
||||
# dst_file_durations tracks duration of each destination file
|
||||
if "dst_file_durations" not in videos_idx[key]:
|
||||
videos_idx[key]["dst_file_durations"] = {}
|
||||
|
||||
for key, video_idx in videos_idx.items():
|
||||
unique_chunk_file_pairs = {
|
||||
@@ -307,13 +282,9 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
|
||||
|
||||
chunk_idx = video_idx["chunk"]
|
||||
file_idx = video_idx["file"]
|
||||
dst_file_durations = video_idx["dst_file_durations"]
|
||||
current_offset = video_idx["latest_duration"]
|
||||
|
||||
for src_chunk_idx, src_file_idx in unique_chunk_file_pairs:
|
||||
# Convert to Python int to ensure consistent dict keys
|
||||
src_chunk_idx = int(src_chunk_idx)
|
||||
src_file_idx = int(src_file_idx)
|
||||
|
||||
src_path = src_meta.root / DEFAULT_VIDEO_PATH.format(
|
||||
video_key=key,
|
||||
chunk_index=src_chunk_idx,
|
||||
@@ -327,17 +298,14 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
|
||||
)
|
||||
|
||||
src_duration = get_video_duration_in_s(src_path)
|
||||
dst_key = (chunk_idx, file_idx)
|
||||
|
||||
if not dst_path.exists():
|
||||
# New destination file: offset is 0
|
||||
videos_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = 0
|
||||
videos_idx[key]["src_to_dst"][(src_chunk_idx, src_file_idx)] = dst_key
|
||||
# Store offset before incrementing
|
||||
videos_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = current_offset
|
||||
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy(str(src_path), str(dst_path))
|
||||
# Track duration of this destination file
|
||||
dst_file_durations[dst_key] = src_duration
|
||||
videos_idx[key]["episode_duration"] += src_duration
|
||||
current_offset += src_duration
|
||||
continue
|
||||
|
||||
# Check file sizes before appending
|
||||
@@ -345,11 +313,10 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
|
||||
dst_size = get_file_size_in_mb(dst_path)
|
||||
|
||||
if dst_size + src_size >= video_files_size_in_mb:
|
||||
# Rotate to a new file - offset is 0
|
||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, chunk_size)
|
||||
dst_key = (chunk_idx, file_idx)
|
||||
# Rotate to a new file, this source becomes start of new destination
|
||||
# So its offset should be 0
|
||||
videos_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = 0
|
||||
videos_idx[key]["src_to_dst"][(src_chunk_idx, src_file_idx)] = dst_key
|
||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, chunk_size)
|
||||
dst_path = dst_meta.root / DEFAULT_VIDEO_PATH.format(
|
||||
video_key=key,
|
||||
chunk_index=chunk_idx,
|
||||
@@ -357,20 +324,16 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
|
||||
)
|
||||
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy(str(src_path), str(dst_path))
|
||||
# Track duration of this new destination file
|
||||
dst_file_durations[dst_key] = src_duration
|
||||
# Reset offset for next file
|
||||
current_offset = src_duration
|
||||
else:
|
||||
# Append to existing destination file
|
||||
# Offset is the current duration of this destination file
|
||||
current_dst_duration = dst_file_durations.get(dst_key, 0)
|
||||
videos_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = current_dst_duration
|
||||
videos_idx[key]["src_to_dst"][(src_chunk_idx, src_file_idx)] = dst_key
|
||||
# Append to existing video file - use current accumulated offset
|
||||
videos_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = current_offset
|
||||
concatenate_video_files(
|
||||
[dst_path, src_path],
|
||||
dst_path,
|
||||
)
|
||||
# Update duration of this destination file
|
||||
dst_file_durations[dst_key] = current_dst_duration + src_duration
|
||||
current_offset += src_duration
|
||||
|
||||
videos_idx[key]["episode_duration"] += src_duration
|
||||
|
||||
|
||||
@@ -98,7 +98,6 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
|
||||
image_transforms=image_transforms,
|
||||
revision=cfg.dataset.revision,
|
||||
video_backend=cfg.dataset.video_backend,
|
||||
tolerance_s=cfg.tolerance_s,
|
||||
)
|
||||
else:
|
||||
dataset = StreamingLeRobotDataset(
|
||||
@@ -109,7 +108,6 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
|
||||
image_transforms=image_transforms,
|
||||
revision=cfg.dataset.revision,
|
||||
max_num_shards=cfg.num_workers,
|
||||
tolerance_s=cfg.tolerance_s,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("The MultiLeRobotDataset isn't supported for now.")
|
||||
|
||||
@@ -245,7 +245,7 @@ class HILSerlRobotEnvConfig(EnvConfig):
|
||||
class LiberoEnv(EnvConfig):
|
||||
task: str = "libero_10" # can also choose libero_spatial, libero_object, etc.
|
||||
fps: int = 30
|
||||
episode_length: int | None = None
|
||||
episode_length: int = 520
|
||||
obs_type: str = "pixels_agent_pos"
|
||||
render_mode: str = "rgb_array"
|
||||
camera_name: str = "agentview_image,robot0_eye_in_hand_image"
|
||||
@@ -272,7 +272,6 @@ class LiberoEnv(EnvConfig):
|
||||
LIBERO_KEY_PIXELS_EYE_IN_HAND: f"{OBS_IMAGES}.image2",
|
||||
}
|
||||
)
|
||||
control_mode: str = "relative" # or "absolute"
|
||||
|
||||
def __post_init__(self):
|
||||
if self.obs_type == "pixels":
|
||||
|
||||
@@ -19,10 +19,8 @@ from typing import Any
|
||||
import gymnasium as gym
|
||||
from gymnasium.envs.registration import registry as gym_registry
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.envs.configs import AlohaEnv, EnvConfig, LiberoEnv, PushtEnv
|
||||
from lerobot.envs.utils import _call_make_env, _download_hub_file, _import_hub_module, _normalize_hub_result
|
||||
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
|
||||
from lerobot.processor import ProcessorStep
|
||||
from lerobot.processor.env_processor import LiberoProcessorStep
|
||||
from lerobot.processor.pipeline import PolicyProcessorPipeline
|
||||
@@ -41,7 +39,6 @@ def make_env_config(env_type: str, **kwargs) -> EnvConfig:
|
||||
|
||||
def make_env_pre_post_processors(
|
||||
env_cfg: EnvConfig,
|
||||
policy_cfg: PreTrainedConfig,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
@@ -64,10 +61,6 @@ def make_env_pre_post_processors(
|
||||
# Preprocessor and Postprocessor steps are Identity for most environments
|
||||
preprocessor_steps: list[ProcessorStep] = []
|
||||
postprocessor_steps: list[ProcessorStep] = []
|
||||
if isinstance(policy_cfg, XVLAConfig):
|
||||
from lerobot.policies.xvla.processor_xvla import make_xvla_libero_pre_post_processors
|
||||
|
||||
return make_xvla_libero_pre_post_processors()
|
||||
|
||||
# For LIBERO environments, add the LiberoProcessorStep to preprocessor
|
||||
if isinstance(env_cfg, LiberoEnv) or "libero" in env_cfg.type:
|
||||
@@ -143,8 +136,6 @@ def make_env(
|
||||
init_states=cfg.init_states,
|
||||
gym_kwargs=cfg.gym_kwargs,
|
||||
env_cls=env_cls,
|
||||
control_mode=cfg.control_mode,
|
||||
episode_length=cfg.episode_length,
|
||||
)
|
||||
elif "metaworld" in cfg.type:
|
||||
from lerobot.envs.metaworld import create_metaworld_envs
|
||||
|
||||
@@ -80,7 +80,10 @@ def get_libero_dummy_action():
|
||||
return [0, 0, 0, 0, 0, 0, -1]
|
||||
|
||||
|
||||
OBS_STATE_DIM = 8
|
||||
ACTION_DIM = 7
|
||||
AGENT_POS_LOW = -1000.0
|
||||
AGENT_POS_HIGH = 1000.0
|
||||
ACTION_LOW = -1.0
|
||||
ACTION_HIGH = 1.0
|
||||
TASK_SUITE_MAX_STEPS: dict[str, int] = {
|
||||
@@ -100,7 +103,6 @@ class LiberoEnv(gym.Env):
|
||||
task_suite: Any,
|
||||
task_id: int,
|
||||
task_suite_name: str,
|
||||
episode_length: int | None = None,
|
||||
camera_name: str | Sequence[str] = "agentview_image,robot0_eye_in_hand_image",
|
||||
obs_type: str = "pixels",
|
||||
render_mode: str = "rgb_array",
|
||||
@@ -112,7 +114,6 @@ class LiberoEnv(gym.Env):
|
||||
episode_index: int = 0,
|
||||
camera_name_mapping: dict[str, str] | None = None,
|
||||
num_steps_wait: int = 10,
|
||||
control_mode: str = "relative",
|
||||
):
|
||||
super().__init__()
|
||||
self.task_id = task_id
|
||||
@@ -140,19 +141,14 @@ class LiberoEnv(gym.Env):
|
||||
self.camera_name_mapping = camera_name_mapping
|
||||
self.num_steps_wait = num_steps_wait
|
||||
self.episode_index = episode_index
|
||||
self.episode_length = episode_length
|
||||
# Load once and keep
|
||||
self._init_states = get_task_init_states(task_suite, self.task_id) if self.init_states else None
|
||||
self._init_state_id = self.episode_index # tie each sub-env to a fixed init state
|
||||
|
||||
self._env = self._make_envs_task(task_suite, self.task_id)
|
||||
default_steps = 500
|
||||
self._max_episode_steps = (
|
||||
TASK_SUITE_MAX_STEPS.get(task_suite_name, default_steps)
|
||||
if self.episode_length is None
|
||||
else self.episode_length
|
||||
)
|
||||
self.control_mode = control_mode
|
||||
self._max_episode_steps = TASK_SUITE_MAX_STEPS.get(task_suite_name, default_steps)
|
||||
|
||||
images = {}
|
||||
for cam in self.camera_name:
|
||||
images[self.camera_name_mapping[cam]] = spaces.Box(
|
||||
@@ -300,15 +296,6 @@ class LiberoEnv(gym.Env):
|
||||
# Increasing this value can improve determinism and reproducibility across resets.
|
||||
for _ in range(self.num_steps_wait):
|
||||
raw_obs, _, _, _ = self._env.step(get_libero_dummy_action())
|
||||
|
||||
if self.control_mode == "absolute":
|
||||
for robot in self._env.robots:
|
||||
robot.controller.use_delta = False
|
||||
elif self.control_mode == "relative":
|
||||
for robot in self._env.robots:
|
||||
robot.controller.use_delta = True
|
||||
else:
|
||||
raise ValueError(f"Invalid control mode: {self.control_mode}")
|
||||
observation = self._format_raw_obs(raw_obs)
|
||||
info = {"is_success": False}
|
||||
return observation, info
|
||||
@@ -354,10 +341,8 @@ def _make_env_fns(
|
||||
task_id: int,
|
||||
n_envs: int,
|
||||
camera_names: list[str],
|
||||
episode_length: int | None,
|
||||
init_states: bool,
|
||||
gym_kwargs: Mapping[str, Any],
|
||||
control_mode: str,
|
||||
) -> list[Callable[[], LiberoEnv]]:
|
||||
"""Build n_envs factory callables for a single (suite, task_id)."""
|
||||
|
||||
@@ -369,9 +354,7 @@ def _make_env_fns(
|
||||
task_suite_name=suite_name,
|
||||
camera_name=camera_names,
|
||||
init_states=init_states,
|
||||
episode_length=episode_length,
|
||||
episode_index=episode_index,
|
||||
control_mode=control_mode,
|
||||
**local_kwargs,
|
||||
)
|
||||
|
||||
@@ -391,8 +374,6 @@ def create_libero_envs(
|
||||
camera_name: str | Sequence[str] = "agentview_image,robot0_eye_in_hand_image",
|
||||
init_states: bool = True,
|
||||
env_cls: Callable[[Sequence[Callable[[], Any]]], Any] | None = None,
|
||||
control_mode: str = "relative",
|
||||
episode_length: int | None = None,
|
||||
) -> dict[str, dict[int, Any]]:
|
||||
"""
|
||||
Create vectorized LIBERO environments with a consistent return shape.
|
||||
@@ -434,14 +415,12 @@ def create_libero_envs(
|
||||
for tid in selected:
|
||||
fns = _make_env_fns(
|
||||
suite=suite,
|
||||
episode_length=episode_length,
|
||||
suite_name=suite_name,
|
||||
task_id=tid,
|
||||
n_envs=n_envs,
|
||||
camera_names=camera_names,
|
||||
init_states=init_states,
|
||||
gym_kwargs=gym_kwargs,
|
||||
control_mode=control_mode,
|
||||
)
|
||||
out[suite_name][tid] = env_cls(fns)
|
||||
print(f"Built vec env | suite={suite_name} | task_id={tid} | n_envs={n_envs}")
|
||||
|
||||
@@ -35,8 +35,6 @@ def make_optimizer_and_scheduler(
|
||||
tuple[Optimizer, LRScheduler | None]: The couple (Optimizer, Scheduler). Scheduler can be `None`.
|
||||
"""
|
||||
params = policy.get_optim_params() if cfg.use_policy_training_preset else policy.parameters()
|
||||
if cfg.optimizer is None:
|
||||
raise ValueError("Optimizer config is required but not provided in TrainPipelineConfig")
|
||||
optimizer = cfg.optimizer.build(params)
|
||||
lr_scheduler = cfg.scheduler.build(optimizer, cfg.steps) if cfg.scheduler is not None else None
|
||||
return optimizer, lr_scheduler
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import abc
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
@@ -30,17 +29,6 @@ from lerobot.utils.constants import (
|
||||
)
|
||||
from lerobot.utils.io_utils import deserialize_json_into_object
|
||||
|
||||
# Type alias for parameters accepted by optimizer build() methods.
|
||||
# This matches PyTorch's optimizer signature while also supporting:
|
||||
# - dict[str, Parameter]: Named parameters for differential LR by name (e.g., XVLA)
|
||||
# - dict[str, Iterable]: Multiple parameter groups for multi-optimizer configs (e.g., SAC)
|
||||
OptimizerParams = (
|
||||
Iterable[torch.nn.Parameter] # From model.parameters()
|
||||
| Iterable[dict[str, Any]] # List of param groups with lr/weight_decay overrides
|
||||
| dict[str, torch.nn.Parameter] # From dict(model.named_parameters()) for name-based LR
|
||||
| dict[str, Any] # For multi-optimizer configs (SAC) with multiple param groups
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class OptimizerConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
@@ -57,24 +45,13 @@ class OptimizerConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
return "adam"
|
||||
|
||||
@abc.abstractmethod
|
||||
def build(self, params: OptimizerParams) -> torch.optim.Optimizer | dict[str, torch.optim.Optimizer]:
|
||||
def build(self) -> torch.optim.Optimizer | dict[str, torch.optim.Optimizer]:
|
||||
"""
|
||||
Build the optimizer. It can be a single optimizer or a dictionary of optimizers.
|
||||
|
||||
NOTE: Multiple optimizers are useful when you have different models to optimize.
|
||||
For example, you can have one optimizer for the policy and another one for the value function
|
||||
in reinforcement learning settings.
|
||||
|
||||
Args:
|
||||
params: Parameters to optimize. Accepts multiple formats depending on the optimizer:
|
||||
- Iterable[Parameter]: From model.parameters() - standard PyTorch usage
|
||||
- Iterable[dict]: List of param groups with 'params' key and optional
|
||||
'lr', 'weight_decay' overrides (e.g., ACT, VQBeT policies)
|
||||
- dict[str, Parameter]: From dict(model.named_parameters()) for optimizers
|
||||
that apply differential learning rates by parameter name (e.g., XVLA)
|
||||
- dict[str, Iterable]: For multi-optimizer configs where each key maps to
|
||||
a separate optimizer's parameters (e.g., SAC with actor/critic/temperature)
|
||||
|
||||
Returns:
|
||||
The optimizer or a dictionary of optimizers.
|
||||
"""
|
||||
@@ -90,7 +67,7 @@ class AdamConfig(OptimizerConfig):
|
||||
weight_decay: float = 0.0
|
||||
grad_clip_norm: float = 10.0
|
||||
|
||||
def build(self, params: OptimizerParams) -> torch.optim.Optimizer:
|
||||
def build(self, params: dict) -> torch.optim.Optimizer:
|
||||
kwargs = asdict(self)
|
||||
kwargs.pop("grad_clip_norm")
|
||||
return torch.optim.Adam(params, **kwargs)
|
||||
@@ -105,7 +82,7 @@ class AdamWConfig(OptimizerConfig):
|
||||
weight_decay: float = 1e-2
|
||||
grad_clip_norm: float = 10.0
|
||||
|
||||
def build(self, params: OptimizerParams) -> torch.optim.Optimizer:
|
||||
def build(self, params: dict) -> torch.optim.Optimizer:
|
||||
kwargs = asdict(self)
|
||||
kwargs.pop("grad_clip_norm")
|
||||
return torch.optim.AdamW(params, **kwargs)
|
||||
@@ -121,111 +98,12 @@ class SGDConfig(OptimizerConfig):
|
||||
weight_decay: float = 0.0
|
||||
grad_clip_norm: float = 10.0
|
||||
|
||||
def build(self, params: OptimizerParams) -> torch.optim.Optimizer:
|
||||
def build(self, params: dict) -> torch.optim.Optimizer:
|
||||
kwargs = asdict(self)
|
||||
kwargs.pop("grad_clip_norm")
|
||||
return torch.optim.SGD(params, **kwargs)
|
||||
|
||||
|
||||
@OptimizerConfig.register_subclass("xvla-adamw")
|
||||
@dataclass
|
||||
class XVLAAdamWConfig(OptimizerConfig):
|
||||
"""Custom AdamW optimizer for XVLA with differential learning rates.
|
||||
|
||||
The Vision-Language Model (VLM) is trained with 1/10 of the base learning rate
|
||||
for stable optimization, while all other components use the full LR.
|
||||
|
||||
This LR ratio is crucial for achieving strong and stable finetuning performance.
|
||||
|
||||
Soft-prompts can optionally use a separate learning rate with warm-up support.
|
||||
Set `soft_prompt_lr_scale` to a value < 1.0 (e.g., 0.1) to start soft-prompts
|
||||
at a lower LR. Combine with a warmup scheduler for optimal results.
|
||||
|
||||
Note:
|
||||
Completely matching official reported performance may require an additional
|
||||
warm-up LR schedule for soft-prompts, which can bring minor improvements.
|
||||
When `soft_prompt_warmup_lr_scale` is set, soft-prompts start at
|
||||
`lr * soft_prompt_warmup_lr_scale` and should be warmed up via the scheduler.
|
||||
|
||||
Parameter Groups:
|
||||
- Group 0 (vlm): VLM parameters at lr * 0.1, weight_decay * 0.1
|
||||
- Group 1 (soft_prompts): Soft-prompt parameters at lr * soft_prompt_lr_scale
|
||||
- Group 2 (other): All other parameters at full lr
|
||||
"""
|
||||
|
||||
lr: float = 1e-4
|
||||
betas: tuple[float, float] = (0.9, 0.99)
|
||||
eps: float = 1e-8
|
||||
weight_decay: float = 0.0
|
||||
grad_clip_norm: float = 10.0
|
||||
# Soft-prompt specific settings
|
||||
soft_prompt_lr_scale: float = 1.0 # Scale factor for soft-prompt LR (1.0 = same as base LR)
|
||||
soft_prompt_warmup_lr_scale: float | None = None # If set, start soft-prompts at this scale (e.g., 0.01)
|
||||
|
||||
def build(self, params: OptimizerParams) -> torch.optim.Optimizer:
|
||||
"""
|
||||
Build AdamW optimizer with differential learning rates.
|
||||
|
||||
Args:
|
||||
params: Must be a dict[str, Parameter] from dict(model.named_parameters())
|
||||
or equivalent.
|
||||
|
||||
Returns:
|
||||
AdamW optimizer with parameter groups for VLM, soft-prompts, and other components
|
||||
|
||||
Raises:
|
||||
AssertionError: If params is not a dict (e.g., from model.parameters())
|
||||
"""
|
||||
assert isinstance(params, dict), "Custom LR optimizer requires `named_parameters()` as inputs."
|
||||
|
||||
vlm_group, soft_prompt_group, other_group = [], [], []
|
||||
for name, p in params.items():
|
||||
if not p.requires_grad:
|
||||
continue
|
||||
if "vlm" in name.lower():
|
||||
vlm_group.append(p)
|
||||
elif "soft_prompt" in name.lower():
|
||||
soft_prompt_group.append(p)
|
||||
else:
|
||||
other_group.append(p)
|
||||
|
||||
# Determine soft-prompt LR
|
||||
soft_prompt_lr = self.lr * self.soft_prompt_lr_scale
|
||||
if self.soft_prompt_warmup_lr_scale is not None:
|
||||
# Start at warmup scale, scheduler will warm up to soft_prompt_lr
|
||||
soft_prompt_lr = self.lr * self.soft_prompt_warmup_lr_scale
|
||||
|
||||
param_groups: list[dict[str, Any]] = [
|
||||
{
|
||||
"params": vlm_group,
|
||||
"lr": self.lr * 0.1,
|
||||
"weight_decay": self.weight_decay * 0.1,
|
||||
"name": "vlm",
|
||||
},
|
||||
{
|
||||
"params": soft_prompt_group,
|
||||
"lr": soft_prompt_lr,
|
||||
"weight_decay": self.weight_decay,
|
||||
"name": "soft_prompts",
|
||||
},
|
||||
{
|
||||
"params": other_group,
|
||||
"lr": self.lr,
|
||||
"weight_decay": self.weight_decay,
|
||||
"name": "other",
|
||||
},
|
||||
]
|
||||
|
||||
# Filter out empty groups
|
||||
param_groups = [g for g in param_groups if len(g["params"]) > 0]
|
||||
|
||||
return torch.optim.AdamW(
|
||||
param_groups,
|
||||
betas=self.betas,
|
||||
eps=self.eps,
|
||||
)
|
||||
|
||||
|
||||
@OptimizerConfig.register_subclass("multi_adam")
|
||||
@dataclass
|
||||
class MultiAdamConfig(OptimizerConfig):
|
||||
@@ -245,25 +123,19 @@ class MultiAdamConfig(OptimizerConfig):
|
||||
grad_clip_norm: float = 10.0
|
||||
optimizer_groups: dict[str, dict[str, Any]] = field(default_factory=dict)
|
||||
|
||||
def build(self, params: OptimizerParams) -> dict[str, torch.optim.Optimizer]:
|
||||
def build(self, params_dict: dict[str, list]) -> dict[str, torch.optim.Optimizer]:
|
||||
"""Build multiple Adam optimizers.
|
||||
|
||||
Args:
|
||||
params: Must be a dict[str, Iterable[Parameter]] mapping parameter group names
|
||||
to iterables of parameters. The keys should match the keys in optimizer_groups.
|
||||
Typically from policies that need separate optimizers (e.g., SAC with
|
||||
actor/critic/temperature).
|
||||
params_dict: Dictionary mapping parameter group names to lists of parameters
|
||||
The keys should match the keys in optimizer_groups
|
||||
|
||||
Returns:
|
||||
Dictionary mapping parameter group names to their optimizers
|
||||
|
||||
Raises:
|
||||
AssertionError: If params is not a dict
|
||||
"""
|
||||
assert isinstance(params, dict), "MultiAdamConfig requires a dict of parameter groups as inputs."
|
||||
optimizers = {}
|
||||
|
||||
for name, group_params in params.items():
|
||||
for name, params in params_dict.items():
|
||||
# Get group-specific hyperparameters or use defaults
|
||||
group_config = self.optimizer_groups.get(name, {})
|
||||
|
||||
@@ -275,7 +147,7 @@ class MultiAdamConfig(OptimizerConfig):
|
||||
"weight_decay": group_config.get("weight_decay", self.weight_decay),
|
||||
}
|
||||
|
||||
optimizers[name] = torch.optim.Adam(group_params, **optimizer_kwargs)
|
||||
optimizers[name] = torch.optim.Adam(params, **optimizer_kwargs)
|
||||
|
||||
return optimizers
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ from lerobot.utils.io_utils import deserialize_json_into_object
|
||||
|
||||
@dataclass
|
||||
class LRSchedulerConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
num_warmup_steps: int | None
|
||||
num_warmup_steps: int
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
|
||||
@@ -21,8 +21,6 @@ from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
|
||||
from .smolvla.processor_smolvla import SmolVLANewLineProcessor
|
||||
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
|
||||
from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig
|
||||
from .wall_x.configuration_wall_x import WallXConfig as WallXConfig
|
||||
from .xvla.configuration_xvla import XVLAConfig as XVLAConfig
|
||||
|
||||
__all__ = [
|
||||
"ACTConfig",
|
||||
@@ -30,10 +28,7 @@ __all__ = [
|
||||
"PI0Config",
|
||||
"PI05Config",
|
||||
"SmolVLAConfig",
|
||||
"SARMConfig",
|
||||
"TDMPCConfig",
|
||||
"VQBeTConfig",
|
||||
"GrootConfig",
|
||||
"XVLAConfig",
|
||||
"WallXConfig",
|
||||
]
|
||||
|
||||
@@ -50,7 +50,6 @@ class ACTPolicy(PreTrainedPolicy):
|
||||
def __init__(
|
||||
self,
|
||||
config: ACTConfig,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
|
||||
@@ -56,7 +56,6 @@ class DiffusionPolicy(PreTrainedPolicy):
|
||||
def __init__(
|
||||
self,
|
||||
config: DiffusionConfig,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
|
||||
@@ -16,7 +16,6 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import logging
|
||||
from typing import Any, TypedDict
|
||||
|
||||
@@ -37,13 +36,10 @@ from lerobot.policies.pi05.configuration_pi05 import PI05Config
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
|
||||
from lerobot.policies.sarm.configuration_sarm import SARMConfig
|
||||
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
|
||||
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from lerobot.policies.utils import validate_visual_features_consistency
|
||||
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||
from lerobot.policies.wall_x.configuration_wall_x import WallXConfig
|
||||
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
|
||||
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
|
||||
from lerobot.processor.converters import (
|
||||
batch_to_transition,
|
||||
@@ -63,7 +59,7 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
|
||||
Args:
|
||||
name: The name of the policy. Supported names are "tdmpc", "diffusion", "act",
|
||||
"vqbet", "pi0", "pi05", "sac", "reward_classifier", "smolvla", "wall_x".
|
||||
"vqbet", "pi0", "pi05", "sac", "reward_classifier", "smolvla".
|
||||
|
||||
Returns:
|
||||
The policy class corresponding to the given name.
|
||||
@@ -107,27 +103,12 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy
|
||||
|
||||
return SmolVLAPolicy
|
||||
elif name == "sarm":
|
||||
from lerobot.policies.sarm.modeling_sarm import SARMRewardModel
|
||||
|
||||
return SARMRewardModel
|
||||
elif name == "groot":
|
||||
from lerobot.policies.groot.modeling_groot import GrootPolicy
|
||||
|
||||
return GrootPolicy
|
||||
elif name == "xvla":
|
||||
from lerobot.policies.xvla.modeling_xvla import XVLAPolicy
|
||||
|
||||
return XVLAPolicy
|
||||
elif name == "wall_x":
|
||||
from lerobot.policies.wall_x.modeling_wall_x import WallXPolicy
|
||||
|
||||
return WallXPolicy
|
||||
else:
|
||||
try:
|
||||
return _get_policy_cls_from_policy_name(name=name)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Policy type '{name}' is not available.") from e
|
||||
raise NotImplementedError(f"Policy with name {name} is not implemented.")
|
||||
|
||||
|
||||
def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
@@ -140,7 +121,7 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
Args:
|
||||
policy_type: The type of the policy. Supported types include "tdmpc",
|
||||
"diffusion", "act", "vqbet", "pi0", "pi05", "sac", "smolvla",
|
||||
"reward_classifier", "wall_x".
|
||||
"reward_classifier".
|
||||
**kwargs: Keyword arguments to be passed to the configuration class constructor.
|
||||
|
||||
Returns:
|
||||
@@ -169,16 +150,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
return RewardClassifierConfig(**kwargs)
|
||||
elif policy_type == "groot":
|
||||
return GrootConfig(**kwargs)
|
||||
elif policy_type == "xvla":
|
||||
return XVLAConfig(**kwargs)
|
||||
elif policy_type == "wall_x":
|
||||
return WallXConfig(**kwargs)
|
||||
else:
|
||||
try:
|
||||
config_cls = PreTrainedConfig.get_choice_class(policy_type)
|
||||
return config_cls(**kwargs)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Policy type '{policy_type}' is not available.") from e
|
||||
raise ValueError(f"Policy type '{policy_type}' is not available.")
|
||||
|
||||
|
||||
class ProcessorConfigKwargs(TypedDict, total=False):
|
||||
@@ -349,14 +322,6 @@ def make_pre_post_processors(
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, SARMConfig):
|
||||
from lerobot.policies.sarm.processor_sarm import make_sarm_pre_post_processors
|
||||
|
||||
processors = make_sarm_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
dataset_meta=kwargs.get("dataset_meta"),
|
||||
)
|
||||
elif isinstance(policy_cfg, GrootConfig):
|
||||
from lerobot.policies.groot.processor_groot import make_groot_pre_post_processors
|
||||
|
||||
@@ -365,32 +330,8 @@ def make_pre_post_processors(
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, XVLAConfig):
|
||||
from lerobot.policies.xvla.processor_xvla import (
|
||||
make_xvla_pre_post_processors,
|
||||
)
|
||||
|
||||
processors = make_xvla_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, WallXConfig):
|
||||
from lerobot.policies.wall_x.processor_wall_x import make_wall_x_pre_post_processors
|
||||
|
||||
processors = make_wall_x_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
else:
|
||||
try:
|
||||
processors = _make_processors_from_policy_config(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Processor for policy type '{policy_cfg.type}' is not implemented.") from e
|
||||
raise NotImplementedError(f"Processor for policy type '{policy_cfg.type}' is not implemented.")
|
||||
|
||||
return processors
|
||||
|
||||
@@ -459,52 +400,17 @@ def make_policy(
|
||||
raise ValueError("env_cfg cannot be None when ds_meta is not provided")
|
||||
features = env_to_policy_features(env_cfg)
|
||||
|
||||
cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
|
||||
if not cfg.output_features:
|
||||
cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
|
||||
if not cfg.input_features:
|
||||
cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features}
|
||||
kwargs["config"] = cfg
|
||||
|
||||
# Pass dataset_stats to the policy if available (needed for some policies like SARM)
|
||||
if ds_meta is not None and hasattr(ds_meta, "stats"):
|
||||
kwargs["dataset_stats"] = ds_meta.stats
|
||||
|
||||
if ds_meta is not None:
|
||||
kwargs["dataset_meta"] = ds_meta
|
||||
|
||||
if not cfg.pretrained_path and cfg.use_peft:
|
||||
raise ValueError(
|
||||
"Instantiating a policy with `use_peft=True` without a checkpoint is not supported since that requires "
|
||||
"the PEFT config parameters to be set. For training with PEFT, see `lerobot_train.py` on how to do that."
|
||||
)
|
||||
|
||||
if cfg.pretrained_path and not cfg.use_peft:
|
||||
if cfg.pretrained_path:
|
||||
# Load a pretrained policy and override the config if needed (for example, if there are inference-time
|
||||
# hyperparameters that we want to vary).
|
||||
kwargs["pretrained_name_or_path"] = cfg.pretrained_path
|
||||
policy = policy_cls.from_pretrained(**kwargs)
|
||||
elif cfg.pretrained_path and cfg.use_peft:
|
||||
# Load a pretrained PEFT model on top of the policy. The pretrained path points to the folder/repo
|
||||
# of the adapter and the adapter's config contains the path to the base policy. So we need the
|
||||
# adapter config first, then load the correct policy and then apply PEFT.
|
||||
from peft import PeftConfig, PeftModel
|
||||
|
||||
logging.info("Loading policy's PEFT adapter.")
|
||||
|
||||
peft_pretrained_path = cfg.pretrained_path
|
||||
peft_config = PeftConfig.from_pretrained(peft_pretrained_path)
|
||||
|
||||
kwargs["pretrained_name_or_path"] = peft_config.base_model_name_or_path
|
||||
if not kwargs["pretrained_name_or_path"]:
|
||||
# This means that there's a bug or we trained a policy from scratch using PEFT.
|
||||
# It is more likely that this is a bug so we'll raise an error.
|
||||
raise ValueError(
|
||||
"No pretrained model name found in adapter config. Can't instantiate the pre-trained policy on which "
|
||||
"the adapter was trained."
|
||||
)
|
||||
|
||||
policy = policy_cls.from_pretrained(**kwargs)
|
||||
policy = PeftModel.from_pretrained(policy, peft_pretrained_path, config=peft_config)
|
||||
|
||||
else:
|
||||
# Make a fresh policy.
|
||||
policy = policy_cls(**kwargs)
|
||||
@@ -519,65 +425,3 @@ def make_policy(
|
||||
# TODO: (jadechoghari) - add a check_state(cfg, features) and check_action(cfg, features)
|
||||
|
||||
return policy
|
||||
|
||||
|
||||
def _get_policy_cls_from_policy_name(name: str) -> type[PreTrainedConfig]:
|
||||
"""Get policy class from its registered name using dynamic imports.
|
||||
|
||||
This is used as a helper function to import policies from 3rd party lerobot plugins.
|
||||
|
||||
Args:
|
||||
name: The name of the policy.
|
||||
Returns:
|
||||
The policy class corresponding to the given name.
|
||||
"""
|
||||
if name not in PreTrainedConfig.get_known_choices():
|
||||
raise ValueError(
|
||||
f"Unknown policy name '{name}'. Available policies: {PreTrainedConfig.get_known_choices()}"
|
||||
)
|
||||
|
||||
config_cls = PreTrainedConfig.get_choice_class(name)
|
||||
config_cls_name = config_cls.__name__
|
||||
|
||||
model_name = config_cls_name.removesuffix("Config") # e.g., DiffusionConfig -> Diffusion
|
||||
if model_name == config_cls_name:
|
||||
raise ValueError(
|
||||
f"The config class name '{config_cls_name}' does not follow the expected naming convention."
|
||||
f"Make sure it ends with 'Config'!"
|
||||
)
|
||||
cls_name = model_name + "Policy" # e.g., DiffusionConfig -> DiffusionPolicy
|
||||
module_path = config_cls.__module__.replace(
|
||||
"configuration_", "modeling_"
|
||||
) # e.g., configuration_diffusion -> modeling_diffusion
|
||||
|
||||
module = importlib.import_module(module_path)
|
||||
policy_cls = getattr(module, cls_name)
|
||||
return policy_cls
|
||||
|
||||
|
||||
def _make_processors_from_policy_config(
|
||||
config: PreTrainedConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
) -> tuple[Any, Any]:
|
||||
"""Create pre- and post-processors from a policy configuration using dynamic imports.
|
||||
|
||||
This is used as a helper function to import processor factories from 3rd party lerobot plugins.
|
||||
|
||||
Args:
|
||||
config: The policy configuration object.
|
||||
dataset_stats: Dataset statistics for normalization.
|
||||
Returns:
|
||||
A tuple containing the input (pre-processor) and output (post-processor) pipelines.
|
||||
"""
|
||||
|
||||
policy_type = config.type
|
||||
function_name = f"make_{policy_type}_pre_post_processors"
|
||||
module_path = config.__class__.__module__.replace(
|
||||
"configuration_", "processor_"
|
||||
) # e.g., configuration_diffusion -> processor_diffusion
|
||||
logging.debug(
|
||||
f"Instantiating pre/post processors using function '{function_name}' from module '{module_path}'"
|
||||
)
|
||||
module = importlib.import_module(module_path)
|
||||
function = getattr(module, function_name)
|
||||
return function(config, dataset_stats=dataset_stats)
|
||||
|
||||
@@ -49,7 +49,7 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
name = "groot"
|
||||
config_class = GrootConfig
|
||||
|
||||
def __init__(self, config: GrootConfig, **kwargs):
|
||||
def __init__(self, config: GrootConfig):
|
||||
"""Initialize Groot policy wrapper."""
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
|
||||
@@ -23,8 +23,6 @@ from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.utils.constants import OBS_IMAGES
|
||||
|
||||
DEFAULT_IMAGE_SIZE = 224
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("pi0")
|
||||
@dataclass
|
||||
@@ -53,10 +51,7 @@ class PI0Config(PreTrainedConfig):
|
||||
# Real-Time Chunking (RTC) configuration
|
||||
rtc_config: RTCConfig | None = None
|
||||
|
||||
image_resolution: tuple[int, int] = (
|
||||
DEFAULT_IMAGE_SIZE,
|
||||
DEFAULT_IMAGE_SIZE,
|
||||
) # see openpi `preprocessing_pytorch.py`
|
||||
image_resolution: tuple[int, int] = (224, 224) # see openpi `preprocessing_pytorch.py`
|
||||
|
||||
# Add empty images. Used to add empty cameras when no image features are present.
|
||||
empty_cameras: int = 0
|
||||
|
||||
@@ -41,7 +41,7 @@ else:
|
||||
PaliGemmaForConditionalGeneration = None
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.policies.pi0.configuration_pi0 import DEFAULT_IMAGE_SIZE, PI0Config
|
||||
from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
||||
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
|
||||
from lerobot.utils.constants import (
|
||||
@@ -337,7 +337,6 @@ class PaliGemmaWithExpertModel(
|
||||
action_expert_config,
|
||||
use_adarms=None,
|
||||
precision: Literal["bfloat16", "float32"] = "bfloat16",
|
||||
image_size: int = DEFAULT_IMAGE_SIZE,
|
||||
):
|
||||
if use_adarms is None:
|
||||
use_adarms = [False, False]
|
||||
@@ -357,7 +356,6 @@ class PaliGemmaWithExpertModel(
|
||||
vlm_config_hf.text_config.vocab_size = 257152
|
||||
vlm_config_hf.text_config.use_adarms = use_adarms[0]
|
||||
vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None
|
||||
vlm_config_hf.vision_config.image_size = image_size
|
||||
vlm_config_hf.vision_config.intermediate_size = 4304
|
||||
vlm_config_hf.vision_config.projection_dim = 2048
|
||||
vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast"
|
||||
@@ -521,17 +519,11 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
paligemma_config = get_gemma_config(config.paligemma_variant)
|
||||
action_expert_config = get_gemma_config(config.action_expert_variant)
|
||||
|
||||
if config.image_resolution[0] != config.image_resolution[1]:
|
||||
raise ValueError(
|
||||
f"PaliGemma expects square image resolution, invalid resolution: {config.image_resolution}"
|
||||
)
|
||||
|
||||
self.paligemma_with_expert = PaliGemmaWithExpertModel(
|
||||
paligemma_config,
|
||||
action_expert_config,
|
||||
use_adarms=[False, False],
|
||||
precision=config.dtype,
|
||||
image_size=config.image_resolution[0],
|
||||
)
|
||||
|
||||
self.action_in_proj = nn.Linear(config.max_action_dim, action_expert_config.width)
|
||||
@@ -820,13 +812,16 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
)
|
||||
|
||||
dt = -1.0 / num_steps
|
||||
dt = torch.tensor(dt, dtype=torch.float32, device=device)
|
||||
|
||||
x_t = noise
|
||||
for step in range(num_steps):
|
||||
time = 1.0 + step * dt
|
||||
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
|
||||
time = torch.tensor(1.0, dtype=torch.float32, device=device)
|
||||
while time >= -dt / 2:
|
||||
expanded_time = time.expand(bsize)
|
||||
|
||||
def denoise_step_partial_call(input_x_t, current_timestep=time_tensor):
|
||||
# Define a closure function to properly capture expanded_time
|
||||
# This avoids the lambda expression (E731) and loop variable binding (B023) issues
|
||||
def denoise_step_partial_call(input_x_t, current_timestep=expanded_time):
|
||||
return self.denoise_step(
|
||||
state=state,
|
||||
prefix_pad_masks=prefix_pad_masks,
|
||||
@@ -851,11 +846,15 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
else:
|
||||
v_t = denoise_step_partial_call(x_t)
|
||||
|
||||
x_t = x_t + dt * v_t
|
||||
# Euler step
|
||||
x_t += dt * v_t
|
||||
|
||||
# Record x_t and v_t after Euler step
|
||||
if self.rtc_processor is not None and self.rtc_processor.is_debug_enabled():
|
||||
self.rtc_processor.track(time=time, x_t=x_t, v_t=v_t)
|
||||
|
||||
time += dt
|
||||
|
||||
return x_t
|
||||
|
||||
def denoise_step(
|
||||
@@ -907,7 +906,6 @@ class PI0Policy(PreTrainedPolicy):
|
||||
def __init__(
|
||||
self,
|
||||
config: PI0Config,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@@ -1236,15 +1234,9 @@ class PI0Policy(PreTrainedPolicy):
|
||||
|
||||
return actions
|
||||
|
||||
def forward(self, batch: dict[str, Tensor], reduction: str = "mean") -> tuple[Tensor, dict]:
|
||||
"""Run the batch through the model and compute the loss for training.
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
||||
"""Run the batch through the model and compute the loss for training."""
|
||||
|
||||
Args:
|
||||
batch: Training batch containing observations and actions.
|
||||
reduction: How to reduce the loss. Options:
|
||||
- "mean": Return scalar mean loss (default, backward compatible)
|
||||
- "none": Return per-sample losses of shape (batch_size,) for RA-BC weighting
|
||||
"""
|
||||
# Prepare inputs
|
||||
images, img_masks = self._preprocess_images(batch)
|
||||
lang_tokens, lang_masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
|
||||
@@ -1258,17 +1250,11 @@ class PI0Policy(PreTrainedPolicy):
|
||||
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||
losses = losses[:, :, :original_action_dim]
|
||||
|
||||
loss = losses.mean()
|
||||
|
||||
loss_dict = {
|
||||
"loss": loss.item(),
|
||||
"loss_per_dim": losses.mean(dim=[0, 1]).detach().cpu().numpy().tolist(),
|
||||
}
|
||||
|
||||
if reduction == "none":
|
||||
# Return per-sample losses (B,) by averaging over time and action dims
|
||||
per_sample_loss = losses.mean(dim=(1, 2))
|
||||
loss_dict["loss"] = per_sample_loss.mean().item()
|
||||
return per_sample_loss, loss_dict
|
||||
else:
|
||||
# Default: return scalar mean loss
|
||||
loss = losses.mean()
|
||||
loss_dict["loss"] = loss.item()
|
||||
return loss, loss_dict
|
||||
return loss, loss_dict
|
||||
|
||||
@@ -22,8 +22,6 @@ from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
|
||||
DEFAULT_IMAGE_SIZE = 224
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("pi05")
|
||||
@dataclass
|
||||
@@ -52,10 +50,7 @@ class PI05Config(PreTrainedConfig):
|
||||
# Real-Time Chunking (RTC) configuration
|
||||
rtc_config: RTCConfig | None = None
|
||||
|
||||
image_resolution: tuple[int, int] = (
|
||||
DEFAULT_IMAGE_SIZE,
|
||||
DEFAULT_IMAGE_SIZE,
|
||||
) # see openpi `preprocessing_pytorch.py`
|
||||
image_resolution: tuple[int, int] = (224, 224) # see openpi `preprocessing_pytorch.py`
|
||||
|
||||
# Add empty images. Used to add empty cameras when no image features are present.
|
||||
empty_cameras: int = 0
|
||||
|
||||
@@ -41,7 +41,7 @@ else:
|
||||
PaliGemmaForConditionalGeneration = None
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.policies.pi05.configuration_pi05 import DEFAULT_IMAGE_SIZE, PI05Config
|
||||
from lerobot.policies.pi05.configuration_pi05 import PI05Config
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
||||
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
|
||||
from lerobot.utils.constants import (
|
||||
@@ -336,7 +336,6 @@ class PaliGemmaWithExpertModel(
|
||||
action_expert_config,
|
||||
use_adarms=None,
|
||||
precision: Literal["bfloat16", "float32"] = "bfloat16",
|
||||
image_size: int = DEFAULT_IMAGE_SIZE,
|
||||
):
|
||||
if use_adarms is None:
|
||||
use_adarms = [False, False]
|
||||
@@ -356,7 +355,6 @@ class PaliGemmaWithExpertModel(
|
||||
vlm_config_hf.text_config.vocab_size = 257152
|
||||
vlm_config_hf.text_config.use_adarms = use_adarms[0]
|
||||
vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None
|
||||
vlm_config_hf.vision_config.image_size = image_size
|
||||
vlm_config_hf.vision_config.intermediate_size = 4304
|
||||
vlm_config_hf.vision_config.projection_dim = 2048
|
||||
vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast"
|
||||
@@ -520,17 +518,11 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
paligemma_config = get_gemma_config(config.paligemma_variant)
|
||||
action_expert_config = get_gemma_config(config.action_expert_variant)
|
||||
|
||||
if config.image_resolution[0] != config.image_resolution[1]:
|
||||
raise ValueError(
|
||||
f"PaliGemma expects square image resolution, invalid resolution: {config.image_resolution}"
|
||||
)
|
||||
|
||||
self.paligemma_with_expert = PaliGemmaWithExpertModel(
|
||||
paligemma_config,
|
||||
action_expert_config,
|
||||
use_adarms=[False, True],
|
||||
precision=config.dtype,
|
||||
image_size=config.image_resolution[0],
|
||||
)
|
||||
|
||||
self.action_in_proj = nn.Linear(config.max_action_dim, action_expert_config.width)
|
||||
@@ -795,13 +787,16 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
)
|
||||
|
||||
dt = -1.0 / num_steps
|
||||
dt = torch.tensor(dt, dtype=torch.float32, device=device)
|
||||
|
||||
x_t = noise
|
||||
for step in range(num_steps):
|
||||
time = 1.0 + step * dt
|
||||
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
|
||||
time = torch.tensor(1.0, dtype=torch.float32, device=device)
|
||||
while time >= -dt / 2:
|
||||
expanded_time = time.expand(bsize)
|
||||
|
||||
def denoise_step_partial_call(input_x_t, current_timestep=time_tensor):
|
||||
# Define a closure function to properly capture expanded_time
|
||||
# This avoids the lambda expression (E731) and loop variable binding (B023) issues
|
||||
def denoise_step_partial_call(input_x_t, current_timestep=expanded_time):
|
||||
return self.denoise_step(
|
||||
prefix_pad_masks=prefix_pad_masks,
|
||||
past_key_values=past_key_values,
|
||||
@@ -825,11 +820,15 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
else:
|
||||
v_t = denoise_step_partial_call(x_t)
|
||||
|
||||
x_t = x_t + dt * v_t
|
||||
# Euler step
|
||||
x_t += dt * v_t
|
||||
|
||||
# Record x_t and v_t after Euler step
|
||||
if self.rtc_processor is not None and self.rtc_processor.is_debug_enabled():
|
||||
self.rtc_processor.track(time=time, x_t=x_t, v_t=v_t)
|
||||
|
||||
time += dt
|
||||
|
||||
return x_t
|
||||
|
||||
def denoise_step(
|
||||
@@ -880,7 +879,6 @@ class PI05Policy(PreTrainedPolicy):
|
||||
def __init__(
|
||||
self,
|
||||
config: PI05Config,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@@ -1210,15 +1208,9 @@ class PI05Policy(PreTrainedPolicy):
|
||||
|
||||
return actions
|
||||
|
||||
def forward(self, batch: dict[str, Tensor], reduction: str = "mean") -> tuple[Tensor, dict]:
|
||||
"""Run the batch through the model and compute the loss for training.
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
||||
"""Run the batch through the model and compute the loss for training."""
|
||||
|
||||
Args:
|
||||
batch: Training batch containing observations and actions.
|
||||
reduction: How to reduce the loss. Options:
|
||||
- "mean": Return scalar mean loss (default, backward compatible)
|
||||
- "none": Return per-sample losses of shape (batch_size,) for RA-BC weighting
|
||||
"""
|
||||
# Prepare inputs
|
||||
images, img_masks = self._preprocess_images(batch)
|
||||
tokens, masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
|
||||
@@ -1232,17 +1224,11 @@ class PI05Policy(PreTrainedPolicy):
|
||||
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||
losses = losses[:, :, :original_action_dim]
|
||||
|
||||
loss = losses.mean()
|
||||
|
||||
loss_dict = {
|
||||
"loss": loss.item(),
|
||||
"loss_per_dim": losses.mean(dim=[0, 1]).detach().cpu().numpy().tolist(),
|
||||
}
|
||||
|
||||
if reduction == "none":
|
||||
# Return per-sample losses (B,) by averaging over time and action dims
|
||||
per_sample_loss = losses.mean(dim=(1, 2))
|
||||
loss_dict["loss"] = per_sample_loss.mean().item()
|
||||
return per_sample_loss, loss_dict
|
||||
else:
|
||||
# Default: return scalar mean loss
|
||||
loss = losses.mean()
|
||||
loss_dict["loss"] = loss.item()
|
||||
return loss, loss_dict
|
||||
return loss, loss_dict
|
||||
|
||||
@@ -206,7 +206,6 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
||||
def push_model_to_hub(
|
||||
self,
|
||||
cfg: TrainPipelineConfig,
|
||||
peft_model=None,
|
||||
):
|
||||
api = HfApi()
|
||||
repo_id = api.create_repo(
|
||||
@@ -217,14 +216,7 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
||||
with TemporaryDirectory(ignore_cleanup_errors=True) as tmp:
|
||||
saved_path = Path(tmp) / repo_id
|
||||
|
||||
if peft_model is not None:
|
||||
# Since PEFT just forwards calls to `push_model_to_hub`, `self` is not the PeftModel wrapper
|
||||
# but the actual policy which is why we need the PEFT model passed to us to save the adapter.
|
||||
# That also means that we need to store the policy config ourselves since PEFT can't.
|
||||
peft_model.save_pretrained(saved_path)
|
||||
self.config.save_pretrained(saved_path)
|
||||
else:
|
||||
self.save_pretrained(saved_path) # Calls _save_pretrained and stores model tensors
|
||||
self.save_pretrained(saved_path) # Calls _save_pretrained and stores model tensors
|
||||
|
||||
card = self.generate_model_card(
|
||||
cfg.dataset.repo_id, self.config.type, self.config.license, self.config.tags
|
||||
|
||||
@@ -1,14 +0,0 @@
|
||||
## Paper
|
||||
|
||||
https://arxiv.org/abs/2509.25358
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@article{chen2025sarm,
|
||||
title={SARM: Stage-Aware Reward Modeling for Long Horizon Robot Manipulation},
|
||||
author={Chen, Qianzhong and Yu, Justin and Schwager, Mac and Abbeel, Pieter and Shentu, Yide and Wu, Philipp},
|
||||
journal={arXiv preprint arXiv:2509.25358},
|
||||
year={2025}
|
||||
}
|
||||
```
|
||||
@@ -1,870 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Compute SARM progress values for RA-BC (Reward-Aware Behavior Cloning) weighting.
|
||||
|
||||
This script processes all frames in a dataset with SARM to compute progress values [0, 1].
|
||||
The results are saved as a parquet file that can be loaded during training for RA-BC weighting.
|
||||
|
||||
Uses multi-output extraction: each SARM query returns progress for 9 frames, so we only
|
||||
need ~num_frames/30 queries instead of one per frame (~30x speedup).
|
||||
|
||||
Usage:
|
||||
# Full RA-BC computation with visualizations
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
|
||||
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
|
||||
--reward-model-path pepijn223/sarm_single_uni4
|
||||
|
||||
# Faster computation with stride (compute every 5 frames, interpolate the rest)
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
|
||||
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
|
||||
--reward-model-path pepijn223/sarm_single_uni4 \\
|
||||
--stride 5
|
||||
|
||||
# Visualize predictions only (no RA-BC computation)
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
|
||||
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
|
||||
--reward-model-path pepijn223/sarm_single_uni4 \\
|
||||
--visualize-only \\
|
||||
--num-visualizations 5
|
||||
|
||||
The output is saved to the dataset's local cache directory as 'sarm_progress.parquet'.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import matplotlib.gridspec as gridspec
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
import pyarrow.parquet as pq
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.policies.sarm.modeling_sarm import SARMRewardModel
|
||||
from lerobot.policies.sarm.processor_sarm import make_sarm_pre_post_processors
|
||||
from lerobot.policies.sarm.sarm_utils import normalize_stage_tau
|
||||
|
||||
|
||||
def get_reward_model_path_from_parquet(parquet_path: Path) -> str | None:
|
||||
"""Read reward_model_path from parquet metadata if available."""
|
||||
if not parquet_path.exists():
|
||||
return None
|
||||
try:
|
||||
metadata = pq.read_metadata(parquet_path).schema.to_arrow_schema().metadata
|
||||
if metadata and b"reward_model_path" in metadata:
|
||||
return metadata[b"reward_model_path"].decode()
|
||||
except Exception: # nosec B110
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
def load_sarm_resources(
|
||||
dataset_repo_id: str,
|
||||
reward_model_path: str,
|
||||
device: str = "cuda",
|
||||
) -> tuple[LeRobotDataset, SARMRewardModel, any]:
|
||||
"""
|
||||
Load SARM model, dataset, and preprocessor.
|
||||
|
||||
Returns:
|
||||
Tuple of (dataset, reward_model, preprocessor)
|
||||
"""
|
||||
logging.info(f"Loading model: {reward_model_path}")
|
||||
reward_model = SARMRewardModel.from_pretrained(reward_model_path)
|
||||
reward_model.config.device = device
|
||||
reward_model.to(device).eval()
|
||||
|
||||
image_key = reward_model.config.image_key
|
||||
state_key = reward_model.config.state_key
|
||||
delta_indices = reward_model.config.observation_delta_indices
|
||||
|
||||
logging.info(f"Loading dataset: {dataset_repo_id}")
|
||||
temp_dataset = LeRobotDataset(dataset_repo_id, download_videos=True)
|
||||
fps = temp_dataset.fps
|
||||
|
||||
delta_timestamps = {
|
||||
image_key: [idx / fps for idx in delta_indices],
|
||||
state_key: [idx / fps for idx in delta_indices],
|
||||
}
|
||||
dataset = LeRobotDataset(dataset_repo_id, delta_timestamps=delta_timestamps)
|
||||
logging.info(f"Dataset: {dataset.num_episodes} episodes, {dataset.num_frames} frames")
|
||||
|
||||
preprocess, _ = make_sarm_pre_post_processors(
|
||||
config=reward_model.config,
|
||||
dataset_stats=dataset.meta.stats,
|
||||
dataset_meta=dataset.meta,
|
||||
)
|
||||
|
||||
return dataset, reward_model, preprocess
|
||||
|
||||
|
||||
def to_numpy_image(img) -> np.ndarray:
|
||||
"""Convert image tensor to numpy uint8 (H, W, C)."""
|
||||
if isinstance(img, torch.Tensor):
|
||||
img = img.cpu().numpy()
|
||||
if img.ndim == 4:
|
||||
# Take center frame for bidirectional sampling
|
||||
img = img[img.shape[0] // 2]
|
||||
if img.shape[0] in [1, 3]:
|
||||
img = np.transpose(img, (1, 2, 0))
|
||||
if img.dtype != np.uint8:
|
||||
# Handle normalized images (may have negative values or values > 1)
|
||||
img = img.astype(np.float32)
|
||||
img = (img - img.min()) / (img.max() - img.min() + 1e-8) # Normalize to [0, 1]
|
||||
img = (img * 255).astype(np.uint8)
|
||||
return img
|
||||
|
||||
|
||||
def visualize_episode(
|
||||
frames, progress_preds, stage_preds, title, output_path, stage_labels, gt_progress=None, gt_stages=None
|
||||
):
|
||||
"""Create visualization with progress plot, stage probabilities, and sample frames.
|
||||
|
||||
Same as sarm_inference_visualization.py
|
||||
"""
|
||||
num_stages = stage_preds.shape[1]
|
||||
colors = plt.cm.tab10(np.linspace(0, 1, num_stages))
|
||||
frame_indices = np.arange(len(progress_preds))
|
||||
|
||||
fig = plt.figure(figsize=(14, 12))
|
||||
gs = gridspec.GridSpec(3, 1, height_ratios=[2, 1, 1], hspace=0.3)
|
||||
ax_progress, ax_stages, ax_frames = fig.add_subplot(gs[0]), fig.add_subplot(gs[1]), fig.add_subplot(gs[2])
|
||||
|
||||
# Progress plot
|
||||
ax_progress.plot(frame_indices, progress_preds, linewidth=2, color="#2E86AB", label="Predicted")
|
||||
ax_progress.fill_between(frame_indices, 0, progress_preds, alpha=0.3, color="#2E86AB")
|
||||
if gt_progress is not None:
|
||||
ax_progress.plot(
|
||||
frame_indices, gt_progress, linewidth=2, color="#28A745", linestyle="--", label="Ground Truth"
|
||||
)
|
||||
ax_progress.axhline(y=1.0, color="gray", linestyle="--", alpha=0.5)
|
||||
ax_progress.set_ylabel("Progress")
|
||||
ax_progress.set_title(f'Task: "{title}"', fontweight="bold")
|
||||
ax_progress.set_ylim(-0.05, 1.1)
|
||||
ax_progress.legend(loc="upper left")
|
||||
ax_progress.grid(True, alpha=0.3)
|
||||
|
||||
# Stage predictions
|
||||
ax_stages.stackplot(
|
||||
frame_indices,
|
||||
*[stage_preds[:, i] for i in range(num_stages)],
|
||||
colors=colors,
|
||||
alpha=0.8,
|
||||
labels=stage_labels,
|
||||
)
|
||||
if gt_stages is not None:
|
||||
for change_idx in np.where(np.diff(gt_stages) != 0)[0] + 1:
|
||||
ax_stages.axvline(x=change_idx, color="black", linestyle="-", alpha=0.7, linewidth=1.5)
|
||||
ax_stages.set_xlabel("Frame")
|
||||
ax_stages.set_ylabel("Stage Probability")
|
||||
ax_stages.set_ylim(0, 1)
|
||||
ax_stages.legend(loc="upper left", ncol=min(num_stages, 5), fontsize=8)
|
||||
ax_stages.grid(True, alpha=0.3)
|
||||
|
||||
# Sample frames
|
||||
ax_frames.axis("off")
|
||||
num_sample = 8
|
||||
sample_indices = np.linspace(0, len(frames) - 1, num_sample, dtype=int)
|
||||
h, w = frames[0].shape[:2]
|
||||
combined = np.zeros((h, w * num_sample, 3), dtype=np.uint8)
|
||||
for i, idx in enumerate(sample_indices):
|
||||
frame = frames[idx]
|
||||
if frame.shape[-1] == 1:
|
||||
frame = np.repeat(frame, 3, axis=-1)
|
||||
combined[:, i * w : (i + 1) * w] = frame
|
||||
stage_name = stage_labels[np.argmax(stage_preds[idx])][:12]
|
||||
ax_frames.text(
|
||||
i * w + w / 2,
|
||||
-10,
|
||||
f"Frame {idx}\n{progress_preds[idx]:.2f}\n{stage_name}",
|
||||
ha="center",
|
||||
va="top",
|
||||
fontsize=7,
|
||||
)
|
||||
ax_frames.imshow(combined)
|
||||
ax_frames.set_title("Sample Frames", pad=20)
|
||||
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
plt.savefig(output_path, dpi=150, bbox_inches="tight")
|
||||
plt.close()
|
||||
print(f"Saved: {output_path}")
|
||||
|
||||
|
||||
def visualize_sarm_predictions(
|
||||
dataset: LeRobotDataset,
|
||||
reward_model: SARMRewardModel,
|
||||
preprocess,
|
||||
episode_indices: list[int],
|
||||
head_mode: str,
|
||||
output_dir: Path,
|
||||
num_display_frames: int = 5,
|
||||
stride: int = 1,
|
||||
):
|
||||
"""
|
||||
Visualize SARM predictions for multiple episodes.
|
||||
|
||||
Computes predictions for every frame by default. With stride > 1, computes predictions
|
||||
every N frames and interpolates (progress + stage probabilities) for visualization.
|
||||
|
||||
Args:
|
||||
dataset: LeRobotDataset with delta_timestamps configured
|
||||
reward_model: Loaded SARM model
|
||||
preprocess: Preprocessor from make_sarm_pre_post_processors
|
||||
episode_indices: List of episode indices to visualize
|
||||
head_mode: "sparse", "dense", or "both"
|
||||
output_dir: Directory to save visualizations
|
||||
num_display_frames: Number of frames to display in thumbnail strip (default: 5)
|
||||
stride: Compute predictions every N frames, interpolate the rest (default: 1)
|
||||
"""
|
||||
output_dir = Path(output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
image_key = reward_model.config.image_key
|
||||
state_key = reward_model.config.state_key
|
||||
dual_mode = reward_model.config.uses_dual_heads
|
||||
device = reward_model.device
|
||||
|
||||
# Center frame index for bidirectional sampling
|
||||
target_idx = reward_model.config.n_obs_steps // 2
|
||||
|
||||
# Determine which heads to visualize
|
||||
schemes_to_viz = []
|
||||
if head_mode in ("sparse", "both") or not dual_mode:
|
||||
schemes_to_viz.append("sparse")
|
||||
if head_mode in ("dense", "both") and dual_mode:
|
||||
schemes_to_viz.append("dense")
|
||||
|
||||
# Set preprocessor to eval mode to disable augmentations
|
||||
if hasattr(preprocess, "eval"):
|
||||
preprocess.eval()
|
||||
for step in preprocess.steps:
|
||||
if hasattr(step, "eval"):
|
||||
step.eval()
|
||||
|
||||
for episode_idx in episode_indices:
|
||||
ep = dataset.meta.episodes[episode_idx]
|
||||
ep_start = ep["dataset_from_index"]
|
||||
ep_end = ep["dataset_to_index"]
|
||||
task = dataset[ep_start].get("task", "perform the task")
|
||||
num_frames = ep_end - ep_start
|
||||
|
||||
# Select frames for display thumbnails (evenly sampled from begin to end)
|
||||
display_indices = set(
|
||||
[
|
||||
ep_start + int(i * (num_frames - 1) / (num_display_frames - 1))
|
||||
for i in range(num_display_frames)
|
||||
]
|
||||
if num_frames >= num_display_frames
|
||||
else list(range(ep_start, ep_end))
|
||||
)
|
||||
viz_frames = {}
|
||||
|
||||
# Load display frames up-front (stride mode might skip them otherwise).
|
||||
for frame_idx in display_indices:
|
||||
sample = dataset[frame_idx]
|
||||
viz_frames[frame_idx] = to_numpy_image(sample[image_key])
|
||||
|
||||
# Initialize storage for each scheme
|
||||
scheme_data = {}
|
||||
for scheme in schemes_to_viz:
|
||||
num_stages = getattr(reward_model.config, f"num_{scheme}_stages")
|
||||
scheme_data[scheme] = {
|
||||
"viz_progress": np.full(num_frames, np.nan),
|
||||
"viz_stages": np.full((num_frames, num_stages), np.nan),
|
||||
"viz_gt_progress": np.full(num_frames, np.nan),
|
||||
"viz_gt_stages": np.full(num_frames, np.nan),
|
||||
"target_key": f"{scheme}_targets",
|
||||
"num_stages": num_stages,
|
||||
"temporal_props": getattr(reward_model.config, f"{scheme}_temporal_proportions"),
|
||||
"subtask_names": getattr(reward_model.config, f"{scheme}_subtask_names"),
|
||||
}
|
||||
|
||||
if stride > 1:
|
||||
logging.info(f"Visualization stride={stride}: inferring every {stride} frames and interpolating")
|
||||
|
||||
# Process frames one at a time to avoid memory buildup
|
||||
frame_indices = list(range(ep_start, ep_end, stride))
|
||||
if (ep_end - 1) not in frame_indices:
|
||||
frame_indices.append(ep_end - 1)
|
||||
frame_indices = sorted(set(frame_indices))
|
||||
|
||||
for frame_idx in tqdm(frame_indices, desc=f"Episode {episode_idx}", leave=False):
|
||||
local_idx = frame_idx - ep_start
|
||||
sample = dataset[frame_idx]
|
||||
|
||||
batch = {
|
||||
image_key: sample[image_key],
|
||||
"task": task,
|
||||
"index": frame_idx,
|
||||
"episode_index": episode_idx,
|
||||
}
|
||||
if state_key in sample:
|
||||
batch[state_key] = sample[state_key]
|
||||
|
||||
with torch.no_grad():
|
||||
processed = preprocess(batch)
|
||||
video_features = processed["video_features"].to(device)
|
||||
text_features = processed["text_features"].to(device)
|
||||
state_features = processed.get("state_features")
|
||||
if state_features is not None:
|
||||
state_features = state_features.to(device)
|
||||
lengths = processed.get("lengths")
|
||||
|
||||
for scheme in schemes_to_viz:
|
||||
sd = scheme_data[scheme]
|
||||
|
||||
# Ground truth
|
||||
# In stride visualization mode, ground-truth plots can be misleading
|
||||
# (only sparse points are available), so we skip GT.
|
||||
if stride == 1 and sd["target_key"] in processed:
|
||||
gt_target = processed[sd["target_key"]][0, target_idx].cpu().item()
|
||||
sd["viz_gt_stages"][local_idx] = int(gt_target)
|
||||
sd["viz_gt_progress"][local_idx] = normalize_stage_tau(
|
||||
gt_target,
|
||||
num_stages=sd["num_stages"],
|
||||
temporal_proportions=sd["temporal_props"],
|
||||
subtask_names=sd["subtask_names"],
|
||||
)
|
||||
|
||||
# Predictions
|
||||
reward, stage_probs = reward_model.calculate_rewards(
|
||||
text_embeddings=text_features,
|
||||
video_embeddings=video_features,
|
||||
state_features=state_features,
|
||||
lengths=lengths,
|
||||
return_all_frames=True,
|
||||
return_stages=True,
|
||||
head_mode=scheme,
|
||||
)
|
||||
|
||||
# Handle both tensor and numpy outputs
|
||||
if isinstance(reward, torch.Tensor):
|
||||
reward = reward.cpu().numpy()
|
||||
stage_probs = stage_probs.cpu().numpy()
|
||||
|
||||
if reward.ndim == 2:
|
||||
sd["viz_progress"][local_idx] = reward[0, target_idx]
|
||||
sd["viz_stages"][local_idx] = stage_probs[0, target_idx, :]
|
||||
else:
|
||||
sd["viz_progress"][local_idx] = reward[target_idx]
|
||||
sd["viz_stages"][local_idx] = stage_probs[target_idx, :]
|
||||
|
||||
# Clear GPU memory after each frame
|
||||
del processed, video_features, text_features
|
||||
if state_features is not None:
|
||||
del state_features
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Interpolate predictions back to per-frame arrays for smooth visualization.
|
||||
if stride > 1:
|
||||
all_local = np.arange(num_frames)
|
||||
for scheme in schemes_to_viz:
|
||||
sd = scheme_data[scheme]
|
||||
|
||||
valid = np.isfinite(sd["viz_progress"])
|
||||
valid_idx = np.where(valid)[0]
|
||||
if valid_idx.size >= 1:
|
||||
sd["viz_progress"] = interpolate_progress(
|
||||
valid_idx, sd["viz_progress"][valid_idx], all_local
|
||||
)
|
||||
|
||||
stage_interp = np.zeros_like(sd["viz_stages"], dtype=np.float32)
|
||||
for s in range(sd["num_stages"]):
|
||||
stage_interp[:, s] = interpolate_progress(
|
||||
valid_idx, sd["viz_stages"][valid_idx, s], all_local
|
||||
)
|
||||
|
||||
stage_interp = np.clip(stage_interp, 0.0, 1.0)
|
||||
row_sums = stage_interp.sum(axis=1, keepdims=True)
|
||||
nz = row_sums.squeeze(-1) > 0
|
||||
stage_interp[nz] = stage_interp[nz] / row_sums[nz]
|
||||
sd["viz_stages"] = stage_interp
|
||||
else:
|
||||
# No valid points: keep NaNs/zeros; visualization will be empty.
|
||||
sd["viz_stages"] = np.nan_to_num(sd["viz_stages"], nan=0.0)
|
||||
|
||||
# Generate visualization for each head
|
||||
ordered_viz_frames = [viz_frames[idx] for idx in sorted(display_indices)]
|
||||
for scheme in schemes_to_viz:
|
||||
sd = scheme_data[scheme]
|
||||
stage_labels = sd["subtask_names"] or [f"Stage {i + 1}" for i in range(sd["num_stages"])]
|
||||
viz_path = output_dir / f"sarm_prediction_ep{episode_idx}_{scheme}.png"
|
||||
|
||||
visualize_episode(
|
||||
frames=np.array(ordered_viz_frames),
|
||||
progress_preds=sd["viz_progress"],
|
||||
stage_preds=sd["viz_stages"],
|
||||
title=f"{task} (Episode {episode_idx})",
|
||||
output_path=viz_path,
|
||||
stage_labels=stage_labels,
|
||||
gt_progress=sd["viz_gt_progress"] if not np.all(np.isnan(sd["viz_gt_progress"])) else None,
|
||||
gt_stages=sd["viz_gt_stages"] if not np.all(np.isnan(sd["viz_gt_stages"])) else None,
|
||||
)
|
||||
|
||||
# Clear memory between episodes
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
logging.info(f"Visualizations saved to: {output_dir.absolute()}")
|
||||
|
||||
|
||||
def generate_all_frame_indices(ep_start: int, ep_end: int, frame_gap: int = 30) -> list[int]:
|
||||
"""Generate all frame indices, ordered by offset for cache-friendly access.
|
||||
|
||||
Orders frames as: [0, 30, 60...], [1, 31, 61...], ..., [29, 59, 89...]
|
||||
This groups frames that share similar temporal windows together.
|
||||
"""
|
||||
num_frames = ep_end - ep_start
|
||||
indices = []
|
||||
for offset in range(frame_gap):
|
||||
for frame_rel in range(offset, num_frames, frame_gap):
|
||||
indices.append(ep_start + frame_rel)
|
||||
return indices
|
||||
|
||||
|
||||
def interpolate_progress(
|
||||
computed_indices: np.ndarray,
|
||||
computed_values: np.ndarray,
|
||||
all_indices: np.ndarray,
|
||||
) -> np.ndarray:
|
||||
"""Linearly interpolate values to fill in gaps (robust to NaNs / edge cases)."""
|
||||
computed_indices = np.asarray(computed_indices)
|
||||
computed_values = np.asarray(computed_values)
|
||||
all_indices = np.asarray(all_indices)
|
||||
|
||||
mask = np.isfinite(computed_values)
|
||||
if mask.sum() == 0:
|
||||
return np.full(all_indices.shape, np.nan, dtype=np.float32)
|
||||
if mask.sum() == 1:
|
||||
return np.full(all_indices.shape, float(computed_values[mask][0]), dtype=np.float32)
|
||||
|
||||
out = np.interp(all_indices, computed_indices[mask], computed_values[mask])
|
||||
return out.astype(np.float32)
|
||||
|
||||
|
||||
def compute_sarm_progress(
|
||||
dataset_repo_id: str,
|
||||
reward_model_path: str,
|
||||
output_path: str | None = None,
|
||||
head_mode: str = "sparse",
|
||||
device: str = "cuda",
|
||||
num_visualizations: int = 5,
|
||||
output_dir: str = "./sarm_viz",
|
||||
stride: int = 1,
|
||||
):
|
||||
"""
|
||||
Compute SARM progress predictions for all frames in a dataset.
|
||||
|
||||
Args:
|
||||
dataset_repo_id: HuggingFace dataset repo ID or local path
|
||||
reward_model_path: Path to pretrained SARM model
|
||||
output_path: Path to save results. If None, saves to dataset's cache directory
|
||||
head_mode: SARM head to use ("sparse", "dense", or "both")
|
||||
device: Device to use for inference
|
||||
num_visualizations: Number of episodes to visualize (0 to skip)
|
||||
output_dir: Directory to save visualizations
|
||||
stride: Compute progress every N frames, interpolate the rest (default: 1 = every frame)
|
||||
"""
|
||||
dataset, reward_model, preprocess = load_sarm_resources(dataset_repo_id, reward_model_path, device)
|
||||
|
||||
# Set preprocessor to eval mode to disable augmentations
|
||||
if hasattr(preprocess, "eval"):
|
||||
preprocess.eval()
|
||||
for step in preprocess.steps:
|
||||
if hasattr(step, "eval"):
|
||||
step.eval()
|
||||
|
||||
image_key = reward_model.config.image_key
|
||||
state_key = reward_model.config.state_key
|
||||
frame_gap = reward_model.config.frame_gap
|
||||
num_episodes = dataset.num_episodes
|
||||
total_frames = dataset.num_frames
|
||||
logging.info(f"Processing {total_frames} frames across {num_episodes} episodes")
|
||||
|
||||
# Determine which heads to compute
|
||||
dual_mode = reward_model.config.uses_dual_heads
|
||||
compute_sparse = head_mode in ("sparse", "both") or not dual_mode
|
||||
compute_dense = head_mode in ("dense", "both") and dual_mode
|
||||
|
||||
# Storage arrays
|
||||
all_indices = []
|
||||
all_episode_indices = []
|
||||
all_frame_indices = []
|
||||
all_progress_sparse = [] if compute_sparse else None
|
||||
all_progress_dense = [] if compute_dense else None
|
||||
|
||||
if stride > 1:
|
||||
logging.info(f"Using stride={stride}: computing every {stride} frames, interpolating the rest")
|
||||
|
||||
# Process all episodes
|
||||
for episode_idx in tqdm(range(num_episodes), desc="Episodes"):
|
||||
ep = dataset.meta.episodes[episode_idx]
|
||||
ep_start = ep["dataset_from_index"]
|
||||
ep_end = ep["dataset_to_index"]
|
||||
|
||||
# Get task description
|
||||
task = dataset[ep_start].get("task", "perform the task")
|
||||
|
||||
# Generate frames to compute (with stride applied)
|
||||
all_ep_indices = generate_all_frame_indices(ep_start, ep_end, frame_gap)
|
||||
if stride > 1:
|
||||
# Only compute every stride-th frame (relative to episode start)
|
||||
compute_indices = [idx for idx in all_ep_indices if (idx - ep_start) % stride == 0]
|
||||
# Always include last frame for better interpolation at episode end
|
||||
last_frame = ep_end - 1
|
||||
if last_frame not in compute_indices:
|
||||
compute_indices.append(last_frame)
|
||||
compute_indices = sorted(set(compute_indices))
|
||||
else:
|
||||
compute_indices = all_ep_indices
|
||||
|
||||
center_idx = reward_model.config.n_obs_steps // 2 # Center of bidirectional window
|
||||
|
||||
# Dictionary to collect results
|
||||
frame_results = {}
|
||||
|
||||
for query_idx in tqdm(compute_indices, desc=f" Ep {episode_idx}", leave=False):
|
||||
try:
|
||||
sample = dataset[query_idx]
|
||||
|
||||
batch = {
|
||||
image_key: sample[image_key],
|
||||
"task": task,
|
||||
"index": query_idx,
|
||||
"episode_index": episode_idx,
|
||||
}
|
||||
if state_key in sample:
|
||||
batch[state_key] = sample[state_key]
|
||||
|
||||
with torch.no_grad():
|
||||
processed = preprocess(batch)
|
||||
video_features = processed["video_features"].to(device)
|
||||
text_features = processed["text_features"].to(device)
|
||||
state_features = processed.get("state_features")
|
||||
if state_features is not None:
|
||||
state_features = state_features.to(device)
|
||||
lengths = processed.get("lengths")
|
||||
|
||||
sparse_val = np.nan
|
||||
dense_val = np.nan
|
||||
|
||||
# Compute sparse prediction for center frame
|
||||
if compute_sparse:
|
||||
sparse_progress = reward_model.calculate_rewards(
|
||||
text_embeddings=text_features,
|
||||
video_embeddings=video_features,
|
||||
state_features=state_features,
|
||||
lengths=lengths,
|
||||
return_all_frames=True,
|
||||
head_mode="sparse",
|
||||
)
|
||||
sparse_val = float(
|
||||
sparse_progress[0, center_idx]
|
||||
if sparse_progress.ndim == 2
|
||||
else sparse_progress[center_idx]
|
||||
)
|
||||
|
||||
# Compute dense prediction for center frame
|
||||
if compute_dense:
|
||||
dense_progress = reward_model.calculate_rewards(
|
||||
text_embeddings=text_features,
|
||||
video_embeddings=video_features,
|
||||
state_features=state_features,
|
||||
lengths=lengths,
|
||||
return_all_frames=True,
|
||||
head_mode="dense",
|
||||
)
|
||||
dense_val = float(
|
||||
dense_progress[0, center_idx]
|
||||
if dense_progress.ndim == 2
|
||||
else dense_progress[center_idx]
|
||||
)
|
||||
|
||||
frame_results[query_idx] = (sparse_val, dense_val)
|
||||
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to process frame {query_idx}: {e}")
|
||||
|
||||
# Interpolate to get values for all frames
|
||||
computed_indices = np.array(sorted(frame_results.keys()))
|
||||
computed_sparse = (
|
||||
np.array([frame_results[i][0] for i in computed_indices]) if compute_sparse else None
|
||||
)
|
||||
computed_dense = np.array([frame_results[i][1] for i in computed_indices]) if compute_dense else None
|
||||
|
||||
# All frame indices for this episode
|
||||
all_frame_idx_array = np.arange(ep_start, ep_end)
|
||||
|
||||
if stride > 1 and len(computed_indices) > 1:
|
||||
# Interpolate progress values
|
||||
if compute_sparse:
|
||||
interp_sparse = interpolate_progress(computed_indices, computed_sparse, all_frame_idx_array)
|
||||
if compute_dense:
|
||||
interp_dense = interpolate_progress(computed_indices, computed_dense, all_frame_idx_array)
|
||||
else:
|
||||
# No interpolation needed
|
||||
interp_sparse = computed_sparse if compute_sparse else None
|
||||
interp_dense = computed_dense if compute_dense else None
|
||||
|
||||
# Store results for all frames
|
||||
for i, frame_idx in enumerate(all_frame_idx_array):
|
||||
local_idx = frame_idx - ep_start
|
||||
all_indices.append(frame_idx)
|
||||
all_episode_indices.append(episode_idx)
|
||||
all_frame_indices.append(local_idx)
|
||||
if compute_sparse:
|
||||
if stride > 1 and len(computed_indices) > 1:
|
||||
all_progress_sparse.append(float(interp_sparse[i]))
|
||||
elif frame_idx in frame_results:
|
||||
all_progress_sparse.append(frame_results[frame_idx][0])
|
||||
else:
|
||||
all_progress_sparse.append(np.nan)
|
||||
if compute_dense:
|
||||
if stride > 1 and len(computed_indices) > 1:
|
||||
all_progress_dense.append(float(interp_dense[i]))
|
||||
elif frame_idx in frame_results:
|
||||
all_progress_dense.append(frame_results[frame_idx][1])
|
||||
else:
|
||||
all_progress_dense.append(np.nan)
|
||||
|
||||
# Create output table
|
||||
table_data = {
|
||||
"index": np.array(all_indices, dtype=np.int64),
|
||||
"episode_index": np.array(all_episode_indices, dtype=np.int64),
|
||||
"frame_index": np.array(all_frame_indices, dtype=np.int64),
|
||||
}
|
||||
if compute_sparse:
|
||||
table_data["progress_sparse"] = np.array(all_progress_sparse, dtype=np.float32)
|
||||
if compute_dense:
|
||||
table_data["progress_dense"] = np.array(all_progress_dense, dtype=np.float32)
|
||||
|
||||
# Sort by index
|
||||
df = pa.table(table_data).to_pandas()
|
||||
df = df.sort_values("index").reset_index(drop=True)
|
||||
final_table = pa.Table.from_pandas(df, preserve_index=False)
|
||||
|
||||
# Add metadata with reward model path
|
||||
metadata = {b"reward_model_path": reward_model_path.encode()}
|
||||
final_table = final_table.replace_schema_metadata(metadata)
|
||||
|
||||
# Determine output path
|
||||
output_path = Path(dataset.root) / "sarm_progress.parquet" if output_path is None else Path(output_path)
|
||||
|
||||
# Save
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
pq.write_table(final_table, output_path)
|
||||
logging.info(f"Saved {len(final_table)} frame progress values to {output_path}")
|
||||
|
||||
# Print statistics
|
||||
if "progress_sparse" in df.columns:
|
||||
valid = df["progress_sparse"].dropna()
|
||||
logging.info(
|
||||
f"Sparse progress: mean={valid.mean():.4f}, std={valid.std():.4f}, "
|
||||
f"min={valid.min():.4f}, max={valid.max():.4f}"
|
||||
)
|
||||
|
||||
if "progress_dense" in df.columns:
|
||||
valid = df["progress_dense"].dropna()
|
||||
logging.info(
|
||||
f"Dense progress: mean={valid.mean():.4f}, std={valid.std():.4f}, "
|
||||
f"min={valid.min():.4f}, max={valid.max():.4f}"
|
||||
)
|
||||
|
||||
# Visualize episodes after processing
|
||||
if num_visualizations > 0:
|
||||
viz_episodes = list(range(min(num_visualizations, num_episodes)))
|
||||
logging.info(f"Generating {len(viz_episodes)} visualizations...")
|
||||
visualize_sarm_predictions(
|
||||
dataset=dataset,
|
||||
reward_model=reward_model,
|
||||
preprocess=preprocess,
|
||||
episode_indices=viz_episodes,
|
||||
head_mode=head_mode,
|
||||
output_dir=Path(output_dir),
|
||||
stride=stride,
|
||||
)
|
||||
|
||||
return output_path
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Compute SARM progress values for RA-BC weighting or visualize SARM predictions",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
# Full RA-BC computation with visualizations
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
|
||||
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
|
||||
--reward-model-path pepijn223/sarm_single_uni4
|
||||
|
||||
# Visualize predictions only (no RA-BC computation)
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
|
||||
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
|
||||
--reward-model-path pepijn223/sarm_single_uni4 \\
|
||||
--visualize-only \\
|
||||
--num-visualizations 10
|
||||
""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-repo-id",
|
||||
type=str,
|
||||
required=True,
|
||||
help="HuggingFace dataset repo ID or local path",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--reward-model-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to pretrained SARM model (reads from existing parquet metadata if not provided)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Output path for parquet. If not set, saves to dataset's cache directory",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--head-mode",
|
||||
type=str,
|
||||
default="sparse",
|
||||
choices=["sparse", "dense", "both"],
|
||||
help="SARM head to use (default: sparse)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="cuda",
|
||||
help="Device to use (default: cuda)",
|
||||
)
|
||||
# Visualization options
|
||||
parser.add_argument(
|
||||
"--visualize-only",
|
||||
action="store_true",
|
||||
help="Only visualize SARM predictions (no RA-BC computation)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-visualizations",
|
||||
type=int,
|
||||
default=5,
|
||||
help="Number of episodes to visualize (default: 5, set to 0 to skip)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=str,
|
||||
default="./sarm_viz",
|
||||
help="Output directory for visualizations (default: ./sarm_viz)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push-to-hub",
|
||||
action="store_true",
|
||||
help="Upload progress file to the dataset repo on HuggingFace Hub",
|
||||
default=True,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--stride",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Compute progress every N frames, interpolate the rest (default: 1 = every frame)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
||||
|
||||
# Try to get reward_model_path from parquet metadata if not provided
|
||||
reward_model_path = args.reward_model_path
|
||||
if reward_model_path is None:
|
||||
# Load dataset to find parquet path
|
||||
temp_dataset = LeRobotDataset(args.dataset_repo_id, download_videos=False)
|
||||
parquet_path = Path(temp_dataset.root) / "sarm_progress.parquet"
|
||||
reward_model_path = get_reward_model_path_from_parquet(parquet_path)
|
||||
if reward_model_path:
|
||||
logging.info(f"Using reward model from parquet metadata: {reward_model_path}")
|
||||
else:
|
||||
raise ValueError(
|
||||
"--reward-model-path is required (no existing parquet with model metadata found)"
|
||||
)
|
||||
|
||||
# Handle visualize-only mode
|
||||
if args.visualize_only:
|
||||
dataset, reward_model, preprocess = load_sarm_resources(
|
||||
args.dataset_repo_id, reward_model_path, args.device
|
||||
)
|
||||
logging.info(f"Visualization-only mode: visualizing {args.num_visualizations} episodes")
|
||||
viz_episodes = list(range(min(args.num_visualizations, dataset.num_episodes)))
|
||||
visualize_sarm_predictions(
|
||||
dataset=dataset,
|
||||
reward_model=reward_model,
|
||||
preprocess=preprocess,
|
||||
episode_indices=viz_episodes,
|
||||
head_mode=args.head_mode,
|
||||
output_dir=Path(args.output_dir),
|
||||
stride=args.stride,
|
||||
)
|
||||
print(f"\nVisualizations saved to: {Path(args.output_dir).absolute()}")
|
||||
return
|
||||
|
||||
# Full RABC computation (compute_sarm_progress loads model/dataset itself)
|
||||
output_path = compute_sarm_progress(
|
||||
dataset_repo_id=args.dataset_repo_id,
|
||||
reward_model_path=reward_model_path,
|
||||
output_path=args.output_path,
|
||||
head_mode=args.head_mode,
|
||||
device=args.device,
|
||||
num_visualizations=args.num_visualizations,
|
||||
output_dir=args.output_dir,
|
||||
stride=args.stride,
|
||||
)
|
||||
|
||||
print(f"\nSARM progress values saved to: {output_path}")
|
||||
|
||||
# Upload to Hub if requested
|
||||
if args.push_to_hub:
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
api = HfApi()
|
||||
hub_path = "sarm_progress.parquet"
|
||||
|
||||
print(f"\nUploading to Hub: {args.dataset_repo_id}/{hub_path}")
|
||||
api.upload_file(
|
||||
path_or_fileobj=str(output_path),
|
||||
path_in_repo=hub_path,
|
||||
repo_id=args.dataset_repo_id,
|
||||
repo_type="dataset",
|
||||
)
|
||||
print(
|
||||
f"Successfully uploaded to: https://huggingface.co/datasets/{args.dataset_repo_id}/blob/main/{hub_path}"
|
||||
)
|
||||
|
||||
print("\nTo use in training, add to your config:")
|
||||
print(" use_rabc: true")
|
||||
print(f" rabc_progress_path: hf://datasets/{args.dataset_repo_id}/{hub_path}")
|
||||
print(" rabc_head_mode: sparse # or dense")
|
||||
else:
|
||||
print("\nTo use in training, add to your config:")
|
||||
print(" use_rabc: true")
|
||||
print(f" rabc_progress_path: {output_path}")
|
||||
print(" rabc_head_mode: sparse # or dense")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,248 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 Qianzhong Chen, Justin Yu, Mac Schwager, Pieter Abbeel, Yide Shentu, Philipp Wu
|
||||
# and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
SARM: Stage-Aware Reward Modeling for Long Horizon Robot Manipulation.
|
||||
Paper: https://arxiv.org/abs/2509.25358
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("sarm")
|
||||
@dataclass
|
||||
class SARMConfig(PreTrainedConfig):
|
||||
"""Configuration class for SARM (Stage-Aware Reward Modeling).
|
||||
|
||||
Supports three annotation modes:
|
||||
|
||||
1. single_stage (default): No annotations needed. Uses the episode's task description
|
||||
as a single stage covering the entire episode.
|
||||
|
||||
2. dense_only: Uses dense (fine-grained) annotations from VLM, with an auto-generated
|
||||
single sparse "task" stage covering the full episode. The dense head learns detailed
|
||||
subtask progression while sparse provides overall task completion.
|
||||
|
||||
3. dual: Full dual-head mode with both sparse (high-level) and dense (fine-grained)
|
||||
annotations from VLM. Both heads are trained on their respective annotations.
|
||||
|
||||
The annotation_mode determines how sparse_temporal_proportions and dense_temporal_proportions
|
||||
are loaded/generated during model initialization.
|
||||
"""
|
||||
|
||||
annotation_mode: str = "single_stage" # "single_stage", "dense_only", or "dual"
|
||||
n_obs_steps: int = 8 # Number of observation history steps
|
||||
frame_gap: int = 30 # Frame gap between frames (at 30 fps = 1 second)
|
||||
max_rewind_steps: int = 4 # Maximum rewind steps for temporal augmentation
|
||||
|
||||
# Total frames = 1 + n_obs_steps + max_rewind_steps (computed in property)
|
||||
# During training with rewind: [obs_frames] + [rewind_frames]
|
||||
# During inference: [obs_frames] only
|
||||
|
||||
# Architecture params
|
||||
image_dim: int = 512
|
||||
text_dim: int = 512
|
||||
hidden_dim: int = 768
|
||||
num_heads: int = 12
|
||||
num_layers: int = 8
|
||||
max_state_dim: int = 32
|
||||
drop_n_last_frames: int = 1
|
||||
batch_size: int = 64
|
||||
clip_batch_size: int = 64
|
||||
dropout: float = 0.1
|
||||
stage_loss_weight: float = 1.0 # Weight for stage classification loss when using subtask annotations
|
||||
|
||||
rewind_probability: float = 0.8
|
||||
language_perturbation_probability: float = 0.2
|
||||
|
||||
# Sparse annotations (high-level stages)
|
||||
num_sparse_stages: int = 1
|
||||
sparse_subtask_names: list | None = None
|
||||
sparse_temporal_proportions: list | None = None
|
||||
|
||||
# Dense annotations (fine-grained stages)
|
||||
num_dense_stages: int | None = None
|
||||
dense_subtask_names: list | None = None
|
||||
dense_temporal_proportions: list | None = None
|
||||
|
||||
pretrained_model_path: str | None = None
|
||||
device: str | None = None
|
||||
image_key: str = "observation.images.top" # Key for image used from the dataset
|
||||
state_key: str = "observation.state"
|
||||
|
||||
# Populated by the processor (video_features, state_features, text_features)
|
||||
input_features: dict = field(default_factory=lambda: {})
|
||||
|
||||
# Output features (updated in __post_init__)
|
||||
output_features: dict = field(
|
||||
default_factory=lambda: {
|
||||
"stage": PolicyFeature(shape=(9, 5), type=FeatureType.REWARD),
|
||||
"progress": PolicyFeature(shape=(9, 1), type=FeatureType.REWARD),
|
||||
}
|
||||
)
|
||||
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.MEAN_STD,
|
||||
"LANGUAGE": NormalizationMode.IDENTITY,
|
||||
"REWARD": NormalizationMode.IDENTITY,
|
||||
}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
if self.annotation_mode not in ["single_stage", "dense_only", "dual"]:
|
||||
raise ValueError(
|
||||
f"annotation_mode must be 'single_stage', 'dense_only', or 'dual', got {self.annotation_mode}"
|
||||
)
|
||||
|
||||
if self.annotation_mode == "single_stage":
|
||||
# Use task description as stage name, full episode as one stage
|
||||
self.num_sparse_stages = 1
|
||||
self.sparse_subtask_names = ["task"]
|
||||
self.sparse_temporal_proportions = [1.0]
|
||||
self.num_dense_stages = None
|
||||
self.dense_subtask_names = None
|
||||
self.dense_temporal_proportions = None
|
||||
|
||||
elif self.annotation_mode == "dense_only":
|
||||
self.num_sparse_stages = 1
|
||||
self.sparse_subtask_names = ["task"]
|
||||
self.sparse_temporal_proportions = [1.0]
|
||||
|
||||
self.input_features = {}
|
||||
self.output_features = {}
|
||||
|
||||
if self.image_key:
|
||||
self.input_features[self.image_key] = PolicyFeature(shape=(480, 640, 3), type=FeatureType.VISUAL)
|
||||
|
||||
self.input_features[self.state_key] = PolicyFeature(
|
||||
shape=(self.max_state_dim,),
|
||||
type=FeatureType.STATE,
|
||||
)
|
||||
|
||||
# Update output features based on annotation_mode
|
||||
if self.annotation_mode in ["dense_only", "dual"]:
|
||||
self.output_features["sparse_stage"] = PolicyFeature(
|
||||
shape=(self.num_frames, self.num_sparse_stages), type=FeatureType.REWARD
|
||||
)
|
||||
self.output_features["sparse_progress"] = PolicyFeature(
|
||||
shape=(self.num_frames, 1), type=FeatureType.REWARD
|
||||
)
|
||||
dense_stages = self.num_dense_stages or self.num_sparse_stages
|
||||
self.output_features["dense_stage"] = PolicyFeature(
|
||||
shape=(self.num_frames, dense_stages), type=FeatureType.REWARD
|
||||
)
|
||||
self.output_features["dense_progress"] = PolicyFeature(
|
||||
shape=(self.num_frames, 1), type=FeatureType.REWARD
|
||||
)
|
||||
else:
|
||||
self.output_features["sparse_stage"] = PolicyFeature(
|
||||
shape=(self.num_frames, self.num_sparse_stages), type=FeatureType.REWARD
|
||||
)
|
||||
self.output_features["sparse_progress"] = PolicyFeature(
|
||||
shape=(self.num_frames, 1), type=FeatureType.REWARD
|
||||
)
|
||||
|
||||
if self.max_rewind_steps >= self.n_obs_steps:
|
||||
raise ValueError(
|
||||
f"max_rewind_steps ({self.max_rewind_steps}) must be less than n_obs_steps ({self.n_obs_steps})"
|
||||
)
|
||||
if self.num_sparse_stages < 1:
|
||||
raise ValueError(f"num_sparse_stages must be at least 1, got {self.num_sparse_stages}")
|
||||
if (
|
||||
self.annotation_mode in ["dense_only", "dual"]
|
||||
and self.num_dense_stages is not None
|
||||
and self.num_dense_stages < 2
|
||||
):
|
||||
raise ValueError(f"num_dense_stages must be at least 2, got {self.num_dense_stages}")
|
||||
|
||||
def get_optimizer_preset(self) -> AdamWConfig:
|
||||
"""Get default optimizer configuration for SARM training."""
|
||||
return AdamWConfig(
|
||||
lr=5e-5,
|
||||
weight_decay=1e-3,
|
||||
betas=(0.9, 0.999),
|
||||
eps=1e-8,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self) -> CosineDecayWithWarmupSchedulerConfig:
|
||||
"""Get default learning rate scheduler configuration."""
|
||||
return CosineDecayWithWarmupSchedulerConfig(
|
||||
peak_lr=5e-5,
|
||||
decay_lr=5e-6,
|
||||
num_warmup_steps=500,
|
||||
num_decay_steps=50000,
|
||||
)
|
||||
|
||||
def validate_features(self) -> None:
|
||||
pass
|
||||
|
||||
@property
|
||||
def uses_dual_heads(self) -> bool:
|
||||
"""Whether the model uses dual heads (dense_only or dual annotation modes)."""
|
||||
return self.annotation_mode in ["dense_only", "dual"]
|
||||
|
||||
@property
|
||||
def num_frames(self) -> int:
|
||||
"""Total number of frames in sequence.
|
||||
|
||||
For training: 1 + n_obs_steps + max_rewind_steps
|
||||
The sequence is: [obs_frames (n_obs_steps + 1)] + [rewind_frames (max_rewind_steps)]
|
||||
"""
|
||||
return 1 + self.n_obs_steps + self.max_rewind_steps
|
||||
|
||||
@property
|
||||
def max_length(self) -> int:
|
||||
return self.num_frames
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> list[int]:
|
||||
"""Bidirectional frame sampling centered on target frame.
|
||||
|
||||
Example with n_obs_steps=8, gap=30:
|
||||
Before: [-120, -90, -60, -30] (4 frames)
|
||||
Current: [0] (1 frame)
|
||||
After: [30, 60, 90, 120] (4 frames)
|
||||
Total: 9 frames
|
||||
"""
|
||||
half_steps = self.n_obs_steps // 2
|
||||
|
||||
past_deltas = [-self.frame_gap * i for i in range(half_steps, 0, -1)]
|
||||
future_deltas = [self.frame_gap * i for i in range(1, half_steps + 1)]
|
||||
obs_deltas = past_deltas + [0] + future_deltas
|
||||
|
||||
# Rewind placeholders
|
||||
rewind_deltas = [-self.frame_gap * (i + 1) for i in range(self.max_rewind_steps)]
|
||||
|
||||
return obs_deltas + rewind_deltas
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> None:
|
||||
"""SARM is a reward model, not an action policy."""
|
||||
return None
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
||||
@@ -1,793 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 Qianzhong Chen, Justin Yu, Mac Schwager, Pieter Abbeel, Yide Shentu, Philipp Wu
|
||||
# and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
SARM: Stage-Aware Reward Modeling for Long Horizon Robot Manipulation.
|
||||
|
||||
Paper: https://arxiv.org/abs/2509.25358
|
||||
|
||||
- StageTransformer: Predicts stage classification (sparse/dense)
|
||||
- SubtaskTransformer: Predicts within-stage progress (tau) conditioned on stage
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.sarm.configuration_sarm import SARMConfig
|
||||
from lerobot.policies.sarm.sarm_utils import (
|
||||
normalize_stage_tau,
|
||||
pad_state_to_max_dim,
|
||||
)
|
||||
|
||||
|
||||
class StageTransformer(nn.Module):
|
||||
"""
|
||||
Stage classification transformer for SARM.
|
||||
|
||||
Predicts which stage/subtask the current frame belongs to.
|
||||
Supports both sparse (high-level) and dense (fine-grained) annotation schemes.
|
||||
|
||||
Input streams: [vis_proj, lang_proj, state_proj] concatenated -> (B, N+2, T, D)
|
||||
Output: stage logits (B, T, num_classes)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int = 512,
|
||||
vis_emb_dim: int = 512,
|
||||
text_emb_dim: int = 512,
|
||||
state_dim: int = 32,
|
||||
n_layers: int = 6,
|
||||
n_heads: int = 8,
|
||||
dropout: float = 0.1,
|
||||
num_cameras: int = 1,
|
||||
num_classes_sparse: int = 4,
|
||||
num_classes_dense: int = 8,
|
||||
):
|
||||
super().__init__()
|
||||
self.d_model = d_model
|
||||
self.num_cameras = num_cameras
|
||||
|
||||
# Projections
|
||||
self.lang_proj = nn.Linear(text_emb_dim, d_model)
|
||||
self.visual_proj = nn.Linear(vis_emb_dim, d_model)
|
||||
self.state_proj = nn.Linear(state_dim, d_model)
|
||||
|
||||
# Encoder
|
||||
enc_layer = nn.TransformerEncoderLayer(d_model, n_heads, 4 * d_model, dropout, batch_first=True)
|
||||
self.transformer = nn.TransformerEncoder(enc_layer, n_layers)
|
||||
|
||||
# Positional bias on first visual frame
|
||||
self.first_pos = nn.Parameter(torch.zeros(1, d_model))
|
||||
|
||||
# Shared fusion MLP
|
||||
# Fuses (num_cameras + 2) streams: cameras + lang + state
|
||||
fused_in = d_model * (num_cameras + 2)
|
||||
self.fusion_backbone = nn.Sequential(
|
||||
nn.LayerNorm(fused_in),
|
||||
nn.Linear(fused_in, d_model),
|
||||
nn.ReLU(),
|
||||
)
|
||||
|
||||
# Scheme-specific heads
|
||||
self.heads = nn.ModuleDict(
|
||||
{
|
||||
"sparse": nn.Linear(d_model, num_classes_sparse),
|
||||
"dense": nn.Linear(d_model, num_classes_dense),
|
||||
}
|
||||
)
|
||||
|
||||
def _prep_lang(self, lang_emb: torch.Tensor, B: int, T: int, D: int) -> torch.Tensor: # noqa: N803
|
||||
"""
|
||||
Prepare language embeddings for fusion.
|
||||
|
||||
Accepts lang_emb of shape:
|
||||
- (B, text_emb_dim) -> broadcast across time
|
||||
- (B, T, text_emb_dim) -> per-timestep (dense annotation mode)
|
||||
|
||||
Returns: (B, 1, T, D)
|
||||
"""
|
||||
if lang_emb.dim() == 3:
|
||||
# (B, T, E) -> (B, T, D) -> (B, 1, T, D)
|
||||
lang_proj = self.lang_proj(lang_emb).unsqueeze(1)
|
||||
else:
|
||||
# (B, E) -> (B, 1, 1, D) -> expand to (B, 1, T, D)
|
||||
lang_proj = self.lang_proj(lang_emb).unsqueeze(1).unsqueeze(2).expand(B, 1, T, D)
|
||||
return lang_proj
|
||||
|
||||
def forward(
|
||||
self,
|
||||
img_seq: torch.Tensor, # (B, N, T, vis_emb_dim)
|
||||
lang_emb: torch.Tensor, # (B, E) or (B, T, E)
|
||||
state: torch.Tensor, # (B, T, state_dim)
|
||||
lengths: torch.Tensor, # (B,) - valid sequence lengths
|
||||
scheme: str = "sparse", # "sparse" or "dense"
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass for stage classification.
|
||||
|
||||
Args:
|
||||
img_seq: Image embeddings (B, N, T, vis_emb_dim) where N=num_cameras
|
||||
lang_emb: Language embeddings (B, E) or (B, T, E) for dense
|
||||
state: State features (B, T, state_dim)
|
||||
lengths: Valid sequence lengths (B,) for masking
|
||||
scheme: "sparse" or "dense" for head selection
|
||||
|
||||
Returns:
|
||||
Stage logits (B, T, num_classes)
|
||||
"""
|
||||
assert scheme in self.heads, f"Unknown scheme '{scheme}'. Use one of {list(self.heads.keys())}."
|
||||
|
||||
B, N, T, _ = img_seq.shape # noqa: N806
|
||||
D = self.d_model # noqa: N806
|
||||
device = img_seq.device
|
||||
|
||||
# Project inputs
|
||||
vis_proj = self.visual_proj(img_seq) # (B, N, T, D)
|
||||
state_proj = self.state_proj(state).unsqueeze(1) # (B, 1, T, D)
|
||||
lang_proj = self._prep_lang(lang_emb, B, T, D) # (B, 1, T, D)
|
||||
|
||||
# Concatenate streams
|
||||
# cameras + lang + state -> (B, N+2, T, D)
|
||||
x = torch.cat([vis_proj, lang_proj, state_proj], dim=1)
|
||||
|
||||
# Add positional bias to first visual frame
|
||||
x[:, :N, 0, :] = x[:, :N, 0, :] + self.first_pos
|
||||
|
||||
# Flatten to tokens for Transformer
|
||||
x_tokens = x.view(B, (N + 2) * T, D)
|
||||
L = x_tokens.size(1) # noqa: N806
|
||||
|
||||
# Create padding mask
|
||||
base_mask = torch.arange(T, device=device).expand(B, T) >= lengths.unsqueeze(1) # (B, T)
|
||||
mask = base_mask.unsqueeze(1).expand(B, N + 2, T).reshape(B, (N + 2) * T)
|
||||
|
||||
# Create causal mask
|
||||
causal_mask = torch.triu(torch.ones(L, L, device=device, dtype=torch.bool), diagonal=1)
|
||||
|
||||
# Encode
|
||||
h = self.transformer(x_tokens, mask=causal_mask, src_key_padding_mask=mask, is_causal=True)
|
||||
|
||||
# Reshape and fuse
|
||||
h = h.view(B, N + 2, T, D).permute(0, 2, 1, 3).reshape(B, T, (N + 2) * D)
|
||||
fused = self.fusion_backbone(h) # (B, T, D)
|
||||
|
||||
# Scheme-specific logits
|
||||
logits = self.heads[scheme](fused) # (B, T, num_classes)
|
||||
return logits
|
||||
|
||||
|
||||
class SubtaskTransformer(nn.Module):
|
||||
"""
|
||||
Subtask progress regression transformer for SARM.
|
||||
|
||||
Predicts within-stage normalized progress (tau) conditioned on stage prior.
|
||||
The stage prior is a one-hot encoding passed from StageTransformer predictions.
|
||||
|
||||
Input streams: [vis_proj, lang_proj, state_proj, stage_emb] -> (B, N+3, T, D)
|
||||
Output: tau predictions (B, T) in [0, 1]
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int = 512,
|
||||
vis_emb_dim: int = 512,
|
||||
text_emb_dim: int = 512,
|
||||
state_dim: int = 32,
|
||||
n_layers: int = 6,
|
||||
n_heads: int = 8,
|
||||
dropout: float = 0.1,
|
||||
num_cameras: int = 1,
|
||||
):
|
||||
super().__init__()
|
||||
self.d_model = d_model
|
||||
self.num_cameras = num_cameras
|
||||
|
||||
# Projections
|
||||
self.lang_proj = nn.Linear(text_emb_dim, d_model)
|
||||
self.visual_proj = nn.Linear(vis_emb_dim, d_model)
|
||||
self.state_proj = nn.Linear(state_dim, d_model)
|
||||
|
||||
# Encoder
|
||||
enc = nn.TransformerEncoderLayer(d_model, n_heads, 4 * d_model, dropout, batch_first=True)
|
||||
self.transformer = nn.TransformerEncoder(enc, n_layers)
|
||||
|
||||
# Learned bias on first visual frame
|
||||
self.first_pos = nn.Parameter(torch.zeros(1, d_model))
|
||||
|
||||
# Shared fusion backbone
|
||||
# Fuses (num_cameras + 3) streams: cameras + lang + state + stage_emb
|
||||
fused_in = d_model * (num_cameras + 3)
|
||||
self.fusion_backbone = nn.Sequential(
|
||||
nn.LayerNorm(fused_in),
|
||||
nn.Linear(fused_in, d_model),
|
||||
nn.ReLU(),
|
||||
)
|
||||
|
||||
# Scheme-specific regression heads
|
||||
self.heads = nn.ModuleDict(
|
||||
{
|
||||
"sparse": nn.Linear(d_model, 1),
|
||||
"dense": nn.Linear(d_model, 1),
|
||||
}
|
||||
)
|
||||
|
||||
def _prep_lang(self, lang_emb: torch.Tensor, B: int, T: int, D: int) -> torch.Tensor: # noqa: N803
|
||||
"""
|
||||
Prepare language embeddings for fusion.
|
||||
"""
|
||||
if lang_emb.dim() == 3:
|
||||
# (B, T, E) -> (B, T, D) -> (B, 1, T, D)
|
||||
return self.lang_proj(lang_emb).unsqueeze(1)
|
||||
else:
|
||||
# (B, E) -> (B, 1, 1, D) -> (B, 1, T, D)
|
||||
return self.lang_proj(lang_emb).unsqueeze(1).unsqueeze(2).expand(B, 1, T, D)
|
||||
|
||||
def _stage_to_dmodel(self, stage_prior: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Deterministic projection of one-hot stage to d_model by pad/truncate.
|
||||
|
||||
Args:
|
||||
stage_prior: One-hot stage embedding (B, 1, T, C)
|
||||
|
||||
Returns:
|
||||
Projected stage embedding (B, 1, T, d_model)
|
||||
"""
|
||||
B, one, T, C = stage_prior.shape # noqa: N806
|
||||
D = self.d_model # noqa: N806
|
||||
if D == C:
|
||||
return stage_prior
|
||||
elif D > C:
|
||||
pad = torch.zeros(B, one, T, D - C, device=stage_prior.device, dtype=stage_prior.dtype)
|
||||
return torch.cat([stage_prior, pad], dim=-1)
|
||||
else:
|
||||
return stage_prior[..., :D]
|
||||
|
||||
def forward(
|
||||
self,
|
||||
img_seq: torch.Tensor, # (B, N, T, vis_emb_dim)
|
||||
lang_emb: torch.Tensor, # (B, E) or (B, T, E)
|
||||
state: torch.Tensor, # (B, T, state_dim)
|
||||
lengths: torch.Tensor, # (B,) - valid sequence lengths
|
||||
stage_prior: torch.Tensor, # (B, 1, T, C) one-hot from gen_stage_emb
|
||||
scheme: str = "sparse", # "sparse" or "dense"
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass for subtask progress regression.
|
||||
|
||||
Args:
|
||||
img_seq: Image embeddings (B, N, T, vis_emb_dim)
|
||||
lang_emb: Language embeddings (B, E) or (B, T, E)
|
||||
state: State features (B, T, state_dim)
|
||||
lengths: Valid sequence lengths (B,) for masking
|
||||
stage_prior: One-hot stage prior (B, 1, T, num_classes)
|
||||
scheme: "sparse" or "dense" for head selection
|
||||
|
||||
Returns:
|
||||
Tau predictions (B, T) in [0, 1] via sigmoid
|
||||
"""
|
||||
assert scheme in self.heads, f"Unknown scheme '{scheme}'. Use one of {list(self.heads.keys())}."
|
||||
|
||||
B, N, T, _ = img_seq.shape # noqa: N806
|
||||
D = self.d_model # noqa: N806
|
||||
device = img_seq.device
|
||||
|
||||
# Project inputs
|
||||
vis_proj = self.visual_proj(img_seq) # (B, N, T, D)
|
||||
state_proj = self.state_proj(state).unsqueeze(1) # (B, 1, T, D)
|
||||
lang_proj = self._prep_lang(lang_emb, B, T, D) # (B, 1, T, D)
|
||||
stage_emb = self._stage_to_dmodel(stage_prior) # (B, 1, T, D)
|
||||
|
||||
# Concatenate all streams
|
||||
# cameras + lang + state + stage_emb -> (B, N+3, T, D)
|
||||
x = torch.cat([vis_proj, lang_proj, state_proj, stage_emb], dim=1)
|
||||
|
||||
# Add positional bias to first visual frame
|
||||
x[:, :N, 0, :] = x[:, :N, 0, :] + self.first_pos
|
||||
|
||||
# Flatten to tokens
|
||||
x_tokens = x.view(B, (N + 3) * T, D)
|
||||
L = x_tokens.size(1) # noqa: N806
|
||||
|
||||
# Create padding mask
|
||||
base_mask = torch.arange(T, device=device).expand(B, T) >= lengths.unsqueeze(1)
|
||||
mask = base_mask.unsqueeze(1).expand(B, N + 3, T).reshape(B, (N + 3) * T)
|
||||
|
||||
# Create causal mask
|
||||
causal_mask = torch.triu(torch.ones(L, L, device=device, dtype=torch.bool), diagonal=1)
|
||||
|
||||
# Encode
|
||||
h = self.transformer(x_tokens, mask=causal_mask, src_key_padding_mask=mask, is_causal=True)
|
||||
|
||||
# Reshape and fuse
|
||||
h = h.view(B, N + 3, T, D)
|
||||
h_flat = h.permute(0, 2, 1, 3).reshape(B, T, (N + 3) * D)
|
||||
fused = self.fusion_backbone(h_flat) # (B, T, D)
|
||||
|
||||
# Scheme-specific regression head -> sigmoid
|
||||
r = torch.sigmoid(self.heads[scheme](fused)).squeeze(-1) # (B, T)
|
||||
return r
|
||||
|
||||
|
||||
def gen_stage_emb(num_classes: int, targets: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Generate one-hot stage embeddings from targets.
|
||||
|
||||
Args:
|
||||
num_classes: Number of stage classes
|
||||
targets: Target values (B, T) where integer part is stage index
|
||||
|
||||
Returns:
|
||||
One-hot stage embedding (B, 1, T, num_classes)
|
||||
"""
|
||||
# Integer part of float targets -> [0, C-1]
|
||||
idx = targets.long().clamp(min=0, max=num_classes - 1) # (B, T)
|
||||
C = num_classes # noqa: N806
|
||||
# Identity-lookup one-hot
|
||||
stage_onehot = torch.eye(C, device=targets.device)[idx] # (B, T, C)
|
||||
stage_onehot = stage_onehot.unsqueeze(1) # (B, 1, T, C)
|
||||
return stage_onehot
|
||||
|
||||
|
||||
class SARMRewardModel(PreTrainedPolicy):
|
||||
"""
|
||||
SARM Reward Model for stage-aware task completion rewards.
|
||||
|
||||
Uses two separate transformer models:
|
||||
- StageTransformer: Classifies which stage/subtask
|
||||
- SubtaskTransformer: Predicts within-stage progress (tau)
|
||||
|
||||
Training uses 75%/25% GT/predicted stage conditioning (teacher forcing).
|
||||
"""
|
||||
|
||||
name = "sarm"
|
||||
config_class = SARMConfig
|
||||
|
||||
def __init__(self, config: SARMConfig, dataset_stats: dict | None = None, dataset_meta=None):
|
||||
super().__init__(config, dataset_stats)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
self.dataset_stats = dataset_stats
|
||||
self.device = torch.device(
|
||||
config.device if config.device else "cuda" if torch.cuda.is_available() else "cpu"
|
||||
)
|
||||
|
||||
# Load temporal proportions based on annotation_mode
|
||||
if config.annotation_mode == "single_stage":
|
||||
logging.info(f"Using single_stage mode: sparse_subtask_names={config.sparse_subtask_names}")
|
||||
elif dataset_meta is not None:
|
||||
self._load_temporal_proportions(dataset_meta)
|
||||
|
||||
# Create two separate models
|
||||
self.stage_model = StageTransformer(
|
||||
d_model=config.hidden_dim,
|
||||
vis_emb_dim=config.image_dim,
|
||||
text_emb_dim=config.text_dim,
|
||||
state_dim=config.max_state_dim,
|
||||
n_layers=config.num_layers,
|
||||
n_heads=config.num_heads,
|
||||
dropout=config.dropout,
|
||||
num_cameras=1, # Single camera for now
|
||||
num_classes_sparse=config.num_sparse_stages,
|
||||
num_classes_dense=config.num_dense_stages or config.num_sparse_stages,
|
||||
)
|
||||
|
||||
self.subtask_model = SubtaskTransformer(
|
||||
d_model=config.hidden_dim,
|
||||
vis_emb_dim=config.image_dim,
|
||||
text_emb_dim=config.text_dim,
|
||||
state_dim=config.max_state_dim,
|
||||
n_layers=config.num_layers,
|
||||
n_heads=config.num_heads,
|
||||
dropout=config.dropout,
|
||||
num_cameras=1,
|
||||
)
|
||||
|
||||
self.stage_model.to(self.device)
|
||||
self.subtask_model.to(self.device)
|
||||
|
||||
# GT/predicted stage ratio for teacher forcing
|
||||
self.gt_stage_ratio = 0.75
|
||||
|
||||
if config.uses_dual_heads:
|
||||
logging.info(
|
||||
f"SARM initialized with dual heads: {config.num_sparse_stages} sparse stages, "
|
||||
f"{config.num_dense_stages} dense stages"
|
||||
)
|
||||
else:
|
||||
logging.info(f"SARM initialized with sparse head only: {config.num_sparse_stages} stages")
|
||||
|
||||
logging.info(f"SARM initialized on {self.device}")
|
||||
|
||||
def _load_proportions_from_json(self, path, annotation_type: str) -> tuple[list[str], list[float]]:
|
||||
"""Load temporal proportions from a JSON file (preserving order)."""
|
||||
if not path.exists():
|
||||
raise ValueError(
|
||||
f"{annotation_type.capitalize()} temporal proportions not found at {path}. "
|
||||
f"Run the subtask annotation tool with --{annotation_type}-subtasks to generate annotations."
|
||||
)
|
||||
with open(path) as f:
|
||||
proportions_dict = json.load(f)
|
||||
names = list(proportions_dict.keys())
|
||||
logging.info(f"Loaded {len(names)} {annotation_type} subtasks: {names}")
|
||||
logging.info(f"{annotation_type.capitalize()} temporal proportions: {proportions_dict}")
|
||||
return names, [proportions_dict[name] for name in names]
|
||||
|
||||
def _load_temporal_proportions(self, dataset_meta) -> None:
|
||||
"""Load temporal proportions based on annotation_mode."""
|
||||
meta_path = dataset_meta.root / "meta"
|
||||
|
||||
if self.config.annotation_mode == "dual":
|
||||
names, props = self._load_proportions_from_json(
|
||||
meta_path / "temporal_proportions_sparse.json", "sparse"
|
||||
)
|
||||
(
|
||||
self.config.num_sparse_stages,
|
||||
self.config.sparse_subtask_names,
|
||||
self.config.sparse_temporal_proportions,
|
||||
) = len(names), names, props
|
||||
|
||||
if self.config.annotation_mode in ["dense_only", "dual"]:
|
||||
names, props = self._load_proportions_from_json(
|
||||
meta_path / "temporal_proportions_dense.json", "dense"
|
||||
)
|
||||
(
|
||||
self.config.num_dense_stages,
|
||||
self.config.dense_subtask_names,
|
||||
self.config.dense_temporal_proportions,
|
||||
) = len(names), names, props
|
||||
if self.config.annotation_mode == "dense_only":
|
||||
logging.info(f"Using auto-generated sparse 'task' stage: {self.config.sparse_subtask_names}")
|
||||
|
||||
def to(self, device):
|
||||
"""Override to method to ensure all components move together."""
|
||||
super().to(device)
|
||||
self.device = device if isinstance(device, torch.device) else torch.device(device)
|
||||
self.stage_model.to(device)
|
||||
self.subtask_model.to(device)
|
||||
return self
|
||||
|
||||
@torch.no_grad()
|
||||
def calculate_rewards(
|
||||
self,
|
||||
text_embeddings: np.ndarray | torch.Tensor,
|
||||
video_embeddings: np.ndarray | torch.Tensor,
|
||||
state_features: np.ndarray | torch.Tensor | None = None,
|
||||
lengths: np.ndarray | torch.Tensor | None = None,
|
||||
return_all_frames: bool = False,
|
||||
return_stages: bool = False,
|
||||
return_confidence: bool = False,
|
||||
head_mode: str | None = "sparse",
|
||||
frame_index: int | None = None,
|
||||
) -> np.ndarray | tuple:
|
||||
"""
|
||||
Calculate rewards for given text, video, and state representations.
|
||||
|
||||
This is the canonical method for SARM reward computation, used for:
|
||||
- Inference/visualization
|
||||
- RA-BC weight computation
|
||||
|
||||
Args:
|
||||
text_embeddings: Encoded text representations (batch_size, 512)
|
||||
video_embeddings: Encoded video representations (batch_size, num_frames, 512)
|
||||
state_features: Joint state features (batch_size, num_frames, state_dim)
|
||||
lengths: Valid sequence lengths (batch_size,)
|
||||
return_all_frames: If True, return rewards for all frames
|
||||
return_stages: If True, also return stage predictions
|
||||
return_confidence: If True, also return stage confidence
|
||||
head_mode: Which head to use ("sparse" or "dense")
|
||||
frame_index: Index of the target frame to extract (default: n_obs_steps).
|
||||
|
||||
Returns:
|
||||
Rewards and optionally stage probs/confidence.
|
||||
"""
|
||||
if isinstance(text_embeddings, np.ndarray):
|
||||
text_embeddings = torch.tensor(text_embeddings, dtype=torch.float32)
|
||||
if isinstance(video_embeddings, np.ndarray):
|
||||
video_embeddings = torch.tensor(video_embeddings, dtype=torch.float32)
|
||||
if state_features is not None and isinstance(state_features, np.ndarray):
|
||||
state_features = torch.tensor(state_features, dtype=torch.float32)
|
||||
|
||||
# Handle single sample case
|
||||
if text_embeddings.dim() == 1:
|
||||
text_embeddings = text_embeddings.unsqueeze(0)
|
||||
video_embeddings = video_embeddings.unsqueeze(0)
|
||||
if state_features is not None:
|
||||
state_features = state_features.unsqueeze(0)
|
||||
single_sample = True
|
||||
else:
|
||||
single_sample = False
|
||||
|
||||
batch_size = video_embeddings.shape[0]
|
||||
seq_len = video_embeddings.shape[1]
|
||||
|
||||
scheme = head_mode
|
||||
|
||||
# Default lengths if not provided
|
||||
if lengths is None:
|
||||
lengths = torch.full((batch_size,), seq_len, dtype=torch.int32)
|
||||
elif isinstance(lengths, np.ndarray):
|
||||
lengths = torch.tensor(lengths, dtype=torch.int32)
|
||||
|
||||
# Reshape video to (B, N, T, D) for multi-camera format
|
||||
# Currently single camera: (B, T, D) -> (B, 1, T, D)
|
||||
img_seq = video_embeddings.unsqueeze(1).to(self.device)
|
||||
lang_emb = text_embeddings.to(self.device)
|
||||
state = (
|
||||
state_features.to(self.device)
|
||||
if state_features is not None
|
||||
else torch.zeros(batch_size, seq_len, self.config.max_state_dim, device=self.device)
|
||||
)
|
||||
lens = lengths.to(self.device)
|
||||
|
||||
# Pad state to max_state_dim
|
||||
state = pad_state_to_max_dim(state, self.config.max_state_dim)
|
||||
|
||||
# Get num_classes for this scheme
|
||||
num_classes = self.config.num_sparse_stages if scheme == "sparse" else self.config.num_dense_stages
|
||||
|
||||
# Run stage model
|
||||
stage_logits = self.stage_model(img_seq, lang_emb, state, lens, scheme=scheme)
|
||||
stage_probs = F.softmax(stage_logits, dim=-1) # (B, T, num_classes)
|
||||
stage_idx = stage_probs.argmax(dim=-1) # (B, T)
|
||||
stage_conf = stage_probs.gather(-1, stage_idx.unsqueeze(-1)).squeeze(-1) # (B, T)
|
||||
|
||||
# Create one-hot stage prior
|
||||
stage_onehot = F.one_hot(stage_idx, num_classes=num_classes).float() # (B, T, C)
|
||||
stage_emb = stage_onehot.unsqueeze(1) # (B, 1, T, C)
|
||||
|
||||
# Run subtask model
|
||||
tau_pred = self.subtask_model(img_seq, lang_emb, state, lens, stage_emb, scheme=scheme)
|
||||
|
||||
# Compute final reward: stage + tau
|
||||
raw_reward = stage_idx.float() + tau_pred # (B, T)
|
||||
|
||||
# Normalize to [0, 1] using temporal proportions for proper weighting
|
||||
if scheme == "sparse":
|
||||
normalized_reward = normalize_stage_tau(
|
||||
raw_reward,
|
||||
num_stages=num_classes,
|
||||
temporal_proportions=self.config.sparse_temporal_proportions,
|
||||
subtask_names=self.config.sparse_subtask_names,
|
||||
)
|
||||
else:
|
||||
normalized_reward = normalize_stage_tau(
|
||||
raw_reward,
|
||||
num_stages=num_classes,
|
||||
temporal_proportions=self.config.dense_temporal_proportions,
|
||||
subtask_names=self.config.dense_subtask_names,
|
||||
)
|
||||
|
||||
# Default frame index is n_obs_steps (last observation frame)
|
||||
if frame_index is None:
|
||||
frame_index = self.config.n_obs_steps
|
||||
|
||||
# Prepare outputs (batch mode or no smoothing)
|
||||
if return_all_frames:
|
||||
rewards = normalized_reward.cpu().numpy()
|
||||
else:
|
||||
rewards = normalized_reward[:, frame_index].cpu().numpy()
|
||||
|
||||
if single_sample:
|
||||
rewards = rewards[0] if not return_all_frames else rewards[0]
|
||||
|
||||
outputs = [rewards]
|
||||
if return_stages:
|
||||
probs = stage_probs.cpu().numpy()
|
||||
if single_sample:
|
||||
probs = probs[0]
|
||||
outputs.append(probs)
|
||||
if return_confidence:
|
||||
conf = stage_conf.cpu().numpy()
|
||||
if single_sample:
|
||||
conf = conf[0]
|
||||
outputs.append(conf)
|
||||
|
||||
return outputs[0] if len(outputs) == 1 else tuple(outputs)
|
||||
|
||||
def train(self, mode: bool = True):
|
||||
"""Set training mode for both models."""
|
||||
super().train(mode)
|
||||
self.stage_model.train(mode)
|
||||
self.subtask_model.train(mode)
|
||||
return self
|
||||
|
||||
def eval(self):
|
||||
"""Set evaluation mode for both models."""
|
||||
return self.train(False)
|
||||
|
||||
def parameters(self):
|
||||
"""Override to return trainable parameters from both models."""
|
||||
from itertools import chain
|
||||
|
||||
return chain(self.stage_model.parameters(), self.subtask_model.parameters())
|
||||
|
||||
def get_optim_params(self):
|
||||
"""Override to return optimizer parameters from both models."""
|
||||
return self.parameters()
|
||||
|
||||
def reset(self):
|
||||
"""Required by PreTrainedPolicy but not used for reward models."""
|
||||
pass
|
||||
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Required by PreTrainedPolicy but not used for reward models."""
|
||||
raise NotImplementedError("SARM model does not predict action chunks")
|
||||
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Required by PreTrainedPolicy but not used for SARM."""
|
||||
raise NotImplementedError("SARM model does not select actions")
|
||||
|
||||
def _train_step(
|
||||
self,
|
||||
img_emb: torch.Tensor, # (B, N, T, D)
|
||||
lang_emb: torch.Tensor, # (B, E) or (B, T, E)
|
||||
state: torch.Tensor, # (B, T, state_dim)
|
||||
lengths: torch.Tensor, # (B,)
|
||||
targets: torch.Tensor, # (B, T) - format: stage.tau
|
||||
scheme: str,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Single training step for one annotation scheme.
|
||||
|
||||
Implements 75%/25% GT/predicted stage conditioning.
|
||||
|
||||
Args:
|
||||
img_emb: Image embeddings (B, N, T, D)
|
||||
lang_emb: Language embeddings
|
||||
state: State features
|
||||
lengths: Valid sequence lengths
|
||||
targets: Target values where floor=stage, remainder=tau
|
||||
scheme: "sparse" or "dense"
|
||||
|
||||
Returns:
|
||||
Dict with stage_loss, subtask_loss, total_loss
|
||||
"""
|
||||
num_classes = self.config.num_sparse_stages if scheme == "sparse" else self.config.num_dense_stages
|
||||
|
||||
# Ground truth: stage (integer) and tau (fractional)
|
||||
# Clamp stage indices to valid range [0, num_classes-1] to handle edge cases
|
||||
# where targets may exceed expected range (e.g., frames between subtasks)
|
||||
gt_stage = torch.floor(targets).long().clamp(0, num_classes - 1) # (B, T)
|
||||
gt_tau = torch.remainder(targets, 1.0) # (B, T)
|
||||
|
||||
# Run stage model
|
||||
stage_pred = self.stage_model(img_emb, lang_emb, state, lengths, scheme=scheme)
|
||||
|
||||
# 75%/25% GT/predicted stage conditioning
|
||||
if random.random() < self.gt_stage_ratio:
|
||||
# Mode 1: Use ground truth stage -> one-hot
|
||||
stage_emb = gen_stage_emb(num_classes, targets) # (B, 1, T, C)
|
||||
else:
|
||||
# Mode 2: Use predicted stage argmax -> one-hot
|
||||
stage_idx = stage_pred.argmax(dim=-1) # (B, T)
|
||||
stage_onehot = F.one_hot(stage_idx, num_classes=num_classes).float() # (B, T, C)
|
||||
stage_emb = stage_onehot.unsqueeze(1) # (B, 1, T, C)
|
||||
|
||||
# Run subtask model with stage prior
|
||||
tau_pred = self.subtask_model(img_emb, lang_emb, state, lengths, stage_emb, scheme=scheme)
|
||||
|
||||
# Compute losses
|
||||
stage_loss = F.cross_entropy(stage_pred.view(-1, num_classes), gt_stage.view(-1), reduction="mean")
|
||||
subtask_loss = F.mse_loss(tau_pred, gt_tau, reduction="mean")
|
||||
|
||||
return {
|
||||
"stage_loss": stage_loss,
|
||||
"subtask_loss": subtask_loss,
|
||||
"total_loss": stage_loss + subtask_loss,
|
||||
}
|
||||
|
||||
def forward(self, batch):
|
||||
"""
|
||||
Forward pass for SARM reward model training.
|
||||
|
||||
Uses stage+tau target format where:
|
||||
- Integer part = stage index
|
||||
- Fractional part = within-stage progress (tau)
|
||||
|
||||
Training uses 75%/25% GT/predicted stage conditioning.
|
||||
|
||||
Args:
|
||||
batch: Dictionary with 'observation' containing:
|
||||
- 'video_features': (B, T, 512) pre-encoded video features
|
||||
- 'text_features': (B, 512) or (B, T, 512) text features
|
||||
- 'state_features': (B, T, state_dim) joint state features
|
||||
- 'lengths': (B,) valid sequence lengths
|
||||
- 'sparse_targets': (B, T) sparse targets (stage.tau format)
|
||||
- 'dense_targets': (B, T) dense targets (optional, for dual mode)
|
||||
|
||||
Returns:
|
||||
Tuple of (total_loss, output_dict with loss components)
|
||||
"""
|
||||
observation = batch.get("observation", batch)
|
||||
|
||||
# Extract features
|
||||
video_features = observation["video_features"].to(self.device)
|
||||
text_features = observation["text_features"].to(self.device)
|
||||
state_features = observation.get("state_features")
|
||||
if state_features is not None:
|
||||
state_features = state_features.to(self.device)
|
||||
|
||||
batch_size = video_features.shape[0]
|
||||
seq_len = video_features.shape[1]
|
||||
|
||||
# Get lengths (default to full sequence)
|
||||
lengths = observation.get("lengths")
|
||||
if lengths is None:
|
||||
lengths = torch.full((batch_size,), seq_len, dtype=torch.int32, device=self.device)
|
||||
else:
|
||||
lengths = lengths.to(self.device)
|
||||
|
||||
# Reshape video to (B, N, T, D) - single camera
|
||||
img_emb = video_features.unsqueeze(1)
|
||||
|
||||
# Pad state to max_state_dim
|
||||
if state_features is None:
|
||||
state_features = torch.zeros(batch_size, seq_len, self.config.max_state_dim, device=self.device)
|
||||
else:
|
||||
state_features = pad_state_to_max_dim(state_features, self.config.max_state_dim)
|
||||
|
||||
output_dict = {}
|
||||
total_loss = torch.tensor(0.0, device=self.device)
|
||||
|
||||
# Sparse training (always)
|
||||
sparse_targets = observation.get("sparse_targets")
|
||||
if sparse_targets is None:
|
||||
# Try legacy format
|
||||
sparse_targets = observation.get("targets")
|
||||
if sparse_targets is None:
|
||||
raise ValueError("sparse_targets (or targets) is required for SARM training")
|
||||
sparse_targets = sparse_targets.to(self.device)
|
||||
|
||||
sparse_result = self._train_step(
|
||||
img_emb, text_features, state_features, lengths, sparse_targets, scheme="sparse"
|
||||
)
|
||||
output_dict["sparse_stage_loss"] = sparse_result["stage_loss"].item()
|
||||
output_dict["sparse_subtask_loss"] = sparse_result["subtask_loss"].item()
|
||||
total_loss = total_loss + sparse_result["total_loss"]
|
||||
|
||||
# Dense training (if dual mode)
|
||||
if self.config.uses_dual_heads:
|
||||
dense_targets = observation.get("dense_targets")
|
||||
if dense_targets is not None:
|
||||
dense_targets = dense_targets.to(self.device)
|
||||
dense_result = self._train_step(
|
||||
img_emb, text_features, state_features, lengths, dense_targets, scheme="dense"
|
||||
)
|
||||
output_dict["dense_stage_loss"] = dense_result["stage_loss"].item()
|
||||
output_dict["dense_subtask_loss"] = dense_result["subtask_loss"].item()
|
||||
total_loss = total_loss + dense_result["total_loss"]
|
||||
|
||||
output_dict["total_loss"] = total_loss.item()
|
||||
return total_loss, output_dict
|
||||
|
||||
|
||||
def compute_stage_loss(stage_logits: torch.Tensor, target_stages: torch.Tensor) -> torch.Tensor:
|
||||
"""Compute cross-entropy loss for stage classification."""
|
||||
_, _, num_stages = stage_logits.shape
|
||||
stage_logits_flat = stage_logits.reshape(-1, num_stages)
|
||||
# Clamp target stage indices to valid range [0, num_stages-1]
|
||||
target_stages_flat = target_stages.reshape(-1).clamp(0, num_stages - 1)
|
||||
return F.cross_entropy(stage_logits_flat, target_stages_flat)
|
||||
@@ -1,518 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""SARM Processor for encoding images/text and generating stage+tau targets."""
|
||||
|
||||
import random
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
from faker import Faker
|
||||
from PIL import Image
|
||||
from transformers import CLIPModel, CLIPProcessor
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.policies.sarm.configuration_sarm import SARMConfig
|
||||
from lerobot.policies.sarm.sarm_utils import (
|
||||
apply_rewind_augmentation,
|
||||
compute_absolute_indices,
|
||||
find_stage_and_tau,
|
||||
pad_state_to_max_dim,
|
||||
)
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyAction,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorStep,
|
||||
RenameObservationsProcessorStep,
|
||||
)
|
||||
from lerobot.processor.converters import (
|
||||
from_tensor_to_numpy,
|
||||
policy_action_to_transition,
|
||||
transition_to_policy_action,
|
||||
)
|
||||
from lerobot.processor.core import EnvTransition, TransitionKey
|
||||
from lerobot.processor.pipeline import PipelineFeatureType
|
||||
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
|
||||
|
||||
class SARMEncodingProcessorStep(ProcessorStep):
|
||||
"""ProcessorStep that encodes images and text with CLIP and generates stage and progress labels for SARM."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: SARMConfig,
|
||||
image_key: str | None = None,
|
||||
dataset_meta=None,
|
||||
dataset_stats: dict | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.image_key = image_key or config.image_key
|
||||
self.dataset_meta = dataset_meta
|
||||
self.dataset_stats = dataset_stats
|
||||
self.annotation_mode = config.annotation_mode
|
||||
|
||||
# Helper to create temporal proportions dict
|
||||
def make_props_dict(names, props):
|
||||
return dict(zip(names, props, strict=True)) if names and props else None
|
||||
|
||||
# Sparse annotations (always needed)
|
||||
self.sparse_temporal_proportions = make_props_dict(
|
||||
config.sparse_subtask_names, config.sparse_temporal_proportions
|
||||
)
|
||||
self.sparse_subtask_names = config.sparse_subtask_names
|
||||
|
||||
# Dense annotations (only for dual mode)
|
||||
self.dense_subtask_names = config.dense_subtask_names if config.uses_dual_heads else None
|
||||
self.dense_temporal_proportions = (
|
||||
make_props_dict(config.dense_subtask_names, config.dense_temporal_proportions)
|
||||
if config.uses_dual_heads
|
||||
else None
|
||||
)
|
||||
|
||||
self.device = torch.device(
|
||||
self.config.device if self.config.device else "cuda" if torch.cuda.is_available() else "cpu"
|
||||
)
|
||||
|
||||
self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
||||
self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", use_fast=True)
|
||||
self.clip_model.to(self.device)
|
||||
self.clip_model.eval()
|
||||
|
||||
self.verbs = ["move", "grasp", "rotate", "push", "pull", "slide", "lift", "place"]
|
||||
self.fake = Faker()
|
||||
|
||||
def _find_episode_for_frame(self, frame_idx: int) -> int:
|
||||
"""Find the episode index for a given frame index."""
|
||||
for ep_idx in range(len(self.dataset_meta.episodes)):
|
||||
ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"]
|
||||
ep_end = self.dataset_meta.episodes[ep_idx]["dataset_to_index"]
|
||||
if ep_start <= frame_idx < ep_end:
|
||||
return ep_idx
|
||||
return 0
|
||||
|
||||
def _get_episode_indices(self, frame_indices: np.ndarray, episode_index) -> np.ndarray:
|
||||
"""Get episode indices for each frame index."""
|
||||
if episode_index is None:
|
||||
return np.array([self._find_episode_for_frame(int(f)) for f in frame_indices])
|
||||
|
||||
episode_indices = np.atleast_1d(np.asarray(from_tensor_to_numpy(episode_index)))
|
||||
|
||||
# If single episode but multiple frames, compute episode for each frame
|
||||
if len(episode_indices) == 1 and len(frame_indices) > 1:
|
||||
return np.array([self._find_episode_for_frame(int(f)) for f in frame_indices])
|
||||
|
||||
return episode_indices
|
||||
|
||||
def _generate_perturbed_task(self) -> str:
|
||||
"""Generate a random perturbed task string for language perturbation."""
|
||||
num_words = random.randint(1, 5)
|
||||
verb = random.choice(self.verbs)
|
||||
phrase = " ".join([verb] + self.fake.words(nb=num_words))
|
||||
return phrase
|
||||
|
||||
def _get_annotation_config(self, annotation_type: str) -> tuple[list[str], dict[str, float] | None]:
|
||||
"""Get global subtask names and temporal proportions for an annotation type."""
|
||||
if annotation_type == "dense":
|
||||
return self.dense_subtask_names, self.dense_temporal_proportions
|
||||
return self.sparse_subtask_names, self.sparse_temporal_proportions
|
||||
|
||||
def _load_episode_annotations(
|
||||
self,
|
||||
ep_idx: int,
|
||||
episodes_df: pd.DataFrame | None,
|
||||
annotation_type: str,
|
||||
global_names: list[str],
|
||||
) -> tuple[list | None, list | None, list | None]:
|
||||
"""Load subtask annotations for an episode from DataFrame."""
|
||||
# Single-stage mode: (linear progress 0→1)
|
||||
if episodes_df is None or len(global_names) == 1:
|
||||
return None, None, None
|
||||
|
||||
# Resolve column name with fallback
|
||||
def col(suffix):
|
||||
prefixed = f"{annotation_type}_{suffix}"
|
||||
return prefixed if prefixed in episodes_df.columns else suffix
|
||||
|
||||
col_names = col("subtask_names")
|
||||
if col_names not in episodes_df.columns or ep_idx >= len(episodes_df):
|
||||
return None, None, None
|
||||
|
||||
subtask_names = episodes_df.loc[ep_idx, col_names]
|
||||
if subtask_names is None or (isinstance(subtask_names, float) and pd.isna(subtask_names)):
|
||||
return None, None, None
|
||||
|
||||
return (
|
||||
subtask_names,
|
||||
episodes_df.loc[ep_idx, col("subtask_start_frames")],
|
||||
episodes_df.loc[ep_idx, col("subtask_end_frames")],
|
||||
)
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
"""
|
||||
Encode images, text, and normalize states in the transition.
|
||||
|
||||
Implements SARM training data preparation:
|
||||
- Applies language perturbation (20% probability)
|
||||
- Applies rewind augmentation (80% probability)
|
||||
- Generates stage+tau targets for all frames
|
||||
- Outputs lengths tensor for valid sequence masking
|
||||
"""
|
||||
new_transition = transition.copy() if hasattr(transition, "copy") else dict(transition)
|
||||
observation = new_transition.get(TransitionKey.OBSERVATION)
|
||||
comp_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||
|
||||
frame_index = comp_data.get("index")
|
||||
episode_index = comp_data.get("episode_index")
|
||||
|
||||
if frame_index is None:
|
||||
raise ValueError("Frame index ('index') not found in COMPLEMENTARY_DATA")
|
||||
if episode_index is None:
|
||||
raise ValueError("Episode index ('episode_index') not found in COMPLEMENTARY_DATA")
|
||||
|
||||
frame_indices = np.atleast_1d(np.asarray(from_tensor_to_numpy(frame_index)))
|
||||
episode_indices = self._get_episode_indices(frame_indices, episode_index)
|
||||
|
||||
image = observation.get(self.image_key)
|
||||
if isinstance(image, torch.Tensor):
|
||||
image = image.cpu().numpy()
|
||||
|
||||
# If 4D (T, C, H, W) from delta_timestamps, add batch dim
|
||||
# If 3D (C, H, W) single frame, add batch and time dims
|
||||
if image.ndim == 4:
|
||||
image = image[np.newaxis, ...] # (T, C, H, W) -> (1, T, C, H, W)
|
||||
elif image.ndim == 3:
|
||||
image = image[np.newaxis, np.newaxis, ...] # (C, H, W) -> (1, 1, C, H, W)
|
||||
|
||||
batch_size = image.shape[0]
|
||||
total_frames = image.shape[1] # Should be 13: 9 obs + 4 rewind placeholders
|
||||
n_obs_steps = self.config.n_obs_steps
|
||||
max_rewind_steps = self.config.max_rewind_steps
|
||||
n_obs_frames = 1 + n_obs_steps # 9 observation frames (including current)
|
||||
|
||||
# Rewind augmentation
|
||||
rewind_steps = torch.zeros(batch_size, dtype=torch.int32)
|
||||
apply_rewind = self.training and random.random() < self.config.rewind_probability
|
||||
|
||||
if apply_rewind and self.dataset_meta is not None:
|
||||
for b_idx, (ep_idx, frame_idx) in enumerate(
|
||||
zip(episode_indices.tolist(), frame_indices.tolist(), strict=True)
|
||||
):
|
||||
ep_idx, frame_idx = int(ep_idx), int(frame_idx)
|
||||
ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"]
|
||||
|
||||
rewind_step, _ = apply_rewind_augmentation(
|
||||
frame_idx, ep_start, n_obs_steps, max_rewind_steps, frame_gap=self.config.frame_gap
|
||||
)
|
||||
rewind_steps[b_idx] = rewind_step
|
||||
|
||||
# Compute valid lengths: n_obs_frames + rewind_steps
|
||||
lengths = n_obs_frames + rewind_steps # (B,)
|
||||
|
||||
# Apply rewind masking to images
|
||||
# For frames beyond valid length, we mask with zeros (or copy last valid frame)
|
||||
for b_idx in range(batch_size):
|
||||
valid_len = lengths[b_idx].item()
|
||||
if valid_len < total_frames:
|
||||
image[b_idx, valid_len:] = 0 # Zero out frames beyond valid length
|
||||
|
||||
# Encode images with CLIP
|
||||
video_features = self._encode_images_batch(image)
|
||||
observation["video_features"] = video_features
|
||||
|
||||
state_key = self.config.state_key
|
||||
state_data = observation.get(state_key)
|
||||
|
||||
if isinstance(state_data, torch.Tensor):
|
||||
state_tensor = state_data.float()
|
||||
else:
|
||||
state_tensor = torch.tensor(state_data, dtype=torch.float32)
|
||||
|
||||
if state_tensor.ndim == 2:
|
||||
state_tensor = state_tensor.unsqueeze(0) # (T, D) -> (1, T, D)
|
||||
elif state_tensor.ndim == 1:
|
||||
state_tensor = state_tensor.unsqueeze(0).unsqueeze(0) # (D,) -> (1, 1, D)
|
||||
|
||||
# Apply same rewind masking to state
|
||||
for b_idx in range(batch_size):
|
||||
valid_len = lengths[b_idx].item()
|
||||
if valid_len < state_tensor.shape[1]:
|
||||
state_tensor[b_idx, valid_len:] = 0 # Zero out frames beyond valid length
|
||||
|
||||
observation["state_features"] = pad_state_to_max_dim(state_tensor, self.config.max_state_dim)
|
||||
|
||||
task = comp_data.get("task")
|
||||
if isinstance(task, list):
|
||||
task = task[0] if task else ""
|
||||
|
||||
# Apply language perturbation during training (20% probability)
|
||||
# When perturbed, targets will be zeroed to train model to output low values for irrelevant text
|
||||
apply_perturbation = self.training and random.random() < self.config.language_perturbation_probability
|
||||
if apply_perturbation:
|
||||
task = self._generate_perturbed_task()
|
||||
|
||||
# Encode text with CLIP
|
||||
observation["text_features"] = self._encode_text_clip(task, batch_size)
|
||||
|
||||
# Store lengths for model
|
||||
observation["lengths"] = lengths
|
||||
|
||||
# When language is perturbed, targets are zero so perturbed samples don't contribute to progress loss
|
||||
if self.dataset_meta is not None:
|
||||
episodes_df = None
|
||||
if self.sparse_subtask_names != ["task"]:
|
||||
episodes_df = self.dataset_meta.episodes.to_pandas()
|
||||
|
||||
# Generate sparse targets
|
||||
if self.sparse_temporal_proportions is not None:
|
||||
if apply_perturbation:
|
||||
# Zero targets when language is perturbed
|
||||
sparse_targets = torch.zeros(batch_size, total_frames, dtype=torch.float32)
|
||||
else:
|
||||
sparse_targets = self._compute_batch_targets(
|
||||
frame_indices, episode_indices, lengths, rewind_steps, episodes_df, "sparse"
|
||||
)
|
||||
observation["sparse_targets"] = sparse_targets
|
||||
|
||||
# Generate dense targets (for dual mode)
|
||||
if self.config.uses_dual_heads and self.dense_temporal_proportions is not None:
|
||||
if apply_perturbation:
|
||||
# Zero targets when language is perturbed
|
||||
dense_targets = torch.zeros(batch_size, total_frames, dtype=torch.float32)
|
||||
else:
|
||||
dense_targets = self._compute_batch_targets(
|
||||
frame_indices, episode_indices, lengths, rewind_steps, episodes_df, "dense"
|
||||
)
|
||||
observation["dense_targets"] = dense_targets
|
||||
|
||||
new_transition[TransitionKey.OBSERVATION] = observation
|
||||
return new_transition
|
||||
|
||||
def _compute_batch_targets(
|
||||
self,
|
||||
frame_indices: np.ndarray,
|
||||
episode_indices: np.ndarray,
|
||||
lengths: torch.Tensor,
|
||||
rewind_steps: torch.Tensor,
|
||||
episodes_df: pd.DataFrame | None,
|
||||
annotation_type: str,
|
||||
) -> torch.Tensor:
|
||||
"""Compute stage+tau targets for a batch of samples."""
|
||||
batch_size = len(frame_indices)
|
||||
n_obs_steps = self.config.n_obs_steps
|
||||
max_rewind_steps = self.config.max_rewind_steps
|
||||
total_frames = 1 + n_obs_steps + max_rewind_steps
|
||||
frame_gap = self.config.frame_gap
|
||||
|
||||
global_names, temporal_props = self._get_annotation_config(annotation_type)
|
||||
targets = torch.zeros(batch_size, total_frames, dtype=torch.float32)
|
||||
|
||||
for b_idx in range(batch_size):
|
||||
ep_idx = int(episode_indices[b_idx])
|
||||
frame_idx = int(frame_indices[b_idx])
|
||||
|
||||
ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"]
|
||||
ep_end = self.dataset_meta.episodes[ep_idx]["dataset_to_index"]
|
||||
ep_length = ep_end - ep_start
|
||||
|
||||
subtask_names, subtask_start_frames, subtask_end_frames = self._load_episode_annotations(
|
||||
ep_idx, episodes_df, annotation_type, global_names
|
||||
)
|
||||
|
||||
# Compute observation frame indices
|
||||
obs_indices, _ = compute_absolute_indices(
|
||||
frame_idx, ep_start, ep_end, n_obs_steps, frame_gap=frame_gap
|
||||
)
|
||||
obs_indices = obs_indices.tolist()
|
||||
|
||||
# Compute targets for observation frames
|
||||
for t_idx, abs_idx in enumerate(obs_indices):
|
||||
rel_frame = abs_idx - ep_start
|
||||
targets[b_idx, t_idx] = find_stage_and_tau(
|
||||
rel_frame,
|
||||
ep_length,
|
||||
subtask_names,
|
||||
subtask_start_frames,
|
||||
subtask_end_frames,
|
||||
global_names,
|
||||
temporal_props,
|
||||
return_combined=True,
|
||||
)
|
||||
|
||||
# Compute targets for rewind frames (if any)
|
||||
rewind_step = rewind_steps[b_idx].item()
|
||||
if rewind_step > 0:
|
||||
_, rewind_indices = apply_rewind_augmentation(
|
||||
frame_idx,
|
||||
ep_start,
|
||||
n_obs_steps,
|
||||
max_rewind_steps,
|
||||
frame_gap=frame_gap,
|
||||
rewind_step=rewind_step,
|
||||
)
|
||||
|
||||
for r_idx, abs_idx in enumerate(rewind_indices[:rewind_step]):
|
||||
rel_frame = max(0, abs_idx - ep_start)
|
||||
targets[b_idx, n_obs_steps + 1 + r_idx] = find_stage_and_tau(
|
||||
rel_frame,
|
||||
ep_length,
|
||||
subtask_names,
|
||||
subtask_start_frames,
|
||||
subtask_end_frames,
|
||||
global_names,
|
||||
temporal_props,
|
||||
return_combined=True,
|
||||
)
|
||||
|
||||
return targets
|
||||
|
||||
@property
|
||||
def training(self) -> bool:
|
||||
return getattr(self, "_training_mode", True)
|
||||
|
||||
def train(self, mode: bool = True):
|
||||
"""Set training mode for augmentation decisions."""
|
||||
self._training_mode = mode
|
||||
return self
|
||||
|
||||
def eval(self):
|
||||
"""Set evaluation mode (disable augmentations)."""
|
||||
return self.train(False)
|
||||
|
||||
@torch.no_grad()
|
||||
def _encode_images_batch(self, images: np.ndarray) -> torch.Tensor:
|
||||
"""Encode a batch of images using CLIP.
|
||||
|
||||
Args:
|
||||
images: Batched images with shape: (B, T, C, H, W)
|
||||
|
||||
Returns:
|
||||
Encoded feature vectors with shape (B, T, 512)
|
||||
"""
|
||||
|
||||
batch_size, seq_length = images.shape[0], images.shape[1]
|
||||
images = images.reshape(batch_size * seq_length, *images.shape[2:])
|
||||
|
||||
num_frames = images.shape[0]
|
||||
images_list = []
|
||||
for i in range(num_frames):
|
||||
img = images[i]
|
||||
if img.shape[0] in [1, 3]: # Channel first (C, H, W)
|
||||
img = img.transpose(1, 2, 0)
|
||||
|
||||
# Handle single channel
|
||||
if img.shape[-1] == 1:
|
||||
img = np.repeat(img, 3, axis=-1)
|
||||
|
||||
if img.dtype != np.uint8:
|
||||
img = (img * 255).astype(np.uint8) if img.max() <= 1.0 else img.astype(np.uint8)
|
||||
|
||||
images_list.append(Image.fromarray(img))
|
||||
|
||||
all_embeddings = []
|
||||
for i in range(0, num_frames, self.config.clip_batch_size):
|
||||
batch_imgs = images_list[i : i + self.config.clip_batch_size]
|
||||
|
||||
inputs = self.clip_processor(images=batch_imgs, return_tensors="pt")
|
||||
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
||||
|
||||
# Get image embeddings
|
||||
embeddings = self.clip_model.get_image_features(**inputs).detach().cpu()
|
||||
|
||||
# Handle single frame case
|
||||
if embeddings.dim() == 1:
|
||||
embeddings = embeddings.unsqueeze(0)
|
||||
|
||||
all_embeddings.append(embeddings)
|
||||
|
||||
all_embeddings = torch.cat(all_embeddings) # (B*T, 512)
|
||||
all_embeddings = all_embeddings.reshape(batch_size, seq_length, -1) # (B, T, 512)
|
||||
|
||||
return all_embeddings
|
||||
|
||||
@torch.no_grad()
|
||||
def _encode_text_clip(self, text: str, batch_size: int) -> torch.Tensor:
|
||||
"""Encode text using CLIP text encoder (per SARM paper A.4).
|
||||
|
||||
Args:
|
||||
text: Task description text to encode
|
||||
batch_size: Batch size to replicate for
|
||||
|
||||
Returns:
|
||||
Encoded text features with shape (B, 512)
|
||||
"""
|
||||
inputs = self.clip_processor.tokenizer([text], return_tensors="pt", padding=True, truncation=True)
|
||||
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
||||
|
||||
text_embedding = self.clip_model.get_text_features(**inputs).detach().cpu()
|
||||
text_embedding = text_embedding.expand(batch_size, -1)
|
||||
|
||||
return text_embedding
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
"""Add encoded features to the observation features."""
|
||||
features[PipelineFeatureType.OBSERVATION]["video_features"] = PolicyFeature(
|
||||
type=FeatureType.VISUAL, shape=(self.config.num_frames, self.config.image_dim)
|
||||
)
|
||||
features[PipelineFeatureType.OBSERVATION]["text_features"] = PolicyFeature(
|
||||
type=FeatureType.LANGUAGE, shape=(self.config.text_dim,)
|
||||
)
|
||||
features[PipelineFeatureType.OBSERVATION]["state_features"] = PolicyFeature(
|
||||
type=FeatureType.STATE, shape=(self.config.num_frames, self.config.max_state_dim)
|
||||
)
|
||||
return features
|
||||
|
||||
|
||||
def make_sarm_pre_post_processors(
|
||||
config: SARMConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
dataset_meta=None,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
"""Create pre-processor and post-processor pipelines for SARM."""
|
||||
return (
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||
steps=[
|
||||
AddBatchDimensionProcessorStep(),
|
||||
RenameObservationsProcessorStep(rename_map={}),
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
SARMEncodingProcessorStep(
|
||||
config=config, dataset_meta=dataset_meta, dataset_stats=dataset_stats
|
||||
),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
],
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
),
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction](
|
||||
steps=[DeviceProcessorStep(device="cpu")],
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
to_transition=policy_action_to_transition,
|
||||
to_output=transition_to_policy_action,
|
||||
),
|
||||
)
|
||||
@@ -1,295 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
|
||||
|
||||
def find_stage_and_tau(
|
||||
current_frame: int,
|
||||
episode_length: int,
|
||||
subtask_names: list | None,
|
||||
subtask_start_frames: list | None,
|
||||
subtask_end_frames: list | None,
|
||||
global_subtask_names: list,
|
||||
temporal_proportions: dict,
|
||||
return_combined: bool = False,
|
||||
) -> tuple[int, float] | float:
|
||||
"""Find stage and within-stage progress (tau) for a frame.
|
||||
|
||||
Args:
|
||||
current_frame: Frame index relative to episode start
|
||||
episode_length: Total frames in episode
|
||||
subtask_names: Subtask names for this episode (None for single_stage)
|
||||
subtask_start_frames: Subtask start frames
|
||||
subtask_end_frames: Subtask end frames
|
||||
global_subtask_names: Global list of all subtask names
|
||||
temporal_proportions: Dict of temporal proportions
|
||||
return_combined: If True, return stage+tau as float; else (stage_idx, tau) tuple
|
||||
|
||||
Returns:
|
||||
Float (stage.tau) if return_combined, else (stage_idx, tau) tuple
|
||||
"""
|
||||
stage_idx, tau = 0, 0.0
|
||||
num_stages = len(global_subtask_names)
|
||||
|
||||
# Single-stage mode: linear progress from 0 to 1
|
||||
if num_stages == 1:
|
||||
tau = min(1.0, max(0.0, current_frame / max(episode_length - 1, 1)))
|
||||
elif subtask_names is None:
|
||||
pass # stage_idx=0, tau=0.0
|
||||
elif current_frame < subtask_start_frames[0]:
|
||||
pass # Before first subtask: stage_idx=0, tau=0.0
|
||||
elif current_frame > subtask_end_frames[-1]:
|
||||
stage_idx, tau = num_stages - 1, 0.999 # After last subtask
|
||||
else:
|
||||
# Find which subtask this frame belongs to
|
||||
found = False
|
||||
for name, start, end in zip(subtask_names, subtask_start_frames, subtask_end_frames, strict=True):
|
||||
if start <= current_frame <= end:
|
||||
stage_idx = global_subtask_names.index(name) if name in global_subtask_names else 0
|
||||
tau = compute_tau(current_frame, start, end)
|
||||
found = True
|
||||
break
|
||||
# Frame between subtasks - use previous subtask's end state
|
||||
if not found:
|
||||
for j in range(len(subtask_names) - 1):
|
||||
if subtask_end_frames[j] < current_frame < subtask_start_frames[j + 1]:
|
||||
name = subtask_names[j]
|
||||
stage_idx = global_subtask_names.index(name) if name in global_subtask_names else j
|
||||
tau = 1.0
|
||||
break
|
||||
|
||||
if return_combined:
|
||||
# Clamp to avoid overflow at end
|
||||
if stage_idx >= num_stages - 1 and tau >= 1.0:
|
||||
return num_stages - 1 + 0.999
|
||||
return stage_idx + tau
|
||||
return stage_idx, tau
|
||||
|
||||
|
||||
def compute_absolute_indices(
|
||||
frame_idx: int,
|
||||
ep_start: int,
|
||||
ep_end: int,
|
||||
n_obs_steps: int,
|
||||
frame_gap: int = 30,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Compute absolute frame indices with clamping for bidirectional observation sequence.
|
||||
|
||||
Bidirectional sampling centered on target frame:
|
||||
- Before: [-frame_gap * half_steps, ..., -frame_gap] (half_steps frames)
|
||||
- Current: [0] (1 frame)
|
||||
- After: [frame_gap, ..., frame_gap * half_steps] (half_steps frames)
|
||||
- Total: n_obs_steps + 1 frames
|
||||
|
||||
Out-of-bounds frames are clamped (duplicated from boundary).
|
||||
|
||||
Args:
|
||||
frame_idx: Target frame index (center frame of sequence)
|
||||
ep_start: Episode start index
|
||||
ep_end: Episode end index (exclusive)
|
||||
n_obs_steps: Number of observation steps (must be even for symmetric sampling)
|
||||
frame_gap: Gap between observation frames
|
||||
|
||||
Returns:
|
||||
Tuple of (indices, out_of_bounds_flags)
|
||||
"""
|
||||
half_steps = n_obs_steps // 2
|
||||
|
||||
# Bidirectional deltas: past + current + future
|
||||
past_deltas = [-frame_gap * i for i in range(half_steps, 0, -1)]
|
||||
future_deltas = [frame_gap * i for i in range(1, half_steps + 1)]
|
||||
delta_indices = past_deltas + [0] + future_deltas
|
||||
|
||||
frames = []
|
||||
out_of_bounds = []
|
||||
|
||||
for delta in delta_indices:
|
||||
target_idx = frame_idx + delta
|
||||
# Clamp to episode bounds (duplicate boundary frames for out-of-bounds)
|
||||
clamped_idx = max(ep_start, min(ep_end - 1, target_idx))
|
||||
frames.append(clamped_idx)
|
||||
# Flag as out-of-bounds if clamping occurred
|
||||
out_of_bounds.append(1 if target_idx != clamped_idx else 0)
|
||||
|
||||
return torch.tensor(frames), torch.tensor(out_of_bounds)
|
||||
|
||||
|
||||
def apply_rewind_augmentation(
|
||||
frame_idx: int,
|
||||
ep_start: int,
|
||||
n_obs_steps: int,
|
||||
max_rewind_steps: int,
|
||||
frame_gap: int = 30,
|
||||
rewind_step: int | None = None,
|
||||
) -> tuple[int, list[int]]:
|
||||
"""
|
||||
Generate rewind frame indices for temporal augmentation.
|
||||
|
||||
Rewind simulates going backwards through previously seen frames,
|
||||
starting from before the earliest observation frame (for bidirectional sampling).
|
||||
Appends reversed frames after the observation sequence.
|
||||
|
||||
Args:
|
||||
frame_idx: Target frame index (center of bidirectional observation window)
|
||||
ep_start: Episode start index
|
||||
n_obs_steps: Number of observation steps
|
||||
max_rewind_steps: Maximum rewind steps
|
||||
frame_gap: Gap between frames
|
||||
rewind_step: If provided, use this exact rewind step (for deterministic behavior).
|
||||
If None, sample randomly.
|
||||
|
||||
Returns:
|
||||
Tuple of (rewind_step, rewind_indices)
|
||||
"""
|
||||
# For bidirectional sampling, earliest obs frame is at frame_idx - half_steps * frame_gap
|
||||
half_steps = n_obs_steps // 2
|
||||
earliest_obs_frame = frame_idx - half_steps * frame_gap
|
||||
|
||||
# Required history: frames before earliest observation frame
|
||||
if earliest_obs_frame <= ep_start:
|
||||
return 0, [] # No history before observation window
|
||||
|
||||
# Max valid rewind steps based on available history before earliest obs frame
|
||||
available_history = earliest_obs_frame - ep_start
|
||||
max_valid_step = available_history // frame_gap
|
||||
max_rewind = min(max_rewind_steps, max(0, max_valid_step))
|
||||
|
||||
if max_rewind <= 0:
|
||||
return 0, []
|
||||
|
||||
# Sample rewind steps if not provided
|
||||
rewind_step = random.randint(1, max_rewind) if rewind_step is None else min(rewind_step, max_rewind)
|
||||
|
||||
if rewind_step == 0:
|
||||
return 0, []
|
||||
|
||||
# Generate rewind indices going backwards from earliest obs frame
|
||||
# rewind_indices[0] is closest to obs window, rewind_indices[-1] is furthest back
|
||||
rewind_indices = []
|
||||
for i in range(1, rewind_step + 1):
|
||||
idx = earliest_obs_frame - i * frame_gap
|
||||
idx = max(ep_start, idx) # Clamp to episode start
|
||||
rewind_indices.append(idx)
|
||||
|
||||
return rewind_step, rewind_indices
|
||||
|
||||
|
||||
def compute_tau(current_frame: int | float, subtask_start: int | float, subtask_end: int | float) -> float:
|
||||
"""Compute τ_t = (t - s_k) / (e_k - s_k) ∈ [0, 1]. Returns 1.0 for zero-duration subtasks."""
|
||||
duration = subtask_end - subtask_start
|
||||
if duration <= 0:
|
||||
return 1.0
|
||||
return float(np.clip((current_frame - subtask_start) / duration, 0.0, 1.0))
|
||||
|
||||
|
||||
def pad_state_to_max_dim(state: torch.Tensor, max_state_dim: int) -> torch.Tensor:
|
||||
"""Pad the state tensor's last dimension to max_state_dim with zeros."""
|
||||
current_dim = state.shape[-1]
|
||||
if current_dim >= max_state_dim:
|
||||
return state[..., :max_state_dim] # Truncate if larger
|
||||
|
||||
# Pad with zeros on the right
|
||||
padding = (0, max_state_dim - current_dim) # (left, right) for last dim
|
||||
return F.pad(state, padding, mode="constant", value=0)
|
||||
|
||||
|
||||
def temporal_proportions_to_breakpoints(
|
||||
temporal_proportions: dict[str, float] | list[float] | None,
|
||||
subtask_names: list[str] | None = None,
|
||||
) -> list[float] | None:
|
||||
"""Convert temporal proportions to cumulative breakpoints for normalization."""
|
||||
if temporal_proportions is None:
|
||||
return None
|
||||
|
||||
if isinstance(temporal_proportions, dict):
|
||||
if subtask_names is not None:
|
||||
proportions = [temporal_proportions.get(name, 0.0) for name in subtask_names]
|
||||
else:
|
||||
proportions = list(temporal_proportions.values())
|
||||
else:
|
||||
proportions = list(temporal_proportions)
|
||||
|
||||
total = sum(proportions)
|
||||
if total > 0 and abs(total - 1.0) > 1e-6:
|
||||
proportions = [p / total for p in proportions]
|
||||
|
||||
breakpoints = [0.0]
|
||||
cumsum = 0.0
|
||||
for prop in proportions:
|
||||
cumsum += prop
|
||||
breakpoints.append(cumsum)
|
||||
breakpoints[-1] = 1.0
|
||||
|
||||
return breakpoints
|
||||
|
||||
|
||||
def normalize_stage_tau(
|
||||
x: float | torch.Tensor,
|
||||
num_stages: int | None = None,
|
||||
breakpoints: list[float] | None = None,
|
||||
temporal_proportions: dict[str, float] | list[float] | None = None,
|
||||
subtask_names: list[str] | None = None,
|
||||
) -> float | torch.Tensor:
|
||||
"""
|
||||
Normalize stage+tau reward to [0, 1] with custom breakpoints.
|
||||
|
||||
Maps stage index + within-stage tau to normalized progress [0, 1].
|
||||
The breakpoints are designed to give appropriate weight to each stage
|
||||
based on their importance in the task (using temporal proportions).
|
||||
|
||||
Priority: breakpoints > temporal_proportions > linear fallback
|
||||
|
||||
Args:
|
||||
x: Raw reward value (stage index + tau) where stage ∈ [0, num_stages-1] and tau ∈ [0, 1)
|
||||
num_stages: Number of stages (required if breakpoints/proportions not provided)
|
||||
breakpoints: Optional custom breakpoints list of length num_stages + 1.
|
||||
temporal_proportions: Optional temporal proportions dict/list to compute breakpoints.
|
||||
subtask_names: Optional ordered list of subtask names (for dict proportions)
|
||||
|
||||
Returns:
|
||||
Normalized progress value ∈ [0, 1]
|
||||
"""
|
||||
if breakpoints is not None:
|
||||
num_stages = len(breakpoints) - 1
|
||||
elif temporal_proportions is not None:
|
||||
breakpoints = temporal_proportions_to_breakpoints(temporal_proportions, subtask_names)
|
||||
num_stages = len(breakpoints) - 1
|
||||
elif num_stages is not None:
|
||||
breakpoints = [i / num_stages for i in range(num_stages + 1)]
|
||||
else:
|
||||
raise ValueError("Either num_stages, breakpoints, or temporal_proportions must be provided")
|
||||
|
||||
if isinstance(x, torch.Tensor):
|
||||
result = torch.zeros_like(x)
|
||||
for i in range(num_stages):
|
||||
mask = (x >= i) & (x < i + 1)
|
||||
tau_in_stage = x - i
|
||||
result[mask] = breakpoints[i] + tau_in_stage[mask] * (breakpoints[i + 1] - breakpoints[i])
|
||||
result[x >= num_stages] = 1.0
|
||||
return result.clamp(0.0, 1.0)
|
||||
else:
|
||||
if x < 0:
|
||||
return 0.0
|
||||
if x >= num_stages:
|
||||
return 1.0
|
||||
stage = int(x)
|
||||
tau = x - stage
|
||||
return breakpoints[stage] + tau * (breakpoints[stage + 1] - breakpoints[stage])
|
||||
@@ -231,7 +231,6 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
||||
def __init__(
|
||||
self,
|
||||
config: SmolVLAConfig,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@@ -353,19 +352,8 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
||||
def _rtc_enabled(self) -> bool:
|
||||
return self.config.rtc_config is not None and self.config.rtc_config.enabled
|
||||
|
||||
def forward(
|
||||
self, batch: dict[str, Tensor], noise=None, time=None, reduction: str = "mean"
|
||||
) -> dict[str, Tensor]:
|
||||
"""Do a full training forward pass to compute the loss.
|
||||
|
||||
Args:
|
||||
batch: Training batch containing observations and actions.
|
||||
noise: Optional noise tensor for flow matching.
|
||||
time: Optional time tensor for flow matching.
|
||||
reduction: How to reduce the loss. Options:
|
||||
- "mean": Return scalar mean loss (default, backward compatible)
|
||||
- "none": Return per-sample losses of shape (batch_size,) for RA-BC weighting
|
||||
"""
|
||||
def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> dict[str, Tensor]:
|
||||
"""Do a full training forward pass to compute the loss"""
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
|
||||
batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
|
||||
@@ -389,16 +377,11 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
||||
losses = losses[:, :, : self.config.max_action_dim]
|
||||
loss_dict["losses_after_rm_padding"] = losses.clone()
|
||||
|
||||
if reduction == "none":
|
||||
# Return per-sample losses (B,) by averaging over time and action dims
|
||||
per_sample_loss = losses.mean(dim=(1, 2))
|
||||
loss_dict["loss"] = per_sample_loss.mean().item()
|
||||
return per_sample_loss, loss_dict
|
||||
else:
|
||||
# Default: return scalar mean loss
|
||||
loss = losses.mean()
|
||||
loss_dict["loss"] = loss.item()
|
||||
return loss, loss_dict
|
||||
# For backward pass
|
||||
loss = losses.mean()
|
||||
# For backward pass
|
||||
loss_dict["loss"] = loss.item()
|
||||
return loss, loss_dict
|
||||
|
||||
def prepare_images(self, batch):
|
||||
"""Apply SmolVLA preprocessing to the images, like resizing to 224x224 and padding to keep aspect ratio, and
|
||||
@@ -544,7 +527,6 @@ class VLAFlowMatching(nn.Module):
|
||||
num_vlm_layers=self.config.num_vlm_layers,
|
||||
self_attn_every_n_layers=self.config.self_attn_every_n_layers,
|
||||
expert_width_multiplier=self.config.expert_width_multiplier,
|
||||
device=self.config.device if self.config.device is not None else "auto",
|
||||
)
|
||||
self.state_proj = nn.Linear(
|
||||
self.config.max_state_dim, self.vlm_with_expert.config.text_config.hidden_size
|
||||
@@ -801,15 +783,18 @@ class VLAFlowMatching(nn.Module):
|
||||
use_cache=self.config.use_cache,
|
||||
fill_kv_cache=True,
|
||||
)
|
||||
num_steps = self.config.num_steps
|
||||
dt = -1.0 / num_steps
|
||||
dt = -1.0 / self.config.num_steps
|
||||
dt = torch.tensor(dt, dtype=torch.float32, device=device)
|
||||
|
||||
x_t = noise
|
||||
for step in range(num_steps):
|
||||
time = 1.0 + step * dt
|
||||
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
|
||||
time = torch.tensor(1.0, dtype=torch.float32, device=device)
|
||||
|
||||
def denoise_step_partial_call(input_x_t, current_timestep=time_tensor):
|
||||
while time >= -dt / 2:
|
||||
expanded_time = time.expand(bsize)
|
||||
|
||||
# Define a closure function to properly capture expanded_time
|
||||
# This avoids the lambda expression (E731) and loop variable binding (B023) issues
|
||||
def denoise_step_partial_call(input_x_t, current_timestep=expanded_time):
|
||||
return self.denoise_step(
|
||||
x_t=input_x_t,
|
||||
prefix_pad_masks=prefix_pad_masks,
|
||||
@@ -833,11 +818,15 @@ class VLAFlowMatching(nn.Module):
|
||||
else:
|
||||
v_t = denoise_step_partial_call(x_t)
|
||||
|
||||
x_t = x_t + dt * v_t
|
||||
# Euler step
|
||||
x_t += dt * v_t
|
||||
|
||||
# Record x_t and v_t after Euler step (other params are recorded in rtc_processor.denoise_step)
|
||||
if self.rtc_processor is not None and self.rtc_processor.is_debug_enabled():
|
||||
self.rtc_processor.track(time=time, x_t=x_t, v_t=v_t)
|
||||
|
||||
time += dt
|
||||
|
||||
return x_t
|
||||
|
||||
def denoise_step(
|
||||
|
||||
@@ -65,7 +65,6 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
def __init__(
|
||||
self,
|
||||
config: TDMPCConfig,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
|
||||
@@ -231,20 +231,11 @@ def validate_visual_features_consistency(
|
||||
"""
|
||||
Validates visual feature consistency between a policy config and provided dataset/environment features.
|
||||
|
||||
Validation passes if EITHER:
|
||||
- Policy's expected visuals are a subset of dataset (policy uses some cameras, dataset has more)
|
||||
- Dataset's provided visuals are a subset of policy (policy declares extras for flexibility)
|
||||
|
||||
Args:
|
||||
cfg (PreTrainedConfig): The model or policy configuration containing input_features and type.
|
||||
features (Dict[str, PolicyFeature]): A mapping of feature names to PolicyFeature objects.
|
||||
"""
|
||||
expected_visuals = {k for k, v in cfg.input_features.items() if v.type == FeatureType.VISUAL}
|
||||
provided_visuals = {k for k, v in features.items() if v.type == FeatureType.VISUAL}
|
||||
|
||||
# Accept if either direction is a subset
|
||||
policy_subset_of_dataset = expected_visuals.issubset(provided_visuals)
|
||||
dataset_subset_of_policy = provided_visuals.issubset(expected_visuals)
|
||||
|
||||
if not (policy_subset_of_dataset or dataset_subset_of_policy):
|
||||
if not provided_visuals.issubset(expected_visuals):
|
||||
raise_feature_mismatch_error(provided_visuals, expected_visuals)
|
||||
|
||||
@@ -47,7 +47,6 @@ class VQBeTPolicy(PreTrainedPolicy):
|
||||
def __init__(
|
||||
self,
|
||||
config: VQBeTConfig | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
|
||||
@@ -1,35 +0,0 @@
|
||||
# WALL-OSS
|
||||
|
||||
This repository contains the Hugging Face port of **WALL-OSS**, a Vision-Language-Action model for cross-embodiment robotic control based on Qwen2.5-VL with flow matching/FAST action prediction.
|
||||
|
||||
---
|
||||
|
||||
## Model Overview
|
||||
|
||||
| Feature | Description |
|
||||
| ------------------ | ----------------------------------------------------- | --- |
|
||||
| Base Model | Qwen2.5-VL (Vision-Language Model) |
|
||||
| Action Prediction | Flow Matching (diffusion) or FAST (discrete tokens) |
|
||||
| Architecture | Mixture of Experts (MoE) with action-specific routing | |
|
||||
| Multi-Modal Inputs | Vision (images/videos), Language, Proprioception |
|
||||
|
||||
---
|
||||
|
||||
## Citation
|
||||
|
||||
If you use this work, please cite:
|
||||
|
||||
```bibtex
|
||||
@article{zhai2025igniting,
|
||||
title = {Igniting VLMs Toward the Embodied Space},
|
||||
author = {Zhai, Andy and Liu, Brae and Fang, Bruno and Cai, Chalse and Ma, Ellie and Yin, Ethan and Wang, Hao and Zhou, Hugo and Wang, James and Shi, Lights and Liang, Lucy and Wang, Make and Wang, Qian and Gan, Roy and Yu, Ryan and Li, Shalfun and Liu, Starrick and Chen, Sylas and Chen, Vincent and Xu, Zach},
|
||||
journal = {arXiv preprint arXiv:2509.11766},
|
||||
year = {2025}
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## License
|
||||
|
||||
This port follows the **Apache 2.0 License**.
|
||||
@@ -1,19 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .configuration_wall_x import WallXConfig
|
||||
|
||||
__all__ = ["WallXConfig", "WallXPolicy", "make_wall_x_pre_post_processors"]
|
||||
@@ -1,165 +0,0 @@
|
||||
# Copyright 2025 HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("wall_x")
|
||||
@dataclass
|
||||
class WallXConfig(PreTrainedConfig):
|
||||
"""
|
||||
Configuration class for Wall-X policy.
|
||||
|
||||
Wall-X is based on Qwen2.5-VL with action prediction capabilities using flow matching.
|
||||
It supports cross-embodiment robotic control through unified action representations.
|
||||
|
||||
This config supports multi-modal learning with vision, language, and action data.
|
||||
"""
|
||||
|
||||
# ==================== Input / Output Structure ====================
|
||||
n_obs_steps: int = 1
|
||||
chunk_size: int = 32 # action_horizon in wall-x
|
||||
n_action_steps: int = 32
|
||||
|
||||
# Action dimension - wall-x uses 20
|
||||
max_action_dim: int = 20
|
||||
max_state_dim: int = 20 # For proprioception
|
||||
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.MEAN_STD,
|
||||
"ACTION": NormalizationMode.MEAN_STD,
|
||||
}
|
||||
)
|
||||
|
||||
# ==================== Action Prediction ====================
|
||||
# Pretrained model paths
|
||||
pretrained_name_or_path: str = "x-square-robot/wall-oss-flow"
|
||||
|
||||
# Tokenizer settings
|
||||
action_tokenizer_path: str | None = "physical-intelligence/fast"
|
||||
|
||||
# Action prediction mode: "diffusion" or "fast"
|
||||
prediction_mode: str = "diffusion"
|
||||
|
||||
# Attention Implementation, options: "eager", "flash_attention_2", "sdpa"
|
||||
# NOTE: flash-attn==2.7.4.post1 is required for flash_attention_2 implementation
|
||||
attn_implementation: str = "eager"
|
||||
|
||||
# ==================== Optimizer Presets ====================
|
||||
optimizer_lr: float = 2e-5
|
||||
optimizer_betas: tuple[float, float] = (0.9, 0.95)
|
||||
optimizer_eps: float = 1e-8
|
||||
optimizer_weight_decay: float = 0.01
|
||||
optimizer_grad_clip_norm: float = 1.0
|
||||
|
||||
scheduler_warmup_steps: int = 1000
|
||||
scheduler_decay_steps: int = 100000
|
||||
scheduler_decay_lr: float = 1e-6
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
# Input validation
|
||||
if self.n_action_steps > self.chunk_size:
|
||||
raise ValueError(
|
||||
f"The chunk size is the upper bound for the number of action steps per model invocation. Got "
|
||||
f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`."
|
||||
)
|
||||
|
||||
if self.prediction_mode not in ["diffusion", "fast"]:
|
||||
raise ValueError(f"prediction_mode must be 'diffusion' or 'fast', got {self.prediction_mode}")
|
||||
|
||||
# Assign use_fast_tokenizer based on prediction_mode
|
||||
if self.prediction_mode == "fast":
|
||||
self.use_fast_tokenizer = True
|
||||
elif self.prediction_mode == "diffusion":
|
||||
self.use_fast_tokenizer = False
|
||||
self.action_tokenizer_path = None # disable action tokenizer for diffusion mode
|
||||
else:
|
||||
raise ValueError(f"prediction_mode must be 'diffusion' or 'fast', got {self.prediction_mode}")
|
||||
|
||||
def validate_features(self) -> None:
|
||||
"""Validate and set up input/output features."""
|
||||
image_features = [key for key, feat in self.input_features.items() if feat.type == FeatureType.VISUAL]
|
||||
if not image_features:
|
||||
raise ValueError(
|
||||
"Wall-X policy requires at least one visual input feature. "
|
||||
"No features of type FeatureType.VISUAL found in input_features."
|
||||
)
|
||||
|
||||
if "observation.state" not in self.input_features:
|
||||
state_feature = PolicyFeature(
|
||||
type=FeatureType.STATE,
|
||||
shape=(self.max_state_dim,), # Padded to max_state_dim
|
||||
)
|
||||
self.input_features["observation.state"] = state_feature
|
||||
else:
|
||||
state_shape = self.input_features["observation.state"].shape
|
||||
state_dim = state_shape[0] if state_shape else 0
|
||||
if state_dim > self.max_state_dim:
|
||||
raise ValueError(
|
||||
f"State dimension {state_dim} exceeds max_state_dim {self.max_state_dim}. "
|
||||
f"Either reduce state dimension or increase max_state_dim in config."
|
||||
)
|
||||
|
||||
if "action" not in self.output_features:
|
||||
action_feature = PolicyFeature(
|
||||
type=FeatureType.ACTION,
|
||||
shape=(self.max_action_dim,), # Padded to max_action_dim
|
||||
)
|
||||
self.output_features["action"] = action_feature
|
||||
else:
|
||||
action_shape = self.output_features["action"].shape
|
||||
action_dim = action_shape[0] if action_shape else 0
|
||||
if action_dim > self.max_action_dim:
|
||||
raise ValueError(
|
||||
f"Action dimension {action_dim} exceeds max_action_dim {self.max_action_dim}. "
|
||||
f"Either reduce action dimension or increase max_action_dim in config."
|
||||
)
|
||||
|
||||
def get_optimizer_preset(self) -> AdamWConfig:
|
||||
return AdamWConfig(
|
||||
lr=self.optimizer_lr,
|
||||
betas=self.optimizer_betas,
|
||||
eps=self.optimizer_eps,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
grad_clip_norm=self.optimizer_grad_clip_norm,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self):
|
||||
return CosineDecayWithWarmupSchedulerConfig(
|
||||
peak_lr=self.optimizer_lr,
|
||||
decay_lr=self.scheduler_decay_lr,
|
||||
num_warmup_steps=self.scheduler_warmup_steps,
|
||||
num_decay_steps=self.scheduler_decay_steps,
|
||||
)
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> list:
|
||||
return None
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list:
|
||||
return list(range(self.chunk_size))
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
||||
@@ -1,41 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Wall-X Constants and Configuration Data.
|
||||
"""
|
||||
|
||||
CAMERA_NAME_MAPPING = {
|
||||
"face_view": "front view",
|
||||
"left_wrist_view": "left wrist view",
|
||||
"right_wrist_view": "right wrist view",
|
||||
"move1_view": "move view",
|
||||
"move2_view": "move view",
|
||||
"wall_view": "wall view",
|
||||
"top_view": "top view",
|
||||
}
|
||||
|
||||
RESOLUTION = 256
|
||||
|
||||
# Parameters for preprocessing
|
||||
MAX_PIXELS = 16384 * 28 * 28
|
||||
MIN_PIXELS = 4 * 28 * 28
|
||||
IMAGE_FACTOR = 28
|
||||
PRIORITY_ORDER = None
|
||||
GENERATE_SUBTASK_RATIO = 0.0
|
||||
MODEL_TYPE = "qwen2_5"
|
||||
|
||||
TOKENIZER_MAX_LENGTH = 768
|
||||
@@ -1,133 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.policies.wall_x.configuration_wall_x import WallXConfig
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
ComplementaryDataProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyAction,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorStepRegistry,
|
||||
RenameObservationsProcessorStep,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
|
||||
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
|
||||
|
||||
def make_wall_x_pre_post_processors(
|
||||
config: WallXConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
"""
|
||||
Constructs pre-processor and post-processor pipelines for the Wall-X policy.
|
||||
|
||||
The pre-processing pipeline prepares input data for the model by:
|
||||
1. Renaming features to match pretrained configurations
|
||||
2. Adding a batch dimension
|
||||
4. Normalizing input and output features based on dataset statistics
|
||||
5. Moving all data to the specified device
|
||||
|
||||
The post-processing pipeline handles the model's output by:
|
||||
1. Unnormalizing the output actions to their original scale
|
||||
2. Moving data to the CPU
|
||||
|
||||
Args:
|
||||
config: The configuration object for the Wall-X policy
|
||||
dataset_stats: A dictionary of statistics for normalization
|
||||
|
||||
Returns:
|
||||
A tuple containing the configured pre-processor and post-processor pipelines
|
||||
"""
|
||||
|
||||
input_steps = [
|
||||
RenameObservationsProcessorStep(rename_map={}),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
WallXTaskProcessor(), # Process task description
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
]
|
||||
|
||||
output_steps = [
|
||||
UnnormalizerProcessorStep(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
]
|
||||
|
||||
return (
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||
steps=input_steps,
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
),
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction](
|
||||
steps=output_steps,
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
to_transition=policy_action_to_transition,
|
||||
to_output=transition_to_policy_action,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register(name="wall_x_task_processor")
|
||||
class WallXTaskProcessor(ComplementaryDataProcessorStep):
|
||||
"""
|
||||
A processor step that ensures the task description is properly formatted for Wall-X.
|
||||
|
||||
This step handles task preprocessing similar to Qwen-VL requirements.
|
||||
"""
|
||||
|
||||
def complementary_data(self, complementary_data):
|
||||
if "task" not in complementary_data:
|
||||
return complementary_data
|
||||
|
||||
task = complementary_data["task"]
|
||||
if task is None:
|
||||
# Provide default task if none specified
|
||||
complementary_data["task"] = "Execute the robot action."
|
||||
return complementary_data
|
||||
|
||||
new_complementary_data = dict(complementary_data)
|
||||
|
||||
# Handle both string and list of strings
|
||||
if isinstance(task, str):
|
||||
# Single string: ensure proper formatting
|
||||
if not task.endswith("."):
|
||||
new_complementary_data["task"] = f"{task}."
|
||||
elif isinstance(task, list) and all(isinstance(t, str) for t in task):
|
||||
# List of strings: format each
|
||||
new_complementary_data["task"] = [t if t.endswith(".") else f"{t}." for t in task]
|
||||
|
||||
return new_complementary_data
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
return features
|
||||
@@ -1,248 +0,0 @@
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.modeling_rope_utils import rope_config_validation
|
||||
|
||||
|
||||
class Qwen2_5_VLVisionConfig(PretrainedConfig):
|
||||
model_type = "qwen2_5_vl"
|
||||
base_config_key = "vision_config"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
depth=32,
|
||||
hidden_size=3584,
|
||||
hidden_act="silu",
|
||||
intermediate_size=3420,
|
||||
num_heads=16,
|
||||
in_channels=3,
|
||||
patch_size=14,
|
||||
spatial_merge_size=2,
|
||||
temporal_patch_size=2,
|
||||
tokens_per_second=4,
|
||||
window_size=112,
|
||||
out_hidden_size=3584,
|
||||
fullatt_block_indexes=[7, 15, 23, 31],
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.depth = depth
|
||||
self.hidden_size = hidden_size
|
||||
self.hidden_act = hidden_act
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_heads = num_heads
|
||||
self.in_channels = in_channels
|
||||
self.patch_size = patch_size
|
||||
self.spatial_merge_size = spatial_merge_size
|
||||
self.temporal_patch_size = temporal_patch_size
|
||||
self.tokens_per_second = tokens_per_second
|
||||
self.window_size = window_size
|
||||
self.fullatt_block_indexes = fullatt_block_indexes
|
||||
self.out_hidden_size = out_hidden_size
|
||||
|
||||
|
||||
class Qwen2_5_VLConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`Qwen2_5_VLModel`]. It is used to instantiate a
|
||||
Qwen2-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
||||
with the defaults will yield a similar configuration to that of
|
||||
Qwen2-VL-7B-Instruct [Qwen/Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct).
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 152064):
|
||||
Vocabulary size of the Qwen2_5_VL model. Defines the number of different tokens that can be represented by the
|
||||
`inputs_ids` passed when calling [`Qwen2_5_VLModel`]
|
||||
hidden_size (`int`, *optional*, defaults to 8192):
|
||||
Dimension of the hidden representations.
|
||||
intermediate_size (`int`, *optional*, defaults to 29568):
|
||||
Dimension of the MLP representations.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 80):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 64):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
num_key_value_heads (`int`, *optional*, defaults to 8):
|
||||
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
||||
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
||||
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
||||
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
||||
by meanpooling all the original heads within that group. For more details checkout [this
|
||||
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
||||
The non-linear activation function (function or string) in the decoder.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 32768):
|
||||
The maximum sequence length that this model might ever be used with.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
|
||||
The epsilon used by the rms normalization layers.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||
relevant if `config.is_decoder=True`.
|
||||
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
||||
Whether the model's input and output word embeddings should be tied.
|
||||
rope_theta (`float`, *optional*, defaults to 1000000.0):
|
||||
The base period of the RoPE embeddings.
|
||||
use_sliding_window (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use sliding window attention.
|
||||
sliding_window (`int`, *optional*, defaults to 4096):
|
||||
Sliding window attention (SWA) window size. If not specified, will default to `4096`.
|
||||
max_window_layers (`int`, *optional*, defaults to 80):
|
||||
The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
vision_config (`Dict`, *optional*):
|
||||
The config for the visual encoder initialization.
|
||||
rope_scaling (`Dict`, *optional*):
|
||||
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
|
||||
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
||||
accordingly.
|
||||
Expected contents:
|
||||
`rope_type` (`str`):
|
||||
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
|
||||
'llama3'], with 'default' being the original RoPE implementation.
|
||||
`factor` (`float`, *optional*):
|
||||
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
|
||||
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
|
||||
original maximum pre-trained length.
|
||||
`original_max_position_embeddings` (`int`, *optional*):
|
||||
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
|
||||
pretraining.
|
||||
`attention_factor` (`float`, *optional*):
|
||||
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
|
||||
computation. If unspecified, it defaults to value recommended by the implementation, using the
|
||||
`factor` field to infer the suggested value.
|
||||
`beta_fast` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 32.
|
||||
`beta_slow` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 1.
|
||||
`short_factor` (`List[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`long_factor` (`List[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`low_freq_factor` (`float`, *optional*):
|
||||
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
|
||||
`high_freq_factor` (`float`, *optional*):
|
||||
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
|
||||
|
||||
```python
|
||||
>>> from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLConfig
|
||||
|
||||
>>> # Initializing a Qwen2_5_VL style configuration
|
||||
>>> configuration = Qwen2_5_VLConfig()
|
||||
|
||||
>>> # Initializing a model from the Qwen2-VL-7B style configuration
|
||||
>>> model = Qwen2_5_VLForConditionalGeneration(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "qwen2_5_vl"
|
||||
sub_configs = {"vision_config": Qwen2_5_VLVisionConfig}
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
# Default tensor parallel plan for base model `Qwen2_5_VL`
|
||||
base_model_tp_plan = {
|
||||
"layers.*.self_attn.q_proj": "colwise",
|
||||
"layers.*.self_attn.k_proj": "colwise",
|
||||
"layers.*.self_attn.v_proj": "colwise",
|
||||
"layers.*.self_attn.o_proj": "rowwise",
|
||||
"layers.*.mlp.gate_proj": "colwise",
|
||||
"layers.*.mlp.up_proj": "colwise",
|
||||
"layers.*.mlp.down_proj": "rowwise",
|
||||
}
|
||||
base_model_pp_plan = {
|
||||
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||
"norm": (["hidden_states"], ["hidden_states"]),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=152064,
|
||||
hidden_size=8192,
|
||||
intermediate_size=29568,
|
||||
num_hidden_layers=80,
|
||||
num_attention_heads=64,
|
||||
num_key_value_heads=8,
|
||||
hidden_act="silu",
|
||||
max_position_embeddings=32768,
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=1e-05,
|
||||
use_cache=True,
|
||||
tie_word_embeddings=False,
|
||||
rope_theta=1000000.0,
|
||||
use_sliding_window=False,
|
||||
sliding_window=4096,
|
||||
max_window_layers=80,
|
||||
attention_dropout=0.0,
|
||||
vision_config=None,
|
||||
rope_scaling=None,
|
||||
num_experts=4,
|
||||
experts=None,
|
||||
dof_config=None,
|
||||
noise_scheduler=None,
|
||||
dim_inputs=(1536, 1536),
|
||||
attention_moe=False,
|
||||
mlp_moe=False,
|
||||
**kwargs,
|
||||
):
|
||||
if isinstance(vision_config, dict):
|
||||
self.vision_config = self.sub_configs["vision_config"](**vision_config)
|
||||
elif vision_config is None:
|
||||
self.vision_config = self.sub_configs["vision_config"]()
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.use_sliding_window = use_sliding_window
|
||||
self.sliding_window = sliding_window
|
||||
self.max_window_layers = max_window_layers
|
||||
self.layer_types = ["dense"] * num_hidden_layers
|
||||
|
||||
# for backward compatibility
|
||||
if num_key_value_heads is None:
|
||||
num_key_value_heads = num_attention_heads
|
||||
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.attention_dropout = attention_dropout
|
||||
self.rope_scaling = rope_scaling
|
||||
|
||||
self.num_experts = num_experts
|
||||
self.experts = experts
|
||||
self.dof_config = dof_config
|
||||
self.noise_scheduler = noise_scheduler
|
||||
self.dim_inputs = tuple(dim_inputs)
|
||||
self.attention_moe = attention_moe
|
||||
self.mlp_moe = mlp_moe
|
||||
|
||||
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
||||
if self.rope_scaling["type"] == "mrope":
|
||||
self.rope_scaling["type"] = "default"
|
||||
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
||||
rope_config_validation(self, ignore_keys={"mrope_section"})
|
||||
|
||||
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
|
||||
|
||||
@property
|
||||
def text_config(self):
|
||||
return self
|
||||
|
||||
|
||||
__all__ = ["Qwen2_5_VLConfig"]
|
||||
@@ -1,631 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Wall-X Utility Functions.
|
||||
|
||||
Contains data processing utilities, text formatting functions, and helper classes
|
||||
for the Wall-X cross-embodiment robotic control model.
|
||||
"""
|
||||
|
||||
import random
|
||||
import re
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from transformers import BatchFeature
|
||||
|
||||
from lerobot.policies.wall_x.constant import (
|
||||
CAMERA_NAME_MAPPING,
|
||||
)
|
||||
from lerobot.utils.constants import OBS_IMAGES
|
||||
|
||||
|
||||
@dataclass
|
||||
class X2RDataProcessingConfig:
|
||||
"""Configuration class for X2R data processing pipeline.
|
||||
|
||||
This class contains all the necessary parameters for processing robotic data
|
||||
including camera mappings, tactile sensor configurations, action predictions,
|
||||
and various processing options.
|
||||
"""
|
||||
|
||||
# Action prediction configuration
|
||||
predict_action_keys: list[str] = field(default_factory=list)
|
||||
obs_action_keys: list[str] = field(default_factory=list)
|
||||
|
||||
# Image resolution settings for different views
|
||||
resolution: dict[str, int] = field(
|
||||
default_factory=lambda: {
|
||||
"face_view": -1,
|
||||
"left_wrist_view": 128,
|
||||
"right_wrist_view": 128,
|
||||
}
|
||||
)
|
||||
|
||||
# Dataset splitting
|
||||
train_test_split: float = 0.9
|
||||
split_seed: int = 42
|
||||
|
||||
# Instruction handling
|
||||
priority_order: dict[str, float] | None = None
|
||||
|
||||
# Vision model parameters
|
||||
model_type: str = "qwen2_5"
|
||||
max_pixels: int = 16384 * 28 * 28
|
||||
min_pixels: int = 4 * 28 * 28
|
||||
image_factor: int = 28
|
||||
|
||||
generate_subtask_ratio: float = 0.0
|
||||
|
||||
def __post_init__(self):
|
||||
"""Post-initialization validation and setup."""
|
||||
# Validate train/test split
|
||||
if not 0 < self.train_test_split < 1:
|
||||
raise ValueError(f"train_test_split must be between 0 and 1, got {self.train_test_split}")
|
||||
|
||||
def as_dict(self) -> dict:
|
||||
"""Convert configuration to dictionary format.
|
||||
|
||||
Returns:
|
||||
Dict: Configuration as dictionary
|
||||
"""
|
||||
return self.__dict__
|
||||
|
||||
def update(self, **kwargs) -> "X2RDataProcessingConfig":
|
||||
"""Update configuration parameters.
|
||||
|
||||
Args:
|
||||
**kwargs: Key-value pairs to update
|
||||
|
||||
Returns:
|
||||
X2RDataProcessingConfig: Updated configuration instance
|
||||
"""
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(self, key):
|
||||
setattr(self, key, value)
|
||||
else:
|
||||
raise ValueError(f"Unknown configuration parameter: {key}")
|
||||
return self
|
||||
|
||||
|
||||
def preprocesser_call(
|
||||
processor,
|
||||
images: list | Any | None = None,
|
||||
text: str | list[str] | None = None,
|
||||
videos: list | Any | None = None,
|
||||
padding: bool | str = False,
|
||||
truncation: bool | None = None,
|
||||
max_length: int | None = None,
|
||||
return_tensors: str = "pt",
|
||||
) -> BatchFeature:
|
||||
"""Unified preprocessing function for Wall-X model handling text, image and video inputs.
|
||||
|
||||
Processes inputs into format suitable for multimodal transformer models, including:
|
||||
- Text tokenization and special token handling
|
||||
- Image/video processing through image processor
|
||||
- Attention mask and label generation
|
||||
- Padding and truncation handling
|
||||
|
||||
Args:
|
||||
processor: Multimodal processor containing tokenizer and image processor
|
||||
images: Input images (PIL, numpy arrays, or torch tensors)
|
||||
text: Text or list of texts to tokenize
|
||||
videos: Input videos (numpy arrays or torch tensors)
|
||||
padding: Whether to pad sequences to same length
|
||||
truncation: Whether to truncate sequences longer than max_length
|
||||
max_length: Maximum length for truncation/padding
|
||||
return_tensors: Format for returned tensors ('pt', 'np', etc.)
|
||||
|
||||
Returns:
|
||||
BatchFeature containing processed inputs with keys:
|
||||
- input_ids: Tokenized text
|
||||
- attention_mask: Attention mask for text
|
||||
- pixel_values: Processed image pixels
|
||||
- pixel_values_videos: Processed video frames
|
||||
- image_grid_thw: Image grid dimensions for LLM
|
||||
- video_grid_thw: Video grid dimensions for LLM
|
||||
- labels: Training labels with masking
|
||||
"""
|
||||
# Process image inputs
|
||||
if images is not None and len(images) > 0:
|
||||
image_inputs = processor.image_processor(images=images, videos=None, return_tensors=return_tensors)
|
||||
image_grid_thw = image_inputs["image_grid_thw"]
|
||||
else:
|
||||
image_inputs = {}
|
||||
image_grid_thw = None
|
||||
|
||||
# Process video inputs
|
||||
if videos is not None:
|
||||
videos_inputs = processor.image_processor(images=None, videos=videos, return_tensors=return_tensors)
|
||||
video_grid_thw = videos_inputs["video_grid_thw"]
|
||||
else:
|
||||
videos_inputs = {}
|
||||
video_grid_thw = None
|
||||
|
||||
# Ensure text input is in list format
|
||||
if not isinstance(text, list):
|
||||
text = [text]
|
||||
|
||||
# Process image placeholder tokens in text
|
||||
if image_grid_thw is not None:
|
||||
merge_length = processor.image_processor.merge_size**2
|
||||
index = 0
|
||||
for i in range(len(text)):
|
||||
while "<|image_pad|>" in text[i]:
|
||||
# Add bounds checking to avoid index overflow
|
||||
if index >= len(image_grid_thw):
|
||||
print(
|
||||
f"Warning: Number of image placeholders ({index + 1}) "
|
||||
f"exceeds actual images ({len(image_grid_thw)}), "
|
||||
f"skipping remaining placeholder processing"
|
||||
)
|
||||
break
|
||||
# Replace image placeholder with actual token count
|
||||
token_count = image_grid_thw[index].prod() // merge_length
|
||||
text[i] = text[i].replace("<|image_pad|>", "<|placeholder|>" * token_count, 1)
|
||||
index += 1
|
||||
text[i] = text[i].replace("<|placeholder|>", "<|image_pad|>")
|
||||
|
||||
# Process video placeholder tokens in text
|
||||
if video_grid_thw is not None:
|
||||
merge_length = processor.image_processor.merge_size**2
|
||||
index = 0
|
||||
for i in range(len(text)):
|
||||
while "<|video_pad|>" in text[i]:
|
||||
# Replace video placeholder with actual token count
|
||||
token_count = video_grid_thw[index].prod() // merge_length
|
||||
text[i] = text[i].replace("<|video_pad|>", "<|placeholder|>" * token_count, 1)
|
||||
index += 1
|
||||
text[i] = text[i].replace("<|placeholder|>", "<|video_pad|>")
|
||||
|
||||
# Tokenize complete input text
|
||||
text_inputs = processor.tokenizer(
|
||||
text,
|
||||
return_tensors=return_tensors,
|
||||
padding=padding,
|
||||
truncation=truncation,
|
||||
max_length=max_length,
|
||||
)
|
||||
|
||||
# Get pad token ID for label generation
|
||||
pad_token_id = processor.tokenizer.pad_token_id
|
||||
if pad_token_id is None:
|
||||
pad_token_id = processor.tokenizer.eos_token_id
|
||||
|
||||
# Generate labels for multi-turn dialogue, keeping only assistant response loss
|
||||
labels = torch.full_like(text_inputs.input_ids, -100)
|
||||
assistant_marker = "<|im_start|>assistant\n"
|
||||
im_end_token_id = processor.tokenizer.convert_tokens_to_ids("<|im_end|>")
|
||||
assistant_tokens = processor.tokenizer("<|im_start|>assistant\n", add_special_tokens=False).input_ids
|
||||
|
||||
for i in range(len(text)):
|
||||
assistant_regions = []
|
||||
parts = text[i].split(assistant_marker)
|
||||
|
||||
# Process each part to determine which tokens belong to assistant responses
|
||||
# Count left padding tokens
|
||||
num_left_pads = 0
|
||||
for token_id in text_inputs.input_ids[i]:
|
||||
if token_id == pad_token_id:
|
||||
num_left_pads += 1
|
||||
else:
|
||||
break
|
||||
current_pos = num_left_pads
|
||||
|
||||
for j, part in enumerate(parts):
|
||||
part_tokens = processor.tokenizer(part, add_special_tokens=False).input_ids
|
||||
if j == 0:
|
||||
# First part is system prompt or user question, all labels are -100
|
||||
current_pos += len(part_tokens)
|
||||
continue
|
||||
|
||||
# From second part onwards, each part starts with assistant response
|
||||
for k in range(current_pos + 1, len(text_inputs.input_ids[i])):
|
||||
if text_inputs.input_ids[i][k] == im_end_token_id:
|
||||
assistant_regions.append((current_pos + len(assistant_tokens), k + 2))
|
||||
break
|
||||
current_pos += len(part_tokens) + 3
|
||||
|
||||
# Set labels for assistant response regions
|
||||
for start, end in assistant_regions:
|
||||
labels[i][start:end] = text_inputs.input_ids[i][start:end]
|
||||
|
||||
# Mask special action tokens in labels
|
||||
action_token_id = processor.tokenizer.encode("<|action|>")[0]
|
||||
propri_token_id = processor.tokenizer.encode("<|propri|>")[0]
|
||||
labels[labels == action_token_id] = -100
|
||||
labels[labels == propri_token_id] = -100
|
||||
labels[labels == processor.tokenizer.pad_token_id] = -100
|
||||
|
||||
# Set labels to None if all are invalid to skip cross entropy loss
|
||||
if (labels != -100).any().item():
|
||||
text_inputs["labels"] = labels
|
||||
else:
|
||||
text_inputs["labels"] = None
|
||||
|
||||
return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs})
|
||||
|
||||
|
||||
def process_grounding_points(
|
||||
text: str,
|
||||
orig_height: int,
|
||||
orig_width: int,
|
||||
resized_height: int,
|
||||
resized_width: int,
|
||||
model_type: str,
|
||||
) -> str:
|
||||
"""Process grounding point coordinates in text based on image resizing.
|
||||
|
||||
Adjusts coordinate values in <point> tags to match resized image dimensions
|
||||
for different model types (qwen2, qwen2_5).
|
||||
|
||||
Args:
|
||||
text: Input text containing <point> tags with coordinates
|
||||
orig_height: Original image height
|
||||
orig_width: Original image width
|
||||
resized_height: Resized image height
|
||||
resized_width: Resized image width
|
||||
model_type: Model type for coordinate processing ('qwen2' or 'qwen2_5')
|
||||
|
||||
Returns:
|
||||
Text with adjusted coordinate values
|
||||
"""
|
||||
# Regex pattern to match <point> tags and their contents
|
||||
point_pattern = re.compile(r"<point>(.*?)</point>")
|
||||
|
||||
def process_match(match):
|
||||
"""Process a single point match and adjust coordinates."""
|
||||
coords_str = match.group(1)
|
||||
try:
|
||||
# Extract coordinates from string
|
||||
coords = list(map(int, re.findall(r"\d+", coords_str)))
|
||||
|
||||
# Calculate resize scale factors
|
||||
scale_w = resized_width / orig_width
|
||||
scale_h = resized_height / orig_height
|
||||
|
||||
if len(coords) == 2:
|
||||
x, y = coords
|
||||
if model_type == "qwen2_5":
|
||||
# Qwen2.5 uses pixel coordinates
|
||||
new_x = max(0, min(round(x * scale_w), resized_width - 1))
|
||||
new_y = max(0, min(round(y * scale_h), resized_height - 1))
|
||||
elif model_type == "qwen2":
|
||||
# Qwen2 normalizes to [0, 1000) range
|
||||
new_x = max(0, min(999.999, (x / orig_width) * 1000))
|
||||
new_y = max(0, min(999.999, (y / orig_height) * 1000))
|
||||
else:
|
||||
raise ValueError(f"Unsupported model type: {model_type}")
|
||||
coords = [new_x, new_y]
|
||||
|
||||
elif len(coords) == 4:
|
||||
x1, y1, x2, y2 = coords
|
||||
if model_type == "qwen2_5":
|
||||
new_x1 = max(0, min(round(x1 * scale_w), resized_width - 1))
|
||||
new_y1 = max(0, min(round(y1 * scale_h), resized_height - 1))
|
||||
new_x2 = max(0, min(round(x2 * scale_w), resized_width - 1))
|
||||
new_y2 = max(0, min(round(y2 * scale_h), resized_height - 1))
|
||||
elif model_type == "qwen2":
|
||||
new_x1 = max(0, min(999.999, (x1 / orig_width) * 1000))
|
||||
new_y1 = max(0, min(999.999, (y1 / orig_height) * 1000))
|
||||
new_x2 = max(0, min(999.999, (x2 / orig_width) * 1000))
|
||||
new_y2 = max(0, min(999.999, (y2 / orig_height) * 1000))
|
||||
else:
|
||||
raise ValueError(f"Unsupported model type: {model_type}")
|
||||
coords = [new_x1, new_y1, new_x2, new_y2]
|
||||
|
||||
# Return processed point tag
|
||||
return f"<point>[{', '.join(map(str, coords))}]</point>"
|
||||
|
||||
except (ValueError, TypeError):
|
||||
# Return original content if processing fails
|
||||
return match.group(0)
|
||||
|
||||
# Replace all matching point tags
|
||||
processed_text = point_pattern.sub(process_match, text)
|
||||
return processed_text
|
||||
|
||||
|
||||
def get_frame_instruction(
|
||||
instruction_info: dict[str, Any],
|
||||
frame_idx: int | None = None,
|
||||
truncate_keys: list[str] | None = None,
|
||||
) -> tuple[dict[str, Any], int | None]:
|
||||
"""Extract frame-specific instruction from instruction dictionary.
|
||||
|
||||
Args:
|
||||
instruction_info: Dictionary containing instruction components
|
||||
frame_idx: Current frame index
|
||||
truncate_keys: Keys that trigger truncation when found
|
||||
|
||||
Returns:
|
||||
Tuple of (frame_instruction_dict, split_end_frame)
|
||||
"""
|
||||
if truncate_keys is None:
|
||||
truncate_keys = [
|
||||
"subtask_generation",
|
||||
"distribute",
|
||||
"subtask_generation_zh",
|
||||
"distribute_zh",
|
||||
]
|
||||
|
||||
instruction_for_frame = {}
|
||||
split_end = None
|
||||
|
||||
for key, value in instruction_info.items():
|
||||
if isinstance(value, dict):
|
||||
# Handle frame-range specific instructions
|
||||
for frame_range, frame_instruction in value.items():
|
||||
start_frame, end_frame = map(int, frame_range.split(" "))
|
||||
if start_frame <= frame_idx < end_frame or (start_frame == frame_idx):
|
||||
instruction_for_frame[key] = frame_instruction
|
||||
if truncate_keys is not None and split_end is None and key in truncate_keys:
|
||||
split_end = end_frame + 1
|
||||
break
|
||||
else:
|
||||
instruction_for_frame[key] = value
|
||||
|
||||
return instruction_for_frame, split_end
|
||||
|
||||
|
||||
def get_task_instruction(
|
||||
frame_instruction_info: dict[str, Any], priority_order: OrderedDict | None = None
|
||||
) -> str:
|
||||
"""Construct task instruction from available instruction fields using priority sampling.
|
||||
|
||||
Args:
|
||||
frame_instruction_info: Dictionary containing instruction fields
|
||||
priority_order: OrderedDict specifying sampling probability for each field
|
||||
|
||||
Returns:
|
||||
Combined instruction string with priority components
|
||||
"""
|
||||
# Default priority settings
|
||||
default_priority_order = OrderedDict(
|
||||
{
|
||||
"subtask_generation": 0.25,
|
||||
"subtask_generation_zh": 0.25,
|
||||
"distribute": 0.25,
|
||||
"distribute_zh": 0.25,
|
||||
}
|
||||
)
|
||||
|
||||
if priority_order is not None:
|
||||
priority_order = OrderedDict(priority_order)
|
||||
else:
|
||||
priority_order = default_priority_order
|
||||
|
||||
got_instruction = False
|
||||
task_instruction = ""
|
||||
|
||||
# Sample instruction components based on priority probabilities
|
||||
for key, prob in priority_order.items():
|
||||
if key in frame_instruction_info and frame_instruction_info[key] != "":
|
||||
if got_instruction:
|
||||
if random.random() >= prob:
|
||||
continue
|
||||
|
||||
task_instruction += f"\n{frame_instruction_info[key]}"
|
||||
got_instruction = True
|
||||
break
|
||||
|
||||
# Fall back to base instruction if no priority components found
|
||||
if not got_instruction:
|
||||
task_instruction = frame_instruction_info.get("instruction", "")
|
||||
|
||||
return task_instruction
|
||||
|
||||
|
||||
def get_wallx_normal_text(
|
||||
instruction_info: dict[str, Any],
|
||||
action_chunk_size: int,
|
||||
frame_idx: int,
|
||||
priority_order: OrderedDict | None = None,
|
||||
img_keys: list[str] | None = None,
|
||||
generate_subtask_ratio: float = 0.0,
|
||||
) -> tuple[str, bool]:
|
||||
"""Construct complete multimodal prompt text for Wall-X model.
|
||||
|
||||
Formats input using special tokens including:
|
||||
- System message
|
||||
- User observations (with image placeholders)
|
||||
- Task instructions
|
||||
- Proprioception prompts
|
||||
- Assistant responses (with action tokens)
|
||||
|
||||
Args:
|
||||
instruction_info: Dictionary containing instruction components
|
||||
action_chunk_size: Number of action tokens to generate
|
||||
frame_idx: Current frame index
|
||||
priority_order: Priority order for instruction sampling
|
||||
img_keys: List of image keys
|
||||
generate_subtask_ratio: Probability of generating subtask instead of actions
|
||||
|
||||
Returns:
|
||||
Tuple of (formatted_prompt_text, is_subtask_generation)
|
||||
"""
|
||||
# Special tokens for formatting
|
||||
role_start_symbol = "<|im_start|>"
|
||||
role_end_symbol = "<|im_end|>"
|
||||
vision_start_symbol = "<|vision_start|>"
|
||||
vision_end_symbol = "<|vision_end|>"
|
||||
image_pad_symbol = "<|image_pad|>"
|
||||
propri_symbol = "<|propri|>"
|
||||
action_symbol = "<|action|>"
|
||||
action_fast_symbol = "<|action_fast|>"
|
||||
|
||||
# System prologue
|
||||
prologue = f"{role_start_symbol}system\nYou are a helpful assistant.{role_end_symbol}\n"
|
||||
|
||||
# User request with observation
|
||||
user_request = f"{role_start_symbol}user\nObservation:"
|
||||
if img_keys:
|
||||
img_keys = img_key_mapping(img_keys)
|
||||
for key in img_keys:
|
||||
user_request += f" {key}: {vision_start_symbol}{image_pad_symbol}{vision_end_symbol}"
|
||||
user_request += "\nInstruction:"
|
||||
|
||||
# Get frame-specific instruction
|
||||
frame_instruction_info, _ = get_frame_instruction(instruction_info, frame_idx=frame_idx)
|
||||
|
||||
generate_subtask = False
|
||||
priority_keys = ["subtask_generation", "distribute"]
|
||||
|
||||
# Decide whether to generate subtask or actions
|
||||
if (
|
||||
bool(set(frame_instruction_info.keys()) & set(priority_keys))
|
||||
and random.random() < generate_subtask_ratio
|
||||
):
|
||||
# Generate subtask (equivalent to VQA task)
|
||||
instruction = frame_instruction_info.get("instruction", "")
|
||||
text_prompt = "\nPredict the next action in language.\n"
|
||||
user_message = f"{user_request} {instruction}{text_prompt}{role_end_symbol}\n"
|
||||
|
||||
# Find output instruction from priority keys
|
||||
for key in priority_keys:
|
||||
if key in frame_instruction_info:
|
||||
output_instruction = frame_instruction_info[key]
|
||||
break
|
||||
|
||||
assistant_output = f"{role_start_symbol}assistant\n{output_instruction}\n{role_end_symbol}"
|
||||
generate_subtask = True
|
||||
else:
|
||||
# Generate actions
|
||||
instruction = get_task_instruction(frame_instruction_info, priority_order=priority_order)
|
||||
text_prompt = f"\nPredict the next action in robot action.\nProprioception: {propri_symbol}\n"
|
||||
user_message = f"{user_request} {instruction}{text_prompt}{role_end_symbol}\n"
|
||||
assistant_output = f"{role_start_symbol}assistant\n{action_fast_symbol}{role_end_symbol}\n{action_symbol * action_chunk_size}"
|
||||
|
||||
complete_text = prologue + user_message + assistant_output
|
||||
return complete_text, generate_subtask
|
||||
|
||||
|
||||
def img_key_mapping(img_keys: list[str]) -> list[str]:
|
||||
"""Map image keys to camera names.
|
||||
|
||||
Args:
|
||||
img_keys: List of image keys
|
||||
|
||||
Returns:
|
||||
List of camera names
|
||||
"""
|
||||
processed_img_keys = []
|
||||
for key in img_keys:
|
||||
key = key.replace(OBS_IMAGES + ".", "")
|
||||
if key in CAMERA_NAME_MAPPING:
|
||||
key = CAMERA_NAME_MAPPING[key]
|
||||
else:
|
||||
if "view" in key:
|
||||
key = key.replace("_", " ")
|
||||
else:
|
||||
key = key + " view"
|
||||
processed_img_keys.append(key)
|
||||
return processed_img_keys
|
||||
|
||||
|
||||
def get_action_tokens(normalized_actions: torch.Tensor | list, action_tokenizer) -> list[list[str]]:
|
||||
"""Convert normalized actions to action token strings.
|
||||
|
||||
Args:
|
||||
normalized_actions: Normalized action arrays/tensors
|
||||
action_tokenizer: Tokenizer for converting actions to tokens
|
||||
|
||||
Returns:
|
||||
List of action token string lists for each sample
|
||||
"""
|
||||
if isinstance(normalized_actions, torch.Tensor):
|
||||
normalized_actions = normalized_actions.cpu().numpy()
|
||||
|
||||
all_action_tokens = []
|
||||
for i in range(len(normalized_actions)):
|
||||
if isinstance(normalized_actions[i], torch.Tensor):
|
||||
normalized_actions[i] = normalized_actions[i].cpu().numpy()
|
||||
|
||||
token_id = action_tokenizer(normalized_actions[i])
|
||||
action_tokens = [f"<|action_token_{j}|>" for j in token_id[0]]
|
||||
all_action_tokens.append(action_tokens)
|
||||
|
||||
return all_action_tokens
|
||||
|
||||
|
||||
def pad_action_token_strs(
|
||||
actions_token_lists: list[list[str]],
|
||||
pad_token: str = "<|endoftext|>", # nosec B107
|
||||
) -> list[str]:
|
||||
"""Pad action token lists to same length and join as strings.
|
||||
|
||||
Args:
|
||||
actions_token_lists: List of action token lists for each sample
|
||||
pad_token: Token used for padding
|
||||
|
||||
Returns:
|
||||
List of padded action token strings
|
||||
"""
|
||||
max_len = max(len(tokens) for tokens in actions_token_lists)
|
||||
padded_action_strs = []
|
||||
|
||||
for tokens in actions_token_lists:
|
||||
padded_tokens = tokens + ["<|im_end|>\n"] + [pad_token] * (max_len - len(tokens))
|
||||
padded_action_strs.append("".join(padded_tokens))
|
||||
|
||||
return padded_action_strs
|
||||
|
||||
|
||||
def replace_action_token(
|
||||
text: list[str],
|
||||
norm_action: torch.Tensor | None,
|
||||
action_tokenizer,
|
||||
dof_masks: torch.Tensor | None = None,
|
||||
) -> list[str]:
|
||||
"""Replace action placeholders in text with actual action tokens.
|
||||
|
||||
Args:
|
||||
text: List of text strings with action placeholders
|
||||
norm_action: Normalized action tensors
|
||||
action_tokenizer: Tokenizer for converting actions to tokens
|
||||
dof_masks: Masks for degrees of freedom
|
||||
|
||||
Returns:
|
||||
List of text strings with action tokens replaced
|
||||
"""
|
||||
if action_tokenizer is not None and norm_action is not None:
|
||||
# Extract actions based on chunk sizes and DOF masks
|
||||
norm_action = [action[:32, dof_masks[i, 0].bool()] for i, action in enumerate(norm_action)]
|
||||
|
||||
# Convert to action tokens and pad
|
||||
actions_fast_tokens = get_action_tokens(norm_action, action_tokenizer)
|
||||
actions_fast_token_strs = pad_action_token_strs(actions_fast_tokens)
|
||||
|
||||
# Replace action placeholders with actual tokens
|
||||
actions_fast_token_idx = 0
|
||||
for i in range(len(text)):
|
||||
if "<|action_fast|>" in text[i]:
|
||||
text[i] = text[i].replace(
|
||||
"<|action_fast|><|im_end|>\n",
|
||||
actions_fast_token_strs[actions_fast_token_idx],
|
||||
)
|
||||
actions_fast_token_idx += 1
|
||||
|
||||
# Remove remaining action placeholders
|
||||
text = [t.replace("<|action|>", "") for t in text]
|
||||
else:
|
||||
# Remove action placeholders when no tokenizer available
|
||||
text = [t.replace("<|action_fast|><|im_end|>\n", "") for t in text]
|
||||
|
||||
return text
|
||||
@@ -1,6 +0,0 @@
|
||||
# register the processor steps
|
||||
from lerobot.policies.xvla.processor_xvla import (
|
||||
XVLAAddDomainIdProcessorStep,
|
||||
XVLAImageNetNormalizeProcessorStep,
|
||||
XVLAImageToFloatProcessorStep,
|
||||
)
|
||||
@@ -1,588 +0,0 @@
|
||||
# ------------------------------------------------------------------------------
|
||||
# Copyright 2025 2toINF and HuggingFace Inc. (https://github.com/2toINF)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterable
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
# =============================================================================
|
||||
# Registry
|
||||
# =============================================================================
|
||||
ACTION_REGISTRY: dict[str, type[BaseActionSpace]] = {}
|
||||
|
||||
|
||||
def register_action(name: str):
|
||||
"""Decorator for registering a new action space."""
|
||||
|
||||
def _wrap(cls):
|
||||
key = name.lower()
|
||||
if key in ACTION_REGISTRY:
|
||||
raise KeyError(f"ActionSpace '{key}' already registered -> {ACTION_REGISTRY[key]}")
|
||||
ACTION_REGISTRY[key] = cls
|
||||
cls.name = key
|
||||
return cls
|
||||
|
||||
return _wrap
|
||||
|
||||
|
||||
def build_action_space(name: str, **kwargs) -> BaseActionSpace:
|
||||
"""Instantiate a registered action space by name."""
|
||||
key = name.lower()
|
||||
if key not in ACTION_REGISTRY:
|
||||
raise KeyError(f"Unknown action space '{name}'. Available: {list(ACTION_REGISTRY.keys())}")
|
||||
return ACTION_REGISTRY[key](**kwargs)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Base class
|
||||
# =============================================================================
|
||||
class BaseActionSpace(nn.Module):
|
||||
"""
|
||||
Abstract base class for all action-space definitions.
|
||||
|
||||
Each subclass defines:
|
||||
- `dim_action`: dimension of the action vector.
|
||||
- `gripper_idx`: indices of gripper channels.
|
||||
- `compute_loss(pred, target)`: supervised loss for this space.
|
||||
- `preprocess(proprio, action, mode)`: pre-step modifications.
|
||||
- `postprocess(action)`: post-step corrections (e.g. apply sigmoid).
|
||||
"""
|
||||
|
||||
name: str = "base"
|
||||
dim_action: int = 0
|
||||
gripper_idx: tuple[int, ...] = ()
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# Core supervised loss
|
||||
# ---------------------------------------------------------------------
|
||||
def compute_loss(self, pred: torch.Tensor, target: torch.Tensor) -> dict[str, torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> dict[str, torch.Tensor]:
|
||||
"""Alias for compute_loss."""
|
||||
return self.compute_loss(pred, target)
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# Space-level hooks
|
||||
# ---------------------------------------------------------------------
|
||||
def preprocess(
|
||||
self,
|
||||
proprio: torch.Tensor,
|
||||
action: torch.Tensor,
|
||||
mode: str = "train",
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Default: return unchanged."""
|
||||
return proprio, action
|
||||
|
||||
def postprocess(self, action: torch.Tensor) -> torch.Tensor:
|
||||
"""Default: return unchanged."""
|
||||
return action
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Utilities
|
||||
# =============================================================================
|
||||
def _ensure_indices_valid(dim_action: int, idx: Iterable[int], name: str) -> None:
|
||||
bad = [i for i in idx if i < 0 or i >= dim_action]
|
||||
if bad:
|
||||
raise IndexError(f"{name} contains out-of-range indices {bad} for action dim dim_action={dim_action}")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Implementations
|
||||
# =============================================================================
|
||||
@register_action("ee6d")
|
||||
class EE6DActionSpace(BaseActionSpace):
|
||||
"""End-effector layout with xyz, 6D rotation, and gripper channels."""
|
||||
|
||||
dim_action = 20
|
||||
gripper_idx = (9, 19)
|
||||
GRIPPER_SCALE = 1.0
|
||||
XYZ_SCALE = 500.0
|
||||
ROT_SCALE = 10.0
|
||||
|
||||
POS_IDX_1 = (0, 1, 2)
|
||||
POS_IDX_2 = (10, 11, 12)
|
||||
ROT_IDX_1 = (3, 4, 5, 6, 7, 8)
|
||||
ROT_IDX_2 = (13, 14, 15, 16, 17, 18)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.mse = nn.MSELoss()
|
||||
self.bce = nn.BCEWithLogitsLoss()
|
||||
|
||||
def compute_loss(self, pred, target):
|
||||
assert pred.shape == target.shape, "pred/target shapes must match"
|
||||
batch_size, seq_len, action_dim = pred.shape
|
||||
_ensure_indices_valid(action_dim, self.gripper_idx, "gripper_idx")
|
||||
|
||||
# Gripper BCE
|
||||
g_losses = [self.bce(pred[:, :, gi], target[:, :, gi]) for gi in self.gripper_idx]
|
||||
gripper_loss = sum(g_losses) / len(self.gripper_idx) * self.GRIPPER_SCALE
|
||||
|
||||
# XYZ position
|
||||
pos_loss = (
|
||||
self.mse(pred[:, :, self.POS_IDX_1], target[:, :, self.POS_IDX_1])
|
||||
+ self.mse(pred[:, :, self.POS_IDX_2], target[:, :, self.POS_IDX_2])
|
||||
) * self.XYZ_SCALE
|
||||
|
||||
# Rotation 6D
|
||||
rot_loss = (
|
||||
self.mse(pred[:, :, self.ROT_IDX_1], target[:, :, self.ROT_IDX_1])
|
||||
+ self.mse(pred[:, :, self.ROT_IDX_2], target[:, :, self.ROT_IDX_2])
|
||||
) * self.ROT_SCALE
|
||||
|
||||
return {
|
||||
"position_loss": pos_loss,
|
||||
"rotate6D_loss": rot_loss,
|
||||
"gripper_loss": gripper_loss,
|
||||
}
|
||||
|
||||
def preprocess(self, proprio, action, mode="train"):
|
||||
"""Zero-out gripper channels in proprio/action."""
|
||||
proprio_m = proprio.clone()
|
||||
action_m = action.clone()
|
||||
proprio_m[..., self.gripper_idx] = 0.0
|
||||
action_m[..., self.gripper_idx] = 0.0
|
||||
return proprio_m, action_m
|
||||
|
||||
def postprocess(self, action: torch.Tensor) -> torch.Tensor:
|
||||
"""Apply sigmoid to gripper logits."""
|
||||
if action.size(-1) > max(self.gripper_idx):
|
||||
action[..., self.gripper_idx] = torch.sigmoid(action[..., self.gripper_idx])
|
||||
return action
|
||||
|
||||
|
||||
@register_action("joint")
|
||||
class JointActionSpace(BaseActionSpace):
|
||||
"""Joint-space layout with joints + gripper only."""
|
||||
|
||||
dim_action = 14
|
||||
gripper_idx = (6, 13)
|
||||
GRIPPER_SCALE = 0.1
|
||||
JOINTS_SCALE = 1.0
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.mse = nn.MSELoss()
|
||||
self.bce = nn.BCEWithLogitsLoss()
|
||||
|
||||
def compute_loss(self, pred, target):
|
||||
assert pred.shape == target.shape
|
||||
batch_size, seq_len, action_dim = pred.shape
|
||||
_ensure_indices_valid(action_dim, self.gripper_idx, "gripper_idx")
|
||||
|
||||
g_losses = [self.bce(pred[:, :, gi], target[:, :, gi]) for gi in self.gripper_idx]
|
||||
gripper_loss = sum(g_losses) / len(self.gripper_idx) * self.GRIPPER_SCALE
|
||||
|
||||
joints_idx = tuple(i for i in range(action_dim) if i not in set(self.gripper_idx))
|
||||
joints_loss = self.mse(pred[:, :, joints_idx], target[:, :, joints_idx]) * self.JOINTS_SCALE
|
||||
|
||||
return {
|
||||
"joints_loss": joints_loss,
|
||||
"gripper_loss": gripper_loss,
|
||||
}
|
||||
|
||||
def preprocess(self, proprio, action, mode="train"):
|
||||
"""Zero-out gripper channels in proprio/action."""
|
||||
proprio_m = proprio.clone()
|
||||
action_m = action.clone()
|
||||
proprio_m[..., self.gripper_idx] = 0.0
|
||||
action_m[..., self.gripper_idx] = 0.0
|
||||
return proprio_m, action_m
|
||||
|
||||
def postprocess(self, action: torch.Tensor) -> torch.Tensor:
|
||||
"""Apply sigmoid to gripper logits."""
|
||||
if action.size(-1) > max(self.gripper_idx):
|
||||
action[..., self.gripper_idx] = torch.sigmoid(action[..., self.gripper_idx])
|
||||
return action
|
||||
|
||||
|
||||
@register_action("agibot_ee6d")
|
||||
class AGIBOTEE6DActionSpace(BaseActionSpace):
|
||||
"""AGI-bot variant of EE6DActionSpace using MSE for all components."""
|
||||
|
||||
dim_action = 20
|
||||
gripper_idx = (9, 19)
|
||||
GRIPPER_SCALE = 10.0
|
||||
XYZ_SCALE = 500.0
|
||||
ROT_SCALE = 10.0
|
||||
POS_IDX_1 = (0, 1, 2)
|
||||
POS_IDX_2 = (10, 11, 12)
|
||||
ROT_IDX_1 = (3, 4, 5, 6, 7, 8)
|
||||
ROT_IDX_2 = (13, 14, 15, 16, 17, 18)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.mse = nn.MSELoss()
|
||||
|
||||
def compute_loss(self, pred, target):
|
||||
assert pred.shape == target.shape
|
||||
batch_size, seq_len, action_dim = pred.shape
|
||||
_ensure_indices_valid(action_dim, self.gripper_idx, "gripper_idx")
|
||||
|
||||
gripper_loss = (
|
||||
self.mse(pred[:, :, self.gripper_idx], target[:, :, self.gripper_idx]) * self.GRIPPER_SCALE
|
||||
)
|
||||
pos_loss = (
|
||||
self.mse(pred[:, :, self.POS_IDX_1], target[:, :, self.POS_IDX_1])
|
||||
+ self.mse(pred[:, :, self.POS_IDX_2], target[:, :, self.POS_IDX_2])
|
||||
) * self.XYZ_SCALE
|
||||
rot_loss = (
|
||||
self.mse(pred[:, :, self.ROT_IDX_1], target[:, :, self.ROT_IDX_1])
|
||||
+ self.mse(pred[:, :, self.ROT_IDX_2], target[:, :, self.ROT_IDX_2])
|
||||
) * self.ROT_SCALE
|
||||
|
||||
return {
|
||||
"position_loss": pos_loss,
|
||||
"rotate6D_loss": rot_loss,
|
||||
"gripper_loss": gripper_loss,
|
||||
}
|
||||
|
||||
def preprocess(self, proprio, action, mode="train"):
|
||||
"""No preprocessing applied in AGIBOT variant."""
|
||||
return proprio, action
|
||||
|
||||
def postprocess(self, action: torch.Tensor) -> torch.Tensor:
|
||||
"""AGIBOT does not postprocess."""
|
||||
return action
|
||||
|
||||
|
||||
@register_action("franka_joint7")
|
||||
class FrankaJoint7ActionSpace(BaseActionSpace):
|
||||
"""
|
||||
Franka Panda joint-space: 7 joints, with gripper.
|
||||
|
||||
- Real robot action dim: 7
|
||||
- Model-facing dim: 20 (padded with zeros)
|
||||
compatible with pretrained VLA models expecting 20D.
|
||||
"""
|
||||
|
||||
dim_action = 20 # model dimension
|
||||
REAL_DIM = 7 # actual Franka joints
|
||||
|
||||
JOINTS_SCALE = 1.0
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.mse = nn.MSELoss()
|
||||
|
||||
def _pad_to_model_dim(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Pad 7 → 20 dims (zeros for the dummy channels)."""
|
||||
if x is None:
|
||||
return None
|
||||
if x.size(-1) == self.dim_action:
|
||||
return x
|
||||
if x.size(-1) != self.REAL_DIM:
|
||||
raise ValueError(
|
||||
f"Expected last dim to be {self.REAL_DIM} or {self.dim_action}, got {x.size(-1)}"
|
||||
)
|
||||
|
||||
pad_shape = list(x.shape[:-1]) + [self.dim_action - self.REAL_DIM] # 13 zeros
|
||||
pad = x.new_zeros(pad_shape)
|
||||
return torch.cat([x, pad], dim=-1)
|
||||
|
||||
def _trim_to_real_dim(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Trim model output 20 → 7 dims."""
|
||||
return x[..., : self.REAL_DIM]
|
||||
|
||||
def compute_loss(self, pred, target):
|
||||
"""
|
||||
pred : [B, T, 20]
|
||||
target : [B, T, 7] or [B, T, 20]
|
||||
|
||||
Only compute MSE on the first 7 dims.
|
||||
"""
|
||||
pred = self._pad_to_model_dim(pred)
|
||||
target = self._pad_to_model_dim(target)
|
||||
|
||||
assert pred.shape == target.shape
|
||||
|
||||
joints_loss = (
|
||||
self.mse(
|
||||
pred[:, :, : self.REAL_DIM], # use only the first 7 joints
|
||||
target[:, :, : self.REAL_DIM],
|
||||
)
|
||||
* self.JOINTS_SCALE
|
||||
)
|
||||
|
||||
return {"joints_loss": joints_loss}
|
||||
|
||||
def preprocess(self, proprio, action, mode="train"):
|
||||
"""
|
||||
During training:
|
||||
- Pad [7] → [20]
|
||||
"""
|
||||
return proprio, self._pad_to_model_dim(action)
|
||||
|
||||
def postprocess(self, action: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
After model prediction:
|
||||
- Trim [20] → [7] for real robot control.
|
||||
"""
|
||||
return self._trim_to_real_dim(action)
|
||||
|
||||
|
||||
@register_action("auto")
|
||||
class AutoActionSpace(BaseActionSpace):
|
||||
"""
|
||||
Auto-detecting action space that adapts to any action dimension.
|
||||
|
||||
- Auto-detects the real action dimension from the policy feature
|
||||
- Model outputs max_dim for compatibility with pretrained models
|
||||
- Loss is computed only on the first real_dim dimensions
|
||||
- Postprocess trims output back to real_dim
|
||||
|
||||
Args:
|
||||
real_dim: The actual action dimension from the dataset/policy feature
|
||||
max_dim: The model's output dimension for pretrained VLA compatibility
|
||||
"""
|
||||
|
||||
JOINTS_SCALE = 1.0
|
||||
|
||||
def __init__(self, real_dim: int, max_dim: int):
|
||||
super().__init__()
|
||||
self.real_dim = real_dim
|
||||
self.dim_action = max_dim # Model-facing dimension
|
||||
self.mse = nn.MSELoss()
|
||||
|
||||
def _pad_to_model_dim(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Pad real_dim → max_dim (zeros for the dummy channels)."""
|
||||
if x is None:
|
||||
return None
|
||||
if x.size(-1) == self.dim_action:
|
||||
return x
|
||||
if x.size(-1) != self.real_dim:
|
||||
# If dimension doesn't match either, pad/trim to real_dim first
|
||||
if x.size(-1) < self.real_dim:
|
||||
pad_shape = list(x.shape[:-1]) + [self.real_dim - x.size(-1)]
|
||||
pad = x.new_zeros(pad_shape)
|
||||
x = torch.cat([x, pad], dim=-1)
|
||||
else:
|
||||
x = x[..., : self.real_dim]
|
||||
|
||||
pad_shape = list(x.shape[:-1]) + [self.dim_action - self.real_dim]
|
||||
pad = x.new_zeros(pad_shape)
|
||||
return torch.cat([x, pad], dim=-1)
|
||||
|
||||
def _trim_to_real_dim(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Trim model output max_dim → real_dim."""
|
||||
return x[..., : self.real_dim]
|
||||
|
||||
def compute_loss(self, pred: torch.Tensor, target: torch.Tensor) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Compute loss only on the first real_dim dimensions.
|
||||
|
||||
pred: [B, T, max_dim] from the model
|
||||
target: [B, T, real_dim] or [B, T, max_dim]
|
||||
|
||||
Loss = MSE(pred[:,:,:real_dim], target[:,:,:real_dim])
|
||||
"""
|
||||
pred = self._pad_to_model_dim(pred)
|
||||
target = self._pad_to_model_dim(target)
|
||||
assert pred.shape == target.shape, f"Shape mismatch: pred {pred.shape} vs target {target.shape}"
|
||||
|
||||
# only compute loss on the real dimensions
|
||||
joints_loss = (
|
||||
self.mse(
|
||||
pred[:, :, : self.real_dim],
|
||||
target[:, :, : self.real_dim],
|
||||
)
|
||||
* self.JOINTS_SCALE
|
||||
)
|
||||
|
||||
return {"joints_loss": joints_loss}
|
||||
|
||||
def preprocess(self, proprio: torch.Tensor, action: torch.Tensor, mode: str = "train"):
|
||||
"""
|
||||
Pad action from real_dim to max_dim for the model.
|
||||
"""
|
||||
return proprio, self._pad_to_model_dim(action)
|
||||
|
||||
def postprocess(self, action: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Trim model output from max_dim to real_dim for real robot control.
|
||||
"""
|
||||
return self._trim_to_real_dim(action)
|
||||
|
||||
|
||||
@register_action("so101_bimanual")
|
||||
class BimanualSO101ActionSpace(BaseActionSpace):
|
||||
"""
|
||||
Bimanual SO101 robot: 2 arms with 5 joints each + gripper.
|
||||
|
||||
Layout (real robot):
|
||||
[left_arm (5 joints + gripper), right_arm (5 joints + gripper)]
|
||||
- Left arm: shoulder_pan, shoulder_lift, elbow_flex, wrist_flex, wrist_roll, gripper
|
||||
- Right arm: shoulder_pan, shoulder_lift, elbow_flex, wrist_flex, wrist_roll, gripper
|
||||
|
||||
Real action dim: 12
|
||||
Model-facing dim: 20 (extra 8 dummy dims at the end)
|
||||
"""
|
||||
|
||||
# Model output / training dimension (to match pretrained policy)
|
||||
dim_action = 20
|
||||
|
||||
# Real robot action dimension
|
||||
REAL_DIM = 12
|
||||
|
||||
# Indices of real vs dummy channels
|
||||
REAL_IDXS = tuple(range(REAL_DIM)) # 0..11
|
||||
DUMMY_IDXS = tuple(range(REAL_DIM, dim_action)) # 12..19
|
||||
|
||||
# Grippers live in the real part
|
||||
gripper_idx = (5, 11) # left_gripper at idx 5, right_gripper at idx 11
|
||||
GRIPPER_SCALE = 1.0
|
||||
JOINTS_SCALE = 1.0
|
||||
|
||||
# Indices for left and right arm joints (excluding grippers)
|
||||
LEFT_ARM_JOINTS = (0, 1, 2, 3, 4)
|
||||
RIGHT_ARM_JOINTS = (6, 7, 8, 9, 10)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.mse = nn.MSELoss()
|
||||
self.bce = nn.BCEWithLogitsLoss()
|
||||
|
||||
# ---------- helpers ----------
|
||||
|
||||
def _pad_to_model_dim(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""If last dim is REAL_DIM (12), pad zeros to reach dim_action (20)."""
|
||||
if x is None:
|
||||
return None
|
||||
if x.size(-1) == self.dim_action:
|
||||
return x
|
||||
if x.size(-1) != self.REAL_DIM:
|
||||
raise ValueError(
|
||||
f"Expected last dim to be {self.REAL_DIM} or {self.dim_action}, got {x.size(-1)}"
|
||||
)
|
||||
pad_shape = list(x.shape[:-1]) + [self.dim_action - self.REAL_DIM]
|
||||
pad = x.new_zeros(pad_shape)
|
||||
return torch.cat([x, pad], dim=-1)
|
||||
|
||||
def _trim_to_real_dim(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Keep only the first REAL_DIM (12) dims for the real robot."""
|
||||
return x[..., : self.REAL_DIM]
|
||||
|
||||
# ---------- loss ----------
|
||||
|
||||
def compute_loss(self, pred, target):
|
||||
"""
|
||||
pred: [B, T, 20] from the model
|
||||
target: [B, T, 12] or [B, T, 20]
|
||||
We pad target → 20 and compute loss only on the real dims.
|
||||
"""
|
||||
# Ensure both are [B, T, 20]
|
||||
pred = self._pad_to_model_dim(pred)
|
||||
target = self._pad_to_model_dim(target)
|
||||
assert pred.shape == target.shape
|
||||
|
||||
# ---- MSE for all real dims (0–11) ----
|
||||
real_dims = 12
|
||||
|
||||
joints_loss = (
|
||||
self.mse(
|
||||
pred[:, :, :real_dims],
|
||||
target[:, :, :real_dims],
|
||||
)
|
||||
* self.JOINTS_SCALE
|
||||
)
|
||||
|
||||
left_arm_loss = self.mse(pred[:, :, :6], target[:, :, :6])
|
||||
right_arm_loss = self.mse(pred[:, :, 6:12], target[:, :, 6:12])
|
||||
|
||||
gripper_loss = (
|
||||
self.mse(
|
||||
pred[:, :, [5, 11]],
|
||||
target[:, :, [5, 11]],
|
||||
)
|
||||
* self.GRIPPER_SCALE
|
||||
)
|
||||
|
||||
return {
|
||||
"joints_loss": joints_loss,
|
||||
"gripper_loss": gripper_loss,
|
||||
"left_arm_loss": left_arm_loss,
|
||||
"right_arm_loss": right_arm_loss,
|
||||
}
|
||||
|
||||
# ---------- preprocess / postprocess ----------
|
||||
|
||||
def preprocess(self, proprio, action, mode="train"):
|
||||
"""
|
||||
- If proprio/action are 12-dim, pad them to 20 for the model.
|
||||
- Zero-out gripper channels in proprio/action to focus learning on joints.
|
||||
"""
|
||||
proprio_m = self._pad_to_model_dim(proprio.clone())
|
||||
action_m = self._pad_to_model_dim(action.clone()) if action is not None else None
|
||||
|
||||
proprio_m[..., self.gripper_idx] = 0.0
|
||||
if action_m is not None:
|
||||
action_m[..., self.gripper_idx] = 0.0
|
||||
|
||||
return proprio_m, action_m
|
||||
|
||||
def postprocess(self, action: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
- Model outputs [*, 20]
|
||||
- Apply sigmoid to gripper logits
|
||||
- Return only the first 12 dims for the real robot:
|
||||
["left_shoulder_pan.pos",
|
||||
"left_shoulder_lift.pos",
|
||||
"left_elbow_flex.pos",
|
||||
"left_wrist_flex.pos",
|
||||
"left_wrist_roll.pos",
|
||||
"left_gripper.pos",
|
||||
"right_shoulder_pan.pos",
|
||||
"right_shoulder_lift.pos",
|
||||
"right_elbow_flex.pos",
|
||||
"right_wrist_flex.pos",
|
||||
"right_wrist_roll.pos",
|
||||
"right_gripper.pos"]
|
||||
"""
|
||||
# Ensure we at least have the real dims + grippers
|
||||
if action.size(-1) < self.REAL_DIM:
|
||||
raise ValueError(f"Expected at least {self.REAL_DIM} dims in action, got {action.size(-1)}")
|
||||
|
||||
# Apply sigmoid on gripper channels in model space (indices 5 and 11)
|
||||
if action.size(-1) > max(self.gripper_idx):
|
||||
action[..., self.gripper_idx] = torch.sigmoid(action[..., self.gripper_idx])
|
||||
|
||||
# Return only the real 12-dim control vector for the env
|
||||
return self._trim_to_real_dim(action)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Exports
|
||||
# =============================================================================
|
||||
__all__ = [
|
||||
"BaseActionSpace",
|
||||
"build_action_space",
|
||||
"register_action",
|
||||
"EE6DActionSpace",
|
||||
"JointActionSpace",
|
||||
"AGIBOTEE6DActionSpace",
|
||||
"FrankaJoint7ActionSpace",
|
||||
"AutoActionSpace",
|
||||
"BimanualSO101ActionSpace",
|
||||
"ACTION_REGISTRY",
|
||||
]
|
||||
@@ -1,353 +0,0 @@
|
||||
# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import warnings
|
||||
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.utils import logging
|
||||
|
||||
""" Florence-2 configuration"""
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class Florence2VisionConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`Florence2VisionModel`]. It is used to instantiate a Florence2VisionModel
|
||||
according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
||||
defaults will yield a similar configuration to that of the Florence2VisionModel architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
drop_path_rate (`float`, *optional*, defaults to 0.1):
|
||||
The dropout rate of the drop path layer.
|
||||
patch_size (`List[int]`, *optional*, defaults to [7, 3, 3, 3]):
|
||||
The patch size of the image.
|
||||
patch_stride (`List[int]`, *optional*, defaults to [4, 2, 2, 2]):
|
||||
The patch stride of the image.
|
||||
patch_padding (`List[int]`, *optional*, defaults to [3, 1, 1, 1]):
|
||||
The patch padding of the image.
|
||||
patch_prenorm (`List[bool]`, *optional*, defaults to [false, true, true, true]):
|
||||
Whether to apply layer normalization before the patch embedding layer.
|
||||
enable_checkpoint (`bool`, *optional*, defaults to False):
|
||||
Whether to enable checkpointing.
|
||||
dim_embed (`List[int]`, *optional*, defaults to [256, 512, 1024, 2048]):
|
||||
The dimension of the embedding layer.
|
||||
num_heads (`List[int]`, *optional*, defaults to [8, 16, 32, 64]):
|
||||
The number of attention heads.
|
||||
num_groups (`List[int]`, *optional*, defaults to [8, 16, 32, 64]):
|
||||
The number of groups.
|
||||
depths (`List[int]`, *optional*, defaults to [1, 1, 9, 1]):
|
||||
The depth of the model.
|
||||
window_size (`int`, *optional*, defaults to 12):
|
||||
The window size of the model.
|
||||
projection_dim (`int`, *optional*, defaults to 1024):
|
||||
The dimension of the projection layer.
|
||||
visual_temporal_embedding (`dict`, *optional*):
|
||||
The configuration of the visual temporal embedding.
|
||||
image_pos_embed (`dict`, *optional*):
|
||||
The configuration of the image position embedding.
|
||||
image_feature_source (`List[str]`, *optional*, defaults to ["spatial_avg_pool", "temporal_avg_pool"]):
|
||||
The source of the image feature.
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import Florence2VisionConfig, Florence2VisionModel
|
||||
|
||||
>>> # Initializing a Florence2 Vision style configuration
|
||||
>>> configuration = Florence2VisionConfig()
|
||||
|
||||
>>> # Initializing a model (with random weights)
|
||||
>>> model = Florence2VisionModel(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "davit"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
drop_path_rate=0.1,
|
||||
patch_size=None,
|
||||
patch_stride=None,
|
||||
patch_padding=None,
|
||||
patch_prenorm=None,
|
||||
enable_checkpoint=False,
|
||||
dim_embed=None,
|
||||
num_heads=None,
|
||||
num_groups=None,
|
||||
depths=None,
|
||||
window_size=12,
|
||||
projection_dim=1024,
|
||||
visual_temporal_embedding=None,
|
||||
image_pos_embed=None,
|
||||
image_feature_source=None,
|
||||
**kwargs,
|
||||
):
|
||||
self.drop_path_rate = drop_path_rate
|
||||
self.patch_size = patch_size if patch_size is not None else [7, 3, 3, 3]
|
||||
self.patch_stride = patch_stride if patch_stride is not None else [4, 2, 2, 2]
|
||||
self.patch_padding = patch_padding if patch_padding is not None else [3, 1, 1, 1]
|
||||
self.patch_prenorm = patch_prenorm if patch_prenorm is not None else [False, True, True, True]
|
||||
self.enable_checkpoint = enable_checkpoint
|
||||
self.dim_embed = dim_embed if dim_embed is not None else [256, 512, 1024, 2048]
|
||||
self.num_heads = num_heads if num_heads is not None else [8, 16, 32, 64]
|
||||
self.num_groups = num_groups if num_groups is not None else [8, 16, 32, 64]
|
||||
self.depths = depths if depths is not None else [1, 1, 9, 1]
|
||||
self.window_size = window_size
|
||||
self.projection_dim = projection_dim
|
||||
|
||||
if visual_temporal_embedding is None:
|
||||
visual_temporal_embedding = {
|
||||
"type": "COSINE",
|
||||
"max_temporal_embeddings": 100,
|
||||
}
|
||||
self.visual_temporal_embedding = visual_temporal_embedding
|
||||
|
||||
if image_pos_embed is None:
|
||||
image_pos_embed = {
|
||||
"type": "learned_abs_2d",
|
||||
"max_pos_embeddings": 1000,
|
||||
}
|
||||
self.image_pos_embed = image_pos_embed
|
||||
|
||||
self.image_feature_source = (
|
||||
image_feature_source
|
||||
if image_feature_source is not None
|
||||
else ["spatial_avg_pool", "temporal_avg_pool"]
|
||||
)
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
class Florence2LanguageConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`Florence2LanguagePreTrainedModel`]. It is used to instantiate a BART
|
||||
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
||||
defaults will yield a similar configuration to that of the BART
|
||||
[facebook/bart-large](https://huggingface.co/facebook/bart-large) architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 51289):
|
||||
Vocabulary size of the Florence2Language model. Defines the number of different tokens that can be represented by the
|
||||
`inputs_ids` passed when calling [`Florence2LanguageModel`].
|
||||
d_model (`int`, *optional*, defaults to 1024):
|
||||
Dimensionality of the layers and the pooler layer.
|
||||
encoder_layers (`int`, *optional*, defaults to 12):
|
||||
Number of encoder layers.
|
||||
decoder_layers (`int`, *optional*, defaults to 12):
|
||||
Number of decoder layers.
|
||||
encoder_attention_heads (`int`, *optional*, defaults to 16):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
decoder_attention_heads (`int`, *optional*, defaults to 16):
|
||||
Number of attention heads for each attention layer in the Transformer decoder.
|
||||
decoder_ffn_dim (`int`, *optional*, defaults to 4096):
|
||||
Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
|
||||
encoder_ffn_dim (`int`, *optional*, defaults to 4096):
|
||||
Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
|
||||
activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
|
||||
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
||||
`"relu"`, `"silu"` and `"gelu_new"` are supported.
|
||||
dropout (`float`, *optional*, defaults to 0.1):
|
||||
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
activation_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for activations inside the fully connected layer.
|
||||
classifier_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for classifier.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 1024):
|
||||
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
||||
just in case (e.g., 512 or 1024 or 2048).
|
||||
init_std (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
encoder_layerdrop (`float`, *optional*, defaults to 0.0):
|
||||
The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
|
||||
for more details.
|
||||
decoder_layerdrop (`float`, *optional*, defaults to 0.0):
|
||||
The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
|
||||
for more details.
|
||||
scale_embedding (`bool`, *optional*, defaults to `False`):
|
||||
Scale embeddings by diving by sqrt(d_model).
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models).
|
||||
num_labels (`int`, *optional*, defaults to 3):
|
||||
The number of labels to use in [`Florence2LanguageForSequenceClassification`].
|
||||
forced_eos_token_id (`int`, *optional*, defaults to 2):
|
||||
The id of the token to force as the last generated token when `max_length` is reached. Usually set to
|
||||
`eos_token_id`.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import Florence2LanguageConfig, Florence2LanguageModel
|
||||
|
||||
>>> # Initializing a Florence2 Language style configuration
|
||||
>>> configuration = Florence2LanguageConfig()
|
||||
|
||||
>>> # Initializing a model (with random weights)
|
||||
>>> model = Florence2LanguageModel(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "florence2_language"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=51289,
|
||||
max_position_embeddings=1024,
|
||||
encoder_layers=12,
|
||||
encoder_ffn_dim=4096,
|
||||
encoder_attention_heads=16,
|
||||
decoder_layers=12,
|
||||
decoder_ffn_dim=4096,
|
||||
decoder_attention_heads=16,
|
||||
encoder_layerdrop=0.0,
|
||||
decoder_layerdrop=0.0,
|
||||
activation_function="gelu",
|
||||
d_model=1024,
|
||||
dropout=0.1,
|
||||
attention_dropout=0.0,
|
||||
activation_dropout=0.0,
|
||||
init_std=0.02,
|
||||
classifier_dropout=0.0,
|
||||
scale_embedding=False,
|
||||
use_cache=True,
|
||||
num_labels=3,
|
||||
pad_token_id=1,
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
is_encoder_decoder=True,
|
||||
decoder_start_token_id=2,
|
||||
forced_eos_token_id=2,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.d_model = d_model
|
||||
self.encoder_ffn_dim = encoder_ffn_dim
|
||||
self.encoder_layers = encoder_layers
|
||||
self.encoder_attention_heads = encoder_attention_heads
|
||||
self.decoder_ffn_dim = decoder_ffn_dim
|
||||
self.decoder_layers = decoder_layers
|
||||
self.decoder_attention_heads = decoder_attention_heads
|
||||
self.dropout = dropout
|
||||
self.attention_dropout = attention_dropout
|
||||
self.activation_dropout = activation_dropout
|
||||
self.activation_function = activation_function
|
||||
self.init_std = init_std
|
||||
self.encoder_layerdrop = encoder_layerdrop
|
||||
self.decoder_layerdrop = decoder_layerdrop
|
||||
self.classifier_dropout = classifier_dropout
|
||||
self.use_cache = use_cache
|
||||
self.num_hidden_layers = encoder_layers
|
||||
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
|
||||
|
||||
super().__init__(
|
||||
num_labels=num_labels,
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
is_encoder_decoder=is_encoder_decoder,
|
||||
decoder_start_token_id=decoder_start_token_id,
|
||||
forced_eos_token_id=forced_eos_token_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# ensure backward compatibility for BART CNN models
|
||||
if self.forced_bos_token_id is None and kwargs.get("force_bos_token_to_be_generated", False):
|
||||
self.forced_bos_token_id = self.bos_token_id
|
||||
warnings.warn(
|
||||
f"Please make sure the config includes `forced_bos_token_id={self.bos_token_id}` in future versions. "
|
||||
"The config can simply be saved and uploaded again to be fixed.",
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
|
||||
class Florence2Config(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`Florence2ForConditionalGeneration`]. It is used to instantiate an
|
||||
Florence-2 model according to the specified arguments, defining the model architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
vision_config (`Florence2VisionConfig`, *optional*):
|
||||
Custom vision config or dict
|
||||
text_config (`Union[AutoConfig, dict]`, *optional*):
|
||||
The config object of the text backbone.
|
||||
ignore_index (`int`, *optional*, defaults to -100):
|
||||
The ignore index for the loss function.
|
||||
vocab_size (`int`, *optional*, defaults to 51289):
|
||||
Vocabulary size of the Florence2model. Defines the number of different tokens that can be represented by the
|
||||
`inputs_ids` passed when calling [`~Florence2ForConditionalGeneration`]
|
||||
projection_dim (`int`, *optional*, defaults to 1024):
|
||||
Dimension of the multimodal projection space.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import Florence2ForConditionalGeneration, Florence2Config, CLIPVisionConfig, BartConfig
|
||||
|
||||
>>> # Initializing a clip-like vision config
|
||||
>>> vision_config = CLIPVisionConfig()
|
||||
|
||||
>>> # Initializing a Bart config
|
||||
>>> text_config = BartConfig()
|
||||
|
||||
>>> # Initializing a Florence-2 configuration
|
||||
>>> configuration = Florence2Config(vision_config, text_config)
|
||||
|
||||
>>> # Initializing a model from the florence-2 configuration
|
||||
>>> model = Florence2ForConditionalGeneration(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "florence2"
|
||||
is_composition = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vision_config=None,
|
||||
text_config=None,
|
||||
ignore_index=-100,
|
||||
vocab_size=51289,
|
||||
projection_dim=1024,
|
||||
**kwargs,
|
||||
):
|
||||
self.ignore_index = ignore_index
|
||||
self.vocab_size = vocab_size
|
||||
self.projection_dim = projection_dim
|
||||
if vision_config is not None:
|
||||
vision_config = Florence2VisionConfig(**vision_config)
|
||||
self.vision_config = vision_config
|
||||
|
||||
self.text_config = text_config
|
||||
if text_config is not None:
|
||||
self.text_config = Florence2LanguageConfig(**text_config)
|
||||
|
||||
super().__init__(**kwargs)
|
||||
@@ -1,203 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Copyright 2025 The HuggingFace Inc. team and 2toINF (https://github.com/2toINF)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.optim.optimizers import XVLAAdamWConfig
|
||||
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
||||
from lerobot.utils.constants import OBS_IMAGES
|
||||
|
||||
# Conditional import for type checking and lazy loading
|
||||
from lerobot.utils.import_utils import _transformers_available
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from .configuration_florence2 import Florence2Config
|
||||
else:
|
||||
Florence2Config = None
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("xvla")
|
||||
@dataclass
|
||||
class XVLAConfig(PreTrainedConfig):
|
||||
"""
|
||||
Configuration class for the XVLA (Extended Vision-Language-Action) policy so it can
|
||||
plug into the LeRobot training stack.
|
||||
|
||||
The config mirrors the knobs exposed in the original XVLA repository but also
|
||||
declares the input/output feature contract required by LeRobot.
|
||||
"""
|
||||
|
||||
# Input / output structure
|
||||
n_obs_steps: int = 1
|
||||
chunk_size: int = 32
|
||||
n_action_steps: int = 32
|
||||
dtype: str = "float32" # Options: "bfloat16", "float32"
|
||||
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.IDENTITY,
|
||||
"ACTION": NormalizationMode.IDENTITY,
|
||||
}
|
||||
)
|
||||
|
||||
# Florence2 backbone and tokenizer configuration
|
||||
florence_config: dict[str, Any] = field(default_factory=dict)
|
||||
tokenizer_name: str = "facebook/bart-large"
|
||||
tokenizer_max_length: int = 64
|
||||
tokenizer_padding_side: str = "right"
|
||||
pad_language_to: str = "max_length"
|
||||
|
||||
# Transformer head
|
||||
hidden_size: int = 1024
|
||||
depth: int = 24
|
||||
num_heads: int = 16
|
||||
mlp_ratio: float = 4.0
|
||||
num_domains: int = 30
|
||||
len_soft_prompts: int = 32
|
||||
dim_time: int = 32
|
||||
max_len_seq: int = 512
|
||||
use_hetero_proj: bool = False
|
||||
|
||||
# Action & proprioception
|
||||
action_mode: str = "ee6d"
|
||||
num_denoising_steps: int = 10
|
||||
use_proprio: bool = True
|
||||
max_state_dim: int = 32
|
||||
max_action_dim: int = 20 # Maximum action dimension for padding (used by "auto" action mode)
|
||||
domain_feature_key: str | None = None
|
||||
|
||||
# Vision preprocessing
|
||||
resize_imgs_with_padding: tuple[int, int] | None = None
|
||||
num_image_views: int | None = None
|
||||
empty_cameras: int = 0
|
||||
|
||||
# Freezing options for VLM components
|
||||
# By default, VLM encoders are frozen and only policy transformer + soft prompts train
|
||||
freeze_vision_encoder: bool = False # Freeze VLM vision encoder weights
|
||||
freeze_language_encoder: bool = False # Freeze VLM language encoder weights
|
||||
train_policy_transformer: bool = True # Allow policy transformer to train
|
||||
train_soft_prompts: bool = True # Allow soft prompts to train
|
||||
|
||||
# Training presets
|
||||
optimizer_lr: float = 1e-4
|
||||
optimizer_betas: tuple[float, float] = (0.9, 0.99)
|
||||
optimizer_eps: float = 1e-8
|
||||
optimizer_weight_decay: float = 0.0
|
||||
optimizer_grad_clip_norm: float = 10.0
|
||||
# Soft-prompt LR settings (for optional warm-up)
|
||||
optimizer_soft_prompt_lr_scale: float = 1.0 # Scale factor for soft-prompt LR
|
||||
optimizer_soft_prompt_warmup_lr_scale: float | None = None # Start scale for warmup (e.g., 0.01)
|
||||
|
||||
scheduler_warmup_steps: int = 1_000
|
||||
scheduler_decay_steps: int = 30_000
|
||||
scheduler_decay_lr: float = 2.5e-6
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
super().__post_init__()
|
||||
|
||||
if self.chunk_size <= 0:
|
||||
raise ValueError("`chunk_size` must be strictly positive.")
|
||||
if self.n_action_steps > self.chunk_size:
|
||||
raise ValueError(
|
||||
f"`n_action_steps` ({self.n_action_steps}) must be <= `chunk_size` ({self.chunk_size})."
|
||||
)
|
||||
if self.num_image_views is not None and self.num_image_views <= 0:
|
||||
raise ValueError("`num_image_views` must be > 0 when specified.")
|
||||
if self.dtype not in ["bfloat16", "float32"]:
|
||||
raise ValueError(f"Invalid dtype: {self.dtype}")
|
||||
self._florence_config_obj: Florence2Config | None = None
|
||||
|
||||
def get_florence_config(self) -> Florence2Config:
|
||||
"""
|
||||
Build (and cache) the Florence2 transformer config that should back the VLM.
|
||||
"""
|
||||
if self._florence_config_obj is None:
|
||||
config_dict = dict(self.florence_config)
|
||||
if "vision_config" not in config_dict or config_dict["vision_config"] is None:
|
||||
raise ValueError("vision_config is required")
|
||||
|
||||
if "text_config" not in config_dict or config_dict["text_config"] is None:
|
||||
raise ValueError("text_config is required")
|
||||
self._florence_config_obj = Florence2Config(**config_dict)
|
||||
return self._florence_config_obj
|
||||
|
||||
def validate_features(self) -> None:
|
||||
if not self.image_features:
|
||||
raise ValueError("XVLA requires at least one visual feature in the inputs.")
|
||||
if self.use_proprio and self.robot_state_feature is None:
|
||||
raise ValueError("`use_proprio=True` requires a proprioceptive state feature.")
|
||||
if self.num_image_views is None:
|
||||
self.num_image_views = len(self.image_features) + self.empty_cameras
|
||||
else:
|
||||
self.num_image_views = max(self.num_image_views, len(self.image_features) + self.empty_cameras)
|
||||
|
||||
if self.empty_cameras > 0:
|
||||
height, width = (480, 640)
|
||||
if self.resize_imgs_with_padding is not None:
|
||||
height, width = self.resize_imgs_with_padding
|
||||
for idx in range(self.empty_cameras):
|
||||
key = f"{OBS_IMAGES}.empty_camera_{idx}"
|
||||
if key not in self.input_features:
|
||||
self.input_features[key] = PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(3, height, width),
|
||||
)
|
||||
|
||||
def get_optimizer_preset(self) -> XVLAAdamWConfig:
|
||||
"""Return the XVLA-specific optimizer with differential learning rates.
|
||||
|
||||
This optimizer applies:
|
||||
- 1/10 LR for VLM parameters (stable optimization)
|
||||
- Full LR for transformer/action head
|
||||
- Configurable LR for soft-prompts (with optional warm-up)
|
||||
"""
|
||||
return XVLAAdamWConfig(
|
||||
lr=self.optimizer_lr,
|
||||
betas=self.optimizer_betas,
|
||||
eps=self.optimizer_eps,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
grad_clip_norm=self.optimizer_grad_clip_norm,
|
||||
soft_prompt_lr_scale=self.optimizer_soft_prompt_lr_scale,
|
||||
soft_prompt_warmup_lr_scale=self.optimizer_soft_prompt_warmup_lr_scale,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self) -> CosineDecayWithWarmupSchedulerConfig:
|
||||
return CosineDecayWithWarmupSchedulerConfig(
|
||||
peak_lr=self.optimizer_lr,
|
||||
decay_lr=self.scheduler_decay_lr,
|
||||
num_warmup_steps=self.scheduler_warmup_steps,
|
||||
num_decay_steps=self.scheduler_decay_steps,
|
||||
)
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> list[int] | None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list[int]:
|
||||
return list(range(self.chunk_size))
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> list[int] | None:
|
||||
return None
|
||||
@@ -1,548 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Copyright 2025 The HuggingFace Inc. team and 2toINF (https://github.com/2toINF)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import builtins
|
||||
import logging
|
||||
import os
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
||||
from lerobot.policies.utils import populate_queues
|
||||
from lerobot.utils.constants import ACTION, OBS_LANGUAGE_TOKENS, OBS_STATE
|
||||
|
||||
from .action_hub import build_action_space
|
||||
from .configuration_florence2 import Florence2Config
|
||||
from .configuration_xvla import XVLAConfig
|
||||
from .modeling_florence2 import Florence2ForConditionalGeneration
|
||||
from .soft_transformer import SoftPromptedTransformer
|
||||
|
||||
|
||||
class XVLAModel(nn.Module):
|
||||
"""
|
||||
XVLA backbone that stitches Florence-2 embeddings with the temporal/action transformer head.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: XVLAConfig,
|
||||
florence_config: Florence2Config,
|
||||
proprio_dim: int,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.chunk_size: int = config.chunk_size
|
||||
self.use_proprio: bool = config.use_proprio
|
||||
|
||||
# Build action space with auto-detection for "auto" mode
|
||||
if config.action_mode.lower() == "auto":
|
||||
# Auto-detect real action dim from config.action_feature
|
||||
real_dim = (
|
||||
config.action_feature.shape[-1]
|
||||
if config.action_feature is not None
|
||||
else config.max_action_dim
|
||||
)
|
||||
self.action_space = build_action_space(
|
||||
config.action_mode.lower(),
|
||||
real_dim=real_dim,
|
||||
max_dim=config.max_action_dim,
|
||||
)
|
||||
else:
|
||||
self.action_space = build_action_space(config.action_mode.lower())
|
||||
|
||||
self.dim_action = self.action_space.dim_action
|
||||
self.dim_proprio = proprio_dim
|
||||
|
||||
self.vlm = Florence2ForConditionalGeneration(florence_config)
|
||||
if hasattr(self.vlm, "language_model"):
|
||||
lm = self.vlm.language_model
|
||||
if hasattr(lm, "model") and hasattr(lm.model, "decoder"):
|
||||
del lm.model.decoder
|
||||
if hasattr(lm, "lm_head"):
|
||||
del lm.lm_head
|
||||
|
||||
projection_dim = getattr(self.vlm.config, "projection_dim", None)
|
||||
if projection_dim is None:
|
||||
raise ValueError("Florence2 config must provide `projection_dim` for multimodal fusion.")
|
||||
|
||||
self.transformer = SoftPromptedTransformer(
|
||||
hidden_size=config.hidden_size,
|
||||
multi_modal_input_size=projection_dim,
|
||||
depth=config.depth,
|
||||
num_heads=config.num_heads,
|
||||
mlp_ratio=config.mlp_ratio,
|
||||
num_domains=config.num_domains,
|
||||
dim_action=self.dim_action,
|
||||
dim_propio=self.dim_proprio,
|
||||
len_soft_prompts=config.len_soft_prompts,
|
||||
dim_time=config.dim_time,
|
||||
max_len_seq=config.max_len_seq,
|
||||
use_hetero_proj=config.use_hetero_proj,
|
||||
)
|
||||
|
||||
# Apply freezing based on config
|
||||
self._apply_freezing()
|
||||
|
||||
# Apply dtype casting based on config
|
||||
self._apply_dtype()
|
||||
|
||||
def _get_target_dtype(self) -> torch.dtype:
|
||||
"""Get the target dtype based on config."""
|
||||
if self.config.dtype == "bfloat16":
|
||||
return torch.bfloat16
|
||||
return torch.float32
|
||||
|
||||
def _apply_dtype(self) -> None:
|
||||
"""
|
||||
Apply dtype casting to model components based on config.
|
||||
"""
|
||||
target_dtype = self._get_target_dtype()
|
||||
self.to(dtype=target_dtype)
|
||||
|
||||
def _apply_freezing(self) -> None:
|
||||
"""
|
||||
Freeze VLM vision and language encoders based on config options.
|
||||
Keep only policy transformer and soft prompts trainable.
|
||||
"""
|
||||
# Freeze vision encoder
|
||||
if self.config.freeze_vision_encoder and hasattr(self.vlm, "vision_tower"):
|
||||
for param in self.vlm.vision_tower.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
# Freeze language encoder
|
||||
if self.config.freeze_language_encoder and hasattr(self.vlm, "language_model"):
|
||||
lm = self.vlm.language_model
|
||||
# Freeze encoder
|
||||
if hasattr(lm, "model") and hasattr(lm.model, "encoder"):
|
||||
for param in lm.model.encoder.parameters():
|
||||
param.requires_grad = False
|
||||
# Freeze shared embeddings
|
||||
if hasattr(lm, "model") and hasattr(lm.model, "shared"):
|
||||
for param in lm.model.shared.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
# Freeze or unfreeze policy transformer
|
||||
if not self.config.train_policy_transformer:
|
||||
for name, param in self.transformer.named_parameters():
|
||||
if "soft_prompts" not in name:
|
||||
param.requires_grad = False
|
||||
|
||||
# Freeze or unfreeze soft prompts
|
||||
if not self.config.train_soft_prompts and hasattr(self.transformer, "soft_prompt_hub"):
|
||||
for param in self.transformer.soft_prompt_hub.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward_vlm(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
pixel_values: torch.FloatTensor,
|
||||
image_mask: torch.Tensor,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Encode text and multi-view images via Florence2 encoder.
|
||||
"""
|
||||
batch_size, num_views = pixel_values.shape[:2]
|
||||
flat_mask = image_mask.view(-1).to(dtype=torch.bool)
|
||||
flat_images = pixel_values.flatten(0, 1)
|
||||
num_valid = int(flat_mask.sum().item())
|
||||
if num_valid == 0:
|
||||
raise ValueError("At least one image view must be valid per batch.")
|
||||
|
||||
valid_images = flat_images[flat_mask]
|
||||
valid_feats = self.vlm._encode_image(valid_images)
|
||||
tokens_per_view, hidden_dim = valid_feats.shape[1:]
|
||||
|
||||
image_features = valid_feats.new_zeros((batch_size * num_views, tokens_per_view, hidden_dim))
|
||||
image_features[flat_mask] = valid_feats
|
||||
image_features = image_features.view(batch_size, num_views, tokens_per_view, hidden_dim)
|
||||
inputs_embeds = self.vlm.get_input_embeddings()(input_ids)
|
||||
merged_embeds, attention_mask = self.vlm._merge_input_ids_with_image_features(
|
||||
image_features[:, 0],
|
||||
inputs_embeds,
|
||||
)
|
||||
|
||||
enc_out = self.vlm.language_model.model.encoder(
|
||||
attention_mask=attention_mask,
|
||||
inputs_embeds=merged_embeds,
|
||||
)[0]
|
||||
|
||||
aux_visual_inputs = image_features[:, 1:].reshape(batch_size, -1, hidden_dim)
|
||||
return {"vlm_features": enc_out, "aux_visual_inputs": aux_visual_inputs}
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
image_input: torch.FloatTensor,
|
||||
image_mask: torch.Tensor,
|
||||
domain_id: torch.LongTensor,
|
||||
proprio: torch.Tensor,
|
||||
action: torch.Tensor,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Forward pass for the XVLA model.
|
||||
"""
|
||||
target_dtype = self._get_target_dtype()
|
||||
image_input = image_input.to(dtype=target_dtype)
|
||||
proprio = proprio.to(dtype=target_dtype)
|
||||
action = action.to(dtype=target_dtype)
|
||||
|
||||
enc = self.forward_vlm(input_ids, image_input, image_mask)
|
||||
|
||||
batch_size = input_ids.shape[0]
|
||||
t = (
|
||||
torch.rand(1, device=input_ids.device, dtype=target_dtype)
|
||||
+ torch.arange(batch_size, device=input_ids.device, dtype=target_dtype) / batch_size
|
||||
) % (1 - 1e-5)
|
||||
|
||||
action_noisy = torch.randn_like(action) * t.view(-1, 1, 1) + action * (1 - t).view(-1, 1, 1)
|
||||
proprio_m, action_noisy_m = self.action_space.preprocess(proprio, action_noisy)
|
||||
|
||||
pred_action = self.transformer(
|
||||
domain_id=domain_id,
|
||||
action_with_noise=action_noisy_m,
|
||||
t=t,
|
||||
proprio=proprio_m,
|
||||
**enc,
|
||||
)
|
||||
return self.action_space.compute_loss(pred_action, action)
|
||||
|
||||
@torch.no_grad()
|
||||
def generate_actions(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
image_input: torch.FloatTensor,
|
||||
image_mask: torch.Tensor,
|
||||
domain_id: torch.LongTensor,
|
||||
proprio: torch.Tensor,
|
||||
steps: int,
|
||||
) -> torch.Tensor:
|
||||
self.eval()
|
||||
|
||||
target_dtype = self._get_target_dtype()
|
||||
image_input = image_input.to(dtype=target_dtype)
|
||||
proprio = proprio.to(dtype=target_dtype)
|
||||
|
||||
enc = self.forward_vlm(input_ids, image_input, image_mask)
|
||||
|
||||
batch_size = input_ids.shape[0]
|
||||
action_dim = self.dim_action
|
||||
|
||||
x1 = torch.randn(batch_size, self.chunk_size, action_dim, device=proprio.device, dtype=target_dtype)
|
||||
action = torch.zeros_like(x1)
|
||||
|
||||
steps = max(1, int(steps))
|
||||
for i in range(steps, 0, -1):
|
||||
t = torch.full((batch_size,), i / steps, device=proprio.device, dtype=target_dtype)
|
||||
x_t = x1 * t.view(-1, 1, 1) + action * (1 - t).view(-1, 1, 1)
|
||||
proprio_m, x_t_m = self.action_space.preprocess(proprio, x_t)
|
||||
action = self.transformer(
|
||||
domain_id=domain_id,
|
||||
action_with_noise=x_t_m,
|
||||
proprio=proprio_m,
|
||||
t=t,
|
||||
**enc,
|
||||
)
|
||||
return self.action_space.postprocess(action)
|
||||
|
||||
|
||||
class XVLAPolicy(PreTrainedPolicy):
|
||||
"""LeRobot-compliant wrapper built around the XVLA model."""
|
||||
|
||||
config_class = XVLAConfig
|
||||
name = "xvla"
|
||||
|
||||
def __init__(self, config: XVLAConfig, **kwargs):
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
florence_config = config.get_florence_config()
|
||||
proprio_dim = config.max_state_dim if config.use_proprio else 0
|
||||
self.model = XVLAModel(config=config, florence_config=florence_config, proprio_dim=proprio_dim)
|
||||
self.reset()
|
||||
|
||||
def reset(self) -> None:
|
||||
self._queues = {
|
||||
ACTION: deque(maxlen=self.config.n_action_steps),
|
||||
}
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
"""Return trainable named parameters for optimization.
|
||||
|
||||
Returns a dict of name -> param for all trainable parameters.
|
||||
This enables the xvla-adamw optimizer to apply differential learning rates
|
||||
based on parameter names (e.g., 1/10 LR for VLM components).
|
||||
"""
|
||||
return dict(filter(lambda kv: kv[1].requires_grad, self.named_parameters()))
|
||||
|
||||
def _prepare_state(self, batch: dict[str, Tensor], batch_size: int, device: torch.device) -> Tensor:
|
||||
if not self.config.use_proprio or OBS_STATE not in batch:
|
||||
return torch.zeros(batch_size, 0, device=device)
|
||||
state = batch[OBS_STATE]
|
||||
if state.ndim > 2:
|
||||
state = state[:, -1, :]
|
||||
return pad_vector(state, self.model.dim_proprio)
|
||||
|
||||
def _prepare_images(self, batch: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
|
||||
present_img_keys = [key for key in self.config.image_features if key in batch]
|
||||
if len(present_img_keys) == 0:
|
||||
raise ValueError(
|
||||
"All image features are missing from the batch. "
|
||||
f"Batch keys: {list(batch.keys())}, expected at least one of {list(self.config.image_features)}."
|
||||
)
|
||||
|
||||
images = []
|
||||
masks = []
|
||||
for key in present_img_keys:
|
||||
img = batch[key][:, -1] if batch[key].ndim == 5 else batch[key]
|
||||
if self.config.resize_imgs_with_padding is not None:
|
||||
img = resize_with_pad(img, *self.config.resize_imgs_with_padding)
|
||||
images.append(img)
|
||||
masks.append(torch.ones(img.size(0), dtype=torch.bool, device=img.device))
|
||||
|
||||
stacked_imgs = torch.stack(images, dim=1)
|
||||
stacked_masks = torch.stack(masks, dim=1)
|
||||
|
||||
total_views = self.config.num_image_views or stacked_imgs.size(1)
|
||||
total_views = max(total_views, stacked_imgs.size(1))
|
||||
num_pad = total_views - stacked_imgs.size(1)
|
||||
if num_pad > 0:
|
||||
pad_shape = (stacked_imgs.size(0), num_pad, *stacked_imgs.shape[2:])
|
||||
pad_imgs = stacked_imgs.new_zeros(pad_shape)
|
||||
pad_masks = stacked_masks.new_zeros((stacked_masks.size(0), num_pad))
|
||||
stacked_imgs = torch.cat([stacked_imgs, pad_imgs], dim=1)
|
||||
stacked_masks = torch.cat([stacked_masks, pad_masks], dim=1)
|
||||
|
||||
return stacked_imgs, stacked_masks
|
||||
|
||||
def _get_domain_id(self, batch: dict[str, Tensor], batch_size: int, device: torch.device) -> Tensor:
|
||||
candidate = None
|
||||
if self.config.domain_feature_key and self.config.domain_feature_key in batch:
|
||||
candidate = batch[self.config.domain_feature_key]
|
||||
elif "domain_id" in batch:
|
||||
candidate = batch["domain_id"]
|
||||
|
||||
if candidate is None:
|
||||
return torch.zeros(batch_size, dtype=torch.long, device=device)
|
||||
|
||||
if not isinstance(candidate, torch.Tensor):
|
||||
candidate = torch.as_tensor(candidate, device=device)
|
||||
else:
|
||||
candidate = candidate.to(device=device)
|
||||
|
||||
if candidate.ndim == 0:
|
||||
candidate = candidate.expand(batch_size)
|
||||
if candidate.ndim > 1:
|
||||
candidate = candidate.view(candidate.shape[0], -1)[:, 0]
|
||||
if candidate.shape[0] != batch_size:
|
||||
candidate = candidate.expand(batch_size)
|
||||
return candidate.to(dtype=torch.long)
|
||||
|
||||
def _prepare_action_targets(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
if ACTION not in batch:
|
||||
raise ValueError("Batch is missing action targets required for training.")
|
||||
actions = batch[ACTION]
|
||||
if actions.ndim == 2:
|
||||
actions = actions.unsqueeze(1)
|
||||
actions = pad_tensor_along_dim(actions, self.config.chunk_size, dim=1)
|
||||
if actions.shape[-1] != self.model.dim_action:
|
||||
actions = pad_vector(actions, self.model.dim_action)
|
||||
return actions
|
||||
|
||||
def _build_model_inputs(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
input_ids = batch[OBS_LANGUAGE_TOKENS]
|
||||
batch_size = input_ids.shape[0]
|
||||
images, image_mask = self._prepare_images(batch)
|
||||
domain_id = self._get_domain_id(batch, batch_size, images.device)
|
||||
proprio = self._prepare_state(batch, batch_size, images.device)
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"image_input": images,
|
||||
"image_mask": image_mask,
|
||||
"domain_id": domain_id,
|
||||
"proprio": proprio,
|
||||
}
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
||||
inputs = self._build_model_inputs(batch)
|
||||
targets = self._prepare_action_targets(batch)
|
||||
losses = self.model(action=targets, **inputs)
|
||||
total_loss = sum(losses.values())
|
||||
|
||||
log_dict = {k: v.detach().item() for k, v in losses.items()}
|
||||
log_dict["loss"] = total_loss.detach().item()
|
||||
return total_loss, log_dict
|
||||
|
||||
def _get_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
inputs = self._build_model_inputs(batch)
|
||||
actions = self.model.generate_actions(**inputs, steps=self.config.num_denoising_steps)
|
||||
return actions
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: # noqa: ARG002
|
||||
self.eval()
|
||||
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
|
||||
return self._get_action_chunk(batch)
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: # noqa: ARG002
|
||||
self.eval()
|
||||
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
|
||||
|
||||
if len(self._queues[ACTION]) == 0:
|
||||
actions = self._get_action_chunk(batch)
|
||||
self._queues[ACTION].extend(actions.transpose(0, 1)[: self.config.n_action_steps])
|
||||
|
||||
return self._queues[ACTION].popleft()
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls: builtins.type[T],
|
||||
pretrained_name_or_path: str | Path,
|
||||
*,
|
||||
config: PreTrainedConfig | None = None,
|
||||
force_download: bool = False,
|
||||
resume_download: bool | None = None,
|
||||
proxies: dict | None = None,
|
||||
token: str | bool | None = None,
|
||||
cache_dir: str | Path | None = None,
|
||||
local_files_only: bool = False,
|
||||
revision: str | None = None,
|
||||
strict: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Loads XVLA model weights with:
|
||||
- automatic prefix 'model.' added to all keys
|
||||
- skip list for layers that should remain randomly initialized
|
||||
"""
|
||||
import safetensors.torch
|
||||
|
||||
# step 1: load config
|
||||
# TODO: jadechoghari, fix this
|
||||
if config is None:
|
||||
config = PreTrainedConfig.from_pretrained(
|
||||
pretrained_name_or_path=pretrained_name_or_path,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=local_files_only,
|
||||
revision=revision,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
model_id = str(pretrained_name_or_path)
|
||||
instance = cls(config, **kwargs)
|
||||
# step 2: locate model.safetensors
|
||||
if os.path.isdir(model_id):
|
||||
logging.info("Loading weights from local directory")
|
||||
model_file = os.path.join(model_id, "model.safetensors")
|
||||
else:
|
||||
try:
|
||||
from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.utils import HfHubHTTPError
|
||||
|
||||
model_file = hf_hub_download(
|
||||
repo_id=model_id,
|
||||
filename="model.safetensors",
|
||||
revision=revision,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
token=token,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
except HfHubHTTPError as e:
|
||||
raise FileNotFoundError(f"model.safetensors not found on the Hub at {model_id}") from e
|
||||
|
||||
logging.info(f"Loading checkpoint from {model_file}")
|
||||
# step 3: load state dict
|
||||
state_dict = safetensors.torch.load_file(model_file)
|
||||
encoder_key = "model.vlm.language_model.model.encoder.embed_tokens.weight"
|
||||
shared_key = "model.vlm.language_model.model.shared.weight"
|
||||
if encoder_key in state_dict:
|
||||
state_dict[shared_key] = state_dict[encoder_key]
|
||||
# or deepcopy
|
||||
# step 4: load into instance
|
||||
instance.load_state_dict(state_dict, strict=True)
|
||||
logging.info("Loaded XVLA checkpoint")
|
||||
# step 5: finalize
|
||||
# Reapply dtype after loading state dict
|
||||
instance.model._apply_dtype()
|
||||
instance.to(config.device)
|
||||
instance.eval()
|
||||
return instance
|
||||
|
||||
|
||||
def resize_with_pad(img: torch.Tensor, height: int, width: int, pad_value: float = 0.0) -> torch.Tensor:
|
||||
if img.ndim != 4:
|
||||
raise ValueError(f"(b,c,h,w) expected, but got {img.shape}")
|
||||
|
||||
current_height, current_width = img.shape[2:]
|
||||
if current_height == height and current_width == width:
|
||||
return img
|
||||
|
||||
ratio = max(current_width / width, current_height / height)
|
||||
resized_height = int(current_height / ratio)
|
||||
resized_width = int(current_width / ratio)
|
||||
resized_img = F.interpolate(
|
||||
img, size=(resized_height, resized_width), mode="bilinear", align_corners=False
|
||||
)
|
||||
|
||||
pad_height = max(0, height - resized_height)
|
||||
pad_width = max(0, width - resized_width)
|
||||
padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value)
|
||||
return padded_img
|
||||
|
||||
|
||||
def pad_vector(vector: Tensor, new_dim: int) -> Tensor:
|
||||
if vector.shape[-1] == new_dim:
|
||||
return vector
|
||||
if new_dim == 0:
|
||||
shape = list(vector.shape)
|
||||
shape[-1] = 0
|
||||
return vector.new_zeros(*shape)
|
||||
shape = list(vector.shape)
|
||||
current_dim = shape[-1]
|
||||
shape[-1] = new_dim
|
||||
new_vector = vector.new_zeros(*shape)
|
||||
length = min(current_dim, new_dim)
|
||||
new_vector[..., :length] = vector[..., :length]
|
||||
return new_vector
|
||||
|
||||
|
||||
def pad_tensor_along_dim(tensor: Tensor, target_len: int, dim: int = 1) -> Tensor:
|
||||
current_len = tensor.size(dim)
|
||||
if current_len == target_len:
|
||||
return tensor
|
||||
if current_len > target_len:
|
||||
slices = [slice(None)] * tensor.dim()
|
||||
slices[dim] = slice(0, target_len)
|
||||
return tensor[tuple(slices)]
|
||||
pad_shape = list(tensor.shape)
|
||||
pad_shape[dim] = target_len - current_len
|
||||
pad_tensor = tensor.new_zeros(pad_shape)
|
||||
return torch.cat([tensor, pad_tensor], dim=dim)
|
||||