mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-12 23:29:52 +00:00
Compare commits
39 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 0394fae446 | |||
| 602b8e66a6 | |||
| ab4dce6fed | |||
| 40f4386e4a | |||
| 87a91b4b08 | |||
| fadb900c36 | |||
| de0663226a | |||
| 0ca9d66cae | |||
| 2222f25da3 | |||
| acae8417aa | |||
| 2697f65cf6 | |||
| 74f42f218e | |||
| ca9d49e305 | |||
| 6705876d47 | |||
| aadbd27675 | |||
| 5221647b5e | |||
| 9c981300dd | |||
| f5b27aad1b | |||
| 75f1285507 | |||
| 33cedc2f71 | |||
| aa32e6c4ab | |||
| f906270ec4 | |||
| 733b6d84db | |||
| 8abc9037a3 | |||
| e4d4ac0bda | |||
| e79b2a439b | |||
| f9ae78ca74 | |||
| e1ced538e3 | |||
| 2a98602ad6 | |||
| a2f5b3571e | |||
| cecf2eff4f | |||
| 7e6b598a51 | |||
| 4fa41ba806 | |||
| 1de2b87a92 | |||
| e3c511db67 | |||
| aed4130d39 | |||
| d26349c692 | |||
| a9bce4732b | |||
| 86d69e3c1d |
@@ -173,8 +173,6 @@ jobs:
|
||||
shell: bash
|
||||
working-directory: /lerobot
|
||||
steps:
|
||||
- name: Fix ptxas permissions
|
||||
run: chmod +x /lerobot/.venv/lib/python3.10/site-packages/triton/backends/nvidia/bin/ptxas
|
||||
- name: Run pytest on GPU
|
||||
run: pytest tests -vv --maxfail=10
|
||||
- name: Run end-to-end tests
|
||||
|
||||
@@ -1,3 +1,2 @@
|
||||
include src/lerobot/templates/lerobot_modelcard_template.md
|
||||
include src/lerobot/datasets/card_template.md
|
||||
include src/lerobot/envs/metaworld_config.json
|
||||
|
||||
+42
-42
@@ -28,9 +28,9 @@ We don't expect the same optimal settings for a dataset of images from a simulat
|
||||
For these reasons, we run this benchmark on four representative datasets:
|
||||
|
||||
- `lerobot/pusht_image`: (96 x 96 pixels) simulation with simple geometric shapes, fixed camera.
|
||||
- `lerobot/aloha_mobile_shrimp_image`: (480 x 640 pixels) real-world indoor, moving camera.
|
||||
- `lerobot/paris_street`: (720 x 1280 pixels) real-world outdoor, moving camera.
|
||||
- `lerobot/kitchen`: (1080 x 1920 pixels) real-world indoor, fixed camera.
|
||||
- `aliberts/aloha_mobile_shrimp_image`: (480 x 640 pixels) real-world indoor, moving camera.
|
||||
- `aliberts/paris_street`: (720 x 1280 pixels) real-world outdoor, moving camera.
|
||||
- `aliberts/kitchen`: (1080 x 1920 pixels) real-world indoor, fixed camera.
|
||||
|
||||
Note: The datasets used for this benchmark need to be image datasets, not video datasets.
|
||||
|
||||
@@ -179,7 +179,7 @@ python benchmark/video/run_video_benchmark.py \
|
||||
--output-dir outputs/video_benchmark \
|
||||
--repo-ids \
|
||||
lerobot/pusht_image \
|
||||
lerobot/aloha_mobile_shrimp_image \
|
||||
aliberts/aloha_mobile_shrimp_image \
|
||||
--vcodec libx264 libx265 \
|
||||
--pix-fmt yuv444p yuv420p \
|
||||
--g 2 20 None \
|
||||
@@ -203,9 +203,9 @@ python benchmark/video/run_video_benchmark.py \
|
||||
--output-dir outputs/video_benchmark \
|
||||
--repo-ids \
|
||||
lerobot/pusht_image \
|
||||
lerobot/aloha_mobile_shrimp_image \
|
||||
lerobot/paris_street \
|
||||
lerobot/kitchen \
|
||||
aliberts/aloha_mobile_shrimp_image \
|
||||
aliberts/paris_street \
|
||||
aliberts/kitchen \
|
||||
--vcodec libx264 libx265 \
|
||||
--pix-fmt yuv444p yuv420p \
|
||||
--g 1 2 3 4 5 6 10 15 20 40 None \
|
||||
@@ -221,9 +221,9 @@ python benchmark/video/run_video_benchmark.py \
|
||||
--output-dir outputs/video_benchmark \
|
||||
--repo-ids \
|
||||
lerobot/pusht_image \
|
||||
lerobot/aloha_mobile_shrimp_image \
|
||||
lerobot/paris_street \
|
||||
lerobot/kitchen \
|
||||
aliberts/aloha_mobile_shrimp_image \
|
||||
aliberts/paris_street \
|
||||
aliberts/kitchen \
|
||||
--vcodec libsvtav1 \
|
||||
--pix-fmt yuv420p \
|
||||
--g 1 2 3 4 5 6 10 15 20 40 None \
|
||||
@@ -252,37 +252,37 @@ Since we're using av1 encoding, we're choosing the `pyav` decoder as `video_read
|
||||
|
||||
These tables show the results for `g=2` and `crf=30`, using `timestamps-modes=6_frames` and `backend=pyav`
|
||||
|
||||
| video_images_size_ratio | vcodec | pix_fmt | | | |
|
||||
| --------------------------------- | ---------- | ------- | --------- | --------- | --------- |
|
||||
| | libx264 | | libx265 | | libsvtav1 |
|
||||
| repo_id | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
|
||||
| lerobot/pusht_image | **16.97%** | 17.58% | 18.57% | 18.86% | 22.06% |
|
||||
| lerobot/aloha_mobile_shrimp_image | 2.14% | 2.11% | 1.38% | **1.37%** | 5.59% |
|
||||
| lerobot/paris_street | 2.12% | 2.13% | **1.54%** | **1.54%** | 4.43% |
|
||||
| lerobot/kitchen | 1.40% | 1.39% | **1.00%** | **1.00%** | 2.52% |
|
||||
| video_images_size_ratio | vcodec | pix_fmt | | | |
|
||||
| ---------------------------------- | ---------- | ------- | --------- | --------- | --------- |
|
||||
| | libx264 | | libx265 | | libsvtav1 |
|
||||
| repo_id | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
|
||||
| lerobot/pusht_image | **16.97%** | 17.58% | 18.57% | 18.86% | 22.06% |
|
||||
| aliberts/aloha_mobile_shrimp_image | 2.14% | 2.11% | 1.38% | **1.37%** | 5.59% |
|
||||
| aliberts/paris_street | 2.12% | 2.13% | **1.54%** | **1.54%** | 4.43% |
|
||||
| aliberts/kitchen | 1.40% | 1.39% | **1.00%** | **1.00%** | 2.52% |
|
||||
|
||||
| video_images_load_time_ratio | vcodec | pix_fmt | | | |
|
||||
| --------------------------------- | ------- | ------- | -------- | ------- | --------- |
|
||||
| | libx264 | | libx265 | | libsvtav1 |
|
||||
| repo_id | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
|
||||
| lerobot/pusht_image | 6.45 | 5.19 | **1.90** | 2.12 | 2.47 |
|
||||
| lerobot/aloha_mobile_shrimp_image | 11.80 | 7.92 | 0.71 | 0.85 | **0.48** |
|
||||
| lerobot/paris_street | 2.21 | 2.05 | 0.36 | 0.49 | **0.30** |
|
||||
| lerobot/kitchen | 1.46 | 1.46 | 0.28 | 0.51 | **0.26** |
|
||||
| video_images_load_time_ratio | vcodec | pix_fmt | | | |
|
||||
| ---------------------------------- | ------- | ------- | -------- | ------- | --------- |
|
||||
| | libx264 | | libx265 | | libsvtav1 |
|
||||
| repo_id | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
|
||||
| lerobot/pusht_image | 6.45 | 5.19 | **1.90** | 2.12 | 2.47 |
|
||||
| aliberts/aloha_mobile_shrimp_image | 11.80 | 7.92 | 0.71 | 0.85 | **0.48** |
|
||||
| aliberts/paris_street | 2.21 | 2.05 | 0.36 | 0.49 | **0.30** |
|
||||
| aliberts/kitchen | 1.46 | 1.46 | 0.28 | 0.51 | **0.26** |
|
||||
|
||||
| | | vcodec | pix_fmt | | | |
|
||||
| --------------------------------- | -------- | -------- | ------------ | -------- | --------- | ------------ |
|
||||
| | | libx264 | | libx265 | | libsvtav1 |
|
||||
| repo_id | metric | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
|
||||
| lerobot/pusht_image | avg_mse | 2.90E-04 | **2.03E-04** | 3.13E-04 | 2.29E-04 | 2.19E-04 |
|
||||
| | avg_psnr | 35.44 | 37.07 | 35.49 | **37.30** | 37.20 |
|
||||
| | avg_ssim | 98.28% | **98.85%** | 98.31% | 98.84% | 98.72% |
|
||||
| lerobot/aloha_mobile_shrimp_image | avg_mse | 2.76E-04 | 2.59E-04 | 3.17E-04 | 3.06E-04 | **1.30E-04** |
|
||||
| | avg_psnr | 35.91 | 36.21 | 35.88 | 36.09 | **40.17** |
|
||||
| | avg_ssim | 95.19% | 95.18% | 95.00% | 95.05% | **97.73%** |
|
||||
| lerobot/paris_street | avg_mse | 6.89E-04 | 6.70E-04 | 4.03E-03 | 4.02E-03 | **3.09E-04** |
|
||||
| | avg_psnr | 33.48 | 33.68 | 32.05 | 32.15 | **35.40** |
|
||||
| | avg_ssim | 93.76% | 93.75% | 89.46% | 89.46% | **95.46%** |
|
||||
| lerobot/kitchen | avg_mse | 2.50E-04 | 2.24E-04 | 4.28E-04 | 4.18E-04 | **1.53E-04** |
|
||||
| | avg_psnr | 36.73 | 37.33 | 36.56 | 36.75 | **39.12** |
|
||||
| | avg_ssim | 95.47% | 95.58% | 95.52% | 95.53% | **96.82%** |
|
||||
| | | vcodec | pix_fmt | | | |
|
||||
| ---------------------------------- | -------- | -------- | ------------ | -------- | --------- | ------------ |
|
||||
| | | libx264 | | libx265 | | libsvtav1 |
|
||||
| repo_id | metric | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
|
||||
| lerobot/pusht_image | avg_mse | 2.90E-04 | **2.03E-04** | 3.13E-04 | 2.29E-04 | 2.19E-04 |
|
||||
| | avg_psnr | 35.44 | 37.07 | 35.49 | **37.30** | 37.20 |
|
||||
| | avg_ssim | 98.28% | **98.85%** | 98.31% | 98.84% | 98.72% |
|
||||
| aliberts/aloha_mobile_shrimp_image | avg_mse | 2.76E-04 | 2.59E-04 | 3.17E-04 | 3.06E-04 | **1.30E-04** |
|
||||
| | avg_psnr | 35.91 | 36.21 | 35.88 | 36.09 | **40.17** |
|
||||
| | avg_ssim | 95.19% | 95.18% | 95.00% | 95.05% | **97.73%** |
|
||||
| aliberts/paris_street | avg_mse | 6.89E-04 | 6.70E-04 | 4.03E-03 | 4.02E-03 | **3.09E-04** |
|
||||
| | avg_psnr | 33.48 | 33.68 | 32.05 | 32.15 | **35.40** |
|
||||
| | avg_ssim | 93.76% | 93.75% | 89.46% | 89.46% | **95.46%** |
|
||||
| aliberts/kitchen | avg_mse | 2.50E-04 | 2.24E-04 | 4.28E-04 | 4.18E-04 | **1.53E-04** |
|
||||
| | avg_psnr | 36.73 | 37.33 | 36.56 | 36.75 | **39.12** |
|
||||
| | avg_ssim | 95.47% | 95.58% | 95.52% | 95.53% | **96.82%** |
|
||||
|
||||
@@ -85,8 +85,6 @@ RUN if [ "$UNBOUND_DEPS" = "true" ]; then \
|
||||
|
||||
RUN uv pip install --no-cache ".[all]"
|
||||
|
||||
RUN chmod +x /lerobot/.venv/lib/python${PYTHON_VERSION}/site-packages/triton/backends/nvidia/bin/ptxas
|
||||
|
||||
# Copy the rest of the application source code
|
||||
# Make sure to have the git-LFS files for testing
|
||||
COPY --chown=user_lerobot:user_lerobot . .
|
||||
|
||||
@@ -29,8 +29,6 @@
|
||||
title: Using the Dataset Tools
|
||||
- local: dataset_subtask
|
||||
title: Using Subtasks in the Dataset
|
||||
- local: streaming_video_encoding
|
||||
title: Streaming Video Encoding
|
||||
title: "Datasets"
|
||||
- sections:
|
||||
- local: act
|
||||
|
||||
@@ -88,8 +88,5 @@ lerobot-record \
|
||||
--dataset.repo_id=${HF_USER}/eval_act_your_dataset \
|
||||
--dataset.num_episodes=10 \
|
||||
--dataset.single_task="Your task description" \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.vcodec=auto \
|
||||
--policy.path=${HF_USER}/act_policy
|
||||
```
|
||||
|
||||
@@ -185,16 +185,13 @@ echo $HF_USER
|
||||
Use the standard recording command:
|
||||
|
||||
```bash
|
||||
lerobot-record \
|
||||
python src/lerobot/scripts/lerobot_record.py \
|
||||
--robot.type=earthrover_mini_plus \
|
||||
--teleop.type=keyboard_rover \
|
||||
--dataset.repo_id=your_username/dataset_name \
|
||||
--dataset.num_episodes=2 \
|
||||
--dataset.fps=10 \
|
||||
--dataset.single_task="Navigate around obstacles" \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.vcodec=auto \
|
||||
--display_data=true
|
||||
```
|
||||
|
||||
|
||||
@@ -120,12 +120,9 @@ lerobot-record \
|
||||
--display_data=true \
|
||||
--dataset.repo_id=<user>/eval_groot-bimanual \
|
||||
--dataset.num_episodes=10 \
|
||||
--dataset.single_task="Grab and handover the red cube to the other arm" \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.vcodec=auto \
|
||||
--policy.path=<user>/groot-bimanual \ # your trained model
|
||||
--dataset.episode_time_s=30 \
|
||||
--dataset.single_task="Grab and handover the red cube to the other arm"
|
||||
--policy.path=<user>/groot-bimanual # your trained model
|
||||
--dataset.episode_time_s=30
|
||||
--dataset.reset_time_s=10
|
||||
```
|
||||
|
||||
|
||||
+5
-11
@@ -224,15 +224,12 @@ lerobot-record \
|
||||
--teleop.port=/dev/tty.usbmodem1201 \
|
||||
--teleop.id=right \
|
||||
--teleop.side=right \
|
||||
--dataset.repo_id=<USER>/hand_record_test_with_video_data \
|
||||
--dataset.repo_id=nepyope/hand_record_test_with_video_data \
|
||||
--dataset.single_task="Hand recording test with video data" \
|
||||
--dataset.num_episodes=1 \
|
||||
--dataset.episode_time_s=5 \
|
||||
--dataset.push_to_hub=true \
|
||||
--dataset.private=true \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.vcodec=auto \
|
||||
--display_data=true
|
||||
```
|
||||
|
||||
@@ -244,7 +241,7 @@ lerobot-replay \
|
||||
--robot.port=/dev/tty.usbmodem58760432281 \
|
||||
--robot.id=right \
|
||||
--robot.side=right \
|
||||
--dataset.repo_id=<USER>/hand_record_test_with_camera \
|
||||
--dataset.repo_id=nepyope/hand_record_test_with_camera \
|
||||
--dataset.episode=0
|
||||
```
|
||||
|
||||
@@ -252,13 +249,13 @@ lerobot-replay \
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--dataset.repo_id=<USER>/hand_record_test_with_video_data \
|
||||
--dataset.repo_id=nepyope/hand_record_test_with_video_data \
|
||||
--policy.type=act \
|
||||
--output_dir=outputs/train/hopejr_hand \
|
||||
--job_name=hopejr \
|
||||
--policy.device=mps \
|
||||
--wandb.enable=true \
|
||||
--policy.repo_id=<USER>/hand_test_policy
|
||||
--policy.repo_id=nepyope/hand_test_policy
|
||||
```
|
||||
|
||||
### Evaluate
|
||||
@@ -273,11 +270,8 @@ lerobot-record \
|
||||
--robot.side=right \
|
||||
--robot.cameras='{"main": {"type": "opencv", "index_or_path": 0, "width": 640, "height": 480, "fps": 30}}' \
|
||||
--display_data=false \
|
||||
--dataset.repo_id=<USER>/eval_hopejr \
|
||||
--dataset.repo_id=nepyope/eval_hopejr \
|
||||
--dataset.single_task="Evaluate hopejr hand policy" \
|
||||
--dataset.num_episodes=10 \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.vcodec=auto \
|
||||
--policy.path=outputs/train/hopejr_hand/checkpoints/last/pretrained_model
|
||||
```
|
||||
|
||||
@@ -165,7 +165,7 @@ huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
|
||||
Then store your Hugging Face repository name in a variable:
|
||||
|
||||
```bash
|
||||
HF_USER=$(hf auth whoami | awk -F': *' 'NR==1 {print $2}')
|
||||
HF_USER=$(hf auth whoami | head -n 1)
|
||||
echo $HF_USER
|
||||
```
|
||||
|
||||
@@ -185,10 +185,7 @@ lerobot-record \
|
||||
--display_data=true \
|
||||
--dataset.repo_id=${HF_USER}/record-test \
|
||||
--dataset.num_episodes=5 \
|
||||
--dataset.single_task="Grab the black cube" \
|
||||
--dataset.streaming_encoding=true \
|
||||
# --dataset.vcodec=auto \
|
||||
--dataset.encoder_threads=2
|
||||
--dataset.single_task="Grab the black cube"
|
||||
```
|
||||
</hfoption>
|
||||
<hfoption id="API example">
|
||||
@@ -518,9 +515,6 @@ lerobot-record \
|
||||
--display_data=false \
|
||||
--dataset.repo_id=${HF_USER}/eval_so100 \
|
||||
--dataset.single_task="Put lego brick into the transparent box" \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.vcodec=auto \
|
||||
# <- Teleop optional if you want to teleoperate in between episodes \
|
||||
# --teleop.type=so100_leader \
|
||||
# --teleop.port=/dev/ttyACM0 \
|
||||
|
||||
@@ -40,13 +40,6 @@ conda install ffmpeg -c conda-forge
|
||||
>
|
||||
> - _[On Linux only]_ If you want to bring your own ffmpeg: Install [ffmpeg build dependencies](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#GettheDependencies) and [compile ffmpeg from source with libsvtav1](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#libsvtav1), and make sure you use the corresponding ffmpeg binary to your install with `which ffmpeg`.
|
||||
|
||||
> [!NOTE]
|
||||
> When installing LeRobot inside WSL (Windows Subsystem for Linux), make sure to install `evdev` with the following command:
|
||||
>
|
||||
> ```bash
|
||||
> conda install evdev -c conda-forge
|
||||
> ```
|
||||
|
||||
## Step 3: Install LeRobot 🤗
|
||||
|
||||
### From Source
|
||||
|
||||
@@ -41,10 +41,7 @@ lerobot-record \
|
||||
--display_data=true \
|
||||
--dataset.repo_id=${HF_USER}/record-test \
|
||||
--dataset.num_episodes=5 \
|
||||
--dataset.single_task="Grab the black cube" \
|
||||
--dataset.streaming_encoding=true \
|
||||
# --dataset.vcodec=auto \
|
||||
--dataset.encoder_threads=2
|
||||
--dataset.single_task="Grab the black cube"
|
||||
```
|
||||
|
||||
See the [recording guide](./il_robots#record-a-dataset) for more details.
|
||||
|
||||
@@ -66,13 +66,12 @@ Run on of the examples scripts to teleoperate, record a dataset, replay a datase
|
||||
|
||||
All scripts assume you configured your robot (e.g., SO-100 follower) and set the correct serial port.
|
||||
|
||||
Additionally you need to **copy the URDF of the robot into the examples folder**. For the examples in this tutorial (using SO100/SO101), copy the `SO101` folder from the [SO-ARM100 repo](https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101) into the `examples/phone_to_so100/` directory, so that the URDF file path becomes `examples/phone_to_so100/SO101/so101_new_calib.urdf`.
|
||||
Additionally you need to **copy the urdf of the robot to the examples folder**. For the examples in this tutorial (Using SO100/SO101) it is highly recommended to use the urdf in the [SO-ARM100 repo](https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf)
|
||||
|
||||
- Run this example to teleoperate:
|
||||
|
||||
```bash
|
||||
cd examples/phone_to_so100
|
||||
python teleoperate.py
|
||||
python examples/phone_to_so100/teleoperate.py
|
||||
```
|
||||
|
||||
After running the example:
|
||||
@@ -85,22 +84,19 @@ Additionally you can customize mapping or safety limits by editing the processor
|
||||
- Run this example to record a dataset, which saves absolute end effector observations and actions:
|
||||
|
||||
```bash
|
||||
cd examples/phone_to_so100
|
||||
python record.py
|
||||
python examples/phone_to_so100/record.py
|
||||
```
|
||||
|
||||
- Run this example to replay recorded episodes:
|
||||
|
||||
```bash
|
||||
cd examples/phone_to_so100
|
||||
python replay.py
|
||||
python examples/phone_to_so100/replay.py
|
||||
```
|
||||
|
||||
- Run this example to evaluate a pretrained policy:
|
||||
|
||||
```bash
|
||||
cd examples/phone_to_so100
|
||||
python evaluate.py
|
||||
python examples/phone_to_so100/evaluate.py
|
||||
```
|
||||
|
||||
### Important pipeline steps and options
|
||||
|
||||
+1
-1
@@ -60,7 +60,7 @@ policy.type=pi0
|
||||
For training π₀, you can use the standard LeRobot training script with the appropriate configuration:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
python src/lerobot/scripts/lerobot_train.py \
|
||||
--dataset.repo_id=your_dataset \
|
||||
--policy.type=pi0 \
|
||||
--output_dir=./outputs/pi0_training \
|
||||
|
||||
@@ -56,7 +56,7 @@ policy.type=pi05
|
||||
Here's a complete training command for finetuning the base π₀.₅ model on your own dataset:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
python src/lerobot/scripts/lerobot_train.py\
|
||||
--dataset.repo_id=your_dataset \
|
||||
--policy.type=pi05 \
|
||||
--output_dir=./outputs/pi05_training \
|
||||
|
||||
@@ -159,9 +159,6 @@ lerobot-record \
|
||||
--dataset.fps=15 \
|
||||
--dataset.push_to_hub=true \
|
||||
--dataset.private=true \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.vcodec=auto \
|
||||
--display_data=true
|
||||
```
|
||||
|
||||
@@ -201,9 +198,6 @@ lerobot-record \
|
||||
--dataset.fps=15 \
|
||||
--dataset.push_to_hub=true \
|
||||
--dataset.private=true \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.vcodec=auto \
|
||||
--display_data=true
|
||||
```
|
||||
|
||||
|
||||
@@ -269,7 +269,7 @@ This generates visualizations showing video frames with subtask boundaries overl
|
||||
Train with **no annotations** - uses linear progress from 0 to 1:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
python src/lerobot/scripts/lerobot_train.py \
|
||||
--dataset.repo_id=your-username/your-dataset \
|
||||
--policy.type=sarm \
|
||||
--policy.annotation_mode=single_stage \
|
||||
@@ -288,7 +288,7 @@ lerobot-train \
|
||||
Train with **dense annotations only** (sparse auto-generated):
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
python src/lerobot/scripts/lerobot_train.py \
|
||||
--dataset.repo_id=your-username/your-dataset \
|
||||
--policy.type=sarm \
|
||||
--policy.annotation_mode=dense_only \
|
||||
@@ -307,7 +307,7 @@ lerobot-train \
|
||||
Train with **both sparse and dense annotations**:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
python src/lerobot/scripts/lerobot_train.py \
|
||||
--dataset.repo_id=your-username/your-dataset \
|
||||
--policy.type=sarm \
|
||||
--policy.annotation_mode=dual \
|
||||
@@ -468,7 +468,7 @@ This script:
|
||||
Once you have the progress file, train your policy with RA-BC weighting. The progress file is auto-detected from the dataset path (`sarm_progress.parquet`). Currently PI0, PI0.5 and SmolVLA are supported with RA-BC:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
python src/lerobot/scripts/lerobot_train.py \
|
||||
--dataset.repo_id=your-username/your-dataset \
|
||||
--policy.type=pi0 \
|
||||
--use_rabc=true \
|
||||
|
||||
@@ -106,9 +106,6 @@ lerobot-record \
|
||||
--dataset.repo_id=${HF_USER}/eval_DATASET_NAME_test \ # <- This will be the dataset name on HF Hub
|
||||
--dataset.episode_time_s=50 \
|
||||
--dataset.num_episodes=10 \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.vcodec=auto \
|
||||
# <- Teleop optional if you want to teleoperate in between episodes \
|
||||
# --teleop.type=so100_leader \
|
||||
# --teleop.port=/dev/ttyACM0 \
|
||||
|
||||
@@ -1,155 +0,0 @@
|
||||
# Streaming Video Encoding Guide
|
||||
|
||||
## 1. Overview
|
||||
|
||||
Streaming video encoding eliminates the traditional PNG round-trip during video dataset recording. Instead of:
|
||||
|
||||
1. Capture frame -> write PNG to disk -> (at episode end) read PNG's -> encode to MP4 -> delete PNG's
|
||||
|
||||
Frames can be encoded in real-time during capture:
|
||||
|
||||
1. Capture frame -> queue to encoder thread -> encode to MP4 directly
|
||||
|
||||
This makes `save_episode()` near-instant (the video is already encoded by the time the episode ends) and removes the blocking wait that previously occurred between episodes, especially with multiple cameras in long episodes.
|
||||
|
||||
## 2. Tuning Parameters
|
||||
|
||||
| Parameter | CLI Flag | Type | Default | Description |
|
||||
| ----------------------- | --------------------------------- | ------------- | ------------- | ----------------------------------------------------------------- |
|
||||
| `streaming_encoding` | `--dataset.streaming_encoding` | `bool` | `True` | Enable real-time encoding during capture |
|
||||
| `vcodec` | `--dataset.vcodec` | `str` | `"libsvtav1"` | Video codec. `"auto"` detects best HW encoder |
|
||||
| `encoder_threads` | `--dataset.encoder_threads` | `int \| None` | `None` (auto) | Threads per encoder instance. `None` will leave the vcoded decide |
|
||||
| `encoder_queue_maxsize` | `--dataset.encoder_queue_maxsize` | `int` | `60` | Max buffered frames per camera (~2s at 30fps). Consumes RAM |
|
||||
|
||||
## 3. Performance Considerations
|
||||
|
||||
Streaming encoding means the CPU is encoding video **during** the capture loop, not after. This creates a CPU budget that must be shared between:
|
||||
|
||||
- **Control loop** (reading cameras, control the robot, writing non-video data)
|
||||
- **Encoder threads** (one pool per camera)
|
||||
- **Rerun visualization** (if enabled)
|
||||
- **OS and other processes**
|
||||
|
||||
### Resolution & Number of Cameras Impact
|
||||
|
||||
| Setup | Throughput (px/sec) | CPU Encoding Load | Notes |
|
||||
| ------------------------- | ------------------- | ----------------- | ------------------------------ |
|
||||
| 2camsx 640x480x3 @30fps | 55M | Low | Works on most systems |
|
||||
| 2camsx 1280x720x3 @30fps | 165M | Moderate | Comfortable on modern systems |
|
||||
| 2camsx 1920x1080x3 @30fps | 373M | High | Requires powerful high-end CPU |
|
||||
|
||||
### `encoder_threads` Tuning
|
||||
|
||||
This parameter controls how many threads each encoder instance uses internally:
|
||||
|
||||
- **Higher values** (e.g., 4-5): Faster encoding, but uses more CPU cores per camera. Good for high-end systems with many cores.
|
||||
- **Lower values** (e.g., 1-2): Less CPU per camera, freeing cores for capture and visualization. Good for low-res images and capable CPUs.
|
||||
- **`None` (default)**: Lets the codec decide. Information available in the codec logs.
|
||||
|
||||
### Backpressure and Frame Dropping
|
||||
|
||||
Each camera has a bounded queue (`encoder_queue_maxsize`, default 60 frames). When the encoder can't keep up:
|
||||
|
||||
1. The queue fills up (consuming RAM)
|
||||
2. New frames are **dropped** (not blocked) — the capture loop continues uninterrupted
|
||||
3. A warning is logged: `"Encoder queue full for {camera}, dropped N frame(s)"`
|
||||
4. At episode end, total dropped frames per camera are reported
|
||||
|
||||
### Symptoms of Encoder Falling Behind
|
||||
|
||||
- **System feels laggy and freezes**: all CPUs are at 100%
|
||||
- **Dropped frame warnings** in the log or lower frames/FPS than expected in the recorded dataset
|
||||
- **Choppy robot movement**: If CPU is severely overloaded, even the capture loop may be affected
|
||||
- **Accumulated rerun lag**: Visualization falls behind real-time
|
||||
|
||||
## 4. Hardware-Accelerated Encoding
|
||||
|
||||
### When to Use
|
||||
|
||||
Use HW encoding when:
|
||||
|
||||
- CPU is the bottleneck (dropped frames, choppy robot, rerun lag)
|
||||
- You have compatible hardware (GPU or dedicated encoder)
|
||||
- You're recording at high throughput (high resolution or with many cameras)
|
||||
|
||||
### Choosing a Codec
|
||||
|
||||
| Codec | CPU Usage | File Size | Quality | Notes |
|
||||
| --------------------- | --------- | -------------- | ------- | ---------------------------------------------------------------- |
|
||||
| `libsvtav1` (default) | High | Smallest | Best | Default. Best compression but most CPU-intensive |
|
||||
| `h264` | Medium | ~30-50% larger | Good | Software H.264. Lower CPU |
|
||||
| HW encoders | Very Low | Largest | Good | Offloads to dedicated hardware. Best for CPU-constrained systems |
|
||||
|
||||
### Available HW Encoders
|
||||
|
||||
| Encoder | Platform | Hardware | CLI Value |
|
||||
| ------------------- | ------------- | ------------------------------------------------------------------------------------------------ | ------------------------------------ |
|
||||
| `h264_videotoolbox` | macOS | Apple Silicon / Intel | `--dataset.vcodec=h264_videotoolbox` |
|
||||
| `hevc_videotoolbox` | macOS | Apple Silicon / Intel | `--dataset.vcodec=hevc_videotoolbox` |
|
||||
| `h264_nvenc` | Linux/Windows | NVIDIA GPU | `--dataset.vcodec=h264_nvenc` |
|
||||
| `hevc_nvenc` | Linux/Windows | NVIDIA GPU | `--dataset.vcodec=hevc_nvenc` |
|
||||
| `h264_vaapi` | Linux | Intel/AMD GPU | `--dataset.vcodec=h264_vaapi` |
|
||||
| `h264_qsv` | Linux/Windows | Intel Quick Sync | `--dataset.vcodec=h264_qsv` |
|
||||
| `auto` | Any | Probes the system for available HW encoders. Falls back to `libsvtav1` if no HW encoder is found | `--dataset.vcodec=auto` |
|
||||
|
||||
> [!NOTE]
|
||||
> In order to use the HW accelerated encoders you might need to upgrade your GPU drivers.
|
||||
|
||||
> [!NOTE]
|
||||
> `libsvtav1` is the default because it provides the best training performance; other vcodecs can reduce CPU usage and be faster, but they typically produce larger files and may affect training time.
|
||||
|
||||
## 5. Troubleshooting
|
||||
|
||||
| Symptom | Likely Cause | Fix |
|
||||
| ------------------------------------------------------------------ | -------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
|
||||
| System freezes or choppy robot movement or Rerun visualization lag | CPU starved (100% load usage) | Close other apps, reduce encoding throughput, lower `encoder_threads`, use `h264`, use `display_data=False`. If the CPU continues to be at 100% then it might be insufficient for your setup, consider `--dataset.streaming_encoding=false` or HW encoding (`--dataset.vcodec=auto`) |
|
||||
| "Encoder queue full" warnings or dropped frames in dataset | Encoder can't keep up (Queue overflow) | If CPU is not at 100%: Increase `encoder_threads`, increase `encoder_queue_maxsize` or use HW encoding (`--dataset.vcodec=auto`). |
|
||||
| High RAM usage | Queue filling faster than encoding | `encoder_threads` too low or CPU insufficient. Reduce `encoder_queue_maxsize` or use HW encoding |
|
||||
| Large video files | Using HW encoder or H.264 | Expected trade-off. Switch to `libsvtav1` if CPU allows |
|
||||
| `save_episode()` still slow | `streaming_encoding` is `False` | Set `--dataset.streaming_encoding=true` |
|
||||
| Encoder thread crash | Codec not available or invalid settings | Check `vcodec` is installed, try `--dataset.vcodec=auto` |
|
||||
| Recorded dataset is missing frames | CPU/GPU starvation or occasional load spikes | If ~5% of frames are missing, your system is likely overloaded — follow the recommendations above. If fewer frames are missing (~2%), they are probably due to occasional transient load spikes (often at startup) and can be considered expected. |
|
||||
|
||||
## 6. Recommended Configurations
|
||||
|
||||
These estimates are conservative; we recommend testing them on your setup—start with a low load and increase it gradually.
|
||||
|
||||
### High-End Systems: modern 12+ cores (24+ threads)
|
||||
|
||||
A throughput between ~250-500M px/sec should be comfortable in CPU. For even better results try HW encoding if available.
|
||||
|
||||
```bash
|
||||
# 3camsx 1280x720x3 @30fps: Defaults work well. Optionally increase encoder parallelism.
|
||||
# 2camsx 1920x1080x3 @30fps: Defaults work well. Optionally increase encoder parallelism.
|
||||
lerobot-record --dataset.encoder_threads=5 ...
|
||||
|
||||
# 3camsx 1920x1080x3 @30fps: Might require some tuning.
|
||||
```
|
||||
|
||||
### Mid-Range Systems: modern 8+ cores (16+ threads) or Apple Silicon
|
||||
|
||||
A throughput between ~80-300M px/sec should be possible in CPU.
|
||||
|
||||
```bash
|
||||
# 3camsx 640x480x3 @30fps: Defaults work well. Optionally decrease encoder parallelism.
|
||||
# 2camsx 1280x720x3 @30fps: Defaults work well. Optionally decrease encoder parallelism.
|
||||
lerobot-record --dataset.encoder_threads=2 ...
|
||||
|
||||
# 2camsx 1920x1080x3 @30fps: Might require some tuning.
|
||||
```
|
||||
|
||||
### Low-Resource Systems: modern 4+ cores (8+ threads) or Raspberry Pi 5
|
||||
|
||||
On very constrained systems, streaming encoding may compete too heavily with the capture loop. Disabling it falls back to the PNG-based approach where encoding happens between episodes (blocking, but doesn't interfere with capture). Alternatively, record at a lower throughput to reduce both capture and encoding load. Consider also changing codec to `h264` and using batch encoding.
|
||||
|
||||
```bash
|
||||
# 2camsx 640x480x3 @30fps: Requires some tuning.
|
||||
|
||||
# Use H.264, disable streaming, consider batching encoding
|
||||
lerobot-record --dataset.vcodec=h264 --dataset.streaming_encoding=false ...
|
||||
```
|
||||
|
||||
## 7. Closing note
|
||||
|
||||
Performance ultimately depends on your exact setup — frames-per-second, resolution, CPU cores and load, available memory, episode length, and the encoder you choose. Always test with your target workload, be mindful about your CPU & system capabilities and tune `encoder_threads`, `encoder_queue_maxsize`, and
|
||||
`vcodec` reasonably. That said, a common practical configuration (for many applications) is three cameras at 640×480x3 @30fps; this usually runs fine with the default streaming video encoding settings in modern systems. Always verify your recorded dataset is healthy by comparing the video duration to the CLI episode duration and confirming the row count equals FPS × CLI duration.
|
||||
@@ -216,7 +216,7 @@ lerobot-teleoperate \
|
||||
### Record Dataset in Simulation
|
||||
|
||||
```bash
|
||||
lerobot-record \
|
||||
python -m lerobot.scripts.lerobot_record \
|
||||
--robot.type=unitree_g1 \
|
||||
--robot.is_simulation=true \
|
||||
--robot.cameras='{"global_view": {"type": "zmq", "server_address": "localhost", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30}}' \
|
||||
@@ -229,10 +229,7 @@ lerobot-record \
|
||||
--dataset.num_episodes=2 \
|
||||
--dataset.episode_time_s=5 \
|
||||
--dataset.reset_time_s=5 \
|
||||
--dataset.push_to_hub=true \
|
||||
--dataset.streaming_encoding=true \
|
||||
# --dataset.vcodec=auto \
|
||||
--dataset.encoder_threads=2
|
||||
--dataset.push_to_hub=true
|
||||
```
|
||||
|
||||
Example simulation dataset: [nepyope/teleop_test_sim](https://huggingface.co/datasets/nepyope/teleop_test_sim)
|
||||
@@ -269,7 +266,7 @@ lerobot-teleoperate \
|
||||
### Record Dataset on Real Robot
|
||||
|
||||
```bash
|
||||
lerobot-record \
|
||||
python -m lerobot.scripts.lerobot_record \
|
||||
--robot.type=unitree_g1 \
|
||||
--robot.is_simulation=false \
|
||||
--robot.cameras='{"global_view": {"type": "zmq", "server_address": "172.18.129.215", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30}}' \
|
||||
@@ -282,10 +279,7 @@ lerobot-record \
|
||||
--dataset.num_episodes=2 \
|
||||
--dataset.episode_time_s=5 \
|
||||
--dataset.reset_time_s=5 \
|
||||
--dataset.push_to_hub=true \
|
||||
--dataset.streaming_encoding=true \
|
||||
# --dataset.vcodec=auto \
|
||||
--dataset.encoder_threads=2
|
||||
--dataset.push_to_hub=true
|
||||
```
|
||||
|
||||
**Note**: Update `server_address` to match your robot's camera server IP.
|
||||
|
||||
@@ -12,7 +12,6 @@ LeRobot provides several utilities for manipulating datasets:
|
||||
4. **Add Features** - Add new features to a dataset
|
||||
5. **Remove Features** - Remove features from a dataset
|
||||
6. **Convert to Video** - Convert image-based datasets to video format for efficient storage
|
||||
7. **Show the Info of Datasets** - Show the summary of datasets information such as number of episode etc.
|
||||
|
||||
The core implementation is in `lerobot.datasets.dataset_tools`.
|
||||
An example script detailing how to use the tools API is available in `examples/dataset/use_dataset_tools.py`.
|
||||
@@ -157,30 +156,6 @@ lerobot-edit-dataset \
|
||||
|
||||
**Note:** The resulting dataset will be a proper LeRobotDataset with all cameras encoded as videos in the `videos/` directory, with parquet files containing only metadata (no raw image data). All episodes, stats, and tasks are preserved.
|
||||
|
||||
### Show the information of datasets
|
||||
|
||||
Show the information of datasets such as number of episode, number of frame, File size and so on.
|
||||
No change will be made to the dataset
|
||||
|
||||
```bash
|
||||
|
||||
# Show dataset information without feature details
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--operation.type info \
|
||||
|
||||
# Show dataset information with feature details
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--operation.type info \
|
||||
--operation.show_features true
|
||||
|
||||
```
|
||||
|
||||
**Parameters:**
|
||||
|
||||
- `parameters`: The flag to control show or no show dataset information with feature details.(default=false)
|
||||
|
||||
### Push to Hub
|
||||
|
||||
Add the `--push_to_hub true` flag to any command to automatically upload the resulting dataset to the Hugging Face Hub:
|
||||
|
||||
@@ -45,7 +45,7 @@ policy.type=wall_x
|
||||
For training WallX, you can use the standard LeRobot training script with the appropriate configuration:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
python src/lerobot/scripts/lerobot_train.py \
|
||||
--dataset.repo_id=your_dataset \
|
||||
--policy.type=wall_x \
|
||||
--output_dir=./outputs/wallx_training \
|
||||
|
||||
@@ -154,7 +154,7 @@ lerobot-train \
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--dataset.repo_id=<USER>/bimanual-so100-handover-cube \
|
||||
--dataset.repo_id=pepijn223/bimanual-so100-handover-cube \
|
||||
--output_dir=./outputs/xvla_bimanual \
|
||||
--job_name=xvla_so101_training \
|
||||
--policy.path="lerobot/xvla-base" \
|
||||
|
||||
@@ -22,7 +22,7 @@ lerobot-replay \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58760431541 \
|
||||
--robot.id=black \
|
||||
--dataset.repo_id=<USER>/record-test \
|
||||
--dataset.repo_id=aliberts/record-test \
|
||||
--dataset.episode=2
|
||||
```
|
||||
"""
|
||||
|
||||
@@ -27,8 +27,8 @@ measuring consistency and ground truth alignment.
|
||||
Usage:
|
||||
# Basic usage with smolvla policy
|
||||
uv run python examples/rtc/eval_dataset.py \
|
||||
--policy.path=<USER>/smolvla_check_rtc_last3 \
|
||||
--dataset.repo_id=<USER>/check_rtc \
|
||||
--policy.path=helper2424/smolvla_check_rtc_last3 \
|
||||
--dataset.repo_id=helper2424/check_rtc \
|
||||
--rtc.execution_horizon=8 \
|
||||
--device=mps \
|
||||
--rtc.max_guidance_weight=10.0 \
|
||||
@@ -58,16 +58,16 @@ Usage:
|
||||
--device=cuda
|
||||
|
||||
uv run python examples/rtc/eval_dataset.py \
|
||||
--policy.path=<USER>/reuben_pi0 \
|
||||
--dataset.repo_id=<USER>/so101_cube_in_cup \
|
||||
--policy.path=lipsop/reuben_pi0 \
|
||||
--dataset.repo_id=ReubenLim/so101_cube_in_cup \
|
||||
--rtc.execution_horizon=8 \
|
||||
--device=cuda
|
||||
|
||||
# With torch.compile for faster inference (PyTorch 2.0+)
|
||||
# Note: CUDA graphs disabled by default due to in-place ops in denoising loop
|
||||
uv run python examples/rtc/eval_dataset.py \
|
||||
--policy.path=<USER>/smolvla_check_rtc_last3 \
|
||||
--dataset.repo_id=<USER>/check_rtc \
|
||||
--policy.path=helper2424/smolvla_check_rtc_last3 \
|
||||
--dataset.repo_id=helper2424/check_rtc \
|
||||
--rtc.execution_horizon=8 \
|
||||
--device=mps \
|
||||
--use_torch_compile=true \
|
||||
@@ -75,8 +75,8 @@ Usage:
|
||||
|
||||
# With torch.compile on CUDA (CUDA graphs disabled by default)
|
||||
uv run python examples/rtc/eval_dataset.py \
|
||||
--policy.path=<USER>/smolvla_check_rtc_last3 \
|
||||
--dataset.repo_id=<USER>/check_rtc \
|
||||
--policy.path=helper2424/smolvla_check_rtc_last3 \
|
||||
--dataset.repo_id=helper2424/check_rtc \
|
||||
--rtc.execution_horizon=8 \
|
||||
--device=cuda \
|
||||
--use_torch_compile=true \
|
||||
@@ -84,8 +84,8 @@ Usage:
|
||||
|
||||
# Enable CUDA graphs (advanced - may cause tensor aliasing errors)
|
||||
uv run python examples/rtc/eval_dataset.py \
|
||||
--policy.path=<USER>/smolvla_check_rtc_last3 \
|
||||
--dataset.repo_id=<USER>/check_rtc \
|
||||
--policy.path=helper2424/smolvla_check_rtc_last3 \
|
||||
--dataset.repo_id=helper2424/check_rtc \
|
||||
--use_torch_compile=true \
|
||||
--torch_compile_backend=inductor \
|
||||
--torch_compile_mode=max-autotune \
|
||||
|
||||
@@ -28,7 +28,7 @@ For simulation environments, see eval_with_simulation.py
|
||||
Usage:
|
||||
# Run RTC with Real robot with RTC
|
||||
uv run examples/rtc/eval_with_real_robot.py \
|
||||
--policy.path=<USER>/smolvla_check_rtc_last3 \
|
||||
--policy.path=helper2424/smolvla_check_rtc_last3 \
|
||||
--policy.device=mps \
|
||||
--rtc.enabled=true \
|
||||
--rtc.execution_horizon=20 \
|
||||
@@ -41,7 +41,7 @@ Usage:
|
||||
|
||||
# Run RTC with Real robot without RTC
|
||||
uv run examples/rtc/eval_with_real_robot.py \
|
||||
--policy.path=<USER>/smolvla_check_rtc_last3 \
|
||||
--policy.path=helper2424/smolvla_check_rtc_last3 \
|
||||
--policy.device=mps \
|
||||
--rtc.enabled=false \
|
||||
--robot.type=so100_follower \
|
||||
@@ -53,7 +53,7 @@ Usage:
|
||||
|
||||
# Run RTC with Real robot with pi0.5 policy
|
||||
uv run examples/rtc/eval_with_real_robot.py \
|
||||
--policy.path=<USER>/pi05_check_rtc \
|
||||
--policy.path=helper2424/pi05_check_rtc \
|
||||
--policy.device=mps \
|
||||
--rtc.enabled=true \
|
||||
--rtc.execution_horizon=20 \
|
||||
|
||||
+5
-10
@@ -59,7 +59,7 @@ keywords = ["lerobot", "huggingface", "robotics", "machine learning", "artifici
|
||||
dependencies = [
|
||||
|
||||
# Hugging Face dependencies
|
||||
"datasets>=4.0.0,<5.0.0",
|
||||
"datasets>=4.0.0,<4.2.0",
|
||||
"diffusers>=0.27.2,<0.36.0",
|
||||
"huggingface-hub[hf-transfer,cli]>=0.34.2,<0.36.0",
|
||||
"accelerate>=1.10.0,<2.0.0",
|
||||
@@ -76,9 +76,9 @@ dependencies = [
|
||||
"pyserial>=3.5,<4.0",
|
||||
"wandb>=0.24.0,<0.25.0",
|
||||
|
||||
"torch>=2.2.1,<2.11.0", # TODO: Bump dependency
|
||||
"torchcodec>=0.2.1,<0.11.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # TODO: Bump dependency
|
||||
"torchvision>=0.21.0,<0.26.0", # TODO: Bump dependency
|
||||
"torch>=2.2.1,<2.8.0", # TODO: Bumb dependency
|
||||
"torchcodec>=0.2.1,<0.6.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # TODO: Bumb dependency
|
||||
"torchvision>=0.21.0,<0.23.0", # TODO: Bumb dependency
|
||||
|
||||
"draccus==0.10.0", # TODO: Remove ==
|
||||
"gymnasium>=1.1.1,<2.0.0",
|
||||
@@ -98,13 +98,11 @@ pygame-dep = ["pygame>=2.5.1,<2.7.0"]
|
||||
placo-dep = ["placo>=0.9.6,<0.10.0"]
|
||||
transformers-dep = ["transformers>=4.57.1,<5.0.0"]
|
||||
grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"]
|
||||
can-dep = ["python-can>=4.2.0,<5.0.0"]
|
||||
|
||||
# Motors
|
||||
feetech = ["feetech-servo-sdk>=1.0.0,<2.0.0"]
|
||||
dynamixel = ["dynamixel-sdk>=3.7.31,<3.9.0"]
|
||||
damiao = ["lerobot[can-dep]"]
|
||||
robstride = ["lerobot[can-dep]"]
|
||||
damiao = ["python-can>=4.2.0,<5.0.0"]
|
||||
|
||||
# Robots
|
||||
openarms = ["lerobot[damiao]"]
|
||||
@@ -214,9 +212,6 @@ lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main"
|
||||
lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main"
|
||||
|
||||
# ---------------- Tool Configurations ----------------
|
||||
[tool.setuptools.package-data]
|
||||
lerobot = ["envs/*.json"]
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
|
||||
|
||||
@@ -150,7 +150,7 @@ class Camera(abc.ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
def read_latest(self, max_age_ms: int = 500) -> NDArray[Any]:
|
||||
def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]:
|
||||
"""Return the most recent frame captured immediately (Peeking).
|
||||
|
||||
This method is non-blocking and returns whatever is currently in the
|
||||
|
||||
@@ -530,7 +530,7 @@ class OpenCVCamera(Camera):
|
||||
return frame
|
||||
|
||||
@check_if_not_connected
|
||||
def read_latest(self, max_age_ms: int = 500) -> NDArray[Any]:
|
||||
def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]:
|
||||
"""Return the most recent frame captured immediately (Peeking).
|
||||
|
||||
This method is non-blocking and returns whatever is currently in the
|
||||
|
||||
@@ -201,7 +201,7 @@ class Reachy2Camera(Camera):
|
||||
return self.read()
|
||||
|
||||
@check_if_not_connected
|
||||
def read_latest(self, max_age_ms: int = 500) -> NDArray[Any]:
|
||||
def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]:
|
||||
"""Return the most recent frame captured immediately (Peeking).
|
||||
|
||||
This method is non-blocking and returns whatever is currently in the
|
||||
|
||||
@@ -573,7 +573,7 @@ class RealSenseCamera(Camera):
|
||||
|
||||
# NOTE(Steven): Missing implementation for depth for now
|
||||
@check_if_not_connected
|
||||
def read_latest(self, max_age_ms: int = 500) -> NDArray[Any]:
|
||||
def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]:
|
||||
"""Return the most recent (color) frame captured immediately (Peeking).
|
||||
|
||||
This method is non-blocking and returns whatever is currently in the
|
||||
|
||||
@@ -7,13 +7,6 @@
|
||||
|
||||
This dataset was created using [LeRobot](https://github.com/huggingface/lerobot).
|
||||
|
||||
{% if repo_id is defined and repo_id %}
|
||||
<a class="flex" href="https://huggingface.co/spaces/lerobot/visualize_dataset?path={{ repo_id }}">
|
||||
<img class="block dark:hidden" src="https://huggingface.co/datasets/huggingface/badges/resolve/main/visualize-this-dataset-xl.svg"/>
|
||||
<img class="hidden dark:block" src="https://huggingface.co/datasets/huggingface/badges/resolve/main/visualize-this-dataset-xl-dark.svg"/>
|
||||
</a>
|
||||
{% endif %}
|
||||
|
||||
## Dataset Description
|
||||
|
||||
{{ dataset_description | default("", true) }}
|
||||
|
||||
@@ -37,7 +37,7 @@ import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.datasets.aggregate import aggregate_datasets
|
||||
from lerobot.datasets.compute_stats import aggregate_stats
|
||||
from lerobot.datasets.compute_stats import aggregate_stats, compute_episode_stats, get_feature_stats
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.datasets.utils import (
|
||||
DATA_DIR,
|
||||
@@ -567,22 +567,20 @@ def _copy_and_reindex_data(
|
||||
def _keep_episodes_from_video_with_av(
|
||||
input_path: Path,
|
||||
output_path: Path,
|
||||
episodes_to_keep: list[tuple[int, int]],
|
||||
episodes_to_keep: list[tuple[float, float]],
|
||||
fps: float,
|
||||
vcodec: str = "libsvtav1",
|
||||
pix_fmt: str = "yuv420p",
|
||||
) -> None:
|
||||
"""Keep only specified episodes from a video file using PyAV.
|
||||
|
||||
This function decodes frames from specified frame ranges and re-encodes them with
|
||||
This function decodes frames from specified time ranges and re-encodes them with
|
||||
properly reset timestamps to ensure monotonic progression.
|
||||
|
||||
Args:
|
||||
input_path: Source video file path.
|
||||
output_path: Destination video file path.
|
||||
episodes_to_keep: List of (start_frame, end_frame) tuples for episodes to keep.
|
||||
Ranges are half-open intervals: [start_frame, end_frame), where start_frame
|
||||
is inclusive and end_frame is exclusive.
|
||||
episodes_to_keep: List of (start_time, end_time) tuples for episodes to keep.
|
||||
fps: Frame rate of the video.
|
||||
vcodec: Video codec to use for encoding.
|
||||
pix_fmt: Pixel format for output video.
|
||||
@@ -624,10 +622,9 @@ def _keep_episodes_from_video_with_av(
|
||||
|
||||
# Create set of (start, end) ranges for fast lookup.
|
||||
# Convert to a sorted list for efficient checking.
|
||||
frame_ranges = sorted(episodes_to_keep)
|
||||
time_ranges = sorted(episodes_to_keep)
|
||||
|
||||
# Track frame index for setting PTS and current range being processed.
|
||||
src_frame_count = 0
|
||||
frame_count = 0
|
||||
range_idx = 0
|
||||
|
||||
@@ -637,20 +634,21 @@ def _keep_episodes_from_video_with_av(
|
||||
if frame is None:
|
||||
continue
|
||||
|
||||
# Check if frame is in any of our desired frame ranges.
|
||||
# Get frame timestamp.
|
||||
frame_time = float(frame.pts * frame.time_base) if frame.pts is not None else 0.0
|
||||
|
||||
# Check if frame is in any of our desired time ranges.
|
||||
# Skip ranges that have already passed.
|
||||
while range_idx < len(frame_ranges) and src_frame_count >= frame_ranges[range_idx][1]:
|
||||
while range_idx < len(time_ranges) and frame_time >= time_ranges[range_idx][1]:
|
||||
range_idx += 1
|
||||
|
||||
# If we've passed all ranges, stop processing.
|
||||
if range_idx >= len(frame_ranges):
|
||||
if range_idx >= len(time_ranges):
|
||||
break
|
||||
|
||||
# Check if frame is in current range.
|
||||
start_frame = frame_ranges[range_idx][0]
|
||||
|
||||
if src_frame_count < start_frame:
|
||||
src_frame_count += 1
|
||||
start_ts, end_ts = time_ranges[range_idx]
|
||||
if frame_time < start_ts:
|
||||
continue
|
||||
|
||||
# Frame is in range - create a new frame with reset timestamps.
|
||||
@@ -663,7 +661,6 @@ def _keep_episodes_from_video_with_av(
|
||||
for pkt in v_out.encode(new_frame):
|
||||
out.mux(pkt)
|
||||
|
||||
src_frame_count += 1
|
||||
frame_count += 1
|
||||
|
||||
# Flush encoder.
|
||||
@@ -752,17 +749,15 @@ def _copy_and_reindex_videos(
|
||||
f"videos/{video_key}/to_timestamp"
|
||||
]
|
||||
else:
|
||||
# Build list of frame ranges to keep, in sorted order.
|
||||
# Build list of time ranges to keep, in sorted order.
|
||||
sorted_keep_episodes = sorted(episodes_in_file, key=lambda x: episode_mapping[x])
|
||||
episodes_to_keep_ranges: list[tuple[int, int]] = []
|
||||
episodes_to_keep_ranges: list[tuple[float, float]] = []
|
||||
|
||||
for old_idx in sorted_keep_episodes:
|
||||
src_ep = src_dataset.meta.episodes[old_idx]
|
||||
from_frame = round(src_ep[f"videos/{video_key}/from_timestamp"] * src_dataset.meta.fps)
|
||||
to_frame = round(src_ep[f"videos/{video_key}/to_timestamp"] * src_dataset.meta.fps)
|
||||
assert src_ep["length"] == to_frame - from_frame, (
|
||||
f"Episode length mismatch: {src_ep['length']} vs {to_frame - from_frame}"
|
||||
)
|
||||
episodes_to_keep_ranges.append((from_frame, to_frame))
|
||||
from_ts = src_ep[f"videos/{video_key}/from_timestamp"]
|
||||
to_ts = src_ep[f"videos/{video_key}/to_timestamp"]
|
||||
episodes_to_keep_ranges.append((from_ts, to_ts))
|
||||
|
||||
# Use PyAV filters to efficiently re-encode only the desired segments.
|
||||
assert src_dataset.meta.video_path is not None
|
||||
@@ -1527,6 +1522,122 @@ def modify_tasks(
|
||||
return dataset
|
||||
|
||||
|
||||
def recompute_stats(
|
||||
dataset: LeRobotDataset,
|
||||
skip_image_video: bool = True,
|
||||
delta_action: bool = False,
|
||||
delta_exclude_joints: list[str] | None = None,
|
||||
) -> LeRobotDataset:
|
||||
"""Recompute stats.json from scratch by iterating all episodes.
|
||||
|
||||
Args:
|
||||
dataset: The LeRobotDataset to recompute stats for.
|
||||
skip_image_video: If True (default), only recompute stats for numeric features
|
||||
(action, state, etc.) and keep existing image/video stats unchanged.
|
||||
delta_action: If True, compute action stats as delta (action - state).
|
||||
Useful when training with use_delta_actions=True so normalization matches.
|
||||
delta_exclude_joints: Joint names to exclude from delta conversion when
|
||||
delta_action=True. These dims keep absolute stats. Uses dataset's
|
||||
action feature names to build the mask. Default: ["gripper"].
|
||||
|
||||
Returns:
|
||||
The same dataset with updated stats.
|
||||
"""
|
||||
features = dataset.meta.features
|
||||
numeric_features = {
|
||||
k: v for k, v in features.items()
|
||||
if v["dtype"] not in ["image", "video", "string"]
|
||||
and k not in ["index", "episode_index", "task_index", "frame_index", "timestamp"]
|
||||
}
|
||||
|
||||
if skip_image_video:
|
||||
features_to_compute = numeric_features
|
||||
else:
|
||||
features_to_compute = {
|
||||
k: v for k, v in features.items()
|
||||
if v["dtype"] != "string"
|
||||
and k not in ["index", "episode_index", "task_index", "frame_index", "timestamp"]
|
||||
}
|
||||
|
||||
# Build delta mask if delta_action is enabled
|
||||
delta_mask = None
|
||||
if delta_action and "action" in features and "observation.state" in features:
|
||||
if delta_exclude_joints is None:
|
||||
delta_exclude_joints = ["gripper"]
|
||||
action_names = features["action"].get("names")
|
||||
if action_names is not None:
|
||||
exclude = set(delta_exclude_joints)
|
||||
delta_mask = [n not in exclude for n in action_names]
|
||||
else:
|
||||
action_dim = features["action"]["shape"][0]
|
||||
delta_mask = [True] * action_dim
|
||||
# Only recompute action stats when delta is enabled — state stays unchanged
|
||||
features_to_compute = {"action": features["action"]}
|
||||
logging.info(f"Recomputing action stats as delta (exclude: {delta_exclude_joints})")
|
||||
else:
|
||||
logging.info(f"Recomputing stats for features: {list(features_to_compute.keys())}")
|
||||
|
||||
data_dir = dataset.root / DATA_DIR
|
||||
parquet_files = sorted(data_dir.glob("*/*.parquet"))
|
||||
if not parquet_files:
|
||||
raise ValueError(f"No parquet files found in {data_dir}")
|
||||
|
||||
all_episode_stats = []
|
||||
numeric_keys = [k for k, v in features_to_compute.items() if v["dtype"] not in ["image", "video"]]
|
||||
# Also need state for delta computation even though we don't recompute state stats
|
||||
needs_state = delta_mask is not None
|
||||
|
||||
for parquet_path in tqdm(parquet_files, desc="Computing stats from data files"):
|
||||
df = pd.read_parquet(parquet_path)
|
||||
|
||||
for ep_idx in sorted(df["episode_index"].unique()):
|
||||
ep_df = df[df["episode_index"] == ep_idx]
|
||||
episode_data = {}
|
||||
for key in numeric_keys:
|
||||
if key in ep_df.columns:
|
||||
values = ep_df[key].values
|
||||
if hasattr(values[0], "__len__"):
|
||||
episode_data[key] = np.stack(values)
|
||||
else:
|
||||
episode_data[key] = np.array(values)
|
||||
|
||||
# Apply delta conversion to actions before computing stats
|
||||
if delta_mask is not None and "action" in episode_data:
|
||||
from lerobot.processor.delta_action_processor import to_delta_actions
|
||||
|
||||
# Load state for delta even if we're not computing state stats
|
||||
if needs_state and "observation.state" in ep_df.columns:
|
||||
state_values = ep_df["observation.state"].values
|
||||
if hasattr(state_values[0], "__len__"):
|
||||
states = np.stack(state_values)
|
||||
else:
|
||||
states = np.array(state_values)
|
||||
actions_t = torch.from_numpy(episode_data["action"]).float()
|
||||
states_t = torch.from_numpy(states).float()
|
||||
episode_data["action"] = to_delta_actions(actions_t, states_t, delta_mask).numpy()
|
||||
|
||||
ep_stats = compute_episode_stats(episode_data, features_to_compute)
|
||||
all_episode_stats.append(ep_stats)
|
||||
|
||||
if not all_episode_stats:
|
||||
logging.warning("No episode stats computed")
|
||||
return dataset
|
||||
|
||||
new_stats = aggregate_stats(all_episode_stats)
|
||||
|
||||
# Merge: keep existing stats for features we didn't recompute
|
||||
if dataset.meta.stats:
|
||||
for key, value in dataset.meta.stats.items():
|
||||
if key not in new_stats:
|
||||
new_stats[key] = value
|
||||
|
||||
write_stats(new_stats, dataset.root)
|
||||
dataset.meta.stats = new_stats
|
||||
|
||||
logging.info(f"Stats recomputed for {len(all_episode_stats)} episodes")
|
||||
return dataset
|
||||
|
||||
|
||||
def convert_image_to_video_dataset(
|
||||
dataset: LeRobotDataset,
|
||||
output_dir: Path,
|
||||
|
||||
@@ -68,7 +68,6 @@ from lerobot.datasets.utils import (
|
||||
write_tasks,
|
||||
)
|
||||
from lerobot.datasets.video_utils import (
|
||||
StreamingVideoEncoder,
|
||||
VideoFrame,
|
||||
concatenate_video_files,
|
||||
decode_video_frames,
|
||||
@@ -76,11 +75,11 @@ from lerobot.datasets.video_utils import (
|
||||
get_safe_default_codec,
|
||||
get_video_duration_in_s,
|
||||
get_video_info,
|
||||
resolve_vcodec,
|
||||
)
|
||||
from lerobot.utils.constants import HF_LEROBOT_HOME
|
||||
|
||||
CODEBASE_VERSION = "v3.0"
|
||||
VALID_VIDEO_CODECS = {"h264", "hevc", "libsvtav1"}
|
||||
|
||||
|
||||
class LeRobotDatasetMetadata:
|
||||
@@ -546,19 +545,12 @@ class LeRobotDatasetMetadata:
|
||||
|
||||
|
||||
def _encode_video_worker(
|
||||
video_key: str,
|
||||
episode_index: int,
|
||||
root: Path,
|
||||
fps: int,
|
||||
vcodec: str = "libsvtav1",
|
||||
encoder_threads: int | None = None,
|
||||
video_key: str, episode_index: int, root: Path, fps: int, vcodec: str = "libsvtav1"
|
||||
) -> Path:
|
||||
temp_path = Path(tempfile.mkdtemp(dir=root)) / f"{video_key}_{episode_index:03d}.mp4"
|
||||
fpath = DEFAULT_IMAGE_PATH.format(image_key=video_key, episode_index=episode_index, frame_index=0)
|
||||
img_dir = (root / fpath).parent
|
||||
encode_video_frames(
|
||||
img_dir, temp_path, fps, vcodec=vcodec, overwrite=True, encoder_threads=encoder_threads
|
||||
)
|
||||
encode_video_frames(img_dir, temp_path, fps, vcodec=vcodec, overwrite=True)
|
||||
shutil.rmtree(img_dir)
|
||||
return temp_path
|
||||
|
||||
@@ -578,9 +570,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
video_backend: str | None = None,
|
||||
batch_encoding_size: int = 1,
|
||||
vcodec: str = "libsvtav1",
|
||||
streaming_encoding: bool = False,
|
||||
encoder_queue_maxsize: int = 30,
|
||||
encoder_threads: int | None = None,
|
||||
):
|
||||
"""
|
||||
2 modes are available for instantiating this class, depending on 2 different use cases:
|
||||
@@ -667,7 +656,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
repo_id (str): This is the repo id that will be used to fetch the dataset. Locally, the dataset
|
||||
will be stored under root/repo_id.
|
||||
root (Path | None, optional): Local directory to use for downloading/writing files. You can also
|
||||
set the HF_LEROBOT_HOME environment variable to point to a different location. Defaults to
|
||||
set the LEROBOT_HOME environment variable to point to a different location. Defaults to
|
||||
'~/.cache/huggingface/lerobot'.
|
||||
episodes (list[int] | None, optional): If specified, this will only load episodes specified by
|
||||
their episode_index in this list. Defaults to None.
|
||||
@@ -694,17 +683,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
batch_encoding_size (int, optional): Number of episodes to accumulate before batch encoding videos.
|
||||
Set to 1 for immediate encoding (default), or higher for batched encoding. Defaults to 1.
|
||||
vcodec (str, optional): Video codec for encoding videos during recording. Options: 'h264', 'hevc',
|
||||
'libsvtav1', 'auto', or hardware-specific codecs like 'h264_videotoolbox', 'h264_nvenc'.
|
||||
Defaults to 'libsvtav1'. Use 'auto' to auto-detect the best available hardware encoder.
|
||||
streaming_encoding (bool, optional): If True, encode video frames in real-time during capture
|
||||
instead of writing PNG images first. This makes save_episode() near-instant. Defaults to False.
|
||||
encoder_queue_maxsize (int, optional): Maximum number of frames to buffer per camera when using
|
||||
streaming encoding. Defaults to 30 (~1s at 30fps).
|
||||
encoder_threads (int | None, optional): Number of threads per encoder instance. None lets the
|
||||
codec auto-detect (default). Lower values reduce CPU usage per encoder. Maps to 'lp' (via svtav1-params) for
|
||||
libsvtav1 and 'threads' for h264/hevc.
|
||||
'libsvtav1'. Defaults to 'libsvtav1'. Use 'h264' for faster encoding on systems where AV1
|
||||
encoding is CPU-heavy.
|
||||
"""
|
||||
super().__init__()
|
||||
if vcodec not in VALID_VIDEO_CODECS:
|
||||
raise ValueError(f"Invalid vcodec '{vcodec}'. Must be one of: {sorted(VALID_VIDEO_CODECS)}")
|
||||
self.repo_id = repo_id
|
||||
self.root = Path(root) if root else HF_LEROBOT_HOME / repo_id
|
||||
self.image_transforms = image_transforms
|
||||
@@ -716,8 +700,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.delta_indices = None
|
||||
self.batch_encoding_size = batch_encoding_size
|
||||
self.episodes_since_last_encoding = 0
|
||||
self.vcodec = resolve_vcodec(vcodec)
|
||||
self._encoder_threads = encoder_threads
|
||||
self.vcodec = vcodec
|
||||
|
||||
# Unused attributes
|
||||
self.image_writer = None
|
||||
@@ -725,7 +708,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.writer = None
|
||||
self.latest_episode = None
|
||||
self._current_file_start_frame = None # Track the starting frame index of the current parquet file
|
||||
self._streaming_encoder = None
|
||||
|
||||
self.root.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
@@ -747,7 +729,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
# Check if cached dataset contains all requested episodes
|
||||
if not self._check_cached_episodes_sufficient():
|
||||
raise FileNotFoundError("Cached dataset doesn't contain all requested episodes")
|
||||
except (FileNotFoundError, NotADirectoryError):
|
||||
except (AssertionError, FileNotFoundError, NotADirectoryError):
|
||||
if is_valid_version(self.revision):
|
||||
self.revision = get_safe_version(self.repo_id, self.revision)
|
||||
self.download(download_videos)
|
||||
@@ -767,19 +749,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s)
|
||||
self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps)
|
||||
|
||||
# Initialize streaming encoder for resumed recording
|
||||
if streaming_encoding and len(self.meta.video_keys) > 0:
|
||||
self._streaming_encoder = StreamingVideoEncoder(
|
||||
fps=self.meta.fps,
|
||||
vcodec=self.vcodec,
|
||||
pix_fmt="yuv420p",
|
||||
g=2,
|
||||
crf=30,
|
||||
preset=None,
|
||||
queue_maxsize=encoder_queue_maxsize,
|
||||
encoder_threads=encoder_threads,
|
||||
)
|
||||
|
||||
def _close_writer(self) -> None:
|
||||
"""Close and cleanup the parquet writer if it exists."""
|
||||
writer = getattr(self, "writer", None)
|
||||
@@ -839,7 +808,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
hub_api.upload_folder(**upload_kwargs)
|
||||
|
||||
card = create_lerobot_dataset_card(
|
||||
tags=tags, dataset_info=self.meta.info, license=license, repo_id=self.repo_id, **card_kwargs
|
||||
tags=tags, dataset_info=self.meta.info, license=license, **card_kwargs
|
||||
)
|
||||
card.push_to_hub(repo_id=self.repo_id, repo_type="dataset", revision=branch)
|
||||
|
||||
@@ -1135,8 +1104,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
self._close_writer()
|
||||
self.meta._close_writer()
|
||||
if self._streaming_encoder is not None:
|
||||
self._streaming_encoder.close()
|
||||
|
||||
def create_episode_buffer(self, episode_index: int | None = None) -> dict:
|
||||
current_ep_idx = self.meta.total_episodes if episode_index is None else episode_index
|
||||
@@ -1191,13 +1158,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.episode_buffer["timestamp"].append(timestamp)
|
||||
self.episode_buffer["task"].append(frame.pop("task")) # Remove task from frame after processing
|
||||
|
||||
# Start streaming encoder on first frame of episode (once, before iterating keys)
|
||||
if frame_index == 0 and self._streaming_encoder is not None:
|
||||
self._streaming_encoder.start_episode(
|
||||
video_keys=list(self.meta.video_keys),
|
||||
temp_dir=self.root,
|
||||
)
|
||||
|
||||
# Add frame features to episode_buffer
|
||||
for key in frame:
|
||||
if key not in self.features:
|
||||
@@ -1205,10 +1165,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
f"An element of the frame is not in the features. '{key}' not in '{self.features.keys()}'."
|
||||
)
|
||||
|
||||
if self.features[key]["dtype"] == "video" and self._streaming_encoder is not None:
|
||||
self._streaming_encoder.feed_frame(key, frame[key])
|
||||
self.episode_buffer[key].append(None) # Placeholder (video keys are skipped in parquet)
|
||||
elif self.features[key]["dtype"] in ["image", "video"]:
|
||||
if self.features[key]["dtype"] in ["image", "video"]:
|
||||
img_path = self._get_image_file_path(
|
||||
episode_index=self.episode_buffer["episode_index"], image_key=key, frame_index=frame_index
|
||||
)
|
||||
@@ -1269,38 +1226,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
# Wait for image writer to end, so that episode stats over images can be computed
|
||||
self._wait_image_writer()
|
||||
|
||||
has_video_keys = len(self.meta.video_keys) > 0
|
||||
use_streaming = self._streaming_encoder is not None and has_video_keys
|
||||
use_batched_encoding = self.batch_encoding_size > 1
|
||||
|
||||
if use_streaming:
|
||||
# Compute stats for non-video features only (video stats come from encoder)
|
||||
non_video_buffer = {
|
||||
k: v
|
||||
for k, v in episode_buffer.items()
|
||||
if self.features.get(k, {}).get("dtype") not in ("video",)
|
||||
}
|
||||
non_video_features = {k: v for k, v in self.features.items() if v["dtype"] != "video"}
|
||||
ep_stats = compute_episode_stats(non_video_buffer, non_video_features)
|
||||
else:
|
||||
ep_stats = compute_episode_stats(episode_buffer, self.features)
|
||||
ep_stats = compute_episode_stats(episode_buffer, self.features)
|
||||
|
||||
ep_metadata = self._save_episode_data(episode_buffer)
|
||||
has_video_keys = len(self.meta.video_keys) > 0
|
||||
use_batched_encoding = self.batch_encoding_size > 1
|
||||
|
||||
if use_streaming:
|
||||
# Finish streaming encoding and collect results
|
||||
streaming_results = self._streaming_encoder.finish_episode()
|
||||
for video_key in self.meta.video_keys:
|
||||
temp_path, video_stats = streaming_results[video_key]
|
||||
if video_stats is not None:
|
||||
# Format stats same as compute_episode_stats: normalize to [0,1], reshape to (C,1,1)
|
||||
ep_stats[video_key] = {
|
||||
k: v if k == "count" else np.squeeze(v.reshape(1, -1, 1, 1) / 255.0, axis=0)
|
||||
for k, v in video_stats.items()
|
||||
}
|
||||
ep_metadata.update(self._save_episode_video(video_key, episode_index, temp_path=temp_path))
|
||||
elif has_video_keys and not use_batched_encoding:
|
||||
if has_video_keys and not use_batched_encoding:
|
||||
num_cameras = len(self.meta.video_keys)
|
||||
if parallel_encoding and num_cameras > 1:
|
||||
# TODO(Steven): Ideally we would like to control the number of threads per encoding such that:
|
||||
@@ -1314,7 +1246,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.root,
|
||||
self.fps,
|
||||
self.vcodec,
|
||||
self._encoder_threads,
|
||||
): video_key
|
||||
for video_key in self.meta.video_keys
|
||||
}
|
||||
@@ -1583,10 +1514,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
return metadata
|
||||
|
||||
def clear_episode_buffer(self, delete_images: bool = True) -> None:
|
||||
# Cancel streaming encoder if active
|
||||
if self._streaming_encoder is not None:
|
||||
self._streaming_encoder.cancel_episode()
|
||||
|
||||
# Clean up image files for the current episode buffer
|
||||
if delete_images:
|
||||
# Wait for the async image writer to finish
|
||||
@@ -1634,9 +1561,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
|
||||
since video encoding with ffmpeg is already using multithreading.
|
||||
"""
|
||||
return _encode_video_worker(
|
||||
video_key, episode_index, self.root, self.fps, self.vcodec, self._encoder_threads
|
||||
)
|
||||
return _encode_video_worker(video_key, episode_index, self.root, self.fps, self.vcodec)
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
@@ -1653,13 +1578,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
video_backend: str | None = None,
|
||||
batch_encoding_size: int = 1,
|
||||
vcodec: str = "libsvtav1",
|
||||
metadata_buffer_size: int = 10,
|
||||
streaming_encoding: bool = False,
|
||||
encoder_queue_maxsize: int = 30,
|
||||
encoder_threads: int | None = None,
|
||||
) -> "LeRobotDataset":
|
||||
"""Create a LeRobot Dataset from scratch in order to record data."""
|
||||
vcodec = resolve_vcodec(vcodec)
|
||||
if vcodec not in VALID_VIDEO_CODECS:
|
||||
raise ValueError(f"Invalid vcodec '{vcodec}'. Must be one of: {sorted(VALID_VIDEO_CODECS)}")
|
||||
obj = cls.__new__(cls)
|
||||
obj.meta = LeRobotDatasetMetadata.create(
|
||||
repo_id=repo_id,
|
||||
@@ -1668,7 +1590,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
features=features,
|
||||
root=root,
|
||||
use_videos=use_videos,
|
||||
metadata_buffer_size=metadata_buffer_size,
|
||||
)
|
||||
obj.repo_id = obj.meta.repo_id
|
||||
obj.root = obj.meta.root
|
||||
@@ -1678,7 +1599,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
obj.batch_encoding_size = batch_encoding_size
|
||||
obj.episodes_since_last_encoding = 0
|
||||
obj.vcodec = vcodec
|
||||
obj._encoder_threads = encoder_threads
|
||||
|
||||
if image_writer_processes or image_writer_threads:
|
||||
obj.start_image_writer(image_writer_processes, image_writer_threads)
|
||||
@@ -1700,22 +1620,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
obj._lazy_loading = False
|
||||
obj._recorded_frames = 0
|
||||
obj._writer_closed_for_reading = False
|
||||
|
||||
# Initialize streaming encoder
|
||||
if streaming_encoding and len(obj.meta.video_keys) > 0:
|
||||
obj._streaming_encoder = StreamingVideoEncoder(
|
||||
fps=fps,
|
||||
vcodec=vcodec,
|
||||
pix_fmt="yuv420p",
|
||||
g=2,
|
||||
crf=30,
|
||||
preset=None,
|
||||
queue_maxsize=encoder_queue_maxsize,
|
||||
encoder_threads=encoder_threads,
|
||||
)
|
||||
else:
|
||||
obj._streaming_encoder = None
|
||||
|
||||
return obj
|
||||
|
||||
|
||||
|
||||
@@ -122,9 +122,19 @@ def load_nested_dataset(
|
||||
raise FileNotFoundError(f"Provided directory does not contain any parquet file: {pq_dir}")
|
||||
|
||||
with SuppressProgressBars():
|
||||
# We use .from_parquet() memory-mapped loading for efficiency
|
||||
filters = pa_ds.field("episode_index").isin(episodes) if episodes is not None else None
|
||||
return Dataset.from_parquet([str(path) for path in paths], filters=filters, features=features)
|
||||
# When no filtering needed, Dataset uses memory-mapped loading for efficiency
|
||||
# PyArrow loads the entire dataset into memory
|
||||
if episodes is None:
|
||||
return Dataset.from_parquet([str(path) for path in paths], features=features)
|
||||
|
||||
arrow_dataset = pa_ds.dataset(paths, format="parquet")
|
||||
filter_expr = pa_ds.field("episode_index").isin(episodes)
|
||||
table = arrow_dataset.to_table(filter=filter_expr)
|
||||
|
||||
if features is not None:
|
||||
table = table.cast(features.arrow_schema)
|
||||
|
||||
return Dataset(table)
|
||||
|
||||
|
||||
def get_parquet_num_frames(parquet_path: str | Path) -> int:
|
||||
|
||||
@@ -529,7 +529,7 @@ if __name__ == "__main__":
|
||||
type=str,
|
||||
required=True,
|
||||
help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset "
|
||||
"(e.g. `lerobot/pusht`, `<USER>/aloha_sim_insertion_human`).",
|
||||
"(e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--branch",
|
||||
|
||||
@@ -13,106 +13,25 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import contextlib
|
||||
import glob
|
||||
import importlib
|
||||
import logging
|
||||
import queue
|
||||
import shutil
|
||||
import tempfile
|
||||
import threading
|
||||
import warnings
|
||||
from dataclasses import dataclass, field
|
||||
from fractions import Fraction
|
||||
from pathlib import Path
|
||||
from threading import Lock
|
||||
from typing import Any, ClassVar
|
||||
|
||||
import av
|
||||
import fsspec
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
import torch
|
||||
import torchvision
|
||||
from datasets.features.features import register_feature
|
||||
from PIL import Image
|
||||
|
||||
# List of hardware encoders to probe for auto-selection. Availability depends on the platform and FFmpeg build.
|
||||
# Determines the order of preference for auto-selection when vcodec="auto" is used.
|
||||
HW_ENCODERS = [
|
||||
"h264_videotoolbox", # macOS
|
||||
"hevc_videotoolbox", # macOS
|
||||
"h264_nvenc", # NVIDIA GPU
|
||||
"hevc_nvenc", # NVIDIA GPU
|
||||
"h264_vaapi", # Linux Intel/AMD
|
||||
"h264_qsv", # Intel Quick Sync
|
||||
]
|
||||
|
||||
VALID_VIDEO_CODECS = {"h264", "hevc", "libsvtav1", "auto"} | set(HW_ENCODERS)
|
||||
|
||||
|
||||
def _get_codec_options(
|
||||
vcodec: str,
|
||||
g: int | None = 2,
|
||||
crf: int | None = 30,
|
||||
preset: int | None = None,
|
||||
) -> dict:
|
||||
"""Build codec-specific options dict for video encoding."""
|
||||
options = {}
|
||||
|
||||
# GOP size (keyframe interval) - supported by VideoToolbox and software encoders
|
||||
if g is not None and (vcodec in ("h264_videotoolbox", "hevc_videotoolbox") or vcodec not in HW_ENCODERS):
|
||||
options["g"] = str(g)
|
||||
|
||||
# Quality control (codec-specific parameter names)
|
||||
if crf is not None:
|
||||
if vcodec in ("h264", "hevc", "libsvtav1"):
|
||||
options["crf"] = str(crf)
|
||||
elif vcodec in ("h264_videotoolbox", "hevc_videotoolbox"):
|
||||
quality = max(1, min(100, int(100 - crf * 2)))
|
||||
options["q:v"] = str(quality)
|
||||
elif vcodec in ("h264_nvenc", "hevc_nvenc"):
|
||||
options["rc"] = "constqp"
|
||||
options["qp"] = str(crf)
|
||||
elif vcodec in ("h264_vaapi",):
|
||||
options["qp"] = str(crf)
|
||||
elif vcodec in ("h264_qsv",):
|
||||
options["global_quality"] = str(crf)
|
||||
|
||||
# Preset (only for libsvtav1)
|
||||
if vcodec == "libsvtav1":
|
||||
options["preset"] = str(preset) if preset is not None else "12"
|
||||
|
||||
return options
|
||||
|
||||
|
||||
def detect_available_hw_encoders() -> list[str]:
|
||||
"""Probe PyAV/FFmpeg for available hardware video encoders."""
|
||||
available = []
|
||||
for codec_name in HW_ENCODERS:
|
||||
try:
|
||||
av.codec.Codec(codec_name, "w")
|
||||
available.append(codec_name)
|
||||
except Exception: # nosec B110
|
||||
pass # nosec B110
|
||||
return available
|
||||
|
||||
|
||||
def resolve_vcodec(vcodec: str) -> str:
|
||||
"""Validate vcodec and resolve 'auto' to best available HW encoder, fallback to libsvtav1."""
|
||||
if vcodec not in VALID_VIDEO_CODECS:
|
||||
raise ValueError(f"Invalid vcodec '{vcodec}'. Must be one of: {sorted(VALID_VIDEO_CODECS)}")
|
||||
if vcodec != "auto":
|
||||
logging.info(f"Using video codec: {vcodec}")
|
||||
return vcodec
|
||||
available = detect_available_hw_encoders()
|
||||
for encoder in HW_ENCODERS:
|
||||
if encoder in available:
|
||||
logging.info(f"Auto-selected video codec: {encoder}")
|
||||
return encoder
|
||||
logging.info("No hardware encoder available, falling back to software encoder 'libsvtav1'")
|
||||
return "libsvtav1"
|
||||
|
||||
|
||||
def get_safe_default_codec():
|
||||
if importlib.util.find_spec("torchcodec"):
|
||||
@@ -227,17 +146,16 @@ def decode_video_frames_torchvision(
|
||||
min_, argmin_ = dist.min(1)
|
||||
|
||||
is_within_tol = min_ < tolerance_s
|
||||
if not is_within_tol.all():
|
||||
raise FrameTimestampError(
|
||||
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
|
||||
" It means that the closest frame that can be loaded from the video is too far away in time."
|
||||
" This might be due to synchronization issues with timestamps during data collection."
|
||||
" To be safe, we advise to ignore this item during training."
|
||||
f"\nqueried timestamps: {query_ts}"
|
||||
f"\nloaded timestamps: {loaded_ts}"
|
||||
f"\nvideo: {video_path}"
|
||||
f"\nbackend: {backend}"
|
||||
)
|
||||
assert is_within_tol.all(), (
|
||||
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
|
||||
"It means that the closest frame that can be loaded from the video is too far away in time."
|
||||
"This might be due to synchronization issues with timestamps during data collection."
|
||||
"To be safe, we advise to ignore this item during training."
|
||||
f"\nqueried timestamps: {query_ts}"
|
||||
f"\nloaded timestamps: {loaded_ts}"
|
||||
f"\nvideo: {video_path}"
|
||||
f"\nbackend: {backend}"
|
||||
)
|
||||
|
||||
# get closest frames to the query timestamps
|
||||
closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_])
|
||||
@@ -249,11 +167,7 @@ def decode_video_frames_torchvision(
|
||||
# convert to the pytorch format which is float32 in [0,1] range (and channel first)
|
||||
closest_frames = closest_frames.type(torch.float32) / 255
|
||||
|
||||
if len(timestamps) != len(closest_frames):
|
||||
raise FrameTimestampError(
|
||||
f"Number of retrieved frames ({len(closest_frames)}) does not match "
|
||||
f"number of queried timestamps ({len(timestamps)})"
|
||||
)
|
||||
assert len(timestamps) == len(closest_frames)
|
||||
return closest_frames
|
||||
|
||||
|
||||
@@ -358,16 +272,15 @@ def decode_video_frames_torchcodec(
|
||||
min_, argmin_ = dist.min(1)
|
||||
|
||||
is_within_tol = min_ < tolerance_s
|
||||
if not is_within_tol.all():
|
||||
raise FrameTimestampError(
|
||||
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
|
||||
" It means that the closest frame that can be loaded from the video is too far away in time."
|
||||
" This might be due to synchronization issues with timestamps during data collection."
|
||||
" To be safe, we advise to ignore this item during training."
|
||||
f"\nqueried timestamps: {query_ts}"
|
||||
f"\nloaded timestamps: {loaded_ts}"
|
||||
f"\nvideo: {video_path}"
|
||||
)
|
||||
assert is_within_tol.all(), (
|
||||
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
|
||||
"It means that the closest frame that can be loaded from the video is too far away in time."
|
||||
"This might be due to synchronization issues with timestamps during data collection."
|
||||
"To be safe, we advise to ignore this item during training."
|
||||
f"\nqueried timestamps: {query_ts}"
|
||||
f"\nloaded timestamps: {loaded_ts}"
|
||||
f"\nvideo: {video_path}"
|
||||
)
|
||||
|
||||
# get closest frames to the query timestamps
|
||||
closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_])
|
||||
@@ -396,13 +309,14 @@ def encode_video_frames(
|
||||
g: int | None = 2,
|
||||
crf: int | None = 30,
|
||||
fast_decode: int = 0,
|
||||
log_level: int | None = av.logging.WARNING,
|
||||
log_level: int | None = av.logging.ERROR,
|
||||
overwrite: bool = False,
|
||||
preset: int | None = None,
|
||||
encoder_threads: int | None = None,
|
||||
) -> None:
|
||||
"""More info on ffmpeg arguments tuning on `benchmark/video/README.md`"""
|
||||
vcodec = resolve_vcodec(vcodec)
|
||||
# Check encoder availability
|
||||
if vcodec not in ["h264", "hevc", "libsvtav1"]:
|
||||
raise ValueError(f"Unsupported video codec: {vcodec}. Supported codecs are: h264, hevc, libsvtav1.")
|
||||
|
||||
video_path = Path(video_path)
|
||||
imgs_dir = Path(imgs_dir)
|
||||
@@ -433,22 +347,21 @@ def encode_video_frames(
|
||||
width, height = dummy_image.size
|
||||
|
||||
# Define video codec options
|
||||
video_options = _get_codec_options(vcodec, g, crf, preset)
|
||||
video_options = {}
|
||||
|
||||
if g is not None:
|
||||
video_options["g"] = str(g)
|
||||
|
||||
if crf is not None:
|
||||
video_options["crf"] = str(crf)
|
||||
|
||||
if fast_decode:
|
||||
key = "svtav1-params" if vcodec == "libsvtav1" else "tune"
|
||||
value = f"fast-decode={fast_decode}" if vcodec == "libsvtav1" else "fastdecode"
|
||||
video_options[key] = value
|
||||
|
||||
if encoder_threads is not None:
|
||||
if vcodec == "libsvtav1":
|
||||
lp_param = f"lp={encoder_threads}"
|
||||
if "svtav1-params" in video_options:
|
||||
video_options["svtav1-params"] += f":{lp_param}"
|
||||
else:
|
||||
video_options["svtav1-params"] = lp_param
|
||||
else:
|
||||
video_options["threads"] = str(encoder_threads)
|
||||
if vcodec == "libsvtav1":
|
||||
video_options["preset"] = str(preset) if preset is not None else "12"
|
||||
|
||||
# Set logging level
|
||||
if log_level is not None:
|
||||
@@ -567,348 +480,6 @@ def concatenate_video_files(
|
||||
Path(tmp_concatenate_path).unlink()
|
||||
|
||||
|
||||
class _CameraEncoderThread(threading.Thread):
|
||||
"""A thread that encodes video frames streamed via a queue into an MP4 file.
|
||||
|
||||
One instance is created per camera per episode. Frames are received as numpy arrays
|
||||
from the main thread, encoded in real-time using PyAV (which releases the GIL during
|
||||
encoding), and written to disk. Stats are computed incrementally using
|
||||
RunningQuantileStats and returned via result_queue.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
video_path: Path,
|
||||
fps: int,
|
||||
vcodec: str,
|
||||
pix_fmt: str,
|
||||
g: int | None,
|
||||
crf: int | None,
|
||||
preset: int | None,
|
||||
frame_queue: queue.Queue,
|
||||
result_queue: queue.Queue,
|
||||
stop_event: threading.Event,
|
||||
encoder_threads: int | None = None,
|
||||
):
|
||||
super().__init__(daemon=True)
|
||||
self.video_path = video_path
|
||||
self.fps = fps
|
||||
self.vcodec = vcodec
|
||||
self.pix_fmt = pix_fmt
|
||||
self.g = g
|
||||
self.crf = crf
|
||||
self.preset = preset
|
||||
self.frame_queue = frame_queue
|
||||
self.result_queue = result_queue
|
||||
self.stop_event = stop_event
|
||||
self.encoder_threads = encoder_threads
|
||||
|
||||
def run(self) -> None:
|
||||
from lerobot.datasets.compute_stats import RunningQuantileStats, auto_downsample_height_width
|
||||
|
||||
container = None
|
||||
output_stream = None
|
||||
stats_tracker = RunningQuantileStats()
|
||||
frame_count = 0
|
||||
|
||||
try:
|
||||
logging.getLogger("libav").setLevel(av.logging.WARNING)
|
||||
|
||||
while True:
|
||||
try:
|
||||
frame_data = self.frame_queue.get(timeout=1)
|
||||
except queue.Empty:
|
||||
if self.stop_event.is_set():
|
||||
break
|
||||
continue
|
||||
|
||||
if frame_data is None:
|
||||
# Sentinel: flush and close
|
||||
break
|
||||
|
||||
# Ensure HWC uint8 numpy array
|
||||
if isinstance(frame_data, np.ndarray):
|
||||
if frame_data.ndim == 3 and frame_data.shape[0] == 3:
|
||||
# CHW -> HWC
|
||||
frame_data = frame_data.transpose(1, 2, 0)
|
||||
if frame_data.dtype != np.uint8:
|
||||
frame_data = (frame_data * 255).astype(np.uint8)
|
||||
|
||||
# Open container on first frame (to get width/height)
|
||||
if container is None:
|
||||
height, width = frame_data.shape[:2]
|
||||
video_options = _get_codec_options(self.vcodec, self.g, self.crf, self.preset)
|
||||
if self.encoder_threads is not None:
|
||||
if self.vcodec == "libsvtav1":
|
||||
lp_param = f"lp={self.encoder_threads}"
|
||||
if "svtav1-params" in video_options:
|
||||
video_options["svtav1-params"] += f":{lp_param}"
|
||||
else:
|
||||
video_options["svtav1-params"] = lp_param
|
||||
else:
|
||||
video_options["threads"] = str(self.encoder_threads)
|
||||
Path(self.video_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
container = av.open(str(self.video_path), "w")
|
||||
output_stream = container.add_stream(self.vcodec, self.fps, options=video_options)
|
||||
output_stream.pix_fmt = self.pix_fmt
|
||||
output_stream.width = width
|
||||
output_stream.height = height
|
||||
output_stream.time_base = Fraction(1, self.fps)
|
||||
|
||||
# Encode frame with explicit timestamps
|
||||
pil_img = Image.fromarray(frame_data)
|
||||
video_frame = av.VideoFrame.from_image(pil_img)
|
||||
video_frame.pts = frame_count
|
||||
video_frame.time_base = Fraction(1, self.fps)
|
||||
packet = output_stream.encode(video_frame)
|
||||
if packet:
|
||||
container.mux(packet)
|
||||
|
||||
# Update stats with downsampled frame (per-channel stats like compute_episode_stats)
|
||||
img_chw = frame_data.transpose(2, 0, 1) # HWC -> CHW
|
||||
img_downsampled = auto_downsample_height_width(img_chw)
|
||||
# Reshape CHW to (H*W, C) for per-channel stats
|
||||
channels = img_downsampled.shape[0]
|
||||
img_for_stats = img_downsampled.transpose(1, 2, 0).reshape(-1, channels)
|
||||
stats_tracker.update(img_for_stats)
|
||||
|
||||
frame_count += 1
|
||||
|
||||
# Flush encoder
|
||||
if output_stream is not None:
|
||||
packet = output_stream.encode()
|
||||
if packet:
|
||||
container.mux(packet)
|
||||
|
||||
if container is not None:
|
||||
container.close()
|
||||
|
||||
av.logging.restore_default_callback()
|
||||
|
||||
# Get stats and put on result queue
|
||||
if frame_count >= 2:
|
||||
stats = stats_tracker.get_statistics()
|
||||
self.result_queue.put(("ok", stats))
|
||||
else:
|
||||
self.result_queue.put(("ok", None))
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Encoder thread error: {e}")
|
||||
if container is not None:
|
||||
with contextlib.suppress(Exception):
|
||||
container.close()
|
||||
self.result_queue.put(("error", str(e)))
|
||||
|
||||
|
||||
class StreamingVideoEncoder:
|
||||
"""Manages per-camera encoder threads for real-time video encoding during recording.
|
||||
|
||||
Instead of writing frames as PNG images and then encoding to MP4 at episode end,
|
||||
this class streams frames directly to encoder threads, eliminating the
|
||||
PNG round-trip and making save_episode() near-instant.
|
||||
|
||||
Uses threading instead of multiprocessing to avoid the overhead of pickling large
|
||||
numpy arrays through multiprocessing.Queue. PyAV's encode() releases the GIL,
|
||||
so encoding runs in parallel with the main recording loop.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fps: int,
|
||||
vcodec: str = "libsvtav1",
|
||||
pix_fmt: str = "yuv420p",
|
||||
g: int | None = 2,
|
||||
crf: int | None = 30,
|
||||
preset: int | None = None,
|
||||
queue_maxsize: int = 30,
|
||||
encoder_threads: int | None = None,
|
||||
):
|
||||
self.fps = fps
|
||||
self.vcodec = resolve_vcodec(vcodec)
|
||||
self.pix_fmt = pix_fmt
|
||||
self.g = g
|
||||
self.crf = crf
|
||||
self.preset = preset
|
||||
self.queue_maxsize = queue_maxsize
|
||||
self.encoder_threads = encoder_threads
|
||||
|
||||
self._frame_queues: dict[str, queue.Queue] = {}
|
||||
self._result_queues: dict[str, queue.Queue] = {}
|
||||
self._threads: dict[str, _CameraEncoderThread] = {}
|
||||
self._stop_events: dict[str, threading.Event] = {}
|
||||
self._video_paths: dict[str, Path] = {}
|
||||
self._dropped_frames: dict[str, int] = {}
|
||||
self._episode_active = False
|
||||
|
||||
def start_episode(self, video_keys: list[str], temp_dir: Path) -> None:
|
||||
"""Start encoder threads for a new episode.
|
||||
|
||||
Args:
|
||||
video_keys: List of video feature keys (e.g. ["observation.images.laptop"])
|
||||
temp_dir: Base directory for temporary MP4 files
|
||||
"""
|
||||
if self._episode_active:
|
||||
self.cancel_episode()
|
||||
|
||||
self._dropped_frames.clear()
|
||||
|
||||
for video_key in video_keys:
|
||||
frame_queue: queue.Queue = queue.Queue(maxsize=self.queue_maxsize)
|
||||
result_queue: queue.Queue = queue.Queue(maxsize=1)
|
||||
stop_event = threading.Event()
|
||||
|
||||
temp_video_dir = Path(tempfile.mkdtemp(dir=temp_dir))
|
||||
video_path = temp_video_dir / f"{video_key.replace('/', '_')}_streaming.mp4"
|
||||
|
||||
encoder_thread = _CameraEncoderThread(
|
||||
video_path=video_path,
|
||||
fps=self.fps,
|
||||
vcodec=self.vcodec,
|
||||
pix_fmt=self.pix_fmt,
|
||||
g=self.g,
|
||||
crf=self.crf,
|
||||
preset=self.preset,
|
||||
frame_queue=frame_queue,
|
||||
result_queue=result_queue,
|
||||
stop_event=stop_event,
|
||||
encoder_threads=self.encoder_threads,
|
||||
)
|
||||
encoder_thread.start()
|
||||
|
||||
self._frame_queues[video_key] = frame_queue
|
||||
self._result_queues[video_key] = result_queue
|
||||
self._threads[video_key] = encoder_thread
|
||||
self._stop_events[video_key] = stop_event
|
||||
self._video_paths[video_key] = video_path
|
||||
|
||||
self._episode_active = True
|
||||
|
||||
def feed_frame(self, video_key: str, image: np.ndarray) -> None:
|
||||
"""Feed a frame to the encoder for a specific camera.
|
||||
|
||||
A copy of the image is made before enqueueing to prevent race conditions
|
||||
with camera drivers that may reuse buffers. If the encoder queue is full
|
||||
(encoder can't keep up), the frame is dropped with a warning instead of
|
||||
crashing the recording session.
|
||||
|
||||
Args:
|
||||
video_key: The video feature key
|
||||
image: numpy array in (H,W,C) or (C,H,W) format, uint8 or float
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the encoder thread has crashed
|
||||
"""
|
||||
if not self._episode_active:
|
||||
raise RuntimeError("No active episode. Call start_episode() first.")
|
||||
|
||||
thread = self._threads[video_key]
|
||||
if not thread.is_alive():
|
||||
# Check for error
|
||||
try:
|
||||
status, msg = self._result_queues[video_key].get_nowait()
|
||||
if status == "error":
|
||||
raise RuntimeError(f"Encoder thread for {video_key} crashed: {msg}")
|
||||
except queue.Empty:
|
||||
pass
|
||||
raise RuntimeError(f"Encoder thread for {video_key} is not alive")
|
||||
|
||||
try:
|
||||
self._frame_queues[video_key].put(image.copy(), timeout=0.1)
|
||||
except queue.Full:
|
||||
self._dropped_frames[video_key] = self._dropped_frames.get(video_key, 0) + 1
|
||||
count = self._dropped_frames[video_key]
|
||||
# Log periodically to avoid spam (1st, then every 10th)
|
||||
if count == 1 or count % 10 == 0:
|
||||
logging.warning(
|
||||
f"Encoder queue full for {video_key}, dropped {count} frame(s). "
|
||||
f"Consider using vcodec='auto' for hardware encoding or increasing encoder_queue_maxsize."
|
||||
)
|
||||
|
||||
def finish_episode(self) -> dict[str, tuple[Path, dict | None]]:
|
||||
"""Finish encoding the current episode.
|
||||
|
||||
Sends sentinel values, waits for encoder threads to complete,
|
||||
and collects results.
|
||||
|
||||
Returns:
|
||||
Dict mapping video_key to (mp4_path, stats_dict_or_None)
|
||||
"""
|
||||
if not self._episode_active:
|
||||
raise RuntimeError("No active episode to finish.")
|
||||
|
||||
results = {}
|
||||
|
||||
# Report dropped frames
|
||||
for video_key, count in self._dropped_frames.items():
|
||||
if count > 0:
|
||||
logging.warning(f"Episode finished with {count} dropped frame(s) for {video_key}.")
|
||||
|
||||
# Send sentinel to all queues
|
||||
for video_key in self._frame_queues:
|
||||
self._frame_queues[video_key].put(None)
|
||||
|
||||
# Wait for all threads and collect results
|
||||
for video_key in self._threads:
|
||||
self._threads[video_key].join(timeout=120)
|
||||
if self._threads[video_key].is_alive():
|
||||
logging.error(f"Encoder thread for {video_key} did not finish in time")
|
||||
self._stop_events[video_key].set()
|
||||
self._threads[video_key].join(timeout=5)
|
||||
results[video_key] = (self._video_paths[video_key], None)
|
||||
continue
|
||||
|
||||
try:
|
||||
status, data = self._result_queues[video_key].get(timeout=5)
|
||||
if status == "error":
|
||||
raise RuntimeError(f"Encoder thread for {video_key} failed: {data}")
|
||||
results[video_key] = (self._video_paths[video_key], data)
|
||||
except queue.Empty:
|
||||
logging.error(f"No result from encoder thread for {video_key}")
|
||||
results[video_key] = (self._video_paths[video_key], None)
|
||||
|
||||
self._cleanup()
|
||||
self._episode_active = False
|
||||
return results
|
||||
|
||||
def cancel_episode(self) -> None:
|
||||
"""Cancel the current episode, stopping encoder threads and cleaning up."""
|
||||
if not self._episode_active:
|
||||
return
|
||||
|
||||
# Signal all threads to stop
|
||||
for video_key in self._stop_events:
|
||||
self._stop_events[video_key].set()
|
||||
|
||||
# Wait for threads to finish
|
||||
for video_key in self._threads:
|
||||
self._threads[video_key].join(timeout=5)
|
||||
|
||||
# Clean up temp MP4 files
|
||||
video_path = self._video_paths.get(video_key)
|
||||
if video_path is not None and video_path.exists():
|
||||
shutil.rmtree(str(video_path.parent), ignore_errors=True)
|
||||
|
||||
self._cleanup()
|
||||
self._episode_active = False
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the encoder, canceling any in-progress episode."""
|
||||
if self._episode_active:
|
||||
self.cancel_episode()
|
||||
|
||||
def _cleanup(self) -> None:
|
||||
"""Clean up queues and thread tracking dicts."""
|
||||
for q in self._frame_queues.values():
|
||||
with contextlib.suppress(Exception):
|
||||
while not q.empty():
|
||||
q.get_nowait()
|
||||
self._frame_queues.clear()
|
||||
self._result_queues.clear()
|
||||
self._threads.clear()
|
||||
self._stop_events.clear()
|
||||
self._video_paths.clear()
|
||||
|
||||
|
||||
@dataclass
|
||||
class VideoFrame:
|
||||
# TODO(rcadene, lhoestq): move to Hugging Face `datasets` repo
|
||||
@@ -943,7 +514,7 @@ with warnings.catch_warnings():
|
||||
|
||||
def get_audio_info(video_path: Path | str) -> dict:
|
||||
# Set logging level
|
||||
logging.getLogger("libav").setLevel(av.logging.WARNING)
|
||||
logging.getLogger("libav").setLevel(av.logging.ERROR)
|
||||
|
||||
# Getting audio stream information
|
||||
audio_info = {}
|
||||
@@ -975,7 +546,7 @@ def get_audio_info(video_path: Path | str) -> dict:
|
||||
|
||||
def get_video_info(video_path: Path | str) -> dict:
|
||||
# Set logging level
|
||||
logging.getLogger("libav").setLevel(av.logging.WARNING)
|
||||
logging.getLogger("libav").setLevel(av.logging.ERROR)
|
||||
|
||||
# Getting video stream information
|
||||
video_info = {}
|
||||
@@ -1061,15 +632,8 @@ class VideoEncodingManager:
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
streaming_encoder = getattr(self.dataset, "_streaming_encoder", None)
|
||||
|
||||
if streaming_encoder is not None:
|
||||
# Handle streaming encoder cleanup
|
||||
if exc_type is not None:
|
||||
streaming_encoder.cancel_episode()
|
||||
streaming_encoder.close()
|
||||
elif self.dataset.episodes_since_last_encoding > 0:
|
||||
# Handle any remaining episodes that haven't been batch encoded
|
||||
# Handle any remaining episodes that haven't been batch encoded
|
||||
if self.dataset.episodes_since_last_encoding > 0:
|
||||
if exc_type is not None:
|
||||
logging.info("Exception occurred. Encoding remaining episodes before exit...")
|
||||
else:
|
||||
@@ -1086,8 +650,8 @@ class VideoEncodingManager:
|
||||
# Finalize the dataset to properly close all writers
|
||||
self.dataset.finalize()
|
||||
|
||||
# Clean up episode images if recording was interrupted (only for non-streaming mode)
|
||||
if exc_type is not None and streaming_encoder is None:
|
||||
# Clean up episode images if recording was interrupted
|
||||
if exc_type is not None:
|
||||
interrupted_episode_index = self.dataset.num_episodes
|
||||
for key in self.dataset.meta.video_keys:
|
||||
img_dir = self.dataset._get_image_file_path(
|
||||
@@ -1101,12 +665,14 @@ class VideoEncodingManager:
|
||||
|
||||
# Clean up any remaining images directory if it's empty
|
||||
img_dir = self.dataset.root / "images"
|
||||
if img_dir.exists():
|
||||
png_files = list(img_dir.rglob("*.png"))
|
||||
if len(png_files) == 0:
|
||||
# Check for any remaining PNG files
|
||||
png_files = list(img_dir.rglob("*.png"))
|
||||
if len(png_files) == 0:
|
||||
# Only remove the images directory if no PNG files remain
|
||||
if img_dir.exists():
|
||||
shutil.rmtree(img_dir)
|
||||
logging.debug("Cleaned up empty images directory")
|
||||
else:
|
||||
logging.debug(f"Images directory is not empty, containing {len(png_files)} PNG files")
|
||||
else:
|
||||
logging.debug(f"Images directory is not empty, containing {len(png_files)} PNG files")
|
||||
|
||||
return False # Don't suppress the original exception
|
||||
|
||||
@@ -1,18 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .robstride import RobstrideMotorsBus
|
||||
from .tables import *
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,120 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Configuration tables for Damiao motors."""
|
||||
|
||||
from enum import IntEnum
|
||||
|
||||
|
||||
# Motor type definitions
|
||||
class MotorType(IntEnum):
|
||||
O0 = 0
|
||||
O1 = 1
|
||||
O2 = 2
|
||||
O3 = 3
|
||||
O4 = 4
|
||||
O5 = 5
|
||||
ELO5 = 6
|
||||
O6 = 7
|
||||
|
||||
|
||||
class CommMode(IntEnum):
|
||||
PrivateProtocole = 0
|
||||
CANopen = 1
|
||||
MIT = 2
|
||||
|
||||
|
||||
# Control modes
|
||||
class ControlMode(IntEnum):
|
||||
MIT = 0
|
||||
POS_VEL = 1
|
||||
VEL = 2
|
||||
|
||||
|
||||
# Motor limit parameters [PMAX, VMAX, TMAX]
|
||||
# PMAX: Maximum position (rad)
|
||||
# VMAX: Maximum velocity (rad/s)
|
||||
# TMAX: Maximum torque (N·m)
|
||||
MOTOR_LIMIT_PARAMS: dict[MotorType, tuple[float, float, float]] = {
|
||||
MotorType.O0: (12.57, 33, 14),
|
||||
MotorType.O1: (12.57, 44, 17),
|
||||
MotorType.O2: (12.57, 33, 20),
|
||||
MotorType.O3: (12.57, 33, 60),
|
||||
MotorType.O4: (12.57, 33, 120),
|
||||
MotorType.O5: (12.57, 50, 5.5),
|
||||
MotorType.ELO5: (12.57, 50, 6),
|
||||
MotorType.O6: (112.5, 50, 36),
|
||||
}
|
||||
|
||||
# Motor model names
|
||||
MODEL_NAMES = {
|
||||
MotorType.O0: "O0",
|
||||
MotorType.O1: "O1",
|
||||
MotorType.O2: "O2",
|
||||
MotorType.O3: "O3",
|
||||
MotorType.O4: "O4",
|
||||
MotorType.O5: "O5",
|
||||
MotorType.ELO5: "ELO5",
|
||||
MotorType.O6: "O6",
|
||||
}
|
||||
|
||||
# Motor resolution table (encoder counts per revolution)
|
||||
MODEL_RESOLUTION = {
|
||||
"O0": 65536,
|
||||
"O1": 65536,
|
||||
"O2": 65536,
|
||||
"O3": 65536,
|
||||
"O4": 65536,
|
||||
"O5": 65536,
|
||||
"ELO5": 65536,
|
||||
"O6": 65536,
|
||||
}
|
||||
|
||||
# CAN baudrates supported by Robstride motors
|
||||
AVAILABLE_BAUDRATES = [
|
||||
1000000, # 4: 1 mbps (default)
|
||||
]
|
||||
DEFAULT_BAUDRATE = 1000000
|
||||
|
||||
# Default timeout in milliseconds
|
||||
DEFAULT_TIMEOUT_MS = 0 # disabled by default, otherwise 20000 is 1s
|
||||
|
||||
|
||||
# Data that should be normalized
|
||||
NORMALIZED_DATA = ["Present_Position", "Goal_Position"]
|
||||
|
||||
|
||||
# MIT control parameter ranges
|
||||
MIT_KP_RANGE = (0.0, 500.0)
|
||||
MIT_KD_RANGE = (0.0, 5.0)
|
||||
|
||||
# CAN frame command IDs
|
||||
CAN_CMD_ENABLE = 0xFC
|
||||
CAN_CMD_DISABLE = 0xFD
|
||||
CAN_CMD_SET_ZERO = 0xFE
|
||||
CAN_CMD_CLEAR_FAULT = 0xFB
|
||||
|
||||
|
||||
CAN_CMD_QUERY_PARAM = 0x33
|
||||
CAN_CMD_WRITE_PARAM = 0x55
|
||||
CAN_CMD_SAVE_PARAM = 0xAA
|
||||
|
||||
# CAN ID for parameter operations
|
||||
CAN_PARAM_ID = 0x7FF
|
||||
|
||||
|
||||
RUNNING_TIMEOUT = 0.001
|
||||
PARAM_TIMEOUT = 0.01
|
||||
|
||||
STATE_CACHE_TTL_S = 0.02
|
||||
@@ -139,10 +139,6 @@ class DiffusionConfig(PreTrainedConfig):
|
||||
# Inference
|
||||
num_inference_steps: int | None = None
|
||||
|
||||
# Optimization
|
||||
compile_model: bool = False
|
||||
compile_mode: str = "reduce-overhead"
|
||||
|
||||
# Loss computation
|
||||
do_mask_loss_for_padding: bool = False
|
||||
|
||||
|
||||
@@ -142,9 +142,6 @@ class DiffusionPolicy(PreTrainedPolicy):
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
for key in self.config.image_features:
|
||||
if self.config.n_obs_steps == 1 and batch[key].ndim == 4:
|
||||
batch[key] = batch[key].unsqueeze(1)
|
||||
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
|
||||
loss = self.diffusion.compute_loss(batch)
|
||||
# no output_dict so returning None
|
||||
@@ -185,11 +182,6 @@ class DiffusionModel(nn.Module):
|
||||
|
||||
self.unet = DiffusionConditionalUnet1d(config, global_cond_dim=global_cond_dim * config.n_obs_steps)
|
||||
|
||||
if config.compile_model:
|
||||
# Compile the U-Net. "reduce-overhead" is preferred for the small-batch repetitive loops
|
||||
# common in diffusion inference.
|
||||
self.unet = torch.compile(self.unet, mode=config.compile_mode)
|
||||
|
||||
self.noise_scheduler = _make_noise_scheduler(
|
||||
config.noise_scheduler_type,
|
||||
num_train_timesteps=config.num_train_timesteps,
|
||||
|
||||
@@ -470,6 +470,13 @@ def make_policy(
|
||||
cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
|
||||
if not cfg.input_features:
|
||||
cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features}
|
||||
|
||||
# Store action feature names for delta_exclude_joints support
|
||||
if ds_meta is not None and hasattr(cfg, "action_feature_names"):
|
||||
action_names = ds_meta.features.get(ACTION, {}).get("names")
|
||||
if action_names is not None:
|
||||
cfg.action_feature_names = list(action_names)
|
||||
|
||||
kwargs["config"] = cfg
|
||||
|
||||
# Pass dataset_stats to the policy if available (needed for some policies like SARM)
|
||||
|
||||
@@ -50,6 +50,13 @@ class PI0Config(PreTrainedConfig):
|
||||
min_period: float = 4e-3
|
||||
max_period: float = 4.0
|
||||
|
||||
# Delta actions: converts absolute actions to delta (relative to state).
|
||||
use_delta_actions: bool = False
|
||||
# Joint names to exclude from delta (kept absolute). Empty list = all dims delta.
|
||||
delta_exclude_joints: list[str] = field(default_factory=lambda: ["gripper"])
|
||||
# Populated at runtime from dataset metadata by make_policy.
|
||||
action_feature_names: list[str] | None = None
|
||||
|
||||
# Real-Time Chunking (RTC) configuration
|
||||
rtc_config: RTCConfig | None = None
|
||||
|
||||
|
||||
@@ -21,8 +21,10 @@ import torch
|
||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.processor import (
|
||||
AbsoluteActionsProcessorStep,
|
||||
AddBatchDimensionProcessorStep,
|
||||
ComplementaryDataProcessorStep,
|
||||
DeltaActionsProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyAction,
|
||||
@@ -126,7 +128,13 @@ def make_pi0_pre_post_processors(
|
||||
A tuple containing the configured pre-processor and post-processor pipelines.
|
||||
"""
|
||||
|
||||
# Add remaining processors
|
||||
delta_step = DeltaActionsProcessorStep(
|
||||
enabled=config.use_delta_actions,
|
||||
exclude_joints=getattr(config, "delta_exclude_joints", []),
|
||||
action_names=getattr(config, "action_feature_names", None),
|
||||
)
|
||||
|
||||
# OpenPI order: raw → delta → normalize → model → unnormalize → absolute
|
||||
input_steps: list[ProcessorStep] = [
|
||||
RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one
|
||||
AddBatchDimensionProcessorStep(),
|
||||
@@ -138,6 +146,7 @@ def make_pi0_pre_post_processors(
|
||||
padding="max_length",
|
||||
),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
delta_step,
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
@@ -149,6 +158,7 @@ def make_pi0_pre_post_processors(
|
||||
UnnormalizerProcessorStep(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
AbsoluteActionsProcessorStep(enabled=config.use_delta_actions, delta_step=delta_step),
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
]
|
||||
|
||||
|
||||
@@ -50,6 +50,13 @@ class PI05Config(PreTrainedConfig):
|
||||
min_period: float = 4e-3
|
||||
max_period: float = 4.0
|
||||
|
||||
# Delta actions: converts absolute actions to delta (relative to state).
|
||||
use_delta_actions: bool = False
|
||||
# Joint names to exclude from delta (kept absolute). Empty list = all dims delta.
|
||||
delta_exclude_joints: list[str] = field(default_factory=lambda: ["gripper"])
|
||||
# Populated at runtime from dataset metadata by make_policy.
|
||||
action_feature_names: list[str] | None = None
|
||||
|
||||
# Real-Time Chunking (RTC) configuration
|
||||
rtc_config: RTCConfig | None = None
|
||||
|
||||
|
||||
@@ -25,7 +25,9 @@ from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.policies.pi05.configuration_pi05 import PI05Config
|
||||
from lerobot.policies.pi05.modeling_pi05 import pad_vector
|
||||
from lerobot.processor import (
|
||||
AbsoluteActionsProcessorStep,
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeltaActionsProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyAction,
|
||||
@@ -129,10 +131,19 @@ def make_pi05_pre_post_processors(
|
||||
A tuple containing the configured pre-processor and post-processor pipelines.
|
||||
"""
|
||||
|
||||
# Add remaining processors
|
||||
delta_step = DeltaActionsProcessorStep(
|
||||
enabled=config.use_delta_actions,
|
||||
exclude_joints=getattr(config, "delta_exclude_joints", []),
|
||||
action_names=getattr(config, "action_feature_names", None),
|
||||
)
|
||||
|
||||
# OpenPI order: raw → delta → normalize → model → unnormalize → absolute
|
||||
# NOTE: NormalizerProcessorStep MUST come before Pi05PrepareStateTokenizerProcessorStep
|
||||
# because the tokenizer step expects normalized state in [-1, 1] range for discretization
|
||||
input_steps: list[ProcessorStep] = [
|
||||
RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one
|
||||
AddBatchDimensionProcessorStep(),
|
||||
delta_step,
|
||||
# NOTE: NormalizerProcessorStep MUST come before Pi05PrepareStateTokenizerProcessorStep
|
||||
# because the tokenizer step expects normalized state in [-1, 1] range for discretization
|
||||
NormalizerProcessorStep(
|
||||
@@ -154,6 +165,7 @@ def make_pi05_pre_post_processors(
|
||||
UnnormalizerProcessorStep(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
AbsoluteActionsProcessorStep(enabled=config.use_delta_actions, delta_step=delta_step),
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
]
|
||||
|
||||
|
||||
@@ -41,6 +41,9 @@ class PI0FastConfig(PreTrainedConfig):
|
||||
max_action_dim: int = 32
|
||||
max_action_tokens: int = 256
|
||||
|
||||
# Delta actions: converts absolute actions to delta (relative to state).
|
||||
use_delta_actions: bool = False
|
||||
|
||||
# Real-Time Chunking (RTC) configuration
|
||||
rtc_config: RTCConfig | None = None
|
||||
|
||||
|
||||
@@ -48,12 +48,14 @@ from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.policies.pi0_fast.configuration_pi0_fast import PI0FastConfig
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
||||
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
|
||||
from lerobot.processor.delta_action_processor import to_absolute_actions
|
||||
from lerobot.utils.constants import (
|
||||
ACTION,
|
||||
ACTION_TOKEN_MASK,
|
||||
ACTION_TOKENS,
|
||||
OBS_LANGUAGE_ATTENTION_MASK,
|
||||
OBS_LANGUAGE_TOKENS,
|
||||
OBS_STATE,
|
||||
OPENPI_ATTENTION_MASK_VALUE,
|
||||
)
|
||||
|
||||
@@ -1315,6 +1317,12 @@ class PI0FastPolicy(PreTrainedPolicy):
|
||||
action_tokens, action_horizon=action_horizon, action_dim=action_dim
|
||||
)
|
||||
|
||||
if self.config.use_delta_actions and OBS_STATE in batch:
|
||||
state = pad_vector(batch[OBS_STATE], self.config.max_state_dim)
|
||||
continuous_actions = to_absolute_actions(
|
||||
continuous_actions, state, [True] * continuous_actions.shape[-1]
|
||||
)
|
||||
|
||||
return continuous_actions
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
||||
|
||||
@@ -27,6 +27,7 @@ from lerobot.policies.pi0_fast.modeling_pi0_fast import pad_vector
|
||||
from lerobot.processor import (
|
||||
ActionTokenizerProcessorStep,
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeltaActionsProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyAction,
|
||||
@@ -147,6 +148,7 @@ def make_pi0_fast_pre_post_processors(
|
||||
padding_side="right",
|
||||
padding="max_length",
|
||||
),
|
||||
DeltaActionsProcessorStep(enabled=config.use_delta_actions),
|
||||
ActionTokenizerProcessorStep(
|
||||
action_tokenizer_name=config.action_tokenizer_name,
|
||||
max_action_tokens=config.max_action_tokens,
|
||||
|
||||
@@ -27,18 +27,18 @@ Usage:
|
||||
# Full RA-BC computation with visualizations
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
|
||||
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
|
||||
--reward-model-path <USER>/sarm_single_uni4
|
||||
--reward-model-path pepijn223/sarm_single_uni4
|
||||
|
||||
# Faster computation with stride (compute every 5 frames, interpolate the rest)
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
|
||||
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
|
||||
--reward-model-path <USER>/sarm_single_uni4 \\
|
||||
--reward-model-path pepijn223/sarm_single_uni4 \\
|
||||
--stride 5
|
||||
|
||||
# Visualize predictions only (no RA-BC computation)
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
|
||||
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
|
||||
--reward-model-path <USER>/sarm_single_uni4 \\
|
||||
--reward-model-path pepijn223/sarm_single_uni4 \\
|
||||
--visualize-only \\
|
||||
--num-visualizations 5
|
||||
|
||||
@@ -714,12 +714,12 @@ Examples:
|
||||
# Full RA-BC computation with visualizations
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
|
||||
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
|
||||
--reward-model-path <USER>/sarm_single_uni4
|
||||
--reward-model-path pepijn223/sarm_single_uni4
|
||||
|
||||
# Visualize predictions only (no RA-BC computation)
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
|
||||
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
|
||||
--reward-model-path <USER>/sarm_single_uni4 \\
|
||||
--reward-model-path pepijn223/sarm_single_uni4 \\
|
||||
--visualize-only \\
|
||||
--num-visualizations 10
|
||||
""",
|
||||
|
||||
@@ -277,7 +277,9 @@ class SARMEncodingProcessorStep(ProcessorStep):
|
||||
|
||||
# When language is perturbed, targets are zero so perturbed samples don't contribute to progress loss
|
||||
if self.dataset_meta is not None:
|
||||
episodes_df = self.dataset_meta.episodes.to_pandas()
|
||||
episodes_df = None
|
||||
if self.sparse_subtask_names != ["task"]:
|
||||
episodes_df = self.dataset_meta.episodes.to_pandas()
|
||||
|
||||
# Generate sparse targets
|
||||
if self.sparse_temporal_proportions is not None:
|
||||
|
||||
@@ -85,7 +85,7 @@ class SmolVLAConfig(PreTrainedConfig):
|
||||
scheduler_decay_lr: float = 2.5e-6
|
||||
|
||||
vlm_model_name: str = "HuggingFaceTB/SmolVLM2-500M-Video-Instruct" # Select the VLM backbone.
|
||||
load_vlm_weights: bool = False # Set to False in case of training the expert from scratch. True when init from pretrained SmolVLA weights
|
||||
load_vlm_weights: bool = False # Set to True in case of training the expert from scratch. True when init from pretrained SmolVLA weights
|
||||
|
||||
add_image_special_tokens: bool = False # Whether to use special image tokens around image features.
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ Example of finetuning the smolvla pretrained model (`smolvla_base`):
|
||||
```bash
|
||||
lerobot-train \
|
||||
--policy.path=lerobot/smolvla_base \
|
||||
--dataset.repo_id=<USER>/svla_so100_task1_v3 \
|
||||
--dataset.repo_id=danaaubakirova/svla_so100_task1_v3 \
|
||||
--batch_size=64 \
|
||||
--steps=200000
|
||||
```
|
||||
@@ -40,7 +40,7 @@ and an action expert.
|
||||
```bash
|
||||
lerobot-train \
|
||||
--policy.type=smolvla \
|
||||
--dataset.repo_id=<USER>/svla_so100_task1_v3 \
|
||||
--dataset.repo_id=danaaubakirova/svla_so100_task1_v3 \
|
||||
--batch_size=64 \
|
||||
--steps=200000
|
||||
```
|
||||
|
||||
@@ -28,7 +28,14 @@ from .core import (
|
||||
RobotObservation,
|
||||
TransitionKey,
|
||||
)
|
||||
from .delta_action_processor import MapDeltaActionToRobotActionStep, MapTensorToDeltaActionDictStep
|
||||
from .delta_action_processor import (
|
||||
AbsoluteActionsProcessorStep,
|
||||
DeltaActionsProcessorStep,
|
||||
MapDeltaActionToRobotActionStep,
|
||||
MapTensorToDeltaActionDictStep,
|
||||
to_absolute_actions,
|
||||
to_delta_actions,
|
||||
)
|
||||
from .device_processor import DeviceProcessorStep
|
||||
from .factory import (
|
||||
make_default_processors,
|
||||
@@ -44,7 +51,6 @@ from .hil_processor import (
|
||||
AddTeleopActionAsComplimentaryDataStep,
|
||||
AddTeleopEventsAsInfoStep,
|
||||
GripperPenaltyProcessorStep,
|
||||
GymHILAdapterProcessorStep,
|
||||
ImageCropResizeProcessorStep,
|
||||
InterventionActionProcessorStep,
|
||||
RewardClassifierProcessorStep,
|
||||
@@ -88,7 +94,6 @@ __all__ = [
|
||||
"DoneProcessorStep",
|
||||
"EnvAction",
|
||||
"EnvTransition",
|
||||
"GymHILAdapterProcessorStep",
|
||||
"GripperPenaltyProcessorStep",
|
||||
"hotswap_stats",
|
||||
"IdentityProcessorStep",
|
||||
@@ -99,6 +104,8 @@ __all__ = [
|
||||
"make_default_teleop_action_processor",
|
||||
"make_default_robot_action_processor",
|
||||
"make_default_robot_observation_processor",
|
||||
"AbsoluteActionsProcessorStep",
|
||||
"DeltaActionsProcessorStep",
|
||||
"MapDeltaActionToRobotActionStep",
|
||||
"MapTensorToDeltaActionDictStep",
|
||||
"NormalizerProcessorStep",
|
||||
@@ -128,6 +135,8 @@ __all__ = [
|
||||
"transition_to_batch",
|
||||
"TransitionKey",
|
||||
"TruncatedProcessorStep",
|
||||
"to_absolute_actions",
|
||||
"to_delta_actions",
|
||||
"UnnormalizerProcessorStep",
|
||||
"VanillaObservationProcessorStep",
|
||||
]
|
||||
|
||||
@@ -14,12 +14,54 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
|
||||
from lerobot.utils.constants import OBS_STATE
|
||||
|
||||
from .core import PolicyAction, RobotAction
|
||||
from .pipeline import ActionProcessorStep, ProcessorStepRegistry, RobotActionProcessorStep
|
||||
from .core import EnvTransition, PolicyAction, RobotAction, TransitionKey
|
||||
from .pipeline import ActionProcessorStep, ProcessorStep, ProcessorStepRegistry, RobotActionProcessorStep
|
||||
|
||||
|
||||
def to_delta_actions(actions: Tensor, state: Tensor, mask: Sequence[bool]) -> Tensor:
|
||||
"""Convert absolute actions to delta: delta = action - state (for masked dims).
|
||||
|
||||
Args:
|
||||
actions: (B, T, action_dim) or (B, action_dim).
|
||||
state: (B, state_dim). Broadcast across time dimension.
|
||||
mask: Which dims to convert. Can be shorter than action_dim.
|
||||
"""
|
||||
mask_t = torch.tensor(mask, dtype=actions.dtype, device=actions.device)
|
||||
dims = mask_t.shape[0]
|
||||
state_offset = state[..., :dims] * mask_t
|
||||
if actions.ndim == 3:
|
||||
state_offset = state_offset.unsqueeze(-2)
|
||||
actions = actions.clone()
|
||||
actions[..., :dims] -= state_offset
|
||||
return actions
|
||||
|
||||
|
||||
def to_absolute_actions(actions: Tensor, state: Tensor, mask: Sequence[bool]) -> Tensor:
|
||||
"""Convert delta actions back to absolute: absolute = delta + state (for masked dims).
|
||||
|
||||
Args:
|
||||
actions: (B, T, action_dim) or (B, action_dim).
|
||||
state: (B, state_dim). Broadcast across time dimension.
|
||||
mask: Which dims to convert. Can be shorter than action_dim.
|
||||
"""
|
||||
mask_t = torch.tensor(mask, dtype=actions.dtype, device=actions.device)
|
||||
dims = mask_t.shape[0]
|
||||
state_offset = state[..., :dims] * mask_t
|
||||
if actions.ndim == 3:
|
||||
state_offset = state_offset.unsqueeze(-2)
|
||||
actions = actions.clone()
|
||||
actions[..., :dims] += state_offset
|
||||
return actions
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("map_tensor_to_delta_action_dict")
|
||||
@@ -141,3 +183,126 @@ class MapDeltaActionToRobotActionStep(RobotActionProcessorStep):
|
||||
)
|
||||
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("delta_actions_processor")
|
||||
@dataclass
|
||||
class DeltaActionsProcessorStep(ProcessorStep):
|
||||
"""Converts absolute actions to delta actions (action -= state) for masked dimensions.
|
||||
|
||||
Mirrors OpenPI's DeltaActions transform. Applied during preprocessing so the model
|
||||
trains on relative offsets instead of absolute positions.
|
||||
Caches the last seen state so a paired AbsoluteActionsProcessorStep can reverse
|
||||
the conversion during postprocessing.
|
||||
|
||||
Attributes:
|
||||
enabled: Whether to apply the delta conversion.
|
||||
exclude_joints: Joint names to keep absolute (not converted to delta).
|
||||
action_names: Action dimension names from dataset metadata, used to build
|
||||
the mask from exclude_joints. If None, all dims are converted.
|
||||
"""
|
||||
|
||||
enabled: bool = False
|
||||
exclude_joints: list[str] = field(default_factory=list)
|
||||
action_names: list[str] | None = None
|
||||
_last_state: torch.Tensor | None = field(default=None, init=False, repr=False)
|
||||
|
||||
def _build_mask(self, action_dim: int) -> list[bool]:
|
||||
if not self.exclude_joints or self.action_names is None:
|
||||
return [True] * action_dim
|
||||
|
||||
exclude_tokens = [str(name).lower() for name in self.exclude_joints if name]
|
||||
if not exclude_tokens:
|
||||
return [True] * action_dim
|
||||
|
||||
mask = []
|
||||
for name in self.action_names[:action_dim]:
|
||||
action_name = str(name).lower()
|
||||
is_excluded = any(token == action_name or token in action_name for token in exclude_tokens)
|
||||
mask.append(not is_excluded)
|
||||
|
||||
if len(mask) < action_dim:
|
||||
mask.extend([True] * (action_dim - len(mask)))
|
||||
|
||||
return mask
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
observation = transition.get(TransitionKey.OBSERVATION, {})
|
||||
state = observation.get(OBS_STATE) if observation else None
|
||||
|
||||
# Always cache state for the paired AbsoluteActionsProcessorStep
|
||||
if state is not None:
|
||||
self._last_state = state
|
||||
|
||||
if not self.enabled:
|
||||
return transition
|
||||
|
||||
new_transition = transition.copy()
|
||||
action = new_transition.get(TransitionKey.ACTION)
|
||||
if action is None or state is None:
|
||||
return new_transition
|
||||
|
||||
mask = self._build_mask(action.shape[-1])
|
||||
new_transition[TransitionKey.ACTION] = to_delta_actions(action, state, mask)
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {"enabled": self.enabled, "exclude_joints": self.exclude_joints}
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("absolute_actions_processor")
|
||||
@dataclass
|
||||
class AbsoluteActionsProcessorStep(ProcessorStep):
|
||||
"""Converts delta actions back to absolute actions (action += state) for all dimensions.
|
||||
|
||||
Mirrors OpenPI's AbsoluteActions transform. Applied during postprocessing so
|
||||
predicted deltas are converted back to absolute positions for execution.
|
||||
Reads the cached state from its paired DeltaActionsProcessorStep.
|
||||
|
||||
Attributes:
|
||||
enabled: Whether to apply the absolute conversion.
|
||||
delta_step: Reference to the paired DeltaActionsProcessorStep that caches state.
|
||||
"""
|
||||
|
||||
enabled: bool = False
|
||||
delta_step: DeltaActionsProcessorStep | None = field(default=None, repr=False)
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
if not self.enabled:
|
||||
return transition
|
||||
|
||||
if self.delta_step is None:
|
||||
raise RuntimeError(
|
||||
"AbsoluteActionsProcessorStep requires a paired DeltaActionsProcessorStep "
|
||||
"but delta_step is None. Ensure delta_step is set when constructing the postprocessor."
|
||||
)
|
||||
|
||||
if self.delta_step._last_state is None:
|
||||
raise RuntimeError(
|
||||
"AbsoluteActionsProcessorStep requires state from DeltaActionsProcessorStep "
|
||||
"but no state has been cached. Ensure the preprocessor runs before the postprocessor."
|
||||
)
|
||||
|
||||
new_transition = transition.copy()
|
||||
action = new_transition.get(TransitionKey.ACTION)
|
||||
if action is None:
|
||||
return new_transition
|
||||
|
||||
mask = self.delta_step._build_mask(action.shape[-1])
|
||||
new_transition[TransitionKey.ACTION] = to_absolute_actions(
|
||||
action, self.delta_step._last_state, mask
|
||||
)
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {"enabled": self.enabled}
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
return features
|
||||
|
||||
@@ -20,7 +20,6 @@ from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
|
||||
from .converters import to_tensor
|
||||
from .core import EnvAction, EnvTransition, PolicyAction
|
||||
from .hil_processor import TELEOP_ACTION_KEY
|
||||
from .pipeline import ActionProcessorStep, ProcessorStep, ProcessorStepRegistry
|
||||
|
||||
|
||||
@@ -90,13 +89,6 @@ class Numpy2TorchActionProcessorStep(ProcessorStep):
|
||||
torch_action = to_tensor(action, dtype=None) # Preserve original dtype
|
||||
new_transition[TransitionKey.ACTION] = torch_action
|
||||
|
||||
complementary_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||
if TELEOP_ACTION_KEY in complementary_data:
|
||||
teleop_action = complementary_data[TELEOP_ACTION_KEY]
|
||||
if isinstance(teleop_action, EnvAction):
|
||||
complementary_data[TELEOP_ACTION_KEY] = to_tensor(teleop_action)
|
||||
new_transition[TransitionKey.COMPLEMENTARY_DATA] = complementary_data
|
||||
|
||||
return new_transition
|
||||
|
||||
def transform_features(
|
||||
|
||||
@@ -312,37 +312,6 @@ class TimeLimitProcessorStep(TruncatedProcessorStep):
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("gym_hil_adapter_processor")
|
||||
class GymHILAdapterProcessorStep(ProcessorStep):
|
||||
"""
|
||||
Adapts the output of the `gym-hil` environment to the format expected by `lerobot` processors.
|
||||
|
||||
This step normalizes the `transition` object by:
|
||||
1. Copying `teleop_action` from `info` to `complementary_data`.
|
||||
2. Copying `is_intervention` from `info` (using the string key) to `info` (using the enum key).
|
||||
"""
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
info = transition.get(TransitionKey.INFO, {})
|
||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||
|
||||
if TELEOP_ACTION_KEY in info:
|
||||
complementary_data[TELEOP_ACTION_KEY] = info[TELEOP_ACTION_KEY]
|
||||
|
||||
if "is_intervention" in info:
|
||||
info[TeleopEvents.IS_INTERVENTION] = info["is_intervention"]
|
||||
|
||||
transition[TransitionKey.INFO] = info
|
||||
transition[TransitionKey.COMPLEMENTARY_DATA] = complementary_data
|
||||
|
||||
return transition
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register("gripper_penalty_processor")
|
||||
class GripperPenaltyProcessorStep(ProcessorStep):
|
||||
|
||||
@@ -331,11 +331,9 @@ class _NormalizationMixin:
|
||||
)
|
||||
|
||||
mean, std = stats["mean"], stats["std"]
|
||||
# Avoid division by zero by adding a small epsilon.
|
||||
denom = std + self.eps
|
||||
if inverse:
|
||||
return tensor * std + mean
|
||||
return (tensor - mean) / denom
|
||||
return tensor * (std + 1e-6) + mean
|
||||
return (tensor - mean) / (std + 1e-6)
|
||||
|
||||
if norm_mode == NormalizationMode.MIN_MAX:
|
||||
min_val = stats.get("min", None)
|
||||
@@ -367,11 +365,7 @@ class _NormalizationMixin:
|
||||
"QUANTILES normalization mode requires q01 and q99 stats, please update the dataset with the correct stats using the `augment_dataset_quantile_stats.py` script"
|
||||
)
|
||||
|
||||
denom = q99 - q01
|
||||
# Avoid division by zero by adding epsilon when quantiles are identical
|
||||
denom = torch.where(
|
||||
denom == 0, torch.tensor(self.eps, device=tensor.device, dtype=tensor.dtype), denom
|
||||
)
|
||||
denom = q99 - q01 + 1e-6
|
||||
if inverse:
|
||||
return (tensor + 1.0) * denom / 2.0 + q01
|
||||
return 2.0 * (tensor - q01) / denom - 1.0
|
||||
|
||||
@@ -36,7 +36,6 @@ from lerobot.processor import (
|
||||
DeviceProcessorStep,
|
||||
EnvTransition,
|
||||
GripperPenaltyProcessorStep,
|
||||
GymHILAdapterProcessorStep,
|
||||
ImageCropResizeProcessorStep,
|
||||
InterventionActionProcessorStep,
|
||||
MapDeltaActionToRobotActionStep,
|
||||
@@ -380,7 +379,6 @@ def make_processors(
|
||||
]
|
||||
|
||||
env_pipeline_steps = [
|
||||
GymHILAdapterProcessorStep(),
|
||||
Numpy2TorchActionProcessorStep(),
|
||||
VanillaObservationProcessorStep(),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
@@ -610,14 +608,7 @@ def control_loop(
|
||||
|
||||
dataset = None
|
||||
if cfg.mode == "record":
|
||||
if teleop_device:
|
||||
action_features = teleop_device.action_features
|
||||
else:
|
||||
action_features = {
|
||||
"dtype": "float32",
|
||||
"shape": (4,),
|
||||
"names": ["delta_x", "delta_y", "delta_z", "gripper"],
|
||||
}
|
||||
action_features = teleop_device.action_features
|
||||
features = {
|
||||
ACTION: action_features,
|
||||
REWARD: {"dtype": "float32", "shape": (1,), "names": None},
|
||||
@@ -665,7 +656,7 @@ def control_loop(
|
||||
# Create a neutral action (no movement)
|
||||
neutral_action = torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32)
|
||||
if use_gripper:
|
||||
neutral_action = torch.cat([neutral_action, torch.tensor([0.0])]) # Gripper stay
|
||||
neutral_action = torch.cat([neutral_action, torch.tensor([1.0])]) # Gripper stay
|
||||
|
||||
# Use the new step function
|
||||
transition = step_env_and_process_transition(
|
||||
@@ -734,8 +725,6 @@ def control_loop(
|
||||
precise_sleep(max(dt - (time.perf_counter() - step_start_time), 0.0))
|
||||
|
||||
if dataset is not None and cfg.dataset.push_to_hub:
|
||||
logging.info("Finalizing dataset before pushing to hub")
|
||||
dataset.finalize()
|
||||
logging.info("Pushing dataset to hub")
|
||||
dataset.push_to_hub()
|
||||
|
||||
|
||||
@@ -140,7 +140,7 @@ class HopeJrArm(Robot):
|
||||
# Capture images from cameras
|
||||
for cam_key, cam in self.cameras.items():
|
||||
start = time.perf_counter()
|
||||
obs_dict[cam_key] = cam.read_latest()
|
||||
obs_dict[cam_key] = cam.async_read()
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
||||
|
||||
|
||||
@@ -171,7 +171,7 @@ class HopeJrHand(Robot):
|
||||
# Capture images from cameras
|
||||
for cam_key, cam in self.cameras.items():
|
||||
start = time.perf_counter()
|
||||
obs_dict[cam_key] = cam.read_latest()
|
||||
obs_dict[cam_key] = cam.async_read()
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
||||
|
||||
|
||||
@@ -193,7 +193,7 @@ class KochFollower(Robot):
|
||||
# Capture images from cameras
|
||||
for cam_key, cam in self.cameras.items():
|
||||
start = time.perf_counter()
|
||||
obs_dict[cam_key] = cam.read_latest()
|
||||
obs_dict[cam_key] = cam.async_read()
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
||||
|
||||
|
||||
@@ -360,7 +360,7 @@ class LeKiwi(Robot):
|
||||
# Capture images from cameras
|
||||
for cam_key, cam in self.cameras.items():
|
||||
start = time.perf_counter()
|
||||
obs_dict[cam_key] = cam.read_latest()
|
||||
obs_dict[cam_key] = cam.async_read()
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
||||
|
||||
|
||||
@@ -176,7 +176,7 @@ class OmxFollower(Robot):
|
||||
# Capture images from cameras
|
||||
for cam_key, cam in self.cameras.items():
|
||||
start = time.perf_counter()
|
||||
obs_dict[cam_key] = cam.read_latest()
|
||||
obs_dict[cam_key] = cam.async_read()
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
||||
|
||||
|
||||
@@ -241,7 +241,7 @@ class OpenArmFollower(Robot):
|
||||
# Capture images from cameras
|
||||
for cam_key, cam in self.cameras.items():
|
||||
start = time.perf_counter()
|
||||
obs_dict[cam_key] = cam.read_latest()
|
||||
obs_dict[cam_key] = cam.async_read()
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
||||
|
||||
|
||||
@@ -180,7 +180,7 @@ class Reachy2Robot(Robot):
|
||||
|
||||
# Capture images from cameras
|
||||
for cam_key, cam in self.cameras.items():
|
||||
obs_dict[cam_key] = cam.read_latest()
|
||||
obs_dict[cam_key] = cam.async_read()
|
||||
|
||||
return obs_dict
|
||||
|
||||
|
||||
@@ -187,7 +187,7 @@ class SOFollower(Robot):
|
||||
# Capture images from cameras
|
||||
for cam_key, cam in self.cameras.items():
|
||||
start = time.perf_counter()
|
||||
obs_dict[cam_key] = cam.read_latest()
|
||||
obs_dict[cam_key] = cam.async_read()
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
||||
|
||||
|
||||
@@ -324,7 +324,7 @@ class UnitreeG1(Robot):
|
||||
|
||||
# Cameras - read images from ZMQ cameras
|
||||
for cam_name, cam in self._cameras.items():
|
||||
obs[cam_name] = cam.read_latest()
|
||||
obs[cam_name] = cam.async_read()
|
||||
|
||||
return obs
|
||||
|
||||
|
||||
@@ -47,14 +47,16 @@ local$ rerun lerobot_pusht_episode_0.rrd
|
||||
```
|
||||
|
||||
- Visualize data stored on a distant machine through streaming:
|
||||
(You need to forward the websocket port to the distant machine, with
|
||||
`ssh -L 9087:localhost:9087 username@remote-host`)
|
||||
```
|
||||
distant$ lerobot-dataset-viz \
|
||||
--repo-id lerobot/pusht \
|
||||
--episode-index 0 \
|
||||
--mode distant \
|
||||
--grpc-port 9876
|
||||
--ws-port 9087
|
||||
|
||||
local$ rerun rerun+http://IP:GRPC_PORT/proxy
|
||||
local$ rerun ws://localhost:9087
|
||||
```
|
||||
|
||||
"""
|
||||
@@ -73,7 +75,6 @@ import tqdm
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.utils.constants import ACTION, DONE, OBS_STATE, REWARD
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
|
||||
def to_hwc_uint8_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray:
|
||||
@@ -92,11 +93,10 @@ def visualize_dataset(
|
||||
num_workers: int = 0,
|
||||
mode: str = "local",
|
||||
web_port: int = 9090,
|
||||
grpc_port: int = 9876,
|
||||
ws_port: int = 9087,
|
||||
save: bool = False,
|
||||
output_dir: Path | None = None,
|
||||
display_compressed_images: bool = False,
|
||||
**kwargs,
|
||||
) -> Path | None:
|
||||
if save:
|
||||
assert output_dir is not None, (
|
||||
@@ -126,9 +126,7 @@ def visualize_dataset(
|
||||
gc.collect()
|
||||
|
||||
if mode == "distant":
|
||||
server_uri = rr.serve_grpc(grpc_port=grpc_port)
|
||||
logging.info(f"Connect to a Rerun Server: rerun rerun+http://IP:{grpc_port}/proxy")
|
||||
rr.serve_web_viewer(open_browser=False, web_port=web_port, connect_to=server_uri)
|
||||
rr.serve_web_viewer(open_browser=False, web_port=web_port)
|
||||
|
||||
logging.info("Logging to Rerun")
|
||||
|
||||
@@ -228,7 +226,7 @@ def main():
|
||||
"Mode of viewing between 'local' or 'distant'. "
|
||||
"'local' requires data to be on a local machine. It spawns a viewer to visualize the data locally. "
|
||||
"'distant' creates a server on the distant machine where the data is stored. "
|
||||
"Visualize the data by connecting to the server with `rerun rerun+http://IP:GRPC_PORT/proxy` on the local machine."
|
||||
"Visualize the data by connecting to the server with `rerun ws://localhost:PORT` on the local machine."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
@@ -240,13 +238,8 @@ def main():
|
||||
parser.add_argument(
|
||||
"--ws-port",
|
||||
type=int,
|
||||
help="deprecated, please use --grpc-port instead.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--grpc-port",
|
||||
type=int,
|
||||
default=9876,
|
||||
help="gRPC port for rerun.io when `--mode distant` is set.",
|
||||
default=9087,
|
||||
help="Web socket port for rerun.io when `--mode distant` is set.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save",
|
||||
@@ -272,7 +265,9 @@ def main():
|
||||
|
||||
parser.add_argument(
|
||||
"--display-compressed-images",
|
||||
action="store_true",
|
||||
type=bool,
|
||||
required=True,
|
||||
default=False,
|
||||
help="If set, display compressed images in Rerun instead of uncompressed ones.",
|
||||
)
|
||||
|
||||
@@ -282,14 +277,6 @@ def main():
|
||||
root = kwargs.pop("root")
|
||||
tolerance_s = kwargs.pop("tolerance_s")
|
||||
|
||||
if kwargs["ws_port"] is not None:
|
||||
logging.warning(
|
||||
"--ws-port is deprecated and will be removed in future versions. Please use --grpc-port instead."
|
||||
)
|
||||
logging.warning("Setting grpc_port to ws_port value.")
|
||||
kwargs["grpc_port"] = kwargs.pop("ws_port")
|
||||
|
||||
init_logging()
|
||||
logging.info("Loading dataset")
|
||||
dataset = LeRobotDataset(repo_id, episodes=[args.episode_index], root=root, tolerance_s=tolerance_s)
|
||||
|
||||
|
||||
@@ -24,107 +24,94 @@ When new_repo_id is specified, creates a new dataset.
|
||||
Usage Examples:
|
||||
|
||||
Delete episodes 0, 2, and 5 from a dataset:
|
||||
lerobot-edit-dataset \
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--operation.type delete_episodes \
|
||||
--operation.episode_indices "[0, 2, 5]"
|
||||
|
||||
Delete episodes and save to a new dataset:
|
||||
lerobot-edit-dataset \
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--new_repo_id lerobot/pusht_filtered \
|
||||
--operation.type delete_episodes \
|
||||
--operation.episode_indices "[0, 2, 5]"
|
||||
|
||||
Split dataset by fractions:
|
||||
lerobot-edit-dataset \
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--operation.type split \
|
||||
--operation.splits '{"train": 0.8, "val": 0.2}'
|
||||
|
||||
Split dataset by episode indices:
|
||||
lerobot-edit-dataset \
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--operation.type split \
|
||||
--operation.splits '{"train": [0, 1, 2, 3], "val": [4, 5]}'
|
||||
|
||||
Split into more than two splits:
|
||||
lerobot-edit-dataset \
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--operation.type split \
|
||||
--operation.splits '{"train": 0.6, "val": 0.2, "test": 0.2}'
|
||||
|
||||
Merge multiple datasets:
|
||||
lerobot-edit-dataset \
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
--repo_id lerobot/pusht_merged \
|
||||
--operation.type merge \
|
||||
--operation.repo_ids "['lerobot/pusht_train', 'lerobot/pusht_val']"
|
||||
|
||||
Remove camera feature:
|
||||
lerobot-edit-dataset \
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--operation.type remove_feature \
|
||||
--operation.feature_names "['observation.images.top']"
|
||||
|
||||
Modify tasks - set a single task for all episodes (WARNING: modifies in-place):
|
||||
lerobot-edit-dataset \
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--operation.type modify_tasks \
|
||||
--operation.new_task "Pick up the cube and place it"
|
||||
|
||||
Modify tasks - set different tasks for specific episodes (WARNING: modifies in-place):
|
||||
lerobot-edit-dataset \
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--operation.type modify_tasks \
|
||||
--operation.episode_tasks '{"0": "Task A", "1": "Task B", "2": "Task A"}'
|
||||
|
||||
Modify tasks - set default task with overrides for specific episodes (WARNING: modifies in-place):
|
||||
lerobot-edit-dataset \
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--operation.type modify_tasks \
|
||||
--operation.new_task "Default task" \
|
||||
--operation.episode_tasks '{"5": "Special task for episode 5"}'
|
||||
|
||||
Convert image dataset to video format and save locally:
|
||||
lerobot-edit-dataset \
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--operation.type convert_image_to_video \
|
||||
--operation.output_dir /path/to/output/pusht_video
|
||||
|
||||
Convert image dataset to video format and save with new repo_id:
|
||||
lerobot-edit-dataset \
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--new_repo_id lerobot/pusht_video \
|
||||
--operation.type convert_image_to_video
|
||||
|
||||
Convert image dataset to video format and push to hub:
|
||||
lerobot-edit-dataset \
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--new_repo_id lerobot/pusht_video \
|
||||
--operation.type convert_image_to_video \
|
||||
--push_to_hub true
|
||||
|
||||
Show dataset information:
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--operation.type info \
|
||||
--operation.show_features true
|
||||
|
||||
Show dataset information without feature details:
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--operation.type info \
|
||||
--operation.show_features false
|
||||
|
||||
Using JSON config file:
|
||||
lerobot-edit-dataset \
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
--config_path path/to/edit_config.json
|
||||
"""
|
||||
|
||||
import abc
|
||||
import logging
|
||||
import shutil
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
@@ -197,13 +184,6 @@ class ConvertImageToVideoConfig(OperationConfig):
|
||||
max_frames_per_batch: int | None = None
|
||||
|
||||
|
||||
@OperationConfig.register_subclass("info")
|
||||
@dataclass
|
||||
class InfoConfig(OperationConfig):
|
||||
type: str = "info"
|
||||
show_features: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class EditDatasetConfig:
|
||||
repo_id: str
|
||||
@@ -456,49 +436,6 @@ def handle_convert_image_to_video(cfg: EditDatasetConfig) -> None:
|
||||
logging.info("Dataset saved locally (not pushed to hub)")
|
||||
|
||||
|
||||
def _get_dataset_size(repo_path):
|
||||
import os
|
||||
|
||||
total = 0
|
||||
with os.scandir(repo_path) as it:
|
||||
for entry in it:
|
||||
if entry.is_file():
|
||||
total += entry.stat().st_size
|
||||
elif entry.is_dir():
|
||||
total += _get_dataset_size(entry.path)
|
||||
return total
|
||||
|
||||
|
||||
def handle_info(cfg: EditDatasetConfig):
|
||||
if not isinstance(cfg.operation, InfoConfig):
|
||||
raise ValueError("Operation config must be InfoConfig")
|
||||
|
||||
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root)
|
||||
sys.stdout.write(f"======Info {dataset.meta.repo_id}\n")
|
||||
sys.stdout.write(f"Repository ID: {dataset.meta.repo_id} \n")
|
||||
sys.stdout.write(f"Total episode: {dataset.meta.total_episodes} \n")
|
||||
sys.stdout.write(f"Total task: {dataset.meta.total_tasks} \n")
|
||||
sys.stdout.write(f"Total frame(Actual Count): {dataset.meta.total_frames}({len(dataset)}) \n")
|
||||
sys.stdout.write(
|
||||
f"Average frame per episode: {dataset.meta.total_frames / dataset.meta.total_episodes:.1f}\n"
|
||||
)
|
||||
sys.stdout.write(
|
||||
f"Average episode time(sec): {(dataset.meta.total_frames / dataset.meta.total_episodes) / dataset.meta.fps:.1f}\n"
|
||||
)
|
||||
sys.stdout.write(f"FPS: {dataset.meta.fps}\n")
|
||||
|
||||
total_file_size = _get_dataset_size(dataset.root)
|
||||
sys.stdout.write(f"Size: {total_file_size / (1024 * 1024):.1f} MB\n")
|
||||
if cfg.operation.show_features:
|
||||
import json
|
||||
|
||||
feature_dump_str = json.dumps(
|
||||
dataset.meta.features, ensure_ascii=False, indent=4, sort_keys=True, separators=(",", ": ")
|
||||
)
|
||||
sys.stdout.write("Features:\n")
|
||||
sys.stdout.write(f"{feature_dump_str}\n")
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def edit_dataset(cfg: EditDatasetConfig) -> None:
|
||||
operation_type = cfg.operation.type
|
||||
@@ -515,8 +452,6 @@ def edit_dataset(cfg: EditDatasetConfig) -> None:
|
||||
handle_modify_tasks(cfg)
|
||||
elif operation_type == "convert_image_to_video":
|
||||
handle_convert_image_to_video(cfg)
|
||||
elif operation_type == "info":
|
||||
handle_info(cfg)
|
||||
else:
|
||||
available = ", ".join(OperationConfig.get_known_choices())
|
||||
raise ValueError(f"Unknown operation: {operation_type}\nAvailable operations: {available}")
|
||||
|
||||
@@ -26,10 +26,8 @@ lerobot-record \
|
||||
--dataset.repo_id=<my_username>/<my_dataset_name> \
|
||||
--dataset.num_episodes=2 \
|
||||
--dataset.single_task="Grab the cube" \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
--display_data=true
|
||||
# <- Optional: specify video codec (auto, h264, hevc, libsvtav1). Default is libsvtav1. \
|
||||
# <- Optional: specify video codec (h264, hevc, libsvtav1). Default is libsvtav1. \
|
||||
# --dataset.vcodec=h264 \
|
||||
# <- Teleop optional if you want to teleoperate to record or in between episodes with a policy \
|
||||
# --teleop.type=so100_leader \
|
||||
@@ -60,10 +58,7 @@ lerobot-record \
|
||||
--display_data=true \
|
||||
--dataset.repo_id=${HF_USER}/bimanual-so-handover-cube \
|
||||
--dataset.num_episodes=25 \
|
||||
--dataset.single_task="Grab and handover the red cube to the other arm" \
|
||||
--dataset.streaming_encoding=true \
|
||||
# --dataset.vcodec=auto \
|
||||
--dataset.encoder_threads=2
|
||||
--dataset.single_task="Grab and handover the red cube to the other arm"
|
||||
```
|
||||
"""
|
||||
|
||||
@@ -184,19 +179,9 @@ class DatasetRecordConfig:
|
||||
# Number of episodes to record before batch encoding videos
|
||||
# Set to 1 for immediate encoding (default behavior), or higher for batched encoding
|
||||
video_encoding_batch_size: int = 1
|
||||
# Video codec for encoding videos. Options: 'h264', 'hevc', 'libsvtav1', 'auto',
|
||||
# or hardware-specific: 'h264_videotoolbox', 'h264_nvenc', 'h264_vaapi', 'h264_qsv'.
|
||||
# Use 'auto' to auto-detect the best available hardware encoder.
|
||||
# Video codec for encoding videos. Options: 'h264', 'hevc', 'libsvtav1'.
|
||||
# Use 'h264' for faster encoding on systems where AV1 encoding is CPU-heavy.
|
||||
vcodec: str = "libsvtav1"
|
||||
# Enable streaming video encoding: encode frames in real-time during capture instead
|
||||
# of writing PNG images first. Makes save_episode() near-instant. More info in the documentation: https://huggingface.co/docs/lerobot/streaming_video_encoding
|
||||
streaming_encoding: bool = False
|
||||
# Maximum number of frames to buffer per camera when using streaming encoding.
|
||||
# ~1s buffer at 30fps. Provides backpressure if the encoder can't keep up.
|
||||
encoder_queue_maxsize: int = 30
|
||||
# Number of threads per encoder instance. None = auto (codec default).
|
||||
# Lower values reduce CPU usage, maps to 'lp' (via svtav1-params) for libsvtav1 and 'threads' for h264/hevc..
|
||||
encoder_threads: int | None = None
|
||||
# Rename map for the observation to override the image and state keys
|
||||
rename_map: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
@@ -413,14 +398,7 @@ def record_loop(
|
||||
)
|
||||
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
|
||||
sleep_time_s: float = 1 / fps - dt_s
|
||||
if sleep_time_s < 0:
|
||||
logging.warning(
|
||||
f"Record loop is running slower ({1 / dt_s:.1f} Hz) than the target FPS ({fps} Hz). Dataset frames might be dropped and robot control might be unstable. Common causes are: 1) Camera FPS not keeping up 2) Policy inference taking too long 3) CPU starvation"
|
||||
)
|
||||
|
||||
precise_sleep(max(sleep_time_s, 0.0))
|
||||
precise_sleep(max(1 / fps - dt_s, 0.0))
|
||||
|
||||
timestamp = time.perf_counter() - start_episode_t
|
||||
|
||||
@@ -467,9 +445,6 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
|
||||
root=cfg.dataset.root,
|
||||
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
|
||||
vcodec=cfg.dataset.vcodec,
|
||||
streaming_encoding=cfg.dataset.streaming_encoding,
|
||||
encoder_queue_maxsize=cfg.dataset.encoder_queue_maxsize,
|
||||
encoder_threads=cfg.dataset.encoder_threads,
|
||||
)
|
||||
|
||||
if hasattr(robot, "cameras") and len(robot.cameras) > 0:
|
||||
@@ -492,9 +467,6 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
|
||||
image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot.cameras),
|
||||
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
|
||||
vcodec=cfg.dataset.vcodec,
|
||||
streaming_encoding=cfg.dataset.streaming_encoding,
|
||||
encoder_queue_maxsize=cfg.dataset.encoder_queue_maxsize,
|
||||
encoder_threads=cfg.dataset.encoder_threads,
|
||||
)
|
||||
|
||||
# Load pretrained policy
|
||||
@@ -518,11 +490,6 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
|
||||
|
||||
listener, events = init_keyboard_listener()
|
||||
|
||||
if not cfg.dataset.streaming_encoding:
|
||||
logging.info(
|
||||
"Streaming encoding is disabled. If you have capable hardware, consider enabling it for way faster episode saving. --dataset.streaming_encoding=true --dataset.encoder_threads=2 # --dataset.vcodec=auto. More info in the documentation: https://huggingface.co/docs/lerobot/streaming_video_encoding"
|
||||
)
|
||||
|
||||
with VideoEncodingManager(dataset):
|
||||
recorded_episodes = 0
|
||||
while recorded_episodes < cfg.dataset.num_episodes and not events["stop_recording"]:
|
||||
|
||||
@@ -22,7 +22,7 @@ lerobot-replay \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58760431541 \
|
||||
--robot.id=black \
|
||||
--dataset.repo_id=<USER>/record-test \
|
||||
--dataset.repo_id=aliberts/record-test \
|
||||
--dataset.episode=0
|
||||
```
|
||||
|
||||
|
||||
@@ -152,7 +152,6 @@ def test_motor(bus, motor_id: int, timeout: float, use_fd: bool):
|
||||
)
|
||||
try:
|
||||
bus.send(disable_msg)
|
||||
bus.recv(timeout=0.1) # Clear any pending responses
|
||||
except Exception:
|
||||
print(f"Error sending message to motor 0x{motor_id:02X}")
|
||||
|
||||
|
||||
@@ -43,7 +43,6 @@ from lerobot.teleoperators import ( # noqa: F401
|
||||
koch_leader,
|
||||
make_teleoperator_from_config,
|
||||
omx_leader,
|
||||
openarm_mini,
|
||||
so_leader,
|
||||
)
|
||||
|
||||
@@ -52,7 +51,6 @@ COMPATIBLE_DEVICES = [
|
||||
"koch_leader",
|
||||
"omx_follower",
|
||||
"omx_leader",
|
||||
"openarm_mini",
|
||||
"so100_follower",
|
||||
"so100_leader",
|
||||
"so101_follower",
|
||||
|
||||
@@ -24,7 +24,6 @@ import torch
|
||||
from accelerate import Accelerator
|
||||
from termcolor import colored
|
||||
from torch.optim import Optimizer
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
@@ -38,7 +37,6 @@ from lerobot.policies.factory import make_policy, make_pre_post_processors
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.rl.wandb_utils import WandBLogger
|
||||
from lerobot.scripts.lerobot_eval import eval_policy_all
|
||||
from lerobot.teleoperators import openarm_mini # noqa: F401
|
||||
from lerobot.utils.import_utils import register_third_party_plugins
|
||||
from lerobot.utils.logging_utils import AverageMeter, MetricsTracker
|
||||
from lerobot.utils.random_utils import set_seed
|
||||
@@ -53,7 +51,6 @@ from lerobot.utils.utils import (
|
||||
format_big_number,
|
||||
has_method,
|
||||
init_logging,
|
||||
inside_slurm,
|
||||
)
|
||||
|
||||
|
||||
@@ -178,8 +175,6 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
from accelerate.utils import DistributedDataParallelKwargs
|
||||
|
||||
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
||||
# Accelerate auto-detects the device based on the available hardware and ignores the policy.device setting.
|
||||
# Force the device to be CPU when policy.device is set to CPU.
|
||||
force_cpu = cfg.policy.device == "cpu"
|
||||
accelerator = Accelerator(
|
||||
step_scheduler_with_optimizer=False,
|
||||
@@ -214,16 +209,98 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
# Dataset loading synchronization: main process downloads first to avoid race conditions
|
||||
delta_action_stats = None
|
||||
if is_main_process:
|
||||
logging.info("Creating dataset")
|
||||
dataset = make_dataset(cfg)
|
||||
|
||||
# Compute delta action stats BEFORE distributed sync to avoid NCCL timeout
|
||||
if getattr(cfg.policy, "use_delta_actions", False):
|
||||
import numpy as np
|
||||
|
||||
from lerobot.datasets.compute_stats import get_feature_stats
|
||||
from lerobot.processor.delta_action_processor import DeltaActionsProcessorStep, to_delta_actions
|
||||
|
||||
chunk_size = cfg.policy.chunk_size
|
||||
hf = dataset.hf_dataset
|
||||
total_frames = len(hf)
|
||||
sample_upper_bound = total_frames - chunk_size
|
||||
if sample_upper_bound <= 0:
|
||||
raise ValueError(
|
||||
f"Cannot compute delta action stats: total_frames={total_frames}, chunk_size={chunk_size}"
|
||||
)
|
||||
|
||||
max_samples = min(100_000, sample_upper_bound)
|
||||
indices = np.random.choice(sample_upper_bound, max_samples, replace=False)
|
||||
|
||||
action_names = dataset.meta.features.get("action", {}).get("names")
|
||||
delta_mask_step = DeltaActionsProcessorStep(
|
||||
enabled=True,
|
||||
exclude_joints=getattr(cfg.policy, "delta_exclude_joints", []),
|
||||
action_names=action_names,
|
||||
)
|
||||
delta_mask = delta_mask_step._build_mask(dataset.meta.features["action"]["shape"][0])
|
||||
logging.info(
|
||||
f"use_delta_actions is enabled — computing delta action stats "
|
||||
f"from {max_samples} chunk samples (chunk_size={chunk_size})"
|
||||
)
|
||||
|
||||
all_delta_actions = []
|
||||
episode_indices = np.array(hf["episode_index"])
|
||||
for idx in indices:
|
||||
idx = int(idx)
|
||||
ep_idx = episode_indices[idx]
|
||||
end_idx = min(idx + chunk_size, total_frames)
|
||||
if end_idx > idx and episode_indices[end_idx - 1] != ep_idx:
|
||||
continue
|
||||
|
||||
chunk_data = hf[idx:end_idx]
|
||||
actions = torch.tensor(np.stack([np.asarray(a) for a in chunk_data["action"]])).float()
|
||||
state = torch.tensor(np.asarray(chunk_data["observation.state"][0])).float()
|
||||
|
||||
delta = to_delta_actions(actions.unsqueeze(0), state.unsqueeze(0), delta_mask).squeeze(0)
|
||||
all_delta_actions.append(delta.numpy())
|
||||
|
||||
if not all_delta_actions:
|
||||
raise RuntimeError("Failed to compute delta action stats: no valid chunks found.")
|
||||
|
||||
all_delta = np.concatenate(all_delta_actions, axis=0)
|
||||
delta_stats = get_feature_stats(all_delta, axis=0, keepdims=all_delta.ndim == 1)
|
||||
delta_action_stats = delta_stats
|
||||
dataset.meta.stats["action"] = delta_action_stats
|
||||
|
||||
norm_type = "UNKNOWN"
|
||||
if hasattr(cfg.policy, "normalization_mapping"):
|
||||
from lerobot.configs.types import NormalizationMode
|
||||
action_norm = cfg.policy.normalization_mapping.get("ACTION", None)
|
||||
norm_type = action_norm.value if action_norm else "UNKNOWN"
|
||||
|
||||
excluded_dims = len(delta_mask) - sum(delta_mask)
|
||||
logging.info(
|
||||
f"Delta action stats ({len(all_delta_actions)} chunks, {len(all_delta)} values, norm={norm_type}): "
|
||||
f"delta_dims={sum(delta_mask)}/{len(delta_mask)} (excluded={excluded_dims}), "
|
||||
f"mean={np.abs(delta_stats['mean']).mean():.4f}, std={delta_stats['std'].mean():.4f}, "
|
||||
f"q01={delta_stats['q01'].mean():.4f}, q99={delta_stats['q99'].mean():.4f}"
|
||||
)
|
||||
if norm_type == "QUANTILES":
|
||||
q_range = (delta_stats['q99'] - delta_stats['q01']).mean()
|
||||
logging.info(f" Quantile range (q99-q01): {q_range:.4f}")
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# Now all other processes can safely load the dataset
|
||||
if not is_main_process:
|
||||
dataset = make_dataset(cfg)
|
||||
|
||||
# Ensure all ranks use the exact same delta action stats.
|
||||
if getattr(cfg.policy, "use_delta_actions", False):
|
||||
if accelerator.num_processes > 1 and torch.distributed.is_initialized():
|
||||
stats_list = [delta_action_stats]
|
||||
torch.distributed.broadcast_object_list(stats_list, src=0)
|
||||
delta_action_stats = stats_list[0]
|
||||
if delta_action_stats is not None:
|
||||
dataset.meta.stats["action"] = delta_action_stats
|
||||
|
||||
# Create environment used for evaluating checkpoints during training on simulation data.
|
||||
# On real-world data, no need to create an environment as evaluations are done outside train.py,
|
||||
# using the eval.py instead, with gym_dora environment and dora-rs.
|
||||
@@ -249,10 +326,22 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
# Wait for all processes to finish policy creation before continuing
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
processor_pretrained_path = cfg.policy.pretrained_path
|
||||
if (
|
||||
getattr(cfg.policy, "use_delta_actions", False)
|
||||
and processor_pretrained_path is not None
|
||||
and not cfg.resume
|
||||
):
|
||||
logging.warning(
|
||||
"use_delta_actions=true with pretrained processors can skip delta transforms if "
|
||||
"the checkpoint processors do not define them. Building processors from current policy config."
|
||||
)
|
||||
processor_pretrained_path = None
|
||||
|
||||
# Create processors - only provide dataset_stats if not resuming from saved processors
|
||||
processor_kwargs = {}
|
||||
postprocessor_kwargs = {}
|
||||
if (cfg.policy.pretrained_path and not cfg.resume) or not cfg.policy.pretrained_path:
|
||||
if (processor_pretrained_path and not cfg.resume) or not processor_pretrained_path:
|
||||
# Only provide dataset_stats when not resuming from saved processor state
|
||||
processor_kwargs["dataset_stats"] = dataset.meta.stats
|
||||
|
||||
@@ -260,7 +349,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
if cfg.policy.type == "sarm":
|
||||
processor_kwargs["dataset_meta"] = dataset.meta
|
||||
|
||||
if cfg.policy.pretrained_path is not None:
|
||||
if processor_pretrained_path is not None:
|
||||
processor_kwargs["preprocessor_overrides"] = {
|
||||
"device_processor": {"device": device.type},
|
||||
"normalizer_processor": {
|
||||
@@ -282,7 +371,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy,
|
||||
pretrained_path=cfg.policy.pretrained_path,
|
||||
pretrained_path=processor_pretrained_path,
|
||||
**processor_kwargs,
|
||||
**postprocessor_kwargs,
|
||||
)
|
||||
@@ -393,14 +482,6 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
)
|
||||
|
||||
if is_main_process:
|
||||
progbar = tqdm(
|
||||
total=cfg.steps - step,
|
||||
desc="Training",
|
||||
unit="step",
|
||||
disable=inside_slurm(),
|
||||
position=0,
|
||||
leave=True,
|
||||
)
|
||||
logging.info(
|
||||
f"Start offline training on a fixed dataset, with effective batch size: {effective_batch_size}"
|
||||
)
|
||||
@@ -408,7 +489,36 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
for _ in range(step, cfg.steps):
|
||||
start_time = time.perf_counter()
|
||||
batch = next(dl_iter)
|
||||
|
||||
# Debug logging for first few steps and periodically
|
||||
if is_main_process and (step < 3 or (cfg.log_freq > 0 and step % (cfg.log_freq * 10) == 0)):
|
||||
action = batch.get("action")
|
||||
state = batch.get("observation.state")
|
||||
if action is not None and state is not None:
|
||||
logging.info(
|
||||
f"[DEBUG step={step}] PRE-PROCESSOR — "
|
||||
f"action: shape={tuple(action.shape)}, mean={action.mean():.4f}, std={action.std():.4f}, "
|
||||
f"min={action.min():.4f}, max={action.max():.4f} | "
|
||||
f"state: shape={tuple(state.shape)}, mean={state.mean():.4f}"
|
||||
)
|
||||
|
||||
batch = preprocessor(batch)
|
||||
|
||||
if is_main_process and (step < 3 or (cfg.log_freq > 0 and step % (cfg.log_freq * 10) == 0)):
|
||||
action = batch.get("action")
|
||||
state = batch.get("observation.state")
|
||||
if action is not None:
|
||||
logging.info(
|
||||
f"[DEBUG step={step}] POST-PROCESSOR — "
|
||||
f"action: shape={tuple(action.shape)}, mean={action.mean():.4f}, std={action.std():.4f}, "
|
||||
f"min={action.min():.4f}, max={action.max():.4f}"
|
||||
)
|
||||
if state is not None:
|
||||
logging.info(
|
||||
f"[DEBUG step={step}] POST-PROCESSOR — "
|
||||
f"state: shape={tuple(state.shape)}, mean={state.mean():.4f}, std={state.std():.4f}"
|
||||
)
|
||||
|
||||
train_tracker.dataloading_s = time.perf_counter() - start_time
|
||||
|
||||
train_tracker, output_dict = update_policy(
|
||||
@@ -425,8 +535,6 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
|
||||
# increment `step` here.
|
||||
step += 1
|
||||
if is_main_process:
|
||||
progbar.update(1)
|
||||
train_tracker.step()
|
||||
is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0 and is_main_process
|
||||
is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps
|
||||
@@ -520,9 +628,6 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
if is_main_process:
|
||||
progbar.close()
|
||||
|
||||
if eval_env:
|
||||
close_envs(eval_env)
|
||||
|
||||
|
||||
@@ -1,20 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .config_openarm_mini import OpenArmMiniConfig
|
||||
from .openarm_mini import OpenArmMini
|
||||
|
||||
__all__ = ["OpenArmMini", "OpenArmMiniConfig"]
|
||||
@@ -1,30 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from ..config import TeleoperatorConfig
|
||||
|
||||
|
||||
@TeleoperatorConfig.register_subclass("openarm_mini")
|
||||
@dataclass
|
||||
class OpenArmMiniConfig(TeleoperatorConfig):
|
||||
"""Configuration for OpenArm Mini teleoperator with Feetech motors (dual arms)."""
|
||||
|
||||
port_right: str = "/dev/ttyUSB0"
|
||||
port_left: str = "/dev/ttyUSB1"
|
||||
|
||||
use_degrees: bool = True
|
||||
@@ -1,296 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
|
||||
from lerobot.motors.feetech import (
|
||||
FeetechMotorsBus,
|
||||
OperatingMode,
|
||||
)
|
||||
from lerobot.processor import RobotAction
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
|
||||
from ..teleoperator import Teleoperator
|
||||
from .config_openarm_mini import OpenArmMiniConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Motors whose direction is inverted during readout
|
||||
RIGHT_MOTORS_TO_FLIP = ["joint_1", "joint_2", "joint_3", "joint_4", "joint_5"]
|
||||
LEFT_MOTORS_TO_FLIP = ["joint_1", "joint_3", "joint_4", "joint_5", "joint_6", "joint_7"]
|
||||
|
||||
|
||||
class OpenArmMini(Teleoperator):
|
||||
"""
|
||||
OpenArm Mini Teleoperator with dual Feetech-based arms (8 motors per arm).
|
||||
|
||||
Each arm has 7 joints plus a gripper, using Feetech STS3215 servos.
|
||||
"""
|
||||
|
||||
config_class = OpenArmMiniConfig
|
||||
name = "openarm_mini"
|
||||
|
||||
def __init__(self, config: OpenArmMiniConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
norm_mode_body = MotorNormMode.DEGREES
|
||||
|
||||
motors_right = {
|
||||
"joint_1": Motor(1, "sts3215", norm_mode_body),
|
||||
"joint_2": Motor(2, "sts3215", norm_mode_body),
|
||||
"joint_3": Motor(3, "sts3215", norm_mode_body),
|
||||
"joint_4": Motor(4, "sts3215", norm_mode_body),
|
||||
"joint_5": Motor(5, "sts3215", norm_mode_body),
|
||||
"joint_6": Motor(6, "sts3215", norm_mode_body),
|
||||
"joint_7": Motor(7, "sts3215", norm_mode_body),
|
||||
"gripper": Motor(8, "sts3215", MotorNormMode.RANGE_0_100),
|
||||
}
|
||||
|
||||
motors_left = {
|
||||
"joint_1": Motor(1, "sts3215", norm_mode_body),
|
||||
"joint_2": Motor(2, "sts3215", norm_mode_body),
|
||||
"joint_3": Motor(3, "sts3215", norm_mode_body),
|
||||
"joint_4": Motor(4, "sts3215", norm_mode_body),
|
||||
"joint_5": Motor(5, "sts3215", norm_mode_body),
|
||||
"joint_6": Motor(6, "sts3215", norm_mode_body),
|
||||
"joint_7": Motor(7, "sts3215", norm_mode_body),
|
||||
"gripper": Motor(8, "sts3215", MotorNormMode.RANGE_0_100),
|
||||
}
|
||||
|
||||
cal_right = {
|
||||
k.replace("right_", ""): v for k, v in (self.calibration or {}).items() if k.startswith("right_")
|
||||
}
|
||||
cal_left = {
|
||||
k.replace("left_", ""): v for k, v in (self.calibration or {}).items() if k.startswith("left_")
|
||||
}
|
||||
|
||||
self.bus_right = FeetechMotorsBus(
|
||||
port=self.config.port_right,
|
||||
motors=motors_right,
|
||||
calibration=cal_right,
|
||||
)
|
||||
|
||||
self.bus_left = FeetechMotorsBus(
|
||||
port=self.config.port_left,
|
||||
motors=motors_left,
|
||||
calibration=cal_left,
|
||||
)
|
||||
|
||||
@property
|
||||
def action_features(self) -> dict[str, type]:
|
||||
features: dict[str, type] = {}
|
||||
for motor in self.bus_right.motors:
|
||||
features[f"right_{motor}.pos"] = float
|
||||
for motor in self.bus_left.motors:
|
||||
features[f"left_{motor}.pos"] = float
|
||||
return features
|
||||
|
||||
@property
|
||||
def feedback_features(self) -> dict[str, type]:
|
||||
return {}
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self.bus_right.is_connected and self.bus_left.is_connected
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
logger.info(f"Connecting right arm on {self.config.port_right}...")
|
||||
self.bus_right.connect()
|
||||
logger.info(f"Connecting left arm on {self.config.port_left}...")
|
||||
self.bus_left.connect()
|
||||
|
||||
if calibrate:
|
||||
self.calibrate()
|
||||
|
||||
self.configure()
|
||||
logger.info(f"{self} connected.")
|
||||
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
return self.bus_right.is_calibrated and self.bus_left.is_calibrated
|
||||
|
||||
def calibrate(self) -> None:
|
||||
"""
|
||||
Run calibration procedure for OpenArm Mini.
|
||||
|
||||
1. Disable torque
|
||||
2. Ask user to position arms in hanging position with grippers closed
|
||||
3. Set this as zero position via half-turn homing
|
||||
4. Interactive gripper calibration (open/close positions)
|
||||
5. Save calibration
|
||||
"""
|
||||
if self.calibration:
|
||||
user_input = input(
|
||||
f"Press ENTER to use existing calibration for {self.id}, "
|
||||
f"or type 'c' and press ENTER to run new calibration: "
|
||||
)
|
||||
if user_input.strip().lower() != "c":
|
||||
logger.info(f"Using existing calibration for {self.id}")
|
||||
cal_right = {
|
||||
k.replace("right_", ""): v for k, v in self.calibration.items() if k.startswith("right_")
|
||||
}
|
||||
cal_left = {
|
||||
k.replace("left_", ""): v for k, v in self.calibration.items() if k.startswith("left_")
|
||||
}
|
||||
self.bus_right.write_calibration(cal_right)
|
||||
self.bus_left.write_calibration(cal_left)
|
||||
return
|
||||
|
||||
logger.info(f"\nRunning calibration for {self}")
|
||||
|
||||
self._calibrate_arm("right", self.bus_right)
|
||||
self._calibrate_arm("left", self.bus_left)
|
||||
|
||||
self._save_calibration()
|
||||
print(f"\nCalibration complete and saved to {self.calibration_fpath}")
|
||||
|
||||
def _calibrate_arm(self, arm_name: str, bus: FeetechMotorsBus) -> None:
|
||||
"""Calibrate a single arm with Feetech motors."""
|
||||
logger.info(f"\n=== Calibrating {arm_name.upper()} arm ===")
|
||||
|
||||
bus.disable_torque()
|
||||
|
||||
logger.info(f"Setting Phase to 12 for all motors in {arm_name.upper()} arm...")
|
||||
for motor in bus.motors:
|
||||
bus.write("Phase", motor, 12)
|
||||
|
||||
for motor in bus.motors:
|
||||
bus.write("Operating_Mode", motor, OperatingMode.POSITION.value)
|
||||
|
||||
input(
|
||||
f"\nCalibration: Zero Position ({arm_name.upper()} arm)\n"
|
||||
"Position the arm in the following configuration:\n"
|
||||
" - Arm hanging straight down\n"
|
||||
" - Gripper closed\n"
|
||||
"Press ENTER when ready..."
|
||||
)
|
||||
|
||||
homing_offsets = bus.set_half_turn_homings()
|
||||
logger.info(f"{arm_name.capitalize()} arm zero position set.")
|
||||
|
||||
print(f"\nSetting motor ranges for {arm_name.upper()} arm\n")
|
||||
|
||||
if self.calibration is None:
|
||||
self.calibration = {}
|
||||
|
||||
motor_resolution = bus.model_resolution_table[list(bus.motors.values())[0].model]
|
||||
max_res = motor_resolution - 1
|
||||
|
||||
for motor_name, motor in bus.motors.items():
|
||||
prefixed_name = f"{arm_name}_{motor_name}"
|
||||
|
||||
if motor_name == "gripper":
|
||||
input(
|
||||
f"\nGripper Calibration ({arm_name.upper()} arm)\n"
|
||||
f"Step 1: CLOSE the gripper fully\n"
|
||||
f"Press ENTER when gripper is closed..."
|
||||
)
|
||||
closed_pos = bus.read("Present_Position", motor_name, normalize=False)
|
||||
logger.info(f" Gripper closed position recorded: {closed_pos}")
|
||||
|
||||
input("\nStep 2: OPEN the gripper fully\nPress ENTER when gripper is fully open...")
|
||||
open_pos = bus.read("Present_Position", motor_name, normalize=False)
|
||||
logger.info(f" Gripper open position recorded: {open_pos}")
|
||||
|
||||
if closed_pos < open_pos:
|
||||
range_min = int(closed_pos)
|
||||
range_max = int(open_pos)
|
||||
drive_mode = 0
|
||||
else:
|
||||
range_min = int(open_pos)
|
||||
range_max = int(closed_pos)
|
||||
drive_mode = 1
|
||||
|
||||
logger.info(
|
||||
f" {prefixed_name}: range set to [{range_min}, {range_max}] "
|
||||
f"(0=closed, 100=open, drive_mode={drive_mode})"
|
||||
)
|
||||
else:
|
||||
range_min = 0
|
||||
range_max = max_res
|
||||
drive_mode = 0
|
||||
logger.info(f" {prefixed_name}: range set to [0, {max_res}] (full motor range)")
|
||||
|
||||
self.calibration[prefixed_name] = MotorCalibration(
|
||||
id=motor.id,
|
||||
drive_mode=drive_mode,
|
||||
homing_offset=homing_offsets[motor_name],
|
||||
range_min=range_min,
|
||||
range_max=range_max,
|
||||
)
|
||||
|
||||
cal_for_bus = {
|
||||
k.replace(f"{arm_name}_", ""): v
|
||||
for k, v in self.calibration.items()
|
||||
if k.startswith(f"{arm_name}_")
|
||||
}
|
||||
bus.write_calibration(cal_for_bus)
|
||||
|
||||
def configure(self) -> None:
|
||||
self.bus_right.disable_torque()
|
||||
self.bus_right.configure_motors()
|
||||
for motor in self.bus_right.motors:
|
||||
self.bus_right.write("Operating_Mode", motor, OperatingMode.POSITION.value)
|
||||
|
||||
self.bus_left.disable_torque()
|
||||
self.bus_left.configure_motors()
|
||||
for motor in self.bus_left.motors:
|
||||
self.bus_left.write("Operating_Mode", motor, OperatingMode.POSITION.value)
|
||||
|
||||
def setup_motors(self) -> None:
|
||||
print("\nSetting up RIGHT arm motors...")
|
||||
for motor in reversed(self.bus_right.motors):
|
||||
input(f"Connect the controller board to the RIGHT '{motor}' motor only and press enter.")
|
||||
self.bus_right.setup_motor(motor)
|
||||
print(f"RIGHT '{motor}' motor id set to {self.bus_right.motors[motor].id}")
|
||||
|
||||
print("\nSetting up LEFT arm motors...")
|
||||
for motor in reversed(self.bus_left.motors):
|
||||
input(f"Connect the controller board to the LEFT '{motor}' motor only and press enter.")
|
||||
self.bus_left.setup_motor(motor)
|
||||
print(f"LEFT '{motor}' motor id set to {self.bus_left.motors[motor].id}")
|
||||
|
||||
@check_if_not_connected
|
||||
def get_action(self) -> RobotAction:
|
||||
"""Get current action from both arms (read positions from all motors)."""
|
||||
start = time.perf_counter()
|
||||
|
||||
right_positions = self.bus_right.sync_read("Present_Position")
|
||||
left_positions = self.bus_left.sync_read("Present_Position")
|
||||
|
||||
action: dict[str, Any] = {}
|
||||
for motor, val in right_positions.items():
|
||||
action[f"right_{motor}.pos"] = -val if motor in RIGHT_MOTORS_TO_FLIP else val
|
||||
for motor, val in left_positions.items():
|
||||
action[f"left_{motor}.pos"] = -val if motor in LEFT_MOTORS_TO_FLIP else val
|
||||
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read action: {dt_ms:.1f}ms")
|
||||
return action
|
||||
|
||||
def send_feedback(self, feedback: dict[str, float]) -> None:
|
||||
raise NotImplementedError("Feedback is not yet implemented for OpenArm Mini.")
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self) -> None:
|
||||
self.bus_right.disconnect()
|
||||
self.bus_left.disconnect()
|
||||
logger.info(f"{self} disconnected.")
|
||||
@@ -95,10 +95,6 @@ def make_teleoperator_from_config(config: TeleoperatorConfig) -> "Teleoperator":
|
||||
from .bi_openarm_leader import BiOpenArmLeader
|
||||
|
||||
return BiOpenArmLeader(config)
|
||||
elif config.type == "openarm_mini":
|
||||
from .openarm_mini import OpenArmMini
|
||||
|
||||
return OpenArmMini(config)
|
||||
else:
|
||||
try:
|
||||
return cast("Teleoperator", make_device_from_device_class(config))
|
||||
|
||||
@@ -189,7 +189,7 @@ def sanity_check_dataset_name(repo_id, policy_cfg):
|
||||
# Check if dataset_name starts with "eval_" but policy is missing
|
||||
if dataset_name.startswith("eval_") and policy_cfg is None:
|
||||
raise ValueError(
|
||||
f"Your dataset name begins with 'eval_' ({dataset_name}), but no policy is provided."
|
||||
f"Your dataset name begins with 'eval_' ({dataset_name}), but no policy is provided ({policy_cfg.type})."
|
||||
)
|
||||
|
||||
# Check if dataset_name does not start with "eval_" but policy is provided
|
||||
|
||||
@@ -16,14 +16,14 @@ import platform
|
||||
import time
|
||||
|
||||
|
||||
def precise_sleep(seconds: float, spin_threshold: float = 0.010, sleep_margin: float = 0.005):
|
||||
def precise_sleep(seconds: float, spin_threshold: float = 0.010, sleep_margin: float = 0.003):
|
||||
"""
|
||||
Wait for `seconds` with better precision than time.sleep alone at the expense of more CPU usage.
|
||||
|
||||
Parameters:
|
||||
- seconds: duration to wait
|
||||
- spin_threshold: if remaining <= spin_threshold -> spin; otherwise sleep (seconds). Default 10ms
|
||||
- sleep_margin: when sleeping leave this much time before deadline to avoid oversleep. Default 5ms
|
||||
- sleep_margin: when sleeping leave this much time before deadline to avoid oversleep. Default 3ms
|
||||
|
||||
Note:
|
||||
The default parameters are chosen to prioritize timing accuracy over CPU usage for the common 30 FPS use case.
|
||||
|
||||
@@ -31,6 +31,7 @@ from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.datasets.factory import make_dataset
|
||||
from lerobot.datasets.image_writer import image_array_to_pil_image
|
||||
from lerobot.datasets.lerobot_dataset import (
|
||||
VALID_VIDEO_CODECS,
|
||||
LeRobotDataset,
|
||||
MultiLeRobotDataset,
|
||||
_encode_video_worker,
|
||||
@@ -44,7 +45,6 @@ from lerobot.datasets.utils import (
|
||||
hf_transform_to_torch,
|
||||
hw_to_dataset_features,
|
||||
)
|
||||
from lerobot.datasets.video_utils import VALID_VIDEO_CODECS
|
||||
from lerobot.envs.factory import make_env_config
|
||||
from lerobot.policies.factory import make_policy_config
|
||||
from lerobot.robots import make_robot_from_config
|
||||
@@ -393,7 +393,7 @@ def test_tmp_mixed_deletion(tmp_path, empty_lerobot_dataset_factory):
|
||||
vid_key: {"dtype": "video", "shape": DUMMY_HWC, "names": ["height", "width", "channels"]},
|
||||
}
|
||||
ds_mixed = empty_lerobot_dataset_factory(
|
||||
root=tmp_path / "mixed", features=features_mixed, batch_encoding_size=2, streaming_encoding=False
|
||||
root=tmp_path / "mixed", features=features_mixed, batch_encoding_size=2
|
||||
)
|
||||
ds_mixed.add_frame(
|
||||
{
|
||||
@@ -1450,10 +1450,7 @@ def test_valid_video_codecs_constant():
|
||||
assert "h264" in VALID_VIDEO_CODECS
|
||||
assert "hevc" in VALID_VIDEO_CODECS
|
||||
assert "libsvtav1" in VALID_VIDEO_CODECS
|
||||
assert "auto" in VALID_VIDEO_CODECS
|
||||
assert "h264_videotoolbox" in VALID_VIDEO_CODECS
|
||||
assert "h264_nvenc" in VALID_VIDEO_CODECS
|
||||
assert len(VALID_VIDEO_CODECS) == 10
|
||||
assert len(VALID_VIDEO_CODECS) == 3
|
||||
|
||||
|
||||
def test_delta_timestamps_with_episodes_filter(tmp_path, empty_lerobot_dataset_factory):
|
||||
|
||||
@@ -1,730 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for streaming video encoding and hardware-accelerated encoding."""
|
||||
|
||||
import queue
|
||||
import threading
|
||||
from unittest.mock import patch
|
||||
|
||||
import av
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from lerobot.datasets.video_utils import (
|
||||
VALID_VIDEO_CODECS,
|
||||
StreamingVideoEncoder,
|
||||
_CameraEncoderThread,
|
||||
_get_codec_options,
|
||||
detect_available_hw_encoders,
|
||||
resolve_vcodec,
|
||||
)
|
||||
from lerobot.utils.constants import OBS_IMAGES
|
||||
|
||||
# ─── _get_codec_options tests ───
|
||||
|
||||
|
||||
class TestGetCodecOptions:
|
||||
def test_libsvtav1_defaults(self):
|
||||
opts = _get_codec_options("libsvtav1")
|
||||
assert opts["g"] == "2"
|
||||
assert opts["crf"] == "30"
|
||||
assert opts["preset"] == "12"
|
||||
|
||||
def test_libsvtav1_custom_preset(self):
|
||||
opts = _get_codec_options("libsvtav1", preset=8)
|
||||
assert opts["preset"] == "8"
|
||||
|
||||
def test_h264_options(self):
|
||||
opts = _get_codec_options("h264", g=10, crf=23)
|
||||
assert opts["g"] == "10"
|
||||
assert opts["crf"] == "23"
|
||||
assert "preset" not in opts
|
||||
|
||||
def test_videotoolbox_options(self):
|
||||
opts = _get_codec_options("h264_videotoolbox", g=2, crf=30)
|
||||
assert opts["g"] == "2"
|
||||
# CRF 30 maps to quality = max(1, min(100, 100 - 30*2)) = 40
|
||||
assert opts["q:v"] == "40"
|
||||
assert "crf" not in opts
|
||||
|
||||
def test_nvenc_options(self):
|
||||
opts = _get_codec_options("h264_nvenc", g=2, crf=25)
|
||||
assert opts["rc"] == "constqp"
|
||||
assert opts["qp"] == "25"
|
||||
assert "crf" not in opts
|
||||
# NVENC doesn't support g
|
||||
assert "g" not in opts
|
||||
|
||||
def test_vaapi_options(self):
|
||||
opts = _get_codec_options("h264_vaapi", crf=28)
|
||||
assert opts["qp"] == "28"
|
||||
|
||||
def test_qsv_options(self):
|
||||
opts = _get_codec_options("h264_qsv", crf=25)
|
||||
assert opts["global_quality"] == "25"
|
||||
|
||||
def test_no_g_no_crf(self):
|
||||
opts = _get_codec_options("h264", g=None, crf=None)
|
||||
assert "g" not in opts
|
||||
assert "crf" not in opts
|
||||
|
||||
|
||||
# ─── HW encoder detection tests ───
|
||||
|
||||
|
||||
class TestHWEncoderDetection:
|
||||
def test_detect_available_hw_encoders_returns_list(self):
|
||||
result = detect_available_hw_encoders()
|
||||
assert isinstance(result, list)
|
||||
|
||||
def test_detect_available_hw_encoders_only_valid(self):
|
||||
from lerobot.datasets.video_utils import HW_ENCODERS
|
||||
|
||||
result = detect_available_hw_encoders()
|
||||
for encoder in result:
|
||||
assert encoder in HW_ENCODERS
|
||||
|
||||
def test_resolve_vcodec_passthrough(self):
|
||||
assert resolve_vcodec("libsvtav1") == "libsvtav1"
|
||||
assert resolve_vcodec("h264") == "h264"
|
||||
|
||||
def test_resolve_vcodec_auto_fallback(self):
|
||||
"""When no HW encoders are available, auto should fall back to libsvtav1."""
|
||||
with patch("lerobot.datasets.video_utils.detect_available_hw_encoders", return_value=[]):
|
||||
assert resolve_vcodec("auto") == "libsvtav1"
|
||||
|
||||
def test_resolve_vcodec_auto_picks_hw(self):
|
||||
"""When a HW encoder is available, auto should pick it."""
|
||||
with patch(
|
||||
"lerobot.datasets.video_utils.detect_available_hw_encoders",
|
||||
return_value=["h264_videotoolbox"],
|
||||
):
|
||||
assert resolve_vcodec("auto") == "h264_videotoolbox"
|
||||
|
||||
def test_resolve_vcodec_auto_returns_valid(self):
|
||||
"""Test that resolve_vcodec('auto') returns a known valid codec."""
|
||||
result = resolve_vcodec("auto")
|
||||
assert result in VALID_VIDEO_CODECS
|
||||
|
||||
def test_hw_encoder_names_accepted_in_validation(self):
|
||||
"""Test that HW encoder names pass validation in VALID_VIDEO_CODECS."""
|
||||
assert "auto" in VALID_VIDEO_CODECS
|
||||
assert "h264_videotoolbox" in VALID_VIDEO_CODECS
|
||||
assert "h264_nvenc" in VALID_VIDEO_CODECS
|
||||
|
||||
def test_resolve_vcodec_invalid_raises(self):
|
||||
"""Test that resolve_vcodec raises ValueError for invalid codecs."""
|
||||
with pytest.raises(ValueError, match="Invalid vcodec"):
|
||||
resolve_vcodec("not_a_real_codec")
|
||||
|
||||
|
||||
# ─── _CameraEncoderThread tests ───
|
||||
|
||||
|
||||
class TestCameraEncoderThread:
|
||||
def test_encodes_valid_mp4(self, tmp_path):
|
||||
"""Test that the encoder thread creates a valid MP4 file with correct frame count."""
|
||||
num_frames = 30
|
||||
height, width = 64, 96
|
||||
fps = 30
|
||||
video_path = tmp_path / "test_output" / "test.mp4"
|
||||
|
||||
frame_queue: queue.Queue = queue.Queue(maxsize=60)
|
||||
result_queue: queue.Queue = queue.Queue(maxsize=1)
|
||||
stop_event = threading.Event()
|
||||
|
||||
encoder_thread = _CameraEncoderThread(
|
||||
video_path=video_path,
|
||||
fps=fps,
|
||||
vcodec="libsvtav1",
|
||||
pix_fmt="yuv420p",
|
||||
g=2,
|
||||
crf=30,
|
||||
preset=13,
|
||||
frame_queue=frame_queue,
|
||||
result_queue=result_queue,
|
||||
stop_event=stop_event,
|
||||
)
|
||||
encoder_thread.start()
|
||||
|
||||
# Feed frames (HWC uint8)
|
||||
for _ in range(num_frames):
|
||||
frame = np.random.randint(0, 255, (height, width, 3), dtype=np.uint8)
|
||||
frame_queue.put(frame)
|
||||
|
||||
# Send sentinel
|
||||
frame_queue.put(None)
|
||||
encoder_thread.join(timeout=60)
|
||||
assert not encoder_thread.is_alive()
|
||||
|
||||
# Check result
|
||||
status, data = result_queue.get(timeout=5)
|
||||
assert status == "ok"
|
||||
assert data is not None # Stats should be returned
|
||||
assert "mean" in data
|
||||
assert "std" in data
|
||||
assert "min" in data
|
||||
assert "max" in data
|
||||
assert "count" in data
|
||||
|
||||
# Verify the MP4 file is valid
|
||||
assert video_path.exists()
|
||||
with av.open(str(video_path)) as container:
|
||||
stream = container.streams.video[0]
|
||||
# The frame count should match
|
||||
total_frames = sum(1 for _ in container.decode(stream))
|
||||
assert total_frames == num_frames
|
||||
|
||||
def test_handles_chw_input(self, tmp_path):
|
||||
"""Test that CHW format input is handled correctly."""
|
||||
num_frames = 5
|
||||
fps = 30
|
||||
video_path = tmp_path / "test_chw" / "test.mp4"
|
||||
|
||||
frame_queue: queue.Queue = queue.Queue(maxsize=60)
|
||||
result_queue: queue.Queue = queue.Queue(maxsize=1)
|
||||
stop_event = threading.Event()
|
||||
|
||||
encoder_thread = _CameraEncoderThread(
|
||||
video_path=video_path,
|
||||
fps=fps,
|
||||
vcodec="libsvtav1",
|
||||
pix_fmt="yuv420p",
|
||||
g=2,
|
||||
crf=30,
|
||||
preset=13,
|
||||
frame_queue=frame_queue,
|
||||
result_queue=result_queue,
|
||||
stop_event=stop_event,
|
||||
)
|
||||
encoder_thread.start()
|
||||
|
||||
# Feed CHW frames
|
||||
for _ in range(num_frames):
|
||||
frame = np.random.randint(0, 255, (3, 64, 96), dtype=np.uint8)
|
||||
frame_queue.put(frame)
|
||||
|
||||
frame_queue.put(None)
|
||||
encoder_thread.join(timeout=60)
|
||||
|
||||
status, _ = result_queue.get(timeout=5)
|
||||
assert status == "ok"
|
||||
assert video_path.exists()
|
||||
|
||||
def test_stop_event_cancellation(self, tmp_path):
|
||||
"""Test that setting the stop event causes the thread to exit."""
|
||||
fps = 30
|
||||
video_path = tmp_path / "test_cancel" / "test.mp4"
|
||||
|
||||
frame_queue: queue.Queue = queue.Queue(maxsize=60)
|
||||
result_queue: queue.Queue = queue.Queue(maxsize=1)
|
||||
stop_event = threading.Event()
|
||||
|
||||
encoder_thread = _CameraEncoderThread(
|
||||
video_path=video_path,
|
||||
fps=fps,
|
||||
vcodec="libsvtav1",
|
||||
pix_fmt="yuv420p",
|
||||
g=2,
|
||||
crf=30,
|
||||
preset=13,
|
||||
frame_queue=frame_queue,
|
||||
result_queue=result_queue,
|
||||
stop_event=stop_event,
|
||||
)
|
||||
encoder_thread.start()
|
||||
|
||||
# Feed a few frames
|
||||
for _ in range(3):
|
||||
frame = np.random.randint(0, 255, (64, 96, 3), dtype=np.uint8)
|
||||
frame_queue.put(frame)
|
||||
|
||||
# Signal stop instead of sending sentinel
|
||||
stop_event.set()
|
||||
encoder_thread.join(timeout=10)
|
||||
assert not encoder_thread.is_alive()
|
||||
|
||||
|
||||
# ─── StreamingVideoEncoder tests ───
|
||||
|
||||
|
||||
class TestStreamingVideoEncoder:
|
||||
def test_single_camera_episode(self, tmp_path):
|
||||
"""Test encoding a single camera episode."""
|
||||
encoder = StreamingVideoEncoder(fps=30, vcodec="libsvtav1", pix_fmt="yuv420p", g=2, crf=30, preset=13)
|
||||
|
||||
video_keys = [f"{OBS_IMAGES}.laptop"]
|
||||
encoder.start_episode(video_keys, tmp_path)
|
||||
|
||||
num_frames = 20
|
||||
for _ in range(num_frames):
|
||||
frame = np.random.randint(0, 255, (64, 96, 3), dtype=np.uint8)
|
||||
encoder.feed_frame(f"{OBS_IMAGES}.laptop", frame)
|
||||
|
||||
results = encoder.finish_episode()
|
||||
assert f"{OBS_IMAGES}.laptop" in results
|
||||
|
||||
mp4_path, stats = results[f"{OBS_IMAGES}.laptop"]
|
||||
assert mp4_path.exists()
|
||||
assert stats is not None
|
||||
|
||||
# Verify frame count
|
||||
with av.open(str(mp4_path)) as container:
|
||||
stream = container.streams.video[0]
|
||||
total_frames = sum(1 for _ in container.decode(stream))
|
||||
assert total_frames == num_frames
|
||||
|
||||
encoder.close()
|
||||
|
||||
def test_multi_camera_episode(self, tmp_path):
|
||||
"""Test encoding multiple cameras simultaneously."""
|
||||
encoder = StreamingVideoEncoder(fps=30, vcodec="libsvtav1", pix_fmt="yuv420p", g=2, crf=30)
|
||||
|
||||
video_keys = [f"{OBS_IMAGES}.laptop", f"{OBS_IMAGES}.phone"]
|
||||
encoder.start_episode(video_keys, tmp_path)
|
||||
|
||||
num_frames = 15
|
||||
for _ in range(num_frames):
|
||||
frame0 = np.random.randint(0, 255, (64, 96, 3), dtype=np.uint8)
|
||||
frame1 = np.random.randint(0, 255, (64, 96, 3), dtype=np.uint8)
|
||||
encoder.feed_frame(video_keys[0], frame0)
|
||||
encoder.feed_frame(video_keys[1], frame1)
|
||||
|
||||
results = encoder.finish_episode()
|
||||
|
||||
for key in video_keys:
|
||||
assert key in results
|
||||
mp4_path, stats = results[key]
|
||||
assert mp4_path.exists()
|
||||
assert stats is not None
|
||||
|
||||
encoder.close()
|
||||
|
||||
def test_sequential_episodes(self, tmp_path):
|
||||
"""Test that multiple sequential episodes work correctly."""
|
||||
encoder = StreamingVideoEncoder(fps=30, vcodec="libsvtav1", pix_fmt="yuv420p", g=2, crf=30)
|
||||
video_keys = [f"{OBS_IMAGES}.cam"]
|
||||
|
||||
for ep in range(3):
|
||||
encoder.start_episode(video_keys, tmp_path)
|
||||
num_frames = 10 + ep * 5
|
||||
for _ in range(num_frames):
|
||||
frame = np.random.randint(0, 255, (64, 96, 3), dtype=np.uint8)
|
||||
encoder.feed_frame(f"{OBS_IMAGES}.cam", frame)
|
||||
results = encoder.finish_episode()
|
||||
|
||||
mp4_path, stats = results[f"{OBS_IMAGES}.cam"]
|
||||
assert mp4_path.exists()
|
||||
|
||||
with av.open(str(mp4_path)) as container:
|
||||
stream = container.streams.video[0]
|
||||
total_frames = sum(1 for _ in container.decode(stream))
|
||||
assert total_frames == num_frames
|
||||
|
||||
encoder.close()
|
||||
|
||||
def test_cancel_episode(self, tmp_path):
|
||||
"""Test that canceling an episode cleans up properly."""
|
||||
encoder = StreamingVideoEncoder(fps=30, vcodec="libsvtav1", pix_fmt="yuv420p", g=2, crf=30)
|
||||
video_keys = [f"{OBS_IMAGES}.cam"]
|
||||
|
||||
encoder.start_episode(video_keys, tmp_path)
|
||||
|
||||
for _ in range(5):
|
||||
frame = np.random.randint(0, 255, (64, 96, 3), dtype=np.uint8)
|
||||
encoder.feed_frame(f"{OBS_IMAGES}.cam", frame)
|
||||
|
||||
encoder.cancel_episode()
|
||||
|
||||
# Should be able to start a new episode after cancel
|
||||
encoder.start_episode(video_keys, tmp_path)
|
||||
for _ in range(5):
|
||||
frame = np.random.randint(0, 255, (64, 96, 3), dtype=np.uint8)
|
||||
encoder.feed_frame(f"{OBS_IMAGES}.cam", frame)
|
||||
results = encoder.finish_episode()
|
||||
|
||||
assert f"{OBS_IMAGES}.cam" in results
|
||||
encoder.close()
|
||||
|
||||
def test_feed_without_start_raises(self, tmp_path):
|
||||
"""Test that feeding frames without starting an episode raises."""
|
||||
encoder = StreamingVideoEncoder(fps=30, vcodec="libsvtav1", pix_fmt="yuv420p")
|
||||
with pytest.raises(RuntimeError, match="No active episode"):
|
||||
encoder.feed_frame("cam", np.zeros((64, 96, 3), dtype=np.uint8))
|
||||
encoder.close()
|
||||
|
||||
def test_finish_without_start_raises(self, tmp_path):
|
||||
"""Test that finishing without starting raises."""
|
||||
encoder = StreamingVideoEncoder(fps=30, vcodec="libsvtav1", pix_fmt="yuv420p")
|
||||
with pytest.raises(RuntimeError, match="No active episode"):
|
||||
encoder.finish_episode()
|
||||
encoder.close()
|
||||
|
||||
def test_close_is_idempotent(self, tmp_path):
|
||||
"""Test that close() can be called multiple times safely."""
|
||||
encoder = StreamingVideoEncoder(fps=30, vcodec="libsvtav1", pix_fmt="yuv420p")
|
||||
encoder.close()
|
||||
encoder.close() # Should not raise
|
||||
|
||||
def test_video_duration_matches_frame_count(self, tmp_path):
|
||||
"""Test that encoded video duration matches num_frames / fps."""
|
||||
encoder = StreamingVideoEncoder(fps=30, vcodec="libsvtav1", pix_fmt="yuv420p", g=2, crf=30, preset=13)
|
||||
video_keys = [f"{OBS_IMAGES}.cam"]
|
||||
encoder.start_episode(video_keys, tmp_path)
|
||||
|
||||
num_frames = 90 # 3 seconds at 30fps
|
||||
for _ in range(num_frames):
|
||||
frame = np.random.randint(0, 255, (64, 96, 3), dtype=np.uint8)
|
||||
encoder.feed_frame(f"{OBS_IMAGES}.cam", frame)
|
||||
|
||||
results = encoder.finish_episode()
|
||||
mp4_path, _ = results[f"{OBS_IMAGES}.cam"]
|
||||
|
||||
expected_duration = num_frames / 30.0 # 3.0 seconds
|
||||
|
||||
with av.open(str(mp4_path)) as container:
|
||||
stream = container.streams.video[0]
|
||||
total_frames = sum(1 for _ in container.decode(stream))
|
||||
if stream.duration is not None:
|
||||
actual_duration = float(stream.duration * stream.time_base)
|
||||
else:
|
||||
actual_duration = float(container.duration / av.time_base)
|
||||
|
||||
assert total_frames == num_frames
|
||||
# Allow small tolerance for duration due to codec framing
|
||||
assert abs(actual_duration - expected_duration) < 0.5, (
|
||||
f"Video duration {actual_duration:.2f}s != expected {expected_duration:.2f}s"
|
||||
)
|
||||
|
||||
encoder.close()
|
||||
|
||||
def test_multi_camera_start_episode_called_once(self, tmp_path):
|
||||
"""Test that with multiple cameras, no frames are lost due to double start_episode."""
|
||||
encoder = StreamingVideoEncoder(fps=30, vcodec="libsvtav1", pix_fmt="yuv420p", g=2, crf=30)
|
||||
|
||||
video_keys = [f"{OBS_IMAGES}.cam1", f"{OBS_IMAGES}.cam2"]
|
||||
encoder.start_episode(video_keys, tmp_path)
|
||||
|
||||
num_frames = 30
|
||||
for _ in range(num_frames):
|
||||
frame0 = np.random.randint(0, 255, (64, 96, 3), dtype=np.uint8)
|
||||
frame1 = np.random.randint(0, 255, (64, 96, 3), dtype=np.uint8)
|
||||
encoder.feed_frame(video_keys[0], frame0)
|
||||
encoder.feed_frame(video_keys[1], frame1)
|
||||
|
||||
results = encoder.finish_episode()
|
||||
|
||||
# Both cameras should have all frames
|
||||
for key in video_keys:
|
||||
mp4_path, stats = results[key]
|
||||
assert mp4_path.exists()
|
||||
with av.open(str(mp4_path)) as container:
|
||||
stream = container.streams.video[0]
|
||||
total_frames = sum(1 for _ in container.decode(stream))
|
||||
assert total_frames == num_frames, (
|
||||
f"Camera {key}: expected {num_frames} frames, got {total_frames}"
|
||||
)
|
||||
|
||||
encoder.close()
|
||||
|
||||
def test_encoder_threads_passed_to_thread(self, tmp_path):
|
||||
"""Test that encoder_threads is stored and passed through to encoder threads."""
|
||||
encoder = StreamingVideoEncoder(
|
||||
fps=30, vcodec="libsvtav1", pix_fmt="yuv420p", g=2, crf=30, encoder_threads=2
|
||||
)
|
||||
assert encoder.encoder_threads == 2
|
||||
|
||||
video_keys = [f"{OBS_IMAGES}.cam"]
|
||||
encoder.start_episode(video_keys, tmp_path)
|
||||
|
||||
# Verify the thread received the encoder_threads value
|
||||
thread = encoder._threads[f"{OBS_IMAGES}.cam"]
|
||||
assert thread.encoder_threads == 2
|
||||
|
||||
# Feed some frames and finish to ensure it works end-to-end
|
||||
num_frames = 10
|
||||
for _ in range(num_frames):
|
||||
frame = np.random.randint(0, 255, (64, 96, 3), dtype=np.uint8)
|
||||
encoder.feed_frame(f"{OBS_IMAGES}.cam", frame)
|
||||
|
||||
results = encoder.finish_episode()
|
||||
mp4_path, stats = results[f"{OBS_IMAGES}.cam"]
|
||||
assert mp4_path.exists()
|
||||
assert stats is not None
|
||||
|
||||
with av.open(str(mp4_path)) as container:
|
||||
stream = container.streams.video[0]
|
||||
total_frames = sum(1 for _ in container.decode(stream))
|
||||
assert total_frames == num_frames
|
||||
|
||||
encoder.close()
|
||||
|
||||
def test_encoder_threads_none_by_default(self, tmp_path):
|
||||
"""Test that encoder_threads defaults to None (codec auto-detect)."""
|
||||
encoder = StreamingVideoEncoder(fps=30, vcodec="libsvtav1", pix_fmt="yuv420p")
|
||||
assert encoder.encoder_threads is None
|
||||
encoder.close()
|
||||
|
||||
def test_graceful_frame_dropping(self, tmp_path):
|
||||
"""Test that full queue drops frames instead of crashing."""
|
||||
encoder = StreamingVideoEncoder(
|
||||
fps=30, vcodec="libsvtav1", pix_fmt="yuv420p", g=2, crf=30, preset=13, queue_maxsize=1
|
||||
)
|
||||
video_keys = [f"{OBS_IMAGES}.cam"]
|
||||
encoder.start_episode(video_keys, tmp_path)
|
||||
|
||||
# Feed many frames quickly - with queue_maxsize=1, some will be dropped
|
||||
num_frames = 50
|
||||
for _ in range(num_frames):
|
||||
frame = np.random.randint(0, 255, (64, 96, 3), dtype=np.uint8)
|
||||
encoder.feed_frame(f"{OBS_IMAGES}.cam", frame)
|
||||
|
||||
# Should not raise - frames are dropped gracefully
|
||||
results = encoder.finish_episode()
|
||||
assert f"{OBS_IMAGES}.cam" in results
|
||||
|
||||
mp4_path, _ = results[f"{OBS_IMAGES}.cam"]
|
||||
assert mp4_path.exists()
|
||||
|
||||
# Some frames should have been dropped (queue was tiny)
|
||||
dropped = encoder._dropped_frames.get(f"{OBS_IMAGES}.cam", 0)
|
||||
# We can't guarantee drops but can verify no crash occurred
|
||||
assert dropped >= 0
|
||||
|
||||
encoder.close()
|
||||
|
||||
|
||||
# ─── Integration tests with LeRobotDataset ───
|
||||
|
||||
|
||||
class TestStreamingEncoderIntegration:
|
||||
def test_add_frame_save_episode_streaming(self, tmp_path):
|
||||
"""Full integration test: add_frame -> save_episode with streaming encoding."""
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
features = {
|
||||
"observation.images.cam": {
|
||||
"dtype": "video",
|
||||
"shape": (64, 96, 3),
|
||||
"names": ["height", "width", "channels"],
|
||||
},
|
||||
"action": {"dtype": "float32", "shape": (6,), "names": ["j1", "j2", "j3", "j4", "j5", "j6"]},
|
||||
}
|
||||
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id="test/streaming",
|
||||
fps=30,
|
||||
features=features,
|
||||
root=tmp_path / "streaming_test",
|
||||
use_videos=True,
|
||||
streaming_encoding=True,
|
||||
)
|
||||
|
||||
assert dataset._streaming_encoder is not None
|
||||
|
||||
num_frames = 20
|
||||
for _ in range(num_frames):
|
||||
frame = {
|
||||
"observation.images.cam": np.random.randint(0, 255, (64, 96, 3), dtype=np.uint8),
|
||||
"action": np.random.randn(6).astype(np.float32),
|
||||
"task": "test task",
|
||||
}
|
||||
dataset.add_frame(frame)
|
||||
|
||||
dataset.save_episode()
|
||||
|
||||
# Verify dataset metadata
|
||||
assert dataset.meta.total_episodes == 1
|
||||
assert dataset.meta.total_frames == num_frames
|
||||
|
||||
# Verify stats exist for the video key
|
||||
assert dataset.meta.stats is not None
|
||||
assert "observation.images.cam" in dataset.meta.stats
|
||||
assert "action" in dataset.meta.stats
|
||||
|
||||
dataset.finalize()
|
||||
|
||||
def test_streaming_disabled_creates_pngs(self, tmp_path):
|
||||
"""Test that disabling streaming encoding falls back to PNG path."""
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
features = {
|
||||
"observation.images.cam": {
|
||||
"dtype": "video",
|
||||
"shape": (64, 96, 3),
|
||||
"names": ["height", "width", "channels"],
|
||||
},
|
||||
"action": {"dtype": "float32", "shape": (6,), "names": ["j1", "j2", "j3", "j4", "j5", "j6"]},
|
||||
}
|
||||
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id="test/no_streaming",
|
||||
fps=30,
|
||||
features=features,
|
||||
root=tmp_path / "no_streaming_test",
|
||||
use_videos=True,
|
||||
streaming_encoding=False,
|
||||
)
|
||||
|
||||
assert dataset._streaming_encoder is None
|
||||
|
||||
num_frames = 5
|
||||
for _ in range(num_frames):
|
||||
frame = {
|
||||
"observation.images.cam": np.random.randint(0, 255, (64, 96, 3), dtype=np.uint8),
|
||||
"action": np.random.randn(6).astype(np.float32),
|
||||
"task": "test task",
|
||||
}
|
||||
dataset.add_frame(frame)
|
||||
|
||||
# With streaming disabled, PNG files should be written
|
||||
images_dir = dataset.root / "images"
|
||||
assert images_dir.exists()
|
||||
|
||||
dataset.save_episode()
|
||||
dataset.finalize()
|
||||
|
||||
def test_multi_episode_streaming(self, tmp_path):
|
||||
"""Test recording multiple episodes with streaming encoding."""
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
features = {
|
||||
"observation.images.cam": {
|
||||
"dtype": "video",
|
||||
"shape": (64, 96, 3),
|
||||
"names": ["height", "width", "channels"],
|
||||
},
|
||||
"action": {"dtype": "float32", "shape": (2,), "names": ["j1", "j2"]},
|
||||
}
|
||||
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id="test/multi_ep",
|
||||
fps=30,
|
||||
features=features,
|
||||
root=tmp_path / "multi_ep_test",
|
||||
use_videos=True,
|
||||
streaming_encoding=True,
|
||||
)
|
||||
|
||||
for ep in range(3):
|
||||
num_frames = 10 + ep * 5
|
||||
for _ in range(num_frames):
|
||||
frame = {
|
||||
"observation.images.cam": np.random.randint(0, 255, (64, 96, 3), dtype=np.uint8),
|
||||
"action": np.random.randn(2).astype(np.float32),
|
||||
"task": f"task_{ep}",
|
||||
}
|
||||
dataset.add_frame(frame)
|
||||
dataset.save_episode()
|
||||
|
||||
assert dataset.meta.total_episodes == 3
|
||||
assert dataset.meta.total_frames == 10 + 15 + 20
|
||||
|
||||
dataset.finalize()
|
||||
|
||||
def test_clear_episode_buffer_cancels_streaming(self, tmp_path):
|
||||
"""Test that clearing episode buffer cancels streaming encoding."""
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
features = {
|
||||
"observation.images.cam": {
|
||||
"dtype": "video",
|
||||
"shape": (64, 96, 3),
|
||||
"names": ["height", "width", "channels"],
|
||||
},
|
||||
"action": {"dtype": "float32", "shape": (2,), "names": ["j1", "j2"]},
|
||||
}
|
||||
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id="test/cancel",
|
||||
fps=30,
|
||||
features=features,
|
||||
root=tmp_path / "cancel_test",
|
||||
use_videos=True,
|
||||
streaming_encoding=True,
|
||||
)
|
||||
|
||||
# Add some frames
|
||||
for _ in range(5):
|
||||
frame = {
|
||||
"observation.images.cam": np.random.randint(0, 255, (64, 96, 3), dtype=np.uint8),
|
||||
"action": np.random.randn(2).astype(np.float32),
|
||||
"task": "task",
|
||||
}
|
||||
dataset.add_frame(frame)
|
||||
|
||||
# Cancel and re-record
|
||||
dataset.clear_episode_buffer()
|
||||
|
||||
# Record a new episode
|
||||
for _ in range(10):
|
||||
frame = {
|
||||
"observation.images.cam": np.random.randint(0, 255, (64, 96, 3), dtype=np.uint8),
|
||||
"action": np.random.randn(2).astype(np.float32),
|
||||
"task": "task",
|
||||
}
|
||||
dataset.add_frame(frame)
|
||||
dataset.save_episode()
|
||||
|
||||
assert dataset.meta.total_episodes == 1
|
||||
assert dataset.meta.total_frames == 10
|
||||
|
||||
dataset.finalize()
|
||||
|
||||
def test_multi_camera_add_frame_streaming(self, tmp_path):
|
||||
"""Test that start_episode is called once with multiple video keys."""
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
features = {
|
||||
"observation.images.cam1": {
|
||||
"dtype": "video",
|
||||
"shape": (64, 96, 3),
|
||||
"names": ["height", "width", "channels"],
|
||||
},
|
||||
"observation.images.cam2": {
|
||||
"dtype": "video",
|
||||
"shape": (64, 96, 3),
|
||||
"names": ["height", "width", "channels"],
|
||||
},
|
||||
"action": {"dtype": "float32", "shape": (2,), "names": ["j1", "j2"]},
|
||||
}
|
||||
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id="test/multi_cam",
|
||||
fps=30,
|
||||
features=features,
|
||||
root=tmp_path / "multi_cam_test",
|
||||
use_videos=True,
|
||||
streaming_encoding=True,
|
||||
)
|
||||
|
||||
num_frames = 15
|
||||
for _ in range(num_frames):
|
||||
frame = {
|
||||
"observation.images.cam1": np.random.randint(0, 255, (64, 96, 3), dtype=np.uint8),
|
||||
"observation.images.cam2": np.random.randint(0, 255, (64, 96, 3), dtype=np.uint8),
|
||||
"action": np.random.randn(2).astype(np.float32),
|
||||
"task": "test task",
|
||||
}
|
||||
dataset.add_frame(frame)
|
||||
|
||||
dataset.save_episode()
|
||||
|
||||
assert dataset.meta.total_episodes == 1
|
||||
assert dataset.meta.total_frames == num_frames
|
||||
|
||||
dataset.finalize()
|
||||
@@ -11,8 +11,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
from packaging.version import Version
|
||||
from torch.optim.lr_scheduler import LambdaLR
|
||||
|
||||
from lerobot.optim.schedulers import (
|
||||
@@ -40,10 +38,6 @@ def test_diffuser_scheduler(optimizer):
|
||||
"last_epoch": 1,
|
||||
"lr_lambdas": [None],
|
||||
}
|
||||
|
||||
if Version(torch.__version__) >= Version("2.8"):
|
||||
expected_state_dict["_is_initial"] = False
|
||||
|
||||
assert scheduler.state_dict() == expected_state_dict
|
||||
|
||||
|
||||
@@ -62,10 +56,6 @@ def test_vqbet_scheduler(optimizer):
|
||||
"last_epoch": 1,
|
||||
"lr_lambdas": [None],
|
||||
}
|
||||
|
||||
if Version(torch.__version__) >= Version("2.8"):
|
||||
expected_state_dict["_is_initial"] = False
|
||||
|
||||
assert scheduler.state_dict() == expected_state_dict
|
||||
|
||||
|
||||
@@ -86,10 +76,6 @@ def test_cosine_decay_with_warmup_scheduler(optimizer):
|
||||
"last_epoch": 1,
|
||||
"lr_lambdas": [None],
|
||||
}
|
||||
|
||||
if Version(torch.__version__) >= Version("2.8"):
|
||||
expected_state_dict["_is_initial"] = False
|
||||
|
||||
assert scheduler.state_dict() == expected_state_dict
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,344 @@
|
||||
"""Tests for delta action transforms — full pipeline validation.
|
||||
|
||||
Tests the complete flow matching OpenPI:
|
||||
raw actions → DeltaActions → Normalize(delta_stats) → model → Unnormalize → AbsoluteActions
|
||||
|
||||
Uses real dataset: lerobot-data-collection/dagger_final_1_21
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.datasets.compute_stats import get_feature_stats
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.processor import TransitionKey, batch_to_transition
|
||||
from lerobot.processor.delta_action_processor import (
|
||||
AbsoluteActionsProcessorStep,
|
||||
DeltaActionsProcessorStep,
|
||||
to_absolute_actions,
|
||||
to_delta_actions,
|
||||
)
|
||||
from lerobot.processor.normalize_processor import NormalizerProcessorStep, UnnormalizerProcessorStep
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||
|
||||
CHUNK_SIZE = 10
|
||||
REPO_ID = "lerobot-data-collection/dagger_final_1_21"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def dataset():
|
||||
return LeRobotDataset(REPO_ID, episodes=[0])
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def action_dim(dataset):
|
||||
return dataset.meta.features["action"]["shape"][0]
|
||||
|
||||
|
||||
def _build_action_chunks(dataset, chunk_size, max_chunks=50):
|
||||
"""Build action chunks from hf_dataset, like the training script does."""
|
||||
hf = dataset.hf_dataset
|
||||
total = len(hf)
|
||||
all_ep = torch.tensor([int(hf[i]["episode_index"]) for i in range(total)])
|
||||
chunks, states = [], []
|
||||
for i in range(total - chunk_size + 1):
|
||||
if all_ep[i] != all_ep[i + chunk_size - 1]:
|
||||
continue
|
||||
chunk_actions = torch.stack([hf[i + k]["action"] for k in range(chunk_size)]).float()
|
||||
state = hf[i]["observation.state"].float()
|
||||
chunks.append(chunk_actions)
|
||||
states.append(state)
|
||||
if len(chunks) >= max_chunks:
|
||||
break
|
||||
assert len(chunks) > 0, f"No valid chunks found. total={total}, ep_indices={all_ep.tolist()}"
|
||||
return torch.stack(chunks), torch.stack(states)
|
||||
|
||||
|
||||
def _compute_delta_chunk_stats(action_chunks, states, mask):
|
||||
all_deltas = []
|
||||
for actions, state in zip(action_chunks, states):
|
||||
delta = to_delta_actions(actions.unsqueeze(0), state.unsqueeze(0), mask).squeeze(0)
|
||||
all_deltas.append(delta.numpy())
|
||||
all_delta = np.concatenate(all_deltas, axis=0)
|
||||
return get_feature_stats(all_delta, axis=0, keepdims=all_delta.ndim == 1)
|
||||
|
||||
|
||||
# --- Basic roundtrip tests ---
|
||||
|
||||
def test_roundtrip_3d(action_dim):
|
||||
actions = torch.randn(4, CHUNK_SIZE, action_dim)
|
||||
state = torch.randn(4, action_dim)
|
||||
mask = [True] * action_dim
|
||||
recovered = to_absolute_actions(to_delta_actions(actions, state, mask), state, mask)
|
||||
torch.testing.assert_close(recovered, actions)
|
||||
|
||||
|
||||
def test_roundtrip_2d(action_dim):
|
||||
actions = torch.randn(4, action_dim)
|
||||
state = torch.randn(4, action_dim)
|
||||
mask = [True] * action_dim
|
||||
recovered = to_absolute_actions(to_delta_actions(actions, state, mask), state, mask)
|
||||
torch.testing.assert_close(recovered, actions)
|
||||
|
||||
|
||||
def test_no_mutation(action_dim):
|
||||
actions = torch.randn(2, CHUNK_SIZE, action_dim)
|
||||
original = actions.clone()
|
||||
state = torch.randn(2, action_dim)
|
||||
to_delta_actions(actions, state, [True] * action_dim)
|
||||
torch.testing.assert_close(actions, original)
|
||||
|
||||
|
||||
def test_exclude_joints_supports_partial_name_matching():
|
||||
names = [
|
||||
"right_joint_1.pos",
|
||||
"right_gripper.pos",
|
||||
"left_joint_1.pos",
|
||||
"left_gripper.pos",
|
||||
]
|
||||
step = DeltaActionsProcessorStep(enabled=True, exclude_joints=["gripper"], action_names=names)
|
||||
assert step._build_mask(len(names)) == [True, False, True, False]
|
||||
|
||||
|
||||
# --- Chunk-level delta stats test ---
|
||||
|
||||
def test_chunk_stats_have_larger_std_than_frame_stats(dataset, action_dim):
|
||||
"""Chunk-level delta stats should have larger std than per-frame delta stats."""
|
||||
action_chunks, states = _build_action_chunks(dataset, CHUNK_SIZE)
|
||||
mask = [True] * action_dim
|
||||
|
||||
chunk_stats = _compute_delta_chunk_stats(action_chunks, states, mask)
|
||||
|
||||
# Per-frame stats
|
||||
hf = dataset.hf_dataset
|
||||
n = min(500, len(hf))
|
||||
frame_actions = torch.stack([hf[i]["action"] for i in range(n)]).float()
|
||||
frame_states = torch.stack([hf[i]["observation.state"] for i in range(n)]).float()
|
||||
frame_deltas = to_delta_actions(frame_actions, frame_states, mask).numpy()
|
||||
frame_stats = get_feature_stats(frame_deltas, axis=0, keepdims=frame_deltas.ndim == 1)
|
||||
|
||||
assert chunk_stats["std"].mean() >= frame_stats["std"].mean(), (
|
||||
f"Chunk std ({chunk_stats['std'].mean():.4f}) should be >= "
|
||||
f"frame std ({frame_stats['std'].mean():.4f})"
|
||||
)
|
||||
|
||||
|
||||
# --- Full pipeline roundtrip: delta → normalize → unnormalize → absolute ---
|
||||
|
||||
def test_full_pipeline_roundtrip(dataset, action_dim):
|
||||
"""Test the complete OpenPI pipeline: delta → normalize → unnormalize → absolute."""
|
||||
action_chunks, states = _build_action_chunks(dataset, CHUNK_SIZE)
|
||||
mask = [True] * action_dim
|
||||
|
||||
delta_stats = _compute_delta_chunk_stats(action_chunks, states, mask)
|
||||
stats = {ACTION: {k: v for k, v in delta_stats.items()}}
|
||||
|
||||
features = {ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))}
|
||||
norm_map = {FeatureType.ACTION: NormalizationMode.MEAN_STD}
|
||||
|
||||
delta_step = DeltaActionsProcessorStep(enabled=True)
|
||||
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
|
||||
unnormalizer = UnnormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
|
||||
absolute_step = AbsoluteActionsProcessorStep(enabled=True, delta_step=delta_step)
|
||||
|
||||
original_actions = action_chunks[0].unsqueeze(0)
|
||||
state = states[0].unsqueeze(0)
|
||||
|
||||
batch = {ACTION: original_actions, OBS_STATE: state}
|
||||
transition = batch_to_transition(batch)
|
||||
|
||||
# Forward: delta → normalize
|
||||
t1 = delta_step(transition)
|
||||
t2 = normalizer(t1)
|
||||
|
||||
normalized_action = t2[TransitionKey.ACTION]
|
||||
assert normalized_action.abs().mean() < 10, (
|
||||
f"Normalized actions should be in reasonable range, got mean abs {normalized_action.abs().mean():.2f}"
|
||||
)
|
||||
|
||||
# Reverse: unnormalize → absolute
|
||||
t3 = unnormalizer(t2)
|
||||
t4 = absolute_step(t3)
|
||||
|
||||
recovered_actions = t4[TransitionKey.ACTION]
|
||||
torch.testing.assert_close(recovered_actions, original_actions, atol=1e-4, rtol=1e-4)
|
||||
|
||||
|
||||
def test_normalized_delta_values_are_reasonable(dataset, action_dim):
|
||||
"""With correct chunk stats, normalized delta actions should be in a reasonable range."""
|
||||
action_chunks, states = _build_action_chunks(dataset, CHUNK_SIZE)
|
||||
mask = [True] * action_dim
|
||||
|
||||
delta_stats = _compute_delta_chunk_stats(action_chunks, states, mask)
|
||||
mean = torch.tensor(delta_stats["mean"]).float()
|
||||
std = torch.tensor(delta_stats["std"]).float()
|
||||
|
||||
all_normalized = []
|
||||
for actions, state in zip(action_chunks, states):
|
||||
delta = to_delta_actions(actions.unsqueeze(0), state.unsqueeze(0), mask).squeeze(0)
|
||||
normalized = (delta - mean) / (std + 1e-6)
|
||||
all_normalized.append(normalized)
|
||||
|
||||
all_normalized = torch.cat(all_normalized, dim=0)
|
||||
|
||||
pct_in_range = (all_normalized.abs() < 5).float().mean()
|
||||
assert pct_in_range > 0.9, (
|
||||
f"Only {pct_in_range*100:.1f}% of normalized values in [-5, 5], expected >90%"
|
||||
)
|
||||
|
||||
assert all_normalized.mean().abs() < 1.0, (
|
||||
f"Mean of normalized deltas is {all_normalized.mean():.2f}, expected near 0"
|
||||
)
|
||||
|
||||
|
||||
def test_processor_step_roundtrip(dataset, action_dim):
|
||||
"""DeltaActionsProcessorStep applies delta; to_absolute_actions recovers original."""
|
||||
hf = dataset.hf_dataset
|
||||
batch = {
|
||||
ACTION: torch.stack([hf[i]["action"] for i in range(4)]),
|
||||
OBS_STATE: torch.stack([hf[i]["observation.state"] for i in range(4)]),
|
||||
}
|
||||
original_actions = batch[ACTION].clone()
|
||||
transition = batch_to_transition(batch)
|
||||
|
||||
step = DeltaActionsProcessorStep(enabled=True)
|
||||
delta_transition = step(transition)
|
||||
assert not torch.allclose(delta_transition[TransitionKey.ACTION], original_actions)
|
||||
|
||||
state = transition[TransitionKey.OBSERVATION][OBS_STATE]
|
||||
mask = [True] * action_dim
|
||||
recovered = to_absolute_actions(delta_transition[TransitionKey.ACTION], state, mask)
|
||||
torch.testing.assert_close(recovered, original_actions)
|
||||
|
||||
|
||||
def test_processor_step_disabled_is_noop(dataset, action_dim):
|
||||
"""enabled=False should be a no-op."""
|
||||
hf = dataset.hf_dataset
|
||||
batch = {
|
||||
ACTION: torch.stack([hf[i]["action"] for i in range(2)]),
|
||||
OBS_STATE: torch.stack([hf[i]["observation.state"] for i in range(2)]),
|
||||
}
|
||||
original = batch[ACTION].clone()
|
||||
transition = batch_to_transition(batch)
|
||||
result = DeltaActionsProcessorStep(enabled=False)(transition)
|
||||
torch.testing.assert_close(result[TransitionKey.ACTION], original)
|
||||
|
||||
|
||||
# --- Training batch shape validation ---
|
||||
|
||||
def test_delta_with_action_chunks(dataset, action_dim):
|
||||
"""Verify delta works correctly with (B, chunk_size, action_dim) shaped actions."""
|
||||
action_chunks, states = _build_action_chunks(dataset, CHUNK_SIZE)
|
||||
|
||||
# Simulate a training batch: actions=(B, chunk_size, action_dim), state=(B, state_dim)
|
||||
batch_actions = action_chunks[:4] # (4, chunk_size, action_dim)
|
||||
batch_states = states[:4] # (4, state_dim)
|
||||
|
||||
mask = [True] * action_dim
|
||||
delta = to_delta_actions(batch_actions, batch_states, mask)
|
||||
|
||||
# First action in each chunk should be close to zero (action[t] - state[t] ≈ small)
|
||||
first_deltas = delta[:, 0, :] # (B, action_dim)
|
||||
assert first_deltas.abs().mean() < delta.abs().mean(), (
|
||||
f"First action in chunk should have smaller delta than average. "
|
||||
f"First: {first_deltas.abs().mean():.4f}, Average: {delta.abs().mean():.4f}"
|
||||
)
|
||||
|
||||
# Later actions should have larger deltas
|
||||
last_deltas = delta[:, -1, :] # (B, action_dim)
|
||||
assert last_deltas.abs().mean() >= first_deltas.abs().mean(), (
|
||||
f"Last action in chunk should have >= delta than first. "
|
||||
f"Last: {last_deltas.abs().mean():.4f}, First: {first_deltas.abs().mean():.4f}"
|
||||
)
|
||||
|
||||
# Roundtrip
|
||||
recovered = to_absolute_actions(delta, batch_states, mask)
|
||||
torch.testing.assert_close(recovered, batch_actions)
|
||||
|
||||
|
||||
def test_delta_stats_match_actual_data_distribution(dataset, action_dim):
|
||||
"""Verify computed stats match the actual delta distribution."""
|
||||
action_chunks, states = _build_action_chunks(dataset, CHUNK_SIZE)
|
||||
mask = [True] * action_dim
|
||||
|
||||
# Compute stats like the training script does
|
||||
delta_stats = _compute_delta_chunk_stats(action_chunks, states, mask)
|
||||
|
||||
# Also compute directly
|
||||
all_deltas = []
|
||||
for actions, state in zip(action_chunks, states):
|
||||
delta = to_delta_actions(actions.unsqueeze(0), state.unsqueeze(0), mask).squeeze(0)
|
||||
all_deltas.append(delta)
|
||||
all_deltas_tensor = torch.cat(all_deltas, dim=0)
|
||||
|
||||
# Compare mean
|
||||
actual_mean = all_deltas_tensor.mean(dim=0).numpy()
|
||||
np.testing.assert_allclose(delta_stats["mean"], actual_mean, atol=0.01)
|
||||
|
||||
# Compare std
|
||||
actual_std = all_deltas_tensor.std(dim=0).numpy()
|
||||
np.testing.assert_allclose(delta_stats["std"], actual_std, atol=0.1)
|
||||
|
||||
# Verify q01 < mean < q99
|
||||
assert (delta_stats["q01"] < delta_stats["mean"]).all(), "q01 should be < mean"
|
||||
assert (delta_stats["mean"] < delta_stats["q99"]).all(), "mean should be < q99"
|
||||
|
||||
|
||||
def test_quantile_normalization_roundtrip(dataset, action_dim):
|
||||
"""Full roundtrip with QUANTILES normalization (what OpenPI uses for pi05)."""
|
||||
action_chunks, states = _build_action_chunks(dataset, CHUNK_SIZE)
|
||||
mask = [True] * action_dim
|
||||
|
||||
delta_stats = _compute_delta_chunk_stats(action_chunks, states, mask)
|
||||
stats = {ACTION: {k: v for k, v in delta_stats.items()}}
|
||||
|
||||
features = {ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))}
|
||||
norm_map = {FeatureType.ACTION: NormalizationMode.QUANTILES}
|
||||
|
||||
delta_step = DeltaActionsProcessorStep(enabled=True)
|
||||
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
|
||||
unnormalizer = UnnormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
|
||||
absolute_step = AbsoluteActionsProcessorStep(enabled=True, delta_step=delta_step)
|
||||
|
||||
original_actions = action_chunks[0].unsqueeze(0)
|
||||
state = states[0].unsqueeze(0)
|
||||
|
||||
batch = {ACTION: original_actions, OBS_STATE: state}
|
||||
transition = batch_to_transition(batch)
|
||||
|
||||
# Forward: delta → quantile normalize
|
||||
t1 = delta_step(transition)
|
||||
t2 = normalizer(t1)
|
||||
|
||||
normalized = t2[TransitionKey.ACTION]
|
||||
# Most values should be in [-1, 1] with quantile normalization
|
||||
pct_in_range = (normalized.abs() < 2).float().mean()
|
||||
assert pct_in_range > 0.5, (
|
||||
f"Only {pct_in_range*100:.1f}% in [-2, 2] after quantile norm, expected >50%"
|
||||
)
|
||||
|
||||
# Reverse: unnormalize → absolute
|
||||
t3 = unnormalizer(t2)
|
||||
t4 = absolute_step(t3)
|
||||
|
||||
recovered = t4[TransitionKey.ACTION]
|
||||
torch.testing.assert_close(recovered, original_actions, atol=1e-3, rtol=1e-3)
|
||||
|
||||
|
||||
def test_state_not_modified_by_delta(dataset, action_dim):
|
||||
"""State should never be modified by the delta processor."""
|
||||
hf = dataset.hf_dataset
|
||||
batch = {
|
||||
ACTION: torch.stack([hf[i]["action"] for i in range(4)]),
|
||||
OBS_STATE: torch.stack([hf[i]["observation.state"] for i in range(4)]),
|
||||
}
|
||||
original_state = batch[OBS_STATE].clone()
|
||||
transition = batch_to_transition(batch)
|
||||
|
||||
step = DeltaActionsProcessorStep(enabled=True)
|
||||
result = step(transition)
|
||||
|
||||
result_state = result[TransitionKey.OBSERVATION][OBS_STATE]
|
||||
torch.testing.assert_close(result_state, original_state)
|
||||
@@ -142,7 +142,6 @@ def _make_reachy2_camera_mock(*args, **kwargs):
|
||||
cam.connect = MagicMock()
|
||||
cam.disconnect = MagicMock()
|
||||
cam.async_read = MagicMock(side_effect=lambda: np.zeros((height, width, 3), dtype=np.uint8))
|
||||
cam.read_latest = MagicMock(side_effect=lambda: np.zeros((height, width, 3), dtype=np.uint8))
|
||||
return cam
|
||||
|
||||
|
||||
|
||||
@@ -21,7 +21,6 @@ from lerobot.scripts.lerobot_edit_dataset import (
|
||||
ConvertImageToVideoConfig,
|
||||
DeleteEpisodesConfig,
|
||||
EditDatasetConfig,
|
||||
InfoConfig,
|
||||
MergeConfig,
|
||||
ModifyTasksConfig,
|
||||
OperationConfig,
|
||||
@@ -47,7 +46,6 @@ class TestOperationTypeParsing:
|
||||
("remove_feature", RemoveFeatureConfig),
|
||||
("modify_tasks", ModifyTasksConfig),
|
||||
("convert_image_to_video", ConvertImageToVideoConfig),
|
||||
("info", InfoConfig),
|
||||
],
|
||||
)
|
||||
def test_operation_type_resolves_correct_class(self, type_name, expected_cls):
|
||||
@@ -65,7 +63,6 @@ class TestOperationTypeParsing:
|
||||
("remove_feature", RemoveFeatureConfig),
|
||||
("modify_tasks", ModifyTasksConfig),
|
||||
("convert_image_to_video", ConvertImageToVideoConfig),
|
||||
("info", InfoConfig),
|
||||
],
|
||||
)
|
||||
def test_get_choice_name_roundtrips(self, type_name, expected_cls):
|
||||
|
||||
Reference in New Issue
Block a user