Compare commits
149 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 3316301693 | |||
| feedababd2 | |||
| 480ee3299f | |||
| 2d1fb0f508 | |||
| b1a55b0666 | |||
| 24af996f82 | |||
| 8d7eec79c8 | |||
| ccced0c9fc | |||
| 4166eeb7da | |||
| 1f93a74d8c | |||
| b16e2f25f7 | |||
| 9cc841c674 | |||
| 63c28ea395 | |||
| 98c33a4748 | |||
| 4428248a01 | |||
| 7d6f113072 | |||
| 7ac05c838d | |||
| c85f1692d6 | |||
| 9fd329713a | |||
| 97d068e5a2 | |||
| e5bea36387 | |||
| cf1d8c3d5b | |||
| 464b65cfb0 | |||
| 90145426b4 | |||
| c76bc4cdea | |||
| 20f0381f81 | |||
| a447c652cb | |||
| 8277dbf0dc | |||
| eb0918249d | |||
| 640a7889fc | |||
| 03c6ee5f9a | |||
| dfd229ae4f | |||
| aba42c805f | |||
| 8b6b41f8dc | |||
| 1771da222b | |||
| 0514616c87 | |||
| f15872293d | |||
| a97255e3d1 | |||
| 1716d599c1 | |||
| c07ab7e1fa | |||
| 5ba9fbd9ca | |||
| 38b814f3d4 | |||
| 48a963793b | |||
| 9833b84bf8 | |||
| 27eeff7535 | |||
| 60efd875fa | |||
| 12043b3b5c | |||
| a06f4b9140 | |||
| 20c22a2799 | |||
| 2f238fce15 | |||
| ff271e8b51 | |||
| a142c365dd | |||
| b2ef6ae720 | |||
| a64f2fd322 | |||
| 17c5a0774f | |||
| 0071b1ff6e | |||
| 00b5f65752 | |||
| 2f6c870c4b | |||
| 0bd1969d0a | |||
| f04958527e | |||
| 4a151a9682 | |||
| 8667b9ef08 | |||
| 86eee5c1e2 | |||
| 469b855e42 | |||
| 202a493c14 | |||
| eadd4c0856 | |||
| 3434a5d5df | |||
| 1ba51a6d02 | |||
| c62ca6c5d2 | |||
| 4831195310 | |||
| c514d9ffe2 | |||
| 9ae4477356 | |||
| 0e545e5177 | |||
| a0c9a7d85d | |||
| 9ce6dd9e25 | |||
| 51bd288f1a | |||
| fc6262e23d | |||
| d2b16afb12 | |||
| a754c86f64 | |||
| 76e6dc1ba1 | |||
| d10d3ef251 | |||
| feebca050a | |||
| a8e7a2967c | |||
| 2cf509795e | |||
| d3846b0beb | |||
| 292333cafc | |||
| f0c98e23f1 | |||
| 7621af5acd | |||
| 08d2ed8015 | |||
| 4bcd14b8de | |||
| c34935090d | |||
| 9cfd56587e | |||
| ff8584a025 | |||
| 6bc1e5186a | |||
| 69dc8165ae | |||
| 021bca2ad9 | |||
| 4e0ee0d643 | |||
| 0a8aa85871 | |||
| 76ddd8b948 | |||
| bf08733068 | |||
| e38f56c071 | |||
| 19fe69dac0 | |||
| 14319ee608 | |||
| 9b04fd25b6 | |||
| 40e98ba690 | |||
| 894d65d58a | |||
| f58d508df2 | |||
| e22b909e7c | |||
| 09f1673cbf | |||
| 4744f99990 | |||
| 01c1735739 | |||
| 6808a42455 | |||
| fff719cb4f | |||
| e2c00f6ed8 | |||
| 0f90db23c5 | |||
| 96b192f2ae | |||
| ecdc34a699 | |||
| fa6a2fb9b7 | |||
| b011643dc9 | |||
| 30c10c1c6e | |||
| 56e2360072 | |||
| 92fdbe9bbf | |||
| b303d1ab38 | |||
| b1d162f333 | |||
| 2b304eeb84 | |||
| 4e6048a221 | |||
| 81ebcac8d7 | |||
| a6c3a0fa09 | |||
| c2fb644613 | |||
| 1d07a4aefd | |||
| ce348a3460 | |||
| cb920235c4 | |||
| 7f40b3bf82 | |||
| 2e9c9fd832 | |||
| f9cb5e659c | |||
| 0217e1e3ad | |||
| d79dd6d31f | |||
| 56b43cc888 | |||
| 77fe5a09ed | |||
| 89ae7813a7 | |||
| e003108cf8 | |||
| 5766eea377 | |||
| f8a4cf225b | |||
| 43b0f17eb9 | |||
| b0b755471b | |||
| 35c5a27352 | |||
| afb90e17e7 | |||
| 9ec9ee781a | |||
| 0b497fc37d |
@@ -12,57 +12,83 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
name: "\U0001F41B Bug Report"
|
||||
description: Submit a bug report to help us improve LeRobot
|
||||
name: "🚀 Issue / Bug / Request"
|
||||
description: Report a bug, suggest an improvement, or ask a technical question.
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
Thanks for taking the time to submit a bug report! 🐛
|
||||
If this is not a bug related to the LeRobot library directly, but instead a general question about your code or the library specifically please use our [discord](https://discord.gg/s3KuuzsPFb).
|
||||
### Thanks for contributing to LeRobot! 🙌
|
||||
Please choose the most relevant sections below. If this is a general "how-to" question, consider our [Discord](https://discord.gg/s3KuuzsPFb) for faster community support.
|
||||
|
||||
- type: dropdown
|
||||
id: issue-type
|
||||
attributes:
|
||||
label: Ticket Type
|
||||
description: What kind of ticket are you opening?
|
||||
options:
|
||||
- "🐛 Bug Report (Something isn't working)"
|
||||
- "💡 Feature Request / Improvement"
|
||||
- "❓ Technical Question"
|
||||
- "🧹 Maintenance / Documentation"
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: system-info
|
||||
attributes:
|
||||
label: System Info
|
||||
description: Please share your LeRobot configuration by running `lerobot-info` (if installed) or `python -m lerobot.scripts.display_sys_info` (if not installed) and pasting the output below.
|
||||
label: Environment & System Info
|
||||
description: |
|
||||
For bugs or technical questions, please run `lerobot-info` and paste the output.
|
||||
(Optional for feature requests).
|
||||
render: Shell
|
||||
placeholder: lerobot version, OS, python version, numpy version, torch version, and lerobot's configuration
|
||||
placeholder: lerobot version, OS, python version, etc.
|
||||
|
||||
- type: textarea
|
||||
id: description
|
||||
validations:
|
||||
required: true
|
||||
attributes:
|
||||
label: Description
|
||||
description: |
|
||||
Provide a clear summary of the issue or your proposal.
|
||||
- **Bugs:** What is happening?
|
||||
- **Features:** What is the goal/use case?
|
||||
- **Questions:** What are you trying to achieve?
|
||||
placeholder: |
|
||||
A clear and concise description of the issue or suggestion.
|
||||
|
||||
- type: textarea
|
||||
id: context-repro
|
||||
attributes:
|
||||
label: Context & Reproduction
|
||||
description: |
|
||||
Provide a code snippet, steps to reproduce a bug, or technical details about your proposal.
|
||||
Please use code blocks for scripts and CLI commands.
|
||||
placeholder: |
|
||||
Steps to reproduce / Usage example:
|
||||
1.
|
||||
2.
|
||||
3.
|
||||
|
||||
- type: textarea
|
||||
id: logs
|
||||
attributes:
|
||||
label: Relevant logs or stack trace
|
||||
description: If applicable, paste relevant error logs here.
|
||||
render: Shell
|
||||
|
||||
- type: checkboxes
|
||||
id: information-scripts-examples
|
||||
id: extras
|
||||
attributes:
|
||||
label: Information
|
||||
description: 'The problem arises when using:'
|
||||
label: Checklist
|
||||
options:
|
||||
- label: "One of the scripts in the examples/ folder of LeRobot"
|
||||
- label: "My own task or dataset (give details below)"
|
||||
- label: I have searched existing tickets to ensure this isn't a duplicate.
|
||||
- label: I am using the latest version of the `main` branch.
|
||||
- label: I have verified this is not an environment-specific problem.
|
||||
|
||||
- type: textarea
|
||||
id: reproduction
|
||||
validations:
|
||||
required: true
|
||||
id: workaround
|
||||
attributes:
|
||||
label: Reproduction
|
||||
description: |
|
||||
If needed, provide a simple code sample that reproduces the problem you ran into. It can be a Colab link or just a code snippet.
|
||||
Sharing error messages or stack traces could be useful as well!
|
||||
Important! Use code tags to correctly format your code. See https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting
|
||||
Try to avoid screenshots, as they are hard to read and don't allow copy-and-pasting.
|
||||
|
||||
placeholder: |
|
||||
Steps to reproduce the behavior:
|
||||
|
||||
1.
|
||||
2.
|
||||
3.
|
||||
|
||||
- type: textarea
|
||||
id: expected-behavior
|
||||
validations:
|
||||
required: true
|
||||
attributes:
|
||||
label: Expected behavior
|
||||
description: "A clear and concise description of what you would expect to happen."
|
||||
label: Additional Info / Workarounds
|
||||
description: Anything else we should know? If you have a workaround, please share it!
|
||||
|
||||
@@ -1,41 +1,54 @@
|
||||
## What this does
|
||||
## Title
|
||||
|
||||
Explain what this PR does. Feel free to tag your PR with the appropriate label(s).
|
||||
Short, imperative summary (e.g., "fix(robots): handle None in sensor parser"). See [CONTRIBUTING.md](../CONTRIBUTING.md) for PR conventions.
|
||||
|
||||
Examples:
|
||||
| Title | Label |
|
||||
|----------------------|-----------------|
|
||||
| Fixes #[issue] | (🐛 Bug) |
|
||||
| Adds new dataset | (🗃️ Dataset) |
|
||||
| Optimizes something | (⚡️ Performance) |
|
||||
## Type / Scope
|
||||
|
||||
## How it was tested
|
||||
- **Type**: (Bug | Feature | Docs | Performance | Test | CI | Chore)
|
||||
- **Scope**: (optional — name of module or package affected)
|
||||
|
||||
Explain/show how you tested your changes.
|
||||
## Summary / Motivation
|
||||
|
||||
Examples:
|
||||
- One-paragraph description of what changes and why.
|
||||
- Why this change is needed and any trade-offs or design notes.
|
||||
|
||||
- Added `test_something` in `tests/test_stuff.py`.
|
||||
- Added `new_feature` and checked that training converges with policy X on dataset/environment Y.
|
||||
- Optimized `some_function`, it now runs X times faster than previously.
|
||||
## Related issues
|
||||
|
||||
## How to checkout & try? (for the reviewer)
|
||||
- Fixes / Closes: # (if any)
|
||||
- Related: # (if any)
|
||||
|
||||
Provide a simple way for the reviewer to try out your changes.
|
||||
## What changed
|
||||
|
||||
Examples:
|
||||
- Short, concrete bullets of the modifications (files/behaviour).
|
||||
- Short note if this introduces breaking changes and migration steps.
|
||||
|
||||
```bash
|
||||
pytest -sx tests/test_stuff.py::test_something
|
||||
```
|
||||
## How was this tested
|
||||
|
||||
```bash
|
||||
lerobot-train --some.option=true
|
||||
```
|
||||
- Tests added: list new tests or test files.
|
||||
- Manual checks / dataset runs performed.
|
||||
|
||||
## SECTION TO REMOVE BEFORE SUBMITTING YOUR PR
|
||||
## How to run locally (reviewer)
|
||||
|
||||
**Note**: Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
|
||||
members/contributors who may be interested in your PR. Try to avoid tagging more than 3 people.
|
||||
- Run the relevant tests:
|
||||
|
||||
**Note**: Before submitting this PR, please read the [contributor guideline](https://github.com/huggingface/lerobot/blob/main/CONTRIBUTING.md#submitting-a-pull-request-pr).
|
||||
```bash
|
||||
pytest -q tests/ -k <keyword>
|
||||
```
|
||||
|
||||
- Run a quick example or CLI (if applicable):
|
||||
|
||||
```bash
|
||||
lerobot-train --some.option=true
|
||||
```
|
||||
|
||||
## Checklist (required before merge)
|
||||
|
||||
- [ ] Linting/formatting run (`pre-commit run -a`)
|
||||
- [ ] All tests pass locally (`pytest`)
|
||||
- [ ] Documentation updated
|
||||
- [ ] CI is green
|
||||
|
||||
## Reviewer notes
|
||||
|
||||
- Anything the reviewer should focus on (performance, edge-cases, specific files) or general notes.
|
||||
- Anyone in the community is free to review the PR.
|
||||
|
||||
@@ -0,0 +1,69 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
CI:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- '.github/**'
|
||||
- 'docker/**'
|
||||
|
||||
github_actions:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: '.github/**'
|
||||
|
||||
documentation:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- '**/*.md'
|
||||
- '**/*.mdx'
|
||||
- 'docs/**'
|
||||
|
||||
examples:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: 'examples/**'
|
||||
|
||||
tests:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: 'tests/**'
|
||||
|
||||
sensors:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: 'src/lerobot/cameras/**'
|
||||
|
||||
configuration:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: 'src/lerobot/configs/**'
|
||||
|
||||
dataset:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: 'src/lerobot/datasets/**'
|
||||
|
||||
evaluation:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: 'src/lerobot/envs/**'
|
||||
|
||||
robots:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- 'src/lerobot/teleoperators/**'
|
||||
- 'src/lerobot/robots/**'
|
||||
- 'src/lerobot/motors/**'
|
||||
|
||||
policies:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: 'src/lerobot/policies/**'
|
||||
|
||||
processor:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: 'src/lerobot/processor/**'
|
||||
@@ -31,7 +31,8 @@ jobs:
|
||||
name: Upload Preview and Comment
|
||||
if: >
|
||||
github.event.workflow_run.event == 'pull_request' &&
|
||||
github.event.workflow_run.conclusion == 'success'
|
||||
github.event.workflow_run.conclusion == 'success' &&
|
||||
github.repository == 'huggingface/lerobot'
|
||||
uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@main
|
||||
with:
|
||||
package_name: lerobot
|
||||
|
||||
@@ -33,6 +33,9 @@ on:
|
||||
paths:
|
||||
- "docs/**"
|
||||
|
||||
release:
|
||||
types: [published]
|
||||
|
||||
# Ensures that only the latest commit for a PR or branch is built, canceling older runs.
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
|
||||
@@ -42,14 +45,16 @@ jobs:
|
||||
# This job builds and deploys the official documentation.
|
||||
build_main_docs:
|
||||
name: Build Main Docs
|
||||
if: github.event_name == 'push' || github.event_name == 'workflow_dispatch'
|
||||
if: >
|
||||
(github.event_name == 'push' || github.event_name == 'workflow_dispatch' || github.event_name == 'release') &&
|
||||
github.repository == 'huggingface/lerobot'
|
||||
permissions:
|
||||
contents: read
|
||||
uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@main
|
||||
with:
|
||||
commit_sha: ${{ github.sha }}
|
||||
package: lerobot
|
||||
additional_args: --not_python_module
|
||||
additional_args: --not_python_module ${{ github.event_name == 'release' && format('--version {0}', github.event.release.tag_name) || '' }}
|
||||
secrets:
|
||||
token: ${{ secrets.HUGGINGFACE_PUSH }}
|
||||
hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }}
|
||||
@@ -58,7 +63,7 @@ jobs:
|
||||
# The result of this job triggers the 'Upload PR Documentation' workflow.
|
||||
build_pr_docs:
|
||||
name: Build PR Docs
|
||||
if: github.event_name == 'pull_request'
|
||||
if: github.event_name == 'pull_request' && github.repository == 'huggingface/lerobot'
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: write
|
||||
|
||||
@@ -45,7 +45,6 @@ permissions:
|
||||
env:
|
||||
UV_VERSION: "0.8.0"
|
||||
PYTHON_VERSION: "3.10"
|
||||
DOCKER_IMAGE_NAME: huggingface/lerobot-gpu
|
||||
|
||||
# Ensures that only the latest commit for a PR or branch is built, canceling older runs.
|
||||
concurrency:
|
||||
@@ -63,7 +62,7 @@ jobs:
|
||||
HF_HOME: /mnt/cache/.cache/huggingface
|
||||
HF_LEROBOT_HOME: /mnt/cache/.cache/huggingface/lerobot
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
lfs: true
|
||||
|
||||
@@ -61,7 +61,7 @@ jobs:
|
||||
HF_HOME: /mnt/cache/.cache/huggingface
|
||||
HF_LEROBOT_HOME: /mnt/cache/.cache/huggingface/lerobot
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v6
|
||||
with:
|
||||
lfs: true
|
||||
persist-credentials: false
|
||||
@@ -85,7 +85,7 @@ jobs:
|
||||
python-version: ${{ env.PYTHON_VERSION }}
|
||||
|
||||
- name: Install lerobot with all extras
|
||||
run: uv sync --all-extras --no-extra groot # TODO(Steven): Make flash-attn optional
|
||||
run: uv sync --extra all # TODO(Steven): Make flash-attn optional
|
||||
|
||||
- name: Run pytest (all extras)
|
||||
run: uv run pytest tests -vv --maxfail=10
|
||||
@@ -127,7 +127,7 @@ jobs:
|
||||
sudo apt-get update
|
||||
sudo apt-get install git-lfs
|
||||
git lfs install
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v6
|
||||
with:
|
||||
lfs: true
|
||||
persist-credentials: false
|
||||
|
||||
@@ -0,0 +1,77 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# This workflow automatically labels issues based on their content.
|
||||
name: Issue Labeler
|
||||
on:
|
||||
# Trigger on new issues and edits to existing issues
|
||||
issues:
|
||||
types: [opened, edited]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
issues: write
|
||||
|
||||
jobs:
|
||||
label-issue:
|
||||
name: Auto Label Issue
|
||||
runs-on: ubuntu-latest
|
||||
if: github.repository == 'huggingface/lerobot'
|
||||
steps:
|
||||
- uses: actions/github-script@v8
|
||||
with:
|
||||
script: |
|
||||
// Setup Input Text
|
||||
const body = (context.payload.issue.body || '');
|
||||
const title = (context.payload.issue.title || '');
|
||||
const cleanBody = body.replace(/```[\s\S]*?```/g, '');
|
||||
const text = `${title}\n${cleanBody}`.toLowerCase();
|
||||
const labelsToAdd = new Set();
|
||||
const matches = (re) => re.test(text);
|
||||
|
||||
// Keyword Heuristics
|
||||
|
||||
if (matches(/\b(bug|error|crash|exception)\b/i)) labelsToAdd.add('bug');
|
||||
if (matches(/\b(new feature|enhancement|improvement|proposal|feature request)\b/i)) labelsToAdd.add('enhancement');
|
||||
if (matches(/\b(question|how to|clarify|explain|how do i|help me|question about)\b/i)) labelsToAdd.add('question');
|
||||
if (matches(/\b(documentation|docs?|readme|tutorial|wiki|typo|docstring)\b/i)) labelsToAdd.add('documentation');
|
||||
if (matches(/\b(example|sample|demo|notebook)s?\b/i)) labelsToAdd.add('examples');
|
||||
if (matches(/\b(datasets?|data loader|data augmentation|data preprocessing)\b/i)) labelsToAdd.add('dataset');
|
||||
if (matches(/\b(mujoco|isaac|simulation|sim)\b/i)) labelsToAdd.add('simulation');
|
||||
if (matches(/\b(train|training|optimizer|gradient|wandb|sac)\b/i)) labelsToAdd.add('training');
|
||||
if (matches(/\b(rerun|plot|render|rendering|visualizer)/i)) labelsToAdd.add('visualization');
|
||||
if (matches(/\b(cameras?|opencv|realsense|lidars?|sensors?|imus?|microphones?|rgbd|encoders?)\b/i)) labelsToAdd.add('sensors');
|
||||
if (matches(/\b(urdf|actuators?|calibration|end-effector|kinematics)\b/i)) labelsToAdd.add('robots');
|
||||
if (matches(/\b(teleop|teleoperator|controller|leader|follower|joystick|gamepad)\b/i)) labelsToAdd.add('teleoperators');
|
||||
if (matches(/\b(policy|policies|model?)\b/i)) labelsToAdd.add('policies');
|
||||
if (matches(/\b(processor|pipeline|preprocessor|postprocessor)s?\b/i)) labelsToAdd.add('processor');
|
||||
if (matches(/\b(eval|evaluate|evaluation|metrics?|score|benchmarks?)\b/i)) labelsToAdd.add('evaluation');
|
||||
if (matches(/\b(tests?|pytest|unittest|failing test)\b/i)) labelsToAdd.add('tests');
|
||||
if (matches(/\b(ci|github actions?|github workflows?|gha|docker|pypi)\b/i)) labelsToAdd.add('CI');
|
||||
if (matches(/\b(perf|latency|throughput|fps|speed|performance|slow|fast|slower|faster|memory usage)\b/i)) labelsToAdd.add('performance');
|
||||
if (matches(/\b(dependency|dependencies|pip|install error|importerror|package not found|pyproject)\b/i)) labelsToAdd.add('dependencies');
|
||||
if (matches(/\b(configuration|config|arguments?|input feature|dracuss)\b/i)) labelsToAdd.add('configuration');
|
||||
|
||||
// Apply Labels
|
||||
const labels = Array.from(labelsToAdd).filter(Boolean);
|
||||
|
||||
if (labels.length > 0) {
|
||||
console.log(`Adding labels: ${labels.join(', ')}`);
|
||||
await github.rest.issues.addLabels({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: context.issue.number,
|
||||
labels,
|
||||
});
|
||||
}
|
||||
@@ -43,6 +43,7 @@ jobs:
|
||||
name: Build CPU Docker for Nightly
|
||||
runs-on:
|
||||
group: aws-general-8-plus
|
||||
if: github.repository == 'huggingface/lerobot'
|
||||
outputs:
|
||||
image_tag: ${{ env.DOCKER_IMAGE_NAME_CPU }}
|
||||
steps:
|
||||
@@ -51,7 +52,7 @@ jobs:
|
||||
sudo apt-get update
|
||||
sudo apt-get install git-lfs
|
||||
git lfs install
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v6
|
||||
with:
|
||||
lfs: true
|
||||
persist-credentials: false
|
||||
@@ -77,6 +78,7 @@ jobs:
|
||||
name: Build GPU Docker for Nightly
|
||||
runs-on:
|
||||
group: aws-general-8-plus
|
||||
if: github.repository == 'huggingface/lerobot'
|
||||
outputs:
|
||||
image_tag: ${{ env.DOCKER_IMAGE_NAME_GPU }}
|
||||
steps:
|
||||
@@ -85,7 +87,7 @@ jobs:
|
||||
sudo apt-get update
|
||||
sudo apt-get install git-lfs
|
||||
git lfs install
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v6
|
||||
with:
|
||||
lfs: true
|
||||
persist-credentials: false
|
||||
|
||||
@@ -0,0 +1,39 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# This workflow labels pull requests based on the files that were changed.
|
||||
name: Pull Request Labeler
|
||||
|
||||
on:
|
||||
# Allows labeling pull requests when they are opened or updated
|
||||
# zizmor: ignore[dangerous-triggers] Needed to label PRs from forks
|
||||
pull_request_target:
|
||||
branches:
|
||||
- main
|
||||
types: [opened, synchronize, reopened, ready_for_review]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: write
|
||||
|
||||
jobs:
|
||||
triage:
|
||||
name: Label PR
|
||||
runs-on: ubuntu-latest
|
||||
if: github.repository == 'huggingface/lerobot' && !github.event.pull_request.draft
|
||||
steps:
|
||||
- uses: actions/labeler@v6
|
||||
with:
|
||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
sync-labels: true # Removes labels if files are removed from the PR
|
||||
@@ -43,12 +43,12 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: '3.10'
|
||||
|
||||
|
||||
@@ -29,6 +29,7 @@ jobs:
|
||||
build-and-publish:
|
||||
name: Build and publish Python distributions
|
||||
runs-on: ubuntu-latest
|
||||
if: github.repository == 'huggingface/lerobot'
|
||||
outputs:
|
||||
version: ${{ steps.extract_info.outputs.tag_version }}
|
||||
permissions:
|
||||
@@ -37,12 +38,12 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: '3.10'
|
||||
|
||||
@@ -134,7 +135,7 @@ jobs:
|
||||
env:
|
||||
MUJOCO_GL: egl
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v6
|
||||
with:
|
||||
lfs: true
|
||||
persist-credentials: false
|
||||
@@ -176,4 +177,3 @@ jobs:
|
||||
|
||||
# TODO(Steven): Publish draft/pre-release and to test pypi weekly
|
||||
# TODO(Steven): Separate build and publish job
|
||||
# TODO(Steven): Tag documentation with the same version as the package
|
||||
|
||||
@@ -43,7 +43,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4 # zizmor: ignore[unpinned-uses]
|
||||
uses: actions/checkout@v6 # zizmor: ignore[unpinned-uses]
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
@@ -45,6 +45,7 @@ jobs:
|
||||
stale:
|
||||
name: Close Stale Issues and PRs
|
||||
runs-on: ubuntu-latest
|
||||
if: github.repository == 'huggingface/lerobot'
|
||||
permissions:
|
||||
actions: write
|
||||
contents: write # only for delete-branch option
|
||||
|
||||
@@ -43,12 +43,13 @@ jobs:
|
||||
full-tests:
|
||||
name: Full Unbound Tests
|
||||
runs-on: ubuntu-latest
|
||||
if: github.repository == 'huggingface/lerobot'
|
||||
env:
|
||||
MUJOCO_GL: egl
|
||||
HF_HOME: /mnt/cache/.cache/huggingface
|
||||
HF_LEROBOT_HOME: /mnt/cache/.cache/huggingface/lerobot
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v6
|
||||
with:
|
||||
lfs: true
|
||||
persist-credentials: false
|
||||
@@ -77,7 +78,7 @@ jobs:
|
||||
echo "Dependencies unbound:" && cat pyproject.toml
|
||||
|
||||
- name: Install lerobot with all extras
|
||||
run: uv sync --all-extras --no-extra groot # TODO(Steven): Make flash-attn optional
|
||||
run: uv sync --extra all # TODO(Steven): Make flash-attn optional
|
||||
|
||||
- name: Run pytest (all extras)
|
||||
run: uv run pytest tests -vv
|
||||
@@ -100,7 +101,7 @@ jobs:
|
||||
sudo apt-get update
|
||||
sudo apt-get install git-lfs
|
||||
git lfs install
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v6
|
||||
with:
|
||||
lfs: true
|
||||
persist-credentials: false
|
||||
|
||||
@@ -87,7 +87,7 @@ repos:
|
||||
# TODO(Steven): Uncomment when ready to use
|
||||
##### Static Analysis & Typing #####
|
||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||
rev: v1.18.2
|
||||
rev: v1.19.1
|
||||
hooks:
|
||||
- id: mypy
|
||||
args: [--config-file=pyproject.toml]
|
||||
|
||||
@@ -52,7 +52,7 @@ decisions when appropriate.
|
||||
|
||||
This Code of Conduct applies within all community spaces, and also applies when
|
||||
an individual is officially representing the community in public spaces.
|
||||
Examples of representing our community include using an official email address,
|
||||
Examples of representing our community include using an official e-mail address,
|
||||
posting via an official social media account, or acting as an appointed
|
||||
representative at an online or offline event.
|
||||
|
||||
@@ -60,7 +60,7 @@ representative at an online or offline event.
|
||||
|
||||
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
||||
reported to the community leaders responsible for enforcement at
|
||||
[feedback@huggingface.co](mailto:feedback@huggingface.co).
|
||||
feedback@huggingface.co.
|
||||
All complaints will be reviewed and investigated promptly and fairly.
|
||||
|
||||
All community leaders are obligated to respect the privacy and security of the
|
||||
|
||||
@@ -1,323 +1,83 @@
|
||||
# How to contribute to 🤗 LeRobot?
|
||||
# How to contribute to 🤗 LeRobot
|
||||
|
||||
Everyone is welcome to contribute, and we value everybody's contribution. Code
|
||||
is thus not the only way to help the community. Answering questions, helping
|
||||
others, reaching out and improving the documentations are immensely valuable to
|
||||
the community.
|
||||
Everyone is welcome to contribute, and we value everybody's contribution. Code is not the only way to help the community. Answering questions, helping others, reaching out, and improving the documentation are immensely valuable.
|
||||
|
||||
It also helps us if you spread the word: reference the library from blog posts
|
||||
on the awesome projects it made possible, shout out on Twitter when it has
|
||||
helped you, or simply ⭐️ the repo to say "thank you".
|
||||
Whichever way you choose to contribute, please be mindful to respect our [code of conduct](./CODE_OF_CONDUCT.md).
|
||||
|
||||
Whichever way you choose to contribute, please be mindful to respect our
|
||||
[code of conduct](https://github.com/huggingface/lerobot/blob/main/CODE_OF_CONDUCT.md).
|
||||
## Ways to Contribute
|
||||
|
||||
## You can contribute in so many ways!
|
||||
You can contribute in many ways:
|
||||
|
||||
Some of the ways you can contribute to 🤗 LeRobot:
|
||||
- **Fixing issues:** Resolve bugs or improve existing code.
|
||||
- **New features:** Develop new features.
|
||||
- **Extend:** Implement new models/policies, robots, or simulation environments and upload datasets to the Hugging Face Hub.
|
||||
- **Documentation:** Improve examples, guides, and docstrings.
|
||||
- **Feedback:** Submit tickets related to bugs or desired new features.
|
||||
|
||||
- Fixing outstanding issues with the existing code.
|
||||
- Implementing new models, datasets or simulation environments.
|
||||
- Contributing to the examples or to the documentation.
|
||||
- Submitting issues related to bugs or desired new features.
|
||||
If you are unsure where to start, join our [Discord Channel](https://discord.gg/JkrYNdmw).
|
||||
|
||||
Following the guides below, feel free to open issues and PRs and to coordinate your efforts with the community on our [Discord Channel](https://discord.gg/VjFz58wn3R). For specific inquiries, reach out to [Remi Cadene](mailto:remi.cadene@huggingface.co).
|
||||
## Development Setup
|
||||
|
||||
If you are not sure how to contribute or want to know the next features we working on, look on this project page: [LeRobot TODO](https://github.com/orgs/huggingface/projects/46)
|
||||
To contribute code, you need to set up a development environment.
|
||||
|
||||
## Submitting a new issue or feature request
|
||||
### 1. Fork and Clone
|
||||
|
||||
Do your best to follow these guidelines when submitting an issue or a feature
|
||||
request. It will make it easier for us to come back to you quickly and with good
|
||||
feedback.
|
||||
|
||||
### Did you find a bug?
|
||||
|
||||
The 🤗 LeRobot library is robust and reliable thanks to the users who notify us of
|
||||
the problems they encounter. So thank you for reporting an issue.
|
||||
|
||||
First, we would really appreciate it if you could **make sure the bug was not
|
||||
already reported** (use the search bar on Github under Issues).
|
||||
|
||||
Did not find it? :( So we can act quickly on it, please follow these steps:
|
||||
|
||||
- Include your **OS type and version**, the versions of **Python** and **PyTorch**.
|
||||
- A short, self-contained, code snippet that allows us to reproduce the bug in
|
||||
less than 30s.
|
||||
- The full traceback if an exception is raised.
|
||||
- Attach any other additional information, like screenshots, you think may help.
|
||||
|
||||
### Do you want a new feature?
|
||||
|
||||
A good feature request addresses the following points:
|
||||
|
||||
1. Motivation first:
|
||||
|
||||
- Is it related to a problem/frustration with the library? If so, please explain
|
||||
why. Providing a code snippet that demonstrates the problem is best.
|
||||
- Is it related to something you would need for a project? We'd love to hear
|
||||
about it!
|
||||
- Is it something you worked on and think could benefit the community?
|
||||
Awesome! Tell us what problem it solved for you.
|
||||
|
||||
2. Write a _paragraph_ describing the feature.
|
||||
3. Provide a **code snippet** that demonstrates its future use.
|
||||
4. In case this is related to a paper, please attach a link.
|
||||
5. Attach any additional information (drawings, screenshots, etc.) you think may help.
|
||||
|
||||
If your issue is well written we're already 80% of the way there by the time you
|
||||
post it.
|
||||
|
||||
## Adding new policies, datasets or environments
|
||||
|
||||
Look at our implementations for [datasets](./src/lerobot/datasets/), [policies](./src/lerobot/policies/),
|
||||
environments ([aloha](https://github.com/huggingface/gym-aloha),
|
||||
[pusht](https://github.com/huggingface/gym-pusht))
|
||||
and follow the same api design.
|
||||
|
||||
When implementing a new dataset loadable with LeRobotDataset follow these steps:
|
||||
|
||||
- Update `available_datasets_per_env` in `lerobot/__init__.py`
|
||||
|
||||
When implementing a new environment (e.g. `gym_aloha`), follow these steps:
|
||||
|
||||
- Update `available_tasks_per_env` and `available_datasets_per_env` in `lerobot/__init__.py`
|
||||
|
||||
When implementing a new policy class (e.g. `DiffusionPolicy`) follow these steps:
|
||||
|
||||
- Update `available_policies` and `available_policies_per_env`, in `lerobot/__init__.py`
|
||||
- Set the required `name` class attribute.
|
||||
- Update variables in `tests/test_available.py` by importing your new Policy class
|
||||
|
||||
## Submitting a pull request (PR)
|
||||
|
||||
Before writing code, we strongly advise you to search through the existing PRs or
|
||||
issues to make sure that nobody is already working on the same thing. If you are
|
||||
unsure, it is always a good idea to open an issue to get some feedback.
|
||||
|
||||
You will need basic `git` proficiency to be able to contribute to
|
||||
🤗 LeRobot. `git` is not the easiest tool to use but it has the greatest
|
||||
manual. Type `git --help` in a shell and enjoy. If you prefer books, [Pro
|
||||
Git](https://git-scm.com/book/en/v2) is a very good reference.
|
||||
|
||||
Follow these steps to start contributing:
|
||||
|
||||
1. Fork the [repository](https://github.com/huggingface/lerobot) by
|
||||
clicking on the 'Fork' button on the repository's page. This creates a copy of the code
|
||||
under your GitHub user account.
|
||||
|
||||
2. Clone your fork to your local disk, and add the base repository as a remote. The following command
|
||||
assumes you have your public SSH key uploaded to GitHub. See the following guide for more
|
||||
[information](https://docs.github.com/en/repositories/creating-and-managing-repositories/cloning-a-repository).
|
||||
|
||||
```bash
|
||||
git clone git@github.com:<your Github handle>/lerobot.git
|
||||
cd lerobot
|
||||
git remote add upstream https://github.com/huggingface/lerobot.git
|
||||
```
|
||||
|
||||
3. Create a new branch to hold your development changes, and do this for every new PR you work on.
|
||||
|
||||
Start by synchronizing your `main` branch with the `upstream/main` branch (more details in the [GitHub Docs](https://docs.github.com/en/github/collaborating-with-issues-and-pull-requests/syncing-a-fork)):
|
||||
|
||||
```bash
|
||||
git checkout main
|
||||
git fetch upstream
|
||||
git rebase upstream/main
|
||||
```
|
||||
|
||||
Once your `main` branch is synchronized, create a new branch from it:
|
||||
|
||||
```bash
|
||||
git checkout -b a-descriptive-name-for-my-changes
|
||||
```
|
||||
|
||||
🚨 **Do not** work on the `main` branch.
|
||||
|
||||
4. for development, we advise to use a tool like `poetry` or `uv` instead of just `pip` to easily track our dependencies.
|
||||
Follow the instructions to [install poetry](https://python-poetry.org/docs/#installation) (use a version >=2.1.0) or to [install uv](https://docs.astral.sh/uv/getting-started/installation/#installation-methods) if you don't have one of them already.
|
||||
|
||||
Set up a development environment with conda:
|
||||
|
||||
```bash
|
||||
conda create -y -n lerobot-dev python=3.10 && conda activate lerobot-dev
|
||||
```
|
||||
|
||||
If you're using `uv`, it can manage python versions so you can instead do:
|
||||
|
||||
```bash
|
||||
uv venv --python 3.10 && source .venv/bin/activate
|
||||
```
|
||||
|
||||
To develop on 🤗 LeRobot, you will at least need to install the `dev` and `test` extras dependencies along with the core library:
|
||||
|
||||
using `poetry`
|
||||
|
||||
```bash
|
||||
poetry sync --extras "dev test"
|
||||
```
|
||||
|
||||
using `uv`
|
||||
|
||||
```bash
|
||||
uv sync --extra dev --extra test
|
||||
```
|
||||
|
||||
You can also install the project with all its dependencies (including environments):
|
||||
|
||||
using `poetry`
|
||||
|
||||
```bash
|
||||
poetry sync --all-extras
|
||||
```
|
||||
|
||||
using `uv`
|
||||
|
||||
```bash
|
||||
uv sync --all-extras
|
||||
```
|
||||
|
||||
> **Note:** If you don't install simulation environments with `--all-extras`, the tests that require them will be skipped when running the pytest suite locally. However, they _will_ be tested in the CI. In general, we advise you to install everything and test locally before pushing.
|
||||
|
||||
Whichever command you chose to install the project (e.g. `poetry sync --all-extras`), you should run it again when pulling code with an updated version of `pyproject.toml` and `poetry.lock` in order to synchronize your virtual environment with the new dependencies.
|
||||
|
||||
The equivalent of `pip install some-package`, would just be:
|
||||
|
||||
using `poetry`
|
||||
|
||||
```bash
|
||||
poetry add some-package
|
||||
```
|
||||
|
||||
using `uv`
|
||||
|
||||
```bash
|
||||
uv add some-package
|
||||
```
|
||||
|
||||
When making changes to the poetry sections of the `pyproject.toml`, you should run the following command to lock dependencies.
|
||||
using `poetry`
|
||||
|
||||
```bash
|
||||
poetry lock
|
||||
```
|
||||
|
||||
using `uv`
|
||||
|
||||
```bash
|
||||
uv lock
|
||||
```
|
||||
|
||||
5. Develop the features on your branch.
|
||||
|
||||
As you work on the features, you should make sure that the test suite
|
||||
passes. You should run the tests impacted by your changes like this (see
|
||||
below an explanation regarding the environment variable):
|
||||
|
||||
```bash
|
||||
pytest tests/<TEST_TO_RUN>.py
|
||||
```
|
||||
|
||||
6. Follow our style.
|
||||
|
||||
`lerobot` relies on `ruff` to format its source code
|
||||
consistently. Set up [`pre-commit`](https://pre-commit.com/) to run these checks
|
||||
automatically as Git commit hooks.
|
||||
|
||||
Install `pre-commit` hooks:
|
||||
|
||||
```bash
|
||||
pre-commit install
|
||||
```
|
||||
|
||||
You can run these hooks whenever you need on staged files with:
|
||||
|
||||
```bash
|
||||
pre-commit
|
||||
```
|
||||
|
||||
Once you're happy with your changes, add changed files using `git add` and
|
||||
make a commit with `git commit` to record your changes locally:
|
||||
|
||||
```bash
|
||||
git add modified_file.py
|
||||
git commit
|
||||
```
|
||||
|
||||
Note, if you already committed some changes that have a wrong formatting, you can use:
|
||||
|
||||
```bash
|
||||
pre-commit run --all-files
|
||||
```
|
||||
|
||||
Please write [good commit messages](https://chris.beams.io/posts/git-commit/).
|
||||
|
||||
It is a good idea to sync your copy of the code with the original
|
||||
repository regularly. This way you can quickly account for changes:
|
||||
|
||||
```bash
|
||||
git fetch upstream
|
||||
git rebase upstream/main
|
||||
```
|
||||
|
||||
Push the changes to your account using:
|
||||
|
||||
```bash
|
||||
git push -u origin a-descriptive-name-for-my-changes
|
||||
```
|
||||
|
||||
7. Once you are satisfied (**and the checklist below is happy too**), go to the
|
||||
webpage of your fork on GitHub. Click on 'Pull request' to send your changes
|
||||
to the project maintainers for review.
|
||||
|
||||
8. It's ok if maintainers ask you for changes. It happens to core contributors
|
||||
too! So everyone can see the changes in the Pull request, work in your local
|
||||
branch and push the changes to your fork. They will automatically appear in
|
||||
the pull request.
|
||||
|
||||
### Checklist
|
||||
|
||||
1. The title of your pull request should be a summary of its contribution;
|
||||
2. If your pull request addresses an issue, please mention the issue number in
|
||||
the pull request description to make sure they are linked (and people
|
||||
consulting the issue know you are working on it);
|
||||
3. To indicate a work in progress please prefix the title with `[WIP]`, or preferably mark
|
||||
the PR as a draft PR. These are useful to avoid duplicated work, and to differentiate
|
||||
it from PRs ready to be merged;
|
||||
4. Make sure existing tests pass;
|
||||
|
||||
### Tests
|
||||
|
||||
An extensive test suite is included to test the library behavior and several examples. Library tests can be found in the [tests folder](https://github.com/huggingface/lerobot/tree/main/tests).
|
||||
|
||||
Install [git lfs](https://git-lfs.com/) to retrieve test artifacts (if you don't have it already).
|
||||
|
||||
On Mac:
|
||||
Fork the repository on GitHub, then clone your fork:
|
||||
|
||||
```bash
|
||||
brew install git-lfs
|
||||
git lfs install
|
||||
git clone https://github.com/<your-handle>/lerobot.git
|
||||
cd lerobot
|
||||
git remote add upstream https://github.com/huggingface/lerobot.git
|
||||
```
|
||||
|
||||
On Ubuntu:
|
||||
### 2. Environment Installation
|
||||
|
||||
Please follow our [Installation Guide](./docs/source/installation.mdx) for the environment setup & installation from source.
|
||||
|
||||
## Running Tests & Quality Checks
|
||||
|
||||
### Code Style (Pre-commit)
|
||||
|
||||
Install `pre-commit` hooks to run checks automatically before you commit:
|
||||
|
||||
```bash
|
||||
sudo apt-get install git-lfs
|
||||
git lfs install
|
||||
pre-commit install
|
||||
```
|
||||
|
||||
Pull artifacts if they're not in [tests/artifacts](tests/artifacts)
|
||||
To run checks manually on all files:
|
||||
|
||||
```bash
|
||||
pre-commit run --all-files
|
||||
```
|
||||
|
||||
### Running Tests
|
||||
|
||||
We use `pytest`. First, ensure you have test artifacts by installing **git-lfs**:
|
||||
|
||||
```bash
|
||||
git lfs install
|
||||
git lfs pull
|
||||
```
|
||||
|
||||
We use `pytest` in order to run the tests. From the root of the
|
||||
repository, here's how to run tests with `pytest` for the library:
|
||||
Run the full suite (this may require extras installed):
|
||||
|
||||
```bash
|
||||
python -m pytest -sv ./tests
|
||||
pytest -sv ./tests
|
||||
```
|
||||
|
||||
You can specify a smaller set of tests in order to test only the feature
|
||||
you're working on.
|
||||
Or run a specific test file during development:
|
||||
|
||||
```bash
|
||||
pytest -sv tests/test_specific_feature.py
|
||||
```
|
||||
|
||||
## Submitting Issues & Pull Requests
|
||||
|
||||
Use the templates for required fields and examples.
|
||||
|
||||
- **Issues:** Follow the [ticket template](./.github/ISSUE_TEMPLATE/bug-report.yml).
|
||||
- **Pull requests:** Rebase on `upstream/main`, use a descriptive branch (don't work on `main`), run `pre-commit` and tests locally, and follow the [PR template](./.github/PULL_REQUEST_TEMPLATE.md).
|
||||
|
||||
One member of the LeRobot team will then review your contribution.
|
||||
|
||||
Thank you for contributing to LeRobot!
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
<p align="center">
|
||||
<img alt="LeRobot, Hugging Face Robotics Library" src="https://raw.githubusercontent.com/huggingface/lerobot/main/media/lerobot-logo-thumbnail.png" width="100%">
|
||||
<br/>
|
||||
<br/>
|
||||
<img alt="LeRobot, Hugging Face Robotics Library" src="./media/readme/lerobot-logo-thumbnail.png" width="100%">
|
||||
</p>
|
||||
|
||||
<div align="center">
|
||||
@@ -12,323 +10,130 @@
|
||||
[](https://pypi.org/project/lerobot/)
|
||||
[](https://pypi.org/project/lerobot/)
|
||||
[](https://github.com/huggingface/lerobot/blob/main/CODE_OF_CONDUCT.md)
|
||||
[](https://discord.gg/s3KuuzsPFb)
|
||||
|
||||
<!-- [](https://codecov.io/gh/huggingface/lerobot) -->
|
||||
|
||||
</div>
|
||||
|
||||
<h2 align="center">
|
||||
<p><a href="https://huggingface.co/docs/lerobot/hope_jr">
|
||||
Build Your Own HopeJR Robot!</a></p>
|
||||
</h2>
|
||||
**LeRobot** aims to provide models, datasets, and tools for real-world robotics in PyTorch. The goal is to lower the barrier to entry so that everyone can contribute to and benefit from shared datasets and pretrained models.
|
||||
|
||||
<div align="center">
|
||||
<img
|
||||
src="https://raw.githubusercontent.com/huggingface/lerobot/main/media/hope_jr/hopejr.png"
|
||||
alt="HopeJR robot"
|
||||
title="HopeJR robot"
|
||||
width="60%"
|
||||
/>
|
||||
🤗 A hardware-agnostic, Python-native interface that standardizes control across diverse platforms, from low-cost arms (SO-100) to humanoids.
|
||||
|
||||
<p><strong>Meet HopeJR – A humanoid robot arm and hand for dexterous manipulation!</strong></p>
|
||||
<p>Control it with exoskeletons and gloves for precise hand movements.</p>
|
||||
<p>Perfect for advanced manipulation tasks! 🤖</p>
|
||||
🤗 A standardized, scalable LeRobotDataset format (Parquet + MP4 or images) hosted on the Hugging Face Hub, enabling efficient storage, streaming and visualization of massive robotic datasets.
|
||||
|
||||
<p><a href="https://huggingface.co/docs/lerobot/hope_jr">
|
||||
See the full HopeJR tutorial here.</a></p>
|
||||
</div>
|
||||
🤗 State-of-the-art policies that have been shown to transfer to the real-world ready for training and deployment.
|
||||
|
||||
<br/>
|
||||
🤗 Comprehensive support for the open-source ecosystem to democratize physical AI.
|
||||
|
||||
<h2 align="center">
|
||||
<p><a href="https://huggingface.co/docs/lerobot/so101">
|
||||
Build Your Own SO-101 Robot!</a></p>
|
||||
</h2>
|
||||
## Quick Start
|
||||
|
||||
<div align="center">
|
||||
<table>
|
||||
<tr>
|
||||
<td align="center"><img src="https://raw.githubusercontent.com/huggingface/lerobot/main/media/so101/so101.webp" alt="SO-101 follower arm" title="SO-101 follower arm" width="90%"/></td>
|
||||
<td align="center"><img src="https://raw.githubusercontent.com/huggingface/lerobot/main/media/so101/so101-leader.webp" alt="SO-101 leader arm" title="SO-101 leader arm" width="90%"/></td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
<p><strong>Meet the updated SO100, the SO-101 – Just €114 per arm!</strong></p>
|
||||
<p>Train it in minutes with a few simple moves on your laptop.</p>
|
||||
<p>Then sit back and watch your creation act autonomously! 🤯</p>
|
||||
|
||||
<p><a href="https://huggingface.co/docs/lerobot/so101">
|
||||
See the full SO-101 tutorial here.</a></p>
|
||||
|
||||
<p>Want to take it to the next level? Make your SO-101 mobile by building LeKiwi!</p>
|
||||
<p>Check out the <a href="https://huggingface.co/docs/lerobot/lekiwi">LeKiwi tutorial</a> and bring your robot to life on wheels.</p>
|
||||
|
||||
<img src="https://raw.githubusercontent.com/huggingface/lerobot/main/media/lekiwi/kiwi.webp" alt="LeKiwi mobile robot" title="LeKiwi mobile robot" width="50%">
|
||||
</div>
|
||||
|
||||
<br/>
|
||||
|
||||
<h3 align="center">
|
||||
<p>LeRobot: State-of-the-art AI for real-world robotics</p>
|
||||
</h3>
|
||||
|
||||
---
|
||||
|
||||
🤗 LeRobot aims to provide models, datasets, and tools for real-world robotics in PyTorch. The goal is to lower the barrier to entry to robotics so that everyone can contribute and benefit from sharing datasets and pretrained models.
|
||||
|
||||
🤗 LeRobot contains state-of-the-art approaches that have been shown to transfer to the real-world with a focus on imitation learning and reinforcement learning.
|
||||
|
||||
🤗 LeRobot already provides a set of pretrained models, datasets with human collected demonstrations, and simulation environments to get started without assembling a robot. In the coming weeks, the plan is to add more and more support for real-world robotics on the most affordable and capable robots out there.
|
||||
|
||||
🤗 LeRobot hosts pretrained models and datasets on this Hugging Face community page: [huggingface.co/lerobot](https://huggingface.co/lerobot)
|
||||
|
||||
#### Examples of pretrained models on simulation environments
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
<td><img src="https://raw.githubusercontent.com/huggingface/lerobot/main/media/gym/aloha_act.gif" width="100%" alt="ACT policy on ALOHA env"/></td>
|
||||
<td><img src="https://raw.githubusercontent.com/huggingface/lerobot/main/media/gym/simxarm_tdmpc.gif" width="100%" alt="TDMPC policy on SimXArm env"/></td>
|
||||
<td><img src="https://raw.githubusercontent.com/huggingface/lerobot/main/media/gym/pusht_diffusion.gif" width="100%" alt="Diffusion policy on PushT env"/></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">ACT policy on ALOHA env</td>
|
||||
<td align="center">TDMPC policy on SimXArm env</td>
|
||||
<td align="center">Diffusion policy on PushT env</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
## Installation
|
||||
|
||||
LeRobot works with Python 3.10+ and PyTorch 2.2+.
|
||||
|
||||
### Environment Setup
|
||||
|
||||
Create a virtual environment with Python 3.10 and activate it, e.g. with [`miniforge`](https://conda-forge.org/download/):
|
||||
|
||||
```bash
|
||||
conda create -y -n lerobot python=3.10
|
||||
conda activate lerobot
|
||||
```
|
||||
|
||||
When using `conda`, install `ffmpeg` in your environment:
|
||||
|
||||
```bash
|
||||
conda install ffmpeg -c conda-forge
|
||||
```
|
||||
|
||||
> **NOTE:** This usually installs `ffmpeg 7.X` for your platform compiled with the `libsvtav1` encoder. If `libsvtav1` is not supported (check supported encoders with `ffmpeg -encoders`), you can:
|
||||
>
|
||||
> - _[On any platform]_ Explicitly install `ffmpeg 7.X` using:
|
||||
>
|
||||
> ```bash
|
||||
> conda install ffmpeg=7.1.1 -c conda-forge
|
||||
> ```
|
||||
>
|
||||
> - _[On Linux only]_ Install [ffmpeg build dependencies](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#GettheDependencies) and [compile ffmpeg from source with libsvtav1](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#libsvtav1), and make sure you use the corresponding ffmpeg binary to your install with `which ffmpeg`.
|
||||
|
||||
### Install LeRobot 🤗
|
||||
|
||||
#### From Source
|
||||
|
||||
First, clone the repository and navigate into the directory:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/huggingface/lerobot.git
|
||||
cd lerobot
|
||||
```
|
||||
|
||||
Then, install the library in editable mode. This is useful if you plan to contribute to the code.
|
||||
|
||||
```bash
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
> **NOTE:** If you encounter build errors, you may need to install additional dependencies (`cmake`, `build-essential`, and `ffmpeg libs`). On Linux, run:
|
||||
> `sudo apt-get install cmake build-essential python3-dev pkg-config libavformat-dev libavcodec-dev libavdevice-dev libavutil-dev libswscale-dev libswresample-dev libavfilter-dev`. For other systems, see: [Compiling PyAV](https://pyav.org/docs/develop/overview/installation.html#bring-your-own-ffmpeg)
|
||||
|
||||
For simulations, 🤗 LeRobot comes with gymnasium environments that can be installed as extras:
|
||||
|
||||
- [aloha](https://github.com/huggingface/gym-aloha)
|
||||
- [xarm](https://github.com/huggingface/gym-xarm)
|
||||
- [pusht](https://github.com/huggingface/gym-pusht)
|
||||
|
||||
For instance, to install 🤗 LeRobot with aloha and pusht, use:
|
||||
|
||||
```bash
|
||||
pip install -e ".[aloha, pusht]"
|
||||
```
|
||||
|
||||
### Installation from PyPI
|
||||
|
||||
**Core Library:**
|
||||
Install the base package with:
|
||||
LeRobot can be installed directly from PyPI.
|
||||
|
||||
```bash
|
||||
pip install lerobot
|
||||
lerobot-info
|
||||
```
|
||||
|
||||
_This installs only the default dependencies._
|
||||
> [!IMPORTANT]
|
||||
> For detailed installation guide, please see the [Installation Documentation](https://huggingface.co/docs/lerobot/installation).
|
||||
|
||||
**Extra Features:**
|
||||
To install additional functionality, use one of the following:
|
||||
## Robots & Control
|
||||
|
||||
<div align="center">
|
||||
<img src="./media/readme/robots_control_video.webp" width="640px" alt="Reachy 2 Demo">
|
||||
</div>
|
||||
|
||||
LeRobot provides a unified `Robot` class interface that decouples control logic from hardware specifics. It supports a wide range of robots and teleoperation devices.
|
||||
|
||||
```python
|
||||
from lerobot.robots.myrobot import MyRobot
|
||||
|
||||
# Connect to a robot
|
||||
robot = MyRobot(config=...)
|
||||
robot.connect()
|
||||
|
||||
# Read observation and send action
|
||||
obs = robot.get_observation()
|
||||
action = model.select_action(obs)
|
||||
robot.send_action(action)
|
||||
```
|
||||
|
||||
**Supported Hardware:** SO100, LeKiwi, Koch, HopeJR, OMX, EarthRover, Reachy2, Gamepads, Keyboards, Phones, OpenARM, Unitree G1.
|
||||
|
||||
While these devices are natively integrated into the LeRobot codebase, the library is designed to be extensible. You can easily implement the Robot interface to utilize LeRobot's data collection, training, and visualization tools for your own custom robot.
|
||||
|
||||
For detailed hardware setup guides, see the [Hardware Documentation](https://huggingface.co/docs/lerobot/integrate_hardware).
|
||||
|
||||
## LeRobot Dataset
|
||||
|
||||
To solve the data fragmentation problem in robotics, we utilize the **LeRobotDataset** format.
|
||||
|
||||
- **Structure:** Synchronized MP4 videos (or images) for vision and Parquet files for state/action data.
|
||||
- **HF Hub Integration:** Explore thousands of robotics datasets on the [Hugging Face Hub](https://huggingface.co/lerobot).
|
||||
- **Tools:** Seamlessly delete episodes, split by indices/fractions, add/remove features, and merge multiple datasets.
|
||||
|
||||
```python
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
# Load a dataset from the Hub
|
||||
dataset = LeRobotDataset("lerobot/aloha_mobile_cabinet")
|
||||
|
||||
# Access data (automatically handles video decoding)
|
||||
episode_index=0
|
||||
print(f"{dataset[episode_index]['action'].shape=}\n")
|
||||
```
|
||||
|
||||
Learn more about it in the [LeRobotDataset Documentation](https://huggingface.co/docs/lerobot/lerobot-dataset-v3)
|
||||
|
||||
## SoTA Models
|
||||
|
||||
LeRobot implements state-of-the-art policies in pure PyTorch, covering Imitation Learning, Reinforcement Learning, and Vision-Language-Action (VLA) models, with more coming soon. It also provides you with the tools to instrument and inspect your training process.
|
||||
|
||||
<p align="center">
|
||||
<img alt="Gr00t Architecture" src="./media/readme/VLA_architecture.jpg" width="640px">
|
||||
</p>
|
||||
|
||||
Training a policy is as simple as running a script configuration:
|
||||
|
||||
```bash
|
||||
pip install 'lerobot[all]' # All available features
|
||||
pip install 'lerobot[aloha,pusht]' # Specific features (Aloha & Pusht)
|
||||
pip install 'lerobot[feetech]' # Feetech motor support
|
||||
lerobot-train \
|
||||
--policy=act \
|
||||
--dataset.repo_id=lerobot/aloha_mobile_cabinet
|
||||
```
|
||||
|
||||
_Replace `[...]` with your desired features._
|
||||
| Category | Models |
|
||||
| -------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| **Imitation Learning** | [ACT](./docs/source/policy_act_README.md), [Diffusion](./docs/source/policy_diffusion_README.md), [VQ-BeT](./docs/source/policy_vqbet_README.md) |
|
||||
| **Reinforcement Learning** | [HIL-SERL](./docs/source/hilserl.mdx), [TDMPC](./docs/source/policy_tdmpc_README.md) & QC-FQL (coming soon) |
|
||||
| **VLAs Models** | [Pi0.5](./docs/source/pi05.mdx), [GR00T N1.5](./docs/source/policy_groot_README.md), [SmolVLA](./docs/source/policy_smolvla_README.md), [XVLA](./docs/source/xvla.mdx) |
|
||||
|
||||
**Available Tags:**
|
||||
For a full list of optional dependencies, see:
|
||||
https://pypi.org/project/lerobot/
|
||||
Similarly to the hardware, you can easily implement your own policy & leverage LeRobot's data collection, training, and visualization tools, and share your model to the HF Hub
|
||||
|
||||
> [!NOTE]
|
||||
> For lerobot 0.4.0, if you want to install pi tags, you will have to do: `pip install "lerobot[pi]@git+https://github.com/huggingface/lerobot.git"`.
|
||||
>
|
||||
> This will be solved in the next patch release
|
||||
For detailed policy setup guides, see the [Policy Documentation](https://huggingface.co/docs/lerobot/bring_your_own_policies).
|
||||
|
||||
### Weights & Biases
|
||||
## Inference & Evaluation
|
||||
|
||||
To use [Weights and Biases](https://docs.wandb.ai/quickstart) for experiment tracking, log in with
|
||||
Evaluate your policies in simulation or on real hardware using the unified evaluation script. LeRobot supports standard benchmarks like **LIBERO**, **MetaWorld** and more to come.
|
||||
|
||||
```bash
|
||||
wandb login
|
||||
# Evaluate a policy on the LIBERO benchmark
|
||||
lerobot-eval \
|
||||
--policy.path=lerobot/pi0_libero_finetuned \
|
||||
--env.type=libero \
|
||||
--env.task=libero_object \
|
||||
--eval.n_episodes=10
|
||||
```
|
||||
|
||||
(note: you will also need to enable WandB in the configuration. See below.)
|
||||
Learn how to implement your own simulation environment or benchmark and distribute it from the HF Hub by following the [EnvHub Documentation](https://huggingface.co/docs/lerobot/envhub)
|
||||
|
||||
### Visualize datasets
|
||||
## Resources
|
||||
|
||||
Check out [example 1](https://github.com/huggingface/lerobot/blob/main/examples/dataset/load_lerobot_dataset.py) that illustrates how to use our dataset class which automatically downloads data from the Hugging Face hub.
|
||||
|
||||
You can also locally visualize episodes from a dataset on the hub by executing our script from the command line:
|
||||
|
||||
```bash
|
||||
lerobot-dataset-viz \
|
||||
--repo-id lerobot/pusht \
|
||||
--episode-index 0
|
||||
```
|
||||
|
||||
or from a dataset in a local folder with the `root` option and the `--mode local` (in the following case the dataset will be searched for in `./my_local_data_dir/lerobot/pusht`)
|
||||
|
||||
```bash
|
||||
lerobot-dataset-viz \
|
||||
--repo-id lerobot/pusht \
|
||||
--root ./my_local_data_dir \
|
||||
--mode local \
|
||||
--episode-index 0
|
||||
```
|
||||
|
||||
It will open `rerun.io` and display the camera streams, robot states and actions, like this:
|
||||
|
||||
https://github-production-user-asset-6210df.s3.amazonaws.com/4681518/328035972-fd46b787-b532-47e2-bb6f-fd536a55a7ed.mov?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240505%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240505T172924Z&X-Amz-Expires=300&X-Amz-Signature=d680b26c532eeaf80740f08af3320d22ad0b8a4e4da1bcc4f33142c15b509eda&X-Amz-SignedHeaders=host&actor_id=24889239&key_id=0&repo_id=748713144
|
||||
|
||||
Our script can also visualize datasets stored on a distant server. See `lerobot-dataset-viz --help` for more instructions.
|
||||
|
||||
### The `LeRobotDataset` format
|
||||
|
||||
A dataset in `LeRobotDataset` format is very simple to use. It can be loaded from a repository on the Hugging Face hub or a local folder simply with e.g. `dataset = LeRobotDataset("lerobot/aloha_static_coffee")` and can be indexed into like any Hugging Face and PyTorch dataset. For instance `dataset[0]` will retrieve a single temporal frame from the dataset containing observation(s) and an action as PyTorch tensors ready to be fed to a model.
|
||||
|
||||
A specificity of `LeRobotDataset` is that, rather than retrieving a single frame by its index, we can retrieve several frames based on their temporal relationship with the indexed frame, by setting `delta_timestamps` to a list of relative times with respect to the indexed frame. For example, with `delta_timestamps = {"observation.image": [-1, -0.5, -0.2, 0]}` one can retrieve, for a given index, 4 frames: 3 "previous" frames 1 second, 0.5 seconds, and 0.2 seconds before the indexed frame, and the indexed frame itself (corresponding to the 0 entry). See example [1_load_lerobot_dataset.py](https://github.com/huggingface/lerobot/blob/main/examples/dataset/load_lerobot_dataset.py) for more details on `delta_timestamps`.
|
||||
|
||||
Under the hood, the `LeRobotDataset` format makes use of several ways to serialize data which can be useful to understand if you plan to work more closely with this format. We tried to make a flexible yet simple dataset format that would cover most type of features and specificities present in reinforcement learning and robotics, in simulation and in real-world, with a focus on cameras and robot states but easily extended to other types of sensory inputs as long as they can be represented by a tensor.
|
||||
|
||||
Here are the important details and internal structure organization of a typical `LeRobotDataset` instantiated with `dataset = LeRobotDataset("lerobot/aloha_static_coffee")`. The exact features will change from dataset to dataset but not the main aspects:
|
||||
|
||||
```
|
||||
dataset attributes:
|
||||
├ hf_dataset: a Hugging Face dataset (backed by Arrow/parquet). Typical features example:
|
||||
│ ├ observation.images.cam_high (VideoFrame):
|
||||
│ │ VideoFrame = {'path': path to a mp4 video, 'timestamp' (float32): timestamp in the video}
|
||||
│ ├ observation.state (list of float32): position of an arm joints (for instance)
|
||||
│ ... (more observations)
|
||||
│ ├ action (list of float32): goal position of an arm joints (for instance)
|
||||
│ ├ episode_index (int64): index of the episode for this sample
|
||||
│ ├ frame_index (int64): index of the frame for this sample in the episode ; starts at 0 for each episode
|
||||
│ ├ timestamp (float32): timestamp in the episode
|
||||
│ ├ next.done (bool): indicates the end of an episode ; True for the last frame in each episode
|
||||
│ └ index (int64): general index in the whole dataset
|
||||
├ meta: a LeRobotDatasetMetadata object containing:
|
||||
│ ├ info: a dictionary of metadata on the dataset
|
||||
│ │ ├ codebase_version (str): this is to keep track of the codebase version the dataset was created with
|
||||
│ │ ├ fps (int): frame per second the dataset is recorded/synchronized to
|
||||
│ │ ├ features (dict): all features contained in the dataset with their shapes and types
|
||||
│ │ ├ total_episodes (int): total number of episodes in the dataset
|
||||
│ │ ├ total_frames (int): total number of frames in the dataset
|
||||
│ │ ├ robot_type (str): robot type used for recording
|
||||
│ │ ├ data_path (str): formattable string for the parquet files
|
||||
│ │ └ video_path (str): formattable string for the video files (if using videos)
|
||||
│ ├ episodes: a DataFrame containing episode metadata with columns:
|
||||
│ │ ├ episode_index (int): index of the episode
|
||||
│ │ ├ tasks (list): list of tasks for this episode
|
||||
│ │ ├ length (int): number of frames in this episode
|
||||
│ │ ├ dataset_from_index (int): start index of this episode in the dataset
|
||||
│ │ └ dataset_to_index (int): end index of this episode in the dataset
|
||||
│ ├ stats: a dictionary of statistics (max, mean, min, std) for each feature in the dataset, for instance
|
||||
│ │ ├ observation.images.front_cam: {'max': tensor with same number of dimensions (e.g. `(c, 1, 1)` for images, `(c,)` for states), etc.}
|
||||
│ │ └ ...
|
||||
│ └ tasks: a DataFrame containing task information with task names as index and task_index as values
|
||||
├ root (Path): local directory where the dataset is stored
|
||||
├ image_transforms (Callable): optional image transformations to apply to visual modalities
|
||||
└ delta_timestamps (dict): optional delta timestamps for temporal queries
|
||||
```
|
||||
|
||||
A `LeRobotDataset` is serialised using several widespread file formats for each of its parts, namely:
|
||||
|
||||
- hf_dataset stored using Hugging Face datasets library serialization to parquet
|
||||
- videos are stored in mp4 format to save space
|
||||
- metadata are stored in plain json/jsonl files
|
||||
|
||||
Dataset can be uploaded/downloaded from the HuggingFace hub seamlessly. To work on a local dataset, you can specify its location with the `root` argument if it's not in the default `~/.cache/huggingface/lerobot` location.
|
||||
|
||||
#### Reproduce state-of-the-art (SOTA)
|
||||
|
||||
We provide some pretrained policies on our [hub page](https://huggingface.co/lerobot) that can achieve state-of-the-art performances.
|
||||
You can reproduce their training by loading the config from their run. Simply running:
|
||||
|
||||
```bash
|
||||
lerobot-train --config_path=lerobot/diffusion_pusht
|
||||
```
|
||||
|
||||
reproduces SOTA results for Diffusion Policy on the PushT task.
|
||||
|
||||
## Contribute
|
||||
|
||||
If you would like to contribute to 🤗 LeRobot, please check out our [contribution guide](https://github.com/huggingface/lerobot/blob/main/CONTRIBUTING.md).
|
||||
|
||||
### Add a pretrained policy
|
||||
|
||||
Once you have trained a policy you may upload it to the Hugging Face hub using a hub id that looks like `${hf_user}/${repo_name}` (e.g. [lerobot/diffusion_pusht](https://huggingface.co/lerobot/diffusion_pusht)).
|
||||
|
||||
You first need to find the checkpoint folder located inside your experiment directory (e.g. `outputs/train/2024-05-05/20-21-12_aloha_act_default/checkpoints/002500`). Within that there is a `pretrained_model` directory which should contain:
|
||||
|
||||
- `config.json`: A serialized version of the policy configuration (following the policy's dataclass config).
|
||||
- `model.safetensors`: A set of `torch.nn.Module` parameters, saved in [Hugging Face Safetensors](https://huggingface.co/docs/safetensors/index) format.
|
||||
- `train_config.json`: A consolidated configuration containing all parameters used for training. The policy configuration should match `config.json` exactly. This is useful for anyone who wants to evaluate your policy or for reproducibility.
|
||||
|
||||
To upload these to the hub, run the following:
|
||||
|
||||
```bash
|
||||
huggingface-cli upload ${hf_user}/${repo_name} path/to/pretrained_model
|
||||
```
|
||||
|
||||
See [lerobot_eval.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/scripts/lerobot_eval.py) for an example of how other people may use your policy.
|
||||
|
||||
### Acknowledgment
|
||||
|
||||
- The LeRobot team 🤗 for building SmolVLA [Paper](https://arxiv.org/abs/2506.01844), [Blog](https://huggingface.co/blog/smolvla).
|
||||
- Thanks to Tony Zhao, Zipeng Fu and colleagues for open sourcing ACT policy, ALOHA environments and datasets. Ours are adapted from [ALOHA](https://tonyzhaozh.github.io/aloha) and [Mobile ALOHA](https://mobile-aloha.github.io).
|
||||
- Thanks to Cheng Chi, Zhenjia Xu and colleagues for open sourcing Diffusion policy, Pusht environment and datasets, as well as UMI datasets. Ours are adapted from [Diffusion Policy](https://diffusion-policy.cs.columbia.edu) and [UMI Gripper](https://umi-gripper.github.io).
|
||||
- Thanks to Nicklas Hansen, Yunhai Feng and colleagues for open sourcing TDMPC policy, Simxarm environments and datasets. Ours are adapted from [TDMPC](https://github.com/nicklashansen/tdmpc) and [FOWM](https://www.yunhaifeng.com/FOWM).
|
||||
- Thanks to Antonio Loquercio and Ashish Kumar for their early support.
|
||||
- Thanks to [Seungjae (Jay) Lee](https://sjlee.cc/), [Mahi Shafiullah](https://mahis.life/) and colleagues for open sourcing [VQ-BeT](https://sjlee.cc/vq-bet/) policy and helping us adapt the codebase to our repository. The policy is adapted from [VQ-BeT repo](https://github.com/jayLEE0301/vq_bet_official).
|
||||
- **[Documentation](https://huggingface.co/docs/lerobot/index):** The complete guide to tutorials & API.
|
||||
- **[Discord](https://discord.gg/3gxM6Avj):** Join the `LeRobot` server to discuss with the community.
|
||||
- **[X](https://x.com/LeRobotHF):** Follow us on X to stay up-to-date with the latest developments.
|
||||
- **[Robot Learning Tutorial](https://huggingface.co/spaces/lerobot/robot-learning-tutorial):** A free, hands-on course to learn robot learning using LeRobot.
|
||||
|
||||
## Citation
|
||||
|
||||
If you want, you can cite this work with:
|
||||
If you use LeRobot in your research, please cite:
|
||||
|
||||
```bibtex
|
||||
@misc{cadene2024lerobot,
|
||||
@@ -339,6 +144,14 @@ If you want, you can cite this work with:
|
||||
}
|
||||
```
|
||||
|
||||
## Star History
|
||||
## Contribute
|
||||
|
||||
[](https://star-history.com/#huggingface/lerobot&Timeline)
|
||||
We welcome contributions from everyone in the community! To get started, please read our [CONTRIBUTING.md](./CONTRIBUTING.md) guide. Whether you're adding a new feature, improving documentation, or fixing a bug, your help and feedback are invaluable. We're incredibly excited about the future of open-source robotics and can't wait to work with you on what's next—thank you for your support!
|
||||
|
||||
<p align="center">
|
||||
<img alt="SO101 Video" src="./media/readme/so100_video.webp" width="640px">
|
||||
</p>
|
||||
|
||||
<div align="center">
|
||||
<sub>Built by the <a href="https://huggingface.co/lerobot">LeRobot</a> team at <a href="https://huggingface.co">Hugging Face</a> with ❤️</sub>
|
||||
</div>
|
||||
|
||||
@@ -9,6 +9,8 @@
|
||||
title: Imitation Learning for Robots
|
||||
- local: cameras
|
||||
title: Cameras
|
||||
- local: bring_your_own_policies
|
||||
title: Bring Your Own Policies
|
||||
- local: integrate_hardware
|
||||
title: Bring Your Own Hardware
|
||||
- local: hilserl
|
||||
@@ -37,7 +39,15 @@
|
||||
title: π₀.₅ (Pi05)
|
||||
- local: groot
|
||||
title: NVIDIA GR00T N1.5
|
||||
- local: xvla
|
||||
title: X-VLA
|
||||
- local: walloss
|
||||
title: WALL-OSS
|
||||
title: "Policies"
|
||||
- sections:
|
||||
- local: sarm
|
||||
title: SARM
|
||||
title: "Reward Models"
|
||||
- sections:
|
||||
- local: async
|
||||
title: Use Async Inference
|
||||
@@ -81,11 +91,17 @@
|
||||
title: Reachy 2
|
||||
- local: unitree_g1
|
||||
title: Unitree G1
|
||||
- local: earthrover_mini_plus
|
||||
title: Earth Rover Mini
|
||||
title: "Robots"
|
||||
- sections:
|
||||
- local: phone_teleop
|
||||
title: Phone
|
||||
title: "Teleoperators"
|
||||
- sections:
|
||||
- local: torch_accelerators
|
||||
title: PyTorch accelerators
|
||||
title: "Supported Hardware"
|
||||
- sections:
|
||||
- local: notebooks
|
||||
title: Notebooks
|
||||
|
||||
@@ -278,7 +278,7 @@ We found the default values of `actions_per_chunk` and `chunk_size_threshold` to
|
||||
2. **Adjust your `fps` based on inference latency.** While the server generates a new action chunk, the client is not idle and is stepping through its current action queue. If the two processes happen at fundamentally different speeds, the client might end up with an empty queue. As such, you should reduce your fps if you consistently run out of actions in queue.
|
||||
3. **Adjust `chunk_size_threshold`**.
|
||||
- Values closer to `0.0` result in almost sequential behavior. Values closer to `1.0` → send observation every step (more bandwidth, relies on good world-model).
|
||||
- We found values around 0.5-0.6 to work well. If you want to tweak this, spin up a `RobotClient` setting the `--debug-visualize-queue-size` to `True`. This will plot the action queue size evolution at runtime, and you can use it to find the value of `chunk_size_threshold` that works best for your setup.
|
||||
- We found values around 0.5-0.6 to work well. If you want to tweak this, spin up a `RobotClient` setting the `--debug_visualize_queue_size` to `True`. This will plot the action queue size evolution at runtime, and you can use it to find the value of `chunk_size_threshold` that works best for your setup.
|
||||
|
||||
<p align="center">
|
||||
<img
|
||||
@@ -289,7 +289,7 @@ We found the default values of `actions_per_chunk` and `chunk_size_threshold` to
|
||||
<p align="center">
|
||||
<i>
|
||||
The action queue size is plotted at runtime when the
|
||||
`--debug-visualize-queue-size` flag is passed, for various levels of
|
||||
`--debug_visualize_queue_size` flag is passed, for various levels of
|
||||
`chunk_size_threshold` (`g` in the SmolVLA paper).
|
||||
</i>
|
||||
</p>
|
||||
|
||||
@@ -0,0 +1,175 @@
|
||||
# Bring Your Own Policies
|
||||
|
||||
This tutorial explains how to integrate your own custom policy implementations into the LeRobot ecosystem, allowing you to leverage all LeRobot tools for training, evaluation, and deployment while using your own algorithms.
|
||||
|
||||
## Step 1: Create a Policy Package
|
||||
|
||||
Your custom policy should be organized as an installable Python package following LeRobot's plugin conventions.
|
||||
|
||||
### Package Structure
|
||||
|
||||
Create a package with the prefix `lerobot_policy_` (IMPORTANT!) followed by your policy name:
|
||||
|
||||
```bash
|
||||
lerobot_policy_my_custom_policy/
|
||||
├── pyproject.toml
|
||||
└── src/
|
||||
└── lerobot_policy_my_custom_policy/
|
||||
├── __init__.py
|
||||
├── configuration_my_custom_policy.py
|
||||
├── modeling_my_custom_policy.py
|
||||
└── processor_my_custom_policy.py
|
||||
```
|
||||
|
||||
### Package Configuration
|
||||
|
||||
Set up your `pyproject.toml`:
|
||||
|
||||
```toml
|
||||
[project]
|
||||
name = "lerobot_policy_my_custom_policy"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
# your policy-specific dependencies
|
||||
]
|
||||
requires-python = ">= 3.11"
|
||||
|
||||
[build-system]
|
||||
build-backend = # your-build-backend
|
||||
requires = # your-build-system
|
||||
```
|
||||
|
||||
## Step 2: Define the Policy Configuration
|
||||
|
||||
Create a configuration class that inherits from `PreTrainedConfig` and registers your policy type:
|
||||
|
||||
```python
|
||||
# configuration_my_custom_policy.py
|
||||
from dataclasses import dataclass, field
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import NormalizationMode
|
||||
|
||||
@PreTrainedConfig.register_subclass("my_custom_policy")
|
||||
@dataclass
|
||||
class MyCustomPolicyConfig(PreTrainedConfig):
|
||||
"""Configuration class for MyCustomPolicy.
|
||||
|
||||
Args:
|
||||
n_obs_steps: Number of observation steps to use as input
|
||||
horizon: Action prediction horizon
|
||||
n_action_steps: Number of action steps to execute
|
||||
hidden_dim: Hidden dimension for the policy network
|
||||
# Add your policy-specific parameters here
|
||||
"""
|
||||
# ...PreTrainedConfig fields...
|
||||
pass
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
# Add any validation logic here
|
||||
|
||||
def validate_features(self) -> None:
|
||||
"""Validate input/output feature compatibility."""
|
||||
# Implement validation logic for your policy's requirements
|
||||
pass
|
||||
```
|
||||
|
||||
## Step 3: Implement the Policy Class
|
||||
|
||||
Create your policy implementation by inheriting from LeRobot's base `PreTrainedPolicy` class:
|
||||
|
||||
```python
|
||||
# modeling_my_custom_policy.py
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Dict, Any
|
||||
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from .configuration_my_custom_policy import MyCustomPolicyConfig
|
||||
|
||||
class MyCustomPolicy(PreTrainedPolicy):
|
||||
config_class = MyCustomPolicyConfig
|
||||
name = "my_custom_policy"
|
||||
|
||||
def __init__(self, config: MyCustomPolicyConfig, dataset_stats: Dict[str, Any] = None):
|
||||
super().__init__(config, dataset_stats)
|
||||
...
|
||||
```
|
||||
|
||||
## Step 4: Add Data Processors
|
||||
|
||||
Create processor functions:
|
||||
|
||||
```python
|
||||
# processor_my_custom_policy.py
|
||||
from typing import Dict, Any
|
||||
import torch
|
||||
|
||||
|
||||
def make_my_custom_policy_pre_post_processors(
|
||||
config,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
"""Create preprocessing and postprocessing functions for your policy."""
|
||||
pass # Define your preprocessing and postprocessing logic here
|
||||
|
||||
```
|
||||
|
||||
## Step 5: Package Initialization
|
||||
|
||||
Expose your classes in the package's `__init__.py`:
|
||||
|
||||
```python
|
||||
# __init__.py
|
||||
"""Custom policy package for LeRobot."""
|
||||
|
||||
try:
|
||||
import lerobot # noqa: F401
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"lerobot is not installed. Please install lerobot to use this policy package."
|
||||
)
|
||||
|
||||
from .configuration_my_custom_policy import MyCustomPolicyConfig
|
||||
from .modeling_my_custom_policy import MyCustomPolicy
|
||||
from .processor_my_custom_policy import make_my_custom_policy_pre_post_processors
|
||||
|
||||
__all__ = [
|
||||
"MyCustomPolicyConfig",
|
||||
"MyCustomPolicy",
|
||||
"make_my_custom_policy_pre_post_processors",
|
||||
]
|
||||
```
|
||||
|
||||
## Step 6: Installation and Usage
|
||||
|
||||
### Install Your Policy Package
|
||||
|
||||
```bash
|
||||
cd lerobot_policy_my_custom_policy
|
||||
pip install -e .
|
||||
|
||||
# Or install from PyPI if published
|
||||
pip install lerobot_policy_my_custom_policy
|
||||
```
|
||||
|
||||
### Use Your Policy
|
||||
|
||||
Once installed, your policy automatically integrates with LeRobot's training and evaluation tools:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--policy.type my_custom_policy \
|
||||
--env.type pusht \
|
||||
--steps 200000
|
||||
```
|
||||
|
||||
## Examples and Community Contributions
|
||||
|
||||
Check out these example policy implementations:
|
||||
|
||||
- [DiTFlow Policy](https://github.com/danielsanjosepro/lerobot_policy_ditflow) - Diffusion Transformer policy with flow-matching objective. Try it out in this example: [DiTFlow Example](https://github.com/danielsanjosepro/test_lerobot_policy_ditflow)
|
||||
|
||||
Share your policy implementations with the community! 🤗
|
||||
@@ -0,0 +1,206 @@
|
||||
# EarthRover Mini Plus
|
||||
|
||||
The EarthRover Mini Plus is a fully open source mobile robot that connects through the cloud using the Frodobots SDK. This lets you control the robot and record datasets for training AI models.
|
||||
|
||||
## What You Need
|
||||
|
||||
### Hardware
|
||||
|
||||
- EarthRover Mini robot
|
||||
- Computer with Python 3.10 or newer
|
||||
- Internet connection
|
||||
|
||||
### Setting Up the Frodobots SDK
|
||||
|
||||
The robot needs the [Frodobots SDK](https://github.com/Frodobots/earth-rovers-sdk) running on your computer. Here's how:
|
||||
|
||||
1. Download and install the SDK:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/Frodobots/earth-rovers-sdk.git
|
||||
cd earth-rovers-sdk
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
2. Start the SDK:
|
||||
|
||||
```bash
|
||||
hypercorn main:app --reload
|
||||
```
|
||||
|
||||
3. Open your web browser and go to `http://localhost:8000`, then click "Join"
|
||||
|
||||
The SDK gives you:
|
||||
|
||||
- Live video from front and rear cameras
|
||||
|
||||
> [!IMPORTANT]
|
||||
> The SDK must be running before you can use the robot.
|
||||
|
||||
## Install LeRobot
|
||||
|
||||
Follow our [Installation Guide](./installation) to install LeRobot.
|
||||
|
||||
In addition to the base installation, install the EarthRover Mini dependencies:
|
||||
|
||||
```bash
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
## How It Works
|
||||
|
||||
The robot uses the internet to communicate:
|
||||
|
||||
- **Movement commands**: Sent through the SDK
|
||||
- **Camera video**: Received from the SDK
|
||||
- **Robot info**: Battery, location, speed from the SDK
|
||||
|
||||
You don't need to plug anything in - it all works through the SDK.
|
||||
|
||||
## Calibration
|
||||
|
||||
No calibration needed! The robot is ready to use as soon as the SDK is running.
|
||||
|
||||
## Controlling the Robot
|
||||
|
||||
You control the robot using your keyboard - just like playing a video game with WASD keys.
|
||||
|
||||
### Keyboard Controls
|
||||
|
||||
| Key | Action |
|
||||
| --- | -------------------------------- |
|
||||
| W | Move forward |
|
||||
| S | Move backward |
|
||||
| A | Turn left (with forward motion) |
|
||||
| D | Turn right (with forward motion) |
|
||||
| Q | Rotate left in place |
|
||||
| E | Rotate right in place |
|
||||
| X | Stop all movement |
|
||||
| +/= | Increase speed |
|
||||
| - | Decrease speed |
|
||||
| ESC | Disconnect |
|
||||
|
||||
### Speed Settings
|
||||
|
||||
You can adjust how fast the robot moves:
|
||||
|
||||
- **Forward/backward speed**: Default is full speed (1.0)
|
||||
- **Turning speed**: Default is full speed (1.0)
|
||||
- **Speed changes**: Use +/- keys to adjust by 0.1 each time
|
||||
|
||||
### Try It Out
|
||||
|
||||
Test driving the robot before recording data:
|
||||
|
||||
```python
|
||||
from lerobot.robots.earthrover_mini_plus import EarthRoverMiniPlus, EarthRoverMiniPlusConfig
|
||||
from lerobot.teleoperators.keyboard import KeyboardRoverTeleop, KeyboardRoverTeleopConfig
|
||||
|
||||
# Initialize robot
|
||||
robot_config = EarthRoverMiniPlusConfig()
|
||||
robot = EarthRoverMiniPlus(robot_config)
|
||||
|
||||
# Initialize teleoperator
|
||||
teleop_config = KeyboardRoverTeleopConfig(
|
||||
linear_speed=1.0,
|
||||
angular_speed=1.0,
|
||||
speed_increment=0.1
|
||||
)
|
||||
teleop = KeyboardRoverTeleop(teleop_config)
|
||||
|
||||
# Connect
|
||||
robot.connect()
|
||||
teleop.connect()
|
||||
|
||||
# Teleoperate (use keyboard controls)
|
||||
try:
|
||||
while True:
|
||||
action = teleop.get_action()
|
||||
robot.send_action(action)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
robot.disconnect()
|
||||
teleop.disconnect()
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
> If you're using a Mac, you might need to give Terminal permission to access your keyboard for teleoperation. Go to System Preferences > Security & Privacy > Input Monitoring and check the box for Terminal.
|
||||
|
||||
## Recording Data
|
||||
|
||||
Once you can drive the robot well, you can start recording data to train AI models. The system records:
|
||||
|
||||
- **What you do**: How you move the robot (forward, backward, turning)
|
||||
- **What the robot sees**:
|
||||
- Videos from both cameras
|
||||
- Robot speed and direction
|
||||
- Battery level and location
|
||||
- GPS position and signal
|
||||
- Other sensor data
|
||||
- **When it happened**: Timestamps for everything
|
||||
|
||||
### Setting Up Hugging Face
|
||||
|
||||
We use Hugging Face to store your data online. First, log in with your token from [Hugging Face settings](https://huggingface.co/settings/tokens):
|
||||
|
||||
```bash
|
||||
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
|
||||
```
|
||||
|
||||
Store your Hugging Face username:
|
||||
|
||||
```bash
|
||||
HF_USER=$(huggingface-cli whoami | head -n 1)
|
||||
echo $HF_USER
|
||||
```
|
||||
|
||||
### Start Recording
|
||||
|
||||
Use the standard recording command:
|
||||
|
||||
```bash
|
||||
python src/lerobot/scripts/lerobot_record.py \
|
||||
--robot.type=earthrover_mini_plus \
|
||||
--teleop.type=keyboard_rover \
|
||||
--dataset.repo_id=your_username/dataset_name \
|
||||
--dataset.num_episodes=2 \
|
||||
--dataset.fps=10 \
|
||||
--dataset.single_task="Navigate around obstacles" \
|
||||
--display_data=true
|
||||
```
|
||||
|
||||
Replace `your_username/dataset_name` with your Hugging Face username and a name for your dataset.
|
||||
|
||||
### What Gets Saved
|
||||
|
||||
Your dataset includes:
|
||||
|
||||
**Your Actions (2 things)**:
|
||||
|
||||
- How much you moved forward/backward
|
||||
- How much you turned left/right
|
||||
|
||||
**Robot Observations (12 things)**:
|
||||
|
||||
- Front camera video
|
||||
- Rear camera video
|
||||
- Current speed
|
||||
- Battery level
|
||||
- Which way the robot is facing
|
||||
- GPS location (latitude, longitude, signal strength)
|
||||
- Network signal strength
|
||||
- Vibration level
|
||||
- Lamp status (on/off)
|
||||
|
||||
### Where Your Data Goes
|
||||
|
||||
On your computer: `~/.cache/huggingface/lerobot/{repo-id}`
|
||||
|
||||
After recording, your data automatically uploads to your Hugging Face page:
|
||||
|
||||
```bash
|
||||
echo https://huggingface.co/datasets/${HF_USER}/earthrover-navigation
|
||||
```
|
||||
|
||||
Your dataset will be tagged with `LeRobot` for community discovery.
|
||||
@@ -201,7 +201,8 @@ from lerobot.teleoperators.so100_leader.so100_leader import SO100Leader
|
||||
from lerobot.utils.control_utils import init_keyboard_listener
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import init_rerun
|
||||
from lerobot.record import record_loop
|
||||
from lerobot.scripts.lerobot_record import record_loop
|
||||
from lerobot.processor import make_default_processors
|
||||
|
||||
NUM_EPISODES = 5
|
||||
FPS = 30
|
||||
@@ -209,12 +210,19 @@ EPISODE_TIME_SEC = 60
|
||||
RESET_TIME_SEC = 10
|
||||
TASK_DESCRIPTION = "My task description"
|
||||
|
||||
# Create the robot and teleoperator configurations
|
||||
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
|
||||
# Create robot configuration
|
||||
robot_config = SO100FollowerConfig(
|
||||
port="/dev/tty.usbmodem58760434471", id="my_awesome_follower_arm", cameras=camera_config
|
||||
id="my_awesome_follower_arm",
|
||||
cameras={
|
||||
"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS) # Optional: fourcc="MJPG" for troubleshooting OpenCV async error.
|
||||
},
|
||||
port="/dev/tty.usbmodem58760434471",
|
||||
)
|
||||
|
||||
teleop_config = SO100LeaderConfig(
|
||||
id="my_awesome_leader_arm",
|
||||
port="/dev/tty.usbmodem585A0077581",
|
||||
)
|
||||
teleop_config = SO100LeaderConfig(port="/dev/tty.usbmodem585A0077581", id="my_awesome_leader_arm")
|
||||
|
||||
# Initialize the robot and teleoperator
|
||||
robot = SO100Follower(robot_config)
|
||||
@@ -243,6 +251,9 @@ init_rerun(session_name="recording")
|
||||
robot.connect()
|
||||
teleop.connect()
|
||||
|
||||
# Create the required processors
|
||||
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
|
||||
|
||||
episode_idx = 0
|
||||
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
@@ -251,6 +262,9 @@ while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
teleop=teleop,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
@@ -265,6 +279,9 @@ while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
teleop=teleop,
|
||||
control_time_s=RESET_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
@@ -428,7 +445,7 @@ Your robot should replicate movements similar to those you recorded. For example
|
||||
|
||||
## Train a policy
|
||||
|
||||
To train a policy to control your robot, use the [`lerobot-train`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
|
||||
To train a policy to control your robot, use the [`lerobot-train`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/scripts/lerobot_train.py) script. A few arguments are required. Here is an example command:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
@@ -485,7 +502,7 @@ huggingface-cli upload ${HF_USER}/act_so101_test${CKPT} \
|
||||
|
||||
## Run inference and evaluate your policy
|
||||
|
||||
You can use the `record` script from [`lerobot/record.py`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/record.py) with a policy checkpoint as input, to run inference and evaluate your policy. For instance, run this command or API example to run inference and record 10 evaluation episodes:
|
||||
You can use the `record` script from [`lerobot-record`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/scripts/lerobot_record.py) with a policy checkpoint as input, to run inference and evaluate your policy. For instance, run this command or API example to run inference and record 10 evaluation episodes:
|
||||
|
||||
<hfoptions id="eval">
|
||||
<hfoption id="Command">
|
||||
|
||||
@@ -90,7 +90,7 @@ If you encounter build errors, you may need to install additional dependencies:
|
||||
To install these for linux run:
|
||||
|
||||
```bash
|
||||
sudo apt-get install cmake build-essential python-dev pkg-config libavformat-dev libavcodec-dev libavdevice-dev libavutil-dev libswscale-dev libswresample-dev libavfilter-dev pkg-config
|
||||
sudo apt-get install cmake build-essential python3-dev pkg-config libavformat-dev libavcodec-dev libavdevice-dev libavutil-dev libswscale-dev libswresample-dev libavfilter-dev
|
||||
```
|
||||
|
||||
For other systems, see: [Compiling PyAV](https://pyav.org/docs/develop/overview/installation.html#bring-your-own-ffmpeg)
|
||||
|
||||
@@ -62,6 +62,11 @@ lerobot-eval \
|
||||
|
||||
- Pass a comma-separated list to `--env.task` for multi-suite evaluation.
|
||||
|
||||
### Control Mode
|
||||
|
||||
LIBERO now supports two control modes: relative and absolute. This matters because different VLA checkpoints are trained with different mode of action to output hence control parameterizations.
|
||||
You can switch them with: `env.control_mode = "relative"` and `env.control_mode = "absolute"`
|
||||
|
||||
### Policy inputs and outputs
|
||||
|
||||
When using LIBERO through LeRobot, policies interact with the environment via **observations** and **actions**:
|
||||
|
||||
@@ -0,0 +1,328 @@
|
||||
# OpenArms Robot
|
||||
|
||||
OpenArms is a 7 DOF robotic arm with a gripper, designed by [Enactic, Inc.](https://www.enactic.com/) It uses Damiao motors controlled via CAN bus communication and MIT control mode for smooth, precise motion.
|
||||
|
||||
## Hardware Overview
|
||||
|
||||
- **7 DOF per arm** (14 DOF total for dual arm setup)
|
||||
- **1 gripper per arm** (2 grippers total)
|
||||
- **Damiao motors** with 4 different types:
|
||||
- **DM8009** (DM-J8009P-2EC) for shoulders (J1, J2) - high torque
|
||||
- **DM4340** for shoulder rotation and elbow (J3, J4)
|
||||
- **DM4310** (DM-J4310-2EC V1.1) for wrist (J5, J6, J7) and gripper (J8)
|
||||
- **24V power supply** required
|
||||
- **CAN interface device**:
|
||||
- **Linux**: Any SocketCAN-compatible adapter
|
||||
- **macOS**: CANable, PEAK PCAN-USB, or Kvaser USBcan
|
||||
- Proper CAN wiring (CANH, CANL, 120Ω termination)
|
||||
|
||||
|
||||
## Motor Configuration
|
||||
|
||||
Each arm has the following motor configuration based on the [OpenArm setup guide](https://docs.openarm.dev/software/setup/):
|
||||
|
||||
| Joint | Motor | Motor Type | Sender CAN ID | Receiver ID | Description |
|
||||
|-------|-------|------------|---------------|-------------|-------------|
|
||||
| J1 | joint_1 | DM8009 | 0x01 | 0x11 | Shoulder pan |
|
||||
| J2 | joint_2 | DM8009 | 0x02 | 0x12 | Shoulder lift |
|
||||
| J3 | joint_3 | DM4340 | 0x03 | 0x13 | Shoulder rotation |
|
||||
| J4 | joint_4 | DM4340 | 0x04 | 0x14 | Elbow flex |
|
||||
| J5 | joint_5 | DM4310 | 0x05 | 0x15 | Wrist roll |
|
||||
| J6 | joint_6 | DM4310 | 0x06 | 0x16 | Wrist pitch |
|
||||
| J7 | joint_7 | DM4310 | 0x07 | 0x17 | Wrist rotation |
|
||||
| J8 | gripper | DM4310 | 0x08 | 0x18 | Gripper |
|
||||
|
||||
For dual arm setups, the left arm uses IDs 0x09-0x10 for joints 1-8 with the same motor types.
|
||||
|
||||
## Quick Start
|
||||
|
||||
```bash
|
||||
# Install system dependencies
|
||||
sudo apt install can-utils iproute2
|
||||
|
||||
# Install LeRobot with OpenArms support
|
||||
pip install -e ".[openarms]"
|
||||
```
|
||||
|
||||
## Setup Guide
|
||||
|
||||
### Step 1: Motor ID Configuration
|
||||
|
||||
**IMPORTANT**: Before using the robot, motors must be configured with the correct CAN IDs.
|
||||
|
||||
Refer to the [OpenArm Motor ID Configuration Guide](https://docs.openarm.dev/software/setup/motor-id) for detailed instructions using the Damiao Debugging Tools on Windows.
|
||||
|
||||
Key points:
|
||||
- Each motor needs a unique **Sender CAN ID** (0x01-0x08)
|
||||
- Each motor needs a unique **Receiver/Master ID** (0x11-0x18)
|
||||
- Use the Damiao Debugging Tools to set these IDs
|
||||
|
||||
### Step 2: Setup CAN Interface
|
||||
|
||||
Configure your CAN interface as described in the [OpenArm CAN Setup Guide](https://docs.openarm.dev/software/setup/can-setup):
|
||||
|
||||
#### Linux (SocketCAN)
|
||||
|
||||
```bash
|
||||
# Find your CAN interface
|
||||
ip link show
|
||||
|
||||
# Configure can0, 1, 2, 3
|
||||
sudo ip link set can0 down
|
||||
sudo ip link set can0 type can bitrate 1000000
|
||||
sudo ip link set can0 up
|
||||
|
||||
sudo ip link set can1 down
|
||||
sudo ip link set can1 type can bitrate 1000000
|
||||
sudo ip link set can1 up
|
||||
|
||||
sudo ip link set can2 down
|
||||
sudo ip link set can2 type can bitrate 1000000
|
||||
sudo ip link set can2 up
|
||||
|
||||
sudo ip link set can3 down
|
||||
sudo ip link set can3 type can bitrate 1000000
|
||||
sudo ip link set can3 up
|
||||
|
||||
# Verify configuration
|
||||
ip link show can0
|
||||
```
|
||||
|
||||
or run:
|
||||
|
||||
`examples/openarms/setup_can.sh`
|
||||
|
||||
### Testing canbus and motor connection
|
||||
|
||||
Please run this script to check if all motors can be found and to find your can-fd speed: `python examples/openarms/debug_can_communication.py`
|
||||
|
||||
## Usage
|
||||
|
||||
### Basic Setup
|
||||
|
||||
|
||||
```python
|
||||
from lerobot.robots.openarms import OpenArmsFollower
|
||||
from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig
|
||||
|
||||
# Configure for dual arm setup
|
||||
config = OpenArmsFollowerConfig(
|
||||
port="can0",
|
||||
can_interface="socketcan", # Or "auto" for auto-detection
|
||||
id="openarms_dual",
|
||||
is_dual_arm=True,
|
||||
)
|
||||
|
||||
robot = OpenArmsFollower(config)
|
||||
robot.connect()
|
||||
```
|
||||
|
||||
### Calibration
|
||||
|
||||
On first use, you'll need to calibrate the robot:
|
||||
|
||||
```python
|
||||
robot.calibrate()
|
||||
```
|
||||
|
||||
The calibration process will:
|
||||
1. Disable torque on all motors
|
||||
2. Ask you to position arms in **hanging position with grippers closed**
|
||||
3. Set this as the zero position
|
||||
4. Ask you to move each joint through its full range
|
||||
5. Record min/max positions for each joint
|
||||
6. Save calibration to file
|
||||
|
||||
### Reading Observations
|
||||
|
||||
The robot provides comprehensive state information:
|
||||
|
||||
```python
|
||||
observation = robot.get_observation()
|
||||
|
||||
# Observation includes for each motor:
|
||||
# - {motor_name}.pos: Position in degrees
|
||||
# - {motor_name}.vel: Velocity in degrees/second
|
||||
# - {motor_name}.torque: Motor torque
|
||||
# - {camera_name}: Camera images (if configured)
|
||||
|
||||
print(f"Right arm joint 1 position: {observation['right_joint_1.pos']:.1f}°")
|
||||
print(f"Right arm joint 1 velocity: {observation['right_joint_1.vel']:.1f}°/s")
|
||||
print(f"Right arm joint 1 torque: {observation['right_joint_1.torque']:.3f} N·m")
|
||||
```
|
||||
|
||||
### Sending Actions
|
||||
|
||||
```python
|
||||
# Send target positions (in degrees)
|
||||
action = {
|
||||
"right_joint_1.pos": 45.0,
|
||||
"right_joint_2.pos": -30.0,
|
||||
# ... all joints
|
||||
"right_gripper.pos": 45.0, # Half-closed
|
||||
}
|
||||
|
||||
actual_action = robot.send_action(action)
|
||||
```
|
||||
|
||||
### Gripper Control
|
||||
|
||||
```python
|
||||
# Open gripper
|
||||
robot.open_gripper(arm="right")
|
||||
|
||||
# Close gripper
|
||||
robot.close_gripper(arm="right")
|
||||
```
|
||||
|
||||
## Safety Features
|
||||
|
||||
### 1. Maximum Relative Target
|
||||
|
||||
Limits how far a joint can move in a single command to prevent sudden movements:
|
||||
|
||||
```python
|
||||
config = OpenArmsFollowerConfig(
|
||||
port="can0",
|
||||
# Limit all joints to 10 degrees per command
|
||||
max_relative_target=10.0,
|
||||
|
||||
# Or set per-motor limits
|
||||
max_relative_target={
|
||||
"right_joint_1": 15.0, # Slower moving joint
|
||||
"right_joint_2": 10.0,
|
||||
"right_gripper": 5.0, # Very slow gripper
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
**How it works**: If current position is 50° and you command 80°, with `max_relative_target=10.0`, the robot will only move to 60° in that step.
|
||||
|
||||
### 2. Torque Limits
|
||||
|
||||
Control maximum torque output, especially important for grippers and teleoperation:
|
||||
|
||||
```python
|
||||
config = OpenArmsFollowerConfig(
|
||||
port="can0",
|
||||
# Gripper torque limit (fraction of motor's max torque)
|
||||
gripper_torque_limit=0.5, # 50% of max torque
|
||||
)
|
||||
```
|
||||
|
||||
Lower torque limits prevent damage when gripping delicate objects.
|
||||
|
||||
### 3. MIT Control Gains
|
||||
|
||||
Control responsiveness and stability via PID-like gains:
|
||||
|
||||
```python
|
||||
config = OpenArmsFollowerConfig(
|
||||
port="can0",
|
||||
position_kp=10.0, # Position gain (higher = more responsive)
|
||||
position_kd=0.5, # Velocity damping (higher = more damped)
|
||||
)
|
||||
```
|
||||
|
||||
**Guidelines**:
|
||||
- **For following (robot)**: Higher gains for responsiveness
|
||||
- `position_kp=10.0`, `position_kd=0.5`
|
||||
- **For teleoperation (leader)**: Lower gains or disable torque for manual movement
|
||||
- `manual_control=True` (torque disabled)
|
||||
|
||||
### 4. Velocity Limits
|
||||
|
||||
Velocity limits are enforced by the Damiao motors based on motor type. For DM4310:
|
||||
- Max velocity: 30 rad/s ≈ 1718°/s
|
||||
|
||||
The motors will automatically limit velocity to safe values.
|
||||
|
||||
## Teleoperation
|
||||
|
||||
### Leader Arm Setup
|
||||
|
||||
The leader arm is moved manually (torque disabled) to generate commands:
|
||||
|
||||
```python
|
||||
from lerobot.teleoperators.openarms import OpenArmsLeader
|
||||
from lerobot.teleoperators.openarms.config_openarms_leader import OpenArmsLeaderConfig
|
||||
|
||||
config = OpenArmsLeaderConfig(
|
||||
port="can1", # Separate CAN interface for leader
|
||||
id="openarms_leader",
|
||||
manual_control=True, # Torque disabled for manual movement
|
||||
is_dual_arm=True,
|
||||
)
|
||||
|
||||
leader = OpenArmsLeader(config)
|
||||
leader.connect()
|
||||
|
||||
# Read current position as action
|
||||
action = leader.get_action()
|
||||
# action contains positions for all joints in degrees
|
||||
```
|
||||
|
||||
### Safety Considerations for Teleoperation
|
||||
|
||||
1. **Use separate CAN interfaces** for leader and follower to avoid conflicts
|
||||
2. **Enable max_relative_target** on follower to smooth abrupt movements
|
||||
3. **Lower torque limits** on follower to prevent damage from tracking errors
|
||||
4. **Test with one arm** before enabling dual arm teleoperation
|
||||
5. **Have emergency stop** ready (power switch or CAN disable)
|
||||
|
||||
```python
|
||||
# Recommended follower config for teleoperation
|
||||
follower_config = OpenArmsFollowerConfig(
|
||||
port="can0",
|
||||
max_relative_target=5.0, # Small steps for smooth following
|
||||
gripper_torque_limit=0.3, # Low torque for safety
|
||||
position_kp=5.0, # Lower gains for gentler following
|
||||
position_kd=0.3,
|
||||
)
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Motor Shaking/Unstable
|
||||
|
||||
- **Lower control gains**: Reduce `position_kp` and `position_kd`
|
||||
- **Check calibration**: Re-run calibration procedure
|
||||
- **Verify power**: Insufficient current can cause instability
|
||||
- **Check mechanical**: Loose connections, binding, or damaged components
|
||||
|
||||
### CAN Bus Errors
|
||||
|
||||
```bash
|
||||
# Check for errors
|
||||
ip -s link show can0
|
||||
|
||||
# Reset CAN interface
|
||||
sudo ip link set can0 down
|
||||
sudo ip link set can0 up
|
||||
```
|
||||
|
||||
### Control Mode
|
||||
|
||||
OpenArms uses **MIT control mode** which allows simultaneous control of:
|
||||
- Position (degrees)
|
||||
- Velocity (degrees/second)
|
||||
- Torque (N·m)
|
||||
- Position gain (Kp)
|
||||
- Velocity damping (Kd)
|
||||
|
||||
### Communication
|
||||
|
||||
- **Protocol**: CAN 2.0 at 1 Mbps (or CAN-FD at 5 Mbps)
|
||||
- **Frame format**: Standard 11-bit IDs
|
||||
- **Update rate**: Typically 50-100 Hz depending on motor count
|
||||
- **Latency**: ~10-20ms per motor command
|
||||
|
||||
## References
|
||||
|
||||
- [OpenArm Official Documentation](https://docs.openarm.dev/)
|
||||
- [OpenArm Setup Guide](https://docs.openarm.dev/software/setup/)
|
||||
- [Motor ID Configuration](https://docs.openarm.dev/software/setup/motor-id)
|
||||
- [CAN Interface Setup](https://docs.openarm.dev/software/setup/can-setup)
|
||||
- [Motor Communication Test](https://docs.openarm.dev/software/setup/configure-test)
|
||||
- [Damiao Motor Documentation](https://wiki.seeedstudio.com/damiao_series/)
|
||||
- [Enactic GitHub](https://github.com/enactic/openarm_can)
|
||||
@@ -0,0 +1,35 @@
|
||||
# WALL-OSS
|
||||
|
||||
This repository contains the Hugging Face port of **WALL-OSS**, a Vision-Language-Action model for cross-embodiment robotic control based on Qwen2.5-VL with flow matching/FAST action prediction.
|
||||
|
||||
---
|
||||
|
||||
## Model Overview
|
||||
|
||||
| Feature | Description |
|
||||
| ------------------ | ----------------------------------------------------- | --- |
|
||||
| Base Model | Qwen2.5-VL (Vision-Language Model) |
|
||||
| Action Prediction | Flow Matching (diffusion) or FAST (discrete tokens) |
|
||||
| Architecture | Mixture of Experts (MoE) with action-specific routing | |
|
||||
| Multi-Modal Inputs | Vision (images/videos), Language, Proprioception |
|
||||
|
||||
---
|
||||
|
||||
## Citation
|
||||
|
||||
If you use this work, please cite:
|
||||
|
||||
```bibtex
|
||||
@article{zhai2025igniting,
|
||||
title = {Igniting VLMs Toward the Embodied Space},
|
||||
author = {Zhai, Andy and Liu, Brae and Fang, Bruno and Cai, Chalse and Ma, Ellie and Yin, Ethan and Wang, Hao and Zhou, Hugo and Wang, James and Shi, Lights and Liang, Lucy and Wang, Make and Wang, Qian and Gan, Roy and Yu, Ryan and Li, Shalfun and Liu, Starrick and Chen, Sylas and Chen, Vincent and Xu, Zach},
|
||||
journal = {arXiv preprint arXiv:2509.11766},
|
||||
year = {2025}
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## License
|
||||
|
||||
This port follows the **Apache 2.0 License**.
|
||||
@@ -0,0 +1,291 @@
|
||||
# RaC: Recovery and Correction Training
|
||||
|
||||
RaC (Recovery and Correction) is a human-in-the-loop data collection and training paradigm that improves robot policy performance on long-horizon tasks by explicitly teaching recovery and correction behaviors.
|
||||
|
||||
**Key References:**
|
||||
- [RaC: Robot Learning for Long-Horizon Tasks by Scaling Recovery and Correction](https://arxiv.org/abs/2509.07953) (Hu et al., 2025)
|
||||
- [HG-DAgger: Interactive Imitation Learning with Human Experts](https://arxiv.org/abs/1810.02890) (Kelly et al., 2019)
|
||||
- [π∗0.6: a VLA That Learns From Experience](https://pi.website/blog/pistar06) (Physical Intelligence, 2025)
|
||||
- [SARM: Stage-Aware Reward Modeling](https://arxiv.org/abs/2509.25358) (Chen et al., 2025)
|
||||
|
||||
---
|
||||
|
||||
## Why RaC? The Problem with Standard Data Collection
|
||||
|
||||
### Standard Behavioral Cloning Data Collection Limitations
|
||||
|
||||
Standard behavior cloning trains policies on successful demonstrations. This approach can be sensitive to distribution shift and compounding errors. Because during deployment small errors can cascade and push the robot into states never seen during training.
|
||||
This is where RaC and methods like Dagger and HG-DAgger come in.
|
||||
|
||||
### Prior Human-in-the-Loop Methods
|
||||
|
||||
**DAgger** (Dataset Aggregation) addresses distribution shift by:
|
||||
- Running the novice policy to collect states
|
||||
- Querying expert for correct actions at those states
|
||||
- Aggregating new labels into training set
|
||||
|
||||
**HG-DAgger** (Human-Gated DAgger) improves on DAgger by:
|
||||
- Giving human full control authority during interventions
|
||||
- Human takes over when unsafe, provides correction, returns control
|
||||
- Better action labels because human has uninterrupted control
|
||||
|
||||
### RaC
|
||||
|
||||
RaC explicitly collects **recovery + correction** data:
|
||||
|
||||
```
|
||||
BC/DAgger: policy → mistake → human corrects → continue
|
||||
RaC: policy → mistake → human RECOVERS (teleop back) → CORRECTS → END
|
||||
```
|
||||
|
||||
The critical insight is **Rule 1 (Recover then Correct)**:
|
||||
- Every intervention starts with human teleoperating back to an in-distribution state
|
||||
- Then human provides correction to complete the current subtask
|
||||
- Both segments are recorded as training data
|
||||
- This teaches the policy: "when things go wrong, go back and retry"
|
||||
|
||||
**Rule 2 (Terminate after Intervention)**:
|
||||
- Episode ends after correction completes
|
||||
- Avoids mixed policy/human data on later subtasks
|
||||
- Keeps data distribution clean
|
||||
|
||||
---
|
||||
|
||||
## Comparison Table
|
||||
|
||||
| Method | Data Type | Recovery Behavior | Correction Behavior |
|
||||
|--------|-----------|-------------------|---------------------|
|
||||
| BC | Success only | ✗ | ✗ |
|
||||
| DAgger | Success + corrections | ✗ | ✓ |
|
||||
| HG-DAgger | Success + corrections | Sometimes | ✓ |
|
||||
| RaC | Success + recovery + correction | ✓ Explicit | ✓ |
|
||||
|
||||
---
|
||||
|
||||
## The RaC Pipeline
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────────────────┐
|
||||
│ RaC Training Pipeline │
|
||||
├─────────────────────────────────────────────────────────────────────────┤
|
||||
│ │
|
||||
│ 1. PRE-TRAINING (Standard BC) │
|
||||
│ └─> Train initial policy on clean demonstrations │
|
||||
│ │
|
||||
│ 2. RAC DATA COLLECTION (Human-in-the-loop) │
|
||||
│ ├─> Policy runs autonomously │
|
||||
│ ├─> Human monitors and intervenes when failure imminent │
|
||||
│ │ ├─> RECOVERY: Human teleoperates robot back to good state │
|
||||
│ │ └─> CORRECTION: Human completes the current subtask │
|
||||
│ └─> Episode terminates after correction (Rule 2) │
|
||||
│ │
|
||||
│ 3. REWARD LABELING (Optional: SARM) │
|
||||
│ └─> Compute progress rewards for advantage-weighted training │
|
||||
│ │
|
||||
│ 4. FINE-TUNING │
|
||||
│ └─> Train on combined demos + RaC data (optionally with RA-BC) │
|
||||
│ │
|
||||
└─────────────────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Step-by-Step Guide
|
||||
|
||||
### Step 1: Pre-train a Base Policy
|
||||
|
||||
First, train a policy on your demonstration dataset:
|
||||
|
||||
```bash
|
||||
python src/lerobot/scripts/lerobot_train.py \
|
||||
--dataset.repo_id=your-username/demo-dataset \
|
||||
--policy.type=pi0 \
|
||||
--output_dir=outputs/pretrain \
|
||||
--batch_size=32 \
|
||||
--steps=50000
|
||||
```
|
||||
|
||||
### Step 2: Collect RaC Data
|
||||
|
||||
Run the RaC data collection script with your pre-trained policy:
|
||||
|
||||
```bash
|
||||
python examples/rac/rac_data_collection.py \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58760431541 \
|
||||
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||
--teleop.type=so100_leader \
|
||||
--teleop.port=/dev/tty.usbmodem58760431551 \
|
||||
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
|
||||
--dataset.repo_id=your-username/rac-dataset \
|
||||
--dataset.single_task="Pick up the cube and place it in the bowl" \
|
||||
--dataset.num_episodes=50
|
||||
```
|
||||
|
||||
**Controls (Keyboard + Foot Pedal):**
|
||||
|
||||
| Key / Pedal | Action |
|
||||
|-------------|--------|
|
||||
| **SPACE** / Right pedal | Pause policy (teleop mirrors robot, no recording) |
|
||||
| **c** / Left pedal | Take control (start correction, recording resumes) |
|
||||
| **→** / Right pedal | End episode (save) - when in correction mode |
|
||||
| **←** | Re-record episode |
|
||||
| **ESC** | Stop session and push to hub |
|
||||
| Any key/pedal during reset | Start next episode |
|
||||
|
||||
**The RaC Protocol:**
|
||||
|
||||
1. Watch the policy run autonomously (teleop is idle/free)
|
||||
2. When you see imminent failure, press **SPACE** or **right pedal** to pause
|
||||
- Policy stops
|
||||
- Teleoperator moves to match robot position (torque enabled)
|
||||
- No frames recorded during pause
|
||||
3. Press **c** or **left pedal** to take control
|
||||
- Teleoperator torque disabled, free to move
|
||||
- **RECOVERY**: Teleoperate back to a good state
|
||||
- **CORRECTION**: Complete the subtask
|
||||
- All movements are recorded
|
||||
4. Press **→** or **right pedal** to save and end episode
|
||||
5. **RESET**: Teleop moves to robot position, you can move robot to starting position
|
||||
6. Press any key/pedal to start next episode
|
||||
|
||||
The recovery and correction segments teach the policy how to recover from errors.
|
||||
|
||||
**Foot Pedal Setup (Linux):**
|
||||
|
||||
If using a USB foot pedal (PCsensor FootSwitch), ensure access:
|
||||
```bash
|
||||
sudo setfacl -m u:$USER:rw /dev/input/by-id/usb-PCsensor_FootSwitch-event-kbd
|
||||
```
|
||||
|
||||
### Step 3: (Optional) Compute SARM Rewards
|
||||
|
||||
For advantage-weighted training (RA-BC / Pi0.6-style), compute SARM progress values:
|
||||
|
||||
```bash
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \
|
||||
--dataset-repo-id your-username/rac-dataset \
|
||||
--reward-model-path your-username/sarm-model \
|
||||
--head-mode sparse \
|
||||
--push-to-hub
|
||||
```
|
||||
|
||||
### Step 4: Fine-tune Policy
|
||||
|
||||
Fine-tune on the RaC data:
|
||||
|
||||
```bash
|
||||
# Without RA-BC (standard fine-tuning)
|
||||
python src/lerobot/scripts/lerobot_train.py \
|
||||
--dataset.repo_id=your-username/rac-dataset \
|
||||
--policy.type=pi0 \
|
||||
--policy.pretrained_path=outputs/pretrain/checkpoints/last/pretrained_model \
|
||||
--output_dir=outputs/rac_finetune \
|
||||
--steps=20000
|
||||
|
||||
# With RA-BC (advantage-weighted, Pi0.6-style)
|
||||
python src/lerobot/scripts/lerobot_train.py \
|
||||
--dataset.repo_id=your-username/rac-dataset \
|
||||
--policy.type=pi0 \
|
||||
--policy.pretrained_path=outputs/pretrain/checkpoints/last/pretrained_model \
|
||||
--output_dir=outputs/rac_finetune_rabc \
|
||||
--use_rabc=true \
|
||||
--rabc_kappa=0.01 \
|
||||
--steps=20000
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Connection to Pi0.6 / RECAP
|
||||
|
||||
Pi0.6's RECAP method shares similar principles:
|
||||
- Collect autonomous rollouts + expert interventions
|
||||
- Use value function to compute **advantages**: A(s,a) = V(s') - V(s)
|
||||
- **Advantage conditioning**: Weight training based on expected improvement
|
||||
|
||||
In LeRobot, we can use **SARM** as the value function:
|
||||
- SARM progress φ(s) ∈ [0,1] measures task completion
|
||||
- Progress delta = φ(s') - φ(s) approximates advantage
|
||||
- RA-BC uses these to weight training samples (higher weight for good corrections)
|
||||
|
||||
---
|
||||
|
||||
## Tips for Effective RaC Collection
|
||||
|
||||
### When to Intervene
|
||||
|
||||
Intervene when you see:
|
||||
- Robot about to make an irreversible mistake
|
||||
- Robot hesitating or showing uncertain behavior
|
||||
- Robot deviating from expected trajectory
|
||||
|
||||
### Recovery: Teleoperating Back to Good State
|
||||
|
||||
During recovery, teleoperate the robot back to a state where:
|
||||
- The robot is in a familiar, in-distribution configuration
|
||||
- The current subtask can still be completed
|
||||
- The recovery trajectory itself is informative training data
|
||||
|
||||
### Quality of Corrections
|
||||
|
||||
During correction:
|
||||
- Provide **confident, clean** trajectories
|
||||
- Complete the current subtask fully
|
||||
- Don't overcorrect or add unnecessary movements
|
||||
|
||||
---
|
||||
|
||||
## Iterative Improvement
|
||||
|
||||
RaC can be applied iteratively:
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────────────────┐
|
||||
│ Policy v0 (demos) │
|
||||
│ ↓ │
|
||||
│ RaC Collection (target current failure modes) → Policy v1 │
|
||||
│ ↓ │
|
||||
│ RaC Collection (target new failure modes) → Policy v2 │
|
||||
│ ↓ │
|
||||
│ ... (repeat until satisfactory performance) │
|
||||
└─────────────────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
Each iteration:
|
||||
1. Deploy current policy
|
||||
2. Collect RaC interventions on failure cases
|
||||
3. Fine-tune on accumulated data
|
||||
|
||||
---
|
||||
|
||||
## References
|
||||
|
||||
```bibtex
|
||||
@article{hu2025rac,
|
||||
title={RaC: Robot Learning for Long-Horizon Tasks by Scaling Recovery and Correction},
|
||||
author={Hu, Zheyuan and Wu, Robyn and Enock, Naveen and Li, Jasmine and Kadakia, Riya and Erickson, Zackory and Kumar, Aviral},
|
||||
journal={arXiv preprint arXiv:2509.07953},
|
||||
year={2025}
|
||||
}
|
||||
|
||||
@article{kelly2019hgdagger,
|
||||
title={HG-DAgger: Interactive Imitation Learning with Human Experts},
|
||||
author={Kelly, Michael and Sidrane, Chelsea and Driggs-Campbell, Katherine and Kochenderfer, Mykel J},
|
||||
journal={arXiv preprint arXiv:1810.02890},
|
||||
year={2019}
|
||||
}
|
||||
|
||||
@article{pi2025recap,
|
||||
title={π∗0.6: a VLA That Learns From Experience},
|
||||
author={Physical Intelligence},
|
||||
year={2025}
|
||||
}
|
||||
|
||||
@article{chen2025sarm,
|
||||
title={SARM: Stage-Aware Reward Modeling for Long Horizon Robot Manipulation},
|
||||
author={Chen, Qianzhong and Yu, Justin and Schwager, Mac and Abbeel, Pieter and Shentu, Yide and Wu, Philipp},
|
||||
journal={arXiv preprint arXiv:2509.25358},
|
||||
year={2025}
|
||||
}
|
||||
```
|
||||
|
||||
@@ -0,0 +1,586 @@
|
||||
# SARM: Stage-Aware Reward Modeling
|
||||
|
||||
SARM (Stage-Aware Reward Modeling) is a video-based reward modeling framework for long-horizon robot manipulation tasks. This guide covers how to train SARM reward models and optionally use them with Reward-Aligned Behavior Cloning (RA-BC).
|
||||
|
||||
**Paper**: [SARM: Stage-Aware Reward Modeling for Long Horizon Robot Manipulation](https://arxiv.org/abs/2509.25358)
|
||||
|
||||
## Why Reward Models?
|
||||
|
||||
Standard behavior cloning treats all demonstration frames equally, but real-world robot datasets are messy. They contain hesitations, corrections, and variable-quality trajectories. Reward models solve this by learning a generalizable notion of **task progress** from demonstrations: given video frames and a task description, they predict how close the robot is to completing the task (0→1). This learned "progress signal" can be used in multiple ways, two promising applications are: (1) **weighted imitation learning** (RA-BC), where high-progress frames receive more weight during policy training, and (2) **reinforcement learning**, where the reward model provides dense rewards for online or offline policy improvement.
|
||||
|
||||
## Overview
|
||||
|
||||
SARM has following features:
|
||||
|
||||
1. **Stage-aware architecture**: Jointly predicts the high-level task stage and fine-grained progress within each stage
|
||||
2. **Subtask annotations**: Uses natural language subtask annotations to derive consistent progress labels
|
||||
3. **Temporal proportions**: Computes dataset-level priors (α̅\_k) for each subtask to normalize progress across variable-length demonstrations
|
||||
|
||||
SARM trains on a compact **stage+tau** target for each frame:
|
||||
|
||||
- **stage**: integer stage index `k ∈ {0, ..., K-1}`
|
||||
- **τ (tau)**: within-stage progress `τ ∈ [0, 1]`
|
||||
- **target encoding**: `y = k + τ` (this is what the dataset processor produces)
|
||||
|
||||
At inference time (and in downstream RA-BC), SARM converts the raw `k + τ` value into a **normalized progress** in `[0, 1]` using dataset-level **temporal proportions** `α̅_k` (stored in `meta/temporal_proportions_*.json`).
|
||||
|
||||
This matches **Formula (2)** from the paper:
|
||||
|
||||
```
|
||||
progress_t = P_{k-1} + α̅_k × τ_t
|
||||
```
|
||||
|
||||
Where:
|
||||
|
||||
- `τ_t = (t - s_k) / (e_k - s_k)` is within-subtask normalized time
|
||||
- `P_{k-1}` is cumulative prior (sum of previous subtask proportions)
|
||||
- `α̅_k` is the temporal proportion for subtask k
|
||||
|
||||
This ensures identical task states map to consistent progress values, even across demonstrations of different lengths.
|
||||
|
||||
## Inputs and Targets (What the new code expects)
|
||||
|
||||
SARM is trained through its processor (`src/lerobot/policies/sarm/processor_sarm.py`), which:
|
||||
|
||||
- **Encodes** images and task text with CLIP (ViT-B/32) into `video_features` and `text_features`
|
||||
- **Pads/truncates** robot state into `state_features` (up to `max_state_dim`)
|
||||
- **Builds targets** as `sparse_targets` (and `dense_targets` in `dense_only`/`dual`) using the stage+tau encoding `y = k + τ`
|
||||
- **Masks rewind frames** using a per-sample `lengths` tensor (rewind is a training-time augmentation)
|
||||
|
||||
At minimum, each training sample needs:
|
||||
|
||||
- `task` (string): task description
|
||||
- `policy.image_key` images and `policy.state_key` states from the dataset
|
||||
|
||||
---
|
||||
|
||||
## Annotation Modes
|
||||
|
||||
You can choose from **3 annotation modes** that determine how progress labels are computed:
|
||||
|
||||
| Mode | Annotations Required | Heads | Use Case |
|
||||
| -------------- | -------------------- | ---------------------------- | ------------------------------------------------------------ |
|
||||
| `single_stage` | None | Sparse only | Simple tasks, quick experiments, no VLM needed |
|
||||
| `dense_only` | Dense (VLM) | Dual (sparse auto-generated) | Detailed subtask tracking without defining high-level stages |
|
||||
| `dual` | Sparse + Dense (VLM) | Dual | Full SARM paper setup with both granularities |
|
||||
|
||||
### Mode Details
|
||||
|
||||
<hfoptions id="mode_explanation">
|
||||
<hfoption id="single_stage">
|
||||
|
||||
**No annotations required.** The entire episode is treated as a single stage called `"task"`, and progress is linear from 0 to 1 over the episode duration.
|
||||
|
||||
- **Sparse head**: 1 stage ("task"), linear progress
|
||||
- **Dense head**: Not used
|
||||
- **Best for**: Simple tasks, quick experiments, or when VLM annotation is not available
|
||||
|
||||
## Set Up Your Environment
|
||||
|
||||
1. Install LeRobot by following our [Installation Guide](./installation).
|
||||
2. Install SARM dependencies by running:
|
||||
|
||||
```bash
|
||||
pip install -e ".[sarm]"
|
||||
```
|
||||
|
||||
Workflow:
|
||||
|
||||
```
|
||||
1. Train SARM → 2. Visualize predictions → 3. (Optional) Train policy with RA-BC
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="dense_only">
|
||||
|
||||
**Only dense (fine-grained) annotations from a VLM.** The sparse head automatically uses a single `"task"` stage covering the full episode, while the dense head learns detailed subtask progression.
|
||||
|
||||
- **Sparse head**: 1 stage ("task"), linear progress (auto-generated)
|
||||
- **Dense head**: Multiple fine-grained stages from VLM annotations
|
||||
- **Best for**: When you want detailed subtask tracking but don't need to define high-level stages
|
||||
|
||||
Workflow:
|
||||
|
||||
```
|
||||
1. Annotate (dense) → 2. Verify → 3. Train SARM → 4. Visualize → 5. (Optional) Train policy with RA-BC
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="dual">
|
||||
|
||||
**Both sparse and dense annotations from VLM.** Full dual-head mode as described in the SARM paper, with both high-level (sparse) and fine-grained (dense) stage predictions.
|
||||
|
||||
- **Sparse head**: High-level stages from VLM annotations
|
||||
- **Dense head**: Fine-grained stages from VLM annotations
|
||||
- **Best for**: Complex multi-stage tasks where both granularities are useful
|
||||
|
||||
Workflow:
|
||||
|
||||
```
|
||||
1. Annotate (sparse+dense) → 2. Verify → 3. Train SARM → 4. Visualize → 5. (Optional) Train policy with RA-BC
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
---
|
||||
|
||||
## Step 1: Subtask Annotation
|
||||
|
||||
<hfoptions id="annotation_mode">
|
||||
<hfoption id="single_stage">
|
||||
|
||||
**No annotation required!** Skip this step entirely. The model will use the episode's task description and compute linear progress automatically.
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="dense_only">
|
||||
|
||||
Generate **dense (fine-grained) annotations only** using a VLM. The sparse stage will be auto-generated.
|
||||
|
||||
```bash
|
||||
python src/lerobot/data_processing/sarm_annotations/subtask_annotation.py \
|
||||
--repo-id your-username/your-dataset \
|
||||
--dense-only \
|
||||
--dense-subtasks "Bring robot arms up from starting position,Grab near side and do 1st fold,Grab side and do 2nd fold,Grab side and do 3rd fold to finish folding" \
|
||||
--video-key observation.images.base \
|
||||
--num-workers 4 \
|
||||
--push-to-hub
|
||||
```
|
||||
|
||||
**What gets saved:**
|
||||
|
||||
- `meta/temporal_proportions_sparse.json` - Auto-generated sparse proportions (`{"task": 1.0}`)
|
||||
- `meta/temporal_proportions_dense.json` - Dense temporal proportions
|
||||
- Per-episode columns in `episodes/*.parquet`:
|
||||
- `dense_subtask_names`, `dense_subtask_start_frames`, `dense_subtask_end_frames`
|
||||
- (also time-based columns: `dense_subtask_start_times`, `dense_subtask_end_times`)
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="dual">
|
||||
|
||||
Generate **both sparse (high-level) and dense (fine-grained) annotations** using a VLM.
|
||||
|
||||
```bash
|
||||
python src/lerobot/data_processing/sarm_annotations/subtask_annotation.py \
|
||||
--repo-id your-username/your-dataset \
|
||||
--sparse-subtasks "Bring arms up from starting position,Fold the towel (3 folds in total)" \
|
||||
--dense-subtasks "Bring robot arms up from starting position,Grab near side and do 1st fold,Grab side and do 2nd fold,Grab side and do 3rd fold to finish folding" \
|
||||
--video-key observation.images.base \
|
||||
--num-workers 4 \
|
||||
--push-to-hub
|
||||
```
|
||||
|
||||
**What gets saved:**
|
||||
|
||||
- `meta/temporal_proportions_sparse.json` - Sparse temporal proportions
|
||||
- `meta/temporal_proportions_dense.json` - Dense temporal proportions
|
||||
- Per-episode columns in `episodes/*.parquet`:
|
||||
- `sparse_subtask_names`, `sparse_subtask_start_frames`, `sparse_subtask_end_frames`
|
||||
- `dense_subtask_names`, `dense_subtask_start_frames`, `dense_subtask_end_frames`
|
||||
- (also time-based columns: `*_subtask_start_times`, `*_subtask_end_times`)
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
### Annotation Arguments
|
||||
|
||||
| Argument | Description |
|
||||
| ---------------------- | ------------------------------------------------------------------------------- |
|
||||
| `--repo-id` | HuggingFace dataset repository ID |
|
||||
| `--sparse-subtasks` | Comma-separated list of high-level subtask names |
|
||||
| `--dense-subtasks` | Comma-separated list of fine-grained subtask names |
|
||||
| `--dense-only` | Generate only dense annotations (auto-creates sparse "task" stage) |
|
||||
| `--video-key` | Camera/video key to use (e.g., `observation.images.top`) |
|
||||
| `--num-workers` | Number of parallel GPU workers (default: 1) |
|
||||
| `--episodes` | Specific episode indices to annotate (default: all) |
|
||||
| `--skip-existing` | Skip episodes that already have annotations |
|
||||
| `--model` | VLM model (default: `Qwen/Qwen3-VL-30B-A3B-Instruct`) |
|
||||
| `--num-visualizations` | Number of episodes to visualize after annotation (default: 5, set to 0 to skip) |
|
||||
|
||||
> **Note**: After annotation completes, 5 episodes are automatically visualized by default. Use `--num-visualizations 0` to skip this step.
|
||||
|
||||
---
|
||||
|
||||
## Step 2: Verify Annotations
|
||||
|
||||
<hfoptions id="verify_mode">
|
||||
<hfoption id="single_stage">
|
||||
|
||||
**No verification needed!** Skip this step.
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="dense_only">
|
||||
|
||||
Visualize annotations using the `--visualize-only` flag:
|
||||
|
||||
```bash
|
||||
python src/lerobot/data_processing/sarm_annotations/subtask_annotation.py \
|
||||
--repo-id your-username/your-dataset \
|
||||
--visualize-only \
|
||||
--visualize-type dense \
|
||||
--num-visualizations 5 \
|
||||
--video-key observation.images.base \
|
||||
--output-dir ./subtask_viz
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="dual">
|
||||
|
||||
Visualize annotations using the `--visualize-only` flag:
|
||||
|
||||
```bash
|
||||
python src/lerobot/data_processing/sarm_annotations/subtask_annotation.py \
|
||||
--repo-id your-username/your-dataset \
|
||||
--visualize-only \
|
||||
--visualize-type both \
|
||||
--num-visualizations 5 \
|
||||
--video-key observation.images.base \
|
||||
--output-dir ./subtask_viz
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
This generates visualizations showing video frames with subtask boundaries overlaid and timeline of subtasks.
|
||||
|
||||
### Visualization Arguments
|
||||
|
||||
| Argument | Description |
|
||||
| ---------------------- | -------------------------------------------------------------- |
|
||||
| `--visualize-only` | Only visualize existing annotations (no generation) |
|
||||
| `--num-visualizations` | Number of episodes to visualize (default: 5) |
|
||||
| `--visualize-type` | Type of annotations to visualize: `sparse`, `dense`, or `both` |
|
||||
|
||||
**Tip**: If annotations are inaccurate, adjust your subtask descriptions to be more specific and re-run.
|
||||
|
||||
---
|
||||
|
||||
## Step 3: Train SARM
|
||||
|
||||
<hfoptions id="train_mode">
|
||||
<hfoption id="single_stage">
|
||||
|
||||
Train with **no annotations** - uses linear progress from 0 to 1:
|
||||
|
||||
```bash
|
||||
python src/lerobot/scripts/lerobot_train.py \
|
||||
--dataset.repo_id=your-username/your-dataset \
|
||||
--policy.type=sarm \
|
||||
--policy.annotation_mode=single_stage \
|
||||
--policy.image_key=observation.images.base \
|
||||
--output_dir=outputs/train/sarm_single \
|
||||
--batch_size=32 \
|
||||
--steps=5000 \
|
||||
--wandb.enable=true \
|
||||
--wandb.project=sarm \
|
||||
--policy.repo_id=your-username/your-model-name
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="dense_only">
|
||||
|
||||
Train with **dense annotations only** (sparse auto-generated):
|
||||
|
||||
```bash
|
||||
python src/lerobot/scripts/lerobot_train.py \
|
||||
--dataset.repo_id=your-username/your-dataset \
|
||||
--policy.type=sarm \
|
||||
--policy.annotation_mode=dense_only \
|
||||
--policy.image_key=observation.images.base \
|
||||
--output_dir=outputs/train/sarm_dense \
|
||||
--batch_size=32 \
|
||||
--steps=5000 \
|
||||
--wandb.enable=true \
|
||||
--wandb.project=sarm \
|
||||
--policy.repo_id=your-username/your-model-name
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="dual">
|
||||
|
||||
Train with **both sparse and dense annotations**:
|
||||
|
||||
```bash
|
||||
python src/lerobot/scripts/lerobot_train.py \
|
||||
--dataset.repo_id=your-username/your-dataset \
|
||||
--policy.type=sarm \
|
||||
--policy.annotation_mode=dual \
|
||||
--policy.image_key=observation.images.base \
|
||||
--output_dir=outputs/train/sarm_dual \
|
||||
--batch_size=32 \
|
||||
--steps=5000 \
|
||||
--wandb.enable=true \
|
||||
--wandb.project=sarm \
|
||||
--policy.repo_id=your-username/your-model-name
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
### Multi-GPU Training
|
||||
|
||||
Add `accelerate launch --multi_gpu --num_processes=4` to use multiple GPUs for training.
|
||||
|
||||
### Training Arguments
|
||||
|
||||
| Argument | Description | Default |
|
||||
| -------------------------- | ----------------------------------------------------------------- | ------------------------ |
|
||||
| `--policy.annotation_mode` | `single_stage`, `dense_only`, or `dual` | `single_stage` |
|
||||
| `--policy.image_key` | Camera key for images | `observation.images.top` |
|
||||
| `--policy.state_key` | Key for joint states | `observation.state` |
|
||||
| `--policy.n_obs_steps` | Observation history steps (total obs frames = `n_obs_steps + 1`) | `8` |
|
||||
| `--policy.frame_gap` | Gap (in frames) between sampled observations (at 30 fps: 30 ≈ 1s) | `30` |
|
||||
|
||||
---
|
||||
|
||||
## Step 4: Visualize Predictions
|
||||
|
||||
Use `compute_rabc_weights.py` with `--visualize-only` to visualize model predictions (and, if available, annotation-derived targets) without writing a parquet file.
|
||||
|
||||
<hfoptions id="viz_mode">
|
||||
<hfoption id="single_stage">
|
||||
|
||||
```bash
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \
|
||||
--dataset-repo-id your-username/your-dataset \
|
||||
--reward-model-path your-username/sarm-model \
|
||||
--visualize-only \
|
||||
--num-visualizations 5 \
|
||||
--head-mode sparse \
|
||||
--output-dir ./sarm_viz
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="dense_only">
|
||||
|
||||
```bash
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \
|
||||
--dataset-repo-id your-username/your-dataset \
|
||||
--reward-model-path your-username/sarm-model \
|
||||
--visualize-only \
|
||||
--num-visualizations 5 \
|
||||
--head-mode dense \
|
||||
--output-dir ./sarm_viz
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="dual">
|
||||
|
||||
```bash
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \
|
||||
--dataset-repo-id your-username/your-dataset \
|
||||
--reward-model-path your-username/sarm-model \
|
||||
--visualize-only \
|
||||
--num-visualizations 5 \
|
||||
--head-mode both \
|
||||
--output-dir ./sarm_viz
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
The visualization shows:
|
||||
|
||||
- **Progress plot**: Predicted progress (and optional annotation-derived “GT” when available and `--stride 1`)
|
||||
- **Stage probabilities**: Stacked area plot of predicted stage probabilities
|
||||
- **Sample frames**: Key frames from the episode with progress/stage labels
|
||||
|
||||
### Visualization Arguments
|
||||
|
||||
| Argument | Description |
|
||||
| ---------------------- | --------------------------------------------------------- |
|
||||
| `--visualize-only` | Only visualize predictions (no RABC computation) |
|
||||
| `--num-visualizations` | Number of episodes to visualize (default: 5) |
|
||||
| `--head-mode` | SARM head to use: `sparse`, `dense`, or `both` |
|
||||
| `--stride` | Compute every N frames, interpolate the rest (default: 1) |
|
||||
|
||||
---
|
||||
|
||||
## Step 5 (Optional): Train Policy with RA-BC
|
||||
|
||||
Reward-Aligned Behavior Cloning (RA-BC) uses the trained SARM model to weight training samples based on predicted progress improvement. This requires two steps:
|
||||
|
||||
1. **Precompute progress values** for all frames using the trained SARM model
|
||||
2. **Train policy** with RA-BC weighting using the precomputed values
|
||||
|
||||
### How RA-BC Works
|
||||
|
||||
For each training sample, RA-BC computes the progress delta:
|
||||
|
||||
```
|
||||
r_i = φ(o_{t+Δ}) - φ(o_t)
|
||||
```
|
||||
|
||||
Where `φ` is the SARM progress prediction and `Δ` is the policy's `chunk_size`. Samples with positive progress (good demonstrations) get higher weights, while samples with negative or zero progress get down-weighted.
|
||||
|
||||
The weighting follows **Equations 8-9** from the paper:
|
||||
|
||||
- **Soft weight**: `w̃_i = clip((r_i − (μ − 2σ)) / (4σ + ε), 0, 1)`
|
||||
- **Final weight**: `w_i = 𝟙{r_i > κ} + 𝟙{0 ≤ r_i ≤ κ} × w̃_i`
|
||||
|
||||
### Step 5a: Compute SARM Progress Values
|
||||
|
||||
First, run the SARM model on all frames in your dataset to compute progress values:
|
||||
|
||||
```bash
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \
|
||||
--dataset-repo-id your-username/your-dataset \
|
||||
--reward-model-path your-username/sarm-model \
|
||||
--head-mode sparse \
|
||||
--num-visualizations 5 \
|
||||
--push-to-hub
|
||||
```
|
||||
|
||||
This script:
|
||||
|
||||
- Processes all frames and computes progress values
|
||||
- Saves progress values to a parquet file next to the dataset on disk (defaults to `<dataset_root>/sarm_progress.parquet`)
|
||||
- Generates visualizations of the first N episodes (default: 5)
|
||||
|
||||
**Arguments:**
|
||||
|
||||
| Argument | Description | Default |
|
||||
| ---------------------- | -------------------------------------------------------------- | ---------- |
|
||||
| `--reward-model-path` | Path to trained SARM model | (required) |
|
||||
| `--head-mode` | SARM head to use: `sparse`, `dense`, or `both` | `sparse` |
|
||||
| `--device` | Device for inference | `cuda` |
|
||||
| `--visualize-only` | Only visualize predictions (no RA-BC computation) | `false` |
|
||||
| `--num-visualizations` | Number of episodes to visualize (default: 5, set to 0 to skip) | `5` |
|
||||
|
||||
**Output format** (`sarm_progress.parquet`):
|
||||
|
||||
| Column | Description |
|
||||
| ----------------- | ---------------------------------------------- |
|
||||
| `index` | Global frame index in dataset |
|
||||
| `episode_index` | Episode number |
|
||||
| `frame_index` | Local frame index within episode |
|
||||
| `progress_sparse` | Sparse head progress value [0, 1] |
|
||||
| `progress_dense` | Dense head progress value [0, 1] (if computed) |
|
||||
|
||||
### Step 5b: Train Policy with RA-BC
|
||||
|
||||
Once you have the progress file, train your policy with RA-BC weighting. The progress file is auto-detected from the dataset path (`sarm_progress.parquet`). Currently PI0, PI0.5 and SmolVLA are supported with RA-BC:
|
||||
|
||||
```bash
|
||||
python src/lerobot/scripts/lerobot_train.py \
|
||||
--dataset.repo_id=your-username/your-dataset \
|
||||
--policy.type=pi0 \
|
||||
--use_rabc=true \
|
||||
--rabc_head_mode=sparse \
|
||||
--rabc_kappa=0.01 \
|
||||
--output_dir=outputs/train/policy_rabc \
|
||||
--batch_size=32 \
|
||||
--steps=40000
|
||||
```
|
||||
|
||||
The training script automatically:
|
||||
|
||||
- Loads the precomputed progress values from the parquet file
|
||||
- Uses the policy's `chunk_size` to compute progress deltas (Δ)
|
||||
- Computes sample weights based on progress improvement
|
||||
- Applies weighted loss during training
|
||||
|
||||
**RA-BC Arguments:**
|
||||
|
||||
| Argument | Description | Default |
|
||||
| ---------------------- | ---------------------------------------------------------- | ---------------------------------- |
|
||||
| `--use_rabc` | Enable RA-BC sample weighting | `false` |
|
||||
| `--rabc_progress_path` | Path to progress parquet file (auto-detected from dataset) | `sarm_progress.parquet` in dataset |
|
||||
| `--rabc_head_mode` | Which SARM head's progress to use: `sparse` or `dense` | `sparse` |
|
||||
| `--rabc_kappa` | Threshold κ for high-quality samples | `0.01` |
|
||||
|
||||
### Tuning RA-BC Kappa
|
||||
|
||||
The `kappa` parameter is the threshold that determines which samples get full weight (w=1). Understanding how to tune it is critical for RA-BC to work effectively.
|
||||
|
||||
**How the weighting works:**
|
||||
|
||||
| Condition | Weight |
|
||||
| ------------------- | ----------------------- |
|
||||
| `delta > kappa` | 1.0 (hard threshold) |
|
||||
| `0 ≤ delta ≤ kappa` | Soft weight from Eq. 8 |
|
||||
| `delta < 0` | 0.0 (negative progress) |
|
||||
|
||||
**Diagnosing kappa issues:**
|
||||
|
||||
Monitor these WandB metrics during training:
|
||||
|
||||
| Metric | Healthy Range | Problem Indicator |
|
||||
| ------------------ | ------------- | ------------------------- |
|
||||
| `rabc_mean_weight` | 0.3 - 0.8 | ≈ 1.0 means kappa too low |
|
||||
| `rabc_delta_mean` | > 0 | Should be positive |
|
||||
| `rabc_delta_std` | > 0 | Variance in data quality |
|
||||
|
||||
**If `rabc_mean_weight ≈ 1.0`:** Your kappa is too low. Most samples have `delta > kappa` and bypass the soft-weighting entirely. RA-BC becomes equivalent to vanilla BC.
|
||||
|
||||
**Setting kappa based on your data:**
|
||||
|
||||
The default `kappa=0.01` was tuned for the paper's T-shirt folding task (~90s episodes at 30fps). For your dataset, check the logged `rabc_delta_mean` and `rabc_delta_std`:
|
||||
|
||||
```
|
||||
# If delta_mean ≈ 0.03 and delta_std ≈ 0.02:
|
||||
# Most deltas fall in range [0.01, 0.05]
|
||||
|
||||
# Option 1: Set kappa = delta_mean (medium selectivity)
|
||||
--rabc_kappa=0.03
|
||||
|
||||
# Option 2: Set kappa = delta_mean + delta_std (high selectivity)
|
||||
--rabc_kappa=0.05
|
||||
|
||||
# Option 3: Set kappa = delta_mean + 2*delta_std (very selective)
|
||||
--rabc_kappa=0.07
|
||||
```
|
||||
|
||||
**When RA-BC may not help:**
|
||||
|
||||
If your dataset is already high quality (consistent progress across all demonstrations), RA-BC won't provide much benefit since there's nothing to filter.
|
||||
|
||||
### Multi-GPU Training with RA-BC
|
||||
|
||||
```bash
|
||||
accelerate launch \
|
||||
--multi_gpu \
|
||||
--num_processes=4 \
|
||||
src/lerobot/scripts/lerobot_train.py \
|
||||
--dataset.repo_id=your-username/your-dataset \
|
||||
--policy.type=pi0 \
|
||||
--use_rabc=true \
|
||||
--rabc_kappa=0.01 \
|
||||
--output_dir=outputs/train/policy_rabc \
|
||||
--batch_size=32 \
|
||||
--steps=40000
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Tips & Best Practices
|
||||
|
||||
### Choosing a Mode
|
||||
|
||||
- **Start with `single_stage`** for quick experiments - no annotation overhead
|
||||
- Use **`dense_only`** when you want detailed progress tracking but tasks don't have clear high-level stages
|
||||
- Use **`dual`** for complex tasks where both coarse and fine-grained progress is meaningful
|
||||
|
||||
### Annotation Quality
|
||||
|
||||
1. **Be specific with subtask names**: Instead of "fold", use "grab near side and fold toward center"
|
||||
2. **Verify with visualization**: Always check a few episodes before training
|
||||
3. **Consistent naming**: Use the same subtask names across all episodes
|
||||
|
||||
### RA-BC
|
||||
|
||||
1. **Train SARM first**: RA-BC quality depends entirely on SARM quality
|
||||
2. **Monitor `rabc_mean_weight`**: If it's ≈ 1.0, increase kappa (see [Tuning RA-BC Kappa](#tuning-ra-bc-kappa))
|
||||
|
||||
---
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@article{chen2025sarm,
|
||||
title={SARM: Stage-Aware Reward Modeling for Long Horizon Robot Manipulation},
|
||||
author={Chen, Qianzhong and Yu, Justin and Schwager, Mac and Abbeel, Pieter and Shentu, Yide and Wu, Philipp},
|
||||
journal={arXiv preprint arXiv:2509.25358},
|
||||
year={2025}
|
||||
}
|
||||
```
|
||||
@@ -30,131 +30,6 @@ The follower arm uses 6x STS3215 motors with 1/345 gearing. The leader, however,
|
||||
| Wrist Roll | 5 | 1 / 147 |
|
||||
| Gripper | 6 | 1 / 147 |
|
||||
|
||||
### Clean Parts
|
||||
|
||||
Remove all support material from the 3D-printed parts. The easiest way to do this is using a small screwdriver to get underneath the support material.
|
||||
|
||||
It is advisable to install one 3-pin cable in the motor after placing them before continuing assembly.
|
||||
|
||||
### Joint 1
|
||||
|
||||
- Place the first motor into the base.
|
||||
- Fasten the motor with 4 M2x6mm screws (smallest screws). Two from the top and two from the bottom.
|
||||
- Slide over the first motor holder and fasten it using two M2x6mm screws (one on each side).
|
||||
- Install both motor horns, securing the top horn with a M3x6mm screw.
|
||||
- Attach the shoulder part.
|
||||
- Tighten the shoulder part with 4 M3x6mm screws on top and 4 M3x6mm screws on the bottom
|
||||
- Add the shoulder motor holder.
|
||||
|
||||
<div class="video-container">
|
||||
<video controls width="600">
|
||||
<source
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Joint1_v2.mp4"
|
||||
type="video/mp4"
|
||||
/>
|
||||
</video>
|
||||
</div>
|
||||
|
||||
### Joint 2
|
||||
|
||||
- Slide the second motor in from the top.
|
||||
- Fasten the second motor with 4 M2x6mm screws.
|
||||
- Attach both motor horns to motor 2, again use the M3x6mm horn screw.
|
||||
- Attach the upper arm with 4 M3x6mm screws on each side.
|
||||
|
||||
<div class="video-container">
|
||||
<video controls width="600">
|
||||
<source
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Joint2_v2.mp4"
|
||||
type="video/mp4"
|
||||
/>
|
||||
</video>
|
||||
</div>
|
||||
|
||||
### Joint 3
|
||||
|
||||
- Insert motor 3 and fasten using 4 M2x6mm screws
|
||||
- Attach both motor horns to motor 3 and secure one again with a M3x6mm horn screw.
|
||||
- Connect the forearm to motor 3 using 4 M3x6mm screws on each side.
|
||||
|
||||
<div class="video-container">
|
||||
<video controls width="600">
|
||||
<source
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Joint3_v2.mp4"
|
||||
type="video/mp4"
|
||||
/>
|
||||
</video>
|
||||
</div>
|
||||
|
||||
### Joint 4
|
||||
|
||||
- Slide over motor holder 4.
|
||||
- Slide in motor 4.
|
||||
- Fasten motor 4 with 4 M2x6mm screws and attach its motor horns, use a M3x6mm horn screw.
|
||||
|
||||
<div class="video-container">
|
||||
<video controls width="600">
|
||||
<source
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Joint4_v2.mp4"
|
||||
type="video/mp4"
|
||||
/>
|
||||
</video>
|
||||
</div>
|
||||
|
||||
### Joint 5
|
||||
|
||||
- Insert motor 5 into the wrist holder and secure it with 2 M2x6mm front screws.
|
||||
- Install only one motor horn on the wrist motor and secure it with a M3x6mm horn screw.
|
||||
- Secure the wrist to motor 4 using 4 M3x6mm screws on both sides.
|
||||
|
||||
<div class="video-container">
|
||||
<video controls width="600">
|
||||
<source
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Joint5_v2.mp4"
|
||||
type="video/mp4"
|
||||
/>
|
||||
</video>
|
||||
</div>
|
||||
|
||||
### Gripper / Handle
|
||||
|
||||
<hfoptions id="assembly">
|
||||
<hfoption id="Follower">
|
||||
|
||||
- Attach the gripper to motor 5, attach it to the motor horn on the wrist using 4 M3x6mm screws.
|
||||
- Insert the gripper motor and secure it with 2 M2x6mm screws on each side.
|
||||
- Attach the motor horns and again use a M3x6mm horn screw.
|
||||
- Install the gripper claw and secure it with 4 M3x6mm screws on both sides.
|
||||
|
||||
<div class="video-container">
|
||||
<video controls width="600">
|
||||
<source
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Gripper_v2.mp4"
|
||||
type="video/mp4"
|
||||
/>
|
||||
</video>
|
||||
</div>
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="Leader">
|
||||
|
||||
- Mount the leader holder onto the wrist and secure it with 4 M3x6mm screws.
|
||||
- Attach the handle to motor 5 using 1 M2x6mm screw.
|
||||
- Insert the gripper motor, secure it with 2 M2x6mm screws on each side, attach a motor horn using a M3x6mm horn screw.
|
||||
- Attach the follower trigger with 4 M3x6mm screws.
|
||||
|
||||
<div class="video-container">
|
||||
<video controls width="600">
|
||||
<source
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Leader_v2.mp4"
|
||||
type="video/mp4"
|
||||
/>
|
||||
</video>
|
||||
</div>
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
## Configure the motors
|
||||
|
||||
### 1. Find the USB ports associated with each arm
|
||||
@@ -340,6 +215,131 @@ leader.setup_motors()
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
### Clean Parts
|
||||
|
||||
Remove all support material from the 3D-printed parts. The easiest way to do this is using a small screwdriver to get underneath the support material.
|
||||
|
||||
It is advisable to install one 3-pin cable in the motor after placing them before continuing assembly.
|
||||
|
||||
### Joint 1
|
||||
|
||||
- Place the first motor into the base.
|
||||
- Fasten the motor with 4 M2x6mm screws (smallest screws). Two from the top and two from the bottom.
|
||||
- Slide over the first motor holder and fasten it using two M2x6mm screws (one on each side).
|
||||
- Install both motor horns, securing the top horn with a M3x6mm screw.
|
||||
- Attach the shoulder part.
|
||||
- Tighten the shoulder part with 4 M3x6mm screws on top and 4 M3x6mm screws on the bottom
|
||||
- Add the shoulder motor holder.
|
||||
|
||||
<div class="video-container">
|
||||
<video controls width="600">
|
||||
<source
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Joint1_v2.mp4"
|
||||
type="video/mp4"
|
||||
/>
|
||||
</video>
|
||||
</div>
|
||||
|
||||
### Joint 2
|
||||
|
||||
- Slide the second motor in from the top.
|
||||
- Fasten the second motor with 4 M2x6mm screws.
|
||||
- Attach both motor horns to motor 2, again use the M3x6mm horn screw.
|
||||
- Attach the upper arm with 4 M3x6mm screws on each side.
|
||||
|
||||
<div class="video-container">
|
||||
<video controls width="600">
|
||||
<source
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Joint2_v2.mp4"
|
||||
type="video/mp4"
|
||||
/>
|
||||
</video>
|
||||
</div>
|
||||
|
||||
### Joint 3
|
||||
|
||||
- Insert motor 3 and fasten using 4 M2x6mm screws
|
||||
- Attach both motor horns to motor 3 and secure one again with a M3x6mm horn screw.
|
||||
- Connect the forearm to motor 3 using 4 M3x6mm screws on each side.
|
||||
|
||||
<div class="video-container">
|
||||
<video controls width="600">
|
||||
<source
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Joint3_v2.mp4"
|
||||
type="video/mp4"
|
||||
/>
|
||||
</video>
|
||||
</div>
|
||||
|
||||
### Joint 4
|
||||
|
||||
- Slide over motor holder 4.
|
||||
- Slide in motor 4.
|
||||
- Fasten motor 4 with 4 M2x6mm screws and attach its motor horns, use a M3x6mm horn screw.
|
||||
|
||||
<div class="video-container">
|
||||
<video controls width="600">
|
||||
<source
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Joint4_v2.mp4"
|
||||
type="video/mp4"
|
||||
/>
|
||||
</video>
|
||||
</div>
|
||||
|
||||
### Joint 5
|
||||
|
||||
- Insert motor 5 into the wrist holder and secure it with 2 M2x6mm front screws.
|
||||
- Install only one motor horn on the wrist motor and secure it with a M3x6mm horn screw.
|
||||
- Secure the wrist to motor 4 using 4 M3x6mm screws on both sides.
|
||||
|
||||
<div class="video-container">
|
||||
<video controls width="600">
|
||||
<source
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Joint5_v2.mp4"
|
||||
type="video/mp4"
|
||||
/>
|
||||
</video>
|
||||
</div>
|
||||
|
||||
### Gripper / Handle
|
||||
|
||||
<hfoptions id="assembly">
|
||||
<hfoption id="Follower">
|
||||
|
||||
- Attach the gripper to motor 5, attach it to the motor horn on the wrist using 4 M3x6mm screws.
|
||||
- Insert the gripper motor and secure it with 2 M2x6mm screws on each side.
|
||||
- Attach the motor horns and again use a M3x6mm horn screw.
|
||||
- Install the gripper claw and secure it with 4 M3x6mm screws on both sides.
|
||||
|
||||
<div class="video-container">
|
||||
<video controls width="600">
|
||||
<source
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Gripper_v2.mp4"
|
||||
type="video/mp4"
|
||||
/>
|
||||
</video>
|
||||
</div>
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="Leader">
|
||||
|
||||
- Mount the leader holder onto the wrist and secure it with 4 M3x6mm screws.
|
||||
- Attach the handle to motor 5 using 1 M2x6mm screw.
|
||||
- Insert the gripper motor, secure it with 2 M2x6mm screws on each side, attach a motor horn using a M3x6mm horn screw.
|
||||
- Attach the follower trigger with 4 M3x6mm screws.
|
||||
|
||||
<div class="video-container">
|
||||
<video controls width="600">
|
||||
<source
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Leader_v2.mp4"
|
||||
type="video/mp4"
|
||||
/>
|
||||
</video>
|
||||
</div>
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
## Calibrate
|
||||
|
||||
Next, you'll need to calibrate your robot to ensure that the leader and follower arms have the same position values when they are in the same physical position.
|
||||
|
||||
@@ -0,0 +1,42 @@
|
||||
# PyTorch accelerators
|
||||
|
||||
LeRobot supports multiple hardware acceleration options for both training and inference.
|
||||
|
||||
These options include:
|
||||
|
||||
- **CPU**: CPU executes all computations, no dedicated accelerator is used
|
||||
- **CUDA**: acceleration with NVIDIA & AMD GPUs
|
||||
- **MPS**: acceleration with Apple Silicon GPUs
|
||||
- **XPU**: acceleration with Intel integrated and discrete GPUs
|
||||
|
||||
## Getting Started
|
||||
|
||||
To use particular accelerator, a suitable version of PyTorch should be installed.
|
||||
|
||||
For CPU, CUDA, and MPS backends follow instructions provided on [PyTorch installation page](https://pytorch.org/get-started/locally).
|
||||
For XPU backend, follow instructions from [PyTorch documentation](https://docs.pytorch.org/docs/stable/notes/get_start_xpu.html).
|
||||
|
||||
### Verifying the installation
|
||||
|
||||
After installation, accelerator availability can be verified by running
|
||||
|
||||
```python
|
||||
import torch
|
||||
print(torch.<backend_name>.is_available()) # <backend_name> is cuda, mps, or xpu
|
||||
```
|
||||
|
||||
## How to run training or evaluation
|
||||
|
||||
To select the desired accelerator, use the `--policy.device` flag when running `lerobot-train` or `lerobot-eval`. For example, to use MPS on Apple Silicon, run:
|
||||
|
||||
```bash
|
||||
lerobot-train
|
||||
--policy.device=mps ...
|
||||
```
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.device=mps ...
|
||||
```
|
||||
|
||||
However, in most cases, presence of an accelerator is detected automatically and `policy.device` parameter can be omitted from CLI commands.
|
||||
@@ -4,11 +4,12 @@ This guide covers the complete setup process for the Unitree G1 humanoid, from i
|
||||
|
||||
## About the Unitree G1
|
||||
|
||||
We offer support for both 29 and 23 DOF G1. In this first PR we introduce:
|
||||
We offer support for both 29 and 23 DOF G1. We introduce:
|
||||
|
||||
- **`unitree g1` robot class, handling low level communication with the humanoid**
|
||||
- **ZMQ socket bridge** for remote communication over WiFi, allowing one to deploy policies remotely instead of over ethernet or directly on the Orin
|
||||
- **GR00T locomotion policy** for bipedal walking and balance
|
||||
- **MuJoCo simulation mode** for testing policies without the physical robot
|
||||
|
||||
---
|
||||
|
||||
@@ -191,6 +192,10 @@ Press `Ctrl+C` to stop the policy.
|
||||
|
||||
---
|
||||
|
||||
## Extra: Running in Simulation Mode (MuJoCo)
|
||||
|
||||
You can now test and develop policies without a physical robot using MuJoCo. to do so set `is_simulation=True` in config.
|
||||
|
||||
## Additional Resources
|
||||
|
||||
- [Unitree SDK Documentation](https://github.com/unitreerobotics/unitree_sdk2_python)
|
||||
|
||||
@@ -11,13 +11,14 @@ LeRobot provides several utilities for manipulating datasets:
|
||||
3. **Merge Datasets** - Combine multiple datasets into one. The datasets must have identical features, and episodes are concatenated in the order specified in `repo_ids`
|
||||
4. **Add Features** - Add new features to a dataset
|
||||
5. **Remove Features** - Remove features from a dataset
|
||||
6. **Convert to Video** - Convert image-based datasets to video format for efficient storage
|
||||
|
||||
The core implementation is in `lerobot.datasets.dataset_tools`.
|
||||
An example script detailing how to use the tools API is available in `examples/dataset/use_dataset_tools.py`.
|
||||
|
||||
## Command-Line Tool: lerobot-edit-dataset
|
||||
|
||||
`lerobot-edit-dataset` is a command-line script for editing datasets. It can be used to delete episodes, split datasets, merge datasets, add features, and remove features.
|
||||
`lerobot-edit-dataset` is a command-line script for editing datasets. It can be used to delete episodes, split datasets, merge datasets, add features, remove features, and convert image datasets to video format.
|
||||
|
||||
Run `lerobot-edit-dataset --help` for more information on the configuration of each operation.
|
||||
|
||||
@@ -86,9 +87,71 @@ lerobot-edit-dataset \
|
||||
--operation.feature_names "['observation.images.top']"
|
||||
```
|
||||
|
||||
#### Convert to Video
|
||||
|
||||
Convert an image-based dataset to video format, creating a new LeRobotDataset where images are stored as videos. This is useful for reducing storage requirements and improving data loading performance. The new dataset will have the exact same structure as the original, but with images encoded as MP4 videos in the proper LeRobot format.
|
||||
|
||||
```bash
|
||||
# Local-only: Save to a custom output directory (no hub push)
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--operation.type convert_to_video \
|
||||
--operation.output_dir /path/to/output/pusht_video
|
||||
|
||||
# Save with new repo_id (local storage)
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--new_repo_id lerobot/pusht_video \
|
||||
--operation.type convert_to_video
|
||||
|
||||
# Convert and push to Hugging Face Hub
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--new_repo_id lerobot/pusht_video \
|
||||
--operation.type convert_to_video \
|
||||
--push_to_hub true
|
||||
|
||||
# Convert with custom video codec and quality settings
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--operation.type convert_to_video \
|
||||
--operation.output_dir outputs/pusht_video \
|
||||
--operation.vcodec libsvtav1 \
|
||||
--operation.pix_fmt yuv420p \
|
||||
--operation.g 2 \
|
||||
--operation.crf 30
|
||||
|
||||
# Convert only specific episodes
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--operation.type convert_to_video \
|
||||
--operation.output_dir outputs/pusht_video \
|
||||
--operation.episode_indices "[0, 1, 2, 5, 10]"
|
||||
|
||||
# Convert with multiple workers for parallel processing
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--operation.type convert_to_video \
|
||||
--operation.output_dir outputs/pusht_video \
|
||||
--operation.num_workers 8
|
||||
```
|
||||
|
||||
**Parameters:**
|
||||
|
||||
- `output_dir`: Custom output directory (optional - by default uses `new_repo_id` or `{repo_id}_video`)
|
||||
- `vcodec`: Video codec to use - options: `h264`, `hevc`, `libsvtav1` (default: `libsvtav1`)
|
||||
- `pix_fmt`: Pixel format - options: `yuv420p`, `yuv444p` (default: `yuv420p`)
|
||||
- `g`: Group of pictures (GOP) size - lower values give better quality but larger files (default: 2)
|
||||
- `crf`: Constant rate factor - lower values give better quality but larger files, 0 is lossless (default: 30)
|
||||
- `fast_decode`: Fast decode tuning option (default: 0)
|
||||
- `episode_indices`: List of specific episodes to convert (default: all episodes)
|
||||
- `num_workers`: Number of parallel workers for processing (default: 4)
|
||||
|
||||
**Note:** The resulting dataset will be a proper LeRobotDataset with all cameras encoded as videos in the `videos/` directory, with parquet files containing only metadata (no raw image data). All episodes, stats, and tasks are preserved.
|
||||
|
||||
### Push to Hub
|
||||
|
||||
Add the `--push_to_hub` flag to any command to automatically upload the resulting dataset to the Hugging Face Hub:
|
||||
Add the `--push_to_hub true` flag to any command to automatically upload the resulting dataset to the Hugging Face Hub:
|
||||
|
||||
```bash
|
||||
lerobot-edit-dataset \
|
||||
@@ -96,7 +159,45 @@ lerobot-edit-dataset \
|
||||
--new_repo_id lerobot/pusht_after_deletion \
|
||||
--operation.type delete_episodes \
|
||||
--operation.episode_indices "[0, 2, 5]" \
|
||||
--push_to_hub
|
||||
--push_to_hub true
|
||||
```
|
||||
|
||||
There is also a tool for adding features to a dataset that is not yet covered in `lerobot-edit-dataset`.
|
||||
|
||||
# Dataset Visualization
|
||||
|
||||
## Online Visualization
|
||||
|
||||
When you record a dataset using `lerobot`, it automatically uploads to the Hugging Face Hub unless you specify otherwise. To view the dataset online, use our **LeRobot Dataset Visualizer**, available at:
|
||||
https://huggingface.co/spaces/lerobot/visualize_dataset
|
||||
|
||||
## Local Visualization
|
||||
|
||||
You can also visualize episodes from a dataset locally using our command-line tool.
|
||||
|
||||
**From the Hugging Face Hub:**
|
||||
|
||||
```bash
|
||||
lerobot-dataset-viz \
|
||||
--repo-id lerobot/pusht \
|
||||
--episode-index 0
|
||||
```
|
||||
|
||||
**From a local folder:**
|
||||
Add the `--root` option and set `--mode local`. For example, to search in `./my_local_data_dir/lerobot/pusht`:
|
||||
|
||||
```bash
|
||||
lerobot-dataset-viz \
|
||||
--repo-id lerobot/pusht \
|
||||
--root ./my_local_data_dir \
|
||||
--mode local \
|
||||
--episode-index 0
|
||||
```
|
||||
|
||||
Once executed, the tool opens `rerun.io` and displays the camera streams, robot states, and actions for the selected episode.
|
||||
|
||||
For advanced usage—including visualizing datasets stored on a remote server—run:
|
||||
|
||||
```bash
|
||||
lerobot-dataset-viz --help
|
||||
```
|
||||
|
||||
@@ -0,0 +1,74 @@
|
||||
# WALL-OSS
|
||||
|
||||
WALL-OSS is an open-source foundation model for embodied intelligence, proposed by the [XSquare Robot](https://x2robot.com/en/research/68bc2cde8497d7f238dde690) team in 2025. The LeRobot implementation is adapted from their open-source [WallX](https://github.com/X-Square-Robot/wall-x) repository.
|
||||
|
||||
X Square Robot’s WALL-OSS is now integrated into Hugging Face’s LeRobot ecosystem. This is an exciting collaborative project between the LeRobot and X Square Robot teams. You can now post-train, evaluate, and deploy WALL-OSS directly through LeRobot. With this, we’re aiming to make it easier for the open-source robotics community to customize and deploy WALL-OSS foundation models. Read and explore WALL-OSS [paper](https://arxiv.org/pdf/2509.11766) and [code](https://github.com/X-Square-Robot/wall-x).
|
||||
|
||||
## Model Overview
|
||||
|
||||
The WALL-OSS team is building the embodied foundation model to capture and compress the world's most valuable data: the continuous, high-fidelity stream of physical interaction. By creating a direct feedback loop between the model's decisions and the body's lived experience, the emergence of a truly generalizable intelligence is enabled—one that understands not just how the world works, but how to act effectively within it.
|
||||
|
||||
Technically, WALL-OSS introduces a tightly coupled multimodal architecture (tightly-coupled MoE structure) that integrates both discrete and continuous action modeling strategies. Through a two-stage training pipeline (Inspiration → Integration), the model gradually unifies semantic reasoning and high-frequency action generation. Its core innovations include:
|
||||
|
||||
- **Embodied perception–enhanced multimodal pretraining**: Large-scale training on unified vision–language–action data to strengthen spatial, causal, and manipulation understanding.
|
||||
- **Unified Cross-Level Chain-of-Thought (Uni-CoT)**: A single differentiable framework that unifies high-level instruction reasoning, sub-task decomposition, and fine-grained action synthesis, forming a continuous chain from “understanding” to “execution.”
|
||||
- **Mixture-of-Experts (MoE) action heads**: Dynamically activating experts depending on the task phase and modeling actions in discrete or continuous space to maintain stable VLM priors.
|
||||
- **Two-stage training paradigm**:
|
||||
- **Inspiration stage**: Injecting discrete action priors to strengthen spatial understanding and semantic-action alignment.
|
||||
- **Integration stage**: Using flow matching to achieve high-frequency continuous control.
|
||||
|
||||
## Installation Requirements
|
||||
|
||||
1. Install LeRobot by following our [Installation Guide](./installation).
|
||||
2. Install WallX dependencies by running:
|
||||
|
||||
```bash
|
||||
pip install -e ".[wallx]"
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
To use WallX in LeRobot, specify the policy type as:
|
||||
|
||||
```python
|
||||
policy.type=wall_x
|
||||
```
|
||||
|
||||
## Training
|
||||
|
||||
For training WallX, you can use the standard LeRobot training script with the appropriate configuration:
|
||||
|
||||
```bash
|
||||
python src/lerobot/scripts/lerobot_train.py \
|
||||
--dataset.repo_id=your_dataset \
|
||||
--policy.type=wall_x \
|
||||
--output_dir=./outputs/wallx_training \
|
||||
--job_name=wallx_training \
|
||||
--policy.repo_id=your_repo_id \
|
||||
--policy.pretrained_name_or_path=x-square-robot/wall-oss-flow \
|
||||
--policy.prediction_mode=diffusion \
|
||||
--policy.attn_implementation=eager \
|
||||
--steps=3000 \
|
||||
--policy.device=cuda \
|
||||
--batch_size=32
|
||||
```
|
||||
|
||||
### Training Arguments
|
||||
|
||||
| Argument | Description |
|
||||
| ------------------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `--dataset.repo_id` | The Hugging Face Hub repository ID for your training dataset (e.g., `lerobot/aloha_sim_insertion_human`) |
|
||||
| `--policy.type` | Specifies using the WallX policy architecture |
|
||||
| `--output_dir` | Local directory where training checkpoints and logs will be saved |
|
||||
| `--job_name` | A name identifier for this training run (used in logging/tracking) |
|
||||
| `--policy.repo_id` | Your Hugging Face Hub repo ID where the trained model will be pushed |
|
||||
| `--policy.pretrained_path` | Path to pretrained WallX weights to initialize from (the official WALL-OSS checkpoint) |
|
||||
| `--policy.prediction_mode` | The action prediction strategy: `diffusion` or `fast` - `diffusion` uses iterative denoising for action generation, `fast` uses next token prediction instead |
|
||||
| `--policy.attn_implementation` | Attention implementation backend - `eager` uses standard PyTorch attention (alternatives include `flash_attention_2` or `sdpa`) |
|
||||
| `--steps` | Total number of training steps to run |
|
||||
| `--policy.device` | Device to train on (`cuda` for GPU, `cpu` for CPU) |
|
||||
| `--batch_size` | Number of samples per training batch |
|
||||
|
||||
## License
|
||||
|
||||
This model follows the **Apache 2.0 License**, consistent with the original [WallX repository](https://github.com/X-Square-Robot/wall-x).
|
||||
@@ -0,0 +1,528 @@
|
||||
# X-VLA: The First Soft-Prompted Robot Foundation Model for Any Robot, Any Task
|
||||
|
||||
## Overview
|
||||
|
||||
For years, robotics has aspired to build agents that can follow natural human instructions and operate dexterously across many environments and robot bodies. Recent breakthroughs in LLMs and VLMs suggest a path forward: extend these foundation-model architectures to embodied control by grounding them in actions. This has led to the rise of Vision-Language-Action (VLA) models, with the hope that a single generalist model could combine broad semantic understanding with robust manipulation skills.
|
||||
|
||||
But training such models is difficult. Robot data is fragmented across platforms, sensors, embodiments, and collection protocols. Heterogeneity appears everywhere: different arm configurations, different action spaces, different camera setups, different visual domains, and different task distributions. These inconsistencies create major distribution shifts that make pretraining unstable and adaptation unreliable.
|
||||
|
||||
Inspired by meta-learning and prompt learning, we ask: **"What if a VLA model could learn the structure of each robot and dataset the same way LLMs learn tasks, through prompts?"**
|
||||
|
||||
**X-VLA** is a soft-prompted, flow-matching VLA framework that treats each hardware setup as a "task" and encodes it using a small set of learnable embeddings. These **Soft Prompts** capture embodiment and domain-specific variations, guiding the Transformer from the earliest stages of multimodal fusion. With this mechanism, X-VLA can reconcile diverse robot morphologies, data types, and sensor setups within a single unified architecture.
|
||||
|
||||
<p align="center">
|
||||
<img
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/xvla-architecture.png"
|
||||
alt="XVLA Architecture"
|
||||
style="max-width: 100%; height: auto; width: 800px;"
|
||||
/>
|
||||
</p>
|
||||
|
||||
Built from pure Transformer encoders, X-VLA scales naturally with model size and dataset diversity. Across 6 simulation benchmarks and 3 real robots, Soft Prompts consistently outperform existing methods in handling hardware and domain differences. X-VLA-0.9B, trained on 290K episodes spanning seven robotic platforms, learns an embodiment-agnostic generalist policy in Phase I, and adapts efficiently to new robots in Phase II simply by learning a new set of prompts, while keeping the backbone frozen.
|
||||
|
||||
<p align="center">
|
||||
<img
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/xvla-architecture2.png"
|
||||
alt="XVLA Architecture 2"
|
||||
style="width: 60%; height: auto;"
|
||||
/>
|
||||
</p>
|
||||
|
||||
With only 1% of parameters tuned (9M), X-VLA-0.9B achieves near-π₀ performance on LIBERO and Simpler-WidowX, despite using **300× fewer trainable parameters**. It also demonstrates strong real-world dexterity with minimal demonstrations, including folding cloths in under two minutes.
|
||||
|
||||
<p align="center">
|
||||
<img
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/xvla-fold.png"
|
||||
alt="XVLA fold visualization"
|
||||
style="width: 95%; max-width: 1100px; height: auto;"
|
||||
/>
|
||||
</p>
|
||||
|
||||
X-VLA shows that generalist robot intelligence does not require increasingly complex architectures, only the right way to absorb heterogeneity. Soft Prompts offer a simple, scalable mechanism for unifying diverse robotic data, paving the way toward adaptable, cross-embodiment robot foundation models.
|
||||
|
||||
## Installation
|
||||
|
||||
After installing LeRobot, install the X-VLA dependencies:
|
||||
|
||||
```bash
|
||||
pip install -e .[xvla]
|
||||
```
|
||||
|
||||
After the new release, you'll be able to do:
|
||||
|
||||
```bash
|
||||
pip install lerobot[xvla]
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Basic Usage
|
||||
|
||||
To use X-VLA in your LeRobot configuration, specify the policy type as:
|
||||
|
||||
```bash
|
||||
policy.type=xvla
|
||||
```
|
||||
|
||||
### Evaluating Pre-trained Checkpoints
|
||||
|
||||
Example evaluation with LIBERO:
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.path="lerobot/xvla-libero" \
|
||||
--env.type=libero \
|
||||
--env.task=libero_spatial,libero_goal,libero_10 \
|
||||
--env.control_mode=absolute \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=1 \
|
||||
--env.episode_length=800 \
|
||||
--seed=142
|
||||
```
|
||||
|
||||
## Available Checkpoints
|
||||
|
||||
### 🎯 Base Model
|
||||
|
||||
**[lerobot/xvla-base](https://huggingface.co/lerobot/xvla-base)**
|
||||
|
||||
A 0.9B parameter instantiation of X-VLA, trained with a carefully designed data processing and learning recipe. The training pipeline consists of two phases:
|
||||
|
||||
- **Phase I: Pretraining** - Pretrained on 290K episodes from Droid, Robomind, and Agibot, spanning seven platforms across five types of robotic arms (single-arm to bi-manual setups). By leveraging soft prompts to absorb embodiment-specific variations, the model learns an embodiment-agnostic generalist policy.
|
||||
|
||||
- **Phase II: Domain Adaptation** - Adapted to deployable policies for target domains. A new set of soft prompts is introduced and optimized to encode the hardware configuration of the novel domain, while the pretrained backbone remains frozen.
|
||||
|
||||
### Simulation Checkpoints
|
||||
|
||||
**[lerobot/xvla-libero](https://huggingface.co/lerobot/xvla-libero)**
|
||||
|
||||
Achieves 93% success rate on LIBERO benchmarks. Fine-tuned from the base model for simulation tasks.
|
||||
|
||||
**[lerobot/xvla-widowx](https://huggingface.co/lerobot/xvla-widowx)**
|
||||
|
||||
Fine-tuned on BridgeData for pick-and-place experiments on compact WidowX platforms. Demonstrates robust manipulation capabilities.
|
||||
|
||||
### 🤖 Real-World Checkpoints
|
||||
|
||||
**[lerobot/xvla-folding](https://huggingface.co/lerobot/xvla-folding)**
|
||||
|
||||
A fine-tuned dexterous manipulation model trained on the high-quality Soft-FOLD cloth folding dataset. Achieves 100% success rate over 2 hours of continuous cloth folding.
|
||||
|
||||
**[lerobot/xvla-agibot-world](https://huggingface.co/lerobot/xvla-agibot-world)**
|
||||
|
||||
Optimized for AgileX robot dexterous manipulation tasks.
|
||||
|
||||
**[lerobot/xvla-google-robot](https://huggingface.co/lerobot/xvla-google-robot)**
|
||||
|
||||
Adapted for Google Robot platforms.
|
||||
|
||||
## Training X-VLA
|
||||
|
||||
### Recommended Training Configuration
|
||||
|
||||
When fine-tuning X-VLA for a new embodiment or task, we recommend not freezing the VLM, and also setting the `policy.dtype=bfloat16` to not hit OOM errors.
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--dataset.repo_id=YOUR_DATASET \
|
||||
--output_dir=./outputs/xvla_training \
|
||||
--job_name=xvla_training \
|
||||
--policy.path="lerobot/xvla-base" \
|
||||
--policy.repo_id="HF_USER/xvla-your-robot" \
|
||||
--policy.dtype=bfloat16 \
|
||||
--policy.action_mode=auto \
|
||||
--steps=20000 \
|
||||
--policy.device=cuda \
|
||||
--policy.freeze_vision_encoder=false \
|
||||
--policy.freeze_language_encoder=false \
|
||||
--policy.train_policy_transformer=true \
|
||||
--policy.train_soft_prompts=true \
|
||||
```
|
||||
|
||||
### Training Parameters Explained
|
||||
|
||||
| Parameter | Default | Description |
|
||||
| -------------------------- | ------- | ---------------------------------------------- |
|
||||
| `freeze_vision_encoder` | `false` | Do not freeze the VLM vision encoder weights |
|
||||
| `freeze_language_encoder` | `false` | Do not freeze the VLM language encoder weights |
|
||||
| `train_policy_transformer` | `true` | Allow policy transformer layers to train |
|
||||
| `train_soft_prompts` | `true` | Allow soft prompts to train |
|
||||
|
||||
**💡 Best Practice**: For Phase II adaptation to new embodiments, do not freeze the VLM encoders and also train the policy transformer and soft prompts.
|
||||
|
||||
### Example: Training on Bimanual Robot
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--dataset.repo_id=pepijn223/bimanual-so100-handover-cube \
|
||||
--output_dir=./outputs/xvla_bimanual \
|
||||
--job_name=xvla_so101_training \
|
||||
--policy.path="lerobot/xvla-base" \
|
||||
--policy.dtype=bfloat16 \
|
||||
--policy.repo_id="YOUR_USERNAME/xvla-biso101" \
|
||||
--steps=3000 \
|
||||
--policy.device=cuda \
|
||||
--policy.action_mode=so101_bimanual \
|
||||
--policy.freeze_vision_encoder=false \
|
||||
--policy.freeze_language_encoder=false \
|
||||
--policy.train_policy_transformer=true \
|
||||
--policy.train_soft_prompts=true
|
||||
```
|
||||
|
||||
💡 **Best Performance:** If you have sufficient computational resources and want to achieve best X-VLA finetuning performance, you should follow the official finetuning strategy:
|
||||
|
||||
**🔥 Full-finetune all components with a custom learning-rate scheme**
|
||||
|
||||
To ensure stable optimization, the Vision-Language Model (VLM) must be trained with only 1/10 of the base learning rate, while all other components use the full LR.
|
||||
This LR ratio is crucial for achieving strong and stable finetuning performance. This is already done for you by default.
|
||||
❕Note
|
||||
|
||||
Completely matching the official reported performance may require an additional warm-up LR schedule for soft-prompts, which can bring minor improvements.
|
||||
We encourage implementing this in your customized training pipeline for optimal results.
|
||||
|
||||
## Core Concepts
|
||||
|
||||
### 1. Action Modes
|
||||
|
||||
X-VLA uses an **Action Registry** system to handle different action spaces and embodiments. The `action_mode` parameter defines how actions are processed, what loss functions are used, and how predictions are post-processed.
|
||||
|
||||
#### Available Action Modes
|
||||
|
||||
| Action Mode | Action Dim | Description | Use Case |
|
||||
| ---------------- | ----------------------- | ------------------------------------------- | ------------------------------------ |
|
||||
| `ee6d` | 20 | End-effector with xyz, 6D rotation, gripper | Dual-arm setups with spatial control |
|
||||
| `joint` | 14 | Joint-space with gripper | Direct joint control robots |
|
||||
| `agibot_ee6d` | 20 | AGI-bot variant with MSE loss | AGI-bot platforms |
|
||||
| `so101_bimanual` | 20 (model), 12 (real) | SO101 bimanual robot | Bimanual manipulation tasks |
|
||||
| `auto` | 20 (model), auto (real) | Auto-detects action dim from dataset | **Recommended** for new robots |
|
||||
|
||||
#### Why Action Modes Matter
|
||||
|
||||
When you have a pretrained checkpoint like `lerobot/xvla-base` trained with `action_dim=20`, and you want to train on a dataset with a different action dimension (e.g., 14 for bimanual arms), you can't simply trim the action dimension. The action mode orchestrates:
|
||||
|
||||
1. **Loss Computation**: Different loss functions for different action components (MSE for joints, BCE for grippers, etc.)
|
||||
2. **Preprocessing**: Zeroing out gripper channels, padding dimensions
|
||||
3. **Postprocessing**: Applying sigmoid to gripper logits, trimming padding
|
||||
|
||||
#### Example: BimanualSO101 Action Space
|
||||
|
||||
The `so101_bimanual` action mode handles the mismatch between model output (20D) and real robot control (12D):
|
||||
|
||||
```python
|
||||
# Model outputs 20 dimensions for compatibility
|
||||
dim_action = 20
|
||||
|
||||
# Real robot only needs 12 dimensions
|
||||
# [left_arm (6), right_arm (6)] = [joints (5) + gripper (1)] × 2
|
||||
REAL_DIM = 12
|
||||
|
||||
# Preprocessing: Pad 12D actions to 20D for training
|
||||
# Postprocessing: Trim 20D predictions to 12D for deployment
|
||||
```
|
||||
|
||||
See the [action_hub.py](/home/jade_choghari/robot/lerobot/src/lerobot/policies/xvla/action_hub.py) implementation for details.
|
||||
|
||||
#### Auto Action Mode (Recommended)
|
||||
|
||||
The `auto` action mode is the easiest way to use X-VLA with any robot. It automatically detects your dataset's action dimension and handles padding/trimming:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--policy.path="lerobot/xvla-base" \
|
||||
--policy.action_mode=auto \
|
||||
--policy.max_action_dim=20 \
|
||||
...
|
||||
```
|
||||
|
||||
**How it works:**
|
||||
|
||||
- Reads `action_feature.shape[-1]` from your dataset (e.g., 7 for Franka)
|
||||
- Model outputs `max_action_dim` (default 20) for pretrained compatibility
|
||||
- Loss is computed **only on the real dimensions**: `MSE(pred[:,:,:real_dim], target[:,:,:real_dim])`
|
||||
- Postprocess trims output back to `real_dim` for robot control
|
||||
|
||||
This eliminates the need to create custom action modes for most robots.
|
||||
|
||||
### 2. Domain IDs
|
||||
|
||||
Domain IDs are learnable identifiers for different robot configurations and camera setups. They allow X-VLA to distinguish between:
|
||||
|
||||
- Different robots (Robot 1 vs Robot 2)
|
||||
- Different camera configurations (cam1 vs cam2)
|
||||
- Different combinations (Robot1-cam1-cam2 vs Robot1-cam1 vs Robot2-cam1)
|
||||
|
||||
#### Setting Domain IDs
|
||||
|
||||
**During Training**: By default, domain_id is set to 0 for general training.
|
||||
|
||||
**During Evaluation**: Specify the domain_id that matches your checkpoint's training configuration.
|
||||
|
||||
```python
|
||||
# Example: LIBERO checkpoint uses domain_id=3
|
||||
domain_id = 3
|
||||
```
|
||||
|
||||
The domain_id is automatically added to observations by the `XVLAAddDomainIdProcessorStep` in the preprocessing pipeline.
|
||||
|
||||
The `lerobot/xvla-base` model has been trained on the following domain IDs. It is recommended to choose one that most resembles your robot/configuration:
|
||||
|
||||
#### Fine-tuning Datasets
|
||||
|
||||
| Dataset Name | Domain ID |
|
||||
| ---------------- | --------- |
|
||||
| Bridge | 0 |
|
||||
| RT1 | 1 |
|
||||
| Calvin | 2 |
|
||||
| libero | 3 |
|
||||
| widowx-air | 4 |
|
||||
| AIR-AGILEX-HQ | 5 |
|
||||
| robotwin2_abs_ee | 6 |
|
||||
| robotwin2_clean | 6 |
|
||||
| robocasa-human | 7 |
|
||||
| VLABench | 8 |
|
||||
| AGIBOT-challenge | 9 |
|
||||
| AIR-AGILEX | 10 |
|
||||
| AIRBOT | 18 |
|
||||
|
||||
### 3. Processor Steps
|
||||
|
||||
X-VLA requires specific preprocessing and postprocessing steps for proper operation.
|
||||
|
||||
#### Required Preprocessing Steps
|
||||
|
||||
1. **XVLAImageToFloatProcessorStep**: Converts images from [0, 255] to [0, 1] range
|
||||
2. **XVLAImageNetNormalizeProcessorStep**: Applies ImageNet normalization (required for VLM backbone)
|
||||
3. **XVLAAddDomainIdProcessorStep**: Adds domain_id to observations
|
||||
|
||||
#### Example Custom Processor
|
||||
|
||||
For LIBERO environments, a custom processor handles the specific observation format:
|
||||
|
||||
```python
|
||||
from lerobot.policies.xvla.processor_xvla import LiberoProcessorStep
|
||||
|
||||
processor = LiberoProcessorStep()
|
||||
# Handles robot_state dictionary, converts rotation matrices to 6D representation
|
||||
# Applies 180° image rotation for camera convention
|
||||
```
|
||||
|
||||
### 4. Configuration Parameters
|
||||
|
||||
Key configuration parameters for X-VLA:
|
||||
|
||||
```python
|
||||
# Observation and action
|
||||
n_obs_steps: int = 1 # Number of observation timesteps
|
||||
chunk_size: int = 32 # Action sequence length
|
||||
n_action_steps: int = 32 # Number of action steps to execute
|
||||
|
||||
# Model architecture
|
||||
hidden_size: int = 1024 # Transformer hidden dimension
|
||||
depth: int = 24 # Number of transformer layers
|
||||
num_heads: int = 16 # Number of attention heads
|
||||
num_domains: int = 30 # Maximum number of domain IDs
|
||||
len_soft_prompts: int = 32 # Length of soft prompt embeddings
|
||||
|
||||
# Action space
|
||||
action_mode: str = "ee6d" # Action space type (use "auto" for auto-detection)
|
||||
use_proprio: bool = True # Use proprioceptive state
|
||||
max_state_dim: int = 32 # Maximum state dimension
|
||||
max_action_dim: int = 20 # Max action dim for padding (used by "auto" mode)
|
||||
|
||||
# Vision
|
||||
num_image_views: int | None # Number of camera views
|
||||
resize_imgs_with_padding: tuple[int, int] | None # Target image size with padding
|
||||
|
||||
# Training
|
||||
num_denoising_steps: int = 10 # Flow matching denoising steps
|
||||
```
|
||||
|
||||
## Creating Custom Action Modes
|
||||
|
||||
If your robot has a unique action space, you can create a custom action mode:
|
||||
|
||||
### Step 1: Define Your Action Space
|
||||
|
||||
```python
|
||||
from lerobot.policies.xvla.action_hub import BaseActionSpace, register_action
|
||||
import torch.nn as nn
|
||||
|
||||
@register_action("my_custom_robot")
|
||||
class MyCustomActionSpace(BaseActionSpace):
|
||||
"""Custom action space for my robot."""
|
||||
|
||||
dim_action = 15 # Your robot's action dimension
|
||||
gripper_idx = (7, 14) # Gripper channel indices
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.mse = nn.MSELoss()
|
||||
self.bce = nn.BCEWithLogitsLoss()
|
||||
|
||||
def compute_loss(self, pred, target):
|
||||
"""Define your loss computation."""
|
||||
# Example: MSE for joints, BCE for grippers
|
||||
joints_loss = self.mse(pred[:, :, :7], target[:, :, :7])
|
||||
gripper_loss = self.bce(pred[:, :, self.gripper_idx],
|
||||
target[:, :, self.gripper_idx])
|
||||
|
||||
return {
|
||||
"joints_loss": joints_loss,
|
||||
"gripper_loss": gripper_loss,
|
||||
}
|
||||
|
||||
def preprocess(self, proprio, action, mode="train"):
|
||||
"""Preprocess actions before training."""
|
||||
# Example: Zero out grippers in proprioception
|
||||
proprio_m = proprio.clone()
|
||||
action_m = action.clone() if action is not None else None
|
||||
proprio_m[..., self.gripper_idx] = 0.0
|
||||
if action_m is not None:
|
||||
action_m[..., self.gripper_idx] = 0.0
|
||||
return proprio_m, action_m
|
||||
|
||||
def postprocess(self, action):
|
||||
"""Post-process predictions for deployment."""
|
||||
# Example: Apply sigmoid to gripper logits
|
||||
action[..., self.gripper_idx] = torch.sigmoid(action[..., self.gripper_idx])
|
||||
return action
|
||||
```
|
||||
|
||||
### Step 2: Use Your Custom Action Mode
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--policy.action_mode=my_custom_robot \
|
||||
--dataset.repo_id=YOUR_DATASET \
|
||||
--policy.path="lerobot/xvla-base" \
|
||||
...
|
||||
```
|
||||
|
||||
## Advanced Topics
|
||||
|
||||
### Multi-Camera Support
|
||||
|
||||
X-VLA supports multiple camera views through the `num_image_views` parameter:
|
||||
|
||||
```python
|
||||
# Configure for 3 camera views
|
||||
policy.num_image_views=3
|
||||
|
||||
# Add empty cameras if you have fewer physical cameras
|
||||
policy.empty_cameras=1 # Adds 1 zero-padded camera view
|
||||
```
|
||||
|
||||
### Custom Preprocessing Pipeline
|
||||
|
||||
Create a custom preprocessing pipeline for your environment:
|
||||
|
||||
```python
|
||||
from lerobot.processor import PolicyProcessorPipeline
|
||||
from lerobot.policies.xvla.processor_xvla import (
|
||||
XVLAImageToFloatProcessorStep,
|
||||
XVLAImageNetNormalizeProcessorStep,
|
||||
XVLAAddDomainIdProcessorStep,
|
||||
)
|
||||
|
||||
# Build custom pipeline
|
||||
preprocessor = PolicyProcessorPipeline(
|
||||
steps=[
|
||||
YourCustomProcessorStep(), # Your custom processing
|
||||
XVLAImageToFloatProcessorStep(), # Required: convert to float
|
||||
XVLAImageNetNormalizeProcessorStep(), # Required: ImageNet norm
|
||||
XVLAAddDomainIdProcessorStep(domain_id=5), # Your domain ID
|
||||
]
|
||||
)
|
||||
```
|
||||
|
||||
### Handling Different Action Dimensions
|
||||
|
||||
When your dataset has fewer action dimensions than the pretrained model:
|
||||
|
||||
**Option 1 (Recommended)**: Use `auto` action mode
|
||||
|
||||
```bash
|
||||
# Automatically detects your dataset's action dimension
|
||||
# Works with any robot without custom code
|
||||
policy.action_mode=auto
|
||||
policy.max_action_dim=20 # Match pretrained model
|
||||
```
|
||||
|
||||
**Option 2**: Use a predefined action mode with built-in padding
|
||||
|
||||
```python
|
||||
# Model expects 20D, dataset has 12D
|
||||
# Action mode handles padding internally
|
||||
action_mode = "so101_bimanual" # Pads 12 → 20
|
||||
```
|
||||
|
||||
**Option 2**: Create a custom action mode that maps dimensions explicitly
|
||||
|
||||
```python
|
||||
@register_action("my_mapped_action")
|
||||
class MappedActionSpace(BaseActionSpace):
|
||||
dim_action = 20
|
||||
REAL_DIM = 12
|
||||
|
||||
def _pad_to_model_dim(self, x):
|
||||
# Custom padding logic
|
||||
...
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
|
||||
**Issue**: "Action dimension mismatch"
|
||||
|
||||
- **Solution**: Check that your `action_mode` matches your robot's action space. Create a custom action mode if needed.
|
||||
|
||||
**Issue**: "Image values outside [0, 1] range"
|
||||
|
||||
- **Solution**: Ensure images are preprocessed with `XVLAImageToFloatProcessorStep` before normalization.
|
||||
|
||||
**Issue**: "Domain ID not found"
|
||||
|
||||
- **Solution**: Make sure `XVLAAddDomainIdProcessorStep` is in your preprocessing pipeline with the correct domain_id.
|
||||
|
||||
**Issue**: "Low success rate on new embodiment"
|
||||
|
||||
- **Solution**:
|
||||
1. Verify your action_mode is correct
|
||||
2. Check that soft prompts are being trained (`train_soft_prompts=True`)
|
||||
3. Ensure proper preprocessing (ImageNet normalization, domain_id)
|
||||
4. Consider increasing training steps
|
||||
|
||||
**Issue**: "Out of memory during training"
|
||||
|
||||
- **Solution**:
|
||||
1. Reduce `chunk_size` (e.g., from 32 to 16)
|
||||
2. Enable gradient checkpointing
|
||||
3. Reduce batch size
|
||||
4. Freeze more components
|
||||
|
||||
## Citation
|
||||
|
||||
If you use X-VLA in your research, please cite:
|
||||
|
||||
```bibtex
|
||||
@article{zheng2025x,
|
||||
title = {X-VLA: Soft-Prompted Transformer as Scalable Cross-Embodiment Vision-Language-Action Model},
|
||||
author = {Zheng, Jinliang and Li, Jianxiong and Wang, Zhihao and Liu, Dongxiu and Kang, Xirui
|
||||
and Feng, Yuchun and Zheng, Yinan and Zou, Jiayin and Chen, Yilun and Zeng, Jia and others},
|
||||
journal = {arXiv preprint arXiv:2510.10274},
|
||||
year = {2025}
|
||||
}
|
||||
```
|
||||
|
||||
## Additional Resources
|
||||
|
||||
- [X-VLA Paper](https://arxiv.org/pdf/2510.10274)
|
||||
- [LeRobot Documentation](https://github.com/huggingface/lerobot)
|
||||
- [Action Registry Implementation](https://github.com/huggingface/lerobot/src/lerobot/policies/xvla/action_hub.py)
|
||||
- [Processor Implementation](https://github.com/huggingface/lerobot/src/lerobot/policies/xvla/processor_xvla.py)
|
||||
- [Model Configuration](https://github.com/huggingface/lerobot/src/lerobot/policies/xvla/configuration_xvla.py)
|
||||
|
||||
## Contributing
|
||||
|
||||
We welcome contributions! If you've implemented a new action mode or processor for your robot, please consider submitting a PR to help the community.
|
||||
@@ -1,245 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Aggregate EgoDex shards into a single dataset.
|
||||
|
||||
After distributed processing creates multiple shards, this script combines
|
||||
them into a single unified dataset.
|
||||
|
||||
Reference: https://arxiv.org/abs/2505.11709, https://github.com/apple/ml-egodex
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from datatrove.executor import LocalPipelineExecutor
|
||||
from datatrove.executor.slurm import SlurmPipelineExecutor
|
||||
from datatrove.pipeline.base import PipelineStep
|
||||
|
||||
|
||||
class AggregateEgoDexDatasets(PipelineStep):
|
||||
"""Datatrove pipeline step for aggregating EgoDex shards."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
repo_ids: list[str],
|
||||
aggregated_repo_id: str,
|
||||
local_dir: Path | str | None = None,
|
||||
push_to_hub: bool = False,
|
||||
hf_repo_id: str | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.repo_ids = repo_ids
|
||||
self.aggr_repo_id = aggregated_repo_id
|
||||
self.local_dir = Path(local_dir) if local_dir else None
|
||||
self.push_to_hub = push_to_hub
|
||||
self.hf_repo_id = hf_repo_id if hf_repo_id else aggregated_repo_id
|
||||
|
||||
def run(self, data=None, rank: int = 0, world_size: int = 1):
|
||||
import logging
|
||||
|
||||
from lerobot.datasets.aggregate import aggregate_datasets
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
init_logging()
|
||||
|
||||
# Only worker 0 performs aggregation (aggregate_datasets handles parallelism internally)
|
||||
if rank == 0:
|
||||
logging.info(f"Starting aggregation of {len(self.repo_ids)} shards into {self.aggr_repo_id}")
|
||||
|
||||
# Build roots list if local_dir is specified
|
||||
roots = None
|
||||
if self.local_dir:
|
||||
roots = [self.local_dir / repo_id for repo_id in self.repo_ids]
|
||||
# Filter to only existing directories
|
||||
existing_roots = [r for r in roots if r.exists()]
|
||||
if len(existing_roots) != len(self.repo_ids):
|
||||
logging.warning(
|
||||
f"Only {len(existing_roots)} of {len(self.repo_ids)} shard directories found. "
|
||||
"Missing shards will be skipped."
|
||||
)
|
||||
# Update repo_ids to match existing roots
|
||||
existing_repo_ids = [
|
||||
repo_id for repo_id, r in zip(self.repo_ids, roots, strict=False) if r.exists()
|
||||
]
|
||||
roots = existing_roots
|
||||
self.repo_ids = existing_repo_ids
|
||||
|
||||
if len(self.repo_ids) == 0:
|
||||
logging.error("No shard directories found. Nothing to aggregate.")
|
||||
return
|
||||
|
||||
aggr_root = self.local_dir / self.aggr_repo_id if self.local_dir else None
|
||||
|
||||
aggregate_datasets(
|
||||
repo_ids=self.repo_ids,
|
||||
aggr_repo_id=self.aggr_repo_id,
|
||||
roots=roots,
|
||||
aggr_root=aggr_root,
|
||||
)
|
||||
logging.info("Aggregation complete!")
|
||||
|
||||
# Push to Hugging Face Hub if requested
|
||||
if self.push_to_hub:
|
||||
logging.info(f"Pushing to Hugging Face Hub as {self.hf_repo_id}...")
|
||||
dataset = LeRobotDataset(
|
||||
repo_id=self.aggr_repo_id,
|
||||
root=aggr_root,
|
||||
)
|
||||
# Update repo_id for pushing to different HF account if specified
|
||||
dataset.repo_id = self.hf_repo_id
|
||||
dataset.push_to_hub(
|
||||
tags=["egodex", "hand", "dexterous", "lerobot"],
|
||||
license="cc-by-nc-nd-4.0",
|
||||
)
|
||||
logging.info("Push to hub complete!")
|
||||
else:
|
||||
logging.info(f"Worker {rank} skipping - only worker 0 performs aggregation")
|
||||
|
||||
|
||||
def make_aggregate_executor(
|
||||
repo_id,
|
||||
num_shards,
|
||||
job_name,
|
||||
logs_dir,
|
||||
partition,
|
||||
cpus_per_task,
|
||||
mem_per_cpu,
|
||||
local_dir,
|
||||
push_to_hub,
|
||||
hf_repo_id,
|
||||
slurm=True,
|
||||
):
|
||||
"""Create executor for aggregating EgoDex shards."""
|
||||
# Generate repo IDs for all shards
|
||||
repo_ids = [f"{repo_id}_world_{num_shards}_rank_{rank}" for rank in range(num_shards)]
|
||||
|
||||
kwargs = {
|
||||
"pipeline": [
|
||||
AggregateEgoDexDatasets(repo_ids, repo_id, local_dir, push_to_hub, hf_repo_id),
|
||||
],
|
||||
"logging_dir": str(logs_dir / job_name),
|
||||
}
|
||||
|
||||
if slurm:
|
||||
kwargs.update(
|
||||
{
|
||||
"job_name": job_name,
|
||||
"tasks": 1, # Only need 1 task for aggregation
|
||||
"workers": 1, # Only need 1 worker
|
||||
"time": "24:00:00", # 24 hours for aggregation
|
||||
"partition": partition,
|
||||
"cpus_per_task": cpus_per_task,
|
||||
"sbatch_args": {"mem-per-cpu": mem_per_cpu},
|
||||
}
|
||||
)
|
||||
executor = SlurmPipelineExecutor(**kwargs)
|
||||
else:
|
||||
kwargs.update(
|
||||
{
|
||||
"tasks": 1,
|
||||
"workers": 1,
|
||||
}
|
||||
)
|
||||
executor = LocalPipelineExecutor(**kwargs)
|
||||
|
||||
return executor
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Aggregate EgoDex dataset shards into a single unified dataset."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Repository identifier (base name without shard suffix, e.g., pepijn/egodex-test)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-shards",
|
||||
type=int,
|
||||
required=True,
|
||||
help="Number of shards to aggregate (must match --workers from slurm_port_egodex.py)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--logs-dir",
|
||||
type=Path,
|
||||
default=Path("logs"),
|
||||
help="Path to logs directory for datatrove",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--job-name",
|
||||
type=str,
|
||||
default="aggr_egodex",
|
||||
help="Job name used in SLURM",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--slurm",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Launch over SLURM. Use --slurm 0 to launch locally (for debugging)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--partition",
|
||||
type=str,
|
||||
help="SLURM partition (ideally CPU partition)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cpus-per-task",
|
||||
type=int,
|
||||
default=16,
|
||||
help="Number of CPUs for aggregation task",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mem-per-cpu",
|
||||
type=str,
|
||||
default="8G",
|
||||
help="Memory per CPU for aggregation",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--local-dir",
|
||||
type=Path,
|
||||
default=None,
|
||||
help="Local directory where shards are stored. If not specified, uses default HF cache.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push-to-hub",
|
||||
action="store_true",
|
||||
help="Push aggregated dataset to Hugging Face Hub after aggregation.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hf-repo-id",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Hugging Face repo ID for upload (e.g., username/dataset-name). Defaults to --repo-id.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
kwargs = vars(args)
|
||||
kwargs["slurm"] = kwargs.pop("slurm") == 1
|
||||
|
||||
aggregate_executor = make_aggregate_executor(**kwargs)
|
||||
aggregate_executor.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -1,129 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Download EgoDex dataset
|
||||
# Reference: https://arxiv.org/abs/2505.11709, https://github.com/apple/ml-egodex
|
||||
#
|
||||
# Usage: ./download_egodex.sh [output_dir] [parts...]
|
||||
#
|
||||
# Examples:
|
||||
# ./download_egodex.sh ./data test # Download test set only (16 GB)
|
||||
# ./download_egodex.sh ./data part1 part2 # Download training parts 1 and 2
|
||||
# ./download_egodex.sh ./data all # Download everything (~1.7 TB)
|
||||
#
|
||||
# Available parts:
|
||||
# test - Test set (16 GB)
|
||||
# part1 - Training set part 1 (300 GB)
|
||||
# part2 - Training set part 2 (300 GB)
|
||||
# part3 - Training set part 3 (300 GB)
|
||||
# part4 - Training set part 4 (300 GB)
|
||||
# part5 - Training set part 5 (300 GB)
|
||||
# extra - Additional data (200 GB)
|
||||
# all - Download all parts (~1.7 TB total)
|
||||
|
||||
set -e
|
||||
|
||||
BASE_URL="https://ml-site.cdn-apple.com/datasets/egodex"
|
||||
|
||||
# Map part names to filenames
|
||||
declare -A PART_FILES=(
|
||||
["test"]="test.zip"
|
||||
["part1"]="part1.zip"
|
||||
["part2"]="part2.zip"
|
||||
["part3"]="part3.zip"
|
||||
["part4"]="part4.zip"
|
||||
["part5"]="part5.zip"
|
||||
["extra"]="extra.zip"
|
||||
)
|
||||
|
||||
ALL_PARTS=("test" "part1" "part2" "part3" "part4" "part5" "extra")
|
||||
|
||||
usage() {
|
||||
echo "Usage: $0 <output_dir> <parts...>"
|
||||
echo ""
|
||||
echo "Examples:"
|
||||
echo " $0 ./data test # Download test set only (16 GB)"
|
||||
echo " $0 ./data part1 part2 # Download training parts 1 and 2"
|
||||
echo " $0 ./data all # Download everything (~1.7 TB)"
|
||||
echo ""
|
||||
echo "Available parts: test, part1, part2, part3, part4, part5, extra, all"
|
||||
exit 1
|
||||
}
|
||||
|
||||
download_part() {
|
||||
local output_dir="$1"
|
||||
local part="$2"
|
||||
local filename="${PART_FILES[$part]}"
|
||||
local url="${BASE_URL}/${filename}"
|
||||
local output_file="${output_dir}/${filename}"
|
||||
|
||||
echo "----------------------------------------"
|
||||
echo "Downloading: ${part} (${filename})"
|
||||
echo "URL: ${url}"
|
||||
echo "Output: ${output_file}"
|
||||
echo "----------------------------------------"
|
||||
|
||||
# Download with curl, showing progress
|
||||
curl -L --progress-bar "${url}" -o "${output_file}"
|
||||
|
||||
# Unzip
|
||||
echo "Extracting ${filename}..."
|
||||
unzip -q "${output_file}" -d "${output_dir}"
|
||||
|
||||
# Optionally remove zip file to save space
|
||||
# Uncomment the next line if you want to delete zips after extraction
|
||||
# rm "${output_file}"
|
||||
|
||||
echo "Done: ${part}"
|
||||
echo ""
|
||||
}
|
||||
|
||||
# Check arguments
|
||||
if [ $# -lt 2 ]; then
|
||||
usage
|
||||
fi
|
||||
|
||||
OUTPUT_DIR="$1"
|
||||
shift
|
||||
|
||||
# Create output directory
|
||||
mkdir -p "${OUTPUT_DIR}"
|
||||
|
||||
# Determine which parts to download
|
||||
PARTS_TO_DOWNLOAD=()
|
||||
|
||||
for arg in "$@"; do
|
||||
if [ "$arg" == "all" ]; then
|
||||
PARTS_TO_DOWNLOAD=("${ALL_PARTS[@]}")
|
||||
break
|
||||
elif [ -n "${PART_FILES[$arg]}" ]; then
|
||||
PARTS_TO_DOWNLOAD+=("$arg")
|
||||
else
|
||||
echo "Error: Unknown part '${arg}'"
|
||||
echo "Available parts: test, part1, part2, part3, part4, part5, extra, all"
|
||||
exit 1
|
||||
fi
|
||||
done
|
||||
|
||||
if [ ${#PARTS_TO_DOWNLOAD[@]} -eq 0 ]; then
|
||||
echo "Error: No valid parts specified"
|
||||
usage
|
||||
fi
|
||||
|
||||
echo "========================================"
|
||||
echo "EgoDex Dataset Download"
|
||||
echo "========================================"
|
||||
echo "Output directory: ${OUTPUT_DIR}"
|
||||
echo "Parts to download: ${PARTS_TO_DOWNLOAD[*]}"
|
||||
echo "========================================"
|
||||
echo ""
|
||||
|
||||
# Download each part
|
||||
for part in "${PARTS_TO_DOWNLOAD[@]}"; do
|
||||
download_part "${OUTPUT_DIR}" "${part}"
|
||||
done
|
||||
|
||||
echo "========================================"
|
||||
echo "Download complete!"
|
||||
echo "Data saved to: ${OUTPUT_DIR}"
|
||||
echo "========================================"
|
||||
|
||||
@@ -1,443 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Distributed EgoDex dataset porting using SLURM and datatrove.
|
||||
|
||||
EgoDex is a large-scale dataset for egocentric dexterous manipulation collected
|
||||
with ARKit on Apple Vision Pro. This script converts EgoDex data to LeRobot format.
|
||||
|
||||
Reference: https://arxiv.org/abs/2505.11709, https://github.com/apple/ml-egodex
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import h5py
|
||||
import mediapy as mpy
|
||||
import numpy as np
|
||||
from datatrove.executor import LocalPipelineExecutor
|
||||
from datatrove.executor.slurm import SlurmPipelineExecutor
|
||||
from datatrove.pipeline.base import PipelineStep
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
# Image dimensions
|
||||
DEFAULT_IMAGE_HEIGHT = 1080
|
||||
DEFAULT_IMAGE_WIDTH = 1920
|
||||
|
||||
class PortEgoDexShards(PipelineStep):
|
||||
def __init__(
|
||||
self,
|
||||
raw_dir: Path | str,
|
||||
repo_id: str,
|
||||
local_dir: Path | str = None,
|
||||
percentage: float = 100.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.raw_dir = Path(raw_dir)
|
||||
self.repo_id = repo_id
|
||||
self.local_dir = Path(local_dir) if local_dir else Path("data/local_datasets")
|
||||
self.percentage = percentage
|
||||
|
||||
def run(self, data=None, rank: int = 0, world_size: int = 1):
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import h5py
|
||||
import mediapy as mpy
|
||||
import numpy as np
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
def _get_state_for_single_frame(transforms_group, frame_idx):
|
||||
"""
|
||||
Construct 48D hand state representation from EgoDex.
|
||||
|
||||
State vector composition (per hand = 24D, total = 48D):
|
||||
- Wrist 3D position (3)
|
||||
- Wrist orientation in 6D representation (6)
|
||||
- 5 fingertip 3D positions (15)
|
||||
"""
|
||||
state_vector = []
|
||||
fingertip_joints = {
|
||||
"left": [
|
||||
"leftThumbTip",
|
||||
"leftIndexFingerTip",
|
||||
"leftMiddleFingerTip",
|
||||
"leftRingFingerTip",
|
||||
"leftLittleFingerTip",
|
||||
],
|
||||
"right": [
|
||||
"rightThumbTip",
|
||||
"rightIndexFingerTip",
|
||||
"rightMiddleFingerTip",
|
||||
"rightRingFingerTip",
|
||||
"rightLittleFingerTip",
|
||||
],
|
||||
}
|
||||
|
||||
for hand_side in ["left", "right"]:
|
||||
hand_key = f"{hand_side}Hand"
|
||||
hand_transform = transforms_group[hand_key][frame_idx]
|
||||
|
||||
# 1. Wrist 3D position
|
||||
hand_position = hand_transform[:3, 3]
|
||||
state_vector.extend(hand_position)
|
||||
|
||||
# 2. Wrist orientation in compact 6D representation
|
||||
rotation_matrix = hand_transform[:3, :3]
|
||||
rotation_6d = np.concatenate([rotation_matrix[:, 0], rotation_matrix[:, 1]])
|
||||
state_vector.extend(rotation_6d)
|
||||
|
||||
# 3. 3D positions of 5 fingertips
|
||||
for fingertip in fingertip_joints[hand_side]:
|
||||
fingertip_transform = transforms_group[fingertip][frame_idx]
|
||||
fingertip_pos = fingertip_transform[:3, 3]
|
||||
state_vector.extend(fingertip_pos)
|
||||
|
||||
# Also return camera extrinsics for optional coordinate frame transformations
|
||||
return np.array(state_vector, dtype=np.float32), transforms_group["camera"][frame_idx]
|
||||
|
||||
def get_state_and_action_from_egodex_annotations(demo):
|
||||
"""
|
||||
Convert EgoDex demo annotations into states and actions.
|
||||
|
||||
The "action" is the state at time t+1 (next-pose prediction).
|
||||
"""
|
||||
transforms_group = demo["transforms"]
|
||||
total_frames = list(transforms_group.values())[0].shape[0]
|
||||
|
||||
states_list, extrinsics_list = [], []
|
||||
for frame_idx in range(total_frames):
|
||||
state_vector, extrinsics = _get_state_for_single_frame(transforms_group, frame_idx)
|
||||
states_list.append(state_vector)
|
||||
extrinsics_list.append(extrinsics.flatten()) # Flatten 4x4 to 16D
|
||||
|
||||
state = np.array(states_list, dtype=np.float32)
|
||||
extrinsics = np.array(extrinsics_list, dtype=np.float32)
|
||||
|
||||
# Shift by 1 timestep to convert state to action
|
||||
action = np.roll(state, -1, axis=0)
|
||||
|
||||
return state, action, extrinsics
|
||||
|
||||
def process_demo(hdf5_file_path, video_path):
|
||||
"""Process a single EgoDex demo and return frames for LeRobot."""
|
||||
video = mpy.read_video(str(video_path))
|
||||
video = np.asarray(video)
|
||||
num_frames = video.shape[0]
|
||||
frames = []
|
||||
|
||||
with h5py.File(hdf5_file_path, "r") as demo:
|
||||
state, action, extrinsics = get_state_and_action_from_egodex_annotations(demo)
|
||||
|
||||
# Get natural language task description
|
||||
if demo.attrs.get("llm_type") == "reversible":
|
||||
direction = demo.attrs.get("which_llm_description", "1")
|
||||
lang_instruction = demo.attrs.get(
|
||||
"llm_description" if direction == "1" else "llm_description2",
|
||||
"manipulation task",
|
||||
)
|
||||
else:
|
||||
lang_instruction = demo.attrs.get("llm_description", "manipulation task")
|
||||
|
||||
for step_idx in range(num_frames):
|
||||
# Resize image to default dimensions
|
||||
image_resized = cv2.resize(
|
||||
video[step_idx],
|
||||
(DEFAULT_IMAGE_WIDTH, DEFAULT_IMAGE_HEIGHT),
|
||||
interpolation=cv2.INTER_AREA,
|
||||
)
|
||||
frame = {
|
||||
"task": lang_instruction,
|
||||
"observation.image": image_resized,
|
||||
"observation.state": state[step_idx],
|
||||
"observation.extrinsics": extrinsics[step_idx],
|
||||
"action": action[step_idx],
|
||||
}
|
||||
frames.append(frame)
|
||||
|
||||
return frames
|
||||
|
||||
init_logging()
|
||||
|
||||
# Define EgoDex features
|
||||
EGODEX_FEATURES = {
|
||||
"observation.image": {
|
||||
"dtype": "video",
|
||||
"shape": (DEFAULT_IMAGE_HEIGHT, DEFAULT_IMAGE_WIDTH, 3),
|
||||
"names": ["height", "width", "rgb"],
|
||||
},
|
||||
"observation.state": {
|
||||
"dtype": "float32",
|
||||
"shape": (48,),
|
||||
"names": [
|
||||
# Left hand wrist position (3)
|
||||
"left_wrist_x",
|
||||
"left_wrist_y",
|
||||
"left_wrist_z",
|
||||
# Left hand wrist rotation 6D (6)
|
||||
"left_rot_0",
|
||||
"left_rot_1",
|
||||
"left_rot_2",
|
||||
"left_rot_3",
|
||||
"left_rot_4",
|
||||
"left_rot_5",
|
||||
# Left fingertips (15)
|
||||
"left_thumb_x",
|
||||
"left_thumb_y",
|
||||
"left_thumb_z",
|
||||
"left_index_x",
|
||||
"left_index_y",
|
||||
"left_index_z",
|
||||
"left_middle_x",
|
||||
"left_middle_y",
|
||||
"left_middle_z",
|
||||
"left_ring_x",
|
||||
"left_ring_y",
|
||||
"left_ring_z",
|
||||
"left_little_x",
|
||||
"left_little_y",
|
||||
"left_little_z",
|
||||
# Right hand wrist position (3)
|
||||
"right_wrist_x",
|
||||
"right_wrist_y",
|
||||
"right_wrist_z",
|
||||
# Right hand wrist rotation 6D (6)
|
||||
"right_rot_0",
|
||||
"right_rot_1",
|
||||
"right_rot_2",
|
||||
"right_rot_3",
|
||||
"right_rot_4",
|
||||
"right_rot_5",
|
||||
# Right fingertips (15)
|
||||
"right_thumb_x",
|
||||
"right_thumb_y",
|
||||
"right_thumb_z",
|
||||
"right_index_x",
|
||||
"right_index_y",
|
||||
"right_index_z",
|
||||
"right_middle_x",
|
||||
"right_middle_y",
|
||||
"right_middle_z",
|
||||
"right_ring_x",
|
||||
"right_ring_y",
|
||||
"right_ring_z",
|
||||
"right_little_x",
|
||||
"right_little_y",
|
||||
"right_little_z",
|
||||
],
|
||||
},
|
||||
"observation.extrinsics": {
|
||||
"dtype": "float32",
|
||||
"shape": (16,),
|
||||
"names": [f"extrinsic_{i}" for i in range(16)],
|
||||
},
|
||||
"action": {
|
||||
"dtype": "float32",
|
||||
"shape": (48,),
|
||||
"names": [f"action_{i}" for i in range(48)],
|
||||
},
|
||||
}
|
||||
|
||||
# 1. Discover all HDF5 files
|
||||
files = sorted(list(self.raw_dir.rglob("*.hdf5")))
|
||||
if not files:
|
||||
print(f"No HDF5 files found in {self.raw_dir}")
|
||||
return
|
||||
|
||||
# 2. Apply percentage filter
|
||||
if self.percentage < 100:
|
||||
num_files = max(1, int(len(files) * self.percentage / 100))
|
||||
files = files[:num_files]
|
||||
print(f"Processing {self.percentage}% of dataset: {num_files} files")
|
||||
|
||||
# 3. Assign files to this worker
|
||||
my_files = files[rank::world_size]
|
||||
if not my_files:
|
||||
print(f"Rank {rank} has no files to process.")
|
||||
return
|
||||
|
||||
print(f"Rank {rank} processing {len(my_files)} files out of {len(files)} total.")
|
||||
|
||||
# 4. Create a LeRobot dataset for this shard
|
||||
shard_repo_id = f"{self.repo_id}_world_{world_size}_rank_{rank}"
|
||||
shard_root = self.local_dir / shard_repo_id if self.local_dir else None
|
||||
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=shard_repo_id,
|
||||
fps=30,
|
||||
robot_type="hand",
|
||||
features=EGODEX_FEATURES,
|
||||
root=shard_root,
|
||||
)
|
||||
|
||||
# 5. Process each file
|
||||
for input_h5 in my_files:
|
||||
try:
|
||||
# Derive corresponding video path
|
||||
video_path = input_h5.with_suffix(".mp4")
|
||||
if not video_path.exists():
|
||||
print(f"Warning: Video file not found for {input_h5}, skipping.")
|
||||
continue
|
||||
|
||||
# Process demo and add frames
|
||||
frames = process_demo(input_h5, video_path)
|
||||
for frame in frames:
|
||||
dataset.add_frame(frame)
|
||||
dataset.save_episode()
|
||||
|
||||
# Clean up to avoid OOM
|
||||
del frames
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing {input_h5}: {e}")
|
||||
continue
|
||||
|
||||
# 6. Finalize the dataset
|
||||
dataset.finalize()
|
||||
|
||||
|
||||
def make_port_executor(
|
||||
raw_dir,
|
||||
repo_id,
|
||||
job_name,
|
||||
logs_dir,
|
||||
workers,
|
||||
partition,
|
||||
cpus_per_task,
|
||||
mem_per_cpu,
|
||||
local_dir,
|
||||
percentage,
|
||||
slurm=True,
|
||||
):
|
||||
kwargs = {
|
||||
"pipeline": [
|
||||
PortEgoDexShards(raw_dir, repo_id, local_dir, percentage),
|
||||
],
|
||||
"logging_dir": str(logs_dir / job_name),
|
||||
}
|
||||
|
||||
if slurm:
|
||||
kwargs.update(
|
||||
{
|
||||
"job_name": job_name,
|
||||
"tasks": workers,
|
||||
"workers": workers,
|
||||
"time": "10:00:00", # EgoDex is large, allow more time
|
||||
"partition": partition,
|
||||
"cpus_per_task": cpus_per_task,
|
||||
"sbatch_args": {"mem-per-cpu": mem_per_cpu},
|
||||
}
|
||||
)
|
||||
executor = SlurmPipelineExecutor(**kwargs)
|
||||
else:
|
||||
kwargs.update(
|
||||
{
|
||||
"tasks": workers,
|
||||
"workers": 1, # Run locally sequentially for debugging
|
||||
}
|
||||
)
|
||||
executor = LocalPipelineExecutor(**kwargs)
|
||||
|
||||
return executor
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Convert EgoDex dataset to LeRobot format using SLURM."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--raw-dir",
|
||||
type=Path,
|
||||
required=True,
|
||||
help="Directory containing input EgoDex data (HDF5 + MP4 files).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Repository identifier (e.g., user/egodex-lerobot).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--logs-dir",
|
||||
type=Path,
|
||||
default=Path("logs"),
|
||||
help="Path to logs directory.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--job-name",
|
||||
type=str,
|
||||
default="port_egodex",
|
||||
help="Job name used in SLURM.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--slurm",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Launch over SLURM. Use --slurm 0 to launch sequentially (useful for debugging).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--workers",
|
||||
type=int,
|
||||
default=50,
|
||||
help="Number of SLURM workers.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--partition",
|
||||
type=str,
|
||||
help="SLURM partition.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cpus-per-task",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Number of CPUs per worker.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mem-per-cpu",
|
||||
type=str,
|
||||
default="4G",
|
||||
help="Memory per CPU.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--percentage",
|
||||
type=float,
|
||||
default=100.0,
|
||||
help="Percentage of dataset to process (e.g., 1.0 for 1%%). Useful for testing.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--local-dir",
|
||||
type=Path,
|
||||
default=None,
|
||||
help="Local directory to save the LeRobot dataset. Defaults to data/local_datasets.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
kwargs = vars(args)
|
||||
kwargs["slurm"] = kwargs.pop("slurm") == 1
|
||||
|
||||
port_executor = make_port_executor(**kwargs)
|
||||
port_executor.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -0,0 +1,416 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Comprehensive debug script for OpenArms CAN FD communication.
|
||||
Tests all 4 CAN interfaces with CAN FD support.
|
||||
"""
|
||||
|
||||
import can
|
||||
import time
|
||||
import sys
|
||||
import subprocess
|
||||
|
||||
def check_can_interface(port):
|
||||
"""Check if CAN interface is UP and configured."""
|
||||
try:
|
||||
result = subprocess.run(['ip', 'link', 'show', port],
|
||||
capture_output=True, text=True)
|
||||
if result.returncode != 0:
|
||||
return False, "Interface not found", None
|
||||
|
||||
output = result.stdout
|
||||
if 'UP' not in output:
|
||||
return False, "Interface is DOWN", None
|
||||
|
||||
# Check if CAN FD is enabled
|
||||
is_fd = 'fd on' in output.lower() or 'canfd' in output.lower()
|
||||
|
||||
return True, "Interface is UP", is_fd
|
||||
except FileNotFoundError:
|
||||
return None, "Cannot check (ip command not found)", None
|
||||
|
||||
|
||||
def test_motor_on_interface(bus, motor_id, timeout=2.0, use_fd=False):
|
||||
"""
|
||||
Test a single motor and return all responses.
|
||||
|
||||
Returns:
|
||||
list of (arbitration_id, data) tuples for all responses received
|
||||
"""
|
||||
# Send enable command
|
||||
enable_msg = can.Message(
|
||||
arbitration_id=motor_id,
|
||||
data=[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFC],
|
||||
is_extended_id=False,
|
||||
is_fd=use_fd
|
||||
)
|
||||
|
||||
try:
|
||||
bus.send(enable_msg)
|
||||
except Exception as e:
|
||||
return None, f"Send error: {e}"
|
||||
|
||||
# Listen for responses
|
||||
responses = []
|
||||
start_time = time.time()
|
||||
|
||||
while time.time() - start_time < timeout:
|
||||
msg = bus.recv(timeout=0.1)
|
||||
if msg:
|
||||
responses.append((msg.arbitration_id, msg.data, msg.is_fd if hasattr(msg, 'is_fd') else False))
|
||||
|
||||
# Send disable command
|
||||
disable_msg = can.Message(
|
||||
arbitration_id=motor_id,
|
||||
data=[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFD],
|
||||
is_extended_id=False,
|
||||
is_fd=use_fd
|
||||
)
|
||||
try:
|
||||
bus.send(disable_msg)
|
||||
except:
|
||||
pass
|
||||
|
||||
return responses, None
|
||||
|
||||
|
||||
def test_interface(port, interface_type="socketcan", use_can_fd=True):
|
||||
"""Test all 8 motors on a single CAN interface."""
|
||||
|
||||
results = {
|
||||
'interface': port,
|
||||
'status': None,
|
||||
'is_fd': use_can_fd,
|
||||
'motors': {}
|
||||
}
|
||||
|
||||
# Check interface status
|
||||
status_ok, status_msg, interface_has_fd = check_can_interface(port)
|
||||
|
||||
if interface_has_fd is not None:
|
||||
results['interface_fd_enabled'] = interface_has_fd
|
||||
if use_can_fd and not interface_has_fd:
|
||||
status_msg += " (CAN FD NOT enabled on interface!)"
|
||||
elif interface_has_fd:
|
||||
status_msg += " (CAN FD enabled)"
|
||||
|
||||
results['status'] = status_msg
|
||||
|
||||
if status_ok is False:
|
||||
return results
|
||||
|
||||
# Try to connect
|
||||
try:
|
||||
if use_can_fd:
|
||||
print(f" Connecting to {port} with CAN FD (1 Mbps / 5 Mbps)...")
|
||||
bus = can.interface.Bus(
|
||||
channel=port,
|
||||
interface=interface_type,
|
||||
bitrate=1000000,
|
||||
data_bitrate=5000000,
|
||||
fd=True
|
||||
)
|
||||
else:
|
||||
print(f" Connecting to {port} with CAN 2.0 (1 Mbps)...")
|
||||
bus = can.interface.Bus(
|
||||
channel=port,
|
||||
interface=interface_type,
|
||||
bitrate=1000000
|
||||
)
|
||||
except Exception as e:
|
||||
results['status'] = f"Connection failed: {e}"
|
||||
return results
|
||||
|
||||
try:
|
||||
# Clear any pending messages
|
||||
while bus.recv(timeout=0.01):
|
||||
pass
|
||||
|
||||
# Test each motor (0x01 to 0x08)
|
||||
for motor_id in range(0x01, 0x09):
|
||||
responses, error = test_motor_on_interface(bus, motor_id, timeout=1.0, use_fd=use_can_fd)
|
||||
|
||||
if error:
|
||||
results['motors'][motor_id] = {'error': error}
|
||||
elif responses:
|
||||
results['motors'][motor_id] = {
|
||||
'found': True,
|
||||
'responses': responses
|
||||
}
|
||||
else:
|
||||
results['motors'][motor_id] = {
|
||||
'found': False,
|
||||
'responses': []
|
||||
}
|
||||
|
||||
time.sleep(0.05) # Small delay between motors
|
||||
|
||||
finally:
|
||||
bus.shutdown()
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def print_results(all_results):
|
||||
"""Print formatted results for all interfaces."""
|
||||
|
||||
print("SUMMARY - Motors Found on Each Interface")
|
||||
|
||||
motor_names = {
|
||||
0x01: "joint_1 (Shoulder pan)",
|
||||
0x02: "joint_2 (Shoulder lift)",
|
||||
0x03: "joint_3 (Shoulder rotation)",
|
||||
0x04: "joint_4 (Elbow flex)",
|
||||
0x05: "joint_5 (Wrist roll)",
|
||||
0x06: "joint_6 (Wrist pitch)",
|
||||
0x07: "joint_7 (Wrist rotation)",
|
||||
0x08: "gripper",
|
||||
}
|
||||
|
||||
total_found = 0
|
||||
|
||||
for result in all_results:
|
||||
interface = result['interface']
|
||||
status = result['status']
|
||||
|
||||
print(f"{interface}: {status}")
|
||||
if result.get('is_fd'):
|
||||
print(f" Mode: CAN FD")
|
||||
else:
|
||||
print(f" Mode: CAN 2.0")
|
||||
|
||||
if 'Connection failed' in status or 'DOWN' in status:
|
||||
print(f" ⚠ Cannot test {interface}")
|
||||
continue
|
||||
|
||||
motors_found = 0
|
||||
|
||||
for motor_id in range(0x01, 0x09):
|
||||
motor_data = result['motors'].get(motor_id, {})
|
||||
motor_name = motor_names.get(motor_id, "Unknown")
|
||||
|
||||
if motor_data.get('error'):
|
||||
print(f" Motor 0x{motor_id:02X} ({motor_name}): ✗ {motor_data['error']}")
|
||||
elif motor_data.get('found'):
|
||||
motors_found += 1
|
||||
total_found += 1
|
||||
responses = motor_data['responses']
|
||||
print(f" Motor 0x{motor_id:02X} ({motor_name}): ✓ FOUND")
|
||||
|
||||
for resp_id, data, is_fd in responses:
|
||||
data_hex = data.hex()
|
||||
fd_flag = " [FD]" if is_fd else " [2.0]"
|
||||
print(f" → Response from 0x{resp_id:02X}{fd_flag}: {data_hex}")
|
||||
else:
|
||||
print(f" Motor 0x{motor_id:02X} ({motor_name}): ✗ No response")
|
||||
|
||||
print(f"\n Summary: {motors_found}/8 motors found on {interface}")
|
||||
|
||||
# Overall summary
|
||||
print("OVERALL SUMMARY")
|
||||
print(f"Total motors found across all interfaces: {total_found}")
|
||||
|
||||
# Analyze configuration
|
||||
print("DIAGNOSIS")
|
||||
|
||||
for result in all_results:
|
||||
interface = result['interface']
|
||||
motors_found = sum(1 for m in result['motors'].values() if m.get('found'))
|
||||
|
||||
if motors_found == 0:
|
||||
print(f"\n⚠ {interface}: NO MOTORS FOUND")
|
||||
print(" Possible issues:")
|
||||
print(" 1. CAN FD mode mismatch (interface vs motor configuration)")
|
||||
print(" 2. Missing 120Ω termination resistors at BOTH cable ends")
|
||||
print(" 3. Motor timeout parameter set incorrectly (should NOT be 0)")
|
||||
print(" 4. CANH/CANL wiring issue")
|
||||
print(" 5. Cable too long (>40m for CAN FD at 5Mbps)")
|
||||
|
||||
# Check FD mismatch
|
||||
if result.get('is_fd') and not result.get('interface_fd_enabled'):
|
||||
print(" ⚠️ CRITICAL: Trying CAN FD but interface NOT configured for FD!")
|
||||
print(f" Fix: sudo ip link set {interface} type can bitrate 1000000 dbitrate 5000000 fd on")
|
||||
|
||||
elif motors_found < 8:
|
||||
print(f"\n⚠ {interface}: Only {motors_found}/8 motors responding")
|
||||
print(" Check power and connections for missing motors")
|
||||
else:
|
||||
print(f"\n✓ {interface}: All 8 motors responding correctly!")
|
||||
|
||||
# Check for unexpected response IDs
|
||||
print("RESPONSE ID ANALYSIS")
|
||||
|
||||
for result in all_results:
|
||||
interface = result['interface']
|
||||
unexpected = []
|
||||
|
||||
for motor_id, motor_data in result['motors'].items():
|
||||
if motor_data.get('found'):
|
||||
expected_id = motor_id + 0x10
|
||||
actual_ids = [resp[0] for resp in motor_data['responses']]
|
||||
|
||||
if expected_id not in actual_ids:
|
||||
unexpected.append((motor_id, actual_ids))
|
||||
|
||||
if unexpected:
|
||||
print(f"\n⚠ {interface}: Unexpected response IDs detected")
|
||||
for motor_id, actual_ids in unexpected:
|
||||
expected_id = motor_id + 0x10
|
||||
print(f" Motor 0x{motor_id:02X}: Expected 0x{expected_id:02X}, "
|
||||
f"got {[f'0x{id:02X}' for id in actual_ids]}")
|
||||
print(" → Motor Master IDs need reconfiguration")
|
||||
else:
|
||||
motors_found = sum(1 for m in result['motors'].values() if m.get('found'))
|
||||
if motors_found > 0:
|
||||
print(f"\n✓ {interface}: All responding motors use correct IDs")
|
||||
|
||||
|
||||
def test_communication_speed(interface, motor_id, num_iterations=100):
|
||||
"""
|
||||
Test communication speed with a motor.
|
||||
|
||||
Returns:
|
||||
tuple: (hz, avg_latency_ms) or (None, None) if test failed
|
||||
"""
|
||||
try:
|
||||
# Connect to interface
|
||||
bus = can.interface.Bus(
|
||||
channel=interface,
|
||||
interface="socketcan",
|
||||
bitrate=1000000,
|
||||
data_bitrate=5000000,
|
||||
fd=True
|
||||
)
|
||||
|
||||
# Send refresh commands and measure round-trip time
|
||||
latencies = []
|
||||
successful = 0
|
||||
|
||||
for _ in range(num_iterations):
|
||||
start = time.perf_counter()
|
||||
|
||||
# Send enable command (lightweight operation)
|
||||
enable_msg = can.Message(
|
||||
arbitration_id=motor_id,
|
||||
data=[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFC],
|
||||
is_extended_id=False,
|
||||
is_fd=True
|
||||
)
|
||||
bus.send(enable_msg)
|
||||
|
||||
# Wait for response
|
||||
msg = bus.recv(timeout=0.1)
|
||||
|
||||
if msg:
|
||||
latency = (time.perf_counter() - start) * 1000 # Convert to ms
|
||||
latencies.append(latency)
|
||||
successful += 1
|
||||
|
||||
bus.shutdown()
|
||||
|
||||
if successful > 0:
|
||||
avg_latency = sum(latencies) / len(latencies)
|
||||
hz = 1000.0 / avg_latency if avg_latency > 0 else 0
|
||||
return hz, avg_latency
|
||||
|
||||
return None, None
|
||||
|
||||
except Exception as e:
|
||||
print(f" Speed test error: {e}")
|
||||
return None, None
|
||||
|
||||
|
||||
def main():
|
||||
"""Main function to test all CAN interfaces with CAN FD."""
|
||||
|
||||
print("\nThis will test all 4 CAN interfaces (can0-can3) with CAN FD")
|
||||
print("Testing motors 0x01-0x08 on each interface")
|
||||
print()
|
||||
print("Make sure:")
|
||||
print(" ✓ Motors are powered (24V)")
|
||||
print(" ✓ CAN interfaces configured with FD mode:")
|
||||
print(" ./examples/openarms/setup_can.sh")
|
||||
print(" ✓ Motor 'timeout' parameter NOT set to 0 (use Damiao tools)")
|
||||
print(" ✓ CAN wiring includes 120Ω termination at BOTH ends")
|
||||
print()
|
||||
|
||||
input("Press ENTER to start testing...")
|
||||
|
||||
# Test all 4 interfaces with CAN FD
|
||||
all_results = []
|
||||
|
||||
for i in range(4):
|
||||
interface = f"can{i}"
|
||||
print(f"Testing {interface}...")
|
||||
|
||||
result = test_interface(interface, use_can_fd=True)
|
||||
all_results.append(result)
|
||||
|
||||
# Quick status
|
||||
if 'Connection failed' in result['status'] or 'DOWN' in result['status']:
|
||||
print(f" ⚠ {interface}: {result['status']}")
|
||||
else:
|
||||
motors_found = sum(1 for m in result['motors'].values() if m.get('found'))
|
||||
print(f" {interface}: {motors_found}/8 motors found")
|
||||
|
||||
time.sleep(0.2)
|
||||
|
||||
# Print detailed results
|
||||
print_results(all_results)
|
||||
|
||||
print("Testing Complete!")
|
||||
|
||||
all_found = sum(sum(1 for m in r['motors'].values() if m.get('found')) for r in all_results)
|
||||
|
||||
if all_found == 0:
|
||||
print("\n⚠️ CRITICAL: No motors found on any interface!")
|
||||
print("\nTop issues to check:")
|
||||
print(" 1. Motor 'timeout' parameter (use Damiao tools to set > 0)")
|
||||
print(" 2. CAN FD not enabled (run ./examples/openarms/setup_can.sh)")
|
||||
print(" 3. Missing termination resistors")
|
||||
print("\nTry:")
|
||||
print(" a) Check motor parameters with Damiao Debugging Tools")
|
||||
print(" b) Verify CAN FD is enabled: ip -d link show can0 | grep fd")
|
||||
print(" c) Run setup script: ./examples/openarms/setup_can.sh")
|
||||
else:
|
||||
# Run speed test on interfaces with motors
|
||||
print("COMMUNICATION SPEED TEST")
|
||||
print("\nTesting maximum communication frequency...")
|
||||
|
||||
for result in all_results:
|
||||
interface = result['interface']
|
||||
|
||||
# Find first responding motor
|
||||
responding_motor = None
|
||||
for motor_id, motor_data in result['motors'].items():
|
||||
if motor_data.get('found'):
|
||||
responding_motor = motor_id
|
||||
break
|
||||
|
||||
if responding_motor:
|
||||
print(f"\n{interface}: Testing with motor 0x{responding_motor:02X}...")
|
||||
hz, latency = test_communication_speed(interface, responding_motor, num_iterations=100)
|
||||
|
||||
if hz:
|
||||
print(f" ✓ Max frequency: {hz:.1f} Hz")
|
||||
print(f" ✓ Avg latency: {latency:.2f} ms")
|
||||
print(f" ✓ Commands per second: ~{int(hz)}")
|
||||
else:
|
||||
print(f" ✗ Speed test failed")
|
||||
else:
|
||||
print(f"\n{interface}: No motors found, skipping speed test")
|
||||
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
main()
|
||||
except KeyboardInterrupt:
|
||||
print("\n\nTesting interrupted by user.")
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
print(f"\nUnexpected error: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
|
||||
@@ -0,0 +1,360 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
OpenArms Policy Evaluation
|
||||
|
||||
Evaluates a trained policy on the OpenArms robot by running inference and recording
|
||||
the evaluation episodes to a dataset. Supports optional leader arm for manual resets.
|
||||
|
||||
Example usage:
|
||||
python examples/openarms/evaluate.py
|
||||
"""
|
||||
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
|
||||
from lerobot.datasets.utils import combine_feature_dicts
|
||||
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
||||
from lerobot.processor import make_default_processors
|
||||
from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig
|
||||
from lerobot.robots.openarms.openarms_follower import OpenArmsFollower
|
||||
from lerobot.scripts.lerobot_record import record_loop
|
||||
from lerobot.teleoperators.openarms.config_openarms_leader import OpenArmsLeaderConfig
|
||||
from lerobot.teleoperators.openarms.openarms_leader import OpenArmsLeader
|
||||
from lerobot.utils.control_utils import init_keyboard_listener
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import init_rerun
|
||||
|
||||
|
||||
HF_MODEL_ID = "lerobot-data-collection/three-folds-pi0" # TODO: Replace with your trained model
|
||||
HF_EVAL_DATASET_ID = "lerobot-data-collection/three-folds-pi0_eval7" # TODO: Replace with your eval dataset name
|
||||
TASK_DESCRIPTION = "three-folds-dataset" # TODO: Replace with your task, this should match!!
|
||||
|
||||
NUM_EPISODES = 1
|
||||
FPS = 30
|
||||
EPISODE_TIME_SEC = 300
|
||||
RESET_TIME_SEC = 60
|
||||
|
||||
# Robot CAN interfaces
|
||||
FOLLOWER_LEFT_PORT = "can0"
|
||||
FOLLOWER_RIGHT_PORT = "can1"
|
||||
|
||||
# If enabled, you can manually reset the environment between evaluation episodes
|
||||
USE_LEADER_FOR_RESETS = True # Set to False if you don't want to use leader
|
||||
LEADER_LEFT_PORT = "can2"
|
||||
LEADER_RIGHT_PORT = "can3"
|
||||
|
||||
# Camera configuration
|
||||
CAMERA_CONFIG = {
|
||||
"left_wrist": OpenCVCameraConfig(index_or_path="/dev/video5", width=640, height=480, fps=FPS),
|
||||
"right_wrist": OpenCVCameraConfig(index_or_path="/dev/video1", width=640, height=480, fps=FPS),
|
||||
"base": OpenCVCameraConfig(index_or_path="/dev/video3", width=640, height=480, fps=FPS),
|
||||
}
|
||||
|
||||
def main():
|
||||
"""Main evaluation function."""
|
||||
print("OpenArms Policy Evaluation")
|
||||
print(f"\nModel: {HF_MODEL_ID}")
|
||||
print(f"Evaluation Dataset: {HF_EVAL_DATASET_ID}")
|
||||
print(f"Task: {TASK_DESCRIPTION}")
|
||||
print(f"Episodes: {NUM_EPISODES}")
|
||||
print(f"Episode Duration: {EPISODE_TIME_SEC}s")
|
||||
print(f"Reset Duration: {RESET_TIME_SEC}s")
|
||||
print(f"Use Leader for Resets: {USE_LEADER_FOR_RESETS}")
|
||||
|
||||
follower_config = OpenArmsFollowerConfig(
|
||||
port_left=FOLLOWER_LEFT_PORT,
|
||||
port_right=FOLLOWER_RIGHT_PORT,
|
||||
can_interface="socketcan",
|
||||
id="openarms_follower",
|
||||
disable_torque_on_disconnect=True,
|
||||
max_relative_target=10.0,
|
||||
cameras=CAMERA_CONFIG,
|
||||
)
|
||||
|
||||
follower = OpenArmsFollower(follower_config)
|
||||
follower.connect(calibrate=False)
|
||||
|
||||
if not follower.is_connected:
|
||||
raise RuntimeError("Follower robot failed to connect!")
|
||||
|
||||
|
||||
leader = None
|
||||
if USE_LEADER_FOR_RESETS:
|
||||
leader_config = OpenArmsLeaderConfig(
|
||||
port_left=LEADER_LEFT_PORT,
|
||||
port_right=LEADER_RIGHT_PORT,
|
||||
can_interface="socketcan",
|
||||
id="openarms_leader",
|
||||
manual_control=False, # Enable torque control for gravity compensation
|
||||
)
|
||||
|
||||
leader = OpenArmsLeader(leader_config)
|
||||
leader.connect(calibrate=False)
|
||||
|
||||
if not leader.is_connected:
|
||||
raise RuntimeError("Leader robot failed to connect!")
|
||||
|
||||
# Enable gravity compensation
|
||||
if leader.pin_robot is not None:
|
||||
leader.bus_right.enable_torque()
|
||||
leader.bus_left.enable_torque()
|
||||
time.sleep(0.1)
|
||||
print(f"Leader connected with gravity compensation ({LEADER_LEFT_PORT}, {LEADER_RIGHT_PORT})")
|
||||
else:
|
||||
print(f"Leader connected but gravity compensation unavailable (no URDF)")
|
||||
|
||||
# Build default processors for action and observation
|
||||
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
|
||||
|
||||
# Build dataset features from robot features and processors
|
||||
# For actions, only include positions (no velocity or torque)
|
||||
action_features_hw = {}
|
||||
for key, value in follower.action_features.items():
|
||||
if key.endswith(".pos"):
|
||||
action_features_hw[key] = value
|
||||
|
||||
dataset_features = combine_feature_dicts(
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=teleop_action_processor,
|
||||
initial_features=create_initial_features(action=action_features_hw),
|
||||
use_videos=True,
|
||||
),
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=robot_observation_processor,
|
||||
initial_features=create_initial_features(observation=follower.observation_features),
|
||||
use_videos=True,
|
||||
),
|
||||
)
|
||||
|
||||
# Check if dataset already exists
|
||||
dataset_path = Path.home() / ".cache" / "huggingface" / "lerobot" / HF_EVAL_DATASET_ID
|
||||
if dataset_path.exists():
|
||||
print(f"Evaluation dataset already exists at: {dataset_path}")
|
||||
print("This will append new episodes to the existing dataset.")
|
||||
choice = input(" Continue? (y/n): ").strip().lower()
|
||||
if choice != 'y':
|
||||
print(" Aborting evaluation.")
|
||||
follower.disconnect()
|
||||
if leader:
|
||||
leader.disconnect()
|
||||
return
|
||||
|
||||
# Create dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=HF_EVAL_DATASET_ID,
|
||||
fps=FPS,
|
||||
features=dataset_features,
|
||||
robot_type=follower.name,
|
||||
use_videos=True,
|
||||
image_writer_processes=0,
|
||||
image_writer_threads=12,
|
||||
)
|
||||
|
||||
# Load policy config from pretrained model and create policy using factory
|
||||
policy_config = PreTrainedConfig.from_pretrained(HF_MODEL_ID)
|
||||
policy_config.pretrained_path = HF_MODEL_ID
|
||||
policy = make_policy(policy_config, ds_meta=dataset.meta)
|
||||
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=policy.config,
|
||||
pretrained_path=HF_MODEL_ID,
|
||||
dataset_stats=dataset.meta.stats,
|
||||
preprocessor_overrides={
|
||||
"device_processor": {"device": str(policy.config.device)}
|
||||
},
|
||||
)
|
||||
|
||||
print(f"\nRunning evaluation...")
|
||||
# Initialize keyboard listener and visualization
|
||||
listener, events = init_keyboard_listener()
|
||||
init_rerun(session_name="openarms_evaluation")
|
||||
episode_idx = 0
|
||||
|
||||
try:
|
||||
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Evaluating episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
print(f"\nRunning inference for episode {episode_idx + 1}...")
|
||||
|
||||
# Run inference with policy
|
||||
record_loop(
|
||||
robot=follower,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
)
|
||||
|
||||
# Handle re-recording
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-recording episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
# Save episode
|
||||
if dataset.episode_buffer is not None and dataset.episode_buffer.get("size", 0) > 0:
|
||||
print(f"Saving episode {episode_idx + 1} ({dataset.episode_buffer['size']} frames)...")
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
|
||||
# Reset environment between episodes (if not last episode)
|
||||
if not events["stop_recording"] and episode_idx < NUM_EPISODES:
|
||||
if USE_LEADER_FOR_RESETS and leader:
|
||||
log_say("Reset the environment using leader arms")
|
||||
print(f"\nManual reset period ({RESET_TIME_SEC}s)...")
|
||||
|
||||
# Use leader for manual reset with gravity compensation
|
||||
import numpy as np
|
||||
|
||||
dt = 1 / FPS
|
||||
reset_start_time = time.perf_counter()
|
||||
|
||||
while time.perf_counter() - reset_start_time < RESET_TIME_SEC:
|
||||
if events["exit_early"] or events["stop_recording"]:
|
||||
break
|
||||
|
||||
loop_start = time.perf_counter()
|
||||
|
||||
# Get leader state
|
||||
leader_action = leader.get_action()
|
||||
|
||||
# Extract positions and velocities
|
||||
leader_positions_deg = {}
|
||||
leader_velocities_deg_per_sec = {}
|
||||
|
||||
for motor in leader.bus_right.motors:
|
||||
pos_key = f"right_{motor}.pos"
|
||||
vel_key = f"right_{motor}.vel"
|
||||
if pos_key in leader_action:
|
||||
leader_positions_deg[f"right_{motor}"] = leader_action[pos_key]
|
||||
if vel_key in leader_action:
|
||||
leader_velocities_deg_per_sec[f"right_{motor}"] = leader_action[vel_key]
|
||||
|
||||
for motor in leader.bus_left.motors:
|
||||
pos_key = f"left_{motor}.pos"
|
||||
vel_key = f"left_{motor}.vel"
|
||||
if pos_key in leader_action:
|
||||
leader_positions_deg[f"left_{motor}"] = leader_action[pos_key]
|
||||
if vel_key in leader_action:
|
||||
leader_velocities_deg_per_sec[f"left_{motor}"] = leader_action[vel_key]
|
||||
|
||||
# Calculate gravity and friction torques
|
||||
leader_positions_rad = {k: np.deg2rad(v) for k, v in leader_positions_deg.items()}
|
||||
leader_gravity_torques_nm = leader._gravity_from_q(leader_positions_rad)
|
||||
|
||||
leader_velocities_rad_per_sec = {k: np.deg2rad(v) for k, v in leader_velocities_deg_per_sec.items()}
|
||||
leader_friction_torques_nm = leader._friction_from_velocity(
|
||||
leader_velocities_rad_per_sec,
|
||||
friction_scale=1.0
|
||||
)
|
||||
|
||||
# Combine torques
|
||||
leader_total_torques_nm = {}
|
||||
for motor_name in leader_gravity_torques_nm:
|
||||
gravity = leader_gravity_torques_nm.get(motor_name, 0.0)
|
||||
friction = leader_friction_torques_nm.get(motor_name, 0.0)
|
||||
leader_total_torques_nm[motor_name] = gravity + friction
|
||||
|
||||
# Apply compensation
|
||||
for motor in leader.bus_right.motors:
|
||||
full_name = f"right_{motor}"
|
||||
position = leader_positions_deg.get(full_name, 0.0)
|
||||
torque = leader_total_torques_nm.get(full_name, 0.0)
|
||||
kd = leader.get_damping_kd(motor)
|
||||
|
||||
leader.bus_right._mit_control(
|
||||
motor=motor, kp=0.0, kd=kd,
|
||||
position_degrees=position,
|
||||
velocity_deg_per_sec=0.0,
|
||||
torque=torque,
|
||||
)
|
||||
|
||||
for motor in leader.bus_left.motors:
|
||||
full_name = f"left_{motor}"
|
||||
position = leader_positions_deg.get(full_name, 0.0)
|
||||
torque = leader_total_torques_nm.get(full_name, 0.0)
|
||||
kd = leader.get_damping_kd(motor)
|
||||
|
||||
leader.bus_left._mit_control(
|
||||
motor=motor, kp=0.0, kd=kd,
|
||||
position_degrees=position,
|
||||
velocity_deg_per_sec=0.0,
|
||||
torque=torque,
|
||||
)
|
||||
|
||||
# Send leader positions to follower
|
||||
follower_action = {}
|
||||
for joint in leader_positions_deg.keys():
|
||||
pos_key = f"{joint}.pos"
|
||||
if pos_key in leader_action:
|
||||
follower_action[pos_key] = leader_action[pos_key]
|
||||
|
||||
if follower_action:
|
||||
follower.send_action(follower_action)
|
||||
|
||||
# Maintain loop rate
|
||||
loop_duration = time.perf_counter() - loop_start
|
||||
sleep_time = dt - loop_duration
|
||||
if sleep_time > 0:
|
||||
time.sleep(sleep_time)
|
||||
|
||||
print("Reset complete")
|
||||
else:
|
||||
log_say("Waiting for manual reset")
|
||||
print(f"Manually reset the environment and press ENTER to continue")
|
||||
input("Press ENTER when ready...")
|
||||
|
||||
print(f"Evaluation complete! {episode_idx} episodes recorded")
|
||||
log_say("Evaluation complete", blocking=True)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\nEvaluation interrupted by user")
|
||||
|
||||
finally:
|
||||
if leader:
|
||||
leader.bus_right.disable_torque()
|
||||
leader.bus_left.disable_torque()
|
||||
time.sleep(0.1)
|
||||
leader.disconnect()
|
||||
|
||||
follower.disconnect()
|
||||
|
||||
if listener is not None:
|
||||
listener.stop()
|
||||
|
||||
dataset.finalize()
|
||||
print("\nUploading to Hugging Face Hub...")
|
||||
dataset.push_to_hub(private=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -0,0 +1,703 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
OpenArms Policy Evaluation with Real-Time Chunking (RTC)
|
||||
|
||||
Evaluates a trained policy on the OpenArms robot using RTC for smooth, continuous motion.
|
||||
RTC enables large flow-matching policies (Pi0, Pi0.5, SmolVLA) to produce reactive motion
|
||||
despite high inference latency by asynchronously generating action chunks.
|
||||
|
||||
Features:
|
||||
- Thread-based asynchronous action generation and execution
|
||||
- RTC for smooth transitions between action chunks
|
||||
- Dataset recording for evaluation episodes
|
||||
|
||||
Example usage:
|
||||
python examples/openarms/evaluate_with_rtc.py
|
||||
|
||||
# With custom RTC parameters
|
||||
python examples/openarms/evaluate_with_rtc.py \
|
||||
--rtc.execution_horizon=12 \
|
||||
--rtc.max_guidance_weight=10.0
|
||||
"""
|
||||
|
||||
import logging
|
||||
import math
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from threading import Event, Lock, Thread
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import RTCAttentionSchedule
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
|
||||
from lerobot.datasets.utils import build_dataset_frame, combine_feature_dicts, hw_to_dataset_features
|
||||
from lerobot.policies.factory import get_policy_class, make_pre_post_processors
|
||||
from lerobot.policies.rtc.action_queue import ActionQueue
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.policies.rtc.latency_tracker import LatencyTracker
|
||||
from lerobot.processor import make_default_processors
|
||||
from lerobot.rl.process import ProcessSignalHandler
|
||||
from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig
|
||||
from lerobot.robots.openarms.openarms_follower import OpenArmsFollower
|
||||
from lerobot.utils.hub import HubMixin
|
||||
from lerobot.utils.utils import init_logging, log_say
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Default Configuration Constants
|
||||
# ============================================================================
|
||||
|
||||
DEFAULT_HF_MODEL_ID = "lerobot-data-collection/three-folds-pi0"
|
||||
DEFAULT_HF_EVAL_DATASET_ID = "lerobot-data-collection/three-folds-pi0_eval_rtc"
|
||||
DEFAULT_TASK_DESCRIPTION = "three-folds-dataset"
|
||||
|
||||
DEFAULT_NUM_EPISODES = 1
|
||||
DEFAULT_FPS = 30
|
||||
DEFAULT_EPISODE_TIME_SEC = 300
|
||||
DEFAULT_RESET_TIME_SEC = 60
|
||||
|
||||
DEFAULT_FOLLOWER_LEFT_PORT = "can0"
|
||||
DEFAULT_FOLLOWER_RIGHT_PORT = "can1"
|
||||
|
||||
DEFAULT_CAMERA_CONFIG = {
|
||||
"left_wrist": OpenCVCameraConfig(index_or_path="/dev/video5", width=640, height=480, fps=DEFAULT_FPS),
|
||||
"right_wrist": OpenCVCameraConfig(index_or_path="/dev/video1", width=640, height=480, fps=DEFAULT_FPS),
|
||||
"base": OpenCVCameraConfig(index_or_path="/dev/video3", width=640, height=480, fps=DEFAULT_FPS),
|
||||
}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Thread-Safe Robot Wrapper
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class RobotWrapper:
|
||||
"""Thread-safe wrapper for robot operations."""
|
||||
|
||||
def __init__(self, robot: OpenArmsFollower):
|
||||
self.robot = robot
|
||||
self.lock = Lock()
|
||||
|
||||
def get_observation(self) -> dict[str, Tensor]:
|
||||
with self.lock:
|
||||
return self.robot.get_observation()
|
||||
|
||||
def send_action(self, action: dict) -> None:
|
||||
with self.lock:
|
||||
self.robot.send_action(action)
|
||||
|
||||
@property
|
||||
def observation_features(self) -> dict:
|
||||
with self.lock:
|
||||
return self.robot.observation_features
|
||||
|
||||
@property
|
||||
def action_features(self) -> dict:
|
||||
with self.lock:
|
||||
return self.robot.action_features
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self.robot.name
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Configuration
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@dataclass
|
||||
class OpenArmsRTCEvalConfig(HubMixin):
|
||||
"""Configuration for OpenArms evaluation with RTC."""
|
||||
|
||||
policy: PreTrainedConfig | None = None
|
||||
|
||||
rtc: RTCConfig = field(
|
||||
default_factory=lambda: RTCConfig(
|
||||
enabled=True,
|
||||
execution_horizon=10,
|
||||
max_guidance_weight=10.0,
|
||||
prefix_attention_schedule=RTCAttentionSchedule.EXP,
|
||||
)
|
||||
)
|
||||
|
||||
model_id: str = DEFAULT_HF_MODEL_ID
|
||||
eval_dataset_id: str = DEFAULT_HF_EVAL_DATASET_ID
|
||||
task: str = DEFAULT_TASK_DESCRIPTION
|
||||
|
||||
num_episodes: int = DEFAULT_NUM_EPISODES
|
||||
fps: float = DEFAULT_FPS
|
||||
episode_time_sec: float = DEFAULT_EPISODE_TIME_SEC
|
||||
reset_time_sec: float = DEFAULT_RESET_TIME_SEC
|
||||
|
||||
follower_left_port: str = DEFAULT_FOLLOWER_LEFT_PORT
|
||||
follower_right_port: str = DEFAULT_FOLLOWER_RIGHT_PORT
|
||||
|
||||
device: str = "cuda"
|
||||
|
||||
# Should be higher than inference_delay + execution_horizon
|
||||
action_queue_size_to_get_new_actions: int = 30
|
||||
|
||||
record_dataset: bool = True
|
||||
push_to_hub: bool = True
|
||||
|
||||
interpolation: bool = False
|
||||
|
||||
use_torch_compile: bool = False
|
||||
torch_compile_backend: str = "inductor"
|
||||
torch_compile_mode: str = "default"
|
||||
torch_compile_disable_cudagraphs: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
policy_path = parser.get_path_arg("policy")
|
||||
if policy_path:
|
||||
cli_overrides = parser.get_cli_overrides("policy")
|
||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||
self.policy.pretrained_path = policy_path
|
||||
self.model_id = policy_path
|
||||
elif self.model_id:
|
||||
self.policy = PreTrainedConfig.from_pretrained(self.model_id)
|
||||
self.policy.pretrained_path = self.model_id
|
||||
|
||||
@classmethod
|
||||
def __get_path_fields__(cls) -> list[str]:
|
||||
return ["policy"]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Action Generation Thread
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def get_actions_thread(
|
||||
policy,
|
||||
robot: RobotWrapper,
|
||||
robot_observation_processor,
|
||||
action_queue: ActionQueue,
|
||||
shutdown_event: Event,
|
||||
cfg: OpenArmsRTCEvalConfig,
|
||||
episode_active: Event,
|
||||
):
|
||||
"""Thread function to asynchronously generate action chunks from the policy."""
|
||||
try:
|
||||
logger.info("[GET_ACTIONS] Starting action generation thread")
|
||||
|
||||
latency_tracker = LatencyTracker()
|
||||
time_per_chunk = 1.0 / cfg.fps
|
||||
|
||||
hw_features = hw_to_dataset_features(robot.observation_features, "observation")
|
||||
policy_device = policy.config.device
|
||||
|
||||
logger.info(f"[GET_ACTIONS] Loading preprocessor/postprocessor from {cfg.policy.pretrained_path}")
|
||||
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy,
|
||||
pretrained_path=cfg.policy.pretrained_path,
|
||||
dataset_stats=None,
|
||||
preprocessor_overrides={
|
||||
"device_processor": {"device": cfg.device},
|
||||
},
|
||||
)
|
||||
|
||||
logger.info("[GET_ACTIONS] Preprocessor/postprocessor loaded successfully")
|
||||
|
||||
get_actions_threshold = cfg.action_queue_size_to_get_new_actions
|
||||
if not cfg.rtc.enabled:
|
||||
get_actions_threshold = 0
|
||||
|
||||
while not shutdown_event.is_set():
|
||||
if not episode_active.is_set():
|
||||
time.sleep(0.01)
|
||||
continue
|
||||
|
||||
if action_queue.qsize() <= get_actions_threshold:
|
||||
current_time = time.perf_counter()
|
||||
action_index_before_inference = action_queue.get_action_index()
|
||||
prev_actions = action_queue.get_left_over()
|
||||
|
||||
inference_latency = latency_tracker.max()
|
||||
inference_delay = math.ceil(inference_latency / time_per_chunk) if inference_latency else 0
|
||||
|
||||
obs = robot.get_observation()
|
||||
obs_processed = robot_observation_processor(obs)
|
||||
|
||||
obs_with_policy_features = build_dataset_frame(
|
||||
hw_features, obs_processed, prefix="observation"
|
||||
)
|
||||
|
||||
for name in obs_with_policy_features:
|
||||
obs_with_policy_features[name] = torch.from_numpy(obs_with_policy_features[name])
|
||||
if "image" in name:
|
||||
obs_with_policy_features[name] = (
|
||||
obs_with_policy_features[name].type(torch.float32) / 255
|
||||
)
|
||||
obs_with_policy_features[name] = (
|
||||
obs_with_policy_features[name].permute(2, 0, 1).contiguous()
|
||||
)
|
||||
obs_with_policy_features[name] = obs_with_policy_features[name].unsqueeze(0)
|
||||
obs_with_policy_features[name] = obs_with_policy_features[name].to(policy_device)
|
||||
|
||||
obs_with_policy_features["task"] = [cfg.task]
|
||||
obs_with_policy_features["robot_type"] = robot.name
|
||||
|
||||
preprocessed_obs = preprocessor(obs_with_policy_features)
|
||||
|
||||
actions = policy.predict_action_chunk(
|
||||
preprocessed_obs,
|
||||
inference_delay=inference_delay,
|
||||
prev_chunk_left_over=prev_actions,
|
||||
)
|
||||
|
||||
original_actions = actions.squeeze(0).clone()
|
||||
postprocessed_actions = postprocessor(actions).squeeze(0)
|
||||
|
||||
new_latency = time.perf_counter() - current_time
|
||||
new_delay = math.ceil(new_latency / time_per_chunk)
|
||||
latency_tracker.add(new_latency)
|
||||
|
||||
if cfg.action_queue_size_to_get_new_actions < cfg.rtc.execution_horizon + new_delay:
|
||||
logger.warning(
|
||||
"[GET_ACTIONS] action_queue_size_to_get_new_actions too small. "
|
||||
"Should be higher than inference delay + execution horizon."
|
||||
)
|
||||
|
||||
action_queue.merge(
|
||||
original_actions, postprocessed_actions, new_delay, action_index_before_inference
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"[GET_ACTIONS] Generated chunk, latency={new_latency:.3f}s, "
|
||||
f"delay={new_delay}, queue_size={action_queue.qsize()}"
|
||||
)
|
||||
else:
|
||||
time.sleep(0.01)
|
||||
|
||||
logger.info("[GET_ACTIONS] Action generation thread shutting down")
|
||||
except Exception as e:
|
||||
logger.error(f"[GET_ACTIONS] Fatal exception: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
shutdown_event.set()
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Action Execution Thread
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def actor_thread(
|
||||
robot: RobotWrapper,
|
||||
robot_action_processor,
|
||||
action_queue: ActionQueue,
|
||||
shutdown_event: Event,
|
||||
cfg: OpenArmsRTCEvalConfig,
|
||||
episode_active: Event,
|
||||
dataset: LeRobotDataset | None,
|
||||
dataset_lock: Lock,
|
||||
teleop_action_processor,
|
||||
robot_observation_processor,
|
||||
):
|
||||
"""Thread function to execute actions on the robot."""
|
||||
try:
|
||||
action_keys = [k for k in robot.action_features.keys() if k.endswith(".pos")]
|
||||
|
||||
if cfg.interpolation:
|
||||
interp_factor = 2
|
||||
robot_interval = 1.0 / (cfg.fps * interp_factor)
|
||||
logger.info(f"[ACTOR] Interpolation ON: policy={cfg.fps}Hz -> robot={cfg.fps * interp_factor}Hz (2x)")
|
||||
else:
|
||||
interp_factor = 1
|
||||
robot_interval = 1.0 / cfg.fps
|
||||
logger.info(f"[ACTOR] Interpolation OFF: policy={cfg.fps}Hz, robot={cfg.fps}Hz")
|
||||
|
||||
prev_action: Tensor | None = None
|
||||
interpolated_actions: list[Tensor] = []
|
||||
interp_idx = 0
|
||||
|
||||
robot_send_count = 0
|
||||
policy_consume_count = 0
|
||||
last_hz_print = time.perf_counter()
|
||||
last_dataset_time = 0.0
|
||||
|
||||
while not shutdown_event.is_set():
|
||||
if not episode_active.is_set():
|
||||
prev_action = None
|
||||
interpolated_actions = []
|
||||
interp_idx = 0
|
||||
robot_send_count = 0
|
||||
policy_consume_count = 0
|
||||
last_hz_print = time.perf_counter()
|
||||
time.sleep(0.01)
|
||||
continue
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
if interp_idx >= len(interpolated_actions):
|
||||
new_action = action_queue.get()
|
||||
if new_action is not None:
|
||||
current_action = new_action.cpu()
|
||||
policy_consume_count += 1
|
||||
|
||||
if cfg.interpolation and prev_action is not None:
|
||||
mid = prev_action + 0.5 * (current_action - prev_action)
|
||||
interpolated_actions = [mid, current_action]
|
||||
else:
|
||||
interpolated_actions = [current_action]
|
||||
|
||||
prev_action = current_action
|
||||
interp_idx = 0
|
||||
|
||||
if interp_idx < len(interpolated_actions):
|
||||
action_to_send = interpolated_actions[interp_idx]
|
||||
interp_idx += 1
|
||||
|
||||
action_dict = {}
|
||||
for i, key in enumerate(action_keys):
|
||||
if i < len(action_to_send):
|
||||
action_dict[key] = action_to_send[i].item()
|
||||
|
||||
action_processed = robot_action_processor((action_dict, None))
|
||||
robot.send_action(action_processed)
|
||||
robot_send_count += 1
|
||||
|
||||
if cfg.record_dataset and dataset is not None:
|
||||
now = time.perf_counter()
|
||||
if now - last_dataset_time >= (1.0 / cfg.fps):
|
||||
last_dataset_time = now
|
||||
with dataset_lock:
|
||||
obs = robot.get_observation()
|
||||
obs_processed = robot_observation_processor(obs)
|
||||
action_for_dataset = teleop_action_processor((action_dict, None))
|
||||
frame = {}
|
||||
for key, value in obs_processed.items():
|
||||
frame[f"observation.{key}"] = value
|
||||
for key, value in action_for_dataset.items():
|
||||
frame[f"action.{key}"] = value
|
||||
frame["task"] = cfg.task
|
||||
dataset.add_frame(frame)
|
||||
|
||||
now = time.perf_counter()
|
||||
if now - last_hz_print >= 5.0:
|
||||
elapsed = now - last_hz_print
|
||||
actual_robot_hz = robot_send_count / elapsed if elapsed > 0 else 0
|
||||
actual_policy_hz = policy_consume_count / elapsed if elapsed > 0 else 0
|
||||
logger.info(f"[ACTOR] Actual Hz - Robot: {actual_robot_hz:.1f}, Policy: {actual_policy_hz:.1f}")
|
||||
robot_send_count = 0
|
||||
policy_consume_count = 0
|
||||
last_hz_print = now
|
||||
|
||||
dt_s = time.perf_counter() - start_time
|
||||
sleep_time = max(0, robot_interval - dt_s - 0.001)
|
||||
if sleep_time > 0:
|
||||
time.sleep(sleep_time)
|
||||
|
||||
logger.info("[ACTOR] Shutting down")
|
||||
except Exception as e:
|
||||
logger.error(f"[ACTOR] Fatal exception: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
shutdown_event.set()
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Main Evaluation Function
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def _apply_torch_compile(policy, cfg: OpenArmsRTCEvalConfig):
|
||||
"""Apply torch.compile to the policy's predict_action_chunk method."""
|
||||
if policy.name in ["pi05", "pi0"]:
|
||||
return policy
|
||||
|
||||
try:
|
||||
if not hasattr(torch, "compile"):
|
||||
logger.warning(
|
||||
f"torch.compile not available. Requires PyTorch 2.0+. "
|
||||
f"Current version: {torch.__version__}. Skipping compilation."
|
||||
)
|
||||
return policy
|
||||
|
||||
logger.info("Applying torch.compile to predict_action_chunk...")
|
||||
|
||||
compile_kwargs = {
|
||||
"backend": cfg.torch_compile_backend,
|
||||
"mode": cfg.torch_compile_mode,
|
||||
}
|
||||
|
||||
if cfg.torch_compile_disable_cudagraphs:
|
||||
compile_kwargs["options"] = {"triton.cudagraphs": False}
|
||||
|
||||
original_method = policy.predict_action_chunk
|
||||
compiled_method = torch.compile(original_method, **compile_kwargs)
|
||||
policy.predict_action_chunk = compiled_method
|
||||
logger.info("Successfully compiled predict_action_chunk")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to apply torch.compile: {e}")
|
||||
logger.warning("Continuing without torch.compile")
|
||||
|
||||
return policy
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def main(cfg: OpenArmsRTCEvalConfig):
|
||||
"""Main evaluation function with RTC."""
|
||||
init_logging()
|
||||
|
||||
print("=" * 60)
|
||||
print("OpenArms Policy Evaluation with RTC")
|
||||
print("=" * 60)
|
||||
print(f"\nModel: {cfg.model_id}")
|
||||
print(f"Evaluation Dataset: {cfg.eval_dataset_id}")
|
||||
print(f"Task: {cfg.task}")
|
||||
print(f"Episodes: {cfg.num_episodes}")
|
||||
print(f"Episode Duration: {cfg.episode_time_sec}s")
|
||||
print(f"RTC Enabled: {cfg.rtc.enabled}")
|
||||
print(f"RTC Execution Horizon: {cfg.rtc.execution_horizon}")
|
||||
print(f"RTC Max Guidance Weight: {cfg.rtc.max_guidance_weight}")
|
||||
print(f"Policy Hz: {cfg.fps}")
|
||||
print(f"Robot Hz: {cfg.fps * 2 if cfg.interpolation else cfg.fps}")
|
||||
print(f"Interpolation: {cfg.interpolation}")
|
||||
print(f"Device: {cfg.device}")
|
||||
print("=" * 60)
|
||||
|
||||
signal_handler = ProcessSignalHandler(use_threads=True, display_pid=False)
|
||||
shutdown_event = signal_handler.shutdown_event
|
||||
episode_active = Event()
|
||||
|
||||
# Initialize Robot
|
||||
follower_config = OpenArmsFollowerConfig(
|
||||
port_left=cfg.follower_left_port,
|
||||
port_right=cfg.follower_right_port,
|
||||
can_interface="socketcan",
|
||||
id="openarms_follower",
|
||||
disable_torque_on_disconnect=True,
|
||||
max_relative_target=10.0,
|
||||
cameras=DEFAULT_CAMERA_CONFIG,
|
||||
)
|
||||
|
||||
follower = OpenArmsFollower(follower_config)
|
||||
follower.connect(calibrate=False)
|
||||
|
||||
if not follower.is_connected:
|
||||
raise RuntimeError("Follower robot failed to connect!")
|
||||
|
||||
robot = RobotWrapper(follower)
|
||||
logger.info("Follower robot connected")
|
||||
|
||||
# Build Processors and Dataset Features
|
||||
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
|
||||
|
||||
action_features_hw = {}
|
||||
for key, value in follower.action_features.items():
|
||||
if key.endswith(".pos"):
|
||||
action_features_hw[key] = value
|
||||
|
||||
dataset_features = combine_feature_dicts(
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=teleop_action_processor,
|
||||
initial_features=create_initial_features(action=action_features_hw),
|
||||
use_videos=True,
|
||||
),
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=robot_observation_processor,
|
||||
initial_features=create_initial_features(observation=follower.observation_features),
|
||||
use_videos=True,
|
||||
),
|
||||
)
|
||||
|
||||
# Create or Load Dataset
|
||||
dataset = None
|
||||
dataset_lock = Lock()
|
||||
|
||||
if cfg.record_dataset:
|
||||
dataset_path = Path.home() / ".cache" / "huggingface" / "lerobot" / cfg.eval_dataset_id
|
||||
if dataset_path.exists():
|
||||
logger.info(f"Evaluation dataset exists at: {dataset_path}")
|
||||
logger.info("New episodes will be appended.")
|
||||
choice = input("Continue? (y/n): ").strip().lower()
|
||||
if choice != "y":
|
||||
logger.info("Aborting evaluation.")
|
||||
follower.disconnect()
|
||||
return
|
||||
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=cfg.eval_dataset_id,
|
||||
fps=int(cfg.fps),
|
||||
features=dataset_features,
|
||||
robot_type=follower.name,
|
||||
use_videos=True,
|
||||
image_writer_processes=0,
|
||||
image_writer_threads=12,
|
||||
)
|
||||
logger.info(f"Dataset created: {cfg.eval_dataset_id}")
|
||||
|
||||
# Load Policy
|
||||
logger.info(f"Loading policy from: {cfg.model_id}")
|
||||
|
||||
policy_class = get_policy_class(cfg.policy.type)
|
||||
config = PreTrainedConfig.from_pretrained(cfg.policy.pretrained_path)
|
||||
|
||||
if cfg.policy.type in ["pi05", "pi0"]:
|
||||
config.compile_model = cfg.use_torch_compile
|
||||
|
||||
policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=config)
|
||||
|
||||
policy.config.rtc_config = cfg.rtc
|
||||
policy.init_rtc_processor()
|
||||
|
||||
assert policy.name in ["smolvla", "pi05", "pi0"], "Only smolvla, pi05, and pi0 are supported for RTC"
|
||||
|
||||
policy = policy.to(cfg.device)
|
||||
policy.eval()
|
||||
|
||||
if cfg.use_torch_compile:
|
||||
policy = _apply_torch_compile(policy, cfg)
|
||||
|
||||
logger.info(f"Policy loaded: {policy.name}")
|
||||
|
||||
# Create Action Queue and Start Threads
|
||||
action_queue = ActionQueue(cfg.rtc)
|
||||
|
||||
get_actions_t = Thread(
|
||||
target=get_actions_thread,
|
||||
args=(
|
||||
policy,
|
||||
robot,
|
||||
robot_observation_processor,
|
||||
action_queue,
|
||||
shutdown_event,
|
||||
cfg,
|
||||
episode_active,
|
||||
),
|
||||
daemon=True,
|
||||
name="GetActions",
|
||||
)
|
||||
get_actions_t.start()
|
||||
logger.info("Started action generation thread")
|
||||
|
||||
actor_t = Thread(
|
||||
target=actor_thread,
|
||||
args=(
|
||||
robot,
|
||||
robot_action_processor,
|
||||
action_queue,
|
||||
shutdown_event,
|
||||
cfg,
|
||||
episode_active,
|
||||
dataset,
|
||||
dataset_lock,
|
||||
teleop_action_processor,
|
||||
robot_observation_processor,
|
||||
),
|
||||
daemon=True,
|
||||
name="Actor",
|
||||
)
|
||||
actor_t.start()
|
||||
logger.info("Started actor thread")
|
||||
|
||||
# Run Evaluation Episodes
|
||||
episode_idx = 0
|
||||
|
||||
try:
|
||||
while episode_idx < cfg.num_episodes and not shutdown_event.is_set():
|
||||
log_say(f"Evaluating episode {episode_idx + 1} of {cfg.num_episodes}")
|
||||
logger.info(f"\n{'='*40}")
|
||||
logger.info(f"Episode {episode_idx + 1} / {cfg.num_episodes}")
|
||||
logger.info(f"{'='*40}")
|
||||
|
||||
action_queue = ActionQueue(cfg.rtc)
|
||||
episode_active.set()
|
||||
episode_start_time = time.time()
|
||||
|
||||
while (time.time() - episode_start_time) < cfg.episode_time_sec:
|
||||
if shutdown_event.is_set():
|
||||
break
|
||||
|
||||
elapsed = time.time() - episode_start_time
|
||||
if int(elapsed) % 10 == 0 and int(elapsed) > 0:
|
||||
logger.info(
|
||||
f"[MAIN] Episode progress: {elapsed:.0f}/{cfg.episode_time_sec}s, "
|
||||
f"queue_size={action_queue.qsize()}"
|
||||
)
|
||||
|
||||
time.sleep(0.5)
|
||||
|
||||
episode_active.clear()
|
||||
logger.info(f"Episode {episode_idx + 1} completed")
|
||||
|
||||
if cfg.record_dataset and dataset is not None:
|
||||
with dataset_lock:
|
||||
if dataset.episode_buffer is not None and dataset.episode_buffer.get("size", 0) > 0:
|
||||
logger.info(
|
||||
f"Saving episode {episode_idx + 1} "
|
||||
f"({dataset.episode_buffer['size']} frames)"
|
||||
)
|
||||
dataset.save_episode()
|
||||
|
||||
episode_idx += 1
|
||||
|
||||
# Manual reset between episodes
|
||||
if not shutdown_event.is_set() and episode_idx < cfg.num_episodes:
|
||||
log_say("Waiting for manual reset")
|
||||
logger.info("Manually reset the environment and press ENTER to continue")
|
||||
input("Press ENTER when ready...")
|
||||
|
||||
logger.info(f"Evaluation complete! {episode_idx} episodes recorded")
|
||||
log_say("Evaluation complete", blocking=True)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("\n\nEvaluation interrupted by user")
|
||||
|
||||
finally:
|
||||
shutdown_event.set()
|
||||
episode_active.clear()
|
||||
|
||||
if get_actions_t.is_alive():
|
||||
logger.info("Waiting for action generation thread to finish...")
|
||||
get_actions_t.join(timeout=5.0)
|
||||
|
||||
if actor_t.is_alive():
|
||||
logger.info("Waiting for actor thread to finish...")
|
||||
actor_t.join(timeout=5.0)
|
||||
|
||||
follower.disconnect()
|
||||
logger.info("Follower disconnected")
|
||||
|
||||
if cfg.record_dataset and dataset is not None:
|
||||
dataset.finalize()
|
||||
if cfg.push_to_hub:
|
||||
logger.info("Uploading to Hugging Face Hub...")
|
||||
dataset.push_to_hub(private=True)
|
||||
|
||||
logger.info("Cleanup completed")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,216 @@
|
||||
import time
|
||||
import numpy as np
|
||||
|
||||
from lerobot.robots.openarms.openarms_follower import OpenArmsFollower
|
||||
from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig
|
||||
|
||||
|
||||
# Friction model parameters from OpenArms config/follower.yaml
|
||||
# τ_fric(ω) = Fo + Fv·ω + Fc·tanh(k·ω)
|
||||
# For 8 motors: [joint_1, joint_2, joint_3, joint_4, joint_5, joint_6, joint_7, gripper]
|
||||
FRICTION_PARAMS = {
|
||||
"Fc": [0.306, 0.306, 0.40, 0.166, 0.050, 0.093, 0.172, 0.0512], # Coulomb friction [Nm]
|
||||
"k": [28.417, 28.417, 29.065, 130.038, 151.771, 242.287, 7.888, 4.000], # tanh steepness
|
||||
"Fv": [0.063, 0.0630, 0.604, 0.813, 0.029, 0.072, 0.084, 0.084], # Viscous friction [Nm·s/rad]
|
||||
"Fo": [0.088, 0.088, 0.008, -0.058, 0.005, 0.009, -0.059, -0.050], # Offset torque [Nm]
|
||||
}
|
||||
|
||||
# Constants from OpenArms C++ implementation
|
||||
AMP_TMP = 1.0
|
||||
COEF_TMP = 0.1
|
||||
|
||||
FRICTION_SCALE = 1.0 # OpenArms C++ uses 0.3 factor in unilateral mode
|
||||
DAMPING_KD = [0.5, 0.5, 0.5, 0.5, 0.1, 0.1, 0.1, 0.1] # Damping gains for stability
|
||||
|
||||
def compute_friction_torque(velocity_rad_per_sec: float, motor_index: int) -> float:
|
||||
"""
|
||||
Compute friction torque for a single motor using the tanh friction model.
|
||||
|
||||
Args:
|
||||
velocity_rad_per_sec: Angular velocity in rad/s
|
||||
motor_index: Index of the motor (0-7)
|
||||
|
||||
Returns:
|
||||
Friction torque in N·m (scaled for stability)
|
||||
"""
|
||||
|
||||
Fc = FRICTION_PARAMS["Fc"][motor_index]
|
||||
k = FRICTION_PARAMS["k"][motor_index]
|
||||
Fv = FRICTION_PARAMS["Fv"][motor_index]
|
||||
Fo = FRICTION_PARAMS["Fo"][motor_index]
|
||||
|
||||
# Friction model: τ_fric = amp * Fc * tanh(coef * k * ω) + Fv * ω + Fo
|
||||
friction_torque = (
|
||||
AMP_TMP * Fc * np.tanh(COEF_TMP * k * velocity_rad_per_sec) +
|
||||
Fv * velocity_rad_per_sec +
|
||||
Fo
|
||||
)
|
||||
|
||||
# Scale down friction compensation for stability at lower control rates
|
||||
# (OpenArms C++ uses 0.3 factor in unilateral mode)!!
|
||||
friction_torque *= FRICTION_SCALE
|
||||
|
||||
return friction_torque
|
||||
|
||||
|
||||
def main() -> None:
|
||||
config = OpenArmsFollowerConfig(
|
||||
port_left="can0",
|
||||
port_right="can1",
|
||||
can_interface="socketcan",
|
||||
id="openarms_follower",
|
||||
disable_torque_on_disconnect=True,
|
||||
max_relative_target=5.0,
|
||||
)
|
||||
|
||||
print("Initializing robot...")
|
||||
follower = OpenArmsFollower(config)
|
||||
follower.connect(calibrate=True)
|
||||
|
||||
print(f"Applying friction compensation")
|
||||
print(" 1. Support the arm before starting")
|
||||
print(" 2. The arm will be held in place by friction compensation")
|
||||
print(" 3. You should be able to move it with gentle force")
|
||||
print("\nPress ENTER when ready to start...")
|
||||
input()
|
||||
|
||||
print(f"✓ Motors enabled")
|
||||
print("\nStarting friction compensation loop...")
|
||||
print("Press Ctrl+C to stop\n")
|
||||
|
||||
loop_times = []
|
||||
last_print_time = time.perf_counter()
|
||||
|
||||
# Motor name to index mapping
|
||||
motor_name_to_index = {
|
||||
"joint_1": 0,
|
||||
"joint_2": 1,
|
||||
"joint_3": 2,
|
||||
"joint_4": 3,
|
||||
"joint_5": 4,
|
||||
"joint_6": 5,
|
||||
"joint_7": 6,
|
||||
"gripper": 7,
|
||||
}
|
||||
|
||||
try:
|
||||
while True:
|
||||
loop_start = time.perf_counter()
|
||||
|
||||
# Get current joint positions and velocities from robot
|
||||
obs = follower.get_observation()
|
||||
|
||||
# Extract velocities in degrees per second
|
||||
velocities_deg_per_sec = {}
|
||||
positions_deg = {}
|
||||
|
||||
for motor in follower.bus_right.motors:
|
||||
vel_key = f"right_{motor}.vel"
|
||||
pos_key = f"right_{motor}.pos"
|
||||
if vel_key in obs:
|
||||
velocities_deg_per_sec[f"right_{motor}"] = obs[vel_key]
|
||||
if pos_key in obs:
|
||||
positions_deg[f"right_{motor}"] = obs[pos_key]
|
||||
|
||||
for motor in follower.bus_left.motors:
|
||||
vel_key = f"left_{motor}.vel"
|
||||
pos_key = f"left_{motor}.pos"
|
||||
if vel_key in obs:
|
||||
velocities_deg_per_sec[f"left_{motor}"] = obs[vel_key]
|
||||
if pos_key in obs:
|
||||
positions_deg[f"left_{motor}"] = obs[pos_key]
|
||||
|
||||
# Convert velocities to rad/s and compute friction torques
|
||||
friction_torques_nm = {}
|
||||
for motor_full_name, velocity_deg_per_sec in velocities_deg_per_sec.items():
|
||||
# Extract motor name without arm prefix
|
||||
if motor_full_name.startswith("right_"):
|
||||
motor_name = motor_full_name.removeprefix("right_")
|
||||
elif motor_full_name.startswith("left_"):
|
||||
motor_name = motor_full_name.removeprefix("left_")
|
||||
else:
|
||||
continue
|
||||
|
||||
# Get motor index for friction parameters
|
||||
motor_index = motor_name_to_index.get(motor_name, 0)
|
||||
|
||||
# Convert velocity to rad/s
|
||||
velocity_rad_per_sec = np.deg2rad(velocity_deg_per_sec)
|
||||
|
||||
# Compute friction torque
|
||||
friction_torque = compute_friction_torque(velocity_rad_per_sec, motor_index)
|
||||
friction_torques_nm[motor_full_name] = friction_torque
|
||||
|
||||
# Apply friction compensation to right arm (all joints INCLUDING gripper)
|
||||
for motor in follower.bus_right.motors:
|
||||
full_name = f"right_{motor}"
|
||||
position = positions_deg.get(full_name, 0.0)
|
||||
torque = friction_torques_nm.get(full_name, 0.0)
|
||||
|
||||
# Get motor index for damping gain
|
||||
motor_index = motor_name_to_index.get(motor, 0)
|
||||
kd = DAMPING_KD[motor_index]
|
||||
|
||||
# Send MIT control command with friction compensation + damping
|
||||
follower.bus_right._mit_control(
|
||||
motor=motor,
|
||||
kp=0.0, # No position control
|
||||
kd=kd, # Add damping for stability
|
||||
position_degrees=position,
|
||||
velocity_deg_per_sec=0.0,
|
||||
torque=torque
|
||||
)
|
||||
|
||||
# Apply friction compensation to left arm (all joints INCLUDING gripper)
|
||||
for motor in follower.bus_left.motors:
|
||||
full_name = f"left_{motor}"
|
||||
position = positions_deg.get(full_name, 0.0)
|
||||
torque = friction_torques_nm.get(full_name, 0.0)
|
||||
|
||||
# Get motor index for damping gain
|
||||
motor_index = motor_name_to_index.get(motor, 0)
|
||||
kd = DAMPING_KD[motor_index]
|
||||
|
||||
# Send MIT control command with friction compensation + damping
|
||||
follower.bus_left._mit_control(
|
||||
motor=motor,
|
||||
kp=0.0, # No position control
|
||||
kd=kd, # Add damping for stability
|
||||
position_degrees=position,
|
||||
velocity_deg_per_sec=0.0,
|
||||
torque=torque
|
||||
)
|
||||
|
||||
# Measure loop time
|
||||
loop_end = time.perf_counter()
|
||||
loop_time = loop_end - loop_start
|
||||
loop_times.append(loop_time)
|
||||
|
||||
# Print status every 2 seconds
|
||||
if loop_end - last_print_time >= 2.0:
|
||||
if loop_times:
|
||||
avg_time = sum(loop_times) / len(loop_times)
|
||||
current_hz = 1.0 / avg_time if avg_time > 0 else 0
|
||||
|
||||
print(f"{current_hz:.1f} Hz")
|
||||
|
||||
loop_times = []
|
||||
last_print_time = loop_end
|
||||
|
||||
time.sleep(0.001)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\nStopping friction compensation...")
|
||||
|
||||
finally:
|
||||
print("\nDisabling all motors and disconnecting...")
|
||||
follower.bus_right.disable_torque()
|
||||
follower.bus_left.disable_torque()
|
||||
time.sleep(0.1)
|
||||
follower.disconnect()
|
||||
print("✓ Safe shutdown complete")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -0,0 +1,142 @@
|
||||
import time
|
||||
import numpy as np
|
||||
import pinocchio as pin
|
||||
from os.path import join, dirname, exists, expanduser
|
||||
|
||||
from lerobot.robots.openarms.openarms_follower import OpenArmsFollower
|
||||
from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig
|
||||
|
||||
|
||||
def main() -> None:
|
||||
config = OpenArmsFollowerConfig(
|
||||
port_left="can0",
|
||||
port_right="can1",
|
||||
can_interface="socketcan",
|
||||
id="openarms_follower",
|
||||
disable_torque_on_disconnect=True,
|
||||
max_relative_target=5.0,
|
||||
)
|
||||
|
||||
|
||||
print("Initializing robot...")
|
||||
follower = OpenArmsFollower(config)
|
||||
follower.connect(calibrate=True)
|
||||
|
||||
# Load URDF for Pinocchio dynamics
|
||||
urdf_path = "/home/croissant/Documents/openarm_description/openarm_bimanual_pybullet.urdf"
|
||||
|
||||
pin_robot = pin.RobotWrapper.BuildFromURDF(urdf_path, dirname(urdf_path))
|
||||
pin_robot.data = pin_robot.model.createData()
|
||||
print(f"✓ Loaded Pinocchio model with {pin_robot.nq} DoFs")
|
||||
|
||||
follower.pin_robot = pin_robot
|
||||
|
||||
print(f"Applying gravity compensation")
|
||||
print(" 1. Support the arm before starting")
|
||||
print(" 2. The arm will be held in place by gravity compensation")
|
||||
print(" 3. You should be able to move it with gentle force")
|
||||
print("\nPress ENTER when ready to start...")
|
||||
input()
|
||||
|
||||
print(f"✓ Motors enabled")
|
||||
print("\nStarting gravity compensation loop...")
|
||||
print("Press Ctrl+C to stop\n")
|
||||
|
||||
loop_times = []
|
||||
last_print_time = time.perf_counter()
|
||||
|
||||
try:
|
||||
while True:
|
||||
loop_start = time.perf_counter()
|
||||
|
||||
# Get current joint positions from robot
|
||||
obs = follower.get_observation()
|
||||
|
||||
# Extract positions in degrees
|
||||
positions_deg = {}
|
||||
for motor in follower.bus_right.motors:
|
||||
key = f"right_{motor}.pos"
|
||||
if key in obs:
|
||||
positions_deg[f"right_{motor}"] = obs[key]
|
||||
|
||||
for motor in follower.bus_left.motors:
|
||||
key = f"left_{motor}.pos"
|
||||
if key in obs:
|
||||
positions_deg[f"left_{motor}"] = obs[key]
|
||||
|
||||
# Convert to radians and calculate gravity torques
|
||||
# Use the built-in method from OpenArmsFollower
|
||||
positions_rad = {k: np.deg2rad(v) for k, v in positions_deg.items()}
|
||||
torques_nm = follower._gravity_from_q(positions_rad)
|
||||
|
||||
# Apply gravity compensation to right arm (all joints except gripper)
|
||||
for motor in follower.bus_right.motors:
|
||||
if motor == "gripper":
|
||||
continue # Skip gripper
|
||||
|
||||
full_name = f"right_{motor}"
|
||||
position = positions_deg.get(full_name, 0.0)
|
||||
torque = torques_nm.get(full_name, 0.0)
|
||||
|
||||
# Send MIT control command with gravity compensation torque
|
||||
follower.bus_right._mit_control(
|
||||
motor=motor,
|
||||
kp=0.0, # No position control
|
||||
kd=0.0, # No velocity damping
|
||||
position_degrees=position,
|
||||
velocity_deg_per_sec=0.0,
|
||||
torque=torque
|
||||
)
|
||||
|
||||
# Apply gravity compensation to left arm (all joints except gripper)
|
||||
for motor in follower.bus_left.motors:
|
||||
if motor == "gripper":
|
||||
continue # Skip gripper
|
||||
|
||||
full_name = f"left_{motor}"
|
||||
position = positions_deg.get(full_name, 0.0)
|
||||
torque = torques_nm.get(full_name, 0.0)
|
||||
|
||||
# Send MIT control command with gravity compensation torque
|
||||
follower.bus_left._mit_control(
|
||||
motor=motor,
|
||||
kp=0.0, # No position control
|
||||
kd=0.0, # No velocity damping
|
||||
position_degrees=position,
|
||||
velocity_deg_per_sec=0.0,
|
||||
torque=torque
|
||||
)
|
||||
|
||||
# Measure loop time
|
||||
loop_end = time.perf_counter()
|
||||
loop_time = loop_end - loop_start
|
||||
loop_times.append(loop_time)
|
||||
|
||||
# Print status every 2 seconds
|
||||
if loop_end - last_print_time >= 2.0:
|
||||
if loop_times:
|
||||
avg_time = sum(loop_times) / len(loop_times)
|
||||
current_hz = 1.0 / avg_time if avg_time > 0 else 0
|
||||
|
||||
print(f"{current_hz:.1f} Hz ({avg_time*1000:.1f} ms)")
|
||||
|
||||
loop_times = []
|
||||
last_print_time = loop_end
|
||||
|
||||
time.sleep(0.005)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\nStopping gravity compensation...")
|
||||
|
||||
finally:
|
||||
print("\nDisabling all motors and disconnecting...")
|
||||
follower.bus_right.disable_torque()
|
||||
follower.bus_left.disable_torque()
|
||||
time.sleep(0.1)
|
||||
follower.disconnect()
|
||||
print("✓ Safe shutdown complete")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -0,0 +1,395 @@
|
||||
"""
|
||||
OpenArms Dataset Recording with Gravity + Friction Compensation
|
||||
|
||||
Records a dataset using OpenArms follower robot with leader teleoperator.
|
||||
Leader arms have gravity and friction compensation for weightless, easy movement.
|
||||
Includes 3 cameras: left wrist, right wrist, and base camera.
|
||||
|
||||
Uses the same compensation approach as teleop_with_compensation.py
|
||||
"""
|
||||
|
||||
import shutil
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features
|
||||
from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig
|
||||
from lerobot.robots.openarms.openarms_follower import OpenArmsFollower
|
||||
from lerobot.teleoperators.openarms.config_openarms_leader import OpenArmsLeaderConfig
|
||||
from lerobot.teleoperators.openarms.openarms_leader import OpenArmsLeader
|
||||
from lerobot.utils.control_utils import init_keyboard_listener
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
|
||||
|
||||
# Recording parameters
|
||||
NUM_EPISODES = 1
|
||||
FPS = 30
|
||||
EPISODE_TIME_SEC = 600
|
||||
RESET_TIME_SEC = 120
|
||||
TASK_DESCRIPTION = "OpenArms task description"
|
||||
|
||||
# Friction compensation scale factor (1.0 = full, 0.3 = 30% for stability)
|
||||
FRICTION_SCALE = 1.0
|
||||
|
||||
def record_loop_with_compensation(
|
||||
robot,
|
||||
leader,
|
||||
events,
|
||||
fps,
|
||||
dataset,
|
||||
dataset_features,
|
||||
control_time_s,
|
||||
single_task,
|
||||
display_data=True,
|
||||
):
|
||||
"""
|
||||
Custom record loop that applies gravity + friction compensation to leader.
|
||||
Based on record_loop but with integrated compensation.
|
||||
"""
|
||||
dt = 1 / fps
|
||||
episode_start_time = time.perf_counter()
|
||||
|
||||
# All joints (both arms)
|
||||
all_joints = []
|
||||
for motor in leader.bus_right.motors:
|
||||
all_joints.append(f"right_{motor}")
|
||||
for motor in leader.bus_left.motors:
|
||||
all_joints.append(f"left_{motor}")
|
||||
|
||||
while True:
|
||||
loop_start = time.perf_counter()
|
||||
elapsed = loop_start - episode_start_time
|
||||
|
||||
# Check if we should exit
|
||||
if elapsed >= control_time_s or events["exit_early"] or events["stop_recording"]:
|
||||
break
|
||||
|
||||
# Get leader state
|
||||
leader_action = leader.get_action()
|
||||
|
||||
# Extract positions and velocities in degrees
|
||||
leader_positions_deg = {}
|
||||
leader_velocities_deg_per_sec = {}
|
||||
|
||||
for motor in leader.bus_right.motors:
|
||||
pos_key = f"right_{motor}.pos"
|
||||
vel_key = f"right_{motor}.vel"
|
||||
if pos_key in leader_action:
|
||||
leader_positions_deg[f"right_{motor}"] = leader_action[pos_key]
|
||||
if vel_key in leader_action:
|
||||
leader_velocities_deg_per_sec[f"right_{motor}"] = leader_action[vel_key]
|
||||
|
||||
for motor in leader.bus_left.motors:
|
||||
pos_key = f"left_{motor}.pos"
|
||||
vel_key = f"left_{motor}.vel"
|
||||
if pos_key in leader_action:
|
||||
leader_positions_deg[f"left_{motor}"] = leader_action[pos_key]
|
||||
if vel_key in leader_action:
|
||||
leader_velocities_deg_per_sec[f"left_{motor}"] = leader_action[vel_key]
|
||||
|
||||
# Calculate gravity torques for leader using built-in method
|
||||
leader_positions_rad = {k: np.deg2rad(v) for k, v in leader_positions_deg.items()}
|
||||
leader_gravity_torques_nm = leader._gravity_from_q(leader_positions_rad)
|
||||
|
||||
# Calculate friction torques for leader using built-in method
|
||||
leader_velocities_rad_per_sec = {k: np.deg2rad(v) for k, v in leader_velocities_deg_per_sec.items()}
|
||||
leader_friction_torques_nm = leader._friction_from_velocity(
|
||||
leader_velocities_rad_per_sec,
|
||||
friction_scale=FRICTION_SCALE
|
||||
)
|
||||
|
||||
# Combine gravity + friction torques
|
||||
leader_total_torques_nm = {}
|
||||
for motor_name in leader_gravity_torques_nm:
|
||||
gravity = leader_gravity_torques_nm.get(motor_name, 0.0)
|
||||
friction = leader_friction_torques_nm.get(motor_name, 0.0)
|
||||
leader_total_torques_nm[motor_name] = gravity + friction
|
||||
|
||||
# Apply gravity + friction compensation to leader RIGHT arm (all joints including gripper)
|
||||
for motor in leader.bus_right.motors:
|
||||
full_name = f"right_{motor}"
|
||||
position = leader_positions_deg.get(full_name, 0.0)
|
||||
torque = leader_total_torques_nm.get(full_name, 0.0)
|
||||
|
||||
# Get damping gain for stability
|
||||
kd = leader.get_damping_kd(motor)
|
||||
|
||||
leader.bus_right._mit_control(
|
||||
motor=motor,
|
||||
kp=0.0,
|
||||
kd=kd, # Add damping for stability
|
||||
position_degrees=position,
|
||||
velocity_deg_per_sec=0.0,
|
||||
torque=torque,
|
||||
)
|
||||
|
||||
# Apply gravity + friction compensation to leader LEFT arm (all joints including gripper)
|
||||
for motor in leader.bus_left.motors:
|
||||
full_name = f"left_{motor}"
|
||||
position = leader_positions_deg.get(full_name, 0.0)
|
||||
torque = leader_total_torques_nm.get(full_name, 0.0)
|
||||
|
||||
# Get damping gain for stability
|
||||
kd = leader.get_damping_kd(motor)
|
||||
|
||||
leader.bus_left._mit_control(
|
||||
motor=motor,
|
||||
kp=0.0,
|
||||
kd=kd, # Add damping for stability
|
||||
position_degrees=position,
|
||||
velocity_deg_per_sec=0.0,
|
||||
torque=torque,
|
||||
)
|
||||
|
||||
# Send leader positions to follower (both arms)
|
||||
follower_action = {}
|
||||
for joint in all_joints:
|
||||
pos_key = f"{joint}.pos"
|
||||
if pos_key in leader_action:
|
||||
follower_action[pos_key] = leader_action[pos_key]
|
||||
|
||||
# Send action to robot
|
||||
if follower_action:
|
||||
robot.send_action(follower_action)
|
||||
|
||||
# Get observation from robot (includes camera images)
|
||||
observation = robot.get_observation()
|
||||
|
||||
# Add to dataset if we have a dataset
|
||||
if dataset is not None:
|
||||
# Build properly formatted observation frame
|
||||
obs_frame = build_dataset_frame(dataset_features, observation, prefix="observation")
|
||||
|
||||
# Build properly formatted action frame (keep .pos suffix - it matches the feature names)
|
||||
action_frame = build_dataset_frame(dataset_features, follower_action, prefix="action")
|
||||
|
||||
# Combine into single frame
|
||||
frame = {**obs_frame, **action_frame}
|
||||
|
||||
# Add metadata (task is required, timestamp will be auto-calculated by add_frame)
|
||||
frame["task"] = single_task
|
||||
|
||||
dataset.add_frame(frame)
|
||||
|
||||
# Display data if requested
|
||||
if display_data:
|
||||
log_rerun_data(observation=observation, action=follower_action)
|
||||
|
||||
# Maintain loop rate
|
||||
loop_duration = time.perf_counter() - loop_start
|
||||
sleep_time = dt - loop_duration
|
||||
if sleep_time > 0:
|
||||
time.sleep(sleep_time)
|
||||
|
||||
|
||||
def main():
|
||||
"""Main recording loop with gravity compensation."""
|
||||
|
||||
print("=" * 70)
|
||||
print("OpenArms Dataset Recording with Compensation")
|
||||
print("=" * 70)
|
||||
|
||||
# Create camera configurations (3 cameras: left wrist, right wrist, base)
|
||||
# Using actual device paths found by lerobot-find-cameras opencv
|
||||
camera_config = {
|
||||
"left_wrist": OpenCVCameraConfig(index_or_path="/dev/video0", width=640, height=480, fps=FPS),
|
||||
"right_wrist": OpenCVCameraConfig(index_or_path="/dev/video1", width=640, height=480, fps=FPS),
|
||||
"base": OpenCVCameraConfig(index_or_path="/dev/video7", width=640, height=480, fps=FPS),
|
||||
}
|
||||
|
||||
# Configure follower robot with cameras
|
||||
follower_config = OpenArmsFollowerConfig(
|
||||
port_left="can2",
|
||||
port_right="can3",
|
||||
can_interface="socketcan",
|
||||
id="openarms_follower",
|
||||
disable_torque_on_disconnect=True,
|
||||
max_relative_target=10.0,
|
||||
cameras=camera_config,
|
||||
)
|
||||
|
||||
# Configure leader teleoperator (no cameras needed)
|
||||
leader_config = OpenArmsLeaderConfig(
|
||||
port_left="can0",
|
||||
port_right="can1",
|
||||
can_interface="socketcan",
|
||||
id="openarms_leader",
|
||||
manual_control=False, # Enable torque control for gravity compensation
|
||||
)
|
||||
|
||||
# Initialize robot and teleoperator
|
||||
print("\nInitializing devices...")
|
||||
follower = OpenArmsFollower(follower_config)
|
||||
leader = OpenArmsLeader(leader_config)
|
||||
|
||||
# Connect devices
|
||||
print("Connecting and calibrating...")
|
||||
follower.connect(calibrate=True)
|
||||
leader.connect(calibrate=True)
|
||||
|
||||
# Verify URDF is loaded for gravity compensation
|
||||
if leader.pin_robot is None:
|
||||
raise RuntimeError("URDF model not loaded on leader. Gravity compensation not available.")
|
||||
|
||||
# Configure the dataset features
|
||||
# For actions, we only want to record positions (not velocity or torque)
|
||||
action_features_hw = {}
|
||||
for key, value in follower.action_features.items():
|
||||
if key.endswith(".pos"):
|
||||
action_features_hw[key] = value
|
||||
|
||||
action_features = hw_to_dataset_features(action_features_hw, "action")
|
||||
obs_features = hw_to_dataset_features(follower.observation_features, "observation")
|
||||
dataset_features = {**action_features, **obs_features}
|
||||
|
||||
# Create the dataset
|
||||
print("\nCreating dataset...")
|
||||
repo_id = "<hf_username>/<dataset_repo_id>" # TODO: Replace with your Hugging Face repo
|
||||
|
||||
# Check if dataset already exists and prompt user
|
||||
dataset_path = Path.home() / ".cache" / "huggingface" / "lerobot" / repo_id
|
||||
while dataset_path.exists():
|
||||
print(f"\nDataset already exists at: {dataset_path}")
|
||||
print("\nOptions:")
|
||||
print(" 1. Overwrite existing dataset")
|
||||
print(" 2. Use a different name")
|
||||
print(" 3. Abort")
|
||||
|
||||
choice = input("\nEnter your choice (1/2/3): ").strip()
|
||||
|
||||
if choice == '1':
|
||||
print(f"Removing existing dataset...")
|
||||
shutil.rmtree(dataset_path)
|
||||
print("✓ Existing dataset removed")
|
||||
break
|
||||
elif choice == '2':
|
||||
print("\nCurrent repo_id:", repo_id)
|
||||
new_repo_id = input("Enter new repo_id (format: <username>/<dataset_name>): ").strip()
|
||||
if new_repo_id and '/' in new_repo_id:
|
||||
repo_id = new_repo_id
|
||||
dataset_path = Path.home() / ".cache" / "huggingface" / "lerobot" / repo_id
|
||||
print(f"✓ Using new repo_id: {repo_id}")
|
||||
# Loop will continue if this new path also exists
|
||||
else:
|
||||
print("Invalid repo_id format. Please use format: <username>/<dataset_name>")
|
||||
elif choice == '3':
|
||||
print("Aborting. Please remove the existing dataset manually or restart with a different repo_id.")
|
||||
follower.disconnect()
|
||||
leader.disconnect()
|
||||
return
|
||||
else:
|
||||
print("Invalid choice. Please enter 1, 2, or 3.")
|
||||
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=repo_id,
|
||||
fps=FPS,
|
||||
features=dataset_features,
|
||||
robot_type=follower.name,
|
||||
use_videos=True,
|
||||
image_writer_threads=4,
|
||||
)
|
||||
|
||||
# Initialize keyboard listener and visualization
|
||||
_, events = init_keyboard_listener()
|
||||
init_rerun(session_name="openarms_recording")
|
||||
|
||||
# Enable motors on both leader arms for gravity compensation
|
||||
leader.bus_right.enable_torque()
|
||||
leader.bus_left.enable_torque()
|
||||
time.sleep(0.1)
|
||||
|
||||
print("\n" + "=" * 70)
|
||||
print(f"Recording {NUM_EPISODES} episodes")
|
||||
print(f"Task: {TASK_DESCRIPTION}")
|
||||
print("=" * 70)
|
||||
print("\nLeader BOTH arms: Gravity + Friction comp | Follower BOTH arms: Teleop")
|
||||
print("\nKeyboard controls:")
|
||||
print(" - Press 'q' to stop recording")
|
||||
print(" - Press 'r' to re-record current episode")
|
||||
print("=" * 70)
|
||||
|
||||
episode_idx = 0
|
||||
|
||||
try:
|
||||
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
# Record episode with compensation active
|
||||
record_loop_with_compensation(
|
||||
robot=follower,
|
||||
leader=leader,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
dataset=dataset,
|
||||
dataset_features=dataset_features,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]):
|
||||
log_say("Reset the environment")
|
||||
record_loop_with_compensation(
|
||||
robot=follower,
|
||||
leader=leader,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
dataset=None, # Don't save reset period
|
||||
dataset_features=dataset_features,
|
||||
control_time_s=RESET_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
)
|
||||
|
||||
# Handle re-recording
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-recording episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
# Only save episode if frames were recorded
|
||||
if dataset.episode_buffer is not None and dataset.episode_buffer["size"] > 0:
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
else:
|
||||
log_say("No frames recorded, skipping episode save")
|
||||
# Clear the empty buffer
|
||||
dataset.episode_buffer = None
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\nStopping recording...")
|
||||
|
||||
finally:
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
try:
|
||||
leader.bus_right.disable_torque()
|
||||
leader.bus_left.disable_torque()
|
||||
time.sleep(0.1)
|
||||
leader.disconnect()
|
||||
follower.disconnect()
|
||||
print("✓ Shutdown complete")
|
||||
except Exception as e:
|
||||
print(f"Shutdown error: {e}")
|
||||
|
||||
# Upload dataset
|
||||
print("\nUploading dataset to Hugging Face Hub...")
|
||||
try:
|
||||
dataset.push_to_hub()
|
||||
print("✓ Dataset uploaded successfully")
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to upload dataset: {e}")
|
||||
print("You can manually upload later using: dataset.push_to_hub()")
|
||||
|
||||
print("✓ Recording complete!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,166 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
OpenArms Dataset Replay Example
|
||||
|
||||
Replays position actions from a recorded dataset on an OpenArms follower robot.
|
||||
Only position commands (ending with .pos) are replayed, not velocity or torque.
|
||||
|
||||
Example usage:
|
||||
python examples/openarms/replay.py
|
||||
"""
|
||||
|
||||
import time
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig
|
||||
from lerobot.robots.openarms.openarms_follower import OpenArmsFollower
|
||||
from lerobot.utils.constants import ACTION
|
||||
from lerobot.utils.robot_utils import busy_wait
|
||||
from lerobot.utils.utils import log_say
|
||||
|
||||
# Configuration
|
||||
EPISODE_IDX = 0
|
||||
DATASET_REPO_ID = "lerobot-data-collection/replay-this-2025-11-02-17-58" # TODO: Replace with your dataset
|
||||
DATASET_ROOT = None # Use default cache location, or specify custom path
|
||||
|
||||
# Robot configuration - adjust these to match your setup
|
||||
ROBOT_CONFIG = OpenArmsFollowerConfig(
|
||||
port_left="can2", # CAN interface for left arm
|
||||
port_right="can3", # CAN interface for right arm
|
||||
can_interface="socketcan",
|
||||
id="openarms_follower",
|
||||
disable_torque_on_disconnect=True,
|
||||
max_relative_target=10.0, # Safety limit: max degrees to move per step
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
"""Main replay function."""
|
||||
print("=" * 70)
|
||||
print("OpenArms Dataset Replay")
|
||||
print("=" * 70)
|
||||
print(f"\nDataset: {DATASET_REPO_ID}")
|
||||
print(f"Episode: {EPISODE_IDX}")
|
||||
print(f"Robot: {ROBOT_CONFIG.id}")
|
||||
print(f" Left arm: {ROBOT_CONFIG.port_left}")
|
||||
print(f" Right arm: {ROBOT_CONFIG.port_right}")
|
||||
print("\n" + "=" * 70)
|
||||
|
||||
# Initialize the robot
|
||||
print("\n[1/3] Initializing robot...")
|
||||
robot = OpenArmsFollower(ROBOT_CONFIG)
|
||||
|
||||
# Load the dataset
|
||||
print(f"\n[2/3] Loading dataset '{DATASET_REPO_ID}'...")
|
||||
dataset = LeRobotDataset(
|
||||
DATASET_REPO_ID,
|
||||
root=DATASET_ROOT,
|
||||
episodes=[EPISODE_IDX]
|
||||
)
|
||||
|
||||
# Filter dataset to only include frames from the specified episode
|
||||
# (required for dataset V3.0 where episodes are chunked)
|
||||
episode_frames = dataset.hf_dataset.filter(
|
||||
lambda x: x["episode_index"] == EPISODE_IDX
|
||||
)
|
||||
|
||||
if len(episode_frames) == 0:
|
||||
raise ValueError(
|
||||
f"No frames found for episode {EPISODE_IDX} in dataset {DATASET_REPO_ID}"
|
||||
)
|
||||
|
||||
print(f" Found {len(episode_frames)} frames in episode {EPISODE_IDX}")
|
||||
|
||||
# Extract action features from dataset
|
||||
action_features = dataset.features.get(ACTION, {})
|
||||
action_names = action_features.get("names", [])
|
||||
|
||||
# Filter to only position actions (ending with .pos)
|
||||
position_action_names = [name for name in action_names if name.endswith(".pos")]
|
||||
|
||||
if not position_action_names:
|
||||
raise ValueError(
|
||||
f"No position actions found in dataset. Action names: {action_names}"
|
||||
)
|
||||
|
||||
print(f" Found {len(position_action_names)} position actions to replay")
|
||||
print(f" Actions: {', '.join(position_action_names[:5])}{'...' if len(position_action_names) > 5 else ''}")
|
||||
|
||||
# Select only action columns from dataset
|
||||
actions = episode_frames.select_columns(ACTION)
|
||||
|
||||
# Connect to the robot
|
||||
print(f"\n[3/3] Connecting to robot...")
|
||||
robot.connect(calibrate=False) # Skip calibration for replay
|
||||
|
||||
if not robot.is_connected:
|
||||
raise RuntimeError("Robot failed to connect!")
|
||||
|
||||
print("\n" + "=" * 70)
|
||||
print("Ready to replay!")
|
||||
print("=" * 70)
|
||||
print("\nThe robot will replay the recorded positions.")
|
||||
print("Press Ctrl+C to stop at any time.\n")
|
||||
|
||||
input("Press ENTER to start replaying...")
|
||||
|
||||
# Replay loop
|
||||
log_say(f"Replaying episode {EPISODE_IDX}", blocking=True)
|
||||
|
||||
try:
|
||||
for idx in range(len(episode_frames)):
|
||||
loop_start = time.perf_counter()
|
||||
|
||||
# Extract action array from dataset
|
||||
action_array = actions[idx][ACTION]
|
||||
|
||||
# Build action dictionary, but only include position actions
|
||||
action = {}
|
||||
for i, name in enumerate(action_names):
|
||||
# Only include position actions (ending with .pos)
|
||||
if name.endswith(".pos"):
|
||||
action[name] = float(action_array[i])
|
||||
|
||||
# Send action to robot
|
||||
robot.send_action(action)
|
||||
|
||||
# Maintain replay rate (use dataset fps)
|
||||
loop_duration = time.perf_counter() - loop_start
|
||||
dt_s = 1.0 / dataset.fps - loop_duration
|
||||
busy_wait(dt_s)
|
||||
|
||||
# Progress indicator every 100 frames
|
||||
if (idx + 1) % 100 == 0:
|
||||
progress = (idx + 1) / len(episode_frames) * 100
|
||||
print(f"Progress: {idx + 1}/{len(episode_frames)} frames ({progress:.1f}%)")
|
||||
|
||||
print(f"\n✓ Successfully replayed {len(episode_frames)} frames")
|
||||
log_say("Replay complete", blocking=True)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\nReplay interrupted by user")
|
||||
finally:
|
||||
# Disconnect robot
|
||||
print("\nDisconnecting robot...")
|
||||
robot.disconnect()
|
||||
print("✓ Replay complete!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -0,0 +1,73 @@
|
||||
#!/bin/bash
|
||||
# Setup all OpenArms CAN interfaces with CAN FD
|
||||
|
||||
set -e
|
||||
|
||||
echo "=========================================="
|
||||
echo "OpenArms CAN FD Interface Setup"
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
echo "Mode: CAN FD"
|
||||
echo " - Nominal bitrate: 1 Mbps"
|
||||
echo " - Data bitrate: 5 Mbps"
|
||||
echo ""
|
||||
echo "Configuring interfaces can0, can1, can2, can3..."
|
||||
echo ""
|
||||
|
||||
# Configure each CAN interface with CAN FD
|
||||
for i in 0 1 2 3; do
|
||||
interface="can$i"
|
||||
|
||||
# Check if interface exists
|
||||
if ! ip link show "$interface" &> /dev/null; then
|
||||
echo "⚠ $interface: Not found, skipping"
|
||||
continue
|
||||
fi
|
||||
|
||||
# Bring down interface
|
||||
sudo ip link set "$interface" down 2>/dev/null
|
||||
|
||||
# Configure CAN FD mode
|
||||
sudo ip link set "$interface" type can \
|
||||
bitrate 1000000 \
|
||||
dbitrate 5000000 \
|
||||
fd on
|
||||
|
||||
# Bring up interface
|
||||
sudo ip link set "$interface" up
|
||||
|
||||
# Verify configuration
|
||||
if ip link show "$interface" | grep -q "UP"; then
|
||||
echo "✓ $interface: Configured and UP"
|
||||
else
|
||||
echo "✗ $interface: Failed to bring UP"
|
||||
fi
|
||||
done
|
||||
|
||||
echo ""
|
||||
echo "=========================================="
|
||||
echo "Verification"
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
|
||||
# Show detailed status for each interface
|
||||
for i in 0 1 2 3; do
|
||||
interface="can$i"
|
||||
if ip link show "$interface" &> /dev/null; then
|
||||
echo "$interface:"
|
||||
# Show key parameters
|
||||
ip -d link show "$interface" | grep -E "can|state|bitrate|dbitrate" | head -3
|
||||
echo ""
|
||||
fi
|
||||
done
|
||||
|
||||
echo "=========================================="
|
||||
echo "Setup Complete!"
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
echo "All interfaces configured for CAN FD mode"
|
||||
echo ""
|
||||
echo "Next steps:"
|
||||
echo " 1. Test motors: python debug_can_communication.py"
|
||||
echo " 2. Run teleoperation: python examples/openarms/teleop.py"
|
||||
echo ""
|
||||
@@ -0,0 +1,148 @@
|
||||
"""
|
||||
OpenArms Teleoperation Example - Full Dual Arms
|
||||
|
||||
This script demonstrates teleoperation of OpenArms follower robot using an OpenArms leader arm.
|
||||
It first calibrates both devices, then enters a teleoperation loop for both arms.
|
||||
"""
|
||||
|
||||
import time
|
||||
|
||||
from lerobot.robots.openarms.openarms_follower import OpenArmsFollower
|
||||
from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig
|
||||
from lerobot.teleoperators.openarms.openarms_leader import OpenArmsLeader
|
||||
from lerobot.teleoperators.openarms.config_openarms_leader import OpenArmsLeaderConfig
|
||||
|
||||
|
||||
follower_config = OpenArmsFollowerConfig(
|
||||
port_left="can2", # CAN interface for follower left arm
|
||||
port_right="can3", # CAN interface for follower right arm
|
||||
can_interface="socketcan", # Linux SocketCAN
|
||||
id="openarms_follower",
|
||||
disable_torque_on_disconnect=True,
|
||||
max_relative_target=5.0, # Safety limit
|
||||
)
|
||||
|
||||
|
||||
leader_config = OpenArmsLeaderConfig(
|
||||
port_left="can0", # CAN interface for leader left arm
|
||||
port_right="can1", # CAN interface for leader right arm
|
||||
can_interface="socketcan", # Linux SocketCAN
|
||||
id="openarms_leader",
|
||||
manual_control=True, # Enable manual control (torque disabled)
|
||||
)
|
||||
|
||||
print("=" * 60)
|
||||
print("OpenArms Teleoperation - Full Dual Arms")
|
||||
print("=" * 60)
|
||||
|
||||
# Initialize devices
|
||||
print("\n[1/4] Initializing devices...")
|
||||
follower = OpenArmsFollower(follower_config)
|
||||
leader = OpenArmsLeader(leader_config)
|
||||
|
||||
# Connect and calibrate follower
|
||||
print("\n[2/4] Connecting and calibrating follower robot...")
|
||||
print("Note: If you have existing calibration, just press ENTER to use it.")
|
||||
follower.connect(calibrate=True)
|
||||
|
||||
# Connect and calibrate leader
|
||||
print("\n[3/4] Connecting and calibrating leader arm...")
|
||||
print("Note: The leader arm will have torque disabled for manual control.")
|
||||
leader.connect(calibrate=True)
|
||||
|
||||
# Wait for user to be ready
|
||||
print("\n[4/4] Ready for teleoperation!")
|
||||
print("\nBoth arms will be controlled (16 motors total):")
|
||||
print(" RIGHT ARM: joints 1-7 + gripper")
|
||||
print(" LEFT ARM: joints 1-7 + gripper")
|
||||
|
||||
print("\nPress ENTER to start teleoperation...")
|
||||
input()
|
||||
|
||||
print("\nTeleoperation started! Move both leader arms.")
|
||||
print("Press Ctrl+C to stop.\n")
|
||||
|
||||
# All joints for both arms (16 motors total)
|
||||
all_joints = [
|
||||
# Right arm
|
||||
"right_joint_1",
|
||||
"right_joint_2",
|
||||
"right_joint_3",
|
||||
"right_joint_4",
|
||||
"right_joint_5",
|
||||
"right_joint_6",
|
||||
"right_joint_7",
|
||||
"right_gripper",
|
||||
# Left arm
|
||||
"left_joint_1",
|
||||
"left_joint_2",
|
||||
"left_joint_3",
|
||||
"left_joint_4",
|
||||
"left_joint_5",
|
||||
"left_joint_6",
|
||||
"left_joint_7",
|
||||
"left_gripper",
|
||||
]
|
||||
|
||||
# Performance monitoring
|
||||
loop_times = []
|
||||
start_time = time.perf_counter()
|
||||
last_print_time = start_time
|
||||
|
||||
try:
|
||||
while True:
|
||||
loop_start = time.perf_counter()
|
||||
|
||||
# Get action from leader
|
||||
leader_action = leader.get_action()
|
||||
|
||||
# Filter to only position data for all joints (both arms)
|
||||
joint_action = {}
|
||||
for joint in all_joints:
|
||||
pos_key = f"{joint}.pos"
|
||||
if pos_key in leader_action:
|
||||
joint_action[pos_key] = leader_action[pos_key]
|
||||
|
||||
# Send action to follower (both arms)
|
||||
if joint_action:
|
||||
follower.send_action(joint_action)
|
||||
|
||||
# Measure loop time
|
||||
loop_end = time.perf_counter()
|
||||
loop_time = loop_end - loop_start
|
||||
loop_times.append(loop_time)
|
||||
|
||||
# Print stats every 2 seconds
|
||||
if loop_end - last_print_time >= 2.0:
|
||||
if loop_times:
|
||||
avg_time = sum(loop_times) / len(loop_times)
|
||||
current_hz = 1.0 / avg_time if avg_time > 0 else 0
|
||||
min_time = min(loop_times)
|
||||
max_time = max(loop_times)
|
||||
max_hz = 1.0 / min_time if min_time > 0 else 0
|
||||
min_hz = 1.0 / max_time if max_time > 0 else 0
|
||||
|
||||
print(f"[Hz Stats] Avg: {current_hz:.1f} Hz | "
|
||||
f"Range: {min_hz:.1f}-{max_hz:.1f} Hz | "
|
||||
f"Avg loop time: {avg_time*1000:.1f} ms")
|
||||
|
||||
# Reset for next measurement window
|
||||
loop_times = []
|
||||
last_print_time = loop_end
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\nStopping teleoperation...")
|
||||
finally:
|
||||
# Disconnect devices
|
||||
print("Disconnecting devices...")
|
||||
try:
|
||||
follower.disconnect()
|
||||
except Exception as e:
|
||||
print(f"Error disconnecting follower: {e}")
|
||||
|
||||
try:
|
||||
leader.disconnect()
|
||||
except Exception as e:
|
||||
print(f"Error disconnecting leader: {e}")
|
||||
|
||||
print("Done!")
|
||||
@@ -0,0 +1,197 @@
|
||||
"""
|
||||
OpenArms Mini Teleoperation Example
|
||||
|
||||
This script demonstrates teleoperation of an OpenArms follower robot using
|
||||
an OpenArms Mini leader (Feetech-based) with dual arms (16 motors total).
|
||||
|
||||
The OpenArms Mini has:
|
||||
- Right arm: 8 motors (joint_1 to joint_7 + gripper)
|
||||
- Left arm: 8 motors (joint_1 to joint_7 + gripper)
|
||||
|
||||
Note on gripper normalization:
|
||||
- OpenArms Mini gripper: 0-100 scale (0=closed, 100=open)
|
||||
- OpenArms follower gripper: degrees (0=closed, -65=open)
|
||||
- This script automatically converts between the two ranges
|
||||
"""
|
||||
|
||||
import time
|
||||
import os
|
||||
import sys
|
||||
|
||||
from lerobot.robots.openarms.openarms_follower import OpenArmsFollower
|
||||
from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig
|
||||
from lerobot.teleoperators.openarms_mini.openarms_mini import OpenArmsMini
|
||||
from lerobot.teleoperators.openarms_mini.config_openarms_mini import OpenArmsMiniConfig
|
||||
from lerobot.utils.robot_utils import busy_wait
|
||||
|
||||
# Target control frequency
|
||||
TARGET_FPS = 30
|
||||
|
||||
# Configure the OpenArms follower (Damiao motors on CAN bus)
|
||||
follower_config = OpenArmsFollowerConfig(
|
||||
port_left="can0", # CAN interface for follower left arm
|
||||
port_right="can1", # CAN interface for follower right arm
|
||||
can_interface="socketcan", # Linux SocketCAN
|
||||
id="openarms_follower",
|
||||
disable_torque_on_disconnect=True,
|
||||
max_relative_target=10.0, # Safety limit (degrees per step)
|
||||
)
|
||||
|
||||
# Configure the OpenArms Mini leader (Feetech motors on serial)
|
||||
leader_config = OpenArmsMiniConfig(
|
||||
port_right="/dev/ttyACM0", # Serial port for right arm
|
||||
port_left="/dev/ttyACM1", # Serial port for left arm
|
||||
id="openarms_mini",
|
||||
use_degrees=True,
|
||||
)
|
||||
|
||||
print("OpenArms Mini → OpenArms Follower Teleoperation")
|
||||
|
||||
# Initialize devices
|
||||
follower = OpenArmsFollower(follower_config)
|
||||
leader = OpenArmsMini(leader_config)
|
||||
|
||||
# Connect and calibrate follower
|
||||
print("Note: If you have existing calibration, just press ENTER to use it.")
|
||||
follower.connect(calibrate=True)
|
||||
|
||||
# Connect and calibrate leader
|
||||
print("Note: The leader arms will have torque disabled for manual control.")
|
||||
leader.connect(calibrate=True)
|
||||
|
||||
print("\nPress ENTER to start teleoperation...")
|
||||
input()
|
||||
|
||||
print("Press Ctrl+C to stop.\n")
|
||||
|
||||
# All joints for both arms (16 motors total)
|
||||
all_joints = [
|
||||
# Right arm
|
||||
"right_joint_1",
|
||||
"right_joint_2",
|
||||
"right_joint_3",
|
||||
"right_joint_4",
|
||||
"right_joint_5",
|
||||
"right_joint_6",
|
||||
"right_joint_7",
|
||||
"right_gripper",
|
||||
# Left arm
|
||||
"left_joint_1",
|
||||
"left_joint_2",
|
||||
"left_joint_3",
|
||||
"left_joint_4",
|
||||
"left_joint_5",
|
||||
"left_joint_6",
|
||||
"left_joint_7",
|
||||
"left_gripper",
|
||||
]
|
||||
|
||||
# Performance monitoring
|
||||
loop_times = []
|
||||
avg_loop_time = 0.0
|
||||
min_loop_time = float('inf')
|
||||
max_loop_time = 0.0
|
||||
stats_update_interval = 1.0 # Update stats every 1 second
|
||||
last_stats_update = time.perf_counter()
|
||||
|
||||
|
||||
SWAPPED_JOINTS = {
|
||||
"right_joint_6": "right_joint_7",
|
||||
"right_joint_7": "right_joint_6",
|
||||
"left_joint_6": "left_joint_7",
|
||||
"left_joint_7": "left_joint_6",
|
||||
}
|
||||
|
||||
try:
|
||||
while True:
|
||||
loop_start = time.perf_counter()
|
||||
|
||||
# Get actions and observations
|
||||
leader_action = leader.get_action()
|
||||
follower_obs = follower.get_observation()
|
||||
|
||||
joint_action = {}
|
||||
for joint in all_joints:
|
||||
leader_key = f"{joint}.pos"
|
||||
|
||||
# Determine which follower joint this leader joint controls
|
||||
follower_joint = SWAPPED_JOINTS.get(joint, joint)
|
||||
follower_key = f"{follower_joint}.pos"
|
||||
|
||||
# Get leader position (default 0 if missing)
|
||||
pos = leader_action.get(leader_key, 0.0)
|
||||
|
||||
# Convert gripper values: Mini uses 0-100, OpenArms uses 0 to -65 degrees
|
||||
if "gripper" in joint:
|
||||
# Map 0-100 (Mini) to 0 to -65 (OpenArms)
|
||||
# 0 (closed) -> 0°, 100 (open) -> -65°
|
||||
pos = (pos / 100.0) * -65.0
|
||||
|
||||
# Store in action dict for follower
|
||||
joint_action[follower_key] = pos
|
||||
|
||||
follower.send_action(joint_action)
|
||||
|
||||
# Loop timing
|
||||
loop_end = time.perf_counter()
|
||||
loop_time = loop_end - loop_start
|
||||
loop_times.append(loop_time)
|
||||
|
||||
# Update stats periodically
|
||||
current_time = time.perf_counter()
|
||||
if current_time - last_stats_update >= stats_update_interval:
|
||||
if loop_times:
|
||||
avg_loop_time = sum(loop_times) / len(loop_times)
|
||||
min_loop_time = min(loop_times)
|
||||
max_loop_time = max(loop_times)
|
||||
loop_times = []
|
||||
last_stats_update = current_time
|
||||
|
||||
# Display everything
|
||||
sys.stdout.write("\033[H\033[J") # Clear screen
|
||||
|
||||
# Show timing stats at the top
|
||||
if avg_loop_time > 0:
|
||||
avg_hz = 1.0 / avg_loop_time
|
||||
min_hz = 1.0 / max_loop_time if max_loop_time > 0 else 0
|
||||
max_hz = 1.0 / min_loop_time if min_loop_time > 0 and min_loop_time < float('inf') else 0
|
||||
print(f"[Performance] Target: {TARGET_FPS} Hz | Avg: {avg_hz:.1f} Hz | Range: {min_hz:.1f}-{max_hz:.1f} Hz | Loop: {avg_loop_time*1000:.1f} ms\n")
|
||||
else:
|
||||
print(f"[Performance] Target: {TARGET_FPS} Hz | Measuring...\n")
|
||||
|
||||
# Show joint positions
|
||||
print(f"{'Joint':<20} {'Leader':>15} {'Follower':>15}")
|
||||
print(f"{'':20} {'(0-100/deg)':>15} {'(deg)':>15}")
|
||||
print("-" * 52)
|
||||
|
||||
for joint in all_joints:
|
||||
leader_key = f"{joint}.pos"
|
||||
follower_joint = SWAPPED_JOINTS.get(joint, joint)
|
||||
follower_key = f"{follower_joint}.pos"
|
||||
|
||||
leader_pos = leader_action.get(leader_key, 0.0)
|
||||
follower_pos = follower_obs.get(follower_key, 0.0)
|
||||
|
||||
print(f"{joint:<20} {leader_pos:>15.2f} {follower_pos:>15.2f}")
|
||||
|
||||
# Smart sleep to maintain target FPS
|
||||
dt_s = time.perf_counter() - loop_start
|
||||
busy_wait(max(0, 1.0 / TARGET_FPS - dt_s))
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\nStopping teleoperation...")
|
||||
finally:
|
||||
# Disconnect devices
|
||||
print("Disconnecting devices...")
|
||||
try:
|
||||
follower.disconnect()
|
||||
except Exception as e:
|
||||
print(f"Error disconnecting follower: {e}")
|
||||
|
||||
try:
|
||||
leader.disconnect()
|
||||
except Exception as e:
|
||||
print(f"Error disconnecting leader: {e}")
|
||||
|
||||
print("Done!")
|
||||
|
||||
@@ -0,0 +1,202 @@
|
||||
"""
|
||||
OpenArms Teleoperation with Gravity + Friction Compensation
|
||||
|
||||
Leader arms (both LEFT and RIGHT): Gravity + Friction compensation (weightless, easy to move)
|
||||
Follower arms (both LEFT and RIGHT): Mirror leader movements
|
||||
|
||||
Uses the URDF file from the lerobot repository.
|
||||
"""
|
||||
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig
|
||||
from lerobot.robots.openarms.openarms_follower import OpenArmsFollower
|
||||
from lerobot.teleoperators.openarms.config_openarms_leader import OpenArmsLeaderConfig
|
||||
from lerobot.teleoperators.openarms.openarms_leader import OpenArmsLeader
|
||||
|
||||
# Friction compensation scale factor (1.0 = full, 0.3 = 30% for stability)
|
||||
FRICTION_SCALE = 1.0
|
||||
|
||||
def main():
|
||||
"""Main teleoperation loop with gravity compensation"""
|
||||
|
||||
print("=" * 70)
|
||||
print("OpenArms Teleoperation with Gravity Compensation")
|
||||
print("=" * 70)
|
||||
|
||||
# Configuration
|
||||
follower_config = OpenArmsFollowerConfig(
|
||||
port_left="can2",
|
||||
port_right="can3",
|
||||
can_interface="socketcan",
|
||||
id="openarms_follower",
|
||||
disable_torque_on_disconnect=True,
|
||||
max_relative_target=10.0,
|
||||
)
|
||||
|
||||
leader_config = OpenArmsLeaderConfig(
|
||||
port_left="can0",
|
||||
port_right="can1",
|
||||
can_interface="socketcan",
|
||||
id="openarms_leader",
|
||||
manual_control=False, # Enable torque control for gravity compensation
|
||||
)
|
||||
|
||||
# Initialize and connect
|
||||
print("\nInitializing devices...")
|
||||
follower = OpenArmsFollower(follower_config)
|
||||
leader = OpenArmsLeader(leader_config)
|
||||
|
||||
follower.connect()
|
||||
leader.connect()
|
||||
|
||||
# URDF is automatically loaded in the leader constructor
|
||||
if leader.pin_robot is None:
|
||||
raise RuntimeError("URDF model not loaded on leader. Gravity compensation not available.")
|
||||
|
||||
print("\nLeader BOTH arms: Gravity + Friction comp | Follower BOTH arms: Teleop")
|
||||
print("Press ENTER to start...")
|
||||
input()
|
||||
|
||||
# Enable motors on both leader arms for gravity compensation
|
||||
leader.bus_right.enable_torque()
|
||||
leader.bus_left.enable_torque()
|
||||
time.sleep(0.1)
|
||||
|
||||
print("Press Ctrl+C to stop\n")
|
||||
|
||||
# Main control loop
|
||||
loop_times = []
|
||||
last_print_time = time.perf_counter()
|
||||
|
||||
# All joints (both arms)
|
||||
all_joints = []
|
||||
for motor in leader.bus_right.motors:
|
||||
all_joints.append(f"right_{motor}")
|
||||
for motor in leader.bus_left.motors:
|
||||
all_joints.append(f"left_{motor}")
|
||||
|
||||
try:
|
||||
while True:
|
||||
loop_start = time.perf_counter()
|
||||
|
||||
# Get leader state
|
||||
leader_action = leader.get_action()
|
||||
|
||||
# Extract positions and velocities in degrees
|
||||
leader_positions_deg = {}
|
||||
leader_velocities_deg_per_sec = {}
|
||||
|
||||
for motor in leader.bus_right.motors:
|
||||
pos_key = f"right_{motor}.pos"
|
||||
vel_key = f"right_{motor}.vel"
|
||||
if pos_key in leader_action:
|
||||
leader_positions_deg[f"right_{motor}"] = leader_action[pos_key]
|
||||
if vel_key in leader_action:
|
||||
leader_velocities_deg_per_sec[f"right_{motor}"] = leader_action[vel_key]
|
||||
|
||||
for motor in leader.bus_left.motors:
|
||||
pos_key = f"left_{motor}.pos"
|
||||
vel_key = f"left_{motor}.vel"
|
||||
if pos_key in leader_action:
|
||||
leader_positions_deg[f"left_{motor}"] = leader_action[pos_key]
|
||||
if vel_key in leader_action:
|
||||
leader_velocities_deg_per_sec[f"left_{motor}"] = leader_action[vel_key]
|
||||
|
||||
# Calculate gravity torques for leader using built-in method
|
||||
leader_positions_rad = {k: np.deg2rad(v) for k, v in leader_positions_deg.items()}
|
||||
leader_gravity_torques_nm = leader._gravity_from_q(leader_positions_rad)
|
||||
|
||||
# Calculate friction torques for leader using built-in method
|
||||
leader_velocities_rad_per_sec = {k: np.deg2rad(v) for k, v in leader_velocities_deg_per_sec.items()}
|
||||
leader_friction_torques_nm = leader._friction_from_velocity(
|
||||
leader_velocities_rad_per_sec,
|
||||
friction_scale=FRICTION_SCALE
|
||||
)
|
||||
|
||||
# Combine gravity + friction torques
|
||||
leader_total_torques_nm = {}
|
||||
for motor_name in leader_gravity_torques_nm:
|
||||
gravity = leader_gravity_torques_nm.get(motor_name, 0.0)
|
||||
friction = leader_friction_torques_nm.get(motor_name, 0.0)
|
||||
leader_total_torques_nm[motor_name] = gravity + friction
|
||||
|
||||
# Apply gravity + friction compensation to leader RIGHT arm (all joints including gripper)
|
||||
for motor in leader.bus_right.motors:
|
||||
full_name = f"right_{motor}"
|
||||
position = leader_positions_deg.get(full_name, 0.0)
|
||||
torque = leader_total_torques_nm.get(full_name, 0.0)
|
||||
|
||||
# Get damping gain for stability
|
||||
kd = leader.get_damping_kd(motor)
|
||||
|
||||
leader.bus_right._mit_control(
|
||||
motor=motor,
|
||||
kp=0.0,
|
||||
kd=kd, # Add damping for stability
|
||||
position_degrees=position,
|
||||
velocity_deg_per_sec=0.0,
|
||||
torque=torque,
|
||||
)
|
||||
|
||||
# Apply gravity + friction compensation to leader LEFT arm (all joints including gripper)
|
||||
for motor in leader.bus_left.motors:
|
||||
full_name = f"left_{motor}"
|
||||
position = leader_positions_deg.get(full_name, 0.0)
|
||||
torque = leader_total_torques_nm.get(full_name, 0.0)
|
||||
|
||||
# Get damping gain for stability
|
||||
kd = leader.get_damping_kd(motor)
|
||||
|
||||
leader.bus_left._mit_control(
|
||||
motor=motor,
|
||||
kp=0.0,
|
||||
kd=kd, # Add damping for stability
|
||||
position_degrees=position,
|
||||
velocity_deg_per_sec=0.0,
|
||||
torque=torque,
|
||||
)
|
||||
|
||||
# Send leader positions to follower (both arms)
|
||||
follower_action = {}
|
||||
for joint in all_joints:
|
||||
pos_key = f"{joint}.pos"
|
||||
if pos_key in leader_action:
|
||||
follower_action[pos_key] = leader_action[pos_key]
|
||||
|
||||
if follower_action:
|
||||
follower.send_action(follower_action)
|
||||
|
||||
# Performance monitoring
|
||||
loop_end = time.perf_counter()
|
||||
loop_time = loop_end - loop_start
|
||||
loop_times.append(loop_time)
|
||||
|
||||
if loop_end - last_print_time >= 2.0:
|
||||
if loop_times:
|
||||
avg_time = sum(loop_times) / len(loop_times)
|
||||
current_hz = 1.0 / avg_time if avg_time > 0 else 0
|
||||
|
||||
print(f"{current_hz:.1f} Hz ({avg_time*1000:.1f} ms)")
|
||||
|
||||
loop_times = []
|
||||
last_print_time = loop_end
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\nStopping...")
|
||||
finally:
|
||||
try:
|
||||
leader.bus_right.disable_torque()
|
||||
leader.bus_left.disable_torque()
|
||||
time.sleep(0.1)
|
||||
leader.disconnect()
|
||||
follower.disconnect()
|
||||
print("✓ Shutdown complete")
|
||||
except Exception as e:
|
||||
print(f"Shutdown error: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,152 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Unify all tasks in a dataset to a single task (modifies in-place).
|
||||
|
||||
This script:
|
||||
1. Loads a dataset
|
||||
2. Sets all task_index to 0 and task description to "fold"
|
||||
3. Updates tasks.parquet and task_index in data files (in-place, no copying)
|
||||
|
||||
Usage:
|
||||
python examples/openarms/unify_task.py --repo-id lerobot-data-collection/level1_rac1
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||
from lerobot.datasets.utils import (
|
||||
DATA_DIR,
|
||||
write_info,
|
||||
write_tasks,
|
||||
)
|
||||
from lerobot.utils.constants import HF_LEROBOT_HOME
|
||||
|
||||
|
||||
# Single unified task
|
||||
UNIFIED_TASK = "fold"
|
||||
|
||||
|
||||
def unify_dataset_tasks(
|
||||
repo_id: str,
|
||||
root: Path | None = None,
|
||||
push_to_hub: bool = False,
|
||||
) -> None:
|
||||
"""Unify all tasks in a dataset to a single task (modifies in-place).
|
||||
|
||||
Args:
|
||||
repo_id: Dataset repository ID.
|
||||
root: Optional root path for dataset.
|
||||
push_to_hub: Whether to push the result to HuggingFace Hub.
|
||||
"""
|
||||
input_root = root if root else HF_LEROBOT_HOME / repo_id
|
||||
input_repo_id = repo_id
|
||||
|
||||
logging.info(f"Loading metadata from {repo_id}")
|
||||
|
||||
# Load source metadata
|
||||
src_meta = LeRobotDatasetMetadata(repo_id, root=input_root)
|
||||
|
||||
logging.info(f"Source dataset: {src_meta.total_episodes} episodes, {src_meta.total_frames} frames")
|
||||
logging.info(f"Original tasks: {len(src_meta.tasks)}")
|
||||
|
||||
# Modify in-place (input_root == output_root supported)
|
||||
data_dir = input_root / DATA_DIR
|
||||
|
||||
# Process data files - set all task_index to 0
|
||||
logging.info("Processing data files (in-place)...")
|
||||
for parquet_file in tqdm(sorted(data_dir.rglob("*.parquet")), desc="Processing data"):
|
||||
df = pd.read_parquet(parquet_file)
|
||||
df["task_index"] = 0 # All tasks unified to index 0
|
||||
df.to_parquet(parquet_file)
|
||||
|
||||
# Process episodes metadata - set all tasks to unified task
|
||||
logging.info("Processing episodes metadata (in-place)...")
|
||||
episodes_dir = input_root / "meta" / "episodes"
|
||||
if episodes_dir.exists():
|
||||
for parquet_file in tqdm(sorted(episodes_dir.rglob("*.parquet")), desc="Processing episodes"):
|
||||
df = pd.read_parquet(parquet_file)
|
||||
df["tasks"] = [[UNIFIED_TASK]] * len(df) # All episodes get the unified task
|
||||
df.to_parquet(parquet_file)
|
||||
else:
|
||||
logging.warning(f"No episodes directory found at {episodes_dir}, skipping")
|
||||
|
||||
# Update tasks.parquet with single task
|
||||
logging.info(f"Creating single task: {UNIFIED_TASK}")
|
||||
new_tasks = pd.DataFrame({"task_index": [0]}, index=[UNIFIED_TASK])
|
||||
write_tasks(new_tasks, input_root)
|
||||
|
||||
# Update info.json
|
||||
new_info = src_meta.info.copy()
|
||||
new_info["total_tasks"] = 1
|
||||
write_info(new_info, input_root)
|
||||
|
||||
logging.info(f"Dataset modified in-place at {input_root}")
|
||||
logging.info(f"Task: {UNIFIED_TASK}")
|
||||
|
||||
if push_to_hub:
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
logging.info(f"Pushing {input_repo_id} to hub")
|
||||
dataset = LeRobotDataset(input_repo_id, root=input_root)
|
||||
dataset.push_to_hub(private=True)
|
||||
logging.info("Push complete!")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Unify all tasks in a dataset to a single task 'fold' (modifies in-place)."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Dataset repository ID",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--root",
|
||||
type=Path,
|
||||
default=None,
|
||||
help="Optional root path (defaults to HF_LEROBOT_HOME/repo_id)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push-to-hub",
|
||||
action="store_true",
|
||||
help="Push result to HuggingFace Hub",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
||||
|
||||
unify_dataset_tasks(
|
||||
repo_id=args.repo_id,
|
||||
root=args.root,
|
||||
push_to_hub=args.push_to_hub,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,745 @@
|
||||
body {
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, sans-serif;
|
||||
background: #f5f5f5;
|
||||
}
|
||||
|
||||
main {
|
||||
min-height: 100vh;
|
||||
padding: 2rem;
|
||||
}
|
||||
|
||||
header {
|
||||
text-align: center;
|
||||
margin-bottom: 2rem;
|
||||
}
|
||||
|
||||
h1 {
|
||||
font-size: 2rem;
|
||||
font-weight: 600;
|
||||
color: #333;
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
h2 {
|
||||
font-size: 1.25rem;
|
||||
font-weight: 600;
|
||||
color: #333;
|
||||
margin: 0 0 1rem 0;
|
||||
}
|
||||
|
||||
h3 {
|
||||
font-size: 0.875rem;
|
||||
font-weight: 600;
|
||||
color: #666;
|
||||
margin: 0 0 0.5rem 0;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.5px;
|
||||
}
|
||||
|
||||
.container {
|
||||
max-width: 1920px;
|
||||
margin: 0 auto;
|
||||
display: grid;
|
||||
grid-template-columns: minmax(500px, 600px) 1fr;
|
||||
gap: 2rem;
|
||||
align-items: start;
|
||||
}
|
||||
|
||||
/* Left column container */
|
||||
.left-column {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 1.5rem;
|
||||
}
|
||||
|
||||
/* Right column container */
|
||||
.right-column {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 1.5rem;
|
||||
}
|
||||
|
||||
/* Responsive: Stack on smaller screens */
|
||||
@media (max-width: 1200px) {
|
||||
.container {
|
||||
grid-template-columns: 1fr;
|
||||
}
|
||||
}
|
||||
|
||||
.panel {
|
||||
background: white;
|
||||
border-radius: 8px;
|
||||
padding: 1.5rem;
|
||||
box-shadow: 0 1px 3px rgba(0,0,0,0.1);
|
||||
}
|
||||
|
||||
.config-panel {
|
||||
border: 2px solid #e5e7eb;
|
||||
}
|
||||
|
||||
.config-header {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
cursor: pointer;
|
||||
user-select: none;
|
||||
padding: 0.5rem 0;
|
||||
}
|
||||
|
||||
.config-header:hover {
|
||||
opacity: 0.7;
|
||||
}
|
||||
|
||||
.toggle-icon {
|
||||
font-size: 1rem;
|
||||
color: #6b7280;
|
||||
transition: transform 0.2s;
|
||||
}
|
||||
|
||||
.config-content {
|
||||
margin-top: 1rem;
|
||||
padding-top: 1rem;
|
||||
border-top: 1px solid #e5e7eb;
|
||||
}
|
||||
|
||||
.robot-setup {
|
||||
margin-bottom: 0.5rem;
|
||||
}
|
||||
|
||||
.robot-status {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: space-between;
|
||||
padding: 1rem;
|
||||
border-radius: 6px;
|
||||
font-weight: 500;
|
||||
gap: 1rem;
|
||||
}
|
||||
|
||||
.robot-status.ready {
|
||||
background: linear-gradient(135deg, #d1fae5 0%, #a7f3d0 100%);
|
||||
color: #065f46;
|
||||
border: 1px solid #10b981;
|
||||
}
|
||||
|
||||
.robot-status.not-ready {
|
||||
background: linear-gradient(135deg, #fef3c7 0%, #fde68a 100%);
|
||||
color: #92400e;
|
||||
border: 1px solid #f59e0b;
|
||||
}
|
||||
|
||||
.btn-setup {
|
||||
background: #10b981;
|
||||
color: white;
|
||||
border: none;
|
||||
padding: 0.5rem 1rem;
|
||||
border-radius: 4px;
|
||||
font-size: 0.875rem;
|
||||
font-weight: 500;
|
||||
cursor: pointer;
|
||||
transition: background 0.2s;
|
||||
}
|
||||
|
||||
.btn-setup:hover:not(:disabled) {
|
||||
background: #059669;
|
||||
}
|
||||
|
||||
.btn-setup:disabled {
|
||||
background: #d1d5db;
|
||||
cursor: not-allowed;
|
||||
}
|
||||
|
||||
.btn-zero {
|
||||
background: #8b5cf6;
|
||||
color: white;
|
||||
border: none;
|
||||
padding: 0.5rem 1rem;
|
||||
border-radius: 4px;
|
||||
font-size: 0.875rem;
|
||||
font-weight: 500;
|
||||
cursor: pointer;
|
||||
transition: background 0.2s;
|
||||
}
|
||||
|
||||
.btn-zero:hover:not(:disabled) {
|
||||
background: #7c3aed;
|
||||
}
|
||||
|
||||
.btn-zero:disabled {
|
||||
background: #d1d5db;
|
||||
cursor: not-allowed;
|
||||
}
|
||||
|
||||
.zero-position-section {
|
||||
margin-top: 1rem;
|
||||
padding-top: 1rem;
|
||||
border-top: 1px solid #e5e7eb;
|
||||
}
|
||||
|
||||
.btn-zero-large {
|
||||
width: 100%;
|
||||
background: #8b5cf6;
|
||||
color: white;
|
||||
border: none;
|
||||
padding: 0.875rem 1.5rem;
|
||||
border-radius: 8px;
|
||||
font-size: 1rem;
|
||||
font-weight: 600;
|
||||
cursor: pointer;
|
||||
transition: all 0.2s;
|
||||
box-shadow: 0 2px 4px rgba(139, 92, 246, 0.2);
|
||||
}
|
||||
|
||||
.btn-zero-large:hover:not(:disabled) {
|
||||
background: #7c3aed;
|
||||
box-shadow: 0 4px 8px rgba(139, 92, 246, 0.3);
|
||||
transform: translateY(-1px);
|
||||
}
|
||||
|
||||
.btn-zero-large:disabled {
|
||||
background: #d1d5db;
|
||||
cursor: not-allowed;
|
||||
box-shadow: none;
|
||||
transform: none;
|
||||
}
|
||||
|
||||
.delete-episode-section {
|
||||
margin-top: 1rem;
|
||||
padding-top: 1rem;
|
||||
border-top: 1px solid #e5e7eb;
|
||||
}
|
||||
|
||||
.btn-delete {
|
||||
width: 100%;
|
||||
background: #ef4444;
|
||||
color: white;
|
||||
border: none;
|
||||
padding: 0.875rem 1.5rem;
|
||||
border-radius: 8px;
|
||||
font-size: 1rem;
|
||||
font-weight: 600;
|
||||
cursor: pointer;
|
||||
transition: all 0.2s;
|
||||
box-shadow: 0 2px 4px rgba(239, 68, 68, 0.2);
|
||||
}
|
||||
|
||||
.btn-delete:hover:not(:disabled) {
|
||||
background: #dc2626;
|
||||
box-shadow: 0 4px 8px rgba(239, 68, 68, 0.3);
|
||||
transform: translateY(-1px);
|
||||
}
|
||||
|
||||
.btn-delete:disabled {
|
||||
background: #d1d5db;
|
||||
cursor: not-allowed;
|
||||
box-shadow: none;
|
||||
transform: none;
|
||||
}
|
||||
|
||||
.delete-info {
|
||||
margin-top: 0.5rem;
|
||||
font-size: 0.875rem;
|
||||
color: #666;
|
||||
text-align: center;
|
||||
font-style: italic;
|
||||
}
|
||||
|
||||
.btn-disconnect {
|
||||
background: #ef4444;
|
||||
color: white;
|
||||
border: none;
|
||||
padding: 0.5rem 1rem;
|
||||
border-radius: 4px;
|
||||
font-size: 0.875rem;
|
||||
font-weight: 500;
|
||||
cursor: pointer;
|
||||
transition: background 0.2s;
|
||||
}
|
||||
|
||||
.btn-disconnect:hover {
|
||||
background: #dc2626;
|
||||
}
|
||||
|
||||
.btn-refresh {
|
||||
background: #3b82f6;
|
||||
color: white;
|
||||
border: none;
|
||||
padding: 0.4rem 0.8rem;
|
||||
border-radius: 4px;
|
||||
font-size: 0.75rem;
|
||||
font-weight: 500;
|
||||
cursor: pointer;
|
||||
transition: background 0.2s;
|
||||
}
|
||||
|
||||
.btn-refresh:hover:not(:disabled) {
|
||||
background: #2563eb;
|
||||
}
|
||||
|
||||
.btn-refresh:disabled {
|
||||
background: #d1d5db;
|
||||
cursor: not-allowed;
|
||||
}
|
||||
|
||||
.control-panel {
|
||||
border: 2px solid #10b981;
|
||||
}
|
||||
|
||||
.status-banner {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 1rem;
|
||||
padding: 1rem 1.5rem;
|
||||
border-radius: 6px;
|
||||
margin-bottom: 1.5rem;
|
||||
font-weight: 500;
|
||||
font-size: 0.95rem;
|
||||
}
|
||||
|
||||
.status-banner.initializing {
|
||||
background: linear-gradient(135deg, #dbeafe 0%, #bfdbfe 100%);
|
||||
color: #1e40af;
|
||||
border-left: 4px solid #3b82f6;
|
||||
}
|
||||
|
||||
.status-banner.encoding {
|
||||
background: linear-gradient(135deg, #fef3c7 0%, #fde68a 100%);
|
||||
color: #92400e;
|
||||
border-left: 4px solid #f59e0b;
|
||||
}
|
||||
|
||||
.status-banner.uploading {
|
||||
background: linear-gradient(135deg, #e0e7ff 0%, #c7d2fe 100%);
|
||||
color: #3730a3;
|
||||
border-left: 4px solid #6366f1;
|
||||
}
|
||||
|
||||
.status-banner.success {
|
||||
background: linear-gradient(135deg, #d1fae5 0%, #a7f3d0 100%);
|
||||
color: #065f46;
|
||||
border-left: 4px solid #10b981;
|
||||
}
|
||||
|
||||
.status-banner.warning {
|
||||
background: linear-gradient(135deg, #fee2e2 0%, #fecaca 100%);
|
||||
color: #991b1b;
|
||||
border-left: 4px solid #ef4444;
|
||||
}
|
||||
|
||||
.spinner {
|
||||
width: 20px;
|
||||
height: 20px;
|
||||
border: 3px solid rgba(0, 0, 0, 0.1);
|
||||
border-top-color: currentColor;
|
||||
border-radius: 50%;
|
||||
animation: spin 0.8s linear infinite;
|
||||
}
|
||||
|
||||
@keyframes spin {
|
||||
to { transform: rotate(360deg); }
|
||||
}
|
||||
|
||||
.control-horizontal {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 1.5rem;
|
||||
}
|
||||
|
||||
.control-left {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 1rem;
|
||||
}
|
||||
|
||||
.control-right {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
}
|
||||
|
||||
.input-group {
|
||||
display: flex;
|
||||
gap: 0.5rem;
|
||||
margin-bottom: 0;
|
||||
}
|
||||
|
||||
input[type="text"] {
|
||||
flex: 1;
|
||||
padding: 0.75rem;
|
||||
border: 1px solid #ddd;
|
||||
border-radius: 4px;
|
||||
font-size: 1rem;
|
||||
}
|
||||
|
||||
input[type="text"]:disabled {
|
||||
background: #f5f5f5;
|
||||
cursor: not-allowed;
|
||||
}
|
||||
|
||||
input[type="text"]:focus {
|
||||
outline: none;
|
||||
border-color: #10b981;
|
||||
}
|
||||
|
||||
button {
|
||||
padding: 0.75rem 1.5rem;
|
||||
border: none;
|
||||
border-radius: 4px;
|
||||
font-size: 1rem;
|
||||
font-weight: 500;
|
||||
cursor: pointer;
|
||||
transition: all 0.2s;
|
||||
}
|
||||
|
||||
.btn-set-task {
|
||||
background: #3b82f6;
|
||||
color: white;
|
||||
min-width: 120px;
|
||||
}
|
||||
|
||||
.btn-set-task:hover:not(:disabled) {
|
||||
background: #2563eb;
|
||||
}
|
||||
|
||||
.btn-set-task:disabled {
|
||||
background: #d1d5db;
|
||||
cursor: not-allowed;
|
||||
}
|
||||
|
||||
.btn-start {
|
||||
background: #10b981;
|
||||
color: white;
|
||||
}
|
||||
|
||||
.btn-start:hover:not(:disabled) {
|
||||
background: #059669;
|
||||
}
|
||||
|
||||
.btn-start:disabled {
|
||||
background: #d1d5db;
|
||||
cursor: not-allowed;
|
||||
}
|
||||
|
||||
.btn-stop {
|
||||
background: #ef4444;
|
||||
color: white;
|
||||
}
|
||||
|
||||
.btn-stop:hover {
|
||||
background: #dc2626;
|
||||
}
|
||||
|
||||
.btn-reset {
|
||||
padding: 0.5rem 1rem;
|
||||
background: #6b7280;
|
||||
color: white;
|
||||
font-size: 0.875rem;
|
||||
}
|
||||
|
||||
.btn-reset:hover {
|
||||
background: #4b5563;
|
||||
}
|
||||
|
||||
.status {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.75rem;
|
||||
padding: 1rem;
|
||||
border-radius: 4px;
|
||||
margin-bottom: 1rem;
|
||||
}
|
||||
|
||||
.status.recording {
|
||||
background: #fee2e2;
|
||||
color: #991b1b;
|
||||
}
|
||||
|
||||
.status.recording.recording-active {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 1rem;
|
||||
background: #dc2626;
|
||||
color: white;
|
||||
padding: 1.5rem;
|
||||
border: 4px solid #991b1b;
|
||||
box-shadow: 0 4px 12px rgba(220, 38, 38, 0.4);
|
||||
font-weight: 700;
|
||||
font-size: 1rem;
|
||||
}
|
||||
|
||||
.status.recording.recording-active .indicator {
|
||||
width: 20px;
|
||||
height: 20px;
|
||||
background: #fef2f2;
|
||||
animation: pulse-strong 1s ease-in-out infinite;
|
||||
}
|
||||
|
||||
@keyframes pulse-strong {
|
||||
0%, 100% {
|
||||
opacity: 1;
|
||||
transform: scale(1);
|
||||
}
|
||||
50% {
|
||||
opacity: 0.7;
|
||||
transform: scale(1.1);
|
||||
}
|
||||
}
|
||||
|
||||
.status.recording.recording-active .time-display {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.5rem;
|
||||
font-size: 1.5rem;
|
||||
font-weight: 700;
|
||||
color: white;
|
||||
}
|
||||
|
||||
.fps-display {
|
||||
font-size: 1rem;
|
||||
font-weight: 500;
|
||||
opacity: 0.95;
|
||||
}
|
||||
|
||||
.fps-warning {
|
||||
color: #fef2f2;
|
||||
animation: pulse-warning 1s ease-in-out infinite;
|
||||
}
|
||||
|
||||
@keyframes pulse-warning {
|
||||
0%, 100% { opacity: 1; }
|
||||
50% { opacity: 0.5; }
|
||||
}
|
||||
|
||||
.status.recording.recording-active .btn-stop {
|
||||
align-self: stretch;
|
||||
}
|
||||
|
||||
.ramp-up-countdown {
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
margin-bottom: 1rem;
|
||||
}
|
||||
|
||||
.countdown-box {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
padding: 2rem 3rem;
|
||||
background: linear-gradient(135deg, #fef3c7 0%, #fde68a 100%);
|
||||
border: 4px solid #f59e0b;
|
||||
border-radius: 16px;
|
||||
box-shadow: 0 6px 20px rgba(245, 158, 11, 0.4);
|
||||
min-width: 280px;
|
||||
animation: pulse-warm 1.5s ease-in-out infinite;
|
||||
}
|
||||
|
||||
@keyframes pulse-warm {
|
||||
0%, 100% {
|
||||
box-shadow: 0 6px 20px rgba(245, 158, 11, 0.4);
|
||||
}
|
||||
50% {
|
||||
box-shadow: 0 6px 25px rgba(245, 158, 11, 0.6);
|
||||
}
|
||||
}
|
||||
|
||||
.countdown-label {
|
||||
font-size: 1rem;
|
||||
color: #92400e;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 1.5px;
|
||||
font-weight: 800;
|
||||
margin-bottom: 1rem;
|
||||
text-align: center;
|
||||
}
|
||||
|
||||
.countdown-value {
|
||||
font-size: 4.5rem;
|
||||
font-weight: 900;
|
||||
color: #d97706;
|
||||
font-family: 'Courier New', monospace;
|
||||
line-height: 1;
|
||||
text-shadow: 2px 2px 6px rgba(0, 0, 0, 0.15);
|
||||
margin-bottom: 0.5rem;
|
||||
}
|
||||
|
||||
.countdown-subtitle {
|
||||
font-size: 0.875rem;
|
||||
color: #78350f;
|
||||
font-weight: 600;
|
||||
font-style: italic;
|
||||
text-align: center;
|
||||
margin-top: 0.5rem;
|
||||
}
|
||||
|
||||
.status.idle {
|
||||
background: #f3f4f6;
|
||||
color: #374151;
|
||||
}
|
||||
|
||||
.indicator {
|
||||
width: 12px;
|
||||
height: 12px;
|
||||
border-radius: 50%;
|
||||
background: #ef4444;
|
||||
animation: pulse 1.5s ease-in-out infinite;
|
||||
}
|
||||
|
||||
@keyframes pulse {
|
||||
0%, 100% { opacity: 1; }
|
||||
50% { opacity: 0.5; }
|
||||
}
|
||||
|
||||
.counter {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: center;
|
||||
gap: 0.75rem;
|
||||
padding: 1.5rem;
|
||||
background: linear-gradient(135deg, #f9fafb 0%, #f3f4f6 100%);
|
||||
border-radius: 8px;
|
||||
border: 2px solid #e5e7eb;
|
||||
min-width: 200px;
|
||||
}
|
||||
|
||||
.counter-label {
|
||||
font-size: 0.75rem;
|
||||
color: #6b7280;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.5px;
|
||||
font-weight: 600;
|
||||
}
|
||||
|
||||
.counter-value {
|
||||
font-size: 3rem;
|
||||
font-weight: 700;
|
||||
color: #10b981;
|
||||
line-height: 1;
|
||||
}
|
||||
|
||||
.time-display {
|
||||
font-size: 1.5rem;
|
||||
font-weight: 600;
|
||||
font-family: 'Courier New', monospace;
|
||||
}
|
||||
|
||||
.error-box {
|
||||
padding: 1rem;
|
||||
background: #fee2e2;
|
||||
color: #991b1b;
|
||||
border-radius: 4px;
|
||||
border-left: 4px solid #ef4444;
|
||||
font-size: 0.875rem;
|
||||
}
|
||||
|
||||
.config-section {
|
||||
margin-bottom: 1.5rem;
|
||||
}
|
||||
|
||||
.config-section:last-child {
|
||||
margin-bottom: 0;
|
||||
}
|
||||
|
||||
.config-grid {
|
||||
display: grid;
|
||||
grid-template-columns: repeat(auto-fit, minmax(200px, 1fr));
|
||||
gap: 1rem;
|
||||
}
|
||||
|
||||
label {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.5rem;
|
||||
font-size: 0.875rem;
|
||||
color: #374151;
|
||||
font-weight: 500;
|
||||
}
|
||||
|
||||
select {
|
||||
padding: 0.5rem;
|
||||
border: 1px solid #ddd;
|
||||
border-radius: 4px;
|
||||
font-size: 0.875rem;
|
||||
background: white;
|
||||
}
|
||||
|
||||
select:disabled {
|
||||
background: #f5f5f5;
|
||||
cursor: not-allowed;
|
||||
}
|
||||
|
||||
/* Camera Layout */
|
||||
.camera-layout {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 1.5rem;
|
||||
}
|
||||
|
||||
.camera-base {
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
.camera-wrist-container {
|
||||
display: grid;
|
||||
grid-template-columns: repeat(2, 1fr);
|
||||
gap: 1.5rem;
|
||||
}
|
||||
|
||||
.camera-wrist {
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
.camera {
|
||||
border: 1px solid #e5e7eb;
|
||||
border-radius: 4px;
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
.camera h3 {
|
||||
padding: 0.75rem;
|
||||
background: #f9fafb;
|
||||
border-bottom: 1px solid #e5e7eb;
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
.camera img {
|
||||
width: 100%;
|
||||
height: auto;
|
||||
display: block;
|
||||
background: #000;
|
||||
min-height: 300px;
|
||||
object-fit: cover;
|
||||
}
|
||||
|
||||
.camera-placeholder {
|
||||
text-align: center;
|
||||
padding: 4rem 2rem;
|
||||
background: #f9fafb;
|
||||
border-radius: 4px;
|
||||
border: 2px dashed #d1d5db;
|
||||
}
|
||||
|
||||
.camera-placeholder p {
|
||||
margin: 0.5rem 0;
|
||||
font-size: 1rem;
|
||||
color: #6b7280;
|
||||
}
|
||||
|
||||
.camera-placeholder p:first-child {
|
||||
font-size: 1.25rem;
|
||||
font-weight: 500;
|
||||
color: #374151;
|
||||
}
|
||||
|
||||
.hint {
|
||||
margin-top: 0.5rem;
|
||||
font-size: 0.75rem;
|
||||
color: #6b7280;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.5rem;
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
|
||||
@@ -0,0 +1,857 @@
|
||||
import { useState, useEffect, useCallback, useRef } from 'react';
|
||||
import './App.css';
|
||||
|
||||
const API_BASE = 'http://localhost:8000/api';
|
||||
|
||||
function App() {
|
||||
// State
|
||||
const [task, setTask] = useState('');
|
||||
const [isRecording, setIsRecording] = useState(false);
|
||||
const [isInitializing, setIsInitializing] = useState(false);
|
||||
const [isEncoding, setIsEncoding] = useState(false);
|
||||
const [isUploading, setIsUploading] = useState(false);
|
||||
const [robotsReady, setRobotsReady] = useState(false);
|
||||
const [elapsedTime, setElapsedTime] = useState(0);
|
||||
const [currentFps, setCurrentFps] = useState(0);
|
||||
const [loopFps, setLoopFps] = useState(0);
|
||||
const [episodeCount, setEpisodeCount] = useState(0);
|
||||
const [error, setError] = useState(null);
|
||||
const [statusMessage, setStatusMessage] = useState('Ready');
|
||||
const [uploadStatus, setUploadStatus] = useState(null);
|
||||
const [rampUpRemaining, setRampUpRemaining] = useState(0);
|
||||
const [movingToZero, setMovingToZero] = useState(false);
|
||||
const [configExpanded, setConfigExpanded] = useState(false);
|
||||
const [latestRepoId, setLatestRepoId] = useState(null);
|
||||
|
||||
// Configuration
|
||||
const [config, setConfig] = useState({
|
||||
leader_type: 'openarms', // 'openarms' or 'openarms_mini'
|
||||
leader_left: 'can0',
|
||||
leader_right: 'can1',
|
||||
follower_left: 'can2',
|
||||
follower_right: 'can3',
|
||||
left_wrist: '/dev/video0',
|
||||
right_wrist: '/dev/video1',
|
||||
base: '/dev/video4'
|
||||
});
|
||||
|
||||
// Available options
|
||||
const [availableCameras, setAvailableCameras] = useState([]);
|
||||
const [availableUsbPorts, setAvailableUsbPorts] = useState([]);
|
||||
const canInterfaces = ['can0', 'can1', 'can2', 'can3'];
|
||||
|
||||
const statusIntervalRef = useRef(null);
|
||||
const hasInitializedRef = useRef(false);
|
||||
|
||||
const loadConfig = () => {
|
||||
try {
|
||||
const saved = localStorage.getItem('openarms_config');
|
||||
if (saved) {
|
||||
const loadedConfig = JSON.parse(saved);
|
||||
setConfig(prev => ({ ...prev, ...loadedConfig }));
|
||||
}
|
||||
} catch (e) {
|
||||
console.error('Load config error:', e);
|
||||
}
|
||||
};
|
||||
|
||||
const saveConfig = (newConfig) => {
|
||||
try {
|
||||
localStorage.setItem('openarms_config', JSON.stringify(newConfig || config));
|
||||
} catch (e) {
|
||||
console.error('Save config error:', e);
|
||||
}
|
||||
};
|
||||
|
||||
// Fetch status periodically
|
||||
const fetchStatus = async () => {
|
||||
try {
|
||||
const response = await fetch(`${API_BASE}/status`);
|
||||
const data = await response.json();
|
||||
|
||||
setIsRecording(data.is_recording);
|
||||
setIsInitializing(data.is_initializing);
|
||||
setIsEncoding(data.is_encoding);
|
||||
setIsUploading(data.is_uploading);
|
||||
setRobotsReady(data.robots_ready);
|
||||
setElapsedTime(data.elapsed_time);
|
||||
setCurrentFps(data.current_fps || 0);
|
||||
setLoopFps(data.loop_fps || 0);
|
||||
setEpisodeCount(data.episode_count);
|
||||
setError(data.error);
|
||||
setStatusMessage(data.status_message || 'Ready');
|
||||
setUploadStatus(data.upload_status);
|
||||
setRampUpRemaining(data.ramp_up_remaining || 0);
|
||||
setMovingToZero(data.moving_to_zero || false);
|
||||
|
||||
// Track the latest repo_id from the backend
|
||||
if (data.latest_repo_id) {
|
||||
setLatestRepoId(data.latest_repo_id);
|
||||
}
|
||||
|
||||
if (data.config) {
|
||||
// Only merge server config if we don't have a saved config (first load)
|
||||
if (!localStorage.getItem('openarms_config')) {
|
||||
setConfig(prev => {
|
||||
const merged = { ...data.config, ...prev };
|
||||
localStorage.setItem('openarms_config', JSON.stringify(merged));
|
||||
return merged;
|
||||
});
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
console.error('Failed to fetch status:', e);
|
||||
}
|
||||
};
|
||||
|
||||
const setupRobots = async () => {
|
||||
// Show warning to verify camera positions
|
||||
const confirmed = window.confirm(
|
||||
'⚠️ IMPORTANT: Before connecting robots, please verify:\n\n' +
|
||||
'📹 Check that cameras are correctly positioned:\n' +
|
||||
' • LEFT wrist camera is actually on the LEFT arm\n' +
|
||||
' • RIGHT wrist camera is actually on the RIGHT arm\n' +
|
||||
' • BASE camera is actually the BASE/overhead camera\n\n' +
|
||||
'Incorrect camera positioning will result in invalid training data!\n\n' +
|
||||
'Click OK to continue with robot setup, or Cancel to review configuration.'
|
||||
);
|
||||
|
||||
if (!confirmed) {
|
||||
return; // User cancelled, don't proceed
|
||||
}
|
||||
|
||||
setError(null);
|
||||
try {
|
||||
const response = await fetch(`${API_BASE}/robots/setup`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify(config)
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const data = await response.json();
|
||||
throw new Error(data.detail || 'Failed to setup robots');
|
||||
}
|
||||
|
||||
await response.json();
|
||||
saveConfig(config);
|
||||
} catch (e) {
|
||||
setError(`Robot setup failed: ${e.message}`);
|
||||
}
|
||||
};
|
||||
|
||||
// Disconnect robots
|
||||
const disconnectRobots = async () => {
|
||||
try {
|
||||
await fetch(`${API_BASE}/robots/disconnect`, { method: 'POST' });
|
||||
setRobotsReady(false);
|
||||
} catch (e) {
|
||||
console.error('Failed to disconnect robots:', e);
|
||||
}
|
||||
};
|
||||
|
||||
// Discover cameras
|
||||
const discoverCameras = async () => {
|
||||
try {
|
||||
const response = await fetch(`${API_BASE}/cameras/discover`);
|
||||
const data = await response.json();
|
||||
const cameras = data.cameras || [];
|
||||
setAvailableCameras(cameras);
|
||||
|
||||
// Get list of valid camera IDs
|
||||
const validCameraIds = cameras.map(cam => String(cam.id));
|
||||
|
||||
// Auto-fix config if current values are invalid or not set
|
||||
const updated = { ...config };
|
||||
let changed = false;
|
||||
|
||||
// Auto-fix invalid camera config
|
||||
if (!config.left_wrist || !validCameraIds.includes(config.left_wrist)) {
|
||||
if (cameras.length >= 1) {
|
||||
updated.left_wrist = String(cameras[0].id);
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (!config.right_wrist || !validCameraIds.includes(config.right_wrist)) {
|
||||
if (cameras.length >= 2) {
|
||||
updated.right_wrist = String(cameras[1].id);
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (!config.base || !validCameraIds.includes(config.base)) {
|
||||
if (cameras.length >= 3) {
|
||||
updated.base = String(cameras[2].id);
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (changed) {
|
||||
setConfig(updated);
|
||||
saveConfig(updated);
|
||||
}
|
||||
|
||||
if (cameras.length === 0) {
|
||||
setError('No cameras detected! Please connect cameras and refresh.');
|
||||
}
|
||||
} catch (e) {
|
||||
console.error('Failed to discover cameras:', e);
|
||||
setError(`Camera discovery failed: ${e.message}`);
|
||||
}
|
||||
};
|
||||
|
||||
// Discover USB ports
|
||||
const discoverUsbPorts = async () => {
|
||||
try {
|
||||
const response = await fetch(`${API_BASE}/usb/discover`);
|
||||
const data = await response.json();
|
||||
const ports = data.ports || [];
|
||||
setAvailableUsbPorts(ports);
|
||||
|
||||
// Auto-fix config if OpenArms Mini is selected and ports are invalid
|
||||
if (config.leader_type === 'openarms_mini') {
|
||||
const updated = { ...config };
|
||||
let changed = false;
|
||||
|
||||
if (ports.length >= 1 && !ports.includes(config.leader_left)) {
|
||||
updated.leader_left = ports[0];
|
||||
changed = true;
|
||||
}
|
||||
|
||||
if (ports.length >= 2 && !ports.includes(config.leader_right)) {
|
||||
updated.leader_right = ports[1];
|
||||
changed = true;
|
||||
}
|
||||
|
||||
if (changed) {
|
||||
setConfig(updated);
|
||||
saveConfig(updated);
|
||||
}
|
||||
}
|
||||
|
||||
if (ports.length === 0) {
|
||||
console.warn('No USB ports detected for OpenArms Mini');
|
||||
}
|
||||
} catch (e) {
|
||||
console.error('Failed to discover USB ports:', e);
|
||||
}
|
||||
};
|
||||
|
||||
// Set task only (for pedal use)
|
||||
const setTaskOnly = async () => {
|
||||
if (!task.trim()) {
|
||||
setError('Please enter a task description');
|
||||
return;
|
||||
}
|
||||
|
||||
setError(null);
|
||||
|
||||
try {
|
||||
const response = await fetch(`${API_BASE}/recording/set-task`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ task, ...config })
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const data = await response.json();
|
||||
throw new Error(data.detail || 'Failed to set task');
|
||||
}
|
||||
|
||||
const result = await response.json();
|
||||
setStatusMessage(result.message || `Task set: ${task}`);
|
||||
saveConfig(config);
|
||||
|
||||
// Clear success message after 3 seconds
|
||||
setTimeout(() => {
|
||||
if (!isRecording && !isInitializing) {
|
||||
setStatusMessage('Ready');
|
||||
}
|
||||
}, 3000);
|
||||
} catch (e) {
|
||||
setError(e.message);
|
||||
}
|
||||
};
|
||||
|
||||
// Start recording
|
||||
const startRecording = async () => {
|
||||
if (!task.trim()) {
|
||||
setError('Please enter a task description');
|
||||
return;
|
||||
}
|
||||
|
||||
setError(null);
|
||||
|
||||
try {
|
||||
const response = await fetch(`${API_BASE}/recording/start`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ task, ...config })
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const data = await response.json();
|
||||
throw new Error(data.detail || 'Failed to start recording');
|
||||
}
|
||||
|
||||
await response.json();
|
||||
saveConfig(config);
|
||||
} catch (e) {
|
||||
setError(e.message);
|
||||
}
|
||||
};
|
||||
|
||||
// Stop recording
|
||||
const stopRecording = async () => {
|
||||
try {
|
||||
const response = await fetch(`${API_BASE}/recording/stop`, {
|
||||
method: 'POST'
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const data = await response.json();
|
||||
throw new Error(data.detail || 'Failed to stop recording');
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
setError(null);
|
||||
// Update latest repo_id after recording
|
||||
if (data.dataset_name) {
|
||||
setLatestRepoId(`lerobot-data-collection/${data.dataset_name}`);
|
||||
}
|
||||
} catch (e) {
|
||||
setError(e.message);
|
||||
}
|
||||
};
|
||||
|
||||
const deleteLatestEpisode = async () => {
|
||||
if (!latestRepoId) {
|
||||
setError('No episode to delete');
|
||||
return;
|
||||
}
|
||||
|
||||
const confirmed = window.confirm(
|
||||
`WARNING: This will permanently delete the repository:\n\n${latestRepoId}\n\nThis action cannot be undone. Continue?`
|
||||
);
|
||||
|
||||
if (!confirmed) {
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const response = await fetch(`${API_BASE}/recording/delete-latest`, { method: 'POST' });
|
||||
|
||||
if (!response.ok) {
|
||||
const data = await response.json();
|
||||
throw new Error(data.detail || 'Failed to delete episode');
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
setLatestRepoId(null);
|
||||
setEpisodeCount(Math.max(0, episodeCount - 1));
|
||||
setStatusMessage(`Deleted: ${data.deleted_repo}`);
|
||||
|
||||
setTimeout(() => {
|
||||
if (!isRecording && !isInitializing) {
|
||||
setStatusMessage('Ready');
|
||||
}
|
||||
}, 3000);
|
||||
} catch (e) {
|
||||
setError(`Delete failed: ${e.message}`);
|
||||
}
|
||||
};
|
||||
|
||||
// Reset counter
|
||||
const resetCounter = async () => {
|
||||
try {
|
||||
await fetch(`${API_BASE}/counter/reset`, { method: 'POST' });
|
||||
setEpisodeCount(0);
|
||||
} catch (e) {
|
||||
console.error('Failed to reset counter:', e);
|
||||
}
|
||||
};
|
||||
|
||||
// Move robot to zero position
|
||||
const moveToZero = async () => {
|
||||
setError(null);
|
||||
try {
|
||||
const response = await fetch(`${API_BASE}/robots/move-to-zero`, { method: 'POST' });
|
||||
if (!response.ok) {
|
||||
const data = await response.json();
|
||||
throw new Error(data.detail || 'Failed to move to zero position');
|
||||
}
|
||||
await response.json();
|
||||
} catch (e) {
|
||||
setError(`Move to zero failed: ${e.message}`);
|
||||
}
|
||||
};
|
||||
|
||||
// Format time as MM:SS
|
||||
const formatTime = (seconds) => {
|
||||
const mins = Math.floor(seconds / 60);
|
||||
const secs = Math.floor(seconds % 60);
|
||||
return `${mins.toString().padStart(2, '0')}:${secs.toString().padStart(2, '0')}`;
|
||||
};
|
||||
|
||||
// Update config and save
|
||||
const updateConfig = (key, value) => {
|
||||
const updated = { ...config, [key]: value };
|
||||
setConfig(updated);
|
||||
saveConfig(updated);
|
||||
};
|
||||
|
||||
// Initialize on mount only
|
||||
useEffect(() => {
|
||||
// Prevent double-initialization in development
|
||||
if (hasInitializedRef.current) {
|
||||
return;
|
||||
}
|
||||
hasInitializedRef.current = true;
|
||||
|
||||
loadConfig();
|
||||
discoverCameras();
|
||||
discoverUsbPorts();
|
||||
fetchStatus();
|
||||
statusIntervalRef.current = setInterval(fetchStatus, 1000);
|
||||
|
||||
return () => {
|
||||
if (statusIntervalRef.current) {
|
||||
clearInterval(statusIntervalRef.current);
|
||||
}
|
||||
};
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, []); // Run only once on mount
|
||||
|
||||
// Discover USB ports when leader type changes to Mini
|
||||
useEffect(() => {
|
||||
if (config.leader_type === 'openarms_mini') {
|
||||
discoverUsbPorts();
|
||||
}
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [config.leader_type]);
|
||||
|
||||
return (
|
||||
<main>
|
||||
<header>
|
||||
<h1>OpenArms Recording</h1>
|
||||
</header>
|
||||
|
||||
<div className="container">
|
||||
{/* Left Column: Configuration and Recording Control */}
|
||||
<div className="left-column">
|
||||
{/* Configuration Panel */}
|
||||
<section className="panel config-panel">
|
||||
<div
|
||||
className="config-header"
|
||||
onClick={() => setConfigExpanded(!configExpanded)}
|
||||
role="button"
|
||||
tabIndex={0}
|
||||
onKeyDown={(e) => e.key === 'Enter' && setConfigExpanded(!configExpanded)}
|
||||
>
|
||||
<h2>⚙️ Configuration</h2>
|
||||
<span className="toggle-icon">{configExpanded ? '▼' : '▶'}</span>
|
||||
</div>
|
||||
|
||||
{configExpanded && (
|
||||
<div className="config-content">
|
||||
{/* Robot Setup */}
|
||||
<div className="config-section">
|
||||
<h3>🤖 Robot Setup</h3>
|
||||
<div className="robot-setup">
|
||||
{robotsReady ? (
|
||||
<div className="robot-status ready">
|
||||
<span>✅ Robots Ready - Recording will start instantly</span>
|
||||
<button onClick={disconnectRobots} className="btn-disconnect">
|
||||
Disconnect Robots
|
||||
</button>
|
||||
</div>
|
||||
) : (
|
||||
<div className="robot-status not-ready">
|
||||
<span>⚠️ Robots not initialized - Recording will take ~10 seconds</span>
|
||||
<button
|
||||
onClick={setupRobots}
|
||||
disabled={isRecording || isInitializing}
|
||||
className="btn-setup"
|
||||
>
|
||||
🚀 Setup Robots
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Leader Type Selection */}
|
||||
<div className="config-section">
|
||||
<h3>🎮 Leader Type</h3>
|
||||
<div className="config-grid">
|
||||
<label style={{gridColumn: '1 / -1'}}>
|
||||
Leader Arm Type
|
||||
<select
|
||||
value={config.leader_type}
|
||||
onChange={(e) => updateConfig('leader_type', e.target.value)}
|
||||
disabled={isRecording || robotsReady}
|
||||
>
|
||||
<option value="openarms">OpenArms (CAN Bus - Damiao Motors)</option>
|
||||
<option value="openarms_mini">OpenArms Mini (USB - Feetech Motors)</option>
|
||||
</select>
|
||||
</label>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Leader Interfaces (CAN or USB based on type) */}
|
||||
<div className="config-section">
|
||||
<div style={{ display: 'flex', justifyContent: 'space-between', alignItems: 'center', marginBottom: '0.5rem' }}>
|
||||
<h3>
|
||||
{config.leader_type === 'openarms_mini'
|
||||
? `Leader Ports (USB/Serial) ${availableUsbPorts.length > 0 ? `(${availableUsbPorts.length} detected)` : ''}`
|
||||
: 'Leader Interfaces (CAN)'}
|
||||
</h3>
|
||||
{config.leader_type === 'openarms_mini' && (
|
||||
<button
|
||||
onClick={discoverUsbPorts}
|
||||
className="btn-refresh"
|
||||
disabled={isRecording || robotsReady}
|
||||
>
|
||||
🔄 Refresh
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div className="config-grid">
|
||||
<label>
|
||||
Leader Left
|
||||
<select
|
||||
value={config.leader_left}
|
||||
onChange={(e) => updateConfig('leader_left', e.target.value)}
|
||||
disabled={isRecording || robotsReady}
|
||||
>
|
||||
{config.leader_type === 'openarms_mini' ? (
|
||||
availableUsbPorts.length > 0 ? (
|
||||
availableUsbPorts.map((port) => (
|
||||
<option key={port} value={port}>{port}</option>
|
||||
))
|
||||
) : (
|
||||
<option value="">No USB ports detected</option>
|
||||
)
|
||||
) : (
|
||||
canInterfaces.map((iface) => (
|
||||
<option key={iface} value={iface}>{iface}</option>
|
||||
))
|
||||
)}
|
||||
</select>
|
||||
</label>
|
||||
|
||||
<label>
|
||||
Leader Right
|
||||
<select
|
||||
value={config.leader_right}
|
||||
onChange={(e) => updateConfig('leader_right', e.target.value)}
|
||||
disabled={isRecording || robotsReady}
|
||||
>
|
||||
{config.leader_type === 'openarms_mini' ? (
|
||||
availableUsbPorts.length > 0 ? (
|
||||
availableUsbPorts.map((port) => (
|
||||
<option key={port} value={port}>{port}</option>
|
||||
))
|
||||
) : (
|
||||
<option value="">No USB ports detected</option>
|
||||
)
|
||||
) : (
|
||||
canInterfaces.map((iface) => (
|
||||
<option key={iface} value={iface}>{iface}</option>
|
||||
))
|
||||
)}
|
||||
</select>
|
||||
</label>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Follower CAN Interfaces */}
|
||||
<div className="config-section">
|
||||
<h3>Follower Interfaces (CAN)</h3>
|
||||
|
||||
<div className="config-grid">
|
||||
<label>
|
||||
Follower Left
|
||||
<select
|
||||
value={config.follower_left}
|
||||
onChange={(e) => updateConfig('follower_left', e.target.value)}
|
||||
disabled={isRecording || robotsReady}
|
||||
>
|
||||
{canInterfaces.map((iface) => (
|
||||
<option key={iface} value={iface}>{iface}</option>
|
||||
))}
|
||||
</select>
|
||||
</label>
|
||||
|
||||
<label>
|
||||
Follower Right
|
||||
<select
|
||||
value={config.follower_right}
|
||||
onChange={(e) => updateConfig('follower_right', e.target.value)}
|
||||
disabled={isRecording || robotsReady}
|
||||
>
|
||||
{canInterfaces.map((iface) => (
|
||||
<option key={iface} value={iface}>{iface}</option>
|
||||
))}
|
||||
</select>
|
||||
</label>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Camera Configuration */}
|
||||
<div className="config-section">
|
||||
<div style={{ display: 'flex', justifyContent: 'space-between', alignItems: 'center', marginBottom: '0.5rem' }}>
|
||||
<h3>Cameras {availableCameras.length > 0 && `(${availableCameras.length} detected)`}</h3>
|
||||
<button
|
||||
onClick={discoverCameras}
|
||||
className="btn-refresh"
|
||||
disabled={isRecording || robotsReady}
|
||||
>
|
||||
🔄 Refresh
|
||||
</button>
|
||||
</div>
|
||||
<div className="config-grid">
|
||||
<label>
|
||||
Left Wrist
|
||||
<select
|
||||
value={config.left_wrist}
|
||||
onChange={(e) => updateConfig('left_wrist', e.target.value)}
|
||||
disabled={isRecording || robotsReady}
|
||||
>
|
||||
{availableCameras.map((cam) => (
|
||||
<option key={cam.id} value={String(cam.id)}>
|
||||
{cam.name || `Camera @ ${cam.id}`}
|
||||
</option>
|
||||
))}
|
||||
</select>
|
||||
</label>
|
||||
|
||||
<label>
|
||||
Right Wrist
|
||||
<select
|
||||
value={config.right_wrist}
|
||||
onChange={(e) => updateConfig('right_wrist', e.target.value)}
|
||||
disabled={isRecording || robotsReady}
|
||||
>
|
||||
{availableCameras.map((cam) => (
|
||||
<option key={cam.id} value={String(cam.id)}>
|
||||
{cam.name || `Camera @ ${cam.id}`}
|
||||
</option>
|
||||
))}
|
||||
</select>
|
||||
</label>
|
||||
|
||||
<label>
|
||||
Base Camera
|
||||
<select
|
||||
value={config.base}
|
||||
onChange={(e) => updateConfig('base', e.target.value)}
|
||||
disabled={isRecording || robotsReady}
|
||||
>
|
||||
{availableCameras.map((cam) => (
|
||||
<option key={cam.id} value={String(cam.id)}>
|
||||
{cam.name || `Camera @ ${cam.id}`}
|
||||
</option>
|
||||
))}
|
||||
</select>
|
||||
</label>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</section>
|
||||
|
||||
{/* Control Panel */}
|
||||
<section className="panel control-panel">
|
||||
<h2>🎬 Recording Control</h2>
|
||||
|
||||
{/* Status Banner - Always show important statuses */}
|
||||
{isInitializing && (
|
||||
<div className="status-banner initializing">
|
||||
<div className="spinner"></div>
|
||||
<span>{statusMessage}</span>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{isEncoding && (
|
||||
<div className="status-banner encoding">
|
||||
<div className="spinner"></div>
|
||||
<span>📹 {statusMessage}</span>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{isUploading && (
|
||||
<div className="status-banner uploading">
|
||||
<div className="spinner"></div>
|
||||
<span>☁️ {statusMessage}</span>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{uploadStatus && !isRecording && !isEncoding && !isUploading && (
|
||||
<div className={`status-banner ${uploadStatus.startsWith('✓') ? 'success' : 'warning'}`}>
|
||||
<span>{uploadStatus}</span>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="control-horizontal">
|
||||
{/* Task Input and Status */}
|
||||
<div className="control-left">
|
||||
<div className="input-group">
|
||||
<input
|
||||
type="text"
|
||||
value={task}
|
||||
onChange={(e) => setTask(e.target.value)}
|
||||
placeholder="Task description (e.g., 'pick and place')"
|
||||
disabled={isRecording || isInitializing || isEncoding || isUploading}
|
||||
onKeyPress={(e) => {
|
||||
if (e.key === 'Enter' && robotsReady) {
|
||||
setTaskOnly();
|
||||
}
|
||||
}}
|
||||
/>
|
||||
<button
|
||||
onClick={setTaskOnly}
|
||||
disabled={isRecording || isInitializing || isEncoding || isUploading || !robotsReady}
|
||||
className="btn-set-task"
|
||||
title={!robotsReady ? 'Please setup robots first' : 'Store task for pedal use (Enter key)'}
|
||||
>
|
||||
💾 Set Task
|
||||
</button>
|
||||
<button
|
||||
onClick={startRecording}
|
||||
disabled={isRecording || isInitializing || isEncoding || isUploading || !robotsReady}
|
||||
className="btn-start"
|
||||
title={!robotsReady ? 'Please setup robots first' : ''}
|
||||
>
|
||||
{isInitializing
|
||||
? '⏳ Initializing...'
|
||||
: isRecording
|
||||
? '⏺ Recording...'
|
||||
: robotsReady
|
||||
? '⏺ Start Recording'
|
||||
: '⏺ Setup Robots First'}
|
||||
</button>
|
||||
</div>
|
||||
|
||||
{/* Ramp-up Countdown */}
|
||||
{isRecording && rampUpRemaining > 0 && (
|
||||
<div className="ramp-up-countdown">
|
||||
<div className="countdown-box">
|
||||
<div className="countdown-label">⚡ WARMING UP - PID RAMP-UP</div>
|
||||
<div className="countdown-value">{rampUpRemaining.toFixed(1)}s</div>
|
||||
<div className="countdown-subtitle">Recording will start automatically...</div>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Recording Status - Only show after ramp-up */}
|
||||
{isRecording && rampUpRemaining <= 0 && (
|
||||
<div className="status recording recording-active">
|
||||
<div className="indicator"></div>
|
||||
<div className="time-display">
|
||||
<span>{formatTime(elapsedTime)}</span>
|
||||
<span className="fps-display">
|
||||
Loop: {loopFps.toFixed(1)} Hz
|
||||
{loopFps > 0 && loopFps < 29 && <span className="fps-warning"> ⚠️</span>}
|
||||
</span>
|
||||
<span className="fps-display">Recording: {currentFps.toFixed(1)} FPS</span>
|
||||
</div>
|
||||
<button onClick={stopRecording} className="btn-stop">
|
||||
⏹ Stop
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Episode Counter */}
|
||||
<div className="control-right">
|
||||
<div className="counter">
|
||||
<div className="counter-label">Episodes Recorded</div>
|
||||
<div className="counter-value">{episodeCount}</div>
|
||||
<button onClick={resetCounter} className="btn-reset">
|
||||
Reset
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Delete Latest Episode Button */}
|
||||
{!isRecording && !isInitializing && latestRepoId && (
|
||||
<div className="delete-episode-section">
|
||||
<button
|
||||
onClick={deleteLatestEpisode}
|
||||
className="btn-delete"
|
||||
title="Delete the latest recorded episode from HuggingFace Hub"
|
||||
>
|
||||
Delete Latest Episode
|
||||
</button>
|
||||
<div className="delete-info">Will delete: {latestRepoId}</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Move to Zero Button */}
|
||||
{robotsReady && !isRecording && !isInitializing && (
|
||||
<div className="zero-position-section">
|
||||
<button
|
||||
onClick={moveToZero}
|
||||
disabled={movingToZero}
|
||||
className="btn-zero-large"
|
||||
title="Move both leader and follower robots to zero position (2s)"
|
||||
>
|
||||
{movingToZero ? '⏳ Moving to Zero Position...' : '🎯 Move to Zero Position (Leader + Follower)'}
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Error Display */}
|
||||
{error && (
|
||||
<div className="error-box">
|
||||
⚠️ {error}
|
||||
</div>
|
||||
)}
|
||||
</section>
|
||||
</div>
|
||||
|
||||
{/* Right Column: Camera Feeds */}
|
||||
<div className="right-column">
|
||||
<section className="panel cameras">
|
||||
<h2>📹 Camera Views</h2>
|
||||
{robotsReady || isRecording || isInitializing ? (
|
||||
<div className="camera-layout">
|
||||
{/* Base camera - full width */}
|
||||
<div className="camera camera-base">
|
||||
<h3>Base Camera</h3>
|
||||
<img src={`${API_BASE}/camera/stream/base`} alt="Base Camera" />
|
||||
</div>
|
||||
|
||||
{/* Wrist cameras - side by side */}
|
||||
<div className="camera-wrist-container">
|
||||
<div className="camera camera-wrist">
|
||||
<h3>Left Wrist</h3>
|
||||
<img src={`${API_BASE}/camera/stream/left_wrist`} alt="Left Wrist Camera" />
|
||||
</div>
|
||||
|
||||
<div className="camera camera-wrist">
|
||||
<h3>Right Wrist</h3>
|
||||
<img src={`${API_BASE}/camera/stream/right_wrist`} alt="Right Wrist Camera" />
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
) : (
|
||||
<div className="camera-placeholder">
|
||||
<p>📷 Camera feeds will appear when robots are set up</p>
|
||||
<p className="hint">Click "Setup Robots" above to preview camera feeds</p>
|
||||
</div>
|
||||
)}
|
||||
</section>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
</main>
|
||||
);
|
||||
}
|
||||
|
||||
export default App;
|
||||
|
||||
@@ -0,0 +1,41 @@
|
||||
# OpenArms Web Recording Interface
|
||||
|
||||
A web interface for recording OpenArms datasets.
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
cd examples/openarms_web_interface
|
||||
npm install
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
**Start everything with one command:**
|
||||
|
||||
```bash
|
||||
./launch.sh
|
||||
```
|
||||
|
||||
This will:
|
||||
- Start the FastAPI backend on port 8000
|
||||
- Start the React frontend on port 5173
|
||||
- Show live logs from both services
|
||||
|
||||
Then open your browser to: **http://localhost:5173**
|
||||
|
||||
**Stop with:** `Ctrl+C`
|
||||
|
||||
---
|
||||
|
||||
## Workflow
|
||||
|
||||
1. **Configure CAN interfaces** and **camera paths** in the dropdowns
|
||||
2. Click **"Setup Robots"** to initialize (once at start)
|
||||
3. Enter a **task description**
|
||||
4. Click **"Start Recording"** to begin an episode
|
||||
5. Click **"Stop Recording"** when done
|
||||
6. Dataset is automatically encoded and uploaded to HuggingFace Hub as **private**
|
||||
7. Repeat steps 3-6 for more episodes (no need to re-setup robots!)
|
||||
|
||||
---
|
||||
@@ -0,0 +1,12 @@
|
||||
<!doctype html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<title>OpenArms Recording Interface</title>
|
||||
</head>
|
||||
<body>
|
||||
<div id="root"></div>
|
||||
<script type="module" src="/main.jsx"></script>
|
||||
</body>
|
||||
</html>
|
||||
@@ -0,0 +1,142 @@
|
||||
#!/bin/bash
|
||||
|
||||
# OpenArms Web Interface Launcher
|
||||
# Starts Rerun viewer, FastAPI backend, and React frontend
|
||||
|
||||
set -e
|
||||
|
||||
# Colors for output
|
||||
GREEN='\033[0;32m'
|
||||
BLUE='\033[0;34m'
|
||||
YELLOW='\033[1;33m'
|
||||
RED='\033[0;31m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
# Get script directory
|
||||
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
|
||||
cd "$SCRIPT_DIR"
|
||||
|
||||
echo -e "${BLUE}╔════════════════════════════════════════╗${NC}"
|
||||
echo -e "${BLUE}║ OpenArms Web Recording Interface ║${NC}"
|
||||
echo -e "${BLUE}╚════════════════════════════════════════╝${NC}"
|
||||
echo ""
|
||||
|
||||
# Function to cleanup on exit
|
||||
cleanup() {
|
||||
echo ""
|
||||
echo -e "${YELLOW}Shutting down services...${NC}"
|
||||
|
||||
# Kill all child processes
|
||||
pkill -P $$ 2>/dev/null || true
|
||||
|
||||
# Kill specific services by port
|
||||
lsof -ti:8000 | xargs kill -9 2>/dev/null || true # Backend
|
||||
lsof -ti:5173 | xargs kill -9 2>/dev/null || true # Frontend
|
||||
lsof -ti:9876 | xargs kill -9 2>/dev/null || true # Rerun (if spawned)
|
||||
|
||||
echo -e "${GREEN}✓ Services stopped${NC}"
|
||||
exit 0
|
||||
}
|
||||
|
||||
# Register cleanup on script exit
|
||||
trap cleanup EXIT INT TERM
|
||||
|
||||
# Check if required commands exist
|
||||
command -v rerun >/dev/null 2>&1 || {
|
||||
echo -e "${RED}✗ Error: 'rerun' not found. Please install: pip install rerun-sdk${NC}"
|
||||
exit 1
|
||||
}
|
||||
|
||||
command -v python >/dev/null 2>&1 || {
|
||||
echo -e "${RED}✗ Error: 'python' not found${NC}"
|
||||
exit 1
|
||||
}
|
||||
|
||||
command -v npm >/dev/null 2>&1 || {
|
||||
echo -e "${RED}✗ Error: 'npm' not found${NC}"
|
||||
exit 1
|
||||
}
|
||||
|
||||
# Check if node_modules exists
|
||||
if [ ! -d "node_modules" ]; then
|
||||
echo -e "${YELLOW}⚠ node_modules not found. Running npm install...${NC}"
|
||||
npm install
|
||||
echo -e "${GREEN}✓ Dependencies installed${NC}"
|
||||
echo ""
|
||||
fi
|
||||
|
||||
echo -e "${GREEN}Starting services...${NC}"
|
||||
echo ""
|
||||
|
||||
# 1. Start FastAPI backend (Rerun will start when recording begins)
|
||||
echo -e "${BLUE}[1/2]${NC} Starting FastAPI backend on port 8000..."
|
||||
cd "$SCRIPT_DIR"
|
||||
|
||||
# Use Python from current environment (if lerobot env is active, it will use that)
|
||||
# Otherwise, check if we need to use conda run
|
||||
if [[ "$CONDA_DEFAULT_ENV" == "lerobot" ]]; then
|
||||
# Already in lerobot environment
|
||||
echo -e "${GREEN}✓ Using active lerobot environment${NC}"
|
||||
PYTHON_CMD="python"
|
||||
elif command -v conda >/dev/null 2>&1 && conda env list | grep -q "^lerobot "; then
|
||||
# lerobot env exists but not active - use conda run
|
||||
echo -e "${YELLOW}Using conda run with lerobot environment...${NC}"
|
||||
PYTHON_CMD="conda run -n lerobot --no-capture-output python"
|
||||
else
|
||||
# Fall back to system python
|
||||
echo -e "${YELLOW}⚠ Warning: lerobot environment not found, using system python${NC}"
|
||||
PYTHON_CMD="python"
|
||||
fi
|
||||
|
||||
$PYTHON_CMD web_record_server.py > /tmp/openarms_backend.log 2>&1 &
|
||||
BACKEND_PID=$!
|
||||
sleep 3
|
||||
|
||||
if ps -p $BACKEND_PID > /dev/null; then
|
||||
echo -e "${GREEN}✓ Backend started${NC} (PID: $BACKEND_PID)"
|
||||
echo -e " URL: ${BLUE}http://localhost:8000${NC}"
|
||||
else
|
||||
echo -e "${RED}✗ Failed to start backend${NC}"
|
||||
echo -e "${YELLOW}Check logs: tail -f /tmp/openarms_backend.log${NC}"
|
||||
exit 1
|
||||
fi
|
||||
echo ""
|
||||
|
||||
# 2. Start React frontend
|
||||
echo -e "${BLUE}[2/2]${NC} Starting React frontend on port 5173..."
|
||||
cd "$SCRIPT_DIR"
|
||||
npm run dev > /tmp/openarms_frontend.log 2>&1 &
|
||||
FRONTEND_PID=$!
|
||||
sleep 3
|
||||
|
||||
if ps -p $FRONTEND_PID > /dev/null; then
|
||||
echo -e "${GREEN}✓ Frontend started${NC} (PID: $FRONTEND_PID)"
|
||||
echo -e " URL: ${BLUE}http://localhost:5173${NC}"
|
||||
else
|
||||
echo -e "${RED}✗ Failed to start frontend${NC}"
|
||||
echo -e "${YELLOW}Check logs: tail -f /tmp/openarms_frontend.log${NC}"
|
||||
exit 1
|
||||
fi
|
||||
echo ""
|
||||
|
||||
# Display status
|
||||
echo -e "${GREEN}╔════════════════════════════════════════╗${NC}"
|
||||
echo -e "${GREEN}║ All services running! 🚀 ║${NC}"
|
||||
echo -e "${GREEN}╚════════════════════════════════════════╝${NC}"
|
||||
echo ""
|
||||
echo -e "🔧 ${BLUE}Backend:${NC} http://localhost:8000"
|
||||
echo -e "🌐 ${BLUE}Frontend:${NC} http://localhost:5173"
|
||||
echo -e "📊 ${BLUE}Rerun:${NC} Will spawn automatically when recording starts"
|
||||
echo ""
|
||||
echo -e "${YELLOW}Open your browser to:${NC} ${BLUE}http://localhost:5173${NC}"
|
||||
echo ""
|
||||
echo -e "${YELLOW}Logs:${NC}"
|
||||
echo -e " • Backend: tail -f /tmp/openarms_backend.log"
|
||||
echo -e " • Frontend: tail -f /tmp/openarms_frontend.log"
|
||||
echo ""
|
||||
echo -e "${RED}Press Ctrl+C to stop all services${NC}"
|
||||
echo ""
|
||||
|
||||
# Keep script running and wait for any service to exit
|
||||
wait
|
||||
|
||||
@@ -0,0 +1,7 @@
|
||||
import { createRoot } from 'react-dom/client'
|
||||
import App from './App.jsx'
|
||||
|
||||
createRoot(document.getElementById('root')).render(
|
||||
<App />
|
||||
)
|
||||
|
||||
@@ -0,0 +1,21 @@
|
||||
{
|
||||
"name": "openarms-web-interface",
|
||||
"private": true,
|
||||
"version": "0.0.0",
|
||||
"type": "module",
|
||||
"scripts": {
|
||||
"dev": "vite",
|
||||
"build": "vite build",
|
||||
"preview": "vite preview"
|
||||
},
|
||||
"dependencies": {
|
||||
"react": "^18.3.1",
|
||||
"react-dom": "^18.3.1"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@types/react": "^18.3.12",
|
||||
"@types/react-dom": "^18.3.1",
|
||||
"@vitejs/plugin-react": "^4.3.4",
|
||||
"vite": "^6.0.1"
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,17 @@
|
||||
import { defineConfig } from 'vite'
|
||||
import react from '@vitejs/plugin-react'
|
||||
|
||||
// https://vite.dev/config/
|
||||
export default defineConfig({
|
||||
plugins: [react()],
|
||||
server: {
|
||||
port: 5173,
|
||||
strictPort: false,
|
||||
host: true,
|
||||
open: false
|
||||
},
|
||||
build: {
|
||||
outDir: 'dist',
|
||||
sourcemap: true
|
||||
}
|
||||
})
|
||||
@@ -0,0 +1,638 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
RaC (Recovery and Correction) Data Collection with Policy Rollout + Human Intervention.
|
||||
|
||||
This implements the RaC paradigm from "RaC: Robot Learning for Long-Horizon Tasks
|
||||
by Scaling Recovery and Correction" (Hu et al., 2025) for LeRobot.
|
||||
|
||||
RaC improves upon standard data collection (BC) and prior human-in-the-loop methods
|
||||
(DAgger, HG-DAgger) by explicitly collecting recovery and correction behaviors:
|
||||
|
||||
The workflow:
|
||||
1. Policy runs autonomously
|
||||
2. Press SPACE to pause - robot holds position
|
||||
3. Press 'c' to take control - human provides RECOVERY + CORRECTION
|
||||
4. Press → to end episode (save and continue to next)
|
||||
5. Reset, then do next rollout
|
||||
|
||||
Key RaC Rules:
|
||||
- Rule 1 (Recover then Correct): Every intervention = recovery + correction (both human)
|
||||
- Rule 2 (Terminate after Intervention): Episode ends after correction
|
||||
|
||||
The recovery segment (teleoperating back to good state) is recorded as training data -
|
||||
this teaches the policy how to recover from errors.
|
||||
|
||||
Keyboard Controls:
|
||||
SPACE - Pause policy (robot holds position, no recording)
|
||||
c - Take control (start correction, recording resumes)
|
||||
→ - End episode (save and continue to next)
|
||||
← - Re-record episode
|
||||
ESC - Stop recording and push dataset to hub
|
||||
|
||||
Usage:
|
||||
python examples/rac/rac_data_collection.py \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58760431541 \
|
||||
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||
--teleop.type=so100_leader \
|
||||
--teleop.port=/dev/tty.usbmodem58760431551 \
|
||||
--policy.path=outputs/train/my_policy/checkpoints/last/pretrained_model \
|
||||
--dataset.repo_id=my_user/rac_dataset \
|
||||
--dataset.single_task="Pick up the cube"
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from pprint import pformat
|
||||
from typing import Any
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
|
||||
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.datasets.image_writer import safe_stop_image_writer
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
|
||||
from lerobot.datasets.utils import build_dataset_frame, combine_feature_dicts
|
||||
from lerobot.datasets.video_utils import VideoEncodingManager
|
||||
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.utils import make_robot_action
|
||||
from lerobot.processor import (
|
||||
IdentityProcessor,
|
||||
PolicyAction,
|
||||
PolicyProcessorPipeline,
|
||||
RobotAction,
|
||||
RobotObservation,
|
||||
RobotProcessorPipeline,
|
||||
)
|
||||
from lerobot.processor.converters import (
|
||||
observation_to_transition,
|
||||
robot_action_observation_to_transition,
|
||||
transition_to_observation,
|
||||
transition_to_robot_action,
|
||||
)
|
||||
from lerobot.processor.rename_processor import rename_stats
|
||||
from lerobot.robots import Robot, RobotConfig, make_robot_from_config
|
||||
from lerobot.teleoperators import Teleoperator, TeleoperatorConfig, make_teleoperator_from_config
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
from lerobot.utils.control_utils import is_headless, predict_action
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.utils import get_safe_torch_device, init_logging, log_say
|
||||
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
|
||||
|
||||
|
||||
@dataclass
|
||||
class RaCDatasetConfig:
|
||||
repo_id: str
|
||||
single_task: str
|
||||
root: str | Path | None = None
|
||||
fps: int = 30
|
||||
episode_time_s: float = 120
|
||||
reset_time_s: float = 30
|
||||
num_episodes: int = 50
|
||||
video: bool = True
|
||||
push_to_hub: bool = True
|
||||
private: bool = False
|
||||
tags: list[str] | None = None
|
||||
num_image_writer_processes: int = 0
|
||||
num_image_writer_threads_per_camera: int = 4
|
||||
video_encoding_batch_size: int = 1
|
||||
rename_map: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RaCConfig:
|
||||
robot: RobotConfig
|
||||
dataset: RaCDatasetConfig
|
||||
policy: PreTrainedConfig
|
||||
teleop: TeleoperatorConfig
|
||||
display_data: bool = True
|
||||
play_sounds: bool = True
|
||||
resume: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
policy_path = parser.get_path_arg("policy")
|
||||
if policy_path:
|
||||
cli_overrides = parser.get_cli_overrides("policy")
|
||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||
self.policy.pretrained_path = policy_path
|
||||
|
||||
@classmethod
|
||||
def __get_path_fields__(cls) -> list[str]:
|
||||
return ["policy"]
|
||||
|
||||
|
||||
def init_rac_keyboard_listener():
|
||||
"""Initialize keyboard listener with RaC-specific controls."""
|
||||
events = {
|
||||
"exit_early": False,
|
||||
"rerecord_episode": False,
|
||||
"stop_recording": False,
|
||||
"policy_paused": False, # SPACE pressed - policy paused, teleop tracking robot
|
||||
"correction_active": False, # 'c' pressed - human controlling, recording correction
|
||||
"in_reset": False, # True during reset period
|
||||
"start_next_episode": False, # Signal to start next episode
|
||||
}
|
||||
|
||||
if is_headless():
|
||||
logging.warning("Headless environment - keyboard controls unavailable")
|
||||
return None, events
|
||||
|
||||
from pynput import keyboard
|
||||
|
||||
def on_press(key):
|
||||
try:
|
||||
if events["in_reset"]:
|
||||
# During reset: any action key starts next episode
|
||||
if key == keyboard.Key.space or key == keyboard.Key.right:
|
||||
print("\n[RaC] Starting next episode...")
|
||||
events["start_next_episode"] = True
|
||||
elif hasattr(key, 'char') and key.char == 'c':
|
||||
print("\n[RaC] Starting next episode...")
|
||||
events["start_next_episode"] = True
|
||||
elif key == keyboard.Key.esc:
|
||||
print("[RaC] ESC - Stop recording, pushing to hub...")
|
||||
events["stop_recording"] = True
|
||||
events["start_next_episode"] = True
|
||||
else:
|
||||
# During episode
|
||||
if key == keyboard.Key.space:
|
||||
if not events["policy_paused"] and not events["correction_active"]:
|
||||
print("\n[RaC] ⏸ PAUSED - Policy stopped, teleop moving to robot position")
|
||||
print(" Press 'c' or START to take control")
|
||||
events["policy_paused"] = True
|
||||
elif hasattr(key, 'char') and key.char == 'c':
|
||||
if events["policy_paused"] and not events["correction_active"]:
|
||||
print("\n[RaC] ▶ START pressed - taking control")
|
||||
events["start_next_episode"] = True
|
||||
elif key == keyboard.Key.right:
|
||||
print("[RaC] → End episode")
|
||||
events["exit_early"] = True
|
||||
elif key == keyboard.Key.left:
|
||||
print("[RaC] ← Re-record episode")
|
||||
events["rerecord_episode"] = True
|
||||
events["exit_early"] = True
|
||||
elif key == keyboard.Key.esc:
|
||||
print("[RaC] ESC - Stop recording, pushing to hub...")
|
||||
events["stop_recording"] = True
|
||||
events["exit_early"] = True
|
||||
except Exception as e:
|
||||
print(f"Key error: {e}")
|
||||
|
||||
listener = keyboard.Listener(on_press=on_press)
|
||||
listener.start()
|
||||
|
||||
start_pedal_listener(events)
|
||||
|
||||
return listener, events
|
||||
|
||||
|
||||
def start_pedal_listener(events: dict):
|
||||
"""Start foot pedal listener thread if evdev is available."""
|
||||
import threading
|
||||
|
||||
try:
|
||||
from evdev import InputDevice, ecodes
|
||||
except ImportError:
|
||||
logging.info("[Pedal] evdev not installed - pedal support disabled")
|
||||
return
|
||||
|
||||
PEDAL_DEVICE = "/dev/input/by-id/usb-PCsensor_FootSwitch-event-kbd"
|
||||
KEY_LEFT = "KEY_A" # Left pedal
|
||||
KEY_RIGHT = "KEY_C" # Right pedal
|
||||
|
||||
def pedal_reader():
|
||||
try:
|
||||
dev = InputDevice(PEDAL_DEVICE)
|
||||
print(f"[Pedal] Connected: {dev.name}")
|
||||
print(f"[Pedal] Right=pause/next, Left=take control/start")
|
||||
|
||||
for ev in dev.read_loop():
|
||||
if ev.type != ecodes.EV_KEY:
|
||||
continue
|
||||
|
||||
from evdev import categorize
|
||||
key = categorize(ev)
|
||||
code = key.keycode
|
||||
if isinstance(code, (list, tuple)):
|
||||
code = code[0]
|
||||
|
||||
# Only trigger on key down
|
||||
if key.keystate != 1:
|
||||
continue
|
||||
|
||||
if events["in_reset"]:
|
||||
# During reset: either pedal starts next episode
|
||||
if code in [KEY_LEFT, KEY_RIGHT]:
|
||||
print("\n[Pedal] Starting next episode...")
|
||||
events["start_next_episode"] = True
|
||||
else:
|
||||
# During episode
|
||||
if code == KEY_RIGHT:
|
||||
# Right pedal: SPACE (pause) when running, → (next) when in correction
|
||||
if events["correction_active"]:
|
||||
print("\n[Pedal] → End episode")
|
||||
events["exit_early"] = True
|
||||
elif not events["policy_paused"]:
|
||||
print("\n[Pedal] ⏸ PAUSED - Policy stopped, teleop moving to robot")
|
||||
print(" Press left pedal to take control")
|
||||
events["policy_paused"] = True
|
||||
|
||||
elif code == KEY_LEFT:
|
||||
# Left pedal: START (take control) when paused
|
||||
if events["policy_paused"] and not events["correction_active"]:
|
||||
print("\n[Pedal] ▶ START pressed - taking control")
|
||||
events["start_next_episode"] = True
|
||||
|
||||
except FileNotFoundError:
|
||||
logging.info(f"[Pedal] Device not found: {PEDAL_DEVICE}")
|
||||
except PermissionError:
|
||||
logging.warning(f"[Pedal] Permission denied. Run: sudo setfacl -m u:$USER:rw {PEDAL_DEVICE}")
|
||||
except Exception as e:
|
||||
logging.debug(f"[Pedal] Error: {e}")
|
||||
|
||||
thread = threading.Thread(target=pedal_reader, daemon=True)
|
||||
thread.start()
|
||||
|
||||
|
||||
def make_identity_processors():
|
||||
"""Create identity processors for RaC recording."""
|
||||
teleop_proc = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[IdentityProcessor()],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
robot_proc = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[IdentityProcessor()],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
obs_proc = RobotProcessorPipeline[RobotObservation, RobotObservation](
|
||||
steps=[IdentityProcessor()],
|
||||
to_transition=observation_to_transition,
|
||||
to_output=transition_to_observation,
|
||||
)
|
||||
return teleop_proc, robot_proc, obs_proc
|
||||
|
||||
|
||||
def move_robot_to_zero(robot: Robot, duration_s: float = 2.0, fps: int = 50):
|
||||
"""Smoothly move all robot joints to zero position."""
|
||||
obs = robot.get_observation()
|
||||
current_pos = {k: v for k, v in obs.items() if k.endswith(".pos")}
|
||||
target_pos = {k: 0.0 for k in current_pos}
|
||||
|
||||
print(f"[RaC] Moving robot to zero position ({duration_s}s)...")
|
||||
steps = int(duration_s * fps)
|
||||
for step in range(steps + 1):
|
||||
t = step / steps
|
||||
interp_pos = {k: current_pos[k] * (1 - t) + target_pos[k] * t for k in current_pos}
|
||||
robot.send_action(interp_pos)
|
||||
time.sleep(1 / fps)
|
||||
print("[RaC] Robot at zero position.")
|
||||
|
||||
@safe_stop_image_writer
|
||||
def rac_rollout_loop(
|
||||
robot: Robot,
|
||||
teleop: Teleoperator,
|
||||
policy: PreTrainedPolicy,
|
||||
preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
dataset: LeRobotDataset,
|
||||
events: dict,
|
||||
fps: int,
|
||||
control_time_s: float,
|
||||
single_task: str,
|
||||
display_data: bool = True,
|
||||
) -> dict:
|
||||
"""
|
||||
RaC rollout loop with two-stage intervention:
|
||||
|
||||
1. Policy runs autonomously (recording)
|
||||
2. SPACE: Policy pauses (NOT recording) - robot holds position
|
||||
3. 'c': Human takes control (recording correction)
|
||||
4. →: End episode
|
||||
"""
|
||||
policy.reset()
|
||||
preprocessor.reset()
|
||||
postprocessor.reset()
|
||||
|
||||
device = get_safe_torch_device(policy.config.device)
|
||||
frame_buffer = []
|
||||
|
||||
stats = {
|
||||
"total_frames": 0,
|
||||
"autonomous_frames": 0,
|
||||
"paused_frames": 0,
|
||||
"correction_frames": 0,
|
||||
}
|
||||
|
||||
last_robot_action = None
|
||||
was_paused = False
|
||||
was_correction_active = False
|
||||
waiting_for_takeover = False
|
||||
timestamp = 0
|
||||
start_t = time.perf_counter()
|
||||
|
||||
while timestamp < control_time_s:
|
||||
loop_start = time.perf_counter()
|
||||
|
||||
if events["exit_early"]:
|
||||
events["exit_early"] = False
|
||||
events["policy_paused"] = False
|
||||
events["correction_active"] = False
|
||||
break
|
||||
|
||||
# Detect transition to paused state
|
||||
if events["policy_paused"] and not was_paused:
|
||||
obs = robot.get_observation()
|
||||
robot_pos = {k: v for k, v in obs.items() if k.endswith(".pos")}
|
||||
print("[RaC] Moving teleop to robot position (2s smooth transition)...")
|
||||
teleop.smooth_move_to(robot_pos, duration_s=2.0, fps=50)
|
||||
print("[RaC] Teleop aligned. Press START to take control.")
|
||||
events["start_next_episode"] = False
|
||||
waiting_for_takeover = True
|
||||
was_paused = True
|
||||
|
||||
# Wait for start button before enabling correction mode
|
||||
if waiting_for_takeover and events["start_next_episode"]:
|
||||
print("[RaC] Start pressed - enabling teleop control...")
|
||||
events["start_next_episode"] = False
|
||||
events["correction_active"] = True
|
||||
waiting_for_takeover = False
|
||||
was_correction_active = True
|
||||
|
||||
obs = robot.get_observation()
|
||||
obs_frame = build_dataset_frame(dataset.features, obs, prefix=OBS_STR)
|
||||
|
||||
if events["correction_active"]:
|
||||
# Human controlling - record correction data
|
||||
robot_action = teleop.get_action()
|
||||
robot.send_action(robot_action)
|
||||
stats["correction_frames"] += 1
|
||||
|
||||
# Record this frame
|
||||
action_frame = build_dataset_frame(dataset.features, robot_action, prefix=ACTION)
|
||||
frame = {**obs_frame, **action_frame, "task": single_task}
|
||||
frame_buffer.append(frame)
|
||||
stats["total_frames"] += 1
|
||||
|
||||
elif waiting_for_takeover:
|
||||
# Waiting for START - policy stopped, no recording, robot holds position
|
||||
if last_robot_action is not None:
|
||||
robot.send_action(last_robot_action)
|
||||
stats["paused_frames"] += 1
|
||||
|
||||
elif events["policy_paused"]:
|
||||
# Paused and user acknowledged - hold last position, don't record
|
||||
if last_robot_action is not None:
|
||||
robot.send_action(last_robot_action)
|
||||
stats["paused_frames"] += 1
|
||||
robot_action = last_robot_action
|
||||
|
||||
else:
|
||||
# Normal policy execution - record
|
||||
action_values = predict_action(
|
||||
observation=obs_frame,
|
||||
policy=policy,
|
||||
device=device,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
use_amp=policy.config.use_amp,
|
||||
task=single_task,
|
||||
robot_type=robot.robot_type,
|
||||
)
|
||||
robot_action: RobotAction = make_robot_action(action_values, dataset.features)
|
||||
robot.send_action(robot_action)
|
||||
last_robot_action = robot_action
|
||||
stats["autonomous_frames"] += 1
|
||||
|
||||
# Record this frame
|
||||
action_frame = build_dataset_frame(dataset.features, robot_action, prefix=ACTION)
|
||||
frame = {**obs_frame, **action_frame, "task": single_task}
|
||||
frame_buffer.append(frame)
|
||||
stats["total_frames"] += 1
|
||||
|
||||
if display_data and robot_action is not None:
|
||||
log_rerun_data(observation=obs, action=robot_action)
|
||||
|
||||
dt = time.perf_counter() - loop_start
|
||||
precise_sleep(1 / fps - dt)
|
||||
timestamp = time.perf_counter() - start_t
|
||||
|
||||
for frame in frame_buffer:
|
||||
dataset.add_frame(frame)
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
def reset_loop(
|
||||
robot: Robot,
|
||||
teleop: Teleoperator,
|
||||
events: dict,
|
||||
fps: int,
|
||||
):
|
||||
"""Reset period where human repositions environment. Two-stage: enable teleop, then start episode."""
|
||||
print("\n" + "=" * 65)
|
||||
print(" [RaC] RESET - Moving teleop to robot position...")
|
||||
print("=" * 65)
|
||||
|
||||
# Enter reset mode
|
||||
events["in_reset"] = True
|
||||
events["start_next_episode"] = False
|
||||
|
||||
# Move teleop to match robot position to avoid sudden jumps
|
||||
obs = robot.get_observation()
|
||||
robot_pos = {k: v for k, v in obs.items() if k.endswith(".pos")}
|
||||
teleop.smooth_move_to(robot_pos, duration_s=2.0, fps=50)
|
||||
|
||||
# Stage 1: Wait for user to press start to enable teleoperation
|
||||
print(" Teleop aligned. Press any key/pedal to enable teleoperation")
|
||||
while not events["start_next_episode"] and not events["stop_recording"]:
|
||||
precise_sleep(0.05)
|
||||
|
||||
if events["stop_recording"]:
|
||||
return
|
||||
|
||||
# Stage 2: Enable teleop and let user move robot to starting position
|
||||
events["start_next_episode"] = False
|
||||
teleop.disable_torque()
|
||||
print(" Teleop enabled - move robot to starting position")
|
||||
print(" Press any key/pedal to start next episode")
|
||||
|
||||
# Wait for user to signal ready for next episode
|
||||
while not events["start_next_episode"] and not events["stop_recording"]:
|
||||
loop_start = time.perf_counter()
|
||||
|
||||
action = teleop.get_action()
|
||||
robot.send_action(action)
|
||||
|
||||
dt = time.perf_counter() - loop_start
|
||||
precise_sleep(1 / fps - dt)
|
||||
|
||||
# Exit reset mode and clear flags for next episode
|
||||
events["in_reset"] = False
|
||||
events["start_next_episode"] = False
|
||||
events["exit_early"] = False
|
||||
events["policy_paused"] = False
|
||||
events["correction_active"] = False
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def rac_collect(cfg: RaCConfig) -> LeRobotDataset:
|
||||
"""Main RaC data collection function."""
|
||||
init_logging()
|
||||
logging.info(pformat(cfg.__dict__))
|
||||
|
||||
if cfg.display_data:
|
||||
init_rerun(session_name="rac_collection")
|
||||
|
||||
robot = make_robot_from_config(cfg.robot)
|
||||
teleop = make_teleoperator_from_config(cfg.teleop)
|
||||
|
||||
teleop_proc, robot_proc, obs_proc = make_identity_processors()
|
||||
|
||||
dataset_features = combine_feature_dicts(
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=teleop_proc,
|
||||
initial_features=create_initial_features(action=robot.action_features),
|
||||
use_videos=cfg.dataset.video,
|
||||
),
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=obs_proc,
|
||||
initial_features=create_initial_features(observation=robot.observation_features),
|
||||
use_videos=cfg.dataset.video,
|
||||
),
|
||||
)
|
||||
|
||||
dataset = None
|
||||
listener = None
|
||||
|
||||
try:
|
||||
if cfg.resume:
|
||||
dataset = LeRobotDataset(
|
||||
cfg.dataset.repo_id,
|
||||
root=cfg.dataset.root,
|
||||
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
|
||||
)
|
||||
if hasattr(robot, "cameras") and robot.cameras:
|
||||
dataset.start_image_writer(
|
||||
num_processes=cfg.dataset.num_image_writer_processes,
|
||||
num_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot.cameras),
|
||||
)
|
||||
else:
|
||||
dataset = LeRobotDataset.create(
|
||||
cfg.dataset.repo_id,
|
||||
cfg.dataset.fps,
|
||||
root=cfg.dataset.root,
|
||||
robot_type=robot.name,
|
||||
features=dataset_features,
|
||||
use_videos=cfg.dataset.video,
|
||||
image_writer_processes=cfg.dataset.num_image_writer_processes,
|
||||
image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera
|
||||
* len(robot.cameras if hasattr(robot, "cameras") else []),
|
||||
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
|
||||
)
|
||||
|
||||
policy = make_policy(cfg.policy, ds_meta=dataset.meta)
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy,
|
||||
pretrained_path=cfg.policy.pretrained_path,
|
||||
dataset_stats=rename_stats(dataset.meta.stats, cfg.dataset.rename_map),
|
||||
preprocessor_overrides={
|
||||
"device_processor": {"device": cfg.policy.device},
|
||||
"rename_observations_processor": {"rename_map": cfg.dataset.rename_map},
|
||||
},
|
||||
)
|
||||
|
||||
robot.connect()
|
||||
teleop.connect()
|
||||
listener, events = init_rac_keyboard_listener()
|
||||
|
||||
print("\n" + "=" * 65)
|
||||
print(" RaC (Recovery and Correction) Data Collection")
|
||||
print("=" * 65)
|
||||
print(" Policy runs autonomously until you intervene.")
|
||||
print()
|
||||
print(" Controls:")
|
||||
print(" SPACE - Pause policy (robot holds position, no recording)")
|
||||
print(" c - Take control (start correction, recording)")
|
||||
print(" → - End episode (save)")
|
||||
print(" ← - Re-record episode")
|
||||
print(" ESC - Stop session and push to hub")
|
||||
print("=" * 65 + "\n")
|
||||
|
||||
with VideoEncodingManager(dataset):
|
||||
recorded = 0
|
||||
while recorded < cfg.dataset.num_episodes and not events["stop_recording"]:
|
||||
log_say(f"RaC episode {dataset.num_episodes}", cfg.play_sounds)
|
||||
|
||||
move_robot_to_zero(robot, duration_s=2.0, fps=cfg.dataset.fps)
|
||||
|
||||
stats = rac_rollout_loop(
|
||||
robot=robot,
|
||||
teleop=teleop,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
events=events,
|
||||
fps=cfg.dataset.fps,
|
||||
control_time_s=cfg.dataset.episode_time_s,
|
||||
single_task=cfg.dataset.single_task,
|
||||
display_data=cfg.display_data,
|
||||
)
|
||||
|
||||
logging.info(f"Episode stats: {stats}")
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-recording", cfg.play_sounds)
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
dataset.save_episode()
|
||||
recorded += 1
|
||||
|
||||
# Reset between episodes
|
||||
if recorded < cfg.dataset.num_episodes and not events["stop_recording"]:
|
||||
reset_loop(
|
||||
robot=robot,
|
||||
teleop=teleop,
|
||||
events=events,
|
||||
fps=cfg.dataset.fps,
|
||||
)
|
||||
|
||||
finally:
|
||||
log_say("Stop recording", cfg.play_sounds, blocking=True)
|
||||
|
||||
if dataset:
|
||||
dataset.finalize()
|
||||
|
||||
if robot.is_connected:
|
||||
robot.disconnect()
|
||||
if teleop.is_connected:
|
||||
teleop.disconnect()
|
||||
|
||||
if not is_headless() and listener:
|
||||
listener.stop()
|
||||
|
||||
if cfg.dataset.push_to_hub:
|
||||
dataset.push_to_hub(tags=cfg.dataset.tags, private=cfg.dataset.private)
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
def main():
|
||||
from lerobot.utils.import_utils import register_third_party_plugins
|
||||
|
||||
register_third_party_plugins()
|
||||
rac_collect()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -0,0 +1,659 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
RaC (Recovery and Correction) Data Collection for OpenArms Robot.
|
||||
|
||||
This implements the RaC paradigm from "RaC: Robot Learning for Long-Horizon Tasks
|
||||
by Scaling Recovery and Correction" (Hu et al., 2025) for LeRobot with OpenArms.
|
||||
|
||||
RaC improves upon standard data collection (BC) and prior human-in-the-loop methods
|
||||
(DAgger, HG-DAgger) by explicitly collecting recovery and correction behaviors:
|
||||
|
||||
The workflow:
|
||||
1. Policy runs autonomously (teleop is idle/free)
|
||||
2. Press SPACE to pause - teleop moves to match robot position
|
||||
3. Press 'c' to take control - teleop is free, human provides RECOVERY + CORRECTION
|
||||
4. Press → to end episode (save and continue to next)
|
||||
5. Reset, then do next rollout
|
||||
|
||||
Key RaC Rules:
|
||||
- Rule 1 (Recover then Correct): Every intervention = recovery + correction (both human)
|
||||
- Rule 2 (Terminate after Intervention): Episode ends after correction
|
||||
|
||||
The recovery segment (teleoperating back to good state) is recorded as training data -
|
||||
this teaches the policy how to recover from errors.
|
||||
|
||||
Keyboard Controls:
|
||||
SPACE - Pause policy (teleop mirrors robot, no recording)
|
||||
c - Take control (teleop free, recording correction)
|
||||
→ - End episode (save and continue to next)
|
||||
← - Re-record episode
|
||||
ESC - Stop recording and push dataset to hub
|
||||
|
||||
Usage:
|
||||
python examples/rac/rac_data_collection_openarms.py \
|
||||
--robot.type=openarms_follower \
|
||||
--robot.port_right=can0 \
|
||||
--robot.port_left=can1 \
|
||||
--robot.cameras="{ left_wrist: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}, right_wrist: {type: opencv, index_or_path: 2, width: 640, height: 480, fps: 30}}" \
|
||||
--teleop.type=openarms_mini \
|
||||
--teleop.port_right=/dev/ttyUSB0 \
|
||||
--teleop.port_left=/dev/ttyUSB1 \
|
||||
--policy.path=outputs/train/my_policy/checkpoints/last/pretrained_model \
|
||||
--dataset.repo_id=my_user/rac_openarms_dataset \
|
||||
--dataset.single_task="Pick up the cube"
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from pprint import pformat
|
||||
from typing import Any
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
|
||||
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.datasets.image_writer import safe_stop_image_writer
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
|
||||
from lerobot.datasets.utils import build_dataset_frame, combine_feature_dicts
|
||||
from lerobot.datasets.video_utils import VideoEncodingManager
|
||||
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.utils import make_robot_action
|
||||
from lerobot.processor import (
|
||||
IdentityProcessorStep,
|
||||
PolicyAction,
|
||||
PolicyProcessorPipeline,
|
||||
RobotAction,
|
||||
RobotObservation,
|
||||
RobotProcessorPipeline,
|
||||
)
|
||||
from lerobot.processor.converters import (
|
||||
observation_to_transition,
|
||||
robot_action_observation_to_transition,
|
||||
transition_to_observation,
|
||||
transition_to_robot_action,
|
||||
)
|
||||
from lerobot.processor.rename_processor import rename_stats
|
||||
from lerobot.robots import Robot, RobotConfig, make_robot_from_config
|
||||
from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig # noqa: F401
|
||||
from lerobot.teleoperators import Teleoperator, TeleoperatorConfig, make_teleoperator_from_config
|
||||
from lerobot.teleoperators.openarms_mini.config_openarms_mini import OpenArmsMiniConfig # noqa: F401
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
from lerobot.utils.control_utils import is_headless, predict_action
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.utils import get_safe_torch_device, init_logging, log_say
|
||||
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
|
||||
|
||||
|
||||
@dataclass
|
||||
class RaCDatasetConfig:
|
||||
repo_id: str
|
||||
single_task: str
|
||||
root: str | Path | None = None
|
||||
fps: int = 30
|
||||
episode_time_s: float = 120
|
||||
reset_time_s: float = 30
|
||||
num_episodes: int = 50
|
||||
video: bool = True
|
||||
push_to_hub: bool = True
|
||||
private: bool = False
|
||||
tags: list[str] | None = None
|
||||
num_image_writer_processes: int = 0
|
||||
num_image_writer_threads_per_camera: int = 4
|
||||
video_encoding_batch_size: int = 1
|
||||
rename_map: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RaCConfig:
|
||||
robot: RobotConfig
|
||||
dataset: RaCDatasetConfig
|
||||
teleop: TeleoperatorConfig
|
||||
policy: PreTrainedConfig | None = None
|
||||
display_data: bool = True
|
||||
play_sounds: bool = True
|
||||
resume: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
policy_path = parser.get_path_arg("policy")
|
||||
if policy_path:
|
||||
cli_overrides = parser.get_cli_overrides("policy")
|
||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||
self.policy.pretrained_path = policy_path
|
||||
if self.policy is None:
|
||||
raise ValueError("policy.path is required")
|
||||
|
||||
@classmethod
|
||||
def __get_path_fields__(cls) -> list[str]:
|
||||
return ["policy"]
|
||||
|
||||
|
||||
def init_rac_keyboard_listener():
|
||||
"""Initialize keyboard listener with RaC-specific controls."""
|
||||
events = {
|
||||
"exit_early": False,
|
||||
"rerecord_episode": False,
|
||||
"stop_recording": False,
|
||||
"policy_paused": False, # SPACE pressed - policy paused, teleop tracking robot
|
||||
"correction_active": False, # 'c' pressed - human controlling, recording correction
|
||||
"in_reset": False, # True during reset period
|
||||
"start_next_episode": False, # Signal to start next episode
|
||||
}
|
||||
|
||||
if is_headless():
|
||||
logging.warning("Headless environment - keyboard controls unavailable")
|
||||
return None, events
|
||||
|
||||
from pynput import keyboard
|
||||
|
||||
def on_press(key):
|
||||
try:
|
||||
if events["in_reset"]:
|
||||
# During reset: any action key starts next episode
|
||||
if key == keyboard.Key.space or key == keyboard.Key.right:
|
||||
print("\n[RaC] Starting next episode...")
|
||||
events["start_next_episode"] = True
|
||||
elif hasattr(key, 'char') and key.char == 'c':
|
||||
print("\n[RaC] Starting next episode...")
|
||||
events["start_next_episode"] = True
|
||||
elif key == keyboard.Key.esc:
|
||||
print("[RaC] ESC - Stop recording, pushing to hub...")
|
||||
events["stop_recording"] = True
|
||||
events["start_next_episode"] = True
|
||||
else:
|
||||
# During episode
|
||||
if key == keyboard.Key.space:
|
||||
if not events["policy_paused"] and not events["correction_active"]:
|
||||
print("\n[RaC] ⏸ PAUSED - Policy stopped, teleop moving to robot position")
|
||||
print(" Press 'c' or START to take control")
|
||||
events["policy_paused"] = True
|
||||
elif hasattr(key, 'char') and key.char == 'c':
|
||||
if events["policy_paused"] and not events["correction_active"]:
|
||||
print("\n[RaC] ▶ START pressed - taking control")
|
||||
events["start_next_episode"] = True
|
||||
elif key == keyboard.Key.right:
|
||||
print("[RaC] → End episode")
|
||||
events["exit_early"] = True
|
||||
elif key == keyboard.Key.left:
|
||||
print("[RaC] ← Re-record episode")
|
||||
events["rerecord_episode"] = True
|
||||
events["exit_early"] = True
|
||||
elif key == keyboard.Key.esc:
|
||||
print("[RaC] ESC - Stop recording, pushing to hub...")
|
||||
events["stop_recording"] = True
|
||||
events["exit_early"] = True
|
||||
except Exception as e:
|
||||
print(f"Key error: {e}")
|
||||
|
||||
listener = keyboard.Listener(on_press=on_press)
|
||||
listener.start()
|
||||
|
||||
start_pedal_listener(events)
|
||||
|
||||
return listener, events
|
||||
|
||||
|
||||
def start_pedal_listener(events: dict):
|
||||
"""Start foot pedal listener thread if evdev is available."""
|
||||
import threading
|
||||
|
||||
try:
|
||||
from evdev import InputDevice, ecodes
|
||||
except ImportError:
|
||||
logging.info("[Pedal] evdev not installed - pedal support disabled")
|
||||
return
|
||||
|
||||
PEDAL_DEVICE = "/dev/input/by-id/usb-PCsensor_FootSwitch-event-kbd"
|
||||
KEY_LEFT = "KEY_A" # Left pedal
|
||||
KEY_RIGHT = "KEY_C" # Right pedal
|
||||
|
||||
def pedal_reader():
|
||||
try:
|
||||
dev = InputDevice(PEDAL_DEVICE)
|
||||
print(f"[Pedal] Connected: {dev.name}")
|
||||
print(f"[Pedal] Right=pause/next, Left=take control/start")
|
||||
|
||||
for ev in dev.read_loop():
|
||||
if ev.type != ecodes.EV_KEY:
|
||||
continue
|
||||
|
||||
from evdev import categorize
|
||||
key = categorize(ev)
|
||||
code = key.keycode
|
||||
if isinstance(code, (list, tuple)):
|
||||
code = code[0]
|
||||
|
||||
# Only trigger on key down
|
||||
if key.keystate != 1:
|
||||
continue
|
||||
|
||||
if events["in_reset"]:
|
||||
# During reset: either pedal starts next episode
|
||||
if code in [KEY_LEFT, KEY_RIGHT]:
|
||||
print("\n[Pedal] Starting next episode...")
|
||||
events["start_next_episode"] = True
|
||||
else:
|
||||
# During episode
|
||||
if code == KEY_RIGHT:
|
||||
# Right pedal: SPACE (pause) when running, → (next) when in correction
|
||||
if events["correction_active"]:
|
||||
print("\n[Pedal] → End episode")
|
||||
events["exit_early"] = True
|
||||
elif not events["policy_paused"]:
|
||||
print("\n[Pedal] ⏸ PAUSED - Policy stopped, teleop moving to robot")
|
||||
print(" Press left pedal to take control")
|
||||
events["policy_paused"] = True
|
||||
|
||||
elif code == KEY_LEFT:
|
||||
# Left pedal: START (take control) when paused
|
||||
if events["policy_paused"] and not events["correction_active"]:
|
||||
print("\n[Pedal] ▶ START pressed - taking control")
|
||||
events["start_next_episode"] = True
|
||||
|
||||
except FileNotFoundError:
|
||||
logging.info(f"[Pedal] Device not found: {PEDAL_DEVICE}")
|
||||
except PermissionError:
|
||||
logging.warning(f"[Pedal] Permission denied. Run: sudo setfacl -m u:$USER:rw {PEDAL_DEVICE}")
|
||||
except Exception as e:
|
||||
logging.debug(f"[Pedal] Error: {e}")
|
||||
|
||||
thread = threading.Thread(target=pedal_reader, daemon=True)
|
||||
thread.start()
|
||||
|
||||
|
||||
def make_identity_processors():
|
||||
"""Create identity processors for RaC recording."""
|
||||
teleop_proc = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[IdentityProcessorStep()],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
robot_proc = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[IdentityProcessorStep()],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
obs_proc = RobotProcessorPipeline[RobotObservation, RobotObservation](
|
||||
steps=[IdentityProcessorStep()],
|
||||
to_transition=observation_to_transition,
|
||||
to_output=transition_to_observation,
|
||||
)
|
||||
return teleop_proc, robot_proc, obs_proc
|
||||
|
||||
|
||||
def move_robot_to_zero(robot: Robot, duration_s: float = 2.0, fps: int = 50):
|
||||
"""Smoothly move all robot joints to zero position."""
|
||||
obs = robot.get_observation()
|
||||
current_pos = {k: v for k, v in obs.items() if k.endswith(".pos")}
|
||||
target_pos = {k: 0.0 for k in current_pos}
|
||||
|
||||
print(f"[RaC] Moving robot to zero position ({duration_s}s)...")
|
||||
steps = int(duration_s * fps)
|
||||
for step in range(steps + 1):
|
||||
t = step / steps
|
||||
interp_pos = {k: current_pos[k] * (1 - t) + target_pos[k] * t for k in current_pos}
|
||||
robot.send_action(interp_pos)
|
||||
time.sleep(1 / fps)
|
||||
print("[RaC] Robot at zero position.")
|
||||
|
||||
|
||||
@safe_stop_image_writer
|
||||
def rac_rollout_loop(
|
||||
robot: Robot,
|
||||
teleop: Teleoperator,
|
||||
policy: PreTrainedPolicy,
|
||||
preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
dataset: LeRobotDataset,
|
||||
events: dict,
|
||||
fps: int,
|
||||
control_time_s: float,
|
||||
single_task: str,
|
||||
display_data: bool = True,
|
||||
) -> dict:
|
||||
"""
|
||||
RaC rollout loop with two-stage intervention:
|
||||
|
||||
1. Policy runs autonomously (recording) - teleop free/idle
|
||||
2. SPACE: Policy pauses, teleop mirrors robot position (NOT recording)
|
||||
3. 'c': Human takes control, teleop torque disabled (recording correction)
|
||||
4. →: End episode
|
||||
|
||||
This allows smooth handoff - teleop tracks robot only when paused.
|
||||
"""
|
||||
policy.reset()
|
||||
preprocessor.reset()
|
||||
postprocessor.reset()
|
||||
|
||||
device = get_safe_torch_device(policy.config.device)
|
||||
frame_buffer = []
|
||||
|
||||
stats = {
|
||||
"total_frames": 0,
|
||||
"autonomous_frames": 0,
|
||||
"paused_frames": 0,
|
||||
"correction_frames": 0,
|
||||
}
|
||||
|
||||
# Start with teleop torque disabled - only enable when paused to track robot
|
||||
teleop.disable_torque()
|
||||
was_paused = False
|
||||
was_correction_active = False
|
||||
waiting_for_takeover = False
|
||||
|
||||
timestamp = 0
|
||||
start_t = time.perf_counter()
|
||||
|
||||
while timestamp < control_time_s:
|
||||
loop_start = time.perf_counter()
|
||||
|
||||
if events["exit_early"]:
|
||||
events["exit_early"] = False
|
||||
events["policy_paused"] = False
|
||||
events["correction_active"] = False
|
||||
break
|
||||
|
||||
# Detect transition to paused state - smooth move teleop to robot position
|
||||
if events["policy_paused"] and not was_paused:
|
||||
obs = robot.get_observation()
|
||||
obs_filtered = {k: v for k, v in obs.items() if k in robot.observation_features}
|
||||
robot_pos = {k: v for k, v in obs_filtered.items() if k.endswith(".pos")}
|
||||
print("[RaC] Moving teleop to robot position (2s smooth transition)...")
|
||||
teleop.smooth_move_to(robot_pos, duration_s=2.0, fps=50)
|
||||
print("[RaC] Teleop aligned. Press START to take control.")
|
||||
events["start_next_episode"] = False
|
||||
waiting_for_takeover = True
|
||||
was_paused = True
|
||||
|
||||
# Wait for start button before enabling correction mode
|
||||
if waiting_for_takeover and events["start_next_episode"]:
|
||||
print("[RaC] Start pressed - enabling teleop control...")
|
||||
teleop.disable_torque()
|
||||
events["start_next_episode"] = False
|
||||
events["correction_active"] = True
|
||||
waiting_for_takeover = False
|
||||
was_correction_active = True
|
||||
|
||||
obs = robot.get_observation()
|
||||
obs_filtered = {k: v for k, v in obs.items() if k in robot.observation_features}
|
||||
obs_frame = build_dataset_frame(dataset.features, obs_filtered, prefix=OBS_STR)
|
||||
|
||||
if events["correction_active"]:
|
||||
# Human controlling - record correction data
|
||||
robot_action = teleop.get_action()
|
||||
# Convert gripper from teleop range (0-100) to robot degrees (-65 to 0)
|
||||
for key in robot_action:
|
||||
if "gripper" in key:
|
||||
robot_action[key] = -0.65 * robot_action[key]
|
||||
robot.send_action(robot_action)
|
||||
stats["correction_frames"] += 1
|
||||
|
||||
# Record this frame
|
||||
action_frame = build_dataset_frame(dataset.features, robot_action, prefix=ACTION)
|
||||
frame = {**obs_frame, **action_frame, "task": single_task}
|
||||
frame_buffer.append(frame)
|
||||
stats["total_frames"] += 1
|
||||
|
||||
elif waiting_for_takeover:
|
||||
# Waiting for START - policy stopped, no recording, robot holds position
|
||||
stats["paused_frames"] += 1
|
||||
|
||||
elif events["policy_paused"]:
|
||||
# Paused and user acknowledged - teleop tracks robot position, don't record
|
||||
robot_action = {k: v for k, v in obs_filtered.items() if k.endswith(".pos")}
|
||||
teleop.send_feedback(robot_action)
|
||||
stats["paused_frames"] += 1
|
||||
|
||||
else:
|
||||
# Normal policy execution - record (teleop is free/idle)
|
||||
action_values = predict_action(
|
||||
observation=obs_frame,
|
||||
policy=policy,
|
||||
device=device,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
use_amp=policy.config.use_amp,
|
||||
task=single_task,
|
||||
robot_type=robot.robot_type,
|
||||
)
|
||||
robot_action: RobotAction = make_robot_action(action_values, dataset.features)
|
||||
robot.send_action(robot_action)
|
||||
stats["autonomous_frames"] += 1
|
||||
|
||||
# Record this frame
|
||||
action_frame = build_dataset_frame(dataset.features, robot_action, prefix=ACTION)
|
||||
frame = {**obs_frame, **action_frame, "task": single_task}
|
||||
frame_buffer.append(frame)
|
||||
stats["total_frames"] += 1
|
||||
|
||||
if display_data:
|
||||
log_rerun_data(observation=obs_filtered, action=robot_action)
|
||||
|
||||
dt = time.perf_counter() - loop_start
|
||||
precise_sleep(1 / fps - dt)
|
||||
timestamp = time.perf_counter() - start_t
|
||||
|
||||
# Ensure teleoperator torque is disabled at end
|
||||
teleop.disable_torque()
|
||||
|
||||
for frame in frame_buffer:
|
||||
dataset.add_frame(frame)
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
def reset_loop(
|
||||
robot: Robot,
|
||||
teleop: Teleoperator,
|
||||
events: dict,
|
||||
fps: int,
|
||||
):
|
||||
"""Reset period where human repositions environment. Two-stage: enable teleop, then start episode."""
|
||||
print("\n" + "=" * 65)
|
||||
print(" [RaC] RESET - Moving teleop to robot position...")
|
||||
print("=" * 65)
|
||||
|
||||
# Enter reset mode
|
||||
events["in_reset"] = True
|
||||
events["start_next_episode"] = False
|
||||
|
||||
# First move teleop to match robot position to avoid sudden jumps
|
||||
obs = robot.get_observation()
|
||||
robot_pos = {k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features}
|
||||
teleop.smooth_move_to(robot_pos, duration_s=2.0, fps=50)
|
||||
|
||||
# Stage 1: Wait for user to press start to enable teleoperation
|
||||
print(" Teleop aligned. Press any key/pedal to enable teleoperation")
|
||||
while not events["start_next_episode"] and not events["stop_recording"]:
|
||||
precise_sleep(0.05)
|
||||
|
||||
if events["stop_recording"]:
|
||||
return
|
||||
|
||||
# Stage 2: Enable teleop and let user move robot to starting position
|
||||
events["start_next_episode"] = False
|
||||
teleop.disable_torque()
|
||||
print(" Teleop enabled - move robot to starting position")
|
||||
print(" Press any key/pedal to start next episode")
|
||||
|
||||
# Wait for user to signal ready for next episode
|
||||
while not events["start_next_episode"] and not events["stop_recording"]:
|
||||
loop_start = time.perf_counter()
|
||||
|
||||
action = teleop.get_action()
|
||||
# Convert gripper from teleop range (0-100) to robot degrees (-65 to 0)
|
||||
for key in action:
|
||||
if "gripper" in key:
|
||||
action[key] = -0.65 * action[key]
|
||||
robot.send_action(action)
|
||||
|
||||
dt = time.perf_counter() - loop_start
|
||||
precise_sleep(1 / fps - dt)
|
||||
|
||||
# Exit reset mode and clear flags for next episode
|
||||
events["in_reset"] = False
|
||||
events["start_next_episode"] = False
|
||||
events["exit_early"] = False
|
||||
events["policy_paused"] = False
|
||||
events["correction_active"] = False
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def rac_collect(cfg: RaCConfig) -> LeRobotDataset:
|
||||
"""Main RaC data collection function."""
|
||||
init_logging()
|
||||
logging.info(pformat(cfg.__dict__))
|
||||
|
||||
if cfg.display_data:
|
||||
init_rerun(session_name="rac_collection_openarms")
|
||||
|
||||
robot = make_robot_from_config(cfg.robot)
|
||||
teleop = make_teleoperator_from_config(cfg.teleop)
|
||||
|
||||
teleop_proc, robot_proc, obs_proc = make_identity_processors()
|
||||
|
||||
dataset_features = combine_feature_dicts(
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=teleop_proc,
|
||||
initial_features=create_initial_features(action=robot.action_features),
|
||||
use_videos=cfg.dataset.video,
|
||||
),
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=obs_proc,
|
||||
initial_features=create_initial_features(observation=robot.observation_features),
|
||||
use_videos=cfg.dataset.video,
|
||||
),
|
||||
)
|
||||
|
||||
dataset = None
|
||||
listener = None
|
||||
|
||||
try:
|
||||
if cfg.resume:
|
||||
dataset = LeRobotDataset(
|
||||
cfg.dataset.repo_id,
|
||||
root=cfg.dataset.root,
|
||||
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
|
||||
)
|
||||
if hasattr(robot, "cameras") and robot.cameras:
|
||||
dataset.start_image_writer(
|
||||
num_processes=cfg.dataset.num_image_writer_processes,
|
||||
num_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot.cameras),
|
||||
)
|
||||
else:
|
||||
dataset = LeRobotDataset.create(
|
||||
cfg.dataset.repo_id,
|
||||
cfg.dataset.fps,
|
||||
root=cfg.dataset.root,
|
||||
robot_type=robot.name,
|
||||
features=dataset_features,
|
||||
use_videos=cfg.dataset.video,
|
||||
image_writer_processes=cfg.dataset.num_image_writer_processes,
|
||||
image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera
|
||||
* len(robot.cameras if hasattr(robot, "cameras") else []),
|
||||
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
|
||||
)
|
||||
|
||||
policy = make_policy(cfg.policy, ds_meta=dataset.meta)
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy,
|
||||
pretrained_path=cfg.policy.pretrained_path,
|
||||
dataset_stats=rename_stats(dataset.meta.stats, cfg.dataset.rename_map),
|
||||
preprocessor_overrides={
|
||||
"device_processor": {"device": cfg.policy.device},
|
||||
"rename_observations_processor": {"rename_map": cfg.dataset.rename_map},
|
||||
},
|
||||
)
|
||||
|
||||
robot.connect()
|
||||
teleop.connect()
|
||||
listener, events = init_rac_keyboard_listener()
|
||||
|
||||
print("\n" + "=" * 65)
|
||||
print(" RaC (Recovery and Correction) Data Collection - OpenArms")
|
||||
print("=" * 65)
|
||||
print(" Policy runs autonomously until you intervene.")
|
||||
print()
|
||||
print(" Controls:")
|
||||
print(" SPACE - Pause policy (teleop tracks robot, no recording)")
|
||||
print(" c - Take control (start correction, recording)")
|
||||
print(" → - End episode (save)")
|
||||
print(" ← - Re-record episode")
|
||||
print(" ESC - Stop session and push to hub")
|
||||
print("=" * 65 + "\n")
|
||||
|
||||
with VideoEncodingManager(dataset):
|
||||
recorded = 0
|
||||
while recorded < cfg.dataset.num_episodes and not events["stop_recording"]:
|
||||
log_say(f"RaC episode {dataset.num_episodes}", cfg.play_sounds)
|
||||
|
||||
move_robot_to_zero(robot, duration_s=2.0, fps=cfg.dataset.fps)
|
||||
|
||||
stats = rac_rollout_loop(
|
||||
robot=robot,
|
||||
teleop=teleop,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
events=events,
|
||||
fps=cfg.dataset.fps,
|
||||
control_time_s=cfg.dataset.episode_time_s,
|
||||
single_task=cfg.dataset.single_task,
|
||||
display_data=cfg.display_data,
|
||||
)
|
||||
|
||||
logging.info(f"Episode stats: {stats}")
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-recording", cfg.play_sounds)
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
dataset.save_episode()
|
||||
recorded += 1
|
||||
|
||||
# Reset between episodes
|
||||
if recorded < cfg.dataset.num_episodes and not events["stop_recording"]:
|
||||
reset_loop(
|
||||
robot=robot,
|
||||
teleop=teleop,
|
||||
events=events,
|
||||
fps=cfg.dataset.fps,
|
||||
)
|
||||
|
||||
finally:
|
||||
log_say("Stop recording", cfg.play_sounds, blocking=True)
|
||||
|
||||
if dataset:
|
||||
dataset.finalize()
|
||||
|
||||
if robot.is_connected:
|
||||
robot.disconnect()
|
||||
if teleop.is_connected:
|
||||
teleop.disconnect()
|
||||
|
||||
if not is_headless() and listener:
|
||||
listener.stop()
|
||||
|
||||
if cfg.dataset.push_to_hub:
|
||||
dataset.push_to_hub(tags=cfg.dataset.tags, private=cfg.dataset.private)
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
def main():
|
||||
from lerobot.utils.import_utils import register_third_party_plugins
|
||||
|
||||
register_third_party_plugins()
|
||||
rac_collect()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -0,0 +1,877 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
RaC (Recovery and Correction) Data Collection for OpenArms Robot with RTC.
|
||||
|
||||
This combines RaC data collection with Real-Time Chunking (RTC) for smooth policy execution.
|
||||
RTC enables large flow-matching policies (Pi0, Pi0.5, SmolVLA) to produce reactive motion
|
||||
despite high inference latency by asynchronously generating action chunks.
|
||||
|
||||
The workflow:
|
||||
1. Policy runs autonomously with RTC (teleop is idle/free)
|
||||
2. Press SPACE to pause - teleop moves to match robot position
|
||||
3. Press 'c' to take control - teleop is free, human provides RECOVERY + CORRECTION
|
||||
4. Press → to end episode (save and continue to next)
|
||||
5. Reset, then do next rollout
|
||||
|
||||
Usage:
|
||||
python examples/rac/rac_data_collection_openarms_rtc.py \
|
||||
--robot.port_right=can0 \
|
||||
--robot.port_left=can1 \
|
||||
--teleop.port_right=/dev/ttyUSB0 \
|
||||
--teleop.port_left=/dev/ttyUSB1 \
|
||||
--policy.path=outputs/train/my_policy/checkpoints/last/pretrained_model \
|
||||
--dataset.repo_id=my_user/rac_openarms_dataset \
|
||||
--dataset.single_task="Pick up the cube"
|
||||
"""
|
||||
|
||||
import logging
|
||||
import math
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from pprint import pformat
|
||||
from threading import Event, Lock, Thread
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
|
||||
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import RTCAttentionSchedule
|
||||
from lerobot.datasets.image_writer import safe_stop_image_writer
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
|
||||
from lerobot.datasets.utils import build_dataset_frame, combine_feature_dicts, hw_to_dataset_features
|
||||
from lerobot.datasets.video_utils import VideoEncodingManager
|
||||
from lerobot.policies.factory import get_policy_class, make_pre_post_processors
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.rtc.action_queue import ActionQueue
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.policies.rtc.latency_tracker import LatencyTracker
|
||||
from lerobot.policies.utils import make_robot_action
|
||||
from lerobot.processor import (
|
||||
IdentityProcessorStep,
|
||||
PolicyAction,
|
||||
PolicyProcessorPipeline,
|
||||
RobotAction,
|
||||
RobotObservation,
|
||||
RobotProcessorPipeline,
|
||||
)
|
||||
from lerobot.processor.converters import (
|
||||
observation_to_transition,
|
||||
robot_action_observation_to_transition,
|
||||
transition_to_observation,
|
||||
transition_to_robot_action,
|
||||
)
|
||||
from lerobot.processor.rename_processor import rename_stats
|
||||
from lerobot.robots import Robot, RobotConfig, make_robot_from_config
|
||||
from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig # noqa: F401
|
||||
from lerobot.teleoperators import Teleoperator, TeleoperatorConfig, make_teleoperator_from_config
|
||||
from lerobot.teleoperators.openarms_mini.config_openarms_mini import OpenArmsMiniConfig # noqa: F401
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
from lerobot.utils.control_utils import is_headless, predict_action
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.utils import get_safe_torch_device, init_logging, log_say
|
||||
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Configuration
|
||||
# ============================================================================
|
||||
|
||||
@dataclass
|
||||
class RaCRTCDatasetConfig:
|
||||
repo_id: str = "lerobot/rac_openarms_rtc"
|
||||
single_task: str = "default task"
|
||||
root: str | Path | None = None
|
||||
fps: int = 30
|
||||
episode_time_s: float = 500
|
||||
reset_time_s: float = 30
|
||||
num_episodes: int = 50
|
||||
video: bool = True
|
||||
push_to_hub: bool = True
|
||||
private: bool = False
|
||||
tags: list[str] | None = None
|
||||
num_image_writer_processes: int = 0
|
||||
num_image_writer_threads_per_camera: int = 4
|
||||
video_encoding_batch_size: int = 1
|
||||
rename_map: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RaCRTCConfig:
|
||||
robot: RobotConfig = field(default_factory=lambda: OpenArmsFollowerConfig(
|
||||
port_left="can0",
|
||||
port_right="can1",
|
||||
))
|
||||
teleop: TeleoperatorConfig = field(default_factory=lambda: OpenArmsMiniConfig(
|
||||
port_left="/dev/ttyUSB1",
|
||||
port_right="/dev/ttyUSB0",
|
||||
))
|
||||
dataset: RaCRTCDatasetConfig = field(default_factory=RaCRTCDatasetConfig)
|
||||
policy: PreTrainedConfig | None = None
|
||||
|
||||
rtc: RTCConfig = field(default_factory=lambda: RTCConfig(
|
||||
enabled=True,
|
||||
execution_horizon=20,
|
||||
max_guidance_weight=5.0,
|
||||
prefix_attention_schedule=RTCAttentionSchedule.LINEAR,
|
||||
))
|
||||
|
||||
interpolation: bool = True
|
||||
display_data: bool = True
|
||||
play_sounds: bool = True
|
||||
resume: bool = False
|
||||
device: str = "cuda"
|
||||
action_queue_size_to_get_new_actions: int = 30
|
||||
|
||||
def __post_init__(self):
|
||||
policy_path = parser.get_path_arg("policy")
|
||||
if policy_path:
|
||||
cli_overrides = parser.get_cli_overrides("policy")
|
||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||
self.policy.pretrained_path = policy_path
|
||||
if self.policy is None:
|
||||
raise ValueError("policy.path is required")
|
||||
|
||||
@classmethod
|
||||
def __get_path_fields__(cls) -> list[str]:
|
||||
return ["policy"]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Thread-Safe Robot Wrapper (from evaluate_with_rtc.py)
|
||||
# ============================================================================
|
||||
|
||||
class RobotWrapper:
|
||||
"""Thread-safe wrapper for robot operations."""
|
||||
|
||||
def __init__(self, robot: Robot):
|
||||
self.robot = robot
|
||||
self.lock = Lock()
|
||||
|
||||
def get_observation(self) -> dict[str, Tensor]:
|
||||
with self.lock:
|
||||
return self.robot.get_observation()
|
||||
|
||||
def send_action(self, action: dict) -> None:
|
||||
with self.lock:
|
||||
self.robot.send_action(action)
|
||||
|
||||
@property
|
||||
def observation_features(self) -> dict:
|
||||
return self.robot.observation_features
|
||||
|
||||
@property
|
||||
def action_features(self) -> dict:
|
||||
return self.robot.action_features
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self.robot.name
|
||||
|
||||
@property
|
||||
def robot_type(self) -> str:
|
||||
return self.robot.robot_type
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Keyboard/Pedal Listeners
|
||||
# ============================================================================
|
||||
|
||||
def init_rac_keyboard_listener():
|
||||
"""Initialize keyboard listener with RaC-specific controls."""
|
||||
events = {
|
||||
"exit_early": False,
|
||||
"rerecord_episode": False,
|
||||
"stop_recording": False,
|
||||
"policy_paused": False,
|
||||
"correction_active": False,
|
||||
"in_reset": False,
|
||||
"start_next_episode": False,
|
||||
}
|
||||
|
||||
if is_headless():
|
||||
logging.warning("Headless environment - keyboard controls unavailable")
|
||||
return None, events
|
||||
|
||||
from pynput import keyboard
|
||||
|
||||
def on_press(key):
|
||||
try:
|
||||
if events["in_reset"]:
|
||||
if key == keyboard.Key.space or key == keyboard.Key.right:
|
||||
print("\n[RaC] Starting next episode...")
|
||||
events["start_next_episode"] = True
|
||||
elif hasattr(key, 'char') and key.char == 'c':
|
||||
print("\n[RaC] Starting next episode...")
|
||||
events["start_next_episode"] = True
|
||||
elif key == keyboard.Key.esc:
|
||||
print("[RaC] ESC - Stop recording, pushing to hub...")
|
||||
events["stop_recording"] = True
|
||||
events["start_next_episode"] = True
|
||||
else:
|
||||
if key == keyboard.Key.space:
|
||||
if not events["policy_paused"] and not events["correction_active"]:
|
||||
print("\n[RaC] ⏸ PAUSED - Policy stopped, teleop moving to robot position")
|
||||
print(" Press 'c' or START to take control")
|
||||
events["policy_paused"] = True
|
||||
elif hasattr(key, 'char') and key.char == 'c':
|
||||
if events["policy_paused"] and not events["correction_active"]:
|
||||
print("\n[RaC] ▶ START pressed - taking control")
|
||||
events["start_next_episode"] = True
|
||||
elif key == keyboard.Key.right:
|
||||
print("[RaC] → End episode")
|
||||
events["exit_early"] = True
|
||||
elif key == keyboard.Key.left:
|
||||
print("[RaC] ← Re-record episode")
|
||||
events["rerecord_episode"] = True
|
||||
events["exit_early"] = True
|
||||
elif key == keyboard.Key.esc:
|
||||
print("[RaC] ESC - Stop recording, pushing to hub...")
|
||||
events["stop_recording"] = True
|
||||
events["exit_early"] = True
|
||||
except Exception as e:
|
||||
print(f"Key error: {e}")
|
||||
|
||||
listener = keyboard.Listener(on_press=on_press)
|
||||
listener.start()
|
||||
|
||||
start_pedal_listener(events)
|
||||
|
||||
return listener, events
|
||||
|
||||
|
||||
def start_pedal_listener(events: dict):
|
||||
"""Start foot pedal listener thread if evdev is available."""
|
||||
import threading
|
||||
|
||||
try:
|
||||
from evdev import InputDevice, ecodes # noqa: F401
|
||||
except ImportError:
|
||||
logging.info("[Pedal] evdev not installed - pedal support disabled")
|
||||
return
|
||||
|
||||
PEDAL_DEVICE = "/dev/input/by-id/usb-PCsensor_FootSwitch-event-kbd"
|
||||
KEY_LEFT = "KEY_A"
|
||||
KEY_RIGHT = "KEY_C"
|
||||
|
||||
def pedal_reader():
|
||||
try:
|
||||
dev = InputDevice(PEDAL_DEVICE)
|
||||
print(f"[Pedal] Connected: {dev.name}")
|
||||
|
||||
for ev in dev.read_loop():
|
||||
if ev.type != ecodes.EV_KEY:
|
||||
continue
|
||||
|
||||
from evdev import categorize # noqa: F401
|
||||
key = categorize(ev)
|
||||
code = key.keycode
|
||||
if isinstance(code, (list, tuple)):
|
||||
code = code[0]
|
||||
|
||||
if key.keystate != 1:
|
||||
continue
|
||||
|
||||
if events["in_reset"]:
|
||||
if code in [KEY_LEFT, KEY_RIGHT]:
|
||||
events["start_next_episode"] = True
|
||||
else:
|
||||
if code == KEY_RIGHT:
|
||||
if events["correction_active"]:
|
||||
events["exit_early"] = True
|
||||
elif not events["policy_paused"]:
|
||||
events["policy_paused"] = True
|
||||
elif code == KEY_LEFT:
|
||||
if events["policy_paused"] and not events["correction_active"]:
|
||||
events["start_next_episode"] = True
|
||||
|
||||
except FileNotFoundError:
|
||||
logging.info(f"[Pedal] Device not found: {PEDAL_DEVICE}")
|
||||
except PermissionError:
|
||||
logging.warning(f"[Pedal] Permission denied for {PEDAL_DEVICE}")
|
||||
except Exception as e:
|
||||
logging.debug(f"[Pedal] Error: {e}")
|
||||
|
||||
thread = threading.Thread(target=pedal_reader, daemon=True)
|
||||
thread.start()
|
||||
|
||||
|
||||
def make_identity_processors():
|
||||
"""Create identity processors for RaC recording."""
|
||||
teleop_proc = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[IdentityProcessorStep()],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
robot_proc = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[IdentityProcessorStep()],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
obs_proc = RobotProcessorPipeline[RobotObservation, RobotObservation](
|
||||
steps=[IdentityProcessorStep()],
|
||||
to_transition=observation_to_transition,
|
||||
to_output=transition_to_observation,
|
||||
)
|
||||
return teleop_proc, robot_proc, obs_proc
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# RTC Inference Thread (from evaluate_with_rtc.py)
|
||||
# ============================================================================
|
||||
|
||||
def rtc_inference_thread(
|
||||
policy,
|
||||
obs_holder: dict,
|
||||
hw_features: dict,
|
||||
preprocessor,
|
||||
postprocessor,
|
||||
queue_holder: dict,
|
||||
shutdown_event: Event,
|
||||
policy_active: Event,
|
||||
cfg: RaCRTCConfig,
|
||||
):
|
||||
"""Background thread that generates action chunks using RTC."""
|
||||
try:
|
||||
logger.info("[RTC] ========== INFERENCE THREAD STARTED ==========")
|
||||
logger.info(f"[RTC] policy={policy.name}, hw_features has {len(hw_features)} keys")
|
||||
|
||||
latency_tracker = LatencyTracker()
|
||||
time_per_chunk = 1.0 / cfg.dataset.fps
|
||||
policy_device = policy.config.device
|
||||
|
||||
get_actions_threshold = cfg.action_queue_size_to_get_new_actions
|
||||
if not cfg.rtc.enabled:
|
||||
get_actions_threshold = 0
|
||||
|
||||
inference_count = 0
|
||||
wait_logged = False
|
||||
|
||||
while not shutdown_event.is_set():
|
||||
if not policy_active.is_set():
|
||||
if not wait_logged:
|
||||
logger.info("[RTC] Waiting for policy_active...")
|
||||
wait_logged = True
|
||||
time.sleep(0.01)
|
||||
continue
|
||||
|
||||
wait_logged = False
|
||||
|
||||
action_queue = queue_holder["queue"]
|
||||
if action_queue is None:
|
||||
logger.warning("[RTC] queue_holder['queue'] is None!")
|
||||
time.sleep(0.01)
|
||||
continue
|
||||
|
||||
obs_filtered = obs_holder.get("obs")
|
||||
if obs_filtered is None:
|
||||
logger.warning("[RTC] obs_holder['obs'] is None!")
|
||||
time.sleep(0.01)
|
||||
continue
|
||||
|
||||
qsize = action_queue.qsize()
|
||||
if qsize <= get_actions_threshold:
|
||||
try:
|
||||
if inference_count == 0:
|
||||
logger.info(f"[RTC] Starting first inference, obs keys={len(obs_filtered)}, qsize={qsize}")
|
||||
|
||||
current_time = time.perf_counter()
|
||||
action_index_before_inference = action_queue.get_action_index()
|
||||
prev_actions = action_queue.get_left_over()
|
||||
|
||||
inference_latency = latency_tracker.max()
|
||||
inference_delay = math.ceil(inference_latency / time_per_chunk) if inference_latency else 0
|
||||
|
||||
obs_with_policy_features = build_dataset_frame(hw_features, obs_filtered, prefix="observation")
|
||||
|
||||
for name in obs_with_policy_features:
|
||||
obs_with_policy_features[name] = torch.from_numpy(obs_with_policy_features[name])
|
||||
if "image" in name:
|
||||
obs_with_policy_features[name] = obs_with_policy_features[name].float() / 255
|
||||
obs_with_policy_features[name] = obs_with_policy_features[name].permute(2, 0, 1).contiguous()
|
||||
obs_with_policy_features[name] = obs_with_policy_features[name].unsqueeze(0).to(policy_device)
|
||||
|
||||
obs_with_policy_features["task"] = [cfg.dataset.single_task]
|
||||
obs_with_policy_features["robot_type"] = obs_holder.get("robot_type", "openarms_follower")
|
||||
|
||||
preprocessed_obs = preprocessor(obs_with_policy_features)
|
||||
|
||||
actions = policy.predict_action_chunk(
|
||||
preprocessed_obs,
|
||||
inference_delay=inference_delay,
|
||||
prev_chunk_left_over=prev_actions,
|
||||
)
|
||||
|
||||
original_actions = actions.squeeze(0).clone()
|
||||
postprocessed_actions = postprocessor(actions).squeeze(0)
|
||||
|
||||
new_latency = time.perf_counter() - current_time
|
||||
new_delay = math.ceil(new_latency / time_per_chunk)
|
||||
latency_tracker.add(new_latency)
|
||||
|
||||
action_queue.merge(original_actions, postprocessed_actions, new_delay, action_index_before_inference)
|
||||
|
||||
inference_count += 1
|
||||
logger.info(f"[RTC] Inference #{inference_count}, latency={new_latency:.2f}s, queue={action_queue.qsize()}")
|
||||
except Exception as e:
|
||||
logger.error(f"[RTC] Inference error: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
time.sleep(1.0)
|
||||
else:
|
||||
time.sleep(0.01)
|
||||
|
||||
logger.info("[RTC] Inference thread shutting down")
|
||||
except Exception as e:
|
||||
logger.error(f"[RTC] THREAD CRASHED: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Main Rollout Loop
|
||||
# ============================================================================
|
||||
|
||||
@safe_stop_image_writer
|
||||
def rac_rtc_rollout_loop(
|
||||
robot: RobotWrapper,
|
||||
teleop: Teleoperator,
|
||||
policy: PreTrainedPolicy,
|
||||
preprocessor,
|
||||
postprocessor,
|
||||
dataset: LeRobotDataset,
|
||||
events: dict,
|
||||
cfg: RaCRTCConfig,
|
||||
queue_holder: dict,
|
||||
obs_holder: dict, # Main loop writes obs here for RTC thread to read
|
||||
policy_active: Event,
|
||||
hw_features: dict,
|
||||
) -> dict:
|
||||
"""RaC rollout loop with RTC for smooth policy execution."""
|
||||
fps = cfg.dataset.fps
|
||||
single_task = cfg.dataset.single_task
|
||||
control_time_s = cfg.dataset.episode_time_s
|
||||
device = get_safe_torch_device(cfg.device)
|
||||
|
||||
# Reset policy state
|
||||
policy.reset()
|
||||
preprocessor.reset()
|
||||
postprocessor.reset()
|
||||
|
||||
frame_buffer = []
|
||||
stats = {
|
||||
"total_frames": 0,
|
||||
"autonomous_frames": 0,
|
||||
"paused_frames": 0,
|
||||
"correction_frames": 0,
|
||||
}
|
||||
|
||||
teleop.disable_torque()
|
||||
was_paused = False
|
||||
waiting_for_takeover = False
|
||||
|
||||
# Action keys for converting tensor to dict
|
||||
action_keys = [k for k in robot.action_features.keys() if k.endswith(".pos")]
|
||||
|
||||
# Interpolation state
|
||||
prev_action: Tensor | None = None
|
||||
interpolated_actions: list[Tensor] = []
|
||||
interp_idx = 0
|
||||
|
||||
if cfg.interpolation:
|
||||
control_interval = 1.0 / (fps * 2) # 2x rate
|
||||
else:
|
||||
control_interval = 1.0 / fps
|
||||
|
||||
robot_action = {}
|
||||
timestamp = 0
|
||||
start_t = time.perf_counter()
|
||||
|
||||
while timestamp < control_time_s:
|
||||
loop_start = time.perf_counter()
|
||||
|
||||
if events["exit_early"]:
|
||||
events["exit_early"] = False
|
||||
events["policy_paused"] = False
|
||||
events["correction_active"] = False
|
||||
break
|
||||
|
||||
# State transition: entering paused state
|
||||
if events["policy_paused"] and not was_paused:
|
||||
policy_active.clear() # Stop RTC inference
|
||||
obs = robot.get_observation()
|
||||
obs_filtered = {k: v for k, v in obs.items() if k in robot.observation_features}
|
||||
robot_pos = {k: v for k, v in obs_filtered.items() if k.endswith(".pos")}
|
||||
print("[RaC] Moving teleop to robot position...")
|
||||
teleop.smooth_move_to(robot_pos, duration_s=2.0, fps=50)
|
||||
print("[RaC] Teleop aligned. Press 'c' to take control.")
|
||||
events["start_next_episode"] = False
|
||||
waiting_for_takeover = True
|
||||
was_paused = True
|
||||
# Reset interpolation
|
||||
prev_action = None
|
||||
interpolated_actions = []
|
||||
interp_idx = 0
|
||||
|
||||
# Wait for takeover
|
||||
if waiting_for_takeover and events["start_next_episode"]:
|
||||
print("[RaC] Taking control...")
|
||||
teleop.disable_torque()
|
||||
events["start_next_episode"] = False
|
||||
events["correction_active"] = True
|
||||
waiting_for_takeover = False
|
||||
|
||||
# Get observation (ONLY the main loop reads from robot!)
|
||||
obs = robot.get_observation()
|
||||
obs_filtered = {k: v for k, v in obs.items() if k in robot.observation_features}
|
||||
obs_frame = build_dataset_frame(dataset.features, obs_filtered, prefix=OBS_STR)
|
||||
|
||||
# Share observation with RTC thread (thread reads, main loop writes)
|
||||
obs_holder["obs"] = obs_filtered
|
||||
|
||||
if events["correction_active"]:
|
||||
# Human controlling
|
||||
robot_action = teleop.get_action()
|
||||
for key in robot_action:
|
||||
if "gripper" in key:
|
||||
robot_action[key] = -0.65 * robot_action[key]
|
||||
robot.send_action(robot_action)
|
||||
stats["correction_frames"] += 1
|
||||
|
||||
action_frame = build_dataset_frame(dataset.features, robot_action, prefix=ACTION)
|
||||
frame = {**obs_frame, **action_frame, "task": single_task}
|
||||
frame_buffer.append(frame)
|
||||
stats["total_frames"] += 1
|
||||
|
||||
elif waiting_for_takeover:
|
||||
stats["paused_frames"] += 1
|
||||
|
||||
elif events["policy_paused"]:
|
||||
robot_pos = {k: v for k, v in obs_filtered.items() if k.endswith(".pos")}
|
||||
teleop.send_feedback(robot_pos)
|
||||
stats["paused_frames"] += 1
|
||||
|
||||
else:
|
||||
# Policy execution with RTC
|
||||
if not policy_active.is_set():
|
||||
policy_active.set()
|
||||
logger.info("[ROLLOUT] Policy activated, waiting for first actions...")
|
||||
|
||||
action_queue = queue_holder["queue"]
|
||||
|
||||
# Get action from queue (with interpolation)
|
||||
if interp_idx >= len(interpolated_actions):
|
||||
new_action = action_queue.get() if action_queue else None
|
||||
|
||||
# Log queue status periodically
|
||||
if stats["autonomous_frames"] == 0 and new_action is None:
|
||||
qsize = action_queue.qsize() if action_queue else -1
|
||||
if timestamp < 0.5 or int(timestamp * 10) % 10 == 0:
|
||||
logger.info(f"[ROLLOUT] Waiting for actions... queue_size={qsize}, obs_set={obs_holder.get('obs') is not None}")
|
||||
|
||||
if new_action is not None:
|
||||
current_action = new_action.cpu()
|
||||
|
||||
if cfg.interpolation and prev_action is not None:
|
||||
mid = prev_action + 0.5 * (current_action - prev_action)
|
||||
interpolated_actions = [mid, current_action]
|
||||
else:
|
||||
interpolated_actions = [current_action]
|
||||
|
||||
prev_action = current_action
|
||||
interp_idx = 0
|
||||
|
||||
if stats["autonomous_frames"] == 0:
|
||||
logger.info(f"[ROLLOUT] Got first action! Starting robot motion.")
|
||||
|
||||
if interp_idx < len(interpolated_actions):
|
||||
action_to_send = interpolated_actions[interp_idx]
|
||||
interp_idx += 1
|
||||
|
||||
robot_action = {}
|
||||
for i, key in enumerate(action_keys):
|
||||
if i < len(action_to_send):
|
||||
robot_action[key] = action_to_send[i].item()
|
||||
|
||||
robot.send_action(robot_action)
|
||||
stats["autonomous_frames"] += 1
|
||||
|
||||
# Record at original fps
|
||||
action_frame = build_dataset_frame(dataset.features, robot_action, prefix=ACTION)
|
||||
frame = {**obs_frame, **action_frame, "task": single_task}
|
||||
frame_buffer.append(frame)
|
||||
stats["total_frames"] += 1
|
||||
|
||||
if cfg.display_data:
|
||||
log_rerun_data(observation=obs_filtered, action=robot_action)
|
||||
|
||||
dt = time.perf_counter() - loop_start
|
||||
sleep_time = control_interval - dt
|
||||
if sleep_time > 0:
|
||||
precise_sleep(sleep_time)
|
||||
timestamp = time.perf_counter() - start_t
|
||||
|
||||
policy_active.clear()
|
||||
teleop.disable_torque()
|
||||
|
||||
for frame in frame_buffer:
|
||||
dataset.add_frame(frame)
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
def reset_loop(robot: RobotWrapper, teleop: Teleoperator, events: dict, fps: int):
|
||||
"""Reset period where human repositions environment."""
|
||||
print("\n" + "=" * 65)
|
||||
print(" [RaC] RESET")
|
||||
print("=" * 65)
|
||||
|
||||
events["in_reset"] = True
|
||||
events["start_next_episode"] = False
|
||||
|
||||
obs = robot.get_observation()
|
||||
robot_pos = {k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features}
|
||||
teleop.smooth_move_to(robot_pos, duration_s=2.0, fps=50)
|
||||
|
||||
print(" Press any key/pedal to enable teleoperation")
|
||||
while not events["start_next_episode"] and not events["stop_recording"]:
|
||||
precise_sleep(0.05)
|
||||
|
||||
if events["stop_recording"]:
|
||||
return
|
||||
|
||||
events["start_next_episode"] = False
|
||||
teleop.disable_torque()
|
||||
print(" Teleop enabled - press any key/pedal to start episode")
|
||||
|
||||
while not events["start_next_episode"] and not events["stop_recording"]:
|
||||
loop_start = time.perf_counter()
|
||||
action = teleop.get_action()
|
||||
for key in action:
|
||||
if "gripper" in key:
|
||||
action[key] = -0.65 * action[key]
|
||||
robot.send_action(action)
|
||||
dt = time.perf_counter() - loop_start
|
||||
precise_sleep(1 / fps - dt)
|
||||
|
||||
events["in_reset"] = False
|
||||
events["start_next_episode"] = False
|
||||
events["exit_early"] = False
|
||||
events["policy_paused"] = False
|
||||
events["correction_active"] = False
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Main Entry Point
|
||||
# ============================================================================
|
||||
|
||||
@parser.wrap()
|
||||
def rac_rtc_collect(cfg: RaCRTCConfig) -> LeRobotDataset:
|
||||
"""Main RaC data collection function with RTC."""
|
||||
init_logging()
|
||||
logging.info(pformat(cfg.__dict__))
|
||||
|
||||
if cfg.display_data:
|
||||
init_rerun(session_name="rac_rtc_collection_openarms")
|
||||
|
||||
robot_raw = make_robot_from_config(cfg.robot)
|
||||
teleop = make_teleoperator_from_config(cfg.teleop)
|
||||
|
||||
teleop_proc, robot_proc, obs_proc = make_identity_processors()
|
||||
|
||||
dataset_features = combine_feature_dicts(
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=teleop_proc,
|
||||
initial_features=create_initial_features(action=robot_raw.action_features),
|
||||
use_videos=cfg.dataset.video,
|
||||
),
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=obs_proc,
|
||||
initial_features=create_initial_features(observation=robot_raw.observation_features),
|
||||
use_videos=cfg.dataset.video,
|
||||
),
|
||||
)
|
||||
|
||||
dataset = None
|
||||
listener = None
|
||||
shutdown_event = Event()
|
||||
policy_active = Event()
|
||||
rtc_thread = None
|
||||
|
||||
try:
|
||||
if cfg.resume:
|
||||
dataset = LeRobotDataset(
|
||||
cfg.dataset.repo_id,
|
||||
root=cfg.dataset.root,
|
||||
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
|
||||
)
|
||||
if hasattr(robot_raw, "cameras") and robot_raw.cameras:
|
||||
dataset.start_image_writer(
|
||||
num_processes=cfg.dataset.num_image_writer_processes,
|
||||
num_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot_raw.cameras),
|
||||
)
|
||||
else:
|
||||
dataset = LeRobotDataset.create(
|
||||
cfg.dataset.repo_id,
|
||||
cfg.dataset.fps,
|
||||
root=cfg.dataset.root,
|
||||
robot_type=robot_raw.name,
|
||||
features=dataset_features,
|
||||
use_videos=cfg.dataset.video,
|
||||
image_writer_processes=cfg.dataset.num_image_writer_processes,
|
||||
image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera
|
||||
* len(robot_raw.cameras if hasattr(robot_raw, "cameras") else []),
|
||||
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
|
||||
)
|
||||
|
||||
# Load policy
|
||||
logger.info(f"Loading policy from: {cfg.policy.pretrained_path}")
|
||||
policy_class = get_policy_class(cfg.policy.type)
|
||||
policy = policy_class.from_pretrained(cfg.policy.pretrained_path)
|
||||
policy.config.rtc_config = cfg.rtc
|
||||
policy.init_rtc_processor()
|
||||
policy = policy.to(cfg.device)
|
||||
policy.eval()
|
||||
logger.info(f"Policy loaded: {policy.name}")
|
||||
|
||||
# Setup preprocessor/postprocessor
|
||||
hw_features = hw_to_dataset_features(robot_raw.observation_features, "observation")
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy,
|
||||
pretrained_path=cfg.policy.pretrained_path,
|
||||
dataset_stats=rename_stats(dataset.meta.stats, cfg.dataset.rename_map),
|
||||
preprocessor_overrides={
|
||||
"device_processor": {"device": cfg.device},
|
||||
"rename_observations_processor": {"rename_map": cfg.dataset.rename_map},
|
||||
},
|
||||
)
|
||||
|
||||
# Connect robot and wrap for thread safety
|
||||
robot_raw.connect()
|
||||
robot = RobotWrapper(robot_raw)
|
||||
|
||||
teleop.connect()
|
||||
listener, events = init_rac_keyboard_listener()
|
||||
|
||||
# Shared state holders (main loop writes, RTC thread reads)
|
||||
queue_holder = {"queue": ActionQueue(cfg.rtc)}
|
||||
obs_holder = {"obs": None, "robot_type": robot.robot_type} # Main loop updates obs
|
||||
|
||||
# Start RTC inference thread
|
||||
# NOTE: Thread does NOT access robot directly - reads from obs_holder
|
||||
rtc_thread = Thread(
|
||||
target=rtc_inference_thread,
|
||||
args=(
|
||||
policy,
|
||||
obs_holder, # Thread reads obs from here (set by main loop)
|
||||
hw_features,
|
||||
preprocessor,
|
||||
postprocessor,
|
||||
queue_holder,
|
||||
shutdown_event,
|
||||
policy_active,
|
||||
cfg,
|
||||
),
|
||||
daemon=True,
|
||||
name="RTCInference",
|
||||
)
|
||||
rtc_thread.start()
|
||||
logger.info("Started RTC inference thread")
|
||||
|
||||
print("\n" + "=" * 65)
|
||||
print(" RaC Data Collection with RTC")
|
||||
print("=" * 65)
|
||||
print(f" Policy: {cfg.policy.pretrained_path}")
|
||||
print(f" Task: {cfg.dataset.single_task}")
|
||||
print(f" FPS: {cfg.dataset.fps}")
|
||||
print(f" Interpolation: {cfg.interpolation}")
|
||||
print()
|
||||
print(" Controls:")
|
||||
print(" SPACE - Pause policy")
|
||||
print(" c - Take control")
|
||||
print(" → - End episode")
|
||||
print(" ESC - Stop and push to hub")
|
||||
print("=" * 65 + "\n")
|
||||
|
||||
with VideoEncodingManager(dataset):
|
||||
recorded = 0
|
||||
while recorded < cfg.dataset.num_episodes and not events["stop_recording"]:
|
||||
log_say(f"RaC episode {dataset.num_episodes}", cfg.play_sounds)
|
||||
|
||||
# Fresh action queue per episode (update holder so thread sees it)
|
||||
queue_holder["queue"] = ActionQueue(cfg.rtc)
|
||||
|
||||
logger.info(f"Episode {recorded + 1} / {cfg.dataset.num_episodes}")
|
||||
|
||||
stats = rac_rtc_rollout_loop(
|
||||
robot=robot,
|
||||
teleop=teleop,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
events=events,
|
||||
cfg=cfg,
|
||||
queue_holder=queue_holder,
|
||||
obs_holder=obs_holder,
|
||||
policy_active=policy_active,
|
||||
hw_features=hw_features,
|
||||
)
|
||||
|
||||
logging.info(f"Episode stats: {stats}")
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-recording", cfg.play_sounds)
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
dataset.save_episode()
|
||||
recorded += 1
|
||||
|
||||
if recorded < cfg.dataset.num_episodes and not events["stop_recording"]:
|
||||
reset_loop(robot, teleop, events, cfg.dataset.fps)
|
||||
|
||||
finally:
|
||||
log_say("Stop recording", cfg.play_sounds, blocking=True)
|
||||
|
||||
shutdown_event.set()
|
||||
policy_active.clear()
|
||||
|
||||
if rtc_thread and rtc_thread.is_alive():
|
||||
rtc_thread.join(timeout=2.0)
|
||||
|
||||
if dataset:
|
||||
dataset.finalize()
|
||||
|
||||
if robot_raw.is_connected:
|
||||
robot_raw.disconnect()
|
||||
if teleop.is_connected:
|
||||
teleop.disconnect()
|
||||
|
||||
if not is_headless() and listener:
|
||||
listener.stop()
|
||||
|
||||
if cfg.dataset.push_to_hub:
|
||||
dataset.push_to_hub(tags=cfg.dataset.tags, private=cfg.dataset.private)
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
def main():
|
||||
from lerobot.utils.import_utils import register_third_party_plugins
|
||||
register_third_party_plugins()
|
||||
rac_rtc_collect()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,10 @@
|
||||
from huggingface_hub import HfApi, list_datasets
|
||||
|
||||
api = HfApi()
|
||||
datasets = list_datasets(author="lerobot-data-collection")
|
||||
print('"[', end="")
|
||||
i=0
|
||||
for dataset in datasets:
|
||||
if "three-folds-dataset" in dataset.id:
|
||||
print("'" + dataset.id + "',", end="")
|
||||
print(']"',)
|
||||
|
Before Width: | Height: | Size: 2.9 MiB |
|
Before Width: | Height: | Size: 185 KiB |
|
Before Width: | Height: | Size: 464 KiB |
|
Before Width: | Height: | Size: 72 KiB |
|
Before Width: | Height: | Size: 219 KiB |
|
Before Width: | Height: | Size: 199 KiB |
|
After Width: | Height: | Size: 774 KiB |
|
Before Width: | Height: | Size: 160 KiB After Width: | Height: | Size: 160 KiB |
|
After Width: | Height: | Size: 2.3 MiB |
|
After Width: | Height: | Size: 481 KiB |
|
Before Width: | Height: | Size: 117 KiB |
|
Before Width: | Height: | Size: 151 KiB |
|
Before Width: | Height: | Size: 130 KiB |
|
Before Width: | Height: | Size: 407 KiB |
@@ -96,7 +96,7 @@ dependencies = [
|
||||
# Common
|
||||
pygame-dep = ["pygame>=2.5.1,<2.7.0"]
|
||||
placo-dep = ["placo>=0.9.6,<0.10.0"]
|
||||
transformers-dep = ["transformers>=4.53.0,<5.0.0"]
|
||||
transformers-dep = ["transformers>=4.57.1,<5.0.0"]
|
||||
grpcio-dep = ["grpcio==1.73.1", "protobuf==6.31.0"] # TODO: Bumb dependency (compatible with wandb)
|
||||
|
||||
# Motors
|
||||
@@ -120,6 +120,13 @@ intelrealsense = [
|
||||
phone = ["hebi-py>=2.8.0,<2.12.0", "teleop>=0.1.0,<0.2.0", "fastapi<1.0"]
|
||||
|
||||
# Policies
|
||||
wallx = [
|
||||
"transformers==4.49.0",
|
||||
"peft==0.17.1",
|
||||
"scipy==1.15.3",
|
||||
"torchdiffeq==0.2.5",
|
||||
"qwen_vl_utils==0.0.11"
|
||||
]
|
||||
pi = ["transformers @ git+https://github.com/huggingface/transformers.git@fix/lerobot_openpi"]
|
||||
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0", "safetensors>=0.4.3,<1.0.0"]
|
||||
groot = [
|
||||
@@ -133,13 +140,15 @@ groot = [
|
||||
"ninja>=1.11.1,<2.0.0",
|
||||
"flash-attn>=2.5.9,<3.0.0 ; sys_platform != 'darwin'"
|
||||
]
|
||||
sarm = ["lerobot[transformers-dep]", "faker>=33.0.0,<35.0.0", "matplotlib>=3.10.3,<4.0.0", "qwen-vl-utils>=0.0.14"]
|
||||
xvla = ["lerobot[transformers-dep]"]
|
||||
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||
|
||||
# Features
|
||||
async = ["lerobot[grpcio-dep]", "matplotlib>=3.10.3,<4.0.0"]
|
||||
|
||||
# Development
|
||||
dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1"]
|
||||
dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1", "mypy>=1.19.1"]
|
||||
test = ["pytest>=8.1.0,<9.0.0", "pytest-timeout>=2.4.0,<3.0.0", "pytest-cov>=5.0.0,<8.0.0", "mock-serial>=0.0.1,<0.1.0 ; sys_platform != 'win32'"]
|
||||
video_benchmark = ["scikit-image>=0.23.2,<0.26.0", "pandas>=2.2.2,<2.4.0"]
|
||||
|
||||
@@ -158,9 +167,11 @@ all = [
|
||||
"lerobot[reachy2]",
|
||||
"lerobot[kinematics]",
|
||||
"lerobot[intelrealsense]",
|
||||
"lerobot[pi]",
|
||||
# "lerobot[wallx]",
|
||||
# "lerobot[pi]", TODO(Pepijn): Update pi to transformers v5
|
||||
"lerobot[smolvla]",
|
||||
# "lerobot[groot]", TODO(Steven): Gr00t requires specific installation instructions for flash-attn
|
||||
"lerobot[xvla]",
|
||||
"lerobot[hilserl]",
|
||||
"lerobot[async]",
|
||||
"lerobot[dev]",
|
||||
@@ -171,6 +182,7 @@ all = [
|
||||
"lerobot[phone]",
|
||||
"lerobot[libero]",
|
||||
"lerobot[metaworld]",
|
||||
"lerobot[sarm]"
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
@@ -225,6 +237,7 @@ ignore = [
|
||||
|
||||
[tool.ruff.lint.per-file-ignores]
|
||||
"__init__.py" = ["F401", "F403"]
|
||||
"src/lerobot/policies/wall_x/**" = ["N801", "N812", "SIM102", "SIM108", "SIM210", "SIM211", "B006", "B007", "SIM118"] # Supprese these as they are coming from original Qwen2_5_vl code TODO(pepijn): refactor original
|
||||
|
||||
[tool.ruff.lint.isort]
|
||||
combine-as-imports = true
|
||||
@@ -261,6 +274,7 @@ default.extend-ignore-identifiers-re = [
|
||||
"ein",
|
||||
"thw",
|
||||
"inpt",
|
||||
"ROBOTIS",
|
||||
]
|
||||
|
||||
# TODO: Uncomment when ready to use
|
||||
@@ -315,9 +329,9 @@ disallow_untyped_defs = true
|
||||
disallow_incomplete_defs = true
|
||||
check_untyped_defs = true
|
||||
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.optim.*"
|
||||
# ignore_errors = false
|
||||
[[tool.mypy.overrides]]
|
||||
module = "lerobot.optim.*"
|
||||
ignore_errors = false
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "lerobot.model.*"
|
||||
@@ -360,10 +374,84 @@ ignore_errors = false
|
||||
# module = "lerobot.async_inference.*"
|
||||
# ignore_errors = false
|
||||
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.transport.*"
|
||||
# ignore_errors = false
|
||||
[[tool.mypy.overrides]]
|
||||
module = "lerobot.transport.*"
|
||||
ignore_errors = false
|
||||
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.scripts.*"
|
||||
# ignore_errors = false
|
||||
|
||||
[tool.uv]
|
||||
# wallx requires transformers==4.49.0 which conflicts with other extras that need >=4.53.0
|
||||
conflicts = [
|
||||
[
|
||||
{ extra = "wallx" },
|
||||
{ extra = "transformers-dep" },
|
||||
],
|
||||
[
|
||||
{ extra = "wallx" },
|
||||
{ extra = "pi" },
|
||||
],
|
||||
[
|
||||
{ extra = "wallx" },
|
||||
{ extra = "smolvla" },
|
||||
],
|
||||
[
|
||||
{ extra = "wallx" },
|
||||
{ extra = "groot" },
|
||||
],
|
||||
[
|
||||
{ extra = "wallx" },
|
||||
{ extra = "xvla" },
|
||||
],
|
||||
[
|
||||
{ extra = "wallx" },
|
||||
{ extra = "sarm" },
|
||||
],
|
||||
[
|
||||
{ extra = "wallx" },
|
||||
{ extra = "hilserl" },
|
||||
],
|
||||
[
|
||||
{ extra = "wallx" },
|
||||
{ extra = "libero" },
|
||||
],
|
||||
[
|
||||
{ extra = "wallx" },
|
||||
{ extra = "all" },
|
||||
],
|
||||
# pi uses custom branch which conflicts with transformers-dep
|
||||
[
|
||||
{ extra = "pi" },
|
||||
{ extra = "transformers-dep" },
|
||||
],
|
||||
[
|
||||
{ extra = "pi" },
|
||||
{ extra = "smolvla" },
|
||||
],
|
||||
[
|
||||
{ extra = "pi" },
|
||||
{ extra = "groot" },
|
||||
],
|
||||
[
|
||||
{ extra = "pi" },
|
||||
{ extra = "xvla" },
|
||||
],
|
||||
[
|
||||
{ extra = "pi" },
|
||||
{ extra = "sarm" },
|
||||
],
|
||||
[
|
||||
{ extra = "pi" },
|
||||
{ extra = "hilserl" },
|
||||
],
|
||||
[
|
||||
{ extra = "pi" },
|
||||
{ extra = "libero" },
|
||||
],
|
||||
[
|
||||
{ extra = "pi" },
|
||||
{ extra = "all" },
|
||||
],
|
||||
]
|
||||
|
||||
@@ -26,4 +26,4 @@ DEFAULT_OBS_QUEUE_TIMEOUT = 2
|
||||
SUPPORTED_POLICIES = ["act", "smolvla", "diffusion", "tdmpc", "vqbet", "pi0", "pi05"]
|
||||
|
||||
# TODO: Add all other robots
|
||||
SUPPORTED_ROBOTS = ["so100_follower", "so101_follower", "bi_so100_follower"]
|
||||
SUPPORTED_ROBOTS = ["so100_follower", "so101_follower", "bi_so100_follower", "omx_follower"]
|
||||
|
||||
@@ -54,6 +54,7 @@ from lerobot.robots import ( # noqa: F401
|
||||
bi_so100_follower,
|
||||
koch_follower,
|
||||
make_robot_from_config,
|
||||
omx_follower,
|
||||
so100_follower,
|
||||
so101_follower,
|
||||
)
|
||||
|
||||
@@ -56,6 +56,7 @@ class TrainPipelineConfig(HubMixin):
|
||||
steps: int = 100_000
|
||||
eval_freq: int = 20_000
|
||||
log_freq: int = 200
|
||||
tolerance_s: float = 1e-4
|
||||
save_checkpoint: bool = True
|
||||
# Checkpoint is saved every `save_freq` training iterations and after the last training step.
|
||||
save_freq: int = 20_000
|
||||
@@ -64,9 +65,17 @@ class TrainPipelineConfig(HubMixin):
|
||||
scheduler: LRSchedulerConfig | None = None
|
||||
eval: EvalConfig = field(default_factory=EvalConfig)
|
||||
wandb: WandBConfig = field(default_factory=WandBConfig)
|
||||
checkpoint_path: Path | None = field(init=False, default=None)
|
||||
|
||||
# RA-BC (Reward-Aligned Behavior Cloning) parameters
|
||||
use_rabc: bool = False # Enable reward-weighted training
|
||||
rabc_progress_path: str | None = None # Path to precomputed SARM progress parquet file
|
||||
rabc_kappa: float = 0.01 # Hard threshold for high-quality samples
|
||||
rabc_epsilon: float = 1e-6 # Small constant for numerical stability
|
||||
rabc_head_mode: str | None = "sparse" # For dual-head models: "sparse" or "dense"
|
||||
|
||||
# Rename map for the observation to override the image and state keys
|
||||
rename_map: dict[str, str] = field(default_factory=dict)
|
||||
checkpoint_path: Path | None = field(init=False, default=None)
|
||||
|
||||
def validate(self) -> None:
|
||||
# HACK: We parse again the cli args here to get the pretrained paths if there was some.
|
||||
@@ -130,6 +139,14 @@ class TrainPipelineConfig(HubMixin):
|
||||
"'policy.repo_id' argument missing. Please specify it to push the model to the hub."
|
||||
)
|
||||
|
||||
if self.use_rabc and not self.rabc_progress_path:
|
||||
# Auto-detect from dataset path
|
||||
repo_id = self.dataset.repo_id
|
||||
if self.dataset.root:
|
||||
self.rabc_progress_path = str(Path(self.dataset.root) / "sarm_progress.parquet")
|
||||
else:
|
||||
self.rabc_progress_path = f"hf://datasets/{repo_id}/sarm_progress.parquet"
|
||||
|
||||
@classmethod
|
||||
def __get_path_fields__(cls) -> list[str]:
|
||||
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
|
||||
|
||||
@@ -0,0 +1,13 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
@@ -0,0 +1,13 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
@@ -136,21 +136,40 @@ def update_meta_data(
|
||||
df["_orig_chunk"] = df[orig_chunk_col].copy()
|
||||
df["_orig_file"] = df[orig_file_col].copy()
|
||||
|
||||
# Update chunk and file indices to point to destination
|
||||
df[orig_chunk_col] = video_idx["chunk"]
|
||||
df[orig_file_col] = video_idx["file"]
|
||||
|
||||
# Apply per-source-file timestamp offsets
|
||||
# Get mappings for this video key
|
||||
src_to_offset = video_idx.get("src_to_offset", {})
|
||||
if src_to_offset:
|
||||
# Apply offset based on original source file
|
||||
src_to_dst = video_idx.get("src_to_dst", {})
|
||||
|
||||
# Apply per-source-file mappings
|
||||
if src_to_dst:
|
||||
# Map each episode to its correct destination file and apply offset
|
||||
for idx in df.index:
|
||||
src_key = (df.at[idx, "_orig_chunk"], df.at[idx, "_orig_file"])
|
||||
# Convert to Python int to avoid numpy type mismatch in dict lookup
|
||||
src_key = (int(df.at[idx, "_orig_chunk"]), int(df.at[idx, "_orig_file"]))
|
||||
|
||||
# Get destination chunk/file for this source file
|
||||
dst_chunk, dst_file = src_to_dst.get(src_key, (video_idx["chunk"], video_idx["file"]))
|
||||
df.at[idx, orig_chunk_col] = dst_chunk
|
||||
df.at[idx, orig_file_col] = dst_file
|
||||
|
||||
# Apply timestamp offset
|
||||
offset = src_to_offset.get(src_key, 0)
|
||||
df.at[idx, f"videos/{key}/from_timestamp"] += offset
|
||||
df.at[idx, f"videos/{key}/to_timestamp"] += offset
|
||||
elif src_to_offset:
|
||||
# Fallback: use same destination for all, but apply per-file offsets
|
||||
df[orig_chunk_col] = video_idx["chunk"]
|
||||
df[orig_file_col] = video_idx["file"]
|
||||
for idx in df.index:
|
||||
# Convert to Python int to avoid numpy type mismatch in dict lookup
|
||||
src_key = (int(df.at[idx, "_orig_chunk"]), int(df.at[idx, "_orig_file"]))
|
||||
offset = src_to_offset.get(src_key, 0)
|
||||
df.at[idx, f"videos/{key}/from_timestamp"] += offset
|
||||
df.at[idx, f"videos/{key}/to_timestamp"] += offset
|
||||
else:
|
||||
# Fallback to simple offset (for backward compatibility)
|
||||
df[orig_chunk_col] = video_idx["chunk"]
|
||||
df[orig_file_col] = video_idx["file"]
|
||||
df[f"videos/{key}/from_timestamp"] = (
|
||||
df[f"videos/{key}/from_timestamp"] + video_idx["latest_duration"]
|
||||
)
|
||||
@@ -268,6 +287,12 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
|
||||
videos_idx[key]["episode_duration"] = 0
|
||||
# Track offset for each source (chunk, file) pair
|
||||
videos_idx[key]["src_to_offset"] = {}
|
||||
# Track destination (chunk, file) for each source (chunk, file) pair
|
||||
videos_idx[key]["src_to_dst"] = {}
|
||||
# Initialize dst_file_durations if not present
|
||||
# dst_file_durations tracks duration of each destination file
|
||||
if "dst_file_durations" not in videos_idx[key]:
|
||||
videos_idx[key]["dst_file_durations"] = {}
|
||||
|
||||
for key, video_idx in videos_idx.items():
|
||||
unique_chunk_file_pairs = {
|
||||
@@ -282,9 +307,13 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
|
||||
|
||||
chunk_idx = video_idx["chunk"]
|
||||
file_idx = video_idx["file"]
|
||||
current_offset = video_idx["latest_duration"]
|
||||
dst_file_durations = video_idx["dst_file_durations"]
|
||||
|
||||
for src_chunk_idx, src_file_idx in unique_chunk_file_pairs:
|
||||
# Convert to Python int to ensure consistent dict keys
|
||||
src_chunk_idx = int(src_chunk_idx)
|
||||
src_file_idx = int(src_file_idx)
|
||||
|
||||
src_path = src_meta.root / DEFAULT_VIDEO_PATH.format(
|
||||
video_key=key,
|
||||
chunk_index=src_chunk_idx,
|
||||
@@ -298,14 +327,17 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
|
||||
)
|
||||
|
||||
src_duration = get_video_duration_in_s(src_path)
|
||||
dst_key = (chunk_idx, file_idx)
|
||||
|
||||
if not dst_path.exists():
|
||||
# Store offset before incrementing
|
||||
videos_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = current_offset
|
||||
# New destination file: offset is 0
|
||||
videos_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = 0
|
||||
videos_idx[key]["src_to_dst"][(src_chunk_idx, src_file_idx)] = dst_key
|
||||
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy(str(src_path), str(dst_path))
|
||||
# Track duration of this destination file
|
||||
dst_file_durations[dst_key] = src_duration
|
||||
videos_idx[key]["episode_duration"] += src_duration
|
||||
current_offset += src_duration
|
||||
continue
|
||||
|
||||
# Check file sizes before appending
|
||||
@@ -313,10 +345,11 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
|
||||
dst_size = get_file_size_in_mb(dst_path)
|
||||
|
||||
if dst_size + src_size >= video_files_size_in_mb:
|
||||
# Rotate to a new file, this source becomes start of new destination
|
||||
# So its offset should be 0
|
||||
videos_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = 0
|
||||
# Rotate to a new file - offset is 0
|
||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, chunk_size)
|
||||
dst_key = (chunk_idx, file_idx)
|
||||
videos_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = 0
|
||||
videos_idx[key]["src_to_dst"][(src_chunk_idx, src_file_idx)] = dst_key
|
||||
dst_path = dst_meta.root / DEFAULT_VIDEO_PATH.format(
|
||||
video_key=key,
|
||||
chunk_index=chunk_idx,
|
||||
@@ -324,16 +357,20 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
|
||||
)
|
||||
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy(str(src_path), str(dst_path))
|
||||
# Reset offset for next file
|
||||
current_offset = src_duration
|
||||
# Track duration of this new destination file
|
||||
dst_file_durations[dst_key] = src_duration
|
||||
else:
|
||||
# Append to existing video file - use current accumulated offset
|
||||
videos_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = current_offset
|
||||
# Append to existing destination file
|
||||
# Offset is the current duration of this destination file
|
||||
current_dst_duration = dst_file_durations.get(dst_key, 0)
|
||||
videos_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = current_dst_duration
|
||||
videos_idx[key]["src_to_dst"][(src_chunk_idx, src_file_idx)] = dst_key
|
||||
concatenate_video_files(
|
||||
[dst_path, src_path],
|
||||
dst_path,
|
||||
)
|
||||
current_offset += src_duration
|
||||
# Update duration of this destination file
|
||||
dst_file_durations[dst_key] = current_dst_duration + src_duration
|
||||
|
||||
videos_idx[key]["episode_duration"] += src_duration
|
||||
|
||||
|
||||
@@ -98,6 +98,7 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
|
||||
image_transforms=image_transforms,
|
||||
revision=cfg.dataset.revision,
|
||||
video_backend=cfg.dataset.video_backend,
|
||||
tolerance_s=cfg.tolerance_s,
|
||||
)
|
||||
else:
|
||||
dataset = StreamingLeRobotDataset(
|
||||
@@ -108,6 +109,7 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
|
||||
image_transforms=image_transforms,
|
||||
revision=cfg.dataset.revision,
|
||||
max_num_shards=cfg.num_workers,
|
||||
tolerance_s=cfg.tolerance_s,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("The MultiLeRobotDataset isn't supported for now.")
|
||||
|
||||
@@ -23,11 +23,13 @@ from pathlib import Path
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
import os
|
||||
import packaging.version
|
||||
import pandas as pd
|
||||
import PIL.Image
|
||||
import pyarrow as pa
|
||||
import pyarrow.parquet as pq
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
import torch
|
||||
import torch.utils
|
||||
from huggingface_hub import HfApi, snapshot_download
|
||||
@@ -1199,40 +1201,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
use_batched_encoding = self.batch_encoding_size > 1
|
||||
|
||||
if has_video_keys and not use_batched_encoding:
|
||||
num_cameras = len(self.meta.video_keys)
|
||||
if parallel_encoding and num_cameras > 1:
|
||||
# TODO(Steven): Ideally we would like to control the number of threads per encoding such that:
|
||||
# num_cameras * num_threads = (total_cpu -1)
|
||||
with concurrent.futures.ProcessPoolExecutor(max_workers=num_cameras) as executor:
|
||||
future_to_key = {
|
||||
executor.submit(
|
||||
_encode_video_worker,
|
||||
video_key,
|
||||
episode_index,
|
||||
self.root,
|
||||
self.fps,
|
||||
): video_key
|
||||
for video_key in self.meta.video_keys
|
||||
}
|
||||
|
||||
results = {}
|
||||
for future in concurrent.futures.as_completed(future_to_key):
|
||||
video_key = future_to_key[future]
|
||||
try:
|
||||
temp_path = future.result()
|
||||
results[video_key] = temp_path
|
||||
except Exception as exc:
|
||||
logging.error(f"Video encoding failed for {video_key}: {exc}")
|
||||
raise exc
|
||||
|
||||
for video_key in self.meta.video_keys:
|
||||
temp_path = results[video_key]
|
||||
ep_metadata.update(
|
||||
self._save_episode_video(video_key, episode_index, temp_path=temp_path)
|
||||
)
|
||||
else:
|
||||
for video_key in self.meta.video_keys:
|
||||
ep_metadata.update(self._save_episode_video(video_key, episode_index))
|
||||
video_paths = self._encode_multiple_temporary_episode_videos(self.meta.video_keys, episode_index)
|
||||
for video_key, video_path in zip(self.meta.video_keys, video_paths):
|
||||
ep_metadata.update(self._save_episode_video(video_key, episode_index, video_path))
|
||||
|
||||
# `meta.save_episode` need to be executed after encoding the videos
|
||||
self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats, ep_metadata)
|
||||
@@ -1528,6 +1499,22 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
return _encode_video_worker(video_key, episode_index, self.root, self.fps)
|
||||
|
||||
def _encode_multiple_temporary_episode_videos(self, video_keys, episode_index):
|
||||
temp_paths = []
|
||||
img_dirs = []
|
||||
for video_key in video_keys:
|
||||
temp_paths.append(Path(tempfile.mkdtemp(dir=self.root)) / f"{video_key}_{episode_index:03d}.mp4")
|
||||
img_dirs.append(self._get_image_file_dir(episode_index, video_key))
|
||||
fps = [self.fps]*len(video_keys)
|
||||
|
||||
with ProcessPoolExecutor(max_workers=len(video_keys)) as executor:
|
||||
executor.map(encode_video_frames,img_dirs,temp_paths,fps)
|
||||
|
||||
for img_dir in img_dirs:
|
||||
shutil.rmtree(img_dir)
|
||||
|
||||
return temp_paths
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
|
||||
@@ -310,7 +310,7 @@ def encode_video_frames(
|
||||
crf: int | None = 30,
|
||||
fast_decode: int = 0,
|
||||
log_level: int | None = av.logging.ERROR,
|
||||
overwrite: bool = False,
|
||||
overwrite: bool = True,
|
||||
preset: int | None = None,
|
||||
) -> None:
|
||||
"""More info on ffmpeg arguments tuning on `benchmark/video/README.md`"""
|
||||
@@ -355,6 +355,9 @@ def encode_video_frames(
|
||||
if crf is not None:
|
||||
video_options["crf"] = str(crf)
|
||||
|
||||
#TEMPORARY FIX
|
||||
video_options["preset"] = "12"
|
||||
|
||||
if fast_decode:
|
||||
key = "svtav1-params" if vcodec == "libsvtav1" else "tune"
|
||||
value = f"fast-decode={fast_decode}" if vcodec == "libsvtav1" else "fastdecode"
|
||||
|
||||
@@ -245,7 +245,7 @@ class HILSerlRobotEnvConfig(EnvConfig):
|
||||
class LiberoEnv(EnvConfig):
|
||||
task: str = "libero_10" # can also choose libero_spatial, libero_object, etc.
|
||||
fps: int = 30
|
||||
episode_length: int = 520
|
||||
episode_length: int | None = None
|
||||
obs_type: str = "pixels_agent_pos"
|
||||
render_mode: str = "rgb_array"
|
||||
camera_name: str = "agentview_image,robot0_eye_in_hand_image"
|
||||
@@ -272,6 +272,7 @@ class LiberoEnv(EnvConfig):
|
||||
LIBERO_KEY_PIXELS_EYE_IN_HAND: f"{OBS_IMAGES}.image2",
|
||||
}
|
||||
)
|
||||
control_mode: str = "relative" # or "absolute"
|
||||
|
||||
def __post_init__(self):
|
||||
if self.obs_type == "pixels":
|
||||
|
||||
@@ -19,8 +19,10 @@ from typing import Any
|
||||
import gymnasium as gym
|
||||
from gymnasium.envs.registration import registry as gym_registry
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.envs.configs import AlohaEnv, EnvConfig, LiberoEnv, PushtEnv
|
||||
from lerobot.envs.utils import _call_make_env, _download_hub_file, _import_hub_module, _normalize_hub_result
|
||||
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
|
||||
from lerobot.processor import ProcessorStep
|
||||
from lerobot.processor.env_processor import LiberoProcessorStep
|
||||
from lerobot.processor.pipeline import PolicyProcessorPipeline
|
||||
@@ -39,6 +41,7 @@ def make_env_config(env_type: str, **kwargs) -> EnvConfig:
|
||||
|
||||
def make_env_pre_post_processors(
|
||||
env_cfg: EnvConfig,
|
||||
policy_cfg: PreTrainedConfig,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
@@ -61,6 +64,10 @@ def make_env_pre_post_processors(
|
||||
# Preprocessor and Postprocessor steps are Identity for most environments
|
||||
preprocessor_steps: list[ProcessorStep] = []
|
||||
postprocessor_steps: list[ProcessorStep] = []
|
||||
if isinstance(policy_cfg, XVLAConfig):
|
||||
from lerobot.policies.xvla.processor_xvla import make_xvla_libero_pre_post_processors
|
||||
|
||||
return make_xvla_libero_pre_post_processors()
|
||||
|
||||
# For LIBERO environments, add the LiberoProcessorStep to preprocessor
|
||||
if isinstance(env_cfg, LiberoEnv) or "libero" in env_cfg.type:
|
||||
@@ -136,6 +143,8 @@ def make_env(
|
||||
init_states=cfg.init_states,
|
||||
gym_kwargs=cfg.gym_kwargs,
|
||||
env_cls=env_cls,
|
||||
control_mode=cfg.control_mode,
|
||||
episode_length=cfg.episode_length,
|
||||
)
|
||||
elif "metaworld" in cfg.type:
|
||||
from lerobot.envs.metaworld import create_metaworld_envs
|
||||
|
||||
@@ -80,10 +80,7 @@ def get_libero_dummy_action():
|
||||
return [0, 0, 0, 0, 0, 0, -1]
|
||||
|
||||
|
||||
OBS_STATE_DIM = 8
|
||||
ACTION_DIM = 7
|
||||
AGENT_POS_LOW = -1000.0
|
||||
AGENT_POS_HIGH = 1000.0
|
||||
ACTION_LOW = -1.0
|
||||
ACTION_HIGH = 1.0
|
||||
TASK_SUITE_MAX_STEPS: dict[str, int] = {
|
||||
@@ -103,6 +100,7 @@ class LiberoEnv(gym.Env):
|
||||
task_suite: Any,
|
||||
task_id: int,
|
||||
task_suite_name: str,
|
||||
episode_length: int | None = None,
|
||||
camera_name: str | Sequence[str] = "agentview_image,robot0_eye_in_hand_image",
|
||||
obs_type: str = "pixels",
|
||||
render_mode: str = "rgb_array",
|
||||
@@ -114,6 +112,7 @@ class LiberoEnv(gym.Env):
|
||||
episode_index: int = 0,
|
||||
camera_name_mapping: dict[str, str] | None = None,
|
||||
num_steps_wait: int = 10,
|
||||
control_mode: str = "relative",
|
||||
):
|
||||
super().__init__()
|
||||
self.task_id = task_id
|
||||
@@ -141,14 +140,19 @@ class LiberoEnv(gym.Env):
|
||||
self.camera_name_mapping = camera_name_mapping
|
||||
self.num_steps_wait = num_steps_wait
|
||||
self.episode_index = episode_index
|
||||
self.episode_length = episode_length
|
||||
# Load once and keep
|
||||
self._init_states = get_task_init_states(task_suite, self.task_id) if self.init_states else None
|
||||
self._init_state_id = self.episode_index # tie each sub-env to a fixed init state
|
||||
|
||||
self._env = self._make_envs_task(task_suite, self.task_id)
|
||||
default_steps = 500
|
||||
self._max_episode_steps = TASK_SUITE_MAX_STEPS.get(task_suite_name, default_steps)
|
||||
|
||||
self._max_episode_steps = (
|
||||
TASK_SUITE_MAX_STEPS.get(task_suite_name, default_steps)
|
||||
if self.episode_length is None
|
||||
else self.episode_length
|
||||
)
|
||||
self.control_mode = control_mode
|
||||
images = {}
|
||||
for cam in self.camera_name:
|
||||
images[self.camera_name_mapping[cam]] = spaces.Box(
|
||||
@@ -296,6 +300,15 @@ class LiberoEnv(gym.Env):
|
||||
# Increasing this value can improve determinism and reproducibility across resets.
|
||||
for _ in range(self.num_steps_wait):
|
||||
raw_obs, _, _, _ = self._env.step(get_libero_dummy_action())
|
||||
|
||||
if self.control_mode == "absolute":
|
||||
for robot in self._env.robots:
|
||||
robot.controller.use_delta = False
|
||||
elif self.control_mode == "relative":
|
||||
for robot in self._env.robots:
|
||||
robot.controller.use_delta = True
|
||||
else:
|
||||
raise ValueError(f"Invalid control mode: {self.control_mode}")
|
||||
observation = self._format_raw_obs(raw_obs)
|
||||
info = {"is_success": False}
|
||||
return observation, info
|
||||
@@ -341,8 +354,10 @@ def _make_env_fns(
|
||||
task_id: int,
|
||||
n_envs: int,
|
||||
camera_names: list[str],
|
||||
episode_length: int | None,
|
||||
init_states: bool,
|
||||
gym_kwargs: Mapping[str, Any],
|
||||
control_mode: str,
|
||||
) -> list[Callable[[], LiberoEnv]]:
|
||||
"""Build n_envs factory callables for a single (suite, task_id)."""
|
||||
|
||||
@@ -354,7 +369,9 @@ def _make_env_fns(
|
||||
task_suite_name=suite_name,
|
||||
camera_name=camera_names,
|
||||
init_states=init_states,
|
||||
episode_length=episode_length,
|
||||
episode_index=episode_index,
|
||||
control_mode=control_mode,
|
||||
**local_kwargs,
|
||||
)
|
||||
|
||||
@@ -374,6 +391,8 @@ def create_libero_envs(
|
||||
camera_name: str | Sequence[str] = "agentview_image,robot0_eye_in_hand_image",
|
||||
init_states: bool = True,
|
||||
env_cls: Callable[[Sequence[Callable[[], Any]]], Any] | None = None,
|
||||
control_mode: str = "relative",
|
||||
episode_length: int | None = None,
|
||||
) -> dict[str, dict[int, Any]]:
|
||||
"""
|
||||
Create vectorized LIBERO environments with a consistent return shape.
|
||||
@@ -415,12 +434,14 @@ def create_libero_envs(
|
||||
for tid in selected:
|
||||
fns = _make_env_fns(
|
||||
suite=suite,
|
||||
episode_length=episode_length,
|
||||
suite_name=suite_name,
|
||||
task_id=tid,
|
||||
n_envs=n_envs,
|
||||
camera_names=camera_names,
|
||||
init_states=init_states,
|
||||
gym_kwargs=gym_kwargs,
|
||||
control_mode=control_mode,
|
||||
)
|
||||
out[suite_name][tid] = env_cls(fns)
|
||||
print(f"Built vec env | suite={suite_name} | task_id={tid} | n_envs={n_envs}")
|
||||
|
||||
@@ -14,4 +14,11 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .motors_bus import Motor, MotorCalibration, MotorNormMode, MotorsBus
|
||||
from .motors_bus import (
|
||||
Motor,
|
||||
MotorCalibration,
|
||||
MotorNormMode,
|
||||
MotorsBus, # Backward compatibility (alias for SerialMotorsBus)
|
||||
MotorsBusBase,
|
||||
SerialMotorsBus,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,18 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .damiao import DamiaoMotorsBus
|
||||
from .tables import *
|
||||
@@ -0,0 +1,905 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# TODO(pepijn): add license of: https://github.com/cmjang/DM_Control_Python?tab=MIT-1-ov-file#readme
|
||||
|
||||
import logging
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from copy import deepcopy
|
||||
from functools import cached_property
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import can
|
||||
import numpy as np
|
||||
|
||||
from lerobot.motors import Motor, MotorCalibration, MotorNormMode, MotorsBusBase
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.utils.utils import enter_pressed, move_cursor_up
|
||||
|
||||
from .tables import (
|
||||
AVAILABLE_BAUDRATES,
|
||||
CAN_CMD_DISABLE,
|
||||
CAN_CMD_ENABLE,
|
||||
CAN_CMD_REFRESH,
|
||||
CAN_CMD_SET_ZERO,
|
||||
CAN_PARAM_ID,
|
||||
DEFAULT_BAUDRATE,
|
||||
DEFAULT_TIMEOUT_MS,
|
||||
MODEL_RESOLUTION,
|
||||
MOTOR_LIMIT_PARAMS,
|
||||
NORMALIZED_DATA,
|
||||
MotorType,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
NameOrID = Union[str, int]
|
||||
Value = Union[int, float]
|
||||
|
||||
|
||||
class DamiaoMotorsBus(MotorsBusBase):
|
||||
"""
|
||||
The Damiao implementation for a MotorsBus using CAN bus communication.
|
||||
|
||||
This class uses python-can for CAN bus communication with Damiao motors.
|
||||
For more info, see:
|
||||
- python-can documentation: https://python-can.readthedocs.io/en/stable/
|
||||
- Seedstudio documentation: https://wiki.seeedstudio.com/damiao_series/
|
||||
- DM_Control_Python repo: https://github.com/cmjang/DM_Control_Python
|
||||
"""
|
||||
|
||||
# CAN-specific settings
|
||||
available_baudrates = deepcopy(AVAILABLE_BAUDRATES)
|
||||
default_baudrate = DEFAULT_BAUDRATE
|
||||
default_timeout = DEFAULT_TIMEOUT_MS
|
||||
|
||||
# Motor configuration
|
||||
model_resolution_table = deepcopy(MODEL_RESOLUTION)
|
||||
normalized_data = deepcopy(NORMALIZED_DATA)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
port: str,
|
||||
motors: dict[str, Motor],
|
||||
calibration: dict[str, MotorCalibration] | None = None,
|
||||
can_interface: str = "auto",
|
||||
use_can_fd: bool = True,
|
||||
bitrate: int = 1000000,
|
||||
data_bitrate: int | None = 5000000,
|
||||
):
|
||||
"""
|
||||
Initialize the Damiao motors bus.
|
||||
|
||||
Args:
|
||||
port: CAN interface name (e.g., "can0" for Linux, "/dev/cu.usbmodem*" for macOS)
|
||||
motors: Dictionary mapping motor names to Motor objects
|
||||
calibration: Optional calibration data
|
||||
can_interface: CAN interface type - "auto" (default), "socketcan" (Linux), or "slcan" (macOS/serial)
|
||||
use_can_fd: Whether to use CAN FD mode (default: True for OpenArms)
|
||||
bitrate: Nominal bitrate in bps (default: 1000000 = 1 Mbps)
|
||||
data_bitrate: Data bitrate for CAN FD in bps (default: 5000000 = 5 Mbps), ignored if use_can_fd is False
|
||||
"""
|
||||
super().__init__(port, motors, calibration)
|
||||
self.port = port
|
||||
self.can_interface = can_interface
|
||||
self.use_can_fd = use_can_fd
|
||||
self.bitrate = bitrate
|
||||
self.data_bitrate = data_bitrate
|
||||
self.canbus = None
|
||||
self._is_connected = False
|
||||
|
||||
# Map motor names to CAN IDs
|
||||
self._motor_can_ids = {}
|
||||
self._recv_id_to_motor = {}
|
||||
|
||||
# Store motor types and recv IDs
|
||||
self._motor_types = {}
|
||||
for name, motor in self.motors.items():
|
||||
if hasattr(motor, "motor_type"):
|
||||
self._motor_types[name] = motor.motor_type
|
||||
else:
|
||||
# Default to DM4310 if not specified
|
||||
self._motor_types[name] = MotorType.DM4310
|
||||
|
||||
# Map recv_id to motor name for filtering responses
|
||||
if hasattr(motor, "recv_id"):
|
||||
self._recv_id_to_motor[motor.recv_id] = name
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
"""Check if the CAN bus is connected."""
|
||||
return self._is_connected and self.canbus is not None
|
||||
|
||||
def connect(self, handshake: bool = True) -> None:
|
||||
"""
|
||||
Open the CAN bus and initialize communication.
|
||||
|
||||
Args:
|
||||
handshake: If True, ping all motors to verify they're present
|
||||
"""
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(
|
||||
f"{self.__class__.__name__}('{self.port}') is already connected."
|
||||
)
|
||||
|
||||
try:
|
||||
# Auto-detect interface type based on port name
|
||||
if self.can_interface == "auto":
|
||||
if self.port.startswith("/dev/"):
|
||||
# Serial device (macOS/Windows)
|
||||
self.can_interface = "slcan"
|
||||
logger.info(f"Auto-detected slcan interface for port {self.port}")
|
||||
else:
|
||||
# Network interface (Linux)
|
||||
self.can_interface = "socketcan"
|
||||
logger.info(f"Auto-detected socketcan interface for port {self.port}")
|
||||
|
||||
# Connect to CAN bus
|
||||
if self.can_interface == "socketcan":
|
||||
# Linux SocketCAN with CAN FD support
|
||||
if self.use_can_fd and self.data_bitrate is not None:
|
||||
self.canbus = can.interface.Bus(
|
||||
channel=self.port,
|
||||
interface="socketcan",
|
||||
bitrate=self.bitrate,
|
||||
data_bitrate=self.data_bitrate,
|
||||
fd=True
|
||||
)
|
||||
logger.info(f"Connected to {self.port} with CAN FD (bitrate={self.bitrate}, data_bitrate={self.data_bitrate})")
|
||||
else:
|
||||
self.canbus = can.interface.Bus(
|
||||
channel=self.port,
|
||||
interface="socketcan",
|
||||
bitrate=self.bitrate
|
||||
)
|
||||
logger.info(f"Connected to {self.port} with CAN 2.0 (bitrate={self.bitrate})")
|
||||
elif self.can_interface == "slcan":
|
||||
# Serial Line CAN (macOS, Windows, or USB adapters)
|
||||
# Note: SLCAN typically doesn't support CAN FD
|
||||
self.canbus = can.interface.Bus(
|
||||
channel=self.port,
|
||||
interface="slcan",
|
||||
bitrate=self.bitrate
|
||||
)
|
||||
logger.info(f"Connected to {self.port} with SLCAN (bitrate={self.bitrate})")
|
||||
else:
|
||||
# Generic interface (vector, pcan, etc.)
|
||||
if self.use_can_fd and self.data_bitrate is not None:
|
||||
self.canbus = can.interface.Bus(
|
||||
channel=self.port,
|
||||
interface=self.can_interface,
|
||||
bitrate=self.bitrate,
|
||||
data_bitrate=self.data_bitrate,
|
||||
fd=True
|
||||
)
|
||||
else:
|
||||
self.canbus = can.interface.Bus(
|
||||
channel=self.port,
|
||||
interface=self.can_interface,
|
||||
bitrate=self.bitrate
|
||||
)
|
||||
|
||||
self._is_connected = True
|
||||
|
||||
if handshake:
|
||||
self._handshake()
|
||||
|
||||
logger.debug(f"{self.__class__.__name__} connected via {self.can_interface}.")
|
||||
except Exception as e:
|
||||
self._is_connected = False
|
||||
raise ConnectionError(f"Failed to connect to CAN bus: {e}")
|
||||
|
||||
def _handshake(self) -> None:
|
||||
"""Verify all motors are present by refreshing their status."""
|
||||
for motor_name in self.motors:
|
||||
self._refresh_motor(motor_name)
|
||||
time.sleep(0.01) # Small delay between motors
|
||||
|
||||
def disconnect(self, disable_torque: bool = True) -> None:
|
||||
"""
|
||||
Close the CAN bus connection.
|
||||
|
||||
Args:
|
||||
disable_torque: If True, disable torque on all motors before disconnecting
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(
|
||||
f"{self.__class__.__name__}('{self.port}') is not connected."
|
||||
)
|
||||
|
||||
if disable_torque:
|
||||
try:
|
||||
self.disable_torque()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to disable torque during disconnect: {e}")
|
||||
|
||||
if self.canbus:
|
||||
self.canbus.shutdown()
|
||||
self.canbus = None
|
||||
self._is_connected = False
|
||||
logger.debug(f"{self.__class__.__name__} disconnected.")
|
||||
|
||||
def configure_motors(self) -> None:
|
||||
"""Configure all motors with default settings."""
|
||||
# Damiao motors don't require much configuration in MIT mode
|
||||
# Just ensure they're enabled
|
||||
for motor in self.motors:
|
||||
self._enable_motor(motor)
|
||||
time.sleep(0.01)
|
||||
|
||||
def _enable_motor(self, motor: NameOrID) -> None:
|
||||
"""Enable a single motor."""
|
||||
motor_id = self._get_motor_id(motor)
|
||||
recv_id = self._get_motor_recv_id(motor)
|
||||
data = [0xFF] * 7 + [CAN_CMD_ENABLE]
|
||||
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False)
|
||||
self.canbus.send(msg)
|
||||
self._recv_motor_response(expected_recv_id=recv_id)
|
||||
|
||||
def _disable_motor(self, motor: NameOrID) -> None:
|
||||
"""Disable a single motor."""
|
||||
motor_id = self._get_motor_id(motor)
|
||||
recv_id = self._get_motor_recv_id(motor)
|
||||
data = [0xFF] * 7 + [CAN_CMD_DISABLE]
|
||||
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False)
|
||||
self.canbus.send(msg)
|
||||
self._recv_motor_response(expected_recv_id=recv_id)
|
||||
|
||||
def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
|
||||
"""Enable torque on selected motors."""
|
||||
motors = self._get_motors_list(motors)
|
||||
for motor in motors:
|
||||
for _ in range(num_retry + 1):
|
||||
try:
|
||||
self._enable_motor(motor)
|
||||
break
|
||||
except Exception as e:
|
||||
if _ == num_retry:
|
||||
raise e
|
||||
time.sleep(0.01)
|
||||
|
||||
def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
|
||||
"""Disable torque on selected motors."""
|
||||
motors = self._get_motors_list(motors)
|
||||
for motor in motors:
|
||||
for _ in range(num_retry + 1):
|
||||
try:
|
||||
self._disable_motor(motor)
|
||||
break
|
||||
except Exception as e:
|
||||
if _ == num_retry:
|
||||
raise e
|
||||
time.sleep(0.01)
|
||||
|
||||
@contextmanager
|
||||
def torque_disabled(self, motors: str | list[str] | None = None):
|
||||
"""
|
||||
Context manager that guarantees torque is re-enabled.
|
||||
|
||||
This helper is useful to temporarily disable torque when configuring motors.
|
||||
|
||||
Examples:
|
||||
>>> with bus.torque_disabled():
|
||||
... # Safe operations here with torque disabled
|
||||
... pass
|
||||
"""
|
||||
self.disable_torque(motors)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self.enable_torque(motors)
|
||||
|
||||
def set_zero_position(self, motors: str | list[str] | None = None) -> None:
|
||||
"""Set current position as zero for selected motors."""
|
||||
motors = self._get_motors_list(motors)
|
||||
for motor in motors:
|
||||
motor_id = self._get_motor_id(motor)
|
||||
recv_id = self._get_motor_recv_id(motor)
|
||||
data = [0xFF] * 7 + [CAN_CMD_SET_ZERO]
|
||||
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False)
|
||||
self.canbus.send(msg)
|
||||
self._recv_motor_response(expected_recv_id=recv_id)
|
||||
time.sleep(0.01)
|
||||
|
||||
def _refresh_motor(self, motor: NameOrID) -> Optional[can.Message]:
|
||||
"""Refresh motor status and return the response."""
|
||||
motor_id = self._get_motor_id(motor)
|
||||
recv_id = self._get_motor_recv_id(motor)
|
||||
data = [motor_id & 0xFF, (motor_id >> 8) & 0xFF, CAN_CMD_REFRESH, 0, 0, 0, 0, 0]
|
||||
msg = can.Message(arbitration_id=CAN_PARAM_ID, data=data, is_extended_id=False)
|
||||
self.canbus.send(msg)
|
||||
return self._recv_motor_response(expected_recv_id=recv_id)
|
||||
|
||||
def _recv_motor_response(self, expected_recv_id: Optional[int] = None, timeout: float = 0.001) -> Optional[can.Message]:
|
||||
"""
|
||||
Receive a response from a motor.
|
||||
|
||||
Args:
|
||||
expected_recv_id: If provided, only return messages from this CAN ID
|
||||
timeout: Timeout in seconds (default: 1ms for high-speed operation)
|
||||
|
||||
Returns:
|
||||
CAN message if received, None otherwise
|
||||
"""
|
||||
try:
|
||||
start_time = time.time()
|
||||
messages_seen = []
|
||||
while time.time() - start_time < timeout:
|
||||
msg = self.canbus.recv(timeout=0.0001) # 100us timeout for fast polling
|
||||
if msg:
|
||||
messages_seen.append(f"0x{msg.arbitration_id:02X}")
|
||||
# If no filter specified, return any message
|
||||
if expected_recv_id is None:
|
||||
return msg
|
||||
# Otherwise, only return if it matches the expected recv_id
|
||||
if msg.arbitration_id == expected_recv_id:
|
||||
return msg
|
||||
else:
|
||||
logger.debug(f"Ignoring message from CAN ID 0x{msg.arbitration_id:02X}, expected 0x{expected_recv_id:02X}")
|
||||
|
||||
# Only log warnings if we're in debug mode to reduce overhead
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
if messages_seen:
|
||||
logger.debug(f"Received {len(messages_seen)} message(s) from IDs {set(messages_seen)}, but expected 0x{expected_recv_id:02X}")
|
||||
else:
|
||||
logger.debug(f"No CAN messages received (expected from 0x{expected_recv_id:02X})")
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to receive CAN message: {e}")
|
||||
return None
|
||||
|
||||
def _recv_all_responses(self, expected_recv_ids: list[int], timeout: float = 0.002) -> dict[int, can.Message]:
|
||||
"""
|
||||
Efficiently receive responses from multiple motors at once.
|
||||
Uses the OpenArms pattern: collect all available messages within timeout.
|
||||
|
||||
Args:
|
||||
expected_recv_ids: List of CAN IDs we expect responses from
|
||||
timeout: Total timeout in seconds (default: 2ms)
|
||||
|
||||
Returns:
|
||||
Dictionary mapping recv_id to CAN message
|
||||
"""
|
||||
responses = {}
|
||||
expected_set = set(expected_recv_ids)
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
while len(responses) < len(expected_recv_ids) and (time.time() - start_time) < timeout:
|
||||
msg = self.canbus.recv(timeout=0.0002) # 200us poll timeout (increased from 100us for better reliability)
|
||||
if msg and msg.arbitration_id in expected_set:
|
||||
responses[msg.arbitration_id] = msg
|
||||
if len(responses) == len(expected_recv_ids):
|
||||
break # Got all responses, exit early
|
||||
except Exception as e:
|
||||
logger.debug(f"Error receiving responses: {e}")
|
||||
|
||||
return responses
|
||||
|
||||
def _mit_control(
|
||||
self,
|
||||
motor: NameOrID,
|
||||
kp: float,
|
||||
kd: float,
|
||||
position_degrees: float,
|
||||
velocity_deg_per_sec: float,
|
||||
torque: float,
|
||||
) -> None:
|
||||
"""
|
||||
Send MIT control command to a motor.
|
||||
|
||||
Args:
|
||||
motor: Motor name or ID
|
||||
kp: Position gain
|
||||
kd: Velocity gain
|
||||
position_degrees: Target position (degrees)
|
||||
velocity_deg_per_sec: Target velocity (degrees/s)
|
||||
torque: Target torque (N·m)
|
||||
"""
|
||||
motor_id = self._get_motor_id(motor)
|
||||
motor_name = self._get_motor_name(motor)
|
||||
motor_type = self._motor_types.get(motor_name, MotorType.DM4310)
|
||||
|
||||
# Convert degrees to radians for motor control
|
||||
position_rad = np.radians(position_degrees)
|
||||
velocity_rad_per_sec = np.radians(velocity_deg_per_sec)
|
||||
|
||||
# Get motor limits
|
||||
pmax, vmax, tmax = MOTOR_LIMIT_PARAMS[motor_type]
|
||||
|
||||
# Encode parameters
|
||||
kp_uint = self._float_to_uint(kp, 0, 500, 12)
|
||||
kd_uint = self._float_to_uint(kd, 0, 5, 12)
|
||||
q_uint = self._float_to_uint(position_rad, -pmax, pmax, 16)
|
||||
dq_uint = self._float_to_uint(velocity_rad_per_sec, -vmax, vmax, 12)
|
||||
tau_uint = self._float_to_uint(torque, -tmax, tmax, 12)
|
||||
|
||||
# Pack data
|
||||
data = [0] * 8
|
||||
data[0] = (q_uint >> 8) & 0xFF
|
||||
data[1] = q_uint & 0xFF
|
||||
data[2] = dq_uint >> 4
|
||||
data[3] = ((dq_uint & 0xF) << 4) | ((kp_uint >> 8) & 0xF)
|
||||
data[4] = kp_uint & 0xFF
|
||||
data[5] = kd_uint >> 4
|
||||
data[6] = ((kd_uint & 0xF) << 4) | ((tau_uint >> 8) & 0xF)
|
||||
data[7] = tau_uint & 0xFF
|
||||
|
||||
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False)
|
||||
self.canbus.send(msg)
|
||||
recv_id = self._get_motor_recv_id(motor)
|
||||
self._recv_motor_response(expected_recv_id=recv_id)
|
||||
|
||||
def _mit_control_batch(
|
||||
self,
|
||||
commands: Dict[NameOrID, Tuple[float, float, float, float, float]],
|
||||
) -> None:
|
||||
"""
|
||||
Send MIT control commands to multiple motors in batch (optimized).
|
||||
Sends all commands first, then collects responses. Much faster than sequential.
|
||||
|
||||
Args:
|
||||
commands: Dict mapping motor name/ID to (kp, kd, position_deg, velocity_deg/s, torque)
|
||||
Example: {'joint_1': (10.0, 0.5, 45.0, 0.0, 0.0), ...}
|
||||
"""
|
||||
if not commands:
|
||||
return
|
||||
|
||||
expected_recv_ids = []
|
||||
|
||||
# Step 1: Send all MIT control commands (no waiting)
|
||||
for motor, (kp, kd, position_degrees, velocity_deg_per_sec, torque) in commands.items():
|
||||
motor_id = self._get_motor_id(motor)
|
||||
motor_name = self._get_motor_name(motor)
|
||||
motor_type = self._motor_types.get(motor_name, MotorType.DM4310)
|
||||
|
||||
# Convert degrees to radians
|
||||
position_rad = np.radians(position_degrees)
|
||||
velocity_rad_per_sec = np.radians(velocity_deg_per_sec)
|
||||
|
||||
# Get motor limits
|
||||
pmax, vmax, tmax = MOTOR_LIMIT_PARAMS[motor_type]
|
||||
|
||||
# Encode parameters
|
||||
kp_uint = self._float_to_uint(kp, 0, 500, 12)
|
||||
kd_uint = self._float_to_uint(kd, 0, 5, 12)
|
||||
q_uint = self._float_to_uint(position_rad, -pmax, pmax, 16)
|
||||
dq_uint = self._float_to_uint(velocity_rad_per_sec, -vmax, vmax, 12)
|
||||
tau_uint = self._float_to_uint(torque, -tmax, tmax, 12)
|
||||
|
||||
# Pack data
|
||||
data = [0] * 8
|
||||
data[0] = (q_uint >> 8) & 0xFF
|
||||
data[1] = q_uint & 0xFF
|
||||
data[2] = dq_uint >> 4
|
||||
data[3] = ((dq_uint & 0xF) << 4) | ((kp_uint >> 8) & 0xF)
|
||||
data[4] = kp_uint & 0xFF
|
||||
data[5] = kd_uint >> 4
|
||||
data[6] = ((kd_uint & 0xF) << 4) | ((tau_uint >> 8) & 0xF)
|
||||
data[7] = tau_uint & 0xFF
|
||||
|
||||
# Send command
|
||||
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False)
|
||||
self.canbus.send(msg)
|
||||
|
||||
# Track expected response
|
||||
recv_id = self._get_motor_recv_id(motor)
|
||||
expected_recv_ids.append(recv_id)
|
||||
|
||||
# Step 2: Collect all responses at once
|
||||
self._recv_all_responses(expected_recv_ids, timeout=0.002)
|
||||
|
||||
def _float_to_uint(self, x: float, x_min: float, x_max: float, bits: int) -> int:
|
||||
"""Convert float to unsigned integer for CAN transmission."""
|
||||
x = max(x_min, min(x_max, x)) # Clamp to range
|
||||
span = x_max - x_min
|
||||
data_norm = (x - x_min) / span
|
||||
return int(data_norm * ((1 << bits) - 1))
|
||||
|
||||
def _uint_to_float(self, x: int, x_min: float, x_max: float, bits: int) -> float:
|
||||
"""Convert unsigned integer from CAN to float."""
|
||||
span = x_max - x_min
|
||||
data_norm = float(x) / ((1 << bits) - 1)
|
||||
return data_norm * span + x_min
|
||||
|
||||
def _decode_motor_state(self, data: bytes, motor_type: MotorType) -> Tuple[float, float, float, int, int]:
|
||||
"""
|
||||
Decode motor state from CAN data.
|
||||
|
||||
Returns:
|
||||
Tuple of (position_degrees, velocity_deg_per_sec, torque, temp_mos, temp_rotor)
|
||||
"""
|
||||
if len(data) < 8:
|
||||
raise ValueError("Invalid motor state data")
|
||||
|
||||
# Extract encoded values
|
||||
q_uint = (data[1] << 8) | data[2]
|
||||
dq_uint = (data[3] << 4) | (data[4] >> 4)
|
||||
tau_uint = ((data[4] & 0x0F) << 8) | data[5]
|
||||
t_mos = data[6]
|
||||
t_rotor = data[7]
|
||||
|
||||
# Get motor limits
|
||||
pmax, vmax, tmax = MOTOR_LIMIT_PARAMS[motor_type]
|
||||
|
||||
# Decode to physical values (radians)
|
||||
position_rad = self._uint_to_float(q_uint, -pmax, pmax, 16)
|
||||
velocity_rad_per_sec = self._uint_to_float(dq_uint, -vmax, vmax, 12)
|
||||
torque = self._uint_to_float(tau_uint, -tmax, tmax, 12)
|
||||
|
||||
# Convert to degrees
|
||||
position_degrees = np.degrees(position_rad)
|
||||
velocity_deg_per_sec = np.degrees(velocity_rad_per_sec)
|
||||
|
||||
return position_degrees, velocity_deg_per_sec, torque, t_mos, t_rotor
|
||||
|
||||
def read(
|
||||
self,
|
||||
data_name: str,
|
||||
motor: str,
|
||||
*,
|
||||
normalize: bool = True,
|
||||
num_retry: int = 0,
|
||||
) -> Value:
|
||||
"""Read a value from a single motor. Positions are always in degrees."""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
# Refresh motor to get latest state
|
||||
msg = self._refresh_motor(motor)
|
||||
if msg is None:
|
||||
motor_id = self._get_motor_id(motor)
|
||||
recv_id = self._get_motor_recv_id(motor)
|
||||
raise ConnectionError(
|
||||
f"No response from motor '{motor}' (send ID: 0x{motor_id:02X}, recv ID: 0x{recv_id:02X}). "
|
||||
f"Check that: 1) Motor is powered (24V), 2) CAN wiring is correct, "
|
||||
f"3) Motor IDs are configured correctly using Damiao Debugging Tools"
|
||||
)
|
||||
|
||||
motor_type = self._motor_types.get(motor, MotorType.DM4310)
|
||||
position_degrees, velocity_deg_per_sec, torque, t_mos, t_rotor = self._decode_motor_state(msg.data, motor_type)
|
||||
|
||||
# Return requested data (already in degrees for position/velocity)
|
||||
if data_name == "Present_Position":
|
||||
value = position_degrees
|
||||
elif data_name == "Present_Velocity":
|
||||
value = velocity_deg_per_sec
|
||||
elif data_name == "Present_Torque":
|
||||
value = torque
|
||||
elif data_name == "Temperature_MOS":
|
||||
value = t_mos
|
||||
elif data_name == "Temperature_Rotor":
|
||||
value = t_rotor
|
||||
else:
|
||||
raise ValueError(f"Unknown data_name: {data_name}")
|
||||
|
||||
# For Damiao, positions are always in degrees, no normalization needed
|
||||
# We keep the normalize parameter for compatibility but don't use it
|
||||
return value
|
||||
|
||||
def write(
|
||||
self,
|
||||
data_name: str,
|
||||
motor: str,
|
||||
value: Value,
|
||||
*,
|
||||
normalize: bool = True,
|
||||
num_retry: int = 0,
|
||||
) -> None:
|
||||
"""Write a value to a single motor. Positions are always in degrees."""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
# Value is expected to be in degrees for positions
|
||||
if data_name == "Goal_Position":
|
||||
# Use MIT control with position in degrees
|
||||
self._mit_control(motor, 10.0, 0.5, value, 0, 0)
|
||||
else:
|
||||
raise ValueError(f"Writing {data_name} not supported in MIT mode")
|
||||
|
||||
def sync_read(
|
||||
self,
|
||||
data_name: str,
|
||||
motors: str | list[str] | None = None,
|
||||
*,
|
||||
normalize: bool = True,
|
||||
num_retry: int = 0,
|
||||
) -> Dict[str, Value]:
|
||||
"""
|
||||
Read the same value from multiple motors simultaneously.
|
||||
Uses batched operations: sends all refresh commands, then collects all responses.
|
||||
This is MUCH faster than sequential reads (OpenArms pattern).
|
||||
"""
|
||||
motors = self._get_motors_list(motors)
|
||||
result = {}
|
||||
|
||||
# Step 1: Send refresh commands to ALL motors first (no waiting)
|
||||
for motor in motors:
|
||||
motor_id = self._get_motor_id(motor)
|
||||
data = [motor_id & 0xFF, (motor_id >> 8) & 0xFF, CAN_CMD_REFRESH, 0, 0, 0, 0, 0]
|
||||
msg = can.Message(arbitration_id=CAN_PARAM_ID, data=data, is_extended_id=False)
|
||||
self.canbus.send(msg)
|
||||
|
||||
# Step 2: Collect all responses at once (batch receive)
|
||||
expected_recv_ids = [self._get_motor_recv_id(motor) for motor in motors]
|
||||
responses = self._recv_all_responses(expected_recv_ids, timeout=0.01) # 10ms total timeout
|
||||
|
||||
# Step 3: Parse responses
|
||||
for motor in motors:
|
||||
try:
|
||||
recv_id = self._get_motor_recv_id(motor)
|
||||
msg = responses.get(recv_id)
|
||||
|
||||
if msg is None:
|
||||
logger.warning(f"No response from motor '{motor}' (recv ID: 0x{recv_id:02X})")
|
||||
result[motor] = 0.0
|
||||
continue
|
||||
|
||||
motor_type = self._motor_types.get(motor, MotorType.DM4310)
|
||||
position_degrees, velocity_deg_per_sec, torque, t_mos, t_rotor = self._decode_motor_state(msg.data, motor_type)
|
||||
|
||||
# Return requested data
|
||||
if data_name == "Present_Position":
|
||||
value = position_degrees
|
||||
elif data_name == "Present_Velocity":
|
||||
value = velocity_deg_per_sec
|
||||
elif data_name == "Present_Torque":
|
||||
value = torque
|
||||
elif data_name == "Temperature_MOS":
|
||||
value = t_mos
|
||||
elif data_name == "Temperature_Rotor":
|
||||
value = t_rotor
|
||||
else:
|
||||
raise ValueError(f"Unknown data_name: {data_name}")
|
||||
|
||||
result[motor] = value
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to read {data_name} from {motor}: {e}")
|
||||
result[motor] = 0.0
|
||||
|
||||
return result
|
||||
|
||||
def sync_read_all_states(
|
||||
self,
|
||||
motors: str | list[str] | None = None,
|
||||
*,
|
||||
num_retry: int = 0,
|
||||
) -> Dict[str, Dict[str, Value]]:
|
||||
"""
|
||||
Read ALL motor states (position, velocity, torque) from multiple motors in ONE refresh cycle.
|
||||
This is 3x faster than calling sync_read() three times separately.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping motor names to state dicts with keys: 'position', 'velocity', 'torque'
|
||||
Example: {'joint_1': {'position': 45.2, 'velocity': 1.3, 'torque': 0.5}, ...}
|
||||
"""
|
||||
motors = self._get_motors_list(motors)
|
||||
result = {}
|
||||
|
||||
# Step 1: Send refresh commands to ALL motors first (with small delays to reduce bus congestion)
|
||||
for motor in motors:
|
||||
motor_id = self._get_motor_id(motor)
|
||||
data = [motor_id & 0xFF, (motor_id >> 8) & 0xFF, CAN_CMD_REFRESH, 0, 0, 0, 0, 0]
|
||||
msg = can.Message(arbitration_id=CAN_PARAM_ID, data=data, is_extended_id=False)
|
||||
self.canbus.send(msg)
|
||||
time.sleep(0.0001) # 100us delay between commands to reduce bus congestion
|
||||
|
||||
# Step 2: Collect all responses at once (batch receive)
|
||||
expected_recv_ids = [self._get_motor_recv_id(motor) for motor in motors]
|
||||
responses = self._recv_all_responses(expected_recv_ids, timeout=0.015) # 15ms timeout (increased for reliability)
|
||||
|
||||
# Step 3: Parse responses and extract ALL state values
|
||||
for motor in motors:
|
||||
try:
|
||||
recv_id = self._get_motor_recv_id(motor)
|
||||
msg = responses.get(recv_id)
|
||||
|
||||
if msg is None:
|
||||
logger.warning(f"No response from motor '{motor}' (recv ID: 0x{recv_id:02X})")
|
||||
result[motor] = {"position": 0.0, "velocity": 0.0, "torque": 0.0}
|
||||
continue
|
||||
|
||||
motor_type = self._motor_types.get(motor, MotorType.DM4310)
|
||||
position_degrees, velocity_deg_per_sec, torque, t_mos, t_rotor = self._decode_motor_state(msg.data, motor_type)
|
||||
|
||||
# Return all state values in one dict
|
||||
result[motor] = {
|
||||
"position": position_degrees,
|
||||
"velocity": velocity_deg_per_sec,
|
||||
"torque": torque,
|
||||
"temp_mos": t_mos,
|
||||
"temp_rotor": t_rotor,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to read state from {motor}: {e}")
|
||||
result[motor] = {"position": 0.0, "velocity": 0.0, "torque": 0.0}
|
||||
|
||||
return result
|
||||
|
||||
def sync_write(
|
||||
self,
|
||||
data_name: str,
|
||||
values: Dict[str, Value],
|
||||
*,
|
||||
normalize: bool = True,
|
||||
num_retry: int = 0,
|
||||
) -> None:
|
||||
"""
|
||||
Write different values to multiple motors simultaneously. Positions are always in degrees.
|
||||
Uses batched operations: sends all commands first, then collects responses (OpenArms pattern).
|
||||
"""
|
||||
if data_name == "Goal_Position":
|
||||
# Step 1: Send all MIT control commands first (no waiting)
|
||||
for motor, value_degrees in values.items():
|
||||
motor_id = self._get_motor_id(motor)
|
||||
motor_name = self._get_motor_name(motor)
|
||||
motor_type = self._motor_types.get(motor_name, MotorType.DM4310)
|
||||
|
||||
# Convert degrees to radians
|
||||
position_rad = np.radians(value_degrees)
|
||||
|
||||
# Default gains for position control
|
||||
kp, kd = 10.0, 0.5
|
||||
|
||||
# Get motor limits and encode parameters
|
||||
pmax, vmax, tmax = MOTOR_LIMIT_PARAMS[motor_type]
|
||||
kp_uint = self._float_to_uint(kp, 0, 500, 12)
|
||||
kd_uint = self._float_to_uint(kd, 0, 5, 12)
|
||||
q_uint = self._float_to_uint(position_rad, -pmax, pmax, 16)
|
||||
dq_uint = self._float_to_uint(0, -vmax, vmax, 12)
|
||||
tau_uint = self._float_to_uint(0, -tmax, tmax, 12)
|
||||
|
||||
# Pack data
|
||||
data = [0] * 8
|
||||
data[0] = (q_uint >> 8) & 0xFF
|
||||
data[1] = q_uint & 0xFF
|
||||
data[2] = dq_uint >> 4
|
||||
data[3] = ((dq_uint & 0xF) << 4) | ((kp_uint >> 8) & 0xF)
|
||||
data[4] = kp_uint & 0xFF
|
||||
data[5] = kd_uint >> 4
|
||||
data[6] = ((kd_uint & 0xF) << 4) | ((tau_uint >> 8) & 0xF)
|
||||
data[7] = tau_uint & 0xFF
|
||||
|
||||
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False)
|
||||
self.canbus.send(msg)
|
||||
time.sleep(0.0001) # 100us delay between commands to reduce bus congestion
|
||||
|
||||
# Step 2: Collect all responses at once
|
||||
expected_recv_ids = [self._get_motor_recv_id(motor) for motor in values.keys()]
|
||||
self._recv_all_responses(expected_recv_ids, timeout=0.015) # 15ms timeout (increased for reliability)
|
||||
else:
|
||||
# Fall back to individual writes for other data types
|
||||
for motor, value in values.items():
|
||||
self.write(data_name, motor, value, normalize=normalize, num_retry=num_retry)
|
||||
|
||||
def read_calibration(self) -> dict[str, MotorCalibration]:
|
||||
"""Read calibration data from motors."""
|
||||
# Damiao motors don't store calibration internally
|
||||
# Return existing calibration or empty dict
|
||||
return self.calibration if self.calibration else {}
|
||||
|
||||
def write_calibration(self, calibration_dict: dict[str, MotorCalibration], cache: bool = True) -> None:
|
||||
"""Write calibration data to motors."""
|
||||
# Damiao motors don't store calibration internally
|
||||
# Just cache it in memory
|
||||
if cache:
|
||||
self.calibration = calibration_dict
|
||||
|
||||
def record_ranges_of_motion(
|
||||
self, motors: NameOrID | list[NameOrID] | None = None, display_values: bool = True
|
||||
) -> tuple[dict[NameOrID, Value], dict[NameOrID, Value]]:
|
||||
"""
|
||||
Interactively record the min/max values of each motor in degrees.
|
||||
|
||||
Move the joints by hand (with torque disabled) while the method streams live positions.
|
||||
Press Enter to finish.
|
||||
"""
|
||||
if motors is None:
|
||||
motors = list(self.motors.keys())
|
||||
elif isinstance(motors, (str, int)):
|
||||
motors = [motors]
|
||||
|
||||
# Disable torque for manual movement
|
||||
self.disable_torque(motors)
|
||||
time.sleep(0.1)
|
||||
|
||||
# Get initial positions (already in degrees)
|
||||
start_positions = self.sync_read("Present_Position", motors, normalize=False)
|
||||
mins = start_positions.copy()
|
||||
maxes = start_positions.copy()
|
||||
|
||||
print("\nMove joints through their full range of motion. Press ENTER when done.")
|
||||
user_pressed_enter = False
|
||||
|
||||
while not user_pressed_enter:
|
||||
positions = self.sync_read("Present_Position", motors, normalize=False)
|
||||
|
||||
for motor in motors:
|
||||
if motor in positions:
|
||||
mins[motor] = min(positions[motor], mins.get(motor, positions[motor]))
|
||||
maxes[motor] = max(positions[motor], maxes.get(motor, positions[motor]))
|
||||
|
||||
if display_values:
|
||||
print("\n" + "=" * 50)
|
||||
print(f"{'MOTOR':<20} | {'MIN (deg)':>12} | {'POS (deg)':>12} | {'MAX (deg)':>12}")
|
||||
print("-" * 50)
|
||||
for motor in motors:
|
||||
if motor in positions:
|
||||
print(f"{motor:<20} | {mins[motor]:>12.1f} | {positions[motor]:>12.1f} | {maxes[motor]:>12.1f}")
|
||||
|
||||
if enter_pressed():
|
||||
user_pressed_enter = True
|
||||
|
||||
if display_values and not user_pressed_enter:
|
||||
# Move cursor up to overwrite the previous output
|
||||
move_cursor_up(len(motors) + 4)
|
||||
|
||||
time.sleep(0.05)
|
||||
|
||||
# Re-enable torque
|
||||
self.enable_torque(motors)
|
||||
|
||||
# Validate ranges
|
||||
for motor in motors:
|
||||
if motor in mins and motor in maxes:
|
||||
if abs(maxes[motor] - mins[motor]) < 5.0: # At least 5 degrees of range
|
||||
raise ValueError(f"Motor {motor} has insufficient range of motion (< 5 degrees)")
|
||||
|
||||
return mins, maxes
|
||||
|
||||
def _get_motors_list(self, motors: str | list[str] | None) -> list[str]:
|
||||
"""Convert motor specification to list of motor names."""
|
||||
if motors is None:
|
||||
return list(self.motors.keys())
|
||||
elif isinstance(motors, str):
|
||||
return [motors]
|
||||
elif isinstance(motors, list):
|
||||
return motors
|
||||
else:
|
||||
raise TypeError(f"Invalid motors type: {type(motors)}")
|
||||
|
||||
def _get_motor_id(self, motor: NameOrID) -> int:
|
||||
"""Get CAN ID for a motor."""
|
||||
if isinstance(motor, str):
|
||||
if motor in self.motors:
|
||||
return self.motors[motor].id
|
||||
else:
|
||||
raise ValueError(f"Unknown motor: {motor}")
|
||||
else:
|
||||
return motor
|
||||
|
||||
def _get_motor_name(self, motor: NameOrID) -> str:
|
||||
"""Get motor name from name or ID."""
|
||||
if isinstance(motor, str):
|
||||
return motor
|
||||
else:
|
||||
for name, m in self.motors.items():
|
||||
if m.id == motor:
|
||||
return name
|
||||
raise ValueError(f"Unknown motor ID: {motor}")
|
||||
|
||||
def _get_motor_recv_id(self, motor: NameOrID) -> Optional[int]:
|
||||
"""Get motor recv_id from name or ID."""
|
||||
motor_name = self._get_motor_name(motor)
|
||||
motor_obj = self.motors.get(motor_name)
|
||||
if motor_obj and hasattr(motor_obj, "recv_id"):
|
||||
return motor_obj.recv_id
|
||||
return None
|
||||
|
||||
@cached_property
|
||||
def is_calibrated(self) -> bool:
|
||||
"""Check if motors are calibrated."""
|
||||
return bool(self.calibration)
|
||||
@@ -0,0 +1,209 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Configuration tables for Damiao motors."""
|
||||
|
||||
from enum import IntEnum
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
# Motor type definitions
|
||||
class MotorType(IntEnum):
|
||||
DM3507 = 0
|
||||
DM4310 = 1
|
||||
DM4310_48V = 2
|
||||
DM4340 = 3
|
||||
DM4340_48V = 4
|
||||
DM6006 = 5
|
||||
DM8006 = 6
|
||||
DM8009 = 7
|
||||
DM10010L = 8
|
||||
DM10010 = 9
|
||||
DMH3510 = 10
|
||||
DMH6215 = 11
|
||||
DMG6220 = 12
|
||||
|
||||
# Control modes
|
||||
class ControlMode(IntEnum):
|
||||
MIT = 1
|
||||
POS_VEL = 2
|
||||
VEL = 3
|
||||
TORQUE_POS = 4
|
||||
|
||||
# Motor variable IDs (RID)
|
||||
class MotorVariable(IntEnum):
|
||||
UV_VALUE = 0
|
||||
KT_VALUE = 1
|
||||
OT_VALUE = 2
|
||||
OC_VALUE = 3
|
||||
ACC = 4
|
||||
DEC = 5
|
||||
MAX_SPD = 6
|
||||
MST_ID = 7
|
||||
ESC_ID = 8
|
||||
TIMEOUT = 9
|
||||
CTRL_MODE = 10
|
||||
DAMP = 11
|
||||
INERTIA = 12
|
||||
HW_VER = 13
|
||||
SW_VER = 14
|
||||
SN = 15
|
||||
NPP = 16
|
||||
RS = 17
|
||||
LS = 18
|
||||
FLUX = 19
|
||||
GR = 20
|
||||
PMAX = 21
|
||||
VMAX = 22
|
||||
TMAX = 23
|
||||
I_BW = 24
|
||||
KP_ASR = 25
|
||||
KI_ASR = 26
|
||||
KP_APR = 27
|
||||
KI_APR = 28
|
||||
OV_VALUE = 29
|
||||
GREF = 30
|
||||
DETA = 31
|
||||
V_BW = 32
|
||||
IQ_C1 = 33
|
||||
VL_C1 = 34
|
||||
CAN_BR = 35
|
||||
SUB_VER = 36
|
||||
U_OFF = 50
|
||||
V_OFF = 51
|
||||
K1 = 52
|
||||
K2 = 53
|
||||
M_OFF = 54
|
||||
DIR = 55
|
||||
P_M = 80
|
||||
XOUT = 81
|
||||
|
||||
# Motor limit parameters [PMAX, VMAX, TMAX]
|
||||
# PMAX: Maximum position (rad)
|
||||
# VMAX: Maximum velocity (rad/s)
|
||||
# TMAX: Maximum torque (N·m)
|
||||
MOTOR_LIMIT_PARAMS = {
|
||||
MotorType.DM3507: (12.5, 30, 10),
|
||||
MotorType.DM4310: (12.5, 30, 10),
|
||||
MotorType.DM4310_48V: (12.5, 50, 10),
|
||||
MotorType.DM4340: (12.5, 8, 28),
|
||||
MotorType.DM4340_48V: (12.5, 10, 28),
|
||||
MotorType.DM6006: (12.5, 45, 20),
|
||||
MotorType.DM8006: (12.5, 45, 40),
|
||||
MotorType.DM8009: (12.5, 45, 54),
|
||||
MotorType.DM10010L: (12.5, 25, 200),
|
||||
MotorType.DM10010: (12.5, 20, 200),
|
||||
MotorType.DMH3510: (12.5, 280, 1),
|
||||
MotorType.DMH6215: (12.5, 45, 10),
|
||||
MotorType.DMG6220: (12.5, 45, 10),
|
||||
}
|
||||
|
||||
# Motor model names
|
||||
MODEL_NAMES = {
|
||||
MotorType.DM3507: "dm3507",
|
||||
MotorType.DM4310: "dm4310",
|
||||
MotorType.DM4310_48V: "dm4310_48v",
|
||||
MotorType.DM4340: "dm4340",
|
||||
MotorType.DM4340_48V: "dm4340_48v",
|
||||
MotorType.DM6006: "dm6006",
|
||||
MotorType.DM8006: "dm8006",
|
||||
MotorType.DM8009: "dm8009",
|
||||
MotorType.DM10010L: "dm10010l",
|
||||
MotorType.DM10010: "dm10010",
|
||||
MotorType.DMH3510: "dmh3510",
|
||||
MotorType.DMH6215: "dmh6215",
|
||||
MotorType.DMG6220: "dmg6220",
|
||||
}
|
||||
|
||||
# Motor resolution table (encoder counts per revolution)
|
||||
MODEL_RESOLUTION = {
|
||||
"dm3507": 65536,
|
||||
"dm4310": 65536,
|
||||
"dm4310_48v": 65536,
|
||||
"dm4340": 65536,
|
||||
"dm4340_48v": 65536,
|
||||
"dm6006": 65536,
|
||||
"dm8006": 65536,
|
||||
"dm8009": 65536,
|
||||
"dm10010l": 65536,
|
||||
"dm10010": 65536,
|
||||
"dmh3510": 65536,
|
||||
"dmh6215": 65536,
|
||||
"dmg6220": 65536,
|
||||
}
|
||||
|
||||
# CAN baudrates supported by Damiao motors
|
||||
AVAILABLE_BAUDRATES = [
|
||||
125000, # 0: 125 kbps
|
||||
200000, # 1: 200 kbps
|
||||
250000, # 2: 250 kbps
|
||||
500000, # 3: 500 kbps
|
||||
1000000, # 4: 1 mbps (default for OpenArms)
|
||||
2000000, # 5: 2 mbps
|
||||
2500000, # 6: 2.5 mbps
|
||||
3200000, # 7: 3.2 mbps
|
||||
4000000, # 8: 4 mbps
|
||||
5000000, # 9: 5 mbps
|
||||
]
|
||||
DEFAULT_BAUDRATE = 1000000 # 1 Mbps is standard for OpenArms
|
||||
|
||||
# Default timeout in milliseconds
|
||||
DEFAULT_TIMEOUT_MS = 1000
|
||||
|
||||
# Data that should be normalized
|
||||
NORMALIZED_DATA = ["Present_Position", "Goal_Position"]
|
||||
|
||||
# OpenArms specific configurations
|
||||
# Based on: https://docs.openarm.dev/software/setup/configure-test
|
||||
# OpenArms has 7 DOF per arm (14 total for dual arm)
|
||||
OPENARMS_ARM_MOTOR_IDS = {
|
||||
"joint_1": {"send": 0x01, "recv": 0x11}, # J1 - Shoulder pan
|
||||
"joint_2": {"send": 0x02, "recv": 0x12}, # J2 - Shoulder lift
|
||||
"joint_3": {"send": 0x03, "recv": 0x13}, # J3 - Elbow flex
|
||||
"joint_4": {"send": 0x04, "recv": 0x14}, # J4 - Wrist flex
|
||||
"joint_5": {"send": 0x05, "recv": 0x15}, # J5 - Wrist roll
|
||||
"joint_6": {"send": 0x06, "recv": 0x16}, # J6 - Wrist pitch
|
||||
"joint_7": {"send": 0x07, "recv": 0x17}, # J7 - Wrist rotation
|
||||
}
|
||||
|
||||
OPENARMS_GRIPPER_MOTOR_IDS = {
|
||||
"gripper": {"send": 0x08, "recv": 0x18}, # J8 - Gripper
|
||||
}
|
||||
|
||||
# Default motor types for OpenArms
|
||||
OPENARMS_DEFAULT_MOTOR_TYPES = {
|
||||
"joint_1": MotorType.DM8009, # Shoulder pan - high torque
|
||||
"joint_2": MotorType.DM8009, # Shoulder lift - high torque
|
||||
"joint_3": MotorType.DM4340, # Shoulder rotation
|
||||
"joint_4": MotorType.DM4340, # Elbow flex
|
||||
"joint_5": MotorType.DM4310, # Wrist roll
|
||||
"joint_6": MotorType.DM4310, # Wrist pitch
|
||||
"joint_7": MotorType.DM4310, # Wrist rotation
|
||||
"gripper": MotorType.DM4310, # Gripper
|
||||
}
|
||||
|
||||
# MIT control parameter ranges
|
||||
MIT_KP_RANGE = (0.0, 500.0)
|
||||
MIT_KD_RANGE = (0.0, 5.0)
|
||||
|
||||
# CAN frame command IDs
|
||||
CAN_CMD_ENABLE = 0xFC
|
||||
CAN_CMD_DISABLE = 0xFD
|
||||
CAN_CMD_SET_ZERO = 0xFE
|
||||
CAN_CMD_REFRESH = 0xCC
|
||||
CAN_CMD_QUERY_PARAM = 0x33
|
||||
CAN_CMD_WRITE_PARAM = 0x55
|
||||
CAN_CMD_SAVE_PARAM = 0xAA
|
||||
|
||||
# CAN ID for parameter operations
|
||||
CAN_PARAM_ID = 0x7FF
|
||||
@@ -24,7 +24,7 @@ from enum import Enum
|
||||
|
||||
from lerobot.motors.encoding_utils import decode_twos_complement, encode_twos_complement
|
||||
|
||||
from ..motors_bus import Motor, MotorCalibration, MotorsBus, NameOrID, Value, get_address
|
||||
from ..motors_bus import Motor, MotorCalibration, NameOrID, SerialMotorsBus, Value, get_address
|
||||
from .tables import (
|
||||
AVAILABLE_BAUDRATES,
|
||||
MODEL_BAUDRATE_TABLE,
|
||||
@@ -100,7 +100,7 @@ def _split_into_byte_chunks(value: int, length: int) -> list[int]:
|
||||
return data
|
||||
|
||||
|
||||
class DynamixelMotorsBus(MotorsBus):
|
||||
class DynamixelMotorsBus(SerialMotorsBus):
|
||||
"""
|
||||
The Dynamixel implementation for a MotorsBus. It relies on the python dynamixel sdk to communicate with
|
||||
the motors. For more info, see the Dynamixel SDK Documentation:
|
||||
|
||||
@@ -19,7 +19,7 @@ from pprint import pformat
|
||||
|
||||
from lerobot.motors.encoding_utils import decode_sign_magnitude, encode_sign_magnitude
|
||||
|
||||
from ..motors_bus import Motor, MotorCalibration, MotorsBus, NameOrID, Value, get_address
|
||||
from ..motors_bus import Motor, MotorCalibration, NameOrID, SerialMotorsBus, Value, get_address
|
||||
from .tables import (
|
||||
FIRMWARE_MAJOR_VERSION,
|
||||
FIRMWARE_MINOR_VERSION,
|
||||
@@ -96,7 +96,7 @@ def patch_setPacketTimeout(self, packet_length): # noqa: N802
|
||||
self.packet_timeout = (self.tx_time_per_byte * packet_length) + (self.tx_time_per_byte * 3.0) + 50
|
||||
|
||||
|
||||
class FeetechMotorsBus(MotorsBus):
|
||||
class FeetechMotorsBus(SerialMotorsBus):
|
||||
"""
|
||||
The FeetechMotorsBus class allows to efficiently read and write to the attached motors. It relies on the
|
||||
python feetech sdk to communicate with the motors, which is itself based on the dynamixel sdk.
|
||||
@@ -165,7 +165,7 @@ class FeetechMotorsBus(MotorsBus):
|
||||
|
||||
def _handshake(self) -> None:
|
||||
self._assert_motors_exist()
|
||||
self._assert_same_firmware()
|
||||
#self._assert_same_firmware()
|
||||
|
||||
def _find_single_motor(self, motor: str, initial_baudrate: int | None = None) -> tuple[int, int]:
|
||||
if self.protocol_version == 0:
|
||||
|
||||
@@ -19,6 +19,8 @@
|
||||
# TODO(aliberts): Add block noqa when feature below is available
|
||||
# https://github.com/astral-sh/ruff/issues/3711
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
@@ -41,6 +43,92 @@ Value: TypeAlias = int | float
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MotorsBusBase(abc.ABC):
|
||||
"""
|
||||
Base class for all motor bus implementations.
|
||||
|
||||
This is a minimal interface that all motor buses must implement, regardless of their
|
||||
communication protocol (serial, CAN, etc.).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
port: str,
|
||||
motors: dict[str, Motor],
|
||||
calibration: dict[str, MotorCalibration] | None = None,
|
||||
):
|
||||
self.port = port
|
||||
self.motors = motors
|
||||
self.calibration = calibration if calibration else {}
|
||||
|
||||
@abc.abstractmethod
|
||||
def connect(self, handshake: bool = True) -> None:
|
||||
"""Establish connection to the motors."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def disconnect(self, disable_torque: bool = True) -> None:
|
||||
"""Disconnect from the motors."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def is_connected(self) -> bool:
|
||||
"""Check if connected to the motors."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def read(self, data_name: str, motor: str, *, normalize: bool = True, num_retry: int = 0) -> Value:
|
||||
"""Read a value from a single motor."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def write(
|
||||
self, data_name: str, motor: str, value: Value, *, normalize: bool = True, num_retry: int = 0
|
||||
) -> None:
|
||||
"""Write a value to a single motor."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def sync_read(
|
||||
self, data_name: str, motors: str | list[str] | None = None, *, normalize: bool = True
|
||||
) -> dict[str, Value]:
|
||||
"""Read a value from multiple motors."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def sync_write(
|
||||
self,
|
||||
data_name: str,
|
||||
values: Value | dict[str, Value],
|
||||
motors: str | list[str] | None = None,
|
||||
*,
|
||||
normalize: bool = True,
|
||||
) -> None:
|
||||
"""Write values to multiple motors."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
|
||||
"""Enable torque on selected motors."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def disable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None:
|
||||
"""Disable torque on selected motors."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def read_calibration(self) -> dict[str, MotorCalibration]:
|
||||
"""Read calibration parameters from the motors."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def write_calibration(self, calibration_dict: dict[str, MotorCalibration], cache: bool = True) -> None:
|
||||
"""Write calibration parameters to the motors."""
|
||||
pass
|
||||
|
||||
|
||||
def get_ctrl_table(model_ctrl_table: dict[str, dict], model: str) -> dict[str, tuple[int, int]]:
|
||||
ctrl_table = model_ctrl_table.get(model)
|
||||
if ctrl_table is None:
|
||||
@@ -203,15 +291,15 @@ class GroupSyncWrite(Protocol):
|
||||
def txPacket(self): ...
|
||||
|
||||
|
||||
class MotorsBus(abc.ABC):
|
||||
class SerialMotorsBus(MotorsBusBase):
|
||||
"""
|
||||
A MotorsBus allows to efficiently read and write to the attached motors.
|
||||
A SerialMotorsBus allows to efficiently read and write to motors connected via serial communication.
|
||||
It represents several motors daisy-chained together and connected through a serial port.
|
||||
There are currently two implementations of this abstract class:
|
||||
There are currently two implementations of this class:
|
||||
- DynamixelMotorsBus
|
||||
- FeetechMotorsBus
|
||||
|
||||
Note: This class may evolve in the future should we add support for other types of bus.
|
||||
This class is specifically for serial-based motor protocols (Dynamixel, Feetech, etc.).
|
||||
|
||||
A MotorsBus subclass instance requires a port (e.g. `FeetechMotorsBus(port="/dev/tty.usbmodem575E0031751"`)).
|
||||
To find the port, you can run our utility script:
|
||||
@@ -1212,3 +1300,7 @@ class MotorsBus(abc.ABC):
|
||||
for id_, value in ids_values.items():
|
||||
data = self._serialize_data(value, length)
|
||||
self.sync_writer.addParam(id_, data)
|
||||
|
||||
|
||||
# Backward compatibility alias
|
||||
MotorsBus = SerialMotorsBus
|
||||
|
||||