diff --git a/.github/ISSUE_TEMPLATE/bug-report.yml b/.github/ISSUE_TEMPLATE/bug-report.yml index 7423495de..9f602de30 100644 --- a/.github/ISSUE_TEMPLATE/bug-report.yml +++ b/.github/ISSUE_TEMPLATE/bug-report.yml @@ -12,57 +12,83 @@ # See the License for the specific language governing permissions and # limitations under the License. -name: "\U0001F41B Bug Report" -description: Submit a bug report to help us improve LeRobot +name: "πŸš€ Issue / Bug / Request" +description: Report a bug, suggest an improvement, or ask a technical question. body: - type: markdown attributes: value: | - Thanks for taking the time to submit a bug report! πŸ› - If this is not a bug related to the LeRobot library directly, but instead a general question about your code or the library specifically please use our [discord](https://discord.gg/s3KuuzsPFb). + ### Thanks for contributing to LeRobot! πŸ™Œ + Please choose the most relevant sections below. If this is a general "how-to" question, consider our [Discord](https://discord.gg/s3KuuzsPFb) for faster community support. + + - type: dropdown + id: issue-type + attributes: + label: Ticket Type + description: What kind of ticket are you opening? + options: + - "πŸ› Bug Report (Something isn't working)" + - "πŸ’‘ Feature Request / Improvement" + - "❓ Technical Question" + - "🧹 Maintenance / Documentation" + validations: + required: true - type: textarea id: system-info attributes: - label: System Info - description: Please share your LeRobot configuration by running `lerobot-info` (if installed) or `python -m lerobot.scripts.display_sys_info` (if not installed) and pasting the output below. + label: Environment & System Info + description: | + For bugs or technical questions, please run `lerobot-info` and paste the output. + (Optional for feature requests). render: Shell - placeholder: lerobot version, OS, python version, numpy version, torch version, and lerobot's configuration + placeholder: lerobot version, OS, python version, etc. + + - type: textarea + id: description validations: required: true + attributes: + label: Description + description: | + Provide a clear summary of the issue or your proposal. + - **Bugs:** What is happening? + - **Features:** What is the goal/use case? + - **Questions:** What are you trying to achieve? + placeholder: | + A clear and concise description of the issue or suggestion. + + - type: textarea + id: context-repro + attributes: + label: Context & Reproduction + description: | + Provide a code snippet, steps to reproduce a bug, or technical details about your proposal. + Please use code blocks for scripts and CLI commands. + placeholder: | + Steps to reproduce / Usage example: + 1. + 2. + 3. + + - type: textarea + id: logs + attributes: + label: Relevant logs or stack trace + description: If applicable, paste relevant error logs here. + render: Shell - type: checkboxes - id: information-scripts-examples + id: extras attributes: - label: Information - description: 'The problem arises when using:' + label: Checklist options: - - label: "One of the scripts in the examples/ folder of LeRobot" - - label: "My own task or dataset (give details below)" + - label: I have searched existing tickets to ensure this isn't a duplicate. + - label: I am using the latest version of the `main` branch. + - label: I have verified this is not an environment-specific problem. - type: textarea - id: reproduction - validations: - required: true + id: workaround attributes: - label: Reproduction - description: | - If needed, provide a simple code sample that reproduces the problem you ran into. It can be a Colab link or just a code snippet. - Sharing error messages or stack traces could be useful as well! - Important! Use code tags to correctly format your code. See https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting - Try to avoid screenshots, as they are hard to read and don't allow copy-and-pasting. - - placeholder: | - Steps to reproduce the behavior: - - 1. - 2. - 3. - - - type: textarea - id: expected-behavior - validations: - required: true - attributes: - label: Expected behavior - description: "A clear and concise description of what you would expect to happen." + label: Additional Info / Workarounds + description: Anything else we should know? If you have a workaround, please share it! diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index d37b1a92f..ec5ac4372 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -1,41 +1,54 @@ -## What this does +## Title -Explain what this PR does. Feel free to tag your PR with the appropriate label(s). +Short, imperative summary (e.g., "fix(robots): handle None in sensor parser"). See [CONTRIBUTING.md](../CONTRIBUTING.md) for PR conventions. -Examples: -| Title | Label | -|----------------------|-----------------| -| Fixes #[issue] | (πŸ› Bug) | -| Adds new dataset | (πŸ—ƒοΈ Dataset) | -| Optimizes something | (⚑️ Performance) | +## Type / Scope -## How it was tested +- **Type**: (Bug | Feature | Docs | Performance | Test | CI | Chore) +- **Scope**: (optional β€” name of module or package affected) -Explain/show how you tested your changes. +## Summary / Motivation -Examples: +- One-paragraph description of what changes and why. +- Why this change is needed and any trade-offs or design notes. -- Added `test_something` in `tests/test_stuff.py`. -- Added `new_feature` and checked that training converges with policy X on dataset/environment Y. -- Optimized `some_function`, it now runs X times faster than previously. +## Related issues -## How to checkout & try? (for the reviewer) +- Fixes / Closes: # (if any) +- Related: # (if any) -Provide a simple way for the reviewer to try out your changes. +## What changed -Examples: +- Short, concrete bullets of the modifications (files/behaviour). +- Short note if this introduces breaking changes and migration steps. -```bash -pytest -sx tests/test_stuff.py::test_something -``` +## How was this tested -```bash -lerobot-train --some.option=true -``` +- Tests added: list new tests or test files. +- Manual checks / dataset runs performed. -## SECTION TO REMOVE BEFORE SUBMITTING YOUR PR +## How to run locally (reviewer) -**Note**: Anyone in the community is free to review the PR once the tests have passed. Feel free to tag -members/contributors who may be interested in your PR. Try to avoid tagging more than 3 people. +- Run the relevant tests: -**Note**: Before submitting this PR, please read the [contributor guideline](https://github.com/huggingface/lerobot/blob/main/CONTRIBUTING.md#submitting-a-pull-request-pr). + ```bash + pytest -q tests/ -k + ``` + +- Run a quick example or CLI (if applicable): + + ```bash + lerobot-train --some.option=true + ``` + +## Checklist (required before merge) + +- [ ] Linting/formatting run (`pre-commit run -a`) +- [ ] All tests pass locally (`pytest`) +- [ ] Documentation updated +- [ ] CI is green + +## Reviewer notes + +- Anything the reviewer should focus on (performance, edge-cases, specific files) or general notes. +- Anyone in the community is free to review the PR. diff --git a/.github/labeler.yml b/.github/labeler.yml new file mode 100644 index 000000000..d3c5cc622 --- /dev/null +++ b/.github/labeler.yml @@ -0,0 +1,69 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +CI: + - changed-files: + - any-glob-to-any-file: + - '.github/**' + - 'docker/**' + +github_actions: + - changed-files: + - any-glob-to-any-file: '.github/**' + +documentation: + - changed-files: + - any-glob-to-any-file: + - '**/*.md' + - '**/*.mdx' + - 'docs/**' + +examples: + - changed-files: + - any-glob-to-any-file: 'examples/**' + +tests: + - changed-files: + - any-glob-to-any-file: 'tests/**' + +sensors: + - changed-files: + - any-glob-to-any-file: 'src/lerobot/cameras/**' + +configuration: + - changed-files: + - any-glob-to-any-file: 'src/lerobot/configs/**' + +dataset: + - changed-files: + - any-glob-to-any-file: 'src/lerobot/datasets/**' + +evaluation: + - changed-files: + - any-glob-to-any-file: 'src/lerobot/envs/**' + +robots: + - changed-files: + - any-glob-to-any-file: + - 'src/lerobot/teleoperators/**' + - 'src/lerobot/robots/**' + - 'src/lerobot/motors/**' + +policies: + - changed-files: + - any-glob-to-any-file: 'src/lerobot/policies/**' + +processor: + - changed-files: + - any-glob-to-any-file: 'src/lerobot/processor/**' diff --git a/.github/workflows/documentation-upload-pr.yml b/.github/workflows/documentation-upload-pr.yml index 22ba11cbb..6ee2a5caa 100644 --- a/.github/workflows/documentation-upload-pr.yml +++ b/.github/workflows/documentation-upload-pr.yml @@ -31,7 +31,8 @@ jobs: name: Upload Preview and Comment if: > github.event.workflow_run.event == 'pull_request' && - github.event.workflow_run.conclusion == 'success' + github.event.workflow_run.conclusion == 'success' && + github.repository == 'huggingface/lerobot' uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@main with: package_name: lerobot diff --git a/.github/workflows/documentation.yml b/.github/workflows/documentation.yml index 96005af3f..3007578fc 100644 --- a/.github/workflows/documentation.yml +++ b/.github/workflows/documentation.yml @@ -42,7 +42,9 @@ jobs: # This job builds and deploys the official documentation. build_main_docs: name: Build Main Docs - if: github.event_name == 'push' || github.event_name == 'workflow_dispatch' + if: > + (github.event_name == 'push' || github.event_name == 'workflow_dispatch') && + github.repository == 'huggingface/lerobot' permissions: contents: read uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@main @@ -58,7 +60,7 @@ jobs: # The result of this job triggers the 'Upload PR Documentation' workflow. build_pr_docs: name: Build PR Docs - if: github.event_name == 'pull_request' + if: github.event_name == 'pull_request' && github.repository == 'huggingface/lerobot' permissions: contents: read pull-requests: write diff --git a/.github/workflows/fast_tests.yml b/.github/workflows/fast_tests.yml index a39773b4e..ffd3195c2 100644 --- a/.github/workflows/fast_tests.yml +++ b/.github/workflows/fast_tests.yml @@ -45,7 +45,6 @@ permissions: env: UV_VERSION: "0.8.0" PYTHON_VERSION: "3.10" - DOCKER_IMAGE_NAME: huggingface/lerobot-gpu # Ensures that only the latest commit for a PR or branch is built, canceling older runs. concurrency: diff --git a/.github/workflows/full_tests.yml b/.github/workflows/full_tests.yml index 0dba5e1db..ad222b04f 100644 --- a/.github/workflows/full_tests.yml +++ b/.github/workflows/full_tests.yml @@ -85,7 +85,7 @@ jobs: python-version: ${{ env.PYTHON_VERSION }} - name: Install lerobot with all extras - run: uv sync --all-extras --no-extra groot # TODO(Steven): Make flash-attn optional + run: uv sync --all-extras --no-extra groot --no-extra wallx # TODO(Steven): Make flash-attn optional - name: Run pytest (all extras) run: uv run pytest tests -vv --maxfail=10 diff --git a/.github/workflows/issue_labeler.yml b/.github/workflows/issue_labeler.yml new file mode 100644 index 000000000..27ca2b5f9 --- /dev/null +++ b/.github/workflows/issue_labeler.yml @@ -0,0 +1,89 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This workflow automatically labels issues based on their content. +name: Issue Labeler +on: + # Trigger on new issues and edits to existing issues + issues: + types: [opened, edited] + +permissions: + contents: read + issues: write + +jobs: + label-issue: + name: Auto Label Issue + runs-on: ubuntu-latest + if: github.repository == 'huggingface/lerobot' + steps: + - uses: actions/github-script@v8 + with: + script: | + // Setup Input Text + const body = (context.payload.issue.body || ''); + const title = (context.payload.issue.title || ''); + const cleanBody = body.replace(/```[\s\S]*?```/g, ''); + const text = `${title}\n${cleanBody}`.toLowerCase(); + const labelsToAdd = new Set(); + const matches = (re) => re.test(text); + + // Keyword Heuristics + + // Domain Specific + if (matches(/\b(bug|error|issue|fault|crash|exception)\b/i)) labelsToAdd.add('bug'); + if (matches(/\b(feature|enhancement|improvement|support|implement|proposal)\b/i)) labelsToAdd.add('enhancement'); + if (matches(/\b(question|help|how to||clarify|explain|unclear)\b/i)) labelsToAdd.add('question'); + if (matches(/\b(maintenance|documentation|docs|readme|tutorial|guide|wiki)\b/i)) labelsToAdd.add('documentation'); + if (matches(/\b(example|script|sample|demo|notebook)s?\b/i)) labelsToAdd.add('examples'); + if (matches(/\b(datasets?|data loader|data augmentation|data preprocessing)\b/i)) labelsToAdd.add('dataset'); + if (matches(/\b(mujoco|isaac|simulation|sim)\b/i)) labelsToAdd.add('simulation'); + if (matches(/\b(train|training|loss|optimizer|backward|gradient|wandb|sac)\b/i)) labelsToAdd.add('training'); + if (matches(/\b(rerun|plot|video|render|visualiz|gif)/i)) labelsToAdd.add('visualization'); + if (matches(/\b(camera|realsense|lidar|depth|sensor|imu|microphone|rgbd)\b/i)) labelsToAdd.add('sensors'); + if (matches(/\b(aloha|koch|so-100|so100|mobile|teleop|manipulator|robots?)\b/i)) labelsToAdd.add('robots'); + if (matches(/\b(teleop|teleoperator|controller|leader|follower|joystick|gamepad)\b/i)) labelsToAdd.add('teleoperators'); + if (matches(/\b(policy|policies|p0licy)\b/i)) labelsToAdd.add('policies'); + if (matches(/\b(processors?|pipeline)\b/i)) labelsToAdd.add('processor'); + if (matches(/\b(eval|evaluate|evaluation|metrics?|score|benchmark)\b/i)) labelsToAdd.add('evaluation'); + + // Infrastructure & Code Quality + if (matches(/\b(tests?|pytest|unittest|failing test)\b/i)) labelsToAdd.add('tests'); + if (matches(/\b(ci|github actions|workflow|gha|actions?|pipeline)\b/i)) { + labelsToAdd.add('CI'); + labelsToAdd.add('github_actions'); + } + if (matches(/\b(perf|latency|throughput|fps|speed|performance)\b/i)) labelsToAdd.add('performance'); + if (matches(/\b(dependency|requirements|pip|conda|install error|importerror|package not found)\b/i)) labelsToAdd.add('dependencies'); + if (matches(/\b(python|pyproject|requirements(\.txt)?|pip install|typing error)\b/i)) labelsToAdd.add('python'); + + // Documentation & Meta + if (matches(/\b(doc|documentation|docs|readme|typo|how to)\b/i)) labelsToAdd.add('documentation'); + if (matches(/\b(refactor|cleanup|restructure|rename|modernize code)\b/i)) labelsToAdd.add('refactor'); + if (matches(/\b(release|changelog|version bump|cut a release|tag v)\b/i)) labelsToAdd.add('release'); + if (matches(/\b(breaking change|major change)\b/i)) labelsToAdd.add('breaking change'); + + // Apply Labels + const labels = Array.from(labelsToAdd).filter(Boolean); + + if (labels.length > 0) { + console.log(`Adding labels: ${labels.join(', ')}`); + await github.rest.issues.addLabels({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.issue.number, + labels, + }); + } diff --git a/.github/workflows/nightly.yml b/.github/workflows/nightly.yml index be8b5c094..94d5cc9f2 100644 --- a/.github/workflows/nightly.yml +++ b/.github/workflows/nightly.yml @@ -43,6 +43,7 @@ jobs: name: Build CPU Docker for Nightly runs-on: group: aws-general-8-plus + if: github.repository == 'huggingface/lerobot' outputs: image_tag: ${{ env.DOCKER_IMAGE_NAME_CPU }} steps: @@ -77,6 +78,7 @@ jobs: name: Build GPU Docker for Nightly runs-on: group: aws-general-8-plus + if: github.repository == 'huggingface/lerobot' outputs: image_tag: ${{ env.DOCKER_IMAGE_NAME_GPU }} steps: diff --git a/.github/workflows/pr_labeler.yml b/.github/workflows/pr_labeler.yml new file mode 100644 index 000000000..177c20959 --- /dev/null +++ b/.github/workflows/pr_labeler.yml @@ -0,0 +1,39 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This workflow labels pull requests based on the files that were changed. +name: Pull Request Labeler + +on: + # Allows labeling pull requests when they are opened or updated + # zizmor: ignore[dangerous-triggers] Needed to label PRs from forks + pull_request_target: + branches: + - main + types: [opened, synchronize, reopened, ready_for_review] + +permissions: + contents: read + pull-requests: write + +jobs: + triage: + name: Label PR + runs-on: ubuntu-latest + if: github.repository == 'huggingface/lerobot' && !github.event.pull_request.draft + steps: + - uses: actions/labeler@v6 + with: + repo-token: ${{ secrets.GITHUB_TOKEN }} + sync-labels: true # Removes labels if files are removed from the PR diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 4891707ac..7b159dd17 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -29,6 +29,7 @@ jobs: build-and-publish: name: Build and publish Python distributions runs-on: ubuntu-latest + if: github.repository == 'huggingface/lerobot' outputs: version: ${{ steps.extract_info.outputs.tag_version }} permissions: diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml index 06fc69fc4..4dc119b5e 100644 --- a/.github/workflows/stale.yml +++ b/.github/workflows/stale.yml @@ -45,6 +45,7 @@ jobs: stale: name: Close Stale Issues and PRs runs-on: ubuntu-latest + if: github.repository == 'huggingface/lerobot' permissions: actions: write contents: write # only for delete-branch option diff --git a/.github/workflows/unbound_deps_tests.yml b/.github/workflows/unbound_deps_tests.yml index 92271ba8e..95562d0dd 100644 --- a/.github/workflows/unbound_deps_tests.yml +++ b/.github/workflows/unbound_deps_tests.yml @@ -43,6 +43,7 @@ jobs: full-tests: name: Full Unbound Tests runs-on: ubuntu-latest + if: github.repository == 'huggingface/lerobot' env: MUJOCO_GL: egl HF_HOME: /mnt/cache/.cache/huggingface @@ -77,7 +78,7 @@ jobs: echo "Dependencies unbound:" && cat pyproject.toml - name: Install lerobot with all extras - run: uv sync --all-extras --no-extra groot # TODO(Steven): Make flash-attn optional + run: uv sync --all-extras --no-extra groot --no-extra wallx # TODO(Steven): Make flash-attn optional - name: Run pytest (all extras) run: uv run pytest tests -vv diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 896a0c10b..bfa3340d4 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -87,7 +87,7 @@ repos: # TODO(Steven): Uncomment when ready to use ##### Static Analysis & Typing ##### - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.18.2 + rev: v1.19.1 hooks: - id: mypy args: [--config-file=pyproject.toml] diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md index c0fdac843..305ffa276 100644 --- a/CODE_OF_CONDUCT.md +++ b/CODE_OF_CONDUCT.md @@ -52,7 +52,7 @@ decisions when appropriate. This Code of Conduct applies within all community spaces, and also applies when an individual is officially representing the community in public spaces. -Examples of representing our community include using an official email address, +Examples of representing our community include using an official e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. @@ -60,7 +60,7 @@ representative at an online or offline event. Instances of abusive, harassing, or otherwise unacceptable behavior may be reported to the community leaders responsible for enforcement at -[feedback@huggingface.co](mailto:feedback@huggingface.co). +feedback@huggingface.co. All complaints will be reviewed and investigated promptly and fairly. All community leaders are obligated to respect the privacy and security of the diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index dcb5c03d8..abca0d821 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,323 +1,83 @@ -# How to contribute to πŸ€— LeRobot? +# How to contribute to πŸ€— LeRobot -Everyone is welcome to contribute, and we value everybody's contribution. Code -is thus not the only way to help the community. Answering questions, helping -others, reaching out and improving the documentations are immensely valuable to -the community. +Everyone is welcome to contribute, and we value everybody's contribution. Code is not the only way to help the community. Answering questions, helping others, reaching out, and improving the documentation are immensely valuable. -It also helps us if you spread the word: reference the library from blog posts -on the awesome projects it made possible, shout out on Twitter when it has -helped you, or simply ⭐️ the repo to say "thank you". +Whichever way you choose to contribute, please be mindful to respect our [code of conduct](./CODE_OF_CONDUCT.md). -Whichever way you choose to contribute, please be mindful to respect our -[code of conduct](https://github.com/huggingface/lerobot/blob/main/CODE_OF_CONDUCT.md). +## Ways to Contribute -## You can contribute in so many ways! +You can contribute in many ways: -Some of the ways you can contribute to πŸ€— LeRobot: +- **Fixing issues:** Resolve bugs or improve existing code. +- **New features:** Develop new features. +- **Extend:** Implement new models/policies, robots, or simulation environments and upload datasets to the Hugging Face Hub. +- **Documentation:** Improve examples, guides, and docstrings. +- **Feedback:** Submit tickets related to bugs or desired new features. -- Fixing outstanding issues with the existing code. -- Implementing new models, datasets or simulation environments. -- Contributing to the examples or to the documentation. -- Submitting issues related to bugs or desired new features. +If you are unsure where to start, join our [Discord Channel](https://discord.gg/JkrYNdmw). -Following the guides below, feel free to open issues and PRs and to coordinate your efforts with the community on our [Discord Channel](https://discord.gg/VjFz58wn3R). For specific inquiries, reach out to [Remi Cadene](mailto:remi.cadene@huggingface.co). +## Development Setup -If you are not sure how to contribute or want to know the next features we working on, look on this project page: [LeRobot TODO](https://github.com/orgs/huggingface/projects/46) +To contribute code, you need to set up a development environment. -## Submitting a new issue or feature request +### 1. Fork and Clone -Do your best to follow these guidelines when submitting an issue or a feature -request. It will make it easier for us to come back to you quickly and with good -feedback. - -### Did you find a bug? - -The πŸ€— LeRobot library is robust and reliable thanks to the users who notify us of -the problems they encounter. So thank you for reporting an issue. - -First, we would really appreciate it if you could **make sure the bug was not -already reported** (use the search bar on Github under Issues). - -Did not find it? :( So we can act quickly on it, please follow these steps: - -- Include your **OS type and version**, the versions of **Python** and **PyTorch**. -- A short, self-contained, code snippet that allows us to reproduce the bug in - less than 30s. -- The full traceback if an exception is raised. -- Attach any other additional information, like screenshots, you think may help. - -### Do you want a new feature? - -A good feature request addresses the following points: - -1. Motivation first: - -- Is it related to a problem/frustration with the library? If so, please explain - why. Providing a code snippet that demonstrates the problem is best. -- Is it related to something you would need for a project? We'd love to hear - about it! -- Is it something you worked on and think could benefit the community? - Awesome! Tell us what problem it solved for you. - -2. Write a _paragraph_ describing the feature. -3. Provide a **code snippet** that demonstrates its future use. -4. In case this is related to a paper, please attach a link. -5. Attach any additional information (drawings, screenshots, etc.) you think may help. - -If your issue is well written we're already 80% of the way there by the time you -post it. - -## Adding new policies, datasets or environments - -Look at our implementations for [datasets](./src/lerobot/datasets/), [policies](./src/lerobot/policies/), -environments ([aloha](https://github.com/huggingface/gym-aloha), -[pusht](https://github.com/huggingface/gym-pusht)) -and follow the same api design. - -When implementing a new dataset loadable with LeRobotDataset follow these steps: - -- Update `available_datasets_per_env` in `lerobot/__init__.py` - -When implementing a new environment (e.g. `gym_aloha`), follow these steps: - -- Update `available_tasks_per_env` and `available_datasets_per_env` in `lerobot/__init__.py` - -When implementing a new policy class (e.g. `DiffusionPolicy`) follow these steps: - -- Update `available_policies` and `available_policies_per_env`, in `lerobot/__init__.py` -- Set the required `name` class attribute. -- Update variables in `tests/test_available.py` by importing your new Policy class - -## Submitting a pull request (PR) - -Before writing code, we strongly advise you to search through the existing PRs or -issues to make sure that nobody is already working on the same thing. If you are -unsure, it is always a good idea to open an issue to get some feedback. - -You will need basic `git` proficiency to be able to contribute to -πŸ€— LeRobot. `git` is not the easiest tool to use but it has the greatest -manual. Type `git --help` in a shell and enjoy. If you prefer books, [Pro -Git](https://git-scm.com/book/en/v2) is a very good reference. - -Follow these steps to start contributing: - -1. Fork the [repository](https://github.com/huggingface/lerobot) by - clicking on the 'Fork' button on the repository's page. This creates a copy of the code - under your GitHub user account. - -2. Clone your fork to your local disk, and add the base repository as a remote. The following command - assumes you have your public SSH key uploaded to GitHub. See the following guide for more - [information](https://docs.github.com/en/repositories/creating-and-managing-repositories/cloning-a-repository). - - ```bash - git clone git@github.com:/lerobot.git - cd lerobot - git remote add upstream https://github.com/huggingface/lerobot.git - ``` - -3. Create a new branch to hold your development changes, and do this for every new PR you work on. - - Start by synchronizing your `main` branch with the `upstream/main` branch (more details in the [GitHub Docs](https://docs.github.com/en/github/collaborating-with-issues-and-pull-requests/syncing-a-fork)): - - ```bash - git checkout main - git fetch upstream - git rebase upstream/main - ``` - - Once your `main` branch is synchronized, create a new branch from it: - - ```bash - git checkout -b a-descriptive-name-for-my-changes - ``` - - 🚨 **Do not** work on the `main` branch. - -4. for development, we advise to use a tool like `poetry` or `uv` instead of just `pip` to easily track our dependencies. - Follow the instructions to [install poetry](https://python-poetry.org/docs/#installation) (use a version >=2.1.0) or to [install uv](https://docs.astral.sh/uv/getting-started/installation/#installation-methods) if you don't have one of them already. - - Set up a development environment with conda: - - ```bash - conda create -y -n lerobot-dev python=3.10 && conda activate lerobot-dev - ``` - - If you're using `uv`, it can manage python versions so you can instead do: - - ```bash - uv venv --python 3.10 && source .venv/bin/activate - ``` - - To develop on πŸ€— LeRobot, you will at least need to install the `dev` and `test` extras dependencies along with the core library: - - using `poetry` - - ```bash - poetry sync --extras "dev test" - ``` - - using `uv` - - ```bash - uv sync --extra dev --extra test - ``` - - You can also install the project with all its dependencies (including environments): - - using `poetry` - - ```bash - poetry sync --all-extras - ``` - - using `uv` - - ```bash - uv sync --all-extras - ``` - - > **Note:** If you don't install simulation environments with `--all-extras`, the tests that require them will be skipped when running the pytest suite locally. However, they _will_ be tested in the CI. In general, we advise you to install everything and test locally before pushing. - - Whichever command you chose to install the project (e.g. `poetry sync --all-extras`), you should run it again when pulling code with an updated version of `pyproject.toml` and `poetry.lock` in order to synchronize your virtual environment with the new dependencies. - - The equivalent of `pip install some-package`, would just be: - - using `poetry` - - ```bash - poetry add some-package - ``` - - using `uv` - - ```bash - uv add some-package - ``` - - When making changes to the poetry sections of the `pyproject.toml`, you should run the following command to lock dependencies. - using `poetry` - - ```bash - poetry lock - ``` - - using `uv` - - ```bash - uv lock - ``` - -5. Develop the features on your branch. - - As you work on the features, you should make sure that the test suite - passes. You should run the tests impacted by your changes like this (see - below an explanation regarding the environment variable): - - ```bash - pytest tests/.py - ``` - -6. Follow our style. - - `lerobot` relies on `ruff` to format its source code - consistently. Set up [`pre-commit`](https://pre-commit.com/) to run these checks - automatically as Git commit hooks. - - Install `pre-commit` hooks: - - ```bash - pre-commit install - ``` - - You can run these hooks whenever you need on staged files with: - - ```bash - pre-commit - ``` - - Once you're happy with your changes, add changed files using `git add` and - make a commit with `git commit` to record your changes locally: - - ```bash - git add modified_file.py - git commit - ``` - - Note, if you already committed some changes that have a wrong formatting, you can use: - - ```bash - pre-commit run --all-files - ``` - - Please write [good commit messages](https://chris.beams.io/posts/git-commit/). - - It is a good idea to sync your copy of the code with the original - repository regularly. This way you can quickly account for changes: - - ```bash - git fetch upstream - git rebase upstream/main - ``` - - Push the changes to your account using: - - ```bash - git push -u origin a-descriptive-name-for-my-changes - ``` - -7. Once you are satisfied (**and the checklist below is happy too**), go to the - webpage of your fork on GitHub. Click on 'Pull request' to send your changes - to the project maintainers for review. - -8. It's ok if maintainers ask you for changes. It happens to core contributors - too! So everyone can see the changes in the Pull request, work in your local - branch and push the changes to your fork. They will automatically appear in - the pull request. - -### Checklist - -1. The title of your pull request should be a summary of its contribution; -2. If your pull request addresses an issue, please mention the issue number in - the pull request description to make sure they are linked (and people - consulting the issue know you are working on it); -3. To indicate a work in progress please prefix the title with `[WIP]`, or preferably mark - the PR as a draft PR. These are useful to avoid duplicated work, and to differentiate - it from PRs ready to be merged; -4. Make sure existing tests pass; - -### Tests - -An extensive test suite is included to test the library behavior and several examples. Library tests can be found in the [tests folder](https://github.com/huggingface/lerobot/tree/main/tests). - -Install [git lfs](https://git-lfs.com/) to retrieve test artifacts (if you don't have it already). - -On Mac: +Fork the repository on GitHub, then clone your fork: ```bash -brew install git-lfs -git lfs install +git clone https://github.com//lerobot.git +cd lerobot +git remote add upstream https://github.com/huggingface/lerobot.git ``` -On Ubuntu: +### 2. Environment Installation + +Please follow our [Installation Guide](./docs/source/installation.mdx) for the environment setup & installation from source. + +## Running Tests & Quality Checks + +### Code Style (Pre-commit) + +Install `pre-commit` hooks to run checks automatically before you commit: ```bash -sudo apt-get install git-lfs -git lfs install +pre-commit install ``` -Pull artifacts if they're not in [tests/artifacts](tests/artifacts) +To run checks manually on all files: ```bash +pre-commit run --all-files +``` + +### Running Tests + +We use `pytest`. First, ensure you have test artifacts by installing **git-lfs**: + +```bash +git lfs install git lfs pull ``` -We use `pytest` in order to run the tests. From the root of the -repository, here's how to run tests with `pytest` for the library: +Run the full suite (this may require extras installed): ```bash -python -m pytest -sv ./tests +pytest -sv ./tests ``` -You can specify a smaller set of tests in order to test only the feature -you're working on. +Or run a specific test file during development: + +```bash +pytest -sv tests/test_specific_feature.py +``` + +## Submitting Issues & Pull Requests + +Use the templates for required fields and examples. + +- **Issues:** Follow the [ticket template](./.github/ISSUE_TEMPLATE/bug-report.yml). +- **Pull requests:** Rebase on `upstream/main`, use a descriptive branch (don't work on `main`), run `pre-commit` and tests locally, and follow the [PR template](./.github/PULL_REQUEST_TEMPLATE.md). + +One member of the LeRobot team will then review your contribution. + +Thank you for contributing to LeRobot! diff --git a/README.md b/README.md index 964af4c1d..02652d1c9 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,5 @@

- LeRobot, Hugging Face Robotics Library -
-
+ LeRobot, Hugging Face Robotics Library

@@ -12,323 +10,130 @@ [![Status](https://img.shields.io/pypi/status/lerobot)](https://pypi.org/project/lerobot/) [![Version](https://img.shields.io/pypi/v/lerobot)](https://pypi.org/project/lerobot/) [![Contributor Covenant](https://img.shields.io/badge/Contributor%20Covenant-v2.1-ff69b4.svg)](https://github.com/huggingface/lerobot/blob/main/CODE_OF_CONDUCT.md) -[![Discord](https://dcbadge.vercel.app/api/server/C5P34WJ68S?style=flat)](https://discord.gg/s3KuuzsPFb) - -
-

-

- Build Your Own HopeJR Robot!

-

+**LeRobot** aims to provide models, datasets, and tools for real-world robotics in PyTorch. The goal is to lower the barrier to entry so that everyone can contribute to and benefit from shared datasets and pretrained models. -
- HopeJR robot +πŸ€— A hardware-agnostic, Python-native interface that standardizes control across diverse platforms, from low-cost arms (SO-100) to humanoids. -

Meet HopeJR – A humanoid robot arm and hand for dexterous manipulation!

-

Control it with exoskeletons and gloves for precise hand movements.

-

Perfect for advanced manipulation tasks! πŸ€–

+πŸ€— A standardized, scalable LeRobotDataset format (Parquet + MP4 or images) hosted on the Hugging Face Hub, enabling efficient storage, streaming and visualization of massive robotic datasets. -

- See the full HopeJR tutorial here.

-
+πŸ€— State-of-the-art policies that have been shown to transfer to the real-world ready for training and deployment. -
+πŸ€— Comprehensive support for the open-source ecosystem to democratize physical AI. -

-

- Build Your Own SO-101 Robot!

-

+## Quick Start -
- - - - - -
SO-101 follower armSO-101 leader arm
- -

Meet the updated SO100, the SO-101 – Just €114 per arm!

-

Train it in minutes with a few simple moves on your laptop.

-

Then sit back and watch your creation act autonomously! 🀯

- -

- See the full SO-101 tutorial here.

- -

Want to take it to the next level? Make your SO-101 mobile by building LeKiwi!

-

Check out the LeKiwi tutorial and bring your robot to life on wheels.

- - LeKiwi mobile robot -
- -
- -

-

LeRobot: State-of-the-art AI for real-world robotics

-

- ---- - -πŸ€— LeRobot aims to provide models, datasets, and tools for real-world robotics in PyTorch. The goal is to lower the barrier to entry to robotics so that everyone can contribute and benefit from sharing datasets and pretrained models. - -πŸ€— LeRobot contains state-of-the-art approaches that have been shown to transfer to the real-world with a focus on imitation learning and reinforcement learning. - -πŸ€— LeRobot already provides a set of pretrained models, datasets with human collected demonstrations, and simulation environments to get started without assembling a robot. In the coming weeks, the plan is to add more and more support for real-world robotics on the most affordable and capable robots out there. - -πŸ€— LeRobot hosts pretrained models and datasets on this Hugging Face community page: [huggingface.co/lerobot](https://huggingface.co/lerobot) - -#### Examples of pretrained models on simulation environments - - - - - - - - - - - - -
ACT policy on ALOHA envTDMPC policy on SimXArm envDiffusion policy on PushT env
ACT policy on ALOHA envTDMPC policy on SimXArm envDiffusion policy on PushT env
- -## Installation - -LeRobot works with Python 3.10+ and PyTorch 2.2+. - -### Environment Setup - -Create a virtual environment with Python 3.10 and activate it, e.g. with [`miniforge`](https://conda-forge.org/download/): - -```bash -conda create -y -n lerobot python=3.10 -conda activate lerobot -``` - -When using `conda`, install `ffmpeg` in your environment: - -```bash -conda install ffmpeg -c conda-forge -``` - -> **NOTE:** This usually installs `ffmpeg 7.X` for your platform compiled with the `libsvtav1` encoder. If `libsvtav1` is not supported (check supported encoders with `ffmpeg -encoders`), you can: -> -> - _[On any platform]_ Explicitly install `ffmpeg 7.X` using: -> -> ```bash -> conda install ffmpeg=7.1.1 -c conda-forge -> ``` -> -> - _[On Linux only]_ Install [ffmpeg build dependencies](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#GettheDependencies) and [compile ffmpeg from source with libsvtav1](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#libsvtav1), and make sure you use the corresponding ffmpeg binary to your install with `which ffmpeg`. - -### Install LeRobot πŸ€— - -#### From Source - -First, clone the repository and navigate into the directory: - -```bash -git clone https://github.com/huggingface/lerobot.git -cd lerobot -``` - -Then, install the library in editable mode. This is useful if you plan to contribute to the code. - -```bash -pip install -e . -``` - -> **NOTE:** If you encounter build errors, you may need to install additional dependencies (`cmake`, `build-essential`, and `ffmpeg libs`). On Linux, run: -> `sudo apt-get install cmake build-essential python3-dev pkg-config libavformat-dev libavcodec-dev libavdevice-dev libavutil-dev libswscale-dev libswresample-dev libavfilter-dev`. For other systems, see: [Compiling PyAV](https://pyav.org/docs/develop/overview/installation.html#bring-your-own-ffmpeg) - -For simulations, πŸ€— LeRobot comes with gymnasium environments that can be installed as extras: - -- [aloha](https://github.com/huggingface/gym-aloha) -- [xarm](https://github.com/huggingface/gym-xarm) -- [pusht](https://github.com/huggingface/gym-pusht) - -For instance, to install πŸ€— LeRobot with aloha and pusht, use: - -```bash -pip install -e ".[aloha, pusht]" -``` - -### Installation from PyPI - -**Core Library:** -Install the base package with: +LeRobot can be installed directly from PyPI. ```bash pip install lerobot +lerobot-info ``` -_This installs only the default dependencies._ +> [!IMPORTANT] +> For detailed installation guide, please see the [Installation Documentation](https://huggingface.co/docs/lerobot/installation). -**Extra Features:** -To install additional functionality, use one of the following: +## Robots & Control + +
+ Reachy 2 Demo +
+ +LeRobot provides a unified `Robot` class interface that decouples control logic from hardware specifics. It supports a wide range of robots and teleoperation devices. + +```python +from lerobot.robots.myrobot import MyRobot + +# Connect to a robot +robot = MyRobot(config=...) +robot.connect() + +# Read observation and send action +obs = robot.get_observation() +action = model.select_action(obs) +robot.send_action(action) +``` + +**Supported Hardware:** SO100, LeKiwi, Koch, HopeJR, OMX, EarthRover, Reachy2, Gamepads, Keyboards, Phones, OpenARM, Unitree G1. + +While these devices are natively integrated into the LeRobot codebase, the library is designed to be extensible. You can easily implement the Robot interface to utilize LeRobot's data collection, training, and visualization tools for your own custom robot. + +For detailed hardware setup guides, see the [Hardware Documentation](https://huggingface.co/docs/lerobot/integrate_hardware). + +## LeRobot Dataset + +To solve the data fragmentation problem in robotics, we utilize the **LeRobotDataset** format. + +- **Structure:** Synchronized MP4 videos (or images) for vision and Parquet files for state/action data. +- **HF Hub Integration:** Explore thousands of robotics datasets on the [Hugging Face Hub](https://huggingface.co/lerobot). +- **Tools:** Seamlessly delete episodes, split by indices/fractions, add/remove features, and merge multiple datasets. + +```python +from lerobot.datasets.lerobot_dataset import LeRobotDataset + +# Load a dataset from the Hub +dataset = LeRobotDataset("lerobot/aloha_mobile_cabinet") + +# Access data (automatically handles video decoding) +episode_index=0 +print(f"{dataset[episode_index]['action'].shape=}\n") +``` + +Learn more about it in the [LeRobotDataset Documentation](https://huggingface.co/docs/lerobot/lerobot-dataset-v3) + +## SoTA Models + +LeRobot implements state-of-the-art policies in pure PyTorch, covering Imitation Learning, Reinforcement Learning, and Vision-Language-Action (VLA) models, with more coming soon. It also provides you with the tools to instrument and inspect your training process. + +

+ Gr00t Architecture +

+ +Training a policy is as simple as running a script configuration: ```bash -pip install 'lerobot[all]' # All available features -pip install 'lerobot[aloha,pusht]' # Specific features (Aloha & Pusht) -pip install 'lerobot[feetech]' # Feetech motor support +lerobot-train \ + --policy=act \ + --dataset.repo_id=lerobot/aloha_mobile_cabinet ``` -_Replace `[...]` with your desired features._ +| Category | Models | +| -------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| **Imitation Learning** | [ACT](./docs/source/policy_act_README.md), [Diffusion](./docs/source/policy_diffusion_README.md), [VQ-BeT](./docs/source/policy_vqbet_README.md) | +| **Reinforcement Learning** | [HIL-SERL](./docs/source/hilserl.mdx), [TDMPC](./docs/source/policy_tdmpc_README.md) & QC-FQL (coming soon) | +| **VLAs Models** | [Pi0.5](./docs/source/pi05.mdx), [GR00T N1.5](./docs/source/policy_groot_README.md), [SmolVLA](./docs/source/policy_smolvla_README.md), [XVLA](./docs/source/xvla.mdx) | -**Available Tags:** -For a full list of optional dependencies, see: -https://pypi.org/project/lerobot/ +Similarly to the hardware, you can easily implement your own policy & leverage LeRobot's data collection, training, and visualization tools, and share your model to the HF Hub -> [!NOTE] -> For lerobot 0.4.0, if you want to install pi tags, you will have to do: `pip install "lerobot[pi]@git+https://github.com/huggingface/lerobot.git"`. -> -> This will be solved in the next patch release +For detailed policy setup guides, see the [Policy Documentation](https://huggingface.co/docs/lerobot/bring_your_own_policies). -### Weights & Biases +## Inference & Evaluation -To use [Weights and Biases](https://docs.wandb.ai/quickstart) for experiment tracking, log in with +Evaluate your policies in simulation or on real hardware using the unified evaluation script. LeRobot supports standard benchmarks like **LIBERO**, **MetaWorld** and more to come. ```bash -wandb login +# Evaluate a policy on the LIBERO benchmark +lerobot-eval \ + --policy.path=lerobot/pi0_libero_finetuned \ + --env.type=libero \ + --env.task=libero_object \ + --eval.n_episodes=10 ``` -(note: you will also need to enable WandB in the configuration. See below.) +Learn how to implement your own simulation environment or benchmark and distribute it from the HF Hub by following the [EnvHub Documentation](https://huggingface.co/docs/lerobot/envhub) -### Visualize datasets +## Resources -Check out [example 1](https://github.com/huggingface/lerobot/blob/main/examples/dataset/load_lerobot_dataset.py) that illustrates how to use our dataset class which automatically downloads data from the Hugging Face hub. - -You can also locally visualize episodes from a dataset on the hub by executing our script from the command line: - -```bash -lerobot-dataset-viz \ - --repo-id lerobot/pusht \ - --episode-index 0 -``` - -or from a dataset in a local folder with the `root` option and the `--mode local` (in the following case the dataset will be searched for in `./my_local_data_dir/lerobot/pusht`) - -```bash -lerobot-dataset-viz \ - --repo-id lerobot/pusht \ - --root ./my_local_data_dir \ - --mode local \ - --episode-index 0 -``` - -It will open `rerun.io` and display the camera streams, robot states and actions, like this: - -https://github-production-user-asset-6210df.s3.amazonaws.com/4681518/328035972-fd46b787-b532-47e2-bb6f-fd536a55a7ed.mov?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240505%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240505T172924Z&X-Amz-Expires=300&X-Amz-Signature=d680b26c532eeaf80740f08af3320d22ad0b8a4e4da1bcc4f33142c15b509eda&X-Amz-SignedHeaders=host&actor_id=24889239&key_id=0&repo_id=748713144 - -Our script can also visualize datasets stored on a distant server. See `lerobot-dataset-viz --help` for more instructions. - -### The `LeRobotDataset` format - -A dataset in `LeRobotDataset` format is very simple to use. It can be loaded from a repository on the Hugging Face hub or a local folder simply with e.g. `dataset = LeRobotDataset("lerobot/aloha_static_coffee")` and can be indexed into like any Hugging Face and PyTorch dataset. For instance `dataset[0]` will retrieve a single temporal frame from the dataset containing observation(s) and an action as PyTorch tensors ready to be fed to a model. - -A specificity of `LeRobotDataset` is that, rather than retrieving a single frame by its index, we can retrieve several frames based on their temporal relationship with the indexed frame, by setting `delta_timestamps` to a list of relative times with respect to the indexed frame. For example, with `delta_timestamps = {"observation.image": [-1, -0.5, -0.2, 0]}` one can retrieve, for a given index, 4 frames: 3 "previous" frames 1 second, 0.5 seconds, and 0.2 seconds before the indexed frame, and the indexed frame itself (corresponding to the 0 entry). See example [1_load_lerobot_dataset.py](https://github.com/huggingface/lerobot/blob/main/examples/dataset/load_lerobot_dataset.py) for more details on `delta_timestamps`. - -Under the hood, the `LeRobotDataset` format makes use of several ways to serialize data which can be useful to understand if you plan to work more closely with this format. We tried to make a flexible yet simple dataset format that would cover most type of features and specificities present in reinforcement learning and robotics, in simulation and in real-world, with a focus on cameras and robot states but easily extended to other types of sensory inputs as long as they can be represented by a tensor. - -Here are the important details and internal structure organization of a typical `LeRobotDataset` instantiated with `dataset = LeRobotDataset("lerobot/aloha_static_coffee")`. The exact features will change from dataset to dataset but not the main aspects: - -``` -dataset attributes: - β”œ hf_dataset: a Hugging Face dataset (backed by Arrow/parquet). Typical features example: - β”‚ β”œ observation.images.cam_high (VideoFrame): - β”‚ β”‚ VideoFrame = {'path': path to a mp4 video, 'timestamp' (float32): timestamp in the video} - β”‚ β”œ observation.state (list of float32): position of an arm joints (for instance) - β”‚ ... (more observations) - β”‚ β”œ action (list of float32): goal position of an arm joints (for instance) - β”‚ β”œ episode_index (int64): index of the episode for this sample - β”‚ β”œ frame_index (int64): index of the frame for this sample in the episode ; starts at 0 for each episode - β”‚ β”œ timestamp (float32): timestamp in the episode - β”‚ β”œ next.done (bool): indicates the end of an episode ; True for the last frame in each episode - β”‚ β”” index (int64): general index in the whole dataset - β”œ meta: a LeRobotDatasetMetadata object containing: - β”‚ β”œ info: a dictionary of metadata on the dataset - β”‚ β”‚ β”œ codebase_version (str): this is to keep track of the codebase version the dataset was created with - β”‚ β”‚ β”œ fps (int): frame per second the dataset is recorded/synchronized to - β”‚ β”‚ β”œ features (dict): all features contained in the dataset with their shapes and types - β”‚ β”‚ β”œ total_episodes (int): total number of episodes in the dataset - β”‚ β”‚ β”œ total_frames (int): total number of frames in the dataset - β”‚ β”‚ β”œ robot_type (str): robot type used for recording - β”‚ β”‚ β”œ data_path (str): formattable string for the parquet files - β”‚ β”‚ β”” video_path (str): formattable string for the video files (if using videos) - β”‚ β”œ episodes: a DataFrame containing episode metadata with columns: - β”‚ β”‚ β”œ episode_index (int): index of the episode - β”‚ β”‚ β”œ tasks (list): list of tasks for this episode - β”‚ β”‚ β”œ length (int): number of frames in this episode - β”‚ β”‚ β”œ dataset_from_index (int): start index of this episode in the dataset - β”‚ β”‚ β”” dataset_to_index (int): end index of this episode in the dataset - β”‚ β”œ stats: a dictionary of statistics (max, mean, min, std) for each feature in the dataset, for instance - β”‚ β”‚ β”œ observation.images.front_cam: {'max': tensor with same number of dimensions (e.g. `(c, 1, 1)` for images, `(c,)` for states), etc.} - β”‚ β”‚ β”” ... - β”‚ β”” tasks: a DataFrame containing task information with task names as index and task_index as values - β”œ root (Path): local directory where the dataset is stored - β”œ image_transforms (Callable): optional image transformations to apply to visual modalities - β”” delta_timestamps (dict): optional delta timestamps for temporal queries -``` - -A `LeRobotDataset` is serialised using several widespread file formats for each of its parts, namely: - -- hf_dataset stored using Hugging Face datasets library serialization to parquet -- videos are stored in mp4 format to save space -- metadata are stored in plain json/jsonl files - -Dataset can be uploaded/downloaded from the HuggingFace hub seamlessly. To work on a local dataset, you can specify its location with the `root` argument if it's not in the default `~/.cache/huggingface/lerobot` location. - -#### Reproduce state-of-the-art (SOTA) - -We provide some pretrained policies on our [hub page](https://huggingface.co/lerobot) that can achieve state-of-the-art performances. -You can reproduce their training by loading the config from their run. Simply running: - -```bash -lerobot-train --config_path=lerobot/diffusion_pusht -``` - -reproduces SOTA results for Diffusion Policy on the PushT task. - -## Contribute - -If you would like to contribute to πŸ€— LeRobot, please check out our [contribution guide](https://github.com/huggingface/lerobot/blob/main/CONTRIBUTING.md). - -### Add a pretrained policy - -Once you have trained a policy you may upload it to the Hugging Face hub using a hub id that looks like `${hf_user}/${repo_name}` (e.g. [lerobot/diffusion_pusht](https://huggingface.co/lerobot/diffusion_pusht)). - -You first need to find the checkpoint folder located inside your experiment directory (e.g. `outputs/train/2024-05-05/20-21-12_aloha_act_default/checkpoints/002500`). Within that there is a `pretrained_model` directory which should contain: - -- `config.json`: A serialized version of the policy configuration (following the policy's dataclass config). -- `model.safetensors`: A set of `torch.nn.Module` parameters, saved in [Hugging Face Safetensors](https://huggingface.co/docs/safetensors/index) format. -- `train_config.json`: A consolidated configuration containing all parameters used for training. The policy configuration should match `config.json` exactly. This is useful for anyone who wants to evaluate your policy or for reproducibility. - -To upload these to the hub, run the following: - -```bash -huggingface-cli upload ${hf_user}/${repo_name} path/to/pretrained_model -``` - -See [lerobot_eval.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/scripts/lerobot_eval.py) for an example of how other people may use your policy. - -### Acknowledgment - -- The LeRobot team πŸ€— for building SmolVLA [Paper](https://arxiv.org/abs/2506.01844), [Blog](https://huggingface.co/blog/smolvla). -- Thanks to Tony Zhao, Zipeng Fu and colleagues for open sourcing ACT policy, ALOHA environments and datasets. Ours are adapted from [ALOHA](https://tonyzhaozh.github.io/aloha) and [Mobile ALOHA](https://mobile-aloha.github.io). -- Thanks to Cheng Chi, Zhenjia Xu and colleagues for open sourcing Diffusion policy, Pusht environment and datasets, as well as UMI datasets. Ours are adapted from [Diffusion Policy](https://diffusion-policy.cs.columbia.edu) and [UMI Gripper](https://umi-gripper.github.io). -- Thanks to Nicklas Hansen, Yunhai Feng and colleagues for open sourcing TDMPC policy, Simxarm environments and datasets. Ours are adapted from [TDMPC](https://github.com/nicklashansen/tdmpc) and [FOWM](https://www.yunhaifeng.com/FOWM). -- Thanks to Antonio Loquercio and Ashish Kumar for their early support. -- Thanks to [Seungjae (Jay) Lee](https://sjlee.cc/), [Mahi Shafiullah](https://mahis.life/) and colleagues for open sourcing [VQ-BeT](https://sjlee.cc/vq-bet/) policy and helping us adapt the codebase to our repository. The policy is adapted from [VQ-BeT repo](https://github.com/jayLEE0301/vq_bet_official). +- **[Documentation](https://huggingface.co/docs/lerobot/index):** The complete guide to tutorials & API. +- **[Discord](https://discord.gg/3gxM6Avj):** Join the `LeRobot` server to discuss with the community. +- **[X](https://x.com/LeRobotHF):** Follow us on X to stay up-to-date with the latest developments. +- **[Robot Learning Tutorial](https://huggingface.co/spaces/lerobot/robot-learning-tutorial):** A free, hands-on course to learn robot learning using LeRobot. ## Citation -If you want, you can cite this work with: +If you use LeRobot in your research, please cite: ```bibtex @misc{cadene2024lerobot, @@ -339,6 +144,14 @@ If you want, you can cite this work with: } ``` -## Star History +## Contribute -[![Star History Chart](https://api.star-history.com/svg?repos=huggingface/lerobot&type=Timeline)](https://star-history.com/#huggingface/lerobot&Timeline) +We welcome contributions from everyone in the community! To get started, please read our [CONTRIBUTING.md](./CONTRIBUTING.md) guide. Whether you're adding a new feature, improving documentation, or fixing a bug, your help and feedback are invaluable. We're incredibly excited about the future of open-source robotics and can't wait to work with you on what's nextβ€”thank you for your support! + +

+ SO101 Video +

+ +
+Built by the LeRobot team at Hugging Face with ❀️ +
diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index aae7372fa..7766b3472 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -41,7 +41,13 @@ title: NVIDIA GR00T N1.5 - local: xvla title: X-VLA + - local: walloss + title: WALL-OSS title: "Policies" +- sections: + - local: sarm + title: SARM + title: "Reward Models" - sections: - local: async title: Use Async Inference diff --git a/docs/source/il_robots.mdx b/docs/source/il_robots.mdx index 93a6bf72e..0bc1ca681 100644 --- a/docs/source/il_robots.mdx +++ b/docs/source/il_robots.mdx @@ -201,7 +201,8 @@ from lerobot.teleoperators.so100_leader.so100_leader import SO100Leader from lerobot.utils.control_utils import init_keyboard_listener from lerobot.utils.utils import log_say from lerobot.utils.visualization_utils import init_rerun -from lerobot.record import record_loop +from lerobot.scripts.lerobot_record import record_loop +from lerobot.processor import make_default_processors NUM_EPISODES = 5 FPS = 30 @@ -209,12 +210,19 @@ EPISODE_TIME_SEC = 60 RESET_TIME_SEC = 10 TASK_DESCRIPTION = "My task description" -# Create the robot and teleoperator configurations -camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)} +# Create robot configuration robot_config = SO100FollowerConfig( - port="/dev/tty.usbmodem58760434471", id="my_awesome_follower_arm", cameras=camera_config + id="my_awesome_follower_arm", + cameras={ + "front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS) # Optional: fourcc="MJPG" for troubleshooting OpenCV async error. + }, + port="/dev/tty.usbmodem58760434471", +) + +teleop_config = SO100LeaderConfig( + id="my_awesome_leader_arm", + port="/dev/tty.usbmodem585A0077581", ) -teleop_config = SO100LeaderConfig(port="/dev/tty.usbmodem585A0077581", id="my_awesome_leader_arm") # Initialize the robot and teleoperator robot = SO100Follower(robot_config) @@ -243,6 +251,9 @@ init_rerun(session_name="recording") robot.connect() teleop.connect() +# Create the required processors +teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors() + episode_idx = 0 while episode_idx < NUM_EPISODES and not events["stop_recording"]: log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}") @@ -251,6 +262,9 @@ while episode_idx < NUM_EPISODES and not events["stop_recording"]: robot=robot, events=events, fps=FPS, + teleop_action_processor=teleop_action_processor, + robot_action_processor=robot_action_processor, + robot_observation_processor=robot_observation_processor, teleop=teleop, dataset=dataset, control_time_s=EPISODE_TIME_SEC, @@ -265,6 +279,9 @@ while episode_idx < NUM_EPISODES and not events["stop_recording"]: robot=robot, events=events, fps=FPS, + teleop_action_processor=teleop_action_processor, + robot_action_processor=robot_action_processor, + robot_observation_processor=robot_observation_processor, teleop=teleop, control_time_s=RESET_TIME_SEC, single_task=TASK_DESCRIPTION, diff --git a/docs/source/sarm.mdx b/docs/source/sarm.mdx new file mode 100644 index 000000000..321097692 --- /dev/null +++ b/docs/source/sarm.mdx @@ -0,0 +1,586 @@ +# SARM: Stage-Aware Reward Modeling + +SARM (Stage-Aware Reward Modeling) is a video-based reward modeling framework for long-horizon robot manipulation tasks. This guide covers how to train SARM reward models and optionally use them with Reward-Aligned Behavior Cloning (RA-BC). + +**Paper**: [SARM: Stage-Aware Reward Modeling for Long Horizon Robot Manipulation](https://arxiv.org/abs/2509.25358) + +## Why Reward Models? + +Standard behavior cloning treats all demonstration frames equally, but real-world robot datasets are messy. They contain hesitations, corrections, and variable-quality trajectories. Reward models solve this by learning a generalizable notion of **task progress** from demonstrations: given video frames and a task description, they predict how close the robot is to completing the task (0β†’1). This learned "progress signal" can be used in multiple ways, two promising applications are: (1) **weighted imitation learning** (RA-BC), where high-progress frames receive more weight during policy training, and (2) **reinforcement learning**, where the reward model provides dense rewards for online or offline policy improvement. + +## Overview + +SARM has following features: + +1. **Stage-aware architecture**: Jointly predicts the high-level task stage and fine-grained progress within each stage +2. **Subtask annotations**: Uses natural language subtask annotations to derive consistent progress labels +3. **Temporal proportions**: Computes dataset-level priors (Ξ±Μ…\_k) for each subtask to normalize progress across variable-length demonstrations + +SARM trains on a compact **stage+tau** target for each frame: + +- **stage**: integer stage index `k ∈ {0, ..., K-1}` +- **Ο„ (tau)**: within-stage progress `Ο„ ∈ [0, 1]` +- **target encoding**: `y = k + Ο„` (this is what the dataset processor produces) + +At inference time (and in downstream RA-BC), SARM converts the raw `k + Ο„` value into a **normalized progress** in `[0, 1]` using dataset-level **temporal proportions** `Ξ±Μ…_k` (stored in `meta/temporal_proportions_*.json`). + +This matches **Formula (2)** from the paper: + +``` +progress_t = P_{k-1} + Ξ±Μ…_k Γ— Ο„_t +``` + +Where: + +- `Ο„_t = (t - s_k) / (e_k - s_k)` is within-subtask normalized time +- `P_{k-1}` is cumulative prior (sum of previous subtask proportions) +- `Ξ±Μ…_k` is the temporal proportion for subtask k + +This ensures identical task states map to consistent progress values, even across demonstrations of different lengths. + +## Inputs and Targets (What the new code expects) + +SARM is trained through its processor (`src/lerobot/policies/sarm/processor_sarm.py`), which: + +- **Encodes** images and task text with CLIP (ViT-B/32) into `video_features` and `text_features` +- **Pads/truncates** robot state into `state_features` (up to `max_state_dim`) +- **Builds targets** as `sparse_targets` (and `dense_targets` in `dense_only`/`dual`) using the stage+tau encoding `y = k + Ο„` +- **Masks rewind frames** using a per-sample `lengths` tensor (rewind is a training-time augmentation) + +At minimum, each training sample needs: + +- `task` (string): task description +- `policy.image_key` images and `policy.state_key` states from the dataset + +--- + +## Annotation Modes + +You can choose from **3 annotation modes** that determine how progress labels are computed: + +| Mode | Annotations Required | Heads | Use Case | +| -------------- | -------------------- | ---------------------------- | ------------------------------------------------------------ | +| `single_stage` | None | Sparse only | Simple tasks, quick experiments, no VLM needed | +| `dense_only` | Dense (VLM) | Dual (sparse auto-generated) | Detailed subtask tracking without defining high-level stages | +| `dual` | Sparse + Dense (VLM) | Dual | Full SARM paper setup with both granularities | + +### Mode Details + + + + +**No annotations required.** The entire episode is treated as a single stage called `"task"`, and progress is linear from 0 to 1 over the episode duration. + +- **Sparse head**: 1 stage ("task"), linear progress +- **Dense head**: Not used +- **Best for**: Simple tasks, quick experiments, or when VLM annotation is not available + +## Set Up Your Environment + +1. Install LeRobot by following our [Installation Guide](./installation). +2. Install SARM dependencies by running: + +```bash +pip install -e ".[sarm]" +``` + +Workflow: + +``` +1. Train SARM β†’ 2. Visualize predictions β†’ 3. (Optional) Train policy with RA-BC +``` + + + + +**Only dense (fine-grained) annotations from a VLM.** The sparse head automatically uses a single `"task"` stage covering the full episode, while the dense head learns detailed subtask progression. + +- **Sparse head**: 1 stage ("task"), linear progress (auto-generated) +- **Dense head**: Multiple fine-grained stages from VLM annotations +- **Best for**: When you want detailed subtask tracking but don't need to define high-level stages + +Workflow: + +``` +1. Annotate (dense) β†’ 2. Verify β†’ 3. Train SARM β†’ 4. Visualize β†’ 5. (Optional) Train policy with RA-BC +``` + + + + +**Both sparse and dense annotations from VLM.** Full dual-head mode as described in the SARM paper, with both high-level (sparse) and fine-grained (dense) stage predictions. + +- **Sparse head**: High-level stages from VLM annotations +- **Dense head**: Fine-grained stages from VLM annotations +- **Best for**: Complex multi-stage tasks where both granularities are useful + +Workflow: + +``` +1. Annotate (sparse+dense) β†’ 2. Verify β†’ 3. Train SARM β†’ 4. Visualize β†’ 5. (Optional) Train policy with RA-BC +``` + + + + +--- + +## Step 1: Subtask Annotation + + + + +**No annotation required!** Skip this step entirely. The model will use the episode's task description and compute linear progress automatically. + + + + +Generate **dense (fine-grained) annotations only** using a VLM. The sparse stage will be auto-generated. + +```bash +python src/lerobot/data_processing/sarm_annotations/subtask_annotation.py \ + --repo-id your-username/your-dataset \ + --dense-only \ + --dense-subtasks "Bring robot arms up from starting position,Grab near side and do 1st fold,Grab side and do 2nd fold,Grab side and do 3rd fold to finish folding" \ + --video-key observation.images.base \ + --num-workers 4 \ + --push-to-hub +``` + +**What gets saved:** + +- `meta/temporal_proportions_sparse.json` - Auto-generated sparse proportions (`{"task": 1.0}`) +- `meta/temporal_proportions_dense.json` - Dense temporal proportions +- Per-episode columns in `episodes/*.parquet`: + - `dense_subtask_names`, `dense_subtask_start_frames`, `dense_subtask_end_frames` + - (also time-based columns: `dense_subtask_start_times`, `dense_subtask_end_times`) + + + + +Generate **both sparse (high-level) and dense (fine-grained) annotations** using a VLM. + +```bash +python src/lerobot/data_processing/sarm_annotations/subtask_annotation.py \ + --repo-id your-username/your-dataset \ + --sparse-subtasks "Bring arms up from starting position,Fold the towel (3 folds in total)" \ + --dense-subtasks "Bring robot arms up from starting position,Grab near side and do 1st fold,Grab side and do 2nd fold,Grab side and do 3rd fold to finish folding" \ + --video-key observation.images.base \ + --num-workers 4 \ + --push-to-hub +``` + +**What gets saved:** + +- `meta/temporal_proportions_sparse.json` - Sparse temporal proportions +- `meta/temporal_proportions_dense.json` - Dense temporal proportions +- Per-episode columns in `episodes/*.parquet`: + - `sparse_subtask_names`, `sparse_subtask_start_frames`, `sparse_subtask_end_frames` + - `dense_subtask_names`, `dense_subtask_start_frames`, `dense_subtask_end_frames` + - (also time-based columns: `*_subtask_start_times`, `*_subtask_end_times`) + + + + +### Annotation Arguments + +| Argument | Description | +| ---------------------- | ------------------------------------------------------------------------------- | +| `--repo-id` | HuggingFace dataset repository ID | +| `--sparse-subtasks` | Comma-separated list of high-level subtask names | +| `--dense-subtasks` | Comma-separated list of fine-grained subtask names | +| `--dense-only` | Generate only dense annotations (auto-creates sparse "task" stage) | +| `--video-key` | Camera/video key to use (e.g., `observation.images.top`) | +| `--num-workers` | Number of parallel GPU workers (default: 1) | +| `--episodes` | Specific episode indices to annotate (default: all) | +| `--skip-existing` | Skip episodes that already have annotations | +| `--model` | VLM model (default: `Qwen/Qwen3-VL-30B-A3B-Instruct`) | +| `--num-visualizations` | Number of episodes to visualize after annotation (default: 5, set to 0 to skip) | + +> **Note**: After annotation completes, 5 episodes are automatically visualized by default. Use `--num-visualizations 0` to skip this step. + +--- + +## Step 2: Verify Annotations + + + + +**No verification needed!** Skip this step. + + + + +Visualize annotations using the `--visualize-only` flag: + +```bash +python src/lerobot/data_processing/sarm_annotations/subtask_annotation.py \ + --repo-id your-username/your-dataset \ + --visualize-only \ + --visualize-type dense \ + --num-visualizations 5 \ + --video-key observation.images.base \ + --output-dir ./subtask_viz +``` + + + + +Visualize annotations using the `--visualize-only` flag: + +```bash +python src/lerobot/data_processing/sarm_annotations/subtask_annotation.py \ + --repo-id your-username/your-dataset \ + --visualize-only \ + --visualize-type both \ + --num-visualizations 5 \ + --video-key observation.images.base \ + --output-dir ./subtask_viz +``` + + + + +This generates visualizations showing video frames with subtask boundaries overlaid and timeline of subtasks. + +### Visualization Arguments + +| Argument | Description | +| ---------------------- | -------------------------------------------------------------- | +| `--visualize-only` | Only visualize existing annotations (no generation) | +| `--num-visualizations` | Number of episodes to visualize (default: 5) | +| `--visualize-type` | Type of annotations to visualize: `sparse`, `dense`, or `both` | + +**Tip**: If annotations are inaccurate, adjust your subtask descriptions to be more specific and re-run. + +--- + +## Step 3: Train SARM + + + + +Train with **no annotations** - uses linear progress from 0 to 1: + +```bash +python src/lerobot/scripts/lerobot_train.py \ + --dataset.repo_id=your-username/your-dataset \ + --policy.type=sarm \ + --policy.annotation_mode=single_stage \ + --policy.image_key=observation.images.base \ + --output_dir=outputs/train/sarm_single \ + --batch_size=32 \ + --steps=5000 \ + --wandb.enable=true \ + --wandb.project=sarm \ + --policy.repo_id=your-username/your-model-name +``` + + + + +Train with **dense annotations only** (sparse auto-generated): + +```bash +python src/lerobot/scripts/lerobot_train.py \ + --dataset.repo_id=your-username/your-dataset \ + --policy.type=sarm \ + --policy.annotation_mode=dense_only \ + --policy.image_key=observation.images.base \ + --output_dir=outputs/train/sarm_dense \ + --batch_size=32 \ + --steps=5000 \ + --wandb.enable=true \ + --wandb.project=sarm \ + --policy.repo_id=your-username/your-model-name +``` + + + + +Train with **both sparse and dense annotations**: + +```bash +python src/lerobot/scripts/lerobot_train.py \ + --dataset.repo_id=your-username/your-dataset \ + --policy.type=sarm \ + --policy.annotation_mode=dual \ + --policy.image_key=observation.images.base \ + --output_dir=outputs/train/sarm_dual \ + --batch_size=32 \ + --steps=5000 \ + --wandb.enable=true \ + --wandb.project=sarm \ + --policy.repo_id=your-username/your-model-name +``` + + + + +### Multi-GPU Training + +Add `accelerate launch --multi_gpu --num_processes=4` to use multiple GPUs for training. + +### Training Arguments + +| Argument | Description | Default | +| -------------------------- | ----------------------------------------------------------------- | ------------------------ | +| `--policy.annotation_mode` | `single_stage`, `dense_only`, or `dual` | `single_stage` | +| `--policy.image_key` | Camera key for images | `observation.images.top` | +| `--policy.state_key` | Key for joint states | `observation.state` | +| `--policy.n_obs_steps` | Observation history steps (total obs frames = `n_obs_steps + 1`) | `8` | +| `--policy.frame_gap` | Gap (in frames) between sampled observations (at 30 fps: 30 β‰ˆ 1s) | `30` | + +--- + +## Step 4: Visualize Predictions + +Use `compute_rabc_weights.py` with `--visualize-only` to visualize model predictions (and, if available, annotation-derived targets) without writing a parquet file. + + + + +```bash +python src/lerobot/policies/sarm/compute_rabc_weights.py \ + --dataset-repo-id your-username/your-dataset \ + --reward-model-path your-username/sarm-model \ + --visualize-only \ + --num-visualizations 5 \ + --head-mode sparse \ + --output-dir ./sarm_viz +``` + + + + +```bash +python src/lerobot/policies/sarm/compute_rabc_weights.py \ + --dataset-repo-id your-username/your-dataset \ + --reward-model-path your-username/sarm-model \ + --visualize-only \ + --num-visualizations 5 \ + --head-mode dense \ + --output-dir ./sarm_viz +``` + + + + +```bash +python src/lerobot/policies/sarm/compute_rabc_weights.py \ + --dataset-repo-id your-username/your-dataset \ + --reward-model-path your-username/sarm-model \ + --visualize-only \ + --num-visualizations 5 \ + --head-mode both \ + --output-dir ./sarm_viz +``` + + + + +The visualization shows: + +- **Progress plot**: Predicted progress (and optional annotation-derived β€œGT” when available and `--stride 1`) +- **Stage probabilities**: Stacked area plot of predicted stage probabilities +- **Sample frames**: Key frames from the episode with progress/stage labels + +### Visualization Arguments + +| Argument | Description | +| ---------------------- | --------------------------------------------------------- | +| `--visualize-only` | Only visualize predictions (no RABC computation) | +| `--num-visualizations` | Number of episodes to visualize (default: 5) | +| `--head-mode` | SARM head to use: `sparse`, `dense`, or `both` | +| `--stride` | Compute every N frames, interpolate the rest (default: 1) | + +--- + +## Step 5 (Optional): Train Policy with RA-BC + +Reward-Aligned Behavior Cloning (RA-BC) uses the trained SARM model to weight training samples based on predicted progress improvement. This requires two steps: + +1. **Precompute progress values** for all frames using the trained SARM model +2. **Train policy** with RA-BC weighting using the precomputed values + +### How RA-BC Works + +For each training sample, RA-BC computes the progress delta: + +``` +r_i = Ο†(o_{t+Ξ”}) - Ο†(o_t) +``` + +Where `Ο†` is the SARM progress prediction and `Ξ”` is the policy's `chunk_size`. Samples with positive progress (good demonstrations) get higher weights, while samples with negative or zero progress get down-weighted. + +The weighting follows **Equations 8-9** from the paper: + +- **Soft weight**: `wΜƒ_i = clip((r_i βˆ’ (ΞΌ βˆ’ 2Οƒ)) / (4Οƒ + Ξ΅), 0, 1)` +- **Final weight**: `w_i = πŸ™{r_i > ΞΊ} + πŸ™{0 ≀ r_i ≀ ΞΊ} Γ— wΜƒ_i` + +### Step 5a: Compute SARM Progress Values + +First, run the SARM model on all frames in your dataset to compute progress values: + +```bash +python src/lerobot/policies/sarm/compute_rabc_weights.py \ + --dataset-repo-id your-username/your-dataset \ + --reward-model-path your-username/sarm-model \ + --head-mode sparse \ + --num-visualizations 5 \ + --push-to-hub +``` + +This script: + +- Processes all frames and computes progress values +- Saves progress values to a parquet file next to the dataset on disk (defaults to `/sarm_progress.parquet`) +- Generates visualizations of the first N episodes (default: 5) + +**Arguments:** + +| Argument | Description | Default | +| ---------------------- | -------------------------------------------------------------- | ---------- | +| `--reward-model-path` | Path to trained SARM model | (required) | +| `--head-mode` | SARM head to use: `sparse`, `dense`, or `both` | `sparse` | +| `--device` | Device for inference | `cuda` | +| `--visualize-only` | Only visualize predictions (no RA-BC computation) | `false` | +| `--num-visualizations` | Number of episodes to visualize (default: 5, set to 0 to skip) | `5` | + +**Output format** (`sarm_progress.parquet`): + +| Column | Description | +| ----------------- | ---------------------------------------------- | +| `index` | Global frame index in dataset | +| `episode_index` | Episode number | +| `frame_index` | Local frame index within episode | +| `progress_sparse` | Sparse head progress value [0, 1] | +| `progress_dense` | Dense head progress value [0, 1] (if computed) | + +### Step 5b: Train Policy with RA-BC + +Once you have the progress file, train your policy with RA-BC weighting. The progress file is auto-detected from the dataset path (`sarm_progress.parquet`). Currently PI0, PI0.5 and SmolVLA are supported with RA-BC: + +```bash +python src/lerobot/scripts/lerobot_train.py \ + --dataset.repo_id=your-username/your-dataset \ + --policy.type=pi0 \ + --use_rabc=true \ + --rabc_head_mode=sparse \ + --rabc_kappa=0.01 \ + --output_dir=outputs/train/policy_rabc \ + --batch_size=32 \ + --steps=40000 +``` + +The training script automatically: + +- Loads the precomputed progress values from the parquet file +- Uses the policy's `chunk_size` to compute progress deltas (Ξ”) +- Computes sample weights based on progress improvement +- Applies weighted loss during training + +**RA-BC Arguments:** + +| Argument | Description | Default | +| ---------------------- | ---------------------------------------------------------- | ---------------------------------- | +| `--use_rabc` | Enable RA-BC sample weighting | `false` | +| `--rabc_progress_path` | Path to progress parquet file (auto-detected from dataset) | `sarm_progress.parquet` in dataset | +| `--rabc_head_mode` | Which SARM head's progress to use: `sparse` or `dense` | `sparse` | +| `--rabc_kappa` | Threshold ΞΊ for high-quality samples | `0.01` | + +### Tuning RA-BC Kappa + +The `kappa` parameter is the threshold that determines which samples get full weight (w=1). Understanding how to tune it is critical for RA-BC to work effectively. + +**How the weighting works:** + +| Condition | Weight | +| ------------------- | ----------------------- | +| `delta > kappa` | 1.0 (hard threshold) | +| `0 ≀ delta ≀ kappa` | Soft weight from Eq. 8 | +| `delta < 0` | 0.0 (negative progress) | + +**Diagnosing kappa issues:** + +Monitor these WandB metrics during training: + +| Metric | Healthy Range | Problem Indicator | +| ------------------ | ------------- | ------------------------- | +| `rabc_mean_weight` | 0.3 - 0.8 | β‰ˆ 1.0 means kappa too low | +| `rabc_delta_mean` | > 0 | Should be positive | +| `rabc_delta_std` | > 0 | Variance in data quality | + +**If `rabc_mean_weight β‰ˆ 1.0`:** Your kappa is too low. Most samples have `delta > kappa` and bypass the soft-weighting entirely. RA-BC becomes equivalent to vanilla BC. + +**Setting kappa based on your data:** + +The default `kappa=0.01` was tuned for the paper's T-shirt folding task (~90s episodes at 30fps). For your dataset, check the logged `rabc_delta_mean` and `rabc_delta_std`: + +``` +# If delta_mean β‰ˆ 0.03 and delta_std β‰ˆ 0.02: +# Most deltas fall in range [0.01, 0.05] + +# Option 1: Set kappa = delta_mean (medium selectivity) +--rabc_kappa=0.03 + +# Option 2: Set kappa = delta_mean + delta_std (high selectivity) +--rabc_kappa=0.05 + +# Option 3: Set kappa = delta_mean + 2*delta_std (very selective) +--rabc_kappa=0.07 +``` + +**When RA-BC may not help:** + +If your dataset is already high quality (consistent progress across all demonstrations), RA-BC won't provide much benefit since there's nothing to filter. + +### Multi-GPU Training with RA-BC + +```bash +accelerate launch \ + --multi_gpu \ + --num_processes=4 \ + src/lerobot/scripts/lerobot_train.py \ + --dataset.repo_id=your-username/your-dataset \ + --policy.type=pi0 \ + --use_rabc=true \ + --rabc_kappa=0.01 \ + --output_dir=outputs/train/policy_rabc \ + --batch_size=32 \ + --steps=40000 +``` + +--- + +## Tips & Best Practices + +### Choosing a Mode + +- **Start with `single_stage`** for quick experiments - no annotation overhead +- Use **`dense_only`** when you want detailed progress tracking but tasks don't have clear high-level stages +- Use **`dual`** for complex tasks where both coarse and fine-grained progress is meaningful + +### Annotation Quality + +1. **Be specific with subtask names**: Instead of "fold", use "grab near side and fold toward center" +2. **Verify with visualization**: Always check a few episodes before training +3. **Consistent naming**: Use the same subtask names across all episodes + +### RA-BC + +1. **Train SARM first**: RA-BC quality depends entirely on SARM quality +2. **Monitor `rabc_mean_weight`**: If it's β‰ˆ 1.0, increase kappa (see [Tuning RA-BC Kappa](#tuning-ra-bc-kappa)) + +--- + +## Citation + +```bibtex +@article{chen2025sarm, + title={SARM: Stage-Aware Reward Modeling for Long Horizon Robot Manipulation}, + author={Chen, Qianzhong and Yu, Justin and Schwager, Mac and Abbeel, Pieter and Shentu, Yide and Wu, Philipp}, + journal={arXiv preprint arXiv:2509.25358}, + year={2025} +} +``` diff --git a/docs/source/unitree_g1.mdx b/docs/source/unitree_g1.mdx index 8f91e7791..af06fd742 100644 --- a/docs/source/unitree_g1.mdx +++ b/docs/source/unitree_g1.mdx @@ -4,11 +4,12 @@ This guide covers the complete setup process for the Unitree G1 humanoid, from i ## About the Unitree G1 -We offer support for both 29 and 23 DOF G1. In this first PR we introduce: +We offer support for both 29 and 23 DOF G1. We introduce: - **`unitree g1` robot class, handling low level communication with the humanoid** - **ZMQ socket bridge** for remote communication over WiFi, allowing one to deploy policies remotely instead of over ethernet or directly on the Orin - **GR00T locomotion policy** for bipedal walking and balance +- **MuJoCo simulation mode** for testing policies without the physical robot --- @@ -191,6 +192,10 @@ Press `Ctrl+C` to stop the policy. --- +## Extra: Running in Simulation Mode (MuJoCo) + +You can now test and develop policies without a physical robot using MuJoCo. to do so set `is_simulation=True` in config. + ## Additional Resources - [Unitree SDK Documentation](https://github.com/unitreerobotics/unitree_sdk2_python) diff --git a/docs/source/using_dataset_tools.mdx b/docs/source/using_dataset_tools.mdx index affca0ee5..29e16ea0a 100644 --- a/docs/source/using_dataset_tools.mdx +++ b/docs/source/using_dataset_tools.mdx @@ -11,13 +11,14 @@ LeRobot provides several utilities for manipulating datasets: 3. **Merge Datasets** - Combine multiple datasets into one. The datasets must have identical features, and episodes are concatenated in the order specified in `repo_ids` 4. **Add Features** - Add new features to a dataset 5. **Remove Features** - Remove features from a dataset +6. **Convert to Video** - Convert image-based datasets to video format for efficient storage The core implementation is in `lerobot.datasets.dataset_tools`. An example script detailing how to use the tools API is available in `examples/dataset/use_dataset_tools.py`. ## Command-Line Tool: lerobot-edit-dataset -`lerobot-edit-dataset` is a command-line script for editing datasets. It can be used to delete episodes, split datasets, merge datasets, add features, and remove features. +`lerobot-edit-dataset` is a command-line script for editing datasets. It can be used to delete episodes, split datasets, merge datasets, add features, remove features, and convert image datasets to video format. Run `lerobot-edit-dataset --help` for more information on the configuration of each operation. @@ -86,9 +87,71 @@ lerobot-edit-dataset \ --operation.feature_names "['observation.images.top']" ``` +#### Convert to Video + +Convert an image-based dataset to video format, creating a new LeRobotDataset where images are stored as videos. This is useful for reducing storage requirements and improving data loading performance. The new dataset will have the exact same structure as the original, but with images encoded as MP4 videos in the proper LeRobot format. + +```bash +# Local-only: Save to a custom output directory (no hub push) +lerobot-edit-dataset \ + --repo_id lerobot/pusht_image \ + --operation.type convert_to_video \ + --operation.output_dir /path/to/output/pusht_video + +# Save with new repo_id (local storage) +lerobot-edit-dataset \ + --repo_id lerobot/pusht_image \ + --new_repo_id lerobot/pusht_video \ + --operation.type convert_to_video + +# Convert and push to Hugging Face Hub +lerobot-edit-dataset \ + --repo_id lerobot/pusht_image \ + --new_repo_id lerobot/pusht_video \ + --operation.type convert_to_video \ + --push_to_hub true + +# Convert with custom video codec and quality settings +lerobot-edit-dataset \ + --repo_id lerobot/pusht_image \ + --operation.type convert_to_video \ + --operation.output_dir outputs/pusht_video \ + --operation.vcodec libsvtav1 \ + --operation.pix_fmt yuv420p \ + --operation.g 2 \ + --operation.crf 30 + +# Convert only specific episodes +lerobot-edit-dataset \ + --repo_id lerobot/pusht_image \ + --operation.type convert_to_video \ + --operation.output_dir outputs/pusht_video \ + --operation.episode_indices "[0, 1, 2, 5, 10]" + +# Convert with multiple workers for parallel processing +lerobot-edit-dataset \ + --repo_id lerobot/pusht_image \ + --operation.type convert_to_video \ + --operation.output_dir outputs/pusht_video \ + --operation.num_workers 8 +``` + +**Parameters:** + +- `output_dir`: Custom output directory (optional - by default uses `new_repo_id` or `{repo_id}_video`) +- `vcodec`: Video codec to use - options: `h264`, `hevc`, `libsvtav1` (default: `libsvtav1`) +- `pix_fmt`: Pixel format - options: `yuv420p`, `yuv444p` (default: `yuv420p`) +- `g`: Group of pictures (GOP) size - lower values give better quality but larger files (default: 2) +- `crf`: Constant rate factor - lower values give better quality but larger files, 0 is lossless (default: 30) +- `fast_decode`: Fast decode tuning option (default: 0) +- `episode_indices`: List of specific episodes to convert (default: all episodes) +- `num_workers`: Number of parallel workers for processing (default: 4) + +**Note:** The resulting dataset will be a proper LeRobotDataset with all cameras encoded as videos in the `videos/` directory, with parquet files containing only metadata (no raw image data). All episodes, stats, and tasks are preserved. + ### Push to Hub -Add the `--push_to_hub` flag to any command to automatically upload the resulting dataset to the Hugging Face Hub: +Add the `--push_to_hub true` flag to any command to automatically upload the resulting dataset to the Hugging Face Hub: ```bash lerobot-edit-dataset \ @@ -96,7 +159,45 @@ lerobot-edit-dataset \ --new_repo_id lerobot/pusht_after_deletion \ --operation.type delete_episodes \ --operation.episode_indices "[0, 2, 5]" \ - --push_to_hub + --push_to_hub true ``` There is also a tool for adding features to a dataset that is not yet covered in `lerobot-edit-dataset`. + +# Dataset Visualization + +## Online Visualization + +When you record a dataset using `lerobot`, it automatically uploads to the Hugging Face Hub unless you specify otherwise. To view the dataset online, use our **LeRobot Dataset Visualizer**, available at: +https://huggingface.co/spaces/lerobot/visualize_dataset + +## Local Visualization + +You can also visualize episodes from a dataset locally using our command-line tool. + +**From the Hugging Face Hub:** + +```bash +lerobot-dataset-viz \ + --repo-id lerobot/pusht \ + --episode-index 0 +``` + +**From a local folder:** +Add the `--root` option and set `--mode local`. For example, to search in `./my_local_data_dir/lerobot/pusht`: + +```bash +lerobot-dataset-viz \ + --repo-id lerobot/pusht \ + --root ./my_local_data_dir \ + --mode local \ + --episode-index 0 +``` + +Once executed, the tool opens `rerun.io` and displays the camera streams, robot states, and actions for the selected episode. + +For advanced usageβ€”including visualizing datasets stored on a remote serverβ€”run: + +```bash +lerobot-dataset-viz --help +``` diff --git a/docs/source/walloss.mdx b/docs/source/walloss.mdx new file mode 100644 index 000000000..12e9b1fc7 --- /dev/null +++ b/docs/source/walloss.mdx @@ -0,0 +1,74 @@ +# WALL-OSS + +WALL-OSS is an open-source foundation model for embodied intelligence, proposed by the [XSquare Robot](https://x2robot.com/en/research/68bc2cde8497d7f238dde690) team in 2025. The LeRobot implementation is adapted from their open-source [WallX](https://github.com/X-Square-Robot/wall-x) repository. + +X Square Robot’s WALL-OSS is now integrated into Hugging Face’s LeRobot ecosystem. This is an exciting collaborative project between the LeRobot and X Square Robot teams. You can now post-train, evaluate, and deploy WALL-OSS directly through LeRobot. With this, we’re aiming to make it easier for the open-source robotics community to customize and deploy WALL-OSS foundation models. Read and explore WALL-OSS [paper](https://arxiv.org/pdf/2509.11766) and [code](https://github.com/X-Square-Robot/wall-x). + +## Model Overview + +The WALL-OSS team is building the embodied foundation model to capture and compress the world's most valuable data: the continuous, high-fidelity stream of physical interaction. By creating a direct feedback loop between the model's decisions and the body's lived experience, the emergence of a truly generalizable intelligence is enabledβ€”one that understands not just how the world works, but how to act effectively within it. + +Technically, WALL-OSS introduces a tightly coupled multimodal architecture (tightly-coupled MoE structure) that integrates both discrete and continuous action modeling strategies. Through a two-stage training pipeline (Inspiration β†’ Integration), the model gradually unifies semantic reasoning and high-frequency action generation. Its core innovations include: + +- **Embodied perception–enhanced multimodal pretraining**: Large-scale training on unified vision–language–action data to strengthen spatial, causal, and manipulation understanding. +- **Unified Cross-Level Chain-of-Thought (Uni-CoT)**: A single differentiable framework that unifies high-level instruction reasoning, sub-task decomposition, and fine-grained action synthesis, forming a continuous chain from β€œunderstanding” to β€œexecution.” +- **Mixture-of-Experts (MoE) action heads**: Dynamically activating experts depending on the task phase and modeling actions in discrete or continuous space to maintain stable VLM priors. +- **Two-stage training paradigm**: + - **Inspiration stage**: Injecting discrete action priors to strengthen spatial understanding and semantic-action alignment. + - **Integration stage**: Using flow matching to achieve high-frequency continuous control. + +## Installation Requirements + +1. Install LeRobot by following our [Installation Guide](./installation). +2. Install WallX dependencies by running: + + ```bash + pip install -e ".[wallx]" + ``` + +## Usage + +To use WallX in LeRobot, specify the policy type as: + +```python +policy.type=wall_x +``` + +## Training + +For training WallX, you can use the standard LeRobot training script with the appropriate configuration: + +```bash +python src/lerobot/scripts/lerobot_train.py \ + --dataset.repo_id=your_dataset \ + --policy.type=wall_x \ + --output_dir=./outputs/wallx_training \ + --job_name=wallx_training \ + --policy.repo_id=your_repo_id \ + --policy.pretrained_name_or_path=x-square-robot/wall-oss-flow \ + --policy.prediction_mode=diffusion \ + --policy.attn_implementation=eager \ + --steps=3000 \ + --policy.device=cuda \ + --batch_size=32 +``` + +### Training Arguments + +| Argument | Description | +| ------------------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `--dataset.repo_id` | The Hugging Face Hub repository ID for your training dataset (e.g., `lerobot/aloha_sim_insertion_human`) | +| `--policy.type` | Specifies using the WallX policy architecture | +| `--output_dir` | Local directory where training checkpoints and logs will be saved | +| `--job_name` | A name identifier for this training run (used in logging/tracking) | +| `--policy.repo_id` | Your Hugging Face Hub repo ID where the trained model will be pushed | +| `--policy.pretrained_path` | Path to pretrained WallX weights to initialize from (the official WALL-OSS checkpoint) | +| `--policy.prediction_mode` | The action prediction strategy: `diffusion` or `fast` - `diffusion` uses iterative denoising for action generation, `fast` uses next token prediction instead | +| `--policy.attn_implementation` | Attention implementation backend - `eager` uses standard PyTorch attention (alternatives include `flash_attention_2` or `sdpa`) | +| `--steps` | Total number of training steps to run | +| `--policy.device` | Device to train on (`cuda` for GPU, `cpu` for CPU) | +| `--batch_size` | Number of samples per training batch | + +## License + +This model follows the **Apache 2.0 License**, consistent with the original [WallX repository](https://github.com/X-Square-Robot/wall-x). diff --git a/docs/source/xvla.mdx b/docs/source/xvla.mdx index aa974477f..dd7d1ef57 100644 --- a/docs/source/xvla.mdx +++ b/docs/source/xvla.mdx @@ -24,7 +24,7 @@ Built from pure Transformer encoders, X-VLA scales naturally with model size and XVLA Architecture 2

@@ -120,7 +120,7 @@ Adapted for Google Robot platforms. ### Recommended Training Configuration -When fine-tuning X-VLA for a new embodiment or task, we recommend the following freezing strategy: +When fine-tuning X-VLA for a new embodiment or task, we recommend not freezing the VLM, and also setting the `policy.dtype=bfloat16` to not hit OOM errors. ```bash lerobot-train \ @@ -129,25 +129,26 @@ lerobot-train \ --job_name=xvla_training \ --policy.path="lerobot/xvla-base" \ --policy.repo_id="HF_USER/xvla-your-robot" \ - --steps=3000 \ + --policy.dtype=bfloat16 \ + --policy.action_mode=auto \ + --steps=20000 \ --policy.device=cuda \ - --policy.freeze_vision_encoder=True \ - --policy.freeze_language_encoder=True \ - --policy.train_policy_transformer=True \ - --policy.train_soft_prompts=True \ - --policy.action_mode=YOUR_ACTION_MODE + --policy.freeze_vision_encoder=false \ + --policy.freeze_language_encoder=false \ + --policy.train_policy_transformer=true \ + --policy.train_soft_prompts=true \ ``` ### Training Parameters Explained -| Parameter | Default | Description | -| -------------------------- | ------- | ---------------------------------------- | -| `freeze_vision_encoder` | `True` | Freeze the VLM vision encoder weights | -| `freeze_language_encoder` | `True` | Freeze the VLM language encoder weights | -| `train_policy_transformer` | `True` | Allow policy transformer layers to train | -| `train_soft_prompts` | `True` | Allow soft prompts to train | +| Parameter | Default | Description | +| -------------------------- | ------- | ---------------------------------------------- | +| `freeze_vision_encoder` | `false` | Do not freeze the VLM vision encoder weights | +| `freeze_language_encoder` | `false` | Do not freeze the VLM language encoder weights | +| `train_policy_transformer` | `true` | Allow policy transformer layers to train | +| `train_soft_prompts` | `true` | Allow soft prompts to train | -**πŸ’‘ Best Practice**: For Phase II adaptation to new embodiments, freeze the VLM encoders and only train the policy transformer and soft prompts. This provides excellent sample efficiency with minimal compute. +**πŸ’‘ Best Practice**: For Phase II adaptation to new embodiments, do not freeze the VLM encoders and also train the policy transformer and soft prompts. ### Example: Training on Bimanual Robot @@ -157,14 +158,15 @@ lerobot-train \ --output_dir=./outputs/xvla_bimanual \ --job_name=xvla_so101_training \ --policy.path="lerobot/xvla-base" \ + --policy.dtype=bfloat16 \ --policy.repo_id="YOUR_USERNAME/xvla-biso101" \ --steps=3000 \ --policy.device=cuda \ --policy.action_mode=so101_bimanual \ - --policy.freeze_vision_encoder=True \ - --policy.freeze_language_encoder=True \ - --policy.train_policy_transformer=True \ - --policy.train_soft_prompts=True + --policy.freeze_vision_encoder=false \ + --policy.freeze_language_encoder=false \ + --policy.train_policy_transformer=true \ + --policy.train_soft_prompts=true ``` πŸ’‘ **Best Performance:** If you have sufficient computational resources and want to achieve best X-VLA finetuning performance, you should follow the official finetuning strategy: @@ -172,71 +174,7 @@ lerobot-train \ **πŸ”₯ Full-finetune all components with a custom learning-rate scheme** To ensure stable optimization, the Vision-Language Model (VLM) must be trained with only 1/10 of the base learning rate, while all other components use the full LR. -This LR ratio is crucial for achieving strong and stable finetuning performance. -To enable this behavior, you must: - -1. Implement a custom optimizer and register it in your training config - -``` -from dataclasses import dataclass, asdict -from lerobot.optim.optimizers import OptimizerConfig -import torch - -@OptimizerConfig.register_subclass("xvla-adamw") -@dataclass -class XVLAAdamW(OptimizerConfig): - lr: float = 1e-4 - betas: tuple[float, float] = (0.9, 0.99) - eps: float = 1e-8 - weight_decay: float = 0.0 - grad_clip_norm: float = 10.0 - - def build(self, params: dict) -> torch.optim.Optimizer: - """ - Expect `named_parameters()` as input. - Apply lr = lr / 10 for all VLM-related parameters. - """ - assert isinstance(params, dict), \ - "Custom LR optimizer requires `named_parameters()` as inputs." - kwargs = asdict(self) - kwargs.pop("grad_clip_norm") - vlm_group, other_group = [], [] - for name, p in params.items(): - if not p.requires_grad: - continue - if "vlm" in name.lower(): - vlm_group.append(p) - else: - other_group.append(p) - - param_groups = [ - {"params": vlm_group, "lr": self.lr * 0.1, "weight_decay": self.weight_decay * 0.1}, - {"params": other_group, "lr": self.lr, "weight_decay": self.weight_decay}, - ] - - return torch.optim.AdamW(param_groups, **kwargs) -``` - -2. Modify X-VLA’s get_optim_params to return named parameters - -Replace: - -``` -def get_optim_params(self) -> dict: - """Return only trainable parameters for optimization.""" - return filter(lambda p: p.requires_grad, self.parameters()) -``` - -with: - -``` -def get_optim_params(self): - """Return trainable named parameters.""" - return filter(lambda kv: kv[1].requires_grad, self.named_parameters()) -``` - -This ensures the optimizer receives a dict of named parameters, allowing it to correctly detect VLM modules and apply the 1/10 LR rule. - +This LR ratio is crucial for achieving strong and stable finetuning performance. This is already done for you by default. ❕Note Completely matching the official reported performance may require an additional warm-up LR schedule for soft-prompts, which can bring minor improvements. @@ -326,6 +264,26 @@ domain_id = 3 The domain_id is automatically added to observations by the `XVLAAddDomainIdProcessorStep` in the preprocessing pipeline. +The `lerobot/xvla-base` model has been trained on the following domain IDs. It is recommended to choose one that most resembles your robot/configuration: + +#### Fine-tuning Datasets + +| Dataset Name | Domain ID | +| ---------------- | --------- | +| Bridge | 0 | +| RT1 | 1 | +| Calvin | 2 | +| libero | 3 | +| widowx-air | 4 | +| AIR-AGILEX-HQ | 5 | +| robotwin2_abs_ee | 6 | +| robotwin2_clean | 6 | +| robocasa-human | 7 | +| VLABench | 8 | +| AGIBOT-challenge | 9 | +| AIR-AGILEX | 10 | +| AIRBOT | 18 | + ### 3. Processor Steps X-VLA requires specific preprocessing and postprocessing steps for proper operation. diff --git a/media/gym/aloha_act.gif b/media/gym/aloha_act.gif deleted file mode 100644 index 0285a3dd1..000000000 Binary files a/media/gym/aloha_act.gif and /dev/null differ diff --git a/media/gym/pusht_diffusion.gif b/media/gym/pusht_diffusion.gif deleted file mode 100644 index 2c0129048..000000000 Binary files a/media/gym/pusht_diffusion.gif and /dev/null differ diff --git a/media/gym/simxarm_tdmpc.gif b/media/gym/simxarm_tdmpc.gif deleted file mode 100644 index fc7a19b14..000000000 Binary files a/media/gym/simxarm_tdmpc.gif and /dev/null differ diff --git a/media/hope_jr/hopejr.png b/media/hope_jr/hopejr.png deleted file mode 100644 index 4186547a2..000000000 Binary files a/media/hope_jr/hopejr.png and /dev/null differ diff --git a/media/lekiwi/kiwi.webp b/media/lekiwi/kiwi.webp deleted file mode 100644 index 2dd7d9256..000000000 Binary files a/media/lekiwi/kiwi.webp and /dev/null differ diff --git a/media/lerobot-logo-light.png b/media/lerobot-logo-light.png deleted file mode 100644 index 9a93b50da..000000000 Binary files a/media/lerobot-logo-light.png and /dev/null differ diff --git a/media/readme/VLA_architecture.jpg b/media/readme/VLA_architecture.jpg new file mode 100644 index 000000000..146decc86 Binary files /dev/null and b/media/readme/VLA_architecture.jpg differ diff --git a/media/lerobot-logo-thumbnail.png b/media/readme/lerobot-logo-thumbnail.png similarity index 100% rename from media/lerobot-logo-thumbnail.png rename to media/readme/lerobot-logo-thumbnail.png diff --git a/media/readme/robots_control_video.webp b/media/readme/robots_control_video.webp new file mode 100644 index 000000000..eed7dc4ce Binary files /dev/null and b/media/readme/robots_control_video.webp differ diff --git a/media/readme/so100_video.webp b/media/readme/so100_video.webp new file mode 100644 index 000000000..200e3fe53 Binary files /dev/null and b/media/readme/so100_video.webp differ diff --git a/media/so100/leader_follower.webp b/media/so100/leader_follower.webp deleted file mode 100644 index 83cf4b231..000000000 Binary files a/media/so100/leader_follower.webp and /dev/null differ diff --git a/media/so101/so101-leader.webp b/media/so101/so101-leader.webp deleted file mode 100644 index 22ff3a4bc..000000000 Binary files a/media/so101/so101-leader.webp and /dev/null differ diff --git a/media/so101/so101.webp b/media/so101/so101.webp deleted file mode 100644 index ce65e94bc..000000000 Binary files a/media/so101/so101.webp and /dev/null differ diff --git a/media/wandb.png b/media/wandb.png deleted file mode 100644 index 8adc3d2ae..000000000 Binary files a/media/wandb.png and /dev/null differ diff --git a/pyproject.toml b/pyproject.toml index 050b604e8..48e071d32 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -96,7 +96,7 @@ dependencies = [ # Common pygame-dep = ["pygame>=2.5.1,<2.7.0"] placo-dep = ["placo>=0.9.6,<0.10.0"] -transformers-dep = ["transformers>=4.53.0,<5.0.0"] +transformers-dep = ["transformers>=4.57.1,<5.0.0"] grpcio-dep = ["grpcio==1.73.1", "protobuf==6.31.0"] # TODO: Bumb dependency (compatible with wandb) # Motors @@ -120,6 +120,13 @@ intelrealsense = [ phone = ["hebi-py>=2.8.0,<2.12.0", "teleop>=0.1.0,<0.2.0", "fastapi<1.0"] # Policies +wallx = [ + "transformers==4.49.0", + "peft==0.17.1", + "scipy==1.15.3", + "torchdiffeq==0.2.5", + "qwen_vl_utils==0.0.11" +] pi = ["transformers @ git+https://github.com/huggingface/transformers.git@fix/lerobot_openpi"] smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0", "safetensors>=0.4.3,<1.0.0"] groot = [ @@ -133,6 +140,7 @@ groot = [ "ninja>=1.11.1,<2.0.0", "flash-attn>=2.5.9,<3.0.0 ; sys_platform != 'darwin'" ] +sarm = ["lerobot[transformers-dep]", "faker>=33.0.0,<35.0.0", "matplotlib>=3.10.3,<4.0.0", "qwen-vl-utils>=0.0.14"] xvla = ["lerobot[transformers-dep]"] hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"] @@ -140,7 +148,7 @@ hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpci async = ["lerobot[grpcio-dep]", "matplotlib>=3.10.3,<4.0.0"] # Development -dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1"] +dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1", "mypy>=1.19.1"] test = ["pytest>=8.1.0,<9.0.0", "pytest-timeout>=2.4.0,<3.0.0", "pytest-cov>=5.0.0,<8.0.0", "mock-serial>=0.0.1,<0.1.0 ; sys_platform != 'win32'"] video_benchmark = ["scikit-image>=0.23.2,<0.26.0", "pandas>=2.2.2,<2.4.0"] @@ -159,6 +167,7 @@ all = [ "lerobot[reachy2]", "lerobot[kinematics]", "lerobot[intelrealsense]", + # "lerobot[wallx]", "lerobot[pi]", "lerobot[smolvla]", # "lerobot[groot]", TODO(Steven): Gr00t requires specific installation instructions for flash-attn @@ -173,6 +182,7 @@ all = [ "lerobot[phone]", "lerobot[libero]", "lerobot[metaworld]", + "lerobot[sarm]" ] [project.scripts] @@ -227,6 +237,7 @@ ignore = [ [tool.ruff.lint.per-file-ignores] "__init__.py" = ["F401", "F403"] +"src/lerobot/policies/wall_x/**" = ["N801", "N812", "SIM102", "SIM108", "SIM210", "SIM211", "B006", "B007", "SIM118"] # Supprese these as they are coming from original Qwen2_5_vl code TODO(pepijn): refactor original [tool.ruff.lint.isort] combine-as-imports = true @@ -263,6 +274,7 @@ default.extend-ignore-identifiers-re = [ "ein", "thw", "inpt", + "ROBOTIS", ] # TODO: Uncomment when ready to use @@ -317,9 +329,9 @@ disallow_untyped_defs = true disallow_incomplete_defs = true check_untyped_defs = true -# [[tool.mypy.overrides]] -# module = "lerobot.optim.*" -# ignore_errors = false +[[tool.mypy.overrides]] +module = "lerobot.optim.*" +ignore_errors = false [[tool.mypy.overrides]] module = "lerobot.model.*" @@ -369,3 +381,40 @@ ignore_errors = false # [[tool.mypy.overrides]] # module = "lerobot.scripts.*" # ignore_errors = false + +[tool.uv] +# wallx requires transformers==4.49.0 which conflicts with other extras that need >=4.53.0 +conflicts = [ + [ + { extra = "wallx" }, + { extra = "transformers-dep" }, + ], + [ + { extra = "wallx" }, + { extra = "pi" }, + ], + [ + { extra = "wallx" }, + { extra = "smolvla" }, + ], + [ + { extra = "wallx" }, + { extra = "groot" }, + ], + [ + { extra = "wallx" }, + { extra = "xvla" }, + ], + [ + { extra = "wallx" }, + { extra = "hilserl" }, + ], + [ + { extra = "wallx" }, + { extra = "libero" }, + ], + [ + { extra = "wallx" }, + { extra = "all" }, + ], +] diff --git a/src/lerobot/async_inference/constants.py b/src/lerobot/async_inference/constants.py index 1b1dac0f5..f8b6d7bb3 100644 --- a/src/lerobot/async_inference/constants.py +++ b/src/lerobot/async_inference/constants.py @@ -26,4 +26,4 @@ DEFAULT_OBS_QUEUE_TIMEOUT = 2 SUPPORTED_POLICIES = ["act", "smolvla", "diffusion", "tdmpc", "vqbet", "pi0", "pi05"] # TODO: Add all other robots -SUPPORTED_ROBOTS = ["so100_follower", "so101_follower", "bi_so100_follower"] +SUPPORTED_ROBOTS = ["so100_follower", "so101_follower", "bi_so100_follower", "omx_follower"] diff --git a/src/lerobot/async_inference/robot_client.py b/src/lerobot/async_inference/robot_client.py index f9d70a64e..d32aa6a21 100644 --- a/src/lerobot/async_inference/robot_client.py +++ b/src/lerobot/async_inference/robot_client.py @@ -54,6 +54,7 @@ from lerobot.robots import ( # noqa: F401 bi_so100_follower, koch_follower, make_robot_from_config, + omx_follower, so100_follower, so101_follower, ) diff --git a/src/lerobot/configs/train.py b/src/lerobot/configs/train.py index d17915c36..cee9dfdf9 100644 --- a/src/lerobot/configs/train.py +++ b/src/lerobot/configs/train.py @@ -56,6 +56,7 @@ class TrainPipelineConfig(HubMixin): steps: int = 100_000 eval_freq: int = 20_000 log_freq: int = 200 + tolerance_s: float = 1e-4 save_checkpoint: bool = True # Checkpoint is saved every `save_freq` training iterations and after the last training step. save_freq: int = 20_000 @@ -64,9 +65,17 @@ class TrainPipelineConfig(HubMixin): scheduler: LRSchedulerConfig | None = None eval: EvalConfig = field(default_factory=EvalConfig) wandb: WandBConfig = field(default_factory=WandBConfig) - checkpoint_path: Path | None = field(init=False, default=None) + + # RA-BC (Reward-Aligned Behavior Cloning) parameters + use_rabc: bool = False # Enable reward-weighted training + rabc_progress_path: str | None = None # Path to precomputed SARM progress parquet file + rabc_kappa: float = 0.01 # Hard threshold for high-quality samples + rabc_epsilon: float = 1e-6 # Small constant for numerical stability + rabc_head_mode: str | None = "sparse" # For dual-head models: "sparse" or "dense" + # Rename map for the observation to override the image and state keys rename_map: dict[str, str] = field(default_factory=dict) + checkpoint_path: Path | None = field(init=False, default=None) def validate(self) -> None: # HACK: We parse again the cli args here to get the pretrained paths if there was some. @@ -130,6 +139,14 @@ class TrainPipelineConfig(HubMixin): "'policy.repo_id' argument missing. Please specify it to push the model to the hub." ) + if self.use_rabc and not self.rabc_progress_path: + # Auto-detect from dataset path + repo_id = self.dataset.repo_id + if self.dataset.root: + self.rabc_progress_path = str(Path(self.dataset.root) / "sarm_progress.parquet") + else: + self.rabc_progress_path = f"hf://datasets/{repo_id}/sarm_progress.parquet" + @classmethod def __get_path_fields__(cls) -> list[str]: """This enables the parser to load config from the policy using `--policy.path=local/dir`""" diff --git a/src/lerobot/data_processing/__init__.py b/src/lerobot/data_processing/__init__.py new file mode 100644 index 000000000..2f76d5676 --- /dev/null +++ b/src/lerobot/data_processing/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/lerobot/data_processing/sarm_annotations/__init__.py b/src/lerobot/data_processing/sarm_annotations/__init__.py new file mode 100644 index 000000000..2f76d5676 --- /dev/null +++ b/src/lerobot/data_processing/sarm_annotations/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/lerobot/data_processing/sarm_annotations/subtask_annotation.py b/src/lerobot/data_processing/sarm_annotations/subtask_annotation.py new file mode 100644 index 000000000..67e37bab8 --- /dev/null +++ b/src/lerobot/data_processing/sarm_annotations/subtask_annotation.py @@ -0,0 +1,1202 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +SARM Subtask Annotation using local GPU (Qwen3-VL). + +This script implements the annotation approach from the SARM paper using local GPU inference: +"SARM: Stage-Aware Reward Modeling for Long Horizon Robot Manipulation" +Paper: https://arxiv.org/pdf/2509.25358 + +What it does: +1. Takes videos from a LeRobot dataset +2. Uses Qwen3-VL running locally on GPU to identify when subtasks occur +3. Saves subtask timestamps to the dataset metadata +4. Optionally pushes the annotated dataset to HuggingFace Hub + +SARM trains reward models that predict: + - Stage: Which subtask is currently being executed (discrete classification) + - Progress: How far along the subtask we are (continuous 0-1) + +Supports three annotation modes: + 1. No annotations (no args): Auto-creates single sparse "task" stage covering full episode. + Use with SARM config annotation_mode="single_stage" for simple tasks. + + 2. Dense-only (--dense-only --dense-subtasks): Dense annotations from VLM, auto-generated + single sparse "task" stage. Use with annotation_mode="dense_only". + + 3. Dual mode (--sparse-subtasks + --dense-subtasks): Both sparse and dense annotations + from VLM. Use with annotation_mode="dual". + +Requirements: + - GPU with sufficient VRAM (16GB+ recommended for 30B model) + - `pip install transformers, torch, qwen-vl-utils` + +Run with: +```bash +python examples/dataset_annotation/subtask_annotation.py \ + --repo-id your-username/your-dataset \ + --sparse-subtasks "Do ..." \ + --dense-subtasks "Do task 1, Do task 2, Do task 3" \ + --video-key observation.images.base \ + --push-to-hub +``` +""" + +import argparse +import json +import multiprocessing as mp +import random +import re +import subprocess +import tempfile +import textwrap +import time +from concurrent.futures import ProcessPoolExecutor, as_completed +from pathlib import Path +from typing import Any + +import cv2 +import numpy as np +import pandas as pd +import torch +from pydantic import BaseModel, Field +from transformers import AutoProcessor, Qwen3VLMoeForConditionalGeneration + +from lerobot.datasets.lerobot_dataset import LeRobotDataset + + +# Pydantic Models for SARM Subtask Annotation +class Timestamp(BaseModel): + """Timestamp in MM:SS or SS format""" + + start: str = Field(description="Start timestamp (MM:SS or just seconds)") + end: str = Field(description="End timestamp (MM:SS or just seconds)") + + +class Subtask(BaseModel): + """Individual subtask/stage - must use EXACT names from provided list""" + + name: str = Field(description="Subtask name - MUST match one from the predefined list exactly") + timestamps: Timestamp + + +class SubtaskAnnotation(BaseModel): + """Complete annotation for a robot manipulation episode""" + + subtasks: list[Subtask] = Field(description="List of all subtasks in temporal order") + + +def compute_temporal_proportions( + annotations: dict[int, Any], fps: int = 30, subtask_order: list[str] | None = None +) -> dict[str, float]: + """ + Compute dataset-level temporal proportions (priors) for each subtask. + + Implements SARM Paper Formula (1): αΎ±_k = (1/M) Γ— Ξ£_i (L_{i,k} / T_i) + + Args: + annotations: Dict mapping episode index to SubtaskAnnotation object. + fps: Frames per second (unused, kept for API compatibility) + subtask_order: Optional list defining the output order of subtasks. + + Returns: + Dict mapping subtask name to its temporal proportion (αΎ±_k), ordered by subtask_order if provided. + """ + subtask_proportions: dict[str, list[float]] = {} + + for annotation in annotations.values(): + total_duration = 0 + durations: dict[str, int] = {} + + for subtask in annotation.subtasks: + start_parts = subtask.timestamps.start.split(":") + end_parts = subtask.timestamps.end.split(":") + + start_seconds = ( + int(start_parts[0]) * 60 + int(start_parts[1]) + if len(start_parts) == 2 + else int(start_parts[0]) + ) + end_seconds = ( + int(end_parts[0]) * 60 + int(end_parts[1]) if len(end_parts) == 2 else int(end_parts[0]) + ) + + duration = end_seconds - start_seconds + durations[subtask.name] = duration + total_duration += duration + + if total_duration > 0: + for name, duration in durations.items(): + if name not in subtask_proportions: + subtask_proportions[name] = [] + subtask_proportions[name].append(duration / total_duration) + + if not subtask_proportions: + return {} + + avg_proportions = {name: sum(props) / len(props) for name, props in subtask_proportions.items()} + + total = sum(avg_proportions.values()) + if total > 0: + avg_proportions = {name: prop / total for name, prop in avg_proportions.items()} + + # Reorder according to subtask_order if provided + if subtask_order: + avg_proportions = { + name: avg_proportions.get(name, 0.0) for name in subtask_order if name in avg_proportions + } + + return avg_proportions + + +def create_sarm_prompt(subtask_list: list[str]) -> str: + subtask_str = "\n".join([f" - {name}" for name in subtask_list]) + + return textwrap.dedent(f"""\ + # Role + You are a Robotics Vision System specializing in temporal action localization for robot manipulation. Your job is to segment a single demonstration video into distinct, non-overlapping atomic actions from a fixed subtask list. + + # Subtask Label Set (Closed Vocabulary) + You must strictly identify the video segments using ONLY the following labels. Do not create new labels or modify existing ones: + + [ + {subtask_str} + ] + + The video shows one successful execution of all subtasks in a logical order. + + # Ground-Truth Semantics (Very Important) + Use **visual state changes** to define when a subtask starts and ends. Do NOT assume equal durations for the subtasks. + + - A subtask **starts** at the first frame where the robot's motion clearly initiates that subtask. + - A subtask **ends** at the first frame where that specific action is visually completed and the manipulated object reaches a temporary, stable configuration. + + If there are short pauses or micro-motions that don't clearly correspond to a new subtask, they belong to the **current** subtask. + + # Hard Constraints & Logic + 1. **Continuous Coverage (No Gaps):** + - The entire video duration from "00:00" to the final timestamp must be covered by subtasks. + - There can be no gaps between subtasks. + - If there is any idle or ambiguous time between clear actions, extend the *preceding* subtask to cover it. + + 2. **Boundary Consistency:** + - The `"end"` timestamp of one subtask must be exactly equal to the `"start"` timestamp of the next subtask. + - Boundaries must coincide with a real visual state transition, not just a convenient time split. + + 3. **Chronological Order, One Occurrence Each:** + - This is a single successful demonstration. + - Each subtask from the vocabulary appears **exactly once**, in the correct logical order. + - **Durations may be very different** between subtasks. Never assume they are similar lengths. Base all boundaries only on the video. + + 4. **Reject Uniform Segmentation (Important):** + - Do NOT simply divide the video into equal or nearly equal time chunks. + - If your boundaries would result in subtasks with similar durations (e.g. all around 5 seconds), treat this as evidence that your segmentation is wrong and refine the boundaries. + - Only use nearly equal durations if the video truly shows each subtask taking the same amount of time (this is very rare). + + 5. **Timestamps:** + - Timestamps must be in `"MM:SS"` format. + - The first subtask always starts at `"00:00"`. + - The last subtask ends at the final visible frame of the video. + + # Step 1 β€” Textual Timeline (must do this first) + First, write a extensive and detailed textual timeline describing what happens in the video with approximate timestamps. + For each subtask, include: + - its name + - an approximate start and end time, + - an description of the visual event at the boundary (e.g. "shirt fully folded to the left", "robot rotates folded shirt 90 degrees"). + + Format this as a bullet list. + + # Step 2 β€” JSON Output (final answer) + After the textual timeline, output **only** valid JSON with this structure. + The JSON **must** be consistent with the textual timeline above: + + {{ + "subtasks": [ + {{ + "name": "EXACT_NAME_FROM_LIST", + "timestamps": {{ + "start": "MM:SS", + "end": "MM:SS" + }} + }}, + {{ + "name": "EXACT_NAME_FROM_LIST", + "timestamps": {{ + "start": "MM:SS", + "end": "MM:SS" + }} + }} + ] + }} + + Do not add any extra keys to the JSON. + """) + + +class VideoAnnotator: + """Annotates robot manipulation videos using local Qwen3-VL model on GPU""" + + def __init__( + self, + subtask_list: list[str], + model_name: str = "Qwen/Qwen3-VL-30B-A3B-Instruct", + device: str = "cuda", + torch_dtype: torch.dtype = torch.bfloat16, + model: Qwen3VLMoeForConditionalGeneration | None = None, # noqa: F821 + processor: AutoProcessor | None = None, # noqa: F821 + ): + """ + Initialize the video annotator with local model. + + Args: + subtask_list: List of allowed subtask names (for consistency) + model_name: Hugging Face model name (default: Qwen/Qwen3-VL-30B-A3B-Instruct) + device: Device to use (cuda, cpu) + torch_dtype: Data type for model (bfloat16, float16, float32) + model: Pre-loaded model instance (optional, to share between annotators) + processor: Pre-loaded processor instance (optional, to share between annotators) + """ + self.subtask_list = subtask_list + self.prompt = create_sarm_prompt(subtask_list) + self.device = device + + # Use provided model/processor or load new ones + if model is not None and processor is not None: + self.model = model + self.processor = processor + print(f"Using shared model on {device}") + else: + from transformers import AutoProcessor, Qwen3VLMoeForConditionalGeneration + + print(f"Loading model: {model_name}...") + + self.model = Qwen3VLMoeForConditionalGeneration.from_pretrained( + model_name, torch_dtype=torch_dtype, device_map=device, trust_remote_code=True + ) + + self.processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True) + + print(f"Model loaded successfully on {device}") + + def extract_episode_segment( + self, file_path: Path, start_timestamp: float, end_timestamp: float, target_fps: int = 1 + ) -> Path: + """ + Extract a specific episode segment from concatenated video. + Uses minimal compression to preserve quality for local inference. + + Args: + file_path: Path to the concatenated video file + start_timestamp: Starting timestamp in seconds (within this video file) + end_timestamp: Ending timestamp in seconds (within this video file) + target_fps: Target FPS (default: 1 for faster processing) + + Returns: + Path to extracted video file + """ + # Create temporary file for extracted video + with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp_file: + tmp_path = Path(tmp_file.name) + + try: + # Check if ffmpeg is available + subprocess.run( # nosec B607 + ["ffmpeg", "-version"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True + ) + except (subprocess.CalledProcessError, FileNotFoundError) as err: + raise RuntimeError("ffmpeg not found, cannot extract episode segment") from err + + try: + # Calculate duration + duration = end_timestamp - start_timestamp + + print(f"Extracting episode: {start_timestamp:.1f}s-{end_timestamp:.1f}s ({duration:.1f}s)") + + # Use ffmpeg to extract segment with minimal quality loss + cmd = [ + "ffmpeg", + "-i", + str(file_path), + "-ss", + str(start_timestamp), + "-t", + str(duration), + "-r", + str(target_fps), + "-c:v", + "libx264", + "-preset", + "ultrafast", + "-crf", + "23", + "-an", + "-y", + str(tmp_path), + ] + + subprocess.run(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True) + + # Verify the output file was created and is not empty + if not tmp_path.exists() or tmp_path.stat().st_size == 0: + print("Video extraction failed (0 bytes) - skipping episode") + if tmp_path.exists(): + tmp_path.unlink() + raise RuntimeError("FFmpeg produced empty video file") + + # Show extraction results + file_size_mb = tmp_path.stat().st_size / (1024 * 1024) + + # Fail if file is too small (< 100KB likely means extraction failed) + if file_size_mb < 0.1: + print(f"Extracted video too small ({file_size_mb:.2f}MB) - skipping episode") + tmp_path.unlink() + raise RuntimeError(f"Video extraction produced invalid file ({file_size_mb:.2f}MB)") + + print(f"Extracted: {file_size_mb:.1f}MB ({target_fps} FPS)") + + return tmp_path + + except subprocess.CalledProcessError as e: + raise RuntimeError(f"ffmpeg failed ({e})") from e + + def annotate( + self, + file_path: str | Path, + fps: int, + start_timestamp: float = 0.0, + end_timestamp: float | None = None, + max_retries: int = 3, + ) -> SubtaskAnnotation: + """Annotate a video segment using local GPU.""" + from qwen_vl_utils import process_vision_info + + file_path = Path(file_path) + + if end_timestamp is None: + cap = cv2.VideoCapture(str(file_path)) + end_timestamp = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) / (cap.get(cv2.CAP_PROP_FPS) or 1) + cap.release() + + duration = end_timestamp - start_timestamp + duration_str = f"{int(duration // 60):02d}:{int(duration % 60):02d}" + + extracted_path = self.extract_episode_segment(file_path, start_timestamp, end_timestamp, 1) + is_extracted = extracted_path != file_path + + try: + messages = [ + {"role": "system", "content": [{"type": "text", "text": self.prompt}]}, + { + "role": "user", + "content": [ + {"type": "video", "video": str(extracted_path), "fps": 1.0}, + { + "type": "text", + "text": f"Video is {duration_str} (~{duration:.1f}s). Follow instructions.", + }, + ], + }, + ] + + for attempt in range(max_retries): + try: + text = self.processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + image_inputs, video_inputs = process_vision_info(messages) + inputs = self.processor( + text=[text], + images=image_inputs, + videos=video_inputs, + padding=True, + return_tensors="pt", + ).to(self.device) + + with torch.no_grad(): + generated_ids = self.model.generate( + **inputs, max_new_tokens=1024, do_sample=True, temperature=0.7 + ) + + response = self.processor.batch_decode( + [out[len(inp) :] for inp, out in zip(inputs.input_ids, generated_ids, strict=True)], + skip_special_tokens=True, + )[0].strip() + + # Extract JSON + if "```json" in response: + response = response.split("```json")[1].split("```")[0] + elif "```" in response: + response = response.split("```")[1].split("```")[0] + + try: + return SubtaskAnnotation.model_validate(json.loads(response)) + except json.JSONDecodeError: + match = re.search(r"\{.*\}", response, re.DOTALL) + if match: + return SubtaskAnnotation.model_validate(json.loads(match.group())) + raise ValueError("No JSON found") from None + except Exception as e: + if attempt == max_retries - 1: + raise RuntimeError(f"Failed after {max_retries} attempts") from e + time.sleep(1) + finally: + if is_extracted and extracted_path.exists(): + extracted_path.unlink() + + +def display_annotation(annotation: SubtaskAnnotation, episode_idx: int, fps: int, prefix: str = ""): + """Display annotation summary.""" + subtask_summary = ", ".join( + f"{s.name}({s.timestamps.start}-{s.timestamps.end})" for s in annotation.subtasks + ) + print(f"Episode {episode_idx} {prefix}: {len(annotation.subtasks)} subtasks - {subtask_summary}") + + +def timestamp_to_seconds(timestamp: str) -> float: + """Convert MM:SS or SS timestamp to seconds""" + parts = timestamp.split(":") + if len(parts) == 2: + return int(parts[0]) * 60 + int(parts[1]) + else: + return int(parts[0]) + + +def extract_frame(video_path: Path, timestamp: float) -> np.ndarray | None: + """Extract a single frame from video at given timestamp.""" + cap = cv2.VideoCapture(str(video_path)) + if not cap.isOpened(): + return None + cap.set(cv2.CAP_PROP_POS_MSEC, timestamp * 1000) + ret, frame = cap.read() + cap.release() + return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) if ret else None + + +def draw_timeline(ax, subtasks, total_duration, colors): + """Draw a timeline with color-coded subtask segments.""" + import matplotlib.patches as mpatches + + bar_height, bar_y = 0.6, 0.5 + + for i, subtask in enumerate(subtasks): + start = timestamp_to_seconds(subtask.timestamps.start) + end = timestamp_to_seconds(subtask.timestamps.end) + color = colors[i % len(colors)] + + rect = mpatches.FancyBboxPatch( + (start, bar_y - bar_height / 2), + end - start, + bar_height, + boxstyle="round,pad=0.02,rounding_size=0.1", + facecolor=color, + edgecolor="white", + linewidth=1.5, + alpha=0.85, + ) + ax.add_patch(rect) + + # Add label if segment is wide enough + duration = end - start + if duration > total_duration * 0.06: + ax.text( + (start + end) / 2, + bar_y, + subtask.name, + ha="center", + va="center", + fontsize=8, + fontweight="bold", + color="white", + rotation=0 if duration > total_duration * 0.12 else 45, + ) + + if i > 0: + ax.axvline(x=start, ymin=0.1, ymax=0.9, color="white", linestyle="--", linewidth=1.5, alpha=0.7) + + ax.axvline(x=0, ymin=0.1, ymax=0.9, color="#00ff00", linestyle="-", linewidth=2, alpha=0.9) + if subtasks: + ax.axvline( + x=timestamp_to_seconds(subtasks[-1].timestamps.end), + ymin=0.1, + ymax=0.9, + color="white", + linestyle="--", + linewidth=1.5, + alpha=0.7, + ) + + ax.set_xlim(-total_duration * 0.02, total_duration * 1.02) + ax.set_ylim(-0.1, 1.1) + ax.set_xlabel("Time (seconds)", fontsize=10, color="white", labelpad=5) + for spine in ["top", "right", "left"]: + ax.spines[spine].set_visible(False) + ax.spines["bottom"].set_color("#444444") + ax.tick_params(axis="x", colors="#888888", labelsize=8) + ax.tick_params(axis="y", left=False, labelleft=False) + + +def visualize_episode( + ep_idx: int, + annotation: SubtaskAnnotation, + video_path: Path, + video_start: float, + video_end: float, + output_path: Path, + video_key: str, + ann_type: str, +): + """Create visualization for a single episode with frames and timeline.""" + import matplotlib.pyplot as plt + + if annotation is None: + print(f"No {ann_type} annotation for episode {ep_idx}") + return + + subtasks = annotation.subtasks + if not subtasks: + print(f"No subtasks for episode {ep_idx}") + return + + colors = plt.cm.tab10(np.linspace(0, 1, max(len(subtasks), 10))) + total_duration = timestamp_to_seconds(subtasks[-1].timestamps.end) + + # Extract middle frame from each subtask + sample_frames, frame_times = [], [] + for subtask in subtasks: + start = timestamp_to_seconds(subtask.timestamps.start) + end = timestamp_to_seconds(subtask.timestamps.end) + mid = (start + end) / 2 + frame_times.append(mid) + sample_frames.append(extract_frame(video_path, video_start + mid)) + + # Create figure + fig_width = max(16, len(subtasks) * 2.5) + fig = plt.figure(figsize=(fig_width, 10)) + fig.patch.set_facecolor("#1a1a2e") + + gs = fig.add_gridspec( + 2, + max(len(subtasks), 1), + height_ratios=[2, 1], + hspace=0.3, + wspace=0.1, + left=0.05, + right=0.95, + top=0.88, + bottom=0.1, + ) + + fig.suptitle( + f"Episode {ep_idx} - {ann_type.capitalize()} Annotations", + fontsize=18, + fontweight="bold", + color="white", + y=0.96, + ) + fig.text( + 0.5, + 0.91, + f"Camera: {video_key} | Duration: {video_end - video_start:.1f}s | {len(subtasks)} subtasks", + ha="center", + fontsize=11, + color="#888888", + ) + + # Plot frames + for i, (frame, subtask) in enumerate(zip(sample_frames, subtasks, strict=True)): + ax = fig.add_subplot(gs[0, i]) + ax.set_facecolor("#16213e") + if frame is not None: + ax.imshow(frame) + else: + ax.text( + 0.5, 0.5, "N/A", ha="center", va="center", fontsize=12, color="white", transform=ax.transAxes + ) + ax.set_title(subtask.name, fontsize=10, fontweight="bold", color=colors[i % len(colors)], pad=8) + ax.axis("off") + ax.text( + 0.5, + -0.08, + f"t={frame_times[i]:.1f}s", + ha="center", + fontsize=9, + color="#888888", + transform=ax.transAxes, + ) + + # Plot timeline + ax_timeline = fig.add_subplot(gs[1, :]) + ax_timeline.set_facecolor("#16213e") + draw_timeline(ax_timeline, subtasks, total_duration, colors) + + output_path.parent.mkdir(parents=True, exist_ok=True) + plt.savefig(output_path, dpi=150, facecolor=fig.get_facecolor(), edgecolor="none", bbox_inches="tight") + plt.close() + print(f"Saved: {output_path}") + + +def visualize_annotations( + dataset: LeRobotDataset, + sparse_annotations: dict[int, SubtaskAnnotation], + dense_annotations: dict[int, SubtaskAnnotation] | None, + video_key: str, + output_dir: Path, + num_episodes: int = 5, + annotation_type: str = "sparse", + episode_indices: list[int] | None = None, +): + """ + Visualize subtask annotations for a set of episodes. + + Args: + dataset: LeRobotDataset instance + sparse_annotations: Dict mapping episode index to sparse annotations + dense_annotations: Dict mapping episode index to dense annotations (or None) + video_key: Camera/video key to use + output_dir: Directory to save visualization images + num_episodes: Number of episodes to visualize (ignored if episode_indices provided) + annotation_type: "sparse", "dense", or "both" + episode_indices: Specific episode indices to visualize (optional) + """ + # Determine available episodes based on annotation type + if annotation_type == "sparse": + available = set(sparse_annotations.keys()) + elif annotation_type == "dense": + available = set(dense_annotations.keys()) if dense_annotations else set() + else: # both + sparse_set = set(sparse_annotations.keys()) + dense_set = set(dense_annotations.keys()) if dense_annotations else set() + available = sparse_set | dense_set + + if not available: + print("Error: No annotations found to visualize.") + return + + # Select episodes to visualize + if episode_indices: + episodes = sorted([e for e in episode_indices if e in available]) + missing = set(episode_indices) - available + if missing: + print(f"Episodes not found in annotations: {sorted(missing)}") + else: + episodes = sorted(random.sample(list(available), min(num_episodes, len(available)))) + print(f"Visualizing {len(episodes)} episodes: {episodes}") + output_dir.mkdir(parents=True, exist_ok=True) + + # Generate visualizations + for i, ep_idx in enumerate(episodes, 1): + print(f"Processing episode {ep_idx} ({i}/{len(episodes)})") + video_path = dataset.root / dataset.meta.get_video_file_path(ep_idx, video_key) + if not video_path.exists(): + print(f"Video not found: {video_path}") + continue + + video_start = float(dataset.meta.episodes[f"videos/{video_key}/from_timestamp"][ep_idx]) + video_end = float(dataset.meta.episodes[f"videos/{video_key}/to_timestamp"][ep_idx]) + + if annotation_type == "both": + # Visualize both sparse and dense + for ann_type, annotations in [("sparse", sparse_annotations), ("dense", dense_annotations)]: + if annotations and ep_idx in annotations: + output_path = output_dir / f"episode_{ep_idx:04d}_{ann_type}.png" + visualize_episode( + ep_idx, + annotations.get(ep_idx), + video_path, + video_start, + video_end, + output_path, + video_key, + ann_type, + ) + else: + annotations = sparse_annotations if annotation_type == "sparse" else dense_annotations + if annotations and ep_idx in annotations: + output_path = output_dir / f"episode_{ep_idx:04d}_{annotation_type}.png" + visualize_episode( + ep_idx, + annotations.get(ep_idx), + video_path, + video_start, + video_end, + output_path, + video_key, + annotation_type, + ) + + print(f"Visualizations saved to: {output_dir.absolute()}") + + +def save_annotations_to_dataset( + dataset_path: Path, annotations: dict[int, SubtaskAnnotation], fps: int, prefix: str = "sparse" +): + """Save annotations to LeRobot dataset parquet format.""" + from lerobot.datasets.utils import DEFAULT_EPISODES_PATH, load_episodes + + episodes_dataset = load_episodes(dataset_path) + if not episodes_dataset or len(episodes_dataset) == 0: + return + + episodes_df = episodes_dataset.to_pandas() + cols = [ + f"{prefix}_{c}" + for c in [ + "subtask_names", + "subtask_start_times", + "subtask_end_times", + "subtask_start_frames", + "subtask_end_frames", + ] + ] + for col in cols: + episodes_df[col] = None + + for ep_idx, ann in annotations.items(): + if ep_idx >= len(episodes_df): + continue + names, starts, ends, start_frames, end_frames = [], [], [], [], [] + for s in ann.subtasks: + names.append(s.name) + st, et = timestamp_to_seconds(s.timestamps.start), timestamp_to_seconds(s.timestamps.end) + starts.append(st) + ends.append(et) + start_frames.append(int(st * fps)) + end_frames.append(int(et * fps)) + episodes_df.at[ep_idx, cols[0]] = names + episodes_df.at[ep_idx, cols[1]] = starts + episodes_df.at[ep_idx, cols[2]] = ends + episodes_df.at[ep_idx, cols[3]] = start_frames + episodes_df.at[ep_idx, cols[4]] = end_frames + + # Group by file and write + for ep_idx in episodes_df.index: + key = ( + episodes_df.loc[ep_idx, "meta/episodes/chunk_index"], + episodes_df.loc[ep_idx, "meta/episodes/file_index"], + ) + path = dataset_path / DEFAULT_EPISODES_PATH.format(chunk_index=key[0], file_index=key[1]) + if path.exists(): + file_df = pd.read_parquet(path) + for col in cols + ( + [ + "subtask_names", + "subtask_start_times", + "subtask_end_times", + "subtask_start_frames", + "subtask_end_frames", + ] + if prefix == "sparse" + else [] + ): + if col not in file_df.columns: + file_df[col] = None + if ep_idx in annotations: + for col in cols: + file_df.at[ep_idx, col] = episodes_df.loc[ep_idx, col] + if prefix == "sparse": # Legacy columns + for i, legacy in enumerate( + [ + "subtask_names", + "subtask_start_times", + "subtask_end_times", + "subtask_start_frames", + "subtask_end_frames", + ] + ): + file_df.at[ep_idx, legacy] = episodes_df.loc[ep_idx, cols[i]] + file_df.to_parquet(path, engine="pyarrow", compression="snappy") + + +def generate_auto_sparse_annotations( + dataset: LeRobotDataset, episode_indices: list[int], video_key: str +) -> dict[int, SubtaskAnnotation]: + """Auto-generate single 'task' stage annotations for all episodes.""" + annotations = {} + for ep_idx in episode_indices: + start = float(dataset.meta.episodes[f"videos/{video_key}/from_timestamp"][ep_idx]) + end = float(dataset.meta.episodes[f"videos/{video_key}/to_timestamp"][ep_idx]) + duration = end - start + end_str = f"{int(duration // 60):02d}:{int(duration % 60):02d}" + annotations[ep_idx] = SubtaskAnnotation( + subtasks=[Subtask(name="task", timestamps=Timestamp(start="00:00", end=end_str))] + ) + return annotations + + +def load_annotations_from_dataset(dataset_path: Path, prefix: str = "sparse") -> dict[int, SubtaskAnnotation]: + """Load annotations from LeRobot dataset parquet files.""" + from lerobot.datasets.utils import load_episodes + + episodes_dataset = load_episodes(dataset_path) + if not episodes_dataset or len(episodes_dataset) == 0: + return {} + + col_names = f"{prefix}_subtask_names" + col_start = f"{prefix}_subtask_start_times" + col_end = f"{prefix}_subtask_end_times" + + # Fall back to legacy columns for sparse + if col_names not in episodes_dataset.column_names: + if prefix == "sparse" and "subtask_names" in episodes_dataset.column_names: + col_names, col_start, col_end = "subtask_names", "subtask_start_times", "subtask_end_times" + else: + return {} + + df = episodes_dataset.to_pandas() + annotations = {} + for ep_idx in df.index: + names = df.loc[ep_idx, col_names] + if names is None or (isinstance(names, float) and pd.isna(names)): + continue + starts, ends = df.loc[ep_idx, col_start], df.loc[ep_idx, col_end] + annotations[int(ep_idx)] = SubtaskAnnotation( + subtasks=[ + Subtask( + name=n, + timestamps=Timestamp( + start=f"{int(s) // 60:02d}:{int(s) % 60:02d}", + end=f"{int(e) // 60:02d}:{int(e) % 60:02d}", + ), + ) + for n, s, e in zip(names, starts, ends, strict=True) + ] + ) + return annotations + + +def process_single_episode( + ep_idx: int, + dataset_root: Path, + dataset_meta, + video_key: str, + fps: int, + annotator: VideoAnnotator, +) -> tuple[int, SubtaskAnnotation | None, str | None]: + """Process a single episode annotation.""" + try: + video_path = dataset_root / dataset_meta.get_video_file_path(ep_idx, video_key) + if not video_path.exists(): + return ep_idx, None, f"Video not found: {video_path}" + + start = float(dataset_meta.episodes[f"videos/{video_key}/from_timestamp"][ep_idx]) + end = float(dataset_meta.episodes[f"videos/{video_key}/to_timestamp"][ep_idx]) + return ep_idx, annotator.annotate(video_path, fps, start, end), None + except Exception as e: + return ep_idx, None, str(e) + + +def worker_process_episodes( + worker_id: int, + gpu_id: int, + episode_indices: list[int], + repo_id: str, + video_key: str, + sparse_subtask_list: list[str], + dense_subtask_list: list[str] | None, + model_name: str, + torch_dtype: torch.dtype, +) -> tuple[dict, dict | None]: + """Worker for parallel processing across GPUs.""" + device = f"cuda:{gpu_id}" + dataset = LeRobotDataset(repo_id, download_videos=False) + + sparse_annotator = VideoAnnotator(sparse_subtask_list, model_name, device, torch_dtype) + dense_annotator = ( + VideoAnnotator( + dense_subtask_list, + model_name, + device, + torch_dtype, + sparse_annotator.model, + sparse_annotator.processor, + ) + if dense_subtask_list + else None + ) + + sparse_annotations, dense_annotations = {}, {} if dense_subtask_list else None + + for ep_idx in episode_indices: + _, sparse_ann, err = process_single_episode( + ep_idx, dataset.root, dataset.meta, video_key, dataset.fps, sparse_annotator + ) + if sparse_ann: + sparse_annotations[ep_idx] = sparse_ann + + if dense_annotator: + _, dense_ann, _ = process_single_episode( + ep_idx, dataset.root, dataset.meta, video_key, dataset.fps, dense_annotator + ) + if dense_ann: + dense_annotations[ep_idx] = dense_ann + + return sparse_annotations, dense_annotations + + +def main(): + parser = argparse.ArgumentParser(description="SARM-style subtask annotation using local GPU (Qwen3-VL)") + parser.add_argument("--repo-id", type=str, required=True, help="HuggingFace dataset repository ID") + parser.add_argument( + "--sparse-subtasks", type=str, default=None, help="Comma-separated sparse subtask names" + ) + parser.add_argument( + "--dense-subtasks", type=str, default=None, help="Comma-separated dense subtask names" + ) + parser.add_argument( + "--dense-only", action="store_true", help="Dense-only mode with auto-generated sparse 'task' stage" + ) + parser.add_argument("--episodes", type=int, nargs="+", default=None, help="Episode indices to annotate") + parser.add_argument("--model", type=str, default="Qwen/Qwen3-VL-30B-A3B-Instruct", help="VLM model") + parser.add_argument("--skip-existing", action="store_true", help="Skip already annotated episodes") + parser.add_argument("--video-key", type=str, default=None, help="Video key (default: first available)") + parser.add_argument("--push-to-hub", action="store_true", help="Push to HuggingFace Hub") + parser.add_argument("--output-repo-id", type=str, default=None, help="Output repo ID for push") + parser.add_argument("--device", type=str, default="cuda", help="Device (cuda/cpu)") + parser.add_argument("--dtype", type=str, default="bfloat16", choices=["bfloat16", "float16", "float32"]) + parser.add_argument("--num-workers", type=int, default=1, help="Parallel workers for multi-GPU") + parser.add_argument("--gpu-ids", type=int, nargs="+", default=None, help="GPU IDs to use") + # Visualization options + parser.add_argument( + "--visualize-only", + action="store_true", + help="Only visualize existing annotations (no generation)", + ) + parser.add_argument( + "--num-visualizations", + type=int, + default=5, + help="Number of episodes to visualize (default: 5)", + ) + parser.add_argument( + "--visualize-type", + type=str, + default="sparse", + choices=["sparse", "dense", "both"], + help="Type of annotations to visualize (default: sparse)", + ) + parser.add_argument( + "--output-dir", + type=str, + default="./subtask_viz", + help="Output directory for visualizations (default: ./subtask_viz)", + ) + + args = parser.parse_args() + + # Load dataset first (needed for both annotation and visualization) + print(f"Loading dataset: {args.repo_id}") + dataset = LeRobotDataset(args.repo_id, download_videos=True) + fps = dataset.fps + + if not dataset.meta.video_keys: + raise ValueError("No video keys found") + + video_key = ( + args.video_key if args.video_key in (dataset.meta.video_keys or []) else dataset.meta.video_keys[0] + ) + print(f"Using camera: {video_key}, FPS: {fps}") + + # Handle visualization-only mode + if args.visualize_only: + print("Visualization-only mode") + sparse_annotations = load_annotations_from_dataset(dataset.root, prefix="sparse") + dense_annotations = load_annotations_from_dataset(dataset.root, prefix="dense") + + if not sparse_annotations and not dense_annotations: + return print("Error: No annotations found. Run annotation first.") + + print(f"Found {len(sparse_annotations)} sparse, {len(dense_annotations)} dense annotations") + + visualize_annotations( + dataset=dataset, + sparse_annotations=sparse_annotations, + dense_annotations=dense_annotations if dense_annotations else None, + video_key=video_key, + output_dir=Path(args.output_dir), + num_episodes=args.num_visualizations, + annotation_type=args.visualize_type, + episode_indices=args.episodes, + ) + return + + # Validate arguments for annotation mode + if args.dense_only and not args.dense_subtasks: + return print("Error: --dense-only requires --dense-subtasks") + if args.dense_subtasks and not args.sparse_subtasks and not args.dense_only: + return print("Error: --dense-subtasks requires --sparse-subtasks or --dense-only") + + sparse_subtask_list = ( + [s.strip() for s in args.sparse_subtasks.split(",")] if args.sparse_subtasks else None + ) + dense_subtask_list = [s.strip() for s in args.dense_subtasks.split(",")] if args.dense_subtasks else None + auto_sparse = sparse_subtask_list is None + dense_mode = dense_subtask_list is not None + torch_dtype = {"bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32}[args.dtype] + + # Determine episodes + episode_indices = args.episodes or list(range(dataset.meta.total_episodes)) + + existing_annotations = load_annotations_from_dataset(dataset.root, prefix="sparse") + if args.skip_existing: + episode_indices = [ep for ep in episode_indices if ep not in existing_annotations] + + if not episode_indices: + return print("All episodes already annotated!") + print(f"Annotating {len(episode_indices)} episodes") + + # GPU setup + gpu_ids = args.gpu_ids or list( + range(min(args.num_workers, torch.cuda.device_count() if torch.cuda.is_available() else 1)) + ) + args.num_workers = len(gpu_ids) + + sparse_annotations = existing_annotations.copy() + dense_annotations = {} if dense_mode else None + + # Auto-sparse mode + if auto_sparse: + sparse_annotations.update(generate_auto_sparse_annotations(dataset, episode_indices, video_key)) + save_annotations_to_dataset(dataset.root, sparse_annotations, fps, prefix="sparse") + print(f"Auto-generated {len(episode_indices)} sparse 'task' annotations") + + # VLM annotation (for sparse if not auto, and for dense) + need_vlm = (not auto_sparse) or dense_mode + + if need_vlm: + if args.num_workers > 1 and not auto_sparse: + # Parallel processing + print(f"Parallel processing with {args.num_workers} workers") + episodes_per_worker = [[] for _ in range(args.num_workers)] + for i, ep_idx in enumerate(episode_indices): + episodes_per_worker[i % args.num_workers].append(ep_idx) + + with ProcessPoolExecutor( + max_workers=args.num_workers, mp_context=mp.get_context("spawn") + ) as executor: + futures = [ + executor.submit( + worker_process_episodes, + w, + gpu_ids[w], + episodes_per_worker[w], + args.repo_id, + video_key, + sparse_subtask_list, + dense_subtask_list, + args.model, + torch_dtype, + ) + for w in range(args.num_workers) + if episodes_per_worker[w] + ] + + for future in as_completed(futures): + try: + worker_sparse, worker_dense = future.result() + sparse_annotations.update(worker_sparse) + if dense_mode and worker_dense: + dense_annotations.update(worker_dense) + save_annotations_to_dataset(dataset.root, sparse_annotations, fps, prefix="sparse") + if dense_mode: + save_annotations_to_dataset(dataset.root, dense_annotations, fps, prefix="dense") + except Exception as e: + raise RuntimeError(f"Worker failed: {e}") from e + else: + # Sequential processing + sparse_annotator = ( + VideoAnnotator(sparse_subtask_list, args.model, args.device, torch_dtype) + if not auto_sparse and sparse_subtask_list + else None + ) + dense_annotator = ( + VideoAnnotator( + dense_subtask_list, + args.model, + args.device, + torch_dtype, + sparse_annotator.model if sparse_annotator else None, + sparse_annotator.processor if sparse_annotator else None, + ) + if dense_mode + else None + ) + + for i, ep_idx in enumerate(episode_indices): + print(f"Episode {ep_idx} ({i + 1}/{len(episode_indices)})") + + if sparse_annotator: + _, sparse_ann, err = process_single_episode( + ep_idx, dataset.root, dataset.meta, video_key, fps, sparse_annotator + ) + if sparse_ann: + sparse_annotations[ep_idx] = sparse_ann + save_annotations_to_dataset(dataset.root, sparse_annotations, fps, prefix="sparse") + elif err: + print(f"Sparse failed: {err}") + + if dense_annotator: + _, dense_ann, err = process_single_episode( + ep_idx, dataset.root, dataset.meta, video_key, fps, dense_annotator + ) + if dense_ann: + dense_annotations[ep_idx] = dense_ann + save_annotations_to_dataset(dataset.root, dense_annotations, fps, prefix="dense") + elif err: + print(f"Dense failed: {err}") + + # Save temporal proportions + def save_proportions(annotations, prefix, subtask_list=None, is_auto=False): + props: dict[str, float] = ( + {"task": 1.0} if is_auto else compute_temporal_proportions(annotations, fps, subtask_list) + ) + path = dataset.root / "meta" / f"temporal_proportions_{prefix}.json" + path.parent.mkdir(parents=True, exist_ok=True) + with open(path, "w") as f: + json.dump(props, f, indent=2) + print(f"Saved {prefix} temporal proportions") + + save_proportions(sparse_annotations, "sparse", sparse_subtask_list, auto_sparse) + if dense_mode and dense_annotations: + save_proportions(dense_annotations, "dense", dense_subtask_list) + + print(f"\nComplete! {len(sparse_annotations)} sparse, {len(dense_annotations or {})} dense annotations") + + # Visualize annotations after generation + if args.num_visualizations > 0: + print(f"\nGenerating {args.num_visualizations} visualizations...") + visualize_type = "both" if dense_mode else "sparse" + visualize_annotations( + dataset=dataset, + sparse_annotations=sparse_annotations, + dense_annotations=dense_annotations, + video_key=video_key, + output_dir=Path(args.output_dir), + num_episodes=args.num_visualizations, + annotation_type=visualize_type, + ) + + if args.push_to_hub: + try: + dataset.push_to_hub(push_videos=True) + print(f"Pushed to {args.output_repo_id or args.repo_id}") + except Exception as e: + print(f"Push failed: {e}") + + +if __name__ == "__main__": + main() diff --git a/src/lerobot/datasets/factory.py b/src/lerobot/datasets/factory.py index f3ceb2b0c..31e939809 100644 --- a/src/lerobot/datasets/factory.py +++ b/src/lerobot/datasets/factory.py @@ -98,6 +98,7 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas image_transforms=image_transforms, revision=cfg.dataset.revision, video_backend=cfg.dataset.video_backend, + tolerance_s=cfg.tolerance_s, ) else: dataset = StreamingLeRobotDataset( @@ -108,6 +109,7 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas image_transforms=image_transforms, revision=cfg.dataset.revision, max_num_shards=cfg.num_workers, + tolerance_s=cfg.tolerance_s, ) else: raise NotImplementedError("The MultiLeRobotDataset isn't supported for now.") diff --git a/src/lerobot/optim/factory.py b/src/lerobot/optim/factory.py index bab95d0ce..699289993 100644 --- a/src/lerobot/optim/factory.py +++ b/src/lerobot/optim/factory.py @@ -35,6 +35,8 @@ def make_optimizer_and_scheduler( tuple[Optimizer, LRScheduler | None]: The couple (Optimizer, Scheduler). Scheduler can be `None`. """ params = policy.get_optim_params() if cfg.use_policy_training_preset else policy.parameters() + if cfg.optimizer is None: + raise ValueError("Optimizer config is required but not provided in TrainPipelineConfig") optimizer = cfg.optimizer.build(params) lr_scheduler = cfg.scheduler.build(optimizer, cfg.steps) if cfg.scheduler is not None else None return optimizer, lr_scheduler diff --git a/src/lerobot/optim/optimizers.py b/src/lerobot/optim/optimizers.py index 5120f828c..2b75353d9 100644 --- a/src/lerobot/optim/optimizers.py +++ b/src/lerobot/optim/optimizers.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import abc +from collections.abc import Iterable from dataclasses import asdict, dataclass, field from pathlib import Path from typing import Any @@ -29,6 +30,17 @@ from lerobot.utils.constants import ( ) from lerobot.utils.io_utils import deserialize_json_into_object +# Type alias for parameters accepted by optimizer build() methods. +# This matches PyTorch's optimizer signature while also supporting: +# - dict[str, Parameter]: Named parameters for differential LR by name (e.g., XVLA) +# - dict[str, Iterable]: Multiple parameter groups for multi-optimizer configs (e.g., SAC) +OptimizerParams = ( + Iterable[torch.nn.Parameter] # From model.parameters() + | Iterable[dict[str, Any]] # List of param groups with lr/weight_decay overrides + | dict[str, torch.nn.Parameter] # From dict(model.named_parameters()) for name-based LR + | dict[str, Any] # For multi-optimizer configs (SAC) with multiple param groups +) + @dataclass class OptimizerConfig(draccus.ChoiceRegistry, abc.ABC): @@ -45,13 +57,24 @@ class OptimizerConfig(draccus.ChoiceRegistry, abc.ABC): return "adam" @abc.abstractmethod - def build(self) -> torch.optim.Optimizer | dict[str, torch.optim.Optimizer]: + def build(self, params: OptimizerParams) -> torch.optim.Optimizer | dict[str, torch.optim.Optimizer]: """ Build the optimizer. It can be a single optimizer or a dictionary of optimizers. + NOTE: Multiple optimizers are useful when you have different models to optimize. For example, you can have one optimizer for the policy and another one for the value function in reinforcement learning settings. + Args: + params: Parameters to optimize. Accepts multiple formats depending on the optimizer: + - Iterable[Parameter]: From model.parameters() - standard PyTorch usage + - Iterable[dict]: List of param groups with 'params' key and optional + 'lr', 'weight_decay' overrides (e.g., ACT, VQBeT policies) + - dict[str, Parameter]: From dict(model.named_parameters()) for optimizers + that apply differential learning rates by parameter name (e.g., XVLA) + - dict[str, Iterable]: For multi-optimizer configs where each key maps to + a separate optimizer's parameters (e.g., SAC with actor/critic/temperature) + Returns: The optimizer or a dictionary of optimizers. """ @@ -67,7 +90,7 @@ class AdamConfig(OptimizerConfig): weight_decay: float = 0.0 grad_clip_norm: float = 10.0 - def build(self, params: dict) -> torch.optim.Optimizer: + def build(self, params: OptimizerParams) -> torch.optim.Optimizer: kwargs = asdict(self) kwargs.pop("grad_clip_norm") return torch.optim.Adam(params, **kwargs) @@ -82,7 +105,7 @@ class AdamWConfig(OptimizerConfig): weight_decay: float = 1e-2 grad_clip_norm: float = 10.0 - def build(self, params: dict) -> torch.optim.Optimizer: + def build(self, params: OptimizerParams) -> torch.optim.Optimizer: kwargs = asdict(self) kwargs.pop("grad_clip_norm") return torch.optim.AdamW(params, **kwargs) @@ -98,7 +121,7 @@ class SGDConfig(OptimizerConfig): weight_decay: float = 0.0 grad_clip_norm: float = 10.0 - def build(self, params: dict) -> torch.optim.Optimizer: + def build(self, params: OptimizerParams) -> torch.optim.Optimizer: kwargs = asdict(self) kwargs.pop("grad_clip_norm") return torch.optim.SGD(params, **kwargs) @@ -139,21 +162,19 @@ class XVLAAdamWConfig(OptimizerConfig): soft_prompt_lr_scale: float = 1.0 # Scale factor for soft-prompt LR (1.0 = same as base LR) soft_prompt_warmup_lr_scale: float | None = None # If set, start soft-prompts at this scale (e.g., 0.01) - def build(self, params: dict) -> torch.optim.Optimizer: + def build(self, params: OptimizerParams) -> torch.optim.Optimizer: """ Build AdamW optimizer with differential learning rates. - Expects `named_parameters()` as input (dict of name -> param). - Applies: - - lr * 0.1 for all VLM-related parameters - - lr * soft_prompt_lr_scale for soft-prompt parameters (with optional warmup) - - full lr for all other parameters - Args: - params: Dictionary of parameter names to parameters (from named_parameters()) + params: Must be a dict[str, Parameter] from dict(model.named_parameters()) + or equivalent. Returns: AdamW optimizer with parameter groups for VLM, soft-prompts, and other components + + Raises: + AssertionError: If params is not a dict (e.g., from model.parameters()) """ assert isinstance(params, dict), "Custom LR optimizer requires `named_parameters()` as inputs." @@ -174,7 +195,7 @@ class XVLAAdamWConfig(OptimizerConfig): # Start at warmup scale, scheduler will warm up to soft_prompt_lr soft_prompt_lr = self.lr * self.soft_prompt_warmup_lr_scale - param_groups = [ + param_groups: list[dict[str, Any]] = [ { "params": vlm_group, "lr": self.lr * 0.1, @@ -224,19 +245,25 @@ class MultiAdamConfig(OptimizerConfig): grad_clip_norm: float = 10.0 optimizer_groups: dict[str, dict[str, Any]] = field(default_factory=dict) - def build(self, params_dict: dict[str, list]) -> dict[str, torch.optim.Optimizer]: + def build(self, params: OptimizerParams) -> dict[str, torch.optim.Optimizer]: """Build multiple Adam optimizers. Args: - params_dict: Dictionary mapping parameter group names to lists of parameters - The keys should match the keys in optimizer_groups + params: Must be a dict[str, Iterable[Parameter]] mapping parameter group names + to iterables of parameters. The keys should match the keys in optimizer_groups. + Typically from policies that need separate optimizers (e.g., SAC with + actor/critic/temperature). Returns: Dictionary mapping parameter group names to their optimizers + + Raises: + AssertionError: If params is not a dict """ + assert isinstance(params, dict), "MultiAdamConfig requires a dict of parameter groups as inputs." optimizers = {} - for name, params in params_dict.items(): + for name, group_params in params.items(): # Get group-specific hyperparameters or use defaults group_config = self.optimizer_groups.get(name, {}) @@ -248,7 +275,7 @@ class MultiAdamConfig(OptimizerConfig): "weight_decay": group_config.get("weight_decay", self.weight_decay), } - optimizers[name] = torch.optim.Adam(params, **optimizer_kwargs) + optimizers[name] = torch.optim.Adam(group_params, **optimizer_kwargs) return optimizers diff --git a/src/lerobot/optim/schedulers.py b/src/lerobot/optim/schedulers.py index b5d54b396..4af7f0802 100644 --- a/src/lerobot/optim/schedulers.py +++ b/src/lerobot/optim/schedulers.py @@ -30,7 +30,7 @@ from lerobot.utils.io_utils import deserialize_json_into_object @dataclass class LRSchedulerConfig(draccus.ChoiceRegistry, abc.ABC): - num_warmup_steps: int + num_warmup_steps: int | None @property def type(self) -> str: diff --git a/src/lerobot/policies/__init__.py b/src/lerobot/policies/__init__.py index 788542d49..99275e787 100644 --- a/src/lerobot/policies/__init__.py +++ b/src/lerobot/policies/__init__.py @@ -21,6 +21,7 @@ from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig from .smolvla.processor_smolvla import SmolVLANewLineProcessor from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig +from .wall_x.configuration_wall_x import WallXConfig as WallXConfig from .xvla.configuration_xvla import XVLAConfig as XVLAConfig __all__ = [ @@ -29,8 +30,10 @@ __all__ = [ "PI0Config", "PI05Config", "SmolVLAConfig", + "SARMConfig", "TDMPCConfig", "VQBeTConfig", "GrootConfig", "XVLAConfig", + "WallXConfig", ] diff --git a/src/lerobot/policies/act/modeling_act.py b/src/lerobot/policies/act/modeling_act.py index b7cbcd061..a5c48eb3d 100644 --- a/src/lerobot/policies/act/modeling_act.py +++ b/src/lerobot/policies/act/modeling_act.py @@ -50,6 +50,7 @@ class ACTPolicy(PreTrainedPolicy): def __init__( self, config: ACTConfig, + **kwargs, ): """ Args: diff --git a/src/lerobot/policies/diffusion/modeling_diffusion.py b/src/lerobot/policies/diffusion/modeling_diffusion.py index 3ab6719cb..1fdc76f10 100644 --- a/src/lerobot/policies/diffusion/modeling_diffusion.py +++ b/src/lerobot/policies/diffusion/modeling_diffusion.py @@ -56,6 +56,7 @@ class DiffusionPolicy(PreTrainedPolicy): def __init__( self, config: DiffusionConfig, + **kwargs, ): """ Args: diff --git a/src/lerobot/policies/factory.py b/src/lerobot/policies/factory.py index 3d17fa7dc..3e24656fc 100644 --- a/src/lerobot/policies/factory.py +++ b/src/lerobot/policies/factory.py @@ -37,10 +37,12 @@ from lerobot.policies.pi05.configuration_pi05 import PI05Config from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.sac.configuration_sac import SACConfig from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig +from lerobot.policies.sarm.configuration_sarm import SARMConfig from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig from lerobot.policies.utils import validate_visual_features_consistency from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig +from lerobot.policies.wall_x.configuration_wall_x import WallXConfig from lerobot.policies.xvla.configuration_xvla import XVLAConfig from lerobot.processor import PolicyAction, PolicyProcessorPipeline from lerobot.processor.converters import ( @@ -61,7 +63,7 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]: Args: name: The name of the policy. Supported names are "tdmpc", "diffusion", "act", - "vqbet", "pi0", "pi05", "sac", "reward_classifier", "smolvla". + "vqbet", "pi0", "pi05", "sac", "reward_classifier", "smolvla", "wall_x". Returns: The policy class corresponding to the given name. @@ -105,6 +107,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]: from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy return SmolVLAPolicy + elif name == "sarm": + from lerobot.policies.sarm.modeling_sarm import SARMRewardModel + + return SARMRewardModel elif name == "groot": from lerobot.policies.groot.modeling_groot import GrootPolicy @@ -113,6 +119,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]: from lerobot.policies.xvla.modeling_xvla import XVLAPolicy return XVLAPolicy + elif name == "wall_x": + from lerobot.policies.wall_x.modeling_wall_x import WallXPolicy + + return WallXPolicy else: try: return _get_policy_cls_from_policy_name(name=name) @@ -130,7 +140,7 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig: Args: policy_type: The type of the policy. Supported types include "tdmpc", "diffusion", "act", "vqbet", "pi0", "pi05", "sac", "smolvla", - "reward_classifier". + "reward_classifier", "wall_x". **kwargs: Keyword arguments to be passed to the configuration class constructor. Returns: @@ -161,6 +171,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig: return GrootConfig(**kwargs) elif policy_type == "xvla": return XVLAConfig(**kwargs) + elif policy_type == "wall_x": + return WallXConfig(**kwargs) else: try: config_cls = PreTrainedConfig.get_choice_class(policy_type) @@ -337,6 +349,14 @@ def make_pre_post_processors( dataset_stats=kwargs.get("dataset_stats"), ) + elif isinstance(policy_cfg, SARMConfig): + from lerobot.policies.sarm.processor_sarm import make_sarm_pre_post_processors + + processors = make_sarm_pre_post_processors( + config=policy_cfg, + dataset_stats=kwargs.get("dataset_stats"), + dataset_meta=kwargs.get("dataset_meta"), + ) elif isinstance(policy_cfg, GrootConfig): from lerobot.policies.groot.processor_groot import make_groot_pre_post_processors @@ -344,6 +364,7 @@ def make_pre_post_processors( config=policy_cfg, dataset_stats=kwargs.get("dataset_stats"), ) + elif isinstance(policy_cfg, XVLAConfig): from lerobot.policies.xvla.processor_xvla import ( make_xvla_pre_post_processors, @@ -354,6 +375,14 @@ def make_pre_post_processors( dataset_stats=kwargs.get("dataset_stats"), ) + elif isinstance(policy_cfg, WallXConfig): + from lerobot.policies.wall_x.processor_wall_x import make_wall_x_pre_post_processors + + processors = make_wall_x_pre_post_processors( + config=policy_cfg, + dataset_stats=kwargs.get("dataset_stats"), + ) + else: try: processors = _make_processors_from_policy_config( @@ -435,6 +464,13 @@ def make_policy( cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features} kwargs["config"] = cfg + # Pass dataset_stats to the policy if available (needed for some policies like SARM) + if ds_meta is not None and hasattr(ds_meta, "stats"): + kwargs["dataset_stats"] = ds_meta.stats + + if ds_meta is not None: + kwargs["dataset_meta"] = ds_meta + if cfg.pretrained_path: # Load a pretrained policy and override the config if needed (for example, if there are inference-time # hyperparameters that we want to vary). diff --git a/src/lerobot/policies/groot/modeling_groot.py b/src/lerobot/policies/groot/modeling_groot.py index 605f7a097..bdaef37b9 100644 --- a/src/lerobot/policies/groot/modeling_groot.py +++ b/src/lerobot/policies/groot/modeling_groot.py @@ -49,7 +49,7 @@ class GrootPolicy(PreTrainedPolicy): name = "groot" config_class = GrootConfig - def __init__(self, config: GrootConfig): + def __init__(self, config: GrootConfig, **kwargs): """Initialize Groot policy wrapper.""" super().__init__(config) config.validate_features() diff --git a/src/lerobot/policies/pi0/configuration_pi0.py b/src/lerobot/policies/pi0/configuration_pi0.py index 9e267fb48..33753e0b2 100644 --- a/src/lerobot/policies/pi0/configuration_pi0.py +++ b/src/lerobot/policies/pi0/configuration_pi0.py @@ -23,6 +23,8 @@ from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig from lerobot.policies.rtc.configuration_rtc import RTCConfig from lerobot.utils.constants import OBS_IMAGES +DEFAULT_IMAGE_SIZE = 224 + @PreTrainedConfig.register_subclass("pi0") @dataclass @@ -51,7 +53,10 @@ class PI0Config(PreTrainedConfig): # Real-Time Chunking (RTC) configuration rtc_config: RTCConfig | None = None - image_resolution: tuple[int, int] = (224, 224) # see openpi `preprocessing_pytorch.py` + image_resolution: tuple[int, int] = ( + DEFAULT_IMAGE_SIZE, + DEFAULT_IMAGE_SIZE, + ) # see openpi `preprocessing_pytorch.py` # Add empty images. Used to add empty cameras when no image features are present. empty_cameras: int = 0 diff --git a/src/lerobot/policies/pi0/modeling_pi0.py b/src/lerobot/policies/pi0/modeling_pi0.py index 9b6f38ad4..0d9c77e00 100644 --- a/src/lerobot/policies/pi0/modeling_pi0.py +++ b/src/lerobot/policies/pi0/modeling_pi0.py @@ -41,7 +41,7 @@ else: PaliGemmaForConditionalGeneration = None from lerobot.configs.policies import PreTrainedConfig -from lerobot.policies.pi0.configuration_pi0 import PI0Config +from lerobot.policies.pi0.configuration_pi0 import DEFAULT_IMAGE_SIZE, PI0Config from lerobot.policies.pretrained import PreTrainedPolicy, T from lerobot.policies.rtc.modeling_rtc import RTCProcessor from lerobot.utils.constants import ( @@ -337,6 +337,7 @@ class PaliGemmaWithExpertModel( action_expert_config, use_adarms=None, precision: Literal["bfloat16", "float32"] = "bfloat16", + image_size: int = DEFAULT_IMAGE_SIZE, ): if use_adarms is None: use_adarms = [False, False] @@ -356,6 +357,7 @@ class PaliGemmaWithExpertModel( vlm_config_hf.text_config.vocab_size = 257152 vlm_config_hf.text_config.use_adarms = use_adarms[0] vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None + vlm_config_hf.vision_config.image_size = image_size vlm_config_hf.vision_config.intermediate_size = 4304 vlm_config_hf.vision_config.projection_dim = 2048 vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast" @@ -519,11 +521,17 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch` paligemma_config = get_gemma_config(config.paligemma_variant) action_expert_config = get_gemma_config(config.action_expert_variant) + if config.image_resolution[0] != config.image_resolution[1]: + raise ValueError( + f"PaliGemma expects square image resolution, invalid resolution: {config.image_resolution}" + ) + self.paligemma_with_expert = PaliGemmaWithExpertModel( paligemma_config, action_expert_config, use_adarms=[False, False], precision=config.dtype, + image_size=config.image_resolution[0], ) self.action_in_proj = nn.Linear(config.max_action_dim, action_expert_config.width) @@ -812,16 +820,13 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch` ) dt = -1.0 / num_steps - dt = torch.tensor(dt, dtype=torch.float32, device=device) x_t = noise - time = torch.tensor(1.0, dtype=torch.float32, device=device) - while time >= -dt / 2: - expanded_time = time.expand(bsize) + for step in range(num_steps): + time = 1.0 + step * dt + time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize) - # Define a closure function to properly capture expanded_time - # This avoids the lambda expression (E731) and loop variable binding (B023) issues - def denoise_step_partial_call(input_x_t, current_timestep=expanded_time): + def denoise_step_partial_call(input_x_t, current_timestep=time_tensor): return self.denoise_step( state=state, prefix_pad_masks=prefix_pad_masks, @@ -846,15 +851,11 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch` else: v_t = denoise_step_partial_call(x_t) - # Euler step - x_t += dt * v_t + x_t = x_t + dt * v_t - # Record x_t and v_t after Euler step if self.rtc_processor is not None and self.rtc_processor.is_debug_enabled(): self.rtc_processor.track(time=time, x_t=x_t, v_t=v_t) - time += dt - return x_t def denoise_step( @@ -906,6 +907,7 @@ class PI0Policy(PreTrainedPolicy): def __init__( self, config: PI0Config, + **kwargs, ): """ Args: @@ -1234,9 +1236,15 @@ class PI0Policy(PreTrainedPolicy): return actions - def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]: - """Run the batch through the model and compute the loss for training.""" + def forward(self, batch: dict[str, Tensor], reduction: str = "mean") -> tuple[Tensor, dict]: + """Run the batch through the model and compute the loss for training. + Args: + batch: Training batch containing observations and actions. + reduction: How to reduce the loss. Options: + - "mean": Return scalar mean loss (default, backward compatible) + - "none": Return per-sample losses of shape (batch_size,) for RA-BC weighting + """ # Prepare inputs images, img_masks = self._preprocess_images(batch) lang_tokens, lang_masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"] @@ -1250,11 +1258,17 @@ class PI0Policy(PreTrainedPolicy): original_action_dim = self.config.output_features[ACTION].shape[0] losses = losses[:, :, :original_action_dim] - loss = losses.mean() - loss_dict = { - "loss": loss.item(), "loss_per_dim": losses.mean(dim=[0, 1]).detach().cpu().numpy().tolist(), } - return loss, loss_dict + if reduction == "none": + # Return per-sample losses (B,) by averaging over time and action dims + per_sample_loss = losses.mean(dim=(1, 2)) + loss_dict["loss"] = per_sample_loss.mean().item() + return per_sample_loss, loss_dict + else: + # Default: return scalar mean loss + loss = losses.mean() + loss_dict["loss"] = loss.item() + return loss, loss_dict diff --git a/src/lerobot/policies/pi05/configuration_pi05.py b/src/lerobot/policies/pi05/configuration_pi05.py index 2edd625af..7bdce70dd 100644 --- a/src/lerobot/policies/pi05/configuration_pi05.py +++ b/src/lerobot/policies/pi05/configuration_pi05.py @@ -22,6 +22,8 @@ from lerobot.optim.optimizers import AdamWConfig from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig from lerobot.policies.rtc.configuration_rtc import RTCConfig +DEFAULT_IMAGE_SIZE = 224 + @PreTrainedConfig.register_subclass("pi05") @dataclass @@ -50,7 +52,10 @@ class PI05Config(PreTrainedConfig): # Real-Time Chunking (RTC) configuration rtc_config: RTCConfig | None = None - image_resolution: tuple[int, int] = (224, 224) # see openpi `preprocessing_pytorch.py` + image_resolution: tuple[int, int] = ( + DEFAULT_IMAGE_SIZE, + DEFAULT_IMAGE_SIZE, + ) # see openpi `preprocessing_pytorch.py` # Add empty images. Used to add empty cameras when no image features are present. empty_cameras: int = 0 diff --git a/src/lerobot/policies/pi05/modeling_pi05.py b/src/lerobot/policies/pi05/modeling_pi05.py index 6500ada20..2cd142042 100644 --- a/src/lerobot/policies/pi05/modeling_pi05.py +++ b/src/lerobot/policies/pi05/modeling_pi05.py @@ -41,7 +41,7 @@ else: PaliGemmaForConditionalGeneration = None from lerobot.configs.policies import PreTrainedConfig -from lerobot.policies.pi05.configuration_pi05 import PI05Config +from lerobot.policies.pi05.configuration_pi05 import DEFAULT_IMAGE_SIZE, PI05Config from lerobot.policies.pretrained import PreTrainedPolicy, T from lerobot.policies.rtc.modeling_rtc import RTCProcessor from lerobot.utils.constants import ( @@ -336,6 +336,7 @@ class PaliGemmaWithExpertModel( action_expert_config, use_adarms=None, precision: Literal["bfloat16", "float32"] = "bfloat16", + image_size: int = DEFAULT_IMAGE_SIZE, ): if use_adarms is None: use_adarms = [False, False] @@ -355,6 +356,7 @@ class PaliGemmaWithExpertModel( vlm_config_hf.text_config.vocab_size = 257152 vlm_config_hf.text_config.use_adarms = use_adarms[0] vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None + vlm_config_hf.vision_config.image_size = image_size vlm_config_hf.vision_config.intermediate_size = 4304 vlm_config_hf.vision_config.projection_dim = 2048 vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast" @@ -518,11 +520,17 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` paligemma_config = get_gemma_config(config.paligemma_variant) action_expert_config = get_gemma_config(config.action_expert_variant) + if config.image_resolution[0] != config.image_resolution[1]: + raise ValueError( + f"PaliGemma expects square image resolution, invalid resolution: {config.image_resolution}" + ) + self.paligemma_with_expert = PaliGemmaWithExpertModel( paligemma_config, action_expert_config, use_adarms=[False, True], precision=config.dtype, + image_size=config.image_resolution[0], ) self.action_in_proj = nn.Linear(config.max_action_dim, action_expert_config.width) @@ -787,16 +795,13 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` ) dt = -1.0 / num_steps - dt = torch.tensor(dt, dtype=torch.float32, device=device) x_t = noise - time = torch.tensor(1.0, dtype=torch.float32, device=device) - while time >= -dt / 2: - expanded_time = time.expand(bsize) + for step in range(num_steps): + time = 1.0 + step * dt + time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize) - # Define a closure function to properly capture expanded_time - # This avoids the lambda expression (E731) and loop variable binding (B023) issues - def denoise_step_partial_call(input_x_t, current_timestep=expanded_time): + def denoise_step_partial_call(input_x_t, current_timestep=time_tensor): return self.denoise_step( prefix_pad_masks=prefix_pad_masks, past_key_values=past_key_values, @@ -820,15 +825,11 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` else: v_t = denoise_step_partial_call(x_t) - # Euler step - x_t += dt * v_t + x_t = x_t + dt * v_t - # Record x_t and v_t after Euler step if self.rtc_processor is not None and self.rtc_processor.is_debug_enabled(): self.rtc_processor.track(time=time, x_t=x_t, v_t=v_t) - time += dt - return x_t def denoise_step( @@ -879,6 +880,7 @@ class PI05Policy(PreTrainedPolicy): def __init__( self, config: PI05Config, + **kwargs, ): """ Args: @@ -1208,9 +1210,15 @@ class PI05Policy(PreTrainedPolicy): return actions - def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]: - """Run the batch through the model and compute the loss for training.""" + def forward(self, batch: dict[str, Tensor], reduction: str = "mean") -> tuple[Tensor, dict]: + """Run the batch through the model and compute the loss for training. + Args: + batch: Training batch containing observations and actions. + reduction: How to reduce the loss. Options: + - "mean": Return scalar mean loss (default, backward compatible) + - "none": Return per-sample losses of shape (batch_size,) for RA-BC weighting + """ # Prepare inputs images, img_masks = self._preprocess_images(batch) tokens, masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"] @@ -1224,11 +1232,17 @@ class PI05Policy(PreTrainedPolicy): original_action_dim = self.config.output_features[ACTION].shape[0] losses = losses[:, :, :original_action_dim] - loss = losses.mean() - loss_dict = { - "loss": loss.item(), "loss_per_dim": losses.mean(dim=[0, 1]).detach().cpu().numpy().tolist(), } - return loss, loss_dict + if reduction == "none": + # Return per-sample losses (B,) by averaging over time and action dims + per_sample_loss = losses.mean(dim=(1, 2)) + loss_dict["loss"] = per_sample_loss.mean().item() + return per_sample_loss, loss_dict + else: + # Default: return scalar mean loss + loss = losses.mean() + loss_dict["loss"] = loss.item() + return loss, loss_dict diff --git a/src/lerobot/policies/sarm/README.md b/src/lerobot/policies/sarm/README.md new file mode 100644 index 000000000..e0e49834b --- /dev/null +++ b/src/lerobot/policies/sarm/README.md @@ -0,0 +1,14 @@ +## Paper + +https://arxiv.org/abs/2509.25358 + +## Citation + +```bibtex +@article{chen2025sarm, + title={SARM: Stage-Aware Reward Modeling for Long Horizon Robot Manipulation}, + author={Chen, Qianzhong and Yu, Justin and Schwager, Mac and Abbeel, Pieter and Shentu, Yide and Wu, Philipp}, + journal={arXiv preprint arXiv:2509.25358}, + year={2025} +} +``` diff --git a/src/lerobot/policies/sarm/compute_rabc_weights.py b/src/lerobot/policies/sarm/compute_rabc_weights.py new file mode 100644 index 000000000..5b6ea6e9b --- /dev/null +++ b/src/lerobot/policies/sarm/compute_rabc_weights.py @@ -0,0 +1,870 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Compute SARM progress values for RA-BC (Reward-Aware Behavior Cloning) weighting. + +This script processes all frames in a dataset with SARM to compute progress values [0, 1]. +The results are saved as a parquet file that can be loaded during training for RA-BC weighting. + +Uses multi-output extraction: each SARM query returns progress for 9 frames, so we only +need ~num_frames/30 queries instead of one per frame (~30x speedup). + +Usage: + # Full RA-BC computation with visualizations + python src/lerobot/policies/sarm/compute_rabc_weights.py \\ + --dataset-repo-id lerobot/aloha_sim_insertion_human \\ + --reward-model-path pepijn223/sarm_single_uni4 + + # Faster computation with stride (compute every 5 frames, interpolate the rest) + python src/lerobot/policies/sarm/compute_rabc_weights.py \\ + --dataset-repo-id lerobot/aloha_sim_insertion_human \\ + --reward-model-path pepijn223/sarm_single_uni4 \\ + --stride 5 + + # Visualize predictions only (no RA-BC computation) + python src/lerobot/policies/sarm/compute_rabc_weights.py \\ + --dataset-repo-id lerobot/aloha_sim_insertion_human \\ + --reward-model-path pepijn223/sarm_single_uni4 \\ + --visualize-only \\ + --num-visualizations 5 + +The output is saved to the dataset's local cache directory as 'sarm_progress.parquet'. +""" + +import argparse +import logging +from pathlib import Path + +import matplotlib.gridspec as gridspec +import matplotlib.pyplot as plt +import numpy as np +import pyarrow as pa +import pyarrow.parquet as pq +import torch +from tqdm import tqdm + +from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.policies.sarm.modeling_sarm import SARMRewardModel +from lerobot.policies.sarm.processor_sarm import make_sarm_pre_post_processors +from lerobot.policies.sarm.sarm_utils import normalize_stage_tau + + +def get_reward_model_path_from_parquet(parquet_path: Path) -> str | None: + """Read reward_model_path from parquet metadata if available.""" + if not parquet_path.exists(): + return None + try: + metadata = pq.read_metadata(parquet_path).schema.to_arrow_schema().metadata + if metadata and b"reward_model_path" in metadata: + return metadata[b"reward_model_path"].decode() + except Exception: # nosec B110 + return None + return None + + +def load_sarm_resources( + dataset_repo_id: str, + reward_model_path: str, + device: str = "cuda", +) -> tuple[LeRobotDataset, SARMRewardModel, any]: + """ + Load SARM model, dataset, and preprocessor. + + Returns: + Tuple of (dataset, reward_model, preprocessor) + """ + logging.info(f"Loading model: {reward_model_path}") + reward_model = SARMRewardModel.from_pretrained(reward_model_path) + reward_model.config.device = device + reward_model.to(device).eval() + + image_key = reward_model.config.image_key + state_key = reward_model.config.state_key + delta_indices = reward_model.config.observation_delta_indices + + logging.info(f"Loading dataset: {dataset_repo_id}") + temp_dataset = LeRobotDataset(dataset_repo_id, download_videos=True) + fps = temp_dataset.fps + + delta_timestamps = { + image_key: [idx / fps for idx in delta_indices], + state_key: [idx / fps for idx in delta_indices], + } + dataset = LeRobotDataset(dataset_repo_id, delta_timestamps=delta_timestamps) + logging.info(f"Dataset: {dataset.num_episodes} episodes, {dataset.num_frames} frames") + + preprocess, _ = make_sarm_pre_post_processors( + config=reward_model.config, + dataset_stats=dataset.meta.stats, + dataset_meta=dataset.meta, + ) + + return dataset, reward_model, preprocess + + +def to_numpy_image(img) -> np.ndarray: + """Convert image tensor to numpy uint8 (H, W, C).""" + if isinstance(img, torch.Tensor): + img = img.cpu().numpy() + if img.ndim == 4: + # Take center frame for bidirectional sampling + img = img[img.shape[0] // 2] + if img.shape[0] in [1, 3]: + img = np.transpose(img, (1, 2, 0)) + if img.dtype != np.uint8: + # Handle normalized images (may have negative values or values > 1) + img = img.astype(np.float32) + img = (img - img.min()) / (img.max() - img.min() + 1e-8) # Normalize to [0, 1] + img = (img * 255).astype(np.uint8) + return img + + +def visualize_episode( + frames, progress_preds, stage_preds, title, output_path, stage_labels, gt_progress=None, gt_stages=None +): + """Create visualization with progress plot, stage probabilities, and sample frames. + + Same as sarm_inference_visualization.py + """ + num_stages = stage_preds.shape[1] + colors = plt.cm.tab10(np.linspace(0, 1, num_stages)) + frame_indices = np.arange(len(progress_preds)) + + fig = plt.figure(figsize=(14, 12)) + gs = gridspec.GridSpec(3, 1, height_ratios=[2, 1, 1], hspace=0.3) + ax_progress, ax_stages, ax_frames = fig.add_subplot(gs[0]), fig.add_subplot(gs[1]), fig.add_subplot(gs[2]) + + # Progress plot + ax_progress.plot(frame_indices, progress_preds, linewidth=2, color="#2E86AB", label="Predicted") + ax_progress.fill_between(frame_indices, 0, progress_preds, alpha=0.3, color="#2E86AB") + if gt_progress is not None: + ax_progress.plot( + frame_indices, gt_progress, linewidth=2, color="#28A745", linestyle="--", label="Ground Truth" + ) + ax_progress.axhline(y=1.0, color="gray", linestyle="--", alpha=0.5) + ax_progress.set_ylabel("Progress") + ax_progress.set_title(f'Task: "{title}"', fontweight="bold") + ax_progress.set_ylim(-0.05, 1.1) + ax_progress.legend(loc="upper left") + ax_progress.grid(True, alpha=0.3) + + # Stage predictions + ax_stages.stackplot( + frame_indices, + *[stage_preds[:, i] for i in range(num_stages)], + colors=colors, + alpha=0.8, + labels=stage_labels, + ) + if gt_stages is not None: + for change_idx in np.where(np.diff(gt_stages) != 0)[0] + 1: + ax_stages.axvline(x=change_idx, color="black", linestyle="-", alpha=0.7, linewidth=1.5) + ax_stages.set_xlabel("Frame") + ax_stages.set_ylabel("Stage Probability") + ax_stages.set_ylim(0, 1) + ax_stages.legend(loc="upper left", ncol=min(num_stages, 5), fontsize=8) + ax_stages.grid(True, alpha=0.3) + + # Sample frames + ax_frames.axis("off") + num_sample = 8 + sample_indices = np.linspace(0, len(frames) - 1, num_sample, dtype=int) + h, w = frames[0].shape[:2] + combined = np.zeros((h, w * num_sample, 3), dtype=np.uint8) + for i, idx in enumerate(sample_indices): + frame = frames[idx] + if frame.shape[-1] == 1: + frame = np.repeat(frame, 3, axis=-1) + combined[:, i * w : (i + 1) * w] = frame + stage_name = stage_labels[np.argmax(stage_preds[idx])][:12] + ax_frames.text( + i * w + w / 2, + -10, + f"Frame {idx}\n{progress_preds[idx]:.2f}\n{stage_name}", + ha="center", + va="top", + fontsize=7, + ) + ax_frames.imshow(combined) + ax_frames.set_title("Sample Frames", pad=20) + + output_path.parent.mkdir(parents=True, exist_ok=True) + plt.savefig(output_path, dpi=150, bbox_inches="tight") + plt.close() + print(f"Saved: {output_path}") + + +def visualize_sarm_predictions( + dataset: LeRobotDataset, + reward_model: SARMRewardModel, + preprocess, + episode_indices: list[int], + head_mode: str, + output_dir: Path, + num_display_frames: int = 5, + stride: int = 1, +): + """ + Visualize SARM predictions for multiple episodes. + + Computes predictions for every frame by default. With stride > 1, computes predictions + every N frames and interpolates (progress + stage probabilities) for visualization. + + Args: + dataset: LeRobotDataset with delta_timestamps configured + reward_model: Loaded SARM model + preprocess: Preprocessor from make_sarm_pre_post_processors + episode_indices: List of episode indices to visualize + head_mode: "sparse", "dense", or "both" + output_dir: Directory to save visualizations + num_display_frames: Number of frames to display in thumbnail strip (default: 5) + stride: Compute predictions every N frames, interpolate the rest (default: 1) + """ + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + image_key = reward_model.config.image_key + state_key = reward_model.config.state_key + dual_mode = reward_model.config.uses_dual_heads + device = reward_model.device + + # Center frame index for bidirectional sampling + target_idx = reward_model.config.n_obs_steps // 2 + + # Determine which heads to visualize + schemes_to_viz = [] + if head_mode in ("sparse", "both") or not dual_mode: + schemes_to_viz.append("sparse") + if head_mode in ("dense", "both") and dual_mode: + schemes_to_viz.append("dense") + + # Set preprocessor to eval mode to disable augmentations + if hasattr(preprocess, "eval"): + preprocess.eval() + for step in preprocess.steps: + if hasattr(step, "eval"): + step.eval() + + for episode_idx in episode_indices: + ep = dataset.meta.episodes[episode_idx] + ep_start = ep["dataset_from_index"] + ep_end = ep["dataset_to_index"] + task = dataset[ep_start].get("task", "perform the task") + num_frames = ep_end - ep_start + + # Select frames for display thumbnails (evenly sampled from begin to end) + display_indices = set( + [ + ep_start + int(i * (num_frames - 1) / (num_display_frames - 1)) + for i in range(num_display_frames) + ] + if num_frames >= num_display_frames + else list(range(ep_start, ep_end)) + ) + viz_frames = {} + + # Load display frames up-front (stride mode might skip them otherwise). + for frame_idx in display_indices: + sample = dataset[frame_idx] + viz_frames[frame_idx] = to_numpy_image(sample[image_key]) + + # Initialize storage for each scheme + scheme_data = {} + for scheme in schemes_to_viz: + num_stages = getattr(reward_model.config, f"num_{scheme}_stages") + scheme_data[scheme] = { + "viz_progress": np.full(num_frames, np.nan), + "viz_stages": np.full((num_frames, num_stages), np.nan), + "viz_gt_progress": np.full(num_frames, np.nan), + "viz_gt_stages": np.full(num_frames, np.nan), + "target_key": f"{scheme}_targets", + "num_stages": num_stages, + "temporal_props": getattr(reward_model.config, f"{scheme}_temporal_proportions"), + "subtask_names": getattr(reward_model.config, f"{scheme}_subtask_names"), + } + + if stride > 1: + logging.info(f"Visualization stride={stride}: inferring every {stride} frames and interpolating") + + # Process frames one at a time to avoid memory buildup + frame_indices = list(range(ep_start, ep_end, stride)) + if (ep_end - 1) not in frame_indices: + frame_indices.append(ep_end - 1) + frame_indices = sorted(set(frame_indices)) + + for frame_idx in tqdm(frame_indices, desc=f"Episode {episode_idx}", leave=False): + local_idx = frame_idx - ep_start + sample = dataset[frame_idx] + + batch = { + image_key: sample[image_key], + "task": task, + "index": frame_idx, + "episode_index": episode_idx, + } + if state_key in sample: + batch[state_key] = sample[state_key] + + with torch.no_grad(): + processed = preprocess(batch) + video_features = processed["video_features"].to(device) + text_features = processed["text_features"].to(device) + state_features = processed.get("state_features") + if state_features is not None: + state_features = state_features.to(device) + lengths = processed.get("lengths") + + for scheme in schemes_to_viz: + sd = scheme_data[scheme] + + # Ground truth + # In stride visualization mode, ground-truth plots can be misleading + # (only sparse points are available), so we skip GT. + if stride == 1 and sd["target_key"] in processed: + gt_target = processed[sd["target_key"]][0, target_idx].cpu().item() + sd["viz_gt_stages"][local_idx] = int(gt_target) + sd["viz_gt_progress"][local_idx] = normalize_stage_tau( + gt_target, + num_stages=sd["num_stages"], + temporal_proportions=sd["temporal_props"], + subtask_names=sd["subtask_names"], + ) + + # Predictions + reward, stage_probs = reward_model.calculate_rewards( + text_embeddings=text_features, + video_embeddings=video_features, + state_features=state_features, + lengths=lengths, + return_all_frames=True, + return_stages=True, + head_mode=scheme, + ) + + # Handle both tensor and numpy outputs + if isinstance(reward, torch.Tensor): + reward = reward.cpu().numpy() + stage_probs = stage_probs.cpu().numpy() + + if reward.ndim == 2: + sd["viz_progress"][local_idx] = reward[0, target_idx] + sd["viz_stages"][local_idx] = stage_probs[0, target_idx, :] + else: + sd["viz_progress"][local_idx] = reward[target_idx] + sd["viz_stages"][local_idx] = stage_probs[target_idx, :] + + # Clear GPU memory after each frame + del processed, video_features, text_features + if state_features is not None: + del state_features + + torch.cuda.empty_cache() + + # Interpolate predictions back to per-frame arrays for smooth visualization. + if stride > 1: + all_local = np.arange(num_frames) + for scheme in schemes_to_viz: + sd = scheme_data[scheme] + + valid = np.isfinite(sd["viz_progress"]) + valid_idx = np.where(valid)[0] + if valid_idx.size >= 1: + sd["viz_progress"] = interpolate_progress( + valid_idx, sd["viz_progress"][valid_idx], all_local + ) + + stage_interp = np.zeros_like(sd["viz_stages"], dtype=np.float32) + for s in range(sd["num_stages"]): + stage_interp[:, s] = interpolate_progress( + valid_idx, sd["viz_stages"][valid_idx, s], all_local + ) + + stage_interp = np.clip(stage_interp, 0.0, 1.0) + row_sums = stage_interp.sum(axis=1, keepdims=True) + nz = row_sums.squeeze(-1) > 0 + stage_interp[nz] = stage_interp[nz] / row_sums[nz] + sd["viz_stages"] = stage_interp + else: + # No valid points: keep NaNs/zeros; visualization will be empty. + sd["viz_stages"] = np.nan_to_num(sd["viz_stages"], nan=0.0) + + # Generate visualization for each head + ordered_viz_frames = [viz_frames[idx] for idx in sorted(display_indices)] + for scheme in schemes_to_viz: + sd = scheme_data[scheme] + stage_labels = sd["subtask_names"] or [f"Stage {i + 1}" for i in range(sd["num_stages"])] + viz_path = output_dir / f"sarm_prediction_ep{episode_idx}_{scheme}.png" + + visualize_episode( + frames=np.array(ordered_viz_frames), + progress_preds=sd["viz_progress"], + stage_preds=sd["viz_stages"], + title=f"{task} (Episode {episode_idx})", + output_path=viz_path, + stage_labels=stage_labels, + gt_progress=sd["viz_gt_progress"] if not np.all(np.isnan(sd["viz_gt_progress"])) else None, + gt_stages=sd["viz_gt_stages"] if not np.all(np.isnan(sd["viz_gt_stages"])) else None, + ) + + # Clear memory between episodes + torch.cuda.empty_cache() + + logging.info(f"Visualizations saved to: {output_dir.absolute()}") + + +def generate_all_frame_indices(ep_start: int, ep_end: int, frame_gap: int = 30) -> list[int]: + """Generate all frame indices, ordered by offset for cache-friendly access. + + Orders frames as: [0, 30, 60...], [1, 31, 61...], ..., [29, 59, 89...] + This groups frames that share similar temporal windows together. + """ + num_frames = ep_end - ep_start + indices = [] + for offset in range(frame_gap): + for frame_rel in range(offset, num_frames, frame_gap): + indices.append(ep_start + frame_rel) + return indices + + +def interpolate_progress( + computed_indices: np.ndarray, + computed_values: np.ndarray, + all_indices: np.ndarray, +) -> np.ndarray: + """Linearly interpolate values to fill in gaps (robust to NaNs / edge cases).""" + computed_indices = np.asarray(computed_indices) + computed_values = np.asarray(computed_values) + all_indices = np.asarray(all_indices) + + mask = np.isfinite(computed_values) + if mask.sum() == 0: + return np.full(all_indices.shape, np.nan, dtype=np.float32) + if mask.sum() == 1: + return np.full(all_indices.shape, float(computed_values[mask][0]), dtype=np.float32) + + out = np.interp(all_indices, computed_indices[mask], computed_values[mask]) + return out.astype(np.float32) + + +def compute_sarm_progress( + dataset_repo_id: str, + reward_model_path: str, + output_path: str | None = None, + head_mode: str = "sparse", + device: str = "cuda", + num_visualizations: int = 5, + output_dir: str = "./sarm_viz", + stride: int = 1, +): + """ + Compute SARM progress predictions for all frames in a dataset. + + Args: + dataset_repo_id: HuggingFace dataset repo ID or local path + reward_model_path: Path to pretrained SARM model + output_path: Path to save results. If None, saves to dataset's cache directory + head_mode: SARM head to use ("sparse", "dense", or "both") + device: Device to use for inference + num_visualizations: Number of episodes to visualize (0 to skip) + output_dir: Directory to save visualizations + stride: Compute progress every N frames, interpolate the rest (default: 1 = every frame) + """ + dataset, reward_model, preprocess = load_sarm_resources(dataset_repo_id, reward_model_path, device) + + # Set preprocessor to eval mode to disable augmentations + if hasattr(preprocess, "eval"): + preprocess.eval() + for step in preprocess.steps: + if hasattr(step, "eval"): + step.eval() + + image_key = reward_model.config.image_key + state_key = reward_model.config.state_key + frame_gap = reward_model.config.frame_gap + num_episodes = dataset.num_episodes + total_frames = dataset.num_frames + logging.info(f"Processing {total_frames} frames across {num_episodes} episodes") + + # Determine which heads to compute + dual_mode = reward_model.config.uses_dual_heads + compute_sparse = head_mode in ("sparse", "both") or not dual_mode + compute_dense = head_mode in ("dense", "both") and dual_mode + + # Storage arrays + all_indices = [] + all_episode_indices = [] + all_frame_indices = [] + all_progress_sparse = [] if compute_sparse else None + all_progress_dense = [] if compute_dense else None + + if stride > 1: + logging.info(f"Using stride={stride}: computing every {stride} frames, interpolating the rest") + + # Process all episodes + for episode_idx in tqdm(range(num_episodes), desc="Episodes"): + ep = dataset.meta.episodes[episode_idx] + ep_start = ep["dataset_from_index"] + ep_end = ep["dataset_to_index"] + + # Get task description + task = dataset[ep_start].get("task", "perform the task") + + # Generate frames to compute (with stride applied) + all_ep_indices = generate_all_frame_indices(ep_start, ep_end, frame_gap) + if stride > 1: + # Only compute every stride-th frame (relative to episode start) + compute_indices = [idx for idx in all_ep_indices if (idx - ep_start) % stride == 0] + # Always include last frame for better interpolation at episode end + last_frame = ep_end - 1 + if last_frame not in compute_indices: + compute_indices.append(last_frame) + compute_indices = sorted(set(compute_indices)) + else: + compute_indices = all_ep_indices + + center_idx = reward_model.config.n_obs_steps // 2 # Center of bidirectional window + + # Dictionary to collect results + frame_results = {} + + for query_idx in tqdm(compute_indices, desc=f" Ep {episode_idx}", leave=False): + try: + sample = dataset[query_idx] + + batch = { + image_key: sample[image_key], + "task": task, + "index": query_idx, + "episode_index": episode_idx, + } + if state_key in sample: + batch[state_key] = sample[state_key] + + with torch.no_grad(): + processed = preprocess(batch) + video_features = processed["video_features"].to(device) + text_features = processed["text_features"].to(device) + state_features = processed.get("state_features") + if state_features is not None: + state_features = state_features.to(device) + lengths = processed.get("lengths") + + sparse_val = np.nan + dense_val = np.nan + + # Compute sparse prediction for center frame + if compute_sparse: + sparse_progress = reward_model.calculate_rewards( + text_embeddings=text_features, + video_embeddings=video_features, + state_features=state_features, + lengths=lengths, + return_all_frames=True, + head_mode="sparse", + ) + sparse_val = float( + sparse_progress[0, center_idx] + if sparse_progress.ndim == 2 + else sparse_progress[center_idx] + ) + + # Compute dense prediction for center frame + if compute_dense: + dense_progress = reward_model.calculate_rewards( + text_embeddings=text_features, + video_embeddings=video_features, + state_features=state_features, + lengths=lengths, + return_all_frames=True, + head_mode="dense", + ) + dense_val = float( + dense_progress[0, center_idx] + if dense_progress.ndim == 2 + else dense_progress[center_idx] + ) + + frame_results[query_idx] = (sparse_val, dense_val) + + except Exception as e: + logging.warning(f"Failed to process frame {query_idx}: {e}") + + # Interpolate to get values for all frames + computed_indices = np.array(sorted(frame_results.keys())) + computed_sparse = ( + np.array([frame_results[i][0] for i in computed_indices]) if compute_sparse else None + ) + computed_dense = np.array([frame_results[i][1] for i in computed_indices]) if compute_dense else None + + # All frame indices for this episode + all_frame_idx_array = np.arange(ep_start, ep_end) + + if stride > 1 and len(computed_indices) > 1: + # Interpolate progress values + if compute_sparse: + interp_sparse = interpolate_progress(computed_indices, computed_sparse, all_frame_idx_array) + if compute_dense: + interp_dense = interpolate_progress(computed_indices, computed_dense, all_frame_idx_array) + else: + # No interpolation needed + interp_sparse = computed_sparse if compute_sparse else None + interp_dense = computed_dense if compute_dense else None + + # Store results for all frames + for i, frame_idx in enumerate(all_frame_idx_array): + local_idx = frame_idx - ep_start + all_indices.append(frame_idx) + all_episode_indices.append(episode_idx) + all_frame_indices.append(local_idx) + if compute_sparse: + if stride > 1 and len(computed_indices) > 1: + all_progress_sparse.append(float(interp_sparse[i])) + elif frame_idx in frame_results: + all_progress_sparse.append(frame_results[frame_idx][0]) + else: + all_progress_sparse.append(np.nan) + if compute_dense: + if stride > 1 and len(computed_indices) > 1: + all_progress_dense.append(float(interp_dense[i])) + elif frame_idx in frame_results: + all_progress_dense.append(frame_results[frame_idx][1]) + else: + all_progress_dense.append(np.nan) + + # Create output table + table_data = { + "index": np.array(all_indices, dtype=np.int64), + "episode_index": np.array(all_episode_indices, dtype=np.int64), + "frame_index": np.array(all_frame_indices, dtype=np.int64), + } + if compute_sparse: + table_data["progress_sparse"] = np.array(all_progress_sparse, dtype=np.float32) + if compute_dense: + table_data["progress_dense"] = np.array(all_progress_dense, dtype=np.float32) + + # Sort by index + df = pa.table(table_data).to_pandas() + df = df.sort_values("index").reset_index(drop=True) + final_table = pa.Table.from_pandas(df, preserve_index=False) + + # Add metadata with reward model path + metadata = {b"reward_model_path": reward_model_path.encode()} + final_table = final_table.replace_schema_metadata(metadata) + + # Determine output path + output_path = Path(dataset.root) / "sarm_progress.parquet" if output_path is None else Path(output_path) + + # Save + output_path.parent.mkdir(parents=True, exist_ok=True) + pq.write_table(final_table, output_path) + logging.info(f"Saved {len(final_table)} frame progress values to {output_path}") + + # Print statistics + if "progress_sparse" in df.columns: + valid = df["progress_sparse"].dropna() + logging.info( + f"Sparse progress: mean={valid.mean():.4f}, std={valid.std():.4f}, " + f"min={valid.min():.4f}, max={valid.max():.4f}" + ) + + if "progress_dense" in df.columns: + valid = df["progress_dense"].dropna() + logging.info( + f"Dense progress: mean={valid.mean():.4f}, std={valid.std():.4f}, " + f"min={valid.min():.4f}, max={valid.max():.4f}" + ) + + # Visualize episodes after processing + if num_visualizations > 0: + viz_episodes = list(range(min(num_visualizations, num_episodes))) + logging.info(f"Generating {len(viz_episodes)} visualizations...") + visualize_sarm_predictions( + dataset=dataset, + reward_model=reward_model, + preprocess=preprocess, + episode_indices=viz_episodes, + head_mode=head_mode, + output_dir=Path(output_dir), + stride=stride, + ) + + return output_path + + +def main(): + parser = argparse.ArgumentParser( + description="Compute SARM progress values for RA-BC weighting or visualize SARM predictions", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Full RA-BC computation with visualizations + python src/lerobot/policies/sarm/compute_rabc_weights.py \\ + --dataset-repo-id lerobot/aloha_sim_insertion_human \\ + --reward-model-path pepijn223/sarm_single_uni4 + + # Visualize predictions only (no RA-BC computation) + python src/lerobot/policies/sarm/compute_rabc_weights.py \\ + --dataset-repo-id lerobot/aloha_sim_insertion_human \\ + --reward-model-path pepijn223/sarm_single_uni4 \\ + --visualize-only \\ + --num-visualizations 10 + """, + ) + parser.add_argument( + "--dataset-repo-id", + type=str, + required=True, + help="HuggingFace dataset repo ID or local path", + ) + parser.add_argument( + "--reward-model-path", + type=str, + default=None, + help="Path to pretrained SARM model (reads from existing parquet metadata if not provided)", + ) + parser.add_argument( + "--output-path", + type=str, + default=None, + help="Output path for parquet. If not set, saves to dataset's cache directory", + ) + parser.add_argument( + "--head-mode", + type=str, + default="sparse", + choices=["sparse", "dense", "both"], + help="SARM head to use (default: sparse)", + ) + parser.add_argument( + "--device", + type=str, + default="cuda", + help="Device to use (default: cuda)", + ) + # Visualization options + parser.add_argument( + "--visualize-only", + action="store_true", + help="Only visualize SARM predictions (no RA-BC computation)", + ) + parser.add_argument( + "--num-visualizations", + type=int, + default=5, + help="Number of episodes to visualize (default: 5, set to 0 to skip)", + ) + parser.add_argument( + "--output-dir", + type=str, + default="./sarm_viz", + help="Output directory for visualizations (default: ./sarm_viz)", + ) + parser.add_argument( + "--push-to-hub", + action="store_true", + help="Upload progress file to the dataset repo on HuggingFace Hub", + default=True, + ) + parser.add_argument( + "--stride", + type=int, + default=1, + help="Compute progress every N frames, interpolate the rest (default: 1 = every frame)", + ) + + args = parser.parse_args() + + logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") + + # Try to get reward_model_path from parquet metadata if not provided + reward_model_path = args.reward_model_path + if reward_model_path is None: + # Load dataset to find parquet path + temp_dataset = LeRobotDataset(args.dataset_repo_id, download_videos=False) + parquet_path = Path(temp_dataset.root) / "sarm_progress.parquet" + reward_model_path = get_reward_model_path_from_parquet(parquet_path) + if reward_model_path: + logging.info(f"Using reward model from parquet metadata: {reward_model_path}") + else: + raise ValueError( + "--reward-model-path is required (no existing parquet with model metadata found)" + ) + + # Handle visualize-only mode + if args.visualize_only: + dataset, reward_model, preprocess = load_sarm_resources( + args.dataset_repo_id, reward_model_path, args.device + ) + logging.info(f"Visualization-only mode: visualizing {args.num_visualizations} episodes") + viz_episodes = list(range(min(args.num_visualizations, dataset.num_episodes))) + visualize_sarm_predictions( + dataset=dataset, + reward_model=reward_model, + preprocess=preprocess, + episode_indices=viz_episodes, + head_mode=args.head_mode, + output_dir=Path(args.output_dir), + stride=args.stride, + ) + print(f"\nVisualizations saved to: {Path(args.output_dir).absolute()}") + return + + # Full RABC computation (compute_sarm_progress loads model/dataset itself) + output_path = compute_sarm_progress( + dataset_repo_id=args.dataset_repo_id, + reward_model_path=reward_model_path, + output_path=args.output_path, + head_mode=args.head_mode, + device=args.device, + num_visualizations=args.num_visualizations, + output_dir=args.output_dir, + stride=args.stride, + ) + + print(f"\nSARM progress values saved to: {output_path}") + + # Upload to Hub if requested + if args.push_to_hub: + from huggingface_hub import HfApi + + api = HfApi() + hub_path = "sarm_progress.parquet" + + print(f"\nUploading to Hub: {args.dataset_repo_id}/{hub_path}") + api.upload_file( + path_or_fileobj=str(output_path), + path_in_repo=hub_path, + repo_id=args.dataset_repo_id, + repo_type="dataset", + ) + print( + f"Successfully uploaded to: https://huggingface.co/datasets/{args.dataset_repo_id}/blob/main/{hub_path}" + ) + + print("\nTo use in training, add to your config:") + print(" use_rabc: true") + print(f" rabc_progress_path: hf://datasets/{args.dataset_repo_id}/{hub_path}") + print(" rabc_head_mode: sparse # or dense") + else: + print("\nTo use in training, add to your config:") + print(" use_rabc: true") + print(f" rabc_progress_path: {output_path}") + print(" rabc_head_mode: sparse # or dense") + + +if __name__ == "__main__": + main() diff --git a/src/lerobot/policies/sarm/configuration_sarm.py b/src/lerobot/policies/sarm/configuration_sarm.py new file mode 100644 index 000000000..59cb352d5 --- /dev/null +++ b/src/lerobot/policies/sarm/configuration_sarm.py @@ -0,0 +1,248 @@ +#!/usr/bin/env python + +# Copyright 2025 Qianzhong Chen, Justin Yu, Mac Schwager, Pieter Abbeel, Yide Shentu, Philipp Wu +# and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +SARM: Stage-Aware Reward Modeling for Long Horizon Robot Manipulation. +Paper: https://arxiv.org/abs/2509.25358 +""" + +from dataclasses import dataclass, field + +from lerobot.configs.policies import PreTrainedConfig +from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature +from lerobot.optim.optimizers import AdamWConfig +from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig + + +@PreTrainedConfig.register_subclass("sarm") +@dataclass +class SARMConfig(PreTrainedConfig): + """Configuration class for SARM (Stage-Aware Reward Modeling). + + Supports three annotation modes: + + 1. single_stage (default): No annotations needed. Uses the episode's task description + as a single stage covering the entire episode. + + 2. dense_only: Uses dense (fine-grained) annotations from VLM, with an auto-generated + single sparse "task" stage covering the full episode. The dense head learns detailed + subtask progression while sparse provides overall task completion. + + 3. dual: Full dual-head mode with both sparse (high-level) and dense (fine-grained) + annotations from VLM. Both heads are trained on their respective annotations. + + The annotation_mode determines how sparse_temporal_proportions and dense_temporal_proportions + are loaded/generated during model initialization. + """ + + annotation_mode: str = "single_stage" # "single_stage", "dense_only", or "dual" + n_obs_steps: int = 8 # Number of observation history steps + frame_gap: int = 30 # Frame gap between frames (at 30 fps = 1 second) + max_rewind_steps: int = 4 # Maximum rewind steps for temporal augmentation + + # Total frames = 1 + n_obs_steps + max_rewind_steps (computed in property) + # During training with rewind: [obs_frames] + [rewind_frames] + # During inference: [obs_frames] only + + # Architecture params + image_dim: int = 512 + text_dim: int = 512 + hidden_dim: int = 768 + num_heads: int = 12 + num_layers: int = 8 + max_state_dim: int = 32 + drop_n_last_frames: int = 1 + batch_size: int = 64 + clip_batch_size: int = 64 + dropout: float = 0.1 + stage_loss_weight: float = 1.0 # Weight for stage classification loss when using subtask annotations + + rewind_probability: float = 0.8 + language_perturbation_probability: float = 0.2 + + # Sparse annotations (high-level stages) + num_sparse_stages: int = 1 + sparse_subtask_names: list | None = None + sparse_temporal_proportions: list | None = None + + # Dense annotations (fine-grained stages) + num_dense_stages: int | None = None + dense_subtask_names: list | None = None + dense_temporal_proportions: list | None = None + + pretrained_model_path: str | None = None + device: str | None = None + image_key: str = "observation.images.top" # Key for image used from the dataset + state_key: str = "observation.state" + + # Populated by the processor (video_features, state_features, text_features) + input_features: dict = field(default_factory=lambda: {}) + + # Output features (updated in __post_init__) + output_features: dict = field( + default_factory=lambda: { + "stage": PolicyFeature(shape=(9, 5), type=FeatureType.REWARD), + "progress": PolicyFeature(shape=(9, 1), type=FeatureType.REWARD), + } + ) + + normalization_mapping: dict[str, NormalizationMode] = field( + default_factory=lambda: { + "VISUAL": NormalizationMode.IDENTITY, + "STATE": NormalizationMode.MEAN_STD, + "LANGUAGE": NormalizationMode.IDENTITY, + "REWARD": NormalizationMode.IDENTITY, + } + ) + + def __post_init__(self): + super().__post_init__() + + if self.annotation_mode not in ["single_stage", "dense_only", "dual"]: + raise ValueError( + f"annotation_mode must be 'single_stage', 'dense_only', or 'dual', got {self.annotation_mode}" + ) + + if self.annotation_mode == "single_stage": + # Use task description as stage name, full episode as one stage + self.num_sparse_stages = 1 + self.sparse_subtask_names = ["task"] + self.sparse_temporal_proportions = [1.0] + self.num_dense_stages = None + self.dense_subtask_names = None + self.dense_temporal_proportions = None + + elif self.annotation_mode == "dense_only": + self.num_sparse_stages = 1 + self.sparse_subtask_names = ["task"] + self.sparse_temporal_proportions = [1.0] + + self.input_features = {} + self.output_features = {} + + if self.image_key: + self.input_features[self.image_key] = PolicyFeature(shape=(480, 640, 3), type=FeatureType.VISUAL) + + self.input_features[self.state_key] = PolicyFeature( + shape=(self.max_state_dim,), + type=FeatureType.STATE, + ) + + # Update output features based on annotation_mode + if self.annotation_mode in ["dense_only", "dual"]: + self.output_features["sparse_stage"] = PolicyFeature( + shape=(self.num_frames, self.num_sparse_stages), type=FeatureType.REWARD + ) + self.output_features["sparse_progress"] = PolicyFeature( + shape=(self.num_frames, 1), type=FeatureType.REWARD + ) + dense_stages = self.num_dense_stages or self.num_sparse_stages + self.output_features["dense_stage"] = PolicyFeature( + shape=(self.num_frames, dense_stages), type=FeatureType.REWARD + ) + self.output_features["dense_progress"] = PolicyFeature( + shape=(self.num_frames, 1), type=FeatureType.REWARD + ) + else: + self.output_features["sparse_stage"] = PolicyFeature( + shape=(self.num_frames, self.num_sparse_stages), type=FeatureType.REWARD + ) + self.output_features["sparse_progress"] = PolicyFeature( + shape=(self.num_frames, 1), type=FeatureType.REWARD + ) + + if self.max_rewind_steps >= self.n_obs_steps: + raise ValueError( + f"max_rewind_steps ({self.max_rewind_steps}) must be less than n_obs_steps ({self.n_obs_steps})" + ) + if self.num_sparse_stages < 1: + raise ValueError(f"num_sparse_stages must be at least 1, got {self.num_sparse_stages}") + if ( + self.annotation_mode in ["dense_only", "dual"] + and self.num_dense_stages is not None + and self.num_dense_stages < 2 + ): + raise ValueError(f"num_dense_stages must be at least 2, got {self.num_dense_stages}") + + def get_optimizer_preset(self) -> AdamWConfig: + """Get default optimizer configuration for SARM training.""" + return AdamWConfig( + lr=5e-5, + weight_decay=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + ) + + def get_scheduler_preset(self) -> CosineDecayWithWarmupSchedulerConfig: + """Get default learning rate scheduler configuration.""" + return CosineDecayWithWarmupSchedulerConfig( + peak_lr=5e-5, + decay_lr=5e-6, + num_warmup_steps=500, + num_decay_steps=50000, + ) + + def validate_features(self) -> None: + pass + + @property + def uses_dual_heads(self) -> bool: + """Whether the model uses dual heads (dense_only or dual annotation modes).""" + return self.annotation_mode in ["dense_only", "dual"] + + @property + def num_frames(self) -> int: + """Total number of frames in sequence. + + For training: 1 + n_obs_steps + max_rewind_steps + The sequence is: [obs_frames (n_obs_steps + 1)] + [rewind_frames (max_rewind_steps)] + """ + return 1 + self.n_obs_steps + self.max_rewind_steps + + @property + def max_length(self) -> int: + return self.num_frames + + @property + def observation_delta_indices(self) -> list[int]: + """Bidirectional frame sampling centered on target frame. + + Example with n_obs_steps=8, gap=30: + Before: [-120, -90, -60, -30] (4 frames) + Current: [0] (1 frame) + After: [30, 60, 90, 120] (4 frames) + Total: 9 frames + """ + half_steps = self.n_obs_steps // 2 + + past_deltas = [-self.frame_gap * i for i in range(half_steps, 0, -1)] + future_deltas = [self.frame_gap * i for i in range(1, half_steps + 1)] + obs_deltas = past_deltas + [0] + future_deltas + + # Rewind placeholders + rewind_deltas = [-self.frame_gap * (i + 1) for i in range(self.max_rewind_steps)] + + return obs_deltas + rewind_deltas + + @property + def action_delta_indices(self) -> None: + """SARM is a reward model, not an action policy.""" + return None + + @property + def reward_delta_indices(self) -> None: + return None diff --git a/src/lerobot/policies/sarm/modeling_sarm.py b/src/lerobot/policies/sarm/modeling_sarm.py new file mode 100644 index 000000000..a88b2ad64 --- /dev/null +++ b/src/lerobot/policies/sarm/modeling_sarm.py @@ -0,0 +1,793 @@ +#!/usr/bin/env python + +# Copyright 2025 Qianzhong Chen, Justin Yu, Mac Schwager, Pieter Abbeel, Yide Shentu, Philipp Wu +# and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +SARM: Stage-Aware Reward Modeling for Long Horizon Robot Manipulation. + +Paper: https://arxiv.org/abs/2509.25358 + +- StageTransformer: Predicts stage classification (sparse/dense) +- SubtaskTransformer: Predicts within-stage progress (tau) conditioned on stage +""" + +import json +import logging +import random + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F # noqa: N812 +from torch import Tensor + +from lerobot.policies.pretrained import PreTrainedPolicy +from lerobot.policies.sarm.configuration_sarm import SARMConfig +from lerobot.policies.sarm.sarm_utils import ( + normalize_stage_tau, + pad_state_to_max_dim, +) + + +class StageTransformer(nn.Module): + """ + Stage classification transformer for SARM. + + Predicts which stage/subtask the current frame belongs to. + Supports both sparse (high-level) and dense (fine-grained) annotation schemes. + + Input streams: [vis_proj, lang_proj, state_proj] concatenated -> (B, N+2, T, D) + Output: stage logits (B, T, num_classes) + """ + + def __init__( + self, + d_model: int = 512, + vis_emb_dim: int = 512, + text_emb_dim: int = 512, + state_dim: int = 32, + n_layers: int = 6, + n_heads: int = 8, + dropout: float = 0.1, + num_cameras: int = 1, + num_classes_sparse: int = 4, + num_classes_dense: int = 8, + ): + super().__init__() + self.d_model = d_model + self.num_cameras = num_cameras + + # Projections + self.lang_proj = nn.Linear(text_emb_dim, d_model) + self.visual_proj = nn.Linear(vis_emb_dim, d_model) + self.state_proj = nn.Linear(state_dim, d_model) + + # Encoder + enc_layer = nn.TransformerEncoderLayer(d_model, n_heads, 4 * d_model, dropout, batch_first=True) + self.transformer = nn.TransformerEncoder(enc_layer, n_layers) + + # Positional bias on first visual frame + self.first_pos = nn.Parameter(torch.zeros(1, d_model)) + + # Shared fusion MLP + # Fuses (num_cameras + 2) streams: cameras + lang + state + fused_in = d_model * (num_cameras + 2) + self.fusion_backbone = nn.Sequential( + nn.LayerNorm(fused_in), + nn.Linear(fused_in, d_model), + nn.ReLU(), + ) + + # Scheme-specific heads + self.heads = nn.ModuleDict( + { + "sparse": nn.Linear(d_model, num_classes_sparse), + "dense": nn.Linear(d_model, num_classes_dense), + } + ) + + def _prep_lang(self, lang_emb: torch.Tensor, B: int, T: int, D: int) -> torch.Tensor: # noqa: N803 + """ + Prepare language embeddings for fusion. + + Accepts lang_emb of shape: + - (B, text_emb_dim) -> broadcast across time + - (B, T, text_emb_dim) -> per-timestep (dense annotation mode) + + Returns: (B, 1, T, D) + """ + if lang_emb.dim() == 3: + # (B, T, E) -> (B, T, D) -> (B, 1, T, D) + lang_proj = self.lang_proj(lang_emb).unsqueeze(1) + else: + # (B, E) -> (B, 1, 1, D) -> expand to (B, 1, T, D) + lang_proj = self.lang_proj(lang_emb).unsqueeze(1).unsqueeze(2).expand(B, 1, T, D) + return lang_proj + + def forward( + self, + img_seq: torch.Tensor, # (B, N, T, vis_emb_dim) + lang_emb: torch.Tensor, # (B, E) or (B, T, E) + state: torch.Tensor, # (B, T, state_dim) + lengths: torch.Tensor, # (B,) - valid sequence lengths + scheme: str = "sparse", # "sparse" or "dense" + ) -> torch.Tensor: + """ + Forward pass for stage classification. + + Args: + img_seq: Image embeddings (B, N, T, vis_emb_dim) where N=num_cameras + lang_emb: Language embeddings (B, E) or (B, T, E) for dense + state: State features (B, T, state_dim) + lengths: Valid sequence lengths (B,) for masking + scheme: "sparse" or "dense" for head selection + + Returns: + Stage logits (B, T, num_classes) + """ + assert scheme in self.heads, f"Unknown scheme '{scheme}'. Use one of {list(self.heads.keys())}." + + B, N, T, _ = img_seq.shape # noqa: N806 + D = self.d_model # noqa: N806 + device = img_seq.device + + # Project inputs + vis_proj = self.visual_proj(img_seq) # (B, N, T, D) + state_proj = self.state_proj(state).unsqueeze(1) # (B, 1, T, D) + lang_proj = self._prep_lang(lang_emb, B, T, D) # (B, 1, T, D) + + # Concatenate streams + # cameras + lang + state -> (B, N+2, T, D) + x = torch.cat([vis_proj, lang_proj, state_proj], dim=1) + + # Add positional bias to first visual frame + x[:, :N, 0, :] = x[:, :N, 0, :] + self.first_pos + + # Flatten to tokens for Transformer + x_tokens = x.view(B, (N + 2) * T, D) + L = x_tokens.size(1) # noqa: N806 + + # Create padding mask + base_mask = torch.arange(T, device=device).expand(B, T) >= lengths.unsqueeze(1) # (B, T) + mask = base_mask.unsqueeze(1).expand(B, N + 2, T).reshape(B, (N + 2) * T) + + # Create causal mask + causal_mask = torch.triu(torch.ones(L, L, device=device, dtype=torch.bool), diagonal=1) + + # Encode + h = self.transformer(x_tokens, mask=causal_mask, src_key_padding_mask=mask, is_causal=True) + + # Reshape and fuse + h = h.view(B, N + 2, T, D).permute(0, 2, 1, 3).reshape(B, T, (N + 2) * D) + fused = self.fusion_backbone(h) # (B, T, D) + + # Scheme-specific logits + logits = self.heads[scheme](fused) # (B, T, num_classes) + return logits + + +class SubtaskTransformer(nn.Module): + """ + Subtask progress regression transformer for SARM. + + Predicts within-stage normalized progress (tau) conditioned on stage prior. + The stage prior is a one-hot encoding passed from StageTransformer predictions. + + Input streams: [vis_proj, lang_proj, state_proj, stage_emb] -> (B, N+3, T, D) + Output: tau predictions (B, T) in [0, 1] + """ + + def __init__( + self, + d_model: int = 512, + vis_emb_dim: int = 512, + text_emb_dim: int = 512, + state_dim: int = 32, + n_layers: int = 6, + n_heads: int = 8, + dropout: float = 0.1, + num_cameras: int = 1, + ): + super().__init__() + self.d_model = d_model + self.num_cameras = num_cameras + + # Projections + self.lang_proj = nn.Linear(text_emb_dim, d_model) + self.visual_proj = nn.Linear(vis_emb_dim, d_model) + self.state_proj = nn.Linear(state_dim, d_model) + + # Encoder + enc = nn.TransformerEncoderLayer(d_model, n_heads, 4 * d_model, dropout, batch_first=True) + self.transformer = nn.TransformerEncoder(enc, n_layers) + + # Learned bias on first visual frame + self.first_pos = nn.Parameter(torch.zeros(1, d_model)) + + # Shared fusion backbone + # Fuses (num_cameras + 3) streams: cameras + lang + state + stage_emb + fused_in = d_model * (num_cameras + 3) + self.fusion_backbone = nn.Sequential( + nn.LayerNorm(fused_in), + nn.Linear(fused_in, d_model), + nn.ReLU(), + ) + + # Scheme-specific regression heads + self.heads = nn.ModuleDict( + { + "sparse": nn.Linear(d_model, 1), + "dense": nn.Linear(d_model, 1), + } + ) + + def _prep_lang(self, lang_emb: torch.Tensor, B: int, T: int, D: int) -> torch.Tensor: # noqa: N803 + """ + Prepare language embeddings for fusion. + """ + if lang_emb.dim() == 3: + # (B, T, E) -> (B, T, D) -> (B, 1, T, D) + return self.lang_proj(lang_emb).unsqueeze(1) + else: + # (B, E) -> (B, 1, 1, D) -> (B, 1, T, D) + return self.lang_proj(lang_emb).unsqueeze(1).unsqueeze(2).expand(B, 1, T, D) + + def _stage_to_dmodel(self, stage_prior: torch.Tensor) -> torch.Tensor: + """ + Deterministic projection of one-hot stage to d_model by pad/truncate. + + Args: + stage_prior: One-hot stage embedding (B, 1, T, C) + + Returns: + Projected stage embedding (B, 1, T, d_model) + """ + B, one, T, C = stage_prior.shape # noqa: N806 + D = self.d_model # noqa: N806 + if D == C: + return stage_prior + elif D > C: + pad = torch.zeros(B, one, T, D - C, device=stage_prior.device, dtype=stage_prior.dtype) + return torch.cat([stage_prior, pad], dim=-1) + else: + return stage_prior[..., :D] + + def forward( + self, + img_seq: torch.Tensor, # (B, N, T, vis_emb_dim) + lang_emb: torch.Tensor, # (B, E) or (B, T, E) + state: torch.Tensor, # (B, T, state_dim) + lengths: torch.Tensor, # (B,) - valid sequence lengths + stage_prior: torch.Tensor, # (B, 1, T, C) one-hot from gen_stage_emb + scheme: str = "sparse", # "sparse" or "dense" + ) -> torch.Tensor: + """ + Forward pass for subtask progress regression. + + Args: + img_seq: Image embeddings (B, N, T, vis_emb_dim) + lang_emb: Language embeddings (B, E) or (B, T, E) + state: State features (B, T, state_dim) + lengths: Valid sequence lengths (B,) for masking + stage_prior: One-hot stage prior (B, 1, T, num_classes) + scheme: "sparse" or "dense" for head selection + + Returns: + Tau predictions (B, T) in [0, 1] via sigmoid + """ + assert scheme in self.heads, f"Unknown scheme '{scheme}'. Use one of {list(self.heads.keys())}." + + B, N, T, _ = img_seq.shape # noqa: N806 + D = self.d_model # noqa: N806 + device = img_seq.device + + # Project inputs + vis_proj = self.visual_proj(img_seq) # (B, N, T, D) + state_proj = self.state_proj(state).unsqueeze(1) # (B, 1, T, D) + lang_proj = self._prep_lang(lang_emb, B, T, D) # (B, 1, T, D) + stage_emb = self._stage_to_dmodel(stage_prior) # (B, 1, T, D) + + # Concatenate all streams + # cameras + lang + state + stage_emb -> (B, N+3, T, D) + x = torch.cat([vis_proj, lang_proj, state_proj, stage_emb], dim=1) + + # Add positional bias to first visual frame + x[:, :N, 0, :] = x[:, :N, 0, :] + self.first_pos + + # Flatten to tokens + x_tokens = x.view(B, (N + 3) * T, D) + L = x_tokens.size(1) # noqa: N806 + + # Create padding mask + base_mask = torch.arange(T, device=device).expand(B, T) >= lengths.unsqueeze(1) + mask = base_mask.unsqueeze(1).expand(B, N + 3, T).reshape(B, (N + 3) * T) + + # Create causal mask + causal_mask = torch.triu(torch.ones(L, L, device=device, dtype=torch.bool), diagonal=1) + + # Encode + h = self.transformer(x_tokens, mask=causal_mask, src_key_padding_mask=mask, is_causal=True) + + # Reshape and fuse + h = h.view(B, N + 3, T, D) + h_flat = h.permute(0, 2, 1, 3).reshape(B, T, (N + 3) * D) + fused = self.fusion_backbone(h_flat) # (B, T, D) + + # Scheme-specific regression head -> sigmoid + r = torch.sigmoid(self.heads[scheme](fused)).squeeze(-1) # (B, T) + return r + + +def gen_stage_emb(num_classes: int, targets: torch.Tensor) -> torch.Tensor: + """ + Generate one-hot stage embeddings from targets. + + Args: + num_classes: Number of stage classes + targets: Target values (B, T) where integer part is stage index + + Returns: + One-hot stage embedding (B, 1, T, num_classes) + """ + # Integer part of float targets -> [0, C-1] + idx = targets.long().clamp(min=0, max=num_classes - 1) # (B, T) + C = num_classes # noqa: N806 + # Identity-lookup one-hot + stage_onehot = torch.eye(C, device=targets.device)[idx] # (B, T, C) + stage_onehot = stage_onehot.unsqueeze(1) # (B, 1, T, C) + return stage_onehot + + +class SARMRewardModel(PreTrainedPolicy): + """ + SARM Reward Model for stage-aware task completion rewards. + + Uses two separate transformer models: + - StageTransformer: Classifies which stage/subtask + - SubtaskTransformer: Predicts within-stage progress (tau) + + Training uses 75%/25% GT/predicted stage conditioning (teacher forcing). + """ + + name = "sarm" + config_class = SARMConfig + + def __init__(self, config: SARMConfig, dataset_stats: dict | None = None, dataset_meta=None): + super().__init__(config, dataset_stats) + config.validate_features() + self.config = config + self.dataset_stats = dataset_stats + self.device = torch.device( + config.device if config.device else "cuda" if torch.cuda.is_available() else "cpu" + ) + + # Load temporal proportions based on annotation_mode + if config.annotation_mode == "single_stage": + logging.info(f"Using single_stage mode: sparse_subtask_names={config.sparse_subtask_names}") + elif dataset_meta is not None: + self._load_temporal_proportions(dataset_meta) + + # Create two separate models + self.stage_model = StageTransformer( + d_model=config.hidden_dim, + vis_emb_dim=config.image_dim, + text_emb_dim=config.text_dim, + state_dim=config.max_state_dim, + n_layers=config.num_layers, + n_heads=config.num_heads, + dropout=config.dropout, + num_cameras=1, # Single camera for now + num_classes_sparse=config.num_sparse_stages, + num_classes_dense=config.num_dense_stages or config.num_sparse_stages, + ) + + self.subtask_model = SubtaskTransformer( + d_model=config.hidden_dim, + vis_emb_dim=config.image_dim, + text_emb_dim=config.text_dim, + state_dim=config.max_state_dim, + n_layers=config.num_layers, + n_heads=config.num_heads, + dropout=config.dropout, + num_cameras=1, + ) + + self.stage_model.to(self.device) + self.subtask_model.to(self.device) + + # GT/predicted stage ratio for teacher forcing + self.gt_stage_ratio = 0.75 + + if config.uses_dual_heads: + logging.info( + f"SARM initialized with dual heads: {config.num_sparse_stages} sparse stages, " + f"{config.num_dense_stages} dense stages" + ) + else: + logging.info(f"SARM initialized with sparse head only: {config.num_sparse_stages} stages") + + logging.info(f"SARM initialized on {self.device}") + + def _load_proportions_from_json(self, path, annotation_type: str) -> tuple[list[str], list[float]]: + """Load temporal proportions from a JSON file (preserving order).""" + if not path.exists(): + raise ValueError( + f"{annotation_type.capitalize()} temporal proportions not found at {path}. " + f"Run the subtask annotation tool with --{annotation_type}-subtasks to generate annotations." + ) + with open(path) as f: + proportions_dict = json.load(f) + names = list(proportions_dict.keys()) + logging.info(f"Loaded {len(names)} {annotation_type} subtasks: {names}") + logging.info(f"{annotation_type.capitalize()} temporal proportions: {proportions_dict}") + return names, [proportions_dict[name] for name in names] + + def _load_temporal_proportions(self, dataset_meta) -> None: + """Load temporal proportions based on annotation_mode.""" + meta_path = dataset_meta.root / "meta" + + if self.config.annotation_mode == "dual": + names, props = self._load_proportions_from_json( + meta_path / "temporal_proportions_sparse.json", "sparse" + ) + ( + self.config.num_sparse_stages, + self.config.sparse_subtask_names, + self.config.sparse_temporal_proportions, + ) = len(names), names, props + + if self.config.annotation_mode in ["dense_only", "dual"]: + names, props = self._load_proportions_from_json( + meta_path / "temporal_proportions_dense.json", "dense" + ) + ( + self.config.num_dense_stages, + self.config.dense_subtask_names, + self.config.dense_temporal_proportions, + ) = len(names), names, props + if self.config.annotation_mode == "dense_only": + logging.info(f"Using auto-generated sparse 'task' stage: {self.config.sparse_subtask_names}") + + def to(self, device): + """Override to method to ensure all components move together.""" + super().to(device) + self.device = device if isinstance(device, torch.device) else torch.device(device) + self.stage_model.to(device) + self.subtask_model.to(device) + return self + + @torch.no_grad() + def calculate_rewards( + self, + text_embeddings: np.ndarray | torch.Tensor, + video_embeddings: np.ndarray | torch.Tensor, + state_features: np.ndarray | torch.Tensor | None = None, + lengths: np.ndarray | torch.Tensor | None = None, + return_all_frames: bool = False, + return_stages: bool = False, + return_confidence: bool = False, + head_mode: str | None = "sparse", + frame_index: int | None = None, + ) -> np.ndarray | tuple: + """ + Calculate rewards for given text, video, and state representations. + + This is the canonical method for SARM reward computation, used for: + - Inference/visualization + - RA-BC weight computation + + Args: + text_embeddings: Encoded text representations (batch_size, 512) + video_embeddings: Encoded video representations (batch_size, num_frames, 512) + state_features: Joint state features (batch_size, num_frames, state_dim) + lengths: Valid sequence lengths (batch_size,) + return_all_frames: If True, return rewards for all frames + return_stages: If True, also return stage predictions + return_confidence: If True, also return stage confidence + head_mode: Which head to use ("sparse" or "dense") + frame_index: Index of the target frame to extract (default: n_obs_steps). + + Returns: + Rewards and optionally stage probs/confidence. + """ + if isinstance(text_embeddings, np.ndarray): + text_embeddings = torch.tensor(text_embeddings, dtype=torch.float32) + if isinstance(video_embeddings, np.ndarray): + video_embeddings = torch.tensor(video_embeddings, dtype=torch.float32) + if state_features is not None and isinstance(state_features, np.ndarray): + state_features = torch.tensor(state_features, dtype=torch.float32) + + # Handle single sample case + if text_embeddings.dim() == 1: + text_embeddings = text_embeddings.unsqueeze(0) + video_embeddings = video_embeddings.unsqueeze(0) + if state_features is not None: + state_features = state_features.unsqueeze(0) + single_sample = True + else: + single_sample = False + + batch_size = video_embeddings.shape[0] + seq_len = video_embeddings.shape[1] + + scheme = head_mode + + # Default lengths if not provided + if lengths is None: + lengths = torch.full((batch_size,), seq_len, dtype=torch.int32) + elif isinstance(lengths, np.ndarray): + lengths = torch.tensor(lengths, dtype=torch.int32) + + # Reshape video to (B, N, T, D) for multi-camera format + # Currently single camera: (B, T, D) -> (B, 1, T, D) + img_seq = video_embeddings.unsqueeze(1).to(self.device) + lang_emb = text_embeddings.to(self.device) + state = ( + state_features.to(self.device) + if state_features is not None + else torch.zeros(batch_size, seq_len, self.config.max_state_dim, device=self.device) + ) + lens = lengths.to(self.device) + + # Pad state to max_state_dim + state = pad_state_to_max_dim(state, self.config.max_state_dim) + + # Get num_classes for this scheme + num_classes = self.config.num_sparse_stages if scheme == "sparse" else self.config.num_dense_stages + + # Run stage model + stage_logits = self.stage_model(img_seq, lang_emb, state, lens, scheme=scheme) + stage_probs = F.softmax(stage_logits, dim=-1) # (B, T, num_classes) + stage_idx = stage_probs.argmax(dim=-1) # (B, T) + stage_conf = stage_probs.gather(-1, stage_idx.unsqueeze(-1)).squeeze(-1) # (B, T) + + # Create one-hot stage prior + stage_onehot = F.one_hot(stage_idx, num_classes=num_classes).float() # (B, T, C) + stage_emb = stage_onehot.unsqueeze(1) # (B, 1, T, C) + + # Run subtask model + tau_pred = self.subtask_model(img_seq, lang_emb, state, lens, stage_emb, scheme=scheme) + + # Compute final reward: stage + tau + raw_reward = stage_idx.float() + tau_pred # (B, T) + + # Normalize to [0, 1] using temporal proportions for proper weighting + if scheme == "sparse": + normalized_reward = normalize_stage_tau( + raw_reward, + num_stages=num_classes, + temporal_proportions=self.config.sparse_temporal_proportions, + subtask_names=self.config.sparse_subtask_names, + ) + else: + normalized_reward = normalize_stage_tau( + raw_reward, + num_stages=num_classes, + temporal_proportions=self.config.dense_temporal_proportions, + subtask_names=self.config.dense_subtask_names, + ) + + # Default frame index is n_obs_steps (last observation frame) + if frame_index is None: + frame_index = self.config.n_obs_steps + + # Prepare outputs (batch mode or no smoothing) + if return_all_frames: + rewards = normalized_reward.cpu().numpy() + else: + rewards = normalized_reward[:, frame_index].cpu().numpy() + + if single_sample: + rewards = rewards[0] if not return_all_frames else rewards[0] + + outputs = [rewards] + if return_stages: + probs = stage_probs.cpu().numpy() + if single_sample: + probs = probs[0] + outputs.append(probs) + if return_confidence: + conf = stage_conf.cpu().numpy() + if single_sample: + conf = conf[0] + outputs.append(conf) + + return outputs[0] if len(outputs) == 1 else tuple(outputs) + + def train(self, mode: bool = True): + """Set training mode for both models.""" + super().train(mode) + self.stage_model.train(mode) + self.subtask_model.train(mode) + return self + + def eval(self): + """Set evaluation mode for both models.""" + return self.train(False) + + def parameters(self): + """Override to return trainable parameters from both models.""" + from itertools import chain + + return chain(self.stage_model.parameters(), self.subtask_model.parameters()) + + def get_optim_params(self): + """Override to return optimizer parameters from both models.""" + return self.parameters() + + def reset(self): + """Required by PreTrainedPolicy but not used for reward models.""" + pass + + def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: + """Required by PreTrainedPolicy but not used for reward models.""" + raise NotImplementedError("SARM model does not predict action chunks") + + def select_action(self, batch: dict[str, Tensor]) -> Tensor: + """Required by PreTrainedPolicy but not used for SARM.""" + raise NotImplementedError("SARM model does not select actions") + + def _train_step( + self, + img_emb: torch.Tensor, # (B, N, T, D) + lang_emb: torch.Tensor, # (B, E) or (B, T, E) + state: torch.Tensor, # (B, T, state_dim) + lengths: torch.Tensor, # (B,) + targets: torch.Tensor, # (B, T) - format: stage.tau + scheme: str, + ) -> dict[str, torch.Tensor]: + """ + Single training step for one annotation scheme. + + Implements 75%/25% GT/predicted stage conditioning. + + Args: + img_emb: Image embeddings (B, N, T, D) + lang_emb: Language embeddings + state: State features + lengths: Valid sequence lengths + targets: Target values where floor=stage, remainder=tau + scheme: "sparse" or "dense" + + Returns: + Dict with stage_loss, subtask_loss, total_loss + """ + num_classes = self.config.num_sparse_stages if scheme == "sparse" else self.config.num_dense_stages + + # Ground truth: stage (integer) and tau (fractional) + # Clamp stage indices to valid range [0, num_classes-1] to handle edge cases + # where targets may exceed expected range (e.g., frames between subtasks) + gt_stage = torch.floor(targets).long().clamp(0, num_classes - 1) # (B, T) + gt_tau = torch.remainder(targets, 1.0) # (B, T) + + # Run stage model + stage_pred = self.stage_model(img_emb, lang_emb, state, lengths, scheme=scheme) + + # 75%/25% GT/predicted stage conditioning + if random.random() < self.gt_stage_ratio: + # Mode 1: Use ground truth stage -> one-hot + stage_emb = gen_stage_emb(num_classes, targets) # (B, 1, T, C) + else: + # Mode 2: Use predicted stage argmax -> one-hot + stage_idx = stage_pred.argmax(dim=-1) # (B, T) + stage_onehot = F.one_hot(stage_idx, num_classes=num_classes).float() # (B, T, C) + stage_emb = stage_onehot.unsqueeze(1) # (B, 1, T, C) + + # Run subtask model with stage prior + tau_pred = self.subtask_model(img_emb, lang_emb, state, lengths, stage_emb, scheme=scheme) + + # Compute losses + stage_loss = F.cross_entropy(stage_pred.view(-1, num_classes), gt_stage.view(-1), reduction="mean") + subtask_loss = F.mse_loss(tau_pred, gt_tau, reduction="mean") + + return { + "stage_loss": stage_loss, + "subtask_loss": subtask_loss, + "total_loss": stage_loss + subtask_loss, + } + + def forward(self, batch): + """ + Forward pass for SARM reward model training. + + Uses stage+tau target format where: + - Integer part = stage index + - Fractional part = within-stage progress (tau) + + Training uses 75%/25% GT/predicted stage conditioning. + + Args: + batch: Dictionary with 'observation' containing: + - 'video_features': (B, T, 512) pre-encoded video features + - 'text_features': (B, 512) or (B, T, 512) text features + - 'state_features': (B, T, state_dim) joint state features + - 'lengths': (B,) valid sequence lengths + - 'sparse_targets': (B, T) sparse targets (stage.tau format) + - 'dense_targets': (B, T) dense targets (optional, for dual mode) + + Returns: + Tuple of (total_loss, output_dict with loss components) + """ + observation = batch.get("observation", batch) + + # Extract features + video_features = observation["video_features"].to(self.device) + text_features = observation["text_features"].to(self.device) + state_features = observation.get("state_features") + if state_features is not None: + state_features = state_features.to(self.device) + + batch_size = video_features.shape[0] + seq_len = video_features.shape[1] + + # Get lengths (default to full sequence) + lengths = observation.get("lengths") + if lengths is None: + lengths = torch.full((batch_size,), seq_len, dtype=torch.int32, device=self.device) + else: + lengths = lengths.to(self.device) + + # Reshape video to (B, N, T, D) - single camera + img_emb = video_features.unsqueeze(1) + + # Pad state to max_state_dim + if state_features is None: + state_features = torch.zeros(batch_size, seq_len, self.config.max_state_dim, device=self.device) + else: + state_features = pad_state_to_max_dim(state_features, self.config.max_state_dim) + + output_dict = {} + total_loss = torch.tensor(0.0, device=self.device) + + # Sparse training (always) + sparse_targets = observation.get("sparse_targets") + if sparse_targets is None: + # Try legacy format + sparse_targets = observation.get("targets") + if sparse_targets is None: + raise ValueError("sparse_targets (or targets) is required for SARM training") + sparse_targets = sparse_targets.to(self.device) + + sparse_result = self._train_step( + img_emb, text_features, state_features, lengths, sparse_targets, scheme="sparse" + ) + output_dict["sparse_stage_loss"] = sparse_result["stage_loss"].item() + output_dict["sparse_subtask_loss"] = sparse_result["subtask_loss"].item() + total_loss = total_loss + sparse_result["total_loss"] + + # Dense training (if dual mode) + if self.config.uses_dual_heads: + dense_targets = observation.get("dense_targets") + if dense_targets is not None: + dense_targets = dense_targets.to(self.device) + dense_result = self._train_step( + img_emb, text_features, state_features, lengths, dense_targets, scheme="dense" + ) + output_dict["dense_stage_loss"] = dense_result["stage_loss"].item() + output_dict["dense_subtask_loss"] = dense_result["subtask_loss"].item() + total_loss = total_loss + dense_result["total_loss"] + + output_dict["total_loss"] = total_loss.item() + return total_loss, output_dict + + +def compute_stage_loss(stage_logits: torch.Tensor, target_stages: torch.Tensor) -> torch.Tensor: + """Compute cross-entropy loss for stage classification.""" + _, _, num_stages = stage_logits.shape + stage_logits_flat = stage_logits.reshape(-1, num_stages) + # Clamp target stage indices to valid range [0, num_stages-1] + target_stages_flat = target_stages.reshape(-1).clamp(0, num_stages - 1) + return F.cross_entropy(stage_logits_flat, target_stages_flat) diff --git a/src/lerobot/policies/sarm/processor_sarm.py b/src/lerobot/policies/sarm/processor_sarm.py new file mode 100644 index 000000000..5c617282a --- /dev/null +++ b/src/lerobot/policies/sarm/processor_sarm.py @@ -0,0 +1,518 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""SARM Processor for encoding images/text and generating stage+tau targets.""" + +import random +from typing import Any + +import numpy as np +import pandas as pd +import torch +from faker import Faker +from PIL import Image +from transformers import CLIPModel, CLIPProcessor + +from lerobot.configs.types import FeatureType, PolicyFeature +from lerobot.policies.sarm.configuration_sarm import SARMConfig +from lerobot.policies.sarm.sarm_utils import ( + apply_rewind_augmentation, + compute_absolute_indices, + find_stage_and_tau, + pad_state_to_max_dim, +) +from lerobot.processor import ( + AddBatchDimensionProcessorStep, + DeviceProcessorStep, + NormalizerProcessorStep, + PolicyAction, + PolicyProcessorPipeline, + ProcessorStep, + RenameObservationsProcessorStep, +) +from lerobot.processor.converters import ( + from_tensor_to_numpy, + policy_action_to_transition, + transition_to_policy_action, +) +from lerobot.processor.core import EnvTransition, TransitionKey +from lerobot.processor.pipeline import PipelineFeatureType +from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME + + +class SARMEncodingProcessorStep(ProcessorStep): + """ProcessorStep that encodes images and text with CLIP and generates stage and progress labels for SARM.""" + + def __init__( + self, + config: SARMConfig, + image_key: str | None = None, + dataset_meta=None, + dataset_stats: dict | None = None, + ): + super().__init__() + self.config = config + self.image_key = image_key or config.image_key + self.dataset_meta = dataset_meta + self.dataset_stats = dataset_stats + self.annotation_mode = config.annotation_mode + + # Helper to create temporal proportions dict + def make_props_dict(names, props): + return dict(zip(names, props, strict=True)) if names and props else None + + # Sparse annotations (always needed) + self.sparse_temporal_proportions = make_props_dict( + config.sparse_subtask_names, config.sparse_temporal_proportions + ) + self.sparse_subtask_names = config.sparse_subtask_names + + # Dense annotations (only for dual mode) + self.dense_subtask_names = config.dense_subtask_names if config.uses_dual_heads else None + self.dense_temporal_proportions = ( + make_props_dict(config.dense_subtask_names, config.dense_temporal_proportions) + if config.uses_dual_heads + else None + ) + + self.device = torch.device( + self.config.device if self.config.device else "cuda" if torch.cuda.is_available() else "cpu" + ) + + self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") + self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", use_fast=True) + self.clip_model.to(self.device) + self.clip_model.eval() + + self.verbs = ["move", "grasp", "rotate", "push", "pull", "slide", "lift", "place"] + self.fake = Faker() + + def _find_episode_for_frame(self, frame_idx: int) -> int: + """Find the episode index for a given frame index.""" + for ep_idx in range(len(self.dataset_meta.episodes)): + ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"] + ep_end = self.dataset_meta.episodes[ep_idx]["dataset_to_index"] + if ep_start <= frame_idx < ep_end: + return ep_idx + return 0 + + def _get_episode_indices(self, frame_indices: np.ndarray, episode_index) -> np.ndarray: + """Get episode indices for each frame index.""" + if episode_index is None: + return np.array([self._find_episode_for_frame(int(f)) for f in frame_indices]) + + episode_indices = np.atleast_1d(np.asarray(from_tensor_to_numpy(episode_index))) + + # If single episode but multiple frames, compute episode for each frame + if len(episode_indices) == 1 and len(frame_indices) > 1: + return np.array([self._find_episode_for_frame(int(f)) for f in frame_indices]) + + return episode_indices + + def _generate_perturbed_task(self) -> str: + """Generate a random perturbed task string for language perturbation.""" + num_words = random.randint(1, 5) + verb = random.choice(self.verbs) + phrase = " ".join([verb] + self.fake.words(nb=num_words)) + return phrase + + def _get_annotation_config(self, annotation_type: str) -> tuple[list[str], dict[str, float] | None]: + """Get global subtask names and temporal proportions for an annotation type.""" + if annotation_type == "dense": + return self.dense_subtask_names, self.dense_temporal_proportions + return self.sparse_subtask_names, self.sparse_temporal_proportions + + def _load_episode_annotations( + self, + ep_idx: int, + episodes_df: pd.DataFrame | None, + annotation_type: str, + global_names: list[str], + ) -> tuple[list | None, list | None, list | None]: + """Load subtask annotations for an episode from DataFrame.""" + # Single-stage mode: (linear progress 0β†’1) + if episodes_df is None or len(global_names) == 1: + return None, None, None + + # Resolve column name with fallback + def col(suffix): + prefixed = f"{annotation_type}_{suffix}" + return prefixed if prefixed in episodes_df.columns else suffix + + col_names = col("subtask_names") + if col_names not in episodes_df.columns or ep_idx >= len(episodes_df): + return None, None, None + + subtask_names = episodes_df.loc[ep_idx, col_names] + if subtask_names is None or (isinstance(subtask_names, float) and pd.isna(subtask_names)): + return None, None, None + + return ( + subtask_names, + episodes_df.loc[ep_idx, col("subtask_start_frames")], + episodes_df.loc[ep_idx, col("subtask_end_frames")], + ) + + def __call__(self, transition: EnvTransition) -> EnvTransition: + """ + Encode images, text, and normalize states in the transition. + + Implements SARM training data preparation: + - Applies language perturbation (20% probability) + - Applies rewind augmentation (80% probability) + - Generates stage+tau targets for all frames + - Outputs lengths tensor for valid sequence masking + """ + new_transition = transition.copy() if hasattr(transition, "copy") else dict(transition) + observation = new_transition.get(TransitionKey.OBSERVATION) + comp_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) + + frame_index = comp_data.get("index") + episode_index = comp_data.get("episode_index") + + if frame_index is None: + raise ValueError("Frame index ('index') not found in COMPLEMENTARY_DATA") + if episode_index is None: + raise ValueError("Episode index ('episode_index') not found in COMPLEMENTARY_DATA") + + frame_indices = np.atleast_1d(np.asarray(from_tensor_to_numpy(frame_index))) + episode_indices = self._get_episode_indices(frame_indices, episode_index) + + image = observation.get(self.image_key) + if isinstance(image, torch.Tensor): + image = image.cpu().numpy() + + # If 4D (T, C, H, W) from delta_timestamps, add batch dim + # If 3D (C, H, W) single frame, add batch and time dims + if image.ndim == 4: + image = image[np.newaxis, ...] # (T, C, H, W) -> (1, T, C, H, W) + elif image.ndim == 3: + image = image[np.newaxis, np.newaxis, ...] # (C, H, W) -> (1, 1, C, H, W) + + batch_size = image.shape[0] + total_frames = image.shape[1] # Should be 13: 9 obs + 4 rewind placeholders + n_obs_steps = self.config.n_obs_steps + max_rewind_steps = self.config.max_rewind_steps + n_obs_frames = 1 + n_obs_steps # 9 observation frames (including current) + + # Rewind augmentation + rewind_steps = torch.zeros(batch_size, dtype=torch.int32) + apply_rewind = self.training and random.random() < self.config.rewind_probability + + if apply_rewind and self.dataset_meta is not None: + for b_idx, (ep_idx, frame_idx) in enumerate( + zip(episode_indices.tolist(), frame_indices.tolist(), strict=True) + ): + ep_idx, frame_idx = int(ep_idx), int(frame_idx) + ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"] + + rewind_step, _ = apply_rewind_augmentation( + frame_idx, ep_start, n_obs_steps, max_rewind_steps, frame_gap=self.config.frame_gap + ) + rewind_steps[b_idx] = rewind_step + + # Compute valid lengths: n_obs_frames + rewind_steps + lengths = n_obs_frames + rewind_steps # (B,) + + # Apply rewind masking to images + # For frames beyond valid length, we mask with zeros (or copy last valid frame) + for b_idx in range(batch_size): + valid_len = lengths[b_idx].item() + if valid_len < total_frames: + image[b_idx, valid_len:] = 0 # Zero out frames beyond valid length + + # Encode images with CLIP + video_features = self._encode_images_batch(image) + observation["video_features"] = video_features + + state_key = self.config.state_key + state_data = observation.get(state_key) + + if isinstance(state_data, torch.Tensor): + state_tensor = state_data.float() + else: + state_tensor = torch.tensor(state_data, dtype=torch.float32) + + if state_tensor.ndim == 2: + state_tensor = state_tensor.unsqueeze(0) # (T, D) -> (1, T, D) + elif state_tensor.ndim == 1: + state_tensor = state_tensor.unsqueeze(0).unsqueeze(0) # (D,) -> (1, 1, D) + + # Apply same rewind masking to state + for b_idx in range(batch_size): + valid_len = lengths[b_idx].item() + if valid_len < state_tensor.shape[1]: + state_tensor[b_idx, valid_len:] = 0 # Zero out frames beyond valid length + + observation["state_features"] = pad_state_to_max_dim(state_tensor, self.config.max_state_dim) + + task = comp_data.get("task") + if isinstance(task, list): + task = task[0] if task else "" + + # Apply language perturbation during training (20% probability) + # When perturbed, targets will be zeroed to train model to output low values for irrelevant text + apply_perturbation = self.training and random.random() < self.config.language_perturbation_probability + if apply_perturbation: + task = self._generate_perturbed_task() + + # Encode text with CLIP + observation["text_features"] = self._encode_text_clip(task, batch_size) + + # Store lengths for model + observation["lengths"] = lengths + + # When language is perturbed, targets are zero so perturbed samples don't contribute to progress loss + if self.dataset_meta is not None: + episodes_df = None + if self.sparse_subtask_names != ["task"]: + episodes_df = self.dataset_meta.episodes.to_pandas() + + # Generate sparse targets + if self.sparse_temporal_proportions is not None: + if apply_perturbation: + # Zero targets when language is perturbed + sparse_targets = torch.zeros(batch_size, total_frames, dtype=torch.float32) + else: + sparse_targets = self._compute_batch_targets( + frame_indices, episode_indices, lengths, rewind_steps, episodes_df, "sparse" + ) + observation["sparse_targets"] = sparse_targets + + # Generate dense targets (for dual mode) + if self.config.uses_dual_heads and self.dense_temporal_proportions is not None: + if apply_perturbation: + # Zero targets when language is perturbed + dense_targets = torch.zeros(batch_size, total_frames, dtype=torch.float32) + else: + dense_targets = self._compute_batch_targets( + frame_indices, episode_indices, lengths, rewind_steps, episodes_df, "dense" + ) + observation["dense_targets"] = dense_targets + + new_transition[TransitionKey.OBSERVATION] = observation + return new_transition + + def _compute_batch_targets( + self, + frame_indices: np.ndarray, + episode_indices: np.ndarray, + lengths: torch.Tensor, + rewind_steps: torch.Tensor, + episodes_df: pd.DataFrame | None, + annotation_type: str, + ) -> torch.Tensor: + """Compute stage+tau targets for a batch of samples.""" + batch_size = len(frame_indices) + n_obs_steps = self.config.n_obs_steps + max_rewind_steps = self.config.max_rewind_steps + total_frames = 1 + n_obs_steps + max_rewind_steps + frame_gap = self.config.frame_gap + + global_names, temporal_props = self._get_annotation_config(annotation_type) + targets = torch.zeros(batch_size, total_frames, dtype=torch.float32) + + for b_idx in range(batch_size): + ep_idx = int(episode_indices[b_idx]) + frame_idx = int(frame_indices[b_idx]) + + ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"] + ep_end = self.dataset_meta.episodes[ep_idx]["dataset_to_index"] + ep_length = ep_end - ep_start + + subtask_names, subtask_start_frames, subtask_end_frames = self._load_episode_annotations( + ep_idx, episodes_df, annotation_type, global_names + ) + + # Compute observation frame indices + obs_indices, _ = compute_absolute_indices( + frame_idx, ep_start, ep_end, n_obs_steps, frame_gap=frame_gap + ) + obs_indices = obs_indices.tolist() + + # Compute targets for observation frames + for t_idx, abs_idx in enumerate(obs_indices): + rel_frame = abs_idx - ep_start + targets[b_idx, t_idx] = find_stage_and_tau( + rel_frame, + ep_length, + subtask_names, + subtask_start_frames, + subtask_end_frames, + global_names, + temporal_props, + return_combined=True, + ) + + # Compute targets for rewind frames (if any) + rewind_step = rewind_steps[b_idx].item() + if rewind_step > 0: + _, rewind_indices = apply_rewind_augmentation( + frame_idx, + ep_start, + n_obs_steps, + max_rewind_steps, + frame_gap=frame_gap, + rewind_step=rewind_step, + ) + + for r_idx, abs_idx in enumerate(rewind_indices[:rewind_step]): + rel_frame = max(0, abs_idx - ep_start) + targets[b_idx, n_obs_steps + 1 + r_idx] = find_stage_and_tau( + rel_frame, + ep_length, + subtask_names, + subtask_start_frames, + subtask_end_frames, + global_names, + temporal_props, + return_combined=True, + ) + + return targets + + @property + def training(self) -> bool: + return getattr(self, "_training_mode", True) + + def train(self, mode: bool = True): + """Set training mode for augmentation decisions.""" + self._training_mode = mode + return self + + def eval(self): + """Set evaluation mode (disable augmentations).""" + return self.train(False) + + @torch.no_grad() + def _encode_images_batch(self, images: np.ndarray) -> torch.Tensor: + """Encode a batch of images using CLIP. + + Args: + images: Batched images with shape: (B, T, C, H, W) + + Returns: + Encoded feature vectors with shape (B, T, 512) + """ + + batch_size, seq_length = images.shape[0], images.shape[1] + images = images.reshape(batch_size * seq_length, *images.shape[2:]) + + num_frames = images.shape[0] + images_list = [] + for i in range(num_frames): + img = images[i] + if img.shape[0] in [1, 3]: # Channel first (C, H, W) + img = img.transpose(1, 2, 0) + + # Handle single channel + if img.shape[-1] == 1: + img = np.repeat(img, 3, axis=-1) + + if img.dtype != np.uint8: + img = (img * 255).astype(np.uint8) if img.max() <= 1.0 else img.astype(np.uint8) + + images_list.append(Image.fromarray(img)) + + all_embeddings = [] + for i in range(0, num_frames, self.config.clip_batch_size): + batch_imgs = images_list[i : i + self.config.clip_batch_size] + + inputs = self.clip_processor(images=batch_imgs, return_tensors="pt") + inputs = {k: v.to(self.device) for k, v in inputs.items()} + + # Get image embeddings + embeddings = self.clip_model.get_image_features(**inputs).detach().cpu() + + # Handle single frame case + if embeddings.dim() == 1: + embeddings = embeddings.unsqueeze(0) + + all_embeddings.append(embeddings) + + all_embeddings = torch.cat(all_embeddings) # (B*T, 512) + all_embeddings = all_embeddings.reshape(batch_size, seq_length, -1) # (B, T, 512) + + return all_embeddings + + @torch.no_grad() + def _encode_text_clip(self, text: str, batch_size: int) -> torch.Tensor: + """Encode text using CLIP text encoder (per SARM paper A.4). + + Args: + text: Task description text to encode + batch_size: Batch size to replicate for + + Returns: + Encoded text features with shape (B, 512) + """ + inputs = self.clip_processor.tokenizer([text], return_tensors="pt", padding=True, truncation=True) + inputs = {k: v.to(self.device) for k, v in inputs.items()} + + text_embedding = self.clip_model.get_text_features(**inputs).detach().cpu() + text_embedding = text_embedding.expand(batch_size, -1) + + return text_embedding + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + """Add encoded features to the observation features.""" + features[PipelineFeatureType.OBSERVATION]["video_features"] = PolicyFeature( + type=FeatureType.VISUAL, shape=(self.config.num_frames, self.config.image_dim) + ) + features[PipelineFeatureType.OBSERVATION]["text_features"] = PolicyFeature( + type=FeatureType.LANGUAGE, shape=(self.config.text_dim,) + ) + features[PipelineFeatureType.OBSERVATION]["state_features"] = PolicyFeature( + type=FeatureType.STATE, shape=(self.config.num_frames, self.config.max_state_dim) + ) + return features + + +def make_sarm_pre_post_processors( + config: SARMConfig, + dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None, + dataset_meta=None, +) -> tuple[ + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], + PolicyProcessorPipeline[PolicyAction, PolicyAction], +]: + """Create pre-processor and post-processor pipelines for SARM.""" + return ( + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]]( + steps=[ + AddBatchDimensionProcessorStep(), + RenameObservationsProcessorStep(rename_map={}), + NormalizerProcessorStep( + features={**config.input_features, **config.output_features}, + norm_map=config.normalization_mapping, + stats=dataset_stats, + ), + SARMEncodingProcessorStep( + config=config, dataset_meta=dataset_meta, dataset_stats=dataset_stats + ), + DeviceProcessorStep(device=config.device), + ], + name=POLICY_PREPROCESSOR_DEFAULT_NAME, + ), + PolicyProcessorPipeline[PolicyAction, PolicyAction]( + steps=[DeviceProcessorStep(device="cpu")], + name=POLICY_POSTPROCESSOR_DEFAULT_NAME, + to_transition=policy_action_to_transition, + to_output=transition_to_policy_action, + ), + ) diff --git a/src/lerobot/policies/sarm/sarm_utils.py b/src/lerobot/policies/sarm/sarm_utils.py new file mode 100644 index 000000000..5b6955d38 --- /dev/null +++ b/src/lerobot/policies/sarm/sarm_utils.py @@ -0,0 +1,295 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random + +import numpy as np +import torch +import torch.nn.functional as F # noqa: N812 + + +def find_stage_and_tau( + current_frame: int, + episode_length: int, + subtask_names: list | None, + subtask_start_frames: list | None, + subtask_end_frames: list | None, + global_subtask_names: list, + temporal_proportions: dict, + return_combined: bool = False, +) -> tuple[int, float] | float: + """Find stage and within-stage progress (tau) for a frame. + + Args: + current_frame: Frame index relative to episode start + episode_length: Total frames in episode + subtask_names: Subtask names for this episode (None for single_stage) + subtask_start_frames: Subtask start frames + subtask_end_frames: Subtask end frames + global_subtask_names: Global list of all subtask names + temporal_proportions: Dict of temporal proportions + return_combined: If True, return stage+tau as float; else (stage_idx, tau) tuple + + Returns: + Float (stage.tau) if return_combined, else (stage_idx, tau) tuple + """ + stage_idx, tau = 0, 0.0 + num_stages = len(global_subtask_names) + + # Single-stage mode: linear progress from 0 to 1 + if num_stages == 1: + tau = min(1.0, max(0.0, current_frame / max(episode_length - 1, 1))) + elif subtask_names is None: + pass # stage_idx=0, tau=0.0 + elif current_frame < subtask_start_frames[0]: + pass # Before first subtask: stage_idx=0, tau=0.0 + elif current_frame > subtask_end_frames[-1]: + stage_idx, tau = num_stages - 1, 0.999 # After last subtask + else: + # Find which subtask this frame belongs to + found = False + for name, start, end in zip(subtask_names, subtask_start_frames, subtask_end_frames, strict=True): + if start <= current_frame <= end: + stage_idx = global_subtask_names.index(name) if name in global_subtask_names else 0 + tau = compute_tau(current_frame, start, end) + found = True + break + # Frame between subtasks - use previous subtask's end state + if not found: + for j in range(len(subtask_names) - 1): + if subtask_end_frames[j] < current_frame < subtask_start_frames[j + 1]: + name = subtask_names[j] + stage_idx = global_subtask_names.index(name) if name in global_subtask_names else j + tau = 1.0 + break + + if return_combined: + # Clamp to avoid overflow at end + if stage_idx >= num_stages - 1 and tau >= 1.0: + return num_stages - 1 + 0.999 + return stage_idx + tau + return stage_idx, tau + + +def compute_absolute_indices( + frame_idx: int, + ep_start: int, + ep_end: int, + n_obs_steps: int, + frame_gap: int = 30, +) -> tuple[torch.Tensor, torch.Tensor]: + """Compute absolute frame indices with clamping for bidirectional observation sequence. + + Bidirectional sampling centered on target frame: + - Before: [-frame_gap * half_steps, ..., -frame_gap] (half_steps frames) + - Current: [0] (1 frame) + - After: [frame_gap, ..., frame_gap * half_steps] (half_steps frames) + - Total: n_obs_steps + 1 frames + + Out-of-bounds frames are clamped (duplicated from boundary). + + Args: + frame_idx: Target frame index (center frame of sequence) + ep_start: Episode start index + ep_end: Episode end index (exclusive) + n_obs_steps: Number of observation steps (must be even for symmetric sampling) + frame_gap: Gap between observation frames + + Returns: + Tuple of (indices, out_of_bounds_flags) + """ + half_steps = n_obs_steps // 2 + + # Bidirectional deltas: past + current + future + past_deltas = [-frame_gap * i for i in range(half_steps, 0, -1)] + future_deltas = [frame_gap * i for i in range(1, half_steps + 1)] + delta_indices = past_deltas + [0] + future_deltas + + frames = [] + out_of_bounds = [] + + for delta in delta_indices: + target_idx = frame_idx + delta + # Clamp to episode bounds (duplicate boundary frames for out-of-bounds) + clamped_idx = max(ep_start, min(ep_end - 1, target_idx)) + frames.append(clamped_idx) + # Flag as out-of-bounds if clamping occurred + out_of_bounds.append(1 if target_idx != clamped_idx else 0) + + return torch.tensor(frames), torch.tensor(out_of_bounds) + + +def apply_rewind_augmentation( + frame_idx: int, + ep_start: int, + n_obs_steps: int, + max_rewind_steps: int, + frame_gap: int = 30, + rewind_step: int | None = None, +) -> tuple[int, list[int]]: + """ + Generate rewind frame indices for temporal augmentation. + + Rewind simulates going backwards through previously seen frames, + starting from before the earliest observation frame (for bidirectional sampling). + Appends reversed frames after the observation sequence. + + Args: + frame_idx: Target frame index (center of bidirectional observation window) + ep_start: Episode start index + n_obs_steps: Number of observation steps + max_rewind_steps: Maximum rewind steps + frame_gap: Gap between frames + rewind_step: If provided, use this exact rewind step (for deterministic behavior). + If None, sample randomly. + + Returns: + Tuple of (rewind_step, rewind_indices) + """ + # For bidirectional sampling, earliest obs frame is at frame_idx - half_steps * frame_gap + half_steps = n_obs_steps // 2 + earliest_obs_frame = frame_idx - half_steps * frame_gap + + # Required history: frames before earliest observation frame + if earliest_obs_frame <= ep_start: + return 0, [] # No history before observation window + + # Max valid rewind steps based on available history before earliest obs frame + available_history = earliest_obs_frame - ep_start + max_valid_step = available_history // frame_gap + max_rewind = min(max_rewind_steps, max(0, max_valid_step)) + + if max_rewind <= 0: + return 0, [] + + # Sample rewind steps if not provided + rewind_step = random.randint(1, max_rewind) if rewind_step is None else min(rewind_step, max_rewind) + + if rewind_step == 0: + return 0, [] + + # Generate rewind indices going backwards from earliest obs frame + # rewind_indices[0] is closest to obs window, rewind_indices[-1] is furthest back + rewind_indices = [] + for i in range(1, rewind_step + 1): + idx = earliest_obs_frame - i * frame_gap + idx = max(ep_start, idx) # Clamp to episode start + rewind_indices.append(idx) + + return rewind_step, rewind_indices + + +def compute_tau(current_frame: int | float, subtask_start: int | float, subtask_end: int | float) -> float: + """Compute Ο„_t = (t - s_k) / (e_k - s_k) ∈ [0, 1]. Returns 1.0 for zero-duration subtasks.""" + duration = subtask_end - subtask_start + if duration <= 0: + return 1.0 + return float(np.clip((current_frame - subtask_start) / duration, 0.0, 1.0)) + + +def pad_state_to_max_dim(state: torch.Tensor, max_state_dim: int) -> torch.Tensor: + """Pad the state tensor's last dimension to max_state_dim with zeros.""" + current_dim = state.shape[-1] + if current_dim >= max_state_dim: + return state[..., :max_state_dim] # Truncate if larger + + # Pad with zeros on the right + padding = (0, max_state_dim - current_dim) # (left, right) for last dim + return F.pad(state, padding, mode="constant", value=0) + + +def temporal_proportions_to_breakpoints( + temporal_proportions: dict[str, float] | list[float] | None, + subtask_names: list[str] | None = None, +) -> list[float] | None: + """Convert temporal proportions to cumulative breakpoints for normalization.""" + if temporal_proportions is None: + return None + + if isinstance(temporal_proportions, dict): + if subtask_names is not None: + proportions = [temporal_proportions.get(name, 0.0) for name in subtask_names] + else: + proportions = list(temporal_proportions.values()) + else: + proportions = list(temporal_proportions) + + total = sum(proportions) + if total > 0 and abs(total - 1.0) > 1e-6: + proportions = [p / total for p in proportions] + + breakpoints = [0.0] + cumsum = 0.0 + for prop in proportions: + cumsum += prop + breakpoints.append(cumsum) + breakpoints[-1] = 1.0 + + return breakpoints + + +def normalize_stage_tau( + x: float | torch.Tensor, + num_stages: int | None = None, + breakpoints: list[float] | None = None, + temporal_proportions: dict[str, float] | list[float] | None = None, + subtask_names: list[str] | None = None, +) -> float | torch.Tensor: + """ + Normalize stage+tau reward to [0, 1] with custom breakpoints. + + Maps stage index + within-stage tau to normalized progress [0, 1]. + The breakpoints are designed to give appropriate weight to each stage + based on their importance in the task (using temporal proportions). + + Priority: breakpoints > temporal_proportions > linear fallback + + Args: + x: Raw reward value (stage index + tau) where stage ∈ [0, num_stages-1] and tau ∈ [0, 1) + num_stages: Number of stages (required if breakpoints/proportions not provided) + breakpoints: Optional custom breakpoints list of length num_stages + 1. + temporal_proportions: Optional temporal proportions dict/list to compute breakpoints. + subtask_names: Optional ordered list of subtask names (for dict proportions) + + Returns: + Normalized progress value ∈ [0, 1] + """ + if breakpoints is not None: + num_stages = len(breakpoints) - 1 + elif temporal_proportions is not None: + breakpoints = temporal_proportions_to_breakpoints(temporal_proportions, subtask_names) + num_stages = len(breakpoints) - 1 + elif num_stages is not None: + breakpoints = [i / num_stages for i in range(num_stages + 1)] + else: + raise ValueError("Either num_stages, breakpoints, or temporal_proportions must be provided") + + if isinstance(x, torch.Tensor): + result = torch.zeros_like(x) + for i in range(num_stages): + mask = (x >= i) & (x < i + 1) + tau_in_stage = x - i + result[mask] = breakpoints[i] + tau_in_stage[mask] * (breakpoints[i + 1] - breakpoints[i]) + result[x >= num_stages] = 1.0 + return result.clamp(0.0, 1.0) + else: + if x < 0: + return 0.0 + if x >= num_stages: + return 1.0 + stage = int(x) + tau = x - stage + return breakpoints[stage] + tau * (breakpoints[stage + 1] - breakpoints[stage]) diff --git a/src/lerobot/policies/smolvla/modeling_smolvla.py b/src/lerobot/policies/smolvla/modeling_smolvla.py index e442b14d5..f998661f9 100644 --- a/src/lerobot/policies/smolvla/modeling_smolvla.py +++ b/src/lerobot/policies/smolvla/modeling_smolvla.py @@ -231,6 +231,7 @@ class SmolVLAPolicy(PreTrainedPolicy): def __init__( self, config: SmolVLAConfig, + **kwargs, ): """ Args: @@ -352,8 +353,19 @@ class SmolVLAPolicy(PreTrainedPolicy): def _rtc_enabled(self) -> bool: return self.config.rtc_config is not None and self.config.rtc_config.enabled - def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> dict[str, Tensor]: - """Do a full training forward pass to compute the loss""" + def forward( + self, batch: dict[str, Tensor], noise=None, time=None, reduction: str = "mean" + ) -> dict[str, Tensor]: + """Do a full training forward pass to compute the loss. + + Args: + batch: Training batch containing observations and actions. + noise: Optional noise tensor for flow matching. + time: Optional time tensor for flow matching. + reduction: How to reduce the loss. Options: + - "mean": Return scalar mean loss (default, backward compatible) + - "none": Return per-sample losses of shape (batch_size,) for RA-BC weighting + """ if self.config.adapt_to_pi_aloha: batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE]) batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION]) @@ -377,11 +389,16 @@ class SmolVLAPolicy(PreTrainedPolicy): losses = losses[:, :, : self.config.max_action_dim] loss_dict["losses_after_rm_padding"] = losses.clone() - # For backward pass - loss = losses.mean() - # For backward pass - loss_dict["loss"] = loss.item() - return loss, loss_dict + if reduction == "none": + # Return per-sample losses (B,) by averaging over time and action dims + per_sample_loss = losses.mean(dim=(1, 2)) + loss_dict["loss"] = per_sample_loss.mean().item() + return per_sample_loss, loss_dict + else: + # Default: return scalar mean loss + loss = losses.mean() + loss_dict["loss"] = loss.item() + return loss, loss_dict def prepare_images(self, batch): """Apply SmolVLA preprocessing to the images, like resizing to 224x224 and padding to keep aspect ratio, and @@ -527,6 +544,7 @@ class VLAFlowMatching(nn.Module): num_vlm_layers=self.config.num_vlm_layers, self_attn_every_n_layers=self.config.self_attn_every_n_layers, expert_width_multiplier=self.config.expert_width_multiplier, + device=self.config.device if self.config.device is not None else "auto", ) self.state_proj = nn.Linear( self.config.max_state_dim, self.vlm_with_expert.config.text_config.hidden_size @@ -783,18 +801,15 @@ class VLAFlowMatching(nn.Module): use_cache=self.config.use_cache, fill_kv_cache=True, ) - dt = -1.0 / self.config.num_steps - dt = torch.tensor(dt, dtype=torch.float32, device=device) + num_steps = self.config.num_steps + dt = -1.0 / num_steps x_t = noise - time = torch.tensor(1.0, dtype=torch.float32, device=device) + for step in range(num_steps): + time = 1.0 + step * dt + time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize) - while time >= -dt / 2: - expanded_time = time.expand(bsize) - - # Define a closure function to properly capture expanded_time - # This avoids the lambda expression (E731) and loop variable binding (B023) issues - def denoise_step_partial_call(input_x_t, current_timestep=expanded_time): + def denoise_step_partial_call(input_x_t, current_timestep=time_tensor): return self.denoise_step( x_t=input_x_t, prefix_pad_masks=prefix_pad_masks, @@ -818,15 +833,11 @@ class VLAFlowMatching(nn.Module): else: v_t = denoise_step_partial_call(x_t) - # Euler step - x_t += dt * v_t + x_t = x_t + dt * v_t - # Record x_t and v_t after Euler step (other params are recorded in rtc_processor.denoise_step) if self.rtc_processor is not None and self.rtc_processor.is_debug_enabled(): self.rtc_processor.track(time=time, x_t=x_t, v_t=v_t) - time += dt - return x_t def denoise_step( diff --git a/src/lerobot/policies/tdmpc/modeling_tdmpc.py b/src/lerobot/policies/tdmpc/modeling_tdmpc.py index 195cf6154..f83c82e21 100644 --- a/src/lerobot/policies/tdmpc/modeling_tdmpc.py +++ b/src/lerobot/policies/tdmpc/modeling_tdmpc.py @@ -65,6 +65,7 @@ class TDMPCPolicy(PreTrainedPolicy): def __init__( self, config: TDMPCConfig, + **kwargs, ): """ Args: diff --git a/src/lerobot/policies/utils.py b/src/lerobot/policies/utils.py index c4ca35b72..bfbe2bf1d 100644 --- a/src/lerobot/policies/utils.py +++ b/src/lerobot/policies/utils.py @@ -231,11 +231,20 @@ def validate_visual_features_consistency( """ Validates visual feature consistency between a policy config and provided dataset/environment features. + Validation passes if EITHER: + - Policy's expected visuals are a subset of dataset (policy uses some cameras, dataset has more) + - Dataset's provided visuals are a subset of policy (policy declares extras for flexibility) + Args: cfg (PreTrainedConfig): The model or policy configuration containing input_features and type. features (Dict[str, PolicyFeature]): A mapping of feature names to PolicyFeature objects. """ expected_visuals = {k for k, v in cfg.input_features.items() if v.type == FeatureType.VISUAL} provided_visuals = {k for k, v in features.items() if v.type == FeatureType.VISUAL} - if not provided_visuals.issubset(expected_visuals): + + # Accept if either direction is a subset + policy_subset_of_dataset = expected_visuals.issubset(provided_visuals) + dataset_subset_of_policy = provided_visuals.issubset(expected_visuals) + + if not (policy_subset_of_dataset or dataset_subset_of_policy): raise_feature_mismatch_error(provided_visuals, expected_visuals) diff --git a/src/lerobot/policies/vqbet/modeling_vqbet.py b/src/lerobot/policies/vqbet/modeling_vqbet.py index 91d609701..359b4fdb1 100644 --- a/src/lerobot/policies/vqbet/modeling_vqbet.py +++ b/src/lerobot/policies/vqbet/modeling_vqbet.py @@ -47,6 +47,7 @@ class VQBeTPolicy(PreTrainedPolicy): def __init__( self, config: VQBeTConfig | None = None, + **kwargs, ): """ Args: diff --git a/src/lerobot/policies/wall_x/README.md b/src/lerobot/policies/wall_x/README.md new file mode 100644 index 000000000..78548bd8d --- /dev/null +++ b/src/lerobot/policies/wall_x/README.md @@ -0,0 +1,35 @@ +# WALL-OSS + +This repository contains the Hugging Face port of **WALL-OSS**, a Vision-Language-Action model for cross-embodiment robotic control based on Qwen2.5-VL with flow matching/FAST action prediction. + +--- + +## Model Overview + +| Feature | Description | +| ------------------ | ----------------------------------------------------- | --- | +| Base Model | Qwen2.5-VL (Vision-Language Model) | +| Action Prediction | Flow Matching (diffusion) or FAST (discrete tokens) | +| Architecture | Mixture of Experts (MoE) with action-specific routing | | +| Multi-Modal Inputs | Vision (images/videos), Language, Proprioception | + +--- + +## Citation + +If you use this work, please cite: + +```bibtex +@article{zhai2025igniting, + title = {Igniting VLMs Toward the Embodied Space}, + author = {Zhai, Andy and Liu, Brae and Fang, Bruno and Cai, Chalse and Ma, Ellie and Yin, Ethan and Wang, Hao and Zhou, Hugo and Wang, James and Shi, Lights and Liang, Lucy and Wang, Make and Wang, Qian and Gan, Roy and Yu, Ryan and Li, Shalfun and Liu, Starrick and Chen, Sylas and Chen, Vincent and Xu, Zach}, + journal = {arXiv preprint arXiv:2509.11766}, + year = {2025} +} +``` + +--- + +## License + +This port follows the **Apache 2.0 License**. diff --git a/src/lerobot/policies/wall_x/__init__.py b/src/lerobot/policies/wall_x/__init__.py new file mode 100644 index 000000000..d80c27bda --- /dev/null +++ b/src/lerobot/policies/wall_x/__init__.py @@ -0,0 +1,19 @@ +#!/usr/bin/env python + +# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .configuration_wall_x import WallXConfig + +__all__ = ["WallXConfig", "WallXPolicy", "make_wall_x_pre_post_processors"] diff --git a/src/lerobot/policies/wall_x/configuration_wall_x.py b/src/lerobot/policies/wall_x/configuration_wall_x.py new file mode 100644 index 000000000..0d10a8f98 --- /dev/null +++ b/src/lerobot/policies/wall_x/configuration_wall_x.py @@ -0,0 +1,165 @@ +# Copyright 2025 HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field + +from lerobot.configs.policies import PreTrainedConfig +from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature +from lerobot.optim.optimizers import AdamWConfig +from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig + + +@PreTrainedConfig.register_subclass("wall_x") +@dataclass +class WallXConfig(PreTrainedConfig): + """ + Configuration class for Wall-X policy. + + Wall-X is based on Qwen2.5-VL with action prediction capabilities using flow matching. + It supports cross-embodiment robotic control through unified action representations. + + This config supports multi-modal learning with vision, language, and action data. + """ + + # ==================== Input / Output Structure ==================== + n_obs_steps: int = 1 + chunk_size: int = 32 # action_horizon in wall-x + n_action_steps: int = 32 + + # Action dimension - wall-x uses 20 + max_action_dim: int = 20 + max_state_dim: int = 20 # For proprioception + + normalization_mapping: dict[str, NormalizationMode] = field( + default_factory=lambda: { + "VISUAL": NormalizationMode.IDENTITY, + "STATE": NormalizationMode.MEAN_STD, + "ACTION": NormalizationMode.MEAN_STD, + } + ) + + # ==================== Action Prediction ==================== + # Pretrained model paths + pretrained_name_or_path: str = "x-square-robot/wall-oss-flow" + + # Tokenizer settings + action_tokenizer_path: str | None = "physical-intelligence/fast" + + # Action prediction mode: "diffusion" or "fast" + prediction_mode: str = "diffusion" + + # Attention Implementation, options: "eager", "flash_attention_2", "sdpa" + # NOTE: flash-attn==2.7.4.post1 is required for flash_attention_2 implementation + attn_implementation: str = "eager" + + # ==================== Optimizer Presets ==================== + optimizer_lr: float = 2e-5 + optimizer_betas: tuple[float, float] = (0.9, 0.95) + optimizer_eps: float = 1e-8 + optimizer_weight_decay: float = 0.01 + optimizer_grad_clip_norm: float = 1.0 + + scheduler_warmup_steps: int = 1000 + scheduler_decay_steps: int = 100000 + scheduler_decay_lr: float = 1e-6 + + def __post_init__(self): + super().__post_init__() + + # Input validation + if self.n_action_steps > self.chunk_size: + raise ValueError( + f"The chunk size is the upper bound for the number of action steps per model invocation. Got " + f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`." + ) + + if self.prediction_mode not in ["diffusion", "fast"]: + raise ValueError(f"prediction_mode must be 'diffusion' or 'fast', got {self.prediction_mode}") + + # Assign use_fast_tokenizer based on prediction_mode + if self.prediction_mode == "fast": + self.use_fast_tokenizer = True + elif self.prediction_mode == "diffusion": + self.use_fast_tokenizer = False + self.action_tokenizer_path = None # disable action tokenizer for diffusion mode + else: + raise ValueError(f"prediction_mode must be 'diffusion' or 'fast', got {self.prediction_mode}") + + def validate_features(self) -> None: + """Validate and set up input/output features.""" + image_features = [key for key, feat in self.input_features.items() if feat.type == FeatureType.VISUAL] + if not image_features: + raise ValueError( + "Wall-X policy requires at least one visual input feature. " + "No features of type FeatureType.VISUAL found in input_features." + ) + + if "observation.state" not in self.input_features: + state_feature = PolicyFeature( + type=FeatureType.STATE, + shape=(self.max_state_dim,), # Padded to max_state_dim + ) + self.input_features["observation.state"] = state_feature + else: + state_shape = self.input_features["observation.state"].shape + state_dim = state_shape[0] if state_shape else 0 + if state_dim > self.max_state_dim: + raise ValueError( + f"State dimension {state_dim} exceeds max_state_dim {self.max_state_dim}. " + f"Either reduce state dimension or increase max_state_dim in config." + ) + + if "action" not in self.output_features: + action_feature = PolicyFeature( + type=FeatureType.ACTION, + shape=(self.max_action_dim,), # Padded to max_action_dim + ) + self.output_features["action"] = action_feature + else: + action_shape = self.output_features["action"].shape + action_dim = action_shape[0] if action_shape else 0 + if action_dim > self.max_action_dim: + raise ValueError( + f"Action dimension {action_dim} exceeds max_action_dim {self.max_action_dim}. " + f"Either reduce action dimension or increase max_action_dim in config." + ) + + def get_optimizer_preset(self) -> AdamWConfig: + return AdamWConfig( + lr=self.optimizer_lr, + betas=self.optimizer_betas, + eps=self.optimizer_eps, + weight_decay=self.optimizer_weight_decay, + grad_clip_norm=self.optimizer_grad_clip_norm, + ) + + def get_scheduler_preset(self): + return CosineDecayWithWarmupSchedulerConfig( + peak_lr=self.optimizer_lr, + decay_lr=self.scheduler_decay_lr, + num_warmup_steps=self.scheduler_warmup_steps, + num_decay_steps=self.scheduler_decay_steps, + ) + + @property + def observation_delta_indices(self) -> list: + return None + + @property + def action_delta_indices(self) -> list: + return list(range(self.chunk_size)) + + @property + def reward_delta_indices(self) -> None: + return None diff --git a/src/lerobot/policies/wall_x/constant.py b/src/lerobot/policies/wall_x/constant.py new file mode 100644 index 000000000..43e5e7fb6 --- /dev/null +++ b/src/lerobot/policies/wall_x/constant.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python + +# Copyright 2025 HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Wall-X Constants and Configuration Data. +""" + +CAMERA_NAME_MAPPING = { + "face_view": "front view", + "left_wrist_view": "left wrist view", + "right_wrist_view": "right wrist view", + "move1_view": "move view", + "move2_view": "move view", + "wall_view": "wall view", + "top_view": "top view", +} + +RESOLUTION = 256 + +# Parameters for preprocessing +MAX_PIXELS = 16384 * 28 * 28 +MIN_PIXELS = 4 * 28 * 28 +IMAGE_FACTOR = 28 +PRIORITY_ORDER = None +GENERATE_SUBTASK_RATIO = 0.0 +MODEL_TYPE = "qwen2_5" + +TOKENIZER_MAX_LENGTH = 768 diff --git a/src/lerobot/policies/wall_x/modeling_wall_x.py b/src/lerobot/policies/wall_x/modeling_wall_x.py new file mode 100644 index 000000000..c401c8d60 --- /dev/null +++ b/src/lerobot/policies/wall_x/modeling_wall_x.py @@ -0,0 +1,2008 @@ +#!/usr/bin/env python + +# Copyright 2025 HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Wall-X: Cross-embodiment robotic control using Qwen2.5-VL with flow matching. + +[Paper](https://github.com/x2-robot/wall-x) + +Install wall-x extra dependencies: +```bash +pip install -e ".[wall_x]" +``` + +Example of finetuning a wall-x model: +```bash +lerobot-train \ +--policy.type=wall_x \ +--dataset.repo_id=your/dataset \ +--batch_size=32 \ +--steps=100000 +``` +""" + +import math +from collections import deque +from os import PathLike +from typing import Any + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from peft import LoraConfig, get_peft_model +from PIL import Image +from qwen_vl_utils.vision_process import smart_resize +from torch import Tensor +from torch.distributions import Beta +from torch.nn import CrossEntropyLoss +from torchdiffeq import odeint +from transformers import AutoProcessor, BatchFeature +from transformers.cache_utils import ( + StaticCache, +) +from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( + Qwen2_5_VLForConditionalGeneration, +) +from transformers.utils import is_torchdynamo_compiling, logging + +from lerobot.policies.pretrained import PreTrainedPolicy +from lerobot.policies.utils import populate_queues +from lerobot.policies.wall_x.configuration_wall_x import WallXConfig +from lerobot.policies.wall_x.constant import ( + GENERATE_SUBTASK_RATIO, + IMAGE_FACTOR, + MAX_PIXELS, + MIN_PIXELS, + MODEL_TYPE, + PRIORITY_ORDER, + RESOLUTION, + TOKENIZER_MAX_LENGTH, +) +from lerobot.policies.wall_x.qwen_model.configuration_qwen2_5_vl import Qwen2_5_VLConfig +from lerobot.policies.wall_x.qwen_model.qwen2_5_vl_moe import ( + Qwen2_5_VisionTransformerPretrainedModel, + Qwen2_5_VLACausalLMOutputWithPast, + Qwen2_5_VLMoEModel, +) +from lerobot.policies.wall_x.utils import ( + get_wallx_normal_text, + preprocesser_call, + process_grounding_points, + replace_action_token, +) +from lerobot.utils.constants import ACTION, OBS_STATE + +logger = logging.get_logger(__name__) + + +class SinusoidalPosEmb(nn.Module): + """Sinusoidal positional embedding for diffusion timesteps.""" + + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x): + device = x.device + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=device) * -emb) + emb = x[:, None] * emb[None, :] + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb + + +class ActionHead(nn.Module): + """ + Action prediction head with flow matching. + + Implements Beta-distributed noise scheduling and temporal embeddings + for action sequence prediction. + """ + + def __init__(self, config): + super().__init__() + + self.config = config + self.action_dim = sum(config.dof_config.values()) + self.propri_dim = sum(config.agent_pos_config.values()) + self.hidden_size = config.hidden_size + + # Beta distribution for noise scheduling + self.beta_alpha = 1.5 + self.beta_beta = 1.0 + self.s = 0.999 + + # Sinusoidal timestep embedding + self.time_embed = SinusoidalPosEmb(config.hidden_size) + + # Action embedding network + # *2 for action + DOF mask concatenation + self.w1 = nn.Linear(self.action_dim * 2, self.hidden_size, bias=False) + self.w2 = nn.Linear(self.hidden_size * 2, self.hidden_size, bias=False) # *2 for action + time + self.w3 = nn.Linear(self.hidden_size, self.hidden_size, bias=False) + self.act_fn = nn.SiLU() + + # Project back to action space + self.action_proj_back = nn.Linear(self.hidden_size, self.action_dim, bias=False) + + # Proprioception projection + self.propri_proj = nn.Linear(self.propri_dim * 2, self.hidden_size, bias=False) + + def sample_time(self, batch_size, device): + """Sample timesteps using Beta distribution (always in float32 for numerical stability).""" + beta_dist = Beta( + torch.tensor(self.beta_alpha, dtype=torch.float32, device=device), + torch.tensor(self.beta_beta, dtype=torch.float32, device=device), + ) + sample = beta_dist.sample([batch_size]) + time = (1 - sample) * self.s + return time + + def forward(self, action_chunk, dof_mask=None): + """ + Process action sequences with noise injection for training. + + Args: + action_chunk: Action sequences [batch, seq_len, action_dim] + dof_mask: DOF mask [batch, seq_len, action_dim] + + Returns: + tuple: (action_embeddings, flow_target) + """ + batch_size = action_chunk.shape[0] + device = action_chunk.device + weight_dtype = self.w1.weight.dtype + + # Sample time outside of autocast (Beta distribution needs float32) + time = self.sample_time(batch_size, device) + t = time.unsqueeze(-1).unsqueeze(-1) + + # Noise and flow computation in float32 + noise = torch.randn_like(action_chunk, dtype=torch.float32) + action_chunk_f32 = action_chunk.to(torch.float32) + noisy_action = (1 - t) * noise + t * action_chunk_f32 + flow = action_chunk_f32 - noise + + # Project noisy actions + if dof_mask is not None: + noisy_action = torch.cat([noisy_action, dof_mask.to(torch.float32)], dim=-1) + + # Convert to weight dtype for linear layers + noisy_action = noisy_action.to(dtype=weight_dtype) + action_embed = self.w1(noisy_action) + + # Generate time embeddings and combine + time_embed = self.time_embed(time) + time_embed = time_embed.unsqueeze(1).repeat(1, action_embed.shape[1], 1) + time_embed = time_embed.to(dtype=weight_dtype) + + concat_embed = torch.cat([action_embed, time_embed], dim=-1) + concat_embed = self.w2(concat_embed) + embed = self.w3(self.act_fn(concat_embed)) + + return embed, flow + + def step(self, timestep, noisy_action, dof_mask=None): + """Single denoising step for inference.""" + weight_dtype = self.w1.weight.dtype + + if dof_mask is not None: + noisy_action = torch.cat([noisy_action, dof_mask], dim=-1) + noisy_action = noisy_action.to(dtype=weight_dtype) + + time_embed = self.time_embed(timestep) + action_embed = self.w1(noisy_action) + + time_embed = time_embed.unsqueeze(1).repeat(1, action_embed.shape[1], 1) + time_embed = time_embed.to(device=noisy_action.device, dtype=weight_dtype) + + concat_embed = torch.cat([action_embed, time_embed], dim=-1) + concat_embed = self.w2(concat_embed) + embed = self.w3(self.act_fn(concat_embed)) + + return embed + + def flow_loss(self, action_hidden_states, flow, dof_mask=None): + """Compute flow matching loss (all computations in float32 for stability).""" + # Ensure all inputs are float32 + action_hidden_states = action_hidden_states.to(torch.float32) + flow = flow.to(torch.float32) + + action_pred = self.action_proj_back(action_hidden_states) + loss = F.mse_loss(action_pred, flow, reduction="none") + + if dof_mask is not None: + dof_mask = dof_mask.reshape(-1, dof_mask.shape[-1]).to(torch.float32) + loss = loss * dof_mask + + return loss + + def proprioception_proj(self, proprioception, dof_mask=None, use_history=False): + """Project proprioceptive data to hidden space.""" + # Ensure proper device and dtype alignment + proprioception = proprioception.to(device=self.propri_proj.weight.device).to( + dtype=self.propri_proj.weight.dtype + ) + + if dof_mask is not None: + # Concatenate proprioception with DOF mask + # TODO: Use variable-based dimension checking for better flexibility + if use_history: + proprioception = torch.cat([proprioception, dof_mask], dim=-1) + else: + proprioception = torch.cat([proprioception, dof_mask], dim=-1) + + proprioception = proprioception.to(device=self.propri_proj.weight.device).to( + dtype=self.propri_proj.weight.dtype + ) + return self.propri_proj(proprioception) + + +class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration): + """ + Qwen2.5 Vision-Language Mixture of Experts model for action processing. + + This model extends the base Qwen2.5 VL model with action token processing capabilities + and optional LoRA fine-tuning support. + """ + + _tied_weights_keys = ["lm_head.weight"] + config_class = Qwen2_5_VLConfig + _no_split_modules = ["Qwen2_5_VLDecoderLayer_with_MoE", "Qwen2_5_VLVisionBlock"] + + @classmethod + def from_pretrained( + cls, + pretrained_name_or_path, + config=None, + action_tokenizer_path=None, + attn_implementation: str = "eager", + cache_dir: str | PathLike | None = None, + force_download: bool = False, + local_files_only: bool = False, + token: str | bool | None = None, + revision: str = "main", + strict: bool = False, + **kwargs: Any, + ): + """ + Load model from pretrained model path. + + Args: + pretrained_model_path (str): Model directory path containing model.safetensors file + config_path (str, optional): Configuration file path, if None will look for qwen25_config.json in pretrained_model_path + action_tokenizer_path (str, optional): Action tokenizer path, if None will load from default config + attn_implementation (str, optional): Attention implementation, if None will load from default config + **kwargs: Additional arguments + + Returns: + Qwen2_5_VLMoEForAction: Loaded model instance + """ + if config is None: + config = cls.config_class.from_pretrained( + pretrained_name_or_path, + cache_dir=cache_dir, + force_download=force_download, + local_files_only=local_files_only, + token=token, + revision=revision, + strict=strict, + **kwargs, + ) + if attn_implementation is not None: + config._attn_implementation = attn_implementation + processor = AutoProcessor.from_pretrained(pretrained_name_or_path, use_fast=True) + if action_tokenizer_path is not None: + action_tokenizer = AutoProcessor.from_pretrained(action_tokenizer_path, trust_remote_code=True) + processor.action_processor = action_tokenizer + else: + action_tokenizer = None + # Initialize model with configuration and processor + model = cls(config, processor=processor, action_tokenizer=action_tokenizer, **kwargs) + + # Resize token embeddings to match processor tokenizer vocabulary size + model.resize_token_embeddings(len(processor.tokenizer)) + + # Try to load the model.safetensors file + print(f"Loading model from: {pretrained_name_or_path}") + try: + from transformers.utils import cached_file + + # Try safetensors first + resolved_file = cached_file( + pretrained_name_or_path, + "model.safetensors", + cache_dir=kwargs.get("cache_dir"), + force_download=kwargs.get("force_download", False), + resume_download=kwargs.get("resume_download"), + proxies=kwargs.get("proxies"), + use_auth_token=kwargs.get("use_auth_token"), + revision=kwargs.get("revision"), + local_files_only=kwargs.get("local_files_only", False), + ) + from safetensors.torch import load_file + + sd = load_file(resolved_file) + print("βœ“ Loaded state dict from model.safetensors") + except Exception as e: + print(f"Could not load state dict from remote files: {e}") + print("Returning model without loading pretrained weights") + return model + + state_dict = {} + # filter normalizer statistic params + del_keys = [] + for key in sd.keys(): + if "action_preprocessor.normalizer" in key: + del_keys.append(key) + for key in del_keys: + del sd[key] + state_dict.update(sd) + + model.load_state_dict(state_dict, strict=False) + + return model + + def __init__( + self, + config, + use_fast_tokenizer=False, + processor=None, + action_tokenizer=None, + action_mapper=None, + flow_loss_weight=1.0, + ): + """ + Initialize the Qwen2.5 VLMoE model for action processing. + + Args: + config: Model configuration + use_fast_tokenizer (bool): Whether to use fast tokenizer + processor: Text and image processor + action_tokenizer: Action-specific tokenizer + action_mapper: Action mapping utility + flow_loss_weight (float): Weight for flow loss computation + """ + super().__init__(config) + + # Initialize vision transformer and language model components + self.visual = Qwen2_5_VisionTransformerPretrainedModel._from_config(config.vision_config) + self.model = Qwen2_5_VLMoEModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize loss function without reduction for channel-wise loss computation + self.loss_fct = CrossEntropyLoss(reduction="none") + self.flow_loss_weight = flow_loss_weight + self.use_fast_tokenizer = use_fast_tokenizer + self.processor = processor + self.action_tokenizer = action_tokenizer + + # Define action token IDs + self.define_action_token_id() + + # Cache for rope deltas + self.rope_deltas = None + + # Initialize action preprocessor + self.action_preprocessor = ActionHead(config) + + # Apply LoRA if specified in configuration + if hasattr(config, "use_lora") and config.use_lora: + self.add_lora( + r=config.lora_r, + lora_alpha=config.lora_alpha, + target_modules=config.lora_target_modules, + lora_dropout=config.lora_dropout, + ) + + # Initialize weights and apply final processing + self.post_init() + + def to_bfloat16_for_selected_params(self): + self.to(dtype=torch.bfloat16) + + params_to_keep_float32 = [] + + for name, param in self.named_parameters(): + if "input_layernorm" in name or "post_attention_layernorm" in name or "model.norm" in name: + params_to_keep_float32.append(name) + if "action_preprocessor" in name: + params_to_keep_float32.append(name) + + for name, param in self.named_parameters(): + if name in params_to_keep_float32: + param.data = param.data.to(torch.float32) + + def define_action_token_id(self): + """ + Define action token IDs based on tokenizer configuration. + + Creates mappings for fast action tokens, proprioception tokens, and general action tokens. + """ + # Create list of fast action token IDs + fast_action_token_list = [] + if self.use_fast_tokenizer: + for i in range(self.processor.tokenizer.init_kwargs["action_token_vocab_size"]): + action_token_id = self.processor.tokenizer.convert_tokens_to_ids(f"<|action_token_{i}|>") + fast_action_token_list.append(action_token_id) + + # Get special action token IDs + action_token_id = self.processor.tokenizer.convert_tokens_to_ids("<|action|>") + propri_token_id = self.processor.tokenizer.convert_tokens_to_ids("<|propri|>") + + # Store action token ID mappings + self.action_token_id_set = { + "fast_action_token_list": fast_action_token_list, + "propri_token_id": propri_token_id, + "action_token_id": action_token_id, + } + + def add_lora(self, r=8, lora_alpha=32, target_modules=["q_proj", "v_proj"], lora_dropout=0.1): + """ + Add LoRA (Low-Rank Adaptation) adapters to the model. + + Args: + r (int): Rank of adaptation + lora_alpha (int): LoRA scaling parameter + target_modules (list): List of module names to apply LoRA to + lora_dropout (float): Dropout probability for LoRA layers + """ + config = LoraConfig( + r=r, + lora_alpha=lora_alpha, + target_modules=target_modules, + lora_dropout=lora_dropout, + bias="none", + task_type="CAUSAL_LM", + ) + self.model = get_peft_model(self.model, config) + + # Print information about trainable parameters + self.model.print_trainable_parameters() + + def get_input_embeddings(self): + """Get input embeddings layer.""" + return self.model.embed_tokens + + def set_input_embeddings(self, value): + """Set input embeddings layer.""" + self.model.embed_tokens = value + + def get_output_embeddings(self): + """Get output embeddings layer.""" + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + """Set output embeddings layer.""" + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + """Set the decoder model.""" + self.model = decoder + + def get_decoder(self): + """Get the decoder model.""" + return self.model + + def get_rope_index( + self, + input_ids: torch.LongTensor | None = None, + image_grid_thw: torch.LongTensor | None = None, + video_grid_thw: torch.LongTensor | None = None, + second_per_grid_ts: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Calculate 3D RoPE (Rotary Position Embedding) indices for vision and text tokens. + + This method computes position embeddings that account for the temporal, height, and width + dimensions of vision tokens (images/videos) while maintaining standard 1D position embeddings + for text tokens. + + For vision tokens, 3D position embeddings are calculated based on: + - Temporal dimension: Time patches in videos + - Height dimension: Vertical patches in images/video frames + - Width dimension: Horizontal patches in images/video frames + + For text tokens, standard 1D position embeddings are used, continuing from the maximum + vision position ID plus 1. + + Args: + input_ids (torch.LongTensor, optional): Input token IDs of shape (batch_size, sequence_length) + image_grid_thw (torch.LongTensor, optional): Image grid dimensions (num_images, 3) for [temporal, height, width] + video_grid_thw (torch.LongTensor, optional): Video grid dimensions (num_videos, 3) for [temporal, height, width] + second_per_grid_ts (torch.Tensor, optional): Time interval per temporal grid (num_videos,) + attention_mask (torch.Tensor, optional): Attention mask (batch_size, sequence_length) + + Returns: + tuple: + - position_ids (torch.LongTensor): 3D position IDs of shape (3, batch_size, sequence_length) + - mrope_position_deltas (torch.Tensor): Position deltas for mRoPE of shape (batch_size, 1) + """ + spatial_merge_size = self.config.vision_config.spatial_merge_size + image_token_id = self.config.image_token_id + video_token_id = self.config.video_token_id + vision_start_token_id = self.config.vision_start_token_id + mrope_position_deltas = [] + + if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): + total_input_ids = input_ids + if attention_mask is None: + attention_mask = torch.ones_like(total_input_ids) + + # Initialize 3D position IDs tensor + position_ids = torch.ones( + 3, + input_ids.shape[0], + input_ids.shape[1], + dtype=input_ids.dtype, + device=input_ids.device, + ) + + image_index, video_index = 0, 0 + attention_mask = attention_mask.to(total_input_ids.device) + + # Process each sequence in the batch + for i, input_ids in enumerate(total_input_ids): + input_ids = input_ids[attention_mask[i] == 1] + image_nums, video_nums = 0, 0 + + # Find vision tokens and count images/videos + vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1) + vision_tokens = input_ids[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() + + input_tokens = input_ids.tolist() + llm_pos_ids_list: list = [] + st = 0 + remain_images, remain_videos = image_nums, video_nums + + # Process each vision token (image or video) + for _ in range(image_nums + video_nums): + # Find next image or video token + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + + # Determine if processing image or video token + if ed_image < ed_video: + # Process image token + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + second_per_grid_t = 0 + image_index += 1 + remain_images -= 1 + ed = ed_image + else: + # Process video token + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + if second_per_grid_ts is not None: + second_per_grid_t = second_per_grid_ts[video_index] + else: + second_per_grid_t = 1.0 + video_index += 1 + remain_videos -= 1 + ed = ed_video + + # Calculate grid dimensions after spatial merging + llm_grid_t, llm_grid_h, llm_grid_w = ( + t.item(), + h.item() // spatial_merge_size, + w.item() // spatial_merge_size, + ) + text_len = ed - st + + # Add position IDs for text tokens before vision token + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + # Calculate 3D position embeddings for vision tokens + range_tensor = torch.arange(llm_grid_t).view(-1, 1) + expanded_range = range_tensor.expand(-1, llm_grid_h * llm_grid_w) + + # Calculate temporal position IDs with time scaling + time_tensor = ( + expanded_range * second_per_grid_t * self.config.vision_config.tokens_per_second + ) + time_tensor_long = time_tensor.long() + t_index = time_tensor_long.flatten() + + # Calculate spatial position IDs + h_index = ( + torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() + ) + w_index = ( + torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() + ) + + # Add 3D position IDs for vision tokens + llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + # Add position IDs for remaining text tokens + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + # Concatenate all position IDs for this sequence + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device) + mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i])) + + mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) + return position_ids, mrope_position_deltas + else: + # Handle case without vision tokens - use standard 1D position embeddings + if attention_mask is not None: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device) + max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] + mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] + else: + position_ids = ( + torch.arange(input_ids.shape[1], device=input_ids.device) + .view(1, 1, -1) + .expand(3, input_ids.shape[0], -1) + ) + mrope_position_deltas = torch.zeros( + [input_ids.shape[0], 1], + device=input_ids.device, + dtype=input_ids.dtype, + ) + + return position_ids, mrope_position_deltas + + def train_step_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: list[torch.FloatTensor] | None = None, + inputs_embeds: torch.FloatTensor | None = None, + moe_token_types: torch.LongTensor | None = None, # MoE token type assignments + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + pixel_values: torch.Tensor | None = None, + pixel_values_videos: torch.FloatTensor | None = None, + image_grid_thw: torch.LongTensor | None = None, + video_grid_thw: torch.LongTensor | None = None, + action_chunk: torch.FloatTensor | None = None, # Action trajectory chunks + proprioception: torch.FloatTensor | None = None, # Joint position/orientation data + rope_deltas: torch.LongTensor | None = None, + cache_position: torch.LongTensor | None = None, + second_per_grid_ts: torch.Tensor | None = None, + dof_mask: torch.FloatTensor | None = None, + agent_pos_mask: torch.FloatTensor | None = None, + **kwargs, + ) -> tuple | Qwen2_5_VLACausalLMOutputWithPast: + """ + Forward pass for training with multi-modal inputs including vision, text, and action data. + + This method handles the complete forward pass during training, processing various input modalities + including images, videos, text, proprioceptive data, and action sequences. It computes losses + for both language modeling and action prediction using flow matching. + + Args: + input_ids (torch.LongTensor, optional): Input token IDs + attention_mask (torch.Tensor, optional): Attention mask for input tokens + position_ids (torch.LongTensor, optional): Position IDs for tokens + past_key_values (List[torch.FloatTensor], optional): Cached key-value pairs for generation + inputs_embeds (torch.FloatTensor, optional): Pre-computed input embeddings + moe_token_types (torch.LongTensor, optional): Token type assignments for MoE routing + labels (torch.LongTensor, optional): Target labels for loss computation + use_cache (bool, optional): Whether to use key-value caching + output_attentions (bool, optional): Whether to return attention weights + output_hidden_states (bool, optional): Whether to return hidden states + return_dict (bool, optional): Whether to return structured output + pixel_values (torch.Tensor, optional): Image pixel values + pixel_values_videos (torch.FloatTensor, optional): Video pixel values + image_grid_thw (torch.LongTensor, optional): Image grid dimensions (temporal, height, width) + video_grid_thw (torch.LongTensor, optional): Video grid dimensions (temporal, height, width) + action_chunk (torch.FloatTensor, optional): Action trajectory data chunks + proprioception (torch.FloatTensor, optional): Proprioceptive sensor data (joint positions, etc.) + rope_deltas (torch.LongTensor, optional): RoPE position deltas + cache_position (torch.LongTensor, optional): Cache position indices + second_per_grid_ts (torch.Tensor, optional): Time interval per temporal grid + dof_mask (torch.FloatTensor, optional): Degrees of freedom mask for action tokens + agent_pos_mask (torch.FloatTensor, optional): Agent position mask for proprioceptive data + **kwargs: Additional keyword arguments + + Returns: + Union[Tuple, Qwen2_5_VLACausalLMOutputWithPast]: Model outputs including losses, logits, + and auxiliary information, or tuple if return_dict=False + """ + batch_size, seq_length = input_ids.shape + + # Set output configuration from model config if not specified + output_attentions = ( + output_attentions if output_attentions is not None else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # Calculate RoPE position IDs if not provided + # Note: Cannot calculate rope deltas with 4D attention mask. TODO: Fix this limitation + if position_ids is None and (attention_mask is None or attention_mask.ndim == 2): + # Calculate RoPE index once per generation in the pre-fill stage only + if ( + (cache_position is not None and cache_position[0] == 0) + or self.rope_deltas is None + or (past_key_values is None or past_key_values.get_seq_length() == 0) + ): + position_ids, rope_deltas = self.get_rope_index( + input_ids, + image_grid_thw, + video_grid_thw, + second_per_grid_ts, + attention_mask, + ) + self.rope_deltas = rope_deltas + # Use previously calculated rope deltas to get correct position IDs + else: + delta = ( + (cache_position[0] + self.rope_deltas).to(self.device) + if cache_position is not None + else 0 + ) + position_ids = torch.arange(seq_length, device=self.device) + position_ids = position_ids.view(1, -1).expand(batch_size, -1) + if cache_position is not None: # otherwise `deltas` is an int `0` + delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) + position_ids = position_ids.add(delta) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) + + # Process input embeddings with multi-modal data + if inputs_embeds is None: + inputs_embeds = self.model.embed_tokens(input_ids) + + # Process image embeddings + if pixel_values is not None: + pixel_values = pixel_values.type(self.visual.dtype) + image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) + mask = input_ids == self.config.image_token_id + mask_unsqueezed = mask.unsqueeze(-1) + mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) + image_mask = mask_expanded.to(inputs_embeds.device) + + image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + # Process video embeddings + if pixel_values_videos is not None: + pixel_values_videos = pixel_values_videos.type(self.visual.dtype) + video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) + n_video_tokens = (input_ids == self.config.video_token_id).sum().item() + n_video_features = video_embeds.shape[0] + + # Validate video token and feature count match + if n_video_tokens != n_video_features: + raise ValueError( + f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" + ) + mask = input_ids == self.config.video_token_id + mask_unsqueezed = mask.unsqueeze(-1) + mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) + video_mask = mask_expanded.to(inputs_embeds.device) + + video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) + + # Process proprioceptive data (joint positions, orientations, etc.) + if proprioception is not None: + proprioception = proprioception.to(inputs_embeds.device).to(inputs_embeds.dtype) + agent_pos_mask = agent_pos_mask.to(inputs_embeds.device).to(inputs_embeds.dtype) + proprioception = self.action_preprocessor.proprioception_proj( + proprioception, + agent_pos_mask, + use_history=proprioception.shape[1] > 1, + ) + mask = input_ids == self.action_token_id_set["propri_token_id"] + mask_unsqueezed = mask.unsqueeze(-1) + mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) + proprioception_mask = mask_expanded.to(inputs_embeds.device) + + proprioception = proprioception.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(proprioception_mask, proprioception) + elif self.training: + # Dummy forward pass to ensure gradient registration in DDP + # This handles cases where one process has proprioception data while another doesn't + # Without this, DDP would hang waiting for a gradient that will never be computed + dummy_input = torch.randn( + 2, + self.action_preprocessor.propri_dim * 2, + device=inputs_embeds.device, + ) + dummy_forward = self.action_preprocessor.proprioception_proj(dummy_input) + dummy_loss = sum(p.sum() for p in dummy_forward) + inputs_embeds = inputs_embeds + 0 * dummy_loss + + # Process action chunk data + if action_chunk is not None: + action_chunk = action_chunk.to(inputs_embeds.device).to(inputs_embeds.dtype) + dof_mask = dof_mask.to(inputs_embeds.device).to(inputs_embeds.dtype) + noisy_action_emb, flow = self.action_preprocessor(action_chunk, dof_mask) + mask = input_ids == self.action_token_id_set["action_token_id"] + mask_unsqueezed = mask.unsqueeze(-1) + mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) + action_mask = mask_expanded.to(inputs_embeds.device) + + noisy_action_emb = noisy_action_emb.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(action_mask, noisy_action_emb) + + if attention_mask is not None: + attention_mask = attention_mask.to(inputs_embeds.device) + + # Forward pass through the main model + outputs = self.model( + input_ids=None, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + moe_token_types=moe_token_types, # Pass token types for MoE routing + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + hidden_states = hidden_states.to(self.lm_head.weight.dtype) + logits = self.lm_head(hidden_states) + + # Initialize loss computation variables + loss = None + cross_entropy_loss, flow_loss = None, None + channel_loss_dict = None + channel_loss_count_dict = None + + # Compute losses if labels are provided + if labels is not None: + loss = torch.tensor(0.0, device=hidden_states.device, dtype=torch.float32) + + # Compute standard cross-entropy loss for language modeling + shift_logits = logits[..., :-1, :].contiguous().to(torch.float32) + shift_labels = labels[..., 1:].contiguous() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + + # Enable model parallelism by moving labels to correct device + shift_labels = shift_labels.to(shift_logits.device) + non_ignored_mask = shift_labels != -100 + _cross_entropy_loss = self.loss_fct(shift_logits, shift_labels) + cross_entropy_loss = ( + _cross_entropy_loss[non_ignored_mask].mean() + if non_ignored_mask.any() + else torch.tensor(0.0, device=shift_logits.device, dtype=torch.float32) + ) + + # Add cross-entropy loss to total loss if valid + if not torch.isnan(cross_entropy_loss): + loss = loss + cross_entropy_loss.to(torch.float32) + else: + with torch.no_grad(): + cross_entropy_loss.detach() + + if action_chunk is not None: + action_mask = input_ids == self.action_token_id_set["action_token_id"] + if action_mask.any(): + action_hidden_states = hidden_states[action_mask].to(torch.float32) + flow = flow.reshape(-1, flow.shape[-1]).to(torch.float32) + _flow_loss = self.action_preprocessor.flow_loss(action_hidden_states, flow, dof_mask) + if isinstance(_flow_loss, torch.Tensor): + flow_loss = _flow_loss.mean() + if loss is not None: + loss = loss + self.flow_loss_weight * flow_loss.to(torch.float32) + else: + loss = self.flow_loss_weight * flow_loss.to(torch.float32) + _flow_loss = _flow_loss.view(dof_mask.shape[0], dof_mask.shape[1], dof_mask.shape[2]) + + # Return outputs based on return_dict setting + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return Qwen2_5_VLACausalLMOutputWithPast( + loss=loss, + cross_entropy_loss=(cross_entropy_loss.clone() if cross_entropy_loss is not None else None), + flow_loss=flow_loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + rope_deltas=self.rope_deltas, + channel_loss_dict=channel_loss_dict, + channel_loss_count_dict=channel_loss_count_dict, + ) + + def predict_action(self, predict_mode: str, **kwargs): + """ + Predict actions using specified prediction mode. + + Args: + predict_mode (str): Prediction mode, either "fast" or "diffusion" + **kwargs: Additional arguments passed to the predict method + + Returns: + tuple: (predicted_action, ground_truth_action) where ground_truth_action may be None + """ + assert predict_mode in ["fast", "diffusion"] + + output = self.predict(predict_mode=predict_mode, **kwargs) + + return output["predict_action"], output.get("gt_action", None) + + @torch.no_grad() + def predict( + self, + predict_mode: str, + pred_horizon: int | None = None, + action_dim: int | None = None, + input_ids: torch.LongTensor = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: list[torch.FloatTensor] | None = None, + inputs_embeds: torch.FloatTensor | None = None, + moe_token_types: torch.LongTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + pixel_values: torch.Tensor | None = None, + pixel_values_videos: torch.FloatTensor | None = None, + image_grid_thw: torch.LongTensor | None = None, + video_grid_thw: torch.LongTensor | None = None, + action_chunk: torch.FloatTensor | None = None, + proprioception: torch.FloatTensor | None = None, + rope_deltas: torch.LongTensor | None = None, + cache_position: torch.LongTensor | None = None, + second_per_grid_ts: torch.Tensor | None = None, + num_inference_timesteps: int | None = 10, + dof_mask: torch.FloatTensor | None = None, + agent_pos_mask: torch.FloatTensor | None = None, + re_generate: bool = False, + **kwargs, + ): + """ + Multi-modal prediction method supporting text generation, fast action prediction, and diffusion-based action prediction. + + This method handles three prediction modes: + 1. "text": Pure text generation using autoregressive decoding + 2. "fast": Fast action prediction using discrete action tokens + 3. "diffusion": Continuous action prediction using diffusion/flow matching + + Args: + predict_mode (str): Prediction mode ("text", "fast", or "diffusion") + pred_horizon (int, optional): Prediction horizon for action sequences + action_dim (int, optional): Dimensionality of action space + input_ids (torch.LongTensor, optional): Input token IDs + attention_mask (torch.Tensor, optional): Attention mask for input tokens + position_ids (torch.LongTensor, optional): Position IDs for tokens + past_key_values (List[torch.FloatTensor], optional): Cached key-value pairs + inputs_embeds (torch.FloatTensor, optional): Pre-computed input embeddings + moe_token_types (torch.LongTensor, optional): Token type assignments for MoE routing + labels (torch.LongTensor, optional): Target labels for evaluation + use_cache (bool, optional): Whether to use key-value caching + output_attentions (bool, optional): Whether to return attention weights + output_hidden_states (bool, optional): Whether to return hidden states + return_dict (bool, optional): Whether to return structured output + pixel_values (torch.Tensor, optional): Image pixel values + pixel_values_videos (torch.FloatTensor, optional): Video pixel values + image_grid_thw (torch.LongTensor, optional): Image grid dimensions + video_grid_thw (torch.LongTensor, optional): Video grid dimensions + action_chunk (torch.FloatTensor, optional): Ground truth action sequences + proprioception (torch.FloatTensor, optional): Proprioceptive sensor data + rope_deltas (torch.LongTensor, optional): RoPE position deltas + cache_position (torch.LongTensor, optional): Cache position indices + second_per_grid_ts (torch.Tensor, optional): Time interval per temporal grid + num_inference_timesteps (int, optional): Number of diffusion inference steps + dof_mask (torch.FloatTensor, optional): Degrees of freedom mask + agent_pos_mask (torch.FloatTensor, optional): Agent position mask + re_generate (bool, optional): Whether to use sampling for regeneration + **kwargs: Additional keyword arguments + + Returns: + dict: Dictionary containing prediction results with keys like: + - 'predict_action': Predicted action sequences + - 'gt_action': Ground truth actions (if available) + - 'input_text': Input text (for text/fast modes) + - 'predict_output_text': Generated text (for text/fast modes) + - 'gt_output_text': Ground truth text (for text/fast modes) + """ + batch_size = input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0] + + # Text and fast modes require batch size 1 for autoregressive generation + if predict_mode in ["text", "fast"]: + assert batch_size == 1, "predict only support batch size 1 for ar generation" + + # Set output configuration from model config if not specified + output_attentions = ( + output_attentions if output_attentions is not None else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # Process input embeddings with multi-modal data + if inputs_embeds is None: + inputs_embeds = self.model.embed_tokens(input_ids) + + # Process image embeddings + if pixel_values is not None: + pixel_values = pixel_values.type(self.visual.dtype) + image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) + n_image_tokens = (input_ids == self.config.image_token_id).sum().item() + n_image_features = image_embeds.shape[0] + + # Validate image token and feature count match + if n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + + mask = input_ids == self.config.image_token_id + mask_unsqueezed = mask.unsqueeze(-1) + mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) + image_mask = mask_expanded.to(inputs_embeds.device) + + image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + # Process video embeddings + if pixel_values_videos is not None: + pixel_values_videos = pixel_values_videos.type(self.visual.dtype) + video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) + n_video_tokens = (input_ids == self.config.video_token_id).sum().item() + n_video_features = video_embeds.shape[0] + + # Validate video token and feature count match + if n_video_tokens != n_video_features: + raise ValueError( + f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" + ) + + mask = input_ids == self.config.video_token_id + mask_unsqueezed = mask.unsqueeze(-1) + mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) + video_mask = mask_expanded.to(inputs_embeds.device) + + video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) + + # Process proprioceptive data + if proprioception is not None: + proprioception = proprioception.to(inputs_embeds.device).to(inputs_embeds.dtype) + agent_pos_mask = agent_pos_mask.to(inputs_embeds.device).to(inputs_embeds.dtype) + proprio_embed = self.action_preprocessor.proprioception_proj( + proprioception, + agent_pos_mask, + use_history=proprioception.shape[1] > 1, + ) + proprioception_mask = input_ids == self.action_token_id_set["propri_token_id"] + proprio_embed = proprio_embed.to(torch.bfloat16) + inputs_embeds[proprioception_mask] = proprio_embed.reshape(-1, inputs_embeds.shape[-1]) + + if attention_mask is not None: + attention_mask = attention_mask.to(inputs_embeds.device) + + # Calculate RoPE position IDs if not provided + # Note: Cannot calculate rope deltas with 4D attention mask. TODO: Fix this limitation + if position_ids is None and (attention_mask is None or attention_mask.ndim == 2): + # Calculate RoPE index once per generation in the pre-fill stage only + if ( + (cache_position is not None and cache_position[0] == 0) + or self.rope_deltas is None + or (past_key_values is None or past_key_values.get_seq_length() == 0) + ): + position_ids, rope_deltas = self.get_rope_index( + input_ids, + image_grid_thw, + video_grid_thw, + second_per_grid_ts, + attention_mask, + ) + self.rope_deltas = rope_deltas + # Use previously calculated rope deltas to get correct position IDs + else: + batch_size, seq_length, _ = inputs_embeds.shape + delta = ( + (cache_position[0] + self.rope_deltas).to(inputs_embeds.device) + if cache_position is not None + else 0 + ) + position_ids = torch.arange(seq_length, device=inputs_embeds.device) + position_ids = position_ids.view(1, -1).expand(batch_size, -1) + if cache_position is not None: # otherwise `deltas` is an int `0` + delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) + position_ids = position_ids.add(delta) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) + + # Prepare action chunk data if provided + if action_chunk is not None: + action_chunk = action_chunk.to(inputs_embeds.device).to(torch.float32) + + output = {} + + # Split input sequence for text and fast modes (not needed for diffusion) + if predict_mode == "text" or predict_mode == "fast": + # Look for generation prompt tokens: <|im_start|>assistant + generation_prompt_ids = torch.tensor( + [151644, 77091], device=input_ids.device, dtype=input_ids.dtype + ) + matches = (input_ids[0, :-1] == generation_prompt_ids[0]) & ( + input_ids[0, 1:] == generation_prompt_ids[1] + ) + + if matches.any(): + split_pos = torch.nonzero(matches, as_tuple=True)[0][0].item() + # Extract ground truth output tokens (including newline) + gt_output_ids = input_ids[:, split_pos + 3 :] + # Remove output part from input, keeping prompt + input_ids = input_ids[:, : split_pos + 3] + inputs_embeds = inputs_embeds[:, : split_pos + 3, :] + if attention_mask is not None: + attention_mask = attention_mask[:, : split_pos + 3] + if labels is not None: + labels = labels[:, split_pos + 3 :] + else: + raise ValueError( + "input_ids does not contain the generation prompt tokens <|im_start|>assistant" + ) + + # Decode input text for output + input_text = self.processor.batch_decode( + input_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True + ) + output["input_text"] = input_text + + # Handle text and fast prediction modes using autoregressive generation + if predict_mode == "text" or predict_mode == "fast": + # Initialize MoE token types for generation + moe_token_types = torch.zeros_like(input_ids) + batch = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "pixel_values": pixel_values, + "moe_token_types": moe_token_types, + "image_grid_thw": image_grid_thw, + "dof_mask": dof_mask, + "agent_pos_mask": agent_pos_mask, + "proprioception": proprioception, + } + + # Generate output tokens + predict_output_ids = self.generate( + **batch, + max_new_tokens=100, + eos_token_id=[self.processor.tokenizer.eos_token_id], + use_cache=True, + pad_token_id=self.processor.tokenizer.pad_token_id, + temperature=(1.0 if not re_generate else 0.7), # Higher temperature for regeneration + do_sample=(False if not re_generate else True), # Enable sampling for regeneration + ) + + # Decode generated and ground truth text + gt_output_text = self.processor.batch_decode( + gt_output_ids, + skip_special_tokens=False, + clean_up_tokenization_spaces=True, + ) + predict_output_text = self.processor.batch_decode( + predict_output_ids, + skip_special_tokens=False, + clean_up_tokenization_spaces=True, + ) + output["gt_output_text"] = gt_output_text + output["predict_output_text"] = predict_output_text + + # Convert tokens to actions for fast prediction mode + if predict_mode == "fast": + action_id = [] + # Extract action tokens from generated sequence + for token_id_i in predict_output_ids[0]: + if token_id_i.item() >= self.processor.tokenizer.init_kwargs["action_token_start_index"]: + action_id.append( + token_id_i.item() - self.processor.tokenizer.init_kwargs["action_token_start_index"] + ) + + predict_action = self.processor.action_processor.decode( + [action_id], time_horizon=pred_horizon, action_dim=action_dim + ) + # Handle action decoding errors + if np.sum(predict_action) == 0: + print("Error in decoding action, predict_action is None") + output["predict_action"] = None + else: + # Convert discrete tokens to continuous actions + predict_action = torch.tensor(predict_action, device=self.device) + dof_mask = dof_mask.to(self.device).to(pixel_values.dtype) + # removed unnormalization step for now + predict_action = predict_action[:, :, dof_mask[0, 0, :].bool()] + output["predict_action"] = predict_action + + # Process ground truth actions if available + if action_chunk is not None: + # Apply DOF mask to get ground truth actions + # removed unnormalization step for now + action_chunk = action_chunk[:, :, dof_mask[0, 0, :].bool()] + output["gt_action"] = action_chunk + else: + output["gt_action"] = None + + # Handle diffusion-based action prediction + if predict_mode == "diffusion": + # Initialize with random noise + noisy_action = torch.randn( + size=(batch_size, pred_horizon, action_dim), + dtype=torch.float32, + device=inputs_embeds.device, + ) + dof_mask = dof_mask.to(inputs_embeds.device).to(torch.float32) + + def step(timestep, noisy_action): + """ + Single denoising step for diffusion process. + + Args: + timestep: Current diffusion timestep + noisy_action: Current noisy action estimate + + Returns: + torch.Tensor: Predicted clean action + """ + action_mask = input_ids == self.action_token_id_set["action_token_id"] + assert action_mask.any(), "No action token found in input_ids" + + # Prepare timestep for batch processing + timestep = timestep.unsqueeze(0).repeat(noisy_action.shape[0]) + action_embed = self.action_preprocessor.step( + timestep=timestep, noisy_action=noisy_action, dof_mask=dof_mask + ) + action_embed = action_embed.reshape(-1, inputs_embeds.shape[-1]) + + # Ensure action_embed has the correct dtype and device before assignment + action_embed = action_embed.to(dtype=inputs_embeds.dtype, device=inputs_embeds.device) + + # Create temporary copy of embeddings (clone preserves dtype) + temp_inputs_embeds = inputs_embeds.clone() + temp_inputs_embeds[action_mask] = action_embed + + # Forward pass through transformer + transformer_outputs = self.model( + input_ids=None, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=temp_inputs_embeds, + moe_token_types=moe_token_types, + use_cache=True, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + ) + + # Extract action predictions from hidden states + hidden_states = transformer_outputs.last_hidden_state + action_mask = input_ids == self.action_token_id_set["action_token_id"] + action_hidden_states = hidden_states[action_mask].to(torch.float32) + pred = self.action_preprocessor.action_proj_back(action_hidden_states) + return pred.reshape(batch_size, pred_horizon, action_dim) + + # Perform ODE integration for diffusion sampling + times = torch.linspace( + 0, + 1, + num_inference_timesteps + 1, + device=inputs_embeds.device, + dtype=torch.float32, + ) + action_trajectory = odeint(step, noisy_action, times, method="euler") + + # Extract final predicted action + # Removed unnormalization step for now + predict_action = action_trajectory[-1] + output["predict_action"] = predict_action + + # Process ground truth actions if available + # removed unnormalization step for now + if action_chunk is not None: + output["gt_action"] = action_chunk[:, :, dof_mask[0, 0, :].bool()] + + return output + + def forward(self, mode: str | None = None, predict_mode: str | None = "text", **kwargs): + """ + Main forward pass dispatcher for different execution modes. + + This method routes execution to appropriate forward functions based on the specified mode: + - No mode (None): Training step with gradient disabled + - 'predict': Prediction/inference mode + - 'train': Training mode with gradients enabled + - 'validate': Validation mode with gradients disabled + + Args: + mode (str, optional): Execution mode. If None, defaults to training step without gradients + predict_mode (str, optional): Prediction mode for 'predict' mode ("text", "fast", or "diffusion") + **kwargs: Additional arguments passed to the selected forward function + + Returns: + Model outputs appropriate for the selected mode + + Todo: + - Add support for distinguishing multi-modal data types in prediction mode + """ + if not mode: + with torch.no_grad(): + return self.train_step_forward(**kwargs) + elif mode == "predict": + return self.predict(predict_mode=predict_mode, **kwargs) + elif mode == "train": + return self.train_step_forward(use_cache=False, **kwargs) + elif mode == "validate": + with torch.no_grad(): + return self.train_step_forward(use_cache=False, **kwargs) + else: + raise NotImplementedError("invalid key") + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + moe_token_types=None, + cache_position=None, + position_ids=None, + use_cache=True, + pixel_values=None, + pixel_values_videos=None, + image_grid_thw=None, + video_grid_thw=None, + second_per_grid_ts=None, + proprioception=None, + dof_mask=None, + agent_pos_mask=None, + **kwargs, + ): + """ + Prepare inputs for autoregressive generation with multi-modal support. + + This method handles input preparation for generation, including proper slicing of inputs + based on cache position, MoE token type management, and multi-modal data handling. + Vision inputs are selectively forwarded only when needed during generation. + + Args: + input_ids: Input token IDs + past_key_values: Cached key-value pairs from previous generation steps + attention_mask: Attention mask for input tokens + inputs_embeds: Pre-computed input embeddings + moe_token_types: Token type assignments for MoE routing + cache_position: Current cache position for generation + position_ids: Position IDs for tokens + use_cache: Whether to use key-value caching + pixel_values: Image pixel values + pixel_values_videos: Video pixel values + image_grid_thw: Image grid dimensions + video_grid_thw: Video grid dimensions + second_per_grid_ts: Time interval per temporal grid + proprioception: Proprioceptive sensor data + dof_mask: Degrees of freedom mask + agent_pos_mask: Agent position mask + **kwargs: Additional arguments + + Returns: + dict: Prepared model inputs for generation step + + Todo: + - Test this function thoroughly with various input configurations + + Note: + This is an overridden method that handles specific cases for multi-modal generation: + - Slices input_ids through cache_position to keep only unprocessed tokens + - Handles special cases for input_embeds, generation methods, and GPU synchronization + - Manages vision inputs to avoid unnecessary forward passes + """ + # Initialize MoE token types if not provided + if moe_token_types is None: + moe_token_types = torch.zeros_like( + input_ids + ) # FIXME: Handle case when input_embeds is used instead + else: + # Ensure moe_token_types length matches input_ids + if moe_token_types.shape[1] < input_ids.shape[1]: + # Calculate required padding length + pad_length = input_ids.shape[1] - moe_token_types.shape[1] + # Create padding tensor with default token type (0) + pad_tensor = torch.zeros( + (moe_token_types.shape[0], pad_length), + dtype=moe_token_types.dtype, + device=moe_token_types.device, + ) + # Concatenate padding to existing moe_token_types + moe_token_types = torch.cat([moe_token_types, pad_tensor], dim=1) + + # Handle input slicing based on cache state and special cases + if past_key_values is not None: + if inputs_embeds is not None and input_ids.shape[1] == 0: # Exception 4: input_embeds case + inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :] + moe_token_types = moe_token_types[:, -cache_position.shape[0] :] + elif inputs_embeds is not None or ( # Exception 1: input_embeds provided + is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1] + ): # Exception 3: GPU sync edge case + input_ids = input_ids[:, -cache_position.shape[0] :] + moe_token_types = moe_token_types[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (Exception 2 is no-op) + cache_pos = cache_position.clone() + input_ids = input_ids[:, cache_pos] + moe_token_types = moe_token_types[:, cache_pos] + + # Skip vision inputs for continuation steps (not initial generation) + if cache_position[0] != 0: + pixel_values = None + pixel_values_videos = None + + # Determine whether to use inputs_embeds or input_ids for this generation step + if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]: + model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} + else: + model_inputs = {"input_ids": input_ids, "inputs_embeds": None} + + # Prepare 4D causal attention mask for static cache + if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: + if model_inputs["inputs_embeds"] is not None: + batch_size, sequence_length, _ = inputs_embeds.shape + device = inputs_embeds.device + else: + batch_size, sequence_length = input_ids.shape + device = input_ids.device + + attention_mask = self.model._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=past_key_values.get_max_cache_shape(), + dtype=self.lm_head.weight.dtype, + device=device, + cache_position=cache_position, + batch_size=batch_size, + config=self.config, + past_key_values=past_key_values, + ) + + # Assemble all model inputs for generation + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "moe_token_types": moe_token_types, + "use_cache": use_cache, + "attention_mask": attention_mask, + "pixel_values": pixel_values, + "pixel_values_videos": pixel_values_videos, + "image_grid_thw": image_grid_thw, + "video_grid_thw": video_grid_thw, + "cache_position": cache_position, + "second_per_grid_ts": second_per_grid_ts, + "proprioception": proprioception, + "dof_mask": dof_mask, + "agent_pos_mask": agent_pos_mask, + } + ) + return model_inputs + + def _get_image_nums_and_video_nums( + self, + input_ids: torch.LongTensor | None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Get the number of images and videos for each sample to calculate tensor separation lengths. + + These parameters are computed directly from input_ids rather than being passed through + the processor to avoid unpredictable impacts from interface modifications. + + Args: + input_ids (torch.LongTensor): Input token IDs of shape (batch_size, sequence_length) + + Returns: + tuple: + - image_nums (torch.LongTensor): Number of images per sample + - video_nums (torch.LongTensor): Number of videos per sample + """ + image_token_id = self.config.image_token_id + video_token_id = self.config.video_token_id + vision_start_token_id = self.config.vision_start_token_id + + # Find vision start tokens and their following tokens + vision_start_mask = input_ids == vision_start_token_id + vision_first_mask = torch.roll(vision_start_mask, shifts=1, dims=1) + image_mask = input_ids == image_token_id + video_mask = input_ids == video_token_id + + # Count images and videos following vision start tokens + image_nums = torch.sum(vision_first_mask & image_mask, dim=1) + video_nums = torch.sum(vision_first_mask & video_mask, dim=1) + + return image_nums, video_nums + + def _expand_inputs_for_generation( + self, + expand_size: int = 1, + is_encoder_decoder: bool = False, + input_ids: torch.LongTensor | None = None, + **model_kwargs, + ) -> tuple[torch.LongTensor, dict[str, Any]]: + """ + Expand inputs for generation with support for multi-modal tensors. + + This is an overridden method that supports expanding tensors without a standard batch + size dimension, specifically for vision-related tensors: + - pixel_values.shape[0] = sum(sequence_lengths for all image samples) + - image_grid_thw.shape[0] = sum(num_images for all samples) + - Similar patterns for video tensors + + Args: + expand_size (int): Factor by which to expand inputs (for beam search, etc.) + is_encoder_decoder (bool): Whether using encoder-decoder architecture + input_ids (torch.LongTensor, optional): Input token IDs + **model_kwargs: Additional model arguments to expand + + Returns: + tuple: (expanded_input_ids, expanded_model_kwargs) + """ + if expand_size == 1: + return input_ids, model_kwargs + + # Define keys for vision-related tensors that need special handling + visual_keys = [ + "pixel_values", + "image_grid_thw", + "pixel_values_videos", + "video_grid_thw", + "second_per_grid_ts", + ] + + def _expand_dict_for_generation_visual(dict_to_expand): + """Expand vision-related tensors based on image/video counts per sample.""" + image_grid_thw = model_kwargs.get("image_grid_thw", None) + video_grid_thw = model_kwargs.get("video_grid_thw", None) + image_nums, video_nums = self._get_image_nums_and_video_nums(input_ids) + + def _repeat_interleave_samples(x, lengths, repeat_times): + """Split tensor by lengths and repeat each sample.""" + samples = torch.split(x, lengths) + repeat_args = [repeat_times] + [1] * (x.dim() - 1) + result = torch.cat([sample.repeat(*repeat_args) for sample in samples], dim=0) + return result + + for key in dict_to_expand: + if key == "pixel_values": + # Split images into samples and compute sequence lengths + samples = torch.split(image_grid_thw, list(image_nums)) + lengths = [torch.prod(sample, dim=1).sum() for sample in samples] + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "image_grid_thw": + # Expand based on number of images per sample + lengths = list(image_nums) + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "pixel_values_videos": + # Split videos into samples and compute sequence lengths + samples = torch.split(video_grid_thw, list(video_nums)) + lengths = [torch.prod(sample, dim=1).sum() for sample in samples] + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "video_grid_thw": + # Expand based on number of videos per sample + lengths = list(video_nums) + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "second_per_grid_ts": + # Handle list-type temporal grid data + if not isinstance(dict_to_expand[key], list): + raise TypeError( + f"Expected value for key '{key}' to be a list, but got {type(dict_to_expand[key])} instead." + ) + tensor = torch.tensor(dict_to_expand[key]) + lengths = list(video_nums) + tensor = _repeat_interleave_samples(tensor, lengths=lengths, repeat_times=expand_size) + dict_to_expand[key] = tensor.tolist() + return dict_to_expand + + def _expand_dict_for_generation(dict_to_expand): + """Expand standard tensors using repeat_interleave.""" + for key in dict_to_expand: + if ( + key != "cache_position" + and dict_to_expand[key] is not None + and isinstance(dict_to_expand[key], torch.Tensor) + and key not in visual_keys + ): + dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0) + return dict_to_expand + + # Expand visual inputs only if input_ids is available for counting images/videos + # If input_ids is unavailable, visual inputs won't be used, so no expansion needed + if input_ids is not None and input_ids.numel() != 0: + model_kwargs = _expand_dict_for_generation_visual(model_kwargs) + + # Expand input_ids using standard repeat_interleave + if input_ids is not None: + input_ids = input_ids.repeat_interleave(expand_size, dim=0) + + # Expand all other model arguments + model_kwargs = _expand_dict_for_generation(model_kwargs) + + # Handle encoder-decoder specific expansion + if is_encoder_decoder: + if model_kwargs.get("encoder_outputs") is None: + raise ValueError( + "If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined." + ) + model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"]) + + return input_ids, model_kwargs + + +class WallXPolicy(PreTrainedPolicy): + """ + Wall-X policy for cross-embodiment robotic control. + + Integrates Qwen2.5-VL vision-language model with action prediction + using flow matching for continuous action spaces. + """ + + config_class = WallXConfig + name = "wall_x" + + def __init__(self, config: WallXConfig): + super().__init__(config) + config.validate_features() + self.config = config + + # Initialize the wall-x model + self.model = Qwen2_5_VLMoEForAction.from_pretrained( + pretrained_name_or_path=config.pretrained_name_or_path, + action_tokenizer_path=config.action_tokenizer_path, + attn_implementation=config.attn_implementation, + ) + self.model.to(config.device) + self.model.to_bfloat16_for_selected_params() + + self.reset() + + def reset(self): + """Reset action queue.""" + self._queues = { + ACTION: deque(maxlen=self.config.n_action_steps), + } + + def get_optim_params(self): + """Get parameters for optimization.""" + return self.parameters() + + def preprocess_inputs( + self, + batch: dict[str, Any], + ) -> BatchFeature: + """ + Convert a batch of LeRobot dataset items to Wall-X model input format. + + This processes a batched dictionary where tensors have batch dimension first. + + Args: + batch: Dictionary with batched tensors: + - "observation.state": (batch_size, state_dim) or (batch_size, n_obs_steps, state_dim) + - "action": (batch_size, chunk_size, action_dim) + - "observation.images.": (batch_size, C, H, W) + - "task": List[str] of length batch_size + + Returns: + BatchFeature containing batched model inputs + """ + use_fast_tokenizer = self.config.use_fast_tokenizer + + # Get batch size from state tensor + batch_size = batch[OBS_STATE].shape[0] + + # ==================== PROCESS ALL SAMPLES ==================== + all_image_inputs = [] + all_texts = [] + + # Find image keys in batch + img_keys = [key for key in self.config.image_features if key in batch] + + for i in range(batch_size): + # Vision preprocessing per sample + processed_frames = [] + orig_height, orig_width = None, None + resized_height, resized_width = None, None + + for key in img_keys: + current_obs = batch[key][i].clone() # (C, H, W) + if current_obs.dim() == 3: + current_obs = current_obs.permute(1, 2, 0) # (H, W, C) + + img_pil = Image.fromarray((current_obs * 255).to(torch.uint8).cpu().numpy()) + orig_width, orig_height = img_pil.size + + target_size = RESOLUTION + if target_size != -1: + if orig_width > orig_height: + new_width = target_size + new_height = int(target_size * orig_height / orig_width) + else: + new_height = target_size + new_width = int(target_size * orig_width / orig_height) + img_pil = img_pil.resize((new_width, new_height)) + + current_width, current_height = img_pil.size + resized_height, resized_width = smart_resize( + current_height, + current_width, + factor=IMAGE_FACTOR, + min_pixels=MIN_PIXELS, + max_pixels=MAX_PIXELS, + ) + resized_img = img_pil.resize((resized_width, resized_height)) + processed_frames.append(resized_img) + + all_image_inputs.append(processed_frames) + + # Text preprocessing + task_text = batch["task"][i] if isinstance(batch["task"], list) else batch["task"] + instruction_info = {"instruction": task_text} + + frame_index = batch["frame_index"][i] if "frame_index" in batch else 0 + complete_text, _ = get_wallx_normal_text( + instruction_info, + self.config.chunk_size, + frame_index, + PRIORITY_ORDER, + img_keys, + generate_subtask_ratio=GENERATE_SUBTASK_RATIO, + ) + + text = process_grounding_points( + complete_text, orig_height, orig_width, resized_height, resized_width, MODEL_TYPE + ) + all_texts.append(text) + + # ==================== PROCESS AGENT POS ==================== + agent_pos = batch[OBS_STATE] # (batch_size, state_dim) + if agent_pos.dim() == 2: + agent_pos = agent_pos.unsqueeze(1) # (batch_size, 1, state_dim) + agent_pos_mask = (~torch.isnan(agent_pos)).float() + agent_pos = agent_pos.nan_to_num(nan=0.0) + + if agent_pos.shape[-1] != 20: + pad_size = 20 - agent_pos.shape[-1] + agent_pos = torch.cat( + [ + agent_pos, + torch.zeros(agent_pos.shape[0], agent_pos.shape[1], pad_size, device=agent_pos.device), + ], + dim=-1, + ) + agent_pos_mask = torch.cat( + [ + agent_pos_mask, + torch.zeros( + agent_pos_mask.shape[0], + agent_pos_mask.shape[1], + pad_size, + device=agent_pos_mask.device, + ), + ], + dim=-1, + ) + + # ==================== PROCESS ACTIONS ==================== + action = batch.get(ACTION) # (batch_size, chunk_size, action_dim) + if action is not None: + if action.dim() == 2: + action = action.unsqueeze(1) + dof_mask = (~torch.isnan(action)).float() + action = action.nan_to_num(nan=0.0) + + if action.shape[-1] != 20: + pad_size = 20 - action.shape[-1] + action = torch.cat( + [action, torch.zeros(action.shape[0], action.shape[1], pad_size, device=action.device)], + dim=-1, + ) + dof_mask = torch.cat( + [ + dof_mask, + torch.zeros(dof_mask.shape[0], dof_mask.shape[1], pad_size, device=dof_mask.device), + ], + dim=-1, + ) + else: + action_dim = self.config.output_features["action"].shape[0] + dof_mask = torch.cat( + [ + torch.ones( + batch_size, self.config.chunk_size, action_dim, device=batch[OBS_STATE].device + ), + torch.zeros( + batch_size, self.config.chunk_size, 20 - action_dim, device=batch[OBS_STATE].device + ), + ], + dim=-1, + ) + + # ==================== ACTION TOKEN REPLACEMENT ==================== + all_texts = replace_action_token( + all_texts, + action, + self.model.action_tokenizer if use_fast_tokenizer else None, + dof_mask, + ) + + # ==================== TOKENIZATION ==================== + inputs = preprocesser_call( + processor=self.model.processor, + text=all_texts, + images=all_image_inputs, + videos=None, + padding=True, + truncation=True, + return_tensors="pt", + max_length=TOKENIZER_MAX_LENGTH, + ) + + # ==================== ADDITIONAL INPUTS ==================== + action_token_id = self.model.processor.tokenizer.convert_tokens_to_ids("<|action|>") + moe_token_types = inputs.input_ids == action_token_id + + inputs["proprioception"] = agent_pos + inputs["agent_pos_mask"] = agent_pos_mask + inputs["action_chunk"] = action + inputs["dof_mask"] = dof_mask + inputs["moe_token_types"] = moe_token_types + inputs["frame_index"] = ( + batch["frame_index"] + if "frame_index" in batch + else torch.zeros(batch_size, device=batch[OBS_STATE].device) + ) + + # Move all tensors to the correct device + device = self.config.device + for key, value in inputs.items(): + if isinstance(value, torch.Tensor): + inputs[key] = value.to(device) + + return inputs + + def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]: + """ + Training forward pass using Qwen2_5_VLMoEForAction. + + Args: + batch: Dictionary containing preprocessed inputs from preprocess_inputs() + Expected keys: input_ids, attention_mask, pixel_values, image_grid_thw, + proprioception, agent_pos_mask, action_chunk, dof_mask, moe_token_types, + etc. + + Returns: + tuple: (loss, loss_dict) + """ + batch = self.preprocess_inputs( + batch, + ) + + # Call the underlying model's forward with mode="train" + outputs = self.model(**batch, mode="train") + + # Extract losses from output + loss = outputs.loss + loss_dict = { + "loss": loss.item() if loss is not None else 0.0, + } + + if outputs.flow_loss is not None: + loss_dict["flow_loss"] = outputs.flow_loss.item() + if outputs.cross_entropy_loss is not None: + loss_dict["cross_entropy_loss"] = outputs.cross_entropy_loss.item() + + # Add channel losses if available + if outputs.channel_loss_dict is not None: + for key, value in outputs.channel_loss_dict.items(): + if isinstance(value, torch.Tensor): + loss_dict[f"channel_{key}"] = value.item() + + return loss, loss_dict + + @torch.no_grad() + def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: + """Predict action chunk for evaluation.""" + self.eval() + self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION]) + + batch = self.preprocess_inputs( + batch, + ) + + if self.config.prediction_mode == "diffusion": + output = self.model( + **batch, + action_dim=self.config.max_action_dim, + pred_horizon=self.config.chunk_size, + mode="predict", + predict_mode="diffusion", + ) + elif self.config.prediction_mode == "fast": + output = self.model( + **batch, + action_dim=self.config.output_features["action"].shape[0], + pred_horizon=self.config.chunk_size, + mode="predict", + predict_mode="fast", + ) + else: + raise NotImplementedError(f"Prediction mode {self.config.prediction_mode} not implemented") + + # Extract action tensor from output dictionary + actions = output["predict_action"] + + # Unpad actions to actual action dimension + action_dim = self.config.output_features["action"].shape[0] + actions = actions[:, :, :action_dim] + + return actions + + @torch.no_grad() + def select_action(self, batch: dict[str, Tensor]) -> Tensor: + """Select single action for environment execution.""" + self.eval() + self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION]) + + # Use action queue + if len(self._queues[ACTION]) == 0: + actions = self.predict_action_chunk(batch) + self._queues[ACTION].extend(actions.transpose(0, 1)[: self.config.n_action_steps]) + + return self._queues[ACTION].popleft() diff --git a/src/lerobot/policies/wall_x/processor_wall_x.py b/src/lerobot/policies/wall_x/processor_wall_x.py new file mode 100644 index 000000000..e4e281541 --- /dev/null +++ b/src/lerobot/policies/wall_x/processor_wall_x.py @@ -0,0 +1,133 @@ +#!/usr/bin/env python + +# Copyright 2025 HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +import torch + +from lerobot.configs.types import PipelineFeatureType, PolicyFeature +from lerobot.policies.wall_x.configuration_wall_x import WallXConfig +from lerobot.processor import ( + AddBatchDimensionProcessorStep, + ComplementaryDataProcessorStep, + DeviceProcessorStep, + NormalizerProcessorStep, + PolicyAction, + PolicyProcessorPipeline, + ProcessorStepRegistry, + RenameObservationsProcessorStep, + UnnormalizerProcessorStep, +) +from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action +from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME + + +def make_wall_x_pre_post_processors( + config: WallXConfig, + dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None, +) -> tuple[ + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], + PolicyProcessorPipeline[PolicyAction, PolicyAction], +]: + """ + Constructs pre-processor and post-processor pipelines for the Wall-X policy. + + The pre-processing pipeline prepares input data for the model by: + 1. Renaming features to match pretrained configurations + 2. Adding a batch dimension + 4. Normalizing input and output features based on dataset statistics + 5. Moving all data to the specified device + + The post-processing pipeline handles the model's output by: + 1. Unnormalizing the output actions to their original scale + 2. Moving data to the CPU + + Args: + config: The configuration object for the Wall-X policy + dataset_stats: A dictionary of statistics for normalization + + Returns: + A tuple containing the configured pre-processor and post-processor pipelines + """ + + input_steps = [ + RenameObservationsProcessorStep(rename_map={}), + AddBatchDimensionProcessorStep(), + WallXTaskProcessor(), # Process task description + NormalizerProcessorStep( + features={**config.input_features, **config.output_features}, + norm_map=config.normalization_mapping, + stats=dataset_stats, + ), + DeviceProcessorStep(device=config.device), + ] + + output_steps = [ + UnnormalizerProcessorStep( + features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats + ), + DeviceProcessorStep(device="cpu"), + ] + + return ( + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]]( + steps=input_steps, + name=POLICY_PREPROCESSOR_DEFAULT_NAME, + ), + PolicyProcessorPipeline[PolicyAction, PolicyAction]( + steps=output_steps, + name=POLICY_POSTPROCESSOR_DEFAULT_NAME, + to_transition=policy_action_to_transition, + to_output=transition_to_policy_action, + ), + ) + + +@ProcessorStepRegistry.register(name="wall_x_task_processor") +class WallXTaskProcessor(ComplementaryDataProcessorStep): + """ + A processor step that ensures the task description is properly formatted for Wall-X. + + This step handles task preprocessing similar to Qwen-VL requirements. + """ + + def complementary_data(self, complementary_data): + if "task" not in complementary_data: + return complementary_data + + task = complementary_data["task"] + if task is None: + # Provide default task if none specified + complementary_data["task"] = "Execute the robot action." + return complementary_data + + new_complementary_data = dict(complementary_data) + + # Handle both string and list of strings + if isinstance(task, str): + # Single string: ensure proper formatting + if not task.endswith("."): + new_complementary_data["task"] = f"{task}." + elif isinstance(task, list) and all(isinstance(t, str) for t in task): + # List of strings: format each + new_complementary_data["task"] = [t if t.endswith(".") else f"{t}." for t in task] + + return new_complementary_data + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + return features diff --git a/src/lerobot/policies/wall_x/qwen_model/configuration_qwen2_5_vl.py b/src/lerobot/policies/wall_x/qwen_model/configuration_qwen2_5_vl.py new file mode 100644 index 000000000..731ef3b3e --- /dev/null +++ b/src/lerobot/policies/wall_x/qwen_model/configuration_qwen2_5_vl.py @@ -0,0 +1,248 @@ +from transformers.configuration_utils import PretrainedConfig +from transformers.modeling_rope_utils import rope_config_validation + + +class Qwen2_5_VLVisionConfig(PretrainedConfig): + model_type = "qwen2_5_vl" + base_config_key = "vision_config" + + def __init__( + self, + depth=32, + hidden_size=3584, + hidden_act="silu", + intermediate_size=3420, + num_heads=16, + in_channels=3, + patch_size=14, + spatial_merge_size=2, + temporal_patch_size=2, + tokens_per_second=4, + window_size=112, + out_hidden_size=3584, + fullatt_block_indexes=[7, 15, 23, 31], + **kwargs, + ): + super().__init__(**kwargs) + + self.depth = depth + self.hidden_size = hidden_size + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.num_heads = num_heads + self.in_channels = in_channels + self.patch_size = patch_size + self.spatial_merge_size = spatial_merge_size + self.temporal_patch_size = temporal_patch_size + self.tokens_per_second = tokens_per_second + self.window_size = window_size + self.fullatt_block_indexes = fullatt_block_indexes + self.out_hidden_size = out_hidden_size + + +class Qwen2_5_VLConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen2_5_VLModel`]. It is used to instantiate a + Qwen2-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of + Qwen2-VL-7B-Instruct [Qwen/Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 152064): + Vocabulary size of the Qwen2_5_VL model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Qwen2_5_VLModel`] + hidden_size (`int`, *optional*, defaults to 8192): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 29568): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 80): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 64): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 8): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 32768): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_theta (`float`, *optional*, defaults to 1000000.0): + The base period of the RoPE embeddings. + use_sliding_window (`bool`, *optional*, defaults to `False`): + Whether to use sliding window attention. + sliding_window (`int`, *optional*, defaults to 4096): + Sliding window attention (SWA) window size. If not specified, will default to `4096`. + max_window_layers (`int`, *optional*, defaults to 80): + The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + vision_config (`Dict`, *optional*): + The config for the visual encoder initialization. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + + ```python + >>> from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLConfig + + >>> # Initializing a Qwen2_5_VL style configuration + >>> configuration = Qwen2_5_VLConfig() + + >>> # Initializing a model from the Qwen2-VL-7B style configuration + >>> model = Qwen2_5_VLForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "qwen2_5_vl" + sub_configs = {"vision_config": Qwen2_5_VLVisionConfig} + keys_to_ignore_at_inference = ["past_key_values"] + # Default tensor parallel plan for base model `Qwen2_5_VL` + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + def __init__( + self, + vocab_size=152064, + hidden_size=8192, + intermediate_size=29568, + num_hidden_layers=80, + num_attention_heads=64, + num_key_value_heads=8, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-05, + use_cache=True, + tie_word_embeddings=False, + rope_theta=1000000.0, + use_sliding_window=False, + sliding_window=4096, + max_window_layers=80, + attention_dropout=0.0, + vision_config=None, + rope_scaling=None, + num_experts=4, + experts=None, + dof_config=None, + noise_scheduler=None, + dim_inputs=(1536, 1536), + attention_moe=False, + mlp_moe=False, + **kwargs, + ): + if isinstance(vision_config, dict): + self.vision_config = self.sub_configs["vision_config"](**vision_config) + elif vision_config is None: + self.vision_config = self.sub_configs["vision_config"]() + + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.use_sliding_window = use_sliding_window + self.sliding_window = sliding_window + self.max_window_layers = max_window_layers + self.layer_types = ["dense"] * num_hidden_layers + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_dropout = attention_dropout + self.rope_scaling = rope_scaling + + self.num_experts = num_experts + self.experts = experts + self.dof_config = dof_config + self.noise_scheduler = noise_scheduler + self.dim_inputs = tuple(dim_inputs) + self.attention_moe = attention_moe + self.mlp_moe = mlp_moe + + if self.rope_scaling is not None and "type" in self.rope_scaling: + if self.rope_scaling["type"] == "mrope": + self.rope_scaling["type"] = "default" + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + rope_config_validation(self, ignore_keys={"mrope_section"}) + + super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) + + @property + def text_config(self): + return self + + +__all__ = ["Qwen2_5_VLConfig"] diff --git a/src/lerobot/policies/wall_x/qwen_model/qwen2_5_vl_moe.py b/src/lerobot/policies/wall_x/qwen_model/qwen2_5_vl_moe.py new file mode 100644 index 000000000..490e25095 --- /dev/null +++ b/src/lerobot/policies/wall_x/qwen_model/qwen2_5_vl_moe.py @@ -0,0 +1,2788 @@ +import math +from dataclasses import dataclass +from typing import Any + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import CrossEntropyLoss +from transformers import AutoConfig +from transformers.activations import ACT2FN +from transformers.cache_utils import ( + Cache, + DynamicCache, + SlidingWindowCache, + StaticCache, +) +from transformers.generation import GenerationMixin +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + is_torchdynamo_compiling, + logging, + replace_return_docstrings, +) + +from .configuration_qwen2_5_vl import Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.layers.rotary import apply_rotary_emb +else: + flash_attn_varlen_func = None + apply_rotary_emb = None + flash_attn_func = None + + +if is_flash_attn_2_available(): + pass +else: + flash_attn_varlen_func = None + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "Qwen2_5_VLConfig" + + +class Qwen2_5_VLMLP(nn.Module): + def __init__(self, config, bias: bool = False): + super().__init__() + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=bias) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_state): + return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) + + +class Qwen2_5_VisionPatchEmbed(nn.Module): + def __init__( + self, + patch_size: int = 14, + temporal_patch_size: int = 2, + in_channels: int = 3, + embed_dim: int = 1152, + ) -> None: + super().__init__() + self.patch_size = patch_size + self.temporal_patch_size = temporal_patch_size + self.in_channels = in_channels + self.embed_dim = embed_dim + + kernel_size = [temporal_patch_size, patch_size, patch_size] + self.proj = nn.Conv3d( + in_channels, + embed_dim, + kernel_size=kernel_size, + stride=kernel_size, + bias=False, + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + target_dtype = self.proj.weight.dtype + hidden_states = hidden_states.view( + -1, + self.in_channels, + self.temporal_patch_size, + self.patch_size, + self.patch_size, + ) + hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim) + return hidden_states + + +class Qwen2_5_VisionRotaryEmbedding(nn.Module): + def __init__(self, dim: int, theta: float = 10000.0) -> None: + super().__init__() + inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + def forward(self, seqlen: int) -> torch.Tensor: + seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + freqs = torch.outer(seq, self.inv_freq) + return freqs + + +class Qwen2RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Qwen2RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class Qwen2_5_VLPatchMerger(nn.Module): + def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> None: + super().__init__() + self.hidden_size = context_dim * (spatial_merge_size**2) + self.ln_q = Qwen2RMSNorm(context_dim, eps=1e-6) + self.mlp = nn.Sequential( + nn.Linear(self.hidden_size, self.hidden_size), + nn.GELU(), + nn.Linear(self.hidden_size, dim), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.mlp(self.ln_q(x).view(-1, self.hidden_size)) + return x + + +def apply_rotary_pos_emb_flashatt( + q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor]: + cos = cos.chunk(2, dim=-1)[0].contiguous() + sin = sin.chunk(2, dim=-1)[0].contiguous() + q_embed = apply_rotary_emb(q.float(), cos.float(), sin.float()).type_as(q) + k_embed = apply_rotary_emb(k.float(), cos.float(), sin.float()).type_as(k) + return q_embed, k_embed + + +class Qwen2_5_VLVisionFlashAttention2(nn.Module): + def __init__(self, dim: int, num_heads: int = 16) -> None: + super().__init__() + self.num_heads = num_heads + self.qkv = nn.Linear(dim, dim * 3, bias=True) + self.proj = nn.Linear(dim, dim) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: int | None = None, + rotary_pos_emb: torch.Tensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + ) -> torch.Tensor: + seq_length = hidden_states.shape[0] + q, k, v = ( + self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) + ) + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be " + "removed and `position_embeddings` will be mandatory." + ) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + cos = emb.cos().float() + sin = emb.sin().float() + else: + cos, sin = position_embeddings + q, k = apply_rotary_pos_emb_flashatt(q.unsqueeze(0), k.unsqueeze(0), cos, sin) + q = q.squeeze(0) + k = k.squeeze(0) + + if max_seqlen is None: + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape( + seq_length, -1 + ) + attn_output = self.proj(attn_output) + return attn_output + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb_vision( + q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor]: + orig_q_dtype = q.dtype + orig_k_dtype = k.dtype + q, k = q.float(), k.float() + cos, sin = cos.unsqueeze(-2), sin.unsqueeze(-2) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + q_embed = q_embed.to(orig_q_dtype) + k_embed = k_embed.to(orig_k_dtype) + return q_embed, k_embed + + +class Qwen2_5_VLVisionAttention(nn.Module): + def __init__(self, dim: int, num_heads: int = 16) -> None: + super().__init__() + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.qkv = nn.Linear(dim, dim * 3, bias=True) + self.proj = nn.Linear(dim, dim) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: int | None = None, + rotary_pos_emb: torch.Tensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + ) -> torch.Tensor: + seq_length = hidden_states.shape[0] + q, k, v = ( + self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) + ) + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be " + "removed and `position_embeddings` will be mandatory." + ) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + cos = emb.cos().float() + sin = emb.sin().float() + else: + cos, sin = position_embeddings + q, k = apply_rotary_pos_emb_vision(q, k, cos, sin) + + attention_mask = torch.full( + [1, seq_length, seq_length], + torch.finfo(q.dtype).min, + device=q.device, + dtype=q.dtype, + ) + for i in range(1, len(cu_seqlens)): + attention_mask[ + ..., + cu_seqlens[i - 1] : cu_seqlens[i], + cu_seqlens[i - 1] : cu_seqlens[i], + ] = 0 + + q = q.transpose(0, 1) + k = k.transpose(0, 1) + v = v.transpose(0, 1) + attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim) + attn_weights = attn_weights + attention_mask + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype) + attn_output = torch.matmul(attn_weights, v) + attn_output = attn_output.transpose(0, 1) + attn_output = attn_output.reshape(seq_length, -1) + attn_output = self.proj(attn_output) + return attn_output + + +class Qwen2_5_VLVisionSdpaAttention(nn.Module): + def __init__(self, dim: int, num_heads: int = 16) -> None: + super().__init__() + self.num_heads = num_heads + self.qkv = nn.Linear(dim, dim * 3, bias=True) + self.proj = nn.Linear(dim, dim) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: int | None = None, + rotary_pos_emb: torch.Tensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + ) -> torch.Tensor: + seq_length = hidden_states.shape[0] + q, k, v = ( + self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) + ) + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be " + "removed and `position_embeddings` will be mandatory." + ) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + cos = emb.cos().float() + sin = emb.sin().float() + else: + cos, sin = position_embeddings + q, k = apply_rotary_pos_emb_vision(q, k, cos, sin) + + attention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool) + for i in range(1, len(cu_seqlens)): + attention_mask[ + ..., + cu_seqlens[i - 1] : cu_seqlens[i], + cu_seqlens[i - 1] : cu_seqlens[i], + ] = True + q = q.transpose(0, 1) + k = k.transpose(0, 1) + v = v.transpose(0, 1) + attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0) + attn_output = attn_output.transpose(0, 1) + attn_output = attn_output.reshape(seq_length, -1) + attn_output = self.proj(attn_output) + return attn_output + + +QWEN2_5_VL_VISION_ATTENTION_CLASSES = { + "eager": Qwen2_5_VLVisionAttention, + "flash_attention_2": Qwen2_5_VLVisionFlashAttention2, + "sdpa": Qwen2_5_VLVisionSdpaAttention, +} + + +class Qwen2_5_VLVisionBlock(nn.Module): + def __init__(self, config, attn_implementation: str = "sdpa") -> None: + super().__init__() + self.norm1 = Qwen2RMSNorm(config.hidden_size, eps=1e-6) + self.norm2 = Qwen2RMSNorm(config.hidden_size, eps=1e-6) + self.attn = QWEN2_5_VL_VISION_ATTENTION_CLASSES[attn_implementation]( + config.hidden_size, num_heads=config.num_heads + ) + self.mlp = Qwen2_5_VLMLP(config, bias=True) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: int | None = None, + rotary_pos_emb: torch.Tensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + ) -> torch.Tensor: + hidden_states = hidden_states + self.attn( + self.norm1(hidden_states), + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + rotary_pos_emb=rotary_pos_emb, + position_embeddings=position_embeddings, + ) + hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) + return hidden_states + + +Qwen2_5_VL_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`Qwen2_5_VLConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Qwen2_5_VL Model outputting raw hidden-states without any specific head on top.", + Qwen2_5_VL_START_DOCSTRING, +) +class Qwen2_5_VLPreTrainedModel(PreTrainedModel): + config_class = Qwen2_5_VLConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Qwen2_5_VLDecoderLayer", "Qwen2_5_VLVisionBlock"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + _supports_static_cache = ( + False # TODO (joao): fix. torch.compile failing probably due to `cache_positions` + ) + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, (nn.Linear, nn.Conv3d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel): + config_class = Qwen2_5_VLVisionConfig + _no_split_modules = ["Qwen2_5_VLVisionBlock"] + + def __init__(self, config, *inputs, **kwargs) -> None: + super().__init__(config, *inputs, **kwargs) + self.spatial_merge_size = config.spatial_merge_size + self.patch_size = config.patch_size + self.fullatt_block_indexes = config.fullatt_block_indexes + self.window_size = config.window_size + self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size + + self.patch_embed = Qwen2_5_VisionPatchEmbed( + patch_size=config.patch_size, + temporal_patch_size=config.temporal_patch_size, + in_channels=config.in_channels, + embed_dim=config.hidden_size, + ) + + head_dim = config.hidden_size // config.num_heads + self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2) + + self.blocks = nn.ModuleList( + [Qwen2_5_VLVisionBlock(config, config._attn_implementation) for _ in range(config.depth)] + ) + self.merger = Qwen2_5_VLPatchMerger( + dim=config.out_hidden_size, + context_dim=config.hidden_size, + spatial_merge_size=config.spatial_merge_size, + ) + self.gradient_checkpointing = False + + def rot_pos_emb(self, grid_thw): + pos_ids = [] + for t, h, w in grid_thw: + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + hpos_ids = hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + hpos_ids = hpos_ids.permute(0, 2, 1, 3) + hpos_ids = hpos_ids.flatten() + + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + wpos_ids = wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + wpos_ids = wpos_ids.permute(0, 2, 1, 3) + wpos_ids = wpos_ids.flatten() + pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + pos_ids = torch.cat(pos_ids, dim=0) + max_grid_size = grid_thw[:, 1:].max() + rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + return rotary_pos_emb + + def get_window_index(self, grid_thw): + window_index: list = [] + cu_window_seqlens: list = [0] + window_index_id = 0 + vit_merger_window_size = self.window_size // self.spatial_merge_size // self.patch_size + + for grid_t, grid_h, grid_w in grid_thw: + llm_grid_h, llm_grid_w = ( + grid_h // self.spatial_merge_size, + grid_w // self.spatial_merge_size, + ) + index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w) + pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size + pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size + num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size + num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size + index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100) + index_padded = index_padded.reshape( + grid_t, + num_windows_h, + vit_merger_window_size, + num_windows_w, + vit_merger_window_size, + ) + index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape( + grid_t, + num_windows_h * num_windows_w, + vit_merger_window_size, + vit_merger_window_size, + ) + seqlens = (index_padded != -100).sum([2, 3]).reshape(-1) + index_padded = index_padded.reshape(-1) + index_new = index_padded[index_padded != -100] + window_index.append(index_new + window_index_id) + cu_seqlens_tmp = seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1] + cu_window_seqlens.extend(cu_seqlens_tmp.tolist()) + window_index_id += (grid_t * llm_grid_h * llm_grid_w).item() + window_index = torch.cat(window_index, dim=0) + + return window_index, cu_window_seqlens + + def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor: + """ + Args: + hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`): + The final hidden states of the model. + grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`): + The temporal, height and width of feature shape of each image in LLM. + + Returns: + `torch.Tensor`: hidden_states. + """ + hidden_states = self.patch_embed(hidden_states) + rotary_pos_emb = self.rot_pos_emb(grid_thw) + window_index, cu_window_seqlens = self.get_window_index(grid_thw) + window_index = window_index.to(hidden_states.device) + cu_window_seqlens = torch.tensor( + cu_window_seqlens, + device=hidden_states.device, + dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) + + seq_len, _ = hidden_states.size() + hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + hidden_states = hidden_states[window_index, :, :] + hidden_states = hidden_states.reshape(seq_len, -1) + rotary_pos_emb = rotary_pos_emb.reshape( + seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1 + ) + rotary_pos_emb = rotary_pos_emb[window_index, :, :] + rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + position_embeddings = (emb.cos(), emb.sin()) + + cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( + dim=0, + # Select dtype based on the following factors: + # - FA2 requires that cu_seqlens_q must have dtype int32 + # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw + # See https://github.com/huggingface/transformers/pull/34852 for more information + dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + max_seqlen_full = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + max_seqlen_window = (cu_window_seqlens[1:] - cu_window_seqlens[:-1]).max().item() + + for layer_num, blk in enumerate(self.blocks): + if layer_num in self.fullatt_block_indexes: + cu_seqlens_now = cu_seqlens + max_seqlen_now = max_seqlen_full + else: + cu_seqlens_now = cu_window_seqlens + max_seqlen_now = max_seqlen_window + if self.gradient_checkpointing and self.training: + hidden_states = self._gradient_checkpointing_func( + blk.__call__, + hidden_states, + cu_seqlens_now, + None, + position_embeddings, + ) + else: + hidden_states = blk( + hidden_states, + cu_seqlens=cu_seqlens_now, + max_seqlen=max_seqlen_now, + position_embeddings=position_embeddings, + ) + + hidden_states = self.merger(hidden_states) + reverse_indices = torch.argsort(window_index) + hidden_states = hidden_states[reverse_indices, :] + + return hidden_states + + +class Qwen2_5_VLRotaryEmbedding(nn.Module): + def __init__(self, config: Qwen2_5_VLConfig, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer( + "inv_freq", inv_freq, persistent=False + ) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if ( + seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len + ): # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block. In contrast to other models, Qwen2_5_VL has different position ids for thw grids + # So we expand the inv_freq to shape (3, ...) + inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1) + position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions) + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class Qwen2MLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1): + """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/). + + Explanation: + Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding + sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For + vision embedding part, we apply rotary position embedding on temporal, height and width dimension separately. + Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding. + For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal, + height and width) of text embedding is always the same, so the text embedding rotary position embedding has no + difference with modern LLMs. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + mrope_section(`List(int)`): + Multimodal rope section is for channel dimension of temporal, height and width in rope calculation. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + mrope_section = mrope_section * 2 + cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze( + unsqueeze_dim + ) + sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze( + unsqueeze_dim + ) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class Qwen2_5_VLAttention(nn.Module): + """ + Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer + and "Generating Long Sequences with Sparse Transformers". + """ + + def __init__(self, config: Qwen2_5_VLConfig, layer_idx: int | None = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " + "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.is_causal = True + self.attention_dropout = config.attention_dropout + self.rope_scaling = config.rope_scaling + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + + self.rotary_emb = Qwen2_5_VLRotaryEmbedding(config=config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_value: Cache | None = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: torch.LongTensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] + | None = None, # necessary, but kept here for BC + ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_multimodal_rotary_pos_emb( + query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] + ) + + if past_key_value is not None: + cache_kwargs = { + "sin": sin, + "cos": cos, + "cache_position": cache_position, + } # Specific to RoPE models + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # Fix precision issues in Qwen2-VL float16 inference + # Replace inf values with zeros in attention weights to prevent NaN propagation + if query_states.dtype == torch.float16: + attn_weights = torch.where( + torch.isinf(attn_weights), torch.zeros_like(attn_weights), attn_weights + ) + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class Qwen2_5_VLFlashAttention2(Qwen2_5_VLAttention): + """ + Qwen2_5_VL flash attention module, following Qwen2_5_VL attention module. This module inherits from `Qwen2_5_VLAttention` + as the weights of the module stays untouched. The only required change would be on the forward pass + where it needs to correctly call the public API of flash attention and deal with padding tokens + in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom + config.max_window_layers layers. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_value: Cache | None = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: torch.LongTensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] + | None = None, # necessary, but kept here for BC + ): + bsz, q_len, _ = hidden_states.size() + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + + # Because the input can be padded, the absolute sequence length depends on the max position id. + cos, sin = position_embeddings + query_states, key_states = apply_multimodal_rotary_pos_emb( + query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] + ) + if past_key_value is not None: + cache_kwargs = { + "sin": sin, + "cos": cos, + "cache_position": cache_position, + } # Specific to RoPE models + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + # repeat k/v heads if n_kv_heads < n_heads + # key_states = repeat_kv(key_states, self.num_key_value_groups) + # value_states = repeat_kv(value_states, self.num_key_value_groups) + dropout_rate = 0.0 if not self.training else self.attention_dropout + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + # Reashape to the expected shape for Flash Attention + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + attn_output = flash_attn_func( + query_states, + key_states, + value_states, + dropout_rate, + softmax_scale=None, + causal=self.is_causal, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class Qwen2_5_VLSdpaAttention(Qwen2_5_VLAttention): + """ + Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `Qwen2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from Qwen2Attention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_value: Cache | None = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: torch.LongTensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] + | None = None, # necessary, but kept here for BC + ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "Qwen2_5_VLModel is using Qwen2_5_VLSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_multimodal_rotary_pos_emb( + query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] + ) + + if past_key_value is not None: + cache_kwargs = { + "sin": sin, + "cos": cos, + "cache_position": cache_position, + } # Specific to RoPE models + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal = True if causal_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +QWEN2_5_VL_ATTENTION_CLASSES = { + "eager": Qwen2_5_VLAttention, + "flash_attention_2": Qwen2_5_VLFlashAttention2, + "sdpa": Qwen2_5_VLSdpaAttention, +} + + +class Qwen2_5_VLDecoderLayer(nn.Module): + def __init__(self, config: Qwen2_5_VLConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + if config.use_sliding_window and config._attn_implementation != "flash_attention_2": + logger.warning_once( + f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " + "unexpected results may be encountered." + ) + self.self_attn = QWEN2_5_VL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) + + self.mlp = Qwen2MLP(config) + self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_value: tuple[torch.Tensor] | None = None, + output_attentions: bool | None = False, + use_cache: bool | None = False, + cache_position: torch.LongTensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] + | None = None, # necessary, but kept here for BC + **kwargs, + ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +@add_start_docstrings( + "The bare Qwen2_5_VL Model outputting raw hidden-states without any specific head on top.", + Qwen2_5_VL_START_DOCSTRING, +) +class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel): + def __init__(self, config: Qwen2_5_VLConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [Qwen2_5_VLDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self._attn_implementation = config._attn_implementation + self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Qwen2_5_VLRotaryEmbedding(config=config) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: list[torch.FloatTensor] | None = None, + inputs_embeds: torch.FloatTensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + cache_position: torch.LongTensor | None = None, + ) -> tuple | BaseModelOutputWithPast: + output_attentions = ( + output_attentions if output_attentions is not None else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # torch.jit.trace() doesn't support cache objects in the output + if use_cache and past_key_values is None and not torch.jit.is_tracing(): + past_key_values = DynamicCache() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device, + ) + + # the hard coded `3` is for temporal, height and width. + if position_ids is None: + position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1) + elif position_ids.dim() == 2: + position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) + + causal_mask = self._update_causal_mask( + attention_mask, + inputs_embeds, + cache_position, + past_key_values, + output_attentions, + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + + if not return_dict: + return tuple( + v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None + ) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and past_key_values is not None: + is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] + if is_padding_right: + raise ValueError( + "You are attempting to perform batched generation with padding_side='right'" + " this may lead to unexpected behaviour for Flash Attention version of Qwen2_5_VL. Make sure to " + " call `tokenizer.padding_side = 'left'` before tokenizing the input. " + ) + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if ( + self.config._attn_implementation == "sdpa" + and not (using_static_cache or using_sliding_window_cache) + and not output_attentions + ): + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + sliding_window=self.config.sliding_window, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + # SlidingWindowCache or StaticCache + if using_sliding_window_cache or using_static_cache: + target_length = past_key_values.get_max_cache_shape() + # DynamicCache or no cache + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + config=self.config, + past_key_values=past_key_values, + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu"] + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + config: Qwen2_5_VLConfig, + past_key_values: Cache, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to place the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + config (`Qwen2_5_VLConfig`): + The model's configuration class + past_key_values (`Cache`): + The cache class that is being used currently to generate + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), + fill_value=min_dtype, + dtype=dtype, + device=device, + ) + diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + if config.sliding_window is not None: + # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also + # the check is needed to verify is current checkpoint was trained with sliding window or not + if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: + sliding_attend_mask = torch.arange(target_length, device=device) <= ( + cache_position.reshape(-1, 1) - config.sliding_window + ) + diagonal_attend_mask.bitwise_or_(sliding_attend_mask) + causal_mask *= diagonal_attend_mask + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + if attention_mask.shape[-1] > target_length: + attention_mask = attention_mask[:, :target_length] + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + return causal_mask + + +@dataclass +class Qwen2_5_VLCausalLMOutputWithPast(ModelOutput): + """ + Base class for Qwen2_5_VL causal language model (or autoregressive) outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. + """ + + loss: torch.FloatTensor | None = None + logits: torch.FloatTensor = None + past_key_values: list[torch.FloatTensor] | None = None + hidden_states: tuple[torch.FloatTensor] | None = None + attentions: tuple[torch.FloatTensor] | None = None + rope_deltas: torch.LongTensor | None = None + + +QWEN2_5_VL_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + pixel_values (`torch.FloatTensor` of shape `(seq_length, num_channels * image_size * image_size)): + The tensors corresponding to the input images. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`Qwen2_5_VLImageProcessor.__call__`] for details. [`Qwen2_5_VLProcessor`] uses + [`Qwen2_5_VLImageProcessor`] for processing images. + pixel_values_videos (`torch.FloatTensor` of shape `(seq_length, num_channels * temporal_size * image_size * image_size)): + The tensors corresponding to the input videos. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`Qwen2_5_VLImageProcessor.__call__`] for details. [`Qwen2_5_VLProcessor`] uses + [`Qwen2_5_VLImageProcessor`] for processing videos. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. +""" + + +class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + config_class = Qwen2_5_VLConfig + _no_split_modules = ["Qwen2_5_VLDecoderLayer", "Qwen2_5_VLVisionBlock"] + + def __init__(self, config): + super().__init__(config) + self.visual = Qwen2_5_VisionTransformerPretrainedModel._from_config(config.vision_config) + self.model = Qwen2_5_VLModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.rope_deltas = None # cache rope_deltas here + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def get_rope_index( + self, + input_ids: torch.LongTensor | None = None, + image_grid_thw: torch.LongTensor | None = None, + video_grid_thw: torch.LongTensor | None = None, + second_per_grid_ts: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Calculate the 3D rope index based on image and video's temporal, height and width in LLM. + + Explanation: + Each embedding sequence contains vision embedding and text embedding or just contains text embedding. + + For pure text embedding sequence, the rotary position embedding has no difference with modern LLMs. + Examples: + input_ids: [T T T T T], here T is for text. + temporal position_ids: [0, 1, 2, 3, 4] + height position_ids: [0, 1, 2, 3, 4] + width position_ids: [0, 1, 2, 3, 4] + + For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part + and 1D rotary position embedding for text part. + Examples: + Temporal (Time): 3 patches, representing different segments of the video in time. + Height: 2 patches, dividing each frame vertically. + Width: 2 patches, dividing each frame horizontally. + We also have some important parameters: + fps (Frames Per Second): The video's frame rate, set to 1. This means one frame is processed each second. + tokens_per_second: This is a crucial parameter. It dictates how many "time-steps" or "temporal tokens" are conceptually packed into a one-second interval of the video. In this case, we have 25 tokens per second. So each second of the video will be represented with 25 separate time points. It essentially defines the temporal granularity. + temporal_patch_size: The number of frames that compose one temporal patch. Here, it's 2 frames. + interval: The step size for the temporal position IDs, calculated as tokens_per_second * temporal_patch_size / fps. In this case, 25 * 2 / 1 = 50. This means that each temporal patch will be have a difference of 50 in the temporal position IDs. + input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision. + vision temporal position_ids: [0, 0, 0, 0, 50, 50, 50, 50, 100, 100, 100, 100] + vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1] + vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1] + text temporal position_ids: [101, 102, 103, 104, 105] + text height position_ids: [101, 102, 103, 104, 105] + text width position_ids: [101, 102, 103, 104, 105] + Here we calculate the text start position_ids as the max vision position_ids plus 1. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*): + The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + Returns: + position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`) + mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`) + """ + spatial_merge_size = self.config.vision_config.spatial_merge_size + image_token_id = self.config.image_token_id + video_token_id = self.config.video_token_id + vision_start_token_id = self.config.vision_start_token_id + mrope_position_deltas = [] + if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): + total_input_ids = input_ids + if attention_mask is None: + attention_mask = torch.ones_like(total_input_ids) + position_ids = torch.ones( + 3, + input_ids.shape[0], + input_ids.shape[1], + dtype=input_ids.dtype, + device=input_ids.device, + ) + image_index, video_index = 0, 0 + attention_mask = attention_mask.to(total_input_ids.device) + for i, input_ids in enumerate(total_input_ids): + input_ids = input_ids[attention_mask[i] == 1] + image_nums, video_nums = 0, 0 + vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1) + vision_tokens = input_ids[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() + input_tokens = input_ids.tolist() + llm_pos_ids_list: list = [] + st = 0 + remain_images, remain_videos = image_nums, video_nums + for _ in range(image_nums + video_nums): + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + second_per_grid_t = 0 + image_index += 1 + remain_images -= 1 + ed = ed_image + + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + if second_per_grid_ts is not None: + second_per_grid_t = second_per_grid_ts[video_index] + else: + second_per_grid_t = 1.0 + video_index += 1 + remain_videos -= 1 + ed = ed_video + llm_grid_t, llm_grid_h, llm_grid_w = ( + t.item(), + h.item() // spatial_merge_size, + w.item() // spatial_merge_size, + ) + text_len = ed - st + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + range_tensor = torch.arange(llm_grid_t).view(-1, 1) + expanded_range = range_tensor.expand(-1, llm_grid_h * llm_grid_w) + + time_tensor = ( + expanded_range * second_per_grid_t * self.config.vision_config.tokens_per_second + ) + + time_tensor_long = time_tensor.long() + t_index = time_tensor_long.flatten() + + h_index = ( + torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() + ) + w_index = ( + torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() + ) + llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device) + mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i])) + mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) + return position_ids, mrope_position_deltas + else: + if attention_mask is not None: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device) + max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] + mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] + else: + position_ids = ( + torch.arange(input_ids.shape[1], device=input_ids.device) + .view(1, 1, -1) + .expand(3, input_ids.shape[0], -1) + ) + mrope_position_deltas = torch.zeros( + [input_ids.shape[0], 1], + device=input_ids.device, + dtype=input_ids.dtype, + ) + + return position_ids, mrope_position_deltas + + @add_start_docstrings_to_model_forward(QWEN2_5_VL_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Qwen2_5_VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: list[torch.FloatTensor] | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + pixel_values: torch.Tensor | None = None, + pixel_values_videos: torch.FloatTensor | None = None, + image_grid_thw: torch.LongTensor | None = None, + video_grid_thw: torch.LongTensor | None = None, + rope_deltas: torch.LongTensor | None = None, + cache_position: torch.LongTensor | None = None, + second_per_grid_ts: torch.Tensor | None = None, + ) -> tuple | Qwen2_5_VLCausalLMOutputWithPast: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration + + >>> model = Qwen2_5_VLForConditionalGeneration.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct") + >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct") + + >>> messages = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": "What is shown in this image?"}, + ], + }, + ] + >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + >>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos]) + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..." + ```""" + + output_attentions = ( + output_attentions if output_attentions is not None else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if inputs_embeds is None: + inputs_embeds = self.model.embed_tokens(input_ids) + if pixel_values is not None: + pixel_values = pixel_values.type(self.visual.dtype) + image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) + n_image_tokens = (input_ids == self.config.image_token_id).sum().item() + n_image_features = image_embeds.shape[0] + if n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + + mask = input_ids == self.config.image_token_id + mask_unsqueezed = mask.unsqueeze(-1) + mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) + image_mask = mask_expanded.to(inputs_embeds.device) + + image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + if pixel_values_videos is not None: + pixel_values_videos = pixel_values_videos.type(self.visual.dtype) + video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) + n_video_tokens = (input_ids == self.config.video_token_id).sum().item() + n_video_features = video_embeds.shape[0] + if n_video_tokens != n_video_features: + raise ValueError( + f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" + ) + + mask = input_ids == self.config.video_token_id + mask_unsqueezed = mask.unsqueeze(-1) + mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) + video_mask = mask_expanded.to(inputs_embeds.device) + + video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) + + if attention_mask is not None: + attention_mask = attention_mask.to(inputs_embeds.device) + + # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme + if position_ids is None and (attention_mask is None or attention_mask.ndim == 2): + # calculate RoPE index once per generation in the pre-fill stage only + if ( + (cache_position is not None and cache_position[0] == 0) + or self.rope_deltas is None + or (past_key_values is None or past_key_values.get_seq_length() == 0) + ): + position_ids, rope_deltas = self.get_rope_index( + input_ids, + image_grid_thw, + video_grid_thw, + second_per_grid_ts, + attention_mask, + ) + self.rope_deltas = rope_deltas + # then use the prev pre-calculated rope-deltas to get the correct position ids + else: + batch_size, seq_length, _ = inputs_embeds.shape + delta = ( + (cache_position[0] + self.rope_deltas).to(inputs_embeds.device) + if cache_position is not None + else 0 + ) + position_ids = torch.arange(seq_length, device=inputs_embeds.device) + position_ids = position_ids.view(1, -1).expand(batch_size, -1) + if cache_position is not None: # otherwise `deltas` is an int `0` + delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) + position_ids = position_ids.add(delta) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) + + outputs = self.model( + input_ids=None, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return Qwen2_5_VLCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + rope_deltas=self.rope_deltas, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + pixel_values=None, + pixel_values_videos=None, + image_grid_thw=None, + video_grid_thw=None, + second_per_grid_ts=None, + **kwargs, + ): + # Overwritten -- in specific circumstances we don't want to forward image inputs to the model + + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here + # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case. + # (we can't check exception 3 while compiling) + # Exception 4: If input_embeds are passed then slice it through `cache_position`, to keep only the unprocessed tokens and + # generate the first token for each sequence. Later use the generated Input ids for continuation. + if past_key_values is not None: + if inputs_embeds is not None and input_ids.shape[1] == 0: # Exception 4 + inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :] + elif inputs_embeds is not None or ( # Exception 1 + is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1] + ): # Exception 3 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif ( + input_ids.shape[1] != cache_position.shape[0] + ): # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + + if cache_position[0] != 0: + pixel_values = None + pixel_values_videos = None + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]: + model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} + else: + model_inputs = {"input_ids": input_ids, "inputs_embeds": None} + + if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: + if model_inputs["inputs_embeds"] is not None: + batch_size, sequence_length, _ = inputs_embeds.shape + device = inputs_embeds.device + else: + batch_size, sequence_length = input_ids.shape + device = input_ids.device + + attention_mask = self.model._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=past_key_values.get_max_cache_shape(), + dtype=self.lm_head.weight.dtype, + device=device, + cache_position=cache_position, + batch_size=batch_size, + config=self.config, + past_key_values=past_key_values, + ) + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + "pixel_values": pixel_values, + "pixel_values_videos": pixel_values_videos, + "image_grid_thw": image_grid_thw, + "video_grid_thw": video_grid_thw, + "cache_position": cache_position, + "second_per_grid_ts": second_per_grid_ts, + } + ) + return model_inputs + + def _get_image_nums_and_video_nums( + self, + input_ids: torch.LongTensor | None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Get the number of images and videos for each sample to calculate the separation length of the sample tensor. + These parameters are not passed through the processor to avoid unpredictable impacts from interface modifications. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Returns: + image_nums (`torch.LongTensor` of shape `(batch_size, num_images_sample)`) + video_nums (`torch.LongTensor` of shape `(batch_size, num_videos_sample)`) + """ + image_token_id = self.config.image_token_id + video_token_id = self.config.video_token_id + vision_start_token_id = self.config.vision_start_token_id + + vision_start_mask = input_ids == vision_start_token_id + vision_first_mask = torch.roll(vision_start_mask, shifts=1, dims=1) + image_mask = input_ids == image_token_id + video_mask = input_ids == video_token_id + image_nums = torch.sum(vision_first_mask & image_mask, dim=1) + video_nums = torch.sum(vision_first_mask & video_mask, dim=1) + + return image_nums, video_nums + + def _expand_inputs_for_generation( + self, + expand_size: int = 1, + is_encoder_decoder: bool = False, + input_ids: torch.LongTensor | None = None, + **model_kwargs, + ) -> tuple[torch.LongTensor, dict[str, Any]]: + # Overwritten -- Support for expanding tensors without a batch size dimension + # e.g., pixel_values, image_grid_thw, pixel_values_videos, video_grid_thw, second_per_grid_t + # pixel_values.shape[0] is sum(seqlen_images for samples) + # image_grid_thw.shape[0] is sum(num_images for samples) + + if expand_size == 1: + return input_ids, model_kwargs + + visual_keys = [ + "pixel_values", + "image_grid_thw", + "pixel_values_videos", + "video_grid_thw", + "second_per_grid_ts", + ] + + def _expand_dict_for_generation_visual(dict_to_expand): + image_grid_thw = model_kwargs.get("image_grid_thw", None) + video_grid_thw = model_kwargs.get("video_grid_thw", None) + image_nums, video_nums = self._get_image_nums_and_video_nums(input_ids) + + def _repeat_interleave_samples(x, lengths, repeat_times): + samples = torch.split(x, lengths) + repeat_args = [repeat_times] + [1] * (x.dim() - 1) + result = torch.cat([sample.repeat(*repeat_args) for sample in samples], dim=0) + return result + + for key in dict_to_expand: + if key == "pixel_values": + # split images into samples + samples = torch.split(image_grid_thw, list(image_nums)) + # compute the sequence length of images for each sample + lengths = [torch.prod(sample, dim=1).sum() for sample in samples] + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "image_grid_thw": + # get the num of images for each sample + lengths = list(image_nums) + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "pixel_values_videos": + samples = torch.split(video_grid_thw, list(video_nums)) + lengths = [torch.prod(sample, dim=1).sum() for sample in samples] + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "video_grid_thw": + lengths = list(video_nums) + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "second_per_grid_ts": + if not isinstance(dict_to_expand[key], list): + raise TypeError( + f"Expected value for key '{key}' to be a list, but got {type(dict_to_expand[key])} instead." + ) + tensor = torch.tensor(dict_to_expand[key]) + lengths = list(video_nums) + tensor = _repeat_interleave_samples(tensor, lengths=lengths, repeat_times=expand_size) + dict_to_expand[key] = tensor.tolist() + return dict_to_expand + + def _expand_dict_for_generation(dict_to_expand): + for key in dict_to_expand: + if ( + key != "cache_position" + and dict_to_expand[key] is not None + and isinstance(dict_to_expand[key], torch.Tensor) + and key not in visual_keys + ): + dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0) + return dict_to_expand + + # input_ids is required for expanding visual inputs + # If input_ids is unavailable, visual inputs will not be used; therefore, there is no need to expand visual inputs. + if input_ids is not None and input_ids.numel() != 0: + model_kwargs = _expand_dict_for_generation_visual(model_kwargs) + + if input_ids is not None: + input_ids = input_ids.repeat_interleave(expand_size, dim=0) + + model_kwargs = _expand_dict_for_generation(model_kwargs) + + if is_encoder_decoder: + if model_kwargs.get("encoder_outputs") is None: + raise ValueError( + "If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined." + ) + model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"]) + + return input_ids, model_kwargs + + +@dataclass +class Qwen2_5_VLACausalLMOutputWithPast(ModelOutput): + loss: torch.FloatTensor | None = None + flow_loss: torch.FloatTensor | None = None + cross_entropy_loss: torch.FloatTensor | None = None + logits: torch.FloatTensor | None = None + past_key_values: list[torch.FloatTensor] | None = None + hidden_states: tuple[torch.FloatTensor] | None = None + attentions: tuple[torch.FloatTensor] | None = None + rope_deltas: torch.LongTensor | None = None + + channel_loss_dict: dict[torch.FloatTensor] | None = None + channel_loss_count_dict: dict[torch.FloatTensor] | None = None + + +class BlockSparseMLP(nn.Module): + def __init__(self, config): + super().__init__() + + self.hidden_size = config["hidden_size"] + self.intermediate_size = config["intermediate_size"] + self.hidden_act = config["hidden_act"] + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[self.hidden_act] + + def forward(self, hidden_state): + return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) + + +class SparseMoeBlock(nn.Module): + def __init__(self, config, num_experts: int): + super().__init__() + self.num_experts = num_experts + self.experts = nn.ModuleList([BlockSparseMLP(config.experts[i]) for i in range(num_experts)]) + + if not hasattr(config, "dim_inputs") or not config.dim_inputs: + raise ValueError("Config must contain valid dim_inputs") + + self.dim_inputs = config.dim_inputs + + def forward(self, hidden_states: torch.Tensor, experts_indices: torch.Tensor) -> torch.Tensor: + """ + Route different hidden_states to corresponding experts for processing. + + Args: + hidden_states (torch.Tensor): Tensor of shape (batch_size, seq_length, hidden_dim). + experts_indices (torch.Tensor): Tensor of shape (batch_size, seq_length), + indicating the expert index assigned to each token. + + Returns: + output (torch.Tensor): Tensor of shape (batch_size, seq_length, hidden_dim). + """ + batch_size, seq_length, hidden_dim = hidden_states.size() + output = torch.zeros_like(hidden_states) + + for expert_idx, expert in enumerate(self.experts): + mask = experts_indices == expert_idx + if mask.sum() == 0: + continue + dim_input = self.dim_inputs[expert_idx] + + selected_hidden = hidden_states[mask] + processed_hidden = expert(selected_hidden[:, :dim_input]) + + batch_indices, seq_indices = torch.where(mask) + output[batch_indices, seq_indices, :dim_input] = processed_hidden + + return output + + +QWEN2_5_VL_ATTENTION_CLASSES = { + "eager": Qwen2_5_VLAttention, + "flash_attention_2": Qwen2_5_VLFlashAttention2, + "sdpa": Qwen2_5_VLSdpaAttention, +} + + +class Qwen2_5_VLDecoderLayer_with_MoE(nn.Module): + def __init__(self, config: Qwen2_5_VLConfig, layer_idx: int, num_experts: int): + super().__init__() + self.hidden_size = config.hidden_size + + if config.use_sliding_window and config._attn_implementation != "flash_attention_2": + logger.warning_once( + f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " + "unexpected results may be encountered." + ) + + self.self_attn = QWEN2_5_VL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) + + self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + if config.mlp_moe: + self.moe = SparseMoeBlock(config, num_experts=num_experts) + self.mlp = None + else: + self.mlp = Qwen2_5_VLMLP(config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_value: tuple[torch.Tensor] | None = None, + token_types=None, + output_attentions: bool | None = False, + use_cache: bool | None = False, + cache_position: torch.LongTensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + **kwargs, + ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + residual = hidden_states + hidden_states = hidden_states.to(self.input_layernorm.weight.dtype) + hidden_states = self.input_layernorm(hidden_states) + hidden_states = hidden_states.to(self.self_attn.q_proj.weight.dtype) + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = hidden_states.to(self.post_attention_layernorm.weight.dtype) + hidden_states = self.post_attention_layernorm(hidden_states) + if self.mlp is None: # using moe mlp + hidden_states = hidden_states.to(self.moe.experts[0].down_proj.weight.dtype) + hidden_states = self.moe(hidden_states, token_types) + else: + hidden_states = hidden_states.to(self.mlp.down_proj.weight.dtype) + hidden_states = self.mlp(hidden_states) + + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + if use_cache: + outputs += (present_key_value,) + return outputs + + +class Qwen2_5_VLMoEModel(Qwen2_5_VLPreTrainedModel): + """Qwen2.5-VL model with Mixture of Experts (MoE) architecture. + + This model extends the base Qwen2.5-VL model by incorporating MoE layers + for improved scalability and specialization across different token types. + """ + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: str, + num_experts: int | None = None, + *args, + **kwargs, + ): + """Load a pretrained model with optional MoE configuration. + + Args: + pretrained_model_name_or_path: Path or name of the pretrained model + num_experts: Number of experts for MoE layers (if not in config) + *args: Additional arguments passed to parent class + **kwargs: Additional keyword arguments passed to parent class + + Returns: + Initialized model instance with MoE configuration + """ + config = kwargs.get("config") + if config is None: + config = AutoConfig.from_pretrained(pretrained_model_name_or_path) + + # Override number of experts if specified + if num_experts is not None: + config.num_experts = num_experts + + kwargs["config"] = config + return super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs) + + def __init__(self, config: Qwen2_5_VLConfig): + """Initialize the Qwen2.5-VL MoE model. + + Args: + config: Model configuration containing architecture parameters + """ + super().__init__(config) + + # Basic model parameters + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + # Model components + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + + # Decoder layers with MoE support + self.layers = nn.ModuleList( + [ + Qwen2_5_VLDecoderLayer_with_MoE(config, layer_idx, config.num_experts) + for layer_idx in range(config.num_hidden_layers) + ] + ) + + # Model configuration + self._attn_implementation = config._attn_implementation + self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Qwen2_5_VLRotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Embedding: + """Get the input embedding layer. + + Returns: + The token embedding layer + """ + return self.embed_tokens + + def set_input_embeddings(self, value: nn.Embedding) -> None: + """Set the input embedding layer. + + Args: + value: New embedding layer to use + """ + self.embed_tokens = value + + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: list[torch.FloatTensor] | None = None, + inputs_embeds: torch.FloatTensor | None = None, + moe_token_types: torch.LongTensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + cache_position: torch.LongTensor | None = None, + **kwargs, + ) -> tuple | BaseModelOutputWithPast: + # Set default output options + output_attentions = ( + output_attentions if output_attentions is not None else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # Validate inputs + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if moe_token_types is None: + raise ValueError("moe_token_types must be provided for MoE routing") + + # Handle gradient checkpointing compatibility + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # Initialize cache if needed + if use_cache and past_key_values is None and not torch.jit.is_tracing(): + past_key_values = DynamicCache() + + # Get input embeddings + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + # Set up cache position + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device, + ) + + # Set up position IDs (hardcoded 3 dimensions for temporal, height, width) + if position_ids is None: + position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1) + elif position_ids.dim() == 2: + position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) + + # Create causal attention mask + causal_mask = self._update_causal_mask( + attention_mask, + inputs_embeds, + cache_position, + past_key_values, + output_attentions, + moe_token_types, + ) + + hidden_states = inputs_embeds + + # Create position embeddings to be shared across decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # Initialize output collections + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + # Process through decoder layers + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + # Use gradient checkpointing during training + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + moe_token_types, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + # Regular forward pass + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + token_types=moe_token_types, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + hidden_states = layer_outputs[0] + + # Update cache if using it + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + # Collect attention weights if requested + if output_attentions: + all_self_attns += (layer_outputs[1],) + + # Apply final layer normalization + hidden_states = self.norm(hidden_states) + + # Add final hidden states if collecting all states + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + + # Return outputs in requested format + if not return_dict: + return tuple( + v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None + ) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + moe_token_types: torch.LongTensor | None = None, + ): + """Update causal attention mask with support for bidirectional attention for specific token types. + + This method creates and modifies attention masks to support different attention patterns: + - Standard causal (unidirectional) attention for most tokens + - Bidirectional attention for specific token types (e.g., MoE routing tokens) + + Args: + attention_mask: Input attention mask to avoid attending to padding tokens + input_tensor: Input embeddings tensor for shape and device information + cache_position: Position indices for caching mechanisms + past_key_values: Cached key-value pairs from previous forward passes + output_attentions: Whether attention weights will be returned + moe_token_types: Optional tensor indicating token types for MoE routing + (type 1 tokens will use bidirectional attention) + + Returns: + Updated causal attention mask, or None if using Flash Attention 2 + """ + # Flash Attention 2 handles masking internally + if self.config._attn_implementation == "flash_attention_2": + return None + + # Calculate sequence lengths for cache management + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) + + # For SDPA (Scaled Dot Product Attention), use `is_causal` argument when possible + # instead of explicit attention mask to enable Flash Attention 2 dispatch + # Note: This optimization is not compatible with static cache + if ( + self.config._attn_implementation == "sdpa" + and not (using_static_cache or using_sliding_window_cache) + and not output_attentions + ): + # Check if we can ignore the causal mask and rely on SDPA's internal handling + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + sliding_window=self.config.sliding_window, + is_training=self.training, + ): + return None + + # Extract tensor properties for mask creation + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + + # Determine target length based on cache type + if using_sliding_window_cache or using_static_cache: + # Use maximum cache shape for sliding window or static caches + target_length = past_key_values.get_max_cache_shape() + else: + # For dynamic cache or no cache, calculate based on attention mask or sequence length + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # Generate 4D causal attention mask from 2D input mask if provided + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + config=self.config, + past_key_values=past_key_values, + ) + + # Modify mask to support bidirectional attention for specific token types + if moe_token_types is not None: + # Identify positions of type 1 tokens (MoE routing tokens) + type1_tokens = (moe_token_types == 1).unsqueeze(1).unsqueeze(2) # Shape: [B, 1, 1, S] + + # Create bidirectional attention region for type 1 tokens + # This allows type 1 tokens to attend to each other bidirectionally + type1_mask = torch.zeros_like(causal_mask) # Shape: [B, num_heads, S, S] + type1_region = type1_tokens & type1_tokens.transpose(-1, -2) # Shape: [B, 1, S, S] + type1_mask = type1_mask.masked_fill(type1_region, 1.0).to(torch.bool) + + # Apply bidirectional attention: zero out causal constraints in type 1 regions + causal_mask = torch.where( + type1_mask, # Where type 1 tokens interact with each other + torch.zeros_like(causal_mask), # Remove causal masking (allow bidirectional) + causal_mask, # Keep original causal masking for other regions + ) + + # Handle special case for SDPA with CUDA/XPU devices + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu"] + and not output_attentions + ): + # Ensure attention to all tokens in fully masked rows for memory-efficient attention + # This is required for F.scaled_dot_product_attention's memory-efficient path + # when using left padding. See: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + config: Qwen2_5_VLConfig, + past_key_values: Cache, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to place the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + config (`Qwen2_5_VLConfig`): + The model's configuration class + past_key_values (`Cache`): + The cache class that is being used currently to generate + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), + fill_value=min_dtype, + dtype=dtype, + device=device, + ) + diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + if config.sliding_window is not None: + # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also + # the check is needed to verify is current checkpoint was trained with sliding window or not + if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: + sliding_attend_mask = torch.arange(target_length, device=device) <= ( + cache_position.reshape(-1, 1) - config.sliding_window + ) + diagonal_attend_mask.bitwise_or_(sliding_attend_mask) + causal_mask *= diagonal_attend_mask + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + if attention_mask.shape[-1] > target_length: + attention_mask = attention_mask[:, :target_length] + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + return causal_mask + + +__all__ = [ + "Qwen2_5_VLForConditionalGeneration", + "Qwen2_5_VLModel", + "Qwen2_5_VLPreTrainedModel", + "Qwen2_5_VLDecoderLayer_with_MoE", + "Qwen2_5_VLMoEModel", +] diff --git a/src/lerobot/policies/wall_x/utils.py b/src/lerobot/policies/wall_x/utils.py new file mode 100644 index 000000000..2ea40b377 --- /dev/null +++ b/src/lerobot/policies/wall_x/utils.py @@ -0,0 +1,631 @@ +#!/usr/bin/env python + +# Copyright 2025 HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Wall-X Utility Functions. + +Contains data processing utilities, text formatting functions, and helper classes +for the Wall-X cross-embodiment robotic control model. +""" + +import random +import re +from collections import OrderedDict +from dataclasses import dataclass, field +from typing import Any + +import torch +from transformers import BatchFeature + +from lerobot.policies.wall_x.constant import ( + CAMERA_NAME_MAPPING, +) +from lerobot.utils.constants import OBS_IMAGES + + +@dataclass +class X2RDataProcessingConfig: + """Configuration class for X2R data processing pipeline. + + This class contains all the necessary parameters for processing robotic data + including camera mappings, tactile sensor configurations, action predictions, + and various processing options. + """ + + # Action prediction configuration + predict_action_keys: list[str] = field(default_factory=list) + obs_action_keys: list[str] = field(default_factory=list) + + # Image resolution settings for different views + resolution: dict[str, int] = field( + default_factory=lambda: { + "face_view": -1, + "left_wrist_view": 128, + "right_wrist_view": 128, + } + ) + + # Dataset splitting + train_test_split: float = 0.9 + split_seed: int = 42 + + # Instruction handling + priority_order: dict[str, float] | None = None + + # Vision model parameters + model_type: str = "qwen2_5" + max_pixels: int = 16384 * 28 * 28 + min_pixels: int = 4 * 28 * 28 + image_factor: int = 28 + + generate_subtask_ratio: float = 0.0 + + def __post_init__(self): + """Post-initialization validation and setup.""" + # Validate train/test split + if not 0 < self.train_test_split < 1: + raise ValueError(f"train_test_split must be between 0 and 1, got {self.train_test_split}") + + def as_dict(self) -> dict: + """Convert configuration to dictionary format. + + Returns: + Dict: Configuration as dictionary + """ + return self.__dict__ + + def update(self, **kwargs) -> "X2RDataProcessingConfig": + """Update configuration parameters. + + Args: + **kwargs: Key-value pairs to update + + Returns: + X2RDataProcessingConfig: Updated configuration instance + """ + for key, value in kwargs.items(): + if hasattr(self, key): + setattr(self, key, value) + else: + raise ValueError(f"Unknown configuration parameter: {key}") + return self + + +def preprocesser_call( + processor, + images: list | Any | None = None, + text: str | list[str] | None = None, + videos: list | Any | None = None, + padding: bool | str = False, + truncation: bool | None = None, + max_length: int | None = None, + return_tensors: str = "pt", +) -> BatchFeature: + """Unified preprocessing function for Wall-X model handling text, image and video inputs. + + Processes inputs into format suitable for multimodal transformer models, including: + - Text tokenization and special token handling + - Image/video processing through image processor + - Attention mask and label generation + - Padding and truncation handling + + Args: + processor: Multimodal processor containing tokenizer and image processor + images: Input images (PIL, numpy arrays, or torch tensors) + text: Text or list of texts to tokenize + videos: Input videos (numpy arrays or torch tensors) + padding: Whether to pad sequences to same length + truncation: Whether to truncate sequences longer than max_length + max_length: Maximum length for truncation/padding + return_tensors: Format for returned tensors ('pt', 'np', etc.) + + Returns: + BatchFeature containing processed inputs with keys: + - input_ids: Tokenized text + - attention_mask: Attention mask for text + - pixel_values: Processed image pixels + - pixel_values_videos: Processed video frames + - image_grid_thw: Image grid dimensions for LLM + - video_grid_thw: Video grid dimensions for LLM + - labels: Training labels with masking + """ + # Process image inputs + if images is not None and len(images) > 0: + image_inputs = processor.image_processor(images=images, videos=None, return_tensors=return_tensors) + image_grid_thw = image_inputs["image_grid_thw"] + else: + image_inputs = {} + image_grid_thw = None + + # Process video inputs + if videos is not None: + videos_inputs = processor.image_processor(images=None, videos=videos, return_tensors=return_tensors) + video_grid_thw = videos_inputs["video_grid_thw"] + else: + videos_inputs = {} + video_grid_thw = None + + # Ensure text input is in list format + if not isinstance(text, list): + text = [text] + + # Process image placeholder tokens in text + if image_grid_thw is not None: + merge_length = processor.image_processor.merge_size**2 + index = 0 + for i in range(len(text)): + while "<|image_pad|>" in text[i]: + # Add bounds checking to avoid index overflow + if index >= len(image_grid_thw): + print( + f"Warning: Number of image placeholders ({index + 1}) " + f"exceeds actual images ({len(image_grid_thw)}), " + f"skipping remaining placeholder processing" + ) + break + # Replace image placeholder with actual token count + token_count = image_grid_thw[index].prod() // merge_length + text[i] = text[i].replace("<|image_pad|>", "<|placeholder|>" * token_count, 1) + index += 1 + text[i] = text[i].replace("<|placeholder|>", "<|image_pad|>") + + # Process video placeholder tokens in text + if video_grid_thw is not None: + merge_length = processor.image_processor.merge_size**2 + index = 0 + for i in range(len(text)): + while "<|video_pad|>" in text[i]: + # Replace video placeholder with actual token count + token_count = video_grid_thw[index].prod() // merge_length + text[i] = text[i].replace("<|video_pad|>", "<|placeholder|>" * token_count, 1) + index += 1 + text[i] = text[i].replace("<|placeholder|>", "<|video_pad|>") + + # Tokenize complete input text + text_inputs = processor.tokenizer( + text, + return_tensors=return_tensors, + padding=padding, + truncation=truncation, + max_length=max_length, + ) + + # Get pad token ID for label generation + pad_token_id = processor.tokenizer.pad_token_id + if pad_token_id is None: + pad_token_id = processor.tokenizer.eos_token_id + + # Generate labels for multi-turn dialogue, keeping only assistant response loss + labels = torch.full_like(text_inputs.input_ids, -100) + assistant_marker = "<|im_start|>assistant\n" + im_end_token_id = processor.tokenizer.convert_tokens_to_ids("<|im_end|>") + assistant_tokens = processor.tokenizer("<|im_start|>assistant\n", add_special_tokens=False).input_ids + + for i in range(len(text)): + assistant_regions = [] + parts = text[i].split(assistant_marker) + + # Process each part to determine which tokens belong to assistant responses + # Count left padding tokens + num_left_pads = 0 + for token_id in text_inputs.input_ids[i]: + if token_id == pad_token_id: + num_left_pads += 1 + else: + break + current_pos = num_left_pads + + for j, part in enumerate(parts): + part_tokens = processor.tokenizer(part, add_special_tokens=False).input_ids + if j == 0: + # First part is system prompt or user question, all labels are -100 + current_pos += len(part_tokens) + continue + + # From second part onwards, each part starts with assistant response + for k in range(current_pos + 1, len(text_inputs.input_ids[i])): + if text_inputs.input_ids[i][k] == im_end_token_id: + assistant_regions.append((current_pos + len(assistant_tokens), k + 2)) + break + current_pos += len(part_tokens) + 3 + + # Set labels for assistant response regions + for start, end in assistant_regions: + labels[i][start:end] = text_inputs.input_ids[i][start:end] + + # Mask special action tokens in labels + action_token_id = processor.tokenizer.encode("<|action|>")[0] + propri_token_id = processor.tokenizer.encode("<|propri|>")[0] + labels[labels == action_token_id] = -100 + labels[labels == propri_token_id] = -100 + labels[labels == processor.tokenizer.pad_token_id] = -100 + + # Set labels to None if all are invalid to skip cross entropy loss + if (labels != -100).any().item(): + text_inputs["labels"] = labels + else: + text_inputs["labels"] = None + + return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs}) + + +def process_grounding_points( + text: str, + orig_height: int, + orig_width: int, + resized_height: int, + resized_width: int, + model_type: str, +) -> str: + """Process grounding point coordinates in text based on image resizing. + + Adjusts coordinate values in tags to match resized image dimensions + for different model types (qwen2, qwen2_5). + + Args: + text: Input text containing tags with coordinates + orig_height: Original image height + orig_width: Original image width + resized_height: Resized image height + resized_width: Resized image width + model_type: Model type for coordinate processing ('qwen2' or 'qwen2_5') + + Returns: + Text with adjusted coordinate values + """ + # Regex pattern to match tags and their contents + point_pattern = re.compile(r"(.*?)") + + def process_match(match): + """Process a single point match and adjust coordinates.""" + coords_str = match.group(1) + try: + # Extract coordinates from string + coords = list(map(int, re.findall(r"\d+", coords_str))) + + # Calculate resize scale factors + scale_w = resized_width / orig_width + scale_h = resized_height / orig_height + + if len(coords) == 2: + x, y = coords + if model_type == "qwen2_5": + # Qwen2.5 uses pixel coordinates + new_x = max(0, min(round(x * scale_w), resized_width - 1)) + new_y = max(0, min(round(y * scale_h), resized_height - 1)) + elif model_type == "qwen2": + # Qwen2 normalizes to [0, 1000) range + new_x = max(0, min(999.999, (x / orig_width) * 1000)) + new_y = max(0, min(999.999, (y / orig_height) * 1000)) + else: + raise ValueError(f"Unsupported model type: {model_type}") + coords = [new_x, new_y] + + elif len(coords) == 4: + x1, y1, x2, y2 = coords + if model_type == "qwen2_5": + new_x1 = max(0, min(round(x1 * scale_w), resized_width - 1)) + new_y1 = max(0, min(round(y1 * scale_h), resized_height - 1)) + new_x2 = max(0, min(round(x2 * scale_w), resized_width - 1)) + new_y2 = max(0, min(round(y2 * scale_h), resized_height - 1)) + elif model_type == "qwen2": + new_x1 = max(0, min(999.999, (x1 / orig_width) * 1000)) + new_y1 = max(0, min(999.999, (y1 / orig_height) * 1000)) + new_x2 = max(0, min(999.999, (x2 / orig_width) * 1000)) + new_y2 = max(0, min(999.999, (y2 / orig_height) * 1000)) + else: + raise ValueError(f"Unsupported model type: {model_type}") + coords = [new_x1, new_y1, new_x2, new_y2] + + # Return processed point tag + return f"[{', '.join(map(str, coords))}]" + + except (ValueError, TypeError): + # Return original content if processing fails + return match.group(0) + + # Replace all matching point tags + processed_text = point_pattern.sub(process_match, text) + return processed_text + + +def get_frame_instruction( + instruction_info: dict[str, Any], + frame_idx: int | None = None, + truncate_keys: list[str] | None = None, +) -> tuple[dict[str, Any], int | None]: + """Extract frame-specific instruction from instruction dictionary. + + Args: + instruction_info: Dictionary containing instruction components + frame_idx: Current frame index + truncate_keys: Keys that trigger truncation when found + + Returns: + Tuple of (frame_instruction_dict, split_end_frame) + """ + if truncate_keys is None: + truncate_keys = [ + "subtask_generation", + "distribute", + "subtask_generation_zh", + "distribute_zh", + ] + + instruction_for_frame = {} + split_end = None + + for key, value in instruction_info.items(): + if isinstance(value, dict): + # Handle frame-range specific instructions + for frame_range, frame_instruction in value.items(): + start_frame, end_frame = map(int, frame_range.split(" ")) + if start_frame <= frame_idx < end_frame or (start_frame == frame_idx): + instruction_for_frame[key] = frame_instruction + if truncate_keys is not None and split_end is None and key in truncate_keys: + split_end = end_frame + 1 + break + else: + instruction_for_frame[key] = value + + return instruction_for_frame, split_end + + +def get_task_instruction( + frame_instruction_info: dict[str, Any], priority_order: OrderedDict | None = None +) -> str: + """Construct task instruction from available instruction fields using priority sampling. + + Args: + frame_instruction_info: Dictionary containing instruction fields + priority_order: OrderedDict specifying sampling probability for each field + + Returns: + Combined instruction string with priority components + """ + # Default priority settings + default_priority_order = OrderedDict( + { + "subtask_generation": 0.25, + "subtask_generation_zh": 0.25, + "distribute": 0.25, + "distribute_zh": 0.25, + } + ) + + if priority_order is not None: + priority_order = OrderedDict(priority_order) + else: + priority_order = default_priority_order + + got_instruction = False + task_instruction = "" + + # Sample instruction components based on priority probabilities + for key, prob in priority_order.items(): + if key in frame_instruction_info and frame_instruction_info[key] != "": + if got_instruction: + if random.random() >= prob: + continue + + task_instruction += f"\n{frame_instruction_info[key]}" + got_instruction = True + break + + # Fall back to base instruction if no priority components found + if not got_instruction: + task_instruction = frame_instruction_info.get("instruction", "") + + return task_instruction + + +def get_wallx_normal_text( + instruction_info: dict[str, Any], + action_chunk_size: int, + frame_idx: int, + priority_order: OrderedDict | None = None, + img_keys: list[str] | None = None, + generate_subtask_ratio: float = 0.0, +) -> tuple[str, bool]: + """Construct complete multimodal prompt text for Wall-X model. + + Formats input using special tokens including: + - System message + - User observations (with image placeholders) + - Task instructions + - Proprioception prompts + - Assistant responses (with action tokens) + + Args: + instruction_info: Dictionary containing instruction components + action_chunk_size: Number of action tokens to generate + frame_idx: Current frame index + priority_order: Priority order for instruction sampling + img_keys: List of image keys + generate_subtask_ratio: Probability of generating subtask instead of actions + + Returns: + Tuple of (formatted_prompt_text, is_subtask_generation) + """ + # Special tokens for formatting + role_start_symbol = "<|im_start|>" + role_end_symbol = "<|im_end|>" + vision_start_symbol = "<|vision_start|>" + vision_end_symbol = "<|vision_end|>" + image_pad_symbol = "<|image_pad|>" + propri_symbol = "<|propri|>" + action_symbol = "<|action|>" + action_fast_symbol = "<|action_fast|>" + + # System prologue + prologue = f"{role_start_symbol}system\nYou are a helpful assistant.{role_end_symbol}\n" + + # User request with observation + user_request = f"{role_start_symbol}user\nObservation:" + if img_keys: + img_keys = img_key_mapping(img_keys) + for key in img_keys: + user_request += f" {key}: {vision_start_symbol}{image_pad_symbol}{vision_end_symbol}" + user_request += "\nInstruction:" + + # Get frame-specific instruction + frame_instruction_info, _ = get_frame_instruction(instruction_info, frame_idx=frame_idx) + + generate_subtask = False + priority_keys = ["subtask_generation", "distribute"] + + # Decide whether to generate subtask or actions + if ( + bool(set(frame_instruction_info.keys()) & set(priority_keys)) + and random.random() < generate_subtask_ratio + ): + # Generate subtask (equivalent to VQA task) + instruction = frame_instruction_info.get("instruction", "") + text_prompt = "\nPredict the next action in language.\n" + user_message = f"{user_request} {instruction}{text_prompt}{role_end_symbol}\n" + + # Find output instruction from priority keys + for key in priority_keys: + if key in frame_instruction_info: + output_instruction = frame_instruction_info[key] + break + + assistant_output = f"{role_start_symbol}assistant\n{output_instruction}\n{role_end_symbol}" + generate_subtask = True + else: + # Generate actions + instruction = get_task_instruction(frame_instruction_info, priority_order=priority_order) + text_prompt = f"\nPredict the next action in robot action.\nProprioception: {propri_symbol}\n" + user_message = f"{user_request} {instruction}{text_prompt}{role_end_symbol}\n" + assistant_output = f"{role_start_symbol}assistant\n{action_fast_symbol}{role_end_symbol}\n{action_symbol * action_chunk_size}" + + complete_text = prologue + user_message + assistant_output + return complete_text, generate_subtask + + +def img_key_mapping(img_keys: list[str]) -> list[str]: + """Map image keys to camera names. + + Args: + img_keys: List of image keys + + Returns: + List of camera names + """ + processed_img_keys = [] + for key in img_keys: + key = key.replace(OBS_IMAGES + ".", "") + if key in CAMERA_NAME_MAPPING: + key = CAMERA_NAME_MAPPING[key] + else: + if "view" in key: + key = key.replace("_", " ") + else: + key = key + " view" + processed_img_keys.append(key) + return processed_img_keys + + +def get_action_tokens(normalized_actions: torch.Tensor | list, action_tokenizer) -> list[list[str]]: + """Convert normalized actions to action token strings. + + Args: + normalized_actions: Normalized action arrays/tensors + action_tokenizer: Tokenizer for converting actions to tokens + + Returns: + List of action token string lists for each sample + """ + if isinstance(normalized_actions, torch.Tensor): + normalized_actions = normalized_actions.cpu().numpy() + + all_action_tokens = [] + for i in range(len(normalized_actions)): + if isinstance(normalized_actions[i], torch.Tensor): + normalized_actions[i] = normalized_actions[i].cpu().numpy() + + token_id = action_tokenizer(normalized_actions[i]) + action_tokens = [f"<|action_token_{j}|>" for j in token_id[0]] + all_action_tokens.append(action_tokens) + + return all_action_tokens + + +def pad_action_token_strs( + actions_token_lists: list[list[str]], + pad_token: str = "<|endoftext|>", # nosec B107 +) -> list[str]: + """Pad action token lists to same length and join as strings. + + Args: + actions_token_lists: List of action token lists for each sample + pad_token: Token used for padding + + Returns: + List of padded action token strings + """ + max_len = max(len(tokens) for tokens in actions_token_lists) + padded_action_strs = [] + + for tokens in actions_token_lists: + padded_tokens = tokens + ["<|im_end|>\n"] + [pad_token] * (max_len - len(tokens)) + padded_action_strs.append("".join(padded_tokens)) + + return padded_action_strs + + +def replace_action_token( + text: list[str], + norm_action: torch.Tensor | None, + action_tokenizer, + dof_masks: torch.Tensor | None = None, +) -> list[str]: + """Replace action placeholders in text with actual action tokens. + + Args: + text: List of text strings with action placeholders + norm_action: Normalized action tensors + action_tokenizer: Tokenizer for converting actions to tokens + dof_masks: Masks for degrees of freedom + + Returns: + List of text strings with action tokens replaced + """ + if action_tokenizer is not None and norm_action is not None: + # Extract actions based on chunk sizes and DOF masks + norm_action = [action[:32, dof_masks[i, 0].bool()] for i, action in enumerate(norm_action)] + + # Convert to action tokens and pad + actions_fast_tokens = get_action_tokens(norm_action, action_tokenizer) + actions_fast_token_strs = pad_action_token_strs(actions_fast_tokens) + + # Replace action placeholders with actual tokens + actions_fast_token_idx = 0 + for i in range(len(text)): + if "<|action_fast|>" in text[i]: + text[i] = text[i].replace( + "<|action_fast|><|im_end|>\n", + actions_fast_token_strs[actions_fast_token_idx], + ) + actions_fast_token_idx += 1 + + # Remove remaining action placeholders + text = [t.replace("<|action|>", "") for t in text] + else: + # Remove action placeholders when no tokenizer available + text = [t.replace("<|action_fast|><|im_end|>\n", "") for t in text] + + return text diff --git a/src/lerobot/policies/xvla/modeling_xvla.py b/src/lerobot/policies/xvla/modeling_xvla.py index 27c7c6e1b..0436ae527 100644 --- a/src/lerobot/policies/xvla/modeling_xvla.py +++ b/src/lerobot/policies/xvla/modeling_xvla.py @@ -273,7 +273,7 @@ class XVLAPolicy(PreTrainedPolicy): config_class = XVLAConfig name = "xvla" - def __init__(self, config: XVLAConfig): + def __init__(self, config: XVLAConfig, **kwargs): super().__init__(config) config.validate_features() florence_config = config.get_florence_config() diff --git a/src/lerobot/processor/converters.py b/src/lerobot/processor/converters.py index 6b0b67598..126be0e36 100644 --- a/src/lerobot/processor/converters.py +++ b/src/lerobot/processor/converters.py @@ -170,8 +170,9 @@ def _extract_complementary_data(batch: dict[str, Any]) -> dict[str, Any]: task_key = {"task": batch["task"]} if "task" in batch else {} index_key = {"index": batch["index"]} if "index" in batch else {} task_index_key = {"task_index": batch["task_index"]} if "task_index" in batch else {} + episode_index_key = {"episode_index": batch["episode_index"]} if "episode_index" in batch else {} - return {**pad_keys, **task_key, **index_key, **task_index_key} + return {**pad_keys, **task_key, **index_key, **task_index_key, **episode_index_key} def create_transition( diff --git a/src/lerobot/robots/omx_follower/__init__.py b/src/lerobot/robots/omx_follower/__init__.py new file mode 100644 index 000000000..db48dffe9 --- /dev/null +++ b/src/lerobot/robots/omx_follower/__init__.py @@ -0,0 +1,21 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# OMX is a fully open-source robot from ROBOTIS. +# More information at: https://ai.robotis.com/omx/introduction_omx.html + +from .config_omx_follower import OmxFollowerConfig +from .omx_follower import OmxFollower diff --git a/src/lerobot/robots/omx_follower/config_omx_follower.py b/src/lerobot/robots/omx_follower/config_omx_follower.py new file mode 100644 index 000000000..db4179fdf --- /dev/null +++ b/src/lerobot/robots/omx_follower/config_omx_follower.py @@ -0,0 +1,39 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field + +from lerobot.cameras import CameraConfig + +from ..config import RobotConfig + + +@RobotConfig.register_subclass("omx_follower") +@dataclass +class OmxFollowerConfig(RobotConfig): + # Port to connect to the arm + port: str + + disable_torque_on_disconnect: bool = True + + # `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes. + # Set this to a positive scalar to have the same value for all motors, or a dictionary that maps motor + # names to the max_relative_target value for that motor. + max_relative_target: float | dict[str, float] | None = None + + # cameras + cameras: dict[str, CameraConfig] = field(default_factory=dict) + + # Set to `True` for backward compatibility with previous policies/dataset + use_degrees: bool = False diff --git a/src/lerobot/robots/omx_follower/omx_follower.py b/src/lerobot/robots/omx_follower/omx_follower.py new file mode 100644 index 000000000..2dd851377 --- /dev/null +++ b/src/lerobot/robots/omx_follower/omx_follower.py @@ -0,0 +1,225 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import time +from functools import cached_property +from typing import Any + +from lerobot.cameras.utils import make_cameras_from_configs +from lerobot.motors import Motor, MotorCalibration, MotorNormMode +from lerobot.motors.dynamixel import ( + DriveMode, + DynamixelMotorsBus, + OperatingMode, +) +from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError + +from ..robot import Robot +from ..utils import ensure_safe_goal_position +from .config_omx_follower import OmxFollowerConfig + +logger = logging.getLogger(__name__) + + +class OmxFollower(Robot): + """ + - [OMX](https://github.com/ROBOTIS-GIT/open_manipulator), + expansion, developed by Woojin Wie and Junha Cha from [ROBOTIS](https://ai.robotis.com/) + """ + + config_class = OmxFollowerConfig + name = "omx_follower" + + def __init__(self, config: OmxFollowerConfig): + super().__init__(config) + self.config = config + norm_mode_body = MotorNormMode.DEGREES if config.use_degrees else MotorNormMode.RANGE_M100_100 + self.bus = DynamixelMotorsBus( + port=self.config.port, + motors={ + "shoulder_pan": Motor(11, "xl430-w250", norm_mode_body), + "shoulder_lift": Motor(12, "xl430-w250", norm_mode_body), + "elbow_flex": Motor(13, "xl430-w250", norm_mode_body), + "wrist_flex": Motor(14, "xl330-m288", norm_mode_body), + "wrist_roll": Motor(15, "xl330-m288", norm_mode_body), + "gripper": Motor(16, "xl330-m288", MotorNormMode.RANGE_0_100), + }, + calibration=self.calibration, + ) + self.cameras = make_cameras_from_configs(config.cameras) + + @property + def _motors_ft(self) -> dict[str, type]: + return {f"{motor}.pos": float for motor in self.bus.motors} + + @property + def _cameras_ft(self) -> dict[str, tuple]: + return { + cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras + } + + @cached_property + def observation_features(self) -> dict[str, type | tuple]: + return {**self._motors_ft, **self._cameras_ft} + + @cached_property + def action_features(self) -> dict[str, type]: + return self._motors_ft + + @property + def is_connected(self) -> bool: + return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values()) + + def connect(self, calibrate: bool = True) -> None: + """ + For OMX robots that come pre-calibrated: + - If default calibration from package doesn't match motors, read from motors and save + - This allows using pre-calibrated robots without manual calibration + - If no calibration file exists, use factory default values (homing_offset=0, range_min=0, range_max=4095) + """ + if self.is_connected: + raise DeviceAlreadyConnectedError(f"{self} already connected") + + self.bus.connect() + if not self.is_calibrated and calibrate: + logger.info( + "Mismatch between calibration values in the motor and the calibration file or no calibration file found" + ) + self.calibrate() + + for cam in self.cameras.values(): + cam.connect() + + self.configure() + logger.info(f"{self} connected.") + + @property + def is_calibrated(self) -> bool: + return self.bus.is_calibrated + + def calibrate(self) -> None: + self.bus.disable_torque() + logger.info(f"\nUsing factory default calibration values for {self}") + logger.info(f"\nWriting default configuration of {self} to the motors") + for motor in self.bus.motors: + self.bus.write("Operating_Mode", motor, OperatingMode.EXTENDED_POSITION.value) + + for motor in self.bus.motors: + self.bus.write("Drive_Mode", motor, DriveMode.NON_INVERTED.value) + + self.calibration = {} + for motor, m in self.bus.motors.items(): + self.calibration[motor] = MotorCalibration( + id=m.id, + drive_mode=0, + homing_offset=0, + range_min=0, + range_max=4095, + ) + + self.bus.write_calibration(self.calibration) + self._save_calibration() + logger.info(f"Calibration saved to {self.calibration_fpath}") + + def configure(self) -> None: + with self.bus.torque_disabled(): + self.bus.configure_motors() + # Use 'extended position mode' for all motors except gripper, because in joint mode the servos + # can't rotate more than 360 degrees (from 0 to 4095) And some mistake can happen while assembling + # the arm, you could end up with a servo with a position 0 or 4095 at a crucial point + for motor in self.bus.motors: + if motor != "gripper": + self.bus.write("Operating_Mode", motor, OperatingMode.EXTENDED_POSITION.value) + + # Use 'position control current based' for gripper to be limited by the limit of the current. For + # the follower gripper, it means it can grasp an object without forcing too much even tho, its + # goal position is a complete grasp (both gripper fingers are ordered to join and reach a touch). + # For the leader gripper, it means we can use it as a physical trigger, since we can force with + # our finger to make it move, and it will move back to its original target position when we + # release the force. + self.bus.write("Operating_Mode", "gripper", OperatingMode.CURRENT_POSITION.value) + + # Set better PID values to close the gap between recorded states and actions + # TODO(rcadene): Implement an automatic procedure to set optimal PID values for each motor + self.bus.write("Position_P_Gain", "elbow_flex", 1500) + self.bus.write("Position_I_Gain", "elbow_flex", 0) + self.bus.write("Position_D_Gain", "elbow_flex", 600) + + def setup_motors(self) -> None: + for motor in reversed(self.bus.motors): + input(f"Connect the controller board to the '{motor}' motor only and press enter.") + self.bus.setup_motor(motor) + print(f"'{motor}' motor id set to {self.bus.motors[motor].id}") + + def get_observation(self) -> dict[str, Any]: + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + # Read arm position + start = time.perf_counter() + obs_dict = self.bus.sync_read("Present_Position") + obs_dict = {f"{motor}.pos": val for motor, val in obs_dict.items()} + dt_ms = (time.perf_counter() - start) * 1e3 + logger.debug(f"{self} read state: {dt_ms:.1f}ms") + + # Capture images from cameras + for cam_key, cam in self.cameras.items(): + start = time.perf_counter() + obs_dict[cam_key] = cam.async_read() + dt_ms = (time.perf_counter() - start) * 1e3 + logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms") + + return obs_dict + + def send_action(self, action: dict[str, float]) -> dict[str, float]: + """Command arm to move to a target joint configuration. + + The relative action magnitude may be clipped depending on the configuration parameter + `max_relative_target`. In this case, the action sent differs from original action. + Thus, this function always returns the action actually sent. + + Args: + action (dict[str, float]): The goal positions for the motors. + + Returns: + dict[str, float]: The action sent to the motors, potentially clipped. + """ + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + goal_pos = {key.removesuffix(".pos"): val for key, val in action.items() if key.endswith(".pos")} + + # Cap goal position when too far away from present position. + # /!\ Slower fps expected due to reading from the follower. + if self.config.max_relative_target is not None: + present_pos = self.bus.sync_read("Present_Position") + goal_present_pos = {key: (g_pos, present_pos[key]) for key, g_pos in goal_pos.items()} + goal_pos = ensure_safe_goal_position(goal_present_pos, self.config.max_relative_target) + + # Send goal position to the arm + self.bus.sync_write("Goal_Position", goal_pos) + return {f"{motor}.pos": val for motor, val in goal_pos.items()} + + def disconnect(self): + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + self.bus.disconnect(self.config.disable_torque_on_disconnect) + for cam in self.cameras.values(): + cam.disconnect() + + logger.info(f"{self} disconnected.") diff --git a/src/lerobot/robots/unitree_g1/config_unitree_g1.py b/src/lerobot/robots/unitree_g1/config_unitree_g1.py index 575ad50bb..ac65f1a7b 100644 --- a/src/lerobot/robots/unitree_g1/config_unitree_g1.py +++ b/src/lerobot/robots/unitree_g1/config_unitree_g1.py @@ -51,5 +51,8 @@ class UnitreeG1Config(RobotConfig): control_dt: float = 1.0 / 250.0 # 250Hz + # launch mujoco simulation + is_simulation: bool = True + # socket config for ZMQ bridge robot_ip: str = "192.168.123.164" diff --git a/src/lerobot/robots/unitree_g1/run_g1_server.py b/src/lerobot/robots/unitree_g1/run_g1_server.py index ee3505ea4..70166b590 100644 --- a/src/lerobot/robots/unitree_g1/run_g1_server.py +++ b/src/lerobot/robots/unitree_g1/run_g1_server.py @@ -99,11 +99,12 @@ def state_forward_loop( lowstate_sub: ChannelSubscriber, lowstate_sock: zmq.Socket, state_period: float, + shutdown_event: threading.Event, ) -> None: """Read observation from DDS and forward to ZMQ clients.""" last_state_time = 0.0 - while True: + while not shutdown_event.is_set(): # read from DDS msg = lowstate_sub.Read() if msg is None: @@ -128,7 +129,10 @@ def cmd_forward_loop( ) -> None: """Receive commands from ZMQ and forward to DDS.""" while True: - payload = lowcmd_sock.recv() + try: + payload = lowcmd_sock.recv() + except zmq.ContextTerminated: + break msg_dict = json.loads(payload.decode("utf-8")) topic = msg_dict.get("topic", "") @@ -182,30 +186,26 @@ def main() -> None: lowstate_sock.bind(f"tcp://0.0.0.0:{LOWSTATE_PORT}") state_period = 0.002 # ~500 hz + shutdown_event = threading.Event() - # start observation forwarding thread + # start observation forwarding in background thread t_state = threading.Thread( target=state_forward_loop, - args=(lowstate_sub, lowstate_sock, state_period), - daemon=True, + args=(lowstate_sub, lowstate_sock, state_period, shutdown_event), ) t_state.start() - # start action forwarding thread - t_cmd = threading.Thread( - target=cmd_forward_loop, - args=(lowcmd_sock, lowcmd_pub_debug, crc), - daemon=True, - ) - t_cmd.start() - print("bridge running (lowstate -> zmq, lowcmd -> dds)") - # keep main thread alive so daemon threads don't exit + + # run command forwarding in main thread try: - while True: - time.sleep(1.0) + cmd_forward_loop(lowcmd_sock, lowcmd_pub_debug, crc) except KeyboardInterrupt: print("shutting down bridge...") + finally: + shutdown_event.set() + ctx.term() # terminates blocking zmq.recv() calls + t_state.join(timeout=2.0) if __name__ == "__main__": diff --git a/src/lerobot/robots/unitree_g1/unitree_g1.py b/src/lerobot/robots/unitree_g1/unitree_g1.py index 2e7196b57..cce9d1b1e 100644 --- a/src/lerobot/robots/unitree_g1/unitree_g1.py +++ b/src/lerobot/robots/unitree_g1/unitree_g1.py @@ -30,12 +30,8 @@ from unitree_sdk2py.idl.unitree_hg.msg.dds_ import ( ) from unitree_sdk2py.utils.crc import CRC +from lerobot.envs.factory import make_env from lerobot.robots.unitree_g1.g1_utils import G1_29_JointIndex -from lerobot.robots.unitree_g1.unitree_sdk2_socket import ( - ChannelFactoryInitialize, - ChannelPublisher, - ChannelSubscriber, -) from ..robot import Robot from .config_unitree_g1 import UnitreeG1Config @@ -127,7 +123,21 @@ class UnitreeG1(Robot): self.control_dt = config.control_dt + if config.is_simulation: + from unitree_sdk2py.core.channel import ( + ChannelFactoryInitialize, + ChannelPublisher, + ChannelSubscriber, + ) + else: + from lerobot.robots.unitree_g1.unitree_sdk2_socket import ( + ChannelFactoryInitialize, + ChannelPublisher, + ChannelSubscriber, + ) + # connect robot + self.ChannelFactoryInitialize = ChannelFactoryInitialize self.connect() # initialize direct motor control interface @@ -138,8 +148,8 @@ class UnitreeG1(Robot): self.lowstate_buffer = DataBuffer() # initialize subscribe thread to read robot state + self._shutdown_event = threading.Event() self.subscribe_thread = threading.Thread(target=self._subscribe_motor_state) - self.subscribe_thread.daemon = True self.subscribe_thread.start() while not self.is_connected: @@ -174,7 +184,7 @@ class UnitreeG1(Robot): self.remote_controller = self.RemoteController() def _subscribe_motor_state(self): # polls robot state @ 250Hz - while True: + while not self._shutdown_event.is_set(): start_time = time.time() msg = self.lowstate_subscriber.Read() if msg is not None: @@ -218,10 +228,17 @@ class UnitreeG1(Robot): pass def connect(self, calibrate: bool = True) -> None: # connect to DDS - ChannelFactoryInitialize(0) + if self.config.is_simulation: + self.ChannelFactoryInitialize(0, "lo") + self.mujoco_env = make_env("lerobot/unitree-g1-mujoco", trust_remote_code=True) + else: + self.ChannelFactoryInitialize(0) def disconnect(self): - pass + self._shutdown_event.set() + self.subscribe_thread.join(timeout=2.0) + if self.config.is_simulation: + self.mujoco_env["hub_env"][0].envs[0].kill_sim() def get_observation(self) -> dict[str, Any]: return self.lowstate_buffer.get_data() diff --git a/src/lerobot/robots/utils.py b/src/lerobot/robots/utils.py index 4e8001538..9c5043335 100644 --- a/src/lerobot/robots/utils.py +++ b/src/lerobot/robots/utils.py @@ -28,6 +28,10 @@ def make_robot_from_config(config: RobotConfig) -> Robot: from .koch_follower import KochFollower return KochFollower(config) + elif config.type == "omx_follower": + from .omx_follower import OmxFollower + + return OmxFollower(config) elif config.type == "so100_follower": from .so100_follower import SO100Follower diff --git a/src/lerobot/scripts/lerobot_calibrate.py b/src/lerobot/scripts/lerobot_calibrate.py index 8247ec053..910a9a1b5 100644 --- a/src/lerobot/scripts/lerobot_calibrate.py +++ b/src/lerobot/scripts/lerobot_calibrate.py @@ -40,6 +40,7 @@ from lerobot.robots import ( # noqa: F401 koch_follower, lekiwi, make_robot_from_config, + omx_follower, so100_follower, so101_follower, ) @@ -49,6 +50,7 @@ from lerobot.teleoperators import ( # noqa: F401 homunculus, koch_leader, make_teleoperator_from_config, + omx_leader, so100_leader, so101_leader, ) diff --git a/src/lerobot/scripts/lerobot_edit_dataset.py b/src/lerobot/scripts/lerobot_edit_dataset.py index 83ba027bc..e835b1de6 100644 --- a/src/lerobot/scripts/lerobot_edit_dataset.py +++ b/src/lerobot/scripts/lerobot_edit_dataset.py @@ -18,7 +18,8 @@ Edit LeRobot datasets using various transformation tools. This script allows you to delete episodes, split datasets, merge datasets, -and remove features. When new_repo_id is specified, creates a new dataset. +remove features, and convert image datasets to video format. +When new_repo_id is specified, creates a new dataset. Usage Examples: @@ -65,6 +66,25 @@ Remove camera feature: --operation.type remove_feature \ --operation.feature_names "['observation.images.top']" +Convert image dataset to video format (saves locally): + python -m lerobot.scripts.lerobot_edit_dataset \ + --repo_id lerobot/pusht_image \ + --operation.type convert_to_video \ + --operation.output_dir /path/to/output/pusht_video + +Convert image dataset and save with new repo_id: + python -m lerobot.scripts.lerobot_edit_dataset \ + --repo_id lerobot/pusht_image \ + --new_repo_id lerobot/pusht_video \ + --operation.type convert_to_video + +Convert and push to hub: + python -m lerobot.scripts.lerobot_edit_dataset \ + --repo_id lerobot/pusht_image \ + --new_repo_id lerobot/pusht_video \ + --operation.type convert_to_video \ + --push_to_hub true + Using JSON config file: python -m lerobot.scripts.lerobot_edit_dataset \ --config_path path/to/edit_config.json @@ -72,9 +92,13 @@ Using JSON config file: import logging import shutil +from concurrent.futures import ThreadPoolExecutor, as_completed from dataclasses import dataclass from pathlib import Path +import pandas as pd +from tqdm import tqdm + from lerobot.configs import parser from lerobot.datasets.dataset_tools import ( delete_episodes, @@ -82,8 +106,10 @@ from lerobot.datasets.dataset_tools import ( remove_feature, split_dataset, ) -from lerobot.datasets.lerobot_dataset import LeRobotDataset -from lerobot.utils.constants import HF_LEROBOT_HOME +from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata +from lerobot.datasets.utils import write_stats, write_tasks +from lerobot.datasets.video_utils import encode_video_frames, get_video_info +from lerobot.utils.constants import HF_LEROBOT_HOME, OBS_IMAGE from lerobot.utils.utils import init_logging @@ -111,10 +137,23 @@ class RemoveFeatureConfig: feature_names: list[str] | None = None +@dataclass +class ConvertToVideoConfig: + type: str = "convert_to_video" + output_dir: str | None = None + vcodec: str = "libsvtav1" + pix_fmt: str = "yuv420p" + g: int = 2 + crf: int = 30 + fast_decode: int = 0 + episode_indices: list[int] | None = None + num_workers: int = 4 + + @dataclass class EditDatasetConfig: repo_id: str - operation: DeleteEpisodesConfig | SplitConfig | MergeConfig | RemoveFeatureConfig + operation: DeleteEpisodesConfig | SplitConfig | MergeConfig | RemoveFeatureConfig | ConvertToVideoConfig root: str | None = None new_repo_id: str | None = None push_to_hub: bool = False @@ -258,6 +297,415 @@ def handle_remove_feature(cfg: EditDatasetConfig) -> None: LeRobotDataset(output_repo_id, root=output_dir).push_to_hub() +def save_episode_images_for_video( + dataset: LeRobotDataset, + imgs_dir: Path, + img_key: str, + episode_index: int, + num_workers: int = 4, +) -> None: + """Save images from a specific episode and camera to disk for video encoding. + + Args: + dataset: The LeRobot dataset to extract images from + imgs_dir: Directory to save images to + img_key: The image key (camera) to extract + episode_index: Index of the episode to save + num_workers: Number of threads for parallel image saving + """ + # Create directory + imgs_dir.mkdir(parents=True, exist_ok=True) + + # Get dataset without torch format for PIL image access + hf_dataset = dataset.hf_dataset.with_format(None) + + # Select only this camera's images + imgs_dataset = hf_dataset.select_columns(img_key) + + # Get episode start and end indices + from_idx = dataset.meta.episodes["dataset_from_index"][episode_index] + to_idx = dataset.meta.episodes["dataset_to_index"][episode_index] + + # Get all items for this episode + episode_dataset = imgs_dataset.select(range(from_idx, to_idx)) + + # Define function to save a single image + def save_single_image(i_item_tuple): + i, item = i_item_tuple + img = item[img_key] + # Use frame-XXXXXX.png format to match encode_video_frames expectations + img.save(str(imgs_dir / f"frame-{i:06d}.png"), quality=100) + return i + + # Save images with proper naming convention for encode_video_frames (frame-XXXXXX.png) + items = list(enumerate(episode_dataset)) + + with ThreadPoolExecutor(max_workers=num_workers) as executor: + futures = [executor.submit(save_single_image, item) for item in items] + for future in as_completed(futures): + future.result() # This will raise any exceptions that occurred + + +def encode_episode_videos( + dataset: LeRobotDataset, + new_meta: LeRobotDatasetMetadata, + episode_index: int, + vcodec: str, + pix_fmt: str, + g: int, + crf: int, + fast_decode: int, + temp_dir: Path, + num_image_workers: int = 4, +) -> dict[str, dict]: + """Encode videos for a single episode and return video metadata. + + Args: + dataset: Source dataset with images + new_meta: Metadata object for the new video dataset + episode_index: Episode index to process + vcodec: Video codec + pix_fmt: Pixel format + g: Group of pictures size + crf: Constant rate factor + fast_decode: Fast decode tuning + temp_dir: Temporary directory for images + num_image_workers: Number of workers for saving images + + Returns: + Dictionary mapping video keys to their metadata (chunk_index, file_index, timestamps) + """ + hf_dataset = dataset.hf_dataset.with_format(None) + img_keys = [key for key in hf_dataset.features if key.startswith(OBS_IMAGE)] + + video_metadata = {} + fps = int(dataset.fps) # Convert to int for PyAV compatibility + episode_length = dataset.meta.episodes["length"][episode_index] + episode_duration = episode_length / dataset.fps # Use original fps for duration calculation + + for img_key in img_keys: + # Save images temporarily + imgs_dir = temp_dir / f"episode_{episode_index:06d}" / img_key + save_episode_images_for_video(dataset, imgs_dir, img_key, episode_index, num_image_workers) + + # Determine chunk and file indices + # For simplicity, we'll put each episode in its own file + chunk_idx = episode_index // new_meta.chunks_size + file_idx = episode_index % new_meta.chunks_size + + # Create video path in the new dataset structure + video_path = new_meta.root / new_meta.video_path.format( + video_key=img_key, chunk_index=chunk_idx, file_index=file_idx + ) + video_path.parent.mkdir(parents=True, exist_ok=True) + + # Encode video + encode_video_frames( + imgs_dir=imgs_dir, + video_path=video_path, + fps=fps, + vcodec=vcodec, + pix_fmt=pix_fmt, + g=g, + crf=crf, + fast_decode=fast_decode, + overwrite=True, + ) + + # Clean up temporary images + shutil.rmtree(imgs_dir) + + # Store video metadata + video_metadata[img_key] = { + f"videos/{img_key}/chunk_index": chunk_idx, + f"videos/{img_key}/file_index": file_idx, + f"videos/{img_key}/from_timestamp": 0.0, + f"videos/{img_key}/to_timestamp": episode_duration, + } + + return video_metadata + + +def convert_dataset_to_videos( + dataset: LeRobotDataset, + output_dir: Path, + repo_id: str | None = None, + vcodec: str = "libsvtav1", + pix_fmt: str = "yuv420p", + g: int = 2, + crf: int = 30, + fast_decode: int = 0, + episode_indices: list[int] | None = None, + num_workers: int = 4, +) -> LeRobotDataset: + """Convert image-based dataset to video-based dataset. + + Creates a new LeRobotDataset with videos instead of images, following the proper + LeRobot dataset structure with videos stored in chunked MP4 files. + + Args: + dataset: The source LeRobot dataset with images + output_dir: Directory to save the new video dataset + repo_id: Repository ID for the new dataset (default: original_id + "_video") + vcodec: Video codec (default: libsvtav1) + pix_fmt: Pixel format (default: yuv420p) + g: Group of pictures size (default: 2) + crf: Constant rate factor (default: 30) + fast_decode: Fast decode tuning (default: 0) + episode_indices: List of episode indices to convert (None = all episodes) + num_workers: Number of threads for parallel processing (default: 4) + + Returns: + New LeRobotDataset with videos + """ + # Check that it's an image dataset + if len(dataset.meta.video_keys) > 0: + raise ValueError( + f"This operation is for image datasets only. Video dataset provided: {dataset.repo_id}" + ) + + # Get all image keys + hf_dataset = dataset.hf_dataset.with_format(None) + img_keys = [key for key in hf_dataset.features if key.startswith(OBS_IMAGE)] + + if len(img_keys) == 0: + raise ValueError(f"No image keys found in dataset {dataset.repo_id}") + + # Determine which episodes to process + if episode_indices is None: + episode_indices = list(range(dataset.meta.total_episodes)) + + if repo_id is None: + repo_id = f"{dataset.repo_id}_video" + + logging.info( + f"Converting {len(episode_indices)} episodes with {len(img_keys)} cameras from {dataset.repo_id}" + ) + logging.info(f"Video codec: {vcodec}, pixel format: {pix_fmt}, GOP: {g}, CRF: {crf}") + + # Create new features dict, converting image features to video features + new_features = {} + for key, value in dataset.meta.features.items(): + if key not in img_keys: + new_features[key] = value + else: + # Convert image key to video format + new_features[key] = value.copy() + new_features[key]["dtype"] = "video" # Change dtype from "image" to "video" + # Video info will be updated after episodes are encoded + + # Create new metadata for video dataset + new_meta = LeRobotDatasetMetadata.create( + repo_id=repo_id, + fps=dataset.meta.fps, + features=new_features, + robot_type=dataset.meta.robot_type, + root=output_dir, + use_videos=True, + chunks_size=dataset.meta.chunks_size, + data_files_size_in_mb=dataset.meta.data_files_size_in_mb, + video_files_size_in_mb=dataset.meta.video_files_size_in_mb, + ) + + # Create temporary directory for image extraction + temp_dir = output_dir / "temp_images" + temp_dir.mkdir(parents=True, exist_ok=True) + + # Process each episode + all_episode_metadata = [] + + try: + for ep_idx in tqdm(episode_indices, desc="Converting episodes to videos"): + # Get episode metadata from source + src_episode = dataset.meta.episodes[ep_idx] + + # Encode videos for this episode + video_metadata = encode_episode_videos( + dataset=dataset, + new_meta=new_meta, + episode_index=ep_idx, + vcodec=vcodec, + pix_fmt=pix_fmt, + g=g, + crf=crf, + fast_decode=fast_decode, + temp_dir=temp_dir, + num_image_workers=num_workers, + ) + + # Build episode metadata + episode_meta = { + "episode_index": ep_idx, + "length": src_episode["length"], + "dataset_from_index": ep_idx * src_episode["length"], + "dataset_to_index": (ep_idx + 1) * src_episode["length"], + } + + # Add video metadata + for img_key in img_keys: + episode_meta.update(video_metadata[img_key]) + + # Add data chunk/file info (using same structure as source) + if "data/chunk_index" in src_episode: + episode_meta["data/chunk_index"] = src_episode["data/chunk_index"] + episode_meta["data/file_index"] = src_episode["data/file_index"] + + all_episode_metadata.append(episode_meta) + + # Copy and transform data files (removing image columns) + _copy_data_without_images(dataset, new_meta, episode_indices, img_keys) + + # Save episode metadata + episodes_df = pd.DataFrame(all_episode_metadata) + episodes_path = new_meta.root / "meta" / "episodes" / "chunk-000" / "file-000.parquet" + episodes_path.parent.mkdir(parents=True, exist_ok=True) + episodes_df.to_parquet(episodes_path, index=False) + + # Update metadata info + new_meta.info["total_episodes"] = len(episode_indices) + new_meta.info["total_frames"] = sum(ep["length"] for ep in all_episode_metadata) + new_meta.info["total_tasks"] = dataset.meta.total_tasks + new_meta.info["splits"] = {"train": f"0:{len(episode_indices)}"} + + # Update video info for all image keys (now videos) + # We need to manually set video info since update_video_info() checks video_keys first + for img_key in img_keys: + if not new_meta.features[img_key].get("info", None): + video_path = new_meta.root / new_meta.video_path.format( + video_key=img_key, chunk_index=0, file_index=0 + ) + new_meta.info["features"][img_key]["info"] = get_video_info(video_path) + + from lerobot.datasets.utils import write_info + + write_info(new_meta.info, new_meta.root) + + # Copy stats and tasks + if dataset.meta.stats is not None: + # Remove image stats + new_stats = {k: v for k, v in dataset.meta.stats.items() if k not in img_keys} + write_stats(new_stats, new_meta.root) + + if dataset.meta.tasks is not None: + write_tasks(dataset.meta.tasks, new_meta.root) + + finally: + # Clean up temporary directory + if temp_dir.exists(): + shutil.rmtree(temp_dir) + + logging.info(f"βœ“ Completed converting {dataset.repo_id} to video format") + logging.info(f"New dataset saved to: {output_dir}") + + # Return new dataset + return LeRobotDataset(repo_id=repo_id, root=output_dir) + + +def _copy_data_without_images( + src_dataset: LeRobotDataset, + dst_meta: LeRobotDatasetMetadata, + episode_indices: list[int], + img_keys: list[str], +) -> None: + """Copy data files without image columns. + + Args: + src_dataset: Source dataset + dst_meta: Destination metadata + episode_indices: Episodes to include + img_keys: Image keys to remove + """ + from lerobot.datasets.utils import DATA_DIR + + data_dir = src_dataset.root / DATA_DIR + parquet_files = sorted(data_dir.glob("*/*.parquet")) + + if not parquet_files: + raise ValueError(f"No parquet files found in {data_dir}") + + episode_set = set(episode_indices) + + for src_path in tqdm(parquet_files, desc="Processing data files"): + df = pd.read_parquet(src_path).reset_index(drop=True) + + # Filter to only include selected episodes + df = df[df["episode_index"].isin(episode_set)].copy() + + if len(df) == 0: + continue + + # Remove image columns + columns_to_drop = [col for col in img_keys if col in df.columns] + if columns_to_drop: + df = df.drop(columns=columns_to_drop) + + # Get chunk and file indices from path + relative_path = src_path.relative_to(src_dataset.root) + chunk_dir = relative_path.parts[1] + file_name = relative_path.parts[2] + chunk_idx = int(chunk_dir.split("-")[1]) + file_idx = int(file_name.split("-")[1].split(".")[0]) + + # Write to destination without pandas index + dst_path = dst_meta.root / f"data/chunk-{chunk_idx:03d}/file-{file_idx:03d}.parquet" + dst_path.parent.mkdir(parents=True, exist_ok=True) + df.to_parquet(dst_path, index=False) + + +def handle_convert_to_video(cfg: EditDatasetConfig) -> None: + # Note: Parser may create any config type with the right fields, so we access fields directly + # instead of checking isinstance() + dataset = LeRobotDataset(cfg.repo_id, root=cfg.root) + + # Determine output directory and repo_id + # Priority: 1) new_repo_id, 2) operation.output_dir, 3) auto-generated name + output_dir_config = getattr(cfg.operation, "output_dir", None) + + if cfg.new_repo_id: + # Use new_repo_id for both local storage and hub push + output_repo_id = cfg.new_repo_id + output_dir = Path(cfg.root) / cfg.new_repo_id if cfg.root else HF_LEROBOT_HOME / cfg.new_repo_id + logging.info(f"Saving to new dataset: {cfg.new_repo_id}") + elif output_dir_config: + # Use custom output directory for local-only storage + output_dir = Path(output_dir_config) + # Extract repo name from output_dir for the dataset + output_repo_id = output_dir.name + logging.info(f"Saving to local directory: {output_dir}") + else: + # Auto-generate name: append "_video" to original repo_id + output_repo_id = f"{cfg.repo_id}_video" + output_dir = Path(cfg.root) / output_repo_id if cfg.root else HF_LEROBOT_HOME / output_repo_id + logging.info(f"Saving to auto-generated location: {output_dir}") + + logging.info(f"Converting dataset {cfg.repo_id} to video format") + + new_dataset = convert_dataset_to_videos( + dataset=dataset, + output_dir=output_dir, + repo_id=output_repo_id, + vcodec=getattr(cfg.operation, "vcodec", "libsvtav1"), + pix_fmt=getattr(cfg.operation, "pix_fmt", "yuv420p"), + g=getattr(cfg.operation, "g", 2), + crf=getattr(cfg.operation, "crf", 30), + fast_decode=getattr(cfg.operation, "fast_decode", 0), + episode_indices=getattr(cfg.operation, "episode_indices", None), + num_workers=getattr(cfg.operation, "num_workers", 4), + ) + + logging.info("Video dataset created successfully!") + logging.info(f"Location: {output_dir}") + logging.info(f"Episodes: {new_dataset.meta.total_episodes}") + logging.info(f"Frames: {new_dataset.meta.total_frames}") + + if cfg.push_to_hub: + logging.info(f"Pushing to hub as {output_repo_id}...") + new_dataset.push_to_hub() + logging.info("βœ“ Successfully pushed to hub!") + else: + logging.info("Dataset saved locally (not pushed to hub)") + + @parser.wrap() def edit_dataset(cfg: EditDatasetConfig) -> None: operation_type = cfg.operation.type @@ -270,10 +718,12 @@ def edit_dataset(cfg: EditDatasetConfig) -> None: handle_merge(cfg) elif operation_type == "remove_feature": handle_remove_feature(cfg) + elif operation_type == "convert_to_video": + handle_convert_to_video(cfg) else: raise ValueError( f"Unknown operation type: {operation_type}\n" - f"Available operations: delete_episodes, split, merge, remove_feature" + f"Available operations: delete_episodes, split, merge, remove_feature, convert_to_video" ) diff --git a/src/lerobot/scripts/lerobot_find_joint_limits.py b/src/lerobot/scripts/lerobot_find_joint_limits.py index 95fbd0646..f97c0d820 100644 --- a/src/lerobot/scripts/lerobot_find_joint_limits.py +++ b/src/lerobot/scripts/lerobot_find_joint_limits.py @@ -46,6 +46,7 @@ from lerobot.robots import ( # noqa: F401 RobotConfig, koch_follower, make_robot_from_config, + omx_follower, so100_follower, so101_follower, ) @@ -54,6 +55,7 @@ from lerobot.teleoperators import ( # noqa: F401 gamepad, koch_leader, make_teleoperator_from_config, + omx_leader, so100_leader, so101_leader, ) diff --git a/src/lerobot/scripts/lerobot_info.py b/src/lerobot/scripts/lerobot_info.py index 9b49cad18..879d392be 100644 --- a/src/lerobot/scripts/lerobot_info.py +++ b/src/lerobot/scripts/lerobot_info.py @@ -27,6 +27,25 @@ lerobot-info import importlib import platform +import shutil +import subprocess +from importlib.metadata import PackageNotFoundError, distribution + +PACKAGE_NAME = "lerobot" + + +def get_ffmpeg_version() -> str: + """Get the ffmpeg version if installed, otherwise return 'N/A'.""" + command_path = shutil.which("ffmpeg") + if command_path is None: + return "N/A" + try: + result = subprocess.run([command_path, "-version"], capture_output=True, text=True, check=True) + first_line = result.stdout.splitlines()[0] + version_info = first_line.split(" ")[2] + return version_info + except (subprocess.SubprocessError, IndexError): + return "Installed (version parsing failed)" def get_package_version(package_name: str) -> str: @@ -38,16 +57,17 @@ def get_package_version(package_name: str) -> str: return "N/A" -def get_sys_info() -> dict: +def get_sys_info() -> dict[str, str]: """Run this to get basic system info to help for tracking issues & bugs.""" # General package versions info = { - "lerobot version": get_package_version("lerobot"), + "LeRobot version": get_package_version(PACKAGE_NAME), "Platform": platform.platform(), "Python version": platform.python_version(), "Huggingface Hub version": get_package_version("huggingface_hub"), "Datasets version": get_package_version("datasets"), "Numpy version": get_package_version("numpy"), + "FFmpeg version": get_ffmpeg_version(), } # PyTorch and GPU specific information @@ -58,10 +78,10 @@ def get_sys_info() -> dict: try: import torch - torch_version = torch.__version__ + torch_version = str(torch.__version__) torch_cuda_available = torch.cuda.is_available() if torch_cuda_available: - cuda_version = torch.version.cuda + cuda_version = str(torch.version.cuda) # Gets the name of the first available GPU gpu_model = torch.cuda.get_device_name(0) except ImportError: @@ -71,24 +91,34 @@ def get_sys_info() -> dict: info.update( { "PyTorch version": torch_version, - "Is PyTorch built with CUDA support?": torch_cuda_available, + "Is PyTorch built with CUDA support?": str(torch_cuda_available), "Cuda version": cuda_version, "GPU model": gpu_model, "Using GPU in script?": "", } ) + scripts = "N/A" + try: + dist = distribution(PACKAGE_NAME) + scripts = [ep.name for ep in dist.entry_points if ep.group == "console_scripts"] + except PackageNotFoundError: + pass + + info.update({f"{PACKAGE_NAME} scripts": str(scripts)}) return info -def format_dict_for_markdown(d: dict) -> str: +def format_dict_for_markdown(d: dict[str, str]) -> str: """Formats a dictionary into a markdown-friendly bulleted list.""" return "\n".join([f"- {prop}: {val}" for prop, val in d.items()]) def main(): + """ + Main function to print system info in markdown format. + """ system_info = get_sys_info() - print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the last point.\n") print(format_dict_for_markdown(system_info)) diff --git a/src/lerobot/scripts/lerobot_record.py b/src/lerobot/scripts/lerobot_record.py index 255e9681c..948e92bb8 100644 --- a/src/lerobot/scripts/lerobot_record.py +++ b/src/lerobot/scripts/lerobot_record.py @@ -97,6 +97,7 @@ from lerobot.robots import ( # noqa: F401 hope_jr, koch_follower, make_robot_from_config, + omx_follower, so100_follower, so101_follower, ) @@ -107,6 +108,7 @@ from lerobot.teleoperators import ( # noqa: F401 homunculus, koch_leader, make_teleoperator_from_config, + omx_leader, so100_leader, so101_leader, ) @@ -270,7 +272,12 @@ def record_loop( for t in teleop if isinstance( t, - (so100_leader.SO100Leader | so101_leader.SO101Leader | koch_leader.KochLeader), + ( + so100_leader.SO100Leader + | so101_leader.SO101Leader + | koch_leader.KochLeader + | omx_leader.OmxLeader + ), ) ), None, @@ -397,82 +404,63 @@ def record(cfg: RecordConfig) -> LeRobotDataset: ), ) - if cfg.resume: - dataset = LeRobotDataset( - cfg.dataset.repo_id, - root=cfg.dataset.root, - batch_encoding_size=cfg.dataset.video_encoding_batch_size, - ) + dataset = None + listener = None - if hasattr(robot, "cameras") and len(robot.cameras) > 0: - dataset.start_image_writer( - num_processes=cfg.dataset.num_image_writer_processes, - num_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot.cameras), - ) - sanity_check_dataset_robot_compatibility(dataset, robot, cfg.dataset.fps, dataset_features) - else: - # Create empty dataset or load existing saved episodes - sanity_check_dataset_name(cfg.dataset.repo_id, cfg.policy) - dataset = LeRobotDataset.create( - cfg.dataset.repo_id, - cfg.dataset.fps, - root=cfg.dataset.root, - robot_type=robot.name, - features=dataset_features, - use_videos=cfg.dataset.video, - image_writer_processes=cfg.dataset.num_image_writer_processes, - image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot.cameras), - batch_encoding_size=cfg.dataset.video_encoding_batch_size, - ) - - # Load pretrained policy - policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta) - preprocessor = None - postprocessor = None - if cfg.policy is not None: - preprocessor, postprocessor = make_pre_post_processors( - policy_cfg=cfg.policy, - pretrained_path=cfg.policy.pretrained_path, - dataset_stats=rename_stats(dataset.meta.stats, cfg.dataset.rename_map), - preprocessor_overrides={ - "device_processor": {"device": cfg.policy.device}, - "rename_observations_processor": {"rename_map": cfg.dataset.rename_map}, - }, - ) - - robot.connect() - if teleop is not None: - teleop.connect() - - listener, events = init_keyboard_listener() - - with VideoEncodingManager(dataset): - recorded_episodes = 0 - while recorded_episodes < cfg.dataset.num_episodes and not events["stop_recording"]: - log_say(f"Recording episode {dataset.num_episodes}", cfg.play_sounds) - record_loop( - robot=robot, - events=events, - fps=cfg.dataset.fps, - teleop_action_processor=teleop_action_processor, - robot_action_processor=robot_action_processor, - robot_observation_processor=robot_observation_processor, - teleop=teleop, - policy=policy, - preprocessor=preprocessor, - postprocessor=postprocessor, - dataset=dataset, - control_time_s=cfg.dataset.episode_time_s, - single_task=cfg.dataset.single_task, - display_data=cfg.display_data, + try: + if cfg.resume: + dataset = LeRobotDataset( + cfg.dataset.repo_id, + root=cfg.dataset.root, + batch_encoding_size=cfg.dataset.video_encoding_batch_size, ) - # Execute a few seconds without recording to give time to manually reset the environment - # Skip reset for the last episode to be recorded - if not events["stop_recording"] and ( - (recorded_episodes < cfg.dataset.num_episodes - 1) or events["rerecord_episode"] - ): - log_say("Reset the environment", cfg.play_sounds) + if hasattr(robot, "cameras") and len(robot.cameras) > 0: + dataset.start_image_writer( + num_processes=cfg.dataset.num_image_writer_processes, + num_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot.cameras), + ) + sanity_check_dataset_robot_compatibility(dataset, robot, cfg.dataset.fps, dataset_features) + else: + # Create empty dataset or load existing saved episodes + sanity_check_dataset_name(cfg.dataset.repo_id, cfg.policy) + dataset = LeRobotDataset.create( + cfg.dataset.repo_id, + cfg.dataset.fps, + root=cfg.dataset.root, + robot_type=robot.name, + features=dataset_features, + use_videos=cfg.dataset.video, + image_writer_processes=cfg.dataset.num_image_writer_processes, + image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot.cameras), + batch_encoding_size=cfg.dataset.video_encoding_batch_size, + ) + + # Load pretrained policy + policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta) + preprocessor = None + postprocessor = None + if cfg.policy is not None: + preprocessor, postprocessor = make_pre_post_processors( + policy_cfg=cfg.policy, + pretrained_path=cfg.policy.pretrained_path, + dataset_stats=rename_stats(dataset.meta.stats, cfg.dataset.rename_map), + preprocessor_overrides={ + "device_processor": {"device": cfg.policy.device}, + "rename_observations_processor": {"rename_map": cfg.dataset.rename_map}, + }, + ) + + robot.connect() + if teleop is not None: + teleop.connect() + + listener, events = init_keyboard_listener() + + with VideoEncodingManager(dataset): + recorded_episodes = 0 + while recorded_episodes < cfg.dataset.num_episodes and not events["stop_recording"]: + log_say(f"Recording episode {dataset.num_episodes}", cfg.play_sounds) record_loop( robot=robot, events=events, @@ -481,34 +469,61 @@ def record(cfg: RecordConfig) -> LeRobotDataset: robot_action_processor=robot_action_processor, robot_observation_processor=robot_observation_processor, teleop=teleop, - control_time_s=cfg.dataset.reset_time_s, + policy=policy, + preprocessor=preprocessor, + postprocessor=postprocessor, + dataset=dataset, + control_time_s=cfg.dataset.episode_time_s, single_task=cfg.dataset.single_task, display_data=cfg.display_data, ) - if events["rerecord_episode"]: - log_say("Re-record episode", cfg.play_sounds) - events["rerecord_episode"] = False - events["exit_early"] = False - dataset.clear_episode_buffer() - continue + # Execute a few seconds without recording to give time to manually reset the environment + # Skip reset for the last episode to be recorded + if not events["stop_recording"] and ( + (recorded_episodes < cfg.dataset.num_episodes - 1) or events["rerecord_episode"] + ): + log_say("Reset the environment", cfg.play_sounds) + record_loop( + robot=robot, + events=events, + fps=cfg.dataset.fps, + teleop_action_processor=teleop_action_processor, + robot_action_processor=robot_action_processor, + robot_observation_processor=robot_observation_processor, + teleop=teleop, + control_time_s=cfg.dataset.reset_time_s, + single_task=cfg.dataset.single_task, + display_data=cfg.display_data, + ) - dataset.save_episode() - recorded_episodes += 1 + if events["rerecord_episode"]: + log_say("Re-record episode", cfg.play_sounds) + events["rerecord_episode"] = False + events["exit_early"] = False + dataset.clear_episode_buffer() + continue - log_say("Stop recording", cfg.play_sounds, blocking=True) + dataset.save_episode() + recorded_episodes += 1 + finally: + log_say("Stop recording", cfg.play_sounds, blocking=True) - robot.disconnect() - if teleop is not None: - teleop.disconnect() + if dataset: + dataset.finalize() - if not is_headless() and listener is not None: - listener.stop() + if robot.is_connected: + robot.disconnect() + if teleop and teleop.is_connected: + teleop.disconnect() - if cfg.dataset.push_to_hub: - dataset.push_to_hub(tags=cfg.dataset.tags, private=cfg.dataset.private) + if not is_headless() and listener: + listener.stop() - log_say("Exiting", cfg.play_sounds) + if cfg.dataset.push_to_hub: + dataset.push_to_hub(tags=cfg.dataset.tags, private=cfg.dataset.private) + + log_say("Exiting", cfg.play_sounds) return dataset diff --git a/src/lerobot/scripts/lerobot_replay.py b/src/lerobot/scripts/lerobot_replay.py index 52cb1d73c..d5808c768 100644 --- a/src/lerobot/scripts/lerobot_replay.py +++ b/src/lerobot/scripts/lerobot_replay.py @@ -58,6 +58,7 @@ from lerobot.robots import ( # noqa: F401 hope_jr, koch_follower, make_robot_from_config, + omx_follower, so100_follower, so101_follower, ) diff --git a/src/lerobot/scripts/lerobot_setup_motors.py b/src/lerobot/scripts/lerobot_setup_motors.py index c1d256c21..b721e55ca 100644 --- a/src/lerobot/scripts/lerobot_setup_motors.py +++ b/src/lerobot/scripts/lerobot_setup_motors.py @@ -33,6 +33,7 @@ from lerobot.robots import ( # noqa: F401 koch_follower, lekiwi, make_robot_from_config, + omx_follower, so100_follower, so101_follower, ) @@ -40,6 +41,7 @@ from lerobot.teleoperators import ( # noqa: F401 TeleoperatorConfig, koch_leader, make_teleoperator_from_config, + omx_leader, so100_leader, so101_leader, ) @@ -47,6 +49,8 @@ from lerobot.teleoperators import ( # noqa: F401 COMPATIBLE_DEVICES = [ "koch_follower", "koch_leader", + "omx_follower", + "omx_leader", "so100_follower", "so100_leader", "so101_follower", diff --git a/src/lerobot/scripts/lerobot_teleoperate.py b/src/lerobot/scripts/lerobot_teleoperate.py index d83766b67..bf722d6f1 100644 --- a/src/lerobot/scripts/lerobot_teleoperate.py +++ b/src/lerobot/scripts/lerobot_teleoperate.py @@ -75,6 +75,7 @@ from lerobot.robots import ( # noqa: F401 hope_jr, koch_follower, make_robot_from_config, + omx_follower, so100_follower, so101_follower, ) @@ -87,6 +88,7 @@ from lerobot.teleoperators import ( # noqa: F401 keyboard, koch_leader, make_teleoperator_from_config, + omx_leader, so100_leader, so101_leader, ) diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index 1ebdee600..6cf733442 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -62,6 +62,7 @@ def update_policy( accelerator: Accelerator, lr_scheduler=None, lock=None, + rabc_weights_provider=None, ) -> tuple[MetricsTracker, dict]: """ Performs a single training step to update the policy's weights. @@ -78,6 +79,7 @@ def update_policy( accelerator: The Accelerator instance for distributed training and mixed precision. lr_scheduler: An optional learning rate scheduler. lock: An optional lock for thread-safe optimizer updates. + rabc_weights_provider: Optional RABCWeights instance for sample weighting. Returns: A tuple containing: @@ -87,9 +89,30 @@ def update_policy( start_time = time.perf_counter() policy.train() + # Get RA-BC weights if enabled + rabc_batch_weights = None + rabc_batch_stats = None + if rabc_weights_provider is not None: + rabc_batch_weights, rabc_batch_stats = rabc_weights_provider.compute_batch_weights(batch) + # Let accelerator handle mixed precision with accelerator.autocast(): - loss, output_dict = policy.forward(batch) + # Use per-sample loss when RA-BC is enabled for proper weighting + if rabc_batch_weights is not None: + # Get per-sample losses + per_sample_loss, output_dict = policy.forward(batch, reduction="none") + + # Apply RA-BC weights: L_RA-BC = Ξ£(w_i * l_i) / (Ξ£w_i + Ξ΅) + # rabc_batch_weights is already normalized to sum to batch_size + epsilon = 1e-6 + loss = (per_sample_loss * rabc_batch_weights).sum() / (rabc_batch_weights.sum() + epsilon) + # Log raw mean weight (before normalization) - this is the meaningful metric + output_dict["rabc_mean_weight"] = rabc_batch_stats["raw_mean_weight"] + output_dict["rabc_num_zero_weight"] = rabc_batch_stats["num_zero_weight"] + output_dict["rabc_num_full_weight"] = rabc_batch_stats["num_full_weight"] + else: + loss, output_dict = policy.forward(batch) + # TODO(rcadene): policy.unnormalize_outputs(out_dict) # Use accelerator's backward method @@ -141,8 +164,6 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): cfg: A `TrainPipelineConfig` object containing all training configurations. accelerator: Optional Accelerator instance. If None, one will be created automatically. """ - cfg.validate() - # Create Accelerator if not provided # It will automatically detect if running in distributed mode or single-process mode # We set step_scheduler_with_optimizer=False to prevent accelerate from adjusting the lr_scheduler steps based on the num_processes @@ -159,6 +180,8 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): # When using accelerate, only the main process should log to avoid duplicate outputs is_main_process = accelerator.is_main_process + cfg.validate() + # Only log on main process if is_main_process: logging.info(pformat(cfg.to_dict())) @@ -217,6 +240,10 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): # Only provide dataset_stats when not resuming from saved processor state processor_kwargs["dataset_stats"] = dataset.meta.stats + # For SARM, always provide dataset_meta for progress normalization + if cfg.policy.type == "sarm": + processor_kwargs["dataset_meta"] = dataset.meta + if cfg.policy.pretrained_path is not None: processor_kwargs["preprocessor_overrides"] = { "device_processor": {"device": device.type}, @@ -248,6 +275,29 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): logging.info("Creating optimizer and scheduler") optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy) + # Load precomputed SARM progress for RA-BC if enabled + # Generate progress using: src/lerobot/policies/sarm/compute_rabc_weights.py + rabc_weights = None + if cfg.use_rabc: + from lerobot.utils.rabc import RABCWeights + + # Get chunk_size from policy config + chunk_size = getattr(policy.config, "chunk_size", None) + if chunk_size is None: + raise ValueError("Chunk size is not found in policy config") + + head_mode = getattr(cfg, "rabc_head_mode", "sparse") + logging.info(f"Loading SARM progress for RA-BC from {cfg.rabc_progress_path}") + logging.info(f"Using chunk_size={chunk_size} from policy config, head_mode={head_mode}") + rabc_weights = RABCWeights( + progress_path=cfg.rabc_progress_path, + chunk_size=chunk_size, + head_mode=head_mode, + kappa=getattr(cfg, "rabc_kappa", 0.01), + epsilon=getattr(cfg, "rabc_epsilon", 1e-6), + device=device, + ) + step = 0 # number of policy updates (forward + backward + optim) if cfg.resume: @@ -327,7 +377,9 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): ) if is_main_process: - logging.info("Start offline training on a fixed dataset") + logging.info( + f"Start offline training on a fixed dataset, with effective batch size: {effective_batch_size}" + ) for _ in range(step, cfg.steps): start_time = time.perf_counter() @@ -343,6 +395,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): cfg.optimizer.grad_clip_norm, accelerator=accelerator, lr_scheduler=lr_scheduler, + rabc_weights_provider=rabc_weights, ) # Note: eval and checkpoint happens *after* the `step`th training update has completed, so we @@ -359,6 +412,16 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): wandb_log_dict = train_tracker.to_dict() if output_dict: wandb_log_dict.update(output_dict) + # Log RA-BC statistics if enabled + if rabc_weights is not None: + rabc_stats = rabc_weights.get_stats() + wandb_log_dict.update( + { + "rabc_delta_mean": rabc_stats["delta_mean"], + "rabc_delta_std": rabc_stats["delta_std"], + "rabc_num_frames": rabc_stats["num_frames"], + } + ) wandb_logger.log_dict(wandb_log_dict, step) train_tracker.reset_averages() diff --git a/src/lerobot/teleoperators/omx_leader/__init__.py b/src/lerobot/teleoperators/omx_leader/__init__.py new file mode 100644 index 000000000..04d96d63e --- /dev/null +++ b/src/lerobot/teleoperators/omx_leader/__init__.py @@ -0,0 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .config_omx_leader import OmxLeaderConfig +from .omx_leader import OmxLeader diff --git a/src/lerobot/teleoperators/omx_leader/config_omx_leader.py b/src/lerobot/teleoperators/omx_leader/config_omx_leader.py new file mode 100644 index 000000000..3c0420ab2 --- /dev/null +++ b/src/lerobot/teleoperators/omx_leader/config_omx_leader.py @@ -0,0 +1,30 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + +from ..config import TeleoperatorConfig + + +@TeleoperatorConfig.register_subclass("omx_leader") +@dataclass +class OmxLeaderConfig(TeleoperatorConfig): + # Port to connect to the arm + port: str + + # Sets the arm in torque mode with the gripper motor set to this value. This makes it possible to squeeze + # the gripper and have it spring back to an open position on its own. + gripper_open_pos: float = 37.0 diff --git a/src/lerobot/teleoperators/omx_leader/omx_leader.py b/src/lerobot/teleoperators/omx_leader/omx_leader.py new file mode 100644 index 000000000..c0e49b558 --- /dev/null +++ b/src/lerobot/teleoperators/omx_leader/omx_leader.py @@ -0,0 +1,165 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import time + +from lerobot.motors import Motor, MotorCalibration, MotorNormMode +from lerobot.motors.dynamixel import ( + DriveMode, + DynamixelMotorsBus, + OperatingMode, +) +from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError + +from ..teleoperator import Teleoperator +from .config_omx_leader import OmxLeaderConfig + +logger = logging.getLogger(__name__) + + +class OmxLeader(Teleoperator): + """ + - [OMX](https://github.com/ROBOTIS-GIT/open_manipulator), + expansion, developed by Woojin Wie and Junha Cha from [ROBOTIS](https://ai.robotis.com/) + """ + + config_class = OmxLeaderConfig + name = "omx_leader" + + def __init__(self, config: OmxLeaderConfig): + super().__init__(config) + self.config = config + self.bus = DynamixelMotorsBus( + port=self.config.port, + motors={ + "shoulder_pan": Motor(1, "xl330-m288", MotorNormMode.RANGE_M100_100), + "shoulder_lift": Motor(2, "xl330-m288", MotorNormMode.RANGE_M100_100), + "elbow_flex": Motor(3, "xl330-m288", MotorNormMode.RANGE_M100_100), + "wrist_flex": Motor(4, "xl330-m288", MotorNormMode.RANGE_M100_100), + "wrist_roll": Motor(5, "xl330-m288", MotorNormMode.RANGE_M100_100), + "gripper": Motor(6, "xl330-m077", MotorNormMode.RANGE_0_100), + }, + calibration=self.calibration, + ) + + @property + def action_features(self) -> dict[str, type]: + return {f"{motor}.pos": float for motor in self.bus.motors} + + @property + def feedback_features(self) -> dict[str, type]: + return {} + + @property + def is_connected(self) -> bool: + return self.bus.is_connected + + def connect(self, calibrate: bool = True) -> None: + if self.is_connected: + raise DeviceAlreadyConnectedError(f"{self} already connected") + + self.bus.connect() + if not self.is_calibrated and calibrate: + logger.info( + "Mismatch between calibration values in the motor and the calibration file or no calibration file found" + ) + self.calibrate() + + self.configure() + logger.info(f"{self} connected.") + + @property + def is_calibrated(self) -> bool: + return self.bus.is_calibrated + + def calibrate(self) -> None: + self.bus.disable_torque() + logger.info(f"\nUsing factory default calibration values for {self}") + logger.info(f"\nWriting default configuration of {self} to the motors") + for motor in self.bus.motors: + self.bus.write("Operating_Mode", motor, OperatingMode.EXTENDED_POSITION.value) + + for motor in self.bus.motors: + if motor == "gripper": + self.bus.write("Drive_Mode", motor, DriveMode.INVERTED.value) + else: + self.bus.write("Drive_Mode", motor, DriveMode.NON_INVERTED.value) + drive_modes = {motor: 1 if motor == "gripper" else 0 for motor in self.bus.motors} + + self.calibration = {} + for motor, m in self.bus.motors.items(): + self.calibration[motor] = MotorCalibration( + id=m.id, + drive_mode=drive_modes[motor], + homing_offset=0, + range_min=0, + range_max=4095, + ) + + self.bus.write_calibration(self.calibration) + self._save_calibration() + logger.info(f"Calibration saved to {self.calibration_fpath}") + + def configure(self) -> None: + self.bus.disable_torque() + self.bus.configure_motors() + for motor in self.bus.motors: + if motor != "gripper": + # Use 'extended position mode' for all motors except gripper, because in joint mode the servos + # can't rotate more than 360 degrees (from 0 to 4095) And some mistake can happen while + # assembling the arm, you could end up with a servo with a position 0 or 4095 at a crucial + # point + self.bus.write("Operating_Mode", motor, OperatingMode.EXTENDED_POSITION.value) + + # Use 'position control current based' for gripper to be limited by the limit of the current. + # For the follower gripper, it means it can grasp an object without forcing too much even tho, + # its goal position is a complete grasp (both gripper fingers are ordered to join and reach a touch). + # For the leader gripper, it means we can use it as a physical trigger, since we can force with our finger + # to make it move, and it will move back to its original target position when we release the force. + self.bus.write("Operating_Mode", "gripper", OperatingMode.CURRENT_POSITION.value) + # Set gripper's goal pos in current position mode so that we can use it as a trigger. + self.bus.enable_torque("gripper") + if self.is_calibrated: + self.bus.write("Goal_Position", "gripper", self.config.gripper_open_pos) + + def setup_motors(self) -> None: + for motor in reversed(self.bus.motors): + input(f"Connect the controller board to the '{motor}' motor only and press enter.") + self.bus.setup_motor(motor) + print(f"'{motor}' motor id set to {self.bus.motors[motor].id}") + + def get_action(self) -> dict[str, float]: + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + start = time.perf_counter() + action = self.bus.sync_read("Present_Position") + action = {f"{motor}.pos": val for motor, val in action.items()} + dt_ms = (time.perf_counter() - start) * 1e3 + logger.debug(f"{self} read action: {dt_ms:.1f}ms") + return action + + def send_feedback(self, feedback: dict[str, float]) -> None: + # TODO(rcadene, aliberts): Implement force feedback + raise NotImplementedError + + def disconnect(self) -> None: + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + self.bus.disconnect() + logger.info(f"{self} disconnected.") diff --git a/src/lerobot/teleoperators/utils.py b/src/lerobot/teleoperators/utils.py index 2103a1669..699d1253f 100644 --- a/src/lerobot/teleoperators/utils.py +++ b/src/lerobot/teleoperators/utils.py @@ -41,6 +41,10 @@ def make_teleoperator_from_config(config: TeleoperatorConfig) -> Teleoperator: from .koch_leader import KochLeader return KochLeader(config) + elif config.type == "omx_leader": + from .omx_leader import OmxLeader + + return OmxLeader(config) elif config.type == "so100_leader": from .so100_leader import SO100Leader diff --git a/src/lerobot/utils/rabc.py b/src/lerobot/utils/rabc.py new file mode 100644 index 000000000..c529f3ccc --- /dev/null +++ b/src/lerobot/utils/rabc.py @@ -0,0 +1,276 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from pathlib import Path + +import numpy as np +import pandas as pd +import torch + + +class RABCWeights: + """ + Load precomputed SARM progress values and compute RA-BC weights during training. + + Progress values are loaded from a parquet file (generated by compute_rabc_weights.py). + During training, computes: + - progress_delta = progress[t + chunk_size] - progress[t] + - rabc_weight based on the delta (paper Eq. 8-9) + + Args: + progress_path: Path to parquet file with precomputed progress values + chunk_size: Number of frames ahead for computing progress delta + head_mode: Which SARM head to use ("sparse" or "dense") + kappa: Hard threshold for high-quality samples (default: 0.01) + epsilon: Small constant for numerical stability (default: 1e-6) + fallback_weight: Weight to use for frames without valid delta (default: 1.0) + device: Device to return tensors on + """ + + def __init__( + self, + progress_path: str | Path, + chunk_size: int = 50, + head_mode: str = "sparse", + kappa: float = 0.01, + epsilon: float = 1e-6, + fallback_weight: float = 1.0, + device: torch.device = None, + ): + self.progress_path = Path(progress_path) + self.chunk_size = chunk_size + self.head_mode = head_mode + self.kappa = kappa + self.epsilon = epsilon + self.fallback_weight = fallback_weight + self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # Determine progress column name + self.progress_column = f"progress_{head_mode}" + + # Load progress values + logging.info(f"Loading SARM progress values from {self.progress_path}") + self.df = pd.read_parquet(self.progress_path) + + # Check if the requested head mode column exists + if self.progress_column not in self.df.columns: + available = [c for c in self.df.columns if c.startswith("progress")] + raise ValueError( + f"Column '{self.progress_column}' not found. Available progress columns: {available}" + ) + + logging.info(f"Using progress column: {self.progress_column}") + + self.progress_lookup = {} + self.episode_lookup = {} + + for _, row in self.df.iterrows(): + global_idx = int(row["index"]) + progress = row[self.progress_column] + episode_idx = int(row["episode_index"]) + + if not np.isnan(progress): + self.progress_lookup[global_idx] = float(progress) + self.episode_lookup[global_idx] = episode_idx + + # Build episode boundaries for delta computation + self.episode_boundaries = {} + for episode_idx in self.df["episode_index"].unique(): + ep_df = self.df[self.df["episode_index"] == episode_idx] + self.episode_boundaries[int(episode_idx)] = { + "start": int(ep_df["index"].min()), + "end": int(ep_df["index"].max()) + 1, + } + + logging.info(f"Loaded {len(self.progress_lookup)} frame progress values") + logging.info(f"Chunk size for delta computation: {chunk_size}") + + # Compute global statistics for weight computation + self._compute_global_stats() + + def _compute_global_stats(self): + """Compute global mean and std of progress deltas for weight calculation.""" + all_deltas = [] + + for global_idx, progress in self.progress_lookup.items(): + episode_idx = self.episode_lookup.get(global_idx) + if episode_idx is None: + continue + + bounds = self.episode_boundaries.get(episode_idx) + if bounds is None: + continue + + future_idx = global_idx + self.chunk_size + if future_idx >= bounds["end"]: + # Near end of episode: use last frame's progress + future_idx = bounds["end"] - 1 + + future_progress = self.progress_lookup.get(future_idx) + if future_progress is not None: + delta = future_progress - progress + all_deltas.append(delta) + + if all_deltas: + self.delta_mean = max(np.mean(all_deltas), 0.0) + self.delta_std = max(np.std(all_deltas), self.epsilon) + logging.info(f"Progress delta stats: mean={self.delta_mean:.4f}, std={self.delta_std:.4f}") + else: + self.delta_mean = 0.0 + self.delta_std = self.epsilon + logging.warning("No valid progress deltas found, using default stats") + + def compute_batch_weights(self, batch: dict) -> tuple[torch.Tensor, dict]: + """ + Compute RA-BC weights for a batch. + + For each sample: + 1. Get progress at current frame + 2. Get progress at frame + chunk_size (within same episode) + 3. Compute delta = future_progress - current_progress + 4. Compute weight using paper Eq. 8-9 + + Args: + batch: Training batch containing "index" key with global frame indices + + Returns: + Tuple of: + - Weights tensor (batch_size,) normalized to sum to batch_size + - Stats dict with raw_mean_weight, num_zero_weight, num_full_weight + """ + indices = batch.get("index") + if indices is None: + logging.warning("RA-BC: Batch missing 'index' key, using uniform weights") + batch_size = self._get_batch_size(batch) + return torch.ones(batch_size, device=self.device), {"raw_mean_weight": 1.0} + + # Convert to list of ints + if isinstance(indices, torch.Tensor): + indices = indices.cpu().numpy().tolist() + elif isinstance(indices, np.ndarray): + indices = indices.tolist() + + # Compute deltas and weights for each sample + deltas = [] + for idx in indices: + idx = int(idx) + delta = self._compute_delta(idx) + deltas.append(delta) + + deltas = np.array(deltas, dtype=np.float32) + + # Compute weights from deltas + weights = self._compute_weights(deltas) + + # Compute stats before normalization for logging + raw_mean_weight = float(np.nanmean(weights)) + num_zero_weight = int(np.sum(weights == 0)) + num_full_weight = int(np.sum(weights == 1.0)) + batch_stats = { + "raw_mean_weight": raw_mean_weight, + "num_zero_weight": num_zero_weight, + "num_full_weight": num_full_weight, + } + + weights = torch.tensor(weights, device=self.device, dtype=torch.float32) + + # Normalize to sum to batch_size + batch_size = len(weights) + weight_sum = weights.sum() + self.epsilon + weights = weights * batch_size / weight_sum + + return weights, batch_stats + + def _compute_delta(self, global_idx: int) -> float: + """Compute progress delta for a single frame.""" + current_progress = self.progress_lookup.get(global_idx) + if current_progress is None: + return np.nan + + episode_idx = self.episode_lookup.get(global_idx) + if episode_idx is None: + return np.nan + + bounds = self.episode_boundaries.get(episode_idx) + if bounds is None: + return np.nan + + future_idx = global_idx + self.chunk_size # Ξ” = chunk_size + if future_idx >= bounds["end"]: + # Near end of episode: use last frame's progress instead + future_idx = bounds["end"] - 1 + + future_progress = self.progress_lookup.get(future_idx) + if future_progress is None: + return np.nan + + return future_progress - current_progress + + def _compute_weights(self, deltas: np.ndarray) -> np.ndarray: + """ + Compute RA-BC weights from progress deltas. + + Following paper Eq. 8-9: + - Soft weight: ˜wi = clip((ri βˆ’ (Β΅ βˆ’ 2Οƒ)) / (4Οƒ + Ξ΅), 0, 1) + - Final weight: wi = 1{ri > ΞΊ} + 1{0 ≀ ri ≀ ΞΊ}˜wi + + Returns: + Array of weights + """ + valid_mask = ~np.isnan(deltas) + + # Compute soft weights using global statistics + lower_bound = self.delta_mean - 2 * self.delta_std + soft_weights = (deltas - lower_bound) / (4 * self.delta_std + self.epsilon) + soft_weights = np.clip(soft_weights, 0.0, 1.0) + + # Apply paper's Eq. 9 + weights = np.zeros_like(deltas, dtype=np.float32) + + # High quality: ri > kappa β†’ weight = 1 + high_quality_mask = deltas > self.kappa + weights[high_quality_mask] = 1.0 + + # Moderate quality: 0 <= ri <= kappa β†’ weight = soft_weight + moderate_mask = (deltas >= 0) & (deltas <= self.kappa) + weights[moderate_mask] = soft_weights[moderate_mask] + + # Negative progress: ri < 0 β†’ weight = 0 (already 0) + # Invalid (NaN): use fallback weight + weights[~valid_mask] = self.fallback_weight + + return weights + + def _get_batch_size(self, batch: dict) -> int: + """Determine batch size from batch.""" + for key in ["action", "index"]: + if key in batch: + val = batch[key] + if isinstance(val, (torch.Tensor, np.ndarray)): + return val.shape[0] + return 1 + + def get_stats(self) -> dict: + """Get statistics.""" + return { + "num_frames": len(self.progress_lookup), + "chunk_size": self.chunk_size, + "head_mode": self.head_mode, + "delta_mean": self.delta_mean, + "delta_std": self.delta_std, + "kappa": self.kappa, + } diff --git a/tests/datasets/test_dataset_tools.py b/tests/datasets/test_dataset_tools.py index 8bc1dbf6b..3a4516fc8 100644 --- a/tests/datasets/test_dataset_tools.py +++ b/tests/datasets/test_dataset_tools.py @@ -29,6 +29,7 @@ from lerobot.datasets.dataset_tools import ( remove_feature, split_dataset, ) +from lerobot.scripts.lerobot_edit_dataset import convert_dataset_to_videos @pytest.fixture @@ -1047,3 +1048,107 @@ def test_modify_features_preserves_file_structure(sample_dataset, tmp_path): assert new_chunk_indices == original_chunk_indices, "Chunk indices should be preserved" assert new_file_indices == original_file_indices, "File indices should be preserved" assert "reward" in modified_dataset.meta.features + + +def test_convert_dataset_to_videos(tmp_path): + """Test converting lerobot/pusht_image dataset to video format.""" + from lerobot.datasets.lerobot_dataset import LeRobotDataset + + # Load the actual lerobot/pusht_image dataset (only first 2 episodes for speed) + source_dataset = LeRobotDataset("lerobot/pusht_image", episodes=[0, 1]) + + output_dir = tmp_path / "pusht_video" + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.return_value = str(output_dir) + + # Verify source dataset has images, not videos + assert len(source_dataset.meta.video_keys) == 0 + assert "observation.image" in source_dataset.meta.features + + # Convert to video dataset (only first 2 episodes for speed) + video_dataset = convert_dataset_to_videos( + dataset=source_dataset, + output_dir=output_dir, + repo_id="lerobot/pusht_video", + vcodec="libsvtav1", + pix_fmt="yuv420p", + g=2, + crf=30, + episode_indices=[0, 1], + num_workers=2, + ) + + # Verify new dataset has videos + assert len(video_dataset.meta.video_keys) > 0 + assert "observation.image" in video_dataset.meta.video_keys + + # Verify correct number of episodes and frames (2 episodes) + assert video_dataset.meta.total_episodes == 2 + # Compare against the actual number of frames in the loaded episodes, not metadata total + assert len(video_dataset) == len(source_dataset) + + # Verify video files exist + for ep_idx in range(video_dataset.meta.total_episodes): + for video_key in video_dataset.meta.video_keys: + video_path = video_dataset.root / video_dataset.meta.get_video_file_path(ep_idx, video_key) + assert video_path.exists(), f"Video file should exist: {video_path}" + + # Verify we can load the dataset and access it + assert len(video_dataset) == video_dataset.meta.total_frames + + # Test that we can actually get an item from the video dataset + item = video_dataset[0] + assert "observation.image" in item + assert "action" in item + + # Cleanup + import shutil + + if output_dir.exists(): + shutil.rmtree(output_dir) + + +def test_convert_dataset_to_videos_subset_episodes(tmp_path): + """Test converting only specific episodes from lerobot/pusht_image to video format.""" + from lerobot.datasets.lerobot_dataset import LeRobotDataset + + # Load the actual lerobot/pusht_image dataset (only first 3 episodes) + source_dataset = LeRobotDataset("lerobot/pusht_image", episodes=[0, 1, 2]) + + output_dir = tmp_path / "pusht_video_subset" + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.return_value = str(output_dir) + + # Convert only episode 0 to video (subset of loaded episodes) + episode_indices = [0] + + video_dataset = convert_dataset_to_videos( + dataset=source_dataset, + output_dir=output_dir, + repo_id="lerobot/pusht_video_subset", + episode_indices=episode_indices, + num_workers=2, + ) + + # Verify correct number of episodes + assert video_dataset.meta.total_episodes == len(episode_indices) + + # Verify video files exist for selected episodes + assert len(video_dataset.meta.video_keys) > 0 + assert "observation.image" in video_dataset.meta.video_keys + + # Cleanup + import shutil + + if output_dir.exists(): + shutil.rmtree(output_dir) diff --git a/tests/policies/test_sarm_processor.py b/tests/policies/test_sarm_processor.py new file mode 100644 index 000000000..66404f663 --- /dev/null +++ b/tests/policies/test_sarm_processor.py @@ -0,0 +1,694 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +pytest.importorskip("faker") + +from unittest.mock import MagicMock, patch + +import numpy as np +import pandas as pd +import pytest +import torch + +from lerobot.processor.core import TransitionKey + + +class MockDatasetMeta: + """Mock dataset metadata for testing processor.""" + + def __init__(self, episodes: list[dict]): + self._episodes = episodes + + @property + def episodes(self): + """Return episodes as a mock object with to_pandas() method.""" + mock = MagicMock() + mock.__len__ = lambda s: len(self._episodes) + mock.__getitem__ = lambda s, idx: self._episodes[idx] + mock.to_pandas = lambda: pd.DataFrame(self._episodes) + return mock + + +class MockConfig: + """Mock SARMConfig for testing processor methods.""" + + def __init__( + self, + n_obs_steps: int = 8, + max_rewind_steps: int = 4, + frame_gap: int = 30, + sparse_subtask_names: list = None, + sparse_temporal_proportions: list = None, + dense_subtask_names: list = None, + dense_temporal_proportions: list = None, + image_key: str = "observation.images.top", + state_key: str = "observation.state", + max_state_dim: int = 32, + device: str = None, + rewind_probability: float = 0.8, + language_perturbation_probability: float = 0.2, + annotation_mode: str = "dual", + clip_batch_size: int = 64, + text_dim: int = 512, + ): + self.n_obs_steps = n_obs_steps + self.max_rewind_steps = max_rewind_steps + self.frame_gap = frame_gap + self.sparse_subtask_names = sparse_subtask_names or ["task"] + self.sparse_temporal_proportions = sparse_temporal_proportions or [1.0] + self.dense_subtask_names = dense_subtask_names + self.dense_temporal_proportions = dense_temporal_proportions + self.uses_dual_heads = annotation_mode in ["dense_only", "dual"] + self.image_key = image_key + self.state_key = state_key + self.max_state_dim = max_state_dim + self.device = device + self.rewind_probability = rewind_probability + self.language_perturbation_probability = language_perturbation_probability + self.annotation_mode = annotation_mode + self.clip_batch_size = clip_batch_size + self.text_dim = text_dim + + # Compute observation delta indices (same as config: bidirectional) + half_steps = self.n_obs_steps // 2 + past_deltas = [-self.frame_gap * i for i in range(half_steps, 0, -1)] + future_deltas = [self.frame_gap * i for i in range(1, half_steps + 1)] + obs_deltas = past_deltas + [0] + future_deltas + rewind_deltas = [-self.frame_gap * (i + 1) for i in range(self.max_rewind_steps)] + self.observation_delta_indices = obs_deltas + rewind_deltas + + @property + def num_frames(self) -> int: + return 1 + self.n_obs_steps + self.max_rewind_steps + + +class TestSARMEncodingProcessorStepEndToEnd: + """End-to-end test for SARMEncodingProcessorStep with dummy batch data.""" + + @pytest.fixture + def mock_clip_model(self): + """Mock CLIP model to avoid loading real weights.""" + with ( + patch("lerobot.policies.sarm.processor_sarm.CLIPModel") as mock_model_cls, + patch("lerobot.policies.sarm.processor_sarm.CLIPProcessor") as mock_processor_cls, + ): + # Mock the CLIP model - return embeddings based on input batch size + mock_model = MagicMock() + + def get_image_features_side_effect(**kwargs): + pixel_values = kwargs.get("pixel_values") + batch_size = pixel_values.shape[0] if pixel_values is not None else 1 + return torch.randn(batch_size, 512) + + mock_model.get_image_features.side_effect = get_image_features_side_effect + mock_model.get_text_features.return_value = torch.randn(1, 512) + mock_model.to.return_value = mock_model + mock_model_cls.from_pretrained.return_value = mock_model + + # Mock the CLIP processor - return tensors based on input images + mock_processor = MagicMock() + + def processor_side_effect(images=None, **kwargs): + num_images = len(images) if images is not None else 1 + return { + "pixel_values": torch.randn(num_images, 3, 224, 224), + } + + mock_processor.side_effect = processor_side_effect + # Mock tokenizer for text encoding + mock_processor.tokenizer.return_value = { + "input_ids": torch.ones(1, 77, dtype=torch.long), + "attention_mask": torch.ones(1, 77, dtype=torch.long), + } + mock_processor_cls.from_pretrained.return_value = mock_processor + + yield mock_model, mock_processor + + @pytest.fixture + def processor_with_mocks(self, mock_clip_model): + """Create a processor with mocked CLIP and dataset metadata for dual mode.""" + from lerobot.policies.sarm.processor_sarm import SARMEncodingProcessorStep + + # Dual mode config with both sparse and dense annotations + config = MockConfig( + n_obs_steps=8, + max_rewind_steps=4, + frame_gap=30, + rewind_probability=0.0, # Disable for deterministic test + language_perturbation_probability=0.0, # Disable for deterministic test + annotation_mode="dual", + sparse_subtask_names=["reach", "grasp", "lift"], + sparse_temporal_proportions=[0.3, 0.4, 0.3], + dense_subtask_names=["approach", "contact", "close_gripper", "lift_up"], + dense_temporal_proportions=[0.25, 0.25, 0.25, 0.25], + ) + + # Create mock dataset metadata with one episode of 300 frames + # Include annotation columns for dual mode + episodes = [ + { + "dataset_from_index": 0, + "dataset_to_index": 300, + "task": "pick up the cube", + "sparse_subtask_names": ["reach", "grasp", "lift"], + "sparse_subtask_start_frames": [0, 90, 210], + "sparse_subtask_end_frames": [90, 210, 300], + "dense_subtask_names": ["approach", "contact", "close_gripper", "lift_up"], + "dense_subtask_start_frames": [0, 75, 150, 225], + "dense_subtask_end_frames": [75, 150, 225, 300], + } + ] + dataset_meta = MockDatasetMeta(episodes) + + processor = SARMEncodingProcessorStep( + config=config, + dataset_meta=dataset_meta, + ) + processor.train(True) # Use train() method, not direct assignment + + return processor, config + + def test_call_with_single_frame_batch(self, processor_with_mocks): + """Test processor __call__ with a single-frame batch.""" + processor, config = processor_with_mocks + + # Create dummy input transition + batch_size = 1 + num_frames = config.num_frames # 13 frames (9 obs + 4 rewind) + + # Image: (T, C, H, W) format as expected by processor + dummy_image = np.random.rand(num_frames, 3, 224, 224).astype(np.float32) + + # State: (T, D) format + dummy_state = np.random.rand(num_frames, 6).astype(np.float32) + + transition = { + TransitionKey.OBSERVATION: { + config.image_key: dummy_image, + config.state_key: dummy_state, + }, + TransitionKey.COMPLEMENTARY_DATA: { + "index": 150, # Middle of episode + "episode_index": 0, + "task": "pick up the cube", + }, + } + + # Run processor + result = processor(transition) + + # Verify output structure + obs = result[TransitionKey.OBSERVATION] + + # Check video features exist and have correct shape + assert "video_features" in obs + video_features = obs["video_features"] + assert video_features.shape[0] == batch_size + assert video_features.shape[1] == num_frames + assert video_features.shape[2] == 512 # CLIP embedding dim + + # Check state features exist and have correct shape + assert "state_features" in obs + state_features = obs["state_features"] + assert state_features.shape[0] == batch_size + assert state_features.shape[1] == num_frames + assert state_features.shape[2] == config.max_state_dim # Padded to max_state_dim + + # Check text features exist and have correct shape + assert "text_features" in obs + text_features = obs["text_features"] + assert text_features.shape[0] == batch_size + assert text_features.shape[1] == 512 # CLIP embedding dim + + # Check lengths tensor + assert "lengths" in obs + lengths = obs["lengths"] + assert lengths.shape[0] == batch_size + assert lengths.dtype == torch.int32 + + # Check sparse_targets exist + assert "sparse_targets" in obs + sparse_targets = obs["sparse_targets"] + assert sparse_targets.shape == (batch_size, num_frames) + # All targets should be in [0, max_stages] range (stage.tau format) + assert (sparse_targets >= 0).all() + + # Check dense_targets exist (for dual mode) + assert "dense_targets" in obs + dense_targets = obs["dense_targets"] + assert dense_targets.shape == (batch_size, num_frames) + assert (dense_targets >= 0).all() + + def test_call_with_batched_input(self, mock_clip_model): + """Test processor __call__ with a batched input (multiple frames) in dual mode.""" + from lerobot.policies.sarm.processor_sarm import SARMEncodingProcessorStep + + config = MockConfig( + n_obs_steps=8, + max_rewind_steps=4, + frame_gap=30, + rewind_probability=0.0, + language_perturbation_probability=0.0, + annotation_mode="dual", + sparse_subtask_names=["reach", "grasp"], + sparse_temporal_proportions=[0.5, 0.5], + dense_subtask_names=["step1", "step2", "step3"], + dense_temporal_proportions=[0.33, 0.34, 0.33], + ) + + # Two episodes with different lengths, each with sparse+dense annotations + episodes = [ + { + "dataset_from_index": 0, + "dataset_to_index": 200, + "task": "task A", + "sparse_subtask_names": ["reach", "grasp"], + "sparse_subtask_start_frames": [0, 100], + "sparse_subtask_end_frames": [100, 200], + "dense_subtask_names": ["step1", "step2", "step3"], + "dense_subtask_start_frames": [0, 66, 133], + "dense_subtask_end_frames": [66, 133, 200], + }, + { + "dataset_from_index": 200, + "dataset_to_index": 500, + "task": "task B", + "sparse_subtask_names": ["reach", "grasp"], + "sparse_subtask_start_frames": [200, 350], + "sparse_subtask_end_frames": [350, 500], + "dense_subtask_names": ["step1", "step2", "step3"], + "dense_subtask_start_frames": [200, 300, 400], + "dense_subtask_end_frames": [300, 400, 500], + }, + ] + dataset_meta = MockDatasetMeta(episodes) + + processor = SARMEncodingProcessorStep(config=config, dataset_meta=dataset_meta) + processor.train(True) + + batch_size = 2 + num_frames = config.num_frames + + # Image: (B, T, C, H, W) format + dummy_image = np.random.rand(batch_size, num_frames, 3, 224, 224).astype(np.float32) + dummy_state = np.random.rand(batch_size, num_frames, 6).astype(np.float32) + + transition = { + TransitionKey.OBSERVATION: { + config.image_key: dummy_image, + config.state_key: dummy_state, + }, + TransitionKey.COMPLEMENTARY_DATA: { + "index": np.array([100, 350]), # One frame from each episode + "episode_index": np.array([0, 1]), + "task": ["task A", "task B"], + }, + } + + result = processor(transition) + obs = result[TransitionKey.OBSERVATION] + + # Verify batch dimension is preserved for all outputs + assert obs["video_features"].shape[0] == batch_size + assert obs["state_features"].shape[0] == batch_size + assert obs["lengths"].shape[0] == batch_size + assert obs["sparse_targets"].shape[0] == batch_size + assert obs["dense_targets"].shape[0] == batch_size # Dual mode has dense targets + + def test_targets_increase_with_progress(self, mock_clip_model): + """Test that both sparse and dense targets increase as frame index progresses.""" + from lerobot.policies.sarm.processor_sarm import SARMEncodingProcessorStep + + config = MockConfig( + n_obs_steps=8, + max_rewind_steps=4, + frame_gap=30, + rewind_probability=0.0, + language_perturbation_probability=0.0, + annotation_mode="dual", + sparse_subtask_names=["phase1", "phase2"], + sparse_temporal_proportions=[0.5, 0.5], + dense_subtask_names=["a", "b", "c", "d"], + dense_temporal_proportions=[0.25, 0.25, 0.25, 0.25], + ) + + episodes = [ + { + "dataset_from_index": 0, + "dataset_to_index": 300, + "task": "test task", + "sparse_subtask_names": ["phase1", "phase2"], + "sparse_subtask_start_frames": [0, 150], + "sparse_subtask_end_frames": [150, 300], + "dense_subtask_names": ["a", "b", "c", "d"], + "dense_subtask_start_frames": [0, 75, 150, 225], + "dense_subtask_end_frames": [75, 150, 225, 300], + } + ] + dataset_meta = MockDatasetMeta(episodes) + + processor = SARMEncodingProcessorStep(config=config, dataset_meta=dataset_meta) + processor.train(True) + + num_frames = config.num_frames + + # Test at early, middle, and late points in episode + frame_indices = [30, 150, 270] + sparse_center_targets = [] + dense_center_targets = [] + + for frame_idx in frame_indices: + dummy_image = np.random.rand(num_frames, 3, 224, 224).astype(np.float32) + dummy_state = np.random.rand(num_frames, 6).astype(np.float32) + + transition = { + TransitionKey.OBSERVATION: { + config.image_key: dummy_image, + config.state_key: dummy_state, + }, + TransitionKey.COMPLEMENTARY_DATA: { + "index": frame_idx, + "episode_index": 0, + "task": "test task", + }, + } + + result = processor(transition) + obs = result[TransitionKey.OBSERVATION] + # Get target at center frame (index 4 in 9-frame observation window) + sparse_center_targets.append(obs["sparse_targets"][0, 4].item()) + dense_center_targets.append(obs["dense_targets"][0, 4].item()) + + # Both sparse and dense targets should increase with frame index + assert sparse_center_targets[0] < sparse_center_targets[2], ( + f"Early sparse target ({sparse_center_targets[0]}) should be < late ({sparse_center_targets[2]})" + ) + assert dense_center_targets[0] < dense_center_targets[2], ( + f"Early dense target ({dense_center_targets[0]}) should be < late ({dense_center_targets[2]})" + ) + + def test_progress_labels_exact_values(self, mock_clip_model): + """Test that progress labels (stage.tau) are computed correctly for known positions.""" + from lerobot.policies.sarm.processor_sarm import SARMEncodingProcessorStep + + # Simple setup: 2 sparse stages, 4 dense stages, 100 frame episode + config = MockConfig( + n_obs_steps=8, + max_rewind_steps=4, + frame_gap=10, # Smaller gap for easier calculation + rewind_probability=0.0, + language_perturbation_probability=0.0, + annotation_mode="dual", + sparse_subtask_names=["A", "B"], + sparse_temporal_proportions=[0.5, 0.5], + dense_subtask_names=["d1", "d2", "d3", "d4"], + dense_temporal_proportions=[0.25, 0.25, 0.25, 0.25], + ) + + # Episode: frames 0-99, sparse stages at [0-49], [50-99] + # Dense stages at [0-24], [25-49], [50-74], [75-99] + episodes = [ + { + "dataset_from_index": 0, + "dataset_to_index": 100, + "task": "test", + "sparse_subtask_names": ["A", "B"], + "sparse_subtask_start_frames": [0, 50], + "sparse_subtask_end_frames": [50, 100], + "dense_subtask_names": ["d1", "d2", "d3", "d4"], + "dense_subtask_start_frames": [0, 25, 50, 75], + "dense_subtask_end_frames": [25, 50, 75, 100], + } + ] + dataset_meta = MockDatasetMeta(episodes) + + processor = SARMEncodingProcessorStep(config=config, dataset_meta=dataset_meta) + processor.train(True) + + num_frames = config.num_frames + + # Test at frame 50 (center of episode) + # With frame_gap=10, n_obs_steps=8: + # obs indices around frame 50: [10, 20, 30, 40, 50, 60, 70, 80, 90] (9 frames) + dummy_image = np.random.rand(num_frames, 3, 224, 224).astype(np.float32) + dummy_state = np.random.rand(num_frames, 6).astype(np.float32) + + transition = { + TransitionKey.OBSERVATION: { + config.image_key: dummy_image, + config.state_key: dummy_state, + }, + TransitionKey.COMPLEMENTARY_DATA: { + "index": 50, + "episode_index": 0, + "task": "test", + }, + } + + result = processor(transition) + obs = result[TransitionKey.OBSERVATION] + sparse_targets = obs["sparse_targets"][0] # (13,) + dense_targets = obs["dense_targets"][0] # (13,) + + # First 9 frames are observation frames, last 4 are rewind placeholders (zeros when no rewind) + # Check that obs frames have non-zero targets + obs_sparse = sparse_targets[:9] + obs_dense = dense_targets[:9] + + # Verify targets are monotonically increasing for observation frames + for i in range(1, 9): + assert obs_sparse[i] >= obs_sparse[i - 1], ( + f"Sparse targets should be monotonic: {obs_sparse[i - 1].item():.3f} -> {obs_sparse[i].item():.3f}" + ) + assert obs_dense[i] >= obs_dense[i - 1], ( + f"Dense targets should be monotonic: {obs_dense[i - 1].item():.3f} -> {obs_dense[i].item():.3f}" + ) + + # Rewind slots should be zero when rewind is disabled + rewind_targets = sparse_targets[9:] + assert (rewind_targets == 0).all(), "Rewind slots should be zero when rewind is disabled" + + # Check stage transitions: frame 50 is at boundary of sparse stage A->B + # Center frame (index 4) corresponds to actual frame 50 + center_sparse = obs_sparse[4].item() + # At frame 50, sparse stage B starts, so target should be ~1.0 (stage 1 + tau 0) + assert 0.9 <= center_sparse <= 1.1, ( + f"At sparse boundary, target should be ~1.0, got {center_sparse:.3f}" + ) + + def test_rewind_augmentation_applied(self, mock_clip_model): + """Test that rewind augmentation correctly extends sequence and generates targets.""" + import random + + from lerobot.policies.sarm.processor_sarm import SARMEncodingProcessorStep + + config = MockConfig( + n_obs_steps=8, + max_rewind_steps=4, + frame_gap=10, + rewind_probability=1.0, # Always apply rewind + language_perturbation_probability=0.0, + annotation_mode="dual", + sparse_subtask_names=["A", "B"], + sparse_temporal_proportions=[0.5, 0.5], + dense_subtask_names=["d1", "d2"], + dense_temporal_proportions=[0.5, 0.5], + ) + + episodes = [ + { + "dataset_from_index": 0, + "dataset_to_index": 200, + "task": "test", + "sparse_subtask_names": ["A", "B"], + "sparse_subtask_start_frames": [0, 100], + "sparse_subtask_end_frames": [100, 200], + "dense_subtask_names": ["d1", "d2"], + "dense_subtask_start_frames": [0, 100], + "dense_subtask_end_frames": [100, 200], + } + ] + dataset_meta = MockDatasetMeta(episodes) + + processor = SARMEncodingProcessorStep(config=config, dataset_meta=dataset_meta) + processor.train(True) + + num_frames = config.num_frames # 13 + + # Test at frame 150 (center of bidirectional window) + # With n_obs_steps=8, half_steps=4, frame_gap=10: + # - Earliest obs frame = 150 - 4*10 = 110 + # - Rewind can go back from 110 to frames like 100, 90, 80, 70 + # - History available = 110 - 0 = 110, so max rewind = 110/10 = 11 (capped at 4) + dummy_image = np.random.rand(num_frames, 3, 224, 224).astype(np.float32) + dummy_state = np.random.rand(num_frames, 6).astype(np.float32) + + transition = { + TransitionKey.OBSERVATION: { + config.image_key: dummy_image, + config.state_key: dummy_state, + }, + TransitionKey.COMPLEMENTARY_DATA: { + "index": 150, + "episode_index": 0, + "task": "test", + }, + } + + # Seed random for reproducibility + random.seed(42) + result = processor(transition) + obs = result[TransitionKey.OBSERVATION] + + lengths = obs["lengths"][0].item() + sparse_targets = obs["sparse_targets"][0] + + # With rewind_probability=1.0 and enough history, lengths should be > 9 (9 obs + some rewind) + assert lengths > 9, f"With rewind enabled, lengths should be > 9, got {lengths}" + assert lengths <= num_frames, f"Lengths should not exceed total frames {num_frames}, got {lengths}" + + # Rewind targets should be non-zero for frames within valid length + n_obs_frames = 9 + rewind_count = lengths - n_obs_frames + + if rewind_count > 0: + # Check that rewind frames have targets + rewind_targets = sparse_targets[n_obs_frames : n_obs_frames + rewind_count] + # Rewind frames are from BEFORE the earliest obs frame (110) + # These frames (100, 90, 80, 70) are earlier in the episode + earliest_obs_target = sparse_targets[0].item() # Frame 110 + + # Rewind targets should be less than earliest obs (they're from earlier frames) + for i, rt in enumerate(rewind_targets): + assert rt.item() < earliest_obs_target, ( + f"Rewind target {i} ({rt.item():.3f}) should be < earliest obs ({earliest_obs_target:.3f})" + ) + + # Rewind targets should be decreasing (going further back in time) + for i in range(1, len(rewind_targets)): + assert rewind_targets[i] <= rewind_targets[i - 1], ( + f"Rewind targets should decrease: {rewind_targets[i - 1].item():.3f} -> {rewind_targets[i].item():.3f}" + ) + + def test_full_sequence_target_consistency(self, mock_clip_model): + """Test that the full sequence of targets is consistent with frame positions.""" + from lerobot.policies.sarm.processor_sarm import SARMEncodingProcessorStep + from lerobot.policies.sarm.sarm_utils import find_stage_and_tau + + config = MockConfig( + n_obs_steps=8, + max_rewind_steps=4, + frame_gap=10, + rewind_probability=0.0, + language_perturbation_probability=0.0, + annotation_mode="dual", + sparse_subtask_names=["s1", "s2", "s3"], + sparse_temporal_proportions=[0.33, 0.34, 0.33], + dense_subtask_names=["d1", "d2"], + dense_temporal_proportions=[0.5, 0.5], + ) + + # 3 sparse stages: [0-33), [33-66), [66-99] + # 2 dense stages: [0-50), [50-100) + episodes = [ + { + "dataset_from_index": 0, + "dataset_to_index": 100, + "task": "test", + "sparse_subtask_names": ["s1", "s2", "s3"], + "sparse_subtask_start_frames": [0, 33, 66], + "sparse_subtask_end_frames": [33, 66, 100], + "dense_subtask_names": ["d1", "d2"], + "dense_subtask_start_frames": [0, 50], + "dense_subtask_end_frames": [50, 100], + } + ] + dataset_meta = MockDatasetMeta(episodes) + + processor = SARMEncodingProcessorStep(config=config, dataset_meta=dataset_meta) + processor.train(True) + + num_frames = config.num_frames + + # Test at frame 50 (middle of episode) + dummy_image = np.random.rand(num_frames, 3, 224, 224).astype(np.float32) + dummy_state = np.random.rand(num_frames, 6).astype(np.float32) + + transition = { + TransitionKey.OBSERVATION: { + config.image_key: dummy_image, + config.state_key: dummy_state, + }, + TransitionKey.COMPLEMENTARY_DATA: { + "index": 50, + "episode_index": 0, + "task": "test", + }, + } + + result = processor(transition) + obs = result[TransitionKey.OBSERVATION] + sparse_targets = obs["sparse_targets"][0] + dense_targets = obs["dense_targets"][0] + + # Manually compute expected targets for observation frames + # With frame_gap=10, n_obs_steps=8, center at 50: + # obs frames: [10, 20, 30, 40, 50, 60, 70, 80, 90] + expected_obs_frames = [10, 20, 30, 40, 50, 60, 70, 80, 90] + + sparse_names = ["s1", "s2", "s3"] + sparse_starts = [0, 33, 66] + sparse_ends = [33, 66, 100] + sparse_props = {"s1": 0.33, "s2": 0.34, "s3": 0.33} + + dense_names = ["d1", "d2"] + dense_starts = [0, 50] + dense_ends = [50, 100] + dense_props = {"d1": 0.5, "d2": 0.5} + + for i, frame in enumerate(expected_obs_frames): + expected_sparse = find_stage_and_tau( + frame, + 100, + sparse_names, + sparse_starts, + sparse_ends, + sparse_names, + sparse_props, + return_combined=True, + ) + expected_dense = find_stage_and_tau( + frame, + 100, + dense_names, + dense_starts, + dense_ends, + dense_names, + dense_props, + return_combined=True, + ) + + actual_sparse = sparse_targets[i].item() + actual_dense = dense_targets[i].item() + + assert abs(actual_sparse - expected_sparse) < 0.01, ( + f"Frame {frame}: sparse mismatch {actual_sparse:.3f} vs expected {expected_sparse:.3f}" + ) + assert abs(actual_dense - expected_dense) < 0.01, ( + f"Frame {frame}: dense mismatch {actual_dense:.3f} vs expected {expected_dense:.3f}" + ) diff --git a/tests/policies/test_sarm_subtask_annotations.py b/tests/policies/test_sarm_subtask_annotations.py new file mode 100644 index 000000000..0dc087288 --- /dev/null +++ b/tests/policies/test_sarm_subtask_annotations.py @@ -0,0 +1,134 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +pytest.importorskip("transformers") + +from lerobot.data_processing.sarm_annotations.subtask_annotation import ( + Subtask, + SubtaskAnnotation, + Timestamp, + compute_temporal_proportions, +) + + +def make_annotation(subtasks: list[tuple[str, int, int]]) -> SubtaskAnnotation: + """Helper to create SubtaskAnnotation from list of (name, start_sec, end_sec).""" + return SubtaskAnnotation( + subtasks=[ + Subtask( + name=name, + timestamps=Timestamp( + start=f"{start // 60:02d}:{start % 60:02d}", end=f"{end // 60:02d}:{end % 60:02d}" + ), + ) + for name, start, end in subtasks + ] + ) + + +class TestComputeTemporalProportions: + """Tests for compute_temporal_proportions (SARM Paper Formula 1). + + Formula: αΎ±_k = (1/M) Γ— Ξ£_i (L_{i,k} / T_i) + + Key insight: This averages the PROPORTION of each subtask within each trajectory, + giving equal weight to all trajectories regardless of absolute length. + """ + + def test_basic_two_trajectories_equal_proportions(self): + """Test with two trajectories that have equal proportions.""" + # Both trajectories: subtask1 = 50%, subtask2 = 50% + # Traj 1: T=100s, subtask1=50s, subtask2=50s + # Traj 2: T=200s, subtask1=100s, subtask2=100s + annotations = { + 0: make_annotation([("subtask1", 0, 50), ("subtask2", 50, 100)]), + 1: make_annotation([("subtask1", 0, 100), ("subtask2", 100, 200)]), + } + + result = compute_temporal_proportions(annotations) + + # Both should be 0.5 + assert abs(result["subtask1"] - 0.5) < 1e-6 + assert abs(result["subtask2"] - 0.5) < 1e-6 + + def test_paper_example_different_from_avg_durations(self): + """Test that compute_temporal_proportions differs from naive average duration approach. + + This is the key test showing the difference between: + - Paper formula: average of (L_i,k / T_i) + - Naive approach: mean(L_i,k) / sum(mean(L_i,j)) + """ + # Episode 1: T=100s, subtask1=80s, subtask2=20s (proportions: 0.8, 0.2) + # Episode 2: T=200s, subtask1=40s, subtask2=160s (proportions: 0.2, 0.8) + annotations = { + 0: make_annotation([("subtask1", 0, 80), ("subtask2", 80, 100)]), + 1: make_annotation([("subtask1", 0, 40), ("subtask2", 40, 200)]), + } + + result = compute_temporal_proportions(annotations) + + # Paper formula: + # αΎ±_1 = (1/2) Γ— (80/100 + 40/200) = (1/2) Γ— (0.8 + 0.2) = 0.5 + # αΎ±_2 = (1/2) Γ— (20/100 + 160/200) = (1/2) Γ— (0.2 + 0.8) = 0.5 + assert abs(result["subtask1"] - 0.5) < 1e-6 + assert abs(result["subtask2"] - 0.5) < 1e-6 + + def test_single_trajectory(self): + """Test with a single trajectory.""" + # T=100s, reach=30s, grasp=20s, lift=50s + annotations = { + 0: make_annotation([("reach", 0, 30), ("grasp", 30, 50), ("lift", 50, 100)]), + } + + result = compute_temporal_proportions(annotations) + + assert abs(result["reach"] - 0.3) < 1e-6 + assert abs(result["grasp"] - 0.2) < 1e-6 + assert abs(result["lift"] - 0.5) < 1e-6 + + def test_sum_to_one(self): + """Test that proportions always sum to 1.""" + # Three episodes with varying proportions + annotations = { + 0: make_annotation([("a", 0, 10), ("b", 10, 50), ("c", 50, 100)]), # 0.1, 0.4, 0.5 + 1: make_annotation([("a", 0, 20), ("b", 20, 70), ("c", 70, 100)]), # 0.2, 0.5, 0.3 + 2: make_annotation([("a", 0, 30), ("b", 30, 90), ("c", 90, 100)]), # 0.3, 0.6, 0.1 + } + + result = compute_temporal_proportions(annotations) + + total = sum(result.values()) + assert abs(total - 1.0) < 1e-6 + + def test_empty_annotations_returns_empty(self): + """Test that empty annotations returns empty dict.""" + result = compute_temporal_proportions({}) + assert result == {} + + def test_uniform_proportions(self): + """Test with uniform proportions across subtasks.""" + # Each subtask takes 25% of each episode + annotations = { + 0: make_annotation([("a", 0, 25), ("b", 25, 50), ("c", 50, 75), ("d", 75, 100)]), + 1: make_annotation([("a", 0, 50), ("b", 50, 100), ("c", 100, 150), ("d", 150, 200)]), + } + + result = compute_temporal_proportions(annotations) + + for name in ["a", "b", "c", "d"]: + assert abs(result[name] - 0.25) < 1e-6 diff --git a/tests/policies/test_sarm_utils.py b/tests/policies/test_sarm_utils.py new file mode 100644 index 000000000..510477ec8 --- /dev/null +++ b/tests/policies/test_sarm_utils.py @@ -0,0 +1,615 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest +import torch + +from lerobot.policies.sarm.sarm_utils import ( + apply_rewind_augmentation, + compute_absolute_indices, + compute_tau, + find_stage_and_tau, + normalize_stage_tau, + temporal_proportions_to_breakpoints, +) + + +class TestProgressLabelsWithModes: + """End-to-end tests for progress label generation in different modes.""" + + def test_sparse_mode_single_stage(self): + """Sparse mode with single stage should give linear progress.""" + episode_length = 300 + global_names = ["task"] + proportions = {"task": 1.0} + + # Test at various frames + for frame in [0, 100, 200, 299]: + stage, tau = find_stage_and_tau( + frame, episode_length, None, None, None, global_names, proportions + ) + + expected_tau = frame / (episode_length - 1) + assert stage == 0 + assert abs(tau - expected_tau) < 1e-5 + + def test_sparse_mode_multi_stage(self): + """Sparse mode with multiple stages.""" + global_names = ["reach", "grasp", "lift", "place"] + proportions = {"reach": 0.2, "grasp": 0.2, "lift": 0.3, "place": 0.3} + + subtask_names = ["reach", "grasp", "lift", "place"] + subtask_starts = [0, 60, 120, 210] + subtask_ends = [59, 119, 209, 299] + + # Check stages are correctly identified + stage_at_30, _ = find_stage_and_tau( + 30, 300, subtask_names, subtask_starts, subtask_ends, global_names, proportions + ) + assert stage_at_30 == 0 + + stage_at_90, _ = find_stage_and_tau( + 90, 300, subtask_names, subtask_starts, subtask_ends, global_names, proportions + ) + assert stage_at_90 == 1 + + stage_at_150, _ = find_stage_and_tau( + 150, 300, subtask_names, subtask_starts, subtask_ends, global_names, proportions + ) + assert stage_at_150 == 2 + + def test_dense_mode_more_stages(self): + """Dense mode should work with more fine-grained stages.""" + global_names = ["a", "b", "c", "d", "e", "f", "g", "h"] + proportions = dict.fromkeys(global_names, 1 / 8) + + subtask_names = global_names + subtask_starts = [i * 50 for i in range(8)] + subtask_ends = [(i + 1) * 50 - 1 for i in range(8)] + + # Each stage should occupy 50 frames + for stage_idx in range(8): + mid_frame = stage_idx * 50 + 25 + stage, _ = find_stage_and_tau( + mid_frame, 400, subtask_names, subtask_starts, subtask_ends, global_names, proportions + ) + assert stage == stage_idx + + +class TestComputeAbsoluteIndices: + """Tests for compute_absolute_indices (bidirectional sampling).""" + + def test_no_clamping_when_in_middle(self): + """When frame is in middle of episode, no clamping should occur.""" + frame_idx = 300 + ep_start = 0 + ep_end = 1000 + n_obs_steps = 8 + frame_gap = 30 + + indices, out_of_bounds = compute_absolute_indices(frame_idx, ep_start, ep_end, n_obs_steps, frame_gap) + + # All should be valid (no out of bounds) + assert out_of_bounds.sum() == 0 + + # Check bidirectional indices: [-120, -90, -60, -30, 0, 30, 60, 90, 120] from center + half_steps = n_obs_steps // 2 + expected = ( + [frame_idx - frame_gap * i for i in range(half_steps, 0, -1)] + + [frame_idx] + + [frame_idx + frame_gap * i for i in range(1, half_steps + 1)] + ) + assert indices.tolist() == expected + + # Center frame (index 4) should be frame_idx + assert indices[half_steps] == frame_idx + + def test_clamping_at_episode_start(self): + """Early frames should be clamped to episode start.""" + frame_idx = 50 # Not enough history for full past window + ep_start = 0 + ep_end = 1000 + n_obs_steps = 8 + frame_gap = 30 + + indices, out_of_bounds = compute_absolute_indices(frame_idx, ep_start, ep_end, n_obs_steps, frame_gap) + + # Some past frames should be clamped (out_of_bounds = 1) + assert out_of_bounds.sum() > 0 + + # All indices should be >= ep_start + assert (indices >= ep_start).all() + + # Center index should be frame_idx + half_steps = n_obs_steps // 2 + assert indices[half_steps] == frame_idx + + def test_clamping_at_episode_end(self): + """Late frames should be clamped to episode end.""" + frame_idx = 950 # Not enough future for full window + ep_start = 0 + ep_end = 1000 + n_obs_steps = 8 + frame_gap = 30 + + indices, out_of_bounds = compute_absolute_indices(frame_idx, ep_start, ep_end, n_obs_steps, frame_gap) + + # Some future frames should be clamped + assert out_of_bounds.sum() > 0 + + # All indices should be < ep_end + assert (indices < ep_end).all() + + # Center index should be frame_idx + half_steps = n_obs_steps // 2 + assert indices[half_steps] == frame_idx + + def test_sequence_is_monotonic(self): + """Frame indices should be monotonically increasing.""" + for frame_idx in [50, 100, 300, 950]: + indices, _ = compute_absolute_indices(frame_idx, 0, 1000, 8, 30) + + # Check monotonic (non-decreasing due to clamping) + diffs = indices[1:] - indices[:-1] + assert (diffs >= 0).all(), f"Non-monotonic at frame {frame_idx}" + + +class TestComputeTau: + """Tests for compute_tau (within-subtask progress). + + Formula: Ο„_t = (t - s_k) / (e_k - s_k) ∈ [0, 1] + """ + + def test_at_start(self): + """Ο„ should be 0 at subtask start.""" + tau = compute_tau(current_frame=10, subtask_start=10, subtask_end=50) + assert tau == 0.0 + + def test_at_end(self): + """Ο„ should be 1 at subtask end.""" + tau = compute_tau(current_frame=50, subtask_start=10, subtask_end=50) + assert tau == 1.0 + + def test_at_middle(self): + """Ο„ should be 0.5 at subtask midpoint.""" + tau = compute_tau(current_frame=30, subtask_start=10, subtask_end=50) + assert abs(tau - 0.5) < 1e-6 + + def test_quarter_progress(self): + """Test Ο„ at 25% through subtask.""" + tau = compute_tau(current_frame=20, subtask_start=0, subtask_end=80) + assert abs(tau - 0.25) < 1e-6 + + def test_zero_duration_subtask(self): + """Ο„ should be 1.0 for zero-duration subtask.""" + tau = compute_tau(current_frame=10, subtask_start=10, subtask_end=10) + assert tau == 1.0 + + def test_clamps_below_zero(self): + """Ο„ should be clamped to 0 if frame is before subtask.""" + tau = compute_tau(current_frame=5, subtask_start=10, subtask_end=50) + assert tau == 0.0 + + def test_clamps_above_one(self): + """Ο„ should be clamped to 1 if frame is after subtask.""" + tau = compute_tau(current_frame=60, subtask_start=10, subtask_end=50) + assert tau == 1.0 + + def test_float_inputs(self): + """Test with float frame indices (from interpolation).""" + tau = compute_tau(current_frame=25.5, subtask_start=10.0, subtask_end=50.0) + expected = (25.5 - 10.0) / (50.0 - 10.0) + assert abs(tau - expected) < 1e-6 + + +class TestFindStageAndTau: + """Tests for find_stage_and_tau logic. + + This function is the core of progress label computation. It determines + which stage a frame belongs to and the within-stage progress (tau). + """ + + def test_single_stage_mode_linear_progress(self): + """Single-stage mode should give linear progress from 0 to 1.""" + episode_length = 100 + + # Frame 0 -> tau = 0 + stage, tau = find_stage_and_tau(0, episode_length, None, None, None, ["task"], {"task": 1.0}) + assert stage == 0 + assert abs(tau - 0.0) < 1e-6 + + # Frame 50 -> tau = 0.505 (50/99) + stage, tau = find_stage_and_tau(50, episode_length, None, None, None, ["task"], {"task": 1.0}) + assert stage == 0 + assert abs(tau - 50 / 99) < 1e-6 + + # Frame 99 -> tau = 1.0 + stage, tau = find_stage_and_tau(99, episode_length, None, None, None, ["task"], {"task": 1.0}) + assert stage == 0 + assert abs(tau - 1.0) < 1e-6 + + def test_multi_stage_within_subtask(self): + """Test finding stage when frame is within a subtask.""" + global_names = ["reach", "grasp", "lift"] + proportions = {"reach": 0.3, "grasp": 0.2, "lift": 0.5} + + subtask_names = ["reach", "grasp", "lift"] + subtask_starts = [0, 30, 50] + subtask_ends = [29, 49, 99] + + # Frame 15 in "reach" stage (index 0) + stage, tau = find_stage_and_tau( + 15, 100, subtask_names, subtask_starts, subtask_ends, global_names, proportions + ) + assert stage == 0 + assert abs(tau - 15 / 29) < 1e-6 + + # Frame 40 in "grasp" stage (index 1) + stage, tau = find_stage_and_tau( + 40, 100, subtask_names, subtask_starts, subtask_ends, global_names, proportions + ) + assert stage == 1 + # tau = (40 - 30) / (49 - 30) = 10/19 + assert abs(tau - 10 / 19) < 1e-6 + + # Frame 75 in "lift" stage (index 2) + stage, tau = find_stage_and_tau( + 75, 100, subtask_names, subtask_starts, subtask_ends, global_names, proportions + ) + assert stage == 2 + # tau = (75 - 50) / (99 - 50) = 25/49 + assert abs(tau - 25 / 49) < 1e-6 + + def test_frame_at_subtask_boundaries(self): + """Test frames exactly at subtask boundaries.""" + global_names = ["a", "b"] + proportions = {"a": 0.5, "b": 0.5} + + subtask_names = ["a", "b"] + subtask_starts = [0, 50] + subtask_ends = [49, 99] + + # Frame at start of first subtask + stage, tau = find_stage_and_tau( + 0, 100, subtask_names, subtask_starts, subtask_ends, global_names, proportions + ) + assert stage == 0 + assert tau == 0.0 + + # Frame at end of first subtask + stage, tau = find_stage_and_tau( + 49, 100, subtask_names, subtask_starts, subtask_ends, global_names, proportions + ) + assert stage == 0 + assert tau == 1.0 + + # Frame at start of second subtask + stage, tau = find_stage_and_tau( + 50, 100, subtask_names, subtask_starts, subtask_ends, global_names, proportions + ) + assert stage == 1 + assert tau == 0.0 + + def test_frame_after_last_subtask(self): + """Frames after last subtask should return last stage with high tau.""" + global_names = ["a", "b"] + proportions = {"a": 0.5, "b": 0.5} + + subtask_names = ["a", "b"] + subtask_starts = [0, 30] + subtask_ends = [29, 59] + + # Frame 80 is after last subtask + stage, tau = find_stage_and_tau( + 80, 100, subtask_names, subtask_starts, subtask_ends, global_names, proportions + ) + assert stage == 1 # Last stage + assert tau == 0.999 # Nearly complete + + +class TestEndToEndProgressLabeling: + """End-to-end tests for progress label computation using normalize_stage_tau.""" + + def test_consistent_semantic_meaning(self): + """Test that same subtask completion maps to same progress across trajectories. + + This is the key semantic property: "end of subtask 1" should always + mean the same progress value regardless of trajectory speed. + """ + proportions = [0.3, 0.5, 0.2] + + # Fast trajectory: subtask 1 ends at frame 30 (of 100) + tau_fast = compute_tau(30, 0, 30) # = 1.0 + y_fast = normalize_stage_tau(0 + tau_fast, temporal_proportions=proportions) + + # Slow trajectory: subtask 1 ends at frame 90 (of 300) + tau_slow = compute_tau(90, 0, 90) # = 1.0 + y_slow = normalize_stage_tau(0 + tau_slow, temporal_proportions=proportions) + + # Both should map to same progress (0.3 = end of subtask 1) + assert abs(y_fast - y_slow) < 1e-6 + assert abs(y_fast - 0.3) < 1e-6 + + def test_monotonic_within_subtask(self): + """Test that progress is monotonically increasing within a subtask.""" + proportions = [0.4, 0.6] + + prev_y = -1 + for tau in np.linspace(0, 1, 11): + y = normalize_stage_tau(0 + tau, temporal_proportions=proportions) + assert y > prev_y or (tau == 0 and y == 0) + prev_y = y + + def test_continuous_across_subtasks(self): + """Test that progress is continuous at subtask boundaries.""" + proportions = [0.3, 0.5, 0.2] + + # End of subtask 0 (stage=0, tau=1.0) -> stage.tau = 1.0 + y_end_0 = normalize_stage_tau(0 + 1.0, temporal_proportions=proportions) + + # Start of subtask 1 (stage=1, tau=0.0) -> stage.tau = 1.0 + y_start_1 = normalize_stage_tau(1 + 0.0, temporal_proportions=proportions) + + # Should be equal (P_1 = 0.3) + assert abs(y_end_0 - y_start_1) < 1e-6 + + # End of subtask 1 (stage=1, tau=1.0) -> stage.tau = 2.0 + y_end_1 = normalize_stage_tau(1 + 1.0, temporal_proportions=proportions) + + # Start of subtask 2 (stage=2, tau=0.0) -> stage.tau = 2.0 + y_start_2 = normalize_stage_tau(2 + 0.0, temporal_proportions=proportions) + + # Should be equal (P_2 = 0.8) + assert abs(y_end_1 - y_start_2) < 1e-6 + + +class TestTemporalProportionsToBreakpoints: + """Tests for temporal_proportions_to_breakpoints. + + Converts temporal proportions to cumulative breakpoints for normalization. + Example: [0.3, 0.5, 0.2] -> [0.0, 0.3, 0.8, 1.0] + """ + + def test_basic_conversion(self): + """Test basic conversion from proportions to breakpoints.""" + proportions = [0.3, 0.5, 0.2] + breakpoints = temporal_proportions_to_breakpoints(proportions) + + assert breakpoints is not None + assert len(breakpoints) == 4 + assert breakpoints[0] == 0.0 + assert abs(breakpoints[1] - 0.3) < 1e-6 + assert abs(breakpoints[2] - 0.8) < 1e-6 + assert breakpoints[3] == 1.0 + + def test_dict_input(self): + """Test with dict input.""" + proportions = {"a": 0.25, "b": 0.25, "c": 0.5} + breakpoints = temporal_proportions_to_breakpoints(proportions) + + assert breakpoints is not None + assert len(breakpoints) == 4 + assert breakpoints[0] == 0.0 + assert breakpoints[-1] == 1.0 + + def test_dict_with_subtask_names_order(self): + """Test that subtask_names determines order for dict input.""" + proportions = {"c": 0.5, "a": 0.2, "b": 0.3} # Dict order + subtask_names = ["a", "b", "c"] # Different order + + breakpoints = temporal_proportions_to_breakpoints(proportions, subtask_names) + + # Breakpoints should follow subtask_names order: a=0.2, b=0.3, c=0.5 + assert abs(breakpoints[1] - 0.2) < 1e-6 # a + assert abs(breakpoints[2] - 0.5) < 1e-6 # a + b = 0.5 + assert breakpoints[3] == 1.0 # a + b + c = 1.0 + + def test_uniform_proportions(self): + """Test with uniform proportions.""" + proportions = [0.25, 0.25, 0.25, 0.25] + breakpoints = temporal_proportions_to_breakpoints(proportions) + + expected = [0.0, 0.25, 0.5, 0.75, 1.0] + for i, (bp, exp) in enumerate(zip(breakpoints, expected, strict=True)): + assert abs(bp - exp) < 1e-6, f"Breakpoint {i} mismatch" + + def test_none_input(self): + """Test that None input returns None.""" + result = temporal_proportions_to_breakpoints(None) + assert result is None + + def test_normalization(self): + """Test that non-normalized proportions are normalized.""" + # Proportions sum to 2.0, not 1.0 + proportions = [0.6, 1.0, 0.4] + breakpoints = temporal_proportions_to_breakpoints(proportions) + + # Should be normalized: [0.3, 0.5, 0.2] -> [0, 0.3, 0.8, 1.0] + assert breakpoints[-1] == 1.0 + assert abs(breakpoints[1] - 0.3) < 1e-6 + + +class TestNormalizeStageTau: + """Tests for normalize_stage_tau. + + Normalizes stage+tau values to [0, 1] using breakpoints. + """ + + def test_linear_fallback(self): + """Test linear normalization when only num_stages is provided.""" + # 4 stages, linear: [0, 0.25, 0.5, 0.75, 1.0] + + # Stage 0 start + assert normalize_stage_tau(0.0, num_stages=4) == 0.0 + + # Stage 0 end / Stage 1 start + assert abs(normalize_stage_tau(1.0, num_stages=4) - 0.25) < 1e-6 + + # Stage 1 middle + assert abs(normalize_stage_tau(1.5, num_stages=4) - 0.375) < 1e-6 + + # Stage 3 end + assert normalize_stage_tau(4.0, num_stages=4) == 1.0 + + def test_with_custom_breakpoints(self): + """Test with custom breakpoints.""" + # Non-linear breakpoints + breakpoints = [0.0, 0.1, 0.5, 1.0] # 3 stages + + # Stage 0: maps [0, 1) to [0.0, 0.1) + assert abs(normalize_stage_tau(0.5, breakpoints=breakpoints) - 0.05) < 1e-6 + + # Stage 1: maps [1, 2) to [0.1, 0.5) + assert abs(normalize_stage_tau(1.5, breakpoints=breakpoints) - 0.3) < 1e-6 + + # Stage 2: maps [2, 3) to [0.5, 1.0) + assert abs(normalize_stage_tau(2.5, breakpoints=breakpoints) - 0.75) < 1e-6 + + def test_with_temporal_proportions(self): + """Test with temporal proportions (auto-computed breakpoints).""" + proportions = {"a": 0.2, "b": 0.3, "c": 0.5} + subtask_names = ["a", "b", "c"] + + # Stage 0 end should map to 0.2 + result = normalize_stage_tau(1.0, temporal_proportions=proportions, subtask_names=subtask_names) + assert abs(result - 0.2) < 1e-6 + + # Stage 1 end should map to 0.5 + result = normalize_stage_tau(2.0, temporal_proportions=proportions, subtask_names=subtask_names) + assert abs(result - 0.5) < 1e-6 + + def test_tensor_input(self): + """Test with tensor input.""" + x = torch.tensor([0.0, 0.5, 1.0, 1.5, 2.0]) + breakpoints = [0.0, 0.3, 0.8, 1.0] # 3 stages + + result = normalize_stage_tau(x, breakpoints=breakpoints) + + assert isinstance(result, torch.Tensor) + assert result.shape == x.shape + assert abs(result[0].item() - 0.0) < 1e-6 + assert abs(result[2].item() - 0.3) < 1e-6 # End of stage 0 + assert abs(result[4].item() - 0.8) < 1e-6 # End of stage 1 + + def test_clamping(self): + """Test that output is clamped to [0, 1].""" + # Below 0 + assert normalize_stage_tau(-0.5, num_stages=4) == 0.0 + + # Above num_stages + assert normalize_stage_tau(5.0, num_stages=4) == 1.0 + + def test_batch_tensor(self): + """Test with batched tensor.""" + x = torch.tensor([[0.0, 1.0, 2.0], [0.5, 1.5, 2.5]]) # (2, 3) + + result = normalize_stage_tau(x, num_stages=3) + + assert result.shape == (2, 3) + assert (result >= 0).all() + assert (result <= 1).all() + + def test_requires_one_of_inputs(self): + """Test that at least one input method is required.""" + with pytest.raises(ValueError): + normalize_stage_tau(1.0) + + +class TestRewindAugmentation: + """Tests for rewind augmentation logic with bidirectional observation sampling. + + Rewind appends frames before the earliest observation frame, going backwards. + With bidirectional sampling centered at frame_idx: + - Earliest obs frame = frame_idx - half_steps * frame_gap + - Rewind goes backwards from that point + """ + + def test_rewind_indices_go_backwards_from_earliest_obs(self): + """Rewind indices should go backwards from earliest observation frame.""" + frame_idx = 300 # Center of bidirectional window + ep_start = 0 + n_obs_steps = 4 # half_steps = 2 + frame_gap = 30 + + # Earliest obs frame = 300 - 2*30 = 240 + # Rewind goes backwards: 210, 180 + rewind_step, rewind_indices = apply_rewind_augmentation( + frame_idx, + ep_start, + n_obs_steps=n_obs_steps, + max_rewind_steps=2, + frame_gap=frame_gap, + rewind_step=2, + ) + + assert rewind_step == 2 + assert len(rewind_indices) == 2 + # First rewind frame is closest to obs window, second is further back + assert rewind_indices[0] == 210 # 240 - 30 + assert rewind_indices[1] == 180 # 240 - 60 + assert rewind_indices[0] > rewind_indices[1], "Rewind should be descending" + + def test_rewind_goes_backward_through_history(self): + """Rewind frames should go backward before the observation window.""" + frame_idx = 450 # Center of bidirectional window + ep_start = 0 + n_obs_steps = 8 # half_steps = 4 + frame_gap = 30 + + # Earliest obs frame = 450 - 4*30 = 330 + # Rewind from 330: [300, 270, 240] + rewind_step, rewind_indices = apply_rewind_augmentation( + frame_idx, + ep_start, + n_obs_steps=n_obs_steps, + max_rewind_steps=4, + frame_gap=frame_gap, + rewind_step=3, + ) + + assert rewind_step == 3 + expected = [300, 270, 240] # Going backwards from 330 + assert rewind_indices == expected + + def test_no_rewind_when_obs_window_at_episode_start(self): + """No rewind when observation window reaches episode start.""" + frame_idx = 120 # Center of window + ep_start = 0 + n_obs_steps = 8 # half_steps = 4 + frame_gap = 30 + + # Earliest obs frame = 120 - 4*30 = 0 (at episode start) + rewind_step, rewind_indices = apply_rewind_augmentation( + frame_idx, ep_start, n_obs_steps=n_obs_steps, max_rewind_steps=4, frame_gap=frame_gap + ) + + # No room for rewind + assert rewind_step == 0 + assert rewind_indices == [] + + def test_rewind_targets_are_decreasing(self): + """Progress targets for rewind frames should be decreasing.""" + # Simulate progress values + obs_progress = [0.1, 0.2, 0.3, 0.4, 0.5] # Forward progress + + # Rewind reverses progress + rewind_indices = [4, 3, 2] # Go backwards through indices + rewind_progress = [obs_progress[i] for i in rewind_indices] + + # Should be decreasing + for i in range(len(rewind_progress) - 1): + assert rewind_progress[i] > rewind_progress[i + 1] diff --git a/tests/policies/wall_x/test_wallx.py b/tests/policies/wall_x/test_wallx.py new file mode 100644 index 000000000..837907041 --- /dev/null +++ b/tests/policies/wall_x/test_wallx.py @@ -0,0 +1,140 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test script to verify Wall-X policy integration with LeRobot, only meant to be run locally!""" + +import os + +import pytest +import torch + +# Skip if openpi or transformers is not available +pytest.importorskip("peft") +pytest.importorskip("transformers==4.49.0") + +# Skip this entire module in CI +pytestmark = pytest.mark.skipif( + os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true", + reason="This test requires local Wall-X installation and is not meant for CI", +) + +from lerobot.policies.factory import make_policy_config # noqa: E402 +from lerobot.policies.wall_x import WallXConfig # noqa: E402 +from lerobot.policies.wall_x.modeling_wall_x import WallXPolicy # noqa: E402 +from lerobot.policies.wall_x.processor_wall_x import make_wall_x_pre_post_processors # noqa: E402 +from lerobot.utils.random_utils import set_seed # noqa: E402 + + +def test_policy_instantiation(): + # Create config + set_seed(42) + config = WallXConfig(device="cuda") + + # Set up input_features and output_features in the config + from lerobot.configs.types import FeatureType, PolicyFeature + + config.input_features = { + "observation.state": PolicyFeature( + type=FeatureType.STATE, + shape=(7,), + ), + "observation.images.face_view": PolicyFeature( + type=FeatureType.VISUAL, + shape=(3, 224, 224), + ), + } + + config.output_features = { + "action": PolicyFeature( + type=FeatureType.ACTION, + shape=(7,), + ), + } + + # Create dummy dataset stats + dataset_stats = { + "observation.state": { + "mean": torch.zeros(7), + "std": torch.ones(7), + }, + "action": { + "mean": torch.zeros(7), + "std": torch.ones(7), + }, + "observation.images.face_view": { + "mean": torch.zeros(3, 224, 224), + "std": torch.ones(3, 224, 224), + }, + } + + # Instantiate policy + policy = WallXPolicy(config) + preprocessor, postprocessor = make_wall_x_pre_post_processors(config=config, dataset_stats=dataset_stats) + # Test forward pass with dummy data + batch_size = 1 + device = config.device + batch = { + "observation.state": torch.randn(batch_size, 7, dtype=torch.float32, device=device), + "action": torch.randn(batch_size, config.chunk_size, 7, dtype=torch.float32, device=device), + "observation.images.face_view": torch.rand( + batch_size, 3, 224, 224, dtype=torch.float32, device=device + ), # Use rand for [0,1] range + "task": ["Pick up the object"] * batch_size, + } + batch = preprocessor(batch) + try: + loss, loss_dict = policy.forward(batch) + print(f"Forward pass successful. Loss: {loss_dict['loss']:.4f}") + except Exception as e: + print(f"Forward pass failed: {e}") + raise + + # Test inference + batch = { + "observation.state": torch.randn(batch_size, 7, dtype=torch.float32, device=device), + "observation.images.face_view": torch.rand( + batch_size, 3, 224, 224, dtype=torch.float32, device=device + ), # Use rand for [0,1] range + "task": ["Pick up the object"] * batch_size, + } + batch = preprocessor(batch) + try: + with torch.no_grad(): + action = policy.select_action(batch) + action = postprocessor(action) + print(f"Action: {action}") + print(f"Action prediction successful. Action shape: {action.shape}") + except Exception as e: + print(f"Action prediction failed: {e}") + raise + + +def test_config_creation(): + """Test policy config creation through factory.""" + try: + config = make_policy_config( + policy_type="wall_x", + ) + print("Config created successfully through factory") + print(f" Config type: {type(config).__name__}") + except Exception as e: + print(f"Config creation failed: {e}") + raise + + +if __name__ == "__main__": + test_policy_instantiation() + test_config_creation()