mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-13 15:49:53 +00:00
Compare commits
30 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 63dedac255 | |||
| b0286b10cf | |||
| 7a8b02cd32 | |||
| 892e9f13b7 | |||
| 4b8436aefa | |||
| 9d97426cb8 | |||
| e8f504edaa | |||
| db7334a384 | |||
| 5de7aa5a4f | |||
| fc8d89b128 | |||
| e0bde22193 | |||
| 055f20f658 | |||
| 30d2fe3bb3 | |||
| 4eecbad32b | |||
| 1396b9fab7 | |||
| 7c032f19fc | |||
| e2f27bf71b | |||
| ea36a4a176 | |||
| 399b3c9ba5 | |||
| 913041e753 | |||
| 2b541ddd4c | |||
| 50a1e67e94 | |||
| d60a700d2b | |||
| 8c3d4cf900 | |||
| b6e60a6e30 | |||
| 3596681d94 | |||
| 4dbbcca496 | |||
| 818892a38b | |||
| 66fef25ded | |||
| 2cf08b7a4b |
@@ -0,0 +1,86 @@
|
||||
# LeRobot — Claude Code Instructions
|
||||
|
||||
You are a senior robotics ML engineer reviewing code for **LeRobot**, a PyTorch framework for real-world robot learning.
|
||||
Apply these principles to every PR review, fix, or task.
|
||||
|
||||
---
|
||||
|
||||
## Core Abstractions
|
||||
|
||||
These are the load-bearing types. Handle them with care — breaking changes here affect every user.
|
||||
|
||||
| Type | Location | Role |
|
||||
| ---------------- | ---------------------------- | ------------------------------------------------------------ |
|
||||
| `LeRobotDataset` | `src/lerobot/datasets/` | Streaming replay buffer; HF Hub integration |
|
||||
| `Policy` | `src/lerobot/policies/` | Base class for all learning agents (ACT, Diffusion, SARM, …) |
|
||||
| `Robot` | `src/lerobot/robots/` | Hardware abstraction; carries `_output_pipeline` |
|
||||
| `Teleoperator` | `src/lerobot/teleoperators/` | Leader-side hardware abstraction; carries `_output_pipeline` |
|
||||
| `Env` | `src/lerobot/envs/` | Gym-like robotics environments |
|
||||
| `Processor` | `src/lerobot/processor/` | Data transformation pipelines attached to robots/teleops |
|
||||
|
||||
**Never break their public APIs without a migration note and explicit user approval.**
|
||||
|
||||
---
|
||||
|
||||
## Engineering Principles
|
||||
|
||||
### Code quality
|
||||
|
||||
- Explicit over magic — no hidden control flow, no implicit state.
|
||||
- No deep inheritance trees. Prefer composition.
|
||||
- No decorative comment separators (`===`, `---`, etc.).
|
||||
- Add comments only where the logic is non-obvious.
|
||||
- No over-engineering. YAGNI applies strictly.
|
||||
|
||||
### Type safety
|
||||
|
||||
- All new and modified Python code must be fully typed (PEP 484).
|
||||
- `mypy --strict` must pass on changed files.
|
||||
- Do not widen or weaken existing type signatures.
|
||||
|
||||
### Backwards compatibility
|
||||
|
||||
- Public API changes require migration notes.
|
||||
- Additive changes are preferred over modifications.
|
||||
- `so100_follower` / `so101_follower` are aliases — never bleed changes there unintentionally.
|
||||
|
||||
### HF ecosystem
|
||||
|
||||
- Use `push_to_hub()`, HF Hub dataset streaming, and `evaluate` scripts.
|
||||
- Dataset changes must preserve streaming compatibility.
|
||||
- Prefer reusing HF primitives over rolling custom solutions.
|
||||
|
||||
---
|
||||
|
||||
## PR Review Checklist
|
||||
|
||||
Before approving or marking P1 issues resolved, verify:
|
||||
|
||||
- [ ] `pre-commit run -a` would pass (ruff, mypy, typos, zizmor, bandit)
|
||||
- [ ] All new/modified code is typed and passes `mypy --strict`
|
||||
- [ ] New features have unit tests; no silent behavioral changes
|
||||
- [ ] Public APIs of `LeRobotDataset`, `Policy`, `Robot`, `Teleoperator`, `Env` are unchanged (or migration note present)
|
||||
- [ ] HF Hub streaming still works for dataset changes
|
||||
- [ ] No unnecessary abstractions introduced
|
||||
- [ ] No breaking changes to training scripts (`lerobot-train`, `lerobot-eval`, `lerobot-record`)
|
||||
|
||||
---
|
||||
|
||||
## ML-Specific Checks
|
||||
|
||||
Flag these as **P1** if found:
|
||||
|
||||
- **Data leakage**: train and val/test splits must be constructed before any normalization or augmentation that uses train statistics.
|
||||
- **Loss function errors**: verify reduction mode (`mean` vs `sum`), correct masking, correct shape alignment.
|
||||
- **Gradient flow**: new modules must have gradients flowing (check `requires_grad`, no detached tensors in the loss path by accident).
|
||||
- **Distributed training**: operations on tensors must be DDP-safe; no in-place ops on parameters; batch norm needs `SyncBatchNorm` if used.
|
||||
- **Memory leaks**: no accumulation of tensors outside the training loop; `optimizer.zero_grad()` called correctly.
|
||||
|
||||
---
|
||||
|
||||
## What to Skip
|
||||
|
||||
- Don't flag style nitpicks on unchanged surrounding code.
|
||||
- Don't propose refactors outside the PR's scope.
|
||||
- Don't add docstrings or comments to code the PR didn't touch.
|
||||
- Don't suggest speculative future features (YAGNI).
|
||||
@@ -0,0 +1,49 @@
|
||||
name: Claude Code Review
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [opened, synchronize, ready_for_review, reopened]
|
||||
|
||||
jobs:
|
||||
claude-review:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: write
|
||||
issues: read
|
||||
id-token: write
|
||||
actions: read
|
||||
env:
|
||||
FORCE_JAVASCRIPT_ACTIONS_TO_NODE24: true
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 1
|
||||
persist-credentials: false
|
||||
|
||||
- name: Run Claude Code Review
|
||||
id: claude-review
|
||||
uses: anthropics/claude-code-action@26ddc358fe3befff50c5ec2f80304c90c763f6f8 # v1
|
||||
with:
|
||||
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
use_sticky_comment: true
|
||||
prompt: |
|
||||
Read `.github/CLAUDE.md` for lerobot-specific conventions, then review this PR.
|
||||
Provide structured, actionable feedback.
|
||||
|
||||
Focus areas (in priority order):
|
||||
1. **Correctness**: Logic errors, off-by-ones, wrong tensor shapes, incorrect loss functions
|
||||
2. **Type safety**: All new/modified Python code must pass `mypy --strict`; check for missing annotations
|
||||
3. **Backwards compatibility**: Does this break `LeRobotDataset`, `Policy`, `Robot`, `Teleoperator`, `Env`, or `Processor` public APIs?
|
||||
4. **Tests**: New features must have tests; no silent behavioral changes
|
||||
5. **Code style**: Explicit over magic, no unnecessary abstractions, no decorative comments
|
||||
6. **HF integration**: Dataset streaming, `push_to_hub`, HF Hub compatibility preserved?
|
||||
7. **pre-commit**: Would `pre-commit run -a` pass? (ruff, mypy, typos, zizmor)
|
||||
|
||||
Format findings as P1 (must fix) / P2 (should fix) / P3 (nice to have).
|
||||
Skip P3 if the PR is already high quality.
|
||||
claude_args: '--model claude-opus-4-6'
|
||||
# See https://github.com/anthropics/claude-code-action/blob/main/docs/usage.md
|
||||
# or https://code.claude.com/docs/en/cli-reference for available options
|
||||
@@ -0,0 +1,58 @@
|
||||
name: Claude Code
|
||||
|
||||
on:
|
||||
issue_comment:
|
||||
types: [created]
|
||||
pull_request_review_comment:
|
||||
types: [created]
|
||||
issues:
|
||||
types: [opened, assigned]
|
||||
pull_request_review:
|
||||
types: [submitted]
|
||||
|
||||
jobs:
|
||||
claude:
|
||||
if: |
|
||||
(github.event_name == 'issue_comment' &&
|
||||
contains(github.event.comment.body, '@claude') &&
|
||||
(github.event.comment.author_association == 'OWNER' || github.event.comment.author_association == 'MEMBER' || github.event.comment.author_association == 'COLLABORATOR')) ||
|
||||
(github.event_name == 'pull_request_review_comment' &&
|
||||
contains(github.event.comment.body, '@claude') &&
|
||||
(github.event.comment.author_association == 'OWNER' || github.event.comment.author_association == 'MEMBER' || github.event.comment.author_association == 'COLLABORATOR')) ||
|
||||
(github.event_name == 'pull_request_review' &&
|
||||
contains(github.event.review.body, '@claude') &&
|
||||
(github.event.review.author_association == 'OWNER' || github.event.review.author_association == 'MEMBER' || github.event.review.author_association == 'COLLABORATOR')) ||
|
||||
(github.event_name == 'issues' &&
|
||||
(contains(github.event.issue.body, '@claude') || contains(github.event.issue.title, '@claude')) &&
|
||||
(github.event.issue.author_association == 'OWNER' || github.event.issue.author_association == 'MEMBER' || github.event.issue.author_association == 'COLLABORATOR'))
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: write
|
||||
issues: write
|
||||
id-token: write
|
||||
actions: read
|
||||
env:
|
||||
FORCE_JAVASCRIPT_ACTIONS_TO_NODE24: true
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 1
|
||||
persist-credentials: false
|
||||
|
||||
- name: Run Claude Code
|
||||
id: claude
|
||||
uses: anthropics/claude-code-action@26ddc358fe3befff50c5ec2f80304c90c763f6f8 # v1
|
||||
with:
|
||||
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
use_sticky_comment: true
|
||||
|
||||
# This is an optional setting that allows Claude to read CI results on PRs
|
||||
additional_permissions: |
|
||||
actions: read
|
||||
|
||||
claude_args: '--system-prompt "Read .github/CLAUDE.md for lerobot-specific conventions before responding."'
|
||||
# See https://github.com/anthropics/claude-code-action/blob/main/docs/usage.md
|
||||
# or https://code.claude.com/docs/en/cli-reference for available options
|
||||
@@ -12,8 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# This workflow handles nightly testing & docker images publishing.
|
||||
name: Nightly
|
||||
# This workflow handles Docker image publishing & testing.
|
||||
name: Docker Publish & Test
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
@@ -39,8 +39,8 @@ concurrency:
|
||||
|
||||
jobs:
|
||||
# This job builds a CPU image for testing & distribution
|
||||
build-docker-cpu-nightly:
|
||||
name: Build CPU Docker for Nightly
|
||||
build-docker-cpu:
|
||||
name: Build CPU Docker
|
||||
runs-on:
|
||||
group: aws-general-8-plus
|
||||
if: github.repository == 'huggingface/lerobot'
|
||||
@@ -74,8 +74,8 @@ jobs:
|
||||
tags: ${{ env.DOCKER_IMAGE_NAME_CPU }}
|
||||
|
||||
# This job builds a GPU image for testing & distribution
|
||||
build-docker-gpu-nightly:
|
||||
name: Build GPU Docker for Nightly
|
||||
build-docker-gpu:
|
||||
name: Build GPU Docker
|
||||
runs-on:
|
||||
group: aws-general-8-plus
|
||||
if: github.repository == 'huggingface/lerobot'
|
||||
@@ -109,9 +109,9 @@ jobs:
|
||||
tags: ${{ env.DOCKER_IMAGE_NAME_GPU }}
|
||||
|
||||
# This job runs the E2E tests + pytest with all extras in the CPU image
|
||||
nightly-cpu-tests:
|
||||
name: Nightly CPU Tests
|
||||
needs: [build-docker-cpu-nightly]
|
||||
cpu-tests:
|
||||
name: CPU Tests
|
||||
needs: [build-docker-cpu]
|
||||
runs-on:
|
||||
group: aws-g6-4xlarge-plus
|
||||
env:
|
||||
@@ -121,7 +121,7 @@ jobs:
|
||||
TRITON_CACHE_DIR: /home/user_lerobot/.cache/triton
|
||||
HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }}
|
||||
container:
|
||||
image: ${{ needs.build-docker-cpu-nightly.outputs.image_tag }} # zizmor: ignore[unpinned-images]
|
||||
image: ${{ needs.build-docker-cpu.outputs.image_tag }} # zizmor: ignore[unpinned-images]
|
||||
options: --shm-size "16gb"
|
||||
credentials:
|
||||
username: ${{ secrets.DOCKERHUB_LEROBOT_USERNAME }}
|
||||
@@ -142,9 +142,9 @@ jobs:
|
||||
run: make test-end-to-end
|
||||
|
||||
# This job runs the E2E tests + pytest with all extras in the GPU image
|
||||
nightly-gpu-tests:
|
||||
name: Nightly GPU Tests
|
||||
needs: [build-docker-gpu-nightly]
|
||||
gpu-tests:
|
||||
name: GPU Tests
|
||||
needs: [build-docker-gpu]
|
||||
runs-on:
|
||||
group: aws-g6-4xlarge-plus
|
||||
env:
|
||||
@@ -154,7 +154,7 @@ jobs:
|
||||
TRITON_CACHE_DIR: /home/user_lerobot/.cache/triton
|
||||
HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }}
|
||||
container:
|
||||
image: ${{ needs.build-docker-gpu-nightly.outputs.image_tag }} # zizmor: ignore[unpinned-images]
|
||||
image: ${{ needs.build-docker-gpu.outputs.image_tag }} # zizmor: ignore[unpinned-images]
|
||||
options: --gpus all --shm-size "16gb"
|
||||
credentials:
|
||||
username: ${{ secrets.DOCKERHUB_LEROBOT_USERNAME }}
|
||||
@@ -175,9 +175,9 @@ jobs:
|
||||
run: make test-end-to-end
|
||||
|
||||
# This job runs multi-GPU training tests with 4 GPUs
|
||||
nightly-multi-gpu-tests:
|
||||
name: Nightly Multi-GPU Tests
|
||||
needs: [build-docker-gpu-nightly]
|
||||
multi-gpu-tests:
|
||||
name: Multi-GPU Tests
|
||||
needs: [build-docker-gpu]
|
||||
runs-on:
|
||||
group: aws-g4dn-12xlarge # Instance with 4 GPUs
|
||||
env:
|
||||
@@ -188,7 +188,7 @@ jobs:
|
||||
CUDA_VISIBLE_DEVICES: "0,1,2,3"
|
||||
HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }}
|
||||
container:
|
||||
image: ${{ needs.build-docker-gpu-nightly.outputs.image_tag }} # zizmor: ignore[unpinned-images]
|
||||
image: ${{ needs.build-docker-gpu.outputs.image_tag }} # zizmor: ignore[unpinned-images]
|
||||
options: --gpus all --shm-size "16gb"
|
||||
credentials:
|
||||
username: ${{ secrets.DOCKERHUB_LEROBOT_USERNAME }}
|
||||
@@ -33,7 +33,7 @@ jobs:
|
||||
github.event.workflow_run.event == 'pull_request' &&
|
||||
github.event.workflow_run.conclusion == 'success' &&
|
||||
github.repository == 'huggingface/lerobot'
|
||||
uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@main
|
||||
uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@90b4ee2c10b81b5c1a6367c4e6fc9e2fb510a7e3 # main
|
||||
with:
|
||||
package_name: lerobot
|
||||
secrets:
|
||||
|
||||
@@ -55,7 +55,7 @@ jobs:
|
||||
github.repository == 'huggingface/lerobot'
|
||||
permissions:
|
||||
contents: read
|
||||
uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@main
|
||||
uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@90b4ee2c10b81b5c1a6367c4e6fc9e2fb510a7e3 # main
|
||||
with:
|
||||
commit_sha: ${{ github.sha }}
|
||||
package: lerobot
|
||||
@@ -78,7 +78,7 @@ jobs:
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: write
|
||||
uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@main
|
||||
uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@90b4ee2c10b81b5c1a6367c4e6fc9e2fb510a7e3 # main
|
||||
with:
|
||||
commit_sha: ${{ github.event.pull_request.head.sha }}
|
||||
pr_number: ${{ github.event.number }}
|
||||
|
||||
@@ -27,6 +27,7 @@ on:
|
||||
- "tests/**"
|
||||
- ".github/workflows/**"
|
||||
- "pyproject.toml"
|
||||
- "uv.lock"
|
||||
- "Makefile"
|
||||
push:
|
||||
branches:
|
||||
@@ -36,6 +37,7 @@ on:
|
||||
- "tests/**"
|
||||
- ".github/workflows/**"
|
||||
- "pyproject.toml"
|
||||
- "uv.lock"
|
||||
- "Makefile"
|
||||
|
||||
permissions:
|
||||
@@ -63,7 +65,7 @@ jobs:
|
||||
HF_LEROBOT_HOME: /mnt/cache/.cache/huggingface/lerobot
|
||||
HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }}
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
lfs: true
|
||||
@@ -81,14 +83,14 @@ jobs:
|
||||
libusb-1.0-0-dev speech-dispatcher libgeos-dev portaudio19-dev
|
||||
|
||||
- name: Setup uv and Python
|
||||
uses: astral-sh/setup-uv@v6 # zizmor: ignore[unpinned-uses]
|
||||
uses: astral-sh/setup-uv@d0cc045d04ccac9d8b7881df0226f9e82c39688e # v6
|
||||
with:
|
||||
enable-cache: true
|
||||
version: ${{ env.UV_VERSION }}
|
||||
python-version: ${{ env.PYTHON_VERSION }}
|
||||
|
||||
- name: Install lerobot with test extras
|
||||
run: uv sync --extra "test"
|
||||
run: uv sync --locked --extra "test"
|
||||
|
||||
- name: Login to Hugging Face
|
||||
if: env.HF_USER_TOKEN != ''
|
||||
|
||||
@@ -29,6 +29,7 @@ on:
|
||||
- "tests/**"
|
||||
- ".github/workflows/**"
|
||||
- "pyproject.toml"
|
||||
- "uv.lock"
|
||||
- "Makefile"
|
||||
|
||||
permissions:
|
||||
@@ -62,7 +63,7 @@ jobs:
|
||||
HF_LEROBOT_HOME: /mnt/cache/.cache/huggingface/lerobot
|
||||
HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }}
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
lfs: true
|
||||
persist-credentials: false
|
||||
@@ -79,14 +80,14 @@ jobs:
|
||||
speech-dispatcher libgeos-dev portaudio19-dev
|
||||
|
||||
- name: Setup uv and Python
|
||||
uses: astral-sh/setup-uv@v6 # zizmor: ignore[unpinned-uses]
|
||||
uses: astral-sh/setup-uv@d0cc045d04ccac9d8b7881df0226f9e82c39688e # v6
|
||||
with:
|
||||
enable-cache: true
|
||||
version: ${{ env.UV_VERSION }}
|
||||
python-version: ${{ env.PYTHON_VERSION }}
|
||||
|
||||
- name: Install lerobot with all extras
|
||||
run: uv sync --extra all # TODO(Steven): Make flash-attn optional
|
||||
run: uv sync --locked --extra all # TODO(Steven): Make flash-attn optional
|
||||
|
||||
- name: Login to Hugging Face
|
||||
if: env.HF_USER_TOKEN != ''
|
||||
@@ -136,21 +137,21 @@ jobs:
|
||||
sudo apt-get update
|
||||
sudo apt-get install git-lfs
|
||||
git lfs install
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
lfs: true
|
||||
persist-credentials: false
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3 # zizmor: ignore[unpinned-uses]
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # v3
|
||||
with:
|
||||
cache-binary: false
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3 # zizmor: ignore[unpinned-uses]
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_LEROBOT_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_LEROBOT_PASSWORD }}
|
||||
- name: Build and push Docker image
|
||||
uses: docker/build-push-action@v6 # zizmor: ignore[unpinned-uses]
|
||||
uses: docker/build-push-action@10e90e3645eae34f1e60eeb005ba3a3d33f178e8 # v6
|
||||
with:
|
||||
context: .
|
||||
file: ./docker/Dockerfile.internal
|
||||
|
||||
+139
-37
@@ -12,38 +12,81 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# This workflow handles full testing with unboud dependencies versions.
|
||||
name: Unbound Dependency Tests
|
||||
# This workflow tests the project against the latest upstream dependencies
|
||||
# (within pyproject.toml constraints) and opens a PR to update uv.lock
|
||||
# if the tests pass and the lockfile has changed.
|
||||
name: Latest Dependency Tests
|
||||
|
||||
on:
|
||||
# Allows running this workflow manually from the Actions tab
|
||||
workflow_dispatch:
|
||||
|
||||
# Run on the 1st and 15th of every month at 09:00 UTC
|
||||
# schedule:
|
||||
# - cron: '0 2 1,15 * *'
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
# Runs at 03:00 UTC
|
||||
schedule:
|
||||
- cron: "0 3 * * *"
|
||||
|
||||
# Sets up the environment variables
|
||||
env:
|
||||
UV_VERSION: "0.8.0"
|
||||
PYTHON_VERSION: "3.12"
|
||||
DOCKER_IMAGE_NAME: huggingface/lerobot-gpu:unbound
|
||||
DOCKER_IMAGE_NAME: huggingface/lerobot-gpu:latest-deps
|
||||
|
||||
# Ensures that only the latest action is built, canceling older runs.
|
||||
# Ensures that only the latest run is active, canceling older runs.
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
|
||||
group: ${{ github.workflow }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
|
||||
# This job runs the E2E tests + pytest with all unbound extras
|
||||
full-tests:
|
||||
name: Full Unbound Tests
|
||||
# This job upgrades the lockfile and checks if dependencies have changed
|
||||
upgrade-lock:
|
||||
name: Upgrade Lockfile
|
||||
runs-on: ubuntu-latest
|
||||
if: github.repository == 'huggingface/lerobot'
|
||||
permissions:
|
||||
contents: read
|
||||
outputs:
|
||||
changed: ${{ steps.diff.outputs.changed }}
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup uv and Python
|
||||
uses: astral-sh/setup-uv@v6 # zizmor: ignore[unpinned-uses]
|
||||
with:
|
||||
version: ${{ env.UV_VERSION }}
|
||||
python-version: ${{ env.PYTHON_VERSION }}
|
||||
|
||||
- name: Upgrade uv.lock
|
||||
run: uv lock --upgrade
|
||||
|
||||
- name: Check for changes
|
||||
id: diff
|
||||
run: |
|
||||
if git diff --quiet uv.lock; then
|
||||
echo "changed=false" >> "$GITHUB_OUTPUT"
|
||||
echo "uv.lock is up to date — no dependency changes."
|
||||
else
|
||||
echo "changed=true" >> "$GITHUB_OUTPUT"
|
||||
echo "uv.lock has changed — running tests."
|
||||
fi
|
||||
|
||||
- name: Upload updated lockfile
|
||||
if: steps.diff.outputs.changed == 'true'
|
||||
uses: actions/upload-artifact@v4 # zizmor: ignore[unpinned-uses]
|
||||
with:
|
||||
name: uv-lock
|
||||
path: uv.lock
|
||||
|
||||
# This job runs the full test suite with the upgraded dependencies
|
||||
cpu-tests:
|
||||
name: CPU Tests (Latest Deps)
|
||||
needs: [upgrade-lock]
|
||||
if: needs.upgrade-lock.outputs.changed == 'true'
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
env:
|
||||
MUJOCO_GL: egl
|
||||
HF_HOME: /mnt/cache/.cache/huggingface
|
||||
@@ -55,6 +98,11 @@ jobs:
|
||||
lfs: true
|
||||
persist-credentials: false
|
||||
|
||||
- name: Download updated lockfile
|
||||
uses: actions/download-artifact@v4 # zizmor: ignore[unpinned-uses]
|
||||
with:
|
||||
name: uv-lock
|
||||
|
||||
# NOTE(Steven): Mount to `/mnt` to avoid the limited storage on `/home`. Consider cleaning default SDKs or using self-hosted runners for more space.
|
||||
# (As of 2024-06-10, the runner's `/home` has only 6.2 GB free—8% of its 72 GB total.)
|
||||
- name: Setup /mnt storage
|
||||
@@ -73,34 +121,32 @@ jobs:
|
||||
version: ${{ env.UV_VERSION }}
|
||||
python-version: ${{ env.PYTHON_VERSION }}
|
||||
|
||||
- name: Unbound dependencies
|
||||
run: |
|
||||
sed -i 's/,[[:space:]]*<[0-9\.]*//g' pyproject.toml
|
||||
echo "Dependencies unbound:" && cat pyproject.toml
|
||||
|
||||
- name: Install lerobot with all extras
|
||||
run: uv sync --extra all # TODO(Steven): Make flash-attn optional
|
||||
run: uv sync --locked --extra all # TODO(Steven): Make flash-attn optional
|
||||
|
||||
- name: Login to Hugging Face
|
||||
if: env.HF_USER_TOKEN != ''
|
||||
run: |
|
||||
uv run hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
|
||||
uv run hf auth whoami
|
||||
|
||||
- name: Run pytest (all extras)
|
||||
run: uv run pytest tests -vv
|
||||
run: uv run pytest tests -vv --maxfail=10
|
||||
|
||||
- name: Run end-to-end tests
|
||||
run: uv run make test-end-to-end
|
||||
|
||||
# This job builds a GPU enabled image for testing
|
||||
# This job builds a GPU-enabled Docker image with the upgraded dependencies
|
||||
build-and-push-docker:
|
||||
name: Build and Push Docker
|
||||
needs: [upgrade-lock]
|
||||
if: needs.upgrade-lock.outputs.changed == 'true'
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on:
|
||||
group: aws-general-8-plus
|
||||
if: github.repository == 'huggingface/lerobot'
|
||||
outputs:
|
||||
image_tag: ${{ env.DOCKER_IMAGE_NAME }}
|
||||
env:
|
||||
GITHUB_REF: ${{ github.ref }}
|
||||
steps:
|
||||
- name: Install Git LFS
|
||||
run: |
|
||||
@@ -111,6 +157,12 @@ jobs:
|
||||
with:
|
||||
lfs: true
|
||||
persist-credentials: false
|
||||
|
||||
- name: Download updated lockfile
|
||||
uses: actions/download-artifact@v4 # zizmor: ignore[unpinned-uses]
|
||||
with:
|
||||
name: uv-lock
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3 # zizmor: ignore[unpinned-uses]
|
||||
with:
|
||||
@@ -127,14 +179,13 @@ jobs:
|
||||
file: ./docker/Dockerfile.internal
|
||||
push: true
|
||||
tags: ${{ env.DOCKER_IMAGE_NAME }}
|
||||
build-args: |
|
||||
UNBOUND_DEPS=true
|
||||
|
||||
# This job runs pytest with all unbound extras in a GPU enabled host
|
||||
# It runs everytime a test image is created
|
||||
# This job runs pytest with all extras on a GPU-enabled host
|
||||
gpu-tests:
|
||||
name: GPU Unbound Tests
|
||||
name: GPU Tests (Latest Deps)
|
||||
needs: [build-and-push-docker]
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on:
|
||||
group: aws-g6-4xlarge-plus
|
||||
env:
|
||||
@@ -159,17 +210,69 @@ jobs:
|
||||
run: |
|
||||
hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
|
||||
hf auth whoami
|
||||
- name: Fix ptxas permissions
|
||||
run: chmod +x /lerobot/.venv/lib/python3.12/site-packages/triton/backends/nvidia/bin/ptxas
|
||||
- name: Run pytest on GPU
|
||||
run: pytest tests -vv
|
||||
run: pytest tests -vv --maxfail=10
|
||||
- name: Run end-to-end tests
|
||||
run: make test-end-to-end
|
||||
|
||||
# This job deletes the test image recently created
|
||||
# It runs everytime after the gpu-tests have finished
|
||||
delete-unbound-image:
|
||||
name: Delete Unbound Image
|
||||
# This job creates or updates a PR with the upgraded lockfile
|
||||
open-pr:
|
||||
name: Open PR
|
||||
needs: [cpu-tests, gpu-tests, upgrade-lock]
|
||||
if: success() && needs.upgrade-lock.outputs.changed == 'true'
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.UPDATE_LOCK_TOKEN }}
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Download updated lockfile
|
||||
uses: actions/download-artifact@v4 # zizmor: ignore[unpinned-uses]
|
||||
with:
|
||||
name: uv-lock
|
||||
|
||||
- name: Create or update PR
|
||||
run: |
|
||||
set -euo pipefail
|
||||
BRANCH="auto/update-uv-lock"
|
||||
|
||||
git config user.name "github-actions[bot]"
|
||||
git config user.email "github-actions[bot]@users.noreply.github.com"
|
||||
git remote set-url origin "https://x-access-token:${GH_TOKEN}@github.com/${{ github.repository }}.git"
|
||||
|
||||
git checkout -B "$BRANCH"
|
||||
git add uv.lock
|
||||
git commit -m "chore(dependencies): update uv.lock"
|
||||
git push --force origin "$BRANCH"
|
||||
|
||||
# Create PR only if one doesn't already exist for this branch
|
||||
EXISTING_PR=$(gh pr list --head "$BRANCH" --state open --json number --jq '.[0].number')
|
||||
if [ -z "$EXISTING_PR" ]; then
|
||||
gh pr create \
|
||||
--title "chore(dependencies): update uv.lock" \
|
||||
--body "Automated update of \`uv.lock\` after successful latest dependency tests (CPU + GPU).
|
||||
|
||||
This PR upgrades all dependencies to their latest versions within the ranges specified in \`pyproject.toml\`." \
|
||||
--head "$BRANCH" \
|
||||
--base main
|
||||
else
|
||||
echo "PR #$EXISTING_PR already exists, branch has been updated."
|
||||
fi
|
||||
|
||||
# This job deletes the temporary Docker image after tests complete
|
||||
cleanup-docker:
|
||||
name: Cleanup Docker Image
|
||||
needs: [gpu-tests, build-and-push-docker]
|
||||
if: always() && needs.build-and-push-docker.result == 'success'
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Get Docker Hub Token and Delete Image
|
||||
@@ -180,8 +283,7 @@ jobs:
|
||||
IMAGE_FULL: ${{ needs.build-and-push-docker.outputs.image_tag }}
|
||||
run: |
|
||||
IMAGE_NAME=$(echo "$IMAGE_FULL" | cut -d':' -f1)
|
||||
IMAGE_TAG=$(echo "$IMAGE_FULL" | cut -d':' -f2)
|
||||
|
||||
IMAGE_TAG=$(echo "$IMAGE_FULL" | cut -d':' -f2-)
|
||||
echo "Attempting to delete image: $IMAGE_NAME:$IMAGE_TAG"
|
||||
|
||||
TOKEN=$(curl -s -H "Content-Type: application/json" \
|
||||
@@ -43,16 +43,16 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v6
|
||||
uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6
|
||||
with:
|
||||
python-version: '3.12'
|
||||
|
||||
- name: Run pre-commit hooks
|
||||
uses: pre-commit/action@v3.0.1 # zizmor: ignore[unpinned-uses]
|
||||
uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1
|
||||
with:
|
||||
extra_args: --all-files --show-diff-on-failure --color=always
|
||||
|
||||
@@ -38,12 +38,12 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v6
|
||||
uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6
|
||||
with:
|
||||
python-version: '3.12'
|
||||
|
||||
@@ -104,7 +104,7 @@ jobs:
|
||||
- name: Publish to TestPyPI for pre-releases
|
||||
# True for tags like 'v0.2.0-rc1'
|
||||
if: startsWith(github.ref, 'refs/tags/v') && contains(github.ref, '-')
|
||||
uses: pypa/gh-action-pypi-publish@v1.13.0 # zizmor: ignore[unpinned-uses, use-trusted-publishing]
|
||||
uses: pypa/gh-action-pypi-publish@ed0c53931b1dc9bd32cbe73a98c7f6766f8a527e # v1.13.0
|
||||
with:
|
||||
repository-url: https://test.pypi.org/legacy/
|
||||
verbose: true
|
||||
@@ -112,7 +112,7 @@ jobs:
|
||||
|
||||
- name: Publish to PyPI
|
||||
if: startsWith(github.ref, 'refs/tags/v') && !contains(github.ref, '-')
|
||||
uses: pypa/gh-action-pypi-publish@v1.13.0 # zizmor: ignore[unpinned-uses, use-trusted-publishing]
|
||||
uses: pypa/gh-action-pypi-publish@ed0c53931b1dc9bd32cbe73a98c7f6766f8a527e # v1.13.0
|
||||
with:
|
||||
verbose: true
|
||||
print-hash: true
|
||||
@@ -127,7 +127,7 @@ jobs:
|
||||
env:
|
||||
MUJOCO_GL: egl
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
lfs: true
|
||||
persist-credentials: false
|
||||
@@ -137,7 +137,7 @@ jobs:
|
||||
git curl libglib2.0-0 libegl1-mesa-dev ffmpeg libusb-1.0-0-dev \
|
||||
speech-dispatcher libgeos-dev portaudio19-dev
|
||||
- name: Setup uv and Python
|
||||
uses: astral-sh/setup-uv@v6 # zizmor: ignore[unpinned-uses]
|
||||
uses: astral-sh/setup-uv@d0cc045d04ccac9d8b7881df0226f9e82c39688e # v6
|
||||
with:
|
||||
enable-cache: true # zizmor: ignore[cache-poisoning]
|
||||
version: ${{ env.UV_VERSION }}
|
||||
|
||||
@@ -43,12 +43,12 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v6 # zizmor: ignore[unpinned-uses]
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Secret Scanning
|
||||
uses: trufflesecurity/trufflehog@v3.90.0 # zizmor: ignore[unpinned-uses]
|
||||
uses: trufflesecurity/trufflehog@eafb8c5f6a06175141c27f17bcc17941853d0047 # v3.90.0
|
||||
with:
|
||||
extra_args: --only-verified
|
||||
|
||||
@@ -25,7 +25,6 @@ node_modules/
|
||||
|
||||
# Lock files
|
||||
poetry.lock
|
||||
uv.lock
|
||||
Pipfile.lock
|
||||
|
||||
### Build & Distribution ###
|
||||
|
||||
@@ -4,7 +4,8 @@
|
||||
|
||||
<div align="center">
|
||||
|
||||
[](https://github.com/huggingface/lerobot/actions/workflows/nightly.yml?query=branch%3Amain)
|
||||
[](https://github.com/huggingface/lerobot/actions/workflows/latest_deps_tests.yml?query=branch%3Amain)
|
||||
[](https://github.com/huggingface/lerobot/actions/workflows/docker_publish.yml?query=branch%3Amain)
|
||||
[](https://www.python.org/downloads/)
|
||||
[](https://github.com/huggingface/lerobot/blob/main/LICENSE)
|
||||
[](https://pypi.org/project/lerobot/)
|
||||
|
||||
@@ -1,219 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from soundfile import read
|
||||
|
||||
from lerobot.microphones.configs import MicrophoneConfig
|
||||
from lerobot.microphones.portaudio import PortAudioMicrophone, PortAudioMicrophoneConfig
|
||||
from lerobot.microphones.utils import (
|
||||
async_microphones_start_recording,
|
||||
async_microphones_stop_recording,
|
||||
make_microphones_from_configs,
|
||||
)
|
||||
from lerobot.utils.robot_utils import (
|
||||
precise_sleep,
|
||||
)
|
||||
|
||||
|
||||
def main(
|
||||
microphones_configs: dict[str, MicrophoneConfig],
|
||||
audio_chunks_number: int,
|
||||
audio_chunks_duration: float,
|
||||
repetitions: int,
|
||||
multiprocessing: bool = False,
|
||||
):
|
||||
recording_dir = Path("outputs/audio_benchmark")
|
||||
recording_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create microphones
|
||||
microphones = make_microphones_from_configs(microphones_configs)
|
||||
|
||||
# Connect microphones
|
||||
for microphone in microphones.values():
|
||||
microphone.connect()
|
||||
|
||||
all_audio_chunks = []
|
||||
for i in range(repetitions):
|
||||
print(f"Repetition {i + 1}/{repetitions}...")
|
||||
|
||||
# Create audio chunks
|
||||
audio_chunks = {}
|
||||
for microphone_key in microphones:
|
||||
audio_chunks.update({microphone_key: []})
|
||||
|
||||
# Start recording
|
||||
async_microphones_start_recording(
|
||||
microphones,
|
||||
output_files=[
|
||||
recording_dir / f"{microphone_key}_recording_{i}.wav" for microphone_key in microphones
|
||||
],
|
||||
multiprocessing=multiprocessing,
|
||||
)
|
||||
|
||||
# Record audio chunks
|
||||
for j in range(audio_chunks_number):
|
||||
precise_sleep(audio_chunks_duration)
|
||||
|
||||
for microphone_key, microphone in microphones.items():
|
||||
audio_chunk = microphone.read()
|
||||
print(f"{microphone_key} - repetition {i} - chunk {j} - samples {audio_chunk.shape[0]}")
|
||||
audio_chunks[microphone_key].append(audio_chunk)
|
||||
|
||||
# Stop recording
|
||||
async_microphones_stop_recording(microphones)
|
||||
|
||||
for microphone_key in microphones:
|
||||
audio_chunks[microphone_key] = np.concatenate(audio_chunks[microphone_key], axis=0)
|
||||
|
||||
all_audio_chunks.append(audio_chunks)
|
||||
|
||||
# Disconnect microphones
|
||||
for microphone in microphones.values():
|
||||
microphone.disconnect()
|
||||
|
||||
# Compute statistics
|
||||
cmap = plt.get_cmap("tab10")
|
||||
_, ax = plt.subplots(nrows=repetitions, ncols=len(microphones))
|
||||
chunk_length = np.zeros((repetitions, len(microphones)))
|
||||
record_length = np.zeros((repetitions, len(microphones)))
|
||||
for i in range(repetitions):
|
||||
for j, (microphone_key, microphone) in enumerate(microphones.items()):
|
||||
# Get recorded audio chunks
|
||||
recorded_audio_chunks = all_audio_chunks[i][microphone_key]
|
||||
|
||||
# Load recorded file
|
||||
recorded_data, _ = read(recording_dir / f"{microphone_key}_recording_{i}.wav")
|
||||
if recorded_data.ndim == 1:
|
||||
recorded_data = np.expand_dims(recorded_data, axis=1)
|
||||
|
||||
record_length[i, j] = recorded_data.shape[0]
|
||||
chunk_length[i, j] = recorded_audio_chunks.shape[0]
|
||||
|
||||
for k, (chunk_data, record_data) in enumerate(
|
||||
zip(recorded_audio_chunks.T, recorded_data.T, strict=False)
|
||||
):
|
||||
# Plot audio chunks and recorded data
|
||||
ax[i, j].plot(
|
||||
np.arange(0, len(chunk_data)) / microphone.sample_rate,
|
||||
chunk_data,
|
||||
label=f"audio chunks - channel {k}",
|
||||
color=cmap(2 * k),
|
||||
)
|
||||
ax[i, j].plot(
|
||||
np.arange(0, len(record_data)) / microphone.sample_rate,
|
||||
record_data,
|
||||
label=f"recorded data - channel {k}",
|
||||
linestyle="dashed",
|
||||
color=cmap(2 * k + 1),
|
||||
)
|
||||
|
||||
# Plot absolute difference (errors should be located at the end of the recordings)
|
||||
if recorded_data.shape[0] - recorded_audio_chunks.shape[0] > 0:
|
||||
chunk_data = np.append(
|
||||
chunk_data, np.zeros(int(recorded_data.shape[0] - recorded_audio_chunks.shape[0]))
|
||||
)
|
||||
else:
|
||||
record_data = np.append(
|
||||
record_data, np.zeros(int(-recorded_data.shape[0] + recorded_audio_chunks.shape[0]))
|
||||
)
|
||||
ax[i, j].plot(
|
||||
np.arange(0, len(record_data)) / microphone.sample_rate,
|
||||
np.abs(chunk_data - record_data),
|
||||
label=f"differences - channel {k}",
|
||||
color="red",
|
||||
linestyle="dotted",
|
||||
)
|
||||
ax[i, j].set_title(f"{microphone_key} - repetition {i}")
|
||||
ax[i, j].legend()
|
||||
|
||||
plt.show()
|
||||
|
||||
# Print statistics
|
||||
differences = record_length - chunk_length
|
||||
for i, (microphone_key, microphone) in enumerate(microphones.items()):
|
||||
print(
|
||||
f"Average recorded duration for {microphone_key} : {np.mean(record_length[:, i]) / microphone.sample_rate:.3f} seconds"
|
||||
)
|
||||
print(
|
||||
f"Average chunk duration for {microphone_key} : {np.mean(chunk_length[:, i]) / microphone.sample_rate:.3f} seconds"
|
||||
)
|
||||
print(f"Average difference for {microphone_key} : {np.mean(differences[:, i]):.3f} samples")
|
||||
print(
|
||||
f"Average difference for {microphone_key} : {np.mean(differences[:, i]) / microphone.sample_rate:.3f} seconds"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--microphones_indices",
|
||||
type=int,
|
||||
nargs="+",
|
||||
default=[microphone["index"] for microphone in PortAudioMicrophone.find_microphones()],
|
||||
)
|
||||
parser.add_argument(
|
||||
"--microphones_sample_rate",
|
||||
type=float,
|
||||
nargs="+",
|
||||
default=[None] * len(PortAudioMicrophone.find_microphones()),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--microphones_channels",
|
||||
type=int,
|
||||
nargs="+",
|
||||
default=[None] * len(PortAudioMicrophone.find_microphones()),
|
||||
)
|
||||
parser.add_argument("--audio_chunks_number", type=int, default=2)
|
||||
parser.add_argument(
|
||||
"--audio_chunks_duration",
|
||||
type=float,
|
||||
default=1.0,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repetitions",
|
||||
type=int,
|
||||
default=2,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--multiprocessing",
|
||||
action="store_true",
|
||||
)
|
||||
|
||||
args = vars(parser.parse_args())
|
||||
|
||||
args["microphones_configs"] = {}
|
||||
for index, sample_rate, channels in zip(
|
||||
args["microphones_indices"],
|
||||
args["microphones_sample_rate"],
|
||||
args["microphones_channels"],
|
||||
strict=False,
|
||||
):
|
||||
microphone_config = PortAudioMicrophoneConfig(
|
||||
microphone_index=index,
|
||||
sample_rate=sample_rate,
|
||||
channels=channels,
|
||||
)
|
||||
args["microphones_configs"].update({f"microphone_{index}": microphone_config})
|
||||
args.pop("microphones_indices")
|
||||
args.pop("microphones_sample_rate")
|
||||
args.pop("microphones_channels")
|
||||
|
||||
main(**args)
|
||||
@@ -1,137 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
|
||||
from lerobot.microphones.configs import MicrophoneConfig
|
||||
from lerobot.microphones.touchlab import TouchLabSensorConfig
|
||||
from lerobot.microphones.utils import (
|
||||
async_microphones_start_recording,
|
||||
async_microphones_stop_recording,
|
||||
make_microphones_from_configs,
|
||||
)
|
||||
from lerobot.utils.robot_utils import (
|
||||
precise_sleep,
|
||||
)
|
||||
|
||||
|
||||
def main(
|
||||
sensors_configs: dict[str, MicrophoneConfig],
|
||||
multiprocessing: bool = False,
|
||||
):
|
||||
recording_dir = Path("outputs/tactile_benchmark")
|
||||
recording_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create microphones
|
||||
sensors = make_microphones_from_configs(sensors_configs)
|
||||
|
||||
# Connect microphones
|
||||
for sensor in sensors.values():
|
||||
sensor.connect()
|
||||
|
||||
# Create audio chunks
|
||||
data_chunks = {}
|
||||
for sensor_key in sensors:
|
||||
data_chunks.update({sensor_key: []})
|
||||
|
||||
# Start recording
|
||||
async_microphones_start_recording(
|
||||
sensors,
|
||||
output_files=[recording_dir / f"{sensor_key}_recording.wav" for sensor_key in sensors],
|
||||
multiprocessing=multiprocessing,
|
||||
)
|
||||
|
||||
# Record audio chunks
|
||||
precise_sleep(10.0)
|
||||
|
||||
for sensor_key, sensor in sensors.items():
|
||||
data_chunk = sensor.read()
|
||||
print(f"{sensor_key} - samples {data_chunk.shape[0]}")
|
||||
data_chunks[sensor_key].append(data_chunk)
|
||||
|
||||
# Stop recording
|
||||
async_microphones_stop_recording(sensors)
|
||||
|
||||
for sensor_key in sensors:
|
||||
data_chunks[sensor_key] = np.concatenate(data_chunks[sensor_key], axis=0)
|
||||
|
||||
# Disconnect microphones
|
||||
for sensor in sensors.values():
|
||||
sensor.disconnect()
|
||||
|
||||
for sensor_key in sensors:
|
||||
data, sample_rate = sf.read(recording_dir / f"{sensor_key}_recording.wav")
|
||||
print(f"{sensor_key} - samples {data.shape[0]}")
|
||||
print(f"{sensor_key} - sample rate {sample_rate}")
|
||||
print(f"{sensor_key} - data {data}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--sensors_ports",
|
||||
type=str,
|
||||
nargs="+",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sensors_baud_rate",
|
||||
type=int,
|
||||
nargs="+",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sensors_sample_rate",
|
||||
type=int,
|
||||
nargs="+",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sensors_channels",
|
||||
type=int,
|
||||
nargs="+",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--multiprocessing",
|
||||
action="store_true",
|
||||
)
|
||||
|
||||
args = vars(parser.parse_args())
|
||||
|
||||
args["sensors_configs"] = {}
|
||||
for port, baud_rate, sample_rate, channels in zip(
|
||||
args["sensors_ports"],
|
||||
args["sensors_baud_rate"],
|
||||
args["sensors_sample_rate"],
|
||||
args["sensors_channels"],
|
||||
strict=False,
|
||||
):
|
||||
if isinstance(channels, int):
|
||||
channels = [channels]
|
||||
sensor_config = TouchLabSensorConfig(
|
||||
sensor_port=port,
|
||||
baud_rate=baud_rate,
|
||||
sample_rate=sample_rate,
|
||||
channels=channels,
|
||||
)
|
||||
args["sensors_configs"].update({f"sensor_{port}": sensor_config})
|
||||
args.pop("sensors_ports")
|
||||
args.pop("sensors_baud_rate")
|
||||
args.pop("sensors_sample_rate")
|
||||
args.pop("sensors_channels")
|
||||
|
||||
main(**args)
|
||||
@@ -73,17 +73,10 @@ ENV HOME=/home/user_lerobot \
|
||||
RUN uv venv --python python${PYTHON_VERSION}
|
||||
|
||||
# Install Python dependencies for caching
|
||||
COPY --chown=user_lerobot:user_lerobot setup.py pyproject.toml README.md MANIFEST.in ./
|
||||
COPY --chown=user_lerobot:user_lerobot setup.py pyproject.toml uv.lock README.md MANIFEST.in ./
|
||||
COPY --chown=user_lerobot:user_lerobot src/ src/
|
||||
|
||||
ARG UNBOUND_DEPS=false
|
||||
|
||||
RUN if [ "$UNBOUND_DEPS" = "true" ]; then \
|
||||
sed -i 's/,[[:space:]]*<[0-9\.]*//g' pyproject.toml; \
|
||||
echo "Dependencies unbound:" && cat pyproject.toml; \
|
||||
fi
|
||||
|
||||
RUN uv pip install --no-cache ".[all]"
|
||||
RUN uv sync --locked --extra all --no-cache
|
||||
|
||||
RUN chmod +x /lerobot/.venv/lib/python${PYTHON_VERSION}/site-packages/triton/backends/nvidia/bin/ptxas
|
||||
|
||||
|
||||
@@ -61,17 +61,10 @@ ENV HOME=/home/user_lerobot \
|
||||
RUN uv venv
|
||||
|
||||
# Install Python dependencies for caching
|
||||
COPY --chown=user_lerobot:user_lerobot setup.py pyproject.toml README.md MANIFEST.in ./
|
||||
COPY --chown=user_lerobot:user_lerobot setup.py pyproject.toml uv.lock README.md MANIFEST.in ./
|
||||
COPY --chown=user_lerobot:user_lerobot src/ src/
|
||||
|
||||
ARG UNBOUND_DEPS=false
|
||||
|
||||
RUN if [ "$UNBOUND_DEPS" = "true" ]; then \
|
||||
sed -i 's/,[[:space:]]*<[0-9\.]*//g' pyproject.toml; \
|
||||
echo "Dependencies unbound:" && cat pyproject.toml; \
|
||||
fi
|
||||
|
||||
RUN uv pip install --no-cache ".[all]"
|
||||
RUN uv sync --locked --extra all --no-cache
|
||||
|
||||
# Copy the rest of the application code
|
||||
# Make sure to have the git-LFS files for testing
|
||||
|
||||
@@ -0,0 +1,77 @@
|
||||
# Docker
|
||||
|
||||
This directory contains Dockerfiles for running LeRobot in containerized environments. Both images are **built nightly from `main`** and published to Docker Hub with the full environment pre-baked — no dependency setup required.
|
||||
|
||||
## Pre-built Images
|
||||
|
||||
```bash
|
||||
# CPU-only image (based on Dockerfile.user)
|
||||
docker pull huggingface/lerobot-cpu:latest
|
||||
|
||||
# GPU image with CUDA support (based on Dockerfile.internal)
|
||||
docker pull huggingface/lerobot-gpu:latest
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
The fastest way to start training is to pull the GPU image and run `lerobot-train` directly. This is the same environment used for all of our CI, so it is a well-tested, batteries-included setup.
|
||||
|
||||
```bash
|
||||
docker run -it --rm --gpus all --shm-size 16gb huggingface/lerobot-gpu:latest
|
||||
|
||||
# inside the container:
|
||||
lerobot-train --policy.type=act --dataset.repo_id=lerobot/aloha_sim_transfer_cube_human
|
||||
```
|
||||
|
||||
## Dockerfiles
|
||||
|
||||
### `Dockerfile.user` (CPU)
|
||||
|
||||
A lightweight image based on `python:3.12-slim`. Includes all Python dependencies and system libraries but does not include CUDA — there is no GPU support. Useful for exploring the codebase, running scripts, or working with robots, but not practical for training.
|
||||
|
||||
### `Dockerfile.internal` (GPU)
|
||||
|
||||
A CUDA-enabled image based on `nvidia/cuda`. This is the image for training — mostly used for internal interactions with the GPU cluster.
|
||||
|
||||
## Usage
|
||||
|
||||
### Running a pre-built image
|
||||
|
||||
```bash
|
||||
# CPU
|
||||
docker run -it --rm huggingface/lerobot-cpu:latest
|
||||
|
||||
# GPU
|
||||
docker run -it --rm --gpus all --shm-size 16gb huggingface/lerobot-gpu:latest
|
||||
```
|
||||
|
||||
### Building locally
|
||||
|
||||
From the repo root:
|
||||
|
||||
```bash
|
||||
# CPU
|
||||
docker build -f docker/Dockerfile.user -t lerobot-user .
|
||||
docker run -it --rm lerobot-user
|
||||
|
||||
# GPU
|
||||
docker build -f docker/Dockerfile.internal -t lerobot-internal .
|
||||
docker run -it --rm --gpus all --shm-size 16gb lerobot-internal
|
||||
```
|
||||
|
||||
### Multi-GPU training
|
||||
|
||||
To select specific GPUs, set `CUDA_VISIBLE_DEVICES` when launching the container:
|
||||
|
||||
```bash
|
||||
# Use 4 GPUs
|
||||
docker run -it --rm --gpus all --shm-size 16gb \
|
||||
-e CUDA_VISIBLE_DEVICES=0,1,2,3 \
|
||||
huggingface/lerobot-gpu:latest
|
||||
```
|
||||
|
||||
### USB device access (e.g. robots, cameras)
|
||||
|
||||
```bash
|
||||
docker run -it --device=/dev/ -v /dev/:/dev/ --rm huggingface/lerobot-cpu:latest
|
||||
```
|
||||
@@ -17,6 +17,8 @@
|
||||
title: Train RL in Simulation
|
||||
- local: multi_gpu_training
|
||||
title: Multi GPU training
|
||||
- local: hil_data_collection
|
||||
title: Human In the Loop Data Collection
|
||||
- local: peft_training
|
||||
title: Training with PEFT (e.g., LoRA)
|
||||
- local: rename_map
|
||||
@@ -69,13 +71,17 @@
|
||||
title: Environments from the Hub
|
||||
- local: envhub_leisaac
|
||||
title: Control & Train Robots in Sim (LeIsaac)
|
||||
title: "Simulation"
|
||||
- sections:
|
||||
- local: adding_benchmarks
|
||||
title: Adding a New Benchmark
|
||||
- local: libero
|
||||
title: LIBERO
|
||||
- local: metaworld
|
||||
title: Meta-World
|
||||
- local: envhub_isaaclab_arena
|
||||
title: NVIDIA IsaacLab Arena Environments
|
||||
- local: libero
|
||||
title: Using Libero
|
||||
- local: metaworld
|
||||
title: Using MetaWorld
|
||||
title: "Simulation"
|
||||
title: "Benchmarks"
|
||||
- sections:
|
||||
- local: introduction_processors
|
||||
title: Introduction to Robot Processors
|
||||
|
||||
@@ -0,0 +1,320 @@
|
||||
# Adding a New Benchmark
|
||||
|
||||
This guide walks you through adding a new simulation benchmark to LeRobot. Follow the steps in order and use the existing benchmarks as templates.
|
||||
|
||||
A benchmark in LeRobot is a set of [Gymnasium](https://gymnasium.farama.org/) environments that wrap a third-party simulator (like LIBERO or Meta-World) behind a standard `gym.Env` interface. The `lerobot-eval` CLI then runs evaluation uniformly across all benchmarks.
|
||||
|
||||
## Existing benchmarks at a glance
|
||||
|
||||
Before diving in, here is what is already integrated:
|
||||
|
||||
| Benchmark | Env file | Config class | Tasks | Action dim | Processor |
|
||||
| -------------- | ------------------- | ------------------ | ------------------- | ------------ | ---------------------------- |
|
||||
| LIBERO | `envs/libero.py` | `LiberoEnv` | 130 across 5 suites | 7 | `LiberoProcessorStep` |
|
||||
| Meta-World | `envs/metaworld.py` | `MetaworldEnv` | 50 (MT50) | 4 | None |
|
||||
| IsaacLab Arena | Hub-hosted | `IsaaclabArenaEnv` | Configurable | Configurable | `IsaaclabArenaProcessorStep` |
|
||||
|
||||
Use `src/lerobot/envs/libero.py` and `src/lerobot/envs/metaworld.py` as reference implementations.
|
||||
|
||||
## How it all fits together
|
||||
|
||||
### Data flow
|
||||
|
||||
During evaluation, data moves through four stages:
|
||||
|
||||
```
|
||||
1. gym.Env ──→ raw observations (numpy dicts)
|
||||
|
||||
2. Preprocessing ──→ standard LeRobot keys + task description
|
||||
(preprocess_observation, add_envs_task in envs/utils.py)
|
||||
|
||||
3. Processors ──→ env-specific then policy-specific transforms
|
||||
(env_preprocessor, policy_preprocessor)
|
||||
|
||||
4. Policy ──→ select_action() ──→ action tensor
|
||||
then reverse: policy_postprocessor → env_postprocessor → numpy action → env.step()
|
||||
```
|
||||
|
||||
Most benchmarks only need to care about stage 1 (producing observations in the right format) and optionally stage 3 (if env-specific transforms are needed).
|
||||
|
||||
### Environment structure
|
||||
|
||||
`make_env()` returns a nested dict of vectorized environments:
|
||||
|
||||
```python
|
||||
dict[str, dict[int, gym.vector.VectorEnv]]
|
||||
# ^suite ^task_id
|
||||
```
|
||||
|
||||
A single-task env (e.g. PushT) looks like `{"pusht": {0: vec_env}}`.
|
||||
A multi-task benchmark (e.g. LIBERO) looks like `{"libero_spatial": {0: vec0, 1: vec1, ...}, ...}`.
|
||||
|
||||
### How evaluation runs
|
||||
|
||||
All benchmarks are evaluated the same way by `lerobot-eval`:
|
||||
|
||||
1. `make_env()` builds the nested `{suite: {task_id: VectorEnv}}` dict.
|
||||
2. `eval_policy_all()` iterates over every suite and task.
|
||||
3. For each task, it runs `n_episodes` rollouts via `rollout()`.
|
||||
4. Results are aggregated hierarchically: episode, task, suite, overall.
|
||||
5. Metrics include `pc_success` (success rate), `avg_sum_reward`, and `avg_max_reward`.
|
||||
|
||||
The critical piece: your env must return `info["is_success"]` on every `step()` call. This is how the eval loop knows whether a task was completed.
|
||||
|
||||
## What your environment must provide
|
||||
|
||||
LeRobot does not enforce a strict observation schema. Instead it relies on a set of conventions that all benchmarks follow.
|
||||
|
||||
### Env attributes
|
||||
|
||||
Your `gym.Env` must set these attributes:
|
||||
|
||||
| Attribute | Type | Why |
|
||||
| -------------------- | ----- | ---------------------------------------------------- |
|
||||
| `_max_episode_steps` | `int` | `rollout()` uses this to cap episode length |
|
||||
| `task_description` | `str` | Passed to VLA policies as a language instruction |
|
||||
| `task` | `str` | Fallback identifier if `task_description` is not set |
|
||||
|
||||
### Success reporting
|
||||
|
||||
Your `step()` and `reset()` must include `"is_success"` in the `info` dict:
|
||||
|
||||
```python
|
||||
info = {"is_success": True} # or False
|
||||
return observation, reward, terminated, truncated, info
|
||||
```
|
||||
|
||||
### Observations
|
||||
|
||||
The simplest approach is to map your simulator's outputs to the standard keys that `preprocess_observation()` already understands. Do this inside your `gym.Env` (e.g. in a `_format_raw_obs()` helper):
|
||||
|
||||
| Your env should output | LeRobot maps it to | What it is |
|
||||
| ------------------------- | -------------------------- | ------------------------------------- |
|
||||
| `"pixels"` (single array) | `observation.image` | Single camera image, HWC uint8 |
|
||||
| `"pixels"` (dict) | `observation.images.<cam>` | Multiple cameras, each HWC uint8 |
|
||||
| `"agent_pos"` | `observation.state` | Proprioceptive state vector |
|
||||
| `"environment_state"` | `observation.env_state` | Full environment state (e.g. PushT) |
|
||||
| `"robot_state"` | `observation.robot_state` | Nested robot state dict (e.g. LIBERO) |
|
||||
|
||||
If your simulator uses different key names, you have two options:
|
||||
|
||||
1. **Recommended:** Rename them to the standard keys inside your `gym.Env` wrapper.
|
||||
2. **Alternative:** Write an env processor to transform observations after `preprocess_observation()` runs (see step 4 below).
|
||||
|
||||
### Actions
|
||||
|
||||
Actions are continuous numpy arrays in a `gym.spaces.Box`. The dimensionality depends on your benchmark (7 for LIBERO, 4 for Meta-World, etc.). Policies adapt to different action dimensions through their `input_features` / `output_features` config.
|
||||
|
||||
### Feature declaration
|
||||
|
||||
Each `EnvConfig` subclass declares two dicts that tell the policy what to expect:
|
||||
|
||||
- `features` — maps feature names to `PolicyFeature(type, shape)` (e.g. action dim, image shape).
|
||||
- `features_map` — maps raw observation keys to LeRobot convention keys (e.g. `"agent_pos"` to `"observation.state"`).
|
||||
|
||||
## Step by step
|
||||
|
||||
<Tip>
|
||||
At minimum, you need two files: a **gym.Env wrapper** and an **EnvConfig
|
||||
subclass** with a `create_envs()` override. Everything else is optional or
|
||||
documentation. No changes to `factory.py` are needed.
|
||||
</Tip>
|
||||
|
||||
### Checklist
|
||||
|
||||
| File | Required | Why |
|
||||
| ---------------------------------------- | -------- | ------------------------------------------------------------ |
|
||||
| `src/lerobot/envs/<benchmark>.py` | Yes | Wraps the simulator as a standard gym.Env |
|
||||
| `src/lerobot/envs/configs.py` | Yes | Registers your benchmark and its `create_envs()` for the CLI |
|
||||
| `src/lerobot/processor/env_processor.py` | Optional | Custom observation/action transforms |
|
||||
| `src/lerobot/envs/utils.py` | Optional | Only if you need new raw observation keys |
|
||||
| `pyproject.toml` | Yes | Declares benchmark-specific dependencies |
|
||||
| `docs/source/<benchmark>.mdx` | Yes | User-facing documentation page |
|
||||
| `docs/source/_toctree.yml` | Yes | Adds your page to the docs sidebar |
|
||||
|
||||
### 1. The gym.Env wrapper (`src/lerobot/envs/<benchmark>.py`)
|
||||
|
||||
Create a `gym.Env` subclass that wraps the third-party simulator:
|
||||
|
||||
```python
|
||||
class MyBenchmarkEnv(gym.Env):
|
||||
metadata = {"render_modes": ["rgb_array"], "render_fps": <fps>}
|
||||
|
||||
def __init__(self, task_suite, task_id, ...):
|
||||
super().__init__()
|
||||
self.task = <task_name_string>
|
||||
self.task_description = <natural_language_instruction>
|
||||
self._max_episode_steps = <max_steps>
|
||||
self.observation_space = spaces.Dict({...})
|
||||
self.action_space = spaces.Box(low=..., high=..., shape=(...,), dtype=np.float32)
|
||||
|
||||
def reset(self, seed=None, **kwargs):
|
||||
... # return (observation, info) — info must contain {"is_success": False}
|
||||
|
||||
def step(self, action: np.ndarray):
|
||||
... # return (obs, reward, terminated, truncated, info) — info must contain {"is_success": <bool>}
|
||||
|
||||
def render(self):
|
||||
... # return RGB image as numpy array
|
||||
|
||||
def close(self):
|
||||
...
|
||||
```
|
||||
|
||||
Also provide a factory function that returns the nested dict structure:
|
||||
|
||||
```python
|
||||
def create_mybenchmark_envs(
|
||||
task: str,
|
||||
n_envs: int,
|
||||
gym_kwargs: dict | None = None,
|
||||
env_cls: type | None = None,
|
||||
) -> dict[str, dict[int, Any]]:
|
||||
"""Create {suite_name: {task_id: VectorEnv}} for MyBenchmark."""
|
||||
...
|
||||
```
|
||||
|
||||
See `create_libero_envs()` (multi-suite, multi-task) and `create_metaworld_envs()` (difficulty-grouped tasks) for reference.
|
||||
|
||||
### 2. The config (`src/lerobot/envs/configs.py`)
|
||||
|
||||
Register a config dataclass so users can select your benchmark with `--env.type=<name>`. Each config owns its environment creation and processor logic via two methods:
|
||||
|
||||
- **`create_envs(n_envs, use_async_envs)`** — Returns `{suite: {task_id: VectorEnv}}`. The base class default uses `gym.make()` for single-task envs. Multi-task benchmarks override this.
|
||||
- **`get_env_processors()`** — Returns `(preprocessor, postprocessor)`. The base class default returns identity (no-op) pipelines. Override if your benchmark needs observation/action transforms.
|
||||
|
||||
```python
|
||||
@EnvConfig.register_subclass("<benchmark_name>")
|
||||
@dataclass
|
||||
class MyBenchmarkEnvConfig(EnvConfig):
|
||||
task: str = "<default_task>"
|
||||
fps: int = <fps>
|
||||
obs_type: str = "pixels_agent_pos"
|
||||
|
||||
features: dict[str, PolicyFeature] = field(default_factory=lambda: {
|
||||
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(<action_dim>,)),
|
||||
})
|
||||
features_map: dict[str, str] = field(default_factory=lambda: {
|
||||
ACTION: ACTION,
|
||||
"agent_pos": OBS_STATE,
|
||||
"pixels": OBS_IMAGE,
|
||||
})
|
||||
|
||||
def __post_init__(self):
|
||||
... # populate features based on obs_type
|
||||
|
||||
@property
|
||||
def gym_kwargs(self) -> dict:
|
||||
return {"obs_type": self.obs_type, "render_mode": self.render_mode}
|
||||
|
||||
def create_envs(self, n_envs: int, use_async_envs: bool = False):
|
||||
"""Override for multi-task benchmarks or custom env creation."""
|
||||
from lerobot.envs.<benchmark> import create_<benchmark>_envs
|
||||
return create_<benchmark>_envs(task=self.task, n_envs=n_envs, ...)
|
||||
|
||||
def get_env_processors(self):
|
||||
"""Override if your benchmark needs observation/action transforms."""
|
||||
from lerobot.processor.pipeline import PolicyProcessorPipeline
|
||||
from lerobot.processor.env_processor import MyBenchmarkProcessorStep
|
||||
return (
|
||||
PolicyProcessorPipeline(steps=[MyBenchmarkProcessorStep()]),
|
||||
PolicyProcessorPipeline(steps=[]),
|
||||
)
|
||||
```
|
||||
|
||||
Key points:
|
||||
|
||||
- The `register_subclass` name is what users pass on the CLI (`--env.type=<name>`).
|
||||
- `features` tells the policy what the environment produces.
|
||||
- `features_map` maps raw observation keys to LeRobot convention keys.
|
||||
- **No changes to `factory.py` needed** — the factory delegates to `cfg.create_envs()` and `cfg.get_env_processors()` automatically.
|
||||
|
||||
### 3. Env processor (optional — `src/lerobot/processor/env_processor.py`)
|
||||
|
||||
Only needed if your benchmark requires observation transforms beyond what `preprocess_observation()` handles (e.g. image flipping, coordinate conversion). Define the processor step here and return it from `get_env_processors()` in your config (see step 2):
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="<benchmark>_processor")
|
||||
class MyBenchmarkProcessorStep(ObservationProcessorStep):
|
||||
def _process_observation(self, observation):
|
||||
processed = observation.copy()
|
||||
# your transforms here
|
||||
return processed
|
||||
|
||||
def transform_features(self, features):
|
||||
return features # update if shapes change
|
||||
|
||||
def observation(self, observation):
|
||||
return self._process_observation(observation)
|
||||
```
|
||||
|
||||
See `LiberoProcessorStep` for a full example (image rotation, quaternion-to-axis-angle conversion).
|
||||
|
||||
### 4. Dependencies (`pyproject.toml`)
|
||||
|
||||
Add a new optional-dependency group:
|
||||
|
||||
```toml
|
||||
mybenchmark = ["my-benchmark-pkg==1.2.3", "lerobot[scipy-dep]"]
|
||||
```
|
||||
|
||||
Pinning rules:
|
||||
|
||||
- **Always pin** benchmark packages to exact versions for reproducibility (e.g. `metaworld==3.0.0`).
|
||||
- **Add platform markers** when needed (e.g. `; sys_platform == 'linux'`).
|
||||
- **Pin fragile transitive deps** if known (e.g. `gymnasium==1.1.0` for Meta-World).
|
||||
- **Document constraints** in your benchmark doc page.
|
||||
|
||||
Users install with:
|
||||
|
||||
```bash
|
||||
pip install -e ".[mybenchmark]"
|
||||
```
|
||||
|
||||
### 5. Documentation (`docs/source/<benchmark>.mdx`)
|
||||
|
||||
Write a user-facing page following the template in the next section. See `docs/source/libero.mdx` and `docs/source/metaworld.mdx` for full examples.
|
||||
|
||||
### 6. Table of contents (`docs/source/_toctree.yml`)
|
||||
|
||||
Add your benchmark to the "Benchmarks" section:
|
||||
|
||||
```yaml
|
||||
- sections:
|
||||
- local: libero
|
||||
title: LIBERO
|
||||
- local: metaworld
|
||||
title: Meta-World
|
||||
- local: envhub_isaaclab_arena
|
||||
title: NVIDIA IsaacLab Arena Environments
|
||||
- local: <your_benchmark>
|
||||
title: <Your Benchmark Name>
|
||||
title: "Benchmarks"
|
||||
```
|
||||
|
||||
## Verifying your integration
|
||||
|
||||
After completing the steps above, confirm that everything works:
|
||||
|
||||
1. **Install** — `pip install -e ".[mybenchmark]"` and verify the dependency group installs cleanly.
|
||||
2. **Smoke test env creation** — call `make_env()` with your config in Python, check that the returned dict has the expected `{suite: {task_id: VectorEnv}}` shape, and that `reset()` returns observations with the right keys.
|
||||
3. **Run a full eval** — `lerobot-eval --env.type=<name> --env.task=<task> --eval.n_episodes=1 --eval.batch_size=1 --policy.path=<any_compatible_policy>` to exercise the full pipeline end-to-end.
|
||||
4. **Check success detection** — verify that `info["is_success"]` flips to `True` when the task is actually completed. This is what the eval loop uses to compute success rates.
|
||||
|
||||
## Writing a benchmark doc page
|
||||
|
||||
Each benchmark `.mdx` page should include:
|
||||
|
||||
- **Title and description** — 1-2 paragraphs on what the benchmark tests and why it matters.
|
||||
- **Links** — paper, GitHub repo, project website (if available).
|
||||
- **Overview image or GIF.**
|
||||
- **Available tasks** — table of task suites with counts and brief descriptions.
|
||||
- **Installation** — `pip install -e ".[<benchmark>]"` plus any extra steps (env vars, system packages).
|
||||
- **Evaluation** — recommended `lerobot-eval` command with `n_episodes` and `batch_size` for reproducible results. Include single-task and multi-task examples if applicable.
|
||||
- **Policy inputs and outputs** — observation keys with shapes, action space description.
|
||||
- **Recommended evaluation episodes** — how many episodes per task is standard.
|
||||
- **Training** — example `lerobot-train` command.
|
||||
- **Reproducing published results** — link to pretrained model, eval command, results table (if available).
|
||||
|
||||
See `docs/source/libero.mdx` and `docs/source/metaworld.mdx` for complete examples.
|
||||
@@ -90,11 +90,17 @@ The same policy can work with different environment processors, and the same env
|
||||
|
||||
```python
|
||||
# Use SmolVLA policy with LIBERO environment
|
||||
libero_preprocessor, libero_postprocessor = make_env_pre_post_processors(libero_cfg)
|
||||
# Use SmolVLA policy with LIBERO environment
|
||||
libero_preprocessor, libero_postprocessor = make_env_pre_post_processors(
|
||||
env_cfg=libero_cfg,
|
||||
policy_cfg=smolvla_cfg,
|
||||
)
|
||||
smolvla_preprocessor, smolvla_postprocessor = make_pre_post_processors(smolvla_cfg)
|
||||
|
||||
# Or use ACT policy with the same LIBERO environment
|
||||
libero_preprocessor, libero_postprocessor = make_env_pre_post_processors(libero_cfg)
|
||||
libero_preprocessor, libero_postprocessor = make_env_pre_post_processors(
|
||||
env_cfg=libero_cfg,
|
||||
policy_cfg=act_cfg,
|
||||
)
|
||||
act_preprocessor, act_postprocessor = make_pre_post_processors(act_cfg)
|
||||
```
|
||||
|
||||
@@ -151,7 +157,7 @@ observation = {
|
||||
|
||||
### Factory Function
|
||||
|
||||
The `make_env_pre_post_processors` function follows the same pattern as `make_pre_post_processors` for policies:
|
||||
The `make_env_pre_post_processors` function delegates to `env_cfg.get_env_processors()`:
|
||||
|
||||
```python
|
||||
from lerobot.envs.factory import make_env_pre_post_processors
|
||||
@@ -159,47 +165,31 @@ from lerobot.envs.configs import LiberoEnv, PushtEnv
|
||||
|
||||
# For LIBERO: Returns LiberoProcessorStep in preprocessor
|
||||
libero_cfg = LiberoEnv(task="libero_spatial", camera_name=["agentview"])
|
||||
env_preprocessor, env_postprocessor = make_env_pre_post_processors(libero_cfg)
|
||||
env_preprocessor, env_postprocessor = make_env_pre_post_processors(libero_cfg, policy_cfg)
|
||||
|
||||
# For other environments: Returns identity processors (no-op)
|
||||
pusht_cfg = PushtEnv()
|
||||
env_preprocessor, env_postprocessor = make_env_pre_post_processors(pusht_cfg)
|
||||
env_preprocessor, env_postprocessor = make_env_pre_post_processors(pusht_cfg, policy_cfg)
|
||||
```
|
||||
|
||||
### Implementation in `envs/factory.py`
|
||||
### How It Works
|
||||
|
||||
Each `EnvConfig` subclass can override `get_env_processors()` to return benchmark-specific
|
||||
processor pipelines. The base class returns identity (no-op) processors by default.
|
||||
|
||||
```python
|
||||
def make_env_pre_post_processors(
|
||||
env_cfg: EnvConfig,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
]:
|
||||
"""
|
||||
Create preprocessor and postprocessor pipelines for environment observations.
|
||||
|
||||
Args:
|
||||
env_cfg: The configuration of the environment.
|
||||
|
||||
Returns:
|
||||
A tuple containing:
|
||||
- preprocessor: Pipeline that processes environment observations
|
||||
- postprocessor: Pipeline that processes environment outputs
|
||||
"""
|
||||
# For LIBERO environments, add the LiberoProcessorStep to preprocessor
|
||||
if isinstance(env_cfg, LiberoEnv) or "libero" in env_cfg.type:
|
||||
preprocessor = PolicyProcessorPipeline(steps=[LiberoProcessorStep()])
|
||||
else:
|
||||
# For all other environments, return an identity preprocessor
|
||||
preprocessor = PolicyProcessorPipeline(steps=[])
|
||||
|
||||
# Postprocessor is currently identity for all environments
|
||||
# Future: Could add environment-specific action transformations
|
||||
postprocessor = PolicyProcessorPipeline(steps=[])
|
||||
|
||||
return preprocessor, postprocessor
|
||||
# In your EnvConfig subclass:
|
||||
def get_env_processors(self):
|
||||
from lerobot.processor.pipeline import PolicyProcessorPipeline
|
||||
return (
|
||||
PolicyProcessorPipeline(steps=[MyProcessorStep()]),
|
||||
PolicyProcessorPipeline(steps=[]),
|
||||
)
|
||||
```
|
||||
|
||||
The factory function `make_env_pre_post_processors` simply delegates to this method,
|
||||
with a special case for `XVLAConfig` policies which override the env processors entirely.
|
||||
|
||||
### Integration in Evaluation
|
||||
|
||||
In `lerobot_eval.py`, the environment processors are created once and used throughout:
|
||||
@@ -219,7 +209,10 @@ def eval_main(cfg: EvalPipelineConfig):
|
||||
)
|
||||
|
||||
# Create environment processors (NEW!)
|
||||
env_preprocessor, env_postprocessor = make_env_pre_post_processors(env_cfg=cfg.env)
|
||||
env_preprocessor, env_postprocessor = make_env_pre_post_processors(
|
||||
env_cfg=cfg.env,
|
||||
policy_cfg=cfg.policy,
|
||||
)
|
||||
|
||||
# Run evaluation with both processor types
|
||||
eval_policy_all(
|
||||
@@ -323,21 +316,22 @@ class MyEnvProcessorStep(ObservationProcessorStep):
|
||||
return processed
|
||||
```
|
||||
|
||||
### 2. Update the Factory
|
||||
### 2. Update Your `EnvConfig` Subclass
|
||||
|
||||
```python
|
||||
# In src/lerobot/envs/factory.py
|
||||
# In src/lerobot/envs/configs.py
|
||||
@EnvConfig.register_subclass("myenv")
|
||||
@dataclass
|
||||
class MyEnvConfig(EnvConfig):
|
||||
# ... task/features/gym kwargs ...
|
||||
|
||||
def make_env_pre_post_processors(env_cfg: EnvConfig):
|
||||
if isinstance(env_cfg, LiberoEnv) or "libero" in env_cfg.type:
|
||||
preprocessor = PolicyProcessorPipeline(steps=[LiberoProcessorStep()])
|
||||
elif isinstance(env_cfg, MyEnvConfig) or "myenv" in env_cfg.type:
|
||||
preprocessor = PolicyProcessorPipeline(steps=[MyEnvProcessorStep()])
|
||||
else:
|
||||
preprocessor = PolicyProcessorPipeline(steps=[])
|
||||
def get_env_processors(self):
|
||||
from lerobot.processor.pipeline import PolicyProcessorPipeline
|
||||
|
||||
postprocessor = PolicyProcessorPipeline(steps=[])
|
||||
return preprocessor, postprocessor
|
||||
return (
|
||||
PolicyProcessorPipeline(steps=[MyEnvProcessorStep()]),
|
||||
PolicyProcessorPipeline(steps=[]),
|
||||
)
|
||||
```
|
||||
|
||||
### 3. Use in Evaluation
|
||||
|
||||
@@ -131,4 +131,4 @@ lerobot-record \
|
||||
|
||||
## License
|
||||
|
||||
This model follows the **Apache 2.0 License**, consistent with the original [GR00T repository](https://github.com/NVIDIA/Isaac-GR00T).
|
||||
This model follows NVIDIA's proprietary license, consistent with the original [GR00T repository](https://github.com/NVIDIA/Isaac-GR00T). Future versions (starting from N1.7) will follow **Apache 2.0 License**.
|
||||
|
||||
@@ -0,0 +1,269 @@
|
||||
# Human-In-the-Loop Data Collection
|
||||
|
||||
Human-In-the-Loop (HIL) data collection lets you improve a trained policy by deploying it on a real robot while a human operator monitors and intervenes when needed. The intervention data (recovery movements and corrections) is recorded alongside autonomous segments, producing a richer training dataset that teaches the policy how to handle failures.
|
||||
|
||||
---
|
||||
|
||||
## Why Human-In-the-Loop?
|
||||
|
||||
Standard behavioral cloning trains policies on successful demonstrations only. During deployment, small errors can compound and push the robot into states never seen during training (distribution shift). HIL data collection addresses this by:
|
||||
|
||||
- Running the trained policy on the real robot
|
||||
- Having a human intervene when the robot is about to fail
|
||||
- Recording the human's recovery and correction as training data
|
||||
- Fine-tuning the policy on the combined dataset
|
||||
|
||||
This produces a policy that not only knows how to perform the task, but also how to recover when things go wrong.
|
||||
|
||||
---
|
||||
|
||||
## How It Works
|
||||
|
||||
During a HIL session, the human operator follows this loop within each episode:
|
||||
|
||||
1. **Watch** the policy run autonomously
|
||||
2. **Pause** when failure is imminent, the robot holds its position
|
||||
3. **Take control** and teleoperate the robot back to a good state (recovery), then correct the behavior
|
||||
4. **Return control to the policy**, the policy resumes autonomous execution
|
||||
5. Repeat steps 2–4 as many times as needed during the episode
|
||||
6. **End the episode** when the task is complete, save and move on to the next rollout
|
||||
|
||||
Both autonomous and human-controlled segments are recorded. The policy and human can alternate control multiple times within a single episode, and the episode continues from the current state after each handoff (no reset required just because intervention happened). This captures autonomous execution, recovery, and correction in one continuous trajectory. After collection, the combined dataset (original demonstrations + HIL data) is used to fine-tune the policy.
|
||||
|
||||
This process can be repeated iteratively: deploy, collect, fine-tune, repeat. Each round targets the current policy's failure modes.
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────────────────┐
|
||||
│ Policy v0 (trained on demos) │
|
||||
│ ↓ │
|
||||
│ HIL Collection (target current failure modes) → Fine-tune → Policy v1 │
|
||||
│ ↓ │
|
||||
│ HIL Collection (target new failure modes) → Fine-tune → Policy v2 │
|
||||
│ ↓ │
|
||||
│ ... (repeat until satisfactory performance) │
|
||||
└─────────────────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Hardware Requirements
|
||||
|
||||
### Teleoperator Requirements
|
||||
|
||||
The `examples/hil` HIL scripts require **teleoperators with active motors** that can:
|
||||
|
||||
- Enable/disable torque programmatically
|
||||
- Move to target positions (to mirror the robot state when pausing)
|
||||
|
||||
**Compatible teleoperators in the current `examples/hil` scripts:**
|
||||
|
||||
- `openarm_mini` - OpenArm Mini
|
||||
- `so_leader` - SO100 / SO101 leader arm
|
||||
|
||||
> [!IMPORTANT]
|
||||
> The provided `examples/hil` commands default to `bi_openarm_follower` + `openarm_mini`.
|
||||
> `so_follower` + `so_leader` configs are also registered and can be used via CLI flags.
|
||||
|
||||
---
|
||||
|
||||
## Script
|
||||
|
||||
A single script handles both synchronous and RTC-based inference. Toggle RTC with `--rtc.enabled=true`:
|
||||
|
||||
| Mode | Flag | Models |
|
||||
| ------------------------ | -------------------- | --------------------- |
|
||||
| Standard (default) | _(no flag needed)_ | ACT, Diffusion Policy |
|
||||
| Real-Time Chunking (RTC) | `--rtc.enabled=true` | Pi0, Pi0.5, SmolVLA |
|
||||
|
||||
---
|
||||
|
||||
## Step-by-Step Guide
|
||||
|
||||
### Step 1: Pre-train a Base Policy
|
||||
|
||||
First, train a policy on your demonstration dataset:
|
||||
|
||||
```bash
|
||||
python src/lerobot/scripts/lerobot_train.py \
|
||||
--dataset.repo_id=your-username/demo-dataset \
|
||||
--policy.type=pi0 \
|
||||
--output_dir=outputs/pretrain \
|
||||
--batch_size=32 \
|
||||
--steps=50000
|
||||
```
|
||||
|
||||
### Step 2: Collect HIL Data
|
||||
|
||||
**Standard inference (ACT, Diffusion Policy):**
|
||||
|
||||
```bash
|
||||
python examples/hil/hil_data_collection.py \
|
||||
--robot.type=bi_openarm_follower \
|
||||
--robot.left_arm_config.port=can1 \
|
||||
--robot.left_arm_config.side=left \
|
||||
--robot.right_arm_config.port=can0 \
|
||||
--robot.right_arm_config.side=right \
|
||||
--robot.cameras='{left_wrist: {type: opencv, index_or_path: "/dev/video0", width: 1280, height: 720, fps: 30}, right_wrist: {type: opencv, index_or_path: "/dev/video4", width: 1280, height: 720, fps: 30}, base: {type: opencv, index_or_path: "/dev/video2", width: 640, height: 480, fps: 30}}' \
|
||||
--teleop.type=openarm_mini \
|
||||
--teleop.port_left=/dev/ttyACM0 \
|
||||
--teleop.port_right=/dev/ttyACM1 \
|
||||
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
|
||||
--dataset.repo_id=your-username/hil-dataset \
|
||||
--dataset.single_task="Fold the T-shirt properly" \
|
||||
--dataset.fps=30 \
|
||||
--dataset.episode_time_s=1000 \
|
||||
--dataset.num_episodes=50 \
|
||||
--interpolation_multiplier=2
|
||||
```
|
||||
|
||||
**With RTC for large models (Pi0, Pi0.5, SmolVLA):**
|
||||
|
||||
For models with high inference latency, enable RTC for smooth execution:
|
||||
|
||||
```bash
|
||||
python examples/hil/hil_data_collection.py \
|
||||
--rtc.enabled=true \
|
||||
--rtc.execution_horizon=20 \
|
||||
--rtc.max_guidance_weight=5.0 \
|
||||
--rtc.prefix_attention_schedule=LINEAR \
|
||||
--robot.type=bi_openarm_follower \
|
||||
--robot.left_arm_config.port=can1 \
|
||||
--robot.left_arm_config.side=left \
|
||||
--robot.right_arm_config.port=can0 \
|
||||
--robot.right_arm_config.side=right \
|
||||
--robot.cameras='{left_wrist: {type: opencv, index_or_path: "/dev/video0", width: 1280, height: 720, fps: 30}, right_wrist: {type: opencv, index_or_path: "/dev/video4", width: 1280, height: 720, fps: 30}, base: {type: opencv, index_or_path: "/dev/video2", width: 640, height: 480, fps: 30}}' \
|
||||
--teleop.type=openarm_mini \
|
||||
--teleop.port_left=/dev/ttyACM0 \
|
||||
--teleop.port_right=/dev/ttyACM1 \
|
||||
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
|
||||
--dataset.repo_id=your-username/hil-rtc-dataset \
|
||||
--dataset.single_task="Fold the T-shirt properly" \
|
||||
--dataset.fps=30 \
|
||||
--dataset.episode_time_s=1000 \
|
||||
--dataset.num_episodes=50 \
|
||||
--interpolation_multiplier=3
|
||||
```
|
||||
|
||||
**Controls (Conceptual):**
|
||||
|
||||
The interaction model is:
|
||||
|
||||
- **Pause input**: pause autonomous policy execution
|
||||
- **Takeover input**: transfer control to the human operator and record intervention data
|
||||
- **Return-to-policy input**: hand control back to the policy and continue the same episode
|
||||
- **Episode control inputs**: save/re-record/stop/reset as needed
|
||||
|
||||
Exact key/pedal bindings can differ across scripts and hardware integrations. Use each script's printed controls as the source of truth for the concrete mapping on your setup.
|
||||
|
||||
**The HIL Protocol:**
|
||||
|
||||
1. Watch the policy run autonomously (teleop is idle/free)
|
||||
2. When you see imminent failure, trigger the **pause input**
|
||||
- Policy stops
|
||||
- Teleoperator moves to match robot position (torque enabled)
|
||||
- No frames recorded during pause
|
||||
3. Trigger the **takeover input** to take control
|
||||
- Teleoperator torque disabled, free to move
|
||||
- **Recovery**: Teleoperate the robot back to a good state
|
||||
- **Correction**: Correct the behavior
|
||||
- All movements are recorded
|
||||
4. Trigger the **return-to-policy input**
|
||||
- Policy resumes autonomous execution from the current state
|
||||
- You can intervene again at any time (repeat steps 2–4)
|
||||
5. End and save the episode when the task is complete (or episode time limit is reached)
|
||||
6. **Reset**: Teleop moves to robot position, you can move the robot to the starting position
|
||||
7. Start the next episode
|
||||
|
||||
**Foot Pedal Setup (Linux):**
|
||||
|
||||
If using a USB foot pedal (PCsensor FootSwitch), ensure access:
|
||||
|
||||
```bash
|
||||
sudo setfacl -m u:$USER:rw /dev/input/by-id/usb-PCsensor_FootSwitch-event-kbd
|
||||
```
|
||||
|
||||
### Step 3: Fine-tune the Policy
|
||||
|
||||
Fine-tune on the **combined** dataset (`demo-dataset` + `hil-dataset` merged together):
|
||||
|
||||
```bash
|
||||
python src/lerobot/scripts/lerobot_train.py \
|
||||
--dataset.repo_id=your-username/hil-dataset \
|
||||
--policy.type=pi0 \
|
||||
--policy.pretrained_path=outputs/pretrain/checkpoints/last/pretrained_model \
|
||||
--output_dir=outputs/hil_finetune \
|
||||
--steps=20000
|
||||
```
|
||||
|
||||
Then deploy the fine-tuned policy and repeat from Step 2 to target its remaining failure modes.
|
||||
|
||||
---
|
||||
|
||||
## Tips for Effective HIL Collection
|
||||
|
||||
### When to Intervene
|
||||
|
||||
Intervene when you see:
|
||||
|
||||
- Robot about to make an irreversible mistake
|
||||
- Robot hesitating or showing uncertain behavior
|
||||
- Robot deviating from the expected trajectory
|
||||
|
||||
### Recovery: Teleoperating Back to a Good State
|
||||
|
||||
During recovery, teleoperate the robot back to a state where:
|
||||
|
||||
- The robot is in a familiar, in-distribution configuration
|
||||
- The current subtask can still be completed
|
||||
- The recovery trajectory itself is informative training data
|
||||
|
||||
### Quality of Corrections
|
||||
|
||||
During correction:
|
||||
|
||||
- Provide **confident, clean** trajectories
|
||||
- Complete the current subtask fully
|
||||
- Don't overcorrect or add unnecessary movements
|
||||
|
||||
---
|
||||
|
||||
## Related Work
|
||||
|
||||
This HIL data collection approach builds on ideas from interactive imitation learning:
|
||||
|
||||
- **DAgger** (Ross et al., 2011) introduced the core idea: instead of only training on expert demonstrations, query the expert for corrections on states the _learner_ visits. This breaks the compounding-error cycle of standard behavioral cloning by iteratively collecting on-policy data.
|
||||
|
||||
- **HG-DAgger** (Kelly et al., 2019) made this practical for robotics: a human expert monitors the robot and only intervenes when needed, rather than labeling every state. The gating between autonomous and human control is exactly the pause → takeover → return-to-policy loop used in the scripts here.
|
||||
|
||||
- **RaC** (Hu et al., 2025) scales this loop to long-horizon tasks by explicitly decomposing interventions into **recovery** (teleoperating back to a good state) and **correction** (demonstrating the right behavior from there). This decomposition is the protocol followed by the HIL scripts in `examples/hil`.
|
||||
|
||||
- **π0.6/RECAP** (Physical Intelligence, 2025) applies the same iterative collect-and-finetune loop at scale with VLA models, showing that even large pretrained policies benefit substantially from targeted human corrections on their own failure modes. π0.6 is trained using RECAP.
|
||||
|
||||
```bibtex
|
||||
@article{ross2011dagger,
|
||||
title={A Reduction of Imitation Learning and Structured Prediction to No-Regret Online Learning},
|
||||
author={Ross, Stéphane and Gordon, Geoffrey and Bagnell, Drew},
|
||||
journal={Proceedings of the Fourteenth International Conference on Artificial Intelligence and Statistics},
|
||||
year={2011}
|
||||
}
|
||||
|
||||
@article{kelly2019hgdagger,
|
||||
title={HG-DAgger: Interactive Imitation Learning with Human Experts},
|
||||
author={Kelly, Michael and Sidrane, Chelsea and Driggs-Campbell, Katherine and Kochenderfer, Mykel J},
|
||||
journal={arXiv preprint arXiv:1810.02890},
|
||||
year={2019}
|
||||
}
|
||||
|
||||
@article{hu2025rac,
|
||||
title={RaC: Robot Learning for Long-Horizon Tasks by Scaling Recovery and Correction},
|
||||
author={Hu, Zheyuan and Wu, Robyn and Enock, Naveen and Li, Jasmine and Kadakia, Riya and Erickson, Zackory and Kumar, Aviral},
|
||||
journal={arXiv preprint arXiv:2509.07953},
|
||||
year={2025}
|
||||
}
|
||||
|
||||
@article{pi2025recap,
|
||||
title={π0.6: a VLA That Learns From Experience},
|
||||
author={Physical Intelligence},
|
||||
year={2025}
|
||||
}
|
||||
```
|
||||
@@ -1,6 +1,6 @@
|
||||
# Installation
|
||||
|
||||
This guide uses `conda` (via miniforge) to manage environments (recommended). If you prefer another environment manager (e.g. `uv`, `venv`), ensure you have Python >=3.12 and `ffmpeg` installed with the `libsvtav1` encoder, then skip ahead to [Environment Setup](#step-2-environment-setup).
|
||||
This guide uses `conda` (via miniforge) to manage environments (recommended). If you prefer another environment manager (e.g. `uv`, `venv`), ensure you have Python >=3.12 and support PyTorch >= 2.10, then skip ahead to [Environment Setup](#step-2-environment-setup).
|
||||
|
||||
## Step 1 (`conda` only): Install [`miniforge`](https://conda-forge.org/download/)
|
||||
|
||||
@@ -20,7 +20,7 @@ Create a virtual environment with Python 3.12:
|
||||
conda create -y -n lerobot python=3.12
|
||||
```
|
||||
</hfoption>
|
||||
<hfoption id="uv">
|
||||
<hfoption id="uv (PyTorch >= 2.10 only)">
|
||||
```bash
|
||||
uv python install 3.12
|
||||
uv venv --python 3.12
|
||||
@@ -32,48 +32,87 @@ uv venv --python 3.12
|
||||
Then activate your virtual environment, you have to do this each time you open a shell to use lerobot:
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
|
||||
<hfoptions id="activate_venv">
|
||||
<hfoption id="conda">```bash
|
||||
<hfoption id="conda">
|
||||
```bash
|
||||
conda activate lerobot
|
||||
```</hfoption>
|
||||
<hfoption id="uv">
|
||||
```bash
|
||||
# Linux/macOSsource
|
||||
source .venv/bin/activate
|
||||
# Windows PowerShell
|
||||
source .venv\Scripts\Activate.ps1
|
||||
```
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
When using `conda`, install `ffmpeg` in your environment:
|
||||
|
||||
```bash
|
||||
conda install ffmpeg -c conda-forge
|
||||
ffmpeg -version # ffmpeg 8.X is not yet supported !
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
> This usually installs `ffmpeg 7.X` for your platform compiled with the `libsvtav1` encoder. If `libsvtav1` is not supported (check supported encoders with `ffmpeg -encoders`), you can:
|
||||
>
|
||||
> - _[On any platform]_ Explicitly install `ffmpeg 7.X` using:
|
||||
>
|
||||
> ```bash
|
||||
> conda install ffmpeg=7.1.1 -c conda-forge
|
||||
> ```
|
||||
>
|
||||
> - _[On Linux only]_ If you want to bring your own ffmpeg: Install [ffmpeg build dependencies](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#GettheDependencies) and [compile ffmpeg from source with libsvtav1](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#libsvtav1), and make sure you use the corresponding ffmpeg binary to your install with `which ffmpeg`.
|
||||
|
||||
> [!NOTE]
|
||||
> When installing LeRobot inside WSL (Windows Subsystem for Linux), make sure to install `evdev` with the following command:
|
||||
> When installing LeRobot inside WSL (Windows Subsystem for Linux), make sure to also install `evdev`:
|
||||
>
|
||||
> ```bash
|
||||
> conda install evdev -c conda-forge
|
||||
> ```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="uv (PyTorch >= 2.10 only)">
|
||||
```bash
|
||||
# Linux/macOS
|
||||
source .venv/bin/activate
|
||||
# Windows PowerShell
|
||||
.venv\Scripts\activate
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> When installing LeRobot inside WSL (Windows Subsystem for Linux), make sure to also install `evdev`:
|
||||
>
|
||||
> ```bash
|
||||
> sudo apt install libevdev-dev
|
||||
> uv pip install evdev
|
||||
> ```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
### Install `ffmpeg` (for video decoding)
|
||||
|
||||
LeRobot uses [TorchCodec](https://github.com/meta-pytorch/torchcodec) for video decoding by default, which requires `ffmpeg`.
|
||||
|
||||
> [!NOTE]
|
||||
> **Platform support:** TorchCodec is **not available** on macOS Intel (x86_64), Linux ARM (aarch64, arm64, armv7l), or Windows with PyTorch < 2.8. On these platforms, LeRobot automatically falls back to `pyav` — so you do not need to install `ffmpeg` and can skip to Step 3.
|
||||
|
||||
If your platform supports TorchCodec, install `ffmpeg` using one of the methods below:
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
|
||||
<hfoptions id="install_ffmpeg">
|
||||
<hfoption id="conda (any PyTorch version)">
|
||||
|
||||
Install `ffmpeg` in your conda environment. This works with **all PyTorch versions** and is **required for PyTorch < 2.10**:
|
||||
|
||||
```bash
|
||||
conda install ffmpeg -c conda-forge
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
> This usually installs `ffmpeg 8.X` with the `libsvtav1` encoder. If you run into issues (e.g. `libsvtav1` missing — check with `ffmpeg -encoders` — or a version mismatch with `torchcodec`), you can explicitly install `ffmpeg 7.1.1` using:
|
||||
>
|
||||
> ```bash
|
||||
> conda install ffmpeg=7.1.1 -c conda-forge
|
||||
> ```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="uv (PyTorch >= 2.10 only)">
|
||||
|
||||
Starting with **PyTorch >= 2.10** (TorchCodec ≥ 0.10), TorchCodec can dynamically link to a system-wide `ffmpeg` installation. This is useful when using `uv` or other non-`conda` environment managers:
|
||||
|
||||
```bash
|
||||
# Ubuntu/Debian
|
||||
sudo apt install ffmpeg
|
||||
|
||||
# macOS (Apple Silicon)
|
||||
brew install ffmpeg
|
||||
```
|
||||
|
||||
> [!IMPORTANT]
|
||||
> If you are using `uv` you will have to install `ffmpeg` system-wide (outside of the virtual environment). You rely on `uv` and `torchcodec` ability to dynamically link to the system `ffmpeg`.
|
||||
> System-wide `ffmpeg` is **only supported with PyTorch >= 2.10** (TorchCodec ≥ 0.10). For older PyTorch versions, you **must** use `conda install ffmpeg -c conda-forge` instead.
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
## Step 3: Install LeRobot 🤗
|
||||
|
||||
|
||||
+90
-81
@@ -1,36 +1,61 @@
|
||||
# LIBERO
|
||||
|
||||
**LIBERO** is a benchmark designed to study **lifelong robot learning**. The idea is that robots won’t just be pretrained once in a factory, they’ll need to keep learning and adapting with their human users over time. This ongoing adaptation is called **lifelong learning in decision making (LLDM)**, and it’s a key step toward building robots that become truly personalized helpers.
|
||||
LIBERO is a benchmark designed to study **lifelong robot learning** — the idea that robots need to keep learning and adapting with their users over time, not just be pretrained once. It provides a set of standardized manipulation tasks that focus on **knowledge transfer**: how well a robot can apply what it has already learned to new situations. By evaluating on LIBERO, different algorithms can be compared fairly and researchers can build on each other's work.
|
||||
|
||||
- 📄 [LIBERO paper](https://arxiv.org/abs/2306.03310)
|
||||
- 💻 [Original LIBERO repo](https://github.com/Lifelong-Robot-Learning/LIBERO)
|
||||
|
||||
To make progress on this challenge, LIBERO provides a set of standardized tasks that focus on **knowledge transfer**: how well a robot can apply what it has already learned to new situations. By evaluating on LIBERO, different algorithms can be compared fairly and researchers can build on each other’s work.
|
||||
|
||||
LIBERO includes **five task suites**:
|
||||
|
||||
- **LIBERO-Spatial (`libero_spatial`)** – tasks that require reasoning about spatial relations.
|
||||
- **LIBERO-Object (`libero_object`)** – tasks centered on manipulating different objects.
|
||||
- **LIBERO-Goal (`libero_goal`)** – goal-conditioned tasks where the robot must adapt to changing targets.
|
||||
- **LIBERO-90 (`libero_90`)** – 90 short-horizon tasks from the LIBERO-100 collection.
|
||||
- **LIBERO-Long (`libero_10`)** – 10 long-horizon tasks from the LIBERO-100 collection.
|
||||
|
||||
Together, these suites cover **130 tasks**, ranging from simple object manipulations to complex multi-step scenarios. LIBERO is meant to grow over time, and to serve as a shared benchmark where the community can test and improve lifelong learning algorithms.
|
||||
- Paper: [Benchmarking Knowledge Transfer for Lifelong Robot Learning](https://arxiv.org/abs/2306.03310)
|
||||
- GitHub: [Lifelong-Robot-Learning/LIBERO](https://github.com/Lifelong-Robot-Learning/LIBERO)
|
||||
- Project website: [libero-project.github.io](https://libero-project.github.io)
|
||||
|
||||

|
||||
|
||||
## Evaluating with LIBERO
|
||||
## Available tasks
|
||||
|
||||
At **LeRobot**, we ported [LIBERO](https://github.com/Lifelong-Robot-Learning/LIBERO) into our framework and used it mainly to **evaluate [SmolVLA](https://huggingface.co/docs/lerobot/en/smolvla)**, our lightweight Vision-Language-Action model.
|
||||
LIBERO includes **five task suites** covering **130 tasks**, ranging from simple object manipulations to complex multi-step scenarios:
|
||||
|
||||
LIBERO is now part of our **multi-eval supported simulation**, meaning you can benchmark your policies either on a **single suite of tasks** or across **multiple suites at once** with just a flag.
|
||||
| Suite | CLI name | Tasks | Description |
|
||||
| -------------- | ---------------- | ----- | -------------------------------------------------- |
|
||||
| LIBERO-Spatial | `libero_spatial` | 10 | Tasks requiring reasoning about spatial relations |
|
||||
| LIBERO-Object | `libero_object` | 10 | Tasks centered on manipulating different objects |
|
||||
| LIBERO-Goal | `libero_goal` | 10 | Goal-conditioned tasks with changing targets |
|
||||
| LIBERO-90 | `libero_90` | 90 | Short-horizon tasks from the LIBERO-100 collection |
|
||||
| LIBERO-Long | `libero_10` | 10 | Long-horizon tasks from the LIBERO-100 collection |
|
||||
|
||||
To Install LIBERO, after following LeRobot official instructions, just do:
|
||||
`pip install -e ".[libero]"`
|
||||
## Installation
|
||||
|
||||
After following the LeRobot installation instructions:
|
||||
|
||||
```bash
|
||||
pip install -e ".[libero]"
|
||||
```
|
||||
|
||||
<Tip>
|
||||
LIBERO requires Linux (`sys_platform == 'linux'`). LeRobot uses MuJoCo for simulation — set the rendering backend before training or evaluation:
|
||||
|
||||
```bash
|
||||
export MUJOCO_GL=egl # for headless servers (HPC, cloud)
|
||||
```
|
||||
|
||||
</Tip>
|
||||
|
||||
## Evaluation
|
||||
|
||||
### Default evaluation (recommended)
|
||||
|
||||
Evaluate across the four standard suites (10 episodes per task):
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.path="your-policy-id" \
|
||||
--env.type=libero \
|
||||
--env.task=libero_spatial,libero_object,libero_goal,libero_10 \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=10 \
|
||||
--env.max_parallel_tasks=1
|
||||
```
|
||||
|
||||
### Single-suite evaluation
|
||||
|
||||
Evaluate a policy on one LIBERO suite:
|
||||
Evaluate on one LIBERO suite:
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
@@ -42,15 +67,13 @@ lerobot-eval \
|
||||
```
|
||||
|
||||
- `--env.task` picks the suite (`libero_object`, `libero_spatial`, etc.).
|
||||
- `--env.task_ids` picks task ids to run (`[0]`, `[1,2,3]`, etc.). Omit this flag (or set it to `null`) to run all tasks in the suite.
|
||||
- `--env.task_ids` restricts to specific task indices (`[0]`, `[1,2,3]`, etc.). Omit to run all tasks in the suite.
|
||||
- `--eval.batch_size` controls how many environments run in parallel.
|
||||
- `--eval.n_episodes` sets how many episodes to run in total.
|
||||
|
||||
---
|
||||
- `--eval.n_episodes` sets how many episodes to run per task.
|
||||
|
||||
### Multi-suite evaluation
|
||||
|
||||
Benchmark a policy across multiple suites at once:
|
||||
Benchmark a policy across multiple suites at once by passing a comma-separated list:
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
@@ -61,50 +84,49 @@ lerobot-eval \
|
||||
--eval.n_episodes=2
|
||||
```
|
||||
|
||||
- Pass a comma-separated list to `--env.task` for multi-suite evaluation.
|
||||
### Control mode
|
||||
|
||||
### Control Mode
|
||||
LIBERO supports two control modes — `relative` (default) and `absolute`. Different VLA checkpoints are trained with different action parameterizations, so make sure the mode matches your policy:
|
||||
|
||||
LIBERO now supports two control modes: relative and absolute. This matters because different VLA checkpoints are trained with different mode of action to output hence control parameterizations.
|
||||
You can switch them with: `env.control_mode = "relative"` and `env.control_mode = "absolute"`
|
||||
```bash
|
||||
--env.control_mode=relative # or "absolute"
|
||||
```
|
||||
|
||||
### Policy inputs and outputs
|
||||
|
||||
When using LIBERO through LeRobot, policies interact with the environment via **observations** and **actions**:
|
||||
**Observations:**
|
||||
|
||||
- **Observations**
|
||||
- `observation.state` – proprioceptive features (agent state).
|
||||
- `observation.images.image` – main camera view (`agentview_image`).
|
||||
- `observation.images.image2` – wrist camera view (`robot0_eye_in_hand_image`).
|
||||
- `observation.state` — 8-dim proprioceptive features (eef position, axis-angle orientation, gripper qpos)
|
||||
- `observation.images.image` — main camera view (`agentview_image`), HWC uint8
|
||||
- `observation.images.image2` — wrist camera view (`robot0_eye_in_hand_image`), HWC uint8
|
||||
|
||||
⚠️ **Note:** LeRobot enforces the `.images.*` prefix for any multi-modal visual features. Always ensure that your policy config `input_features` use the same naming keys, and that your dataset metadata keys follow this convention during evaluation.
|
||||
If your data contains different keys, you must rename the observations to match what the policy expects, since naming keys are encoded inside the normalization statistics layer.
|
||||
This will be fixed with the upcoming Pipeline PR.
|
||||
<Tip warning={true}>
|
||||
LeRobot enforces the `.images.*` prefix for visual features. Ensure your
|
||||
policy config `input_features` use the same naming keys, and that your dataset
|
||||
metadata keys follow this convention. If your data contains different keys,
|
||||
you must rename the observations to match what the policy expects, since
|
||||
naming keys are encoded inside the normalization statistics layer.
|
||||
</Tip>
|
||||
|
||||
- **Actions**
|
||||
- Continuous control values in a `Box(-1, 1, shape=(7,))` space.
|
||||
**Actions:**
|
||||
|
||||
We also provide a notebook for quick testing:
|
||||
Training with LIBERO
|
||||
- Continuous control in `Box(-1, 1, shape=(7,))` — 6D end-effector delta + 1D gripper
|
||||
|
||||
## Training with LIBERO
|
||||
### Recommended evaluation episodes
|
||||
|
||||
When training on LIBERO tasks, make sure your dataset parquet and metadata keys follow the LeRobot convention.
|
||||
For reproducible benchmarking, use **10 episodes per task** across all four standard suites (Spatial, Object, Goal, Long). This gives 400 total episodes and matches the protocol used for published results.
|
||||
|
||||
The environment expects:
|
||||
## Training
|
||||
|
||||
- `observation.state` → 8-dim agent state
|
||||
- `observation.images.image` → main camera (`agentview_image`)
|
||||
- `observation.images.image2` → wrist camera (`robot0_eye_in_hand_image`)
|
||||
### Dataset
|
||||
|
||||
⚠️ Cleaning the dataset upfront is **cleaner and more efficient** than remapping keys inside the code.
|
||||
To avoid potential mismatches and key errors, we provide a **preprocessed LIBERO dataset** that is fully compatible with the current LeRobot codebase and requires no additional manipulation:
|
||||
👉 [HuggingFaceVLA/libero](https://huggingface.co/datasets/HuggingFaceVLA/libero)
|
||||
We provide a preprocessed LIBERO dataset fully compatible with LeRobot:
|
||||
|
||||
For reference, here is the **original dataset** published by Physical Intelligence:
|
||||
👉 [physical-intelligence/libero](https://huggingface.co/datasets/physical-intelligence/libero)
|
||||
- [HuggingFaceVLA/libero](https://huggingface.co/datasets/HuggingFaceVLA/libero)
|
||||
|
||||
---
|
||||
For reference, the original dataset published by Physical Intelligence:
|
||||
|
||||
- [physical-intelligence/libero](https://huggingface.co/datasets/physical-intelligence/libero)
|
||||
|
||||
### Example training command
|
||||
|
||||
@@ -121,52 +143,39 @@ lerobot-train \
|
||||
--batch_size=4 \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=1 \
|
||||
--eval_freq=1000 \
|
||||
--eval_freq=1000
|
||||
```
|
||||
|
||||
---
|
||||
## Reproducing published results
|
||||
|
||||
### Note on rendering
|
||||
We reproduce the results of Pi0.5 on the LIBERO benchmark. We take the Physical Intelligence LIBERO base model (`pi05_libero`) and finetune for an additional 6k steps in bfloat16, with batch size of 256 on 8 H100 GPUs using the [HuggingFace LIBERO dataset](https://huggingface.co/datasets/HuggingFaceVLA/libero).
|
||||
|
||||
LeRobot uses MuJoCo for simulation. You need to set the rendering backend before training or evaluation:
|
||||
The finetuned model: [lerobot/pi05_libero_finetuned](https://huggingface.co/lerobot/pi05_libero_finetuned)
|
||||
|
||||
- `export MUJOCO_GL=egl` → for headless servers (e.g. HPC, cloud)
|
||||
|
||||
## Reproducing π₀.₅ results
|
||||
|
||||
We reproduce the results of π₀.₅ on the LIBERO benchmark using the LeRobot implementation. We take the Physical Intelligence LIBERO base model (`pi05_libero`) and finetune for an additional 6k steps in bfloat16, with batch size of 256 on 8 H100 GPUs using the [HuggingFace LIBERO dataset](https://huggingface.co/datasets/HuggingFaceVLA/libero).
|
||||
|
||||
The finetuned model can be found here:
|
||||
|
||||
- **π₀.₅ LIBERO**: [lerobot/pi05_libero_finetuned](https://huggingface.co/lerobot/pi05_libero_finetuned)
|
||||
|
||||
We then evaluate the finetuned model using the LeRobot LIBERO implementation, by running the following command:
|
||||
### Evaluation command
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--output_dir=/logs/ \
|
||||
--output_dir=./eval_logs/ \
|
||||
--env.type=libero \
|
||||
--env.task=libero_spatial,libero_object,libero_goal,libero_10 \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=10 \
|
||||
--policy.path=pi05_libero_finetuned \
|
||||
--policy.n_action_steps=10 \
|
||||
--output_dir=./eval_logs/ \
|
||||
--env.max_parallel_tasks=1
|
||||
```
|
||||
|
||||
**Note:** We set `n_action_steps=10`, similar to the original OpenPI implementation.
|
||||
We set `n_action_steps=10`, matching the original OpenPI implementation.
|
||||
|
||||
### Results
|
||||
|
||||
We obtain the following results on the LIBERO benchmark:
|
||||
| Model | LIBERO Spatial | LIBERO Object | LIBERO Goal | LIBERO 10 | Average |
|
||||
| ------------------- | -------------- | ------------- | ----------- | --------- | -------- |
|
||||
| **Pi0.5 (LeRobot)** | 97.0 | 99.0 | 98.0 | 96.0 | **97.5** |
|
||||
|
||||
| Model | LIBERO Spatial | LIBERO Object | LIBERO Goal | LIBERO 10 | Average |
|
||||
| -------- | -------------- | ------------- | ----------- | --------- | -------- |
|
||||
| **π₀.₅** | 97.0 | 99.0 | 98.0 | 96.0 | **97.5** |
|
||||
These results are consistent with the [original results](https://github.com/Physical-Intelligence/openpi/tree/main/examples/libero#results) reported by Physical Intelligence:
|
||||
|
||||
These results are consistent with the original [results](https://github.com/Physical-Intelligence/openpi/tree/main/examples/libero#results) reported by Physical Intelligence:
|
||||
|
||||
| Model | LIBERO Spatial | LIBERO Object | LIBERO Goal | LIBERO 10 | Average |
|
||||
| -------- | -------------- | ------------- | ----------- | --------- | --------- |
|
||||
| **π₀.₅** | 98.8 | 98.2 | 98.0 | 92.4 | **96.85** |
|
||||
| Model | LIBERO Spatial | LIBERO Object | LIBERO Goal | LIBERO 10 | Average |
|
||||
| ------------------ | -------------- | ------------- | ----------- | --------- | --------- |
|
||||
| **Pi0.5 (OpenPI)** | 98.8 | 98.2 | 98.0 | 92.4 | **96.85** |
|
||||
|
||||
+97
-47
@@ -1,32 +1,111 @@
|
||||
# Meta-World
|
||||
|
||||
Meta-World is a well-designed, open-source simulation benchmark for multi-task and meta reinforcement learning in continuous-control robotic manipulation. It gives researchers a shared, realistic playground to test whether algorithms can _learn many different tasks_ and _generalize quickly to new ones_ — two central challenges for real-world robotics.
|
||||
Meta-World is an open-source simulation benchmark for **multi-task and meta reinforcement learning** in continuous-control robotic manipulation. It bundles 50 diverse manipulation tasks using everyday objects and a common tabletop Sawyer arm, providing a standardized playground to test whether algorithms can learn many different tasks and generalize quickly to new ones.
|
||||
|
||||
- 📄 [MetaWorld paper](https://arxiv.org/pdf/1910.10897)
|
||||
- 💻 [Original MetaWorld repo](https://github.com/Farama-Foundation/Metaworld)
|
||||
- Paper: [Meta-World: A Benchmark and Evaluation for Multi-Task and Meta Reinforcement Learning](https://arxiv.org/abs/1910.10897)
|
||||
- GitHub: [Farama-Foundation/Metaworld](https://github.com/Farama-Foundation/Metaworld)
|
||||
- Project website: [metaworld.farama.org](https://metaworld.farama.org)
|
||||
|
||||

|
||||
|
||||
## Why Meta-World matters
|
||||
## Available tasks
|
||||
|
||||
- **Diverse, realistic tasks.** Meta-World bundles a large suite of simulated manipulation tasks (50 in the MT50 suite) using everyday objects and a common tabletop Sawyer arm. This diversity exposes algorithms to a wide variety of dynamics, contacts and goal specifications while keeping a consistent control and observation structure.
|
||||
- **Focus on generalization and multi-task learning.** By evaluating across task distributions that share structure but differ in goals and objects, Meta-World reveals whether an agent truly learns transferable skills rather than overfitting to a narrow task.
|
||||
- **Standardized evaluation protocol.** It provides clear evaluation modes and difficulty splits, so different methods can be compared fairly across easy, medium, hard and very-hard regimes.
|
||||
- **Empirical insight.** Past evaluations on Meta-World show impressive progress on some fronts, but also highlight that current multi-task and meta-RL methods still struggle with large, diverse task sets. That gap points to important research directions.
|
||||
Meta-World provides 50 tasks organized into difficulty groups. In LeRobot, you can evaluate on individual tasks, difficulty groups, or the full MT50 suite:
|
||||
|
||||
## What it enables in LeRobot
|
||||
| Group | CLI name | Tasks | Description |
|
||||
| ---------- | -------------------- | ----- | ------------------------------------------------------ |
|
||||
| Easy | `easy` | 28 | Tasks with simple dynamics and single-step goals |
|
||||
| Medium | `medium` | 11 | Tasks requiring multi-step reasoning |
|
||||
| Hard | `hard` | 6 | Tasks with complex contacts and precise manipulation |
|
||||
| Very Hard | `very_hard` | 5 | The most challenging tasks in the suite |
|
||||
| MT50 (all) | Comma-separated list | 50 | All 50 tasks — the most challenging multi-task setting |
|
||||
|
||||
In LeRobot, you can evaluate any policy or vision-language-action (VLA) model on Meta-World tasks and get a clear success-rate measure. The integration is designed to be straightforward:
|
||||
You can also pass individual task names directly (e.g., `assembly-v3`, `dial-turn-v3`).
|
||||
|
||||
- We provide a LeRobot-ready dataset for Meta-World (MT50) on the HF Hub: `https://huggingface.co/datasets/lerobot/metaworld_mt50`.
|
||||
- This dataset is formatted for the MT50 evaluation that uses all 50 tasks (the most challenging multi-task setting).
|
||||
- MT50 gives the policy a one-hot task vector and uses fixed object/goal positions for consistency.
|
||||
We provide a LeRobot-ready dataset for Meta-World MT50 on the HF Hub: [lerobot/metaworld_mt50](https://huggingface.co/datasets/lerobot/metaworld_mt50). This dataset is formatted for the MT50 evaluation that uses all 50 tasks with fixed object/goal positions and one-hot task vectors for consistency.
|
||||
|
||||
- Task descriptions and the exact keys required for evaluation are available in the repo/dataset — use these to ensure your policy outputs the right success signals.
|
||||
## Installation
|
||||
|
||||
## Quick start, train a SmolVLA policy on Meta-World
|
||||
After following the LeRobot installation instructions:
|
||||
|
||||
Example command to train a SmolVLA policy on a subset of tasks:
|
||||
```bash
|
||||
pip install -e ".[metaworld]"
|
||||
```
|
||||
|
||||
<Tip warning={true}>
|
||||
If you encounter an `AssertionError: ['human', 'rgb_array', 'depth_array']` when running Meta-World environments, this is a mismatch between Meta-World and your Gymnasium version. Fix it with:
|
||||
|
||||
```bash
|
||||
pip install "gymnasium==1.1.0"
|
||||
```
|
||||
|
||||
</Tip>
|
||||
|
||||
## Evaluation
|
||||
|
||||
### Default evaluation (recommended)
|
||||
|
||||
Evaluate on the medium difficulty split (a good balance of coverage and compute):
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.path="your-policy-id" \
|
||||
--env.type=metaworld \
|
||||
--env.task=medium \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=10
|
||||
```
|
||||
|
||||
### Single-task evaluation
|
||||
|
||||
Evaluate on a specific task:
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.path="your-policy-id" \
|
||||
--env.type=metaworld \
|
||||
--env.task=assembly-v3 \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=10
|
||||
```
|
||||
|
||||
### Multi-task evaluation
|
||||
|
||||
Evaluate across multiple tasks or difficulty groups:
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.path="your-policy-id" \
|
||||
--env.type=metaworld \
|
||||
--env.task=assembly-v3,dial-turn-v3,handle-press-side-v3 \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=10
|
||||
```
|
||||
|
||||
- `--env.task` accepts explicit task lists (comma-separated) or difficulty groups (e.g., `easy`, `medium`, `hard`, `very_hard`).
|
||||
- `--eval.batch_size` controls how many environments run in parallel.
|
||||
- `--eval.n_episodes` sets how many episodes to run per task.
|
||||
|
||||
### Policy inputs and outputs
|
||||
|
||||
**Observations:**
|
||||
|
||||
- `observation.image` — single camera view (`corner2`), 480x480 HWC uint8
|
||||
- `observation.state` — 4-dim proprioceptive state (end-effector position + gripper)
|
||||
|
||||
**Actions:**
|
||||
|
||||
- Continuous control in `Box(-1, 1, shape=(4,))` — 3D end-effector delta + 1D gripper
|
||||
|
||||
### Recommended evaluation episodes
|
||||
|
||||
For reproducible benchmarking, use **10 episodes per task**. For the full MT50 suite this gives 500 total episodes. If you care about generalization, run on the full MT50 — it is intentionally challenging and reveals strengths/weaknesses better than a few narrow tasks.
|
||||
|
||||
## Training
|
||||
|
||||
### Example training command
|
||||
|
||||
Train a SmolVLA policy on a subset of Meta-World tasks:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
@@ -44,37 +123,8 @@ lerobot-train \
|
||||
--eval_freq=1000
|
||||
```
|
||||
|
||||
Notes:
|
||||
|
||||
- `--env.task` accepts explicit task lists (comma separated) or difficulty groups (e.g., `env.task="hard"`).
|
||||
- Adjust `batch_size`, `steps`, and `eval_freq` to match your compute budget.
|
||||
- **Gymnasium Assertion Error**: if you encounter an error like
|
||||
`AssertionError: ['human', 'rgb_array', 'depth_array']` when running MetaWorld environments, this comes from a mismatch between MetaWorld and your Gymnasium version.
|
||||
We recommend using:
|
||||
|
||||
```bash
|
||||
pip install "gymnasium==1.1.0"
|
||||
```
|
||||
|
||||
to ensure proper compatibility.
|
||||
|
||||
## Quick start — evaluate a trained policy
|
||||
|
||||
To evaluate a trained policy on the Meta-World medium difficulty split:
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.path="your-policy-id" \
|
||||
--env.type=metaworld \
|
||||
--env.task=medium \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=2
|
||||
```
|
||||
|
||||
This will run episodes and return per-task success rates using the standard Meta-World evaluation keys.
|
||||
|
||||
## Practical tips
|
||||
|
||||
- If you care about generalization, run on the full MT50 suite — it’s intentionally challenging and reveals strengths/weaknesses better than a few narrow tasks.
|
||||
- Use the one-hot task conditioning for multi-task training (MT10 / MT50 conventions) so policies have explicit task context.
|
||||
- Use the one-hot task conditioning for multi-task training (MT10/MT50 conventions) so policies have explicit task context.
|
||||
- Inspect the dataset task descriptions and the `info["is_success"]` keys when writing post-processing or logging so your success metrics line up with the benchmark.
|
||||
- Adjust `batch_size`, `steps`, and `eval_freq` to match your compute budget.
|
||||
|
||||
@@ -331,6 +331,54 @@ lerobot-train \
|
||||
--wandb.project=multitask_dit
|
||||
```
|
||||
|
||||
## Libero Results
|
||||
|
||||
```
|
||||
python -m lerobot.scripts.lerobot_train \
|
||||
--dataset.repo_id=HuggingFaceVLA/libero \
|
||||
--policy.type=multi_task_dit \
|
||||
--policy.push_to_hub=false \
|
||||
--output_dir="./outputs/multitask_dit_libero" \
|
||||
--job_name="multitask-dit-libero" \
|
||||
--wandb.enable=true \
|
||||
--wandb.project=multitask_dit_libero \
|
||||
--dataset.image_transforms.enable=true \
|
||||
--dataset.image_transforms.max_num_transforms=4 \
|
||||
--dataset.image_transforms.tfs='{"brightness":{"type":"ColorJitter","kwargs":{"brightness":[0.75,1.25]}},"contrast":{"type":"ColorJitter","kwargs":{"contrast":[0.6,1.4]}},"saturation":{"type":"ColorJitter","kwargs":{"saturation":[0.8,1.2]}},"hue":{"type":"ColorJitter","kwargs":{"hue":[-0.05,0.05]}},"sharpness":{"type":"SharpnessJitter","kwargs":{"sharpness":[0.6,1.4]}},"rotation":{"type":"RandomRotation","kwargs":{"degrees":[-5,5]}},"translation":{"type":"RandomAffine","kwargs":{"degrees":0,"translate":[0.1,0.1]}}}' \
|
||||
--dataset.video_backend=torchcodec \
|
||||
--policy.use_amp=true \
|
||||
--policy.horizon=48 \
|
||||
--policy.n_obs_steps=2 \
|
||||
--policy.use_rope=true \
|
||||
--policy.use_positional_encoding=false \
|
||||
--policy.hidden_dim=768 \
|
||||
--policy.num_layers=8 \
|
||||
--policy.num_heads=12 \
|
||||
--policy.dropout=0.1 \
|
||||
--policy.timestep_embed_dim=256 \
|
||||
--policy.objective=diffusion \
|
||||
--policy.optimizer_lr=3e-4 \
|
||||
--policy.optimizer_weight_decay=0 \
|
||||
--policy.scheduler_warmup_steps=0 \
|
||||
--policy.vision_encoder_name=openai/clip-vit-base-patch16 \
|
||||
--policy.image_resize_shape=[256,256] \
|
||||
--policy.image_crop_is_random=true \
|
||||
--policy.text_encoder_name=openai/clip-vit-base-patch16 \
|
||||
--policy.vision_encoder_lr_multiplier=0.1 \
|
||||
--policy.device=cuda \
|
||||
--num_workers=8 \
|
||||
--save_freq=4000 \
|
||||
--log_freq=100 \
|
||||
--steps=100000 \
|
||||
--batch_size=320
|
||||
```
|
||||
|
||||
Results:
|
||||
|
||||
| LIBERO Spatial | LIBERO Object | LIBERO Goal | LIBERO 10 | Average |
|
||||
| -------------- | ------------- | ----------- | --------- | ------- |
|
||||
| 87.0 | 98.2 | 93.8 | 83.2 | 90.6 |
|
||||
|
||||
## References
|
||||
|
||||
For more details on the technical implementation and architecture, see:
|
||||
|
||||
@@ -0,0 +1,91 @@
|
||||
# π₀.₅ (pi05)
|
||||
|
||||
This repository contains the Hugging Face port of **π₀.₅**, adapted from [OpenPI](https://github.com/Physical-Intelligence/openpi) by the Physical Intelligence.
|
||||
It is designed as a **Vision-Language-Action model with open-world generalization**.
|
||||
|
||||
---
|
||||
|
||||
## Model Overview
|
||||
|
||||
| Feature | π₀ | π₀.₅ |
|
||||
| -------------------- | ------------------------------------------------------ | ----------------------------------------- |
|
||||
| Time Conditioning | Concatenates time with actions via `action_time_mlp_*` | Uses `time_mlp_*` for AdaRMS conditioning |
|
||||
| AdaRMS | Not used | Used in action expert |
|
||||
| Tokenizer Length | 48 tokens | 200 tokens |
|
||||
| Discrete State Input | False (Uses `state_proj` layer) | True |
|
||||
| Parameter Count | Higher (includes state embedding) | Lower (no state embedding) |
|
||||
|
||||
---
|
||||
|
||||
## Relative Actions
|
||||
|
||||
π₀.₅ supports training with **relative actions**, where the model learns relative offsets
|
||||
from the current robot state instead of absolute joint positions. This mirrors the
|
||||
relative-action transform in OpenPI (`DeltaActions`) and can improve performance.
|
||||
|
||||
### How it works
|
||||
|
||||
1. **During preprocessing**, absolute actions are converted to relative offsets:
|
||||
`relative = action - state` (for selected joints).
|
||||
2. The relative actions are normalized using statistics computed from the relative distribution.
|
||||
3. **During postprocessing**, predicted relative actions are converted back to absolute:
|
||||
`absolute = relative + state`.
|
||||
|
||||
Joints listed in `relative_exclude_joints` (e.g., gripper) are kept absolute.
|
||||
|
||||
### Configuration
|
||||
|
||||
| Parameter | Type | Default | Description |
|
||||
| ------------------------- | ----------- | ------------- | ---------------------------------------------------------------- |
|
||||
| `use_relative_actions` | `bool` | `False` | Enable relative-action training |
|
||||
| `relative_exclude_joints` | `list[str]` | `["gripper"]` | Joint names to keep absolute (matched by substring) |
|
||||
| `action_feature_names` | `list[str]` | `None` | Auto-populated from dataset metadata at runtime by `make_policy` |
|
||||
|
||||
### Training example
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.lerobot_train \
|
||||
--policy.type=pi05 \
|
||||
--dataset.repo_id=your_org/your_dataset \
|
||||
--policy.use_relative_actions=true \
|
||||
--policy.relative_exclude_joints='["gripper"]'
|
||||
```
|
||||
|
||||
When `use_relative_actions=true`, the training script automatically:
|
||||
|
||||
- Computes relative action statistics from the dataset (sampled chunk-level relative actions)
|
||||
- Replaces the standard action stats with relative stats for normalization
|
||||
- Broadcasts these stats across all ranks in distributed training
|
||||
|
||||
---
|
||||
|
||||
## Citation
|
||||
|
||||
If you use this work, please cite both **OpenPI** and the π₀.₅ paper:
|
||||
|
||||
```bibtex
|
||||
@misc{openpi2024,
|
||||
author = {Physical Intelligence Lab},
|
||||
title = {OpenPI: PyTorch Implementation of π0 and π0.5 Policies},
|
||||
year = {2024},
|
||||
publisher = {GitHub},
|
||||
howpublished = {\url{https://github.com/Physical-Intelligence/openpi}},
|
||||
license = {Apache-2.0}
|
||||
}
|
||||
|
||||
@misc{intelligence2025pi05visionlanguageactionmodelopenworld,
|
||||
title = {π₀.₅: a Vision-Language-Action Model with Open-World Generalization},
|
||||
author = {Physical Intelligence and Kevin Black and Noah Brown and James Darpinian and Karan Dhabalia and Danny Driess and Adnan Esmail and Michael Equi and Chelsea Finn and Niccolo Fusai and Manuel Y. Galliker and Dibya Ghosh and Lachy Groom and Karol Hausman and Brian Ichter and Szymon Jakubczak and Tim Jones and Liyiming Ke and Devin LeBlanc and Sergey Levine and Adrian Li-Bell and Mohith Mothukuri and Suraj Nair and Karl Pertsch and Allen Z. Ren and Lucy Xiaoyang Shi and Laura Smith and Jost Tobias Springenberg and Kyle Stachowicz and James Tanner and Quan Vuong and Homer Walke and Anna Walling and Haohuan Wang and Lili Yu and Ury Zhilinsky},
|
||||
year = {2025},
|
||||
eprint = {2504.16054},
|
||||
archivePrefix= {arXiv},
|
||||
primaryClass = {cs.LG},
|
||||
url = {https://arxiv.org/abs/2504.16054},
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## License
|
||||
|
||||
This port follows the **Apache 2.0 License**, consistent with the original [OpenPI repository](https://github.com/Physical-Intelligence/openpi).
|
||||
@@ -0,0 +1,108 @@
|
||||
# π₀ (pi0)
|
||||
|
||||
This repository contains the Hugging Face port of **π₀**, adapted from [OpenPI](https://github.com/Physical-Intelligence/openpi) by the Physical Intelligence.
|
||||
It is designed as a **Vision-Language-Action model for general robot control**.
|
||||
|
||||
---
|
||||
|
||||
## Model Overview
|
||||
|
||||
| Feature | π₀ | π₀.₅ |
|
||||
| -------------------- | ------------------------------------------------------ | ----------------------------------------- |
|
||||
| Time Conditioning | Concatenates time with actions via `action_time_mlp_*` | Uses `time_mlp_*` for AdaRMS conditioning |
|
||||
| AdaRMS | Not used | Used in action expert |
|
||||
| Tokenizer Length | 48 tokens | 200 tokens |
|
||||
| Discrete State Input | False (Uses `state_proj` layer) | True |
|
||||
| Parameter Count | Higher (includes state embedding) | Lower (no state embedding) |
|
||||
|
||||
---
|
||||
|
||||
## Relative Actions
|
||||
|
||||
π₀ supports training with **relative actions**, where the model learns relative offsets
|
||||
from the current robot state instead of absolute joint positions. This mirrors the
|
||||
relative-action transform in OpenPI (`DeltaActions`) and can improve performance.
|
||||
|
||||
### How it works
|
||||
|
||||
1. **During preprocessing**, absolute actions are converted to relative offsets:
|
||||
`relative = action - state` (for selected joints).
|
||||
2. The relative actions are normalized using statistics computed from the relative distribution.
|
||||
3. **During postprocessing**, predicted relative actions are converted back to absolute:
|
||||
`absolute = relative + state`.
|
||||
|
||||
Joints listed in `relative_exclude_joints` (e.g., gripper) are kept absolute.
|
||||
|
||||
### Configuration
|
||||
|
||||
| Parameter | Type | Default | Description |
|
||||
| ------------------------- | ----------- | ------------- | ---------------------------------------------------------------- |
|
||||
| `use_relative_actions` | `bool` | `False` | Enable relative-action training |
|
||||
| `relative_exclude_joints` | `list[str]` | `["gripper"]` | Joint names to keep absolute (matched by substring) |
|
||||
| `action_feature_names` | `list[str]` | `None` | Auto-populated from dataset metadata at runtime by `make_policy` |
|
||||
|
||||
### Training example
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.lerobot_train \
|
||||
--policy.type=pi0 \
|
||||
--dataset.repo_id=your_org/your_dataset \
|
||||
--policy.use_relative_actions=true \
|
||||
--policy.relative_exclude_joints='["gripper"]'
|
||||
```
|
||||
|
||||
When `use_relative_actions=true`, the training script automatically:
|
||||
|
||||
- Computes relative action statistics from the dataset (sampled chunk-level relative actions)
|
||||
- Replaces the standard action stats with relative stats for normalization
|
||||
- Broadcasts these stats across all ranks in distributed training
|
||||
|
||||
### Recomputing stats for an existing dataset
|
||||
|
||||
If you want to precompute relative action stats offline, use `recompute_stats` from
|
||||
`lerobot.datasets.dataset_tools`:
|
||||
|
||||
```python
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.dataset_tools import recompute_stats
|
||||
|
||||
dataset = LeRobotDataset("your_org/your_dataset")
|
||||
dataset = recompute_stats(
|
||||
dataset,
|
||||
relative_action=True,
|
||||
relative_exclude_joints=["gripper"],
|
||||
)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Citation
|
||||
|
||||
If you use this work, please cite both **OpenPI** and the π₀ paper:
|
||||
|
||||
```bibtex
|
||||
@misc{openpi2024,
|
||||
author = {Physical Intelligence Lab},
|
||||
title = {OpenPI: PyTorch Implementation of π0 and π0.5 Policies},
|
||||
year = {2024},
|
||||
publisher = {GitHub},
|
||||
howpublished = {\url{https://github.com/Physical-Intelligence/openpi}},
|
||||
license = {Apache-2.0}
|
||||
}
|
||||
|
||||
@misc{black2024pi0visionlanguageactionflowmodel,
|
||||
title = {π₀: A Vision-Language-Action Flow Model for General Robot Control},
|
||||
author = {Kevin Black and Noah Brown and Danny Driess and Adnan Esmail and Michael Equi and Chelsea Finn and Niccolo Fusai and Lachy Groom and Karol Hausman and Brian Ichter and Szymon Jakubczak and Tim Jones and Liyiming Ke and Sergey Levine and Adrian Li-Bell and Mohith Mothukuri and Suraj Nair and Karl Pertsch and Lucy Xiaoyang Shi and James Tanner and Quan Vuong and Anna Walling and Haohuan Wang and Ury Zhilinsky},
|
||||
year = {2024},
|
||||
eprint = {2410.24164},
|
||||
archivePrefix= {arXiv},
|
||||
primaryClass = {cs.LG},
|
||||
url = {https://arxiv.org/abs/2410.24164},
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## License
|
||||
|
||||
This port follows the **Apache 2.0 License**, consistent with the original [OpenPI repository](https://github.com/Physical-Intelligence/openpi).
|
||||
@@ -0,0 +1,38 @@
|
||||
# Real-Time Chunking (RTC)
|
||||
|
||||
This module contains the LeRobot implementation of **Real-Time Chunking (RTC)**, an inference-time technique for flow-matching based policies.
|
||||
|
||||
**Note**: RTC is not a policy itself, but rather an inference enhancement that works with flow-matching based policies including [π₀](../pi0/), [π₀.₅](../pi05/), and [SmolVLA](../smolvla/).
|
||||
|
||||
---
|
||||
|
||||
## Citation
|
||||
|
||||
If you use Real-Time Chunking in your work, please cite:
|
||||
|
||||
```bibtex
|
||||
@misc{openpi2024,
|
||||
author = {Physical Intelligence Lab},
|
||||
title = {OpenPI: PyTorch Implementation of π0 and π0.5 Policies},
|
||||
year = {2024},
|
||||
publisher = {GitHub},
|
||||
howpublished = {\url{https://github.com/Physical-Intelligence/openpi}},
|
||||
license = {Apache-2.0}
|
||||
}
|
||||
|
||||
@misc{black2025realtimeexecutionactionchunking,
|
||||
title={Real-Time Execution of Action Chunking Flow Policies},
|
||||
author={Kevin Black and Manuel Y. Galliker and Sergey Levine},
|
||||
year={2025},
|
||||
eprint={2506.07339},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.RO},
|
||||
url={https://arxiv.org/abs/2506.07339},
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## License
|
||||
|
||||
This implementation follows the **Apache 2.0 License**, consistent with the LeRobot project.
|
||||
@@ -0,0 +1,14 @@
|
||||
## Paper
|
||||
|
||||
https://arxiv.org/abs/2509.25358
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@article{chen2025sarm,
|
||||
title={SARM: Stage-Aware Reward Modeling for Long Horizon Robot Manipulation},
|
||||
author={Chen, Qianzhong and Yu, Justin and Schwager, Mac and Abbeel, Pieter and Shentu, Yide and Wu, Philipp},
|
||||
journal={arXiv preprint arXiv:2509.25358},
|
||||
year={2025}
|
||||
}
|
||||
```
|
||||
@@ -0,0 +1,680 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Create MP4 (or GIF) videos with sarm_progress overlay for specified episodes.
|
||||
|
||||
Downloads datasets from HuggingFace, seeks directly into the episode segment
|
||||
of the source video, draws a progress line on each frame, and writes the result.
|
||||
|
||||
Usage:
|
||||
python examples/dataset/create_progress_videos.py \
|
||||
--repo-id lerobot-data-collection/level2_final_quality3 \
|
||||
--episode 1100
|
||||
|
||||
python examples/dataset/create_progress_videos.py \
|
||||
--repo-id lerobot-data-collection/level2_final_quality3 \
|
||||
--episode 1100 \
|
||||
--camera-key observation.images.top \
|
||||
--output-dir ./my_videos \
|
||||
--gif
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
GRAPH_Y_TOP_FRAC = 0.01
|
||||
GRAPH_Y_BOT_FRAC = 0.99
|
||||
LINE_THICKNESS = 3
|
||||
SHADOW_THICKNESS = 6
|
||||
REF_ALPHA = 0.45
|
||||
FILL_ALPHA = 0.55
|
||||
SCORE_FONT_SCALE = 0.8
|
||||
TASK_FONT_SCALE = 0.55
|
||||
|
||||
|
||||
def download_episode_metadata(repo_id: str, episode: int) -> Path:
|
||||
"""Download only the metadata and sarm_progress files for a dataset.
|
||||
|
||||
Args:
|
||||
repo_id: HuggingFace dataset repository ID.
|
||||
episode: Episode index (used for logging only; all meta is fetched).
|
||||
|
||||
Returns:
|
||||
Local cache path for the downloaded snapshot.
|
||||
"""
|
||||
logging.info("[1/4] Downloading metadata for %s (episode %d) ...", repo_id, episode)
|
||||
local_path = Path(
|
||||
snapshot_download(
|
||||
repo_id=repo_id,
|
||||
repo_type="dataset",
|
||||
allow_patterns=["meta/**", "sarm_progress.parquet"],
|
||||
ignore_patterns=["*.mp4"],
|
||||
)
|
||||
)
|
||||
return local_path
|
||||
|
||||
|
||||
def load_episode_meta(local_path: Path, episode: int, camera_key: str | None) -> dict:
|
||||
"""Read info.json and episode parquet to resolve fps, video path, and timestamps.
|
||||
|
||||
Args:
|
||||
local_path: Local cache directory containing meta/.
|
||||
episode: Episode index to look up.
|
||||
camera_key: Camera observation key (e.g. "observation.images.base").
|
||||
If None, the first available video key is used.
|
||||
|
||||
Returns:
|
||||
Dict with keys: fps, camera, video_rel, chunk_index, file_index,
|
||||
from_ts, to_ts, task_name.
|
||||
"""
|
||||
info = json.loads((local_path / "meta" / "info.json").read_text())
|
||||
fps = info["fps"]
|
||||
features = info["features"]
|
||||
|
||||
video_keys = [k for k, v in features.items() if v.get("dtype") == "video"]
|
||||
if not video_keys:
|
||||
raise RuntimeError("No video keys found in dataset features")
|
||||
|
||||
if camera_key is not None:
|
||||
if camera_key not in video_keys:
|
||||
raise RuntimeError(f"camera_key='{camera_key}' not found. Available: {video_keys}")
|
||||
selected_camera = camera_key
|
||||
else:
|
||||
selected_camera = video_keys[0]
|
||||
logging.info(" fps=%d camera='%s' all_cams=%s", fps, selected_camera, video_keys)
|
||||
|
||||
episode_rows = []
|
||||
for parquet_file in sorted((local_path / "meta" / "episodes").glob("**/*.parquet")):
|
||||
episode_rows.append(pd.read_parquet(parquet_file))
|
||||
episode_df = pd.concat(episode_rows, ignore_index=True)
|
||||
row = episode_df[episode_df["episode_index"] == episode]
|
||||
if row.empty:
|
||||
raise RuntimeError(f"Episode {episode} not found in episode metadata")
|
||||
row = row.iloc[0]
|
||||
|
||||
chunk_col = f"videos/{selected_camera}/chunk_index"
|
||||
file_col = f"videos/{selected_camera}/file_index"
|
||||
ts_from_col = f"videos/{selected_camera}/from_timestamp"
|
||||
ts_to_col = f"videos/{selected_camera}/to_timestamp"
|
||||
|
||||
if chunk_col not in row.index:
|
||||
chunk_col = f"{selected_camera}/chunk_index"
|
||||
file_col = f"{selected_camera}/file_index"
|
||||
ts_from_col = f"{selected_camera}/from_timestamp"
|
||||
ts_to_col = f"{selected_camera}/to_timestamp"
|
||||
if chunk_col not in row.index:
|
||||
raise RuntimeError(
|
||||
f"Cannot find video metadata columns for {selected_camera}.\nAvailable: {list(row.index)}"
|
||||
)
|
||||
|
||||
chunk_index = int(row[chunk_col])
|
||||
file_index = int(row[file_col])
|
||||
from_timestamp = float(row[ts_from_col])
|
||||
to_timestamp = float(row[ts_to_col])
|
||||
|
||||
video_template = info.get(
|
||||
"video_path", "videos/{video_key}/chunk-{chunk_index:03d}/file-{file_index:03d}.mp4"
|
||||
)
|
||||
video_rel = video_template.format(
|
||||
video_key=selected_camera,
|
||||
chunk_index=chunk_index,
|
||||
file_index=file_index,
|
||||
)
|
||||
|
||||
task_name = _resolve_task_name(row, local_path)
|
||||
|
||||
return {
|
||||
"fps": fps,
|
||||
"camera": selected_camera,
|
||||
"video_rel": video_rel,
|
||||
"chunk_index": chunk_index,
|
||||
"file_index": file_index,
|
||||
"from_ts": from_timestamp,
|
||||
"to_ts": to_timestamp,
|
||||
"task_name": task_name,
|
||||
}
|
||||
|
||||
|
||||
def _resolve_task_name(row: pd.Series, local_path: Path) -> str:
|
||||
"""Best-effort extraction of the task name for an episode row.
|
||||
|
||||
Args:
|
||||
row: Single-episode row from the episodes parquet.
|
||||
local_path: Dataset cache root.
|
||||
|
||||
Returns:
|
||||
Task name string, or empty string if unavailable.
|
||||
"""
|
||||
try:
|
||||
if "tasks" in row.index and row["tasks"] is not None:
|
||||
tasks_val = row["tasks"]
|
||||
if isinstance(tasks_val, (list, tuple, np.ndarray)) and len(tasks_val) > 0:
|
||||
return str(tasks_val[0])
|
||||
return str(tasks_val).strip("[]'")
|
||||
|
||||
tasks_parquet = local_path / "meta" / "tasks.parquet"
|
||||
if tasks_parquet.exists():
|
||||
tasks_df = pd.read_parquet(tasks_parquet)
|
||||
task_idx = int(row.get("task_index", 0)) if "task_index" in row.index else 0
|
||||
match = tasks_df[tasks_df["task_index"] == task_idx]
|
||||
if not match.empty:
|
||||
return str(match.index[0])
|
||||
except Exception as exc:
|
||||
logging.warning("Could not load task name: %s", exc)
|
||||
return ""
|
||||
|
||||
|
||||
def download_video_file(repo_id: str, local_path: Path, video_rel: str) -> Path:
|
||||
"""Download the specific video file if not already cached.
|
||||
|
||||
Args:
|
||||
repo_id: HuggingFace dataset repository ID.
|
||||
local_path: Local cache directory.
|
||||
video_rel: Relative path to the video file within the dataset.
|
||||
|
||||
Returns:
|
||||
Absolute path to the downloaded video file.
|
||||
"""
|
||||
video_path = local_path / video_rel
|
||||
if video_path.exists():
|
||||
logging.info(" Video already cached: %s", video_path)
|
||||
return video_path
|
||||
logging.info("[2/4] Downloading video file %s ...", video_rel)
|
||||
snapshot_download(
|
||||
repo_id=repo_id,
|
||||
repo_type="dataset",
|
||||
local_dir=str(local_path),
|
||||
allow_patterns=[video_rel],
|
||||
)
|
||||
if not video_path.exists():
|
||||
raise RuntimeError(f"Video not found after download: {video_path}")
|
||||
return video_path
|
||||
|
||||
|
||||
def load_progress_data(local_path: Path, episode: int) -> np.ndarray | None:
|
||||
"""Load sarm_progress values for an episode.
|
||||
|
||||
Args:
|
||||
local_path: Dataset cache root.
|
||||
episode: Episode index.
|
||||
|
||||
Returns:
|
||||
Sorted (N, 2) array of (frame_index, progress), or None if unavailable.
|
||||
"""
|
||||
parquet_path = local_path / "sarm_progress.parquet"
|
||||
if not parquet_path.exists():
|
||||
logging.warning("sarm_progress.parquet not found")
|
||||
return None
|
||||
df = pd.read_parquet(parquet_path)
|
||||
logging.info(" sarm_progress.parquet columns: %s", list(df.columns))
|
||||
episode_df = df[df["episode_index"] == episode].copy()
|
||||
if episode_df.empty:
|
||||
logging.warning("No sarm_progress rows for episode %d", episode)
|
||||
return None
|
||||
episode_df = episode_df.sort_values("frame_index")
|
||||
|
||||
if "progress_dense" in episode_df.columns and episode_df["progress_dense"].notna().any():
|
||||
progress_column = "progress_dense"
|
||||
elif "progress_sparse" in episode_df.columns:
|
||||
progress_column = "progress_sparse"
|
||||
else:
|
||||
progress_columns = [c for c in episode_df.columns if "progress" in c.lower()]
|
||||
if not progress_columns:
|
||||
return None
|
||||
progress_column = progress_columns[0]
|
||||
|
||||
logging.info(" Using progress column: '%s'", progress_column)
|
||||
return episode_df[["frame_index", progress_column]].rename(columns={progress_column: "progress"}).values
|
||||
|
||||
|
||||
def _precompute_pixel_coords(
|
||||
progress_data: np.ndarray,
|
||||
num_frames: int,
|
||||
frame_width: int,
|
||||
frame_height: int,
|
||||
) -> np.ndarray:
|
||||
"""Map progress samples to pixel coordinates for overlay drawing.
|
||||
|
||||
Args:
|
||||
progress_data: (N, 2) array of (frame_index, progress).
|
||||
num_frames: Total number of video frames.
|
||||
frame_width: Video width in pixels.
|
||||
frame_height: Video height in pixels.
|
||||
|
||||
Returns:
|
||||
(N, 2) array of (x, y) pixel coordinates.
|
||||
"""
|
||||
frame_indices = progress_data[:, 0].astype(float)
|
||||
progress_values = np.clip(progress_data[:, 1].astype(float), 0.0, 1.0)
|
||||
|
||||
y_top = int(frame_height * GRAPH_Y_TOP_FRAC)
|
||||
y_bot = int(frame_height * GRAPH_Y_BOT_FRAC)
|
||||
graph_height = y_bot - y_top
|
||||
|
||||
x_coords = (frame_indices / (num_frames - 1) * (frame_width - 1)).astype(int)
|
||||
y_coords = (y_bot - progress_values * graph_height).astype(int)
|
||||
|
||||
return np.stack([x_coords, y_coords], axis=1)
|
||||
|
||||
|
||||
def _progress_color(normalized_position: float) -> tuple[int, int, int]:
|
||||
"""Interpolate BGR color from red to green based on position in [0, 1].
|
||||
|
||||
Args:
|
||||
normalized_position: Value in [0, 1] indicating how far along the episode.
|
||||
|
||||
Returns:
|
||||
BGR color tuple.
|
||||
"""
|
||||
red = int(255 * (1.0 - normalized_position))
|
||||
green = int(255 * normalized_position)
|
||||
return (0, green, red)
|
||||
|
||||
|
||||
def _prerender_fill_polygon(
|
||||
pixel_coords: np.ndarray,
|
||||
frame_width: int,
|
||||
frame_height: int,
|
||||
) -> np.ndarray:
|
||||
"""Pre-render the grey fill polygon under the progress curve as a BGRA image.
|
||||
|
||||
Args:
|
||||
pixel_coords: (N, 2) array of (x, y) pixel coordinates.
|
||||
frame_width: Video width in pixels.
|
||||
frame_height: Video height in pixels.
|
||||
|
||||
Returns:
|
||||
BGRA image array of shape (frame_height, frame_width, 4).
|
||||
"""
|
||||
y_bot = int(frame_height * GRAPH_Y_BOT_FRAC)
|
||||
fill_image = np.zeros((frame_height, frame_width, 4), dtype=np.uint8)
|
||||
polygon = np.concatenate(
|
||||
[
|
||||
pixel_coords,
|
||||
[[pixel_coords[-1][0], y_bot], [pixel_coords[0][0], y_bot]],
|
||||
],
|
||||
axis=0,
|
||||
).astype(np.int32)
|
||||
cv2.fillPoly(fill_image, [polygon], color=(128, 128, 128, int(255 * FILL_ALPHA)))
|
||||
return fill_image
|
||||
|
||||
|
||||
def _alpha_composite_region(base: np.ndarray, overlay_bgra: np.ndarray, x_limit: int) -> None:
|
||||
"""Blend BGRA overlay onto BGR base in-place, up to x_limit columns.
|
||||
|
||||
Args:
|
||||
base: BGR frame to draw on (modified in-place).
|
||||
overlay_bgra: BGRA overlay image.
|
||||
x_limit: Only blend columns [0, x_limit).
|
||||
"""
|
||||
if x_limit <= 0:
|
||||
return
|
||||
region_base = base[:, :x_limit]
|
||||
region_overlay = overlay_bgra[:, :x_limit]
|
||||
alpha = region_overlay[:, :, 3:4].astype(np.float32) / 255.0
|
||||
region_base[:] = np.clip(
|
||||
region_overlay[:, :, :3].astype(np.float32) * alpha + region_base.astype(np.float32) * (1.0 - alpha),
|
||||
0,
|
||||
255,
|
||||
).astype(np.uint8)
|
||||
|
||||
|
||||
def _draw_text_outlined(
|
||||
frame: np.ndarray,
|
||||
text: str,
|
||||
position: tuple[int, int],
|
||||
font_scale: float,
|
||||
thickness: int = 1,
|
||||
) -> None:
|
||||
"""Draw white text with a dark outline for readability on any background.
|
||||
|
||||
Args:
|
||||
frame: BGR image to draw on (modified in-place).
|
||||
text: String to render.
|
||||
position: (x, y) bottom-left corner of the text.
|
||||
font_scale: OpenCV font scale.
|
||||
thickness: Text stroke thickness.
|
||||
"""
|
||||
font = cv2.FONT_HERSHEY_SIMPLEX
|
||||
cv2.putText(frame, text, position, font, font_scale, (0, 0, 0), thickness + 2, cv2.LINE_AA)
|
||||
cv2.putText(frame, text, position, font, font_scale, (255, 255, 255), thickness, cv2.LINE_AA)
|
||||
|
||||
|
||||
def composite_progress_video(
|
||||
video_path: Path,
|
||||
from_timestamp: float,
|
||||
to_timestamp: float,
|
||||
progress_data: np.ndarray,
|
||||
output_path: Path,
|
||||
fps: float,
|
||||
task_name: str = "",
|
||||
) -> Path:
|
||||
"""Read episode frames by seeking into the source video, draw progress overlay, write output.
|
||||
|
||||
Uses cv2.CAP_PROP_POS_MSEC to seek directly into the source video,
|
||||
eliminating the need for an intermediate clip file.
|
||||
|
||||
Args:
|
||||
video_path: Path to the full source video file.
|
||||
from_timestamp: Start timestamp of the episode in seconds.
|
||||
to_timestamp: End timestamp of the episode in seconds.
|
||||
progress_data: (N, 2) array of (frame_index, progress).
|
||||
output_path: Path to write the output MP4.
|
||||
fps: Frames per second for the output video.
|
||||
task_name: Optional task name to display at the top of the video.
|
||||
|
||||
Returns:
|
||||
Path to the written output file (MP4).
|
||||
"""
|
||||
capture = cv2.VideoCapture(str(video_path))
|
||||
try:
|
||||
capture.set(cv2.CAP_PROP_POS_MSEC, from_timestamp * 1000)
|
||||
|
||||
frame_width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
frame_height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
duration_seconds = to_timestamp - from_timestamp
|
||||
num_frames = int(round(duration_seconds * fps))
|
||||
|
||||
logging.info(
|
||||
" Video: %dx%d, %d frames @ %.1f fps (%.2fs)",
|
||||
frame_width,
|
||||
frame_height,
|
||||
num_frames,
|
||||
fps,
|
||||
duration_seconds,
|
||||
)
|
||||
|
||||
pixel_coords = _precompute_pixel_coords(progress_data, num_frames, frame_width, frame_height)
|
||||
y_ref = int(frame_height * GRAPH_Y_TOP_FRAC)
|
||||
|
||||
fill_image = _prerender_fill_polygon(pixel_coords, frame_width, frame_height)
|
||||
|
||||
ref_line_image = np.zeros((frame_height, frame_width, 4), dtype=np.uint8)
|
||||
cv2.line(
|
||||
ref_line_image,
|
||||
(0, y_ref),
|
||||
(frame_width - 1, y_ref),
|
||||
(200, 200, 200, int(255 * REF_ALPHA)),
|
||||
1,
|
||||
cv2.LINE_AA,
|
||||
)
|
||||
|
||||
frame_indices = progress_data[:, 0].astype(int)
|
||||
progress_values = progress_data[:, 1].astype(float)
|
||||
|
||||
logging.info("[3/4] Compositing %d frames ...", num_frames)
|
||||
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
|
||||
writer = cv2.VideoWriter(str(output_path), fourcc, fps, (frame_width, frame_height))
|
||||
|
||||
for frame_idx in range(num_frames):
|
||||
ret, frame = capture.read()
|
||||
if not ret:
|
||||
break
|
||||
|
||||
drawn_count = int(np.searchsorted(frame_indices, frame_idx, side="right"))
|
||||
x_current = (
|
||||
int(pixel_coords[min(drawn_count, len(pixel_coords)) - 1][0]) + 1 if drawn_count > 0 else 0
|
||||
)
|
||||
|
||||
_alpha_composite_region(frame, ref_line_image, frame_width)
|
||||
_alpha_composite_region(frame, fill_image, x_current)
|
||||
|
||||
if drawn_count >= 2:
|
||||
time_position = (drawn_count - 1) / max(len(progress_values) - 1, 1)
|
||||
line_color = _progress_color(time_position)
|
||||
points = pixel_coords[:drawn_count].reshape(-1, 1, 2).astype(np.int32)
|
||||
cv2.polylines(
|
||||
frame,
|
||||
[points],
|
||||
isClosed=False,
|
||||
color=(255, 255, 255),
|
||||
thickness=SHADOW_THICKNESS,
|
||||
lineType=cv2.LINE_AA,
|
||||
)
|
||||
cv2.polylines(
|
||||
frame,
|
||||
[points],
|
||||
isClosed=False,
|
||||
color=line_color,
|
||||
thickness=LINE_THICKNESS,
|
||||
lineType=cv2.LINE_AA,
|
||||
)
|
||||
|
||||
if drawn_count > 0:
|
||||
score = float(progress_values[min(drawn_count, len(progress_values)) - 1])
|
||||
score_text = f"{score:.2f}"
|
||||
(text_width, _), _ = cv2.getTextSize(
|
||||
score_text, cv2.FONT_HERSHEY_SIMPLEX, SCORE_FONT_SCALE, 2
|
||||
)
|
||||
score_x = frame_width - text_width - 12
|
||||
score_y = frame_height - 12
|
||||
time_position = (drawn_count - 1) / max(len(progress_values) - 1, 1)
|
||||
score_color = _progress_color(time_position)
|
||||
cv2.putText(
|
||||
frame,
|
||||
score_text,
|
||||
(score_x, score_y),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
SCORE_FONT_SCALE,
|
||||
(0, 0, 0),
|
||||
4,
|
||||
cv2.LINE_AA,
|
||||
)
|
||||
cv2.putText(
|
||||
frame,
|
||||
score_text,
|
||||
(score_x, score_y),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
SCORE_FONT_SCALE,
|
||||
score_color,
|
||||
2,
|
||||
cv2.LINE_AA,
|
||||
)
|
||||
|
||||
if task_name:
|
||||
(text_width, _), _ = cv2.getTextSize(task_name, cv2.FONT_HERSHEY_SIMPLEX, TASK_FONT_SCALE, 1)
|
||||
task_x = max((frame_width - text_width) // 2, 4)
|
||||
_draw_text_outlined(frame, task_name, (task_x, 22), TASK_FONT_SCALE)
|
||||
|
||||
writer.write(frame)
|
||||
if frame_idx % 100 == 0:
|
||||
logging.info(" Frame %d/%d ...", frame_idx, num_frames)
|
||||
|
||||
writer.release()
|
||||
finally:
|
||||
capture.release()
|
||||
|
||||
logging.info(" MP4 written: %s", output_path)
|
||||
return output_path
|
||||
|
||||
|
||||
def convert_mp4_to_gif(mp4_path: Path) -> Path:
|
||||
"""Convert an MP4 to an optimized GIF using ffmpeg palette generation.
|
||||
|
||||
Args:
|
||||
mp4_path: Path to the source MP4 file.
|
||||
|
||||
Returns:
|
||||
Path to the generated GIF file.
|
||||
"""
|
||||
capture = cv2.VideoCapture(str(mp4_path))
|
||||
frame_width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
capture.release()
|
||||
|
||||
gif_path = mp4_path.with_suffix(".gif")
|
||||
palette_path = mp4_path.parent / "_palette.png"
|
||||
|
||||
logging.info("[4/4] Converting to GIF ...")
|
||||
result_palette = subprocess.run( # nosec B607
|
||||
[
|
||||
"ffmpeg",
|
||||
"-y",
|
||||
"-i",
|
||||
str(mp4_path),
|
||||
"-vf",
|
||||
f"fps=10,scale={frame_width}:-1:flags=lanczos,palettegen=max_colors=128:stats_mode=diff",
|
||||
"-update",
|
||||
"1",
|
||||
str(palette_path),
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
if result_palette.returncode != 0:
|
||||
logging.warning("palettegen failed:\n%s", result_palette.stderr[-500:])
|
||||
|
||||
result_gif = subprocess.run( # nosec B607
|
||||
[
|
||||
"ffmpeg",
|
||||
"-y",
|
||||
"-i",
|
||||
str(mp4_path),
|
||||
"-i",
|
||||
str(palette_path),
|
||||
"-filter_complex",
|
||||
f"fps=10,scale={frame_width}:-1:flags=lanczos[v];[v][1:v]paletteuse=dither=bayer:bayer_scale=3",
|
||||
str(gif_path),
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
if result_gif.returncode != 0:
|
||||
logging.warning("GIF encode failed:\n%s", result_gif.stderr[-500:])
|
||||
|
||||
palette_path.unlink(missing_ok=True)
|
||||
logging.info(" GIF written: %s", gif_path)
|
||||
return gif_path
|
||||
|
||||
|
||||
def process_dataset(
|
||||
repo_id: str,
|
||||
episode: int,
|
||||
camera_key: str | None,
|
||||
output_dir: Path,
|
||||
create_gif: bool = False,
|
||||
) -> Path | None:
|
||||
"""Full pipeline: download, extract metadata, composite progress, write output.
|
||||
|
||||
Args:
|
||||
repo_id: HuggingFace dataset repository ID.
|
||||
episode: Episode index.
|
||||
camera_key: Camera key to use, or None for auto-selection.
|
||||
output_dir: Directory to write output files.
|
||||
create_gif: If True, also generate a GIF from the MP4.
|
||||
|
||||
Returns:
|
||||
Path to the final output file, or None on failure.
|
||||
"""
|
||||
safe_name = repo_id.replace("/", "_")
|
||||
logging.info("Processing: %s | episode %d", repo_id, episode)
|
||||
|
||||
local_path = download_episode_metadata(repo_id, episode)
|
||||
logging.info(" Local cache: %s", local_path)
|
||||
|
||||
episode_meta = load_episode_meta(local_path, episode, camera_key)
|
||||
logging.info(" Episode meta: %s", episode_meta)
|
||||
|
||||
video_path = download_video_file(repo_id, local_path, episode_meta["video_rel"])
|
||||
|
||||
progress_data = load_progress_data(local_path, episode)
|
||||
if progress_data is None:
|
||||
logging.error("Could not load sarm_progress data. Skipping overlay.")
|
||||
return None
|
||||
|
||||
logging.info(" Progress frames: %d", len(progress_data))
|
||||
|
||||
output_path = output_dir / f"{safe_name}_ep{episode}_progress.mp4"
|
||||
final_path = composite_progress_video(
|
||||
video_path=video_path,
|
||||
from_timestamp=episode_meta["from_ts"],
|
||||
to_timestamp=episode_meta["to_ts"],
|
||||
progress_data=progress_data,
|
||||
output_path=output_path,
|
||||
fps=episode_meta["fps"],
|
||||
task_name=episode_meta.get("task_name", ""),
|
||||
)
|
||||
|
||||
if create_gif:
|
||||
final_path = convert_mp4_to_gif(final_path)
|
||||
|
||||
logging.info("Done: %s", final_path)
|
||||
return final_path
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Create MP4/GIF videos with sarm_progress overlay for dataset episodes."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
required=True,
|
||||
help="HuggingFace dataset repository ID (e.g. 'lerobot-data-collection/level2_final_quality3').",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--episode",
|
||||
type=int,
|
||||
required=True,
|
||||
help="Episode index to visualize.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--camera-key",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Camera observation key (e.g. 'observation.images.base'). Auto-selects first camera if omitted.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=Path,
|
||||
default=Path("progress_videos"),
|
||||
help="Directory to write output files (default: ./progress_videos).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gif",
|
||||
action="store_true",
|
||||
help="Also generate a GIF from the MP4 output.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
|
||||
|
||||
args.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
result = process_dataset(
|
||||
repo_id=args.repo_id,
|
||||
episode=args.episode,
|
||||
camera_key=args.camera_key,
|
||||
output_dir=args.output_dir,
|
||||
create_gif=args.gif,
|
||||
)
|
||||
|
||||
if result:
|
||||
logging.info("Output: %s", result)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,228 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Shared utilities for Human-in-the-Loop data collection scripts."""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
from lerobot.processor import (
|
||||
IdentityProcessorStep,
|
||||
RobotAction,
|
||||
RobotObservation,
|
||||
RobotProcessorPipeline,
|
||||
)
|
||||
from lerobot.processor.converters import (
|
||||
observation_to_transition,
|
||||
robot_action_observation_to_transition,
|
||||
transition_to_observation,
|
||||
transition_to_robot_action,
|
||||
)
|
||||
from lerobot.robots import Robot
|
||||
from lerobot.teleoperators import Teleoperator
|
||||
from lerobot.utils.control_utils import is_headless
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class HILDatasetConfig:
|
||||
repo_id: str
|
||||
single_task: str
|
||||
root: str | Path | None = None
|
||||
fps: int = 30
|
||||
episode_time_s: float = 120
|
||||
num_episodes: int = 50
|
||||
video: bool = True
|
||||
push_to_hub: bool = True
|
||||
private: bool = False
|
||||
tags: list[str] | None = None
|
||||
num_image_writer_processes: int = 0
|
||||
num_image_writer_threads_per_camera: int = 4
|
||||
video_encoding_batch_size: int = 1
|
||||
vcodec: str = "auto"
|
||||
streaming_encoding: bool = True
|
||||
encoder_queue_maxsize: int = 30
|
||||
encoder_threads: int | None = None
|
||||
rename_map: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
|
||||
def teleop_has_motor_control(teleop: Teleoperator) -> bool:
|
||||
"""Check if teleoperator has motor control capabilities."""
|
||||
return all(hasattr(teleop, attr) for attr in ("enable_torque", "disable_torque", "write_goal_positions"))
|
||||
|
||||
|
||||
def teleop_disable_torque(teleop: Teleoperator) -> None:
|
||||
"""Disable teleop torque if supported."""
|
||||
if hasattr(teleop, "disable_torque"):
|
||||
teleop.disable_torque()
|
||||
|
||||
|
||||
def teleop_enable_torque(teleop: Teleoperator) -> None:
|
||||
"""Enable teleop torque if supported."""
|
||||
if hasattr(teleop, "enable_torque"):
|
||||
teleop.enable_torque()
|
||||
|
||||
|
||||
def teleop_smooth_move_to(teleop: Teleoperator, target_pos: dict, duration_s: float = 2.0, fps: int = 50):
|
||||
"""Smoothly move teleop to target position if motor control is available."""
|
||||
if not teleop_has_motor_control(teleop):
|
||||
logger.warning("Teleop does not support motor control - cannot mirror robot position")
|
||||
return
|
||||
|
||||
teleop_enable_torque(teleop)
|
||||
current = teleop.get_action()
|
||||
steps = max(int(duration_s * fps), 1)
|
||||
|
||||
for step in range(steps + 1):
|
||||
t = step / steps
|
||||
interp = {}
|
||||
for k in current:
|
||||
if k in target_pos:
|
||||
interp[k] = current[k] * (1 - t) + target_pos[k] * t
|
||||
else:
|
||||
interp[k] = current[k]
|
||||
teleop.write_goal_positions(interp)
|
||||
time.sleep(1 / fps)
|
||||
|
||||
|
||||
def init_keyboard_listener():
|
||||
"""Initialize keyboard listener with HIL controls."""
|
||||
events = {
|
||||
"exit_early": False,
|
||||
"rerecord_episode": False,
|
||||
"stop_recording": False,
|
||||
"policy_paused": False,
|
||||
"correction_active": False,
|
||||
"resume_policy": False,
|
||||
"in_reset": False,
|
||||
"start_next_episode": False,
|
||||
}
|
||||
|
||||
if is_headless():
|
||||
logger.warning("Headless environment - keyboard controls unavailable")
|
||||
return None, events
|
||||
|
||||
from pynput import keyboard
|
||||
|
||||
def on_press(key):
|
||||
try:
|
||||
if events["in_reset"]:
|
||||
if key in [keyboard.Key.space, keyboard.Key.right]:
|
||||
logger.info("[HIL] Starting next episode...")
|
||||
events["start_next_episode"] = True
|
||||
elif hasattr(key, "char") and key.char == "c":
|
||||
events["start_next_episode"] = True
|
||||
elif key == keyboard.Key.esc:
|
||||
logger.info("[HIL] ESC - Stop recording, pushing to hub...")
|
||||
events["stop_recording"] = True
|
||||
events["start_next_episode"] = True
|
||||
else:
|
||||
if key == keyboard.Key.space:
|
||||
if not events["policy_paused"] and not events["correction_active"]:
|
||||
logger.info("[HIL] PAUSED - Press 'c' to take control or 'p' to resume policy")
|
||||
events["policy_paused"] = True
|
||||
elif hasattr(key, "char") and key.char == "c":
|
||||
if events["policy_paused"] and not events["correction_active"]:
|
||||
logger.info("[HIL] Taking control...")
|
||||
events["start_next_episode"] = True
|
||||
elif hasattr(key, "char") and key.char == "p":
|
||||
if events["policy_paused"] or events["correction_active"]:
|
||||
logger.info("[HIL] Resuming policy...")
|
||||
events["resume_policy"] = True
|
||||
elif key == keyboard.Key.right:
|
||||
logger.info("[HIL] End episode")
|
||||
events["exit_early"] = True
|
||||
elif key == keyboard.Key.left:
|
||||
logger.info("[HIL] Re-record episode")
|
||||
events["rerecord_episode"] = True
|
||||
events["exit_early"] = True
|
||||
elif key == keyboard.Key.esc:
|
||||
logger.info("[HIL] ESC - Stop recording...")
|
||||
events["stop_recording"] = True
|
||||
events["exit_early"] = True
|
||||
except Exception as e:
|
||||
logger.info(f"Key error: {e}")
|
||||
|
||||
listener = keyboard.Listener(on_press=on_press)
|
||||
listener.start()
|
||||
return listener, events
|
||||
|
||||
|
||||
def make_identity_processors():
|
||||
"""Create identity processors for recording."""
|
||||
teleop_proc = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[IdentityProcessorStep()],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
obs_proc = RobotProcessorPipeline[RobotObservation, RobotObservation](
|
||||
steps=[IdentityProcessorStep()],
|
||||
to_transition=observation_to_transition,
|
||||
to_output=transition_to_observation,
|
||||
)
|
||||
return teleop_proc, obs_proc
|
||||
|
||||
|
||||
def reset_loop(robot: Robot, teleop: Teleoperator, events: dict, fps: int):
|
||||
"""Reset period where human repositions environment."""
|
||||
logger.info("[HIL] RESET")
|
||||
|
||||
events["in_reset"] = True
|
||||
events["start_next_episode"] = False
|
||||
|
||||
obs = robot.get_observation()
|
||||
robot_pos = {k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features}
|
||||
teleop_smooth_move_to(teleop, robot_pos, duration_s=2.0, fps=50)
|
||||
|
||||
logger.info("Press any key to enable teleoperation")
|
||||
while not events["start_next_episode"] and not events["stop_recording"]:
|
||||
precise_sleep(0.05)
|
||||
|
||||
if events["stop_recording"]:
|
||||
return
|
||||
|
||||
events["start_next_episode"] = False
|
||||
teleop_disable_torque(teleop)
|
||||
logger.info("Teleop enabled - press any key to start episode")
|
||||
|
||||
while not events["start_next_episode"] and not events["stop_recording"]:
|
||||
loop_start = time.perf_counter()
|
||||
action = teleop.get_action()
|
||||
robot.send_action(action)
|
||||
precise_sleep(1 / fps - (time.perf_counter() - loop_start))
|
||||
|
||||
events["in_reset"] = False
|
||||
events["start_next_episode"] = False
|
||||
events["exit_early"] = False
|
||||
events["policy_paused"] = False
|
||||
events["correction_active"] = False
|
||||
events["resume_policy"] = False
|
||||
|
||||
|
||||
def print_controls(rtc: bool = False):
|
||||
"""Print control instructions."""
|
||||
mode = "Human-in-the-Loop Data Collection" + (" (RTC)" if rtc else "")
|
||||
logger.info(
|
||||
"%s\n Controls:\n"
|
||||
" SPACE - Pause policy\n"
|
||||
" c - Take control\n"
|
||||
" p - Resume policy after pause/correction\n"
|
||||
" → - End episode\n"
|
||||
" ESC - Stop and push to hub",
|
||||
mode,
|
||||
)
|
||||
@@ -43,13 +43,12 @@ def main():
|
||||
keyboard.connect()
|
||||
|
||||
# Init rerun viewer
|
||||
init_rerun(session_name="lekiwi_teleop", robot=robot, reset_time=True)
|
||||
init_rerun(session_name="lekiwi_teleop")
|
||||
|
||||
if not robot.is_connected or not leader_arm.is_connected or not keyboard.is_connected:
|
||||
raise ValueError("Robot or teleop is not connected!")
|
||||
|
||||
print("Starting teleop loop...")
|
||||
start = time.perf_counter()
|
||||
while True:
|
||||
t0 = time.perf_counter()
|
||||
|
||||
@@ -70,7 +69,7 @@ def main():
|
||||
_ = robot.send_action(action)
|
||||
|
||||
# Visualize
|
||||
log_rerun_data(observation=observation, action=action, log_time=time.perf_counter() - start)
|
||||
log_rerun_data(observation=observation, action=action)
|
||||
|
||||
precise_sleep(max(1.0 / FPS - (time.perf_counter() - t0), 0.0))
|
||||
|
||||
|
||||
@@ -90,13 +90,12 @@ def main():
|
||||
teleop_device.connect()
|
||||
|
||||
# Init rerun viewer
|
||||
init_rerun(session_name="phone_so100_teleop", robot=robot, reset_time=True)
|
||||
init_rerun(session_name="phone_so100_teleop")
|
||||
|
||||
if not robot.is_connected or not teleop_device.is_connected:
|
||||
raise ValueError("Robot or teleop is not connected!")
|
||||
|
||||
print("Starting teleop loop. Move your phone to teleoperate the robot...")
|
||||
start = time.perf_counter()
|
||||
while True:
|
||||
t0 = time.perf_counter()
|
||||
|
||||
@@ -113,7 +112,7 @@ def main():
|
||||
_ = robot.send_action(joint_action)
|
||||
|
||||
# Visualize
|
||||
log_rerun_data(observation=phone_obs, action=joint_action, log_time=time.perf_counter() - start)
|
||||
log_rerun_data(observation=phone_obs, action=joint_action)
|
||||
|
||||
precise_sleep(max(1.0 / FPS - (time.perf_counter() - t0), 0.0))
|
||||
|
||||
|
||||
@@ -69,15 +69,20 @@ Usage:
|
||||
--policy.path=lerobot-data-collection/folding_final \
|
||||
--robot.type=bi_openarm_follower \
|
||||
--robot.cameras='{left_wrist: {type: opencv, index_or_path: "/dev/video4", width: 1280, height: 720, fps: 30}, base: {type: opencv, index_or_path: "/dev/video2", width: 640, height: 480, fps: 30}, right_wrist: {type: opencv, index_or_path: "/dev/video0", width: 1280, height: 720, fps: 30}}' \
|
||||
--robot.left_arm_config.port=can1 \
|
||||
--robot.left_arm_config.port=can0 \
|
||||
--robot.left_arm_config.side=left \
|
||||
--robot.left_arm_config.can_interface=socketcan \
|
||||
--robot.right_arm_config.port=can0 \
|
||||
--robot.left_arm_config.disable_torque_on_disconnect=true \
|
||||
--robot.left_arm_config.max_relative_target=8.0 \
|
||||
--robot.right_arm_config.port=can1 \
|
||||
--robot.right_arm_config.side=right \
|
||||
--robot.right_arm_config.can_interface=socketcan \
|
||||
--robot.right_arm_config.disable_torque_on_disconnect=true \
|
||||
--robot.right_arm_config.max_relative_target=8.0 \
|
||||
--task="Fold the T-shirt properly" \
|
||||
--fps=30 \
|
||||
--duration=2000 \
|
||||
--interpolation_multiplier=3 \
|
||||
--rtc.enabled=true \
|
||||
--rtc.execution_horizon=20 \
|
||||
--rtc.max_guidance_weight=5.0 \
|
||||
@@ -104,9 +109,7 @@ from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import RTCAttentionSchedule
|
||||
from lerobot.datasets.feature_utils import build_dataset_frame, hw_to_dataset_features
|
||||
from lerobot.policies.factory import get_policy_class, make_pre_post_processors
|
||||
from lerobot.policies.rtc.action_queue import ActionQueue
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.policies.rtc.latency_tracker import LatencyTracker
|
||||
from lerobot.policies.rtc import ActionInterpolator, ActionQueue, LatencyTracker, RTCConfig
|
||||
from lerobot.processor import (
|
||||
NormalizerProcessorStep,
|
||||
RelativeActionsProcessorStep,
|
||||
@@ -181,6 +184,7 @@ class RTCDemoConfig(HubMixin):
|
||||
# Demo parameters
|
||||
duration: float = 30.0 # Duration to run the demo (seconds)
|
||||
fps: float = 10.0 # Action execution frequency (Hz)
|
||||
interpolation_multiplier: int = 1 # Control rate multiplier (1=off, 2=2x, 3=3x)
|
||||
|
||||
# Compute device
|
||||
device: str | None = None # Device to run on (cuda, cpu, auto)
|
||||
@@ -461,20 +465,23 @@ def actor_control(
|
||||
action_keys = [k for k in robot.action_features() if k.endswith(".pos")]
|
||||
|
||||
action_count = 0
|
||||
action_interval = 1.0 / cfg.fps
|
||||
interpolator = ActionInterpolator(multiplier=cfg.interpolation_multiplier)
|
||||
action_interval = interpolator.get_control_interval(cfg.fps)
|
||||
|
||||
while not shutdown_event.is_set():
|
||||
start_time = time.perf_counter()
|
||||
|
||||
# Try to get an action from the queue with timeout
|
||||
action = action_queue.get()
|
||||
if interpolator.needs_new_action():
|
||||
new_action = action_queue.get()
|
||||
if new_action is not None:
|
||||
interpolator.add(new_action.cpu())
|
||||
|
||||
action = interpolator.get()
|
||||
if action is not None:
|
||||
action = action.cpu()
|
||||
action_dict = {key: action[i].item() for i, key in enumerate(action_keys)}
|
||||
action_processed = robot_action_processor((action_dict, None))
|
||||
robot.send_action(action_processed)
|
||||
|
||||
action_count += 1
|
||||
|
||||
dt_s = time.perf_counter() - start_time
|
||||
|
||||
@@ -95,10 +95,9 @@ def main():
|
||||
leader.connect()
|
||||
|
||||
# Init rerun viewer
|
||||
init_rerun(session_name="so100_so100_EE_teleop", robot=follower, reset_time=True)
|
||||
init_rerun(session_name="so100_so100_EE_teleop")
|
||||
|
||||
print("Starting teleop loop...")
|
||||
start = time.perf_counter()
|
||||
while True:
|
||||
t0 = time.perf_counter()
|
||||
|
||||
@@ -118,9 +117,7 @@ def main():
|
||||
_ = follower.send_action(follower_joints_act)
|
||||
|
||||
# Visualize
|
||||
log_rerun_data(
|
||||
observation=leader_ee_act, action=follower_joints_act, log_time=time.perf_counter() - start
|
||||
)
|
||||
log_rerun_data(observation=leader_ee_act, action=follower_joints_act)
|
||||
|
||||
precise_sleep(max(1.0 / FPS - (time.perf_counter() - t0), 0.0))
|
||||
|
||||
|
||||
+4
-6
@@ -25,7 +25,7 @@ discord = "https://discord.gg/s3KuuzsPFb"
|
||||
|
||||
[project]
|
||||
name = "lerobot"
|
||||
version = "0.5.1"
|
||||
version = "0.5.2"
|
||||
description = "🤗 LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch"
|
||||
dynamic = ["readme"]
|
||||
license = { text = "Apache-2.0" }
|
||||
@@ -71,9 +71,9 @@ dependencies = [
|
||||
"cmake>=3.29.0.1,<4.2.0",
|
||||
"packaging>=24.2,<26.0",
|
||||
|
||||
"torch>=2.2.1,<2.11.0",
|
||||
"torchcodec>=0.2.1,<0.11.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')",
|
||||
"torchvision>=0.21.0,<0.26.0",
|
||||
"torch>=2.7,<2.11.0",
|
||||
"torchcodec>=0.3.0,<0.11.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # NOTE: Windows support starts at version 0.7 (needs torch==2.8), ffmpeg>=8 support starts at version 0.8.1 (needs torch==2.9), system-wide ffmpeg support starts at version 0.10 (needs torch==2.10).
|
||||
"torchvision>=0.22.0,<0.26.0",
|
||||
|
||||
"einops>=0.8.0,<0.9.0",
|
||||
"opencv-python-headless>=4.9.0,<4.14.0",
|
||||
@@ -164,7 +164,6 @@ hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpci
|
||||
# Features
|
||||
async = ["lerobot[grpcio-dep]", "lerobot[matplotlib-dep]"]
|
||||
peft = ["lerobot[transformers-dep]", "lerobot[peft-dep]"]
|
||||
audio = ["sounddevice>=0.5.1,<0.6.0", "soundfile>=0.13.1,<0.14.0", "librosa>=0.11.0,<0.12.0", "torchaudio>=2.6.0,<2.10.0"]
|
||||
|
||||
# Development
|
||||
dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1", "mypy>=1.19.1"]
|
||||
@@ -199,7 +198,6 @@ all = [
|
||||
"lerobot[xvla]",
|
||||
"lerobot[hilserl]",
|
||||
"lerobot[async]",
|
||||
"lerobot[audio]",
|
||||
"lerobot[dev]",
|
||||
"lerobot[test]",
|
||||
"lerobot[video_benchmark]",
|
||||
|
||||
@@ -29,7 +29,6 @@ Example:
|
||||
print(lerobot.available_policies_per_env)
|
||||
print(lerobot.available_robots)
|
||||
print(lerobot.available_cameras)
|
||||
print(lerobot.available_microphones)
|
||||
print(lerobot.available_motors)
|
||||
```
|
||||
|
||||
@@ -175,12 +174,6 @@ available_cameras = [
|
||||
"intelrealsense",
|
||||
]
|
||||
|
||||
# lists all available microphones from `lerobot/microphones`
|
||||
available_microphones = [
|
||||
"portaudio",
|
||||
"touchlab",
|
||||
]
|
||||
|
||||
# lists all available motors from `lerobot/motors`
|
||||
available_motors = [
|
||||
"dynamixel",
|
||||
|
||||
@@ -49,8 +49,6 @@ import torch
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
|
||||
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
|
||||
from lerobot.microphones.portaudio.configuration_portaudio import PortAudioMicrophoneConfig # noqa: F401
|
||||
from lerobot.microphones.touchlab.configuration_touchlab import TouchLabSensorConfig # noqa: F401
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
Robot,
|
||||
RobotConfig,
|
||||
|
||||
@@ -151,12 +151,6 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno
|
||||
return {}
|
||||
return {key: ft for key, ft in self.input_features.items() if ft.type is FeatureType.VISUAL}
|
||||
|
||||
@property
|
||||
def audio_features(self) -> dict[str, PolicyFeature]:
|
||||
if not self.input_features:
|
||||
return {}
|
||||
return {key: ft for key, ft in self.input_features.items() if ft.type is FeatureType.AUDIO}
|
||||
|
||||
@property
|
||||
def action_feature(self) -> PolicyFeature | None:
|
||||
if not self.output_features:
|
||||
|
||||
@@ -20,7 +20,6 @@ from enum import Enum
|
||||
class FeatureType(str, Enum):
|
||||
STATE = "STATE"
|
||||
VISUAL = "VISUAL"
|
||||
AUDIO = "AUDIO"
|
||||
ENV = "ENV"
|
||||
ACTION = "ACTION"
|
||||
REWARD = "REWARD"
|
||||
|
||||
@@ -35,8 +35,6 @@ from lerobot.datasets.io_utils import (
|
||||
write_tasks,
|
||||
)
|
||||
from lerobot.datasets.utils import (
|
||||
DEFAULT_AUDIO_FILE_SIZE_IN_MB,
|
||||
DEFAULT_AUDIO_PATH,
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
DEFAULT_DATA_PATH,
|
||||
@@ -45,7 +43,7 @@ from lerobot.datasets.utils import (
|
||||
DEFAULT_VIDEO_PATH,
|
||||
update_chunk_file_indices,
|
||||
)
|
||||
from lerobot.datasets.video_utils import concatenate_media_files, get_media_duration_in_s
|
||||
from lerobot.datasets.video_utils import concatenate_video_files, get_video_duration_in_s
|
||||
|
||||
|
||||
def validate_all_metadata(all_metadata: list[LeRobotDatasetMetadata]):
|
||||
@@ -114,7 +112,6 @@ def update_meta_data(
|
||||
meta_idx,
|
||||
data_idx,
|
||||
videos_idx,
|
||||
audios_idx,
|
||||
):
|
||||
"""Updates metadata DataFrame with new chunk, file, and timestamp indices.
|
||||
|
||||
@@ -130,7 +127,7 @@ def update_meta_data(
|
||||
meta_idx: Dictionary containing current metadata chunk and file indices.
|
||||
data_idx: Dictionary containing current data chunk and file indices.
|
||||
videos_idx: Dictionary containing current video indices and timestamps.
|
||||
audios_idx: Dictionary containing current audio indices and timestamps.
|
||||
|
||||
Returns:
|
||||
pd.DataFrame: Updated DataFrame with adjusted indices and timestamps.
|
||||
"""
|
||||
@@ -228,36 +225,6 @@ def update_meta_data(
|
||||
# Clean up temporary columns
|
||||
df = df.drop(columns=["_orig_chunk", "_orig_file"])
|
||||
|
||||
for key, audio_idx in audios_idx.items():
|
||||
# Store original audio file indices before updating
|
||||
orig_chunk_col = f"audio/{key}/chunk_index"
|
||||
orig_file_col = f"audio/{key}/file_index"
|
||||
df["_orig_chunk"] = df[orig_chunk_col].copy()
|
||||
df["_orig_file"] = df[orig_file_col].copy()
|
||||
|
||||
# Update chunk and file indices to point to destination
|
||||
df[orig_chunk_col] = audio_idx["chunk"]
|
||||
df[orig_file_col] = audio_idx["file"]
|
||||
|
||||
# Apply per-source-file timestamp offsets
|
||||
src_to_offset = audio_idx.get("src_to_offset", {})
|
||||
if src_to_offset:
|
||||
# Apply offset based on original source file
|
||||
for idx in df.index:
|
||||
src_key = (df.at[idx, "_orig_chunk"], df.at[idx, "_orig_file"])
|
||||
offset = src_to_offset.get(src_key, 0)
|
||||
df.at[idx, f"audio/{key}/from_timestamp"] += offset
|
||||
df.at[idx, f"audio/{key}/to_timestamp"] += offset
|
||||
else:
|
||||
# Fallback to simple offset (for backward compatibility)
|
||||
df[f"audio/{key}/from_timestamp"] = (
|
||||
df[f"audio/{key}/from_timestamp"] + audio_idx["latest_duration"]
|
||||
)
|
||||
df[f"audio/{key}/to_timestamp"] = df[f"audio/{key}/to_timestamp"] + audio_idx["latest_duration"]
|
||||
|
||||
# Clean up temporary columns
|
||||
df = df.drop(columns=["_orig_chunk", "_orig_file"])
|
||||
|
||||
df["dataset_from_index"] = df["dataset_from_index"] + dst_meta.info["total_frames"]
|
||||
df["dataset_to_index"] = df["dataset_to_index"] + dst_meta.info["total_frames"]
|
||||
df["episode_index"] = df["episode_index"] + dst_meta.info["total_episodes"]
|
||||
@@ -272,7 +239,6 @@ def aggregate_datasets(
|
||||
aggr_root: Path | None = None,
|
||||
data_files_size_in_mb: float | None = None,
|
||||
video_files_size_in_mb: float | None = None,
|
||||
audio_files_size_in_mb: float | None = None,
|
||||
chunk_size: int | None = None,
|
||||
):
|
||||
"""Aggregates multiple LeRobot datasets into a single unified dataset.
|
||||
@@ -290,7 +256,6 @@ def aggregate_datasets(
|
||||
aggr_root: Optional root path for the aggregated dataset.
|
||||
data_files_size_in_mb: Maximum size for data files in MB (defaults to DEFAULT_DATA_FILE_SIZE_IN_MB)
|
||||
video_files_size_in_mb: Maximum size for video files in MB (defaults to DEFAULT_VIDEO_FILE_SIZE_IN_MB)
|
||||
audio_files_size_in_mb: Maximum size for audio files in MB (defaults to DEFAULT_AUDIO_FILE_SIZE_IN_MB)
|
||||
chunk_size: Maximum number of files per chunk (defaults to DEFAULT_CHUNK_SIZE)
|
||||
"""
|
||||
logging.info("Start aggregate_datasets")
|
||||
@@ -299,8 +264,6 @@ def aggregate_datasets(
|
||||
data_files_size_in_mb = DEFAULT_DATA_FILE_SIZE_IN_MB
|
||||
if video_files_size_in_mb is None:
|
||||
video_files_size_in_mb = DEFAULT_VIDEO_FILE_SIZE_IN_MB
|
||||
if audio_files_size_in_mb is None:
|
||||
audio_files_size_in_mb = DEFAULT_AUDIO_FILE_SIZE_IN_MB
|
||||
if chunk_size is None:
|
||||
chunk_size = DEFAULT_CHUNK_SIZE
|
||||
|
||||
@@ -313,7 +276,6 @@ def aggregate_datasets(
|
||||
)
|
||||
fps, robot_type, features = validate_all_metadata(all_metadata)
|
||||
video_keys = [key for key in features if features[key]["dtype"] == "video"]
|
||||
audio_keys = [key for key in features if features[key]["dtype"] == "audio"]
|
||||
|
||||
dst_meta = LeRobotDatasetMetadata.create(
|
||||
repo_id=aggr_repo_id,
|
||||
@@ -325,7 +287,6 @@ def aggregate_datasets(
|
||||
chunks_size=chunk_size,
|
||||
data_files_size_in_mb=data_files_size_in_mb,
|
||||
video_files_size_in_mb=video_files_size_in_mb,
|
||||
audio_files_size_in_mb=audio_files_size_in_mb,
|
||||
)
|
||||
|
||||
logging.info("Find all tasks")
|
||||
@@ -339,18 +300,14 @@ def aggregate_datasets(
|
||||
videos_idx = {
|
||||
key: {"chunk": 0, "file": 0, "latest_duration": 0, "episode_duration": 0} for key in video_keys
|
||||
}
|
||||
audios_idx = {
|
||||
key: {"chunk": 0, "file": 0, "latest_duration": 0, "episode_duration": 0} for key in audio_keys
|
||||
}
|
||||
|
||||
dst_meta.episodes = {}
|
||||
|
||||
for src_meta in tqdm.tqdm(all_metadata, desc="Copy data and videos"):
|
||||
videos_idx = aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chunk_size)
|
||||
audios_idx = aggregate_audio(src_meta, dst_meta, audios_idx, audio_files_size_in_mb, chunk_size)
|
||||
data_idx = aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_size)
|
||||
|
||||
meta_idx = aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx, audios_idx)
|
||||
meta_idx = aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx)
|
||||
|
||||
# Clear the src_to_dst mapping after processing each source dataset
|
||||
# to avoid interference between different source datasets
|
||||
@@ -418,7 +375,7 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
|
||||
file_index=file_idx,
|
||||
)
|
||||
|
||||
src_duration = get_media_duration_in_s(src_path, media_type="video")
|
||||
src_duration = get_video_duration_in_s(src_path)
|
||||
dst_key = (chunk_idx, file_idx)
|
||||
|
||||
if not dst_path.exists():
|
||||
@@ -457,7 +414,7 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
|
||||
current_dst_duration = dst_file_durations.get(dst_key, 0)
|
||||
videos_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = current_dst_duration
|
||||
videos_idx[key]["src_to_dst"][(src_chunk_idx, src_file_idx)] = dst_key
|
||||
concatenate_media_files(
|
||||
concatenate_video_files(
|
||||
[dst_path, src_path],
|
||||
dst_path,
|
||||
)
|
||||
@@ -472,101 +429,6 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
|
||||
return videos_idx
|
||||
|
||||
|
||||
def aggregate_audio(src_meta, dst_meta, audios_idx, audio_files_size_in_mb, chunk_size):
|
||||
"""Aggregates audio files from a source dataset into the destination dataset.
|
||||
|
||||
Handles audio file concatenation and rotation based on file size limits.
|
||||
Creates new audio files when size limits are exceeded.
|
||||
|
||||
Args:
|
||||
src_meta: Source dataset metadata.
|
||||
dst_meta: Destination dataset metadata.
|
||||
audio_idx: Dictionary tracking audio chunk and file indices.
|
||||
audio_files_size_in_mb: Maximum size for audio files in MB (defaults to DEFAULT_AUDIO_FILE_SIZE_IN_MB)
|
||||
chunk_size: Maximum number of files per chunk (defaults to DEFAULT_CHUNK_SIZE)
|
||||
|
||||
Returns:
|
||||
dict: Updated audio_idx with current chunk and file indices.
|
||||
"""
|
||||
for key in audios_idx:
|
||||
audios_idx[key]["episode_duration"] = 0
|
||||
# Track offset for each source (chunk, file) pair
|
||||
audios_idx[key]["src_to_offset"] = {}
|
||||
|
||||
for key, audio_idx in audios_idx.items():
|
||||
unique_chunk_file_pairs = {
|
||||
(chunk, file)
|
||||
for chunk, file in zip(
|
||||
src_meta.episodes[f"audio/{key}/chunk_index"],
|
||||
src_meta.episodes[f"audio/{key}/file_index"],
|
||||
strict=False,
|
||||
)
|
||||
}
|
||||
unique_chunk_file_pairs = sorted(unique_chunk_file_pairs)
|
||||
|
||||
chunk_idx = audio_idx["chunk"]
|
||||
file_idx = audio_idx["file"]
|
||||
current_offset = audio_idx["latest_duration"]
|
||||
|
||||
for src_chunk_idx, src_file_idx in unique_chunk_file_pairs:
|
||||
src_path = src_meta.root / DEFAULT_AUDIO_PATH.format(
|
||||
audio_key=key,
|
||||
chunk_index=src_chunk_idx,
|
||||
file_index=src_file_idx,
|
||||
)
|
||||
|
||||
dst_path = dst_meta.root / DEFAULT_AUDIO_PATH.format(
|
||||
audio_key=key,
|
||||
chunk_index=chunk_idx,
|
||||
file_index=file_idx,
|
||||
)
|
||||
|
||||
src_duration = get_media_duration_in_s(src_path, media_type="audio")
|
||||
|
||||
if not dst_path.exists():
|
||||
# Store offset before incrementing
|
||||
audios_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = current_offset
|
||||
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy(str(src_path), str(dst_path))
|
||||
audios_idx[key]["episode_duration"] += src_duration
|
||||
current_offset += src_duration
|
||||
continue
|
||||
|
||||
# Check file sizes before appending
|
||||
src_size = get_file_size_in_mb(src_path)
|
||||
dst_size = get_file_size_in_mb(dst_path)
|
||||
|
||||
if dst_size + src_size >= audio_files_size_in_mb:
|
||||
# Rotate to a new file, this source becomes start of new destination
|
||||
# So its offset should be 0
|
||||
audios_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = 0
|
||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, chunk_size)
|
||||
dst_path = dst_meta.root / DEFAULT_AUDIO_PATH.format(
|
||||
audio_key=key,
|
||||
chunk_index=chunk_idx,
|
||||
file_index=file_idx,
|
||||
)
|
||||
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy(str(src_path), str(dst_path))
|
||||
# Reset offset for next file
|
||||
current_offset = src_duration
|
||||
else:
|
||||
# Append to existing video file - use current accumulated offset
|
||||
audios_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = current_offset
|
||||
concatenate_media_files(
|
||||
[dst_path, src_path],
|
||||
dst_path,
|
||||
)
|
||||
current_offset += src_duration
|
||||
|
||||
audios_idx[key]["episode_duration"] += src_duration
|
||||
|
||||
audios_idx[key]["chunk"] = chunk_idx
|
||||
audios_idx[key]["file"] = file_idx
|
||||
|
||||
return audios_idx
|
||||
|
||||
|
||||
def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_size):
|
||||
"""Aggregates data chunks from a source dataset into the destination dataset.
|
||||
|
||||
@@ -639,7 +501,7 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
|
||||
return data_idx
|
||||
|
||||
|
||||
def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx, audios_idx):
|
||||
def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx):
|
||||
"""Aggregates metadata from a source dataset into the destination dataset.
|
||||
|
||||
Reads source metadata files, updates all indices and timestamps,
|
||||
@@ -651,7 +513,6 @@ def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx, audio
|
||||
meta_idx: Dictionary tracking metadata chunk and file indices.
|
||||
data_idx: Dictionary tracking data chunk and file indices.
|
||||
videos_idx: Dictionary tracking video indices and timestamps.
|
||||
audios_idx: Dictionary tracking audio indices and timestamps.
|
||||
|
||||
Returns:
|
||||
dict: Updated meta_idx with current chunk and file indices.
|
||||
@@ -675,7 +536,6 @@ def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx, audio
|
||||
meta_idx,
|
||||
data_idx,
|
||||
videos_idx,
|
||||
audios_idx,
|
||||
)
|
||||
|
||||
meta_idx, _ = append_or_create_parquet_file(
|
||||
@@ -692,8 +552,7 @@ def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx, audio
|
||||
# Increment latest_duration by the total duration added from this source dataset
|
||||
for k in videos_idx:
|
||||
videos_idx[k]["latest_duration"] += videos_idx[k]["episode_duration"]
|
||||
for k in audios_idx:
|
||||
audios_idx[k]["latest_duration"] += audios_idx[k]["episode_duration"]
|
||||
|
||||
return meta_idx
|
||||
|
||||
|
||||
|
||||
@@ -1,275 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import av
|
||||
import torch
|
||||
import torchaudio
|
||||
import torchcodec
|
||||
from numpy import ceil
|
||||
|
||||
CHANNELS_LAYOUTS_MAPPING = {
|
||||
1: "mono",
|
||||
2: "stereo",
|
||||
3: "2.1",
|
||||
4: "3.1",
|
||||
5: "4.1",
|
||||
6: "5.1",
|
||||
7: "6.1",
|
||||
8: "7.1",
|
||||
16: "hexadecagonal",
|
||||
24: "22.2",
|
||||
}
|
||||
|
||||
|
||||
def decode_audio(
|
||||
audio_path: Path | str,
|
||||
timestamps: list[float],
|
||||
duration: float,
|
||||
start_time_s: float | None = 0.0,
|
||||
backend: str | None = "torchcodec",
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Decodes audio using the specified backend.
|
||||
Args:
|
||||
audio_path (Path): Path to the audio file.
|
||||
timestamps (list[float]): List of (starting) timestamps to extract audio chunks.
|
||||
duration (float): Duration of the audio chunks in seconds.
|
||||
backend (str, optional): Backend to use for decoding. Defaults to "torchcodec".
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Decoded audio chunks.
|
||||
|
||||
Currently supports torchaudio.
|
||||
"""
|
||||
if backend == "torchcodec":
|
||||
return decode_audio_torchcodec(audio_path, timestamps, duration, start_time_s)
|
||||
elif backend == "torchaudio":
|
||||
return decode_audio_torchaudio(audio_path, timestamps, duration, start_time_s)
|
||||
else:
|
||||
raise ValueError(f"Unsupported video backend: {backend}")
|
||||
|
||||
|
||||
def decode_audio_torchcodec(
|
||||
audio_path: Path | str,
|
||||
timestamps: list[float],
|
||||
duration: float,
|
||||
start_time_s: float | None = 0.0,
|
||||
log_loaded_timestamps: bool = False,
|
||||
) -> torch.Tensor:
|
||||
# TODO(CarolinePascal) : add channels selection
|
||||
audio_decoder = torchcodec.decoders.AudioDecoder(audio_path)
|
||||
audio_sample_rate = audio_decoder.metadata.sample_rate
|
||||
audio_channels = audio_decoder.metadata.num_channels
|
||||
# TODO(CarolinePascal) : assert ts < total record duration
|
||||
|
||||
audio_chunks = []
|
||||
timestamps = [
|
||||
timestamp + start_time_s for timestamp in timestamps
|
||||
] # Add an offset of start_time_s to each timestamp
|
||||
for ts in timestamps:
|
||||
current_audio_chunk = audio_decoder.get_samples_played_in_range(
|
||||
start_seconds=max(0.0, ts - duration), stop_seconds=ts
|
||||
)
|
||||
|
||||
current_audio_chunk_data = current_audio_chunk.data
|
||||
|
||||
# Case where the requested audio chunk starts before the beginning of the audio stream
|
||||
if ts - duration < 0:
|
||||
# No useful audio sample has been recorded
|
||||
if ts < 1 / audio_sample_rate:
|
||||
# TODO(CarolinePascal) : add low level white noise instead of zeros ?
|
||||
current_audio_chunk_data = torch.zeros(
|
||||
(audio_channels, int(ceil(duration * audio_sample_rate)))
|
||||
)
|
||||
# At least one useful audio sample has been recorded
|
||||
else:
|
||||
# Pad the beginning of the audio chunk with zeros
|
||||
# TODO(CarolinePascal) : add low level white noise instead of zeros ?
|
||||
current_audio_chunk_data = torch.nn.functional.pad(
|
||||
current_audio_chunk_data,
|
||||
(int(ceil((duration - ts) * audio_sample_rate)), 0, 0, 0), # left, right, top, bottom
|
||||
)
|
||||
|
||||
if log_loaded_timestamps:
|
||||
logging.info(
|
||||
f"audio chunk loaded at timestamp={current_audio_chunk.pts_seconds:.4f} with duration={current_audio_chunk.duration_seconds:.4f}"
|
||||
)
|
||||
|
||||
audio_chunks.append(current_audio_chunk_data)
|
||||
|
||||
audio_chunks = torch.stack(audio_chunks)
|
||||
|
||||
assert len(timestamps) == len(audio_chunks)
|
||||
return audio_chunks
|
||||
|
||||
|
||||
def decode_audio_torchaudio(
|
||||
audio_path: Path | str,
|
||||
timestamps: list[float],
|
||||
duration: float,
|
||||
start_time_s: float | None = 0.0,
|
||||
log_loaded_timestamps: bool = False,
|
||||
) -> torch.Tensor:
|
||||
# TODO(CarolinePascal) : add channels selection
|
||||
audio_path = str(audio_path)
|
||||
|
||||
reader = torchaudio.io.StreamReader(src=audio_path)
|
||||
audio_sample_rate = reader.get_src_stream_info(reader.default_audio_stream).sample_rate
|
||||
audio_channels = reader.get_src_stream_info(reader.default_audio_stream).num_channels
|
||||
# TODO(CarolinePascal) : assert ts < total record duration
|
||||
|
||||
# TODO(CarolinePascal) : sort timestamps ?
|
||||
|
||||
reader.add_basic_audio_stream(
|
||||
frames_per_chunk=int(ceil(duration * audio_sample_rate)), # Too much is better than not enough
|
||||
buffer_chunk_size=-1, # No dropping frames
|
||||
format="fltp", # Format as float32
|
||||
)
|
||||
|
||||
audio_chunks = []
|
||||
timestamps = [
|
||||
timestamp + start_time_s for timestamp in timestamps
|
||||
] # Add an offset of start_time_s to each timestamp
|
||||
for ts in timestamps:
|
||||
reader.seek(max(0.0, ts - duration)) # Default to closest audio sample. Needs to be non-negative !
|
||||
status = reader.fill_buffer()
|
||||
if status != 0:
|
||||
# Should not happen, but just in case
|
||||
logging.warning("Audio stream reached end of recording before decoding desired timestamps.")
|
||||
|
||||
current_audio_chunk = reader.pop_chunks()[0]
|
||||
current_audio_chunk_data = current_audio_chunk.t() # Channel first format
|
||||
|
||||
# Case where the requested audio chunk starts before the beginning of the audio stream
|
||||
if ts - duration < 0:
|
||||
# No useful audio sample has been recorded
|
||||
if ts < 1 / audio_sample_rate:
|
||||
current_audio_chunk_data = torch.zeros(
|
||||
(audio_channels, int(ceil(duration * audio_sample_rate)))
|
||||
)
|
||||
# At least one useful audio sample has been recorded
|
||||
else:
|
||||
# Remove the superfluous last samples of the audio chunk
|
||||
current_audio_chunk_data = current_audio_chunk_data[:, : int(ceil(ts * audio_sample_rate))]
|
||||
# Pad the beginning of the audio chunk with zeros
|
||||
# TODO(CarolinePascal) : add low level white noise instead of zeros ?
|
||||
current_audio_chunk_data = torch.nn.functional.pad(
|
||||
current_audio_chunk_data,
|
||||
(int(ceil((duration - ts) * audio_sample_rate)), 0, 0, 0), # left, right, top, bottom
|
||||
)
|
||||
|
||||
if log_loaded_timestamps:
|
||||
logging.info(
|
||||
f"audio chunk loaded at starting timestamp={current_audio_chunk['pts']:.4f} with duration={len(current_audio_chunk) / audio_sample_rate:.4f}"
|
||||
)
|
||||
|
||||
audio_chunks.append(current_audio_chunk_data)
|
||||
|
||||
audio_chunks = torch.stack(audio_chunks)
|
||||
|
||||
assert len(timestamps) == len(audio_chunks)
|
||||
return audio_chunks
|
||||
|
||||
|
||||
def encode_audio(
|
||||
input_path: Path | str,
|
||||
output_path: Path | str,
|
||||
codec: str = "aac", # TODO(CarolinePascal) : investigate Fraunhofer FDK AAC (libfdk_aac) codec and and constant (file size control) /variable (quality control) bitrate options
|
||||
bit_rate: int | None = None,
|
||||
sample_rate: int | None = None,
|
||||
log_level: int | None = av.logging.ERROR,
|
||||
overwrite: bool = False,
|
||||
) -> None:
|
||||
"""Encodes an audio file using ffmpeg."""
|
||||
output_path = Path(output_path)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=overwrite)
|
||||
|
||||
# Set logging level
|
||||
if log_level is not None:
|
||||
# "While less efficient, it is generally preferable to modify logging with Python’s logging"
|
||||
logging.getLogger("libav").setLevel(log_level)
|
||||
|
||||
# Open input file
|
||||
with av.open(str(input_path), "r") as input:
|
||||
input_stream = input.streams.audio[0] # Assuming the first stream is the audio stream to be encoded
|
||||
|
||||
# Define sub-sampling options
|
||||
if sample_rate is None:
|
||||
sample_rate = input_stream.rate
|
||||
|
||||
# Create and open output file (overwrite by default)
|
||||
with av.open(str(output_path), "w") as output:
|
||||
output_stream = output.add_stream(
|
||||
codec, rate=sample_rate, layout=CHANNELS_LAYOUTS_MAPPING[input_stream.channels]
|
||||
)
|
||||
|
||||
if bit_rate is not None:
|
||||
output_stream.bit_rate = bit_rate
|
||||
|
||||
# Loop through input WAV packets and encode them
|
||||
for input_frame in input.decode(
|
||||
input_stream
|
||||
): # This step handles both demuxing and decoding under the hood
|
||||
packet = output_stream.encode(input_frame)
|
||||
if packet:
|
||||
output.mux(packet)
|
||||
|
||||
# Flush the encoder
|
||||
packet = output_stream.encode()
|
||||
if packet:
|
||||
output.mux(packet)
|
||||
|
||||
# Reset logging level
|
||||
if log_level is not None:
|
||||
av.logging.restore_default_callback()
|
||||
|
||||
if not output_path.exists():
|
||||
raise OSError(f"Audio encoding did not work. File not found: {output_path}.")
|
||||
|
||||
|
||||
def get_audio_info(video_path: Path | str) -> dict:
|
||||
# Set logging level
|
||||
logging.getLogger("libav").setLevel(av.logging.ERROR)
|
||||
|
||||
# Getting audio stream information
|
||||
audio_info = {}
|
||||
with av.open(str(video_path), "r") as audio_file:
|
||||
try:
|
||||
audio_stream = audio_file.streams.audio[0]
|
||||
except IndexError:
|
||||
# Reset logging level
|
||||
av.logging.restore_default_callback()
|
||||
return {"has_audio": False}
|
||||
|
||||
audio_info["audio.channels"] = audio_stream.channels
|
||||
audio_info["audio.codec"] = audio_stream.codec.canonical_name
|
||||
# In an ideal loseless case : bit depth x sample rate x channels = bit rate.
|
||||
# In an actual compressed case, the bit rate is set according to the compression level : the lower the bit rate, the more compression is applied.
|
||||
audio_info["audio.bit_rate"] = audio_stream.bit_rate
|
||||
audio_info["audio.sample_rate"] = audio_stream.sample_rate # Number of samples per second
|
||||
# In an ideal loseless case : fixed number of bits per sample.
|
||||
# In an actual compressed case : variable number of bits per sample (often reduced to match a given depth rate).
|
||||
audio_info["audio.bit_depth"] = audio_stream.format.bits
|
||||
audio_info["audio.channel_layout"] = audio_stream.layout.name
|
||||
audio_info["has_audio"] = True
|
||||
|
||||
# Reset logging level
|
||||
av.logging.restore_default_callback()
|
||||
|
||||
return audio_info
|
||||
@@ -19,7 +19,8 @@ import logging
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.datasets.io_utils import load_audio_from_path, load_image_as_numpy
|
||||
from lerobot.datasets.io_utils import load_image_as_numpy
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||
|
||||
DEFAULT_QUANTILES = [0.01, 0.10, 0.50, 0.90, 0.99]
|
||||
|
||||
@@ -249,20 +250,6 @@ def sample_images(image_paths: list[str]) -> np.ndarray:
|
||||
return images
|
||||
|
||||
|
||||
def sample_audio_from_path(audio_path: str) -> np.ndarray:
|
||||
"""Samples audio data from an audio recording stored in a WAV file."""
|
||||
data = load_audio_from_path(audio_path)
|
||||
sampled_indices = sample_indices(len(data))
|
||||
|
||||
return data[sampled_indices]
|
||||
|
||||
|
||||
def sample_audio_from_data(data: np.ndarray) -> np.ndarray:
|
||||
"""Samples audio data from an audio recording stored in a numpy array."""
|
||||
sampled_indices = sample_indices(len(data))
|
||||
return data[sampled_indices]
|
||||
|
||||
|
||||
def _reshape_stats_by_axis(
|
||||
stats: dict[str, np.ndarray],
|
||||
axis: int | tuple[int, ...] | None,
|
||||
@@ -530,13 +517,6 @@ def compute_episode_stats(
|
||||
ep_ft_array = sample_images(data)
|
||||
axes_to_reduce = (0, 2, 3)
|
||||
keepdims = True
|
||||
elif features[key]["dtype"] == "audio":
|
||||
try:
|
||||
ep_ft_array = sample_audio_from_path(data[0])
|
||||
except TypeError: # Should only be triggered for LeKiwi robot, for which audio is stored chunk by chunk in a visual frame-like manner
|
||||
ep_ft_array = sample_audio_from_data(data)
|
||||
axes_to_reduce = 0
|
||||
keepdims = True
|
||||
else:
|
||||
ep_ft_array = data
|
||||
axes_to_reduce = 0
|
||||
|
||||
@@ -23,7 +23,6 @@ import pyarrow as pa
|
||||
import pyarrow.parquet as pq
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from lerobot.datasets.audio_utils import get_audio_info
|
||||
from lerobot.datasets.compute_stats import aggregate_stats
|
||||
from lerobot.datasets.feature_utils import _validate_feature_names, create_empty_dataset_info
|
||||
from lerobot.datasets.io_utils import (
|
||||
@@ -41,7 +40,6 @@ from lerobot.datasets.io_utils import (
|
||||
from lerobot.datasets.utils import (
|
||||
DEFAULT_EPISODES_PATH,
|
||||
DEFAULT_FEATURES,
|
||||
DEFAULT_INITIAL_AUDIO_BUFFER_DURATION,
|
||||
INFO_PATH,
|
||||
check_version_compatibility,
|
||||
flatten_dict,
|
||||
@@ -271,32 +269,6 @@ class LeRobotDatasetMetadata:
|
||||
fpath = self.video_path.format(video_key=vid_key, chunk_index=chunk_idx, file_index=file_idx)
|
||||
return Path(fpath)
|
||||
|
||||
def get_audio_file_path(self, ep_index: int, audio_key: str) -> Path:
|
||||
"""Return the relative audio file path for the given episode and audio key.
|
||||
|
||||
Args:
|
||||
ep_index: Zero-based episode index.
|
||||
audio_key: Feature key identifying the audio stream
|
||||
(e.g. ``'observation.audio.microphone'``).
|
||||
|
||||
Returns:
|
||||
Path to the audio file containing this episode's audio.
|
||||
|
||||
Raises:
|
||||
IndexError: If ``ep_index`` is out of range.
|
||||
"""
|
||||
if self.episodes is None:
|
||||
self.episodes = load_episodes(self.root)
|
||||
if ep_index >= len(self.episodes):
|
||||
raise IndexError(
|
||||
f"Episode index {ep_index} out of range. Episodes: {len(self.episodes) if self.episodes else 0}"
|
||||
)
|
||||
ep = self.episodes[ep_index]
|
||||
chunk_idx = ep[f"audio/{audio_key}/chunk_index"]
|
||||
file_idx = ep[f"audio/{audio_key}/file_index"]
|
||||
fpath = self.audio_path.format(audio_key=audio_key, chunk_index=chunk_idx, file_index=file_idx)
|
||||
return Path(fpath)
|
||||
|
||||
@property
|
||||
def data_path(self) -> str:
|
||||
"""Formattable string for the parquet files."""
|
||||
@@ -307,11 +279,6 @@ class LeRobotDatasetMetadata:
|
||||
"""Formattable string for the video files."""
|
||||
return self.info["video_path"]
|
||||
|
||||
@property
|
||||
def audio_path(self) -> str | None:
|
||||
"""Formattable string for the audio files."""
|
||||
return self.info["audio_path"]
|
||||
|
||||
@property
|
||||
def robot_type(self) -> str | None:
|
||||
"""Robot type used in recording this dataset."""
|
||||
@@ -342,11 +309,6 @@ class LeRobotDatasetMetadata:
|
||||
"""Keys to access visual modalities (regardless of their storage method)."""
|
||||
return [key for key, ft in self.features.items() if ft["dtype"] in ["video", "image"]]
|
||||
|
||||
@property
|
||||
def audio_keys(self) -> list[str]:
|
||||
"""Keys to access audio modalities."""
|
||||
return [key for key, ft in self.features.items() if ft["dtype"] == "audio"]
|
||||
|
||||
@property
|
||||
def names(self) -> dict[str, list | dict]:
|
||||
"""Names of the various dimensions of vector modalities."""
|
||||
@@ -387,11 +349,6 @@ class LeRobotDatasetMetadata:
|
||||
"""Max size of video file in mega bytes."""
|
||||
return self.info["video_files_size_in_mb"]
|
||||
|
||||
@property
|
||||
def audio_files_size_in_mb(self) -> int:
|
||||
"""Max size of audio file in mega bytes."""
|
||||
return self.info["audio_files_size_in_mb"]
|
||||
|
||||
def get_task_index(self, task: str) -> int | None:
|
||||
"""
|
||||
Given a task in natural language, returns its task_index if the task already exists in the dataset,
|
||||
@@ -558,27 +515,11 @@ class LeRobotDatasetMetadata:
|
||||
video_path = self.root / self.video_path.format(video_key=key, chunk_index=0, file_index=0)
|
||||
self.info["features"][key]["info"] = get_video_info(video_path)
|
||||
|
||||
def update_audio_info(self, audio_key: str | None = None) -> None:
|
||||
"""
|
||||
Warning: this function writes info from first episode audio, implicitly assuming that all audio have
|
||||
been encoded the same way. Also, this means it assumes the first episode exists.
|
||||
"""
|
||||
if audio_key is not None and audio_key not in self.audio_keys:
|
||||
raise ValueError(f"Audio key {audio_key} not found in dataset")
|
||||
|
||||
audio_keys = [audio_key] if audio_key is not None else self.audio_keys
|
||||
for key in audio_keys:
|
||||
if not self.features[key].get("info", None):
|
||||
audio_path = self.root / self.audio_path.format(audio_key=key, chunk_index=0, file_index=0)
|
||||
self.info["features"][key]["info"] = get_audio_info(audio_path)
|
||||
self.info["features"][key]["info"]["start_time_s"] = DEFAULT_INITIAL_AUDIO_BUFFER_DURATION
|
||||
|
||||
def update_chunk_settings(
|
||||
self,
|
||||
chunks_size: int | None = None,
|
||||
data_files_size_in_mb: int | None = None,
|
||||
video_files_size_in_mb: int | None = None,
|
||||
audio_files_size_in_mb: int | None = None,
|
||||
) -> None:
|
||||
"""Update chunk and file size settings after dataset creation.
|
||||
|
||||
@@ -590,7 +531,6 @@ class LeRobotDatasetMetadata:
|
||||
chunks_size: Maximum number of files per chunk directory. If None, keeps current value.
|
||||
data_files_size_in_mb: Maximum size for data parquet files in MB. If None, keeps current value.
|
||||
video_files_size_in_mb: Maximum size for video files in MB. If None, keeps current value.
|
||||
audio_files_size_in_mb: Maximum size for audio files in MB. If None, keeps current value.
|
||||
"""
|
||||
if chunks_size is not None:
|
||||
if chunks_size <= 0:
|
||||
@@ -607,11 +547,6 @@ class LeRobotDatasetMetadata:
|
||||
raise ValueError(f"video_files_size_in_mb must be positive, got {video_files_size_in_mb}")
|
||||
self.info["video_files_size_in_mb"] = video_files_size_in_mb
|
||||
|
||||
if audio_files_size_in_mb is not None:
|
||||
if audio_files_size_in_mb <= 0:
|
||||
raise ValueError(f"audio_files_size_in_mb must be positive, got {audio_files_size_in_mb}")
|
||||
self.info["audio_files_size_in_mb"] = audio_files_size_in_mb
|
||||
|
||||
# Update the info file on disk
|
||||
write_info(self.info, self.root)
|
||||
|
||||
@@ -619,13 +554,12 @@ class LeRobotDatasetMetadata:
|
||||
"""Get current chunk and file size settings.
|
||||
|
||||
Returns:
|
||||
Dict containing chunks_size, data_files_size_in_mb, video_files_size_in_mb, and audio_files_size_in_mb.
|
||||
Dict containing chunks_size, data_files_size_in_mb, and video_files_size_in_mb.
|
||||
"""
|
||||
return {
|
||||
"chunks_size": self.chunks_size,
|
||||
"data_files_size_in_mb": self.data_files_size_in_mb,
|
||||
"video_files_size_in_mb": self.video_files_size_in_mb,
|
||||
"audio_files_size_in_mb": self.audio_files_size_in_mb,
|
||||
}
|
||||
|
||||
def __repr__(self):
|
||||
@@ -652,7 +586,6 @@ class LeRobotDatasetMetadata:
|
||||
chunks_size: int | None = None,
|
||||
data_files_size_in_mb: int | None = None,
|
||||
video_files_size_in_mb: int | None = None,
|
||||
audio_files_size_in_mb: int | None = None,
|
||||
) -> "LeRobotDatasetMetadata":
|
||||
"""Create metadata for a new LeRobot dataset from scratch.
|
||||
|
||||
@@ -703,7 +636,6 @@ class LeRobotDatasetMetadata:
|
||||
chunks_size,
|
||||
data_files_size_in_mb,
|
||||
video_files_size_in_mb,
|
||||
audio_files_size_in_mb,
|
||||
)
|
||||
if len(obj.video_keys) > 0 and not use_videos:
|
||||
raise ValueError(
|
||||
|
||||
@@ -21,14 +21,12 @@ from pathlib import Path
|
||||
import datasets
|
||||
import torch
|
||||
|
||||
from lerobot.datasets.audio_utils import decode_audio
|
||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
|
||||
from lerobot.datasets.feature_utils import (
|
||||
check_delta_timestamps,
|
||||
get_delta_indices,
|
||||
get_hf_features_from_features,
|
||||
)
|
||||
from lerobot.datasets.utils import DEFAULT_AUDIO_CHUNK_DURATION
|
||||
from lerobot.datasets.io_utils import (
|
||||
hf_transform_to_torch,
|
||||
load_nested_dataset,
|
||||
@@ -132,7 +130,7 @@ class DatasetReader:
|
||||
return hf_dataset
|
||||
|
||||
def _check_cached_episodes_sufficient(self) -> bool:
|
||||
"""Check if the cached dataset contains all requested episodes and their video and audio files."""
|
||||
"""Check if the cached dataset contains all requested episodes and their video files."""
|
||||
if self.hf_dataset is None or len(self.hf_dataset) == 0:
|
||||
return False
|
||||
|
||||
@@ -156,13 +154,6 @@ class DatasetReader:
|
||||
if not video_path.exists():
|
||||
return False
|
||||
|
||||
if len(self._meta.audio_keys) > 0:
|
||||
for ep_idx in requested_episodes:
|
||||
for audio_key in self._meta.audio_keys:
|
||||
audio_path = self.root / self._meta.get_compressed_audio_file_path(ep_idx, audio_key)
|
||||
if not audio_path.exists():
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def get_episodes_file_paths(self) -> list[Path]:
|
||||
@@ -179,15 +170,6 @@ class DatasetReader:
|
||||
for ep_idx in episodes
|
||||
]
|
||||
fpaths += video_files
|
||||
|
||||
if len(self._meta.audio_keys) > 0:
|
||||
audio_files = [
|
||||
str(self._meta.get_compressed_audio_file_path(ep_idx, audio_key))
|
||||
for audio_key in self._meta.audio_keys
|
||||
for ep_idx in episodes
|
||||
]
|
||||
fpaths += audio_files
|
||||
|
||||
# episodes are stored in the same files, so we return unique paths only
|
||||
fpaths = list(set(fpaths))
|
||||
return fpaths
|
||||
@@ -217,7 +199,7 @@ class DatasetReader:
|
||||
query_indices: dict[str, list[int]] | None = None,
|
||||
) -> dict[str, list[float]]:
|
||||
query_timestamps = {}
|
||||
for key in self._meta.video_keys + self._meta.audio_keys:
|
||||
for key in self._meta.video_keys:
|
||||
if query_indices is not None and key in query_indices:
|
||||
if self._absolute_to_relative_idx is not None:
|
||||
relative_indices = [self._absolute_to_relative_idx[idx] for idx in query_indices[key]]
|
||||
@@ -231,10 +213,10 @@ class DatasetReader:
|
||||
return query_timestamps
|
||||
|
||||
def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict:
|
||||
"""Query dataset for indices across keys, skipping video and audio keys."""
|
||||
"""Query dataset for indices across keys, skipping video keys."""
|
||||
result: dict = {}
|
||||
for key, q_idx in query_indices.items():
|
||||
if key in self._meta.video_keys or key in self._meta.audio_keys:
|
||||
if key in self._meta.video_keys:
|
||||
continue
|
||||
relative_indices = (
|
||||
q_idx
|
||||
@@ -264,28 +246,6 @@ class DatasetReader:
|
||||
|
||||
return item
|
||||
|
||||
# TODO(CarolinePascal): add variable query durations
|
||||
def _query_audio(
|
||||
self, query_timestamps: dict[str, list[float]], query_duration: float, ep_idx: int
|
||||
) -> dict[str, torch.Tensor]:
|
||||
ep = self.meta.episodes[ep_idx]
|
||||
item = {}
|
||||
for audio_key, query_ts in query_timestamps.items():
|
||||
# Episodes are stored sequentially on a single mp4 to reduce the number of files.
|
||||
# Thus we load the start timestamp of the episode on this mp4 and,
|
||||
# shift the query timestamp accordingly.
|
||||
from_timestamp = ep[f"audio/{audio_key}/from_timestamp"]
|
||||
shifted_query_ts = [from_timestamp + ts for ts in query_ts]
|
||||
|
||||
audio_path = self.root / self.meta.get_audio_file_path(ep_idx, audio_key)
|
||||
start_time_s = self.meta.features[audio_key]["info"].get("start_time_s", 0.0)
|
||||
audio_chunk = decode_audio(
|
||||
audio_path, shifted_query_ts, query_duration, start_time_s, self.audio_backend
|
||||
)
|
||||
item[audio_key] = audio_chunk.squeeze(0)
|
||||
|
||||
return item
|
||||
|
||||
def get_item(self, idx) -> dict:
|
||||
"""Core __getitem__ logic. Assumes hf_dataset is loaded.
|
||||
|
||||
@@ -305,12 +265,11 @@ class DatasetReader:
|
||||
for key, val in query_result.items():
|
||||
item[key] = val
|
||||
|
||||
if len(self._meta.video_keys) > 0 or len(self._meta.audio_keys) > 0:
|
||||
if len(self._meta.video_keys) > 0:
|
||||
current_ts = item["timestamp"].item()
|
||||
query_timestamps = self._get_query_timestamps(current_ts, query_indices)
|
||||
video_frames = self._query_videos(query_timestamps, ep_idx)
|
||||
audio_chunks = self._query_audio(query_timestamps, DEFAULT_AUDIO_CHUNK_DURATION, ep_idx)
|
||||
item = {**video_frames, **audio_chunks, **item}
|
||||
item = {**video_frames, **item}
|
||||
|
||||
if self._image_transforms is not None:
|
||||
image_keys = self._meta.camera_keys
|
||||
|
||||
@@ -31,7 +31,6 @@ import PIL.Image
|
||||
import pyarrow.parquet as pq
|
||||
import torch
|
||||
|
||||
from lerobot.datasets.audio_utils import encode_audio
|
||||
from lerobot.datasets.compute_stats import compute_episode_stats
|
||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
|
||||
from lerobot.datasets.feature_utils import (
|
||||
@@ -49,17 +48,14 @@ from lerobot.datasets.io_utils import (
|
||||
from lerobot.datasets.utils import (
|
||||
DEFAULT_EPISODES_PATH,
|
||||
DEFAULT_IMAGE_PATH,
|
||||
DEFAULT_RAW_AUDIO_PATH,
|
||||
update_chunk_file_indices,
|
||||
)
|
||||
from lerobot.datasets.video_utils import (
|
||||
StreamingVideoEncoder,
|
||||
concatenate_media_files,
|
||||
concatenate_video_files,
|
||||
encode_video_frames,
|
||||
get_media_duration_in_s,
|
||||
get_video_duration_in_s,
|
||||
)
|
||||
from lerobot.microphones.microphone import Microphone
|
||||
from lerobot.microphones.utils import async_microphones_start_recording
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -148,10 +144,6 @@ class DatasetWriter:
|
||||
def _get_image_file_dir(self, episode_index: int, image_key: str) -> Path:
|
||||
return self._get_image_file_path(episode_index, image_key, frame_index=0).parent
|
||||
|
||||
def _get_raw_audio_file_path(self, episode_index: int, audio_key: str) -> Path:
|
||||
fpath = DEFAULT_RAW_AUDIO_PATH.format(audio_key=audio_key, episode_index=episode_index)
|
||||
return self._root / fpath
|
||||
|
||||
def _save_image(
|
||||
self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path, compress_level: int = 1
|
||||
) -> None:
|
||||
@@ -216,43 +208,11 @@ class DatasetWriter:
|
||||
compress_level = 1 if self._meta.features[key]["dtype"] == "video" else 6
|
||||
self._save_image(frame[key], img_path, compress_level)
|
||||
self.episode_buffer[key].append(str(img_path))
|
||||
elif self._meta.features[key]["dtype"] == "audio":
|
||||
if (
|
||||
self._meta.robot_type == "lekiwi"
|
||||
): # Raw data storage should only be triggered for LeKiwi robot, for which audio is stored chunk by chunk in a visual frame-like manner
|
||||
self.episode_buffer[key].append(frame[key])
|
||||
else: # Otherwise, only the audio file path is stored in the episode buffer
|
||||
if frame_index == 0:
|
||||
audio_path = self._get_raw_audio_file_path(
|
||||
episode_index=self.episode_buffer["episode_index"], audio_key=key
|
||||
)
|
||||
self.episode_buffer[key].append(str(audio_path))
|
||||
else:
|
||||
self.episode_buffer[key].append(frame[key])
|
||||
|
||||
self.episode_buffer["size"] += 1
|
||||
|
||||
def add_microphone_recording(self, microphone_key: str, microphone: Microphone) -> None:
|
||||
"""
|
||||
Starts recording audio data provided by the microphone and directly writes it in a .wav file.
|
||||
"""
|
||||
|
||||
audio_file = self._get_raw_audio_file_path(self._meta.total_episodes, "observation.audio." + microphone_key)
|
||||
microphone.start_recording(output_file=audio_file)
|
||||
|
||||
def add_microphones_recordings(self, microphones: dict[str, Microphone]) -> None:
|
||||
"""
|
||||
Starts recording audio data provided by multiple microphones and directly writes it in appropriate .wav files.
|
||||
"""
|
||||
|
||||
output_files = []
|
||||
for microphone_key in microphones:
|
||||
output_files.append(
|
||||
self._get_raw_audio_file_path(self._meta.total_episodes, "observation.audio." + microphone_key)
|
||||
)
|
||||
|
||||
async_microphones_start_recording(microphones, output_files)
|
||||
|
||||
def save_episode(
|
||||
self,
|
||||
episode_data: dict | None = None,
|
||||
@@ -281,19 +241,12 @@ class DatasetWriter:
|
||||
for key, ft in self._meta.features.items():
|
||||
if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]:
|
||||
continue
|
||||
elif ft["dtype"] == "audio":
|
||||
if (
|
||||
self._meta.robot_type == "lekiwi"
|
||||
): # Raw data storage should only be triggered for LeKiwi robot, for which audio is stored chunk by chunk in a visual frame-like manner
|
||||
episode_buffer[key] = np.concatenate(episode_buffer[key], axis=0)
|
||||
continue
|
||||
episode_buffer[key] = np.stack(episode_buffer[key])
|
||||
|
||||
# Wait for image writer to end, so that episode stats over images can be computed
|
||||
self._wait_image_writer()
|
||||
|
||||
has_video_keys = len(self._meta.video_keys) > 0
|
||||
has_audio_keys = len(self._meta.audio_keys) > 0
|
||||
use_streaming = self._streaming_encoder is not None and has_video_keys
|
||||
use_batched_encoding = self._batch_encoding_size > 1
|
||||
|
||||
@@ -320,7 +273,7 @@ class DatasetWriter:
|
||||
for k, v in video_stats.items()
|
||||
}
|
||||
ep_metadata.update(self._save_episode_video(video_key, episode_index, temp_path=temp_path))
|
||||
elif (has_video_keys or has_audio_keys) and not use_batched_encoding:
|
||||
elif has_video_keys and not use_batched_encoding:
|
||||
num_cameras = len(self._meta.video_keys)
|
||||
if parallel_encoding and num_cameras > 1:
|
||||
with concurrent.futures.ProcessPoolExecutor(max_workers=num_cameras) as executor:
|
||||
@@ -356,28 +309,19 @@ class DatasetWriter:
|
||||
for video_key in self._meta.video_keys:
|
||||
ep_metadata.update(self._save_episode_video(video_key, episode_index))
|
||||
|
||||
# TODO(Caroline): add parallel encoding for audio as well
|
||||
for audio_key in self._meta.audio_keys:
|
||||
ep_metadata.update(self._save_episode_audio(audio_key, episode_index))
|
||||
|
||||
# `meta.save_episode` need to be executed after encoding the videos
|
||||
self._meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats, ep_metadata)
|
||||
|
||||
if (has_video_keys or has_audio_keys) and use_batched_encoding:
|
||||
if has_video_keys and use_batched_encoding:
|
||||
self._episodes_since_last_encoding += 1
|
||||
if self._episodes_since_last_encoding == self._batch_encoding_size:
|
||||
start_ep = self._meta.total_episodes - self._batch_encoding_size
|
||||
end_ep = self._meta.total_episodes
|
||||
if has_video_keys:
|
||||
self._batch_save_episode_video(start_ep, end_ep)
|
||||
if has_audio_keys:
|
||||
self._batch_save_episode_audio(start_ep, end_ep)
|
||||
self._batch_save_episode_video(start_ep, end_ep)
|
||||
self._episodes_since_last_encoding = 0
|
||||
|
||||
if episode_data is None:
|
||||
self.clear_episode_buffer(
|
||||
delete_images=len(self._meta.image_keys) > 0, delete_audio=len(self._meta.audio_keys) > 0
|
||||
)
|
||||
self.clear_episode_buffer(delete_images=len(self._meta.image_keys) > 0)
|
||||
|
||||
def _batch_save_episode_video(self, start_episode: int, end_episode: int | None = None) -> None:
|
||||
"""Batch save videos for multiple episodes."""
|
||||
@@ -424,59 +368,6 @@ class DatasetWriter:
|
||||
episode_df.to_parquet(episode_df_path)
|
||||
self._meta.episodes = load_episodes(self._root)
|
||||
|
||||
def _batch_save_episode_audio(self, start_episode: int, end_episode: int | None = None) -> None:
|
||||
"""
|
||||
Batch save audio for multiple episodes.
|
||||
|
||||
Args:
|
||||
start_episode: Starting episode index (inclusive)
|
||||
end_episode: Ending episode index (exclusive). If None, encodes all episodes from start_episode to the current episode.
|
||||
"""
|
||||
if end_episode is None:
|
||||
end_episode = self._meta.total_episodes
|
||||
|
||||
logging.info(
|
||||
f"Batch encoding {self.batch_encoding_size} audio for episodes {start_episode} to {end_episode - 1}"
|
||||
)
|
||||
|
||||
chunk_idx = self._meta.episodes[start_episode]["data/chunk_index"]
|
||||
file_idx = self._meta.episodes[start_episode]["data/file_index"]
|
||||
episode_df_path = self._root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||
episode_df = pd.read_parquet(episode_df_path)
|
||||
|
||||
for ep_idx in range(start_episode, end_episode):
|
||||
logging.info(f"Encoding audio for episode {ep_idx}")
|
||||
|
||||
if (
|
||||
self._meta.episodes[ep_idx]["data/chunk_index"] != chunk_idx
|
||||
or self._meta.episodes[ep_idx]["data/file_index"] != file_idx
|
||||
):
|
||||
# The current episode is in a new chunk or file.
|
||||
# Save previous episode dataframe and update the Hugging Face dataset by reloading it.
|
||||
episode_df.to_parquet(episode_df_path)
|
||||
self._meta.episodes = load_episodes(self._root)
|
||||
|
||||
# Load new episode dataframe
|
||||
chunk_idx = self._meta.episodes[ep_idx]["data/chunk_index"]
|
||||
file_idx = self._meta.episodes[ep_idx]["data/file_index"]
|
||||
episode_df_path = self._root / DEFAULT_EPISODES_PATH.format(
|
||||
chunk_index=chunk_idx, file_index=file_idx
|
||||
)
|
||||
episode_df = pd.read_parquet(episode_df_path)
|
||||
|
||||
# Save the current episode's video metadata to the dataframe
|
||||
audio_ep_metadata = {}
|
||||
for audio_key in self._meta.audio_keys:
|
||||
audio_ep_metadata.update(self._save_episode_audio(audio_key, ep_idx))
|
||||
audio_ep_metadata.pop("episode_index")
|
||||
audio_ep_df = pd.DataFrame(audio_ep_metadata, index=[ep_idx]).convert_dtypes(
|
||||
dtype_backend="pyarrow"
|
||||
) # allows NaN values along with integers
|
||||
|
||||
episode_df = episode_df.combine_first(audio_ep_df)
|
||||
episode_df.to_parquet(episode_df_path)
|
||||
self._meta.episodes = load_episodes(self._root)
|
||||
|
||||
def _save_episode_data(self, episode_buffer: dict) -> dict:
|
||||
"""Save episode data to a parquet file."""
|
||||
# Use metadata features as the authoritative schema
|
||||
@@ -554,7 +445,7 @@ class DatasetWriter:
|
||||
ep_path = temp_path
|
||||
|
||||
ep_size_in_mb = get_file_size_in_mb(ep_path)
|
||||
ep_duration_in_s = get_media_duration_in_s(ep_path, media_type="video")
|
||||
ep_duration_in_s = get_video_duration_in_s(ep_path)
|
||||
|
||||
if (
|
||||
episode_index == 0
|
||||
@@ -594,7 +485,7 @@ class DatasetWriter:
|
||||
shutil.move(str(ep_path), str(new_path))
|
||||
latest_duration_in_s = 0.0
|
||||
else:
|
||||
concatenate_media_files(
|
||||
concatenate_video_files(
|
||||
[latest_path, ep_path],
|
||||
latest_path,
|
||||
)
|
||||
@@ -616,91 +507,7 @@ class DatasetWriter:
|
||||
}
|
||||
return metadata
|
||||
|
||||
def _encode_temporary_episode_audio(self, audio_key: str, episode_index: int) -> Path:
|
||||
"""
|
||||
Use ffmpeg to convert raw audio files into m4a audio files.
|
||||
Note: `encode_episode_audio` is a blocking call. Making it asynchronous shouldn't speedup encoding,
|
||||
since audio encoding with ffmpeg is already using multithreading.
|
||||
"""
|
||||
temp_path = Path(tempfile.mkdtemp(dir=self._root)) / f"{audio_key}_{episode_index:03d}.m4a"
|
||||
raw_audio_file = self._get_raw_audio_file_path(episode_index, audio_key)
|
||||
encode_audio(raw_audio_file, temp_path, overwrite=True)
|
||||
raw_audio_file.unlink()
|
||||
return temp_path
|
||||
|
||||
def _save_episode_audio(self, audio_key: str, episode_index: int) -> dict:
|
||||
# Encode episode audio into a temporary audio file
|
||||
ep_path = self._encode_temporary_episode_audio(audio_key, episode_index)
|
||||
ep_size_in_mb = get_file_size_in_mb(ep_path)
|
||||
ep_duration_in_s = get_media_duration_in_s(ep_path, media_type="audio")
|
||||
|
||||
if (
|
||||
episode_index == 0
|
||||
or self._meta.latest_episode is None
|
||||
or f"audio/{audio_key}/chunk_index" not in self._meta.latest_episode
|
||||
):
|
||||
# Initialize indices for a new dataset made of the first episode data
|
||||
chunk_idx, file_idx = 0, 0
|
||||
if self._meta.episodes is not None and len(self._meta.episodes) > 0:
|
||||
# It means we are resuming recording, so we need to load the latest episode
|
||||
# Update the indices to avoid overwriting the latest episode
|
||||
old_chunk_idx = self._meta.episodes[-1][f"audio/{audio_key}/chunk_index"]
|
||||
old_file_idx = self._meta.episodes[-1][f"audio/{audio_key}/file_index"]
|
||||
chunk_idx, file_idx = update_chunk_file_indices(
|
||||
old_chunk_idx, old_file_idx, self._meta.chunks_size
|
||||
)
|
||||
latest_duration_in_s = 0.0
|
||||
new_path = self._root / self._meta.audio_path.format(
|
||||
audio_key=audio_key, chunk_index=chunk_idx, file_index=file_idx
|
||||
)
|
||||
new_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.move(str(ep_path), str(new_path))
|
||||
else:
|
||||
# Retrieve information from the latest updated audio file using latest_episode
|
||||
latest_ep = self._meta.latest_episode
|
||||
chunk_idx = latest_ep[f"audio/{audio_key}/chunk_index"][0]
|
||||
file_idx = latest_ep[f"audio/{audio_key}/file_index"][0]
|
||||
|
||||
latest_path = self._root / self._meta.audio_path.format(
|
||||
audio_key=audio_key, chunk_index=chunk_idx, file_index=file_idx
|
||||
)
|
||||
latest_size_in_mb = get_file_size_in_mb(latest_path)
|
||||
latest_duration_in_s = latest_ep[f"audio/{audio_key}/to_timestamp"][0]
|
||||
|
||||
if latest_size_in_mb + ep_size_in_mb >= self._meta.audio_files_size_in_mb:
|
||||
# Move temporary episode audio to a new audio file in the dataset
|
||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self._meta.chunks_size)
|
||||
new_path = self._root / self._meta.audio_path.format(
|
||||
audio_key=audio_key, chunk_index=chunk_idx, file_index=file_idx
|
||||
)
|
||||
new_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.move(str(ep_path), str(new_path))
|
||||
latest_duration_in_s = 0.0
|
||||
else:
|
||||
# Update latest audio file
|
||||
concatenate_media_files(
|
||||
[latest_path, ep_path],
|
||||
latest_path,
|
||||
)
|
||||
|
||||
# Remove temporary directory
|
||||
shutil.rmtree(str(ep_path.parent))
|
||||
|
||||
# Update audio info (only needed when first episode is encoded since it reads from episode 0)
|
||||
if episode_index == 0:
|
||||
self._meta.update_audio_info(audio_key)
|
||||
write_info(self._meta.info, self._meta.root) # ensure audio info always written properly
|
||||
|
||||
metadata = {
|
||||
"episode_index": episode_index,
|
||||
f"audio/{audio_key}/chunk_index": chunk_idx,
|
||||
f"audio/{audio_key}/file_index": file_idx,
|
||||
f"audio/{audio_key}/from_timestamp": latest_duration_in_s,
|
||||
f"audio/{audio_key}/to_timestamp": latest_duration_in_s + ep_duration_in_s,
|
||||
}
|
||||
return metadata
|
||||
|
||||
def clear_episode_buffer(self, delete_images: bool = True, delete_audio: bool = True) -> None:
|
||||
def clear_episode_buffer(self, delete_images: bool = True) -> None:
|
||||
"""Discard the current episode buffer and optionally delete temp images.
|
||||
|
||||
Args:
|
||||
@@ -724,15 +531,6 @@ class DatasetWriter:
|
||||
if img_dir.is_dir():
|
||||
shutil.rmtree(img_dir)
|
||||
|
||||
if delete_audio:
|
||||
episode_index = self.episode_buffer["episode_index"]
|
||||
if isinstance(episode_index, np.ndarray):
|
||||
episode_index = episode_index.item() if episode_index.size == 1 else episode_index[0]
|
||||
for audio_key in self._meta.audio_keys:
|
||||
audio_file = self._get_raw_audio_file_path(episode_index, audio_key)
|
||||
if audio_file.is_file():
|
||||
audio_file.unlink()
|
||||
|
||||
self.episode_buffer = self._create_episode_buffer()
|
||||
|
||||
def start_image_writer(self, num_processes: int = 0, num_threads: int = 4) -> None:
|
||||
@@ -798,7 +596,7 @@ class DatasetWriter:
|
||||
self._streaming_encoder.cancel_episode()
|
||||
|
||||
def cleanup_interrupted_episode(self, episode_index: int) -> None:
|
||||
"""Remove temporary image and audio directories for an interrupted episode."""
|
||||
"""Remove temporary image directories for an interrupted episode."""
|
||||
for key in self._meta.video_keys:
|
||||
img_dir = self._get_image_file_path(
|
||||
episode_index=episode_index, image_key=key, frame_index=0
|
||||
@@ -809,14 +607,6 @@ class DatasetWriter:
|
||||
)
|
||||
shutil.rmtree(img_dir)
|
||||
|
||||
for key in self._meta.audio_keys:
|
||||
audio_file = self._get_raw_audio_file_path(episode_index=episode_index, audio_key=key)
|
||||
if audio_file.exists():
|
||||
logger.debug(
|
||||
f"Cleaning up interrupted episode audio for episode {episode_index}, microphone {key}"
|
||||
)
|
||||
audio_file.unlink()
|
||||
|
||||
def finalize(self) -> None:
|
||||
"""Flush all pending work and release all resources.
|
||||
|
||||
|
||||
@@ -22,8 +22,6 @@ from PIL import Image as PILImage
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.datasets.utils import (
|
||||
DEFAULT_AUDIO_FILE_SIZE_IN_MB,
|
||||
DEFAULT_AUDIO_PATH,
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
DEFAULT_DATA_PATH,
|
||||
@@ -49,7 +47,7 @@ def get_hf_features_from_features(features: dict) -> datasets.Features:
|
||||
"""
|
||||
hf_features = {}
|
||||
for key, ft in features.items():
|
||||
if ft["dtype"] == "video" or ft["dtype"] == "audio":
|
||||
if ft["dtype"] == "video":
|
||||
continue
|
||||
elif ft["dtype"] == "image":
|
||||
hf_features[key] = datasets.Image()
|
||||
@@ -112,12 +110,7 @@ def hw_to_dataset_features(
|
||||
for key, ftype in hw_features.items()
|
||||
if ftype is float or (isinstance(ftype, PolicyFeature) and ftype.type != FeatureType.VISUAL)
|
||||
}
|
||||
cam_fts = {
|
||||
key: shape for key, shape in hw_features.items() if isinstance(shape, tuple) and len(shape) == 3
|
||||
}
|
||||
mic_fts = {
|
||||
key: shape for key, shape in hw_features.items() if isinstance(shape, tuple) and len(shape) == 2
|
||||
}
|
||||
cam_fts = {key: shape for key, shape in hw_features.items() if isinstance(shape, tuple)}
|
||||
|
||||
if joint_fts and prefix == ACTION:
|
||||
features[prefix] = {
|
||||
@@ -140,14 +133,6 @@ def hw_to_dataset_features(
|
||||
"names": ["height", "width", "channels"],
|
||||
}
|
||||
|
||||
for key, parameters in mic_fts.items():
|
||||
features[f"{prefix}.audio.{key}"] = {
|
||||
"dtype": "audio",
|
||||
"shape": (len(parameters[1]),),
|
||||
"names": ["channels"],
|
||||
"info": {"sample_rate": parameters[0]},
|
||||
}
|
||||
|
||||
_validate_feature_names(features)
|
||||
return features
|
||||
|
||||
@@ -177,8 +162,6 @@ def build_dataset_frame(
|
||||
frame[key] = np.array([values[name] for name in ft["names"]], dtype=np.float32)
|
||||
elif ft["dtype"] in ["image", "video"]:
|
||||
frame[key] = values[key.removeprefix(f"{prefix}.images.")]
|
||||
elif ft["dtype"] == "audio":
|
||||
frame[key] = values[key.removeprefix(f"{prefix}.audio.")]
|
||||
|
||||
return frame
|
||||
|
||||
@@ -212,10 +195,6 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea
|
||||
# Backward compatibility for "channel" which is an error introduced in LeRobotDataset v2.0 for ported datasets.
|
||||
if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w)
|
||||
shape = (shape[2], shape[0], shape[1])
|
||||
elif ft["dtype"] == "audio":
|
||||
type = FeatureType.AUDIO
|
||||
if len(shape) != 2:
|
||||
raise ValueError(f"Number of dimensions of {key} != 2 (shape={shape})")
|
||||
elif key == OBS_ENV_STATE:
|
||||
type = FeatureType.ENV
|
||||
elif key.startswith(OBS_STR):
|
||||
@@ -294,7 +273,6 @@ def create_empty_dataset_info(
|
||||
chunks_size: int | None = None,
|
||||
data_files_size_in_mb: int | None = None,
|
||||
video_files_size_in_mb: int | None = None,
|
||||
audio_files_size_in_mb: int | None = None,
|
||||
) -> dict:
|
||||
"""Create a template dictionary for a new dataset's `info.json`.
|
||||
|
||||
@@ -304,10 +282,7 @@ def create_empty_dataset_info(
|
||||
features (dict): The LeRobot features dictionary for the dataset.
|
||||
use_videos (bool): Whether the dataset will store videos.
|
||||
robot_type (str | None): The type of robot used, if any.
|
||||
chunks_size (int | None): The number of files per chunk.
|
||||
data_files_size_in_mb (int | None): The maximum size per data file in MB.
|
||||
video_files_size_in_mb (int | None): The maximum size per video file in MB.
|
||||
audio_files_size_in_mb (int | None): The maximum size per audio file in MB.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary with the initial dataset metadata.
|
||||
"""
|
||||
@@ -320,12 +295,10 @@ def create_empty_dataset_info(
|
||||
"chunks_size": chunks_size or DEFAULT_CHUNK_SIZE,
|
||||
"data_files_size_in_mb": data_files_size_in_mb or DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
"video_files_size_in_mb": video_files_size_in_mb or DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
||||
"audio_files_size_in_mb": audio_files_size_in_mb or DEFAULT_AUDIO_FILE_SIZE_IN_MB,
|
||||
"fps": fps,
|
||||
"splits": {},
|
||||
"data_path": DEFAULT_DATA_PATH,
|
||||
"video_path": DEFAULT_VIDEO_PATH if use_videos else None,
|
||||
"audio_path": DEFAULT_AUDIO_PATH,
|
||||
"features": features,
|
||||
}
|
||||
|
||||
@@ -462,8 +435,6 @@ def validate_feature_dtype_and_shape(
|
||||
return validate_feature_numpy_array(name, expected_dtype, expected_shape, value)
|
||||
elif expected_dtype in ["image", "video"]:
|
||||
return validate_feature_image_or_video(name, expected_shape, value)
|
||||
elif expected_dtype == "audio":
|
||||
return validate_feature_audio(name, expected_shape, value)
|
||||
elif expected_dtype == "string":
|
||||
return validate_feature_string(name, value)
|
||||
else:
|
||||
@@ -530,33 +501,6 @@ def validate_feature_image_or_video(
|
||||
return error_message
|
||||
|
||||
|
||||
def validate_feature_audio(name: str, expected_shape: list[str], value: np.ndarray):
|
||||
"""Validate a feature that is expected to be an audio frame.
|
||||
|
||||
Args:
|
||||
name (str): The name of the feature.
|
||||
expected_shape (list[str]): The expected shape (C,).
|
||||
value: The audio data to validate.
|
||||
|
||||
Returns:
|
||||
str: An error message if validation fails, otherwise an empty string.
|
||||
"""
|
||||
error_message = ""
|
||||
if isinstance(value, np.ndarray):
|
||||
actual_shape = value.shape
|
||||
c = expected_shape
|
||||
if (len(actual_shape) != 2 and len(actual_shape) != 1) or actual_shape[-1] != c[
|
||||
-1
|
||||
]: # The number of frames might be different
|
||||
error_message += (
|
||||
f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{c}'.\n"
|
||||
)
|
||||
else:
|
||||
error_message += f"The feature '{name}' is expected to be of type 'np.ndarray', but type '{type(value)}' provided instead.\n"
|
||||
|
||||
return error_message
|
||||
|
||||
|
||||
def validate_feature_string(name: str, value: str) -> str:
|
||||
"""Validate a feature that is expected to be a string.
|
||||
|
||||
|
||||
@@ -23,7 +23,6 @@ import pandas
|
||||
import pandas as pd
|
||||
import pyarrow.dataset as pa_ds
|
||||
import pyarrow.parquet as pq
|
||||
import soundfile as sf
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
from datasets.table import embed_table_storage
|
||||
@@ -281,24 +280,6 @@ def load_image_as_numpy(
|
||||
return img_array
|
||||
|
||||
|
||||
def load_audio_from_path(fpath: str | Path) -> np.ndarray:
|
||||
"""Load an audio file from a path into a numpy array.
|
||||
|
||||
Args:
|
||||
fpath (str | Path): Path to the audio file.
|
||||
|
||||
Returns:
|
||||
np.ndarray: The audio as a numpy array.
|
||||
"""
|
||||
audio_data, _ = sf.read(fpath, dtype="float32")
|
||||
|
||||
# Fill missing channel dimension when loading mono audio data
|
||||
if audio_data.ndim == 1:
|
||||
audio_data = np.expand_dims(audio_data, axis=1)
|
||||
|
||||
return audio_data
|
||||
|
||||
|
||||
def hf_transform_to_torch(items_dict: dict[str, list[Any]]) -> dict[str, list[torch.Tensor | str]]:
|
||||
"""Convert a batch from a Hugging Face dataset to torch tensors.
|
||||
|
||||
|
||||
@@ -54,9 +54,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
revision: str | None = None,
|
||||
force_cache_sync: bool = False,
|
||||
download_videos: bool = True,
|
||||
download_audio: bool = True,
|
||||
video_backend: str | None = None,
|
||||
audio_backend: str | None = None,
|
||||
batch_encoding_size: int = 1,
|
||||
vcodec: str = "libsvtav1",
|
||||
streaming_encoding: bool = False,
|
||||
@@ -93,7 +91,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
task-conditioned training.
|
||||
- data (backed by datasets.Dataset), which reads values from parquet files.
|
||||
- videos (optional) from which frames are loaded to be synchronous with data from parquet files.
|
||||
- audio (optional) from which audio is loaded to be synchronous with data from parquet files.
|
||||
|
||||
A typical LeRobotDataset looks like this from its root path:
|
||||
.
|
||||
@@ -119,37 +116,19 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
│ ├── info.json
|
||||
│ ├── stats.json
|
||||
│ └── tasks.parquet
|
||||
├── videos
|
||||
│ ├── observation.images.laptop
|
||||
│ │ ├── chunk-000
|
||||
│ │ │ ├── file-000.mp4
|
||||
│ │ │ ├── file-001.mp4
|
||||
│ │ │ └── ...
|
||||
│ │ ├── chunk-001
|
||||
│ │ │ └── ...
|
||||
│ │ └── ...
|
||||
│ ├── observation.images.phone
|
||||
│ │ ├── chunk-000
|
||||
│ │ │ ├── file-000.mp4
|
||||
│ │ │ ├── file-001.mp4
|
||||
│ │ │ └── ...
|
||||
│ │ ├── chunk-001
|
||||
│ │ │ └── ...
|
||||
│ │ └── ...
|
||||
│ └── ...
|
||||
└── audio
|
||||
├── observation.audio.laptop
|
||||
└── videos
|
||||
├── observation.images.laptop
|
||||
│ ├── chunk-000
|
||||
│ │ ├── file-000.m4a
|
||||
│ │ ├── file-001.m4a
|
||||
│ │ ├── file-000.mp4
|
||||
│ │ ├── file-001.mp4
|
||||
│ │ └── ...
|
||||
│ ├── chunk-001
|
||||
│ │ └── ...
|
||||
│ └── ...
|
||||
├── observation.audio.phone
|
||||
├── observation.images.phone
|
||||
│ ├── chunk-000
|
||||
│ │ ├── file-000.m4a
|
||||
│ │ ├── file-001.m4a
|
||||
│ │ ├── file-000.mp4
|
||||
│ │ ├── file-001.mp4
|
||||
│ │ └── ...
|
||||
│ ├── chunk-001
|
||||
│ │ └── ...
|
||||
@@ -172,9 +151,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
``$HF_LEROBOT_HOME/hub``.
|
||||
episodes (list[int] | None, optional): If specified, this will only load episodes specified by
|
||||
their episode_index in this list. Defaults to None.
|
||||
image_transforms (Callable | None, optional): You can pass standard v2 image transforms from
|
||||
torchvision.transforms.v2 here which will be applied to visual modalities (whether they come
|
||||
from videos or images). Defaults to None.
|
||||
image_transforms (Callable | None, optional):
|
||||
Transform applied to visual modalities inside `__getitem__` after image decoding / tensor
|
||||
conversion. This works for both image-backed and video-backed observations and can later be
|
||||
updated with `set_image_transforms()` or cleared with `clear_image_transforms()`.
|
||||
Defaults to None.
|
||||
delta_timestamps (dict[list[float]] | None, optional): _description_. Defaults to None.
|
||||
tolerance_s (float, optional): Tolerance in seconds used to ensure data timestamps are actually in
|
||||
sync with the fps value. It is used at the init of the dataset to make sure that each
|
||||
@@ -190,10 +171,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
download_videos (bool, optional): Flag to download the videos. Note that when set to True but the
|
||||
video files are already present on local disk, they won't be downloaded again. Defaults to
|
||||
True.
|
||||
download_audio (bool, optional): Flag to download the audio. Defaults to True.
|
||||
video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec when available int the platform; otherwise, defaults to 'pyav'.
|
||||
You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision.
|
||||
audio_backend (str | None, optional): Audio backend to use for decoding audio. Defaults to 'torchcodec'.
|
||||
batch_encoding_size (int, optional): Number of episodes to accumulate before batch encoding videos.
|
||||
Set to 1 for immediate encoding (default), or higher for batched encoding. Defaults to 1.
|
||||
vcodec (str, optional): Video codec for encoding videos during recording. Options: 'h264', 'hevc',
|
||||
@@ -215,13 +194,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
super().__init__()
|
||||
self.repo_id = repo_id
|
||||
self._requested_root = Path(root) if root else None
|
||||
self.image_transforms = image_transforms
|
||||
self.reader = None
|
||||
self.set_image_transforms(image_transforms)
|
||||
self.delta_timestamps = delta_timestamps
|
||||
self.episodes = episodes
|
||||
self.tolerance_s = tolerance_s
|
||||
self.revision = revision if revision else CODEBASE_VERSION
|
||||
self._video_backend = video_backend if video_backend else get_safe_default_codec()
|
||||
self._audio_backend = audio_backend if audio_backend else "torchcodec"
|
||||
self._batch_encoding_size = batch_encoding_size
|
||||
self._vcodec = resolve_vcodec(vcodec)
|
||||
self._encoder_threads = encoder_threads
|
||||
@@ -243,7 +222,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
episodes=episodes,
|
||||
tolerance_s=tolerance_s,
|
||||
video_backend=self._video_backend,
|
||||
audio_backend=self._audio_backend,
|
||||
delta_timestamps=delta_timestamps,
|
||||
image_transforms=image_transforms,
|
||||
)
|
||||
@@ -252,7 +230,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
if force_cache_sync or not self.reader.try_load():
|
||||
if is_valid_version(self.revision):
|
||||
self.revision = get_safe_version(self.repo_id, self.revision)
|
||||
self._download(download_videos, download_audio)
|
||||
self._download(download_videos)
|
||||
self.reader.load_and_activate()
|
||||
|
||||
# Detect write-mode params for backward compatibility
|
||||
@@ -306,7 +284,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
episodes=self.episodes,
|
||||
tolerance_s=self.tolerance_s,
|
||||
video_backend=self._video_backend,
|
||||
audio_backend=self._audio_backend,
|
||||
delta_timestamps=self.delta_timestamps,
|
||||
image_transforms=self.image_transforms,
|
||||
)
|
||||
@@ -386,14 +363,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self._require_writer("add_frame")
|
||||
self.writer.add_frame(frame)
|
||||
|
||||
def add_microphones_recordings(self, microphones: dict) -> None:
|
||||
"""Add microphone recordings to the current episode buffer.
|
||||
|
||||
Delegates to :meth:`DatasetWriter.add_microphones_recordings`.
|
||||
"""
|
||||
self._require_writer("add_microphones_recordings")
|
||||
self.writer.add_microphones_recordings(microphones)
|
||||
|
||||
def save_episode(self, episode_data: dict | None = None, parallel_encoding: bool = True) -> None:
|
||||
"""Save the current episode buffer to disk.
|
||||
|
||||
@@ -509,6 +478,18 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
f"}})"
|
||||
)
|
||||
|
||||
def set_image_transforms(self, image_transforms: Callable | None) -> None:
|
||||
"""Replace the transform applied to visual observations."""
|
||||
if image_transforms is not None and not callable(image_transforms):
|
||||
raise TypeError("image_transforms must be callable or None.")
|
||||
self.image_transforms = image_transforms
|
||||
if self.reader is not None:
|
||||
self.reader._image_transforms = image_transforms
|
||||
|
||||
def clear_image_transforms(self) -> None:
|
||||
"""Remove the transform applied to visual observations."""
|
||||
self.set_image_transforms(None)
|
||||
|
||||
# ── Hub methods (stay on facade) ──────────────────────────────────
|
||||
|
||||
def push_to_hub(
|
||||
@@ -518,7 +499,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
license: str | None = "apache-2.0",
|
||||
tag_version: bool = True,
|
||||
push_videos: bool = True,
|
||||
push_audio: bool = True,
|
||||
private: bool = False,
|
||||
allow_patterns: list[str] | str | None = None,
|
||||
upload_large_folder: bool = False,
|
||||
@@ -548,8 +528,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
ignore_patterns = ["images/"]
|
||||
if not push_videos:
|
||||
ignore_patterns.append("videos/")
|
||||
if not push_audio:
|
||||
ignore_patterns.append("audio/")
|
||||
|
||||
hub_api = HfApi()
|
||||
hub_api.create_repo(
|
||||
@@ -590,15 +568,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
hub_api.delete_tag(self.repo_id, tag=CODEBASE_VERSION, repo_type="dataset")
|
||||
hub_api.create_tag(self.repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
|
||||
|
||||
def _download(self, download_videos: bool = True, download_audio: bool = True) -> None:
|
||||
def _download(self, download_videos: bool = True) -> None:
|
||||
"""Downloads the dataset from the given 'repo_id' at the provided version."""
|
||||
ignore_patterns = None if download_videos else "videos/"
|
||||
files = None
|
||||
ignore_patterns = []
|
||||
if not download_videos:
|
||||
ignore_patterns.append("videos/")
|
||||
if not download_audio:
|
||||
ignore_patterns.append("audio/")
|
||||
if self.episodes is not None:
|
||||
# Reader is guaranteed to exist here (created in __init__ before _download)
|
||||
files = self.reader.get_episodes_file_paths()
|
||||
@@ -645,7 +618,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
image_writer_processes: int = 0,
|
||||
image_writer_threads: int = 0,
|
||||
video_backend: str | None = None,
|
||||
audio_backend: str | None = None,
|
||||
batch_encoding_size: int = 1,
|
||||
vcodec: str = "libsvtav1",
|
||||
metadata_buffer_size: int = 10,
|
||||
|
||||
@@ -89,12 +89,24 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
)
|
||||
self.disabled_features.update(extra_keys)
|
||||
|
||||
self.image_transforms = image_transforms
|
||||
self.delta_timestamps = delta_timestamps
|
||||
# TODO(rcadene, aliberts): We should not perform this aggregation for datasets
|
||||
# with multiple robots of different ranges. Instead we should have one normalization
|
||||
# per robot.
|
||||
self.stats = aggregate_stats([dataset.meta.stats for dataset in self._datasets])
|
||||
self.set_image_transforms(image_transforms)
|
||||
|
||||
def set_image_transforms(self, image_transforms: Callable | None) -> None:
|
||||
"""Replace the transform for this dataset and its children."""
|
||||
if image_transforms is not None and not callable(image_transforms):
|
||||
raise TypeError("image_transforms must be callable or None.")
|
||||
self.image_transforms = image_transforms
|
||||
for dataset in getattr(self, "_datasets", []):
|
||||
dataset.set_image_transforms(self.image_transforms)
|
||||
|
||||
def clear_image_transforms(self) -> None:
|
||||
"""Remove the transform from this dataset and its children."""
|
||||
self.set_image_transforms(None)
|
||||
|
||||
@property
|
||||
def repo_id_to_index(self):
|
||||
|
||||
@@ -73,7 +73,6 @@ class ForwardCompatibilityError(CompatibilityError):
|
||||
DEFAULT_CHUNK_SIZE = 1000 # Max number of files per chunk
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB = 100 # Max size per file
|
||||
DEFAULT_VIDEO_FILE_SIZE_IN_MB = 200 # Max size per file
|
||||
DEFAULT_AUDIO_FILE_SIZE_IN_MB = 100 # Max size per file
|
||||
|
||||
INFO_PATH = "meta/info.json"
|
||||
STATS_PATH = "meta/stats.json"
|
||||
@@ -81,7 +80,6 @@ STATS_PATH = "meta/stats.json"
|
||||
EPISODES_DIR = "meta/episodes"
|
||||
DATA_DIR = "data"
|
||||
VIDEO_DIR = "videos"
|
||||
AUDIO_DIR = "audio"
|
||||
|
||||
CHUNK_FILE_PATTERN = "chunk-{chunk_index:03d}/file-{file_index:03d}"
|
||||
DEFAULT_TASKS_PATH = "meta/tasks.parquet"
|
||||
@@ -89,12 +87,7 @@ DEFAULT_SUBTASKS_PATH = "meta/subtasks.parquet"
|
||||
DEFAULT_EPISODES_PATH = EPISODES_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
|
||||
DEFAULT_DATA_PATH = DATA_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
|
||||
DEFAULT_VIDEO_PATH = VIDEO_DIR + "/{video_key}/" + CHUNK_FILE_PATTERN + ".mp4"
|
||||
DEFAULT_AUDIO_PATH = AUDIO_DIR + "/{audio_key}/" + CHUNK_FILE_PATTERN + ".m4a"
|
||||
DEFAULT_IMAGE_PATH = "images/{image_key}/episode-{episode_index:06d}/frame-{frame_index:06d}.png"
|
||||
DEFAULT_RAW_AUDIO_PATH = "raw_audio/{audio_key}/episode_{episode_index:06d}.wav"
|
||||
|
||||
DEFAULT_AUDIO_CHUNK_DURATION = 0.5 # seconds
|
||||
DEFAULT_INITIAL_AUDIO_BUFFER_DURATION = 1.0 # seconds
|
||||
|
||||
LEGACY_EPISODES_PATH = "meta/episodes.jsonl"
|
||||
LEGACY_EPISODES_STATS_PATH = "meta/episodes_stats.jsonl"
|
||||
|
||||
@@ -486,42 +486,42 @@ def encode_video_frames(
|
||||
raise OSError(f"Video encoding did not work. File not found: {video_path}.")
|
||||
|
||||
|
||||
def concatenate_media_files(
|
||||
input_media_paths: list[Path | str], output_media_path: Path, overwrite: bool = True
|
||||
def concatenate_video_files(
|
||||
input_video_paths: list[Path | str], output_video_path: Path, overwrite: bool = True
|
||||
):
|
||||
"""
|
||||
Concatenate multiple media files (video & audio) into a single media file using pyav.
|
||||
Concatenate multiple video files into a single video file using pyav.
|
||||
|
||||
This function takes a list of input media file paths and concatenates them into a single
|
||||
output media file. It uses ffmpeg's concat demuxer with stream copy mode for fast
|
||||
This function takes a list of video input file paths and concatenates them into a single
|
||||
output video file. It uses ffmpeg's concat demuxer with stream copy mode for fast
|
||||
concatenation without re-encoding.
|
||||
|
||||
Args:
|
||||
input_media_paths: Ordered list of input media file paths to concatenate.
|
||||
output_media_path: Path to the output media file.
|
||||
overwrite: Whether to overwrite the output media file if it already exists. Default is True.
|
||||
input_video_paths: Ordered list of input video file paths to concatenate.
|
||||
output_video_path: Path to the output video file.
|
||||
overwrite: Whether to overwrite the output video file if it already exists. Default is True.
|
||||
|
||||
Note:
|
||||
- Creates a temporary .ffconcat file and container audio/video file that are cleaned up after use.
|
||||
- Uses ffmpeg's concat demuxer which requires all input media files to have the same
|
||||
- Creates a temporary directory for intermediate files that is cleaned up after use.
|
||||
- Uses ffmpeg's concat demuxer which requires all input videos to have the same
|
||||
codec, resolution, and frame rate for proper concatenation.
|
||||
"""
|
||||
|
||||
output_media_path = Path(output_media_path)
|
||||
output_video_path = Path(output_video_path)
|
||||
|
||||
if output_media_path.exists() and not overwrite:
|
||||
logging.warning(f"Media file already exists: {output_media_path}. Skipping concatenation.")
|
||||
if output_video_path.exists() and not overwrite:
|
||||
logger.warning(f"Video file already exists: {output_video_path}. Skipping concatenation.")
|
||||
return
|
||||
|
||||
output_media_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
output_video_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if len(input_media_paths) == 0:
|
||||
raise FileNotFoundError("No input media paths provided.")
|
||||
if len(input_video_paths) == 0:
|
||||
raise FileNotFoundError("No input video paths provided.")
|
||||
|
||||
# Create a temporary .ffconcat file to list the input media paths
|
||||
# Create a temporary .ffconcat file to list the input video paths
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".ffconcat", delete=False) as tmp_concatenate_file:
|
||||
tmp_concatenate_file.write("ffconcat version 1.0\n")
|
||||
for input_path in input_media_paths:
|
||||
for input_path in input_video_paths:
|
||||
tmp_concatenate_file.write(f"file '{str(input_path.resolve())}'\n")
|
||||
tmp_concatenate_file.flush()
|
||||
tmp_concatenate_path = tmp_concatenate_file.name
|
||||
@@ -531,12 +531,11 @@ def concatenate_media_files(
|
||||
tmp_concatenate_path, mode="r", format="concat", options={"safe": "0"}
|
||||
) # safe = 0 allows absolute paths as well as relative paths
|
||||
|
||||
# Using an intermediate container to store the concatenated media file is necessary to avoid inplace concatenation read-write race conditions.
|
||||
with tempfile.NamedTemporaryFile(suffix=output_media_path.suffix, delete=False) as tmp_named_file:
|
||||
tmp_output_media_path = tmp_named_file.name
|
||||
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp_named_file:
|
||||
tmp_output_video_path = tmp_named_file.name
|
||||
|
||||
output_container = av.open(
|
||||
tmp_output_media_path, mode="w", options={"movflags": "faststart"}
|
||||
tmp_output_video_path, mode="w", options={"movflags": "faststart"}
|
||||
) # faststart is to move the metadata to the beginning of the file to speed up loading
|
||||
|
||||
# Replicate input streams in output container
|
||||
@@ -551,7 +550,6 @@ def concatenate_media_files(
|
||||
stream_map[input_stream.index].time_base = input_stream.time_base
|
||||
|
||||
# Demux + remux packets (no re-encode)
|
||||
last_dts = None
|
||||
for packet in input_container.demux():
|
||||
# Skip packets from un-mapped streams
|
||||
if packet.stream.index not in stream_map:
|
||||
@@ -560,16 +558,6 @@ def concatenate_media_files(
|
||||
# Skip demux flushing packets
|
||||
if packet.dts is None:
|
||||
continue
|
||||
else:
|
||||
# Enforce strictly increasing decoding timestamps (DTS)
|
||||
if last_dts is not None and packet.dts <= last_dts:
|
||||
shift = last_dts - packet.dts + 1
|
||||
packet.dts += shift
|
||||
packet.pts += shift # Presenting timestamps (PTS) are the same as DTS here
|
||||
logging.warning(
|
||||
f"Non-monotonic DTS; previous: {last_dts}, current: {packet.dts - shift}; changing to {packet.dts}. This may result in incorrect timestamps in the output file."
|
||||
)
|
||||
last_dts = packet.dts
|
||||
|
||||
output_stream = stream_map[packet.stream.index]
|
||||
packet.stream = output_stream
|
||||
@@ -577,7 +565,7 @@ def concatenate_media_files(
|
||||
|
||||
input_container.close()
|
||||
output_container.close()
|
||||
shutil.move(tmp_output_media_path, output_media_path)
|
||||
shutil.move(tmp_output_video_path, output_video_path)
|
||||
Path(tmp_concatenate_path).unlink()
|
||||
|
||||
|
||||
@@ -959,6 +947,38 @@ with warnings.catch_warnings():
|
||||
register_feature(VideoFrame, "VideoFrame")
|
||||
|
||||
|
||||
def get_audio_info(video_path: Path | str) -> dict:
|
||||
# Set logging level
|
||||
logging.getLogger("libav").setLevel(av.logging.WARNING)
|
||||
|
||||
# Getting audio stream information
|
||||
audio_info = {}
|
||||
with av.open(str(video_path), "r") as audio_file:
|
||||
try:
|
||||
audio_stream = audio_file.streams.audio[0]
|
||||
except IndexError:
|
||||
# Reset logging level
|
||||
av.logging.restore_default_callback()
|
||||
return {"has_audio": False}
|
||||
|
||||
audio_info["audio.channels"] = audio_stream.channels
|
||||
audio_info["audio.codec"] = audio_stream.codec.canonical_name
|
||||
# In an ideal loseless case : bit depth x sample rate x channels = bit rate.
|
||||
# In an actual compressed case, the bit rate is set according to the compression level : the lower the bit rate, the more compression is applied.
|
||||
audio_info["audio.bit_rate"] = audio_stream.bit_rate
|
||||
audio_info["audio.sample_rate"] = audio_stream.sample_rate # Number of samples per second
|
||||
# In an ideal loseless case : fixed number of bits per sample.
|
||||
# In an actual compressed case : variable number of bits per sample (often reduced to match a given depth rate).
|
||||
audio_info["audio.bit_depth"] = audio_stream.format.bits
|
||||
audio_info["audio.channel_layout"] = audio_stream.layout.name
|
||||
audio_info["has_audio"] = True
|
||||
|
||||
# Reset logging level
|
||||
av.logging.restore_default_callback()
|
||||
|
||||
return audio_info
|
||||
|
||||
|
||||
def get_video_info(video_path: Path | str) -> dict:
|
||||
# Set logging level
|
||||
logging.getLogger("libav").setLevel(av.logging.WARNING)
|
||||
@@ -988,6 +1008,9 @@ def get_video_info(video_path: Path | str) -> dict:
|
||||
# Reset logging level
|
||||
av.logging.restore_default_callback()
|
||||
|
||||
# Adding audio stream information
|
||||
video_info.update(**get_audio_info(video_path))
|
||||
|
||||
return video_info
|
||||
|
||||
|
||||
@@ -1002,22 +1025,22 @@ def get_video_pixel_channels(pix_fmt: str) -> int:
|
||||
raise ValueError("Unknown format")
|
||||
|
||||
|
||||
def get_media_duration_in_s(media_path: Path | str, media_type: str = "video") -> float:
|
||||
def get_video_duration_in_s(video_path: Path | str) -> float:
|
||||
"""
|
||||
Get the duration of a media file (video & audio) in seconds using PyAV.
|
||||
Get the duration of a video file in seconds using PyAV.
|
||||
|
||||
Args:
|
||||
media_path: Path to the media file.
|
||||
video_path: Path to the video file.
|
||||
|
||||
Returns:
|
||||
Duration of the media file in seconds.
|
||||
Duration of the video in seconds.
|
||||
"""
|
||||
with av.open(str(media_path)) as container:
|
||||
# Get the first stream
|
||||
stream = container.streams.video[0] if media_type == "video" else container.streams.audio[0]
|
||||
with av.open(str(video_path)) as container:
|
||||
# Get the first video stream
|
||||
video_stream = container.streams.video[0]
|
||||
# Calculate duration: stream.duration * stream.time_base gives duration in seconds
|
||||
if stream.duration is not None:
|
||||
duration = float(stream.duration * stream.time_base)
|
||||
if video_stream.duration is not None:
|
||||
duration = float(video_stream.duration * video_stream.time_base)
|
||||
else:
|
||||
# Fallback to container duration if stream duration is not available
|
||||
duration = float(container.duration / av.time_base)
|
||||
@@ -1026,12 +1049,12 @@ def get_media_duration_in_s(media_path: Path | str, media_type: str = "video") -
|
||||
|
||||
class VideoEncodingManager:
|
||||
"""
|
||||
Context manager that ensures proper video and audio encoding and data cleanup even if exceptions occur.
|
||||
Context manager that ensures proper video encoding and data cleanup even if exceptions occur.
|
||||
|
||||
This manager handles:
|
||||
- Batch encoding for any remaining episodes when recording interrupted
|
||||
- Cleaning up temporary image and audio files from interrupted episodes
|
||||
- Removing empty image and audio directories
|
||||
- Cleaning up temporary image files from interrupted episodes
|
||||
- Removing empty image directories
|
||||
|
||||
Args:
|
||||
dataset: The LeRobotDataset instance
|
||||
@@ -1068,16 +1091,4 @@ class VideoEncodingManager:
|
||||
else:
|
||||
logger.debug(f"Images directory is not empty, containing {len(png_files)} PNG files")
|
||||
|
||||
# Clean up any remaining audio directory if it's empty
|
||||
audio_dir = self.dataset.root / "raw_audio"
|
||||
# Check for any remaining WAV files
|
||||
wav_files = list(audio_dir.rglob("*.wav"))
|
||||
if len(wav_files) == 0:
|
||||
# Only remove the raw_audio directory if no WAV files remain
|
||||
if audio_dir.exists():
|
||||
shutil.rmtree(audio_dir)
|
||||
logging.debug("Cleaned up empty audio directory")
|
||||
else:
|
||||
logging.debug(f"Audio directory is not empty, containing {len(wav_files)} WAV files")
|
||||
|
||||
return False # Don't suppress the original exception
|
||||
|
||||
@@ -12,11 +12,16 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import importlib
|
||||
from dataclasses import dataclass, field, fields
|
||||
from typing import Any
|
||||
|
||||
import draccus
|
||||
import gymnasium as gym
|
||||
from gymnasium.envs.registration import registry as gym_registry
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.robots import RobotConfig
|
||||
@@ -67,6 +72,49 @@ class EnvConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
def gym_kwargs(self) -> dict:
|
||||
raise NotImplementedError()
|
||||
|
||||
def create_envs(
|
||||
self,
|
||||
n_envs: int,
|
||||
use_async_envs: bool = False,
|
||||
) -> dict[str, dict[int, gym.vector.VectorEnv]]:
|
||||
"""Create {suite: {task_id: VectorEnv}}.
|
||||
|
||||
Default: single-task env via gym.make(). Multi-task benchmarks override.
|
||||
"""
|
||||
env_cls = gym.vector.AsyncVectorEnv if use_async_envs else gym.vector.SyncVectorEnv
|
||||
|
||||
if self.gym_id not in gym_registry:
|
||||
print(f"gym id '{self.gym_id}' not found, attempting to import '{self.package_name}'...")
|
||||
try:
|
||||
importlib.import_module(self.package_name)
|
||||
except ModuleNotFoundError as e:
|
||||
raise ModuleNotFoundError(
|
||||
f"Package '{self.package_name}' required for env '{self.type}' not found. "
|
||||
f"Please install it or check PYTHONPATH."
|
||||
) from e
|
||||
|
||||
if self.gym_id not in gym_registry:
|
||||
raise gym.error.NameNotFound(
|
||||
f"Environment '{self.gym_id}' not registered even after importing '{self.package_name}'."
|
||||
)
|
||||
|
||||
def _make_one():
|
||||
return gym.make(self.gym_id, disable_env_checker=self.disable_env_checker, **self.gym_kwargs)
|
||||
|
||||
try:
|
||||
from gymnasium.vector import AutoresetMode
|
||||
|
||||
vec = env_cls([_make_one for _ in range(n_envs)], autoreset_mode=AutoresetMode.SAME_STEP)
|
||||
except ImportError:
|
||||
vec = env_cls([_make_one for _ in range(n_envs)])
|
||||
return {self.type: {0: vec}}
|
||||
|
||||
def get_env_processors(self):
|
||||
"""Return (preprocessor, postprocessor) for this env. Default: identity."""
|
||||
from lerobot.processor.pipeline import PolicyProcessorPipeline
|
||||
|
||||
return PolicyProcessorPipeline(steps=[]), PolicyProcessorPipeline(steps=[])
|
||||
|
||||
|
||||
@dataclass
|
||||
class HubEnvConfig(EnvConfig):
|
||||
@@ -338,6 +386,12 @@ class LiberoEnv(EnvConfig):
|
||||
else:
|
||||
raise ValueError(f"Unsupported obs_type: {self.obs_type}")
|
||||
|
||||
if self.camera_name_mapping is not None:
|
||||
mapped_agentview = self.camera_name_mapping.get("agentview_image", "image")
|
||||
mapped_eye_in_hand = self.camera_name_mapping.get("robot0_eye_in_hand_image", "image2")
|
||||
self.features_map[LIBERO_KEY_PIXELS_AGENTVIEW] = f"{OBS_IMAGES}.{mapped_agentview}"
|
||||
self.features_map[LIBERO_KEY_PIXELS_EYE_IN_HAND] = f"{OBS_IMAGES}.{mapped_eye_in_hand}"
|
||||
|
||||
@property
|
||||
def gym_kwargs(self) -> dict:
|
||||
kwargs: dict[str, Any] = {"obs_type": self.obs_type, "render_mode": self.render_mode}
|
||||
@@ -345,6 +399,33 @@ class LiberoEnv(EnvConfig):
|
||||
kwargs["task_ids"] = self.task_ids
|
||||
return kwargs
|
||||
|
||||
def create_envs(self, n_envs: int, use_async_envs: bool = False):
|
||||
from lerobot.envs.libero import create_libero_envs
|
||||
|
||||
if self.task is None:
|
||||
raise ValueError("LiberoEnv requires a task to be specified")
|
||||
env_cls = gym.vector.AsyncVectorEnv if use_async_envs else gym.vector.SyncVectorEnv
|
||||
return create_libero_envs(
|
||||
task=self.task,
|
||||
n_envs=n_envs,
|
||||
camera_name=self.camera_name,
|
||||
init_states=self.init_states,
|
||||
gym_kwargs=self.gym_kwargs,
|
||||
env_cls=env_cls,
|
||||
control_mode=self.control_mode,
|
||||
episode_length=self.episode_length,
|
||||
camera_name_mapping=self.camera_name_mapping,
|
||||
)
|
||||
|
||||
def get_env_processors(self):
|
||||
from lerobot.processor.env_processor import LiberoProcessorStep
|
||||
from lerobot.processor.pipeline import PolicyProcessorPipeline
|
||||
|
||||
return (
|
||||
PolicyProcessorPipeline(steps=[LiberoProcessorStep()]),
|
||||
PolicyProcessorPipeline(steps=[]),
|
||||
)
|
||||
|
||||
|
||||
@EnvConfig.register_subclass("metaworld")
|
||||
@dataclass
|
||||
@@ -387,6 +468,19 @@ class MetaworldEnv(EnvConfig):
|
||||
"render_mode": self.render_mode,
|
||||
}
|
||||
|
||||
def create_envs(self, n_envs: int, use_async_envs: bool = False):
|
||||
from lerobot.envs.metaworld import create_metaworld_envs
|
||||
|
||||
if self.task is None:
|
||||
raise ValueError("MetaWorld requires a task to be specified")
|
||||
env_cls = gym.vector.AsyncVectorEnv if use_async_envs else gym.vector.SyncVectorEnv
|
||||
return create_metaworld_envs(
|
||||
task=self.task,
|
||||
n_envs=n_envs,
|
||||
gym_kwargs=self.gym_kwargs,
|
||||
env_cls=env_cls,
|
||||
)
|
||||
|
||||
|
||||
@EnvConfig.register_subclass("isaaclab_arena")
|
||||
@dataclass
|
||||
@@ -454,3 +548,18 @@ class IsaaclabArenaEnv(HubEnvConfig):
|
||||
@property
|
||||
def gym_kwargs(self) -> dict:
|
||||
return {}
|
||||
|
||||
def get_env_processors(self):
|
||||
from lerobot.processor.env_processor import IsaaclabArenaProcessorStep
|
||||
from lerobot.processor.pipeline import PolicyProcessorPipeline
|
||||
|
||||
state_keys = tuple(k.strip() for k in (self.state_keys or "").split(",") if k.strip())
|
||||
camera_keys = tuple(k.strip() for k in (self.camera_keys or "").split(",") if k.strip())
|
||||
if not state_keys and not camera_keys:
|
||||
raise ValueError("At least one of state_keys or camera_keys must be specified.")
|
||||
return (
|
||||
PolicyProcessorPipeline(
|
||||
steps=[IsaaclabArenaProcessorStep(state_keys=state_keys, camera_keys=camera_keys)]
|
||||
),
|
||||
PolicyProcessorPipeline(steps=[]),
|
||||
)
|
||||
|
||||
+20
-117
@@ -13,90 +13,46 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import importlib
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium.envs.registration import registry as gym_registry
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.envs.configs import AlohaEnv, EnvConfig, HubEnvConfig, IsaaclabArenaEnv, LiberoEnv, PushtEnv
|
||||
from lerobot.envs.configs import EnvConfig, HubEnvConfig
|
||||
from lerobot.envs.utils import _call_make_env, _download_hub_file, _import_hub_module, _normalize_hub_result
|
||||
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
|
||||
from lerobot.processor import ProcessorStep
|
||||
from lerobot.processor.env_processor import IsaaclabArenaProcessorStep, LiberoProcessorStep
|
||||
from lerobot.processor.pipeline import PolicyProcessorPipeline
|
||||
|
||||
|
||||
def make_env_config(env_type: str, **kwargs) -> EnvConfig:
|
||||
if env_type == "aloha":
|
||||
return AlohaEnv(**kwargs)
|
||||
elif env_type == "pusht":
|
||||
return PushtEnv(**kwargs)
|
||||
elif env_type == "libero":
|
||||
return LiberoEnv(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Policy type '{env_type}' is not available.")
|
||||
try:
|
||||
cls = EnvConfig.get_choice_class(env_type)
|
||||
except KeyError as err:
|
||||
raise ValueError(
|
||||
f"Environment type '{env_type}' is not registered. "
|
||||
f"Available: {list(EnvConfig.get_known_choices().keys())}"
|
||||
) from err
|
||||
return cls(**kwargs)
|
||||
|
||||
|
||||
def make_env_pre_post_processors(
|
||||
env_cfg: EnvConfig,
|
||||
policy_cfg: PreTrainedConfig,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
]:
|
||||
policy_cfg: Any,
|
||||
) -> tuple[Any, Any]:
|
||||
"""
|
||||
Create preprocessor and postprocessor pipelines for environment observations.
|
||||
|
||||
This function creates processor pipelines that transform raw environment
|
||||
observations and actions. By default, it returns identity processors that do nothing.
|
||||
For specific environments like LIBERO, it adds environment-specific processing steps.
|
||||
|
||||
Args:
|
||||
env_cfg: The configuration of the environment.
|
||||
|
||||
Returns:
|
||||
A tuple containing:
|
||||
- preprocessor: Pipeline that processes environment observations
|
||||
- postprocessor: Pipeline that processes environment outputs (currently identity)
|
||||
Returns a tuple of (preprocessor, postprocessor). By default, delegates to
|
||||
``env_cfg.get_env_processors()``. The XVLAConfig policy-specific override
|
||||
stays here because it depends on the *policy* config, not the env config.
|
||||
"""
|
||||
# Preprocessor and Postprocessor steps are Identity for most environments
|
||||
preprocessor_steps: list[ProcessorStep] = []
|
||||
postprocessor_steps: list[ProcessorStep] = []
|
||||
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
|
||||
|
||||
if isinstance(policy_cfg, XVLAConfig):
|
||||
from lerobot.policies.xvla.processor_xvla import make_xvla_libero_pre_post_processors
|
||||
|
||||
return make_xvla_libero_pre_post_processors()
|
||||
|
||||
# For LIBERO environments, add the LiberoProcessorStep to preprocessor
|
||||
if isinstance(env_cfg, LiberoEnv) or "libero" in env_cfg.type:
|
||||
preprocessor_steps.append(LiberoProcessorStep())
|
||||
|
||||
# For Isaaclab Arena environments, add the IsaaclabArenaProcessorStep
|
||||
if isinstance(env_cfg, IsaaclabArenaEnv) or "isaaclab_arena" in env_cfg.type:
|
||||
# Parse comma-separated keys (handle None for state-based policies)
|
||||
if env_cfg.state_keys:
|
||||
state_keys = tuple(k.strip() for k in env_cfg.state_keys.split(",") if k.strip())
|
||||
else:
|
||||
state_keys = ()
|
||||
if env_cfg.camera_keys:
|
||||
camera_keys = tuple(k.strip() for k in env_cfg.camera_keys.split(",") if k.strip())
|
||||
else:
|
||||
camera_keys = ()
|
||||
if not state_keys and not camera_keys:
|
||||
raise ValueError("At least one of state_keys or camera_keys must be specified.")
|
||||
preprocessor_steps.append(
|
||||
IsaaclabArenaProcessorStep(
|
||||
state_keys=state_keys,
|
||||
camera_keys=camera_keys,
|
||||
)
|
||||
)
|
||||
|
||||
preprocessor = PolicyProcessorPipeline(steps=preprocessor_steps)
|
||||
postprocessor = PolicyProcessorPipeline(steps=postprocessor_steps)
|
||||
|
||||
return preprocessor, postprocessor
|
||||
return env_cfg.get_env_processors()
|
||||
|
||||
|
||||
def make_env(
|
||||
@@ -163,57 +119,4 @@ def make_env(
|
||||
if n_envs < 1:
|
||||
raise ValueError("`n_envs` must be at least 1")
|
||||
|
||||
env_cls = gym.vector.AsyncVectorEnv if use_async_envs else gym.vector.SyncVectorEnv
|
||||
|
||||
if "libero" in cfg.type:
|
||||
from lerobot.envs.libero import create_libero_envs
|
||||
|
||||
if cfg.task is None:
|
||||
raise ValueError("LiberoEnv requires a task to be specified")
|
||||
|
||||
return create_libero_envs(
|
||||
task=cfg.task,
|
||||
n_envs=n_envs,
|
||||
camera_name=cfg.camera_name,
|
||||
init_states=cfg.init_states,
|
||||
gym_kwargs=cfg.gym_kwargs,
|
||||
env_cls=env_cls,
|
||||
control_mode=cfg.control_mode,
|
||||
episode_length=cfg.episode_length,
|
||||
)
|
||||
elif "metaworld" in cfg.type:
|
||||
from lerobot.envs.metaworld import create_metaworld_envs
|
||||
|
||||
if cfg.task is None:
|
||||
raise ValueError("MetaWorld requires a task to be specified")
|
||||
|
||||
return create_metaworld_envs(
|
||||
task=cfg.task,
|
||||
n_envs=n_envs,
|
||||
gym_kwargs=cfg.gym_kwargs,
|
||||
env_cls=env_cls,
|
||||
)
|
||||
|
||||
if cfg.gym_id not in gym_registry:
|
||||
print(f"gym id '{cfg.gym_id}' not found, attempting to import '{cfg.package_name}'...")
|
||||
try:
|
||||
importlib.import_module(cfg.package_name)
|
||||
except ModuleNotFoundError as e:
|
||||
raise ModuleNotFoundError(
|
||||
f"Package '{cfg.package_name}' required for env '{cfg.type}' not found. "
|
||||
f"Please install it or check PYTHONPATH."
|
||||
) from e
|
||||
|
||||
if cfg.gym_id not in gym_registry:
|
||||
raise gym.error.NameNotFound(
|
||||
f"Environment '{cfg.gym_id}' not registered even after importing '{cfg.package_name}'."
|
||||
)
|
||||
|
||||
def _make_one():
|
||||
return gym.make(cfg.gym_id, disable_env_checker=cfg.disable_env_checker, **(cfg.gym_kwargs or {}))
|
||||
|
||||
vec = env_cls([_make_one for _ in range(n_envs)], autoreset_mode=gym.vector.AutoresetMode.SAME_STEP)
|
||||
|
||||
# normalize to {suite: {task_id: vec_env}} for consistency
|
||||
suite_name = cfg.type # e.g., "pusht", "aloha"
|
||||
return {suite_name: {0: vec}}
|
||||
return cfg.create_envs(n_envs=n_envs, use_async_envs=use_async_envs)
|
||||
|
||||
@@ -223,7 +223,8 @@ class LiberoEnv(gym.Env):
|
||||
|
||||
def render(self):
|
||||
raw_obs = self._env.env._get_observations()
|
||||
image = self._format_raw_obs(raw_obs)["pixels"]["image"]
|
||||
pixels = self._format_raw_obs(raw_obs)["pixels"]
|
||||
image = next(iter(pixels.values()))
|
||||
image = image[::-1, ::-1] # flip both H and W for visualization
|
||||
return image
|
||||
|
||||
@@ -339,12 +340,6 @@ class LiberoEnv(gym.Env):
|
||||
)
|
||||
observation = self._format_raw_obs(raw_obs)
|
||||
if terminated:
|
||||
info["final_info"] = {
|
||||
"task": self.task,
|
||||
"task_id": self.task_id,
|
||||
"done": bool(done),
|
||||
"is_success": bool(is_success),
|
||||
}
|
||||
self.reset()
|
||||
truncated = False
|
||||
return observation, reward, terminated, truncated, info
|
||||
@@ -364,6 +359,7 @@ def _make_env_fns(
|
||||
init_states: bool,
|
||||
gym_kwargs: Mapping[str, Any],
|
||||
control_mode: str,
|
||||
camera_name_mapping: dict[str, str] | None = None,
|
||||
) -> list[Callable[[], LiberoEnv]]:
|
||||
"""Build n_envs factory callables for a single (suite, task_id)."""
|
||||
|
||||
@@ -379,6 +375,7 @@ def _make_env_fns(
|
||||
episode_index=episode_index,
|
||||
n_envs=n_envs,
|
||||
control_mode=control_mode,
|
||||
camera_name_mapping=camera_name_mapping,
|
||||
**local_kwargs,
|
||||
)
|
||||
|
||||
@@ -400,6 +397,7 @@ def create_libero_envs(
|
||||
env_cls: Callable[[Sequence[Callable[[], Any]]], Any] | None = None,
|
||||
control_mode: str = "relative",
|
||||
episode_length: int | None = None,
|
||||
camera_name_mapping: dict[str, str] | None = None,
|
||||
) -> dict[str, dict[int, Any]]:
|
||||
"""
|
||||
Create vectorized LIBERO environments with a consistent return shape.
|
||||
@@ -449,6 +447,7 @@ def create_libero_envs(
|
||||
init_states=init_states,
|
||||
gym_kwargs=gym_kwargs,
|
||||
control_mode=control_mode,
|
||||
camera_name_mapping=camera_name_mapping,
|
||||
)
|
||||
out[suite_name][tid] = env_cls(fns)
|
||||
print(f"Built vec env | suite={suite_name} | task_id={tid} | n_envs={n_envs}")
|
||||
|
||||
@@ -1,17 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .configs import MicrophoneConfig
|
||||
from .microphone import Microphone
|
||||
from .utils import make_microphones_from_configs
|
||||
@@ -1,140 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import abc
|
||||
from pathlib import Path
|
||||
from threading import Barrier
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .configs import MicrophoneConfig
|
||||
|
||||
|
||||
class Microphone(abc.ABC):
|
||||
"""Base class for microphone implementations.
|
||||
|
||||
Defines a standard interface for microphone operations across different backends.
|
||||
Subclasses must implement all abstract methods.
|
||||
|
||||
Manages basic microphone properties (sample rate, channels) and core operations:
|
||||
- Connection/disconnection
|
||||
- Start/stop recording
|
||||
- Audio chunk reading
|
||||
|
||||
Attributes:
|
||||
sample_rate (int | None): Configured sample rate in Hz
|
||||
channels (list[int] | None): List of channel numbers to record
|
||||
|
||||
Example:
|
||||
class MyMicrophone(Microphone):
|
||||
def __init__(self, config): ...
|
||||
@property
|
||||
def is_connected(self) -> bool: ...
|
||||
def connect(self): ...
|
||||
# Plus other required methods
|
||||
"""
|
||||
|
||||
def __init__(self, config: MicrophoneConfig):
|
||||
"""Initialize the microphone with the given configuration.
|
||||
|
||||
Args:
|
||||
config: Microphone configuration containing sample rate and channels.
|
||||
"""
|
||||
self.sample_rate: int | None = config.sample_rate
|
||||
self.channels: list[int] | None = config.channels
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def is_connected(self) -> bool:
|
||||
"""Check if the microphone is currently connected.
|
||||
|
||||
Returns:
|
||||
bool: True if the microphone is connected and ready to start recording,
|
||||
False otherwise.
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def is_recording(self) -> bool:
|
||||
"""Check if the microphone is currently recording.
|
||||
|
||||
Returns:
|
||||
bool: True if the microphone is recording, False otherwise.
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def is_writing(self) -> bool:
|
||||
"""Check if the microphone is currently writing to a file.
|
||||
|
||||
Returns:
|
||||
bool: True if the microphone is writing to a file, False otherwise.
|
||||
"""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@abc.abstractmethod
|
||||
def find_microphones() -> list[dict[str, Any]]:
|
||||
"""Detects available microphones connected to the system.
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: A list of dictionaries,
|
||||
where each dictionary contains information about a detected microphone.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def connect(self) -> None:
|
||||
"""Establish connection to the microphone."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def start_recording(
|
||||
self,
|
||||
output_file: str | Path | None = None,
|
||||
multiprocessing: bool | None = False,
|
||||
overwrite: bool | None = True,
|
||||
barrier: Barrier | None = None,
|
||||
) -> None:
|
||||
"""Start recording audio from the microphone.
|
||||
|
||||
Args:
|
||||
output_file: Optional path to save the recorded audio.
|
||||
multiprocessing: If True, enables multiprocessing for recording. Defaults to multithreading otherwise.
|
||||
overwrite: If True, overwrites existing files at output_file path.
|
||||
barrier: If not None, ensures that multiple microphones start recording at the same time.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def read(self) -> np.ndarray:
|
||||
"""Capture and return a single audio chunk from the microphone.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Captured audio chunk as a numpy array.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def stop_recording(self) -> None:
|
||||
"""Stop recording audio from the microphone."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def disconnect(self) -> None:
|
||||
"""Disconnect the microphone and release any resources."""
|
||||
pass
|
||||
@@ -1,16 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .configuration_portaudio import PortAudioMicrophoneConfig
|
||||
from .microphone_portaudio import PortAudioMicrophone
|
||||
@@ -1,41 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from ..configs import MicrophoneConfig
|
||||
|
||||
|
||||
@MicrophoneConfig.register_subclass("portaudio")
|
||||
@dataclass
|
||||
class PortAudioMicrophoneConfig(MicrophoneConfig):
|
||||
"""Configuration class for PortAudio-based microphone devices.
|
||||
|
||||
This class provides configuration options for microphones accessed through PortAudio with the sounddevice Python package.
|
||||
including device index, sample rate and channels.
|
||||
|
||||
Example configurations:
|
||||
```python
|
||||
# Basic configurations
|
||||
PortAudioMicrophoneConfig(0, 16000, [1]) # Device index 0, 16000Hz, mono
|
||||
PortAudioMicrophoneConfig(1, 44100, [1, 2]) # Device index 1, 44100Hz, stereo
|
||||
```
|
||||
|
||||
Attributes:
|
||||
microphone_index: Device index for the microphone.
|
||||
sample_rate: Sample rate in Hz for the microphone.
|
||||
channels: List of channel numbers to use for the microphone.
|
||||
"""
|
||||
|
||||
microphone_index: int
|
||||
@@ -1,394 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import abc
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from threading import Event, Thread
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from sounddevice import PortAudioError
|
||||
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
|
||||
|
||||
# --- Interface definitions for InputStream ---
|
||||
class IInputStream(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
samplerate: float | None = None,
|
||||
blocksize: int | None = None,
|
||||
device: int | str | None = None,
|
||||
channels: int | None = None,
|
||||
dtype: str | np.dtype | None = None,
|
||||
latency: float | str | None = None,
|
||||
callback: Callable[[Any, int, Any, Any], None] | None = None,
|
||||
):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def start(self) -> None:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def stop(self) -> None:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def close(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class ISounddeviceSDK(abc.ABC):
|
||||
"""Interface defining the contract for the Sounddevice SDK."""
|
||||
|
||||
InputStream: type[IInputStream]
|
||||
|
||||
@abc.abstractmethod
|
||||
def query_devices(self, device: int | str | None = None, kind: str | None = None) -> list[dict[str, Any]]:
|
||||
pass
|
||||
|
||||
|
||||
# --- Real SDK Adapter ---
|
||||
|
||||
|
||||
class SounddeviceSDKAdapter(ISounddeviceSDK):
|
||||
"""Adapts the real sounddevice library to the ISounddeviceSDK interface."""
|
||||
|
||||
_sounddevice = None
|
||||
|
||||
def __init__(self):
|
||||
try:
|
||||
import sounddevice
|
||||
|
||||
SounddeviceSDKAdapter._sounddevice = sounddevice
|
||||
except ImportError as e:
|
||||
raise ImportError("sounddevice library not found") from e
|
||||
|
||||
# --- Inner Class Implementation ---
|
||||
class RealInputStream(IInputStream):
|
||||
def __init__(
|
||||
self,
|
||||
samplerate: int | None = None,
|
||||
blocksize: int | None = None,
|
||||
device: int | None = None,
|
||||
channels: int | None = None,
|
||||
dtype: str | np.dtype | None = None,
|
||||
latency: float | str | None = None,
|
||||
callback: Callable[[Any, int, Any, Any], None] | None = None,
|
||||
):
|
||||
import sounddevice
|
||||
|
||||
self._input_stream = sounddevice.InputStream(
|
||||
samplerate=samplerate,
|
||||
blocksize=blocksize,
|
||||
device=device,
|
||||
channels=channels,
|
||||
dtype=dtype,
|
||||
latency=latency,
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
def start(self) -> None:
|
||||
self._input_stream.start()
|
||||
|
||||
def stop(self) -> None:
|
||||
self._input_stream.stop()
|
||||
|
||||
def close(self) -> None:
|
||||
self._input_stream.close()
|
||||
|
||||
def __del__(self):
|
||||
self._input_stream.stop()
|
||||
self._input_stream.close()
|
||||
|
||||
@property
|
||||
def active(self) -> bool:
|
||||
return self._input_stream.active
|
||||
|
||||
@property
|
||||
def stopped(self) -> bool:
|
||||
return self._input_stream.stopped
|
||||
|
||||
@property
|
||||
def closed(self) -> bool:
|
||||
return self._input_stream.closed
|
||||
|
||||
InputStream = RealInputStream
|
||||
|
||||
def query_devices(self, device: int | str | None = None, kind: str | None = None) -> list[dict[str, Any]]:
|
||||
return SounddeviceSDKAdapter._sounddevice.query_devices(device, kind)
|
||||
|
||||
|
||||
# Emulates a 48kHz stereo microphone
|
||||
VALID_DTYPE = {
|
||||
"float32",
|
||||
"int32",
|
||||
"int16",
|
||||
"int8",
|
||||
"uint8",
|
||||
np.float32,
|
||||
np.int32,
|
||||
np.int16,
|
||||
np.int8,
|
||||
np.uint8,
|
||||
}
|
||||
VALID_LATENCY = {"low", "high"}
|
||||
|
||||
VALID_DEVICES = [
|
||||
{
|
||||
"index": 0,
|
||||
"name": "Built-in Microphone",
|
||||
"hostapi": 0,
|
||||
"max_input_channels": 2,
|
||||
"max_output_channels": 0,
|
||||
"default_low_input_latency": 0.01,
|
||||
"default_low_output_latency": 0.001,
|
||||
"default_high_input_latency": 0.1,
|
||||
"default_high_output_latency": 0.01,
|
||||
"default_samplerate": 48000.0,
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"name": "Built-in Output",
|
||||
"hostapi": 0,
|
||||
"max_input_channels": 0,
|
||||
"max_output_channels": 2,
|
||||
"default_low_input_latency": 0.04,
|
||||
"default_low_output_latency": 0.04,
|
||||
"default_high_input_latency": 0.12,
|
||||
"default_high_output_latency": 0.12,
|
||||
"default_samplerate": 48000.0,
|
||||
},
|
||||
{
|
||||
"index": 2,
|
||||
"name": "USB Audio Device",
|
||||
"hostapi": 0,
|
||||
"max_input_channels": 1,
|
||||
"max_output_channels": 0,
|
||||
"default_low_input_latency": 0.03,
|
||||
"default_low_output_latency": 0.01,
|
||||
"default_high_input_latency": 0.04,
|
||||
"default_high_output_latency": 0.03,
|
||||
"default_samplerate": 16000.0,
|
||||
},
|
||||
]
|
||||
|
||||
# -- Fake SDK Adapter ---
|
||||
|
||||
|
||||
class FakeSounddeviceSDKAdapter(ISounddeviceSDK):
|
||||
"""Implements the ISounddeviceSDK interface with fake behaviour for testing."""
|
||||
|
||||
# --- Inner Class Implementation ---
|
||||
class FakeInputStream(IInputStream):
|
||||
def __init__(
|
||||
self,
|
||||
samplerate: float | None = None,
|
||||
blocksize: int | None = None,
|
||||
device: int | str | None = None,
|
||||
channels: int | None = None,
|
||||
dtype: str | None = None,
|
||||
latency: str | None = None,
|
||||
callback: Callable[[Any, int, Any, Any], None] | None = None,
|
||||
):
|
||||
self.samplerate = samplerate
|
||||
self.blocksize = blocksize
|
||||
self.device = device
|
||||
self.channels = channels
|
||||
self.dtype = dtype
|
||||
self.latency = latency
|
||||
self.callback = callback
|
||||
|
||||
self._validate_settings()
|
||||
|
||||
self._active = False
|
||||
self._closed = False
|
||||
|
||||
if self.callback is not None:
|
||||
self._streaming_thread = Thread(target=self._streaming_loop, daemon=True)
|
||||
self._streaming_thread_stop_event = Event()
|
||||
|
||||
@property
|
||||
def active(self) -> bool:
|
||||
"""True when the stream is active, False otherwise."""
|
||||
return self._active
|
||||
|
||||
@property
|
||||
def stopped(self) -> bool:
|
||||
"""True when the stream is stopped, False otherwise."""
|
||||
return not self._active
|
||||
|
||||
@property
|
||||
def closed(self) -> bool:
|
||||
"""True after a call to close(), False otherwise."""
|
||||
return self._closed
|
||||
|
||||
def _get_device_info(self):
|
||||
"""Returns the device info for the device."""
|
||||
for device in VALID_DEVICES:
|
||||
if (isinstance(self.device, int) and device["index"] == self.device) or (
|
||||
isinstance(self.device, str) and device["name"] == self.device
|
||||
):
|
||||
return device
|
||||
raise PortAudioError(f"No input device matching {self.device}")
|
||||
|
||||
def _validate_device(self):
|
||||
"""Validates the device against the valid devices."""
|
||||
valid_device_indices = [device["index"] for device in VALID_DEVICES]
|
||||
valid_device_names = [device["name"] for device in VALID_DEVICES]
|
||||
|
||||
if self.device is not None:
|
||||
if isinstance(self.device, (int, str)):
|
||||
# Check if device index is valid
|
||||
if isinstance(self.device, int) and self.device not in valid_device_indices:
|
||||
raise PortAudioError(f"Error querying device {self.device}")
|
||||
|
||||
# Check if device name is valid
|
||||
if isinstance(self.device, str) and self.device not in valid_device_names:
|
||||
raise PortAudioError(f"No input device matching {self.device}")
|
||||
else:
|
||||
raise PortAudioError(f"Device must be int or str, got {type(self.device)}")
|
||||
else:
|
||||
# Default to first input device
|
||||
input_devices = [d for d in VALID_DEVICES if d["max_input_channels"] > 0]
|
||||
if input_devices:
|
||||
self.device = input_devices[0]["index"]
|
||||
|
||||
def _validate_samplerate(self):
|
||||
"""Validates the samplerate against the device's maximum samplerate."""
|
||||
device_info = self._get_device_info()
|
||||
if self.samplerate is None:
|
||||
self.samplerate = device_info["default_samplerate"]
|
||||
elif self.samplerate > device_info["default_samplerate"] or self.samplerate < 1000:
|
||||
raise PortAudioError("Error opening InputStream: Invalid sample rate")
|
||||
|
||||
def _validate_channels(self):
|
||||
"""Validates the channels against the device's maximum channels."""
|
||||
device_info = self._get_device_info()
|
||||
if self.channels is None:
|
||||
self.channels = device_info["max_input_channels"]
|
||||
elif self.channels > device_info["max_input_channels"] or self.channels < 1:
|
||||
raise PortAudioError("Error opening InputStream: Invalid number of channels")
|
||||
|
||||
def _validate_dtype(self):
|
||||
"""Validates the dtype against the valid dtypes."""
|
||||
if self.dtype is not None:
|
||||
if self.dtype not in VALID_DTYPE:
|
||||
raise PortAudioError("Invalid input sample format")
|
||||
else:
|
||||
self.dtype = "float32" # Default dtype
|
||||
|
||||
def _validate_latency(self):
|
||||
"""Validates the latency against the valid latencies."""
|
||||
if self.latency is not None:
|
||||
if self.latency not in VALID_LATENCY:
|
||||
raise PortAudioError("Invalid latency")
|
||||
else:
|
||||
self.latency = "low" # Default latency
|
||||
|
||||
if isinstance(self.latency, str):
|
||||
device_info = self._get_device_info()
|
||||
if self.latency == "low":
|
||||
self.latency = device_info["default_low_input_latency"]
|
||||
elif self.latency == "high":
|
||||
self.latency = device_info["default_high_input_latency"]
|
||||
|
||||
def _validate_settings(self):
|
||||
"""Validates the input parameters against available devices and valid options."""
|
||||
self._validate_device()
|
||||
self._validate_samplerate()
|
||||
self._validate_channels()
|
||||
self._validate_dtype()
|
||||
self._validate_latency()
|
||||
|
||||
def _simulated_audio_data(self) -> np.ndarray:
|
||||
"""Generates a simulated audio signal for testing purposes with proper value ranges."""
|
||||
duration_samples = int(self.samplerate * self.latency)
|
||||
|
||||
# Generate output according to dtype
|
||||
if self.dtype in {"float32", np.float32}:
|
||||
# Generate values between -1 and 1 for float32
|
||||
data = np.random.uniform(-1.0, 1.0, (duration_samples, self.channels)).astype(self.dtype)
|
||||
else:
|
||||
# Use np.iinfo to get proper range for integer types
|
||||
info = np.iinfo(self.dtype)
|
||||
data = np.random.randint(
|
||||
info.min, info.max + 1, (duration_samples, self.channels), dtype=self.dtype
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
def _streaming_loop(self):
|
||||
if self.callback is not None:
|
||||
while not self._streaming_thread_stop_event.is_set():
|
||||
precise_sleep(self.latency)
|
||||
tmp_data = self._simulated_audio_data()
|
||||
self.callback(
|
||||
tmp_data,
|
||||
len(tmp_data),
|
||||
time.perf_counter(),
|
||||
None,
|
||||
)
|
||||
|
||||
def start(self) -> None:
|
||||
"""Start the fake input stream."""
|
||||
if not self.active and self.callback is not None:
|
||||
self._streaming_thread.start()
|
||||
self._active = True
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop the fake input stream."""
|
||||
if self.callback is not None:
|
||||
self._streaming_thread_stop_event.set()
|
||||
self._streaming_thread.join()
|
||||
self._active = False
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the fake input stream."""
|
||||
if self.active and self.callback is not None:
|
||||
self.stop()
|
||||
self._active = False
|
||||
self._closed = True
|
||||
|
||||
def __del__(self):
|
||||
self.close()
|
||||
|
||||
InputStream = FakeInputStream
|
||||
|
||||
def query_devices(self, device: int | str | None = None, kind: str | None = None) -> list[dict[str, Any]]:
|
||||
"""Returns a realistic list of audio devices including speakers and microphones."""
|
||||
if device is not None:
|
||||
# Return specific device
|
||||
for valid_device in VALID_DEVICES:
|
||||
if (isinstance(device, int) and valid_device["index"] == device) or (
|
||||
isinstance(device, str) and valid_device["name"] == device
|
||||
):
|
||||
return valid_device
|
||||
raise PortAudioError(f"Error querying device {device}")
|
||||
|
||||
elif kind is not None:
|
||||
for valid_device in VALID_DEVICES:
|
||||
if (
|
||||
valid_device["max_input_channels"] > 0
|
||||
and kind == "input"
|
||||
or valid_device["max_output_channels"] > 0
|
||||
and kind == "output"
|
||||
):
|
||||
return valid_device
|
||||
raise PortAudioError(f"No {kind} device found")
|
||||
|
||||
return VALID_DEVICES
|
||||
@@ -1,566 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Provides the PortAudioMicrophone class for capturing audio from microphones using the PortAudio library through the sounddevice Python package.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from multiprocessing import (
|
||||
Event as process_Event,
|
||||
JoinableQueue as process_Queue,
|
||||
Process,
|
||||
)
|
||||
from pathlib import Path
|
||||
from queue import Empty
|
||||
from threading import Barrier, Event, Event as thread_Event, Thread
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from soundfile import SoundFile
|
||||
|
||||
from lerobot.microphones.portaudio.interface_sounddevice_sdk import ISounddeviceSDK, SounddeviceSDKAdapter
|
||||
from lerobot.utils.errors import (
|
||||
DeviceAlreadyConnectedError,
|
||||
DeviceAlreadyRecordingError,
|
||||
DeviceNotConnectedError,
|
||||
DeviceNotRecordingError,
|
||||
)
|
||||
from lerobot.utils.shared_array import SharedArray
|
||||
|
||||
from ..microphone import Microphone
|
||||
from .configuration_portaudio import PortAudioMicrophoneConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PortAudioMicrophone(Microphone):
|
||||
"""
|
||||
The PortAudioMicrophone class handles all microphones compatible with sounddevice (and the underlying PortAudio library). Most microphones and sound cards are compatible, across all OS (Linux, Mac, Windows).
|
||||
|
||||
A PortAudioMicrophone instance requires the sounddevice index of the microphone, which may be obtained using `python -m sounddevice`. It also requires the recording sample rate as well as the list of recorded channels.
|
||||
|
||||
Example of usage:
|
||||
```python
|
||||
from lerobot.common.robot_devices.microphones.configs import PortAudioMicrophoneConfig
|
||||
|
||||
config = PortAudioMicrophoneConfig(microphone_index=0, sample_rate=16000, channels=[1])
|
||||
microphone = PortAudioMicrophone(config)
|
||||
|
||||
microphone.connect()
|
||||
microphone.start_recording("some/output/file.wav")
|
||||
...
|
||||
audio_readings = microphone.read() # Gets all recorded audio data since the last read or since the beginning of the recording. The longer the period the longer the reading time !
|
||||
...
|
||||
microphone.stop_recording()
|
||||
microphone.disconnect()
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, config: PortAudioMicrophoneConfig, sounddevice_sdk: ISounddeviceSDK = None):
|
||||
"""
|
||||
Initializes the PortAudioMicrophone instance.
|
||||
|
||||
Args:
|
||||
config: The configuration settings for the microphone.
|
||||
"""
|
||||
super().__init__(config)
|
||||
|
||||
if sounddevice_sdk is None:
|
||||
self.sounddevice_sdk = SounddeviceSDKAdapter()
|
||||
else:
|
||||
self.sounddevice_sdk = sounddevice_sdk
|
||||
|
||||
# Microphone index
|
||||
self.microphone_index = config.microphone_index
|
||||
|
||||
# Input audio recording process and events
|
||||
self.record_process = None
|
||||
self.record_stop_event = process_Event()
|
||||
self.record_start_event = process_Event()
|
||||
self.record_close_event = process_Event()
|
||||
self.record_is_started_event = process_Event()
|
||||
self.audio_callback_start_event = process_Event()
|
||||
|
||||
# Process-safe concurrent queue to send audio from the recording process to the writing process/thread
|
||||
self.write_queue = process_Queue()
|
||||
|
||||
# SharedArray to store audio from the recording process.
|
||||
self.read_shared_array = None
|
||||
self.local_read_shared_array = None
|
||||
# Thread/Process to handle data writing in a separate thread/process (safely)
|
||||
self.write_thread = None
|
||||
self.write_stop_event = None
|
||||
self.write_is_started_event = None
|
||||
|
||||
self.logs = {}
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.__class__.__name__}({self.microphone_index})"
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self.record_process is not None and self.record_process.is_alive()
|
||||
|
||||
@property
|
||||
def is_recording(self) -> bool:
|
||||
return self.record_is_started_event.is_set()
|
||||
|
||||
@property
|
||||
def is_writing(self) -> bool:
|
||||
return self.write_thread is not None and self.write_is_started_event.is_set()
|
||||
|
||||
@staticmethod
|
||||
def find_microphones(
|
||||
device: int | str | None = None, sounddevice_sdk: ISounddeviceSDK = None
|
||||
) -> list[dict[str, Any]] | dict[str, Any]:
|
||||
"""
|
||||
Detects available microphones connected to the system.
|
||||
|
||||
Args:
|
||||
device: The device to find microphones for. If None, all microphones are found.
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: A list of dictionaries,
|
||||
where each dictionary contains information about a detected microphone : index, name, sample rate, channels.
|
||||
"""
|
||||
|
||||
if sounddevice_sdk is None:
|
||||
sounddevice_sdk = SounddeviceSDKAdapter()
|
||||
|
||||
found_microphones_info = []
|
||||
|
||||
devices = sounddevice_sdk.query_devices()
|
||||
for d in devices:
|
||||
if d["max_input_channels"] > 0:
|
||||
microphone_info = {
|
||||
"index": d["index"],
|
||||
"name": d["name"],
|
||||
"sample_rate": int(d["default_samplerate"]),
|
||||
"channels": np.arange(1, d["max_input_channels"] + 1),
|
||||
}
|
||||
|
||||
if device is None or (
|
||||
(isinstance(device, int) and d["index"] == device)
|
||||
or (isinstance(device, str) and d["name"] == device)
|
||||
):
|
||||
found_microphones_info.append(microphone_info)
|
||||
|
||||
if device is not None:
|
||||
if len(found_microphones_info) == 0:
|
||||
raise RuntimeError(f"No microphone found for device {device}")
|
||||
else:
|
||||
return found_microphones_info[0]
|
||||
|
||||
if len(found_microphones_info) == 0:
|
||||
logger.warning("No microphone found !")
|
||||
|
||||
return found_microphones_info
|
||||
|
||||
def _configure_capture_settings(self) -> None:
|
||||
"""
|
||||
Validates the microphone index, sample rate and channels settings specified in the constructor's config to the un-connected microphone.
|
||||
|
||||
This method actually checks the specified settings and fills the sample rate and channels settings if not specified before attempting to start a PortAudio stream.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If one of the specified settings is not compatible with the microphone.
|
||||
DeviceAlreadyConnectedError: If the microphone is connected when attempting to configure settings.
|
||||
"""
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(
|
||||
f"Cannot configure settings for {self} as it is already connected."
|
||||
)
|
||||
|
||||
self._validate_microphone_index()
|
||||
self._validate_sample_rate()
|
||||
self._validate_channels()
|
||||
|
||||
def _validate_microphone_index(self) -> None:
|
||||
""" "Validates the microphone index against available devices by checking if it has at least one input channel."""
|
||||
|
||||
try:
|
||||
PortAudioMicrophone.find_microphones(self.microphone_index, self.sounddevice_sdk)
|
||||
except RuntimeError as e:
|
||||
raise RuntimeError(
|
||||
f"{e}. Available microphones: {PortAudioMicrophone.find_microphones(sounddevice_sdk=self.sounddevice_sdk)}"
|
||||
) from e
|
||||
|
||||
def _validate_sample_rate(self) -> None:
|
||||
"""Validates the sample rate against the actual microphone's default sample rate."""
|
||||
|
||||
actual_sample_rate = PortAudioMicrophone.find_microphones(
|
||||
self.microphone_index, self.sounddevice_sdk
|
||||
)["sample_rate"]
|
||||
|
||||
if self.sample_rate is not None:
|
||||
try:
|
||||
self.sample_rate = int(self.sample_rate)
|
||||
except ValueError as e:
|
||||
raise RuntimeError(
|
||||
f"Cannot convert the provided sample rate ({self.sample_rate} Hz) to an integer."
|
||||
) from e
|
||||
|
||||
if self.sample_rate > actual_sample_rate or self.sample_rate < 1000:
|
||||
raise RuntimeError(
|
||||
f"Provided sample rate {self.sample_rate} is either too low or too high compared to the sample rate of the microphone {actual_sample_rate}."
|
||||
)
|
||||
else:
|
||||
if self.sample_rate < actual_sample_rate:
|
||||
logger.warning(
|
||||
"Provided sample rate is lower than the sample rate of the microphone. Performance may be impacted."
|
||||
)
|
||||
else:
|
||||
self.sample_rate = actual_sample_rate
|
||||
|
||||
def _validate_channels(self) -> None:
|
||||
"""Validates the channels against the actual microphone's maximum input channels."""
|
||||
|
||||
actual_channels = PortAudioMicrophone.find_microphones(self.microphone_index, self.sounddevice_sdk)[
|
||||
"channels"
|
||||
]
|
||||
|
||||
if self.channels is not None and len(self.channels) > 0:
|
||||
if not all(channel in actual_channels for channel in self.channels):
|
||||
raise RuntimeError(
|
||||
f"Some of the provided channels {self.channels} are outside the possible channel range of the microphone {actual_channels}."
|
||||
)
|
||||
else:
|
||||
self.channels = actual_channels
|
||||
|
||||
# Get channels index instead of number for slicing
|
||||
self.channels_index = np.array(self.channels) - 1
|
||||
|
||||
def connect(self) -> None:
|
||||
"""
|
||||
Connects the microphone and checks if the requested acquisition parameters are compatible with the microphone.
|
||||
"""
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"Microphone {self.microphone_index} is already connected.")
|
||||
|
||||
self._configure_capture_settings()
|
||||
|
||||
# Create or reset queue and shared array
|
||||
self.read_shared_array = SharedArray(
|
||||
shape=(self.sample_rate * 10, len(self.channels)),
|
||||
dtype=np.dtype("float32"),
|
||||
)
|
||||
self.local_read_shared_array = self.read_shared_array.get_local_array()
|
||||
self.write_queue = process_Queue()
|
||||
|
||||
# Reset events
|
||||
self.record_start_event.clear()
|
||||
self.record_stop_event.clear()
|
||||
self.record_close_event.clear()
|
||||
self.record_is_started_event.clear()
|
||||
self.audio_callback_start_event.clear()
|
||||
|
||||
# Create and start an audio input stream with a recording callback
|
||||
# Remark: this is done in a separate process so that audio recording is not impacted by the main thread CPU usage, especially the precise_sleep function.
|
||||
process_init_event = process_Event()
|
||||
self.record_process = Process(
|
||||
target=self._record_process,
|
||||
args=(
|
||||
self.microphone_index,
|
||||
self.sample_rate,
|
||||
self.channels,
|
||||
process_init_event,
|
||||
self.record_start_event,
|
||||
self.record_stop_event,
|
||||
self.record_close_event,
|
||||
self.record_is_started_event,
|
||||
self.audio_callback_start_event,
|
||||
self.write_queue,
|
||||
self.read_shared_array,
|
||||
self.sounddevice_sdk,
|
||||
),
|
||||
)
|
||||
self.record_process.daemon = True
|
||||
self.record_process.start()
|
||||
|
||||
is_init = process_init_event.wait(
|
||||
timeout=5.0
|
||||
) # Wait for the recording process to be started, and to potentially raise an error on failure.
|
||||
if not self.is_connected or not is_init:
|
||||
raise RuntimeError(f"Error connecting microphone {self.microphone_index}.")
|
||||
|
||||
logger.info(f"{self} connected.")
|
||||
|
||||
def disconnect(self) -> None:
|
||||
"""
|
||||
Disconnects the microphone and stops the recording.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.")
|
||||
|
||||
if self.is_recording:
|
||||
self.stop_recording()
|
||||
|
||||
self.record_close_event.set()
|
||||
self.read_shared_array.delete()
|
||||
self.write_queue.close()
|
||||
self.record_process.join()
|
||||
|
||||
if self.is_connected:
|
||||
raise RuntimeError(f"Error disconnecting microphone {self.microphone_index}.")
|
||||
|
||||
logger.info(f"{self} disconnected.")
|
||||
|
||||
def _read(self) -> np.ndarray:
|
||||
"""
|
||||
Thread/Process-safe callback to read available audio data
|
||||
"""
|
||||
return self.read_shared_array.read(self.local_read_shared_array, flush=True)
|
||||
|
||||
def read(self) -> np.ndarray:
|
||||
"""
|
||||
Reads the last audio chunk recorded by the microphone, e.g. all samples recorded since the last read or since the beginning of the recording.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.")
|
||||
if not self.is_recording:
|
||||
raise RuntimeError(f"Microphone {self.microphone_index} is not recording.")
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
audio_readings = self._read()
|
||||
|
||||
# log the number of seconds it took to read the audio chunk
|
||||
self.logs["delta_timestamp_s"] = time.perf_counter() - start_time
|
||||
|
||||
# log the utc time at which the audio chunk was received
|
||||
self.logs["timestamp_utc"] = time.perf_counter()
|
||||
|
||||
return audio_readings
|
||||
|
||||
@staticmethod
|
||||
def _record_process(
|
||||
microphone_index,
|
||||
sample_rate,
|
||||
channels,
|
||||
process_init_event,
|
||||
record_start_event,
|
||||
record_stop_event,
|
||||
record_close_event,
|
||||
record_is_started_event,
|
||||
audio_callback_start_event,
|
||||
write_queue,
|
||||
read_shared_array,
|
||||
sounddevice_sdk,
|
||||
) -> None:
|
||||
"""
|
||||
Process callback used to create an unpickable sounddevice audio input stream with a recording callback and start, stop and close it based on multiprocessing events.
|
||||
"""
|
||||
|
||||
channels_index = np.array(channels) - 1
|
||||
local_read_shared_array = read_shared_array.get_local_array()
|
||||
|
||||
def audio_callback(indata, frames, timestamp, status) -> None:
|
||||
"""
|
||||
Low-level sounddevice callback.
|
||||
"""
|
||||
if status:
|
||||
logger.warning(status)
|
||||
if audio_callback_start_event.is_set():
|
||||
write_queue.put_nowait(indata[:, channels_index])
|
||||
read_shared_array.write(local_read_shared_array, indata[:, channels_index])
|
||||
|
||||
# Create the audio stream
|
||||
# InputStream must be instantiated in the process as it is not pickable.
|
||||
stream = sounddevice_sdk.InputStream(
|
||||
device=microphone_index,
|
||||
samplerate=sample_rate,
|
||||
channels=max(channels),
|
||||
dtype="float32",
|
||||
blocksize=0, # Varying input buffer length, but no additional latency
|
||||
latency="low", # Low latency mode (not enabled by default !)
|
||||
# never_drop_input=True, # Disabled as it generates an error for some devices
|
||||
callback=audio_callback,
|
||||
)
|
||||
process_init_event.set()
|
||||
|
||||
while True:
|
||||
start_flag = record_start_event.wait(timeout=0.1)
|
||||
if record_close_event.is_set():
|
||||
break
|
||||
elif not start_flag:
|
||||
continue
|
||||
stream.start()
|
||||
record_is_started_event.set()
|
||||
record_stop_event.wait()
|
||||
stream.stop() # stream.stop() waits for all buffers to be processed, stream.abort() flushes the buffers !
|
||||
record_is_started_event.clear()
|
||||
stream.close()
|
||||
|
||||
def start_recording(
|
||||
self,
|
||||
output_file: str | None = None,
|
||||
multiprocessing: bool | None = False,
|
||||
overwrite: bool | None = True,
|
||||
barrier: Barrier | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Starts the recording of the microphone. If output_file is provided, the audio will be written to this file.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.")
|
||||
if self.is_recording:
|
||||
raise DeviceAlreadyRecordingError(f"Microphone {self.microphone_index} is already recording.")
|
||||
|
||||
# Reset queue and shared memory
|
||||
self.read_shared_array.reset()
|
||||
self._clear_queue(self.write_queue)
|
||||
|
||||
# Reset stop event
|
||||
self.record_stop_event.clear()
|
||||
|
||||
# Write recordings into a file if output_file is provided
|
||||
if output_file is not None:
|
||||
output_file = Path(output_file)
|
||||
output_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if output_file.exists():
|
||||
if overwrite:
|
||||
output_file.unlink()
|
||||
else:
|
||||
raise FileExistsError(
|
||||
f"Output file {output_file} already exists. Set overwrite to True to overwrite it."
|
||||
)
|
||||
|
||||
if multiprocessing:
|
||||
self.write_stop_event = process_Event()
|
||||
self.write_is_started_event = process_Event()
|
||||
self.write_thread = Process(
|
||||
target=PortAudioMicrophone._write_loop,
|
||||
args=(
|
||||
self.write_queue,
|
||||
self.write_stop_event,
|
||||
self.write_is_started_event,
|
||||
self.sample_rate,
|
||||
self.channels,
|
||||
output_file,
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.write_stop_event = thread_Event()
|
||||
self.write_is_started_event = thread_Event()
|
||||
self.write_thread = Thread(
|
||||
target=PortAudioMicrophone._write_loop,
|
||||
args=(
|
||||
self.write_queue,
|
||||
self.write_stop_event,
|
||||
self.write_is_started_event,
|
||||
self.sample_rate,
|
||||
self.channels,
|
||||
output_file,
|
||||
),
|
||||
)
|
||||
self.write_thread.daemon = True
|
||||
self.write_thread.start()
|
||||
self.write_is_started_event.wait() # Wait for the writing thread/process to be started.
|
||||
|
||||
self.record_start_event.set() # Start the input audio stream process
|
||||
self.record_is_started_event.wait() # Wait for the input audio stream process to be actually started
|
||||
|
||||
if barrier is not None:
|
||||
barrier.wait() # Wait for multiple input audio streams to be started at the same time
|
||||
|
||||
self.audio_callback_start_event.set()
|
||||
|
||||
if not self.is_recording:
|
||||
raise RuntimeError(f"Error starting recording for microphone {self.microphone_index}.")
|
||||
if output_file is not None and not self.is_writing:
|
||||
raise RuntimeError(f"Error starting writing for microphone {self.microphone_index}.")
|
||||
|
||||
def stop_recording(self) -> None:
|
||||
"""
|
||||
Stops the recording of the microphones.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.")
|
||||
if not self.is_recording:
|
||||
raise DeviceNotRecordingError(f"Microphone {self.microphone_index} is not recording.")
|
||||
|
||||
self.audio_callback_start_event.clear()
|
||||
self.record_start_event.clear() # Ensures the audio stream is not started again !
|
||||
self.record_stop_event.set()
|
||||
|
||||
# Wait for the stream to be stopped (might lead to race condition if the stream is not properly stopped on array reset and queue clearing)
|
||||
timeout = 1.0
|
||||
while self.is_recording and timeout > 0:
|
||||
time.sleep(0.01)
|
||||
timeout -= 0.01
|
||||
|
||||
self.read_shared_array.reset()
|
||||
self._clear_queue(self.write_queue, join_queue=True)
|
||||
|
||||
if self.is_writing:
|
||||
self.write_stop_event.set()
|
||||
self.write_thread.join()
|
||||
|
||||
if self.is_recording:
|
||||
raise RuntimeError(f"Error stopping recording for microphone {self.microphone_index}.")
|
||||
if self.is_writing:
|
||||
raise RuntimeError(f"Error stopping writing for microphone {self.microphone_index}.")
|
||||
|
||||
@staticmethod
|
||||
def _write_loop(
|
||||
queue,
|
||||
write_stop_event: Event,
|
||||
write_is_started_event: Event,
|
||||
sample_rate: int,
|
||||
channels: list[int],
|
||||
output_file: Path,
|
||||
) -> None:
|
||||
"""
|
||||
Thread/Process-safe loop to write audio data into a file.
|
||||
"""
|
||||
# Can only be run on a single process/thread for file writing safety
|
||||
with SoundFile(
|
||||
output_file,
|
||||
mode="w",
|
||||
samplerate=sample_rate,
|
||||
channels=len(channels),
|
||||
format="WAV",
|
||||
subtype="FLOAT", # By default, a much lower quality WAV file is created !
|
||||
) as file:
|
||||
write_is_started_event.set()
|
||||
while not write_stop_event.is_set():
|
||||
try:
|
||||
file.write(
|
||||
queue.get(timeout=0.005)
|
||||
) # Timeout set as the usual sounddevice buffer size. get_nowait is not possible here as it saturates the thread.
|
||||
queue.task_done()
|
||||
except Empty:
|
||||
continue
|
||||
write_is_started_event.clear()
|
||||
|
||||
def __del__(self) -> None:
|
||||
if self.is_connected:
|
||||
self.disconnect()
|
||||
|
||||
@staticmethod
|
||||
def _clear_queue(queue, join_queue: bool = False):
|
||||
"""
|
||||
Clears the queue by getting all items until it is empty. The longer the queue, the longer it takes to clear it.
|
||||
"""
|
||||
try:
|
||||
while True:
|
||||
queue.get_nowait()
|
||||
queue.task_done()
|
||||
except Empty:
|
||||
if join_queue:
|
||||
queue.join()
|
||||
return
|
||||
@@ -1,16 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .configuration_touchlab import TouchLabSensorConfig
|
||||
from .sensor_touchlab import TouchLabSensor
|
||||
@@ -1,42 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from ..configs import MicrophoneConfig
|
||||
|
||||
|
||||
@MicrophoneConfig.register_subclass("touchlab")
|
||||
@dataclass
|
||||
class TouchLabSensorConfig(MicrophoneConfig):
|
||||
"""Configuration class for TouchLab tactile sensors (technically not a microphone, but behaves like one acquisition-wise).
|
||||
|
||||
This class provides configuration options for TouchLab tactile sensors, including serial port, sample rate and channels.
|
||||
|
||||
Example configurations:
|
||||
```python
|
||||
# Basic configurations
|
||||
TouchLabSensorConfig("/dev/ttyACM0", 16000) # Serial port /dev/ttyACM0, 16000Hz
|
||||
TouchLabSensorConfig("/dev/ttyACM1", 44100) # Serial port /dev/ttyACM1, 44100Hz
|
||||
```
|
||||
|
||||
Attributes:
|
||||
sensor_port: Serial port of the tactile sensor.
|
||||
baud_rate: Baud rate of the tactile sensor.
|
||||
sample_rate: Sample rate in Hz for the tactile sensor.
|
||||
channels: List of channel numbers to use for the tactile sensor.
|
||||
"""
|
||||
|
||||
sensor_port: str
|
||||
baud_rate: int = 115_200
|
||||
@@ -1,469 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Provides the TouchLabSensor class for capturing tactile data from TouchLab tactile sensors.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from multiprocessing import (
|
||||
Event as process_Event,
|
||||
JoinableQueue as process_Queue,
|
||||
Process,
|
||||
)
|
||||
from pathlib import Path
|
||||
from queue import Empty
|
||||
from threading import Barrier, Event, Event as thread_Event, Thread
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from serial import Serial
|
||||
from soundfile import SoundFile
|
||||
|
||||
from lerobot.utils.errors import (
|
||||
DeviceAlreadyConnectedError,
|
||||
DeviceAlreadyRecordingError,
|
||||
DeviceNotConnectedError,
|
||||
DeviceNotRecordingError,
|
||||
)
|
||||
from lerobot.utils.shared_array import SharedArray
|
||||
|
||||
from ..microphone import Microphone
|
||||
from .configuration_touchlab import TouchLabSensorConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MAX_SERIAL_READ_SIZE = 512
|
||||
|
||||
|
||||
class TouchLabSensor(Microphone):
|
||||
"""
|
||||
The TouchLabSensor class handles all TouchLab tactile sensors.
|
||||
|
||||
A TouchLabSensor instance requires the serial port of the tactile sensor, which may be obtained using `python -m lerobot.find_port`. It also requires the recording sample rate as well as the list of recorded channels.
|
||||
|
||||
Example of usage:
|
||||
```python
|
||||
from lerobot.common.robot_devices.microphones.configs import TouchLabSensorConfig
|
||||
|
||||
config = TouchLabSensorConfig(sensor_port="/dev/ttyACM0", baud_rate=115200, sample_rate=115, channels=[1])
|
||||
microphone = TouchLabSensor(config)
|
||||
|
||||
microphone.connect()
|
||||
microphone.start_recording("some/output/file.wav")
|
||||
...
|
||||
audio_readings = microphone.read() # Gets all recorded audio data since the last read or since the beginning of the recording. The longer the period the longer the reading time !
|
||||
...
|
||||
microphone.stop_recording()
|
||||
microphone.disconnect()
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, config: TouchLabSensorConfig):
|
||||
""" "
|
||||
Initializes the TouchLabSensor instance.
|
||||
|
||||
Args:
|
||||
config: The configuration settings for the sensor.
|
||||
"""
|
||||
super().__init__(config)
|
||||
|
||||
# Sensor port
|
||||
self.sensor_port = config.sensor_port
|
||||
|
||||
# Baud rate
|
||||
self.baud_rate = config.baud_rate
|
||||
|
||||
# Input audio recording process and events
|
||||
self.record_process = None
|
||||
self.record_stop_event = process_Event()
|
||||
self.record_start_event = process_Event()
|
||||
self.record_close_event = process_Event()
|
||||
self.record_is_started_event = process_Event()
|
||||
self.audio_callback_start_event = process_Event()
|
||||
|
||||
# Process-safe concurrent queue to send audio from the recording process to the writing process/thread
|
||||
self.write_queue = process_Queue()
|
||||
|
||||
# SharedArray to store audio from the recording process.
|
||||
self.read_shared_array = None
|
||||
self.local_read_shared_array = None
|
||||
# Thread/Process to handle data writing in a separate thread/process (safely)
|
||||
self.write_thread = None
|
||||
self.write_stop_event = None
|
||||
self.write_is_started_event = None
|
||||
|
||||
self.logs = {}
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.__class__.__name__}({self.sensor_port})"
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
"""Check if the sensor is currently connected.
|
||||
|
||||
Returns:
|
||||
bool: True if the sensor is connected and ready to start recording,
|
||||
False otherwise.
|
||||
"""
|
||||
return self.record_process is not None and self.record_process.is_alive()
|
||||
|
||||
@property
|
||||
def is_recording(self) -> bool:
|
||||
"""Check if the sensor is currently recording.
|
||||
|
||||
Returns:
|
||||
bool: True if the sensor is recording, False otherwise.
|
||||
"""
|
||||
return self.record_is_started_event.is_set()
|
||||
|
||||
@property
|
||||
def is_writing(self) -> bool:
|
||||
"""Check if the sensor is currently writing to a file.
|
||||
|
||||
Returns:
|
||||
bool: True if the sensor is writing to a file, False otherwise.
|
||||
"""
|
||||
return self.write_thread is not None and self.write_is_started_event.is_set()
|
||||
|
||||
@staticmethod
|
||||
def find_microphones() -> list[dict[str, Any]]:
|
||||
"""Detects available sensors connected to the system.
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: A list of dictionaries,
|
||||
where each dictionary contains information about a detected sensor.
|
||||
"""
|
||||
pass
|
||||
|
||||
def connect(self) -> None:
|
||||
"""
|
||||
Establish connection to the sensor.
|
||||
"""
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"Sensor connected to {self.sensor_port} is already connected.")
|
||||
|
||||
# Create or reset queue and shared array
|
||||
self.read_shared_array = SharedArray(
|
||||
shape=(self.sample_rate * 10, len(self.channels)),
|
||||
dtype=np.dtype("int16"),
|
||||
)
|
||||
self.local_read_shared_array = self.read_shared_array.get_local_array()
|
||||
self.write_queue = process_Queue()
|
||||
|
||||
# Reset events
|
||||
self.record_start_event.clear()
|
||||
self.record_stop_event.clear()
|
||||
self.record_close_event.clear()
|
||||
self.record_is_started_event.clear()
|
||||
self.audio_callback_start_event.clear()
|
||||
|
||||
# Create and start an audio input stream with a recording callback
|
||||
# Remark: this is done in a separate process so that audio recording is not impacted by the main thread CPU usage, especially the precise_sleep function.
|
||||
process_init_event = process_Event()
|
||||
self.record_process = Process(
|
||||
target=self._record_process,
|
||||
args=(
|
||||
self.sensor_port,
|
||||
self.baud_rate,
|
||||
self.channels,
|
||||
process_init_event,
|
||||
self.record_start_event,
|
||||
self.record_stop_event,
|
||||
self.record_close_event,
|
||||
self.record_is_started_event,
|
||||
self.audio_callback_start_event,
|
||||
self.write_queue,
|
||||
self.read_shared_array,
|
||||
),
|
||||
)
|
||||
self.record_process.daemon = True
|
||||
self.record_process.start()
|
||||
|
||||
is_init = process_init_event.wait(
|
||||
timeout=5.0
|
||||
) # Wait for the recording process to be started, and to potentially raise an error on failure.
|
||||
if not self.is_connected or not is_init:
|
||||
raise RuntimeError(f"Error connecting sensor connected to {self.sensor_port}.")
|
||||
|
||||
logger.info(f"{self} connected.")
|
||||
|
||||
@staticmethod
|
||||
def _record_process(
|
||||
sensor_port,
|
||||
baud_rate,
|
||||
channels,
|
||||
process_init_event,
|
||||
record_start_event,
|
||||
record_stop_event,
|
||||
record_close_event,
|
||||
record_is_started_event,
|
||||
audio_callback_start_event,
|
||||
write_queue,
|
||||
read_shared_array,
|
||||
) -> None:
|
||||
channels_index = np.array(channels) - 1
|
||||
local_read_shared_array = read_shared_array.get_local_array()
|
||||
|
||||
def tactile_callback(serial_connection):
|
||||
"""
|
||||
Parse the tactile data from the raw input data.
|
||||
"""
|
||||
buffer = serial_connection.readline()
|
||||
|
||||
if audio_callback_start_event.is_set():
|
||||
strings = buffer.decode("utf8").split(",")
|
||||
num_taxels = len(strings)
|
||||
|
||||
if num_taxels > 0 and num_taxels < MAX_SERIAL_READ_SIZE: # Make sure we didn't read rubbish
|
||||
indata = np.empty((1, num_taxels))
|
||||
for i in range(num_taxels):
|
||||
indata[0, i] = int(strings[i])
|
||||
|
||||
write_queue.put_nowait(indata[:, channels_index])
|
||||
read_shared_array.write(local_read_shared_array, indata[:, channels_index])
|
||||
|
||||
process_init_event.set()
|
||||
|
||||
while True:
|
||||
start_flag = record_start_event.wait(timeout=0.1)
|
||||
if record_close_event.is_set():
|
||||
break
|
||||
elif not start_flag:
|
||||
continue
|
||||
|
||||
with Serial(sensor_port, baud_rate, timeout=0.5) as serial_connection:
|
||||
serial_connection.flush()
|
||||
record_is_started_event.set()
|
||||
while not record_stop_event.is_set():
|
||||
tactile_callback(serial_connection)
|
||||
record_is_started_event.clear()
|
||||
serial_connection.close()
|
||||
|
||||
def disconnect(self) -> None:
|
||||
"""
|
||||
Disconnect the sensor and release any resources.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"Sensor connected to {self.sensor_port} is not connected.")
|
||||
|
||||
if self.is_recording:
|
||||
self.stop_recording()
|
||||
|
||||
self.record_close_event.set()
|
||||
self.read_shared_array.delete()
|
||||
self.write_queue.close()
|
||||
self.record_process.join()
|
||||
|
||||
if self.is_connected:
|
||||
raise RuntimeError(f"Error disconnecting sensor connected to {self.sensor_port}.")
|
||||
|
||||
logger.info(f"{self} disconnected.")
|
||||
|
||||
def start_recording(
|
||||
self,
|
||||
output_file: str | Path | None = None,
|
||||
multiprocessing: bool | None = False,
|
||||
overwrite: bool | None = True,
|
||||
barrier: Barrier | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Start recording tactile data from the sensor.
|
||||
|
||||
Args:
|
||||
output_file: Optional path to save the recorded tactile data.
|
||||
multiprocessing: If True, enables multiprocessing for recording. Defaults to multithreading otherwise.
|
||||
overwrite: If True, overwrites existing files at output_file path.
|
||||
barrier: If not None, ensures that multiple sensors start recording at the same time.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"Sensor connected to {self.sensor_port} is not connected.")
|
||||
if self.is_recording:
|
||||
raise DeviceAlreadyRecordingError(f"Sensor connected to {self.sensor_port} is already recording.")
|
||||
|
||||
# Reset queue and shared memory
|
||||
self.read_shared_array.reset()
|
||||
self._clear_queue(self.write_queue)
|
||||
|
||||
# Reset stop event
|
||||
self.record_stop_event.clear()
|
||||
|
||||
# Write recordings into a file if output_file is provided
|
||||
if output_file is not None:
|
||||
output_file = Path(output_file)
|
||||
output_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if output_file.exists():
|
||||
if overwrite:
|
||||
output_file.unlink()
|
||||
else:
|
||||
raise FileExistsError(
|
||||
f"Output file {output_file} already exists. Set overwrite to True to overwrite it."
|
||||
)
|
||||
|
||||
if multiprocessing:
|
||||
self.write_stop_event = process_Event()
|
||||
self.write_is_started_event = process_Event()
|
||||
self.write_thread = Process(
|
||||
target=TouchLabSensor._write_loop,
|
||||
args=(
|
||||
self.write_queue,
|
||||
self.write_stop_event,
|
||||
self.write_is_started_event,
|
||||
self.sample_rate,
|
||||
self.channels,
|
||||
output_file,
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.write_stop_event = thread_Event()
|
||||
self.write_is_started_event = thread_Event()
|
||||
self.write_thread = Thread(
|
||||
target=TouchLabSensor._write_loop,
|
||||
args=(
|
||||
self.write_queue,
|
||||
self.write_stop_event,
|
||||
self.write_is_started_event,
|
||||
self.sample_rate,
|
||||
self.channels,
|
||||
output_file,
|
||||
),
|
||||
)
|
||||
self.write_thread.daemon = True
|
||||
self.write_thread.start()
|
||||
self.write_is_started_event.wait() # Wait for the writing thread/process to be started.
|
||||
|
||||
self.record_start_event.set() # Start the input audio stream process
|
||||
self.record_is_started_event.wait() # Wait for the input audio stream process to be actually started
|
||||
|
||||
if barrier is not None:
|
||||
barrier.wait() # Wait for multiple input audio streams to be started at the same time
|
||||
|
||||
self.audio_callback_start_event.set()
|
||||
|
||||
if not self.is_recording:
|
||||
raise RuntimeError(f"Error starting recording for sensor connected to {self.sensor_port}.")
|
||||
if output_file is not None and not self.is_writing:
|
||||
raise RuntimeError(f"Error starting writing for sensor connected to {self.sensor_port}.")
|
||||
|
||||
def _read(self) -> np.ndarray:
|
||||
"""
|
||||
Thread/Process-safe callback to read available audio data
|
||||
"""
|
||||
return self.read_shared_array.read(self.local_read_shared_array, flush=True)
|
||||
|
||||
def read(self) -> np.ndarray:
|
||||
"""Capture and return a single audio chunk from the sensor.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Captured audio chunk as a numpy array.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"Sensor connected to {self.sensor_port} is not connected.")
|
||||
if not self.is_recording:
|
||||
raise RuntimeError(f"Sensor connected to {self.sensor_port} is not recording.")
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
tactile_readings = self._read()
|
||||
|
||||
# log the number of seconds it took to read the audio chunk
|
||||
self.logs["delta_timestamp_s"] = time.perf_counter() - start_time
|
||||
|
||||
# log the utc time at which the audio chunk was received
|
||||
self.logs["timestamp_utc"] = time.perf_counter()
|
||||
|
||||
return tactile_readings
|
||||
|
||||
def _read_loop(self) -> None:
|
||||
"""Internal loop run by the background thread for asynchronous reading."""
|
||||
|
||||
def stop_recording(self) -> None:
|
||||
"""Stop recording audio from the sensor."""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"Sensor connected to {self.sensor_port} is not connected.")
|
||||
if not self.is_recording:
|
||||
raise DeviceNotRecordingError(f"Sensor connected to {self.sensor_port} is not recording.")
|
||||
|
||||
self.audio_callback_start_event.clear()
|
||||
self.record_start_event.clear() # Ensures the audio stream is not started again !
|
||||
self.record_stop_event.set()
|
||||
|
||||
self.read_shared_array.reset()
|
||||
self._clear_queue(self.write_queue, join_queue=True)
|
||||
|
||||
if self.is_writing:
|
||||
self.write_stop_event.set()
|
||||
self.write_thread.join()
|
||||
|
||||
timeout = 1.0
|
||||
while self.is_recording and timeout > 0:
|
||||
time.sleep(0.01)
|
||||
timeout -= 0.01
|
||||
|
||||
if self.is_recording:
|
||||
raise RuntimeError(f"Error stopping recording for sensor connected to {self.sensor_port}.")
|
||||
if self.is_writing:
|
||||
raise RuntimeError(f"Error stopping writing for sensor connected to {self.sensor_port}.")
|
||||
|
||||
def __del__(self) -> None:
|
||||
if self.is_connected:
|
||||
self.disconnect()
|
||||
|
||||
@staticmethod
|
||||
def _clear_queue(queue, join_queue: bool = False):
|
||||
"""
|
||||
Clears the queue by getting all items until it is empty. The longer the queue, the longer it takes to clear it.
|
||||
"""
|
||||
try:
|
||||
while True:
|
||||
queue.get_nowait()
|
||||
queue.task_done()
|
||||
except Empty:
|
||||
if join_queue:
|
||||
queue.join()
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def _write_loop(
|
||||
queue,
|
||||
write_stop_event: Event,
|
||||
write_is_started_event: Event,
|
||||
sample_rate: int,
|
||||
channels: list[int],
|
||||
output_file: Path,
|
||||
) -> None:
|
||||
"""
|
||||
Thread/Process-safe loop to write audio data into a file.
|
||||
"""
|
||||
# Can only be run on a single process/thread for file writing safety
|
||||
with SoundFile(
|
||||
output_file,
|
||||
mode="w",
|
||||
samplerate=sample_rate,
|
||||
channels=len(channels),
|
||||
format="WAV",
|
||||
subtype="PCM_16", # Subtype for int16 values
|
||||
) as file:
|
||||
write_is_started_event.set()
|
||||
while not write_stop_event.is_set():
|
||||
try:
|
||||
file.write(
|
||||
queue.get(timeout=0.005)
|
||||
) # Timeout set as the usual sounddevice buffer size. get_nowait is not possible here as it saturates the thread.
|
||||
queue.task_done()
|
||||
except Empty:
|
||||
continue
|
||||
write_is_started_event.clear()
|
||||
@@ -1,89 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from multiprocessing import Barrier
|
||||
from threading import Thread
|
||||
|
||||
from .configs import MicrophoneConfig
|
||||
from .microphone import Microphone
|
||||
|
||||
|
||||
def make_microphones_from_configs(microphone_configs: dict[str, MicrophoneConfig]) -> dict[str, Microphone]:
|
||||
microphones = {}
|
||||
|
||||
for key, cfg in microphone_configs.items():
|
||||
if cfg.type == "portaudio":
|
||||
from .portaudio import PortAudioMicrophone
|
||||
|
||||
microphones[key] = PortAudioMicrophone(cfg)
|
||||
elif cfg.type == "touchlab":
|
||||
from .touchlab import TouchLabSensor
|
||||
|
||||
microphones[key] = TouchLabSensor(cfg)
|
||||
else:
|
||||
raise ValueError(f"The microphone type '{cfg.type}' is not valid.")
|
||||
|
||||
return microphones
|
||||
|
||||
|
||||
def async_microphones_start_recording(
|
||||
microphones: dict[str, Microphone],
|
||||
output_files: list[str | None] | None = None,
|
||||
multiprocessing: bool = False,
|
||||
overwrite: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Starts recording on multiple microphones asynchronously to avoid delays.
|
||||
|
||||
Args:
|
||||
microphones: A dictionary of microphones.
|
||||
output_files: A list of output files.
|
||||
multiprocessing: If True, enables multiprocessing for recording.
|
||||
overwrite: If True, overwrites existing files at output_file path.
|
||||
"""
|
||||
|
||||
start_recording_threads = []
|
||||
if output_files is None:
|
||||
output_files = [None] * len(microphones)
|
||||
|
||||
barrier = Barrier(len(microphones))
|
||||
|
||||
for microphone, output_file in zip(microphones.values(), output_files, strict=False):
|
||||
start_recording_threads.append(
|
||||
Thread(target=microphone.start_recording, args=(output_file, multiprocessing, overwrite, barrier))
|
||||
)
|
||||
|
||||
for thread in start_recording_threads:
|
||||
thread.start()
|
||||
for thread in start_recording_threads:
|
||||
thread.join()
|
||||
|
||||
|
||||
def async_microphones_stop_recording(microphones: dict[str, Microphone]) -> None:
|
||||
"""
|
||||
Stops recording on multiple microphones asynchronously to avoid delays.
|
||||
|
||||
Args:
|
||||
microphones: A dictionary of microphones.
|
||||
"""
|
||||
|
||||
stop_recording_threads = []
|
||||
|
||||
for microphone in microphones.values():
|
||||
stop_recording_threads.append(Thread(target=microphone.stop_recording))
|
||||
|
||||
for thread in stop_recording_threads:
|
||||
thread.start()
|
||||
for thread in stop_recording_threads:
|
||||
thread.join()
|
||||
@@ -89,7 +89,6 @@ class ACTConfig(PreTrainedConfig):
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.MEAN_STD,
|
||||
"AUDIO": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.MEAN_STD,
|
||||
"ACTION": NormalizationMode.MEAN_STD,
|
||||
}
|
||||
@@ -100,10 +99,6 @@ class ACTConfig(PreTrainedConfig):
|
||||
vision_backbone: str = "resnet18"
|
||||
pretrained_backbone_weights: str | None = "ResNet18_Weights.IMAGENET1K_V1"
|
||||
replace_final_stride_with_dilation: int = False
|
||||
# Audio backbone.
|
||||
audio_backbone: str = vision_backbone
|
||||
pretrained_backbone_weights_audio: str | None = None
|
||||
replace_final_stride_with_dilation_audio: int = False
|
||||
# Transformer layers.
|
||||
pre_norm: bool = False
|
||||
dim_model: int = 512
|
||||
@@ -166,10 +161,8 @@ class ACTConfig(PreTrainedConfig):
|
||||
return None
|
||||
|
||||
def validate_features(self) -> None:
|
||||
if not (self.image_features or self.audio_features) and not self.env_state_feature:
|
||||
raise ValueError(
|
||||
"You must provide at least one image/audio or the environment state among the inputs."
|
||||
)
|
||||
if not self.image_features and not self.env_state_feature:
|
||||
raise ValueError("You must provide at least one image or the environment state among the inputs.")
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> None:
|
||||
|
||||
@@ -35,7 +35,7 @@ from torchvision.ops.misc import FrozenBatchNorm2d
|
||||
|
||||
from lerobot.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.utils.constants import ACTION, OBS_AUDIO, OBS_ENV_STATE, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_IMAGES, OBS_STATE
|
||||
|
||||
|
||||
class ACTPolicy(PreTrainedPolicy):
|
||||
@@ -106,8 +106,6 @@ class ACTPolicy(PreTrainedPolicy):
|
||||
"""
|
||||
self.eval() # keeping the policy in eval mode as it could be set to train mode while queue is consumed
|
||||
|
||||
# If we are doing temporal ensembling, do online updates where we keep track of the number of actions
|
||||
# we are ensembling over.
|
||||
if self.config.temporal_ensemble_coeff is not None:
|
||||
actions = self.predict_action_chunk(batch)
|
||||
action = self.temporal_ensembler.update(actions)
|
||||
@@ -333,26 +331,12 @@ class ACT(nn.Module):
|
||||
# Note: The forward method of this returns a dict: {"feature_map": output}.
|
||||
self.backbone = IntermediateLayerGetter(backbone_model, return_layers={"layer4": "feature_map"})
|
||||
|
||||
# Backbone for audio feature extraction.
|
||||
if self.config.audio_features:
|
||||
audio_backbone_model = getattr(torchvision.models, config.audio_backbone)(
|
||||
replace_stride_with_dilation=[False, False, config.replace_final_stride_with_dilation_audio],
|
||||
weights=config.pretrained_backbone_weights_audio,
|
||||
norm_layer=FrozenBatchNorm2d,
|
||||
)
|
||||
# Note: The assumption here is that we are using a ResNet model (and hence layer4 is the final
|
||||
# feature map).
|
||||
# Note: The forward method of this returns a dict: {"feature_map": output}.
|
||||
self.audio_backbone = IntermediateLayerGetter(
|
||||
audio_backbone_model, return_layers={"layer4": "feature_map"}
|
||||
)
|
||||
|
||||
# Transformer (acts as VAE decoder when training with the variational objective).
|
||||
self.encoder = ACTEncoder(config)
|
||||
self.decoder = ACTDecoder(config)
|
||||
|
||||
# Transformer encoder input projections. The tokens will be structured like
|
||||
# [latent, (robot_state), (env_state), (image_feature_map_pixels), (audio_feature)].
|
||||
# [latent, (robot_state), (env_state), (image_feature_map_pixels)].
|
||||
if self.config.robot_state_feature:
|
||||
self.encoder_robot_state_input_proj = nn.Linear(
|
||||
self.config.robot_state_feature.shape[0], config.dim_model
|
||||
@@ -366,10 +350,6 @@ class ACT(nn.Module):
|
||||
self.encoder_img_feat_input_proj = nn.Conv2d(
|
||||
backbone_model.fc.in_features, config.dim_model, kernel_size=1
|
||||
)
|
||||
if self.config.audio_features:
|
||||
self.encoder_audio_feat_input_proj = nn.Conv2d(
|
||||
audio_backbone_model.fc.in_features, config.dim_model, kernel_size=1
|
||||
)
|
||||
# Transformer encoder positional embeddings.
|
||||
n_1d_tokens = 1 # for the latent
|
||||
if self.config.robot_state_feature:
|
||||
@@ -379,8 +359,6 @@ class ACT(nn.Module):
|
||||
self.encoder_1d_feature_pos_embed = nn.Embedding(n_1d_tokens, config.dim_model)
|
||||
if self.config.image_features:
|
||||
self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2)
|
||||
if self.config.audio_features:
|
||||
self.encoder_audio_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2)
|
||||
|
||||
# Transformer decoder.
|
||||
# Learnable positional embedding for the transformer's decoder (in the style of DETR object queries).
|
||||
@@ -505,21 +483,6 @@ class ACT(nn.Module):
|
||||
encoder_in_tokens.extend(list(cam_features))
|
||||
encoder_in_pos_embed.extend(list(cam_pos_embed))
|
||||
|
||||
if self.config.audio_features:
|
||||
for audio in batch[OBS_AUDIO]:
|
||||
audio_features = self.audio_backbone(audio)["feature_map"]
|
||||
audio_pos_embed = self.encoder_audio_feat_pos_embed(audio_features).to(
|
||||
dtype=audio_features.dtype
|
||||
)
|
||||
audio_features = self.encoder_audio_feat_input_proj(audio_features)
|
||||
|
||||
# Rearrange features to (sequence, batch, dim).
|
||||
audio_features = einops.rearrange(audio_features, "b c h w -> (h w) b c")
|
||||
audio_pos_embed = einops.rearrange(audio_pos_embed, "b c h w -> (h w) b c")
|
||||
|
||||
encoder_in_tokens.extend(list(audio_features))
|
||||
encoder_in_pos_embed.extend(list(audio_pos_embed))
|
||||
|
||||
# Stack all tokens along the sequence dimension.
|
||||
encoder_in_tokens = torch.stack(encoder_in_tokens, axis=0)
|
||||
encoder_in_pos_embed = torch.stack(encoder_in_pos_embed, axis=0)
|
||||
|
||||
@@ -17,11 +17,9 @@ from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.datasets.utils import DEFAULT_AUDIO_CHUNK_DURATION
|
||||
from lerobot.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
AudioProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyAction,
|
||||
@@ -65,15 +63,6 @@ def make_act_pre_post_processors(
|
||||
stats=dataset_stats,
|
||||
device=config.device,
|
||||
),
|
||||
AudioProcessorStep(
|
||||
output_height=224,
|
||||
output_width=224,
|
||||
output_channels=3,
|
||||
input_audio_chunk_duration=DEFAULT_AUDIO_CHUNK_DURATION,
|
||||
input_sample_rate=48000,
|
||||
intermediate_sample_rate=16000,
|
||||
n_fft=1024,
|
||||
),
|
||||
]
|
||||
output_steps = [
|
||||
UnnormalizerProcessorStep(
|
||||
|
||||
@@ -1,37 +0,0 @@
|
||||
# Multitask DiT Policy
|
||||
|
||||
## Citation
|
||||
|
||||
If you use this work, please cite the following works:
|
||||
|
||||
```bibtex
|
||||
@misc{jones2025multitaskditpolicy,
|
||||
author = {Bryson Jones},
|
||||
title = {Dissecting and Open-Sourcing Multitask Diffusion Transformer Policy},
|
||||
year = {2025},
|
||||
url = {https://brysonkjones.substack.com/p/dissecting-and-open-sourcing-multitask-diffusion-transformer-policy},
|
||||
note = {Blog post}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{trilbmteam2025carefulexaminationlargebehaviormodels,
|
||||
author = {TRI LBM Team},
|
||||
title = {A Careful Examination of Large Behavior Models for Multitask Dexterous Manipulation},
|
||||
year = {2025},
|
||||
eprint = {arXiv:2507.05331},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.RO},
|
||||
url = {https://arxiv.org/abs/2507.05331}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{bostondynamics2025largebehaviormodelsatlas,
|
||||
author = {Boston Dynamics and TRI Research Team},
|
||||
title = {Large Behavior Models and Atlas Find New Footing},
|
||||
year = {2025},
|
||||
url = {https://bostondynamics.com/blog/large-behavior-models-atlas-find-new-footing/},
|
||||
note = {Blog post}
|
||||
}
|
||||
```
|
||||
@@ -0,0 +1 @@
|
||||
../../../../docs/source/policy_multi_task_dit_README.md
|
||||
@@ -1,108 +0,0 @@
|
||||
# π₀ (pi0)
|
||||
|
||||
This repository contains the Hugging Face port of **π₀**, adapted from [OpenPI](https://github.com/Physical-Intelligence/openpi) by the Physical Intelligence.
|
||||
It is designed as a **Vision-Language-Action model for general robot control**.
|
||||
|
||||
---
|
||||
|
||||
## Model Overview
|
||||
|
||||
| Feature | π₀ | π₀.₅ |
|
||||
| -------------------- | ------------------------------------------------------ | ----------------------------------------- |
|
||||
| Time Conditioning | Concatenates time with actions via `action_time_mlp_*` | Uses `time_mlp_*` for AdaRMS conditioning |
|
||||
| AdaRMS | Not used | Used in action expert |
|
||||
| Tokenizer Length | 48 tokens | 200 tokens |
|
||||
| Discrete State Input | False (Uses `state_proj` layer) | True |
|
||||
| Parameter Count | Higher (includes state embedding) | Lower (no state embedding) |
|
||||
|
||||
---
|
||||
|
||||
## Relative Actions
|
||||
|
||||
π₀ supports training with **relative actions**, where the model learns relative offsets
|
||||
from the current robot state instead of absolute joint positions. This mirrors the
|
||||
relative-action transform in OpenPI (`DeltaActions`) and can improve performance.
|
||||
|
||||
### How it works
|
||||
|
||||
1. **During preprocessing**, absolute actions are converted to relative offsets:
|
||||
`relative = action - state` (for selected joints).
|
||||
2. The relative actions are normalized using statistics computed from the relative distribution.
|
||||
3. **During postprocessing**, predicted relative actions are converted back to absolute:
|
||||
`absolute = relative + state`.
|
||||
|
||||
Joints listed in `relative_exclude_joints` (e.g., gripper) are kept absolute.
|
||||
|
||||
### Configuration
|
||||
|
||||
| Parameter | Type | Default | Description |
|
||||
| ------------------------- | ----------- | ------------- | ---------------------------------------------------------------- |
|
||||
| `use_relative_actions` | `bool` | `False` | Enable relative-action training |
|
||||
| `relative_exclude_joints` | `list[str]` | `["gripper"]` | Joint names to keep absolute (matched by substring) |
|
||||
| `action_feature_names` | `list[str]` | `None` | Auto-populated from dataset metadata at runtime by `make_policy` |
|
||||
|
||||
### Training example
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.lerobot_train \
|
||||
--policy.type=pi0 \
|
||||
--dataset.repo_id=your_org/your_dataset \
|
||||
--policy.use_relative_actions=true \
|
||||
--policy.relative_exclude_joints='["gripper"]'
|
||||
```
|
||||
|
||||
When `use_relative_actions=true`, the training script automatically:
|
||||
|
||||
- Computes relative action statistics from the dataset (sampled chunk-level relative actions)
|
||||
- Replaces the standard action stats with relative stats for normalization
|
||||
- Broadcasts these stats across all ranks in distributed training
|
||||
|
||||
### Recomputing stats for an existing dataset
|
||||
|
||||
If you want to precompute relative action stats offline, use `recompute_stats` from
|
||||
`lerobot.datasets.dataset_tools`:
|
||||
|
||||
```python
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.dataset_tools import recompute_stats
|
||||
|
||||
dataset = LeRobotDataset("your_org/your_dataset")
|
||||
dataset = recompute_stats(
|
||||
dataset,
|
||||
relative_action=True,
|
||||
relative_exclude_joints=["gripper"],
|
||||
)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Citation
|
||||
|
||||
If you use this work, please cite both **OpenPI** and the π₀ paper:
|
||||
|
||||
```bibtex
|
||||
@misc{openpi2024,
|
||||
author = {Physical Intelligence Lab},
|
||||
title = {OpenPI: PyTorch Implementation of π0 and π0.5 Policies},
|
||||
year = {2024},
|
||||
publisher = {GitHub},
|
||||
howpublished = {\url{https://github.com/Physical-Intelligence/openpi}},
|
||||
license = {Apache-2.0}
|
||||
}
|
||||
|
||||
@misc{black2024pi0visionlanguageactionflowmodel,
|
||||
title = {π₀: A Vision-Language-Action Flow Model for General Robot Control},
|
||||
author = {Kevin Black and Noah Brown and Danny Driess and Adnan Esmail and Michael Equi and Chelsea Finn and Niccolo Fusai and Lachy Groom and Karol Hausman and Brian Ichter and Szymon Jakubczak and Tim Jones and Liyiming Ke and Sergey Levine and Adrian Li-Bell and Mohith Mothukuri and Suraj Nair and Karl Pertsch and Lucy Xiaoyang Shi and James Tanner and Quan Vuong and Anna Walling and Haohuan Wang and Ury Zhilinsky},
|
||||
year = {2024},
|
||||
eprint = {2410.24164},
|
||||
archivePrefix= {arXiv},
|
||||
primaryClass = {cs.LG},
|
||||
url = {https://arxiv.org/abs/2410.24164},
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## License
|
||||
|
||||
This port follows the **Apache 2.0 License**, consistent with the original [OpenPI repository](https://github.com/Physical-Intelligence/openpi).
|
||||
+1
@@ -0,0 +1 @@
|
||||
../../../../docs/source/policy_pi0_README.md
|
||||
@@ -1,91 +0,0 @@
|
||||
# π₀.₅ (pi05)
|
||||
|
||||
This repository contains the Hugging Face port of **π₀.₅**, adapted from [OpenPI](https://github.com/Physical-Intelligence/openpi) by the Physical Intelligence.
|
||||
It is designed as a **Vision-Language-Action model with open-world generalization**.
|
||||
|
||||
---
|
||||
|
||||
## Model Overview
|
||||
|
||||
| Feature | π₀ | π₀.₅ |
|
||||
| -------------------- | ------------------------------------------------------ | ----------------------------------------- |
|
||||
| Time Conditioning | Concatenates time with actions via `action_time_mlp_*` | Uses `time_mlp_*` for AdaRMS conditioning |
|
||||
| AdaRMS | Not used | Used in action expert |
|
||||
| Tokenizer Length | 48 tokens | 200 tokens |
|
||||
| Discrete State Input | False (Uses `state_proj` layer) | True |
|
||||
| Parameter Count | Higher (includes state embedding) | Lower (no state embedding) |
|
||||
|
||||
---
|
||||
|
||||
## Relative Actions
|
||||
|
||||
π₀.₅ supports training with **relative actions**, where the model learns relative offsets
|
||||
from the current robot state instead of absolute joint positions. This mirrors the
|
||||
relative-action transform in OpenPI (`DeltaActions`) and can improve performance.
|
||||
|
||||
### How it works
|
||||
|
||||
1. **During preprocessing**, absolute actions are converted to relative offsets:
|
||||
`relative = action - state` (for selected joints).
|
||||
2. The relative actions are normalized using statistics computed from the relative distribution.
|
||||
3. **During postprocessing**, predicted relative actions are converted back to absolute:
|
||||
`absolute = relative + state`.
|
||||
|
||||
Joints listed in `relative_exclude_joints` (e.g., gripper) are kept absolute.
|
||||
|
||||
### Configuration
|
||||
|
||||
| Parameter | Type | Default | Description |
|
||||
| ------------------------- | ----------- | ------------- | ---------------------------------------------------------------- |
|
||||
| `use_relative_actions` | `bool` | `False` | Enable relative-action training |
|
||||
| `relative_exclude_joints` | `list[str]` | `["gripper"]` | Joint names to keep absolute (matched by substring) |
|
||||
| `action_feature_names` | `list[str]` | `None` | Auto-populated from dataset metadata at runtime by `make_policy` |
|
||||
|
||||
### Training example
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.lerobot_train \
|
||||
--policy.type=pi05 \
|
||||
--dataset.repo_id=your_org/your_dataset \
|
||||
--policy.use_relative_actions=true \
|
||||
--policy.relative_exclude_joints='["gripper"]'
|
||||
```
|
||||
|
||||
When `use_relative_actions=true`, the training script automatically:
|
||||
|
||||
- Computes relative action statistics from the dataset (sampled chunk-level relative actions)
|
||||
- Replaces the standard action stats with relative stats for normalization
|
||||
- Broadcasts these stats across all ranks in distributed training
|
||||
|
||||
---
|
||||
|
||||
## Citation
|
||||
|
||||
If you use this work, please cite both **OpenPI** and the π₀.₅ paper:
|
||||
|
||||
```bibtex
|
||||
@misc{openpi2024,
|
||||
author = {Physical Intelligence Lab},
|
||||
title = {OpenPI: PyTorch Implementation of π0 and π0.5 Policies},
|
||||
year = {2024},
|
||||
publisher = {GitHub},
|
||||
howpublished = {\url{https://github.com/Physical-Intelligence/openpi}},
|
||||
license = {Apache-2.0}
|
||||
}
|
||||
|
||||
@misc{intelligence2025pi05visionlanguageactionmodelopenworld,
|
||||
title = {π₀.₅: a Vision-Language-Action Model with Open-World Generalization},
|
||||
author = {Physical Intelligence and Kevin Black and Noah Brown and James Darpinian and Karan Dhabalia and Danny Driess and Adnan Esmail and Michael Equi and Chelsea Finn and Niccolo Fusai and Manuel Y. Galliker and Dibya Ghosh and Lachy Groom and Karol Hausman and Brian Ichter and Szymon Jakubczak and Tim Jones and Liyiming Ke and Devin LeBlanc and Sergey Levine and Adrian Li-Bell and Mohith Mothukuri and Suraj Nair and Karl Pertsch and Allen Z. Ren and Lucy Xiaoyang Shi and Laura Smith and Jost Tobias Springenberg and Kyle Stachowicz and James Tanner and Quan Vuong and Homer Walke and Anna Walling and Haohuan Wang and Lili Yu and Ury Zhilinsky},
|
||||
year = {2025},
|
||||
eprint = {2504.16054},
|
||||
archivePrefix= {arXiv},
|
||||
primaryClass = {cs.LG},
|
||||
url = {https://arxiv.org/abs/2504.16054},
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## License
|
||||
|
||||
This port follows the **Apache 2.0 License**, consistent with the original [OpenPI repository](https://github.com/Physical-Intelligence/openpi).
|
||||
+1
@@ -0,0 +1 @@
|
||||
../../../../docs/source/policy_pi05_README.md
|
||||
@@ -1,38 +0,0 @@
|
||||
# Real-Time Chunking (RTC)
|
||||
|
||||
This module contains the LeRobot implementation of **Real-Time Chunking (RTC)**, an inference-time technique for flow-matching based policies.
|
||||
|
||||
**Note**: RTC is not a policy itself, but rather an inference enhancement that works with flow-matching based policies including [π₀](../pi0/), [π₀.₅](../pi05/), and [SmolVLA](../smolvla/).
|
||||
|
||||
---
|
||||
|
||||
## Citation
|
||||
|
||||
If you use Real-Time Chunking in your work, please cite:
|
||||
|
||||
```bibtex
|
||||
@misc{openpi2024,
|
||||
author = {Physical Intelligence Lab},
|
||||
title = {OpenPI: PyTorch Implementation of π0 and π0.5 Policies},
|
||||
year = {2024},
|
||||
publisher = {GitHub},
|
||||
howpublished = {\url{https://github.com/Physical-Intelligence/openpi}},
|
||||
license = {Apache-2.0}
|
||||
}
|
||||
|
||||
@misc{black2025realtimeexecutionactionchunking,
|
||||
title={Real-Time Execution of Action Chunking Flow Policies},
|
||||
author={Kevin Black and Manuel Y. Galliker and Sergey Levine},
|
||||
year={2025},
|
||||
eprint={2506.07339},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.RO},
|
||||
url={https://arxiv.org/abs/2506.07339},
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## License
|
||||
|
||||
This implementation follows the **Apache 2.0 License**, consistent with the LeRobot project.
|
||||
+1
@@ -0,0 +1 @@
|
||||
../../../../docs/source/policy_rtc_README.md
|
||||
@@ -12,17 +12,18 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import abc
|
||||
from dataclasses import dataclass
|
||||
"""Real-Time Chunking (RTC) utilities for action-chunking policies."""
|
||||
|
||||
import draccus
|
||||
from lerobot.policies.rtc.action_interpolator import ActionInterpolator
|
||||
from lerobot.policies.rtc.action_queue import ActionQueue
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.policies.rtc.latency_tracker import LatencyTracker
|
||||
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class MicrophoneConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
sample_rate: int | None = None
|
||||
channels: list[int] | None = None
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return self.get_choice_name(self.__class__)
|
||||
__all__ = [
|
||||
"ActionInterpolator",
|
||||
"ActionQueue",
|
||||
"LatencyTracker",
|
||||
"RTCConfig",
|
||||
"RTCProcessor",
|
||||
]
|
||||
@@ -0,0 +1,116 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Action interpolation for smoother robot control.
|
||||
|
||||
Provides configurable Nx control rate by interpolating between consecutive actions.
|
||||
Useful with RTC and action-chunking policies to reduce jerkiness.
|
||||
"""
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
class ActionInterpolator:
|
||||
"""Interpolates between consecutive actions for smoother control.
|
||||
|
||||
When enabled with multiplier N, produces N actions per policy action
|
||||
by linearly interpolating between the previous and current action.
|
||||
|
||||
Example with multiplier=3:
|
||||
prev_action -> [1/3 interpolated, 2/3 interpolated, current_action]
|
||||
|
||||
This effectively multiplies the control rate for smoother motion.
|
||||
|
||||
Usage:
|
||||
interpolator = ActionInterpolator(multiplier=2) # 2x control rate
|
||||
|
||||
# In control loop:
|
||||
if interpolator.needs_new_action():
|
||||
new_action = queue.get()
|
||||
if new_action:
|
||||
interpolator.add(new_action.cpu())
|
||||
|
||||
action = interpolator.get()
|
||||
if action:
|
||||
robot.send_action(action)
|
||||
"""
|
||||
|
||||
def __init__(self, multiplier: int = 1):
|
||||
"""Initialize the interpolator.
|
||||
|
||||
Args:
|
||||
multiplier: Control rate multiplier (1 = no interpolation, 2 = 2x, 3 = 3x, etc.)
|
||||
"""
|
||||
if multiplier < 1:
|
||||
raise ValueError(f"multiplier must be >= 1, got {multiplier}")
|
||||
self.multiplier = multiplier
|
||||
self._prev: Tensor | None = None
|
||||
self._buffer: list[Tensor] = []
|
||||
self._idx = 0
|
||||
|
||||
@property
|
||||
def enabled(self) -> bool:
|
||||
"""Whether interpolation is active (multiplier > 1)."""
|
||||
return self.multiplier > 1
|
||||
|
||||
def reset(self):
|
||||
"""Reset interpolation state (call between episodes)."""
|
||||
self._prev = None
|
||||
self._buffer = []
|
||||
self._idx = 0
|
||||
|
||||
def needs_new_action(self) -> bool:
|
||||
"""Check if a new action is needed from the queue."""
|
||||
return self._idx >= len(self._buffer)
|
||||
|
||||
def add(self, action: Tensor) -> None:
|
||||
"""Add a new action and compute interpolated sequence.
|
||||
|
||||
Args:
|
||||
action: New action tensor from policy/queue (already on CPU).
|
||||
"""
|
||||
if self.multiplier > 1 and self._prev is not None:
|
||||
self._buffer = []
|
||||
for i in range(1, self.multiplier + 1):
|
||||
t = i / self.multiplier
|
||||
interp = self._prev + t * (action - self._prev)
|
||||
self._buffer.append(interp)
|
||||
else:
|
||||
# First step: no previous action yet, so run at base FPS without interpolation.
|
||||
self._buffer = [action.clone()]
|
||||
self._prev = action.clone()
|
||||
self._idx = 0
|
||||
|
||||
def get(self) -> Tensor | None:
|
||||
"""Get the next interpolated action.
|
||||
|
||||
Returns:
|
||||
Next action tensor, or None if buffer is exhausted.
|
||||
"""
|
||||
if self._idx >= len(self._buffer):
|
||||
return None
|
||||
action = self._buffer[self._idx]
|
||||
self._idx += 1
|
||||
return action
|
||||
|
||||
def get_control_interval(self, fps: float) -> float:
|
||||
"""Get the control interval based on interpolation multiplier.
|
||||
|
||||
Args:
|
||||
fps: Base frames per second.
|
||||
|
||||
Returns:
|
||||
Control interval in seconds (divided by multiplier).
|
||||
"""
|
||||
return 1.0 / (fps * self.multiplier)
|
||||
@@ -79,6 +79,13 @@ class ActionQueue:
|
||||
self.last_index += 1
|
||||
return action.clone()
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear queued actions and reset consumption index."""
|
||||
with self.lock:
|
||||
self.queue = None
|
||||
self.original_queue = None
|
||||
self.last_index = 0
|
||||
|
||||
def qsize(self) -> int:
|
||||
"""Get the number of remaining actions in the queue.
|
||||
|
||||
@@ -123,14 +130,26 @@ class ActionQueue:
|
||||
with self.lock:
|
||||
if self.original_queue is None:
|
||||
return None
|
||||
return self.original_queue[self.last_index :]
|
||||
return self.original_queue[self.last_index :].clone()
|
||||
|
||||
def get_processed_left_over(self) -> Tensor | None:
|
||||
"""Get leftover processed actions (the actions currently executed by the robot).
|
||||
|
||||
Returns:
|
||||
Tensor | None: Remaining processed actions (remaining_steps, action_dim),
|
||||
or None if no processed queue exists.
|
||||
"""
|
||||
with self.lock:
|
||||
if self.queue is None:
|
||||
return None
|
||||
return self.queue[self.last_index :].clone()
|
||||
|
||||
def merge(
|
||||
self,
|
||||
original_actions: Tensor,
|
||||
processed_actions: Tensor,
|
||||
real_delay: int,
|
||||
action_index_before_inference: int | None = 0,
|
||||
action_index_before_inference: int | None = None,
|
||||
):
|
||||
"""Merge new actions into the queue.
|
||||
|
||||
@@ -145,10 +164,10 @@ class ActionQueue:
|
||||
action_index_before_inference: Index before inference started, for validation.
|
||||
"""
|
||||
with self.lock:
|
||||
self._check_delays(real_delay, action_index_before_inference)
|
||||
delay = self._check_and_resolve_delays(real_delay, action_index_before_inference)
|
||||
|
||||
if self.cfg.enabled:
|
||||
self._replace_actions_queue(original_actions, processed_actions, real_delay)
|
||||
self._replace_actions_queue(original_actions, processed_actions, delay)
|
||||
return
|
||||
|
||||
self._append_actions_queue(original_actions, processed_actions)
|
||||
@@ -164,12 +183,13 @@ class ActionQueue:
|
||||
processed_actions: Post-processed actions for robot.
|
||||
real_delay: Number of time steps to skip due to inference delay.
|
||||
"""
|
||||
self.original_queue = original_actions[real_delay:].clone()
|
||||
self.queue = processed_actions[real_delay:].clone()
|
||||
clamped_delay = max(0, min(real_delay, len(original_actions), len(processed_actions)))
|
||||
self.original_queue = original_actions[clamped_delay:].clone()
|
||||
self.queue = processed_actions[clamped_delay:].clone()
|
||||
|
||||
logger.debug(f"original_actions shape: {self.original_queue.shape}")
|
||||
logger.debug(f"processed_actions shape: {self.queue.shape}")
|
||||
logger.debug(f"real_delay: {real_delay}")
|
||||
logger.debug(f"real_delay: {real_delay}, clamped_delay: {clamped_delay}")
|
||||
|
||||
self.last_index = 0
|
||||
|
||||
@@ -196,7 +216,9 @@ class ActionQueue:
|
||||
|
||||
self.last_index = 0
|
||||
|
||||
def _check_delays(self, real_delay: int, action_index_before_inference: int | None = None):
|
||||
def _check_and_resolve_delays(
|
||||
self, real_delay: int, action_index_before_inference: int | None = None
|
||||
) -> int:
|
||||
"""Validate that computed delays match expectations.
|
||||
|
||||
Compares the delay computed from inference latency with the actual
|
||||
@@ -205,15 +227,20 @@ class ActionQueue:
|
||||
Args:
|
||||
real_delay: Delay computed from inference latency.
|
||||
action_index_before_inference: Action index when inference started.
|
||||
"""
|
||||
if action_index_before_inference is None:
|
||||
return
|
||||
|
||||
indexes_diff = self.last_index - action_index_before_inference
|
||||
if indexes_diff != real_delay:
|
||||
# Let's check that action index difference (real delay calculated based on action queue)
|
||||
# is the same as delay calculated based on inference latency
|
||||
logger.warning(
|
||||
f"[ACTION_QUEUE] Indexes diff is not equal to real delay. "
|
||||
f"Indexes diff: {indexes_diff}, real delay: {real_delay}"
|
||||
)
|
||||
Returns:
|
||||
int: Delay to use.
|
||||
"""
|
||||
effective_delay = max(0, real_delay)
|
||||
|
||||
if action_index_before_inference is not None:
|
||||
indexes_diff = max(0, self.last_index - action_index_before_inference)
|
||||
if indexes_diff != real_delay:
|
||||
logger.warning(
|
||||
"Indexes diff is not equal to real delay. indexes_diff=%d, real_delay=%d",
|
||||
indexes_diff,
|
||||
real_delay,
|
||||
)
|
||||
return real_delay
|
||||
|
||||
return effective_delay
|
||||
|
||||
@@ -1,14 +0,0 @@
|
||||
## Paper
|
||||
|
||||
https://arxiv.org/abs/2509.25358
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@article{chen2025sarm,
|
||||
title={SARM: Stage-Aware Reward Modeling for Long Horizon Robot Manipulation},
|
||||
author={Chen, Qianzhong and Yu, Justin and Schwager, Mac and Abbeel, Pieter and Shentu, Yide and Wu, Philipp},
|
||||
journal={arXiv preprint arXiv:2509.25358},
|
||||
year={2025}
|
||||
}
|
||||
```
|
||||
+1
@@ -0,0 +1 @@
|
||||
../../../../docs/source/policy_sarm_README.md
|
||||
@@ -106,7 +106,7 @@ def prepare_observation_for_inference(
|
||||
This function takes a dictionary of NumPy arrays, performs necessary
|
||||
preprocessing, and prepares it for model inference. The steps include:
|
||||
1. Converting NumPy arrays to PyTorch tensors.
|
||||
2. Normalizing and permuting image data and audio data (if any).
|
||||
2. Normalizing and permuting image data (if any).
|
||||
3. Adding a batch dimension to each tensor.
|
||||
4. Moving all tensors to the specified compute device.
|
||||
5. Adding task and robot type information to the dictionary.
|
||||
@@ -129,9 +129,6 @@ def prepare_observation_for_inference(
|
||||
if "image" in name:
|
||||
observation[name] = observation[name].type(torch.float32) / 255
|
||||
observation[name] = observation[name].permute(2, 0, 1).contiguous()
|
||||
elif "audio" in name:
|
||||
observation[name] = observation[name].type(torch.float32)
|
||||
observation[name] = observation[name].permute(1, 0).contiguous()
|
||||
observation[name] = observation[name].unsqueeze(0)
|
||||
observation[name] = observation[name].to(device)
|
||||
|
||||
|
||||
@@ -23,7 +23,6 @@ from lerobot.types import (
|
||||
TransitionKey,
|
||||
)
|
||||
|
||||
from .audio_processor import AudioProcessorStep
|
||||
from .batch_processor import AddBatchDimensionProcessorStep
|
||||
from .converters import (
|
||||
batch_to_transition,
|
||||
@@ -89,7 +88,6 @@ __all__ = [
|
||||
"ActionProcessorStep",
|
||||
"AddTeleopActionAsComplimentaryDataStep",
|
||||
"AddTeleopEventsAsInfoStep",
|
||||
"AudioProcessorStep",
|
||||
"ComplementaryDataProcessorStep",
|
||||
"batch_to_transition",
|
||||
"create_transition",
|
||||
|
||||
@@ -1,130 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from torch import Tensor
|
||||
from torchaudio.functional import amplitude_to_DB
|
||||
from torchaudio.transforms import MelSpectrogram, Resample
|
||||
from torchvision.transforms import Compose, Lambda, Resize
|
||||
|
||||
from lerobot.datasets.utils import DEFAULT_AUDIO_CHUNK_DURATION
|
||||
from lerobot.utils.constants import OBS_AUDIO
|
||||
|
||||
from .pipeline import ObservationProcessorStep, ProcessorStepRegistry
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="audio_processor")
|
||||
class AudioProcessorStep(ObservationProcessorStep):
|
||||
"""
|
||||
Processes audio waveform data into a mel-spectrogram image representation.
|
||||
|
||||
**Audio Processing:**
|
||||
- Averages waveform data over all channels.
|
||||
- Resamples the waveform to 16kHz.
|
||||
- Converts the waveform to a mel-spectrogram.
|
||||
- Converts the mel-spectrogram to decibels.
|
||||
- Resizes the mel-spectrogram to 224×224.
|
||||
- Converts the mel-spectrogram to a channel-first, normalized tensor.
|
||||
|
||||
Attributes:
|
||||
output_height: Height of the output mel-spectrogram image in pixels.
|
||||
output_width: Width of the output mel-spectrogram image in pixels.
|
||||
output_channels: Number of channels in the output image (3 for RGB-like format).
|
||||
input_audio_chunk_duration: Duration of the input audio chunk in seconds.
|
||||
input_sample_rate: Original sample rate of the input audio in Hz.
|
||||
|
||||
intermediate_sample_rate: Reduced intermediate sample rate in Hz.
|
||||
Downsampling improves the temporal resolution but reduces the frequency range.
|
||||
n_fft: Size of the FFT window for spectrogram computation.
|
||||
Increasing the window size increases the frequency resolution but decreases the temporal resolution.
|
||||
|
||||
hop_length: Number of samples between successive frames, computed automatically to match the output_width.
|
||||
Decreasing the hop length increases the temporal resolution but decreases the frequency resolution.
|
||||
n_mels: Number of mel filter banks, computed automatically to match the output_height.
|
||||
Increasing the number of banks increases the number of rows in the spectrogram and the frequency resolution.
|
||||
mel_spectrogram_transform: The complete audio processing pipeline.
|
||||
"""
|
||||
|
||||
output_height: int = 224
|
||||
output_width: int = 224
|
||||
output_channels: int = 3
|
||||
input_audio_chunk_duration: float = DEFAULT_AUDIO_CHUNK_DURATION
|
||||
|
||||
input_sample_rate: int = 48000
|
||||
intermediate_sample_rate: int = 16000
|
||||
|
||||
n_fft: int = 1024
|
||||
|
||||
# Parameters computed from other parameters at initialization
|
||||
hop_length: int = field(init=False)
|
||||
n_mels: int = field(init=False)
|
||||
mel_spectrogram_transform: Compose = field(init=False, repr=False)
|
||||
|
||||
def __post_init__(self):
|
||||
self.hop_length = int(
|
||||
self.intermediate_sample_rate * self.input_audio_chunk_duration
|
||||
- self.n_fft // self.output_width
|
||||
- 1
|
||||
)
|
||||
self.n_mels = self.output_height
|
||||
|
||||
self.mel_spectrogram_transform = Compose(
|
||||
[
|
||||
Lambda(lambda x: x.mean(dim=1)), # Average over all channels (second dimension after batch)
|
||||
Resample(orig_freq=self.input_sample_rate, new_freq=self.intermediate_sample_rate),
|
||||
MelSpectrogram(
|
||||
sample_rate=self.intermediate_sample_rate,
|
||||
n_fft=self.n_fft,
|
||||
hop_length=self.hop_length,
|
||||
n_mels=self.n_mels,
|
||||
power=2, # Power spectrum
|
||||
),
|
||||
Lambda(
|
||||
lambda x: amplitude_to_DB(x, multiplier=10, amin=1e-10, db_multiplier=0)
|
||||
), # Convert to decibels
|
||||
Resize(
|
||||
(self.output_height, self.output_width)
|
||||
), # Resize spectrogram to output_height×output_width
|
||||
Lambda(
|
||||
lambda x: x.unsqueeze(1).expand(-1, self.output_channels, -1, -1)
|
||||
), # Duplicate across 3 channels to mimic RGB images. Dimensions are [batch, rgb, height, width].
|
||||
]
|
||||
)
|
||||
|
||||
def _process_observation(self, observation: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
"""
|
||||
Processes audio data contained in the provided observation.
|
||||
"""
|
||||
processed_obs = observation.copy()
|
||||
|
||||
# Process single audio observation
|
||||
if OBS_AUDIO in processed_obs:
|
||||
audio_data = processed_obs[OBS_AUDIO]
|
||||
if isinstance(audio_data, Tensor) and audio_data.dim() == 3: # Batch, Channels, Samples
|
||||
processed_obs[OBS_AUDIO] = self.mel_spectrogram_transform(audio_data)
|
||||
|
||||
# Process multiple audio observations
|
||||
for key, value in processed_obs.items():
|
||||
if (
|
||||
key.startswith(f"{OBS_AUDIO}.") and isinstance(value, Tensor) and value.dim() == 3
|
||||
): # Batch, Channels, Samples
|
||||
processed_obs[key] = self.mel_spectrogram_transform(value)
|
||||
|
||||
return processed_obs
|
||||
|
||||
def observation(self, observation: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
return self._process_observation(observation)
|
||||
@@ -25,7 +25,8 @@ from dataclasses import dataclass, field
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.utils.constants import OBS_AUDIO, OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.types import EnvTransition, PolicyAction
|
||||
from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
|
||||
|
||||
from .pipeline import (
|
||||
ComplementaryDataProcessorStep,
|
||||
@@ -35,7 +36,6 @@ from .pipeline import (
|
||||
ProcessorStepRegistry,
|
||||
TransitionKey,
|
||||
)
|
||||
from lerobot.types import PolicyAction, EnvTransition
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -88,8 +88,6 @@ class AddBatchDimensionObservationStep(ObservationProcessorStep):
|
||||
- State vectors (1D tensors).
|
||||
- Single images (3D tensors).
|
||||
- Dictionaries of multiple images (3D tensors).
|
||||
- Single audio waveforms (2D tensors).
|
||||
- Dictionaries of multiple audio waveforms (2D tensors).
|
||||
"""
|
||||
|
||||
def observation(self, observation: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
@@ -119,18 +117,6 @@ class AddBatchDimensionObservationStep(ObservationProcessorStep):
|
||||
for key, value in observation.items():
|
||||
if key.startswith(f"{OBS_IMAGES}.") and isinstance(value, Tensor) and value.dim() == 3:
|
||||
observation[key] = value.unsqueeze(0)
|
||||
|
||||
# Process single audio observation - add batch dim if 2D
|
||||
if OBS_AUDIO in observation:
|
||||
audio_value = observation[OBS_AUDIO]
|
||||
if isinstance(audio_value, Tensor) and audio_value.dim() == 2:
|
||||
observation[OBS_AUDIO] = audio_value.unsqueeze(0)
|
||||
|
||||
# Process multiple audio observations - add batch dim if 2D
|
||||
for key, value in observation.items():
|
||||
if key.startswith(f"{OBS_AUDIO}.") and isinstance(value, Tensor) and value.dim() == 2:
|
||||
observation[key] = value.unsqueeze(0)
|
||||
|
||||
return observation
|
||||
|
||||
def transform_features(
|
||||
|
||||
@@ -96,9 +96,11 @@ class BiOpenArmFollower(Robot):
|
||||
left_arm_motors_ft = self.left_arm._motors_ft
|
||||
right_arm_motors_ft = self.right_arm._motors_ft
|
||||
|
||||
# Right first, then left — matches the teleoperator (OpenArmMini) ordering
|
||||
# and the dataset feature names recorded during data collection.
|
||||
return {
|
||||
**{f"left_{k}": v for k, v in left_arm_motors_ft.items()},
|
||||
**{f"right_{k}": v for k, v in right_arm_motors_ft.items()},
|
||||
**{f"left_{k}": v for k, v in left_arm_motors_ft.items()},
|
||||
}
|
||||
|
||||
@property
|
||||
@@ -150,14 +152,16 @@ class BiOpenArmFollower(Robot):
|
||||
left_cam_keys = set(self.left_arm.cameras.keys())
|
||||
right_cam_keys = set(self.right_arm.cameras.keys())
|
||||
|
||||
left_obs = self.left_arm.get_observation()
|
||||
for key, value in left_obs.items():
|
||||
obs_dict[key if key in left_cam_keys else f"left_{key}"] = value
|
||||
|
||||
# Right first, then left — matches the teleoperator (OpenArmMini) ordering
|
||||
# and the dataset feature names recorded during data collection.
|
||||
right_obs = self.right_arm.get_observation()
|
||||
for key, value in right_obs.items():
|
||||
obs_dict[key if key in right_cam_keys else f"right_{key}"] = value
|
||||
|
||||
left_obs = self.left_arm.get_observation()
|
||||
for key, value in left_obs.items():
|
||||
obs_dict[key if key in left_cam_keys else f"left_{key}"] = value
|
||||
|
||||
return obs_dict
|
||||
|
||||
@check_if_not_connected
|
||||
@@ -183,7 +187,7 @@ class BiOpenArmFollower(Robot):
|
||||
prefixed_sent_action_left = {f"left_{key}": value for key, value in sent_action_left.items()}
|
||||
prefixed_sent_action_right = {f"right_{key}": value for key, value in sent_action_right.items()}
|
||||
|
||||
return {**prefixed_sent_action_left, **prefixed_sent_action_right}
|
||||
return {**prefixed_sent_action_right, **prefixed_sent_action_left}
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self):
|
||||
|
||||
@@ -23,10 +23,12 @@ from ..config import RobotConfig
|
||||
|
||||
|
||||
@RobotConfig.register_subclass("bi_openarm_follower")
|
||||
@dataclass
|
||||
@dataclass(kw_only=True)
|
||||
class BiOpenArmFollowerConfig(RobotConfig):
|
||||
"""Configuration class for Bi OpenArm Follower robots."""
|
||||
|
||||
id: str | None = "bi_openarm_follower"
|
||||
|
||||
left_arm_config: OpenArmFollowerConfigBase
|
||||
right_arm_config: OpenArmFollowerConfigBase
|
||||
|
||||
|
||||
@@ -34,13 +34,6 @@ class RobotConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
raise ValueError(
|
||||
f"Specifying '{attr}' is required for the camera to be used in a robot"
|
||||
)
|
||||
if hasattr(self, "microphones") and self.microphones:
|
||||
for _, config in self.microphones.items():
|
||||
for attr in ["sample_rate", "channels"]:
|
||||
if getattr(config, attr) is None:
|
||||
raise ValueError(
|
||||
f"Specifying '{attr}' is required for the microphone to be used in a robot"
|
||||
)
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
|
||||
@@ -15,7 +15,6 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.cameras import CameraConfig
|
||||
from lerobot.microphones import MicrophoneConfig
|
||||
|
||||
from ..config import RobotConfig
|
||||
|
||||
@@ -36,8 +35,5 @@ class KochFollowerConfig(RobotConfig):
|
||||
# cameras
|
||||
cameras: dict[str, CameraConfig] = field(default_factory=dict)
|
||||
|
||||
# microphones
|
||||
microphones: dict[str, MicrophoneConfig] = field(default_factory=dict)
|
||||
|
||||
# Set to `True` for backward compatibility with previous policies/dataset
|
||||
use_degrees: bool = False
|
||||
|
||||
@@ -19,7 +19,6 @@ import time
|
||||
from functools import cached_property
|
||||
|
||||
from lerobot.cameras.utils import make_cameras_from_configs
|
||||
from lerobot.microphones.utils import make_microphones_from_configs
|
||||
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
|
||||
from lerobot.motors.dynamixel import (
|
||||
DynamixelMotorsBus,
|
||||
@@ -62,7 +61,6 @@ class KochFollower(Robot):
|
||||
calibration=self.calibration,
|
||||
)
|
||||
self.cameras = make_cameras_from_configs(config.cameras)
|
||||
self.microphones = make_microphones_from_configs(config.microphones)
|
||||
|
||||
@property
|
||||
def _motors_ft(self) -> dict[str, type]:
|
||||
@@ -74,16 +72,9 @@ class KochFollower(Robot):
|
||||
cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras
|
||||
}
|
||||
|
||||
@property
|
||||
def _microphones_ft(self) -> dict[str, tuple]:
|
||||
return {
|
||||
mic: (self.config.microphones[mic].sample_rate, self.config.microphones[mic].channels)
|
||||
for mic in self.microphones
|
||||
}
|
||||
|
||||
@cached_property
|
||||
def observation_features(self) -> dict[str, type | tuple]:
|
||||
return {**self._motors_ft, **self._cameras_ft, **self._microphones_ft}
|
||||
return {**self._motors_ft, **self._cameras_ft}
|
||||
|
||||
@cached_property
|
||||
def action_features(self) -> dict[str, type]:
|
||||
@@ -91,11 +82,7 @@ class KochFollower(Robot):
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return (
|
||||
self.bus.is_connected
|
||||
and all(cam.is_connected for cam in self.cameras.values())
|
||||
and all(mic.is_connected for mic in self.microphones.values())
|
||||
)
|
||||
return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values())
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
@@ -114,9 +101,6 @@ class KochFollower(Robot):
|
||||
for cam in self.cameras.values():
|
||||
cam.connect()
|
||||
|
||||
for mic in self.microphones.values():
|
||||
mic.connect()
|
||||
|
||||
self.configure()
|
||||
logger.info(f"{self} connected.")
|
||||
|
||||
@@ -213,13 +197,6 @@ class KochFollower(Robot):
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
||||
|
||||
# Read audio frames from microphones
|
||||
for mic_key, mic in self.microphones.items():
|
||||
start = time.perf_counter()
|
||||
obs_dict[mic_key] = mic.read()
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {mic_key}: {dt_ms:.1f}ms")
|
||||
|
||||
return obs_dict
|
||||
|
||||
@check_if_not_connected
|
||||
@@ -255,7 +232,5 @@ class KochFollower(Robot):
|
||||
self.bus.disconnect(self.config.disable_torque_on_disconnect)
|
||||
for cam in self.cameras.values():
|
||||
cam.disconnect()
|
||||
for mic in self.microphones.values():
|
||||
mic.disconnect()
|
||||
|
||||
logger.info(f"{self} disconnected.")
|
||||
|
||||
@@ -16,7 +16,6 @@ from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.cameras.configs import CameraConfig, Cv2Rotation
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.microphones import MicrophoneConfig
|
||||
|
||||
from ..config import RobotConfig
|
||||
|
||||
@@ -46,8 +45,6 @@ class LeKiwiConfig(RobotConfig):
|
||||
|
||||
cameras: dict[str, CameraConfig] = field(default_factory=lekiwi_cameras_config)
|
||||
|
||||
microphones: dict[str, MicrophoneConfig] = field(default_factory=dict)
|
||||
|
||||
# Set to `True` for backward compatibility with previous policies/dataset
|
||||
use_degrees: bool = False
|
||||
|
||||
@@ -95,7 +92,5 @@ class LeKiwiClientConfig(RobotConfig):
|
||||
|
||||
cameras: dict[str, CameraConfig] = field(default_factory=lekiwi_cameras_config)
|
||||
|
||||
microphones: dict[str, MicrophoneConfig] = field(default_factory=dict)
|
||||
|
||||
polling_timeout_ms: int = 15
|
||||
connect_timeout_s: int = 5
|
||||
|
||||
@@ -23,7 +23,6 @@ from typing import Any
|
||||
import numpy as np
|
||||
|
||||
from lerobot.cameras.utils import make_cameras_from_configs
|
||||
from lerobot.microphones.utils import make_microphones_from_configs
|
||||
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
|
||||
from lerobot.motors.feetech import (
|
||||
FeetechMotorsBus,
|
||||
@@ -74,7 +73,6 @@ class LeKiwi(Robot):
|
||||
self.arm_motors = [motor for motor in self.bus.motors if motor.startswith("arm")]
|
||||
self.base_motors = [motor for motor in self.bus.motors if motor.startswith("base")]
|
||||
self.cameras = make_cameras_from_configs(config.cameras)
|
||||
self.microphones = make_microphones_from_configs(config.microphones)
|
||||
|
||||
@property
|
||||
def _state_ft(self) -> dict[str, type]:
|
||||
@@ -99,16 +97,9 @@ class LeKiwi(Robot):
|
||||
cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras
|
||||
}
|
||||
|
||||
@property
|
||||
def _microphones_ft(self) -> dict[str, tuple]:
|
||||
return {
|
||||
mic: (self.config.microphones[mic].sample_rate, self.config.microphones[mic].channels)
|
||||
for mic in self.microphones
|
||||
}
|
||||
|
||||
@cached_property
|
||||
def observation_features(self) -> dict[str, type | tuple]:
|
||||
return {**self._state_ft, **self._cameras_ft, **self._microphones_ft}
|
||||
return {**self._state_ft, **self._cameras_ft}
|
||||
|
||||
@cached_property
|
||||
def action_features(self) -> dict[str, type]:
|
||||
@@ -116,11 +107,7 @@ class LeKiwi(Robot):
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return (
|
||||
self.bus.is_connected
|
||||
and all(cam.is_connected for cam in self.cameras.values())
|
||||
and all(mic.is_connected for mic in self.microphones.values())
|
||||
)
|
||||
return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values())
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
@@ -134,9 +121,6 @@ class LeKiwi(Robot):
|
||||
for cam in self.cameras.values():
|
||||
cam.connect()
|
||||
|
||||
for mic in self.microphones.values():
|
||||
mic.connect()
|
||||
|
||||
self.configure()
|
||||
logger.info(f"{self} connected.")
|
||||
|
||||
@@ -380,13 +364,6 @@ class LeKiwi(Robot):
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
||||
|
||||
# Read audio frames from microphones
|
||||
for mic_key, mic in self.microphones.items():
|
||||
start = time.perf_counter()
|
||||
obs_dict[mic_key] = mic.read()
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {mic_key}: {dt_ms:.1f}ms")
|
||||
|
||||
return obs_dict
|
||||
|
||||
@check_if_not_connected
|
||||
@@ -436,7 +413,5 @@ class LeKiwi(Robot):
|
||||
self.bus.disconnect(self.config.disable_torque_on_disconnect)
|
||||
for cam in self.cameras.values():
|
||||
cam.disconnect()
|
||||
for mic in self.microphones.values():
|
||||
mic.disconnect()
|
||||
|
||||
logger.info(f"{self} disconnected.")
|
||||
|
||||
@@ -18,7 +18,6 @@ import base64
|
||||
import json
|
||||
import logging
|
||||
from functools import cached_property
|
||||
from time import perf_counter
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
@@ -59,9 +58,8 @@ class LeKiwiClient(Robot):
|
||||
self.zmq_observation_socket = None
|
||||
|
||||
self.last_frames = {}
|
||||
|
||||
self.last_remote_state = {}
|
||||
self.last_frame_timestamp = None
|
||||
self.last_frame_delay = 0.0
|
||||
|
||||
# Define three speed levels and a current index
|
||||
self.speed_levels = [
|
||||
@@ -99,13 +97,9 @@ class LeKiwiClient(Robot):
|
||||
def _cameras_ft(self) -> dict[str, tuple[int, int, int]]:
|
||||
return {name: (cfg.height, cfg.width, 3) for name, cfg in self.config.cameras.items()}
|
||||
|
||||
@cached_property
|
||||
def _microphones_ft(self) -> dict[str, tuple]:
|
||||
return {name: (cfg.sample_rate, cfg.channels) for name, cfg in self.config.microphones.items()}
|
||||
|
||||
@cached_property
|
||||
def observation_features(self) -> dict[str, type | tuple]:
|
||||
return {**self._state_ft, **self._cameras_ft, **self._microphones_ft}
|
||||
return {**self._state_ft, **self._cameras_ft}
|
||||
|
||||
@cached_property
|
||||
def action_features(self) -> dict[str, type]:
|
||||
@@ -141,7 +135,6 @@ class LeKiwiClient(Robot):
|
||||
if self.zmq_observation_socket not in socks or socks[self.zmq_observation_socket] != zmq.POLLIN:
|
||||
raise DeviceNotConnectedError("Timeout waiting for LeKiwi Host to connect expired.")
|
||||
|
||||
self.last_frame_timestamp = perf_counter()
|
||||
self._is_connected = True
|
||||
|
||||
def calibrate(self) -> None:
|
||||
@@ -174,8 +167,6 @@ class LeKiwiClient(Robot):
|
||||
if last_msg is None:
|
||||
logging.warning("Poller indicated data, but failed to retrieve message.")
|
||||
|
||||
self.last_frame_delay = perf_counter() - self.last_frame_timestamp
|
||||
self.last_frame_timestamp = perf_counter()
|
||||
return last_msg
|
||||
|
||||
def _parse_observation_json(self, obs_string: str) -> RobotObservation | None:
|
||||
@@ -212,16 +203,14 @@ class LeKiwiClient(Robot):
|
||||
|
||||
obs_dict: RobotObservation = {**flat_state, OBS_STATE: state_vec}
|
||||
|
||||
# Decode images and audio data
|
||||
# Decode images
|
||||
current_frames: dict[str, np.ndarray] = {}
|
||||
for frame_name, frame_data in observation.items():
|
||||
if frame_name in self._cameras_ft:
|
||||
image = self._decode_image_from_b64(frame_data)
|
||||
if image is not None:
|
||||
current_frames[frame_name] = image
|
||||
elif frame_name in self._microphones_ft:
|
||||
if frame_data is not None:
|
||||
current_frames[frame_name] = frame_data
|
||||
for cam_name, image_b64 in observation.items():
|
||||
if cam_name not in self._cameras_ft:
|
||||
continue
|
||||
frame = self._decode_image_from_b64(image_b64)
|
||||
if frame is not None:
|
||||
current_frames[cam_name] = frame
|
||||
|
||||
return current_frames, obs_dict
|
||||
|
||||
@@ -265,27 +254,17 @@ class LeKiwiClient(Robot):
|
||||
"""
|
||||
Capture observations from the remote robot: current follower arm positions,
|
||||
present wheel speeds (converted to body-frame velocities: x, y, theta),
|
||||
and cameras and microphones data. Receives over ZMQ, translate to body-frame vel
|
||||
and a camera frame. Receives over ZMQ, translate to body-frame vel
|
||||
"""
|
||||
|
||||
frames, obs_dict = self._get_data()
|
||||
|
||||
# Loop over each configured camera and microphone
|
||||
for frame_name, frame_data in frames.items():
|
||||
if frame_data is None:
|
||||
if frame_name in self._cameras_ft:
|
||||
logging.warning("Image frame is None")
|
||||
image = np.zeros((640, 480, 3), dtype=np.uint8)
|
||||
obs_dict[frame_name] = image
|
||||
elif frame_name in self._microphones_ft:
|
||||
logging.warning("Audio frame is None")
|
||||
obs_dict[frame_name] = np.zeros(
|
||||
(
|
||||
int(self._microphones_ft[frame_name][0] * self.last_frame_delay),
|
||||
self._microphones_ft[frame_name][1],
|
||||
),
|
||||
dtype=np.float32,
|
||||
)
|
||||
# Loop over each configured camera
|
||||
for cam_name, frame in frames.items():
|
||||
if frame is None:
|
||||
logging.warning("Frame is None")
|
||||
frame = np.zeros((640, 480, 3), dtype=np.uint8)
|
||||
obs_dict[cam_name] = frame
|
||||
|
||||
return obs_dict
|
||||
|
||||
|
||||
@@ -17,7 +17,6 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.cameras import CameraConfig
|
||||
from lerobot.microphones import MicrophoneConfig
|
||||
|
||||
from ..config import RobotConfig
|
||||
|
||||
@@ -39,9 +38,6 @@ class SOFollowerConfig:
|
||||
# cameras
|
||||
cameras: dict[str, CameraConfig] = field(default_factory=dict)
|
||||
|
||||
# microphones
|
||||
microphones: dict[str, MicrophoneConfig] = field(default_factory=dict)
|
||||
|
||||
# Set to `True` for backward compatibility with previous policies/dataset
|
||||
use_degrees: bool = True
|
||||
|
||||
|
||||
@@ -19,7 +19,6 @@ import time
|
||||
from functools import cached_property
|
||||
|
||||
from lerobot.cameras.utils import make_cameras_from_configs
|
||||
from lerobot.microphones.utils import make_microphones_from_configs
|
||||
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
|
||||
from lerobot.motors.feetech import (
|
||||
FeetechMotorsBus,
|
||||
@@ -62,7 +61,6 @@ class SOFollower(Robot):
|
||||
calibration=self.calibration,
|
||||
)
|
||||
self.cameras = make_cameras_from_configs(config.cameras)
|
||||
self.microphones = make_microphones_from_configs(config.microphones)
|
||||
|
||||
@property
|
||||
def _motors_ft(self) -> dict[str, type]:
|
||||
@@ -74,16 +72,9 @@ class SOFollower(Robot):
|
||||
cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras
|
||||
}
|
||||
|
||||
@property
|
||||
def _microphones_ft(self) -> dict[str, tuple]:
|
||||
return {
|
||||
mic: (self.config.microphones[mic].sample_rate, self.config.microphones[mic].channels)
|
||||
for mic in self.microphones
|
||||
}
|
||||
|
||||
@cached_property
|
||||
def observation_features(self) -> dict[str, type | tuple]:
|
||||
return {**self._motors_ft, **self._cameras_ft, **self._microphones_ft}
|
||||
return {**self._motors_ft, **self._cameras_ft}
|
||||
|
||||
@cached_property
|
||||
def action_features(self) -> dict[str, type]:
|
||||
@@ -91,11 +82,7 @@ class SOFollower(Robot):
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return (
|
||||
self.bus.is_connected
|
||||
and all(cam.is_connected for cam in self.cameras.values())
|
||||
and all(mic.is_connected for mic in self.microphones.values())
|
||||
)
|
||||
return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values())
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
@@ -114,9 +101,6 @@ class SOFollower(Robot):
|
||||
for cam in self.cameras.values():
|
||||
cam.connect()
|
||||
|
||||
for mic in self.microphones.values():
|
||||
mic.connect()
|
||||
|
||||
self.configure()
|
||||
logger.info(f"{self} connected.")
|
||||
|
||||
@@ -206,13 +190,6 @@ class SOFollower(Robot):
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
||||
|
||||
# Read audio frames from microphones
|
||||
for mic_key, mic in self.microphones.items():
|
||||
start = time.perf_counter()
|
||||
obs_dict[mic_key] = mic.read()
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {mic_key}: {dt_ms:.1f}ms")
|
||||
|
||||
return obs_dict
|
||||
|
||||
@check_if_not_connected
|
||||
@@ -248,8 +225,6 @@ class SOFollower(Robot):
|
||||
self.bus.disconnect(self.config.disable_torque_on_disconnect)
|
||||
for cam in self.cameras.values():
|
||||
cam.disconnect()
|
||||
for mic in self.microphones.values():
|
||||
mic.disconnect()
|
||||
|
||||
logger.info(f"{self} disconnected.")
|
||||
|
||||
|
||||
@@ -85,7 +85,7 @@ from lerobot.datasets.utils import (
|
||||
flatten_dict,
|
||||
update_chunk_file_indices,
|
||||
)
|
||||
from lerobot.datasets.video_utils import concatenate_media_files, get_media_duration_in_s
|
||||
from lerobot.datasets.video_utils import concatenate_video_files, get_video_duration_in_s
|
||||
from lerobot.utils.constants import HF_LEROBOT_HOME
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
@@ -318,12 +318,12 @@ def convert_videos_of_camera(root: Path, new_root: Path, video_key: str, video_f
|
||||
|
||||
for ep_path in tqdm.tqdm(ep_paths, desc=f"convert videos of {video_key}"):
|
||||
ep_size_in_mb = get_file_size_in_mb(ep_path)
|
||||
ep_duration_in_s = get_media_duration_in_s(ep_path, media_type="video")
|
||||
ep_duration_in_s = get_video_duration_in_s(ep_path)
|
||||
|
||||
# Check if adding this episode would exceed the limit
|
||||
if size_in_mb + ep_size_in_mb >= video_file_size_in_mb and len(paths_to_cat) > 0:
|
||||
# Size limit would be exceeded, save current accumulation WITHOUT this episode
|
||||
concatenate_media_files(
|
||||
concatenate_video_files(
|
||||
paths_to_cat,
|
||||
new_root
|
||||
/ DEFAULT_VIDEO_PATH.format(video_key=video_key, chunk_index=chunk_idx, file_index=file_idx),
|
||||
@@ -359,7 +359,7 @@ def convert_videos_of_camera(root: Path, new_root: Path, video_key: str, video_f
|
||||
|
||||
# Write remaining videos if any
|
||||
if paths_to_cat:
|
||||
concatenate_media_files(
|
||||
concatenate_video_files(
|
||||
paths_to_cat,
|
||||
new_root
|
||||
/ DEFAULT_VIDEO_PATH.format(video_key=video_key, chunk_index=chunk_idx, file_index=file_idx),
|
||||
@@ -402,12 +402,7 @@ def generate_episode_metadata_dict(
|
||||
if len(ep_ids_set) != 1:
|
||||
raise ValueError(f"Number of episodes is not the same ({ep_ids_set}).")
|
||||
|
||||
ep_dict = {
|
||||
**ep_metadata,
|
||||
**ep_video,
|
||||
**ep_legacy_metadata,
|
||||
**flatten_dict({"stats": ep_stats}),
|
||||
}
|
||||
ep_dict = {**ep_metadata, **ep_video, **ep_legacy_metadata, **flatten_dict({"stats": ep_stats})}
|
||||
ep_dict["meta/episodes/chunk_index"] = 0
|
||||
ep_dict["meta/episodes/file_index"] = 0
|
||||
yield ep_dict
|
||||
@@ -428,10 +423,7 @@ def convert_episodes_metadata(root, new_root, episodes_metadata, episodes_video_
|
||||
|
||||
ds_episodes = Dataset.from_generator(
|
||||
lambda: generate_episode_metadata_dict(
|
||||
episodes_legacy_metadata,
|
||||
episodes_metadata,
|
||||
episodes_stats,
|
||||
episodes_video_metadata,
|
||||
episodes_legacy_metadata, episodes_metadata, episodes_stats, episodes_video_metadata
|
||||
)
|
||||
)
|
||||
write_episodes(ds_episodes, new_root)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user