Compare commits

...

30 Commits

Author SHA1 Message Date
pepijn b2d3186011 Add chained SLURM mirror-and-double dataset script.
Provide a standalone DataTrove workflow that mirrors bimanual shards, aggregates mirrored output, builds a doubled dataset, and optionally pushes the final dataset to the Hub.

Made-with: Cursor
2026-02-27 11:13:17 +00:00
Steven Palma 5865170d36 chore(deps): bump ceil datasets (#2946) 2026-02-20 17:01:46 +01:00
Khalil 2dd366436e Fix gym-hil integration with the new LeRobot pipeline. (#2482)
* Add GymHILAdapterProcessorStep for gym-hil environment integration

* Fix action features in control loop for None teleop device with gym-hil

* Finalize dataset before pushing to hub for visualization on the hub

* Fix neutral action for gripper

* fix pre-commit
2026-02-19 14:35:02 +01:00
Steven Palma 5f15232271 chore: remove usernames + use entrypoints in docs, comments & sample commands (#2988) 2026-02-18 22:46:12 +01:00
Steven Palma bc38261321 feat(robots): use read_latest() camera (#2987)
* feat(robots): use read_latest() camera

* fix(test): add read_latest reachy cam mock
2026-02-18 20:05:15 +01:00
Caroline Pascal aaf3707058 fix(filtering): fixing episodes filtering in load_nested_dataset to always use .from_parquet() (#2982) 2026-02-18 19:16:53 +01:00
Steven Palma 89bd58a9a2 chore(scripts): warn if we don't respect the target FPS (#2986) 2026-02-18 18:22:35 +01:00
Steven Palma b22e0315b0 fix(utils): more conservative sleep_margin default value in precise_sleep (#2985) 2026-02-18 17:32:25 +01:00
HUANG TZU-CHUN fcbf550952 fix(docs): update environment variable name to HF_LEROBOT_HOME in docstring (#2973)
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2026-02-18 11:27:40 +01:00
Sota Nakamura af036ce57e fix(scripts): serve grpc for a web viewer (#2881)
* serve grpc for a web viewer

* add help

* remove ip detection

* fix comment

* pass grpc_port

* fix(CLI): fixing CLI display-compressed-images argument 1/2

Co-authored-by: HUANG TZU-CHUN <tzu.chun.huang.tw@gmail.com>
Signed-off-by: Caroline Pascal <caroline8.pascal@gmail.com>

* fix(CLI): fixing CLI display-compressed-images argument 2/2

Co-authored-by: HUANG TZU-CHUN <tzu.chun.huang.tw@gmail.com>
Signed-off-by: Caroline Pascal <caroline8.pascal@gmail.com>

---------

Signed-off-by: Caroline Pascal <caroline8.pascal@gmail.com>
Co-authored-by: Caroline Pascal <caroline8.pascal@gmail.com>
Co-authored-by: HUANG TZU-CHUN <tzu.chun.huang.tw@gmail.com>
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2026-02-18 01:05:51 +01:00
Vladislav Sovrasov 1c388c0002 (Chore) Bump upper bound for torch version (#2897)
* Bump upper torch version bound

* Apply suggestion from @Copilot

Signed-off-by: Vladislav Sovrasov <vladislav.sovrasov@intel.com>

* Update ref state dicts for schedulers

* Support older than 2.8 torch versions

* Fix precommit

---------

Signed-off-by: Vladislav Sovrasov <vladislav.sovrasov@intel.com>
2026-02-17 23:37:46 +01:00
masato-ka 51d3822d75 feat(datasets): Add info operation to lerobot-edit-dataset command (#2917)
* Add New featrue to lerobot_edit_datset.py that show dataset information.

* Fix to draccus error when happen give only --operation.type=info

* Updating test and documents regarding lerobot-edit-dataset info function.

* Updating documents regarding lerobot-edit-dataset extract function. option name in document is mistake.

* feat(datasets): Update to align formatting with pre-commit.(#2917)

Update to align formatting by pre-commit.

---------

Co-authored-by: Caroline Pascal <caroline8.pascal@gmail.com>
2026-02-17 20:09:42 +01:00
Pepijn 6600b60e7f always use degrees (#2968) 2026-02-13 13:49:01 +01:00
Caroline Pascal adebbcf090 fix(dataset tools draccus): fixing draccus parsing for dataset edit operation type specification (#2949)
* fix(edit dataset operation): fixing dataset tools CLI operation type specification

* test(edit dataset operation): adding tests for dataset tools operation type specification

* chore(format): running pre-commit

* chore(backward compatibility): adding a type property in OperationConfig for backward compatibility

Signed-off-by: Caroline Pascal <caroline8.pascal@gmail.com>
2026-02-12 18:56:04 +01:00
taken-yjyoon 3615160d89 fix(typo): Fixing wrong argparse examples in the comments (using 'True' not 'true') (#1040)
Co-authored-by: juni <>
2026-02-12 18:13:51 +01:00
Steven Palma fc8a388a25 feat(cameras): make backend configurable to the CLI (#2945)
* feat(cameras): make backend configurable to the CLI

* chore(cameras): address feedback

* feat(Enum error messages): adding better instanciation error messages for Enum classes

* chore(Enum error messages): propagating Enum error messages to all camera classes

* chore(comments): removing superfluous comments

* chore(format): applying ruff checks

---------

Co-authored-by: CarolinePascal <caroline8.pascal@gmail.com>
2026-02-11 13:57:25 +01:00
Steven Palma 3c84d271d5 fix(motors): use decorator to fix precommit (#2951) 2026-02-10 18:40:50 +01:00
Steven Palma 1ba3975020 chore: use is_connected decorators (#2948)
* chore: use is_connected decorators

* chore(robots): add is_connected to bi setups too
2026-02-10 17:49:30 +01:00
Steven Palma 35363c5798 chore(linter): ensure motors module passes MyPy type checks (#2939)
* fix: ensure motors module passes MyPy type checks

This commit fixes 62 mypy type errors in the motors module by:

- Updating Protocol classes (PortHandler, PacketHandler, GroupSyncRead,
  GroupSyncWrite) to use class-level attribute declarations instead of
  __init__ body declarations
- Adding missing `broadcastPing` method to PacketHandler Protocol
- Fixing return type annotations (e.g., `_get_motor_model` returns str, not int)
- Fixing parameter types to use `Sequence` for covariant list parameters
- Fixing `Mapping` for covariant dict value types in `_normalize`
- Updating method signatures to be consistent across parent and child classes
  (disable_torque, enable_torque, _get_half_turn_homings)
- Adding explicit `int()` casts for MotorCalibration arguments
- Adding explicit `return None` for functions returning Optional types
- Adding type annotations for variables like `data_list: dict[int, int]`
- Using `# type: ignore[method-assign]` for intentional monkeypatch
- Fixing variable references (using `self.groups` instead of `groups`)

Fixes #1723

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

* chore(style): pre-commit after main merge

* chore(linter): solve comments

* chore(linter): apply pre-commit fixes to damiao

* chore(linter): more fixes to damiao

---------

Co-authored-by: yurekami <yurekami@users.noreply.github.com>
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-10 17:35:39 +01:00
whats2000 778db19a17 [Bug Fix] fix(ci): prevent runner group error on fork pushes (#2911)
* fix(ci): prevent runner group error on fork pushes

Add repository check to unbound_deps_tests workflow to ensure
aws-general-8-plus runner group is only used on main repository,
preventing 'Required runner group not found' errors on forks.

* fix(ci): use gating job to prevent runner allocation on forks

The previous approach failed because GitHub evaluates runs-on before if conditions.
Now using a check-repo job that runs on ubuntu-latest first, and all jobs with
special runners depend on it and check its output before being scheduled.

* fix(ci): add gating job to full_tests to prevent runner allocation on forks

Apply the same gating pattern used in unbound_deps_tests to full_tests.yml
to prevent GitHub from trying to allocate custom runners when workflows
run on forks. The check-repo job runs first on ubuntu-latest and all jobs
with custom runners depend on it and check its output.

* fix(ci): add repository check to unbound_deps_tests workflow

Add 'if: github.repository == huggingface/lerobot' check to build-and-push-docker job to prevent runner group access errors on forks, matching the pattern used in nightly.yml

* fix(ci): add repository check to full_tests workflow

Add 'if: github.repository == huggingface/lerobot' check to build-and-push-docker and gpu-tests jobs to prevent runner group access errors on forks

* refactor(ci): remove redundant check from gpu-tests job

gpu-tests depends on build-and-push-docker via needs, so it will automatically skip when the parent job is skipped

* refactor(ci): remove unnecessary fork check from full-tests job

full-tests runs on ubuntu-latest which is available to all forks, no need to restrict it

---------

Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2026-02-10 15:21:40 +01:00
Jai Kumaar Ratadia d2d01399d6 docs: clarify installation steps are sequential, not optional (#2925)
* docs: clarify installation steps are sequential, not optional

Add intro paragraph noting conda is one path (not the only one) and
number the three sections as steps so readers understand miniforge and
environment setup are prerequisites, not independent choices.

* Update installation guide link for LeRobot

Signed-off-by: Jai Kumaar Ratadia <jaikumaarratadia@gmail.com>

* Fix link formatting in installation guide again

Signed-off-by: Jai Kumaar Ratadia <jaikumaarratadia@gmail.com>

---------

Signed-off-by: Jai Kumaar Ratadia <jaikumaarratadia@gmail.com>
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2026-02-10 15:18:32 +01:00
Aoqun Jin 5eba4ce6f4 Change LIBERO init_state_id when reset. (#2899)
* Change LIBERO init_state_id when reset.

Signed-off-by: Aoqun Jin <aojiaojiao@foxmail.com>

* Change LIBERO init_state_id when reset.

Signed-off-by: Aoqun Jin <aojiaojiao@foxmail.com>

* pre-commit run

---------

Signed-off-by: Aoqun Jin <aojiaojiao@foxmail.com>
Co-authored-by: Jade Choghari <chogharijade@gmail.com>
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2026-02-10 16:39:17 +03:00
Stepan Feduniak cca0296cd6 fix(pipeline): use FeatureType for STATE features in Libero processor (#2888)
* fix the types

* pre-commit

---------

Co-authored-by: Jade Choghari <chogharijade@gmail.com>
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2026-02-10 15:55:11 +03:00
Steven Palma 489cb7b6b9 fix(scripts): correct can import check (#2937) 2026-02-09 16:58:32 +01:00
Reece O'Mahoney e14bdf57d0 Convert tensors to scalars (#2903)
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2026-02-09 14:46:12 +01:00
Reece O'Mahoney 97e7e0f9ed feat(datasets): improve image transform support (#2885)
* improve image transform support

* add tests

* Add stricter transform check and extra test

* improve subclass check
2026-02-05 15:39:58 +01:00
jwang078 0f39248445 Small docstring fix in diffusion configuration (#2847) 2026-02-03 19:19:00 +01:00
Iori Yanokura a6370dd783 fix(wandb): truncate init tags to 64-character limit (#995) 2026-02-03 14:17:04 +01:00
Michel Aractingi 14a15f90e7 Add missing RL config options: add_ee_pose_to_observation and gripper_penalty_in_reward (#2873)
* fix(RL) add missing config arguments

* respond to copilot review

* fix(revert penalty in reward): reverting gripper penalty addition in reward. This is already done in compute_loss_discrete_critic.

---------

Co-authored-by: CarolinePascal <caroline8.pascal@gmail.com>
2026-02-02 22:14:03 +01:00
Hirokazu Ishida 9c24a09665 docs: update document in response to Simplify configs PR (#1596)
* docs: update document input/output_shapes -> input/output_features

* fix inconsistent quote (suggested by copilot reviewer)

* docs: shapes => PolicyFeature

* docs: relfect normalization_mapping and remove outdated
2026-02-02 20:05:58 +01:00
81 changed files with 1523 additions and 532 deletions
+5 -3
View File
@@ -101,9 +101,11 @@ jobs:
runs-on:
group: aws-general-8-plus
if: |
(github.event_name == 'pull_request_review' && github.event.review.state == 'approved' && github.event.pull_request.head.repo.fork == false) ||
github.event_name == 'push' ||
github.event_name == 'workflow_dispatch'
github.repository == 'huggingface/lerobot' && (
(github.event_name == 'pull_request_review' && github.event.review.state == 'approved' && github.event.pull_request.head.repo.fork == false) ||
github.event_name == 'push' ||
github.event_name == 'workflow_dispatch'
)
outputs:
image_tag: ${{ steps.set_tag.outputs.image_tag }}
env:
+1
View File
@@ -91,6 +91,7 @@ jobs:
name: Build and Push Docker
runs-on:
group: aws-general-8-plus
if: github.repository == 'huggingface/lerobot'
outputs:
image_tag: ${{ env.DOCKER_IMAGE_NAME }}
env:
+42 -42
View File
@@ -28,9 +28,9 @@ We don't expect the same optimal settings for a dataset of images from a simulat
For these reasons, we run this benchmark on four representative datasets:
- `lerobot/pusht_image`: (96 x 96 pixels) simulation with simple geometric shapes, fixed camera.
- `aliberts/aloha_mobile_shrimp_image`: (480 x 640 pixels) real-world indoor, moving camera.
- `aliberts/paris_street`: (720 x 1280 pixels) real-world outdoor, moving camera.
- `aliberts/kitchen`: (1080 x 1920 pixels) real-world indoor, fixed camera.
- `lerobot/aloha_mobile_shrimp_image`: (480 x 640 pixels) real-world indoor, moving camera.
- `lerobot/paris_street`: (720 x 1280 pixels) real-world outdoor, moving camera.
- `lerobot/kitchen`: (1080 x 1920 pixels) real-world indoor, fixed camera.
Note: The datasets used for this benchmark need to be image datasets, not video datasets.
@@ -179,7 +179,7 @@ python benchmark/video/run_video_benchmark.py \
--output-dir outputs/video_benchmark \
--repo-ids \
lerobot/pusht_image \
aliberts/aloha_mobile_shrimp_image \
lerobot/aloha_mobile_shrimp_image \
--vcodec libx264 libx265 \
--pix-fmt yuv444p yuv420p \
--g 2 20 None \
@@ -203,9 +203,9 @@ python benchmark/video/run_video_benchmark.py \
--output-dir outputs/video_benchmark \
--repo-ids \
lerobot/pusht_image \
aliberts/aloha_mobile_shrimp_image \
aliberts/paris_street \
aliberts/kitchen \
lerobot/aloha_mobile_shrimp_image \
lerobot/paris_street \
lerobot/kitchen \
--vcodec libx264 libx265 \
--pix-fmt yuv444p yuv420p \
--g 1 2 3 4 5 6 10 15 20 40 None \
@@ -221,9 +221,9 @@ python benchmark/video/run_video_benchmark.py \
--output-dir outputs/video_benchmark \
--repo-ids \
lerobot/pusht_image \
aliberts/aloha_mobile_shrimp_image \
aliberts/paris_street \
aliberts/kitchen \
lerobot/aloha_mobile_shrimp_image \
lerobot/paris_street \
lerobot/kitchen \
--vcodec libsvtav1 \
--pix-fmt yuv420p \
--g 1 2 3 4 5 6 10 15 20 40 None \
@@ -252,37 +252,37 @@ Since we're using av1 encoding, we're choosing the `pyav` decoder as `video_read
These tables show the results for `g=2` and `crf=30`, using `timestamps-modes=6_frames` and `backend=pyav`
| video_images_size_ratio | vcodec | pix_fmt | | | |
| ---------------------------------- | ---------- | ------- | --------- | --------- | --------- |
| | libx264 | | libx265 | | libsvtav1 |
| repo_id | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
| lerobot/pusht_image | **16.97%** | 17.58% | 18.57% | 18.86% | 22.06% |
| aliberts/aloha_mobile_shrimp_image | 2.14% | 2.11% | 1.38% | **1.37%** | 5.59% |
| aliberts/paris_street | 2.12% | 2.13% | **1.54%** | **1.54%** | 4.43% |
| aliberts/kitchen | 1.40% | 1.39% | **1.00%** | **1.00%** | 2.52% |
| video_images_size_ratio | vcodec | pix_fmt | | | |
| --------------------------------- | ---------- | ------- | --------- | --------- | --------- |
| | libx264 | | libx265 | | libsvtav1 |
| repo_id | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
| lerobot/pusht_image | **16.97%** | 17.58% | 18.57% | 18.86% | 22.06% |
| lerobot/aloha_mobile_shrimp_image | 2.14% | 2.11% | 1.38% | **1.37%** | 5.59% |
| lerobot/paris_street | 2.12% | 2.13% | **1.54%** | **1.54%** | 4.43% |
| lerobot/kitchen | 1.40% | 1.39% | **1.00%** | **1.00%** | 2.52% |
| video_images_load_time_ratio | vcodec | pix_fmt | | | |
| ---------------------------------- | ------- | ------- | -------- | ------- | --------- |
| | libx264 | | libx265 | | libsvtav1 |
| repo_id | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
| lerobot/pusht_image | 6.45 | 5.19 | **1.90** | 2.12 | 2.47 |
| aliberts/aloha_mobile_shrimp_image | 11.80 | 7.92 | 0.71 | 0.85 | **0.48** |
| aliberts/paris_street | 2.21 | 2.05 | 0.36 | 0.49 | **0.30** |
| aliberts/kitchen | 1.46 | 1.46 | 0.28 | 0.51 | **0.26** |
| video_images_load_time_ratio | vcodec | pix_fmt | | | |
| --------------------------------- | ------- | ------- | -------- | ------- | --------- |
| | libx264 | | libx265 | | libsvtav1 |
| repo_id | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
| lerobot/pusht_image | 6.45 | 5.19 | **1.90** | 2.12 | 2.47 |
| lerobot/aloha_mobile_shrimp_image | 11.80 | 7.92 | 0.71 | 0.85 | **0.48** |
| lerobot/paris_street | 2.21 | 2.05 | 0.36 | 0.49 | **0.30** |
| lerobot/kitchen | 1.46 | 1.46 | 0.28 | 0.51 | **0.26** |
| | | vcodec | pix_fmt | | | |
| ---------------------------------- | -------- | -------- | ------------ | -------- | --------- | ------------ |
| | | libx264 | | libx265 | | libsvtav1 |
| repo_id | metric | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
| lerobot/pusht_image | avg_mse | 2.90E-04 | **2.03E-04** | 3.13E-04 | 2.29E-04 | 2.19E-04 |
| | avg_psnr | 35.44 | 37.07 | 35.49 | **37.30** | 37.20 |
| | avg_ssim | 98.28% | **98.85%** | 98.31% | 98.84% | 98.72% |
| aliberts/aloha_mobile_shrimp_image | avg_mse | 2.76E-04 | 2.59E-04 | 3.17E-04 | 3.06E-04 | **1.30E-04** |
| | avg_psnr | 35.91 | 36.21 | 35.88 | 36.09 | **40.17** |
| | avg_ssim | 95.19% | 95.18% | 95.00% | 95.05% | **97.73%** |
| aliberts/paris_street | avg_mse | 6.89E-04 | 6.70E-04 | 4.03E-03 | 4.02E-03 | **3.09E-04** |
| | avg_psnr | 33.48 | 33.68 | 32.05 | 32.15 | **35.40** |
| | avg_ssim | 93.76% | 93.75% | 89.46% | 89.46% | **95.46%** |
| aliberts/kitchen | avg_mse | 2.50E-04 | 2.24E-04 | 4.28E-04 | 4.18E-04 | **1.53E-04** |
| | avg_psnr | 36.73 | 37.33 | 36.56 | 36.75 | **39.12** |
| | avg_ssim | 95.47% | 95.58% | 95.52% | 95.53% | **96.82%** |
| | | vcodec | pix_fmt | | | |
| --------------------------------- | -------- | -------- | ------------ | -------- | --------- | ------------ |
| | | libx264 | | libx265 | | libsvtav1 |
| repo_id | metric | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
| lerobot/pusht_image | avg_mse | 2.90E-04 | **2.03E-04** | 3.13E-04 | 2.29E-04 | 2.19E-04 |
| | avg_psnr | 35.44 | 37.07 | 35.49 | **37.30** | 37.20 |
| | avg_ssim | 98.28% | **98.85%** | 98.31% | 98.84% | 98.72% |
| lerobot/aloha_mobile_shrimp_image | avg_mse | 2.76E-04 | 2.59E-04 | 3.17E-04 | 3.06E-04 | **1.30E-04** |
| | avg_psnr | 35.91 | 36.21 | 35.88 | 36.09 | **40.17** |
| | avg_ssim | 95.19% | 95.18% | 95.00% | 95.05% | **97.73%** |
| lerobot/paris_street | avg_mse | 6.89E-04 | 6.70E-04 | 4.03E-03 | 4.02E-03 | **3.09E-04** |
| | avg_psnr | 33.48 | 33.68 | 32.05 | 32.15 | **35.40** |
| | avg_ssim | 93.76% | 93.75% | 89.46% | 89.46% | **95.46%** |
| lerobot/kitchen | avg_mse | 2.50E-04 | 2.24E-04 | 4.28E-04 | 4.18E-04 | **1.53E-04** |
| | avg_psnr | 36.73 | 37.33 | 36.56 | 36.75 | **39.12** |
| | avg_ssim | 95.47% | 95.58% | 95.52% | 95.53% | **96.82%** |
+1 -1
View File
@@ -185,7 +185,7 @@ echo $HF_USER
Use the standard recording command:
```bash
python src/lerobot/scripts/lerobot_record.py \
lerobot-record \
--robot.type=earthrover_mini_plus \
--teleop.type=keyboard_rover \
--dataset.repo_id=your_username/dataset_name \
+5 -5
View File
@@ -224,7 +224,7 @@ lerobot-record \
--teleop.port=/dev/tty.usbmodem1201 \
--teleop.id=right \
--teleop.side=right \
--dataset.repo_id=nepyope/hand_record_test_with_video_data \
--dataset.repo_id=<USER>/hand_record_test_with_video_data \
--dataset.single_task="Hand recording test with video data" \
--dataset.num_episodes=1 \
--dataset.episode_time_s=5 \
@@ -241,7 +241,7 @@ lerobot-replay \
--robot.port=/dev/tty.usbmodem58760432281 \
--robot.id=right \
--robot.side=right \
--dataset.repo_id=nepyope/hand_record_test_with_camera \
--dataset.repo_id=<USER>/hand_record_test_with_camera \
--dataset.episode=0
```
@@ -249,13 +249,13 @@ lerobot-replay \
```bash
lerobot-train \
--dataset.repo_id=nepyope/hand_record_test_with_video_data \
--dataset.repo_id=<USER>/hand_record_test_with_video_data \
--policy.type=act \
--output_dir=outputs/train/hopejr_hand \
--job_name=hopejr \
--policy.device=mps \
--wandb.enable=true \
--policy.repo_id=nepyope/hand_test_policy
--policy.repo_id=<USER>/hand_test_policy
```
### Evaluate
@@ -270,7 +270,7 @@ lerobot-record \
--robot.side=right \
--robot.cameras='{"main": {"type": "opencv", "index_or_path": 0, "width": 640, "height": 480, "fps": 30}}' \
--display_data=false \
--dataset.repo_id=nepyope/eval_hopejr \
--dataset.repo_id=<USER>/eval_hopejr \
--dataset.single_task="Evaluate hopejr hand policy" \
--dataset.num_episodes=10 \
--policy.path=outputs/train/hopejr_hand/checkpoints/last/pretrained_model
+5 -3
View File
@@ -1,13 +1,15 @@
# Installation
## Install [`miniforge`](https://conda-forge.org/download/)
This guide uses conda (via miniforge) to manage environments. If you prefer another environment manager (e.g. `uv`, `venv`), ensure you have Python >=3.10 and ffmpeg installed with the `libsvtav1` encoder, then skip ahead to [Install LeRobot](#step-3-install-lerobot-).
## Step 1: Install [`miniforge`](https://conda-forge.org/download/)
```bash
wget "https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-$(uname)-$(uname -m).sh"
bash Miniforge3-$(uname)-$(uname -m).sh
```
## Environment Setup
## Step 2: Environment Setup
Create a virtual environment with Python 3.10, using conda:
@@ -38,7 +40,7 @@ conda install ffmpeg -c conda-forge
>
> - _[On Linux only]_ If you want to bring your own ffmpeg: Install [ffmpeg build dependencies](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#GettheDependencies) and [compile ffmpeg from source with libsvtav1](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#libsvtav1), and make sure you use the corresponding ffmpeg binary to your install with `which ffmpeg`.
## Install LeRobot 🤗
## Step 3: Install LeRobot 🤗
### From Source
+1 -1
View File
@@ -60,7 +60,7 @@ policy.type=pi0
For training π₀, you can use the standard LeRobot training script with the appropriate configuration:
```bash
python src/lerobot/scripts/lerobot_train.py \
lerobot-train \
--dataset.repo_id=your_dataset \
--policy.type=pi0 \
--output_dir=./outputs/pi0_training \
+1 -1
View File
@@ -56,7 +56,7 @@ policy.type=pi05
Here's a complete training command for finetuning the base π₀.₅ model on your own dataset:
```bash
python src/lerobot/scripts/lerobot_train.py\
lerobot-train \
--dataset.repo_id=your_dataset \
--policy.type=pi05 \
--output_dir=./outputs/pi05_training \
+4 -4
View File
@@ -269,7 +269,7 @@ This generates visualizations showing video frames with subtask boundaries overl
Train with **no annotations** - uses linear progress from 0 to 1:
```bash
python src/lerobot/scripts/lerobot_train.py \
lerobot-train \
--dataset.repo_id=your-username/your-dataset \
--policy.type=sarm \
--policy.annotation_mode=single_stage \
@@ -288,7 +288,7 @@ python src/lerobot/scripts/lerobot_train.py \
Train with **dense annotations only** (sparse auto-generated):
```bash
python src/lerobot/scripts/lerobot_train.py \
lerobot-train \
--dataset.repo_id=your-username/your-dataset \
--policy.type=sarm \
--policy.annotation_mode=dense_only \
@@ -307,7 +307,7 @@ python src/lerobot/scripts/lerobot_train.py \
Train with **both sparse and dense annotations**:
```bash
python src/lerobot/scripts/lerobot_train.py \
lerobot-train \
--dataset.repo_id=your-username/your-dataset \
--policy.type=sarm \
--policy.annotation_mode=dual \
@@ -468,7 +468,7 @@ This script:
Once you have the progress file, train your policy with RA-BC weighting. The progress file is auto-detected from the dataset path (`sarm_progress.parquet`). Currently PI0, PI0.5 and SmolVLA are supported with RA-BC:
```bash
python src/lerobot/scripts/lerobot_train.py \
lerobot-train \
--dataset.repo_id=your-username/your-dataset \
--policy.type=pi0 \
--use_rabc=true \
+2 -2
View File
@@ -216,7 +216,7 @@ lerobot-teleoperate \
### Record Dataset in Simulation
```bash
python -m lerobot.scripts.lerobot_record \
lerobot-record \
--robot.type=unitree_g1 \
--robot.is_simulation=true \
--robot.cameras='{"global_view": {"type": "zmq", "server_address": "localhost", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30}}' \
@@ -266,7 +266,7 @@ lerobot-teleoperate \
### Record Dataset on Real Robot
```bash
python -m lerobot.scripts.lerobot_record \
lerobot-record \
--robot.type=unitree_g1 \
--robot.is_simulation=false \
--robot.cameras='{"global_view": {"type": "zmq", "server_address": "172.18.129.215", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30}}' \
+25
View File
@@ -12,6 +12,7 @@ LeRobot provides several utilities for manipulating datasets:
4. **Add Features** - Add new features to a dataset
5. **Remove Features** - Remove features from a dataset
6. **Convert to Video** - Convert image-based datasets to video format for efficient storage
7. **Show the Info of Datasets** - Show the summary of datasets information such as number of episode etc.
The core implementation is in `lerobot.datasets.dataset_tools`.
An example script detailing how to use the tools API is available in `examples/dataset/use_dataset_tools.py`.
@@ -156,6 +157,30 @@ lerobot-edit-dataset \
**Note:** The resulting dataset will be a proper LeRobotDataset with all cameras encoded as videos in the `videos/` directory, with parquet files containing only metadata (no raw image data). All episodes, stats, and tasks are preserved.
### Show the information of datasets
Show the information of datasets such as number of episode, number of frame, File size and so on.
No change will be made to the dataset
```bash
# Show dataset information without feature details
lerobot-edit-dataset \
--repo_id lerobot/pusht_image \
--operation.type info \
# Show dataset information with feature details
lerobot-edit-dataset \
--repo_id lerobot/pusht_image \
--operation.type info \
--operation.show_features true
```
**Parameters:**
- `parameters`: The flag to control show or no show dataset information with feature details.(default=false)
### Push to Hub
Add the `--push_to_hub true` flag to any command to automatically upload the resulting dataset to the Hugging Face Hub:
+1 -1
View File
@@ -45,7 +45,7 @@ policy.type=wall_x
For training WallX, you can use the standard LeRobot training script with the appropriate configuration:
```bash
python src/lerobot/scripts/lerobot_train.py \
lerobot-train \
--dataset.repo_id=your_dataset \
--policy.type=wall_x \
--output_dir=./outputs/wallx_training \
+1 -1
View File
@@ -154,7 +154,7 @@ lerobot-train \
```bash
lerobot-train \
--dataset.repo_id=pepijn223/bimanual-so100-handover-cube \
--dataset.repo_id=<USER>/bimanual-so100-handover-cube \
--output_dir=./outputs/xvla_bimanual \
--job_name=xvla_so101_training \
--policy.path="lerobot/xvla-base" \
+1 -1
View File
@@ -22,7 +22,7 @@ lerobot-replay \
--robot.type=so100_follower \
--robot.port=/dev/tty.usbmodem58760431541 \
--robot.id=black \
--dataset.repo_id=aliberts/record-test \
--dataset.repo_id=<USER>/record-test \
--dataset.episode=2
```
"""
@@ -0,0 +1,726 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Mirror a bimanual dataset in parallel with DataTrove + SLURM, then double it.
Workflow:
1) Split source episodes across `num_shards` ranks and mirror each shard in parallel.
2) Aggregate mirrored shards into one mirrored dataset.
3) Aggregate [original, mirrored] into a final doubled dataset.
Example:
python examples/port_datasets/slurm_mirror_dataset.py \
--repo-id=pepijn/openarm_bimanual \
--output-repo-id=pepijn/openarm_bimanual_doubled \
--partition=hopper-cpu \
--num-shards=256 \
--workers=64 \
--cpus-per-task=8 \
--mem-per-cpu=4G
"""
import argparse
import copy
import logging
import shutil
from pathlib import Path
from typing import Any
import numpy as np
from datatrove.executor import LocalPipelineExecutor
from datatrove.executor.slurm import SlurmPipelineExecutor
from datatrove.pipeline.base import PipelineStep
from lerobot.datasets.aggregate import aggregate_datasets
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
from lerobot.datasets.utils import DEFAULT_FEATURES
from lerobot.utils.constants import HF_LEROBOT_HOME
from lerobot.utils.utils import init_logging
logger = logging.getLogger(__name__)
OPENARM_MIRRORING_MASK = {
"joint_1": -1,
"joint_2": -1,
"joint_3": -1,
"joint_4": 1,
"joint_5": -1,
"joint_6": -1,
"joint_7": -1,
"gripper": 1,
}
def get_mirroring_mask(robot_type: str | None) -> dict[str, int]:
if robot_type in ["bi_openarm_follower", "openarm_follower", "bi_openarms_follower", "openarms_follower"]:
return OPENARM_MIRRORING_MASK
raise ValueError(f"Unknown robot type: {robot_type}. Add a mirroring mask for this robot.")
def swap_left_right_name(name: str) -> str:
value = name.replace("left_", "LEFT_PLACEHOLDER_")
value = value.replace("right_", "left_")
value = value.replace("LEFT_PLACEHOLDER_", "right_")
return value
def mirror_feature_names(names: list[str]) -> tuple[list[str], dict[int, int]]:
mirrored_names = [swap_left_right_name(n) for n in names]
old_to_new_idx = {}
for old_idx, old_name in enumerate(names):
new_name = swap_left_right_name(old_name)
new_idx = mirrored_names.index(new_name)
old_to_new_idx[old_idx] = new_idx
return mirrored_names, old_to_new_idx
def _get_axis_names(feature: dict[str, Any]) -> list[str] | None:
names = feature.get("names")
if isinstance(names, list):
return names
if isinstance(names, dict):
axes = names.get("axes")
if isinstance(axes, list):
return axes
return None
def _to_numpy(value: Any) -> Any:
if isinstance(value, np.ndarray):
return value
if hasattr(value, "detach"):
return value.detach().cpu().numpy()
if hasattr(value, "cpu") and hasattr(value, "numpy"):
return value.cpu().numpy()
if hasattr(value, "numpy"):
return value.numpy()
return value
def apply_mirroring_mask(value: float, axis_name: str, mirroring_mask: dict[str, int]) -> float:
if axis_name.startswith("left_") or axis_name.startswith("right_"):
axis_name = axis_name.split("_", 1)[1]
joint_name = axis_name.split(".")[0]
return value * mirroring_mask.get(joint_name, 1)
def mirror_vector_feature(
value: Any,
feature: dict[str, Any],
mirroring_mask: dict[str, int],
) -> Any:
array = _to_numpy(value)
if not isinstance(array, np.ndarray) or array.ndim != 1:
return array
names = _get_axis_names(feature)
if names is None or len(names) != len(array):
return array
mirrored_names, index_mapping = mirror_feature_names(names)
mirrored = np.zeros_like(array)
for old_idx, new_idx in index_mapping.items():
mirrored[new_idx] = apply_mirroring_mask(array[old_idx], mirrored_names[new_idx], mirroring_mask)
return mirrored
def flip_horizontal(value: Any, expected_shape: list[int] | tuple[int, ...]) -> Any:
array = _to_numpy(value)
if not isinstance(array, np.ndarray) or array.ndim != 3:
return array
expected_shape = tuple(expected_shape)
if array.shape == expected_shape:
return np.flip(array, axis=1).copy() # HWC
if len(expected_shape) == 3:
c, h, w = expected_shape
if array.shape == (c, h, w):
return np.flip(array, axis=2).copy() # CHW
# Conservative fallback for unexpected layouts.
return np.flip(array, axis=-1).copy()
def build_mirrored_features(features: dict[str, dict[str, Any]]) -> dict[str, dict[str, Any]]:
mirrored = {}
for key, feature in features.items():
new_key = swap_left_right_name(key)
new_feature = copy.deepcopy(feature)
names = new_feature.get("names")
if isinstance(names, list):
new_feature["names"] = [swap_left_right_name(name) for name in names]
elif isinstance(names, dict) and isinstance(names.get("axes"), list):
new_feature["names"]["axes"] = [swap_left_right_name(name) for name in names["axes"]]
mirrored[new_key] = new_feature
return mirrored
def build_mirrored_frame(
item: dict[str, Any],
source_features: dict[str, dict[str, Any]],
mirroring_mask: dict[str, int],
) -> dict[str, Any]:
frame = {}
for key, feature in source_features.items():
if key in DEFAULT_FEATURES:
continue
value = item[key]
if key in {"action", "observation.state"}:
value = mirror_vector_feature(value, feature, mirroring_mask)
elif feature["dtype"] in {"video", "image"}:
value = flip_horizontal(value, feature["shape"])
else:
value = _to_numpy(value)
frame[swap_left_right_name(key)] = value
frame["task"] = item["task"]
if "timestamp" in item:
ts = _to_numpy(item["timestamp"])
frame["timestamp"] = float(ts.item() if hasattr(ts, "item") else ts)
return frame
def _resolve_source_root(repo_id: str, root: Path | None) -> Path:
source_meta = LeRobotDatasetMetadata(repo_id=repo_id, root=root)
return source_meta.root
def _get_work_dir(output_repo_id: str, work_dir: Path | None) -> Path:
if work_dir is not None:
return work_dir
safe_name = output_repo_id.replace("/", "__")
return HF_LEROBOT_HOME / "_mirror_work" / safe_name
def _get_shard_root(work_dir: Path, world_size: int, rank: int) -> Path:
return work_dir / "mirrored_shards" / f"world_{world_size}_rank_{rank}"
def _is_valid_dataset_root(root: Path) -> bool:
return (root / "meta" / "info.json").exists()
def mirror_shard(
repo_id: str,
source_root: Path,
mirrored_repo_id: str,
shard_root: Path,
rank: int,
world_size: int,
vcodec: str,
overwrite: bool,
) -> None:
source_dataset = LeRobotDataset(repo_id=repo_id, root=source_root)
selected_episodes = list(range(rank, source_dataset.meta.total_episodes, world_size))
if len(selected_episodes) == 0:
logger.info("Rank %s has no episodes assigned. Skipping.", rank)
return
if shard_root.exists():
if overwrite:
shutil.rmtree(shard_root)
elif _is_valid_dataset_root(shard_root):
logger.info("Rank %s shard already exists at %s. Skipping.", rank, shard_root)
return
else:
raise RuntimeError(
f"Shard root {shard_root} exists but is not a valid dataset. Use --overwrite to recreate."
)
mirroring_mask = get_mirroring_mask(source_dataset.meta.robot_type)
mirrored_features = build_mirrored_features(source_dataset.meta.features)
shard_repo_name = f"{mirrored_repo_id}_world_{world_size}_rank_{rank}"
mirrored_dataset = LeRobotDataset.create(
repo_id=shard_repo_name,
root=shard_root,
fps=source_dataset.meta.fps,
features=mirrored_features,
robot_type=source_dataset.meta.robot_type,
use_videos=len(source_dataset.meta.video_keys) > 0,
vcodec=vcodec,
)
mirrored_dataset.meta.update_chunk_settings(
chunks_size=source_dataset.meta.chunks_size,
data_files_size_in_mb=source_dataset.meta.data_files_size_in_mb,
video_files_size_in_mb=source_dataset.meta.video_files_size_in_mb,
)
logger.info(
"Rank %s processing %s episodes into shard %s",
rank,
len(selected_episodes),
shard_root,
)
for source_ep_idx in selected_episodes:
episode = source_dataset.meta.episodes[source_ep_idx]
start_idx = int(episode["dataset_from_index"])
end_idx = int(episode["dataset_to_index"])
for frame_idx in range(start_idx, end_idx):
item = source_dataset[frame_idx]
mirrored_frame = build_mirrored_frame(
item=item,
source_features=source_dataset.meta.features,
mirroring_mask=mirroring_mask,
)
mirrored_dataset.add_frame(mirrored_frame)
mirrored_dataset.save_episode()
mirrored_dataset.finalize()
class MirrorDatasetShards(PipelineStep):
def __init__(
self,
repo_id: str,
source_root: Path,
mirrored_repo_id: str,
work_dir: Path,
vcodec: str,
overwrite: bool,
):
super().__init__()
self.repo_id = repo_id
self.source_root = source_root
self.mirrored_repo_id = mirrored_repo_id
self.work_dir = work_dir
self.vcodec = vcodec
self.overwrite = overwrite
def run(self, data=None, rank: int = 0, world_size: int = 1):
init_logging()
shard_root = _get_shard_root(self.work_dir, world_size, rank)
mirror_shard(
repo_id=self.repo_id,
source_root=self.source_root,
mirrored_repo_id=self.mirrored_repo_id,
shard_root=shard_root,
rank=rank,
world_size=world_size,
vcodec=self.vcodec,
overwrite=self.overwrite,
)
def make_mirror_executor(
repo_id: str,
source_root: Path,
mirrored_repo_id: str,
work_dir: Path,
logs_dir: Path,
job_name: str,
num_shards: int,
workers: int,
partition: str,
cpus_per_task: int,
mem_per_cpu: str,
time_limit: str,
vcodec: str,
overwrite: bool,
slurm: bool,
):
kwargs = {
"pipeline": [
MirrorDatasetShards(
repo_id=repo_id,
source_root=source_root,
mirrored_repo_id=mirrored_repo_id,
work_dir=work_dir,
vcodec=vcodec,
overwrite=overwrite,
),
],
"logging_dir": str(logs_dir / job_name),
}
if slurm:
if partition is None:
raise ValueError("`--partition` is required when `--slurm 1`.")
kwargs.update(
{
"job_name": job_name,
"tasks": num_shards,
"workers": workers,
"time": time_limit,
"partition": partition,
"cpus_per_task": cpus_per_task,
"sbatch_args": {"mem-per-cpu": mem_per_cpu},
}
)
return SlurmPipelineExecutor(**kwargs)
kwargs.update({"tasks": num_shards, "workers": 1})
return LocalPipelineExecutor(**kwargs)
class AggregateMirroredShardsStep(PipelineStep):
def __init__(
self,
mirrored_repo_id: str,
mirrored_root: Path,
work_dir: Path,
num_shards: int,
overwrite: bool,
):
super().__init__()
self.mirrored_repo_id = mirrored_repo_id
self.mirrored_root = mirrored_root
self.work_dir = work_dir
self.num_shards = num_shards
self.overwrite = overwrite
def run(self, data=None, rank: int = 0, world_size: int = 1):
init_logging()
if rank != 0:
logger.info("Skipping rank %s for aggregate mirrored step", rank)
return
aggregate_mirrored_shards(
mirrored_repo_id=self.mirrored_repo_id,
mirrored_root=self.mirrored_root,
work_dir=self.work_dir,
num_shards=self.num_shards,
overwrite=self.overwrite,
)
class BuildDoubledDatasetStep(PipelineStep):
def __init__(
self,
source_repo_id: str,
source_root: Path,
mirrored_repo_id: str,
mirrored_root: Path,
output_repo_id: str,
output_root: Path,
overwrite: bool,
):
super().__init__()
self.source_repo_id = source_repo_id
self.source_root = source_root
self.mirrored_repo_id = mirrored_repo_id
self.mirrored_root = mirrored_root
self.output_repo_id = output_repo_id
self.output_root = output_root
self.overwrite = overwrite
def run(self, data=None, rank: int = 0, world_size: int = 1):
init_logging()
if rank != 0:
logger.info("Skipping rank %s for build doubled step", rank)
return
build_doubled_dataset(
source_repo_id=self.source_repo_id,
source_root=self.source_root,
mirrored_repo_id=self.mirrored_repo_id,
mirrored_root=self.mirrored_root,
output_repo_id=self.output_repo_id,
output_root=self.output_root,
overwrite=self.overwrite,
)
class PushDoubledDatasetStep(PipelineStep):
def __init__(
self,
output_repo_id: str,
output_root: Path,
):
super().__init__()
self.output_repo_id = output_repo_id
self.output_root = output_root
def run(self, data=None, rank: int = 0, world_size: int = 1):
init_logging()
if rank != 0:
logger.info("Skipping rank %s for push step", rank)
return
logger.info("Pushing doubled dataset to hub: %s", self.output_repo_id)
LeRobotDataset(self.output_repo_id, root=self.output_root).push_to_hub()
def make_single_task_executor(
step: PipelineStep,
logs_dir: Path,
job_name: str,
partition: str | None,
cpus_per_task: int,
mem_per_cpu: str,
time_limit: str,
slurm: bool,
depends: SlurmPipelineExecutor | None = None,
):
kwargs = {"pipeline": [step], "logging_dir": str(logs_dir / job_name)}
if slurm:
if partition is None:
raise ValueError("`--partition` is required when `--slurm 1`.")
kwargs.update(
{
"job_name": job_name,
"tasks": 1,
"workers": 1,
"time": time_limit,
"partition": partition,
"cpus_per_task": cpus_per_task,
"sbatch_args": {"mem-per-cpu": mem_per_cpu},
"depends": depends,
}
)
return SlurmPipelineExecutor(**kwargs)
kwargs.update({"tasks": 1, "workers": 1})
return LocalPipelineExecutor(**kwargs)
def aggregate_mirrored_shards(
mirrored_repo_id: str,
mirrored_root: Path,
work_dir: Path,
num_shards: int,
overwrite: bool,
):
if mirrored_root.exists():
if overwrite:
shutil.rmtree(mirrored_root)
elif _is_valid_dataset_root(mirrored_root):
logger.info("Mirrored dataset already exists at %s. Skipping aggregation.", mirrored_root)
return
else:
raise RuntimeError(
f"Mirrored root {mirrored_root} exists but is not a valid dataset. Use --overwrite to recreate."
)
shard_repo_ids = []
shard_roots = []
for rank in range(num_shards):
shard_root = _get_shard_root(work_dir, num_shards, rank)
if _is_valid_dataset_root(shard_root):
shard_repo_ids.append(f"{mirrored_repo_id}_world_{num_shards}_rank_{rank}")
shard_roots.append(shard_root)
if len(shard_repo_ids) == 0:
raise RuntimeError("No mirrored shards were produced. Nothing to aggregate.")
logger.info("Aggregating %s mirrored shards into %s", len(shard_repo_ids), mirrored_root)
aggregate_datasets(
repo_ids=shard_repo_ids,
roots=shard_roots,
aggr_repo_id=mirrored_repo_id,
aggr_root=mirrored_root,
)
def build_doubled_dataset(
source_repo_id: str,
source_root: Path,
mirrored_repo_id: str,
mirrored_root: Path,
output_repo_id: str,
output_root: Path,
overwrite: bool,
):
if output_root.exists():
if overwrite:
shutil.rmtree(output_root)
elif _is_valid_dataset_root(output_root):
logger.info("Doubled dataset already exists at %s. Skipping final aggregation.", output_root)
return
else:
raise RuntimeError(
f"Output root {output_root} exists but is not a valid dataset. Use --overwrite to recreate."
)
logger.info("Aggregating source + mirrored into doubled dataset at %s", output_root)
aggregate_datasets(
repo_ids=[source_repo_id, mirrored_repo_id],
roots=[source_root, mirrored_root],
aggr_repo_id=output_repo_id,
aggr_root=output_root,
)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--repo-id", type=str, required=True, help="Source dataset repo id.")
parser.add_argument("--output-repo-id", type=str, required=True, help="Final doubled dataset repo id.")
parser.add_argument("--root", type=Path, default=None, help="Root path of source dataset.")
parser.add_argument(
"--output-root",
type=Path,
default=None,
help="Root path where final doubled dataset is written.",
)
parser.add_argument(
"--work-dir",
type=Path,
default=None,
help="Intermediate directory for mirrored shards and mirrored aggregate dataset.",
)
parser.add_argument("--logs-dir", type=Path, required=True, help="DataTrove logs path.")
parser.add_argument("--job-name", type=str, default="mirror_dataset", help="SLURM job name.")
parser.add_argument("--num-shards", type=int, default=256, help="Number of DataTrove tasks/ranks.")
parser.add_argument(
"--workers",
type=int,
default=64,
help="Max concurrent DataTrove workers on SLURM.",
)
parser.add_argument("--partition", type=str, default=None, help="SLURM partition (e.g. hopper-cpu).")
parser.add_argument("--cpus-per-task", type=int, default=8, help="CPU count per SLURM task.")
parser.add_argument("--mem-per-cpu", type=str, default="4G", help="Memory per CPU for SLURM task.")
parser.add_argument("--time", type=str, default="24:00:00", help="SLURM time limit.")
parser.add_argument("--vcodec", type=str, default="libsvtav1", help="Video codec for output videos.")
parser.add_argument(
"--slurm",
type=int,
default=1,
help="Use SLURM executor. Set 0 for local sequential debugging.",
)
parser.add_argument("--overwrite", action="store_true", help="Delete existing intermediate/final outputs.")
parser.add_argument(
"--push-to-hub",
action="store_true",
help="Push final doubled dataset to Hugging Face Hub after completion.",
)
args = parser.parse_args()
init_logging()
slurm = args.slurm == 1
source_root = _resolve_source_root(args.repo_id, args.root)
output_root = args.output_root if args.output_root is not None else HF_LEROBOT_HOME / args.output_repo_id
work_dir = _get_work_dir(args.output_repo_id, args.work_dir)
mirrored_repo_id = f"{args.output_repo_id}_mirrored"
mirrored_root = work_dir / "mirrored_aggregate"
work_dir.mkdir(parents=True, exist_ok=True)
args.logs_dir.mkdir(parents=True, exist_ok=True)
mirror_executor = make_mirror_executor(
repo_id=args.repo_id,
source_root=source_root,
mirrored_repo_id=mirrored_repo_id,
work_dir=work_dir,
logs_dir=args.logs_dir,
job_name=args.job_name,
num_shards=args.num_shards,
workers=args.workers,
partition=args.partition,
cpus_per_task=args.cpus_per_task,
mem_per_cpu=args.mem_per_cpu,
time_limit=args.time,
vcodec=args.vcodec,
overwrite=args.overwrite,
slurm=slurm,
)
if slurm:
aggregate_executor = make_single_task_executor(
step=AggregateMirroredShardsStep(
mirrored_repo_id=mirrored_repo_id,
mirrored_root=mirrored_root,
work_dir=work_dir,
num_shards=args.num_shards,
overwrite=args.overwrite,
),
logs_dir=args.logs_dir,
job_name=f"{args.job_name}_aggregate_mirrored",
partition=args.partition,
cpus_per_task=args.cpus_per_task,
mem_per_cpu=args.mem_per_cpu,
time_limit=args.time,
slurm=True,
depends=mirror_executor,
)
build_executor = make_single_task_executor(
step=BuildDoubledDatasetStep(
source_repo_id=args.repo_id,
source_root=source_root,
mirrored_repo_id=mirrored_repo_id,
mirrored_root=mirrored_root,
output_repo_id=args.output_repo_id,
output_root=output_root,
overwrite=args.overwrite,
),
logs_dir=args.logs_dir,
job_name=f"{args.job_name}_build_doubled",
partition=args.partition,
cpus_per_task=args.cpus_per_task,
mem_per_cpu=args.mem_per_cpu,
time_limit=args.time,
slurm=True,
depends=aggregate_executor,
)
final_executor: SlurmPipelineExecutor | LocalPipelineExecutor = build_executor
push_executor = None
if args.push_to_hub:
push_executor = make_single_task_executor(
step=PushDoubledDatasetStep(
output_repo_id=args.output_repo_id,
output_root=output_root,
),
logs_dir=args.logs_dir,
job_name=f"{args.job_name}_push",
partition=args.partition,
cpus_per_task=args.cpus_per_task,
mem_per_cpu=args.mem_per_cpu,
time_limit=args.time,
slurm=True,
depends=build_executor,
)
final_executor = push_executor
final_executor.run()
logger.info(
"Submitted SLURM chain. job_ids: mirror=%s aggregate=%s doubled=%s push=%s",
mirror_executor.job_id,
aggregate_executor.job_id,
build_executor.job_id,
push_executor.job_id if push_executor is not None else None,
)
return
mirror_executor.run()
aggregate_mirrored_shards(
mirrored_repo_id=mirrored_repo_id,
mirrored_root=mirrored_root,
work_dir=work_dir,
num_shards=args.num_shards,
overwrite=args.overwrite,
)
build_doubled_dataset(
source_repo_id=args.repo_id,
source_root=source_root,
mirrored_repo_id=mirrored_repo_id,
mirrored_root=mirrored_root,
output_repo_id=args.output_repo_id,
output_root=output_root,
overwrite=args.overwrite,
)
if args.push_to_hub:
logger.info("Pushing doubled dataset to hub: %s", args.output_repo_id)
LeRobotDataset(args.output_repo_id, root=output_root).push_to_hub()
if __name__ == "__main__":
main()
+10 -10
View File
@@ -27,8 +27,8 @@ measuring consistency and ground truth alignment.
Usage:
# Basic usage with smolvla policy
uv run python examples/rtc/eval_dataset.py \
--policy.path=helper2424/smolvla_check_rtc_last3 \
--dataset.repo_id=helper2424/check_rtc \
--policy.path=<USER>/smolvla_check_rtc_last3 \
--dataset.repo_id=<USER>/check_rtc \
--rtc.execution_horizon=8 \
--device=mps \
--rtc.max_guidance_weight=10.0 \
@@ -58,16 +58,16 @@ Usage:
--device=cuda
uv run python examples/rtc/eval_dataset.py \
--policy.path=lipsop/reuben_pi0 \
--dataset.repo_id=ReubenLim/so101_cube_in_cup \
--policy.path=<USER>/reuben_pi0 \
--dataset.repo_id=<USER>/so101_cube_in_cup \
--rtc.execution_horizon=8 \
--device=cuda
# With torch.compile for faster inference (PyTorch 2.0+)
# Note: CUDA graphs disabled by default due to in-place ops in denoising loop
uv run python examples/rtc/eval_dataset.py \
--policy.path=helper2424/smolvla_check_rtc_last3 \
--dataset.repo_id=helper2424/check_rtc \
--policy.path=<USER>/smolvla_check_rtc_last3 \
--dataset.repo_id=<USER>/check_rtc \
--rtc.execution_horizon=8 \
--device=mps \
--use_torch_compile=true \
@@ -75,8 +75,8 @@ Usage:
# With torch.compile on CUDA (CUDA graphs disabled by default)
uv run python examples/rtc/eval_dataset.py \
--policy.path=helper2424/smolvla_check_rtc_last3 \
--dataset.repo_id=helper2424/check_rtc \
--policy.path=<USER>/smolvla_check_rtc_last3 \
--dataset.repo_id=<USER>/check_rtc \
--rtc.execution_horizon=8 \
--device=cuda \
--use_torch_compile=true \
@@ -84,8 +84,8 @@ Usage:
# Enable CUDA graphs (advanced - may cause tensor aliasing errors)
uv run python examples/rtc/eval_dataset.py \
--policy.path=helper2424/smolvla_check_rtc_last3 \
--dataset.repo_id=helper2424/check_rtc \
--policy.path=<USER>/smolvla_check_rtc_last3 \
--dataset.repo_id=<USER>/check_rtc \
--use_torch_compile=true \
--torch_compile_backend=inductor \
--torch_compile_mode=max-autotune \
+3 -3
View File
@@ -28,7 +28,7 @@ For simulation environments, see eval_with_simulation.py
Usage:
# Run RTC with Real robot with RTC
uv run examples/rtc/eval_with_real_robot.py \
--policy.path=helper2424/smolvla_check_rtc_last3 \
--policy.path=<USER>/smolvla_check_rtc_last3 \
--policy.device=mps \
--rtc.enabled=true \
--rtc.execution_horizon=20 \
@@ -41,7 +41,7 @@ Usage:
# Run RTC with Real robot without RTC
uv run examples/rtc/eval_with_real_robot.py \
--policy.path=helper2424/smolvla_check_rtc_last3 \
--policy.path=<USER>/smolvla_check_rtc_last3 \
--policy.device=mps \
--rtc.enabled=false \
--robot.type=so100_follower \
@@ -53,7 +53,7 @@ Usage:
# Run RTC with Real robot with pi0.5 policy
uv run examples/rtc/eval_with_real_robot.py \
--policy.path=helper2424/pi05_check_rtc \
--policy.path=<USER>/pi05_check_rtc \
--policy.device=mps \
--rtc.enabled=true \
--rtc.execution_horizon=20 \
+7 -7
View File
@@ -59,7 +59,7 @@ keywords = ["lerobot", "huggingface", "robotics", "machine learning", "artifici
dependencies = [
# Hugging Face dependencies
"datasets>=4.0.0,<4.2.0",
"datasets>=4.0.0,<5.0.0",
"diffusers>=0.27.2,<0.36.0",
"huggingface-hub[hf-transfer,cli]>=0.34.2,<0.36.0",
"accelerate>=1.10.0,<2.0.0",
@@ -76,9 +76,9 @@ dependencies = [
"pyserial>=3.5,<4.0",
"wandb>=0.24.0,<0.25.0",
"torch>=2.2.1,<2.8.0", # TODO: Bumb dependency
"torchcodec>=0.2.1,<0.6.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # TODO: Bumb dependency
"torchvision>=0.21.0,<0.23.0", # TODO: Bumb dependency
"torch>=2.2.1,<2.11.0", # TODO: Bump dependency
"torchcodec>=0.2.1,<0.11.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # TODO: Bump dependency
"torchvision>=0.21.0,<0.26.0", # TODO: Bump dependency
"draccus==0.10.0", # TODO: Remove ==
"gymnasium>=1.1.1,<2.0.0",
@@ -360,9 +360,9 @@ ignore_errors = false
module = "lerobot.cameras.*"
ignore_errors = false
# [[tool.mypy.overrides]]
# module = "lerobot.motors.*"
# ignore_errors = false
[[tool.mypy.overrides]]
module = "lerobot.motors.*"
ignore_errors = false
# [[tool.mypy.overrides]]
# module = "lerobot.robots.*"
+1 -1
View File
@@ -13,5 +13,5 @@
# limitations under the License.
from .camera import Camera
from .configs import CameraConfig, ColorMode, Cv2Rotation
from .configs import CameraConfig, ColorMode, Cv2Backends, Cv2Rotation
from .utils import make_cameras_from_configs
+1 -1
View File
@@ -150,7 +150,7 @@ class Camera(abc.ABC):
"""
pass
def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]:
def read_latest(self, max_age_ms: int = 500) -> NDArray[Any]:
"""Return the most recent frame captured immediately (Peeking).
This method is non-blocking and returns whatever is currently in the
+23
View File
@@ -25,6 +25,10 @@ class ColorMode(str, Enum):
RGB = "rgb"
BGR = "bgr"
@classmethod
def _missing_(cls, value: object) -> None:
raise ValueError(f"`color_mode` is expected to be in {list(cls)}, but {value} is provided.")
class Cv2Rotation(int, Enum):
NO_ROTATION = 0
@@ -32,6 +36,25 @@ class Cv2Rotation(int, Enum):
ROTATE_180 = 180
ROTATE_270 = -90
@classmethod
def _missing_(cls, value: object) -> None:
raise ValueError(f"`rotation` is expected to be in {list(cls)}, but {value} is provided.")
# Subset from https://docs.opencv.org/3.4/d4/d15/group__videoio__flags__base.html
class Cv2Backends(int, Enum):
ANY = 0
V4L2 = 200
DSHOW = 700
PVAPI = 800
ANDROID = 1000
AVFOUNDATION = 1200
MSMF = 1400
@classmethod
def _missing_(cls, value: object) -> None:
raise ValueError(f"`backend` is expected to be in {list(cls)}, but {value} is provided.")
@dataclass(kw_only=True)
class CameraConfig(draccus.ChoiceRegistry, abc.ABC): # type: ignore # TODO: add type stubs for draccus
+10 -15
View File
@@ -32,10 +32,11 @@ if platform.system() == "Windows" and "OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS"
os.environ["OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS"] = "0"
import cv2 # type: ignore # TODO: add type stubs for OpenCV
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from lerobot.utils.errors import DeviceNotConnectedError
from ..camera import Camera
from ..utils import get_cv2_backend, get_cv2_rotation
from ..utils import get_cv2_rotation
from .configuration_opencv import ColorMode, OpenCVCameraConfig
# NOTE(Steven): The maximum opencv device index depends on your operating system. For instance,
@@ -117,7 +118,7 @@ class OpenCVCamera(Camera):
self.new_frame_event: Event = Event()
self.rotation: int | None = get_cv2_rotation(config.rotation)
self.backend: int = get_cv2_backend()
self.backend: int = config.backend
if self.height and self.width:
self.capture_width, self.capture_height = self.width, self.height
@@ -132,6 +133,7 @@ class OpenCVCamera(Camera):
"""Checks if the camera is currently connected and opened."""
return isinstance(self.videocapture, cv2.VideoCapture) and self.videocapture.isOpened()
@check_if_already_connected
def connect(self, warmup: bool = True) -> None:
"""
Connects to the OpenCV camera specified in the configuration.
@@ -148,8 +150,6 @@ class OpenCVCamera(Camera):
ConnectionError: If the specified camera index/path is not found or fails to open.
RuntimeError: If the camera opens but fails to apply requested settings.
"""
if self.is_connected:
raise DeviceAlreadyConnectedError(f"{self} is already connected.")
# Use 1 thread for OpenCV operations to avoid potential conflicts or
# blocking in multi-threaded applications, especially during data collection.
@@ -178,6 +178,7 @@ class OpenCVCamera(Camera):
logger.info(f"{self} connected.")
@check_if_not_connected
def _configure_capture_settings(self) -> None:
"""
Applies the specified FOURCC, FPS, width, and height settings to the connected camera.
@@ -197,8 +198,6 @@ class OpenCVCamera(Camera):
to the requested value.
DeviceNotConnectedError: If the camera is not connected.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"Cannot configure settings for {self} as it is not connected.")
# Set FOURCC first (if specified) as it can affect available FPS/resolution options
if self.config.fourcc is not None:
@@ -348,6 +347,7 @@ class OpenCVCamera(Camera):
return frame
@check_if_not_connected
def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]:
"""
Reads a single frame synchronously from the camera.
@@ -374,9 +374,6 @@ class OpenCVCamera(Camera):
f"{self} read() color_mode parameter is deprecated and will be removed in future versions."
)
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
if self.thread is None or not self.thread.is_alive():
raise RuntimeError(f"{self} read thread is not running.")
@@ -490,6 +487,7 @@ class OpenCVCamera(Camera):
self.latest_timestamp = None
self.new_frame_event.clear()
@check_if_not_connected
def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
"""
Reads the latest available frame asynchronously.
@@ -512,8 +510,6 @@ class OpenCVCamera(Camera):
TimeoutError: If no frame becomes available within the specified timeout.
RuntimeError: If an unexpected error occurs.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
if self.thread is None or not self.thread.is_alive():
raise RuntimeError(f"{self} read thread is not running.")
@@ -533,7 +529,8 @@ class OpenCVCamera(Camera):
return frame
def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]:
@check_if_not_connected
def read_latest(self, max_age_ms: int = 500) -> NDArray[Any]:
"""Return the most recent frame captured immediately (Peeking).
This method is non-blocking and returns whatever is currently in the
@@ -548,8 +545,6 @@ class OpenCVCamera(Camera):
DeviceNotConnectedError: If the camera is not connected.
RuntimeError: If the camera is connected but has not captured any frames yet.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
if self.thread is None or not self.thread.is_alive():
raise RuntimeError(f"{self} read thread is not running.")
@@ -15,9 +15,9 @@
from dataclasses import dataclass
from pathlib import Path
from ..configs import CameraConfig, ColorMode, Cv2Rotation
from ..configs import CameraConfig, ColorMode, Cv2Backends, Cv2Rotation
__all__ = ["OpenCVCameraConfig", "ColorMode", "Cv2Rotation"]
__all__ = ["OpenCVCameraConfig", "ColorMode", "Cv2Rotation", "Cv2Backends"]
@CameraConfig.register_subclass("opencv")
@@ -50,6 +50,7 @@ class OpenCVCameraConfig(CameraConfig):
rotation: Image rotation setting (0°, 90°, 180°, or 270°). Defaults to no rotation.
warmup_s: Time reading frames before returning from connect (in seconds)
fourcc: FOURCC code for video format (e.g., "MJPG", "YUYV", "I420"). Defaults to None (auto-detect).
backend: OpenCV backend identifier (https://docs.opencv.org/3.4/d4/d15/group__videoio__flags__base.html). Defaults to ANY.
Note:
- Only 3-channel color output (RGB/BGR) is currently supported.
@@ -62,22 +63,12 @@ class OpenCVCameraConfig(CameraConfig):
rotation: Cv2Rotation = Cv2Rotation.NO_ROTATION
warmup_s: int = 1
fourcc: str | None = None
backend: Cv2Backends = Cv2Backends.ANY
def __post_init__(self) -> None:
if self.color_mode not in (ColorMode.RGB, ColorMode.BGR):
raise ValueError(
f"`color_mode` is expected to be {ColorMode.RGB.value} or {ColorMode.BGR.value}, but {self.color_mode} is provided."
)
if self.rotation not in (
Cv2Rotation.NO_ROTATION,
Cv2Rotation.ROTATE_90,
Cv2Rotation.ROTATE_180,
Cv2Rotation.ROTATE_270,
):
raise ValueError(
f"`rotation` is expected to be in {(Cv2Rotation.NO_ROTATION, Cv2Rotation.ROTATE_90, Cv2Rotation.ROTATE_180, Cv2Rotation.ROTATE_270)}, but {self.rotation} is provided."
)
self.color_mode = ColorMode(self.color_mode)
self.rotation = Cv2Rotation(self.rotation)
self.backend = Cv2Backends(self.backend)
if self.fourcc is not None and (not isinstance(self.fourcc, str) or len(self.fourcc) != 4):
raise ValueError(
@@ -74,7 +74,4 @@ class Reachy2CameraConfig(CameraConfig):
f"`image_type` is expected to be 'left' or 'right' for teleop camera, and 'rgb' or 'depth' for depth camera, but {self.image_type} is provided."
)
if self.color_mode not in ["rgb", "bgr"]:
raise ValueError(
f"`color_mode` is expected to be 'rgb' or 'bgr', but {self.color_mode} is provided."
)
self.color_mode = ColorMode(self.color_mode)
@@ -32,6 +32,7 @@ if platform.system() == "Windows" and "OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS"
import cv2 # type: ignore # TODO: add type stubs for OpenCV
import numpy as np # type: ignore # TODO: add type stubs for numpy
from lerobot.utils.decorators import check_if_not_connected
from lerobot.utils.import_utils import _reachy2_sdk_available
if TYPE_CHECKING or _reachy2_sdk_available:
@@ -123,6 +124,7 @@ class Reachy2Camera(Camera):
"""
raise NotImplementedError("Camera detection is not implemented for Reachy2 cameras.")
@check_if_not_connected
def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]:
"""
Reads a single frame synchronously from the camera.
@@ -136,9 +138,6 @@ class Reachy2Camera(Camera):
"""
start_time = time.perf_counter()
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
if self.cam_manager is None:
raise DeviceNotConnectedError(f"{self} is not connected.")
@@ -184,6 +183,7 @@ class Reachy2Camera(Camera):
return frame
@check_if_not_connected
def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
"""
Same as read()
@@ -197,12 +197,11 @@ class Reachy2Camera(Camera):
TimeoutError: If no frame becomes available within the specified timeout.
RuntimeError: If an unexpected error occurs.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
return self.read()
def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]:
@check_if_not_connected
def read_latest(self, max_age_ms: int = 500) -> NDArray[Any]:
"""Return the most recent frame captured immediately (Peeking).
This method is non-blocking and returns whatever is currently in the
@@ -219,8 +218,6 @@ class Reachy2Camera(Camera):
DeviceNotConnectedError: If the camera is not connected.
RuntimeError: If the camera is connected but has not captured any frames yet.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
if self.latest_frame is None or self.latest_timestamp is None:
raise RuntimeError(f"{self} has not captured any frames yet.")
@@ -233,6 +230,7 @@ class Reachy2Camera(Camera):
return self.latest_frame
@check_if_not_connected
def disconnect(self) -> None:
"""
Stops the background read thread (if running).
@@ -240,8 +238,6 @@ class Reachy2Camera(Camera):
Raises:
DeviceNotConnectedError: If the camera is already disconnected.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} not connected.")
if self.cam_manager is not None:
self.cam_manager.disconnect()
@@ -30,7 +30,8 @@ try:
except Exception as e:
logging.info(f"Could not import realsense: {e}")
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from lerobot.utils.errors import DeviceNotConnectedError
from ..camera import Camera
from ..configs import ColorMode
@@ -152,6 +153,7 @@ class RealSenseCamera(Camera):
"""Checks if the camera pipeline is started and streams are active."""
return self.rs_pipeline is not None and self.rs_profile is not None
@check_if_already_connected
def connect(self, warmup: bool = True) -> None:
"""
Connects to the RealSense camera specified in the configuration.
@@ -169,8 +171,6 @@ class RealSenseCamera(Camera):
ConnectionError: If the camera is found but fails to start the pipeline or no RealSense devices are detected at all.
RuntimeError: If the pipeline starts but fails to apply requested settings.
"""
if self.is_connected:
raise DeviceAlreadyConnectedError(f"{self} is already connected.")
self.rs_pipeline = rs.pipeline()
rs_config = rs.config()
@@ -290,6 +290,7 @@ class RealSenseCamera(Camera):
if self.use_depth:
rs_config.enable_stream(rs.stream.depth)
@check_if_not_connected
def _configure_capture_settings(self) -> None:
"""Sets fps, width, and height from device stream if not already configured.
@@ -299,8 +300,6 @@ class RealSenseCamera(Camera):
Raises:
DeviceNotConnectedError: If device is not connected.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"Cannot validate settings for {self} as it is not connected.")
if self.rs_profile is None:
raise RuntimeError(f"{self}: rs_profile must be initialized before use.")
@@ -320,6 +319,7 @@ class RealSenseCamera(Camera):
self.width, self.height = actual_width, actual_height
self.capture_width, self.capture_height = actual_width, actual_height
@check_if_not_connected
def read_depth(self, timeout_ms: int = 200) -> NDArray[Any]:
"""
Reads a single frame (depth) synchronously from the camera.
@@ -345,9 +345,6 @@ class RealSenseCamera(Camera):
f"Failed to capture depth frame '.read_depth()'. Depth stream is not enabled for {self}."
)
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
if self.thread is None or not self.thread.is_alive():
raise RuntimeError(f"{self} read thread is not running.")
@@ -374,6 +371,7 @@ class RealSenseCamera(Camera):
return frame
@check_if_not_connected
def read(self, color_mode: ColorMode | None = None, timeout_ms: int = 0) -> NDArray[Any]:
"""
Reads a single frame (color) synchronously from the camera.
@@ -403,9 +401,6 @@ class RealSenseCamera(Camera):
f"{self} read() timeout_ms parameter is deprecated and will be removed in future versions."
)
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
if self.thread is None or not self.thread.is_alive():
raise RuntimeError(f"{self} read thread is not running.")
@@ -534,6 +529,7 @@ class RealSenseCamera(Camera):
self.new_frame_event.clear()
# NOTE(Steven): Missing implementation for depth for now
@check_if_not_connected
def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
"""
Reads the latest available frame data (color) asynchronously.
@@ -556,8 +552,6 @@ class RealSenseCamera(Camera):
TimeoutError: If no frame data becomes available within the specified timeout.
RuntimeError: If the background thread died unexpectedly or another error occurs.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
if self.thread is None or not self.thread.is_alive():
raise RuntimeError(f"{self} read thread is not running.")
@@ -578,7 +572,8 @@ class RealSenseCamera(Camera):
return frame
# NOTE(Steven): Missing implementation for depth for now
def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]:
@check_if_not_connected
def read_latest(self, max_age_ms: int = 500) -> NDArray[Any]:
"""Return the most recent (color) frame captured immediately (Peeking).
This method is non-blocking and returns whatever is currently in the
@@ -593,8 +588,6 @@ class RealSenseCamera(Camera):
DeviceNotConnectedError: If the camera is not connected.
RuntimeError: If the camera is connected but has not captured any frames yet.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
if self.thread is None or not self.thread.is_alive():
raise RuntimeError(f"{self} read thread is not running.")
@@ -60,20 +60,8 @@ class RealSenseCameraConfig(CameraConfig):
warmup_s: int = 1
def __post_init__(self) -> None:
if self.color_mode not in (ColorMode.RGB, ColorMode.BGR):
raise ValueError(
f"`color_mode` is expected to be {ColorMode.RGB.value} or {ColorMode.BGR.value}, but {self.color_mode} is provided."
)
if self.rotation not in (
Cv2Rotation.NO_ROTATION,
Cv2Rotation.ROTATE_90,
Cv2Rotation.ROTATE_180,
Cv2Rotation.ROTATE_270,
):
raise ValueError(
f"`rotation` is expected to be in {(Cv2Rotation.NO_ROTATION, Cv2Rotation.ROTATE_90, Cv2Rotation.ROTATE_180, Cv2Rotation.ROTATE_270)}, but {self.rotation} is provided."
)
self.color_mode = ColorMode(self.color_mode)
self.rotation = Cv2Rotation(self.rotation)
values = (self.fps, self.width, self.height)
if any(v is not None for v in values) and any(v is None for v in values):
-12
View File
@@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import platform
from typing import cast
from lerobot.utils.import_utils import make_device_from_device_class
@@ -68,14 +67,3 @@ def get_cv2_rotation(rotation: Cv2Rotation) -> int | None:
return int(cv2.ROTATE_90_COUNTERCLOCKWISE)
else:
return None
def get_cv2_backend() -> int:
import cv2
if platform.system() == "Windows":
return int(cv2.CAP_MSMF) # Use MSMF for Windows instead of AVFOUNDATION
# elif platform.system() == "Darwin": # macOS
# return cv2.CAP_AVFOUNDATION
else: # Linux and others
return int(cv2.CAP_ANY)
+6 -10
View File
@@ -34,7 +34,8 @@ import cv2
import numpy as np
from numpy.typing import NDArray
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from lerobot.utils.errors import DeviceNotConnectedError
from ..camera import Camera
from ..configs import ColorMode
@@ -104,6 +105,7 @@ class ZMQCamera(Camera):
"""Checks if the ZMQ socket is initialized and connected."""
return self._connected and self.context is not None and self.socket is not None
@check_if_already_connected
def connect(self, warmup: bool = True) -> None:
"""Connect to ZMQ camera server.
@@ -111,8 +113,6 @@ class ZMQCamera(Camera):
warmup (bool): If True, waits for the camera to provide at least one
valid frame before returning. Defaults to True.
"""
if self.is_connected:
raise DeviceAlreadyConnectedError(f"{self} is already connected.")
logger.info(f"Connecting to {self}...")
@@ -211,6 +211,7 @@ class ZMQCamera(Camera):
return frame
@check_if_not_connected
def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]:
"""
Reads a single frame synchronously from the camera.
@@ -228,9 +229,6 @@ class ZMQCamera(Camera):
f"{self} read() color_mode parameter is deprecated and will be removed in future versions."
)
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
if self.thread is None or not self.thread.is_alive():
raise RuntimeError(f"{self} read thread is not running.")
@@ -301,6 +299,7 @@ class ZMQCamera(Camera):
self.latest_timestamp = None
self.new_frame_event.clear()
@check_if_not_connected
def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
"""
Reads the latest available frame asynchronously.
@@ -317,8 +316,6 @@ class ZMQCamera(Camera):
TimeoutError: If no frame data becomes available within the specified timeout.
RuntimeError: If the background thread is not running.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
if self.thread is None or not self.thread.is_alive():
raise RuntimeError(f"{self} read thread is not running.")
@@ -335,6 +332,7 @@ class ZMQCamera(Camera):
return frame
@check_if_not_connected
def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]:
"""Return the most recent frame captured immediately (Peeking).
@@ -350,8 +348,6 @@ class ZMQCamera(Camera):
DeviceNotConnectedError: If the camera is not connected.
RuntimeError: If the camera is connected but has not captured any frames yet.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
if self.thread is None or not self.thread.is_alive():
raise RuntimeError(f"{self} read thread is not running.")
+1 -4
View File
@@ -32,10 +32,7 @@ class ZMQCameraConfig(CameraConfig):
warmup_s: int = 1
def __post_init__(self) -> None:
if self.color_mode not in (ColorMode.RGB, ColorMode.BGR):
raise ValueError(
f"`color_mode` is expected to be {ColorMode.RGB.value} or {ColorMode.BGR.value}, but {self.color_mode} is provided."
)
self.color_mode = ColorMode(self.color_mode)
if self.timeout_ms <= 0:
raise ValueError(f"`timeout_ms` must be positive, but {self.timeout_ms} is provided.")
+6 -6
View File
@@ -45,12 +45,12 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno
Args:
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
current step and additional steps going back).
input_shapes: A dictionary defining the shapes of the input data for the policy.
output_shapes: A dictionary defining the shapes of the output data for the policy.
input_normalization_modes: A dictionary with key representing the modality and the value specifies the
normalization mode to apply.
output_normalization_modes: Similar dictionary as `input_normalization_modes`, but to unnormalize to
the original scale.
input_features: A dictionary defining the PolicyFeature of the input data for the policy. The key represents
the input data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
output_features: A dictionary defining the PolicyFeature of the output data for the policy. The key represents
the output data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
normalization_mapping: A dictionary that maps from a str value of FeatureType (e.g., "STATE", "VISUAL") to
a corresponding NormalizationMode (e.g., NormalizationMode.MIN_MAX)
"""
n_obs_steps: int = 1
+1 -1
View File
@@ -656,7 +656,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
repo_id (str): This is the repo id that will be used to fetch the dataset. Locally, the dataset
will be stored under root/repo_id.
root (Path | None, optional): Local directory to use for downloading/writing files. You can also
set the LEROBOT_HOME environment variable to point to a different location. Defaults to
set the HF_LEROBOT_HOME environment variable to point to a different location. Defaults to
'~/.cache/huggingface/lerobot'.
episodes (list[int] | None, optional): If specified, this will only load episodes specified by
their episode_index in this list. Defaults to None.
+10 -9
View File
@@ -216,16 +216,17 @@ class ImageTransformsConfig:
def make_transform_from_config(cfg: ImageTransformConfig):
if cfg.type == "Identity":
return v2.Identity(**cfg.kwargs)
elif cfg.type == "ColorJitter":
return v2.ColorJitter(**cfg.kwargs)
elif cfg.type == "SharpnessJitter":
if cfg.type == "SharpnessJitter":
return SharpnessJitter(**cfg.kwargs)
elif cfg.type == "RandomAffine":
return v2.RandomAffine(**cfg.kwargs)
else:
raise ValueError(f"Transform '{cfg.type}' is not valid.")
transform_cls = getattr(v2, cfg.type, None)
if isinstance(transform_cls, type) and issubclass(transform_cls, Transform):
return transform_cls(**cfg.kwargs)
raise ValueError(
f"Transform '{cfg.type}' is not valid. It must be a class in "
f"torchvision.transforms.v2 or 'SharpnessJitter'."
)
class ImageTransforms(Transform):
+3 -13
View File
@@ -122,19 +122,9 @@ def load_nested_dataset(
raise FileNotFoundError(f"Provided directory does not contain any parquet file: {pq_dir}")
with SuppressProgressBars():
# When no filtering needed, Dataset uses memory-mapped loading for efficiency
# PyArrow loads the entire dataset into memory
if episodes is None:
return Dataset.from_parquet([str(path) for path in paths], features=features)
arrow_dataset = pa_ds.dataset(paths, format="parquet")
filter_expr = pa_ds.field("episode_index").isin(episodes)
table = arrow_dataset.to_table(filter=filter_expr)
if features is not None:
table = table.cast(features.arrow_schema)
return Dataset(table)
# We use .from_parquet() memory-mapped loading for efficiency
filters = pa_ds.field("episode_index").isin(episodes) if episodes is not None else None
return Dataset.from_parquet([str(path) for path in paths], filters=filters, features=features)
def get_parquet_num_frames(parquet_path: str | Path) -> int:
@@ -529,7 +529,7 @@ if __name__ == "__main__":
type=str,
required=True,
help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset "
"(e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).",
"(e.g. `lerobot/pusht`, `<USER>/aloha_sim_insertion_human`).",
)
parser.add_argument(
"--branch",
+1
View File
@@ -205,6 +205,7 @@ class ObservationConfig:
add_joint_velocity_to_observation: bool = False
add_current_to_observation: bool = False
add_ee_pose_to_observation: bool = False
display_cameras: bool = False
+7 -2
View File
@@ -112,6 +112,7 @@ class LiberoEnv(gym.Env):
visualization_height: int = 480,
init_states: bool = True,
episode_index: int = 0,
n_envs: int = 1,
camera_name_mapping: dict[str, str] | None = None,
num_steps_wait: int = 10,
control_mode: str = "relative",
@@ -145,7 +146,9 @@ class LiberoEnv(gym.Env):
self.episode_length = episode_length
# Load once and keep
self._init_states = get_task_init_states(task_suite, self.task_id) if self.init_states else None
self._init_state_id = self.episode_index # tie each sub-env to a fixed init state
self._reset_stride = n_envs # when performing a reset, append `_reset_stride` to `init_state_id`.
self.init_state_id = self.episode_index # tie each sub-env to a fixed init state
self._env = self._make_envs_task(task_suite, self.task_id)
default_steps = 500
@@ -295,7 +298,8 @@ class LiberoEnv(gym.Env):
self._env.seed(seed)
raw_obs = self._env.reset()
if self.init_states and self._init_states is not None:
raw_obs = self._env.set_init_state(self._init_states[self._init_state_id])
raw_obs = self._env.set_init_state(self._init_states[self.init_state_id % len(self._init_states)])
self.init_state_id += self._reset_stride # Change init_state_id when reset
# After reset, objects may be unstable (slightly floating, intersecting, etc.).
# Step the simulator with a no-op action for a few frames so everything settles.
@@ -373,6 +377,7 @@ def _make_env_fns(
init_states=init_states,
episode_length=episode_length,
episode_index=episode_index,
n_envs=n_envs,
control_mode=control_mode,
**local_kwargs,
)
+6 -4
View File
@@ -221,7 +221,7 @@ class RangeFinderGUI:
self.bus = bus
self.groups = groups if groups is not None else {"all": list(bus.motors)}
self.group_names = list(groups)
self.group_names = list(self.groups)
self.current_group = self.group_names[0]
if not bus.is_connected:
@@ -230,18 +230,20 @@ class RangeFinderGUI:
self.calibration = bus.read_calibration()
self.res_table = bus.model_resolution_table
self.present_cache = {
m: bus.read("Present_Position", m, normalize=False) for motors in groups.values() for m in motors
m: bus.read("Present_Position", m, normalize=False)
for motors in self.groups.values()
for m in motors
}
pygame.init()
self.font = pygame.font.Font(None, FONT_SIZE)
label_pad = max(self.font.size(m)[0] for ms in groups.values() for m in ms)
label_pad = max(self.font.size(m)[0] for ms in self.groups.values() for m in ms)
self.label_pad = label_pad
width = 40 + label_pad + BAR_LEN + 6 + BTN_W + 10 + SAVE_W + 10
self.controls_bottom = 10 + SAVE_H
self.base_y = self.controls_bottom + TOP_GAP
height = self.base_y + PADDING_Y * len(groups[self.current_group]) + 40
height = self.base_y + PADDING_Y * len(self.groups[self.current_group]) + 40
self.screen = pygame.display.set_mode((width, height))
pygame.display.set_caption("Motors range finder")
+41 -15
View File
@@ -23,6 +23,7 @@ from copy import deepcopy
from functools import cached_property
from typing import TYPE_CHECKING, Any, TypedDict
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from lerobot.utils.import_utils import _can_available
if TYPE_CHECKING or _can_available:
@@ -36,7 +37,6 @@ else:
import numpy as np
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.utils import enter_pressed, move_cursor_up
@@ -155,6 +155,7 @@ class DamiaoMotorsBus(MotorsBusBase):
"""Check if the CAN bus is connected."""
return self._is_connected and self.canbus is not None
@check_if_already_connected
def connect(self, handshake: bool = True) -> None:
"""
Open the CAN bus and initialize communication.
@@ -162,10 +163,6 @@ class DamiaoMotorsBus(MotorsBusBase):
Args:
handshake: If True, ping all motors to verify they're present
"""
if self.is_connected:
raise DeviceAlreadyConnectedError(
f"{self.__class__.__name__}('{self.port}') is already connected."
)
try:
# Auto-detect interface type based on port name
@@ -211,6 +208,9 @@ class DamiaoMotorsBus(MotorsBusBase):
logger.info("Starting handshake with motors...")
# Drain any pending messages
if self.canbus is None:
raise RuntimeError("CAN bus is not initialized.")
while self.canbus.recv(timeout=0.01):
pass
@@ -246,6 +246,7 @@ class DamiaoMotorsBus(MotorsBusBase):
)
logger.info("Handshake successful. All motors ready.")
@check_if_not_connected
def disconnect(self, disable_torque: bool = True) -> None:
"""
Close the CAN bus connection.
@@ -253,8 +254,6 @@ class DamiaoMotorsBus(MotorsBusBase):
Args:
disable_torque: If True, disable torque on all motors before disconnecting
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self.__class__.__name__}('{self.port}') is not connected.")
if disable_torque:
try:
@@ -283,6 +282,10 @@ class DamiaoMotorsBus(MotorsBusBase):
recv_id = self._get_motor_recv_id(motor)
data = [0xFF] * 7 + [command_byte]
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False, is_fd=self.use_can_fd)
if self.canbus is None:
raise RuntimeError("CAN bus is not initialized.")
self.canbus.send(msg)
if msg := self._recv_motor_response(expected_recv_id=recv_id):
self._process_response(motor_name, msg)
@@ -341,6 +344,10 @@ class DamiaoMotorsBus(MotorsBusBase):
recv_id = self._get_motor_recv_id(motor)
data = [motor_id & 0xFF, (motor_id >> 8) & 0xFF, CAN_CMD_REFRESH, 0, 0, 0, 0, 0]
msg = can.Message(arbitration_id=CAN_PARAM_ID, data=data, is_extended_id=False, is_fd=self.use_can_fd)
if self.canbus is None:
raise RuntimeError("CAN bus is not initialized.")
self.canbus.send(msg)
return self._recv_motor_response(expected_recv_id=recv_id)
@@ -356,6 +363,10 @@ class DamiaoMotorsBus(MotorsBusBase):
Returns:
CAN message if received, None otherwise
"""
if self.canbus is None:
raise RuntimeError("CAN bus is not initialized.")
try:
start_time = time.time()
messages_seen = []
@@ -394,10 +405,13 @@ class DamiaoMotorsBus(MotorsBusBase):
Returns:
Dictionary mapping recv_id to CAN message
"""
responses = {}
responses: dict[int, can.Message] = {}
expected_set = set(expected_recv_ids)
start_time = time.time()
if self.canbus is None:
raise RuntimeError("CAN bus is not initialized.")
try:
while len(responses) < len(expected_recv_ids) and (time.time() - start_time) < timeout:
# 100us poll timeout
@@ -461,6 +475,9 @@ class DamiaoMotorsBus(MotorsBusBase):
motor_name = self._get_motor_name(motor)
motor_type = self._motor_types[motor_name]
if self.canbus is None:
raise RuntimeError("CAN bus is not initialized.")
data = self._encode_mit_packet(motor_type, kp, kd, position_degrees, velocity_deg_per_sec, torque)
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False, is_fd=self.use_can_fd)
self.canbus.send(msg)
@@ -488,6 +505,9 @@ class DamiaoMotorsBus(MotorsBusBase):
recv_id_to_motor: dict[int, str] = {}
if self.canbus is None:
raise RuntimeError("CAN bus is not initialized.")
# Step 1: Send all MIT control commands
for motor, (kp, kd, position_degrees, velocity_deg_per_sec, torque) in commands.items():
motor_id = self._get_motor_id(motor)
@@ -562,10 +582,9 @@ class DamiaoMotorsBus(MotorsBusBase):
except Exception as e:
logger.warning(f"Failed to decode response from {motor}: {e}")
@check_if_not_connected
def read(self, data_name: str, motor: str) -> Value:
"""Read a value from a single motor. Positions are always in degrees."""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
# Refresh motor to get latest state
msg = self._refresh_motor(motor)
@@ -595,6 +614,7 @@ class DamiaoMotorsBus(MotorsBusBase):
raise ValueError(f"Unknown data_name: {data_name}")
return mapping[data_name]
@check_if_not_connected
def write(
self,
data_name: str,
@@ -605,8 +625,6 @@ class DamiaoMotorsBus(MotorsBusBase):
Write a value to a single motor. Positions are always in degrees.
Can write 'Goal_Position', 'Kp', or 'Kd'.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
if data_name in ("Kp", "Kd"):
self._gains[motor][data_name.lower()] = float(value)
@@ -656,6 +674,10 @@ class DamiaoMotorsBus(MotorsBusBase):
def _batch_refresh(self, motors: list[str]) -> None:
"""Internal helper to refresh a list of motors and update cache."""
if self.canbus is None:
raise RuntimeError("CAN bus is not initialized.")
# Send refresh commands
for motor in motors:
motor_id = self._get_motor_id(motor)
@@ -678,10 +700,12 @@ class DamiaoMotorsBus(MotorsBusBase):
else:
logger.warning(f"Packet drop: {motor} (ID: 0x{recv_id:02X}). Using last known state.")
def sync_write(self, data_name: str, values: Value | dict[str, Value]) -> None:
@check_if_not_connected
def sync_write(self, data_name: str, values: dict[str, Value]) -> None:
"""
Write values to multiple motors simultaneously. Positions are always in degrees.
"""
if data_name in ("Kp", "Kd"):
key = data_name.lower()
for motor, val in values.items():
@@ -690,6 +714,8 @@ class DamiaoMotorsBus(MotorsBusBase):
elif data_name == "Goal_Position":
# Step 1: Send all MIT control commands
recv_id_to_motor: dict[int, str] = {}
if self.canbus is None:
raise RuntimeError("CAN bus is not initialized.")
for motor, value_degrees in values.items():
motor_id = self._get_motor_id(motor)
motor_name = self._get_motor_name(motor)
@@ -732,9 +758,9 @@ class DamiaoMotorsBus(MotorsBusBase):
def record_ranges_of_motion(
self,
motors: NameOrID | list[NameOrID] | None = None,
motors: str | list[str] | None = None,
display_values: bool = True,
) -> tuple[dict[NameOrID, Value], dict[NameOrID, Value]]:
) -> tuple[dict[str, Value], dict[str, Value]]:
"""
Interactively record the min/max values of each motor in degrees.
+8 -8
View File
@@ -181,10 +181,10 @@ class DynamixelMotorsBus(SerialMotorsBus):
for motor, m in self.motors.items():
calibration[motor] = MotorCalibration(
id=m.id,
drive_mode=drive_modes[motor],
homing_offset=offsets[motor],
range_min=mins[motor],
range_max=maxes[motor],
drive_mode=int(drive_modes[motor]),
homing_offset=int(offsets[motor]),
range_min=int(mins[motor]),
range_max=int(maxes[motor]),
)
return calibration
@@ -198,7 +198,7 @@ class DynamixelMotorsBus(SerialMotorsBus):
if cache:
self.calibration = calibration_dict
def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
def disable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None:
for motor in self._get_motors_list(motors):
self.write("Torque_Enable", motor, TorqueMode.DISABLED.value, num_retry=num_retry)
@@ -206,7 +206,7 @@ class DynamixelMotorsBus(SerialMotorsBus):
addr, length = get_address(self.model_ctrl_table, model, "Torque_Enable")
self._write(addr, length, motor, TorqueMode.DISABLED.value, num_retry=num_retry)
def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
def enable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None:
for motor in self._get_motors_list(motors):
self.write("Torque_Enable", motor, TorqueMode.ENABLED.value, num_retry=num_retry)
@@ -235,7 +235,7 @@ class DynamixelMotorsBus(SerialMotorsBus):
On Dynamixel Motors:
Present_Position = Actual_Position + Homing_Offset
"""
half_turn_homings = {}
half_turn_homings: dict[NameOrID, Value] = {}
for motor, pos in positions.items():
model = self._get_motor_model(motor)
max_res = self.model_resolution_table[model] - 1
@@ -258,6 +258,6 @@ class DynamixelMotorsBus(SerialMotorsBus):
if raise_on_error:
raise ConnectionError(self.packet_handler.getTxRxResult(comm))
return
return None
return {id_: data[0] for id_, data in data_list.items()}
+9 -9
View File
@@ -126,7 +126,7 @@ class FeetechMotorsBus(SerialMotorsBus):
self.port_handler = scs.PortHandler(self.port)
# HACK: monkeypatch
self.port_handler.setPacketTimeout = patch_setPacketTimeout.__get__(
self.port_handler.setPacketTimeout = patch_setPacketTimeout.__get__( # type: ignore[method-assign]
self.port_handler, scs.PortHandler
)
self.packet_handler = scs.PacketHandler(protocol_version)
@@ -262,9 +262,9 @@ class FeetechMotorsBus(SerialMotorsBus):
calibration[motor] = MotorCalibration(
id=m.id,
drive_mode=0,
homing_offset=offsets[motor],
range_min=mins[motor],
range_max=maxes[motor],
homing_offset=int(offsets[motor]),
range_min=int(mins[motor]),
range_max=int(maxes[motor]),
)
return calibration
@@ -284,7 +284,7 @@ class FeetechMotorsBus(SerialMotorsBus):
On Feetech Motors:
Present_Position = Actual_Position - Homing_Offset
"""
half_turn_homings = {}
half_turn_homings: dict[NameOrID, Value] = {}
for motor, pos in positions.items():
model = self._get_motor_model(motor)
max_res = self.model_resolution_table[model] - 1
@@ -292,7 +292,7 @@ class FeetechMotorsBus(SerialMotorsBus):
return half_turn_homings
def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
def disable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None:
for motor in self._get_motors_list(motors):
self.write("Torque_Enable", motor, TorqueMode.DISABLED.value, num_retry=num_retry)
self.write("Lock", motor, 0, num_retry=num_retry)
@@ -303,7 +303,7 @@ class FeetechMotorsBus(SerialMotorsBus):
addr, length = get_address(self.model_ctrl_table, model, "Lock")
self._write(addr, length, motor, 0, num_retry=num_retry)
def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
def enable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None:
for motor in self._get_motors_list(motors):
self.write("Torque_Enable", motor, TorqueMode.ENABLED.value, num_retry=num_retry)
self.write("Lock", motor, 1, num_retry=num_retry)
@@ -334,7 +334,7 @@ class FeetechMotorsBus(SerialMotorsBus):
def _broadcast_ping(self) -> tuple[dict[int, int], int]:
import scservo_sdk as scs
data_list = {}
data_list: dict[int, int] = {}
status_length = 6
@@ -414,7 +414,7 @@ class FeetechMotorsBus(SerialMotorsBus):
if not self._is_comm_success(comm):
if raise_on_error:
raise ConnectionError(self.packet_handler.getTxRxResult(comm))
return
return None
ids_errors = {id_: status for id_, status in ids_status.items() if self._is_error(status)}
if ids_errors:
+93 -90
View File
@@ -23,6 +23,7 @@ from __future__ import annotations
import abc
import logging
from collections.abc import Sequence
from contextlib import contextmanager
from dataclasses import dataclass
from enum import Enum
@@ -93,7 +94,7 @@ class MotorsBusBase(abc.ABC):
pass
@abc.abstractmethod
def sync_write(self, data_name: str, values: Value | dict[str, Value]) -> None:
def sync_write(self, data_name: str, values: dict[str, Value]) -> None:
"""Write values to multiple motors."""
pass
@@ -179,15 +180,16 @@ class Motor:
class PortHandler(Protocol):
def __init__(self, port_name):
self.is_open: bool
self.baudrate: int
self.packet_start_time: float
self.packet_timeout: float
self.tx_time_per_byte: float
self.is_using: bool
self.port_name: str
self.ser: serial.Serial
is_open: bool
baudrate: int
packet_start_time: float
packet_timeout: float
tx_time_per_byte: float
is_using: bool
port_name: str
ser: serial.Serial
def __init__(self, port_name: str) -> None: ...
def openPort(self): ...
def closePort(self): ...
@@ -240,19 +242,22 @@ class PacketHandler(Protocol):
def regWriteTxRx(self, port, id, address, length, data): ...
def syncReadTx(self, port, start_address, data_length, param, param_length): ...
def syncWriteTxOnly(self, port, start_address, data_length, param, param_length): ...
def broadcastPing(self, port): ...
class GroupSyncRead(Protocol):
def __init__(self, port, ph, start_address, data_length):
self.port: str
self.ph: PortHandler
self.start_address: int
self.data_length: int
self.last_result: bool
self.is_param_changed: bool
self.param: list
self.data_dict: dict
port: str
ph: PortHandler
start_address: int
data_length: int
last_result: bool
is_param_changed: bool
param: list
data_dict: dict
def __init__(
self, port: PortHandler, ph: PacketHandler, start_address: int, data_length: int
) -> None: ...
def makeParam(self): ...
def addParam(self, id): ...
def removeParam(self, id): ...
@@ -265,15 +270,17 @@ class GroupSyncRead(Protocol):
class GroupSyncWrite(Protocol):
def __init__(self, port, ph, start_address, data_length):
self.port: str
self.ph: PortHandler
self.start_address: int
self.data_length: int
self.is_param_changed: bool
self.param: list
self.data_dict: dict
port: str
ph: PortHandler
start_address: int
data_length: int
is_param_changed: bool
param: list
data_dict: dict
def __init__(
self, port: PortHandler, ph: PacketHandler, start_address: int, data_length: int
) -> None: ...
def makeParam(self): ...
def addParam(self, id, data): ...
def removeParam(self, id): ...
@@ -400,7 +407,7 @@ class SerialMotorsBus(MotorsBusBase):
else:
raise TypeError(f"'{motor}' should be int, str.")
def _get_motor_model(self, motor: NameOrID) -> int:
def _get_motor_model(self, motor: NameOrID) -> str:
if isinstance(motor, str):
return self.motors[motor].model
elif isinstance(motor, int):
@@ -408,17 +415,19 @@ class SerialMotorsBus(MotorsBusBase):
else:
raise TypeError(f"'{motor}' should be int, str.")
def _get_motors_list(self, motors: str | list[str] | None) -> list[str]:
def _get_motors_list(self, motors: NameOrID | Sequence[NameOrID] | None) -> list[str]:
if motors is None:
return list(self.motors)
elif isinstance(motors, str):
return [motors]
elif isinstance(motors, list):
return motors.copy()
elif isinstance(motors, int):
return [self._id_to_name(motors)]
elif isinstance(motors, Sequence):
return [m if isinstance(m, str) else self._id_to_name(m) for m in motors]
else:
raise TypeError(motors)
def _get_ids_values_dict(self, values: Value | dict[str, Value] | None) -> list[str]:
def _get_ids_values_dict(self, values: Value | dict[str, Value] | None) -> dict[int, Value]:
if isinstance(values, (int | float)):
return dict.fromkeys(self.ids, values)
elif isinstance(values, dict):
@@ -640,18 +649,19 @@ class SerialMotorsBus(MotorsBusBase):
pass
@abc.abstractmethod
def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
def enable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None:
"""Enable torque on selected motors.
Args:
motor (int): Same semantics as :pymeth:`disable_torque`. Defaults to `None`.
motors (int | str | list[str] | None, optional): Same semantics as :pymeth:`disable_torque`.
Defaults to `None`.
num_retry (int, optional): Number of additional retry attempts on communication failure.
Defaults to 0.
"""
pass
@contextmanager
def torque_disabled(self, motors: int | str | list[str] | None = None):
def torque_disabled(self, motors: str | list[str] | None = None):
"""Context-manager that guarantees torque is re-enabled.
This helper is useful to temporarily disable torque when configuring motors.
@@ -728,24 +738,19 @@ class SerialMotorsBus(MotorsBusBase):
"""
pass
def reset_calibration(self, motors: NameOrID | list[NameOrID] | None = None) -> None:
def reset_calibration(self, motors: NameOrID | Sequence[NameOrID] | None = None) -> None:
"""Restore factory calibration for the selected motors.
Homing offset is set to ``0`` and min/max position limits are set to the full usable range.
The in-memory :pyattr:`calibration` is cleared.
Args:
motors (NameOrID | list[NameOrID] | None, optional): Selection of motors. `None` (default)
motors (NameOrID | Sequence[NameOrID] | None, optional): Selection of motors. `None` (default)
resets every motor.
"""
if motors is None:
motors = list(self.motors)
elif isinstance(motors, (str | int)):
motors = [motors]
elif not isinstance(motors, list):
raise TypeError(motors)
motor_names = self._get_motors_list(motors)
for motor in motors:
for motor in motor_names:
model = self._get_motor_model(motor)
max_res = self.model_resolution_table[model] - 1
self.write("Homing_Offset", motor, 0, normalize=False)
@@ -754,7 +759,9 @@ class SerialMotorsBus(MotorsBusBase):
self.calibration = {}
def set_half_turn_homings(self, motors: NameOrID | list[NameOrID] | None = None) -> dict[NameOrID, Value]:
def set_half_turn_homings(
self, motors: NameOrID | Sequence[NameOrID] | None = None
) -> dict[NameOrID, Value]:
"""Centre each motor range around its current position.
The function computes and writes a homing offset such that the present position becomes exactly one
@@ -764,17 +771,12 @@ class SerialMotorsBus(MotorsBusBase):
motors (NameOrID | list[NameOrID] | None, optional): Motors to adjust. Defaults to all motors (`None`).
Returns:
dict[NameOrID, Value]: Mapping *motor → written homing offset*.
dict[str, Value]: Mapping *motor name → written homing offset*.
"""
if motors is None:
motors = list(self.motors)
elif isinstance(motors, (str | int)):
motors = [motors]
elif not isinstance(motors, list):
raise TypeError(motors)
motor_names = self._get_motors_list(motors)
self.reset_calibration(motors)
actual_positions = self.sync_read("Present_Position", motors, normalize=False)
self.reset_calibration(motor_names)
actual_positions = self.sync_read("Present_Position", motor_names, normalize=False)
homing_offsets = self._get_half_turn_homings(actual_positions)
for motor, offset in homing_offsets.items():
self.write("Homing_Offset", motor, offset)
@@ -786,8 +788,8 @@ class SerialMotorsBus(MotorsBusBase):
pass
def record_ranges_of_motion(
self, motors: NameOrID | list[NameOrID] | None = None, display_values: bool = True
) -> tuple[dict[NameOrID, Value], dict[NameOrID, Value]]:
self, motors: NameOrID | Sequence[NameOrID] | None = None, display_values: bool = True
) -> tuple[dict[str, Value], dict[str, Value]]:
"""Interactively record the min/max encoder values of each motor.
Move the joints by hand (with torque disabled) while the method streams live positions. Press
@@ -799,30 +801,25 @@ class SerialMotorsBus(MotorsBusBase):
display_values (bool, optional): When `True` (default) a live table is printed to the console.
Returns:
tuple[dict[NameOrID, Value], dict[NameOrID, Value]]: Two dictionaries *mins* and *maxes* with the
tuple[dict[str, Value], dict[str, Value]]: Two dictionaries *mins* and *maxes* with the
extreme values observed for each motor.
"""
if motors is None:
motors = list(self.motors)
elif isinstance(motors, (str | int)):
motors = [motors]
elif not isinstance(motors, list):
raise TypeError(motors)
motor_names = self._get_motors_list(motors)
start_positions = self.sync_read("Present_Position", motors, normalize=False)
start_positions = self.sync_read("Present_Position", motor_names, normalize=False)
mins = start_positions.copy()
maxes = start_positions.copy()
user_pressed_enter = False
while not user_pressed_enter:
positions = self.sync_read("Present_Position", motors, normalize=False)
positions = self.sync_read("Present_Position", motor_names, normalize=False)
mins = {motor: min(positions[motor], min_) for motor, min_ in mins.items()}
maxes = {motor: max(positions[motor], max_) for motor, max_ in maxes.items()}
if display_values:
print("\n-------------------------------------------")
print(f"{'NAME':<15} | {'MIN':>6} | {'POS':>6} | {'MAX':>6}")
for motor in motors:
for motor in motor_names:
print(f"{motor:<15} | {mins[motor]:>6} | {positions[motor]:>6} | {maxes[motor]:>6}")
if enter_pressed():
@@ -830,9 +827,9 @@ class SerialMotorsBus(MotorsBusBase):
if display_values and not user_pressed_enter:
# Move cursor up to overwrite the previous output
move_cursor_up(len(motors) + 3)
move_cursor_up(len(motor_names) + 3)
same_min_max = [motor for motor in motors if mins[motor] == maxes[motor]]
same_min_max = [motor for motor in motor_names if mins[motor] == maxes[motor]]
if same_min_max:
raise ValueError(f"Some motors have the same min and max values:\n{pformat(same_min_max)}")
@@ -955,12 +952,12 @@ class SerialMotorsBus(MotorsBusBase):
if raise_on_error:
raise ConnectionError(self.packet_handler.getTxRxResult(comm))
else:
return
return None
if self._is_error(error):
if raise_on_error:
raise RuntimeError(self.packet_handler.getRxPacketError(error))
else:
return
return None
return model_number
@@ -1007,12 +1004,13 @@ class SerialMotorsBus(MotorsBusBase):
err_msg = f"Failed to read '{data_name}' on {id_=} after {num_retry + 1} tries."
value, _, _ = self._read(addr, length, id_, num_retry=num_retry, raise_on_error=True, err_msg=err_msg)
id_value = self._decode_sign(data_name, {id_: value})
decoded = self._decode_sign(data_name, {id_: value})
if normalize and data_name in self.normalized_data:
id_value = self._normalize(id_value)
normalized = self._normalize(decoded)
return normalized[id_]
return id_value[id_]
return decoded[id_]
def _read(
self,
@@ -1023,7 +1021,7 @@ class SerialMotorsBus(MotorsBusBase):
num_retry: int = 0,
raise_on_error: bool = True,
err_msg: str = "",
) -> tuple[int, int]:
) -> tuple[int, int, int]:
if length == 1:
read_fn = self.packet_handler.read1ByteTxRx
elif length == 2:
@@ -1073,13 +1071,14 @@ class SerialMotorsBus(MotorsBusBase):
model = self.motors[motor].model
addr, length = get_address(self.model_ctrl_table, model, data_name)
int_value = int(value)
if normalize and data_name in self.normalized_data:
value = self._unnormalize({id_: value})[id_]
int_value = self._unnormalize({id_: value})[id_]
value = self._encode_sign(data_name, {id_: value})[id_]
int_value = self._encode_sign(data_name, {id_: int_value})[id_]
err_msg = f"Failed to write '{data_name}' on {id_=} with '{value}' after {num_retry + 1} tries."
self._write(addr, length, id_, value, num_retry=num_retry, raise_on_error=True, err_msg=err_msg)
err_msg = f"Failed to write '{data_name}' on {id_=} with '{int_value}' after {num_retry + 1} tries."
self._write(addr, length, id_, int_value, num_retry=num_retry, raise_on_error=True, err_msg=err_msg)
def _write(
self,
@@ -1113,7 +1112,7 @@ class SerialMotorsBus(MotorsBusBase):
def sync_read(
self,
data_name: str,
motors: str | list[str] | None = None,
motors: NameOrID | Sequence[NameOrID] | None = None,
*,
normalize: bool = True,
num_retry: int = 0,
@@ -1122,7 +1121,7 @@ class SerialMotorsBus(MotorsBusBase):
Args:
data_name (str): Register name.
motors (str | list[str] | None, optional): Motors to query. `None` (default) reads every motor.
motors (NameOrID | Sequence[NameOrID] | None, optional): Motors to query. `None` (default) reads every motor.
normalize (bool, optional): Normalisation flag. Defaults to `True`.
num_retry (int, optional): Retry attempts. Defaults to `0`.
@@ -1143,16 +1142,17 @@ class SerialMotorsBus(MotorsBusBase):
addr, length = get_address(self.model_ctrl_table, model, data_name)
err_msg = f"Failed to sync read '{data_name}' on {ids=} after {num_retry + 1} tries."
ids_values, _ = self._sync_read(
raw_ids_values, _ = self._sync_read(
addr, length, ids, num_retry=num_retry, raise_on_error=True, err_msg=err_msg
)
ids_values = self._decode_sign(data_name, ids_values)
decoded = self._decode_sign(data_name, raw_ids_values)
if normalize and data_name in self.normalized_data:
ids_values = self._normalize(ids_values)
normalized = self._normalize(decoded)
return {self._id_to_name(id_): value for id_, value in normalized.items()}
return {self._id_to_name(id_): value for id_, value in ids_values.items()}
return {self._id_to_name(id_): value for id_, value in decoded.items()}
def _sync_read(
self,
@@ -1224,21 +1224,24 @@ class SerialMotorsBus(MotorsBusBase):
num_retry (int, optional): Retry attempts. Defaults to `0`.
"""
ids_values = self._get_ids_values_dict(values)
models = [self._id_to_model(id_) for id_ in ids_values]
raw_ids_values = self._get_ids_values_dict(values)
models = [self._id_to_model(id_) for id_ in raw_ids_values]
if self._has_different_ctrl_tables:
assert_same_address(self.model_ctrl_table, models, data_name)
model = next(iter(models))
addr, length = get_address(self.model_ctrl_table, model, data_name)
int_ids_values = {id_: int(val) for id_, val in raw_ids_values.items()}
if normalize and data_name in self.normalized_data:
ids_values = self._unnormalize(ids_values)
int_ids_values = self._unnormalize(raw_ids_values)
ids_values = self._encode_sign(data_name, ids_values)
int_ids_values = self._encode_sign(data_name, int_ids_values)
err_msg = f"Failed to sync write '{data_name}' with {ids_values=} after {num_retry + 1} tries."
self._sync_write(addr, length, ids_values, num_retry=num_retry, raise_on_error=True, err_msg=err_msg)
err_msg = f"Failed to sync write '{data_name}' with ids_values={int_ids_values} after {num_retry + 1} tries."
self._sync_write(
addr, length, int_ids_values, num_retry=num_retry, raise_on_error=True, err_msg=err_msg
)
def _sync_write(
self,
+7 -16
View File
@@ -28,7 +28,7 @@ class ACTConfig(PreTrainedConfig):
Defaults are configured for training on bimanual Aloha tasks like "insertion" or "transfer".
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
Those are: `input_shapes` and 'output_shapes`.
Those are: `input_features` and `output_features`.
Notes on the inputs and outputs:
- Either:
@@ -48,21 +48,12 @@ class ACTConfig(PreTrainedConfig):
This should be no greater than the chunk size. For example, if the chunk size size 100, you may
set this to 50. This would mean that the model predicts 100 steps worth of actions, runs 50 in the
environment, and throws the other 50 out.
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
the input data name, and the value is a list indicating the dimensions of the corresponding data.
For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96],
indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't
include batch dimension or temporal dimension.
output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents
the output data name, and the value is a list indicating the dimensions of the corresponding data.
For example, "action" refers to an output shape of [14], indicating 14-dimensional actions.
Importantly, `output_shapes` doesn't include batch dimension or temporal dimension.
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
[-1, 1] range.
output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
original scale. Note that this is also used for normalizing the training targets.
input_features: A dictionary defining the PolicyFeature of the input data for the policy. The key represents
the input data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
output_features: A dictionary defining the PolicyFeature of the output data for the policy. The key represents
the output data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
normalization_mapping: A dictionary that maps from a str value of FeatureType (e.g., "STATE", "VISUAL") to
a corresponding NormalizationMode (e.g., NormalizationMode.MIN_MAX)
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
pretrained_backbone_weights: Pretrained weights from torchvision to initialize the backbone.
`None` means no pretrained weights.
@@ -30,7 +30,7 @@ class DiffusionConfig(PreTrainedConfig):
Defaults are configured for training with PushT providing proprioceptive and single camera observations.
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
Those are: `input_shapes` and `output_shapes`.
Those are: `input_features` and `output_features`.
Notes on the inputs and outputs:
- "observation.state" is required as an input key.
@@ -48,21 +48,12 @@ class DiffusionConfig(PreTrainedConfig):
horizon: Diffusion model action prediction size as detailed in `DiffusionPolicy.select_action`.
n_action_steps: The number of action steps to run in the environment for one invocation of the policy.
See `DiffusionPolicy.select_action` for more details.
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
the input data name, and the value is a list indicating the dimensions of the corresponding data.
For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96],
indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't
include batch dimension or temporal dimension.
output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents
the output data name, and the value is a list indicating the dimensions of the corresponding data.
For example, "action" refers to an output shape of [14], indicating 14-dimensional actions.
Importantly, `output_shapes` doesn't include batch dimension or temporal dimension.
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
[-1, 1] range.
output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
original scale. Note that this is also used for normalizing the training targets.
input_features: A dictionary defining the PolicyFeature of the input data for the policy. The key represents
the input data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
output_features: A dictionary defining the PolicyFeature of the output data for the policy. The key represents
the output data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
normalization_mapping: A dictionary that maps from a str value of FeatureType (e.g., "STATE", "VISUAL") to
a corresponding NormalizationMode (e.g., NormalizationMode.MIN_MAX)
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit
within the image size. If None, no cropping is done.
@@ -73,7 +64,7 @@ class DiffusionConfig(PreTrainedConfig):
use_group_norm: Whether to replace batch normalization with group normalization in the backbone.
The group sizes are set to be about 16 (to be precise, feature_dim // 16).
spatial_softmax_num_keypoints: Number of keypoints for SpatialSoftmax.
use_separate_rgb_encoders_per_camera: Whether to use a separate RGB encoder for each camera view.
use_separate_rgb_encoder_per_camera: Whether to use a separate RGB encoder for each camera view.
down_dims: Feature dimension for each stage of temporal downsampling in the diffusion modeling Unet.
You may provide a variable number of dimensions, therefore also controlling the degree of
downsampling.
@@ -27,18 +27,18 @@ Usage:
# Full RA-BC computation with visualizations
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
--reward-model-path pepijn223/sarm_single_uni4
--reward-model-path <USER>/sarm_single_uni4
# Faster computation with stride (compute every 5 frames, interpolate the rest)
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
--reward-model-path pepijn223/sarm_single_uni4 \\
--reward-model-path <USER>/sarm_single_uni4 \\
--stride 5
# Visualize predictions only (no RA-BC computation)
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
--reward-model-path pepijn223/sarm_single_uni4 \\
--reward-model-path <USER>/sarm_single_uni4 \\
--visualize-only \\
--num-visualizations 5
@@ -714,12 +714,12 @@ Examples:
# Full RA-BC computation with visualizations
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
--reward-model-path pepijn223/sarm_single_uni4
--reward-model-path <USER>/sarm_single_uni4
# Visualize predictions only (no RA-BC computation)
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
--reward-model-path pepijn223/sarm_single_uni4 \\
--reward-model-path <USER>/sarm_single_uni4 \\
--visualize-only \\
--num-visualizations 10
""",
@@ -30,7 +30,7 @@ Example of finetuning the smolvla pretrained model (`smolvla_base`):
```bash
lerobot-train \
--policy.path=lerobot/smolvla_base \
--dataset.repo_id=danaaubakirova/svla_so100_task1_v3 \
--dataset.repo_id=<USER>/svla_so100_task1_v3 \
--batch_size=64 \
--steps=200000
```
@@ -40,7 +40,7 @@ and an action expert.
```bash
lerobot-train \
--policy.type=smolvla \
--dataset.repo_id=danaaubakirova/svla_so100_task1_v3 \
--dataset.repo_id=<USER>/svla_so100_task1_v3 \
--batch_size=64 \
--steps=200000
```
@@ -378,16 +378,16 @@ class SmolVLAPolicy(PreTrainedPolicy):
actions_is_pad = batch.get("actions_id_pad")
loss_dict = {}
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time)
loss_dict["losses_after_forward"] = losses.clone()
loss_dict["losses_after_forward"] = losses.clone().mean().item()
if actions_is_pad is not None:
in_episode_bound = ~actions_is_pad
losses = losses * in_episode_bound.unsqueeze(-1)
loss_dict["losses_after_in_ep_bound"] = losses.clone()
loss_dict["losses_after_in_ep_bound"] = losses.clone().mean().item()
# Remove padding
losses = losses[:, :, : self.config.max_action_dim]
loss_dict["losses_after_rm_padding"] = losses.clone()
loss_dict["losses_after_rm_padding"] = losses.clone().mean().item()
if reduction == "none":
# Return per-sample losses (B,) by averaging over time and action dims
@@ -30,7 +30,7 @@ class TDMPCConfig(PreTrainedConfig):
camera observations.
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
Those are: `input_shapes`, `output_shapes`, and perhaps `max_random_shift_ratio`.
Those are: `input_features`, `output_features`, and perhaps `max_random_shift_ratio`.
Args:
n_action_repeats: The number of times to repeat the action returned by the planning. (hint: Google
@@ -40,24 +40,12 @@ class TDMPCConfig(PreTrainedConfig):
is an alternative to using action repeats. If this is set to more than 1, then we require
`n_action_repeats == 1`, `use_mpc == True` and `n_action_steps <= horizon`. Note that this
approach of using multiple steps from the plan is not in the original implementation.
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
the input data name, and the value is a list indicating the dimensions of the corresponding data.
For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96],
indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't
include batch dimension or temporal dimension.
output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents
the output data name, and the value is a list indicating the dimensions of the corresponding data.
For example, "action" refers to an output shape of [14], indicating 14-dimensional actions.
Importantly, `output_shapes` doesn't include batch dimension or temporal dimension.
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
[-1, 1] range. Note that here this defaults to None meaning inputs are not normalized. This is to
match the original implementation.
output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
original scale. Note that this is also used for normalizing the training targets. NOTE: Clipping
to [-1, +1] is used during MPPI/CEM. Therefore, it is recommended that you stick with "min_max"
normalization mode here.
input_features: A dictionary defining the PolicyFeature of the input data for the policy. The key represents
the input data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
output_features: A dictionary defining the PolicyFeature of the output data for the policy. The key represents
the output data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
normalization_mapping: A dictionary that maps from a str value of FeatureType (e.g., "STATE", "VISUAL") to
a corresponding NormalizationMode (e.g., NormalizationMode.MIN_MAX)
image_encoder_hidden_dim: Number of channels for the convolutional layers used for image encoding.
state_encoder_hidden_dim: Hidden dimension for MLP used for state vector encoding.
latent_dim: Observation's latent embedding dimension.
@@ -32,7 +32,7 @@ class VQBeTConfig(PreTrainedConfig):
Defaults are configured for training with PushT providing proprioceptive and single camera observations.
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
Those are: `input_shapes` and `output_shapes`.
Those are: `input_features` and `output_features`.
Notes on the inputs and outputs:
- "observation.state" is required as an input key.
@@ -46,21 +46,12 @@ class VQBeTConfig(PreTrainedConfig):
current step and additional steps going back).
n_action_pred_token: Total number of current token and future tokens that VQ-BeT predicts.
action_chunk_size: Action chunk size of each action prediction token.
input_shapes: A dictionary defining the shapes of the input data for the policy.
The key represents the input data name, and the value is a list indicating the dimensions
of the corresponding data. For example, "observation.image" refers to an input from
a camera with dimensions [3, 96, 96], indicating it has three color channels and 96x96 resolution.
Importantly, shapes doesnt include batch dimension or temporal dimension.
output_shapes: A dictionary defining the shapes of the output data for the policy.
The key represents the output data name, and the value is a list indicating the dimensions
of the corresponding data. For example, "action" refers to an output shape of [14], indicating
14-dimensional actions. Importantly, shapes doesnt include batch dimension or temporal dimension.
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
[-1, 1] range.
output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
original scale. Note that this is also used for normalizing the training targets.
input_features: A dictionary defining the PolicyFeature of the input data for the policy. The key represents
the input data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
output_features: A dictionary defining the PolicyFeature of the output data for the policy. The key represents
the output data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
normalization_mapping: A dictionary that maps from a str value of FeatureType (e.g., "STATE", "VISUAL") to
a corresponding NormalizationMode (e.g., NormalizationMode.MIN_MAX)
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit
within the image size. If None, no cropping is done.
+2
View File
@@ -44,6 +44,7 @@ from .hil_processor import (
AddTeleopActionAsComplimentaryDataStep,
AddTeleopEventsAsInfoStep,
GripperPenaltyProcessorStep,
GymHILAdapterProcessorStep,
ImageCropResizeProcessorStep,
InterventionActionProcessorStep,
RewardClassifierProcessorStep,
@@ -87,6 +88,7 @@ __all__ = [
"DoneProcessorStep",
"EnvAction",
"EnvTransition",
"GymHILAdapterProcessorStep",
"GripperPenaltyProcessorStep",
"hotswap_stats",
"IdentityProcessorStep",
+4 -6
View File
@@ -17,7 +17,7 @@ from dataclasses import dataclass
import torch
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
from lerobot.utils.constants import OBS_IMAGES, OBS_PREFIX, OBS_STATE, OBS_STR
from .pipeline import ObservationProcessorStep, ProcessorStepRegistry
@@ -92,7 +92,7 @@ class LiberoProcessorStep(ObservationProcessorStep):
# copy over non-STATE features
for ft, feats in features.items():
if ft != PipelineFeatureType.STATE:
if ft != FeatureType.STATE:
new_features[ft] = feats.copy()
# rebuild STATE features
@@ -100,13 +100,11 @@ class LiberoProcessorStep(ObservationProcessorStep):
# add our new flattened state
state_feats[OBS_STATE] = PolicyFeature(
key=OBS_STATE,
type=FeatureType.STATE,
shape=(8,), # [eef_pos(3), axis_angle(3), gripper(2)]
dtype="float32",
description=("Concatenated end-effector position (3), axis-angle (3), and gripper qpos (2)."),
)
new_features[PipelineFeatureType.STATE] = state_feats
new_features[FeatureType.STATE] = state_feats
return new_features
@@ -20,6 +20,7 @@ from lerobot.configs.types import PipelineFeatureType, PolicyFeature
from .converters import to_tensor
from .core import EnvAction, EnvTransition, PolicyAction
from .hil_processor import TELEOP_ACTION_KEY
from .pipeline import ActionProcessorStep, ProcessorStep, ProcessorStepRegistry
@@ -89,6 +90,13 @@ class Numpy2TorchActionProcessorStep(ProcessorStep):
torch_action = to_tensor(action, dtype=None) # Preserve original dtype
new_transition[TransitionKey.ACTION] = torch_action
complementary_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
if TELEOP_ACTION_KEY in complementary_data:
teleop_action = complementary_data[TELEOP_ACTION_KEY]
if isinstance(teleop_action, EnvAction):
complementary_data[TELEOP_ACTION_KEY] = to_tensor(teleop_action)
new_transition[TransitionKey.COMPLEMENTARY_DATA] = complementary_data
return new_transition
def transform_features(
+43 -10
View File
@@ -312,9 +312,40 @@ class TimeLimitProcessorStep(TruncatedProcessorStep):
return features
@ProcessorStepRegistry.register("gym_hil_adapter_processor")
class GymHILAdapterProcessorStep(ProcessorStep):
"""
Adapts the output of the `gym-hil` environment to the format expected by `lerobot` processors.
This step normalizes the `transition` object by:
1. Copying `teleop_action` from `info` to `complementary_data`.
2. Copying `is_intervention` from `info` (using the string key) to `info` (using the enum key).
"""
def __call__(self, transition: EnvTransition) -> EnvTransition:
info = transition.get(TransitionKey.INFO, {})
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
if TELEOP_ACTION_KEY in info:
complementary_data[TELEOP_ACTION_KEY] = info[TELEOP_ACTION_KEY]
if "is_intervention" in info:
info[TeleopEvents.IS_INTERVENTION] = info["is_intervention"]
transition[TransitionKey.INFO] = info
transition[TransitionKey.COMPLEMENTARY_DATA] = complementary_data
return transition
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
return features
@dataclass
@ProcessorStepRegistry.register("gripper_penalty_processor")
class GripperPenaltyProcessorStep(ComplementaryDataProcessorStep):
class GripperPenaltyProcessorStep(ProcessorStep):
"""
Applies a penalty for inefficient gripper usage.
@@ -329,26 +360,27 @@ class GripperPenaltyProcessorStep(ComplementaryDataProcessorStep):
penalty: float = -0.01
max_gripper_pos: float = 30.0
def complementary_data(self, complementary_data: dict) -> dict:
def __call__(self, transition: EnvTransition) -> EnvTransition:
"""
Calculates the gripper penalty and adds it to the complementary data.
Args:
complementary_data: The incoming complementary data, which should contain
raw joint positions.
transition: The incoming environment transition.
Returns:
A new complementary data dictionary with the `discrete_penalty` key added.
The modified transition with the penalty added to complementary data.
"""
action = self.transition.get(TransitionKey.ACTION)
new_transition = transition.copy()
action = new_transition.get(TransitionKey.ACTION)
complementary_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
raw_joint_positions = complementary_data.get("raw_joint_positions")
if raw_joint_positions is None:
return complementary_data
return new_transition
current_gripper_pos = raw_joint_positions.get(GRIPPER_KEY, None)
if current_gripper_pos is None:
return complementary_data
return new_transition
# Gripper action is a PolicyAction at this stage
gripper_action = action[-1].item()
@@ -364,11 +396,12 @@ class GripperPenaltyProcessorStep(ComplementaryDataProcessorStep):
gripper_penalty = self.penalty * int(gripper_penalty_bool)
# Create new complementary data with penalty info
# Update complementary data with penalty info
new_complementary_data = dict(complementary_data)
new_complementary_data[DISCRETE_PENALTY_KEY] = gripper_penalty
new_transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary_data
return new_complementary_data
return new_transition
def get_config(self) -> dict[str, Any]:
"""
+1 -1
View File
@@ -413,7 +413,7 @@ class DataProcessorPipeline(HubMixin, Generic[TInput, TOutput]):
Args:
save_directory: The directory where the pipeline will be saved. If None, saves to
HF_LEROBOT_HOME/processors/{sanitized_pipeline_name}.
repo_id: ID of your repository on the Hub. Used only if `push_to_hub=True`.
repo_id: ID of your repository on the Hub. Used only if `push_to_hub=true`.
push_to_hub: Whether or not to push your object to the Hugging Face Hub after saving it.
card_kwargs: Additional arguments passed to the card template to customize the card.
config_filename: The name of the JSON configuration file. If None, a name is
+23 -4
View File
@@ -36,6 +36,7 @@ from lerobot.processor import (
DeviceProcessorStep,
EnvTransition,
GripperPenaltyProcessorStep,
GymHILAdapterProcessorStep,
ImageCropResizeProcessorStep,
InterventionActionProcessorStep,
MapDeltaActionToRobotActionStep,
@@ -379,6 +380,7 @@ def make_processors(
]
env_pipeline_steps = [
GymHILAdapterProcessorStep(),
Numpy2TorchActionProcessorStep(),
VanillaObservationProcessorStep(),
AddBatchDimensionProcessorStep(),
@@ -412,7 +414,10 @@ def make_processors(
if cfg.processor.observation.add_current_to_observation:
env_pipeline_steps.append(MotorCurrentProcessorStep(robot=env.robot))
if kinematics_solver is not None:
add_ee_pose = (
cfg.processor.observation is not None and cfg.processor.observation.add_ee_pose_to_observation
)
if kinematics_solver is not None and add_ee_pose:
env_pipeline_steps.append(
ForwardKinematicsJointsToEEObservation(
kinematics=kinematics_solver,
@@ -435,7 +440,12 @@ def make_processors(
)
# Add gripper penalty processor if gripper config exists and enabled
if cfg.processor.gripper is not None and cfg.processor.gripper.use_gripper:
# Only add if max_gripper_pos is explicitly configured (required for normalization)
if (
cfg.processor.gripper is not None
and cfg.processor.gripper.use_gripper
and cfg.processor.max_gripper_pos is not None
):
env_pipeline_steps.append(
GripperPenaltyProcessorStep(
penalty=cfg.processor.gripper.gripper_penalty,
@@ -600,7 +610,14 @@ def control_loop(
dataset = None
if cfg.mode == "record":
action_features = teleop_device.action_features
if teleop_device:
action_features = teleop_device.action_features
else:
action_features = {
"dtype": "float32",
"shape": (4,),
"names": ["delta_x", "delta_y", "delta_z", "gripper"],
}
features = {
ACTION: action_features,
REWARD: {"dtype": "float32", "shape": (1,), "names": None},
@@ -648,7 +665,7 @@ def control_loop(
# Create a neutral action (no movement)
neutral_action = torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32)
if use_gripper:
neutral_action = torch.cat([neutral_action, torch.tensor([1.0])]) # Gripper stay
neutral_action = torch.cat([neutral_action, torch.tensor([0.0])]) # Gripper stay
# Use the new step function
transition = step_env_and_process_transition(
@@ -717,6 +734,8 @@ def control_loop(
precise_sleep(max(dt - (time.perf_counter() - step_start_time), 0.0))
if dataset is not None and cfg.dataset.push_to_hub:
logging.info("Finalizing dataset before pushing to hub")
dataset.finalize()
logging.info("Pushing dataset to hub")
dataset.push_to_hub()
+17 -2
View File
@@ -26,8 +26,21 @@ from lerobot.configs.train import TrainPipelineConfig
from lerobot.utils.constants import PRETRAINED_MODEL_DIR
def cfg_to_group(cfg: TrainPipelineConfig, return_list: bool = False) -> list[str] | str:
def cfg_to_group(
cfg: TrainPipelineConfig, return_list: bool = False, truncate_tags: bool = False, max_tag_length: int = 64
) -> list[str] | str:
"""Return a group name for logging. Optionally returns group name as list."""
def _maybe_truncate(tag: str) -> str:
"""Truncate tag to max_tag_length characters if required.
wandb rejects tags longer than 64 characters.
See: https://github.com/wandb/wandb/blob/main/wandb/sdk/wandb_settings.py
"""
if len(tag) <= max_tag_length:
return tag
return tag[:max_tag_length]
lst = [
f"policy:{cfg.policy.type}",
f"seed:{cfg.seed}",
@@ -36,6 +49,8 @@ def cfg_to_group(cfg: TrainPipelineConfig, return_list: bool = False) -> list[st
lst.append(f"dataset:{cfg.dataset.repo_id}")
if cfg.env is not None:
lst.append(f"env:{cfg.env.type}")
if truncate_tags:
lst = [_maybe_truncate(tag) for tag in lst]
return lst if return_list else "-".join(lst)
@@ -83,7 +98,7 @@ class WandBLogger:
entity=self.cfg.entity,
name=self.job_name,
notes=self.cfg.notes,
tags=cfg_to_group(cfg, return_list=True),
tags=cfg_to_group(cfg, return_list=True, truncate_tags=True),
dir=self.log_dir,
config=cfg.to_dict(),
# TODO(rcadene): try set to True
@@ -19,6 +19,7 @@ from functools import cached_property
from lerobot.processor import RobotAction, RobotObservation
from lerobot.robots.openarm_follower import OpenArmFollower, OpenArmFollowerConfig
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from ..robot import Robot
from .config_bi_openarm_follower import BiOpenArmFollowerConfig
@@ -112,6 +113,7 @@ class BiOpenArmFollower(Robot):
def is_connected(self) -> bool:
return self.left_arm.is_connected and self.right_arm.is_connected
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
self.left_arm.connect(calibrate)
self.right_arm.connect(calibrate)
@@ -133,6 +135,7 @@ class BiOpenArmFollower(Robot):
"Motor ID configuration is typically done via manufacturer tools for CAN motors."
)
@check_if_not_connected
def get_observation(self) -> RobotObservation:
obs_dict = {}
@@ -146,6 +149,7 @@ class BiOpenArmFollower(Robot):
return obs_dict
@check_if_not_connected
def send_action(
self,
action: RobotAction,
@@ -170,6 +174,7 @@ class BiOpenArmFollower(Robot):
return {**prefixed_sent_action_left, **prefixed_sent_action_right}
@check_if_not_connected
def disconnect(self):
self.left_arm.disconnect()
self.right_arm.disconnect()
@@ -19,6 +19,7 @@ from functools import cached_property
from lerobot.processor import RobotAction, RobotObservation
from lerobot.robots.so_follower import SOFollower, SOFollowerRobotConfig
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from ..robot import Robot
from .config_bi_so_follower import BiSOFollowerConfig
@@ -96,6 +97,7 @@ class BiSOFollower(Robot):
def is_connected(self) -> bool:
return self.left_arm.is_connected and self.right_arm.is_connected
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
self.left_arm.connect(calibrate)
self.right_arm.connect(calibrate)
@@ -116,6 +118,7 @@ class BiSOFollower(Robot):
self.left_arm.setup_motors()
self.right_arm.setup_motors()
@check_if_not_connected
def get_observation(self) -> RobotObservation:
obs_dict = {}
@@ -129,6 +132,7 @@ class BiSOFollower(Robot):
return obs_dict
@check_if_not_connected
def send_action(self, action: RobotAction) -> RobotAction:
# Remove "left_" prefix
left_action = {
@@ -148,6 +152,7 @@ class BiSOFollower(Robot):
return {**prefixed_sent_action_left, **prefixed_sent_action_right}
@check_if_not_connected
def disconnect(self):
self.left_arm.disconnect()
self.right_arm.disconnect()
+1 -1
View File
@@ -140,7 +140,7 @@ class HopeJrArm(Robot):
# Capture images from cameras
for cam_key, cam in self.cameras.items():
start = time.perf_counter()
obs_dict[cam_key] = cam.async_read()
obs_dict[cam_key] = cam.read_latest()
dt_ms = (time.perf_counter() - start) * 1e3
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
+1 -1
View File
@@ -171,7 +171,7 @@ class HopeJrHand(Robot):
# Capture images from cameras
for cam_key, cam in self.cameras.items():
start = time.perf_counter()
obs_dict[cam_key] = cam.async_read()
obs_dict[cam_key] = cam.read_latest()
dt_ms = (time.perf_counter() - start) * 1e3
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
@@ -193,7 +193,7 @@ class KochFollower(Robot):
# Capture images from cameras
for cam_key, cam in self.cameras.items():
start = time.perf_counter()
obs_dict[cam_key] = cam.async_read()
obs_dict[cam_key] = cam.read_latest()
dt_ms = (time.perf_counter() - start) * 1e3
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
+1 -1
View File
@@ -360,7 +360,7 @@ class LeKiwi(Robot):
# Capture images from cameras
for cam_key, cam in self.cameras.items():
start = time.perf_counter()
obs_dict[cam_key] = cam.async_read()
obs_dict[cam_key] = cam.read_latest()
dt_ms = (time.perf_counter() - start) * 1e3
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
@@ -176,7 +176,7 @@ class OmxFollower(Robot):
# Capture images from cameras
for cam_key, cam in self.cameras.items():
start = time.perf_counter()
obs_dict[cam_key] = cam.async_read()
obs_dict[cam_key] = cam.read_latest()
dt_ms = (time.perf_counter() - start) * 1e3
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
@@ -23,7 +23,7 @@ from lerobot.cameras.utils import make_cameras_from_configs
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
from lerobot.motors.damiao import DamiaoMotorsBus
from lerobot.processor import RobotAction, RobotObservation
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from ..robot import Robot
from ..utils import ensure_safe_goal_position
@@ -119,6 +119,7 @@ class OpenArmFollower(Robot):
"""Check if robot is connected."""
return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values())
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
"""
Connect to the robot and optionally calibrate.
@@ -126,8 +127,6 @@ class OpenArmFollower(Robot):
We assume that at connection time, the arms are in a safe rest position,
and torque can be safely disabled to run calibration if needed.
"""
if self.is_connected:
raise DeviceAlreadyConnectedError(f"{self} already connected")
# Connect to CAN bus
logger.info(f"Connecting arm on {self.config.port}...")
@@ -219,6 +218,7 @@ class OpenArmFollower(Robot):
"Motor ID configuration is typically done via manufacturer tools for CAN motors."
)
@check_if_not_connected
def get_observation(self) -> RobotObservation:
"""
Get current observation from robot including position, velocity, and torque.
@@ -228,9 +228,6 @@ class OpenArmFollower(Robot):
"""
start = time.perf_counter()
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
obs_dict: dict[str, Any] = {}
states = self.bus.sync_read_all_states()
@@ -244,7 +241,7 @@ class OpenArmFollower(Robot):
# Capture images from cameras
for cam_key, cam in self.cameras.items():
start = time.perf_counter()
obs_dict[cam_key] = cam.async_read()
obs_dict[cam_key] = cam.read_latest()
dt_ms = (time.perf_counter() - start) * 1e3
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
@@ -253,6 +250,7 @@ class OpenArmFollower(Robot):
return obs_dict
@check_if_not_connected
def send_action(
self,
action: RobotAction,
@@ -272,8 +270,6 @@ class OpenArmFollower(Robot):
Returns:
The action actually sent (potentially clipped)
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
goal_pos = {key.removesuffix(".pos"): val for key, val in action.items() if key.endswith(".pos")}
@@ -333,10 +329,9 @@ class OpenArmFollower(Robot):
return {f"{motor}.pos": val for motor, val in goal_pos.items()}
@check_if_not_connected
def disconnect(self):
"""Disconnect from robot."""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
# Disconnect CAN bus
self.bus.disconnect(self.config.disable_torque_on_disconnect)
+1 -1
View File
@@ -180,7 +180,7 @@ class Reachy2Robot(Robot):
# Capture images from cameras
for cam_key, cam in self.cameras.items():
obs_dict[cam_key] = cam.async_read()
obs_dict[cam_key] = cam.read_latest()
return obs_dict
@@ -40,7 +40,7 @@ class SOFollowerConfig:
cameras: dict[str, CameraConfig] = field(default_factory=dict)
# Set to `True` for backward compatibility with previous policies/dataset
use_degrees: bool = False
use_degrees: bool = True
@RobotConfig.register_subclass("so101_follower")
@@ -187,7 +187,7 @@ class SOFollower(Robot):
# Capture images from cameras
for cam_key, cam in self.cameras.items():
start = time.perf_counter()
obs_dict[cam_key] = cam.async_read()
obs_dict[cam_key] = cam.read_latest()
dt_ms = (time.perf_counter() - start) * 1e3
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
+1 -1
View File
@@ -324,7 +324,7 @@ class UnitreeG1(Robot):
# Cameras - read images from ZMQ cameras
for cam_name, cam in self._cameras.items():
obs[cam_name] = cam.async_read()
obs[cam_name] = cam.read_latest()
return obs
+25 -12
View File
@@ -47,16 +47,14 @@ local$ rerun lerobot_pusht_episode_0.rrd
```
- Visualize data stored on a distant machine through streaming:
(You need to forward the websocket port to the distant machine, with
`ssh -L 9087:localhost:9087 username@remote-host`)
```
distant$ lerobot-dataset-viz \
--repo-id lerobot/pusht \
--episode-index 0 \
--mode distant \
--ws-port 9087
--grpc-port 9876
local$ rerun ws://localhost:9087
local$ rerun rerun+http://IP:GRPC_PORT/proxy
```
"""
@@ -75,6 +73,7 @@ import tqdm
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.utils.constants import ACTION, DONE, OBS_STATE, REWARD
from lerobot.utils.utils import init_logging
def to_hwc_uint8_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray:
@@ -93,10 +92,11 @@ def visualize_dataset(
num_workers: int = 0,
mode: str = "local",
web_port: int = 9090,
ws_port: int = 9087,
grpc_port: int = 9876,
save: bool = False,
output_dir: Path | None = None,
display_compressed_images: bool = False,
**kwargs,
) -> Path | None:
if save:
assert output_dir is not None, (
@@ -126,7 +126,9 @@ def visualize_dataset(
gc.collect()
if mode == "distant":
rr.serve_web_viewer(open_browser=False, web_port=web_port)
server_uri = rr.serve_grpc(grpc_port=grpc_port)
logging.info(f"Connect to a Rerun Server: rerun rerun+http://IP:{grpc_port}/proxy")
rr.serve_web_viewer(open_browser=False, web_port=web_port, connect_to=server_uri)
logging.info("Logging to Rerun")
@@ -226,7 +228,7 @@ def main():
"Mode of viewing between 'local' or 'distant'. "
"'local' requires data to be on a local machine. It spawns a viewer to visualize the data locally. "
"'distant' creates a server on the distant machine where the data is stored. "
"Visualize the data by connecting to the server with `rerun ws://localhost:PORT` on the local machine."
"Visualize the data by connecting to the server with `rerun rerun+http://IP:GRPC_PORT/proxy` on the local machine."
),
)
parser.add_argument(
@@ -238,8 +240,13 @@ def main():
parser.add_argument(
"--ws-port",
type=int,
default=9087,
help="Web socket port for rerun.io when `--mode distant` is set.",
help="deprecated, please use --grpc-port instead.",
)
parser.add_argument(
"--grpc-port",
type=int,
default=9876,
help="gRPC port for rerun.io when `--mode distant` is set.",
)
parser.add_argument(
"--save",
@@ -265,9 +272,7 @@ def main():
parser.add_argument(
"--display-compressed-images",
type=bool,
required=True,
default=False,
action="store_true",
help="If set, display compressed images in Rerun instead of uncompressed ones.",
)
@@ -277,6 +282,14 @@ def main():
root = kwargs.pop("root")
tolerance_s = kwargs.pop("tolerance_s")
if kwargs["ws_port"] is not None:
logging.warning(
"--ws-port is deprecated and will be removed in future versions. Please use --grpc-port instead."
)
logging.warning("Setting grpc_port to ws_port value.")
kwargs["grpc_port"] = kwargs.pop("ws_port")
init_logging()
logging.info("Loading dataset")
dataset = LeRobotDataset(repo_id, episodes=[args.episode_index], root=root, tolerance_s=tolerance_s)
+104 -38
View File
@@ -24,96 +24,112 @@ When new_repo_id is specified, creates a new dataset.
Usage Examples:
Delete episodes 0, 2, and 5 from a dataset:
python -m lerobot.scripts.lerobot_edit_dataset \
lerobot-edit-dataset \
--repo_id lerobot/pusht \
--operation.type delete_episodes \
--operation.episode_indices "[0, 2, 5]"
Delete episodes and save to a new dataset:
python -m lerobot.scripts.lerobot_edit_dataset \
lerobot-edit-dataset \
--repo_id lerobot/pusht \
--new_repo_id lerobot/pusht_filtered \
--operation.type delete_episodes \
--operation.episode_indices "[0, 2, 5]"
Split dataset by fractions:
python -m lerobot.scripts.lerobot_edit_dataset \
lerobot-edit-dataset \
--repo_id lerobot/pusht \
--operation.type split \
--operation.splits '{"train": 0.8, "val": 0.2}'
Split dataset by episode indices:
python -m lerobot.scripts.lerobot_edit_dataset \
lerobot-edit-dataset \
--repo_id lerobot/pusht \
--operation.type split \
--operation.splits '{"train": [0, 1, 2, 3], "val": [4, 5]}'
Split into more than two splits:
python -m lerobot.scripts.lerobot_edit_dataset \
lerobot-edit-dataset \
--repo_id lerobot/pusht \
--operation.type split \
--operation.splits '{"train": 0.6, "val": 0.2, "test": 0.2}'
Merge multiple datasets:
python -m lerobot.scripts.lerobot_edit_dataset \
lerobot-edit-dataset \
--repo_id lerobot/pusht_merged \
--operation.type merge \
--operation.repo_ids "['lerobot/pusht_train', 'lerobot/pusht_val']"
Remove camera feature:
python -m lerobot.scripts.lerobot_edit_dataset \
lerobot-edit-dataset \
--repo_id lerobot/pusht \
--operation.type remove_feature \
--operation.feature_names "['observation.images.top']"
Modify tasks - set a single task for all episodes (WARNING: modifies in-place):
python -m lerobot.scripts.lerobot_edit_dataset \
lerobot-edit-dataset \
--repo_id lerobot/pusht \
--operation.type modify_tasks \
--operation.new_task "Pick up the cube and place it"
Modify tasks - set different tasks for specific episodes (WARNING: modifies in-place):
python -m lerobot.scripts.lerobot_edit_dataset \
lerobot-edit-dataset \
--repo_id lerobot/pusht \
--operation.type modify_tasks \
--operation.episode_tasks '{"0": "Task A", "1": "Task B", "2": "Task A"}'
Modify tasks - set default task with overrides for specific episodes (WARNING: modifies in-place):
python -m lerobot.scripts.lerobot_edit_dataset \
lerobot-edit-dataset \
--repo_id lerobot/pusht \
--operation.type modify_tasks \
--operation.new_task "Default task" \
--operation.episode_tasks '{"5": "Special task for episode 5"}'
Convert image dataset to video format and save locally:
python -m lerobot.scripts.lerobot_edit_dataset \
lerobot-edit-dataset \
--repo_id lerobot/pusht_image \
--operation.type convert_image_to_video \
--operation.output_dir /path/to/output/pusht_video
Convert image dataset to video format and save with new repo_id:
python -m lerobot.scripts.lerobot_edit_dataset \
lerobot-edit-dataset \
--repo_id lerobot/pusht_image \
--new_repo_id lerobot/pusht_video \
--operation.type convert_image_to_video
Convert image dataset to video format and push to hub:
python -m lerobot.scripts.lerobot_edit_dataset \
lerobot-edit-dataset \
--repo_id lerobot/pusht_image \
--new_repo_id lerobot/pusht_video \
--operation.type convert_image_to_video \
--push_to_hub true
Show dataset information:
lerobot-edit-dataset \
--repo_id lerobot/pusht_image \
--operation.type info \
--operation.show_features true
Show dataset information without feature details:
lerobot-edit-dataset \
--repo_id lerobot/pusht_image \
--operation.type info \
--operation.show_features false
Using JSON config file:
python -m lerobot.scripts.lerobot_edit_dataset \
lerobot-edit-dataset \
--config_path path/to/edit_config.json
"""
import abc
import logging
import shutil
import sys
from dataclasses import dataclass
from pathlib import Path
import draccus
from lerobot.configs import parser
from lerobot.datasets.dataset_tools import (
convert_image_to_video_dataset,
@@ -129,39 +145,46 @@ from lerobot.utils.utils import init_logging
@dataclass
class DeleteEpisodesConfig:
type: str = "delete_episodes"
class OperationConfig(draccus.ChoiceRegistry, abc.ABC):
@property
def type(self) -> str:
return self.get_choice_name(self.__class__)
@OperationConfig.register_subclass("delete_episodes")
@dataclass
class DeleteEpisodesConfig(OperationConfig):
episode_indices: list[int] | None = None
@OperationConfig.register_subclass("split")
@dataclass
class SplitConfig:
type: str = "split"
class SplitConfig(OperationConfig):
splits: dict[str, float | list[int]] | None = None
@OperationConfig.register_subclass("merge")
@dataclass
class MergeConfig:
type: str = "merge"
class MergeConfig(OperationConfig):
repo_ids: list[str] | None = None
@OperationConfig.register_subclass("remove_feature")
@dataclass
class RemoveFeatureConfig:
type: str = "remove_feature"
class RemoveFeatureConfig(OperationConfig):
feature_names: list[str] | None = None
@OperationConfig.register_subclass("modify_tasks")
@dataclass
class ModifyTasksConfig:
type: str = "modify_tasks"
class ModifyTasksConfig(OperationConfig):
new_task: str | None = None
episode_tasks: dict[str, str] | None = None
@OperationConfig.register_subclass("convert_image_to_video")
@dataclass
class ConvertImageToVideoConfig:
type: str = "convert_image_to_video"
class ConvertImageToVideoConfig(OperationConfig):
output_dir: str | None = None
vcodec: str = "libsvtav1"
pix_fmt: str = "yuv420p"
@@ -174,17 +197,17 @@ class ConvertImageToVideoConfig:
max_frames_per_batch: int | None = None
@OperationConfig.register_subclass("info")
@dataclass
class InfoConfig(OperationConfig):
type: str = "info"
show_features: bool = False
@dataclass
class EditDatasetConfig:
repo_id: str
operation: (
DeleteEpisodesConfig
| SplitConfig
| MergeConfig
| RemoveFeatureConfig
| ModifyTasksConfig
| ConvertImageToVideoConfig
)
operation: OperationConfig
root: str | None = None
new_repo_id: str | None = None
push_to_hub: bool = False
@@ -433,6 +456,49 @@ def handle_convert_image_to_video(cfg: EditDatasetConfig) -> None:
logging.info("Dataset saved locally (not pushed to hub)")
def _get_dataset_size(repo_path):
import os
total = 0
with os.scandir(repo_path) as it:
for entry in it:
if entry.is_file():
total += entry.stat().st_size
elif entry.is_dir():
total += _get_dataset_size(entry.path)
return total
def handle_info(cfg: EditDatasetConfig):
if not isinstance(cfg.operation, InfoConfig):
raise ValueError("Operation config must be InfoConfig")
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root)
sys.stdout.write(f"======Info {dataset.meta.repo_id}\n")
sys.stdout.write(f"Repository ID: {dataset.meta.repo_id} \n")
sys.stdout.write(f"Total episode: {dataset.meta.total_episodes} \n")
sys.stdout.write(f"Total task: {dataset.meta.total_tasks} \n")
sys.stdout.write(f"Total frame(Actual Count): {dataset.meta.total_frames}({len(dataset)}) \n")
sys.stdout.write(
f"Average frame per episode: {dataset.meta.total_frames / dataset.meta.total_episodes:.1f}\n"
)
sys.stdout.write(
f"Average episode time(sec): {(dataset.meta.total_frames / dataset.meta.total_episodes) / dataset.meta.fps:.1f}\n"
)
sys.stdout.write(f"FPS: {dataset.meta.fps}\n")
total_file_size = _get_dataset_size(dataset.root)
sys.stdout.write(f"Size: {total_file_size / (1024 * 1024):.1f} MB\n")
if cfg.operation.show_features:
import json
feature_dump_str = json.dumps(
dataset.meta.features, ensure_ascii=False, indent=4, sort_keys=True, separators=(",", ": ")
)
sys.stdout.write("Features:\n")
sys.stdout.write(f"{feature_dump_str}\n")
@parser.wrap()
def edit_dataset(cfg: EditDatasetConfig) -> None:
operation_type = cfg.operation.type
@@ -449,11 +515,11 @@ def edit_dataset(cfg: EditDatasetConfig) -> None:
handle_modify_tasks(cfg)
elif operation_type == "convert_image_to_video":
handle_convert_image_to_video(cfg)
elif operation_type == "info":
handle_info(cfg)
else:
raise ValueError(
f"Unknown operation type: {operation_type}\n"
f"Available operations: delete_episodes, split, merge, remove_feature, modify_tasks, convert_image_to_video"
)
available = ", ".join(OperationConfig.get_known_choices())
raise ValueError(f"Unknown operation: {operation_type}\nAvailable operations: {available}")
def main() -> None:
+8 -1
View File
@@ -398,7 +398,14 @@ def record_loop(
)
dt_s = time.perf_counter() - start_loop_t
precise_sleep(max(1 / fps - dt_s, 0.0))
sleep_time_s: float = 1 / fps - dt_s
if sleep_time_s < 0:
logging.warning(
f"Record loop is running slower ({1 / dt_s:.1f} Hz) than the target FPS ({fps} Hz). Dataset frames might be dropped and robot control might be unstable. Common causes are: 1) Camera FPS not keeping up 2) Policy inference taking too long 3) CPU starvation"
)
precise_sleep(max(sleep_time_s, 0.0))
timestamp = time.perf_counter() - start_episode_t
+1 -1
View File
@@ -22,7 +22,7 @@ lerobot-replay \
--robot.type=so100_follower \
--robot.port=/dev/tty.usbmodem58760431541 \
--robot.id=black \
--dataset.repo_id=aliberts/record-test \
--dataset.repo_id=<USER>/record-test \
--dataset.episode=0
```
+2 -2
View File
@@ -45,7 +45,7 @@ from dataclasses import dataclass, field
import draccus
from lerobot.utils.import_utils import is_package_available
from lerobot.utils.import_utils import _can_available
MOTOR_NAMES = {
0x01: "joint_1",
@@ -336,7 +336,7 @@ def run_speed(cfg: CANSetupConfig):
@draccus.wrap()
def setup_can(cfg: CANSetupConfig):
if not is_package_available("can"):
if not _can_available:
print("Error: python-can not installed. Install with: pip install python-can")
sys.exit(1)
@@ -19,6 +19,7 @@ from functools import cached_property
from lerobot.processor import RobotAction
from lerobot.teleoperators.openarm_leader import OpenArmLeaderConfig
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from ..openarm_leader import OpenArmLeader
from ..teleoperator import Teleoperator
@@ -88,6 +89,7 @@ class BiOpenArmLeader(Teleoperator):
def is_connected(self) -> bool:
return self.left_arm.is_connected and self.right_arm.is_connected
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
self.left_arm.connect(calibrate)
self.right_arm.connect(calibrate)
@@ -109,6 +111,7 @@ class BiOpenArmLeader(Teleoperator):
"Motor ID configuration is typically done via manufacturer tools for CAN motors."
)
@check_if_not_connected
def get_action(self) -> RobotAction:
action_dict = {}
@@ -126,6 +129,7 @@ class BiOpenArmLeader(Teleoperator):
# TODO: Implement force feedback
raise NotImplementedError
@check_if_not_connected
def disconnect(self) -> None:
self.left_arm.disconnect()
self.right_arm.disconnect()
@@ -18,7 +18,7 @@ import logging
from functools import cached_property
from lerobot.teleoperators.so_leader import SOLeaderTeleopConfig
from lerobot.utils.decorators import check_if_not_connected
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from ..so_leader import SOLeader
from ..teleoperator import Teleoperator
@@ -72,6 +72,7 @@ class BiSOLeader(Teleoperator):
def is_connected(self) -> bool:
return self.left_arm.is_connected and self.right_arm.is_connected
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
self.left_arm.connect(calibrate)
self.right_arm.connect(calibrate)
@@ -110,6 +111,7 @@ class BiSOLeader(Teleoperator):
# TODO: Implement force feedback
raise NotImplementedError
@check_if_not_connected
def disconnect(self) -> None:
self.left_arm.disconnect()
self.right_arm.disconnect()
@@ -21,7 +21,7 @@ from typing import Any
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
from lerobot.motors.damiao import DamiaoMotorsBus
from lerobot.processor import RobotAction
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from ..teleoperator import Teleoperator
from .config_openarm_leader import OpenArmLeaderConfig
@@ -84,6 +84,7 @@ class OpenArmLeader(Teleoperator):
"""Check if teleoperator is connected."""
return self.bus.is_connected
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
"""
Connect to the teleoperator.
@@ -91,8 +92,6 @@ class OpenArmLeader(Teleoperator):
For manual control, we disable torque after connecting so the
arm can be moved by hand.
"""
if self.is_connected:
raise DeviceAlreadyConnectedError(f"{self} already connected")
# Connect to CAN bus
logger.info(f"Connecting arm on {self.config.port}...")
@@ -183,6 +182,7 @@ class OpenArmLeader(Teleoperator):
"Motor ID configuration is typically done via manufacturer tools for CAN motors."
)
@check_if_not_connected
def get_action(self) -> RobotAction:
"""
Get current action from the leader arm.
@@ -193,8 +193,6 @@ class OpenArmLeader(Teleoperator):
Reads all motor states (pos/vel/torque) in one CAN refresh cycle.
"""
start = time.perf_counter()
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
action_dict: dict[str, Any] = {}
@@ -214,10 +212,9 @@ class OpenArmLeader(Teleoperator):
def send_feedback(self, feedback: dict[str, float]) -> None:
raise NotImplementedError("Feedback is not yet implemented for OpenArm leader.")
@check_if_not_connected
def disconnect(self) -> None:
"""Disconnect from teleoperator."""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
# Disconnect CAN bus
# For manual control, ensure torque is disabled before disconnecting
@@ -28,7 +28,7 @@ class SOLeaderConfig:
port: str
# Whether to use degrees for angles
use_degrees: bool = False
use_degrees: bool = True
@TeleoperatorConfig.register_subclass("so101_leader")
+2 -2
View File
@@ -16,14 +16,14 @@ import platform
import time
def precise_sleep(seconds: float, spin_threshold: float = 0.010, sleep_margin: float = 0.003):
def precise_sleep(seconds: float, spin_threshold: float = 0.010, sleep_margin: float = 0.005):
"""
Wait for `seconds` with better precision than time.sleep alone at the expense of more CPU usage.
Parameters:
- seconds: duration to wait
- spin_threshold: if remaining <= spin_threshold -> spin; otherwise sleep (seconds). Default 10ms
- sleep_margin: when sleeping leave this much time before deadline to avoid oversleep. Default 3ms
- sleep_margin: when sleeping leave this much time before deadline to avoid oversleep. Default 5ms
Note:
The default parameters are chosen to prioritize timing accuracy over CPU usage for the common 30 FPS use case.
+24
View File
@@ -390,6 +390,30 @@ def test_sharpness_jitter_invalid_range_max_smaller():
SharpnessJitter((2.0, 0.1))
def test_make_transform_from_config_with_v2_resize(img_tensor_factory):
img_tensor = img_tensor_factory()
tf_cfg = ImageTransformConfig(type="Resize", kwargs={"size": (32, 32)})
tf = make_transform_from_config(tf_cfg)
assert isinstance(tf, v2.Resize)
output = tf(img_tensor)
assert output.shape[-2:] == (32, 32)
def test_make_transform_from_config_with_v2_identity(img_tensor_factory):
img_tensor = img_tensor_factory()
tf_cfg = ImageTransformConfig(type="Identity", kwargs={})
tf = make_transform_from_config(tf_cfg)
assert isinstance(tf, v2.Identity)
output = tf(img_tensor)
assert output.shape == img_tensor.shape
def test_make_transform_from_config_invalid_type():
tf_cfg = ImageTransformConfig(type="NotARealTransform", kwargs={})
with pytest.raises(ValueError, match="not valid"):
make_transform_from_config(tf_cfg)
def test_save_all_transforms(img_tensor_factory, tmp_path):
img_tensor = img_tensor_factory()
tf_cfg = ImageTransformsConfig(enable=True)
+14
View File
@@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from packaging.version import Version
from torch.optim.lr_scheduler import LambdaLR
from lerobot.optim.schedulers import (
@@ -38,6 +40,10 @@ def test_diffuser_scheduler(optimizer):
"last_epoch": 1,
"lr_lambdas": [None],
}
if Version(torch.__version__) >= Version("2.8"):
expected_state_dict["_is_initial"] = False
assert scheduler.state_dict() == expected_state_dict
@@ -56,6 +62,10 @@ def test_vqbet_scheduler(optimizer):
"last_epoch": 1,
"lr_lambdas": [None],
}
if Version(torch.__version__) >= Version("2.8"):
expected_state_dict["_is_initial"] = False
assert scheduler.state_dict() == expected_state_dict
@@ -76,6 +86,10 @@ def test_cosine_decay_with_warmup_scheduler(optimizer):
"last_epoch": 1,
"lr_lambdas": [None],
}
if Version(torch.__version__) >= Version("2.8"):
expected_state_dict["_is_initial"] = False
assert scheduler.state_dict() == expected_state_dict
+1
View File
@@ -142,6 +142,7 @@ def _make_reachy2_camera_mock(*args, **kwargs):
cam.connect = MagicMock()
cam.disconnect = MagicMock()
cam.async_read = MagicMock(side_effect=lambda: np.zeros((height, width, 3), dtype=np.uint8))
cam.read_latest = MagicMock(side_effect=lambda: np.zeros((height, width, 3), dtype=np.uint8))
return cam
@@ -0,0 +1,74 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import draccus
import pytest
from lerobot.scripts.lerobot_edit_dataset import (
ConvertImageToVideoConfig,
DeleteEpisodesConfig,
EditDatasetConfig,
InfoConfig,
MergeConfig,
ModifyTasksConfig,
OperationConfig,
RemoveFeatureConfig,
SplitConfig,
)
def parse_cfg(cli_args: list[str]) -> EditDatasetConfig:
"""Helper to parse CLI args into an EditDatasetConfig via draccus."""
return draccus.parse(EditDatasetConfig, args=cli_args)
class TestOperationTypeParsing:
"""Test that --operation.type correctly selects the right config subclass."""
@pytest.mark.parametrize(
"type_name, expected_cls",
[
("delete_episodes", DeleteEpisodesConfig),
("split", SplitConfig),
("merge", MergeConfig),
("remove_feature", RemoveFeatureConfig),
("modify_tasks", ModifyTasksConfig),
("convert_image_to_video", ConvertImageToVideoConfig),
("info", InfoConfig),
],
)
def test_operation_type_resolves_correct_class(self, type_name, expected_cls):
cfg = parse_cfg(["--repo_id", "test/repo", "--operation.type", type_name])
assert isinstance(cfg.operation, expected_cls), (
f"Expected {expected_cls.__name__}, got {type(cfg.operation).__name__}"
)
@pytest.mark.parametrize(
"type_name, expected_cls",
[
("delete_episodes", DeleteEpisodesConfig),
("split", SplitConfig),
("merge", MergeConfig),
("remove_feature", RemoveFeatureConfig),
("modify_tasks", ModifyTasksConfig),
("convert_image_to_video", ConvertImageToVideoConfig),
("info", InfoConfig),
],
)
def test_get_choice_name_roundtrips(self, type_name, expected_cls):
cfg = parse_cfg(["--repo_id", "test/repo", "--operation.type", type_name])
resolved_name = OperationConfig.get_choice_name(type(cfg.operation))
assert resolved_name == type_name