mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-11 22:59:50 +00:00
Compare commits
49 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 92fba37225 | |||
| 3e45120272 | |||
| f0d2b37beb | |||
| cbc8bfb2e6 | |||
| 0d1be72dc8 | |||
| 96b7c212c4 | |||
| 4303b3c930 | |||
| 63dca86df8 | |||
| 8a0cc3d664 | |||
| 8bb8ed4803 | |||
| 095856b06a | |||
| 563f42bdb1 | |||
| 8fff0fde7c | |||
| 04de496547 | |||
| baf9b50365 | |||
| a0fdbf037a | |||
| c085531b17 | |||
| c7c6205332 | |||
| 4e54be1334 | |||
| fde9d08281 | |||
| 46044fed75 | |||
| 975dcad918 | |||
| d0b58190da | |||
| 9a5ab8ffab | |||
| 7541d72130 | |||
| 0317a15bf1 | |||
| f138e5948a | |||
| 8fef4ddab8 | |||
| 18d9cb5ac4 | |||
| 5095ab0845 | |||
| dac1efd13d | |||
| 7fd71c83a3 | |||
| 0f44adbeec | |||
| 7dbbaa3727 | |||
| fcabfd32a5 | |||
| 544cbc5f38 | |||
| a0c5d19391 | |||
| e96339a3b4 | |||
| 5865170d36 | |||
| 2dd366436e | |||
| 5f15232271 | |||
| bc38261321 | |||
| aaf3707058 | |||
| 89bd58a9a2 | |||
| b22e0315b0 | |||
| fcbf550952 | |||
| af036ce57e | |||
| 1c388c0002 | |||
| 51d3822d75 |
@@ -61,6 +61,7 @@ jobs:
|
||||
MUJOCO_GL: egl
|
||||
HF_HOME: /mnt/cache/.cache/huggingface
|
||||
HF_LEROBOT_HOME: /mnt/cache/.cache/huggingface/lerobot
|
||||
HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }}
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
with:
|
||||
@@ -89,5 +90,10 @@ jobs:
|
||||
- name: Install lerobot with test extras
|
||||
run: uv sync --extra "test"
|
||||
|
||||
- name: Login to Hugging Face
|
||||
run: |
|
||||
uv run hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
|
||||
uv run hf auth whoami
|
||||
|
||||
- name: Run pytest
|
||||
run: uv run pytest tests -vv --maxfail=10
|
||||
|
||||
@@ -60,6 +60,7 @@ jobs:
|
||||
MUJOCO_GL: egl
|
||||
HF_HOME: /mnt/cache/.cache/huggingface
|
||||
HF_LEROBOT_HOME: /mnt/cache/.cache/huggingface/lerobot
|
||||
HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }}
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
with:
|
||||
@@ -87,6 +88,11 @@ jobs:
|
||||
- name: Install lerobot with all extras
|
||||
run: uv sync --extra all # TODO(Steven): Make flash-attn optional
|
||||
|
||||
- name: Login to Hugging Face
|
||||
run: |
|
||||
uv run hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
|
||||
uv run hf auth whoami
|
||||
|
||||
- name: Run pytest (all extras)
|
||||
run: uv run pytest tests -vv --maxfail=10
|
||||
|
||||
@@ -162,6 +168,7 @@ jobs:
|
||||
HF_LEROBOT_HOME: /home/user_lerobot/.cache/huggingface/lerobot
|
||||
TORCH_HOME: /home/user_lerobot/.cache/torch
|
||||
TRITON_CACHE_DIR: /home/user_lerobot/.cache/triton
|
||||
HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }}
|
||||
container:
|
||||
image: ${{ needs.build-and-push-docker.outputs.image_tag }} # zizmor: ignore[unpinned-images]
|
||||
options: --gpus all --shm-size "16gb"
|
||||
@@ -173,6 +180,12 @@ jobs:
|
||||
shell: bash
|
||||
working-directory: /lerobot
|
||||
steps:
|
||||
- name: Login to Hugging Face
|
||||
run: |
|
||||
hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
|
||||
hf auth whoami
|
||||
- name: Fix ptxas permissions
|
||||
run: chmod +x /lerobot/.venv/lib/python3.10/site-packages/triton/backends/nvidia/bin/ptxas
|
||||
- name: Run pytest on GPU
|
||||
run: pytest tests -vv --maxfail=10
|
||||
- name: Run end-to-end tests
|
||||
|
||||
@@ -119,6 +119,7 @@ jobs:
|
||||
HF_LEROBOT_HOME: /home/user_lerobot/.cache/huggingface/lerobot
|
||||
TORCH_HOME: /home/user_lerobot/.cache/torch
|
||||
TRITON_CACHE_DIR: /home/user_lerobot/.cache/triton
|
||||
HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }}
|
||||
container:
|
||||
image: ${{ needs.build-docker-cpu-nightly.outputs.image_tag }} # zizmor: ignore[unpinned-images]
|
||||
options: --shm-size "16gb"
|
||||
@@ -130,6 +131,10 @@ jobs:
|
||||
shell: bash
|
||||
working-directory: /lerobot
|
||||
steps:
|
||||
- name: Login to Hugging Face
|
||||
run: |
|
||||
hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
|
||||
hf auth whoami
|
||||
- name: Run pytest on CPU
|
||||
run: pytest tests -vv --maxfail=10
|
||||
- name: Run end-to-end tests
|
||||
@@ -146,6 +151,7 @@ jobs:
|
||||
HF_LEROBOT_HOME: /home/user_lerobot/.cache/huggingface/lerobot
|
||||
TORCH_HOME: /home/user_lerobot/.cache/torch
|
||||
TRITON_CACHE_DIR: /home/user_lerobot/.cache/triton
|
||||
HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }}
|
||||
container:
|
||||
image: ${{ needs.build-docker-gpu-nightly.outputs.image_tag }} # zizmor: ignore[unpinned-images]
|
||||
options: --gpus all --shm-size "16gb"
|
||||
@@ -157,6 +163,10 @@ jobs:
|
||||
shell: bash
|
||||
working-directory: /lerobot
|
||||
steps:
|
||||
- name: Login to Hugging Face
|
||||
run: |
|
||||
hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
|
||||
hf auth whoami
|
||||
- name: Run pytest on GPU
|
||||
run: pytest tests -vv --maxfail=10
|
||||
- name: Run end-to-end tests
|
||||
@@ -174,6 +184,7 @@ jobs:
|
||||
TORCH_HOME: /home/user_lerobot/.cache/torch
|
||||
TRITON_CACHE_DIR: /home/user_lerobot/.cache/triton
|
||||
CUDA_VISIBLE_DEVICES: "0,1,2,3"
|
||||
HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }}
|
||||
container:
|
||||
image: ${{ needs.build-docker-gpu-nightly.outputs.image_tag }} # zizmor: ignore[unpinned-images]
|
||||
options: --gpus all --shm-size "16gb"
|
||||
@@ -185,6 +196,10 @@ jobs:
|
||||
shell: bash
|
||||
working-directory: /lerobot
|
||||
steps:
|
||||
- name: Login to Hugging Face
|
||||
run: |
|
||||
hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
|
||||
hf auth whoami
|
||||
- name: Verify GPU availability
|
||||
run: |
|
||||
nvidia-smi
|
||||
@@ -193,4 +208,3 @@ jobs:
|
||||
- name: Run multi-GPU training tests
|
||||
# TODO(Steven): Investigate why motors tests are failing in multi-GPU setup
|
||||
run: pytest tests -vv --maxfail=10 --ignore=tests/motors/
|
||||
timeout-minutes: 10
|
||||
|
||||
@@ -48,6 +48,7 @@ jobs:
|
||||
MUJOCO_GL: egl
|
||||
HF_HOME: /mnt/cache/.cache/huggingface
|
||||
HF_LEROBOT_HOME: /mnt/cache/.cache/huggingface/lerobot
|
||||
HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }}
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
with:
|
||||
@@ -79,7 +80,10 @@ jobs:
|
||||
|
||||
- name: Install lerobot with all extras
|
||||
run: uv sync --extra all # TODO(Steven): Make flash-attn optional
|
||||
|
||||
- name: Login to Hugging Face
|
||||
run: |
|
||||
uv run hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
|
||||
uv run hf auth whoami
|
||||
- name: Run pytest (all extras)
|
||||
run: uv run pytest tests -vv
|
||||
|
||||
@@ -137,6 +141,7 @@ jobs:
|
||||
HF_LEROBOT_HOME: /home/user_lerobot/.cache/huggingface/lerobot
|
||||
TORCH_HOME: /home/user_lerobot/.cache/torch
|
||||
TRITON_CACHE_DIR: /home/user_lerobot/.cache/triton
|
||||
HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }}
|
||||
container:
|
||||
image: ${{ needs.build-and-push-docker.outputs.image_tag }} # zizmor: ignore[unpinned-images]
|
||||
options: --gpus all --shm-size "16gb"
|
||||
@@ -148,6 +153,10 @@ jobs:
|
||||
shell: bash
|
||||
working-directory: /lerobot
|
||||
steps:
|
||||
- name: Login to Hugging Face
|
||||
run: |
|
||||
hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
|
||||
hf auth whoami
|
||||
- name: Run pytest on GPU
|
||||
run: pytest tests -vv
|
||||
- name: Run end-to-end tests
|
||||
|
||||
@@ -0,0 +1,25 @@
|
||||
# AI Usage Policy
|
||||
|
||||
The LeRobot project welcomes contributions from everyone, and we have a few guidelines regarding AI usage to ensure high code quality, clear communication, and a healthy open-source ecosystem:
|
||||
|
||||
- **Please disclose significant AI assistance.** If you used AI tools (e.g., Copilot, Claude, Cursor, ChatGPT) to generate a substantial portion of your code or text, let us know in your PR description. Transparency helps us review your changes more effectively.
|
||||
- **Own your code (The Human-in-the-Loop).** You must fully understand all the changes you are proposing. If you cannot explain what your AI-assisted code does or how it interacts with LeRobot's broader architecture, please take the time to learn and test it before submitting.
|
||||
- **Keep issues and discussions focused.** You are welcome to use AI to help draft issues or PR descriptions, but please review and edit them carefully before posting. AI can often be overly verbose; trimming the noise and getting straight to the point helps our maintainers address your needs faster.
|
||||
|
||||
Our core maintainers also use AI tools to aid their workflows, but they do so while bringing deep contextual knowledge of the LeRobot codebase to validate the output. We ask all contributors to apply that same level of rigor.
|
||||
|
||||
## Remember the Human Maintainers
|
||||
|
||||
Please remember that LeRobot is maintained by a dedicated team of humans.
|
||||
|
||||
Every discussion, issue, and pull request is read and reviewed by real people. While AI tools can generate thousands of lines of code in seconds, reviewing that code still takes human time and energy. Submitting unverified or low-effort AI output puts an unfair burden on our maintainers.
|
||||
|
||||
Today, the quality of the AI output still heavily depends on the developer driving the tool. We ask that you respect our maintainers' time by thoroughly vetting, testing, and refining your submissions.
|
||||
|
||||
## AI is Welcome Here
|
||||
|
||||
LeRobot operates at the cutting edge of AI and robotics, and many of our maintainers actively embrace AI coding assistants as valuable productivity tools. We are a pro-AI project!
|
||||
|
||||
Our reason for having an AI policy is not an anti-AI stance. Rather, it exists to ensure that AI is used to enhance human contributions, not replace them with unverified noise. It's about how the tools are used, not the tools themselves.
|
||||
|
||||
We value the unique human insight you bring to the LeRobot community. Let AI empower your workflow, but always let your own judgment take the wheel.
|
||||
+1
-1
@@ -2,7 +2,7 @@
|
||||
|
||||
Everyone is welcome to contribute, and we value everybody's contribution. Code is not the only way to help the community. Answering questions, helping others, reaching out, and improving the documentation are immensely valuable.
|
||||
|
||||
Whichever way you choose to contribute, please be mindful to respect our [code of conduct](./CODE_OF_CONDUCT.md).
|
||||
Whichever way you choose to contribute, please be mindful to respect our [code of conduct](./CODE_OF_CONDUCT.md) and our [AI policy](./AI_POLICY.md).
|
||||
|
||||
## Ways to Contribute
|
||||
|
||||
|
||||
@@ -1,2 +1,3 @@
|
||||
include src/lerobot/templates/lerobot_modelcard_template.md
|
||||
include src/lerobot/datasets/card_template.md
|
||||
include src/lerobot/envs/metaworld_config.json
|
||||
|
||||
+42
-42
@@ -28,9 +28,9 @@ We don't expect the same optimal settings for a dataset of images from a simulat
|
||||
For these reasons, we run this benchmark on four representative datasets:
|
||||
|
||||
- `lerobot/pusht_image`: (96 x 96 pixels) simulation with simple geometric shapes, fixed camera.
|
||||
- `aliberts/aloha_mobile_shrimp_image`: (480 x 640 pixels) real-world indoor, moving camera.
|
||||
- `aliberts/paris_street`: (720 x 1280 pixels) real-world outdoor, moving camera.
|
||||
- `aliberts/kitchen`: (1080 x 1920 pixels) real-world indoor, fixed camera.
|
||||
- `lerobot/aloha_mobile_shrimp_image`: (480 x 640 pixels) real-world indoor, moving camera.
|
||||
- `lerobot/paris_street`: (720 x 1280 pixels) real-world outdoor, moving camera.
|
||||
- `lerobot/kitchen`: (1080 x 1920 pixels) real-world indoor, fixed camera.
|
||||
|
||||
Note: The datasets used for this benchmark need to be image datasets, not video datasets.
|
||||
|
||||
@@ -179,7 +179,7 @@ python benchmark/video/run_video_benchmark.py \
|
||||
--output-dir outputs/video_benchmark \
|
||||
--repo-ids \
|
||||
lerobot/pusht_image \
|
||||
aliberts/aloha_mobile_shrimp_image \
|
||||
lerobot/aloha_mobile_shrimp_image \
|
||||
--vcodec libx264 libx265 \
|
||||
--pix-fmt yuv444p yuv420p \
|
||||
--g 2 20 None \
|
||||
@@ -203,9 +203,9 @@ python benchmark/video/run_video_benchmark.py \
|
||||
--output-dir outputs/video_benchmark \
|
||||
--repo-ids \
|
||||
lerobot/pusht_image \
|
||||
aliberts/aloha_mobile_shrimp_image \
|
||||
aliberts/paris_street \
|
||||
aliberts/kitchen \
|
||||
lerobot/aloha_mobile_shrimp_image \
|
||||
lerobot/paris_street \
|
||||
lerobot/kitchen \
|
||||
--vcodec libx264 libx265 \
|
||||
--pix-fmt yuv444p yuv420p \
|
||||
--g 1 2 3 4 5 6 10 15 20 40 None \
|
||||
@@ -221,9 +221,9 @@ python benchmark/video/run_video_benchmark.py \
|
||||
--output-dir outputs/video_benchmark \
|
||||
--repo-ids \
|
||||
lerobot/pusht_image \
|
||||
aliberts/aloha_mobile_shrimp_image \
|
||||
aliberts/paris_street \
|
||||
aliberts/kitchen \
|
||||
lerobot/aloha_mobile_shrimp_image \
|
||||
lerobot/paris_street \
|
||||
lerobot/kitchen \
|
||||
--vcodec libsvtav1 \
|
||||
--pix-fmt yuv420p \
|
||||
--g 1 2 3 4 5 6 10 15 20 40 None \
|
||||
@@ -252,37 +252,37 @@ Since we're using av1 encoding, we're choosing the `pyav` decoder as `video_read
|
||||
|
||||
These tables show the results for `g=2` and `crf=30`, using `timestamps-modes=6_frames` and `backend=pyav`
|
||||
|
||||
| video_images_size_ratio | vcodec | pix_fmt | | | |
|
||||
| ---------------------------------- | ---------- | ------- | --------- | --------- | --------- |
|
||||
| | libx264 | | libx265 | | libsvtav1 |
|
||||
| repo_id | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
|
||||
| lerobot/pusht_image | **16.97%** | 17.58% | 18.57% | 18.86% | 22.06% |
|
||||
| aliberts/aloha_mobile_shrimp_image | 2.14% | 2.11% | 1.38% | **1.37%** | 5.59% |
|
||||
| aliberts/paris_street | 2.12% | 2.13% | **1.54%** | **1.54%** | 4.43% |
|
||||
| aliberts/kitchen | 1.40% | 1.39% | **1.00%** | **1.00%** | 2.52% |
|
||||
| video_images_size_ratio | vcodec | pix_fmt | | | |
|
||||
| --------------------------------- | ---------- | ------- | --------- | --------- | --------- |
|
||||
| | libx264 | | libx265 | | libsvtav1 |
|
||||
| repo_id | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
|
||||
| lerobot/pusht_image | **16.97%** | 17.58% | 18.57% | 18.86% | 22.06% |
|
||||
| lerobot/aloha_mobile_shrimp_image | 2.14% | 2.11% | 1.38% | **1.37%** | 5.59% |
|
||||
| lerobot/paris_street | 2.12% | 2.13% | **1.54%** | **1.54%** | 4.43% |
|
||||
| lerobot/kitchen | 1.40% | 1.39% | **1.00%** | **1.00%** | 2.52% |
|
||||
|
||||
| video_images_load_time_ratio | vcodec | pix_fmt | | | |
|
||||
| ---------------------------------- | ------- | ------- | -------- | ------- | --------- |
|
||||
| | libx264 | | libx265 | | libsvtav1 |
|
||||
| repo_id | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
|
||||
| lerobot/pusht_image | 6.45 | 5.19 | **1.90** | 2.12 | 2.47 |
|
||||
| aliberts/aloha_mobile_shrimp_image | 11.80 | 7.92 | 0.71 | 0.85 | **0.48** |
|
||||
| aliberts/paris_street | 2.21 | 2.05 | 0.36 | 0.49 | **0.30** |
|
||||
| aliberts/kitchen | 1.46 | 1.46 | 0.28 | 0.51 | **0.26** |
|
||||
| video_images_load_time_ratio | vcodec | pix_fmt | | | |
|
||||
| --------------------------------- | ------- | ------- | -------- | ------- | --------- |
|
||||
| | libx264 | | libx265 | | libsvtav1 |
|
||||
| repo_id | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
|
||||
| lerobot/pusht_image | 6.45 | 5.19 | **1.90** | 2.12 | 2.47 |
|
||||
| lerobot/aloha_mobile_shrimp_image | 11.80 | 7.92 | 0.71 | 0.85 | **0.48** |
|
||||
| lerobot/paris_street | 2.21 | 2.05 | 0.36 | 0.49 | **0.30** |
|
||||
| lerobot/kitchen | 1.46 | 1.46 | 0.28 | 0.51 | **0.26** |
|
||||
|
||||
| | | vcodec | pix_fmt | | | |
|
||||
| ---------------------------------- | -------- | -------- | ------------ | -------- | --------- | ------------ |
|
||||
| | | libx264 | | libx265 | | libsvtav1 |
|
||||
| repo_id | metric | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
|
||||
| lerobot/pusht_image | avg_mse | 2.90E-04 | **2.03E-04** | 3.13E-04 | 2.29E-04 | 2.19E-04 |
|
||||
| | avg_psnr | 35.44 | 37.07 | 35.49 | **37.30** | 37.20 |
|
||||
| | avg_ssim | 98.28% | **98.85%** | 98.31% | 98.84% | 98.72% |
|
||||
| aliberts/aloha_mobile_shrimp_image | avg_mse | 2.76E-04 | 2.59E-04 | 3.17E-04 | 3.06E-04 | **1.30E-04** |
|
||||
| | avg_psnr | 35.91 | 36.21 | 35.88 | 36.09 | **40.17** |
|
||||
| | avg_ssim | 95.19% | 95.18% | 95.00% | 95.05% | **97.73%** |
|
||||
| aliberts/paris_street | avg_mse | 6.89E-04 | 6.70E-04 | 4.03E-03 | 4.02E-03 | **3.09E-04** |
|
||||
| | avg_psnr | 33.48 | 33.68 | 32.05 | 32.15 | **35.40** |
|
||||
| | avg_ssim | 93.76% | 93.75% | 89.46% | 89.46% | **95.46%** |
|
||||
| aliberts/kitchen | avg_mse | 2.50E-04 | 2.24E-04 | 4.28E-04 | 4.18E-04 | **1.53E-04** |
|
||||
| | avg_psnr | 36.73 | 37.33 | 36.56 | 36.75 | **39.12** |
|
||||
| | avg_ssim | 95.47% | 95.58% | 95.52% | 95.53% | **96.82%** |
|
||||
| | | vcodec | pix_fmt | | | |
|
||||
| --------------------------------- | -------- | -------- | ------------ | -------- | --------- | ------------ |
|
||||
| | | libx264 | | libx265 | | libsvtav1 |
|
||||
| repo_id | metric | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
|
||||
| lerobot/pusht_image | avg_mse | 2.90E-04 | **2.03E-04** | 3.13E-04 | 2.29E-04 | 2.19E-04 |
|
||||
| | avg_psnr | 35.44 | 37.07 | 35.49 | **37.30** | 37.20 |
|
||||
| | avg_ssim | 98.28% | **98.85%** | 98.31% | 98.84% | 98.72% |
|
||||
| lerobot/aloha_mobile_shrimp_image | avg_mse | 2.76E-04 | 2.59E-04 | 3.17E-04 | 3.06E-04 | **1.30E-04** |
|
||||
| | avg_psnr | 35.91 | 36.21 | 35.88 | 36.09 | **40.17** |
|
||||
| | avg_ssim | 95.19% | 95.18% | 95.00% | 95.05% | **97.73%** |
|
||||
| lerobot/paris_street | avg_mse | 6.89E-04 | 6.70E-04 | 4.03E-03 | 4.02E-03 | **3.09E-04** |
|
||||
| | avg_psnr | 33.48 | 33.68 | 32.05 | 32.15 | **35.40** |
|
||||
| | avg_ssim | 93.76% | 93.75% | 89.46% | 89.46% | **95.46%** |
|
||||
| lerobot/kitchen | avg_mse | 2.50E-04 | 2.24E-04 | 4.28E-04 | 4.18E-04 | **1.53E-04** |
|
||||
| | avg_psnr | 36.73 | 37.33 | 36.56 | 36.75 | **39.12** |
|
||||
| | avg_ssim | 95.47% | 95.58% | 95.52% | 95.53% | **96.82%** |
|
||||
|
||||
@@ -85,6 +85,8 @@ RUN if [ "$UNBOUND_DEPS" = "true" ]; then \
|
||||
|
||||
RUN uv pip install --no-cache ".[all]"
|
||||
|
||||
RUN chmod +x /lerobot/.venv/lib/python${PYTHON_VERSION}/site-packages/triton/backends/nvidia/bin/ptxas
|
||||
|
||||
# Copy the rest of the application source code
|
||||
# Make sure to have the git-LFS files for testing
|
||||
COPY --chown=user_lerobot:user_lerobot . .
|
||||
|
||||
@@ -29,6 +29,8 @@
|
||||
title: Using the Dataset Tools
|
||||
- local: dataset_subtask
|
||||
title: Using Subtasks in the Dataset
|
||||
- local: streaming_video_encoding
|
||||
title: Streaming Video Encoding
|
||||
title: "Datasets"
|
||||
- sections:
|
||||
- local: act
|
||||
|
||||
@@ -88,5 +88,8 @@ lerobot-record \
|
||||
--dataset.repo_id=${HF_USER}/eval_act_your_dataset \
|
||||
--dataset.num_episodes=10 \
|
||||
--dataset.single_task="Your task description" \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.vcodec=auto \
|
||||
--policy.path=${HF_USER}/act_policy
|
||||
```
|
||||
|
||||
@@ -48,7 +48,7 @@ python -m lerobot.async_inference.robot_client \
|
||||
--task="dummy" \ # POLICY: The task to run the policy on (`Fold my t-shirt`). Not necessarily defined for all policies, such as `act`
|
||||
--policy_type=your_policy_type \ # POLICY: the type of policy to run (smolvla, act, etc)
|
||||
--pretrained_name_or_path=user/model \ # POLICY: the model name/path on server to the checkpoint to run (e.g., lerobot/smolvla_base)
|
||||
--policy_device=mps \ # POLICY: the device to run the policy on, on the server
|
||||
--policy_device=mps \ # POLICY: the device to run the policy on, on the server (cuda, mps, xpu, cpu)
|
||||
--actions_per_chunk=50 \ # POLICY: the number of actions to output at once
|
||||
--chunk_size_threshold=0.5 \ # CLIENT: the threshold for the chunk size before sending a new observation to the server
|
||||
--aggregate_fn_name=weighted_average \ # CLIENT: the function to aggregate actions on overlapping portions
|
||||
|
||||
@@ -170,13 +170,13 @@ Once you can drive the robot well, you can start recording data to train AI mode
|
||||
We use Hugging Face to store your data online. First, log in with your token from [Hugging Face settings](https://huggingface.co/settings/tokens):
|
||||
|
||||
```bash
|
||||
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
|
||||
hf auth login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
|
||||
```
|
||||
|
||||
Store your Hugging Face username:
|
||||
|
||||
```bash
|
||||
HF_USER=$(huggingface-cli whoami | head -n 1)
|
||||
HF_USER=$(hf auth whoami | awk -F': *' 'NR==1 {print $2}')
|
||||
echo $HF_USER
|
||||
```
|
||||
|
||||
@@ -185,13 +185,16 @@ echo $HF_USER
|
||||
Use the standard recording command:
|
||||
|
||||
```bash
|
||||
python src/lerobot/scripts/lerobot_record.py \
|
||||
lerobot-record \
|
||||
--robot.type=earthrover_mini_plus \
|
||||
--teleop.type=keyboard_rover \
|
||||
--dataset.repo_id=your_username/dataset_name \
|
||||
--dataset.num_episodes=2 \
|
||||
--dataset.fps=10 \
|
||||
--dataset.single_task="Navigate around obstacles" \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.vcodec=auto \
|
||||
--display_data=true
|
||||
```
|
||||
|
||||
|
||||
@@ -155,10 +155,10 @@ Upload your repository to Hugging Face:
|
||||
pip install huggingface_hub
|
||||
|
||||
# Login to Hugging Face
|
||||
huggingface-cli login
|
||||
hf auth login
|
||||
|
||||
# Create a new repository
|
||||
huggingface-cli repo create my-custom-env --type space --org my-org
|
||||
hf repo create my-org/my-custom-env
|
||||
|
||||
# Initialize git and push
|
||||
git init
|
||||
|
||||
@@ -120,9 +120,12 @@ lerobot-record \
|
||||
--display_data=true \
|
||||
--dataset.repo_id=<user>/eval_groot-bimanual \
|
||||
--dataset.num_episodes=10 \
|
||||
--dataset.single_task="Grab and handover the red cube to the other arm"
|
||||
--policy.path=<user>/groot-bimanual # your trained model
|
||||
--dataset.episode_time_s=30
|
||||
--dataset.single_task="Grab and handover the red cube to the other arm" \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.vcodec=auto \
|
||||
--policy.path=<user>/groot-bimanual \ # your trained model
|
||||
--dataset.episode_time_s=30 \
|
||||
--dataset.reset_time_s=10
|
||||
```
|
||||
|
||||
|
||||
+11
-5
@@ -224,12 +224,15 @@ lerobot-record \
|
||||
--teleop.port=/dev/tty.usbmodem1201 \
|
||||
--teleop.id=right \
|
||||
--teleop.side=right \
|
||||
--dataset.repo_id=nepyope/hand_record_test_with_video_data \
|
||||
--dataset.repo_id=<USER>/hand_record_test_with_video_data \
|
||||
--dataset.single_task="Hand recording test with video data" \
|
||||
--dataset.num_episodes=1 \
|
||||
--dataset.episode_time_s=5 \
|
||||
--dataset.push_to_hub=true \
|
||||
--dataset.private=true \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.vcodec=auto \
|
||||
--display_data=true
|
||||
```
|
||||
|
||||
@@ -241,7 +244,7 @@ lerobot-replay \
|
||||
--robot.port=/dev/tty.usbmodem58760432281 \
|
||||
--robot.id=right \
|
||||
--robot.side=right \
|
||||
--dataset.repo_id=nepyope/hand_record_test_with_camera \
|
||||
--dataset.repo_id=<USER>/hand_record_test_with_camera \
|
||||
--dataset.episode=0
|
||||
```
|
||||
|
||||
@@ -249,13 +252,13 @@ lerobot-replay \
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--dataset.repo_id=nepyope/hand_record_test_with_video_data \
|
||||
--dataset.repo_id=<USER>/hand_record_test_with_video_data \
|
||||
--policy.type=act \
|
||||
--output_dir=outputs/train/hopejr_hand \
|
||||
--job_name=hopejr \
|
||||
--policy.device=mps \
|
||||
--wandb.enable=true \
|
||||
--policy.repo_id=nepyope/hand_test_policy
|
||||
--policy.repo_id=<USER>/hand_test_policy
|
||||
```
|
||||
|
||||
### Evaluate
|
||||
@@ -270,8 +273,11 @@ lerobot-record \
|
||||
--robot.side=right \
|
||||
--robot.cameras='{"main": {"type": "opencv", "index_or_path": 0, "width": 640, "height": 480, "fps": 30}}' \
|
||||
--display_data=false \
|
||||
--dataset.repo_id=nepyope/eval_hopejr \
|
||||
--dataset.repo_id=<USER>/eval_hopejr \
|
||||
--dataset.single_task="Evaluate hopejr hand policy" \
|
||||
--dataset.num_episodes=10 \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.vcodec=auto \
|
||||
--policy.path=outputs/train/hopejr_hand/checkpoints/last/pretrained_model
|
||||
```
|
||||
|
||||
@@ -159,13 +159,13 @@ We use the Hugging Face hub features for uploading your dataset. If you haven't
|
||||
Add your token to the CLI by running this command:
|
||||
|
||||
```bash
|
||||
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
|
||||
hf auth login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
|
||||
```
|
||||
|
||||
Then store your Hugging Face repository name in a variable:
|
||||
|
||||
```bash
|
||||
HF_USER=$(hf auth whoami | head -n 1)
|
||||
HF_USER=$(hf auth whoami | awk -F': *' 'NR==1 {print $2}')
|
||||
echo $HF_USER
|
||||
```
|
||||
|
||||
@@ -185,7 +185,10 @@ lerobot-record \
|
||||
--display_data=true \
|
||||
--dataset.repo_id=${HF_USER}/record-test \
|
||||
--dataset.num_episodes=5 \
|
||||
--dataset.single_task="Grab the black cube"
|
||||
--dataset.single_task="Grab the black cube" \
|
||||
--dataset.streaming_encoding=true \
|
||||
# --dataset.vcodec=auto \
|
||||
--dataset.encoder_threads=2
|
||||
```
|
||||
</hfoption>
|
||||
<hfoption id="API example">
|
||||
@@ -324,7 +327,7 @@ You can look for other LeRobot datasets on the hub by searching for `LeRobot` [t
|
||||
You can also push your local dataset to the Hub manually, running:
|
||||
|
||||
```bash
|
||||
huggingface-cli upload ${HF_USER}/record-test ~/.cache/huggingface/lerobot/{repo-id} --repo-type dataset
|
||||
hf upload ${HF_USER}/record-test ~/.cache/huggingface/lerobot/{repo-id} --repo-type dataset
|
||||
```
|
||||
|
||||
#### Record function
|
||||
@@ -488,7 +491,7 @@ If your local computer doesn't have a powerful GPU you could utilize Google Cola
|
||||
Once training is done, upload the latest checkpoint with:
|
||||
|
||||
```bash
|
||||
huggingface-cli upload ${HF_USER}/act_so101_test \
|
||||
hf upload ${HF_USER}/act_so101_test \
|
||||
outputs/train/act_so101_test/checkpoints/last/pretrained_model
|
||||
```
|
||||
|
||||
@@ -496,7 +499,7 @@ You can also upload intermediate checkpoints with:
|
||||
|
||||
```bash
|
||||
CKPT=010000
|
||||
huggingface-cli upload ${HF_USER}/act_so101_test${CKPT} \
|
||||
hf upload ${HF_USER}/act_so101_test${CKPT} \
|
||||
outputs/train/act_so101_test/checkpoints/${CKPT}/pretrained_model
|
||||
```
|
||||
|
||||
@@ -515,6 +518,9 @@ lerobot-record \
|
||||
--display_data=false \
|
||||
--dataset.repo_id=${HF_USER}/eval_so100 \
|
||||
--dataset.single_task="Put lego brick into the transparent box" \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.vcodec=auto \
|
||||
# <- Teleop optional if you want to teleoperate in between episodes \
|
||||
# --teleop.type=so100_leader \
|
||||
# --teleop.port=/dev/ttyACM0 \
|
||||
|
||||
@@ -40,6 +40,13 @@ conda install ffmpeg -c conda-forge
|
||||
>
|
||||
> - _[On Linux only]_ If you want to bring your own ffmpeg: Install [ffmpeg build dependencies](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#GettheDependencies) and [compile ffmpeg from source with libsvtav1](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#libsvtav1), and make sure you use the corresponding ffmpeg binary to your install with `which ffmpeg`.
|
||||
|
||||
> [!NOTE]
|
||||
> When installing LeRobot inside WSL (Windows Subsystem for Linux), make sure to install `evdev` with the following command:
|
||||
>
|
||||
> ```bash
|
||||
> conda install evdev -c conda-forge
|
||||
> ```
|
||||
|
||||
## Step 3: Install LeRobot 🤗
|
||||
|
||||
### From Source
|
||||
|
||||
@@ -279,13 +279,13 @@ We use the Hugging Face hub features for uploading your dataset. If you haven't
|
||||
Add your token to the CLI by running this command:
|
||||
|
||||
```bash
|
||||
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
|
||||
hf auth login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
|
||||
```
|
||||
|
||||
Then store your Hugging Face repository name in a variable:
|
||||
|
||||
```bash
|
||||
HF_USER=$(huggingface-cli whoami | head -n 1)
|
||||
HF_USER=$(hf auth whoami | awk -F': *' 'NR==1 {print $2}')
|
||||
echo $HF_USER
|
||||
```
|
||||
|
||||
|
||||
@@ -41,7 +41,10 @@ lerobot-record \
|
||||
--display_data=true \
|
||||
--dataset.repo_id=${HF_USER}/record-test \
|
||||
--dataset.num_episodes=5 \
|
||||
--dataset.single_task="Grab the black cube"
|
||||
--dataset.single_task="Grab the black cube" \
|
||||
--dataset.streaming_encoding=true \
|
||||
# --dataset.vcodec=auto \
|
||||
--dataset.encoder_threads=2
|
||||
```
|
||||
|
||||
See the [recording guide](./il_robots#record-a-dataset) for more details.
|
||||
|
||||
@@ -66,12 +66,13 @@ Run on of the examples scripts to teleoperate, record a dataset, replay a datase
|
||||
|
||||
All scripts assume you configured your robot (e.g., SO-100 follower) and set the correct serial port.
|
||||
|
||||
Additionally you need to **copy the urdf of the robot to the examples folder**. For the examples in this tutorial (Using SO100/SO101) it is highly recommended to use the urdf in the [SO-ARM100 repo](https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf)
|
||||
Additionally you need to **copy the URDF of the robot into the examples folder**. For the examples in this tutorial (using SO100/SO101), copy the `SO101` folder from the [SO-ARM100 repo](https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101) into the `examples/phone_to_so100/` directory, so that the URDF file path becomes `examples/phone_to_so100/SO101/so101_new_calib.urdf`.
|
||||
|
||||
- Run this example to teleoperate:
|
||||
|
||||
```bash
|
||||
python examples/phone_to_so100/teleoperate.py
|
||||
cd examples/phone_to_so100
|
||||
python teleoperate.py
|
||||
```
|
||||
|
||||
After running the example:
|
||||
@@ -84,19 +85,22 @@ Additionally you can customize mapping or safety limits by editing the processor
|
||||
- Run this example to record a dataset, which saves absolute end effector observations and actions:
|
||||
|
||||
```bash
|
||||
python examples/phone_to_so100/record.py
|
||||
cd examples/phone_to_so100
|
||||
python record.py
|
||||
```
|
||||
|
||||
- Run this example to replay recorded episodes:
|
||||
|
||||
```bash
|
||||
python examples/phone_to_so100/replay.py
|
||||
cd examples/phone_to_so100
|
||||
python replay.py
|
||||
```
|
||||
|
||||
- Run this example to evaluate a pretrained policy:
|
||||
|
||||
```bash
|
||||
python examples/phone_to_so100/evaluate.py
|
||||
cd examples/phone_to_so100
|
||||
python evaluate.py
|
||||
```
|
||||
|
||||
### Important pipeline steps and options
|
||||
|
||||
+1
-1
@@ -60,7 +60,7 @@ policy.type=pi0
|
||||
For training π₀, you can use the standard LeRobot training script with the appropriate configuration:
|
||||
|
||||
```bash
|
||||
python src/lerobot/scripts/lerobot_train.py \
|
||||
lerobot-train \
|
||||
--dataset.repo_id=your_dataset \
|
||||
--policy.type=pi0 \
|
||||
--output_dir=./outputs/pi0_training \
|
||||
|
||||
@@ -56,7 +56,7 @@ policy.type=pi05
|
||||
Here's a complete training command for finetuning the base π₀.₅ model on your own dataset:
|
||||
|
||||
```bash
|
||||
python src/lerobot/scripts/lerobot_train.py\
|
||||
lerobot-train \
|
||||
--dataset.repo_id=your_dataset \
|
||||
--policy.type=pi05 \
|
||||
--output_dir=./outputs/pi05_training \
|
||||
|
||||
+10
-10
@@ -52,7 +52,7 @@ This approach can transform **any existing VLM** into a VLA by training it to pr
|
||||
|
||||
You have two options for the FAST tokenizer:
|
||||
|
||||
1. **Use the pre-trained tokenizer**: The `physical-intelligence/fast` tokenizer was trained on 1M+ real robot action sequences and works as a general-purpose tokenizer.
|
||||
1. **Use the pre-trained tokenizer**: The `lerobot/fast-action-tokenizer` tokenizer was trained on 1M+ real robot action sequences and works as a general-purpose tokenizer.
|
||||
|
||||
2. **Train your own tokenizer**: For maximum performance on your specific dataset, you can finetune the tokenizer on your own data.
|
||||
|
||||
@@ -114,15 +114,15 @@ lerobot-train \
|
||||
|
||||
### Key Training Parameters
|
||||
|
||||
| Parameter | Description | Default |
|
||||
| -------------------------------------- | -------------------------------------------------- | ---------------------------- |
|
||||
| `--policy.gradient_checkpointing=true` | Reduces memory usage significantly during training | `false` |
|
||||
| `--policy.dtype=bfloat16` | Use mixed precision training for efficiency | `float32` |
|
||||
| `--policy.chunk_size` | Number of action steps to predict (action horizon) | `50` |
|
||||
| `--policy.n_action_steps` | Number of action steps to execute | `50` |
|
||||
| `--policy.max_action_tokens` | Maximum number of FAST tokens per action chunk | `256` |
|
||||
| `--policy.action_tokenizer_name` | FAST tokenizer to use | `physical-intelligence/fast` |
|
||||
| `--policy.compile_model=true` | Enable torch.compile for faster training | `false` |
|
||||
| Parameter | Description | Default |
|
||||
| -------------------------------------- | -------------------------------------------------- | ------------------------------- |
|
||||
| `--policy.gradient_checkpointing=true` | Reduces memory usage significantly during training | `false` |
|
||||
| `--policy.dtype=bfloat16` | Use mixed precision training for efficiency | `float32` |
|
||||
| `--policy.chunk_size` | Number of action steps to predict (action horizon) | `50` |
|
||||
| `--policy.n_action_steps` | Number of action steps to execute | `50` |
|
||||
| `--policy.max_action_tokens` | Maximum number of FAST tokens per action chunk | `256` |
|
||||
| `--policy.action_tokenizer_name` | FAST tokenizer to use | `lerobot/fast-action-tokenizer` |
|
||||
| `--policy.compile_model=true` | Enable torch.compile for faster training | `false` |
|
||||
|
||||
## Inference
|
||||
|
||||
|
||||
@@ -159,6 +159,9 @@ lerobot-record \
|
||||
--dataset.fps=15 \
|
||||
--dataset.push_to_hub=true \
|
||||
--dataset.private=true \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.vcodec=auto \
|
||||
--display_data=true
|
||||
```
|
||||
|
||||
@@ -198,6 +201,9 @@ lerobot-record \
|
||||
--dataset.fps=15 \
|
||||
--dataset.push_to_hub=true \
|
||||
--dataset.private=true \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.vcodec=auto \
|
||||
--display_data=true
|
||||
```
|
||||
|
||||
|
||||
@@ -269,7 +269,7 @@ This generates visualizations showing video frames with subtask boundaries overl
|
||||
Train with **no annotations** - uses linear progress from 0 to 1:
|
||||
|
||||
```bash
|
||||
python src/lerobot/scripts/lerobot_train.py \
|
||||
lerobot-train \
|
||||
--dataset.repo_id=your-username/your-dataset \
|
||||
--policy.type=sarm \
|
||||
--policy.annotation_mode=single_stage \
|
||||
@@ -288,7 +288,7 @@ python src/lerobot/scripts/lerobot_train.py \
|
||||
Train with **dense annotations only** (sparse auto-generated):
|
||||
|
||||
```bash
|
||||
python src/lerobot/scripts/lerobot_train.py \
|
||||
lerobot-train \
|
||||
--dataset.repo_id=your-username/your-dataset \
|
||||
--policy.type=sarm \
|
||||
--policy.annotation_mode=dense_only \
|
||||
@@ -307,7 +307,7 @@ python src/lerobot/scripts/lerobot_train.py \
|
||||
Train with **both sparse and dense annotations**:
|
||||
|
||||
```bash
|
||||
python src/lerobot/scripts/lerobot_train.py \
|
||||
lerobot-train \
|
||||
--dataset.repo_id=your-username/your-dataset \
|
||||
--policy.type=sarm \
|
||||
--policy.annotation_mode=dual \
|
||||
@@ -468,7 +468,7 @@ This script:
|
||||
Once you have the progress file, train your policy with RA-BC weighting. The progress file is auto-detected from the dataset path (`sarm_progress.parquet`). Currently PI0, PI0.5 and SmolVLA are supported with RA-BC:
|
||||
|
||||
```bash
|
||||
python src/lerobot/scripts/lerobot_train.py \
|
||||
lerobot-train \
|
||||
--dataset.repo_id=your-username/your-dataset \
|
||||
--policy.type=pi0 \
|
||||
--use_rabc=true \
|
||||
|
||||
@@ -106,6 +106,9 @@ lerobot-record \
|
||||
--dataset.repo_id=${HF_USER}/eval_DATASET_NAME_test \ # <- This will be the dataset name on HF Hub
|
||||
--dataset.episode_time_s=50 \
|
||||
--dataset.num_episodes=10 \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.vcodec=auto \
|
||||
# <- Teleop optional if you want to teleoperate in between episodes \
|
||||
# --teleop.type=so100_leader \
|
||||
# --teleop.port=/dev/ttyACM0 \
|
||||
|
||||
@@ -0,0 +1,155 @@
|
||||
# Streaming Video Encoding Guide
|
||||
|
||||
## 1. Overview
|
||||
|
||||
Streaming video encoding eliminates the traditional PNG round-trip during video dataset recording. Instead of:
|
||||
|
||||
1. Capture frame -> write PNG to disk -> (at episode end) read PNG's -> encode to MP4 -> delete PNG's
|
||||
|
||||
Frames can be encoded in real-time during capture:
|
||||
|
||||
1. Capture frame -> queue to encoder thread -> encode to MP4 directly
|
||||
|
||||
This makes `save_episode()` near-instant (the video is already encoded by the time the episode ends) and removes the blocking wait that previously occurred between episodes, especially with multiple cameras in long episodes.
|
||||
|
||||
## 2. Tuning Parameters
|
||||
|
||||
| Parameter | CLI Flag | Type | Default | Description |
|
||||
| ----------------------- | --------------------------------- | ------------- | ------------- | ----------------------------------------------------------------- |
|
||||
| `streaming_encoding` | `--dataset.streaming_encoding` | `bool` | `True` | Enable real-time encoding during capture |
|
||||
| `vcodec` | `--dataset.vcodec` | `str` | `"libsvtav1"` | Video codec. `"auto"` detects best HW encoder |
|
||||
| `encoder_threads` | `--dataset.encoder_threads` | `int \| None` | `None` (auto) | Threads per encoder instance. `None` will leave the vcoded decide |
|
||||
| `encoder_queue_maxsize` | `--dataset.encoder_queue_maxsize` | `int` | `60` | Max buffered frames per camera (~2s at 30fps). Consumes RAM |
|
||||
|
||||
## 3. Performance Considerations
|
||||
|
||||
Streaming encoding means the CPU is encoding video **during** the capture loop, not after. This creates a CPU budget that must be shared between:
|
||||
|
||||
- **Control loop** (reading cameras, control the robot, writing non-video data)
|
||||
- **Encoder threads** (one pool per camera)
|
||||
- **Rerun visualization** (if enabled)
|
||||
- **OS and other processes**
|
||||
|
||||
### Resolution & Number of Cameras Impact
|
||||
|
||||
| Setup | Throughput (px/sec) | CPU Encoding Load | Notes |
|
||||
| ------------------------- | ------------------- | ----------------- | ------------------------------ |
|
||||
| 2camsx 640x480x3 @30fps | 55M | Low | Works on most systems |
|
||||
| 2camsx 1280x720x3 @30fps | 165M | Moderate | Comfortable on modern systems |
|
||||
| 2camsx 1920x1080x3 @30fps | 373M | High | Requires powerful high-end CPU |
|
||||
|
||||
### `encoder_threads` Tuning
|
||||
|
||||
This parameter controls how many threads each encoder instance uses internally:
|
||||
|
||||
- **Higher values** (e.g., 4-5): Faster encoding, but uses more CPU cores per camera. Good for high-end systems with many cores.
|
||||
- **Lower values** (e.g., 1-2): Less CPU per camera, freeing cores for capture and visualization. Good for low-res images and capable CPUs.
|
||||
- **`None` (default)**: Lets the codec decide. Information available in the codec logs.
|
||||
|
||||
### Backpressure and Frame Dropping
|
||||
|
||||
Each camera has a bounded queue (`encoder_queue_maxsize`, default 60 frames). When the encoder can't keep up:
|
||||
|
||||
1. The queue fills up (consuming RAM)
|
||||
2. New frames are **dropped** (not blocked) — the capture loop continues uninterrupted
|
||||
3. A warning is logged: `"Encoder queue full for {camera}, dropped N frame(s)"`
|
||||
4. At episode end, total dropped frames per camera are reported
|
||||
|
||||
### Symptoms of Encoder Falling Behind
|
||||
|
||||
- **System feels laggy and freezes**: all CPUs are at 100%
|
||||
- **Dropped frame warnings** in the log or lower frames/FPS than expected in the recorded dataset
|
||||
- **Choppy robot movement**: If CPU is severely overloaded, even the capture loop may be affected
|
||||
- **Accumulated rerun lag**: Visualization falls behind real-time
|
||||
|
||||
## 4. Hardware-Accelerated Encoding
|
||||
|
||||
### When to Use
|
||||
|
||||
Use HW encoding when:
|
||||
|
||||
- CPU is the bottleneck (dropped frames, choppy robot, rerun lag)
|
||||
- You have compatible hardware (GPU or dedicated encoder)
|
||||
- You're recording at high throughput (high resolution or with many cameras)
|
||||
|
||||
### Choosing a Codec
|
||||
|
||||
| Codec | CPU Usage | File Size | Quality | Notes |
|
||||
| --------------------- | --------- | -------------- | ------- | ---------------------------------------------------------------- |
|
||||
| `libsvtav1` (default) | High | Smallest | Best | Default. Best compression but most CPU-intensive |
|
||||
| `h264` | Medium | ~30-50% larger | Good | Software H.264. Lower CPU |
|
||||
| HW encoders | Very Low | Largest | Good | Offloads to dedicated hardware. Best for CPU-constrained systems |
|
||||
|
||||
### Available HW Encoders
|
||||
|
||||
| Encoder | Platform | Hardware | CLI Value |
|
||||
| ------------------- | ------------- | ------------------------------------------------------------------------------------------------ | ------------------------------------ |
|
||||
| `h264_videotoolbox` | macOS | Apple Silicon / Intel | `--dataset.vcodec=h264_videotoolbox` |
|
||||
| `hevc_videotoolbox` | macOS | Apple Silicon / Intel | `--dataset.vcodec=hevc_videotoolbox` |
|
||||
| `h264_nvenc` | Linux/Windows | NVIDIA GPU | `--dataset.vcodec=h264_nvenc` |
|
||||
| `hevc_nvenc` | Linux/Windows | NVIDIA GPU | `--dataset.vcodec=hevc_nvenc` |
|
||||
| `h264_vaapi` | Linux | Intel/AMD GPU | `--dataset.vcodec=h264_vaapi` |
|
||||
| `h264_qsv` | Linux/Windows | Intel Quick Sync | `--dataset.vcodec=h264_qsv` |
|
||||
| `auto` | Any | Probes the system for available HW encoders. Falls back to `libsvtav1` if no HW encoder is found | `--dataset.vcodec=auto` |
|
||||
|
||||
> [!NOTE]
|
||||
> In order to use the HW accelerated encoders you might need to upgrade your GPU drivers.
|
||||
|
||||
> [!NOTE]
|
||||
> `libsvtav1` is the default because it provides the best training performance; other vcodecs can reduce CPU usage and be faster, but they typically produce larger files and may affect training time.
|
||||
|
||||
## 5. Troubleshooting
|
||||
|
||||
| Symptom | Likely Cause | Fix |
|
||||
| ------------------------------------------------------------------ | -------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
|
||||
| System freezes or choppy robot movement or Rerun visualization lag | CPU starved (100% load usage) | Close other apps, reduce encoding throughput, lower `encoder_threads`, use `h264`, use `display_data=False`. If the CPU continues to be at 100% then it might be insufficient for your setup, consider `--dataset.streaming_encoding=false` or HW encoding (`--dataset.vcodec=auto`) |
|
||||
| "Encoder queue full" warnings or dropped frames in dataset | Encoder can't keep up (Queue overflow) | If CPU is not at 100%: Increase `encoder_threads`, increase `encoder_queue_maxsize` or use HW encoding (`--dataset.vcodec=auto`). |
|
||||
| High RAM usage | Queue filling faster than encoding | `encoder_threads` too low or CPU insufficient. Reduce `encoder_queue_maxsize` or use HW encoding |
|
||||
| Large video files | Using HW encoder or H.264 | Expected trade-off. Switch to `libsvtav1` if CPU allows |
|
||||
| `save_episode()` still slow | `streaming_encoding` is `False` | Set `--dataset.streaming_encoding=true` |
|
||||
| Encoder thread crash | Codec not available or invalid settings | Check `vcodec` is installed, try `--dataset.vcodec=auto` |
|
||||
| Recorded dataset is missing frames | CPU/GPU starvation or occasional load spikes | If ~5% of frames are missing, your system is likely overloaded — follow the recommendations above. If fewer frames are missing (~2%), they are probably due to occasional transient load spikes (often at startup) and can be considered expected. |
|
||||
|
||||
## 6. Recommended Configurations
|
||||
|
||||
These estimates are conservative; we recommend testing them on your setup—start with a low load and increase it gradually.
|
||||
|
||||
### High-End Systems: modern 12+ cores (24+ threads)
|
||||
|
||||
A throughput between ~250-500M px/sec should be comfortable in CPU. For even better results try HW encoding if available.
|
||||
|
||||
```bash
|
||||
# 3camsx 1280x720x3 @30fps: Defaults work well. Optionally increase encoder parallelism.
|
||||
# 2camsx 1920x1080x3 @30fps: Defaults work well. Optionally increase encoder parallelism.
|
||||
lerobot-record --dataset.encoder_threads=5 ...
|
||||
|
||||
# 3camsx 1920x1080x3 @30fps: Might require some tuning.
|
||||
```
|
||||
|
||||
### Mid-Range Systems: modern 8+ cores (16+ threads) or Apple Silicon
|
||||
|
||||
A throughput between ~80-300M px/sec should be possible in CPU.
|
||||
|
||||
```bash
|
||||
# 3camsx 640x480x3 @30fps: Defaults work well. Optionally decrease encoder parallelism.
|
||||
# 2camsx 1280x720x3 @30fps: Defaults work well. Optionally decrease encoder parallelism.
|
||||
lerobot-record --dataset.encoder_threads=2 ...
|
||||
|
||||
# 2camsx 1920x1080x3 @30fps: Might require some tuning.
|
||||
```
|
||||
|
||||
### Low-Resource Systems: modern 4+ cores (8+ threads) or Raspberry Pi 5
|
||||
|
||||
On very constrained systems, streaming encoding may compete too heavily with the capture loop. Disabling it falls back to the PNG-based approach where encoding happens between episodes (blocking, but doesn't interfere with capture). Alternatively, record at a lower throughput to reduce both capture and encoding load. Consider also changing codec to `h264` and using batch encoding.
|
||||
|
||||
```bash
|
||||
# 2camsx 640x480x3 @30fps: Requires some tuning.
|
||||
|
||||
# Use H.264, disable streaming, consider batching encoding
|
||||
lerobot-record --dataset.vcodec=h264 --dataset.streaming_encoding=false ...
|
||||
```
|
||||
|
||||
## 7. Closing note
|
||||
|
||||
Performance ultimately depends on your exact setup — frames-per-second, resolution, CPU cores and load, available memory, episode length, and the encoder you choose. Always test with your target workload, be mindful about your CPU & system capabilities and tune `encoder_threads`, `encoder_queue_maxsize`, and
|
||||
`vcodec` reasonably. That said, a common practical configuration (for many applications) is three cameras at 640×480x3 @30fps; this usually runs fine with the default streaming video encoding settings in modern systems. Always verify your recorded dataset is healthy by comparing the video duration to the CLI episode duration and confirming the row count equals FPS × CLI duration.
|
||||
@@ -216,7 +216,7 @@ lerobot-teleoperate \
|
||||
### Record Dataset in Simulation
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.lerobot_record \
|
||||
lerobot-record \
|
||||
--robot.type=unitree_g1 \
|
||||
--robot.is_simulation=true \
|
||||
--robot.cameras='{"global_view": {"type": "zmq", "server_address": "localhost", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30}}' \
|
||||
@@ -229,7 +229,10 @@ python -m lerobot.scripts.lerobot_record \
|
||||
--dataset.num_episodes=2 \
|
||||
--dataset.episode_time_s=5 \
|
||||
--dataset.reset_time_s=5 \
|
||||
--dataset.push_to_hub=true
|
||||
--dataset.push_to_hub=true \
|
||||
--dataset.streaming_encoding=true \
|
||||
# --dataset.vcodec=auto \
|
||||
--dataset.encoder_threads=2
|
||||
```
|
||||
|
||||
Example simulation dataset: [nepyope/teleop_test_sim](https://huggingface.co/datasets/nepyope/teleop_test_sim)
|
||||
@@ -266,7 +269,7 @@ lerobot-teleoperate \
|
||||
### Record Dataset on Real Robot
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.lerobot_record \
|
||||
lerobot-record \
|
||||
--robot.type=unitree_g1 \
|
||||
--robot.is_simulation=false \
|
||||
--robot.cameras='{"global_view": {"type": "zmq", "server_address": "172.18.129.215", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30}}' \
|
||||
@@ -279,7 +282,10 @@ python -m lerobot.scripts.lerobot_record \
|
||||
--dataset.num_episodes=2 \
|
||||
--dataset.episode_time_s=5 \
|
||||
--dataset.reset_time_s=5 \
|
||||
--dataset.push_to_hub=true
|
||||
--dataset.push_to_hub=true \
|
||||
--dataset.streaming_encoding=true \
|
||||
# --dataset.vcodec=auto \
|
||||
--dataset.encoder_threads=2
|
||||
```
|
||||
|
||||
**Note**: Update `server_address` to match your robot's camera server IP.
|
||||
|
||||
@@ -12,6 +12,7 @@ LeRobot provides several utilities for manipulating datasets:
|
||||
4. **Add Features** - Add new features to a dataset
|
||||
5. **Remove Features** - Remove features from a dataset
|
||||
6. **Convert to Video** - Convert image-based datasets to video format for efficient storage
|
||||
7. **Show the Info of Datasets** - Show the summary of datasets information such as number of episode etc.
|
||||
|
||||
The core implementation is in `lerobot.datasets.dataset_tools`.
|
||||
An example script detailing how to use the tools API is available in `examples/dataset/use_dataset_tools.py`.
|
||||
@@ -156,6 +157,30 @@ lerobot-edit-dataset \
|
||||
|
||||
**Note:** The resulting dataset will be a proper LeRobotDataset with all cameras encoded as videos in the `videos/` directory, with parquet files containing only metadata (no raw image data). All episodes, stats, and tasks are preserved.
|
||||
|
||||
### Show the information of datasets
|
||||
|
||||
Show the information of datasets such as number of episode, number of frame, File size and so on.
|
||||
No change will be made to the dataset
|
||||
|
||||
```bash
|
||||
|
||||
# Show dataset information without feature details
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--operation.type info \
|
||||
|
||||
# Show dataset information with feature details
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--operation.type info \
|
||||
--operation.show_features true
|
||||
|
||||
```
|
||||
|
||||
**Parameters:**
|
||||
|
||||
- `parameters`: The flag to control show or no show dataset information with feature details.(default=false)
|
||||
|
||||
### Push to Hub
|
||||
|
||||
Add the `--push_to_hub true` flag to any command to automatically upload the resulting dataset to the Hugging Face Hub:
|
||||
|
||||
@@ -45,7 +45,7 @@ policy.type=wall_x
|
||||
For training WallX, you can use the standard LeRobot training script with the appropriate configuration:
|
||||
|
||||
```bash
|
||||
python src/lerobot/scripts/lerobot_train.py \
|
||||
lerobot-train \
|
||||
--dataset.repo_id=your_dataset \
|
||||
--policy.type=wall_x \
|
||||
--output_dir=./outputs/wallx_training \
|
||||
|
||||
@@ -154,7 +154,7 @@ lerobot-train \
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--dataset.repo_id=pepijn223/bimanual-so100-handover-cube \
|
||||
--dataset.repo_id=<USER>/bimanual-so100-handover-cube \
|
||||
--output_dir=./outputs/xvla_bimanual \
|
||||
--job_name=xvla_so101_training \
|
||||
--policy.path="lerobot/xvla-base" \
|
||||
|
||||
@@ -22,7 +22,7 @@ lerobot-replay \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58760431541 \
|
||||
--robot.id=black \
|
||||
--dataset.repo_id=aliberts/record-test \
|
||||
--dataset.repo_id=<USER>/record-test \
|
||||
--dataset.episode=2
|
||||
```
|
||||
"""
|
||||
@@ -57,7 +57,7 @@ class DatasetReplayConfig:
|
||||
repo_id: str
|
||||
# Episode to replay.
|
||||
episode: int
|
||||
# Root directory where the dataset will be stored (e.g. 'dataset/path').
|
||||
# Root directory where the dataset will be stored (e.g. 'dataset/path'). If None, defaults to $HF_LEROBOT_HOME/repo_id.
|
||||
root: str | Path | None = None
|
||||
# Limit the frames per second. By default, uses the policy fps.
|
||||
fps: int = 30
|
||||
|
||||
@@ -27,8 +27,8 @@ measuring consistency and ground truth alignment.
|
||||
Usage:
|
||||
# Basic usage with smolvla policy
|
||||
uv run python examples/rtc/eval_dataset.py \
|
||||
--policy.path=helper2424/smolvla_check_rtc_last3 \
|
||||
--dataset.repo_id=helper2424/check_rtc \
|
||||
--policy.path=<USER>/smolvla_check_rtc_last3 \
|
||||
--dataset.repo_id=<USER>/check_rtc \
|
||||
--rtc.execution_horizon=8 \
|
||||
--device=mps \
|
||||
--rtc.max_guidance_weight=10.0 \
|
||||
@@ -58,16 +58,16 @@ Usage:
|
||||
--device=cuda
|
||||
|
||||
uv run python examples/rtc/eval_dataset.py \
|
||||
--policy.path=lipsop/reuben_pi0 \
|
||||
--dataset.repo_id=ReubenLim/so101_cube_in_cup \
|
||||
--policy.path=<USER>/reuben_pi0 \
|
||||
--dataset.repo_id=<USER>/so101_cube_in_cup \
|
||||
--rtc.execution_horizon=8 \
|
||||
--device=cuda
|
||||
|
||||
# With torch.compile for faster inference (PyTorch 2.0+)
|
||||
# Note: CUDA graphs disabled by default due to in-place ops in denoising loop
|
||||
uv run python examples/rtc/eval_dataset.py \
|
||||
--policy.path=helper2424/smolvla_check_rtc_last3 \
|
||||
--dataset.repo_id=helper2424/check_rtc \
|
||||
--policy.path=<USER>/smolvla_check_rtc_last3 \
|
||||
--dataset.repo_id=<USER>/check_rtc \
|
||||
--rtc.execution_horizon=8 \
|
||||
--device=mps \
|
||||
--use_torch_compile=true \
|
||||
@@ -75,8 +75,8 @@ Usage:
|
||||
|
||||
# With torch.compile on CUDA (CUDA graphs disabled by default)
|
||||
uv run python examples/rtc/eval_dataset.py \
|
||||
--policy.path=helper2424/smolvla_check_rtc_last3 \
|
||||
--dataset.repo_id=helper2424/check_rtc \
|
||||
--policy.path=<USER>/smolvla_check_rtc_last3 \
|
||||
--dataset.repo_id=<USER>/check_rtc \
|
||||
--rtc.execution_horizon=8 \
|
||||
--device=cuda \
|
||||
--use_torch_compile=true \
|
||||
@@ -84,8 +84,8 @@ Usage:
|
||||
|
||||
# Enable CUDA graphs (advanced - may cause tensor aliasing errors)
|
||||
uv run python examples/rtc/eval_dataset.py \
|
||||
--policy.path=helper2424/smolvla_check_rtc_last3 \
|
||||
--dataset.repo_id=helper2424/check_rtc \
|
||||
--policy.path=<USER>/smolvla_check_rtc_last3 \
|
||||
--dataset.repo_id=<USER>/check_rtc \
|
||||
--use_torch_compile=true \
|
||||
--torch_compile_backend=inductor \
|
||||
--torch_compile_mode=max-autotune \
|
||||
|
||||
@@ -28,7 +28,7 @@ For simulation environments, see eval_with_simulation.py
|
||||
Usage:
|
||||
# Run RTC with Real robot with RTC
|
||||
uv run examples/rtc/eval_with_real_robot.py \
|
||||
--policy.path=helper2424/smolvla_check_rtc_last3 \
|
||||
--policy.path=<USER>/smolvla_check_rtc_last3 \
|
||||
--policy.device=mps \
|
||||
--rtc.enabled=true \
|
||||
--rtc.execution_horizon=20 \
|
||||
@@ -41,7 +41,7 @@ Usage:
|
||||
|
||||
# Run RTC with Real robot without RTC
|
||||
uv run examples/rtc/eval_with_real_robot.py \
|
||||
--policy.path=helper2424/smolvla_check_rtc_last3 \
|
||||
--policy.path=<USER>/smolvla_check_rtc_last3 \
|
||||
--policy.device=mps \
|
||||
--rtc.enabled=false \
|
||||
--robot.type=so100_follower \
|
||||
@@ -53,7 +53,7 @@ Usage:
|
||||
|
||||
# Run RTC with Real robot with pi0.5 policy
|
||||
uv run examples/rtc/eval_with_real_robot.py \
|
||||
--policy.path=helper2424/pi05_check_rtc \
|
||||
--policy.path=<USER>/pi05_check_rtc \
|
||||
--policy.device=mps \
|
||||
--rtc.enabled=true \
|
||||
--rtc.execution_horizon=20 \
|
||||
|
||||
+27
-101
@@ -25,7 +25,7 @@ discord = "https://discord.gg/s3KuuzsPFb"
|
||||
|
||||
[project]
|
||||
name = "lerobot"
|
||||
version = "0.4.4"
|
||||
version = "0.4.5"
|
||||
description = "🤗 LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch"
|
||||
dynamic = ["readme"]
|
||||
license = { text = "Apache-2.0" }
|
||||
@@ -59,9 +59,9 @@ keywords = ["lerobot", "huggingface", "robotics", "machine learning", "artifici
|
||||
dependencies = [
|
||||
|
||||
# Hugging Face dependencies
|
||||
"datasets>=4.0.0,<4.2.0",
|
||||
"datasets>=4.0.0,<5.0.0",
|
||||
"diffusers>=0.27.2,<0.36.0",
|
||||
"huggingface-hub[hf-transfer,cli]>=0.34.2,<0.36.0",
|
||||
"huggingface-hub[cli]>=1.0.0,<2.0.0",
|
||||
"accelerate>=1.10.0,<2.0.0",
|
||||
|
||||
# Core dependencies
|
||||
@@ -76,9 +76,9 @@ dependencies = [
|
||||
"pyserial>=3.5,<4.0",
|
||||
"wandb>=0.24.0,<0.25.0",
|
||||
|
||||
"torch>=2.2.1,<2.8.0", # TODO: Bumb dependency
|
||||
"torchcodec>=0.2.1,<0.6.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # TODO: Bumb dependency
|
||||
"torchvision>=0.21.0,<0.23.0", # TODO: Bumb dependency
|
||||
"torch>=2.2.1,<2.11.0", # TODO: Bump dependency
|
||||
"torchcodec>=0.2.1,<0.11.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # TODO: Bump dependency
|
||||
"torchvision>=0.21.0,<0.26.0", # TODO: Bump dependency
|
||||
|
||||
"draccus==0.10.0", # TODO: Remove ==
|
||||
"gymnasium>=1.1.1,<2.0.0",
|
||||
@@ -96,13 +96,18 @@ dependencies = [
|
||||
# Common
|
||||
pygame-dep = ["pygame>=2.5.1,<2.7.0"]
|
||||
placo-dep = ["placo>=0.9.6,<0.10.0"]
|
||||
transformers-dep = ["transformers>=4.57.1,<5.0.0"]
|
||||
transformers-dep = ["transformers>=5.3.0,<6.0.0"]
|
||||
grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"]
|
||||
can-dep = ["python-can>=4.2.0,<5.0.0"]
|
||||
peft-dep = ["peft>=0.18.0,<1.0.0"]
|
||||
scipy-dep = ["scipy>=1.14.0,<2.0.0"]
|
||||
qwen-vl-utils-dep = ["qwen-vl-utils>=0.0.11,<0.1.0"]
|
||||
|
||||
# Motors
|
||||
feetech = ["feetech-servo-sdk>=1.0.0,<2.0.0"]
|
||||
dynamixel = ["dynamixel-sdk>=3.7.31,<3.9.0"]
|
||||
damiao = ["python-can>=4.2.0,<5.0.0"]
|
||||
damiao = ["lerobot[can-dep]"]
|
||||
robstride = ["lerobot[can-dep]"]
|
||||
|
||||
# Robots
|
||||
openarms = ["lerobot[damiao]"]
|
||||
@@ -127,17 +132,17 @@ phone = ["hebi-py>=2.8.0,<2.12.0", "teleop>=0.1.0,<0.2.0", "fastapi<1.0"]
|
||||
|
||||
# Policies
|
||||
wallx = [
|
||||
"transformers==4.49.0",
|
||||
"peft==0.17.1",
|
||||
"scipy==1.15.3",
|
||||
"torchdiffeq==0.2.5",
|
||||
"qwen_vl_utils==0.0.11"
|
||||
"lerobot[transformers-dep]",
|
||||
"lerobot[peft]",
|
||||
"lerobot[scipy-dep]",
|
||||
"torchdiffeq>=0.2.4,<0.3.0",
|
||||
"lerobot[qwen-vl-utils-dep]",
|
||||
]
|
||||
pi = ["transformers @ git+https://github.com/huggingface/transformers.git@fix/lerobot_openpi", "scipy>=1.10.1,<1.15"]
|
||||
pi = ["lerobot[transformers-dep]", "lerobot[scipy-dep]"]
|
||||
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0", "safetensors>=0.4.3,<1.0.0"]
|
||||
groot = [
|
||||
"lerobot[transformers-dep]",
|
||||
"peft>=0.13.0,<1.0.0",
|
||||
"lerobot[peft]",
|
||||
"dm-tree>=0.1.8,<1.0.0",
|
||||
"timm>=1.0.0,<1.1.0",
|
||||
"safetensors>=0.4.3,<1.0.0",
|
||||
@@ -146,13 +151,13 @@ groot = [
|
||||
"ninja>=1.11.1,<2.0.0",
|
||||
"flash-attn>=2.5.9,<3.0.0 ; sys_platform != 'darwin'"
|
||||
]
|
||||
sarm = ["lerobot[transformers-dep]", "faker>=33.0.0,<35.0.0", "matplotlib>=3.10.3,<4.0.0", "qwen-vl-utils>=0.0.14,<0.1.0"]
|
||||
sarm = ["lerobot[transformers-dep]", "faker>=33.0.0,<35.0.0", "matplotlib>=3.10.3,<4.0.0", "lerobot[qwen-vl-utils-dep]"]
|
||||
xvla = ["lerobot[transformers-dep]"]
|
||||
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||
|
||||
# Features
|
||||
async = ["lerobot[grpcio-dep]", "matplotlib>=3.10.3,<4.0.0"]
|
||||
peft = ["lerobot[transformers-dep]", "peft>=0.18.0,<1.0.0"]
|
||||
peft = ["lerobot[transformers-dep]", "lerobot[peft-dep]"]
|
||||
|
||||
# Development
|
||||
dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1", "mypy>=1.19.1"]
|
||||
@@ -174,8 +179,8 @@ all = [
|
||||
"lerobot[reachy2]",
|
||||
"lerobot[kinematics]",
|
||||
"lerobot[intelrealsense]",
|
||||
# "lerobot[wallx]",
|
||||
# "lerobot[pi]", TODO(Pepijn): Update pi to transformers v5
|
||||
"lerobot[wallx]",
|
||||
"lerobot[pi]",
|
||||
"lerobot[smolvla]",
|
||||
# "lerobot[groot]", TODO(Steven): Gr00t requires specific installation instructions for flash-attn
|
||||
"lerobot[xvla]",
|
||||
@@ -212,6 +217,9 @@ lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main"
|
||||
lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main"
|
||||
|
||||
# ---------------- Tool Configurations ----------------
|
||||
[tool.setuptools.package-data]
|
||||
lerobot = ["envs/*.json"]
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
|
||||
@@ -392,85 +400,3 @@ ignore_errors = false
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.scripts.*"
|
||||
# ignore_errors = false
|
||||
|
||||
[tool.uv]
|
||||
# wallx requires transformers==4.49.0 which conflicts with other extras that need >=4.53.0
|
||||
conflicts = [
|
||||
[
|
||||
{ extra = "wallx" },
|
||||
{ extra = "transformers-dep" },
|
||||
],
|
||||
[
|
||||
{ extra = "wallx" },
|
||||
{ extra = "pi" },
|
||||
],
|
||||
[
|
||||
{ extra = "wallx" },
|
||||
{ extra = "smolvla" },
|
||||
],
|
||||
[
|
||||
{ extra = "wallx" },
|
||||
{ extra = "groot" },
|
||||
],
|
||||
[
|
||||
{ extra = "wallx" },
|
||||
{ extra = "xvla" },
|
||||
],
|
||||
[
|
||||
{ extra = "wallx" },
|
||||
{ extra = "sarm" },
|
||||
],
|
||||
[
|
||||
{ extra = "wallx" },
|
||||
{ extra = "hilserl" },
|
||||
],
|
||||
[
|
||||
{ extra = "wallx" },
|
||||
{ extra = "libero" },
|
||||
],
|
||||
[
|
||||
{ extra = "wallx" },
|
||||
{ extra = "peft" },
|
||||
],
|
||||
[
|
||||
{ extra = "wallx" },
|
||||
{ extra = "all" },
|
||||
],
|
||||
# pi uses custom branch which conflicts with transformers-dep
|
||||
[
|
||||
{ extra = "pi" },
|
||||
{ extra = "transformers-dep" },
|
||||
],
|
||||
[
|
||||
{ extra = "pi" },
|
||||
{ extra = "smolvla" },
|
||||
],
|
||||
[
|
||||
{ extra = "pi" },
|
||||
{ extra = "groot" },
|
||||
],
|
||||
[
|
||||
{ extra = "pi" },
|
||||
{ extra = "xvla" },
|
||||
],
|
||||
[
|
||||
{ extra = "pi" },
|
||||
{ extra = "sarm" },
|
||||
],
|
||||
[
|
||||
{ extra = "pi" },
|
||||
{ extra = "hilserl" },
|
||||
],
|
||||
[
|
||||
{ extra = "pi" },
|
||||
{ extra = "libero" },
|
||||
],
|
||||
[
|
||||
{ extra = "pi" },
|
||||
{ extra = "peft" },
|
||||
],
|
||||
[
|
||||
{ extra = "pi" },
|
||||
{ extra = "all" },
|
||||
],
|
||||
]
|
||||
|
||||
@@ -49,23 +49,18 @@ import torch
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
|
||||
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
Robot,
|
||||
RobotConfig,
|
||||
bi_so_follower,
|
||||
koch_follower,
|
||||
from lerobot.robots import (
|
||||
RobotConfig, # noqa: F401
|
||||
make_robot_from_config,
|
||||
omx_follower,
|
||||
so_follower,
|
||||
)
|
||||
from lerobot.transport import (
|
||||
services_pb2, # type: ignore
|
||||
services_pb2_grpc, # type: ignore
|
||||
)
|
||||
from lerobot.transport.utils import grpc_channel_options, send_bytes_in_chunks
|
||||
from lerobot.utils.import_utils import register_third_party_plugins
|
||||
|
||||
from .configs import RobotClientConfig
|
||||
from .constants import SUPPORTED_ROBOTS
|
||||
from .helpers import (
|
||||
Action,
|
||||
FPSTracker,
|
||||
@@ -485,8 +480,9 @@ class RobotClient:
|
||||
def async_client(cfg: RobotClientConfig):
|
||||
logging.info(pformat(asdict(cfg)))
|
||||
|
||||
if cfg.robot.type not in SUPPORTED_ROBOTS:
|
||||
raise ValueError(f"Robot {cfg.robot.type} not yet supported!")
|
||||
# TODO: Assert if checking robot support is still needed with the plugin system
|
||||
# if cfg.robot.type not in SUPPORTED_ROBOTS:
|
||||
# raise ValueError(f"Robot {cfg.robot.type} not yet supported!")
|
||||
|
||||
client = RobotClient(cfg)
|
||||
|
||||
@@ -512,4 +508,5 @@ def async_client(cfg: RobotClientConfig):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
register_third_party_plugins()
|
||||
async_client() # run the client
|
||||
|
||||
@@ -150,7 +150,7 @@ class Camera(abc.ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]:
|
||||
def read_latest(self, max_age_ms: int = 500) -> NDArray[Any]:
|
||||
"""Return the most recent frame captured immediately (Peeking).
|
||||
|
||||
This method is non-blocking and returns whatever is currently in the
|
||||
|
||||
@@ -530,7 +530,7 @@ class OpenCVCamera(Camera):
|
||||
return frame
|
||||
|
||||
@check_if_not_connected
|
||||
def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]:
|
||||
def read_latest(self, max_age_ms: int = 500) -> NDArray[Any]:
|
||||
"""Return the most recent frame captured immediately (Peeking).
|
||||
|
||||
This method is non-blocking and returns whatever is currently in the
|
||||
|
||||
@@ -201,7 +201,7 @@ class Reachy2Camera(Camera):
|
||||
return self.read()
|
||||
|
||||
@check_if_not_connected
|
||||
def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]:
|
||||
def read_latest(self, max_age_ms: int = 500) -> NDArray[Any]:
|
||||
"""Return the most recent frame captured immediately (Peeking).
|
||||
|
||||
This method is non-blocking and returns whatever is currently in the
|
||||
|
||||
@@ -573,7 +573,7 @@ class RealSenseCamera(Camera):
|
||||
|
||||
# NOTE(Steven): Missing implementation for depth for now
|
||||
@check_if_not_connected
|
||||
def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]:
|
||||
def read_latest(self, max_age_ms: int = 500) -> NDArray[Any]:
|
||||
"""Return the most recent (color) frame captured immediately (Peeking).
|
||||
|
||||
This method is non-blocking and returns whatever is currently in the
|
||||
|
||||
@@ -27,7 +27,7 @@ class DatasetConfig:
|
||||
# "dataset_index" into the returned item. The index mapping is made according to the order in which the
|
||||
# datasets are provided.
|
||||
repo_id: str
|
||||
# Root directory where the dataset will be stored (e.g. 'dataset/path').
|
||||
# Root directory where the dataset will be stored (e.g. 'dataset/path'). If None, defaults to $HF_LEROBOT_HOME/repo_id.
|
||||
root: str | None = None
|
||||
episodes: list[int] | None = None
|
||||
image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig)
|
||||
|
||||
@@ -289,7 +289,9 @@ def aggregate_datasets(
|
||||
|
||||
logging.info("Find all tasks")
|
||||
unique_tasks = pd.concat([m.tasks for m in all_metadata]).index.unique()
|
||||
dst_meta.tasks = pd.DataFrame({"task_index": range(len(unique_tasks))}, index=unique_tasks)
|
||||
dst_meta.tasks = pd.DataFrame(
|
||||
{"task_index": range(len(unique_tasks))}, index=pd.Index(unique_tasks, name="task")
|
||||
)
|
||||
|
||||
meta_idx = {"chunk": 0, "file": 0}
|
||||
data_idx = {"chunk": 0, "file": 0}
|
||||
|
||||
@@ -7,6 +7,13 @@
|
||||
|
||||
This dataset was created using [LeRobot](https://github.com/huggingface/lerobot).
|
||||
|
||||
{% if repo_id is defined and repo_id %}
|
||||
<a class="flex" href="https://huggingface.co/spaces/lerobot/visualize_dataset?path={{ repo_id }}">
|
||||
<img class="block dark:hidden" src="https://huggingface.co/datasets/huggingface/badges/resolve/main/visualize-this-dataset-xl.svg"/>
|
||||
<img class="hidden dark:block" src="https://huggingface.co/datasets/huggingface/badges/resolve/main/visualize-this-dataset-xl-dark.svg"/>
|
||||
</a>
|
||||
{% endif %}
|
||||
|
||||
## Dataset Description
|
||||
|
||||
{{ dataset_description | default("", true) }}
|
||||
|
||||
@@ -89,8 +89,8 @@ def delete_episodes(
|
||||
Args:
|
||||
dataset: The source LeRobotDataset.
|
||||
episode_indices: List of episode indices to delete.
|
||||
output_dir: Directory to save the new dataset. If None, uses default location.
|
||||
repo_id: Repository ID for the new dataset. If None, appends "_modified" to original.
|
||||
output_dir: Root directory where the edited dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/repo_id. Equivalent to new_root in EditDatasetConfig.
|
||||
repo_id: Edited dataset identifier. Equivalent to new_repo_id in EditDatasetConfig.
|
||||
"""
|
||||
if not episode_indices:
|
||||
raise ValueError("No episodes to delete")
|
||||
@@ -152,7 +152,7 @@ def split_dataset(
|
||||
dataset: The source LeRobotDataset to split.
|
||||
splits: Either a dict mapping split names to episode indices, or a dict mapping
|
||||
split names to fractions (must sum to <= 1.0).
|
||||
output_dir: Base directory for output datasets. If None, uses default location.
|
||||
output_dir: Root directory where the split datasets will be stored. If not specified, defaults to $HF_LEROBOT_HOME/repo_id.
|
||||
|
||||
Examples:
|
||||
Split by specific episodes
|
||||
@@ -243,8 +243,8 @@ def merge_datasets(
|
||||
|
||||
Args:
|
||||
datasets: List of LeRobotDatasets to merge.
|
||||
output_repo_id: Repository ID for the merged dataset.
|
||||
output_dir: Directory to save the merged dataset. If None, uses default location.
|
||||
output_repo_id: Merged dataset identifier.
|
||||
output_dir: Root directory where the merged dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/output_repo_id.
|
||||
"""
|
||||
if not datasets:
|
||||
raise ValueError("No datasets to merge")
|
||||
@@ -288,8 +288,8 @@ def modify_features(
|
||||
dataset: The source LeRobotDataset.
|
||||
add_features: Optional dict mapping feature names to (feature_values, feature_info) tuples.
|
||||
remove_features: Optional feature name(s) to remove. Can be a single string or list.
|
||||
output_dir: Directory to save the new dataset. If None, uses default location.
|
||||
repo_id: Repository ID for the new dataset. If None, appends "_modified" to original.
|
||||
output_dir: Root directory where the edited dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/repo_id. Equivalent to new_root in EditDatasetConfig.
|
||||
repo_id: Edited dataset identifier. Equivalent to new_repo_id in EditDatasetConfig.
|
||||
|
||||
Returns:
|
||||
New dataset with features modified.
|
||||
@@ -390,8 +390,8 @@ def add_features(
|
||||
Args:
|
||||
dataset: The source LeRobotDataset.
|
||||
features: Dictionary mapping feature names to (feature_values, feature_info) tuples.
|
||||
output_dir: Directory to save the new dataset. If None, uses default location.
|
||||
repo_id: Repository ID for the new dataset. If None, appends "_modified" to original.
|
||||
output_dir: Root directory where the edited dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/repo_id. Equivalent to new_root in EditDatasetConfig.
|
||||
repo_id: Edited dataset identifier. Equivalent to new_repo_id in EditDatasetConfig.
|
||||
|
||||
Returns:
|
||||
New dataset with all features added.
|
||||
@@ -427,8 +427,8 @@ def remove_feature(
|
||||
Args:
|
||||
dataset: The source LeRobotDataset.
|
||||
feature_names: Name(s) of features to remove. Can be a single string or list.
|
||||
output_dir: Directory to save the new dataset. If None, uses default location.
|
||||
repo_id: Repository ID for the new dataset. If None, appends "_modified" to original.
|
||||
output_dir: Root directory where the edited dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/repo_id. Equivalent to new_root in EditDatasetConfig.
|
||||
repo_id: Edited dataset identifier. Equivalent to new_repo_id in EditDatasetConfig.
|
||||
|
||||
Returns:
|
||||
New dataset with features removed.
|
||||
@@ -567,20 +567,22 @@ def _copy_and_reindex_data(
|
||||
def _keep_episodes_from_video_with_av(
|
||||
input_path: Path,
|
||||
output_path: Path,
|
||||
episodes_to_keep: list[tuple[float, float]],
|
||||
episodes_to_keep: list[tuple[int, int]],
|
||||
fps: float,
|
||||
vcodec: str = "libsvtav1",
|
||||
pix_fmt: str = "yuv420p",
|
||||
) -> None:
|
||||
"""Keep only specified episodes from a video file using PyAV.
|
||||
|
||||
This function decodes frames from specified time ranges and re-encodes them with
|
||||
This function decodes frames from specified frame ranges and re-encodes them with
|
||||
properly reset timestamps to ensure monotonic progression.
|
||||
|
||||
Args:
|
||||
input_path: Source video file path.
|
||||
output_path: Destination video file path.
|
||||
episodes_to_keep: List of (start_time, end_time) tuples for episodes to keep.
|
||||
episodes_to_keep: List of (start_frame, end_frame) tuples for episodes to keep.
|
||||
Ranges are half-open intervals: [start_frame, end_frame), where start_frame
|
||||
is inclusive and end_frame is exclusive.
|
||||
fps: Frame rate of the video.
|
||||
vcodec: Video codec to use for encoding.
|
||||
pix_fmt: Pixel format for output video.
|
||||
@@ -622,9 +624,10 @@ def _keep_episodes_from_video_with_av(
|
||||
|
||||
# Create set of (start, end) ranges for fast lookup.
|
||||
# Convert to a sorted list for efficient checking.
|
||||
time_ranges = sorted(episodes_to_keep)
|
||||
frame_ranges = sorted(episodes_to_keep)
|
||||
|
||||
# Track frame index for setting PTS and current range being processed.
|
||||
src_frame_count = 0
|
||||
frame_count = 0
|
||||
range_idx = 0
|
||||
|
||||
@@ -634,21 +637,20 @@ def _keep_episodes_from_video_with_av(
|
||||
if frame is None:
|
||||
continue
|
||||
|
||||
# Get frame timestamp.
|
||||
frame_time = float(frame.pts * frame.time_base) if frame.pts is not None else 0.0
|
||||
|
||||
# Check if frame is in any of our desired time ranges.
|
||||
# Check if frame is in any of our desired frame ranges.
|
||||
# Skip ranges that have already passed.
|
||||
while range_idx < len(time_ranges) and frame_time >= time_ranges[range_idx][1]:
|
||||
while range_idx < len(frame_ranges) and src_frame_count >= frame_ranges[range_idx][1]:
|
||||
range_idx += 1
|
||||
|
||||
# If we've passed all ranges, stop processing.
|
||||
if range_idx >= len(time_ranges):
|
||||
if range_idx >= len(frame_ranges):
|
||||
break
|
||||
|
||||
# Check if frame is in current range.
|
||||
start_ts, end_ts = time_ranges[range_idx]
|
||||
if frame_time < start_ts:
|
||||
start_frame = frame_ranges[range_idx][0]
|
||||
|
||||
if src_frame_count < start_frame:
|
||||
src_frame_count += 1
|
||||
continue
|
||||
|
||||
# Frame is in range - create a new frame with reset timestamps.
|
||||
@@ -661,6 +663,7 @@ def _keep_episodes_from_video_with_av(
|
||||
for pkt in v_out.encode(new_frame):
|
||||
out.mux(pkt)
|
||||
|
||||
src_frame_count += 1
|
||||
frame_count += 1
|
||||
|
||||
# Flush encoder.
|
||||
@@ -749,15 +752,17 @@ def _copy_and_reindex_videos(
|
||||
f"videos/{video_key}/to_timestamp"
|
||||
]
|
||||
else:
|
||||
# Build list of time ranges to keep, in sorted order.
|
||||
# Build list of frame ranges to keep, in sorted order.
|
||||
sorted_keep_episodes = sorted(episodes_in_file, key=lambda x: episode_mapping[x])
|
||||
episodes_to_keep_ranges: list[tuple[float, float]] = []
|
||||
|
||||
episodes_to_keep_ranges: list[tuple[int, int]] = []
|
||||
for old_idx in sorted_keep_episodes:
|
||||
src_ep = src_dataset.meta.episodes[old_idx]
|
||||
from_ts = src_ep[f"videos/{video_key}/from_timestamp"]
|
||||
to_ts = src_ep[f"videos/{video_key}/to_timestamp"]
|
||||
episodes_to_keep_ranges.append((from_ts, to_ts))
|
||||
from_frame = round(src_ep[f"videos/{video_key}/from_timestamp"] * src_dataset.meta.fps)
|
||||
to_frame = round(src_ep[f"videos/{video_key}/to_timestamp"] * src_dataset.meta.fps)
|
||||
assert src_ep["length"] == to_frame - from_frame, (
|
||||
f"Episode length mismatch: {src_ep['length']} vs {to_frame - from_frame}"
|
||||
)
|
||||
episodes_to_keep_ranges.append((from_frame, to_frame))
|
||||
|
||||
# Use PyAV filters to efficiently re-encode only the desired segments.
|
||||
assert src_dataset.meta.video_path is not None
|
||||
@@ -1470,7 +1475,9 @@ def modify_tasks(
|
||||
|
||||
# Collect all unique tasks and create new task mapping
|
||||
unique_tasks = sorted(set(episode_to_task.values()))
|
||||
new_task_df = pd.DataFrame({"task_index": list(range(len(unique_tasks)))}, index=unique_tasks)
|
||||
new_task_df = pd.DataFrame(
|
||||
{"task_index": list(range(len(unique_tasks)))}, index=pd.Index(unique_tasks, name="task")
|
||||
)
|
||||
task_to_index = {task: idx for idx, task in enumerate(unique_tasks)}
|
||||
|
||||
logging.info(f"Modifying tasks in {dataset.repo_id}")
|
||||
@@ -1524,7 +1531,7 @@ def modify_tasks(
|
||||
|
||||
def convert_image_to_video_dataset(
|
||||
dataset: LeRobotDataset,
|
||||
output_dir: Path,
|
||||
output_dir: Path | None = None,
|
||||
repo_id: str | None = None,
|
||||
vcodec: str = "libsvtav1",
|
||||
pix_fmt: str = "yuv420p",
|
||||
@@ -1543,8 +1550,8 @@ def convert_image_to_video_dataset(
|
||||
|
||||
Args:
|
||||
dataset: The source LeRobot dataset with images
|
||||
output_dir: Directory to save the new video dataset
|
||||
repo_id: Repository ID for the new dataset (default: original_id + "_video")
|
||||
output_dir: Root directory where the edited dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/repo_id. Equivalent to new_root in EditDatasetConfig.
|
||||
repo_id: Edited dataset identifier. Equivalent to new_repo_id in EditDatasetConfig.
|
||||
vcodec: Video codec (default: libsvtav1)
|
||||
pix_fmt: Pixel format (default: yuv420p)
|
||||
g: Group of pictures size (default: 2)
|
||||
@@ -1595,6 +1602,7 @@ def convert_image_to_video_dataset(
|
||||
# Video info will be updated after episodes are encoded
|
||||
|
||||
# Create new metadata for video dataset
|
||||
output_dir = Path(output_dir) if output_dir is not None else HF_LEROBOT_HOME / repo_id
|
||||
new_meta = LeRobotDatasetMetadata.create(
|
||||
repo_id=repo_id,
|
||||
fps=dataset.meta.fps,
|
||||
|
||||
@@ -68,6 +68,7 @@ from lerobot.datasets.utils import (
|
||||
write_tasks,
|
||||
)
|
||||
from lerobot.datasets.video_utils import (
|
||||
StreamingVideoEncoder,
|
||||
VideoFrame,
|
||||
concatenate_video_files,
|
||||
decode_video_frames,
|
||||
@@ -75,11 +76,11 @@ from lerobot.datasets.video_utils import (
|
||||
get_safe_default_codec,
|
||||
get_video_duration_in_s,
|
||||
get_video_info,
|
||||
resolve_vcodec,
|
||||
)
|
||||
from lerobot.utils.constants import HF_LEROBOT_HOME
|
||||
|
||||
CODEBASE_VERSION = "v3.0"
|
||||
VALID_VIDEO_CODECS = {"h264", "hevc", "libsvtav1"}
|
||||
|
||||
|
||||
class LeRobotDatasetMetadata:
|
||||
@@ -313,7 +314,7 @@ class LeRobotDatasetMetadata:
|
||||
if self.tasks is None:
|
||||
new_tasks = tasks
|
||||
task_indices = range(len(tasks))
|
||||
self.tasks = pd.DataFrame({"task_index": task_indices}, index=tasks)
|
||||
self.tasks = pd.DataFrame({"task_index": task_indices}, index=pd.Index(tasks, name="task"))
|
||||
else:
|
||||
new_tasks = [task for task in tasks if task not in self.tasks.index]
|
||||
new_task_indices = range(len(self.tasks), len(self.tasks) + len(new_tasks))
|
||||
@@ -545,12 +546,19 @@ class LeRobotDatasetMetadata:
|
||||
|
||||
|
||||
def _encode_video_worker(
|
||||
video_key: str, episode_index: int, root: Path, fps: int, vcodec: str = "libsvtav1"
|
||||
video_key: str,
|
||||
episode_index: int,
|
||||
root: Path,
|
||||
fps: int,
|
||||
vcodec: str = "libsvtav1",
|
||||
encoder_threads: int | None = None,
|
||||
) -> Path:
|
||||
temp_path = Path(tempfile.mkdtemp(dir=root)) / f"{video_key}_{episode_index:03d}.mp4"
|
||||
fpath = DEFAULT_IMAGE_PATH.format(image_key=video_key, episode_index=episode_index, frame_index=0)
|
||||
img_dir = (root / fpath).parent
|
||||
encode_video_frames(img_dir, temp_path, fps, vcodec=vcodec, overwrite=True)
|
||||
encode_video_frames(
|
||||
img_dir, temp_path, fps, vcodec=vcodec, overwrite=True, encoder_threads=encoder_threads
|
||||
)
|
||||
shutil.rmtree(img_dir)
|
||||
return temp_path
|
||||
|
||||
@@ -570,6 +578,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
video_backend: str | None = None,
|
||||
batch_encoding_size: int = 1,
|
||||
vcodec: str = "libsvtav1",
|
||||
streaming_encoding: bool = False,
|
||||
encoder_queue_maxsize: int = 30,
|
||||
encoder_threads: int | None = None,
|
||||
):
|
||||
"""
|
||||
2 modes are available for instantiating this class, depending on 2 different use cases:
|
||||
@@ -653,11 +664,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
for the README).
|
||||
|
||||
Args:
|
||||
repo_id (str): This is the repo id that will be used to fetch the dataset. Locally, the dataset
|
||||
will be stored under root/repo_id.
|
||||
root (Path | None, optional): Local directory to use for downloading/writing files. You can also
|
||||
set the LEROBOT_HOME environment variable to point to a different location. Defaults to
|
||||
'~/.cache/huggingface/lerobot'.
|
||||
repo_id (str): This is the repo id that will be used to fetch the dataset.
|
||||
root (Path | None, optional): Local directory where the dataset will be downloaded and
|
||||
stored. If set, all dataset files will be stored directly under this path. If not set, the
|
||||
dataset files will be stored under $HF_LEROBOT_HOME/repo_id (configurable via the
|
||||
HF_LEROBOT_HOME environment variable).
|
||||
episodes (list[int] | None, optional): If specified, this will only load episodes specified by
|
||||
their episode_index in this list. Defaults to None.
|
||||
image_transforms (Callable | None, optional): You can pass standard v2 image transforms from
|
||||
@@ -683,12 +694,17 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
batch_encoding_size (int, optional): Number of episodes to accumulate before batch encoding videos.
|
||||
Set to 1 for immediate encoding (default), or higher for batched encoding. Defaults to 1.
|
||||
vcodec (str, optional): Video codec for encoding videos during recording. Options: 'h264', 'hevc',
|
||||
'libsvtav1'. Defaults to 'libsvtav1'. Use 'h264' for faster encoding on systems where AV1
|
||||
encoding is CPU-heavy.
|
||||
'libsvtav1', 'auto', or hardware-specific codecs like 'h264_videotoolbox', 'h264_nvenc'.
|
||||
Defaults to 'libsvtav1'. Use 'auto' to auto-detect the best available hardware encoder.
|
||||
streaming_encoding (bool, optional): If True, encode video frames in real-time during capture
|
||||
instead of writing PNG images first. This makes save_episode() near-instant. Defaults to False.
|
||||
encoder_queue_maxsize (int, optional): Maximum number of frames to buffer per camera when using
|
||||
streaming encoding. Defaults to 30 (~1s at 30fps).
|
||||
encoder_threads (int | None, optional): Number of threads per encoder instance. None lets the
|
||||
codec auto-detect (default). Lower values reduce CPU usage per encoder. Maps to 'lp' (via svtav1-params) for
|
||||
libsvtav1 and 'threads' for h264/hevc.
|
||||
"""
|
||||
super().__init__()
|
||||
if vcodec not in VALID_VIDEO_CODECS:
|
||||
raise ValueError(f"Invalid vcodec '{vcodec}'. Must be one of: {sorted(VALID_VIDEO_CODECS)}")
|
||||
self.repo_id = repo_id
|
||||
self.root = Path(root) if root else HF_LEROBOT_HOME / repo_id
|
||||
self.image_transforms = image_transforms
|
||||
@@ -700,7 +716,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.delta_indices = None
|
||||
self.batch_encoding_size = batch_encoding_size
|
||||
self.episodes_since_last_encoding = 0
|
||||
self.vcodec = vcodec
|
||||
self.vcodec = resolve_vcodec(vcodec)
|
||||
self._encoder_threads = encoder_threads
|
||||
|
||||
# Unused attributes
|
||||
self.image_writer = None
|
||||
@@ -708,6 +725,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.writer = None
|
||||
self.latest_episode = None
|
||||
self._current_file_start_frame = None # Track the starting frame index of the current parquet file
|
||||
self._streaming_encoder = None
|
||||
|
||||
self.root.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
@@ -729,7 +747,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
# Check if cached dataset contains all requested episodes
|
||||
if not self._check_cached_episodes_sufficient():
|
||||
raise FileNotFoundError("Cached dataset doesn't contain all requested episodes")
|
||||
except (AssertionError, FileNotFoundError, NotADirectoryError):
|
||||
except (FileNotFoundError, NotADirectoryError):
|
||||
if is_valid_version(self.revision):
|
||||
self.revision = get_safe_version(self.repo_id, self.revision)
|
||||
self.download(download_videos)
|
||||
@@ -749,6 +767,19 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s)
|
||||
self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps)
|
||||
|
||||
# Initialize streaming encoder for resumed recording
|
||||
if streaming_encoding and len(self.meta.video_keys) > 0:
|
||||
self._streaming_encoder = StreamingVideoEncoder(
|
||||
fps=self.meta.fps,
|
||||
vcodec=self.vcodec,
|
||||
pix_fmt="yuv420p",
|
||||
g=2,
|
||||
crf=30,
|
||||
preset=None,
|
||||
queue_maxsize=encoder_queue_maxsize,
|
||||
encoder_threads=encoder_threads,
|
||||
)
|
||||
|
||||
def _close_writer(self) -> None:
|
||||
"""Close and cleanup the parquet writer if it exists."""
|
||||
writer = getattr(self, "writer", None)
|
||||
@@ -808,7 +839,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
hub_api.upload_folder(**upload_kwargs)
|
||||
|
||||
card = create_lerobot_dataset_card(
|
||||
tags=tags, dataset_info=self.meta.info, license=license, **card_kwargs
|
||||
tags=tags, dataset_info=self.meta.info, license=license, repo_id=self.repo_id, **card_kwargs
|
||||
)
|
||||
card.push_to_hub(repo_id=self.repo_id, repo_type="dataset", revision=branch)
|
||||
|
||||
@@ -1104,6 +1135,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
self._close_writer()
|
||||
self.meta._close_writer()
|
||||
if self._streaming_encoder is not None:
|
||||
self._streaming_encoder.close()
|
||||
|
||||
def create_episode_buffer(self, episode_index: int | None = None) -> dict:
|
||||
current_ep_idx = self.meta.total_episodes if episode_index is None else episode_index
|
||||
@@ -1158,6 +1191,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.episode_buffer["timestamp"].append(timestamp)
|
||||
self.episode_buffer["task"].append(frame.pop("task")) # Remove task from frame after processing
|
||||
|
||||
# Start streaming encoder on first frame of episode (once, before iterating keys)
|
||||
if frame_index == 0 and self._streaming_encoder is not None:
|
||||
self._streaming_encoder.start_episode(
|
||||
video_keys=list(self.meta.video_keys),
|
||||
temp_dir=self.root,
|
||||
)
|
||||
|
||||
# Add frame features to episode_buffer
|
||||
for key in frame:
|
||||
if key not in self.features:
|
||||
@@ -1165,7 +1205,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
f"An element of the frame is not in the features. '{key}' not in '{self.features.keys()}'."
|
||||
)
|
||||
|
||||
if self.features[key]["dtype"] in ["image", "video"]:
|
||||
if self.features[key]["dtype"] == "video" and self._streaming_encoder is not None:
|
||||
self._streaming_encoder.feed_frame(key, frame[key])
|
||||
self.episode_buffer[key].append(None) # Placeholder (video keys are skipped in parquet)
|
||||
elif self.features[key]["dtype"] in ["image", "video"]:
|
||||
img_path = self._get_image_file_path(
|
||||
episode_index=self.episode_buffer["episode_index"], image_key=key, frame_index=frame_index
|
||||
)
|
||||
@@ -1226,13 +1269,38 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
# Wait for image writer to end, so that episode stats over images can be computed
|
||||
self._wait_image_writer()
|
||||
ep_stats = compute_episode_stats(episode_buffer, self.features)
|
||||
|
||||
ep_metadata = self._save_episode_data(episode_buffer)
|
||||
has_video_keys = len(self.meta.video_keys) > 0
|
||||
use_streaming = self._streaming_encoder is not None and has_video_keys
|
||||
use_batched_encoding = self.batch_encoding_size > 1
|
||||
|
||||
if has_video_keys and not use_batched_encoding:
|
||||
if use_streaming:
|
||||
# Compute stats for non-video features only (video stats come from encoder)
|
||||
non_video_buffer = {
|
||||
k: v
|
||||
for k, v in episode_buffer.items()
|
||||
if self.features.get(k, {}).get("dtype") not in ("video",)
|
||||
}
|
||||
non_video_features = {k: v for k, v in self.features.items() if v["dtype"] != "video"}
|
||||
ep_stats = compute_episode_stats(non_video_buffer, non_video_features)
|
||||
else:
|
||||
ep_stats = compute_episode_stats(episode_buffer, self.features)
|
||||
|
||||
ep_metadata = self._save_episode_data(episode_buffer)
|
||||
|
||||
if use_streaming:
|
||||
# Finish streaming encoding and collect results
|
||||
streaming_results = self._streaming_encoder.finish_episode()
|
||||
for video_key in self.meta.video_keys:
|
||||
temp_path, video_stats = streaming_results[video_key]
|
||||
if video_stats is not None:
|
||||
# Format stats same as compute_episode_stats: normalize to [0,1], reshape to (C,1,1)
|
||||
ep_stats[video_key] = {
|
||||
k: v if k == "count" else np.squeeze(v.reshape(1, -1, 1, 1) / 255.0, axis=0)
|
||||
for k, v in video_stats.items()
|
||||
}
|
||||
ep_metadata.update(self._save_episode_video(video_key, episode_index, temp_path=temp_path))
|
||||
elif has_video_keys and not use_batched_encoding:
|
||||
num_cameras = len(self.meta.video_keys)
|
||||
if parallel_encoding and num_cameras > 1:
|
||||
# TODO(Steven): Ideally we would like to control the number of threads per encoding such that:
|
||||
@@ -1246,6 +1314,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.root,
|
||||
self.fps,
|
||||
self.vcodec,
|
||||
self._encoder_threads,
|
||||
): video_key
|
||||
for video_key in self.meta.video_keys
|
||||
}
|
||||
@@ -1514,6 +1583,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
return metadata
|
||||
|
||||
def clear_episode_buffer(self, delete_images: bool = True) -> None:
|
||||
# Cancel streaming encoder if active
|
||||
if self._streaming_encoder is not None:
|
||||
self._streaming_encoder.cancel_episode()
|
||||
|
||||
# Clean up image files for the current episode buffer
|
||||
if delete_images:
|
||||
# Wait for the async image writer to finish
|
||||
@@ -1561,7 +1634,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
|
||||
since video encoding with ffmpeg is already using multithreading.
|
||||
"""
|
||||
return _encode_video_worker(video_key, episode_index, self.root, self.fps, self.vcodec)
|
||||
return _encode_video_worker(
|
||||
video_key, episode_index, self.root, self.fps, self.vcodec, self._encoder_threads
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
@@ -1578,10 +1653,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
video_backend: str | None = None,
|
||||
batch_encoding_size: int = 1,
|
||||
vcodec: str = "libsvtav1",
|
||||
metadata_buffer_size: int = 10,
|
||||
streaming_encoding: bool = False,
|
||||
encoder_queue_maxsize: int = 30,
|
||||
encoder_threads: int | None = None,
|
||||
) -> "LeRobotDataset":
|
||||
"""Create a LeRobot Dataset from scratch in order to record data."""
|
||||
if vcodec not in VALID_VIDEO_CODECS:
|
||||
raise ValueError(f"Invalid vcodec '{vcodec}'. Must be one of: {sorted(VALID_VIDEO_CODECS)}")
|
||||
vcodec = resolve_vcodec(vcodec)
|
||||
obj = cls.__new__(cls)
|
||||
obj.meta = LeRobotDatasetMetadata.create(
|
||||
repo_id=repo_id,
|
||||
@@ -1590,6 +1668,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
features=features,
|
||||
root=root,
|
||||
use_videos=use_videos,
|
||||
metadata_buffer_size=metadata_buffer_size,
|
||||
)
|
||||
obj.repo_id = obj.meta.repo_id
|
||||
obj.root = obj.meta.root
|
||||
@@ -1599,6 +1678,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
obj.batch_encoding_size = batch_encoding_size
|
||||
obj.episodes_since_last_encoding = 0
|
||||
obj.vcodec = vcodec
|
||||
obj._encoder_threads = encoder_threads
|
||||
|
||||
if image_writer_processes or image_writer_threads:
|
||||
obj.start_image_writer(image_writer_processes, image_writer_threads)
|
||||
@@ -1620,6 +1700,22 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
obj._lazy_loading = False
|
||||
obj._recorded_frames = 0
|
||||
obj._writer_closed_for_reading = False
|
||||
|
||||
# Initialize streaming encoder
|
||||
if streaming_encoding and len(obj.meta.video_keys) > 0:
|
||||
obj._streaming_encoder = StreamingVideoEncoder(
|
||||
fps=fps,
|
||||
vcodec=vcodec,
|
||||
pix_fmt="yuv420p",
|
||||
g=2,
|
||||
crf=30,
|
||||
preset=None,
|
||||
queue_maxsize=encoder_queue_maxsize,
|
||||
encoder_threads=encoder_threads,
|
||||
)
|
||||
else:
|
||||
obj._streaming_encoder = None
|
||||
|
||||
return obj
|
||||
|
||||
|
||||
@@ -1675,11 +1771,12 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
)
|
||||
for repo_id, ds in zip(self.repo_ids, self._datasets, strict=True):
|
||||
extra_keys = set(ds.features).difference(intersection_features)
|
||||
logging.warning(
|
||||
f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the "
|
||||
"other datasets."
|
||||
)
|
||||
self.disabled_features.update(extra_keys)
|
||||
if extra_keys:
|
||||
logging.warning(
|
||||
f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the "
|
||||
"other datasets."
|
||||
)
|
||||
self.disabled_features.update(extra_keys)
|
||||
|
||||
self.image_transforms = image_transforms
|
||||
self.delta_timestamps = delta_timestamps
|
||||
|
||||
@@ -122,19 +122,9 @@ def load_nested_dataset(
|
||||
raise FileNotFoundError(f"Provided directory does not contain any parquet file: {pq_dir}")
|
||||
|
||||
with SuppressProgressBars():
|
||||
# When no filtering needed, Dataset uses memory-mapped loading for efficiency
|
||||
# PyArrow loads the entire dataset into memory
|
||||
if episodes is None:
|
||||
return Dataset.from_parquet([str(path) for path in paths], features=features)
|
||||
|
||||
arrow_dataset = pa_ds.dataset(paths, format="parquet")
|
||||
filter_expr = pa_ds.field("episode_index").isin(episodes)
|
||||
table = arrow_dataset.to_table(filter=filter_expr)
|
||||
|
||||
if features is not None:
|
||||
table = table.cast(features.arrow_schema)
|
||||
|
||||
return Dataset(table)
|
||||
# We use .from_parquet() memory-mapped loading for efficiency
|
||||
filters = pa_ds.field("episode_index").isin(episodes) if episodes is not None else None
|
||||
return Dataset.from_parquet([str(path) for path in paths], filters=filters, features=features)
|
||||
|
||||
|
||||
def get_parquet_num_frames(parquet_path: str | Path) -> int:
|
||||
@@ -351,6 +341,7 @@ def write_tasks(tasks: pandas.DataFrame, local_dir: Path) -> None:
|
||||
|
||||
def load_tasks(local_dir: Path) -> pandas.DataFrame:
|
||||
tasks = pd.read_parquet(local_dir / DEFAULT_TASKS_PATH)
|
||||
tasks.index.name = "task"
|
||||
return tasks
|
||||
|
||||
|
||||
|
||||
@@ -36,8 +36,11 @@ Convert a local dataset (works in place):
|
||||
```bash
|
||||
python src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py \
|
||||
--repo-id=lerobot/pusht \
|
||||
--root=/path/to/local/dataset/directory
|
||||
--root=/path/to/local/dataset/directory \
|
||||
--push-to-hub=false
|
||||
|
||||
N.B. Path semantics (v2): --root is the exact dataset folder containing
|
||||
meta/, data/, videos/. When omitted, defaults to $HF_LEROBOT_HOME/{repo_id}.
|
||||
```
|
||||
|
||||
"""
|
||||
@@ -105,7 +108,7 @@ episodes.jsonl
|
||||
{"episode_index": 1, "tasks": ["Put the blue block in the green bowl"], "length": 266}
|
||||
|
||||
NEW
|
||||
meta/episodes/chunk-000/episodes_000.parquet
|
||||
meta/episodes/chunk-000/file_000.parquet
|
||||
episode_index | video_chunk_index | video_file_index | data_chunk_index | data_file_index | tasks | length
|
||||
-------------------------
|
||||
OLD
|
||||
@@ -113,15 +116,16 @@ tasks.jsonl
|
||||
{"task_index": 1, "task": "Put the blue block in the green bowl"}
|
||||
|
||||
NEW
|
||||
meta/tasks/chunk-000/file_000.parquet
|
||||
meta/tasks.parquet
|
||||
task_index | task
|
||||
-------------------------
|
||||
OLD
|
||||
episodes_stats.jsonl
|
||||
{"episode_index": 1, "stats": {"feature_name": {"min": ..., "max": ..., "mean": ..., "std": ..., "count": ...}}}
|
||||
|
||||
NEW
|
||||
meta/episodes_stats/chunk-000/file_000.parquet
|
||||
episode_index | mean | std | min | max
|
||||
meta/episodes/chunk-000/file_000.parquet
|
||||
episode_index | feature_name/min | feature_name/max | feature_name/mean | feature_name/std | feature_name/count
|
||||
-------------------------
|
||||
UPDATE
|
||||
meta/info.json
|
||||
@@ -170,7 +174,7 @@ def convert_tasks(root, new_root):
|
||||
tasks, _ = legacy_load_tasks(root)
|
||||
task_indices = tasks.keys()
|
||||
task_strings = tasks.values()
|
||||
df_tasks = pd.DataFrame({"task_index": task_indices}, index=task_strings)
|
||||
df_tasks = pd.DataFrame({"task_index": task_indices}, index=pd.Index(task_strings, name="task"))
|
||||
write_tasks(df_tasks, new_root)
|
||||
|
||||
|
||||
@@ -201,7 +205,6 @@ def convert_data(root: Path, new_root: Path, data_file_size_in_mb: int):
|
||||
|
||||
image_keys = get_image_keys(root)
|
||||
|
||||
ep_idx = 0
|
||||
chunk_idx = 0
|
||||
file_idx = 0
|
||||
size_in_mb = 0
|
||||
@@ -211,9 +214,23 @@ def convert_data(root: Path, new_root: Path, data_file_size_in_mb: int):
|
||||
|
||||
logging.info(f"Converting data files from {len(ep_paths)} episodes")
|
||||
|
||||
for ep_path in tqdm.tqdm(ep_paths, desc="convert data files"):
|
||||
for ep_idx, ep_path in enumerate(tqdm.tqdm(ep_paths, desc="convert data files")):
|
||||
ep_size_in_mb = get_parquet_file_size_in_mb(ep_path)
|
||||
ep_num_frames = get_parquet_num_frames(ep_path)
|
||||
|
||||
# Check if we need to start a new file BEFORE creating metadata
|
||||
if size_in_mb + ep_size_in_mb >= data_file_size_in_mb and len(paths_to_cat) > 0:
|
||||
# Write the accumulated data files
|
||||
concat_data_files(paths_to_cat, new_root, chunk_idx, file_idx, image_keys)
|
||||
|
||||
# Move to next file
|
||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, DEFAULT_CHUNK_SIZE)
|
||||
|
||||
# Reset for the next file
|
||||
size_in_mb = 0
|
||||
paths_to_cat = []
|
||||
|
||||
# Now create metadata with correct chunk/file indices
|
||||
ep_metadata = {
|
||||
"episode_index": ep_idx,
|
||||
"data/chunk_index": chunk_idx,
|
||||
@@ -224,20 +241,7 @@ def convert_data(root: Path, new_root: Path, data_file_size_in_mb: int):
|
||||
size_in_mb += ep_size_in_mb
|
||||
num_frames += ep_num_frames
|
||||
episodes_metadata.append(ep_metadata)
|
||||
ep_idx += 1
|
||||
|
||||
if size_in_mb < data_file_size_in_mb:
|
||||
paths_to_cat.append(ep_path)
|
||||
continue
|
||||
|
||||
if paths_to_cat:
|
||||
concat_data_files(paths_to_cat, new_root, chunk_idx, file_idx, image_keys)
|
||||
|
||||
# Reset for the next file
|
||||
size_in_mb = ep_size_in_mb
|
||||
paths_to_cat = [ep_path]
|
||||
|
||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, DEFAULT_CHUNK_SIZE)
|
||||
paths_to_cat.append(ep_path)
|
||||
|
||||
# Write remaining data if any
|
||||
if paths_to_cat:
|
||||
@@ -469,7 +473,7 @@ def convert_dataset(
|
||||
|
||||
# Set root based on whether local dataset path is provided
|
||||
use_local_dataset = False
|
||||
root = HF_LEROBOT_HOME / repo_id if root is None else Path(root) / repo_id
|
||||
root = HF_LEROBOT_HOME / repo_id if root is None else Path(root)
|
||||
if root.exists():
|
||||
validate_local_dataset_version(root)
|
||||
use_local_dataset = True
|
||||
@@ -529,7 +533,7 @@ if __name__ == "__main__":
|
||||
type=str,
|
||||
required=True,
|
||||
help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset "
|
||||
"(e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).",
|
||||
"(e.g. `lerobot/pusht`, `<USER>/aloha_sim_insertion_human`).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--branch",
|
||||
@@ -553,7 +557,7 @@ if __name__ == "__main__":
|
||||
"--root",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Local directory to use for downloading/writing the dataset.",
|
||||
help="Local directory to use for downloading/writing the dataset. Defaults to $HF_LEROBOT_HOME/repo_id.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push-to-hub",
|
||||
|
||||
@@ -13,25 +13,106 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import contextlib
|
||||
import glob
|
||||
import importlib
|
||||
import logging
|
||||
import queue
|
||||
import shutil
|
||||
import tempfile
|
||||
import threading
|
||||
import warnings
|
||||
from dataclasses import dataclass, field
|
||||
from fractions import Fraction
|
||||
from pathlib import Path
|
||||
from threading import Lock
|
||||
from typing import Any, ClassVar
|
||||
|
||||
import av
|
||||
import fsspec
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
import torch
|
||||
import torchvision
|
||||
from datasets.features.features import register_feature
|
||||
from PIL import Image
|
||||
|
||||
# List of hardware encoders to probe for auto-selection. Availability depends on the platform and FFmpeg build.
|
||||
# Determines the order of preference for auto-selection when vcodec="auto" is used.
|
||||
HW_ENCODERS = [
|
||||
"h264_videotoolbox", # macOS
|
||||
"hevc_videotoolbox", # macOS
|
||||
"h264_nvenc", # NVIDIA GPU
|
||||
"hevc_nvenc", # NVIDIA GPU
|
||||
"h264_vaapi", # Linux Intel/AMD
|
||||
"h264_qsv", # Intel Quick Sync
|
||||
]
|
||||
|
||||
VALID_VIDEO_CODECS = {"h264", "hevc", "libsvtav1", "auto"} | set(HW_ENCODERS)
|
||||
|
||||
|
||||
def _get_codec_options(
|
||||
vcodec: str,
|
||||
g: int | None = 2,
|
||||
crf: int | None = 30,
|
||||
preset: int | None = None,
|
||||
) -> dict:
|
||||
"""Build codec-specific options dict for video encoding."""
|
||||
options = {}
|
||||
|
||||
# GOP size (keyframe interval) - supported by VideoToolbox and software encoders
|
||||
if g is not None and (vcodec in ("h264_videotoolbox", "hevc_videotoolbox") or vcodec not in HW_ENCODERS):
|
||||
options["g"] = str(g)
|
||||
|
||||
# Quality control (codec-specific parameter names)
|
||||
if crf is not None:
|
||||
if vcodec in ("h264", "hevc", "libsvtav1"):
|
||||
options["crf"] = str(crf)
|
||||
elif vcodec in ("h264_videotoolbox", "hevc_videotoolbox"):
|
||||
quality = max(1, min(100, int(100 - crf * 2)))
|
||||
options["q:v"] = str(quality)
|
||||
elif vcodec in ("h264_nvenc", "hevc_nvenc"):
|
||||
options["rc"] = "constqp"
|
||||
options["qp"] = str(crf)
|
||||
elif vcodec in ("h264_vaapi",):
|
||||
options["qp"] = str(crf)
|
||||
elif vcodec in ("h264_qsv",):
|
||||
options["global_quality"] = str(crf)
|
||||
|
||||
# Preset (only for libsvtav1)
|
||||
if vcodec == "libsvtav1":
|
||||
options["preset"] = str(preset) if preset is not None else "12"
|
||||
|
||||
return options
|
||||
|
||||
|
||||
def detect_available_hw_encoders() -> list[str]:
|
||||
"""Probe PyAV/FFmpeg for available hardware video encoders."""
|
||||
available = []
|
||||
for codec_name in HW_ENCODERS:
|
||||
try:
|
||||
av.codec.Codec(codec_name, "w")
|
||||
available.append(codec_name)
|
||||
except Exception: # nosec B110
|
||||
pass # nosec B110
|
||||
return available
|
||||
|
||||
|
||||
def resolve_vcodec(vcodec: str) -> str:
|
||||
"""Validate vcodec and resolve 'auto' to best available HW encoder, fallback to libsvtav1."""
|
||||
if vcodec not in VALID_VIDEO_CODECS:
|
||||
raise ValueError(f"Invalid vcodec '{vcodec}'. Must be one of: {sorted(VALID_VIDEO_CODECS)}")
|
||||
if vcodec != "auto":
|
||||
logging.info(f"Using video codec: {vcodec}")
|
||||
return vcodec
|
||||
available = detect_available_hw_encoders()
|
||||
for encoder in HW_ENCODERS:
|
||||
if encoder in available:
|
||||
logging.info(f"Auto-selected video codec: {encoder}")
|
||||
return encoder
|
||||
logging.info("No hardware encoder available, falling back to software encoder 'libsvtav1'")
|
||||
return "libsvtav1"
|
||||
|
||||
|
||||
def get_safe_default_codec():
|
||||
if importlib.util.find_spec("torchcodec"):
|
||||
@@ -146,16 +227,17 @@ def decode_video_frames_torchvision(
|
||||
min_, argmin_ = dist.min(1)
|
||||
|
||||
is_within_tol = min_ < tolerance_s
|
||||
assert is_within_tol.all(), (
|
||||
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
|
||||
"It means that the closest frame that can be loaded from the video is too far away in time."
|
||||
"This might be due to synchronization issues with timestamps during data collection."
|
||||
"To be safe, we advise to ignore this item during training."
|
||||
f"\nqueried timestamps: {query_ts}"
|
||||
f"\nloaded timestamps: {loaded_ts}"
|
||||
f"\nvideo: {video_path}"
|
||||
f"\nbackend: {backend}"
|
||||
)
|
||||
if not is_within_tol.all():
|
||||
raise FrameTimestampError(
|
||||
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
|
||||
" It means that the closest frame that can be loaded from the video is too far away in time."
|
||||
" This might be due to synchronization issues with timestamps during data collection."
|
||||
" To be safe, we advise to ignore this item during training."
|
||||
f"\nqueried timestamps: {query_ts}"
|
||||
f"\nloaded timestamps: {loaded_ts}"
|
||||
f"\nvideo: {video_path}"
|
||||
f"\nbackend: {backend}"
|
||||
)
|
||||
|
||||
# get closest frames to the query timestamps
|
||||
closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_])
|
||||
@@ -167,7 +249,11 @@ def decode_video_frames_torchvision(
|
||||
# convert to the pytorch format which is float32 in [0,1] range (and channel first)
|
||||
closest_frames = closest_frames.type(torch.float32) / 255
|
||||
|
||||
assert len(timestamps) == len(closest_frames)
|
||||
if len(timestamps) != len(closest_frames):
|
||||
raise FrameTimestampError(
|
||||
f"Number of retrieved frames ({len(closest_frames)}) does not match "
|
||||
f"number of queried timestamps ({len(timestamps)})"
|
||||
)
|
||||
return closest_frames
|
||||
|
||||
|
||||
@@ -272,15 +358,16 @@ def decode_video_frames_torchcodec(
|
||||
min_, argmin_ = dist.min(1)
|
||||
|
||||
is_within_tol = min_ < tolerance_s
|
||||
assert is_within_tol.all(), (
|
||||
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
|
||||
"It means that the closest frame that can be loaded from the video is too far away in time."
|
||||
"This might be due to synchronization issues with timestamps during data collection."
|
||||
"To be safe, we advise to ignore this item during training."
|
||||
f"\nqueried timestamps: {query_ts}"
|
||||
f"\nloaded timestamps: {loaded_ts}"
|
||||
f"\nvideo: {video_path}"
|
||||
)
|
||||
if not is_within_tol.all():
|
||||
raise FrameTimestampError(
|
||||
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
|
||||
" It means that the closest frame that can be loaded from the video is too far away in time."
|
||||
" This might be due to synchronization issues with timestamps during data collection."
|
||||
" To be safe, we advise to ignore this item during training."
|
||||
f"\nqueried timestamps: {query_ts}"
|
||||
f"\nloaded timestamps: {loaded_ts}"
|
||||
f"\nvideo: {video_path}"
|
||||
)
|
||||
|
||||
# get closest frames to the query timestamps
|
||||
closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_])
|
||||
@@ -309,14 +396,13 @@ def encode_video_frames(
|
||||
g: int | None = 2,
|
||||
crf: int | None = 30,
|
||||
fast_decode: int = 0,
|
||||
log_level: int | None = av.logging.ERROR,
|
||||
log_level: int | None = av.logging.WARNING,
|
||||
overwrite: bool = False,
|
||||
preset: int | None = None,
|
||||
encoder_threads: int | None = None,
|
||||
) -> None:
|
||||
"""More info on ffmpeg arguments tuning on `benchmark/video/README.md`"""
|
||||
# Check encoder availability
|
||||
if vcodec not in ["h264", "hevc", "libsvtav1"]:
|
||||
raise ValueError(f"Unsupported video codec: {vcodec}. Supported codecs are: h264, hevc, libsvtav1.")
|
||||
vcodec = resolve_vcodec(vcodec)
|
||||
|
||||
video_path = Path(video_path)
|
||||
imgs_dir = Path(imgs_dir)
|
||||
@@ -347,21 +433,22 @@ def encode_video_frames(
|
||||
width, height = dummy_image.size
|
||||
|
||||
# Define video codec options
|
||||
video_options = {}
|
||||
|
||||
if g is not None:
|
||||
video_options["g"] = str(g)
|
||||
|
||||
if crf is not None:
|
||||
video_options["crf"] = str(crf)
|
||||
video_options = _get_codec_options(vcodec, g, crf, preset)
|
||||
|
||||
if fast_decode:
|
||||
key = "svtav1-params" if vcodec == "libsvtav1" else "tune"
|
||||
value = f"fast-decode={fast_decode}" if vcodec == "libsvtav1" else "fastdecode"
|
||||
video_options[key] = value
|
||||
|
||||
if vcodec == "libsvtav1":
|
||||
video_options["preset"] = str(preset) if preset is not None else "12"
|
||||
if encoder_threads is not None:
|
||||
if vcodec == "libsvtav1":
|
||||
lp_param = f"lp={encoder_threads}"
|
||||
if "svtav1-params" in video_options:
|
||||
video_options["svtav1-params"] += f":{lp_param}"
|
||||
else:
|
||||
video_options["svtav1-params"] = lp_param
|
||||
else:
|
||||
video_options["threads"] = str(encoder_threads)
|
||||
|
||||
# Set logging level
|
||||
if log_level is not None:
|
||||
@@ -480,6 +567,348 @@ def concatenate_video_files(
|
||||
Path(tmp_concatenate_path).unlink()
|
||||
|
||||
|
||||
class _CameraEncoderThread(threading.Thread):
|
||||
"""A thread that encodes video frames streamed via a queue into an MP4 file.
|
||||
|
||||
One instance is created per camera per episode. Frames are received as numpy arrays
|
||||
from the main thread, encoded in real-time using PyAV (which releases the GIL during
|
||||
encoding), and written to disk. Stats are computed incrementally using
|
||||
RunningQuantileStats and returned via result_queue.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
video_path: Path,
|
||||
fps: int,
|
||||
vcodec: str,
|
||||
pix_fmt: str,
|
||||
g: int | None,
|
||||
crf: int | None,
|
||||
preset: int | None,
|
||||
frame_queue: queue.Queue,
|
||||
result_queue: queue.Queue,
|
||||
stop_event: threading.Event,
|
||||
encoder_threads: int | None = None,
|
||||
):
|
||||
super().__init__(daemon=True)
|
||||
self.video_path = video_path
|
||||
self.fps = fps
|
||||
self.vcodec = vcodec
|
||||
self.pix_fmt = pix_fmt
|
||||
self.g = g
|
||||
self.crf = crf
|
||||
self.preset = preset
|
||||
self.frame_queue = frame_queue
|
||||
self.result_queue = result_queue
|
||||
self.stop_event = stop_event
|
||||
self.encoder_threads = encoder_threads
|
||||
|
||||
def run(self) -> None:
|
||||
from lerobot.datasets.compute_stats import RunningQuantileStats, auto_downsample_height_width
|
||||
|
||||
container = None
|
||||
output_stream = None
|
||||
stats_tracker = RunningQuantileStats()
|
||||
frame_count = 0
|
||||
|
||||
try:
|
||||
logging.getLogger("libav").setLevel(av.logging.WARNING)
|
||||
|
||||
while True:
|
||||
try:
|
||||
frame_data = self.frame_queue.get(timeout=1)
|
||||
except queue.Empty:
|
||||
if self.stop_event.is_set():
|
||||
break
|
||||
continue
|
||||
|
||||
if frame_data is None:
|
||||
# Sentinel: flush and close
|
||||
break
|
||||
|
||||
# Ensure HWC uint8 numpy array
|
||||
if isinstance(frame_data, np.ndarray):
|
||||
if frame_data.ndim == 3 and frame_data.shape[0] == 3:
|
||||
# CHW -> HWC
|
||||
frame_data = frame_data.transpose(1, 2, 0)
|
||||
if frame_data.dtype != np.uint8:
|
||||
frame_data = (frame_data * 255).astype(np.uint8)
|
||||
|
||||
# Open container on first frame (to get width/height)
|
||||
if container is None:
|
||||
height, width = frame_data.shape[:2]
|
||||
video_options = _get_codec_options(self.vcodec, self.g, self.crf, self.preset)
|
||||
if self.encoder_threads is not None:
|
||||
if self.vcodec == "libsvtav1":
|
||||
lp_param = f"lp={self.encoder_threads}"
|
||||
if "svtav1-params" in video_options:
|
||||
video_options["svtav1-params"] += f":{lp_param}"
|
||||
else:
|
||||
video_options["svtav1-params"] = lp_param
|
||||
else:
|
||||
video_options["threads"] = str(self.encoder_threads)
|
||||
Path(self.video_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
container = av.open(str(self.video_path), "w")
|
||||
output_stream = container.add_stream(self.vcodec, self.fps, options=video_options)
|
||||
output_stream.pix_fmt = self.pix_fmt
|
||||
output_stream.width = width
|
||||
output_stream.height = height
|
||||
output_stream.time_base = Fraction(1, self.fps)
|
||||
|
||||
# Encode frame with explicit timestamps
|
||||
pil_img = Image.fromarray(frame_data)
|
||||
video_frame = av.VideoFrame.from_image(pil_img)
|
||||
video_frame.pts = frame_count
|
||||
video_frame.time_base = Fraction(1, self.fps)
|
||||
packet = output_stream.encode(video_frame)
|
||||
if packet:
|
||||
container.mux(packet)
|
||||
|
||||
# Update stats with downsampled frame (per-channel stats like compute_episode_stats)
|
||||
img_chw = frame_data.transpose(2, 0, 1) # HWC -> CHW
|
||||
img_downsampled = auto_downsample_height_width(img_chw)
|
||||
# Reshape CHW to (H*W, C) for per-channel stats
|
||||
channels = img_downsampled.shape[0]
|
||||
img_for_stats = img_downsampled.transpose(1, 2, 0).reshape(-1, channels)
|
||||
stats_tracker.update(img_for_stats)
|
||||
|
||||
frame_count += 1
|
||||
|
||||
# Flush encoder
|
||||
if output_stream is not None:
|
||||
packet = output_stream.encode()
|
||||
if packet:
|
||||
container.mux(packet)
|
||||
|
||||
if container is not None:
|
||||
container.close()
|
||||
|
||||
av.logging.restore_default_callback()
|
||||
|
||||
# Get stats and put on result queue
|
||||
if frame_count >= 2:
|
||||
stats = stats_tracker.get_statistics()
|
||||
self.result_queue.put(("ok", stats))
|
||||
else:
|
||||
self.result_queue.put(("ok", None))
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Encoder thread error: {e}")
|
||||
if container is not None:
|
||||
with contextlib.suppress(Exception):
|
||||
container.close()
|
||||
self.result_queue.put(("error", str(e)))
|
||||
|
||||
|
||||
class StreamingVideoEncoder:
|
||||
"""Manages per-camera encoder threads for real-time video encoding during recording.
|
||||
|
||||
Instead of writing frames as PNG images and then encoding to MP4 at episode end,
|
||||
this class streams frames directly to encoder threads, eliminating the
|
||||
PNG round-trip and making save_episode() near-instant.
|
||||
|
||||
Uses threading instead of multiprocessing to avoid the overhead of pickling large
|
||||
numpy arrays through multiprocessing.Queue. PyAV's encode() releases the GIL,
|
||||
so encoding runs in parallel with the main recording loop.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fps: int,
|
||||
vcodec: str = "libsvtav1",
|
||||
pix_fmt: str = "yuv420p",
|
||||
g: int | None = 2,
|
||||
crf: int | None = 30,
|
||||
preset: int | None = None,
|
||||
queue_maxsize: int = 30,
|
||||
encoder_threads: int | None = None,
|
||||
):
|
||||
self.fps = fps
|
||||
self.vcodec = resolve_vcodec(vcodec)
|
||||
self.pix_fmt = pix_fmt
|
||||
self.g = g
|
||||
self.crf = crf
|
||||
self.preset = preset
|
||||
self.queue_maxsize = queue_maxsize
|
||||
self.encoder_threads = encoder_threads
|
||||
|
||||
self._frame_queues: dict[str, queue.Queue] = {}
|
||||
self._result_queues: dict[str, queue.Queue] = {}
|
||||
self._threads: dict[str, _CameraEncoderThread] = {}
|
||||
self._stop_events: dict[str, threading.Event] = {}
|
||||
self._video_paths: dict[str, Path] = {}
|
||||
self._dropped_frames: dict[str, int] = {}
|
||||
self._episode_active = False
|
||||
|
||||
def start_episode(self, video_keys: list[str], temp_dir: Path) -> None:
|
||||
"""Start encoder threads for a new episode.
|
||||
|
||||
Args:
|
||||
video_keys: List of video feature keys (e.g. ["observation.images.laptop"])
|
||||
temp_dir: Base directory for temporary MP4 files
|
||||
"""
|
||||
if self._episode_active:
|
||||
self.cancel_episode()
|
||||
|
||||
self._dropped_frames.clear()
|
||||
|
||||
for video_key in video_keys:
|
||||
frame_queue: queue.Queue = queue.Queue(maxsize=self.queue_maxsize)
|
||||
result_queue: queue.Queue = queue.Queue(maxsize=1)
|
||||
stop_event = threading.Event()
|
||||
|
||||
temp_video_dir = Path(tempfile.mkdtemp(dir=temp_dir))
|
||||
video_path = temp_video_dir / f"{video_key.replace('/', '_')}_streaming.mp4"
|
||||
|
||||
encoder_thread = _CameraEncoderThread(
|
||||
video_path=video_path,
|
||||
fps=self.fps,
|
||||
vcodec=self.vcodec,
|
||||
pix_fmt=self.pix_fmt,
|
||||
g=self.g,
|
||||
crf=self.crf,
|
||||
preset=self.preset,
|
||||
frame_queue=frame_queue,
|
||||
result_queue=result_queue,
|
||||
stop_event=stop_event,
|
||||
encoder_threads=self.encoder_threads,
|
||||
)
|
||||
encoder_thread.start()
|
||||
|
||||
self._frame_queues[video_key] = frame_queue
|
||||
self._result_queues[video_key] = result_queue
|
||||
self._threads[video_key] = encoder_thread
|
||||
self._stop_events[video_key] = stop_event
|
||||
self._video_paths[video_key] = video_path
|
||||
|
||||
self._episode_active = True
|
||||
|
||||
def feed_frame(self, video_key: str, image: np.ndarray) -> None:
|
||||
"""Feed a frame to the encoder for a specific camera.
|
||||
|
||||
A copy of the image is made before enqueueing to prevent race conditions
|
||||
with camera drivers that may reuse buffers. If the encoder queue is full
|
||||
(encoder can't keep up), the frame is dropped with a warning instead of
|
||||
crashing the recording session.
|
||||
|
||||
Args:
|
||||
video_key: The video feature key
|
||||
image: numpy array in (H,W,C) or (C,H,W) format, uint8 or float
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the encoder thread has crashed
|
||||
"""
|
||||
if not self._episode_active:
|
||||
raise RuntimeError("No active episode. Call start_episode() first.")
|
||||
|
||||
thread = self._threads[video_key]
|
||||
if not thread.is_alive():
|
||||
# Check for error
|
||||
try:
|
||||
status, msg = self._result_queues[video_key].get_nowait()
|
||||
if status == "error":
|
||||
raise RuntimeError(f"Encoder thread for {video_key} crashed: {msg}")
|
||||
except queue.Empty:
|
||||
pass
|
||||
raise RuntimeError(f"Encoder thread for {video_key} is not alive")
|
||||
|
||||
try:
|
||||
self._frame_queues[video_key].put(image.copy(), timeout=0.1)
|
||||
except queue.Full:
|
||||
self._dropped_frames[video_key] = self._dropped_frames.get(video_key, 0) + 1
|
||||
count = self._dropped_frames[video_key]
|
||||
# Log periodically to avoid spam (1st, then every 10th)
|
||||
if count == 1 or count % 10 == 0:
|
||||
logging.warning(
|
||||
f"Encoder queue full for {video_key}, dropped {count} frame(s). "
|
||||
f"Consider using vcodec='auto' for hardware encoding or increasing encoder_queue_maxsize."
|
||||
)
|
||||
|
||||
def finish_episode(self) -> dict[str, tuple[Path, dict | None]]:
|
||||
"""Finish encoding the current episode.
|
||||
|
||||
Sends sentinel values, waits for encoder threads to complete,
|
||||
and collects results.
|
||||
|
||||
Returns:
|
||||
Dict mapping video_key to (mp4_path, stats_dict_or_None)
|
||||
"""
|
||||
if not self._episode_active:
|
||||
raise RuntimeError("No active episode to finish.")
|
||||
|
||||
results = {}
|
||||
|
||||
# Report dropped frames
|
||||
for video_key, count in self._dropped_frames.items():
|
||||
if count > 0:
|
||||
logging.warning(f"Episode finished with {count} dropped frame(s) for {video_key}.")
|
||||
|
||||
# Send sentinel to all queues
|
||||
for video_key in self._frame_queues:
|
||||
self._frame_queues[video_key].put(None)
|
||||
|
||||
# Wait for all threads and collect results
|
||||
for video_key in self._threads:
|
||||
self._threads[video_key].join(timeout=120)
|
||||
if self._threads[video_key].is_alive():
|
||||
logging.error(f"Encoder thread for {video_key} did not finish in time")
|
||||
self._stop_events[video_key].set()
|
||||
self._threads[video_key].join(timeout=5)
|
||||
results[video_key] = (self._video_paths[video_key], None)
|
||||
continue
|
||||
|
||||
try:
|
||||
status, data = self._result_queues[video_key].get(timeout=5)
|
||||
if status == "error":
|
||||
raise RuntimeError(f"Encoder thread for {video_key} failed: {data}")
|
||||
results[video_key] = (self._video_paths[video_key], data)
|
||||
except queue.Empty:
|
||||
logging.error(f"No result from encoder thread for {video_key}")
|
||||
results[video_key] = (self._video_paths[video_key], None)
|
||||
|
||||
self._cleanup()
|
||||
self._episode_active = False
|
||||
return results
|
||||
|
||||
def cancel_episode(self) -> None:
|
||||
"""Cancel the current episode, stopping encoder threads and cleaning up."""
|
||||
if not self._episode_active:
|
||||
return
|
||||
|
||||
# Signal all threads to stop
|
||||
for video_key in self._stop_events:
|
||||
self._stop_events[video_key].set()
|
||||
|
||||
# Wait for threads to finish
|
||||
for video_key in self._threads:
|
||||
self._threads[video_key].join(timeout=5)
|
||||
|
||||
# Clean up temp MP4 files
|
||||
video_path = self._video_paths.get(video_key)
|
||||
if video_path is not None and video_path.exists():
|
||||
shutil.rmtree(str(video_path.parent), ignore_errors=True)
|
||||
|
||||
self._cleanup()
|
||||
self._episode_active = False
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the encoder, canceling any in-progress episode."""
|
||||
if self._episode_active:
|
||||
self.cancel_episode()
|
||||
|
||||
def _cleanup(self) -> None:
|
||||
"""Clean up queues and thread tracking dicts."""
|
||||
for q in self._frame_queues.values():
|
||||
with contextlib.suppress(Exception):
|
||||
while not q.empty():
|
||||
q.get_nowait()
|
||||
self._frame_queues.clear()
|
||||
self._result_queues.clear()
|
||||
self._threads.clear()
|
||||
self._stop_events.clear()
|
||||
self._video_paths.clear()
|
||||
|
||||
|
||||
@dataclass
|
||||
class VideoFrame:
|
||||
# TODO(rcadene, lhoestq): move to Hugging Face `datasets` repo
|
||||
@@ -514,7 +943,7 @@ with warnings.catch_warnings():
|
||||
|
||||
def get_audio_info(video_path: Path | str) -> dict:
|
||||
# Set logging level
|
||||
logging.getLogger("libav").setLevel(av.logging.ERROR)
|
||||
logging.getLogger("libav").setLevel(av.logging.WARNING)
|
||||
|
||||
# Getting audio stream information
|
||||
audio_info = {}
|
||||
@@ -546,7 +975,7 @@ def get_audio_info(video_path: Path | str) -> dict:
|
||||
|
||||
def get_video_info(video_path: Path | str) -> dict:
|
||||
# Set logging level
|
||||
logging.getLogger("libav").setLevel(av.logging.ERROR)
|
||||
logging.getLogger("libav").setLevel(av.logging.WARNING)
|
||||
|
||||
# Getting video stream information
|
||||
video_info = {}
|
||||
@@ -632,8 +1061,15 @@ class VideoEncodingManager:
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
# Handle any remaining episodes that haven't been batch encoded
|
||||
if self.dataset.episodes_since_last_encoding > 0:
|
||||
streaming_encoder = getattr(self.dataset, "_streaming_encoder", None)
|
||||
|
||||
if streaming_encoder is not None:
|
||||
# Handle streaming encoder cleanup
|
||||
if exc_type is not None:
|
||||
streaming_encoder.cancel_episode()
|
||||
streaming_encoder.close()
|
||||
elif self.dataset.episodes_since_last_encoding > 0:
|
||||
# Handle any remaining episodes that haven't been batch encoded
|
||||
if exc_type is not None:
|
||||
logging.info("Exception occurred. Encoding remaining episodes before exit...")
|
||||
else:
|
||||
@@ -650,8 +1086,8 @@ class VideoEncodingManager:
|
||||
# Finalize the dataset to properly close all writers
|
||||
self.dataset.finalize()
|
||||
|
||||
# Clean up episode images if recording was interrupted
|
||||
if exc_type is not None:
|
||||
# Clean up episode images if recording was interrupted (only for non-streaming mode)
|
||||
if exc_type is not None and streaming_encoder is None:
|
||||
interrupted_episode_index = self.dataset.num_episodes
|
||||
for key in self.dataset.meta.video_keys:
|
||||
img_dir = self.dataset._get_image_file_path(
|
||||
@@ -665,14 +1101,12 @@ class VideoEncodingManager:
|
||||
|
||||
# Clean up any remaining images directory if it's empty
|
||||
img_dir = self.dataset.root / "images"
|
||||
# Check for any remaining PNG files
|
||||
png_files = list(img_dir.rglob("*.png"))
|
||||
if len(png_files) == 0:
|
||||
# Only remove the images directory if no PNG files remain
|
||||
if img_dir.exists():
|
||||
if img_dir.exists():
|
||||
png_files = list(img_dir.rglob("*.png"))
|
||||
if len(png_files) == 0:
|
||||
shutil.rmtree(img_dir)
|
||||
logging.debug("Cleaned up empty images directory")
|
||||
else:
|
||||
logging.debug(f"Images directory is not empty, containing {len(png_files)} PNG files")
|
||||
else:
|
||||
logging.debug(f"Images directory is not empty, containing {len(png_files)} PNG files")
|
||||
|
||||
return False # Don't suppress the original exception
|
||||
|
||||
@@ -0,0 +1,18 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .robstride import RobstrideMotorsBus
|
||||
from .tables import *
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,120 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Configuration tables for Damiao motors."""
|
||||
|
||||
from enum import IntEnum
|
||||
|
||||
|
||||
# Motor type definitions
|
||||
class MotorType(IntEnum):
|
||||
O0 = 0
|
||||
O1 = 1
|
||||
O2 = 2
|
||||
O3 = 3
|
||||
O4 = 4
|
||||
O5 = 5
|
||||
ELO5 = 6
|
||||
O6 = 7
|
||||
|
||||
|
||||
class CommMode(IntEnum):
|
||||
PrivateProtocole = 0
|
||||
CANopen = 1
|
||||
MIT = 2
|
||||
|
||||
|
||||
# Control modes
|
||||
class ControlMode(IntEnum):
|
||||
MIT = 0
|
||||
POS_VEL = 1
|
||||
VEL = 2
|
||||
|
||||
|
||||
# Motor limit parameters [PMAX, VMAX, TMAX]
|
||||
# PMAX: Maximum position (rad)
|
||||
# VMAX: Maximum velocity (rad/s)
|
||||
# TMAX: Maximum torque (N·m)
|
||||
MOTOR_LIMIT_PARAMS: dict[MotorType, tuple[float, float, float]] = {
|
||||
MotorType.O0: (12.57, 33, 14),
|
||||
MotorType.O1: (12.57, 44, 17),
|
||||
MotorType.O2: (12.57, 33, 20),
|
||||
MotorType.O3: (12.57, 33, 60),
|
||||
MotorType.O4: (12.57, 33, 120),
|
||||
MotorType.O5: (12.57, 50, 5.5),
|
||||
MotorType.ELO5: (12.57, 50, 6),
|
||||
MotorType.O6: (112.5, 50, 36),
|
||||
}
|
||||
|
||||
# Motor model names
|
||||
MODEL_NAMES = {
|
||||
MotorType.O0: "O0",
|
||||
MotorType.O1: "O1",
|
||||
MotorType.O2: "O2",
|
||||
MotorType.O3: "O3",
|
||||
MotorType.O4: "O4",
|
||||
MotorType.O5: "O5",
|
||||
MotorType.ELO5: "ELO5",
|
||||
MotorType.O6: "O6",
|
||||
}
|
||||
|
||||
# Motor resolution table (encoder counts per revolution)
|
||||
MODEL_RESOLUTION = {
|
||||
"O0": 65536,
|
||||
"O1": 65536,
|
||||
"O2": 65536,
|
||||
"O3": 65536,
|
||||
"O4": 65536,
|
||||
"O5": 65536,
|
||||
"ELO5": 65536,
|
||||
"O6": 65536,
|
||||
}
|
||||
|
||||
# CAN baudrates supported by Robstride motors
|
||||
AVAILABLE_BAUDRATES = [
|
||||
1000000, # 4: 1 mbps (default)
|
||||
]
|
||||
DEFAULT_BAUDRATE = 1000000
|
||||
|
||||
# Default timeout in milliseconds
|
||||
DEFAULT_TIMEOUT_MS = 0 # disabled by default, otherwise 20000 is 1s
|
||||
|
||||
|
||||
# Data that should be normalized
|
||||
NORMALIZED_DATA = ["Present_Position", "Goal_Position"]
|
||||
|
||||
|
||||
# MIT control parameter ranges
|
||||
MIT_KP_RANGE = (0.0, 500.0)
|
||||
MIT_KD_RANGE = (0.0, 5.0)
|
||||
|
||||
# CAN frame command IDs
|
||||
CAN_CMD_ENABLE = 0xFC
|
||||
CAN_CMD_DISABLE = 0xFD
|
||||
CAN_CMD_SET_ZERO = 0xFE
|
||||
CAN_CMD_CLEAR_FAULT = 0xFB
|
||||
|
||||
|
||||
CAN_CMD_QUERY_PARAM = 0x33
|
||||
CAN_CMD_WRITE_PARAM = 0x55
|
||||
CAN_CMD_SAVE_PARAM = 0xAA
|
||||
|
||||
# CAN ID for parameter operations
|
||||
CAN_PARAM_ID = 0x7FF
|
||||
|
||||
|
||||
RUNNING_TIMEOUT = 0.001
|
||||
PARAM_TIMEOUT = 0.01
|
||||
|
||||
STATE_CACHE_TTL_S = 0.02
|
||||
@@ -55,10 +55,16 @@ class DiffusionConfig(PreTrainedConfig):
|
||||
normalization_mapping: A dictionary that maps from a str value of FeatureType (e.g., "STATE", "VISUAL") to
|
||||
a corresponding NormalizationMode (e.g., NormalizationMode.MIN_MAX)
|
||||
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
|
||||
crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit
|
||||
within the image size. If None, no cropping is done.
|
||||
crop_is_random: Whether the crop should be random at training time (it's always a center crop in eval
|
||||
mode).
|
||||
resize_shape: (H, W) shape to resize images to as a preprocessing step for the vision
|
||||
backbone. If None, no resizing is done and the original image resolution is used.
|
||||
crop_ratio: Ratio in (0, 1] used to derive the crop size from resize_shape
|
||||
(crop_h = int(resize_shape[0] * crop_ratio), likewise for width).
|
||||
Set to 1.0 to disable cropping. Only takes effect when resize_shape is not None.
|
||||
crop_shape: (H, W) shape to crop images to. When resize_shape is set and crop_ratio < 1.0,
|
||||
this is computed automatically. Can also be set directly for legacy configs that use
|
||||
crop-only (without resize). If None and no derivation applies, no cropping is done.
|
||||
crop_is_random: Whether the crop should be random at training time (it's always a center
|
||||
crop in eval mode).
|
||||
pretrained_backbone_weights: Pretrained weights from torchvision to initialize the backbone.
|
||||
`None` means no pretrained weights.
|
||||
use_group_norm: Whether to replace batch normalization with group normalization in the backbone.
|
||||
@@ -114,7 +120,9 @@ class DiffusionConfig(PreTrainedConfig):
|
||||
# Architecture / modeling.
|
||||
# Vision backbone.
|
||||
vision_backbone: str = "resnet18"
|
||||
crop_shape: tuple[int, int] | None = (84, 84)
|
||||
resize_shape: tuple[int, int] | None = None
|
||||
crop_ratio: float = 1.0
|
||||
crop_shape: tuple[int, int] | None = None
|
||||
crop_is_random: bool = True
|
||||
pretrained_backbone_weights: str | None = None
|
||||
use_group_norm: bool = True
|
||||
@@ -139,6 +147,10 @@ class DiffusionConfig(PreTrainedConfig):
|
||||
# Inference
|
||||
num_inference_steps: int | None = None
|
||||
|
||||
# Optimization
|
||||
compile_model: bool = False
|
||||
compile_mode: str = "reduce-overhead"
|
||||
|
||||
# Loss computation
|
||||
do_mask_loss_for_padding: bool = False
|
||||
|
||||
@@ -171,6 +183,25 @@ class DiffusionConfig(PreTrainedConfig):
|
||||
f"Got {self.noise_scheduler_type}."
|
||||
)
|
||||
|
||||
if self.resize_shape is not None and (
|
||||
len(self.resize_shape) != 2 or any(d <= 0 for d in self.resize_shape)
|
||||
):
|
||||
raise ValueError(f"`resize_shape` must be a pair of positive integers. Got {self.resize_shape}.")
|
||||
if not (0 < self.crop_ratio <= 1.0):
|
||||
raise ValueError(f"`crop_ratio` must be in (0, 1]. Got {self.crop_ratio}.")
|
||||
|
||||
if self.resize_shape is not None:
|
||||
if self.crop_ratio < 1.0:
|
||||
self.crop_shape = (
|
||||
int(self.resize_shape[0] * self.crop_ratio),
|
||||
int(self.resize_shape[1] * self.crop_ratio),
|
||||
)
|
||||
else:
|
||||
# Explicitly disable cropping for resize+ratio path when crop_ratio == 1.0.
|
||||
self.crop_shape = None
|
||||
if self.crop_shape is not None and (self.crop_shape[0] <= 0 or self.crop_shape[1] <= 0):
|
||||
raise ValueError(f"`crop_shape` must have positive dimensions. Got {self.crop_shape}.")
|
||||
|
||||
# Check that the horizon size and U-Net downsampling is compatible.
|
||||
# U-Net downsamples by 2 with each stage.
|
||||
downsampling_factor = 2 ** len(self.down_dims)
|
||||
@@ -198,13 +229,12 @@ class DiffusionConfig(PreTrainedConfig):
|
||||
if len(self.image_features) == 0 and self.env_state_feature is None:
|
||||
raise ValueError("You must provide at least one image or the environment state among the inputs.")
|
||||
|
||||
if self.crop_shape is not None:
|
||||
if self.resize_shape is None and self.crop_shape is not None:
|
||||
for key, image_ft in self.image_features.items():
|
||||
if self.crop_shape[0] > image_ft.shape[1] or self.crop_shape[1] > image_ft.shape[2]:
|
||||
raise ValueError(
|
||||
f"`crop_shape` should fit within the images shapes. Got {self.crop_shape} "
|
||||
f"for `crop_shape` and {image_ft.shape} for "
|
||||
f"`{key}`."
|
||||
f"`crop_shape` should fit within the image shapes. Got {self.crop_shape} "
|
||||
f"for `crop_shape` and {image_ft.shape} for `{key}`."
|
||||
)
|
||||
|
||||
# Check that all input images have the same shape.
|
||||
|
||||
@@ -142,6 +142,9 @@ class DiffusionPolicy(PreTrainedPolicy):
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
for key in self.config.image_features:
|
||||
if self.config.n_obs_steps == 1 and batch[key].ndim == 4:
|
||||
batch[key] = batch[key].unsqueeze(1)
|
||||
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
|
||||
loss = self.diffusion.compute_loss(batch)
|
||||
# no output_dict so returning None
|
||||
@@ -182,6 +185,11 @@ class DiffusionModel(nn.Module):
|
||||
|
||||
self.unet = DiffusionConditionalUnet1d(config, global_cond_dim=global_cond_dim * config.n_obs_steps)
|
||||
|
||||
if config.compile_model:
|
||||
# Compile the U-Net. "reduce-overhead" is preferred for the small-batch repetitive loops
|
||||
# common in diffusion inference.
|
||||
self.unet = torch.compile(self.unet, mode=config.compile_mode)
|
||||
|
||||
self.noise_scheduler = _make_noise_scheduler(
|
||||
config.noise_scheduler_type,
|
||||
num_train_timesteps=config.num_train_timesteps,
|
||||
@@ -446,12 +454,18 @@ class DiffusionRgbEncoder(nn.Module):
|
||||
def __init__(self, config: DiffusionConfig):
|
||||
super().__init__()
|
||||
# Set up optional preprocessing.
|
||||
if config.crop_shape is not None:
|
||||
if config.resize_shape is not None:
|
||||
self.resize = torchvision.transforms.Resize(config.resize_shape)
|
||||
else:
|
||||
self.resize = None
|
||||
|
||||
crop_shape = config.crop_shape
|
||||
if crop_shape is not None:
|
||||
self.do_crop = True
|
||||
# Always use center crop for eval
|
||||
self.center_crop = torchvision.transforms.CenterCrop(config.crop_shape)
|
||||
self.center_crop = torchvision.transforms.CenterCrop(crop_shape)
|
||||
if config.crop_is_random:
|
||||
self.maybe_random_crop = torchvision.transforms.RandomCrop(config.crop_shape)
|
||||
self.maybe_random_crop = torchvision.transforms.RandomCrop(crop_shape)
|
||||
else:
|
||||
self.maybe_random_crop = self.center_crop
|
||||
else:
|
||||
@@ -477,13 +491,16 @@ class DiffusionRgbEncoder(nn.Module):
|
||||
|
||||
# Set up pooling and final layers.
|
||||
# Use a dry run to get the feature map shape.
|
||||
# The dummy input should take the number of image channels from `config.image_features` and it should
|
||||
# use the height and width from `config.crop_shape` if it is provided, otherwise it should use the
|
||||
# height and width from `config.image_features`.
|
||||
# The dummy shape mirrors the runtime preprocessing order: resize -> crop.
|
||||
|
||||
# Note: we have a check in the config class to make sure all images have the same shape.
|
||||
images_shape = next(iter(config.image_features.values())).shape
|
||||
dummy_shape_h_w = config.crop_shape if config.crop_shape is not None else images_shape[1:]
|
||||
if config.crop_shape is not None:
|
||||
dummy_shape_h_w = config.crop_shape
|
||||
elif config.resize_shape is not None:
|
||||
dummy_shape_h_w = config.resize_shape
|
||||
else:
|
||||
dummy_shape_h_w = images_shape[1:]
|
||||
dummy_shape = (1, images_shape[0], *dummy_shape_h_w)
|
||||
feature_map_shape = get_output_shape(self.backbone, dummy_shape)[1:]
|
||||
|
||||
@@ -499,7 +516,10 @@ class DiffusionRgbEncoder(nn.Module):
|
||||
Returns:
|
||||
(B, D) image feature.
|
||||
"""
|
||||
# Preprocess: maybe crop (if it was set up in the __init__).
|
||||
# Preprocess: resize if configured, then crop if configured.
|
||||
|
||||
if self.resize is not None:
|
||||
x = self.resize(x)
|
||||
if self.do_crop:
|
||||
if self.training: # noqa: SIM108
|
||||
x = self.maybe_random_crop(x)
|
||||
|
||||
@@ -14,7 +14,7 @@ from transformers.image_processing_utils import (
|
||||
)
|
||||
from transformers.image_processing_utils_fast import (
|
||||
BaseImageProcessorFast,
|
||||
DefaultFastImageProcessorKwargs,
|
||||
ImagesKwargs,
|
||||
group_images_by_shape,
|
||||
reorder_images,
|
||||
)
|
||||
@@ -77,7 +77,7 @@ def crop(img: torch.Tensor, left: int, top: int, right: int, bottom: int) -> tor
|
||||
return img[:, top:bottom, left:right]
|
||||
|
||||
|
||||
class Eagle25VLFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
|
||||
class Eagle25VLFastImageProcessorKwargs(ImagesKwargs):
|
||||
max_dynamic_tiles: int | None
|
||||
min_dynamic_tiles: int | None
|
||||
use_thumbnail: bool | None
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import builtins
|
||||
import copy
|
||||
import logging
|
||||
import math
|
||||
from collections import deque
|
||||
@@ -32,13 +33,21 @@ from lerobot.utils.import_utils import _transformers_available
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers.models.auto import CONFIG_MAPPING
|
||||
from transformers.models.gemma import modeling_gemma
|
||||
from transformers.models.gemma.modeling_gemma import GemmaForCausalLM
|
||||
from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration
|
||||
|
||||
from lerobot.policies.pi_gemma import (
|
||||
PaliGemmaForConditionalGenerationWithPiGemma,
|
||||
PiGemmaForCausalLM,
|
||||
_gated_residual,
|
||||
layernorm_forward,
|
||||
)
|
||||
else:
|
||||
CONFIG_MAPPING = None
|
||||
modeling_gemma = None
|
||||
GemmaForCausalLM = None
|
||||
PaliGemmaForConditionalGeneration = None
|
||||
PiGemmaForCausalLM = None
|
||||
_gated_residual = None
|
||||
layernorm_forward = None
|
||||
PaliGemmaForConditionalGenerationWithPiGemma = None
|
||||
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.policies.pi0.configuration_pi0 import DEFAULT_IMAGE_SIZE, PI0Config
|
||||
@@ -191,7 +200,7 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
|
||||
if images.dtype == torch.uint8:
|
||||
resized_images = torch.round(resized_images).clamp(0, 255).to(torch.uint8)
|
||||
elif images.dtype == torch.float32:
|
||||
resized_images = resized_images.clamp(-1.0, 1.0)
|
||||
resized_images = resized_images.clamp(0.0, 1.0)
|
||||
else:
|
||||
raise ValueError(f"Unsupported image dtype: {images.dtype}")
|
||||
|
||||
@@ -202,7 +211,7 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
|
||||
pad_w1 = pad_w0 + remainder_w
|
||||
|
||||
# Pad
|
||||
constant_value = 0 if images.dtype == torch.uint8 else -1.0
|
||||
constant_value = 0 if images.dtype == torch.uint8 else 0.0
|
||||
padded_images = F.pad(
|
||||
resized_images,
|
||||
(pad_w0, pad_w1, pad_h0, pad_h1), # left, right, top, bottom
|
||||
@@ -221,14 +230,14 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
|
||||
def compute_layer_complete(
|
||||
layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert
|
||||
):
|
||||
models = [paligemma.language_model, gemma_expert.model]
|
||||
models = [paligemma.model.language_model, gemma_expert.model]
|
||||
query_states = []
|
||||
key_states = []
|
||||
value_states = []
|
||||
gates = []
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
layer = models[i].layers[layer_idx]
|
||||
hidden_states, gate = layer.input_layernorm(hidden_states, cond=adarms_cond[i]) # noqa: PLW2901
|
||||
hidden_states, gate = layernorm_forward(layer.input_layernorm, hidden_states, adarms_cond[i])
|
||||
gates.append(gate)
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
|
||||
@@ -254,10 +263,10 @@ def compute_layer_complete(
|
||||
query_states, key_states, cos, sin, unsqueeze_dim=1
|
||||
)
|
||||
batch_size = query_states.shape[0]
|
||||
scaling = paligemma.language_model.layers[layer_idx].self_attn.scaling
|
||||
scaling = paligemma.model.language_model.layers[layer_idx].self_attn.scaling
|
||||
# Attention computation
|
||||
att_output, _ = modeling_gemma.eager_attention_forward(
|
||||
paligemma.language_model.layers[layer_idx].self_attn,
|
||||
paligemma.model.language_model.layers[layer_idx].self_attn,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
@@ -265,7 +274,7 @@ def compute_layer_complete(
|
||||
scaling,
|
||||
)
|
||||
# Get head_dim from the current layer, not from the model
|
||||
head_dim = paligemma.language_model.layers[layer_idx].self_attn.head_dim
|
||||
head_dim = paligemma.model.language_model.layers[layer_idx].self_attn.head_dim
|
||||
att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim)
|
||||
# Process layer outputs
|
||||
outputs_embeds = []
|
||||
@@ -277,15 +286,15 @@ def compute_layer_complete(
|
||||
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
|
||||
out_emb = layer.self_attn.o_proj(att_output[:, start_pos:end_pos])
|
||||
# first residual
|
||||
out_emb = modeling_gemma._gated_residual(hidden_states, out_emb, gates[i]) # noqa: SLF001
|
||||
out_emb = _gated_residual(hidden_states, out_emb, gates[i])
|
||||
after_first_residual = out_emb.clone()
|
||||
out_emb, gate = layer.post_attention_layernorm(out_emb, cond=adarms_cond[i])
|
||||
out_emb, gate = layernorm_forward(layer.post_attention_layernorm, out_emb, adarms_cond[i])
|
||||
# Convert to bfloat16 if the next layer (mlp) uses bfloat16
|
||||
if layer.mlp.up_proj.weight.dtype == torch.bfloat16:
|
||||
out_emb = out_emb.to(dtype=torch.bfloat16)
|
||||
out_emb = layer.mlp(out_emb)
|
||||
# second residual
|
||||
out_emb = modeling_gemma._gated_residual(after_first_residual, out_emb, gate) # noqa: SLF001
|
||||
out_emb = _gated_residual(after_first_residual, out_emb, gate)
|
||||
outputs_embeds.append(out_emb)
|
||||
start_pos = end_pos
|
||||
return outputs_embeds
|
||||
@@ -358,7 +367,7 @@ class PaliGemmaWithExpertModel(
|
||||
vlm_config_hf.text_config.num_hidden_layers = vlm_config.depth
|
||||
vlm_config_hf.text_config.num_key_value_heads = vlm_config.num_kv_heads
|
||||
vlm_config_hf.text_config.hidden_activation = "gelu_pytorch_tanh"
|
||||
vlm_config_hf.text_config.torch_dtype = "float32"
|
||||
vlm_config_hf.text_config.dtype = "float32"
|
||||
vlm_config_hf.text_config.vocab_size = 257152
|
||||
vlm_config_hf.text_config.use_adarms = use_adarms[0]
|
||||
vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None
|
||||
@@ -366,7 +375,7 @@ class PaliGemmaWithExpertModel(
|
||||
vlm_config_hf.vision_config.intermediate_size = 4304
|
||||
vlm_config_hf.vision_config.projection_dim = 2048
|
||||
vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast"
|
||||
vlm_config_hf.vision_config.torch_dtype = "float32"
|
||||
vlm_config_hf.vision_config.dtype = "float32"
|
||||
|
||||
action_expert_config_hf = CONFIG_MAPPING["gemma"](
|
||||
head_dim=action_expert_config.head_dim,
|
||||
@@ -377,13 +386,13 @@ class PaliGemmaWithExpertModel(
|
||||
num_key_value_heads=action_expert_config.num_kv_heads,
|
||||
vocab_size=257152,
|
||||
hidden_activation="gelu_pytorch_tanh",
|
||||
torch_dtype="float32",
|
||||
dtype="float32",
|
||||
use_adarms=use_adarms[1],
|
||||
adarms_cond_dim=action_expert_config.width if use_adarms[1] else None,
|
||||
)
|
||||
|
||||
self.paligemma = PaliGemmaForConditionalGeneration(config=vlm_config_hf)
|
||||
self.gemma_expert = GemmaForCausalLM(config=action_expert_config_hf)
|
||||
self.paligemma = PaliGemmaForConditionalGenerationWithPiGemma(config=vlm_config_hf)
|
||||
self.gemma_expert = PiGemmaForCausalLM(config=action_expert_config_hf)
|
||||
self.gemma_expert.model.embed_tokens = None
|
||||
|
||||
self.to_bfloat16_for_selected_params(precision)
|
||||
@@ -398,10 +407,11 @@ class PaliGemmaWithExpertModel(
|
||||
else:
|
||||
raise ValueError(f"Invalid precision: {precision}")
|
||||
|
||||
# Keep full vision path in float32 so we never toggle (toggle causes optimizer
|
||||
# "same dtype" error). Align with PI05.
|
||||
params_to_keep_float32 = [
|
||||
"vision_tower.vision_model.embeddings.patch_embedding.weight",
|
||||
"vision_tower.vision_model.embeddings.patch_embedding.bias",
|
||||
"vision_tower.vision_model.embeddings.position_embedding.weight",
|
||||
"vision_tower",
|
||||
"multi_modal_projector",
|
||||
"input_layernorm",
|
||||
"post_attention_layernorm",
|
||||
"model.norm",
|
||||
@@ -413,8 +423,8 @@ class PaliGemmaWithExpertModel(
|
||||
|
||||
def _set_requires_grad(self):
|
||||
if self.freeze_vision_encoder:
|
||||
self.paligemma.vision_tower.eval()
|
||||
for param in self.paligemma.vision_tower.parameters():
|
||||
self.paligemma.model.vision_tower.eval()
|
||||
for param in self.paligemma.model.vision_tower.parameters():
|
||||
param.requires_grad = False
|
||||
if self.train_expert_only:
|
||||
self.paligemma.eval()
|
||||
@@ -424,15 +434,23 @@ class PaliGemmaWithExpertModel(
|
||||
def train(self, mode: bool = True):
|
||||
super().train(mode)
|
||||
if self.freeze_vision_encoder:
|
||||
self.paligemma.vision_tower.eval()
|
||||
self.paligemma.model.vision_tower.eval()
|
||||
if self.train_expert_only:
|
||||
self.paligemma.eval()
|
||||
|
||||
def embed_image(self, image: torch.Tensor):
|
||||
return self.paligemma.model.get_image_features(image)
|
||||
# Vision tower and multi_modal_projector are kept in float32 (params_to_keep_float32). Align with PI05.
|
||||
out_dtype = image.dtype
|
||||
if image.dtype != torch.float32:
|
||||
image = image.to(torch.float32)
|
||||
image_outputs = self.paligemma.model.get_image_features(image)
|
||||
features = image_outputs.pooler_output * self.paligemma.config.text_config.hidden_size**0.5
|
||||
if features.dtype != out_dtype:
|
||||
features = features.to(out_dtype)
|
||||
return features
|
||||
|
||||
def embed_language_tokens(self, tokens: torch.Tensor):
|
||||
return self.paligemma.language_model.embed_tokens(tokens)
|
||||
return self.paligemma.model.language_model.embed_tokens(tokens)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -446,7 +464,7 @@ class PaliGemmaWithExpertModel(
|
||||
if adarms_cond is None:
|
||||
adarms_cond = [None, None]
|
||||
if inputs_embeds[1] is None:
|
||||
prefix_output = self.paligemma.language_model.forward(
|
||||
prefix_output = self.paligemma.model.language_model.forward(
|
||||
inputs_embeds=inputs_embeds[0],
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
@@ -470,7 +488,7 @@ class PaliGemmaWithExpertModel(
|
||||
prefix_output = None
|
||||
prefix_past_key_values = None
|
||||
else:
|
||||
models = [self.paligemma.language_model, self.gemma_expert.model]
|
||||
models = [self.paligemma.model.language_model, self.gemma_expert.model]
|
||||
num_layers = self.paligemma.config.text_config.num_hidden_layers
|
||||
|
||||
# Check if gradient checkpointing is enabled for any of the models
|
||||
@@ -510,7 +528,7 @@ class PaliGemmaWithExpertModel(
|
||||
def compute_final_norms(inputs_embeds, adarms_cond):
|
||||
outputs_embeds = []
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
out_emb, _ = models[i].norm(hidden_states, cond=adarms_cond[i])
|
||||
out_emb, _ = layernorm_forward(models[i].norm, hidden_states, adarms_cond[i])
|
||||
outputs_embeds.append(out_emb)
|
||||
return outputs_embeds
|
||||
|
||||
@@ -576,29 +594,19 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
# Also compile the main forward pass used during training
|
||||
self.forward = torch.compile(self.forward, mode=config.compile_mode)
|
||||
|
||||
msg = """An incorrect transformer version is used, please create an issue on https://github.com/huggingface/lerobot/issues"""
|
||||
|
||||
try:
|
||||
from transformers.models.siglip import check
|
||||
|
||||
if not check.check_whether_transformers_replace_is_installed_correctly():
|
||||
raise ValueError(msg)
|
||||
except ImportError:
|
||||
raise ValueError(msg) from None
|
||||
|
||||
def gradient_checkpointing_enable(self):
|
||||
"""Enable gradient checkpointing for memory optimization."""
|
||||
self.gradient_checkpointing_enabled = True
|
||||
self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = True
|
||||
self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = True
|
||||
self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing = True
|
||||
self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing = True
|
||||
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = True
|
||||
logging.info("Enabled gradient checkpointing for PI0Pytorch model")
|
||||
|
||||
def gradient_checkpointing_disable(self):
|
||||
"""Disable gradient checkpointing."""
|
||||
self.gradient_checkpointing_enabled = False
|
||||
self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = False
|
||||
self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = False
|
||||
self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing = False
|
||||
self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing = False
|
||||
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False
|
||||
logging.info("Disabled gradient checkpointing for PI0Pytorch model")
|
||||
|
||||
@@ -760,7 +768,7 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(state, x_t, time)
|
||||
|
||||
if (
|
||||
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
|
||||
self.paligemma_with_expert.paligemma.model.language_model.layers[0].self_attn.q_proj.weight.dtype
|
||||
== torch.bfloat16
|
||||
):
|
||||
suffix_embs = suffix_embs.to(dtype=torch.bfloat16)
|
||||
@@ -834,7 +842,7 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
||||
|
||||
prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks)
|
||||
self.paligemma_with_expert.paligemma.language_model.config._attn_implementation = "eager" # noqa: SLF001
|
||||
self.paligemma_with_expert.paligemma.model.language_model.config._attn_implementation = "eager" # noqa: SLF001
|
||||
|
||||
_, past_key_values = self.paligemma_with_expert.forward(
|
||||
attention_mask=prefix_att_2d_masks_4d,
|
||||
@@ -908,6 +916,7 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks)
|
||||
self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
|
||||
|
||||
past_key_values = copy.deepcopy(past_key_values)
|
||||
outputs_embeds, _ = self.paligemma_with_expert.forward(
|
||||
attention_mask=full_att_2d_masks_4d,
|
||||
position_ids=position_ids,
|
||||
@@ -997,14 +1006,12 @@ class PI0Policy(PreTrainedPolicy):
|
||||
# Check if dataset_stats were provided in kwargs
|
||||
model = cls(config, **kwargs)
|
||||
|
||||
# Now manually load and remap the state dict
|
||||
# Load state dict (expects keys with "model." prefix)
|
||||
try:
|
||||
# Try to load the pytorch_model.bin or model.safetensors file
|
||||
print(f"Loading model from: {pretrained_name_or_path}")
|
||||
try:
|
||||
from transformers.utils import cached_file
|
||||
|
||||
# Try safetensors first
|
||||
resolved_file = cached_file(
|
||||
pretrained_name_or_path,
|
||||
"model.safetensors",
|
||||
@@ -1012,7 +1019,7 @@ class PI0Policy(PreTrainedPolicy):
|
||||
force_download=kwargs.get("force_download", False),
|
||||
resume_download=kwargs.get("resume_download"),
|
||||
proxies=kwargs.get("proxies"),
|
||||
use_auth_token=kwargs.get("use_auth_token"),
|
||||
token=kwargs.get("token"),
|
||||
revision=kwargs.get("revision"),
|
||||
local_files_only=kwargs.get("local_files_only", False),
|
||||
)
|
||||
@@ -1025,7 +1032,7 @@ class PI0Policy(PreTrainedPolicy):
|
||||
print("Returning model without loading pretrained weights")
|
||||
return model
|
||||
|
||||
# First, fix any key differences # see openpi `model.py, _fix_pytorch_state_dict_keys`
|
||||
# First, fix any key differences (see openpi model.py, _fix_pytorch_state_dict_keys)
|
||||
fixed_state_dict = model._fix_pytorch_state_dict_keys(original_state_dict, model.config)
|
||||
|
||||
# Then add "model." prefix for all keys that don't already have it
|
||||
@@ -1070,7 +1077,7 @@ class PI0Policy(PreTrainedPolicy):
|
||||
print("All keys loaded successfully!")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not remap state dict keys: {e}")
|
||||
print(f"Warning: Could not load state dict: {e}")
|
||||
|
||||
return model
|
||||
|
||||
@@ -1120,6 +1127,14 @@ class PI0Policy(PreTrainedPolicy):
|
||||
# Some checkpoints might have this, but current model expects different structure
|
||||
logging.warning(f"Vision embedding key might need handling: {key}")
|
||||
|
||||
if (
|
||||
key == "model.paligemma_with_expert.paligemma.lm_head.weight"
|
||||
or key == "paligemma_with_expert.paligemma.lm_head.weight"
|
||||
):
|
||||
fixed_state_dict[
|
||||
"model.paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"
|
||||
] = value.clone()
|
||||
|
||||
fixed_state_dict[new_key] = value
|
||||
|
||||
return fixed_state_dict
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import builtins
|
||||
import copy
|
||||
import logging
|
||||
import math
|
||||
from collections import deque
|
||||
@@ -32,14 +33,20 @@ from lerobot.utils.import_utils import _transformers_available
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers.models.auto import CONFIG_MAPPING
|
||||
from transformers.models.gemma import modeling_gemma
|
||||
from transformers.models.gemma.modeling_gemma import GemmaForCausalLM
|
||||
from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration
|
||||
|
||||
from lerobot.policies.pi_gemma import (
|
||||
PaliGemmaForConditionalGenerationWithPiGemma,
|
||||
PiGemmaForCausalLM,
|
||||
_gated_residual,
|
||||
layernorm_forward,
|
||||
)
|
||||
else:
|
||||
CONFIG_MAPPING = None
|
||||
modeling_gemma = None
|
||||
GemmaForCausalLM = None
|
||||
PaliGemmaForConditionalGeneration = None
|
||||
|
||||
PiGemmaForCausalLM = None
|
||||
_gated_residual = None
|
||||
layernorm_forward = None
|
||||
PaliGemmaForConditionalGenerationWithPiGemma = None
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.policies.pi05.configuration_pi05 import DEFAULT_IMAGE_SIZE, PI05Config
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
||||
@@ -92,10 +99,11 @@ def create_sinusoidal_pos_embedding( # see openpi `create_sinusoidal_pos_embedd
|
||||
|
||||
|
||||
def sample_beta(alpha, beta, bsize, device): # see openpi `sample_beta` (exact copy)
|
||||
alpha_t = torch.as_tensor(alpha, dtype=torch.float32, device=device)
|
||||
beta_t = torch.as_tensor(beta, dtype=torch.float32, device=device)
|
||||
# Beta sampling uses _sample_dirichlet which isn't implemented for MPS, so sample on CPU
|
||||
alpha_t = torch.tensor(alpha, dtype=torch.float32)
|
||||
beta_t = torch.tensor(beta, dtype=torch.float32)
|
||||
dist = torch.distributions.Beta(alpha_t, beta_t)
|
||||
return dist.sample((bsize,))
|
||||
return dist.sample((bsize,)).to(device)
|
||||
|
||||
|
||||
def make_att_2d_masks(pad_masks, att_masks): # see openpi `make_att_2d_masks` (exact copy)
|
||||
@@ -189,7 +197,7 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
|
||||
if images.dtype == torch.uint8:
|
||||
resized_images = torch.round(resized_images).clamp(0, 255).to(torch.uint8)
|
||||
elif images.dtype == torch.float32:
|
||||
resized_images = resized_images.clamp(-1.0, 1.0)
|
||||
resized_images = resized_images.clamp(0.0, 1.0)
|
||||
else:
|
||||
raise ValueError(f"Unsupported image dtype: {images.dtype}")
|
||||
|
||||
@@ -200,7 +208,7 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
|
||||
pad_w1 = pad_w0 + remainder_w
|
||||
|
||||
# Pad
|
||||
constant_value = 0 if images.dtype == torch.uint8 else -1.0
|
||||
constant_value = 0 if images.dtype == torch.uint8 else 0.0
|
||||
padded_images = F.pad(
|
||||
resized_images,
|
||||
(pad_w0, pad_w1, pad_h0, pad_h1), # left, right, top, bottom
|
||||
@@ -219,14 +227,14 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
|
||||
def compute_layer_complete(
|
||||
layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert
|
||||
):
|
||||
models = [paligemma.language_model, gemma_expert.model]
|
||||
models = [paligemma.model.language_model, gemma_expert.model]
|
||||
query_states = []
|
||||
key_states = []
|
||||
value_states = []
|
||||
gates = []
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
layer = models[i].layers[layer_idx]
|
||||
hidden_states, gate = layer.input_layernorm(hidden_states, cond=adarms_cond[i]) # noqa: PLW2901
|
||||
hidden_states, gate = layernorm_forward(layer.input_layernorm, hidden_states, adarms_cond[i])
|
||||
gates.append(gate)
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
|
||||
@@ -252,10 +260,10 @@ def compute_layer_complete(
|
||||
query_states, key_states, cos, sin, unsqueeze_dim=1
|
||||
)
|
||||
batch_size = query_states.shape[0]
|
||||
scaling = paligemma.language_model.layers[layer_idx].self_attn.scaling
|
||||
scaling = paligemma.model.language_model.layers[layer_idx].self_attn.scaling
|
||||
# Attention computation
|
||||
att_output, _ = modeling_gemma.eager_attention_forward(
|
||||
paligemma.language_model.layers[layer_idx].self_attn,
|
||||
paligemma.model.language_model.layers[layer_idx].self_attn,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
@@ -263,7 +271,7 @@ def compute_layer_complete(
|
||||
scaling,
|
||||
)
|
||||
# Get head_dim from the current layer, not from the model
|
||||
head_dim = paligemma.language_model.layers[layer_idx].self_attn.head_dim
|
||||
head_dim = paligemma.model.language_model.layers[layer_idx].self_attn.head_dim
|
||||
att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim)
|
||||
# Process layer outputs
|
||||
outputs_embeds = []
|
||||
@@ -275,15 +283,15 @@ def compute_layer_complete(
|
||||
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
|
||||
out_emb = layer.self_attn.o_proj(att_output[:, start_pos:end_pos])
|
||||
# first residual
|
||||
out_emb = modeling_gemma._gated_residual(hidden_states, out_emb, gates[i]) # noqa: SLF001
|
||||
out_emb = _gated_residual(hidden_states, out_emb, gates[i])
|
||||
after_first_residual = out_emb.clone()
|
||||
out_emb, gate = layer.post_attention_layernorm(out_emb, cond=adarms_cond[i])
|
||||
out_emb, gate = layernorm_forward(layer.post_attention_layernorm, out_emb, adarms_cond[i])
|
||||
# Convert to bfloat16 if the next layer (mlp) uses bfloat16
|
||||
if layer.mlp.up_proj.weight.dtype == torch.bfloat16:
|
||||
out_emb = out_emb.to(dtype=torch.bfloat16)
|
||||
out_emb = layer.mlp(out_emb)
|
||||
# second residual
|
||||
out_emb = modeling_gemma._gated_residual(after_first_residual, out_emb, gate) # noqa: SLF001
|
||||
out_emb = _gated_residual(after_first_residual, out_emb, gate)
|
||||
outputs_embeds.append(out_emb)
|
||||
start_pos = end_pos
|
||||
return outputs_embeds
|
||||
@@ -356,7 +364,7 @@ class PaliGemmaWithExpertModel(
|
||||
vlm_config_hf.text_config.num_hidden_layers = vlm_config.depth
|
||||
vlm_config_hf.text_config.num_key_value_heads = vlm_config.num_kv_heads
|
||||
vlm_config_hf.text_config.hidden_activation = "gelu_pytorch_tanh"
|
||||
vlm_config_hf.text_config.torch_dtype = "float32"
|
||||
vlm_config_hf.text_config.dtype = "float32"
|
||||
vlm_config_hf.text_config.vocab_size = 257152
|
||||
vlm_config_hf.text_config.use_adarms = use_adarms[0]
|
||||
vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None
|
||||
@@ -364,7 +372,7 @@ class PaliGemmaWithExpertModel(
|
||||
vlm_config_hf.vision_config.intermediate_size = 4304
|
||||
vlm_config_hf.vision_config.projection_dim = 2048
|
||||
vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast"
|
||||
vlm_config_hf.vision_config.torch_dtype = "float32"
|
||||
vlm_config_hf.vision_config.dtype = "float32"
|
||||
|
||||
action_expert_config_hf = CONFIG_MAPPING["gemma"](
|
||||
head_dim=action_expert_config.head_dim,
|
||||
@@ -375,13 +383,13 @@ class PaliGemmaWithExpertModel(
|
||||
num_key_value_heads=action_expert_config.num_kv_heads,
|
||||
vocab_size=257152,
|
||||
hidden_activation="gelu_pytorch_tanh",
|
||||
torch_dtype="float32",
|
||||
dtype="float32",
|
||||
use_adarms=use_adarms[1],
|
||||
adarms_cond_dim=action_expert_config.width if use_adarms[1] else None,
|
||||
)
|
||||
|
||||
self.paligemma = PaliGemmaForConditionalGeneration(config=vlm_config_hf)
|
||||
self.gemma_expert = GemmaForCausalLM(config=action_expert_config_hf)
|
||||
self.paligemma = PaliGemmaForConditionalGenerationWithPiGemma(config=vlm_config_hf)
|
||||
self.gemma_expert = PiGemmaForCausalLM(config=action_expert_config_hf)
|
||||
self.gemma_expert.model.embed_tokens = None
|
||||
|
||||
self.to_bfloat16_for_selected_params(precision)
|
||||
@@ -396,10 +404,11 @@ class PaliGemmaWithExpertModel(
|
||||
else:
|
||||
raise ValueError(f"Invalid precision: {precision}")
|
||||
|
||||
# Keep full vision path in float32 so we never toggle (toggle causes optimizer
|
||||
# "same dtype" error). Saves memory vs full float32; more memory than only 3 params.
|
||||
params_to_keep_float32 = [
|
||||
"vision_tower.vision_model.embeddings.patch_embedding.weight",
|
||||
"vision_tower.vision_model.embeddings.patch_embedding.bias",
|
||||
"vision_tower.vision_model.embeddings.position_embedding.weight",
|
||||
"vision_tower",
|
||||
"multi_modal_projector",
|
||||
"input_layernorm",
|
||||
"post_attention_layernorm",
|
||||
"model.norm",
|
||||
@@ -411,8 +420,8 @@ class PaliGemmaWithExpertModel(
|
||||
|
||||
def _set_requires_grad(self):
|
||||
if self.freeze_vision_encoder:
|
||||
self.paligemma.vision_tower.eval()
|
||||
for param in self.paligemma.vision_tower.parameters():
|
||||
self.paligemma.model.vision_tower.eval()
|
||||
for param in self.paligemma.model.vision_tower.parameters():
|
||||
param.requires_grad = False
|
||||
if self.train_expert_only:
|
||||
self.paligemma.eval()
|
||||
@@ -422,15 +431,23 @@ class PaliGemmaWithExpertModel(
|
||||
def train(self, mode: bool = True):
|
||||
super().train(mode)
|
||||
if self.freeze_vision_encoder:
|
||||
self.paligemma.vision_tower.eval()
|
||||
self.paligemma.model.vision_tower.eval()
|
||||
if self.train_expert_only:
|
||||
self.paligemma.eval()
|
||||
|
||||
def embed_image(self, image: torch.Tensor):
|
||||
return self.paligemma.model.get_image_features(image)
|
||||
# Vision tower and multi_modal_projector are kept in float32 (params_to_keep_float32).
|
||||
out_dtype = image.dtype
|
||||
if image.dtype != torch.float32:
|
||||
image = image.to(torch.float32)
|
||||
image_outputs = self.paligemma.model.get_image_features(image)
|
||||
features = image_outputs.pooler_output * self.paligemma.config.text_config.hidden_size**0.5
|
||||
if features.dtype != out_dtype:
|
||||
features = features.to(out_dtype)
|
||||
return features
|
||||
|
||||
def embed_language_tokens(self, tokens: torch.Tensor):
|
||||
return self.paligemma.language_model.embed_tokens(tokens)
|
||||
return self.paligemma.model.language_model.embed_tokens(tokens)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -444,7 +461,7 @@ class PaliGemmaWithExpertModel(
|
||||
if adarms_cond is None:
|
||||
adarms_cond = [None, None]
|
||||
if inputs_embeds[1] is None:
|
||||
prefix_output = self.paligemma.language_model.forward(
|
||||
prefix_output = self.paligemma.model.language_model.forward(
|
||||
inputs_embeds=inputs_embeds[0],
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
@@ -468,7 +485,7 @@ class PaliGemmaWithExpertModel(
|
||||
prefix_output = None
|
||||
prefix_past_key_values = None
|
||||
else:
|
||||
models = [self.paligemma.language_model, self.gemma_expert.model]
|
||||
models = [self.paligemma.model.language_model, self.gemma_expert.model]
|
||||
num_layers = self.paligemma.config.text_config.num_hidden_layers
|
||||
|
||||
# Check if gradient checkpointing is enabled for any of the models
|
||||
@@ -508,7 +525,7 @@ class PaliGemmaWithExpertModel(
|
||||
def compute_final_norms(inputs_embeds, adarms_cond):
|
||||
outputs_embeds = []
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
out_emb, _ = models[i].norm(hidden_states, cond=adarms_cond[i])
|
||||
out_emb, _ = layernorm_forward(models[i].norm, hidden_states, adarms_cond[i])
|
||||
outputs_embeds.append(out_emb)
|
||||
return outputs_embeds
|
||||
|
||||
@@ -573,29 +590,19 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
# Also compile the main forward pass used during training
|
||||
self.forward = torch.compile(self.forward, mode=config.compile_mode)
|
||||
|
||||
msg = """An incorrect transformer version is used, please create an issue on https://github.com/huggingface/lerobot/issues"""
|
||||
|
||||
try:
|
||||
from transformers.models.siglip import check
|
||||
|
||||
if not check.check_whether_transformers_replace_is_installed_correctly():
|
||||
raise ValueError(msg)
|
||||
except ImportError:
|
||||
raise ValueError(msg) from None
|
||||
|
||||
def gradient_checkpointing_enable(self):
|
||||
"""Enable gradient checkpointing for memory optimization."""
|
||||
self.gradient_checkpointing_enabled = True
|
||||
self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = True
|
||||
self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = True
|
||||
self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing = True
|
||||
self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing = True
|
||||
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = True
|
||||
logging.info("Enabled gradient checkpointing for PI05Pytorch model")
|
||||
|
||||
def gradient_checkpointing_disable(self):
|
||||
"""Disable gradient checkpointing."""
|
||||
self.gradient_checkpointing_enabled = False
|
||||
self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = False
|
||||
self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = False
|
||||
self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing = False
|
||||
self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing = False
|
||||
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False
|
||||
logging.info("Disabled gradient checkpointing for PI05Pytorch model")
|
||||
|
||||
@@ -737,7 +744,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(x_t, time)
|
||||
|
||||
if (
|
||||
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
|
||||
self.paligemma_with_expert.paligemma.model.language_model.layers[0].self_attn.q_proj.weight.dtype
|
||||
== torch.bfloat16
|
||||
):
|
||||
suffix_embs = suffix_embs.to(dtype=torch.bfloat16)
|
||||
@@ -808,7 +815,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
||||
|
||||
prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks)
|
||||
self.paligemma_with_expert.paligemma.language_model.config._attn_implementation = "eager" # noqa: SLF001
|
||||
self.paligemma_with_expert.paligemma.model.language_model.config._attn_implementation = "eager" # noqa: SLF001
|
||||
|
||||
_, past_key_values = self.paligemma_with_expert.forward(
|
||||
attention_mask=prefix_att_2d_masks_4d,
|
||||
@@ -880,6 +887,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks)
|
||||
self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
|
||||
|
||||
past_key_values = copy.deepcopy(past_key_values)
|
||||
outputs_embeds, _ = self.paligemma_with_expert.forward(
|
||||
attention_mask=full_att_2d_masks_4d,
|
||||
position_ids=position_ids,
|
||||
@@ -969,14 +977,12 @@ class PI05Policy(PreTrainedPolicy):
|
||||
# Check if dataset_stats were provided in kwargs
|
||||
model = cls(config, **kwargs)
|
||||
|
||||
# Now manually load and remap the state dict
|
||||
# Load state dict (expects keys with "model." prefix)
|
||||
try:
|
||||
# Try to load the pytorch_model.bin or model.safetensors file
|
||||
print(f"Loading model from: {pretrained_name_or_path}")
|
||||
try:
|
||||
from transformers.utils import cached_file
|
||||
|
||||
# Try safetensors first
|
||||
resolved_file = cached_file(
|
||||
pretrained_name_or_path,
|
||||
"model.safetensors",
|
||||
@@ -984,7 +990,7 @@ class PI05Policy(PreTrainedPolicy):
|
||||
force_download=kwargs.get("force_download", False),
|
||||
resume_download=kwargs.get("resume_download"),
|
||||
proxies=kwargs.get("proxies"),
|
||||
use_auth_token=kwargs.get("use_auth_token"),
|
||||
token=kwargs.get("token"),
|
||||
revision=kwargs.get("revision"),
|
||||
local_files_only=kwargs.get("local_files_only", False),
|
||||
)
|
||||
@@ -997,7 +1003,7 @@ class PI05Policy(PreTrainedPolicy):
|
||||
print("Returning model without loading pretrained weights")
|
||||
return model
|
||||
|
||||
# First, fix any key differences # see openpi `model.py, _fix_pytorch_state_dict_keys`
|
||||
# First, fix any key differences (see openpi model.py, _fix_pytorch_state_dict_keys)
|
||||
fixed_state_dict = model._fix_pytorch_state_dict_keys(original_state_dict, model.config)
|
||||
|
||||
# Then add "model." prefix for all keys that don't already have it
|
||||
@@ -1009,8 +1015,6 @@ class PI05Policy(PreTrainedPolicy):
|
||||
new_key = f"model.{key}"
|
||||
remapped_state_dict[new_key] = value
|
||||
remap_count += 1
|
||||
if remap_count <= 10: # Only print first 10 to avoid spam
|
||||
print(f"Remapped: {key} -> {new_key}")
|
||||
else:
|
||||
remapped_state_dict[key] = value
|
||||
|
||||
@@ -1044,7 +1048,7 @@ class PI05Policy(PreTrainedPolicy):
|
||||
print("All keys loaded successfully!")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not remap state dict keys: {e}")
|
||||
print(f"Warning: Could not load state dict: {e}")
|
||||
|
||||
return model
|
||||
|
||||
@@ -1098,6 +1102,14 @@ class PI05Policy(PreTrainedPolicy):
|
||||
# Some checkpoints might have this, but current model expects different structure
|
||||
logging.warning(f"Vision embedding key might need handling: {key}")
|
||||
|
||||
if (
|
||||
key == "model.paligemma_with_expert.paligemma.lm_head.weight"
|
||||
or key == "paligemma_with_expert.paligemma.lm_head.weight"
|
||||
):
|
||||
fixed_state_dict[
|
||||
"model.paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"
|
||||
] = value.clone()
|
||||
|
||||
fixed_state_dict[new_key] = value
|
||||
|
||||
return fixed_state_dict
|
||||
|
||||
@@ -23,7 +23,6 @@ import torch
|
||||
|
||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.policies.pi05.configuration_pi05 import PI05Config
|
||||
from lerobot.policies.pi05.modeling_pi05 import pad_vector
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
@@ -68,9 +67,6 @@ class Pi05PrepareStateTokenizerProcessorStep(ProcessorStep):
|
||||
# TODO: check if this necessary
|
||||
state = deepcopy(state)
|
||||
|
||||
# Prepare state (pad to max_state_dim)
|
||||
state = pad_vector(state, self.max_state_dim)
|
||||
|
||||
# State should already be normalized to [-1, 1] by the NormalizerProcessorStep that runs before this step
|
||||
# Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`)
|
||||
state_np = state.cpu().numpy()
|
||||
|
||||
@@ -54,7 +54,7 @@ class PI0FastConfig(PreTrainedConfig):
|
||||
|
||||
tokenizer_max_length: int = 200 # see openpi `__post_init__`
|
||||
text_tokenizer_name: str = "google/paligemma-3b-pt-224"
|
||||
action_tokenizer_name: str = "physical-intelligence/fast"
|
||||
action_tokenizer_name: str = "lerobot/fast-action-tokenizer"
|
||||
temperature: float = 0.0
|
||||
max_decoding_steps: int = 256
|
||||
fast_skip_tokens: int = 128
|
||||
|
||||
@@ -38,11 +38,16 @@ else:
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers import AutoTokenizer
|
||||
from transformers.models.auto import CONFIG_MAPPING
|
||||
from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration
|
||||
|
||||
from lerobot.policies.pi_gemma import (
|
||||
PaliGemmaForConditionalGenerationWithPiGemma,
|
||||
PiGemmaModel,
|
||||
)
|
||||
else:
|
||||
CONFIG_MAPPING = None
|
||||
PaliGemmaForConditionalGeneration = None
|
||||
AutoTokenizer = None
|
||||
PiGemmaModel = None
|
||||
PaliGemmaForConditionalGenerationWithPiGemma = None
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.policies.pi0_fast.configuration_pi0_fast import PI0FastConfig
|
||||
@@ -121,7 +126,7 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
|
||||
if images.dtype == torch.uint8:
|
||||
resized_images = torch.round(resized_images).clamp(0, 255).to(torch.uint8)
|
||||
elif images.dtype == torch.float32:
|
||||
resized_images = resized_images.clamp(-1.0, 1.0)
|
||||
resized_images = resized_images.clamp(0.0, 1.0)
|
||||
else:
|
||||
raise ValueError(f"Unsupported image dtype: {images.dtype}")
|
||||
|
||||
@@ -132,7 +137,7 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
|
||||
pad_w1 = pad_w0 + remainder_w
|
||||
|
||||
# Pad
|
||||
constant_value = 0 if images.dtype == torch.uint8 else -1.0
|
||||
constant_value = 0 if images.dtype == torch.uint8 else 0.0
|
||||
padded_images = F.pad(
|
||||
resized_images,
|
||||
(pad_w0, pad_w1, pad_h0, pad_h1), # left, right, top, bottom
|
||||
@@ -206,16 +211,22 @@ class PI0FastPaliGemma(nn.Module):
|
||||
vlm_config_hf.text_config.num_hidden_layers = vlm_config.depth
|
||||
vlm_config_hf.text_config.num_key_value_heads = vlm_config.num_kv_heads
|
||||
vlm_config_hf.text_config.hidden_activation = "gelu_pytorch_tanh"
|
||||
vlm_config_hf.text_config.torch_dtype = "float32"
|
||||
vlm_config_hf.text_config.dtype = "float32"
|
||||
vlm_config_hf.text_config.vocab_size = 257152
|
||||
vlm_config_hf.text_config.use_adarms = use_adarms[0]
|
||||
vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None
|
||||
vlm_config_hf.vision_config.intermediate_size = 4304
|
||||
vlm_config_hf.vision_config.projection_dim = 2048
|
||||
vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast"
|
||||
vlm_config_hf.vision_config.torch_dtype = "float32"
|
||||
vlm_config_hf.vision_config.dtype = "float32"
|
||||
|
||||
self.paligemma = PaliGemmaForConditionalGeneration(config=vlm_config_hf)
|
||||
self.paligemma = PaliGemmaForConditionalGenerationWithPiGemma(config=vlm_config_hf)
|
||||
|
||||
# Use PI Gemma (AdaRMS) as language model when use_adarms[0] is True so that
|
||||
# forward(..., adarms_cond=...) is supported (same as pi0/pi05).
|
||||
if use_adarms[0]:
|
||||
text_config = self.paligemma.config.text_config
|
||||
self.paligemma.model.language_model = PiGemmaModel(text_config)
|
||||
|
||||
self.to_bfloat16_for_selected_params(precision)
|
||||
|
||||
@@ -228,10 +239,11 @@ class PI0FastPaliGemma(nn.Module):
|
||||
else:
|
||||
raise ValueError(f"Invalid precision: {precision}")
|
||||
|
||||
# Keep full vision path in float32 so we never toggle (toggle causes optimizer
|
||||
# "same dtype" error). Align with PI05.
|
||||
params_to_keep_float32 = [
|
||||
"vision_tower.vision_model.embeddings.patch_embedding.weight",
|
||||
"vision_tower.vision_model.embeddings.patch_embedding.bias",
|
||||
"vision_tower.vision_model.embeddings.position_embedding.weight",
|
||||
"vision_tower",
|
||||
"multi_modal_projector",
|
||||
"input_layernorm",
|
||||
"post_attention_layernorm",
|
||||
"model.norm",
|
||||
@@ -242,10 +254,18 @@ class PI0FastPaliGemma(nn.Module):
|
||||
param.data = param.data.to(dtype=torch.float32)
|
||||
|
||||
def embed_image(self, image: torch.Tensor):
|
||||
return self.paligemma.model.get_image_features(image)
|
||||
# Vision tower and multi_modal_projector are kept in float32 (params_to_keep_float32). Align with PI05.
|
||||
out_dtype = image.dtype
|
||||
if image.dtype != torch.float32:
|
||||
image = image.to(torch.float32)
|
||||
image_outputs = self.paligemma.model.get_image_features(image)
|
||||
features = image_outputs.pooler_output * self.paligemma.config.text_config.hidden_size**0.5
|
||||
if features.dtype != out_dtype:
|
||||
features = features.to(out_dtype)
|
||||
return features
|
||||
|
||||
def embed_language_tokens(self, tokens: torch.Tensor):
|
||||
return self.paligemma.language_model.embed_tokens(tokens)
|
||||
return self.paligemma.model.language_model.embed_tokens(tokens)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -259,7 +279,7 @@ class PI0FastPaliGemma(nn.Module):
|
||||
if adarms_cond is None:
|
||||
adarms_cond = [None, None]
|
||||
if inputs_embeds[1] is None:
|
||||
prefix_output = self.paligemma.language_model.forward(
|
||||
prefix_output = self.paligemma.model.language_model.forward(
|
||||
inputs_embeds=inputs_embeds[0],
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
@@ -306,24 +326,14 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
self.sample_actions_fast = torch.compile(self.sample_actions_fast, mode=config.compile_mode)
|
||||
self.forward = torch.compile(self.forward, mode=config.compile_mode)
|
||||
|
||||
msg = """An incorrect transformer version is used, please create an issue on https://github.com/huggingface/lerobot/issues"""
|
||||
|
||||
try:
|
||||
from transformers.models.siglip import check
|
||||
|
||||
if not check.check_whether_transformers_replace_is_installed_correctly():
|
||||
raise ValueError(msg)
|
||||
except ImportError:
|
||||
raise ValueError(msg) from None
|
||||
|
||||
def gradient_checkpointing_enable(self):
|
||||
"""Enable gradient checkpointing for memory optimization."""
|
||||
self.gradient_checkpointing_enabled = True
|
||||
# Call the proper gradient_checkpointing_enable() method with use_reentrant=False for better memory efficiency
|
||||
self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing_enable(
|
||||
self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing_enable(
|
||||
gradient_checkpointing_kwargs={"use_reentrant": False}
|
||||
)
|
||||
self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing_enable(
|
||||
self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing_enable(
|
||||
gradient_checkpointing_kwargs={"use_reentrant": False}
|
||||
)
|
||||
logging.info("Enabled gradient checkpointing for PI0FastPytorch model")
|
||||
@@ -332,8 +342,8 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
"""Disable gradient checkpointing."""
|
||||
self.gradient_checkpointing_enabled = False
|
||||
# Call the proper gradient_checkpointing_disable() method
|
||||
self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing_disable()
|
||||
self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing_disable()
|
||||
self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing_disable()
|
||||
self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing_disable()
|
||||
logging.info("Disabled gradient checkpointing for PI0FastPytorch model")
|
||||
|
||||
def _apply_checkpoint(self, func, *args, **kwargs):
|
||||
@@ -523,7 +533,7 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
|
||||
# Convert embeddings to bfloat16 if needed
|
||||
if (
|
||||
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
|
||||
self.paligemma_with_expert.paligemma.model.language_model.layers[0].self_attn.q_proj.weight.dtype
|
||||
== torch.bfloat16
|
||||
):
|
||||
prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
|
||||
@@ -616,7 +626,7 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
)
|
||||
|
||||
if (
|
||||
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
|
||||
self.paligemma_with_expert.paligemma.model.language_model.layers[0].self_attn.q_proj.weight.dtype
|
||||
== torch.bfloat16
|
||||
):
|
||||
prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
|
||||
@@ -714,7 +724,7 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
|
||||
# Ensure correct precision (bfloat16/float32)
|
||||
if (
|
||||
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
|
||||
self.paligemma_with_expert.paligemma.model.language_model.layers[0].self_attn.q_proj.weight.dtype
|
||||
== torch.bfloat16
|
||||
):
|
||||
prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
|
||||
@@ -897,14 +907,12 @@ class PI0FastPolicy(PreTrainedPolicy):
|
||||
# Check if dataset_stats were provided in kwargs
|
||||
model = cls(config, **kwargs)
|
||||
|
||||
# Now manually load and remap the state dict
|
||||
# Load state dict (expects keys with "model." prefix)
|
||||
try:
|
||||
# Try to load the pytorch_model.bin or model.safetensors file
|
||||
print(f"Loading model from: {pretrained_name_or_path}")
|
||||
try:
|
||||
from transformers.utils import cached_file
|
||||
|
||||
# Try safetensors first
|
||||
resolved_file = cached_file(
|
||||
pretrained_name_or_path,
|
||||
"model.safetensors",
|
||||
@@ -912,7 +920,7 @@ class PI0FastPolicy(PreTrainedPolicy):
|
||||
force_download=kwargs.get("force_download", False),
|
||||
resume_download=kwargs.get("resume_download"),
|
||||
proxies=kwargs.get("proxies"),
|
||||
use_auth_token=kwargs.get("use_auth_token"),
|
||||
token=kwargs.get("token"),
|
||||
revision=kwargs.get("revision"),
|
||||
local_files_only=kwargs.get("local_files_only", False),
|
||||
)
|
||||
@@ -925,8 +933,9 @@ class PI0FastPolicy(PreTrainedPolicy):
|
||||
print("Returning model without loading pretrained weights")
|
||||
return model
|
||||
|
||||
# First, fix any key differences # see openpi `model.py, _fix_pytorch_state_dict_keys`
|
||||
# First, fix any key differences (see openpi model.py, _fix_pytorch_state_dict_keys)
|
||||
fixed_state_dict = model._fix_pytorch_state_dict_keys(original_state_dict, model.config)
|
||||
|
||||
# Then add "model." prefix for all keys that don't already have it
|
||||
remapped_state_dict = {}
|
||||
remap_count = 0
|
||||
@@ -936,8 +945,6 @@ class PI0FastPolicy(PreTrainedPolicy):
|
||||
new_key = f"model.{key}"
|
||||
remapped_state_dict[new_key] = value
|
||||
remap_count += 1
|
||||
if remap_count <= 10: # Only print first 10 to avoid spam
|
||||
print(f"Remapped: {key} -> {new_key}")
|
||||
else:
|
||||
remapped_state_dict[key] = value
|
||||
|
||||
@@ -971,7 +978,7 @@ class PI0FastPolicy(PreTrainedPolicy):
|
||||
print("All keys loaded successfully!")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not remap state dict keys: {e}")
|
||||
print(f"Warning: Could not load state dict: {e}")
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@@ -23,7 +23,6 @@ import torch
|
||||
|
||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.policies.pi0_fast.configuration_pi0_fast import PI0FastConfig
|
||||
from lerobot.policies.pi0_fast.modeling_pi0_fast import pad_vector
|
||||
from lerobot.processor import (
|
||||
ActionTokenizerProcessorStep,
|
||||
AddBatchDimensionProcessorStep,
|
||||
@@ -69,9 +68,6 @@ class Pi0FastPrepareStateAndLanguageTokenizerProcessorStep(ProcessorStep):
|
||||
# TODO: check if this necessary
|
||||
state = deepcopy(state)
|
||||
|
||||
# Prepare state (pad to max_state_dim)
|
||||
state = pad_vector(state, self.max_state_dim)
|
||||
|
||||
# State should already be normalized to [-1, 1] by the NormalizerProcessorStep that runs before this step
|
||||
# Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`)
|
||||
state_np = state.cpu().numpy()
|
||||
|
||||
@@ -0,0 +1,363 @@
|
||||
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from lerobot.utils.import_utils import _transformers_available
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers.cache_utils import DynamicCache
|
||||
from transformers.masking_utils import create_causal_mask
|
||||
from transformers.modeling_layers import GradientCheckpointingLayer
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
from transformers.models.gemma.modeling_gemma import (
|
||||
GemmaAttention,
|
||||
GemmaConfig,
|
||||
GemmaForCausalLM,
|
||||
GemmaMLP,
|
||||
GemmaModel,
|
||||
)
|
||||
from transformers.models.paligemma.modeling_paligemma import (
|
||||
PaliGemmaForConditionalGeneration,
|
||||
PaliGemmaModel,
|
||||
)
|
||||
else:
|
||||
GemmaAttention = None
|
||||
GemmaConfig = None
|
||||
GemmaForCausalLM = None
|
||||
GemmaMLP = None
|
||||
GemmaModel = None
|
||||
PaliGemmaModel = None
|
||||
PaliGemmaForConditionalGeneration = None
|
||||
DynamicCache = None
|
||||
GradientCheckpointingLayer = None
|
||||
BaseModelOutputWithPast = None
|
||||
create_causal_mask = None
|
||||
|
||||
|
||||
def _gated_residual(
|
||||
x: torch.Tensor | None,
|
||||
y: torch.Tensor | None,
|
||||
gate: torch.Tensor | None,
|
||||
) -> torch.Tensor | None:
|
||||
"""Gated residual: x + y when gate is None, else x + y * gate."""
|
||||
if x is None and y is None:
|
||||
return None
|
||||
if x is None or y is None:
|
||||
return x if x is not None else y
|
||||
if gate is None:
|
||||
return x + y
|
||||
return x + y * gate
|
||||
|
||||
|
||||
def layernorm_forward(
|
||||
layernorm: nn.Module,
|
||||
x: torch.Tensor,
|
||||
cond: torch.Tensor | None = None,
|
||||
):
|
||||
"""
|
||||
call layernorm and return hidden states and gate
|
||||
if cond is not None, use conditional norm
|
||||
otherwise, use normal gemma norm
|
||||
"""
|
||||
if cond is not None:
|
||||
return layernorm(x, cond=cond)
|
||||
else:
|
||||
return layernorm(x)
|
||||
|
||||
|
||||
class PiGemmaRMSNorm(nn.Module):
|
||||
"""
|
||||
Adaptive RMSNorm for PI Gemma (AdaRMS).
|
||||
When cond_dim is set, uses cond to modulate scale/shift/gate; otherwise behaves like standard GemmaRMSNorm.
|
||||
forward(x, cond=None) returns (output, gate) for use with _gated_residual.
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int, eps: float = 1e-6, cond_dim: int | None = None):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.dim = dim
|
||||
self.cond_dim = cond_dim
|
||||
if cond_dim is not None:
|
||||
self.dense = nn.Linear(cond_dim, dim * 3, bias=True)
|
||||
nn.init.zeros_(self.dense.weight)
|
||||
else:
|
||||
self.weight = nn.Parameter(torch.zeros(dim))
|
||||
self.dense = None
|
||||
|
||||
def _norm(self, x):
|
||||
# Compute variance in float32 (like the source implementation)
|
||||
var = torch.mean(torch.square(x.float()), dim=-1, keepdim=True)
|
||||
# Compute normalization in float32
|
||||
normed_inputs = x * torch.rsqrt(var + self.eps)
|
||||
return normed_inputs
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
cond: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
dtype = x.dtype
|
||||
normed = self._norm(x)
|
||||
if cond is None or self.dense is None:
|
||||
normed = normed * (1.0 + self.weight.float())
|
||||
return normed.type_as(x), None
|
||||
if cond.shape[-1] != self.cond_dim:
|
||||
raise ValueError(f"Expected cond dim {self.cond_dim}, got {cond.shape[-1]}")
|
||||
modulation = self.dense(cond)
|
||||
if len(x.shape) == 3:
|
||||
modulation = modulation.unsqueeze(1)
|
||||
scale, shift, gate = modulation.chunk(3, dim=-1)
|
||||
normed = normed * (1 + scale.float()) + shift.float()
|
||||
return normed.to(dtype), gate.to(dtype)
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
if self.dense is not None:
|
||||
return f"dim={self.dim}, eps={self.eps}, adaptive=True, cond_dim={self.cond_dim}"
|
||||
return f"dim={self.dim}, eps={self.eps}"
|
||||
|
||||
|
||||
def _get_pi_gemma_decoder_layer_base():
|
||||
"""base for PiGemmaDecoderLayer"""
|
||||
|
||||
class _PiGemmaDecoderLayerBase(GradientCheckpointingLayer):
|
||||
"""Decoder layer that uses PiGemmaRMSNorm and _gated_residual, compatible with v5 Gemma."""
|
||||
|
||||
def __init__(self, config: GemmaConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.self_attn = GemmaAttention(config=config, layer_idx=layer_idx)
|
||||
self.mlp = GemmaMLP(config)
|
||||
cond_dim = (
|
||||
getattr(config, "adarms_cond_dim", None) if getattr(config, "use_adarms", False) else None
|
||||
)
|
||||
self.input_layernorm = PiGemmaRMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim
|
||||
)
|
||||
self.post_attention_layernorm = PiGemmaRMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
position_ids: torch.LongTensor | None = None,
|
||||
past_key_values=None,
|
||||
use_cache: bool = False,
|
||||
cache_position: torch.LongTensor | None = None,
|
||||
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
|
||||
adarms_cond: torch.Tensor | None = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
hidden_states, gate = self.input_layernorm(hidden_states, cond=adarms_cond)
|
||||
hidden_states, _ = self.self_attn(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = _gated_residual(residual, hidden_states, gate)
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states, gate = self.post_attention_layernorm(hidden_states, cond=adarms_cond)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = _gated_residual(residual, hidden_states, gate)
|
||||
return hidden_states
|
||||
|
||||
return _PiGemmaDecoderLayerBase
|
||||
|
||||
|
||||
class PiGemmaModel(GemmaModel): # type: ignore[misc]
|
||||
"""
|
||||
GemmaModel extended with AdaRMS (adaptive RMSNorm) and gated residuals when config.use_adarms is True.
|
||||
"""
|
||||
|
||||
def __init__(self, config: GemmaConfig, **kwargs):
|
||||
super().__init__(config, **kwargs)
|
||||
# if not getattr(config, "use_adarms", False):
|
||||
# return
|
||||
cond_dim = getattr(config, "adarms_cond_dim", None)
|
||||
pi_gemma_decoder_layer_base = _get_pi_gemma_decoder_layer_base()
|
||||
self.layers = nn.ModuleList(
|
||||
[pi_gemma_decoder_layer_base(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||
)
|
||||
self.norm = PiGemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
position_ids: torch.LongTensor | None = None,
|
||||
past_key_values: DynamicCache | None = None,
|
||||
inputs_embeds: torch.FloatTensor | None = None,
|
||||
use_cache: bool | None = None,
|
||||
output_attentions: bool | None = None,
|
||||
output_hidden_states: bool | None = None,
|
||||
cache_position: torch.LongTensor | None = None,
|
||||
adarms_cond: torch.Tensor | None = None,
|
||||
**kwargs,
|
||||
) -> BaseModelOutputWithPast:
|
||||
"""
|
||||
adarms_cond (`torch.Tensor` of shape `(batch_size, cond_dim)`, *optional*):
|
||||
Condition for ADARMS.
|
||||
"""
|
||||
output_attentions = (
|
||||
output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if self.gradient_checkpointing and self.training and use_cache:
|
||||
import logging
|
||||
|
||||
logging.warning(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
if use_cache and past_key_values is None:
|
||||
past_key_values = DynamicCache()
|
||||
|
||||
if cache_position is None:
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
cache_position = torch.arange(
|
||||
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
||||
)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
causal_mask = create_causal_mask(
|
||||
config=self.config,
|
||||
inputs_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
cache_position=cache_position,
|
||||
past_key_values=past_key_values,
|
||||
position_ids=position_ids,
|
||||
)
|
||||
|
||||
# embed positions
|
||||
hidden_states = inputs_embeds
|
||||
# Convert to bfloat16 if the first layer uses bfloat16
|
||||
if len(self.layers) > 0 and self.layers[0].self_attn.q_proj.weight.dtype == torch.bfloat16:
|
||||
hidden_states = hidden_states.to(torch.bfloat16)
|
||||
|
||||
# create position embeddings to be shared across the decoder layers
|
||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||
|
||||
# normalized
|
||||
# Gemma downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5
|
||||
# See https://github.com/huggingface/transformers/pull/29402
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
|
||||
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
adarms_cond=adarms_cond,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, adarms_cond)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=past_key_values if use_cache else None,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
|
||||
class PiGemmaForCausalLM(GemmaForCausalLM): # type: ignore[misc]
|
||||
"""
|
||||
Causal LM wrapper using PiGemmaModel as the backbone, for consistency with GemmaForCausalLM
|
||||
and the language model used in pi0_fast. Use this for the action expert in pi0/pi05.
|
||||
"""
|
||||
|
||||
def __init__(self, config: GemmaConfig, **kwargs):
|
||||
super().__init__(config, **kwargs)
|
||||
self.model = PiGemmaModel(config)
|
||||
|
||||
|
||||
class PaliGemmaModelWithPiGemma(PaliGemmaModel):
|
||||
"""PaliGemmaModel whose language_model is PiGemmaModel (custom decoder with PiGemmaRMSNorm and gated residuals)."""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.language_model = PiGemmaModel(config.text_config)
|
||||
|
||||
|
||||
class PaliGemmaForConditionalGenerationWithPiGemma(PaliGemmaForConditionalGeneration):
|
||||
"""PaliGemmaForConditionalGeneration using PiGemma decoder for the language model."""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.model = PaliGemmaModelWithPiGemma(config)
|
||||
|
||||
# Make modules available through conditional class for BC
|
||||
@property
|
||||
def language_model(self):
|
||||
return self.model.language_model
|
||||
|
||||
|
||||
__all__ = [
|
||||
"PiGemmaModel",
|
||||
"PiGemmaForCausalLM",
|
||||
"PiGemmaRMSNorm",
|
||||
"_gated_residual",
|
||||
"layernorm_forward",
|
||||
"PaliGemmaModelWithPiGemma",
|
||||
"PaliGemmaForConditionalGenerationWithPiGemma",
|
||||
]
|
||||
@@ -33,7 +33,7 @@ class RewardClassifierConfig(PreTrainedConfig):
|
||||
latent_dim: int = 256
|
||||
image_embedding_pooling_dim: int = 8
|
||||
dropout_rate: float = 0.1
|
||||
model_name: str = "helper2424/resnet10"
|
||||
model_name: str = "helper2424/resnet10" # TODO: This needs to be updated. The model on the Hub doesn't call self.post_init() in its __init__, which is required by transformers v5 to set all_tied_weights_keys. The from_pretrained call fails when it tries to access this attribute during _finalize_model_loading.
|
||||
device: str = "cpu"
|
||||
model_type: str = "cnn" # "transformer" or "cnn"
|
||||
num_cameras: int = 2
|
||||
|
||||
@@ -27,18 +27,18 @@ Usage:
|
||||
# Full RA-BC computation with visualizations
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
|
||||
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
|
||||
--reward-model-path pepijn223/sarm_single_uni4
|
||||
--reward-model-path <USER>/sarm_single_uni4
|
||||
|
||||
# Faster computation with stride (compute every 5 frames, interpolate the rest)
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
|
||||
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
|
||||
--reward-model-path pepijn223/sarm_single_uni4 \\
|
||||
--reward-model-path <USER>/sarm_single_uni4 \\
|
||||
--stride 5
|
||||
|
||||
# Visualize predictions only (no RA-BC computation)
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
|
||||
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
|
||||
--reward-model-path pepijn223/sarm_single_uni4 \\
|
||||
--reward-model-path <USER>/sarm_single_uni4 \\
|
||||
--visualize-only \\
|
||||
--num-visualizations 5
|
||||
|
||||
@@ -714,12 +714,12 @@ Examples:
|
||||
# Full RA-BC computation with visualizations
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
|
||||
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
|
||||
--reward-model-path pepijn223/sarm_single_uni4
|
||||
--reward-model-path <USER>/sarm_single_uni4
|
||||
|
||||
# Visualize predictions only (no RA-BC computation)
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
|
||||
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
|
||||
--reward-model-path pepijn223/sarm_single_uni4 \\
|
||||
--reward-model-path <USER>/sarm_single_uni4 \\
|
||||
--visualize-only \\
|
||||
--num-visualizations 10
|
||||
""",
|
||||
|
||||
@@ -277,9 +277,7 @@ class SARMEncodingProcessorStep(ProcessorStep):
|
||||
|
||||
# When language is perturbed, targets are zero so perturbed samples don't contribute to progress loss
|
||||
if self.dataset_meta is not None:
|
||||
episodes_df = None
|
||||
if self.sparse_subtask_names != ["task"]:
|
||||
episodes_df = self.dataset_meta.episodes.to_pandas()
|
||||
episodes_df = self.dataset_meta.episodes.to_pandas()
|
||||
|
||||
# Generate sparse targets
|
||||
if self.sparse_temporal_proportions is not None:
|
||||
|
||||
@@ -85,7 +85,7 @@ class SmolVLAConfig(PreTrainedConfig):
|
||||
scheduler_decay_lr: float = 2.5e-6
|
||||
|
||||
vlm_model_name: str = "HuggingFaceTB/SmolVLM2-500M-Video-Instruct" # Select the VLM backbone.
|
||||
load_vlm_weights: bool = False # Set to True in case of training the expert from scratch. True when init from pretrained SmolVLA weights
|
||||
load_vlm_weights: bool = False # Set to False in case of training the expert from scratch. True when init from pretrained SmolVLA weights
|
||||
|
||||
add_image_special_tokens: bool = False # Whether to use special image tokens around image features.
|
||||
|
||||
@@ -106,6 +106,9 @@ class SmolVLAConfig(PreTrainedConfig):
|
||||
# Real-Time Chunking (RTC) configuration
|
||||
rtc_config: RTCConfig | None = None
|
||||
|
||||
compile_model: bool = False # Whether to use torch.compile for model optimization
|
||||
compile_mode: str = "max-autotune" # Torch compile mode
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ Example of finetuning the smolvla pretrained model (`smolvla_base`):
|
||||
```bash
|
||||
lerobot-train \
|
||||
--policy.path=lerobot/smolvla_base \
|
||||
--dataset.repo_id=danaaubakirova/svla_so100_task1_v3 \
|
||||
--dataset.repo_id=<USER>/svla_so100_task1_v3 \
|
||||
--batch_size=64 \
|
||||
--steps=200000
|
||||
```
|
||||
@@ -40,7 +40,7 @@ and an action expert.
|
||||
```bash
|
||||
lerobot-train \
|
||||
--policy.type=smolvla \
|
||||
--dataset.repo_id=danaaubakirova/svla_so100_task1_v3 \
|
||||
--dataset.repo_id=<USER>/svla_so100_task1_v3 \
|
||||
--batch_size=64 \
|
||||
--steps=200000
|
||||
```
|
||||
@@ -593,6 +593,12 @@ class VLAFlowMatching(nn.Module):
|
||||
self.prefix_length = self.config.prefix_length
|
||||
self.rtc_processor = rtc_processor
|
||||
|
||||
# Compile model if requested
|
||||
if config.compile_model:
|
||||
torch.set_float32_matmul_precision("high")
|
||||
self.sample_actions = torch.compile(self.sample_actions, mode=config.compile_mode)
|
||||
self.forward = torch.compile(self.forward, mode=config.compile_mode)
|
||||
|
||||
def _rtc_enabled(self):
|
||||
return self.config.rtc_config is not None and self.config.rtc_config.enabled
|
||||
|
||||
|
||||
@@ -77,7 +77,6 @@ class SmolVLMWithExpertModel(nn.Module):
|
||||
print(f"Loading {model_id} weights ...")
|
||||
self.vlm = AutoModelForImageTextToText.from_pretrained(
|
||||
model_id,
|
||||
device_map=device,
|
||||
torch_dtype="bfloat16",
|
||||
low_cpu_mem_usage=True,
|
||||
)
|
||||
|
||||
@@ -55,7 +55,7 @@ class WallXConfig(PreTrainedConfig):
|
||||
pretrained_name_or_path: str = "x-square-robot/wall-oss-flow"
|
||||
|
||||
# Tokenizer settings
|
||||
action_tokenizer_path: str | None = "physical-intelligence/fast"
|
||||
action_tokenizer_path: str | None = "lerobot/fast-action-tokenizer"
|
||||
|
||||
# Action prediction mode: "diffusion" or "fast"
|
||||
prediction_mode: str = "diffusion"
|
||||
|
||||
@@ -261,10 +261,15 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
||||
and optional LoRA fine-tuning support.
|
||||
"""
|
||||
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
|
||||
config_class = Qwen2_5_VLConfig
|
||||
_no_split_modules = ["Qwen2_5_VLDecoderLayer_with_MoE", "Qwen2_5_VLVisionBlock"]
|
||||
|
||||
def init_weights(self):
|
||||
if getattr(self.model, "language_model", None) is not None:
|
||||
return
|
||||
super().init_weights()
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
@@ -312,6 +317,11 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
||||
processor.action_processor = action_tokenizer
|
||||
else:
|
||||
action_tokenizer = None
|
||||
|
||||
# add pad_token_id to config
|
||||
config.pad_token_id = processor.tokenizer.pad_token_id
|
||||
config.text_config.pad_token_id = processor.tokenizer.pad_token_id
|
||||
|
||||
# Initialize model with configuration and processor
|
||||
model = cls(config, processor=processor, action_tokenizer=action_tokenizer, **kwargs)
|
||||
|
||||
@@ -331,7 +341,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
||||
force_download=kwargs.get("force_download", False),
|
||||
resume_download=kwargs.get("resume_download"),
|
||||
proxies=kwargs.get("proxies"),
|
||||
use_auth_token=kwargs.get("use_auth_token"),
|
||||
token=kwargs.get("token"),
|
||||
revision=kwargs.get("revision"),
|
||||
local_files_only=kwargs.get("local_files_only", False),
|
||||
)
|
||||
|
||||
@@ -21,6 +21,7 @@ class Qwen2_5_VLVisionConfig(PretrainedConfig):
|
||||
window_size=112,
|
||||
out_hidden_size=3584,
|
||||
fullatt_block_indexes=[7, 15, 23, 31],
|
||||
initializer_range=0.02,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
@@ -38,6 +39,7 @@ class Qwen2_5_VLVisionConfig(PretrainedConfig):
|
||||
self.window_size = window_size
|
||||
self.fullatt_block_indexes = fullatt_block_indexes
|
||||
self.out_hidden_size = out_hidden_size
|
||||
self.initializer_range = initializer_range
|
||||
|
||||
|
||||
class Qwen2_5_VLConfig(PretrainedConfig):
|
||||
|
||||
@@ -11,7 +11,6 @@ from transformers.activations import ACT2FN
|
||||
from transformers.cache_utils import (
|
||||
Cache,
|
||||
DynamicCache,
|
||||
SlidingWindowCache,
|
||||
StaticCache,
|
||||
)
|
||||
from transformers.generation import GenerationMixin
|
||||
@@ -31,6 +30,15 @@ from transformers.utils import (
|
||||
|
||||
from .configuration_qwen2_5_vl import Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig
|
||||
|
||||
|
||||
# TODO(Steven): SlidingWindowCache was removed in transformers v5. Define a placeholder so isinstance checks
|
||||
# always return False (which is the correct behavior when no sliding window cache is in use).
|
||||
class _SlidingWindowCachePlaceholder:
|
||||
pass
|
||||
|
||||
|
||||
SlidingWindowCache = _SlidingWindowCachePlaceholder
|
||||
|
||||
if is_flash_attn_2_available():
|
||||
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
||||
from flash_attn.layers.rotary import apply_rotary_emb
|
||||
@@ -594,19 +602,40 @@ class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel):
|
||||
return hidden_states
|
||||
|
||||
|
||||
def _compute_default_rope_parameters_qwen2_5_vl(config, device=None):
|
||||
"""
|
||||
compute default rope parameters for Qwen2_5_VL
|
||||
"""
|
||||
base = config.text_config.rope_parameters["rope_theta"]
|
||||
dim = config.hidden_size // config.num_attention_heads
|
||||
inv_freq = 1.0 / (
|
||||
base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
|
||||
)
|
||||
return inv_freq, 1.0
|
||||
|
||||
|
||||
class Qwen2_5_VLRotaryEmbedding(nn.Module):
|
||||
def __init__(self, config: Qwen2_5_VLConfig, device=None):
|
||||
super().__init__()
|
||||
# BC: "rope_type" was originally "type"
|
||||
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
||||
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
||||
elif hasattr(config, "rope_parameters") and config.rope_parameters is not None:
|
||||
self.rope_type = config.rope_parameters.get("rope_type", "default")
|
||||
else:
|
||||
self.rope_type = "default"
|
||||
self.max_seq_len_cached = config.max_position_embeddings
|
||||
self.original_max_seq_len = config.max_position_embeddings
|
||||
|
||||
self.config = config
|
||||
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
||||
|
||||
if self.rope_type == "default":
|
||||
self.rope_init_fn = _compute_default_rope_parameters_qwen2_5_vl
|
||||
self.rope_kwargs = {}
|
||||
else:
|
||||
rope_type_key = "linear" if self.rope_type == "linear" else self.rope_type
|
||||
self.rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type_key]
|
||||
self.rope_kwargs = {}
|
||||
|
||||
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
@@ -1567,7 +1596,7 @@ QWEN2_5_VL_INPUTS_DOCSTRING = r"""
|
||||
|
||||
|
||||
class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
|
||||
config_class = Qwen2_5_VLConfig
|
||||
_no_split_modules = ["Qwen2_5_VLDecoderLayer", "Qwen2_5_VLVisionBlock"]
|
||||
|
||||
|
||||
@@ -144,7 +144,7 @@ def preprocesser_call(
|
||||
"""
|
||||
# Process image inputs
|
||||
if images is not None and len(images) > 0:
|
||||
image_inputs = processor.image_processor(images=images, videos=None, return_tensors=return_tensors)
|
||||
image_inputs = processor.image_processor(images=images, return_tensors=return_tensors)
|
||||
image_grid_thw = image_inputs["image_grid_thw"]
|
||||
else:
|
||||
image_inputs = {}
|
||||
@@ -152,7 +152,7 @@ def preprocesser_call(
|
||||
|
||||
# Process video inputs
|
||||
if videos is not None:
|
||||
videos_inputs = processor.image_processor(images=None, videos=videos, return_tensors=return_tensors)
|
||||
videos_inputs = processor.image_processor(videos=videos, return_tensors=return_tensors)
|
||||
video_grid_thw = videos_inputs["video_grid_thw"]
|
||||
else:
|
||||
videos_inputs = {}
|
||||
|
||||
@@ -276,6 +276,8 @@ class Florence2LanguageConfig(PretrainedConfig):
|
||||
)
|
||||
|
||||
# ensure backward compatibility for BART CNN models
|
||||
if not hasattr(self, "forced_bos_token_id"):
|
||||
self.forced_bos_token_id = None
|
||||
if self.forced_bos_token_id is None and kwargs.get("force_bos_token_to_be_generated", False):
|
||||
self.forced_bos_token_id = self.bos_token_id
|
||||
warnings.warn(
|
||||
|
||||
@@ -1951,7 +1951,10 @@ class Florence2Decoder(Florence2LanguagePreTrainedModel):
|
||||
|
||||
|
||||
class Florence2LanguageModel(Florence2LanguagePreTrainedModel):
|
||||
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
|
||||
_tied_weights_keys = {
|
||||
"encoder.embed_tokens.weight": "shared.weight",
|
||||
"decoder.embed_tokens.weight": "shared.weight",
|
||||
}
|
||||
|
||||
def __init__(self, config: Florence2LanguageConfig):
|
||||
super().__init__(config)
|
||||
@@ -2076,7 +2079,10 @@ class Florence2LanguageModel(Florence2LanguagePreTrainedModel):
|
||||
|
||||
class Florence2LanguageForConditionalGeneration(Florence2LanguagePreTrainedModel, GenerationMixin):
|
||||
base_model_prefix = "model"
|
||||
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
|
||||
_tied_weights_keys = {
|
||||
"model.encoder.embed_tokens.weight": "model.shared.weight",
|
||||
"model.decoder.embed_tokens.weight": "model.shared.weight",
|
||||
}
|
||||
_keys_to_ignore_on_load_missing = ["final_logits_bias"]
|
||||
|
||||
def __init__(self, config: Florence2LanguageConfig):
|
||||
@@ -2436,11 +2442,10 @@ FLORENCE2_INPUTS_DOCSTRING = r"""
|
||||
FLORENCE2_START_DOCSTRING,
|
||||
)
|
||||
class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
|
||||
_tied_weights_keys = [
|
||||
"language_model.encoder.embed_tokens.weight",
|
||||
"language_model.decoder.embed_tokens.weight",
|
||||
"language_model.lm_head.weight",
|
||||
]
|
||||
_tied_weights_keys = {
|
||||
"language_model.model.encoder.embed_tokens.weight": "language_model.model.shared.weight",
|
||||
"language_model.model.decoder.embed_tokens.weight": "language_model.model.shared.weight",
|
||||
}
|
||||
|
||||
def __init__(self, config: Florence2Config):
|
||||
super().__init__(config)
|
||||
|
||||
@@ -44,6 +44,7 @@ from .hil_processor import (
|
||||
AddTeleopActionAsComplimentaryDataStep,
|
||||
AddTeleopEventsAsInfoStep,
|
||||
GripperPenaltyProcessorStep,
|
||||
GymHILAdapterProcessorStep,
|
||||
ImageCropResizeProcessorStep,
|
||||
InterventionActionProcessorStep,
|
||||
RewardClassifierProcessorStep,
|
||||
@@ -87,6 +88,7 @@ __all__ = [
|
||||
"DoneProcessorStep",
|
||||
"EnvAction",
|
||||
"EnvTransition",
|
||||
"GymHILAdapterProcessorStep",
|
||||
"GripperPenaltyProcessorStep",
|
||||
"hotswap_stats",
|
||||
"IdentityProcessorStep",
|
||||
|
||||
@@ -20,6 +20,7 @@ from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
|
||||
from .converters import to_tensor
|
||||
from .core import EnvAction, EnvTransition, PolicyAction
|
||||
from .hil_processor import TELEOP_ACTION_KEY
|
||||
from .pipeline import ActionProcessorStep, ProcessorStep, ProcessorStepRegistry
|
||||
|
||||
|
||||
@@ -89,6 +90,13 @@ class Numpy2TorchActionProcessorStep(ProcessorStep):
|
||||
torch_action = to_tensor(action, dtype=None) # Preserve original dtype
|
||||
new_transition[TransitionKey.ACTION] = torch_action
|
||||
|
||||
complementary_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||
if TELEOP_ACTION_KEY in complementary_data:
|
||||
teleop_action = complementary_data[TELEOP_ACTION_KEY]
|
||||
if isinstance(teleop_action, EnvAction):
|
||||
complementary_data[TELEOP_ACTION_KEY] = to_tensor(teleop_action)
|
||||
new_transition[TransitionKey.COMPLEMENTARY_DATA] = complementary_data
|
||||
|
||||
return new_transition
|
||||
|
||||
def transform_features(
|
||||
|
||||
@@ -312,6 +312,37 @@ class TimeLimitProcessorStep(TruncatedProcessorStep):
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("gym_hil_adapter_processor")
|
||||
class GymHILAdapterProcessorStep(ProcessorStep):
|
||||
"""
|
||||
Adapts the output of the `gym-hil` environment to the format expected by `lerobot` processors.
|
||||
|
||||
This step normalizes the `transition` object by:
|
||||
1. Copying `teleop_action` from `info` to `complementary_data`.
|
||||
2. Copying `is_intervention` from `info` (using the string key) to `info` (using the enum key).
|
||||
"""
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
info = transition.get(TransitionKey.INFO, {})
|
||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||
|
||||
if TELEOP_ACTION_KEY in info:
|
||||
complementary_data[TELEOP_ACTION_KEY] = info[TELEOP_ACTION_KEY]
|
||||
|
||||
if "is_intervention" in info:
|
||||
info[TeleopEvents.IS_INTERVENTION] = info["is_intervention"]
|
||||
|
||||
transition[TransitionKey.INFO] = info
|
||||
transition[TransitionKey.COMPLEMENTARY_DATA] = complementary_data
|
||||
|
||||
return transition
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register("gripper_penalty_processor")
|
||||
class GripperPenaltyProcessorStep(ProcessorStep):
|
||||
|
||||
@@ -336,7 +336,7 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
|
||||
Requires the `transformers` library to be installed.
|
||||
|
||||
Attributes:
|
||||
tokenizer_name: The name of a pretrained processor from the Hugging Face Hub (e.g., "physical-intelligence/fast").
|
||||
tokenizer_name: The name of a pretrained processor from the Hugging Face Hub (e.g., "lerobot/fast-action-tokenizer").
|
||||
tokenizer: A pre-initialized processor/tokenizer object. If provided, `tokenizer_name` is ignored.
|
||||
trust_remote_code: Whether to trust remote code when loading the tokenizer (required for some tokenizers).
|
||||
action_tokenizer: The internal tokenizer/processor instance, loaded during initialization.
|
||||
|
||||
@@ -36,6 +36,7 @@ from lerobot.processor import (
|
||||
DeviceProcessorStep,
|
||||
EnvTransition,
|
||||
GripperPenaltyProcessorStep,
|
||||
GymHILAdapterProcessorStep,
|
||||
ImageCropResizeProcessorStep,
|
||||
InterventionActionProcessorStep,
|
||||
MapDeltaActionToRobotActionStep,
|
||||
@@ -379,6 +380,7 @@ def make_processors(
|
||||
]
|
||||
|
||||
env_pipeline_steps = [
|
||||
GymHILAdapterProcessorStep(),
|
||||
Numpy2TorchActionProcessorStep(),
|
||||
VanillaObservationProcessorStep(),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
@@ -608,7 +610,14 @@ def control_loop(
|
||||
|
||||
dataset = None
|
||||
if cfg.mode == "record":
|
||||
action_features = teleop_device.action_features
|
||||
if teleop_device:
|
||||
action_features = teleop_device.action_features
|
||||
else:
|
||||
action_features = {
|
||||
"dtype": "float32",
|
||||
"shape": (4,),
|
||||
"names": ["delta_x", "delta_y", "delta_z", "gripper"],
|
||||
}
|
||||
features = {
|
||||
ACTION: action_features,
|
||||
REWARD: {"dtype": "float32", "shape": (1,), "names": None},
|
||||
@@ -656,7 +665,7 @@ def control_loop(
|
||||
# Create a neutral action (no movement)
|
||||
neutral_action = torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32)
|
||||
if use_gripper:
|
||||
neutral_action = torch.cat([neutral_action, torch.tensor([1.0])]) # Gripper stay
|
||||
neutral_action = torch.cat([neutral_action, torch.tensor([0.0])]) # Gripper stay
|
||||
|
||||
# Use the new step function
|
||||
transition = step_env_and_process_transition(
|
||||
@@ -725,6 +734,8 @@ def control_loop(
|
||||
precise_sleep(max(dt - (time.perf_counter() - step_start_time), 0.0))
|
||||
|
||||
if dataset is not None and cfg.dataset.push_to_hub:
|
||||
logging.info("Finalizing dataset before pushing to hub")
|
||||
dataset.finalize()
|
||||
logging.info("Pushing dataset to hub")
|
||||
dataset.push_to_hub()
|
||||
|
||||
|
||||
@@ -140,7 +140,7 @@ class HopeJrArm(Robot):
|
||||
# Capture images from cameras
|
||||
for cam_key, cam in self.cameras.items():
|
||||
start = time.perf_counter()
|
||||
obs_dict[cam_key] = cam.async_read()
|
||||
obs_dict[cam_key] = cam.read_latest()
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
||||
|
||||
|
||||
@@ -171,7 +171,7 @@ class HopeJrHand(Robot):
|
||||
# Capture images from cameras
|
||||
for cam_key, cam in self.cameras.items():
|
||||
start = time.perf_counter()
|
||||
obs_dict[cam_key] = cam.async_read()
|
||||
obs_dict[cam_key] = cam.read_latest()
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
||||
|
||||
|
||||
@@ -193,7 +193,7 @@ class KochFollower(Robot):
|
||||
# Capture images from cameras
|
||||
for cam_key, cam in self.cameras.items():
|
||||
start = time.perf_counter()
|
||||
obs_dict[cam_key] = cam.async_read()
|
||||
obs_dict[cam_key] = cam.read_latest()
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
||||
|
||||
|
||||
@@ -360,7 +360,7 @@ class LeKiwi(Robot):
|
||||
# Capture images from cameras
|
||||
for cam_key, cam in self.cameras.items():
|
||||
start = time.perf_counter()
|
||||
obs_dict[cam_key] = cam.async_read()
|
||||
obs_dict[cam_key] = cam.read_latest()
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
||||
|
||||
|
||||
@@ -176,7 +176,7 @@ class OmxFollower(Robot):
|
||||
# Capture images from cameras
|
||||
for cam_key, cam in self.cameras.items():
|
||||
start = time.perf_counter()
|
||||
obs_dict[cam_key] = cam.async_read()
|
||||
obs_dict[cam_key] = cam.read_latest()
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
||||
|
||||
|
||||
@@ -241,7 +241,7 @@ class OpenArmFollower(Robot):
|
||||
# Capture images from cameras
|
||||
for cam_key, cam in self.cameras.items():
|
||||
start = time.perf_counter()
|
||||
obs_dict[cam_key] = cam.async_read()
|
||||
obs_dict[cam_key] = cam.read_latest()
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
||||
|
||||
|
||||
@@ -180,7 +180,7 @@ class Reachy2Robot(Robot):
|
||||
|
||||
# Capture images from cameras
|
||||
for cam_key, cam in self.cameras.items():
|
||||
obs_dict[cam_key] = cam.async_read()
|
||||
obs_dict[cam_key] = cam.read_latest()
|
||||
|
||||
return obs_dict
|
||||
|
||||
|
||||
@@ -187,7 +187,7 @@ class SOFollower(Robot):
|
||||
# Capture images from cameras
|
||||
for cam_key, cam in self.cameras.items():
|
||||
start = time.perf_counter()
|
||||
obs_dict[cam_key] = cam.async_read()
|
||||
obs_dict[cam_key] = cam.read_latest()
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
||||
|
||||
|
||||
@@ -324,7 +324,7 @@ class UnitreeG1(Robot):
|
||||
|
||||
# Cameras - read images from ZMQ cameras
|
||||
for cam_name, cam in self._cameras.items():
|
||||
obs[cam_name] = cam.async_read()
|
||||
obs[cam_name] = cam.read_latest()
|
||||
|
||||
return obs
|
||||
|
||||
|
||||
@@ -56,6 +56,7 @@ from lerobot.teleoperators import ( # noqa: F401
|
||||
make_teleoperator_from_config,
|
||||
omx_leader,
|
||||
openarm_leader,
|
||||
openarm_mini,
|
||||
so_leader,
|
||||
unitree_g1,
|
||||
)
|
||||
|
||||
@@ -47,16 +47,14 @@ local$ rerun lerobot_pusht_episode_0.rrd
|
||||
```
|
||||
|
||||
- Visualize data stored on a distant machine through streaming:
|
||||
(You need to forward the websocket port to the distant machine, with
|
||||
`ssh -L 9087:localhost:9087 username@remote-host`)
|
||||
```
|
||||
distant$ lerobot-dataset-viz \
|
||||
--repo-id lerobot/pusht \
|
||||
--episode-index 0 \
|
||||
--mode distant \
|
||||
--ws-port 9087
|
||||
--grpc-port 9876
|
||||
|
||||
local$ rerun ws://localhost:9087
|
||||
local$ rerun rerun+http://IP:GRPC_PORT/proxy
|
||||
```
|
||||
|
||||
"""
|
||||
@@ -75,6 +73,7 @@ import tqdm
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.utils.constants import ACTION, DONE, OBS_STATE, REWARD
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
|
||||
def to_hwc_uint8_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray:
|
||||
@@ -93,10 +92,11 @@ def visualize_dataset(
|
||||
num_workers: int = 0,
|
||||
mode: str = "local",
|
||||
web_port: int = 9090,
|
||||
ws_port: int = 9087,
|
||||
grpc_port: int = 9876,
|
||||
save: bool = False,
|
||||
output_dir: Path | None = None,
|
||||
display_compressed_images: bool = False,
|
||||
**kwargs,
|
||||
) -> Path | None:
|
||||
if save:
|
||||
assert output_dir is not None, (
|
||||
@@ -126,14 +126,19 @@ def visualize_dataset(
|
||||
gc.collect()
|
||||
|
||||
if mode == "distant":
|
||||
rr.serve_web_viewer(open_browser=False, web_port=web_port)
|
||||
server_uri = rr.serve_grpc(grpc_port=grpc_port)
|
||||
logging.info(f"Connect to a Rerun Server: rerun rerun+http://IP:{grpc_port}/proxy")
|
||||
rr.serve_web_viewer(open_browser=False, web_port=web_port, connect_to=server_uri)
|
||||
|
||||
logging.info("Logging to Rerun")
|
||||
|
||||
first_index = None
|
||||
for batch in tqdm.tqdm(dataloader, total=len(dataloader)):
|
||||
if first_index is None:
|
||||
first_index = batch["index"][0].item()
|
||||
# iterate over the batch
|
||||
for i in range(len(batch["index"])):
|
||||
rr.set_time("frame_index", sequence=batch["frame_index"][i].item())
|
||||
rr.set_time("frame_index", sequence=batch["index"][i].item() - first_index)
|
||||
rr.set_time("timestamp", timestamp=batch["timestamp"][i].item())
|
||||
|
||||
# display each camera image
|
||||
@@ -226,7 +231,7 @@ def main():
|
||||
"Mode of viewing between 'local' or 'distant'. "
|
||||
"'local' requires data to be on a local machine. It spawns a viewer to visualize the data locally. "
|
||||
"'distant' creates a server on the distant machine where the data is stored. "
|
||||
"Visualize the data by connecting to the server with `rerun ws://localhost:PORT` on the local machine."
|
||||
"Visualize the data by connecting to the server with `rerun rerun+http://IP:GRPC_PORT/proxy` on the local machine."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
@@ -238,8 +243,13 @@ def main():
|
||||
parser.add_argument(
|
||||
"--ws-port",
|
||||
type=int,
|
||||
default=9087,
|
||||
help="Web socket port for rerun.io when `--mode distant` is set.",
|
||||
help="deprecated, please use --grpc-port instead.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--grpc-port",
|
||||
type=int,
|
||||
default=9876,
|
||||
help="gRPC port for rerun.io when `--mode distant` is set.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save",
|
||||
@@ -265,9 +275,7 @@ def main():
|
||||
|
||||
parser.add_argument(
|
||||
"--display-compressed-images",
|
||||
type=bool,
|
||||
required=True,
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="If set, display compressed images in Rerun instead of uncompressed ones.",
|
||||
)
|
||||
|
||||
@@ -277,6 +285,14 @@ def main():
|
||||
root = kwargs.pop("root")
|
||||
tolerance_s = kwargs.pop("tolerance_s")
|
||||
|
||||
if kwargs["ws_port"] is not None:
|
||||
logging.warning(
|
||||
"--ws-port is deprecated and will be removed in future versions. Please use --grpc-port instead."
|
||||
)
|
||||
logging.warning("Setting grpc_port to ws_port value.")
|
||||
kwargs["grpc_port"] = kwargs.pop("ws_port")
|
||||
|
||||
init_logging()
|
||||
logging.info("Loading dataset")
|
||||
dataset = LeRobotDataset(repo_id, episodes=[args.episode_index], root=root, tolerance_s=tolerance_s)
|
||||
|
||||
|
||||
@@ -21,97 +21,142 @@ This script allows you to delete episodes, split datasets, merge datasets,
|
||||
remove features, modify tasks, and convert image datasets to video format.
|
||||
When new_repo_id is specified, creates a new dataset.
|
||||
|
||||
Path semantics (v2): --root and --new_root are exact dataset folders containing
|
||||
meta/, data/, videos/. When omitted, defaults to $HF_LEROBOT_HOME/{repo_id}.
|
||||
|
||||
Usage Examples:
|
||||
|
||||
Delete episodes 0, 2, and 5 from a dataset:
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--operation.type delete_episodes \
|
||||
--operation.episode_indices "[0, 2, 5]"
|
||||
|
||||
Delete episodes and save to a new dataset:
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
Delete episodes from a local dataset at a specific path:
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--root /path/to/pusht \
|
||||
--operation.type delete_episodes \
|
||||
--operation.episode_indices "[0, 2, 5]"
|
||||
|
||||
Delete episodes and save to a new dataset at a specific path and with a new repo_id:
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--new_repo_id lerobot/pusht_filtered \
|
||||
--new_root /path/to/pusht_filtered \
|
||||
--operation.type delete_episodes \
|
||||
--operation.episode_indices "[0, 2, 5]"
|
||||
|
||||
Split dataset by fractions:
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
Split dataset by fractions (pusht_train, pusht_val):
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--operation.type split \
|
||||
--operation.splits '{"train": 0.8, "val": 0.2}'
|
||||
|
||||
Split dataset by fractions and save split datasets to a specific folder (base_folder/train, base_folder/val):
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--new_root /path/to/base_folder \
|
||||
--operation.type split \
|
||||
--operation.splits '{"train": 0.8, "val": 0.2}'
|
||||
|
||||
Split dataset by episode indices:
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--operation.type split \
|
||||
--operation.splits '{"train": [0, 1, 2, 3], "val": [4, 5]}'
|
||||
|
||||
Split into more than two splits:
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--operation.type split \
|
||||
--operation.splits '{"train": 0.6, "val": 0.2, "test": 0.2}'
|
||||
|
||||
Merge multiple datasets:
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
--repo_id lerobot/pusht_merged \
|
||||
lerobot-edit-dataset \
|
||||
--new_repo_id lerobot/pusht_merged \
|
||||
--operation.type merge \
|
||||
--operation.repo_ids "['lerobot/pusht_train', 'lerobot/pusht_val']"
|
||||
|
||||
Merge multiple datasets to a specific output path:
|
||||
lerobot-edit-dataset \
|
||||
--new_repo_id lerobot/pusht_merged \
|
||||
--new_root /path/to/pusht_merged \
|
||||
--operation.type merge \
|
||||
--operation.repo_ids "['lerobot/pusht_train', 'lerobot/pusht_val']"
|
||||
|
||||
Merge multiple datasets from a list of local dataset paths:
|
||||
lerobot-edit-dataset \
|
||||
--new_repo_id lerobot/pusht_merged \
|
||||
--operation.type merge \
|
||||
--operation.repo_ids "['pusht_train', 'pusht_val']" \
|
||||
--operation.roots "['/path/to/pusht_train', '/path/to/pusht_val']"
|
||||
|
||||
Remove camera feature:
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--operation.type remove_feature \
|
||||
--operation.feature_names "['observation.images.top']"
|
||||
--operation.feature_names "['observation.image']"
|
||||
|
||||
Modify tasks - set a single task for all episodes (WARNING: modifies in-place):
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--operation.type modify_tasks \
|
||||
--operation.new_task "Pick up the cube and place it"
|
||||
|
||||
Modify tasks - set different tasks for specific episodes (WARNING: modifies in-place):
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--operation.type modify_tasks \
|
||||
--operation.episode_tasks '{"0": "Task A", "1": "Task B", "2": "Task A"}'
|
||||
|
||||
Modify tasks - set default task with overrides for specific episodes (WARNING: modifies in-place):
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--operation.type modify_tasks \
|
||||
--operation.new_task "Default task" \
|
||||
--operation.episode_tasks '{"5": "Special task for episode 5"}'
|
||||
|
||||
Convert image dataset to video format and save locally:
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--operation.type convert_image_to_video \
|
||||
--operation.output_dir /path/to/output/pusht_video
|
||||
--new_root /path/to/output/pusht_video \
|
||||
--operation.type convert_image_to_video
|
||||
|
||||
Convert image dataset to video format and save with new repo_id:
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--new_repo_id lerobot/pusht_video \
|
||||
--operation.type convert_image_to_video
|
||||
|
||||
Convert image dataset to video format and push to hub:
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--new_repo_id lerobot/pusht_video \
|
||||
--operation.type convert_image_to_video \
|
||||
--push_to_hub true
|
||||
|
||||
Show dataset information:
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--operation.type info \
|
||||
--operation.show_features true
|
||||
|
||||
Show dataset information without feature details:
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--operation.type info \
|
||||
--operation.show_features false
|
||||
|
||||
Using JSON config file:
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
lerobot-edit-dataset \
|
||||
--config_path path/to/edit_config.json
|
||||
"""
|
||||
|
||||
import abc
|
||||
import logging
|
||||
import shutil
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
@@ -154,6 +199,7 @@ class SplitConfig(OperationConfig):
|
||||
@dataclass
|
||||
class MergeConfig(OperationConfig):
|
||||
repo_ids: list[str] | None = None
|
||||
roots: list[str] | None = None
|
||||
|
||||
|
||||
@OperationConfig.register_subclass("remove_feature")
|
||||
@@ -184,32 +230,49 @@ class ConvertImageToVideoConfig(OperationConfig):
|
||||
max_frames_per_batch: int | None = None
|
||||
|
||||
|
||||
@OperationConfig.register_subclass("info")
|
||||
@dataclass
|
||||
class InfoConfig(OperationConfig):
|
||||
show_features: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class EditDatasetConfig:
|
||||
repo_id: str
|
||||
# Operation configuration.
|
||||
operation: OperationConfig
|
||||
# Input dataset identifier. Always required unless for Merge operation.
|
||||
repo_id: str | None = None
|
||||
# Root directory where the input dataset is stored. If not specified, defaults to $HF_LEROBOT_HOME/repo_id.
|
||||
root: str | None = None
|
||||
# Edited dataset identifier. When both new_repo_id (resp. new_root) and repo_id (resp. root) are identical, modifications are applied in-place and a backup of the original dataset is created. Required for Merge operation.
|
||||
new_repo_id: str | None = None
|
||||
# Root directory where the edited dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/new_repo_id. For Split operation, this is the base directory for the split datasets.
|
||||
new_root: str | None = None
|
||||
# Upload dataset to Hugging Face hub.
|
||||
push_to_hub: bool = False
|
||||
|
||||
|
||||
def get_output_path(repo_id: str, new_repo_id: str | None, root: Path | None) -> tuple[str, Path]:
|
||||
if new_repo_id:
|
||||
output_repo_id = new_repo_id
|
||||
output_dir = root / new_repo_id if root else HF_LEROBOT_HOME / new_repo_id
|
||||
else:
|
||||
output_repo_id = repo_id
|
||||
dataset_path = root / repo_id if root else HF_LEROBOT_HOME / repo_id
|
||||
old_path = Path(str(dataset_path) + "_old")
|
||||
def get_output_path(
|
||||
repo_id: str,
|
||||
new_repo_id: str | None,
|
||||
root: Path | str | None,
|
||||
new_root: Path | str | None,
|
||||
) -> tuple[str, Path]:
|
||||
input_path = Path(root) if root else HF_LEROBOT_HOME / repo_id
|
||||
|
||||
if dataset_path.exists():
|
||||
if old_path.exists():
|
||||
shutil.rmtree(old_path)
|
||||
shutil.move(str(dataset_path), str(old_path))
|
||||
output_repo_id = new_repo_id if new_repo_id else repo_id
|
||||
output_path = Path(new_root) if new_root else HF_LEROBOT_HOME / output_repo_id
|
||||
|
||||
output_dir = dataset_path
|
||||
# In case of in-place modification, create a backup of the original dataset (if it exists)
|
||||
if output_path == input_path:
|
||||
backup_path = input_path.with_name(input_path.name + "_old")
|
||||
|
||||
return output_repo_id, output_dir
|
||||
if input_path.exists():
|
||||
if backup_path.exists():
|
||||
shutil.rmtree(backup_path)
|
||||
shutil.move(input_path, backup_path)
|
||||
|
||||
return output_repo_id, output_path
|
||||
|
||||
|
||||
def handle_delete_episodes(cfg: EditDatasetConfig) -> None:
|
||||
@@ -221,11 +284,15 @@ def handle_delete_episodes(cfg: EditDatasetConfig) -> None:
|
||||
|
||||
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root)
|
||||
output_repo_id, output_dir = get_output_path(
|
||||
cfg.repo_id, cfg.new_repo_id, Path(cfg.root) if cfg.root else None
|
||||
cfg.repo_id,
|
||||
new_repo_id=cfg.new_repo_id,
|
||||
root=cfg.root,
|
||||
new_root=cfg.new_root,
|
||||
)
|
||||
|
||||
if cfg.new_repo_id is None:
|
||||
dataset.root = Path(str(dataset.root) + "_old")
|
||||
# In case of in-place modification, make the dataset point to the backup directory
|
||||
if output_dir == dataset.root:
|
||||
dataset.root = dataset.root.with_name(dataset.root.name + "_old")
|
||||
|
||||
logging.info(f"Deleting episodes {cfg.operation.episode_indices} from {cfg.repo_id}")
|
||||
new_dataset = delete_episodes(
|
||||
@@ -252,19 +319,27 @@ def handle_split(cfg: EditDatasetConfig) -> None:
|
||||
"splits dict must be specified with split names as keys and fractions/episode lists as values"
|
||||
)
|
||||
|
||||
if cfg.new_repo_id is not None:
|
||||
logging.warning(
|
||||
"split uses the original dataset identifier --repo_id to generate split names. The --new_repo_id parameter is ignored."
|
||||
)
|
||||
|
||||
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root)
|
||||
|
||||
logging.info(f"Splitting dataset {cfg.repo_id} with splits: {cfg.operation.splits}")
|
||||
split_datasets = split_dataset(dataset, splits=cfg.operation.splits)
|
||||
split_datasets = split_dataset(
|
||||
dataset,
|
||||
splits=cfg.operation.splits,
|
||||
output_dir=cfg.new_root,
|
||||
)
|
||||
|
||||
for split_name, split_ds in split_datasets.items():
|
||||
split_repo_id = f"{cfg.repo_id}_{split_name}"
|
||||
logging.info(
|
||||
f"{split_name}: {split_ds.meta.total_episodes} episodes, {split_ds.meta.total_frames} frames"
|
||||
)
|
||||
|
||||
if cfg.push_to_hub:
|
||||
logging.info(f"Pushing {split_name} split to hub as {split_repo_id}")
|
||||
logging.info(f"Pushing {split_name} split to hub as {split_ds.repo_id}")
|
||||
LeRobotDataset(split_ds.repo_id, root=split_ds.root).push_to_hub()
|
||||
|
||||
|
||||
@@ -275,18 +350,29 @@ def handle_merge(cfg: EditDatasetConfig) -> None:
|
||||
if not cfg.operation.repo_ids:
|
||||
raise ValueError("repo_ids must be specified for merge operation")
|
||||
|
||||
if not cfg.repo_id:
|
||||
raise ValueError("repo_id must be specified as the output repository for merged dataset")
|
||||
if cfg.repo_id is not None or cfg.root is not None:
|
||||
logging.warning(
|
||||
"merge uses --new_repo_id and --new_root for the merged dataset. The --repo_id and --root parameters are ignored."
|
||||
)
|
||||
|
||||
logging.info(f"Loading {len(cfg.operation.repo_ids)} datasets to merge")
|
||||
datasets = [LeRobotDataset(repo_id, root=cfg.root) for repo_id in cfg.operation.repo_ids]
|
||||
if cfg.operation.roots:
|
||||
if len(cfg.operation.roots) != len(cfg.operation.repo_ids):
|
||||
raise ValueError("repo_ids and roots must have the same length for merge operation")
|
||||
logging.info(f"Loading {len(cfg.operation.roots)} datasets to merge")
|
||||
datasets = [
|
||||
LeRobotDataset(repo_id=repo_id, root=root)
|
||||
for repo_id, root in zip(cfg.operation.repo_ids, cfg.operation.roots, strict=True)
|
||||
]
|
||||
else:
|
||||
logging.info(f"Loading {len(cfg.operation.repo_ids)} datasets to merge")
|
||||
datasets = [LeRobotDataset(repo_id) for repo_id in cfg.operation.repo_ids]
|
||||
|
||||
output_dir = Path(cfg.root) / cfg.repo_id if cfg.root else HF_LEROBOT_HOME / cfg.repo_id
|
||||
output_dir = Path(cfg.new_root) if cfg.new_root else HF_LEROBOT_HOME / cfg.new_repo_id
|
||||
|
||||
logging.info(f"Merging datasets into {cfg.repo_id}")
|
||||
logging.info(f"Merging datasets into {cfg.new_repo_id}")
|
||||
merged_dataset = merge_datasets(
|
||||
datasets,
|
||||
output_repo_id=cfg.repo_id,
|
||||
output_repo_id=cfg.new_repo_id,
|
||||
output_dir=output_dir,
|
||||
)
|
||||
|
||||
@@ -296,7 +382,7 @@ def handle_merge(cfg: EditDatasetConfig) -> None:
|
||||
)
|
||||
|
||||
if cfg.push_to_hub:
|
||||
logging.info(f"Pushing to hub as {cfg.repo_id}")
|
||||
logging.info(f"Pushing to hub as {cfg.new_repo_id}")
|
||||
LeRobotDataset(merged_dataset.repo_id, root=output_dir).push_to_hub()
|
||||
|
||||
|
||||
@@ -309,11 +395,15 @@ def handle_remove_feature(cfg: EditDatasetConfig) -> None:
|
||||
|
||||
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root)
|
||||
output_repo_id, output_dir = get_output_path(
|
||||
cfg.repo_id, cfg.new_repo_id, Path(cfg.root) if cfg.root else None
|
||||
cfg.repo_id,
|
||||
new_repo_id=cfg.new_repo_id,
|
||||
root=cfg.root,
|
||||
new_root=cfg.new_root,
|
||||
)
|
||||
|
||||
if cfg.new_repo_id is None:
|
||||
dataset.root = Path(str(dataset.root) + "_old")
|
||||
# In case of in-place modification, make the dataset point to the backup directory
|
||||
if output_dir == dataset.root:
|
||||
dataset.root = dataset.root.with_name(dataset.root.name + "_old")
|
||||
|
||||
logging.info(f"Removing features {cfg.operation.feature_names} from {cfg.repo_id}")
|
||||
new_dataset = remove_feature(
|
||||
@@ -341,9 +431,10 @@ def handle_modify_tasks(cfg: EditDatasetConfig) -> None:
|
||||
if new_task is None and episode_tasks_raw is None:
|
||||
raise ValueError("Must specify at least one of new_task or episode_tasks for modify_tasks operation")
|
||||
|
||||
# Warn about in-place modification behavior
|
||||
if cfg.new_repo_id is not None:
|
||||
logging.warning("modify_tasks modifies datasets in-place. The --new_repo_id parameter is ignored.")
|
||||
if cfg.new_repo_id is not None or cfg.new_root is not None:
|
||||
logging.warning(
|
||||
"modify_tasks modifies datasets in-place. The --new_repo_id and --new_root parameters are ignored."
|
||||
)
|
||||
|
||||
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root)
|
||||
logging.warning(f"Modifying dataset in-place at {dataset.root}. Original data will be overwritten.")
|
||||
@@ -379,32 +470,30 @@ def handle_convert_image_to_video(cfg: EditDatasetConfig) -> None:
|
||||
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root)
|
||||
|
||||
# Determine output directory and repo_id
|
||||
# Priority: 1) new_repo_id, 2) operation.output_dir, 3) auto-generated name
|
||||
# Priority: 1) new_root, 2) new_repo_id, 3) operation.output_dir, 4) auto-generated name
|
||||
output_dir_config = getattr(cfg.operation, "output_dir", None)
|
||||
if output_dir_config:
|
||||
logging.warning(
|
||||
"--operation.output_dir is deprecated and will be removed in future versions. "
|
||||
"Please use --new_root instead."
|
||||
)
|
||||
|
||||
if cfg.new_repo_id:
|
||||
# Use new_repo_id for both local storage and hub push
|
||||
if cfg.new_root:
|
||||
output_dir = Path(cfg.new_root)
|
||||
output_repo_id = cfg.new_repo_id or f"{cfg.repo_id}_video"
|
||||
logging.info(f"Saving to new_root: {output_dir} as {output_repo_id}")
|
||||
elif cfg.new_repo_id:
|
||||
output_repo_id = cfg.new_repo_id
|
||||
# Place new dataset as a sibling to the original dataset
|
||||
# Get the parent of the actual dataset root (not cfg.root which might be the lerobot cache dir)
|
||||
# Extract just the dataset name (after last slash) for the local directory
|
||||
local_dir_name = cfg.new_repo_id.split("/")[-1]
|
||||
output_dir = dataset.root.parent / local_dir_name
|
||||
output_dir = HF_LEROBOT_HOME / cfg.new_repo_id
|
||||
logging.info(f"Saving to new dataset: {cfg.new_repo_id} at {output_dir}")
|
||||
elif output_dir_config:
|
||||
# Use custom output directory for local-only storage
|
||||
output_dir = Path(output_dir_config)
|
||||
# Extract repo name from output_dir for the dataset
|
||||
output_repo_id = output_dir.name
|
||||
logging.info(f"Saving to local directory: {output_dir}")
|
||||
logging.info(f"Saving to local directory: {output_dir} as {output_repo_id}")
|
||||
else:
|
||||
# Auto-generate name: append "_video" to original repo_id
|
||||
output_repo_id = f"{cfg.repo_id}_video"
|
||||
# Place new dataset as a sibling to the original dataset
|
||||
# Extract just the dataset name (after last slash) for the local directory
|
||||
local_dir_name = output_repo_id.split("/")[-1]
|
||||
output_dir = dataset.root.parent / local_dir_name
|
||||
logging.info(f"Saving to auto-generated location: {output_dir}")
|
||||
output_dir = HF_LEROBOT_HOME / output_repo_id
|
||||
logging.info(f"Saving to auto-generated location: {output_dir} as {output_repo_id}")
|
||||
|
||||
logging.info(f"Converting dataset {cfg.repo_id} to video format")
|
||||
|
||||
@@ -436,8 +525,63 @@ def handle_convert_image_to_video(cfg: EditDatasetConfig) -> None:
|
||||
logging.info("Dataset saved locally (not pushed to hub)")
|
||||
|
||||
|
||||
def _get_dataset_size(repo_path):
|
||||
import os
|
||||
|
||||
total = 0
|
||||
with os.scandir(repo_path) as it:
|
||||
for entry in it:
|
||||
if entry.is_file():
|
||||
total += entry.stat().st_size
|
||||
elif entry.is_dir():
|
||||
total += _get_dataset_size(entry.path)
|
||||
return total
|
||||
|
||||
|
||||
def handle_info(cfg: EditDatasetConfig):
|
||||
if not isinstance(cfg.operation, InfoConfig):
|
||||
raise ValueError("Operation config must be InfoConfig")
|
||||
|
||||
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root)
|
||||
sys.stdout.write(f"======Info {dataset.meta.repo_id}\n")
|
||||
sys.stdout.write(f"Repository ID: {dataset.meta.repo_id} \n")
|
||||
sys.stdout.write(f"Total episode: {dataset.meta.total_episodes} \n")
|
||||
sys.stdout.write(f"Total task: {dataset.meta.total_tasks} \n")
|
||||
sys.stdout.write(f"Total frame(Actual Count): {dataset.meta.total_frames}({len(dataset)}) \n")
|
||||
sys.stdout.write(
|
||||
f"Average frame per episode: {dataset.meta.total_frames / dataset.meta.total_episodes:.1f}\n"
|
||||
)
|
||||
sys.stdout.write(
|
||||
f"Average episode time(sec): {(dataset.meta.total_frames / dataset.meta.total_episodes) / dataset.meta.fps:.1f}\n"
|
||||
)
|
||||
sys.stdout.write(f"FPS: {dataset.meta.fps}\n")
|
||||
|
||||
total_file_size = _get_dataset_size(dataset.root)
|
||||
sys.stdout.write(f"Size: {total_file_size / (1024 * 1024):.1f} MB\n")
|
||||
if cfg.operation.show_features:
|
||||
import json
|
||||
|
||||
feature_dump_str = json.dumps(
|
||||
dataset.meta.features, ensure_ascii=False, indent=4, sort_keys=True, separators=(",", ": ")
|
||||
)
|
||||
sys.stdout.write("Features:\n")
|
||||
sys.stdout.write(f"{feature_dump_str}\n")
|
||||
|
||||
|
||||
def _validate_config(cfg: EditDatasetConfig) -> None:
|
||||
if isinstance(cfg.operation, MergeConfig):
|
||||
if not cfg.new_repo_id:
|
||||
raise ValueError("--new_repo_id is required for merge operation (the merged dataset identifier)")
|
||||
else:
|
||||
if not cfg.repo_id:
|
||||
raise ValueError(
|
||||
f"--repo_id is required for {cfg.operation.type} operation (the input dataset identifier)"
|
||||
)
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def edit_dataset(cfg: EditDatasetConfig) -> None:
|
||||
_validate_config(cfg)
|
||||
operation_type = cfg.operation.type
|
||||
|
||||
if operation_type == "delete_episodes":
|
||||
@@ -452,6 +596,8 @@ def edit_dataset(cfg: EditDatasetConfig) -> None:
|
||||
handle_modify_tasks(cfg)
|
||||
elif operation_type == "convert_image_to_video":
|
||||
handle_convert_image_to_video(cfg)
|
||||
elif operation_type == "info":
|
||||
handle_info(cfg)
|
||||
else:
|
||||
available = ", ".join(OperationConfig.get_known_choices())
|
||||
raise ValueError(f"Unknown operation: {operation_type}\nAvailable operations: {available}")
|
||||
|
||||
@@ -61,6 +61,7 @@ from lerobot.teleoperators import ( # noqa: F401
|
||||
make_teleoperator_from_config,
|
||||
omx_leader,
|
||||
openarm_leader,
|
||||
openarm_mini,
|
||||
so_leader,
|
||||
)
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
|
||||
@@ -26,8 +26,10 @@ lerobot-record \
|
||||
--dataset.repo_id=<my_username>/<my_dataset_name> \
|
||||
--dataset.num_episodes=2 \
|
||||
--dataset.single_task="Grab the cube" \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
--display_data=true
|
||||
# <- Optional: specify video codec (h264, hevc, libsvtav1). Default is libsvtav1. \
|
||||
# <- Optional: specify video codec (auto, h264, hevc, libsvtav1). Default is libsvtav1. \
|
||||
# --dataset.vcodec=h264 \
|
||||
# <- Teleop optional if you want to teleoperate to record or in between episodes with a policy \
|
||||
# --teleop.type=so100_leader \
|
||||
@@ -58,7 +60,10 @@ lerobot-record \
|
||||
--display_data=true \
|
||||
--dataset.repo_id=${HF_USER}/bimanual-so-handover-cube \
|
||||
--dataset.num_episodes=25 \
|
||||
--dataset.single_task="Grab and handover the red cube to the other arm"
|
||||
--dataset.single_task="Grab and handover the red cube to the other arm" \
|
||||
--dataset.streaming_encoding=true \
|
||||
# --dataset.vcodec=auto \
|
||||
--dataset.encoder_threads=2
|
||||
```
|
||||
"""
|
||||
|
||||
@@ -120,6 +125,7 @@ from lerobot.teleoperators import ( # noqa: F401
|
||||
make_teleoperator_from_config,
|
||||
omx_leader,
|
||||
openarm_leader,
|
||||
openarm_mini,
|
||||
reachy2_teleoperator,
|
||||
so_leader,
|
||||
unitree_g1,
|
||||
@@ -149,7 +155,7 @@ class DatasetRecordConfig:
|
||||
repo_id: str
|
||||
# A short but accurate description of the task performed during the recording (e.g. "Pick the Lego block and drop it in the box on the right.")
|
||||
single_task: str
|
||||
# Root directory where the dataset will be stored (e.g. 'dataset/path').
|
||||
# Root directory where the dataset will be stored (e.g. 'dataset/path'). If None, defaults to $HF_LEROBOT_HOME/repo_id.
|
||||
root: str | Path | None = None
|
||||
# Limit the frames per second.
|
||||
fps: int = 30
|
||||
@@ -179,9 +185,19 @@ class DatasetRecordConfig:
|
||||
# Number of episodes to record before batch encoding videos
|
||||
# Set to 1 for immediate encoding (default behavior), or higher for batched encoding
|
||||
video_encoding_batch_size: int = 1
|
||||
# Video codec for encoding videos. Options: 'h264', 'hevc', 'libsvtav1'.
|
||||
# Use 'h264' for faster encoding on systems where AV1 encoding is CPU-heavy.
|
||||
# Video codec for encoding videos. Options: 'h264', 'hevc', 'libsvtav1', 'auto',
|
||||
# or hardware-specific: 'h264_videotoolbox', 'h264_nvenc', 'h264_vaapi', 'h264_qsv'.
|
||||
# Use 'auto' to auto-detect the best available hardware encoder.
|
||||
vcodec: str = "libsvtav1"
|
||||
# Enable streaming video encoding: encode frames in real-time during capture instead
|
||||
# of writing PNG images first. Makes save_episode() near-instant. More info in the documentation: https://huggingface.co/docs/lerobot/streaming_video_encoding
|
||||
streaming_encoding: bool = False
|
||||
# Maximum number of frames to buffer per camera when using streaming encoding.
|
||||
# ~1s buffer at 30fps. Provides backpressure if the encoder can't keep up.
|
||||
encoder_queue_maxsize: int = 30
|
||||
# Number of threads per encoder instance. None = auto (codec default).
|
||||
# Lower values reduce CPU usage, maps to 'lp' (via svtav1-params) for libsvtav1 and 'threads' for h264/hevc..
|
||||
encoder_threads: int | None = None
|
||||
# Rename map for the observation to override the image and state keys
|
||||
rename_map: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
@@ -318,6 +334,7 @@ def record_loop(
|
||||
preprocessor.reset()
|
||||
postprocessor.reset()
|
||||
|
||||
no_action_count = 0
|
||||
timestamp = 0
|
||||
start_episode_t = time.perf_counter()
|
||||
while timestamp < control_time_s:
|
||||
@@ -365,11 +382,13 @@ def record_loop(
|
||||
act = {**arm_action, **base_action} if len(base_action) > 0 else arm_action
|
||||
act_processed_teleop = teleop_action_processor((act, obs))
|
||||
else:
|
||||
logging.info(
|
||||
"No policy or teleoperator provided, skipping action generation."
|
||||
"This is likely to happen when resetting the environment without a teleop device."
|
||||
"The robot won't be at its rest position at the start of the next episode."
|
||||
)
|
||||
no_action_count += 1
|
||||
if no_action_count == 1 or no_action_count % 10 == 0:
|
||||
logging.warning(
|
||||
"No policy or teleoperator provided, skipping action generation. "
|
||||
"This is likely to happen when resetting the environment without a teleop device. "
|
||||
"The robot won't be at its rest position at the start of the next episode."
|
||||
)
|
||||
continue
|
||||
|
||||
# Applies a pipeline to the action, default is IdentityProcessor
|
||||
@@ -398,7 +417,14 @@ def record_loop(
|
||||
)
|
||||
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
precise_sleep(max(1 / fps - dt_s, 0.0))
|
||||
|
||||
sleep_time_s: float = 1 / fps - dt_s
|
||||
if sleep_time_s < 0:
|
||||
logging.warning(
|
||||
f"Record loop is running slower ({1 / dt_s:.1f} Hz) than the target FPS ({fps} Hz). Dataset frames might be dropped and robot control might be unstable. Common causes are: 1) Camera FPS not keeping up 2) Policy inference taking too long 3) CPU starvation"
|
||||
)
|
||||
|
||||
precise_sleep(max(sleep_time_s, 0.0))
|
||||
|
||||
timestamp = time.perf_counter() - start_episode_t
|
||||
|
||||
@@ -445,6 +471,9 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
|
||||
root=cfg.dataset.root,
|
||||
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
|
||||
vcodec=cfg.dataset.vcodec,
|
||||
streaming_encoding=cfg.dataset.streaming_encoding,
|
||||
encoder_queue_maxsize=cfg.dataset.encoder_queue_maxsize,
|
||||
encoder_threads=cfg.dataset.encoder_threads,
|
||||
)
|
||||
|
||||
if hasattr(robot, "cameras") and len(robot.cameras) > 0:
|
||||
@@ -467,6 +496,9 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
|
||||
image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot.cameras),
|
||||
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
|
||||
vcodec=cfg.dataset.vcodec,
|
||||
streaming_encoding=cfg.dataset.streaming_encoding,
|
||||
encoder_queue_maxsize=cfg.dataset.encoder_queue_maxsize,
|
||||
encoder_threads=cfg.dataset.encoder_threads,
|
||||
)
|
||||
|
||||
# Load pretrained policy
|
||||
@@ -490,6 +522,11 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
|
||||
|
||||
listener, events = init_keyboard_listener()
|
||||
|
||||
if not cfg.dataset.streaming_encoding:
|
||||
logging.info(
|
||||
"Streaming encoding is disabled. If you have capable hardware, consider enabling it for way faster episode saving. --dataset.streaming_encoding=true --dataset.encoder_threads=2 # --dataset.vcodec=auto. More info in the documentation: https://huggingface.co/docs/lerobot/streaming_video_encoding"
|
||||
)
|
||||
|
||||
with VideoEncodingManager(dataset):
|
||||
recorded_episodes = 0
|
||||
while recorded_episodes < cfg.dataset.num_episodes and not events["stop_recording"]:
|
||||
|
||||
@@ -22,7 +22,7 @@ lerobot-replay \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58760431541 \
|
||||
--robot.id=black \
|
||||
--dataset.repo_id=aliberts/record-test \
|
||||
--dataset.repo_id=<USER>/record-test \
|
||||
--dataset.episode=0
|
||||
```
|
||||
|
||||
@@ -80,7 +80,7 @@ class DatasetReplayConfig:
|
||||
repo_id: str
|
||||
# Episode to replay.
|
||||
episode: int
|
||||
# Root directory where the dataset will be stored (e.g. 'dataset/path').
|
||||
# Root directory where the dataset will be stored (e.g. 'dataset/path'). If None, defaults to $HF_LEROBOT_HOME/repo_id.
|
||||
root: str | Path | None = None
|
||||
# Limit the frames per second. By default, uses the policy fps.
|
||||
fps: int = 30
|
||||
|
||||
@@ -152,6 +152,7 @@ def test_motor(bus, motor_id: int, timeout: float, use_fd: bool):
|
||||
)
|
||||
try:
|
||||
bus.send(disable_msg)
|
||||
bus.recv(timeout=0.1) # Clear any pending responses
|
||||
except Exception:
|
||||
print(f"Error sending message to motor 0x{motor_id:02X}")
|
||||
|
||||
|
||||
@@ -43,6 +43,7 @@ from lerobot.teleoperators import ( # noqa: F401
|
||||
koch_leader,
|
||||
make_teleoperator_from_config,
|
||||
omx_leader,
|
||||
openarm_mini,
|
||||
so_leader,
|
||||
)
|
||||
|
||||
@@ -51,6 +52,7 @@ COMPATIBLE_DEVICES = [
|
||||
"koch_leader",
|
||||
"omx_follower",
|
||||
"omx_leader",
|
||||
"openarm_mini",
|
||||
"so100_follower",
|
||||
"so100_leader",
|
||||
"so101_follower",
|
||||
|
||||
@@ -94,6 +94,7 @@ from lerobot.teleoperators import ( # noqa: F401
|
||||
make_teleoperator_from_config,
|
||||
omx_leader,
|
||||
openarm_leader,
|
||||
openarm_mini,
|
||||
reachy2_teleoperator,
|
||||
so_leader,
|
||||
unitree_g1,
|
||||
|
||||
@@ -24,6 +24,7 @@ import torch
|
||||
from accelerate import Accelerator
|
||||
from termcolor import colored
|
||||
from torch.optim import Optimizer
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
@@ -51,6 +52,7 @@ from lerobot.utils.utils import (
|
||||
format_big_number,
|
||||
has_method,
|
||||
init_logging,
|
||||
inside_slurm,
|
||||
)
|
||||
|
||||
|
||||
@@ -378,10 +380,10 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
"dataloading_s": AverageMeter("data_s", ":.3f"),
|
||||
}
|
||||
|
||||
# Use effective batch size for proper epoch calculation in distributed training
|
||||
# Keep global batch size for logging; MetricsTracker handles world size internally.
|
||||
effective_batch_size = cfg.batch_size * accelerator.num_processes
|
||||
train_tracker = MetricsTracker(
|
||||
effective_batch_size,
|
||||
cfg.batch_size,
|
||||
dataset.num_frames,
|
||||
dataset.num_episodes,
|
||||
train_metrics,
|
||||
@@ -390,6 +392,14 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
)
|
||||
|
||||
if is_main_process:
|
||||
progbar = tqdm(
|
||||
total=cfg.steps - step,
|
||||
desc="Training",
|
||||
unit="step",
|
||||
disable=inside_slurm(),
|
||||
position=0,
|
||||
leave=True,
|
||||
)
|
||||
logging.info(
|
||||
f"Start offline training on a fixed dataset, with effective batch size: {effective_batch_size}"
|
||||
)
|
||||
@@ -414,6 +424,8 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
|
||||
# increment `step` here.
|
||||
step += 1
|
||||
if is_main_process:
|
||||
progbar.update(1)
|
||||
train_tracker.step()
|
||||
is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0 and is_main_process
|
||||
is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps
|
||||
@@ -507,6 +519,9 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
if is_main_process:
|
||||
progbar.close()
|
||||
|
||||
if eval_env:
|
||||
close_envs(eval_env)
|
||||
|
||||
|
||||
@@ -306,7 +306,7 @@ def train_fast_tokenizer(
|
||||
|
||||
# download the tokenizer source code (not pretrained weights)
|
||||
# we'll train a new tokenizer on our own data
|
||||
base_tokenizer = AutoProcessor.from_pretrained("physical-intelligence/fast", trust_remote_code=True)
|
||||
base_tokenizer = AutoProcessor.from_pretrained("lerobot/fast-action-tokenizer", trust_remote_code=True)
|
||||
|
||||
# convert action_chunks array to list of arrays (expected by .fit())
|
||||
action_data_list = [action_chunks[i] for i in range(len(action_chunks))]
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user